fix(rlm): preserve prompt cache usage (#1127)

* fix(rlm): preserve prompt cache usage

* refactor(rlm): share prompt cache usage helper
This commit is contained in:
Sun
2026-05-08 14:51:24 +08:00
committed by GitHub
parent fa32e7ac53
commit 2904d817fa
3 changed files with 105 additions and 19 deletions
+51 -7
View File
@@ -145,8 +145,7 @@ impl RlmBridge {
{
let mut u = self.usage.lock().await;
u.input_tokens = u.input_tokens.saturating_add(response.usage.input_tokens);
u.output_tokens = u.output_tokens.saturating_add(response.usage.output_tokens);
super::add_usage_with_prompt_cache(&mut u, &response.usage);
}
SingleResp { text, error: None }
@@ -209,8 +208,7 @@ impl RlmBridge {
{
let mut u = self.usage.lock().await;
u.input_tokens = u.input_tokens.saturating_add(result.usage.input_tokens);
u.output_tokens = u.output_tokens.saturating_add(result.usage.output_tokens);
super::add_usage_with_prompt_cache(&mut u, &result.usage);
}
SingleResp {
@@ -284,7 +282,7 @@ mod tests {
use super::*;
use crate::llm_client::mock::MockLlmClient;
fn mock_response(text: &str, input_tokens: u32, output_tokens: u32) -> MessageResponse {
fn mock_response_with_usage(text: &str, usage: Usage) -> MessageResponse {
MessageResponse {
id: "mock_msg".to_string(),
r#type: "message".to_string(),
@@ -297,12 +295,19 @@ mod tests {
stop_reason: Some("end_turn".to_string()),
stop_sequence: None,
container: None,
usage: Usage {
usage,
}
}
fn mock_response(text: &str, input_tokens: u32, output_tokens: u32) -> MessageResponse {
mock_response_with_usage(
text,
Usage {
input_tokens,
output_tokens,
..Usage::default()
},
}
)
}
fn bridge_for(mock: Arc<MockLlmClient>, depth_remaining: u32) -> RlmBridge {
@@ -371,6 +376,45 @@ mod tests {
assert_eq!(usage.output_tokens, 11);
}
#[tokio::test]
async fn llm_dispatch_preserves_prompt_cache_usage() {
let mock = Arc::new(MockLlmClient::new(Vec::new()));
mock.push_message_response(mock_response_with_usage(
"cached child answer",
Usage {
input_tokens: 1000,
output_tokens: 100,
prompt_cache_hit_tokens: Some(800),
prompt_cache_miss_tokens: Some(200),
..Usage::default()
},
));
let bridge = bridge_for(Arc::clone(&mock), 1);
let response = bridge
.dispatch(RpcRequest::Llm {
prompt: "child prompt".to_string(),
model: None,
max_tokens: None,
system: None,
})
.await;
match response {
RpcResponse::Single(single) => {
assert_eq!(single.text, "cached child answer");
assert!(single.error.is_none());
}
other => panic!("expected single response, got {other:?}"),
}
let usage = bridge.usage.lock().await;
assert_eq!(usage.input_tokens, 1000);
assert_eq!(usage.output_tokens, 100);
assert_eq!(usage.prompt_cache_hit_tokens, Some(800));
assert_eq!(usage.prompt_cache_miss_tokens, Some(200));
}
#[tokio::test]
async fn llm_batch_dispatch_pins_configured_child_model() {
let mock = Arc::new(MockLlmClient::new(Vec::new()));
+52
View File
@@ -22,6 +22,8 @@
//! - Code rounds and sub-LLM calls travel over a single stdin/stdout
//! pipe to a long-lived `python3 -u` subprocess. No HTTP sidecar.
use crate::models::Usage;
pub mod bridge;
pub mod prompt;
pub mod turn;
@@ -29,3 +31,53 @@ pub mod turn;
pub use bridge::RlmBridge;
pub use prompt::rlm_system_prompt;
pub use turn::{RlmTermination, RlmTurnResult, run_rlm_turn, run_rlm_turn_with_root};
fn add_usage_with_prompt_cache(total: &mut Usage, delta: &Usage) {
total.input_tokens = total.input_tokens.saturating_add(delta.input_tokens);
total.output_tokens = total.output_tokens.saturating_add(delta.output_tokens);
total.prompt_cache_hit_tokens =
add_optional_usage(total.prompt_cache_hit_tokens, delta.prompt_cache_hit_tokens);
total.prompt_cache_miss_tokens = add_optional_usage(
total.prompt_cache_miss_tokens,
delta.prompt_cache_miss_tokens,
);
}
fn add_optional_usage(total: Option<u32>, delta: Option<u32>) -> Option<u32> {
match (total, delta) {
(Some(total), Some(delta)) => Some(total.saturating_add(delta)),
(None, Some(delta)) => Some(delta),
(Some(total), None) => Some(total),
(None, None) => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn add_usage_with_prompt_cache_preserves_cache_counts() {
let mut total = Usage {
input_tokens: 100,
output_tokens: 10,
prompt_cache_hit_tokens: Some(80),
prompt_cache_miss_tokens: Some(20),
..Usage::default()
};
let delta = Usage {
input_tokens: 50,
output_tokens: 5,
prompt_cache_hit_tokens: Some(30),
prompt_cache_miss_tokens: Some(20),
..Usage::default()
};
add_usage_with_prompt_cache(&mut total, &delta);
assert_eq!(total.input_tokens, 150);
assert_eq!(total.output_tokens, 15);
assert_eq!(total.prompt_cache_hit_tokens, Some(110));
assert_eq!(total.prompt_cache_miss_tokens, Some(40));
}
}
+2 -12
View File
@@ -284,12 +284,7 @@ async fn run_rlm_turn_impl(
}
};
total_usage.input_tokens = total_usage
.input_tokens
.saturating_add(response.usage.input_tokens);
total_usage.output_tokens = total_usage
.output_tokens
.saturating_add(response.usage.output_tokens);
super::add_usage_with_prompt_cache(&mut total_usage, &response.usage);
let response_text = extract_text_blocks(&response.content);
last_response_text = response_text.clone();
@@ -510,12 +505,7 @@ async fn run_rlm_turn_impl(
// Fold bridge usage (children + nested sub_rlm) into totals.
let bridge_usage = usage_handle.lock().await;
let mut final_usage = result.usage.clone();
final_usage.input_tokens = final_usage
.input_tokens
.saturating_add(bridge_usage.input_tokens);
final_usage.output_tokens = final_usage
.output_tokens
.saturating_add(bridge_usage.output_tokens);
super::add_usage_with_prompt_cache(&mut final_usage, &bridge_usage);
drop(bridge_usage);
repl.shutdown().await;