Merge pull request #2280 from Hmbown/fix/1572-custom-model-switch
fix(tui): accept custom model IDs in /model for non-DeepSeek providers (#1572)
This commit is contained in:
@@ -3,7 +3,9 @@
|
||||
use std::fmt::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::config::{COMMON_DEEPSEEK_MODELS, normalize_model_name_for_provider};
|
||||
use crate::config::{
|
||||
COMMON_DEEPSEEK_MODELS, normalize_custom_model_id, normalize_model_name_for_provider,
|
||||
};
|
||||
use crate::localization::{MessageId, tr};
|
||||
use crate::tui::app::{App, AppAction, AppMode, ReasoningEffort};
|
||||
use crate::tui::views::{HelpView, ModalKind, SubAgentsView, subagent_view_agents};
|
||||
@@ -135,11 +137,21 @@ pub fn model(app: &mut App, model_name: Option<&str>) -> CommandResult {
|
||||
AppAction::UpdateCompaction(app.compaction_config()),
|
||||
);
|
||||
}
|
||||
let Some(model_id) = normalize_model_name_for_provider(app.api_provider, name) else {
|
||||
return CommandResult::error(format!(
|
||||
"Invalid model '{name}'. Expected auto or a DeepSeek model ID. Common models: {}",
|
||||
COMMON_DEEPSEEK_MODELS.join(", ")
|
||||
));
|
||||
let model_id = if app.accepts_custom_model_ids() {
|
||||
let Some(model_id) = normalize_custom_model_id(name) else {
|
||||
return CommandResult::error(format!(
|
||||
"Invalid model '{name}'. Expected a non-empty model ID."
|
||||
));
|
||||
};
|
||||
model_id
|
||||
} else {
|
||||
let Some(model_id) = normalize_model_name_for_provider(app.api_provider, name) else {
|
||||
return CommandResult::error(format!(
|
||||
"Invalid model '{name}'. Expected auto or a DeepSeek model ID. Common models: {}",
|
||||
COMMON_DEEPSEEK_MODELS.join(", ")
|
||||
));
|
||||
};
|
||||
model_id
|
||||
};
|
||||
let old_model = app.model_display_label();
|
||||
let model_changed = app.auto_model || app.model != model_id;
|
||||
@@ -729,6 +741,38 @@ mod tests {
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_change_accepts_custom_id_for_openai_compatible_provider() {
|
||||
let mut app = create_test_app();
|
||||
app.api_provider = crate::config::ApiProvider::Openai;
|
||||
app.model_ids_passthrough = true;
|
||||
|
||||
let result = model(&mut app, Some("opencode-go/glm-5.1"));
|
||||
|
||||
assert!(result.message.is_some());
|
||||
assert_eq!(app.model, "opencode-go/glm-5.1");
|
||||
assert!(!app.auto_model);
|
||||
assert!(matches!(
|
||||
result.action,
|
||||
Some(AppAction::UpdateCompaction(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_change_accepts_custom_id_for_custom_base_url() {
|
||||
let mut app = create_test_app();
|
||||
app.model_ids_passthrough = true;
|
||||
|
||||
let result = model(&mut app, Some("opencode-go/kimi-k2.6"));
|
||||
|
||||
assert!(result.message.is_some());
|
||||
assert_eq!(app.model, "opencode-go/kimi-k2.6");
|
||||
assert!(matches!(
|
||||
result.action,
|
||||
Some(AppAction::UpdateCompaction(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_change_rejects_invalid_model() {
|
||||
let mut app = create_test_app();
|
||||
|
||||
@@ -85,10 +85,8 @@ fn close_hunt(app: &mut App, verdict: HuntVerdict) -> CommandResult {
|
||||
|
||||
let prev = app.hunt.verdict;
|
||||
let should_write_trophy = prev != verdict || !matches!(verdict, HuntVerdict::Hunted);
|
||||
if should_write_trophy {
|
||||
if let Err(err) = write_trophy_card(app, verdict) {
|
||||
return CommandResult::error(err);
|
||||
}
|
||||
if should_write_trophy && let Err(err) = write_trophy_card(app, verdict) {
|
||||
return CommandResult::error(err);
|
||||
}
|
||||
app.hunt.verdict = verdict;
|
||||
|
||||
|
||||
@@ -429,6 +429,16 @@ pub fn normalize_model_name(model: &str) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub(crate) fn normalize_custom_model_id(model: &str) -> Option<String> {
|
||||
let trimmed = model.trim();
|
||||
if trimmed.is_empty() || trimmed.chars().any(char::is_control) {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn canonical_official_deepseek_model_id(model: &str) -> Option<&'static str> {
|
||||
match model.trim().to_ascii_lowercase().as_str() {
|
||||
"deepseek-v4-pro"
|
||||
@@ -1966,6 +1976,12 @@ impl Config {
|
||||
provider_preserves_custom_base_url_model(provider, &self.deepseek_base_url())
|
||||
}
|
||||
|
||||
pub(crate) fn model_ids_pass_through(&self) -> bool {
|
||||
let provider = self.api_provider();
|
||||
provider_passes_model_through(provider)
|
||||
|| self.active_provider_preserves_custom_base_url_model()
|
||||
}
|
||||
|
||||
/// Read the API key.
|
||||
///
|
||||
/// Precedence: **explicit in-memory override → provider/root config
|
||||
|
||||
@@ -2002,11 +2002,11 @@ impl Engine {
|
||||
// system prompt so the agent can autonomously review them before
|
||||
// claiming the task is done (#2127).
|
||||
let gate_block = self.slop_ledger_gate_block();
|
||||
if let Some(ref block) = gate_block {
|
||||
if let Some(SystemPrompt::Text(prompt_text)) = &mut stable_prompt {
|
||||
prompt_text.push_str("\n\n");
|
||||
prompt_text.push_str(block);
|
||||
}
|
||||
if let Some(ref block) = gate_block
|
||||
&& let Some(SystemPrompt::Text(prompt_text)) = &mut stable_prompt
|
||||
{
|
||||
prompt_text.push_str("\n\n");
|
||||
prompt_text.push_str(block);
|
||||
}
|
||||
|
||||
let stable_hash = system_prompt_hash(stable_prompt.as_ref());
|
||||
|
||||
@@ -4268,10 +4268,10 @@ async fn run_mcp_command(config: &Config, command: McpCommand) -> Result<()> {
|
||||
if command.is_none() && url.is_none() {
|
||||
bail!("Provide either --command or --url for `mcp add`.");
|
||||
}
|
||||
if let Some(transport) = transport.as_deref() {
|
||||
if !transport.trim().eq_ignore_ascii_case("sse") {
|
||||
bail!("Unsupported MCP transport '{transport}'. Supported values: sse");
|
||||
}
|
||||
if let Some(transport) = transport.as_deref()
|
||||
&& !transport.trim().eq_ignore_ascii_case("sse")
|
||||
{
|
||||
bail!("Unsupported MCP transport '{transport}'. Supported values: sse");
|
||||
}
|
||||
let mut cfg = load_mcp_config(&config_path)?;
|
||||
cfg.servers.insert(
|
||||
|
||||
@@ -1832,10 +1832,10 @@ impl McpConnection {
|
||||
// IDs, but accept numeric echoes for compatibility with older
|
||||
// servers and tests.
|
||||
if response_id_matches(value.get("id"), &expected_id) {
|
||||
if let Some(error) = value.get("error") {
|
||||
if is_mcp_stale_session_body(&error.to_string()) {
|
||||
anyhow::bail!("MCP session expired: {error}");
|
||||
}
|
||||
if let Some(error) = value.get("error")
|
||||
&& is_mcp_stale_session_body(&error.to_string())
|
||||
{
|
||||
anyhow::bail!("MCP session expired: {error}");
|
||||
}
|
||||
return Ok(value);
|
||||
}
|
||||
|
||||
@@ -263,7 +263,7 @@ impl SlopLedger {
|
||||
pub fn default_path() -> io::Result<PathBuf> {
|
||||
codewhale_config::resolve_state_dir("slop_ledger")
|
||||
.map(|p| p.join("slop_ledger.json"))
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
.map_err(io::Error::other)
|
||||
}
|
||||
|
||||
/// Load ledger from the default path, returning an empty ledger if the
|
||||
@@ -297,9 +297,8 @@ impl SlopLedger {
|
||||
if let Some(parent) = self.ledger_path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
let data = serde_json::to_string_pretty(self).map_err(|e| {
|
||||
io::Error::new(io::ErrorKind::Other, format!("serialization error: {e}"))
|
||||
})?;
|
||||
let data = serde_json::to_string_pretty(self)
|
||||
.map_err(|e| io::Error::other(format!("serialization error: {e}")))?;
|
||||
crate::utils::write_atomic(&self.ledger_path, data.as_bytes())
|
||||
}
|
||||
|
||||
@@ -330,20 +329,20 @@ impl SlopLedger {
|
||||
.entries
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
if let Some(bucket) = &filter.bucket {
|
||||
if e.bucket != *bucket {
|
||||
return false;
|
||||
}
|
||||
if let Some(bucket) = &filter.bucket
|
||||
&& e.bucket != *bucket
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if let Some(severity) = &filter.severity {
|
||||
if e.severity != *severity {
|
||||
return false;
|
||||
}
|
||||
if let Some(severity) = &filter.severity
|
||||
&& e.severity != *severity
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if let Some(status) = &filter.status {
|
||||
if e.status != *status {
|
||||
return false;
|
||||
}
|
||||
if let Some(status) = &filter.status
|
||||
&& e.status != *status
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if let Some(search) = &filter.search {
|
||||
let q = search.to_lowercase();
|
||||
@@ -408,7 +407,7 @@ impl SlopLedger {
|
||||
let mut out = format!("# {heading}\n\n");
|
||||
out.push_str(&format!(
|
||||
"_Generated at {} — {} entries_\n\n",
|
||||
chrono::Utc::now().format("%Y-%m-%d %H:%M UTC").to_string(),
|
||||
chrono::Utc::now().format("%Y-%m-%d %H:%M UTC"),
|
||||
entries.len()
|
||||
));
|
||||
|
||||
|
||||
@@ -1003,21 +1003,16 @@ impl Default for ViewportState {
|
||||
}
|
||||
|
||||
/// Verdict for a hunt (#2092).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum HuntVerdict {
|
||||
#[default]
|
||||
Hunting,
|
||||
Hunted,
|
||||
Wounded,
|
||||
Escaped,
|
||||
}
|
||||
|
||||
impl Default for HuntVerdict {
|
||||
fn default() -> Self {
|
||||
Self::Hunting
|
||||
}
|
||||
}
|
||||
|
||||
/// Hunt tracking state (#2092 — was GoalState).
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct HuntState {
|
||||
@@ -1163,6 +1158,9 @@ pub struct App {
|
||||
/// Updated by `/provider` switches so the UI/commands can read the
|
||||
/// active backend without re-deriving it from the live config.
|
||||
pub api_provider: ApiProvider,
|
||||
/// True when the active provider/base URL accepts arbitrary model IDs
|
||||
/// verbatim rather than DeepSeek-only aliases.
|
||||
pub model_ids_passthrough: bool,
|
||||
/// Current reasoning-effort tier for DeepSeek thinking mode.
|
||||
/// Cycled via Shift+Tab; initialized from config at startup.
|
||||
pub reasoning_effort: ReasoningEffort,
|
||||
@@ -1716,6 +1714,7 @@ impl App {
|
||||
}
|
||||
let mut effective_auth_config = config.clone();
|
||||
effective_auth_config.provider = Some(provider.as_str().to_string());
|
||||
let model_ids_passthrough = effective_auth_config.model_ids_pass_through();
|
||||
|
||||
// Check if the effective provider has an API key. This must happen
|
||||
// after settings.default_provider is applied; otherwise a saved
|
||||
@@ -1906,6 +1905,7 @@ impl App {
|
||||
auto_model,
|
||||
last_effective_model: None,
|
||||
api_provider: provider,
|
||||
model_ids_passthrough,
|
||||
reasoning_effort,
|
||||
last_effective_reasoning_effort: None,
|
||||
workspace,
|
||||
@@ -4680,6 +4680,11 @@ impl App {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn accepts_custom_model_ids(&self) -> bool {
|
||||
self.model_ids_passthrough
|
||||
|| crate::config::provider_passes_model_through(self.api_provider)
|
||||
}
|
||||
|
||||
pub fn effective_model_for_budget(&self) -> &str {
|
||||
if self.auto_model {
|
||||
return self
|
||||
|
||||
@@ -73,12 +73,13 @@ pub struct ModelPickerView {
|
||||
impl ModelPickerView {
|
||||
#[must_use]
|
||||
pub fn new(app: &App) -> Self {
|
||||
let hide_deepseek_models = crate::config::provider_passes_model_through(app.api_provider);
|
||||
let hide_deepseek_models = app.accepts_custom_model_ids();
|
||||
// Whale routes are DeepSeek-specific — only official providers get them.
|
||||
let show_whale_routes = matches!(
|
||||
app.api_provider,
|
||||
crate::config::ApiProvider::Deepseek | crate::config::ApiProvider::DeepseekCN
|
||||
);
|
||||
let show_whale_routes = !hide_deepseek_models
|
||||
&& matches!(
|
||||
app.api_provider,
|
||||
crate::config::ApiProvider::Deepseek | crate::config::ApiProvider::DeepseekCN
|
||||
);
|
||||
let initial_model = if app.auto_model {
|
||||
"auto".to_string()
|
||||
} else {
|
||||
@@ -594,6 +595,7 @@ mod tests {
|
||||
app.auto_model = false;
|
||||
app.reasoning_effort = ReasoningEffort::Max;
|
||||
app.api_provider = crate::config::ApiProvider::Deepseek;
|
||||
app.model_ids_passthrough = false;
|
||||
(app, lock)
|
||||
}
|
||||
|
||||
@@ -672,6 +674,21 @@ mod tests {
|
||||
assert_eq!(view.resolved_model(), "deepseek-v4-pro-2026-04-XX");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn picker_uses_pass_through_layout_for_custom_base_url_model_ids() {
|
||||
let (mut app, _lock) = create_test_app();
|
||||
app.model_ids_passthrough = true;
|
||||
app.model = "opencode-go/glm-5.1".to_string();
|
||||
app.auto_model = false;
|
||||
|
||||
let view = ModelPickerView::new(&app);
|
||||
|
||||
assert!(view.hide_deepseek_models);
|
||||
assert!(!view.show_whale_routes);
|
||||
assert!(view.show_custom_model_row);
|
||||
assert_eq!(view.resolved_model(), "opencode-go/glm-5.1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arrow_keys_move_within_whale_routes() {
|
||||
let (app, _lock) = create_test_app();
|
||||
|
||||
+10
-11
@@ -1687,16 +1687,14 @@ async fn run_event_loop(
|
||||
// required — so the agent can't forget to check.
|
||||
if let Ok(ledger) = crate::slop_ledger::SlopLedger::load()
|
||||
&& ledger.has_open_entries()
|
||||
&& let Some(gate_msg) = ledger.completion_gate_summary()
|
||||
{
|
||||
if let Some(gate_msg) = ledger.completion_gate_summary() {
|
||||
let short =
|
||||
gate_msg.lines().nth(4).unwrap_or("review before done");
|
||||
app.push_status_toast(
|
||||
format!("⚠️ SlopLedger: {short}"),
|
||||
crate::tui::app::StatusToastLevel::Warning,
|
||||
Some(12_000),
|
||||
);
|
||||
}
|
||||
let short = gate_msg.lines().nth(4).unwrap_or("review before done");
|
||||
app.push_status_toast(
|
||||
format!("⚠️ SlopLedger: {short}"),
|
||||
crate::tui::app::StatusToastLevel::Warning,
|
||||
Some(12_000),
|
||||
);
|
||||
}
|
||||
|
||||
let tool_count = app.tool_evidence.len();
|
||||
@@ -1739,7 +1737,7 @@ async fn run_event_loop(
|
||||
// adding latency to any request path.
|
||||
let balance_cooldown_expired = app
|
||||
.last_balance_fetch
|
||||
.map_or(true, |t| t.elapsed() >= BALANCE_FETCH_COOLDOWN);
|
||||
.is_none_or(|t| t.elapsed() >= BALANCE_FETCH_COOLDOWN);
|
||||
if balance_cooldown_expired && should_fetch_deepseek_balance(app) {
|
||||
let cell = app.balance_cell.clone();
|
||||
let api_key = config.deepseek_api_key().unwrap_or_default();
|
||||
@@ -4889,6 +4887,7 @@ async fn switch_provider(
|
||||
let new_model = config.default_model();
|
||||
let cache_scope_changed = previous_provider != target || previous_model != new_model;
|
||||
app.api_provider = target;
|
||||
app.model_ids_passthrough = config.model_ids_pass_through();
|
||||
app.set_model_selection(new_model.clone());
|
||||
app.update_model_compaction_budget();
|
||||
if cache_scope_changed {
|
||||
@@ -5112,7 +5111,7 @@ async fn apply_command_result(
|
||||
// Refresh balance after provider switch.
|
||||
let balance_cooldown_expired = app
|
||||
.last_balance_fetch
|
||||
.map_or(true, |t| t.elapsed() >= BALANCE_FETCH_COOLDOWN);
|
||||
.is_none_or(|t| t.elapsed() >= BALANCE_FETCH_COOLDOWN);
|
||||
if balance_cooldown_expired && should_fetch_deepseek_balance(app) {
|
||||
let cell = app.balance_cell.clone();
|
||||
let api_key = config.deepseek_api_key().unwrap_or_default();
|
||||
|
||||
@@ -1360,6 +1360,7 @@ fn create_test_options() -> TuiOptions {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
async fn tool_result_api_content_receipts_large_live_output() {
|
||||
let _guard = crate::tools::truncate::TEST_SPILLOVER_GUARD
|
||||
.lock()
|
||||
|
||||
@@ -2363,11 +2363,11 @@ fn push_command_entry(
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(hint) = argument_hint {
|
||||
if !hint.trim().is_empty() {
|
||||
description.push_str(" ");
|
||||
description.push_str(hint.trim());
|
||||
}
|
||||
if let Some(hint) = argument_hint
|
||||
&& !hint.trim().is_empty()
|
||||
{
|
||||
description.push_str(" ");
|
||||
description.push_str(hint.trim());
|
||||
}
|
||||
(description, None)
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user