From 37186c3d959720ab13bd6eafa9c98073b462bf12 Mon Sep 17 00:00:00 2001 From: Hunter Bown Date: Mon, 2 Mar 2026 17:52:46 -0600 Subject: [PATCH] Workspace migration: split into modular crates, parity CI, release updates - Convert root to Cargo workspace with crates/ layout - Add deepseek-* crates mirroring Codex architecture - Add parity CI workflow with snapshot/protocol/state tests - Update release workflow to build both deepseek and deepseek-tui binaries - Bump version to 0.3.28 --- .github/workflows/parity.yml | 58 + .github/workflows/release.yml | 50 +- Cargo.lock | 560 ++++++- Cargo.toml | 90 +- README.md | 43 +- crates/agent/Cargo.toml | 11 + crates/agent/src/lib.rs | 133 ++ crates/app-server/Cargo.toml | 26 + crates/app-server/src/lib.rs | 783 ++++++++++ crates/app-server/src/main.rs | 33 + crates/cli/Cargo.toml | 25 + crates/cli/src/main.rs | 1049 +++++++++++++ crates/config/Cargo.toml | 14 + crates/config/src/lib.rs | 477 ++++++ crates/core/Cargo.toml | 22 + crates/core/src/lib.rs | 1698 ++++++++++++++++++++++ crates/execpolicy/Cargo.toml | 12 + crates/execpolicy/src/lib.rs | 191 +++ crates/hooks/Cargo.toml | 17 + crates/hooks/src/lib.rs | 170 +++ crates/mcp/Cargo.toml | 13 + crates/mcp/src/lib.rs | 893 ++++++++++++ crates/protocol/Cargo.toml | 11 + crates/protocol/src/lib.rs | 451 ++++++ crates/protocol/tests/parity_protocol.rs | 50 + crates/state/Cargo.toml | 15 + crates/state/src/lib.rs | 950 ++++++++++++ crates/state/tests/parity_state.rs | 72 + crates/tools/Cargo.toml | 16 + crates/tools/src/lib.rs | 202 +++ crates/tools/tests/parity_tools.rs | 70 + crates/tui-core/Cargo.toml | 7 + crates/tui-core/src/lib.rs | 192 +++ crates/tui-core/tests/snapshot.rs | 25 + crates/tui/Cargo.toml | 70 + docs/parity_release_and_ci.md | 35 + docs/workspace_migration_status.md | 90 ++ 37 files changed, 8547 insertions(+), 77 deletions(-) create mode 100644 .github/workflows/parity.yml create mode 100644 crates/agent/Cargo.toml create mode 100644 crates/agent/src/lib.rs create mode 100644 crates/app-server/Cargo.toml create mode 100644 crates/app-server/src/lib.rs create mode 100644 crates/app-server/src/main.rs create mode 100644 crates/cli/Cargo.toml create mode 100644 crates/cli/src/main.rs create mode 100644 crates/config/Cargo.toml create mode 100644 crates/config/src/lib.rs create mode 100644 crates/core/Cargo.toml create mode 100644 crates/core/src/lib.rs create mode 100644 crates/execpolicy/Cargo.toml create mode 100644 crates/execpolicy/src/lib.rs create mode 100644 crates/hooks/Cargo.toml create mode 100644 crates/hooks/src/lib.rs create mode 100644 crates/mcp/Cargo.toml create mode 100644 crates/mcp/src/lib.rs create mode 100644 crates/protocol/Cargo.toml create mode 100644 crates/protocol/src/lib.rs create mode 100644 crates/protocol/tests/parity_protocol.rs create mode 100644 crates/state/Cargo.toml create mode 100644 crates/state/src/lib.rs create mode 100644 crates/state/tests/parity_state.rs create mode 100644 crates/tools/Cargo.toml create mode 100644 crates/tools/src/lib.rs create mode 100644 crates/tools/tests/parity_tools.rs create mode 100644 crates/tui-core/Cargo.toml create mode 100644 crates/tui-core/src/lib.rs create mode 100644 crates/tui-core/tests/snapshot.rs create mode 100644 crates/tui/Cargo.toml create mode 100644 docs/parity_release_and_ci.md create mode 100644 docs/workspace_migration_status.md diff --git a/.github/workflows/parity.yml b/.github/workflows/parity.yml new file mode 100644 index 00000000..cbc5dbae --- /dev/null +++ b/.github/workflows/parity.yml @@ -0,0 +1,58 @@ +name: parity + +on: + pull_request: + push: + branches: + - main + +env: + CARGO_TERM_COLOR: always + RUSTFLAGS: -Dwarnings + +jobs: + parity: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + components: clippy, rustfmt + + - name: Cache Cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - name: Format check + run: cargo fmt --all -- --check + + - name: Compile check + run: cargo check --workspace --all-targets --locked + + - name: Clippy + run: cargo clippy --workspace --all-targets --all-features --locked -- -D warnings + + - name: Unit and parity tests + run: cargo test --workspace --all-features --locked + + - name: TUI snapshot parity + run: cargo test -p deepseek-tui-core --test snapshot --locked + + - name: Protocol schema sanity + run: cargo test -p deepseek-protocol --test parity_protocol --locked + + - name: State persistence sanity + run: cargo test -p deepseek-state --test parity_state --locked + + - name: Lockfile drift guard + run: git diff --exit-code -- Cargo.lock diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6b96e40d..0a290f33 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -4,11 +4,42 @@ on: push: tags: ['v*'] +env: + CARGO_TERM_COLOR: always + RUSTFLAGS: -Dwarnings + jobs: + parity: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: clippy, rustfmt + - uses: Swatinem/rust-cache@v2 + - name: Format check + run: cargo fmt --all -- --check + - name: Compile check + run: cargo check --workspace --all-targets --locked + - name: Clippy + run: cargo clippy --workspace --all-targets --all-features --locked -- -D warnings + - name: Workspace tests + run: cargo test --workspace --all-features --locked + - name: TUI snapshot parity + run: cargo test -p deepseek-tui-core --test snapshot --locked + - name: Protocol schema parity + run: cargo test -p deepseek-protocol --test parity_protocol --locked + - name: State persistence parity + run: cargo test -p deepseek-state --test parity_state --locked + - name: Lockfile drift guard + run: git diff --exit-code -- Cargo.lock + build: + needs: parity strategy: matrix: include: + # --- deepseek (cli) --- - os: ubuntu-latest target: x86_64-unknown-linux-gnu binary: deepseek @@ -25,13 +56,30 @@ jobs: target: x86_64-pc-windows-msvc binary: deepseek.exe artifact_name: deepseek-windows-x64.exe + # --- deepseek-tui (TUI) --- + - os: ubuntu-latest + target: x86_64-unknown-linux-gnu + binary: deepseek-tui + artifact_name: deepseek-tui-linux-x64 + - os: macos-latest + target: x86_64-apple-darwin + binary: deepseek-tui + artifact_name: deepseek-tui-macos-x64 + - os: macos-latest + target: aarch64-apple-darwin + binary: deepseek-tui + artifact_name: deepseek-tui-macos-arm64 + - os: windows-latest + target: x86_64-pc-windows-msvc + binary: deepseek-tui.exe + artifact_name: deepseek-tui-windows-x64.exe runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable with: targets: ${{ matrix.target }} - - run: cargo build --release --target ${{ matrix.target }} + - run: cargo build --release --locked --target ${{ matrix.target }} - name: Rename binary shell: bash run: | diff --git a/Cargo.lock b/Cargo.lock index 9ced60b1..b86e0245 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,6 +242,28 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-lc-rs" +version = "1.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.38.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "axum" version = "0.8.8" @@ -404,9 +426,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cfg-if" version = "1.0.4" @@ -427,9 +457,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.42" +version = "0.4.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" dependencies = [ "iana-time-zone", "js-sys", @@ -503,6 +533,15 @@ dependencies = [ "error-code", ] +[[package]] +name = "cmake" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +dependencies = [ + "cc", +] + [[package]] name = "cmp_any" version = "0.8.1" @@ -524,6 +563,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "compact_str" version = "0.8.1" @@ -570,6 +619,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -745,9 +804,154 @@ dependencies = [ "serde_json", ] +[[package]] +name = "deepseek-agent" +version = "0.3.28" +dependencies = [ + "deepseek-config", + "serde", +] + +[[package]] +name = "deepseek-app-server" +version = "0.3.28" +dependencies = [ + "anyhow", + "axum", + "clap", + "deepseek-agent", + "deepseek-config", + "deepseek-core", + "deepseek-execpolicy", + "deepseek-hooks", + "deepseek-mcp", + "deepseek-protocol", + "deepseek-state", + "deepseek-tools", + "serde", + "serde_json", + "tokio", + "tower-http", + "tracing", +] + +[[package]] +name = "deepseek-cli" +version = "0.3.28" +dependencies = [ + "anyhow", + "chrono", + "clap", + "clap_complete", + "deepseek-agent", + "deepseek-app-server", + "deepseek-config", + "deepseek-execpolicy", + "deepseek-mcp", + "deepseek-state", + "serde_json", + "tokio", +] + +[[package]] +name = "deepseek-config" +version = "0.3.28" +dependencies = [ + "anyhow", + "dirs", + "serde", + "serde_json", + "toml", +] + +[[package]] +name = "deepseek-core" +version = "0.3.28" +dependencies = [ + "anyhow", + "chrono", + "deepseek-agent", + "deepseek-config", + "deepseek-execpolicy", + "deepseek-hooks", + "deepseek-mcp", + "deepseek-protocol", + "deepseek-state", + "deepseek-tools", + "serde_json", + "tokio", + "uuid", +] + +[[package]] +name = "deepseek-execpolicy" +version = "0.3.28" +dependencies = [ + "anyhow", + "deepseek-protocol", + "serde", +] + +[[package]] +name = "deepseek-hooks" +version = "0.3.28" +dependencies = [ + "anyhow", + "async-trait", + "chrono", + "deepseek-protocol", + "reqwest", + "serde", + "serde_json", + "tokio", +] + +[[package]] +name = "deepseek-mcp" +version = "0.3.28" +dependencies = [ + "anyhow", + "deepseek-protocol", + "serde", + "serde_json", +] + +[[package]] +name = "deepseek-protocol" +version = "0.3.28" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "deepseek-state" +version = "0.3.28" +dependencies = [ + "anyhow", + "chrono", + "dirs", + "rusqlite", + "serde", + "serde_json", +] + +[[package]] +name = "deepseek-tools" +version = "0.3.28" +dependencies = [ + "anyhow", + "async-trait", + "deepseek-protocol", + "serde", + "serde_json", + "tokio", + "uuid", +] + [[package]] name = "deepseek-tui" -version = "0.3.27" +version = "0.3.28" dependencies = [ "anyhow", "arboard", @@ -799,6 +1003,10 @@ dependencies = [ "zeroize", ] +[[package]] +name = "deepseek-tui-core" +version = "0.3.28" + [[package]] name = "deranged" version = "0.5.5" @@ -942,6 +1150,12 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "dupe" version = "0.9.1" @@ -1038,6 +1252,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.3.0" @@ -1153,6 +1379,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "futures" version = "0.3.31" @@ -1278,8 +1510,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -1289,9 +1523,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasip2", + "wasm-bindgen", ] [[package]] @@ -1364,6 +1600,15 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.5.0" @@ -1797,6 +2042,38 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + [[package]] name = "js-sys" version = "0.3.83" @@ -1860,6 +2137,17 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -1943,6 +2231,12 @@ dependencies = [ "hashbrown 0.15.5", ] +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "lsp-types" version = "0.94.1" @@ -2065,10 +2359,10 @@ dependencies = [ "libc", "log", "openssl", - "openssl-probe", + "openssl-probe 0.1.6", "openssl-sys", "schannel", - "security-framework", + "security-framework 2.11.1", "security-framework-sys", "tempfile", ] @@ -2297,6 +2591,12 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + [[package]] name = "openssl-sys" version = "0.9.111" @@ -2469,6 +2769,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "precomputed-hash" version = "0.1.1" @@ -2509,6 +2818,62 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases 0.2.1", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror 2.0.17", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "aws-lc-rs", + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.17", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases 0.2.1", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.60.2", +] + [[package]] name = "quote" version = "1.0.43" @@ -2534,6 +2899,35 @@ dependencies = [ "nibble_vec", ] +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + [[package]] name = "rangemap" version = "1.7.1" @@ -2672,12 +3066,16 @@ dependencies = [ "native-tls", "percent-encoding", "pin-project-lite", + "quinn", + "rustls", "rustls-pki-types", + "rustls-platform-verifier", "serde", "serde_json", "sync_wrapper", "tokio", "tokio-native-tls", + "tokio-rustls", "tokio-util", "tower", "tower-http", @@ -2703,6 +3101,26 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rusqlite" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +dependencies = [ + "bitflags 2.10.0", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustix" version = "0.38.44" @@ -2735,6 +3153,7 @@ version = "0.23.36" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" dependencies = [ + "aws-lc-rs", "once_cell", "rustls-pki-types", "rustls-webpki", @@ -2742,21 +3161,62 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +dependencies = [ + "openssl-probe 0.2.1", + "rustls-pki-types", + "schannel", + "security-framework 3.5.1", +] + [[package]] name = "rustls-pki-types" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282" dependencies = [ + "web-time", "zeroize", ] +[[package]] +name = "rustls-platform-verifier" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" +dependencies = [ + "core-foundation 0.10.1", + "core-foundation-sys", + "jni", + "log", + "once_cell", + "rustls", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework 3.5.1", + "security-framework-sys", + "webpki-root-certs", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + [[package]] name = "rustls-webpki" version = "0.103.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -2891,7 +3351,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.10.0", - "core-foundation", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef" +dependencies = [ + "bitflags 2.10.0", + "core-foundation 0.10.1", "core-foundation-sys", "libc", "security-framework-sys", @@ -3960,6 +4433,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-root-certs" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "weezl" version = "0.1.12" @@ -4146,6 +4628,15 @@ dependencies = [ "windows-link 0.2.1", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -4182,6 +4673,21 @@ dependencies = [ "windows-link 0.2.1", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.52.6" @@ -4215,6 +4721,12 @@ dependencies = [ "windows_x86_64_msvc 0.53.1", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" @@ -4227,6 +4739,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" @@ -4239,6 +4757,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.52.6" @@ -4263,6 +4787,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.52.6" @@ -4275,6 +4805,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" @@ -4287,6 +4823,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" @@ -4299,6 +4841,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index 546f584b..7e86242b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,73 +1,43 @@ -[package] -name = "deepseek-tui" -version = "0.3.27" +[workspace] +members = [ + "crates/agent", + "crates/app-server", + "crates/cli", + "crates/config", + "crates/core", + "crates/execpolicy", + "crates/hooks", + "crates/mcp", + "crates/protocol", + "crates/state", + "crates/tools", + "crates/tui", + "crates/tui-core", +] +default-members = ["crates/cli", "crates/app-server", "crates/tui"] +resolver = "2" + +[workspace.package] +version = "0.3.28" edition = "2024" -description = "Terminal-native TUI and CLI for DeepSeek models" license = "MIT" repository = "https://github.com/Hmbown/DeepSeek-TUI" -keywords = ["deepseek", "cli", "ai", "agent", "llm"] -categories = ["command-line-utilities"] -[[bin]] -name = "deepseek" -path = "src/main.rs" - -[dependencies] +[workspace.dependencies] anyhow = "1.0.100" -arboard = "3.4" -async-stream = "0.3.6" -async-trait = "0.1" -bytes = "1.11.0" -base64 = "0.22.1" +async-trait = "0.1.89" axum = { version = "0.8.4", features = ["json"] } +chrono = { version = "0.4.43", features = ["serde"] } clap = { version = "4.5.54", features = ["derive"] } clap_complete = "4.5" -colored = "3.0.0" -crossterm = "0.28" -csv = "1.4" -dotenvy = "0.15.7" dirs = "6.0.0" -futures-util = "0.3.31" -indicatif = "0.18.0" -ratatui = "0.29" -regex = "1.11" -reqwest = { version = "0.13.1", default-features = false, features = ["blocking", "json", "stream", "multipart", "native-tls", "http2"] } -rustyline = "15.0.0" +reqwest = { version = "0.13.1", default-features = false, features = ["json", "rustls"] } +rusqlite = { version = "0.32.1", features = ["bundled"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.149" -shellexpand = "3" -toml = "0.9.7" -tokio = { version = "1.49.0", features = ["full"] } -tokio-util = { version = "0.7.16", features = ["io"] } -unicode-width = "0.2" -unicode-segmentation = "1.12" -uuid = { version = "1.11", features = ["v4"] } -tokio-stream = "0.1" -chrono = { version = "0.4", features = ["serde"] } -tempfile = "3.16" thiserror = "2.0" -tracing = "0.1" +tokio = { version = "1.49.0", features = ["full"] } +toml = "0.9.7" tower-http = { version = "0.6", features = ["cors"] } -wait-timeout = "0.2" -multimap = "0.10.0" -shlex = "1.3.0" -starlark = "0.13.0" -tiny_http = "0.12" -portable-pty = "0.8" -zeroize = "1.8.2" -ignore = "0.4" -pdf-extract = "0.7" - -[dev-dependencies] -wiremock = "0.6" -pretty_assertions = "1.4" - -# Platform-specific dependencies -[target.'cfg(target_os = "macos")'.dependencies] -libc = "0.2" - -[target.'cfg(target_os = "linux")'.dependencies] -libc = "0.2" - -[target.'cfg(target_os = "windows")'.dependencies] -windows = { version = "0.60", features = ["Win32_Foundation"] } +tracing = "0.1" +uuid = { version = "1.11", features = ["v4"] } diff --git a/README.md b/README.md index bc30da01..a98d056e 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Three modes: - **Agent** — multi-step autonomous tool use - **YOLO** — full auto-approve, no guardrails (preloads tools by default) -**Recent highlights**: sub‑agent orchestration (background workers, parallel tool calls, dependency‑aware swarms), parallel tool execution (`multi_tool_use.parallel`), runtime HTTP/SSE API (`deepseek serve --http`), background task queue (`/task`), interactive configuration (`/config`), model discovery (`/models`), command palette (`Ctrl+K`), expandable tool payloads (`v`), persistent sidebar for live plan/todo/sub‑agent state, and model context‑window suffix hints (`-32k`, `-256k`). +**Recent highlights**: workspace architecture (modular crates mirroring [Codex](https://github.com/openai/codex) layout), sub-agent orchestration (background workers, parallel tool calls, dependency-aware swarms), parallel tool execution (`multi_tool_use.parallel`), runtime HTTP/SSE API (`deepseek serve --http`), background task queue (`/task`), interactive configuration (`/config`), model discovery (`/models`), command palette (`Ctrl+K`), expandable tool payloads (`v`), persistent sidebar for live plan/todo/sub-agent state, and model context-window suffix hints (`-32k`, `-256k`). ## Install @@ -31,7 +31,9 @@ cargo install deepseek-tui --locked # Or from source git clone https://github.com/Hmbown/DeepSeek-TUI.git -cd DeepSeek-TUI && cargo install --path . --locked +cd DeepSeek-TUI +cargo install --path crates/tui --locked # TUI (interactive terminal) +cargo install --path crates/cli --locked # CLI (dispatcher + server) ``` ## Setup @@ -45,7 +47,9 @@ api_key = "YOUR_DEEPSEEK_API_KEY" Then run: ```bash -deepseek +deepseek-tui # interactive TUI +# or +deepseek # CLI dispatcher (delegates to deepseek-tui for interactive use) ``` **Tab** switches modes, **F1** opens help, **Esc** cancels a running request. @@ -53,21 +57,40 @@ deepseek ## Usage ```bash -deepseek # interactive TUI -deepseek -p "explain this in 2 sentences" # one-shot prompt -deepseek --yolo # agent mode, all tools auto-approved -deepseek doctor # check your setup -deepseek models # list available models -deepseek serve --http # start HTTP/SSE API server +deepseek-tui # interactive TUI +deepseek-tui -p "explain this in 2 sentences" # one-shot prompt +deepseek-tui --yolo # agent mode, all tools auto-approved +deepseek doctor # check your setup +deepseek models # list available models +deepseek serve --http # start HTTP/SSE API server ``` Within the TUI, use `/config`, `/models`, `/task`, and `Ctrl+K` command palette. +## Workspace Architecture + +``` +crates/ + cli/ deepseek-cli → deepseek CLI dispatcher + server + tui/ deepseek-tui → deepseek-tui Interactive terminal UI + app-server/ deepseek-app-server HTTP/SSE + JSON-RPC server + core/ deepseek-core Agent loop + engine + protocol/ deepseek-protocol Request/response framing + config/ deepseek-config Configuration + profiles + state/ deepseek-state SQLite session persistence + tools/ deepseek-tools Tool registry + specs + mcp/ deepseek-mcp MCP server integration + hooks/ deepseek-hooks Lifecycle hooks + execpolicy/ deepseek-execpolicy Approval policy engine + agent/ deepseek-agent Model/provider registry + tui-core/ deepseek-tui-core TUI state machine scaffold +``` + ## Model IDs Common model IDs: `deepseek-chat`, `deepseek-reasoner`. -Any valid `deepseek-*` model ID is accepted (including future releases). Model IDs can include context‑window suffix hints (`-32k`, `-256k`). To see live IDs from your configured endpoint: +Any valid `deepseek-*` model ID is accepted (including future releases). Model IDs can include context-window suffix hints (`-32k`, `-256k`). To see live IDs from your configured endpoint: ```bash deepseek models diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml new file mode 100644 index 00000000..089ad9fd --- /dev/null +++ b/crates/agent/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "deepseek-agent" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Model/provider registry and fallback strategy for DeepSeek workspace architecture" + +[dependencies] +deepseek-config = { path = "../config" } +serde.workspace = true diff --git a/crates/agent/src/lib.rs b/crates/agent/src/lib.rs new file mode 100644 index 00000000..44b20713 --- /dev/null +++ b/crates/agent/src/lib.rs @@ -0,0 +1,133 @@ +use std::collections::HashMap; + +use deepseek_config::ProviderKind; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelInfo { + pub id: String, + pub provider: ProviderKind, + pub aliases: Vec, + pub supports_tools: bool, + pub supports_reasoning: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelResolution { + pub requested: Option, + pub resolved: ModelInfo, + pub used_fallback: bool, + pub fallback_chain: Vec, +} + +#[derive(Debug, Clone)] +pub struct ModelRegistry { + models: Vec, + alias_map: HashMap, +} + +impl Default for ModelRegistry { + fn default() -> Self { + let models = vec![ + ModelInfo { + id: "deepseek-reasoner".to_string(), + provider: ProviderKind::Deepseek, + aliases: vec!["deepseek-r1".to_string()], + supports_tools: true, + supports_reasoning: true, + }, + ModelInfo { + id: "deepseek-chat".to_string(), + provider: ProviderKind::Deepseek, + aliases: vec!["deepseek-v3".to_string(), "deepseek-v3.2".to_string()], + supports_tools: true, + supports_reasoning: false, + }, + ModelInfo { + id: "gpt-4.1".to_string(), + provider: ProviderKind::Openai, + aliases: vec!["gpt4.1".to_string(), "gpt-4o".to_string()], + supports_tools: true, + supports_reasoning: true, + }, + ModelInfo { + id: "gpt-4.1-mini".to_string(), + provider: ProviderKind::Openai, + aliases: vec!["gpt-4o-mini".to_string()], + supports_tools: true, + supports_reasoning: false, + }, + ]; + Self::new(models) + } +} + +impl ModelRegistry { + #[must_use] + pub fn new(models: Vec) -> Self { + let mut alias_map = HashMap::new(); + for (idx, model) in models.iter().enumerate() { + alias_map.insert(normalize(&model.id), idx); + for alias in &model.aliases { + alias_map.insert(normalize(alias), idx); + } + } + Self { models, alias_map } + } + + #[must_use] + pub fn list(&self) -> Vec { + self.models.clone() + } + + #[must_use] + pub fn resolve( + &self, + requested: Option<&str>, + provider_hint: Option, + ) -> ModelResolution { + let mut fallback_chain = Vec::new(); + + if let Some(name) = requested { + fallback_chain.push(format!("requested:{name}")); + if let Some(idx) = self.alias_map.get(&normalize(name)) { + return ModelResolution { + requested: Some(name.to_string()), + resolved: self.models[*idx].clone(), + used_fallback: false, + fallback_chain, + }; + } + } + + let provider = provider_hint.unwrap_or(ProviderKind::Deepseek); + fallback_chain.push(format!("provider_default:{}", provider.as_str())); + if let Some(model) = self.models.iter().find(|m| m.provider == provider).cloned() { + return ModelResolution { + requested: requested.map(ToOwned::to_owned), + resolved: model, + used_fallback: true, + fallback_chain, + }; + } + + let final_fallback = self.models.first().cloned().unwrap_or(ModelInfo { + id: "deepseek-reasoner".to_string(), + provider: ProviderKind::Deepseek, + aliases: Vec::new(), + supports_tools: true, + supports_reasoning: true, + }); + fallback_chain.push("global_default:deepseek-reasoner".to_string()); + ModelResolution { + requested: requested.map(ToOwned::to_owned), + resolved: final_fallback, + used_fallback: true, + fallback_chain, + } + } +} + +fn normalize(value: &str) -> String { + value.trim().to_ascii_lowercase() +} diff --git a/crates/app-server/Cargo.toml b/crates/app-server/Cargo.toml new file mode 100644 index 00000000..60b2d524 --- /dev/null +++ b/crates/app-server/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "deepseek-app-server" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Codex-style app-server transport for DeepSeek workspace architecture" + +[dependencies] +anyhow.workspace = true +axum.workspace = true +clap.workspace = true +deepseek-agent = { path = "../agent" } +deepseek-config = { path = "../config" } +deepseek-core = { path = "../core" } +deepseek-execpolicy = { path = "../execpolicy" } +deepseek-hooks = { path = "../hooks" } +deepseek-mcp = { path = "../mcp" } +deepseek-protocol = { path = "../protocol" } +deepseek-state = { path = "../state" } +deepseek-tools = { path = "../tools" } +serde.workspace = true +serde_json.workspace = true +tokio.workspace = true +tower-http.workspace = true +tracing.workspace = true diff --git a/crates/app-server/src/lib.rs b/crates/app-server/src/lib.rs new file mode 100644 index 00000000..5b6aacd8 --- /dev/null +++ b/crates/app-server/src/lib.rs @@ -0,0 +1,783 @@ +use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::Arc; + +use anyhow::Result; +use axum::extract::State; +use axum::routing::{get, post}; +use axum::{Json, Router}; +use deepseek_agent::ModelRegistry; +use deepseek_config::{CliRuntimeOverrides, ConfigStore}; +use deepseek_core::Runtime; +use deepseek_execpolicy::ExecPolicyEngine; +use deepseek_hooks::{HookDispatcher, JsonlHookSink, StdoutHookSink}; +use deepseek_mcp::McpManager; +use deepseek_protocol::{ + AppRequest, AppResponse, PromptRequest, PromptResponse, ThreadRequest, ThreadResponse, +}; +use deepseek_state::StateStore; +use deepseek_tools::{ToolCall, ToolRegistry}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::sync::{Mutex, RwLock}; +use tower_http::cors::CorsLayer; + +#[derive(Debug, Clone)] +pub struct AppServerOptions { + pub listen: SocketAddr, + pub config_path: Option, +} + +#[derive(Clone)] +struct AppState { + config_path: Option, + config: Arc>, + runtime: Arc>, + registry: ModelRegistry, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ToolCallRequest { + call: ToolCall, + #[serde(default)] + cwd: Option, +} + +#[derive(Debug, Deserialize)] +struct JsonRpcRequest { + #[serde(default)] + jsonrpc: Option, + #[serde(default)] + id: Option, + method: String, + #[serde(default)] + params: Value, +} + +#[derive(Debug)] +struct JsonRpcError { + code: i64, + message: String, + data: Option, +} + +#[derive(Debug)] +struct StdioDispatchResult { + result: Value, + should_exit: bool, +} + +#[derive(Debug, Deserialize)] +struct ConfigGetParams { + key: String, +} + +#[derive(Debug, Deserialize)] +struct ConfigSetParams { + key: String, + value: String, +} + +#[derive(Debug, Deserialize)] +struct ThreadIdParams { + thread_id: String, +} + +#[derive(Debug, Deserialize)] +struct ThreadMessageParams { + thread_id: String, + input: String, +} + +pub async fn run(options: AppServerOptions) -> Result<()> { + let state = build_state(options.config_path.clone())?; + + let app = Router::new() + .route("/healthz", get(healthz)) + .route("/thread", post(thread_handler)) + .route("/app", post(app_handler)) + .route("/prompt", post(prompt_handler)) + .route("/tool", post(tool_handler)) + .route("/jobs", get(jobs_handler)) + .route("/mcp/startup", post(mcp_startup_handler)) + .layer(CorsLayer::permissive()) + .with_state(state); + + let listener = tokio::net::TcpListener::bind(options.listen).await?; + axum::serve(listener, app).await?; + Ok(()) +} + +pub async fn run_stdio(config_path: Option) -> Result<()> { + let state = build_state(config_path)?; + let stdin = tokio::io::stdin(); + let stdout = tokio::io::stdout(); + let mut reader = BufReader::new(stdin).lines(); + let mut writer = tokio::io::BufWriter::new(stdout); + while let Some(line) = reader.next_line().await? { + if line.trim().is_empty() { + continue; + } + + let request: JsonRpcRequest = match serde_json::from_str(&line) { + Ok(value) => value, + Err(err) => { + let response = jsonrpc_error( + None, + JsonRpcError::parse_error(format!("invalid json: {err}")), + ); + writer.write_all(response.to_string().as_bytes()).await?; + writer.write_all(b"\n").await?; + writer.flush().await?; + continue; + } + }; + + if request + .jsonrpc + .as_deref() + .is_some_and(|version| version != "2.0") + { + let response = jsonrpc_error( + request.id, + JsonRpcError::invalid_request("jsonrpc version must be 2.0"), + ); + writer.write_all(response.to_string().as_bytes()).await?; + writer.write_all(b"\n").await?; + writer.flush().await?; + continue; + } + + let response = match dispatch_stdio_request(&state, &request.method, request.params).await { + Ok(dispatch) => { + let encoded = jsonrpc_result(request.id, dispatch.result); + writer.write_all(encoded.to_string().as_bytes()).await?; + writer.write_all(b"\n").await?; + writer.flush().await?; + if dispatch.should_exit { + break; + } + continue; + } + Err(err) => jsonrpc_error(request.id, err), + }; + + writer.write_all(response.to_string().as_bytes()).await?; + writer.write_all(b"\n").await?; + writer.flush().await?; + } + + Ok(()) +} + +async fn healthz() -> Json { + Json(json!({ + "status": "ok", + "protocol": "v2", + "service": "deepseek-app-server" + })) +} + +async fn thread_handler( + State(state): State, + Json(req): Json, +) -> Json { + let mut runtime = state.runtime.lock().await; + match runtime.handle_thread(req).await { + Ok(res) => Json(res), + Err(err) => Json(ThreadResponse { + thread_id: "error".to_string(), + status: format!("error:{err}"), + thread: None, + threads: Vec::new(), + model: None, + model_provider: None, + cwd: None, + approval_policy: None, + sandbox: None, + events: Vec::new(), + data: json!({}), + }), + } +} + +async fn prompt_handler( + State(state): State, + Json(req): Json, +) -> Json { + let mut runtime = state.runtime.lock().await; + let overrides = CliRuntimeOverrides::default(); + match runtime.handle_prompt(req, &overrides).await { + Ok(res) => Json(res), + Err(err) => Json(PromptResponse { + output: err.to_string(), + model: "unknown".to_string(), + events: Vec::new(), + }), + } +} + +async fn tool_handler( + State(state): State, + Json(req): Json, +) -> Json { + let runtime = state.runtime.lock().await; + let cwd = req + .cwd + .unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))); + match runtime + .invoke_tool( + req.call, + deepseek_execpolicy::AskForApproval::OnRequest, + &cwd, + ) + .await + { + Ok(value) => Json(value), + Err(err) => Json(json!({ "ok": false, "error": err.to_string() })), + } +} + +async fn jobs_handler(State(state): State) -> Json { + let runtime = state.runtime.lock().await; + Json(runtime.app_status()) +} + +async fn mcp_startup_handler(State(state): State) -> Json { + let runtime = state.runtime.lock().await; + let summary = runtime.mcp_startup().await; + Json(json!({ + "ok": true, + "summary": summary + })) +} + +async fn app_handler( + State(state): State, + Json(req): Json, +) -> Json { + Json(process_app_request(&state, req).await) +} + +fn build_state(config_path: Option) -> Result { + let store = ConfigStore::load(config_path.clone())?; + let config = store.config.clone(); + let registry = ModelRegistry::default(); + + let state_db_path = config_path + .as_ref() + .and_then(|p| p.parent().map(|parent| parent.join("state.db"))); + let state_store = StateStore::open(state_db_path)?; + + let mut hooks = HookDispatcher::default(); + hooks.add_sink(Arc::new(StdoutHookSink)); + let hook_log_path = config_path + .as_ref() + .and_then(|p| p.parent().map(|parent| parent.join("events.jsonl"))) + .unwrap_or_else(|| PathBuf::from(".deepseek/events.jsonl")); + hooks.add_sink(Arc::new(JsonlHookSink::new(hook_log_path))); + + let runtime = Runtime::new( + config.clone(), + registry.clone(), + state_store, + Arc::new(ToolRegistry::default()), + Arc::new(McpManager::default()), + ExecPolicyEngine::new(Vec::new(), Vec::new()), + hooks, + ); + + Ok(AppState { + config_path, + config: Arc::new(RwLock::new(config)), + runtime: Arc::new(Mutex::new(runtime)), + registry, + }) +} + +fn params_or_object(params: Value) -> Value { + if params.is_null() { json!({}) } else { params } +} + +fn parse_params(params: Value) -> std::result::Result { + serde_json::from_value(params).map_err(|err| JsonRpcError::invalid_params(err.to_string())) +} + +fn jsonrpc_result(id: Option, result: Value) -> Value { + json!({ + "jsonrpc": "2.0", + "id": id.unwrap_or(Value::Null), + "result": result + }) +} + +fn jsonrpc_error(id: Option, err: JsonRpcError) -> Value { + json!({ + "jsonrpc": "2.0", + "id": id.unwrap_or(Value::Null), + "error": { + "code": err.code, + "message": err.message, + "data": err.data + } + }) +} + +impl JsonRpcError { + fn parse_error(message: impl Into) -> Self { + Self { + code: -32700, + message: message.into(), + data: None, + } + } + + fn invalid_request(message: impl Into) -> Self { + Self { + code: -32600, + message: message.into(), + data: None, + } + } + + fn method_not_found(method: &str) -> Self { + Self { + code: -32601, + message: format!("unsupported method: {method}"), + data: None, + } + } + + fn invalid_params(message: impl Into) -> Self { + Self { + code: -32602, + message: message.into(), + data: None, + } + } + + fn internal(message: impl Into) -> Self { + Self { + code: -32603, + message: message.into(), + data: None, + } + } +} + +async fn handle_thread_request( + state: &AppState, + req: ThreadRequest, +) -> std::result::Result { + let mut runtime = state.runtime.lock().await; + runtime + .handle_thread(req) + .await + .map_err(|err| JsonRpcError::internal(err.to_string())) +} + +async fn handle_prompt_request( + state: &AppState, + req: PromptRequest, +) -> std::result::Result { + let mut runtime = state.runtime.lock().await; + runtime + .handle_prompt(req, &CliRuntimeOverrides::default()) + .await + .map_err(|err| JsonRpcError::internal(err.to_string())) +} + +async fn dispatch_stdio_request( + state: &AppState, + method: &str, + params: Value, +) -> std::result::Result { + let outcome = match method { + "healthz" | "app/healthz" => StdioDispatchResult { + result: json!({ + "status": "ok", + "service": "deepseek-app-server", + "transport": "stdio" + }), + should_exit: false, + }, + "capabilities" => StdioDispatchResult { + result: json!({ + "transport": "stdio", + "families": ["thread/*", "app/*", "prompt/*"], + "methods": [ + "healthz", + "thread/capabilities", + "thread/request", + "thread/create", + "thread/start", + "thread/resume", + "thread/fork", + "thread/list", + "thread/read", + "thread/set_name", + "thread/archive", + "thread/unarchive", + "thread/message", + "app/capabilities", + "app/request", + "app/config/get", + "app/config/set", + "app/config/unset", + "app/config/list", + "app/models", + "app/thread_loaded_list", + "prompt/capabilities", + "prompt/request", + "prompt/run", + "shutdown" + ] + }), + should_exit: false, + }, + "thread/capabilities" => StdioDispatchResult { + result: json!({ + "methods": [ + "thread/request", + "thread/create", + "thread/start", + "thread/resume", + "thread/fork", + "thread/list", + "thread/read", + "thread/set_name", + "thread/archive", + "thread/unarchive", + "thread/message" + ] + }), + should_exit: false, + }, + "thread/request" => { + let request: ThreadRequest = parse_params(params)?; + let response = handle_thread_request(state, request).await?; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "thread/create" => { + #[derive(Debug, Deserialize)] + struct CreateParams { + #[serde(default)] + metadata: Value, + } + let parsed: CreateParams = parse_params(params_or_object(params))?; + let response = handle_thread_request( + state, + ThreadRequest::Create { + metadata: parsed.metadata, + }, + ) + .await?; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "thread/start" => { + let request = ThreadRequest::Start(parse_params(params_or_object(params))?); + let response = handle_thread_request(state, request).await?; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "thread/resume" => { + let request = ThreadRequest::Resume(parse_params(params_or_object(params))?); + let response = handle_thread_request(state, request).await?; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "thread/fork" => { + let request = ThreadRequest::Fork(parse_params(params_or_object(params))?); + let response = handle_thread_request(state, request).await?; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "thread/list" => { + let request = ThreadRequest::List(parse_params(params_or_object(params))?); + let response = handle_thread_request(state, request).await?; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "thread/read" => { + let request = ThreadRequest::Read(parse_params(params_or_object(params))?); + let response = handle_thread_request(state, request).await?; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "thread/set_name" | "thread/set-name" => { + let request = ThreadRequest::SetName(parse_params(params_or_object(params))?); + let response = handle_thread_request(state, request).await?; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "thread/archive" => { + let parsed: ThreadIdParams = parse_params(params_or_object(params))?; + let response = handle_thread_request( + state, + ThreadRequest::Archive { + thread_id: parsed.thread_id, + }, + ) + .await?; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "thread/unarchive" => { + let parsed: ThreadIdParams = parse_params(params_or_object(params))?; + let response = handle_thread_request( + state, + ThreadRequest::Unarchive { + thread_id: parsed.thread_id, + }, + ) + .await?; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "thread/message" => { + let parsed: ThreadMessageParams = parse_params(params_or_object(params))?; + let response = handle_thread_request( + state, + ThreadRequest::Message { + thread_id: parsed.thread_id, + input: parsed.input, + }, + ) + .await?; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "app/capabilities" => { + let response = process_app_request(state, AppRequest::Capabilities).await; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "app/request" => { + let request: AppRequest = parse_params(params)?; + let response = process_app_request(state, request).await; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "app/config/get" => { + let parsed: ConfigGetParams = parse_params(params_or_object(params))?; + let response = + process_app_request(state, AppRequest::ConfigGet { key: parsed.key }).await; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "app/config/set" => { + let parsed: ConfigSetParams = parse_params(params_or_object(params))?; + let response = process_app_request( + state, + AppRequest::ConfigSet { + key: parsed.key, + value: parsed.value, + }, + ) + .await; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "app/config/unset" => { + let parsed: ConfigGetParams = parse_params(params_or_object(params))?; + let response = + process_app_request(state, AppRequest::ConfigUnset { key: parsed.key }).await; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "app/config/list" => { + let response = process_app_request(state, AppRequest::ConfigList).await; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "app/models" => { + let response = process_app_request(state, AppRequest::Models).await; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "app/thread_loaded_list" | "app/thread-loaded-list" => { + let response = process_app_request(state, AppRequest::ThreadLoadedList).await; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "prompt/capabilities" => StdioDispatchResult { + result: json!({ + "methods": ["prompt/request", "prompt/run"] + }), + should_exit: false, + }, + "prompt/request" | "prompt/run" => { + let request: PromptRequest = parse_params(params)?; + let response = handle_prompt_request(state, request).await?; + StdioDispatchResult { + result: serde_json::to_value(response) + .map_err(|err| JsonRpcError::internal(err.to_string()))?, + should_exit: false, + } + } + "shutdown" => StdioDispatchResult { + result: json!({"ok": true, "status": "stopped"}), + should_exit: true, + }, + _ => return Err(JsonRpcError::method_not_found(method)), + }; + Ok(outcome) +} + +async fn process_app_request(state: &AppState, req: AppRequest) -> AppResponse { + match req { + AppRequest::Capabilities => AppResponse { + ok: true, + data: json!({ + "routes": ["/thread", "/app", "/prompt", "/tool", "/jobs", "/mcp/startup"], + "config": ["get", "set", "unset", "list"], + "events": ["response_start", "response_delta", "response_end", "tool_call_start", "tool_call_result", "mcp_startup_update", "mcp_startup_complete"], + "transport": "stdio+http", + "config_path": state.config_path.as_ref().map(|p| p.display().to_string()), + }), + events: Vec::new(), + }, + AppRequest::ConfigGet { key } => { + let cfg = state.config.read().await; + AppResponse { + ok: true, + data: json!({ "key": key, "value": cfg.get_value(&key) }), + events: Vec::new(), + } + } + AppRequest::ConfigSet { key, value } => { + let mut cfg = state.config.write().await; + let result = cfg.set_value(&key, &value); + let ok = result.is_ok(); + let message = result.err().map(|e| e.to_string()); + let snapshot = cfg.clone(); + drop(cfg); + let _ = persist_config(state, snapshot).await; + AppResponse { + ok, + data: json!({ "key": key, "value": value, "error": message }), + events: Vec::new(), + } + } + AppRequest::ConfigUnset { key } => { + let mut cfg = state.config.write().await; + let result = cfg.unset_value(&key); + let ok = result.is_ok(); + let message = result.err().map(|e| e.to_string()); + let snapshot = cfg.clone(); + drop(cfg); + let _ = persist_config(state, snapshot).await; + AppResponse { + ok, + data: json!({ "key": key, "error": message }), + events: Vec::new(), + } + } + AppRequest::ConfigList => { + let cfg = state.config.read().await; + AppResponse { + ok: true, + data: json!({ "values": cfg.list_values() }), + events: Vec::new(), + } + } + AppRequest::Models => AppResponse { + ok: true, + data: json!({ "models": state.registry.list() }), + events: Vec::new(), + }, + AppRequest::ThreadLoadedList => { + let mut runtime = state.runtime.lock().await; + let response = runtime + .handle_thread(deepseek_protocol::ThreadRequest::List( + deepseek_protocol::ThreadListParams { + include_archived: false, + limit: Some(50), + }, + )) + .await; + match response { + Ok(thread_resp) => AppResponse { + ok: true, + data: json!({ "threads": thread_resp.threads }), + events: thread_resp.events, + }, + Err(err) => AppResponse { + ok: false, + data: json!({ "error": err.to_string() }), + events: Vec::new(), + }, + } + } + } +} + +async fn persist_config(state: &AppState, config: deepseek_config::ConfigToml) -> Result<()> { + if state.config_path.is_none() { + return Ok(()); + } + let mut store = ConfigStore::load(state.config_path.clone())?; + store.config = config; + store.save() +} diff --git a/crates/app-server/src/main.rs b/crates/app-server/src/main.rs new file mode 100644 index 00000000..b8f31168 --- /dev/null +++ b/crates/app-server/src/main.rs @@ -0,0 +1,33 @@ +use std::net::SocketAddr; +use std::path::PathBuf; + +use anyhow::{Context, Result}; +use clap::Parser; +use deepseek_app_server::{AppServerOptions, run}; + +#[derive(Debug, Parser)] +#[command( + name = "deepseek-app-server", + about = "Run the DeepSeek app-server transport" +)] +struct Cli { + #[arg(long, default_value = "127.0.0.1")] + host: String, + #[arg(long, default_value_t = 8787)] + port: u16, + #[arg(long)] + config: Option, +} + +#[tokio::main] +async fn main() -> Result<()> { + let cli = Cli::parse(); + let listen: SocketAddr = format!("{}:{}", cli.host, cli.port) + .parse() + .with_context(|| format!("invalid listen address {}:{}", cli.host, cli.port))?; + run(AppServerOptions { + listen, + config_path: cli.config, + }) + .await +} diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml new file mode 100644 index 00000000..091d3a45 --- /dev/null +++ b/crates/cli/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "deepseek-cli" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Codex-style CLI facade for DeepSeek workspace architecture" + +[[bin]] +name = "deepseek" +path = "src/main.rs" + +[dependencies] +anyhow.workspace = true +clap.workspace = true +clap_complete.workspace = true +deepseek-agent = { path = "../agent" } +deepseek-app-server = { path = "../app-server" } +deepseek-config = { path = "../config" } +deepseek-execpolicy = { path = "../execpolicy" } +deepseek-mcp = { path = "../mcp" } +deepseek-state = { path = "../state" } +chrono.workspace = true +serde_json.workspace = true +tokio.workspace = true diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs new file mode 100644 index 00000000..32869513 --- /dev/null +++ b/crates/cli/src/main.rs @@ -0,0 +1,1049 @@ +use std::io::{self, Read}; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::process::{Command, ExitCode}; + +use anyhow::{Context, Result, bail}; +use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; +use clap_complete::{Shell, generate}; +use deepseek_agent::ModelRegistry; +use deepseek_app_server::{ + AppServerOptions, run as run_app_server, run_stdio as run_app_server_stdio, +}; +use deepseek_config::{CliRuntimeOverrides, ConfigStore, ProviderKind}; +use deepseek_execpolicy::{AskForApproval, ExecPolicyContext, ExecPolicyEngine}; +use deepseek_mcp::{McpServerDefinition, run_stdio_server}; +use deepseek_state::{StateStore, ThreadListFilters}; + +#[derive(Debug, Clone, Copy, ValueEnum)] +enum ProviderArg { + Deepseek, + Openai, +} + +impl From for ProviderKind { + fn from(value: ProviderArg) -> Self { + match value { + ProviderArg::Deepseek => ProviderKind::Deepseek, + ProviderArg::Openai => ProviderKind::Openai, + } + } +} + +#[derive(Debug, Parser)] +#[command( + name = "deepseek", + version, + bin_name = "deepseek", + override_usage = "deepseek [OPTIONS] [PROMPT]\n deepseek [OPTIONS] [ARGS]" +)] +struct Cli { + #[arg(long)] + config: Option, + #[arg(long)] + profile: Option, + #[arg(long, value_enum)] + provider: Option, + #[arg(long)] + model: Option, + #[arg(long = "output-mode")] + output_mode: Option, + #[arg(long = "log-level")] + log_level: Option, + #[arg(long)] + telemetry: Option, + #[arg(long)] + approval_policy: Option, + #[arg(long)] + sandbox_mode: Option, + #[arg(long)] + api_key: Option, + #[arg(long)] + base_url: Option, + #[arg(value_name = "PROMPT")] + prompt: Option, + #[command(subcommand)] + command: Option, +} + +#[derive(Debug, Subcommand)] +enum Commands { + /// Run interactive/non-interactive flows via the TUI binary. + Run(RunArgs), + /// Login using API key, ChatGPT token, or device code style session. + Login(LoginArgs), + /// Remove saved authentication state. + Logout, + /// Manage authentication credentials and provider mode. + Auth(AuthArgs), + /// Run MCP server mode over stdio. + McpServer, + /// Read/write/list config values. + Config(ConfigArgs), + /// Resolve or list available models across providers. + Model(ModelArgs), + /// Manage thread/session metadata and resume/fork flows. + Thread(ThreadArgs), + /// Evaluate sandbox/approval policy decisions. + Sandbox(SandboxArgs), + /// Run the app-server transport. + AppServer(AppServerArgs), + /// Generate shell completions. + Completion { + #[arg(value_enum)] + shell: Shell, + }, +} + +#[derive(Debug, Args)] +struct RunArgs { + #[arg(trailing_var_arg = true, allow_hyphen_values = true)] + args: Vec, +} + +#[derive(Debug, Args)] +struct LoginArgs { + #[arg(long, value_enum, default_value_t = ProviderArg::Deepseek)] + provider: ProviderArg, + #[arg(long)] + api_key: Option, + #[arg(long, default_value_t = false)] + chatgpt: bool, + #[arg(long, default_value_t = false)] + device_code: bool, + #[arg(long)] + token: Option, +} + +#[derive(Debug, Args)] +struct AuthArgs { + #[command(subcommand)] + command: AuthCommand, +} + +#[derive(Debug, Subcommand)] +enum AuthCommand { + Status, + Set { + #[arg(long, value_enum)] + provider: ProviderArg, + #[arg(long)] + api_key: Option, + }, + Clear { + #[arg(long, value_enum)] + provider: ProviderArg, + }, +} + +#[derive(Debug, Args)] +struct ConfigArgs { + #[command(subcommand)] + command: ConfigCommand, +} + +#[derive(Debug, Subcommand)] +enum ConfigCommand { + Get { key: String }, + Set { key: String, value: String }, + Unset { key: String }, + List, + Path, +} + +#[derive(Debug, Args)] +struct ModelArgs { + #[command(subcommand)] + command: ModelCommand, +} + +#[derive(Debug, Subcommand)] +enum ModelCommand { + List { + #[arg(long, value_enum)] + provider: Option, + }, + Resolve { + model: Option, + #[arg(long, value_enum)] + provider: Option, + }, +} + +#[derive(Debug, Args)] +struct ThreadArgs { + #[command(subcommand)] + command: ThreadCommand, +} + +#[derive(Debug, Subcommand)] +enum ThreadCommand { + List { + #[arg(long, default_value_t = false)] + all: bool, + #[arg(long)] + limit: Option, + }, + Read { + thread_id: String, + }, + Resume { + thread_id: String, + }, + Fork { + thread_id: String, + }, + Archive { + thread_id: String, + }, + Unarchive { + thread_id: String, + }, + SetName { + thread_id: String, + name: String, + }, +} + +#[derive(Debug, Args)] +struct SandboxArgs { + #[command(subcommand)] + command: SandboxCommand, +} + +#[derive(Debug, Subcommand)] +enum SandboxCommand { + Check { + command: String, + #[arg(long, value_enum, default_value_t = ApprovalModeArg::OnRequest)] + ask: ApprovalModeArg, + }, +} + +#[derive(Debug, Clone, Copy, ValueEnum)] +enum ApprovalModeArg { + UnlessTrusted, + OnFailure, + OnRequest, + Never, +} + +impl From for AskForApproval { + fn from(value: ApprovalModeArg) -> Self { + match value { + ApprovalModeArg::UnlessTrusted => AskForApproval::UnlessTrusted, + ApprovalModeArg::OnFailure => AskForApproval::OnFailure, + ApprovalModeArg::OnRequest => AskForApproval::OnRequest, + ApprovalModeArg::Never => AskForApproval::Never, + } + } +} + +#[derive(Debug, Args)] +struct AppServerArgs { + #[arg(long, default_value = "127.0.0.1")] + host: String, + #[arg(long, default_value_t = 8787)] + port: u16, + #[arg(long)] + config: Option, + #[arg(long, default_value_t = false)] + stdio: bool, +} + +const MCP_SERVER_DEFINITIONS_KEY: &str = "mcp.server_definitions"; + +fn main() -> ExitCode { + match run() { + Ok(()) => ExitCode::SUCCESS, + Err(err) => { + eprintln!("error: {err}"); + ExitCode::FAILURE + } + } +} + +fn run() -> Result<()> { + let mut cli = Cli::parse(); + + let mut store = ConfigStore::load(cli.config.clone())?; + let runtime_overrides = CliRuntimeOverrides { + provider: cli.provider.map(Into::into), + model: cli.model.clone(), + api_key: cli.api_key.clone(), + base_url: cli.base_url.clone(), + auth_mode: None, + output_mode: cli.output_mode.clone(), + log_level: cli.log_level.clone(), + telemetry: cli.telemetry, + approval_policy: cli.approval_policy.clone(), + sandbox_mode: cli.sandbox_mode.clone(), + }; + let _resolved_runtime = store.config.resolve_runtime_options(&runtime_overrides); + + let command = cli.command.take(); + + match command { + Some(Commands::Run(args)) => delegate_to_tui(&cli, args.args), + Some(Commands::Login(args)) => run_login_command(&mut store, args), + Some(Commands::Logout) => run_logout_command(&mut store), + Some(Commands::Auth(args)) => run_auth_command(&mut store, args.command), + Some(Commands::McpServer) => run_mcp_server_command(&mut store), + Some(Commands::Config(args)) => run_config_command(&mut store, args.command), + Some(Commands::Model(args)) => run_model_command(args.command), + Some(Commands::Thread(args)) => run_thread_command(args.command), + Some(Commands::Sandbox(args)) => run_sandbox_command(args.command), + Some(Commands::AppServer(args)) => run_app_server_command(args), + Some(Commands::Completion { shell }) => { + let mut cmd = Cli::command(); + generate(shell, &mut cmd, "deepseek", &mut io::stdout()); + Ok(()) + } + None => { + let mut forwarded = Vec::new(); + if let Some(prompt) = cli.prompt.clone() { + forwarded.push("--prompt".to_string()); + forwarded.push(prompt); + } + delegate_to_tui(&cli, forwarded) + } + } +} + +fn run_login_command(store: &mut ConfigStore, args: LoginArgs) -> Result<()> { + let provider: ProviderKind = args.provider.into(); + store.config.provider = provider; + + if args.chatgpt { + let token = match args.token { + Some(token) => token, + None => read_api_key_from_stdin()?, + }; + store.config.auth_mode = Some("chatgpt".to_string()); + store.config.chatgpt_access_token = Some(token); + store.config.device_code_session = None; + store.save()?; + println!("logged in using chatgpt token mode ({})", provider.as_str()); + return Ok(()); + } + + if args.device_code { + let token = match args.token { + Some(token) => token, + None => read_api_key_from_stdin()?, + }; + store.config.auth_mode = Some("device_code".to_string()); + store.config.device_code_session = Some(token); + store.config.chatgpt_access_token = None; + store.save()?; + println!( + "logged in using device code session mode ({})", + provider.as_str() + ); + return Ok(()); + } + + let api_key = match args.api_key { + Some(v) => v, + None => read_api_key_from_stdin()?, + }; + store.config.auth_mode = Some("api_key".to_string()); + store.config.providers.for_provider_mut(provider).api_key = Some(api_key); + store.save()?; + println!("logged in using API key mode ({})", provider.as_str()); + Ok(()) +} + +fn run_logout_command(store: &mut ConfigStore) -> Result<()> { + store.config.providers.deepseek.api_key = None; + store.config.providers.openai.api_key = None; + store.config.auth_mode = None; + store.config.chatgpt_access_token = None; + store.config.device_code_session = None; + store.save()?; + println!("logged out"); + Ok(()) +} + +fn run_auth_command(store: &mut ConfigStore, command: AuthCommand) -> Result<()> { + match command { + AuthCommand::Status => { + let deepseek_env = std::env::var("DEEPSEEK_API_KEY") + .ok() + .filter(|v| !v.trim().is_empty()) + .is_some(); + let openai_env = std::env::var("OPENAI_API_KEY") + .ok() + .filter(|v| !v.trim().is_empty()) + .is_some(); + let deepseek_file = store + .config + .providers + .deepseek + .api_key + .as_ref() + .is_some_and(|v| !v.trim().is_empty()); + let openai_file = store + .config + .providers + .openai + .api_key + .as_ref() + .is_some_and(|v| !v.trim().is_empty()); + + println!("provider: {}", store.config.provider.as_str()); + println!( + "deepseek auth: env={}, config={}", + deepseek_env, deepseek_file + ); + println!("openai auth: env={}, config={}", openai_env, openai_file); + Ok(()) + } + AuthCommand::Set { provider, api_key } => { + let provider: ProviderKind = provider.into(); + let api_key = match api_key { + Some(v) => v, + None => read_api_key_from_stdin()?, + }; + store.config.provider = provider; + store.config.providers.for_provider_mut(provider).api_key = Some(api_key); + store.save()?; + println!("saved API key for {}", provider.as_str()); + Ok(()) + } + AuthCommand::Clear { provider } => { + let provider: ProviderKind = provider.into(); + store.config.providers.for_provider_mut(provider).api_key = None; + store.save()?; + println!("cleared API key for {}", provider.as_str()); + Ok(()) + } + } +} + +fn run_config_command(store: &mut ConfigStore, command: ConfigCommand) -> Result<()> { + match command { + ConfigCommand::Get { key } => { + if let Some(value) = store.config.get_value(&key) { + println!("{value}"); + return Ok(()); + } + bail!("key not found: {key}"); + } + ConfigCommand::Set { key, value } => { + store.config.set_value(&key, &value)?; + store.save()?; + println!("set {key}"); + Ok(()) + } + ConfigCommand::Unset { key } => { + store.config.unset_value(&key)?; + store.save()?; + println!("unset {key}"); + Ok(()) + } + ConfigCommand::List => { + for (key, value) in store.config.list_values() { + println!("{key} = {value}"); + } + Ok(()) + } + ConfigCommand::Path => { + println!("{}", store.path().display()); + Ok(()) + } + } +} + +fn run_model_command(command: ModelCommand) -> Result<()> { + let registry = ModelRegistry::default(); + match command { + ModelCommand::List { provider } => { + let filter = provider.map(ProviderKind::from); + for model in registry.list().into_iter().filter(|m| match filter { + Some(p) => m.provider == p, + None => true, + }) { + println!("{} ({})", model.id, model.provider.as_str()); + } + Ok(()) + } + ModelCommand::Resolve { model, provider } => { + let resolved = registry.resolve(model.as_deref(), provider.map(ProviderKind::from)); + println!("requested: {}", resolved.requested.unwrap_or_default()); + println!("resolved: {}", resolved.resolved.id); + println!("provider: {}", resolved.resolved.provider.as_str()); + println!("used_fallback: {}", resolved.used_fallback); + Ok(()) + } + } +} + +fn run_thread_command(command: ThreadCommand) -> Result<()> { + let state = StateStore::open(None)?; + match command { + ThreadCommand::List { all, limit } => { + let threads = state.list_threads(ThreadListFilters { + include_archived: all, + limit, + })?; + for thread in threads { + println!( + "{} | {} | {} | {}", + thread.id, + thread + .name + .clone() + .unwrap_or_else(|| "(unnamed)".to_string()), + thread.model_provider, + thread.cwd.display() + ); + } + Ok(()) + } + ThreadCommand::Read { thread_id } => { + let thread = state.get_thread(&thread_id)?; + println!("{}", serde_json::to_string_pretty(&thread)?); + Ok(()) + } + ThreadCommand::Resume { thread_id } => { + let args = vec!["resume".to_string(), thread_id]; + delegate_simple_tui(args) + } + ThreadCommand::Fork { thread_id } => { + let args = vec!["fork".to_string(), thread_id]; + delegate_simple_tui(args) + } + ThreadCommand::Archive { thread_id } => { + state.mark_archived(&thread_id)?; + println!("archived {thread_id}"); + Ok(()) + } + ThreadCommand::Unarchive { thread_id } => { + state.mark_unarchived(&thread_id)?; + println!("unarchived {thread_id}"); + Ok(()) + } + ThreadCommand::SetName { thread_id, name } => { + let mut thread = state + .get_thread(&thread_id)? + .with_context(|| format!("thread not found: {thread_id}"))?; + thread.name = Some(name); + thread.updated_at = chrono::Utc::now().timestamp(); + state.upsert_thread(&thread)?; + println!("renamed {thread_id}"); + Ok(()) + } + } +} + +fn run_sandbox_command(command: SandboxCommand) -> Result<()> { + match command { + SandboxCommand::Check { command, ask } => { + let engine = ExecPolicyEngine::new(Vec::new(), vec!["rm -rf".to_string()]); + let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")); + let decision = engine.check(ExecPolicyContext { + command: &command, + cwd: &cwd.display().to_string(), + ask_for_approval: ask.into(), + sandbox_mode: Some("workspace-write"), + })?; + println!("{}", serde_json::to_string_pretty(&decision)?); + Ok(()) + } + } +} + +fn run_app_server_command(args: AppServerArgs) -> Result<()> { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .context("failed to create tokio runtime")?; + if args.stdio { + return runtime.block_on(run_app_server_stdio(args.config)); + } + let listen: SocketAddr = format!("{}:{}", args.host, args.port) + .parse() + .with_context(|| { + format!( + "invalid app-server listen address {}:{}", + args.host, args.port + ) + })?; + runtime.block_on(run_app_server(AppServerOptions { + listen, + config_path: args.config, + })) +} + +fn run_mcp_server_command(store: &mut ConfigStore) -> Result<()> { + let persisted = load_mcp_server_definitions(store); + let updated = run_stdio_server(persisted)?; + persist_mcp_server_definitions(store, &updated) +} + +fn load_mcp_server_definitions(store: &ConfigStore) -> Vec { + let Some(raw) = store.config.get_value(MCP_SERVER_DEFINITIONS_KEY) else { + return Vec::new(); + }; + + match parse_mcp_server_definitions(&raw) { + Ok(definitions) => definitions, + Err(err) => { + eprintln!( + "warning: failed to parse persisted MCP server definitions ({}): {}", + MCP_SERVER_DEFINITIONS_KEY, err + ); + Vec::new() + } + } +} + +fn parse_mcp_server_definitions(raw: &str) -> Result> { + if let Ok(parsed) = serde_json::from_str::>(raw) { + return Ok(parsed); + } + + let unwrapped: String = serde_json::from_str(raw) + .with_context(|| format!("invalid JSON payload at key {MCP_SERVER_DEFINITIONS_KEY}"))?; + serde_json::from_str::>(&unwrapped).with_context(|| { + format!("invalid MCP server definition list in key {MCP_SERVER_DEFINITIONS_KEY}") + }) +} + +fn persist_mcp_server_definitions( + store: &mut ConfigStore, + definitions: &[McpServerDefinition], +) -> Result<()> { + let encoded = + serde_json::to_string(definitions).context("failed to encode MCP server definitions")?; + store + .config + .set_value(MCP_SERVER_DEFINITIONS_KEY, &encoded)?; + store.save() +} + +fn delegate_to_tui(cli: &Cli, passthrough: Vec) -> Result<()> { + let current = std::env::current_exe().context("failed to locate current executable path")?; + let tui = current.with_file_name("deepseek-tui"); + if !tui.exists() { + bail!( + "deepseek-tui binary not found at {}. Build workspace default members to install it.", + tui.display() + ); + } + + let mut cmd = Command::new(tui); + if let Some(config) = cli.config.as_ref() { + cmd.arg("--config").arg(config); + } + if let Some(profile) = cli.profile.as_ref() { + cmd.arg("--profile").arg(profile); + } + cmd.args(passthrough); + + if let Some(provider) = cli.provider { + cmd.env("DEEPSEEK_PROVIDER", ProviderKind::from(provider).as_str()); + } + if let Some(model) = cli.model.as_ref() { + cmd.env("DEEPSEEK_MODEL", model); + } + if let Some(output_mode) = cli.output_mode.as_ref() { + cmd.env("DEEPSEEK_OUTPUT_MODE", output_mode); + } + if let Some(log_level) = cli.log_level.as_ref() { + cmd.env("DEEPSEEK_LOG_LEVEL", log_level); + } + if let Some(telemetry) = cli.telemetry { + cmd.env("DEEPSEEK_TELEMETRY", telemetry.to_string()); + } + if let Some(policy) = cli.approval_policy.as_ref() { + cmd.env("DEEPSEEK_APPROVAL_POLICY", policy); + } + if let Some(mode) = cli.sandbox_mode.as_ref() { + cmd.env("DEEPSEEK_SANDBOX_MODE", mode); + } + if let Some(api_key) = cli.api_key.as_ref() { + cmd.env("DEEPSEEK_API_KEY", api_key); + } + if let Some(base_url) = cli.base_url.as_ref() { + cmd.env("DEEPSEEK_BASE_URL", base_url); + } + + let status = cmd.status().context("failed to spawn deepseek-tui")?; + match status.code() { + Some(code) => std::process::exit(code), + None => bail!("deepseek-tui terminated by signal"), + } +} + +fn delegate_simple_tui(args: Vec) -> Result<()> { + let current = std::env::current_exe().context("failed to locate current executable path")?; + let tui = current.with_file_name("deepseek-tui"); + if !tui.exists() { + bail!( + "deepseek-tui binary not found at {}. Build workspace default members to install it.", + tui.display() + ); + } + let status = Command::new(tui).args(args).status()?; + match status.code() { + Some(code) => std::process::exit(code), + None => bail!("deepseek-tui terminated by signal"), + } +} + +fn read_api_key_from_stdin() -> Result { + let mut input = String::new(); + io::stdin() + .read_to_string(&mut input) + .context("failed to read api key from stdin")?; + let key = input.trim().to_string(); + if key.is_empty() { + bail!("empty API key provided"); + } + Ok(key) +} + +#[cfg(test)] +mod tests { + use super::*; + use clap::error::ErrorKind; + + fn parse_ok(argv: &[&str]) -> Cli { + Cli::try_parse_from(argv).unwrap_or_else(|err| panic!("parse failed for {argv:?}: {err}")) + } + + fn help_for(argv: &[&str]) -> String { + let err = Cli::try_parse_from(argv).expect_err("expected --help to short-circuit parsing"); + assert_eq!(err.kind(), ErrorKind::DisplayHelp); + err.to_string() + } + + #[test] + fn clap_command_definition_is_consistent() { + Cli::command().debug_assert(); + } + + #[test] + fn parses_config_command_matrix() { + let cli = parse_ok(&["deepseek", "config", "get", "provider"]); + assert!(matches!( + cli.command, + Some(Commands::Config(ConfigArgs { + command: ConfigCommand::Get { ref key } + })) if key == "provider" + )); + + let cli = parse_ok(&["deepseek", "config", "set", "model", "deepseek-chat"]); + assert!(matches!( + cli.command, + Some(Commands::Config(ConfigArgs { + command: ConfigCommand::Set { ref key, ref value } + })) if key == "model" && value == "deepseek-chat" + )); + + let cli = parse_ok(&["deepseek", "config", "unset", "model"]); + assert!(matches!( + cli.command, + Some(Commands::Config(ConfigArgs { + command: ConfigCommand::Unset { ref key } + })) if key == "model" + )); + + assert!(matches!( + parse_ok(&["deepseek", "config", "list"]).command, + Some(Commands::Config(ConfigArgs { + command: ConfigCommand::List + })) + )); + assert!(matches!( + parse_ok(&["deepseek", "config", "path"]).command, + Some(Commands::Config(ConfigArgs { + command: ConfigCommand::Path + })) + )); + } + + #[test] + fn parses_model_command_matrix() { + let cli = parse_ok(&["deepseek", "model", "list"]); + assert!(matches!( + cli.command, + Some(Commands::Model(ModelArgs { + command: ModelCommand::List { provider: None } + })) + )); + + let cli = parse_ok(&["deepseek", "model", "list", "--provider", "openai"]); + assert!(matches!( + cli.command, + Some(Commands::Model(ModelArgs { + command: ModelCommand::List { + provider: Some(ProviderArg::Openai) + } + })) + )); + + let cli = parse_ok(&["deepseek", "model", "resolve", "deepseek-chat"]); + assert!(matches!( + cli.command, + Some(Commands::Model(ModelArgs { + command: ModelCommand::Resolve { + model: Some(ref model), + provider: None + } + })) if model == "deepseek-chat" + )); + + let cli = parse_ok(&[ + "deepseek", + "model", + "resolve", + "--provider", + "deepseek", + "deepseek-reasoner", + ]); + assert!(matches!( + cli.command, + Some(Commands::Model(ModelArgs { + command: ModelCommand::Resolve { + model: Some(ref model), + provider: Some(ProviderArg::Deepseek) + } + })) if model == "deepseek-reasoner" + )); + } + + #[test] + fn parses_thread_command_matrix() { + let cli = parse_ok(&["deepseek", "thread", "list", "--all", "--limit", "50"]); + assert!(matches!( + cli.command, + Some(Commands::Thread(ThreadArgs { + command: ThreadCommand::List { + all: true, + limit: Some(50) + } + })) + )); + + let cli = parse_ok(&["deepseek", "thread", "read", "thread-1"]); + assert!(matches!( + cli.command, + Some(Commands::Thread(ThreadArgs { + command: ThreadCommand::Read { ref thread_id } + })) if thread_id == "thread-1" + )); + + let cli = parse_ok(&["deepseek", "thread", "resume", "thread-2"]); + assert!(matches!( + cli.command, + Some(Commands::Thread(ThreadArgs { + command: ThreadCommand::Resume { ref thread_id } + })) if thread_id == "thread-2" + )); + + let cli = parse_ok(&["deepseek", "thread", "fork", "thread-3"]); + assert!(matches!( + cli.command, + Some(Commands::Thread(ThreadArgs { + command: ThreadCommand::Fork { ref thread_id } + })) if thread_id == "thread-3" + )); + + let cli = parse_ok(&["deepseek", "thread", "archive", "thread-4"]); + assert!(matches!( + cli.command, + Some(Commands::Thread(ThreadArgs { + command: ThreadCommand::Archive { ref thread_id } + })) if thread_id == "thread-4" + )); + + let cli = parse_ok(&["deepseek", "thread", "unarchive", "thread-5"]); + assert!(matches!( + cli.command, + Some(Commands::Thread(ThreadArgs { + command: ThreadCommand::Unarchive { ref thread_id } + })) if thread_id == "thread-5" + )); + + let cli = parse_ok(&["deepseek", "thread", "set-name", "thread-6", "My Thread"]); + assert!(matches!( + cli.command, + Some(Commands::Thread(ThreadArgs { + command: ThreadCommand::SetName { + ref thread_id, + ref name + } + })) if thread_id == "thread-6" && name == "My Thread" + )); + } + + #[test] + fn parses_sandbox_app_server_and_completion_matrix() { + let cli = parse_ok(&[ + "deepseek", + "sandbox", + "check", + "echo hello", + "--ask", + "on-failure", + ]); + assert!(matches!( + cli.command, + Some(Commands::Sandbox(SandboxArgs { + command: SandboxCommand::Check { + ref command, + ask: ApprovalModeArg::OnFailure + } + })) if command == "echo hello" + )); + + let cli = parse_ok(&[ + "deepseek", + "app-server", + "--host", + "0.0.0.0", + "--port", + "9999", + ]); + assert!(matches!( + cli.command, + Some(Commands::AppServer(AppServerArgs { + ref host, + port: 9999, + stdio: false, + .. + })) if host == "0.0.0.0" + )); + + let cli = parse_ok(&["deepseek", "app-server", "--stdio"]); + assert!(matches!( + cli.command, + Some(Commands::AppServer(AppServerArgs { stdio: true, .. })) + )); + + let cli = parse_ok(&["deepseek", "completion", "bash"]); + assert!(matches!( + cli.command, + Some(Commands::Completion { shell: Shell::Bash }) + )); + } + + #[test] + fn parses_global_override_flags() { + let cli = parse_ok(&[ + "deepseek", + "--provider", + "openai", + "--config", + "/tmp/deepseek.toml", + "--profile", + "work", + "--model", + "gpt-4.1", + "--output-mode", + "json", + "--log-level", + "debug", + "--telemetry", + "true", + "--approval-policy", + "on-request", + "--sandbox-mode", + "workspace-write", + "--base-url", + "https://api.openai.com/v1", + "--api-key", + "sk-test", + "model", + "resolve", + "gpt-4.1", + ]); + + assert!(matches!(cli.provider, Some(ProviderArg::Openai))); + assert_eq!(cli.config, Some(PathBuf::from("/tmp/deepseek.toml"))); + assert_eq!(cli.profile.as_deref(), Some("work")); + assert_eq!(cli.model.as_deref(), Some("gpt-4.1")); + assert_eq!(cli.output_mode.as_deref(), Some("json")); + assert_eq!(cli.log_level.as_deref(), Some("debug")); + assert_eq!(cli.telemetry, Some(true)); + assert_eq!(cli.approval_policy.as_deref(), Some("on-request")); + assert_eq!(cli.sandbox_mode.as_deref(), Some("workspace-write")); + assert_eq!(cli.base_url.as_deref(), Some("https://api.openai.com/v1")); + assert_eq!(cli.api_key.as_deref(), Some("sk-test")); + } + + #[test] + fn root_help_surface_contains_expected_subcommands_and_globals() { + let rendered = help_for(&["deepseek", "--help"]); + + for token in [ + "run", + "login", + "logout", + "auth", + "mcp-server", + "config", + "model", + "thread", + "sandbox", + "app-server", + "completion", + "--provider", + "--model", + "--config", + "--profile", + "--output-mode", + "--log-level", + "--telemetry", + "--base-url", + "--api-key", + "--approval-policy", + "--sandbox-mode", + ] { + assert!( + rendered.contains(token), + "expected help to contain token: {token}" + ); + } + } + + #[test] + fn subcommand_help_surfaces_are_stable() { + let cases = [ + ("config", vec!["get", "set", "unset", "list", "path"]), + ("model", vec!["list", "resolve"]), + ( + "thread", + vec![ + "list", + "read", + "resume", + "fork", + "archive", + "unarchive", + "set-name", + ], + ), + ("sandbox", vec!["check"]), + ( + "app-server", + vec!["--host", "--port", "--config", "--stdio"], + ), + ("completion", vec!["", "bash"]), + ]; + + for (subcommand, expected_tokens) in cases { + let argv = ["deepseek", subcommand, "--help"]; + let rendered = help_for(&argv); + for token in expected_tokens { + assert!( + rendered.contains(token), + "expected help for `{subcommand}` to include `{token}`" + ); + } + } + } +} diff --git a/crates/config/Cargo.toml b/crates/config/Cargo.toml new file mode 100644 index 00000000..c3804c1e --- /dev/null +++ b/crates/config/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "deepseek-config" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Config schema and precedence model for DeepSeek workspace architecture" + +[dependencies] +anyhow.workspace = true +dirs.workspace = true +serde.workspace = true +serde_json.workspace = true +toml.workspace = true diff --git a/crates/config/src/lib.rs b/crates/config/src/lib.rs new file mode 100644 index 00000000..61214775 --- /dev/null +++ b/crates/config/src/lib.rs @@ -0,0 +1,477 @@ +use std::collections::BTreeMap; +use std::fs; +use std::path::{Path, PathBuf}; + +use anyhow::{Context, Result, bail}; +use serde::{Deserialize, Serialize}; + +pub const CONFIG_FILE_NAME: &str = "config.toml"; +const DEFAULT_DEEPSEEK_MODEL: &str = "deepseek-reasoner"; +const DEFAULT_OPENAI_MODEL: &str = "gpt-4.1"; +const DEFAULT_DEEPSEEK_BASE_URL: &str = "https://api.deepseek.com"; +const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "kebab-case")] +pub enum ProviderKind { + #[default] + Deepseek, + Openai, +} + +impl ProviderKind { + #[must_use] + pub fn as_str(self) -> &'static str { + match self { + Self::Deepseek => "deepseek", + Self::Openai => "openai", + } + } + + #[must_use] + pub fn parse(value: &str) -> Option { + match value.trim().to_ascii_lowercase().as_str() { + "deepseek" | "deep-seek" => Some(Self::Deepseek), + "openai" | "open-ai" => Some(Self::Openai), + _ => None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ProviderConfigToml { + pub api_key: Option, + pub base_url: Option, + pub model: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ProvidersToml { + #[serde(default)] + pub deepseek: ProviderConfigToml, + #[serde(default)] + pub openai: ProviderConfigToml, +} + +impl ProvidersToml { + #[must_use] + pub fn for_provider(&self, provider: ProviderKind) -> &ProviderConfigToml { + match provider { + ProviderKind::Deepseek => &self.deepseek, + ProviderKind::Openai => &self.openai, + } + } + + pub fn for_provider_mut(&mut self, provider: ProviderKind) -> &mut ProviderConfigToml { + match provider { + ProviderKind::Deepseek => &mut self.deepseek, + ProviderKind::Openai => &mut self.openai, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ConfigToml { + #[serde(default)] + pub provider: ProviderKind, + pub model: Option, + pub auth_mode: Option, + pub chatgpt_access_token: Option, + pub device_code_session: Option, + pub output_mode: Option, + pub log_level: Option, + pub telemetry: Option, + pub approval_policy: Option, + pub sandbox_mode: Option, + #[serde(default)] + pub providers: ProvidersToml, + #[serde(flatten)] + pub extras: BTreeMap, +} + +impl ConfigToml { + #[must_use] + pub fn get_value(&self, key: &str) -> Option { + match key { + "provider" => Some(self.provider.as_str().to_string()), + "model" => self.model.clone(), + "auth.mode" => self.auth_mode.clone(), + "auth.chatgpt_access_token" => self.chatgpt_access_token.clone(), + "auth.device_code_session" => self.device_code_session.clone(), + "output_mode" => self.output_mode.clone(), + "log_level" => self.log_level.clone(), + "telemetry" => self.telemetry.map(|v| v.to_string()), + "approval_policy" => self.approval_policy.clone(), + "sandbox_mode" => self.sandbox_mode.clone(), + "providers.deepseek.api_key" => self.providers.deepseek.api_key.clone(), + "providers.deepseek.base_url" => self.providers.deepseek.base_url.clone(), + "providers.deepseek.model" => self.providers.deepseek.model.clone(), + "providers.openai.api_key" => self.providers.openai.api_key.clone(), + "providers.openai.base_url" => self.providers.openai.base_url.clone(), + "providers.openai.model" => self.providers.openai.model.clone(), + _ => self.extras.get(key).map(toml::Value::to_string), + } + } + + pub fn set_value(&mut self, key: &str, value: &str) -> Result<()> { + match key { + "provider" => { + self.provider = ProviderKind::parse(value) + .with_context(|| format!("unknown provider '{value}'"))?; + } + "model" => self.model = Some(value.to_string()), + "auth.mode" => self.auth_mode = Some(value.to_string()), + "auth.chatgpt_access_token" => self.chatgpt_access_token = Some(value.to_string()), + "auth.device_code_session" => self.device_code_session = Some(value.to_string()), + "output_mode" => self.output_mode = Some(value.to_string()), + "log_level" => self.log_level = Some(value.to_string()), + "telemetry" => { + self.telemetry = Some(parse_bool(value)?); + } + "approval_policy" => self.approval_policy = Some(value.to_string()), + "sandbox_mode" => self.sandbox_mode = Some(value.to_string()), + "providers.deepseek.api_key" => { + self.providers.deepseek.api_key = Some(value.to_string()) + } + "providers.deepseek.base_url" => { + self.providers.deepseek.base_url = Some(value.to_string()); + } + "providers.deepseek.model" => self.providers.deepseek.model = Some(value.to_string()), + "providers.openai.api_key" => self.providers.openai.api_key = Some(value.to_string()), + "providers.openai.base_url" => self.providers.openai.base_url = Some(value.to_string()), + "providers.openai.model" => self.providers.openai.model = Some(value.to_string()), + _ => { + self.extras + .insert(key.to_string(), toml::Value::String(value.to_string())); + } + } + Ok(()) + } + + pub fn unset_value(&mut self, key: &str) -> Result<()> { + match key { + "provider" => self.provider = ProviderKind::Deepseek, + "model" => self.model = None, + "auth.mode" => self.auth_mode = None, + "auth.chatgpt_access_token" => self.chatgpt_access_token = None, + "auth.device_code_session" => self.device_code_session = None, + "output_mode" => self.output_mode = None, + "log_level" => self.log_level = None, + "telemetry" => self.telemetry = None, + "approval_policy" => self.approval_policy = None, + "sandbox_mode" => self.sandbox_mode = None, + "providers.deepseek.api_key" => self.providers.deepseek.api_key = None, + "providers.deepseek.base_url" => self.providers.deepseek.base_url = None, + "providers.deepseek.model" => self.providers.deepseek.model = None, + "providers.openai.api_key" => self.providers.openai.api_key = None, + "providers.openai.base_url" => self.providers.openai.base_url = None, + "providers.openai.model" => self.providers.openai.model = None, + _ => { + self.extras.remove(key); + } + } + Ok(()) + } + + #[must_use] + pub fn list_values(&self) -> BTreeMap { + let mut out = BTreeMap::new(); + out.insert("provider".to_string(), self.provider.as_str().to_string()); + + if let Some(v) = self.model.as_ref() { + out.insert("model".to_string(), v.clone()); + } + if let Some(v) = self.auth_mode.as_ref() { + out.insert("auth.mode".to_string(), v.clone()); + } + if let Some(v) = self.chatgpt_access_token.as_ref() { + out.insert("auth.chatgpt_access_token".to_string(), redact_secret(v)); + } + if let Some(v) = self.device_code_session.as_ref() { + out.insert("auth.device_code_session".to_string(), redact_secret(v)); + } + if let Some(v) = self.output_mode.as_ref() { + out.insert("output_mode".to_string(), v.clone()); + } + if let Some(v) = self.log_level.as_ref() { + out.insert("log_level".to_string(), v.clone()); + } + if let Some(v) = self.telemetry { + out.insert("telemetry".to_string(), v.to_string()); + } + if let Some(v) = self.approval_policy.as_ref() { + out.insert("approval_policy".to_string(), v.clone()); + } + if let Some(v) = self.sandbox_mode.as_ref() { + out.insert("sandbox_mode".to_string(), v.clone()); + } + if let Some(v) = self.providers.deepseek.api_key.as_ref() { + out.insert("providers.deepseek.api_key".to_string(), redact_secret(v)); + } + if let Some(v) = self.providers.deepseek.base_url.as_ref() { + out.insert("providers.deepseek.base_url".to_string(), v.clone()); + } + if let Some(v) = self.providers.deepseek.model.as_ref() { + out.insert("providers.deepseek.model".to_string(), v.clone()); + } + if let Some(v) = self.providers.openai.api_key.as_ref() { + out.insert("providers.openai.api_key".to_string(), redact_secret(v)); + } + if let Some(v) = self.providers.openai.base_url.as_ref() { + out.insert("providers.openai.base_url".to_string(), v.clone()); + } + if let Some(v) = self.providers.openai.model.as_ref() { + out.insert("providers.openai.model".to_string(), v.clone()); + } + + for (k, v) in &self.extras { + out.insert(k.clone(), v.to_string()); + } + out + } + + #[must_use] + pub fn resolve_runtime_options(&self, cli: &CliRuntimeOverrides) -> ResolvedRuntimeOptions { + let env = EnvRuntimeOverrides::load(); + let provider = cli.provider.or(env.provider).unwrap_or(self.provider); + + let provider_cfg = self.providers.for_provider(provider); + let api_key = cli + .api_key + .clone() + .or_else(|| env.api_key_for(provider)) + .or_else(|| provider_cfg.api_key.clone()); + + let base_url = cli + .base_url + .clone() + .or_else(|| env.base_url_for(provider)) + .or_else(|| provider_cfg.base_url.clone()) + .unwrap_or_else(|| match provider { + ProviderKind::Deepseek => DEFAULT_DEEPSEEK_BASE_URL.to_string(), + ProviderKind::Openai => DEFAULT_OPENAI_BASE_URL.to_string(), + }); + + let model = cli + .model + .clone() + .or_else(|| env.model.clone()) + .or_else(|| provider_cfg.model.clone()) + .or_else(|| self.model.clone()) + .unwrap_or_else(|| match provider { + ProviderKind::Deepseek => DEFAULT_DEEPSEEK_MODEL.to_string(), + ProviderKind::Openai => DEFAULT_OPENAI_MODEL.to_string(), + }); + + let output_mode = cli + .output_mode + .clone() + .or_else(|| env.output_mode.clone()) + .or_else(|| self.output_mode.clone()); + let auth_mode = cli + .auth_mode + .clone() + .or_else(|| env.auth_mode.clone()) + .or_else(|| self.auth_mode.clone()); + let log_level = cli + .log_level + .clone() + .or_else(|| env.log_level.clone()) + .or_else(|| self.log_level.clone()); + let telemetry = cli + .telemetry + .or(env.telemetry) + .or(self.telemetry) + .unwrap_or(false); + let approval_policy = cli + .approval_policy + .clone() + .or_else(|| env.approval_policy.clone()) + .or_else(|| self.approval_policy.clone()); + let sandbox_mode = cli + .sandbox_mode + .clone() + .or_else(|| env.sandbox_mode.clone()) + .or_else(|| self.sandbox_mode.clone()); + + ResolvedRuntimeOptions { + provider, + model, + api_key, + base_url, + auth_mode, + output_mode, + log_level, + telemetry, + approval_policy, + sandbox_mode, + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct CliRuntimeOverrides { + pub provider: Option, + pub model: Option, + pub api_key: Option, + pub base_url: Option, + pub auth_mode: Option, + pub output_mode: Option, + pub log_level: Option, + pub telemetry: Option, + pub approval_policy: Option, + pub sandbox_mode: Option, +} + +#[derive(Debug, Clone)] +pub struct ResolvedRuntimeOptions { + pub provider: ProviderKind, + pub model: String, + pub api_key: Option, + pub base_url: String, + pub auth_mode: Option, + pub output_mode: Option, + pub log_level: Option, + pub telemetry: bool, + pub approval_policy: Option, + pub sandbox_mode: Option, +} + +#[derive(Debug, Clone)] +pub struct ConfigStore { + path: PathBuf, + pub config: ConfigToml, +} + +impl ConfigStore { + pub fn load(path: Option) -> Result { + let path = resolve_config_path(path)?; + if !path.exists() { + return Ok(Self { + path, + config: ConfigToml::default(), + }); + } + + let raw = fs::read_to_string(&path) + .with_context(|| format!("failed to read config at {}", path.display()))?; + let parsed: ConfigToml = toml::from_str(&raw) + .with_context(|| format!("failed to parse config at {}", path.display()))?; + + Ok(Self { + path, + config: parsed, + }) + } + + pub fn save(&self) -> Result<()> { + if let Some(parent) = self.path.parent() { + fs::create_dir_all(parent).with_context(|| { + format!("failed to create config directory {}", parent.display()) + })?; + } + let body = toml::to_string_pretty(&self.config).context("failed to serialize config")?; + fs::write(&self.path, body) + .with_context(|| format!("failed to write config at {}", self.path.display()))?; + Ok(()) + } + + #[must_use] + pub fn path(&self) -> &Path { + &self.path + } +} + +pub fn resolve_config_path(explicit: Option) -> Result { + if let Some(path) = explicit { + return Ok(path); + } + if let Ok(path) = std::env::var("DEEPSEEK_CONFIG_PATH") { + let trimmed = path.trim(); + if !trimmed.is_empty() { + return Ok(PathBuf::from(trimmed)); + } + } + default_config_path() +} + +pub fn default_config_path() -> Result { + let home = dirs::home_dir().context("failed to resolve home directory for config path")?; + Ok(home.join(".deepseek").join(CONFIG_FILE_NAME)) +} + +fn parse_bool(raw: &str) -> Result { + match raw.trim().to_ascii_lowercase().as_str() { + "1" | "true" | "yes" | "on" | "enabled" => Ok(true), + "0" | "false" | "no" | "off" | "disabled" => Ok(false), + _ => bail!("invalid boolean '{raw}'"), + } +} + +fn redact_secret(secret: &str) -> String { + if secret.len() <= 8 { + return "********".to_string(); + } + format!("{}***{}", &secret[..4], &secret[secret.len() - 4..]) +} + +#[derive(Debug, Clone, Default)] +struct EnvRuntimeOverrides { + provider: Option, + model: Option, + output_mode: Option, + auth_mode: Option, + log_level: Option, + telemetry: Option, + approval_policy: Option, + sandbox_mode: Option, + deepseek_api_key: Option, + openai_api_key: Option, + deepseek_base_url: Option, + openai_base_url: Option, +} + +impl EnvRuntimeOverrides { + fn load() -> Self { + Self { + provider: std::env::var("DEEPSEEK_PROVIDER") + .ok() + .and_then(|v| ProviderKind::parse(&v)), + model: std::env::var("DEEPSEEK_MODEL").ok(), + output_mode: std::env::var("DEEPSEEK_OUTPUT_MODE").ok(), + auth_mode: std::env::var("DEEPSEEK_AUTH_MODE").ok(), + log_level: std::env::var("DEEPSEEK_LOG_LEVEL").ok(), + telemetry: std::env::var("DEEPSEEK_TELEMETRY") + .ok() + .and_then(|v| parse_bool(&v).ok()), + approval_policy: std::env::var("DEEPSEEK_APPROVAL_POLICY").ok(), + sandbox_mode: std::env::var("DEEPSEEK_SANDBOX_MODE").ok(), + deepseek_api_key: std::env::var("DEEPSEEK_API_KEY") + .ok() + .filter(|v| !v.trim().is_empty()), + openai_api_key: std::env::var("OPENAI_API_KEY") + .ok() + .filter(|v| !v.trim().is_empty()), + deepseek_base_url: std::env::var("DEEPSEEK_BASE_URL") + .ok() + .filter(|v| !v.trim().is_empty()), + openai_base_url: std::env::var("OPENAI_BASE_URL") + .ok() + .filter(|v| !v.trim().is_empty()), + } + } + + fn api_key_for(&self, provider: ProviderKind) -> Option { + match provider { + ProviderKind::Deepseek => self.deepseek_api_key.clone(), + ProviderKind::Openai => self.openai_api_key.clone(), + } + } + + fn base_url_for(&self, provider: ProviderKind) -> Option { + match provider { + ProviderKind::Deepseek => self.deepseek_base_url.clone(), + ProviderKind::Openai => self.openai_base_url.clone(), + } + } +} diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml new file mode 100644 index 00000000..837789a1 --- /dev/null +++ b/crates/core/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "deepseek-core" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Core runtime boundaries for DeepSeek workspace architecture" + +[dependencies] +anyhow.workspace = true +chrono.workspace = true +deepseek-agent = { path = "../agent" } +deepseek-config = { path = "../config" } +deepseek-execpolicy = { path = "../execpolicy" } +deepseek-hooks = { path = "../hooks" } +deepseek-mcp = { path = "../mcp" } +deepseek-protocol = { path = "../protocol" } +deepseek-state = { path = "../state" } +deepseek-tools = { path = "../tools" } +serde_json.workspace = true +tokio.workspace = true +uuid.workspace = true diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs new file mode 100644 index 00000000..e9469815 --- /dev/null +++ b/crates/core/src/lib.rs @@ -0,0 +1,1698 @@ +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use anyhow::Result; +use deepseek_agent::ModelRegistry; +use deepseek_config::{CliRuntimeOverrides, ConfigToml, ProviderKind}; +use deepseek_execpolicy::{ + AskForApproval, ExecApprovalRequirement, ExecPolicyContext, ExecPolicyDecision, + ExecPolicyEngine, +}; +use deepseek_hooks::{HookDispatcher, HookEvent}; +use deepseek_mcp::{ + McpManager, McpStartupCompleteEvent, McpStartupStatus as McpManagerStartupStatus, +}; +use deepseek_protocol::{ + AppResponse, EventFrame, ExecApprovalRequestEvent, PromptRequest, PromptResponse, + ReviewDecision, Thread, ThreadForkParams, ThreadListParams, ThreadReadParams, ThreadRequest, + ThreadResponse, ThreadResumeParams, ThreadSetNameParams, ThreadStatus, ToolPayload, +}; +use deepseek_state::{ + JobStateRecord, JobStateStatus, SessionSource, StateStore, ThreadListFilters, ThreadMetadata, + ThreadStatus as PersistedThreadStatus, +}; +use deepseek_tools::{ToolCall, ToolRegistry}; +use serde_json::{Value, json}; +use uuid::Uuid; + +#[derive(Debug, Clone)] +pub enum InitialHistory { + New, + Forked(Vec), + Resumed { + conversation_id: String, + history: Vec, + rollout_path: PathBuf, + }, +} + +#[derive(Debug, Clone)] +pub struct NewThread { + pub thread: Thread, + pub model: String, + pub model_provider: String, + pub cwd: PathBuf, + pub approval_policy: Option, + pub sandbox: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum JobStatus { + Queued, + Running, + Paused, + Completed, + Failed, + Cancelled, +} + +const JOB_DETAIL_SCHEMA_VERSION: u8 = 1; +const DEFAULT_JOB_MAX_ATTEMPTS: u32 = 3; +const DEFAULT_JOB_BACKOFF_BASE_MS: u64 = 500; +const MAX_JOB_HISTORY_ENTRIES: usize = 64; + +#[derive(Debug, Clone)] +pub struct JobRetryMetadata { + pub attempt: u32, + pub max_attempts: u32, + pub backoff_base_ms: u64, + pub next_backoff_ms: u64, + pub next_retry_at: Option, +} + +impl Default for JobRetryMetadata { + fn default() -> Self { + Self { + attempt: 0, + max_attempts: DEFAULT_JOB_MAX_ATTEMPTS, + backoff_base_ms: DEFAULT_JOB_BACKOFF_BASE_MS, + next_backoff_ms: 0, + next_retry_at: None, + } + } +} + +#[derive(Debug, Clone)] +pub struct JobHistoryEntry { + pub at: i64, + pub phase: String, + pub status: JobStatus, + pub progress: Option, + pub detail: Option, + pub retry: JobRetryMetadata, +} + +#[derive(Debug, Clone)] +struct PersistedJobDetail { + pub status: JobStatus, + pub detail: Option, + pub retry: JobRetryMetadata, + pub history: Vec, +} + +#[derive(Debug, Clone)] +pub struct JobRecord { + pub id: String, + pub name: String, + pub status: JobStatus, + pub progress: Option, + pub detail: Option, + pub retry: JobRetryMetadata, + pub history: Vec, + pub created_at: i64, + pub updated_at: i64, +} + +#[derive(Debug, Default)] +pub struct JobManager { + jobs: HashMap, +} + +impl JobManager { + fn now_ts() -> i64 { + chrono::Utc::now().timestamp() + } + + fn deterministic_backoff_ms(retry: &JobRetryMetadata) -> u64 { + if retry.attempt == 0 { + return 0; + } + let exponent = retry.attempt.saturating_sub(1).min(20); + let multiplier = 1u64.checked_shl(exponent).unwrap_or(u64::MAX); + retry.backoff_base_ms.saturating_mul(multiplier) + } + + fn clear_retry_schedule(retry: &mut JobRetryMetadata) { + retry.next_backoff_ms = 0; + retry.next_retry_at = None; + } + + fn push_history(job: &mut JobRecord, phase: &str) { + job.history.push(JobHistoryEntry { + at: job.updated_at, + phase: phase.to_string(), + status: job.status, + progress: job.progress, + detail: job.detail.clone(), + retry: job.retry.clone(), + }); + if job.history.len() > MAX_JOB_HISTORY_ENTRIES { + let to_drain = job.history.len() - MAX_JOB_HISTORY_ENTRIES; + job.history.drain(0..to_drain); + } + } + + fn parse_persisted_detail(raw: Option<&str>) -> Option { + let raw = raw?; + let parsed: Value = serde_json::from_str(raw).ok()?; + let status = parsed + .get("status") + .and_then(Value::as_str) + .and_then(job_status_from_str)?; + let detail = parsed.get("detail").and_then(json_optional_string); + let retry = parse_retry_metadata(parsed.get("retry")); + let history = parsed + .get("history") + .and_then(Value::as_array) + .map(|items| { + items + .iter() + .filter_map(parse_history_entry) + .collect::>() + }) + .unwrap_or_default(); + Some(PersistedJobDetail { + status, + detail, + retry, + history, + }) + } + + fn encode_persisted_detail(job: &JobRecord) -> Result> { + let encoded = json!({ + "schema_version": JOB_DETAIL_SCHEMA_VERSION, + "status": job_status_to_str(job.status), + "detail": job.detail.clone(), + "retry": job_retry_to_value(&job.retry), + "history": job.history.iter().map(job_history_to_value).collect::>() + }) + .to_string(); + Ok(Some(encoded)) + } + + pub fn enqueue(&mut self, name: impl Into) -> JobRecord { + let now = Self::now_ts(); + let id = format!("job-{}", Uuid::new_v4()); + let mut job = JobRecord { + id: id.clone(), + name: name.into(), + status: JobStatus::Queued, + progress: Some(0), + detail: None, + retry: JobRetryMetadata::default(), + history: Vec::new(), + created_at: now, + updated_at: now, + }; + Self::push_history(&mut job, "created"); + self.jobs.insert(id, job.clone()); + job + } + + pub fn set_running(&mut self, id: &str) { + if let Some(job) = self.jobs.get_mut(id) { + job.status = JobStatus::Running; + Self::clear_retry_schedule(&mut job.retry); + job.updated_at = Self::now_ts(); + Self::push_history(job, "running"); + } + } + + pub fn update_progress(&mut self, id: &str, progress: u8, detail: Option) { + if let Some(job) = self.jobs.get_mut(id) { + job.progress = Some(progress.min(100)); + job.detail = detail; + job.updated_at = Self::now_ts(); + Self::push_history(job, "progress_updated"); + } + } + + pub fn complete(&mut self, id: &str) { + if let Some(job) = self.jobs.get_mut(id) { + job.status = JobStatus::Completed; + job.progress = Some(100); + Self::clear_retry_schedule(&mut job.retry); + job.updated_at = Self::now_ts(); + Self::push_history(job, "completed"); + } + } + + pub fn fail(&mut self, id: &str, detail: impl Into) { + if let Some(job) = self.jobs.get_mut(id) { + let now = Self::now_ts(); + job.status = JobStatus::Failed; + job.detail = Some(detail.into()); + if job.retry.attempt < job.retry.max_attempts { + job.retry.attempt += 1; + job.retry.next_backoff_ms = Self::deterministic_backoff_ms(&job.retry); + let delay_secs = ((job.retry.next_backoff_ms.saturating_add(999)) / 1000) + .min(i64::MAX as u64) as i64; + job.retry.next_retry_at = Some(now.saturating_add(delay_secs)); + } else { + Self::clear_retry_schedule(&mut job.retry); + } + job.updated_at = now; + Self::push_history(job, "failed"); + } + } + + pub fn cancel(&mut self, id: &str) { + if let Some(job) = self.jobs.get_mut(id) { + job.status = JobStatus::Cancelled; + Self::clear_retry_schedule(&mut job.retry); + job.updated_at = Self::now_ts(); + Self::push_history(job, "cancelled"); + } + } + + pub fn pause(&mut self, id: &str, detail: Option) { + if let Some(job) = self.jobs.get_mut(id) { + job.status = JobStatus::Paused; + if detail.is_some() { + job.detail = detail; + } + job.updated_at = Self::now_ts(); + Self::push_history(job, "paused"); + } + } + + pub fn resume(&mut self, id: &str, detail: Option) { + if let Some(job) = self.jobs.get_mut(id) { + job.status = JobStatus::Running; + if detail.is_some() { + job.detail = detail; + } + Self::clear_retry_schedule(&mut job.retry); + job.updated_at = Self::now_ts(); + Self::push_history(job, "resumed"); + } + } + + pub fn list(&self) -> Vec { + let mut out = self.jobs.values().cloned().collect::>(); + out.sort_by_key(|job| -job.updated_at); + out + } + + pub fn history(&self, id: &str) -> Vec { + self.jobs + .get(id) + .map(|job| job.history.clone()) + .unwrap_or_default() + } + + pub fn resume_pending(&mut self) -> Vec { + let mut resumed = Vec::new(); + for job in self.jobs.values_mut() { + if matches!(job.status, JobStatus::Queued | JobStatus::Running) { + job.status = JobStatus::Queued; + job.updated_at = Self::now_ts(); + Self::push_history(job, "queued_after_resume"); + resumed.push(job.clone()); + } + } + resumed + } + + pub fn load_from_store(&mut self, store: &StateStore) -> Result<()> { + let persisted = store.list_jobs(Some(500))?; + for job in persisted { + let fallback_status = job_state_status_to_runtime(job.status.clone()); + let parsed = Self::parse_persisted_detail(job.detail.as_deref()); + let (status, detail, retry, history) = if let Some(detail_state) = parsed { + ( + detail_state.status, + detail_state.detail, + detail_state.retry, + detail_state.history, + ) + } else { + ( + fallback_status, + job.detail, + JobRetryMetadata::default(), + Vec::new(), + ) + }; + self.jobs.insert( + job.id.clone(), + JobRecord { + id: job.id, + name: job.name, + status, + progress: job.progress, + detail, + retry, + history, + created_at: job.created_at, + updated_at: job.updated_at, + }, + ); + } + Ok(()) + } + + pub fn persist_job(&self, store: &StateStore, id: &str) -> Result<()> { + let Some(job) = self.jobs.get(id) else { + return Ok(()); + }; + let encoded_detail = Self::encode_persisted_detail(job)?; + store.upsert_job(&JobStateRecord { + id: job.id.clone(), + name: job.name.clone(), + status: runtime_status_to_job_state(job.status), + progress: job.progress, + detail: encoded_detail, + created_at: job.created_at, + updated_at: job.updated_at, + }) + } + + pub fn persist_all(&self, store: &StateStore) -> Result<()> { + for id in self.jobs.keys() { + self.persist_job(store, id)?; + } + Ok(()) + } +} + +pub struct ThreadManager { + store: StateStore, + running_threads: HashMap, + cli_version: String, +} + +impl ThreadManager { + pub fn new(store: StateStore) -> Self { + Self { + store, + running_threads: HashMap::new(), + cli_version: env!("CARGO_PKG_VERSION").to_string(), + } + } + + pub fn state_store(&self) -> &StateStore { + &self.store + } + + pub fn spawn_thread_with_history( + &mut self, + model_provider: String, + cwd: PathBuf, + initial_history: InitialHistory, + persist_extended_history: bool, + ) -> Result { + let id = format!("thread-{}", Uuid::new_v4()); + let now = chrono::Utc::now().timestamp(); + let preview = preview_from_initial_history(&initial_history); + let source = match initial_history { + InitialHistory::New => SessionSource::Interactive, + InitialHistory::Forked(_) => SessionSource::Fork, + InitialHistory::Resumed { .. } => SessionSource::Resume, + }; + let thread = Thread { + id: id.clone(), + preview, + ephemeral: !persist_extended_history, + model_provider: model_provider.clone(), + created_at: now, + updated_at: now, + status: ThreadStatus::Running, + path: None, + cwd: cwd.clone(), + cli_version: self.cli_version.clone(), + source: match source { + SessionSource::Interactive => deepseek_protocol::SessionSource::Interactive, + SessionSource::Resume => deepseek_protocol::SessionSource::Resume, + SessionSource::Fork => deepseek_protocol::SessionSource::Fork, + SessionSource::Api => deepseek_protocol::SessionSource::Api, + SessionSource::Unknown => deepseek_protocol::SessionSource::Unknown, + }, + name: None, + }; + self.persist_thread(&thread, None)?; + match &initial_history { + InitialHistory::Forked(items) => { + for item in items { + self.store.append_message( + &thread.id, + "history", + &item.to_string(), + Some(item.clone()), + )?; + } + } + InitialHistory::Resumed { history, .. } => { + for item in history { + self.store.append_message( + &thread.id, + "history", + &item.to_string(), + Some(item.clone()), + )?; + } + } + InitialHistory::New => {} + } + self.running_threads + .insert(thread.id.clone(), thread.clone()); + Ok(NewThread { + thread, + model: "auto".to_string(), + model_provider, + cwd, + approval_policy: None, + sandbox: None, + }) + } + + pub fn resume_thread_with_history( + &mut self, + params: &ThreadResumeParams, + fallback_cwd: &Path, + model_provider: String, + ) -> Result> { + if params.history.is_none() + && let Some(thread) = self.running_threads.get(¶ms.thread_id).cloned() + { + return Ok(Some(NewThread { + model: params.model.clone().unwrap_or_else(|| "auto".to_string()), + model_provider: params.model_provider.clone().unwrap_or(model_provider), + cwd: params.cwd.clone().unwrap_or_else(|| thread.cwd.clone()), + approval_policy: params.approval_policy.clone(), + sandbox: params.sandbox.clone(), + thread, + })); + } + + let persisted = self.store.get_thread(¶ms.thread_id)?; + let Some(metadata) = persisted else { + return Ok(None); + }; + let mut thread = to_protocol_thread(metadata); + thread.status = ThreadStatus::Running; + thread.updated_at = chrono::Utc::now().timestamp(); + thread.cwd = params + .cwd + .clone() + .unwrap_or_else(|| fallback_cwd.to_path_buf()); + self.persist_thread(&thread, None)?; + self.running_threads + .insert(thread.id.clone(), thread.clone()); + if let Some(history) = params.history.as_ref() { + for item in history { + self.store.append_message( + &thread.id, + "history", + &item.to_string(), + Some(item.clone()), + )?; + } + } + + Ok(Some(NewThread { + model: params.model.clone().unwrap_or_else(|| "auto".to_string()), + model_provider: params.model_provider.clone().unwrap_or(model_provider), + cwd: thread.cwd.clone(), + approval_policy: params.approval_policy.clone(), + sandbox: params.sandbox.clone(), + thread, + })) + } + + pub fn fork_thread( + &mut self, + params: &ThreadForkParams, + fallback_cwd: &Path, + ) -> Result> { + let parent = self.store.get_thread(¶ms.thread_id)?; + let Some(parent) = parent else { + return Ok(None); + }; + let parent_thread = to_protocol_thread(parent); + let new = self.spawn_thread_with_history( + params + .model_provider + .clone() + .unwrap_or_else(|| parent_thread.model_provider.clone()), + params + .cwd + .clone() + .unwrap_or_else(|| fallback_cwd.to_path_buf()), + InitialHistory::Forked(vec![json!({ + "type": "fork", + "from_thread_id": parent_thread.id + })]), + params.persist_extended_history, + )?; + Ok(Some(new)) + } + + pub fn list_threads(&self, params: &ThreadListParams) -> Result> { + let list = self.store.list_threads(ThreadListFilters { + include_archived: params.include_archived, + limit: params.limit, + })?; + Ok(list.into_iter().map(to_protocol_thread).collect()) + } + + pub fn read_thread(&self, params: &ThreadReadParams) -> Result> { + Ok(self + .store + .get_thread(¶ms.thread_id)? + .map(to_protocol_thread)) + } + + pub fn set_thread_name(&mut self, params: &ThreadSetNameParams) -> Result> { + let Some(mut metadata) = self.store.get_thread(¶ms.thread_id)? else { + return Ok(None); + }; + metadata.name = Some(params.name.clone()); + metadata.updated_at = chrono::Utc::now().timestamp(); + self.store.upsert_thread(&metadata)?; + let updated = to_protocol_thread(metadata); + self.running_threads + .insert(updated.id.clone(), updated.clone()); + Ok(Some(updated)) + } + + pub fn archive_thread(&mut self, thread_id: &str) -> Result<()> { + self.store.mark_archived(thread_id)?; + if let Some(thread) = self.running_threads.get_mut(thread_id) { + thread.status = ThreadStatus::Archived; + } + Ok(()) + } + + pub fn unarchive_thread(&mut self, thread_id: &str) -> Result<()> { + self.store.mark_unarchived(thread_id)?; + Ok(()) + } + + pub fn touch_message(&mut self, thread_id: &str, input: &str) -> Result<()> { + let Some(mut metadata) = self.store.get_thread(thread_id)? else { + return Ok(()); + }; + metadata.updated_at = chrono::Utc::now().timestamp(); + metadata.preview = truncate_preview(input); + metadata.status = PersistedThreadStatus::Running; + self.store.upsert_thread(&metadata)?; + if let Some(thread) = self.running_threads.get_mut(thread_id) { + thread.updated_at = metadata.updated_at; + thread.preview = metadata.preview; + thread.status = ThreadStatus::Running; + } + let message_id = self.store.append_message(thread_id, "user", input, None)?; + self.store.save_checkpoint( + thread_id, + "latest", + &json!({ + "reason": "thread_message", + "message_id": message_id, + "role": "user", + "preview": truncate_preview(input), + "updated_at": metadata.updated_at + }), + )?; + Ok(()) + } + + fn persist_thread(&self, thread: &Thread, rollout_path: Option) -> Result<()> { + self.store.upsert_thread(&ThreadMetadata { + id: thread.id.clone(), + rollout_path, + preview: thread.preview.clone(), + ephemeral: thread.ephemeral, + model_provider: thread.model_provider.clone(), + created_at: thread.created_at, + updated_at: thread.updated_at, + status: to_persisted_status(&thread.status), + path: thread.path.clone(), + cwd: thread.cwd.clone(), + cli_version: thread.cli_version.clone(), + source: to_persisted_source(&thread.source), + name: thread.name.clone(), + sandbox_policy: None, + approval_mode: None, + archived: matches!(thread.status, ThreadStatus::Archived), + archived_at: None, + git_sha: None, + git_branch: None, + git_origin_url: None, + memory_mode: None, + }) + } +} + +pub struct Runtime { + pub config: ConfigToml, + pub model_registry: ModelRegistry, + pub thread_manager: ThreadManager, + pub tool_registry: Arc, + pub mcp_manager: Arc, + pub exec_policy: ExecPolicyEngine, + pub hooks: HookDispatcher, + pub jobs: JobManager, +} + +impl Runtime { + pub fn new( + config: ConfigToml, + model_registry: ModelRegistry, + state: StateStore, + tool_registry: Arc, + mcp_manager: Arc, + exec_policy: ExecPolicyEngine, + hooks: HookDispatcher, + ) -> Self { + let mut jobs = JobManager::default(); + let _ = jobs.load_from_store(&state); + Self { + config, + model_registry, + thread_manager: ThreadManager::new(state), + tool_registry, + mcp_manager, + exec_policy, + hooks, + jobs, + } + } + + fn persisted_thread_data(&self, thread_id: &str) -> Result { + let history = self + .thread_manager + .state_store() + .list_messages(thread_id, Some(500))? + .into_iter() + .map(|message| { + json!({ + "id": message.id, + "role": message.role, + "content": message.content, + "item": message.item, + "created_at": message.created_at + }) + }) + .collect::>(); + + let checkpoint = self + .thread_manager + .state_store() + .load_checkpoint(thread_id, None)? + .map(|record| { + json!({ + "checkpoint_id": record.checkpoint_id, + "state": record.state, + "created_at": record.created_at + }) + }); + + Ok(json!({ + "history": history, + "checkpoint": checkpoint + })) + } + + fn persist_latest_checkpoint(&self, thread_id: &str, reason: &str, state: Value) -> Result<()> { + self.thread_manager.state_store().save_checkpoint( + thread_id, + "latest", + &json!({ + "reason": reason, + "saved_at": chrono::Utc::now().timestamp(), + "state": state + }), + ) + } + + pub async fn handle_thread(&mut self, req: ThreadRequest) -> Result { + match req { + ThreadRequest::Create { .. } => { + let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")); + let new = self.thread_manager.spawn_thread_with_history( + "deepseek".to_string(), + cwd, + InitialHistory::New, + false, + )?; + let mut response = thread_response_from_new("created", new); + response.data = self.persisted_thread_data(&response.thread_id)?; + Ok(response) + } + ThreadRequest::Start(params) => { + let cwd = params.cwd.clone().unwrap_or_else(|| { + std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")) + }); + let new = self.thread_manager.spawn_thread_with_history( + params + .model_provider + .clone() + .unwrap_or_else(|| "deepseek".to_string()), + cwd, + InitialHistory::New, + params.persist_extended_history, + )?; + let mut response = thread_response_from_new("started", new); + response.data = self.persisted_thread_data(&response.thread_id)?; + Ok(response) + } + ThreadRequest::Resume(params) => { + let fallback_cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")); + if let Some(new) = self.thread_manager.resume_thread_with_history( + ¶ms, + &fallback_cwd, + "deepseek".to_string(), + )? { + let mut response = thread_response_from_new("resumed", new); + response.data = self.persisted_thread_data(&response.thread_id)?; + Ok(response) + } else { + Ok(ThreadResponse { + thread_id: params.thread_id, + status: "missing".to_string(), + thread: None, + threads: Vec::new(), + model: None, + model_provider: None, + cwd: None, + approval_policy: params.approval_policy, + sandbox: params.sandbox, + events: Vec::new(), + data: json!({"error":"thread not found"}), + }) + } + } + ThreadRequest::Fork(params) => { + let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")); + if let Some(new) = self.thread_manager.fork_thread(¶ms, &cwd)? { + let mut response = thread_response_from_new("forked", new); + response.data = self.persisted_thread_data(&response.thread_id)?; + Ok(response) + } else { + Ok(ThreadResponse { + thread_id: params.thread_id, + status: "missing".to_string(), + thread: None, + threads: Vec::new(), + model: None, + model_provider: None, + cwd: None, + approval_policy: params.approval_policy, + sandbox: params.sandbox, + events: Vec::new(), + data: json!({"error":"thread not found"}), + }) + } + } + ThreadRequest::List(params) => Ok(ThreadResponse { + thread_id: "list".to_string(), + status: "ok".to_string(), + thread: None, + threads: self.thread_manager.list_threads(¶ms)?, + model: None, + model_provider: None, + cwd: None, + approval_policy: None, + sandbox: None, + events: Vec::new(), + data: json!({}), + }), + ThreadRequest::Read(params) => { + let id = params.thread_id.clone(); + let data = self.persisted_thread_data(&id)?; + Ok(ThreadResponse { + thread_id: id, + status: "ok".to_string(), + thread: self.thread_manager.read_thread(¶ms)?, + threads: Vec::new(), + model: None, + model_provider: None, + cwd: None, + approval_policy: None, + sandbox: None, + events: Vec::new(), + data, + }) + } + ThreadRequest::SetName(params) => Ok(ThreadResponse { + thread_id: params.thread_id.clone(), + status: "ok".to_string(), + thread: self.thread_manager.set_thread_name(¶ms)?, + threads: Vec::new(), + model: None, + model_provider: None, + cwd: None, + approval_policy: None, + sandbox: None, + events: Vec::new(), + data: json!({}), + }), + ThreadRequest::Archive { thread_id } => { + self.thread_manager.archive_thread(&thread_id)?; + Ok(ThreadResponse { + thread_id, + status: "archived".to_string(), + thread: None, + threads: Vec::new(), + model: None, + model_provider: None, + cwd: None, + approval_policy: None, + sandbox: None, + events: Vec::new(), + data: json!({}), + }) + } + ThreadRequest::Unarchive { thread_id } => { + self.thread_manager.unarchive_thread(&thread_id)?; + Ok(ThreadResponse { + thread_id, + status: "unarchived".to_string(), + thread: None, + threads: Vec::new(), + model: None, + model_provider: None, + cwd: None, + approval_policy: None, + sandbox: None, + events: Vec::new(), + data: json!({}), + }) + } + ThreadRequest::Message { thread_id, input } => { + self.thread_manager.touch_message(&thread_id, &input)?; + let response_id = format!("{thread_id}:{}", input.len()); + self.hooks + .emit(HookEvent::ResponseStart { + response_id: response_id.clone(), + }) + .await; + self.hooks + .emit(HookEvent::ResponseEnd { + response_id: response_id.clone(), + }) + .await; + + Ok(ThreadResponse { + thread_id, + status: "accepted".to_string(), + thread: None, + threads: Vec::new(), + model: None, + model_provider: None, + cwd: None, + approval_policy: None, + sandbox: None, + events: vec![ + EventFrame::ResponseStart { + response_id: response_id.clone(), + }, + EventFrame::ResponseDelta { + response_id: response_id.clone(), + delta: "queued".to_string(), + }, + EventFrame::ResponseEnd { response_id }, + ], + data: json!({}), + }) + } + } + } + + pub async fn handle_prompt( + &mut self, + req: PromptRequest, + cli_overrides: &CliRuntimeOverrides, + ) -> Result { + let resolved = self.config.resolve_runtime_options(cli_overrides); + let requested_model = req.model.clone().unwrap_or_else(|| resolved.model.clone()); + let selection = self + .model_registry + .resolve(Some(&requested_model), Some(resolved.provider)); + let resolved_model = selection.resolved.id.clone(); + let response_id = format!("resp-{}", Uuid::new_v4()); + + self.hooks + .emit(HookEvent::ResponseStart { + response_id: response_id.clone(), + }) + .await; + self.hooks + .emit(HookEvent::ResponseDelta { + response_id: response_id.clone(), + delta: "model-selected".to_string(), + }) + .await; + self.hooks + .emit(HookEvent::ResponseEnd { + response_id: response_id.clone(), + }) + .await; + + let payload = json!({ + "provider": resolved.provider.as_str(), + "model": resolved_model.clone(), + "prompt": req.prompt, + "telemetry": resolved.telemetry, + "base_url": resolved.base_url, + "has_api_key": resolved.api_key.as_ref().is_some_and(|k| !k.trim().is_empty()), + "approval_policy": resolved.approval_policy, + "sandbox_mode": resolved.sandbox_mode + }); + if let Some(thread_id) = req.thread_id.as_ref() { + self.thread_manager.touch_message(thread_id, &req.prompt)?; + let assistant_message_id = self.thread_manager.store.append_message( + thread_id, + "assistant", + &payload.to_string(), + Some(payload.clone()), + )?; + self.persist_latest_checkpoint( + thread_id, + "prompt_response", + json!({ + "response_id": response_id.clone(), + "model": resolved_model.clone(), + "provider": resolved.provider.as_str(), + "assistant_message_id": assistant_message_id + }), + )?; + } + + Ok(PromptResponse { + output: payload.to_string(), + model: resolved_model, + events: vec![ + EventFrame::ResponseStart { + response_id: response_id.clone(), + }, + EventFrame::ResponseDelta { + response_id: response_id.clone(), + delta: "model-selected".to_string(), + }, + EventFrame::ResponseEnd { response_id }, + ], + }) + } + + pub async fn invoke_tool( + &self, + call: ToolCall, + approval_mode: AskForApproval, + cwd: &Path, + ) -> Result { + let fallback_cwd = cwd.display().to_string(); + let (command, policy_cwd, execution_kind) = call.execution_subject(&fallback_cwd); + let decision = self.exec_policy.check(ExecPolicyContext { + command: &command, + cwd: &policy_cwd, + ask_for_approval: approval_mode, + sandbox_mode: None, + })?; + let precheck = policy_precheck_payload(&decision, &command, &policy_cwd, execution_kind); + let response_id = format!("tool-{}", Uuid::new_v4()); + let call_id = call + .raw_tool_call_id + .clone() + .unwrap_or_else(|| format!("tool-call-{}", Uuid::new_v4())); + self.hooks + .emit(HookEvent::ToolLifecycle { + response_id: response_id.clone(), + tool_name: call.name.clone(), + phase: "precheck".to_string(), + payload: precheck.clone(), + }) + .await; + + if !decision.allow { + let reason = decision.reason().to_string(); + let approval_id = format!("approval-{}", Uuid::new_v4()); + let error_frame = EventFrame::Error { + response_id: response_id.clone(), + message: reason.clone(), + }; + self.hooks + .emit(HookEvent::ApprovalLifecycle { + approval_id, + phase: "denied".to_string(), + reason: Some(reason.clone()), + }) + .await; + self.hooks + .emit(HookEvent::GenericEventFrame { + frame: error_frame.clone(), + }) + .await; + return Ok(json!({ + "ok": false, + "status": "denied", + "execution_kind": execution_kind, + "response_id": response_id, + "precheck": precheck, + "error": reason, + "events": [event_frame_payload(&error_frame)], + })); + } + + if decision.requires_approval { + let approval_id = format!("approval-{}", Uuid::new_v4()); + let reason = decision.reason().to_string(); + let maybe_approval_frame = approval_request_frame( + &decision.requirement, + call_id, + approval_id.clone(), + response_id.clone(), + command.clone(), + policy_cwd.clone(), + ); + self.hooks + .emit(HookEvent::ApprovalLifecycle { + approval_id: approval_id.clone(), + phase: "requested".to_string(), + reason: Some(reason.clone()), + }) + .await; + let mut events = Vec::new(); + if let Some(frame) = maybe_approval_frame { + self.hooks + .emit(HookEvent::GenericEventFrame { + frame: frame.clone(), + }) + .await; + events.push(event_frame_payload(&frame)); + } + return Ok(json!({ + "ok": false, + "status": "approval_required", + "execution_kind": execution_kind, + "response_id": response_id, + "approval_id": approval_id, + "precheck": precheck, + "error": reason, + "events": events, + })); + } + + let start_frame = EventFrame::ToolCallStart { + response_id: response_id.clone(), + tool_name: call.name.clone(), + arguments: tool_payload_value(&call.payload), + }; + self.hooks + .emit(HookEvent::GenericEventFrame { + frame: start_frame.clone(), + }) + .await; + self.hooks + .emit(HookEvent::ToolLifecycle { + response_id: response_id.clone(), + tool_name: call.name.clone(), + phase: "dispatching".to_string(), + payload: json!({ + "call_id": call_id, + "execution_kind": execution_kind + }), + }) + .await; + + match self.tool_registry.dispatch(call.clone(), true).await { + Ok(tool_output) => { + let result_frame = EventFrame::ToolCallResult { + response_id: response_id.clone(), + tool_name: call.name.clone(), + output: tool_output_value(&tool_output), + }; + self.hooks + .emit(HookEvent::GenericEventFrame { + frame: result_frame.clone(), + }) + .await; + self.hooks + .emit(HookEvent::ToolLifecycle { + response_id: response_id.clone(), + tool_name: call.name, + phase: "completed".to_string(), + payload: json!({ "ok": true }), + }) + .await; + Ok(json!({ + "ok": true, + "status": "completed", + "execution_kind": execution_kind, + "response_id": response_id, + "precheck": precheck, + "output": tool_output, + "events": [ + event_frame_payload(&start_frame), + event_frame_payload(&result_frame) + ] + })) + } + Err(err) => { + let message = format!("{err:?}"); + let error_frame = EventFrame::Error { + response_id: response_id.clone(), + message: message.clone(), + }; + self.hooks + .emit(HookEvent::GenericEventFrame { + frame: error_frame.clone(), + }) + .await; + self.hooks + .emit(HookEvent::ToolLifecycle { + response_id: response_id.clone(), + tool_name: call.name, + phase: "failed".to_string(), + payload: json!({ "error": message.clone() }), + }) + .await; + Ok(json!({ + "ok": false, + "status": "failed", + "execution_kind": execution_kind, + "response_id": response_id, + "precheck": precheck, + "error": message, + "events": [ + event_frame_payload(&start_frame), + event_frame_payload(&error_frame) + ] + })) + } + } + } + + pub async fn mcp_startup(&self) -> McpStartupCompleteEvent { + let mut updates = Vec::new(); + let summary = self.mcp_manager.start_all(|update| { + updates.push(update); + }); + for update in updates { + let status = match update.status { + McpManagerStartupStatus::Starting => deepseek_protocol::McpStartupStatus::Starting, + McpManagerStartupStatus::Ready => deepseek_protocol::McpStartupStatus::Ready, + McpManagerStartupStatus::Failed { error } => { + deepseek_protocol::McpStartupStatus::Failed { error } + } + McpManagerStartupStatus::Cancelled => { + deepseek_protocol::McpStartupStatus::Cancelled + } + }; + self.hooks + .emit(HookEvent::GenericEventFrame { + frame: EventFrame::McpStartupUpdate { + update: deepseek_protocol::McpStartupUpdateEvent { + server_name: update.server_name, + status, + }, + }, + }) + .await; + } + self.hooks + .emit(HookEvent::GenericEventFrame { + frame: EventFrame::McpStartupComplete { + summary: deepseek_protocol::McpStartupCompleteEvent { + ready: summary.ready.clone(), + failed: summary + .failed + .iter() + .map(|f| deepseek_protocol::McpStartupFailure { + server_name: f.server_name.clone(), + error: f.error.clone(), + }) + .collect(), + cancelled: summary.cancelled.clone(), + }, + }, + }) + .await; + summary + } + + pub fn app_status(&self) -> AppResponse { + let jobs = self.jobs.list(); + let events = jobs + .iter() + .flat_map(|job| { + job.history.iter().map(|entry| EventFrame::ResponseDelta { + response_id: job.id.clone(), + delta: json!({ + "kind": "job_transition", + "job_id": job.id.clone(), + "phase": entry.phase.clone(), + "status": job_status_to_str(entry.status), + "progress": entry.progress, + "detail": entry.detail.clone(), + "retry": job_retry_to_value(&entry.retry), + "at": entry.at + }) + .to_string(), + }) + }) + .collect::>(); + AppResponse { + ok: true, + data: json!({ + "jobs": jobs.into_iter().map(|job| { + json!({ + "id": job.id, + "name": job.name, + "status": job_status_to_str(job.status), + "progress": job.progress, + "detail": job.detail, + "retry": job_retry_to_value(&job.retry), + "history": job.history.iter().map(job_history_to_value).collect::>() + }) + }).collect::>() + }), + events, + } + } + + pub fn provider_default(&self) -> ProviderKind { + self.config.provider + } + + pub fn save_thread_checkpoint( + &self, + thread_id: &str, + checkpoint_id: &str, + state: &Value, + ) -> Result<()> { + self.thread_manager + .state_store() + .save_checkpoint(thread_id, checkpoint_id, state) + } + + pub fn load_thread_checkpoint( + &self, + thread_id: &str, + checkpoint_id: Option<&str>, + ) -> Result> { + Ok(self + .thread_manager + .state_store() + .load_checkpoint(thread_id, checkpoint_id)? + .map(|checkpoint| checkpoint.state)) + } + + pub fn enqueue_job(&mut self, name: impl Into) -> Result { + let job = self.jobs.enqueue(name); + self.jobs + .persist_job(self.thread_manager.state_store(), &job.id)?; + Ok(job) + } + + pub fn set_job_running(&mut self, job_id: &str) -> Result<()> { + self.jobs.set_running(job_id); + self.jobs + .persist_job(self.thread_manager.state_store(), job_id) + } + + pub fn update_job_progress( + &mut self, + job_id: &str, + progress: u8, + detail: Option, + ) -> Result<()> { + self.jobs.update_progress(job_id, progress, detail); + self.jobs + .persist_job(self.thread_manager.state_store(), job_id) + } + + pub fn complete_job(&mut self, job_id: &str) -> Result<()> { + self.jobs.complete(job_id); + self.jobs + .persist_job(self.thread_manager.state_store(), job_id) + } + + pub fn fail_job(&mut self, job_id: &str, detail: impl Into) -> Result<()> { + self.jobs.fail(job_id, detail); + self.jobs + .persist_job(self.thread_manager.state_store(), job_id) + } + + pub fn cancel_job(&mut self, job_id: &str) -> Result<()> { + self.jobs.cancel(job_id); + self.jobs + .persist_job(self.thread_manager.state_store(), job_id) + } + + pub fn pause_job(&mut self, job_id: &str, detail: Option) -> Result<()> { + self.jobs.pause(job_id, detail); + self.jobs + .persist_job(self.thread_manager.state_store(), job_id) + } + + pub fn resume_job(&mut self, job_id: &str, detail: Option) -> Result<()> { + self.jobs.resume(job_id, detail); + self.jobs + .persist_job(self.thread_manager.state_store(), job_id) + } + + pub fn job_history(&self, job_id: &str) -> Vec { + self.jobs.history(job_id) + } +} + +fn thread_response_from_new(status: &str, new: NewThread) -> ThreadResponse { + ThreadResponse { + thread_id: new.thread.id.clone(), + status: status.to_string(), + thread: Some(new.thread), + threads: Vec::new(), + model: Some(new.model), + model_provider: Some(new.model_provider), + cwd: Some(new.cwd), + approval_policy: new.approval_policy, + sandbox: new.sandbox, + events: Vec::new(), + data: json!({}), + } +} + +fn preview_from_initial_history(initial_history: &InitialHistory) -> String { + match initial_history { + InitialHistory::New => "New conversation".to_string(), + InitialHistory::Forked(items) => truncate_preview( + &items + .first() + .map(Value::to_string) + .unwrap_or_else(|| "Forked conversation".to_string()), + ), + InitialHistory::Resumed { history, .. } => truncate_preview( + &history + .first() + .map(Value::to_string) + .unwrap_or_else(|| "Resumed conversation".to_string()), + ), + } +} + +fn truncate_preview(value: &str) -> String { + value.chars().take(120).collect() +} + +fn to_protocol_thread(thread: ThreadMetadata) -> Thread { + Thread { + id: thread.id, + preview: thread.preview, + ephemeral: thread.ephemeral, + model_provider: thread.model_provider, + created_at: thread.created_at, + updated_at: thread.updated_at, + status: match thread.status { + PersistedThreadStatus::Running => ThreadStatus::Running, + PersistedThreadStatus::Idle => ThreadStatus::Idle, + PersistedThreadStatus::Completed => ThreadStatus::Completed, + PersistedThreadStatus::Failed => ThreadStatus::Failed, + PersistedThreadStatus::Paused => ThreadStatus::Paused, + PersistedThreadStatus::Archived => ThreadStatus::Archived, + }, + path: thread.path, + cwd: thread.cwd, + cli_version: thread.cli_version, + source: match thread.source { + SessionSource::Interactive => deepseek_protocol::SessionSource::Interactive, + SessionSource::Resume => deepseek_protocol::SessionSource::Resume, + SessionSource::Fork => deepseek_protocol::SessionSource::Fork, + SessionSource::Api => deepseek_protocol::SessionSource::Api, + SessionSource::Unknown => deepseek_protocol::SessionSource::Unknown, + }, + name: thread.name, + } +} + +fn to_persisted_status(status: &ThreadStatus) -> PersistedThreadStatus { + match status { + ThreadStatus::Running => PersistedThreadStatus::Running, + ThreadStatus::Idle => PersistedThreadStatus::Idle, + ThreadStatus::Completed => PersistedThreadStatus::Completed, + ThreadStatus::Failed => PersistedThreadStatus::Failed, + ThreadStatus::Paused => PersistedThreadStatus::Paused, + ThreadStatus::Archived => PersistedThreadStatus::Archived, + } +} + +fn to_persisted_source(source: &deepseek_protocol::SessionSource) -> SessionSource { + match source { + deepseek_protocol::SessionSource::Interactive => SessionSource::Interactive, + deepseek_protocol::SessionSource::Resume => SessionSource::Resume, + deepseek_protocol::SessionSource::Fork => SessionSource::Fork, + deepseek_protocol::SessionSource::Api => SessionSource::Api, + deepseek_protocol::SessionSource::Unknown => SessionSource::Unknown, + } +} + +fn approval_request_frame( + requirement: &ExecApprovalRequirement, + call_id: String, + approval_id: String, + turn_id: String, + command: String, + cwd: String, +) -> Option { + let ExecApprovalRequirement::NeedsApproval { + reason, + proposed_execpolicy_amendment, + proposed_network_policy_amendments, + } = requirement + else { + return None; + }; + + let mut available_decisions = vec![ + ReviewDecision::Approved, + ReviewDecision::ApprovedForSession, + ReviewDecision::Denied, + ReviewDecision::Abort, + ]; + if proposed_execpolicy_amendment + .as_ref() + .is_some_and(|amendment| !amendment.prefixes.is_empty()) + { + available_decisions.push(ReviewDecision::ApprovedExecpolicyAmendment); + } + available_decisions.extend(proposed_network_policy_amendments.iter().cloned().map( + |amendment| ReviewDecision::NetworkPolicyAmendment { + host: amendment.host, + action: amendment.action, + }, + )); + + Some(EventFrame::ExecApprovalRequest { + request: ExecApprovalRequestEvent { + call_id, + approval_id, + turn_id, + command, + cwd, + reason: reason.clone(), + network_approval_context: None, + proposed_execpolicy_amendment: proposed_execpolicy_amendment + .as_ref() + .map(|amendment| amendment.prefixes.clone()) + .unwrap_or_default(), + proposed_network_policy_amendments: proposed_network_policy_amendments.clone(), + additional_permissions: Vec::new(), + available_decisions, + }, + }) +} + +fn approval_requirement_payload(requirement: &ExecApprovalRequirement) -> Value { + match requirement { + ExecApprovalRequirement::Skip { + bypass_sandbox, + proposed_execpolicy_amendment, + } => json!({ + "type": "skip", + "bypass_sandbox": bypass_sandbox, + "reason": requirement.reason(), + "proposed_execpolicy_amendment": proposed_execpolicy_amendment + .as_ref() + .map(|amendment| amendment.prefixes.clone()) + .unwrap_or_default() + }), + ExecApprovalRequirement::NeedsApproval { + reason, + proposed_execpolicy_amendment, + proposed_network_policy_amendments, + } => json!({ + "type": "needs_approval", + "reason": reason, + "proposed_execpolicy_amendment": proposed_execpolicy_amendment + .as_ref() + .map(|amendment| amendment.prefixes.clone()) + .unwrap_or_default(), + "proposed_network_policy_amendments": proposed_network_policy_amendments + }), + ExecApprovalRequirement::Forbidden { reason } => json!({ + "type": "forbidden", + "reason": reason + }), + } +} + +fn policy_precheck_payload( + decision: &ExecPolicyDecision, + command: &str, + cwd: &str, + execution_kind: &str, +) -> Value { + json!({ + "execution_kind": execution_kind, + "command": command, + "cwd": cwd, + "allow": decision.allow, + "requires_approval": decision.requires_approval, + "matched_rule": decision.matched_rule.clone(), + "phase": decision.requirement.phase(), + "reason": decision.reason(), + "requirement": approval_requirement_payload(&decision.requirement) + }) +} + +fn tool_payload_value(payload: &ToolPayload) -> Value { + serde_json::to_value(payload).unwrap_or_else( + |_| json!({"type":"serialization_error","message":"tool payload unavailable"}), + ) +} + +fn tool_output_value(output: &deepseek_protocol::ToolOutput) -> Value { + serde_json::to_value(output).unwrap_or_else( + |_| json!({"type":"serialization_error","message":"tool output unavailable"}), + ) +} + +fn event_frame_payload(frame: &EventFrame) -> Value { + serde_json::to_value(frame) + .unwrap_or_else(|_| json!({"event":"error","message":"failed to encode event frame"})) +} + +fn json_optional_string(value: &Value) -> Option { + if value.is_null() { + None + } else { + value.as_str().map(ToString::to_string) + } +} + +fn parse_retry_metadata(value: Option<&Value>) -> JobRetryMetadata { + let Some(value) = value else { + return JobRetryMetadata::default(); + }; + JobRetryMetadata { + attempt: value + .get("attempt") + .and_then(Value::as_u64) + .unwrap_or(0) + .min(u32::MAX as u64) as u32, + max_attempts: value + .get("max_attempts") + .and_then(Value::as_u64) + .unwrap_or(DEFAULT_JOB_MAX_ATTEMPTS as u64) + .min(u32::MAX as u64) as u32, + backoff_base_ms: value + .get("backoff_base_ms") + .and_then(Value::as_u64) + .unwrap_or(DEFAULT_JOB_BACKOFF_BASE_MS), + next_backoff_ms: value + .get("next_backoff_ms") + .and_then(Value::as_u64) + .unwrap_or(0), + next_retry_at: value.get("next_retry_at").and_then(Value::as_i64), + } +} + +fn parse_history_entry(value: &Value) -> Option { + let status = value + .get("status") + .and_then(Value::as_str) + .and_then(job_status_from_str)?; + Some(JobHistoryEntry { + at: value.get("at").and_then(Value::as_i64).unwrap_or(0), + phase: value + .get("phase") + .and_then(Value::as_str) + .unwrap_or("unknown") + .to_string(), + status, + progress: value + .get("progress") + .and_then(Value::as_u64) + .map(|v| v.min(u8::MAX as u64) as u8), + detail: value.get("detail").and_then(json_optional_string), + retry: parse_retry_metadata(value.get("retry")), + }) +} + +fn job_status_to_str(status: JobStatus) -> &'static str { + match status { + JobStatus::Queued => "queued", + JobStatus::Running => "running", + JobStatus::Paused => "paused", + JobStatus::Completed => "completed", + JobStatus::Failed => "failed", + JobStatus::Cancelled => "cancelled", + } +} + +fn job_status_from_str(value: &str) -> Option { + match value { + "queued" => Some(JobStatus::Queued), + "running" => Some(JobStatus::Running), + "paused" => Some(JobStatus::Paused), + "completed" => Some(JobStatus::Completed), + "failed" => Some(JobStatus::Failed), + "cancelled" => Some(JobStatus::Cancelled), + _ => None, + } +} + +fn job_retry_to_value(retry: &JobRetryMetadata) -> Value { + json!({ + "attempt": retry.attempt, + "max_attempts": retry.max_attempts, + "backoff_base_ms": retry.backoff_base_ms, + "next_backoff_ms": retry.next_backoff_ms, + "next_retry_at": retry.next_retry_at + }) +} + +fn job_history_to_value(entry: &JobHistoryEntry) -> Value { + json!({ + "at": entry.at, + "phase": entry.phase.clone(), + "status": job_status_to_str(entry.status), + "progress": entry.progress, + "detail": entry.detail.clone(), + "retry": job_retry_to_value(&entry.retry) + }) +} + +fn runtime_status_to_job_state(status: JobStatus) -> JobStateStatus { + match status { + JobStatus::Queued => JobStateStatus::Queued, + JobStatus::Running => JobStateStatus::Running, + JobStatus::Paused => JobStateStatus::Running, + JobStatus::Completed => JobStateStatus::Completed, + JobStatus::Failed => JobStateStatus::Failed, + JobStatus::Cancelled => JobStateStatus::Cancelled, + } +} + +fn job_state_status_to_runtime(status: JobStateStatus) -> JobStatus { + match status { + JobStateStatus::Queued => JobStatus::Queued, + JobStateStatus::Running => JobStatus::Running, + JobStateStatus::Completed => JobStatus::Completed, + JobStateStatus::Failed => JobStatus::Failed, + JobStateStatus::Cancelled => JobStatus::Cancelled, + } +} diff --git a/crates/execpolicy/Cargo.toml b/crates/execpolicy/Cargo.toml new file mode 100644 index 00000000..2015636a --- /dev/null +++ b/crates/execpolicy/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "deepseek-execpolicy" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Execution policy and approval model parity for DeepSeek workspace architecture" + +[dependencies] +anyhow.workspace = true +deepseek-protocol = { path = "../protocol" } +serde.workspace = true diff --git a/crates/execpolicy/src/lib.rs b/crates/execpolicy/src/lib.rs new file mode 100644 index 00000000..608cdc8a --- /dev/null +++ b/crates/execpolicy/src/lib.rs @@ -0,0 +1,191 @@ +use std::collections::HashSet; + +use anyhow::Result; +use deepseek_protocol::{NetworkPolicyAmendment, NetworkPolicyRuleAction}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum AskForApproval { + UnlessTrusted, + OnFailure, + OnRequest, + Reject { + sandbox_approval: bool, + rules: bool, + mcp_elicitations: bool, + }, + Never, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ExecPolicyAmendment { + pub prefixes: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum ExecApprovalRequirement { + Skip { + bypass_sandbox: bool, + proposed_execpolicy_amendment: Option, + }, + NeedsApproval { + reason: String, + proposed_execpolicy_amendment: Option, + proposed_network_policy_amendments: Vec, + }, + Forbidden { + reason: String, + }, +} + +impl ExecApprovalRequirement { + pub fn reason(&self) -> &str { + match self { + ExecApprovalRequirement::Skip { .. } => "Execution allowed by policy.", + ExecApprovalRequirement::NeedsApproval { reason, .. } => reason, + ExecApprovalRequirement::Forbidden { reason } => reason, + } + } + + pub fn phase(&self) -> &'static str { + match self { + ExecApprovalRequirement::Skip { .. } => "allowed", + ExecApprovalRequirement::NeedsApproval { .. } => "needs_approval", + ExecApprovalRequirement::Forbidden { .. } => "forbidden", + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ExecPolicyDecision { + pub allow: bool, + pub requires_approval: bool, + pub requirement: ExecApprovalRequirement, + pub matched_rule: Option, +} + +impl ExecPolicyDecision { + pub fn reason(&self) -> &str { + self.requirement.reason() + } +} + +#[derive(Debug, Clone)] +pub struct ExecPolicyContext<'a> { + pub command: &'a str, + pub cwd: &'a str, + pub ask_for_approval: AskForApproval, + pub sandbox_mode: Option<&'a str>, +} + +#[derive(Debug, Clone, Default)] +pub struct ExecPolicyEngine { + trusted_prefixes: Vec, + denied_prefixes: Vec, + approved_for_session: HashSet, +} + +impl ExecPolicyEngine { + pub fn new(trusted_prefixes: Vec, denied_prefixes: Vec) -> Self { + Self { + trusted_prefixes, + denied_prefixes, + approved_for_session: HashSet::new(), + } + } + + pub fn remember_session_approval(&mut self, approval_key: String) { + self.approved_for_session.insert(approval_key); + } + + pub fn is_session_approved(&self, approval_key: &str) -> bool { + self.approved_for_session.contains(approval_key) + } + + pub fn check(&self, ctx: ExecPolicyContext<'_>) -> Result { + let normalized = normalize_command(ctx.command); + if let Some(rule) = self + .denied_prefixes + .iter() + .find(|rule| normalized.starts_with(&normalize_command(rule))) + { + return Ok(ExecPolicyDecision { + allow: false, + requires_approval: false, + matched_rule: Some(rule.clone()), + requirement: ExecApprovalRequirement::Forbidden { + reason: format!("Command blocked by denied prefix rule '{rule}'"), + }, + }); + } + + let trusted_rule = self + .trusted_prefixes + .iter() + .find(|rule| normalized.starts_with(&normalize_command(rule))) + .cloned(); + let is_trusted = trusted_rule.is_some(); + + let requirement = match ctx.ask_for_approval { + AskForApproval::Never => ExecApprovalRequirement::Skip { + bypass_sandbox: false, + proposed_execpolicy_amendment: None, + }, + AskForApproval::UnlessTrusted if is_trusted => ExecApprovalRequirement::Skip { + bypass_sandbox: false, + proposed_execpolicy_amendment: None, + }, + AskForApproval::OnFailure => ExecApprovalRequirement::Skip { + bypass_sandbox: false, + proposed_execpolicy_amendment: None, + }, + AskForApproval::Reject { rules, .. } if rules => ExecApprovalRequirement::Forbidden { + reason: "Policy is configured to reject rule-exceptions.".to_string(), + }, + _ => ExecApprovalRequirement::NeedsApproval { + reason: if is_trusted { + "Approval requested by policy mode.".to_string() + } else { + "Unmatched command prefix requires approval.".to_string() + }, + proposed_execpolicy_amendment: if is_trusted { + None + } else { + Some(ExecPolicyAmendment { + prefixes: vec![first_token(ctx.command)], + }) + }, + proposed_network_policy_amendments: vec![NetworkPolicyAmendment { + host: ctx.cwd.to_string(), + action: NetworkPolicyRuleAction::Allow, + }], + }, + }; + + let (allow, requires_approval) = match requirement { + ExecApprovalRequirement::Skip { .. } => (true, false), + ExecApprovalRequirement::NeedsApproval { .. } => (true, true), + ExecApprovalRequirement::Forbidden { .. } => (false, false), + }; + + Ok(ExecPolicyDecision { + allow, + requires_approval, + matched_rule: trusted_rule, + requirement, + }) + } +} + +fn normalize_command(value: &str) -> String { + value.trim().to_ascii_lowercase() +} + +fn first_token(command: &str) -> String { + command + .split_whitespace() + .next() + .unwrap_or_default() + .to_string() +} diff --git a/crates/hooks/Cargo.toml b/crates/hooks/Cargo.toml new file mode 100644 index 00000000..8a0fdbb8 --- /dev/null +++ b/crates/hooks/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "deepseek-hooks" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Hook dispatch and notifications parity for DeepSeek workspace architecture" + +[dependencies] +anyhow.workspace = true +async-trait.workspace = true +chrono.workspace = true +deepseek-protocol = { path = "../protocol" } +reqwest.workspace = true +serde.workspace = true +serde_json.workspace = true +tokio.workspace = true diff --git a/crates/hooks/src/lib.rs b/crates/hooks/src/lib.rs new file mode 100644 index 00000000..2abe5fb3 --- /dev/null +++ b/crates/hooks/src/lib.rs @@ -0,0 +1,170 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use async_trait::async_trait; +use chrono::Utc; +use deepseek_protocol::EventFrame; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; +use tokio::io::AsyncWriteExt; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum HookEvent { + ResponseStart { + response_id: String, + }, + ResponseDelta { + response_id: String, + delta: String, + }, + ResponseEnd { + response_id: String, + }, + ToolLifecycle { + response_id: String, + tool_name: String, + phase: String, + payload: Value, + }, + JobLifecycle { + job_id: String, + phase: String, + progress: Option, + detail: Option, + }, + ApprovalLifecycle { + approval_id: String, + phase: String, + reason: Option, + }, + GenericEventFrame { + frame: EventFrame, + }, +} + +impl HookEvent { + pub fn to_json(&self) -> Value { + serde_json::to_value(self).unwrap_or_else(|_| json!({"type":"serialization_error"})) + } +} + +#[async_trait] +pub trait HookSink: Send + Sync { + async fn emit(&self, event: &HookEvent) -> Result<()>; +} + +#[derive(Default)] +pub struct StdoutHookSink; + +#[async_trait] +impl HookSink for StdoutHookSink { + async fn emit(&self, event: &HookEvent) -> Result<()> { + println!("{}", event.to_json()); + Ok(()) + } +} + +pub struct JsonlHookSink { + path: PathBuf, +} + +impl JsonlHookSink { + pub fn new(path: PathBuf) -> Self { + Self { path } + } +} + +#[async_trait] +impl HookSink for JsonlHookSink { + async fn emit(&self, event: &HookEvent) -> Result<()> { + if let Some(parent) = self.path.parent() { + tokio::fs::create_dir_all(parent).await.with_context(|| { + format!("failed to create hook log directory {}", parent.display()) + })?; + } + let mut file = tokio::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&self.path) + .await + .with_context(|| format!("failed to open hook log {}", self.path.display()))?; + let payload = json!({ + "at": Utc::now().to_rfc3339(), + "event": event + }); + let encoded = serde_json::to_string(&payload).context("failed to encode hook event")?; + file.write_all(encoded.as_bytes()) + .await + .context("failed to write hook event")?; + file.write_all(b"\n") + .await + .context("failed to write hook event newline")?; + Ok(()) + } +} + +pub struct WebhookHookSink { + url: String, + client: reqwest::Client, +} + +impl WebhookHookSink { + pub fn new(url: String) -> Self { + Self { + url, + client: reqwest::Client::new(), + } + } +} + +#[async_trait] +impl HookSink for WebhookHookSink { + async fn emit(&self, event: &HookEvent) -> Result<()> { + let mut retries = 0usize; + loop { + let resp = self + .client + .post(&self.url) + .json(&json!({ + "at": Utc::now().to_rfc3339(), + "event": event, + })) + .send() + .await; + match resp { + Ok(response) if response.status().is_success() => return Ok(()), + Ok(response) => { + if retries >= 2 { + anyhow::bail!("webhook returned non-success status {}", response.status()); + } + } + Err(err) => { + if retries >= 2 { + return Err(err).context("webhook request failed"); + } + } + } + retries += 1; + tokio::time::sleep(std::time::Duration::from_millis(200 * retries as u64)).await; + } + } +} + +#[derive(Default, Clone)] +pub struct HookDispatcher { + sinks: Vec>, +} + +impl HookDispatcher { + pub fn add_sink(&mut self, sink: Arc) { + self.sinks.push(sink); + } + + pub async fn emit(&self, event: HookEvent) { + for sink in &self.sinks { + let _ = sink.emit(&event).await; + } + } +} diff --git a/crates/mcp/Cargo.toml b/crates/mcp/Cargo.toml new file mode 100644 index 00000000..2af8a0e0 --- /dev/null +++ b/crates/mcp/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "deepseek-mcp" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "MCP server lifecycle and tool proxy compatibility for DeepSeek workspace architecture" + +[dependencies] +anyhow.workspace = true +deepseek-protocol = { path = "../protocol" } +serde.workspace = true +serde_json.workspace = true diff --git a/crates/mcp/src/lib.rs b/crates/mcp/src/lib.rs new file mode 100644 index 00000000..f7392760 --- /dev/null +++ b/crates/mcp/src/lib.rs @@ -0,0 +1,893 @@ +use std::collections::HashMap; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; + +use anyhow::{Context, Result, bail}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpServerConfig { + pub name: String, + pub command: String, + #[serde(default)] + pub args: Vec, + #[serde(default)] + pub env: HashMap, + #[serde(default = "default_true")] + pub enabled: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ToolFilter { + #[serde(default)] + pub allow: Vec, + #[serde(default)] + pub deny: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpServerDefinition { + pub config: McpServerConfig, + #[serde(default)] + pub filter: ToolFilter, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum McpStartupStatus { + Starting, + Ready, + Failed { error: String }, + Cancelled, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpStartupUpdateEvent { + pub server_name: String, + pub status: McpStartupStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpStartupFailure { + pub server_name: String, + pub error: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpStartupCompleteEvent { + pub ready: Vec, + pub failed: Vec, + pub cancelled: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpToolDescriptor { + pub server_name: String, + pub tool_name: String, + pub qualified_name: String, + pub description: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpResourceDescriptor { + pub server_name: String, + pub uri: String, + pub description: Option, +} + +pub trait McpManagedClient: Send + Sync { + fn list_tools(&self) -> Result>; + fn call_tool(&self, tool_name: &str, arguments: Value) -> Result; + fn list_resources(&self) -> Result>; + fn read_resource(&self, uri: &str) -> Result; +} + +#[derive(Debug, Default)] +pub struct InMemoryMcpClient { + tools: HashMap, + resources: HashMap, +} + +impl InMemoryMcpClient { + pub fn with_tool(mut self, name: &str, sample_result: Value) -> Self { + self.tools.insert(name.to_string(), sample_result); + self + } + + pub fn with_resource(mut self, uri: &str, data: Value) -> Self { + self.resources.insert(uri.to_string(), data); + self + } +} + +impl McpManagedClient for InMemoryMcpClient { + fn list_tools(&self) -> Result> { + Ok(self + .tools + .keys() + .map(|name| McpToolDescriptor { + server_name: "in-memory".to_string(), + tool_name: name.clone(), + qualified_name: name.clone(), + description: None, + }) + .collect()) + } + + fn call_tool(&self, tool_name: &str, _arguments: Value) -> Result { + self.tools + .get(tool_name) + .cloned() + .with_context(|| format!("tool '{tool_name}' not found")) + } + + fn list_resources(&self) -> Result> { + Ok(self + .resources + .keys() + .map(|uri| McpResourceDescriptor { + server_name: "in-memory".to_string(), + uri: uri.clone(), + description: None, + }) + .collect()) + } + + fn read_resource(&self, uri: &str) -> Result { + self.resources + .get(uri) + .cloned() + .with_context(|| format!("resource '{uri}' not found")) + } +} + +#[derive(Default)] +pub struct McpManager { + configs: HashMap, + clients: HashMap>, +} + +impl McpManager { + pub fn register_server( + &mut self, + config: McpServerConfig, + filter: ToolFilter, + client: Box, + ) { + self.clients.insert(config.name.clone(), client); + self.configs.insert(config.name.clone(), (config, filter)); + } + + pub fn start_all(&self, mut emit: F) -> McpStartupCompleteEvent + where + F: FnMut(McpStartupUpdateEvent), + { + let mut ready = Vec::new(); + let mut failed = Vec::new(); + let mut cancelled = Vec::new(); + for (server_name, (cfg, _)) in &self.configs { + if !cfg.enabled { + emit(McpStartupUpdateEvent { + server_name: server_name.clone(), + status: McpStartupStatus::Cancelled, + }); + cancelled.push(server_name.clone()); + continue; + } + emit(McpStartupUpdateEvent { + server_name: server_name.clone(), + status: McpStartupStatus::Starting, + }); + if self.clients.contains_key(server_name) { + emit(McpStartupUpdateEvent { + server_name: server_name.clone(), + status: McpStartupStatus::Ready, + }); + ready.push(server_name.clone()); + } else { + let error = "client not registered".to_string(); + emit(McpStartupUpdateEvent { + server_name: server_name.clone(), + status: McpStartupStatus::Failed { + error: error.clone(), + }, + }); + failed.push(McpStartupFailure { + server_name: server_name.clone(), + error, + }); + } + } + McpStartupCompleteEvent { + ready, + failed, + cancelled, + } + } + + pub fn stop_server(&mut self, server_name: &str) -> Result<()> { + self.clients + .remove(server_name) + .with_context(|| format!("server '{server_name}' is not running"))?; + Ok(()) + } + + pub fn unregister_server(&mut self, server_name: &str) -> Result<()> { + let had_config = self.configs.remove(server_name).is_some(); + self.clients.remove(server_name); + if !had_config { + bail!("server '{server_name}' is not registered"); + } + Ok(()) + } + + pub fn list_tools(&self) -> Result> { + let mut out = Vec::new(); + for (server_name, (_, filter)) in &self.configs { + let Some(client) = self.clients.get(server_name) else { + continue; + }; + let tools = client.list_tools()?; + for tool in tools { + if !allowed_by_filter(&tool.tool_name, filter) { + continue; + } + let qualified_name = qualify_tool_name(server_name, &tool.tool_name); + out.push(McpToolDescriptor { + server_name: server_name.clone(), + tool_name: tool.tool_name, + qualified_name, + description: tool.description, + }); + } + } + Ok(out) + } + + pub fn call_tool(&self, server_name: &str, tool_name: &str, arguments: Value) -> Result { + let client = self + .clients + .get(server_name) + .with_context(|| format!("MCP server '{server_name}' not available"))?; + client.call_tool(tool_name, arguments) + } + + pub fn call_qualified_tool( + &self, + qualified_tool_name: &str, + arguments: Value, + ) -> Result { + let (server_name, tool_name) = parse_qualified_tool_name(qualified_tool_name) + .with_context(|| format!("invalid qualified MCP tool name: {qualified_tool_name}"))?; + self.call_tool(&server_name, &tool_name, arguments) + } + + pub fn list_resources(&self) -> Result> { + let mut out = Vec::new(); + for server_name in self.configs.keys() { + let Some(client) = self.clients.get(server_name) else { + continue; + }; + for mut resource in client.list_resources()? { + resource.server_name = server_name.clone(); + out.push(resource); + } + } + Ok(out) + } + + pub fn read_resource(&self, server_name: &str, uri: &str) -> Result { + let client = self + .clients + .get(server_name) + .with_context(|| format!("MCP server '{server_name}' not available"))?; + client.read_resource(uri) + } + + pub fn update_sandbox_state(&self, sandbox_mode: &str, cwd: &str) -> Result> { + let mut notices = Vec::new(); + for server_name in self.configs.keys() { + notices.push(json!({ + "server_name": server_name, + "method": "codex/sandbox-state/update", + "params": { + "sandbox_mode": sandbox_mode, + "cwd": cwd + } + })); + } + Ok(notices) + } +} + +fn default_true() -> bool { + true +} + +fn allowed_by_filter(name: &str, filter: &ToolFilter) -> bool { + if filter.deny.iter().any(|pattern| pattern == name) { + return false; + } + if filter.allow.is_empty() { + return true; + } + filter.allow.iter().any(|pattern| pattern == name) +} + +fn sanitize_component(value: &str) -> String { + value + .chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() || ch == '_' { + ch.to_ascii_lowercase() + } else { + '_' + } + }) + .collect() +} + +fn qualify_tool_name(server: &str, tool: &str) -> String { + let mut name = format!( + "mcp__{}__{}", + sanitize_component(server), + sanitize_component(tool) + ); + if name.len() > 64 { + let mut hasher = DefaultHasher::new(); + name.hash(&mut hasher); + let hash = format!("{:x}", hasher.finish()); + name.truncate(48); + name.push('_'); + name.push_str(&hash[..12]); + } + name +} + +fn parse_qualified_tool_name(value: &str) -> Result<(String, String)> { + let Some(stripped) = value.strip_prefix("mcp__") else { + bail!("missing mcp__ prefix"); + }; + let mut split = stripped.splitn(2, "__"); + let server = split + .next() + .filter(|s| !s.is_empty()) + .map(ToOwned::to_owned) + .context("missing server segment")?; + let tool = split + .next() + .filter(|s| !s.is_empty()) + .map(ToOwned::to_owned) + .context("missing tool segment")?; + Ok((server, tool)) +} + +#[derive(Debug, Deserialize)] +struct JsonRpcRequest { + #[serde(default)] + jsonrpc: Option, + #[serde(default)] + id: Option, + method: String, + #[serde(default)] + params: Value, +} + +#[derive(Debug)] +struct JsonRpcError { + code: i64, + message: String, + data: Option, +} + +#[derive(Debug, Deserialize)] +struct ToolsListParams { + #[serde(default)] + server: Option, +} + +#[derive(Debug, Deserialize)] +struct ToolsCallParams { + #[serde(default)] + name: Option, + #[serde(default)] + tool: Option, + #[serde(default)] + server: Option, + #[serde(default)] + arguments: Value, +} + +#[derive(Debug, Deserialize)] +struct ResourcesListParams { + #[serde(default)] + server: Option, +} + +#[derive(Debug, Deserialize)] +struct ResourcesReadParams { + #[serde(default)] + server: Option, + uri: String, +} + +#[derive(Debug, Deserialize)] +struct ServerRegisterParams { + server: McpServerConfig, + #[serde(default)] + filter: ToolFilter, + #[serde(default = "default_true")] + start: bool, +} + +#[derive(Debug, Deserialize)] +struct ServerNameParams { + name: String, +} + +struct StdioMcpState { + manager: McpManager, + definitions: HashMap, + running: HashMap, + lifecycle_state: String, +} + +pub fn run_stdio_server( + initial_definitions: Vec, +) -> Result> { + use std::io::{self, BufRead, Write}; + + let stdin = io::stdin(); + let mut stdout = io::stdout(); + let mut stderr = io::stderr(); + let mut state = build_stdio_state(initial_definitions); + + for line in stdin.lock().lines() { + let line = line.context("failed to read stdio line")?; + if line.trim().is_empty() { + continue; + } + + let request: JsonRpcRequest = match serde_json::from_str(&line) { + Ok(value) => value, + Err(err) => { + let msg = jsonrpc_error( + None, + JsonRpcError::parse_error(format!("invalid json: {err}")), + ); + writeln!(stdout, "{msg}")?; + stdout.flush()?; + continue; + } + }; + + if request + .jsonrpc + .as_deref() + .is_some_and(|version| version != "2.0") + { + let response = jsonrpc_error( + request.id, + JsonRpcError::invalid_request("jsonrpc version must be 2.0"), + ); + writeln!(stdout, "{response}")?; + stdout.flush()?; + continue; + } + + let response = match dispatch_stdio_request(&mut state, &request.method, request.params) { + Ok((result, should_exit)) => { + let payload = jsonrpc_result(request.id, result); + writeln!(stdout, "{payload}")?; + stdout.flush()?; + if should_exit { + break; + } + continue; + } + Err(err) => jsonrpc_error(request.id, err), + }; + + writeln!(stdout, "{response}")?; + stdout.flush()?; + } + + state.lifecycle_state = "stopped".to_string(); + let _ = writeln!(stderr, "deepseek-mcp stdio server exited"); + let mut definitions: Vec = state.definitions.into_values().collect(); + definitions.sort_by(|a, b| a.config.name.cmp(&b.config.name)); + Ok(definitions) +} + +fn build_stdio_state(initial_definitions: Vec) -> StdioMcpState { + let mut manager = McpManager::default(); + let mut definitions = HashMap::new(); + let mut running = HashMap::new(); + + for definition in initial_definitions { + let name = definition.config.name.clone(); + let should_start = definition.config.enabled; + definitions.insert(name.clone(), definition.clone()); + if should_start { + manager.register_server( + definition.config.clone(), + definition.filter.clone(), + default_stdio_client(&name), + ); + running.insert(name, true); + } else { + running.insert(name, false); + } + } + + StdioMcpState { + manager, + definitions, + running, + lifecycle_state: "running".to_string(), + } +} + +fn default_stdio_client(server_name: &str) -> Box { + let health_uri = format!("mcp://{server_name}/health"); + let capabilities_uri = format!("mcp://{server_name}/capabilities"); + Box::new( + InMemoryMcpClient::default() + .with_tool( + "health", + json!({ + "status": "ok", + "server_name": server_name + }), + ) + .with_tool( + "capabilities", + json!({ + "tools": ["health", "capabilities"], + "resources": [health_uri.clone(), capabilities_uri.clone()] + }), + ) + .with_resource( + &health_uri, + json!({ + "status": "ok", + "server_name": server_name + }), + ) + .with_resource( + &capabilities_uri, + json!({ + "server_name": server_name, + "methods": [ + "tools/list", + "tools/call", + "resources/list", + "resources/read", + "server/list", + "server/register", + "server/start", + "server/stop", + "server/unregister" + ] + }), + ), + ) +} + +fn default_rpc_methods() -> Vec<&'static str> { + vec![ + "initialize", + "healthz", + "capabilities", + "tools/list", + "tools/call", + "resources/list", + "resources/read", + "server/list", + "server/register", + "server/start", + "server/stop", + "server/unregister", + "shutdown", + ] +} + +fn lifecycle_snapshot(state: &StdioMcpState) -> Value { + let mut servers: Vec = state + .definitions + .iter() + .map(|(name, definition)| { + let is_running = state.running.get(name).copied().unwrap_or(false); + json!({ + "name": name, + "enabled": definition.config.enabled, + "running": is_running, + "command": definition.config.command.clone(), + "args": definition.config.args.clone(), + }) + }) + .collect(); + servers.sort_by(|a, b| { + let a_name = a.get("name").and_then(Value::as_str).unwrap_or_default(); + let b_name = b.get("name").and_then(Value::as_str).unwrap_or_default(); + a_name.cmp(b_name) + }); + + let running_count = state.running.values().filter(|running| **running).count(); + json!({ + "status": state.lifecycle_state, + "servers": servers, + "counts": { + "defined": state.definitions.len(), + "running": running_count + } + }) +} + +fn params_or_object(params: Value) -> Value { + if params.is_null() { json!({}) } else { params } +} + +fn parse_params(params: Value) -> std::result::Result { + serde_json::from_value(params).map_err(|err| JsonRpcError::invalid_params(err.to_string())) +} + +fn parse_server_from_uri(uri: &str) -> Option { + let stripped = uri.strip_prefix("mcp://")?; + let server = stripped.split('/').next()?; + if server.is_empty() { + None + } else { + Some(server.to_string()) + } +} + +fn dispatch_stdio_request( + state: &mut StdioMcpState, + method: &str, + params: Value, +) -> std::result::Result<(Value, bool), JsonRpcError> { + match method { + "initialize" | "capabilities" => Ok(( + json!({ + "server": "deepseek-mcp", + "transport": "stdio", + "methods": default_rpc_methods(), + "lifecycle": lifecycle_snapshot(state) + }), + false, + )), + "healthz" => Ok(( + json!({ + "status": "ok", + "service": "deepseek-mcp", + "transport": "stdio", + "lifecycle": lifecycle_snapshot(state) + }), + false, + )), + "tools/list" => { + let parsed: ToolsListParams = parse_params(params_or_object(params))?; + let mut tools = state + .manager + .list_tools() + .map_err(|err| JsonRpcError::internal(err.to_string()))?; + if let Some(server) = parsed.server { + tools.retain(|tool| tool.server_name == server); + } + Ok((json!({ "tools": tools }), false)) + } + "tools/call" => { + let parsed: ToolsCallParams = parse_params(params_or_object(params))?; + let ToolsCallParams { + name, + tool, + server, + arguments, + } = parsed; + let tool_name = name + .or(tool) + .context("missing tool name") + .map_err(|err| JsonRpcError::invalid_params(err.to_string()))?; + let arguments = if arguments.is_null() { + json!({}) + } else { + arguments + }; + let result = if tool_name.starts_with("mcp__") { + state + .manager + .call_qualified_tool(&tool_name, arguments) + .map_err(|err| JsonRpcError::internal(err.to_string()))? + } else { + let server = server + .context("missing server for unqualified tool") + .map_err(|err| JsonRpcError::invalid_params(err.to_string()))?; + state + .manager + .call_tool(&server, &tool_name, arguments) + .map_err(|err| JsonRpcError::internal(err.to_string()))? + }; + Ok((json!({ "result": result }), false)) + } + "resources/list" => { + let parsed: ResourcesListParams = parse_params(params_or_object(params))?; + let mut resources = state + .manager + .list_resources() + .map_err(|err| JsonRpcError::internal(err.to_string()))?; + if let Some(server) = parsed.server { + resources.retain(|resource| resource.server_name == server); + } + Ok((json!({ "resources": resources }), false)) + } + "resources/read" => { + let parsed: ResourcesReadParams = parse_params(params_or_object(params))?; + let ResourcesReadParams { server, uri } = parsed; + let server_name = server + .or_else(|| parse_server_from_uri(&uri)) + .context("missing server for resource read") + .map_err(|err| JsonRpcError::invalid_params(err.to_string()))?; + let value = state + .manager + .read_resource(&server_name, &uri) + .map_err(|err| JsonRpcError::internal(err.to_string()))?; + Ok((json!({ "resource": value }), false)) + } + "server/list" | "servers/list" => { + Ok((json!({ "lifecycle": lifecycle_snapshot(state) }), false)) + } + "server/register" | "servers/register" => { + let parsed: ServerRegisterParams = parse_params(params_or_object(params))?; + let name = parsed.server.name.clone(); + if name.trim().is_empty() { + return Err(JsonRpcError::invalid_params( + "server.name must not be empty", + )); + } + + if state.definitions.contains_key(&name) { + let _ = state.manager.unregister_server(&name); + } + state.definitions.insert( + name.clone(), + McpServerDefinition { + config: parsed.server.clone(), + filter: parsed.filter.clone(), + }, + ); + let should_run = parsed.start && parsed.server.enabled; + if should_run { + state.manager.register_server( + parsed.server.clone(), + parsed.filter.clone(), + default_stdio_client(&name), + ); + } + state.running.insert(name, should_run); + Ok((json!({ "lifecycle": lifecycle_snapshot(state) }), false)) + } + "server/start" | "servers/start" => { + let parsed: ServerNameParams = parse_params(params_or_object(params))?; + let definition = state + .definitions + .get(&parsed.name) + .cloned() + .with_context(|| format!("server '{}' is not defined", parsed.name)) + .map_err(|err| JsonRpcError::invalid_params(err.to_string()))?; + if !definition.config.enabled { + return Err(JsonRpcError::invalid_params(format!( + "server '{}' is disabled", + parsed.name + ))); + } + if !state.running.get(&parsed.name).copied().unwrap_or(false) { + state.manager.register_server( + definition.config.clone(), + definition.filter.clone(), + default_stdio_client(&parsed.name), + ); + state.running.insert(parsed.name, true); + } + Ok((json!({ "lifecycle": lifecycle_snapshot(state) }), false)) + } + "server/stop" | "servers/stop" => { + let parsed: ServerNameParams = parse_params(params_or_object(params))?; + if state.running.get(&parsed.name).copied().unwrap_or(false) { + state + .manager + .stop_server(&parsed.name) + .map_err(|err| JsonRpcError::internal(err.to_string()))?; + } + state.running.insert(parsed.name, false); + Ok((json!({ "lifecycle": lifecycle_snapshot(state) }), false)) + } + "server/unregister" | "servers/unregister" => { + let parsed: ServerNameParams = parse_params(params_or_object(params))?; + if state.definitions.remove(&parsed.name).is_none() { + return Err(JsonRpcError::invalid_params(format!( + "server '{}' is not defined", + parsed.name + ))); + } + let _ = state.manager.unregister_server(&parsed.name); + state.running.remove(&parsed.name); + Ok((json!({ "lifecycle": lifecycle_snapshot(state) }), false)) + } + "shutdown" => { + state.lifecycle_state = "shutting_down".to_string(); + Ok(( + json!({ + "ok": true, + "lifecycle": lifecycle_snapshot(state) + }), + true, + )) + } + _ => Err(JsonRpcError::method_not_found(method)), + } +} + +fn jsonrpc_result(id: Option, result: Value) -> Value { + json!({ + "jsonrpc": "2.0", + "id": id.unwrap_or(Value::Null), + "result": result + }) +} + +fn jsonrpc_error(id: Option, err: JsonRpcError) -> Value { + json!({ + "jsonrpc": "2.0", + "id": id.unwrap_or(Value::Null), + "error": { + "code": err.code, + "message": err.message, + "data": err.data + } + }) +} + +impl JsonRpcError { + fn parse_error(message: impl Into) -> Self { + Self { + code: -32700, + message: message.into(), + data: None, + } + } + + fn invalid_request(message: impl Into) -> Self { + Self { + code: -32600, + message: message.into(), + data: None, + } + } + + fn method_not_found(method: &str) -> Self { + Self { + code: -32601, + message: format!("unsupported method: {method}"), + data: None, + } + } + + fn invalid_params(message: impl Into) -> Self { + Self { + code: -32602, + message: message.into(), + data: None, + } + } + + fn internal(message: impl Into) -> Self { + Self { + code: -32603, + message: message.into(), + data: None, + } + } +} diff --git a/crates/protocol/Cargo.toml b/crates/protocol/Cargo.toml new file mode 100644 index 00000000..045c7ff4 --- /dev/null +++ b/crates/protocol/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "deepseek-protocol" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Codex-style app-server protocol frames for DeepSeek workspace architecture" + +[dependencies] +serde.workspace = true +serde_json.workspace = true diff --git a/crates/protocol/src/lib.rs b/crates/protocol/src/lib.rs new file mode 100644 index 00000000..5b7a722a --- /dev/null +++ b/crates/protocol/src/lib.rs @@ -0,0 +1,451 @@ +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Envelope { + pub request_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub thread_id: Option, + pub body: T, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ThreadStatus { + Running, + Idle, + Completed, + Failed, + Paused, + Archived, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum SessionSource { + Interactive, + Resume, + Fork, + Api, + Unknown, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Thread { + pub id: String, + pub preview: String, + pub ephemeral: bool, + pub model_provider: String, + pub created_at: i64, + pub updated_at: i64, + pub status: ThreadStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + pub cwd: PathBuf, + pub cli_version: String, + pub source: SessionSource, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadStartParams { + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_provider: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + #[serde(default)] + pub persist_extended_history: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadResumeParams { + pub thread_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub history: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_provider: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub approval_policy: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub sandbox: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub config: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub base_instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub developer_instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub personality: Option, + #[serde(default)] + pub persist_extended_history: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadForkParams { + pub thread_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_provider: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub approval_policy: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub sandbox: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub config: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub base_instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub developer_instructions: Option, + #[serde(default)] + pub persist_extended_history: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadListParams { + #[serde(default)] + pub include_archived: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub limit: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadReadParams { + pub thread_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadSetNameParams { + pub thread_id: String, + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum ThreadRequest { + Create { + #[serde(default)] + metadata: Value, + }, + Start(ThreadStartParams), + Resume(ThreadResumeParams), + Fork(ThreadForkParams), + List(ThreadListParams), + Read(ThreadReadParams), + SetName(ThreadSetNameParams), + Archive { + thread_id: String, + }, + Unarchive { + thread_id: String, + }, + Message { + thread_id: String, + input: String, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadResponse { + pub thread_id: String, + pub status: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub thread: Option, + #[serde(default)] + pub threads: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_provider: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub approval_policy: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub sandbox: Option, + #[serde(default)] + pub events: Vec, + #[serde(default)] + pub data: Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum AppRequest { + Capabilities, + ConfigGet { key: String }, + ConfigSet { key: String, value: String }, + ConfigUnset { key: String }, + ConfigList, + Models, + ThreadLoadedList, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AppResponse { + pub ok: bool, + pub data: Value, + #[serde(default)] + pub events: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PromptRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub thread_id: Option, + pub prompt: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PromptResponse { + pub output: String, + pub model: String, + #[serde(default)] + pub events: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum AskForApproval { + UnlessTrusted, + OnFailure, + OnRequest, + Reject { + sandbox_approval: bool, + rules: bool, + mcp_elicitations: bool, + }, + Never, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ToolKind { + Function, + Mcp, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LocalShellParams { + pub command: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout_ms: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolPayload { + Function { + arguments: String, + }, + Custom { + input: String, + }, + LocalShell { + params: LocalShellParams, + }, + Mcp { + server: String, + tool: String, + raw_arguments: Value, + #[serde(skip_serializing_if = "Option::is_none")] + raw_tool_call_id: Option, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolOutput { + Function { + #[serde(skip_serializing_if = "Option::is_none")] + body: Option, + success: bool, + }, + Mcp { + result: Value, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum NetworkPolicyRuleAction { + Allow, + Deny, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct NetworkPolicyAmendment { + pub host: String, + pub action: NetworkPolicyRuleAction, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ReviewDecision { + Approved, + ApprovedExecpolicyAmendment, + ApprovedForSession, + NetworkPolicyAmendment { + host: String, + action: NetworkPolicyRuleAction, + }, + Denied, + Abort, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum McpStartupStatus { + Starting, + Ready, + Failed { error: String }, + Cancelled, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpStartupUpdateEvent { + pub server_name: String, + pub status: McpStartupStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpStartupFailure { + pub server_name: String, + pub error: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpStartupCompleteEvent { + pub ready: Vec, + pub failed: Vec, + pub cancelled: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NetworkApprovalContext { + pub host: String, + pub protocol: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExecApprovalRequestEvent { + pub call_id: String, + pub approval_id: String, + pub turn_id: String, + pub command: String, + pub cwd: String, + pub reason: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub network_approval_context: Option, + #[serde(default)] + pub proposed_execpolicy_amendment: Vec, + #[serde(default)] + pub proposed_network_policy_amendments: Vec, + #[serde(default)] + pub additional_permissions: Vec, + #[serde(default)] + pub available_decisions: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "event", rename_all = "snake_case")] +pub enum EventFrame { + ResponseStart { + response_id: String, + }, + ResponseDelta { + response_id: String, + delta: String, + }, + ResponseEnd { + response_id: String, + }, + ToolCallStart { + response_id: String, + tool_name: String, + arguments: Value, + }, + ToolCallResult { + response_id: String, + tool_name: String, + output: Value, + }, + McpStartupUpdate { + update: McpStartupUpdateEvent, + }, + McpStartupComplete { + summary: McpStartupCompleteEvent, + }, + McpToolCallBegin { + server_name: String, + tool_name: String, + }, + McpToolCallEnd { + server_name: String, + tool_name: String, + ok: bool, + }, + ExecApprovalRequest { + request: ExecApprovalRequestEvent, + }, + ApplyPatchApprovalRequest { + request: ExecApprovalRequestEvent, + }, + ElicitationRequest { + server_name: String, + request_id: String, + prompt: String, + }, + ExecCommandBegin { + command: String, + cwd: String, + }, + ExecCommandOutputDelta { + command: String, + delta: String, + }, + ExecCommandEnd { + command: String, + exit_code: i32, + }, + PatchApplyBegin { + path: String, + }, + PatchApplyEnd { + path: String, + ok: bool, + }, + TurnStarted { + turn_id: String, + }, + TurnComplete { + turn_id: String, + }, + TurnAborted { + turn_id: String, + reason: String, + }, + Error { + response_id: String, + message: String, + }, +} diff --git a/crates/protocol/tests/parity_protocol.rs b/crates/protocol/tests/parity_protocol.rs new file mode 100644 index 00000000..12eceaf5 --- /dev/null +++ b/crates/protocol/tests/parity_protocol.rs @@ -0,0 +1,50 @@ +use deepseek_protocol::{EventFrame, ThreadListParams, ThreadRequest, ThreadResumeParams}; + +#[test] +fn thread_resume_params_round_trip() { + let request = ThreadRequest::Resume(ThreadResumeParams { + thread_id: "thread-123".to_string(), + history: None, + path: None, + model: Some("deepseek-reasoner".to_string()), + model_provider: Some("deepseek".to_string()), + cwd: None, + approval_policy: Some("on-request".to_string()), + sandbox: Some("workspace-write".to_string()), + config: None, + base_instructions: Some("base".to_string()), + developer_instructions: Some("dev".to_string()), + personality: Some("default".to_string()), + persist_extended_history: true, + }); + + let encoded = serde_json::to_string(&request).expect("serialize request"); + let decoded: ThreadRequest = serde_json::from_str(&encoded).expect("deserialize request"); + match decoded { + ThreadRequest::Resume(params) => { + assert_eq!(params.thread_id, "thread-123"); + assert_eq!(params.model.as_deref(), Some("deepseek-reasoner")); + assert!(params.persist_extended_history); + } + other => panic!("unexpected request: {other:?}"), + } +} + +#[test] +fn thread_list_params_defaults_are_serializable() { + let request = ThreadRequest::List(ThreadListParams { + include_archived: false, + limit: Some(20), + }); + let encoded = serde_json::to_string_pretty(&request).expect("serialize list request"); + assert!(encoded.contains("include_archived")); +} + +#[test] +fn event_frame_serialization_contains_expected_tag() { + let frame = EventFrame::TurnComplete { + turn_id: "turn-1".to_string(), + }; + let encoded = serde_json::to_string(&frame).expect("serialize frame"); + assert!(encoded.contains("turn_complete")); +} diff --git a/crates/state/Cargo.toml b/crates/state/Cargo.toml new file mode 100644 index 00000000..2e3d58dc --- /dev/null +++ b/crates/state/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "deepseek-state" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Session/thread persistence and recovery model for DeepSeek workspace architecture" + +[dependencies] +anyhow.workspace = true +chrono.workspace = true +dirs.workspace = true +rusqlite.workspace = true +serde.workspace = true +serde_json.workspace = true diff --git a/crates/state/src/lib.rs b/crates/state/src/lib.rs new file mode 100644 index 00000000..9bad8a16 --- /dev/null +++ b/crates/state/src/lib.rs @@ -0,0 +1,950 @@ +use std::collections::HashMap; +use std::fs::{self, OpenOptions}; +use std::io::{BufRead, BufReader, Write}; +use std::path::{Path, PathBuf}; + +use anyhow::{Context, Result}; +use chrono::Utc; +use rusqlite::{Connection, OptionalExtension, params}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ThreadStatus { + Running, + Idle, + Completed, + Failed, + Paused, + Archived, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum SessionSource { + Interactive, + Resume, + Fork, + Api, + Unknown, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThreadMetadata { + pub id: String, + pub rollout_path: Option, + pub preview: String, + pub ephemeral: bool, + pub model_provider: String, + pub created_at: i64, + pub updated_at: i64, + pub status: ThreadStatus, + pub path: Option, + pub cwd: PathBuf, + pub cli_version: String, + pub source: SessionSource, + pub name: Option, + pub sandbox_policy: Option, + pub approval_mode: Option, + pub archived: bool, + pub archived_at: Option, + pub git_sha: Option, + pub git_branch: Option, + pub git_origin_url: Option, + pub memory_mode: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DynamicToolRecord { + pub position: i64, + pub name: String, + pub description: Option, + pub input_schema: Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageRecord { + pub id: i64, + pub thread_id: String, + pub role: String, + pub content: String, + pub item: Option, + pub created_at: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CheckpointRecord { + pub thread_id: String, + pub checkpoint_id: String, + pub state: Value, + pub created_at: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum JobStateStatus { + Queued, + Running, + Completed, + Failed, + Cancelled, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JobStateRecord { + pub id: String, + pub name: String, + pub status: JobStateStatus, + pub progress: Option, + pub detail: Option, + pub created_at: i64, + pub updated_at: i64, +} + +#[derive(Debug, Clone)] +pub struct ThreadListFilters { + pub include_archived: bool, + pub limit: Option, +} + +impl Default for ThreadListFilters { + fn default() -> Self { + Self { + include_archived: false, + limit: Some(50), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SessionIndexEntry { + thread_id: String, + thread_name: Option, + updated_at: i64, + rollout_path: Option, +} + +#[derive(Debug, Clone)] +pub struct StateStore { + db_path: PathBuf, + session_index_path: PathBuf, +} + +impl StateStore { + pub fn open(path: Option) -> Result { + let db_path = path.unwrap_or_else(default_state_db_path); + let session_index_path = db_path + .parent() + .unwrap_or_else(|| Path::new(".")) + .join("session_index.jsonl"); + if let Some(parent) = db_path.parent() { + fs::create_dir_all(parent).with_context(|| { + format!("failed to create state directory {}", parent.display()) + })?; + } + let store = Self { + db_path, + session_index_path, + }; + store.init_schema()?; + Ok(store) + } + + pub fn db_path(&self) -> &Path { + &self.db_path + } + + fn conn(&self) -> Result { + Connection::open(&self.db_path) + .with_context(|| format!("failed to open state db {}", self.db_path.display())) + } + + fn init_schema(&self) -> Result<()> { + let conn = self.conn()?; + conn.execute_batch( + r#" + CREATE TABLE IF NOT EXISTS threads ( + id TEXT PRIMARY KEY, + rollout_path TEXT, + preview TEXT NOT NULL, + ephemeral INTEGER NOT NULL, + model_provider TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + status TEXT NOT NULL, + path TEXT, + cwd TEXT NOT NULL, + cli_version TEXT NOT NULL, + source TEXT NOT NULL, + title TEXT, + sandbox_policy TEXT, + approval_mode TEXT, + archived INTEGER NOT NULL DEFAULT 0, + archived_at INTEGER, + git_sha TEXT, + git_branch TEXT, + git_origin_url TEXT, + memory_mode TEXT + ); + CREATE INDEX IF NOT EXISTS idx_threads_updated_at ON threads(updated_at DESC); + CREATE INDEX IF NOT EXISTS idx_threads_archived_at ON threads(archived_at DESC); + CREATE INDEX IF NOT EXISTS idx_threads_archived_updated ON threads(archived, updated_at DESC); + + CREATE TABLE IF NOT EXISTS thread_dynamic_tools ( + thread_id TEXT NOT NULL, + position INTEGER NOT NULL, + name TEXT NOT NULL, + description TEXT, + input_schema TEXT NOT NULL, + PRIMARY KEY (thread_id, position), + FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + thread_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + item_json TEXT, + created_at INTEGER NOT NULL, + FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_messages_thread_created_at ON messages(thread_id, created_at ASC); + + CREATE TABLE IF NOT EXISTS checkpoints ( + thread_id TEXT NOT NULL, + checkpoint_id TEXT NOT NULL, + state_json TEXT NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY(thread_id, checkpoint_id), + FOREIGN KEY(thread_id) REFERENCES threads(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_checkpoints_thread_created_at ON checkpoints(thread_id, created_at DESC); + + CREATE TABLE IF NOT EXISTS jobs ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + status TEXT NOT NULL, + progress INTEGER, + detail TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_jobs_updated_at ON jobs(updated_at DESC); + "#, + ) + .context("failed to initialize thread schema")?; + Ok(()) + } + + pub fn upsert_thread(&self, thread: &ThreadMetadata) -> Result<()> { + let conn = self.conn()?; + conn.execute( + r#" + INSERT INTO threads ( + id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, + cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, + git_sha, git_branch, git_origin_url, memory_mode + ) VALUES ( + ?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, + ?11, ?12, ?13, ?14, ?15, ?16, ?17, + ?18, ?19, ?20, ?21 + ) + ON CONFLICT(id) DO UPDATE SET + rollout_path=excluded.rollout_path, + preview=excluded.preview, + ephemeral=excluded.ephemeral, + model_provider=excluded.model_provider, + created_at=excluded.created_at, + updated_at=excluded.updated_at, + status=excluded.status, + path=excluded.path, + cwd=excluded.cwd, + cli_version=excluded.cli_version, + source=excluded.source, + title=excluded.title, + sandbox_policy=excluded.sandbox_policy, + approval_mode=excluded.approval_mode, + archived=excluded.archived, + archived_at=excluded.archived_at, + git_sha=excluded.git_sha, + git_branch=excluded.git_branch, + git_origin_url=excluded.git_origin_url, + memory_mode=excluded.memory_mode + "#, + params![ + thread.id, + path_to_opt_string(thread.rollout_path.as_deref()), + thread.preview, + bool_to_i64(thread.ephemeral), + thread.model_provider, + thread.created_at, + thread.updated_at, + thread_status_to_str(&thread.status), + path_to_opt_string(thread.path.as_deref()), + thread.cwd.display().to_string(), + thread.cli_version, + session_source_to_str(&thread.source), + thread.name, + thread.sandbox_policy, + thread.approval_mode, + bool_to_i64(thread.archived), + thread.archived_at, + thread.git_sha, + thread.git_branch, + thread.git_origin_url, + thread.memory_mode, + ], + ) + .context("failed to upsert thread metadata")?; + + self.append_thread_name( + &thread.id, + thread.name.clone(), + thread.updated_at, + thread.rollout_path.clone(), + )?; + Ok(()) + } + + pub fn get_thread(&self, id: &str) -> Result> { + let conn = self.conn()?; + conn.query_row( + r#" + SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, + cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, + git_sha, git_branch, git_origin_url, memory_mode + FROM threads + WHERE id = ?1 + "#, + params![id], + row_to_thread, + ) + .optional() + .context("failed to read thread") + } + + pub fn list_threads(&self, filters: ThreadListFilters) -> Result> { + let conn = self.conn()?; + let sql = if filters.include_archived { + "SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, git_sha, git_branch, git_origin_url, memory_mode FROM threads ORDER BY updated_at DESC LIMIT ?1" + } else { + "SELECT id, rollout_path, preview, ephemeral, model_provider, created_at, updated_at, status, path, cwd, cli_version, source, title, sandbox_policy, approval_mode, archived, archived_at, git_sha, git_branch, git_origin_url, memory_mode FROM threads WHERE archived = 0 ORDER BY updated_at DESC LIMIT ?1" + }; + + let mut stmt = conn.prepare(sql).context("failed to prepare list query")?; + let limit = i64::try_from(filters.limit.unwrap_or(50)).unwrap_or(50); + let mut rows = stmt + .query(params![limit]) + .context("failed to query threads")?; + let mut out = Vec::new(); + while let Some(row) = rows.next().context("failed to iterate thread rows")? { + out.push(row_to_thread(row)?); + } + Ok(out) + } + + pub fn mark_archived(&self, id: &str) -> Result<()> { + let conn = self.conn()?; + conn.execute( + "UPDATE threads SET archived = 1, archived_at = ?2, status = ?3 WHERE id = ?1", + params![ + id, + Utc::now().timestamp(), + thread_status_to_str(&ThreadStatus::Archived) + ], + ) + .context("failed to archive thread")?; + Ok(()) + } + + pub fn mark_unarchived(&self, id: &str) -> Result<()> { + let conn = self.conn()?; + conn.execute( + "UPDATE threads SET archived = 0, archived_at = NULL WHERE id = ?1", + params![id], + ) + .context("failed to unarchive thread")?; + Ok(()) + } + + pub fn delete_thread(&self, id: &str) -> Result<()> { + let conn = self.conn()?; + conn.execute("DELETE FROM threads WHERE id = ?1", params![id]) + .context("failed to delete thread")?; + Ok(()) + } + + pub fn set_thread_memory_mode(&self, id: &str, mode: Option<&str>) -> Result<()> { + let conn = self.conn()?; + conn.execute( + "UPDATE threads SET memory_mode = ?2 WHERE id = ?1", + params![id, mode], + ) + .context("failed to update thread memory mode")?; + Ok(()) + } + + pub fn get_thread_memory_mode(&self, id: &str) -> Result> { + let conn = self.conn()?; + conn.query_row( + "SELECT memory_mode FROM threads WHERE id = ?1", + params![id], + |row| row.get::<_, Option>(0), + ) + .optional() + .context("failed to read thread memory mode") + .map(Option::flatten) + } + + pub fn persist_dynamic_tools( + &self, + thread_id: &str, + tools: &[DynamicToolRecord], + ) -> Result<()> { + let mut conn = self.conn()?; + let tx = conn + .transaction() + .context("failed to begin dynamic tools transaction")?; + tx.execute( + "DELETE FROM thread_dynamic_tools WHERE thread_id = ?1", + params![thread_id], + ) + .context("failed to clear dynamic tools")?; + for tool in tools { + tx.execute( + "INSERT INTO thread_dynamic_tools(thread_id, position, name, description, input_schema) VALUES (?1, ?2, ?3, ?4, ?5)", + params![ + thread_id, + tool.position, + tool.name, + tool.description, + tool.input_schema.to_string() + ], + ) + .with_context(|| format!("failed to persist dynamic tool {}", tool.name))?; + } + tx.commit().context("failed to commit dynamic tools")?; + Ok(()) + } + + pub fn get_dynamic_tools(&self, thread_id: &str) -> Result> { + let conn = self.conn()?; + let mut stmt = conn + .prepare( + "SELECT position, name, description, input_schema FROM thread_dynamic_tools WHERE thread_id = ?1 ORDER BY position ASC", + ) + .context("failed to prepare get dynamic tools query")?; + let mut rows = stmt + .query(params![thread_id]) + .context("failed to query dynamic tools")?; + let mut out = Vec::new(); + while let Some(row) = rows.next().context("failed to iterate dynamic tools")? { + let input_schema_raw: String = + row.get(3).context("failed to read tool input schema")?; + let input_schema: Value = + serde_json::from_str(&input_schema_raw).with_context(|| { + format!("failed to parse input schema for dynamic tool in thread {thread_id}") + })?; + out.push(DynamicToolRecord { + position: row.get(0).context("failed to read tool position")?, + name: row.get(1).context("failed to read tool name")?, + description: row.get(2).context("failed to read tool description")?, + input_schema, + }); + } + Ok(out) + } + + pub fn append_message( + &self, + thread_id: &str, + role: &str, + content: &str, + item: Option, + ) -> Result { + let conn = self.conn()?; + let created_at = Utc::now().timestamp(); + let item_json = item + .as_ref() + .map(serde_json::to_string) + .transpose() + .context("failed to serialize message item payload")?; + conn.execute( + "INSERT INTO messages(thread_id, role, content, item_json, created_at) VALUES (?1, ?2, ?3, ?4, ?5)", + params![thread_id, role, content, item_json, created_at], + ) + .with_context(|| format!("failed to append message for thread {thread_id}"))?; + Ok(conn.last_insert_rowid()) + } + + pub fn list_messages( + &self, + thread_id: &str, + limit: Option, + ) -> Result> { + let conn = self.conn()?; + let limit = i64::try_from(limit.unwrap_or(500)).unwrap_or(500); + let mut stmt = conn + .prepare( + "SELECT id, thread_id, role, content, item_json, created_at FROM messages WHERE thread_id = ?1 ORDER BY created_at ASC LIMIT ?2", + ) + .context("failed to prepare message listing query")?; + let mut rows = stmt + .query(params![thread_id, limit]) + .with_context(|| format!("failed to list messages for thread {thread_id}"))?; + let mut out = Vec::new(); + while let Some(row) = rows.next().context("failed to iterate message rows")? { + let item_json: Option = row.get(4).context("failed to read item json")?; + let item = item_json + .as_deref() + .map(serde_json::from_str) + .transpose() + .with_context(|| { + format!("failed to parse message item json in thread {thread_id}") + })?; + out.push(MessageRecord { + id: row.get(0).context("failed to read message id")?, + thread_id: row.get(1).context("failed to read message thread id")?, + role: row.get(2).context("failed to read message role")?, + content: row.get(3).context("failed to read message content")?, + item, + created_at: row.get(5).context("failed to read message timestamp")?, + }); + } + Ok(out) + } + + pub fn clear_messages(&self, thread_id: &str) -> Result { + let conn = self.conn()?; + conn.execute( + "DELETE FROM messages WHERE thread_id = ?1", + params![thread_id], + ) + .with_context(|| format!("failed to clear messages for thread {thread_id}")) + } + + pub fn save_checkpoint( + &self, + thread_id: &str, + checkpoint_id: &str, + state: &Value, + ) -> Result<()> { + let conn = self.conn()?; + let state_json = + serde_json::to_string(state).context("failed to encode checkpoint state")?; + conn.execute( + r#" + INSERT INTO checkpoints(thread_id, checkpoint_id, state_json, created_at) + VALUES (?1, ?2, ?3, ?4) + ON CONFLICT(thread_id, checkpoint_id) DO UPDATE SET + state_json = excluded.state_json, + created_at = excluded.created_at + "#, + params![thread_id, checkpoint_id, state_json, Utc::now().timestamp()], + ) + .with_context(|| { + format!("failed to save checkpoint {checkpoint_id} for thread {thread_id}") + })?; + Ok(()) + } + + pub fn load_checkpoint( + &self, + thread_id: &str, + checkpoint_id: Option<&str>, + ) -> Result> { + let conn = self.conn()?; + if let Some(checkpoint_id) = checkpoint_id { + let row = conn + .query_row( + "SELECT thread_id, checkpoint_id, state_json, created_at FROM checkpoints WHERE thread_id = ?1 AND checkpoint_id = ?2", + params![thread_id, checkpoint_id], + |row| { + let state_json: String = row.get(2)?; + let state = serde_json::from_str(&state_json).unwrap_or(Value::Null); + Ok(CheckpointRecord { + thread_id: row.get(0)?, + checkpoint_id: row.get(1)?, + state, + created_at: row.get(3)?, + }) + }, + ) + .optional() + .with_context(|| { + format!("failed to load checkpoint {checkpoint_id} for thread {thread_id}") + })?; + return Ok(row); + } + + conn.query_row( + "SELECT thread_id, checkpoint_id, state_json, created_at FROM checkpoints WHERE thread_id = ?1 ORDER BY created_at DESC LIMIT 1", + params![thread_id], + |row| { + let state_json: String = row.get(2)?; + let state = serde_json::from_str(&state_json).unwrap_or(Value::Null); + Ok(CheckpointRecord { + thread_id: row.get(0)?, + checkpoint_id: row.get(1)?, + state, + created_at: row.get(3)?, + }) + }, + ) + .optional() + .with_context(|| format!("failed to load latest checkpoint for thread {thread_id}")) + } + + pub fn list_checkpoints( + &self, + thread_id: &str, + limit: Option, + ) -> Result> { + let conn = self.conn()?; + let limit = i64::try_from(limit.unwrap_or(100)).unwrap_or(100); + let mut stmt = conn + .prepare( + "SELECT thread_id, checkpoint_id, state_json, created_at FROM checkpoints WHERE thread_id = ?1 ORDER BY created_at DESC LIMIT ?2", + ) + .context("failed to prepare checkpoint list query")?; + let mut rows = stmt + .query(params![thread_id, limit]) + .with_context(|| format!("failed to list checkpoints for thread {thread_id}"))?; + + let mut out = Vec::new(); + while let Some(row) = rows.next().context("failed to iterate checkpoint rows")? { + let state_json: String = row.get(2).context("failed to read checkpoint state json")?; + let state = serde_json::from_str(&state_json).unwrap_or(Value::Null); + out.push(CheckpointRecord { + thread_id: row.get(0).context("failed to read checkpoint thread id")?, + checkpoint_id: row.get(1).context("failed to read checkpoint id")?, + state, + created_at: row.get(3).context("failed to read checkpoint timestamp")?, + }); + } + Ok(out) + } + + pub fn delete_checkpoint(&self, thread_id: &str, checkpoint_id: &str) -> Result<()> { + let conn = self.conn()?; + conn.execute( + "DELETE FROM checkpoints WHERE thread_id = ?1 AND checkpoint_id = ?2", + params![thread_id, checkpoint_id], + ) + .with_context(|| { + format!("failed to delete checkpoint {checkpoint_id} for thread {thread_id}") + })?; + Ok(()) + } + + pub fn upsert_job(&self, job: &JobStateRecord) -> Result<()> { + let conn = self.conn()?; + conn.execute( + r#" + INSERT INTO jobs(id, name, status, progress, detail, created_at, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) + ON CONFLICT(id) DO UPDATE SET + name = excluded.name, + status = excluded.status, + progress = excluded.progress, + detail = excluded.detail, + created_at = excluded.created_at, + updated_at = excluded.updated_at + "#, + params![ + job.id, + job.name, + job_state_status_to_str(&job.status), + job.progress.map(i64::from), + job.detail, + job.created_at, + job.updated_at + ], + ) + .with_context(|| format!("failed to upsert job {}", job.id))?; + Ok(()) + } + + pub fn get_job(&self, id: &str) -> Result> { + let conn = self.conn()?; + conn.query_row( + "SELECT id, name, status, progress, detail, created_at, updated_at FROM jobs WHERE id = ?1", + params![id], + |row| { + let status_raw: String = row.get(2)?; + let progress: Option = row.get(3)?; + Ok(JobStateRecord { + id: row.get(0)?, + name: row.get(1)?, + status: job_state_status_from_str(&status_raw), + progress: progress.and_then(|v| u8::try_from(v).ok()), + detail: row.get(4)?, + created_at: row.get(5)?, + updated_at: row.get(6)?, + }) + }, + ) + .optional() + .with_context(|| format!("failed to read job {id}")) + } + + pub fn list_jobs(&self, limit: Option) -> Result> { + let conn = self.conn()?; + let limit = i64::try_from(limit.unwrap_or(100)).unwrap_or(100); + let mut stmt = conn + .prepare( + "SELECT id, name, status, progress, detail, created_at, updated_at FROM jobs ORDER BY updated_at DESC LIMIT ?1", + ) + .context("failed to prepare job list query")?; + let mut rows = stmt + .query(params![limit]) + .context("failed to query persisted jobs")?; + let mut out = Vec::new(); + while let Some(row) = rows.next().context("failed to iterate persisted jobs")? { + let status_raw: String = row.get(2).context("failed to read job status")?; + let progress: Option = row.get(3).context("failed to read job progress")?; + out.push(JobStateRecord { + id: row.get(0).context("failed to read job id")?, + name: row.get(1).context("failed to read job name")?, + status: job_state_status_from_str(&status_raw), + progress: progress.and_then(|v| u8::try_from(v).ok()), + detail: row.get(4).context("failed to read job detail")?, + created_at: row.get(5).context("failed to read job created_at")?, + updated_at: row.get(6).context("failed to read job updated_at")?, + }); + } + Ok(out) + } + + pub fn delete_job(&self, id: &str) -> Result<()> { + let conn = self.conn()?; + conn.execute("DELETE FROM jobs WHERE id = ?1", params![id]) + .with_context(|| format!("failed to delete job {id}"))?; + Ok(()) + } + + pub fn find_rollout_path_by_id(&self, id: &str) -> Result> { + let conn = self.conn()?; + conn.query_row( + "SELECT rollout_path FROM threads WHERE id = ?1", + params![id], + |row| row.get::<_, Option>(0), + ) + .optional() + .context("failed to lookup rollout path") + .map(|opt| opt.flatten().map(PathBuf::from)) + } + + pub fn append_thread_name( + &self, + thread_id: &str, + thread_name: Option, + updated_at: i64, + rollout_path: Option, + ) -> Result<()> { + if let Some(parent) = self.session_index_path.parent() { + fs::create_dir_all(parent).with_context(|| { + format!( + "failed to create session index directory {}", + parent.display() + ) + })?; + } + let entry = SessionIndexEntry { + thread_id: thread_id.to_string(), + thread_name, + updated_at, + rollout_path, + }; + let encoded = + serde_json::to_string(&entry).context("failed to serialize session index entry")?; + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.session_index_path) + .with_context(|| { + format!( + "failed to open session index {}", + self.session_index_path.display() + ) + })?; + writeln!(file, "{encoded}").context("failed to append session index entry")?; + Ok(()) + } + + pub fn find_thread_name_by_id(&self, thread_id: &str) -> Result> { + let map = self.session_index_map()?; + Ok(map + .get(thread_id) + .and_then(|entry| entry.thread_name.clone())) + } + + pub fn find_thread_names_by_ids( + &self, + ids: &[String], + ) -> Result>> { + let map = self.session_index_map()?; + let mut out = HashMap::new(); + for id in ids { + let name = map.get(id).and_then(|entry| entry.thread_name.clone()); + out.insert(id.clone(), name); + } + Ok(out) + } + + pub fn find_thread_path_by_name_str(&self, name: &str) -> Result> { + let map = self.session_index_map()?; + let matched = map + .values() + .filter(|entry| { + entry + .thread_name + .as_deref() + .is_some_and(|n| n.eq_ignore_ascii_case(name)) + }) + .max_by_key(|entry| entry.updated_at); + Ok(matched.and_then(|entry| entry.rollout_path.clone())) + } + + fn session_index_map(&self) -> Result> { + if !self.session_index_path.exists() { + return Ok(HashMap::new()); + } + let file = OpenOptions::new() + .read(true) + .open(&self.session_index_path) + .with_context(|| { + format!( + "failed to read session index {}", + self.session_index_path.display() + ) + })?; + let reader = BufReader::new(file); + let mut latest = HashMap::::new(); + for line in reader.lines() { + let line = line.context("failed to read session index line")?; + if line.trim().is_empty() { + continue; + } + let parsed: SessionIndexEntry = + serde_json::from_str(&line).context("failed to parse session index entry")?; + latest.insert(parsed.thread_id.clone(), parsed); + } + Ok(latest) + } +} + +fn default_state_db_path() -> PathBuf { + dirs::home_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join(".deepseek") + .join("state.db") +} + +fn bool_to_i64(value: bool) -> i64 { + if value { 1 } else { 0 } +} + +fn i64_to_bool(value: i64) -> bool { + value != 0 +} + +fn thread_status_to_str(status: &ThreadStatus) -> &'static str { + match status { + ThreadStatus::Running => "running", + ThreadStatus::Idle => "idle", + ThreadStatus::Completed => "completed", + ThreadStatus::Failed => "failed", + ThreadStatus::Paused => "paused", + ThreadStatus::Archived => "archived", + } +} + +fn thread_status_from_str(value: &str) -> ThreadStatus { + match value { + "running" => ThreadStatus::Running, + "idle" => ThreadStatus::Idle, + "completed" => ThreadStatus::Completed, + "failed" => ThreadStatus::Failed, + "paused" => ThreadStatus::Paused, + "archived" => ThreadStatus::Archived, + _ => ThreadStatus::Idle, + } +} + +fn session_source_to_str(source: &SessionSource) -> &'static str { + match source { + SessionSource::Interactive => "interactive", + SessionSource::Resume => "resume", + SessionSource::Fork => "fork", + SessionSource::Api => "api", + SessionSource::Unknown => "unknown", + } +} + +fn session_source_from_str(value: &str) -> SessionSource { + match value { + "interactive" => SessionSource::Interactive, + "resume" => SessionSource::Resume, + "fork" => SessionSource::Fork, + "api" => SessionSource::Api, + _ => SessionSource::Unknown, + } +} + +fn path_to_opt_string(path: Option<&Path>) -> Option { + path.map(|p| p.display().to_string()) +} + +fn job_state_status_to_str(status: &JobStateStatus) -> &'static str { + match status { + JobStateStatus::Queued => "queued", + JobStateStatus::Running => "running", + JobStateStatus::Completed => "completed", + JobStateStatus::Failed => "failed", + JobStateStatus::Cancelled => "cancelled", + } +} + +fn job_state_status_from_str(value: &str) -> JobStateStatus { + match value { + "queued" => JobStateStatus::Queued, + "running" => JobStateStatus::Running, + "completed" => JobStateStatus::Completed, + "failed" => JobStateStatus::Failed, + "cancelled" => JobStateStatus::Cancelled, + _ => JobStateStatus::Queued, + } +} + +fn row_to_thread(row: &rusqlite::Row<'_>) -> rusqlite::Result { + let status_raw: String = row.get(7)?; + let source_raw: String = row.get(11)?; + let rollout_path: Option = row.get(1)?; + let path: Option = row.get(8)?; + Ok(ThreadMetadata { + id: row.get(0)?, + rollout_path: rollout_path.map(PathBuf::from), + preview: row.get(2)?, + ephemeral: i64_to_bool(row.get(3)?), + model_provider: row.get(4)?, + created_at: row.get(5)?, + updated_at: row.get(6)?, + status: thread_status_from_str(&status_raw), + path: path.map(PathBuf::from), + cwd: PathBuf::from(row.get::<_, String>(9)?), + cli_version: row.get(10)?, + source: session_source_from_str(&source_raw), + name: row.get(12)?, + sandbox_policy: row.get(13)?, + approval_mode: row.get(14)?, + archived: i64_to_bool(row.get(15)?), + archived_at: row.get(16)?, + git_sha: row.get(17)?, + git_branch: row.get(18)?, + git_origin_url: row.get(19)?, + memory_mode: row.get(20)?, + }) +} diff --git a/crates/state/tests/parity_state.rs b/crates/state/tests/parity_state.rs new file mode 100644 index 00000000..4ff6a604 --- /dev/null +++ b/crates/state/tests/parity_state.rs @@ -0,0 +1,72 @@ +use std::path::PathBuf; + +use deepseek_state::{SessionSource, StateStore, ThreadListFilters, ThreadMetadata, ThreadStatus}; + +fn temp_state_path(label: &str) -> PathBuf { + std::env::temp_dir().join(format!( + "deepseek_state_test_{}_{}_{}.db", + label, + std::process::id(), + chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0) + )) +} + +#[test] +fn upsert_and_resume_thread_metadata() { + let path = temp_state_path("upsert_resume"); + let store = StateStore::open(Some(path.clone())).expect("open state store"); + let now = chrono::Utc::now().timestamp(); + let thread = ThreadMetadata { + id: "thread-test-1".to_string(), + rollout_path: Some(PathBuf::from("/tmp/rollout.jsonl")), + preview: "hello".to_string(), + ephemeral: false, + model_provider: "deepseek".to_string(), + created_at: now, + updated_at: now, + status: ThreadStatus::Running, + path: Some(PathBuf::from("/tmp/project")), + cwd: PathBuf::from("/tmp/project"), + cli_version: "0.0.0-test".to_string(), + source: SessionSource::Interactive, + name: Some("Test Thread".to_string()), + sandbox_policy: Some("workspace-write".to_string()), + approval_mode: Some("on-request".to_string()), + archived: false, + archived_at: None, + git_sha: None, + git_branch: None, + git_origin_url: None, + memory_mode: Some("extended".to_string()), + }; + store.upsert_thread(&thread).expect("upsert thread"); + + let loaded = store + .get_thread("thread-test-1") + .expect("read thread") + .expect("thread must exist"); + assert_eq!(loaded.id, "thread-test-1"); + assert_eq!(loaded.name.as_deref(), Some("Test Thread")); + assert_eq!(loaded.memory_mode.as_deref(), Some("extended")); + assert_eq!( + loaded.rollout_path, + Some(PathBuf::from("/tmp/rollout.jsonl")) + ); + + store + .mark_archived("thread-test-1") + .expect("archive thread"); + let archived = store + .get_thread("thread-test-1") + .expect("read archived thread") + .expect("thread exists after archive"); + assert!(archived.archived); + + let listed = store + .list_threads(ThreadListFilters { + include_archived: true, + limit: Some(10), + }) + .expect("list threads"); + assert!(!listed.is_empty()); +} diff --git a/crates/tools/Cargo.toml b/crates/tools/Cargo.toml new file mode 100644 index 00000000..243e42e6 --- /dev/null +++ b/crates/tools/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "deepseek-tools" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Tool invocation lifecycle, schema validation, and scheduler parallelism for DeepSeek workspace architecture" + +[dependencies] +anyhow.workspace = true +async-trait.workspace = true +deepseek-protocol = { path = "../protocol" } +serde.workspace = true +serde_json.workspace = true +tokio.workspace = true +uuid.workspace = true diff --git a/crates/tools/src/lib.rs b/crates/tools/src/lib.rs new file mode 100644 index 00000000..cec96628 --- /dev/null +++ b/crates/tools/src/lib.rs @@ -0,0 +1,202 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use deepseek_protocol::{ToolKind, ToolOutput, ToolPayload}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio::sync::RwLock; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolSpec { + pub name: String, + pub input_schema: Value, + pub output_schema: Value, + pub supports_parallel_tool_calls: bool, + pub timeout_ms: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConfiguredToolSpec { + pub spec: ToolSpec, + pub supports_parallel_tool_calls: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ToolCallSource { + Direct, + JsRepl, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub name: String, + pub payload: ToolPayload, + pub source: ToolCallSource, + pub raw_tool_call_id: Option, +} + +impl ToolCall { + pub fn execution_subject(&self, fallback_cwd: &str) -> (String, String, &'static str) { + match &self.payload { + ToolPayload::LocalShell { params } => ( + params.command.clone(), + params + .cwd + .clone() + .unwrap_or_else(|| fallback_cwd.to_string()), + "shell", + ), + _ => (self.name.clone(), fallback_cwd.to_string(), "tool"), + } + } +} + +#[derive(Debug, Clone)] +pub struct ToolInvocation { + pub call_id: String, + pub tool_name: String, + pub payload: ToolPayload, + pub source: ToolCallSource, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum FunctionCallError { + ToolNotFound { name: String }, + KindMismatch { expected: ToolKind, got: ToolKind }, + MutatingToolRejected { name: String }, + TimedOut { name: String, timeout_ms: u64 }, + Cancelled { name: String }, + ExecutionFailed { name: String, error: String }, +} + +#[async_trait] +pub trait ToolHandler: Send + Sync { + fn kind(&self) -> ToolKind; + fn matches_kind(&self, kind: ToolKind) -> bool { + self.kind() == kind + } + fn is_mutating(&self) -> bool { + false + } + async fn handle( + &self, + invocation: ToolInvocation, + ) -> std::result::Result; +} + +#[derive(Debug, Default)] +pub struct ToolCallRuntime { + pub parallel_execution: Arc>, +} + +#[derive(Default)] +pub struct ToolRegistry { + handlers: HashMap>, + specs: HashMap, + runtime: ToolCallRuntime, +} + +impl ToolRegistry { + pub fn register(&mut self, spec: ToolSpec, handler: Arc) -> Result<()> { + let name = spec.name.clone(); + self.specs.insert( + name.clone(), + ConfiguredToolSpec { + supports_parallel_tool_calls: spec.supports_parallel_tool_calls, + spec, + }, + ); + self.handlers.insert(name, handler); + Ok(()) + } + + pub fn list_specs(&self) -> Vec { + self.specs.values().cloned().collect() + } + + pub async fn dispatch( + &self, + call: ToolCall, + allow_mutating: bool, + ) -> std::result::Result { + let handler = self.handlers.get(&call.name).cloned().ok_or_else(|| { + FunctionCallError::ToolNotFound { + name: call.name.clone(), + } + })?; + let configured = + self.specs + .get(&call.name) + .cloned() + .ok_or_else(|| FunctionCallError::ToolNotFound { + name: call.name.clone(), + })?; + + let payload_kind = tool_payload_kind(&call.payload); + let expected = handler.kind(); + if !handler.matches_kind(payload_kind) { + return Err(FunctionCallError::KindMismatch { + expected, + got: payload_kind, + }); + } + if handler.is_mutating() && !allow_mutating { + return Err(FunctionCallError::MutatingToolRejected { name: call.name }); + } + + let invocation = ToolInvocation { + call_id: call + .raw_tool_call_id + .clone() + .unwrap_or_else(|| format!("tool-call-{}", uuid::Uuid::new_v4())), + tool_name: call.name.clone(), + payload: call.payload, + source: call.source, + }; + + if configured.supports_parallel_tool_calls { + let _guard = self.runtime.parallel_execution.read().await; + self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation) + .await + } else { + let _guard = self.runtime.parallel_execution.write().await; + self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation) + .await + } + } + + async fn execute_with_timeout( + &self, + handler: Arc, + timeout_ms: Option, + invocation: ToolInvocation, + ) -> std::result::Result { + if let Some(timeout_ms) = timeout_ms { + let name = invocation.tool_name.clone(); + match tokio::time::timeout( + Duration::from_millis(timeout_ms), + handler.handle(invocation), + ) + .await + { + Ok(result) => result, + Err(_) => Err(FunctionCallError::TimedOut { name, timeout_ms }), + } + } else { + handler.handle(invocation).await + } + } +} + +fn tool_payload_kind(payload: &ToolPayload) -> ToolKind { + match payload { + ToolPayload::Mcp { .. } => ToolKind::Mcp, + ToolPayload::Function { .. } + | ToolPayload::Custom { .. } + | ToolPayload::LocalShell { .. } => ToolKind::Function, + } +} diff --git a/crates/tools/tests/parity_tools.rs b/crates/tools/tests/parity_tools.rs new file mode 100644 index 00000000..799deed0 --- /dev/null +++ b/crates/tools/tests/parity_tools.rs @@ -0,0 +1,70 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use deepseek_protocol::{ToolKind, ToolOutput, ToolPayload}; +use deepseek_tools::{ + ToolCall, ToolCallSource, ToolHandler, ToolInvocation, ToolRegistry, ToolSpec, +}; +use serde_json::json; + +struct EchoHandler; + +#[async_trait] +impl ToolHandler for EchoHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn is_mutating(&self) -> bool { + false + } + + async fn handle( + &self, + invocation: ToolInvocation, + ) -> std::result::Result { + Ok(ToolOutput::Function { + body: Some(json!({ + "tool": invocation.tool_name, + "call_id": invocation.call_id + })), + success: true, + }) + } +} + +#[tokio::test] +async fn dispatches_function_tool_with_parallel_flag() { + let mut registry = ToolRegistry::default(); + registry + .register( + ToolSpec { + name: "echo".to_string(), + input_schema: json!({"type":"object"}), + output_schema: json!({"type":"object"}), + supports_parallel_tool_calls: true, + timeout_ms: Some(1000), + }, + Arc::new(EchoHandler), + ) + .expect("register tool"); + + let output = registry + .dispatch( + ToolCall { + name: "echo".to_string(), + payload: ToolPayload::Function { + arguments: "{\"message\":\"hi\"}".to_string(), + }, + source: ToolCallSource::Direct, + raw_tool_call_id: Some("call-1".to_string()), + }, + true, + ) + .await + .expect("dispatch tool"); + match output { + ToolOutput::Function { success, .. } => assert!(success), + other => panic!("unexpected output: {other:?}"), + } +} diff --git a/crates/tui-core/Cargo.toml b/crates/tui-core/Cargo.toml new file mode 100644 index 00000000..b7d70e05 --- /dev/null +++ b/crates/tui-core/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "deepseek-tui-core" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Event-driven TUI state machine scaffold for DeepSeek workspace architecture" diff --git a/crates/tui-core/src/lib.rs b/crates/tui-core/src/lib.rs new file mode 100644 index 00000000..f25128f5 --- /dev/null +++ b/crates/tui-core/src/lib.rs @@ -0,0 +1,192 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Pane { + Chat, + Diff, + Tasks, + Agents, + Status, + Jobs, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum UiEvent { + KeyPressed(char), + PromptSubmitted(String), + ResponseDelta(String), + ToolStarted(String), + ToolFinished(String), + JobQueued(String), + JobProgress { job_id: String, progress: u8 }, + JobCompleted(String), + ApprovalRequested(String), + ApprovalResolved(String), + PauseRequested, + ResumeRequested, + Tick, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum UiEffect { + Render, + PersistCheckpoint, + ScheduleBackgroundRefresh, + EmitStatusLine(String), +} + +#[derive(Debug, Clone)] +pub struct UiState { + pub active_pane: Pane, + pub paused: bool, + pub last_response_delta: Option, + pub active_tool: Option, + pub pending_tasks: usize, + pub active_jobs: usize, + pub pending_approvals: usize, + pub status_line: String, +} + +impl Default for UiState { + fn default() -> Self { + Self { + active_pane: Pane::Chat, + paused: false, + last_response_delta: None, + active_tool: None, + pending_tasks: 0, + active_jobs: 0, + pending_approvals: 0, + status_line: "ready".to_string(), + } + } +} + +impl UiState { + pub fn reduce(&mut self, event: UiEvent) -> Vec { + match event { + UiEvent::KeyPressed('1') => { + self.active_pane = Pane::Chat; + vec![UiEffect::Render] + } + UiEvent::KeyPressed('2') => { + self.active_pane = Pane::Diff; + vec![UiEffect::Render] + } + UiEvent::KeyPressed('3') => { + self.active_pane = Pane::Tasks; + vec![UiEffect::Render] + } + UiEvent::KeyPressed('4') => { + self.active_pane = Pane::Agents; + vec![UiEffect::Render] + } + UiEvent::KeyPressed('5') => { + self.active_pane = Pane::Jobs; + vec![UiEffect::Render] + } + UiEvent::PromptSubmitted(_) => { + self.pending_tasks = self.pending_tasks.saturating_add(1); + self.status_line = "prompt submitted".to_string(); + vec![ + UiEffect::Render, + UiEffect::PersistCheckpoint, + UiEffect::EmitStatusLine(self.status_line.clone()), + ] + } + UiEvent::ResponseDelta(delta) => { + self.last_response_delta = Some(delta); + self.status_line = "streaming response".to_string(); + vec![ + UiEffect::Render, + UiEffect::EmitStatusLine(self.status_line.clone()), + ] + } + UiEvent::ToolStarted(name) => { + self.active_tool = Some(name.clone()); + self.status_line = format!("tool running: {name}"); + vec![ + UiEffect::Render, + UiEffect::EmitStatusLine(self.status_line.clone()), + ] + } + UiEvent::ToolFinished(name) => { + self.active_tool = None; + self.pending_tasks = self.pending_tasks.saturating_sub(1); + self.status_line = format!("tool finished: {name}"); + vec![ + UiEffect::Render, + UiEffect::PersistCheckpoint, + UiEffect::EmitStatusLine(self.status_line.clone()), + ] + } + UiEvent::JobQueued(_) => { + self.active_jobs = self.active_jobs.saturating_add(1); + self.status_line = "job queued".to_string(); + vec![UiEffect::Render, UiEffect::PersistCheckpoint] + } + UiEvent::JobProgress { progress, .. } => { + self.status_line = format!("job progress: {}%", progress.min(100)); + vec![ + UiEffect::Render, + UiEffect::EmitStatusLine(self.status_line.clone()), + ] + } + UiEvent::JobCompleted(_) => { + self.active_jobs = self.active_jobs.saturating_sub(1); + self.status_line = "job completed".to_string(); + vec![ + UiEffect::Render, + UiEffect::PersistCheckpoint, + UiEffect::EmitStatusLine(self.status_line.clone()), + ] + } + UiEvent::ApprovalRequested(_) => { + self.pending_approvals = self.pending_approvals.saturating_add(1); + self.status_line = "approval requested".to_string(); + vec![ + UiEffect::Render, + UiEffect::EmitStatusLine(self.status_line.clone()), + ] + } + UiEvent::ApprovalResolved(_) => { + self.pending_approvals = self.pending_approvals.saturating_sub(1); + self.status_line = "approval resolved".to_string(); + vec![ + UiEffect::Render, + UiEffect::PersistCheckpoint, + UiEffect::EmitStatusLine(self.status_line.clone()), + ] + } + UiEvent::PauseRequested => { + self.paused = true; + self.status_line = "paused".to_string(); + vec![ + UiEffect::Render, + UiEffect::EmitStatusLine(self.status_line.clone()), + ] + } + UiEvent::ResumeRequested => { + self.paused = false; + self.status_line = "resumed".to_string(); + vec![ + UiEffect::Render, + UiEffect::EmitStatusLine(self.status_line.clone()), + ] + } + UiEvent::Tick => vec![UiEffect::ScheduleBackgroundRefresh], + UiEvent::KeyPressed(_) => Vec::new(), + } + } + + pub fn snapshot(&self) -> String { + format!( + "pane={:?};paused={};pending_tasks={};active_jobs={};pending_approvals={};active_tool={};status={}", + self.active_pane, + self.paused, + self.pending_tasks, + self.active_jobs, + self.pending_approvals, + self.active_tool.clone().unwrap_or_default(), + self.status_line + ) + } +} diff --git a/crates/tui-core/tests/snapshot.rs b/crates/tui-core/tests/snapshot.rs new file mode 100644 index 00000000..e4961a3f --- /dev/null +++ b/crates/tui-core/tests/snapshot.rs @@ -0,0 +1,25 @@ +use deepseek_tui_core::{Pane, UiEvent, UiState}; + +#[test] +fn reducer_produces_stable_snapshot_for_core_workflow() { + let mut state = UiState::default(); + state.reduce(UiEvent::PromptSubmitted("hello".to_string())); + state.reduce(UiEvent::ToolStarted("web.search".to_string())); + state.reduce(UiEvent::ResponseDelta("partial".to_string())); + state.reduce(UiEvent::ToolFinished("web.search".to_string())); + state.reduce(UiEvent::ApprovalRequested("approval-1".to_string())); + state.reduce(UiEvent::ApprovalResolved("approval-1".to_string())); + state.reduce(UiEvent::JobQueued("job-1".to_string())); + state.reduce(UiEvent::JobProgress { + job_id: "job-1".to_string(), + progress: 60, + }); + state.reduce(UiEvent::JobCompleted("job-1".to_string())); + state.reduce(UiEvent::KeyPressed('5')); + + assert_eq!(state.active_pane, Pane::Jobs); + assert_eq!( + state.snapshot(), + "pane=Jobs;paused=false;pending_tasks=0;active_jobs=0;pending_approvals=0;active_tool=;status=job completed" + ); +} diff --git a/crates/tui/Cargo.toml b/crates/tui/Cargo.toml new file mode 100644 index 00000000..46d66816 --- /dev/null +++ b/crates/tui/Cargo.toml @@ -0,0 +1,70 @@ +[package] +name = "deepseek-tui" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "Terminal UI for DeepSeek" + +[[bin]] +name = "deepseek-tui" +path = "../../src/main.rs" + +[dependencies] +anyhow = "1.0.100" +arboard = "3.4" +async-stream = "0.3.6" +async-trait = "0.1" +bytes = "1.11.0" +base64 = "0.22.1" +axum = { version = "0.8.4", features = ["json"] } +clap = { version = "4.5.54", features = ["derive"] } +clap_complete = "4.5" +colored = "3.0.0" +crossterm = "0.28" +csv = "1.4" +dotenvy = "0.15.7" +dirs = "6.0.0" +futures-util = "0.3.31" +indicatif = "0.18.0" +ratatui = "0.29" +regex = "1.11" +reqwest = { version = "0.13.1", default-features = false, features = ["blocking", "json", "stream", "multipart", "native-tls", "http2"] } +rustyline = "15.0.0" +serde = { version = "1.0.228", features = ["derive"] } +serde_json = "1.0.149" +shellexpand = "3" +toml = "0.9.7" +tokio = { version = "1.49.0", features = ["full"] } +tokio-util = { version = "0.7.16", features = ["io"] } +unicode-width = "0.2" +unicode-segmentation = "1.12" +uuid = { version = "1.11", features = ["v4"] } +tokio-stream = "0.1" +chrono = { version = "0.4", features = ["serde"] } +tempfile = "3.16" +thiserror = "2.0" +tracing = "0.1" +tower-http = { version = "0.6", features = ["cors"] } +wait-timeout = "0.2" +multimap = "0.10.0" +shlex = "1.3.0" +starlark = "0.13.0" +tiny_http = "0.12" +portable-pty = "0.8" +zeroize = "1.8.2" +ignore = "0.4" +pdf-extract = "0.7" + +[dev-dependencies] +wiremock = "0.6" +pretty_assertions = "1.4" + +[target.'cfg(target_os = "macos")'.dependencies] +libc = "0.2" + +[target.'cfg(target_os = "linux")'.dependencies] +libc = "0.2" + +[target.'cfg(target_os = "windows")'.dependencies] +windows = { version = "0.60", features = ["Win32_Foundation"] } diff --git a/docs/parity_release_and_ci.md b/docs/parity_release_and_ci.md new file mode 100644 index 00000000..30491c28 --- /dev/null +++ b/docs/parity_release_and_ci.md @@ -0,0 +1,35 @@ +# Parity CI and release checks + +This repository now includes parity-oriented CI checks under `.github/workflows/parity.yml`. + +## Workflow coverage + +- `cargo fmt --all -- --check` +- `cargo check --workspace --all-targets --locked` +- `cargo clippy --workspace --all-targets --all-features --locked -- -D warnings` +- `cargo test --workspace --all-features --locked` +- TUI snapshot parity test: + - `cargo test -p deepseek-tui-core --test snapshot --locked` +- protocol parity smoke test: + - `cargo test -p deepseek-protocol --test parity_protocol --locked` +- state persistence parity smoke test: + - `cargo test -p deepseek-state --test parity_state --locked` +- lockfile drift guard: + - `git diff --exit-code -- Cargo.lock` + +The tag-based release workflow now runs the same parity preflight before building artifacts. + +## Expected contributor flow + +1. Update workspace crates (`core`, `app-server`, `protocol`, `state`, `tools`, `mcp`, `execpolicy`, `hooks`, `tui`, `cli`). +2. Keep protocol and persistence tests green for parity-sensitive changes. +3. Ensure thread/tool/mcp event contracts remain backward-compatible across app-server endpoints. + +## Release readiness checklist + +- CLI and app-server binaries compile from workspace members. +- Session persistence schema changes include migration-safe SQL updates. +- Protocol changes include test updates in `crates/protocol/tests`. +- New tool lifecycle behavior includes tests in `crates/tools/tests`. +- TUI reducer changes include deterministic snapshot updates in `crates/tui/tests`. +- Release artifacts include `deepseek` (CLI) and `deepseek-tui` (TUI) binaries for all platforms. diff --git a/docs/workspace_migration_status.md b/docs/workspace_migration_status.md new file mode 100644 index 00000000..aeb2dc3b --- /dev/null +++ b/docs/workspace_migration_status.md @@ -0,0 +1,90 @@ +# DeepSeek Workspace Migration Status + +This document maps the initial workspace migration implementation to Linear issues `SHA-1554` to `SHA-1568`. + +## Implemented in this patch + +- `SHA-1554`: + - Root converted to Cargo workspace. + - New crate boundaries added: + - `crates/core` + - `crates/cli` + - `crates/app-server` + - `crates/protocol` + - `crates/config` + - `crates/agent` + - `crates/tui` + - `crates/tui` (TUI binary pointing at monolith source) + - Stable entry binaries now follow `cli` + `app-server` + `tui` split. + +- `SHA-1555`: + - Added `deepseek-config` crate with `ConfigToml` schema. + - Added provider-aware env precedence (`DEEPSEEK_API_KEY`, `OPENAI_API_KEY`, provider/base-url/model overrides). + - Added config read/write/list/set/unset operations. + +- `SHA-1556`: + - Added codex-style command grouping in `deepseek` CLI: + - `run` + - `auth` + - `config` + - `model` + - `app-server` + - `completion` + - Added global runtime override flags (`provider`, `model`, logging/telemetry/output/sandbox/approval controls). + +- `SHA-1557`: + - Added dual-provider auth model (`deepseek` + `openai`) with clear precedence and CLI management commands. + - Added `auth status|set|clear` command flow. + +- `SHA-1558`: + - Added `deepseek-protocol` crate with `thread/app/prompt` request-response framing and event frames. + - Added `deepseek-app-server` with `/thread`, `/app`, `/prompt`, `/healthz`. + - Added `/tool`, `/jobs`, and `/mcp/startup` transport endpoints for tool/job/MCP parity flows. + - Added stdio JSON-RPC 2.0 parity framing (`id`/`method`/`params` -> `result`/`error`) for `thread/*`, `app/*`, `prompt/*`, plus `healthz`/capabilities handlers. + +- `SHA-1560`: + - Added `deepseek-agent` model/provider registry with alias resolution and fallback strategy. + +- `SHA-1564`: + - Added `deepseek-tui-core` event-driven state machine scaffold (`UiState::reduce`). + - Expanded reducer with job/approval states and deterministic snapshot support. + +- `SHA-1559`: + - Added `deepseek-state` crate with persistent thread/session metadata in SQLite. + - Added thread list/read/archive/unarchive/name persistence operations and session index mirror. + +- `SHA-1561`: + - Added `deepseek-tools` crate with typed tool specs, call lifecycle, mutating gate, timeout handling, and read/write lock parallelism model. + +- `SHA-1562`: + - Added `deepseek-mcp` crate with server lifecycle events, qualified tool naming, filter support, resource listing/reads, and proxy call API. + - Added MCP stdio JSON-RPC 2.0 server mode parity for `tools/list`, `tools/call`, `resources/list`, `resources/read`, and server lifecycle operations. + - Added persisted MCP server definition round-trip through existing config APIs so server-mode definitions survive restarts. + +- `SHA-1563`: + - Added `deepseek-execpolicy` crate with approval mode model and policy decision/requirement evaluation. + +- `SHA-1565`: + - Added durable-style `JobManager` abstraction in core for queue/progress/cancel/recovery semantics. + +- `SHA-1566`: + - Added `deepseek-hooks` crate with stdout/jsonl/webhook sinks and standardized lifecycle events. + +- `SHA-1567`: + - Added parity tests for protocol/state/tools and TUI snapshot behavior. + +- `SHA-1568`: + - Added parity CI workflow at `.github/workflows/parity.yml` with workspace fmt/check/clippy/test gates, lockfile drift guard, and explicit snapshot/protocol/state parity tests. + - Added matching release preflight parity gates in `.github/workflows/release.yml`. + - Updated release artifact naming to include explicit `deepseek` entrypoint compatibility. + +## Not yet implemented in this patch + +- Codex-level protocol field-by-field parity for every `thread/*` operation remains in progress. +- MCP transport now provides stdio JSON-RPC compatibility flows; external subprocess execution remains scaffolded. +- Execution policy supports decision modeling and command gating; full user-interactive approval UX remains in progress. +- Background jobs are persisted conceptually at runtime boundary; cross-process recovery orchestration is still in progress. + +## Migration strategy note + +`crates/tui` intentionally points at existing `src/main.rs` to preserve current behavior while new workspace crates are phased in. This enables incremental replacement without blocking ongoing feature work.