refactor: extract neutral command support

This commit is contained in:
Paulo Aboim Pinto
2026-06-07 02:44:29 +02:00
parent 8e8b45a20e
commit 18df8db056
13 changed files with 1096 additions and 1412 deletions
+4 -198
View File
@@ -1,5 +1,3 @@
#![allow(dead_code)]
//! Command safety analysis for shell execution
//!
//! This module provides pre-execution analysis of shell commands to detect
@@ -374,43 +372,38 @@ pub enum SafetyLevel {
#[derive(Debug, Clone)]
pub struct SafetyAnalysis {
pub level: SafetyLevel,
pub command: String,
pub reasons: Vec<String>,
pub suggestions: Vec<String>,
}
impl SafetyAnalysis {
pub fn safe(command: &str) -> Self {
pub fn safe(_command: &str) -> Self {
Self {
level: SafetyLevel::Safe,
command: command.to_string(),
reasons: vec!["Command is read-only".to_string()],
suggestions: vec![],
}
}
pub fn workspace_safe(command: &str, reason: &str) -> Self {
pub fn workspace_safe(_command: &str, reason: &str) -> Self {
Self {
level: SafetyLevel::WorkspaceSafe,
command: command.to_string(),
reasons: vec![reason.to_string()],
suggestions: vec![],
}
}
pub fn requires_approval(command: &str, reasons: Vec<String>) -> Self {
pub fn requires_approval(_command: &str, reasons: Vec<String>) -> Self {
Self {
level: SafetyLevel::RequiresApproval,
command: command.to_string(),
reasons,
suggestions: vec![],
}
}
pub fn dangerous(command: &str, reasons: Vec<String>, suggestions: Vec<String>) -> Self {
pub fn dangerous(_command: &str, reasons: Vec<String>, suggestions: Vec<String>) -> Self {
Self {
level: SafetyLevel::Dangerous,
command: command.to_string(),
reasons,
suggestions,
}
@@ -1012,72 +1005,6 @@ fn is_workspace_safe_command(command: &str) -> bool {
false
}
/// Check if a path escapes the workspace
pub fn path_escapes_workspace(path: &str, workspace: &str) -> bool {
let path_lower = normalize_safety_path(path);
let workspace_lower = normalize_safety_path(workspace);
// Check for obvious escape patterns
if path_lower.starts_with("~/") || path_lower.starts_with("$home") {
return true;
}
if is_absolute_safety_path(&path_lower) {
let path_components = lexical_components(&path_lower);
let workspace_components = lexical_components(&workspace_lower);
return !components_start_with(&path_components, &workspace_components);
}
// Walk the path components. Track depth relative to the workspace root:
// non-`..` components increment depth, `..` components decrement it.
// If depth ever goes negative, the path escapes the workspace boundary.
// This correctly distinguishes genuine traversal like `../outside` from
// names that happen to contain consecutive dots like `foo..bar`.
let mut depth: i32 = 0;
for component in path_lower.split('/') {
match component {
"" | "." => {}
".." => depth -= 1,
_ => depth += 1,
}
if depth < 0 {
return true;
}
}
false
}
fn normalize_safety_path(path: &str) -> String {
path.trim().replace('\\', "/").to_lowercase()
}
fn is_absolute_safety_path(path: &str) -> bool {
path.starts_with('/')
|| path
.as_bytes()
.get(1..3)
.is_some_and(|bytes| bytes[0] == b':' && bytes[1] == b'/')
}
fn lexical_components(path: &str) -> Vec<&str> {
let mut components = Vec::new();
for component in path.split('/') {
match component {
"" | "." => {}
".." => {
components.pop();
}
_ => components.push(component),
}
}
components
}
fn components_start_with(path: &[&str], prefix: &[&str]) -> bool {
path.len() >= prefix.len() && path.iter().zip(prefix.iter()).all(|(a, b)| a == b)
}
/// Parse a command and extract the primary command name
pub fn extract_primary_command(command: &str) -> Option<&str> {
let trimmed = command.trim();
@@ -1093,56 +1020,6 @@ pub fn extract_primary_command(command: &str) -> Option<&str> {
}
}
/// Categorize commands into groups
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CommandCategory {
FileSystem,
Network,
Process,
Package,
Git,
Build,
System,
Shell,
Other,
}
/// Get the category of a command
pub fn categorize_command(command: &str) -> CommandCategory {
let primary = match extract_primary_command(command) {
Some(cmd) => cmd.to_lowercase(),
None => return CommandCategory::Other,
};
match primary.as_str() {
"ls" | "dir" | "cat" | "head" | "tail" | "less" | "more" | "cp" | "mv" | "rm" | "mkdir"
| "rmdir" | "touch" | "chmod" | "chown" | "ln" | "find" | "fd" | "locate" | "stat"
| "file" => CommandCategory::FileSystem,
"curl" | "wget" | "fetch" | "nc" | "netcat" | "ssh" | "scp" | "sftp" | "rsync" | "ftp"
| "ping" | "traceroute" | "nslookup" | "dig" | "host" | "nmap" => CommandCategory::Network,
"ps" | "top" | "htop" | "kill" | "killall" | "pkill" | "pgrep" | "nice" | "renice"
| "nohup" | "timeout" => CommandCategory::Process,
"npm" | "yarn" | "pnpm" | "pip" | "pip3" | "brew" | "apt" | "apt-get" | "yum" | "dnf"
| "pacman" => CommandCategory::Package,
"git" | "gh" | "hub" => CommandCategory::Git,
"make" | "cmake" | "ninja" | "meson" | "cargo" | "go" | "gcc" | "g++" | "clang"
| "rustc" | "javac" | "tsc" => CommandCategory::Build,
"sudo" | "su" | "systemctl" | "service" | "shutdown" | "reboot" | "mount" | "umount"
| "fdisk" | "parted" => CommandCategory::System,
"bash" | "sh" | "zsh" | "fish" | "csh" | "tcsh" | "dash" | "source" | "." | "exec"
| "eval" => CommandCategory::Shell,
_ => CommandCategory::Other,
}
}
// === Unit Tests ===
#[cfg(test)]
@@ -1321,62 +1198,6 @@ mod tests {
);
}
#[test]
fn test_path_escapes_workspace() {
assert!(path_escapes_workspace("/etc/passwd", "/home/user/project"));
assert!(path_escapes_workspace("~/secret", "/home/user/project"));
assert!(!path_escapes_workspace(
"./src/main.rs",
"/home/user/project"
));
}
#[test]
fn test_path_escapes_workspace_doesnt_flag_double_dot_in_names() {
// Names like `foo..bar` should NOT be flagged as path traversal
assert!(!path_escapes_workspace(
"some..file.txt",
"/home/user/project"
));
assert!(!path_escapes_workspace(
"./dir..name/file.txt",
"/home/user/project"
));
}
#[test]
fn test_path_escapes_workspace_detects_genuine_traversal() {
assert!(path_escapes_workspace("../outside", "/home/user/project"));
assert!(path_escapes_workspace(
"..\\outside",
"C:\\Users\\me\\project"
));
assert!(path_escapes_workspace(
"./subdir/../../etc/passwd",
"/home/user/project"
));
assert!(path_escapes_workspace(
"/home/user/project/../secret",
"/home/user/project"
));
assert!(path_escapes_workspace(
"C:\\Users\\me\\project\\..\\secret",
"C:\\Users\\me\\project"
));
}
#[test]
fn test_path_escapes_workspace_allows_absolute_workspace_children() {
assert!(!path_escapes_workspace(
"/home/user/project/src/main.rs",
"/home/user/project"
));
assert!(!path_escapes_workspace(
"C:\\Users\\me\\project\\src\\main.rs",
"C:\\Users\\me\\project"
));
}
#[test]
fn test_extract_primary_command() {
assert_eq!(extract_primary_command("ls -la"), Some("ls"));
@@ -1387,21 +1208,6 @@ mod tests {
assert_eq!(extract_primary_command(" git status "), Some("git"));
}
#[test]
fn test_categorize_command() {
assert_eq!(categorize_command("ls -la"), CommandCategory::FileSystem);
assert_eq!(
categorize_command("curl https://example.com"),
CommandCategory::Network
);
assert_eq!(categorize_command("git status"), CommandCategory::Git);
assert_eq!(categorize_command("npm install"), CommandCategory::Package);
assert_eq!(
categorize_command("sudo apt update"),
CommandCategory::System
);
}
// ── classify_command tests ────────────────────────────────────────────────
/// Helper: split a string on whitespace into a `Vec<&str>` and call
File diff suppressed because it is too large Load Diff
-68
View File
@@ -78,7 +78,6 @@ impl CommandResult {
}
/// Create a result with both message and action
#[allow(dead_code)]
pub fn with_message_and_action(msg: impl Into<String>, action: AppAction) -> Self {
Self {
message: Some(msg.into()),
@@ -710,37 +709,9 @@ pub fn set_config_value(app: &mut App, key: &str, value: &str, persist: bool) ->
config::set_config_value(app, key, value, persist)
}
/// Persist the user's chosen footer items to `~/.deepseek/config.toml` under
/// `tui.status_items`. See [`config::persist_status_items`] for details.
pub fn persist_status_items(
items: &[crate::config::StatusItem],
) -> anyhow::Result<std::path::PathBuf> {
config::persist_status_items(items)
}
/// Persist a root-level string key in `config.toml`.
pub fn persist_root_string_key(
config_path: Option<&std::path::Path>,
key: &str,
value: &str,
) -> anyhow::Result<std::path::PathBuf> {
config::persist_root_string_key(config_path, key, value)
}
pub fn switch_mode(app: &mut App, mode: crate::tui::app::AppMode) -> String {
config::switch_mode(app, mode)
}
/// Auto-select a model based on request complexity.
pub fn auto_model_heuristic(input: &str, current_model: &str) -> String {
config::auto_model_heuristic(input, current_model)
}
pub use config::{
AutoRouteRecommendation, AutoRouteSelection, normalize_auto_route_effort,
parse_auto_route_recommendation, resolve_auto_route_with_flash,
};
/// Execute a Recursive Language Model (RLM) turn — Algorithm 1 from
/// Zhang et al. (arXiv:2512.24601).
///
@@ -1006,45 +977,6 @@ pub fn get_command_info(name: &str) -> Option<&'static CommandInfo> {
.find(|cmd| cmd.name == name || cmd.aliases.contains(&name))
}
/// Get all command names matching a prefix, including both built-in
/// static commands and user-defined commands, formatted as `/name`.
///
/// `workspace` is used to also scan workspace-local command directories;
/// pass `None` when no workspace context is available.
#[allow(dead_code)]
pub fn all_command_names_matching(
prefix: &str,
workspace: Option<&std::path::Path>,
) -> Vec<String> {
let prefix = prefix.strip_prefix('/').unwrap_or(prefix).to_lowercase();
let mut result: Vec<String> = COMMANDS
.iter()
.filter(|cmd| {
cmd.name.starts_with(&prefix) || cmd.aliases.iter().any(|a| a.starts_with(&prefix))
})
.map(|cmd| format!("/{}", cmd.name))
.collect();
// Add user-defined commands
result.extend(user_commands::user_commands_matching(&prefix, workspace));
result.sort();
result.dedup();
result
}
/// Get all commands matching a prefix (for autocomplete)
#[allow(dead_code)]
pub fn commands_matching(prefix: &str) -> Vec<&'static CommandInfo> {
let prefix = prefix.strip_prefix('/').unwrap_or(prefix).to_lowercase();
COMMANDS
.iter()
.filter(|cmd| {
cmd.name.starts_with(&prefix) || cmd.aliases.iter().any(|a| a.starts_with(&prefix))
})
.collect()
}
fn edit_distance(a: &str, b: &str) -> usize {
if a == b {
return 0;
+3 -3
View File
@@ -70,7 +70,7 @@ enum NetworkEdit {
}
fn list_policy() -> anyhow::Result<String> {
let path = super::config::config_toml_path(None)?;
let path = crate::config_persistence::config_toml_path(None)?;
let doc = load_config_doc(&path)?;
let network = doc.get("network").and_then(Value::as_table);
let default = network
@@ -97,7 +97,7 @@ fn list_policy() -> anyhow::Result<String> {
}
fn update_host(edit: NetworkEdit, host: &str) -> anyhow::Result<String> {
let path = super::config::config_toml_path(None)?;
let path = crate::config_persistence::config_toml_path(None)?;
let mut doc = load_config_doc(&path)?;
let network = network_table_mut(&mut doc)?;
@@ -136,7 +136,7 @@ fn update_default(value: &str) -> anyhow::Result<String> {
_ => bail!("Usage: /network default <allow|deny|prompt>"),
};
let path = super::config::config_toml_path(None)?;
let path = crate::config_persistence::config_toml_path(None)?;
let mut doc = load_config_doc(&path)?;
let network = network_table_mut(&mut doc)?;
network.insert("default".to_string(), Value::String(normalized.to_string()));
-39
View File
@@ -232,22 +232,6 @@ pub fn try_dispatch_user_command(app: &mut App, input: &str) -> Option<CommandRe
None
}
/// Get user command names that match a given prefix (for autocomplete).
///
/// The prefix should be the command name portion only (after `/`).
/// Returns entries formatted as `/name`.
///
/// `workspace` is used to also scan workspace-local command directories;
/// pass `None` when no workspace context is available.
pub fn user_commands_matching(prefix: &str, workspace: Option<&Path>) -> Vec<String> {
let prefix = prefix.to_lowercase();
load_user_commands(workspace)
.into_iter()
.filter(|(name, _)| name.starts_with(&prefix))
.map(|(name, _)| format!("/{name}"))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
@@ -307,12 +291,6 @@ mod tests {
assert!(result.is_none());
}
#[test]
fn test_user_commands_matching_with_prefix_no_workspace() {
let matches = user_commands_matching("zzzznotfound", None);
assert!(matches.is_empty());
}
// ── Workspace-local commands tests ─────────────────────────────────
fn write_command(dir: &Path, name: &str, body: &str) {
@@ -474,23 +452,6 @@ mod tests {
}
}
#[test]
fn user_commands_matching_with_workspace() {
let tmp = TempDir::new().unwrap();
let ws = tmp.path();
write_command(
&ws.join(".deepseek").join("commands"),
"project-cmd",
"body",
);
let matches = user_commands_matching("project", Some(ws));
assert!(
matches.contains(&"/project-cmd".to_string()),
"got: {matches:?}"
);
}
#[test]
fn frontmatter_is_stripped_before_dispatch() {
use crate::config::Config;
+461
View File
@@ -0,0 +1,461 @@
//! Config file path resolution and TOML persistence helpers.
//!
//! These helpers are used by command handlers and non-command UI code, so
//! persistence lives outside the command tree.
use std::path::{Path, PathBuf};
use crate::config::{ApiProvider, StatusItem, effective_home_dir, expand_path};
pub(crate) fn persist_status_items(items: &[StatusItem]) -> anyhow::Result<PathBuf> {
use anyhow::Context;
use std::fs;
let path = config_toml_path(None)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create config directory {}", parent.display()))?;
}
let mut doc: toml::Value = if path.exists() {
let raw = fs::read_to_string(&path)
.with_context(|| format!("failed to read config at {}", path.display()))?;
toml::from_str(&raw)
.with_context(|| format!("failed to parse config at {}", path.display()))?
} else {
toml::Value::Table(toml::value::Table::new())
};
let table = doc
.as_table_mut()
.context("config.toml root must be a table")?;
let tui_entry = table
.entry("tui".to_string())
.or_insert_with(|| toml::Value::Table(toml::value::Table::new()));
let tui_table = tui_entry
.as_table_mut()
.context("`tui` section in config.toml must be a table")?;
let array = items
.iter()
.map(|item| toml::Value::String(item.key().to_string()))
.collect::<Vec<_>>();
tui_table.insert("status_items".to_string(), toml::Value::Array(array));
let body = toml::to_string_pretty(&doc).context("failed to serialize config.toml")?;
fs::write(&path, body)
.with_context(|| format!("failed to write config at {}", path.display()))?;
Ok(path)
}
pub(crate) fn persist_root_string_key(
config_path: Option<&Path>,
key: &str,
value: &str,
) -> anyhow::Result<PathBuf> {
use anyhow::Context;
use std::fs;
let path = config_toml_path(config_path)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create config directory {}", parent.display()))?;
}
let mut doc: toml::Value = if path.exists() {
let raw = fs::read_to_string(&path)
.with_context(|| format!("failed to read config at {}", path.display()))?;
toml::from_str(&raw)
.with_context(|| format!("failed to parse config at {}", path.display()))?
} else {
toml::Value::Table(toml::value::Table::new())
};
let table = doc
.as_table_mut()
.context("config.toml root must be a table")?;
table.insert(key.to_string(), toml::Value::String(value.to_string()));
let body = toml::to_string_pretty(&doc).context("failed to serialize config.toml")?;
fs::write(&path, body)
.with_context(|| format!("failed to write config at {}", path.display()))?;
Ok(path)
}
pub(crate) fn persist_root_bool_key(
config_path: Option<&Path>,
key: &str,
value: bool,
) -> anyhow::Result<PathBuf> {
use anyhow::Context;
use std::fs;
let path = config_toml_path(config_path)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create config directory {}", parent.display()))?;
}
let mut doc: toml::Value = if path.exists() {
let raw = fs::read_to_string(&path)
.with_context(|| format!("failed to read config at {}", path.display()))?;
toml::from_str(&raw)
.with_context(|| format!("failed to parse config at {}", path.display()))?
} else {
toml::Value::Table(toml::value::Table::new())
};
let table = doc
.as_table_mut()
.context("config.toml root must be a table")?;
table.insert(key.to_string(), toml::Value::Boolean(value));
let body = toml::to_string_pretty(&doc).context("failed to serialize config.toml")?;
fs::write(&path, body)
.with_context(|| format!("failed to write config at {}", path.display()))?;
Ok(path)
}
pub(crate) fn persist_tui_integer_key(
config_path: Option<&Path>,
key: &str,
value: u64,
) -> anyhow::Result<PathBuf> {
use anyhow::Context;
use std::fs;
let path = config_toml_path(config_path)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create config directory {}", parent.display()))?;
}
let mut doc: toml::Value = if path.exists() {
let raw = fs::read_to_string(&path)
.with_context(|| format!("failed to read config at {}", path.display()))?;
toml::from_str(&raw)
.with_context(|| format!("failed to parse config at {}", path.display()))?
} else {
toml::Value::Table(toml::value::Table::new())
};
let table = doc
.as_table_mut()
.context("config.toml root must be a table")?;
let tui_entry = table
.entry("tui".to_string())
.or_insert_with(|| toml::Value::Table(toml::value::Table::new()));
let tui_table = tui_entry
.as_table_mut()
.context("`tui` section in config.toml must be a table")?;
let value = i64::try_from(value).context("integer value is too large for TOML")?;
tui_table.insert(key.to_string(), toml::Value::Integer(value));
let body = toml::to_string_pretty(&doc).context("failed to serialize config.toml")?;
fs::write(&path, body)
.with_context(|| format!("failed to write config at {}", path.display()))?;
Ok(path)
}
pub(crate) fn persist_provider_base_url_key(
config_path: Option<&Path>,
provider: ApiProvider,
value: &str,
) -> anyhow::Result<PathBuf> {
use anyhow::Context;
use std::fs;
let path = config_toml_path(config_path)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create config directory {}", parent.display()))?;
}
let mut doc: toml::Value = if path.exists() {
let raw = fs::read_to_string(&path)
.with_context(|| format!("failed to read config at {}", path.display()))?;
toml::from_str(&raw)
.with_context(|| format!("failed to parse config at {}", path.display()))?
} else {
toml::Value::Table(toml::value::Table::new())
};
let table = doc
.as_table_mut()
.context("config.toml root must be a table")?;
let providers = table
.entry("providers".to_string())
.or_insert_with(|| toml::Value::Table(toml::value::Table::new()))
.as_table_mut()
.context("`providers` must be a table")?;
let provider_key = provider_base_url_table_key(provider)?;
let entry = providers
.entry(provider_key.to_string())
.or_insert_with(|| toml::Value::Table(toml::value::Table::new()))
.as_table_mut()
.with_context(|| format!("`providers.{provider_key}` must be a table"))?;
entry.insert(
"base_url".to_string(),
toml::Value::String(value.to_string()),
);
let body = toml::to_string_pretty(&doc).context("failed to serialize config.toml")?;
fs::write(&path, body)
.with_context(|| format!("failed to write config at {}", path.display()))?;
Ok(path)
}
fn provider_base_url_table_key(provider: ApiProvider) -> anyhow::Result<&'static str> {
match provider {
ApiProvider::Deepseek | ApiProvider::DeepseekCN => {
anyhow::bail!("DeepSeek uses the root base_url setting")
}
ApiProvider::NvidiaNim => Ok("nvidia_nim"),
ApiProvider::Openai => Ok("openai"),
ApiProvider::Atlascloud => Ok("atlascloud"),
ApiProvider::WanjieArk => Ok("wanjie_ark"),
ApiProvider::Volcengine => Ok("volcengine"),
ApiProvider::Openrouter => Ok("openrouter"),
ApiProvider::XiaomiMimo => Ok("xiaomi_mimo"),
ApiProvider::Novita => Ok("novita"),
ApiProvider::Fireworks => Ok("fireworks"),
ApiProvider::Siliconflow | ApiProvider::SiliconflowCn => Ok("siliconflow"),
ApiProvider::Arcee => Ok("arcee"),
ApiProvider::Huggingface => Ok("huggingface"),
ApiProvider::Moonshot => Ok("moonshot"),
ApiProvider::Sglang => Ok("sglang"),
ApiProvider::Vllm => Ok("vllm"),
ApiProvider::Ollama => Ok("ollama"),
}
}
pub(crate) fn config_toml_path(config_path: Option<&Path>) -> anyhow::Result<PathBuf> {
use anyhow::Context;
if let Some(path) = config_path {
return Ok(expand_path(path.to_string_lossy().as_ref()));
}
if let Ok(env) = std::env::var("CODEWHALE_CONFIG_PATH") {
let trimmed = env.trim();
if !trimmed.is_empty() {
return Ok(PathBuf::from(trimmed));
}
}
if let Ok(env) = std::env::var("DEEPSEEK_CONFIG_PATH") {
let trimmed = env.trim();
if !trimmed.is_empty() {
return Ok(PathBuf::from(trimmed));
}
}
let home =
effective_home_dir().context("failed to resolve home directory for config.toml path")?;
let primary = home.join(".codewhale").join("config.toml");
if primary.exists() {
return Ok(primary);
}
let legacy = home.join(".deepseek").join("config.toml");
if legacy.exists() {
return Ok(legacy);
}
Ok(primary)
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use std::ffi::OsString;
use std::fs;
use std::path::Path;
use std::time::{SystemTime, UNIX_EPOCH};
struct EnvGuard {
home: Option<OsString>,
userprofile: Option<OsString>,
codewhale_config_path: Option<OsString>,
deepseek_config_path: Option<OsString>,
_lock: std::sync::MutexGuard<'static, ()>,
}
impl EnvGuard {
fn new(home: &Path) -> Self {
let lock = crate::test_support::lock_test_env();
let home_str = OsString::from(home.as_os_str());
let config_path = home.join(".deepseek").join("config.toml");
let config_str = OsString::from(config_path.as_os_str());
let home_prev = env::var_os("HOME");
let userprofile_prev = env::var_os("USERPROFILE");
let codewhale_config_prev = env::var_os("CODEWHALE_CONFIG_PATH");
let deepseek_config_prev = env::var_os("DEEPSEEK_CONFIG_PATH");
// Safety: test-only environment mutation guarded by process-wide mutex.
unsafe {
env::set_var("HOME", &home_str);
env::set_var("USERPROFILE", &home_str);
env::remove_var("CODEWHALE_CONFIG_PATH");
env::set_var("DEEPSEEK_CONFIG_PATH", &config_str);
}
Self {
home: home_prev,
userprofile: userprofile_prev,
codewhale_config_path: codewhale_config_prev,
deepseek_config_path: deepseek_config_prev,
_lock: lock,
}
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
if let Some(value) = self.home.take() {
// Safety: test-only environment mutation guarded by a global mutex.
unsafe {
env::set_var("HOME", value);
}
} else {
// Safety: test-only environment mutation guarded by a global mutex.
unsafe {
env::remove_var("HOME");
}
}
if let Some(value) = self.userprofile.take() {
// Safety: test-only environment mutation guarded by a global mutex.
unsafe {
env::set_var("USERPROFILE", value);
}
} else {
// Safety: test-only environment mutation guarded by a global mutex.
unsafe {
env::remove_var("USERPROFILE");
}
}
if let Some(value) = self.codewhale_config_path.take() {
// Safety: test-only environment mutation guarded by a global mutex.
unsafe {
env::set_var("CODEWHALE_CONFIG_PATH", value);
}
} else {
// Safety: test-only environment mutation guarded by a global mutex.
unsafe {
env::remove_var("CODEWHALE_CONFIG_PATH");
}
}
if let Some(value) = self.deepseek_config_path.take() {
// Safety: test-only environment mutation guarded by a global mutex.
unsafe {
env::set_var("DEEPSEEK_CONFIG_PATH", value);
}
} else {
// Safety: test-only environment mutation guarded by a global mutex.
unsafe {
env::remove_var("DEEPSEEK_CONFIG_PATH");
}
}
}
}
fn temp_root(prefix: &str) -> std::path::PathBuf {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
env::temp_dir().join(format!("{prefix}-{}-{nanos}", std::process::id()))
}
#[test]
fn persist_status_items_writes_tui_section_to_config_toml() {
let temp_root = temp_root("codewhale-statusline-persist");
fs::create_dir_all(&temp_root).unwrap();
let _guard = EnvGuard::new(&temp_root);
let items = vec![
crate::config::StatusItem::Mode,
crate::config::StatusItem::Model,
crate::config::StatusItem::Cost,
];
let path = persist_status_items(&items).expect("persist should succeed");
let body = fs::read_to_string(&path).expect("written file should be readable");
assert!(body.contains("[tui]"), "expected [tui] section in {body}");
assert!(
body.contains("status_items"),
"expected status_items key in {body}"
);
assert!(body.contains("\"mode\""), "expected mode key in {body}");
assert!(body.contains("\"cost\""), "expected cost key in {body}");
}
#[test]
fn config_toml_path_uses_codewhale_home_for_fresh_installs() {
let temp_root = temp_root("codewhale-config-path-fresh");
fs::create_dir_all(&temp_root).unwrap();
let _guard = EnvGuard::new(&temp_root);
unsafe {
env::remove_var("DEEPSEEK_CONFIG_PATH");
}
assert_eq!(
config_toml_path(None).unwrap(),
temp_root.join(".codewhale").join("config.toml")
);
}
#[test]
fn config_toml_path_preserves_legacy_config_when_it_exists() {
let temp_root = temp_root("codewhale-config-path-legacy");
let legacy_config = temp_root.join(".deepseek").join("config.toml");
fs::create_dir_all(legacy_config.parent().unwrap()).unwrap();
fs::write(&legacy_config, "").unwrap();
let _guard = EnvGuard::new(&temp_root);
unsafe {
env::remove_var("DEEPSEEK_CONFIG_PATH");
}
assert_eq!(config_toml_path(None).unwrap(), legacy_config);
}
#[test]
fn config_toml_path_prefers_codewhale_env_over_legacy_env() {
let temp_root = temp_root("codewhale-config-path-env");
fs::create_dir_all(&temp_root).unwrap();
let _guard = EnvGuard::new(&temp_root);
let preferred = temp_root.join("preferred.toml");
let legacy = temp_root.join("legacy.toml");
unsafe {
env::set_var("CODEWHALE_CONFIG_PATH", &preferred);
env::set_var("DEEPSEEK_CONFIG_PATH", &legacy);
}
assert_eq!(config_toml_path(None).unwrap(), preferred);
}
#[test]
fn persist_status_items_preserves_existing_unrelated_keys() {
let temp_root = temp_root("codewhale-statusline-preserve");
fs::create_dir_all(&temp_root).unwrap();
let _guard = EnvGuard::new(&temp_root);
let path = temp_root.join(".deepseek").join("config.toml");
fs::create_dir_all(path.parent().unwrap()).unwrap();
fs::write(
&path,
"api_key = \"sentinel-key\"\nmodel = \"deepseek-v4-pro\"\n",
)
.unwrap();
let written = persist_status_items(&[crate::config::StatusItem::Mode])
.expect("persist should succeed");
let body = fs::read_to_string(&written).expect("written file should be readable");
assert!(
body.contains("api_key = \"sentinel-key\""),
"round-trip lost api_key: {body}"
);
assert!(
body.contains("model = \"deepseek-v4-pro\""),
"round-trip lost model: {body}"
);
assert!(
body.contains("status_items"),
"expected status_items in {body}"
);
}
}
+2 -2
View File
@@ -596,7 +596,7 @@ pub fn apply_document(
app.status_items = new_status_items.clone();
app.needs_redraw = true;
if persist {
let path = commands::persist_status_items(&new_status_items)?;
let path = crate::config_persistence::persist_status_items(&new_status_items)?;
notes.push(format!("status_items saved to {}", path.display()));
} else {
notes.push("status_items updated for this session".to_string());
@@ -685,7 +685,7 @@ fn apply_reasoning_effort(
app.last_effective_reasoning_effort = None;
app.update_model_compaction_budget();
if persist {
commands::persist_root_string_key(
crate::config_persistence::persist_root_string_key(
app.config_path.as_deref(),
"reasoning_effort",
effort.as_setting(),
+9 -1
View File
@@ -27,6 +27,7 @@ mod compaction;
mod composer_history;
mod composer_stash;
mod config;
mod config_persistence;
mod config_ui;
mod core;
mod cost_status;
@@ -46,6 +47,7 @@ mod lsp;
mod mcp;
mod mcp_server;
mod memory;
mod model_routing;
mod models;
mod network_policy;
mod palette;
@@ -5505,7 +5507,7 @@ struct CliAutoRoute {
async fn resolve_cli_auto_route(config: &Config, model: &str, prompt: &str) -> CliAutoRoute {
if model.trim().eq_ignore_ascii_case("auto") {
let selection =
commands::resolve_auto_route_with_flash(config, prompt, "", "auto", "auto").await;
model_routing::resolve_auto_route_with_flash(config, prompt, "", "auto", "auto").await;
CliAutoRoute {
model: selection.model,
reasoning_effort: selection.reasoning_effort,
@@ -6709,6 +6711,12 @@ mod terminal_mode_tests {
.args(["config", "user.email", "codewhale@example.invalid"])
.status()
.expect("git config user.email");
std::process::Command::new("git")
.arg("-C")
.arg(repo)
.args(["config", "core.autocrlf", "false"])
.status()
.expect("git config core.autocrlf");
std::fs::write(
repo.join("math_utils.py"),
"def add(a, b):\n return a - b\n",
+569
View File
@@ -0,0 +1,569 @@
//! Model selection and auto-routing.
//!
//! The CLI, TUI, runtime threads, subagents, and command handlers all need
//! this behavior, so it intentionally lives outside the command tree.
use std::time::Duration;
use anyhow::Result;
use crate::client::DeepSeekClient;
use crate::config::Config;
use crate::llm_client::LlmClient;
use crate::models::{ContentBlock, Message, MessageRequest, MessageResponse, SystemPrompt};
use crate::tui::app::ReasoningEffort;
/// Auto-select a model based on request complexity.
///
/// Short messages (<100 chars) go to Flash. Long messages and requests with
/// complex keywords go to Pro. The fallback is Flash.
pub(crate) fn auto_model_heuristic(input: &str, current_model: &str) -> String {
auto_model_heuristic_with_bias(input, current_model, false)
}
fn auto_model_heuristic_with_bias(input: &str, current_model: &str, cost_saving: bool) -> String {
auto_model_heuristic_selection_with_bias(input, current_model, cost_saving).model
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum AutoModelHeuristicConfidence {
Decisive,
Ambiguous,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct AutoModelHeuristicSelection {
model: String,
confidence: AutoModelHeuristicConfidence,
}
fn auto_model_heuristic_selection_with_bias(
input: &str,
_current_model: &str,
cost_saving: bool,
) -> AutoModelHeuristicSelection {
let len = input.chars().count();
let lower = input.to_lowercase();
let borderline_pro_keywords: &[&str] = &[
"implement",
"analyze",
"\u{5b9e}\u{73b0}",
"\u{5206}\u{6790}",
"\u{5be6}\u{73fe}",
];
let strong_match = COMPLEX_KEYWORDS
.iter()
.any(|kw| !borderline_pro_keywords.contains(kw) && lower.contains(kw));
let borderline_match = borderline_pro_keywords.iter().any(|kw| lower.contains(kw));
let pro_match = strong_match || (!cost_saving && borderline_match);
if pro_match {
return AutoModelHeuristicSelection {
model: "deepseek-v4-pro".to_string(),
confidence: AutoModelHeuristicConfidence::Decisive,
};
}
if len < 100 {
return AutoModelHeuristicSelection {
model: "deepseek-v4-flash".to_string(),
confidence: AutoModelHeuristicConfidence::Decisive,
};
}
let long_threshold = if cost_saving { 1_000 } else { 500 };
if len > long_threshold {
return AutoModelHeuristicSelection {
model: "deepseek-v4-pro".to_string(),
confidence: AutoModelHeuristicConfidence::Decisive,
};
}
AutoModelHeuristicSelection {
model: "deepseek-v4-flash".to_string(),
confidence: AutoModelHeuristicConfidence::Ambiguous,
}
}
const COMPLEX_KEYWORDS: &[&str] = &[
"refactor",
"architecture",
"design",
"debug",
"security",
"review",
"audit",
"migrate",
"optimize",
"rewrite",
"implement",
"analyze",
"\u{91cd}\u{6784}",
"\u{67b6}\u{6784}",
"\u{8bbe}\u{8ba1}",
"\u{8c03}\u{8bd5}",
"\u{5b89}\u{5168}",
"\u{5ba1}\u{67e5}",
"\u{5ba1}\u{8ba1}",
"\u{8fc1}\u{79fb}",
"\u{4f18}\u{5316}",
"\u{91cd}\u{5199}",
"\u{5b9e}\u{73b0}",
"\u{5206}\u{6790}",
"\u{91cd}\u{69cb}",
"\u{67b6}\u{69cb}",
"\u{8a2d}\u{8a08}",
"\u{8abf}\u{8a66}",
"\u{5be9}\u{67e5}",
"\u{5be9}\u{8a08}",
"\u{9077}\u{79fb}",
"\u{512a}\u{5316}",
"\u{91cd}\u{5beb}",
"\u{5be6}\u{73fe}",
];
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct AutoRouteRecommendation {
pub(crate) model: String,
pub(crate) reasoning_effort: Option<ReasoningEffort>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum AutoRouteSource {
FlashRouter,
Heuristic,
}
impl AutoRouteSource {
#[must_use]
pub(crate) fn label(self) -> &'static str {
match self {
AutoRouteSource::FlashRouter => "flash-router",
AutoRouteSource::Heuristic => "heuristic",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct AutoRouteSelection {
pub(crate) model: String,
pub(crate) reasoning_effort: Option<ReasoningEffort>,
pub(crate) source: AutoRouteSource,
}
const AUTO_MODEL_ROUTER_SYSTEM_PROMPT: &str = "\
You are the codewhale auto-routing classifier. Return only compact JSON: \
{\"model\":\"deepseek-v4-flash|deepseek-v4-pro\",\"thinking\":\"off|high|max\"}. \
Use deepseek-v4-flash for trivial, conversational, status, or single-step work. \
Use deepseek-v4-pro for coding, debugging, release work, multi-step tasks, high-risk decisions, \
tool-heavy work, ambiguous requests, or anything that benefits from deeper reasoning. \
Use thinking off only for trivial no-tool answers, high for ordinary reasoning, and max for \
agentic, coding, multi-file, release, architecture, debugging, security, tool-heavy, or uncertain work.";
const AUTO_MODEL_ROUTER_COST_SAVING_ADDENDUM: &str = "\
\n\nCost-saving mode is ON. Prefer deepseek-v4-flash for any request that is \
not unmistakably agentic, multi-step, architecture/design, security review, \
debugging, or otherwise clearly out of Flash's capability. Resolve ambiguous \
cases in favour of deepseek-v4-flash, not deepseek-v4-pro.";
pub(crate) fn parse_auto_route_recommendation(raw: &str) -> Option<AutoRouteRecommendation> {
let json = extract_first_json_object(raw)?;
let value: serde_json::Value = serde_json::from_str(json).ok()?;
let model = value.get("model").and_then(serde_json::Value::as_str)?;
let model = normalize_auto_route_model(model)?;
let reasoning_effort = value
.get("thinking")
.or_else(|| value.get("reasoning_effort"))
.or_else(|| value.get("effort"))
.and_then(serde_json::Value::as_str)
.and_then(parse_auto_route_reasoning_effort);
Some(AutoRouteRecommendation {
model: model.to_string(),
reasoning_effort,
})
}
fn extract_first_json_object(raw: &str) -> Option<&str> {
let start = raw.find('{')?;
let end = raw.rfind('}')?;
(end >= start).then_some(&raw[start..=end])
}
fn normalize_auto_route_model(model: &str) -> Option<&'static str> {
match model.trim().to_ascii_lowercase().as_str() {
"deepseek-v4-pro" | "v4-pro" | "pro" => Some("deepseek-v4-pro"),
"deepseek-v4-flash" | "v4-flash" | "flash" => Some("deepseek-v4-flash"),
_ => None,
}
}
fn parse_auto_route_reasoning_effort(effort: &str) -> Option<ReasoningEffort> {
match effort.trim().to_ascii_lowercase().as_str() {
"off" | "disabled" | "none" | "false" => Some(ReasoningEffort::Off),
"low" | "minimal" | "medium" | "mid" => Some(ReasoningEffort::High),
"high" => Some(ReasoningEffort::High),
"max" | "maximum" | "xhigh" => Some(ReasoningEffort::Max),
_ => None,
}
}
#[must_use]
pub(crate) fn normalize_auto_route_effort(effort: ReasoningEffort) -> ReasoningEffort {
match effort {
ReasoningEffort::Low | ReasoningEffort::Medium => ReasoningEffort::High,
other => other,
}
}
pub(crate) async fn resolve_auto_route_with_flash(
config: &Config,
latest_request: &str,
recent_context: &str,
selected_model_mode: &str,
selected_thinking_mode: &str,
) -> AutoRouteSelection {
let cost_saving = config.auto_cost_saving();
let heuristic =
auto_model_heuristic_selection_with_bias(latest_request, selected_model_mode, cost_saving);
if heuristic.confidence == AutoModelHeuristicConfidence::Decisive {
return auto_route_from_heuristic(latest_request, heuristic);
}
match auto_route_flash_recommendation(
config,
latest_request,
recent_context,
selected_model_mode,
selected_thinking_mode,
)
.await
{
Ok(Some(recommendation)) => AutoRouteSelection {
model: recommendation.model,
reasoning_effort: recommendation.reasoning_effort,
source: AutoRouteSource::FlashRouter,
},
Ok(None) | Err(_) => auto_route_from_heuristic(latest_request, heuristic),
}
}
fn auto_route_from_heuristic(
latest_request: &str,
heuristic: AutoModelHeuristicSelection,
) -> AutoRouteSelection {
AutoRouteSelection {
model: heuristic.model,
reasoning_effort: Some(normalize_auto_route_effort(crate::auto_reasoning::select(
false,
latest_request,
))),
source: AutoRouteSource::Heuristic,
}
}
async fn auto_route_flash_recommendation(
config: &Config,
latest_request: &str,
recent_context: &str,
selected_model_mode: &str,
selected_thinking_mode: &str,
) -> Result<Option<AutoRouteRecommendation>> {
if cfg!(test) {
return Ok(None);
}
let client = DeepSeekClient::new(config)?;
let mut router_system = AUTO_MODEL_ROUTER_SYSTEM_PROMPT.to_string();
if config.auto_cost_saving() {
router_system.push_str(AUTO_MODEL_ROUTER_COST_SAVING_ADDENDUM);
}
let request = MessageRequest {
model: "deepseek-v4-flash".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: vec![ContentBlock::Text {
text: auto_route_prompt(
latest_request,
recent_context,
selected_model_mode,
selected_thinking_mode,
),
cache_control: None,
}],
}],
max_tokens: 96,
system: Some(SystemPrompt::Text(router_system)),
tools: None,
tool_choice: None,
metadata: None,
thinking: None,
reasoning_effort: Some("off".to_string()),
stream: Some(false),
temperature: Some(0.0),
top_p: None,
};
let response =
tokio::time::timeout(Duration::from_secs(4), client.create_message(request)).await??;
Ok(parse_auto_route_recommendation(&message_response_text(
&response,
)))
}
fn auto_route_prompt(
latest_request: &str,
recent_context: &str,
selected_model_mode: &str,
selected_thinking_mode: &str,
) -> String {
format!(
"Session mode: agent\nSelected model mode: {}\nSelected thinking mode: {}\n\nRecent context:\n{}\n\nLatest user request:\n{}\n\nReturn JSON only.",
selected_model_mode,
selected_thinking_mode,
if recent_context.trim().is_empty() {
"No prior context."
} else {
recent_context
},
truncate_for_auto_router(latest_request, 4_000)
)
}
fn message_response_text(response: &MessageResponse) -> String {
let mut out = String::new();
for block in &response.content {
match block {
ContentBlock::Text { text, .. } | ContentBlock::ToolResult { content: text, .. } => {
append_router_text(&mut out, text);
}
ContentBlock::Thinking { thinking } => {
append_router_text(&mut out, thinking);
}
ContentBlock::ToolUse { name, .. } => {
append_router_text(&mut out, &format!("[tool call: {name}]"));
}
_ => {}
}
}
out
}
fn append_router_text(out: &mut String, text: &str) {
if !out.is_empty() {
out.push('\n');
}
out.push_str(text);
}
fn truncate_for_auto_router(text: &str, max_chars: usize) -> String {
let mut chars = text.chars();
let truncated: String = chars.by_ref().take(max_chars).collect();
if chars.next().is_some() {
format!("{truncated}...")
} else {
truncated
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auto_model_heuristic_chinese_keywords_route_to_pro() {
for msg in [
"\u{5e2e}\u{6211}\u{91cd}\u{6784}\u{8fd9}\u{4e2a}\u{6a21}\u{5757}",
"\u{8bbe}\u{8ba1}\u{6570}\u{636e}\u{5e93}\u{67b6}\u{6784}",
"\u{8c03}\u{8bd5}\u{5d29}\u{6e83}\u{95ee}\u{9898}",
"\u{5ba1}\u{8ba1}\u{5b89}\u{5168}\u{6f0f}\u{6d1e}",
"\u{8fc1}\u{79fb}\u{5230}\u{65b0}\u{6846}\u{67b6}",
"\u{4f18}\u{5316}\u{6027}\u{80fd}\u{74f6}\u{9888}",
"\u{5206}\u{6790}\u{8fd9}\u{6bb5}\u{4ee3}\u{7801}",
] {
assert_eq!(
auto_model_heuristic(msg, "auto"),
"deepseek-v4-pro",
"expected Pro for `{msg}`",
);
}
}
#[test]
fn auto_model_heuristic_traditional_chinese_keywords_route_to_pro() {
for msg in [
"\u{8acb}\u{91cd}\u{69cb}\u{6b64}\u{6a21}\u{7d44}",
"\u{67b6}\u{69cb}\u{8a2d}\u{8a08}",
"\u{4ee3}\u{78bc}\u{8abf}\u{8a66}",
"\u{5be9}\u{8a08}\u{6f0f}\u{6d1e}",
"\u{9077}\u{79fb}\u{5230}\u{65b0}\u{67b6}\u{69cb}",
"\u{512a}\u{5316}\u{6027}\u{80fd}",
"\u{91cd}\u{5beb}\u{4ee3}\u{78bc}",
"\u{5be6}\u{73fe}\u{65b0}\u{529f}\u{80fd}",
] {
assert_eq!(
auto_model_heuristic(msg, "auto"),
"deepseek-v4-pro",
"expected Pro for `{msg}`",
);
}
}
#[test]
fn auto_model_heuristic_short_chinese_chat_stays_on_flash() {
assert_eq!(
auto_model_heuristic("\u{4f60}\u{597d}", "auto"),
"deepseek-v4-flash",
);
}
#[test]
fn auto_heuristic_selection_marks_short_and_complex_routes_decisive() {
let short = auto_model_heuristic_selection_with_bias("yes", "auto", false);
assert_eq!(short.model, "deepseek-v4-flash");
assert_eq!(
short.confidence,
AutoModelHeuristicConfidence::Decisive,
"trivial replies should skip the Flash router"
);
let complex = auto_model_heuristic_selection_with_bias(
"Please review the auth migration",
"auto",
false,
);
assert_eq!(complex.model, "deepseek-v4-pro");
assert_eq!(
complex.confidence,
AutoModelHeuristicConfidence::Decisive,
"strong complexity keywords should skip the Flash router"
);
}
#[test]
fn auto_heuristic_selection_leaves_default_branch_ambiguous_for_router() {
let request =
"Please update the configuration notes so each option has a clearer label. ".repeat(3);
assert!(
(100..500).contains(&request.chars().count()),
"test request must stay in the default grey zone"
);
let selection = auto_model_heuristic_selection_with_bias(&request, "auto", false);
assert_eq!(selection.model, "deepseek-v4-flash");
assert_eq!(
selection.confidence,
AutoModelHeuristicConfidence::Ambiguous,
"only the grey-zone default branch should invoke the Flash router"
);
}
#[test]
fn auto_route_recommendation_parses_strict_json() {
let rec =
parse_auto_route_recommendation(r#"{"model":"deepseek-v4-pro","thinking":"max"}"#)
.expect("valid router response should parse");
assert_eq!(rec.model, "deepseek-v4-pro");
assert_eq!(rec.reasoning_effort, Some(ReasoningEffort::Max));
}
#[test]
fn auto_route_recommendation_accepts_wrapped_json_aliases() {
let rec =
parse_auto_route_recommendation(r#"route: {"model":"flash","reasoning_effort":"off"}"#)
.expect("wrapped router response should parse");
assert_eq!(rec.model, "deepseek-v4-flash");
assert_eq!(rec.reasoning_effort, Some(ReasoningEffort::Off));
}
#[test]
fn auto_route_recommendation_normalizes_legacy_low_medium_to_high() {
let rec = parse_auto_route_recommendation(
r#"{"model":"deepseek-v4-pro","reasoning_effort":"medium"}"#,
)
.expect("medium should parse for back-compat");
assert_eq!(rec.model, "deepseek-v4-pro");
assert_eq!(rec.reasoning_effort, Some(ReasoningEffort::High));
}
#[test]
fn auto_route_recommendation_rejects_unknown_model() {
assert!(
parse_auto_route_recommendation(r#"{"model":"some-other-model","thinking":"max"}"#,)
.is_none()
);
}
#[test]
fn auto_heuristic_default_routes_implement_to_pro() {
assert_eq!(
auto_model_heuristic_with_bias("Please implement a binary search", "auto", false),
"deepseek-v4-pro"
);
}
#[test]
fn auto_heuristic_cost_saving_keeps_borderline_keywords_on_flash() {
assert_eq!(
auto_model_heuristic_with_bias("Please implement a binary search", "auto", true),
"deepseek-v4-flash"
);
assert_eq!(
auto_model_heuristic_with_bias("analyze this snippet", "auto", true),
"deepseek-v4-flash"
);
}
#[test]
fn auto_heuristic_strong_keywords_still_route_to_pro_under_cost_saving() {
for kw in [
"refactor",
"architecture",
"design",
"debug",
"security",
"review",
"audit",
"migrate",
"optimize",
"rewrite",
] {
let req = format!("Please {kw} this module");
assert_eq!(
auto_model_heuristic_with_bias(&req, "auto", true),
"deepseek-v4-pro",
"expected Pro for strong keyword `{kw}` even in cost-saving mode"
);
}
}
#[test]
fn auto_heuristic_cost_saving_raises_long_message_threshold() {
let body = "filler sentence. ".repeat(40);
assert_eq!(
auto_model_heuristic_with_bias(&body, "auto", false),
"deepseek-v4-pro"
);
assert_eq!(
auto_model_heuristic_with_bias(&body, "auto", true),
"deepseek-v4-flash"
);
}
#[test]
fn config_auto_cost_saving_defaults_to_false() {
let cfg = Config::default();
assert!(!cfg.auto_cost_saving());
}
#[test]
fn config_auto_cost_saving_reads_table() {
let cfg = Config {
auto: Some(crate::config::AutoConfig {
cost_saving: Some(true),
}),
..Default::default()
};
assert!(cfg.auto_cost_saving());
}
}
+1 -1
View File
@@ -1660,7 +1660,7 @@ impl RuntimeThreadManager {
let requested_model = req.model.unwrap_or_else(|| thread.model.clone());
let auto_model = requested_model.trim().eq_ignore_ascii_case("auto");
let (model, reasoning_effort) = if auto_model {
let selection = crate::commands::resolve_auto_route_with_flash(
let selection = crate::model_routing::resolve_auto_route_with_flash(
&self.config,
&prompt,
"",
+3 -3
View File
@@ -5140,7 +5140,7 @@ fn fallback_subagent_assignment_route(
let model = if let Some(model) = configured_model {
model
} else if runtime.auto_model {
crate::commands::auto_model_heuristic(prompt, &runtime.model)
crate::model_routing::auto_model_heuristic(prompt, &runtime.model)
} else {
runtime.model.clone()
};
@@ -5166,7 +5166,7 @@ fn fallback_subagent_assignment_route(
async fn subagent_flash_router(
runtime: &SubAgentRuntime,
prompt: &str,
) -> Result<Option<crate::commands::AutoRouteRecommendation>> {
) -> Result<Option<crate::model_routing::AutoRouteRecommendation>> {
if cfg!(test) {
return Ok(None);
}
@@ -5199,7 +5199,7 @@ async fn subagent_flash_router(
runtime.client.create_message(request),
)
.await??;
Ok(crate::commands::parse_auto_route_recommendation(
Ok(crate::model_routing::parse_auto_route_recommendation(
&message_response_text(&response.content),
))
}
+5 -5
View File
@@ -4,12 +4,12 @@
//! The TUI calls `resolve_auto_model_selection` once per user turn when
//! `app.auto_model` is set. The async function builds a recent-context
//! summary from `api_messages` (capped to six rows of up to 900 chars
//! each), passes it through `commands::resolve_auto_route_with_flash`,
//! each), passes it through `model_routing::resolve_auto_route_with_flash`,
//! and returns the selection (model + reasoning effort). The remaining
//! helpers are pure transforms used to build that summary.
use crate::commands;
use crate::config::Config;
use crate::model_routing;
use crate::models::{ContentBlock, Message};
use crate::tui::app::{App, QueuedMessage, ReasoningEffort};
@@ -25,13 +25,13 @@ pub(super) async fn resolve_auto_model_selection(
config: &Config,
message: &QueuedMessage,
latest_content: &str,
) -> commands::AutoRouteSelection {
) -> model_routing::AutoRouteSelection {
let latest_request = if latest_content.trim().is_empty() {
message.display.as_str()
} else {
latest_content
};
commands::resolve_auto_route_with_flash(
model_routing::resolve_auto_route_with_flash(
config,
latest_request,
&recent_auto_router_context(&app.api_messages),
@@ -43,7 +43,7 @@ pub(super) async fn resolve_auto_model_selection(
/// Normalize the heuristic effort to the canonical auto-route effort.
pub(super) fn normalize_auto_routed_effort(effort: ReasoningEffort) -> ReasoningEffort {
commands::normalize_auto_route_effort(effort)
model_routing::normalize_auto_route_effort(effort)
}
/// Build a compact recent-context summary for the auto-route prompt.
+10 -4
View File
@@ -4769,7 +4769,7 @@ fn rollback_provider_after_auth_failure(app: &mut App, config: &mut Config) -> O
app.api_key_env_only = previous_api_key_env_only;
let persistence_error = (|| -> anyhow::Result<()> {
commands::persist_root_string_key(
crate::config_persistence::persist_root_string_key(
app.config_path.as_deref(),
"provider",
previous_provider.as_str(),
@@ -5348,7 +5348,9 @@ async fn dispatch_user_message(
auto_selection
.as_ref()
.map(|selection| selection.model.clone())
.unwrap_or_else(|| commands::auto_model_heuristic(&message.display, &app.model))
.unwrap_or_else(|| {
crate::model_routing::auto_model_heuristic(&message.display, &app.model)
})
} else {
app.model.clone()
};
@@ -5813,7 +5815,11 @@ async fn switch_provider(
.await;
let persist_warning = (|| -> anyhow::Result<()> {
commands::persist_root_string_key(app.config_path.as_deref(), "provider", target.as_str())?;
crate::config_persistence::persist_root_string_key(
app.config_path.as_deref(),
"provider",
target.as_str(),
)?;
let mut settings = crate::settings::Settings::load()?;
settings.default_provider = Some(target.as_str().to_string());
@@ -7732,7 +7738,7 @@ async fn handle_view_events(
app.status_items = items.clone();
app.needs_redraw = true;
if final_save {
match commands::persist_status_items(&items) {
match crate::config_persistence::persist_status_items(&items) {
Ok(path) => {
app.status_message =
Some(format!("Status line saved to {}", path.display()));