feat(rlm_query): verify parallel fan-out + per-child prompt rendering (closes #60)

Introduce `RlmChildClient` — a dyn-compatible `#[async_trait]` wrapper around
the single create_message operation — so tests can inject a `MockRlmClient`
without a live API key. This replaces the direct `Arc<DeepSeekClient>` field
with `Arc<dyn RlmChildClient>`, wired transparently via `RlmQueryTool::new`.

Concurrency regression test (`rlm_parallel_fanout_overlaps_not_serialized`):
fires N=4 children each sleeping 50 ms through `join_all`. Asserts total
elapsed < 4×50 ms (serial bound) and that all start timestamps cluster within
<50 ms of each other. First run: total_elapsed=54 ms, start_spread=141 µs —
fan-out was already correct; no serialization fix needed.

UI wiring tests (`rlm_query_tool_cell_wired_with_prompts_on_start` etc.) verify
that `handle_tool_call_started` with `rlm_query` populates `GenericToolCell.prompts`
from the `prompts` (array) and `prompt` (singular) input shapes, and that
non-fan-out tools leave `prompts: None`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Hunter Bown
2026-04-26 14:21:43 -05:00
parent e9970fcad3
commit 49673d2ea3
2 changed files with 307 additions and 5 deletions
+211 -5
View File
@@ -14,8 +14,8 @@ use serde_json::{Value, json};
use tracing::debug;
use crate::client::DeepSeekClient;
use crate::llm_client::LlmClient;
use crate::models::{ContentBlock, Message, MessageRequest, SystemPrompt};
use crate::llm_client::LlmClient as _;
use crate::models::{ContentBlock, Message, MessageRequest, MessageResponse, SystemPrompt};
use crate::tools::spec::{
ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec,
optional_str, optional_u64,
@@ -28,16 +28,56 @@ const DEFAULT_MAX_TOKENS: u32 = 4096;
/// Hard cap on parallel children — protects against runaway fan-out.
const MAX_PARALLEL: usize = 16;
// ---------------------------------------------------------------------------
// RlmChildClient — dyn-compatible wrapper around LLM completion.
//
// The workspace's `LlmClient` trait uses native `async fn`, which is not dyn
// compatible in stable Rust (RPITIT vtable limitations). We define a small
// local trait with `#[async_trait]` that IS dyn-compatible, implement it for
// `DeepSeekClient`, and also implement it in tests for `MockLlmClient`. This
// avoids touching `llm_client.rs` or adding a new dep.
// ---------------------------------------------------------------------------
/// Minimal dyn-compatible async interface for the single RLM child-completion
/// operation. `#[async_trait]` desugars the async method into a boxed future
/// so the trait is object-safe.
#[async_trait]
pub(crate) trait RlmChildClient: Send + Sync {
async fn complete(&self, request: MessageRequest) -> anyhow::Result<MessageResponse>;
}
/// Blanket impl: any `DeepSeekClient` is a valid child client.
#[async_trait]
impl RlmChildClient for DeepSeekClient {
async fn complete(&self, request: MessageRequest) -> anyhow::Result<MessageResponse> {
self.create_message(request).await
}
}
/// Tool: `rlm_query`. Runs one or more prompts in parallel and joins the
/// results. Structured tool call so the model can trigger fan-out reliably.
pub struct RlmQueryTool {
client: Option<DeepSeekClient>,
/// Boxed child client — `Arc<dyn RlmChildClient>` lets tests inject a
/// mock without going through a real HTTP connection. `None` when no API
/// key is configured.
client: Option<Arc<dyn RlmChildClient>>,
default_model: String,
}
impl RlmQueryTool {
/// Construct with a concrete `DeepSeekClient` (production path).
#[must_use]
pub fn new(client: Option<DeepSeekClient>) -> Self {
Self {
client: client.map(|c| Arc::new(c) as Arc<dyn RlmChildClient>),
default_model: DEFAULT_CHILD_MODEL.to_string(),
}
}
/// Construct with a pre-boxed `RlmChildClient` — used by tests to inject
/// a `MockRlmClient` without an active API connection.
#[cfg(test)]
pub(crate) fn new_with_arc(client: Option<Arc<dyn RlmChildClient>>) -> Self {
Self {
client,
default_model: DEFAULT_CHILD_MODEL.to_string(),
@@ -141,7 +181,7 @@ impl ToolSpec for RlmQueryTool {
)));
}
let client = Arc::new(client);
// client is already Arc<dyn RlmChildClient> — clone the Arc, not the client.
let model = Arc::new(model);
let system = Arc::new(system);
let total = prompts.len();
@@ -191,7 +231,7 @@ impl ToolSpec for RlmQueryTool {
temperature: Some(0.4),
top_p: Some(0.9),
};
let response = client.create_message(request).await;
let response = client.complete(request).await;
let elapsed_ms = started.elapsed().as_millis() as u64;
in_flight.fetch_sub(1, Ordering::Relaxed);
debug!(
@@ -258,8 +298,10 @@ fn extract_text(blocks: &[ContentBlock]) -> String {
#[cfg(test)]
mod tests {
use super::*;
use crate::models::{MessageResponse, Usage};
use crate::tools::spec::ToolContext;
use std::path::PathBuf;
use std::sync::Mutex;
fn ctx() -> ToolContext {
ToolContext::with_auto_approve(
@@ -275,6 +317,170 @@ mod tests {
RlmQueryTool::new(None)
}
// -----------------------------------------------------------------------
// MockRlmClient — in-process stub for concurrency tests.
//
// Records the wall-clock instant each call *starts* and sleeps
// `call_delay` before returning. With join_all the N futures run
// concurrently on a single-threaded executor, so all starts happen
// before any sleep expires — demonstrating true overlap.
// -----------------------------------------------------------------------
struct MockRlmClient {
/// Per-call sleep to make overlap visible against a wall clock.
call_delay: std::time::Duration,
/// Timestamps recorded at the start of each `complete` call.
start_times: Arc<Mutex<Vec<Instant>>>,
}
impl MockRlmClient {
fn new(call_delay: std::time::Duration) -> Self {
Self {
call_delay,
start_times: Arc::new(Mutex::new(Vec::new())),
}
}
}
#[async_trait]
impl RlmChildClient for MockRlmClient {
async fn complete(&self, request: MessageRequest) -> anyhow::Result<MessageResponse> {
// Record start time before sleeping.
self.start_times.lock().unwrap().push(Instant::now());
tokio::time::sleep(self.call_delay).await;
// Return a minimal valid response that mirrors the incoming prompt.
let prompt_text = request
.messages
.first()
.and_then(|m| m.content.first())
.and_then(|b| match b {
ContentBlock::Text { text, .. } => Some(text.clone()),
_ => None,
})
.unwrap_or_default();
Ok(MessageResponse {
id: "mock-id".to_string(),
r#type: "message".to_string(),
role: "assistant".to_string(),
content: vec![ContentBlock::Text {
text: format!("echo: {prompt_text}"),
cache_control: None,
}],
model: "mock-model".to_string(),
stop_reason: Some("end_turn".to_string()),
stop_sequence: None,
container: None,
usage: Usage::default(),
})
}
}
// -----------------------------------------------------------------------
// Concurrency regression test
//
// With N=4 prompts and a 50 ms per-call sleep, *serial* execution would
// take ≥ 4×50 = 200 ms. True join_all fan-out means all calls start
// before any completes, so total wall time is ~50 ms (one sleep, not
// four). We assert: total_elapsed < 4 × call_delay, i.e. the calls
// must overlap.
//
// The test also verifies the mock records N start timestamps all clustered
// within one call_delay window — double-confirming overlap is real.
// -----------------------------------------------------------------------
#[tokio::test]
async fn rlm_parallel_fanout_overlaps_not_serialized() {
const N: usize = 4;
const CALL_DELAY_MS: u64 = 50;
let delay = std::time::Duration::from_millis(CALL_DELAY_MS);
let mock = Arc::new(MockRlmClient::new(delay));
let start_times_ref = Arc::clone(&mock.start_times);
let tool = RlmQueryTool::new_with_arc(Some(mock as Arc<dyn RlmChildClient>));
let prompts: Vec<&str> = vec!["a", "b", "c", "d"];
let overall_start = Instant::now();
let result = tool
.execute(json!({ "prompts": prompts }), &ctx())
.await
.expect("mock tool should succeed");
let total_elapsed = overall_start.elapsed();
// Sanity: all 4 children returned text.
assert!(result.success, "all children should succeed");
// Overlap assertion: total wall time must be well under 4×delay.
// We allow 3×delay as a generous upper bound (plenty of headroom for
// slow CI machines) while still catching serialization bugs.
let serial_time = delay * u32::try_from(N).unwrap();
assert!(
total_elapsed < serial_time,
"fan-out looks serialized: elapsed {total_elapsed:?} >= serial bound {serial_time:?}"
);
// Secondary confirmation: the mock recorded N start timestamps that
// are within one call_delay of each other, proving actual concurrency.
let starts = start_times_ref.lock().unwrap();
assert_eq!(starts.len(), N, "expected exactly {N} child calls");
let min_start = *starts.iter().min().unwrap();
let max_start = *starts.iter().max().unwrap();
// All starts must cluster within one call_delay window — if they were
// serial, max_start - min_start would be ≥ (N-1) × delay.
let start_spread = max_start.duration_since(min_start);
assert!(
start_spread < delay,
"child starts are spread over {start_spread:?}, expected < {delay:?} \
(suggests serialization rather than concurrent fan-out)"
);
// Surface numbers for test output (visible with --nocapture or on
// failure). This is the same information the issue asked to emit.
eprintln!(
"[rlm_parallel_fanout] total_elapsed={total_elapsed:?} \
start_spread={start_spread:?} \
max_concurrent={N} \
per_call_delay={delay:?}"
);
}
/// With a mock client, `prompt` (singular) still fans out as a single
/// child and returns plain text (no `[0]` prefix for N=1).
#[tokio::test]
async fn rlm_single_prompt_returns_plain_text() {
let mock = Arc::new(MockRlmClient::new(std::time::Duration::from_millis(1)));
let tool = RlmQueryTool::new_with_arc(Some(mock as Arc<dyn RlmChildClient>));
let result = tool
.execute(json!({ "prompt": "hello" }), &ctx())
.await
.expect("single-prompt mock should succeed");
let text = &result.content;
// N=1 returns bare text, no "[0]" index prefix.
assert!(!text.starts_with("[0]"), "N=1 must not add index prefix");
assert!(text.contains("echo: hello"), "text must echo the prompt");
}
/// With a mock client, `prompts` (plural, N>1) returns indexed blocks.
#[tokio::test]
async fn rlm_multi_prompt_returns_indexed_blocks() {
let mock = Arc::new(MockRlmClient::new(std::time::Duration::from_millis(1)));
let tool = RlmQueryTool::new_with_arc(Some(mock as Arc<dyn RlmChildClient>));
let result = tool
.execute(json!({ "prompts": ["alpha", "beta"] }), &ctx())
.await
.expect("multi-prompt mock should succeed");
let text = &result.content;
assert!(text.contains("[0]"), "first block must be indexed [0]");
assert!(text.contains("[1]"), "second block must be indexed [1]");
assert!(text.contains("echo: alpha"));
assert!(text.contains("echo: beta"));
}
#[test]
fn schema_advertises_both_shapes() {
let schema = tool_without_client().input_schema();
+96
View File
@@ -1790,3 +1790,99 @@ fn second_thinking_block_appends_new_entry_in_same_active_cell() {
"the group still hasn't flushed — no prose yet"
);
}
// ---- rlm_query per-child prompt wiring ----
//
// When `handle_tool_call_started` receives an `rlm_query` call with a
// `prompts` array, the resulting `GenericToolCell` must carry the parsed
// prompts so the TUI can render one row per child (see
// `GenericToolCell::lines_with_motion` and the `show_prompts` branch in
// `history.rs`).
#[test]
fn rlm_query_tool_cell_wired_with_prompts_on_start() {
let mut app = create_test_app();
handle_tool_call_started(
&mut app,
"rlm-1",
"rlm_query",
&serde_json::json!({
"prompts": [
"What is the capital of France?",
"List all public types in client.rs",
"Summarize the README"
]
}),
);
// The cell must be live in the active_cell slot (turn not yet complete).
let active = app.active_cell.as_ref().expect("active cell present");
let HistoryCell::Tool(ToolCell::Generic(generic)) = &active.entries()[0] else {
panic!("expected GenericToolCell for rlm_query");
};
assert_eq!(generic.name, "rlm_query");
assert_eq!(generic.status, ToolStatus::Running);
// Core assertion: prompts populated from the JSON input.
let prompts = generic
.prompts
.as_ref()
.expect("rlm_query cell must have prompts populated");
assert_eq!(prompts.len(), 3);
assert_eq!(prompts[0], "What is the capital of France?");
assert_eq!(prompts[1], "List all public types in client.rs");
assert_eq!(prompts[2], "Summarize the README");
}
#[test]
fn rlm_query_singular_prompt_wired_as_single_element_vec() {
// When the model passes `prompt` (singular) instead of `prompts`,
// the cell should still populate a one-element prompts vec so the
// renderer shows the child's question.
let mut app = create_test_app();
handle_tool_call_started(
&mut app,
"rlm-2",
"rlm_query",
&serde_json::json!({ "prompt": "Explain the engine loop" }),
);
let active = app.active_cell.as_ref().expect("active cell present");
let HistoryCell::Tool(ToolCell::Generic(generic)) = &active.entries()[0] else {
panic!("expected GenericToolCell for rlm_query");
};
let prompts = generic
.prompts
.as_ref()
.expect("singular prompt must populate prompts vec");
assert_eq!(prompts.len(), 1);
assert_eq!(prompts[0], "Explain the engine loop");
}
#[test]
fn non_fanout_tool_does_not_populate_prompts() {
// Tools other than rlm_query must not get a prompts vec — they use
// the standard `args:` summary rendering path.
let mut app = create_test_app();
handle_tool_call_started(
&mut app,
"fs-1",
"file_search",
&serde_json::json!({ "query": "client.rs" }),
);
let active = app.active_cell.as_ref().expect("active cell present");
let HistoryCell::Tool(ToolCell::Generic(generic)) = &active.entries()[0] else {
panic!("expected GenericToolCell for file_search");
};
assert!(
generic.prompts.is_none(),
"non-fan-out tool must not populate prompts"
);
}