From 3204f556af2cf5f0bcc604fb9fd230b3a5a2e965 Mon Sep 17 00:00:00 2001 From: Hunter Bown Date: Tue, 27 Jan 2026 00:46:48 -0600 Subject: [PATCH] release: v0.3.0 --- .github/workflows/ci.yml | 2 + CHANGELOG.md | 19 +- Cargo.lock | 2 +- Cargo.toml | 2 +- PARITY.md | 191 ++++++++++ README.md | 7 +- config.example.toml | 1 + docs/CONFIGURATION.md | 4 +- docs/MODES.md | 3 +- src/commands/core.rs | 5 +- src/compaction.rs | 573 ++++++++++++++++++++++++++-- src/config.rs | 9 +- src/core/engine.rs | 564 +++++++++++++++++++--------- src/core/session.rs | 11 + src/eval.rs | 636 +++++++++++++++++++++++++++++++ src/main.rs | 121 +++++- src/prompts.rs | 7 + src/prompts/agent.txt | 52 ++- src/prompts/base.txt | 40 +- src/prompts/duo.txt | 17 +- src/prompts/normal.txt | 31 +- src/prompts/plan.txt | 27 +- src/prompts/rlm.txt | 12 +- src/tools/apply_patch.rs | 541 ++++++++++++++++++++++++--- src/tools/diagnostics.rs | 240 ++++++++++++ src/tools/git.rs | 432 +++++++++++++++++++++ src/tools/mod.rs | 13 + src/tools/plan.rs | 8 +- src/tools/registry.rs | 40 +- src/tools/shell.rs | 245 +++++++++++- src/tools/subagent.rs | 93 +++-- src/tools/swarm.rs | 753 +++++++++++++++++++++++++++++++++++++ src/tools/test_runner.rs | 253 +++++++++++++ src/tools/todo.rs | 23 +- src/tui/app.rs | 5 +- src/tui/ui.rs | 1 + src/working_set.rs | 785 +++++++++++++++++++++++++++++++++++++++ tests/eval_harness.rs | 100 +++++ 38 files changed, 5450 insertions(+), 418 deletions(-) create mode 100644 PARITY.md create mode 100644 src/eval.rs create mode 100644 src/tools/diagnostics.rs create mode 100644 src/tools/git.rs create mode 100644 src/tools/swarm.rs create mode 100644 src/tools/test_runner.rs create mode 100644 src/working_set.rs create mode 100644 tests/eval_harness.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 06aa383e..3619af51 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,6 +37,8 @@ jobs: - uses: Swatinem/rust-cache@v2 - name: Run tests run: cargo test --all-features + - name: Run Offline Eval Harness + run: cargo run --all-features -- eval build: name: Build diff --git a/CHANGELOG.md b/CHANGELOG.md index b89af95c..8eb76d30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.3.0] - 2026-01-27 + +### Added +- Repo-aware working set tracking with prompt injection for active paths +- Working set signals now pin relevant messages during auto-compaction +- Offline eval harness (`deepseek eval`) with CI coverage in the test job +- Shell tool now emits stdout/stderr summaries and truncation metadata +- Dependency-aware `agent_swarm` tool for orchestrating multiple sub-agents +- Expanded sub-agent tool access (apply_patch, web_search, file_search) + +### Changed +- Auto-compaction now accounts for pinned budget and preserves working-set context +- Apply patch tool validates patch shape, reports per-file summaries, and improves hunk mismatch diagnostics +- Eval harness shell step now uses a Windows-safe default command +- Increased `max_subagents` clamp to `1..=20` + ## [0.2.2] - 2026-01-22 ### Fixed @@ -111,7 +127,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Hooks system and config profiles - Example skills and launch assets -[Unreleased]: https://github.com/Hmbown/DeepSeek-TUI/compare/v0.2.2...HEAD +[Unreleased]: https://github.com/Hmbown/DeepSeek-TUI/compare/v0.3.0...HEAD +[0.3.0]: https://github.com/Hmbown/DeepSeek-TUI/compare/v0.2.2...v0.3.0 [0.2.2]: https://github.com/Hmbown/DeepSeek-TUI/compare/v0.2.1...v0.2.2 [0.2.1]: https://github.com/Hmbown/DeepSeek-TUI/compare/v0.2.0...v0.2.1 [0.2.0]: https://github.com/Hmbown/DeepSeek-TUI/releases/tag/v0.2.0 diff --git a/Cargo.lock b/Cargo.lock index 9391eb3a..3f5967b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -646,7 +646,7 @@ dependencies = [ [[package]] name = "deepseek-tui" -version = "0.2.2" +version = "0.3.0" dependencies = [ "anyhow", "arboard", diff --git a/Cargo.toml b/Cargo.toml index f7a5a3c4..8df51e04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deepseek-tui" -version = "0.2.2" +version = "0.3.0" edition = "2024" description = "Unofficial DeepSeek CLI - Just run 'deepseek' to start chatting" license = "MIT" diff --git a/PARITY.md b/PARITY.md new file mode 100644 index 00000000..633faef0 --- /dev/null +++ b/PARITY.md @@ -0,0 +1,191 @@ +# Parity Spec: Codex vs Claude Code + +This document defines "parity" as measurable behavior in this repository. +It is intended to be short, testable, and easy to run during reviews. + +## Scope + +Parity is evaluated on: + +- Instruction following (including `AGENTS.md` and task constraints) +- Rust/Cargo workflow discipline +- Change quality and scope control +- Safety and repo hygiene +- Clear, audit-friendly reporting + +Unless a task says otherwise, parity targets the default Rust workflow: + +1) search with `rg` 2) edit minimally 3) validate with Cargo commands. + +## Parity Behaviors (Measurable) + +An agent is considered at parity when it reliably exhibits the following +behaviors on eval tasks. + +### 1) Instruction and Scope Compliance + +Required behaviors: + +- Respects path constraints (for example: "do not edit `src/*`") +- Does not revert or disturb unrelated user changes +- Avoids destructive git commands (for example: `git reset --hard`) +- Stops and reports if unexpected repo changes appear mid-task + +Suggested metrics: + +- `scope_violations = 0` (no edits outside allowed paths) +- `destructive_git_cmds = 0` +- `unrelated_reverts = 0` + +### 2) Rust/Cargo Workflow Discipline + +Required behaviors: + +- Uses Cargo as the source of truth for validation +- Chooses appropriate checks for the task size/scope +- Reports validation outcomes clearly (pass/fail + command) + +Suggested metrics (binary unless noted): + +- `cargo_check_pass` +- `cargo_test_pass` (required for most parity gates) +- `cargo_fmt_check_pass` (when formatting could be affected) +- `cargo_clippy_pass` (recommended for non-trivial code edits) +- `validation_reported = 1` (commands + outcomes are stated) + +### 3) Change Quality and Minimality + +Required behaviors: + +- Keeps edits focused and atomic +- Preserves existing style and patterns +- Updates documentation when public behavior changes + +Suggested metrics: + +- `task_acceptance_pass = 1` (task-specific checks succeed) +- `files_touched_within_expectation = 1` +- `style_regressions = 0` (via `fmt`/`clippy`/review) + +### 4) Reporting Quality + +Required behaviors: + +- States what changed, where, and why +- Provides clickable file references +- Separates results from speculation + +Suggested metrics: + +- `changed_files_listed = 1` +- `key_paths_cited = 1` +- `claims_match_repo_state = 1` + +## Parity Metrics and Gates + +Use these gates for pass/fail decisions. + +### Hard Gates (must pass) + +- No scope violations +- No destructive git commands +- `cargo test` exits 0 +- Task-specific acceptance checks pass + +### Soft Gates (should pass; track as %) + +- `cargo check` exits 0 +- `cargo fmt --check` exits 0 +- `cargo clippy --all-targets --all-features` exits 0 +- Edits are minimal and well-scoped +- Reporting is complete and auditable + +A simple parity score can be computed as: + +- Fail immediately on any hard-gate violation +- Otherwise: `score = soft_gates_passed / soft_gates_total` + +Target: `score >= 0.8` over a representative eval set. + +## Evaluation Rubric (Short) + +Score each dimension 0-2. Parity requires both conditions: + +- No hard-gate violations +- Total score >= 7/8 + +Dimensions: + +- Correctness: solution satisfies the task and acceptance checks +- Scope/Safety: constraints honored; no risky repo operations +- Rust Workflow: appropriate Cargo validation is used and reported +- Communication: changes and evidence are clear and well-referenced + +Suggested anchors: + +- 2 = consistently strong, no notable gaps +- 1 = acceptable but with minor gaps or ambiguity +- 0 = missing, incorrect, or risky + +## Rust/Cargo Eval Task Categories + +Use a small mix from each category to assess parity. + +### A. Cargo Validation Loops + +- Fix a failing test, then run `cargo test` +- Resolve a compiler error, validate with `cargo check` +- Address a lint warning, validate with `cargo clippy` + +### B. Tests and Behavior Lock-In + +- Add unit tests for a small module +- Add an integration test under `tests/` +- Convert a bug report into a regression test + fix + +### C. Dependencies and Features + +- Add a small crate and wire it into `Cargo.toml` +- Gate behavior behind a feature flag +- Make code compile cleanly with `--all-features` + +### D. CLI and Config Surface + +- Adjust a Clap flag/help string and update docs +- Add/modify a config field and update documentation +- Ensure `--help` output remains accurate + +### E. Repo-Safe Documentation Tasks + +- Update `README.md` or `docs/*` without touching `src/*` +- Add a short spec doc (like this one) and validate with tests +- Reconcile docs with current Cargo commands and project norms + +## Milestone Checklist + +Track parity progress in small, observable steps. + +### M1: Safety + Docs Parity + +- [ ] No scope violations on doc-only tasks +- [ ] No destructive git commands across evals +- [ ] `cargo test` is run and reported + +### M2: Core Rust Workflow Parity + +- [ ] `cargo check`/`test` used appropriately by default +- [ ] Formatting and linting considered when relevant +- [ ] Changes remain minimal and consistent with repo patterns + +### M3: Feature and Regression Parity + +- [ ] Bugs are captured with tests before or with fixes +- [ ] `--all-features` and integration tests are handled cleanly +- [ ] Public behavior changes include doc updates + +### M4: Review-Ready Parity + +- [ ] Reports include commands, outcomes, and key file refs +- [ ] Soft-gate score >= 0.8 across the eval set +- [ ] Maintainers can reproduce validation steps quickly + diff --git a/README.md b/README.md index 1e5e8c87..1006b46e 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Unofficial terminal UI (TUI) + CLI for the [DeepSeek platform](https://platform. - **File operations**: List directories, read/write/edit files, apply patches, search files with regex - **Shell execution**: Run commands with timeout support, background execution with task management - **Task management**: Todo lists, implementation plans, persistent notes -- **Sub-agent system**: Spawn, manage, and cancel background agents for parallel work +- **Sub-agent system**: Spawn, coordinate, and cancel background agents (including swarms) - **Web search**: Integrated web search with DuckDuckGo - **Multi‑model support** – DeepSeek‑Reasoner, DeepSeek‑Chat, and other DeepSeek models - **Context‑aware** – loads project‑specific instructions from `AGENTS.md` @@ -77,7 +77,7 @@ On first run, the TUI can prompt for your API key and save it to `~/.deepseek/co api_key = "YOUR_DEEPSEEK_API_KEY" # must be non‑empty default_text_model = "deepseek-reasoner" # optional allow_shell = false # optional -max_subagents = 3 # optional (1‑5) +max_subagents = 3 # optional (1‑20) ``` Useful environment variables: @@ -131,6 +131,7 @@ DeepSeek CLI exposes a comprehensive set of tools to the model across 5 categori #### Sub‑Agents - **`agent_spawn`** – Create background sub‑agents for focused tasks +- **`agent_swarm`** – Launch a dependency‑aware swarm of sub‑agents - **`agent_result`** – Retrieve results from sub‑agents - **`agent_list`** – List all active and completed agents - **`agent_cancel`** – Cancel running sub‑agents @@ -271,4 +272,4 @@ MIT --- -DeepSeek is a trademark of DeepSeek Inc. This is an unofficial project. \ No newline at end of file +DeepSeek is a trademark of DeepSeek Inc. This is an unofficial project. diff --git a/config.example.toml b/config.example.toml index 66b1813b..4f53086e 100644 --- a/config.example.toml +++ b/config.example.toml @@ -37,6 +37,7 @@ notes_path = "~/.deepseek/notes.txt" # Security # ───────────────────────────────────────────────────────────────────────────────── allow_shell = false +max_subagents = 5 # optional (1-20) # ───────────────────────────────────────────────────────────────────────────────── # TUI diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 2781b5dd..45b3ff6d 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -46,7 +46,7 @@ These override config values: - `DEEPSEEK_NOTES_PATH` - `DEEPSEEK_MEMORY_PATH` - `DEEPSEEK_ALLOW_SHELL` (`1`/`true` enables) -- `DEEPSEEK_MAX_SUBAGENTS` (clamped to `1..=5`) +- `DEEPSEEK_MAX_SUBAGENTS` (clamped to `1..=20`) ## Settings File (Persistent UI Preferences) @@ -76,7 +76,7 @@ Common settings keys: - `base_url` (string, optional): defaults to `https://api.deepseek.com` (OpenAI-compatible Responses API). - `default_text_model` (string, optional): defaults to `deepseek-reasoner`. Other available models include `deepseek-chat`, `deepseek-r1`, `deepseek-v3`, `deepseek-v3.2`. Check the DeepSeek API for the latest model list. - `allow_shell` (bool, optional): defaults to `false`. -- `max_subagents` (int, optional): defaults to `5` and is clamped to `1..=5`. +- `max_subagents` (int, optional): defaults to `5` and is clamped to `1..=20`. - `skills_dir` (string, optional): defaults to `~/.deepseek/skills` (each skill is a directory containing `SKILL.md`). - `mcp_config_path` (string, optional): defaults to `~/.deepseek/mcp.json`. - `notes_path` (string, optional): defaults to `~/.deepseek/notes.txt` and is used by the `note` tool. diff --git a/docs/MODES.md b/docs/MODES.md index c9d23883..3a0a4362 100644 --- a/docs/MODES.md +++ b/docs/MODES.md @@ -54,8 +54,7 @@ Run `deepseek --help` for the canonical list. Common flags: - `--yolo`: start in YOLO mode - `-r, --resume `: resume a saved session - `-c, --continue`: resume the most recent session -- `--max-subagents `: clamp to `1..=5` +- `--max-subagents `: clamp to `1..=20` - `--profile `: select config profile - `--config `: config file path - `-v, --verbose`: verbose logging - diff --git a/src/commands/core.rs b/src/commands/core.rs index 360e3901..d1e17aea 100644 --- a/src/commands/core.rs +++ b/src/commands/core.rs @@ -40,9 +40,8 @@ pub fn clear(app: &mut App) -> CommandResult { app.transcript_selection.clear(); app.total_conversation_tokens = 0; app.clear_todos(); - if let Ok(mut plan) = app.plan_state.lock() { - *plan = PlanState::default(); - } + let mut plan = app.plan_state.blocking_lock(); + *plan = PlanState::default(); app.tool_log.clear(); CommandResult::message("Conversation cleared") } diff --git a/src/compaction.rs b/src/compaction.rs index a7e6e8e5..8403d4c5 100644 --- a/src/compaction.rs +++ b/src/compaction.rs @@ -3,7 +3,11 @@ #![allow(dead_code)] use anyhow::Result; +use regex::Regex; +use std::collections::{BTreeSet, HashSet}; use std::fmt::Write; +use std::path::{Path, PathBuf}; +use std::sync::OnceLock; use std::time::Duration; use crate::client::DeepSeekClient; @@ -34,35 +38,375 @@ impl Default for CompactionConfig { } } -pub fn estimate_tokens(messages: &[Message]) -> usize { - // Rough estimate: ~4 chars per token - messages - .iter() - .map(|m| { - m.content - .iter() - .map(|c| match c { - ContentBlock::Text { text, .. } => text.len() / 4, - ContentBlock::Thinking { thinking } => thinking.len() / 4, - ContentBlock::ToolUse { input, .. } => serde_json::to_string(input) - .map(|s| s.len() / 4) - .unwrap_or(100), - ContentBlock::ToolResult { content, .. } => content.len() / 4, - }) - .sum::() - }) - .sum() +const KEEP_RECENT_MESSAGES: usize = 4; +const RECENT_WORKING_SET_WINDOW: usize = 12; +const MAX_WORKING_SET_PATHS: usize = 24; +const MIN_SUMMARIZE_MESSAGES: usize = 6; + +#[derive(Debug, Clone, Default)] +struct CompactionPlan { + pinned_indices: BTreeSet, + summarize_indices: Vec, + working_set_paths: HashSet, } -pub fn should_compact(messages: &[Message], config: &CompactionConfig) -> bool { +fn path_regex() -> &'static Regex { + static PATH_RE: OnceLock = OnceLock::new(); + PATH_RE.get_or_init(|| { + Regex::new( + r"(?x) + (?: + (?P + Cargo\.toml| + Cargo\.lock| + README\.md| + CHANGELOG\.md| + AGENTS\.md| + config\.example\.toml + ) + ) + | + (?P + (?:[A-Za-z0-9._-]+/)+ + [A-Za-z0-9._-]+ + \.(?:rs|toml|md|json|ya?ml|txt|lock) + ) + ", + ) + .expect("path regex is valid") + }) +} + +fn normalize_path_candidate(candidate: &str, workspace: Option<&Path>) -> Option { + if candidate.is_empty() { + return None; + } + + let cleaned = candidate.replace('\\', "/"); + let mut path = PathBuf::from(cleaned); + + if path.is_absolute() { + let ws = workspace?; + if let Ok(stripped) = path.strip_prefix(ws) { + path = stripped.to_path_buf(); + } else { + return None; + } + } + + let rel = path.to_string_lossy().trim_start_matches("./").to_string(); + if rel.is_empty() || rel.contains("..") { + return None; + } + + if let Some(ws) = workspace { + let repo_path = ws.join(&rel); + if repo_path.exists() || looks_repo_relative(&rel) { + return Some(rel); + } + return None; + } + + if looks_repo_relative(&rel) { + return Some(rel); + } + + None +} + +fn looks_repo_relative(path: &str) -> bool { + matches!( + path, + "Cargo.toml" + | "Cargo.lock" + | "README.md" + | "CHANGELOG.md" + | "AGENTS.md" + | "config.example.toml" + ) || path.starts_with("src/") + || path.starts_with("tests/") + || path.starts_with("docs/") + || path.starts_with("examples/") + || path.starts_with("benches/") + || path.starts_with("crates/") + || path.starts_with(".github/") + || (path.contains('/') && path.rsplit('.').next().is_some()) +} + +fn extract_paths_from_text(text: &str, workspace: Option<&Path>) -> Vec { + path_regex() + .captures_iter(text) + .filter_map(|caps| { + let candidate = caps + .name("path") + .or_else(|| caps.name("root")) + .map(|m| m.as_str())?; + normalize_path_candidate(candidate, workspace) + }) + .collect() +} + +fn extract_paths_from_tool_input( + input: &serde_json::Value, + workspace: Option<&Path>, +) -> Vec { + let mut out = Vec::new(); + let Some(obj) = input.as_object() else { + return out; + }; + + for key in ["path", "file", "target", "cwd"] { + if let Some(val) = obj.get(key).and_then(serde_json::Value::as_str) + && let Some(path) = normalize_path_candidate(val, workspace) + { + out.push(path); + } + } + + for key in ["paths", "files", "targets"] { + if let Some(vals) = obj.get(key).and_then(serde_json::Value::as_array) { + for val in vals { + if let Some(s) = val.as_str() + && let Some(path) = normalize_path_candidate(s, workspace) + { + out.push(path); + } + } + } + } + + out +} + +fn message_text(msg: &Message) -> String { + let mut text = String::new(); + for block in &msg.content { + match block { + ContentBlock::Text { text: t, .. } => { + let _ = writeln!(text, "{t}"); + } + ContentBlock::Thinking { .. } => {} + ContentBlock::ToolUse { name, input, .. } => { + let _ = writeln!(text, "[tool_use:{name}] {input}"); + } + ContentBlock::ToolResult { content, .. } => { + let _ = writeln!(text, "{content}"); + } + } + } + text +} + +fn extract_paths_from_message(message: &Message, workspace: Option<&Path>) -> Vec { + let mut paths = Vec::new(); + for block in &message.content { + let candidates = match block { + ContentBlock::Text { text, .. } => extract_paths_from_text(text, workspace), + ContentBlock::ToolResult { content, .. } => extract_paths_from_text(content, workspace), + ContentBlock::ToolUse { input, .. } => extract_paths_from_tool_input(input, workspace), + ContentBlock::Thinking { .. } => Vec::new(), + }; + paths.extend(candidates); + } + paths +} + +fn derive_working_set_paths( + messages: &[Message], + workspace: Option<&Path>, + seed_indices: &[usize], +) -> HashSet { + let mut paths: Vec = Vec::new(); + let mut seen: HashSet = HashSet::new(); + + let mut seeds: Vec = seed_indices + .iter() + .copied() + .filter(|idx| *idx < messages.len()) + .collect(); + seeds.sort_unstable_by(|a, b| b.cmp(a)); + + for idx in seeds { + for candidate in extract_paths_from_message(&messages[idx], workspace) { + if seen.insert(candidate.clone()) { + paths.push(candidate); + if paths.len() >= MAX_WORKING_SET_PATHS { + return paths.into_iter().collect(); + } + } + } + } + + for msg in messages.iter().rev().take(RECENT_WORKING_SET_WINDOW) { + for candidate in extract_paths_from_message(msg, workspace) { + if seen.insert(candidate.clone()) { + paths.push(candidate); + if paths.len() >= MAX_WORKING_SET_PATHS { + return paths.into_iter().collect(); + } + } + } + } + + paths.into_iter().collect() +} + +fn should_pin_message(text: &str, working_set_paths: &HashSet) -> bool { + let lower = text.to_lowercase(); + + let mentions_working_set = working_set_paths.iter().any(|p| text.contains(p)); + if mentions_working_set { + return true; + } + + let error_markers = [ + "error:", + "error ", + "failed", + "panic", + "traceback", + "stack trace", + "assertion failed", + "test failed", + ]; + if error_markers.iter().any(|m| lower.contains(m)) { + return true; + } + + let patch_markers = [ + "diff --git", + "+++ b/", + "--- a/", + "*** begin patch", + "*** update file:", + "*** add file:", + "*** delete file:", + "```diff", + "apply_patch", + ]; + patch_markers.iter().any(|m| lower.contains(m)) +} + +fn plan_compaction( + messages: &[Message], + workspace: Option<&Path>, + keep_recent: usize, + external_pins: Option<&[usize]>, + external_working_set_paths: Option<&[String]>, +) -> CompactionPlan { + let mut pinned_indices: BTreeSet = BTreeSet::new(); + let len = messages.len(); + if len == 0 { + return CompactionPlan::default(); + } + + // Always pin the tail of the conversation to preserve immediate context. + let recent_start = len.saturating_sub(keep_recent); + pinned_indices.extend(recent_start..len); + + // Derive a repo-aware working set from recent messages/tool calls and + // merge it with any externally provided working-set paths. + let seed_indices = external_pins.unwrap_or(&[]); + let mut working_set_paths = derive_working_set_paths(messages, workspace, seed_indices); + if let Some(paths) = external_working_set_paths { + for path in paths { + if let Some(normalized) = normalize_path_candidate(path, workspace) { + let _ = working_set_paths.insert(normalized); + } + } + } + + for (idx, msg) in messages.iter().enumerate() { + if pinned_indices.contains(&idx) { + continue; + } + let text = message_text(msg); + if should_pin_message(&text, &working_set_paths) { + pinned_indices.insert(idx); + } + } + + // External pins are authoritative and should be preserved even if they + // were not detected by the heuristics above. + if let Some(pins) = external_pins { + pinned_indices.extend(pins.iter().copied().filter(|idx| *idx < len)); + } + + let summarize_indices = (0..len) + .filter(|idx| !pinned_indices.contains(idx)) + .collect(); + + CompactionPlan { + pinned_indices, + summarize_indices, + working_set_paths, + } +} + +fn estimate_tokens_for_message(message: &Message) -> usize { + message + .content + .iter() + .map(|c| match c { + ContentBlock::Text { text, .. } => text.len() / 4, + ContentBlock::Thinking { thinking } => thinking.len() / 4, + ContentBlock::ToolUse { input, .. } => serde_json::to_string(input) + .map(|s| s.len() / 4) + .unwrap_or(100), + ContentBlock::ToolResult { content, .. } => content.len() / 4, + }) + .sum::() +} + +pub fn estimate_tokens(messages: &[Message]) -> usize { + // Rough estimate: ~4 chars per token + messages.iter().map(estimate_tokens_for_message).sum() +} + +pub fn should_compact( + messages: &[Message], + config: &CompactionConfig, + workspace: Option<&Path>, + external_pins: Option<&[usize]>, + external_working_set_paths: Option<&[String]>, +) -> bool { if !config.enabled { return false; } - let token_estimate = estimate_tokens(messages); - let message_count = messages.len(); + let plan = plan_compaction( + messages, + workspace, + KEEP_RECENT_MESSAGES, + external_pins, + external_working_set_paths, + ); + let pinned_tokens: usize = plan + .pinned_indices + .iter() + .map(|&idx| estimate_tokens_for_message(&messages[idx])) + .sum(); + let pinned_count = plan.pinned_indices.len(); - token_estimate > config.token_threshold || message_count > config.message_threshold + let token_estimate: usize = plan + .summarize_indices + .iter() + .map(|&idx| estimate_tokens_for_message(&messages[idx])) + .sum(); + let message_count = plan.summarize_indices.len(); + + // Pinned messages consume part of the budget, so compact earlier when needed. + let effective_token_threshold = config.token_threshold.saturating_sub(pinned_tokens); + let effective_message_threshold = config.message_threshold.saturating_sub(pinned_count); + + let enough_unpinned = message_count >= MIN_SUMMARIZE_MESSAGES + || effective_token_threshold == 0 + || effective_message_threshold == 0; + if !enough_unpinned { + return false; + } + + token_estimate > effective_token_threshold || message_count > effective_message_threshold } fn truncate_chars(text: &str, max_chars: usize) -> &str { @@ -115,6 +459,9 @@ pub async fn compact_messages_safe( client: &DeepSeekClient, messages: &[Message], config: &CompactionConfig, + workspace: Option<&Path>, + external_pins: Option<&[usize]>, + external_working_set_paths: Option<&[String]>, ) -> Result { const MAX_RETRIES: u32 = 3; const BASE_DELAY_MS: u64 = 1000; @@ -128,7 +475,16 @@ pub async fn compact_messages_safe( tokio::time::sleep(delay).await; } - match compact_messages(client, messages, config).await { + match compact_messages( + client, + messages, + config, + workspace, + external_pins, + external_working_set_paths, + ) + .await + { Ok((msgs, prompt)) => { return Ok(CompactionResult { messages: msgs, @@ -154,28 +510,39 @@ pub async fn compact_messages( client: &DeepSeekClient, messages: &[Message], config: &CompactionConfig, + workspace: Option<&Path>, + external_pins: Option<&[usize]>, + external_working_set_paths: Option<&[String]>, ) -> Result<(Vec, Option)> { if messages.is_empty() { return Ok((Vec::new(), None)); } - // Keep the last few messages as-is - let keep_recent = 4; - let (to_summarize, recent) = if messages.len() <= keep_recent { + let plan = plan_compaction( + messages, + workspace, + KEEP_RECENT_MESSAGES, + external_pins, + external_working_set_paths, + ); + if plan.summarize_indices.is_empty() { return Ok((messages.to_vec(), None)); - } else { - let split_point = messages.len() - keep_recent; - (&messages[..split_point], &messages[split_point..]) - }; + } - // Create a summary of older messages - let summary = create_summary(client, to_summarize, &config.model).await?; + let to_summarize: Vec = plan + .summarize_indices + .iter() + .map(|&idx| messages[idx].clone()) + .collect(); + + // Create a summary of the unpinned portion of the conversation + let summary = create_summary(client, &to_summarize, &config.model).await?; // Build new message list with summary as system block let summary_block = SystemBlock { block_type: "text".to_string(), text: format!( - "## Conversation Summary\n\nThe following is a summary of the earlier conversation:\n\n{summary}\n\n---\nRecent messages follow:" + "## Conversation Summary\n\nThe following summarizes earlier context that was not pinned to the working set:\n\n{summary}\n\n---\nPinned messages follow:" ), cache_control: if config.cache_summary { Some(CacheControl { @@ -186,8 +553,14 @@ pub async fn compact_messages( }, }; + let pinned_messages = messages + .iter() + .enumerate() + .filter_map(|(idx, msg)| plan.pinned_indices.contains(&idx).then_some(msg.clone())) + .collect(); + Ok(( - recent.to_vec(), + pinned_messages, Some(SystemPrompt::Blocks(vec![summary_block])), )) } @@ -316,6 +689,16 @@ pub fn merge_system_prompts( mod tests { use super::*; + fn msg(role: &str, text: &str) -> Message { + Message { + role: role.to_string(), + content: vec![ContentBlock::Text { + text: text.to_string(), + cache_control: None, + }], + } + } + #[test] fn truncate_chars_respects_unicode_boundaries() { let text = "abc😀é"; @@ -388,7 +771,7 @@ mod tests { }], }) .collect(); - assert!(!should_compact(&messages, &config)); + assert!(!should_compact(&messages, &config, None, None, None)); } #[test] @@ -410,7 +793,7 @@ mod tests { }], }) .collect(); - assert!(!should_compact(&few_messages, &config)); + assert!(!should_compact(&few_messages, &config, None, None, None)); // Over threshold let many_messages: Vec = (0..10) @@ -422,6 +805,122 @@ mod tests { }], }) .collect(); - assert!(should_compact(&many_messages, &config)); + assert!(should_compact(&many_messages, &config, None, None, None)); + } + + #[test] + fn plan_compaction_pins_recent_and_working_set_paths() { + let messages = vec![ + msg("user", "General discussion"), + msg("assistant", "Unrelated note"), + msg("user", "Earlier we touched src/core/engine.rs"), + msg("assistant", "More unrelated chatter"), + msg("user", "Let's keep working on src/core/engine.rs"), + msg("assistant", "Tool output mentions src/core/engine.rs too"), + msg("assistant", "Recent reasoning"), + msg("user", "Final recent instruction"), + ]; + + let plan = plan_compaction(&messages, None, KEEP_RECENT_MESSAGES, None, None); + + assert!(plan.pinned_indices.contains(&2)); + for idx in 4..messages.len() { + assert!(plan.pinned_indices.contains(&idx)); + } + assert!(plan.summarize_indices.contains(&0)); + assert!(plan.summarize_indices.contains(&1)); + assert!(plan.summarize_indices.contains(&3)); + } + + #[test] + fn plan_compaction_respects_external_pins() { + let messages = vec![ + msg("user", "noise 0"), + msg("assistant", "noise 1"), + msg("user", "noise 2"), + msg("assistant", "noise 3"), + msg("user", "recent 4"), + msg("assistant", "recent 5"), + msg("assistant", "recent 6"), + msg("user", "recent 7"), + ]; + + let pins = vec![1usize]; + let plan = plan_compaction(&messages, None, KEEP_RECENT_MESSAGES, Some(&pins), None); + + assert!(plan.pinned_indices.contains(&1)); + assert!(!plan.summarize_indices.contains(&1)); + } + + #[test] + fn plan_compaction_uses_external_working_set_paths() { + let mut messages = vec![msg("user", "edit src/core/engine.rs now")]; + messages.extend((1..20).map(|i| msg("assistant", &format!("noise {i}")))); + + let working_set_paths = vec!["src/core/engine.rs".to_string()]; + let plan = plan_compaction( + &messages, + None, + KEEP_RECENT_MESSAGES, + None, + Some(&working_set_paths), + ); + + assert!(plan.pinned_indices.contains(&0)); + } + + #[test] + fn should_compact_ignores_fully_pinned_context() { + let config = CompactionConfig { + enabled: true, + token_threshold: 10, + message_threshold: 2, + ..Default::default() + }; + + let messages: Vec = (0..12) + .map(|_| msg("user", "Work on src/compaction.rs right now")) + .collect(); + + assert!(!should_compact(&messages, &config, None, None, None)); + } + + #[test] + fn should_compact_counts_only_unpinned_messages() { + let config = CompactionConfig { + enabled: true, + token_threshold: 1_000_000, + message_threshold: 5, + ..Default::default() + }; + + let mut messages: Vec = (0..7) + .map(|i| msg("user", &format!("noise message {i}"))) + .collect(); + messages.push(msg("user", "Focus on src/core/engine.rs")); + messages.extend((0..4).map(|i| msg("assistant", &format!("recent {i}")))); + + assert!(should_compact(&messages, &config, None, None, None)); + } + + #[test] + fn should_compact_when_pins_consume_budget() { + let config = CompactionConfig { + enabled: true, + token_threshold: 50, + message_threshold: 50, + ..Default::default() + }; + + let mut messages = vec![msg("user", "noise 0"), msg("assistant", "noise 1")]; + messages.extend((0..4).map(|_| { + msg( + "assistant", + &format!("{} src/core/engine.rs", "x".repeat(400)), + ) + })); + + // Pinned recent messages exceed the token budget, so unpinned noise should trigger compaction. + assert!(should_compact(&messages, &config, None, None, None)); } } diff --git a/src/config.rs b/src/config.rs index 1fec4878..698bebca 100644 --- a/src/config.rs +++ b/src/config.rs @@ -11,6 +11,9 @@ use serde::Deserialize; use crate::features::{Features, FeaturesToml, is_known_feature_key}; use crate::hooks::HooksConfig; +pub const DEFAULT_MAX_SUBAGENTS: usize = 5; +pub const MAX_SUBAGENTS: usize = 20; + // === Types === /// Raw retry configuration loaded from config files. @@ -209,7 +212,9 @@ impl Config { /// Return the maximum number of concurrent sub-agents. #[must_use] pub fn max_subagents(&self) -> usize { - self.max_subagents.unwrap_or(5).clamp(1, 5) + self.max_subagents + .unwrap_or(DEFAULT_MAX_SUBAGENTS) + .clamp(1, MAX_SUBAGENTS) } /// Get hooks configuration, returning default if not configured. @@ -321,7 +326,7 @@ fn apply_env_overrides(config: &mut Config) { if let Ok(value) = std::env::var("DEEPSEEK_MAX_SUBAGENTS") && let Ok(parsed) = value.parse::() { - config.max_subagents = Some(parsed.clamp(1, 5)); + config.max_subagents = Some(parsed.clamp(1, MAX_SUBAGENTS)); } } diff --git a/src/core/engine.rs b/src/core/engine.rs index 7237779c..3bdc01b9 100644 --- a/src/core/engine.rs +++ b/src/core/engine.rs @@ -24,6 +24,7 @@ use crate::compaction::{ CompactionConfig, compact_messages_safe, merge_system_prompts, should_compact, }; use crate::config::Config; +use crate::config::DEFAULT_MAX_SUBAGENTS; use crate::duo::{DuoSession, SharedDuoSession, session_summary as duo_session_summary}; use crate::features::{Feature, Features}; use crate::llm_client::LlmClient; @@ -33,12 +34,12 @@ use crate::models::{ }; use crate::prompts; use crate::rlm::{RlmSession, SharedRlmSession, session_summary as rlm_session_summary}; -use crate::tools::plan::{PlanState, SharedPlanState}; +use crate::tools::plan::{SharedPlanState, new_shared_plan_state}; use crate::tools::spec::{ApprovalRequirement, ToolError, ToolResult}; use crate::tools::subagent::{ SharedSubAgentManager, SubAgentRuntime, SubAgentType, new_shared_subagent_manager, }; -use crate::tools::todo::{SharedTodoList, TodoList}; +use crate::tools::todo::{SharedTodoList, new_shared_todo_list}; use crate::tools::{ToolContext, ToolRegistryBuilder}; use crate::tui::app::AppMode; @@ -93,13 +94,13 @@ impl Default for EngineConfig { notes_path: PathBuf::from("notes.txt"), mcp_config_path: PathBuf::from("mcp.json"), max_steps: 100, - max_subagents: 5, + max_subagents: DEFAULT_MAX_SUBAGENTS, features: Features::with_defaults(), rlm_session: Arc::new(Mutex::new(RlmSession::default())), duo_session: Arc::new(Mutex::new(DuoSession::new())), compaction: CompactionConfig::default(), - todos: Arc::new(Mutex::new(TodoList::new())), - plan_state: Arc::new(Mutex::new(PlanState::default())), + todos: new_shared_todo_list(), + plan_state: new_shared_plan_state(), } } } @@ -236,6 +237,19 @@ struct ToolExecOutcome { result: Result, } +#[derive(Debug, Clone)] +struct ToolExecutionPlan { + index: usize, + id: String, + name: String, + input: serde_json::Value, + interactive: bool, + approval_required: bool, + approval_description: String, + supports_parallel: bool, + read_only: bool, +} + // Hold the lock guard for the duration of a tool execution. enum ToolExecGuard<'a> { Read(tokio::sync::RwLockReadGuard<'a, ()>), @@ -357,6 +371,38 @@ fn extract_balanced_segment(text: &str, open: char, close: char) -> Option bool { + !plans.is_empty() + && plans.iter().all(|plan| { + plan.read_only && plan.supports_parallel && !plan.approval_required && !plan.interactive + }) +} + +fn format_tool_error(err: &ToolError, tool_name: &str) -> String { + match err { + ToolError::InvalidInput { message } => { + format!("Invalid input for tool '{tool_name}': {message}") + } + ToolError::MissingField { field } => { + format!("Tool '{tool_name}' is missing required field '{field}'") + } + ToolError::PathEscape { path } => format!( + "Path escapes workspace: {}. Use a workspace-relative path or enable trust mode.", + path.display() + ), + ToolError::ExecutionFailed { message } => message.clone(), + ToolError::Timeout { seconds } => format!( + "Tool '{tool_name}' timed out after {seconds}s. Try a narrower scope or a longer timeout." + ), + ToolError::NotAvailable { message } => format!( + "Tool '{tool_name}' is not available: {message}. Check mode, feature flags, or tool name." + ), + ToolError::PermissionDenied { message } => format!( + "Tool '{tool_name}' was denied: {message}. Adjust approval mode or request permission." + ), + } +} + impl Engine { /// Create a new engine with the given configuration pub fn new(config: EngineConfig, api_config: &Config) -> (Self, EngineHandle) { @@ -382,9 +428,11 @@ impl Engine { ); // Set up system prompt with project context (default to agent mode) + let working_set_summary = session.working_set.summary_block(&config.workspace); let system_prompt = prompts::system_prompt_for_mode_with_context( AppMode::Agent, &config.workspace, + working_set_summary.as_deref(), None, None, ); @@ -472,19 +520,16 @@ impl Engine { Some(self.tx_event.clone()), ); - let result = self - .subagent_manager - .lock() - .map_err(|_| anyhow::anyhow!("Failed to lock sub-agent manager")) - .and_then(|mut manager| { - manager.spawn_background( - Arc::clone(&self.subagent_manager), - runtime, - SubAgentType::General, - prompt.clone(), - None, - ) - }); + let result = { + let mut manager = self.subagent_manager.lock().await; + manager.spawn_background( + Arc::clone(&self.subagent_manager), + runtime, + SubAgentType::General, + prompt.clone(), + None, + ) + }; match result { Ok(snapshot) => { @@ -508,26 +553,11 @@ impl Engine { } } Op::ListSubAgents => { - let result = self - .subagent_manager - .lock() - .map(|manager| manager.list()) - .map_err(|_| anyhow::anyhow!("Failed to lock sub-agent manager")); - - match result { - Ok(agents) => { - let _ = self.tx_event.send(Event::AgentList { agents }).await; - } - Err(err) => { - let _ = self - .tx_event - .send(Event::error( - format!("Failed to list sub-agents: {err}"), - true, - )) - .await; - } - } + let agents = { + let manager = self.subagent_manager.lock().await; + manager.list() + }; + let _ = self.tx_event.send(Event::AgentList { agents }).await; } Op::ChangeMode { mode } => { let _ = self @@ -575,6 +605,7 @@ impl Engine { } else { None }; + self.session.rebuild_working_set(); let _ = self .tx_event .send(Event::status("Session context synced".to_string())) @@ -613,6 +644,10 @@ impl Engine { return; } + self.session + .working_set + .observe_user_message(&content, &self.session.workspace); + // Add user message to session let user_msg = Message { role: "user".to_string(), @@ -652,9 +687,14 @@ impl Engine { } else { None }; + let working_set_summary = self + .session + .working_set + .summary_block(&self.config.workspace); self.session.system_prompt = Some(prompts::system_prompt_for_mode_with_context( mode, &self.config.workspace, + working_set_summary.as_deref(), rlm_summary.as_deref(), duo_summary.as_deref(), )); @@ -668,6 +708,8 @@ impl Engine { ToolRegistryBuilder::new() .with_read_only_file_tools() .with_search_tools() + .with_git_tools() + .with_diagnostics_tool() .with_todo_tool(todo_list.clone()) .with_plan_tool(plan_state.clone()) } else { @@ -675,6 +717,9 @@ impl Engine { .with_file_tools() .with_note_tool() .with_search_tools() + .with_git_tools() + .with_diagnostics_tool() + .with_test_runner_tool() .with_todo_tool(todo_list.clone()) .with_plan_tool(plan_state.clone()) }; @@ -932,6 +977,8 @@ impl Engine { .clone() .expect("DeepSeek client should be configured"); + let mut consecutive_tool_error_steps = 0u32; + loop { if self.cancel_token.is_cancelled() { let _ = self.tx_event.send(Event::status("Request cancelled")).await; @@ -946,8 +993,20 @@ impl Engine { break; } + let compaction_pins = self + .session + .working_set + .pinned_message_indices(&self.session.messages, &self.session.workspace); + let compaction_paths = self.session.working_set.top_paths(24); + if self.config.compaction.enabled - && should_compact(&self.session.messages, &self.config.compaction) + && should_compact( + &self.session.messages, + &self.config.compaction, + Some(&self.session.workspace), + Some(&compaction_pins), + Some(&compaction_paths), + ) { let _ = self .tx_event @@ -957,6 +1016,9 @@ impl Engine { &client, &self.session.messages, &self.config.compaction, + Some(&self.session.workspace), + Some(&compaction_pins), + Some(&compaction_paths), ) .await { @@ -1303,18 +1365,16 @@ impl Engine { None }; - let mut tool_tasks = FuturesUnordered::new(); - let mut outcomes: Vec> = Vec::with_capacity(tool_uses.len()); - outcomes.resize_with(tool_uses.len(), || None); - + let mut plans: Vec = Vec::with_capacity(tool_uses.len()); for (index, tool) in tool_uses.iter().enumerate() { let tool_id = tool.id.clone(); let tool_name = tool.name.clone(); let tool_input = tool.input.clone(); crate::logging::info(format!( - "Executing tool '{}' with input: {:?}", + "Planning tool '{}' with input: {:?}", tool_name, tool_input )); + let interactive = tool_name == "exec_shell" && tool_input .get("interactive") @@ -1323,166 +1383,216 @@ impl Engine { let mut approval_required = false; let mut approval_description = "Tool execution requires approval".to_string(); - let mut supports_parallel = McpPool::is_mcp_tool(&tool_name); - if let Some(registry) = tool_registry - && let Some(spec) = registry.get(&tool_name) - { - approval_required = spec.approval_requirement() != ApprovalRequirement::Auto; - approval_description = spec.description().to_string(); - supports_parallel = spec.supports_parallel(); - } + let mut supports_parallel = false; + let mut read_only = false; - // Handle approval flow: returns (result_override, context_override) - let (result_override, context_override): ( - Option>, - Option, - ) = if approval_required { - let _ = self - .tx_event - .send(Event::ApprovalRequired { - id: tool_id.clone(), - tool_name: tool_name.clone(), - description: approval_description, - }) - .await; - - match self.await_tool_approval(&tool_id).await { - Ok(ApprovalResult::Approved) => (None, None), - Ok(ApprovalResult::Denied) => ( - Some(Err(ToolError::permission_denied(format!( - "Tool '{tool_name}' denied by user" - )))), - None, - ), - Ok(ApprovalResult::RetryWithPolicy(policy)) => { - // Create a context with the elevated sandbox policy - let elevated_context = tool_registry - .map(|r| r.context().clone().with_elevated_sandbox_policy(policy)); - (None, elevated_context) - } - Err(err) => (Some(Err(err)), None), + if !McpPool::is_mcp_tool(&tool_name) { + if let Some(registry) = tool_registry + && let Some(spec) = registry.get(&tool_name) + { + approval_required = + spec.approval_requirement() != ApprovalRequirement::Auto; + approval_description = spec.description().to_string(); + supports_parallel = spec.supports_parallel(); + read_only = spec.is_read_only(); } - } else { - (None, None) - }; - - let registry = tool_registry; - let lock = tool_exec_lock.clone(); - let mcp_pool = mcp_pool.clone(); - let tx_event = self.tx_event.clone(); - - if let Some(result_override) = result_override { - let started_at = Instant::now(); - let _ = self - .tx_event - .send(Event::ToolCallComplete { - id: tool_id.clone(), - name: tool_name.clone(), - result: result_override.clone(), - }) - .await; - outcomes[index] = Some(ToolExecOutcome { - index, - id: tool_id, - name: tool_name, - input: tool_input, - started_at, - result: result_override, - }); - continue; } - if approval_required { - let started_at = Instant::now(); - let result = Self::execute_tool_with_lock( - lock, - supports_parallel, - interactive, - self.tx_event.clone(), - tool_name.clone(), - tool_input.clone(), - registry, - mcp_pool.clone(), - context_override, - ) - .await; - let _ = self - .tx_event - .send(Event::ToolCallComplete { - id: tool_id.clone(), - name: tool_name.clone(), - result: result.clone(), - }) - .await; - outcomes[index] = Some(ToolExecOutcome { - index, - id: tool_id, - name: tool_name, - input: tool_input, - started_at, - result, - }); - continue; - } - - let started_at = Instant::now(); - tool_tasks.push(async move { - let result = Engine::execute_tool_with_lock( - lock, - supports_parallel, - interactive, - tx_event.clone(), - tool_name.clone(), - tool_input.clone(), - registry, - mcp_pool, - None, // No context override for non-approval-required tools - ) - .await; - - let _ = tx_event - .send(Event::ToolCallComplete { - id: tool_id.clone(), - name: tool_name.clone(), - result: result.clone(), - }) - .await; - - ToolExecOutcome { - index, - id: tool_id, - name: tool_name, - input: tool_input, - started_at, - result, - } + plans.push(ToolExecutionPlan { + index, + id: tool_id, + name: tool_name, + input: tool_input, + interactive, + approval_required, + approval_description, + supports_parallel, + read_only, }); } - while let Some(outcome) = tool_tasks.next().await { - let index = outcome.index; - outcomes[index] = Some(outcome); + let parallel_allowed = should_parallelize_tool_batch(&plans); + if parallel_allowed && plans.len() > 1 { + let _ = self + .tx_event + .send(Event::status(format!( + "Executing {} read-only tools in parallel", + plans.len() + ))) + .await; + } else if plans.len() > 1 { + let _ = self + .tx_event + .send(Event::status( + "Executing tools sequentially (writes, approvals, or non-parallel tools detected)", + )) + .await; } + let mut outcomes: Vec> = Vec::with_capacity(plans.len()); + outcomes.resize_with(plans.len(), || None); + + if parallel_allowed { + let mut tool_tasks = FuturesUnordered::new(); + for plan in plans { + let registry = tool_registry; + let lock = tool_exec_lock.clone(); + let mcp_pool = mcp_pool.clone(); + let tx_event = self.tx_event.clone(); + let started_at = Instant::now(); + + tool_tasks.push(async move { + let result = Engine::execute_tool_with_lock( + lock, + plan.supports_parallel, + plan.interactive, + tx_event.clone(), + plan.name.clone(), + plan.input.clone(), + registry, + mcp_pool, + None, + ) + .await; + + let _ = tx_event + .send(Event::ToolCallComplete { + id: plan.id.clone(), + name: plan.name.clone(), + result: result.clone(), + }) + .await; + + ToolExecOutcome { + index: plan.index, + id: plan.id, + name: plan.name, + input: plan.input, + started_at, + result, + } + }); + } + + while let Some(outcome) = tool_tasks.next().await { + let index = outcome.index; + outcomes[index] = Some(outcome); + } + } else { + for plan in plans { + let tool_id = plan.id.clone(); + let tool_name = plan.name.clone(); + let tool_input = plan.input.clone(); + + // Handle approval flow: returns (result_override, context_override) + let (result_override, context_override): ( + Option>, + Option, + ) = if plan.approval_required { + let _ = self + .tx_event + .send(Event::ApprovalRequired { + id: tool_id.clone(), + tool_name: tool_name.clone(), + description: plan.approval_description.clone(), + }) + .await; + + match self.await_tool_approval(&tool_id).await { + Ok(ApprovalResult::Approved) => (None, None), + Ok(ApprovalResult::Denied) => ( + Some(Err(ToolError::permission_denied(format!( + "Tool '{tool_name}' denied by user" + )))), + None, + ), + Ok(ApprovalResult::RetryWithPolicy(policy)) => { + let elevated_context = tool_registry.map(|r| { + r.context().clone().with_elevated_sandbox_policy(policy) + }); + (None, elevated_context) + } + Err(err) => (Some(Err(err)), None), + } + } else { + (None, None) + }; + + let started_at = Instant::now(); + let result = if let Some(result_override) = result_override { + result_override + } else { + Self::execute_tool_with_lock( + tool_exec_lock.clone(), + plan.supports_parallel, + plan.interactive, + self.tx_event.clone(), + tool_name.clone(), + tool_input.clone(), + tool_registry, + mcp_pool.clone(), + context_override, + ) + .await + }; + + let _ = self + .tx_event + .send(Event::ToolCallComplete { + id: tool_id.clone(), + name: tool_name.clone(), + result: result.clone(), + }) + .await; + + outcomes[plan.index] = Some(ToolExecOutcome { + index: plan.index, + id: tool_id, + name: tool_name, + input: tool_input, + started_at, + result, + }); + } + } + + let mut step_error_count = 0usize; + for outcome in outcomes.into_iter().flatten() { let duration = outcome.started_at.elapsed(); + let tool_input = outcome.input.clone(); + let tool_name_for_ws = outcome.name.clone(); let mut tool_call = TurnToolCall::new(outcome.id.clone(), outcome.name.clone(), outcome.input); match outcome.result { Ok(output) => { - tool_call.set_result(output.content.clone(), duration); + let output_content = output.content; + tool_call.set_result(output_content.clone(), duration); + self.session.working_set.observe_tool_call( + &tool_name_for_ws, + &tool_input, + Some(&output_content), + &self.session.workspace, + ); self.session.add_message(Message { role: "user".to_string(), content: vec![ContentBlock::ToolResult { tool_use_id: outcome.id, - content: output.content, + content: output_content, }], }); } Err(e) => { - let error = e.to_string(); + step_error_count += 1; + let error = format_tool_error(&e, &outcome.name); tool_call.set_error(error.clone(), duration); + self.session.working_set.observe_tool_call( + &tool_name_for_ws, + &tool_input, + Some(&error), + &self.session.workspace, + ); self.session.add_message(Message { role: "user".to_string(), content: vec![ContentBlock::ToolResult { @@ -1496,6 +1606,22 @@ impl Engine { turn.record_tool_call(tool_call); } + if step_error_count > 0 { + consecutive_tool_error_steps = consecutive_tool_error_steps.saturating_add(1); + } else { + consecutive_tool_error_steps = 0; + } + + if consecutive_tool_error_steps >= 3 { + let _ = self + .tx_event + .send(Event::status( + "Stopping after repeated tool failures. Try a narrower scope or adjust approvals.", + )) + .await; + break; + } + turn.next_step(); } } @@ -1521,3 +1647,83 @@ pub fn spawn_engine(config: EngineConfig, api_config: &Config) -> EngineHandle { handle } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + use std::path::PathBuf; + use std::time::Instant; + + fn make_plan( + read_only: bool, + supports_parallel: bool, + approval_required: bool, + interactive: bool, + ) -> ToolExecutionPlan { + ToolExecutionPlan { + index: 0, + id: "tool-1".to_string(), + name: "grep_files".to_string(), + input: json!({"pattern": "test"}), + interactive, + approval_required, + approval_description: "desc".to_string(), + supports_parallel, + read_only, + } + } + + #[test] + fn parallel_batch_requires_read_only_parallel_tools() { + let plans = vec![make_plan(true, true, false, false)]; + assert!(should_parallelize_tool_batch(&plans)); + + let plans = vec![ + make_plan(true, true, false, false), + make_plan(true, true, false, false), + ]; + assert!(should_parallelize_tool_batch(&plans)); + + let plans = vec![make_plan(false, true, false, false)]; + assert!(!should_parallelize_tool_batch(&plans)); + + let plans = vec![make_plan(true, false, false, false)]; + assert!(!should_parallelize_tool_batch(&plans)); + + let plans = vec![make_plan(true, true, true, false)]; + assert!(!should_parallelize_tool_batch(&plans)); + + let plans = vec![make_plan(true, true, false, true)]; + assert!(!should_parallelize_tool_batch(&plans)); + } + + #[test] + fn tool_error_messages_include_actionable_hints() { + let path_error = ToolError::path_escape(PathBuf::from("../escape.txt")); + let formatted = format_tool_error(&path_error, "read_file"); + assert!(formatted.contains("escapes workspace")); + + let missing_field = ToolError::missing_field("path"); + let formatted = format_tool_error(&missing_field, "read_file"); + assert!(formatted.contains("missing required field")); + + let timeout = ToolError::Timeout { seconds: 5 }; + let formatted = format_tool_error(&timeout, "exec_shell"); + assert!(formatted.contains("timed out")); + } + + #[test] + fn tool_exec_outcome_tracks_duration() { + let outcome = ToolExecOutcome { + index: 0, + id: "tool-1".to_string(), + name: "grep_files".to_string(), + input: json!({"pattern": "test"}), + started_at: Instant::now(), + result: Ok(ToolResult::success("ok")), + }; + + assert!(outcome.started_at.elapsed().as_nanos() > 0); + } +} diff --git a/src/core/session.rs b/src/core/session.rs index 1f8487b3..622613b0 100644 --- a/src/core/session.rs +++ b/src/core/session.rs @@ -4,6 +4,7 @@ use crate::models::{Message, SystemPrompt, Usage}; use crate::project_context::{ProjectContext, load_project_context_with_parents}; +use crate::working_set::WorkingSet; use std::path::PathBuf; /// Session state for the engine. @@ -41,6 +42,9 @@ pub struct Session { /// Project context loaded from AGENTS.md, etc. pub project_context: Option, + + /// Repo-aware working set for context management. + pub working_set: WorkingSet, } /// Cumulative usage statistics for a session. @@ -96,6 +100,7 @@ impl Session { } else { None }, + working_set: WorkingSet::default(), } } @@ -111,6 +116,12 @@ impl Session { self.messages.push(message); } + /// Rebuild the working set from current messages (best effort). + pub fn rebuild_working_set(&mut self) { + self.working_set + .rebuild_from_messages(&self.messages, &self.workspace); + } + /// Clear the conversation history pub fn clear(&mut self) { self.messages.clear(); diff --git a/src/eval.rs b/src/eval.rs new file mode 100644 index 00000000..c2154aa3 --- /dev/null +++ b/src/eval.rs @@ -0,0 +1,636 @@ +//! Offline evaluation harness for exercising representative tool loops. +//! +//! This module is intentionally self-contained so it can be wired into a CLI +//! command later without calling the network or any LLM endpoints. + +use anyhow::{Context, Result, anyhow}; +use ignore::WalkBuilder; +use regex::Regex; +use serde::Serialize; +use std::collections::BTreeMap; +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::time::{Duration, Instant}; +use tempfile::TempDir; + +/// Representative tool steps covered by the evaluation harness. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] +pub enum ScenarioStepKind { + List, + Read, + Search, + Edit, + ApplyPatch, + ExecShell, +} + +impl ScenarioStepKind { + /// Tool name associated with this step. + pub fn tool_name(self) -> &'static str { + match self { + ScenarioStepKind::List => "list_dir", + ScenarioStepKind::Read => "read_file", + ScenarioStepKind::Search => "search", + ScenarioStepKind::Edit => "edit_file", + ScenarioStepKind::ApplyPatch => "apply_patch", + ScenarioStepKind::ExecShell => "exec_shell", + } + } + + /// Parse a step kind from CLI-friendly strings. + pub fn parse(value: &str) -> Option { + match value.trim().to_lowercase().as_str() { + "list" | "list_dir" => Some(Self::List), + "read" | "read_file" => Some(Self::Read), + "search" | "grep" | "grep_files" => Some(Self::Search), + "edit" | "edit_file" => Some(Self::Edit), + "patch" | "apply_patch" => Some(Self::ApplyPatch), + "shell" | "exec_shell" | "exec" => Some(Self::ExecShell), + _ => None, + } + } +} + +/// Aggregate statistics for a single tool kind. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] +pub struct ToolStats { + pub invocations: usize, + pub errors: usize, + pub total_duration: Duration, +} + +/// Top-level metrics produced by an evaluation run. +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct EvalMetrics { + pub success: bool, + pub tool_errors: usize, + pub steps: usize, + pub duration: Duration, + pub per_tool: BTreeMap, +} + +/// One tool invocation recorded by the harness. +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct EvalStep { + pub kind: ScenarioStepKind, + pub tool_name: &'static str, + pub success: bool, + pub duration: Duration, + pub error: Option, + pub output: Option, +} + +/// Summary of the generated temporary workspace. +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct WorkspaceSummary { + pub root: PathBuf, + pub file_count: usize, + pub files: Vec, +} + +/// Configuration for the offline evaluation harness. +#[derive(Debug, Clone)] +pub struct EvalHarnessConfig { + /// Human-readable scenario name for reporting. + pub scenario_name: String, + /// If set, the harness will intentionally fail this step to test metrics. + pub fail_step: Option, + /// Shell command executed during the `exec_shell` step. + pub shell_command: String, + /// Token that must appear in shell output for validation. + pub shell_expect_token: String, + /// Maximum characters stored for step output summaries. + pub max_output_chars: usize, +} + +impl Default for EvalHarnessConfig { + fn default() -> Self { + let shell_command = if cfg!(windows) { + "echo eval-harness".to_string() + } else { + "printf eval-harness".to_string() + }; + Self { + scenario_name: "offline-tool-loop".to_string(), + fail_step: None, + shell_command, + shell_expect_token: "eval-harness".to_string(), + max_output_chars: 240, + } + } +} + +/// Offline harness that exercises representative tool loops in a temp workspace. +#[derive(Debug, Clone)] +pub struct EvalHarness { + config: EvalHarnessConfig, +} + +impl EvalHarness { + /// Create a new harness with the provided configuration. + pub fn new(config: EvalHarnessConfig) -> Self { + Self { config } + } + + /// Execute the offline evaluation scenario and return detailed results. + pub fn run(&self) -> Result { + let started_at = Instant::now(); + let workspace = tempfile::Builder::new() + .prefix("deepseek-eval-") + .tempdir() + .context("failed to create evaluation workspace")?; + + let seed = seed_workspace(workspace.path())?; + + let mut steps = Vec::new(); + let mut per_tool: BTreeMap = BTreeMap::new(); + + let list_output = self.run_step(ScenarioStepKind::List, &mut steps, &mut per_tool, || { + let entries = list_dir(workspace.path())?; + Ok(entries.join(", ")) + }); + + let _read_output = self.run_step(ScenarioStepKind::Read, &mut steps, &mut per_tool, || { + let path = if self.config.fail_step == Some(ScenarioStepKind::Read) { + workspace.path().join("missing.txt") + } else { + seed.notes_path.clone() + }; + read_file(&path) + }); + + let search_output = + self.run_step(ScenarioStepKind::Search, &mut steps, &mut per_tool, || { + let root = if self.config.fail_step == Some(ScenarioStepKind::Search) { + workspace.path().join("missing-dir") + } else { + workspace.path().to_path_buf() + }; + let result = search_files(&root, "offline")?; + Ok(format!("matches={}", result.matches.len())) + }); + + let edit_output = self.run_step(ScenarioStepKind::Edit, &mut steps, &mut per_tool, || { + let path = if self.config.fail_step == Some(ScenarioStepKind::Edit) { + workspace.path().join("missing.txt") + } else { + seed.notes_path.clone() + }; + edit_file_append(&path, "edited = true")?; + Ok("appended line".to_string()) + }); + + let patch_output = self.run_step( + ScenarioStepKind::ApplyPatch, + &mut steps, + &mut per_tool, + || { + let patch = if self.config.fail_step == Some(ScenarioStepKind::ApplyPatch) { + "*** Begin Patch\n*** Update File: notes.txt\n@@\n-THIS LINE DOES NOT EXIST\n+broken\n*** End Patch\n" + .to_string() + } else { + "*** Begin Patch\n*** Update File: notes.txt\n@@\n status = \"draft\"\n-todo: offline metrics\n+todo: offline metrics (patched)\n*** End Patch\n" + .to_string() + }; + apply_patch(workspace.path(), &patch)?; + Ok("patch applied".to_string()) + }, + ); + + let shell_output = self.run_step( + ScenarioStepKind::ExecShell, + &mut steps, + &mut per_tool, + || { + let command = if self.config.fail_step == Some(ScenarioStepKind::ExecShell) { + "command_that_does_not_exist".to_string() + } else { + self.config.shell_command.clone() + }; + exec_shell(workspace.path(), &command) + }, + ); + + let duration = started_at.elapsed(); + + let workspace_summary = summarize_workspace(workspace.path(), list_output.as_deref())?; + + let validation_success = validate_outputs( + workspace.path(), + &self.config.shell_expect_token, + search_output.as_deref(), + edit_output.as_deref(), + patch_output.as_deref(), + shell_output.as_deref(), + ); + + let tool_errors = steps.iter().filter(|s| !s.success).count(); + let success = tool_errors == 0 && validation_success; + + let metrics = EvalMetrics { + success, + tool_errors, + steps: steps.len(), + duration, + per_tool, + }; + + Ok(EvalRun { + scenario_name: self.config.scenario_name.clone(), + workspace, + workspace_summary, + metrics, + steps, + }) + } + + fn run_step( + &self, + kind: ScenarioStepKind, + steps: &mut Vec, + per_tool: &mut BTreeMap, + f: F, + ) -> Option + where + F: FnOnce() -> Result, + T: ToString, + { + let started_at = Instant::now(); + let result = f(); + let duration = started_at.elapsed(); + + let stats = per_tool.entry(kind).or_default(); + stats.invocations += 1; + stats.total_duration += duration; + + match result { + Ok(value) => { + let output = truncate_output(&value.to_string(), self.config.max_output_chars); + steps.push(EvalStep { + kind, + tool_name: kind.tool_name(), + success: true, + duration, + error: None, + output: Some(output), + }); + Some(value) + } + Err(err) => { + stats.errors += 1; + steps.push(EvalStep { + kind, + tool_name: kind.tool_name(), + success: false, + duration, + error: Some(err.to_string()), + output: None, + }); + None + } + } + } +} + +impl Default for EvalHarness { + fn default() -> Self { + Self::new(EvalHarnessConfig::default()) + } +} + +/// Result of running the evaluation harness. +#[derive(Debug)] +pub struct EvalRun { + pub scenario_name: String, + workspace: TempDir, + pub workspace_summary: WorkspaceSummary, + pub metrics: EvalMetrics, + pub steps: Vec, +} + +impl EvalRun { + /// Get the root of the temporary workspace. + pub fn workspace_root(&self) -> &Path { + self.workspace.path() + } + + /// Convert the run into a serializable report for CLI output. + pub fn to_report(&self) -> EvalReport { + EvalReport { + scenario_name: self.scenario_name.clone(), + workspace_root: self.workspace_root().to_path_buf(), + workspace_summary: self.workspace_summary.clone(), + metrics: self.metrics.clone(), + steps: self.steps.clone(), + } + } +} + +/// Serializable report derived from an `EvalRun`. +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +pub struct EvalReport { + pub scenario_name: String, + pub workspace_root: PathBuf, + pub workspace_summary: WorkspaceSummary, + pub metrics: EvalMetrics, + pub steps: Vec, +} + +#[derive(Debug, Clone)] +struct SeedWorkspace { + notes_path: PathBuf, +} + +fn seed_workspace(root: &Path) -> Result { + let src_dir = root.join("src"); + fs::create_dir_all(&src_dir) + .with_context(|| format!("failed to create seed directory: {}", src_dir.display()))?; + + let readme_path = root.join("README.md"); + fs::write( + &readme_path, + "# Eval Harness Workspace\n\nThis workspace is offline.\n", + ) + .with_context(|| format!("failed to write {}", readme_path.display()))?; + + let notes_path = root.join("notes.txt"); + fs::write( + ¬es_path, + "# Eval Harness\nstatus = \"draft\"\ntodo: offline metrics\n", + ) + .with_context(|| format!("failed to write {}", notes_path.display()))?; + + let lib_path = src_dir.join("lib.rs"); + fs::write( + &lib_path, + "pub fn add(a: i32, b: i32) -> i32 {\n a + b\n}\n", + ) + .with_context(|| format!("failed to write {}", lib_path.display()))?; + + Ok(SeedWorkspace { notes_path }) +} + +fn summarize_workspace(root: &Path, list_output: Option<&str>) -> Result { + let mut files = Vec::new(); + + let walker = WalkBuilder::new(root) + .hidden(false) + .git_ignore(false) + .git_global(false) + .git_exclude(false) + .build(); + + for entry in walker { + let entry = entry.with_context(|| format!("failed to walk {}", root.display()))?; + if entry.file_type().is_some_and(|t| t.is_file()) { + files.push(entry.into_path()); + } + } + + if files.is_empty() + && let Some(output) = list_output + && !output.trim().is_empty() + { + return Err(anyhow!( + "workspace appears empty after list_dir: {}", + output.trim() + )); + } + + files.sort(); + + Ok(WorkspaceSummary { + root: root.to_path_buf(), + file_count: files.len(), + files, + }) +} + +fn validate_outputs( + root: &Path, + shell_expect_token: &str, + search_output: Option<&str>, + edit_output: Option<&str>, + patch_output: Option<&str>, + shell_output: Option<&str>, +) -> bool { + let notes_path = root.join("notes.txt"); + let notes = match fs::read_to_string(¬es_path) { + Ok(content) => content, + Err(_) => return false, + }; + + let search_ok = search_output.is_some_and(|s| s.contains("matches=")); + let edit_ok = edit_output.is_some_and(|s| !s.is_empty()) && notes.contains("edited = true"); + let patch_ok = patch_output.is_some_and(|s| !s.is_empty()) + && notes.contains("todo: offline metrics (patched)"); + let shell_ok = shell_output + .map(str::trim) + .is_some_and(|s| s.contains(shell_expect_token)); + + search_ok && edit_ok && patch_ok && shell_ok +} + +fn list_dir(path: &Path) -> Result> { + let mut entries = Vec::new(); + let dir = fs::read_dir(path) + .with_context(|| format!("failed to read directory: {}", path.display()))?; + + for entry in dir { + let entry = entry.with_context(|| format!("failed to list {}", path.display()))?; + entries.push(entry.file_name().to_string_lossy().to_string()); + } + + entries.sort(); + Ok(entries) +} + +fn read_file(path: &Path) -> Result { + fs::read_to_string(path).with_context(|| format!("failed to read {}", path.display())) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SearchMatch { + path: PathBuf, + line: usize, + content: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SearchResult { + matches: Vec, +} + +fn search_files(root: &Path, pattern: &str) -> Result { + if !root.exists() { + return Err(anyhow!("search root does not exist: {}", root.display())); + } + + let regex = Regex::new(pattern).context("failed to compile search regex")?; + let mut matches = Vec::new(); + + let walker = WalkBuilder::new(root) + .hidden(false) + .git_ignore(false) + .git_global(false) + .git_exclude(false) + .build(); + + for entry in walker { + let entry = entry.with_context(|| format!("failed to walk {}", root.display()))?; + if !entry.file_type().is_some_and(|t| t.is_file()) { + continue; + } + + let path = entry.path(); + let content = match fs::read_to_string(path) { + Ok(c) => c, + Err(_) => continue, + }; + + for (idx, line) in content.lines().enumerate() { + if regex.is_match(line) { + matches.push(SearchMatch { + path: path.to_path_buf(), + line: idx + 1, + content: line.to_string(), + }); + } + if matches.len() >= 64 { + break; + } + } + if matches.len() >= 64 { + break; + } + } + + Ok(SearchResult { matches }) +} + +fn edit_file_append(path: &Path, line: &str) -> Result<()> { + let mut content = read_file(path)?; + if !content.ends_with('\n') { + content.push('\n'); + } + content.push_str(line); + content.push('\n'); + fs::write(path, content).with_context(|| format!("failed to write {}", path.display())) +} + +fn apply_patch(root: &Path, patch: &str) -> Result<()> { + let mut lines = patch.lines(); + + let begin = lines.next().unwrap_or_default(); + if begin != "*** Begin Patch" { + return Err(anyhow!("patch missing *** Begin Patch header")); + } + + let header = lines.next().unwrap_or_default(); + let file_rel = header + .strip_prefix("*** Update File: ") + .ok_or_else(|| anyhow!("only *** Update File patches are supported"))?; + if file_rel.contains("..") { + return Err(anyhow!("patch path must be workspace-relative")); + } + + let file_path = root.join(file_rel); + let original = read_file(&file_path)?; + let had_trailing_newline = original.ends_with('\n'); + let mut file_lines: Vec = original.lines().map(|l| l.to_string()).collect(); + + let mut cursor = 0usize; + for raw_line in lines { + if raw_line == "*** End Patch" { + break; + } + if raw_line.starts_with("*** ") { + return Err(anyhow!("unexpected patch directive: {raw_line}")); + } + if raw_line.starts_with("@@") { + continue; + } + + let (kind, rest) = raw_line.split_at(1); + let content = rest.to_string(); + + match kind { + " " => { + let Some(found) = file_lines[cursor..] + .iter() + .position(|line| line == &content) + .map(|offset| cursor + offset) + else { + return Err(anyhow!( + "patch context not found in {}: {}", + file_path.display(), + content + )); + }; + cursor = found + 1; + } + "-" => { + if cursor >= file_lines.len() || file_lines[cursor] != content { + return Err(anyhow!( + "patch removal mismatch in {}: expected '{}'", + file_path.display(), + content + )); + } + file_lines.remove(cursor); + } + "+" => { + file_lines.insert(cursor, content); + cursor += 1; + } + _ => return Err(anyhow!("unsupported patch line: {raw_line}")), + } + } + + let mut updated = file_lines.join("\n"); + if had_trailing_newline { + updated.push('\n'); + } + + fs::write(&file_path, updated) + .with_context(|| format!("failed to write patched file {}", file_path.display())) +} + +fn exec_shell(root: &Path, command: &str) -> Result { + #[cfg(windows)] + let output = Command::new("cmd") + .args(["/C", command]) + .current_dir(root) + .output() + .with_context(|| format!("failed to execute shell command: {command}"))?; + + #[cfg(not(windows))] + let output = Command::new("sh") + .arg("-c") + .arg(command) + .current_dir(root) + .output() + .with_context(|| format!("failed to execute shell command: {command}"))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(anyhow!( + "shell command failed (status={}): {}", + output.status, + stderr.trim() + )); + } + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + Ok(stdout.trim().to_string()) +} + +fn truncate_output(value: &str, max_chars: usize) -> String { + if value.chars().count() <= max_chars { + return value.to_string(); + } + + let truncated: String = value.chars().take(max_chars).collect(); + format!("{}...", truncated) +} diff --git a/src/main.rs b/src/main.rs index 512dd13b..c9881b1e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,7 +19,7 @@ use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; use std::time::Duration; -use anyhow::{Result, bail}; +use anyhow::{Context, Result, anyhow, bail}; use clap::{Args, CommandFactory, Parser, Subcommand}; use clap_complete::{Shell, generate}; use dotenvy::dotenv; @@ -33,6 +33,7 @@ mod compaction; mod config; mod core; mod duo; +mod eval; mod execpolicy; mod features; mod hooks; @@ -56,8 +57,10 @@ mod tools; mod tui; mod ui; mod utils; +mod working_set; -use crate::config::Config; +use crate::config::{Config, MAX_SUBAGENTS}; +use crate::eval::{EvalHarness, EvalHarnessConfig, ScenarioStepKind}; use crate::llm_client::LlmClient; use crate::mcp::{McpConfig, McpPool}; use crate::models::{ContentBlock, Message, MessageRequest, SystemPrompt}; @@ -88,7 +91,7 @@ struct Cli { #[arg(long)] yolo: bool, - /// Maximum number of concurrent sub-agents (1-5) + /// Maximum number of concurrent sub-agents (1-20) #[arg(long)] max_subagents: Option, @@ -161,6 +164,8 @@ enum Commands { Review(ReviewArgs), /// Apply a patch file (or stdin) to the working tree Apply(ApplyArgs), + /// Run the offline evaluation harness (no network/LLM calls) + Eval(EvalArgs), /// Manage MCP servers Mcp { #[command(subcommand)] @@ -209,6 +214,25 @@ struct ExecArgs { auto: bool, } +#[derive(Args, Debug, Clone)] +struct EvalArgs { + /// Intentionally fail a specific step (list, read, search, edit, patch, shell) + #[arg(long, value_name = "STEP")] + fail_step: Option, + /// Shell command to run during the exec step + #[arg(long, default_value = "printf eval-harness")] + shell_command: String, + /// Token that must appear in shell output for validation + #[arg(long, default_value = "eval-harness")] + shell_expect_token: String, + /// Maximum characters stored per step output summary + #[arg(long, default_value_t = 240)] + max_output_chars: usize, + /// Emit machine-readable JSON output + #[arg(long, default_value_t = false)] + json: bool, +} + #[derive(Args, Debug, Default, Clone)] struct FeatureToggles { /// Enable a feature (repeatable). Equivalent to `features.=true`. @@ -375,9 +399,10 @@ async fn main() -> Result<()> { let workspace = cli.workspace.clone().unwrap_or_else(|| { std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")) }); - let max_subagents = cli - .max_subagents - .map_or_else(|| config.max_subagents(), |value| value.clamp(1, 5)); + let max_subagents = cli.max_subagents.map_or_else( + || config.max_subagents(), + |value| value.clamp(1, MAX_SUBAGENTS), + ); let auto_mode = args.auto || cli.yolo; run_exec_agent( &config, @@ -398,6 +423,7 @@ async fn main() -> Result<()> { run_review(&config, args).await } Commands::Apply(args) => run_apply(args), + Commands::Eval(args) => run_eval(args), Commands::Mcp { command } => { let config = load_config_from_cli(&cli)?; run_mcp_command(&config, command).await @@ -468,6 +494,74 @@ fn generate_completions(shell: Shell) { generate(shell, &mut cmd, name, &mut io::stdout()); } +/// Run the offline evaluation harness (no network/LLM calls). +fn run_eval(args: EvalArgs) -> Result<()> { + let fail_step = match args.fail_step.as_deref() { + Some(value) => ScenarioStepKind::parse(value) + .map(Some) + .ok_or_else(|| anyhow!("invalid --fail-step '{value}'"))?, + None => None, + }; + + let config = EvalHarnessConfig { + fail_step, + shell_command: args.shell_command, + shell_expect_token: args.shell_expect_token, + max_output_chars: args.max_output_chars, + ..EvalHarnessConfig::default() + }; + + let harness = EvalHarness::new(config); + let run = harness.run().context("evaluation harness failed")?; + let report = run.to_report(); + + if args.json { + let json = serde_json::to_string_pretty(&report)?; + println!("{json}"); + } else { + println!("Offline Eval Harness"); + println!("scenario: {}", report.scenario_name); + println!("workspace: {}", report.workspace_root.display()); + println!("success: {}", report.metrics.success); + println!("steps: {}", report.metrics.steps); + println!("tool_errors: {}", report.metrics.tool_errors); + println!("duration_ms: {}", report.metrics.duration.as_millis()); + + if !report.metrics.per_tool.is_empty() { + println!("per_tool:"); + for (kind, stats) in &report.metrics.per_tool { + println!( + " {} invocations={} errors={} duration_ms={}", + kind.tool_name(), + stats.invocations, + stats.errors, + stats.total_duration.as_millis() + ); + } + } + + let failed_steps: Vec<_> = report.steps.iter().filter(|s| !s.success).collect(); + if !failed_steps.is_empty() { + println!("failed_steps:"); + for step in failed_steps { + let error = step.error.as_deref().unwrap_or("unknown error"); + println!( + " {} tool={} error={}", + step.kind.tool_name(), + step.tool_name, + error + ); + } + } + } + + if report.metrics.success { + Ok(()) + } else { + bail!("offline evaluation harness reported failure") + } +} + /// Run system diagnostics async fn run_doctor() { use crate::palette; @@ -1335,9 +1429,10 @@ async fn run_interactive( .default_text_model .clone() .unwrap_or_else(|| "deepseek-reasoner".to_string()); - let max_subagents = cli - .max_subagents - .map_or_else(|| config.max_subagents(), |value| value.clamp(1, 5)); + let max_subagents = cli.max_subagents.map_or_else( + || config.max_subagents(), + |value| value.clamp(1, MAX_SUBAGENTS), + ); let use_alt_screen = should_use_alt_screen(cli, config); tui::run_tui( @@ -1416,8 +1511,8 @@ async fn run_exec_agent( use crate::core::ops::Op; use crate::duo::DuoSession; use crate::rlm::RlmSession; - use crate::tools::plan::PlanState; - use crate::tools::todo::TodoList; + use crate::tools::plan::new_shared_plan_state; + use crate::tools::todo::new_shared_todo_list; use crate::tui::app::AppMode; let engine_config = EngineConfig { @@ -1433,8 +1528,8 @@ async fn run_exec_agent( rlm_session: Arc::new(Mutex::new(RlmSession::default())), duo_session: Arc::new(Mutex::new(DuoSession::new())), compaction: CompactionConfig::default(), - todos: Arc::new(Mutex::new(TodoList::new())), - plan_state: Arc::new(Mutex::new(PlanState::default())), + todos: new_shared_todo_list(), + plan_state: new_shared_plan_state(), }; let engine_handle = spawn_engine(engine_config, config); diff --git a/src/prompts.rs b/src/prompts.rs index dba6f725..4c09694f 100644 --- a/src/prompts.rs +++ b/src/prompts.rs @@ -32,6 +32,7 @@ pub fn system_prompt_for_mode(mode: AppMode) -> SystemPrompt { pub fn system_prompt_for_mode_with_context( mode: AppMode, workspace: &Path, + working_set_summary: Option<&str>, rlm_summary: Option<&str>, duo_summary: Option<&str>, ) -> SystemPrompt { @@ -53,6 +54,12 @@ pub fn system_prompt_for_mode_with_context( base_prompt.trim().to_string() }; + if let Some(summary) = working_set_summary + && !summary.trim().is_empty() + { + full_prompt = format!("{full_prompt}\n\n{summary}"); + } + if mode == AppMode::Rlm { let summary = rlm_summary.unwrap_or("No RLM contexts loaded."); full_prompt = format!("{full_prompt}\n\nRLM Context Summary:\n{summary}"); diff --git a/src/prompts/agent.txt b/src/prompts/agent.txt index 8d7d04c9..a05959c7 100644 --- a/src/prompts/agent.txt +++ b/src/prompts/agent.txt @@ -3,12 +3,28 @@ You are DeepSeek CLI, an agentic coding assistant with full tool access. IMPORTANT: You are ALREADY running inside the DeepSeek CLI TUI. You have direct access to all tools below - do NOT try to run or launch the CLI binary. Your tools execute directly in the current session. When given a task: -1. Break it into subtasks and track them with todo tools. -2. Work through each subtask systematically. -3. Report progress as you go. -4. Verify your work before marking complete. -5. Do not stop until the full task is done. -6. Avoid destructive actions (deletes, irreversible changes) unless the user explicitly requests them; suggest YOLO for high-risk changes. +1. Understand the goal, constraints, and acceptance criteria first. +2. Break work into small, testable steps and track them with todo tools. +3. Read and search first, then make targeted edits, then verify with tools. +4. Report concise progress updates at meaningful checkpoints. +5. Do not stop until the full task is done or you are clearly blocked. +6. Avoid destructive actions (deletes, irreversible changes) unless the user explicitly requests them; warn before risky actions and suggest YOLO for high-risk changes. + +Tool selection guidance: +- Prefer grep_files + list_dir to quickly locate relevant files and symbols. +- Use read_file to confirm context; do not assume file contents. +- Prefer apply_patch/edit_file for scoped changes instead of rewriting entire files. +- Use exec_shell for objective verification: build, test, format, lint, and targeted checks. +- Use web_search only when local context is insufficient or time-sensitive. + +Testing and stop conditions: +- After any change, run the most relevant tests/checks before declaring success. +- Start narrow (targeted tests) and expand to broader checks when appropriate. +- If a check fails, report it concisely, fix it, and re-run. +- Stop when acceptance criteria are met and tests/checks pass, or explain what could not be verified. + +Step budgeting: +- Budget attempts. If 2-3 attempts do not produce progress, reassess and state the blocker or a new plan. Available tools: @@ -21,6 +37,14 @@ FILE OPERATIONS: - grep_files: Search files by regex - web_search: Search the web for up-to-date information +GIT AND DIAGNOSTICS: +- git_status: Inspect repo status safely +- git_diff: Inspect working tree or staged diffs +- diagnostics: Report workspace, git, sandbox, and toolchain info + +TESTING: +- run_tests: Run `cargo test` with optional args + SHELL EXECUTION: - exec_shell: Run shell commands (supports background execution) - command: The command to execute @@ -34,14 +58,24 @@ TASK MANAGEMENT: SUB-AGENTS: - agent_spawn: Spawn a background sub-agent (type, prompt, allowed_tools) +- agent_swarm: Spawn a dependency-aware swarm of sub-agents (tasks, shared_context) - agent_result: Get result from a sub-agent (agent_id, block, timeout_ms) - agent_cancel: Cancel a running sub-agent (agent_id) - agent_list: List all sub-agents and their status If you spawn a sub-agent, always follow up with agent_result (block: true) and incorporate its result before responding to the user. +If you use agent_swarm, incorporate its aggregated results (or follow up with agent_result for any running agents) before responding. -For complex work, call update_plan to publish a checklist. -Keep exactly one plan step in_progress at a time. -Use todo tools for granular progress when helpful. +Planning and progress: +- For complex or multi-file work, call update_plan to publish a checklist. +- Keep exactly one plan step in_progress at a time. +- Use todo tools for granular progress when helpful. +- Prefer short progress notes over long narration. + +Git hygiene: +- Run git status early (to see the workspace state) and again before finishing. +- Do not revert or overwrite unrelated user changes. +- Avoid destructive git commands unless explicitly requested. +- Do not commit unless the user asks. BACKGROUND EXECUTION: For long-running commands (build, test, server), use exec_shell with background: true. diff --git a/src/prompts/base.txt b/src/prompts/base.txt index 3386f547..d5473849 100644 --- a/src/prompts/base.txt +++ b/src/prompts/base.txt @@ -1,14 +1,38 @@ You are DeepSeek CLI, an agentic coding assistant. When given a task: -1. Break it into subtasks and track them. -2. Work through each subtask systematically. -3. Report progress as you go. -4. Verify your work before marking complete. -5. Do not stop until the full task is done. +1. Understand the goal, constraints, and acceptance criteria first. +2. Break the work into small, testable steps and track them. +3. Choose tools deliberately; read before you write, then verify. +4. Report short progress updates at meaningful checkpoints. +5. Do not stop until the full task is done or you are clearly blocked. -Use tools when needed. For complex work, call update_plan to publish a checklist. -Keep exactly one plan step in_progress at a time. -Use todo tools for granular progress when helpful. +Tool selection guidance: +- Prefer fast search tools (grep/rg) to locate relevant files and symbols. +- Use read tools to confirm context; avoid guessing about file contents. +- Prefer targeted edits (apply_patch/edit) over full rewrites when possible. +- Use shell tools for build/test/format/lint and other objective verification. +- Use web search only when the answer may be time-sensitive or unclear locally. + +Planning and progress: +- For non-trivial tasks, publish a checklist with update_plan. +- Keep exactly one plan step in_progress at a time. +- Use todo tools for granular progress when helpful. +- Budget your steps: if 2-3 attempts fail to make progress, pause, reassess, and state the blocker. + +Testing and stop conditions: +- After any change, run the most relevant tests/checks before declaring success. +- If tests fail, report the failure concisely, fix it, and re-run. +- Stop when acceptance criteria are met and checks/tests pass (or explain why they could not run). + +Git hygiene: +- Check git status early and again before finishing. +- Do not revert or overwrite unrelated user changes. +- Avoid destructive git commands unless explicitly requested. +- Do not commit unless the user asks. + +Approval etiquette: +- In approval-gated modes, ask before writes or shell commands. +- In autonomous modes, warn before risky or irreversible actions. Tone: competent, warm, and concise. Use light humor sparingly when it fits; a rare example is "You're absolutely right! ... maybe." diff --git a/src/prompts/duo.txt b/src/prompts/duo.txt index 904f3212..66abf6e2 100644 --- a/src/prompts/duo.txt +++ b/src/prompts/duo.txt @@ -1,3 +1,18 @@ You are in Duo mode for requirements-driven development. -Use duo_init with a requirements checklist, then alternate duo_player (implement) and duo_coach (verify) until approved. +Workflow: +- Start with duo_init using a clear requirements checklist and acceptance criteria. +- Alternate duo_player (implement) and duo_coach (verify) until approved. +- In duo_player phases, work tool-first: search, read, make targeted edits, then verify. +- In duo_coach phases, prioritize objective verification and requirement coverage. + +Tool selection and verification: +- Use search/read tools to ground work in the current codebase before editing. +- Prefer targeted diffs over broad rewrites when possible. +- After any change, run the most relevant tests/checks before handing off to duo_coach. +- If verification fails, report it concisely, fix it, and re-run. + +Budgeting and hygiene: +- Budget attempts. If 2-3 attempts fail to make progress, reassess and state the blocker. +- Check git status early and again before finishing; do not revert unrelated changes. +- Provide brief progress updates at phase boundaries and major checkpoints. diff --git a/src/prompts/normal.txt b/src/prompts/normal.txt index 6cc019ef..6ff3c8b9 100644 --- a/src/prompts/normal.txt +++ b/src/prompts/normal.txt @@ -12,14 +12,35 @@ Available tools in this mode: - apply_patch: Apply a unified diff patch (ask first) - grep_files: Search files by regex - web_search: Search the web for up-to-date information +- git_status: Inspect repository status safely +- git_diff: Inspect diffs (working tree or staged) +- diagnostics: Report workspace, git, sandbox, and toolchain info +- run_tests: Run `cargo test` with optional args - exec_shell: Run shell commands (ask first, if enabled) - note: Record important information - todo_write: Write or update the todo list - update_plan: Publish a structured plan Guidelines: -1. Answer questions clearly and concisely -2. Provide code examples when helpful -3. You CAN read files and explore the codebase -4. Ask for explicit approval before any file writes, patches, or shell commands -5. If the user wants fully autonomous changes, suggest pressing Tab to switch to Agent or YOLO mode +1. Understand the goal and constraints before proposing changes. +2. Prefer tool-centric reasoning: search, read, then act. +3. Answer clearly and concisely; provide code examples when helpful. +4. You CAN read files and explore the codebase without approval. +5. Ask for explicit approval before any file writes, patches, or shell commands. +6. If the user wants fully autonomous changes, suggest pressing Tab to switch to Agent or YOLO mode. + +Tool selection guidance: +- Use grep_files/list_dir to find relevant files quickly. +- Use read_file to ground your answer in the actual code. +- When approved to edit, prefer apply_patch/edit_file for targeted diffs. +- When approved to run commands, use exec_shell for build/test/format/lint and other objective checks. + +Testing and stop conditions (after approval to edit/run commands): +- After any change, run the most relevant tests/checks before declaring success. +- If a check fails, report it concisely, fix it, and re-run. +- Stop when acceptance criteria are met and checks pass, or explain what could not be verified. + +Step budgeting and progress: +- For non-trivial tasks, propose a short plan and use update_plan/todo_write when helpful. +- Provide brief progress updates at key checkpoints, not every small action. +- If 2-3 attempts fail, pause and ask a focused clarifying question. diff --git a/src/prompts/plan.txt b/src/prompts/plan.txt index 9c9ad8ef..294ab5d5 100644 --- a/src/prompts/plan.txt +++ b/src/prompts/plan.txt @@ -3,10 +3,10 @@ You are DeepSeek CLI in PLAN mode. Design before implementing. This mode is read-only: you can analyze and plan, but you cannot edit files or run shell commands. In this mode, focus on: -1. Understanding requirements fully before proposing solutions -2. Breaking down complex tasks into clear, actionable steps -3. Identifying potential issues and edge cases upfront -4. Creating a detailed plan using update_plan before implementation +1. Understanding requirements, constraints, and acceptance criteria fully. +2. Breaking complex tasks into clear, actionable, testable steps. +3. Identifying potential issues, regressions, and edge cases upfront. +4. Creating a detailed plan using update_plan before implementation. Available tools: @@ -19,13 +19,16 @@ EXPLORATION: - read_file: Read file contents to understand context - grep_files: Search files by regex - web_search: Search the web for up-to-date information (if enabled) +- git_status: Inspect repository status safely +- git_diff: Inspect diffs to understand current changes +- diagnostics: Report workspace, git, sandbox, and toolchain info Guidelines: -- Focus on planning before making changes -- Use update_plan to create structured plans -- Each step should be specific and actionable -- Include acceptance criteria where possible -- Identify dependencies between steps -- Call out risks, edge cases, and verification steps -- Ask clarifying questions if requirements are unclear -- After the plan is ready, summarize briefly and wait for user direction +- Prefer tool-centric planning: use grep_files/list_dir/read_file to ground the plan in the actual codebase. +- Use update_plan to create structured plans with one step in_progress at a time. +- Each step should be specific, actionable, and include expected outcomes. +- Include explicit verification steps (tests/checks) after each planned change. +- Include git hygiene in the plan: check git status early and before finishing; avoid reverting unrelated changes. +- Identify dependencies, risks, edge cases, and rollback/mitigation ideas. +- Budget steps: if key facts are missing after 2-3 exploration attempts, ask a focused clarifying question. +- Provide concise progress notes, then wait for user direction once the plan is ready. diff --git a/src/prompts/rlm.txt b/src/prompts/rlm.txt index 466ed1fd..4c6689dc 100644 --- a/src/prompts/rlm.txt +++ b/src/prompts/rlm.txt @@ -1,3 +1,13 @@ You are in RLM mode for working with large files that exceed context limits. -Use rlm_* tools to load files, explore content, and run focused queries over chunks. +Work tool-first and chunk-aware: +- Use rlm_* tools to load files, explore content, and run focused queries over chunks. +- Prefer search-then-read: locate relevant sections before loading large spans. +- Summarize the relevant chunks in your own words before editing. +- Make targeted edits; avoid full-file rewrites unless necessary. + +Verification, budgeting, and hygiene: +- After any change, run the most relevant tests/checks before declaring success. +- Check git status early and again before finishing; do not revert unrelated changes. +- Budget attempts. If 2-3 chunking/search attempts fail to surface the needed context, pause and state the blocker or request clarification. +- Provide concise progress updates at key checkpoints (what chunk/area you inspected, what changed, what verified). diff --git a/src/tools/apply_patch.rs b/src/tools/apply_patch.rs index 768bbc25..c40c65c2 100644 --- a/src/tools/apply_patch.rs +++ b/src/tools/apply_patch.rs @@ -3,6 +3,7 @@ //! This tool provides precise file modifications using unified diff format, //! supporting multi-hunk patches and fuzzy matching. +use std::collections::HashSet; use std::fs; use std::path::PathBuf; @@ -18,6 +19,10 @@ use super::spec::{ /// Maximum lines of context for fuzzy matching (increased for better tolerance) const MAX_FUZZ: usize = 50; +/// Limit how much context we print in error messages. +const HUNK_PREVIEW_LINES: usize = 4; +const SNIPPET_RADIUS: usize = 2; +const FILE_LIST_LIMIT: usize = 6; // === Types === @@ -30,9 +35,27 @@ pub struct PatchResult { pub hunks_applied: usize, pub hunks_total: usize, pub fuzz_used: usize, + #[serde(default)] + pub hunks_with_fuzz: usize, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub touched_files: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub file_summaries: Vec, pub message: String, } +/// Per-file summary for patch application output. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FileSummary { + pub path: String, + pub hunks: usize, + pub hunks_applied: usize, + pub fuzz_used: usize, + pub hunks_with_fuzz: usize, + pub created: bool, + pub deleted: bool, +} + /// A single hunk in a unified diff #[derive(Debug, Clone)] pub struct Hunk { @@ -76,6 +99,34 @@ struct PatchStats { hunks_applied: usize, hunks_total: usize, fuzz_used: usize, + hunks_with_fuzz: usize, +} + +#[derive(Debug, Default, Clone)] +struct PatchStatsExt { + stats: PatchStats, + touched_files: Vec, + file_summaries: Vec, + header_path_mismatch: Option, +} + +#[derive(Debug, Default, Clone)] +struct PatchShape { + has_hunks: bool, + header_files: Vec, +} + +impl PatchShape { + fn file_count(&self) -> usize { + self.header_files.len() + } +} + +#[derive(Debug, Default, Clone, Copy)] +struct HunkApplyStats { + hunks_applied: usize, + fuzz_used: usize, + hunks_with_fuzz: usize, } // === Errors === @@ -160,21 +211,19 @@ impl ToolSpec for ApplyPatchTool { let create_if_missing = optional_bool(&input, "create_if_missing", false); if let Some(changes_value) = input.get("changes") { - let pending = build_pending_writes_from_changes(changes_value, context)?; - let stats = PatchStats { - files_total: pending.len(), - files_applied: pending.len(), - ..PatchStats::default() - }; + let (pending, stats) = build_pending_writes_from_changes(changes_value, context)?; apply_pending_writes(&pending)?; let result = PatchResult { success: true, - files_applied: stats.files_applied, - files_total: stats.files_total, - hunks_applied: stats.hunks_applied, - hunks_total: stats.hunks_total, - fuzz_used: stats.fuzz_used, - message: format!("Applied {} file change(s)", stats.files_applied), + files_applied: stats.stats.files_applied, + files_total: stats.stats.files_total, + hunks_applied: stats.stats.hunks_applied, + hunks_total: stats.stats.hunks_total, + fuzz_used: stats.stats.fuzz_used, + hunks_with_fuzz: stats.stats.hunks_with_fuzz, + touched_files: stats.touched_files.clone(), + file_summaries: stats.file_summaries.clone(), + message: build_summary_message(&stats), }; return ToolResult::json(&result) .map_err(|e| ToolError::execution_failed(e.to_string())); @@ -182,10 +231,15 @@ impl ToolSpec for ApplyPatchTool { let patch_text = required_str(&input, "patch")?; let path_override = optional_str(&input, "path"); + let patch_shape = inspect_patch_shape(patch_text); + validate_patch_shape(&patch_shape, path_override)?; + let mismatch_note = path_override.and_then(|path| diff_header_mismatch(path, &patch_shape)); let file_patches = if let Some(path) = path_override { let hunks = parse_unified_diff(patch_text)?; if hunks.is_empty() { - return Err(ToolError::invalid_input("No valid hunks found in patch")); + return Err(ToolError::invalid_input( + "Patch did not contain any hunks (`@@ ... @@`). Provide a unified diff hunk.", + )); } vec![FilePatch { path: path.to_string(), @@ -197,25 +251,28 @@ impl ToolSpec for ApplyPatchTool { let file_patches = parse_unified_diff_files(patch_text, create_if_missing)?; if file_patches.is_empty() { return Err(ToolError::invalid_input( - "No valid file patches found in unified diff", + "No valid file patches found. Ensure the patch includes `---`/`+++` headers or provide `path`.", )); } file_patches }; - let (pending, stats) = build_pending_writes_from_patches(file_patches, context, fuzz)?; + let (pending, mut stats) = build_pending_writes_from_patches(file_patches, context, fuzz)?; + if stats.header_path_mismatch.is_none() { + stats.header_path_mismatch = mismatch_note; + } apply_pending_writes(&pending)?; let result = PatchResult { success: true, - files_applied: stats.files_applied, - files_total: stats.files_total, - hunks_applied: stats.hunks_applied, - hunks_total: stats.hunks_total, - fuzz_used: stats.fuzz_used, - message: format!( - "Applied {}/{} hunks across {} file(s) (fuzz: {})", - stats.hunks_applied, stats.hunks_total, stats.files_applied, stats.fuzz_used - ), + files_applied: stats.stats.files_applied, + files_total: stats.stats.files_total, + hunks_applied: stats.stats.hunks_applied, + hunks_total: stats.stats.hunks_total, + fuzz_used: stats.stats.fuzz_used, + hunks_with_fuzz: stats.stats.hunks_with_fuzz, + touched_files: stats.touched_files.clone(), + file_summaries: stats.file_summaries.clone(), + message: build_summary_message(&stats), }; ToolResult::json(&result).map_err(|e| ToolError::execution_failed(e.to_string())) @@ -273,6 +330,7 @@ fn parse_unified_diff_files( let new_path = Some(stripped.trim().to_string()); let (path, delete_after, create_flag) = resolve_diff_paths(old_path.as_deref(), new_path.as_deref(), create_if_missing)?; + old_path = None; if let Some(file) = current.take() { files.push(file); } @@ -287,8 +345,13 @@ fn parse_unified_diff_files( if line.starts_with("@@") { let Some(file) = current.as_mut() else { + if let Some(path) = old_path.as_deref() { + return Err(ToolError::invalid_input(format!( + "Patch hunk encountered after `--- {path}` but before a matching `+++` header. Each file section must include both headers." + ))); + } return Err(ToolError::invalid_input( - "Patch hunk encountered before file header", + "Patch hunk encountered before any file header. Add `---`/`+++` headers or provide `path`.", )); }; let hunk = parse_hunk_header(line, &mut lines)?; @@ -345,7 +408,7 @@ where let parts: Vec<&str> = header.split_whitespace().collect(); if parts.len() < 3 { return Err(ToolError::invalid_input(format!( - "Invalid hunk header: {header}" + "Invalid hunk header: {header}. Expected `@@ -start,count +start,count @@`." ))); } @@ -408,31 +471,155 @@ where /// Parse a range like "10,5" or "10" into (start, count) fn parse_range(range: &str) -> Result<(usize, usize), ToolError> { let parts: Vec<&str> = range.split(',').collect(); - let start = parts[0] - .parse::() - .map_err(|_| ToolError::invalid_input(format!("Invalid line number: {}", parts[0])))?; + let start = parts[0].parse::().map_err(|_| { + ToolError::invalid_input(format!( + "Invalid line number `{}` in hunk header. Use positive integers like `12` or `12,3`.", + parts[0] + )) + })?; let count = if parts.len() > 1 { - parts[1] - .parse::() - .map_err(|_| ToolError::invalid_input(format!("Invalid count: {}", parts[1])))? + parts[1].parse::().map_err(|_| { + ToolError::invalid_input(format!( + "Invalid line count `{}` in hunk header. Use positive integers like `3`.", + parts[1] + )) + })? } else { 1 }; Ok((start, count)) } +fn inspect_patch_shape(patch: &str) -> PatchShape { + let mut shape = PatchShape::default(); + let mut seen = HashSet::new(); + let mut old_path: Option = None; + + for line in patch.lines() { + if line.starts_with("@@") { + shape.has_hunks = true; + } + + if let Some(stripped) = line.strip_prefix("--- ") { + old_path = normalize_diff_path(stripped); + continue; + } + + if let Some(stripped) = line.strip_prefix("+++ ") { + let new_path = normalize_diff_path(stripped); + let resolved = new_path.or(old_path.clone()); + if let Some(path) = resolved { + if seen.insert(path.clone()) { + shape.header_files.push(path); + } + } + old_path = None; + } + } + + shape +} + +fn validate_patch_shape(shape: &PatchShape, path_override: Option<&str>) -> Result<(), ToolError> { + if !shape.has_hunks { + return Err(ToolError::invalid_input( + "Patch must include at least one hunk header (`@@ -start,count +start,count @@`).", + )); + } + + match path_override { + Some(_) if shape.file_count() > 1 => Err(ToolError::invalid_input(format!( + "Patch references multiple files ({}) but `path` was provided. Remove `path` to apply a multi-file patch, or provide a single-file patch.", + format_file_list(&shape.header_files), + ))), + None if shape.file_count() == 0 => Err(ToolError::invalid_input( + "Patch contains hunks but no file headers (`---`/`+++`). Provide `path` or add headers.", + )), + _ => Ok(()), + } +} + +fn diff_header_mismatch(path_override: &str, shape: &PatchShape) -> Option { + if shape.file_count() != 1 { + return None; + } + let header_path = &shape.header_files[0]; + let override_norm = normalize_diff_path(path_override).unwrap_or_else(|| path_override.into()); + if &override_norm == header_path { + None + } else { + Some(format!( + "Note: patch headers reference `{header_path}` but `path` overrides to `{override_norm}`." + )) + } +} + +fn build_summary_message(stats: &PatchStatsExt) -> String { + let mut parts = Vec::new(); + if stats.stats.hunks_total > 0 { + parts.push(format!( + "Applied {}/{} hunks across {} file(s).", + stats.stats.hunks_applied, stats.stats.hunks_total, stats.stats.files_applied + )); + } else { + parts.push(format!( + "Applied {} file change(s).", + stats.stats.files_applied + )); + } + + if !stats.touched_files.is_empty() { + parts.push(format!( + "Files: {}.", + format_file_list(&stats.touched_files) + )); + } + + if stats.stats.fuzz_used > 0 { + parts.push(format!( + "Fuzz used on {} hunk(s) (total fuzz: {}).", + stats.stats.hunks_with_fuzz, stats.stats.fuzz_used + )); + } + + if let Some(note) = stats.header_path_mismatch.as_deref() { + parts.push(note.to_string()); + } + + parts.join(" ") +} + +fn format_file_list(files: &[String]) -> String { + if files.is_empty() { + return "".to_string(); + } + let mut shown: Vec = files.iter().take(FILE_LIST_LIMIT).cloned().collect(); + let remaining = files.len().saturating_sub(shown.len()); + if remaining > 0 { + shown.push(format!("... (+{remaining} more)")); + } + shown.join(", ") +} + +fn push_unique(target: &mut Vec, value: String) { + if !target.iter().any(|existing| existing == &value) { + target.push(value); + } +} + fn build_pending_writes_from_changes( changes_value: &Value, context: &ToolContext, -) -> Result, ToolError> { - let changes = changes_value - .as_array() - .ok_or_else(|| ToolError::invalid_input("changes must be an array of {path, content}"))?; +) -> Result<(Vec, PatchStatsExt), ToolError> { + let changes = changes_value.as_array().ok_or_else(|| { + ToolError::invalid_input("`changes` must be an array of objects like {path, content}") + })?; if changes.is_empty() { - return Err(ToolError::invalid_input("changes cannot be empty")); + return Err(ToolError::invalid_input("`changes` cannot be empty")); } let mut pending = Vec::new(); + let mut stats = PatchStatsExt::default(); for change in changes { let path = change .get("path") @@ -449,30 +636,44 @@ fn build_pending_writes_from_changes( } else { None }; + let created = original.is_none(); pending.push(PendingWrite { path: resolved, content: Some(content.to_string()), original, }); + + stats.stats.files_total += 1; + stats.stats.files_applied += 1; + push_unique(&mut stats.touched_files, path.to_string()); + stats.file_summaries.push(FileSummary { + path: path.to_string(), + hunks: 0, + hunks_applied: 0, + fuzz_used: 0, + hunks_with_fuzz: 0, + created, + deleted: false, + }); } - Ok(pending) + Ok((pending, stats)) } fn build_pending_writes_from_patches( file_patches: Vec, context: &ToolContext, fuzz: usize, -) -> Result<(Vec, PatchStats), ToolError> { +) -> Result<(Vec, PatchStatsExt), ToolError> { let mut pending = Vec::new(); - let mut stats = PatchStats::default(); - stats.files_total = file_patches.len(); + let mut stats = PatchStatsExt::default(); + stats.stats.files_total = file_patches.len(); for file_patch in file_patches { if file_patch.hunks.is_empty() { return Err(ToolError::invalid_input(format!( - "Patch for {} has no hunks", + "Patch section for `{}` has no hunks (`@@ ... @@`).", file_patch.path ))); } @@ -486,15 +687,17 @@ fn build_pending_writes_from_patches( if original.is_none() && !file_patch.create_if_missing { return Err(ToolError::execution_failed(format!( - "File {} does not exist. Set create_if_missing=true for new files.", - resolved.display() + "File `{}` does not exist at `{}`. Set create_if_missing=true for new files or include headers for file creation.", + file_patch.path, + resolved.display(), ))); } if file_patch.delete_after && original.is_none() { return Err(ToolError::execution_failed(format!( - "File {} does not exist to delete.", - resolved.display() + "File `{}` does not exist at `{}` to delete.", + file_patch.path, + resolved.display(), ))); } @@ -505,11 +708,23 @@ fn build_pending_writes_from_patches( base_content.lines().map(String::from).collect() }; - let (applied, fuzz_used) = apply_hunks_to_lines(&mut lines, &file_patch.hunks, fuzz)?; - stats.hunks_applied += applied; - stats.hunks_total += file_patch.hunks.len(); - stats.fuzz_used += fuzz_used; - stats.files_applied += 1; + let apply_stats = + apply_hunks_to_lines(&mut lines, &file_patch.hunks, fuzz, &file_patch.path)?; + stats.stats.hunks_applied += apply_stats.hunks_applied; + stats.stats.hunks_total += file_patch.hunks.len(); + stats.stats.fuzz_used += apply_stats.fuzz_used; + stats.stats.hunks_with_fuzz += apply_stats.hunks_with_fuzz; + stats.stats.files_applied += 1; + push_unique(&mut stats.touched_files, file_patch.path.clone()); + stats.file_summaries.push(FileSummary { + path: file_patch.path.clone(), + hunks: file_patch.hunks.len(), + hunks_applied: apply_stats.hunks_applied, + fuzz_used: apply_stats.fuzz_used, + hunks_with_fuzz: apply_stats.hunks_with_fuzz, + created: original.is_none() && !file_patch.delete_after, + deleted: file_patch.delete_after, + }); if file_patch.delete_after { pending.push(PendingWrite { @@ -593,31 +808,98 @@ fn read_file_content(path: &PathBuf) -> Result { }) } +fn preview_expected_lines(hunk: &Hunk, limit: usize) -> Vec { + let mut preview = Vec::new(); + for line in hunk.lines.iter().filter_map(|line| match line { + HunkLine::Context(s) => Some((" ", s)), + HunkLine::Remove(s) => Some(("-", s)), + HunkLine::Add(_) => None, + }) { + if preview.len() >= limit { + break; + } + preview.push(format!(" {}{}", line.0, line.1)); + } + if preview.is_empty() { + preview.push(" ".to_string()); + } + preview +} + +fn snippet_around(lines: &[String], line_1_based: usize, radius: usize) -> Vec { + if lines.is_empty() { + return vec![" ".to_string()]; + } + + let center = line_1_based + .saturating_sub(1) + .min(lines.len().saturating_sub(1)); + let start = center.saturating_sub(radius); + let end = (center + radius).min(lines.len().saturating_sub(1)); + + lines[start..=end] + .iter() + .enumerate() + .map(|(idx, line)| { + let line_no = start + idx + 1; + format!(" {line_no:>4}: {line}") + }) + .collect() +} + +fn format_hunk_no_match_error( + lines: &[String], + hunk: &Hunk, + err: &ApplyHunkError, + max_fuzz: usize, +) -> String { + match err { + ApplyHunkError::NoMatch { + expected_line, + adjusted_line, + offset, + } => { + let expected_preview = preview_expected_lines(hunk, HUNK_PREVIEW_LINES).join("\n"); + let file_preview = snippet_around(lines, *adjusted_line, SNIPPET_RADIUS).join("\n"); + format!( + "could not find matching context near line {expected_line} (searched around line {adjusted_line} with offset {offset:+} and fuzz up to {max_fuzz}). Expected context preview:\n{expected_preview}\nFile snippet near line {adjusted_line}:\n{file_preview}\nHints: ensure the patch matches the current file contents, increase `fuzz`, or regenerate the patch." + ) + } + } +} + fn apply_hunks_to_lines( lines: &mut Vec, hunks: &[Hunk], fuzz: usize, -) -> Result<(usize, usize), ToolError> { - let mut total_fuzz = 0; - let mut hunks_applied = 0; + file_label: &str, +) -> Result { + let mut stats = HunkApplyStats::default(); let mut cumulative_offset: isize = 0; - for hunk in hunks { + for (idx, hunk) in hunks.iter().enumerate() { match apply_hunk(lines, hunk, fuzz, &mut cumulative_offset) { Ok(fuzz_used) => { - total_fuzz += fuzz_used; - hunks_applied += 1; + stats.fuzz_used += fuzz_used; + stats.hunks_applied += 1; + if fuzz_used > 0 { + stats.hunks_with_fuzz += 1; + } } Err(e) => { + let detail = format_hunk_no_match_error(lines, hunk, &e, fuzz); return Err(ToolError::execution_failed(format!( - "Failed to apply hunk at line {}: {}", - hunk.old_start, e + "Failed to apply hunk {}/{} for `{}`: {}", + idx + 1, + hunks.len(), + file_label, + detail ))); } } } - Ok((hunks_applied, total_fuzz)) + Ok(stats) } /// Apply a hunk to the file content with fuzzy matching @@ -721,6 +1003,10 @@ mod tests { use super::*; use tempfile::tempdir; + fn parse_patch_result(result: ToolResult) -> PatchResult { + serde_json::from_str(&result.content).expect("patch result json") + } + #[test] fn test_parse_range() { assert_eq!(parse_range("10,5").unwrap(), (10, 5)); @@ -851,6 +1137,9 @@ mod tests { .expect("execute"); assert!(result.success); + let patch_result = parse_patch_result(result); + assert_eq!(patch_result.touched_files, vec!["test.txt"]); + assert_eq!(patch_result.hunks_applied, 1); // Verify the patch was applied let content = fs::read_to_string(tmp.path().join("test.txt")).expect("read"); @@ -878,6 +1167,8 @@ mod tests { .expect("execute"); assert!(result.success); + let patch_result = parse_patch_result(result); + assert_eq!(patch_result.touched_files, vec!["test.txt"]); let content = fs::read_to_string(tmp.path().join("test.txt")).expect("read"); assert!(content.contains("line2")); @@ -904,6 +1195,9 @@ mod tests { .expect("execute"); assert!(result.success); + let patch_result = parse_patch_result(result); + assert_eq!(patch_result.touched_files, vec!["new_file.txt"]); + assert!(patch_result.file_summaries.first().unwrap().created); assert!(tmp.path().join("new_file.txt").exists()); } @@ -929,6 +1223,11 @@ mod tests { .expect("execute"); assert!(result.success); + let patch_result = parse_patch_result(result); + let mut touched = patch_result.touched_files.clone(); + touched.sort(); + assert_eq!(touched, vec!["one.txt", "two.txt"]); + assert_eq!(patch_result.hunks_total, 0); assert_eq!( fs::read_to_string(tmp.path().join("one.txt")).unwrap(), "new\n" @@ -970,12 +1269,134 @@ diff --git a/b.txt b/b.txt .expect("execute"); assert!(result.success); + let patch_result = parse_patch_result(result); + let mut touched = patch_result.touched_files.clone(); + touched.sort(); + assert_eq!(touched, vec!["a.txt", "b.txt"]); + assert_eq!(patch_result.files_applied, 2); let a = fs::read_to_string(tmp.path().join("a.txt")).unwrap(); let b = fs::read_to_string(tmp.path().join("b.txt")).unwrap(); assert!(a.contains("line2-mod")); assert!(b.contains("beta2")); } + #[tokio::test] + async fn test_apply_patch_requires_headers_without_path() { + let tmp = tempdir().expect("tempdir"); + let ctx = ToolContext::new(tmp.path().to_path_buf()); + let tool = ApplyPatchTool; + + let patch = r"@@ -1,1 +1,1 @@ +-old ++new +"; + + let err = tool + .execute(json!({"patch": patch}), &ctx) + .await + .unwrap_err(); + match err { + ToolError::InvalidInput { message } => { + assert!(message.contains("no file headers")); + assert!(message.contains("Provide `path`")); + } + other => panic!("expected invalid input, got: {other}"), + } + } + + #[tokio::test] + async fn test_path_override_rejects_multi_file_diff() { + let tmp = tempdir().expect("tempdir"); + let ctx = ToolContext::new(tmp.path().to_path_buf()); + let tool = ApplyPatchTool; + + let patch = r"diff --git a/a.txt b/a.txt +--- a/a.txt ++++ b/a.txt +@@ -1,1 +1,1 @@ +-one ++one-mod +diff --git a/b.txt b/b.txt +--- a/b.txt ++++ b/b.txt +@@ -1,1 +1,1 @@ +-two ++two-mod +"; + + let err = tool + .execute(json!({"path": "a.txt", "patch": patch}), &ctx) + .await + .unwrap_err(); + match err { + ToolError::InvalidInput { message } => { + assert!(message.contains("multiple files")); + assert!(message.contains("a.txt")); + assert!(message.contains("b.txt")); + } + other => panic!("expected invalid input, got: {other}"), + } + } + + #[tokio::test] + async fn test_apply_patch_summary_reports_fuzz() { + let tmp = tempdir().expect("tempdir"); + let ctx = ToolContext::new(tmp.path().to_path_buf()); + let tool = ApplyPatchTool; + + fs::write(tmp.path().join("test.txt"), "line0\nline1\nline2\nline3\n").expect("write"); + + let patch = r"@@ -1,2 +1,2 @@ +-line1 ++modified + line2 +"; + + let result = tool + .execute(json!({"path": "test.txt", "patch": patch, "fuzz": 3}), &ctx) + .await + .expect("execute"); + assert!(result.success); + let patch_result = parse_patch_result(result); + assert_eq!(patch_result.hunks_with_fuzz, 1); + assert!(patch_result.fuzz_used > 0); + assert!(patch_result.message.contains("Fuzz used")); + let summary = patch_result.file_summaries.first().unwrap(); + assert_eq!(summary.hunks_with_fuzz, 1); + } + + #[tokio::test] + async fn test_path_override_header_mismatch_note() { + let tmp = tempdir().expect("tempdir"); + let ctx = ToolContext::new(tmp.path().to_path_buf()); + let tool = ApplyPatchTool; + + fs::write(tmp.path().join("override.txt"), "old\n").expect("write"); + + let patch = r"--- a/other.txt ++++ b/other.txt +@@ -1,1 +1,1 @@ +-old ++new +"; + + let result = tool + .execute(json!({"path": "override.txt", "patch": patch}), &ctx) + .await + .expect("execute"); + let patch_result = parse_patch_result(result); + assert!( + patch_result + .message + .contains("headers reference `other.txt`") + ); + assert!( + patch_result + .message + .contains("path` overrides to `override.txt`") + ); + } + #[test] fn test_apply_patch_tool_properties() { let tool = ApplyPatchTool; diff --git a/src/tools/diagnostics.rs b/src/tools/diagnostics.rs new file mode 100644 index 00000000..b2509d33 --- /dev/null +++ b/src/tools/diagnostics.rs @@ -0,0 +1,240 @@ +//! Workspace diagnostics tool: `diagnostics`. +//! +//! This tool gathers lightweight, best-effort environment information without +//! failing hard when optional commands are unavailable. + +use std::env; +use std::path::Path; +use std::process::Command; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; + +use super::spec::{ + ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec, +}; + +/// Tool for collecting workspace and toolchain diagnostics. +pub struct DiagnosticsTool; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct DiagnosticsOutput { + workspace_root: String, + current_dir: Option, + current_dir_error: Option, + git_repo: bool, + git_branch: Option, + git_error: Option, + sandbox_available: bool, + sandbox_type: Option, + rustc_version: Option, + cargo_version: Option, +} + +#[derive(Debug, Clone, Default)] +struct GitProbe { + detected: bool, + branch: Option, + error: Option, +} + +#[async_trait] +impl ToolSpec for DiagnosticsTool { + fn name(&self) -> &'static str { + "diagnostics" + } + + fn description(&self) -> &'static str { + "Report workspace info, git detection, sandbox availability, and Rust toolchain versions." + } + + fn input_schema(&self) -> Value { + json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + }) + } + + fn capabilities(&self) -> Vec { + vec![ToolCapability::ReadOnly] + } + + fn approval_requirement(&self) -> ApprovalRequirement { + ApprovalRequirement::Auto + } + + fn supports_parallel(&self) -> bool { + true + } + + async fn execute(&self, _input: Value, context: &ToolContext) -> Result { + let workspace_root = context.workspace.display().to_string(); + + let (current_dir, current_dir_error) = match env::current_dir() { + Ok(dir) => (Some(dir.display().to_string()), None), + Err(err) => (None, Some(err.to_string())), + }; + + let git = probe_git(&context.workspace); + let sandbox_type = crate::sandbox::get_platform_sandbox().map(|s| s.to_string()); + let sandbox_available = sandbox_type.is_some(); + + let diagnostics = DiagnosticsOutput { + workspace_root, + current_dir, + current_dir_error, + git_repo: git.detected, + git_branch: git.branch, + git_error: git.error, + sandbox_available, + sandbox_type, + rustc_version: probe_version("rustc", &["--version"], &context.workspace), + cargo_version: probe_version("cargo", &["--version"], &context.workspace), + }; + + ToolResult::json(&diagnostics).map_err(|e| ToolError::execution_failed(e.to_string())) + } +} + +// === Helpers === + +fn probe_git(workspace: &Path) -> GitProbe { + let rev_parse = run_command("git", &["rev-parse", "--is-inside-work-tree"], workspace); + match rev_parse { + CommandProbe::Success(out) => { + if out.trim() != "true" { + return GitProbe { + detected: false, + branch: None, + error: Some(format!("unexpected git rev-parse output: {out}")), + }; + } + let branch = run_command("git", &["rev-parse", "--abbrev-ref", "HEAD"], workspace) + .into_success(); + GitProbe { + detected: true, + branch, + error: None, + } + } + CommandProbe::Failed { stderr, .. } => GitProbe { + detected: false, + branch: None, + error: stderr, + }, + CommandProbe::Missing => GitProbe { + detected: false, + branch: None, + error: Some("git is not installed or not in PATH".to_string()), + }, + } +} + +fn probe_version(program: &str, args: &[&str], cwd: &Path) -> Option { + run_command(program, args, cwd).into_success() +} + +enum CommandProbe { + Success(String), + Failed { stderr: Option }, + Missing, +} + +impl CommandProbe { + fn into_success(self) -> Option { + match self { + CommandProbe::Success(out) => Some(out), + CommandProbe::Failed { .. } | CommandProbe::Missing => None, + } + } +} + +fn run_command(program: &str, args: &[&str], cwd: &Path) -> CommandProbe { + let output = Command::new(program).args(args).current_dir(cwd).output(); + let output = match output { + Ok(output) => output, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => return CommandProbe::Missing, + Err(_) => return CommandProbe::Failed { stderr: None }, + }; + + if output.status.success() { + CommandProbe::Success(String::from_utf8_lossy(&output.stdout).trim().to_string()) + } else { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + CommandProbe::Failed { + stderr: if stderr.is_empty() { + None + } else { + Some(stderr) + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::path::Path; + use std::process::Command; + use tempfile::tempdir; + + fn git_available() -> bool { + Command::new("git") + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + } + + fn init_git_repo(root: &Path) { + let run = |args: &[&str]| { + let status = Command::new("git") + .args(args) + .current_dir(root) + .status() + .expect("git should spawn"); + assert!(status.success(), "git {:?} failed", args); + }; + run(&["init", "-q"]); + run(&["config", "user.email", "test@example.com"]); + run(&["config", "user.name", "Test User"]); + fs::write(root.join("README.md"), "init\n").expect("write"); + run(&["add", "."]); + run(&["commit", "-q", "-m", "init"]); + } + + #[tokio::test] + async fn diagnostics_runs_best_effort_outside_git_repo() { + let tmp = tempdir().expect("tempdir"); + let ctx = ToolContext::new(tmp.path()); + let tool = DiagnosticsTool; + let result = tool.execute(json!({}), &ctx).await.expect("execute"); + assert!(result.success); + + let parsed: DiagnosticsOutput = + serde_json::from_str(&result.content).expect("tool result should be json"); + assert_eq!(parsed.workspace_root, tmp.path().display().to_string()); + } + + #[tokio::test] + async fn diagnostics_detects_git_repo_when_available() { + if !git_available() { + return; + } + let tmp = tempdir().expect("tempdir"); + init_git_repo(tmp.path()); + + let ctx = ToolContext::new(tmp.path()); + let tool = DiagnosticsTool; + let result = tool.execute(json!({}), &ctx).await.expect("execute"); + assert!(result.success); + + let parsed: DiagnosticsOutput = + serde_json::from_str(&result.content).expect("tool result should be json"); + assert!(parsed.git_repo); + assert!(!parsed.git_branch.as_deref().unwrap_or("").is_empty()); + } +} diff --git a/src/tools/git.rs b/src/tools/git.rs new file mode 100644 index 00000000..994b8bb5 --- /dev/null +++ b/src/tools/git.rs @@ -0,0 +1,432 @@ +//! Git power tools: `git_status` and `git_diff`. +//! +//! These tools are read-only wrappers around common git inspection commands, +//! scoped to the workspace and optionally to a sub-path within it. + +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::Command; + +use async_trait::async_trait; +use serde_json::{Value, json}; + +use super::spec::{ + ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec, + optional_bool, optional_str, optional_u64, +}; + +const MAX_OUTPUT_CHARS: usize = 40_000; +const DEFAULT_UNIFIED: u64 = 3; +const MAX_UNIFIED: u64 = 50; + +// === GitStatusTool === + +/// Tool for reading the concise git status of the workspace. +pub struct GitStatusTool; + +#[async_trait] +impl ToolSpec for GitStatusTool { + fn name(&self) -> &'static str { + "git_status" + } + + fn description(&self) -> &'static str { + "Run `git status --porcelain=v1 -b` in the workspace (optionally scoped to a path)." + } + + fn input_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Optional subdirectory or file to scope the status to (must be within the workspace)." + } + }, + "additionalProperties": false + }) + } + + fn capabilities(&self) -> Vec { + vec![ToolCapability::ReadOnly, ToolCapability::Sandboxable] + } + + fn approval_requirement(&self) -> ApprovalRequirement { + ApprovalRequirement::Auto + } + + fn supports_parallel(&self) -> bool { + true + } + + async fn execute(&self, input: Value, context: &ToolContext) -> Result { + let git_ctx = resolve_git_context(context, optional_str(&input, "path"))?; + + let mut args = vec![ + "status".to_string(), + "--porcelain=v1".to_string(), + "-b".to_string(), + ]; + if let Some(pathspec) = &git_ctx.pathspec { + args.push("--".to_string()); + args.push(pathspec.display().to_string()); + } + + let command_str = format_command(&git_ctx.working_dir, &args); + let output = run_git_command(&git_ctx.working_dir, &args)?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + let message = format!("git status failed: {}", stderr.trim()); + return Ok(ToolResult::error(message).with_metadata(json!({ + "command": command_str, + "exit_code": output.status.code(), + "stderr": stderr.trim(), + }))); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + let (content, truncated, omitted_chars) = truncate_with_note(&stdout, MAX_OUTPUT_CHARS); + + Ok(ToolResult::success(content).with_metadata(json!({ + "command": command_str, + "working_dir": git_ctx.working_dir, + "pathspec": git_ctx.pathspec, + "truncated": truncated, + "omitted_chars": omitted_chars, + }))) + } +} + +// === GitDiffTool === + +/// Tool for reading git diffs in the workspace. +pub struct GitDiffTool; + +#[async_trait] +impl ToolSpec for GitDiffTool { + fn name(&self) -> &'static str { + "git_diff" + } + + fn description(&self) -> &'static str { + "Run `git diff` in the workspace with sensible defaults and safe truncation." + } + + fn input_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Optional subdirectory or file to scope the diff to (must be within the workspace)." + }, + "cached": { + "type": "boolean", + "description": "When true, diff staged changes (`--cached`)." + }, + "unified": { + "type": "integer", + "minimum": 0, + "maximum": MAX_UNIFIED, + "default": DEFAULT_UNIFIED, + "description": "Number of context lines to include around changes." + } + }, + "additionalProperties": false + }) + } + + fn capabilities(&self) -> Vec { + vec![ToolCapability::ReadOnly, ToolCapability::Sandboxable] + } + + fn approval_requirement(&self) -> ApprovalRequirement { + ApprovalRequirement::Auto + } + + fn supports_parallel(&self) -> bool { + true + } + + async fn execute(&self, input: Value, context: &ToolContext) -> Result { + let git_ctx = resolve_git_context(context, optional_str(&input, "path"))?; + let cached = optional_bool(&input, "cached", false); + let unified = optional_u64(&input, "unified", DEFAULT_UNIFIED).min(MAX_UNIFIED); + + let mut args = vec![ + "diff".to_string(), + "--no-color".to_string(), + "--no-ext-diff".to_string(), + format!("--unified={unified}"), + ]; + if cached { + args.push("--cached".to_string()); + } + if let Some(pathspec) = &git_ctx.pathspec { + args.push("--".to_string()); + args.push(pathspec.display().to_string()); + } + + let command_str = format_command(&git_ctx.working_dir, &args); + let output = run_git_command(&git_ctx.working_dir, &args)?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + let message = format!("git diff failed: {}", stderr.trim()); + return Ok(ToolResult::error(message).with_metadata(json!({ + "command": command_str, + "exit_code": output.status.code(), + "stderr": stderr.trim(), + }))); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + let (content, truncated, omitted_chars) = truncate_with_note(&stdout, MAX_OUTPUT_CHARS); + + Ok(ToolResult::success(content).with_metadata(json!({ + "command": command_str, + "working_dir": git_ctx.working_dir, + "pathspec": git_ctx.pathspec, + "cached": cached, + "unified": unified, + "truncated": truncated, + "omitted_chars": omitted_chars, + }))) + } +} + +// === Helpers === + +struct GitContext { + working_dir: PathBuf, + pathspec: Option, +} + +fn resolve_git_context(context: &ToolContext, path: Option<&str>) -> Result { + let workspace = canonical_or_workspace(&context.workspace); + let mut working_dir = workspace.clone(); + let mut pathspec = None; + + if let Some(raw) = path { + let resolved = context.resolve_path(raw)?; + let metadata = fs::metadata(&resolved).map_err(|e| { + ToolError::invalid_input(format!( + "Path does not exist or is not accessible: {raw} ({e})" + )) + })?; + + if metadata.is_dir() { + working_dir = resolved; + pathspec = Some(PathBuf::from(".")); + } else { + // For file paths, run from the parent and scope to the file name. + let parent = resolved.parent().ok_or_else(|| { + ToolError::invalid_input(format!("Path has no parent directory: {raw}")) + })?; + working_dir = parent.to_path_buf(); + pathspec = Some(pathspec_from(&working_dir, &resolved)); + } + } + + if !working_dir.exists() { + return Err(ToolError::invalid_input(format!( + "Working directory does not exist: {}", + working_dir.display() + ))); + } + + Ok(GitContext { + working_dir, + pathspec, + }) +} + +fn canonical_or_workspace(workspace: &Path) -> PathBuf { + workspace + .canonicalize() + .unwrap_or_else(|_| workspace.to_path_buf()) +} + +fn pathspec_from(working_dir: &Path, resolved: &Path) -> PathBuf { + match resolved.strip_prefix(working_dir) { + Ok(rel) if rel.as_os_str().is_empty() => PathBuf::from("."), + Ok(rel) => rel.to_path_buf(), + Err(_) => PathBuf::from("."), + } +} + +fn run_git_command(working_dir: &Path, args: &[String]) -> Result { + let mut cmd = Command::new("git"); + cmd.args(args).current_dir(working_dir); + cmd.output().map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + ToolError::not_available("git is not installed or not in PATH") + } else { + ToolError::execution_failed(format!("Failed to run git: {e}")) + } + }) +} + +fn format_command(working_dir: &Path, args: &[String]) -> String { + format!( + "git -C {} {}", + working_dir.display(), + args.iter() + .map(String::as_str) + .collect::>() + .join(" ") + ) +} + +fn truncate_with_note(text: &str, max_chars: usize) -> (String, bool, usize) { + if text.chars().count() <= max_chars { + return (text.to_string(), false, 0); + } + let end = char_boundary_index(text, max_chars); + let truncated = &text[..end]; + let omitted_chars = text + .chars() + .count() + .saturating_sub(truncated.chars().count()); + let note = format!( + "\n\n[output truncated to {max_chars} characters; {omitted_chars} characters omitted]" + ); + (format!("{truncated}{note}"), true, omitted_chars) +} + +fn char_boundary_index(text: &str, max_chars: usize) -> usize { + if max_chars == 0 { + return 0; + } + for (count, (idx, _)) in text.char_indices().enumerate() { + if count == max_chars { + return idx; + } + } + text.len() +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::process::Command; + use tempfile::tempdir; + + fn git_available() -> bool { + Command::new("git") + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + } + + fn init_git_repo(root: &Path) { + let run = |args: &[&str]| { + let status = Command::new("git") + .args(args) + .current_dir(root) + .status() + .expect("git should spawn"); + assert!(status.success(), "git {:?} failed", args); + }; + + run(&["init", "-q"]); + run(&["config", "user.email", "test@example.com"]); + run(&["config", "user.name", "Test User"]); + } + + fn commit_all(root: &Path, message: &str) { + let run = |args: &[&str]| { + let status = Command::new("git") + .args(args) + .current_dir(root) + .status() + .expect("git should spawn"); + assert!(status.success(), "git {:?} failed", args); + }; + run(&["add", "."]); + run(&["commit", "-q", "-m", message]); + } + + #[tokio::test] + async fn git_status_reports_branch_and_changes() { + if !git_available() { + return; + } + let tmp = tempdir().expect("tempdir"); + init_git_repo(tmp.path()); + + let file = tmp.path().join("file.txt"); + fs::write(&file, "hello\n").expect("write"); + commit_all(tmp.path(), "init"); + + fs::write(&file, "hello\nworld\n").expect("modify"); + + let ctx = ToolContext::new(tmp.path()); + let tool = GitStatusTool; + let result = tool.execute(json!({}), &ctx).await.expect("execute"); + assert!(result.success); + assert!(result.content.contains("##")); + assert!(result.content.contains("file.txt")); + } + + #[tokio::test] + async fn git_diff_supports_cached_and_path_scoping() { + if !git_available() { + return; + } + let tmp = tempdir().expect("tempdir"); + init_git_repo(tmp.path()); + + let subdir = tmp.path().join("src"); + fs::create_dir_all(&subdir).expect("mkdir"); + let file = subdir.join("lib.rs"); + fs::write(&file, "pub fn one() -> i32 { 1 }\n").expect("write"); + commit_all(tmp.path(), "init"); + + fs::write(&file, "pub fn one() -> i32 { 2 }\n").expect("modify"); + + let ctx = ToolContext::new(tmp.path()); + let tool = GitDiffTool; + + let uncached = tool + .execute(json!({ "path": "src" }), &ctx) + .await + .expect("diff"); + assert!(uncached.success); + assert!(uncached.content.contains("diff --git")); + assert!(uncached.content.contains("lib.rs")); + + let _ = Command::new("git") + .args(["add", "src/lib.rs"]) + .current_dir(tmp.path()) + .status() + .expect("git add"); + + let cached = tool + .execute(json!({ "path": "src", "cached": true }), &ctx) + .await + .expect("diff cached"); + assert!(cached.success); + assert!(cached.content.contains("diff --git")); + assert!( + cached + .metadata + .as_ref() + .and_then(|m| m.get("cached")) + .and_then(Value::as_bool) + .unwrap_or(false) + ); + } + + #[test] + fn truncation_adds_note() { + let long = "a".repeat(MAX_OUTPUT_CHARS + 100); + let (truncated, did_truncate, omitted) = truncate_with_note(&long, MAX_OUTPUT_CHARS); + assert!(did_truncate); + assert!(omitted > 0); + assert!(truncated.contains("output truncated")); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index d268b90b..cf447e53 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -5,9 +5,11 @@ // === Modules === pub mod apply_patch; +pub mod diagnostics; pub mod duo; pub mod file; pub mod file_search; +pub mod git; pub mod plan; pub mod registry; pub mod review; @@ -16,6 +18,8 @@ pub mod search; pub mod shell; pub mod spec; pub mod subagent; +pub mod swarm; +pub mod test_runner; pub mod todo; pub mod web_search; @@ -43,12 +47,21 @@ pub use review::{ReviewOutput, ReviewTool}; // Re-export file tools pub use file::{EditFileTool, ListDirTool, ReadFileTool, WriteFileTool}; +// Re-export diagnostics tool +pub use diagnostics::DiagnosticsTool; + +// Re-export git tools +pub use git::{GitDiffTool, GitStatusTool}; + // Re-export shell types pub use shell::ExecShellTool; // Re-export subagent types pub use subagent::SubAgent; +// Re-export test runner tool +pub use test_runner::RunTestsTool; + // Re-export todo types pub use todo::TodoWriteTool; diff --git a/src/tools/plan.rs b/src/tools/plan.rs index a1cbb00d..1667b785 100644 --- a/src/tools/plan.rs +++ b/src/tools/plan.rs @@ -1,7 +1,8 @@ //! Plan tool implementation with step tracking and validation -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::{Duration, Instant}; +use tokio::sync::Mutex; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -388,10 +389,7 @@ impl ToolSpec for UpdatePlanTool { plan: plan_args, }; - let mut state = self - .plan_state - .lock() - .map_err(|e| ToolError::execution_failed(format!("Failed to lock plan state: {e}")))?; + let mut state = self.plan_state.lock().await; state.update(args); diff --git a/src/tools/registry.rs b/src/tools/registry.rs index b0dad967..833708e0 100644 --- a/src/tools/registry.rs +++ b/src/tools/registry.rs @@ -293,6 +293,28 @@ impl ToolRegistryBuilder { .with_tool(Arc::new(FileSearchTool)) } + /// Include git inspection tools (`git_status`, `git_diff`). + #[must_use] + pub fn with_git_tools(self) -> Self { + use super::git::{GitDiffTool, GitStatusTool}; + self.with_tool(Arc::new(GitStatusTool)) + .with_tool(Arc::new(GitDiffTool)) + } + + /// Include workspace diagnostics tool. + #[must_use] + pub fn with_diagnostics_tool(self) -> Self { + use super::diagnostics::DiagnosticsTool; + self.with_tool(Arc::new(DiagnosticsTool)) + } + + /// Include cargo test runner tool. + #[must_use] + pub fn with_test_runner_tool(self) -> Self { + use super::test_runner::RunTestsTool; + self.with_tool(Arc::new(RunTestsTool)) + } + /// Include web search tools. #[must_use] pub fn with_web_tools(self) -> Self { @@ -329,7 +351,10 @@ impl ToolRegistryBuilder { .with_note_tool() .with_search_tools() .with_web_tools() - .with_patch_tools(); + .with_patch_tools() + .with_git_tools() + .with_diagnostics_tool() + .with_test_runner_tool(); if allow_shell { builder.with_shell_tools() @@ -403,11 +428,16 @@ impl ToolRegistryBuilder { runtime: super::subagent::SubAgentRuntime, ) -> Self { use super::subagent::{AgentCancelTool, AgentListTool, AgentResultTool, AgentSpawnTool}; + use super::swarm::AgentSwarmTool; - self.with_tool(Arc::new(AgentSpawnTool::new(manager.clone(), runtime))) - .with_tool(Arc::new(AgentResultTool::new(manager.clone()))) - .with_tool(Arc::new(AgentCancelTool::new(manager.clone()))) - .with_tool(Arc::new(AgentListTool::new(manager))) + self.with_tool(Arc::new(AgentSpawnTool::new( + manager.clone(), + runtime.clone(), + ))) + .with_tool(Arc::new(AgentSwarmTool::new(manager.clone(), runtime))) + .with_tool(Arc::new(AgentResultTool::new(manager.clone()))) + .with_tool(Arc::new(AgentCancelTool::new(manager.clone()))) + .with_tool(Arc::new(AgentListTool::new(manager))) } /// Build the registry with the given context. diff --git a/src/tools/shell.rs b/src/tools/shell.rs index fa7db696..c528387c 100644 --- a/src/tools/shell.rs +++ b/src/tools/shell.rs @@ -29,6 +29,9 @@ use crate::sandbox::{ /// Maximum output size before truncation (30KB like Claude Code) const MAX_OUTPUT_SIZE: usize = 30_000; +/// Limits for summary strings in tool metadata. +const SUMMARY_MAX_LINES: usize = 3; +const SUMMARY_MAX_CHARS: usize = 240; /// Status of a shell process #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -49,6 +52,24 @@ pub struct ShellResult { pub stdout: String, pub stderr: String, pub duration_ms: u64, + /// Original stdout length in bytes. + #[serde(default)] + pub stdout_len: usize, + /// Original stderr length in bytes. + #[serde(default)] + pub stderr_len: usize, + /// Bytes omitted from stdout due to truncation. + #[serde(default)] + pub stdout_omitted: usize, + /// Bytes omitted from stderr due to truncation. + #[serde(default)] + pub stderr_omitted: usize, + /// Whether stdout was truncated. + #[serde(default)] + pub stdout_truncated: bool, + /// Whether stderr was truncated. + #[serde(default)] + pub stderr_truncated: bool, /// Whether the command was executed in a sandbox. #[serde(default)] pub sandboxed: bool, @@ -134,13 +155,21 @@ impl BackgroundShell { /// Get a snapshot of the current state pub fn snapshot(&self) -> ShellResult { let sandboxed = !matches!(self.sandbox_type, SandboxType::None); + let (stdout, stdout_meta) = truncate_with_meta(&self.stdout); + let (stderr, stderr_meta) = truncate_with_meta(&self.stderr); ShellResult { task_id: Some(self.id.clone()), status: self.status.clone(), exit_code: self.exit_code, - stdout: truncate_output(&self.stdout), - stderr: truncate_output(&self.stderr), + stdout, + stderr, duration_ms: u64::try_from(self.started_at.elapsed().as_millis()).unwrap_or(u64::MAX), + stdout_len: stdout_meta.original_len, + stderr_len: stderr_meta.original_len, + stdout_omitted: stdout_meta.omitted, + stderr_omitted: stderr_meta.omitted, + stdout_truncated: stdout_meta.truncated, + stderr_truncated: stderr_meta.truncated, sandboxed, sandbox_type: if sandboxed { Some(self.sandbox_type.to_string()) @@ -319,11 +348,14 @@ impl ShellManager { if let Some(status) = child.wait_timeout(timeout)? { let stdout = stdout_thread.join().unwrap_or_default(); let stderr = stderr_thread.join().unwrap_or_default(); - let stderr_str = String::from_utf8_lossy(&stderr); + let stdout_str = String::from_utf8_lossy(&stdout).to_string(); + let stderr_str = String::from_utf8_lossy(&stderr).to_string(); let exit_code = status.code().unwrap_or(-1); // Check if sandbox denied the operation let sandbox_denied = SandboxManager::was_denied(sandbox_type, exit_code, &stderr_str); + let (stdout, stdout_meta) = truncate_with_meta(&stdout_str); + let (stderr, stderr_meta) = truncate_with_meta(&stderr_str); Ok(ShellResult { task_id: None, @@ -333,9 +365,15 @@ impl ShellManager { ShellStatus::Failed }, exit_code: status.code(), - stdout: truncate_output(&String::from_utf8_lossy(&stdout)), - stderr: truncate_output(&stderr_str), + stdout, + stderr, duration_ms: u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX), + stdout_len: stdout_meta.original_len, + stderr_len: stderr_meta.original_len, + stdout_omitted: stdout_meta.omitted, + stderr_omitted: stderr_meta.omitted, + stdout_truncated: stdout_meta.truncated, + stderr_truncated: stderr_meta.truncated, sandboxed, sandbox_type: if sandboxed { Some(sandbox_type.to_string()) @@ -350,14 +388,24 @@ impl ShellManager { let status = child.wait().ok(); let stdout = stdout_thread.join().unwrap_or_default(); let stderr = stderr_thread.join().unwrap_or_default(); + let stdout_str = String::from_utf8_lossy(&stdout).to_string(); + let stderr_str = String::from_utf8_lossy(&stderr).to_string(); + let (stdout, stdout_meta) = truncate_with_meta(&stdout_str); + let (stderr, stderr_meta) = truncate_with_meta(&stderr_str); Ok(ShellResult { task_id: None, status: ShellStatus::TimedOut, exit_code: status.and_then(|s| s.code()), - stdout: truncate_output(&String::from_utf8_lossy(&stdout)), - stderr: truncate_output(&String::from_utf8_lossy(&stderr)), + stdout, + stderr, duration_ms: u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX), + stdout_len: stdout_meta.original_len, + stderr_len: stderr_meta.original_len, + stdout_omitted: stdout_meta.omitted, + stderr_omitted: stderr_meta.omitted, + stdout_truncated: stdout_meta.truncated, + stderr_truncated: stderr_meta.truncated, sandboxed, sandbox_type: if sandboxed { Some(sandbox_type.to_string()) @@ -411,6 +459,12 @@ impl ShellManager { stdout: String::new(), stderr: String::new(), duration_ms: u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX), + stdout_len: 0, + stderr_len: 0, + stdout_omitted: 0, + stderr_omitted: 0, + stdout_truncated: false, + stderr_truncated: false, sandboxed, sandbox_type: if sandboxed { Some(sandbox_type.to_string()) @@ -430,6 +484,12 @@ impl ShellManager { stdout: String::new(), stderr: String::new(), duration_ms: u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX), + stdout_len: 0, + stderr_len: 0, + stdout_omitted: 0, + stderr_omitted: 0, + stdout_truncated: false, + stderr_truncated: false, sandboxed, sandbox_type: if sandboxed { Some(sandbox_type.to_string()) @@ -518,6 +578,12 @@ impl ShellManager { stdout: String::new(), stderr: String::new(), duration_ms: 0, + stdout_len: 0, + stderr_len: 0, + stdout_omitted: 0, + stderr_omitted: 0, + stdout_truncated: false, + stderr_truncated: false, sandboxed, sandbox_type: if sandboxed { Some(sandbox_type.to_string()) @@ -599,19 +665,100 @@ impl ShellManager { } } +#[derive(Debug, Clone, Copy, Default)] +struct TruncationMeta { + original_len: usize, + omitted: usize, + truncated: bool, +} + +fn truncate_with_meta(output: &str) -> (String, TruncationMeta) { + let original_len = output.len(); + if original_len <= MAX_OUTPUT_SIZE { + return ( + output.to_string(), + TruncationMeta { + original_len, + omitted: 0, + truncated: false, + }, + ); + } + + let cut_index = char_boundary_at_or_before(output, MAX_OUTPUT_SIZE); + let truncated = &output[..cut_index]; + let omitted = original_len.saturating_sub(cut_index); + let note = + format!("...\n\n[Output truncated at {MAX_OUTPUT_SIZE} bytes. {omitted} bytes omitted.]"); + + ( + format!("{truncated}{note}"), + TruncationMeta { + original_len, + omitted, + truncated: true, + }, + ) +} + +fn char_boundary_at_or_before(text: &str, max_bytes: usize) -> usize { + if max_bytes >= text.len() { + return text.len(); + } + + let mut last_end = 0usize; + for (idx, ch) in text.char_indices() { + let end = idx.saturating_add(ch.len_utf8()); + if end > max_bytes { + break; + } + last_end = end; + } + + last_end.min(text.len()) +} + +fn strip_truncation_note(text: &str) -> &str { + text.split_once("\n\n[Output truncated at") + .map_or(text, |(prefix, _)| prefix) +} + +fn truncate_chars(text: &str, max_chars: usize) -> String { + if text.chars().count() <= max_chars { + return text.to_string(); + } + + let mut end = text.len(); + for (count, (idx, _)) in text.char_indices().enumerate() { + if count == max_chars { + end = idx; + break; + } + } + + format!("{}...", &text[..end]) +} + +fn summarize_output(text: &str) -> String { + let stripped = strip_truncation_note(text); + let summary = stripped + .lines() + .take(SUMMARY_MAX_LINES) + .collect::>() + .join("\n") + .trim() + .to_string(); + + if summary.is_empty() { + String::new() + } else { + truncate_chars(&summary, SUMMARY_MAX_CHARS) + } +} + /// Truncate output to `MAX_OUTPUT_SIZE` fn truncate_output(output: &str) -> String { - if output.len() <= MAX_OUTPUT_SIZE { - output.to_string() - } else { - let truncated = &output[..MAX_OUTPUT_SIZE]; - format!( - "{}...\n\n[Output truncated at {} characters. {} characters omitted.]", - truncated, - MAX_OUTPUT_SIZE, - output.len() - MAX_OUTPUT_SIZE - ) - } + truncate_with_meta(output).0 } /// Thread-safe wrapper for `ShellManager` @@ -777,6 +924,13 @@ impl ToolSpec for ExecShellTool { match result { Ok(result) => { let task_id_str = result.task_id.clone().unwrap_or_default(); + let stdout_summary = summarize_output(&result.stdout); + let stderr_summary = summarize_output(&result.stderr); + let summary = if !stderr_summary.is_empty() { + stderr_summary.clone() + } else { + stdout_summary.clone() + }; let output = if interactive { format!( "Interactive command completed (exit code: {:?})", @@ -808,7 +962,18 @@ impl ToolSpec for ExecShellTool { "status": format!("{:?}", result.status), "duration_ms": result.duration_ms, "sandboxed": result.sandboxed, + "sandbox_type": result.sandbox_type, + "sandbox_denied": result.sandbox_denied, "task_id": result.task_id, + "stdout_len": result.stdout_len, + "stderr_len": result.stderr_len, + "stdout_truncated": result.stdout_truncated, + "stderr_truncated": result.stderr_truncated, + "stdout_omitted": result.stdout_omitted, + "stderr_omitted": result.stderr_omitted, + "summary": summary, + "stdout_summary": stdout_summary, + "stderr_summary": stderr_summary, "safety_level": format!("{:?}", safety.level), "interactive": interactive, "execpolicy": execpolicy_decision.as_ref().map(|decision| match decision { @@ -900,6 +1065,8 @@ impl ToolSpec for NoteTool { #[cfg(test)] mod tests { use super::*; + use crate::tools::spec::ToolContext; + use serde_json::{Value, json}; use tempfile::tempdir; fn echo_command(message: &str) -> String { @@ -1013,4 +1180,46 @@ mod tests { assert!(truncated.len() < long_output.len()); assert!(truncated.contains("truncated")); } + + #[test] + fn test_truncate_with_meta_reports_omission_counts() { + let long_output = format!("line1\nline2\n{}", "x".repeat(60_000)); + let (truncated, meta) = truncate_with_meta(&long_output); + + assert!(meta.truncated); + assert!(meta.original_len >= long_output.len()); + assert!(meta.omitted > 0); + assert!(truncated.contains("bytes omitted")); + } + + #[test] + fn test_summarize_output_strips_truncation_note() { + let long_output = "x".repeat(60_000); + let truncated = truncate_output(&long_output); + let summary = summarize_output(&truncated); + assert!(!summary.contains("Output truncated at")); + } + + #[tokio::test] + async fn test_exec_shell_metadata_includes_summaries() { + let tmp = tempdir().expect("tempdir"); + let ctx = ToolContext::new(tmp.path()); + let tool = ExecShellTool; + + let result = tool + .execute(json!({"command": echo_command("hello")}), &ctx) + .await + .expect("execute"); + assert!(result.success); + + let meta = result.metadata.expect("metadata"); + let summary = meta + .get("summary") + .and_then(Value::as_str) + .unwrap_or_default() + .to_string(); + assert!(summary.contains("hello")); + assert!(meta.get("stdout_len").is_some()); + assert!(meta.get("stdout_truncated").is_some()); + } } diff --git a/src/tools/subagent.rs b/src/tools/subagent.rs index 59875994..ca2c0357 100644 --- a/src/tools/subagent.rs +++ b/src/tools/subagent.rs @@ -6,8 +6,9 @@ use std::collections::HashMap; use std::path::PathBuf; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::{Duration, Instant}; +use tokio::sync::Mutex; use anyhow::{Result, anyhow}; use async_trait::async_trait; @@ -17,6 +18,7 @@ use tokio::{sync::mpsc, task::JoinHandle}; use uuid::Uuid; use crate::client::DeepSeekClient; +use crate::config::MAX_SUBAGENTS; use crate::core::events::Event; use crate::llm_client::LlmClient; use crate::models::{ContentBlock, Message, MessageRequest, SystemPrompt, Tool}; @@ -87,18 +89,50 @@ impl SubAgentType { "read_file", "write_file", "edit_file", + "apply_patch", + "grep_files", + "file_search", + "web_search", "exec_shell", "note", "todo_write", + "todo_add", + "todo_update", + "todo_list", + "update_plan", ], - Self::Explore => vec!["list_dir", "read_file", "grep_files", "exec_shell"], - Self::Plan => vec!["list_dir", "read_file", "note", "update_plan", "todo_write"], - Self::Review => vec!["list_dir", "read_file", "grep_files", "note"], + Self::Explore => vec![ + "list_dir", + "read_file", + "grep_files", + "file_search", + "web_search", + "exec_shell", + ], + Self::Plan => vec![ + "list_dir", + "read_file", + "grep_files", + "file_search", + "note", + "update_plan", + "todo_write", + "todo_add", + "todo_update", + "todo_list", + ], + Self::Review => vec!["list_dir", "read_file", "grep_files", "file_search", "note"], Self::Custom => vec![], // Must be provided by caller. } } } +impl Default for SubAgentType { + fn default() -> Self { + Self::General + } +} + /// Status of a sub-agent execution. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub enum SubAgentStatus { @@ -215,13 +249,25 @@ impl SubAgentManager { } /// Count running agents. - fn running_count(&self) -> usize { + pub fn running_count(&self) -> usize { self.agents .values() .filter(|agent| agent.status == SubAgentStatus::Running) .count() } + /// Return the maximum number of allowed agents. + #[must_use] + pub fn max_agents(&self) -> usize { + self.max_agents + } + + /// Return remaining capacity for new agents. + #[must_use] + pub fn available_slots(&self) -> usize { + self.max_agents.saturating_sub(self.running_count()) + } + /// Spawn a new background sub-agent. pub fn spawn_background( &mut self, @@ -338,7 +384,7 @@ pub type SharedSubAgentManager = Arc>; /// Create a shared sub-agent manager with a configurable limit. #[must_use] pub fn new_shared_subagent_manager(workspace: PathBuf, max_agents: usize) -> SharedSubAgentManager { - let max_agents = max_agents.clamp(1, 5); + let max_agents = max_agents.clamp(1, MAX_SUBAGENTS); Arc::new(Mutex::new(SubAgentManager::new(workspace, max_agents))) } @@ -423,10 +469,7 @@ impl ToolSpec for AgentSpawnTool { .collect::>() }); - let mut manager = self - .manager - .lock() - .map_err(|_| ToolError::execution_failed("Failed to lock sub-agent manager"))?; + let mut manager = self.manager.lock().await; let result = manager .spawn_background( @@ -503,10 +546,7 @@ impl ToolSpec for AgentResultTool { let result = if block { wait_for_result(&self.manager, agent_id, Duration::from_millis(timeout_ms)).await? } else { - let manager = self - .manager - .lock() - .map_err(|_| ToolError::execution_failed("Failed to lock sub-agent manager"))?; + let manager = self.manager.lock().await; manager .get_result(agent_id) .map_err(|e| ToolError::execution_failed(e.to_string()))? @@ -570,10 +610,7 @@ impl ToolSpec for AgentCancelTool { async fn execute(&self, input: Value, _context: &ToolContext) -> Result { let agent_id = required_str(&input, "agent_id")?; - let mut manager = self - .manager - .lock() - .map_err(|_| ToolError::execution_failed("Failed to lock sub-agent manager"))?; + let mut manager = self.manager.lock().await; let result = manager .cancel(agent_id) .map_err(|e| ToolError::execution_failed(format!("Failed to cancel sub-agent: {e}")))?; @@ -621,10 +658,7 @@ impl ToolSpec for AgentListTool { _input: Value, _context: &ToolContext, ) -> Result { - let manager = self - .manager - .lock() - .map_err(|_| ToolError::execution_failed("Failed to lock sub-agent manager"))?; + let manager = self.manager.lock().await; let results = manager.list(); ToolResult::json(&results).map_err(|e| ToolError::execution_failed(e.to_string())) } @@ -656,11 +690,10 @@ async fn run_subagent_task(task: SubAgentTask) { ) .await; - if let Ok(mut manager) = task.manager_handle.lock() { - match &result { - Ok(res) => manager.update_from_result(&task.agent_id, res.clone()), - Err(err) => manager.update_failed(&task.agent_id, err.to_string()), - } + let mut manager = task.manager_handle.lock().await; + match &result { + Ok(res) => manager.update_from_result(&task.agent_id, res.clone()), + Err(err) => manager.update_failed(&task.agent_id, err.to_string()), } if let Some(event_tx) = task.runtime.event_tx { @@ -794,9 +827,7 @@ async fn wait_for_result( loop { let snapshot = { - let manager = manager - .lock() - .map_err(|_| ToolError::execution_failed("Failed to lock sub-agent manager"))?; + let manager = manager.lock().await; manager .get_result(agent_id) .map_err(|e| ToolError::execution_failed(e.to_string()))? @@ -829,6 +860,8 @@ impl SubAgentToolRegistry { .with_file_tools() .with_search_tools() .with_note_tool() + .with_patch_tools() + .with_web_tools() .with_todo_tool(todo_list) .with_plan_tool(plan_state); diff --git a/src/tools/swarm.rs b/src/tools/swarm.rs new file mode 100644 index 00000000..40a184d1 --- /dev/null +++ b/src/tools/swarm.rs @@ -0,0 +1,753 @@ +//! Swarm orchestration for spawning multiple sub-agents with dependencies. + +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; +use uuid::Uuid; + +use crate::tools::spec::{ + ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec, + optional_bool, optional_str, optional_u64, +}; +use crate::tools::subagent::{ + SharedSubAgentManager, SubAgentResult, SubAgentRuntime, SubAgentStatus, SubAgentType, +}; + +const SWARM_POLL_INTERVAL: Duration = Duration::from_millis(250); +const DEFAULT_SWARM_TIMEOUT_MS: u64 = 600_000; +const DEFAULT_SWARM_TIMEOUT_NONBLOCK_MS: u64 = 15_000; +const MAX_SWARM_TIMEOUT_MS: u64 = 3_600_000; + +#[derive(Debug, Clone, Deserialize)] +struct SwarmTaskSpec { + id: String, + prompt: String, + #[serde(default, rename = "type")] + agent_type: Option, + #[serde(default)] + allowed_tools: Option>, + #[serde(default)] + depends_on: Vec, +} + +#[derive(Debug, Clone)] +enum SwarmTaskState { + Pending, + Running { agent_id: String }, + Done(SubAgentResult), + Failed(String), + Skipped(String), +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +enum SwarmTaskStatus { + Pending, + Running, + Completed, + Failed, + Cancelled, + Skipped, +} + +#[derive(Debug, Clone, Serialize)] +struct SwarmTaskOutcome { + task_id: String, + agent_id: Option, + status: SwarmTaskStatus, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + steps_taken: u32, + duration_ms: u64, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +enum SwarmStatus { + Completed, + Partial, + Timeout, + Failed, +} + +#[derive(Debug, Clone, Serialize)] +struct SwarmCounts { + total: usize, + completed: usize, + failed: usize, + cancelled: usize, + skipped: usize, + running: usize, + pending: usize, +} + +#[derive(Debug, Clone, Serialize)] +struct SwarmOutcome { + swarm_id: String, + status: SwarmStatus, + duration_ms: u64, + counts: SwarmCounts, + tasks: Vec, +} + +/// Tool to launch a swarm of sub-agents with dependency-aware scheduling. +pub struct AgentSwarmTool { + manager: SharedSubAgentManager, + runtime: SubAgentRuntime, +} + +impl AgentSwarmTool { + /// Create a new swarm tool. + #[must_use] + pub fn new(manager: SharedSubAgentManager, runtime: SubAgentRuntime) -> Self { + Self { manager, runtime } + } +} + +#[async_trait] +impl ToolSpec for AgentSwarmTool { + fn name(&self) -> &'static str { + "agent_swarm" + } + + fn description(&self) -> &'static str { + "Spawn multiple sub-agents with optional dependencies and aggregate their results." + } + + fn input_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "tasks": { + "type": "array", + "description": "List of swarm tasks to execute.", + "items": { + "type": "object", + "properties": { + "id": { "type": "string", "description": "Unique task id." }, + "prompt": { "type": "string", "description": "Task prompt for the sub-agent." }, + "type": { "type": "string", "description": "Sub-agent type: general, explore, plan, review, custom." }, + "allowed_tools": { + "type": "array", + "items": { "type": "string" }, + "description": "Explicit tool allowlist (required for custom type)." + }, + "depends_on": { + "type": "array", + "items": { "type": "string" }, + "description": "List of task ids that must complete successfully first." + } + }, + "required": ["id", "prompt"] + } + }, + "shared_context": { + "type": "string", + "description": "Optional shared context prepended to each task prompt." + }, + "block": { + "type": "boolean", + "description": "Whether to wait for tasks to finish (default: true)." + }, + "timeout_ms": { + "type": "integer", + "description": "Max wall time in milliseconds before returning partial results." + }, + "max_parallel": { + "type": "integer", + "description": "Max concurrent swarm agents (defaults to max_subagents)." + }, + "fail_fast": { + "type": "boolean", + "description": "Cancel remaining work on first failure (default: false)." + } + }, + "required": ["tasks"] + }) + } + + fn capabilities(&self) -> Vec { + vec![ + ToolCapability::ExecutesCode, + ToolCapability::RequiresApproval, + ] + } + + fn approval_requirement(&self) -> ApprovalRequirement { + ApprovalRequirement::Required + } + + async fn execute(&self, input: Value, _context: &ToolContext) -> Result { + let tasks_value = input + .get("tasks") + .cloned() + .ok_or_else(|| ToolError::missing_field("tasks"))?; + let tasks: Vec = serde_json::from_value(tasks_value) + .map_err(|err| ToolError::invalid_input(format!("Invalid tasks payload: {err}")))?; + + validate_swarm_tasks(&tasks)?; + + let block = optional_bool(&input, "block", true); + let default_timeout = if block { + DEFAULT_SWARM_TIMEOUT_MS + } else { + DEFAULT_SWARM_TIMEOUT_NONBLOCK_MS + }; + let timeout_ms = + optional_u64(&input, "timeout_ms", default_timeout).clamp(1_000, MAX_SWARM_TIMEOUT_MS); + let fail_fast = optional_bool(&input, "fail_fast", false); + let shared_context = optional_str(&input, "shared_context") + .map(str::trim) + .filter(|text| !text.is_empty()) + .map(str::to_string); + + let max_parallel = { + let manager = self.manager.lock().await; + let max_agents = manager.max_agents(); + let requested = optional_u64(&input, "max_parallel", max_agents as u64); + requested.clamp(1, max_agents as u64) as usize + }; + + let outcome = run_swarm( + &self.manager, + &self.runtime, + tasks, + shared_context, + Duration::from_millis(timeout_ms), + max_parallel, + fail_fast, + block, + ) + .await?; + + ToolResult::json(&outcome).map_err(|err| ToolError::execution_failed(err.to_string())) + } +} + +#[allow(clippy::too_many_arguments)] +async fn run_swarm( + shared_manager: &SharedSubAgentManager, + runtime: &SubAgentRuntime, + tasks: Vec, + shared_context: Option, + timeout: Duration, + max_parallel: usize, + fail_fast: bool, + block: bool, +) -> Result { + let swarm_id = format!("swarm_{}", &Uuid::new_v4().to_string()[..8]); + let start = Instant::now(); + let deadline = start + timeout; + let task_order = tasks.iter().map(|task| task.id.clone()).collect::>(); + + let mut task_map = HashMap::new(); + let mut states = HashMap::new(); + let mut pending = HashSet::new(); + for task in tasks { + pending.insert(task.id.clone()); + states.insert(task.id.clone(), SwarmTaskState::Pending); + task_map.insert(task.id.clone(), task); + } + + let mut running: HashMap = HashMap::new(); + let mut fail_fast_triggered = false; + let mut timed_out = false; + + loop { + let mut changed = false; + + if !running.is_empty() { + let snapshots = { + let manager = shared_manager.lock().await; + manager.list() + }; + let snapshot_map: HashMap = snapshots + .into_iter() + .map(|snapshot| (snapshot.agent_id.clone(), snapshot)) + .collect(); + + let running_ids = running.clone(); + for (task_id, agent_id) in running_ids { + match snapshot_map.get(&agent_id) { + Some(snapshot) => { + if snapshot.status != SubAgentStatus::Running { + states.insert(task_id.clone(), SwarmTaskState::Done(snapshot.clone())); + running.remove(&task_id); + changed = true; + if fail_fast + && matches!( + snapshot.status, + SubAgentStatus::Failed(_) | SubAgentStatus::Cancelled + ) + { + fail_fast_triggered = true; + } + } + } + None => { + states.insert( + task_id.clone(), + SwarmTaskState::Failed("Agent result not found".to_string()), + ); + running.remove(&task_id); + changed = true; + if fail_fast { + fail_fast_triggered = true; + } + } + } + } + } + + if fail_fast_triggered { + apply_fail_fast(shared_manager, &mut states, &mut pending, &mut running).await?; + break; + } + + let mut newly_skipped = Vec::new(); + for task_id in pending.iter() { + if let Some(task) = task_map.get(task_id) + && dependencies_failed(task, &states) + { + newly_skipped.push(task_id.clone()); + } + } + for task_id in newly_skipped { + pending.remove(&task_id); + states.insert( + task_id, + SwarmTaskState::Skipped("Dependency failed".to_string()), + ); + changed = true; + } + + let mut ready = Vec::new(); + for task_id in pending.iter() { + if let Some(task) = task_map.get(task_id) + && dependencies_satisfied(task, &states) + { + ready.push(task_id.clone()); + } + } + + if !ready.is_empty() { + let available_slots = { + let manager = shared_manager.lock().await; + let global_slots = manager.available_slots(); + let swarm_slots = max_parallel.saturating_sub(running.len()); + global_slots.min(swarm_slots) + }; + + if available_slots > 0 { + for task_id in ready.into_iter().take(available_slots) { + let task = task_map + .get(&task_id) + .ok_or_else(|| ToolError::execution_failed("Missing swarm task"))?; + let agent_type = task.agent_type.clone().unwrap_or_default(); + let prompt = format_prompt(shared_context.as_deref(), &task.prompt); + + let spawn_result = { + let mut manager = shared_manager.lock().await; + manager.spawn_background( + Arc::clone(shared_manager), + runtime.clone(), + agent_type, + prompt, + task.allowed_tools.clone(), + ) + }; + + match spawn_result { + Ok(snapshot) => { + states.insert( + task_id.clone(), + SwarmTaskState::Running { + agent_id: snapshot.agent_id.clone(), + }, + ); + running.insert(task_id.clone(), snapshot.agent_id); + pending.remove(&task_id); + changed = true; + } + Err(err) => { + let message = err.to_string(); + if message.contains("Sub-agent limit reached") { + break; + } + states.insert(task_id.clone(), SwarmTaskState::Failed(message)); + pending.remove(&task_id); + changed = true; + if fail_fast { + fail_fast_triggered = true; + } + } + } + } + } + } + + if fail_fast_triggered { + apply_fail_fast(shared_manager, &mut states, &mut pending, &mut running).await?; + break; + } + + if pending.is_empty() && running.is_empty() { + break; + } + if !block { + break; + } + if Instant::now() >= deadline { + timed_out = true; + break; + } + + if !changed { + tokio::time::sleep(SWARM_POLL_INTERVAL).await; + } + } + + let outcomes = build_task_outcomes(&task_order, &states); + let counts = build_counts(&outcomes); + let status = if fail_fast_triggered { + SwarmStatus::Failed + } else if timed_out { + SwarmStatus::Timeout + } else if counts.failed > 0 + || counts.cancelled > 0 + || counts.skipped > 0 + || counts.pending > 0 + || counts.running > 0 + { + SwarmStatus::Partial + } else { + SwarmStatus::Completed + }; + + Ok(SwarmOutcome { + swarm_id, + status, + duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX), + counts, + tasks: outcomes, + }) +} + +fn format_prompt(shared_context: Option<&str>, prompt: &str) -> String { + if let Some(context) = shared_context { + format!("Shared context:\n{context}\n\nTask:\n{prompt}") + } else { + prompt.to_string() + } +} + +fn dependencies_satisfied(task: &SwarmTaskSpec, states: &HashMap) -> bool { + task.depends_on.iter().all(|dep| { + matches!( + states.get(dep), + Some(SwarmTaskState::Done(result)) + if matches!(result.status, SubAgentStatus::Completed) + ) + }) +} + +fn dependencies_failed(task: &SwarmTaskSpec, states: &HashMap) -> bool { + task.depends_on.iter().any(|dep| match states.get(dep) { + Some(SwarmTaskState::Done(result)) => matches!( + result.status, + SubAgentStatus::Failed(_) | SubAgentStatus::Cancelled + ), + Some(SwarmTaskState::Failed(_)) | Some(SwarmTaskState::Skipped(_)) => true, + _ => false, + }) +} + +async fn cancel_running_tasks( + manager: &SharedSubAgentManager, + running: &HashMap, + states: &mut HashMap, +) -> Result<(), ToolError> { + let mut manager = manager.lock().await; + for (task_id, agent_id) in running { + match manager.cancel(agent_id) { + Ok(snapshot) => { + states.insert(task_id.clone(), SwarmTaskState::Done(snapshot)); + } + Err(err) => { + states.insert( + task_id.clone(), + SwarmTaskState::Failed(format!("Failed to cancel agent: {err}")), + ); + } + } + } + Ok(()) +} + +async fn apply_fail_fast( + manager: &SharedSubAgentManager, + states: &mut HashMap, + pending: &mut HashSet, + running: &mut HashMap, +) -> Result<(), ToolError> { + cancel_running_tasks(manager, running, states).await?; + for task_id in pending.drain() { + states.insert( + task_id, + SwarmTaskState::Skipped("Skipped due to fail_fast".to_string()), + ); + } + running.clear(); + Ok(()) +} + +fn build_task_outcomes( + order: &[String], + states: &HashMap, +) -> Vec { + order + .iter() + .map(|task_id| match states.get(task_id) { + Some(SwarmTaskState::Running { agent_id }) => SwarmTaskOutcome { + task_id: task_id.clone(), + agent_id: Some(agent_id.clone()), + status: SwarmTaskStatus::Running, + result: None, + error: None, + steps_taken: 0, + duration_ms: 0, + }, + Some(SwarmTaskState::Done(result)) => match &result.status { + SubAgentStatus::Completed => SwarmTaskOutcome { + task_id: task_id.clone(), + agent_id: Some(result.agent_id.clone()), + status: SwarmTaskStatus::Completed, + result: result.result.clone(), + error: None, + steps_taken: result.steps_taken, + duration_ms: result.duration_ms, + }, + SubAgentStatus::Failed(err) => SwarmTaskOutcome { + task_id: task_id.clone(), + agent_id: Some(result.agent_id.clone()), + status: SwarmTaskStatus::Failed, + result: result.result.clone(), + error: Some(err.clone()), + steps_taken: result.steps_taken, + duration_ms: result.duration_ms, + }, + SubAgentStatus::Cancelled => SwarmTaskOutcome { + task_id: task_id.clone(), + agent_id: Some(result.agent_id.clone()), + status: SwarmTaskStatus::Cancelled, + result: result.result.clone(), + error: Some("Cancelled".to_string()), + steps_taken: result.steps_taken, + duration_ms: result.duration_ms, + }, + SubAgentStatus::Running => SwarmTaskOutcome { + task_id: task_id.clone(), + agent_id: Some(result.agent_id.clone()), + status: SwarmTaskStatus::Running, + result: result.result.clone(), + error: None, + steps_taken: result.steps_taken, + duration_ms: result.duration_ms, + }, + }, + Some(SwarmTaskState::Failed(message)) => SwarmTaskOutcome { + task_id: task_id.clone(), + agent_id: None, + status: SwarmTaskStatus::Failed, + result: None, + error: Some(message.clone()), + steps_taken: 0, + duration_ms: 0, + }, + Some(SwarmTaskState::Skipped(message)) => SwarmTaskOutcome { + task_id: task_id.clone(), + agent_id: None, + status: SwarmTaskStatus::Skipped, + result: None, + error: Some(message.clone()), + steps_taken: 0, + duration_ms: 0, + }, + _ => SwarmTaskOutcome { + task_id: task_id.clone(), + agent_id: None, + status: SwarmTaskStatus::Pending, + result: None, + error: None, + steps_taken: 0, + duration_ms: 0, + }, + }) + .collect() +} + +fn build_counts(outcomes: &[SwarmTaskOutcome]) -> SwarmCounts { + let mut counts = SwarmCounts { + total: outcomes.len(), + completed: 0, + failed: 0, + cancelled: 0, + skipped: 0, + running: 0, + pending: 0, + }; + + for outcome in outcomes { + match outcome.status { + SwarmTaskStatus::Completed => counts.completed += 1, + SwarmTaskStatus::Failed => counts.failed += 1, + SwarmTaskStatus::Cancelled => counts.cancelled += 1, + SwarmTaskStatus::Skipped => counts.skipped += 1, + SwarmTaskStatus::Running => counts.running += 1, + SwarmTaskStatus::Pending => counts.pending += 1, + } + } + + counts +} + +fn validate_swarm_tasks(tasks: &[SwarmTaskSpec]) -> Result<(), ToolError> { + if tasks.is_empty() { + return Err(ToolError::invalid_input("tasks cannot be empty")); + } + + let mut ids = HashSet::new(); + for task in tasks { + let id = task.id.trim(); + if id.is_empty() { + return Err(ToolError::invalid_input("task id cannot be empty")); + } + if task.prompt.trim().is_empty() { + return Err(ToolError::invalid_input(format!( + "task '{id}' prompt cannot be empty" + ))); + } + if matches!(task.agent_type, Some(SubAgentType::Custom)) { + let tools = task + .allowed_tools + .as_ref() + .map(Vec::as_slice) + .unwrap_or(&[]); + if tools.is_empty() { + return Err(ToolError::invalid_input(format!( + "task '{id}' requires allowed_tools for custom type" + ))); + } + } + if !ids.insert(task.id.clone()) { + return Err(ToolError::invalid_input(format!( + "duplicate task id '{id}'" + ))); + } + if task.depends_on.iter().any(|dep| dep == id) { + return Err(ToolError::invalid_input(format!( + "task '{id}' cannot depend on itself" + ))); + } + } + + for task in tasks { + for dep in &task.depends_on { + if !ids.contains(dep) { + return Err(ToolError::invalid_input(format!( + "task '{}' depends on unknown task '{dep}'", + task.id + ))); + } + } + } + + if has_dependency_cycle(tasks) { + return Err(ToolError::invalid_input( + "task dependencies contain a cycle", + )); + } + + Ok(()) +} + +fn has_dependency_cycle(tasks: &[SwarmTaskSpec]) -> bool { + let mut deps = HashMap::new(); + for task in tasks { + deps.insert(task.id.clone(), task.depends_on.clone()); + } + + let mut visiting = HashSet::new(); + let mut visited = HashSet::new(); + + for id in deps.keys() { + if visit(id, &deps, &mut visiting, &mut visited) { + return true; + } + } + + false +} + +fn visit( + id: &str, + deps: &HashMap>, + visiting: &mut HashSet, + visited: &mut HashSet, +) -> bool { + if visited.contains(id) { + return false; + } + if !visiting.insert(id.to_string()) { + return true; + } + if let Some(children) = deps.get(id) { + for child in children { + if visit(child, deps, visiting, visited) { + return true; + } + } + } + visiting.remove(id); + visited.insert(id.to_string()); + false +} + +#[cfg(test)] +mod tests { + use super::{SwarmTaskSpec, validate_swarm_tasks}; + + fn task(id: &str, deps: &[&str]) -> SwarmTaskSpec { + SwarmTaskSpec { + id: id.to_string(), + prompt: "do work".to_string(), + agent_type: None, + allowed_tools: None, + depends_on: deps.iter().map(|dep| dep.to_string()).collect(), + } + } + + #[test] + fn validate_swarm_tasks_accepts_valid_graph() { + let tasks = vec![task("a", &[]), task("b", &["a"])]; + assert!(validate_swarm_tasks(&tasks).is_ok()); + } + + #[test] + fn validate_swarm_tasks_rejects_unknown_dependency() { + let tasks = vec![task("a", &["missing"])]; + assert!(validate_swarm_tasks(&tasks).is_err()); + } + + #[test] + fn validate_swarm_tasks_rejects_cycle() { + let tasks = vec![task("a", &["b"]), task("b", &["a"])]; + assert!(validate_swarm_tasks(&tasks).is_err()); + } +} diff --git a/src/tools/test_runner.rs b/src/tools/test_runner.rs new file mode 100644 index 00000000..0789f8c6 --- /dev/null +++ b/src/tools/test_runner.rs @@ -0,0 +1,253 @@ +//! Cargo test runner tool: `run_tests`. +//! +//! This tool intentionally auto-approves test execution to encourage +//! frequent verification loops while still scoping execution to the workspace. + +use std::path::Path; +use std::process::Command; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; + +use super::spec::{ + ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec, + optional_bool, optional_str, +}; + +const MAX_OUTPUT_CHARS: usize = 40_000; + +/// Tool for running `cargo test` in the workspace root. +pub struct RunTestsTool; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct RunTestsOutput { + success: bool, + exit_code: i32, + stdout: String, + stderr: String, + command: String, +} + +#[async_trait] +impl ToolSpec for RunTestsTool { + fn name(&self) -> &'static str { + "run_tests" + } + + fn description(&self) -> &'static str { + "Run `cargo test` in the workspace root with optional extra arguments." + } + + fn input_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "args": { + "type": "string", + "description": "Optional extra arguments to pass to `cargo test` (shell-style)." + }, + "all_features": { + "type": "boolean", + "description": "When true, include `--all-features`." + } + }, + "additionalProperties": false + }) + } + + fn capabilities(&self) -> Vec { + vec![ToolCapability::ExecutesCode, ToolCapability::Sandboxable] + } + + fn approval_requirement(&self) -> ApprovalRequirement { + // Tests are encouraged, so avoid gating them behind approval. + ApprovalRequirement::Auto + } + + async fn execute(&self, input: Value, context: &ToolContext) -> Result { + let all_features = optional_bool(&input, "all_features", false); + let extra_args = optional_str(&input, "args") + .map(str::trim) + .filter(|s| !s.is_empty()); + + let mut args = vec!["test".to_string()]; + if all_features { + args.push("--all-features".to_string()); + } + if let Some(extra) = extra_args { + let split = shlex::split(extra).ok_or_else(|| { + ToolError::invalid_input("Failed to parse 'args' as shell-style tokens") + })?; + args.extend(split); + } + + let command_str = format_command(&context.workspace, &args); + let output = run_cargo(&context.workspace, &args)?; + + let exit_code = output.status.code().unwrap_or(-1); + let stdout_raw = String::from_utf8_lossy(&output.stdout); + let stderr_raw = String::from_utf8_lossy(&output.stderr); + let stdout = truncate_with_note(&stdout_raw, MAX_OUTPUT_CHARS); + let stderr = truncate_with_note(&stderr_raw, MAX_OUTPUT_CHARS); + + let result = RunTestsOutput { + success: output.status.success(), + exit_code, + stdout, + stderr, + command: command_str, + }; + + ToolResult::json(&result).map_err(|e| ToolError::execution_failed(e.to_string())) + } +} + +// === Helpers === + +fn run_cargo(workspace: &Path, args: &[String]) -> Result { + let mut cmd = Command::new("cargo"); + cmd.args(args).current_dir(workspace); + cmd.output().map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + ToolError::not_available("cargo is not installed or not in PATH") + } else { + ToolError::execution_failed(format!("Failed to run cargo: {e}")) + } + }) +} + +fn format_command(workspace: &Path, args: &[String]) -> String { + format!( + "(cd {} && cargo {})", + workspace.display(), + args.iter() + .map(String::as_str) + .collect::>() + .join(" ") + ) +} + +fn truncate_with_note(text: &str, max_chars: usize) -> String { + if text.chars().count() <= max_chars { + return text.to_string(); + } + let end = char_boundary_index(text, max_chars); + let truncated = &text[..end]; + let omitted_chars = text + .chars() + .count() + .saturating_sub(truncated.chars().count()); + let note = format!( + "\n\n[output truncated to {max_chars} characters; {omitted_chars} characters omitted]" + ); + format!("{truncated}{note}") +} + +fn char_boundary_index(text: &str, max_chars: usize) -> usize { + if max_chars == 0 { + return 0; + } + for (count, (idx, _)) in text.char_indices().enumerate() { + if count == max_chars { + return idx; + } + } + text.len() +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::process::Command; + use tempfile::tempdir; + + fn cargo_available() -> bool { + Command::new("cargo") + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + } + + fn init_cargo_project(root: &Path) -> std::path::PathBuf { + let project_dir = root.join("project"); + fs::create_dir_all(&project_dir).expect("create project dir"); + let status = Command::new("cargo") + .args([ + "init", + "--lib", + "--vcs", + "none", + "-q", + "--name", + "eval_project", + ]) + .current_dir(&project_dir) + .status() + .expect("cargo should spawn"); + assert!(status.success(), "cargo init failed"); + project_dir + } + + #[tokio::test] + async fn run_tests_succeeds_on_fresh_project() { + if !cargo_available() { + return; + } + let tmp = tempdir().expect("tempdir"); + let project_dir = init_cargo_project(tmp.path()); + + let ctx = ToolContext::new(&project_dir); + let tool = RunTestsTool; + let result = tool.execute(json!({}), &ctx).await.expect("execute"); + assert!(result.success); + + let parsed: RunTestsOutput = + serde_json::from_str(&result.content).expect("tool result should be json"); + assert!(parsed.success); + assert_eq!(parsed.exit_code, 0); + assert!(parsed.command.contains("cargo test")); + } + + #[tokio::test] + async fn run_tests_reports_failures_without_hard_error() { + if !cargo_available() { + return; + } + let tmp = tempdir().expect("tempdir"); + let project_dir = init_cargo_project(tmp.path()); + + let lib_rs = project_dir.join("src/lib.rs"); + let failing = r#" +pub fn add(a: i32, b: i32) -> i32 { a + b } + +#[cfg(test)] +mod tests { + #[test] + fn fails() { + assert_eq!(2 + 2, 5); + } +} +"#; + fs::write(&lib_rs, failing).expect("write failing test"); + + let ctx = ToolContext::new(&project_dir); + let tool = RunTestsTool; + let result = tool.execute(json!({}), &ctx).await.expect("execute"); + assert!(result.success); + + let parsed: RunTestsOutput = + serde_json::from_str(&result.content).expect("tool result should be json"); + assert!(!parsed.success); + assert_ne!(parsed.exit_code, 0); + } + + #[test] + fn truncation_adds_note() { + let long = "x".repeat(MAX_OUTPUT_CHARS + 128); + let truncated = truncate_with_note(&long, MAX_OUTPUT_CHARS); + assert!(truncated.contains("output truncated")); + } +} diff --git a/src/tools/todo.rs b/src/tools/todo.rs index 9dd52d17..e54efef7 100644 --- a/src/tools/todo.rs +++ b/src/tools/todo.rs @@ -1,6 +1,7 @@ //! Todo list tool and supporting data structures. -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use tokio::sync::Mutex; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -327,10 +328,7 @@ impl ToolSpec for TodoAddTool { .and_then(TodoStatus::from_str) .unwrap_or(TodoStatus::Pending); - let mut list = self - .todo_list - .lock() - .map_err(|e| ToolError::execution_failed(format!("Failed to lock todo list: {e}")))?; + let mut list = self.todo_list.lock().await; let item = list.add(content.to_string(), status); let snapshot = list.snapshot(); @@ -407,10 +405,7 @@ impl ToolSpec for TodoUpdateTool { .and_then(TodoStatus::from_str) .ok_or_else(|| ToolError::invalid_input("Missing or invalid 'status'"))?; - let mut list = self - .todo_list - .lock() - .map_err(|e| ToolError::execution_failed(format!("Failed to lock todo list: {e}")))?; + let mut list = self.todo_list.lock().await; let updated = list.update_status(id, status); let snapshot = list.snapshot(); let result = serde_json::to_string_pretty(&snapshot).unwrap_or_else(|_| "{}".to_string()); @@ -468,10 +463,7 @@ impl ToolSpec for TodoListTool { _input: serde_json::Value, _context: &ToolContext, ) -> Result { - let list = self - .todo_list - .lock() - .map_err(|e| ToolError::execution_failed(format!("Failed to lock todo list: {e}")))?; + let list = self.todo_list.lock().await; let snapshot = list.snapshot(); let result = serde_json::to_string_pretty(&snapshot).unwrap_or_else(|_| "{}".to_string()); Ok(ToolResult::success(format!( @@ -539,10 +531,7 @@ impl ToolSpec for TodoWriteTool { .and_then(|v| v.as_array()) .ok_or_else(|| ToolError::invalid_input("Missing or invalid 'todos' array"))?; - let mut list = self - .todo_list - .lock() - .map_err(|e| ToolError::execution_failed(format!("Failed to lock todo list: {e}")))?; + let mut list = self.todo_list.lock().await; // Clear and rebuild the list list.clear(); diff --git a/src/tui/app.rs b/src/tui/app.rs index a0e5a1d7..21331bc3 100644 --- a/src/tui/app.rs +++ b/src/tui/app.rs @@ -878,9 +878,8 @@ impl App { } pub fn clear_todos(&mut self) { - if let Ok(mut plan) = self.plan_state.lock() { - *plan = crate::tools::plan::PlanState::default(); - } + let mut plan = self.plan_state.blocking_lock(); + *plan = crate::tools::plan::PlanState::default(); } } diff --git a/src/tui/ui.rs b/src/tui/ui.rs index 66a5de56..5157cb60 100644 --- a/src/tui/ui.rs +++ b/src/tui/ui.rs @@ -1107,6 +1107,7 @@ async fn dispatch_user_message( app.system_prompt = Some(prompts::system_prompt_for_mode_with_context( app.mode, &app.workspace, + None, rlm_summary.as_deref(), duo_summary.as_deref(), )); diff --git a/src/working_set.rs b/src/working_set.rs new file mode 100644 index 00000000..dfbfff3f --- /dev/null +++ b/src/working_set.rs @@ -0,0 +1,785 @@ +//! Repo-aware working set tracking and prompt context packing. +//! +//! The goal of this module is to keep a small, high-signal list of +//! "active" paths that the assistant should prioritize. It observes +//! user messages and tool calls, extracts likely paths, and produces: +//! - a compact working-set summary block for the system prompt +//! - pinned message indices that compaction should preserve + +use crate::models::{ContentBlock, Message}; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::{HashMap, HashSet}; +use std::ffi::OsStr; +use std::fs; +use std::path::{Path, PathBuf}; +use std::sync::OnceLock; + +/// Configuration for working-set tracking. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkingSetConfig { + /// Maximum number of entries to keep. + pub max_entries: usize, + /// Maximum number of paths to pin during compaction. + pub max_pinned_paths: usize, + /// Maximum characters to scan per text block when pinning messages. + pub max_scan_chars: usize, + /// Maximum entries to show in the system prompt block. + pub max_prompt_entries: usize, +} + +impl Default for WorkingSetConfig { + fn default() -> Self { + Self { + max_entries: 16, + max_pinned_paths: 8, + max_scan_chars: 2_000, + max_prompt_entries: 8, + } + } +} + +/// The source that most recently updated an entry. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +pub enum WorkingSetSource { + UserMessage, + ToolInput, + ToolOutput, + Rebuild, +} + +/// A single working-set entry. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkingSetEntry { + /// Workspace-relative path string. + pub path: String, + /// Whether the path is a directory (best-effort). + pub is_dir: bool, + /// Whether the path exists on disk (best-effort). + pub exists: bool, + /// Number of times this path was observed. + pub touches: u32, + /// The last observed turn index. + pub last_turn: u64, + /// The last update source. + pub last_source: WorkingSetSource, +} + +impl WorkingSetEntry { + fn new(path: String, exists: bool, is_dir: bool, turn: u64, source: WorkingSetSource) -> Self { + Self { + path, + is_dir, + exists, + touches: 1, + last_turn: turn, + last_source: source, + } + } +} + +/// Repo-aware working-set state. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct WorkingSet { + /// Tracking configuration. + pub config: WorkingSetConfig, + /// Monotonic turn counter (increments on user messages). + pub turn: u64, + /// Path entries keyed by workspace-relative path. + pub entries: HashMap, +} + +impl WorkingSet { + /// Advance to the next turn. + pub fn next_turn(&mut self) { + self.turn = self.turn.saturating_add(1); + } + + /// Observe a user message and update the working set. + pub fn observe_user_message(&mut self, text: &str, workspace: &Path) { + self.next_turn(); + let paths = extract_paths_from_text(text); + self.record_candidates(paths, workspace, WorkingSetSource::UserMessage); + } + + /// Observe a tool call (input and optional output). + pub fn observe_tool_call( + &mut self, + tool_name: &str, + input: &Value, + output: Option<&str>, + workspace: &Path, + ) { + let input_candidates = extract_paths_from_value(input, Some(tool_name)); + self.record_candidates(input_candidates, workspace, WorkingSetSource::ToolInput); + + if let Some(text) = output { + let output_candidates = extract_paths_from_text(text); + self.record_candidates(output_candidates, workspace, WorkingSetSource::ToolOutput); + } + } + + /// Rebuild the working set from existing messages (best effort). + /// + /// This is used when syncing a resumed session. + pub fn rebuild_from_messages(&mut self, messages: &[Message], workspace: &Path) { + self.entries.clear(); + self.turn = 0; + + for message in messages { + if message.role == "user" { + self.next_turn(); + } + let candidates = extract_paths_from_message(message); + if candidates.is_empty() { + continue; + } + self.record_candidates(candidates, workspace, WorkingSetSource::Rebuild); + } + } + + /// Render a compact working-set block for the system prompt. + pub fn summary_block(&self, workspace: &Path) -> Option { + let entries = self.sorted_entries(); + let prompt_entries: Vec<&WorkingSetEntry> = entries + .into_iter() + .take(self.config.max_prompt_entries) + .collect(); + + let repo_summary = summarize_repo_root(workspace); + + if repo_summary.is_none() && prompt_entries.is_empty() { + return None; + } + + let mut lines: Vec = Vec::new(); + lines.push("## Repo Working Set".to_string()); + lines.push(format!("Workspace: {}", workspace.display())); + + if let Some(summary) = repo_summary { + lines.push(summary); + } + + if !prompt_entries.is_empty() { + lines.push("Active paths (prioritize these):".to_string()); + for entry in prompt_entries { + let age = self.turn.saturating_sub(entry.last_turn); + let kind = if entry.is_dir { "dir" } else { "file" }; + lines.push(format!( + "- {} ({kind}, touches: {}, last seen: {} turn(s) ago)", + entry.path, entry.touches, age + )); + } + } + + lines.push( + "When in doubt, use tools to verify and keep changes focused on the working set." + .to_string(), + ); + + Some(lines.join("\n")) + } + + /// Return the most relevant paths in score order. + pub fn top_paths(&self, limit: usize) -> Vec { + self.sorted_entries() + .into_iter() + .take(limit) + .map(|entry| entry.path.clone()) + .collect() + } + + /// Identify message indices that should be pinned during compaction. + pub fn pinned_message_indices(&self, messages: &[Message], workspace: &Path) -> Vec { + if messages.is_empty() || self.entries.is_empty() { + return Vec::new(); + } + + let pinned_paths: Vec<&WorkingSetEntry> = self + .sorted_entries() + .into_iter() + .take(self.config.max_pinned_paths) + .collect(); + if pinned_paths.is_empty() { + return Vec::new(); + } + + let needles = build_search_needles(&pinned_paths, workspace); + if needles.is_empty() { + return Vec::new(); + } + + let mut pinned: Vec = Vec::new(); + for (idx, message) in messages.iter().enumerate() { + if message_mentions_any_path(message, &needles, self.config.max_scan_chars) { + pinned.push(idx); + } + } + pinned + } + + fn record_candidates( + &mut self, + candidates: Vec, + workspace: &Path, + source: WorkingSetSource, + ) { + if candidates.is_empty() { + return; + } + + let workspace_canon = workspace.canonicalize().ok(); + + for raw in candidates { + let Some(normalized) = normalize_candidate(&raw) else { + continue; + }; + let Some((rel, exists, is_dir)) = + relativize_candidate(&normalized, workspace, workspace_canon.as_deref()) + else { + continue; + }; + self.record_path(rel, exists, is_dir, source); + } + + self.prune(); + } + + fn record_path(&mut self, rel: String, exists: bool, is_dir: bool, source: WorkingSetSource) { + match self.entries.get_mut(&rel) { + Some(entry) => { + entry.exists |= exists; + entry.is_dir |= is_dir; + entry.touches = entry.touches.saturating_add(1); + entry.last_turn = self.turn; + entry.last_source = source; + } + None => { + let entry = WorkingSetEntry::new(rel.clone(), exists, is_dir, self.turn, source); + let _ = self.entries.insert(rel, entry); + } + } + } + + fn prune(&mut self) { + let max_entries = self.config.max_entries; + if self.entries.len() <= max_entries { + return; + } + + // Rank by score ascending and drop the lowest until within bounds. + let mut ranked: Vec<(String, i64)> = self + .entries + .values() + .map(|entry| (entry.path.clone(), score_entry(entry, self.turn))) + .collect(); + ranked.sort_by(|a, b| a.1.cmp(&b.1)); + + let to_remove = self.entries.len().saturating_sub(max_entries); + for (path, _) in ranked.into_iter().take(to_remove) { + let _ = self.entries.remove(&path); + } + } + + fn sorted_entries(&self) -> Vec<&WorkingSetEntry> { + let mut entries: Vec<&WorkingSetEntry> = self.entries.values().collect(); + entries.sort_by(|a, b| { + let sb = score_entry(b, self.turn); + let sa = score_entry(a, self.turn); + sb.cmp(&sa).then_with(|| a.path.cmp(&b.path)) + }); + entries + } +} + +fn score_entry(entry: &WorkingSetEntry, current_turn: u64) -> i64 { + let age = current_turn.saturating_sub(entry.last_turn); + let recency_bonus = match age { + 0 => 6, + 1 => 4, + 2 => 3, + 3..=5 => 2, + 6..=10 => 1, + _ => 0, + }; + i64::from(entry.touches) * 4 + recency_bonus +} + +fn normalize_candidate(raw: &str) -> Option { + let trimmed = raw.trim().trim_matches(|c: char| { + matches!( + c, + '"' | '\'' | '`' | ',' | ';' | ':' | '(' | ')' | '[' | ']' + ) + }); + if trimmed.is_empty() { + return None; + } + Some(trimmed.to_string()) +} + +fn relativize_candidate( + candidate: &str, + workspace: &Path, + workspace_canon: Option<&Path>, +) -> Option<(String, bool, bool)> { + let candidate_path = Path::new(candidate); + + // Reject obvious URLs and non-paths early. + if candidate.contains("://") { + return None; + } + + let (rel_path, abs_path) = if candidate_path.is_absolute() { + let within_workspace = workspace_canon + .map(|ws| candidate_path.starts_with(ws)) + .unwrap_or_else(|| candidate_path.starts_with(workspace)); + if !within_workspace { + return None; + } + let rel = candidate_path.strip_prefix(workspace).ok()?.to_path_buf(); + (rel, candidate_path.to_path_buf()) + } else { + if starts_with_parent_dir(candidate_path) { + return None; + } + let rel = clean_relative(candidate_path); + let abs = workspace.join(&rel); + (rel, abs) + }; + + let metadata = fs::metadata(&abs_path).ok(); + let exists = metadata.is_some(); + let is_dir = metadata + .as_ref() + .map(fs::Metadata::is_dir) + .unwrap_or_else(|| candidate.ends_with('/')); + + let rel_string = path_to_string(&rel_path)?; + Some((rel_string, exists, is_dir)) +} + +fn starts_with_parent_dir(path: &Path) -> bool { + matches!( + path.components().next(), + Some(std::path::Component::ParentDir) + ) +} + +fn clean_relative(path: &Path) -> PathBuf { + use std::path::Component; + + let mut parts: Vec = Vec::new(); + for comp in path.components() { + match comp { + Component::CurDir => {} + Component::ParentDir => { + let _ = parts.pop(); + } + Component::Normal(p) => parts.push(PathBuf::from(p)), + Component::RootDir | Component::Prefix(_) => {} + } + } + let mut out = PathBuf::new(); + for part in parts { + out.push(part); + } + out +} + +fn path_to_string(path: &Path) -> Option { + path.as_os_str().to_str().map(ToOwned::to_owned) +} + +fn extract_paths_from_message(message: &Message) -> Vec { + let mut paths = Vec::new(); + for block in &message.content { + match block { + ContentBlock::Text { text, .. } => { + paths.extend(extract_paths_from_text(text)); + } + ContentBlock::ToolUse { input, .. } => { + paths.extend(extract_paths_from_value(input, None)); + } + ContentBlock::ToolResult { content, .. } => { + paths.extend(extract_paths_from_text(content)); + } + ContentBlock::Thinking { .. } => {} + } + } + paths +} + +fn extract_paths_from_value(value: &Value, tool_hint: Option<&str>) -> Vec { + let mut out = Vec::new(); + extract_paths_from_value_inner(value, tool_hint, None, &mut out); + out +} + +fn extract_paths_from_value_inner( + value: &Value, + tool_hint: Option<&str>, + key_hint: Option<&str>, + out: &mut Vec, +) { + match value { + Value::String(s) => { + let key_suggests_path = key_hint.map(key_is_path_like).unwrap_or(false); + if key_suggests_path || looks_like_path(s) { + out.extend(extract_paths_from_text(s)); + if key_suggests_path && !s.contains('/') && !s.contains('\\') { + out.push(s.to_string()); + } + } else if tool_hint == Some("exec_shell") && s.len() < 400 { + out.extend(extract_paths_from_text(s)); + } + } + Value::Array(arr) => { + for item in arr { + extract_paths_from_value_inner(item, tool_hint, key_hint, out); + } + } + Value::Object(map) => { + for (k, v) in map { + extract_paths_from_value_inner(v, tool_hint, Some(k.as_str()), out); + } + } + Value::Null | Value::Bool(_) | Value::Number(_) => {} + } +} + +fn key_is_path_like(key: &str) -> bool { + let lower = key.to_ascii_lowercase(); + lower.contains("path") + || lower.contains("file") + || lower.contains("dir") + || lower.contains("cwd") + || lower.contains("workspace") + || lower.contains("root") + || lower == "target" +} + +fn looks_like_path(text: &str) -> bool { + let trimmed = text.trim(); + if trimmed.is_empty() { + return false; + } + if trimmed.contains('/') || trimmed.contains('\\') { + return true; + } + match Path::new(trimmed).extension().and_then(OsStr::to_str) { + Some(ext) => COMMON_EXTENSIONS.contains(&ext), + None => false, + } +} + +const COMMON_EXTENSIONS: &[&str] = &[ + "rs", "toml", "md", "txt", "json", "yaml", "yml", "ts", "tsx", "js", "jsx", "py", "go", "java", + "c", "cc", "cpp", "h", "hpp", "sh", "bash", "zsh", "sql", "html", "css", "scss", +]; + +fn extract_paths_from_text(text: &str) -> Vec { + if text.trim().is_empty() { + return Vec::new(); + } + + let re = path_regex(); + re.find_iter(text) + .map(|m| m.as_str().to_string()) + .filter(|s| looks_like_path(s)) + .collect() +} + +fn path_regex() -> &'static Regex { + static RE: OnceLock = OnceLock::new(); + RE.get_or_init(|| { + // Path-ish tokens with separators or file extensions. + Regex::new( + r#"(?x) + (?: + (?:[A-Za-z]:\\)? # optional Windows drive + (?:\./|\../|/)? # optional leading + [A-Za-z0-9._-]+ + (?:[/\\][A-Za-z0-9._-]+)+ + (?:\.[A-Za-z0-9]{1,8})? # optional extension + ) + | + (?: + [A-Za-z0-9._-]+\.[A-Za-z0-9]{1,8} + ) + "#, + ) + .expect("path regex should compile") + }) +} + +fn truncate_chars(text: &str, max_chars: usize) -> &str { + if max_chars == 0 { + return ""; + } + match text.char_indices().nth(max_chars) { + Some((idx, _)) => &text[..idx], + None => text, + } +} + +fn build_search_needles(entries: &[&WorkingSetEntry], workspace: &Path) -> Vec { + let mut needles: HashSet = HashSet::new(); + for entry in entries { + let rel = entry.path.clone(); + if rel.is_empty() { + continue; + } + let abs = workspace.join(&rel); + let abs_str = abs.as_os_str().to_str().map(ToOwned::to_owned); + + let _ = needles.insert(rel.clone()); + if let Some(abs_str) = abs_str { + let _ = needles.insert(abs_str); + } + } + needles.into_iter().collect() +} + +fn message_mentions_any_path(message: &Message, needles: &[String], max_scan_chars: usize) -> bool { + if needles.is_empty() { + return false; + } + for block in &message.content { + match block { + ContentBlock::Text { text, .. } => { + let snippet = truncate_chars(text, max_scan_chars); + if contains_any(snippet, needles) { + return true; + } + } + ContentBlock::ToolUse { input, .. } => { + if let Ok(json) = serde_json::to_string(input) + && contains_any(&json, needles) + { + return true; + } + } + ContentBlock::ToolResult { content, .. } => { + let snippet = truncate_chars(content, max_scan_chars); + if contains_any(snippet, needles) { + return true; + } + } + ContentBlock::Thinking { .. } => {} + } + } + false +} + +fn contains_any(text: &str, needles: &[String]) -> bool { + needles + .iter() + .any(|needle| !needle.is_empty() && text.contains(needle)) +} + +fn summarize_repo_root(workspace: &Path) -> Option { + let key_files = detect_key_files(workspace); + let top_dirs = list_top_level_dirs(workspace, 8); + + if key_files.is_empty() && top_dirs.is_empty() { + return None; + } + + let mut parts: Vec = Vec::new(); + if !key_files.is_empty() { + parts.push(format!("Key files: {}", key_files.join(", "))); + } + if !top_dirs.is_empty() { + parts.push(format!("Top-level dirs: {}", top_dirs.join(", "))); + } + Some(parts.join("\n")) +} + +fn detect_key_files(workspace: &Path) -> Vec { + const CANDIDATES: &[&str] = &[ + "Cargo.toml", + "README.md", + "AGENTS.md", + "CLAUDE.md", + "package.json", + "pyproject.toml", + "go.mod", + "Makefile", + ]; + + CANDIDATES + .iter() + .filter_map(|name| { + let path = workspace.join(name); + if path.exists() { + Some((*name).to_string()) + } else { + None + } + }) + .collect() +} + +fn list_top_level_dirs(workspace: &Path, limit: usize) -> Vec { + let mut dirs = Vec::new(); + let entries = match fs::read_dir(workspace) { + Ok(entries) => entries, + Err(_) => return dirs, + }; + + for entry in entries.flatten() { + let file_name = entry.file_name(); + let Some(name) = file_name.to_str() else { + continue; + }; + + if name.starts_with('.') || IGNORED_ROOT_DIRS.contains(&name) { + continue; + } + + if let Ok(meta) = entry.metadata() + && meta.is_dir() + { + dirs.push(name.to_string()); + } + + if dirs.len() >= limit { + break; + } + } + + dirs.sort(); + dirs +} + +const IGNORED_ROOT_DIRS: &[&str] = &["target", "node_modules", "dist", "build", ".git"]; + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn make_message(role: &str, text: &str) -> Message { + Message { + role: role.to_string(), + content: vec![ContentBlock::Text { + text: text.to_string(), + cache_control: None, + }], + } + } + + #[test] + fn observe_user_message_tracks_paths() { + let tmp = TempDir::new().expect("tempdir"); + let src = tmp.path().join("src"); + let file = src.join("lib.rs"); + fs::create_dir_all(&src).expect("mkdir"); + fs::write(&file, "pub fn x() {}").expect("write"); + + let mut ws = WorkingSet::default(); + ws.observe_user_message("Please check src/lib.rs", tmp.path()); + + assert!(ws.entries.contains_key("src/lib.rs")); + let entry = ws.entries.get("src/lib.rs").expect("entry"); + assert!(entry.exists); + assert!(!entry.is_dir); + } + + #[test] + fn observe_tool_call_extracts_paths_from_input() { + let tmp = TempDir::new().expect("tempdir"); + let file = tmp.path().join("Cargo.toml"); + fs::write(&file, "[package]\nname = \"x\"").expect("write"); + + let mut ws = WorkingSet::default(); + let input = serde_json::json!({ "path": "Cargo.toml" }); + ws.observe_tool_call("read_file", &input, None, tmp.path()); + + assert!(ws.entries.contains_key("Cargo.toml")); + } + + #[test] + fn pinned_message_indices_respects_working_set() { + let tmp = TempDir::new().expect("tempdir"); + let src = tmp.path().join("src"); + fs::create_dir_all(&src).expect("mkdir"); + let file = src.join("main.rs"); + fs::write(&file, "fn main() {}").expect("write"); + + let mut ws = WorkingSet::default(); + ws.observe_user_message("Edit src/main.rs", tmp.path()); + + let messages = vec![ + make_message("user", "Unrelated text"), + make_message("assistant", "I will read src/main.rs next."), + make_message("user", "More unrelated text"), + ]; + + let pinned = ws.pinned_message_indices(&messages, tmp.path()); + assert_eq!(pinned, vec![1]); + } + + #[test] + fn summary_block_includes_repo_and_working_set() { + let tmp = TempDir::new().expect("tempdir"); + fs::write(tmp.path().join("Cargo.toml"), "[package]\nname = \"x\"").expect("write"); + let src = tmp.path().join("src"); + fs::create_dir_all(&src).expect("mkdir"); + fs::write(src.join("lib.rs"), "pub fn x() {}").expect("write"); + + let mut ws = WorkingSet::default(); + ws.observe_user_message("src/lib.rs", tmp.path()); + let block = ws.summary_block(tmp.path()).expect("block"); + + assert!(block.contains("Repo Working Set")); + assert!(block.contains("Cargo.toml")); + assert!(block.contains("src")); + assert!(block.contains("src/lib.rs")); + } + + #[test] + fn extract_paths_from_message_picks_up_tool_results() { + let msg = Message { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult { + tool_use_id: "tool_1".to_string(), + content: "Changed src/compaction.rs".to_string(), + }], + }; + + let paths = extract_paths_from_message(&msg); + assert!(paths.iter().any(|p| p.contains("src/compaction.rs"))); + } + + #[test] + fn pinning_prefers_high_signal_paths() { + let tmp = TempDir::new().expect("tempdir"); + fs::create_dir_all(tmp.path().join("src")).expect("mkdir"); + fs::write(tmp.path().join("src/a.rs"), "a").expect("write"); + fs::write(tmp.path().join("src/b.rs"), "b").expect("write"); + + let mut ws = WorkingSet::default(); + ws.observe_user_message("src/a.rs", tmp.path()); + ws.observe_tool_call( + "read_file", + &serde_json::json!({ "path": "src/a.rs" }), + Some("src/a.rs"), + tmp.path(), + ); + ws.observe_user_message("src/b.rs", tmp.path()); + + let a_score = score_entry(ws.entries.get("src/a.rs").expect("a"), ws.turn); + let b_score = score_entry(ws.entries.get("src/b.rs").expect("b"), ws.turn); + assert!(a_score >= b_score); + } + + #[test] + fn estimate_tokens_is_available_for_future_budgeting() { + use crate::compaction::estimate_tokens; + let messages = vec![make_message("user", "src/main.rs")]; + assert!(estimate_tokens(&messages) > 0); + } +} diff --git a/tests/eval_harness.rs b/tests/eval_harness.rs new file mode 100644 index 00000000..0df16011 --- /dev/null +++ b/tests/eval_harness.rs @@ -0,0 +1,100 @@ +//! Integration tests for the offline evaluation harness. + +use std::fs; + +#[path = "../src/eval.rs"] +mod eval; + +use eval::{EvalHarness, EvalHarnessConfig, ScenarioStepKind}; + +#[test] +fn runs_offline_tool_loop_successfully() { + let harness = EvalHarness::default(); + let run = harness.run().expect("eval harness run should succeed"); + assert_eq!( + ScenarioStepKind::parse("patch"), + Some(ScenarioStepKind::ApplyPatch) + ); + + assert!(run.metrics.success, "expected success metrics: {run:#?}"); + assert_eq!(run.metrics.tool_errors, 0); + assert_eq!(run.metrics.steps, 6); + assert!(run.metrics.duration.as_millis() > 0); + assert!(!run.scenario_name.is_empty()); + assert!(run.workspace_summary.file_count >= 3); + + for kind in [ + ScenarioStepKind::List, + ScenarioStepKind::Read, + ScenarioStepKind::Search, + ScenarioStepKind::Edit, + ScenarioStepKind::ApplyPatch, + ScenarioStepKind::ExecShell, + ] { + let stats = run + .metrics + .per_tool + .get(&kind) + .expect("missing per-tool stats"); + assert_eq!(stats.invocations, 1, "unexpected invocations for {kind:?}"); + assert_eq!(stats.errors, 0, "unexpected errors for {kind:?}"); + assert!(stats.total_duration.as_nanos() > 0); + } + + let notes_path = run.workspace_root().join("notes.txt"); + let notes = fs::read_to_string(¬es_path).expect("notes.txt should exist"); + assert!(notes.contains("edited = true")); + assert!(notes.contains("todo: offline metrics (patched)")); + + let report = run.to_report(); + assert_eq!(report.metrics.success, run.metrics.success); +} + +#[test] +fn records_tool_errors_when_step_fails() { + let config = EvalHarnessConfig { + fail_step: Some(ScenarioStepKind::ApplyPatch), + ..EvalHarnessConfig::default() + }; + let harness = EvalHarness::new(config); + + let run = harness + .run() + .expect("eval harness should return metrics even when a step fails"); + + assert!(!run.metrics.success); + assert!(run.metrics.tool_errors >= 1); + + let patch_stats = run + .metrics + .per_tool + .get(&ScenarioStepKind::ApplyPatch) + .expect("missing apply_patch stats"); + assert_eq!(patch_stats.invocations, 1); + assert_eq!(patch_stats.errors, 1); + + let patch_step = run + .steps + .iter() + .find(|step| step.kind == ScenarioStepKind::ApplyPatch) + .expect("missing apply_patch step"); + assert!(!patch_step.success); + assert!(patch_step.error.as_deref().is_some_and(|e| !e.is_empty())); +} + +#[test] +fn validation_can_fail_without_tool_errors() { + let config = EvalHarnessConfig { + shell_expect_token: "definitely-not-in-output".to_string(), + ..EvalHarnessConfig::default() + }; + let harness = EvalHarness::new(config); + + let run = harness.run().expect("eval harness run should complete"); + + assert_eq!(run.metrics.tool_errors, 0); + assert!( + !run.metrics.success, + "validation should fail due to shell token" + ); +}