merge: rlm_query parallelism verification + per-child UI (closes #60)
This commit is contained in:
@@ -14,8 +14,8 @@ use serde_json::{Value, json};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::client::DeepSeekClient;
|
||||
use crate::llm_client::LlmClient;
|
||||
use crate::models::{ContentBlock, Message, MessageRequest, SystemPrompt};
|
||||
use crate::llm_client::LlmClient as _;
|
||||
use crate::models::{ContentBlock, Message, MessageRequest, MessageResponse, SystemPrompt};
|
||||
use crate::tools::spec::{
|
||||
ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec,
|
||||
optional_str, optional_u64,
|
||||
@@ -28,16 +28,56 @@ const DEFAULT_MAX_TOKENS: u32 = 4096;
|
||||
/// Hard cap on parallel children — protects against runaway fan-out.
|
||||
const MAX_PARALLEL: usize = 16;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RlmChildClient — dyn-compatible wrapper around LLM completion.
|
||||
//
|
||||
// The workspace's `LlmClient` trait uses native `async fn`, which is not dyn
|
||||
// compatible in stable Rust (RPITIT vtable limitations). We define a small
|
||||
// local trait with `#[async_trait]` that IS dyn-compatible, implement it for
|
||||
// `DeepSeekClient`, and also implement it in tests for `MockLlmClient`. This
|
||||
// avoids touching `llm_client.rs` or adding a new dep.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Minimal dyn-compatible async interface for the single RLM child-completion
|
||||
/// operation. `#[async_trait]` desugars the async method into a boxed future
|
||||
/// so the trait is object-safe.
|
||||
#[async_trait]
|
||||
pub(crate) trait RlmChildClient: Send + Sync {
|
||||
async fn complete(&self, request: MessageRequest) -> anyhow::Result<MessageResponse>;
|
||||
}
|
||||
|
||||
/// Blanket impl: any `DeepSeekClient` is a valid child client.
|
||||
#[async_trait]
|
||||
impl RlmChildClient for DeepSeekClient {
|
||||
async fn complete(&self, request: MessageRequest) -> anyhow::Result<MessageResponse> {
|
||||
self.create_message(request).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool: `rlm_query`. Runs one or more prompts in parallel and joins the
|
||||
/// results. Structured tool call so the model can trigger fan-out reliably.
|
||||
pub struct RlmQueryTool {
|
||||
client: Option<DeepSeekClient>,
|
||||
/// Boxed child client — `Arc<dyn RlmChildClient>` lets tests inject a
|
||||
/// mock without going through a real HTTP connection. `None` when no API
|
||||
/// key is configured.
|
||||
client: Option<Arc<dyn RlmChildClient>>,
|
||||
default_model: String,
|
||||
}
|
||||
|
||||
impl RlmQueryTool {
|
||||
/// Construct with a concrete `DeepSeekClient` (production path).
|
||||
#[must_use]
|
||||
pub fn new(client: Option<DeepSeekClient>) -> Self {
|
||||
Self {
|
||||
client: client.map(|c| Arc::new(c) as Arc<dyn RlmChildClient>),
|
||||
default_model: DEFAULT_CHILD_MODEL.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct with a pre-boxed `RlmChildClient` — used by tests to inject
|
||||
/// a `MockRlmClient` without an active API connection.
|
||||
#[cfg(test)]
|
||||
pub(crate) fn new_with_arc(client: Option<Arc<dyn RlmChildClient>>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
default_model: DEFAULT_CHILD_MODEL.to_string(),
|
||||
@@ -142,7 +182,7 @@ impl ToolSpec for RlmQueryTool {
|
||||
)));
|
||||
}
|
||||
|
||||
let client = Arc::new(client);
|
||||
// client is already Arc<dyn RlmChildClient> — clone the Arc, not the client.
|
||||
let model = Arc::new(model);
|
||||
let system = Arc::new(system);
|
||||
let total = prompts.len();
|
||||
@@ -192,7 +232,7 @@ impl ToolSpec for RlmQueryTool {
|
||||
temperature: Some(0.4),
|
||||
top_p: Some(0.9),
|
||||
};
|
||||
let response = client.create_message(request).await;
|
||||
let response = client.complete(request).await;
|
||||
let elapsed_ms = started.elapsed().as_millis() as u64;
|
||||
in_flight.fetch_sub(1, Ordering::Relaxed);
|
||||
debug!(
|
||||
@@ -259,8 +299,10 @@ fn extract_text(blocks: &[ContentBlock]) -> String {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::models::{MessageResponse, Usage};
|
||||
use crate::tools::spec::ToolContext;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Mutex;
|
||||
|
||||
fn ctx() -> ToolContext {
|
||||
ToolContext::with_auto_approve(
|
||||
@@ -276,6 +318,170 @@ mod tests {
|
||||
RlmQueryTool::new(None)
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// MockRlmClient — in-process stub for concurrency tests.
|
||||
//
|
||||
// Records the wall-clock instant each call *starts* and sleeps
|
||||
// `call_delay` before returning. With join_all the N futures run
|
||||
// concurrently on a single-threaded executor, so all starts happen
|
||||
// before any sleep expires — demonstrating true overlap.
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
struct MockRlmClient {
|
||||
/// Per-call sleep to make overlap visible against a wall clock.
|
||||
call_delay: std::time::Duration,
|
||||
/// Timestamps recorded at the start of each `complete` call.
|
||||
start_times: Arc<Mutex<Vec<Instant>>>,
|
||||
}
|
||||
|
||||
impl MockRlmClient {
|
||||
fn new(call_delay: std::time::Duration) -> Self {
|
||||
Self {
|
||||
call_delay,
|
||||
start_times: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RlmChildClient for MockRlmClient {
|
||||
async fn complete(&self, request: MessageRequest) -> anyhow::Result<MessageResponse> {
|
||||
// Record start time before sleeping.
|
||||
self.start_times.lock().unwrap().push(Instant::now());
|
||||
|
||||
tokio::time::sleep(self.call_delay).await;
|
||||
|
||||
// Return a minimal valid response that mirrors the incoming prompt.
|
||||
let prompt_text = request
|
||||
.messages
|
||||
.first()
|
||||
.and_then(|m| m.content.first())
|
||||
.and_then(|b| match b {
|
||||
ContentBlock::Text { text, .. } => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(MessageResponse {
|
||||
id: "mock-id".to_string(),
|
||||
r#type: "message".to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: format!("echo: {prompt_text}"),
|
||||
cache_control: None,
|
||||
}],
|
||||
model: "mock-model".to_string(),
|
||||
stop_reason: Some("end_turn".to_string()),
|
||||
stop_sequence: None,
|
||||
container: None,
|
||||
usage: Usage::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Concurrency regression test
|
||||
//
|
||||
// With N=4 prompts and a 50 ms per-call sleep, *serial* execution would
|
||||
// take ≥ 4×50 = 200 ms. True join_all fan-out means all calls start
|
||||
// before any completes, so total wall time is ~50 ms (one sleep, not
|
||||
// four). We assert: total_elapsed < 4 × call_delay, i.e. the calls
|
||||
// must overlap.
|
||||
//
|
||||
// The test also verifies the mock records N start timestamps all clustered
|
||||
// within one call_delay window — double-confirming overlap is real.
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn rlm_parallel_fanout_overlaps_not_serialized() {
|
||||
const N: usize = 4;
|
||||
const CALL_DELAY_MS: u64 = 50;
|
||||
|
||||
let delay = std::time::Duration::from_millis(CALL_DELAY_MS);
|
||||
let mock = Arc::new(MockRlmClient::new(delay));
|
||||
let start_times_ref = Arc::clone(&mock.start_times);
|
||||
|
||||
let tool = RlmQueryTool::new_with_arc(Some(mock as Arc<dyn RlmChildClient>));
|
||||
let prompts: Vec<&str> = vec!["a", "b", "c", "d"];
|
||||
|
||||
let overall_start = Instant::now();
|
||||
let result = tool
|
||||
.execute(json!({ "prompts": prompts }), &ctx())
|
||||
.await
|
||||
.expect("mock tool should succeed");
|
||||
let total_elapsed = overall_start.elapsed();
|
||||
|
||||
// Sanity: all 4 children returned text.
|
||||
assert!(result.success, "all children should succeed");
|
||||
|
||||
// Overlap assertion: total wall time must be well under 4×delay.
|
||||
// We allow 3×delay as a generous upper bound (plenty of headroom for
|
||||
// slow CI machines) while still catching serialization bugs.
|
||||
let serial_time = delay * u32::try_from(N).unwrap();
|
||||
assert!(
|
||||
total_elapsed < serial_time,
|
||||
"fan-out looks serialized: elapsed {total_elapsed:?} >= serial bound {serial_time:?}"
|
||||
);
|
||||
|
||||
// Secondary confirmation: the mock recorded N start timestamps that
|
||||
// are within one call_delay of each other, proving actual concurrency.
|
||||
let starts = start_times_ref.lock().unwrap();
|
||||
assert_eq!(starts.len(), N, "expected exactly {N} child calls");
|
||||
let min_start = *starts.iter().min().unwrap();
|
||||
let max_start = *starts.iter().max().unwrap();
|
||||
// All starts must cluster within one call_delay window — if they were
|
||||
// serial, max_start - min_start would be ≥ (N-1) × delay.
|
||||
let start_spread = max_start.duration_since(min_start);
|
||||
assert!(
|
||||
start_spread < delay,
|
||||
"child starts are spread over {start_spread:?}, expected < {delay:?} \
|
||||
(suggests serialization rather than concurrent fan-out)"
|
||||
);
|
||||
|
||||
// Surface numbers for test output (visible with --nocapture or on
|
||||
// failure). This is the same information the issue asked to emit.
|
||||
eprintln!(
|
||||
"[rlm_parallel_fanout] total_elapsed={total_elapsed:?} \
|
||||
start_spread={start_spread:?} \
|
||||
max_concurrent={N} \
|
||||
per_call_delay={delay:?}"
|
||||
);
|
||||
}
|
||||
|
||||
/// With a mock client, `prompt` (singular) still fans out as a single
|
||||
/// child and returns plain text (no `[0]` prefix for N=1).
|
||||
#[tokio::test]
|
||||
async fn rlm_single_prompt_returns_plain_text() {
|
||||
let mock = Arc::new(MockRlmClient::new(std::time::Duration::from_millis(1)));
|
||||
let tool = RlmQueryTool::new_with_arc(Some(mock as Arc<dyn RlmChildClient>));
|
||||
|
||||
let result = tool
|
||||
.execute(json!({ "prompt": "hello" }), &ctx())
|
||||
.await
|
||||
.expect("single-prompt mock should succeed");
|
||||
let text = &result.content;
|
||||
// N=1 returns bare text, no "[0]" index prefix.
|
||||
assert!(!text.starts_with("[0]"), "N=1 must not add index prefix");
|
||||
assert!(text.contains("echo: hello"), "text must echo the prompt");
|
||||
}
|
||||
|
||||
/// With a mock client, `prompts` (plural, N>1) returns indexed blocks.
|
||||
#[tokio::test]
|
||||
async fn rlm_multi_prompt_returns_indexed_blocks() {
|
||||
let mock = Arc::new(MockRlmClient::new(std::time::Duration::from_millis(1)));
|
||||
let tool = RlmQueryTool::new_with_arc(Some(mock as Arc<dyn RlmChildClient>));
|
||||
|
||||
let result = tool
|
||||
.execute(json!({ "prompts": ["alpha", "beta"] }), &ctx())
|
||||
.await
|
||||
.expect("multi-prompt mock should succeed");
|
||||
let text = &result.content;
|
||||
assert!(text.contains("[0]"), "first block must be indexed [0]");
|
||||
assert!(text.contains("[1]"), "second block must be indexed [1]");
|
||||
assert!(text.contains("echo: alpha"));
|
||||
assert!(text.contains("echo: beta"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schema_advertises_both_shapes() {
|
||||
let schema = tool_without_client().input_schema();
|
||||
|
||||
@@ -1791,3 +1791,99 @@ fn second_thinking_block_appends_new_entry_in_same_active_cell() {
|
||||
"the group still hasn't flushed — no prose yet"
|
||||
);
|
||||
}
|
||||
|
||||
// ---- rlm_query per-child prompt wiring ----
|
||||
//
|
||||
// When `handle_tool_call_started` receives an `rlm_query` call with a
|
||||
// `prompts` array, the resulting `GenericToolCell` must carry the parsed
|
||||
// prompts so the TUI can render one row per child (see
|
||||
// `GenericToolCell::lines_with_motion` and the `show_prompts` branch in
|
||||
// `history.rs`).
|
||||
|
||||
#[test]
|
||||
fn rlm_query_tool_cell_wired_with_prompts_on_start() {
|
||||
let mut app = create_test_app();
|
||||
|
||||
handle_tool_call_started(
|
||||
&mut app,
|
||||
"rlm-1",
|
||||
"rlm_query",
|
||||
&serde_json::json!({
|
||||
"prompts": [
|
||||
"What is the capital of France?",
|
||||
"List all public types in client.rs",
|
||||
"Summarize the README"
|
||||
]
|
||||
}),
|
||||
);
|
||||
|
||||
// The cell must be live in the active_cell slot (turn not yet complete).
|
||||
let active = app.active_cell.as_ref().expect("active cell present");
|
||||
let HistoryCell::Tool(ToolCell::Generic(generic)) = &active.entries()[0] else {
|
||||
panic!("expected GenericToolCell for rlm_query");
|
||||
};
|
||||
|
||||
assert_eq!(generic.name, "rlm_query");
|
||||
assert_eq!(generic.status, ToolStatus::Running);
|
||||
|
||||
// Core assertion: prompts populated from the JSON input.
|
||||
let prompts = generic
|
||||
.prompts
|
||||
.as_ref()
|
||||
.expect("rlm_query cell must have prompts populated");
|
||||
assert_eq!(prompts.len(), 3);
|
||||
assert_eq!(prompts[0], "What is the capital of France?");
|
||||
assert_eq!(prompts[1], "List all public types in client.rs");
|
||||
assert_eq!(prompts[2], "Summarize the README");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rlm_query_singular_prompt_wired_as_single_element_vec() {
|
||||
// When the model passes `prompt` (singular) instead of `prompts`,
|
||||
// the cell should still populate a one-element prompts vec so the
|
||||
// renderer shows the child's question.
|
||||
let mut app = create_test_app();
|
||||
|
||||
handle_tool_call_started(
|
||||
&mut app,
|
||||
"rlm-2",
|
||||
"rlm_query",
|
||||
&serde_json::json!({ "prompt": "Explain the engine loop" }),
|
||||
);
|
||||
|
||||
let active = app.active_cell.as_ref().expect("active cell present");
|
||||
let HistoryCell::Tool(ToolCell::Generic(generic)) = &active.entries()[0] else {
|
||||
panic!("expected GenericToolCell for rlm_query");
|
||||
};
|
||||
|
||||
let prompts = generic
|
||||
.prompts
|
||||
.as_ref()
|
||||
.expect("singular prompt must populate prompts vec");
|
||||
assert_eq!(prompts.len(), 1);
|
||||
assert_eq!(prompts[0], "Explain the engine loop");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_fanout_tool_does_not_populate_prompts() {
|
||||
// Tools other than rlm_query must not get a prompts vec — they use
|
||||
// the standard `args:` summary rendering path.
|
||||
let mut app = create_test_app();
|
||||
|
||||
handle_tool_call_started(
|
||||
&mut app,
|
||||
"fs-1",
|
||||
"file_search",
|
||||
&serde_json::json!({ "query": "client.rs" }),
|
||||
);
|
||||
|
||||
let active = app.active_cell.as_ref().expect("active cell present");
|
||||
let HistoryCell::Tool(ToolCell::Generic(generic)) = &active.entries()[0] else {
|
||||
panic!("expected GenericToolCell for file_search");
|
||||
};
|
||||
|
||||
assert!(
|
||||
generic.prompts.is_none(),
|
||||
"non-fan-out tool must not populate prompts"
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user