test(rlm): make bridge client seam mockable
This commit is contained in:
@@ -58,7 +58,10 @@ pub trait LlmClient: Send + Sync {
|
||||
fn model(&self) -> &str;
|
||||
|
||||
/// Creates a non-streaming message completion
|
||||
async fn create_message(&self, request: MessageRequest) -> Result<MessageResponse>;
|
||||
fn create_message(
|
||||
&self,
|
||||
request: MessageRequest,
|
||||
) -> impl Future<Output = Result<MessageResponse>> + Send;
|
||||
|
||||
/// Creates a streaming message completion
|
||||
///
|
||||
|
||||
+147
-58
@@ -13,13 +13,14 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{future::Future, pin::Pin};
|
||||
|
||||
use anyhow::Result;
|
||||
use futures_util::future::join_all;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::client::DeepSeekClient;
|
||||
use crate::llm_client::LlmClient as _;
|
||||
use crate::models::{ContentBlock, Message, MessageRequest, SystemPrompt, Usage};
|
||||
use crate::llm_client::LlmClient;
|
||||
use crate::models::{ContentBlock, Message, MessageRequest, MessageResponse, SystemPrompt, Usage};
|
||||
use crate::repl::runtime::{BatchResp, RpcDispatcher, RpcRequest, RpcResponse, SingleResp};
|
||||
|
||||
/// Per-child completion timeout — same as the previous sidecar default.
|
||||
@@ -29,18 +30,46 @@ const DEFAULT_CHILD_MAX_TOKENS: u32 = 4096;
|
||||
/// Hard cap on prompts per batch RPC.
|
||||
pub const MAX_BATCH: usize = 16;
|
||||
|
||||
/// Object-safe slice of the LLM client interface that the RLM bridge needs.
|
||||
///
|
||||
/// `LlmClient` itself uses native async trait methods, which are not dyn-safe.
|
||||
/// The bridge only needs non-streaming completions, so this boxed-future shim
|
||||
/// gives tests a clean mock seam without changing the wider provider trait.
|
||||
pub(crate) trait RlmLlmClient: Send + Sync {
|
||||
fn create_message_boxed(
|
||||
&self,
|
||||
request: MessageRequest,
|
||||
) -> Pin<Box<dyn Future<Output = Result<MessageResponse>> + Send + '_>>;
|
||||
}
|
||||
|
||||
impl<T> RlmLlmClient for T
|
||||
where
|
||||
T: LlmClient + Send + Sync,
|
||||
{
|
||||
fn create_message_boxed(
|
||||
&self,
|
||||
request: MessageRequest,
|
||||
) -> Pin<Box<dyn Future<Output = Result<MessageResponse>> + Send + '_>> {
|
||||
Box::pin(self.create_message(request))
|
||||
}
|
||||
}
|
||||
|
||||
/// State shared with the bridge across all RPC calls in one turn.
|
||||
pub struct RlmBridge {
|
||||
pub client: DeepSeekClient,
|
||||
pub child_model: String,
|
||||
client: Arc<dyn RlmLlmClient>,
|
||||
child_model: String,
|
||||
/// Recursion budget remaining for `Rlm` / `RlmBatch` requests. When
|
||||
/// zero, those requests fall back to plain `Llm` completions.
|
||||
pub depth_remaining: u32,
|
||||
pub usage: Arc<Mutex<Usage>>,
|
||||
depth_remaining: u32,
|
||||
usage: Arc<Mutex<Usage>>,
|
||||
}
|
||||
|
||||
impl RlmBridge {
|
||||
pub fn new(client: DeepSeekClient, child_model: String, depth_remaining: u32) -> Self {
|
||||
pub(crate) fn new(
|
||||
client: Arc<dyn RlmLlmClient>,
|
||||
child_model: String,
|
||||
depth_remaining: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
child_model,
|
||||
@@ -83,7 +112,7 @@ impl RlmBridge {
|
||||
top_p: Some(0.9_f32),
|
||||
};
|
||||
|
||||
let fut = self.client.create_message(request);
|
||||
let fut = self.client.create_message_boxed(request);
|
||||
let response =
|
||||
match tokio::time::timeout(Duration::from_secs(CHILD_TIMEOUT_SECS), fut).await {
|
||||
Ok(Ok(r)) => r,
|
||||
@@ -165,7 +194,7 @@ impl RlmBridge {
|
||||
// Recursive call. The dyn-erasure on `run_rlm_turn_inner` breaks
|
||||
// the `bridge → turn → bridge` opaque-future cycle.
|
||||
let result = super::turn::run_rlm_turn_inner(
|
||||
&self.client,
|
||||
Arc::clone(&self.client),
|
||||
child_model.clone(),
|
||||
prompt,
|
||||
None,
|
||||
@@ -254,43 +283,32 @@ impl RpcDispatcher for RlmBridge {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{Config, ProviderConfig, ProvidersConfig};
|
||||
use serde_json::json;
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
use crate::llm_client::mock::MockLlmClient;
|
||||
|
||||
fn client_for(server: &MockServer) -> DeepSeekClient {
|
||||
let config = Config {
|
||||
provider: Some("sglang".to_string()),
|
||||
providers: Some(ProvidersConfig {
|
||||
sglang: ProviderConfig {
|
||||
base_url: Some(server.uri()),
|
||||
..ProviderConfig::default()
|
||||
},
|
||||
..ProvidersConfig::default()
|
||||
}),
|
||||
..Config::default()
|
||||
};
|
||||
DeepSeekClient::new(&config).expect("test client")
|
||||
fn mock_response(text: &str, input_tokens: u32, output_tokens: u32) -> MessageResponse {
|
||||
MessageResponse {
|
||||
id: "mock_msg".to_string(),
|
||||
r#type: "message".to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: text.to_string(),
|
||||
cache_control: None,
|
||||
}],
|
||||
model: "mock-model".to_string(),
|
||||
stop_reason: Some("end_turn".to_string()),
|
||||
stop_sequence: None,
|
||||
container: None,
|
||||
usage: Usage {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
..Usage::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn chat_response(text: &str) -> serde_json::Value {
|
||||
json!({
|
||||
"id": "chatcmpl-test",
|
||||
"model": "test-model",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 3,
|
||||
"completion_tokens": 5
|
||||
}
|
||||
})
|
||||
fn bridge_for(mock: Arc<MockLlmClient>, depth_remaining: u32) -> RlmBridge {
|
||||
let client: Arc<dyn RlmLlmClient> = mock;
|
||||
RlmBridge::new(client, "child-model".to_string(), depth_remaining)
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -318,22 +336,89 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rlm_dispatch_at_depth_zero_falls_back_to_plain_llm_query() {
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(ResponseTemplate::new(404).set_body_string("responses unavailable"))
|
||||
.mount(&server)
|
||||
.await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/chat/completions"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(chat_response("fallback answer")),
|
||||
)
|
||||
.mount(&server)
|
||||
async fn llm_dispatch_uses_trait_backed_mock_client() {
|
||||
let mock = Arc::new(MockLlmClient::new(Vec::new()));
|
||||
mock.push_message_response(mock_response("child answer", 7, 11));
|
||||
let bridge = bridge_for(Arc::clone(&mock), 1);
|
||||
|
||||
let response = bridge
|
||||
.dispatch(RpcRequest::Llm {
|
||||
prompt: "child prompt".to_string(),
|
||||
model: Some("override-model".to_string()),
|
||||
max_tokens: Some(123),
|
||||
system: Some("child system".to_string()),
|
||||
})
|
||||
.await;
|
||||
|
||||
let bridge = RlmBridge::new(client_for(&server), "child-model".to_string(), 0);
|
||||
match response {
|
||||
RpcResponse::Single(single) => {
|
||||
assert_eq!(single.text, "child answer");
|
||||
assert!(single.error.is_none());
|
||||
}
|
||||
other => panic!("expected single response, got {other:?}"),
|
||||
}
|
||||
|
||||
let captured = mock.captured_requests();
|
||||
assert_eq!(captured.len(), 1);
|
||||
assert_eq!(captured[0].model, "override-model");
|
||||
assert_eq!(captured[0].max_tokens, 123);
|
||||
assert_eq!(
|
||||
captured[0].system,
|
||||
Some(SystemPrompt::Text("child system".to_string()))
|
||||
);
|
||||
|
||||
let usage = bridge.usage.lock().await;
|
||||
assert_eq!(usage.input_tokens, 7);
|
||||
assert_eq!(usage.output_tokens, 11);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn llm_batch_dispatch_preserves_result_count_and_usage() {
|
||||
let mock = Arc::new(MockLlmClient::new(Vec::new()));
|
||||
mock.push_message_response(mock_response("one", 1, 2));
|
||||
mock.push_message_response(mock_response("two", 3, 4));
|
||||
mock.push_message_response(mock_response("three", 5, 6));
|
||||
let bridge = bridge_for(Arc::clone(&mock), 1);
|
||||
|
||||
let response = bridge
|
||||
.dispatch(RpcRequest::LlmBatch {
|
||||
prompts: vec!["a".to_string(), "b".to_string(), "c".to_string()],
|
||||
model: Some("batch-model".to_string()),
|
||||
})
|
||||
.await;
|
||||
|
||||
match response {
|
||||
RpcResponse::Batch(batch) => {
|
||||
let texts: Vec<_> = batch
|
||||
.results
|
||||
.iter()
|
||||
.map(|result| result.text.as_str())
|
||||
.collect();
|
||||
assert_eq!(texts, ["one", "two", "three"]);
|
||||
assert!(batch.results.iter().all(|result| result.error.is_none()));
|
||||
}
|
||||
other => panic!("expected batch response, got {other:?}"),
|
||||
}
|
||||
|
||||
let captured = mock.captured_requests();
|
||||
assert_eq!(captured.len(), 3);
|
||||
assert!(
|
||||
captured
|
||||
.iter()
|
||||
.all(|request| request.model == "batch-model")
|
||||
);
|
||||
|
||||
let usage = bridge.usage.lock().await;
|
||||
assert_eq!(usage.input_tokens, 9);
|
||||
assert_eq!(usage.output_tokens, 12);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rlm_dispatch_at_depth_zero_falls_back_to_plain_llm_query() {
|
||||
let mock = Arc::new(MockLlmClient::new(Vec::new()));
|
||||
mock.push_message_response(mock_response("fallback answer", 3, 5));
|
||||
let bridge = bridge_for(Arc::clone(&mock), 0);
|
||||
|
||||
let response = bridge
|
||||
.dispatch(RpcRequest::Rlm {
|
||||
prompt: "nested prompt".to_string(),
|
||||
@@ -352,5 +437,9 @@ mod tests {
|
||||
let usage = bridge.usage.lock().await;
|
||||
assert_eq!(usage.input_tokens, 3);
|
||||
assert_eq!(usage.output_tokens, 5);
|
||||
|
||||
let captured = mock.captured_requests();
|
||||
assert_eq!(captured.len(), 1);
|
||||
assert_eq!(captured[0].model, "override-model");
|
||||
}
|
||||
}
|
||||
|
||||
+51
-25
@@ -2,6 +2,7 @@
|
||||
//! subprocess + stdin/stdout RPC bridge (no HTTP sidecar).
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
@@ -9,11 +10,10 @@ use uuid::Uuid;
|
||||
|
||||
use crate::client::DeepSeekClient;
|
||||
use crate::core::events::Event;
|
||||
use crate::llm_client::LlmClient;
|
||||
use crate::models::{ContentBlock, Message, MessageRequest, Usage};
|
||||
use crate::models::{ContentBlock, Message, MessageRequest, SystemPrompt, Usage};
|
||||
use crate::repl::PythonRuntime;
|
||||
|
||||
use super::bridge::RlmBridge;
|
||||
use super::bridge::{RlmBridge, RlmLlmClient};
|
||||
use super::prompt::rlm_system_prompt;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -99,7 +99,7 @@ pub async fn run_rlm_turn(
|
||||
max_depth: u32,
|
||||
) -> RlmTurnResult {
|
||||
run_rlm_turn_inner(
|
||||
client,
|
||||
Arc::new(client.clone()),
|
||||
model,
|
||||
prompt,
|
||||
None,
|
||||
@@ -122,7 +122,7 @@ pub async fn run_rlm_turn_with_root(
|
||||
max_depth: u32,
|
||||
) -> RlmTurnResult {
|
||||
run_rlm_turn_inner(
|
||||
client,
|
||||
Arc::new(client.clone()),
|
||||
model,
|
||||
prompt,
|
||||
root_prompt,
|
||||
@@ -136,15 +136,15 @@ pub async fn run_rlm_turn_with_root(
|
||||
/// Inner entry point — also used by the bridge when it recurses. Returns
|
||||
/// a boxed future to break the recursive opaque-future-type cycle:
|
||||
/// `run_rlm_turn_inner` → `RlmBridge::dispatch` → `run_rlm_turn_inner`.
|
||||
pub(crate) fn run_rlm_turn_inner<'a>(
|
||||
client: &'a DeepSeekClient,
|
||||
pub(crate) fn run_rlm_turn_inner(
|
||||
client: Arc<dyn RlmLlmClient>,
|
||||
model: String,
|
||||
prompt: String,
|
||||
root_prompt: Option<String>,
|
||||
child_model: String,
|
||||
tx_event: mpsc::Sender<Event>,
|
||||
max_depth: u32,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = RlmTurnResult> + Send + 'a>> {
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = RlmTurnResult> + Send>> {
|
||||
Box::pin(run_rlm_turn_impl(
|
||||
client,
|
||||
model,
|
||||
@@ -161,7 +161,7 @@ pub(crate) fn run_rlm_turn_inner<'a>(
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn run_rlm_turn_impl(
|
||||
client: &DeepSeekClient,
|
||||
client: Arc<dyn RlmLlmClient>,
|
||||
model: String,
|
||||
prompt: String,
|
||||
root_prompt: Option<String>,
|
||||
@@ -212,7 +212,7 @@ async fn run_rlm_turn_impl(
|
||||
};
|
||||
|
||||
// 3. Build the bridge that services llm_query / rlm_query RPCs.
|
||||
let bridge = RlmBridge::new(client.clone(), child_model.clone(), max_depth);
|
||||
let bridge = RlmBridge::new(Arc::clone(&client), child_model.clone(), max_depth);
|
||||
let usage_handle = bridge.usage_handle();
|
||||
|
||||
let _ = tx_event
|
||||
@@ -262,22 +262,9 @@ async fn run_rlm_turn_impl(
|
||||
.await;
|
||||
|
||||
// 4a. Root LLM generates code from metadata-only context.
|
||||
let request = MessageRequest {
|
||||
model: model.clone(),
|
||||
messages: messages.clone(),
|
||||
max_tokens: ROOT_MAX_TOKENS,
|
||||
system: Some(system.clone()),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
thinking: None,
|
||||
reasoning_effort: None,
|
||||
stream: Some(false),
|
||||
temperature: Some(ROOT_TEMPERATURE),
|
||||
top_p: Some(0.9_f32),
|
||||
};
|
||||
let request = build_root_request(&model, &messages, &system);
|
||||
|
||||
let response = match client.create_message(request).await {
|
||||
let response = match client.create_message_boxed(request).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
break 'turn RlmTurnResult {
|
||||
@@ -551,6 +538,23 @@ fn write_context_file(prompt: &str) -> std::io::Result<PathBuf> {
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
fn build_root_request(model: &str, messages: &[Message], system: &SystemPrompt) -> MessageRequest {
|
||||
MessageRequest {
|
||||
model: model.to_string(),
|
||||
messages: messages.to_vec(),
|
||||
max_tokens: ROOT_MAX_TOKENS,
|
||||
system: Some(system.clone()),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
thinking: None,
|
||||
reasoning_effort: None,
|
||||
stream: Some(false),
|
||||
temperature: Some(ROOT_TEMPERATURE),
|
||||
top_p: Some(0.9_f32),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build `Metadata(state)` from the paper. Surfaces:
|
||||
/// - the small `root_prompt` (if any) — repeated each iteration
|
||||
/// - `context` length + preview
|
||||
@@ -851,6 +855,28 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_root_request_keeps_context_tail_out_of_root_payload() {
|
||||
let secret_tail = "DO_NOT_LEAK_ROOT_REQUEST";
|
||||
let prompt = format!("{}{}", "a".repeat(PROMPT_PREVIEW_LEN + 100), secret_tail);
|
||||
let messages = vec![build_metadata_message(
|
||||
&prompt,
|
||||
Some("answer from the long context"),
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
)];
|
||||
|
||||
let request = build_root_request("root-model", &messages, &rlm_system_prompt());
|
||||
let payload = serde_json::to_string(&request).expect("request should serialize");
|
||||
|
||||
assert!(payload.contains(&format!("- Length: {} chars", prompt.chars().count())));
|
||||
assert!(
|
||||
!payload.contains(secret_tail),
|
||||
"root LLM request leaked the non-preview tail of context"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_metadata_with_iteration_shows_previous_code() {
|
||||
let msg = build_metadata_message("Test prompt", None, 3, Some("print('hi')"), Some("hi"));
|
||||
|
||||
Reference in New Issue
Block a user