test(rlm): make bridge client seam mockable

This commit is contained in:
Hunter Bown
2026-05-01 04:10:59 -05:00
parent 2cf0c20c76
commit d2c007833f
3 changed files with 202 additions and 84 deletions
+4 -1
View File
@@ -58,7 +58,10 @@ pub trait LlmClient: Send + Sync {
fn model(&self) -> &str;
/// Creates a non-streaming message completion
async fn create_message(&self, request: MessageRequest) -> Result<MessageResponse>;
fn create_message(
&self,
request: MessageRequest,
) -> impl Future<Output = Result<MessageResponse>> + Send;
/// Creates a streaming message completion
///
+147 -58
View File
@@ -13,13 +13,14 @@
use std::sync::Arc;
use std::time::Duration;
use std::{future::Future, pin::Pin};
use anyhow::Result;
use futures_util::future::join_all;
use tokio::sync::Mutex;
use crate::client::DeepSeekClient;
use crate::llm_client::LlmClient as _;
use crate::models::{ContentBlock, Message, MessageRequest, SystemPrompt, Usage};
use crate::llm_client::LlmClient;
use crate::models::{ContentBlock, Message, MessageRequest, MessageResponse, SystemPrompt, Usage};
use crate::repl::runtime::{BatchResp, RpcDispatcher, RpcRequest, RpcResponse, SingleResp};
/// Per-child completion timeout — same as the previous sidecar default.
@@ -29,18 +30,46 @@ const DEFAULT_CHILD_MAX_TOKENS: u32 = 4096;
/// Hard cap on prompts per batch RPC.
pub const MAX_BATCH: usize = 16;
/// Object-safe slice of the LLM client interface that the RLM bridge needs.
///
/// `LlmClient` itself uses native async trait methods, which are not dyn-safe.
/// The bridge only needs non-streaming completions, so this boxed-future shim
/// gives tests a clean mock seam without changing the wider provider trait.
pub(crate) trait RlmLlmClient: Send + Sync {
fn create_message_boxed(
&self,
request: MessageRequest,
) -> Pin<Box<dyn Future<Output = Result<MessageResponse>> + Send + '_>>;
}
impl<T> RlmLlmClient for T
where
T: LlmClient + Send + Sync,
{
fn create_message_boxed(
&self,
request: MessageRequest,
) -> Pin<Box<dyn Future<Output = Result<MessageResponse>> + Send + '_>> {
Box::pin(self.create_message(request))
}
}
/// State shared with the bridge across all RPC calls in one turn.
pub struct RlmBridge {
pub client: DeepSeekClient,
pub child_model: String,
client: Arc<dyn RlmLlmClient>,
child_model: String,
/// Recursion budget remaining for `Rlm` / `RlmBatch` requests. When
/// zero, those requests fall back to plain `Llm` completions.
pub depth_remaining: u32,
pub usage: Arc<Mutex<Usage>>,
depth_remaining: u32,
usage: Arc<Mutex<Usage>>,
}
impl RlmBridge {
pub fn new(client: DeepSeekClient, child_model: String, depth_remaining: u32) -> Self {
pub(crate) fn new(
client: Arc<dyn RlmLlmClient>,
child_model: String,
depth_remaining: u32,
) -> Self {
Self {
client,
child_model,
@@ -83,7 +112,7 @@ impl RlmBridge {
top_p: Some(0.9_f32),
};
let fut = self.client.create_message(request);
let fut = self.client.create_message_boxed(request);
let response =
match tokio::time::timeout(Duration::from_secs(CHILD_TIMEOUT_SECS), fut).await {
Ok(Ok(r)) => r,
@@ -165,7 +194,7 @@ impl RlmBridge {
// Recursive call. The dyn-erasure on `run_rlm_turn_inner` breaks
// the `bridge → turn → bridge` opaque-future cycle.
let result = super::turn::run_rlm_turn_inner(
&self.client,
Arc::clone(&self.client),
child_model.clone(),
prompt,
None,
@@ -254,43 +283,32 @@ impl RpcDispatcher for RlmBridge {
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{Config, ProviderConfig, ProvidersConfig};
use serde_json::json;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use crate::llm_client::mock::MockLlmClient;
fn client_for(server: &MockServer) -> DeepSeekClient {
let config = Config {
provider: Some("sglang".to_string()),
providers: Some(ProvidersConfig {
sglang: ProviderConfig {
base_url: Some(server.uri()),
..ProviderConfig::default()
},
..ProvidersConfig::default()
}),
..Config::default()
};
DeepSeekClient::new(&config).expect("test client")
fn mock_response(text: &str, input_tokens: u32, output_tokens: u32) -> MessageResponse {
MessageResponse {
id: "mock_msg".to_string(),
r#type: "message".to_string(),
role: "assistant".to_string(),
content: vec![ContentBlock::Text {
text: text.to_string(),
cache_control: None,
}],
model: "mock-model".to_string(),
stop_reason: Some("end_turn".to_string()),
stop_sequence: None,
container: None,
usage: Usage {
input_tokens,
output_tokens,
..Usage::default()
},
}
}
fn chat_response(text: &str) -> serde_json::Value {
json!({
"id": "chatcmpl-test",
"model": "test-model",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": text,
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 3,
"completion_tokens": 5
}
})
fn bridge_for(mock: Arc<MockLlmClient>, depth_remaining: u32) -> RlmBridge {
let client: Arc<dyn RlmLlmClient> = mock;
RlmBridge::new(client, "child-model".to_string(), depth_remaining)
}
#[test]
@@ -318,22 +336,89 @@ mod tests {
}
#[tokio::test]
async fn rlm_dispatch_at_depth_zero_falls_back_to_plain_llm_query() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(ResponseTemplate::new(404).set_body_string("responses unavailable"))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200).set_body_json(chat_response("fallback answer")),
)
.mount(&server)
async fn llm_dispatch_uses_trait_backed_mock_client() {
let mock = Arc::new(MockLlmClient::new(Vec::new()));
mock.push_message_response(mock_response("child answer", 7, 11));
let bridge = bridge_for(Arc::clone(&mock), 1);
let response = bridge
.dispatch(RpcRequest::Llm {
prompt: "child prompt".to_string(),
model: Some("override-model".to_string()),
max_tokens: Some(123),
system: Some("child system".to_string()),
})
.await;
let bridge = RlmBridge::new(client_for(&server), "child-model".to_string(), 0);
match response {
RpcResponse::Single(single) => {
assert_eq!(single.text, "child answer");
assert!(single.error.is_none());
}
other => panic!("expected single response, got {other:?}"),
}
let captured = mock.captured_requests();
assert_eq!(captured.len(), 1);
assert_eq!(captured[0].model, "override-model");
assert_eq!(captured[0].max_tokens, 123);
assert_eq!(
captured[0].system,
Some(SystemPrompt::Text("child system".to_string()))
);
let usage = bridge.usage.lock().await;
assert_eq!(usage.input_tokens, 7);
assert_eq!(usage.output_tokens, 11);
}
#[tokio::test]
async fn llm_batch_dispatch_preserves_result_count_and_usage() {
let mock = Arc::new(MockLlmClient::new(Vec::new()));
mock.push_message_response(mock_response("one", 1, 2));
mock.push_message_response(mock_response("two", 3, 4));
mock.push_message_response(mock_response("three", 5, 6));
let bridge = bridge_for(Arc::clone(&mock), 1);
let response = bridge
.dispatch(RpcRequest::LlmBatch {
prompts: vec!["a".to_string(), "b".to_string(), "c".to_string()],
model: Some("batch-model".to_string()),
})
.await;
match response {
RpcResponse::Batch(batch) => {
let texts: Vec<_> = batch
.results
.iter()
.map(|result| result.text.as_str())
.collect();
assert_eq!(texts, ["one", "two", "three"]);
assert!(batch.results.iter().all(|result| result.error.is_none()));
}
other => panic!("expected batch response, got {other:?}"),
}
let captured = mock.captured_requests();
assert_eq!(captured.len(), 3);
assert!(
captured
.iter()
.all(|request| request.model == "batch-model")
);
let usage = bridge.usage.lock().await;
assert_eq!(usage.input_tokens, 9);
assert_eq!(usage.output_tokens, 12);
}
#[tokio::test]
async fn rlm_dispatch_at_depth_zero_falls_back_to_plain_llm_query() {
let mock = Arc::new(MockLlmClient::new(Vec::new()));
mock.push_message_response(mock_response("fallback answer", 3, 5));
let bridge = bridge_for(Arc::clone(&mock), 0);
let response = bridge
.dispatch(RpcRequest::Rlm {
prompt: "nested prompt".to_string(),
@@ -352,5 +437,9 @@ mod tests {
let usage = bridge.usage.lock().await;
assert_eq!(usage.input_tokens, 3);
assert_eq!(usage.output_tokens, 5);
let captured = mock.captured_requests();
assert_eq!(captured.len(), 1);
assert_eq!(captured[0].model, "override-model");
}
}
+51 -25
View File
@@ -2,6 +2,7 @@
//! subprocess + stdin/stdout RPC bridge (no HTTP sidecar).
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
@@ -9,11 +10,10 @@ use uuid::Uuid;
use crate::client::DeepSeekClient;
use crate::core::events::Event;
use crate::llm_client::LlmClient;
use crate::models::{ContentBlock, Message, MessageRequest, Usage};
use crate::models::{ContentBlock, Message, MessageRequest, SystemPrompt, Usage};
use crate::repl::PythonRuntime;
use super::bridge::RlmBridge;
use super::bridge::{RlmBridge, RlmLlmClient};
use super::prompt::rlm_system_prompt;
// ---------------------------------------------------------------------------
@@ -99,7 +99,7 @@ pub async fn run_rlm_turn(
max_depth: u32,
) -> RlmTurnResult {
run_rlm_turn_inner(
client,
Arc::new(client.clone()),
model,
prompt,
None,
@@ -122,7 +122,7 @@ pub async fn run_rlm_turn_with_root(
max_depth: u32,
) -> RlmTurnResult {
run_rlm_turn_inner(
client,
Arc::new(client.clone()),
model,
prompt,
root_prompt,
@@ -136,15 +136,15 @@ pub async fn run_rlm_turn_with_root(
/// Inner entry point — also used by the bridge when it recurses. Returns
/// a boxed future to break the recursive opaque-future-type cycle:
/// `run_rlm_turn_inner` → `RlmBridge::dispatch` → `run_rlm_turn_inner`.
pub(crate) fn run_rlm_turn_inner<'a>(
client: &'a DeepSeekClient,
pub(crate) fn run_rlm_turn_inner(
client: Arc<dyn RlmLlmClient>,
model: String,
prompt: String,
root_prompt: Option<String>,
child_model: String,
tx_event: mpsc::Sender<Event>,
max_depth: u32,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = RlmTurnResult> + Send + 'a>> {
) -> std::pin::Pin<Box<dyn std::future::Future<Output = RlmTurnResult> + Send>> {
Box::pin(run_rlm_turn_impl(
client,
model,
@@ -161,7 +161,7 @@ pub(crate) fn run_rlm_turn_inner<'a>(
// ---------------------------------------------------------------------------
async fn run_rlm_turn_impl(
client: &DeepSeekClient,
client: Arc<dyn RlmLlmClient>,
model: String,
prompt: String,
root_prompt: Option<String>,
@@ -212,7 +212,7 @@ async fn run_rlm_turn_impl(
};
// 3. Build the bridge that services llm_query / rlm_query RPCs.
let bridge = RlmBridge::new(client.clone(), child_model.clone(), max_depth);
let bridge = RlmBridge::new(Arc::clone(&client), child_model.clone(), max_depth);
let usage_handle = bridge.usage_handle();
let _ = tx_event
@@ -262,22 +262,9 @@ async fn run_rlm_turn_impl(
.await;
// 4a. Root LLM generates code from metadata-only context.
let request = MessageRequest {
model: model.clone(),
messages: messages.clone(),
max_tokens: ROOT_MAX_TOKENS,
system: Some(system.clone()),
tools: None,
tool_choice: None,
metadata: None,
thinking: None,
reasoning_effort: None,
stream: Some(false),
temperature: Some(ROOT_TEMPERATURE),
top_p: Some(0.9_f32),
};
let request = build_root_request(&model, &messages, &system);
let response = match client.create_message(request).await {
let response = match client.create_message_boxed(request).await {
Ok(r) => r,
Err(e) => {
break 'turn RlmTurnResult {
@@ -551,6 +538,23 @@ fn write_context_file(prompt: &str) -> std::io::Result<PathBuf> {
Ok(path)
}
fn build_root_request(model: &str, messages: &[Message], system: &SystemPrompt) -> MessageRequest {
MessageRequest {
model: model.to_string(),
messages: messages.to_vec(),
max_tokens: ROOT_MAX_TOKENS,
system: Some(system.clone()),
tools: None,
tool_choice: None,
metadata: None,
thinking: None,
reasoning_effort: None,
stream: Some(false),
temperature: Some(ROOT_TEMPERATURE),
top_p: Some(0.9_f32),
}
}
/// Build `Metadata(state)` from the paper. Surfaces:
/// - the small `root_prompt` (if any) — repeated each iteration
/// - `context` length + preview
@@ -851,6 +855,28 @@ mod tests {
);
}
#[test]
fn build_root_request_keeps_context_tail_out_of_root_payload() {
let secret_tail = "DO_NOT_LEAK_ROOT_REQUEST";
let prompt = format!("{}{}", "a".repeat(PROMPT_PREVIEW_LEN + 100), secret_tail);
let messages = vec![build_metadata_message(
&prompt,
Some("answer from the long context"),
0,
None,
None,
)];
let request = build_root_request("root-model", &messages, &rlm_system_prompt());
let payload = serde_json::to_string(&request).expect("request should serialize");
assert!(payload.contains(&format!("- Length: {} chars", prompt.chars().count())));
assert!(
!payload.contains(secret_tail),
"root LLM request leaked the non-preview tail of context"
);
}
#[test]
fn build_metadata_with_iteration_shows_previous_code() {
let msg = build_metadata_message("Test prompt", None, 3, Some("print('hi')"), Some("hi"));