diff --git a/README.md b/README.md index 63294da6..b3ac357e 100644 --- a/README.md +++ b/README.md @@ -303,6 +303,7 @@ Key environment variables: |---|---| | `DEEPSEEK_API_KEY` | API key | | `DEEPSEEK_BASE_URL` | API base URL | +| `DEEPSEEK_HTTP_HEADERS` | Optional custom model request headers, e.g. `X-Model-Provider-Id=your-model-provider` | | `DEEPSEEK_MODEL` | Default model | | `DEEPSEEK_PROVIDER` | `deepseek` (default), `nvidia-nim`, `fireworks`, `sglang`, `vllm` | | `DEEPSEEK_PROFILE` | Config profile name | diff --git a/config.example.toml b/config.example.toml index 41d085bb..3a9aee2f 100644 --- a/config.example.toml +++ b/config.example.toml @@ -20,6 +20,9 @@ api_key = "YOUR_DEEPSEEK_API_KEY" # must be non-empty base_url = "https://api.deepseek.com" # base_url = "https://api.deepseeki.com" # China users # base_url = "https://api.deepseek.com/beta" # DeepSeek beta features such as strict tool mode +# Optional custom model request headers for OpenAI-compatible gateways. +# Authorization and Content-Type are managed by the client and cannot be overridden here. +# http_headers = { "X-Model-Provider-Id" = "your-model-provider" } # ───────────────────────────────────────────────────────────────────────────────── # Default Models @@ -161,6 +164,7 @@ max_subagents = 10 # optional (1-20) # api_key = "YOUR_DEEPSEEK_API_KEY" # base_url = "https://api.deepseek.com" # model = "deepseek-v4-pro" +# http_headers = { "X-Model-Provider-Id" = "your-model-provider" } # optional custom request headers # NVIDIA NIM-hosted DeepSeek V4 (https://build.nvidia.com) [providers.nvidia_nim] diff --git a/crates/cli/src/lib.rs b/crates/cli/src/lib.rs index 0a942435..c3c8faf6 100644 --- a/crates/cli/src/lib.rs +++ b/crates/cli/src/lib.rs @@ -1294,6 +1294,15 @@ fn build_tui_command( cmd.env("DEEPSEEK_MODEL", &resolved_runtime.model); cmd.env("DEEPSEEK_BASE_URL", &resolved_runtime.base_url); cmd.env("DEEPSEEK_PROVIDER", resolved_runtime.provider.as_str()); + if !resolved_runtime.http_headers.is_empty() { + let encoded = resolved_runtime + .http_headers + .iter() + .map(|(name, value)| format!("{}={}", name.trim(), value.trim())) + .collect::>() + .join(","); + cmd.env("DEEPSEEK_HTTP_HEADERS", encoded); + } if let Some(api_key) = resolved_runtime.api_key.as_ref() { cmd.env("DEEPSEEK_API_KEY", api_key); let source = resolved_runtime diff --git a/crates/config/src/lib.rs b/crates/config/src/lib.rs index cd0b5ab7..48077512 100644 --- a/crates/config/src/lib.rs +++ b/crates/config/src/lib.rs @@ -86,6 +86,8 @@ pub struct ProviderConfigToml { pub api_key: Option, pub base_url: Option, pub model: Option, + #[serde(default)] + pub http_headers: BTreeMap, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -144,6 +146,9 @@ pub struct ConfigToml { pub api_key: Option, /// TUI-compatible DeepSeek base URL. pub base_url: Option, + /// Optional extra HTTP headers forwarded to model API requests. + #[serde(default)] + pub http_headers: BTreeMap, /// TUI-compatible default DeepSeek model. pub default_text_model: Option, #[serde(default)] @@ -294,6 +299,9 @@ impl ConfigToml { if project.base_url.is_some() { self.base_url = project.base_url; } + if !project.http_headers.is_empty() { + self.http_headers = project.http_headers; + } if project.default_text_model.is_some() { self.default_text_model = project.default_text_model; } @@ -359,6 +367,7 @@ impl ConfigToml { "provider" => Some(self.provider.as_str().to_string()), "api_key" => self.api_key.clone(), "base_url" => self.base_url.clone(), + "http_headers" => serialize_http_headers(&self.http_headers), "default_text_model" => self.default_text_model.clone(), "model" => self.model.clone(), "auth.mode" => self.auth_mode.clone(), @@ -372,27 +381,51 @@ impl ConfigToml { "providers.deepseek.api_key" => self.providers.deepseek.api_key.clone(), "providers.deepseek.base_url" => self.providers.deepseek.base_url.clone(), "providers.deepseek.model" => self.providers.deepseek.model.clone(), + "providers.deepseek.http_headers" => { + serialize_http_headers(&self.providers.deepseek.http_headers) + } "providers.nvidia_nim.api_key" => self.providers.nvidia_nim.api_key.clone(), "providers.nvidia_nim.base_url" => self.providers.nvidia_nim.base_url.clone(), "providers.nvidia_nim.model" => self.providers.nvidia_nim.model.clone(), + "providers.nvidia_nim.http_headers" => { + serialize_http_headers(&self.providers.nvidia_nim.http_headers) + } "providers.openai.api_key" => self.providers.openai.api_key.clone(), "providers.openai.base_url" => self.providers.openai.base_url.clone(), "providers.openai.model" => self.providers.openai.model.clone(), + "providers.openai.http_headers" => { + serialize_http_headers(&self.providers.openai.http_headers) + } "providers.openrouter.api_key" => self.providers.openrouter.api_key.clone(), "providers.openrouter.base_url" => self.providers.openrouter.base_url.clone(), "providers.openrouter.model" => self.providers.openrouter.model.clone(), + "providers.openrouter.http_headers" => { + serialize_http_headers(&self.providers.openrouter.http_headers) + } "providers.novita.api_key" => self.providers.novita.api_key.clone(), "providers.novita.base_url" => self.providers.novita.base_url.clone(), "providers.novita.model" => self.providers.novita.model.clone(), + "providers.novita.http_headers" => { + serialize_http_headers(&self.providers.novita.http_headers) + } "providers.fireworks.api_key" => self.providers.fireworks.api_key.clone(), "providers.fireworks.base_url" => self.providers.fireworks.base_url.clone(), "providers.fireworks.model" => self.providers.fireworks.model.clone(), + "providers.fireworks.http_headers" => { + serialize_http_headers(&self.providers.fireworks.http_headers) + } "providers.sglang.api_key" => self.providers.sglang.api_key.clone(), "providers.sglang.base_url" => self.providers.sglang.base_url.clone(), "providers.sglang.model" => self.providers.sglang.model.clone(), + "providers.sglang.http_headers" => { + serialize_http_headers(&self.providers.sglang.http_headers) + } "providers.vllm.api_key" => self.providers.vllm.api_key.clone(), "providers.vllm.base_url" => self.providers.vllm.base_url.clone(), "providers.vllm.model" => self.providers.vllm.model.clone(), + "providers.vllm.http_headers" => { + serialize_http_headers(&self.providers.vllm.http_headers) + } _ => self.extras.get(key).map(toml::Value::to_string), } } @@ -405,6 +438,7 @@ impl ConfigToml { } "api_key" => self.api_key = Some(value.to_string()), "base_url" => self.base_url = Some(value.to_string()), + "http_headers" => self.http_headers = parse_http_headers(value)?, "default_text_model" => self.default_text_model = Some(value.to_string()), "model" => self.model = Some(value.to_string()), "auth.mode" => self.auth_mode = Some(value.to_string()), @@ -432,9 +466,17 @@ impl ConfigToml { self.providers.deepseek.model = Some(value.clone()); self.default_text_model = Some(value); } + "providers.deepseek.http_headers" => { + let headers = parse_http_headers(value)?; + self.providers.deepseek.http_headers = headers.clone(); + self.http_headers = headers; + } "providers.openai.api_key" => self.providers.openai.api_key = Some(value.to_string()), "providers.openai.base_url" => self.providers.openai.base_url = Some(value.to_string()), "providers.openai.model" => self.providers.openai.model = Some(value.to_string()), + "providers.openai.http_headers" => { + self.providers.openai.http_headers = parse_http_headers(value)?; + } "providers.nvidia_nim.api_key" => { self.providers.nvidia_nim.api_key = Some(value.to_string()); } @@ -444,6 +486,9 @@ impl ConfigToml { "providers.nvidia_nim.model" => { self.providers.nvidia_nim.model = Some(value.to_string()); } + "providers.nvidia_nim.http_headers" => { + self.providers.nvidia_nim.http_headers = parse_http_headers(value)?; + } "providers.openrouter.api_key" => { self.providers.openrouter.api_key = Some(value.to_string()); } @@ -453,6 +498,9 @@ impl ConfigToml { "providers.openrouter.model" => { self.providers.openrouter.model = Some(value.to_string()); } + "providers.openrouter.http_headers" => { + self.providers.openrouter.http_headers = parse_http_headers(value)?; + } "providers.novita.api_key" => { self.providers.novita.api_key = Some(value.to_string()); } @@ -462,6 +510,9 @@ impl ConfigToml { "providers.novita.model" => { self.providers.novita.model = Some(value.to_string()); } + "providers.novita.http_headers" => { + self.providers.novita.http_headers = parse_http_headers(value)?; + } "providers.fireworks.api_key" => { self.providers.fireworks.api_key = Some(value.to_string()); } @@ -471,6 +522,9 @@ impl ConfigToml { "providers.fireworks.model" => { self.providers.fireworks.model = Some(value.to_string()); } + "providers.fireworks.http_headers" => { + self.providers.fireworks.http_headers = parse_http_headers(value)?; + } "providers.sglang.api_key" => { self.providers.sglang.api_key = Some(value.to_string()); } @@ -480,6 +534,9 @@ impl ConfigToml { "providers.sglang.model" => { self.providers.sglang.model = Some(value.to_string()); } + "providers.sglang.http_headers" => { + self.providers.sglang.http_headers = parse_http_headers(value)?; + } "providers.vllm.api_key" => { self.providers.vllm.api_key = Some(value.to_string()); } @@ -489,6 +546,9 @@ impl ConfigToml { "providers.vllm.model" => { self.providers.vllm.model = Some(value.to_string()); } + "providers.vllm.http_headers" => { + self.providers.vllm.http_headers = parse_http_headers(value)?; + } _ => { self.extras .insert(key.to_string(), toml::Value::String(value.to_string())); @@ -502,6 +562,7 @@ impl ConfigToml { "provider" => self.provider = ProviderKind::Deepseek, "api_key" => self.api_key = None, "base_url" => self.base_url = None, + "http_headers" => self.http_headers.clear(), "default_text_model" => self.default_text_model = None, "model" => self.model = None, "auth.mode" => self.auth_mode = None, @@ -524,27 +585,38 @@ impl ConfigToml { self.providers.deepseek.model = None; self.default_text_model = None; } + "providers.deepseek.http_headers" => { + self.providers.deepseek.http_headers.clear(); + self.http_headers.clear(); + } "providers.openai.api_key" => self.providers.openai.api_key = None, "providers.openai.base_url" => self.providers.openai.base_url = None, "providers.openai.model" => self.providers.openai.model = None, + "providers.openai.http_headers" => self.providers.openai.http_headers.clear(), "providers.nvidia_nim.api_key" => self.providers.nvidia_nim.api_key = None, "providers.nvidia_nim.base_url" => self.providers.nvidia_nim.base_url = None, "providers.nvidia_nim.model" => self.providers.nvidia_nim.model = None, + "providers.nvidia_nim.http_headers" => self.providers.nvidia_nim.http_headers.clear(), "providers.openrouter.api_key" => self.providers.openrouter.api_key = None, "providers.openrouter.base_url" => self.providers.openrouter.base_url = None, "providers.openrouter.model" => self.providers.openrouter.model = None, + "providers.openrouter.http_headers" => self.providers.openrouter.http_headers.clear(), "providers.novita.api_key" => self.providers.novita.api_key = None, "providers.novita.base_url" => self.providers.novita.base_url = None, "providers.novita.model" => self.providers.novita.model = None, + "providers.novita.http_headers" => self.providers.novita.http_headers.clear(), "providers.fireworks.api_key" => self.providers.fireworks.api_key = None, "providers.fireworks.base_url" => self.providers.fireworks.base_url = None, "providers.fireworks.model" => self.providers.fireworks.model = None, + "providers.fireworks.http_headers" => self.providers.fireworks.http_headers.clear(), "providers.sglang.api_key" => self.providers.sglang.api_key = None, "providers.sglang.base_url" => self.providers.sglang.base_url = None, "providers.sglang.model" => self.providers.sglang.model = None, + "providers.sglang.http_headers" => self.providers.sglang.http_headers.clear(), "providers.vllm.api_key" => self.providers.vllm.api_key = None, "providers.vllm.base_url" => self.providers.vllm.base_url = None, "providers.vllm.model" => self.providers.vllm.model = None, + "providers.vllm.http_headers" => self.providers.vllm.http_headers.clear(), _ => { self.extras.remove(key); } @@ -563,6 +635,9 @@ impl ConfigToml { if let Some(v) = self.base_url.as_ref() { out.insert("base_url".to_string(), v.clone()); } + if let Some(v) = serialize_http_headers(&self.http_headers) { + out.insert("http_headers".to_string(), v); + } if let Some(v) = self.default_text_model.as_ref() { out.insert("default_text_model".to_string(), v.clone()); } @@ -602,6 +677,9 @@ impl ConfigToml { if let Some(v) = self.providers.deepseek.model.as_ref() { out.insert("providers.deepseek.model".to_string(), v.clone()); } + if let Some(v) = serialize_http_headers(&self.providers.deepseek.http_headers) { + out.insert("providers.deepseek.http_headers".to_string(), v); + } if let Some(v) = self.providers.openai.api_key.as_ref() { out.insert("providers.openai.api_key".to_string(), redact_secret(v)); } @@ -611,6 +689,9 @@ impl ConfigToml { if let Some(v) = self.providers.openai.model.as_ref() { out.insert("providers.openai.model".to_string(), v.clone()); } + if let Some(v) = serialize_http_headers(&self.providers.openai.http_headers) { + out.insert("providers.openai.http_headers".to_string(), v); + } if let Some(v) = self.providers.nvidia_nim.api_key.as_ref() { out.insert("providers.nvidia_nim.api_key".to_string(), redact_secret(v)); } @@ -620,6 +701,9 @@ impl ConfigToml { if let Some(v) = self.providers.nvidia_nim.model.as_ref() { out.insert("providers.nvidia_nim.model".to_string(), v.clone()); } + if let Some(v) = serialize_http_headers(&self.providers.nvidia_nim.http_headers) { + out.insert("providers.nvidia_nim.http_headers".to_string(), v); + } if let Some(v) = self.providers.openrouter.api_key.as_ref() { out.insert("providers.openrouter.api_key".to_string(), redact_secret(v)); } @@ -629,6 +713,9 @@ impl ConfigToml { if let Some(v) = self.providers.openrouter.model.as_ref() { out.insert("providers.openrouter.model".to_string(), v.clone()); } + if let Some(v) = serialize_http_headers(&self.providers.openrouter.http_headers) { + out.insert("providers.openrouter.http_headers".to_string(), v); + } if let Some(v) = self.providers.novita.api_key.as_ref() { out.insert("providers.novita.api_key".to_string(), redact_secret(v)); } @@ -638,6 +725,9 @@ impl ConfigToml { if let Some(v) = self.providers.novita.model.as_ref() { out.insert("providers.novita.model".to_string(), v.clone()); } + if let Some(v) = serialize_http_headers(&self.providers.novita.http_headers) { + out.insert("providers.novita.http_headers".to_string(), v); + } if let Some(v) = self.providers.fireworks.api_key.as_ref() { out.insert("providers.fireworks.api_key".to_string(), redact_secret(v)); } @@ -647,6 +737,9 @@ impl ConfigToml { if let Some(v) = self.providers.fireworks.model.as_ref() { out.insert("providers.fireworks.model".to_string(), v.clone()); } + if let Some(v) = serialize_http_headers(&self.providers.fireworks.http_headers) { + out.insert("providers.fireworks.http_headers".to_string(), v); + } if let Some(v) = self.providers.sglang.api_key.as_ref() { out.insert("providers.sglang.api_key".to_string(), redact_secret(v)); } @@ -656,6 +749,9 @@ impl ConfigToml { if let Some(v) = self.providers.sglang.model.as_ref() { out.insert("providers.sglang.model".to_string(), v.clone()); } + if let Some(v) = serialize_http_headers(&self.providers.sglang.http_headers) { + out.insert("providers.sglang.http_headers".to_string(), v); + } if let Some(v) = self.providers.vllm.api_key.as_ref() { out.insert("providers.vllm.api_key".to_string(), redact_secret(v)); } @@ -665,6 +761,9 @@ impl ConfigToml { if let Some(v) = self.providers.vllm.model.as_ref() { out.insert("providers.vllm.model".to_string(), v.clone()); } + if let Some(v) = serialize_http_headers(&self.providers.vllm.http_headers) { + out.insert("providers.vllm.http_headers".to_string(), v); + } for (k, v) in &self.extras { out.insert(k.clone(), v.to_string()); @@ -763,6 +862,13 @@ impl ConfigToml { }); let model = normalize_model_for_provider(provider, &model); + let mut http_headers = self.http_headers.clone(); + http_headers.extend(provider_cfg.http_headers.clone()); + if let Some(env_headers) = env.http_headers { + http_headers.extend(env_headers); + } + http_headers.retain(|name, value| !name.trim().is_empty() && !value.trim().is_empty()); + let output_mode = cli .output_mode .clone() @@ -806,6 +912,7 @@ impl ConfigToml { telemetry, approval_policy, sandbox_mode, + http_headers, } } } @@ -820,6 +927,9 @@ fn merge_provider_config(target: &mut ProviderConfigToml, source: &ProviderConfi if source.model.is_some() { target.model = source.model.clone(); } + if !source.http_headers.is_empty() { + target.http_headers = source.http_headers.clone(); + } } /// Load a project-level config from `$WORKSPACE/.deepseek/config.toml`. @@ -930,6 +1040,7 @@ pub struct ResolvedRuntimeOptions { pub telemetry: bool, pub approval_policy: Option, pub sandbox_mode: Option, + pub http_headers: BTreeMap, } #[derive(Debug, Clone)] @@ -1049,6 +1160,42 @@ fn parse_bool(raw: &str) -> Result { } } +fn parse_http_headers(raw: &str) -> Result> { + let mut headers = BTreeMap::new(); + for pair in raw.trim().split(',') { + let pair = pair.trim(); + if pair.is_empty() { + continue; + } + let Some((name, value)) = pair.split_once('=') else { + bail!("invalid header pair '{pair}', expected name=value"); + }; + let name = name.trim(); + let value = value.trim(); + if name.is_empty() { + bail!("header name cannot be empty"); + } + if value.is_empty() { + continue; + } + headers.insert(name.to_string(), value.to_string()); + } + Ok(headers) +} + +fn serialize_http_headers(headers: &BTreeMap) -> Option { + if headers.is_empty() { + return None; + } + Some( + headers + .iter() + .map(|(name, value)| format!("{name}={value}")) + .collect::>() + .join(","), + ) +} + fn redact_secret(secret: &str) -> String { if secret.len() <= 16 { return "********".to_string(); @@ -1066,6 +1213,7 @@ struct EnvRuntimeOverrides { telemetry: Option, approval_policy: Option, sandbox_mode: Option, + http_headers: Option>, deepseek_base_url: Option, nvidia_base_url: Option, openai_base_url: Option, @@ -1091,6 +1239,10 @@ impl EnvRuntimeOverrides { .and_then(|v| parse_bool(&v).ok()), approval_policy: std::env::var("DEEPSEEK_APPROVAL_POLICY").ok(), sandbox_mode: std::env::var("DEEPSEEK_SANDBOX_MODE").ok(), + http_headers: std::env::var("DEEPSEEK_HTTP_HEADERS") + .ok() + .and_then(|value| parse_http_headers(&value).ok()) + .filter(|headers| !headers.is_empty()), deepseek_base_url: std::env::var("DEEPSEEK_BASE_URL") .ok() .filter(|v| !v.trim().is_empty()), @@ -1151,6 +1303,7 @@ mod tests { struct EnvGuard { deepseek_api_key: Option, deepseek_base_url: Option, + deepseek_http_headers: Option, deepseek_model: Option, deepseek_provider: Option, nvidia_api_key: Option, @@ -1175,6 +1328,7 @@ mod tests { let guard = Self { deepseek_api_key: env::var_os("DEEPSEEK_API_KEY"), deepseek_base_url: env::var_os("DEEPSEEK_BASE_URL"), + deepseek_http_headers: env::var_os("DEEPSEEK_HTTP_HEADERS"), deepseek_model: env::var_os("DEEPSEEK_MODEL"), deepseek_provider: env::var_os("DEEPSEEK_PROVIDER"), nvidia_api_key: env::var_os("NVIDIA_API_KEY"), @@ -1197,6 +1351,7 @@ mod tests { unsafe { env::remove_var("DEEPSEEK_API_KEY"); env::remove_var("DEEPSEEK_BASE_URL"); + env::remove_var("DEEPSEEK_HTTP_HEADERS"); env::remove_var("DEEPSEEK_MODEL"); env::remove_var("DEEPSEEK_PROVIDER"); env::remove_var("NVIDIA_API_KEY"); @@ -1233,6 +1388,7 @@ mod tests { unsafe { Self::restore_var("DEEPSEEK_API_KEY", self.deepseek_api_key.take()); Self::restore_var("DEEPSEEK_BASE_URL", self.deepseek_base_url.take()); + Self::restore_var("DEEPSEEK_HTTP_HEADERS", self.deepseek_http_headers.take()); Self::restore_var("DEEPSEEK_MODEL", self.deepseek_model.take()); Self::restore_var("DEEPSEEK_PROVIDER", self.deepseek_provider.take()); Self::restore_var("NVIDIA_API_KEY", self.nvidia_api_key.take()); @@ -1294,6 +1450,75 @@ mod tests { assert_eq!(resolved.model, "deepseek-v4-flash"); } + #[test] + fn provider_http_headers_override_root_headers() { + let _lock = env_lock(); + let _env = EnvGuard::without_deepseek_runtime_overrides(); + let mut config = ConfigToml { + api_key: Some("root-key".to_string()), + base_url: Some("https://api.deepseek.com".to_string()), + default_text_model: Some("deepseek-v4-pro".to_string()), + ..ConfigToml::default() + }; + config.providers.deepseek.api_key = Some("provider-key".to_string()); + config.providers.deepseek.base_url = Some("https://api.deepseeki.com".to_string()); + config.providers.deepseek.model = Some("deepseek-v4-flash".to_string()); + config + .http_headers + .insert("X-Shared".to_string(), "root".to_string()); + config + .providers + .deepseek + .http_headers + .insert("X-Model-Provider-Id".to_string(), "tongyi".to_string()); + config + .providers + .deepseek + .http_headers + .insert("X-Shared".to_string(), "provider".to_string()); + + let resolved = config.resolve_runtime_options(&CliRuntimeOverrides::default()); + + assert_eq!(resolved.api_key.as_deref(), Some("provider-key")); + assert_eq!(resolved.base_url, "https://api.deepseeki.com"); + assert_eq!(resolved.model, "deepseek-v4-flash"); + assert_eq!( + resolved + .http_headers + .get("X-Model-Provider-Id") + .map(String::as_str), + Some("tongyi") + ); + assert_eq!( + resolved.http_headers.get("X-Shared").map(String::as_str), + Some("provider") + ); + } + + #[test] + fn http_headers_env_overrides_config() { + let _lock = env_lock(); + let _env = EnvGuard::without_deepseek_runtime_overrides(); + let mut config = ConfigToml::default(); + config + .http_headers + .insert("X-Model-Provider-Id".to_string(), "from-file".to_string()); + // Safety: test-only environment mutation guarded by a module mutex. + unsafe { + env::set_var("DEEPSEEK_HTTP_HEADERS", "X-Model-Provider-Id=from-env"); + } + + let resolved = config.resolve_runtime_options(&CliRuntimeOverrides::default()); + + assert_eq!( + resolved + .http_headers + .get("X-Model-Provider-Id") + .map(String::as_str), + Some("from-env") + ); + } + #[test] fn nvidia_nim_provider_defaults_to_catalog_endpoint_and_model() { let _lock = env_lock(); diff --git a/crates/tui/src/client.rs b/crates/tui/src/client.rs index 832c3ae0..bfa6719d 100644 --- a/crates/tui/src/client.rs +++ b/crates/tui/src/client.rs @@ -3,11 +3,12 @@ //! DeepSeek documents `/chat/completions` as the primary endpoint, and this //! client now routes all normal traffic through that surface. +use std::collections::HashMap; use std::sync::{Arc, Mutex as StdMutex, OnceLock}; use std::time::{Duration, Instant}; use anyhow::{Context, Result}; -use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue}; +use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue}; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use tokio::sync::Mutex as AsyncMutex; @@ -440,15 +441,22 @@ impl DeepSeekClient { validate_base_url_security(&base_url)?; let retry = config.retry_policy(); let default_model = config.default_model(); + let http_headers = config.http_headers(); logging::info(format!("API provider: {}", api_provider.as_str())); logging::info(format!("API base URL: {base_url}")); + if !http_headers.is_empty() { + logging::info(format!( + "{} custom HTTP header(s) configured", + http_headers.len() + )); + } logging::info(format!( "Retry policy: enabled={}, max_retries={}, initial_delay={}s, max_delay={}s", retry.enabled, retry.max_retries, retry.initial_delay, retry.max_delay )); - let http_client = Self::build_http_client(&api_key)?; + let http_client = Self::build_http_client(&api_key, &http_headers)?; Ok(Self { http_client, @@ -462,15 +470,11 @@ impl DeepSeekClient { }) } - fn build_http_client(api_key: &str) -> Result { - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - if !api_key.trim().is_empty() { - headers.insert( - AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {api_key}"))?, - ); - } + fn build_http_client( + api_key: &str, + extra_headers: &HashMap, + ) -> Result { + let headers = build_default_headers(api_key, extra_headers)?; let mut builder = reqwest::Client::builder() .default_headers(headers) .connect_timeout(Duration::from_secs(30)) @@ -490,6 +494,43 @@ impl DeepSeekClient { builder.build().map_err(Into::into) } + #[cfg(test)] + fn default_headers( + api_key: &str, + extra_headers: &HashMap, + ) -> Result { + build_default_headers(api_key, extra_headers) + } +} + +fn build_default_headers( + api_key: &str, + extra_headers: &HashMap, +) -> Result { + let mut headers = HeaderMap::new(); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + if !api_key.trim().is_empty() { + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {api_key}"))?, + ); + } + for (name, value) in extra_headers { + let name = name.trim(); + let value = value.trim(); + if name.is_empty() || value.is_empty() { + continue; + } + let header_name = HeaderName::from_bytes(name.as_bytes())?; + if header_name == AUTHORIZATION || header_name == CONTENT_TYPE { + continue; + } + headers.insert(header_name, HeaderValue::from_str(value)?); + } + Ok(headers) +} + +impl DeepSeekClient { /// List available models from the provider. pub async fn list_models(&self) -> Result> { let url = api_url(&self.base_url, "models"); @@ -977,6 +1018,27 @@ mod tests { ); } + #[test] + fn default_headers_include_custom_headers_when_configured() { + let mut extra = HashMap::new(); + extra.insert("X-Model-Provider-Id".to_string(), "tongyi".to_string()); + let headers = DeepSeekClient::default_headers("sk-test", &extra).expect("headers"); + assert_eq!( + headers + .get("x-model-provider-id") + .and_then(|value| value.to_str().ok()), + Some("tongyi") + ); + } + + #[test] + fn default_headers_ignore_blank_custom_headers() { + let mut extra = HashMap::new(); + extra.insert("X-Blank".to_string(), " ".to_string()); + let headers = DeepSeekClient::default_headers("sk-test", &extra).expect("headers"); + assert!(headers.get("x-blank").is_none()); + } + #[test] fn chat_messages_keep_reasoning_content_on_all_assistant_messages() { let message = Message { diff --git a/crates/tui/src/config.rs b/crates/tui/src/config.rs index 6e195fca..da462752 100644 --- a/crates/tui/src/config.rs +++ b/crates/tui/src/config.rs @@ -645,6 +645,8 @@ pub struct Config { pub provider: Option, pub api_key: Option, pub base_url: Option, + /// Optional extra HTTP headers sent to model API requests. + pub http_headers: Option>, pub default_text_model: Option, /// DeepSeek reasoning-effort tier: `"off" | "low" | "medium" | "high" | "max"`. /// Defaults to `"max"` at runtime if unset. @@ -896,6 +898,7 @@ pub struct ProviderConfig { pub api_key: Option, pub base_url: Option, pub model: Option, + pub http_headers: Option>, } #[derive(Debug, Clone, Default, Deserialize)] @@ -1101,6 +1104,19 @@ impl Config { self.provider_config_for(self.api_provider()) } + #[must_use] + pub fn http_headers(&self) -> HashMap { + let mut headers = self.http_headers.clone().unwrap_or_default(); + if let Some(provider_headers) = self + .provider_config() + .and_then(|provider| provider.http_headers.as_ref()) + { + headers.extend(provider_headers.clone()); + } + headers.retain(|name, value| !name.trim().is_empty() && !value.trim().is_empty()); + headers + } + #[must_use] pub fn default_model(&self) -> String { let provider = self.api_provider(); @@ -1784,6 +1800,32 @@ fn apply_env_overrides(config: &mut Config) { .vllm .base_url = Some(value); } + if let Ok(value) = std::env::var("DEEPSEEK_HTTP_HEADERS") + && let Ok(headers) = parse_http_headers(&value) + && !headers.is_empty() + { + let mut root_headers = config.http_headers.clone().unwrap_or_default(); + root_headers.extend(headers.clone()); + config.http_headers = Some(root_headers); + + let provider = config.api_provider(); + let providers = config + .providers + .get_or_insert_with(ProvidersConfig::default); + let entry = match provider { + ApiProvider::Deepseek => &mut providers.deepseek, + ApiProvider::DeepseekCN => &mut providers.deepseek_cn, + ApiProvider::NvidiaNim => &mut providers.nvidia_nim, + ApiProvider::Openrouter => &mut providers.openrouter, + ApiProvider::Novita => &mut providers.novita, + ApiProvider::Fireworks => &mut providers.fireworks, + ApiProvider::Sglang => &mut providers.sglang, + ApiProvider::Vllm => &mut providers.vllm, + }; + let mut provider_headers = entry.http_headers.clone().unwrap_or_default(); + provider_headers.extend(headers); + entry.http_headers = Some(provider_headers); + } if matches!(config.api_provider(), ApiProvider::Sglang) && let Ok(value) = std::env::var("SGLANG_MODEL") { @@ -2061,6 +2103,29 @@ fn normalize_base_url(base: &str) -> String { trimmed.to_string() } +fn parse_http_headers(raw: &str) -> Result> { + let mut headers = HashMap::new(); + for pair in raw.trim().split(',') { + let pair = pair.trim(); + if pair.is_empty() { + continue; + } + let Some((name, value)) = pair.split_once('=') else { + anyhow::bail!("invalid header pair '{pair}', expected name=value"); + }; + let name = name.trim(); + let value = value.trim(); + if name.is_empty() { + anyhow::bail!("header name cannot be empty"); + } + if value.is_empty() { + continue; + } + headers.insert(name.to_string(), value.to_string()); + } + Ok(headers) +} + fn apply_profile(config: ConfigFile, profile: Option<&str>) -> Result { if let Some(profile_name) = profile { let profiles = config.profiles.as_ref(); @@ -2095,6 +2160,7 @@ fn merge_config(base: Config, override_cfg: Config) -> Config { provider: override_cfg.provider.or(base.provider), api_key: override_cfg.api_key.or(base.api_key), base_url: override_cfg.base_url.or(base.base_url), + http_headers: override_cfg.http_headers.or(base.http_headers), default_text_model: override_cfg.default_text_model.or(base.default_text_model), reasoning_effort: override_cfg.reasoning_effort.or(base.reasoning_effort), tools_file: override_cfg.tools_file.or(base.tools_file), @@ -2166,6 +2232,7 @@ fn merge_provider_config(base: ProviderConfig, override_cfg: ProviderConfig) -> api_key: override_cfg.api_key.or(base.api_key), base_url: override_cfg.base_url.or(base.base_url), model: override_cfg.model.or(base.model), + http_headers: override_cfg.http_headers.or(base.http_headers), } } @@ -2818,6 +2885,7 @@ mod tests { deepseek_provider: Option, deepseek_api_key: Option, deepseek_base_url: Option, + deepseek_http_headers: Option, deepseek_model: Option, deepseek_default_text_model: Option, nvidia_api_key: Option, @@ -2851,6 +2919,7 @@ mod tests { let deepseek_provider_prev = env::var_os("DEEPSEEK_PROVIDER"); let api_key_prev = env::var_os("DEEPSEEK_API_KEY"); let base_url_prev = env::var_os("DEEPSEEK_BASE_URL"); + let http_headers_prev = env::var_os("DEEPSEEK_HTTP_HEADERS"); let model_prev = env::var_os("DEEPSEEK_MODEL"); let default_text_model_prev = env::var_os("DEEPSEEK_DEFAULT_TEXT_MODEL"); let nvidia_api_key_prev = env::var_os("NVIDIA_API_KEY"); @@ -2879,6 +2948,7 @@ mod tests { env::remove_var("DEEPSEEK_PROVIDER"); env::remove_var("DEEPSEEK_API_KEY"); env::remove_var("DEEPSEEK_BASE_URL"); + env::remove_var("DEEPSEEK_HTTP_HEADERS"); env::remove_var("DEEPSEEK_MODEL"); env::remove_var("DEEPSEEK_DEFAULT_TEXT_MODEL"); env::remove_var("NVIDIA_API_KEY"); @@ -2907,6 +2977,7 @@ mod tests { deepseek_provider: deepseek_provider_prev, deepseek_api_key: api_key_prev, deepseek_base_url: base_url_prev, + deepseek_http_headers: http_headers_prev, deepseek_model: model_prev, deepseek_default_text_model: default_text_model_prev, nvidia_api_key: nvidia_api_key_prev, @@ -2941,6 +3012,7 @@ mod tests { Self::restore_var("DEEPSEEK_PROVIDER", self.deepseek_provider.take()); Self::restore_var("DEEPSEEK_API_KEY", self.deepseek_api_key.take()); Self::restore_var("DEEPSEEK_BASE_URL", self.deepseek_base_url.take()); + Self::restore_var("DEEPSEEK_HTTP_HEADERS", self.deepseek_http_headers.take()); Self::restore_var("DEEPSEEK_MODEL", self.deepseek_model.take()); Self::restore_var( "DEEPSEEK_DEFAULT_TEXT_MODEL", @@ -3774,6 +3846,110 @@ api_key = "old-openrouter-key" Ok(()) } + #[test] + fn http_headers_load_from_root_config() -> Result<()> { + let _lock = lock_test_env(); + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + let temp_root = env::temp_dir().join(format!( + "deepseek-tui-http-headers-root-{}-{}", + std::process::id(), + nanos + )); + fs::create_dir_all(&temp_root)?; + let _guard = EnvGuard::new(&temp_root); + + let config_path = temp_root.join(".deepseek").join("config.toml"); + ensure_parent_dir(&config_path)?; + fs::write( + &config_path, + r#" +api_key = "test-key" +http_headers = { "X-Model-Provider-Id" = "tongyi" } +"#, + )?; + + let config = Config::load(None, None)?; + assert_eq!( + config + .http_headers() + .get("X-Model-Provider-Id") + .map(String::as_str), + Some("tongyi") + ); + Ok(()) + } + + #[test] + fn provider_http_headers_extend_and_override_root_config() { + let mut providers = ProvidersConfig::default(); + providers.deepseek.http_headers = Some(HashMap::from([ + ("X-Model-Provider-Id".to_string(), "tongyi".to_string()), + ("X-Shared".to_string(), "provider".to_string()), + ])); + let config = Config { + http_headers: Some(HashMap::from([ + ("X-Root".to_string(), "root".to_string()), + ("X-Shared".to_string(), "root".to_string()), + ])), + providers: Some(providers), + ..Default::default() + }; + + let headers = config.http_headers(); + assert_eq!( + headers.get("X-Model-Provider-Id").map(String::as_str), + Some("tongyi") + ); + assert_eq!(headers.get("X-Root").map(String::as_str), Some("root")); + assert_eq!( + headers.get("X-Shared").map(String::as_str), + Some("provider") + ); + } + + #[test] + fn http_headers_env_overrides_config() -> Result<()> { + let _lock = lock_test_env(); + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + let temp_root = env::temp_dir().join(format!( + "deepseek-tui-http-headers-env-{}-{}", + std::process::id(), + nanos + )); + fs::create_dir_all(&temp_root)?; + let _guard = EnvGuard::new(&temp_root); + + let config_path = temp_root.join(".deepseek").join("config.toml"); + ensure_parent_dir(&config_path)?; + fs::write( + &config_path, + r#" +api_key = "test-key" +http_headers = { "X-Model-Provider-Id" = "from-file" } +"#, + )?; + // Safety: test-only environment mutation guarded by a global mutex. + unsafe { + env::set_var("DEEPSEEK_HTTP_HEADERS", "X-Model-Provider-Id=from-env"); + } + + let config = Config::load(None, None)?; + assert_eq!( + config + .http_headers() + .get("X-Model-Provider-Id") + .map(String::as_str), + Some("from-env") + ); + Ok(()) + } + #[test] fn nvidia_nim_provider_uses_nim_defaults() -> Result<()> { let config = Config { diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 25b1af30..9ce0b494 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -64,6 +64,16 @@ key, base URL, provider, and model to the TUI process. Use save hosted-provider keys through the facade. SGLang and vLLM are self-hosted and can run without an API key by default. +Third-party OpenAI-compatible gateways that need extra request headers can set +`http_headers = { "X-Model-Provider-Id" = "your-model-provider" }` at the top +level or under a provider table such as `[providers.deepseek]`. When configured, +DeepSeek TUI sends those custom headers on model API requests. The equivalent +environment override is `DEEPSEEK_HTTP_HEADERS`, using comma-separated +`name=value` pairs such as +`X-Model-Provider-Id=your-model-provider,X-Gateway-Route=dev`. `Authorization` +and `Content-Type` are managed by the client and are not overridden by this +setting. + To bootstrap MCP and skills directories at their resolved paths, run `deepseek-tui setup`. To only scaffold MCP, run `deepseek-tui mcp init`. @@ -119,6 +129,7 @@ These override config values: - `DEEPSEEK_API_KEY` - `DEEPSEEK_BASE_URL` +- `DEEPSEEK_HTTP_HEADERS` (custom model request headers, comma-separated `name=value` pairs) - `DEEPSEEK_PROVIDER` (`deepseek|nvidia-nim|openrouter|novita|fireworks|sglang|vllm`) - `DEEPSEEK_MODEL` or `DEEPSEEK_DEFAULT_TEXT_MODEL` - `NVIDIA_API_KEY` or `NVIDIA_NIM_API_KEY` (preferred when provider is `nvidia-nim`; falls back to `DEEPSEEK_API_KEY`)