fix: parse CRLF SSE MCP events
Accept both LF and CRLF SSE event separators in the MCP SSE transport so uvicorn and FastMCP servers can publish endpoint events correctly. Add regression coverage for CRLF endpoint discovery.
This commit is contained in:
+115
-6
@@ -608,18 +608,21 @@ impl SseTransport {
|
||||
let s = String::from_utf8_lossy(&chunk);
|
||||
buffer.push_str(&s);
|
||||
|
||||
while let Some(pos) = buffer.find("\n\n") {
|
||||
while let Some((pos, separator_len)) = find_sse_event_separator(&buffer) {
|
||||
let event_block = buffer[..pos].to_string();
|
||||
buffer = buffer[pos + 2..].to_string();
|
||||
buffer = buffer[pos + separator_len..].to_string();
|
||||
|
||||
let mut event_type = "message";
|
||||
let mut data = String::new();
|
||||
|
||||
for line in event_block.lines() {
|
||||
if let Some(stripped) = line.strip_prefix("event: ") {
|
||||
event_type = stripped;
|
||||
} else if let Some(stripped) = line.strip_prefix("data: ") {
|
||||
data.push_str(stripped);
|
||||
if let Some(value) = sse_field_value(line, "event:") {
|
||||
event_type = value;
|
||||
} else if let Some(value) = sse_field_value(line, "data:") {
|
||||
if !data.is_empty() {
|
||||
data.push('\n');
|
||||
}
|
||||
data.push_str(value);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -874,6 +877,15 @@ fn parse_sse_message_data(body: &str) -> Vec<Vec<u8>> {
|
||||
messages
|
||||
}
|
||||
|
||||
fn find_sse_event_separator(buffer: &str) -> Option<(usize, usize)> {
|
||||
match (buffer.find("\n\n"), buffer.find("\r\n\r\n")) {
|
||||
(Some(lf), Some(crlf)) if crlf < lf => Some((crlf, 4)),
|
||||
(Some(lf), _) => Some((lf, 2)),
|
||||
(_, Some(crlf)) => Some((crlf, 4)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn sse_field_value<'a>(line: &'a str, field: &str) -> Option<&'a str> {
|
||||
let value = line.strip_prefix(field)?;
|
||||
Some(value.strip_prefix(' ').unwrap_or(value))
|
||||
@@ -3014,6 +3026,18 @@ mod tests {
|
||||
assert!(value.get("result").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_sse_event_separator_accepts_lf_and_crlf() {
|
||||
assert_eq!(
|
||||
find_sse_event_separator("event: endpoint\n\n"),
|
||||
Some((15, 2))
|
||||
);
|
||||
assert_eq!(
|
||||
find_sse_event_separator("event: endpoint\r\n\r\n"),
|
||||
Some((15, 4))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mcp_connection_supports_streamable_http_event_stream_responses() {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
@@ -3411,4 +3435,89 @@ mod tests {
|
||||
cancel_token.cancel();
|
||||
server.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sse_connect_accepts_crlf_endpoint_events() {
|
||||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicBool, Ordering as AtomicOrdering},
|
||||
};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let post_seen = Arc::new(AtomicBool::new(false));
|
||||
let server_post_seen = Arc::clone(&post_seen);
|
||||
let cancel_token = tokio_util::sync::CancellationToken::new();
|
||||
let server_cancel = cancel_token.clone();
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
loop {
|
||||
let Ok((mut socket, _)) = listener.accept().await else {
|
||||
break;
|
||||
};
|
||||
let post_seen = Arc::clone(&server_post_seen);
|
||||
let server_cancel = server_cancel.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut request = Vec::new();
|
||||
let mut buf = [0; 1024];
|
||||
loop {
|
||||
let n = socket.read(&mut buf).await.unwrap();
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
if request.windows(4).any(|window| window == b"\r\n\r\n") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let request = String::from_utf8_lossy(&request);
|
||||
if request.starts_with("GET /sse ") {
|
||||
socket
|
||||
.write_all(
|
||||
b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\r\n",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
socket
|
||||
.write_all(b"event: endpoint\r\ndata: /messages\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
server_cancel.cancelled().await;
|
||||
} else if request.starts_with("POST /messages ") {
|
||||
post_seen.store(true, AtomicOrdering::SeqCst);
|
||||
socket
|
||||
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("http://{addr}/sse");
|
||||
let mut transport =
|
||||
SseTransport::connect(client, url, cancel_token.clone(), Duration::from_secs(2))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.send(json_frame(serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize"
|
||||
})))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
post_seen.load(AtomicOrdering::SeqCst),
|
||||
"first SSE send should POST to the CRLF-discovered endpoint"
|
||||
);
|
||||
|
||||
cancel_token.cancel();
|
||||
server.abort();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user