diff --git a/crates/tui/src/rlm/bridge.rs b/crates/tui/src/rlm/bridge.rs index 904ddef7..3e30dd74 100644 --- a/crates/tui/src/rlm/bridge.rs +++ b/crates/tui/src/rlm/bridge.rs @@ -145,8 +145,7 @@ impl RlmBridge { { let mut u = self.usage.lock().await; - u.input_tokens = u.input_tokens.saturating_add(response.usage.input_tokens); - u.output_tokens = u.output_tokens.saturating_add(response.usage.output_tokens); + super::add_usage_with_prompt_cache(&mut u, &response.usage); } SingleResp { text, error: None } @@ -209,8 +208,7 @@ impl RlmBridge { { let mut u = self.usage.lock().await; - u.input_tokens = u.input_tokens.saturating_add(result.usage.input_tokens); - u.output_tokens = u.output_tokens.saturating_add(result.usage.output_tokens); + super::add_usage_with_prompt_cache(&mut u, &result.usage); } SingleResp { @@ -284,7 +282,7 @@ mod tests { use super::*; use crate::llm_client::mock::MockLlmClient; - fn mock_response(text: &str, input_tokens: u32, output_tokens: u32) -> MessageResponse { + fn mock_response_with_usage(text: &str, usage: Usage) -> MessageResponse { MessageResponse { id: "mock_msg".to_string(), r#type: "message".to_string(), @@ -297,12 +295,19 @@ mod tests { stop_reason: Some("end_turn".to_string()), stop_sequence: None, container: None, - usage: Usage { + usage, + } + } + + fn mock_response(text: &str, input_tokens: u32, output_tokens: u32) -> MessageResponse { + mock_response_with_usage( + text, + Usage { input_tokens, output_tokens, ..Usage::default() }, - } + ) } fn bridge_for(mock: Arc, depth_remaining: u32) -> RlmBridge { @@ -371,6 +376,45 @@ mod tests { assert_eq!(usage.output_tokens, 11); } + #[tokio::test] + async fn llm_dispatch_preserves_prompt_cache_usage() { + let mock = Arc::new(MockLlmClient::new(Vec::new())); + mock.push_message_response(mock_response_with_usage( + "cached child answer", + Usage { + input_tokens: 1000, + output_tokens: 100, + prompt_cache_hit_tokens: Some(800), + prompt_cache_miss_tokens: Some(200), + ..Usage::default() + }, + )); + let bridge = bridge_for(Arc::clone(&mock), 1); + + let response = bridge + .dispatch(RpcRequest::Llm { + prompt: "child prompt".to_string(), + model: None, + max_tokens: None, + system: None, + }) + .await; + + match response { + RpcResponse::Single(single) => { + assert_eq!(single.text, "cached child answer"); + assert!(single.error.is_none()); + } + other => panic!("expected single response, got {other:?}"), + } + + let usage = bridge.usage.lock().await; + assert_eq!(usage.input_tokens, 1000); + assert_eq!(usage.output_tokens, 100); + assert_eq!(usage.prompt_cache_hit_tokens, Some(800)); + assert_eq!(usage.prompt_cache_miss_tokens, Some(200)); + } + #[tokio::test] async fn llm_batch_dispatch_pins_configured_child_model() { let mock = Arc::new(MockLlmClient::new(Vec::new())); diff --git a/crates/tui/src/rlm/mod.rs b/crates/tui/src/rlm/mod.rs index fbd76776..44959983 100644 --- a/crates/tui/src/rlm/mod.rs +++ b/crates/tui/src/rlm/mod.rs @@ -22,6 +22,8 @@ //! - Code rounds and sub-LLM calls travel over a single stdin/stdout //! pipe to a long-lived `python3 -u` subprocess. No HTTP sidecar. +use crate::models::Usage; + pub mod bridge; pub mod prompt; pub mod turn; @@ -29,3 +31,53 @@ pub mod turn; pub use bridge::RlmBridge; pub use prompt::rlm_system_prompt; pub use turn::{RlmTermination, RlmTurnResult, run_rlm_turn, run_rlm_turn_with_root}; + +fn add_usage_with_prompt_cache(total: &mut Usage, delta: &Usage) { + total.input_tokens = total.input_tokens.saturating_add(delta.input_tokens); + total.output_tokens = total.output_tokens.saturating_add(delta.output_tokens); + total.prompt_cache_hit_tokens = + add_optional_usage(total.prompt_cache_hit_tokens, delta.prompt_cache_hit_tokens); + total.prompt_cache_miss_tokens = add_optional_usage( + total.prompt_cache_miss_tokens, + delta.prompt_cache_miss_tokens, + ); +} + +fn add_optional_usage(total: Option, delta: Option) -> Option { + match (total, delta) { + (Some(total), Some(delta)) => Some(total.saturating_add(delta)), + (None, Some(delta)) => Some(delta), + (Some(total), None) => Some(total), + (None, None) => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn add_usage_with_prompt_cache_preserves_cache_counts() { + let mut total = Usage { + input_tokens: 100, + output_tokens: 10, + prompt_cache_hit_tokens: Some(80), + prompt_cache_miss_tokens: Some(20), + ..Usage::default() + }; + let delta = Usage { + input_tokens: 50, + output_tokens: 5, + prompt_cache_hit_tokens: Some(30), + prompt_cache_miss_tokens: Some(20), + ..Usage::default() + }; + + add_usage_with_prompt_cache(&mut total, &delta); + + assert_eq!(total.input_tokens, 150); + assert_eq!(total.output_tokens, 15); + assert_eq!(total.prompt_cache_hit_tokens, Some(110)); + assert_eq!(total.prompt_cache_miss_tokens, Some(40)); + } +} diff --git a/crates/tui/src/rlm/turn.rs b/crates/tui/src/rlm/turn.rs index 676d9e1e..5f8a3881 100644 --- a/crates/tui/src/rlm/turn.rs +++ b/crates/tui/src/rlm/turn.rs @@ -284,12 +284,7 @@ async fn run_rlm_turn_impl( } }; - total_usage.input_tokens = total_usage - .input_tokens - .saturating_add(response.usage.input_tokens); - total_usage.output_tokens = total_usage - .output_tokens - .saturating_add(response.usage.output_tokens); + super::add_usage_with_prompt_cache(&mut total_usage, &response.usage); let response_text = extract_text_blocks(&response.content); last_response_text = response_text.clone(); @@ -510,12 +505,7 @@ async fn run_rlm_turn_impl( // Fold bridge usage (children + nested sub_rlm) into totals. let bridge_usage = usage_handle.lock().await; let mut final_usage = result.usage.clone(); - final_usage.input_tokens = final_usage - .input_tokens - .saturating_add(bridge_usage.input_tokens); - final_usage.output_tokens = final_usage - .output_tokens - .saturating_add(bridge_usage.output_tokens); + super::add_usage_with_prompt_cache(&mut final_usage, &bridge_usage); drop(bridge_usage); repl.shutdown().await;