fix(mcp): wait for SSE endpoint before connect returns (#1225)
This commit is contained in:
+162
-15
@@ -5,7 +5,7 @@
|
||||
//! - Automatic tool discovery via `tools/list`
|
||||
//! - Configurable timeouts per-server and globally
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
@@ -359,6 +359,7 @@ pub struct SseTransport {
|
||||
base_url: String,
|
||||
endpoint_url: Option<String>,
|
||||
receiver: tokio::sync::mpsc::UnboundedReceiver<serde_json::Value>,
|
||||
pending_messages: VecDeque<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl SseTransport {
|
||||
@@ -366,10 +367,12 @@ impl SseTransport {
|
||||
client: reqwest::Client,
|
||||
url: String,
|
||||
cancel_token: tokio_util::sync::CancellationToken,
|
||||
endpoint_timeout: Duration,
|
||||
) -> Result<Self> {
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
let client_clone = client.clone();
|
||||
let url_clone = url.clone();
|
||||
let wait_cancel_token = cancel_token.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if cancel_token.is_cancelled() {
|
||||
@@ -402,12 +405,17 @@ impl SseTransport {
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
let mut transport = Self {
|
||||
client,
|
||||
base_url: url,
|
||||
endpoint_url: None,
|
||||
receiver: rx,
|
||||
})
|
||||
pending_messages: VecDeque::new(),
|
||||
};
|
||||
transport
|
||||
.wait_for_endpoint(&wait_cancel_token, endpoint_timeout)
|
||||
.await?;
|
||||
Ok(transport)
|
||||
}
|
||||
|
||||
async fn run_sse_loop(
|
||||
@@ -491,6 +499,56 @@ impl SseTransport {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_for_endpoint(
|
||||
&mut self,
|
||||
cancel_token: &tokio_util::sync::CancellationToken,
|
||||
endpoint_timeout: Duration,
|
||||
) -> Result<()> {
|
||||
let timeout = tokio::time::sleep(endpoint_timeout);
|
||||
tokio::pin!(timeout);
|
||||
|
||||
loop {
|
||||
let msg = tokio::select! {
|
||||
_ = cancel_token.cancelled() => {
|
||||
anyhow::bail!("SSE transport cancelled before endpoint was discovered");
|
||||
}
|
||||
_ = &mut timeout => {
|
||||
anyhow::bail!(
|
||||
"SSE endpoint not received within {}ms",
|
||||
endpoint_timeout.as_millis()
|
||||
);
|
||||
}
|
||||
msg = self.receiver.recv() => {
|
||||
msg.context("SSE transport closed before endpoint was discovered")?
|
||||
}
|
||||
};
|
||||
|
||||
if self.store_endpoint_from_internal_message(&msg)? {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.pending_messages.push_back(msg);
|
||||
}
|
||||
}
|
||||
|
||||
fn store_endpoint_from_internal_message(&mut self, msg: &serde_json::Value) -> Result<bool> {
|
||||
let Some(endpoint) = msg.get("__internal_sse_endpoint__") else {
|
||||
return Ok(false);
|
||||
};
|
||||
let url_str = endpoint.as_str().context("Invalid endpoint format")?;
|
||||
self.endpoint_url = Some(Self::resolve_endpoint_url(&self.base_url, url_str)?);
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn resolve_endpoint_url(base_url: &str, endpoint_url: &str) -> Result<String> {
|
||||
if endpoint_url.starts_with("http://") || endpoint_url.starts_with("https://") {
|
||||
return Ok(endpoint_url.to_string());
|
||||
}
|
||||
let base = reqwest::Url::parse(base_url)?;
|
||||
let joined = base.join(endpoint_url)?;
|
||||
Ok(joined.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -509,17 +567,12 @@ impl McpTransport for SseTransport {
|
||||
|
||||
async fn recv(&mut self) -> Result<serde_json::Value> {
|
||||
loop {
|
||||
let msg = self.receiver.recv().await.context("SSE transport closed")?;
|
||||
if let Some(endpoint) = msg.get("__internal_sse_endpoint__") {
|
||||
let url_str = endpoint.as_str().context("Invalid endpoint format")?;
|
||||
// Handle relative vs absolute URLs
|
||||
if url_str.starts_with("http") {
|
||||
self.endpoint_url = Some(url_str.to_string());
|
||||
} else {
|
||||
let base = reqwest::Url::parse(&self.base_url)?;
|
||||
let joined = base.join(url_str)?;
|
||||
self.endpoint_url = Some(joined.to_string());
|
||||
}
|
||||
let msg = if let Some(msg) = self.pending_messages.pop_front() {
|
||||
msg
|
||||
} else {
|
||||
self.receiver.recv().await.context("SSE transport closed")?
|
||||
};
|
||||
if self.store_endpoint_from_internal_message(&msg)? {
|
||||
continue;
|
||||
}
|
||||
return Ok(msg);
|
||||
@@ -583,7 +636,15 @@ impl McpConnection {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(connect_timeout_secs))
|
||||
.build()?;
|
||||
Box::new(SseTransport::connect(client, url.clone(), cancel_token.clone()).await?)
|
||||
Box::new(
|
||||
SseTransport::connect(
|
||||
client,
|
||||
url.clone(),
|
||||
cancel_token.clone(),
|
||||
Duration::from_secs(connect_timeout_secs),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
} else if let Some(command) = &config.command {
|
||||
let mut cmd = tokio::process::Command::new(command);
|
||||
cmd.args(&config.args)
|
||||
@@ -2102,4 +2163,90 @@ mod tests {
|
||||
"child {pid} survived StdioTransport::shutdown — SIGTERM not delivered"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sse_connect_waits_for_endpoint_before_first_send() {
|
||||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicBool, Ordering as AtomicOrdering},
|
||||
};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let post_seen = Arc::new(AtomicBool::new(false));
|
||||
let server_post_seen = Arc::clone(&post_seen);
|
||||
let cancel_token = tokio_util::sync::CancellationToken::new();
|
||||
let server_cancel = cancel_token.clone();
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
loop {
|
||||
let Ok((mut socket, _)) = listener.accept().await else {
|
||||
break;
|
||||
};
|
||||
let post_seen = Arc::clone(&server_post_seen);
|
||||
let server_cancel = server_cancel.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut request = Vec::new();
|
||||
let mut buf = [0; 1024];
|
||||
loop {
|
||||
let n = socket.read(&mut buf).await.unwrap();
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
if request.windows(4).any(|window| window == b"\r\n\r\n") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let request = String::from_utf8_lossy(&request);
|
||||
if request.starts_with("GET /sse ") {
|
||||
socket
|
||||
.write_all(
|
||||
b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\r\n",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::time::sleep(Duration::from_millis(150)).await;
|
||||
socket
|
||||
.write_all(b"event: endpoint\ndata: /messages\n\n")
|
||||
.await
|
||||
.unwrap();
|
||||
server_cancel.cancelled().await;
|
||||
} else if request.starts_with("POST /messages ") {
|
||||
post_seen.store(true, AtomicOrdering::SeqCst);
|
||||
socket
|
||||
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("http://{addr}/sse");
|
||||
let mut transport =
|
||||
SseTransport::connect(client, url, cancel_token.clone(), Duration::from_secs(2))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.send(serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
post_seen.load(AtomicOrdering::SeqCst),
|
||||
"first SSE send should POST to the discovered endpoint"
|
||||
);
|
||||
|
||||
cancel_token.cancel();
|
||||
server.abort();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user