Support Streamable HTTP MCP endpoints

Add direct POST-based MCP transport for Streamable HTTP servers while keeping
  the existing SSE endpoint-discovery path as a fallback. Parse both JSON and
  text/event-stream responses so servers like DeepWiki can be validated and used.
This commit is contained in:
reidliu41
2026-05-09 18:12:09 +08:00
parent cd27e6ceef
commit e30583ab45
+376 -9
View File
@@ -12,6 +12,8 @@ use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use anyhow::{Context, Result};
use reqwest::StatusCode;
use reqwest::header::{ACCEPT, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use tokio::process::{Child, ChildStdin, ChildStdout};
@@ -376,6 +378,30 @@ pub struct SseTransport {
pending_messages: VecDeque<serde_json::Value>,
}
struct HttpTransport {
mode: HttpTransportMode,
client: reqwest::Client,
base_url: String,
cancel_token: tokio_util::sync::CancellationToken,
endpoint_timeout: Duration,
}
enum HttpTransportMode {
Streamable(StreamableHttpTransport),
Sse(SseTransport),
}
struct StreamableHttpTransport {
client: reqwest::Client,
url: String,
pending_messages: VecDeque<serde_json::Value>,
}
enum StreamableSendError {
Incompatible(String),
Other(anyhow::Error),
}
impl SseTransport {
pub async fn connect(
client: reqwest::Client,
@@ -565,6 +591,205 @@ impl SseTransport {
}
}
impl HttpTransport {
fn new(
client: reqwest::Client,
url: String,
cancel_token: tokio_util::sync::CancellationToken,
endpoint_timeout: Duration,
) -> Self {
Self {
mode: HttpTransportMode::Streamable(StreamableHttpTransport::new(
client.clone(),
url.clone(),
)),
client,
base_url: url,
cancel_token,
endpoint_timeout,
}
}
async fn switch_to_sse_and_send(&mut self, msg: serde_json::Value) -> Result<()> {
let mut sse = SseTransport::connect(
self.client.clone(),
self.base_url.clone(),
self.cancel_token.clone(),
self.endpoint_timeout,
)
.await?;
sse.send(msg).await?;
self.mode = HttpTransportMode::Sse(sse);
Ok(())
}
}
#[async_trait::async_trait]
impl McpTransport for HttpTransport {
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
match &mut self.mode {
HttpTransportMode::Streamable(transport) => match transport.send(msg.clone()).await {
Ok(()) => Ok(()),
Err(StreamableSendError::Incompatible(detail)) => {
tracing::debug!(
"MCP Streamable HTTP unavailable; falling back to SSE endpoint discovery: {}",
detail
);
self.switch_to_sse_and_send(msg).await
}
Err(StreamableSendError::Other(err)) => Err(err),
},
HttpTransportMode::Sse(transport) => transport.send(msg).await,
}
}
async fn recv(&mut self) -> Result<serde_json::Value> {
match &mut self.mode {
HttpTransportMode::Streamable(transport) => transport.recv().await,
HttpTransportMode::Sse(transport) => transport.recv().await,
}
}
async fn shutdown(&mut self) {
if let HttpTransportMode::Sse(transport) = &mut self.mode {
transport.shutdown().await;
}
}
}
impl StreamableHttpTransport {
fn new(client: reqwest::Client, url: String) -> Self {
Self {
client,
url,
pending_messages: VecDeque::new(),
}
}
async fn send(
&mut self,
msg: serde_json::Value,
) -> std::result::Result<(), StreamableSendError> {
let response = self
.client
.post(&self.url)
.header(ACCEPT, "application/json, text/event-stream")
.json(&msg)
.send()
.await
.map_err(|err| StreamableSendError::Other(err.into()))?;
let status = response.status();
if status == StatusCode::ACCEPTED || status == StatusCode::NO_CONTENT {
return Ok(());
}
if !status.is_success() {
let body_excerpt = bounded_body_excerpt(response, ERROR_BODY_PREVIEW_BYTES).await;
if is_streamable_http_incompatible_status(status) {
return Err(StreamableSendError::Incompatible(format!(
"status={status} body={body_excerpt}"
)));
}
return Err(StreamableSendError::Other(anyhow::anyhow!(
"MCP Streamable HTTP rejected (transport=http url={} status={}): {}",
mask_url_secrets(&self.url),
status,
body_excerpt,
)));
}
let content_type = response
.headers()
.get(CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(str::to_string);
let body = response
.text()
.await
.map_err(|err| StreamableSendError::Other(err.into()))?;
self.store_response_body(content_type.as_deref(), &body)
.map_err(StreamableSendError::Other)
}
async fn recv(&mut self) -> Result<serde_json::Value> {
self.pending_messages
.pop_front()
.context("MCP Streamable HTTP response queue is empty")
}
fn store_response_body(&mut self, content_type: Option<&str>, body: &str) -> Result<()> {
if body.trim().is_empty() {
return Ok(());
}
let is_event_stream = content_type
.map(|value| value.to_ascii_lowercase().contains("text/event-stream"))
.unwrap_or(false)
|| body.trim_start().starts_with("event:")
|| body.trim_start().starts_with("data:");
if is_event_stream {
for msg in parse_sse_json_messages(body)? {
self.pending_messages.push_back(msg);
}
return Ok(());
}
self.pending_messages
.push_back(serde_json::from_str(body).context("Invalid MCP Streamable HTTP JSON")?);
Ok(())
}
}
fn is_streamable_http_incompatible_status(status: StatusCode) -> bool {
matches!(
status,
StatusCode::NOT_FOUND
| StatusCode::METHOD_NOT_ALLOWED
| StatusCode::NOT_ACCEPTABLE
| StatusCode::UNSUPPORTED_MEDIA_TYPE
| StatusCode::NOT_IMPLEMENTED
)
}
fn parse_sse_json_messages(body: &str) -> Result<Vec<serde_json::Value>> {
let normalized = body.replace("\r\n", "\n");
let mut messages = Vec::new();
for block in normalized.split("\n\n") {
let mut event_type = "message";
let mut data = String::new();
for line in block.lines() {
if let Some(value) = sse_field_value(line, "event:") {
event_type = value;
} else if let Some(value) = sse_field_value(line, "data:") {
if !data.is_empty() {
data.push('\n');
}
data.push_str(value);
}
}
if event_type != "message" || data.trim().is_empty() {
continue;
}
messages.push(
serde_json::from_str(data.trim())
.with_context(|| format!("Invalid MCP SSE message data: {}", data.trim()))?,
);
}
Ok(messages)
}
fn sse_field_value<'a>(line: &'a str, field: &str) -> Option<&'a str> {
let value = line.strip_prefix(field)?;
Some(value.strip_prefix(' ').unwrap_or(value))
}
#[async_trait::async_trait]
impl McpTransport for SseTransport {
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
@@ -650,15 +875,12 @@ impl McpConnection {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(connect_timeout_secs))
.build()?;
Box::new(
SseTransport::connect(
client,
url.clone(),
cancel_token.clone(),
Duration::from_secs(connect_timeout_secs),
)
.await?,
)
Box::new(HttpTransport::new(
client,
url.clone(),
cancel_token.clone(),
Duration::from_secs(connect_timeout_secs),
))
} else if let Some(command) = &config.command {
let mut cmd = tokio::process::Command::new(command);
cmd.args(&config.args)
@@ -2219,6 +2441,151 @@ mod tests {
);
}
#[test]
fn parse_sse_json_messages_extracts_message_events() {
let body = "event: message\r\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\r\n\r\n";
let messages = parse_sse_json_messages(body).unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0]["id"], 1);
assert!(messages[0].get("result").is_some());
}
#[tokio::test]
async fn mcp_connection_supports_streamable_http_event_stream_responses() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
async fn read_http_request(socket: &mut TcpStream) -> String {
let mut request = Vec::new();
let mut buf = [0; 1024];
let header_end = loop {
let n = socket.read(&mut buf).await.unwrap();
assert!(n > 0, "client closed before headers completed");
request.extend_from_slice(&buf[..n]);
if let Some(pos) = request.windows(4).position(|window| window == b"\r\n\r\n") {
break pos + 4;
}
};
let headers = String::from_utf8_lossy(&request[..header_end]);
let content_length = headers
.lines()
.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().ok())
.flatten()
})
.unwrap_or(0);
let total_len = header_end + content_length;
while request.len() < total_len {
let n = socket.read(&mut buf).await.unwrap();
assert!(n > 0, "client closed before body completed");
request.extend_from_slice(&buf[..n]);
}
String::from_utf8(request).unwrap()
}
async fn write_json_sse(socket: &mut TcpStream, response: serde_json::Value) {
let body = format!("event: message\ndata: {response}\n\n");
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
socket.write_all(response.as_bytes()).await.unwrap();
}
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
for _ in 0..6 {
let (mut socket, _) = listener.accept().await.unwrap();
tokio::spawn(async move {
let request = read_http_request(&mut socket).await;
assert!(request.starts_with("POST /mcp "));
assert!(
request.contains("Accept: application/json, text/event-stream")
|| request.contains("accept: application/json, text/event-stream")
);
let body = request.split("\r\n\r\n").nth(1).unwrap_or("");
let value: serde_json::Value = serde_json::from_str(body).unwrap();
let method = value["method"].as_str().unwrap();
if method == "notifications/initialized" {
socket
.write_all(b"HTTP/1.1 202 Accepted\r\nContent-Length: 0\r\n\r\n")
.await
.unwrap();
return;
}
let id = value["id"].clone();
let result = match method {
"initialize" => serde_json::json!({
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "mock-streamable", "version": "1.0.0"},
"capabilities": {"tools": {}, "resources": {}, "prompts": {}}
}),
"tools/list" => serde_json::json!({
"tools": [{
"name": "read_wiki_structure",
"description": "Read wiki structure",
"inputSchema": {"type": "object"}
}]
}),
"resources/list" => serde_json::json!({"resources": []}),
"resources/templates/list" => {
serde_json::json!({"resourceTemplates": []})
}
"prompts/list" => serde_json::json!({"prompts": []}),
other => panic!("unexpected method: {other}"),
};
write_json_sse(
&mut socket,
serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": result
}),
)
.await;
});
}
});
let config = McpServerConfig {
command: None,
args: vec![],
env: HashMap::new(),
url: Some(format!("http://{addr}/mcp")),
connect_timeout: Some(2),
execute_timeout: None,
read_timeout: None,
disabled: false,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
};
let conn = McpConnection::connect_with_policy(
"deepwiki".to_string(),
config,
&McpTimeouts::default(),
None,
)
.await
.unwrap();
assert_eq!(conn.state(), ConnectionState::Ready);
assert_eq!(conn.tools().len(), 1);
assert_eq!(conn.tools()[0].name, "read_wiki_structure");
server.abort();
}
#[test]
fn mask_url_secrets_strips_userinfo() {
let masked = mask_url_secrets("https://user:s3cret@host.example/api?foo=bar");