fix(mcp): wait for SSE endpoint before connect returns (#1225)

This commit is contained in:
Hunter Bown
2026-05-08 11:00:01 -05:00
committed by GitHub
parent 360438f0c9
commit f29d1a3a21
+162 -15
View File
@@ -5,7 +5,7 @@
//! - Automatic tool discovery via `tools/list`
//! - Configurable timeouts per-server and globally
use std::collections::HashMap;
use std::collections::{HashMap, VecDeque};
use std::fs;
use std::path::Path;
use std::sync::atomic::{AtomicU64, Ordering};
@@ -359,6 +359,7 @@ pub struct SseTransport {
base_url: String,
endpoint_url: Option<String>,
receiver: tokio::sync::mpsc::UnboundedReceiver<serde_json::Value>,
pending_messages: VecDeque<serde_json::Value>,
}
impl SseTransport {
@@ -366,10 +367,12 @@ impl SseTransport {
client: reqwest::Client,
url: String,
cancel_token: tokio_util::sync::CancellationToken,
endpoint_timeout: Duration,
) -> Result<Self> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let client_clone = client.clone();
let url_clone = url.clone();
let wait_cancel_token = cancel_token.clone();
tokio::spawn(async move {
if cancel_token.is_cancelled() {
@@ -402,12 +405,17 @@ impl SseTransport {
}
});
Ok(Self {
let mut transport = Self {
client,
base_url: url,
endpoint_url: None,
receiver: rx,
})
pending_messages: VecDeque::new(),
};
transport
.wait_for_endpoint(&wait_cancel_token, endpoint_timeout)
.await?;
Ok(transport)
}
async fn run_sse_loop(
@@ -491,6 +499,56 @@ impl SseTransport {
}
Ok(())
}
async fn wait_for_endpoint(
&mut self,
cancel_token: &tokio_util::sync::CancellationToken,
endpoint_timeout: Duration,
) -> Result<()> {
let timeout = tokio::time::sleep(endpoint_timeout);
tokio::pin!(timeout);
loop {
let msg = tokio::select! {
_ = cancel_token.cancelled() => {
anyhow::bail!("SSE transport cancelled before endpoint was discovered");
}
_ = &mut timeout => {
anyhow::bail!(
"SSE endpoint not received within {}ms",
endpoint_timeout.as_millis()
);
}
msg = self.receiver.recv() => {
msg.context("SSE transport closed before endpoint was discovered")?
}
};
if self.store_endpoint_from_internal_message(&msg)? {
return Ok(());
}
self.pending_messages.push_back(msg);
}
}
fn store_endpoint_from_internal_message(&mut self, msg: &serde_json::Value) -> Result<bool> {
let Some(endpoint) = msg.get("__internal_sse_endpoint__") else {
return Ok(false);
};
let url_str = endpoint.as_str().context("Invalid endpoint format")?;
self.endpoint_url = Some(Self::resolve_endpoint_url(&self.base_url, url_str)?);
Ok(true)
}
fn resolve_endpoint_url(base_url: &str, endpoint_url: &str) -> Result<String> {
if endpoint_url.starts_with("http://") || endpoint_url.starts_with("https://") {
return Ok(endpoint_url.to_string());
}
let base = reqwest::Url::parse(base_url)?;
let joined = base.join(endpoint_url)?;
Ok(joined.to_string())
}
}
#[async_trait::async_trait]
@@ -509,17 +567,12 @@ impl McpTransport for SseTransport {
async fn recv(&mut self) -> Result<serde_json::Value> {
loop {
let msg = self.receiver.recv().await.context("SSE transport closed")?;
if let Some(endpoint) = msg.get("__internal_sse_endpoint__") {
let url_str = endpoint.as_str().context("Invalid endpoint format")?;
// Handle relative vs absolute URLs
if url_str.starts_with("http") {
self.endpoint_url = Some(url_str.to_string());
} else {
let base = reqwest::Url::parse(&self.base_url)?;
let joined = base.join(url_str)?;
self.endpoint_url = Some(joined.to_string());
}
let msg = if let Some(msg) = self.pending_messages.pop_front() {
msg
} else {
self.receiver.recv().await.context("SSE transport closed")?
};
if self.store_endpoint_from_internal_message(&msg)? {
continue;
}
return Ok(msg);
@@ -583,7 +636,15 @@ impl McpConnection {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(connect_timeout_secs))
.build()?;
Box::new(SseTransport::connect(client, url.clone(), cancel_token.clone()).await?)
Box::new(
SseTransport::connect(
client,
url.clone(),
cancel_token.clone(),
Duration::from_secs(connect_timeout_secs),
)
.await?,
)
} else if let Some(command) = &config.command {
let mut cmd = tokio::process::Command::new(command);
cmd.args(&config.args)
@@ -2102,4 +2163,90 @@ mod tests {
"child {pid} survived StdioTransport::shutdown — SIGTERM not delivered"
);
}
#[tokio::test]
async fn sse_connect_waits_for_endpoint_before_first_send() {
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering as AtomicOrdering},
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let post_seen = Arc::new(AtomicBool::new(false));
let server_post_seen = Arc::clone(&post_seen);
let cancel_token = tokio_util::sync::CancellationToken::new();
let server_cancel = cancel_token.clone();
let server = tokio::spawn(async move {
loop {
let Ok((mut socket, _)) = listener.accept().await else {
break;
};
let post_seen = Arc::clone(&server_post_seen);
let server_cancel = server_cancel.clone();
tokio::spawn(async move {
let mut request = Vec::new();
let mut buf = [0; 1024];
loop {
let n = socket.read(&mut buf).await.unwrap();
if n == 0 {
return;
}
request.extend_from_slice(&buf[..n]);
if request.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
let request = String::from_utf8_lossy(&request);
if request.starts_with("GET /sse ") {
socket
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\r\n",
)
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(150)).await;
socket
.write_all(b"event: endpoint\ndata: /messages\n\n")
.await
.unwrap();
server_cancel.cancelled().await;
} else if request.starts_with("POST /messages ") {
post_seen.store(true, AtomicOrdering::SeqCst);
socket
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
}
});
}
});
let client = reqwest::Client::new();
let url = format!("http://{addr}/sse");
let mut transport =
SseTransport::connect(client, url, cancel_token.clone(), Duration::from_secs(2))
.await
.unwrap();
transport
.send(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize"
}))
.await
.unwrap();
assert!(
post_seen.load(AtomicOrdering::SeqCst),
"first SSE send should POST to the discovered endpoint"
);
cancel_token.cancel();
server.abort();
}
}