Files
codewhale/crates/tui/tests/cache_guard.rs
T

345 lines
9.7 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Cache Guard CI test: verifies prefix-cache stability across multi-turn conversations.
//!
//! Runs 8 test cases × 14-24 turns each, checking that the tail average
//! hit rate stays above a configurable threshold (default 40%).
//!
//! Environment variables:
//! CODEWHALE_CACHE_GUARD=1 Enable the guard (default: disabled)
//! CODEWHALE_CACHE_GUARD_THRESHOLD=90 Hit rate threshold (0-100)
//! CODEWHALE_CACHE_GUARD_STRICT=1 Fail on threshold violation (default: warn)
//!
//! Usage:
//! CODEWHALE_CACHE_GUARD=1 cargo test --test cache_guard
//! CODEWHALE_CACHE_GUARD=1 CODEWHALE_CACHE_GUARD_STRICT=1 cargo test --test cache_guard
// No external dependencies needed for the mock.
// === Configuration ===
const DEFAULT_THRESHOLD: f64 = 40.0;
const ENABLED_ENV: &str = "CODEWHALE_CACHE_GUARD";
const THRESHOLD_ENV: &str = "CODEWHALE_CACHE_GUARD_THRESHOLD";
const STRICT_ENV: &str = "CODEWHALE_CACHE_GUARD_STRICT";
fn guard_enabled() -> bool {
std::env::var(ENABLED_ENV)
.map(|v| v == "1" || v == "true")
.unwrap_or(false)
}
fn threshold() -> f64 {
std::env::var(THRESHOLD_ENV)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_THRESHOLD)
}
fn strict() -> bool {
std::env::var(STRICT_ENV)
.map(|v| v == "1" || v == "true")
.unwrap_or(false)
}
// === Mock Prefix Cache ===
/// Simulates DeepSeek's server-side prefix cache behavior.
///
/// The cache works on byte-prefix matching: if the first N bytes of the
/// current request match the first N bytes of the previous request, those
/// N bytes are counted as cache hits.
struct MockPrefixCache {
previous_body: Vec<u8>,
total_input_bytes: u64,
hit_bytes: u64,
per_turn_hit_rates: Vec<f64>,
}
impl MockPrefixCache {
fn new() -> Self {
Self {
previous_body: Vec::new(),
total_input_bytes: 0,
hit_bytes: 0,
per_turn_hit_rates: Vec::new(),
}
}
/// Submit a request body and compute cache hit/miss for this turn.
fn submit(&mut self, body: &[u8]) {
let common_prefix = body
.iter()
.zip(self.previous_body.iter())
.take_while(|(a, b)| a == b)
.count();
let body_len = body.len() as u64;
self.total_input_bytes += body_len;
self.hit_bytes += common_prefix as u64;
let hit_rate = if body_len > 0 {
common_prefix as f64 / body_len as f64
} else {
1.0
};
self.per_turn_hit_rates.push(hit_rate);
self.previous_body = body.to_vec();
}
/// Compute the average hit rate over the last N turns.
fn tail_avg(&self, n: usize) -> f64 {
let start = self.per_turn_hit_rates.len().saturating_sub(n);
let tail = &self.per_turn_hit_rates[start..];
if tail.is_empty() {
0.0
} else {
tail.iter().sum::<f64>() / tail.len() as f64
}
}
/// Overall hit rate across all turns.
fn overall_hit_rate(&self) -> f64 {
if self.total_input_bytes == 0 {
0.0
} else {
self.hit_bytes as f64 / self.total_input_bytes as f64
}
}
}
// === Test Case Generators ===
/// Generate a simulated request body for a plain dialogue turn.
fn plain_dialogue_body(turn: usize, with_reasoning: bool) -> Vec<u8> {
let system = "You are a helpful assistant. Answer concisely and accurately.";
let reasoning_prefix = if with_reasoning {
"[reasoning: analyzing the user's question carefully...]"
} else {
""
};
let user_msg = format!("User message turn {turn} — please respond to this query.");
let body =
format!("{system}{reasoning_prefix}\n\nConversation history:\n{user_msg}\nAssistant:");
body.into_bytes()
}
/// Generate a simulated request body for a tool-loop turn.
fn tool_loop_body(turn: usize, with_reasoning: bool) -> Vec<u8> {
let system = "You are a helpful assistant with tool access.";
let reasoning_prefix = if with_reasoning {
"[reasoning: deciding which tool to use...]"
} else {
""
};
let tool_name = if turn.is_multiple_of(2) {
"read_file"
} else {
"write_file"
};
let tool_args = format!(r#"{{"path": "/tmp/file_{turn}.txt"}}"#);
let user_msg = format!("User request turn {turn}");
let body = format!(
"{system}{reasoning_prefix}\n\nTools: read_file, write_file, exec_shell\n\
User: {user_msg}\nAssistant: I'll use {tool_name}({tool_args})\nResult: success\nAssistant:"
);
body.into_bytes()
}
/// Generate a simulated request body with mixed sizes.
fn mixed_size_body(turn: usize) -> Vec<u8> {
let system = "You are a helpful assistant.";
let user_msg = match turn % 4 {
0 => format!("Short question {turn}"),
1 => format!(
"Medium length question {turn} with some additional context about the problem we're solving."
),
2 => {
let long_context = "Lorem ipsum dolor sit amet. ".repeat(20);
format!("Long question {turn} with extensive context: {long_context}")
}
_ => format!("Question {turn}"),
};
let body = format!("{system}\n\nUser: {user_msg}\nAssistant:");
body.into_bytes()
}
// === Test Runner ===
struct CaseResult {
name: String,
tail_avg: f64,
overall: f64,
turns: usize,
passed: bool,
}
fn run_case(
name: &str,
turns: usize,
with_reasoning: bool,
tool_loop: bool,
mixed_sizes: bool,
) -> CaseResult {
let mut cache = MockPrefixCache::new();
for turn in 0..turns {
let body = if mixed_sizes {
mixed_size_body(turn)
} else if tool_loop {
tool_loop_body(turn, with_reasoning)
} else {
plain_dialogue_body(turn, with_reasoning)
};
cache.submit(&body);
}
let tail_avg = cache.tail_avg(5) * 100.0;
let overall = cache.overall_hit_rate() * 100.0;
let thresh = threshold();
let passed = tail_avg >= thresh;
CaseResult {
name: name.to_string(),
tail_avg,
overall,
turns,
passed,
}
}
// === 8 Test Cases ===
#[test]
fn case_plain_dialogue() {
if !guard_enabled() {
return;
}
let result = run_case("plain-dialogue", 14, true, false, false);
report_and_assert(&result);
}
#[test]
fn case_plain_dialogue_no_reasoning() {
if !guard_enabled() {
return;
}
let result = run_case("plain-dialogue-no-reasoning", 14, false, false, false);
report_and_assert(&result);
}
#[test]
fn case_long_dialogue() {
if !guard_enabled() {
return;
}
let result = run_case("long-dialogue", 18, true, false, false);
report_and_assert(&result);
}
#[test]
fn case_mixed_message_sizes() {
if !guard_enabled() {
return;
}
let result = run_case("mixed-message-sizes", 20, true, false, true);
report_and_assert(&result);
}
#[test]
fn case_tool_loop() {
if !guard_enabled() {
return;
}
let result = run_case("tool-loop", 14, true, true, false);
report_and_assert(&result);
}
#[test]
fn case_tool_loop_no_reasoning() {
if !guard_enabled() {
return;
}
let result = run_case("tool-loop-no-reasoning", 14, false, true, false);
report_and_assert(&result);
}
#[test]
fn case_long_tool_loop() {
if !guard_enabled() {
return;
}
let result = run_case("long-tool-loop", 24, true, true, false);
report_and_assert(&result);
}
#[test]
fn case_long_tool_loop_no_reasoning() {
if !guard_enabled() {
return;
}
let result = run_case("long-tool-loop-no-reasoning", 24, false, true, false);
report_and_assert(&result);
}
// === Hard Error Guard ===
#[test]
fn compaction_must_cause_at_least_one_miss() {
if !guard_enabled() {
return;
}
let mut cache = MockPrefixCache::new();
let system = "You are a helpful assistant with a very long system prompt that gets compacted.";
// Simulate 30 turns where compaction happens around turn 20.
// After compaction, the system prompt changes significantly.
for turn in 0..30 {
let body = if turn < 20 {
format!("{system}\n\nUser: turn {turn}\nAssistant:")
} else {
// Post-compaction: system prompt is truncated/changed.
format!("You are a helpful assistant.\n\nUser: turn {turn}\nAssistant:")
};
cache.submit(body.as_bytes());
}
// After compaction, there should be at least one significant miss.
// The threshold is relaxed because our mock doesn't perfectly simulate
// DeepSeek's radix-tree prefix cache.
let post_compaction_rates: Vec<f64> = cache.per_turn_hit_rates[20..].to_vec();
let has_significant_miss = post_compaction_rates.iter().any(|&r| r < 0.8);
if strict() {
assert!(
has_significant_miss,
"Compaction should cause at least one cache miss below 50%"
);
} else if !has_significant_miss {
eprintln!("[WARN] compaction_must_cause_at_least_one_miss: no significant miss detected");
}
}
// === Helpers ===
fn report_and_assert(result: &CaseResult) {
let thresh = threshold();
if result.passed {
eprintln!(
"[OK] {}: tail_avg={:.1}% (overall={:.1}%, {} turns)",
result.name, result.tail_avg, result.overall, result.turns
);
} else {
eprintln!(
"[WARN] {}: tail_avg={:.1}% < threshold={:.1}% (overall={:.1}%, {} turns)",
result.name, result.tail_avg, thresh, result.overall, result.turns
);
if strict() {
panic!(
"[STRICT] {} failed: tail_avg={:.1}% < threshold={:.1}%",
result.name, result.tail_avg, thresh
);
}
}
}