diff --git a/crates/tui/src/mcp.rs b/crates/tui/src/mcp.rs index 9e98e281..91c6f2b9 100644 --- a/crates/tui/src/mcp.rs +++ b/crates/tui/src/mcp.rs @@ -608,18 +608,21 @@ impl SseTransport { let s = String::from_utf8_lossy(&chunk); buffer.push_str(&s); - while let Some(pos) = buffer.find("\n\n") { + while let Some((pos, separator_len)) = find_sse_event_separator(&buffer) { let event_block = buffer[..pos].to_string(); - buffer = buffer[pos + 2..].to_string(); + buffer = buffer[pos + separator_len..].to_string(); let mut event_type = "message"; let mut data = String::new(); for line in event_block.lines() { - if let Some(stripped) = line.strip_prefix("event: ") { - event_type = stripped; - } else if let Some(stripped) = line.strip_prefix("data: ") { - data.push_str(stripped); + if let Some(value) = sse_field_value(line, "event:") { + event_type = value; + } else if let Some(value) = sse_field_value(line, "data:") { + if !data.is_empty() { + data.push('\n'); + } + data.push_str(value); } } @@ -874,6 +877,15 @@ fn parse_sse_message_data(body: &str) -> Vec> { messages } +fn find_sse_event_separator(buffer: &str) -> Option<(usize, usize)> { + match (buffer.find("\n\n"), buffer.find("\r\n\r\n")) { + (Some(lf), Some(crlf)) if crlf < lf => Some((crlf, 4)), + (Some(lf), _) => Some((lf, 2)), + (_, Some(crlf)) => Some((crlf, 4)), + _ => None, + } +} + fn sse_field_value<'a>(line: &'a str, field: &str) -> Option<&'a str> { let value = line.strip_prefix(field)?; Some(value.strip_prefix(' ').unwrap_or(value)) @@ -3014,6 +3026,18 @@ mod tests { assert!(value.get("result").is_some()); } + #[test] + fn find_sse_event_separator_accepts_lf_and_crlf() { + assert_eq!( + find_sse_event_separator("event: endpoint\n\n"), + Some((15, 2)) + ); + assert_eq!( + find_sse_event_separator("event: endpoint\r\n\r\n"), + Some((15, 4)) + ); + } + #[tokio::test] async fn mcp_connection_supports_streamable_http_event_stream_responses() { use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -3411,4 +3435,89 @@ mod tests { cancel_token.cancel(); server.abort(); } + + #[tokio::test] + async fn sse_connect_accepts_crlf_endpoint_events() { + use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering as AtomicOrdering}, + }; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let post_seen = Arc::new(AtomicBool::new(false)); + let server_post_seen = Arc::clone(&post_seen); + let cancel_token = tokio_util::sync::CancellationToken::new(); + let server_cancel = cancel_token.clone(); + + let server = tokio::spawn(async move { + loop { + let Ok((mut socket, _)) = listener.accept().await else { + break; + }; + let post_seen = Arc::clone(&server_post_seen); + let server_cancel = server_cancel.clone(); + tokio::spawn(async move { + let mut request = Vec::new(); + let mut buf = [0; 1024]; + loop { + let n = socket.read(&mut buf).await.unwrap(); + if n == 0 { + return; + } + request.extend_from_slice(&buf[..n]); + if request.windows(4).any(|window| window == b"\r\n\r\n") { + break; + } + } + let request = String::from_utf8_lossy(&request); + if request.starts_with("GET /sse ") { + socket + .write_all( + b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\r\n", + ) + .await + .unwrap(); + socket + .write_all(b"event: endpoint\r\ndata: /messages\r\n\r\n") + .await + .unwrap(); + server_cancel.cancelled().await; + } else if request.starts_with("POST /messages ") { + post_seen.store(true, AtomicOrdering::SeqCst); + socket + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + .await + .unwrap(); + } + }); + } + }); + + let client = reqwest::Client::new(); + let url = format!("http://{addr}/sse"); + let mut transport = + SseTransport::connect(client, url, cancel_token.clone(), Duration::from_secs(2)) + .await + .unwrap(); + + transport + .send(json_frame(serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" + }))) + .await + .unwrap(); + + assert!( + post_seen.load(AtomicOrdering::SeqCst), + "first SSE send should POST to the CRLF-discovered endpoint" + ); + + cancel_token.cancel(); + server.abort(); + } }