refactor(core): carve approval + dispatch helpers out of engine.rs (P1.3)
Splits `core/engine.rs` (4670 → 4314 lines) into a small folder module: - `engine/approval.rs` (~125 lines) — `ApprovalDecision`, `UserInputDecision`, `ApprovalResult`, plus the two handshake methods `Engine::await_tool_approval` and `Engine::await_user_input`. - `engine/dispatch.rs` (~300 lines) — tool-input parsing (`final_tool_input`, `parse_tool_input`, fenced/JSON segment helpers), `multi_tool_use.parallel` payload parser, dispatch policy predicates (`should_parallelize_tool_batch`, `should_force_update_plan_first`, `should_stop_after_plan_tool`, the read-only MCP tool helpers), and the `ToolExecutionPlan`/`ToolExecOutcome`/`ParallelToolResult*`/ `ToolExecGuard` types the batch driver passes around. The public engine surface (`EngineConfig`, `EngineHandle`, `spawn_engine`, `MockEngineHandle`, `mock_engine_handle`, `compact_tool_result_for_context`, `TOOL_CALL_*_MARKERS`, `FAKE_WRAPPER_NOTICE`) stays in `engine.rs` — every external user imports unchanged. Not split this round: the 1268-line `handle_deepseek_turn` method. Carving its inline parallel/sequential dispatch and approval handshake arms requires extracting two new methods from a borrow-heavy turn loop; flagged in the v0.6.0 audit doc as future work. Workspace tests: 1011/1011 still green. No clippy regressions. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
+11
-367
@@ -239,43 +239,6 @@ pub struct Engine {
|
||||
turn_counter: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum ApprovalDecision {
|
||||
Approved {
|
||||
id: String,
|
||||
},
|
||||
Denied {
|
||||
id: String,
|
||||
},
|
||||
/// Retry a tool with an elevated sandbox policy.
|
||||
RetryWithPolicy {
|
||||
id: String,
|
||||
policy: crate::sandbox::SandboxPolicy,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum UserInputDecision {
|
||||
Submitted {
|
||||
id: String,
|
||||
response: UserInputResponse,
|
||||
},
|
||||
Cancelled {
|
||||
id: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Result of awaiting tool approval from the user.
|
||||
#[derive(Debug)]
|
||||
enum ApprovalResult {
|
||||
/// User approved the tool execution.
|
||||
Approved,
|
||||
/// User denied the tool execution.
|
||||
Denied,
|
||||
/// User requested retry with an elevated sandbox policy.
|
||||
RetryWithPolicy(crate::sandbox::SandboxPolicy),
|
||||
}
|
||||
|
||||
// === Internal stream helpers ===
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
@@ -294,51 +257,6 @@ struct ToolUseState {
|
||||
input_buffer: String,
|
||||
}
|
||||
|
||||
struct ToolExecOutcome {
|
||||
index: usize,
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
started_at: Instant,
|
||||
result: Result<ToolResult, ToolError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ToolExecutionPlan {
|
||||
index: usize,
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
caller: Option<ToolCaller>,
|
||||
interactive: bool,
|
||||
approval_required: bool,
|
||||
approval_description: String,
|
||||
supports_parallel: bool,
|
||||
read_only: bool,
|
||||
blocked_error: Option<ToolError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
struct ParallelToolResultEntry {
|
||||
tool_name: String,
|
||||
success: bool,
|
||||
content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
struct ParallelToolResult {
|
||||
results: Vec<ParallelToolResultEntry>,
|
||||
}
|
||||
|
||||
// Hold the lock guard for the duration of a tool execution.
|
||||
// The inner guards are held for RAII purposes (dropped when the guard is dropped).
|
||||
enum ToolExecGuard<'a> {
|
||||
Read(#[allow(dead_code)] tokio::sync::RwLockReadGuard<'a, ()>),
|
||||
Write(#[allow(dead_code)] tokio::sync::RwLockWriteGuard<'a, ()>),
|
||||
}
|
||||
|
||||
/// Maximum time to wait for a single stream chunk before assuming a stall.
|
||||
const STREAM_CHUNK_TIMEOUT_SECS: u64 = 90;
|
||||
/// Maximum total bytes of text/thinking content before aborting the stream.
|
||||
@@ -457,125 +375,6 @@ pub(crate) fn filter_tool_call_delta(delta: &str, in_tool_call: &mut bool) -> St
|
||||
/// 3. `input_buffer` non-empty but unparseable → fall back to `input`
|
||||
/// (the per-delta parser has already mirrored the most recent valid
|
||||
/// partial parse into `tool_state.input`).
|
||||
fn final_tool_input(state: &ToolUseState) -> serde_json::Value {
|
||||
if !state.input_buffer.trim().is_empty()
|
||||
&& let Some(parsed) = parse_tool_input(&state.input_buffer)
|
||||
{
|
||||
return parsed;
|
||||
}
|
||||
state.input.clone()
|
||||
}
|
||||
|
||||
fn parse_tool_input(buffer: &str) -> Option<serde_json::Value> {
|
||||
let trimmed = buffer.trim();
|
||||
if trimmed.is_empty() {
|
||||
return None;
|
||||
}
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) {
|
||||
return Some(value);
|
||||
}
|
||||
if let Some(stripped) = strip_code_fences(trimmed)
|
||||
&& let Ok(value) = serde_json::from_str::<serde_json::Value>(&stripped)
|
||||
{
|
||||
return Some(value);
|
||||
}
|
||||
if let Ok(serde_json::Value::String(inner)) = serde_json::from_str::<serde_json::Value>(trimmed)
|
||||
&& let Ok(value) = serde_json::from_str::<serde_json::Value>(&inner)
|
||||
{
|
||||
return Some(value);
|
||||
}
|
||||
extract_json_segment(trimmed)
|
||||
.and_then(|segment| serde_json::from_str::<serde_json::Value>(&segment).ok())
|
||||
}
|
||||
|
||||
fn strip_code_fences(text: &str) -> Option<String> {
|
||||
if !text.contains("```") {
|
||||
return None;
|
||||
}
|
||||
let mut lines = Vec::new();
|
||||
for line in text.lines() {
|
||||
if line.trim_start().starts_with("```") {
|
||||
continue;
|
||||
}
|
||||
lines.push(line);
|
||||
}
|
||||
let stripped = lines.join("\n");
|
||||
let stripped = stripped.trim();
|
||||
if stripped.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stripped.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_json_segment(text: &str) -> Option<String> {
|
||||
extract_balanced_segment(text, '{', '}').or_else(|| extract_balanced_segment(text, '[', ']'))
|
||||
}
|
||||
|
||||
fn extract_balanced_segment(text: &str, open: char, close: char) -> Option<String> {
|
||||
let start = text.find(open)?;
|
||||
let mut depth = 0i32;
|
||||
let mut end = None;
|
||||
for (offset, ch) in text[start..].char_indices() {
|
||||
if ch == open {
|
||||
depth += 1;
|
||||
} else if ch == close {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
end = Some(start + offset + ch.len_utf8());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
end.map(|end_idx| text[start..end_idx].to_string())
|
||||
}
|
||||
|
||||
fn normalize_parallel_tool_name(raw: &str) -> String {
|
||||
let mut name = raw.trim();
|
||||
for prefix in ["functions.", "tools.", "tool."] {
|
||||
if let Some(stripped) = name.strip_prefix(prefix) {
|
||||
name = stripped;
|
||||
break;
|
||||
}
|
||||
}
|
||||
name.to_string()
|
||||
}
|
||||
|
||||
fn parse_parallel_tool_calls(
|
||||
input: &serde_json::Value,
|
||||
) -> Result<Vec<(String, serde_json::Value)>, ToolError> {
|
||||
let tool_uses = input
|
||||
.get("tool_uses")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| ToolError::missing_field("tool_uses"))?;
|
||||
if tool_uses.is_empty() {
|
||||
return Err(ToolError::invalid_input(
|
||||
"multi_tool_use.parallel requires at least one tool call",
|
||||
));
|
||||
}
|
||||
|
||||
let mut calls = Vec::with_capacity(tool_uses.len());
|
||||
for item in tool_uses {
|
||||
let name = item
|
||||
.get("recipient_name")
|
||||
.or_else(|| item.get("tool_name"))
|
||||
.or_else(|| item.get("name"))
|
||||
.or_else(|| item.get("tool"))
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| ToolError::missing_field("recipient_name"))?;
|
||||
let params = item
|
||||
.get("parameters")
|
||||
.or_else(|| item.get("input"))
|
||||
.or_else(|| item.get("args"))
|
||||
.or_else(|| item.get("arguments"))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| json!({}));
|
||||
calls.push((normalize_parallel_tool_name(name), params));
|
||||
}
|
||||
|
||||
Ok(calls)
|
||||
}
|
||||
|
||||
fn is_tool_search_tool(name: &str) -> bool {
|
||||
matches!(name, TOOL_SEARCH_REGEX_NAME | TOOL_SEARCH_BM25_NAME)
|
||||
}
|
||||
@@ -978,99 +777,6 @@ fn caller_allowed_for_tool(caller: Option<&ToolCaller>, tool_def: Option<&Tool>)
|
||||
requested == "direct"
|
||||
}
|
||||
|
||||
fn should_parallelize_tool_batch(plans: &[ToolExecutionPlan]) -> bool {
|
||||
!plans.is_empty()
|
||||
&& plans.iter().all(|plan| {
|
||||
plan.read_only && plan.supports_parallel && !plan.approval_required && !plan.interactive
|
||||
})
|
||||
}
|
||||
|
||||
fn should_stop_after_plan_tool(
|
||||
mode: AppMode,
|
||||
tool_name: &str,
|
||||
result: &Result<ToolResult, ToolError>,
|
||||
) -> bool {
|
||||
mode == AppMode::Plan && tool_name == "update_plan" && result.is_ok()
|
||||
}
|
||||
|
||||
fn should_force_update_plan_first(mode: AppMode, content: &str) -> bool {
|
||||
if mode != AppMode::Plan {
|
||||
return false;
|
||||
}
|
||||
|
||||
let lower = content.to_ascii_lowercase();
|
||||
let asks_for_direct_plan = [
|
||||
"quick plan",
|
||||
"short plan",
|
||||
"simple plan",
|
||||
"3-step plan",
|
||||
"3 step plan",
|
||||
"three-step plan",
|
||||
"three step plan",
|
||||
"high-level plan",
|
||||
"high level plan",
|
||||
"give me a plan",
|
||||
"make a plan",
|
||||
"outline a plan",
|
||||
"draft a plan",
|
||||
]
|
||||
.iter()
|
||||
.any(|needle| lower.contains(needle));
|
||||
|
||||
if !asks_for_direct_plan {
|
||||
return false;
|
||||
}
|
||||
|
||||
let asks_for_repo_exploration = [
|
||||
"inspect the repo",
|
||||
"inspect the code",
|
||||
"explore the repo",
|
||||
"search the repo",
|
||||
"read the code",
|
||||
"review the code",
|
||||
"analyze the code",
|
||||
"investigate",
|
||||
"look through",
|
||||
"understand the current",
|
||||
"ground it in the codebase",
|
||||
"based on the codebase",
|
||||
]
|
||||
.iter()
|
||||
.any(|needle| lower.contains(needle));
|
||||
|
||||
!asks_for_repo_exploration
|
||||
}
|
||||
|
||||
fn mcp_tool_is_parallel_safe(name: &str) -> bool {
|
||||
matches!(
|
||||
name,
|
||||
"list_mcp_resources"
|
||||
| "list_mcp_resource_templates"
|
||||
| "mcp_read_resource"
|
||||
| "read_mcp_resource"
|
||||
| "mcp_get_prompt"
|
||||
)
|
||||
}
|
||||
|
||||
fn mcp_tool_is_read_only(name: &str) -> bool {
|
||||
matches!(
|
||||
name,
|
||||
"list_mcp_resources"
|
||||
| "list_mcp_resource_templates"
|
||||
| "mcp_read_resource"
|
||||
| "read_mcp_resource"
|
||||
| "mcp_get_prompt"
|
||||
)
|
||||
}
|
||||
|
||||
fn mcp_tool_approval_description(name: &str) -> String {
|
||||
if mcp_tool_is_read_only(name) {
|
||||
format!("Read-only MCP tool '{name}'")
|
||||
} else {
|
||||
format!("MCP tool '{name}' may have side effects")
|
||||
}
|
||||
}
|
||||
|
||||
fn format_tool_error(err: &ToolError, tool_name: &str) -> String {
|
||||
match err {
|
||||
ToolError::InvalidInput { message } => {
|
||||
@@ -2291,79 +1997,6 @@ impl Engine {
|
||||
result
|
||||
}
|
||||
|
||||
async fn await_tool_approval(&mut self, tool_id: &str) -> Result<ApprovalResult, ToolError> {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = self.cancel_token.cancelled() => {
|
||||
return Err(ToolError::execution_failed(
|
||||
"Request cancelled while awaiting approval".to_string(),
|
||||
));
|
||||
}
|
||||
decision = self.rx_approval.recv() => {
|
||||
let Some(decision) = decision else {
|
||||
return Err(ToolError::execution_failed(
|
||||
"Approval channel closed".to_string(),
|
||||
));
|
||||
};
|
||||
match decision {
|
||||
ApprovalDecision::Approved { id } if id == tool_id => {
|
||||
return Ok(ApprovalResult::Approved);
|
||||
}
|
||||
ApprovalDecision::Denied { id } if id == tool_id => {
|
||||
return Ok(ApprovalResult::Denied);
|
||||
}
|
||||
ApprovalDecision::RetryWithPolicy { id, policy } if id == tool_id => {
|
||||
return Ok(ApprovalResult::RetryWithPolicy(policy));
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn await_user_input(
|
||||
&mut self,
|
||||
tool_id: &str,
|
||||
request: UserInputRequest,
|
||||
) -> Result<UserInputResponse, ToolError> {
|
||||
let _ = self
|
||||
.tx_event
|
||||
.send(Event::UserInputRequired {
|
||||
id: tool_id.to_string(),
|
||||
request,
|
||||
})
|
||||
.await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = self.cancel_token.cancelled() => {
|
||||
return Err(ToolError::execution_failed(
|
||||
"Request cancelled while awaiting user input".to_string(),
|
||||
));
|
||||
}
|
||||
decision = self.rx_user_input.recv() => {
|
||||
let Some(decision) = decision else {
|
||||
return Err(ToolError::execution_failed(
|
||||
"User input channel closed".to_string(),
|
||||
));
|
||||
};
|
||||
match decision {
|
||||
UserInputDecision::Submitted { id, response } if id == tool_id => {
|
||||
return Ok(response);
|
||||
}
|
||||
UserInputDecision::Cancelled { id } if id == tool_id => {
|
||||
return Err(ToolError::execution_failed(
|
||||
"User input cancelled".to_string(),
|
||||
));
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a turn using the DeepSeek API.
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn handle_deepseek_turn(
|
||||
@@ -4666,5 +4299,16 @@ pub(crate) fn mock_engine_handle() -> MockEngineHandle {
|
||||
}
|
||||
}
|
||||
|
||||
mod approval;
|
||||
mod dispatch;
|
||||
|
||||
use self::approval::{ApprovalDecision, ApprovalResult, UserInputDecision};
|
||||
use self::dispatch::{
|
||||
ParallelToolResult, ParallelToolResultEntry, ToolExecGuard, ToolExecOutcome, ToolExecutionPlan,
|
||||
final_tool_input, mcp_tool_approval_description, mcp_tool_is_parallel_safe,
|
||||
mcp_tool_is_read_only, parse_parallel_tool_calls, parse_tool_input,
|
||||
should_force_update_plan_first, should_parallelize_tool_batch, should_stop_after_plan_tool,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
//! Approval + user-input handshake for the agent loop.
|
||||
//!
|
||||
//! Extracted from `core/engine.rs` (P1.3). The agent loop blocks on these
|
||||
//! two futures whenever a tool requires explicit approval (`await_tool_approval`)
|
||||
//! or whenever a tool requests live user input (`await_user_input`). Channels
|
||||
//! and engine state stay private to the parent module.
|
||||
|
||||
use crate::core::events::Event;
|
||||
use crate::tools::spec::ToolError;
|
||||
use crate::tools::user_input::{UserInputRequest, UserInputResponse};
|
||||
|
||||
use super::Engine;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(super) enum ApprovalDecision {
|
||||
Approved {
|
||||
id: String,
|
||||
},
|
||||
Denied {
|
||||
id: String,
|
||||
},
|
||||
/// Retry a tool with an elevated sandbox policy.
|
||||
RetryWithPolicy {
|
||||
id: String,
|
||||
policy: crate::sandbox::SandboxPolicy,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(super) enum UserInputDecision {
|
||||
Submitted {
|
||||
id: String,
|
||||
response: UserInputResponse,
|
||||
},
|
||||
Cancelled {
|
||||
id: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Result of awaiting tool approval from the user.
|
||||
#[derive(Debug)]
|
||||
pub(super) enum ApprovalResult {
|
||||
/// User approved the tool execution.
|
||||
Approved,
|
||||
/// User denied the tool execution.
|
||||
Denied,
|
||||
/// User requested retry with an elevated sandbox policy.
|
||||
RetryWithPolicy(crate::sandbox::SandboxPolicy),
|
||||
}
|
||||
|
||||
impl Engine {
|
||||
pub(super) async fn await_tool_approval(
|
||||
&mut self,
|
||||
tool_id: &str,
|
||||
) -> Result<ApprovalResult, ToolError> {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = self.cancel_token.cancelled() => {
|
||||
return Err(ToolError::execution_failed(
|
||||
"Request cancelled while awaiting approval".to_string(),
|
||||
));
|
||||
}
|
||||
decision = self.rx_approval.recv() => {
|
||||
let Some(decision) = decision else {
|
||||
return Err(ToolError::execution_failed(
|
||||
"Approval channel closed".to_string(),
|
||||
));
|
||||
};
|
||||
match decision {
|
||||
ApprovalDecision::Approved { id } if id == tool_id => {
|
||||
return Ok(ApprovalResult::Approved);
|
||||
}
|
||||
ApprovalDecision::Denied { id } if id == tool_id => {
|
||||
return Ok(ApprovalResult::Denied);
|
||||
}
|
||||
ApprovalDecision::RetryWithPolicy { id, policy } if id == tool_id => {
|
||||
return Ok(ApprovalResult::RetryWithPolicy(policy));
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn await_user_input(
|
||||
&mut self,
|
||||
tool_id: &str,
|
||||
request: UserInputRequest,
|
||||
) -> Result<UserInputResponse, ToolError> {
|
||||
let _ = self
|
||||
.tx_event
|
||||
.send(Event::UserInputRequired {
|
||||
id: tool_id.to_string(),
|
||||
request,
|
||||
})
|
||||
.await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = self.cancel_token.cancelled() => {
|
||||
return Err(ToolError::execution_failed(
|
||||
"Request cancelled while awaiting user input".to_string(),
|
||||
));
|
||||
}
|
||||
decision = self.rx_user_input.recv() => {
|
||||
let Some(decision) = decision else {
|
||||
return Err(ToolError::execution_failed(
|
||||
"User input channel closed".to_string(),
|
||||
));
|
||||
};
|
||||
match decision {
|
||||
UserInputDecision::Submitted { id, response } if id == tool_id => {
|
||||
return Ok(response);
|
||||
}
|
||||
UserInputDecision::Cancelled { id } if id == tool_id => {
|
||||
return Err(ToolError::execution_failed(
|
||||
"User input cancelled".to_string(),
|
||||
));
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
//! Tool dispatch — plan/execute helpers for the per-turn tool batch.
|
||||
//!
|
||||
//! Extracted from `core/engine.rs` (P1.3). The high-level ordering still
|
||||
//! lives in `Engine::handle_deepseek_turn`; this module owns:
|
||||
//!
|
||||
//! * Streaming-buffer parsing into a finalized `serde_json::Value` tool input
|
||||
//! (`final_tool_input`, `parse_tool_input`, fenced/JSON segment helpers).
|
||||
//! * The `multi_tool_use.parallel` payload parser.
|
||||
//! * Policy predicates the turn loop consults — when a batch can run in
|
||||
//! parallel, when an `update_plan` step should stop the turn, when a Plan
|
||||
//! prompt should force a plan-first hop, and the small set of read-only
|
||||
//! MCP tools that are safe to run in parallel.
|
||||
//! * The tool execution plan/outcome types the batch driver passes around.
|
||||
//!
|
||||
//! All items are `pub(super)`-only: the public engine surface (Op/Event,
|
||||
//! `EngineHandle`, `spawn_engine`) stays in `engine/mod.rs`.
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use crate::models::ToolCaller;
|
||||
use crate::tools::spec::{ToolError, ToolResult};
|
||||
use crate::tui::app::AppMode;
|
||||
|
||||
use super::ToolUseState;
|
||||
|
||||
// === Types ============================================================
|
||||
|
||||
#[allow(dead_code)] // `index` mirrors batch order for diagnostic ergonomics.
|
||||
pub(super) struct ToolExecOutcome {
|
||||
pub(super) index: usize,
|
||||
pub(super) id: String,
|
||||
pub(super) name: String,
|
||||
pub(super) input: serde_json::Value,
|
||||
pub(super) started_at: std::time::Instant,
|
||||
pub(super) result: Result<ToolResult, ToolError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(super) struct ToolExecutionPlan {
|
||||
pub(super) index: usize,
|
||||
pub(super) id: String,
|
||||
pub(super) name: String,
|
||||
pub(super) input: serde_json::Value,
|
||||
pub(super) caller: Option<ToolCaller>,
|
||||
pub(super) interactive: bool,
|
||||
pub(super) approval_required: bool,
|
||||
pub(super) approval_description: String,
|
||||
pub(super) supports_parallel: bool,
|
||||
pub(super) read_only: bool,
|
||||
pub(super) blocked_error: Option<ToolError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub(super) struct ParallelToolResultEntry {
|
||||
pub(super) tool_name: String,
|
||||
pub(super) success: bool,
|
||||
pub(super) content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(super) error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub(super) struct ParallelToolResult {
|
||||
pub(super) results: Vec<ParallelToolResultEntry>,
|
||||
}
|
||||
|
||||
// Hold the lock guard for the duration of a tool execution.
|
||||
// The inner guards are held for RAII purposes (dropped when the guard is dropped).
|
||||
pub(super) enum ToolExecGuard<'a> {
|
||||
Read(#[allow(dead_code)] tokio::sync::RwLockReadGuard<'a, ()>),
|
||||
Write(#[allow(dead_code)] tokio::sync::RwLockWriteGuard<'a, ()>),
|
||||
}
|
||||
|
||||
// === Streaming-buffer parsing =========================================
|
||||
|
||||
/// Promote a streaming `ToolUseState` to a finalized JSON input.
|
||||
///
|
||||
/// Order of preference:
|
||||
///
|
||||
/// 1. `input_buffer` (the raw streamed delta concatenation) — parsed as
|
||||
/// JSON. This is the most authoritative because it's what the model
|
||||
/// actually emitted.
|
||||
/// 2. `input` (the per-delta best-effort parse mirror) — used when the
|
||||
/// buffer is empty (pre-streaming tool calls take this path).
|
||||
/// 3. `input_buffer` non-empty but unparseable → fall back to `input`
|
||||
/// (the per-delta parser has already mirrored the most recent valid
|
||||
/// partial parse into `tool_state.input`).
|
||||
pub(super) fn final_tool_input(state: &ToolUseState) -> serde_json::Value {
|
||||
if !state.input_buffer.trim().is_empty()
|
||||
&& let Some(parsed) = parse_tool_input(&state.input_buffer)
|
||||
{
|
||||
return parsed;
|
||||
}
|
||||
state.input.clone()
|
||||
}
|
||||
|
||||
pub(super) fn parse_tool_input(buffer: &str) -> Option<serde_json::Value> {
|
||||
let trimmed = buffer.trim();
|
||||
if trimmed.is_empty() {
|
||||
return None;
|
||||
}
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) {
|
||||
return Some(value);
|
||||
}
|
||||
if let Some(stripped) = strip_code_fences(trimmed)
|
||||
&& let Ok(value) = serde_json::from_str::<serde_json::Value>(&stripped)
|
||||
{
|
||||
return Some(value);
|
||||
}
|
||||
if let Ok(serde_json::Value::String(inner)) = serde_json::from_str::<serde_json::Value>(trimmed)
|
||||
&& let Ok(value) = serde_json::from_str::<serde_json::Value>(&inner)
|
||||
{
|
||||
return Some(value);
|
||||
}
|
||||
extract_json_segment(trimmed)
|
||||
.and_then(|segment| serde_json::from_str::<serde_json::Value>(&segment).ok())
|
||||
}
|
||||
|
||||
fn strip_code_fences(text: &str) -> Option<String> {
|
||||
if !text.contains("```") {
|
||||
return None;
|
||||
}
|
||||
let mut lines = Vec::new();
|
||||
for line in text.lines() {
|
||||
if line.trim_start().starts_with("```") {
|
||||
continue;
|
||||
}
|
||||
lines.push(line);
|
||||
}
|
||||
let stripped = lines.join("\n");
|
||||
let stripped = stripped.trim();
|
||||
if stripped.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stripped.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_json_segment(text: &str) -> Option<String> {
|
||||
extract_balanced_segment(text, '{', '}').or_else(|| extract_balanced_segment(text, '[', ']'))
|
||||
}
|
||||
|
||||
fn extract_balanced_segment(text: &str, open: char, close: char) -> Option<String> {
|
||||
let start = text.find(open)?;
|
||||
let mut depth = 0i32;
|
||||
let mut end = None;
|
||||
for (offset, ch) in text[start..].char_indices() {
|
||||
if ch == open {
|
||||
depth += 1;
|
||||
} else if ch == close {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
end = Some(start + offset + ch.len_utf8());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
end.map(|end_idx| text[start..end_idx].to_string())
|
||||
}
|
||||
|
||||
fn normalize_parallel_tool_name(raw: &str) -> String {
|
||||
let mut name = raw.trim();
|
||||
for prefix in ["functions.", "tools.", "tool."] {
|
||||
if let Some(stripped) = name.strip_prefix(prefix) {
|
||||
name = stripped;
|
||||
break;
|
||||
}
|
||||
}
|
||||
name.to_string()
|
||||
}
|
||||
|
||||
pub(super) fn parse_parallel_tool_calls(
|
||||
input: &serde_json::Value,
|
||||
) -> Result<Vec<(String, serde_json::Value)>, ToolError> {
|
||||
let tool_uses = input
|
||||
.get("tool_uses")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| ToolError::missing_field("tool_uses"))?;
|
||||
if tool_uses.is_empty() {
|
||||
return Err(ToolError::invalid_input(
|
||||
"multi_tool_use.parallel requires at least one tool call",
|
||||
));
|
||||
}
|
||||
|
||||
let mut calls = Vec::with_capacity(tool_uses.len());
|
||||
for item in tool_uses {
|
||||
let name = item
|
||||
.get("recipient_name")
|
||||
.or_else(|| item.get("tool_name"))
|
||||
.or_else(|| item.get("name"))
|
||||
.or_else(|| item.get("tool"))
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| ToolError::missing_field("recipient_name"))?;
|
||||
let params = item
|
||||
.get("parameters")
|
||||
.or_else(|| item.get("input"))
|
||||
.or_else(|| item.get("args"))
|
||||
.or_else(|| item.get("arguments"))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| json!({}));
|
||||
calls.push((normalize_parallel_tool_name(name), params));
|
||||
}
|
||||
|
||||
Ok(calls)
|
||||
}
|
||||
|
||||
// === Dispatch policy ==================================================
|
||||
|
||||
pub(super) fn should_parallelize_tool_batch(plans: &[ToolExecutionPlan]) -> bool {
|
||||
!plans.is_empty()
|
||||
&& plans.iter().all(|plan| {
|
||||
plan.read_only && plan.supports_parallel && !plan.approval_required && !plan.interactive
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn should_stop_after_plan_tool(
|
||||
mode: AppMode,
|
||||
tool_name: &str,
|
||||
result: &Result<ToolResult, ToolError>,
|
||||
) -> bool {
|
||||
mode == AppMode::Plan && tool_name == "update_plan" && result.is_ok()
|
||||
}
|
||||
|
||||
pub(super) fn should_force_update_plan_first(mode: AppMode, content: &str) -> bool {
|
||||
if mode != AppMode::Plan {
|
||||
return false;
|
||||
}
|
||||
|
||||
let lower = content.to_ascii_lowercase();
|
||||
let asks_for_direct_plan = [
|
||||
"quick plan",
|
||||
"short plan",
|
||||
"simple plan",
|
||||
"3-step plan",
|
||||
"3 step plan",
|
||||
"three-step plan",
|
||||
"three step plan",
|
||||
"high-level plan",
|
||||
"high level plan",
|
||||
"give me a plan",
|
||||
"make a plan",
|
||||
"outline a plan",
|
||||
"draft a plan",
|
||||
]
|
||||
.iter()
|
||||
.any(|needle| lower.contains(needle));
|
||||
|
||||
if !asks_for_direct_plan {
|
||||
return false;
|
||||
}
|
||||
|
||||
let asks_for_repo_exploration = [
|
||||
"inspect the repo",
|
||||
"inspect the code",
|
||||
"explore the repo",
|
||||
"search the repo",
|
||||
"read the code",
|
||||
"review the code",
|
||||
"analyze the code",
|
||||
"investigate",
|
||||
"look through",
|
||||
"understand the current",
|
||||
"ground it in the codebase",
|
||||
"based on the codebase",
|
||||
]
|
||||
.iter()
|
||||
.any(|needle| lower.contains(needle));
|
||||
|
||||
!asks_for_repo_exploration
|
||||
}
|
||||
|
||||
pub(super) fn mcp_tool_is_parallel_safe(name: &str) -> bool {
|
||||
matches!(
|
||||
name,
|
||||
"list_mcp_resources"
|
||||
| "list_mcp_resource_templates"
|
||||
| "mcp_read_resource"
|
||||
| "read_mcp_resource"
|
||||
| "mcp_get_prompt"
|
||||
)
|
||||
}
|
||||
|
||||
pub(super) fn mcp_tool_is_read_only(name: &str) -> bool {
|
||||
matches!(
|
||||
name,
|
||||
"list_mcp_resources"
|
||||
| "list_mcp_resource_templates"
|
||||
| "mcp_read_resource"
|
||||
| "read_mcp_resource"
|
||||
| "mcp_get_prompt"
|
||||
)
|
||||
}
|
||||
|
||||
pub(super) fn mcp_tool_approval_description(name: &str) -> String {
|
||||
if mcp_tool_is_read_only(name) {
|
||||
format!("Read-only MCP tool '{name}'")
|
||||
} else {
|
||||
format!("MCP tool '{name}' may have side effects")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user