fix(mcp): persist Streamable HTTP session IDs

Capture and replay Mcp-Session-Id for Streamable HTTP transports, and apply configured custom headers to the GET preflight.\n\nCloses #1629.\n\nCo-authored-by: Zhiping <2716057626@qq.com>
This commit is contained in:
ZzzPL
2026-05-15 02:25:22 +08:00
committed by GitHub
parent 13e7957621
commit ece805568b
+256 -2
View File
@@ -569,8 +569,15 @@ struct StreamableHttpTransport {
/// runs before each request.
headers: HashMap<String, String>,
pending_messages: VecDeque<Vec<u8>>,
/// Per-spec MCP session identifier returned by the server in the
/// first response (typically the `initialize` response). Attached
/// as the `Mcp-Session-Id` header on every subsequent outbound
/// request so the server can correlate messages within the same
/// session.
session_id: Option<String>,
}
#[derive(Debug)]
enum StreamableSendError {
Incompatible(String),
Other(anyhow::Error),
@@ -797,6 +804,78 @@ impl HttpTransport {
self.mode = HttpTransportMode::Sse(sse);
Ok(())
}
/// Best-effort session-establishment GET preflight.
///
/// Per the Streamable HTTP spec, the server may return an
/// `Mcp-Session-Id` header on the `initialize` response (the normal
/// path handled inside [`StreamableHttpTransport::send`] above).
/// However some servers (e.g. Hindsight, #1629) **require** a session
/// ID on every POST including `initialize`, creating a chicken-and-egg
/// problem. For those servers we send a short-lived GET before the
/// first POST: if the server returns a session ID in the GET response
/// it will be captured by the header-reading code in
/// [`StreamableHttpTransport::send`] just as if it came from a POST
/// response.
///
/// This is intentionally best-effort:
/// * The GET uses a tight per-request inner timeout so it never
/// blocks connection startup for long.
/// * If the server doesn't support GET (405, 404, …) we log a debug
/// line and move on — the `initialize` POST will proceed without a
/// session ID.
/// * If the server opens an SSE stream in response (the GET from old
/// SSE transport), we read only the headers, then discard the body
/// so the SSE stream is torn down. The actual SSE path uses a
/// dedicated `SseTransport` and is triggered by the incompatible-
/// status fallback in [`HttpTransport::send`].
async fn try_establish_session(&mut self) -> Result<()> {
let transport = match &mut self.mode {
HttpTransportMode::Streamable(t) => t,
// Already on SSE — session is implicit via the long-lived GET.
HttpTransportMode::Sse(_) => return Ok(()),
};
let mut request = transport.client.get(&transport.url);
request = with_default_mcp_http_headers(request, false);
for (key, value) in &transport.headers {
if !is_safe_custom_header(key, value) {
tracing::warn!(
target: "mcp",
"skipping unsafe MCP header {:?} (empty/control-char/reserved)",
key
);
continue;
}
request = request.header(key.as_str(), value.as_str());
}
let response = tokio::time::timeout(Duration::from_secs(5), request.send())
.await
.map_err(|_| anyhow::anyhow!("GET timeout"))?
.map_err(|e| anyhow::anyhow!("GET error: {e}"))?;
// Capture session ID from the GET response so subsequent POSTs
// (including `initialize`) can include it. This is the same
// header-reading logic that would be hit inside
// `StreamableHttpTransport::send` for POST responses, but since
// the GET is sent before any POST we do it here directly.
if let Some(sid) = response
.headers()
.get("Mcp-Session-Id")
.and_then(|v| v.to_str().ok())
&& transport.session_id.as_deref() != Some(sid)
{
tracing::debug!(target: "mcp", session_id = %sid, "captured MCP session ID via GET preflight");
transport.session_id = Some(sid.to_string());
}
// We only care about the response headers — discard the body.
// If the server opened an SSE stream in response (some servers
// do this on GET), it will be torn down when response is dropped.
drop(response);
Ok(())
}
}
#[async_trait::async_trait]
@@ -839,6 +918,7 @@ impl StreamableHttpTransport {
url,
headers,
pending_messages: VecDeque::new(),
session_id: None,
}
}
@@ -866,6 +946,12 @@ impl StreamableHttpTransport {
}
request = request.header(key.as_str(), value.as_str());
}
// Attach any previously captured session ID per the Streamable
// HTTP spec so the server can correlate this request to the
// existing session.
if let Some(ref sid) = self.session_id {
request = request.header("Mcp-Session-Id", sid.as_str());
}
let response = request
.body(msg)
.send()
@@ -873,6 +959,19 @@ impl StreamableHttpTransport {
.map_err(|err| StreamableSendError::Other(err.into()))?;
let status = response.status();
// Capture session ID from any response (2xx, 202, 4xx, …). The
// server may return it on the `initialize` response or on a
// best-effort GET preflight below.
if let Some(sid) = response
.headers()
.get("Mcp-Session-Id")
.and_then(|v| v.to_str().ok())
&& self.session_id.as_deref() != Some(sid)
{
tracing::debug!(target: "mcp", session_id = %sid, "captured MCP session ID");
self.session_id = Some(sid.to_string());
}
if status == StatusCode::ACCEPTED || status == StatusCode::NO_CONTENT {
return Ok(());
}
@@ -1113,13 +1212,27 @@ impl McpConnection {
}
}
let client = client_builder.build()?;
Box::new(HttpTransport::new(
let mut http = HttpTransport::new(
client,
url.clone(),
config.headers.clone(),
cancel_token.clone(),
Duration::from_secs(connect_timeout_secs),
))
);
// Best-effort session preflight for servers that require
// a session ID on every POST including `initialize`
// (e.g. Hindsight, #1629). Failures are non-fatal — the
// `initialize` POST will proceed and may capture a session
// ID from the response instead.
if let Err(e) = http.try_establish_session().await {
tracing::debug!(
target: "mcp",
server = %name,
error = %e,
"session-establishment GET skipped; proceeding with POST initialize"
);
}
Box::new(http)
} else if let Some(command) = &config.command {
let mut cmd = tokio::process::Command::new(command);
cmd.args(&config.args)
@@ -2617,6 +2730,7 @@ pub fn format_tool_result(result: &serde_json::Value) -> String {
mod tests {
use super::*;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
use std::sync::{Arc, Mutex};
#[test]
@@ -3764,4 +3878,144 @@ mod tests {
cancel_token.cancel();
server.abort();
}
#[test]
fn session_id_starts_none() {
let transport = StreamableHttpTransport::new(
reqwest::Client::new(),
"https://example.invalid/mcp".to_string(),
HashMap::new(),
);
assert!(transport.session_id.is_none());
}
/// Session ID captured from a POST response is replayed on the next POST.
#[tokio::test]
async fn session_id_captured_from_post_response_and_replayed() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let n = socket.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]);
assert!(req.starts_with("POST "), "expected POST, got: {req}");
// First POST: return a session ID so the transport captures it.
socket
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nMcp-Session-Id: sess-abc-123\r\nContent-Length: 2\r\n\r\n{}",
)
.await
.unwrap();
socket.flush().await.unwrap();
// Read the second POST — should contain the session ID.
let mut buf2 = [0u8; 4096];
let n2 = socket.read(&mut buf2).await.unwrap();
let req2 = String::from_utf8_lossy(&buf2[..n2]);
// reqwest lower-cases header names.
let req2_lower = req2.to_lowercase();
assert!(
req2_lower.contains("mcp-session-id: sess-abc-123"),
"second POST must replay captured session ID, got:\n{req2}"
);
socket
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
});
let client = reqwest::Client::new();
let url = format!("http://{addr}/mcp");
let mut transport = StreamableHttpTransport::new(client, url, HashMap::new());
// First send: server returns Mcp-Session-Id.
transport
.send(json_frame(serde_json::json!({
"jsonrpc": "2.0", "id": 1,
"method": "initialize",
"params": {}
})))
.await
.unwrap();
assert_eq!(
transport.session_id.as_deref(),
Some("sess-abc-123"),
"session ID should be captured from response"
);
// Second send: should replay the session ID.
transport
.send(json_frame(serde_json::json!({
"jsonrpc": "2.0", "id": 2,
"method": "tools/list",
"params": {}
})))
.await
.unwrap();
server.abort();
}
/// Custom headers configured in McpServerConfig are applied to the GET
/// preflight so servers that require auth on session-establishment GET
/// (e.g. Hindsight, #1629) can authenticate it.
#[tokio::test]
async fn custom_headers_applied_to_get_preflight() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
// The test signals success by writing to this flag — the GET handler
// sets it when it sees the expected header.
let header_seen = Arc::new(AtomicBool::new(false));
let header_seen_srv = Arc::clone(&header_seen);
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4096];
let n = socket.read(&mut buf).await.unwrap();
let req = String::from_utf8_lossy(&buf[..n]);
// reqwest lower-cases header names.
if req.starts_with("GET ")
&& req.to_lowercase().contains("x-custom-auth: my-test-token")
{
header_seen_srv.store(true, AtomicOrdering::SeqCst);
}
socket
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
});
let client = reqwest::Client::new();
let url = format!("http://{addr}/mcp");
let mut headers = HashMap::new();
headers.insert("X-Custom-Auth".to_string(), "my-test-token".to_string());
let mut transport = HttpTransport::new(
client,
url,
headers,
tokio_util::sync::CancellationToken::new(),
Duration::from_secs(10),
);
transport.try_establish_session().await.unwrap();
server.abort();
assert!(
header_seen.load(AtomicOrdering::SeqCst),
"GET preflight must include user-configured custom headers"
);
}
}