From d2c007833f316c5e9ef95cbfefa6e479e6f5cb22 Mon Sep 17 00:00:00 2001 From: Hunter Bown Date: Fri, 1 May 2026 04:10:59 -0500 Subject: [PATCH] test(rlm): make bridge client seam mockable --- crates/tui/src/llm_client/mod.rs | 5 +- crates/tui/src/rlm/bridge.rs | 205 ++++++++++++++++++++++--------- crates/tui/src/rlm/turn.rs | 76 ++++++++---- 3 files changed, 202 insertions(+), 84 deletions(-) diff --git a/crates/tui/src/llm_client/mod.rs b/crates/tui/src/llm_client/mod.rs index 48b28b18..009f701b 100644 --- a/crates/tui/src/llm_client/mod.rs +++ b/crates/tui/src/llm_client/mod.rs @@ -58,7 +58,10 @@ pub trait LlmClient: Send + Sync { fn model(&self) -> &str; /// Creates a non-streaming message completion - async fn create_message(&self, request: MessageRequest) -> Result; + fn create_message( + &self, + request: MessageRequest, + ) -> impl Future> + Send; /// Creates a streaming message completion /// diff --git a/crates/tui/src/rlm/bridge.rs b/crates/tui/src/rlm/bridge.rs index 36ec6680..6338762b 100644 --- a/crates/tui/src/rlm/bridge.rs +++ b/crates/tui/src/rlm/bridge.rs @@ -13,13 +13,14 @@ use std::sync::Arc; use std::time::Duration; +use std::{future::Future, pin::Pin}; +use anyhow::Result; use futures_util::future::join_all; use tokio::sync::Mutex; -use crate::client::DeepSeekClient; -use crate::llm_client::LlmClient as _; -use crate::models::{ContentBlock, Message, MessageRequest, SystemPrompt, Usage}; +use crate::llm_client::LlmClient; +use crate::models::{ContentBlock, Message, MessageRequest, MessageResponse, SystemPrompt, Usage}; use crate::repl::runtime::{BatchResp, RpcDispatcher, RpcRequest, RpcResponse, SingleResp}; /// Per-child completion timeout — same as the previous sidecar default. @@ -29,18 +30,46 @@ const DEFAULT_CHILD_MAX_TOKENS: u32 = 4096; /// Hard cap on prompts per batch RPC. pub const MAX_BATCH: usize = 16; +/// Object-safe slice of the LLM client interface that the RLM bridge needs. +/// +/// `LlmClient` itself uses native async trait methods, which are not dyn-safe. +/// The bridge only needs non-streaming completions, so this boxed-future shim +/// gives tests a clean mock seam without changing the wider provider trait. +pub(crate) trait RlmLlmClient: Send + Sync { + fn create_message_boxed( + &self, + request: MessageRequest, + ) -> Pin> + Send + '_>>; +} + +impl RlmLlmClient for T +where + T: LlmClient + Send + Sync, +{ + fn create_message_boxed( + &self, + request: MessageRequest, + ) -> Pin> + Send + '_>> { + Box::pin(self.create_message(request)) + } +} + /// State shared with the bridge across all RPC calls in one turn. pub struct RlmBridge { - pub client: DeepSeekClient, - pub child_model: String, + client: Arc, + child_model: String, /// Recursion budget remaining for `Rlm` / `RlmBatch` requests. When /// zero, those requests fall back to plain `Llm` completions. - pub depth_remaining: u32, - pub usage: Arc>, + depth_remaining: u32, + usage: Arc>, } impl RlmBridge { - pub fn new(client: DeepSeekClient, child_model: String, depth_remaining: u32) -> Self { + pub(crate) fn new( + client: Arc, + child_model: String, + depth_remaining: u32, + ) -> Self { Self { client, child_model, @@ -83,7 +112,7 @@ impl RlmBridge { top_p: Some(0.9_f32), }; - let fut = self.client.create_message(request); + let fut = self.client.create_message_boxed(request); let response = match tokio::time::timeout(Duration::from_secs(CHILD_TIMEOUT_SECS), fut).await { Ok(Ok(r)) => r, @@ -165,7 +194,7 @@ impl RlmBridge { // Recursive call. The dyn-erasure on `run_rlm_turn_inner` breaks // the `bridge → turn → bridge` opaque-future cycle. let result = super::turn::run_rlm_turn_inner( - &self.client, + Arc::clone(&self.client), child_model.clone(), prompt, None, @@ -254,43 +283,32 @@ impl RpcDispatcher for RlmBridge { #[cfg(test)] mod tests { use super::*; - use crate::config::{Config, ProviderConfig, ProvidersConfig}; - use serde_json::json; - use wiremock::matchers::{method, path}; - use wiremock::{Mock, MockServer, ResponseTemplate}; + use crate::llm_client::mock::MockLlmClient; - fn client_for(server: &MockServer) -> DeepSeekClient { - let config = Config { - provider: Some("sglang".to_string()), - providers: Some(ProvidersConfig { - sglang: ProviderConfig { - base_url: Some(server.uri()), - ..ProviderConfig::default() - }, - ..ProvidersConfig::default() - }), - ..Config::default() - }; - DeepSeekClient::new(&config).expect("test client") + fn mock_response(text: &str, input_tokens: u32, output_tokens: u32) -> MessageResponse { + MessageResponse { + id: "mock_msg".to_string(), + r#type: "message".to_string(), + role: "assistant".to_string(), + content: vec![ContentBlock::Text { + text: text.to_string(), + cache_control: None, + }], + model: "mock-model".to_string(), + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + container: None, + usage: Usage { + input_tokens, + output_tokens, + ..Usage::default() + }, + } } - fn chat_response(text: &str) -> serde_json::Value { - json!({ - "id": "chatcmpl-test", - "model": "test-model", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": text, - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 3, - "completion_tokens": 5 - } - }) + fn bridge_for(mock: Arc, depth_remaining: u32) -> RlmBridge { + let client: Arc = mock; + RlmBridge::new(client, "child-model".to_string(), depth_remaining) } #[test] @@ -318,22 +336,89 @@ mod tests { } #[tokio::test] - async fn rlm_dispatch_at_depth_zero_falls_back_to_plain_llm_query() { - let server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/v1/responses")) - .respond_with(ResponseTemplate::new(404).set_body_string("responses unavailable")) - .mount(&server) - .await; - Mock::given(method("POST")) - .and(path("/v1/chat/completions")) - .respond_with( - ResponseTemplate::new(200).set_body_json(chat_response("fallback answer")), - ) - .mount(&server) + async fn llm_dispatch_uses_trait_backed_mock_client() { + let mock = Arc::new(MockLlmClient::new(Vec::new())); + mock.push_message_response(mock_response("child answer", 7, 11)); + let bridge = bridge_for(Arc::clone(&mock), 1); + + let response = bridge + .dispatch(RpcRequest::Llm { + prompt: "child prompt".to_string(), + model: Some("override-model".to_string()), + max_tokens: Some(123), + system: Some("child system".to_string()), + }) .await; - let bridge = RlmBridge::new(client_for(&server), "child-model".to_string(), 0); + match response { + RpcResponse::Single(single) => { + assert_eq!(single.text, "child answer"); + assert!(single.error.is_none()); + } + other => panic!("expected single response, got {other:?}"), + } + + let captured = mock.captured_requests(); + assert_eq!(captured.len(), 1); + assert_eq!(captured[0].model, "override-model"); + assert_eq!(captured[0].max_tokens, 123); + assert_eq!( + captured[0].system, + Some(SystemPrompt::Text("child system".to_string())) + ); + + let usage = bridge.usage.lock().await; + assert_eq!(usage.input_tokens, 7); + assert_eq!(usage.output_tokens, 11); + } + + #[tokio::test] + async fn llm_batch_dispatch_preserves_result_count_and_usage() { + let mock = Arc::new(MockLlmClient::new(Vec::new())); + mock.push_message_response(mock_response("one", 1, 2)); + mock.push_message_response(mock_response("two", 3, 4)); + mock.push_message_response(mock_response("three", 5, 6)); + let bridge = bridge_for(Arc::clone(&mock), 1); + + let response = bridge + .dispatch(RpcRequest::LlmBatch { + prompts: vec!["a".to_string(), "b".to_string(), "c".to_string()], + model: Some("batch-model".to_string()), + }) + .await; + + match response { + RpcResponse::Batch(batch) => { + let texts: Vec<_> = batch + .results + .iter() + .map(|result| result.text.as_str()) + .collect(); + assert_eq!(texts, ["one", "two", "three"]); + assert!(batch.results.iter().all(|result| result.error.is_none())); + } + other => panic!("expected batch response, got {other:?}"), + } + + let captured = mock.captured_requests(); + assert_eq!(captured.len(), 3); + assert!( + captured + .iter() + .all(|request| request.model == "batch-model") + ); + + let usage = bridge.usage.lock().await; + assert_eq!(usage.input_tokens, 9); + assert_eq!(usage.output_tokens, 12); + } + + #[tokio::test] + async fn rlm_dispatch_at_depth_zero_falls_back_to_plain_llm_query() { + let mock = Arc::new(MockLlmClient::new(Vec::new())); + mock.push_message_response(mock_response("fallback answer", 3, 5)); + let bridge = bridge_for(Arc::clone(&mock), 0); + let response = bridge .dispatch(RpcRequest::Rlm { prompt: "nested prompt".to_string(), @@ -352,5 +437,9 @@ mod tests { let usage = bridge.usage.lock().await; assert_eq!(usage.input_tokens, 3); assert_eq!(usage.output_tokens, 5); + + let captured = mock.captured_requests(); + assert_eq!(captured.len(), 1); + assert_eq!(captured[0].model, "override-model"); } } diff --git a/crates/tui/src/rlm/turn.rs b/crates/tui/src/rlm/turn.rs index bc73dbf5..c97f5ad1 100644 --- a/crates/tui/src/rlm/turn.rs +++ b/crates/tui/src/rlm/turn.rs @@ -2,6 +2,7 @@ //! subprocess + stdin/stdout RPC bridge (no HTTP sidecar). use std::path::PathBuf; +use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::mpsc; @@ -9,11 +10,10 @@ use uuid::Uuid; use crate::client::DeepSeekClient; use crate::core::events::Event; -use crate::llm_client::LlmClient; -use crate::models::{ContentBlock, Message, MessageRequest, Usage}; +use crate::models::{ContentBlock, Message, MessageRequest, SystemPrompt, Usage}; use crate::repl::PythonRuntime; -use super::bridge::RlmBridge; +use super::bridge::{RlmBridge, RlmLlmClient}; use super::prompt::rlm_system_prompt; // --------------------------------------------------------------------------- @@ -99,7 +99,7 @@ pub async fn run_rlm_turn( max_depth: u32, ) -> RlmTurnResult { run_rlm_turn_inner( - client, + Arc::new(client.clone()), model, prompt, None, @@ -122,7 +122,7 @@ pub async fn run_rlm_turn_with_root( max_depth: u32, ) -> RlmTurnResult { run_rlm_turn_inner( - client, + Arc::new(client.clone()), model, prompt, root_prompt, @@ -136,15 +136,15 @@ pub async fn run_rlm_turn_with_root( /// Inner entry point — also used by the bridge when it recurses. Returns /// a boxed future to break the recursive opaque-future-type cycle: /// `run_rlm_turn_inner` → `RlmBridge::dispatch` → `run_rlm_turn_inner`. -pub(crate) fn run_rlm_turn_inner<'a>( - client: &'a DeepSeekClient, +pub(crate) fn run_rlm_turn_inner( + client: Arc, model: String, prompt: String, root_prompt: Option, child_model: String, tx_event: mpsc::Sender, max_depth: u32, -) -> std::pin::Pin + Send + 'a>> { +) -> std::pin::Pin + Send>> { Box::pin(run_rlm_turn_impl( client, model, @@ -161,7 +161,7 @@ pub(crate) fn run_rlm_turn_inner<'a>( // --------------------------------------------------------------------------- async fn run_rlm_turn_impl( - client: &DeepSeekClient, + client: Arc, model: String, prompt: String, root_prompt: Option, @@ -212,7 +212,7 @@ async fn run_rlm_turn_impl( }; // 3. Build the bridge that services llm_query / rlm_query RPCs. - let bridge = RlmBridge::new(client.clone(), child_model.clone(), max_depth); + let bridge = RlmBridge::new(Arc::clone(&client), child_model.clone(), max_depth); let usage_handle = bridge.usage_handle(); let _ = tx_event @@ -262,22 +262,9 @@ async fn run_rlm_turn_impl( .await; // 4a. Root LLM generates code from metadata-only context. - let request = MessageRequest { - model: model.clone(), - messages: messages.clone(), - max_tokens: ROOT_MAX_TOKENS, - system: Some(system.clone()), - tools: None, - tool_choice: None, - metadata: None, - thinking: None, - reasoning_effort: None, - stream: Some(false), - temperature: Some(ROOT_TEMPERATURE), - top_p: Some(0.9_f32), - }; + let request = build_root_request(&model, &messages, &system); - let response = match client.create_message(request).await { + let response = match client.create_message_boxed(request).await { Ok(r) => r, Err(e) => { break 'turn RlmTurnResult { @@ -551,6 +538,23 @@ fn write_context_file(prompt: &str) -> std::io::Result { Ok(path) } +fn build_root_request(model: &str, messages: &[Message], system: &SystemPrompt) -> MessageRequest { + MessageRequest { + model: model.to_string(), + messages: messages.to_vec(), + max_tokens: ROOT_MAX_TOKENS, + system: Some(system.clone()), + tools: None, + tool_choice: None, + metadata: None, + thinking: None, + reasoning_effort: None, + stream: Some(false), + temperature: Some(ROOT_TEMPERATURE), + top_p: Some(0.9_f32), + } +} + /// Build `Metadata(state)` from the paper. Surfaces: /// - the small `root_prompt` (if any) — repeated each iteration /// - `context` length + preview @@ -851,6 +855,28 @@ mod tests { ); } + #[test] + fn build_root_request_keeps_context_tail_out_of_root_payload() { + let secret_tail = "DO_NOT_LEAK_ROOT_REQUEST"; + let prompt = format!("{}{}", "a".repeat(PROMPT_PREVIEW_LEN + 100), secret_tail); + let messages = vec![build_metadata_message( + &prompt, + Some("answer from the long context"), + 0, + None, + None, + )]; + + let request = build_root_request("root-model", &messages, &rlm_system_prompt()); + let payload = serde_json::to_string(&request).expect("request should serialize"); + + assert!(payload.contains(&format!("- Length: {} chars", prompt.chars().count()))); + assert!( + !payload.contains(secret_tail), + "root LLM request leaked the non-preview tail of context" + ); + } + #[test] fn build_metadata_with_iteration_shows_previous_code() { let msg = build_metadata_message("Test prompt", None, 3, Some("print('hi')"), Some("hi"));