Merge pull request #2280 from Hmbown/fix/1572-custom-model-switch

fix(tui): accept custom model IDs in /model for non-DeepSeek providers (#1572)
This commit is contained in:
Hunter Bown
2026-05-31 00:42:22 -07:00
committed by GitHub
12 changed files with 147 additions and 68 deletions
+50 -6
View File
@@ -3,7 +3,9 @@
use std::fmt::Write;
use std::path::PathBuf;
use crate::config::{COMMON_DEEPSEEK_MODELS, normalize_model_name_for_provider};
use crate::config::{
COMMON_DEEPSEEK_MODELS, normalize_custom_model_id, normalize_model_name_for_provider,
};
use crate::localization::{MessageId, tr};
use crate::tui::app::{App, AppAction, AppMode, ReasoningEffort};
use crate::tui::views::{HelpView, ModalKind, SubAgentsView, subagent_view_agents};
@@ -135,11 +137,21 @@ pub fn model(app: &mut App, model_name: Option<&str>) -> CommandResult {
AppAction::UpdateCompaction(app.compaction_config()),
);
}
let Some(model_id) = normalize_model_name_for_provider(app.api_provider, name) else {
return CommandResult::error(format!(
"Invalid model '{name}'. Expected auto or a DeepSeek model ID. Common models: {}",
COMMON_DEEPSEEK_MODELS.join(", ")
));
let model_id = if app.accepts_custom_model_ids() {
let Some(model_id) = normalize_custom_model_id(name) else {
return CommandResult::error(format!(
"Invalid model '{name}'. Expected a non-empty model ID."
));
};
model_id
} else {
let Some(model_id) = normalize_model_name_for_provider(app.api_provider, name) else {
return CommandResult::error(format!(
"Invalid model '{name}'. Expected auto or a DeepSeek model ID. Common models: {}",
COMMON_DEEPSEEK_MODELS.join(", ")
));
};
model_id
};
let old_model = app.model_display_label();
let model_changed = app.auto_model || app.model != model_id;
@@ -729,6 +741,38 @@ mod tests {
));
}
#[test]
fn test_model_change_accepts_custom_id_for_openai_compatible_provider() {
let mut app = create_test_app();
app.api_provider = crate::config::ApiProvider::Openai;
app.model_ids_passthrough = true;
let result = model(&mut app, Some("opencode-go/glm-5.1"));
assert!(result.message.is_some());
assert_eq!(app.model, "opencode-go/glm-5.1");
assert!(!app.auto_model);
assert!(matches!(
result.action,
Some(AppAction::UpdateCompaction(_))
));
}
#[test]
fn test_model_change_accepts_custom_id_for_custom_base_url() {
let mut app = create_test_app();
app.model_ids_passthrough = true;
let result = model(&mut app, Some("opencode-go/kimi-k2.6"));
assert!(result.message.is_some());
assert_eq!(app.model, "opencode-go/kimi-k2.6");
assert!(matches!(
result.action,
Some(AppAction::UpdateCompaction(_))
));
}
#[test]
fn test_model_change_rejects_invalid_model() {
let mut app = create_test_app();
+2 -4
View File
@@ -85,10 +85,8 @@ fn close_hunt(app: &mut App, verdict: HuntVerdict) -> CommandResult {
let prev = app.hunt.verdict;
let should_write_trophy = prev != verdict || !matches!(verdict, HuntVerdict::Hunted);
if should_write_trophy {
if let Err(err) = write_trophy_card(app, verdict) {
return CommandResult::error(err);
}
if should_write_trophy && let Err(err) = write_trophy_card(app, verdict) {
return CommandResult::error(err);
}
app.hunt.verdict = verdict;
+16
View File
@@ -429,6 +429,16 @@ pub fn normalize_model_name(model: &str) -> Option<String> {
None
}
#[must_use]
pub(crate) fn normalize_custom_model_id(model: &str) -> Option<String> {
let trimmed = model.trim();
if trimmed.is_empty() || trimmed.chars().any(char::is_control) {
None
} else {
Some(trimmed.to_string())
}
}
fn canonical_official_deepseek_model_id(model: &str) -> Option<&'static str> {
match model.trim().to_ascii_lowercase().as_str() {
"deepseek-v4-pro"
@@ -1966,6 +1976,12 @@ impl Config {
provider_preserves_custom_base_url_model(provider, &self.deepseek_base_url())
}
pub(crate) fn model_ids_pass_through(&self) -> bool {
let provider = self.api_provider();
provider_passes_model_through(provider)
|| self.active_provider_preserves_custom_base_url_model()
}
/// Read the API key.
///
/// Precedence: **explicit in-memory override → provider/root config
+5 -5
View File
@@ -2002,11 +2002,11 @@ impl Engine {
// system prompt so the agent can autonomously review them before
// claiming the task is done (#2127).
let gate_block = self.slop_ledger_gate_block();
if let Some(ref block) = gate_block {
if let Some(SystemPrompt::Text(prompt_text)) = &mut stable_prompt {
prompt_text.push_str("\n\n");
prompt_text.push_str(block);
}
if let Some(ref block) = gate_block
&& let Some(SystemPrompt::Text(prompt_text)) = &mut stable_prompt
{
prompt_text.push_str("\n\n");
prompt_text.push_str(block);
}
let stable_hash = system_prompt_hash(stable_prompt.as_ref());
+4 -4
View File
@@ -4268,10 +4268,10 @@ async fn run_mcp_command(config: &Config, command: McpCommand) -> Result<()> {
if command.is_none() && url.is_none() {
bail!("Provide either --command or --url for `mcp add`.");
}
if let Some(transport) = transport.as_deref() {
if !transport.trim().eq_ignore_ascii_case("sse") {
bail!("Unsupported MCP transport '{transport}'. Supported values: sse");
}
if let Some(transport) = transport.as_deref()
&& !transport.trim().eq_ignore_ascii_case("sse")
{
bail!("Unsupported MCP transport '{transport}'. Supported values: sse");
}
let mut cfg = load_mcp_config(&config_path)?;
cfg.servers.insert(
+4 -4
View File
@@ -1832,10 +1832,10 @@ impl McpConnection {
// IDs, but accept numeric echoes for compatibility with older
// servers and tests.
if response_id_matches(value.get("id"), &expected_id) {
if let Some(error) = value.get("error") {
if is_mcp_stale_session_body(&error.to_string()) {
anyhow::bail!("MCP session expired: {error}");
}
if let Some(error) = value.get("error")
&& is_mcp_stale_session_body(&error.to_string())
{
anyhow::bail!("MCP session expired: {error}");
}
return Ok(value);
}
+16 -17
View File
@@ -263,7 +263,7 @@ impl SlopLedger {
pub fn default_path() -> io::Result<PathBuf> {
codewhale_config::resolve_state_dir("slop_ledger")
.map(|p| p.join("slop_ledger.json"))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
.map_err(io::Error::other)
}
/// Load ledger from the default path, returning an empty ledger if the
@@ -297,9 +297,8 @@ impl SlopLedger {
if let Some(parent) = self.ledger_path.parent() {
fs::create_dir_all(parent)?;
}
let data = serde_json::to_string_pretty(self).map_err(|e| {
io::Error::new(io::ErrorKind::Other, format!("serialization error: {e}"))
})?;
let data = serde_json::to_string_pretty(self)
.map_err(|e| io::Error::other(format!("serialization error: {e}")))?;
crate::utils::write_atomic(&self.ledger_path, data.as_bytes())
}
@@ -330,20 +329,20 @@ impl SlopLedger {
.entries
.iter()
.filter(|e| {
if let Some(bucket) = &filter.bucket {
if e.bucket != *bucket {
return false;
}
if let Some(bucket) = &filter.bucket
&& e.bucket != *bucket
{
return false;
}
if let Some(severity) = &filter.severity {
if e.severity != *severity {
return false;
}
if let Some(severity) = &filter.severity
&& e.severity != *severity
{
return false;
}
if let Some(status) = &filter.status {
if e.status != *status {
return false;
}
if let Some(status) = &filter.status
&& e.status != *status
{
return false;
}
if let Some(search) = &filter.search {
let q = search.to_lowercase();
@@ -408,7 +407,7 @@ impl SlopLedger {
let mut out = format!("# {heading}\n\n");
out.push_str(&format!(
"_Generated at {} — {} entries_\n\n",
chrono::Utc::now().format("%Y-%m-%d %H:%M UTC").to_string(),
chrono::Utc::now().format("%Y-%m-%d %H:%M UTC"),
entries.len()
));
+12 -7
View File
@@ -1003,21 +1003,16 @@ impl Default for ViewportState {
}
/// Verdict for a hunt (#2092).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HuntVerdict {
#[default]
Hunting,
Hunted,
Wounded,
Escaped,
}
impl Default for HuntVerdict {
fn default() -> Self {
Self::Hunting
}
}
/// Hunt tracking state (#2092 — was GoalState).
#[derive(Debug, Clone, Default)]
pub struct HuntState {
@@ -1163,6 +1158,9 @@ pub struct App {
/// Updated by `/provider` switches so the UI/commands can read the
/// active backend without re-deriving it from the live config.
pub api_provider: ApiProvider,
/// True when the active provider/base URL accepts arbitrary model IDs
/// verbatim rather than DeepSeek-only aliases.
pub model_ids_passthrough: bool,
/// Current reasoning-effort tier for DeepSeek thinking mode.
/// Cycled via Shift+Tab; initialized from config at startup.
pub reasoning_effort: ReasoningEffort,
@@ -1716,6 +1714,7 @@ impl App {
}
let mut effective_auth_config = config.clone();
effective_auth_config.provider = Some(provider.as_str().to_string());
let model_ids_passthrough = effective_auth_config.model_ids_pass_through();
// Check if the effective provider has an API key. This must happen
// after settings.default_provider is applied; otherwise a saved
@@ -1906,6 +1905,7 @@ impl App {
auto_model,
last_effective_model: None,
api_provider: provider,
model_ids_passthrough,
reasoning_effort,
last_effective_reasoning_effort: None,
workspace,
@@ -4680,6 +4680,11 @@ impl App {
}
}
pub fn accepts_custom_model_ids(&self) -> bool {
self.model_ids_passthrough
|| crate::config::provider_passes_model_through(self.api_provider)
}
pub fn effective_model_for_budget(&self) -> &str {
if self.auto_model {
return self
+22 -5
View File
@@ -73,12 +73,13 @@ pub struct ModelPickerView {
impl ModelPickerView {
#[must_use]
pub fn new(app: &App) -> Self {
let hide_deepseek_models = crate::config::provider_passes_model_through(app.api_provider);
let hide_deepseek_models = app.accepts_custom_model_ids();
// Whale routes are DeepSeek-specific — only official providers get them.
let show_whale_routes = matches!(
app.api_provider,
crate::config::ApiProvider::Deepseek | crate::config::ApiProvider::DeepseekCN
);
let show_whale_routes = !hide_deepseek_models
&& matches!(
app.api_provider,
crate::config::ApiProvider::Deepseek | crate::config::ApiProvider::DeepseekCN
);
let initial_model = if app.auto_model {
"auto".to_string()
} else {
@@ -594,6 +595,7 @@ mod tests {
app.auto_model = false;
app.reasoning_effort = ReasoningEffort::Max;
app.api_provider = crate::config::ApiProvider::Deepseek;
app.model_ids_passthrough = false;
(app, lock)
}
@@ -672,6 +674,21 @@ mod tests {
assert_eq!(view.resolved_model(), "deepseek-v4-pro-2026-04-XX");
}
#[test]
fn picker_uses_pass_through_layout_for_custom_base_url_model_ids() {
let (mut app, _lock) = create_test_app();
app.model_ids_passthrough = true;
app.model = "opencode-go/glm-5.1".to_string();
app.auto_model = false;
let view = ModelPickerView::new(&app);
assert!(view.hide_deepseek_models);
assert!(!view.show_whale_routes);
assert!(view.show_custom_model_row);
assert_eq!(view.resolved_model(), "opencode-go/glm-5.1");
}
#[test]
fn arrow_keys_move_within_whale_routes() {
let (app, _lock) = create_test_app();
+10 -11
View File
@@ -1687,16 +1687,14 @@ async fn run_event_loop(
// required — so the agent can't forget to check.
if let Ok(ledger) = crate::slop_ledger::SlopLedger::load()
&& ledger.has_open_entries()
&& let Some(gate_msg) = ledger.completion_gate_summary()
{
if let Some(gate_msg) = ledger.completion_gate_summary() {
let short =
gate_msg.lines().nth(4).unwrap_or("review before done");
app.push_status_toast(
format!("⚠️ SlopLedger: {short}"),
crate::tui::app::StatusToastLevel::Warning,
Some(12_000),
);
}
let short = gate_msg.lines().nth(4).unwrap_or("review before done");
app.push_status_toast(
format!("⚠️ SlopLedger: {short}"),
crate::tui::app::StatusToastLevel::Warning,
Some(12_000),
);
}
let tool_count = app.tool_evidence.len();
@@ -1739,7 +1737,7 @@ async fn run_event_loop(
// adding latency to any request path.
let balance_cooldown_expired = app
.last_balance_fetch
.map_or(true, |t| t.elapsed() >= BALANCE_FETCH_COOLDOWN);
.is_none_or(|t| t.elapsed() >= BALANCE_FETCH_COOLDOWN);
if balance_cooldown_expired && should_fetch_deepseek_balance(app) {
let cell = app.balance_cell.clone();
let api_key = config.deepseek_api_key().unwrap_or_default();
@@ -4889,6 +4887,7 @@ async fn switch_provider(
let new_model = config.default_model();
let cache_scope_changed = previous_provider != target || previous_model != new_model;
app.api_provider = target;
app.model_ids_passthrough = config.model_ids_pass_through();
app.set_model_selection(new_model.clone());
app.update_model_compaction_budget();
if cache_scope_changed {
@@ -5112,7 +5111,7 @@ async fn apply_command_result(
// Refresh balance after provider switch.
let balance_cooldown_expired = app
.last_balance_fetch
.map_or(true, |t| t.elapsed() >= BALANCE_FETCH_COOLDOWN);
.is_none_or(|t| t.elapsed() >= BALANCE_FETCH_COOLDOWN);
if balance_cooldown_expired && should_fetch_deepseek_balance(app) {
let cell = app.balance_cell.clone();
let api_key = config.deepseek_api_key().unwrap_or_default();
+1
View File
@@ -1360,6 +1360,7 @@ fn create_test_options() -> TuiOptions {
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn tool_result_api_content_receipts_large_live_output() {
let _guard = crate::tools::truncate::TEST_SPILLOVER_GUARD
.lock()
+5 -5
View File
@@ -2363,11 +2363,11 @@ fn push_command_entry(
}
}
}
if let Some(hint) = argument_hint {
if !hint.trim().is_empty() {
description.push_str(" ");
description.push_str(hint.trim());
}
if let Some(hint) = argument_hint
&& !hint.trim().is_empty()
{
description.push_str(" ");
description.push_str(hint.trim());
}
(description, None)
};