Files
codewhale/crates/tui/src/mcp.rs
T

5878 lines
213 KiB
Rust

//! Async MCP (Model Context Protocol) Implementation
//!
//! This module provides full async support for MCP servers with:
//! - Connection pooling for server reuse
//! - Automatic tool discovery via `tools/list`
//! - Configurable timeouts per-server and globally
use std::collections::{HashMap, VecDeque};
use std::fs;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use anyhow::{Context, Result};
use reqwest::StatusCode;
use reqwest::header::{ACCEPT, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use tokio::process::{Child, ChildStdin, ChildStdout};
use tokio::sync::Mutex as TokioMutex;
use crate::child_env;
use crate::network_policy::{Decision, NetworkPolicyDecider, host_from_url};
use crate::utils::write_atomic;
// === Error diagnostics helpers (#71) ===
/// Bytes of a non-2xx response body to surface in connection errors.
const ERROR_BODY_PREVIEW_BYTES: usize = 200;
const MCP_HTTP_ACCEPT: &str = "application/json, text/event-stream";
fn with_default_mcp_http_headers(
request: reqwest::RequestBuilder,
json_body: bool,
) -> reqwest::RequestBuilder {
let request = request.header(ACCEPT, MCP_HTTP_ACCEPT);
if json_body {
request.header(CONTENT_TYPE, "application/json")
} else {
request
}
}
fn validate_mcp_config_path(path: &Path) -> Result<()> {
if path.as_os_str().is_empty() {
anyhow::bail!("MCP config path cannot be empty");
}
if path
.components()
.any(|component| matches!(component, Component::ParentDir))
{
anyhow::bail!("MCP config path cannot contain '..' components");
}
Ok(())
}
/// Predicate for [`StreamableHttpTransport::send`]'s custom-header pass.
///
/// We accept whatever reqwest's `HeaderName::try_from` /
/// `HeaderValue::try_from` would accept, but with three extra rules:
///
/// 1. Reject empty / whitespace-only keys — these would surface as a
/// request-builder error mid-send and abort the whole connection.
/// 2. Reject keys that duplicate the framing we already emit
/// (`Accept`, `Content-Type`). The MCP Streamable HTTP transport
/// relies on those exact values for protocol negotiation; a stray
/// user override could silently break tool discovery.
/// 3. Reject values containing ASCII CR or LF. reqwest already
/// rejects those, but the explicit check makes the failure path
/// visible (a `tracing::warn!` instead of an obscure
/// builder error) and documents the response-splitting
/// defense.
///
/// Returning `false` means "skip this header"; the rest of the
/// request still goes out.
fn is_safe_custom_header(key: &str, value: &str) -> bool {
let trimmed = key.trim();
if trimmed.is_empty() {
return false;
}
if trimmed.eq_ignore_ascii_case("accept") || trimmed.eq_ignore_ascii_case("content-type") {
return false;
}
!value.contains('\r') && !value.contains('\n')
}
fn apply_safe_custom_headers(
mut request: reqwest::RequestBuilder,
headers: &HashMap<String, String>,
) -> reqwest::RequestBuilder {
for (key, value) in headers {
if !is_safe_custom_header(key, value) {
tracing::warn!(
target: "mcp",
"skipping unsafe MCP header {:?} (empty/control-char/reserved)",
key
);
continue;
}
request = request.header(key.as_str(), value.as_str());
}
request
}
/// Mask a URL so any embedded credentials in the userinfo portion (e.g.
/// `https://user:secret@host`) are replaced with `***`. Failures fall back to
/// the original string so we don't lose context — we never want masking to
/// produce an empty error.
fn mask_url_secrets(url: &str) -> String {
if let Ok(parsed) = reqwest::Url::parse(url) {
let mut clone = parsed.clone();
if !parsed.username().is_empty() || parsed.password().is_some() {
let _ = clone.set_username("***");
let _ = clone.set_password(Some("***"));
}
return clone.to_string();
}
url.to_string()
}
/// Redact the userinfo segment (`username[:password]@…` portion) from
/// a proxy URL so it can be safely included in `tracing::warn!` output
/// without leaking the
/// password into the on-disk log. URLs without userinfo are returned
/// unchanged. Garbage input (no `://` scheme separator) is also returned
/// unchanged — the malformed-URL warning path is the only caller, so an
/// unparseable input is already the failure case.
fn redact_proxy_userinfo(proxy_url: &str) -> String {
let Some(scheme_end) = proxy_url.find("://") else {
return proxy_url.to_string();
};
let after_scheme = scheme_end + 3;
// The userinfo segment ends at the next `@`, but only if that `@`
// comes before the next `/`, `?`, or `#` (otherwise the `@` is in a
// path / query and the URL has no userinfo at all).
let rest = &proxy_url[after_scheme..];
let at_idx = rest.find('@');
let path_idx = rest.find(['/', '?', '#']);
let userinfo_end = match (at_idx, path_idx) {
(Some(a), Some(p)) if a < p => Some(a),
(Some(a), None) => Some(a),
_ => None,
};
if let Some(end) = userinfo_end {
let mut out = String::with_capacity(proxy_url.len());
out.push_str(&proxy_url[..after_scheme]);
out.push_str("***@");
out.push_str(&rest[end + 1..]);
out
} else {
proxy_url.to_string()
}
}
/// Mask any obvious token-like substrings in a body excerpt before surfacing
/// it. Conservative: replaces `Bearer <token>` and `api_key=...` shapes.
fn redact_body_preview(body: &str) -> String {
let mut out = body.to_string();
if let Some(idx) = out.to_lowercase().find("bearer ") {
let tail_start = idx + "bearer ".len();
if tail_start < out.len() {
let end = out[tail_start..]
.find(|c: char| c.is_whitespace() || c == '"' || c == ',')
.map_or(out.len(), |off| tail_start + off);
out.replace_range(tail_start..end, "***");
}
}
for needle in ["api_key=", "apikey=", "api-key=", "token="] {
if let Some(idx) = out.to_lowercase().find(needle) {
let tail_start = idx + needle.len();
let end = out[tail_start..]
.find(|c: char| c.is_whitespace() || c == '&' || c == '"' || c == ',')
.map_or(out.len(), |off| tail_start + off);
out.replace_range(tail_start..end, "***");
}
}
out
}
/// Read up to `max_bytes` of a reqwest Response body and produce a single-line
/// excerpt suitable for an error message. Best-effort — if the body can't be
/// read, returns the literal string `<no body>`.
async fn bounded_body_excerpt(response: reqwest::Response, max_bytes: usize) -> String {
let body_text = response.text().await.unwrap_or_default();
if body_text.is_empty() {
return "<no body>".to_string();
}
let trimmed: String = body_text.chars().take(max_bytes).collect();
let suffix = if body_text.len() > trimmed.len() {
""
} else {
""
};
let one_line = trimmed.replace(['\n', '\r'], " ");
format!("{}{}", redact_body_preview(&one_line), suffix)
}
fn invalid_json_preview(bytes: &[u8]) -> String {
let body_text = String::from_utf8_lossy(bytes);
if body_text.is_empty() {
return "<empty>".to_string();
}
let trimmed: String = body_text.chars().take(ERROR_BODY_PREVIEW_BYTES).collect();
let suffix = if body_text.chars().count() > ERROR_BODY_PREVIEW_BYTES {
""
} else {
""
};
let one_line = trimmed.replace(['\n', '\r'], " ");
format!("{}{}", redact_body_preview(&one_line), suffix)
}
// === Configuration Types ===
/// Full MCP configuration from mcp.json
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct McpConfig {
#[serde(default)]
pub timeouts: McpTimeouts,
#[serde(default, alias = "mcpServers")]
pub servers: HashMap<String, McpServerConfig>,
}
/// Global timeout configuration
#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
#[allow(clippy::struct_field_names)]
pub struct McpTimeouts {
#[serde(default = "default_connect_timeout")]
pub connect_timeout: u64,
#[serde(default = "default_execute_timeout")]
pub execute_timeout: u64,
#[serde(default = "default_read_timeout")]
pub read_timeout: u64,
}
fn default_connect_timeout() -> u64 {
10
}
fn default_execute_timeout() -> u64 {
60
}
fn default_read_timeout() -> u64 {
120
}
impl Default for McpTimeouts {
fn default() -> Self {
Self {
connect_timeout: default_connect_timeout(),
execute_timeout: default_execute_timeout(),
read_timeout: default_read_timeout(),
}
}
}
/// Configuration for a single MCP server
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpServerConfig {
pub command: Option<String>,
#[serde(default)]
pub args: Vec<String>,
#[serde(default)]
pub env: HashMap<String, String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub cwd: Option<PathBuf>,
pub url: Option<String>,
/// Optional explicit HTTP transport override.
///
/// By default URL-based MCP servers use Streamable HTTP first and fall
/// back to legacy SSE only when the server rejects Streamable HTTP with
/// a known incompatible status. Set this to `"sse"` for legacy SSE
/// endpoints that must start with a long-lived GET endpoint discovery
/// stream and cannot accept an initial POST to the configured URL.
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub transport: Option<String>,
#[serde(default)]
pub connect_timeout: Option<u64>,
#[serde(default)]
pub execute_timeout: Option<u64>,
#[serde(default)]
pub read_timeout: Option<u64>,
#[serde(default)]
pub disabled: bool,
#[serde(default = "default_enabled")]
pub enabled: bool,
#[serde(default)]
pub required: bool,
#[serde(default)]
pub enabled_tools: Vec<String>,
#[serde(default)]
pub disabled_tools: Vec<String>,
/// Extra HTTP headers sent with every request to this MCP server.
/// Only the HTTP transports (streamable HTTP today; SSE in a
/// follow-up) honor this — `command`-based stdio servers ignore it.
///
/// Mirrors the `headers` field that Claude Code, Codex, and
/// OpenCode already accept in their MCP config formats. Use it to
/// authenticate against gateways that require a Bearer token or
/// API key, e.g.:
///
/// ```jsonc
/// "huggingface": {
/// "url": "https://huggingface.co/api/mcp",
/// "headers": { "Authorization": "Bearer ${HF_TOKEN}" }
/// }
/// ```
///
/// Header keys and values are passed through as-is — we do not
/// substitute environment variables in v0.8.31. If you store a
/// real token here, the value lives in plain text in
/// `~/.deepseek/mcp.json`; treat that file with the same care
/// as any other secret-bearing config.
#[serde(default)]
#[serde(skip_serializing_if = "HashMap::is_empty")]
pub headers: HashMap<String, String>,
}
fn default_enabled() -> bool {
true
}
impl McpServerConfig {
pub fn effective_connect_timeout(&self, global: &McpTimeouts) -> u64 {
self.connect_timeout.unwrap_or(global.connect_timeout)
}
pub fn effective_execute_timeout(&self, global: &McpTimeouts) -> u64 {
self.execute_timeout.unwrap_or(global.execute_timeout)
}
pub fn effective_read_timeout(&self, global: &McpTimeouts) -> u64 {
self.read_timeout.unwrap_or(global.read_timeout)
}
pub fn is_enabled(&self) -> bool {
self.enabled && !self.disabled
}
pub fn is_tool_enabled(&self, tool_name: &str) -> bool {
let allowed = if self.enabled_tools.is_empty() {
true
} else {
self.enabled_tools.iter().any(|t| t == tool_name)
};
if !allowed {
return false;
}
!self.disabled_tools.iter().any(|t| t == tool_name)
}
}
// === MCP Tool Definition ===
/// Tool discovered from an MCP server
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpTool {
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(rename = "inputSchema", default)]
pub input_schema: serde_json::Value,
}
/// Resource discovered from an MCP server
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpResource {
pub uri: String,
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(rename = "mimeType", default)]
pub mime_type: Option<String>,
}
/// Resource template discovered from an MCP server
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpResourceTemplate {
#[serde(rename = "uriTemplate")]
pub uri_template: String,
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(rename = "mimeType", default)]
pub mime_type: Option<String>,
}
/// Prompt discovered from an MCP server
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpPrompt {
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub arguments: Vec<McpPromptArgument>,
}
/// Argument for an MCP prompt
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpPromptArgument {
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub required: bool,
}
// === Connection State ===
/// State of an MCP connection
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Connecting,
Ready,
Disconnected,
}
// === McpConnection - Async Connection Management ===
// === Transport Trait ===
#[async_trait::async_trait]
pub trait McpTransport: Send + Sync {
async fn send(&mut self, msg: Vec<u8>) -> Result<()>;
async fn recv(&mut self) -> Result<Vec<u8>>;
/// Graceful shutdown — stdio transports send SIGTERM to the child and
/// give it a brief window to exit before tokio's `kill_on_drop` fires
/// SIGKILL as the backstop. Default is a no-op for non-stdio transports
/// that have no child process. Whalescale#420.
async fn shutdown(&mut self) {}
}
pub struct StdioTransport {
child: Child,
stdin: ChildStdin,
reader: tokio::io::BufReader<ChildStdout>,
/// Tail of stderr lines from the spawned MCP server. A background task
/// drains the child's stderr into this buffer so a mid-run crash leaves
/// some context behind instead of `Stdio::null` swallowing it.
stderr_tail: Arc<StderrTail>,
}
/// How long `StdioTransport::shutdown` waits for the child to exit on SIGTERM
/// before `kill_on_drop` fires SIGKILL. Tuned short so a hung MCP server
/// can't stall TUI exit; well-behaved servers almost always exit within
/// a few hundred ms.
const STDIO_SHUTDOWN_GRACE: Duration = Duration::from_millis(2_000);
/// How many lines of MCP-server stderr to keep around for crash diagnostics.
/// Bounded so a chatty server can't grow this without limit; large enough to
/// catch typical Node/Python startup or panic output.
const STDERR_TAIL_CAPACITY: usize = 64;
/// Bounded ring buffer for the most recent stderr lines from a spawned MCP
/// server. Used by `StdioTransport` to surface server-side context when the
/// transport read side fails (server crashed, exited early, etc).
#[derive(Default)]
pub struct StderrTail {
lines: TokioMutex<VecDeque<String>>,
}
impl StderrTail {
fn new() -> Arc<Self> {
Arc::new(Self {
lines: TokioMutex::new(VecDeque::with_capacity(STDERR_TAIL_CAPACITY)),
})
}
async fn push(&self, line: String) {
let mut buf = self.lines.lock().await;
if buf.len() >= STDERR_TAIL_CAPACITY {
buf.pop_front();
}
buf.push_back(line);
}
async fn snapshot(&self) -> Vec<String> {
self.lines.lock().await.iter().cloned().collect()
}
}
/// Format the captured stderr tail for inclusion in an error message. Empty
/// tails return `None` so the caller can fall back to its original message.
async fn format_stderr_context(tail: &StderrTail) -> Option<String> {
let lines = tail.snapshot().await;
if lines.is_empty() {
return None;
}
Some(format!(
"MCP server stderr (last {} line{}):\n{}",
lines.len(),
if lines.len() == 1 { "" } else { "s" },
lines.join("\n"),
))
}
/// Best-effort SIGTERM. On Unix uses `libc::kill`; on Windows there's no
/// equivalent so we let `kill_on_drop` (TerminateProcess) handle it via the
/// subsequent Drop. Returns whether a signal was actually sent.
fn send_sigterm(child: &Child) -> bool {
#[cfg(unix)]
{
if let Some(pid) = child.id() {
// SAFETY: pid was just obtained from `child.id()`. `libc::kill`
// with `SIGTERM` is async-signal-safe and never observes invalid
// memory. Worst case (pid wrap / process already gone) returns
// ESRCH, which we deliberately ignore.
unsafe {
let _ = libc::kill(pid as i32, libc::SIGTERM);
}
return true;
}
false
}
#[cfg(not(unix))]
{
let _ = child;
false
}
}
#[async_trait::async_trait]
impl McpTransport for StdioTransport {
async fn send(&mut self, mut msg: Vec<u8>) -> Result<()> {
msg.push(b'\n');
self.stdin.write_all(&msg).await?;
self.stdin.flush().await?;
Ok(())
}
async fn recv(&mut self) -> Result<Vec<u8>> {
let mut line = String::new();
loop {
line.clear();
let bytes = match self.reader.read_line(&mut line).await {
Ok(b) => b,
Err(err) => {
if let Some(stderr) = format_stderr_context(&self.stderr_tail).await {
anyhow::bail!("Stdio transport read error: {err}\n{stderr}");
}
return Err(err.into());
}
};
if bytes == 0 {
if let Some(stderr) = format_stderr_context(&self.stderr_tail).await {
anyhow::bail!("Stdio transport closed\n{stderr}");
}
anyhow::bail!("Stdio transport closed");
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
return Ok(trimmed.as_bytes().to_vec());
}
}
/// Send SIGTERM and wait up to `STDIO_SHUTDOWN_GRACE` for graceful exit
/// before letting Drop / `kill_on_drop` fire SIGKILL as the backstop.
async fn shutdown(&mut self) {
send_sigterm(&self.child);
// Give the child a window to exit cleanly. Discard the result —
// either it exits (success) or the timeout fires (Drop will SIGKILL).
let _ = tokio::time::timeout(STDIO_SHUTDOWN_GRACE, self.child.wait()).await;
}
}
/// Drop fallback (#420): if `shutdown` was never called explicitly, still
/// fire SIGTERM before tokio's `kill_on_drop` sends SIGKILL. The two
/// signals arrive back-to-back so well-behaved servers at least see the
/// SIGTERM first; misbehaving ones get SIGKILL'd anyway.
impl Drop for StdioTransport {
fn drop(&mut self) {
send_sigterm(&self.child);
}
}
pub struct SseTransport {
client: reqwest::Client,
base_url: String,
headers: HashMap<String, String>,
endpoint_url: Option<String>,
receiver: tokio::sync::mpsc::UnboundedReceiver<SseInbound>,
pending_messages: VecDeque<Vec<u8>>,
#[allow(dead_code)]
sse_task: tokio::task::JoinHandle<()>,
}
enum SseInbound {
Endpoint(String),
Message(Vec<u8>),
}
struct HttpTransport {
mode: HttpTransportMode,
client: reqwest::Client,
base_url: String,
headers: HashMap<String, String>,
cancel_token: tokio_util::sync::CancellationToken,
endpoint_timeout: Duration,
}
enum HttpTransportMode {
Streamable(StreamableHttpTransport),
Sse(SseTransport),
}
struct StreamableHttpTransport {
client: reqwest::Client,
url: String,
/// Extra headers applied to every outbound POST. Populated from
/// [`McpServerConfig::headers`]; an empty map is the no-auth
/// default. See `apply_custom_headers` for the filtering pass that
/// runs before each request.
headers: HashMap<String, String>,
pending_messages: VecDeque<Vec<u8>>,
/// Per-spec MCP session identifier returned by the server in the
/// first response (typically the `initialize` response). Attached
/// as the `Mcp-Session-Id` header on every subsequent outbound
/// request so the server can correlate messages within the same
/// session.
session_id: Option<String>,
}
#[derive(Debug)]
enum StreamableSendError {
Incompatible(String),
StaleSession(String),
Other(anyhow::Error),
}
impl SseTransport {
pub async fn connect(
client: reqwest::Client,
url: String,
headers: HashMap<String, String>,
cancel_token: tokio_util::sync::CancellationToken,
endpoint_timeout: Duration,
) -> Result<Self> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let client_clone = client.clone();
let url_clone = url.clone();
let headers_clone = headers.clone();
let wait_cancel_token = cancel_token.clone();
let sse_task = tokio::spawn(async move {
if cancel_token.is_cancelled() {
return;
}
use futures_util::FutureExt;
let result = std::panic::AssertUnwindSafe(Self::run_sse_loop(
client_clone,
url_clone,
headers_clone,
tx,
cancel_token,
))
.catch_unwind()
.await;
match result {
Ok(res) => {
if let Err(e) = res {
tracing::error!("SSE loop error: {}", e);
}
}
Err(panic_err) => {
if let Some(msg) = panic_err.downcast_ref::<&str>() {
tracing::error!("SSE loop panicked: {}", msg);
} else if let Some(msg) = panic_err.downcast_ref::<String>() {
tracing::error!("SSE loop panicked: {}", msg);
} else {
tracing::error!("SSE loop panicked with unknown error");
}
}
}
});
let mut transport = Self {
client,
base_url: url,
headers,
endpoint_url: None,
receiver: rx,
pending_messages: VecDeque::new(),
sse_task,
};
transport
.wait_for_endpoint(&wait_cancel_token, endpoint_timeout)
.await?;
Ok(transport)
}
async fn run_sse_loop(
client: reqwest::Client,
url: String,
headers: HashMap<String, String>,
tx: tokio::sync::mpsc::UnboundedSender<SseInbound>,
cancel_token: tokio_util::sync::CancellationToken,
) -> Result<()> {
let response = apply_safe_custom_headers(
with_default_mcp_http_headers(client.get(&url), false),
&headers,
)
.send()
.await
.with_context(|| {
format!(
"MCP SSE connect failed (transport=http url={})",
mask_url_secrets(&url),
)
})?;
let status = response.status();
if !status.is_success() {
let body_excerpt = bounded_body_excerpt(response, ERROR_BODY_PREVIEW_BYTES).await;
anyhow::bail!(
"MCP SSE rejected (transport=http url={} status={}): {}",
mask_url_secrets(&url),
status,
body_excerpt,
);
}
let mut stream = response.bytes_stream();
use futures_util::StreamExt;
let mut buffer = String::new();
loop {
if cancel_token.is_cancelled() {
tracing::debug!("SSE loop cancelled");
break;
}
let item = tokio::select! {
_ = cancel_token.cancelled() => {
tracing::debug!("SSE loop shutting down");
break;
}
item = stream.next() => {
match item {
Some(i) => i,
None => break,
}
}
};
let chunk = item?;
let s = String::from_utf8_lossy(&chunk);
buffer.push_str(&s);
while let Some((pos, separator_len)) = find_sse_event_separator(&buffer) {
let event_block = buffer[..pos].to_string();
buffer = buffer[pos + separator_len..].to_string();
let mut event_type = "message";
let mut data = String::new();
for line in event_block.lines() {
if let Some(value) = sse_field_value(line, "event:") {
event_type = value;
} else if let Some(value) = sse_field_value(line, "data:") {
if !data.is_empty() {
data.push('\n');
}
data.push_str(value);
}
}
match event_type {
"endpoint" => {
let _ = tx.send(SseInbound::Endpoint(data));
}
"message" if !data.trim().is_empty() => {
let _ = tx.send(SseInbound::Message(data.into_bytes()));
}
_ => {}
}
}
}
Ok(())
}
async fn wait_for_endpoint(
&mut self,
cancel_token: &tokio_util::sync::CancellationToken,
endpoint_timeout: Duration,
) -> Result<()> {
let timeout = tokio::time::sleep(endpoint_timeout);
tokio::pin!(timeout);
loop {
let msg = tokio::select! {
_ = cancel_token.cancelled() => {
anyhow::bail!("SSE transport cancelled before endpoint was discovered");
}
_ = &mut timeout => {
anyhow::bail!(
"SSE endpoint not received within {}ms",
endpoint_timeout.as_millis()
);
}
msg = self.receiver.recv() => {
msg.context("SSE transport closed before endpoint was discovered")?
}
};
match msg {
SseInbound::Endpoint(endpoint) => {
self.store_endpoint(&endpoint)?;
return Ok(());
}
SseInbound::Message(msg) => self.pending_messages.push_back(msg),
}
}
}
fn store_endpoint(&mut self, endpoint: &str) -> Result<()> {
self.endpoint_url = Some(Self::resolve_endpoint_url(&self.base_url, endpoint)?);
Ok(())
}
fn resolve_endpoint_url(base_url: &str, endpoint_url: &str) -> Result<String> {
if endpoint_url.starts_with("http://") || endpoint_url.starts_with("https://") {
return Ok(endpoint_url.to_string());
}
let base = reqwest::Url::parse(base_url)?;
let joined = base.join(endpoint_url)?;
Ok(joined.to_string())
}
}
impl HttpTransport {
fn new(
client: reqwest::Client,
url: String,
headers: HashMap<String, String>,
cancel_token: tokio_util::sync::CancellationToken,
endpoint_timeout: Duration,
) -> Self {
Self {
mode: HttpTransportMode::Streamable(StreamableHttpTransport::new(
client.clone(),
url.clone(),
headers.clone(),
)),
client,
base_url: url,
headers,
cancel_token,
endpoint_timeout,
}
}
async fn switch_to_sse_and_send(&mut self, msg: Vec<u8>) -> Result<()> {
let mut sse = SseTransport::connect(
self.client.clone(),
self.base_url.clone(),
self.headers.clone(),
self.cancel_token.clone(),
self.endpoint_timeout,
)
.await?;
sse.send(msg).await?;
self.mode = HttpTransportMode::Sse(sse);
Ok(())
}
/// Best-effort session-establishment GET preflight.
///
/// Per the Streamable HTTP spec, the server may return an
/// `Mcp-Session-Id` header on the `initialize` response (the normal
/// path handled inside [`StreamableHttpTransport::send`] above).
/// However some servers (e.g. Hindsight, #1629) **require** a session
/// ID on every POST including `initialize`, creating a chicken-and-egg
/// problem. For those servers we send a short-lived GET before the
/// first POST: if the server returns a session ID in the GET response
/// it will be captured by the header-reading code in
/// [`StreamableHttpTransport::send`] just as if it came from a POST
/// response.
///
/// This is intentionally best-effort:
/// * The GET uses a tight per-request inner timeout so it never
/// blocks connection startup for long.
/// * If the server doesn't support GET (405, 404, …) we log a debug
/// line and move on — the `initialize` POST will proceed without a
/// session ID.
/// * If the server opens an SSE stream in response (the GET from old
/// SSE transport), we read only the headers, then discard the body
/// so the SSE stream is torn down. The actual SSE path uses a
/// dedicated `SseTransport` and is triggered by the incompatible-
/// status fallback in [`HttpTransport::send`].
async fn try_establish_session(&mut self) -> Result<()> {
let transport = match &mut self.mode {
HttpTransportMode::Streamable(t) => t,
// Already on SSE — session is implicit via the long-lived GET.
HttpTransportMode::Sse(_) => return Ok(()),
};
let request = apply_safe_custom_headers(
with_default_mcp_http_headers(transport.client.get(&transport.url), false),
&transport.headers,
);
let response = tokio::time::timeout(Duration::from_secs(5), request.send())
.await
.map_err(|_| anyhow::anyhow!("GET timeout"))?
.map_err(|e| anyhow::anyhow!("GET error: {e}"))?;
// Capture session ID from the GET response so subsequent POSTs
// (including `initialize`) can include it. This is the same
// header-reading logic that would be hit inside
// `StreamableHttpTransport::send` for POST responses, but since
// the GET is sent before any POST we do it here directly.
if let Some(sid) = response
.headers()
.get("Mcp-Session-Id")
.and_then(|v| v.to_str().ok())
&& transport.session_id.as_deref() != Some(sid)
{
tracing::debug!(target: "mcp", session_id = %sid, "captured MCP session ID via GET preflight");
transport.session_id = Some(sid.to_string());
}
// We only care about the response headers — discard the body.
// If the server opened an SSE stream in response (some servers
// do this on GET), it will be torn down when response is dropped.
drop(response);
Ok(())
}
}
#[async_trait::async_trait]
impl McpTransport for HttpTransport {
async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
match &mut self.mode {
HttpTransportMode::Streamable(transport) => match transport.send(msg.clone()).await {
Ok(()) => Ok(()),
Err(StreamableSendError::Incompatible(detail)) => {
tracing::debug!(
"MCP Streamable HTTP unavailable; falling back to SSE endpoint discovery: {}",
detail
);
self.switch_to_sse_and_send(msg).await
}
Err(StreamableSendError::StaleSession(detail)) => {
if let HttpTransportMode::Streamable(transport) = &mut self.mode {
tracing::debug!(
target: "mcp",
error = %detail,
"MCP Streamable HTTP session expired; clearing cached session ID"
);
transport.session_id = None;
}
Err(anyhow::anyhow!(
"MCP Streamable HTTP session expired; retry with a new session required ({detail})"
))
}
Err(StreamableSendError::Other(err)) => Err(err),
},
HttpTransportMode::Sse(transport) => transport.send(msg).await,
}
}
async fn recv(&mut self) -> Result<Vec<u8>> {
match &mut self.mode {
HttpTransportMode::Streamable(transport) => transport.recv().await,
HttpTransportMode::Sse(transport) => transport.recv().await,
}
}
async fn shutdown(&mut self) {
if let HttpTransportMode::Sse(transport) = &mut self.mode {
transport.shutdown().await;
}
}
}
impl StreamableHttpTransport {
fn new(client: reqwest::Client, url: String, headers: HashMap<String, String>) -> Self {
Self {
client,
url,
headers,
pending_messages: VecDeque::new(),
session_id: None,
}
}
async fn send(&mut self, msg: Vec<u8>) -> std::result::Result<(), StreamableSendError> {
// Apply user-configured custom headers after protocol framing so
// reserved Accept / Content-Type overrides can be filtered out.
let mut request = apply_safe_custom_headers(
with_default_mcp_http_headers(self.client.post(&self.url), true),
&self.headers,
);
// Attach any previously captured session ID per the Streamable
// HTTP spec so the server can correlate this request to the
// existing session.
if let Some(ref sid) = self.session_id {
request = request.header("Mcp-Session-Id", sid.as_str());
}
let response = request
.body(msg)
.send()
.await
.map_err(|err| StreamableSendError::Other(err.into()))?;
let status = response.status();
// Capture session ID from any response (2xx, 202, 4xx, …). The
// server may return it on the `initialize` response or on a
// best-effort GET preflight below.
if let Some(sid) = response
.headers()
.get("Mcp-Session-Id")
.and_then(|v| v.to_str().ok())
&& self.session_id.as_deref() != Some(sid)
{
tracing::debug!(target: "mcp", session_id = %sid, "captured MCP session ID");
self.session_id = Some(sid.to_string());
}
if status == StatusCode::ACCEPTED || status == StatusCode::NO_CONTENT {
return Ok(());
}
if !status.is_success() {
let body_excerpt = bounded_body_excerpt(response, ERROR_BODY_PREVIEW_BYTES).await;
if self.session_id.is_some()
&& is_streamable_http_stale_session_status(status, &body_excerpt)
{
return Err(StreamableSendError::StaleSession(format!(
"status={status} body={body_excerpt}"
)));
}
if is_streamable_http_incompatible_status(status) {
return Err(StreamableSendError::Incompatible(format!(
"status={status} body={body_excerpt}"
)));
}
return Err(StreamableSendError::Other(anyhow::anyhow!(
"MCP Streamable HTTP rejected (transport=http url={} status={}): {}",
mask_url_secrets(&self.url),
status,
body_excerpt,
)));
}
let content_type = response
.headers()
.get(CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(str::to_string);
let body = response
.text()
.await
.map_err(|err| StreamableSendError::Other(err.into()))?;
self.store_response_body(content_type.as_deref(), &body)
.map_err(StreamableSendError::Other)
}
async fn recv(&mut self) -> Result<Vec<u8>> {
self.pending_messages
.pop_front()
.context("MCP Streamable HTTP response queue is empty")
}
fn store_response_body(&mut self, content_type: Option<&str>, body: &str) -> Result<()> {
if body.trim().is_empty() {
return Ok(());
}
let is_event_stream = content_type
.map(|value| value.to_ascii_lowercase().contains("text/event-stream"))
.unwrap_or(false)
|| body.trim_start().starts_with("event:")
|| body.trim_start().starts_with("data:");
if is_event_stream {
for msg in parse_sse_message_data(body) {
self.pending_messages.push_back(msg);
}
return Ok(());
}
self.pending_messages.push_back(body.as_bytes().to_vec());
Ok(())
}
}
fn is_streamable_http_incompatible_status(status: StatusCode) -> bool {
matches!(
status,
StatusCode::NOT_FOUND
| StatusCode::METHOD_NOT_ALLOWED
| StatusCode::NOT_ACCEPTABLE
| StatusCode::UNSUPPORTED_MEDIA_TYPE
| StatusCode::NOT_IMPLEMENTED
)
}
fn is_streamable_http_stale_session_status(status: StatusCode, body_excerpt: &str) -> bool {
if status == StatusCode::NOT_FOUND {
return true;
}
if status != StatusCode::BAD_REQUEST && status != StatusCode::UNAUTHORIZED {
return false;
}
let body = body_excerpt.to_ascii_lowercase();
body.contains("session") && (body.contains("expired") || body.contains("invalid"))
}
fn is_mcp_stale_session_body(body: &str) -> bool {
let body = body.to_ascii_lowercase();
body.contains("session") && (body.contains("expired") || body.contains("invalid"))
}
fn is_mcp_stale_session_error(err: &anyhow::Error) -> bool {
let err = format!("{err:#}");
let lower_err = err.to_ascii_lowercase();
err.contains("MCP Streamable HTTP session expired")
|| err.contains("MCP session expired")
|| err.contains("SSE transport closed")
|| (err.contains("MCP SSE POST send failed") && is_connection_closed_error_text(&lower_err))
|| is_mcp_stale_session_body(&err)
}
fn is_connection_closed_error_text(err: &str) -> bool {
err.contains("connection closed")
|| err.contains("connection reset")
|| err.contains("broken pipe")
|| err.contains("unexpected eof")
|| err.contains("forcibly closed")
}
fn parse_sse_message_data(body: &str) -> Vec<Vec<u8>> {
let normalized = body.replace("\r\n", "\n");
let mut messages = Vec::new();
for block in normalized.split("\n\n") {
let mut event_type = "message";
let mut data = String::new();
for line in block.lines() {
if let Some(value) = sse_field_value(line, "event:") {
event_type = value;
} else if let Some(value) = sse_field_value(line, "data:") {
if !data.is_empty() {
data.push('\n');
}
data.push_str(value);
}
}
if event_type != "message" || data.trim().is_empty() {
continue;
}
messages.push(data.trim().as_bytes().to_vec());
}
messages
}
fn find_sse_event_separator(buffer: &str) -> Option<(usize, usize)> {
match (buffer.find("\n\n"), buffer.find("\r\n\r\n")) {
(Some(lf), Some(crlf)) if crlf < lf => Some((crlf, 4)),
(Some(lf), _) => Some((lf, 2)),
(_, Some(crlf)) => Some((crlf, 4)),
_ => None,
}
}
fn sse_field_value<'a>(line: &'a str, field: &str) -> Option<&'a str> {
let value = line.strip_prefix(field)?;
Some(value.strip_prefix(' ').unwrap_or(value))
}
fn is_legacy_sse_transport(config: &McpServerConfig) -> bool {
config
.transport
.as_deref()
.map(|transport| transport.trim().eq_ignore_ascii_case("sse"))
.unwrap_or(false)
}
fn validate_mcp_transport(transport: Option<&str>) -> Result<()> {
let Some(transport) = transport else {
return Ok(());
};
if transport.trim().eq_ignore_ascii_case("sse") {
return Ok(());
}
anyhow::bail!("Unsupported MCP transport '{transport}'. Supported values: sse");
}
fn response_id_matches(id: Option<&serde_json::Value>, expected_id: &str) -> bool {
let Some(id) = id else {
return false;
};
if id.as_str() == Some(expected_id) {
return true;
}
id.as_u64()
.map(|id| id.to_string() == expected_id)
.unwrap_or(false)
}
#[async_trait::async_trait]
impl McpTransport for SseTransport {
async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
let endpoint = self
.endpoint_url
.as_ref()
.context("SSE endpoint not yet discovered")?;
let response = apply_safe_custom_headers(
with_default_mcp_http_headers(self.client.post(endpoint), true),
&self.headers,
)
.body(msg)
.send()
.await
.with_context(|| {
format!(
"MCP SSE POST send failed (transport=sse endpoint={})",
mask_url_secrets(endpoint)
)
})?;
let status = response.status();
if !status.is_success() {
let body_excerpt = bounded_body_excerpt(response, ERROR_BODY_PREVIEW_BYTES).await;
if is_mcp_stale_session_body(&body_excerpt) {
anyhow::bail!(
"MCP session expired (transport=sse endpoint={} status={}): {}",
mask_url_secrets(endpoint),
status,
body_excerpt
);
}
anyhow::bail!(
"MCP SSE POST rejected (transport=sse endpoint={} status={}): {}",
mask_url_secrets(endpoint),
status,
body_excerpt
);
}
Ok(())
}
async fn recv(&mut self) -> Result<Vec<u8>> {
loop {
if let Some(msg) = self.pending_messages.pop_front() {
return Ok(msg);
}
match self.receiver.recv().await.context("SSE transport closed")? {
SseInbound::Endpoint(endpoint) => {
self.store_endpoint(&endpoint)?;
}
SseInbound::Message(msg) => return Ok(msg),
}
}
}
}
// === McpConnection - Async Connection Management ===
/// Manages a single async connection to an MCP server
pub struct McpConnection {
name: String,
transport: Box<dyn McpTransport>,
tools: Vec<McpTool>,
resources: Vec<McpResource>,
resource_templates: Vec<McpResourceTemplate>,
prompts: Vec<McpPrompt>,
request_id: AtomicU64,
state: ConnectionState,
config: McpServerConfig,
cancel_token: tokio_util::sync::CancellationToken,
}
impl McpConnection {
/// Connect to an MCP server and initialize it.
///
/// `network_policy` (added in v0.7.0 for #135) is consulted for HTTP/SSE
/// transports only — STDIO transports are unaffected. Pass `None` to
/// match pre-v0.7.0 permissive behavior.
pub async fn connect_with_policy(
name: String,
config: McpServerConfig,
global_timeouts: &McpTimeouts,
network_policy: Option<&NetworkPolicyDecider>,
) -> Result<Self> {
let connect_timeout_secs = config.effective_connect_timeout(global_timeouts);
let cancel_token = tokio_util::sync::CancellationToken::new();
let transport: Box<dyn McpTransport> = if let Some(url) = &config.url {
// Per-domain network policy gate (#135). Only the HTTP/SSE transport
// is gated; STDIO MCP servers run as local subprocesses and never
// touch the network from this code path.
if let Some(decider) = network_policy
&& let Some(host) = host_from_url(url)
{
match decider.evaluate(&host, "mcp") {
Decision::Allow => {}
Decision::Deny => {
anyhow::bail!(
"MCP server '{name}' connection to '{host}' blocked by network policy"
);
}
Decision::Prompt => {
anyhow::bail!(
"MCP server '{name}' connection to '{host}' requires approval; \
re-run after `/network allow {host}` or set network.default = \"allow\" in config"
);
}
}
}
// Honor the standard `HTTP_PROXY` / `HTTPS_PROXY` (and their
// lowercase equivalents) plus `NO_PROXY` env vars when
// reaching MCP HTTP servers (#1408). Reqwest 0.13 does not
// auto-detect these by default, so users behind corporate
// proxies, on China-mainland connections routing through a
// local Clash / Shadowsocks tunnel, etc. previously had MCP
// HTTP traffic bypass the proxy entirely while every other
// tool on the box (curl, npm, …) used it.
let mut client_builder = crate::tls::reqwest_client_builder()
.timeout(Duration::from_secs(connect_timeout_secs));
let env_proxy_url = std::env::var("HTTPS_PROXY")
.or_else(|_| std::env::var("https_proxy"))
.or_else(|_| std::env::var("HTTP_PROXY"))
.or_else(|_| std::env::var("http_proxy"))
.ok()
.filter(|s| !s.trim().is_empty());
if let Some(proxy_url) = env_proxy_url {
match reqwest::Proxy::all(&proxy_url) {
Ok(proxy) => {
let proxy = proxy.no_proxy(reqwest::NoProxy::from_env());
client_builder = client_builder.proxy(proxy);
}
Err(err) => {
// Redact userinfo (the `username[:password]@…`
// portion of the URL) before logging so an
// HTTPS_PROXY that embeds credentials
// (common in corporate setups) doesn't leak the
// password to the on-disk `~/.deepseek/logs/`.
let proxy_redacted = redact_proxy_userinfo(&proxy_url);
tracing::warn!(
target: "mcp",
?err,
proxy = %proxy_redacted,
"ignoring malformed HTTP(S)_PROXY env var; MCP connection will bypass proxy"
);
}
}
}
let client = client_builder.build()?;
if is_legacy_sse_transport(&config) {
Box::new(
SseTransport::connect(
client,
url.clone(),
config.headers.clone(),
cancel_token.clone(),
Duration::from_secs(connect_timeout_secs),
)
.await?,
)
} else {
let mut http = HttpTransport::new(
client,
url.clone(),
config.headers.clone(),
cancel_token.clone(),
Duration::from_secs(connect_timeout_secs),
);
// Best-effort session preflight for servers that require
// a session ID on every POST including `initialize`
// (e.g. Hindsight, #1629). Failures are non-fatal — the
// `initialize` POST will proceed and may capture a session
// ID from the response instead.
if let Err(e) = http.try_establish_session().await {
tracing::debug!(
target: "mcp",
server = %name,
error = %e,
"session-establishment GET skipped; proceeding with POST initialize"
);
}
Box::new(http)
}
} else if let Some(command) = &config.command {
let mut cmd = tokio::process::Command::new(command);
cmd.args(&config.args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
if let Some(cwd) = &config.cwd {
cmd.current_dir(cwd);
}
// MCP stdio servers are user-configured integrations. Use the
// wider MCP allowlist so common Node/Python/proxy/CA-bundle
// bootstrap variables (NVM_DIR, NODE_OPTIONS, NPM_CONFIG_*,
// HTTP(S)_PROXY, …) reach the child. See `sanitized_mcp_env`
// and #1244 for context.
child_env::apply_to_tokio_command_mcp(&mut cmd, child_env::string_map_env(&config.env));
let mut child = cmd.spawn().with_context(|| {
let env_keys: Vec<&str> = config.env.keys().map(String::as_str).collect();
format!(
"MCP stdio spawn failed (transport=stdio server={name} cmd={command:?} args={:?} env_keys={env_keys:?})",
config.args,
)
})?;
let stdin = child.stdin.take().context("Failed to get MCP stdin")?;
let stdout = child.stdout.take().context("Failed to get MCP stdout")?;
let stderr = child.stderr.take().context("Failed to get MCP stderr")?;
// Drain stderr into a bounded ring buffer so a crash mid-run
// leaves diagnostic breadcrumbs instead of disappearing into
// `Stdio::null`. The task exits naturally when the child closes
// its stderr (kill_on_drop / exit / explicit shutdown).
let stderr_tail = StderrTail::new();
{
let tail = Arc::clone(&stderr_tail);
tokio::spawn(async move {
let mut lines = tokio::io::BufReader::new(stderr).lines();
while let Ok(Some(line)) = lines.next_line().await {
tail.push(line).await;
}
});
}
Box::new(StdioTransport {
child,
stdin,
reader: tokio::io::BufReader::new(stdout),
stderr_tail,
})
} else {
anyhow::bail!("MCP server '{name}' config must have either 'command' or 'url'");
};
let mut conn = Self {
name: name.clone(),
transport,
tools: Vec::new(),
resources: Vec::new(),
resource_templates: Vec::new(),
prompts: Vec::new(),
request_id: AtomicU64::new(1),
state: ConnectionState::Connecting,
config,
cancel_token,
};
// Initialize with timeout
tokio::time::timeout(Duration::from_secs(connect_timeout_secs), conn.initialize())
.await
.with_context(|| format!("MCP server '{name}' initialization timed out"))??;
// Discover tools, resources, and prompts with timeout
tokio::time::timeout(
Duration::from_secs(connect_timeout_secs),
conn.discover_all(),
)
.await
.with_context(|| format!("MCP server '{name}' discovery timed out"))??;
conn.state = ConnectionState::Ready;
Ok(conn)
}
/// Send initialize request and wait for response
async fn initialize(&mut self) -> Result<()> {
let init_id = self.next_id();
self.send(serde_json::json!({
"jsonrpc": "2.0",
"id": &init_id,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"clientInfo": {
"name": "codewhale-tui",
"version": env!("CARGO_PKG_VERSION")
},
"capabilities": {
"tools": {},
"resources": {},
"prompts": {}
}
}
}))
.await?;
self.recv(init_id).await?;
// Send initialized notification (no id, no response expected)
self.send(serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/initialized"
}))
.await?;
Ok(())
}
/// Discover tools, resources, and prompts
async fn discover_all(&mut self) -> Result<()> {
// We use join! to discover everything concurrently if possible,
// but for now let's keep it sequential for simplicity in error handling
self.discover_tools().await?;
self.discover_resources().await?;
self.discover_resource_templates().await?;
self.discover_prompts().await?;
Ok(())
}
/// Discover available tools from the MCP server
async fn discover_tools(&mut self) -> Result<()> {
let mut cursor: Option<String> = None;
loop {
let list_id = self.next_id();
let params = match &cursor {
Some(c) => serde_json::json!({ "cursor": c }),
None => serde_json::json!({}),
};
self.send(serde_json::json!({
"jsonrpc": "2.0",
"id": &list_id,
"method": "tools/list",
"params": params
}))
.await?;
let response = self.recv(list_id).await?;
let Some(result) = response.get("result") else {
break;
};
if let Some(arr) = result.get("tools").and_then(|t| t.as_array()) {
for item in arr {
match serde_json::from_value::<McpTool>(item.clone()) {
Ok(tool) => self.tools.push(tool),
Err(err) => {
// Skip individual malformed entries instead of
// dropping the whole page (#1410). The old
// `unwrap_or_default()` would silently throw
// away every tool when one was misshapen.
tracing::debug!(target: "mcp", ?err, "skipping malformed tool item");
}
}
}
}
cursor = result
.get("nextCursor")
.and_then(|v| v.as_str())
.map(str::to_owned);
if cursor.is_none() {
break;
}
}
// Sort by tool name so the order the model sees doesn't depend on
// server-side pagination ordering — keeps the prompt prefix stable
// for cache-hit purposes (#1319).
self.tools.sort_by(|a, b| a.name.cmp(&b.name));
Ok(())
}
/// Discover available resources from the MCP server
async fn discover_resources(&mut self) -> Result<()> {
let mut cursor: Option<String> = None;
loop {
let list_id = self.next_id();
let params = match &cursor {
Some(c) => serde_json::json!({ "cursor": c }),
None => serde_json::json!({}),
};
self.send(serde_json::json!({
"jsonrpc": "2.0",
"id": &list_id,
"method": "resources/list",
"params": params
}))
.await?;
let response = self.recv(list_id).await?;
let Some(result) = response.get("result") else {
break;
};
if let Some(arr) = result.get("resources").and_then(|r| r.as_array()) {
for item in arr {
match serde_json::from_value::<McpResource>(item.clone()) {
Ok(resource) => self.resources.push(resource),
Err(err) => {
tracing::debug!(target: "mcp", ?err, "skipping malformed resource item");
}
}
}
}
cursor = result
.get("nextCursor")
.and_then(|v| v.as_str())
.map(str::to_owned);
if cursor.is_none() {
break;
}
}
Ok(())
}
/// Discover available resource templates from the MCP server
async fn discover_resource_templates(&mut self) -> Result<()> {
let mut cursor: Option<String> = None;
loop {
let list_id = self.next_id();
let params = match &cursor {
Some(c) => serde_json::json!({ "cursor": c }),
None => serde_json::json!({}),
};
self.send(serde_json::json!({
"jsonrpc": "2.0",
"id": &list_id,
"method": "resources/templates/list",
"params": params
}))
.await?;
let response = self.recv(list_id).await?;
let Some(result) = response.get("result") else {
break;
};
let templates = result
.get("resourceTemplates")
.or_else(|| result.get("templates"))
.or_else(|| result.get("resource_templates"));
if let Some(arr) = templates.and_then(|t| t.as_array()) {
for item in arr {
match serde_json::from_value::<McpResourceTemplate>(item.clone()) {
Ok(tmpl) => self.resource_templates.push(tmpl),
Err(err) => {
tracing::debug!(target: "mcp", ?err, "skipping malformed resource_template item");
}
}
}
}
cursor = result
.get("nextCursor")
.and_then(|v| v.as_str())
.map(str::to_owned);
if cursor.is_none() {
break;
}
}
Ok(())
}
/// Discover available prompts from the MCP server
async fn discover_prompts(&mut self) -> Result<()> {
let mut cursor: Option<String> = None;
loop {
let list_id = self.next_id();
let params = match &cursor {
Some(c) => serde_json::json!({ "cursor": c }),
None => serde_json::json!({}),
};
self.send(serde_json::json!({
"jsonrpc": "2.0",
"id": &list_id,
"method": "prompts/list",
"params": params
}))
.await?;
let response = self.recv(list_id).await?;
let Some(result) = response.get("result") else {
break;
};
if let Some(arr) = result.get("prompts").and_then(|p| p.as_array()) {
for item in arr {
match serde_json::from_value::<McpPrompt>(item.clone()) {
Ok(prompt) => self.prompts.push(prompt),
Err(err) => {
tracing::debug!(target: "mcp", ?err, "skipping malformed prompt item");
}
}
}
}
cursor = result
.get("nextCursor")
.and_then(|v| v.as_str())
.map(str::to_owned);
if cursor.is_none() {
break;
}
}
Ok(())
}
/// Call a tool on this MCP server
pub async fn call_tool(
&mut self,
tool_name: &str,
arguments: serde_json::Value,
timeout_secs: u64,
) -> Result<serde_json::Value> {
self.call_method(
"tools/call",
serde_json::json!({
"name": tool_name,
"arguments": arguments
}),
timeout_secs,
)
.await
}
/// Read a resource from this MCP server
pub async fn read_resource(
&mut self,
uri: &str,
timeout_secs: u64,
) -> Result<serde_json::Value> {
self.call_method(
"resources/read",
serde_json::json!({
"uri": uri
}),
timeout_secs,
)
.await
}
/// Get a prompt from this MCP server
pub async fn get_prompt(
&mut self,
prompt_name: &str,
arguments: serde_json::Value,
timeout_secs: u64,
) -> Result<serde_json::Value> {
self.call_method(
"prompts/get",
serde_json::json!({
"name": prompt_name,
"arguments": arguments
}),
timeout_secs,
)
.await
}
/// Generic method to call an MCP method
async fn call_method(
&mut self,
method: &str,
params: serde_json::Value,
timeout_secs: u64,
) -> Result<serde_json::Value> {
if self.state != ConnectionState::Ready {
anyhow::bail!(
"Failed to call MCP method '{}': connection '{}' is not ready",
method,
self.name
);
}
let call_id = self.next_id();
self.send(serde_json::json!({
"jsonrpc": "2.0",
"id": &call_id,
"method": method,
"params": params
}))
.await?;
let response = tokio::time::timeout(Duration::from_secs(timeout_secs), self.recv(call_id))
.await
.with_context(|| {
format!(
"MCP method '{}' on server '{}' timed out after {}s",
method, self.name, timeout_secs
)
})??;
if let Some(error) = response.get("error") {
return Err(anyhow::anyhow!(
"MCP error in '{}': {}",
method,
serde_json::to_string_pretty(error)?
));
}
Ok(response
.get("result")
.cloned()
.unwrap_or(serde_json::json!(null)))
}
/// Get discovered tools
pub fn tools(&self) -> &[McpTool] {
&self.tools
}
/// Get discovered resources
pub fn resources(&self) -> &[McpResource] {
&self.resources
}
/// Get discovered resource templates
pub fn resource_templates(&self) -> &[McpResourceTemplate] {
&self.resource_templates
}
/// Get discovered prompts
pub fn prompts(&self) -> &[McpPrompt] {
&self.prompts
}
/// Get server name
#[allow(dead_code)] // Public API for MCP consumers
pub fn name(&self) -> &str {
&self.name
}
/// Check if connection is ready
pub fn is_ready(&self) -> bool {
self.state == ConnectionState::Ready
}
/// Get server config
pub fn config(&self) -> &McpServerConfig {
&self.config
}
/// Get connection state
#[allow(dead_code)] // Public API for MCP consumers
pub fn state(&self) -> ConnectionState {
self.state
}
fn next_id(&self) -> String {
self.request_id.fetch_add(1, Ordering::SeqCst).to_string()
}
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
let bytes = serde_json::to_vec(&msg).context("Failed to serialize MCP JSON-RPC message")?;
self.transport.send(bytes).await
}
async fn recv(&mut self, expected_id: String) -> Result<serde_json::Value> {
loop {
let bytes = self.transport.recv().await.inspect_err(|_e| {
self.state = ConnectionState::Disconnected;
})?;
let value: serde_json::Value = serde_json::from_slice(&bytes).with_context(|| {
format!(
"Invalid MCP JSON-RPC message from server '{}': {}",
self.name,
invalid_json_preview(&bytes)
)
})?;
// Check if this is a response with the expected id. We emit
// string IDs because some MCP gateways reject numeric JSON-RPC
// IDs, but accept numeric echoes for compatibility with older
// servers and tests.
if response_id_matches(value.get("id"), &expected_id) {
if let Some(error) = value.get("error")
&& is_mcp_stale_session_body(&error.to_string())
{
anyhow::bail!("MCP session expired: {error}");
}
return Ok(value);
}
// Skip notifications (no id) and responses with different ids
}
}
/// Gracefully close the connection
#[allow(dead_code)] // Public API for MCP consumers
pub fn close(&mut self) {
self.cancel_token.cancel();
self.state = ConnectionState::Disconnected;
}
}
impl Drop for McpConnection {
fn drop(&mut self) {
self.cancel_token.cancel();
}
}
// === McpPool - Connection Pool Management ===
/// Pool of MCP connections for reuse
pub struct McpPool {
connections: HashMap<String, McpConnection>,
config: McpConfig,
network_policy: Option<NetworkPolicyDecider>,
/// Source paths the config was loaded from. Empty for pools constructed
/// directly via `new` (tests, ad-hoc snapshots). Workspace-aware pools
/// track both global and project-level MCP config paths so lazy reload sees
/// either file appear or change.
config_sources: Vec<PathBuf>,
workspace: Option<PathBuf>,
/// 64-bit content hash of the active config (`hash_mcp_config`). Compared
/// against the freshly-loaded config after an mtime change to skip
/// reloading when the file was merely touched.
config_hash: u64,
/// Most recently observed mtime for `config_sources`.
last_mtimes: Vec<Option<std::time::SystemTime>>,
}
impl McpPool {
/// Create a new pool with the given configuration
pub fn new(config: McpConfig) -> Self {
let config_hash = hash_mcp_config(&config);
Self {
connections: HashMap::new(),
config,
network_policy: None,
config_sources: Vec::new(),
workspace: None,
config_hash,
last_mtimes: Vec::new(),
}
}
/// Create a pool from a configuration file path.
#[cfg(test)]
pub fn from_config_path(path: &std::path::Path) -> Result<Self> {
let config = load_config(path)?;
let mut pool = Self::new(config);
pool.config_sources = vec![path.to_path_buf()];
pool.last_mtimes = vec![mcp_config_mtime(path)];
Ok(pool)
}
/// Create a pool from global MCP config plus workspace-local
/// `.codewhale/mcp.json`. Project servers override same-name global
/// servers and default stdio `cwd` to the workspace root.
pub fn from_config_path_with_workspace(
path: &std::path::Path,
workspace: &Path,
) -> Result<Self> {
let config = load_config_with_workspace(path, workspace)?;
let mut pool = Self::new(config);
pool.config_sources = vec![path.to_path_buf(), workspace_mcp_config_path(workspace)];
pool.config_sources
.extend(crate::config::workspace_trust_config_candidate_paths());
pool.last_mtimes = pool
.config_sources
.iter()
.map(|source| mcp_config_mtime(source))
.collect();
pool.workspace = Some(workspace.to_path_buf());
Ok(pool)
}
/// Attach a per-domain network policy (#135). When set, HTTP/SSE
/// transports are gated through it; STDIO transports are unaffected.
pub fn with_network_policy(mut self, policy: NetworkPolicyDecider) -> Self {
self.network_policy = Some(policy);
self
}
/// If the source config file's mtime has changed since the last check,
/// re-read it and (only when the content hash also changed) drop all
/// existing connections so the next `get_or_connect` reattaches under
/// the new config. No-op when the pool was constructed via [`McpPool::new`]
/// (no source path), when stat fails, or when the file content is
/// byte-identical to what we last loaded. Returns `Ok(true)` if any
/// connections were dropped, `Ok(false)` otherwise.
///
/// This is the lazy half of the auto-reload story for #1267: instead of a
/// long-lived file watcher, the next tool invocation pays a single `stat`
/// call (and only re-reads the file when the mtime moved). On networked
/// or remote filesystems where mtime granularity is poor, the hash
/// compare keeps us from churning connections on every check.
pub async fn reload_if_config_changed(&mut self) -> Result<bool> {
if self.config_sources.is_empty() {
return Ok(false);
}
let current_mtimes: Vec<_> = self
.config_sources
.iter()
.map(|path| mcp_config_mtime(path))
.collect();
if current_mtimes == self.last_mtimes {
return Ok(false);
}
// mtime moved — we owe a re-read.
let primary = self
.config_sources
.first()
.context("MCP config source list unexpectedly empty")?;
let new_config = if let Some(workspace) = self.workspace.as_deref() {
load_config_with_workspace(primary, workspace)?
} else {
load_config(primary)?
};
let new_hash = hash_mcp_config(&new_config);
// Always advance mtimes so a touched-but-unchanged file doesn't
// make us re-read on every subsequent call.
self.last_mtimes = current_mtimes;
if new_hash == self.config_hash {
return Ok(false);
}
// Real content change — drop all live connections so the next
// get_or_connect picks up the new config (sandbox flags, env, args).
self.connections.clear();
self.config = new_config;
self.config_hash = new_hash;
Ok(true)
}
/// Get or create a connection to a server
pub async fn get_or_connect(&mut self, server_name: &str) -> Result<&mut McpConnection> {
// Lazy auto-reload (#1267 part 2): cheap mtime-then-hash check before
// each connection lookup. Transient FS errors are logged but not
// propagated so a brief hiccup can't take down the whole tool dispatch.
if let Err(e) = self.reload_if_config_changed().await {
tracing::warn!("MCP config reload check failed: {e:#}");
}
let is_ready = self
.connections
.get(server_name)
.map(|conn| conn.is_ready())
.unwrap_or(false);
if is_ready {
return self
.connections
.get_mut(server_name)
.ok_or_else(|| anyhow::anyhow!("MCP connection disappeared for {server_name}"));
}
self.connections.remove(server_name);
let server_config = self
.config
.servers
.get(server_name)
.ok_or_else(|| anyhow::anyhow!("Failed to find MCP server: {server_name}"))?
.clone();
if !server_config.is_enabled() {
anyhow::bail!("Failed to connect MCP server '{server_name}': server is disabled");
}
let connection = McpConnection::connect_with_policy(
server_name.to_string(),
server_config,
&self.config.timeouts,
self.network_policy.as_ref(),
)
.await?;
self.connections.insert(server_name.to_string(), connection);
self.connections
.get_mut(server_name)
.ok_or_else(|| anyhow::anyhow!("Failed to store MCP connection for {server_name}"))
}
/// Connect to all enabled servers, returning errors for failed connections
pub async fn connect_all(&mut self) -> Vec<(String, anyhow::Error)> {
let mut errors = Vec::new();
let names: Vec<String> = self
.config
.servers
.keys()
.filter(|n| self.config.servers[*n].is_enabled())
.cloned()
.collect();
for name in names {
if let Err(e) = self.get_or_connect(&name).await {
errors.push((name, e));
}
}
for (name, server_cfg) in &self.config.servers {
if server_cfg.required
&& server_cfg.is_enabled()
&& !self
.connections
.get(name)
.is_some_and(McpConnection::is_ready)
{
errors.push((
name.clone(),
anyhow::anyhow!("required MCP server failed to initialize"),
));
}
}
errors
}
/// Get all discovered tools with server-prefixed names
pub fn all_tools(&self) -> Vec<(String, &McpTool)> {
let mut tools = Vec::new();
for (server, conn) in &self.connections {
for tool in conn.tools() {
if !conn.config().is_tool_enabled(&tool.name) {
continue;
}
// Format: mcp_{server}_{tool}
tools.push((format!("mcp_{}_{}", server, tool.name), tool));
}
}
// Sort by prefixed name so iteration order across servers is
// deterministic for prefix-cache stability (#1319).
tools.sort_by(|a, b| a.0.cmp(&b.0));
tools
}
/// Get all discovered resources with server-prefixed names
pub fn all_resources(&self) -> Vec<(String, &McpResource)> {
let mut resources = Vec::new();
for (server, conn) in &self.connections {
for resource in conn.resources() {
// Format: mcp_{server}_{resource_name}
// Note: resource names might contain spaces, we should probably slugify them
let safe_name = resource.name.replace(' ', "_").to_lowercase();
resources.push((format!("mcp_{server}_{safe_name}"), resource));
}
}
resources
}
/// Get all discovered resource templates with server-prefixed names
#[allow(dead_code)] // Public API for MCP resource discovery
pub fn all_resource_templates(&self) -> Vec<(String, &McpResourceTemplate)> {
let mut templates = Vec::new();
for (server, conn) in &self.connections {
for template in conn.resource_templates() {
let safe_name = template.name.replace(' ', "_").to_lowercase();
templates.push((format!("mcp_{server}_{safe_name}"), template));
}
}
templates
}
async fn list_resources(&mut self, server: Option<String>) -> Result<Vec<serde_json::Value>> {
if let Some(server_name) = server {
let conn = self.get_or_connect(&server_name).await?;
let resources = conn
.resources()
.iter()
.map(|resource| {
serde_json::json!({
"server": server_name.clone(),
"uri": resource.uri,
"name": resource.name,
"description": resource.description,
"mime_type": resource.mime_type,
})
})
.collect();
return Ok(resources);
}
let errors = self.connect_all().await;
for (server, err) in errors {
tracing::warn!("Failed to connect MCP server '{server}' for resources: {err:#}");
}
let mut items = Vec::new();
for (server, conn) in &self.connections {
for resource in conn.resources() {
items.push(serde_json::json!({
"server": server,
"uri": resource.uri,
"name": resource.name,
"description": resource.description,
"mime_type": resource.mime_type,
}));
}
}
Ok(items)
}
async fn list_resource_templates(
&mut self,
server: Option<String>,
) -> Result<Vec<serde_json::Value>> {
if let Some(server_name) = server {
let conn = self.get_or_connect(&server_name).await?;
let templates = conn
.resource_templates()
.iter()
.map(|template| {
serde_json::json!({
"server": server_name.clone(),
"uri_template": template.uri_template,
"name": template.name,
"description": template.description,
"mime_type": template.mime_type,
})
})
.collect();
return Ok(templates);
}
let errors = self.connect_all().await;
for (server, err) in errors {
tracing::warn!(
"Failed to connect MCP server '{server}' for resource templates: {err:#}"
);
}
let mut items = Vec::new();
for (server, conn) in &self.connections {
for template in conn.resource_templates() {
items.push(serde_json::json!({
"server": server,
"uri_template": template.uri_template,
"name": template.name,
"description": template.description,
"mime_type": template.mime_type,
}));
}
}
Ok(items)
}
/// Get all discovered prompts with server-prefixed names
pub fn all_prompts(&self) -> Vec<(String, &McpPrompt)> {
let mut prompts = Vec::new();
for (server, conn) in &self.connections {
for prompt in conn.prompts() {
// Format: mcp_{server}_{prompt}
prompts.push((format!("mcp_{}_{}", server, prompt.name), prompt));
}
}
prompts
}
/// Read a resource from a specific server
pub async fn read_resource(
&mut self,
server_name: &str,
uri: &str,
) -> Result<serde_json::Value> {
let global_timeouts = self.config.timeouts;
let conn = self.get_or_connect(server_name).await?;
let timeout = conn.config().effective_read_timeout(&global_timeouts);
conn.read_resource(uri, timeout).await
}
/// Get a prompt from a specific server
pub async fn get_prompt(
&mut self,
server_name: &str,
prompt_name: &str,
arguments: serde_json::Value,
) -> Result<serde_json::Value> {
let global_timeouts = self.config.timeouts;
let conn = self.get_or_connect(server_name).await?;
let timeout = conn.config().effective_execute_timeout(&global_timeouts);
conn.get_prompt(prompt_name, arguments, timeout).await
}
/// Parse a prefixed name into (server_name, tool_name)
pub(crate) fn parse_prefixed_name<'a>(
&self,
prefixed_name: &'a str,
) -> Result<(&'a str, &'a str)> {
let Some(rest) = prefixed_name.strip_prefix("mcp_") else {
anyhow::bail!("Invalid MCP tool name: {prefixed_name}");
};
let mut best_match: Option<(&str, &str)> = None;
for server in self.connections.keys().chain(self.config.servers.keys()) {
let Some(tool) = rest
.strip_prefix(server)
.and_then(|tail| tail.strip_prefix('_'))
else {
continue;
};
if tool.is_empty() {
continue;
}
if best_match.is_none_or(|(matched, _)| server.len() > matched.len()) {
best_match = Some((&rest[..server.len()], tool));
}
}
if let Some((server, tool)) = best_match {
return Ok((server, tool));
}
let Some((server, tool)) = rest.split_once('_') else {
anyhow::bail!("Invalid MCP tool name format: {prefixed_name}");
};
Ok((server, tool))
}
/// Convert discovered tools to API Tool format
pub fn to_api_tools(&self) -> Vec<crate::models::Tool> {
let mut api_tools = Vec::new();
// Add regular tools
for (name, tool) in self.all_tools() {
api_tools.push(crate::models::Tool {
tool_type: None,
name,
description: tool.description.clone().unwrap_or_default(),
input_schema: tool.input_schema.clone(),
allowed_callers: Some(vec!["direct".to_string()]),
defer_loading: Some(false),
input_examples: None,
strict: None,
cache_control: None,
});
}
if !self.config.servers.is_empty() {
api_tools.push(crate::models::Tool {
tool_type: None,
name: "list_mcp_resources".to_string(),
description: "List available MCP resources across servers (optionally filtered by server).".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"server": { "type": "string", "description": "Optional MCP server name to filter by" }
}
}),
allowed_callers: Some(vec!["direct".to_string()]),
defer_loading: Some(false),
input_examples: None,
strict: None,
cache_control: None,
});
api_tools.push(crate::models::Tool {
tool_type: None,
name: "list_mcp_resource_templates".to_string(),
description: "List available MCP resource templates across servers (optionally filtered by server).".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"server": { "type": "string", "description": "Optional MCP server name to filter by" }
}
}),
allowed_callers: Some(vec!["direct".to_string()]),
defer_loading: Some(false),
input_examples: None,
strict: None,
cache_control: None,
});
}
// Add resource reading tools if resources exist
let resources = self.all_resources();
if !resources.is_empty() {
api_tools.push(crate::models::Tool {
tool_type: None,
name: "mcp_read_resource".to_string(),
description: "Read a resource from an MCP server using its URI".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"server": { "type": "string", "description": "The name of the MCP server" },
"uri": { "type": "string", "description": "The URI of the resource to read" }
},
"required": ["server", "uri"]
}),
allowed_callers: Some(vec!["direct".to_string()]),
defer_loading: Some(false),
input_examples: None,
strict: None,
cache_control: None,
});
api_tools.push(crate::models::Tool {
tool_type: None,
name: "read_mcp_resource".to_string(),
description: "Alias for mcp_read_resource.".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"server": { "type": "string", "description": "The name of the MCP server" },
"uri": { "type": "string", "description": "The URI of the resource to read" }
},
"required": ["server", "uri"]
}),
allowed_callers: Some(vec!["direct".to_string()]),
defer_loading: Some(false),
input_examples: None,
strict: None,
cache_control: None,
});
}
// Add prompt getting tools if prompts exist
let prompts = self.all_prompts();
if !prompts.is_empty() {
api_tools.push(crate::models::Tool {
tool_type: None,
name: "mcp_get_prompt".to_string(),
description: "Get a prompt from an MCP server".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"server": { "type": "string", "description": "The name of the MCP server" },
"name": { "type": "string", "description": "The name of the prompt" },
"arguments": {
"type": "object",
"description": "Optional arguments for the prompt",
"additionalProperties": { "type": "string" }
}
},
"required": ["server", "name"]
}),
allowed_callers: Some(vec!["direct".to_string()]),
defer_loading: Some(false),
input_examples: None,
strict: None,
cache_control: None,
});
}
// Sort by name for prefix-cache stability — the tool block sent to
// the model needs to be deterministic across runs (#1319).
api_tools.sort_by(|a, b| a.name.cmp(&b.name));
api_tools
}
/// Call a tool by its prefixed name (mcp_{server}_{tool})
pub async fn call_tool(
&mut self,
prefixed_name: &str,
arguments: serde_json::Value,
) -> Result<serde_json::Value> {
if prefixed_name == "list_mcp_resources" {
let server = arguments
.get("server")
.and_then(|v| v.as_str())
.map(str::to_string);
let resources = self.list_resources(server).await?;
return Ok(serde_json::json!({ "resources": resources }));
}
if prefixed_name == "list_mcp_resource_templates" {
let server = arguments
.get("server")
.and_then(|v| v.as_str())
.map(str::to_string);
let templates = self.list_resource_templates(server).await?;
return Ok(serde_json::json!({ "templates": templates }));
}
if prefixed_name == "mcp_read_resource" {
let server_name = arguments
.get("server")
.and_then(|v| v.as_str())
.context("Missing 'server' argument")?;
let uri = arguments
.get("uri")
.and_then(|v| v.as_str())
.context("Missing 'uri' argument")?;
return self.read_resource(server_name, uri).await;
}
if prefixed_name == "read_mcp_resource" {
let server_name = arguments
.get("server")
.and_then(|v| v.as_str())
.context("Missing 'server' argument")?;
let uri = arguments
.get("uri")
.and_then(|v| v.as_str())
.context("Missing 'uri' argument")?;
return self.read_resource(server_name, uri).await;
}
if prefixed_name == "mcp_get_prompt" {
let server_name = arguments
.get("server")
.and_then(|v| v.as_str())
.context("Missing 'server' argument")?;
let name = arguments
.get("name")
.and_then(|v| v.as_str())
.context("Missing 'name' argument")?;
let args = arguments
.get("arguments")
.cloned()
.unwrap_or(serde_json::json!({}));
return self.get_prompt(server_name, name, args).await;
}
let (server_name, tool_name) = self.parse_prefixed_name(prefixed_name)?;
// Copy the global timeouts to avoid borrow conflict
let global_timeouts = self.config.timeouts;
let conn = self.get_or_connect(server_name).await?;
if !conn.config().is_tool_enabled(tool_name) {
anyhow::bail!("MCP tool '{tool_name}' is disabled for server '{server_name}'");
}
let timeout = conn.config().effective_execute_timeout(&global_timeouts);
match conn.call_tool(tool_name, arguments.clone(), timeout).await {
Ok(result) => Ok(result),
Err(err) if is_mcp_stale_session_error(&err) => {
tracing::debug!(
target: "mcp",
server = server_name,
tool = tool_name,
error = %err,
"retrying MCP tool call after stale session"
);
self.connections.remove(server_name);
let conn = self.get_or_connect(server_name).await?;
if !conn.config().is_tool_enabled(tool_name) {
anyhow::bail!("MCP tool '{tool_name}' is disabled for server '{server_name}'");
}
let timeout = conn.config().effective_execute_timeout(&global_timeouts);
conn.call_tool(tool_name, arguments, timeout).await
}
Err(err) => Err(err),
}
}
/// Get list of configured server names
#[allow(dead_code)] // Public API for MCP consumers
pub fn server_names(&self) -> Vec<&str> {
self.config
.servers
.keys()
.map(std::string::String::as_str)
.collect()
}
/// Get list of connected server names
pub fn connected_servers(&self) -> Vec<&str> {
self.connections
.iter()
.filter(|(_, c)| c.is_ready())
.map(|(n, _)| n.as_str())
.collect()
}
/// Disconnect all connections
#[allow(dead_code)] // Public API for MCP lifecycle management
pub fn disconnect_all(&mut self) {
self.connections.clear();
}
/// Graceful shutdown of every connection in the pool: send SIGTERM to
/// each stdio child and give them a short grace period before drop
/// fires SIGKILL. Whalescale#420.
///
/// Call from the TUI exit path *before* dropping the pool to give
/// MCP servers a chance to flush state. The fallback Drop on
/// `StdioTransport` still sends SIGTERM if this never runs, so even
/// abnormal exits avoid leaking PIDs without a signal.
#[allow(dead_code)] // Wired in by callers that want graceful shutdown
pub async fn shutdown_all(&mut self) {
let names: Vec<String> = self.connections.keys().cloned().collect();
for name in names {
if let Some(conn) = self.connections.get_mut(&name) {
conn.transport.shutdown().await;
}
}
self.connections.clear();
}
/// Get the underlying configuration
#[allow(dead_code)] // Public API for MCP consumers
pub fn config(&self) -> &McpConfig {
&self.config
}
/// Check if a tool name is an MCP tool
pub fn is_mcp_tool(name: &str) -> bool {
name.starts_with("mcp_")
|| matches!(
name,
"list_mcp_resources" | "list_mcp_resource_templates" | "read_mcp_resource"
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum McpWriteStatus {
Created,
Overwritten,
SkippedExists,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct McpDiscoveredItem {
pub name: String,
pub model_name: String,
pub description: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct McpServerSnapshot {
pub name: String,
pub enabled: bool,
pub required: bool,
pub transport: String,
pub command_or_url: String,
pub connect_timeout: u64,
pub execute_timeout: u64,
pub read_timeout: u64,
pub connected: bool,
pub error: Option<String>,
pub tools: Vec<McpDiscoveredItem>,
pub resources: Vec<McpDiscoveredItem>,
pub prompts: Vec<McpDiscoveredItem>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct McpManagerSnapshot {
pub config_path: std::path::PathBuf,
pub config_exists: bool,
pub restart_required: bool,
pub servers: Vec<McpServerSnapshot>,
}
pub fn load_config(path: &Path) -> Result<McpConfig> {
validate_mcp_config_path(path)?;
if !path.exists() {
return Ok(McpConfig::default());
}
let contents = fs::read_to_string(path)
.with_context(|| format!("Failed to read MCP config {}", path.display()))?;
serde_json::from_str(&contents)
.with_context(|| format!("Failed to parse MCP config {}", path.display()))
}
pub fn workspace_mcp_config_path(workspace: &Path) -> PathBuf {
normalize_workspace_path(workspace)
.join(".codewhale")
.join("mcp.json")
}
pub fn load_config_with_workspace(global_path: &Path, workspace: &Path) -> Result<McpConfig> {
let mut merged = load_config(global_path)?;
let workspace = normalize_workspace_path(workspace);
let project_path = workspace_mcp_config_path(&workspace);
if !project_path.exists() || paths_refer_to_same_config(global_path, &project_path) {
return Ok(merged);
}
// Workspace-local MCP can spawn stdio servers, so it is only honored after
// the user has trusted this workspace in user-owned config. Do not accept
// project-local legacy trust markers here: a repository could carry those
// files itself and silently reintroduce the project-scope `mcp_config_path`
// risk denied in #417.
if !workspace_allows_project_mcp_config(&workspace) {
return Ok(merged);
}
let mut project = load_config(&project_path)?;
for server in project.servers.values_mut() {
if server.command.is_some() && server.url.is_none() {
let cwd = match server.cwd.as_deref() {
Some(cwd) if cwd.is_relative() => normalize_path_components(&workspace.join(cwd)),
Some(cwd) => normalize_path_components(cwd),
None => workspace.to_path_buf(),
};
if !cwd.starts_with(&workspace) {
anyhow::bail!(
"Project MCP server cwd must stay within workspace: {}",
cwd.display()
);
}
server.cwd = Some(cwd);
}
}
merged.servers.extend(project.servers);
Ok(merged)
}
fn workspace_allows_project_mcp_config(workspace: &Path) -> bool {
crate::config::is_workspace_trusted(workspace)
}
fn normalize_workspace_path(workspace: &Path) -> PathBuf {
if let Ok(canonical) = workspace.canonicalize() {
return canonical;
}
let absolute = if workspace.is_absolute() {
workspace.to_path_buf()
} else {
std::env::current_dir()
.unwrap_or_else(|_| PathBuf::from("."))
.join(workspace)
};
normalize_path_components(&absolute)
}
fn normalize_path_components(path: &Path) -> PathBuf {
let mut normalized = PathBuf::new();
for component in path.components() {
match component {
Component::Prefix(_) | Component::RootDir => {
normalized.push(component.as_os_str());
}
Component::CurDir => {}
Component::ParentDir => {
normalized.pop();
}
Component::Normal(part) => normalized.push(part),
}
}
if normalized.as_os_str().is_empty() {
PathBuf::from(".")
} else {
normalized
}
}
fn paths_refer_to_same_config(left: &Path, right: &Path) -> bool {
match (left.canonicalize(), right.canonicalize()) {
(Ok(left), Ok(right)) => left == right,
_ => normalize_workspace_path(left) == normalize_workspace_path(right),
}
}
/// 64-bit content hash of an [`McpConfig`]. Used by [`McpPool`] to decide
/// whether a freshly-read config differs from the one currently driving the
/// live connections. Hashing the JSON serialization avoids forcing every
/// nested config type to derive `Hash` (the timeouts struct, network policy
/// stubs, etc.). The hash is stable across runs of the same Rust toolchain
/// for byte-identical input.
fn hash_mcp_config(config: &McpConfig) -> u64 {
use std::hash::{Hash, Hasher};
let bytes = serde_json::to_vec(config).unwrap_or_default();
let mut hasher = std::collections::hash_map::DefaultHasher::new();
bytes.hash(&mut hasher);
hasher.finish()
}
/// Best-effort fetch of the MCP config file's last-modified time. Returns
/// `None` when the file is missing, when stat fails, when the platform
/// doesn't expose mtime, or when the path fails the same allow-list check
/// that `load_config` / `save_config` apply. The lazy-reload check in
/// `McpPool::get_or_connect` treats `None` as "skip the check this turn",
/// so a rejected path simply degrades to "no auto-reload" rather than an
/// error path. Callers already validate via `validate_mcp_config_path` at
/// construction time; the redundant validation here keeps this helper
/// safe-by-construction for any future caller and ties the validation to
/// the call site rather than relying on cross-function reasoning.
fn mcp_config_mtime(path: &Path) -> Option<std::time::SystemTime> {
validate_mcp_config_path(path).ok()?;
fs::metadata(path).ok()?.modified().ok()
}
pub fn save_config(path: &Path, cfg: &McpConfig) -> Result<()> {
validate_mcp_config_path(path)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).with_context(|| {
format!("Failed to create MCP config directory {}", parent.display())
})?;
}
let rendered = serde_json::to_string_pretty(cfg).context("Failed to serialize MCP config")?;
write_atomic(path, rendered.as_bytes())
.with_context(|| format!("Failed to write MCP config {}", path.display()))?;
Ok(())
}
fn mcp_template_json() -> Result<String> {
let mut cfg = McpConfig::default();
cfg.servers.insert(
"example".to_string(),
McpServerConfig {
command: Some("node".to_string()),
args: vec!["./path/to/your-mcp-server.js".to_string()],
env: HashMap::new(),
cwd: None,
url: None,
transport: None,
connect_timeout: None,
execute_timeout: None,
read_timeout: None,
disabled: true,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
headers: HashMap::new(),
},
);
serde_json::to_string_pretty(&cfg).context("Failed to render MCP template JSON")
}
pub fn init_config(path: &Path, force: bool) -> Result<McpWriteStatus> {
if path.exists() && !force {
return Ok(McpWriteStatus::SkippedExists);
}
let status = if path.exists() {
McpWriteStatus::Overwritten
} else {
McpWriteStatus::Created
};
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).with_context(|| {
format!("Failed to create MCP config directory {}", parent.display())
})?;
}
let template = mcp_template_json()?;
write_atomic(path, template.as_bytes())
.with_context(|| format!("Failed to write MCP config {}", path.display()))?;
Ok(status)
}
pub fn add_server_config(
path: &Path,
name: String,
command: Option<String>,
url: Option<String>,
args: Vec<String>,
transport: Option<String>,
) -> Result<()> {
if command.is_none() && url.is_none() {
anyhow::bail!("Provide either a command or URL for MCP server '{name}'.");
}
validate_mcp_transport(transport.as_deref())?;
let mut cfg = load_config(path)?;
cfg.servers.insert(
name,
McpServerConfig {
command,
args,
env: HashMap::new(),
cwd: None,
url,
transport,
connect_timeout: None,
execute_timeout: None,
read_timeout: None,
disabled: false,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
headers: HashMap::new(),
},
);
save_config(path, &cfg)
}
pub fn remove_server_config(path: &Path, name: &str) -> Result<()> {
let mut cfg = load_config(path)?;
if cfg.servers.remove(name).is_none() {
anyhow::bail!("MCP server '{name}' not found");
}
save_config(path, &cfg)
}
pub fn set_server_enabled(path: &Path, name: &str, enabled: bool) -> Result<()> {
let mut cfg = load_config(path)?;
let server = cfg
.servers
.get_mut(name)
.ok_or_else(|| anyhow::anyhow!("MCP server '{name}' not found"))?;
server.enabled = enabled;
server.disabled = !enabled;
save_config(path, &cfg)
}
#[cfg(test)]
pub fn manager_snapshot_from_config(
path: &Path,
restart_required: bool,
) -> Result<McpManagerSnapshot> {
let cfg = load_config(path)?;
Ok(snapshot_from_config(
path,
path.exists(),
restart_required,
&cfg,
None,
))
}
pub fn manager_snapshot_from_config_with_workspace(
path: &Path,
workspace: &Path,
restart_required: bool,
) -> Result<McpManagerSnapshot> {
let cfg = load_config_with_workspace(path, workspace)?;
Ok(snapshot_from_config(
path,
path.exists(),
restart_required,
&cfg,
None,
))
}
#[cfg(test)]
pub async fn discover_manager_snapshot(
path: &Path,
network_policy: Option<NetworkPolicyDecider>,
restart_required: bool,
) -> Result<McpManagerSnapshot> {
let cfg = load_config(path)?;
let mut pool = McpPool::new(cfg.clone());
if let Some(policy) = network_policy {
pool = pool.with_network_policy(policy);
}
let errors = pool
.connect_all()
.await
.into_iter()
.map(|(name, err)| (name, format!("{err:#}")))
.collect::<HashMap<_, _>>();
Ok(snapshot_from_config(
path,
path.exists(),
restart_required,
&cfg,
Some((&pool, &errors)),
))
}
pub async fn discover_manager_snapshot_with_workspace(
path: &Path,
workspace: &Path,
network_policy: Option<NetworkPolicyDecider>,
restart_required: bool,
) -> Result<McpManagerSnapshot> {
let cfg = load_config_with_workspace(path, workspace)?;
let mut pool = McpPool::new(cfg.clone());
if let Some(policy) = network_policy {
pool = pool.with_network_policy(policy);
}
let errors = pool
.connect_all()
.await
.into_iter()
.map(|(name, err)| (name, format!("{err:#}")))
.collect::<HashMap<_, _>>();
Ok(snapshot_from_config(
path,
path.exists(),
restart_required,
&cfg,
Some((&pool, &errors)),
))
}
fn snapshot_from_config(
path: &Path,
config_exists: bool,
restart_required: bool,
cfg: &McpConfig,
discovery: Option<(&McpPool, &HashMap<String, String>)>,
) -> McpManagerSnapshot {
let mut servers = cfg
.servers
.iter()
.map(|(name, server)| {
let transport = if server.url.is_some() {
if is_legacy_sse_transport(server) {
"sse"
} else {
"http/sse"
}
} else {
"stdio"
};
let command_or_url = server.url.clone().unwrap_or_else(|| {
let mut command = server
.command
.clone()
.unwrap_or_else(|| "(missing)".to_string());
if !server.args.is_empty() {
command.push(' ');
command.push_str(&server.args.join(" "));
}
command
});
let mut snapshot = McpServerSnapshot {
name: name.clone(),
enabled: server.is_enabled(),
required: server.required,
transport: transport.to_string(),
command_or_url,
connect_timeout: server.effective_connect_timeout(&cfg.timeouts),
execute_timeout: server.effective_execute_timeout(&cfg.timeouts),
read_timeout: server.effective_read_timeout(&cfg.timeouts),
connected: false,
error: if server.is_enabled() {
None
} else {
Some("disabled".to_string())
},
tools: Vec::new(),
resources: Vec::new(),
prompts: Vec::new(),
};
if let Some((pool, errors)) = discovery {
if let Some(error) = errors.get(name) {
snapshot.error = Some(error.clone());
}
if let Some(conn) = pool.connections.get(name) {
snapshot.connected = conn.is_ready();
snapshot.tools = conn
.tools()
.iter()
.filter(|tool| conn.config().is_tool_enabled(&tool.name))
.map(|tool| McpDiscoveredItem {
name: tool.name.clone(),
model_name: format!("mcp_{}_{}", name, tool.name),
description: tool.description.clone(),
})
.collect();
snapshot.resources =
conn.resources()
.iter()
.map(|resource| McpDiscoveredItem {
name: resource.name.clone(),
model_name: format!(
"mcp_{}_{}",
name,
resource.name.replace(' ', "_").to_lowercase()
),
description: resource.description.clone(),
})
.chain(conn.resource_templates().iter().map(|template| {
McpDiscoveredItem {
name: template.name.clone(),
model_name: format!(
"mcp_{}_{}",
name,
template.name.replace(' ', "_").to_lowercase()
),
description: template.description.clone(),
}
}))
.collect();
snapshot.prompts = conn
.prompts()
.iter()
.map(|prompt| McpDiscoveredItem {
name: prompt.name.clone(),
model_name: format!("mcp_{}_{}", name, prompt.name),
description: prompt.description.clone(),
})
.collect();
}
}
snapshot
})
.collect::<Vec<_>>();
servers.sort_by(|a, b| a.name.cmp(&b.name));
McpManagerSnapshot {
config_path: path.to_path_buf(),
config_exists,
restart_required,
servers,
}
}
// === Helper Functions ===
/// Format MCP tool result for display
#[allow(dead_code)] // Will be used when MCP tool results are displayed in TUI
pub fn format_tool_result(result: &serde_json::Value) -> String {
let is_error = result
.get("isError")
.and_then(serde_json::Value::as_bool)
.unwrap_or(false);
let content = result
.get("content")
.and_then(|v| v.as_array())
.map_or_else(
|| serde_json::to_string_pretty(result).unwrap_or_default(),
|arr| {
arr.iter()
.filter_map(|item| match item.get("type")?.as_str()? {
"text" => item.get("text")?.as_str().map(String::from),
other => Some(format!("[{other} content]")),
})
.collect::<Vec<_>>()
.join("\n")
},
);
if is_error {
format!("Error: {content}")
} else {
content
}
}
// === Unit Tests ===
#[cfg(test)]
mod tests {
use super::*;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
use std::sync::{Arc, Mutex, OnceLock};
fn test_http_client() -> reqwest::Client {
let _ = rustls::crypto::ring::default_provider().install_default();
crate::tls::reqwest_client()
}
async fn lock_mcp_loopback_tests() -> tokio::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<tokio::sync::Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| tokio::sync::Mutex::new(()))
.lock()
.await
}
struct WorkspaceTrustConfigGuard {
config_path: PathBuf,
_codewhale_config_path: crate::test_support::EnvVarGuard,
_deepseek_config_path: crate::test_support::EnvVarGuard,
_env_lock: std::sync::MutexGuard<'static, ()>,
}
fn workspace_trust_config_guard(workspace: &Path) -> WorkspaceTrustConfigGuard {
let env_lock = crate::test_support::lock_test_env();
let config_path = workspace
.parent()
.unwrap_or(workspace)
.join("user-config")
.join("config.toml");
if let Some(parent) = config_path.parent() {
fs::create_dir_all(parent).unwrap();
}
let codewhale_config_path =
crate::test_support::EnvVarGuard::set("CODEWHALE_CONFIG_PATH", config_path.as_os_str());
let deepseek_config_path = crate::test_support::EnvVarGuard::remove("DEEPSEEK_CONFIG_PATH");
WorkspaceTrustConfigGuard {
config_path,
_codewhale_config_path: codewhale_config_path,
_deepseek_config_path: deepseek_config_path,
_env_lock: env_lock,
}
}
fn write_workspace_trust_config(config_path: &Path, workspace: &Path) {
let workspace = workspace
.canonicalize()
.unwrap_or_else(|_| workspace.to_path_buf());
let key = workspace
.to_string_lossy()
.replace('\\', "\\\\")
.replace('"', "\\\"");
fs::write(
config_path,
format!("[projects.\"{key}\"]\ntrust_level = \"trusted\"\n"),
)
.unwrap();
}
fn mark_workspace_trusted(workspace: &Path) -> WorkspaceTrustConfigGuard {
let guard = workspace_trust_config_guard(workspace);
write_workspace_trust_config(&guard.config_path, workspace);
guard
}
#[test]
fn test_mcp_config_defaults() {
let config = McpConfig::default();
assert_eq!(config.timeouts.connect_timeout, 10);
assert_eq!(config.timeouts.execute_timeout, 60);
assert_eq!(config.timeouts.read_timeout, 120);
assert!(config.servers.is_empty());
}
#[test]
fn test_mcp_config_parse() {
let json = r#"{
"timeouts": {
"connect_timeout": 15,
"execute_timeout": 90
},
"servers": {
"test": {
"command": "node",
"args": ["server.js"],
"env": {"FOO": "bar"}
}
}
}"#;
let config: McpConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.timeouts.connect_timeout, 15);
assert_eq!(config.timeouts.execute_timeout, 90);
assert_eq!(config.timeouts.read_timeout, 120); // default
assert!(config.servers.contains_key("test"));
let server = config.servers.get("test").unwrap();
assert_eq!(server.command, Some("node".to_string()));
assert_eq!(server.args, vec!["server.js"]);
assert_eq!(server.env.get("FOO"), Some(&"bar".to_string()));
}
#[test]
fn mcp_pool_parse_prefixed_name_preserves_registered_underscored_server() {
let config: McpConfig = serde_json::from_str(
r#"{
"servers": {
"my": {"command": "node"},
"my_db": {"command": "node"}
}
}"#,
)
.unwrap();
let pool = McpPool::new(config);
let (server, tool) = pool
.parse_prefixed_name("mcp_my_db_execute_sql")
.expect("registered underscored server should parse");
assert_eq!(server, "my_db");
assert_eq!(tool, "execute_sql");
}
#[test]
fn mcp_server_config_parses_custom_headers() {
let json = r#"{
"servers": {
"hf": {
"url": "https://example.invalid/mcp",
"headers": {
"Authorization": "Bearer tok",
"X-Org": "anthropic"
}
}
}
}"#;
let cfg: McpConfig = serde_json::from_str(json).unwrap();
let hf = cfg.servers.get("hf").expect("server present");
assert_eq!(
hf.headers.get("Authorization"),
Some(&"Bearer tok".to_string())
);
assert_eq!(hf.headers.get("X-Org"), Some(&"anthropic".to_string()));
}
#[test]
fn mcp_server_config_omits_headers_when_empty() {
// Empty headers map should not appear in the serialized output —
// older mcp.json files written before v0.8.31 must round-trip
// unchanged so a `mcp save` from a fresh install doesn't add
// dead keys.
let cfg = McpServerConfig {
command: Some("node".into()),
args: vec!["server.js".into()],
env: HashMap::new(),
cwd: None,
url: None,
transport: None,
connect_timeout: None,
execute_timeout: None,
read_timeout: None,
disabled: false,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
headers: HashMap::new(),
};
let serialized = serde_json::to_string(&cfg).unwrap();
assert!(
!serialized.contains("\"headers\""),
"empty headers must be omitted: {serialized}"
);
}
#[test]
fn is_safe_custom_header_accepts_normal_auth_pairs() {
assert!(is_safe_custom_header("Authorization", "Bearer tok"));
assert!(is_safe_custom_header("X-Api-Key", "deadbeef"));
assert!(is_safe_custom_header("x-org", "anthropic"));
}
#[test]
fn is_safe_custom_header_rejects_empty_or_whitespace_key() {
assert!(!is_safe_custom_header("", "value"));
assert!(!is_safe_custom_header(" ", "value"));
}
#[test]
fn is_safe_custom_header_rejects_response_splitting_values() {
assert!(
!is_safe_custom_header("X-Foo", "abc\r\nSet-Cookie: evil=1"),
"CRLF in value must reject — response-splitting defense"
);
assert!(
!is_safe_custom_header("X-Foo", "abc\nbar"),
"bare LF in value must reject"
);
assert!(
!is_safe_custom_header("X-Foo", "abc\rbar"),
"bare CR in value must reject"
);
}
#[test]
fn is_safe_custom_header_rejects_protocol_framing_overrides() {
// The MCP Streamable HTTP transport relies on its own
// Accept / Content-Type values for protocol negotiation;
// a stray user override would silently break tool discovery.
assert!(!is_safe_custom_header("Accept", "text/plain"));
assert!(!is_safe_custom_header("accept", "text/plain"));
assert!(!is_safe_custom_header("Content-Type", "text/plain"));
assert!(!is_safe_custom_header("CONTENT-TYPE", "x/y"));
}
#[test]
fn default_mcp_http_get_accepts_json_and_event_stream() {
let client = test_http_client();
let request =
with_default_mcp_http_headers(client.get("https://example.invalid/mcp"), false)
.build()
.unwrap();
assert_eq!(
request.headers().get(ACCEPT).and_then(|v| v.to_str().ok()),
Some(MCP_HTTP_ACCEPT)
);
assert!(
request.headers().get(CONTENT_TYPE).is_none(),
"SSE GET requests should not advertise a JSON request body"
);
}
#[test]
fn default_mcp_http_post_accepts_json_and_event_stream() {
let client = test_http_client();
let request =
with_default_mcp_http_headers(client.post("https://example.invalid/mcp"), true)
.build()
.unwrap();
assert_eq!(
request.headers().get(ACCEPT).and_then(|v| v.to_str().ok()),
Some(MCP_HTTP_ACCEPT)
);
assert_eq!(
request
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok()),
Some("application/json")
);
}
#[test]
fn streamable_http_transport_stores_headers() {
let client = test_http_client();
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), "Bearer xyz".to_string());
let transport = StreamableHttpTransport::new(
client,
"https://example.invalid/mcp".to_string(),
headers.clone(),
);
assert_eq!(transport.headers, headers);
}
#[test]
fn test_mcp_config_parse_mcp_servers_alias_and_snapshot() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("mcp.json");
fs::write(
&path,
r#"{
"mcpServers": {
"disabled": {
"command": "node",
"args": ["server.js"],
"disabled": true
}
}
}"#,
)
.unwrap();
let cfg = load_config(&path).unwrap();
assert!(cfg.servers.contains_key("disabled"));
let snapshot = manager_snapshot_from_config(&path, true).unwrap();
assert!(snapshot.restart_required);
assert_eq!(snapshot.servers[0].name, "disabled");
assert!(!snapshot.servers[0].enabled);
assert_eq!(snapshot.servers[0].error.as_deref(), Some("disabled"));
}
#[test]
fn workspace_mcp_config_merges_with_project_overrides() {
let dir = tempfile::tempdir().unwrap();
let global_path = dir.path().join("global-mcp.json");
let workspace = dir.path().join("workspace");
let project_dir = workspace.join(".codewhale");
fs::create_dir_all(&project_dir).unwrap();
let _trust = mark_workspace_trusted(&workspace);
fs::write(
&global_path,
r#"{
"servers": {
"global": {"command": "node", "args": ["global.js"]},
"shared": {"command": "node", "args": ["global-shared.js"]}
}
}"#,
)
.unwrap();
fs::write(
project_dir.join("mcp.json"),
r#"{
"servers": {
"project": {"command": "php", "args": ["artisan", "boost:mcp"]},
"shared": {"command": "php", "args": ["artisan", "shared:mcp"]}
}
}"#,
)
.unwrap();
let cfg = load_config_with_workspace(&global_path, &workspace).unwrap();
let workspace = workspace.canonicalize().unwrap();
assert!(cfg.servers.contains_key("global"));
let project = cfg.servers.get("project").unwrap();
assert_eq!(project.command.as_deref(), Some("php"));
assert_eq!(project.cwd.as_deref(), Some(workspace.as_path()));
let shared = cfg.servers.get("shared").unwrap();
assert_eq!(shared.args, vec!["artisan", "shared:mcp"]);
assert_eq!(shared.cwd.as_deref(), Some(workspace.as_path()));
}
#[test]
fn workspace_manager_snapshot_counts_global_and_project_servers() {
let dir = tempfile::tempdir().unwrap();
let global_path = dir.path().join("global-mcp.json");
let workspace = dir.path().join("workspace");
let project_dir = workspace.join(".codewhale");
fs::create_dir_all(&project_dir).unwrap();
let _trust = mark_workspace_trusted(&workspace);
fs::write(
&global_path,
r#"{
"servers": {
"chrome-devtools": {"command": "npx", "args": ["-y", "chrome-devtools-mcp@latest"]},
"context7": {"command": "npx", "args": ["-y", "@upstash/context7-mcp@latest"]}
}
}"#,
)
.unwrap();
fs::write(
project_dir.join("mcp.json"),
r#"{
"servers": {
"laravel-boost": {"command": "php", "args": ["artisan", "boost:mcp"]}
}
}"#,
)
.unwrap();
let plain = manager_snapshot_from_config(&global_path, false).unwrap();
let merged =
manager_snapshot_from_config_with_workspace(&global_path, &workspace, false).unwrap();
assert_eq!(plain.servers.len(), 2);
assert_eq!(merged.servers.len(), 3);
assert!(
merged
.servers
.iter()
.any(|server| server.name == "laravel-boost"),
"workspace-aware snapshots must include trusted project MCP servers"
);
}
#[test]
fn workspace_mcp_config_ignores_project_file_until_workspace_trusted() {
let dir = tempfile::tempdir().unwrap();
let global_path = dir.path().join("global-mcp.json");
let workspace = dir.path().join("workspace");
let project_dir = workspace.join(".codewhale");
fs::create_dir_all(&project_dir).unwrap();
fs::write(
&global_path,
r#"{"servers": {"global": {"command": "node", "args": ["global.js"]}}}"#,
)
.unwrap();
fs::write(
project_dir.join("mcp.json"),
r#"{"servers": {"project": {"command": "php", "args": ["artisan", "boost:mcp"]}}}"#,
)
.unwrap();
let cfg = load_config_with_workspace(&global_path, &workspace).unwrap();
assert!(cfg.servers.contains_key("global"));
assert!(!cfg.servers.contains_key("project"));
}
#[test]
fn workspace_mcp_config_ignores_project_local_legacy_trust_marker() {
let dir = tempfile::tempdir().unwrap();
let global_path = dir.path().join("global-mcp.json");
let workspace = dir.path().join("workspace");
let project_dir = workspace.join(".codewhale");
fs::create_dir_all(&project_dir).unwrap();
fs::create_dir_all(workspace.join(".deepseek")).unwrap();
fs::write(workspace.join(".deepseek").join("trusted"), "").unwrap();
fs::write(
&global_path,
r#"{"servers": {"global": {"command": "node", "args": ["global.js"]}}}"#,
)
.unwrap();
fs::write(
project_dir.join("mcp.json"),
r#"{"servers": {"project": {"command": "php", "args": ["artisan", "boost:mcp"]}}}"#,
)
.unwrap();
let cfg = load_config_with_workspace(&global_path, &workspace).unwrap();
assert!(cfg.servers.contains_key("global"));
assert!(!cfg.servers.contains_key("project"));
}
#[test]
fn workspace_mcp_config_ignores_invalid_untrusted_project_file() {
let dir = tempfile::tempdir().unwrap();
let global_path = dir.path().join("global-mcp.json");
let workspace = dir.path().join("workspace");
let project_dir = workspace.join(".codewhale");
fs::create_dir_all(&project_dir).unwrap();
fs::write(&global_path, r#"{"servers": {}}"#).unwrap();
fs::write(project_dir.join("mcp.json"), "{ not json").unwrap();
let cfg = load_config_with_workspace(&global_path, &workspace).unwrap();
assert!(cfg.servers.is_empty());
}
#[test]
fn workspace_mcp_config_normalizes_parent_components() {
let dir = tempfile::tempdir().unwrap();
let global_path = dir.path().join("global-mcp.json");
let workspace = dir.path().join("workspace");
let project_dir = workspace.join(".codewhale");
fs::create_dir_all(&project_dir).unwrap();
let _trust = mark_workspace_trusted(&workspace);
fs::write(&global_path, r#"{"servers": {}}"#).unwrap();
fs::write(
project_dir.join("mcp.json"),
r#"{"servers": {"project": {"command": "node", "args": ["server.js"]}}}"#,
)
.unwrap();
let workspace_with_parent = workspace.join("..").join("workspace");
let cfg = load_config_with_workspace(&global_path, &workspace_with_parent).unwrap();
let workspace = workspace.canonicalize().unwrap();
assert!(cfg.servers.contains_key("project"));
let project = cfg.servers.get("project").unwrap();
assert_eq!(project.cwd.as_deref(), Some(workspace.as_path()));
}
#[test]
fn workspace_mcp_config_resolves_relative_cwd_from_workspace() {
let dir = tempfile::tempdir().unwrap();
let global_path = dir.path().join("global-mcp.json");
let workspace = dir.path().join("workspace");
let project_dir = workspace.join(".codewhale");
fs::create_dir_all(&project_dir).unwrap();
let _trust = mark_workspace_trusted(&workspace);
fs::write(&global_path, r#"{"servers": {}}"#).unwrap();
fs::write(
project_dir.join("mcp.json"),
r#"{"servers": {"project": {"command": "node", "args": ["server.js"], "cwd": "tools/mcp"}}}"#,
)
.unwrap();
let cfg = load_config_with_workspace(&global_path, &workspace).unwrap();
let workspace = workspace.canonicalize().unwrap();
let project = cfg.servers.get("project").unwrap();
assert_eq!(
project.cwd.as_deref(),
Some(workspace.join("tools/mcp").as_path())
);
}
#[test]
fn workspace_mcp_config_rejects_project_cwd_escape() {
let dir = tempfile::tempdir().unwrap();
let global_path = dir.path().join("global-mcp.json");
let workspace = dir.path().join("workspace");
let project_dir = workspace.join(".codewhale");
fs::create_dir_all(&project_dir).unwrap();
let _trust = mark_workspace_trusted(&workspace);
fs::write(&global_path, r#"{"servers": {}}"#).unwrap();
fs::write(
project_dir.join("mcp.json"),
r#"{"servers": {"project": {"command": "node", "args": ["server.js"], "cwd": "../outside"}}}"#,
)
.unwrap();
let err = load_config_with_workspace(&global_path, &workspace)
.expect_err("project MCP cwd escape must be rejected");
assert!(
err.to_string()
.contains("Project MCP server cwd must stay within workspace"),
"unexpected error: {err}"
);
}
#[tokio::test]
async fn workspace_mcp_pool_reload_picks_up_project_config_creation() {
let dir = tempfile::tempdir().unwrap();
let global_path = dir.path().join("global-mcp.json");
let workspace = dir.path().join("workspace");
let project_dir = workspace.join(".codewhale");
fs::create_dir_all(&workspace).unwrap();
let _trust = mark_workspace_trusted(&workspace);
fs::write(
&global_path,
r#"{"servers": {"global": {"command": "node", "args": ["global.js"]}}}"#,
)
.unwrap();
let mut pool = McpPool::from_config_path_with_workspace(&global_path, &workspace).unwrap();
assert_eq!(pool.server_names(), vec!["global"]);
fs::create_dir_all(&project_dir).unwrap();
fs::write(
project_dir.join("mcp.json"),
r#"{"servers": {"project": {"command": "php", "args": ["artisan", "boost:mcp"]}}}"#,
)
.unwrap();
assert!(pool.reload_if_config_changed().await.unwrap());
let names: std::collections::BTreeSet<_> = pool.server_names().into_iter().collect();
let expected: std::collections::BTreeSet<_> = ["global", "project"].into_iter().collect();
assert_eq!(names, expected);
}
#[tokio::test]
async fn workspace_mcp_pool_reload_picks_up_project_config_after_workspace_trust() {
let dir = tempfile::tempdir().unwrap();
let global_path = dir.path().join("global-mcp.json");
let workspace = dir.path().join("workspace");
let project_dir = workspace.join(".codewhale");
fs::create_dir_all(&project_dir).unwrap();
let trust_env = workspace_trust_config_guard(&workspace);
fs::write(
&global_path,
r#"{"servers": {"global": {"command": "node", "args": ["global.js"]}}}"#,
)
.unwrap();
fs::write(
project_dir.join("mcp.json"),
r#"{"servers": {"project": {"command": "php", "args": ["artisan", "boost:mcp"]}}}"#,
)
.unwrap();
let mut pool = McpPool::from_config_path_with_workspace(&global_path, &workspace).unwrap();
assert_eq!(pool.server_names(), vec!["global"]);
write_workspace_trust_config(&trust_env.config_path, &workspace);
assert!(pool.reload_if_config_changed().await.unwrap());
let names: std::collections::BTreeSet<_> = pool.server_names().into_iter().collect();
let expected: std::collections::BTreeSet<_> = ["global", "project"].into_iter().collect();
assert_eq!(names, expected);
}
#[tokio::test]
async fn workspace_mcp_pool_reload_drops_project_config_after_workspace_trust_removed() {
let dir = tempfile::tempdir().unwrap();
let global_path = dir.path().join("global-mcp.json");
let workspace = dir.path().join("workspace");
let project_dir = workspace.join(".codewhale");
fs::create_dir_all(&project_dir).unwrap();
let trust = mark_workspace_trusted(&workspace);
fs::write(
&global_path,
r#"{"servers": {"global": {"command": "node", "args": ["global.js"]}}}"#,
)
.unwrap();
fs::write(
project_dir.join("mcp.json"),
r#"{"servers": {"project": {"command": "php", "args": ["artisan", "boost:mcp"]}}}"#,
)
.unwrap();
let mut pool = McpPool::from_config_path_with_workspace(&global_path, &workspace).unwrap();
let names: std::collections::BTreeSet<_> = pool.server_names().into_iter().collect();
let expected: std::collections::BTreeSet<_> = ["global", "project"].into_iter().collect();
assert_eq!(names, expected);
fs::remove_file(&trust.config_path).unwrap();
assert!(pool.reload_if_config_changed().await.unwrap());
assert_eq!(pool.server_names(), vec!["global"]);
}
#[tokio::test]
async fn workspace_mcp_pool_reload_drops_project_config_after_deletion() {
let dir = tempfile::tempdir().unwrap();
let global_path = dir.path().join("global-mcp.json");
let workspace = dir.path().join("workspace");
let project_dir = workspace.join(".codewhale");
fs::create_dir_all(&project_dir).unwrap();
let _trust = mark_workspace_trusted(&workspace);
fs::write(
&global_path,
r#"{"servers": {"global": {"command": "node", "args": ["global.js"]}}}"#,
)
.unwrap();
let project_path = project_dir.join("mcp.json");
fs::write(
&project_path,
r#"{"servers": {"project": {"command": "php", "args": ["artisan", "boost:mcp"]}}}"#,
)
.unwrap();
let mut pool = McpPool::from_config_path_with_workspace(&global_path, &workspace).unwrap();
let names: std::collections::BTreeSet<_> = pool.server_names().into_iter().collect();
let expected: std::collections::BTreeSet<_> = ["global", "project"].into_iter().collect();
assert_eq!(names, expected);
fs::remove_file(project_path).unwrap();
assert!(pool.reload_if_config_changed().await.unwrap());
assert_eq!(pool.server_names(), vec!["global"]);
}
#[test]
fn test_mcp_config_rejects_traversal_path() {
let err = load_config(Path::new("../mcp.json")).expect_err("traversal path should fail");
assert!(
format!("{err:#}").contains("cannot contain '..'"),
"got: {err:#}"
);
}
#[test]
fn test_mcp_config_manager_actions_round_trip() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("mcp.json");
assert_eq!(init_config(&path, false).unwrap(), McpWriteStatus::Created);
assert_eq!(
init_config(&path, false).unwrap(),
McpWriteStatus::SkippedExists
);
add_server_config(
&path,
"local".to_string(),
Some("node".to_string()),
None,
vec!["server.js".to_string()],
None,
)
.unwrap();
set_server_enabled(&path, "local", false).unwrap();
let disabled = manager_snapshot_from_config(&path, true).unwrap();
let local = disabled
.servers
.iter()
.find(|server| server.name == "local")
.unwrap();
assert!(!local.enabled);
assert_eq!(local.transport, "stdio");
remove_server_config(&path, "local").unwrap();
let removed = manager_snapshot_from_config(&path, true).unwrap();
assert!(removed.servers.iter().all(|server| server.name != "local"));
}
#[test]
fn test_mcp_config_adds_explicit_sse_transport() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("mcp.json");
add_server_config(
&path,
"legacy".to_string(),
None,
Some("https://example.com/v1/mcp/sse".to_string()),
Vec::new(),
Some("sse".to_string()),
)
.unwrap();
let cfg = load_config(&path).unwrap();
assert_eq!(
cfg.servers
.get("legacy")
.and_then(|server| server.transport.as_deref()),
Some("sse")
);
let snapshot = manager_snapshot_from_config(&path, false).unwrap();
assert_eq!(snapshot.servers[0].transport, "sse");
}
#[test]
fn test_mcp_config_rejects_unknown_transport() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("mcp.json");
let err = add_server_config(
&path,
"bad".to_string(),
None,
Some("https://example.com/mcp".to_string()),
Vec::new(),
Some("streamable".to_string()),
)
.expect_err("unknown transport should fail");
assert!(
format!("{err:#}").contains("Unsupported MCP transport"),
"got: {err:#}"
);
}
#[test]
fn test_server_effective_timeouts() {
let global = McpTimeouts::default();
let server_with_override = McpServerConfig {
command: Some("test".to_string()),
args: vec![],
env: HashMap::new(),
cwd: None,
url: None,
transport: None,
connect_timeout: Some(20),
execute_timeout: None,
read_timeout: Some(180),
disabled: false,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
headers: HashMap::new(),
};
assert_eq!(server_with_override.effective_connect_timeout(&global), 20);
assert_eq!(server_with_override.effective_execute_timeout(&global), 60); // global default
assert_eq!(server_with_override.effective_read_timeout(&global), 180);
}
#[test]
fn test_mcp_pool_is_mcp_tool() {
assert!(McpPool::is_mcp_tool("mcp_filesystem_read"));
assert!(McpPool::is_mcp_tool("mcp_git_status"));
assert!(McpPool::is_mcp_tool("list_mcp_resources"));
assert!(McpPool::is_mcp_tool("list_mcp_resource_templates"));
assert!(McpPool::is_mcp_tool("read_mcp_resource"));
assert!(!McpPool::is_mcp_tool("read_file"));
assert!(!McpPool::is_mcp_tool("exec_shell"));
}
#[test]
fn test_format_tool_result_text() {
let result = serde_json::json!({
"content": [
{"type": "text", "text": "Hello, world!"}
]
});
assert_eq!(format_tool_result(&result), "Hello, world!");
}
#[test]
fn test_format_tool_result_error() {
let result = serde_json::json!({
"isError": true,
"content": [
{"type": "text", "text": "Something went wrong"}
]
});
assert_eq!(format_tool_result(&result), "Error: Something went wrong");
}
#[test]
fn test_format_tool_result_multiple_content() {
let result = serde_json::json!({
"content": [
{"type": "text", "text": "Line 1"},
{"type": "text", "text": "Line 2"},
{"type": "image", "data": "base64..."}
]
});
let formatted = format_tool_result(&result);
assert!(formatted.contains("Line 1"));
assert!(formatted.contains("Line 2"));
assert!(formatted.contains("[image content]"));
}
struct ScriptedValueTransport {
sent: Arc<Mutex<Vec<serde_json::Value>>>,
responses: VecDeque<Vec<u8>>,
}
#[async_trait::async_trait]
impl McpTransport for ScriptedValueTransport {
async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
self.sent
.lock()
.unwrap()
.push(serde_json::from_slice(&msg)?);
Ok(())
}
async fn recv(&mut self) -> Result<Vec<u8>> {
self.responses
.pop_front()
.context("scripted transport exhausted")
}
}
struct HangingValueTransport {
sent: Arc<Mutex<Vec<serde_json::Value>>>,
}
#[async_trait::async_trait]
impl McpTransport for HangingValueTransport {
async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
self.sent
.lock()
.unwrap()
.push(serde_json::from_slice(&msg)?);
Ok(())
}
async fn recv(&mut self) -> Result<Vec<u8>> {
std::future::pending().await
}
}
fn test_server_config() -> McpServerConfig {
McpServerConfig {
command: Some("mock".to_string()),
args: Vec::new(),
env: HashMap::new(),
cwd: None,
url: None,
transport: None,
connect_timeout: None,
execute_timeout: None,
read_timeout: None,
disabled: false,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
headers: HashMap::new(),
}
}
fn test_connection(transport: Box<dyn McpTransport>) -> McpConnection {
McpConnection {
name: "mock".to_string(),
transport,
tools: Vec::new(),
resources: Vec::new(),
resource_templates: Vec::new(),
prompts: Vec::new(),
request_id: AtomicU64::new(1),
state: ConnectionState::Ready,
config: test_server_config(),
cancel_token: tokio_util::sync::CancellationToken::new(),
}
}
fn json_frame(value: serde_json::Value) -> Vec<u8> {
serde_json::to_vec(&value).unwrap()
}
#[tokio::test]
async fn call_method_skips_notifications_and_unmatched_responses() {
let sent = Arc::new(Mutex::new(Vec::new()));
let transport = ScriptedValueTransport {
sent: Arc::clone(&sent),
responses: VecDeque::from([
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/progress",
"params": {"progress": 0.5}
})),
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 99,
"result": {"ignored": true}
})),
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": {"ok": true}
})),
]),
};
let mut conn = test_connection(Box::new(transport));
let result = conn
.call_method("tools/call", serde_json::json!({"name": "echo"}), 1)
.await
.unwrap();
assert_eq!(result, serde_json::json!({"ok": true}));
let sent = sent.lock().unwrap();
assert_eq!(sent.len(), 1);
assert_eq!(sent[0]["jsonrpc"], "2.0");
assert_eq!(sent[0]["id"], "1");
assert_eq!(sent[0]["method"], "tools/call");
}
#[tokio::test]
async fn call_method_invalid_json_includes_server_output_preview() {
let sent = Arc::new(Mutex::new(Vec::new()));
let transport = ScriptedValueTransport {
sent: Arc::clone(&sent),
responses: VecDeque::from([b"Allow Burp MCP connection? [y/N]".to_vec()]),
};
let mut conn = test_connection(Box::new(transport));
let err = conn
.call_method("tools/call", serde_json::json!({"name": "burp"}), 1)
.await
.expect_err("non-json MCP stdout should fail");
let msg = err.to_string();
assert!(msg.contains("Invalid MCP JSON-RPC message from server 'mock'"));
assert!(msg.contains("Allow Burp MCP connection"));
}
#[tokio::test]
async fn call_method_times_out_while_waiting_for_response() {
let sent = Arc::new(Mutex::new(Vec::new()));
let mut conn = test_connection(Box::new(HangingValueTransport {
sent: Arc::clone(&sent),
}));
let err = conn
.call_method("tools/call", serde_json::json!({"name": "echo"}), 0)
.await
.expect_err("hung receive should time out");
assert!(
err.to_string()
.contains("MCP method 'tools/call' on server 'mock' timed out after 0s"),
"unexpected error: {err:#}"
);
assert_eq!(sent.lock().unwrap().len(), 1);
}
#[tokio::test]
async fn test_mcp_pool_empty_config() {
let pool = McpPool::new(McpConfig::default());
assert!(pool.server_names().is_empty());
assert!(pool.all_tools().is_empty());
}
/// #1267 part 2: a pool built without a source path has no file to watch,
/// so `reload_if_config_changed` must short-circuit instead of trying
/// to stat `/`.
#[tokio::test]
async fn reload_if_config_changed_is_noop_without_source_path() {
let mut pool = McpPool::new(McpConfig::default());
let reloaded = pool.reload_if_config_changed().await.unwrap();
assert!(!reloaded, "no source path → no reload");
}
/// #1267 part 2: when the on-disk config is byte-unchanged, the lazy
/// reload must not drop connections — every call to `get_or_connect`
/// would otherwise pay a full reconnect cycle on networked filesystems
/// where mtime granularity is coarse.
#[tokio::test]
async fn reload_if_config_changed_skips_when_content_unchanged() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("mcp.json");
std::fs::write(&path, r#"{"servers":{}}"#).unwrap();
let mut pool = McpPool::from_config_path(&path).unwrap();
// Force the mtime to advance without changing content.
std::thread::sleep(std::time::Duration::from_millis(10));
std::fs::write(&path, r#"{"servers":{}}"#).unwrap();
let reloaded = pool.reload_if_config_changed().await.unwrap();
assert!(
!reloaded,
"content-unchanged config must not trigger a reload"
);
}
/// #1267 part 2: when the on-disk config changes content, the next
/// `reload_if_config_changed` call must swap in the new config and
/// (would) drop all live connections. We can't stand up a real
/// `McpConnection` in a unit test, so we observe the swap via the
/// publicly-readable side: server names go from empty to non-empty.
#[tokio::test]
async fn reload_if_config_changed_swaps_config_on_content_change() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("mcp.json");
std::fs::write(&path, r#"{"servers":{}}"#).unwrap();
let mut pool = McpPool::from_config_path(&path).unwrap();
assert!(pool.server_names().is_empty());
// Mutate the file so both the mtime and the hash change.
std::thread::sleep(std::time::Duration::from_millis(10));
std::fs::write(
&path,
r#"{"servers":{"new":{"command":"echo","args":["hi"]}}}"#,
)
.unwrap();
let reloaded = pool.reload_if_config_changed().await.unwrap();
assert!(reloaded, "content-changed config must trigger reload");
let names = pool.server_names();
assert!(
names.contains(&"new"),
"expected new server in pool after reload, got {names:?}"
);
}
/// #1267 part 2: hash-based comparison must be stable for byte-identical
/// configs and distinct for differing configs.
#[test]
fn hash_mcp_config_is_stable_and_change_sensitive() {
let a = McpConfig::default();
let b = McpConfig::default();
assert_eq!(hash_mcp_config(&a), hash_mcp_config(&b));
let mut c = McpConfig::default();
c.servers.insert(
"x".into(),
McpServerConfig {
command: Some("/bin/echo".into()),
args: vec!["hi".into()],
env: Default::default(),
cwd: None,
url: None,
transport: None,
connect_timeout: None,
execute_timeout: None,
read_timeout: None,
disabled: false,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
headers: HashMap::new(),
},
);
assert_ne!(
hash_mcp_config(&a),
hash_mcp_config(&c),
"hash must change when servers map changes"
);
}
/// #1319: discovered tools must be sorted by name so the prompt prefix
/// is stable across runs (cache-hit stability), even when the server
/// returns them in arbitrary or paginated order.
#[tokio::test]
async fn discover_tools_sorts_by_name_for_cache_stability() {
let sent = Arc::new(Mutex::new(Vec::new()));
let transport = ScriptedValueTransport {
sent: Arc::clone(&sent),
responses: VecDeque::from([
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": {
"tools": [
{ "name": "zeta", "inputSchema": {} },
{ "name": "alpha", "inputSchema": {} }
],
"nextCursor": "page-2"
}
})),
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"result": {
"tools": [
{ "name": "mu", "inputSchema": {} },
{ "name": "beta", "inputSchema": {} }
]
}
})),
]),
};
let mut conn = test_connection(Box::new(transport));
conn.discover_tools().await.expect("discover");
let names: Vec<&str> = conn.tools.iter().map(|t| t.name.as_str()).collect();
assert_eq!(
names,
vec!["alpha", "beta", "mu", "zeta"],
"tools must be sorted by name regardless of server order or pagination"
);
}
#[tokio::test]
async fn mcp_pool_call_tool_preserves_tool_names_with_dashes() {
let sent = Arc::new(Mutex::new(Vec::new()));
let transport = ScriptedValueTransport {
sent: Arc::clone(&sent),
responses: VecDeque::from([json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": {"ok": true}
}))]),
};
let mut conn = test_connection(Box::new(transport));
conn.name = "dephy".to_string();
conn.tools = vec![McpTool {
name: "company--search".to_string(),
description: None,
input_schema: serde_json::json!({}),
}];
let mut pool = McpPool::new(McpConfig {
timeouts: McpTimeouts::default(),
servers: HashMap::new(),
});
pool.connections.insert("dephy".to_string(), conn);
let result = pool
.call_tool(
"mcp_dephy_company--search",
serde_json::json!({"query": "dephy"}),
)
.await
.unwrap();
assert_eq!(result, serde_json::json!({"ok": true}));
let sent = sent.lock().unwrap();
assert_eq!(sent[0]["method"], "tools/call");
assert_eq!(sent[0]["params"]["name"], "company--search");
assert_eq!(
sent[0]["params"]["arguments"],
serde_json::json!({"query": "dephy"})
);
}
#[tokio::test]
async fn mcp_pool_call_tool_preserves_server_names_with_underscores() {
let sent = Arc::new(Mutex::new(Vec::new()));
let transport = ScriptedValueTransport {
sent: Arc::clone(&sent),
responses: VecDeque::from([json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": {"ok": true}
}))]),
};
let mut conn = test_connection(Box::new(transport));
conn.name = "my_db".to_string();
conn.tools = vec![McpTool {
name: "execute_sql".to_string(),
description: None,
input_schema: serde_json::json!({}),
}];
let mut pool = McpPool::new(McpConfig {
timeouts: McpTimeouts::default(),
servers: HashMap::new(),
});
pool.connections.insert("my_db".to_string(), conn);
let result = pool
.call_tool(
"mcp_my_db_execute_sql",
serde_json::json!({"query": "select 1"}),
)
.await
.unwrap();
assert_eq!(result, serde_json::json!({"ok": true}));
let sent = sent.lock().unwrap();
assert_eq!(sent[0]["method"], "tools/call");
assert_eq!(sent[0]["params"]["name"], "execute_sql");
assert_eq!(
sent[0]["params"]["arguments"],
serde_json::json!({"query": "select 1"})
);
}
#[tokio::test]
async fn mcp_pool_call_tool_prefers_longest_matching_server_name() {
let sent_short = Arc::new(Mutex::new(Vec::new()));
let short_transport = ScriptedValueTransport {
sent: Arc::clone(&sent_short),
responses: VecDeque::from([json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": {"short": true}
}))]),
};
let mut short_conn = test_connection(Box::new(short_transport));
short_conn.name = "my".to_string();
short_conn.tools = vec![McpTool {
name: "db_execute_sql".to_string(),
description: None,
input_schema: serde_json::json!({}),
}];
let sent_long = Arc::new(Mutex::new(Vec::new()));
let long_transport = ScriptedValueTransport {
sent: Arc::clone(&sent_long),
responses: VecDeque::from([json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": {"long": true}
}))]),
};
let mut long_conn = test_connection(Box::new(long_transport));
long_conn.name = "my_db".to_string();
long_conn.tools = vec![McpTool {
name: "execute_sql".to_string(),
description: None,
input_schema: serde_json::json!({}),
}];
let mut pool = McpPool::new(McpConfig {
timeouts: McpTimeouts::default(),
servers: HashMap::new(),
});
pool.connections.insert("my".to_string(), short_conn);
pool.connections.insert("my_db".to_string(), long_conn);
let result = pool
.call_tool(
"mcp_my_db_execute_sql",
serde_json::json!({"query": "select 1"}),
)
.await
.unwrap();
assert_eq!(result, serde_json::json!({"long": true}));
assert!(
sent_short.lock().unwrap().is_empty(),
"the shorter server name must not receive the tool call"
);
let sent_long = sent_long.lock().unwrap();
assert_eq!(sent_long[0]["method"], "tools/call");
assert_eq!(sent_long[0]["params"]["name"], "execute_sql");
assert_eq!(
sent_long[0]["params"]["arguments"],
serde_json::json!({"query": "select 1"})
);
}
#[tokio::test]
async fn json_rpc_session_error_is_marked_stale() {
let sent = Arc::new(Mutex::new(Vec::new()));
let transport = ScriptedValueTransport {
sent: Arc::clone(&sent),
responses: VecDeque::from([json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"error": {
"code": -32001,
"message": "MCP session expired"
}
}))]),
};
let mut conn = test_connection(Box::new(transport));
let err = conn
.call_tool("search", serde_json::json!({"query": "dephy"}), 1)
.await
.expect_err("session error should fail");
assert!(
is_mcp_stale_session_error(&err),
"JSON-RPC session error should be retryable, got: {err:#}"
);
}
#[test]
fn sse_transport_closed_is_retryable() {
let err = anyhow::anyhow!("SSE transport closed");
assert!(
is_mcp_stale_session_error(&err),
"closed SSE stream should force reconnect before retry"
);
}
#[test]
fn legacy_sse_post_disconnect_is_retryable() {
let err = anyhow::anyhow!(
"MCP SSE POST send failed (transport=sse endpoint=http://127.0.0.1:123/messages): connection closed before message completed"
);
assert!(
is_mcp_stale_session_error(&err),
"closed legacy SSE POST should force reconnect before retry"
);
let err = anyhow::anyhow!(
"MCP SSE POST send failed (transport=sse endpoint=http://127.0.0.1:123/messages): connection reset by peer"
);
assert!(
is_mcp_stale_session_error(&err),
"reset legacy SSE POST should force reconnect before retry"
);
let err = anyhow::anyhow!(
"MCP SSE POST send failed (transport=sse endpoint=http://127.0.0.1:123/messages): An existing connection was forcibly closed by the remote host."
);
assert!(
is_mcp_stale_session_error(&err),
"Windows reset wording should force reconnect before retry"
);
}
#[tokio::test]
async fn discover_all_ignores_unsupported_optional_capabilities() {
let sent = Arc::new(Mutex::new(Vec::new()));
let transport = ScriptedValueTransport {
sent: Arc::clone(&sent),
responses: VecDeque::from([
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": {
"tools": [
{ "name": "search", "inputSchema": {} }
]
}
})),
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"error": {
"code": -32601,
"message": "resources not supported"
}
})),
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 3,
"error": {
"code": -32601,
"message": "resource templates not supported"
}
})),
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 4,
"error": {
"code": -32601,
"message": "prompts not supported"
}
})),
]),
};
let mut conn = test_connection(Box::new(transport));
conn.discover_all().await.expect("discover");
assert_eq!(conn.tools.len(), 1);
assert_eq!(conn.tools[0].name, "search");
assert!(conn.resources.is_empty());
assert!(conn.resource_templates.is_empty());
assert!(conn.prompts.is_empty());
}
/// #1244: when an MCP stdio server fails to spawn, the underlying OS
/// error (e.g. ENOENT for a missing binary) must reach the user via the
/// snapshot.error string. Regression test for `err.to_string()` dropping
/// the anyhow chain — without `{err:#}` the user sees only the opaque
/// wrapper "MCP stdio spawn failed (...)" and has nothing to act on.
#[tokio::test]
async fn discover_snapshot_includes_underlying_spawn_error_in_chain() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("mcp.json");
fs::write(
&path,
r#"{
"mcpServers": {
"broken": {
"command": "codewhale-tui-test-this-binary-does-not-exist-9f8e7d6c5b4a",
"args": []
}
}
}"#,
)
.unwrap();
let snapshot = discover_manager_snapshot(&path, None, false).await.unwrap();
let server = snapshot
.servers
.iter()
.find(|s| s.name == "broken")
.expect("broken server should appear in snapshot");
let err = server
.error
.as_deref()
.expect("broken server should have an error");
let lowered = err.to_lowercase();
assert!(
lowered.contains("os error")
|| lowered.contains("not found")
|| lowered.contains("no such"),
"expected underlying spawn error in chain, got: {err}"
);
}
#[test]
fn parse_sse_message_data_extracts_message_events() {
let body = "event: message\r\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\r\n\r\n";
let messages = parse_sse_message_data(body);
assert_eq!(messages.len(), 1);
let value: serde_json::Value = serde_json::from_slice(&messages[0]).unwrap();
assert_eq!(value["id"], 1);
assert!(value.get("result").is_some());
}
#[test]
fn response_id_matches_string_and_numeric_echoes() {
assert!(response_id_matches(Some(&serde_json::json!("1")), "1"));
assert!(response_id_matches(Some(&serde_json::json!(1)), "1"));
assert!(!response_id_matches(Some(&serde_json::json!("2")), "1"));
}
#[test]
fn legacy_sse_transport_requires_explicit_config() {
let mut server = test_server_config();
server.url = Some("https://example.com/mcp/abc/sse".to_string());
assert!(
!is_legacy_sse_transport(&server),
"/sse paths must not force legacy SSE without an explicit transport override"
);
server.transport = Some("sse".to_string());
assert!(is_legacy_sse_transport(&server));
server.transport = Some("SSE".to_string());
assert!(is_legacy_sse_transport(&server));
server.transport = Some("http".to_string());
assert!(!is_legacy_sse_transport(&server));
}
#[test]
fn find_sse_event_separator_accepts_lf_and_crlf() {
assert_eq!(
find_sse_event_separator("event: endpoint\n\n"),
Some((15, 2))
);
assert_eq!(
find_sse_event_separator("event: endpoint\r\n\r\n"),
Some((15, 4))
);
}
#[tokio::test]
#[ignore = "flaky: requires a live TCP listener and is sensitive to port allocation races"]
async fn mcp_connection_supports_streamable_http_event_stream_responses() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
async fn read_http_request(socket: &mut TcpStream) -> String {
let mut request = Vec::new();
let mut buf = [0; 1024];
let header_end = loop {
let n = socket.read(&mut buf).await.unwrap();
assert!(n > 0, "client closed before headers completed");
request.extend_from_slice(&buf[..n]);
if let Some(pos) = request.windows(4).position(|window| window == b"\r\n\r\n") {
break pos + 4;
}
};
let headers = String::from_utf8_lossy(&request[..header_end]);
let content_length = headers
.lines()
.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().ok())
.flatten()
})
.unwrap_or(0);
let total_len = header_end + content_length;
while request.len() < total_len {
let n = socket.read(&mut buf).await.unwrap();
assert!(n > 0, "client closed before body completed");
request.extend_from_slice(&buf[..n]);
}
String::from_utf8(request).unwrap()
}
async fn write_json_sse(socket: &mut TcpStream, response: serde_json::Value) {
let body = format!("event: message\ndata: {response}\n\n");
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
socket.write_all(response.as_bytes()).await.unwrap();
}
let _lock = lock_mcp_loopback_tests().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
loop {
let Ok((mut socket, _)) = listener.accept().await else {
break;
};
tokio::spawn(async move {
let request = read_http_request(&mut socket).await;
assert!(request.starts_with("POST /mcp "));
assert!(
request.contains("Accept: application/json, text/event-stream")
|| request.contains("accept: application/json, text/event-stream")
);
let body = request.split("\r\n\r\n").nth(1).unwrap_or("");
let value: serde_json::Value = serde_json::from_str(body).unwrap();
let method = value["method"].as_str().unwrap();
if method == "notifications/initialized" {
socket
.write_all(b"HTTP/1.1 202 Accepted\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
return;
}
let id = value["id"].clone();
let result = match method {
"initialize" => serde_json::json!({
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "mock-streamable", "version": "1.0.0"},
"capabilities": {"tools": {}, "resources": {}, "prompts": {}}
}),
"tools/list" => serde_json::json!({
"tools": [{
"name": "read_wiki_structure",
"description": "Read wiki structure",
"inputSchema": {"type": "object"}
}]
}),
"resources/list" => serde_json::json!({"resources": []}),
"resources/templates/list" => {
serde_json::json!({"resourceTemplates": []})
}
"prompts/list" => serde_json::json!({"prompts": []}),
other => panic!("unexpected method: {other}"),
};
write_json_sse(
&mut socket,
serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": result
}),
)
.await;
});
}
});
let config = McpServerConfig {
command: None,
args: vec![],
env: HashMap::new(),
cwd: None,
url: Some(format!("http://{addr}/mcp")),
transport: None,
connect_timeout: Some(2),
execute_timeout: None,
read_timeout: None,
disabled: false,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
headers: HashMap::new(),
};
let conn = McpConnection::connect_with_policy(
"deepwiki".to_string(),
config,
&McpTimeouts::default(),
None,
)
.await
.unwrap();
assert_eq!(conn.state(), ConnectionState::Ready);
assert_eq!(conn.tools().len(), 1);
assert_eq!(conn.tools()[0].name, "read_wiki_structure");
server.abort();
}
#[test]
fn mask_url_secrets_strips_userinfo() {
let masked = mask_url_secrets("https://user:s3cret@host.example/api?foo=bar");
assert!(masked.contains("***"), "expected masked userinfo: {masked}");
assert!(!masked.contains("s3cret"), "secret leaked: {masked}");
assert!(masked.contains("host.example"), "host preserved: {masked}");
}
#[test]
fn mask_url_secrets_passes_through_clean_url() {
assert_eq!(
mask_url_secrets("https://api.example.com/mcp"),
"https://api.example.com/mcp"
);
}
#[test]
fn redact_body_preview_masks_bearer_token() {
let redacted = redact_body_preview("Authorization: Bearer abc.def.ghi end");
assert!(redacted.contains("Bearer ***"), "redacted: {redacted}");
assert!(!redacted.contains("abc.def.ghi"), "leaked: {redacted}");
}
#[test]
fn redact_proxy_userinfo_strips_password() {
// Corporate-style proxy URL with embedded creds — the
// password must never reach the on-disk log file. URL strings
// are assembled from placeholder constants via `format!` so the
// literal source never contains a scheme-prefixed username +
// password pair (colon-separated, `@`-terminated) that
// GitGuardian's "Basic Auth String" detector would flag as a
// committed credential.
let (placeholder_user, placeholder_pass) = ("PLACEHOLDER_USER", "PLACEHOLDER_PASS");
let with_creds = format!("http://{placeholder_user}:{placeholder_pass}@proxy.example/");
let redacted = redact_proxy_userinfo(&with_creds);
assert_eq!(redacted, "http://***@proxy.example/");
assert!(!redacted.contains(placeholder_pass));
assert!(!redacted.contains(placeholder_user));
// User only (no password) — still redacted.
let with_user_only = format!("https://{placeholder_user}@proxy.example:8080");
let redacted = redact_proxy_userinfo(&with_user_only);
assert_eq!(redacted, "https://***@proxy.example:8080");
// No userinfo segment — pass through.
let redacted = redact_proxy_userinfo("http://proxy.example:3128/");
assert_eq!(redacted, "http://proxy.example:3128/");
// `@` appears only in the path, not as userinfo separator —
// must not be mistaken for credentials.
let redacted = redact_proxy_userinfo("http://proxy.example/path@thing");
assert_eq!(redacted, "http://proxy.example/path@thing");
// Garbage input (no `://`) returned unchanged — the
// surrounding warning log is the only caller and is already
// handling the malformed-URL case.
assert_eq!(redact_proxy_userinfo("not-a-url"), "not-a-url");
}
#[test]
fn redact_body_preview_masks_api_key_param() {
let redacted = redact_body_preview("error message api_key=sk-12345&other=val");
assert!(redacted.contains("api_key=***"), "redacted: {redacted}");
assert!(!redacted.contains("sk-12345"), "leaked: {redacted}");
assert!(
redacted.contains("other=val"),
"non-secret preserved: {redacted}"
);
}
#[test]
fn invalid_json_preview_collapses_lines_and_redacts_secrets() {
let preview = invalid_json_preview(
b"Authorization: Bearer PLACEHOLDER_TOKEN\nAllow connection? api_key=PLACEHOLDER_KEY",
);
assert!(
preview.contains("Authorization: Bearer *** Allow connection? api_key=***"),
"preview: {preview}"
);
assert!(
!preview.contains('\n'),
"preview should be single-line: {preview}"
);
assert!(
!preview.contains("PLACEHOLDER_TOKEN") && !preview.contains("PLACEHOLDER_KEY"),
"secret leaked: {preview}"
);
}
/// #420: `StdioTransport::shutdown` reaps the child process by sending
/// SIGTERM and giving it a brief grace period before drop fires SIGKILL.
/// The test spawns `cat` (which exits immediately on stdin EOF / SIGTERM)
/// and verifies the transport tears down cleanly. Unix-only because
/// SIGTERM doesn't exist on Windows; on Windows the test would just
/// duplicate the kill_on_drop path.
#[cfg(unix)]
#[tokio::test]
async fn stdio_transport_shutdown_terminates_child() {
use tokio::process::Command as TokioCommand;
let mut cmd = TokioCommand::new("cat");
cmd.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::null())
.kill_on_drop(true);
let mut child = cmd.spawn().expect("spawn cat");
let pid = child.id().expect("child pid");
let stdin = child.stdin.take().expect("child stdin");
let stdout = child.stdout.take().expect("child stdout");
let mut transport = StdioTransport {
child,
stdin,
reader: tokio::io::BufReader::new(stdout),
stderr_tail: StderrTail::new(),
};
// shutdown() should send SIGTERM and complete within the grace window.
let start = std::time::Instant::now();
transport.shutdown().await;
let elapsed = start.elapsed();
assert!(
elapsed < STDIO_SHUTDOWN_GRACE + Duration::from_millis(500),
"shutdown blocked beyond grace window: {elapsed:?}"
);
// The child should be reaped — kill(pid, 0) returning ESRCH means
// the pid is gone. If it's still alive, kill(0) returns 0, which
// means our shutdown didn't terminate it.
// SAFETY: pid was just collected from a tokio Child we spawned.
// libc::kill with signal 0 only checks pid existence and is
// async-signal-safe.
let still_alive = unsafe { libc::kill(pid as i32, 0) } == 0;
assert!(
!still_alive,
"child {pid} survived StdioTransport::shutdown — SIGTERM not delivered"
);
}
/// Mid-run MCP server crash: the v0.8.x spawn path used `Stdio::null` for
/// stderr, so a server that died with a useful stderr message left the
/// caller with only "Stdio transport closed". Now stderr is piped into a
/// bounded ring buffer and surfaced when the read side fails.
#[cfg(unix)]
#[tokio::test]
async fn stdio_transport_recv_error_includes_stderr_tail() {
use tokio::process::Command as TokioCommand;
let mut cmd = TokioCommand::new("sh");
cmd.arg("-c")
.arg("echo 'mcp-server: failed to load plugin' 1>&2; exit 1")
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
let mut child = cmd.spawn().expect("spawn sh");
let stdin = child.stdin.take().expect("stdin");
let stdout = child.stdout.take().expect("stdout");
let stderr = child.stderr.take().expect("stderr");
let stderr_tail = StderrTail::new();
{
let tail = Arc::clone(&stderr_tail);
tokio::spawn(async move {
let mut lines = tokio::io::BufReader::new(stderr).lines();
while let Ok(Some(line)) = lines.next_line().await {
tail.push(line).await;
}
});
}
let mut transport = StdioTransport {
child,
stdin,
reader: tokio::io::BufReader::new(stdout),
stderr_tail,
};
// Give the subprocess time to write its stderr line and exit.
tokio::time::sleep(Duration::from_millis(300)).await;
let err = transport
.recv()
.await
.expect_err("expected transport closed error");
let err_str = format!("{err}");
assert!(
err_str.contains("Stdio transport closed"),
"missing closed marker in: {err_str}"
);
assert!(
err_str.contains("mcp-server: failed to load plugin"),
"stderr context missing from error: {err_str}"
);
}
#[tokio::test]
async fn sse_connect_waits_for_endpoint_before_first_send() {
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering as AtomicOrdering},
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let _lock = lock_mcp_loopback_tests().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let post_seen = Arc::new(AtomicBool::new(false));
let server_post_seen = Arc::clone(&post_seen);
let cancel_token = tokio_util::sync::CancellationToken::new();
let server_cancel = cancel_token.clone();
let server = tokio::spawn(async move {
loop {
let Ok((mut socket, _)) = listener.accept().await else {
break;
};
let post_seen = Arc::clone(&server_post_seen);
let server_cancel = server_cancel.clone();
tokio::spawn(async move {
let mut request = Vec::new();
let mut buf = [0; 1024];
loop {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return;
}
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
let request = String::from_utf8_lossy(&request);
if request.starts_with("GET /sse ") {
socket
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\r\n",
)
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(150)).await;
socket
.write_all(b"event: endpoint\ndata: /messages\n\n")
.await
.unwrap();
server_cancel.cancelled().await;
} else if request.starts_with("POST /messages ") {
post_seen.store(true, AtomicOrdering::SeqCst);
socket
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
}
});
}
});
let client = test_http_client();
let url = format!("http://{addr}/sse");
let mut transport = SseTransport::connect(
client,
url,
HashMap::new(),
cancel_token.clone(),
Duration::from_secs(2),
)
.await
.unwrap();
transport
.send(json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize"
})))
.await
.unwrap();
assert!(
post_seen.load(AtomicOrdering::SeqCst),
"first SSE send should POST to the discovered endpoint"
);
cancel_token.cancel();
server.abort();
}
#[tokio::test]
async fn sse_connect_accepts_crlf_endpoint_events() {
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering as AtomicOrdering},
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let _lock = lock_mcp_loopback_tests().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let post_seen = Arc::new(AtomicBool::new(false));
let server_post_seen = Arc::clone(&post_seen);
let cancel_token = tokio_util::sync::CancellationToken::new();
let server_cancel = cancel_token.clone();
let server = tokio::spawn(async move {
loop {
let Ok((mut socket, _)) = listener.accept().await else {
break;
};
let post_seen = Arc::clone(&server_post_seen);
let server_cancel = server_cancel.clone();
tokio::spawn(async move {
let mut request = Vec::new();
let mut buf = [0; 1024];
loop {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return;
}
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
let request = String::from_utf8_lossy(&request);
if request.starts_with("GET /sse ") {
socket
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\r\n",
)
.await
.unwrap();
socket
.write_all(b"event: endpoint\r\ndata: /messages\r\n\r\n")
.await
.unwrap();
server_cancel.cancelled().await;
} else if request.starts_with("POST /messages ") {
post_seen.store(true, AtomicOrdering::SeqCst);
socket
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
}
});
}
});
let client = test_http_client();
let url = format!("http://{addr}/sse");
let mut transport = SseTransport::connect(
client,
url,
HashMap::new(),
cancel_token.clone(),
Duration::from_secs(2),
)
.await
.unwrap();
transport
.send(json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize"
})))
.await
.unwrap();
assert!(
post_seen.load(AtomicOrdering::SeqCst),
"first SSE send should POST to the CRLF-discovered endpoint"
);
cancel_token.cancel();
server.abort();
}
#[tokio::test]
async fn sse_transport_applies_custom_headers_to_get_and_post() {
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering as AtomicOrdering},
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let _lock = lock_mcp_loopback_tests().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let get_header_seen = Arc::new(AtomicBool::new(false));
let post_header_seen = Arc::new(AtomicBool::new(false));
let server_get_header_seen = Arc::clone(&get_header_seen);
let server_post_header_seen = Arc::clone(&post_header_seen);
let cancel_token = tokio_util::sync::CancellationToken::new();
let server_cancel = cancel_token.clone();
let server = tokio::spawn(async move {
loop {
let Ok((mut socket, _)) = listener.accept().await else {
break;
};
let get_header_seen = Arc::clone(&server_get_header_seen);
let post_header_seen = Arc::clone(&server_post_header_seen);
let server_cancel = server_cancel.clone();
tokio::spawn(async move {
let mut request = Vec::new();
let mut buf = [0; 1024];
loop {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return;
}
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
let request = String::from_utf8_lossy(&request);
let request_lower = request.to_lowercase();
if request.starts_with("GET /sse ") {
if request_lower.contains("x-custom-auth: my-test-token") {
get_header_seen.store(true, AtomicOrdering::SeqCst);
}
socket
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\r\n",
)
.await
.unwrap();
socket
.write_all(b"event: endpoint\ndata: /messages\n\n")
.await
.unwrap();
server_cancel.cancelled().await;
} else if request.starts_with("POST /messages ") {
if request_lower.contains("x-custom-auth: my-test-token") {
post_header_seen.store(true, AtomicOrdering::SeqCst);
}
socket
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
}
});
}
});
let client = test_http_client();
let url = format!("http://{addr}/sse");
let mut headers = HashMap::new();
headers.insert("X-Custom-Auth".to_string(), "my-test-token".to_string());
let mut transport = SseTransport::connect(
client,
url,
headers,
cancel_token.clone(),
Duration::from_secs(2),
)
.await
.unwrap();
transport
.send(json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize"
})))
.await
.unwrap();
assert!(
get_header_seen.load(AtomicOrdering::SeqCst),
"legacy SSE GET must include user-configured custom headers"
);
assert!(
post_header_seen.load(AtomicOrdering::SeqCst),
"legacy SSE POST must include user-configured custom headers"
);
cancel_token.cancel();
server.abort();
}
#[tokio::test]
async fn sse_post_error_includes_response_body_excerpt() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let _lock = lock_mcp_loopback_tests().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let cancel_token = tokio_util::sync::CancellationToken::new();
let server_cancel = cancel_token.clone();
let server = tokio::spawn(async move {
loop {
let Ok((mut socket, _)) = listener.accept().await else {
break;
};
let server_cancel = server_cancel.clone();
tokio::spawn(async move {
let mut request = Vec::new();
let mut buf = [0; 1024];
loop {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return;
}
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
let request = String::from_utf8_lossy(&request);
if request.starts_with("GET /sse ") {
socket
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\r\n",
)
.await
.unwrap();
socket
.write_all(b"event: endpoint\ndata: /messages\n\n")
.await
.unwrap();
server_cancel.cancelled().await;
} else if request.starts_with("POST /messages ") {
socket
.write_all(
b"HTTP/1.1 400 Bad Request\r\nContent-Type: application/json\r\nContent-Length: 25\r\n\r\n{\"error\":\"missing query\"}",
)
.await
.unwrap();
}
});
}
});
let client = test_http_client();
let url = format!("http://{addr}/sse");
let mut transport = SseTransport::connect(
client,
url,
HashMap::new(),
cancel_token.clone(),
Duration::from_secs(2),
)
.await
.unwrap();
let err = transport
.send(json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize"
})))
.await
.expect_err("POST rejection should be returned");
let err = format!("{err:#}");
assert!(
err.contains("400 Bad Request") && err.contains("missing query"),
"SSE POST error should include status and body, got: {err}"
);
cancel_token.cancel();
server.abort();
}
#[tokio::test]
async fn streamable_http_stale_session_reconnects_and_retries_tool_call() {
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
async fn write_response(socket: &mut tokio::net::TcpStream, response: &[u8]) {
socket.write_all(response).await.unwrap();
socket.flush().await.unwrap();
socket.shutdown().await.unwrap();
}
let _lock = lock_mcp_loopback_tests().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let get_count = Arc::new(AtomicUsize::new(0));
let stale_seen = Arc::new(AtomicBool::new(false));
let success_seen = Arc::new(AtomicBool::new(false));
let server_get_count = Arc::clone(&get_count);
let server_stale_seen = Arc::clone(&stale_seen);
let server_success_seen = Arc::clone(&success_seen);
let server = tokio::spawn(async move {
loop {
let Ok((mut socket, _)) = listener.accept().await else {
break;
};
let get_count = Arc::clone(&server_get_count);
let stale_seen = Arc::clone(&server_stale_seen);
let success_seen = Arc::clone(&server_success_seen);
tokio::spawn(async move {
let mut request = Vec::new();
let mut buf = [0; 4096];
let header_end = loop {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return;
}
request.extend_from_slice(&buf[..n]);
if let Some(pos) = request.windows(4).position(|w| w == b"\r\n\r\n") {
break pos + 4;
}
};
let headers = String::from_utf8_lossy(&request[..header_end]).to_string();
let content_length = headers
.lines()
.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().ok())
.flatten()
})
.unwrap_or(0);
while request.len() < header_end + content_length {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return;
}
request.extend_from_slice(&buf[..n]);
}
let body = &request[header_end..header_end + content_length];
let session_header = headers.lines().find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("mcp-session-id")
.then(|| value.trim().to_string())
});
if headers.starts_with("GET /mcp ") {
let count = get_count.fetch_add(1, AtomicOrdering::SeqCst);
let session = if count == 0 { "sess-old" } else { "sess-new" };
let response = format!(
"HTTP/1.1 200 OK\r\nMcp-Session-Id: {session}\r\nContent-Length: 0\r\n\r\n"
);
write_response(&mut socket, response.as_bytes()).await;
return;
}
let request_json: serde_json::Value = serde_json::from_slice(body).unwrap();
let method = request_json
.get("method")
.and_then(serde_json::Value::as_str)
.unwrap_or("");
let id = request_json
.get("id")
.cloned()
.unwrap_or_else(|| serde_json::json!("0"));
if method == "tools/call" && session_header.as_deref() == Some("sess-old") {
stale_seen.store(true, AtomicOrdering::SeqCst);
write_response(
&mut socket,
b"HTTP/1.1 404 Not Found\r\nContent-Type: application/json\r\nContent-Length: 27\r\n\r\n{\"error\":\"session expired\"}",
)
.await;
return;
}
let result = match method {
"initialize" => serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {}
}),
"tools/list" => serde_json::json!({
"tools": [
{ "name": "search", "inputSchema": {} }
]
}),
"resources/list" => serde_json::json!({ "resources": [] }),
"resources/templates/list" => {
serde_json::json!({ "resourceTemplates": [] })
}
"prompts/list" => serde_json::json!({ "prompts": [] }),
"tools/call" => {
assert_eq!(session_header.as_deref(), Some("sess-new"));
success_seen.store(true, AtomicOrdering::SeqCst);
serde_json::json!({ "content": [{ "type": "text", "text": "ok" }] })
}
_ => {
write_response(
&mut socket,
b"HTTP/1.1 202 Accepted\r\nContent-Length: 0\r\n\r\n",
)
.await;
return;
}
};
let response_body = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": result
})
.to_string();
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
response_body.len(),
response_body
);
write_response(&mut socket, response.as_bytes()).await;
});
}
});
let mut cfg = McpConfig::default();
cfg.servers.insert(
"dephy".to_string(),
McpServerConfig {
command: None,
args: Vec::new(),
env: HashMap::new(),
cwd: None,
url: Some(format!("http://{addr}/mcp")),
transport: None,
connect_timeout: Some(10),
execute_timeout: Some(10),
read_timeout: None,
disabled: false,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
headers: HashMap::new(),
},
);
let mut pool = McpPool::new(cfg);
let result = pool
.call_tool("mcp_dephy_search", serde_json::json!({ "query": "dephy" }))
.await
.unwrap();
assert_eq!(
result,
serde_json::json!({ "content": [{ "type": "text", "text": "ok" }] })
);
assert!(stale_seen.load(AtomicOrdering::SeqCst));
assert!(success_seen.load(AtomicOrdering::SeqCst));
assert_eq!(get_count.load(AtomicOrdering::SeqCst), 2);
server.abort();
}
#[tokio::test]
async fn legacy_sse_session_expiry_is_marked_stale() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::mpsc;
let _lock = lock_mcp_loopback_tests().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut request = Vec::new();
let mut buf = [0; 4096];
let header_end = loop {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return;
}
request.extend_from_slice(&buf[..n]);
if let Some(pos) = request.windows(4).position(|w| w == b"\r\n\r\n") {
break pos + 4;
}
};
let headers = String::from_utf8_lossy(&request[..header_end]);
assert!(headers.starts_with("POST /messages "));
socket
.write_all(
b"HTTP/1.1 400 Bad Request\r\nContent-Type: application/json\r\nContent-Length: 27\r\n\r\n{\"error\":\"session expired\"}",
)
.await
.unwrap();
});
let (_sender, receiver) = mpsc::unbounded_channel();
let sse_task = tokio::spawn(async {});
let mut transport = SseTransport {
client: test_http_client(),
base_url: format!("http://{addr}/sse"),
headers: HashMap::new(),
endpoint_url: Some(format!("http://{addr}/messages")),
receiver,
pending_messages: VecDeque::new(),
sse_task,
};
let err = transport
.send(br#"{"jsonrpc":"2.0","id":1,"method":"tools/call"}"#.to_vec())
.await
.expect_err("expired SSE session should fail");
assert!(
is_mcp_stale_session_error(&err),
"SSE session expiry should be retryable, got: {err:#}"
);
server.abort();
}
#[tokio::test]
async fn legacy_sse_closed_stream_reconnects_and_retries_tool_call() {
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
async fn read_http_request(socket: &mut TcpStream) -> (String, serde_json::Value) {
let mut request = Vec::new();
let mut buf = [0; 4096];
let header_end = loop {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return (String::new(), serde_json::Value::Null);
}
request.extend_from_slice(&buf[..n]);
if let Some(pos) = request.windows(4).position(|w| w == b"\r\n\r\n") {
break pos + 4;
}
};
let headers = String::from_utf8_lossy(&request[..header_end]).to_string();
let content_length = headers
.lines()
.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().ok())
.flatten()
})
.unwrap_or(0);
while request.len() < header_end + content_length {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return (headers, serde_json::Value::Null);
}
request.extend_from_slice(&buf[..n]);
}
let body = &request[header_end..header_end + content_length];
let json = if body.is_empty() {
serde_json::Value::Null
} else {
serde_json::from_slice(body).unwrap()
};
(headers, json)
}
let _lock = lock_mcp_loopback_tests().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let active_sse = Arc::new(Mutex::new(None::<mpsc::UnboundedSender<Option<String>>>));
let get_count = Arc::new(AtomicUsize::new(0));
let tool_call_count = Arc::new(AtomicUsize::new(0));
let success_seen = Arc::new(AtomicBool::new(false));
let server_active_sse = Arc::clone(&active_sse);
let server_get_count = Arc::clone(&get_count);
let server_tool_call_count = Arc::clone(&tool_call_count);
let server_success_seen = Arc::clone(&success_seen);
let server = tokio::spawn(async move {
loop {
let Ok((mut socket, _)) = listener.accept().await else {
break;
};
let active_sse = Arc::clone(&server_active_sse);
let get_count = Arc::clone(&server_get_count);
let tool_call_count = Arc::clone(&server_tool_call_count);
let success_seen = Arc::clone(&server_success_seen);
tokio::spawn(async move {
let (headers, request_json) = read_http_request(&mut socket).await;
if headers.starts_with("GET /sse ") {
get_count.fetch_add(1, AtomicOrdering::SeqCst);
let (tx, mut rx) = mpsc::unbounded_channel::<Option<String>>();
*active_sse.lock().unwrap() = Some(tx);
socket
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\r\n",
)
.await
.unwrap();
socket
.write_all(b"event: endpoint\ndata: /messages\n\n")
.await
.unwrap();
while let Some(message) = rx.recv().await {
let Some(message) = message else {
return;
};
let event = format!("event: message\ndata: {message}\n\n");
socket.write_all(event.as_bytes()).await.unwrap();
}
return;
}
if !headers.starts_with("POST /messages ") {
return;
}
socket
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
let method = request_json
.get("method")
.and_then(serde_json::Value::as_str)
.unwrap_or("");
if method == "notifications/initialized" {
return;
}
let id = request_json
.get("id")
.cloned()
.unwrap_or_else(|| serde_json::json!("0"));
if method == "tools/call" {
let count = tool_call_count.fetch_add(1, AtomicOrdering::SeqCst);
if count == 0 {
if let Some(tx) = active_sse.lock().unwrap().take() {
let _ = tx.send(None);
}
return;
}
}
let result = match method {
"initialize" => serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {}
}),
"tools/list" => serde_json::json!({
"tools": [
{ "name": "search", "inputSchema": {} }
]
}),
"resources/list" => serde_json::json!({ "resources": [] }),
"resources/templates/list" => {
serde_json::json!({ "resourceTemplates": [] })
}
"prompts/list" => serde_json::json!({ "prompts": [] }),
"tools/call" => {
success_seen.store(true, AtomicOrdering::SeqCst);
serde_json::json!({ "content": [{ "type": "text", "text": "ok" }] })
}
other => panic!("unexpected method: {other}"),
};
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": result
})
.to_string();
// Deliver the response over the *current* SSE channel. The
// retry tool call can race ahead of the reconnecting GET
// /sse that re-stores the sender; under parallel load those
// two server tasks are scheduled in either order, so wait
// briefly for the channel instead of dropping the response
// (which left the client hanging until timeout) (#2597).
let send_deadline =
std::time::Instant::now() + std::time::Duration::from_secs(5);
let tx = loop {
if let Some(tx) = active_sse.lock().unwrap().as_ref().cloned() {
break Some(tx);
}
if std::time::Instant::now() >= send_deadline {
break None;
}
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
};
if let Some(tx) = tx {
let _ = tx.send(Some(response));
}
});
}
});
let mut cfg = McpConfig::default();
cfg.servers.insert(
"dephy".to_string(),
McpServerConfig {
command: None,
args: Vec::new(),
env: HashMap::new(),
cwd: None,
url: Some(format!("http://{addr}/sse")),
transport: Some("sse".to_string()),
connect_timeout: Some(10),
execute_timeout: Some(10),
read_timeout: None,
disabled: false,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
headers: HashMap::new(),
},
);
let mut pool = McpPool::new(cfg);
let result = pool
.call_tool("mcp_dephy_search", serde_json::json!({ "query": "dephy" }))
.await
.unwrap();
assert_eq!(
result,
serde_json::json!({ "content": [{ "type": "text", "text": "ok" }] })
);
assert_eq!(tool_call_count.load(AtomicOrdering::SeqCst), 2);
assert_eq!(get_count.load(AtomicOrdering::SeqCst), 2);
assert!(success_seen.load(AtomicOrdering::SeqCst));
server.abort();
}
#[test]
fn session_id_starts_none() {
let transport = StreamableHttpTransport::new(
test_http_client(),
"https://example.invalid/mcp".to_string(),
HashMap::new(),
);
assert!(transport.session_id.is_none());
}
/// Session ID captured from a POST response is replayed on the next POST.
#[tokio::test]
async fn session_id_captured_from_post_response_and_replayed() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let _lock = lock_mcp_loopback_tests().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let n = socket.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]);
assert!(req.starts_with("POST "), "expected POST, got: {req}");
// First POST: return a session ID so the transport captures it.
socket
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nMcp-Session-Id: sess-abc-123\r\nContent-Length: 2\r\n\r\n{}",
)
.await
.unwrap();
socket.flush().await.unwrap();
// Read the second POST — should contain the session ID.
let mut buf2 = [0u8; 4096];
let n2 = socket.read(&mut buf2).await.unwrap();
let req2 = String::from_utf8_lossy(&buf2[..n2]);
// reqwest lower-cases header names.
let req2_lower = req2.to_lowercase();
assert!(
req2_lower.contains("mcp-session-id: sess-abc-123"),
"second POST must replay captured session ID, got:\n{req2}"
);
socket
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
});
let client = test_http_client();
let url = format!("http://{addr}/mcp");
let mut transport = StreamableHttpTransport::new(client, url, HashMap::new());
// First send: server returns Mcp-Session-Id.
transport
.send(json_frame(serde_json::json!({
"jsonrpc": "2.0", "id": 1,
"method": "initialize",
"params": {}
})))
.await
.unwrap();
assert_eq!(
transport.session_id.as_deref(),
Some("sess-abc-123"),
"session ID should be captured from response"
);
// Second send: should replay the session ID.
transport
.send(json_frame(serde_json::json!({
"jsonrpc": "2.0", "id": 2,
"method": "tools/list",
"params": {}
})))
.await
.unwrap();
server.abort();
}
/// Custom headers configured in McpServerConfig are applied to the GET
/// preflight so servers that require auth on session-establishment GET
/// (e.g. Hindsight, #1629) can authenticate it.
#[tokio::test]
async fn custom_headers_applied_to_get_preflight() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let _lock = lock_mcp_loopback_tests().await;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
// The test signals success by writing to this flag — the GET handler
// sets it when it sees the expected header.
let header_seen = Arc::new(AtomicBool::new(false));
let header_seen_srv = Arc::clone(&header_seen);
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let n = socket.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]);
// reqwest lower-cases header names.
if req.starts_with("GET ")
&& req.to_lowercase().contains("x-custom-auth: my-test-token")
{
header_seen_srv.store(true, AtomicOrdering::SeqCst);
}
socket
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
});
let client = test_http_client();
let url = format!("http://{addr}/mcp");
let mut headers = HashMap::new();
headers.insert("X-Custom-Auth".to_string(), "my-test-token".to_string());
let mut transport = HttpTransport::new(
client,
url,
headers,
tokio_util::sync::CancellationToken::new(),
Duration::from_secs(10),
);
transport.try_establish_session().await.unwrap();
server.abort();
assert!(
header_seen.load(AtomicOrdering::SeqCst),
"GET preflight must include user-configured custom headers"
);
}
}