diff --git a/crates/tui/src/mcp.rs b/crates/tui/src/mcp.rs index 34394090..d25c0734 100644 --- a/crates/tui/src/mcp.rs +++ b/crates/tui/src/mcp.rs @@ -569,8 +569,15 @@ struct StreamableHttpTransport { /// runs before each request. headers: HashMap, pending_messages: VecDeque>, + /// Per-spec MCP session identifier returned by the server in the + /// first response (typically the `initialize` response). Attached + /// as the `Mcp-Session-Id` header on every subsequent outbound + /// request so the server can correlate messages within the same + /// session. + session_id: Option, } +#[derive(Debug)] enum StreamableSendError { Incompatible(String), Other(anyhow::Error), @@ -797,6 +804,78 @@ impl HttpTransport { self.mode = HttpTransportMode::Sse(sse); Ok(()) } + + /// Best-effort session-establishment GET preflight. + /// + /// Per the Streamable HTTP spec, the server may return an + /// `Mcp-Session-Id` header on the `initialize` response (the normal + /// path handled inside [`StreamableHttpTransport::send`] above). + /// However some servers (e.g. Hindsight, #1629) **require** a session + /// ID on every POST including `initialize`, creating a chicken-and-egg + /// problem. For those servers we send a short-lived GET before the + /// first POST: if the server returns a session ID in the GET response + /// it will be captured by the header-reading code in + /// [`StreamableHttpTransport::send`] just as if it came from a POST + /// response. + /// + /// This is intentionally best-effort: + /// * The GET uses a tight per-request inner timeout so it never + /// blocks connection startup for long. + /// * If the server doesn't support GET (405, 404, …) we log a debug + /// line and move on — the `initialize` POST will proceed without a + /// session ID. + /// * If the server opens an SSE stream in response (the GET from old + /// SSE transport), we read only the headers, then discard the body + /// so the SSE stream is torn down. The actual SSE path uses a + /// dedicated `SseTransport` and is triggered by the incompatible- + /// status fallback in [`HttpTransport::send`]. + async fn try_establish_session(&mut self) -> Result<()> { + let transport = match &mut self.mode { + HttpTransportMode::Streamable(t) => t, + // Already on SSE — session is implicit via the long-lived GET. + HttpTransportMode::Sse(_) => return Ok(()), + }; + + let mut request = transport.client.get(&transport.url); + request = with_default_mcp_http_headers(request, false); + for (key, value) in &transport.headers { + if !is_safe_custom_header(key, value) { + tracing::warn!( + target: "mcp", + "skipping unsafe MCP header {:?} (empty/control-char/reserved)", + key + ); + continue; + } + request = request.header(key.as_str(), value.as_str()); + } + let response = tokio::time::timeout(Duration::from_secs(5), request.send()) + .await + .map_err(|_| anyhow::anyhow!("GET timeout"))? + .map_err(|e| anyhow::anyhow!("GET error: {e}"))?; + + // Capture session ID from the GET response so subsequent POSTs + // (including `initialize`) can include it. This is the same + // header-reading logic that would be hit inside + // `StreamableHttpTransport::send` for POST responses, but since + // the GET is sent before any POST we do it here directly. + if let Some(sid) = response + .headers() + .get("Mcp-Session-Id") + .and_then(|v| v.to_str().ok()) + && transport.session_id.as_deref() != Some(sid) + { + tracing::debug!(target: "mcp", session_id = %sid, "captured MCP session ID via GET preflight"); + transport.session_id = Some(sid.to_string()); + } + + // We only care about the response headers — discard the body. + // If the server opened an SSE stream in response (some servers + // do this on GET), it will be torn down when response is dropped. + drop(response); + + Ok(()) + } } #[async_trait::async_trait] @@ -839,6 +918,7 @@ impl StreamableHttpTransport { url, headers, pending_messages: VecDeque::new(), + session_id: None, } } @@ -866,6 +946,12 @@ impl StreamableHttpTransport { } request = request.header(key.as_str(), value.as_str()); } + // Attach any previously captured session ID per the Streamable + // HTTP spec so the server can correlate this request to the + // existing session. + if let Some(ref sid) = self.session_id { + request = request.header("Mcp-Session-Id", sid.as_str()); + } let response = request .body(msg) .send() @@ -873,6 +959,19 @@ impl StreamableHttpTransport { .map_err(|err| StreamableSendError::Other(err.into()))?; let status = response.status(); + + // Capture session ID from any response (2xx, 202, 4xx, …). The + // server may return it on the `initialize` response or on a + // best-effort GET preflight below. + if let Some(sid) = response + .headers() + .get("Mcp-Session-Id") + .and_then(|v| v.to_str().ok()) + && self.session_id.as_deref() != Some(sid) + { + tracing::debug!(target: "mcp", session_id = %sid, "captured MCP session ID"); + self.session_id = Some(sid.to_string()); + } if status == StatusCode::ACCEPTED || status == StatusCode::NO_CONTENT { return Ok(()); } @@ -1113,13 +1212,27 @@ impl McpConnection { } } let client = client_builder.build()?; - Box::new(HttpTransport::new( + let mut http = HttpTransport::new( client, url.clone(), config.headers.clone(), cancel_token.clone(), Duration::from_secs(connect_timeout_secs), - )) + ); + // Best-effort session preflight for servers that require + // a session ID on every POST including `initialize` + // (e.g. Hindsight, #1629). Failures are non-fatal — the + // `initialize` POST will proceed and may capture a session + // ID from the response instead. + if let Err(e) = http.try_establish_session().await { + tracing::debug!( + target: "mcp", + server = %name, + error = %e, + "session-establishment GET skipped; proceeding with POST initialize" + ); + } + Box::new(http) } else if let Some(command) = &config.command { let mut cmd = tokio::process::Command::new(command); cmd.args(&config.args) @@ -2617,6 +2730,7 @@ pub fn format_tool_result(result: &serde_json::Value) -> String { mod tests { use super::*; use std::collections::VecDeque; + use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering}; use std::sync::{Arc, Mutex}; #[test] @@ -3764,4 +3878,144 @@ mod tests { cancel_token.cancel(); server.abort(); } + + #[test] + fn session_id_starts_none() { + let transport = StreamableHttpTransport::new( + reqwest::Client::new(), + "https://example.invalid/mcp".to_string(), + HashMap::new(), + ); + assert!(transport.session_id.is_none()); + } + + /// Session ID captured from a POST response is replayed on the next POST. + #[tokio::test] + async fn session_id_captured_from_post_response_and_replayed() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let server = tokio::spawn(async move { + let (mut socket, _) = listener.accept().await.unwrap(); + let mut buf = [0u8; 4096]; + let n = socket.read(&mut buf).await.unwrap(); + let req = String::from_utf8_lossy(&buf[..n]); + assert!(req.starts_with("POST "), "expected POST, got: {req}"); + + // First POST: return a session ID so the transport captures it. + socket + .write_all( + b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nMcp-Session-Id: sess-abc-123\r\nContent-Length: 2\r\n\r\n{}", + ) + .await + .unwrap(); + socket.flush().await.unwrap(); + + // Read the second POST — should contain the session ID. + let mut buf2 = [0u8; 4096]; + let n2 = socket.read(&mut buf2).await.unwrap(); + let req2 = String::from_utf8_lossy(&buf2[..n2]); + // reqwest lower-cases header names. + let req2_lower = req2.to_lowercase(); + assert!( + req2_lower.contains("mcp-session-id: sess-abc-123"), + "second POST must replay captured session ID, got:\n{req2}" + ); + + socket + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + .await + .unwrap(); + }); + + let client = reqwest::Client::new(); + let url = format!("http://{addr}/mcp"); + let mut transport = StreamableHttpTransport::new(client, url, HashMap::new()); + + // First send: server returns Mcp-Session-Id. + transport + .send(json_frame(serde_json::json!({ + "jsonrpc": "2.0", "id": 1, + "method": "initialize", + "params": {} + }))) + .await + .unwrap(); + assert_eq!( + transport.session_id.as_deref(), + Some("sess-abc-123"), + "session ID should be captured from response" + ); + + // Second send: should replay the session ID. + transport + .send(json_frame(serde_json::json!({ + "jsonrpc": "2.0", "id": 2, + "method": "tools/list", + "params": {} + }))) + .await + .unwrap(); + + server.abort(); + } + + /// Custom headers configured in McpServerConfig are applied to the GET + /// preflight so servers that require auth on session-establishment GET + /// (e.g. Hindsight, #1629) can authenticate it. + #[tokio::test] + async fn custom_headers_applied_to_get_preflight() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + // The test signals success by writing to this flag — the GET handler + // sets it when it sees the expected header. + let header_seen = Arc::new(AtomicBool::new(false)); + let header_seen_srv = Arc::clone(&header_seen); + + let server = tokio::spawn(async move { + let (mut socket, _) = listener.accept().await.unwrap(); + let mut buf = [0u8; 4096]; + let n = socket.read(&mut buf).await.unwrap(); + let req = String::from_utf8_lossy(&buf[..n]); + + // reqwest lower-cases header names. + if req.starts_with("GET ") + && req.to_lowercase().contains("x-custom-auth: my-test-token") + { + header_seen_srv.store(true, AtomicOrdering::SeqCst); + } + + socket + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + .await + .unwrap(); + }); + + let client = reqwest::Client::new(); + let url = format!("http://{addr}/mcp"); + let mut headers = HashMap::new(); + headers.insert("X-Custom-Auth".to_string(), "my-test-token".to_string()); + + let mut transport = HttpTransport::new( + client, + url, + headers, + tokio_util::sync::CancellationToken::new(), + Duration::from_secs(10), + ); + + transport.try_establish_session().await.unwrap(); + + server.abort(); + + assert!( + header_seen.load(AtomicOrdering::SeqCst), + "GET preflight must include user-configured custom headers" + ); + } }