Harvest PR #2773: activate provider fallback chain
Harvested from PR #2773 by @idling11 Co-authored-by: Hanmiao Li <894876246@qq.com>
This commit is contained in:
@@ -566,9 +566,9 @@ pub struct ConfigToml {
|
||||
pub tools: Option<ToolsToml>,
|
||||
#[serde(default)]
|
||||
pub providers: ProvidersToml,
|
||||
/// Dormant provider fallback chain (#2574). This is parsed and preserved
|
||||
/// for future provider-routing work; current runtime resolution still uses
|
||||
/// the selected primary provider and does not auto-switch routes.
|
||||
/// Provider fallback chain (#2574). TUI runtime code may advance through
|
||||
/// these providers after recoverable provider errors; config resolution
|
||||
/// itself still reports the selected primary provider.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub fallback_providers: Vec<ProviderKind>,
|
||||
/// Per-domain network policy (#135). When absent, network tools fall back
|
||||
@@ -901,7 +901,10 @@ impl ProviderChain {
|
||||
|
||||
#[must_use]
|
||||
pub fn current(&self) -> ProviderKind {
|
||||
self.providers[self.position]
|
||||
self.providers
|
||||
.get(self.position)
|
||||
.copied()
|
||||
.unwrap_or(self.providers[0])
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
@@ -917,6 +920,10 @@ impl ProviderChain {
|
||||
Some(self.current())
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.position = 0;
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_fallback_active(&self) -> bool {
|
||||
self.position > 0
|
||||
|
||||
@@ -28,6 +28,10 @@ pub fn provider(app: &mut App, args: Option<&str>) -> CommandResult {
|
||||
let name = parts.next().unwrap_or("");
|
||||
let model_arg = parts.next();
|
||||
|
||||
if name.eq_ignore_ascii_case("fallback") {
|
||||
return provider_fallback(app, model_arg);
|
||||
}
|
||||
|
||||
let Some(target) = ApiProvider::parse(name) else {
|
||||
return CommandResult::error(format!(
|
||||
"Unknown provider '{name}'. Expected: {}.",
|
||||
@@ -70,6 +74,54 @@ pub fn provider(app: &mut App, args: Option<&str>) -> CommandResult {
|
||||
})
|
||||
}
|
||||
|
||||
fn provider_fallback(app: &mut App, subcommand: Option<&str>) -> CommandResult {
|
||||
match subcommand {
|
||||
Some("reset") => {
|
||||
let Some((_, primary, _)) = app.fallback_chain_entries().first().copied() else {
|
||||
return CommandResult::message(
|
||||
"No fallback providers configured. Add `fallback_providers` to your config.",
|
||||
);
|
||||
};
|
||||
CommandResult::with_message_and_action(
|
||||
format!("Fallback chain reset to primary provider: {}.", primary.as_str()),
|
||||
AppAction::SwitchProvider {
|
||||
provider: primary,
|
||||
model: None,
|
||||
},
|
||||
)
|
||||
}
|
||||
Some(other) => CommandResult::error(format!(
|
||||
"Unknown fallback command '{other}'. Usage: /provider fallback [reset]"
|
||||
)),
|
||||
None => {
|
||||
let entries = app.fallback_chain_entries();
|
||||
if entries.is_empty() {
|
||||
return CommandResult::message(
|
||||
"No fallback providers configured. Add `fallback_providers` to your config.",
|
||||
);
|
||||
}
|
||||
|
||||
let mut lines = vec![
|
||||
format!("Current provider: {}", app.api_provider.as_str()),
|
||||
"Fallback chain:".to_string(),
|
||||
];
|
||||
for (index, provider, is_current) in entries {
|
||||
let role = if index == 0 { "primary" } else { "fallback" };
|
||||
let marker = if is_current { " <- current" } else { "" };
|
||||
lines.push(format!(
|
||||
" [{index}] {} ({role}){marker}",
|
||||
provider.as_str()
|
||||
));
|
||||
}
|
||||
if let Some(reason) = app.last_fallback_reason.as_deref() {
|
||||
lines.push(format!("Last fallback: {reason}"));
|
||||
}
|
||||
lines.push("Use `/provider fallback reset` to return to the primary provider.".into());
|
||||
CommandResult::message(lines.join("\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_model_alias_for_provider(provider: ApiProvider, name: &str) -> String {
|
||||
let lower = name.trim().to_ascii_lowercase();
|
||||
if matches!(provider, ApiProvider::XiaomiMimo) {
|
||||
@@ -412,6 +464,31 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_fallback_status_and_reset_use_configured_chain() {
|
||||
let mut app = create_test_app();
|
||||
app.provider_chain = Some(codewhale_config::ProviderChain::new(
|
||||
codewhale_config::ProviderKind::Deepseek,
|
||||
&[codewhale_config::ProviderKind::Openrouter],
|
||||
));
|
||||
|
||||
let status = provider(&mut app, Some("fallback"));
|
||||
let message = status.message.expect("fallback status");
|
||||
assert!(message.contains("Current provider: deepseek"));
|
||||
assert!(message.contains("[0] deepseek (primary) <- current"));
|
||||
assert!(message.contains("[1] openrouter (fallback)"));
|
||||
|
||||
let reset = provider(&mut app, Some("fallback reset"));
|
||||
assert!(reset.message.as_deref().unwrap_or("").contains("deepseek"));
|
||||
assert!(matches!(
|
||||
reset.action,
|
||||
Some(AppAction::SwitchProvider {
|
||||
provider: ApiProvider::Deepseek,
|
||||
model: None
|
||||
})
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_model_returns_error() {
|
||||
let mut app = create_test_app();
|
||||
|
||||
@@ -1685,6 +1685,8 @@ pub struct Config {
|
||||
pub prompt_suggestion: Option<bool>,
|
||||
pub approval_policy: Option<String>,
|
||||
pub sandbox_mode: Option<String>,
|
||||
#[serde(default)]
|
||||
pub fallback_providers: Vec<codewhale_config::ProviderKind>,
|
||||
pub yolo: Option<bool>,
|
||||
pub verbosity: Option<String>,
|
||||
/// External sandbox backend: `"none"` or `"opensandbox"`.
|
||||
@@ -4784,6 +4786,11 @@ fn merge_config(base: Config, override_cfg: Config) -> Config {
|
||||
verbosity: override_cfg.verbosity.or(base.verbosity),
|
||||
approval_policy: override_cfg.approval_policy.or(base.approval_policy),
|
||||
sandbox_mode: override_cfg.sandbox_mode.or(base.sandbox_mode),
|
||||
fallback_providers: if override_cfg.fallback_providers.is_empty() {
|
||||
base.fallback_providers
|
||||
} else {
|
||||
override_cfg.fallback_providers
|
||||
},
|
||||
sandbox_backend: override_cfg.sandbox_backend.or(base.sandbox_backend),
|
||||
sandbox_url: override_cfg.sandbox_url.or(base.sandbox_url),
|
||||
sandbox_api_key: override_cfg.sandbox_api_key.or(base.sandbox_api_key),
|
||||
|
||||
@@ -9,6 +9,8 @@ use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use thiserror::Error;
|
||||
|
||||
use codewhale_config::ProviderChain;
|
||||
|
||||
use crate::artifacts::ArtifactRecord;
|
||||
use crate::client::{CacheWarmupKey, PromptInspection};
|
||||
use crate::compaction::CompactionConfig;
|
||||
@@ -1393,6 +1395,10 @@ pub struct App {
|
||||
/// Updated by `/provider` switches so the UI/commands can read the
|
||||
/// active backend without re-deriving it from the live config.
|
||||
pub api_provider: ApiProvider,
|
||||
/// Primary provider plus configured fallback providers for this session.
|
||||
pub provider_chain: Option<ProviderChain>,
|
||||
/// Human-readable description of the last provider fallback event.
|
||||
pub last_fallback_reason: Option<String>,
|
||||
/// True when the active provider/base URL accepts arbitrary model IDs
|
||||
/// verbatim rather than DeepSeek-only aliases.
|
||||
pub model_ids_passthrough: bool,
|
||||
@@ -2018,6 +2024,10 @@ impl App {
|
||||
let mut effective_auth_config = config.clone();
|
||||
effective_auth_config.provider = Some(provider.as_str().to_string());
|
||||
let model_ids_passthrough = effective_auth_config.model_ids_pass_through();
|
||||
let provider_chain = provider
|
||||
.kind()
|
||||
.map(|kind| ProviderChain::new(kind, &config.fallback_providers))
|
||||
.filter(|chain| chain.providers().len() > 1);
|
||||
|
||||
// Check if the effective provider has an API key. This must happen
|
||||
// after settings.default_provider is applied; otherwise a saved
|
||||
@@ -2231,6 +2241,8 @@ impl App {
|
||||
auto_model,
|
||||
last_effective_model: None,
|
||||
api_provider: provider,
|
||||
provider_chain,
|
||||
last_fallback_reason: None,
|
||||
model_ids_passthrough,
|
||||
pending_provider_switch: None,
|
||||
reasoning_effort,
|
||||
@@ -5168,6 +5180,56 @@ impl App {
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fallback_chain_entries(&self) -> Vec<(usize, ApiProvider, bool)> {
|
||||
let Some(chain) = &self.provider_chain else {
|
||||
return Vec::new();
|
||||
};
|
||||
let position = chain.position();
|
||||
chain
|
||||
.providers()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, provider)| (index, ApiProvider::from_kind(*provider), index == position))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn fallback_chain_position(&self) -> Option<usize> {
|
||||
self.provider_chain.as_ref().map(ProviderChain::position)
|
||||
}
|
||||
|
||||
pub fn fallback_chain_len(&self) -> usize {
|
||||
self.provider_chain
|
||||
.as_ref()
|
||||
.map_or(0, |chain| chain.providers().len())
|
||||
}
|
||||
|
||||
pub fn advance_fallback(&mut self, reason: impl Into<String>) -> Option<ApiProvider> {
|
||||
let reason = reason.into();
|
||||
let Some(chain) = self.provider_chain.as_mut() else {
|
||||
return None;
|
||||
};
|
||||
let Some(next_kind) = chain.advance() else {
|
||||
self.last_fallback_reason = Some(format!(
|
||||
"Fallback chain exhausted after {} provider(s): {reason}",
|
||||
chain.providers().len()
|
||||
));
|
||||
return None;
|
||||
};
|
||||
let next_provider = ApiProvider::from_kind(next_kind);
|
||||
self.api_provider = next_provider;
|
||||
self.last_fallback_reason = Some(format!(
|
||||
"Fell back to {} after recoverable provider error: {reason}",
|
||||
next_provider.as_str()
|
||||
));
|
||||
Some(next_provider)
|
||||
}
|
||||
|
||||
pub fn is_fallback_active(&self) -> bool {
|
||||
self.provider_chain
|
||||
.as_ref()
|
||||
.is_some_and(ProviderChain::is_fallback_active)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn media_attachment_reference(kind: &str, path: &Path, description: Option<&str>) -> String {
|
||||
|
||||
@@ -930,6 +930,9 @@ pub(crate) fn footer_status_line_spans(app: &App, max_width: usize) -> Vec<Span<
|
||||
}
|
||||
|
||||
pub(crate) fn footer_state_label(app: &App) -> (&'static str, ratatui::style::Color) {
|
||||
if app.is_fallback_active() {
|
||||
return ("fallback ->", app.ui_theme.status_warning);
|
||||
}
|
||||
if app.is_compacting {
|
||||
return ("compacting \u{238B}", app.ui_theme.status_warning);
|
||||
}
|
||||
|
||||
@@ -1425,6 +1425,7 @@ async fn run_event_loop(
|
||||
let mut transcript_batch_updated = false;
|
||||
let mut queued_to_send: Option<QueuedMessage> = None;
|
||||
let mut respawn_after_provider_rollback: Option<String> = None;
|
||||
let mut fallback_after_engine_error: Option<ApiProvider> = None;
|
||||
{
|
||||
let mut rx = engine_handle.rx_event.write().await;
|
||||
loop {
|
||||
@@ -2154,12 +2155,16 @@ async fn run_event_loop(
|
||||
envelope,
|
||||
recoverable: _,
|
||||
} => {
|
||||
let provider_before_error = app.api_provider;
|
||||
let rollback_after_auth_failure =
|
||||
matches!(
|
||||
envelope.category,
|
||||
crate::error_taxonomy::ErrorCategory::Authentication
|
||||
) && app.pending_provider_switch.is_some();
|
||||
apply_engine_error_to_app(app, envelope);
|
||||
if app.api_provider != provider_before_error && app.is_fallback_active() {
|
||||
fallback_after_engine_error = Some(provider_before_error);
|
||||
}
|
||||
if rollback_after_auth_failure
|
||||
&& let Some(rollback_warning) =
|
||||
rollback_provider_after_auth_failure(app, config)
|
||||
@@ -2620,6 +2625,10 @@ async fn run_event_loop(
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(previous_provider) = fallback_after_engine_error {
|
||||
apply_provider_fallback_switch(app, &mut engine_handle, config, previous_provider)
|
||||
.await;
|
||||
}
|
||||
if let Some(rollback_warning) = respawn_after_provider_rollback {
|
||||
let _ = engine_handle.send(Op::Shutdown).await;
|
||||
let engine_config = build_engine_config(app, config);
|
||||
@@ -4969,6 +4978,24 @@ pub(crate) fn apply_engine_error_to_app(
|
||||
);
|
||||
return;
|
||||
}
|
||||
if recoverable
|
||||
&& matches!(
|
||||
envelope.category,
|
||||
crate::error_taxonomy::ErrorCategory::Network
|
||||
| crate::error_taxonomy::ErrorCategory::RateLimit
|
||||
| crate::error_taxonomy::ErrorCategory::Timeout
|
||||
)
|
||||
&& app.advance_fallback(message.clone()).is_some()
|
||||
{
|
||||
let position = app.fallback_chain_position().unwrap_or(0);
|
||||
let total = app.fallback_chain_len();
|
||||
app.status_message = Some(format!(
|
||||
"Switched to {} (fallback {position}/{}) after recoverable provider error.",
|
||||
app.api_provider.as_str(),
|
||||
total.saturating_sub(1)
|
||||
));
|
||||
return;
|
||||
}
|
||||
if !recoverable {
|
||||
app.offline_mode = true;
|
||||
}
|
||||
@@ -6076,6 +6103,11 @@ async fn switch_provider(
|
||||
let new_endpoint = display_base_url_host(&new_base_url);
|
||||
let cache_scope_changed = previous_provider != target || previous_model != new_model;
|
||||
app.api_provider = target;
|
||||
app.provider_chain = target
|
||||
.kind()
|
||||
.map(|kind| codewhale_config::ProviderChain::new(kind, &config.fallback_providers))
|
||||
.filter(|chain| chain.providers().len() > 1);
|
||||
app.last_fallback_reason = None;
|
||||
app.model_ids_passthrough = config.model_ids_pass_through();
|
||||
app.reasoning_effort = app.reasoning_effort.normalize_for_provider(target);
|
||||
app.set_model_selection(new_model.clone());
|
||||
@@ -6158,6 +6190,105 @@ async fn switch_provider(
|
||||
app.status_message = Some(status_message);
|
||||
}
|
||||
|
||||
async fn apply_provider_fallback_switch(
|
||||
app: &mut App,
|
||||
engine_handle: &mut EngineHandle,
|
||||
config: &mut Config,
|
||||
previous_provider: ApiProvider,
|
||||
) {
|
||||
let target = app.api_provider;
|
||||
let previous_config = config.clone();
|
||||
let previous_model = app.model.clone();
|
||||
|
||||
config.provider = Some(target.as_str().to_string());
|
||||
if matches!(target, ApiProvider::NvidiaNim)
|
||||
&& config
|
||||
.base_url
|
||||
.as_deref()
|
||||
.map(|base| !base.contains("integrate.api.nvidia.com"))
|
||||
.unwrap_or(true)
|
||||
{
|
||||
config.base_url = Some(DEFAULT_NVIDIA_NIM_BASE_URL.to_string());
|
||||
}
|
||||
if matches!(target, ApiProvider::Deepseek | ApiProvider::DeepseekCN)
|
||||
&& config
|
||||
.base_url
|
||||
.as_deref()
|
||||
.map(root_base_url_belongs_to_non_deepseek_provider)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
config.base_url = None;
|
||||
}
|
||||
|
||||
if let Err(err) = DeepSeekClient::new(config) {
|
||||
*config = previous_config;
|
||||
app.api_provider = previous_provider;
|
||||
app.last_fallback_reason = Some(format!(
|
||||
"Fallback provider {} was unavailable: {err}",
|
||||
target.as_str()
|
||||
));
|
||||
app.status_message = Some(format!(
|
||||
"Fallback provider {} unavailable; provider remains {}.",
|
||||
target.as_str(),
|
||||
previous_provider.as_str()
|
||||
));
|
||||
return;
|
||||
}
|
||||
|
||||
let new_model = config.default_model();
|
||||
let new_base_url = config.deepseek_base_url();
|
||||
let new_endpoint = display_base_url_host(&new_base_url);
|
||||
let cache_scope_changed = previous_provider != target || previous_model != new_model;
|
||||
app.model_ids_passthrough = config.model_ids_pass_through();
|
||||
app.reasoning_effort = app.reasoning_effort.normalize_for_provider(target);
|
||||
app.set_model_selection(new_model.clone());
|
||||
app.update_model_compaction_budget();
|
||||
if cache_scope_changed {
|
||||
app.clear_model_scoped_telemetry();
|
||||
} else {
|
||||
app.session.last_prompt_tokens = None;
|
||||
app.session.last_completion_tokens = None;
|
||||
}
|
||||
|
||||
let _ = engine_handle.send(Op::Shutdown).await;
|
||||
let engine_config = build_engine_config(app, config);
|
||||
*engine_handle = spawn_engine(engine_config, config);
|
||||
|
||||
if !app.api_messages.is_empty() {
|
||||
let _ = engine_handle
|
||||
.send(Op::SyncSession {
|
||||
session_id: app.current_session_id.clone(),
|
||||
messages: app.api_messages.clone(),
|
||||
system_prompt: app.system_prompt.clone(),
|
||||
system_prompt_override: false,
|
||||
model: app.model.clone(),
|
||||
workspace: app.workspace.clone(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
let _ = engine_handle
|
||||
.send(Op::SetCompaction {
|
||||
config: app.compaction_config(),
|
||||
})
|
||||
.await;
|
||||
|
||||
app.add_message(HistoryCell::System {
|
||||
content: format!(
|
||||
"Provider fallback: {} -> {}\nModel: {} -> {}\nEndpoint: {}",
|
||||
previous_provider.as_str(),
|
||||
target.as_str(),
|
||||
previous_model,
|
||||
new_model,
|
||||
new_endpoint
|
||||
),
|
||||
});
|
||||
app.status_message = Some(format!(
|
||||
"Fallback provider: {} via {}",
|
||||
target.as_str(),
|
||||
new_endpoint
|
||||
));
|
||||
}
|
||||
|
||||
fn root_base_url_belongs_to_non_deepseek_provider(base_url: &str) -> bool {
|
||||
let lower = base_url.to_ascii_lowercase();
|
||||
[
|
||||
|
||||
@@ -8100,6 +8100,45 @@ fn recoverable_engine_error_does_not_enter_offline_mode() {
|
||||
let _ = ErrorEnvelope::transient("");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recoverable_provider_error_advances_fallback_chain() {
|
||||
use crate::error_taxonomy::{ErrorCategory, ErrorEnvelope, ErrorSeverity};
|
||||
|
||||
let mut app = create_test_app();
|
||||
app.api_provider = ApiProvider::Deepseek;
|
||||
app.provider_chain = Some(codewhale_config::ProviderChain::new(
|
||||
codewhale_config::ProviderKind::Deepseek,
|
||||
&[codewhale_config::ProviderKind::Openrouter],
|
||||
));
|
||||
|
||||
apply_engine_error_to_app(
|
||||
&mut app,
|
||||
ErrorEnvelope::new(
|
||||
ErrorCategory::RateLimit,
|
||||
ErrorSeverity::Warning,
|
||||
true,
|
||||
"rate_limit",
|
||||
"provider returned 429",
|
||||
),
|
||||
);
|
||||
|
||||
assert_eq!(app.api_provider, ApiProvider::Openrouter);
|
||||
assert!(app.is_fallback_active());
|
||||
assert!(!app.offline_mode);
|
||||
assert!(
|
||||
app.status_message
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.contains("Switched to openrouter")
|
||||
);
|
||||
assert!(
|
||||
app.last_fallback_reason
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.contains("provider returned 429")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_switch_auth_error_restores_previous_provider_and_model() {
|
||||
use crate::error_taxonomy::ErrorEnvelope;
|
||||
|
||||
Reference in New Issue
Block a user