fix(mcp): persist Streamable HTTP session IDs
Capture and replay Mcp-Session-Id for Streamable HTTP transports, and apply configured custom headers to the GET preflight.\n\nCloses #1629.\n\nCo-authored-by: Zhiping <2716057626@qq.com>
This commit is contained in:
+256
-2
@@ -569,8 +569,15 @@ struct StreamableHttpTransport {
|
||||
/// runs before each request.
|
||||
headers: HashMap<String, String>,
|
||||
pending_messages: VecDeque<Vec<u8>>,
|
||||
/// Per-spec MCP session identifier returned by the server in the
|
||||
/// first response (typically the `initialize` response). Attached
|
||||
/// as the `Mcp-Session-Id` header on every subsequent outbound
|
||||
/// request so the server can correlate messages within the same
|
||||
/// session.
|
||||
session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum StreamableSendError {
|
||||
Incompatible(String),
|
||||
Other(anyhow::Error),
|
||||
@@ -797,6 +804,78 @@ impl HttpTransport {
|
||||
self.mode = HttpTransportMode::Sse(sse);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Best-effort session-establishment GET preflight.
|
||||
///
|
||||
/// Per the Streamable HTTP spec, the server may return an
|
||||
/// `Mcp-Session-Id` header on the `initialize` response (the normal
|
||||
/// path handled inside [`StreamableHttpTransport::send`] above).
|
||||
/// However some servers (e.g. Hindsight, #1629) **require** a session
|
||||
/// ID on every POST including `initialize`, creating a chicken-and-egg
|
||||
/// problem. For those servers we send a short-lived GET before the
|
||||
/// first POST: if the server returns a session ID in the GET response
|
||||
/// it will be captured by the header-reading code in
|
||||
/// [`StreamableHttpTransport::send`] just as if it came from a POST
|
||||
/// response.
|
||||
///
|
||||
/// This is intentionally best-effort:
|
||||
/// * The GET uses a tight per-request inner timeout so it never
|
||||
/// blocks connection startup for long.
|
||||
/// * If the server doesn't support GET (405, 404, …) we log a debug
|
||||
/// line and move on — the `initialize` POST will proceed without a
|
||||
/// session ID.
|
||||
/// * If the server opens an SSE stream in response (the GET from old
|
||||
/// SSE transport), we read only the headers, then discard the body
|
||||
/// so the SSE stream is torn down. The actual SSE path uses a
|
||||
/// dedicated `SseTransport` and is triggered by the incompatible-
|
||||
/// status fallback in [`HttpTransport::send`].
|
||||
async fn try_establish_session(&mut self) -> Result<()> {
|
||||
let transport = match &mut self.mode {
|
||||
HttpTransportMode::Streamable(t) => t,
|
||||
// Already on SSE — session is implicit via the long-lived GET.
|
||||
HttpTransportMode::Sse(_) => return Ok(()),
|
||||
};
|
||||
|
||||
let mut request = transport.client.get(&transport.url);
|
||||
request = with_default_mcp_http_headers(request, false);
|
||||
for (key, value) in &transport.headers {
|
||||
if !is_safe_custom_header(key, value) {
|
||||
tracing::warn!(
|
||||
target: "mcp",
|
||||
"skipping unsafe MCP header {:?} (empty/control-char/reserved)",
|
||||
key
|
||||
);
|
||||
continue;
|
||||
}
|
||||
request = request.header(key.as_str(), value.as_str());
|
||||
}
|
||||
let response = tokio::time::timeout(Duration::from_secs(5), request.send())
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("GET timeout"))?
|
||||
.map_err(|e| anyhow::anyhow!("GET error: {e}"))?;
|
||||
|
||||
// Capture session ID from the GET response so subsequent POSTs
|
||||
// (including `initialize`) can include it. This is the same
|
||||
// header-reading logic that would be hit inside
|
||||
// `StreamableHttpTransport::send` for POST responses, but since
|
||||
// the GET is sent before any POST we do it here directly.
|
||||
if let Some(sid) = response
|
||||
.headers()
|
||||
.get("Mcp-Session-Id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
&& transport.session_id.as_deref() != Some(sid)
|
||||
{
|
||||
tracing::debug!(target: "mcp", session_id = %sid, "captured MCP session ID via GET preflight");
|
||||
transport.session_id = Some(sid.to_string());
|
||||
}
|
||||
|
||||
// We only care about the response headers — discard the body.
|
||||
// If the server opened an SSE stream in response (some servers
|
||||
// do this on GET), it will be torn down when response is dropped.
|
||||
drop(response);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -839,6 +918,7 @@ impl StreamableHttpTransport {
|
||||
url,
|
||||
headers,
|
||||
pending_messages: VecDeque::new(),
|
||||
session_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -866,6 +946,12 @@ impl StreamableHttpTransport {
|
||||
}
|
||||
request = request.header(key.as_str(), value.as_str());
|
||||
}
|
||||
// Attach any previously captured session ID per the Streamable
|
||||
// HTTP spec so the server can correlate this request to the
|
||||
// existing session.
|
||||
if let Some(ref sid) = self.session_id {
|
||||
request = request.header("Mcp-Session-Id", sid.as_str());
|
||||
}
|
||||
let response = request
|
||||
.body(msg)
|
||||
.send()
|
||||
@@ -873,6 +959,19 @@ impl StreamableHttpTransport {
|
||||
.map_err(|err| StreamableSendError::Other(err.into()))?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
// Capture session ID from any response (2xx, 202, 4xx, …). The
|
||||
// server may return it on the `initialize` response or on a
|
||||
// best-effort GET preflight below.
|
||||
if let Some(sid) = response
|
||||
.headers()
|
||||
.get("Mcp-Session-Id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
&& self.session_id.as_deref() != Some(sid)
|
||||
{
|
||||
tracing::debug!(target: "mcp", session_id = %sid, "captured MCP session ID");
|
||||
self.session_id = Some(sid.to_string());
|
||||
}
|
||||
if status == StatusCode::ACCEPTED || status == StatusCode::NO_CONTENT {
|
||||
return Ok(());
|
||||
}
|
||||
@@ -1113,13 +1212,27 @@ impl McpConnection {
|
||||
}
|
||||
}
|
||||
let client = client_builder.build()?;
|
||||
Box::new(HttpTransport::new(
|
||||
let mut http = HttpTransport::new(
|
||||
client,
|
||||
url.clone(),
|
||||
config.headers.clone(),
|
||||
cancel_token.clone(),
|
||||
Duration::from_secs(connect_timeout_secs),
|
||||
))
|
||||
);
|
||||
// Best-effort session preflight for servers that require
|
||||
// a session ID on every POST including `initialize`
|
||||
// (e.g. Hindsight, #1629). Failures are non-fatal — the
|
||||
// `initialize` POST will proceed and may capture a session
|
||||
// ID from the response instead.
|
||||
if let Err(e) = http.try_establish_session().await {
|
||||
tracing::debug!(
|
||||
target: "mcp",
|
||||
server = %name,
|
||||
error = %e,
|
||||
"session-establishment GET skipped; proceeding with POST initialize"
|
||||
);
|
||||
}
|
||||
Box::new(http)
|
||||
} else if let Some(command) = &config.command {
|
||||
let mut cmd = tokio::process::Command::new(command);
|
||||
cmd.args(&config.args)
|
||||
@@ -2617,6 +2730,7 @@ pub fn format_tool_result(result: &serde_json::Value) -> String {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[test]
|
||||
@@ -3764,4 +3878,144 @@ mod tests {
|
||||
cancel_token.cancel();
|
||||
server.abort();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_id_starts_none() {
|
||||
let transport = StreamableHttpTransport::new(
|
||||
reqwest::Client::new(),
|
||||
"https://example.invalid/mcp".to_string(),
|
||||
HashMap::new(),
|
||||
);
|
||||
assert!(transport.session_id.is_none());
|
||||
}
|
||||
|
||||
/// Session ID captured from a POST response is replayed on the next POST.
|
||||
#[tokio::test]
|
||||
async fn session_id_captured_from_post_response_and_replayed() {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let server = tokio::spawn(async move {
|
||||
let (mut socket, _) = listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 4096];
|
||||
let n = socket.read(&mut buf).await.unwrap();
|
||||
let req = String::from_utf8_lossy(&buf[..n]);
|
||||
assert!(req.starts_with("POST "), "expected POST, got: {req}");
|
||||
|
||||
// First POST: return a session ID so the transport captures it.
|
||||
socket
|
||||
.write_all(
|
||||
b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nMcp-Session-Id: sess-abc-123\r\nContent-Length: 2\r\n\r\n{}",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
socket.flush().await.unwrap();
|
||||
|
||||
// Read the second POST — should contain the session ID.
|
||||
let mut buf2 = [0u8; 4096];
|
||||
let n2 = socket.read(&mut buf2).await.unwrap();
|
||||
let req2 = String::from_utf8_lossy(&buf2[..n2]);
|
||||
// reqwest lower-cases header names.
|
||||
let req2_lower = req2.to_lowercase();
|
||||
assert!(
|
||||
req2_lower.contains("mcp-session-id: sess-abc-123"),
|
||||
"second POST must replay captured session ID, got:\n{req2}"
|
||||
);
|
||||
|
||||
socket
|
||||
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("http://{addr}/mcp");
|
||||
let mut transport = StreamableHttpTransport::new(client, url, HashMap::new());
|
||||
|
||||
// First send: server returns Mcp-Session-Id.
|
||||
transport
|
||||
.send(json_frame(serde_json::json!({
|
||||
"jsonrpc": "2.0", "id": 1,
|
||||
"method": "initialize",
|
||||
"params": {}
|
||||
})))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
transport.session_id.as_deref(),
|
||||
Some("sess-abc-123"),
|
||||
"session ID should be captured from response"
|
||||
);
|
||||
|
||||
// Second send: should replay the session ID.
|
||||
transport
|
||||
.send(json_frame(serde_json::json!({
|
||||
"jsonrpc": "2.0", "id": 2,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
})))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
server.abort();
|
||||
}
|
||||
|
||||
/// Custom headers configured in McpServerConfig are applied to the GET
|
||||
/// preflight so servers that require auth on session-establishment GET
|
||||
/// (e.g. Hindsight, #1629) can authenticate it.
|
||||
#[tokio::test]
|
||||
async fn custom_headers_applied_to_get_preflight() {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
// The test signals success by writing to this flag — the GET handler
|
||||
// sets it when it sees the expected header.
|
||||
let header_seen = Arc::new(AtomicBool::new(false));
|
||||
let header_seen_srv = Arc::clone(&header_seen);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (mut socket, _) = listener.accept().await.unwrap();
|
||||
let mut buf = [0u8; 4096];
|
||||
let n = socket.read(&mut buf).await.unwrap();
|
||||
let req = String::from_utf8_lossy(&buf[..n]);
|
||||
|
||||
// reqwest lower-cases header names.
|
||||
if req.starts_with("GET ")
|
||||
&& req.to_lowercase().contains("x-custom-auth: my-test-token")
|
||||
{
|
||||
header_seen_srv.store(true, AtomicOrdering::SeqCst);
|
||||
}
|
||||
|
||||
socket
|
||||
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("http://{addr}/mcp");
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("X-Custom-Auth".to_string(), "my-test-token".to_string());
|
||||
|
||||
let mut transport = HttpTransport::new(
|
||||
client,
|
||||
url,
|
||||
headers,
|
||||
tokio_util::sync::CancellationToken::new(),
|
||||
Duration::from_secs(10),
|
||||
);
|
||||
|
||||
transport.try_establish_session().await.unwrap();
|
||||
|
||||
server.abort();
|
||||
|
||||
assert!(
|
||||
header_seen.load(AtomicOrdering::SeqCst),
|
||||
"GET preflight must include user-configured custom headers"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user