From f8b3c1e48137351ef56c35d7a186c70ec5ea6bee Mon Sep 17 00:00:00 2001 From: Hunter B Date: Sat, 13 Jun 2026 08:35:29 -0700 Subject: [PATCH] feat(engine): parallelize read-only shell calls Closes #2983. Allow exec_shell to opt into the existing parallel tool lane when its concrete input is a conservative read-only shell command. The whitelist covers direct git/list/search/read commands plus simple bash/sh/zsh -c wrappers whose inner command passes the same classifier; anything effectful, interactive, redirected, piped, backgrounded, or stdin-fed remains serial and approval-gated. Parallel exec_shell fanout is capped at four workers, and explicit multi_tool_use.parallel now preserves request-order results while using the same input-aware checks. Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/tui/src/command_safety.rs | 120 +++++++++++++++++++ crates/tui/src/core/engine.rs | 2 + crates/tui/src/core/engine/tests.rs | 42 +++++++ crates/tui/src/core/engine/tool_execution.rs | 119 +++++++++++------- crates/tui/src/core/engine/turn_loop.rs | 15 ++- crates/tui/src/tools/shell.rs | 40 ++++++- crates/tui/src/tools/shell/tests.rs | 35 ++++++ crates/tui/src/tools/spec.rs | 15 +++ 8 files changed, 343 insertions(+), 45 deletions(-) diff --git a/crates/tui/src/command_safety.rs b/crates/tui/src/command_safety.rs index a2e1fc79..91ad4b83 100644 --- a/crates/tui/src/command_safety.rs +++ b/crates/tui/src/command_safety.rs @@ -355,6 +355,90 @@ pub fn prefix_allow_matches(pattern: &str, command: &str) -> bool { command_norm == pattern_norm || command_norm.starts_with(&format!("{pattern_norm} ")) } +const PARALLEL_READONLY_PREFIXES: &[&str] = &[ + "git status", + "git log", + "git diff", + "git show", + "git ls-files", + "git blame", + "git grep", + "ls", + "pwd", + "cat", + "head", + "tail", + "wc", + "which", + "stat", + "file", + "du", + "df", + "grep", + "rg", + "fd", +]; + +/// Return `true` when a shell command is safe to auto-approve and run in a +/// parallel read-only chunk. +pub fn is_parallel_readonly_command(command: &str) -> bool { + let trimmed = command.trim(); + if trimmed.is_empty() { + return false; + } + if trimmed.contains("$(") + || trimmed + .chars() + .any(|ch| matches!(ch, '\n' | '\r' | ';' | '&' | '|' | '>' | '<' | '`')) + { + return false; + } + + let tokens = shell_words(trimmed); + let Some(start) = primary_token_index(&tokens) else { + return false; + }; + let command_tokens = tokens[start..].to_vec(); + + if let Some(inner_command) = readonly_shell_wrapper_inner_command(&command_tokens) { + return is_parallel_readonly_command(inner_command); + } + + let command_refs = command_tokens + .iter() + .map(String::as_str) + .collect::>(); + let canonical = classify_command(&command_refs); + if canonical == "tail" + && command_refs.iter().skip(1).any(|token| { + *token == "-f" + || *token == "-F" + || *token == "--follow" + || token.starts_with("--follow=") + }) + { + return false; + } + + PARALLEL_READONLY_PREFIXES + .iter() + .any(|prefix| *prefix == canonical) +} + +fn readonly_shell_wrapper_inner_command(tokens: &[String]) -> Option<&str> { + let shell = tokens.first()?.as_str(); + if !matches!(shell, "bash" | "sh" | "zsh") { + return None; + } + if tokens.len() != 3 { + return None; + } + if !matches!(tokens[1].as_str(), "-c" | "-lc") { + return None; + } + Some(tokens[2].as_str()) +} + /// Safety classification of a command #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SafetyLevel { @@ -1037,6 +1121,42 @@ mod tests { ); } + #[test] + fn parallel_readonly_command_classifier_is_strict() { + for command in [ + "git status -s", + "git log --oneline -5", + "rg foo crates/", + "ls -la", + "cat Cargo.toml", + "bash -lc 'git status -s'", + "sh -c 'rg foo crates/'", + ] { + assert!( + is_parallel_readonly_command(command), + "{command} should be parallel read-only" + ); + } + + for command in [ + "git status && rm -rf /", + "cat a > b", + "git push", + "cargo build", + "tail -f log", + "rg foo | head", + "find . -delete", + "sleep 5 &", + "bash -lc 'git status && rm -rf /'", + "bash -lc 'rg foo | head'", + ] { + assert!( + !is_parallel_readonly_command(command), + "{command} should not be parallel read-only" + ); + } + } + #[test] fn test_workspace_safe_commands() { assert_eq!( diff --git a/crates/tui/src/core/engine.rs b/crates/tui/src/core/engine.rs index 0c8e4cbb..6b81745e 100644 --- a/crates/tui/src/core/engine.rs +++ b/crates/tui/src/core/engine.rs @@ -2849,6 +2849,8 @@ mod tool_setup; mod turn_loop; pub(crate) use token_estimate_cache::TokenEstimateCache; +pub(super) const MAX_PARALLEL_SHELL_EXEC: usize = 4; + pub(crate) fn default_active_native_tool_names() -> &'static [&'static str] { tool_catalog::DEFAULT_ACTIVE_NATIVE_TOOLS } diff --git a/crates/tui/src/core/engine/tests.rs b/crates/tui/src/core/engine/tests.rs index 042fc718..014c72d8 100644 --- a/crates/tui/src/core/engine/tests.rs +++ b/crates/tui/src/core/engine/tests.rs @@ -488,6 +488,48 @@ fn tool_execution_batches_use_serial_barriers() { } } +#[test] +fn shell_readonly_plans_batch_around_serial_barrier() { + let mut shell_a = make_plan_at(0, true, true, false, false); + shell_a.name = "exec_shell".to_string(); + shell_a.input = json!({"command": "git status -s"}); + let mut shell_b = make_plan_at(1, true, true, false, false); + shell_b.name = "exec_shell".to_string(); + shell_b.input = json!({"command": "git log --oneline -5"}); + let mut write_shell = make_plan_at(2, false, false, true, false); + write_shell.name = "exec_shell".to_string(); + write_shell.input = json!({"command": "cargo build"}); + let mut shell_c = make_plan_at(3, true, true, false, false); + shell_c.name = "exec_shell".to_string(); + shell_c.input = json!({"command": "bash -lc 'rg TODO crates/tui/src/core'"}); + + let batches = plan_tool_execution_batches(vec![shell_a, shell_b, write_shell, shell_c]); + assert_eq!(batches.len(), 3); + + match &batches[0] { + ToolExecutionBatch::Parallel(plans) => { + assert_eq!( + plans.iter().map(|plan| plan.index).collect::>(), + vec![0, 1] + ); + } + ToolExecutionBatch::Serial(_) => panic!("first batch should be parallel"), + } + match &batches[1] { + ToolExecutionBatch::Serial(plan) => assert_eq!(plan.index, 2), + ToolExecutionBatch::Parallel(_) => panic!("write shell should be a serial barrier"), + } + match &batches[2] { + ToolExecutionBatch::Parallel(plans) => { + assert_eq!( + plans.iter().map(|plan| plan.index).collect::>(), + vec![3] + ); + } + ToolExecutionBatch::Serial(_) => panic!("third batch should be parallel"), + } +} + #[test] fn successful_update_plan_ends_plan_mode_turn_immediately() { assert!(should_stop_after_plan_tool( diff --git a/crates/tui/src/core/engine/tool_execution.rs b/crates/tui/src/core/engine/tool_execution.rs index ae5e1b5a..7022bc87 100644 --- a/crates/tui/src/core/engine/tool_execution.rs +++ b/crates/tui/src/core/engine/tool_execution.rs @@ -185,8 +185,10 @@ impl Engine { )); }; + let result_count = calls.len(); let mut tasks = FuturesUnordered::new(); - for (tool_name, tool_input) in calls { + let shell_permits = Arc::new(tokio::sync::Semaphore::new(MAX_PARALLEL_SHELL_EXEC)); + for (index, (tool_name, tool_input)) in calls.into_iter().enumerate() { if tool_name == MULTI_TOOL_PARALLEL_NAME { return Err(ToolError::invalid_input( "multi_tool_use.parallel cannot call itself", @@ -206,17 +208,17 @@ impl Engine { "tool '{tool_name}' is not registered" ))); }; - if !spec.is_read_only() { + if !spec.is_read_only_for(&tool_input) { return Err(ToolError::invalid_input(format!( "Tool '{tool_name}' is not read-only and cannot run in parallel" ))); } - if spec.approval_requirement() != ApprovalRequirement::Auto { + if spec.approval_requirement_for(&tool_input) != ApprovalRequirement::Auto { return Err(ToolError::invalid_input(format!( "Tool '{tool_name}' requires approval and cannot run in parallel" ))); } - if !spec.supports_parallel() { + if !spec.supports_parallel_for(&tool_input) { return Err(ToolError::invalid_input(format!( "Tool '{tool_name}' does not support parallel execution" ))); @@ -227,7 +229,13 @@ impl Engine { let lock = tool_exec_lock.clone(); let tx_event = self.tx_event.clone(); let mcp_pool = mcp_pool.clone(); + let shell_permits = shell_permits.clone(); tasks.push(async move { + let _shell_permit = if tool_name == "exec_shell" { + shell_permits.acquire_owned().await.ok() + } else { + None + }; let result = Engine::execute_tool_with_lock( lock, true, @@ -240,36 +248,39 @@ impl Engine { None, ) .await; - (tool_name, result) + (index, tool_name, result) }); } - let mut results = Vec::new(); - while let Some((tool_name, result)) = tasks.next().await { - match result { + let mut results: Vec> = Vec::with_capacity(result_count); + results.resize_with(result_count, || None); + while let Some((index, tool_name, result)) = tasks.next().await { + let entry = match result { Ok(output) => { let mut error = None; if !output.success { error = Some(output.content.clone()); } - results.push(ParallelToolResultEntry { + ParallelToolResultEntry { tool_name, success: output.success, content: output.content, error, - }); + } } Err(err) => { let message = format!("{err}"); - results.push(ParallelToolResultEntry { + ParallelToolResultEntry { tool_name, success: false, content: format!("Error: {message}"), error: Some(message), - }); + } } - } + }; + results[index] = Some(entry); } + let results = results.into_iter().flatten().collect(); ToolResult::json(&ParallelToolResult { results }) .map_err(|e| ToolError::execution_failed(e.to_string())) @@ -381,7 +392,7 @@ impl Engine { mod tests { use super::*; use serde_json::json; - use std::{sync::Mutex, time::Duration}; + use std::{ffi::OsString, path::Path, sync::Mutex, time::Duration}; /// Tests in this module mutate `DEEPSEEK_TOOL_AUDIT_LOG` which is /// process-global; serialise through this guard so the parallel @@ -392,6 +403,43 @@ mod tests { AUDIT_TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner()) } + struct AuditEnvGuard { + previous: Option, + } + + impl AuditEnvGuard { + fn set(path: &Path) -> Self { + let previous = std::env::var_os("DEEPSEEK_TOOL_AUDIT_LOG"); + // SAFETY: serialised by the guard above. + unsafe { + std::env::set_var("DEEPSEEK_TOOL_AUDIT_LOG", path); + } + Self { previous } + } + + fn unset() -> Self { + let previous = std::env::var_os("DEEPSEEK_TOOL_AUDIT_LOG"); + // SAFETY: serialised by the guard above. + unsafe { + std::env::remove_var("DEEPSEEK_TOOL_AUDIT_LOG"); + } + Self { previous } + } + } + + impl Drop for AuditEnvGuard { + fn drop(&mut self) { + // SAFETY: callers hold AUDIT_TEST_GUARD for this guard's lifetime. + unsafe { + if let Some(previous) = self.previous.take() { + std::env::set_var("DEEPSEEK_TOOL_AUDIT_LOG", previous); + } else { + std::env::remove_var("DEEPSEEK_TOOL_AUDIT_LOG"); + } + } + } + } + #[tokio::test] async fn terminal_guard_queues_resume_when_event_channel_is_full() { let (tx, mut rx) = mpsc::channel(1); @@ -443,29 +491,35 @@ mod tests { let _g = audit_test_guard(); let tmp = tempfile::tempdir().expect("tempdir"); let path = tmp.path().join("audit.log"); - // SAFETY: serialised by the guard above. - unsafe { - std::env::set_var("DEEPSEEK_TOOL_AUDIT_LOG", &path); - } + let _env = AuditEnvGuard::set(&path); + let marker = path.display().to_string(); emit_tool_audit(json!({ "event": "tool.spillover", + "test_marker": marker, "tool_id": "call-abc", "tool_name": "exec_shell", "path": "/tmp/foo.txt", })); emit_tool_audit(json!({ "event": "tool.result", + "test_marker": marker, "tool_id": "call-xyz", "success": true, })); let body = std::fs::read_to_string(&path).expect("audit log written"); - let lines: Vec<&str> = body.lines().collect(); - assert_eq!(lines.len(), 2, "two emits → two lines"); + let entries: Vec = body + .lines() + .map(|line| serde_json::from_str(line).expect("audit line is JSON")) + .filter(|entry: &serde_json::Value| { + entry.get("test_marker").and_then(|v| v.as_str()) == Some(marker.as_str()) + }) + .collect(); + assert_eq!(entries.len(), 2, "two marked emits -> two lines"); // Each line round-trips as JSON, has the expected event key. - let first: serde_json::Value = serde_json::from_str(lines[0]).expect("first line is JSON"); + let first = &entries[0]; assert_eq!( first.get("event").and_then(|v| v.as_str()), Some("tool.spillover") @@ -475,26 +529,17 @@ mod tests { Some("call-abc") ); - let second: serde_json::Value = - serde_json::from_str(lines[1]).expect("second line is JSON"); + let second = &entries[1]; assert_eq!( second.get("event").and_then(|v| v.as_str()), Some("tool.result") ); - - // SAFETY: cleanup under the guard. - unsafe { - std::env::remove_var("DEEPSEEK_TOOL_AUDIT_LOG"); - } } #[test] fn emit_tool_audit_is_noop_when_env_var_unset() { let _g = audit_test_guard(); - // SAFETY: serialised by the guard above. - unsafe { - std::env::remove_var("DEEPSEEK_TOOL_AUDIT_LOG"); - } + let _env = AuditEnvGuard::unset(); // Should not panic and should not create any file. We can't // assert "no file written" without knowing where one might be // written, but the contract is "do nothing", which we verify @@ -510,16 +555,8 @@ mod tests { // Path with a parent that doesn't exist yet — the writer // should create it. let nested = tmp.path().join("nested").join("dir").join("audit.log"); - // SAFETY: serialised by the guard above. - unsafe { - std::env::set_var("DEEPSEEK_TOOL_AUDIT_LOG", &nested); - } + let _env = AuditEnvGuard::set(&nested); emit_tool_audit(json!({"event": "test"})); assert!(nested.exists(), "writer should mkdir -p the parent chain"); - - // SAFETY: cleanup under the guard. - unsafe { - std::env::remove_var("DEEPSEEK_TOOL_AUDIT_LOG"); - } } } diff --git a/crates/tui/src/core/engine/turn_loop.rs b/crates/tui/src/core/engine/turn_loop.rs index 267a89b0..af72e7e1 100644 --- a/crates/tui/src/core/engine/turn_loop.rs +++ b/crates/tui/src/core/engine/turn_loop.rs @@ -1530,10 +1530,11 @@ impl Engine { } else if let Some(registry) = tool_registry && let Some(spec) = registry.get(&tool_name) { - approval_required = spec.approval_requirement() != ApprovalRequirement::Auto; + approval_required = + spec.approval_requirement_for(&tool_input) != ApprovalRequirement::Auto; approval_description = spec.description().to_string(); - supports_parallel = spec.supports_parallel(); - read_only = spec.is_read_only(); + supports_parallel = spec.supports_parallel_for(&tool_input); + read_only = spec.is_read_only_for(&tool_input); } else if tool_name == CODE_EXECUTION_TOOL_NAME { approval_required = true; approval_description = @@ -1666,6 +1667,8 @@ impl Engine { if parallel_allowed { let mut tool_tasks = FuturesUnordered::new(); + let shell_permits = + Arc::new(tokio::sync::Semaphore::new(MAX_PARALLEL_SHELL_EXEC)); for plan in plans { if let Some(result) = plan.guard_result.clone() { let result = Ok(result); @@ -1704,8 +1707,14 @@ impl Engine { let tx_event = self.tx_event.clone(); let session_id = self.session.id.clone(); let started_at = Instant::now(); + let shell_permits = shell_permits.clone(); tool_tasks.push(async move { + let _shell_permit = if plan.name == "exec_shell" { + shell_permits.acquire_owned().await.ok() + } else { + None + }; let mut result = Engine::execute_tool_with_lock( lock, plan.supports_parallel, diff --git a/crates/tui/src/tools/shell.rs b/crates/tui/src/tools/shell.rs index c949dd40..6cb933d6 100644 --- a/crates/tui/src/tools/shell.rs +++ b/crates/tui/src/tools/shell.rs @@ -1768,7 +1768,9 @@ pub fn new_shared_shell_manager(workspace: PathBuf) -> SharedShellManager { // === ToolSpec Implementations === -use crate::command_safety::{SafetyLevel, analyze_command, extract_primary_command}; +use crate::command_safety::{ + SafetyLevel, analyze_command, extract_primary_command, is_parallel_readonly_command, +}; use crate::execpolicy::{ExecPolicyDecision, load_default_policy}; use crate::features::Feature; use crate::tools::cargo_failure_summary::summarize_cargo_failure; @@ -1913,6 +1915,26 @@ fn shell_network_restricted_hint<'a>( } } +fn exec_shell_input_is_parallel_readonly(input: &serde_json::Value) -> bool { + let Some(command) = input.get("command").and_then(serde_json::Value::as_str) else { + return false; + }; + if ["background", "interactive", "tty", "combined_output"] + .iter() + .any(|key| input.get(*key).and_then(serde_json::Value::as_bool) == Some(true)) + { + return false; + } + if ["stdin", "input", "data"] + .iter() + .any(|key| input.get(*key).is_some()) + { + return false; + } + + is_parallel_readonly_command(command) +} + async fn execute_foreground_via_background( context: &ToolContext, command: &str, @@ -2061,6 +2083,22 @@ impl ToolSpec for ExecShellTool { ApprovalRequirement::Required } + fn approval_requirement_for(&self, input: &serde_json::Value) -> ApprovalRequirement { + if exec_shell_input_is_parallel_readonly(input) { + ApprovalRequirement::Auto + } else { + self.approval_requirement() + } + } + + fn is_read_only_for(&self, input: &serde_json::Value) -> bool { + exec_shell_input_is_parallel_readonly(input) + } + + fn supports_parallel_for(&self, input: &serde_json::Value) -> bool { + exec_shell_input_is_parallel_readonly(input) + } + async fn execute( &self, input: serde_json::Value, diff --git a/crates/tui/src/tools/shell/tests.rs b/crates/tui/src/tools/shell/tests.rs index 8da55eb8..9ed59820 100644 --- a/crates/tui/src/tools/shell/tests.rs +++ b/crates/tui/src/tools/shell/tests.rs @@ -148,6 +148,41 @@ fn wait_for_completed_shell(manager: &mut ShellManager, task_id: &str) -> ShellR } } +#[test] +fn exec_shell_parallel_flags_are_input_aware() { + let tool = ExecShellTool; + let readonly = json!({"command": "git status -s"}); + assert!(tool.supports_parallel_for(&readonly)); + assert!(tool.is_read_only_for(&readonly)); + assert_eq!( + tool.approval_requirement_for(&readonly), + ApprovalRequirement::Auto + ); + + let bash_readonly = json!({"command": "bash -lc 'rg TODO crates/tui/src/tools'"}); + assert!(tool.supports_parallel_for(&bash_readonly)); + assert!(tool.is_read_only_for(&bash_readonly)); + assert_eq!( + tool.approval_requirement_for(&bash_readonly), + ApprovalRequirement::Auto + ); + + for input in [ + json!({"command": "git status -s", "background": true}), + json!({"command": "git status -s", "stdin": ""}), + json!({"command": "cargo build"}), + json!({"command": "bash -lc 'rg TODO crates | head'"}), + ] { + assert!(!tool.supports_parallel_for(&input), "{input:?}"); + assert!(!tool.is_read_only_for(&input), "{input:?}"); + assert_eq!( + tool.approval_requirement_for(&input), + ApprovalRequirement::Required, + "{input:?}" + ); + } +} + #[test] #[cfg(unix)] fn shell_execution_scrubs_parent_env_and_keeps_explicit_env() { diff --git a/crates/tui/src/tools/spec.rs b/crates/tui/src/tools/spec.rs index 63ac165b..3c7f95dc 100644 --- a/crates/tui/src/tools/spec.rs +++ b/crates/tui/src/tools/spec.rs @@ -643,6 +643,11 @@ pub trait ToolSpec: Send + Sync { } } + /// Returns the approval requirement for this concrete tool input. + fn approval_requirement_for(&self, _input: &Value) -> ApprovalRequirement { + self.approval_requirement() + } + /// Returns whether this tool is sandboxable. #[allow(dead_code)] fn is_sandboxable(&self) -> bool { @@ -657,11 +662,21 @@ pub trait ToolSpec: Send + Sync { && !caps.contains(&ToolCapability::ExecutesCode) } + /// Returns whether this concrete tool input is read-only. + fn is_read_only_for(&self, _input: &Value) -> bool { + self.is_read_only() + } + /// Returns whether this tool can be executed in parallel with others. fn supports_parallel(&self) -> bool { false } + /// Returns whether this concrete tool input can run in parallel. + fn supports_parallel_for(&self, _input: &Value) -> bool { + self.supports_parallel() + } + /// Returns whether this tool should be excluded from the model-visible /// tool catalog (deferred loading). Tools marked `true` are registered /// but not sent to the model until explicitly activated via tool search.