Retry MCP calls after stale SSE connections

This commit is contained in:
zhuang biaowei
2026-05-28 21:29:58 +08:00
committed by Hunter Bown
parent 58c57cb798
commit d26c2128b8
+591 -1
View File
@@ -610,6 +610,7 @@ struct StreamableHttpTransport {
#[derive(Debug)]
enum StreamableSendError {
Incompatible(String),
StaleSession(String),
Other(anyhow::Error),
}
@@ -922,6 +923,19 @@ impl McpTransport for HttpTransport {
);
self.switch_to_sse_and_send(msg).await
}
Err(StreamableSendError::StaleSession(detail)) => {
if let HttpTransportMode::Streamable(transport) = &mut self.mode {
tracing::debug!(
target: "mcp",
error = %detail,
"MCP Streamable HTTP session expired; clearing cached session ID"
);
transport.session_id = None;
}
Err(anyhow::anyhow!(
"MCP Streamable HTTP session expired; retry with a new session required ({detail})"
))
}
Err(StreamableSendError::Other(err)) => Err(err),
},
HttpTransportMode::Sse(transport) => transport.send(msg).await,
@@ -992,6 +1006,13 @@ impl StreamableHttpTransport {
if !status.is_success() {
let body_excerpt = bounded_body_excerpt(response, ERROR_BODY_PREVIEW_BYTES).await;
if self.session_id.is_some()
&& is_streamable_http_stale_session_status(status, &body_excerpt)
{
return Err(StreamableSendError::StaleSession(format!(
"status={status} body={body_excerpt}"
)));
}
if is_streamable_http_incompatible_status(status) {
return Err(StreamableSendError::Incompatible(format!(
"status={status} body={body_excerpt}"
@@ -1058,6 +1079,30 @@ fn is_streamable_http_incompatible_status(status: StatusCode) -> bool {
)
}
fn is_streamable_http_stale_session_status(status: StatusCode, body_excerpt: &str) -> bool {
if status == StatusCode::NOT_FOUND {
return true;
}
if status != StatusCode::BAD_REQUEST && status != StatusCode::UNAUTHORIZED {
return false;
}
let body = body_excerpt.to_ascii_lowercase();
body.contains("session") && (body.contains("expired") || body.contains("invalid"))
}
fn is_mcp_stale_session_body(body: &str) -> bool {
let body = body.to_ascii_lowercase();
body.contains("session") && (body.contains("expired") || body.contains("invalid"))
}
fn is_mcp_stale_session_error(err: &anyhow::Error) -> bool {
let err = format!("{err:#}");
err.contains("MCP Streamable HTTP session expired")
|| err.contains("MCP session expired")
|| err.contains("SSE transport closed")
|| is_mcp_stale_session_body(&err)
}
fn parse_sse_message_data(body: &str) -> Vec<Vec<u8>> {
let normalized = body.replace("\r\n", "\n");
let mut messages = Vec::new();
@@ -1148,6 +1193,14 @@ impl McpTransport for SseTransport {
let status = response.status();
if !status.is_success() {
let body_excerpt = bounded_body_excerpt(response, ERROR_BODY_PREVIEW_BYTES).await;
if is_mcp_stale_session_body(&body_excerpt) {
anyhow::bail!(
"MCP session expired (transport=sse endpoint={} status={}): {}",
mask_url_secrets(endpoint),
status,
body_excerpt
);
}
anyhow::bail!(
"MCP SSE POST rejected (transport=sse endpoint={} status={}): {}",
mask_url_secrets(endpoint),
@@ -1779,6 +1832,11 @@ impl McpConnection {
// IDs, but accept numeric echoes for compatibility with older
// servers and tests.
if response_id_matches(value.get("id"), &expected_id) {
if let Some(error) = value.get("error") {
if is_mcp_stale_session_body(&error.to_string()) {
anyhow::bail!("MCP session expired: {error}");
}
}
return Ok(value);
}
// Skip notifications (no id) and responses with different ids
@@ -2359,7 +2417,26 @@ impl McpPool {
anyhow::bail!("MCP tool '{tool_name}' is disabled for server '{server_name}'");
}
let timeout = conn.config().effective_execute_timeout(&global_timeouts);
conn.call_tool(tool_name, arguments, timeout).await
match conn.call_tool(tool_name, arguments.clone(), timeout).await {
Ok(result) => Ok(result),
Err(err) if is_mcp_stale_session_error(&err) => {
tracing::debug!(
target: "mcp",
server = server_name,
tool = tool_name,
error = %err,
"retrying MCP tool call after stale session"
);
self.connections.remove(server_name);
let conn = self.get_or_connect(server_name).await?;
if !conn.config().is_tool_enabled(tool_name) {
anyhow::bail!("MCP tool '{tool_name}' is disabled for server '{server_name}'");
}
let timeout = conn.config().effective_execute_timeout(&global_timeouts);
conn.call_tool(tool_name, arguments, timeout).await
}
Err(err) => Err(err),
}
}
/// Get list of configured server names
@@ -3500,6 +3577,94 @@ mod tests {
);
}
#[tokio::test]
async fn json_rpc_session_error_is_marked_stale() {
let sent = Arc::new(Mutex::new(Vec::new()));
let transport = ScriptedValueTransport {
sent: Arc::clone(&sent),
responses: VecDeque::from([json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"error": {
"code": -32001,
"message": "MCP session expired"
}
}))]),
};
let mut conn = test_connection(Box::new(transport));
let err = conn
.call_tool("search", serde_json::json!({"query": "dephy"}), 1)
.await
.expect_err("session error should fail");
assert!(
is_mcp_stale_session_error(&err),
"JSON-RPC session error should be retryable, got: {err:#}"
);
}
#[test]
fn sse_transport_closed_is_retryable() {
let err = anyhow::anyhow!("SSE transport closed");
assert!(
is_mcp_stale_session_error(&err),
"closed SSE stream should force reconnect before retry"
);
}
#[tokio::test]
async fn discover_all_ignores_unsupported_optional_capabilities() {
let sent = Arc::new(Mutex::new(Vec::new()));
let transport = ScriptedValueTransport {
sent: Arc::clone(&sent),
responses: VecDeque::from([
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": {
"tools": [
{ "name": "search", "inputSchema": {} }
]
}
})),
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"error": {
"code": -32601,
"message": "resources not supported"
}
})),
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 3,
"error": {
"code": -32601,
"message": "resource templates not supported"
}
})),
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 4,
"error": {
"code": -32601,
"message": "prompts not supported"
}
})),
]),
};
let mut conn = test_connection(Box::new(transport));
conn.discover_all().await.expect("discover");
assert_eq!(conn.tools.len(), 1);
assert_eq!(conn.tools[0].name, "search");
assert!(conn.resources.is_empty());
assert!(conn.resource_templates.is_empty());
assert!(conn.prompts.is_empty());
}
/// #1244: when an MCP stdio server fails to spawn, the underlying OS
/// error (e.g. ENOENT for a missing binary) must reach the user via the
/// snapshot.error string. Regression test for `err.to_string()` dropping
@@ -4277,6 +4442,431 @@ mod tests {
server.abort();
}
#[tokio::test]
async fn streamable_http_stale_session_reconnects_and_retries_tool_call() {
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let get_count = Arc::new(AtomicUsize::new(0));
let stale_seen = Arc::new(AtomicBool::new(false));
let success_seen = Arc::new(AtomicBool::new(false));
let server_get_count = Arc::clone(&get_count);
let server_stale_seen = Arc::clone(&stale_seen);
let server_success_seen = Arc::clone(&success_seen);
let server = tokio::spawn(async move {
loop {
let Ok((mut socket, _)) = listener.accept().await else {
break;
};
let get_count = Arc::clone(&server_get_count);
let stale_seen = Arc::clone(&server_stale_seen);
let success_seen = Arc::clone(&server_success_seen);
tokio::spawn(async move {
let mut request = Vec::new();
let mut buf = [0; 4096];
let header_end = loop {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return;
}
request.extend_from_slice(&buf[..n]);
if let Some(pos) = request.windows(4).position(|w| w == b"\r\n\r\n") {
break pos + 4;
}
};
let headers = String::from_utf8_lossy(&request[..header_end]).to_string();
let content_length = headers
.lines()
.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().ok())
.flatten()
})
.unwrap_or(0);
while request.len() < header_end + content_length {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return;
}
request.extend_from_slice(&buf[..n]);
}
let body = &request[header_end..header_end + content_length];
let session_header = headers.lines().find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("mcp-session-id")
.then(|| value.trim().to_string())
});
if headers.starts_with("GET /mcp ") {
let count = get_count.fetch_add(1, AtomicOrdering::SeqCst);
let session = if count == 0 { "sess-old" } else { "sess-new" };
let response = format!(
"HTTP/1.1 200 OK\r\nMcp-Session-Id: {session}\r\nContent-Length: 0\r\n\r\n"
);
socket.write_all(response.as_bytes()).await.unwrap();
return;
}
let request_json: serde_json::Value = serde_json::from_slice(body).unwrap();
let method = request_json
.get("method")
.and_then(serde_json::Value::as_str)
.unwrap_or("");
let id = request_json
.get("id")
.cloned()
.unwrap_or_else(|| serde_json::json!("0"));
if method == "tools/call" && session_header.as_deref() == Some("sess-old") {
stale_seen.store(true, AtomicOrdering::SeqCst);
socket
.write_all(
b"HTTP/1.1 404 Not Found\r\nContent-Type: application/json\r\nContent-Length: 27\r\n\r\n{\"error\":\"session expired\"}",
)
.await
.unwrap();
return;
}
let result = match method {
"initialize" => serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {}
}),
"tools/list" => serde_json::json!({
"tools": [
{ "name": "search", "inputSchema": {} }
]
}),
"resources/list" => serde_json::json!({ "resources": [] }),
"resources/templates/list" => {
serde_json::json!({ "resourceTemplates": [] })
}
"prompts/list" => serde_json::json!({ "prompts": [] }),
"tools/call" => {
assert_eq!(session_header.as_deref(), Some("sess-new"));
success_seen.store(true, AtomicOrdering::SeqCst);
serde_json::json!({ "content": [{ "type": "text", "text": "ok" }] })
}
_ => {
socket
.write_all(b"HTTP/1.1 202 Accepted\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
return;
}
};
let response_body = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": result
})
.to_string();
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
response_body.len(),
response_body
);
socket.write_all(response.as_bytes()).await.unwrap();
});
}
});
let mut cfg = McpConfig::default();
cfg.servers.insert(
"dephy".to_string(),
McpServerConfig {
command: None,
args: Vec::new(),
env: HashMap::new(),
url: Some(format!("http://{addr}/mcp")),
transport: None,
connect_timeout: Some(2),
execute_timeout: Some(2),
read_timeout: None,
disabled: false,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
headers: HashMap::new(),
},
);
let mut pool = McpPool::new(cfg);
let result = pool
.call_tool("mcp_dephy_search", serde_json::json!({ "query": "dephy" }))
.await
.unwrap();
assert_eq!(
result,
serde_json::json!({ "content": [{ "type": "text", "text": "ok" }] })
);
assert!(stale_seen.load(AtomicOrdering::SeqCst));
assert!(success_seen.load(AtomicOrdering::SeqCst));
assert_eq!(get_count.load(AtomicOrdering::SeqCst), 2);
server.abort();
}
#[tokio::test]
async fn legacy_sse_session_expiry_is_marked_stale() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::mpsc;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut request = Vec::new();
let mut buf = [0; 4096];
let header_end = loop {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return;
}
request.extend_from_slice(&buf[..n]);
if let Some(pos) = request.windows(4).position(|w| w == b"\r\n\r\n") {
break pos + 4;
}
};
let headers = String::from_utf8_lossy(&request[..header_end]);
assert!(headers.starts_with("POST /messages "));
socket
.write_all(
b"HTTP/1.1 400 Bad Request\r\nContent-Type: application/json\r\nContent-Length: 27\r\n\r\n{\"error\":\"session expired\"}",
)
.await
.unwrap();
});
let (_sender, receiver) = mpsc::unbounded_channel();
let mut transport = SseTransport {
client: reqwest::Client::new(),
base_url: format!("http://{addr}/sse"),
headers: HashMap::new(),
endpoint_url: Some(format!("http://{addr}/messages")),
receiver,
pending_messages: VecDeque::new(),
};
let err = transport
.send(br#"{"jsonrpc":"2.0","id":1,"method":"tools/call"}"#.to_vec())
.await
.expect_err("expired SSE session should fail");
assert!(
is_mcp_stale_session_error(&err),
"SSE session expiry should be retryable, got: {err:#}"
);
server.abort();
}
#[tokio::test]
async fn legacy_sse_closed_stream_reconnects_and_retries_tool_call() {
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
async fn read_http_request(socket: &mut TcpStream) -> (String, serde_json::Value) {
let mut request = Vec::new();
let mut buf = [0; 4096];
let header_end = loop {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return (String::new(), serde_json::Value::Null);
}
request.extend_from_slice(&buf[..n]);
if let Some(pos) = request.windows(4).position(|w| w == b"\r\n\r\n") {
break pos + 4;
}
};
let headers = String::from_utf8_lossy(&request[..header_end]).to_string();
let content_length = headers
.lines()
.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().ok())
.flatten()
})
.unwrap_or(0);
while request.len() < header_end + content_length {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return (headers, serde_json::Value::Null);
}
request.extend_from_slice(&buf[..n]);
}
let body = &request[header_end..header_end + content_length];
let json = if body.is_empty() {
serde_json::Value::Null
} else {
serde_json::from_slice(body).unwrap()
};
(headers, json)
}
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let active_sse = Arc::new(Mutex::new(None::<mpsc::UnboundedSender<Option<String>>>));
let get_count = Arc::new(AtomicUsize::new(0));
let tool_call_count = Arc::new(AtomicUsize::new(0));
let success_seen = Arc::new(AtomicBool::new(false));
let server_active_sse = Arc::clone(&active_sse);
let server_get_count = Arc::clone(&get_count);
let server_tool_call_count = Arc::clone(&tool_call_count);
let server_success_seen = Arc::clone(&success_seen);
let server = tokio::spawn(async move {
loop {
let Ok((mut socket, _)) = listener.accept().await else {
break;
};
let active_sse = Arc::clone(&server_active_sse);
let get_count = Arc::clone(&server_get_count);
let tool_call_count = Arc::clone(&server_tool_call_count);
let success_seen = Arc::clone(&server_success_seen);
tokio::spawn(async move {
let (headers, request_json) = read_http_request(&mut socket).await;
if headers.starts_with("GET /sse ") {
get_count.fetch_add(1, AtomicOrdering::SeqCst);
let (tx, mut rx) = mpsc::unbounded_channel::<Option<String>>();
*active_sse.lock().unwrap() = Some(tx);
socket
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\r\n",
)
.await
.unwrap();
socket
.write_all(b"event: endpoint\ndata: /messages\n\n")
.await
.unwrap();
while let Some(message) = rx.recv().await {
let Some(message) = message else {
return;
};
let event = format!("event: message\ndata: {message}\n\n");
socket.write_all(event.as_bytes()).await.unwrap();
}
return;
}
if !headers.starts_with("POST /messages ") {
return;
}
socket
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
let method = request_json
.get("method")
.and_then(serde_json::Value::as_str)
.unwrap_or("");
if method == "notifications/initialized" {
return;
}
let id = request_json
.get("id")
.cloned()
.unwrap_or_else(|| serde_json::json!("0"));
if method == "tools/call" {
let count = tool_call_count.fetch_add(1, AtomicOrdering::SeqCst);
if count == 0 {
if let Some(tx) = active_sse.lock().unwrap().take() {
let _ = tx.send(None);
}
return;
}
}
let result = match method {
"initialize" => serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {}
}),
"tools/list" => serde_json::json!({
"tools": [
{ "name": "search", "inputSchema": {} }
]
}),
"resources/list" => serde_json::json!({ "resources": [] }),
"resources/templates/list" => {
serde_json::json!({ "resourceTemplates": [] })
}
"prompts/list" => serde_json::json!({ "prompts": [] }),
"tools/call" => {
success_seen.store(true, AtomicOrdering::SeqCst);
serde_json::json!({ "content": [{ "type": "text", "text": "ok" }] })
}
other => panic!("unexpected method: {other}"),
};
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": result
})
.to_string();
if let Some(tx) = active_sse.lock().unwrap().as_ref() {
let _ = tx.send(Some(response));
}
});
}
});
let mut cfg = McpConfig::default();
cfg.servers.insert(
"dephy".to_string(),
McpServerConfig {
command: None,
args: Vec::new(),
env: HashMap::new(),
url: Some(format!("http://{addr}/sse")),
transport: Some("sse".to_string()),
connect_timeout: Some(2),
execute_timeout: Some(2),
read_timeout: None,
disabled: false,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
headers: HashMap::new(),
},
);
let mut pool = McpPool::new(cfg);
let result = pool
.call_tool("mcp_dephy_search", serde_json::json!({ "query": "dephy" }))
.await
.unwrap();
assert_eq!(
result,
serde_json::json!({ "content": [{ "type": "text", "text": "ok" }] })
);
assert_eq!(tool_call_count.load(AtomicOrdering::SeqCst), 2);
assert_eq!(get_count.load(AtomicOrdering::SeqCst), 2);
assert!(success_seen.load(AtomicOrdering::SeqCst));
server.abort();
}
#[test]
fn session_id_starts_none() {
let transport = StreamableHttpTransport::new(