fix(rlm): preserve prompt cache usage (#1127)
* fix(rlm): preserve prompt cache usage * refactor(rlm): share prompt cache usage helper
This commit is contained in:
@@ -145,8 +145,7 @@ impl RlmBridge {
|
||||
|
||||
{
|
||||
let mut u = self.usage.lock().await;
|
||||
u.input_tokens = u.input_tokens.saturating_add(response.usage.input_tokens);
|
||||
u.output_tokens = u.output_tokens.saturating_add(response.usage.output_tokens);
|
||||
super::add_usage_with_prompt_cache(&mut u, &response.usage);
|
||||
}
|
||||
|
||||
SingleResp { text, error: None }
|
||||
@@ -209,8 +208,7 @@ impl RlmBridge {
|
||||
|
||||
{
|
||||
let mut u = self.usage.lock().await;
|
||||
u.input_tokens = u.input_tokens.saturating_add(result.usage.input_tokens);
|
||||
u.output_tokens = u.output_tokens.saturating_add(result.usage.output_tokens);
|
||||
super::add_usage_with_prompt_cache(&mut u, &result.usage);
|
||||
}
|
||||
|
||||
SingleResp {
|
||||
@@ -284,7 +282,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::llm_client::mock::MockLlmClient;
|
||||
|
||||
fn mock_response(text: &str, input_tokens: u32, output_tokens: u32) -> MessageResponse {
|
||||
fn mock_response_with_usage(text: &str, usage: Usage) -> MessageResponse {
|
||||
MessageResponse {
|
||||
id: "mock_msg".to_string(),
|
||||
r#type: "message".to_string(),
|
||||
@@ -297,12 +295,19 @@ mod tests {
|
||||
stop_reason: Some("end_turn".to_string()),
|
||||
stop_sequence: None,
|
||||
container: None,
|
||||
usage: Usage {
|
||||
usage,
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_response(text: &str, input_tokens: u32, output_tokens: u32) -> MessageResponse {
|
||||
mock_response_with_usage(
|
||||
text,
|
||||
Usage {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
..Usage::default()
|
||||
},
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
fn bridge_for(mock: Arc<MockLlmClient>, depth_remaining: u32) -> RlmBridge {
|
||||
@@ -371,6 +376,45 @@ mod tests {
|
||||
assert_eq!(usage.output_tokens, 11);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn llm_dispatch_preserves_prompt_cache_usage() {
|
||||
let mock = Arc::new(MockLlmClient::new(Vec::new()));
|
||||
mock.push_message_response(mock_response_with_usage(
|
||||
"cached child answer",
|
||||
Usage {
|
||||
input_tokens: 1000,
|
||||
output_tokens: 100,
|
||||
prompt_cache_hit_tokens: Some(800),
|
||||
prompt_cache_miss_tokens: Some(200),
|
||||
..Usage::default()
|
||||
},
|
||||
));
|
||||
let bridge = bridge_for(Arc::clone(&mock), 1);
|
||||
|
||||
let response = bridge
|
||||
.dispatch(RpcRequest::Llm {
|
||||
prompt: "child prompt".to_string(),
|
||||
model: None,
|
||||
max_tokens: None,
|
||||
system: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
match response {
|
||||
RpcResponse::Single(single) => {
|
||||
assert_eq!(single.text, "cached child answer");
|
||||
assert!(single.error.is_none());
|
||||
}
|
||||
other => panic!("expected single response, got {other:?}"),
|
||||
}
|
||||
|
||||
let usage = bridge.usage.lock().await;
|
||||
assert_eq!(usage.input_tokens, 1000);
|
||||
assert_eq!(usage.output_tokens, 100);
|
||||
assert_eq!(usage.prompt_cache_hit_tokens, Some(800));
|
||||
assert_eq!(usage.prompt_cache_miss_tokens, Some(200));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn llm_batch_dispatch_pins_configured_child_model() {
|
||||
let mock = Arc::new(MockLlmClient::new(Vec::new()));
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
//! - Code rounds and sub-LLM calls travel over a single stdin/stdout
|
||||
//! pipe to a long-lived `python3 -u` subprocess. No HTTP sidecar.
|
||||
|
||||
use crate::models::Usage;
|
||||
|
||||
pub mod bridge;
|
||||
pub mod prompt;
|
||||
pub mod turn;
|
||||
@@ -29,3 +31,53 @@ pub mod turn;
|
||||
pub use bridge::RlmBridge;
|
||||
pub use prompt::rlm_system_prompt;
|
||||
pub use turn::{RlmTermination, RlmTurnResult, run_rlm_turn, run_rlm_turn_with_root};
|
||||
|
||||
fn add_usage_with_prompt_cache(total: &mut Usage, delta: &Usage) {
|
||||
total.input_tokens = total.input_tokens.saturating_add(delta.input_tokens);
|
||||
total.output_tokens = total.output_tokens.saturating_add(delta.output_tokens);
|
||||
total.prompt_cache_hit_tokens =
|
||||
add_optional_usage(total.prompt_cache_hit_tokens, delta.prompt_cache_hit_tokens);
|
||||
total.prompt_cache_miss_tokens = add_optional_usage(
|
||||
total.prompt_cache_miss_tokens,
|
||||
delta.prompt_cache_miss_tokens,
|
||||
);
|
||||
}
|
||||
|
||||
fn add_optional_usage(total: Option<u32>, delta: Option<u32>) -> Option<u32> {
|
||||
match (total, delta) {
|
||||
(Some(total), Some(delta)) => Some(total.saturating_add(delta)),
|
||||
(None, Some(delta)) => Some(delta),
|
||||
(Some(total), None) => Some(total),
|
||||
(None, None) => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn add_usage_with_prompt_cache_preserves_cache_counts() {
|
||||
let mut total = Usage {
|
||||
input_tokens: 100,
|
||||
output_tokens: 10,
|
||||
prompt_cache_hit_tokens: Some(80),
|
||||
prompt_cache_miss_tokens: Some(20),
|
||||
..Usage::default()
|
||||
};
|
||||
let delta = Usage {
|
||||
input_tokens: 50,
|
||||
output_tokens: 5,
|
||||
prompt_cache_hit_tokens: Some(30),
|
||||
prompt_cache_miss_tokens: Some(20),
|
||||
..Usage::default()
|
||||
};
|
||||
|
||||
add_usage_with_prompt_cache(&mut total, &delta);
|
||||
|
||||
assert_eq!(total.input_tokens, 150);
|
||||
assert_eq!(total.output_tokens, 15);
|
||||
assert_eq!(total.prompt_cache_hit_tokens, Some(110));
|
||||
assert_eq!(total.prompt_cache_miss_tokens, Some(40));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,12 +284,7 @@ async fn run_rlm_turn_impl(
|
||||
}
|
||||
};
|
||||
|
||||
total_usage.input_tokens = total_usage
|
||||
.input_tokens
|
||||
.saturating_add(response.usage.input_tokens);
|
||||
total_usage.output_tokens = total_usage
|
||||
.output_tokens
|
||||
.saturating_add(response.usage.output_tokens);
|
||||
super::add_usage_with_prompt_cache(&mut total_usage, &response.usage);
|
||||
|
||||
let response_text = extract_text_blocks(&response.content);
|
||||
last_response_text = response_text.clone();
|
||||
@@ -510,12 +505,7 @@ async fn run_rlm_turn_impl(
|
||||
// Fold bridge usage (children + nested sub_rlm) into totals.
|
||||
let bridge_usage = usage_handle.lock().await;
|
||||
let mut final_usage = result.usage.clone();
|
||||
final_usage.input_tokens = final_usage
|
||||
.input_tokens
|
||||
.saturating_add(bridge_usage.input_tokens);
|
||||
final_usage.output_tokens = final_usage
|
||||
.output_tokens
|
||||
.saturating_add(bridge_usage.output_tokens);
|
||||
super::add_usage_with_prompt_cache(&mut final_usage, &bridge_usage);
|
||||
drop(bridge_usage);
|
||||
|
||||
repl.shutdown().await;
|
||||
|
||||
Reference in New Issue
Block a user