diff --git a/crates/tui/src/client.rs b/crates/tui/src/client.rs index 5aed06f2..8bcfbef3 100644 --- a/crates/tui/src/client.rs +++ b/crates/tui/src/client.rs @@ -133,22 +133,18 @@ pub struct DeepSeekClient { use_chat_completions: AtomicBool, /// Counter of chat-completions requests since last experimental Responses API probe. /// After RESPONSES_RECOVERY_INTERVAL requests, we retry the Responses API when - /// `DEEPSEEK_EXPERIMENTAL_RESPONSES_API` is set. - chat_fallback_counter: AtomicU32, connection_health: Arc>, rate_limiter: Arc>, } /// After this many chat-completions requests, retry the experimental Responses /// API to see if it has recovered. -const RESPONSES_RECOVERY_INTERVAL: u32 = 20; const CONNECTION_FAILURE_THRESHOLD: u32 = 2; const RECOVERY_PROBE_COOLDOWN: Duration = Duration::from_secs(15); const DEFAULT_CLIENT_RATE_LIMIT_RPS: f64 = 8.0; const DEFAULT_CLIENT_RATE_LIMIT_BURST: f64 = 16.0; const ALLOW_INSECURE_HTTP_ENV: &str = "DEEPSEEK_ALLOW_INSECURE_HTTP"; -const EXPERIMENTAL_RESPONSES_API_ENV: &str = "DEEPSEEK_EXPERIMENTAL_RESPONSES_API"; pub(super) const SSE_BACKPRESSURE_HIGH_WATERMARK: usize = 8 * 1024 * 1024; // 8 MB pub(super) const SSE_BACKPRESSURE_SLEEP_MS: u64 = 10; @@ -309,1187 +305,6 @@ impl Clone for DeepSeekClient { use_chat_completions: AtomicBool::new( self.use_chat_completions.load(Ordering::Relaxed), ), - chat_fallback_counter: AtomicU32::new( - self.chat_fallback_counter.load(Ordering::Relaxed), - ), - connection_health: self.connection_health.clone(), - rate_limiter: self.rate_limiter.clone(), - } - } -} - -// === Helpers === - -/// Maximum bytes to read from an error response body (64 KB). -pub(super) const ERROR_BODY_MAX_BYTES: usize = 64 * 1024; - -/// Read an error response body with a size limit to prevent unbounded allocation. -pub(super) async fn bounded_error_text(response: reqwest::Response, max_bytes: usize) -> String { - use futures_util::StreamExt; - let mut stream = response.bytes_stream(); - let mut buf = Vec::with_capacity(max_bytes.min(8192)); - while let Some(chunk) = stream.next().await { - let Ok(chunk) = chunk else { break }; - let remaining = max_bytes.saturating_sub(buf.len()); - if remaining == 0 { - break; - } - buf.extend_from_slice(&chunk[..chunk.len().min(remaining)]); - } - String::from_utf8_lossy(&buf).into_owned() -} - -fn validate_base_url_security(base_url: &str) -> Result<()> { - if base_url.starts_with("https://") - || base_url.starts_with("http://localhost") - || base_url.starts_with("http://127.0.0.1") - || base_url.starts_with("http://[::1]") - { - return Ok(()); - } - - if base_url.starts_with("http://") - && std::env::var(ALLOW_INSECURE_HTTP_ENV) - .ok() - .as_deref() - .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true")) - { - logging::warn(format!( - "Using insecure HTTP base URL because {} is set", - ALLOW_INSECURE_HTTP_ENV - )); - return Ok(()); - } - - if base_url.starts_with("http://") { - anyhow::bail!( - "Refusing insecure base URL '{}'. Use HTTPS or set {}=1 to override for trusted environments.", - base_url, - ALLOW_INSECURE_HTTP_ENV - ); - } - - anyhow::bail!( - "Refusing base URL '{}': only HTTPS (or explicitly allowed HTTP) URLs are supported.", - base_url, - ) -} - -fn experimental_responses_api_enabled() -> bool { - std::env::var(EXPERIMENTAL_RESPONSES_API_ENV) - .ok() - .as_deref() - .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true")) -} - -pub(super) fn versioned_base_url(base_url: &str) -> String { - let trimmed = base_url.trim_end_matches('/'); - if trimmed.ends_with("/v1") || trimmed.ends_with("/beta") { - trimmed.to_string() - } else { - format!("{trimmed}/v1") - } -} - -pub(super) fn api_url(base_url: &str, path: &str) -> String { - format!( - "{}/{}", - versioned_base_url(base_url).trim_end_matches('/'), - path.trim_start_matches('/') - ) -} - -// === DeepSeekClient === - -/// Returns true when DEEPSEEK_FORCE_HTTP1 is set to a truthy value -/// (`1`, `true`, `yes`, `on`, case-insensitive). Used by `build_http_client` -/// to opt out of HTTP/2 entirely when DeepSeek's edge mishandles long-lived H2 -/// streams (#103). Anything else (unset, `0`, `false`, ...) leaves HTTP/2 on. -fn force_http1_from_env() -> bool { - std::env::var("DEEPSEEK_FORCE_HTTP1") - .ok() - .map(|v| v.trim().to_ascii_lowercase()) - .is_some_and(|v| matches!(v.as_str(), "1" | "true" | "yes" | "on")) -} - -/// Read `SSL_CERT_FILE` and add its contents as extra root -/// certificates on the reqwest builder (#418). Tries the PEM-bundle -/// parser first (covers single-cert files too), then falls back to -/// DER. All failures log a warning and return the builder unchanged -/// so a malformed env var degrades gracefully. -fn add_extra_root_certs( - mut builder: reqwest::ClientBuilder, - cert_path: &str, -) -> reqwest::ClientBuilder { - let bytes = match std::fs::read(cert_path) { - Ok(b) => b, - Err(err) => { - logging::warn(format!( - "SSL_CERT_FILE={cert_path} could not be read: {err}" - )); - return builder; - } - }; - - // PEM bundle handles both single-cert and multi-cert files; try - // it first since `BEGIN CERTIFICATE` framing is the common case. - if let Ok(certs) = reqwest::Certificate::from_pem_bundle(&bytes) { - let added = certs.len(); - for cert in certs { - builder = builder.add_root_certificate(cert); - } - logging::info(format!( - "SSL_CERT_FILE={cert_path} loaded ({added} cert(s))" - )); - return builder; - } - - // Single-cert DER fallback. - match reqwest::Certificate::from_der(&bytes) { - Ok(cert) => { - builder = builder.add_root_certificate(cert); - logging::info(format!("SSL_CERT_FILE={cert_path} loaded (1 DER cert)")); - } - Err(err) => { - logging::warn(format!( - "SSL_CERT_FILE={cert_path} could not be parsed as PEM bundle or DER: {err}" - )); - } - } - builder -} - -impl DeepSeekClient { - /// Create a DeepSeek client from CLI configuration. - pub fn new(config: &Config) -> Result { - let api_key = config.deepseek_api_key()?; - let base_url = config.deepseek_base_url(); - let api_provider = config.api_provider(); - validate_base_url_security(&base_url)?; - let retry = config.retry_policy(); - let default_model = config.default_model(); - - logging::info(format!("API provider: {}", api_provider.as_str())); - logging::info(format!("API base URL: {base_url}")); - logging::info(format!( - "Retry policy: enabled={}, max_retries={}, initial_delay={}s, max_delay={}s", - retry.enabled, retry.max_retries, retry.initial_delay, retry.max_delay - )); - - let http_client = Self::build_http_client(&api_key)?; - - Ok(Self { - http_client, - api_key, - base_url, - api_provider, - retry, - default_model, - use_chat_completions: AtomicBool::new(false), - chat_fallback_counter: AtomicU32::new(0), - connection_health: Arc::new(AsyncMutex::new(ConnectionHealth::default())), - rate_limiter: Arc::new(AsyncMutex::new(TokenBucket::from_env())), - }) - } - - fn build_http_client(api_key: &str) -> Result { - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - if !api_key.trim().is_empty() { - headers.insert( - AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {api_key}"))?, - ); - } - let mut builder = reqwest::Client::builder() - .default_headers(headers) - .connect_timeout(Duration::from_secs(30)) - // The blanket 300s request timeout was incompatible with V4-pro - // thinking turns that legitimately exceed that wall-clock window - // (see #103). Drop it; per-chunk and per-stream guards in - // engine.rs already bound how long we'll wait without progress. - .tcp_keepalive(Some(Duration::from_secs(30))) - .http2_keep_alive_interval(Some(Duration::from_secs(15))) - .http2_keep_alive_timeout(Duration::from_secs(20)) - .min_tls_version(reqwest::tls::Version::TLS_1_2); - // Escape hatch (#103): some DeepSeek edge nodes mishandle long-lived - // HTTP/2 streams. Setting DEEPSEEK_FORCE_HTTP1=1 pins the client to - // HTTP/1.1 so users can experiment without us committing to that - // path as the default. - if force_http1_from_env() { - logging::info("DEEPSEEK_FORCE_HTTP1=1 — pinning HTTP client to HTTP/1.1"); - builder = builder.http1_only(); - } - // #418: corporate-proxy / MITM-inspector CA support. When - // `SSL_CERT_FILE` is set, load the cert(s) it points at and - // add them as trusted roots alongside the platform's system - // store. We try PEM bundle first (the common case for - // multi-cert files), then fall back to single-cert PEM, then - // DER. Failures log a warning and continue — the existing - // system roots still apply, so a malformed env var won't - // bring down the launch. - if let Ok(cert_path) = std::env::var("SSL_CERT_FILE") - && !cert_path.is_empty() - { - builder = add_extra_root_certs(builder, &cert_path); - } - builder.build().map_err(Into::into) - } - - /// List available models from the provider. - pub async fn list_models(&self) -> Result> { - let url = api_url(&self.base_url, "models"); - let response = self.send_with_retry(|| self.http_client.get(&url)).await?; - - let status = response.status(); - if !status.is_success() { - let error_text = bounded_error_text(response, ERROR_BODY_MAX_BYTES).await; - anyhow::bail!("Failed to list models: HTTP {status}: {error_text}"); - } - let response_text = response.text().await.unwrap_or_default(); - - parse_models_response(&response_text) - } - - async fn wait_for_rate_limit(&self) { - let maybe_delay = { - let mut limiter = self.rate_limiter.lock().await; - limiter.delay_until_available(1.0) - }; - if let Some(delay) = maybe_delay { - tokio::time::sleep(delay).await; - } - } - - async fn mark_request_success(&self) { - let mut health = self.connection_health.lock().await; - if apply_request_success(&mut health, Instant::now()) { - logging::info("Connection recovered"); - } - } - - async fn mark_request_failure(&self, reason: &str) { - let mut health = self.connection_health.lock().await; - apply_request_failure(&mut health, Instant::now()); - logging::warn(format!( - "Connection degraded (failures={}): {}", - health.consecutive_failures, reason - )); - } - - async fn maybe_probe_recovery(&self) { - let should_probe = { - let mut health = self.connection_health.lock().await; - mark_recovery_probe_if_due(&mut health, Instant::now()) - }; - if !should_probe { - return; - } - let health_url = api_url(&self.base_url, "models"); - let probe = self.http_client.get(health_url).send().await; - match probe { - Ok(resp) if resp.status().is_success() => { - self.mark_request_success().await; - logging::info("Recovery probe succeeded"); - } - Ok(resp) => { - self.mark_request_failure(&format!("probe status={}", resp.status())) - .await; - } - Err(err) => { - self.mark_request_failure(&format!("probe error={err}")) - .await; - } - } - } - - pub(super) async fn send_with_retry(&self, mut build: F) -> Result - where - F: FnMut() -> reqwest::RequestBuilder, - { - let retry_cfg: LlmRetryConfig = self.retry.clone().into(); - let request_result = with_retry( - &retry_cfg, - || { - let request = build(); - async move { - self.wait_for_rate_limit().await; - let response = request - .send() - .await - .map_err(|err| LlmError::from_reqwest(&err))?; - let status = response.status(); - if status.is_success() { - return Ok(response); - } - let retryable = status.as_u16() == 429 || status.is_server_error(); - if !retryable { - return Ok(response); - } - let retry_after = extract_retry_after(response.headers()); - let body = bounded_error_text(response, ERROR_BODY_MAX_BYTES).await; - Err(LlmError::from_http_response_with_retry_after( - status.as_u16(), - &body, - retry_after, - )) - } - }, - Some(Box::new(|err, attempt, delay| { - let (reason_label, human_reason) = retry_reason_label_and_human(err); - logging::warn(format!( - "HTTP retry reason={} attempt={} delay={:.2}s", - reason_label, - attempt + 1, - delay.as_secs_f64(), - )); - // Light up the foreground retry banner (#499). `attempt` - // here is 0-indexed for the *failed* attempt; surface the - // 1-indexed *upcoming* attempt the user is waiting on. - crate::retry_status::start(attempt + 1, delay, human_reason); - })), - ) - .await; - - match request_result { - Ok(response) => { - crate::retry_status::succeeded(); - self.mark_request_success().await; - Ok(response) - } - Err(err) => { - let last = err.last_error.to_string(); - // Only mark the retry banner failed if at least one - // retry actually fired — non-retryable errors should - // surface as turn errors, not as retry-banner failures. - if err.attempts > 1 { - crate::retry_status::failed(last.clone()); - } else { - crate::retry_status::clear(); - } - self.mark_request_failure(&last).await; - self.maybe_probe_recovery().await; - Err(anyhow::anyhow!(last)) - } - } - } -} - -/// Translate the structured `LlmError` into both a categorical label -/// (for structured logs / metrics) and a short human reason string -/// (for the retry banner). Returning both from one match avoids the -/// double-classification we had before. -fn retry_reason_label_and_human(err: &LlmError) -> (&'static str, String) { - match err { - LlmError::RateLimited { retry_after, .. } => { - let human = if let Some(after) = retry_after { - format!("rate limited (Retry-After {}s)", after.as_secs()) - } else { - "rate limited".to_string() - }; - ("rate_limited", human) - } - LlmError::ServerError { status, .. } => ("server_error", format!("upstream {status}")), - LlmError::NetworkError(_) => ("network_error", "network error".to_string()), - LlmError::Timeout(_) => ("timeout", "timeout".to_string()), - _ => ("other", "other".to_string()), - } -} - -impl LlmClient for DeepSeekClient { - fn provider_name(&self) -> &'static str { - self.api_provider.as_str() - } - - fn model(&self) -> &str { - &self.default_model - } - - async fn health_check(&self) -> Result { - let health_url = api_url(&self.base_url, "models"); - self.wait_for_rate_limit().await; - let response = self.http_client.get(health_url).send().await; - match response { - Ok(resp) if resp.status().is_success() => { - self.mark_request_success().await; - Ok(true) - } - Ok(resp) => { - self.mark_request_failure(&format!("health status={}", resp.status())) - .await; - Ok(false) - } - Err(err) => { - self.mark_request_failure(&format!("health error={err}")) - .await; - Ok(false) - } - } - } - - async fn create_message(&self, request: MessageRequest) -> Result { - if !experimental_responses_api_enabled() { - return self.create_message_chat(&request).await; - } - - // Check if it's time to probe Responses API recovery - if self.use_chat_completions.load(Ordering::Relaxed) { - let count = self.chat_fallback_counter.fetch_add(1, Ordering::Relaxed); - if count > 0 && count.is_multiple_of(RESPONSES_RECOVERY_INTERVAL) { - logging::info("Probing Responses API recovery..."); - let request_clone = request.clone(); - match self.create_message_responses(&request).await? { - Ok(message) => { - logging::info("Responses API recovered! Switching back."); - self.use_chat_completions.store(false, Ordering::Relaxed); - self.chat_fallback_counter.store(0, Ordering::Relaxed); - return Ok(message); - } - Err(_) => { - logging::info("Responses API still unavailable, continuing with chat."); - } - } - return self.create_message_chat(&request_clone).await; - } - return self.create_message_chat(&request).await; - } - - let request_clone = request.clone(); - match self.create_message_responses(&request).await? { - Ok(message) => Ok(message), - Err(fallback) => { - logging::warn(format!( - "Responses API unavailable (HTTP {}). Falling back to chat completions.", - fallback.status - )); - logging::info(format!( - "Responses fallback body: {}", - crate::utils::truncate_with_ellipsis(&fallback.body, 500, "...") - )); - self.use_chat_completions.store(true, Ordering::Relaxed); - self.chat_fallback_counter.store(0, Ordering::Relaxed); - self.create_message_chat(&request_clone).await - } - } - } - - async fn create_message_stream(&self, request: MessageRequest) -> Result { - self.handle_chat_completion_stream(request).await - } -} - -#[derive(Debug, Deserialize)] -struct ModelsListResponse { - data: Vec, -} - -#[derive(Debug, Deserialize)] -struct ModelListItem { - id: String, - #[serde(default)] - owned_by: Option, - #[serde(default)] - created: Option, -} - -pub(super) fn parse_models_response(payload: &str) -> Result> { - let parsed: ModelsListResponse = - serde_json::from_str(payload).context("Failed to parse model list JSON")?; - - let mut models = parsed - .data - .into_iter() - .map(|item| AvailableModel { - id: item.id, - owned_by: item.owned_by, - created: item.created, - }) - .collect::>(); - models.sort_by(|a, b| a.id.cmp(&b.id)); - models.dedup_by(|a, b| a.id == b.id); - Ok(models) -} - -pub(super) fn system_to_instructions(system: Option) -> Option { - match system { - Some(SystemPrompt::Text(text)) => Some(text), - Some(SystemPrompt::Blocks(blocks)) => { - let joined = blocks - .into_iter() - .map(|b| b.text) - .collect::>() - .join("\n\n---\n\n"); - if joined.trim().is_empty() { - None - } else { - Some(joined) - } - } - None => None, - } -} - -pub(super) fn apply_reasoning_effort( - body: &mut Value, - effort: Option<&str>, - provider: ApiProvider, -) { - let Some(effort) = effort else { - return; - }; - let normalized = effort.trim().to_ascii_lowercase(); - match normalized.as_str() { - "off" | "disabled" | "none" | "false" => match provider { - // OpenRouter / Novita relay the same DeepSeek V4 payload shape - // as DeepSeek native; they pass through `thinking` / `reasoning_effort`. - ApiProvider::Deepseek - | ApiProvider::DeepseekCN - | ApiProvider::Openrouter - | ApiProvider::Novita - | ApiProvider::Fireworks - | ApiProvider::Sglang => { - body["thinking"] = json!({ "type": "disabled" }); - } - ApiProvider::NvidiaNim => { - body["chat_template_kwargs"] = json!({ - "thinking": false, - }); - } - }, - "low" | "minimal" | "medium" | "mid" | "high" | "" => match provider { - ApiProvider::Deepseek - | ApiProvider::DeepseekCN - | ApiProvider::Openrouter - | ApiProvider::Novita - | ApiProvider::Fireworks - | ApiProvider::Sglang => { - body["reasoning_effort"] = json!("high"); - body["thinking"] = json!({ "type": "enabled" }); - } - ApiProvider::NvidiaNim => { - body["chat_template_kwargs"] = json!({ - "thinking": true, - "reasoning_effort": "high", - }); - } - }, - "xhigh" | "max" | "highest" => match provider { - ApiProvider::Deepseek - | ApiProvider::DeepseekCN - | ApiProvider::Openrouter - | ApiProvider::Novita - | ApiProvider::Fireworks - | ApiProvider::Sglang => { - body["reasoning_effort"] = json!("max"); - body["thinking"] = json!({ "type": "enabled" }); - } - ApiProvider::NvidiaNim => { - body["chat_template_kwargs"] = json!({ - "thinking": true, - "reasoning_effort": "max", - }); - } - }, - _ => { - // Unknown value — do not mutate the request, let the provider - // apply its own defaults. - } - } -} - -pub(super) fn parse_usage(usage: Option<&Value>) -> Usage { - let input_tokens = usage - .and_then(|u| u.get("input_tokens").or_else(|| u.get("prompt_tokens"))) - .and_then(Value::as_u64) - .unwrap_or(0); - let output_tokens = usage - .and_then(|u| { - u.get("output_tokens") - .or_else(|| u.get("completion_tokens")) - }) - .and_then(Value::as_u64) - .unwrap_or(0); - let prompt_cache_hit_tokens = usage - .and_then(|u| u.get("prompt_cache_hit_tokens")) - .and_then(Value::as_u64) - .map(|v| v as u32); - let prompt_cache_miss_tokens = usage - .and_then(|u| u.get("prompt_cache_miss_tokens")) - .and_then(Value::as_u64) - .map(|v| v as u32); - let reasoning_tokens = usage - .and_then(|u| u.get("completion_tokens_details")) - .and_then(|details| details.get("reasoning_tokens")) - .and_then(Value::as_u64) - .map(|v| v as u32); - - let server_tool_use = usage.and_then(|u| u.get("server_tool_use")).map(|server| { - let code_execution_requests = server - .get("code_execution_requests") - .and_then(Value::as_u64) - .map(|v| v as u32); - let tool_search_requests = server - .get("tool_search_requests") - .and_then(Value::as_u64) - .map(|v| v as u32); - ServerToolUsage { - code_execution_requests, - tool_search_requests, - } - }); - - Usage { - input_tokens: input_tokens as u32, - output_tokens: output_tokens as u32, - prompt_cache_hit_tokens, - prompt_cache_miss_tokens, - reasoning_tokens, - reasoning_replay_tokens: None, - server_tool_use, - } -} - -impl DeepSeekClient { - /// Call the DeepSeek `/beta/completions` FIM endpoint. - /// - /// Returns the generated text (the "middle" between `prompt` and `suffix`). - pub async fn fim_completion( - &self, - model: &str, - prompt: &str, - suffix: &str, - max_tokens: u32, - ) -> anyhow::Result { - let url = api_url(&self.base_url, "beta/completions"); - let body = json!({ - "model": model, - "prompt": prompt, - "suffix": suffix, - "max_tokens": max_tokens, - }); - let response = self - .send_with_retry(|| self.http_client.post(&url).json(&body)) - .await?; - let status = response.status(); - if !status.is_success() { - let error_text = bounded_error_text(response, ERROR_BODY_MAX_BYTES).await; - anyhow::bail!("FIM API error: HTTP {status}: {error_text}"); - } - let response_text = response.text().await.unwrap_or_default(); - let value: serde_json::Value = - serde_json::from_str(&response_text).context("Failed to parse FIM API response")?; - let text = value - .pointer("/choices/0/text") - .and_then(serde_json::Value::as_str) - .ok_or_else(|| anyhow::anyhow!("FIM response missing choices[0].text"))?; - Ok(text.to_string()) - } -} - -mod chat; -mod responses; - -#[cfg(test)] -mod tests { - use super::*; - use crate::client::chat::{ - build_chat_messages, build_chat_messages_for_request, count_reasoning_replay_chars, - parse_chat_message, parse_sse_chunk, sanitize_thinking_mode_messages, tool_to_chat, - }; - use crate::models::{ContentBlock, ContentBlockStart, Delta, Message, StreamEvent, Tool}; - use serde_json::json; - - #[test] - fn tool_name_roundtrip_dot() { - let original = "multi_tool_use.parallel"; - let encoded = to_api_tool_name(original); - assert_eq!(encoded, "multi_tool_use-x00002E-parallel"); - let decoded = from_api_tool_name(&encoded); - assert_eq!(decoded, original); - } - - #[test] - fn tool_name_decode_mangled_dot_prefix() { - // Model replaces leading `-` with `.` in `-x00002E-` - let mangled = "multi_tool_use.x00002E-parallel"; - let decoded = from_api_tool_name(mangled); - assert_eq!(decoded, "multi_tool_use..parallel"); - } - - #[test] - fn tool_name_decode_bare_hex_no_trailing_dash() { - // Bare hex without trailing dash - let mangled = "foo_x00002Ebar"; - let decoded = from_api_tool_name(mangled); - assert_eq!(decoded, "foo_.bar"); - } - - #[test] - fn tool_name_bare_hex_preserves_alnum() { - // x000041 = 'A' — should NOT be decoded (alphanumeric) - let input = "foox000041bar"; - let decoded = from_api_tool_name(input); - assert_eq!(decoded, input); - } - - #[test] - fn tool_name_bare_hex_preserves_underscore() { - // x00005F = '_' — should NOT be decoded - let input = "foox00005Fbar"; - let decoded = from_api_tool_name(input); - assert_eq!(decoded, input); - } - - #[test] - fn tool_name_roundtrip_colon() { - let original = "mcp__server:tool_name"; - let encoded = to_api_tool_name(original); - let decoded = from_api_tool_name(&encoded); - assert_eq!(decoded, original); - } - - #[test] - fn api_url_handles_default_v1_and_beta_base_urls() { - assert_eq!( - api_url("https://api.deepseek.com", "chat/completions"), - "https://api.deepseek.com/v1/chat/completions" - ); - assert_eq!( - api_url("https://api.deepseek.com/v1", "chat/completions"), - "https://api.deepseek.com/v1/chat/completions" - ); - assert_eq!( - api_url("https://api.deepseek.com/beta", "chat/completions"), - "https://api.deepseek.com/beta/chat/completions" - ); - } - - #[test] - fn chat_messages_keep_reasoning_content_on_all_assistant_messages() { - let message = Message { - role: "assistant".to_string(), - content: vec![ - ContentBlock::Thinking { - thinking: "plan".to_string(), - }, - ContentBlock::Text { - text: "done".to_string(), - cache_control: None, - }, - ], - }; - let out = build_chat_messages(None, &[message], "deepseek-v4-pro"); - let assistant = out - .iter() - .find(|value| value.get("role").and_then(Value::as_str) == Some("assistant")) - .expect("assistant message"); - assert_eq!( - assistant.get("content").and_then(Value::as_str), - Some("done") - ); - assert_eq!( - assistant.get("reasoning_content").and_then(Value::as_str), - Some("plan"), - "thinking-mode models must keep reasoning_content on ALL assistant messages" - ); - } - - #[test] - fn chat_messages_keep_thinking_only_assistant_for_v4_flash() { - let message = Message { - role: "assistant".to_string(), - content: vec![ContentBlock::Thinking { - thinking: "plan".to_string(), - }], - }; - let out = build_chat_messages(None, &[message], "deepseek-v4-flash"); - let assistant = out - .iter() - .find(|value| value.get("role").and_then(Value::as_str) == Some("assistant")) - .expect("thinking-only assistant kept for V4 model"); - assert_eq!( - assistant.get("reasoning_content").and_then(Value::as_str), - Some("plan") - ); - } - - #[test] - fn chat_messages_keep_thinking_only_assistant_for_v4_pro() { - let message = Message { - role: "assistant".to_string(), - content: vec![ContentBlock::Thinking { - thinking: "plan".to_string(), - }], - }; - let out = build_chat_messages(None, &[message], "deepseek-v4-pro"); - let assistant = out - .iter() - .find(|value| value.get("role").and_then(Value::as_str) == Some("assistant")) - .expect("thinking-only assistant kept for V4 model"); - assert_eq!( - assistant.get("reasoning_content").and_then(Value::as_str), - Some("plan") - ); - } - - #[test] - fn chat_messages_keep_thinking_only_assistant_for_r_series_model() { - let message = Message { - role: "assistant".to_string(), - content: vec![ContentBlock::Thinking { - thinking: "plan".to_string(), - }], - }; - let out = build_chat_messages(None, &[message], "deepseek-r2-lite-preview"); - let assistant = out - .iter() - .find(|value| value.get("role").and_then(Value::as_str) == Some("assistant")) - .expect("thinking-only assistant kept for R-series model"); - assert_eq!( - assistant.get("reasoning_content").and_then(Value::as_str), - Some("plan") - ); - } - - #[test] - fn chat_messages_preserve_current_tool_round_reasoning_for_reasoner_model() { - let messages = vec![ - Message { - role: "user".to_string(), - content: vec![ContentBlock::Text { - text: "Need the date".to_string(), - cache_control: None, - }], - }, - Message { - role: "assistant".to_string(), - content: vec![ - ContentBlock::Thinking { - thinking: "Need to call a tool".to_string(), - }, - ContentBlock::ToolUse { - id: "tool-1".to_string(), - name: "get_date".to_string(), - input: json!({}), - caller: None, - }, - ], - }, - Message { - role: "user".to_string(), - content: vec![ContentBlock::ToolResult { - tool_use_id: "tool-1".to_string(), - content: "2026-04-23".to_string(), - is_error: None, - content_blocks: None, - }], - }, - ]; - let out = build_chat_messages(None, &messages, "deepseek-v4-pro"); - let assistant = out - .iter() - .find(|value| value.get("role").and_then(Value::as_str) == Some("assistant")) - .expect("assistant message"); - assert_eq!(assistant.get("content").and_then(Value::as_str), Some("")); - assert_eq!( - assistant.get("reasoning_content").and_then(Value::as_str), - Some("Need to call a tool") - ); - } - - #[test] - fn chat_messages_replay_prior_tool_round_reasoning_after_new_user_turn() { - let messages = vec![ - Message { - role: "user".to_string(), - content: vec![ContentBlock::Text { - text: "Need the date".to_string(), - cache_control: None, - }], - }, - Message { - role: "assistant".to_string(), - content: vec![ - ContentBlock::Thinking { - thinking: "Need to call a tool".to_string(), - }, - ContentBlock::ToolUse { - id: "tool-1".to_string(), - name: "get_date".to_string(), - input: json!({}), - caller: None, - }, - ], - }, - Message { - role: "user".to_string(), - content: vec![ContentBlock::ToolResult { - tool_use_id: "tool-1".to_string(), - content: "2026-04-23".to_string(), - is_error: None, - content_blocks: None, - }], - }, - Message { - role: "assistant".to_string(), - content: vec![ContentBlock::Text { - text: "It is 2026-04-23.".to_string(), - cache_control: None, - }], - }, - Message { - role: "user".to_string(), - content: vec![ContentBlock::Text { - text: "Thanks. Next question.".to_string(), - cache_control: None, - }], - }, - ]; - let out = build_chat_messages(None, &messages, "deepseek-v4-pro"); - let tool_assistant = out - .iter() - .find(|value| { - value.get("role").and_then(Value::as_str) == Some("assistant") - && value.get("tool_calls").is_some() - }) - .expect("tool-call assistant message"); - assert_eq!( - tool_assistant - .get("reasoning_content") - .and_then(Value::as_str), - Some("Need to call a tool"), - "DeepSeek thinking mode requires reasoning_content to be replayed for tool-call rounds across all subsequent user turns" - ); - } - - #[test] - fn chat_messages_replay_completed_tool_round_reasoning_after_final_answer() { - let messages = vec![ - Message { - role: "user".to_string(), - content: vec![ContentBlock::Text { - text: "Need the date".to_string(), - cache_control: None, - }], - }, - Message { - role: "assistant".to_string(), - content: vec![ - ContentBlock::Thinking { - thinking: "Need to call a tool".to_string(), - }, - ContentBlock::ToolUse { - id: "tool-1".to_string(), - name: "get_date".to_string(), - input: json!({}), - caller: None, - }, - ], - }, - Message { - role: "user".to_string(), - content: vec![ContentBlock::ToolResult { - tool_use_id: "tool-1".to_string(), - content: "2026-04-23".to_string(), - is_error: None, - content_blocks: None, - }], - }, - Message { - role: "assistant".to_string(), - content: vec![ContentBlock::Text { - text: "It is 2026-04-23.".to_string(), - cache_control: None, - }], - }, - ]; - let out = build_chat_messages(None, &messages, "deepseek-v4-pro"); - let tool_assistant = out - .iter() - .find(|value| { - value.get("role").and_then(Value::as_str) == Some("assistant") - && value.get("tool_calls").is_some() - }) - .expect("tool-call assistant message"); - assert_eq!( - tool_assistant - .get("reasoning_content") - .and_then(Value::as_str), - Some("Need to call a tool") - ); - let final_assistant = out - .iter() - .rfind(|value| value.get("role").and_then(Value::as_str) == Some("assistant")) - .expect("final assistant message"); - assert!( - final_assistant - .get("reasoning_content") - .and_then(Value::as_str) - .is_some_and(|s| !s.trim().is_empty()), - "all assistant messages must carry reasoning_content in thinking mode" - ); - } - - #[test] - fn chat_messages_replay_v4_tool_round_reasoning_after_new_user_turn() { - let messages = vec![ - Message { - role: "user".to_string(), - content: vec![ContentBlock::Text { - text: "Use a tool".to_string(), - cache_control: None, - }], - }, - Message { - role: "assistant".to_string(), - content: vec![ - ContentBlock::Thinking { - thinking: "Need a tool for this".to_string(), - }, - ContentBlock::ToolUse { - id: "call-1".to_string(), - name: "read_file".to_string(), - input: json!({"path": "Cargo.toml"}), - caller: None, - }, - ], - }, - Message { - role: "user".to_string(), - content: vec![ContentBlock::ToolResult { - tool_use_id: "call-1".to_string(), - content: "workspace manifest".to_string(), - is_error: None, - content_blocks: None, - }], - }, - Message { - role: "assistant".to_string(), - content: vec![ContentBlock::Text { - text: "Read it.".to_string(), - cache_control: None, - }], - }, - Message { - role: "user".to_string(), - content: vec![ContentBlock::Text { - text: "Now continue.".to_string(), - cache_control: None, - }], - }, - ]; - - let out = build_chat_messages(None, &messages, "deepseek-v4-pro"); - let tool_assistant = out - .iter() - .find(|value| { - value.get("role").and_then(Value::as_str) == Some("assistant") - && value.get("tool_calls").is_some() - }) - .expect("tool-call assistant message"); - assert_eq!( - tool_assistant - .get("reasoning_content") - .and_then(Value::as_str), - Some("Need a tool for this") - ); - } - - #[test] - fn chat_messages_substitute_placeholder_when_v4_tool_round_missing_reasoning() { - let messages = vec![ - Message { - role: "user".to_string(), - content: vec![ContentBlock::Text { - text: "Use a tool".to_string(), - cache_control: None, - }], - }, - Message { - role: "assistant".to_string(), - content: vec![ContentBlock::ToolUse { - id: "call-without-reasoning".to_string(), - name: "read_file".to_string(), - input: json!({"path": "Cargo.toml"}), - caller: None, - }], - }, - Message { - role: "user".to_string(), - content: vec![ContentBlock::ToolResult { - tool_use_id: "call-without-reasoning".to_string(), - content: "workspace manifest".to_string(), - is_error: None, - content_blocks: None, - }], - }, - ]; - - let out = build_chat_messages(None, &messages, "deepseek-v4-pro"); - - let assistant = out - .iter() - .find(|value| { - value.get("role").and_then(Value::as_str) == Some("assistant") - && value.get("tool_calls").is_some() - }) - .expect("tool-call assistant message should be retained with placeholder"); - assert!( - assistant - .get("reasoning_content") - .and_then(Value::as_str) - .is_some_and(|value| !value.trim().is_empty()), - "missing reasoning_content should be substituted with a non-empty placeholder so the API accepts the request" - ); - assert!( - out.iter() - .any(|value| value.get("role").and_then(Value::as_str) == Some("tool")), - "matching tool_result must remain so the conversation chain stays intact" - ); - } - - #[test] - fn chat_messages_allow_tool_round_without_reasoning_when_thinking_disabled() { - let request = MessageRequest { - model: "deepseek-v4-pro".to_string(), - messages: vec![ - Message { - role: "assistant".to_string(), - content: vec![ContentBlock::ToolUse { - id: "call-no-thinking".to_string(), - name: "read_file".to_string(), - input: json!({"path": "Cargo.toml"}), - caller: None, - }], - }, - Message { - role: "user".to_string(), - content: vec![ContentBlock::ToolResult { - tool_use_id: "call-no-thinking".to_string(), - content: "workspace manifest".to_string(), - is_error: None, - content_blocks: None, - }], - }, - ], - max_tokens: 1024, - system: None, - tools: None, - tool_choice: None, - metadata: None, - thinking: None, - reasoning_effort: Some("off".to_string()), - stream: None, - temperature: None, - top_p: None, - }; - - let out = build_chat_messages_for_request(&request); - assert!( - out.iter().any( - |value| value.get("role").and_then(Value::as_str) == Some("assistant") - && value.get("tool_calls").is_some() - ), "tool calls remain valid when thinking mode is disabled" ); assert!( diff --git a/crates/tui/src/client/responses.rs b/crates/tui/src/client/responses.rs deleted file mode 100644 index 7a5a0169..00000000 --- a/crates/tui/src/client/responses.rs +++ /dev/null @@ -1,406 +0,0 @@ -//! Responses API helpers for the experimental DeepSeek endpoint. -//! -//! Gated behind `DEEPSEEK_EXPERIMENTAL_RESPONSES_API`. Normal traffic uses -//! chat completions via `crate::client::chat`. - -use anyhow::{Context, Result}; -use serde_json::{Value, json}; - -use crate::models::{ContentBlock, Message, MessageRequest, MessageResponse, Tool, ToolCaller}; - -use super::{ - DeepSeekClient, ERROR_BODY_MAX_BYTES, api_url, apply_reasoning_effort, bounded_error_text, - from_api_tool_name, parse_usage, system_to_instructions, to_api_tool_name, -}; - -#[derive(Debug)] -pub(super) struct ResponsesFallback { - pub(super) status: u16, - pub(super) body: String, -} - -impl DeepSeekClient { - pub(super) async fn create_message_responses( - &self, - request: &MessageRequest, - ) -> Result> { - let mut body = json!({ - "model": request.model, - "input": build_responses_input(&request.messages), - "store": false, - "max_output_tokens": request.max_tokens, - }); - - if let Some(instructions) = system_to_instructions(request.system.clone()) { - body["instructions"] = json!(instructions); - } - if let Some(temperature) = request.temperature { - body["temperature"] = json!(temperature); - } - if let Some(top_p) = request.top_p { - body["top_p"] = json!(top_p); - } - if let Some(tools) = request.tools.as_ref() { - body["tools"] = json!(tools.iter().map(tool_to_responses).collect::>()); - } - if let Some(choice) = request.tool_choice.as_ref() { - body["tool_choice"] = choice.clone(); - } - apply_reasoning_effort( - &mut body, - request.reasoning_effort.as_deref(), - self.api_provider, - ); - - let url = api_url(&self.base_url, "responses"); - let response = self - .send_with_retry(|| self.http_client.post(&url).json(&body)) - .await?; - - let status = response.status(); - - if status.as_u16() == 404 || status.as_u16() == 405 { - let body = bounded_error_text(response, ERROR_BODY_MAX_BYTES).await; - return Ok(Err(ResponsesFallback { - status: status.as_u16(), - body, - })); - } - - if !status.is_success() { - let error_text = bounded_error_text(response, ERROR_BODY_MAX_BYTES).await; - anyhow::bail!("Failed to call DeepSeek Responses API: HTTP {status}: {error_text}"); - } - - let response_text = response.text().await.unwrap_or_default(); - let value: Value = - serde_json::from_str(&response_text).context("Failed to parse Responses API JSON")?; - let message = parse_responses_message(&value)?; - Ok(Ok(message)) - } -} - -fn build_responses_input(messages: &[Message]) -> Vec { - let mut items = Vec::new(); - - for message in messages { - let role = message.role.as_str(); - let text_type = if role == "user" { - "input_text" - } else { - "output_text" - }; - - for block in &message.content { - match block { - ContentBlock::Text { text, .. } => { - items.push(json!({ - "type": "message", - "role": role, - "content": [{ - "type": text_type, - "text": text, - }] - })); - } - ContentBlock::ToolUse { - id, - name, - input, - caller, - } => { - let args = serde_json::to_string(input).unwrap_or_else(|_| input.to_string()); - let mut item = json!({ - "type": "function_call", - "call_id": id, - "name": to_api_tool_name(name), - "arguments": args, - }); - if let Some(caller) = caller { - item["caller"] = json!({ - "type": caller.caller_type, - "tool_id": caller.tool_id, - }); - } - items.push(item); - } - ContentBlock::ToolResult { - tool_use_id, - content, - is_error, - .. - } => { - let mut item = json!({ - "type": "function_call_output", - "call_id": tool_use_id, - "output": content, - }); - if let Some(is_error) = is_error { - item["is_error"] = json!(is_error); - } - items.push(item); - } - ContentBlock::Thinking { .. } => {} - ContentBlock::ServerToolUse { id, name, input } => { - items.push(json!({ - "type": "server_tool_use", - "id": id, - "name": name, - "input": input, - })); - } - ContentBlock::ToolSearchToolResult { - tool_use_id, - content, - } => { - items.push(json!({ - "type": "tool_search_tool_result", - "tool_use_id": tool_use_id, - "content": content, - })); - } - ContentBlock::CodeExecutionToolResult { - tool_use_id, - content, - } => { - items.push(json!({ - "type": "code_execution_tool_result", - "tool_use_id": tool_use_id, - "content": content, - })); - } - } - } - } - - items -} - -fn tool_to_responses(tool: &Tool) -> Value { - let tool_type = tool.tool_type.as_deref().unwrap_or("function"); - let mut value = if tool_type == "function" { - json!({ - "type": "function", - "name": to_api_tool_name(&tool.name), - "description": tool.description, - "parameters": tool.input_schema, - }) - } else if tool_type == "code_execution_20250825" { - json!({ - "type": tool_type, - "name": to_api_tool_name(&tool.name), - }) - } else { - json!({ - "type": tool_type, - "name": to_api_tool_name(&tool.name), - "description": tool.description, - "input_schema": tool.input_schema, - }) - }; - - if let Some(allowed_callers) = &tool.allowed_callers { - value["allowed_callers"] = json!(allowed_callers); - } - if let Some(defer_loading) = tool.defer_loading { - value["defer_loading"] = json!(defer_loading); - } - if let Some(input_examples) = &tool.input_examples { - value["input_examples"] = json!(input_examples); - } - if let Some(strict) = tool.strict { - value["strict"] = json!(strict); - } - value -} - -fn parse_responses_message(payload: &Value) -> Result { - let id = payload - .get("id") - .and_then(Value::as_str) - .unwrap_or("response") - .to_string(); - let model = payload - .get("model") - .and_then(Value::as_str) - .unwrap_or("unknown") - .to_string(); - - let usage = parse_usage(payload.get("usage")); - let mut content = Vec::new(); - - if let Some(output) = payload.get("output").and_then(Value::as_array) { - for item in output { - let item_type = item.get("type").and_then(Value::as_str).unwrap_or(""); - match item_type { - "message" => { - if let Some(role) = item.get("role").and_then(Value::as_str) - && role != "assistant" - { - continue; - } - if let Some(content_items) = item.get("content").and_then(Value::as_array) { - for content_item in content_items { - let content_type = content_item - .get("type") - .and_then(Value::as_str) - .unwrap_or("output_text"); - if content_type != "output_text" && content_type != "text" { - continue; - } - if let Some(text) = content_item.get("text").and_then(Value::as_str) - && !text.trim().is_empty() - { - content.push(ContentBlock::Text { - text: text.to_string(), - cache_control: None, - }); - } - } - } - } - "function_call" => { - let call_id = item - .get("call_id") - .or_else(|| item.get("id")) - .and_then(Value::as_str) - .unwrap_or("tool_call") - .to_string(); - let name = item - .get("name") - .and_then(Value::as_str) - .unwrap_or("tool") - .to_string(); - let input = match item.get("arguments") { - Some(Value::String(raw)) => { - serde_json::from_str(raw).unwrap_or_else(|_| Value::String(raw.clone())) - } - Some(other) => other.clone(), - None => Value::Null, - }; - let caller = item.get("caller").and_then(|v| { - v.get("type") - .and_then(Value::as_str) - .map(|caller_type| ToolCaller { - caller_type: caller_type.to_string(), - tool_id: v - .get("tool_id") - .and_then(Value::as_str) - .map(std::string::ToString::to_string), - }) - }); - content.push(ContentBlock::ToolUse { - id: call_id, - name: from_api_tool_name(&name), - input, - caller, - }); - } - "function_call_output" => { - let tool_use_id = item - .get("call_id") - .or_else(|| item.get("tool_use_id")) - .and_then(Value::as_str) - .unwrap_or("tool_call") - .to_string(); - let content_text = item - .get("output") - .or_else(|| item.get("content")) - .map(|v| { - if let Some(s) = v.as_str() { - s.to_string() - } else { - v.to_string() - } - }) - .unwrap_or_default(); - let is_error = item.get("is_error").and_then(Value::as_bool); - content.push(ContentBlock::ToolResult { - tool_use_id, - content: content_text, - is_error, - content_blocks: None, - }); - } - "server_tool_use" => { - let id = item - .get("id") - .and_then(Value::as_str) - .unwrap_or("server_tool") - .to_string(); - let name = item - .get("name") - .and_then(Value::as_str) - .unwrap_or("server_tool") - .to_string(); - let input = item.get("input").cloned().unwrap_or(Value::Null); - content.push(ContentBlock::ServerToolUse { id, name, input }); - } - "tool_search_tool_result" => { - let tool_use_id = item - .get("tool_use_id") - .and_then(Value::as_str) - .unwrap_or("tool_search") - .to_string(); - let content_value = item.get("content").cloned().unwrap_or(Value::Null); - content.push(ContentBlock::ToolSearchToolResult { - tool_use_id, - content: content_value, - }); - } - "code_execution_tool_result" => { - let tool_use_id = item - .get("tool_use_id") - .and_then(Value::as_str) - .unwrap_or("code_execution") - .to_string(); - let content_value = item.get("content").cloned().unwrap_or(Value::Null); - content.push(ContentBlock::CodeExecutionToolResult { - tool_use_id, - content: content_value, - }); - } - "reasoning" => { - if let Some(summary) = item.get("summary").and_then(Value::as_array) { - let summary_text = summary - .iter() - .filter_map(|s| s.get("text").and_then(Value::as_str)) - .collect::>() - .join("\n"); - if !summary_text.trim().is_empty() { - content.push(ContentBlock::Thinking { - thinking: summary_text, - }); - } - } - } - _ => {} - } - } - } - - if content.is_empty() - && let Some(text) = payload.get("output_text").and_then(Value::as_str) - && !text.trim().is_empty() - { - content.push(ContentBlock::Text { - text: text.to_string(), - cache_control: None, - }); - } - - Ok(MessageResponse { - id, - r#type: "message".to_string(), - role: "assistant".to_string(), - content, - model, - stop_reason: None, - stop_sequence: None, - container: payload - .get("container") - .cloned() - .and_then(|v| serde_json::from_value(v).ok()), - usage, - }) -} diff --git a/crates/tui/src/config.rs b/crates/tui/src/config.rs index 4dd7a202..2b11ab0d 100644 --- a/crates/tui/src/config.rs +++ b/crates/tui/src/config.rs @@ -144,8 +144,6 @@ pub struct ProviderCapability { pub enum RequestPayloadMode { /// Standard OpenAI-compatible `/v1/chat/completions` payload. ChatCompletions, - /// Anthropic-style Responses API (DeepSeek experimental). - ResponsesApi, } diff --git a/crates/tui/src/main.rs b/crates/tui/src/main.rs index d7f48d21..612c1473 100644 --- a/crates/tui/src/main.rs +++ b/crates/tui/src/main.rs @@ -48,7 +48,6 @@ mod project_context; mod project_doc; mod prompts; pub mod repl; -mod responses_api_proxy; mod retry_status; pub mod rlm; mod runtime_api; @@ -243,9 +242,6 @@ enum Commands { #[arg(long = "last", default_value_t = false, conflicts_with = "session_id")] last: bool, }, - /// Internal: run the responses API proxy. - #[command(hide = true)] - ResponsesApiProxy(responses_api_proxy::Args), } #[derive(Args, Debug, Clone)] @@ -726,10 +722,6 @@ async fn main() -> Result<()> { let new_session_id = fork_session(session_id, last)?; run_interactive(&cli, &config, Some(new_session_id), None).await } - Commands::ResponsesApiProxy(args) => { - responses_api_proxy::run_main(args)?; - Ok(()) - } }; } diff --git a/crates/tui/src/responses_api_proxy/mod.rs b/crates/tui/src/responses_api_proxy/mod.rs deleted file mode 100644 index abeae463..00000000 --- a/crates/tui/src/responses_api_proxy/mod.rs +++ /dev/null @@ -1,226 +0,0 @@ -use std::fs::{self, File}; -use std::io::Write; -use std::net::{SocketAddr, TcpListener}; -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use std::time::Duration; - -use anyhow::{Context, Result, anyhow}; -use clap::Parser; -use reqwest::Url; -use reqwest::blocking::Client; -use reqwest::header::{AUTHORIZATION, HOST, HeaderMap, HeaderName, HeaderValue}; -use serde::Serialize; -use tiny_http::{Header, Method, Request, Response, Server, StatusCode}; - -mod read_api_key; -use read_api_key::read_auth_header_from_stdin; - -/// CLI arguments for the proxy. -#[derive(Debug, Clone, Parser)] -#[command( - name = "responses-api-proxy", - about = "Minimal DeepSeek responses proxy" -)] -pub struct Args { - /// Port to listen on. If not set, an ephemeral port is used. - #[arg(long)] - pub port: Option, - - /// Path to a JSON file to write startup info (single line). Includes `{"port": 12345}`. - #[arg(long, value_name = "FILE")] - pub server_info: Option, - - /// Enable HTTP shutdown endpoint at GET /shutdown - #[arg(long)] - pub http_shutdown: bool, - - /// Absolute URL the proxy should forward requests to. - #[arg(long, default_value = "https://api.deepseek.com/v1/responses")] - pub upstream_url: String, -} - -#[derive(Serialize)] -struct ServerInfo { - port: u16, - pid: u32, -} - -struct ForwardConfig { - upstream_url: Url, - host_header: HeaderValue, -} - -/// Entry point for the proxy server. -pub fn run_main(args: Args) -> Result<()> { - let auth_header = read_auth_header_from_stdin()?; - - let upstream_url = Url::parse(&args.upstream_url).context("parsing --upstream-url")?; - let host = match (upstream_url.host_str(), upstream_url.port()) { - (Some(host), Some(port)) => format!("{host}:{port}"), - (Some(host), None) => host.to_string(), - _ => return Err(anyhow!("upstream URL must include a host")), - }; - let host_header = - HeaderValue::from_str(&host).context("constructing Host header from upstream URL")?; - - let forward_config = Arc::new(ForwardConfig { - upstream_url, - host_header, - }); - - let (listener, bound_addr) = bind_listener(args.port)?; - if let Some(path) = args.server_info.as_ref() { - write_server_info(path, bound_addr.port())?; - } - let server = Server::from_listener(listener, None) - .map_err(|err| anyhow!("creating HTTP server: {err}"))?; - let client = Arc::new( - Client::builder() - // Disable reqwest's 30s default so long-lived response streams keep flowing. - .timeout(None::) - .build() - .context("building reqwest client")?, - ); - - eprintln!("responses-api-proxy listening on {bound_addr}"); - - let http_shutdown = args.http_shutdown; - for request in server.incoming_requests() { - let client = client.clone(); - let forward_config = forward_config.clone(); - std::thread::spawn(move || { - if http_shutdown && request.method() == &Method::Get && request.url() == "/shutdown" { - let _ = request.respond(Response::new_empty(StatusCode(200))); - std::process::exit(0); - } - - if let Err(e) = forward_request(&client, auth_header, &forward_config, request) { - eprintln!("forwarding error: {e}"); - } - }); - } - - Err(anyhow!("server stopped unexpectedly")) -} - -fn bind_listener(port: Option) -> Result<(TcpListener, SocketAddr)> { - let addr = SocketAddr::from(([127, 0, 0, 1], port.unwrap_or(0))); - let listener = TcpListener::bind(addr).with_context(|| format!("failed to bind {addr}"))?; - let bound = listener.local_addr().context("failed to read local_addr")?; - Ok((listener, bound)) -} - -fn write_server_info(path: &Path, port: u16) -> Result<()> { - if let Some(parent) = path.parent() - && !parent.as_os_str().is_empty() - { - fs::create_dir_all(parent)?; - } - - let info = ServerInfo { - port, - pid: std::process::id(), - }; - let mut data = serde_json::to_string(&info)?; - data.push('\n'); - let mut f = File::create(path)?; - f.write_all(data.as_bytes())?; - Ok(()) -} - -fn forward_request( - client: &Client, - auth_header: &'static str, - config: &ForwardConfig, - mut req: Request, -) -> Result<()> { - // Only allow POST /v1/responses exactly, no query string. - let method = req.method().clone(); - let url_path = req.url().to_string(); - let allow = method == Method::Post && url_path == "/v1/responses"; - - if !allow { - let resp = Response::new_empty(StatusCode(403)); - let _ = req.respond(resp); - return Ok(()); - } - - // Read request body - let mut body = Vec::new(); - let mut reader = req.as_reader(); - std::io::Read::read_to_end(&mut reader, &mut body)?; - - // Build headers for upstream, forwarding everything from the incoming - // request except Authorization (we replace it below). - let mut headers = HeaderMap::new(); - for header in req.headers() { - let name_ascii = header.field.as_str(); - let lower = name_ascii.to_ascii_lowercase(); - if lower.as_str() == "authorization" || lower.as_str() == "host" { - continue; - } - - let header_name = match HeaderName::from_bytes(lower.as_bytes()) { - Ok(name) => name, - Err(_) => continue, - }; - if let Ok(value) = HeaderValue::from_bytes(header.value.as_bytes()) { - headers.append(header_name, value); - } - } - - // As part of our effort to keep `auth_header` secret, we use a - // combination of `from_static()` and `set_sensitive(true)`. - let mut auth_header_value = HeaderValue::from_static(auth_header); - auth_header_value.set_sensitive(true); - headers.insert(AUTHORIZATION, auth_header_value); - - headers.insert(HOST, config.host_header.clone()); - - let upstream_resp = client - .post(config.upstream_url.clone()) - .headers(headers) - .body(body) - .send() - .context("forwarding request to upstream")?; - - // We have to create an adapter between a `reqwest::blocking::Response` - // and a `tiny_http::Response`. Fortunately, `reqwest::blocking::Response` - // implements `Read`, so we can use it directly as the body of the - // `tiny_http::Response`. - let status = upstream_resp.status(); - let mut response_headers = Vec::new(); - for (name, value) in upstream_resp.headers().iter() { - // Skip headers that tiny_http manages itself. - if matches!( - name.as_str(), - "content-length" | "transfer-encoding" | "connection" | "trailer" | "upgrade" - ) { - continue; - } - - if let Ok(header) = Header::from_bytes(name.as_str().as_bytes(), value.as_bytes()) { - response_headers.push(header); - } - } - - let content_length = upstream_resp.content_length().and_then(|len| { - if len <= usize::MAX as u64 { - Some(len as usize) - } else { - None - } - }); - - let response = Response::new( - StatusCode(status.as_u16()), - response_headers, - upstream_resp, - content_length, - None, - ); - - let _ = req.respond(response); - Ok(()) -} diff --git a/crates/tui/src/responses_api_proxy/read_api_key.rs b/crates/tui/src/responses_api_proxy/read_api_key.rs deleted file mode 100644 index 2684241c..00000000 --- a/crates/tui/src/responses_api_proxy/read_api_key.rs +++ /dev/null @@ -1,217 +0,0 @@ -use anyhow::{Context, Result, anyhow}; -use zeroize::Zeroize; - -/// Use a generous buffer size to avoid truncation and to allow for longer API -/// keys in the future. -const BUFFER_SIZE: usize = 1024; -const AUTH_HEADER_PREFIX: &[u8] = b"Bearer "; - -/// Reads the auth token from stdin and returns a static `Authorization` header -/// value with the auth token used with `Bearer`. The header value is returned -/// as a `&'static str` whose bytes are locked in memory to avoid accidental -/// exposure. -#[cfg(unix)] -pub(crate) fn read_auth_header_from_stdin() -> Result<&'static str> { - read_auth_header_with(read_from_unix_stdin) -} - -#[cfg(windows)] -pub(crate) fn read_auth_header_from_stdin() -> Result<&'static str> { - use std::io::Read; - - // Use of `stdio::io::stdin()` has the problem mentioned in the docstring on - // the UNIX version of `read_from_unix_stdin()`, so this should ultimately - // be replaced the low-level Windows equivalent. Because we do not have an - // equivalent of mlock() on Windows right now, it is not pressing until we - // address that issue. - read_auth_header_with(|buffer| std::io::stdin().read(buffer)) -} - -/// We perform a low-level read with `read(2)` because `stdio::io::stdin()` has -/// an internal BufReader: -/// -/// -/// -/// that can end up retaining a copy of stdin data in memory with no way to zero -/// it out, whereas we aim to guarantee there is exactly one copy of the API key -/// in memory, protected by mlock(2). -#[cfg(unix)] -fn read_from_unix_stdin(buffer: &mut [u8]) -> std::io::Result { - use libc::c_void; - use libc::read; - - // Perform a single read(2) call into the provided buffer slice. - // Looping and newline/EOF handling are managed by the caller. - loop { - let result = unsafe { - read( - libc::STDIN_FILENO, - buffer.as_mut_ptr().cast::(), - buffer.len(), - ) - }; - - if result == 0 { - return Ok(0); - } - - if result < 0 { - let err = std::io::Error::last_os_error(); - if err.kind() == std::io::ErrorKind::Interrupted { - continue; - } - return Err(err); - } - - return Ok(result as usize); - } -} - -fn read_auth_header_with(mut read_fn: F) -> Result<&'static str> -where - F: FnMut(&mut [u8]) -> std::io::Result, -{ - // TAKE CARE WHEN MODIFYING THIS CODE!!! - // - // This function goes to great lengths to avoid leaving the API key in - // memory longer than necessary and to avoid copying it around. We read - // directly into a stack buffer so the only heap allocation should be the - // one to create the String (with the exact size) for the header value, - // which we then immediately protect with mlock(2). - let mut buf = [0u8; BUFFER_SIZE]; - buf[..AUTH_HEADER_PREFIX.len()].copy_from_slice(AUTH_HEADER_PREFIX); - - let prefix_len = AUTH_HEADER_PREFIX.len(); - let capacity = buf.len() - prefix_len; - let mut total_read = 0usize; // number of bytes read into the token region - let mut saw_newline = false; - let mut saw_eof = false; - - while total_read < capacity { - let slice = &mut buf[prefix_len + total_read..]; - let read = match read_fn(slice) { - Ok(n) => n, - Err(err) => { - buf.zeroize(); - return Err(err.into()); - } - }; - - if read == 0 { - saw_eof = true; - break; - } - - // Search only the newly written region for a newline. - let newly_written = &slice[..read]; - if let Some(pos) = newly_written.iter().position(|&b| b == b'\n') { - total_read += pos + 1; // include the newline for trimming below - saw_newline = true; - break; - } - - total_read += read; - - // Continue loop; if buffer fills without newline/EOF we'll error below. - } - - // If buffer filled and we did not see newline or EOF, error out. - if total_read == capacity && !saw_newline && !saw_eof { - buf.zeroize(); - return Err(anyhow!( - "API key is too large to fit in the {BUFFER_SIZE}-byte buffer" - )); - } - - let mut total = prefix_len + total_read; - while total > prefix_len && (buf[total - 1] == b'\n' || buf[total - 1] == b'\r') { - total -= 1; - } - - if total == AUTH_HEADER_PREFIX.len() { - buf.zeroize(); - return Err(anyhow!( - "API key must be provided via stdin (e.g. printenv DEEPSEEK_API_KEY | deepseek responses-api-proxy)" - )); - } - - if let Err(err) = validate_auth_header_bytes(&buf[AUTH_HEADER_PREFIX.len()..total]) { - buf.zeroize(); - return Err(err); - } - - let header_str = match std::str::from_utf8(&buf[..total]) { - Ok(value) => value, - Err(err) => { - // In theory, validate_auth_header_bytes() should have caught - // any invalid UTF-8 sequences, but just in case... - buf.zeroize(); - return Err(err).context("reading Authorization header from stdin as UTF-8"); - } - }; - - let header_value = String::from(header_str); - buf.zeroize(); - - let leaked: &'static mut str = header_value.leak(); - mlock_str(leaked); - - Ok(leaked) -} - -#[cfg(unix)] -fn mlock_str(value: &str) { - use libc::_SC_PAGESIZE; - use libc::c_void; - use libc::mlock; - use libc::sysconf; - - if value.is_empty() { - return; - } - - let page_size = unsafe { sysconf(_SC_PAGESIZE) }; - if page_size <= 0 { - return; - } - let page_size = page_size as usize; - if page_size == 0 { - return; - } - - let addr = value.as_ptr() as usize; - let len = value.len(); - let start = addr & !(page_size - 1); - let addr_end = match addr.checked_add(len) { - Some(v) => match v.checked_add(page_size - 1) { - Some(total) => total, - None => return, - }, - None => return, - }; - let end = addr_end & !(page_size - 1); - let size = end.saturating_sub(start); - if size == 0 { - return; - } - - let _ = unsafe { mlock(start as *const c_void, size) }; -} - -#[cfg(not(unix))] -fn mlock_str(_value: &str) {} - -/// The key should match /^[A-Za-z0-9\-_]+$/. Ensure there is no funny business -/// with NUL characters and whatnot. -fn validate_auth_header_bytes(key_bytes: &[u8]) -> Result<()> { - if key_bytes - .iter() - .all(|byte| byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'_')) - { - return Ok(()); - } - - Err(anyhow!( - "API key may only contain ASCII letters, numbers, '-' or '_'" - )) -}