diff --git a/crates/tui/src/rlm/bridge.rs b/crates/tui/src/rlm/bridge.rs index 0b8e7754..36ec6680 100644 --- a/crates/tui/src/rlm/bridge.rs +++ b/crates/tui/src/rlm/bridge.rs @@ -121,19 +121,8 @@ impl RlmBridge { } async fn dispatch_llm_batch(&self, prompts: Vec, model: Option) -> BatchResp { - if prompts.is_empty() { - return BatchResp { results: vec![] }; - } - if prompts.len() > MAX_BATCH { - return BatchResp { - results: prompts - .iter() - .map(|_| SingleResp { - text: String::new(), - error: Some(format!("batch too large: {} > {MAX_BATCH}", prompts.len())), - }) - .collect(), - }; + if let Some(resp) = batch_guard(prompts.len()) { + return resp; } let model = Arc::new( @@ -201,19 +190,8 @@ impl RlmBridge { } async fn dispatch_rlm_batch(&self, prompts: Vec, model: Option) -> BatchResp { - if prompts.is_empty() { - return BatchResp { results: vec![] }; - } - if prompts.len() > MAX_BATCH { - return BatchResp { - results: prompts - .iter() - .map(|_| SingleResp { - text: String::new(), - error: Some(format!("batch too large: {} > {MAX_BATCH}", prompts.len())), - }) - .collect(), - }; + if let Some(resp) = batch_guard(prompts.len()) { + return resp; } let model = Arc::new(model); @@ -227,6 +205,23 @@ impl RlmBridge { } } +fn batch_guard(prompt_count: usize) -> Option { + if prompt_count == 0 { + return Some(BatchResp { results: vec![] }); + } + if prompt_count > MAX_BATCH { + return Some(BatchResp { + results: (0..prompt_count) + .map(|_| SingleResp { + text: String::new(), + error: Some(format!("batch too large: {prompt_count} > {MAX_BATCH}")), + }) + .collect(), + }); + } + None +} + impl RpcDispatcher for RlmBridge { fn dispatch<'a>( &'a self, @@ -255,3 +250,107 @@ impl RpcDispatcher for RlmBridge { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{Config, ProviderConfig, ProvidersConfig}; + use serde_json::json; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + fn client_for(server: &MockServer) -> DeepSeekClient { + let config = Config { + provider: Some("sglang".to_string()), + providers: Some(ProvidersConfig { + sglang: ProviderConfig { + base_url: Some(server.uri()), + ..ProviderConfig::default() + }, + ..ProvidersConfig::default() + }), + ..Config::default() + }; + DeepSeekClient::new(&config).expect("test client") + } + + fn chat_response(text: &str) -> serde_json::Value { + json!({ + "id": "chatcmpl-test", + "model": "test-model", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": text, + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 3, + "completion_tokens": 5 + } + }) + } + + #[test] + fn batch_guard_allows_non_empty_batches_at_the_cap() { + assert!(batch_guard(MAX_BATCH).is_none()); + } + + #[test] + fn batch_guard_returns_empty_response_for_empty_batches() { + let response = batch_guard(0).expect("empty batch should be handled"); + assert!(response.results.is_empty()); + } + + #[test] + fn batch_guard_returns_one_error_per_oversized_prompt() { + let response = batch_guard(MAX_BATCH + 2).expect("oversized batch should be handled"); + assert_eq!(response.results.len(), MAX_BATCH + 2); + assert!(response.results.iter().all(|result| { + result.text.is_empty() + && result + .error + .as_deref() + .is_some_and(|err| err.contains("batch too large")) + })); + } + + #[tokio::test] + async fn rlm_dispatch_at_depth_zero_falls_back_to_plain_llm_query() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(ResponseTemplate::new(404).set_body_string("responses unavailable")) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with( + ResponseTemplate::new(200).set_body_json(chat_response("fallback answer")), + ) + .mount(&server) + .await; + + let bridge = RlmBridge::new(client_for(&server), "child-model".to_string(), 0); + let response = bridge + .dispatch(RpcRequest::Rlm { + prompt: "nested prompt".to_string(), + model: Some("override-model".to_string()), + }) + .await; + + match response { + RpcResponse::Single(single) => { + assert_eq!(single.text, "fallback answer"); + assert!(single.error.is_none()); + } + other => panic!("expected single response, got {other:?}"), + } + + let usage = bridge.usage.lock().await; + assert_eq!(usage.input_tokens, 3); + assert_eq!(usage.output_tokens, 5); + } +} diff --git a/crates/tui/src/rlm/turn.rs b/crates/tui/src/rlm/turn.rs index 08d59818..bc73dbf5 100644 --- a/crates/tui/src/rlm/turn.rs +++ b/crates/tui/src/rlm/turn.rs @@ -835,6 +835,22 @@ mod tests { assert!(text.contains("FINAL")); } + #[test] + fn build_metadata_truncates_long_context_without_leaking_tail() { + let secret_tail = "DO_NOT_LEAK_CONTEXT_TAIL"; + let prompt = format!("{}{}", "a".repeat(PROMPT_PREVIEW_LEN + 100), secret_tail); + let msg = build_metadata_message(&prompt, None, 0, None, None); + let text = extract_text_blocks(&msg.content); + + assert!(text.contains(&format!("- Length: {} chars", prompt.chars().count()))); + assert!(text.contains("- Preview: \"")); + assert!(text.contains("...")); + assert!( + !text.contains(secret_tail), + "metadata leaked the non-preview tail of context" + ); + } + #[test] fn build_metadata_with_iteration_shows_previous_code() { let msg = build_metadata_message("Test prompt", None, 3, Some("print('hi')"), Some("hi"));