refactor(engine): split tool execution helpers
This commit is contained in:
@@ -11,7 +11,6 @@ use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex as StdMutex};
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{fs::OpenOptions, io::Write};
|
||||
|
||||
use anyhow::Result;
|
||||
use futures_util::StreamExt;
|
||||
@@ -427,67 +426,6 @@ fn caller_type_for_tool_use(caller: Option<&ToolCaller>) -> &str {
|
||||
caller.map_or("direct", |c| c.caller_type.as_str())
|
||||
}
|
||||
|
||||
/// #136: derive the file path(s) edited by a tool call. Returns the empty
|
||||
/// vec for tools that don't modify files. We intentionally only handle the
|
||||
/// three known edit tools — adding more (e.g. specialized refactor tools)
|
||||
/// is a one-line change here.
|
||||
fn edited_paths_for_tool(tool_name: &str, input: &serde_json::Value) -> Vec<PathBuf> {
|
||||
match tool_name {
|
||||
"edit_file" | "write_file" => {
|
||||
if let Some(path) = input.get("path").and_then(|v| v.as_str()) {
|
||||
vec![PathBuf::from(path)]
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
"apply_patch" => {
|
||||
// `apply_patch` accepts either a `path` override or a list of
|
||||
// `files` (each `{path, content}`). We try both shapes.
|
||||
let mut out = Vec::new();
|
||||
if let Some(path) = input.get("path").and_then(|v| v.as_str()) {
|
||||
out.push(PathBuf::from(path));
|
||||
}
|
||||
if let Some(files) = input.get("files").and_then(|v| v.as_array()) {
|
||||
for entry in files {
|
||||
if let Some(path) = entry.get("path").and_then(|v| v.as_str()) {
|
||||
out.push(PathBuf::from(path));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: parse `---`/`+++` headers from a unified diff payload.
|
||||
if out.is_empty()
|
||||
&& let Some(patch) = input.get("patch").and_then(|v| v.as_str())
|
||||
{
|
||||
out.extend(parse_patch_paths(patch));
|
||||
}
|
||||
out
|
||||
}
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Lightweight parser for `+++ b/<path>` lines in a unified diff. Used as a
|
||||
/// fallback when `apply_patch` is invoked with raw `patch` text and no
|
||||
/// `path`/`files` override. We deliberately keep this dumb — the real
|
||||
/// `apply_patch` tool already validates the patch shape; we only need a
|
||||
/// best-effort hint for the LSP hook.
|
||||
fn parse_patch_paths(patch: &str) -> Vec<PathBuf> {
|
||||
let mut out = Vec::new();
|
||||
for line in patch.lines() {
|
||||
if let Some(rest) = line.strip_prefix("+++ ") {
|
||||
let trimmed = rest.trim();
|
||||
// Strip leading `b/` per git diff conventions.
|
||||
let path = trimmed.strip_prefix("b/").unwrap_or(trimmed);
|
||||
// Skip `/dev/null` (deletion).
|
||||
if path == "/dev/null" {
|
||||
continue;
|
||||
}
|
||||
out.push(PathBuf::from(path));
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn caller_allowed_for_tool(caller: Option<&ToolCaller>, tool_def: Option<&Tool>) -> bool {
|
||||
let requested = caller_type_for_tool_use(caller);
|
||||
if let Some(def) = tool_def
|
||||
@@ -533,23 +471,6 @@ fn format_tool_error(err: &ToolError, tool_name: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_tool_audit(event: serde_json::Value) {
|
||||
let Some(path) = std::env::var_os("DEEPSEEK_TOOL_AUDIT_LOG") else {
|
||||
return;
|
||||
};
|
||||
let line = match serde_json::to_string(&event) {
|
||||
Ok(line) => line,
|
||||
Err(_) => return,
|
||||
};
|
||||
let path = PathBuf::from(path);
|
||||
if let Some(parent) = path.parent() {
|
||||
let _ = std::fs::create_dir_all(parent);
|
||||
}
|
||||
if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) {
|
||||
let _ = writeln!(file, "{line}");
|
||||
}
|
||||
}
|
||||
|
||||
impl Engine {
|
||||
fn reset_cancel_token(&mut self) {
|
||||
let token = CancellationToken::new();
|
||||
@@ -889,57 +810,6 @@ impl Engine {
|
||||
self.emit_session_updated().await;
|
||||
}
|
||||
|
||||
/// #136: post-edit hook. Inspects the tool name + input, derives the
|
||||
/// edited file path, and asks the LSP manager for diagnostics. The
|
||||
/// rendered block is queued in `pending_lsp_blocks` and flushed to the
|
||||
/// session message stream just before the next API request. Failure is
|
||||
/// silent by design — a missing/crashing LSP server must never block
|
||||
/// the agent.
|
||||
async fn run_post_edit_lsp_hook(&mut self, tool_name: &str, tool_input: &serde_json::Value) {
|
||||
if !self.lsp_manager.config().enabled {
|
||||
return;
|
||||
}
|
||||
let paths = edited_paths_for_tool(tool_name, tool_input);
|
||||
for path in paths {
|
||||
let absolute = if path.is_absolute() {
|
||||
path.clone()
|
||||
} else {
|
||||
self.session.workspace.join(&path)
|
||||
};
|
||||
// Use a short edit-sequence based on the existing turn counter so
|
||||
// log output stays correlated even though we do not currently
|
||||
// batch by sequence.
|
||||
let seq = self.turn_counter;
|
||||
if let Some(block) = self.lsp_manager.diagnostics_for(&absolute, seq).await {
|
||||
self.pending_lsp_blocks.push(block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Drain `pending_lsp_blocks` into a single synthetic user message so the
|
||||
/// model sees the diagnostics on its next request. Skips when nothing is
|
||||
/// pending. The message uses the standard `text` content block shape
|
||||
/// (the same shape as the post-tool steer messages) so we don't need to
|
||||
/// invent a new envelope.
|
||||
async fn flush_pending_lsp_diagnostics(&mut self) {
|
||||
if self.pending_lsp_blocks.is_empty() {
|
||||
return;
|
||||
}
|
||||
let blocks = std::mem::take(&mut self.pending_lsp_blocks);
|
||||
let rendered = crate::lsp::render_blocks(&blocks);
|
||||
if rendered.is_empty() {
|
||||
return;
|
||||
}
|
||||
self.add_session_message(Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: rendered,
|
||||
cache_control: None,
|
||||
}],
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Handle a send message operation
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn handle_send_message(
|
||||
@@ -1550,175 +1420,6 @@ impl Engine {
|
||||
pool.to_api_tools()
|
||||
}
|
||||
|
||||
async fn execute_mcp_tool_with_pool(
|
||||
pool: Arc<AsyncMutex<McpPool>>,
|
||||
name: &str,
|
||||
input: serde_json::Value,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let mut pool = pool.lock().await;
|
||||
let result = pool
|
||||
.call_tool(name, input)
|
||||
.await
|
||||
.map_err(|e| ToolError::execution_failed(format!("MCP tool failed: {e}")))?;
|
||||
let content = serde_json::to_string_pretty(&result).unwrap_or_else(|_| result.to_string());
|
||||
Ok(ToolResult::success(content))
|
||||
}
|
||||
|
||||
async fn execute_parallel_tool(
|
||||
&mut self,
|
||||
input: serde_json::Value,
|
||||
tool_registry: Option<&crate::tools::ToolRegistry>,
|
||||
tool_exec_lock: Arc<RwLock<()>>,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let calls = parse_parallel_tool_calls(&input)?;
|
||||
let mcp_pool = if calls.iter().any(|(tool, _)| McpPool::is_mcp_tool(tool)) {
|
||||
Some(self.ensure_mcp_pool().await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let Some(registry) = tool_registry else {
|
||||
return Err(ToolError::not_available(
|
||||
"tool registry unavailable for multi_tool_use.parallel",
|
||||
));
|
||||
};
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
for (tool_name, tool_input) in calls {
|
||||
if tool_name == MULTI_TOOL_PARALLEL_NAME {
|
||||
return Err(ToolError::invalid_input(
|
||||
"multi_tool_use.parallel cannot call itself",
|
||||
));
|
||||
}
|
||||
if McpPool::is_mcp_tool(&tool_name) {
|
||||
if !mcp_tool_is_parallel_safe(&tool_name) {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
"Tool '{tool_name}' is an MCP tool and cannot run in parallel. \
|
||||
Allowed MCP tools: list_mcp_resources, list_mcp_resource_templates, \
|
||||
mcp_read_resource, read_mcp_resource, mcp_get_prompt."
|
||||
)));
|
||||
}
|
||||
} else {
|
||||
let Some(spec) = registry.get(&tool_name) else {
|
||||
return Err(ToolError::not_available(format!(
|
||||
"tool '{tool_name}' is not registered"
|
||||
)));
|
||||
};
|
||||
if !spec.is_read_only() {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
"Tool '{tool_name}' is not read-only and cannot run in parallel"
|
||||
)));
|
||||
}
|
||||
if spec.approval_requirement() != ApprovalRequirement::Auto {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
"Tool '{tool_name}' requires approval and cannot run in parallel"
|
||||
)));
|
||||
}
|
||||
if !spec.supports_parallel() {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
"Tool '{tool_name}' does not support parallel execution"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let registry_ref = registry;
|
||||
let lock = tool_exec_lock.clone();
|
||||
let tx_event = self.tx_event.clone();
|
||||
let mcp_pool = mcp_pool.clone();
|
||||
tasks.push(async move {
|
||||
let result = Engine::execute_tool_with_lock(
|
||||
lock,
|
||||
true,
|
||||
false,
|
||||
tx_event,
|
||||
tool_name.clone(),
|
||||
tool_input.clone(),
|
||||
Some(registry_ref),
|
||||
mcp_pool,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
(tool_name, result)
|
||||
});
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
while let Some((tool_name, result)) = tasks.next().await {
|
||||
match result {
|
||||
Ok(output) => {
|
||||
let mut error = None;
|
||||
if !output.success {
|
||||
error = Some(output.content.clone());
|
||||
}
|
||||
results.push(ParallelToolResultEntry {
|
||||
tool_name,
|
||||
success: output.success,
|
||||
content: output.content,
|
||||
error,
|
||||
});
|
||||
}
|
||||
Err(err) => {
|
||||
let message = format!("{err}");
|
||||
results.push(ParallelToolResultEntry {
|
||||
tool_name,
|
||||
success: false,
|
||||
content: format!("Error: {message}"),
|
||||
error: Some(message),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ToolResult::json(&ParallelToolResult { results })
|
||||
.map_err(|e| ToolError::execution_failed(e.to_string()))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn execute_tool_with_lock(
|
||||
lock: Arc<RwLock<()>>,
|
||||
supports_parallel: bool,
|
||||
interactive: bool,
|
||||
tx_event: mpsc::Sender<Event>,
|
||||
tool_name: String,
|
||||
tool_input: serde_json::Value,
|
||||
registry: Option<&crate::tools::ToolRegistry>,
|
||||
mcp_pool: Option<Arc<AsyncMutex<McpPool>>>,
|
||||
context_override: Option<crate::tools::ToolContext>,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let _guard = if supports_parallel {
|
||||
ToolExecGuard::Read(lock.read().await)
|
||||
} else {
|
||||
ToolExecGuard::Write(lock.write().await)
|
||||
};
|
||||
|
||||
if interactive {
|
||||
let _ = tx_event.send(Event::PauseEvents).await;
|
||||
}
|
||||
|
||||
let result = if McpPool::is_mcp_tool(&tool_name) {
|
||||
if let Some(pool) = mcp_pool {
|
||||
Engine::execute_mcp_tool_with_pool(pool, &tool_name, tool_input).await
|
||||
} else {
|
||||
Err(ToolError::not_available(format!(
|
||||
"tool '{tool_name}' is not registered"
|
||||
)))
|
||||
}
|
||||
} else if let Some(registry) = registry {
|
||||
registry
|
||||
.execute_full_with_context(&tool_name, tool_input, context_override.as_ref())
|
||||
.await
|
||||
} else {
|
||||
Err(ToolError::not_available(format!(
|
||||
"tool '{tool_name}' is not registered"
|
||||
)))
|
||||
};
|
||||
|
||||
if interactive {
|
||||
let _ = tx_event.send(Event::ResumeEvents).await;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Handle a turn using the DeepSeek API.
|
||||
#[allow(clippy::too_many_lines)]
|
||||
/// Run the pre-request layered-context checkpoint (#159). Checks whether
|
||||
@@ -2148,7 +1849,9 @@ use context::{
|
||||
turn_response_headroom_tokens,
|
||||
};
|
||||
mod dispatch;
|
||||
mod lsp_hooks;
|
||||
mod tool_catalog;
|
||||
mod tool_execution;
|
||||
mod tool_setup;
|
||||
mod turn_loop;
|
||||
|
||||
@@ -2159,6 +1862,8 @@ use self::dispatch::{
|
||||
mcp_tool_is_read_only, parse_parallel_tool_calls, parse_tool_input,
|
||||
should_force_update_plan_first, should_parallelize_tool_batch, should_stop_after_plan_tool,
|
||||
};
|
||||
#[cfg(test)]
|
||||
use self::lsp_hooks::{edited_paths_for_tool, parse_patch_paths};
|
||||
use self::tool_catalog::{
|
||||
CODE_EXECUTION_TOOL_NAME, MULTI_TOOL_PARALLEL_NAME, REQUEST_USER_INPUT_NAME,
|
||||
active_tools_for_step, build_model_tool_catalog, ensure_advanced_tooling,
|
||||
@@ -2167,6 +1872,7 @@ use self::tool_catalog::{
|
||||
};
|
||||
#[cfg(test)]
|
||||
use self::tool_catalog::{TOOL_SEARCH_BM25_NAME, should_default_defer_tool};
|
||||
use self::tool_execution::emit_tool_audit;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
//! Post-edit LSP diagnostics hooks for engine tool execution.
|
||||
//!
|
||||
//! The turn loop only needs to ask "did a successful edit produce diagnostics?"
|
||||
//! This module owns the tool-input path extraction and the synthetic diagnostic
|
||||
//! message injection so the top-level engine module stays focused on session
|
||||
//! orchestration.
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// #136: derive the file path(s) edited by a tool call. Returns the empty
|
||||
/// vec for tools that don't modify files. We intentionally only handle the
|
||||
/// three known edit tools — adding more (e.g. specialized refactor tools)
|
||||
/// is a one-line change here.
|
||||
pub(super) fn edited_paths_for_tool(tool_name: &str, input: &serde_json::Value) -> Vec<PathBuf> {
|
||||
match tool_name {
|
||||
"edit_file" | "write_file" => {
|
||||
if let Some(path) = input.get("path").and_then(|v| v.as_str()) {
|
||||
vec![PathBuf::from(path)]
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
"apply_patch" => {
|
||||
// `apply_patch` accepts either a `path` override or a list of
|
||||
// `files` (each `{path, content}`). We try both shapes.
|
||||
let mut out = Vec::new();
|
||||
if let Some(path) = input.get("path").and_then(|v| v.as_str()) {
|
||||
out.push(PathBuf::from(path));
|
||||
}
|
||||
if let Some(files) = input.get("files").and_then(|v| v.as_array()) {
|
||||
for entry in files {
|
||||
if let Some(path) = entry.get("path").and_then(|v| v.as_str()) {
|
||||
out.push(PathBuf::from(path));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: parse `---`/`+++` headers from a unified diff payload.
|
||||
if out.is_empty()
|
||||
&& let Some(patch) = input.get("patch").and_then(|v| v.as_str())
|
||||
{
|
||||
out.extend(parse_patch_paths(patch));
|
||||
}
|
||||
out
|
||||
}
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Lightweight parser for `+++ b/<path>` lines in a unified diff. Used as a
|
||||
/// fallback when `apply_patch` is invoked with raw `patch` text and no
|
||||
/// `path`/`files` override. We deliberately keep this dumb — the real
|
||||
/// `apply_patch` tool already validates the patch shape; we only need a
|
||||
/// best-effort hint for the LSP hook.
|
||||
pub(super) fn parse_patch_paths(patch: &str) -> Vec<PathBuf> {
|
||||
let mut out = Vec::new();
|
||||
for line in patch.lines() {
|
||||
if let Some(rest) = line.strip_prefix("+++ ") {
|
||||
let trimmed = rest.trim();
|
||||
// Strip leading `b/` per git diff conventions.
|
||||
let path = trimmed.strip_prefix("b/").unwrap_or(trimmed);
|
||||
// Skip `/dev/null` (deletion).
|
||||
if path == "/dev/null" {
|
||||
continue;
|
||||
}
|
||||
out.push(PathBuf::from(path));
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
impl Engine {
|
||||
/// #136: post-edit hook. Inspects the tool name + input, derives the
|
||||
/// edited file path, and asks the LSP manager for diagnostics. The
|
||||
/// rendered block is queued in `pending_lsp_blocks` and flushed to the
|
||||
/// session message stream just before the next API request. Failure is
|
||||
/// silent by design — a missing/crashing LSP server must never block
|
||||
/// the agent.
|
||||
pub(super) async fn run_post_edit_lsp_hook(
|
||||
&mut self,
|
||||
tool_name: &str,
|
||||
tool_input: &serde_json::Value,
|
||||
) {
|
||||
if !self.lsp_manager.config().enabled {
|
||||
return;
|
||||
}
|
||||
let paths = edited_paths_for_tool(tool_name, tool_input);
|
||||
for path in paths {
|
||||
let absolute = if path.is_absolute() {
|
||||
path.clone()
|
||||
} else {
|
||||
self.session.workspace.join(&path)
|
||||
};
|
||||
// Use a short edit-sequence based on the existing turn counter so
|
||||
// log output stays correlated even though we do not currently
|
||||
// batch by sequence.
|
||||
let seq = self.turn_counter;
|
||||
if let Some(block) = self.lsp_manager.diagnostics_for(&absolute, seq).await {
|
||||
self.pending_lsp_blocks.push(block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Drain `pending_lsp_blocks` into a single synthetic user message so the
|
||||
/// model sees the diagnostics on its next request. Skips when nothing is
|
||||
/// pending. The message uses the standard `text` content block shape
|
||||
/// (the same shape as the post-tool steer messages) so we don't need to
|
||||
/// invent a new envelope.
|
||||
pub(super) async fn flush_pending_lsp_diagnostics(&mut self) {
|
||||
if self.pending_lsp_blocks.is_empty() {
|
||||
return;
|
||||
}
|
||||
let blocks = std::mem::take(&mut self.pending_lsp_blocks);
|
||||
let rendered = crate::lsp::render_blocks(&blocks);
|
||||
if rendered.is_empty() {
|
||||
return;
|
||||
}
|
||||
self.add_session_message(Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: rendered,
|
||||
cache_control: None,
|
||||
}],
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
//! Low-level tool execution helpers for the engine turn loop.
|
||||
//!
|
||||
//! This module keeps the mechanics of MCP dispatch, execution locking, and
|
||||
//! parallel-tool fanout out of `engine.rs`; the turn loop still owns planning,
|
||||
//! approval, and how tool results are written back into session state.
|
||||
|
||||
use std::{fs::OpenOptions, io::Write};
|
||||
|
||||
use super::*;
|
||||
|
||||
pub(super) fn emit_tool_audit(event: serde_json::Value) {
|
||||
let Some(path) = std::env::var_os("DEEPSEEK_TOOL_AUDIT_LOG") else {
|
||||
return;
|
||||
};
|
||||
let line = match serde_json::to_string(&event) {
|
||||
Ok(line) => line,
|
||||
Err(_) => return,
|
||||
};
|
||||
let path = PathBuf::from(path);
|
||||
if let Some(parent) = path.parent() {
|
||||
let _ = std::fs::create_dir_all(parent);
|
||||
}
|
||||
if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) {
|
||||
let _ = writeln!(file, "{line}");
|
||||
}
|
||||
}
|
||||
|
||||
impl Engine {
|
||||
pub(super) async fn execute_mcp_tool_with_pool(
|
||||
pool: Arc<AsyncMutex<McpPool>>,
|
||||
name: &str,
|
||||
input: serde_json::Value,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let mut pool = pool.lock().await;
|
||||
let result = pool
|
||||
.call_tool(name, input)
|
||||
.await
|
||||
.map_err(|e| ToolError::execution_failed(format!("MCP tool failed: {e}")))?;
|
||||
let content = serde_json::to_string_pretty(&result).unwrap_or_else(|_| result.to_string());
|
||||
Ok(ToolResult::success(content))
|
||||
}
|
||||
|
||||
pub(super) async fn execute_parallel_tool(
|
||||
&mut self,
|
||||
input: serde_json::Value,
|
||||
tool_registry: Option<&crate::tools::ToolRegistry>,
|
||||
tool_exec_lock: Arc<RwLock<()>>,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let calls = parse_parallel_tool_calls(&input)?;
|
||||
let mcp_pool = if calls.iter().any(|(tool, _)| McpPool::is_mcp_tool(tool)) {
|
||||
Some(self.ensure_mcp_pool().await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let Some(registry) = tool_registry else {
|
||||
return Err(ToolError::not_available(
|
||||
"tool registry unavailable for multi_tool_use.parallel",
|
||||
));
|
||||
};
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
for (tool_name, tool_input) in calls {
|
||||
if tool_name == MULTI_TOOL_PARALLEL_NAME {
|
||||
return Err(ToolError::invalid_input(
|
||||
"multi_tool_use.parallel cannot call itself",
|
||||
));
|
||||
}
|
||||
if McpPool::is_mcp_tool(&tool_name) {
|
||||
if !mcp_tool_is_parallel_safe(&tool_name) {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
"Tool '{tool_name}' is an MCP tool and cannot run in parallel. \
|
||||
Allowed MCP tools: list_mcp_resources, list_mcp_resource_templates, \
|
||||
mcp_read_resource, read_mcp_resource, mcp_get_prompt."
|
||||
)));
|
||||
}
|
||||
} else {
|
||||
let Some(spec) = registry.get(&tool_name) else {
|
||||
return Err(ToolError::not_available(format!(
|
||||
"tool '{tool_name}' is not registered"
|
||||
)));
|
||||
};
|
||||
if !spec.is_read_only() {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
"Tool '{tool_name}' is not read-only and cannot run in parallel"
|
||||
)));
|
||||
}
|
||||
if spec.approval_requirement() != ApprovalRequirement::Auto {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
"Tool '{tool_name}' requires approval and cannot run in parallel"
|
||||
)));
|
||||
}
|
||||
if !spec.supports_parallel() {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
"Tool '{tool_name}' does not support parallel execution"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let registry_ref = registry;
|
||||
let lock = tool_exec_lock.clone();
|
||||
let tx_event = self.tx_event.clone();
|
||||
let mcp_pool = mcp_pool.clone();
|
||||
tasks.push(async move {
|
||||
let result = Engine::execute_tool_with_lock(
|
||||
lock,
|
||||
true,
|
||||
false,
|
||||
tx_event,
|
||||
tool_name.clone(),
|
||||
tool_input.clone(),
|
||||
Some(registry_ref),
|
||||
mcp_pool,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
(tool_name, result)
|
||||
});
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
while let Some((tool_name, result)) = tasks.next().await {
|
||||
match result {
|
||||
Ok(output) => {
|
||||
let mut error = None;
|
||||
if !output.success {
|
||||
error = Some(output.content.clone());
|
||||
}
|
||||
results.push(ParallelToolResultEntry {
|
||||
tool_name,
|
||||
success: output.success,
|
||||
content: output.content,
|
||||
error,
|
||||
});
|
||||
}
|
||||
Err(err) => {
|
||||
let message = format!("{err}");
|
||||
results.push(ParallelToolResultEntry {
|
||||
tool_name,
|
||||
success: false,
|
||||
content: format!("Error: {message}"),
|
||||
error: Some(message),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ToolResult::json(&ParallelToolResult { results })
|
||||
.map_err(|e| ToolError::execution_failed(e.to_string()))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(super) async fn execute_tool_with_lock(
|
||||
lock: Arc<RwLock<()>>,
|
||||
supports_parallel: bool,
|
||||
interactive: bool,
|
||||
tx_event: mpsc::Sender<Event>,
|
||||
tool_name: String,
|
||||
tool_input: serde_json::Value,
|
||||
registry: Option<&crate::tools::ToolRegistry>,
|
||||
mcp_pool: Option<Arc<AsyncMutex<McpPool>>>,
|
||||
context_override: Option<crate::tools::ToolContext>,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let _guard = if supports_parallel {
|
||||
ToolExecGuard::Read(lock.read().await)
|
||||
} else {
|
||||
ToolExecGuard::Write(lock.write().await)
|
||||
};
|
||||
|
||||
if interactive {
|
||||
let _ = tx_event.send(Event::PauseEvents).await;
|
||||
}
|
||||
|
||||
let result = if McpPool::is_mcp_tool(&tool_name) {
|
||||
if let Some(pool) = mcp_pool {
|
||||
Engine::execute_mcp_tool_with_pool(pool, &tool_name, tool_input).await
|
||||
} else {
|
||||
Err(ToolError::not_available(format!(
|
||||
"tool '{tool_name}' is not registered"
|
||||
)))
|
||||
}
|
||||
} else if let Some(registry) = registry {
|
||||
registry
|
||||
.execute_full_with_context(&tool_name, tool_input, context_override.as_ref())
|
||||
.await
|
||||
} else {
|
||||
Err(ToolError::not_available(format!(
|
||||
"tool '{tool_name}' is not registered"
|
||||
)))
|
||||
};
|
||||
|
||||
if interactive {
|
||||
let _ = tx_event.send(Event::ResumeEvents).await;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user