Harvest PR #2773: activate provider fallback chain

Harvested from PR #2773 by @idling11

Co-authored-by: Hanmiao Li <894876246@qq.com>
This commit is contained in:
CodeWhale Agent
2026-06-12 14:12:51 -07:00
parent 4200b64365
commit 662a459ee5
7 changed files with 330 additions and 4 deletions
+11 -4
View File
@@ -566,9 +566,9 @@ pub struct ConfigToml {
pub tools: Option<ToolsToml>,
#[serde(default)]
pub providers: ProvidersToml,
/// Dormant provider fallback chain (#2574). This is parsed and preserved
/// for future provider-routing work; current runtime resolution still uses
/// the selected primary provider and does not auto-switch routes.
/// Provider fallback chain (#2574). TUI runtime code may advance through
/// these providers after recoverable provider errors; config resolution
/// itself still reports the selected primary provider.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub fallback_providers: Vec<ProviderKind>,
/// Per-domain network policy (#135). When absent, network tools fall back
@@ -901,7 +901,10 @@ impl ProviderChain {
#[must_use]
pub fn current(&self) -> ProviderKind {
self.providers[self.position]
self.providers
.get(self.position)
.copied()
.unwrap_or(self.providers[0])
}
#[must_use]
@@ -917,6 +920,10 @@ impl ProviderChain {
Some(self.current())
}
pub fn reset(&mut self) {
self.position = 0;
}
#[must_use]
pub fn is_fallback_active(&self) -> bool {
self.position > 0
@@ -28,6 +28,10 @@ pub fn provider(app: &mut App, args: Option<&str>) -> CommandResult {
let name = parts.next().unwrap_or("");
let model_arg = parts.next();
if name.eq_ignore_ascii_case("fallback") {
return provider_fallback(app, model_arg);
}
let Some(target) = ApiProvider::parse(name) else {
return CommandResult::error(format!(
"Unknown provider '{name}'. Expected: {}.",
@@ -70,6 +74,54 @@ pub fn provider(app: &mut App, args: Option<&str>) -> CommandResult {
})
}
fn provider_fallback(app: &mut App, subcommand: Option<&str>) -> CommandResult {
match subcommand {
Some("reset") => {
let Some((_, primary, _)) = app.fallback_chain_entries().first().copied() else {
return CommandResult::message(
"No fallback providers configured. Add `fallback_providers` to your config.",
);
};
CommandResult::with_message_and_action(
format!("Fallback chain reset to primary provider: {}.", primary.as_str()),
AppAction::SwitchProvider {
provider: primary,
model: None,
},
)
}
Some(other) => CommandResult::error(format!(
"Unknown fallback command '{other}'. Usage: /provider fallback [reset]"
)),
None => {
let entries = app.fallback_chain_entries();
if entries.is_empty() {
return CommandResult::message(
"No fallback providers configured. Add `fallback_providers` to your config.",
);
}
let mut lines = vec![
format!("Current provider: {}", app.api_provider.as_str()),
"Fallback chain:".to_string(),
];
for (index, provider, is_current) in entries {
let role = if index == 0 { "primary" } else { "fallback" };
let marker = if is_current { " <- current" } else { "" };
lines.push(format!(
" [{index}] {} ({role}){marker}",
provider.as_str()
));
}
if let Some(reason) = app.last_fallback_reason.as_deref() {
lines.push(format!("Last fallback: {reason}"));
}
lines.push("Use `/provider fallback reset` to return to the primary provider.".into());
CommandResult::message(lines.join("\n"))
}
}
}
fn expand_model_alias_for_provider(provider: ApiProvider, name: &str) -> String {
let lower = name.trim().to_ascii_lowercase();
if matches!(provider, ApiProvider::XiaomiMimo) {
@@ -412,6 +464,31 @@ mod tests {
}
}
#[test]
fn provider_fallback_status_and_reset_use_configured_chain() {
let mut app = create_test_app();
app.provider_chain = Some(codewhale_config::ProviderChain::new(
codewhale_config::ProviderKind::Deepseek,
&[codewhale_config::ProviderKind::Openrouter],
));
let status = provider(&mut app, Some("fallback"));
let message = status.message.expect("fallback status");
assert!(message.contains("Current provider: deepseek"));
assert!(message.contains("[0] deepseek (primary) <- current"));
assert!(message.contains("[1] openrouter (fallback)"));
let reset = provider(&mut app, Some("fallback reset"));
assert!(reset.message.as_deref().unwrap_or("").contains("deepseek"));
assert!(matches!(
reset.action,
Some(AppAction::SwitchProvider {
provider: ApiProvider::Deepseek,
model: None
})
));
}
#[test]
fn invalid_model_returns_error() {
let mut app = create_test_app();
+7
View File
@@ -1685,6 +1685,8 @@ pub struct Config {
pub prompt_suggestion: Option<bool>,
pub approval_policy: Option<String>,
pub sandbox_mode: Option<String>,
#[serde(default)]
pub fallback_providers: Vec<codewhale_config::ProviderKind>,
pub yolo: Option<bool>,
pub verbosity: Option<String>,
/// External sandbox backend: `"none"` or `"opensandbox"`.
@@ -4784,6 +4786,11 @@ fn merge_config(base: Config, override_cfg: Config) -> Config {
verbosity: override_cfg.verbosity.or(base.verbosity),
approval_policy: override_cfg.approval_policy.or(base.approval_policy),
sandbox_mode: override_cfg.sandbox_mode.or(base.sandbox_mode),
fallback_providers: if override_cfg.fallback_providers.is_empty() {
base.fallback_providers
} else {
override_cfg.fallback_providers
},
sandbox_backend: override_cfg.sandbox_backend.or(base.sandbox_backend),
sandbox_url: override_cfg.sandbox_url.or(base.sandbox_url),
sandbox_api_key: override_cfg.sandbox_api_key.or(base.sandbox_api_key),
+62
View File
@@ -9,6 +9,8 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
use codewhale_config::ProviderChain;
use crate::artifacts::ArtifactRecord;
use crate::client::{CacheWarmupKey, PromptInspection};
use crate::compaction::CompactionConfig;
@@ -1393,6 +1395,10 @@ pub struct App {
/// Updated by `/provider` switches so the UI/commands can read the
/// active backend without re-deriving it from the live config.
pub api_provider: ApiProvider,
/// Primary provider plus configured fallback providers for this session.
pub provider_chain: Option<ProviderChain>,
/// Human-readable description of the last provider fallback event.
pub last_fallback_reason: Option<String>,
/// True when the active provider/base URL accepts arbitrary model IDs
/// verbatim rather than DeepSeek-only aliases.
pub model_ids_passthrough: bool,
@@ -2018,6 +2024,10 @@ impl App {
let mut effective_auth_config = config.clone();
effective_auth_config.provider = Some(provider.as_str().to_string());
let model_ids_passthrough = effective_auth_config.model_ids_pass_through();
let provider_chain = provider
.kind()
.map(|kind| ProviderChain::new(kind, &config.fallback_providers))
.filter(|chain| chain.providers().len() > 1);
// Check if the effective provider has an API key. This must happen
// after settings.default_provider is applied; otherwise a saved
@@ -2231,6 +2241,8 @@ impl App {
auto_model,
last_effective_model: None,
api_provider: provider,
provider_chain,
last_fallback_reason: None,
model_ids_passthrough,
pending_provider_switch: None,
reasoning_effort,
@@ -5168,6 +5180,56 @@ impl App {
..Default::default()
}
}
pub fn fallback_chain_entries(&self) -> Vec<(usize, ApiProvider, bool)> {
let Some(chain) = &self.provider_chain else {
return Vec::new();
};
let position = chain.position();
chain
.providers()
.iter()
.enumerate()
.map(|(index, provider)| (index, ApiProvider::from_kind(*provider), index == position))
.collect()
}
pub fn fallback_chain_position(&self) -> Option<usize> {
self.provider_chain.as_ref().map(ProviderChain::position)
}
pub fn fallback_chain_len(&self) -> usize {
self.provider_chain
.as_ref()
.map_or(0, |chain| chain.providers().len())
}
pub fn advance_fallback(&mut self, reason: impl Into<String>) -> Option<ApiProvider> {
let reason = reason.into();
let Some(chain) = self.provider_chain.as_mut() else {
return None;
};
let Some(next_kind) = chain.advance() else {
self.last_fallback_reason = Some(format!(
"Fallback chain exhausted after {} provider(s): {reason}",
chain.providers().len()
));
return None;
};
let next_provider = ApiProvider::from_kind(next_kind);
self.api_provider = next_provider;
self.last_fallback_reason = Some(format!(
"Fell back to {} after recoverable provider error: {reason}",
next_provider.as_str()
));
Some(next_provider)
}
pub fn is_fallback_active(&self) -> bool {
self.provider_chain
.as_ref()
.is_some_and(ProviderChain::is_fallback_active)
}
}
pub fn media_attachment_reference(kind: &str, path: &Path, description: Option<&str>) -> String {
+3
View File
@@ -930,6 +930,9 @@ pub(crate) fn footer_status_line_spans(app: &App, max_width: usize) -> Vec<Span<
}
pub(crate) fn footer_state_label(app: &App) -> (&'static str, ratatui::style::Color) {
if app.is_fallback_active() {
return ("fallback ->", app.ui_theme.status_warning);
}
if app.is_compacting {
return ("compacting \u{238B}", app.ui_theme.status_warning);
}
+131
View File
@@ -1425,6 +1425,7 @@ async fn run_event_loop(
let mut transcript_batch_updated = false;
let mut queued_to_send: Option<QueuedMessage> = None;
let mut respawn_after_provider_rollback: Option<String> = None;
let mut fallback_after_engine_error: Option<ApiProvider> = None;
{
let mut rx = engine_handle.rx_event.write().await;
loop {
@@ -2154,12 +2155,16 @@ async fn run_event_loop(
envelope,
recoverable: _,
} => {
let provider_before_error = app.api_provider;
let rollback_after_auth_failure =
matches!(
envelope.category,
crate::error_taxonomy::ErrorCategory::Authentication
) && app.pending_provider_switch.is_some();
apply_engine_error_to_app(app, envelope);
if app.api_provider != provider_before_error && app.is_fallback_active() {
fallback_after_engine_error = Some(provider_before_error);
}
if rollback_after_auth_failure
&& let Some(rollback_warning) =
rollback_provider_after_auth_failure(app, config)
@@ -2620,6 +2625,10 @@ async fn run_event_loop(
}
}
}
if let Some(previous_provider) = fallback_after_engine_error {
apply_provider_fallback_switch(app, &mut engine_handle, config, previous_provider)
.await;
}
if let Some(rollback_warning) = respawn_after_provider_rollback {
let _ = engine_handle.send(Op::Shutdown).await;
let engine_config = build_engine_config(app, config);
@@ -4969,6 +4978,24 @@ pub(crate) fn apply_engine_error_to_app(
);
return;
}
if recoverable
&& matches!(
envelope.category,
crate::error_taxonomy::ErrorCategory::Network
| crate::error_taxonomy::ErrorCategory::RateLimit
| crate::error_taxonomy::ErrorCategory::Timeout
)
&& app.advance_fallback(message.clone()).is_some()
{
let position = app.fallback_chain_position().unwrap_or(0);
let total = app.fallback_chain_len();
app.status_message = Some(format!(
"Switched to {} (fallback {position}/{}) after recoverable provider error.",
app.api_provider.as_str(),
total.saturating_sub(1)
));
return;
}
if !recoverable {
app.offline_mode = true;
}
@@ -6076,6 +6103,11 @@ async fn switch_provider(
let new_endpoint = display_base_url_host(&new_base_url);
let cache_scope_changed = previous_provider != target || previous_model != new_model;
app.api_provider = target;
app.provider_chain = target
.kind()
.map(|kind| codewhale_config::ProviderChain::new(kind, &config.fallback_providers))
.filter(|chain| chain.providers().len() > 1);
app.last_fallback_reason = None;
app.model_ids_passthrough = config.model_ids_pass_through();
app.reasoning_effort = app.reasoning_effort.normalize_for_provider(target);
app.set_model_selection(new_model.clone());
@@ -6158,6 +6190,105 @@ async fn switch_provider(
app.status_message = Some(status_message);
}
async fn apply_provider_fallback_switch(
app: &mut App,
engine_handle: &mut EngineHandle,
config: &mut Config,
previous_provider: ApiProvider,
) {
let target = app.api_provider;
let previous_config = config.clone();
let previous_model = app.model.clone();
config.provider = Some(target.as_str().to_string());
if matches!(target, ApiProvider::NvidiaNim)
&& config
.base_url
.as_deref()
.map(|base| !base.contains("integrate.api.nvidia.com"))
.unwrap_or(true)
{
config.base_url = Some(DEFAULT_NVIDIA_NIM_BASE_URL.to_string());
}
if matches!(target, ApiProvider::Deepseek | ApiProvider::DeepseekCN)
&& config
.base_url
.as_deref()
.map(root_base_url_belongs_to_non_deepseek_provider)
.unwrap_or(false)
{
config.base_url = None;
}
if let Err(err) = DeepSeekClient::new(config) {
*config = previous_config;
app.api_provider = previous_provider;
app.last_fallback_reason = Some(format!(
"Fallback provider {} was unavailable: {err}",
target.as_str()
));
app.status_message = Some(format!(
"Fallback provider {} unavailable; provider remains {}.",
target.as_str(),
previous_provider.as_str()
));
return;
}
let new_model = config.default_model();
let new_base_url = config.deepseek_base_url();
let new_endpoint = display_base_url_host(&new_base_url);
let cache_scope_changed = previous_provider != target || previous_model != new_model;
app.model_ids_passthrough = config.model_ids_pass_through();
app.reasoning_effort = app.reasoning_effort.normalize_for_provider(target);
app.set_model_selection(new_model.clone());
app.update_model_compaction_budget();
if cache_scope_changed {
app.clear_model_scoped_telemetry();
} else {
app.session.last_prompt_tokens = None;
app.session.last_completion_tokens = None;
}
let _ = engine_handle.send(Op::Shutdown).await;
let engine_config = build_engine_config(app, config);
*engine_handle = spawn_engine(engine_config, config);
if !app.api_messages.is_empty() {
let _ = engine_handle
.send(Op::SyncSession {
session_id: app.current_session_id.clone(),
messages: app.api_messages.clone(),
system_prompt: app.system_prompt.clone(),
system_prompt_override: false,
model: app.model.clone(),
workspace: app.workspace.clone(),
})
.await;
}
let _ = engine_handle
.send(Op::SetCompaction {
config: app.compaction_config(),
})
.await;
app.add_message(HistoryCell::System {
content: format!(
"Provider fallback: {} -> {}\nModel: {} -> {}\nEndpoint: {}",
previous_provider.as_str(),
target.as_str(),
previous_model,
new_model,
new_endpoint
),
});
app.status_message = Some(format!(
"Fallback provider: {} via {}",
target.as_str(),
new_endpoint
));
}
fn root_base_url_belongs_to_non_deepseek_provider(base_url: &str) -> bool {
let lower = base_url.to_ascii_lowercase();
[
+39
View File
@@ -8100,6 +8100,45 @@ fn recoverable_engine_error_does_not_enter_offline_mode() {
let _ = ErrorEnvelope::transient("");
}
#[test]
fn recoverable_provider_error_advances_fallback_chain() {
use crate::error_taxonomy::{ErrorCategory, ErrorEnvelope, ErrorSeverity};
let mut app = create_test_app();
app.api_provider = ApiProvider::Deepseek;
app.provider_chain = Some(codewhale_config::ProviderChain::new(
codewhale_config::ProviderKind::Deepseek,
&[codewhale_config::ProviderKind::Openrouter],
));
apply_engine_error_to_app(
&mut app,
ErrorEnvelope::new(
ErrorCategory::RateLimit,
ErrorSeverity::Warning,
true,
"rate_limit",
"provider returned 429",
),
);
assert_eq!(app.api_provider, ApiProvider::Openrouter);
assert!(app.is_fallback_active());
assert!(!app.offline_mode);
assert!(
app.status_message
.as_deref()
.unwrap_or_default()
.contains("Switched to openrouter")
);
assert!(
app.last_fallback_reason
.as_deref()
.unwrap_or_default()
.contains("provider returned 429")
);
}
#[tokio::test]
async fn provider_switch_auth_error_restores_previous_provider_and_model() {
use crate::error_taxonomy::ErrorEnvelope;