diff --git a/crates/tui/src/mcp.rs b/crates/tui/src/mcp.rs index b9a595e4..b9356c60 100644 --- a/crates/tui/src/mcp.rs +++ b/crates/tui/src/mcp.rs @@ -5,7 +5,7 @@ //! - Automatic tool discovery via `tools/list` //! - Configurable timeouts per-server and globally -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::fs; use std::path::Path; use std::sync::atomic::{AtomicU64, Ordering}; @@ -359,6 +359,7 @@ pub struct SseTransport { base_url: String, endpoint_url: Option, receiver: tokio::sync::mpsc::UnboundedReceiver, + pending_messages: VecDeque, } impl SseTransport { @@ -366,10 +367,12 @@ impl SseTransport { client: reqwest::Client, url: String, cancel_token: tokio_util::sync::CancellationToken, + endpoint_timeout: Duration, ) -> Result { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let client_clone = client.clone(); let url_clone = url.clone(); + let wait_cancel_token = cancel_token.clone(); tokio::spawn(async move { if cancel_token.is_cancelled() { @@ -402,12 +405,17 @@ impl SseTransport { } }); - Ok(Self { + let mut transport = Self { client, base_url: url, endpoint_url: None, receiver: rx, - }) + pending_messages: VecDeque::new(), + }; + transport + .wait_for_endpoint(&wait_cancel_token, endpoint_timeout) + .await?; + Ok(transport) } async fn run_sse_loop( @@ -491,6 +499,56 @@ impl SseTransport { } Ok(()) } + + async fn wait_for_endpoint( + &mut self, + cancel_token: &tokio_util::sync::CancellationToken, + endpoint_timeout: Duration, + ) -> Result<()> { + let timeout = tokio::time::sleep(endpoint_timeout); + tokio::pin!(timeout); + + loop { + let msg = tokio::select! { + _ = cancel_token.cancelled() => { + anyhow::bail!("SSE transport cancelled before endpoint was discovered"); + } + _ = &mut timeout => { + anyhow::bail!( + "SSE endpoint not received within {}ms", + endpoint_timeout.as_millis() + ); + } + msg = self.receiver.recv() => { + msg.context("SSE transport closed before endpoint was discovered")? + } + }; + + if self.store_endpoint_from_internal_message(&msg)? { + return Ok(()); + } + + self.pending_messages.push_back(msg); + } + } + + fn store_endpoint_from_internal_message(&mut self, msg: &serde_json::Value) -> Result { + let Some(endpoint) = msg.get("__internal_sse_endpoint__") else { + return Ok(false); + }; + let url_str = endpoint.as_str().context("Invalid endpoint format")?; + self.endpoint_url = Some(Self::resolve_endpoint_url(&self.base_url, url_str)?); + Ok(true) + } + + fn resolve_endpoint_url(base_url: &str, endpoint_url: &str) -> Result { + if endpoint_url.starts_with("http://") || endpoint_url.starts_with("https://") { + return Ok(endpoint_url.to_string()); + } + let base = reqwest::Url::parse(base_url)?; + let joined = base.join(endpoint_url)?; + Ok(joined.to_string()) + } } #[async_trait::async_trait] @@ -509,17 +567,12 @@ impl McpTransport for SseTransport { async fn recv(&mut self) -> Result { loop { - let msg = self.receiver.recv().await.context("SSE transport closed")?; - if let Some(endpoint) = msg.get("__internal_sse_endpoint__") { - let url_str = endpoint.as_str().context("Invalid endpoint format")?; - // Handle relative vs absolute URLs - if url_str.starts_with("http") { - self.endpoint_url = Some(url_str.to_string()); - } else { - let base = reqwest::Url::parse(&self.base_url)?; - let joined = base.join(url_str)?; - self.endpoint_url = Some(joined.to_string()); - } + let msg = if let Some(msg) = self.pending_messages.pop_front() { + msg + } else { + self.receiver.recv().await.context("SSE transport closed")? + }; + if self.store_endpoint_from_internal_message(&msg)? { continue; } return Ok(msg); @@ -583,7 +636,15 @@ impl McpConnection { let client = reqwest::Client::builder() .timeout(Duration::from_secs(connect_timeout_secs)) .build()?; - Box::new(SseTransport::connect(client, url.clone(), cancel_token.clone()).await?) + Box::new( + SseTransport::connect( + client, + url.clone(), + cancel_token.clone(), + Duration::from_secs(connect_timeout_secs), + ) + .await?, + ) } else if let Some(command) = &config.command { let mut cmd = tokio::process::Command::new(command); cmd.args(&config.args) @@ -2102,4 +2163,90 @@ mod tests { "child {pid} survived StdioTransport::shutdown — SIGTERM not delivered" ); } + + #[tokio::test] + async fn sse_connect_waits_for_endpoint_before_first_send() { + use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering as AtomicOrdering}, + }; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let post_seen = Arc::new(AtomicBool::new(false)); + let server_post_seen = Arc::clone(&post_seen); + let cancel_token = tokio_util::sync::CancellationToken::new(); + let server_cancel = cancel_token.clone(); + + let server = tokio::spawn(async move { + loop { + let Ok((mut socket, _)) = listener.accept().await else { + break; + }; + let post_seen = Arc::clone(&server_post_seen); + let server_cancel = server_cancel.clone(); + tokio::spawn(async move { + let mut request = Vec::new(); + let mut buf = [0; 1024]; + loop { + let n = socket.read(&mut buf).await.unwrap(); + if n == 0 { + return; + } + request.extend_from_slice(&buf[..n]); + if request.windows(4).any(|window| window == b"\r\n\r\n") { + break; + } + } + let request = String::from_utf8_lossy(&request); + if request.starts_with("GET /sse ") { + socket + .write_all( + b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\r\n", + ) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(150)).await; + socket + .write_all(b"event: endpoint\ndata: /messages\n\n") + .await + .unwrap(); + server_cancel.cancelled().await; + } else if request.starts_with("POST /messages ") { + post_seen.store(true, AtomicOrdering::SeqCst); + socket + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + .await + .unwrap(); + } + }); + } + }); + + let client = reqwest::Client::new(); + let url = format!("http://{addr}/sse"); + let mut transport = + SseTransport::connect(client, url, cancel_token.clone(), Duration::from_secs(2)) + .await + .unwrap(); + + transport + .send(serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" + })) + .await + .unwrap(); + + assert!( + post_seen.load(AtomicOrdering::SeqCst), + "first SSE send should POST to the discovered endpoint" + ); + + cancel_token.cancel(); + server.abort(); + } }