From 662a459ee555592413eefc2ac7004a9a450a3595 Mon Sep 17 00:00:00 2001 From: CodeWhale Agent Date: Fri, 12 Jun 2026 14:12:51 -0700 Subject: [PATCH] Harvest PR #2773: activate provider fallback chain Harvested from PR #2773 by @idling11 Co-authored-by: Hanmiao Li <894876246@qq.com> --- crates/config/src/lib.rs | 15 +- .../tui/src/commands/groups/core/provider.rs | 77 ++++++++++ crates/tui/src/config.rs | 7 + crates/tui/src/tui/app.rs | 62 +++++++++ crates/tui/src/tui/footer_ui.rs | 3 + crates/tui/src/tui/ui.rs | 131 ++++++++++++++++++ crates/tui/src/tui/ui/tests.rs | 39 ++++++ 7 files changed, 330 insertions(+), 4 deletions(-) diff --git a/crates/config/src/lib.rs b/crates/config/src/lib.rs index bdcb9324..52ec563e 100644 --- a/crates/config/src/lib.rs +++ b/crates/config/src/lib.rs @@ -566,9 +566,9 @@ pub struct ConfigToml { pub tools: Option, #[serde(default)] pub providers: ProvidersToml, - /// Dormant provider fallback chain (#2574). This is parsed and preserved - /// for future provider-routing work; current runtime resolution still uses - /// the selected primary provider and does not auto-switch routes. + /// Provider fallback chain (#2574). TUI runtime code may advance through + /// these providers after recoverable provider errors; config resolution + /// itself still reports the selected primary provider. #[serde(default, skip_serializing_if = "Vec::is_empty")] pub fallback_providers: Vec, /// Per-domain network policy (#135). When absent, network tools fall back @@ -901,7 +901,10 @@ impl ProviderChain { #[must_use] pub fn current(&self) -> ProviderKind { - self.providers[self.position] + self.providers + .get(self.position) + .copied() + .unwrap_or(self.providers[0]) } #[must_use] @@ -917,6 +920,10 @@ impl ProviderChain { Some(self.current()) } + pub fn reset(&mut self) { + self.position = 0; + } + #[must_use] pub fn is_fallback_active(&self) -> bool { self.position > 0 diff --git a/crates/tui/src/commands/groups/core/provider.rs b/crates/tui/src/commands/groups/core/provider.rs index e5b002c9..47b32060 100644 --- a/crates/tui/src/commands/groups/core/provider.rs +++ b/crates/tui/src/commands/groups/core/provider.rs @@ -28,6 +28,10 @@ pub fn provider(app: &mut App, args: Option<&str>) -> CommandResult { let name = parts.next().unwrap_or(""); let model_arg = parts.next(); + if name.eq_ignore_ascii_case("fallback") { + return provider_fallback(app, model_arg); + } + let Some(target) = ApiProvider::parse(name) else { return CommandResult::error(format!( "Unknown provider '{name}'. Expected: {}.", @@ -70,6 +74,54 @@ pub fn provider(app: &mut App, args: Option<&str>) -> CommandResult { }) } +fn provider_fallback(app: &mut App, subcommand: Option<&str>) -> CommandResult { + match subcommand { + Some("reset") => { + let Some((_, primary, _)) = app.fallback_chain_entries().first().copied() else { + return CommandResult::message( + "No fallback providers configured. Add `fallback_providers` to your config.", + ); + }; + CommandResult::with_message_and_action( + format!("Fallback chain reset to primary provider: {}.", primary.as_str()), + AppAction::SwitchProvider { + provider: primary, + model: None, + }, + ) + } + Some(other) => CommandResult::error(format!( + "Unknown fallback command '{other}'. Usage: /provider fallback [reset]" + )), + None => { + let entries = app.fallback_chain_entries(); + if entries.is_empty() { + return CommandResult::message( + "No fallback providers configured. Add `fallback_providers` to your config.", + ); + } + + let mut lines = vec![ + format!("Current provider: {}", app.api_provider.as_str()), + "Fallback chain:".to_string(), + ]; + for (index, provider, is_current) in entries { + let role = if index == 0 { "primary" } else { "fallback" }; + let marker = if is_current { " <- current" } else { "" }; + lines.push(format!( + " [{index}] {} ({role}){marker}", + provider.as_str() + )); + } + if let Some(reason) = app.last_fallback_reason.as_deref() { + lines.push(format!("Last fallback: {reason}")); + } + lines.push("Use `/provider fallback reset` to return to the primary provider.".into()); + CommandResult::message(lines.join("\n")) + } + } +} + fn expand_model_alias_for_provider(provider: ApiProvider, name: &str) -> String { let lower = name.trim().to_ascii_lowercase(); if matches!(provider, ApiProvider::XiaomiMimo) { @@ -412,6 +464,31 @@ mod tests { } } + #[test] + fn provider_fallback_status_and_reset_use_configured_chain() { + let mut app = create_test_app(); + app.provider_chain = Some(codewhale_config::ProviderChain::new( + codewhale_config::ProviderKind::Deepseek, + &[codewhale_config::ProviderKind::Openrouter], + )); + + let status = provider(&mut app, Some("fallback")); + let message = status.message.expect("fallback status"); + assert!(message.contains("Current provider: deepseek")); + assert!(message.contains("[0] deepseek (primary) <- current")); + assert!(message.contains("[1] openrouter (fallback)")); + + let reset = provider(&mut app, Some("fallback reset")); + assert!(reset.message.as_deref().unwrap_or("").contains("deepseek")); + assert!(matches!( + reset.action, + Some(AppAction::SwitchProvider { + provider: ApiProvider::Deepseek, + model: None + }) + )); + } + #[test] fn invalid_model_returns_error() { let mut app = create_test_app(); diff --git a/crates/tui/src/config.rs b/crates/tui/src/config.rs index 2fe244e6..e59124ab 100644 --- a/crates/tui/src/config.rs +++ b/crates/tui/src/config.rs @@ -1685,6 +1685,8 @@ pub struct Config { pub prompt_suggestion: Option, pub approval_policy: Option, pub sandbox_mode: Option, + #[serde(default)] + pub fallback_providers: Vec, pub yolo: Option, pub verbosity: Option, /// External sandbox backend: `"none"` or `"opensandbox"`. @@ -4784,6 +4786,11 @@ fn merge_config(base: Config, override_cfg: Config) -> Config { verbosity: override_cfg.verbosity.or(base.verbosity), approval_policy: override_cfg.approval_policy.or(base.approval_policy), sandbox_mode: override_cfg.sandbox_mode.or(base.sandbox_mode), + fallback_providers: if override_cfg.fallback_providers.is_empty() { + base.fallback_providers + } else { + override_cfg.fallback_providers + }, sandbox_backend: override_cfg.sandbox_backend.or(base.sandbox_backend), sandbox_url: override_cfg.sandbox_url.or(base.sandbox_url), sandbox_api_key: override_cfg.sandbox_api_key.or(base.sandbox_api_key), diff --git a/crates/tui/src/tui/app.rs b/crates/tui/src/tui/app.rs index f75a2d3a..fb6a3ad4 100644 --- a/crates/tui/src/tui/app.rs +++ b/crates/tui/src/tui/app.rs @@ -9,6 +9,8 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use thiserror::Error; +use codewhale_config::ProviderChain; + use crate::artifacts::ArtifactRecord; use crate::client::{CacheWarmupKey, PromptInspection}; use crate::compaction::CompactionConfig; @@ -1393,6 +1395,10 @@ pub struct App { /// Updated by `/provider` switches so the UI/commands can read the /// active backend without re-deriving it from the live config. pub api_provider: ApiProvider, + /// Primary provider plus configured fallback providers for this session. + pub provider_chain: Option, + /// Human-readable description of the last provider fallback event. + pub last_fallback_reason: Option, /// True when the active provider/base URL accepts arbitrary model IDs /// verbatim rather than DeepSeek-only aliases. pub model_ids_passthrough: bool, @@ -2018,6 +2024,10 @@ impl App { let mut effective_auth_config = config.clone(); effective_auth_config.provider = Some(provider.as_str().to_string()); let model_ids_passthrough = effective_auth_config.model_ids_pass_through(); + let provider_chain = provider + .kind() + .map(|kind| ProviderChain::new(kind, &config.fallback_providers)) + .filter(|chain| chain.providers().len() > 1); // Check if the effective provider has an API key. This must happen // after settings.default_provider is applied; otherwise a saved @@ -2231,6 +2241,8 @@ impl App { auto_model, last_effective_model: None, api_provider: provider, + provider_chain, + last_fallback_reason: None, model_ids_passthrough, pending_provider_switch: None, reasoning_effort, @@ -5168,6 +5180,56 @@ impl App { ..Default::default() } } + + pub fn fallback_chain_entries(&self) -> Vec<(usize, ApiProvider, bool)> { + let Some(chain) = &self.provider_chain else { + return Vec::new(); + }; + let position = chain.position(); + chain + .providers() + .iter() + .enumerate() + .map(|(index, provider)| (index, ApiProvider::from_kind(*provider), index == position)) + .collect() + } + + pub fn fallback_chain_position(&self) -> Option { + self.provider_chain.as_ref().map(ProviderChain::position) + } + + pub fn fallback_chain_len(&self) -> usize { + self.provider_chain + .as_ref() + .map_or(0, |chain| chain.providers().len()) + } + + pub fn advance_fallback(&mut self, reason: impl Into) -> Option { + let reason = reason.into(); + let Some(chain) = self.provider_chain.as_mut() else { + return None; + }; + let Some(next_kind) = chain.advance() else { + self.last_fallback_reason = Some(format!( + "Fallback chain exhausted after {} provider(s): {reason}", + chain.providers().len() + )); + return None; + }; + let next_provider = ApiProvider::from_kind(next_kind); + self.api_provider = next_provider; + self.last_fallback_reason = Some(format!( + "Fell back to {} after recoverable provider error: {reason}", + next_provider.as_str() + )); + Some(next_provider) + } + + pub fn is_fallback_active(&self) -> bool { + self.provider_chain + .as_ref() + .is_some_and(ProviderChain::is_fallback_active) + } } pub fn media_attachment_reference(kind: &str, path: &Path, description: Option<&str>) -> String { diff --git a/crates/tui/src/tui/footer_ui.rs b/crates/tui/src/tui/footer_ui.rs index 79d5d85d..f3ef54f2 100644 --- a/crates/tui/src/tui/footer_ui.rs +++ b/crates/tui/src/tui/footer_ui.rs @@ -930,6 +930,9 @@ pub(crate) fn footer_status_line_spans(app: &App, max_width: usize) -> Vec (&'static str, ratatui::style::Color) { + if app.is_fallback_active() { + return ("fallback ->", app.ui_theme.status_warning); + } if app.is_compacting { return ("compacting \u{238B}", app.ui_theme.status_warning); } diff --git a/crates/tui/src/tui/ui.rs b/crates/tui/src/tui/ui.rs index 88c4a970..add27d97 100644 --- a/crates/tui/src/tui/ui.rs +++ b/crates/tui/src/tui/ui.rs @@ -1425,6 +1425,7 @@ async fn run_event_loop( let mut transcript_batch_updated = false; let mut queued_to_send: Option = None; let mut respawn_after_provider_rollback: Option = None; + let mut fallback_after_engine_error: Option = None; { let mut rx = engine_handle.rx_event.write().await; loop { @@ -2154,12 +2155,16 @@ async fn run_event_loop( envelope, recoverable: _, } => { + let provider_before_error = app.api_provider; let rollback_after_auth_failure = matches!( envelope.category, crate::error_taxonomy::ErrorCategory::Authentication ) && app.pending_provider_switch.is_some(); apply_engine_error_to_app(app, envelope); + if app.api_provider != provider_before_error && app.is_fallback_active() { + fallback_after_engine_error = Some(provider_before_error); + } if rollback_after_auth_failure && let Some(rollback_warning) = rollback_provider_after_auth_failure(app, config) @@ -2620,6 +2625,10 @@ async fn run_event_loop( } } } + if let Some(previous_provider) = fallback_after_engine_error { + apply_provider_fallback_switch(app, &mut engine_handle, config, previous_provider) + .await; + } if let Some(rollback_warning) = respawn_after_provider_rollback { let _ = engine_handle.send(Op::Shutdown).await; let engine_config = build_engine_config(app, config); @@ -4969,6 +4978,24 @@ pub(crate) fn apply_engine_error_to_app( ); return; } + if recoverable + && matches!( + envelope.category, + crate::error_taxonomy::ErrorCategory::Network + | crate::error_taxonomy::ErrorCategory::RateLimit + | crate::error_taxonomy::ErrorCategory::Timeout + ) + && app.advance_fallback(message.clone()).is_some() + { + let position = app.fallback_chain_position().unwrap_or(0); + let total = app.fallback_chain_len(); + app.status_message = Some(format!( + "Switched to {} (fallback {position}/{}) after recoverable provider error.", + app.api_provider.as_str(), + total.saturating_sub(1) + )); + return; + } if !recoverable { app.offline_mode = true; } @@ -6076,6 +6103,11 @@ async fn switch_provider( let new_endpoint = display_base_url_host(&new_base_url); let cache_scope_changed = previous_provider != target || previous_model != new_model; app.api_provider = target; + app.provider_chain = target + .kind() + .map(|kind| codewhale_config::ProviderChain::new(kind, &config.fallback_providers)) + .filter(|chain| chain.providers().len() > 1); + app.last_fallback_reason = None; app.model_ids_passthrough = config.model_ids_pass_through(); app.reasoning_effort = app.reasoning_effort.normalize_for_provider(target); app.set_model_selection(new_model.clone()); @@ -6158,6 +6190,105 @@ async fn switch_provider( app.status_message = Some(status_message); } +async fn apply_provider_fallback_switch( + app: &mut App, + engine_handle: &mut EngineHandle, + config: &mut Config, + previous_provider: ApiProvider, +) { + let target = app.api_provider; + let previous_config = config.clone(); + let previous_model = app.model.clone(); + + config.provider = Some(target.as_str().to_string()); + if matches!(target, ApiProvider::NvidiaNim) + && config + .base_url + .as_deref() + .map(|base| !base.contains("integrate.api.nvidia.com")) + .unwrap_or(true) + { + config.base_url = Some(DEFAULT_NVIDIA_NIM_BASE_URL.to_string()); + } + if matches!(target, ApiProvider::Deepseek | ApiProvider::DeepseekCN) + && config + .base_url + .as_deref() + .map(root_base_url_belongs_to_non_deepseek_provider) + .unwrap_or(false) + { + config.base_url = None; + } + + if let Err(err) = DeepSeekClient::new(config) { + *config = previous_config; + app.api_provider = previous_provider; + app.last_fallback_reason = Some(format!( + "Fallback provider {} was unavailable: {err}", + target.as_str() + )); + app.status_message = Some(format!( + "Fallback provider {} unavailable; provider remains {}.", + target.as_str(), + previous_provider.as_str() + )); + return; + } + + let new_model = config.default_model(); + let new_base_url = config.deepseek_base_url(); + let new_endpoint = display_base_url_host(&new_base_url); + let cache_scope_changed = previous_provider != target || previous_model != new_model; + app.model_ids_passthrough = config.model_ids_pass_through(); + app.reasoning_effort = app.reasoning_effort.normalize_for_provider(target); + app.set_model_selection(new_model.clone()); + app.update_model_compaction_budget(); + if cache_scope_changed { + app.clear_model_scoped_telemetry(); + } else { + app.session.last_prompt_tokens = None; + app.session.last_completion_tokens = None; + } + + let _ = engine_handle.send(Op::Shutdown).await; + let engine_config = build_engine_config(app, config); + *engine_handle = spawn_engine(engine_config, config); + + if !app.api_messages.is_empty() { + let _ = engine_handle + .send(Op::SyncSession { + session_id: app.current_session_id.clone(), + messages: app.api_messages.clone(), + system_prompt: app.system_prompt.clone(), + system_prompt_override: false, + model: app.model.clone(), + workspace: app.workspace.clone(), + }) + .await; + } + let _ = engine_handle + .send(Op::SetCompaction { + config: app.compaction_config(), + }) + .await; + + app.add_message(HistoryCell::System { + content: format!( + "Provider fallback: {} -> {}\nModel: {} -> {}\nEndpoint: {}", + previous_provider.as_str(), + target.as_str(), + previous_model, + new_model, + new_endpoint + ), + }); + app.status_message = Some(format!( + "Fallback provider: {} via {}", + target.as_str(), + new_endpoint + )); +} + fn root_base_url_belongs_to_non_deepseek_provider(base_url: &str) -> bool { let lower = base_url.to_ascii_lowercase(); [ diff --git a/crates/tui/src/tui/ui/tests.rs b/crates/tui/src/tui/ui/tests.rs index 9d7778e0..d9c7007b 100644 --- a/crates/tui/src/tui/ui/tests.rs +++ b/crates/tui/src/tui/ui/tests.rs @@ -8100,6 +8100,45 @@ fn recoverable_engine_error_does_not_enter_offline_mode() { let _ = ErrorEnvelope::transient(""); } +#[test] +fn recoverable_provider_error_advances_fallback_chain() { + use crate::error_taxonomy::{ErrorCategory, ErrorEnvelope, ErrorSeverity}; + + let mut app = create_test_app(); + app.api_provider = ApiProvider::Deepseek; + app.provider_chain = Some(codewhale_config::ProviderChain::new( + codewhale_config::ProviderKind::Deepseek, + &[codewhale_config::ProviderKind::Openrouter], + )); + + apply_engine_error_to_app( + &mut app, + ErrorEnvelope::new( + ErrorCategory::RateLimit, + ErrorSeverity::Warning, + true, + "rate_limit", + "provider returned 429", + ), + ); + + assert_eq!(app.api_provider, ApiProvider::Openrouter); + assert!(app.is_fallback_active()); + assert!(!app.offline_mode); + assert!( + app.status_message + .as_deref() + .unwrap_or_default() + .contains("Switched to openrouter") + ); + assert!( + app.last_fallback_reason + .as_deref() + .unwrap_or_default() + .contains("provider returned 429") + ); +} + #[tokio::test] async fn provider_switch_auth_error_restores_previous_provider_and_model() { use crate::error_taxonomy::ErrorEnvelope;