feat(runtime): bridge desktop approvals and skills
This commit is contained in:
@@ -22,6 +22,11 @@ next round of TUI fixes can be verified against real terminal behaviour.
|
||||
resize, and assert on the parsed terminal frame plus the workspace
|
||||
filesystem. Initial scenarios cover boot smoke and the #1073 paste regression.
|
||||
Adding-a-scenario walkthrough lives in `crates/tui/tests/support/qa_harness/README.md`.
|
||||
- **Whalescale desktop runtime bridge** — the local runtime API now exposes
|
||||
`POST /v1/approvals/{id}`, `GET /v1/runtime/info`, `enabled` flags on
|
||||
`GET /v1/skills`, and `POST /v1/skills/{name}` toggles. Runtime thread
|
||||
events also carry `agent_reasoning` items so desktop clients can render
|
||||
thinking separately from assistant text.
|
||||
|
||||
### Changed
|
||||
- **`deepseek-cn` provider preset now defaults to the official
|
||||
|
||||
@@ -15,8 +15,9 @@ use deepseek_mcp::{
|
||||
};
|
||||
use deepseek_protocol::{
|
||||
AppResponse, EventFrame, ExecApprovalRequestEvent, PromptRequest, PromptResponse,
|
||||
ReviewDecision, Thread, ThreadForkParams, ThreadListParams, ThreadReadParams, ThreadRequest,
|
||||
ThreadResponse, ThreadResumeParams, ThreadSetNameParams, ThreadStatus, ToolPayload,
|
||||
ResponseChannel, ReviewDecision, Thread, ThreadForkParams, ThreadListParams, ThreadReadParams,
|
||||
ThreadRequest, ThreadResponse, ThreadResumeParams, ThreadSetNameParams, ThreadStatus,
|
||||
ToolPayload,
|
||||
};
|
||||
use deepseek_state::{
|
||||
JobStateRecord, JobStateStatus, SessionSource, StateStore, ThreadListFilters, ThreadMetadata,
|
||||
@@ -913,6 +914,7 @@ impl Runtime {
|
||||
EventFrame::ResponseDelta {
|
||||
response_id: response_id.clone(),
|
||||
delta: "queued".to_string(),
|
||||
channel: ResponseChannel::Text,
|
||||
},
|
||||
EventFrame::ResponseEnd { response_id },
|
||||
],
|
||||
@@ -992,6 +994,7 @@ impl Runtime {
|
||||
EventFrame::ResponseDelta {
|
||||
response_id: response_id.clone(),
|
||||
delta: "model-selected".to_string(),
|
||||
channel: ResponseChannel::Text,
|
||||
},
|
||||
EventFrame::ResponseEnd { response_id },
|
||||
],
|
||||
@@ -1252,6 +1255,7 @@ impl Runtime {
|
||||
"at": entry.at
|
||||
})
|
||||
.to_string(),
|
||||
channel: ResponseChannel::Text,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -366,6 +366,27 @@ pub struct ExecApprovalRequestEvent {
|
||||
pub available_decisions: Vec<ReviewDecision>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseChannel {
|
||||
#[default]
|
||||
Text,
|
||||
Reasoning,
|
||||
}
|
||||
|
||||
impl ResponseChannel {
|
||||
pub const fn is_text(&self) -> bool {
|
||||
matches!(self, ResponseChannel::Text)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApprovalDecisionRequest {
|
||||
pub decision: String,
|
||||
#[serde(default)]
|
||||
pub remember: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "event", rename_all = "snake_case")]
|
||||
pub enum EventFrame {
|
||||
@@ -375,6 +396,8 @@ pub enum EventFrame {
|
||||
ResponseDelta {
|
||||
response_id: String,
|
||||
delta: String,
|
||||
#[serde(default, skip_serializing_if = "ResponseChannel::is_text")]
|
||||
channel: ResponseChannel,
|
||||
},
|
||||
ResponseEnd {
|
||||
response_id: String,
|
||||
|
||||
@@ -58,6 +58,7 @@ mod schema_migration;
|
||||
mod seam_manager;
|
||||
mod session_manager;
|
||||
mod settings;
|
||||
mod skill_state;
|
||||
mod skills;
|
||||
mod snapshot;
|
||||
mod task_manager;
|
||||
|
||||
@@ -33,11 +33,13 @@ use crate::automation_manager::{
|
||||
use crate::config::{Config, DEFAULT_TEXT_MODEL};
|
||||
use crate::mcp::{McpConfig, McpPool};
|
||||
use crate::runtime_threads::{
|
||||
CompactThreadRequest, CreateThreadRequest, RuntimeThreadManager, RuntimeThreadManagerConfig,
|
||||
SharedRuntimeThreadManager, StartTurnRequest, SteerTurnRequest, ThreadDetail, ThreadListFilter,
|
||||
ThreadRecord, TurnItemKind, TurnRecord, UpdateThreadRequest, UsageGroupBy,
|
||||
CompactThreadRequest, CreateThreadRequest, ExternalApprovalDecision, RuntimeThreadManager,
|
||||
RuntimeThreadManagerConfig, SharedRuntimeThreadManager, StartTurnRequest, SteerTurnRequest,
|
||||
ThreadDetail, ThreadListFilter, ThreadRecord, TurnItemKind, TurnRecord, UpdateThreadRequest,
|
||||
UsageGroupBy,
|
||||
};
|
||||
use crate::session_manager::{SavedSession, SessionManager, SessionMetadata, default_sessions_dir};
|
||||
use crate::skill_state::SkillStateStore;
|
||||
use crate::skills::SkillRegistry;
|
||||
use crate::task_manager::{
|
||||
NewTaskRequest, SharedTaskManager, TaskManager, TaskManagerConfig, TaskRecord, TaskSummary,
|
||||
@@ -54,6 +56,10 @@ pub struct RuntimeApiState {
|
||||
mcp_config_path: PathBuf,
|
||||
automations: SharedAutomationManager,
|
||||
runtime_token: Option<String>,
|
||||
skill_state: Arc<Mutex<SkillStateStore>>,
|
||||
auth_required: bool,
|
||||
bind_host: String,
|
||||
bind_port: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -207,6 +213,7 @@ struct SkillEntry {
|
||||
name: String,
|
||||
description: String,
|
||||
path: PathBuf,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -216,6 +223,40 @@ struct SkillsResponse {
|
||||
skills: Vec<SkillEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SetSkillEnabledRequest {
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct SetSkillEnabledResponse {
|
||||
name: String,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DecideApprovalBody {
|
||||
decision: String,
|
||||
#[serde(default)]
|
||||
remember: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct DecideApprovalResponse {
|
||||
ok: bool,
|
||||
approval_id: String,
|
||||
decision: String,
|
||||
delivered: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct RuntimeInfoResponse {
|
||||
bind_host: String,
|
||||
port: u16,
|
||||
auth_required: bool,
|
||||
version: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct McpServerEntry {
|
||||
name: String,
|
||||
@@ -313,6 +354,13 @@ pub async fn run_http_server(
|
||||
.or_else(|| std::env::var("DEEPSEEK_RUNTIME_TOKEN").ok())
|
||||
.filter(|token| !token.trim().is_empty());
|
||||
let auth_enabled = runtime_token.is_some();
|
||||
let skill_state = SkillStateStore::load_default().unwrap_or_else(|err| {
|
||||
tracing::warn!(
|
||||
"Failed to load skills_state.toml ({}); treating all skills as enabled",
|
||||
err
|
||||
);
|
||||
SkillStateStore::default()
|
||||
});
|
||||
let state = RuntimeApiState {
|
||||
config: config.clone(),
|
||||
workspace,
|
||||
@@ -323,6 +371,10 @@ pub async fn run_http_server(
|
||||
mcp_config_path: config.mcp_config_path(),
|
||||
automations,
|
||||
runtime_token,
|
||||
skill_state: Arc::new(Mutex::new(skill_state)),
|
||||
auth_required: auth_enabled,
|
||||
bind_host: options.host.clone(),
|
||||
bind_port: options.port,
|
||||
};
|
||||
let app = build_router(state);
|
||||
|
||||
@@ -334,7 +386,26 @@ pub async fn run_http_server(
|
||||
.with_context(|| format!("Failed to bind {addr}"))?;
|
||||
|
||||
println!("Runtime API listening on http://{addr}");
|
||||
println!("Security: this server is local-first. Do not expose it to untrusted networks.");
|
||||
let is_loopback = options.host == "127.0.0.1" || options.host == "::1";
|
||||
if is_loopback {
|
||||
println!("Security: this server is local-first. Do not expose it to untrusted networks.");
|
||||
} else {
|
||||
println!(
|
||||
"Security: bound to {host}; reachable from any peer that can route to this address.",
|
||||
host = options.host
|
||||
);
|
||||
if !auth_enabled {
|
||||
println!(
|
||||
" WARNING: --auth-token (or DEEPSEEK_RUNTIME_TOKEN) is unset. Anyone on the network can call /v1/* without authentication."
|
||||
);
|
||||
}
|
||||
println!(
|
||||
" /v1/runtime/info reports bind_host={host:?}, port={port}, auth_required={auth}.",
|
||||
host = options.host,
|
||||
port = options.port,
|
||||
auth = auth_enabled,
|
||||
);
|
||||
}
|
||||
if auth_enabled {
|
||||
println!("Runtime API auth: bearer token required for /v1/* routes.");
|
||||
}
|
||||
@@ -372,10 +443,12 @@ pub fn build_router(state: RuntimeApiState) -> Router {
|
||||
)
|
||||
.route("/v1/threads/{id}/compact", post(compact_thread))
|
||||
.route("/v1/threads/{id}/events", get(stream_thread_events))
|
||||
.route("/v1/approvals/{approval_id}", post(decide_approval))
|
||||
.route("/v1/tasks", get(list_tasks).post(create_task))
|
||||
.route("/v1/tasks/{id}", get(get_task))
|
||||
.route("/v1/tasks/{id}/cancel", post(cancel_task))
|
||||
.route("/v1/skills", get(list_skills))
|
||||
.route("/v1/skills/{name}", post(set_skill_enabled))
|
||||
.route("/v1/apps/mcp/servers", get(list_mcp_servers))
|
||||
.route("/v1/apps/mcp/tools", get(list_mcp_tools))
|
||||
.route(
|
||||
@@ -400,6 +473,7 @@ pub fn build_router(state: RuntimeApiState) -> Router {
|
||||
|
||||
Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/v1/runtime/info", get(runtime_info))
|
||||
.merge(api_routes)
|
||||
.layer(cors_layer(&state.cors_origins))
|
||||
.with_state(state)
|
||||
@@ -777,6 +851,7 @@ async fn list_skills(
|
||||
) -> Result<Json<SkillsResponse>, ApiError> {
|
||||
let skills_dir = resolve_skills_dir(&state.config, &state.workspace);
|
||||
let registry = SkillRegistry::discover(&skills_dir);
|
||||
let skill_state = state.skill_state.lock().await;
|
||||
let skills = registry
|
||||
.list()
|
||||
.iter()
|
||||
@@ -784,6 +859,7 @@ async fn list_skills(
|
||||
name: skill.name.clone(),
|
||||
description: skill.description.clone(),
|
||||
path: skills_dir.join(&skill.name).join("SKILL.md"),
|
||||
enabled: skill_state.is_enabled(&skill.name),
|
||||
})
|
||||
.collect();
|
||||
Ok(Json(SkillsResponse {
|
||||
@@ -793,6 +869,74 @@ async fn list_skills(
|
||||
}))
|
||||
}
|
||||
|
||||
async fn set_skill_enabled(
|
||||
State(state): State<RuntimeApiState>,
|
||||
Path(name): Path<String>,
|
||||
Json(req): Json<SetSkillEnabledRequest>,
|
||||
) -> Result<Json<SetSkillEnabledResponse>, ApiError> {
|
||||
let skills_dir = resolve_skills_dir(&state.config, &state.workspace);
|
||||
let registry = SkillRegistry::discover(&skills_dir);
|
||||
let exists = registry.list().iter().any(|skill| skill.name == name);
|
||||
if !exists {
|
||||
return Err(ApiError::not_found(format!(
|
||||
"skill '{name}' not found under {}",
|
||||
skills_dir.display()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut store = state.skill_state.lock().await;
|
||||
store
|
||||
.set_enabled(&name, req.enabled)
|
||||
.map_err(|err| ApiError::internal(format!("persist skill state: {err}")))?;
|
||||
Ok(Json(SetSkillEnabledResponse {
|
||||
name,
|
||||
enabled: req.enabled,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn decide_approval(
|
||||
State(state): State<RuntimeApiState>,
|
||||
Path(approval_id): Path<String>,
|
||||
Json(req): Json<DecideApprovalBody>,
|
||||
) -> Result<Json<DecideApprovalResponse>, ApiError> {
|
||||
let decision = match req.decision.as_str() {
|
||||
"allow" => ExternalApprovalDecision::Allow {
|
||||
remember: req.remember,
|
||||
},
|
||||
"deny" => ExternalApprovalDecision::Deny {
|
||||
remember: req.remember,
|
||||
},
|
||||
other => {
|
||||
return Err(ApiError::bad_request(format!(
|
||||
"invalid decision '{other}'; expected \"allow\" or \"deny\""
|
||||
)));
|
||||
}
|
||||
};
|
||||
let delivered = state
|
||||
.runtime_threads
|
||||
.deliver_external_approval(&approval_id, decision);
|
||||
if !delivered {
|
||||
return Err(ApiError::not_found(format!(
|
||||
"no pending approval with id '{approval_id}'"
|
||||
)));
|
||||
}
|
||||
Ok(Json(DecideApprovalResponse {
|
||||
ok: true,
|
||||
approval_id,
|
||||
decision: req.decision,
|
||||
delivered,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn runtime_info(State(state): State<RuntimeApiState>) -> Json<RuntimeInfoResponse> {
|
||||
Json(RuntimeInfoResponse {
|
||||
bind_host: state.bind_host.clone(),
|
||||
port: state.bind_port,
|
||||
auth_required: state.auth_required,
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
})
|
||||
}
|
||||
|
||||
async fn list_mcp_servers(
|
||||
State(state): State<RuntimeApiState>,
|
||||
) -> Result<Json<McpServersResponse>, ApiError> {
|
||||
@@ -1769,6 +1913,7 @@ mod tests {
|
||||
)?));
|
||||
runtime_threads.attach_automation_manager(automations.clone());
|
||||
|
||||
let auth_required = runtime_token.is_some();
|
||||
let state = RuntimeApiState {
|
||||
config: Config::default(),
|
||||
workspace: PathBuf::from("."),
|
||||
@@ -1779,6 +1924,12 @@ mod tests {
|
||||
mcp_config_path: root.join("mcp.json"),
|
||||
automations,
|
||||
runtime_token,
|
||||
skill_state: Arc::new(Mutex::new(
|
||||
SkillStateStore::load_from(root.join("skills_state.toml")).unwrap_or_default(),
|
||||
)),
|
||||
auth_required,
|
||||
bind_host: "127.0.0.1".to_string(),
|
||||
bind_port: 0,
|
||||
};
|
||||
let app = build_router(state);
|
||||
let listener = match TcpListener::bind("127.0.0.1:0").await {
|
||||
@@ -3323,4 +3474,128 @@ mod tests {
|
||||
handle.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn runtime_info_reports_bind_state() -> Result<()> {
|
||||
let Some((addr, _runtime_threads, handle)) = spawn_test_server().await? else {
|
||||
return Ok(());
|
||||
};
|
||||
let client = reqwest::Client::new();
|
||||
let info: serde_json::Value = client
|
||||
.get(format!("http://{addr}/v1/runtime/info"))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await?;
|
||||
assert_eq!(info["bind_host"], "127.0.0.1");
|
||||
assert_eq!(info["auth_required"], false);
|
||||
assert!(info["version"].is_string());
|
||||
|
||||
handle.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn decide_approval_404s_when_nothing_pending() -> Result<()> {
|
||||
let Some((addr, _runtime_threads, handle)) = spawn_test_server().await? else {
|
||||
return Ok(());
|
||||
};
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("http://{addr}/v1/approvals/no_such_id"))
|
||||
.json(&json!({ "decision": "allow" }))
|
||||
.send()
|
||||
.await?;
|
||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
handle.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn decide_approval_400s_on_bad_decision() -> Result<()> {
|
||||
let Some((addr, _runtime_threads, handle)) = spawn_test_server().await? else {
|
||||
return Ok(());
|
||||
};
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("http://{addr}/v1/approvals/whatever"))
|
||||
.json(&json!({ "decision": "yolo" }))
|
||||
.send()
|
||||
.await?;
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
handle.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn decide_approval_delivers_to_runtime() -> Result<()> {
|
||||
let Some((addr, runtime_threads, handle)) = spawn_test_server().await? else {
|
||||
return Ok(());
|
||||
};
|
||||
let client = reqwest::Client::new();
|
||||
let rx = runtime_threads.register_pending_approval_for_test("ext_id");
|
||||
|
||||
let resp = client
|
||||
.post(format!("http://{addr}/v1/approvals/ext_id"))
|
||||
.json(&json!({ "decision": "allow", "remember": false }))
|
||||
.send()
|
||||
.await?;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
assert_eq!(body["ok"], true);
|
||||
assert_eq!(body["decision"], "allow");
|
||||
assert_eq!(body["delivered"], true);
|
||||
|
||||
let received = tokio::time::timeout(Duration::from_secs(1), rx).await??;
|
||||
assert_eq!(
|
||||
received,
|
||||
ExternalApprovalDecision::Allow { remember: false }
|
||||
);
|
||||
|
||||
handle.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn skills_endpoint_includes_enabled_field() -> Result<()> {
|
||||
let Some((addr, _runtime_threads, handle)) = spawn_test_server().await? else {
|
||||
return Ok(());
|
||||
};
|
||||
let client = reqwest::Client::new();
|
||||
let body: serde_json::Value = client
|
||||
.get(format!("http://{addr}/v1/skills"))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await?;
|
||||
if let Some(skills) = body["skills"].as_array() {
|
||||
for skill in skills {
|
||||
assert!(skill.get("enabled").is_some());
|
||||
}
|
||||
}
|
||||
|
||||
handle.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn skill_toggle_endpoint_404s_for_unknown_skill() -> Result<()> {
|
||||
let Some((addr, _runtime_threads, handle)) = spawn_test_server().await? else {
|
||||
return Ok(());
|
||||
};
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(format!("http://{addr}/v1/skills/no-such-skill"))
|
||||
.json(&json!({ "enabled": false }))
|
||||
.send()
|
||||
.await?;
|
||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
handle.abort();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,12 +8,13 @@ use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, Mutex as StdMutex};
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, json};
|
||||
use tokio::sync::{Mutex, broadcast};
|
||||
use tokio::sync::{Mutex, broadcast, oneshot};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -40,6 +41,7 @@ const SUMMARY_LIMIT: usize = 280;
|
||||
/// might misinterpret message counts; bumping is the safe choice.
|
||||
const CURRENT_RUNTIME_SCHEMA_VERSION: u32 = 2;
|
||||
const RUNTIME_RESTART_REASON: &str = "Interrupted by process restart";
|
||||
const APPROVAL_DECISION_TIMEOUT: Duration = Duration::from_secs(300);
|
||||
|
||||
const fn default_runtime_schema_version() -> u32 {
|
||||
CURRENT_RUNTIME_SCHEMA_VERSION
|
||||
@@ -61,6 +63,7 @@ pub enum RuntimeTurnStatus {
|
||||
pub enum TurnItemKind {
|
||||
UserMessage,
|
||||
AgentMessage,
|
||||
AgentReasoning,
|
||||
ToolCall,
|
||||
FileChange,
|
||||
CommandExecution,
|
||||
@@ -685,6 +688,7 @@ pub struct RuntimeThreadManager {
|
||||
cancel_token: CancellationToken,
|
||||
task_manager: Arc<StdMutex<Option<crate::task_manager::SharedTaskManager>>>,
|
||||
automations: Arc<StdMutex<Option<crate::automation_manager::SharedAutomationManager>>>,
|
||||
pending_approvals: Arc<StdMutex<HashMap<String, oneshot::Sender<ExternalApprovalDecision>>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -694,6 +698,12 @@ enum RuntimeApprovalDecision {
|
||||
RetryWithFullAccess,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ExternalApprovalDecision {
|
||||
Allow { remember: bool },
|
||||
Deny { remember: bool },
|
||||
}
|
||||
|
||||
impl RuntimeThreadManager {
|
||||
pub fn open(
|
||||
config: Config,
|
||||
@@ -712,6 +722,7 @@ impl RuntimeThreadManager {
|
||||
cancel_token: CancellationToken::new(),
|
||||
task_manager: Arc::new(StdMutex::new(None)),
|
||||
automations: Arc::new(StdMutex::new(None)),
|
||||
pending_approvals: Arc::new(StdMutex::new(HashMap::new())),
|
||||
};
|
||||
manager.recover_interrupted_state()?;
|
||||
Ok(manager)
|
||||
@@ -738,6 +749,9 @@ impl RuntimeThreadManager {
|
||||
#[allow(dead_code)] // Public API for external callers (runtime API, task manager)
|
||||
pub fn shutdown(&self) {
|
||||
self.cancel_token.cancel();
|
||||
if let Ok(mut map) = self.pending_approvals.lock() {
|
||||
map.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // Public API for external callers
|
||||
@@ -745,6 +759,72 @@ impl RuntimeThreadManager {
|
||||
self.cancel_token.is_cancelled()
|
||||
}
|
||||
|
||||
fn register_pending_approval(
|
||||
&self,
|
||||
approval_id: &str,
|
||||
) -> oneshot::Receiver<ExternalApprovalDecision> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
if let Ok(mut map) = self.pending_approvals.lock() {
|
||||
map.insert(approval_id.to_string(), tx);
|
||||
}
|
||||
rx
|
||||
}
|
||||
|
||||
fn cancel_pending_approval(&self, approval_id: &str) {
|
||||
if let Ok(mut map) = self.pending_approvals.lock() {
|
||||
map.remove(approval_id);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deliver_external_approval(
|
||||
&self,
|
||||
approval_id: &str,
|
||||
decision: ExternalApprovalDecision,
|
||||
) -> bool {
|
||||
let sender = match self.pending_approvals.lock() {
|
||||
Ok(mut map) => map.remove(approval_id),
|
||||
Err(_) => return false,
|
||||
};
|
||||
match sender {
|
||||
Some(tx) => tx.send(decision).is_ok(),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn pending_approvals_count(&self) -> usize {
|
||||
self.pending_approvals
|
||||
.lock()
|
||||
.map(|map| map.len())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn register_pending_approval_for_test(
|
||||
&self,
|
||||
approval_id: &str,
|
||||
) -> oneshot::Receiver<ExternalApprovalDecision> {
|
||||
self.register_pending_approval(approval_id)
|
||||
}
|
||||
|
||||
async fn remember_thread_auto_approve(&self, thread_id: &str) {
|
||||
let Ok(mut thread) = self.store.load_thread(thread_id) else {
|
||||
return;
|
||||
};
|
||||
if thread.auto_approve {
|
||||
return;
|
||||
}
|
||||
thread.auto_approve = true;
|
||||
thread.updated_at = Utc::now();
|
||||
if let Err(err) = self.store.save_thread(&thread) {
|
||||
tracing::warn!(
|
||||
"Failed to persist auto_approve flip for thread {}: {}",
|
||||
thread_id,
|
||||
err
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn subscribe_events(&self) -> broadcast::Receiver<RuntimeEventRecord> {
|
||||
self.event_tx.subscribe()
|
||||
@@ -1921,6 +2001,7 @@ impl RuntimeThreadManager {
|
||||
engine: EngineHandle,
|
||||
) -> Result<()> {
|
||||
let mut current_message_item: Option<(String, String)> = None;
|
||||
let mut current_reasoning_item: Option<(String, String)> = None;
|
||||
let mut tool_items: HashMap<String, String> = HashMap::new();
|
||||
let mut compaction_items: HashMap<String, String> = HashMap::new();
|
||||
let mut turn_usage: Option<Usage> = None;
|
||||
@@ -2012,6 +2093,64 @@ impl RuntimeThreadManager {
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
EngineEvent::ThinkingStarted { .. } => {
|
||||
let item_id = format!("item_{}", &Uuid::new_v4().to_string()[..8]);
|
||||
let item = TurnItemRecord {
|
||||
schema_version: CURRENT_RUNTIME_SCHEMA_VERSION,
|
||||
id: item_id.clone(),
|
||||
turn_id: turn_id.clone(),
|
||||
kind: TurnItemKind::AgentReasoning,
|
||||
status: TurnItemLifecycleStatus::InProgress,
|
||||
summary: String::new(),
|
||||
detail: Some(String::new()),
|
||||
metadata: None,
|
||||
artifact_refs: Vec::new(),
|
||||
started_at: Some(Utc::now()),
|
||||
ended_at: None,
|
||||
};
|
||||
self.store.save_item(&item)?;
|
||||
self.attach_item_to_turn(&turn_id, &item.id)?;
|
||||
self.emit_event(
|
||||
&thread_id,
|
||||
Some(&turn_id),
|
||||
Some(&item_id),
|
||||
"item.started",
|
||||
json!({ "item": item }),
|
||||
)
|
||||
.await?;
|
||||
current_reasoning_item = Some((item_id, String::new()));
|
||||
}
|
||||
EngineEvent::ThinkingDelta { content, .. } => {
|
||||
if let Some((item_id, text)) = current_reasoning_item.as_mut() {
|
||||
text.push_str(&content);
|
||||
self.emit_event(
|
||||
&thread_id,
|
||||
Some(&turn_id),
|
||||
Some(item_id),
|
||||
"item.delta",
|
||||
json!({ "delta": content, "kind": "agent_reasoning" }),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
EngineEvent::ThinkingComplete { .. } => {
|
||||
if let Some((item_id, text)) = current_reasoning_item.take() {
|
||||
let mut item = self.store.load_item(&item_id)?;
|
||||
item.status = TurnItemLifecycleStatus::Completed;
|
||||
item.summary = summarize_text(&text, SUMMARY_LIMIT);
|
||||
item.detail = Some(text);
|
||||
item.ended_at = Some(Utc::now());
|
||||
self.store.save_item(&item)?;
|
||||
self.emit_event(
|
||||
&thread_id,
|
||||
Some(&turn_id),
|
||||
Some(&item_id),
|
||||
"item.completed",
|
||||
json!({ "item": item }),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
EngineEvent::ToolCallStarted { id, name, input } => {
|
||||
let item_id = format!("item_{}", &Uuid::new_v4().to_string()[..8]);
|
||||
tool_items.insert(id.clone(), item_id.clone());
|
||||
@@ -2447,22 +2586,88 @@ impl RuntimeThreadManager {
|
||||
"approval.required",
|
||||
json!({
|
||||
"id": id,
|
||||
"approval_id": id,
|
||||
"tool_name": tool_name,
|
||||
"description": description,
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let (auto_approve, trust_mode) = self
|
||||
.active_turn_flags(&thread_id, &turn_id)
|
||||
.await
|
||||
.unwrap_or((false, false));
|
||||
match Self::approval_decision(auto_approve, trust_mode, false) {
|
||||
RuntimeApprovalDecision::ApproveTool => {
|
||||
let Some((auto_approve, trust_mode)) =
|
||||
self.active_turn_flags(&thread_id, &turn_id).await
|
||||
else {
|
||||
let _ = engine.deny_tool_call(id).await;
|
||||
continue;
|
||||
};
|
||||
|
||||
if auto_approve || trust_mode {
|
||||
match Self::approval_decision(auto_approve, trust_mode, false) {
|
||||
RuntimeApprovalDecision::ApproveTool => {
|
||||
let _ = engine.approve_tool_call(id).await;
|
||||
}
|
||||
RuntimeApprovalDecision::DenyTool
|
||||
| RuntimeApprovalDecision::RetryWithFullAccess => {
|
||||
let _ = engine.deny_tool_call(id).await;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let rx = self.register_pending_approval(&id);
|
||||
match tokio::time::timeout(APPROVAL_DECISION_TIMEOUT, rx).await {
|
||||
Ok(Ok(ExternalApprovalDecision::Allow { remember })) => {
|
||||
if remember {
|
||||
self.remember_thread_auto_approve(&thread_id).await;
|
||||
}
|
||||
self.emit_event(
|
||||
&thread_id,
|
||||
Some(&turn_id),
|
||||
None,
|
||||
"approval.decided",
|
||||
json!({
|
||||
"approval_id": id,
|
||||
"decision": "allow",
|
||||
"remember": remember,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.ok();
|
||||
let _ = engine.approve_tool_call(id).await;
|
||||
}
|
||||
RuntimeApprovalDecision::DenyTool
|
||||
| RuntimeApprovalDecision::RetryWithFullAccess => {
|
||||
Ok(Ok(ExternalApprovalDecision::Deny { remember })) => {
|
||||
self.emit_event(
|
||||
&thread_id,
|
||||
Some(&turn_id),
|
||||
None,
|
||||
"approval.decided",
|
||||
json!({
|
||||
"approval_id": id,
|
||||
"decision": "deny",
|
||||
"remember": remember,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.ok();
|
||||
let _ = engine.deny_tool_call(id).await;
|
||||
}
|
||||
Ok(Err(_recv_err)) => {
|
||||
self.cancel_pending_approval(&id);
|
||||
let _ = engine.deny_tool_call(id).await;
|
||||
}
|
||||
Err(_timeout) => {
|
||||
self.cancel_pending_approval(&id);
|
||||
self.emit_event(
|
||||
&thread_id,
|
||||
Some(&turn_id),
|
||||
None,
|
||||
"approval.timeout",
|
||||
json!({
|
||||
"approval_id": id,
|
||||
"timeout_secs": APPROVAL_DECISION_TIMEOUT.as_secs(),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.ok();
|
||||
let _ = engine.deny_tool_call(id).await;
|
||||
}
|
||||
}
|
||||
@@ -2610,6 +2815,31 @@ impl RuntimeThreadManager {
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Some((item_id, text)) = current_reasoning_item.take() {
|
||||
let mut item = self.store.load_item(&item_id)?;
|
||||
if turn_status == RuntimeTurnStatus::Interrupted {
|
||||
item.status = TurnItemLifecycleStatus::Interrupted;
|
||||
} else {
|
||||
item.status = TurnItemLifecycleStatus::Completed;
|
||||
}
|
||||
item.summary = summarize_text(&text, SUMMARY_LIMIT);
|
||||
item.detail = Some(text);
|
||||
item.ended_at = Some(Utc::now());
|
||||
self.store.save_item(&item)?;
|
||||
self.emit_event(
|
||||
&thread_id,
|
||||
Some(&turn_id),
|
||||
Some(&item_id),
|
||||
if item.status == TurnItemLifecycleStatus::Interrupted {
|
||||
"item.interrupted"
|
||||
} else {
|
||||
"item.completed"
|
||||
},
|
||||
json!({ "item": item }),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let ended_at = Utc::now();
|
||||
let mut turn = self.store.load_turn(&turn_id)?;
|
||||
turn.status = turn_status;
|
||||
@@ -3863,6 +4093,340 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn approval_required_awaits_external_decision_allow() -> Result<()> {
|
||||
let manager = test_manager(test_runtime_dir())?;
|
||||
let thread = manager
|
||||
.create_thread(CreateThreadRequest {
|
||||
model: None,
|
||||
workspace: None,
|
||||
mode: None,
|
||||
allow_shell: None,
|
||||
trust_mode: None,
|
||||
auto_approve: None,
|
||||
archived: false,
|
||||
system_prompt: None,
|
||||
task_id: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut harness = install_mock_engine(&manager, &thread.id).await;
|
||||
let _turn = manager
|
||||
.start_turn(
|
||||
&thread.id,
|
||||
StartTurnRequest {
|
||||
prompt: "needs approval".to_string(),
|
||||
input_summary: None,
|
||||
model: None,
|
||||
mode: None,
|
||||
allow_shell: None,
|
||||
trust_mode: None,
|
||||
auto_approve: None,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
assert!(matches!(
|
||||
harness.rx_op.recv().await,
|
||||
Some(Op::SendMessage { .. })
|
||||
));
|
||||
|
||||
harness
|
||||
.tx_event
|
||||
.send(EngineEvent::ApprovalRequired {
|
||||
approval_key: "key1".to_string(),
|
||||
id: "tool_external_allow".to_string(),
|
||||
tool_name: "exec_command".to_string(),
|
||||
description: "external allow".to_string(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let deadline = Instant::now() + Duration::from_secs(2);
|
||||
while Instant::now() < deadline && manager.pending_approvals_count() == 0 {
|
||||
sleep(Duration::from_millis(20)).await;
|
||||
}
|
||||
assert_eq!(manager.pending_approvals_count(), 1);
|
||||
|
||||
assert!(manager.deliver_external_approval(
|
||||
"tool_external_allow",
|
||||
ExternalApprovalDecision::Allow { remember: false },
|
||||
));
|
||||
assert_eq!(
|
||||
harness.recv_approval_event().await,
|
||||
Some(MockApprovalEvent::Approved {
|
||||
id: "tool_external_allow".to_string(),
|
||||
})
|
||||
);
|
||||
assert_eq!(manager.pending_approvals_count(), 0);
|
||||
|
||||
harness
|
||||
.tx_event
|
||||
.send(EngineEvent::TurnComplete {
|
||||
usage: Usage::default(),
|
||||
status: TurnOutcomeStatus::Completed,
|
||||
error: None,
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn approval_required_external_deny_is_denied() -> Result<()> {
|
||||
let manager = test_manager(test_runtime_dir())?;
|
||||
let thread = manager
|
||||
.create_thread(CreateThreadRequest {
|
||||
model: None,
|
||||
workspace: None,
|
||||
mode: None,
|
||||
allow_shell: None,
|
||||
trust_mode: None,
|
||||
auto_approve: None,
|
||||
archived: false,
|
||||
system_prompt: None,
|
||||
task_id: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut harness = install_mock_engine(&manager, &thread.id).await;
|
||||
let _turn = manager
|
||||
.start_turn(
|
||||
&thread.id,
|
||||
StartTurnRequest {
|
||||
prompt: "needs approval".to_string(),
|
||||
input_summary: None,
|
||||
model: None,
|
||||
mode: None,
|
||||
allow_shell: None,
|
||||
trust_mode: None,
|
||||
auto_approve: None,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
assert!(matches!(
|
||||
harness.rx_op.recv().await,
|
||||
Some(Op::SendMessage { .. })
|
||||
));
|
||||
|
||||
harness
|
||||
.tx_event
|
||||
.send(EngineEvent::ApprovalRequired {
|
||||
approval_key: "key2".to_string(),
|
||||
id: "tool_external_deny".to_string(),
|
||||
tool_name: "exec_command".to_string(),
|
||||
description: "external deny".to_string(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let deadline = Instant::now() + Duration::from_secs(2);
|
||||
while Instant::now() < deadline && manager.pending_approvals_count() == 0 {
|
||||
sleep(Duration::from_millis(20)).await;
|
||||
}
|
||||
assert_eq!(manager.pending_approvals_count(), 1);
|
||||
|
||||
assert!(manager.deliver_external_approval(
|
||||
"tool_external_deny",
|
||||
ExternalApprovalDecision::Deny { remember: false },
|
||||
));
|
||||
assert_eq!(
|
||||
harness.recv_approval_event().await,
|
||||
Some(MockApprovalEvent::Denied {
|
||||
id: "tool_external_deny".to_string(),
|
||||
})
|
||||
);
|
||||
|
||||
harness
|
||||
.tx_event
|
||||
.send(EngineEvent::TurnComplete {
|
||||
usage: Usage::default(),
|
||||
status: TurnOutcomeStatus::Completed,
|
||||
error: None,
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thinking_delta_emits_agent_reasoning_item() -> Result<()> {
|
||||
let manager = test_manager(test_runtime_dir())?;
|
||||
let thread = manager
|
||||
.create_thread(CreateThreadRequest {
|
||||
model: None,
|
||||
workspace: None,
|
||||
mode: None,
|
||||
allow_shell: None,
|
||||
trust_mode: None,
|
||||
auto_approve: Some(true),
|
||||
archived: false,
|
||||
system_prompt: None,
|
||||
task_id: None,
|
||||
})
|
||||
.await?;
|
||||
let mut harness = install_mock_engine(&manager, &thread.id).await;
|
||||
let mut event_rx = manager.subscribe_events();
|
||||
let _turn = manager
|
||||
.start_turn(
|
||||
&thread.id,
|
||||
StartTurnRequest {
|
||||
prompt: "show your thinking".to_string(),
|
||||
input_summary: None,
|
||||
model: None,
|
||||
mode: None,
|
||||
allow_shell: None,
|
||||
trust_mode: None,
|
||||
auto_approve: Some(true),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
assert!(matches!(
|
||||
harness.rx_op.recv().await,
|
||||
Some(Op::SendMessage { .. })
|
||||
));
|
||||
|
||||
harness
|
||||
.tx_event
|
||||
.send(EngineEvent::ThinkingStarted { index: 0 })
|
||||
.await?;
|
||||
harness
|
||||
.tx_event
|
||||
.send(EngineEvent::ThinkingDelta {
|
||||
index: 0,
|
||||
content: "Let me reason about this.".to_string(),
|
||||
})
|
||||
.await?;
|
||||
harness
|
||||
.tx_event
|
||||
.send(EngineEvent::ThinkingComplete { index: 0 })
|
||||
.await?;
|
||||
harness
|
||||
.tx_event
|
||||
.send(EngineEvent::TurnComplete {
|
||||
usage: Usage::default(),
|
||||
status: TurnOutcomeStatus::Completed,
|
||||
error: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let deadline = Instant::now() + Duration::from_secs(2);
|
||||
let mut delta_seen = false;
|
||||
let mut completed_seen = false;
|
||||
while Instant::now() < deadline && (!delta_seen || !completed_seen) {
|
||||
match tokio::time::timeout(Duration::from_millis(200), event_rx.recv()).await {
|
||||
Ok(Ok(record)) => {
|
||||
if record.event == "item.delta"
|
||||
&& record.payload.get("kind").and_then(|v| v.as_str())
|
||||
== Some("agent_reasoning")
|
||||
{
|
||||
delta_seen = true;
|
||||
assert_eq!(
|
||||
record.payload.get("delta").and_then(|v| v.as_str()),
|
||||
Some("Let me reason about this.")
|
||||
);
|
||||
}
|
||||
if record.event == "item.completed"
|
||||
&& record
|
||||
.payload
|
||||
.get("item")
|
||||
.and_then(|v| v.get("kind"))
|
||||
.and_then(|v| v.as_str())
|
||||
== Some("agent_reasoning")
|
||||
{
|
||||
completed_seen = true;
|
||||
}
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
assert!(delta_seen, "expected item.delta with kind=agent_reasoning");
|
||||
assert!(
|
||||
completed_seen,
|
||||
"expected item.completed for the reasoning item"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deliver_external_approval_for_unknown_id_returns_false() {
|
||||
let manager = test_manager(test_runtime_dir()).expect("manager");
|
||||
assert!(!manager.deliver_external_approval(
|
||||
"no_such_approval",
|
||||
ExternalApprovalDecision::Allow { remember: false },
|
||||
));
|
||||
assert_eq!(manager.pending_approvals_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn approval_required_remember_flips_thread_auto_approve() -> Result<()> {
|
||||
let manager = test_manager(test_runtime_dir())?;
|
||||
let thread = manager
|
||||
.create_thread(CreateThreadRequest {
|
||||
model: None,
|
||||
workspace: None,
|
||||
mode: None,
|
||||
allow_shell: None,
|
||||
trust_mode: None,
|
||||
auto_approve: None,
|
||||
archived: false,
|
||||
system_prompt: None,
|
||||
task_id: None,
|
||||
})
|
||||
.await?;
|
||||
assert!(!manager.store.load_thread(&thread.id)?.auto_approve);
|
||||
|
||||
let mut harness = install_mock_engine(&manager, &thread.id).await;
|
||||
let _turn = manager
|
||||
.start_turn(
|
||||
&thread.id,
|
||||
StartTurnRequest {
|
||||
prompt: "needs approval".to_string(),
|
||||
input_summary: None,
|
||||
model: None,
|
||||
mode: None,
|
||||
allow_shell: None,
|
||||
trust_mode: None,
|
||||
auto_approve: None,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
assert!(matches!(
|
||||
harness.rx_op.recv().await,
|
||||
Some(Op::SendMessage { .. })
|
||||
));
|
||||
|
||||
harness
|
||||
.tx_event
|
||||
.send(EngineEvent::ApprovalRequired {
|
||||
approval_key: "key3".to_string(),
|
||||
id: "tool_remember".to_string(),
|
||||
tool_name: "exec_command".to_string(),
|
||||
description: "remember=true".to_string(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let deadline = Instant::now() + Duration::from_secs(2);
|
||||
while Instant::now() < deadline && manager.pending_approvals_count() == 0 {
|
||||
sleep(Duration::from_millis(20)).await;
|
||||
}
|
||||
assert!(manager.deliver_external_approval(
|
||||
"tool_remember",
|
||||
ExternalApprovalDecision::Allow { remember: true },
|
||||
));
|
||||
let _ = harness.recv_approval_event().await;
|
||||
|
||||
assert!(
|
||||
manager.store.load_thread(&thread.id)?.auto_approve,
|
||||
"remember=true should flip thread auto_approve"
|
||||
);
|
||||
|
||||
harness
|
||||
.tx_event
|
||||
.send(EngineEvent::TurnComplete {
|
||||
usage: Usage::default(),
|
||||
status: TurnOutcomeStatus::Completed,
|
||||
error: None,
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn elevation_required_with_stale_active_turn_is_denied() -> Result<()> {
|
||||
let manager = test_manager(test_runtime_dir())?;
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
//! Persistent enable/disable state for runtime API skill listings.
|
||||
//!
|
||||
//! Backs `GET /v1/skills` (`enabled` field per skill) and
|
||||
//! `POST /v1/skills/{name}` (toggle). This is separate from the
|
||||
//! filesystem-discovered `SkillRegistry`: the registry tells us which skills
|
||||
//! exist on disk, and this store tells API clients which ones are marked active.
|
||||
//!
|
||||
//! Storage shape (TOML at `~/.deepseek/skills_state.toml`):
|
||||
//!
|
||||
//! ```toml
|
||||
//! disabled = ["skill-name-1", "skill-name-2"]
|
||||
//! ```
|
||||
//!
|
||||
//! Default state when the file does not exist: empty list (everything enabled).
|
||||
//! A corrupt file is logged and treated as the default, so upgrades never
|
||||
//! accidentally hide every skill.
|
||||
|
||||
use std::collections::BTreeSet;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const STATE_FILE_NAME: &str = "skills_state.toml";
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SkillStateStore {
|
||||
path: Option<PathBuf>,
|
||||
disabled: BTreeSet<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
struct OnDiskState {
|
||||
#[serde(default)]
|
||||
disabled: Vec<String>,
|
||||
}
|
||||
|
||||
impl SkillStateStore {
|
||||
pub fn load_default() -> Result<Self> {
|
||||
let path = default_state_path()?;
|
||||
Self::load_from(path)
|
||||
}
|
||||
|
||||
pub fn load_from(path: PathBuf) -> Result<Self> {
|
||||
if !path.exists() {
|
||||
return Ok(Self {
|
||||
path: Some(path),
|
||||
disabled: BTreeSet::new(),
|
||||
});
|
||||
}
|
||||
|
||||
let raw = fs::read_to_string(&path)
|
||||
.with_context(|| format!("read skill state at {}", path.display()))?;
|
||||
let parsed: OnDiskState = match toml::from_str(&raw) {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"skills_state.toml at {} is malformed ({}); treating all skills as enabled",
|
||||
path.display(),
|
||||
err
|
||||
);
|
||||
OnDiskState::default()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
path: Some(path),
|
||||
disabled: parsed.disabled.into_iter().collect(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_enabled(&self, skill_name: &str) -> bool {
|
||||
!self.disabled.contains(skill_name)
|
||||
}
|
||||
|
||||
pub fn set_enabled(&mut self, skill_name: &str, enabled: bool) -> Result<()> {
|
||||
let changed = if enabled {
|
||||
self.disabled.remove(skill_name)
|
||||
} else {
|
||||
self.disabled.insert(skill_name.to_string())
|
||||
};
|
||||
if !changed {
|
||||
return Ok(());
|
||||
}
|
||||
self.persist()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn disabled(&self) -> Vec<String> {
|
||||
self.disabled.iter().cloned().collect()
|
||||
}
|
||||
|
||||
fn persist(&self) -> Result<()> {
|
||||
let Some(path) = self.path.as_ref() else {
|
||||
return Ok(());
|
||||
};
|
||||
let on_disk = OnDiskState {
|
||||
disabled: self.disabled.iter().cloned().collect(),
|
||||
};
|
||||
let body = toml::to_string_pretty(&on_disk).context("serialize skill state")?;
|
||||
atomic_write(path, body.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
fn default_state_path() -> Result<PathBuf> {
|
||||
let home = dirs::home_dir().context("could not resolve $HOME for ~/.deepseek")?;
|
||||
let dir = home.join(".deepseek");
|
||||
fs::create_dir_all(&dir)
|
||||
.with_context(|| format!("create deepseek state dir at {}", dir.display()))?;
|
||||
Ok(dir.join(STATE_FILE_NAME))
|
||||
}
|
||||
|
||||
fn atomic_write(path: &Path, bytes: &[u8]) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("create parent dir for {}", path.display()))?;
|
||||
}
|
||||
let tmp = path.with_extension("toml.tmp");
|
||||
fs::write(&tmp, bytes).with_context(|| format!("write tmp at {}", tmp.display()))?;
|
||||
fs::rename(&tmp, path).with_context(|| format!("rename tmp into {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn fresh() -> (TempDir, SkillStateStore) {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let path = dir.path().join(STATE_FILE_NAME);
|
||||
let store = SkillStateStore::load_from(path).unwrap();
|
||||
(dir, store)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_file_defaults_to_everything_enabled() {
|
||||
let (_dir, store) = fresh();
|
||||
assert!(store.is_enabled("anything"));
|
||||
assert!(store.disabled().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disable_then_reload_persists() {
|
||||
let (dir, mut store) = fresh();
|
||||
store.set_enabled("foo", false).unwrap();
|
||||
assert!(!store.is_enabled("foo"));
|
||||
|
||||
let reloaded = SkillStateStore::load_from(dir.path().join(STATE_FILE_NAME)).unwrap();
|
||||
assert!(!reloaded.is_enabled("foo"));
|
||||
assert!(reloaded.is_enabled("bar"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enable_removes_from_disabled_list() {
|
||||
let (_dir, mut store) = fresh();
|
||||
store.set_enabled("foo", false).unwrap();
|
||||
store.set_enabled("foo", true).unwrap();
|
||||
assert!(store.is_enabled("foo"));
|
||||
assert!(store.disabled().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn redundant_toggle_is_noop() {
|
||||
let (_dir, mut store) = fresh();
|
||||
store.set_enabled("foo", true).unwrap();
|
||||
assert!(store.disabled().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn malformed_file_falls_back_to_default() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let path = dir.path().join(STATE_FILE_NAME);
|
||||
fs::write(&path, b"this is not toml = { broken").unwrap();
|
||||
let store = SkillStateStore::load_from(path).unwrap();
|
||||
assert!(store.is_enabled("anything"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disabled_list_is_deterministic_order() {
|
||||
let (_dir, mut store) = fresh();
|
||||
store.set_enabled("zeta", false).unwrap();
|
||||
store.set_enabled("alpha", false).unwrap();
|
||||
store.set_enabled("mu", false).unwrap();
|
||||
assert_eq!(
|
||||
store.disabled(),
|
||||
vec!["alpha".to_string(), "mu".to_string(), "zeta".to_string()]
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user