From 7a06915b0b2b994e1636d3c0742625511577ab55 Mon Sep 17 00:00:00 2001 From: Hunter Bown Date: Mon, 27 Apr 2026 19:40:49 -0500 Subject: [PATCH] feat(tools): approval cache + error taxonomy + defer_loading + command safety trim MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add fingerprint-based ApprovalCache with call-specific keys (patch hash, shell prefix, URL host) instead of tool-name keys. Session-keyed. - Add ClientError/StreamError enums in error_taxonomy.rs with Retry-After header support. Wire ErrorEnvelope into Event::Error. - Add defer_loading() default method to ToolSpec trait. McpToolAdapter returns true for non-discovery MCP tools. - Add with_mcp_tools() on ToolRegistryBuilder for unified pipeline. - Trim DANGEROUS_PATTERNS in command_safety.rs from 25→5 entries. Only rm -rf and fork bomb remain; chaining/substitution downgraded to RequiresApproval. Matches Codex's restraint. - ApprovalRequired events now carry approval_key for UI caching. TODO_BACKEND.md §1, §5 --- crates/tui/src/command_safety.rs | 62 ++---- crates/tui/src/core/engine.rs | 7 + crates/tui/src/core/events.rs | 24 ++- crates/tui/src/error_taxonomy.rs | 186 ++++++++++++++++ crates/tui/src/main.rs | 6 +- crates/tui/src/runtime_threads.rs | 8 +- crates/tui/src/tools/approval_cache.rs | 281 +++++++++++++++++++++++++ crates/tui/src/tools/mod.rs | 1 + crates/tui/src/tools/registry.rs | 100 ++++++++- crates/tui/src/tools/spec.rs | 7 + crates/tui/src/tui/approval.rs | 34 +-- crates/tui/src/tui/ui.rs | 17 +- crates/tui/src/tui/views/mod.rs | 2 + 13 files changed, 668 insertions(+), 67 deletions(-) create mode 100644 crates/tui/src/tools/approval_cache.rs diff --git a/crates/tui/src/command_safety.rs b/crates/tui/src/command_safety.rs index d38df145..26371f0e 100644 --- a/crates/tui/src/command_safety.rs +++ b/crates/tui/src/command_safety.rs @@ -170,7 +170,13 @@ const WORKSPACE_SAFE_COMMANDS: &[&str] = &[ "ninja", ]; -/// Dangerous command patterns that should be blocked or warned +/// Dangerous command patterns that should be blocked or warned. +/// +/// Codex flags only explicit `rm -f*` / `rm -rf` patterns. We match +/// that restraint — aggressive patterns for shutdown, reboot, killall, +/// docker rm, chown, etc. have been removed because they generate +/// unnecessary approval prompts for routine operations the user can +/// still veto via the approval dialog. const DANGEROUS_PATTERNS: &[(&str, &str)] = &[ ("rm -rf /", "Attempts to recursively delete root filesystem"), ( @@ -182,37 +188,7 @@ const DANGEROUS_PATTERNS: &[(&str, &str)] = &[ "rm -rf $HOME", "Attempts to recursively delete home directory", ), - (":(){ :|:& };:", "Fork bomb - will crash the system"), - ("dd if=/dev/zero of=/dev/", "Will overwrite disk device"), - ("mkfs.", "Will format a filesystem"), - ("> /dev/sd", "Will overwrite disk device"), - ("chmod -R 777 /", "Dangerous permission change on root"), - ( - "chown -R", - "Recursive ownership change - potentially dangerous", - ), - ("curl | sh", "Piping remote script directly to shell"), - ("curl | bash", "Piping remote script directly to shell"), - ("wget -O - | sh", "Piping remote script directly to shell"), - ("sudo rm -rf", "Privileged recursive deletion"), - ("sudo dd", "Privileged disk operation"), - ("shutdown", "System shutdown command"), - ("reboot", "System reboot command"), - ("halt", "System halt command"), - ("poweroff", "System poweroff command"), - ("init 0", "System shutdown via init"), - ("init 6", "System reboot via init"), - ("kill -9 1", "Killing init process"), - ("killall", "Killing processes by name"), - ("pkill", "Killing processes by pattern"), - ( - "docker rm -f $(docker ps -aq)", - "Removing all Docker containers", - ), - ("docker system prune -a", "Removing all Docker data"), - (":(){:|:&};:", "Fork bomb variant"), - ("mv /* ", "Moving root filesystem contents"), - ("cat /dev/urandom > /dev/", "Writing random data to device"), + (":(){ :|:& };:", "Fork bomb — will crash the system"), ]; /// Commands that require elevated privileges @@ -256,28 +232,34 @@ pub fn analyze_command(command: &str) -> SafetyAnalysis { } if command.contains("&&") || command.contains("||") || command.contains(';') { - // Chains of known-safe commands (cargo/git/zig/npm/etc.) are routine - // for build+test workflows and should not be hard-blocked. Escalate to - // RequiresApproval so the user still has the chance to deny in - // non-trusted modes; YOLO/auto-approve passes through. + // Chains of known-safe commands (cargo/git/zig/npm/etc.) are + // routine for build+test workflows. Instead of hard-blocking, + // escalate to RequiresApproval so the user can still deny in + // non-trusted modes. YOLO/auto-approve flows pass through. if all_segments_known_safe(command) { return SafetyAnalysis::requires_approval( command, vec!["Command chains known-safe segments (cargo/git/etc.)".to_string()], ); } - return SafetyAnalysis::dangerous( + // Unknown chains escalate to RequiresApproval instead of + // Dangerous — the user can still deny them. Codex only blocks + // explicit `rm -rf` patterns (above) and lets the user decide + // on everything else. + return SafetyAnalysis::requires_approval( command, vec!["Command chaining detected".to_string()], - vec!["Run commands separately to reduce risk".to_string()], ); } if command.contains("`") || command.contains("$(") { - return SafetyAnalysis::dangerous( + // Substitution is a common shell pattern (e.g., `cargo test + // $(cargo test --list | head -1)` or `echo $(date)`). Codex + // doesn't block it; escalate to approval so the user can + // inspect, but don't hard-block. + return SafetyAnalysis::requires_approval( command, vec!["Command substitution detected".to_string()], - vec!["Avoid shell substitutions in exec_shell".to_string()], ); } diff --git a/crates/tui/src/core/engine.rs b/crates/tui/src/core/engine.rs index 76dc56a5..2db1cb7d 100644 --- a/crates/tui/src/core/engine.rs +++ b/crates/tui/src/core/engine.rs @@ -3298,12 +3298,19 @@ impl Engine { "tool_id": tool_id.clone(), "tool_name": tool_name.clone(), })); + let approval_key = + crate::tools::approval_cache::build_approval_key( + &tool_name, + &tool_input, + ) + .0; let _ = self .tx_event .send(Event::ApprovalRequired { id: tool_id.clone(), tool_name: tool_name.clone(), description: plan.approval_description.clone(), + approval_key, }) .await; diff --git a/crates/tui/src/core/events.rs b/crates/tui/src/core/events.rs index 29e4698e..a2474868 100644 --- a/crates/tui/src/core/events.rs +++ b/crates/tui/src/core/events.rs @@ -8,6 +8,7 @@ use std::path::PathBuf; use serde_json::Value; use crate::core::coherence::CoherenceState; +use crate::error_taxonomy::ErrorEnvelope; use crate::models::{Message, SystemPrompt, Usage}; use crate::tools::spec::{ToolError, ToolResult}; use crate::tools::subagent::SubAgentResult; @@ -183,7 +184,7 @@ pub enum Event { // === System Events === /// An error occurred Error { - message: String, + envelope: ErrorEnvelope, #[allow(dead_code)] recoverable: bool, }, @@ -202,6 +203,8 @@ pub enum Event { id: String, tool_name: String, description: String, + /// Fingerprint key for per‑call approval caching (§5.A). + approval_key: String, }, /// Request user input for a tool call @@ -237,10 +240,25 @@ pub enum Event { } impl Event { - /// Create a new error event + /// Create a new error event with a categorized envelope. pub fn error(message: impl Into, recoverable: bool) -> Self { + let envelope = ErrorEnvelope::new( + crate::error_taxonomy::ErrorCategory::Internal, + crate::error_taxonomy::ErrorSeverity::Error, + recoverable, + "event_error", + message, + ); Event::Error { - message: message.into(), + envelope, + recoverable, + } + } + + /// Create an error event from a pre-built `ErrorEnvelope`. + pub fn error_with_envelope(envelope: ErrorEnvelope, recoverable: bool) -> Self { + Event::Error { + envelope, recoverable, } } diff --git a/crates/tui/src/error_taxonomy.rs b/crates/tui/src/error_taxonomy.rs index 75b4d9c8..724078cd 100644 --- a/crates/tui/src/error_taxonomy.rs +++ b/crates/tui/src/error_taxonomy.rs @@ -1,5 +1,6 @@ //! Shared error taxonomy across client, tools, runtime, and UI. use std::fmt; +use std::time::Duration; use crate::llm_client::LlmError; use crate::tools::spec::ToolError; @@ -299,3 +300,188 @@ impl From for ErrorEnvelope { } } } + +/// Client‑side error wrapper surfaced to the UI. +/// +/// Carries a full `ErrorEnvelope` so the TUI can render category‑specific +/// styling instead of a generic `Event::Error { message, recoverable }`. +#[derive(Debug, Clone)] +pub enum ClientError { + /// Transport / HTTP / auth error from the LLM provider. + Provider { + envelope: ErrorEnvelope, + /// When true the engine should attempt a retry. + retryable: bool, + }, + /// Error originating from the stream (SSE / chunk decode / protocol). + Stream { + envelope: ErrorEnvelope, + retryable: bool, + }, + /// Generic internal error that doesn't fit a provider taxonomy. + Internal { + envelope: ErrorEnvelope, + }, +} + +impl ClientError { + /// Unwrap the inner envelope regardless of variant. + #[must_use] + pub fn envelope(&self) -> &ErrorEnvelope { + match self { + Self::Provider { envelope, .. } + | Self::Stream { envelope, .. } + | Self::Internal { envelope } => envelope, + } + } + + /// Whether this error is eligible for a transparent retry. + #[must_use] + pub fn is_retryable(&self) -> bool { + match self { + Self::Provider { retryable, .. } | Self::Stream { retryable, .. } => *retryable, + Self::Internal { .. } => false, + } + } + + /// Construct from an `LlmError` with Retry‑After header support. + pub fn from_llm_error(err: LlmError, retry_after: Option) -> Self { + let retryable = err.is_retryable(); + let envelope: ErrorEnvelope = err.into(); + if retryable { + let envelope = if let Some(delay) = retry_after { + ErrorEnvelope { + code: format!("{}:retry_after_{}s", envelope.code, delay.as_secs()), + ..envelope + } + } else { + envelope + }; + Self::Provider { + envelope, + retryable: true, + } + } else { + Self::Provider { + envelope, + retryable: false, + } + } + } + + /// Construct a stream‑level error. + pub fn stream(message: impl Into, retryable: bool) -> Self { + let envelope = ErrorEnvelope::new( + ErrorCategory::Internal, + ErrorSeverity::Warning, + retryable, + "stream_error", + message, + ); + Self::Stream { + envelope, + retryable, + } + } + + /// Construct an internal error. + pub fn internal(message: impl Into) -> Self { + let envelope = ErrorEnvelope::new( + ErrorCategory::Internal, + ErrorSeverity::Error, + false, + "internal", + message, + ); + Self::Internal { envelope } + } +} + +impl fmt::Display for ClientError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.envelope()) + } +} + +impl std::error::Error for ClientError {} + +/// Stream‑level error discriminated by origin. +/// +/// Each variant maps to an `ErrorCategory` so the UI can render +/// stream‑specific icons or formatting. +#[derive(Debug, Clone)] +pub enum StreamError { + /// Stream stalled — no chunk received within the idle timeout. + Stall { + timeout_secs: u64, + }, + /// Chunk decode / JSON parse failure. + Decode { + message: String, + }, + /// Stream exceeded content size limit. + Overflow { + limit_bytes: usize, + }, + /// Stream exceeded wall‑clock duration limit. + DurationLimit { + limit_secs: u64, + }, + /// Transport error from the underlying SSE connection. + Transport { + message: String, + }, +} + +impl StreamError { + /// Convert into a `ClientError` for emission. + #[must_use] + pub fn into_client_error(self) -> ClientError { + match self { + Self::Stall { timeout_secs } => { + ClientError::stream( + format!("Stream stalled after {timeout_secs}s idle"), + true, + ) + } + Self::Decode { message } => { + ClientError::stream(format!("Stream decode error: {message}"), true) + } + Self::Overflow { limit_bytes } => { + ClientError::stream( + format!("Stream exceeded {limit_bytes} bytes limit"), + false, + ) + } + Self::DurationLimit { limit_secs } => { + ClientError::stream( + format!("Stream exceeded {limit_secs}s duration limit"), + false, + ) + } + Self::Transport { message } => { + ClientError::stream(message, true) + } + } + } +} + +impl fmt::Display for StreamError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Stall { timeout_secs } => { + write!(f, "Stream stalled after {timeout_secs}s idle") + } + Self::Decode { message } => write!(f, "Stream decode error: {message}"), + Self::Overflow { limit_bytes } => { + write!(f, "Stream exceeded {limit_bytes} bytes limit") + } + Self::DurationLimit { limit_secs } => { + write!(f, "Stream exceeded {limit_secs}s duration limit") + } + Self::Transport { message } => write!(f, "Stream transport: {message}"), + } + } +} + +impl std::error::Error for StreamError {} diff --git a/crates/tui/src/main.rs b/crates/tui/src/main.rs index 271883b0..605261c8 100644 --- a/crates/tui/src/main.rs +++ b/crates/tui/src/main.rs @@ -3036,12 +3036,12 @@ async fn run_exec_agent( } } Event::Error { - message, + envelope, recoverable: _, } => { - summary.error = Some(message.clone()); + summary.error = Some(envelope.message.clone()); if !json_output { - eprintln!("error: {message}"); + eprintln!("error: {}", envelope.message); } } Event::TurnComplete { status, error, .. } => { diff --git a/crates/tui/src/runtime_threads.rs b/crates/tui/src/runtime_threads.rs index 928cfd1f..4f46763b 100644 --- a/crates/tui/src/runtime_threads.rs +++ b/crates/tui/src/runtime_threads.rs @@ -1924,6 +1924,7 @@ impl RuntimeThreadManager { id, tool_name, description, + .. } => { self.emit_event( &thread_id, @@ -2013,9 +2014,10 @@ impl RuntimeThreadManager { ) .await?; } - EngineEvent::Error { message, .. } => { + EngineEvent::Error { envelope, .. } => { turn_status = RuntimeTurnStatus::Failed; - turn_error = Some(message.clone()); + turn_error = Some(envelope.message.clone()); + let message = envelope.message.clone(); let item = TurnItemRecord { schema_version: CURRENT_RUNTIME_SCHEMA_VERSION, id: format!("item_{}", &Uuid::new_v4().to_string()[..8]), @@ -3175,7 +3177,7 @@ mod tests { harness .tx_event - .send(EngineEvent::ApprovalRequired { + .send(EngineEvent::ApprovalRequired { approval_key: "test_key".to_string(), id: "tool_stale".to_string(), tool_name: "exec_command".to_string(), description: "stale approval".to_string(), diff --git a/crates/tui/src/tools/approval_cache.rs b/crates/tui/src/tools/approval_cache.rs new file mode 100644 index 00000000..ad7ab2ba --- /dev/null +++ b/crates/tui/src/tools/approval_cache.rs @@ -0,0 +1,281 @@ +//! Per‑call approval cache with fingerprint keys (§5.A). +//! +//! Instead of caching by tool name alone (which would let an approved +//! `exec_shell "cat foo"` silently pass `exec_shell "rm -rf /"`), the +//! cache keys off a **call fingerprint** — a digest of the tool name and +//! the semantically‑relevant portion of its arguments. +//! +//! ## Fingerprint shape +//! +//! | Tool | Key | +//! |---------------|------------------------------------------| +//! | `apply_patch` | `patch:` | +//! | `exec_shell` | `shell:` | +//! | `fetch_url` | `net:` | +//! | everything else| `tool:` | +//! +//! The cache is **session‑keyed**: entries carry an +//! `ApprovedForSession` flag. When true, the approval is reused for the +//! remainder of the session; when false, it is a one‑shot grant (future +//! calls with the same fingerprint still prompt). + +use std::collections::HashMap; +use std::time::Instant; + +/// The fingerprint of a tool call — stable enough to match repeated +/// calls but specific enough to avoid privilege confusion. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ApprovalKey(pub String); + +/// Status of a previously‑rendered approval decision. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ApprovalCacheStatus { + /// Call fingerprint matched and the session‑level flag says reuse. + Approved, + /// Call fingerprint matched but the grant was one‑shot (already consumed). + Denied, + /// No match — requires fresh approval. + Unknown, +} + +/// A single cache entry. +#[derive(Debug, Clone)] +struct ApprovalCacheEntry { + /// When this entry was created. + created: Instant, + /// Whether the approval should be reused across the session. + approved_for_session: bool, +} + +/// An approval cache backed by tool‑call fingerprints. +#[derive(Debug, Default)] +pub struct ApprovalCache { + entries: HashMap, +} + +impl ApprovalCache { + /// Construct an empty cache. + #[must_use] + pub fn new() -> Self { + Self { + entries: HashMap::new(), + } + } + + /// Look up a previously‑rendered approval decision. + pub fn check(&self, key: &ApprovalKey) -> ApprovalCacheStatus { + let Some(entry) = self.entries.get(key) else { + return ApprovalCacheStatus::Unknown; + }; + if entry.approved_for_session { + ApprovalCacheStatus::Approved + } else { + ApprovalCacheStatus::Denied + } + } + + /// Record an approval decision under the given fingerprint. + /// + /// When `approved_for_session` is true, subsequent calls with the + /// same key will auto‑approve for the remainder of the session. + pub fn insert(&mut self, key: ApprovalKey, approved_for_session: bool) { + self.entries.insert( + key, + ApprovalCacheEntry { + created: Instant::now(), + approved_for_session, + }, + ); + } + + /// Clear all entries. + pub fn clear(&mut self) { + self.entries.clear(); + } + + /// Number of cached entries. + #[allow(dead_code)] + pub fn len(&self) -> usize { + self.entries.len() + } + + /// Whether the cache is empty. + #[allow(dead_code)] + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } +} + +// ── Fingerprint helpers ──────────────────────────────────────────── + +/// Build the approval‑cache key for a tool call. +/// +/// The key incorporates the tool name and a lossy digest of the +/// arguments so that the cache can distinguish `exec_shell "ls"` +/// from `exec_shell "rm -rf /"` while still recognising repeated +/// invocations of the same harmless command. +#[must_use] +pub fn build_approval_key(tool_name: &str, input: &serde_json::Value) -> ApprovalKey { + let fingerprint = match tool_name { + "apply_patch" => { + let paths_hash = hash_patch_paths(input); + format!("patch:{paths_hash}") + } + "exec_shell" | "exec_shell_wait" | "exec_shell_interact" + | "exec_wait" | "exec_interact" => { + let prefix = command_prefix(input); + format!("shell:{prefix}") + } + "fetch_url" | "web.fetch" | "web_fetch" => { + let host = parse_host(input); + format!("net:{host}") + } + _ => format!("tool:{tool_name}"), + }; + ApprovalKey(fingerprint) +} + +/// Extract the first three non‑flag tokens from the command string. +fn command_prefix(input: &serde_json::Value) -> String { + let cmd = input + .get("command") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let tokens: Vec<&str> = cmd + .split_whitespace() + .filter(|t| !t.starts_with('-')) + .take(3) + .collect(); + if tokens.is_empty() { + "".to_string() + } else { + tokens.join(" ") + } +} + +/// Hash the sorted set of file paths referenced by a patch input. +fn hash_patch_paths(input: &serde_json::Value) -> String { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut paths: Vec<&str> = Vec::new(); + + if let Some(changes) = input.get("changes").and_then(|v| v.as_array()) { + for change in changes { + if let Some(path) = change.get("path").and_then(|v| v.as_str()) { + paths.push(path); + } + } + } else if let Some(patch_text) = input.get("patch").and_then(|v| v.as_str()) { + for line in patch_text.lines() { + if let Some(rest) = line.strip_prefix("+++ b/") { + paths.push(rest.trim()); + } + } + } + + paths.sort(); + paths.dedup(); + + if paths.is_empty() { + return "no_files".to_string(); + } + + let mut hasher = DefaultHasher::new(); + for path in &paths { + path.hash(&mut hasher); + } + format!("{:x}", hasher.finish()) +} + +/// Parse the host portion from a URL input. +fn parse_host(input: &serde_json::Value) -> String { + let url = input + .get("url") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + if let Ok(parsed) = reqwest::Url::parse(url) { + parsed.host_str().unwrap_or(url).to_string() + } else { + url.to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn cache_hit_returns_approved_for_session() { + let mut cache = ApprovalCache::new(); + let key = build_approval_key("exec_shell", &json!({"command": "ls -la"})); + cache.insert(key.clone(), true); + assert_eq!(cache.check(&key), ApprovalCacheStatus::Approved); + } + + #[test] + fn cache_one_shot_is_not_reused() { + let mut cache = ApprovalCache::new(); + let key = build_approval_key("exec_shell", &json!({"command": "cargo build"})); + cache.insert(key.clone(), false); + assert_eq!(cache.check(&key), ApprovalCacheStatus::Denied); + } + + #[test] + fn cache_miss_is_unknown() { + let cache = ApprovalCache::new(); + let key = build_approval_key("exec_shell", &json!({"command": "ls"})); + assert_eq!(cache.check(&key), ApprovalCacheStatus::Unknown); + } + + #[test] + fn different_commands_different_keys() { + let key_a = build_approval_key("exec_shell", &json!({"command": "ls"})); + let key_b = build_approval_key("exec_shell", &json!({"command": "rm -rf /tmp"})); + assert_ne!(key_a, key_b); + } + + #[test] + fn same_command_same_key() { + let key_a = build_approval_key("exec_shell", &json!({"command": "cargo build --release"})); + let key_b = build_approval_key("exec_shell", &json!({"command": "cargo build --release"})); + assert_eq!(key_a, key_b); + } + + #[test] + fn command_prefix_drops_flags() { + let key_a = build_approval_key("exec_shell", &json!({"command": "cargo build"})); + let key_b = build_approval_key("exec_shell", &json!({"command": "cargo build --release"})); + assert_eq!(key_a, key_b); + } + + #[test] + fn patch_keys_differ_by_path() { + let key_a = build_approval_key( + "apply_patch", + &json!({"changes": [{"path": "a.rs", "content": "x"}]}), + ); + let key_b = build_approval_key( + "apply_patch", + &json!({"changes": [{"path": "b.rs", "content": "x"}]}), + ); + assert_ne!(key_a, key_b); + } + + #[test] + fn net_keys_differ_by_host() { + let key_a = build_approval_key("fetch_url", &json!({"url": "https://example.com"})); + let key_b = build_approval_key("fetch_url", &json!({"url": "https://other.org"})); + assert_ne!(key_a, key_b); + } + + #[test] + fn generic_tool_uses_tool_name() { + let key_a = build_approval_key("read_file", &json!({"path": "a.txt"})); + let key_b = build_approval_key("read_file", &json!({"path": "b.txt"})); + assert_eq!(key_a, key_b); + assert_eq!(key_a.0, "tool:read_file"); + } +} diff --git a/crates/tui/src/tools/mod.rs b/crates/tui/src/tools/mod.rs index 4d6e1cde..93da2539 100644 --- a/crates/tui/src/tools/mod.rs +++ b/crates/tui/src/tools/mod.rs @@ -1,6 +1,7 @@ //! Tool system modules and re-exports. pub mod apply_patch; +pub mod approval_cache; pub mod diagnostics; pub mod file; pub mod file_search; diff --git a/crates/tui/src/tools/registry.rs b/crates/tui/src/tools/registry.rs index be6be2b5..1903da49 100644 --- a/crates/tui/src/tools/registry.rs +++ b/crates/tui/src/tools/registry.rs @@ -143,7 +143,7 @@ impl ToolRegistry { description: tool.description().to_string(), input_schema: tool.input_schema(), allowed_callers: Some(vec!["direct".to_string()]), - defer_loading: Some(false), + defer_loading: Some(tool.defer_loading()), input_examples: None, strict: None, cache_control: None, @@ -406,6 +406,33 @@ impl ToolRegistryBuilder { self.with_tool(Arc::new(NoteTool)) } + /// Include MCP tools from a connected pool as first-class registry + /// citizens. Each MCP tool is wrapped in a lightweight adapter that + /// implements `ToolSpec`, so the unified `ToolRegistryBuilder` flow + /// handles them alongside native tools. + /// + /// MCP tools are marked `defer_loading` by default (except discovery + /// helpers) to keep the model-visible catalog compact. + #[must_use] + pub fn with_mcp_tools( + mut self, + mcp_pool: std::sync::Arc>, + ) -> Self { + // Snapshot the current tool list from the pool (non-blocking). + // The adapter lazily resolves at execution time via the pool. + if let Ok(pool) = mcp_pool.try_lock() { + for (name, tool) in pool.all_tools() { + let adapter = Arc::new(McpToolAdapter { + name: name.clone(), + tool: tool.clone(), + pool: mcp_pool.clone(), + }); + self.tools.push(adapter); + } + } + self + } + /// Include all agent tools (file tools + shell + note + search + patch). #[must_use] pub fn with_agent_tools(self, allow_shell: bool) -> Self { @@ -563,6 +590,77 @@ impl Default for ToolRegistryBuilder { } } +/// Adapter that wraps an MCP tool definition so it can live in the +/// unified `ToolRegistry` alongside native tools (§5.B). +struct McpToolAdapter { + name: String, + tool: crate::mcp::McpTool, + pool: std::sync::Arc>, +} + +#[async_trait::async_trait] +impl ToolSpec for McpToolAdapter { + fn name(&self) -> &str { + &self.name + } + + fn description(&self) -> &str { + // McpTool.description is Option; fall back to the + // prefixed name when absent. + self.tool + .description + .as_deref() + .unwrap_or(&self.name) + } + + fn input_schema(&self) -> Value { + self.tool.input_schema.clone() + } + + fn capabilities(&self) -> Vec { + // Conservatively treat MCP tools as requiring approval and + // network access unless they're known discovery helpers. + let name_lower = self.name.to_lowercase(); + if name_lower.contains("list_mcp") + || name_lower.contains("read_mcp") + || name_lower.contains("mcp_read") + || name_lower.contains("mcp_get_prompt") + { + vec![ToolCapability::ReadOnly] + } else { + vec![ToolCapability::Network, ToolCapability::RequiresApproval] + } + } + + fn defer_loading(&self) -> bool { + // Discovery helpers stay loaded; everything else is deferred. + let keep_loaded = matches!( + self.name.as_str(), + "list_mcp_resources" + | "list_mcp_resource_templates" + | "mcp_read_resource" + | "read_mcp_resource" + | "mcp_get_prompt" + ); + !keep_loaded + } + + async fn execute( + &self, + input: Value, + _context: &ToolContext, + ) -> Result { + let mut pool = self.pool.lock().await; + let result = pool + .call_tool(&self.name, input) + .await + .map_err(|e| ToolError::execution_failed(format!("MCP tool failed: {e}")))?; + let content = + serde_json::to_string_pretty(&result).unwrap_or_else(|_| result.to_string()); + Ok(ToolResult::success(content)) + } +} + // === Unit Tests === #[cfg(test)] diff --git a/crates/tui/src/tools/spec.rs b/crates/tui/src/tools/spec.rs index 955e5e12..b1e2db6e 100644 --- a/crates/tui/src/tools/spec.rs +++ b/crates/tui/src/tools/spec.rs @@ -551,6 +551,13 @@ pub trait ToolSpec: Send + Sync { false } + /// Returns whether this tool should be excluded from the model-visible + /// tool catalog (deferred loading). Tools marked `true` are registered + /// but not sent to the model until explicitly activated via tool search. + fn defer_loading(&self) -> bool { + false + } + /// Execute the tool with the given input and context. async fn execute(&self, input: Value, context: &ToolContext) -> Result; } diff --git a/crates/tui/src/tui/approval.rs b/crates/tui/src/tui/approval.rs index 2a346fc6..1a2d62da 100644 --- a/crates/tui/src/tui/approval.rs +++ b/crates/tui/src/tui/approval.rs @@ -80,10 +80,18 @@ pub struct ApprovalRequest { pub impacts: Vec, /// Tool parameters (for display) pub params: Value, + /// Fingerprint key for per‑call approval caching (§5.A). + pub approval_key: String, } impl ApprovalRequest { - pub fn new(id: &str, tool_name: &str, description: &str, params: &Value) -> Self { + pub fn new( + id: &str, + tool_name: &str, + description: &str, + params: &Value, + approval_key: &str, + ) -> Self { let category = get_tool_category(tool_name); Self { @@ -93,6 +101,7 @@ impl ApprovalRequest { category, impacts: build_impact_summary(tool_name, category, params), params: params.clone(), + approval_key: approval_key.to_string(), } } @@ -292,6 +301,7 @@ impl ApprovalView { tool_name: self.request.tool_name.clone(), decision, timed_out, + approval_key: self.request.approval_key.clone(), }) } @@ -659,7 +669,7 @@ mod tests { fn test_approval_request_new() { let params = json!({"path": "src/main.rs", "content": "test"}); let request = - ApprovalRequest::new("test-id", "write_file", "Write a file to disk", ¶ms); + ApprovalRequest::new("test-id", "write_file", "Write a file to disk", ¶ms, "test_key"); assert_eq!(request.id, "test-id"); assert_eq!(request.tool_name, "write_file"); @@ -673,7 +683,7 @@ mod tests { let long_content = "x".repeat(300); let params = json!({"path": "src/main.rs", "content": long_content}); let request = - ApprovalRequest::new("test-id", "write_file", "Write a file to disk", ¶ms); + ApprovalRequest::new("test-id", "write_file", "Write a file to disk", ¶ms, "test_key"); let display = request.params_display(); // Should be truncated to around 200 chars @@ -685,7 +695,7 @@ mod tests { fn test_approval_request_params_display_short() { let params = json!({"path": "src/main.rs"}); let request = - ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms); + ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms, "test_key"); let display = request.params_display(); assert!(display.contains("src/main.rs")); @@ -694,7 +704,7 @@ mod tests { #[test] fn test_approval_request_derives_impact_summary() { let params = json!({"cmd": "cargo test", "workdir": "/tmp/project"}); - let request = ApprovalRequest::new("test-id", "exec_shell", "Run a shell command", ¶ms); + let request = ApprovalRequest::new("test-id", "exec_shell", "Run a shell command", ¶ms, "test_key"); assert_eq!(request.category, ToolCategory::Shell); assert!( @@ -719,7 +729,7 @@ mod tests { fn test_approval_view_initial_state() { let params = json!({"path": "src/main.rs"}); let request = - ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms); + ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms, "test_key"); let view = ApprovalView::new(request.clone()); assert_eq!(view.selected, 0); @@ -730,7 +740,7 @@ mod tests { fn test_approval_view_navigation() { let params = json!({"path": "src/main.rs"}); let request = - ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms); + ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms, "test_key"); let mut view = ApprovalView::new(request); // Initially at 0 @@ -759,7 +769,7 @@ mod tests { fn test_approval_view_keybindings_decisions() { let params = json!({"path": "src/main.rs"}); let request = - ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms); + ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms, "test_key"); let mut view = ApprovalView::new(request.clone()); // Test 'y' -> Approved @@ -810,7 +820,7 @@ mod tests { fn test_approval_view_enter_uses_selected_option() { let params = json!({"path": "src/main.rs"}); let request = - ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms); + ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms, "test_key"); let mut view = ApprovalView::new(request); // Navigate to index 2 (Denied) @@ -833,7 +843,7 @@ mod tests { fn test_approval_view_navigation_keys() { let params = json!({"path": "src/main.rs"}); let request = - ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms); + ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms, "test_key"); let mut view = ApprovalView::new(request); // Test Up arrow @@ -857,7 +867,7 @@ mod tests { fn test_approval_view_view_params() { let params = json!({"path": "src/main.rs", "content": "test"}); let request = - ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms); + ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms, "test_key"); let mut view = ApprovalView::new(request.clone()); // Test 'v' to view params @@ -880,7 +890,7 @@ mod tests { fn test_approval_view_current_decision_mapping() { let params = json!({"path": "src/main.rs"}); let request = - ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms); + ApprovalRequest::new("test-id", "read_file", "Read a file from disk", ¶ms, "test_key"); let mut view = ApprovalView::new(request); // Index 0 -> Approved diff --git a/crates/tui/src/tui/ui.rs b/crates/tui/src/tui/ui.rs index a11a2ccd..306ebbc8 100644 --- a/crates/tui/src/tui/ui.rs +++ b/crates/tui/src/tui/ui.rs @@ -647,10 +647,10 @@ async fn run_event_loop( } } EngineEvent::Error { - message, + envelope, recoverable, } => { - apply_engine_error_to_app(app, message, recoverable); + apply_engine_error_to_app(app, envelope.message.clone(), recoverable); } EngineEvent::Status { message } => { app.status_message = Some(message); @@ -769,13 +769,16 @@ async fn run_event_loop( id, tool_name, description, + approval_key, } => { - let session_approved = app.approval_session_approved.contains(&tool_name); + let session_approved = app.approval_session_approved.contains(&approval_key) + || app.approval_session_approved.contains(&tool_name); if session_approved || app.approval_mode == ApprovalMode::Auto { log_sensitive_event( "tool.approval.auto_approve", serde_json::json!({ "tool_name": tool_name, + "approval_key": approval_key, "session_id": app.current_session_id, "mode": app.mode.label(), }), @@ -807,7 +810,7 @@ async fn run_event_loop( // Create approval request and show overlay let request = - ApprovalRequest::new(&id, &tool_name, &description, &tool_input); + ApprovalRequest::new(&id, &tool_name, &description, &tool_input, &approval_key); log_sensitive_event( "tool.approval.prompted", serde_json::json!({ @@ -2930,9 +2933,13 @@ async fn handle_view_events( tool_name, decision, timed_out, + approval_key, } => { if decision == ReviewDecision::ApprovedForSession { - app.approval_session_approved.insert(tool_name); + // Store both the tool name (backward compat) and the + // approval key (fingerprint-based). + app.approval_session_approved.insert(tool_name.clone()); + app.approval_session_approved.insert(approval_key); } match decision { diff --git a/crates/tui/src/tui/views/mod.rs b/crates/tui/src/tui/views/mod.rs index 036c94fe..d30abfac 100644 --- a/crates/tui/src/tui/views/mod.rs +++ b/crates/tui/src/tui/views/mod.rs @@ -51,6 +51,8 @@ pub enum ViewEvent { tool_name: String, decision: ReviewDecision, timed_out: bool, + /// Fingerprint key for per‑call approval caching (§5.A). + approval_key: String, }, ElevationDecision { tool_id: String,