feat(mcp): support Streamable HTTP MCP endpoints with SSE fallback (#1300)
Closes #1266 (DeepWiki at https://mcp.deepwiki.com/mcp). URL-based MCP servers now try the modern Streamable HTTP transport first — POST JSON-RPC to the base URL with `Accept: application/json, text/event-stream`, accept either JSON or SSE response — and fall back to the older SSE endpoint-discovery flow on incompatible status codes (404/405/406/415/501). Existing SSE servers keep working via the fallback. Single-file change in `crates/tui/src/mcp.rs` with a tokio-based end-to-end test that exercises the full handshake. Thanks @reidliu41.
This commit is contained in:
+376
-9
@@ -12,6 +12,8 @@ use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::header::{ACCEPT, CONTENT_TYPE};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
|
||||
use tokio::process::{Child, ChildStdin, ChildStdout};
|
||||
@@ -376,6 +378,30 @@ pub struct SseTransport {
|
||||
pending_messages: VecDeque<serde_json::Value>,
|
||||
}
|
||||
|
||||
struct HttpTransport {
|
||||
mode: HttpTransportMode,
|
||||
client: reqwest::Client,
|
||||
base_url: String,
|
||||
cancel_token: tokio_util::sync::CancellationToken,
|
||||
endpoint_timeout: Duration,
|
||||
}
|
||||
|
||||
enum HttpTransportMode {
|
||||
Streamable(StreamableHttpTransport),
|
||||
Sse(SseTransport),
|
||||
}
|
||||
|
||||
struct StreamableHttpTransport {
|
||||
client: reqwest::Client,
|
||||
url: String,
|
||||
pending_messages: VecDeque<serde_json::Value>,
|
||||
}
|
||||
|
||||
enum StreamableSendError {
|
||||
Incompatible(String),
|
||||
Other(anyhow::Error),
|
||||
}
|
||||
|
||||
impl SseTransport {
|
||||
pub async fn connect(
|
||||
client: reqwest::Client,
|
||||
@@ -565,6 +591,205 @@ impl SseTransport {
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpTransport {
|
||||
fn new(
|
||||
client: reqwest::Client,
|
||||
url: String,
|
||||
cancel_token: tokio_util::sync::CancellationToken,
|
||||
endpoint_timeout: Duration,
|
||||
) -> Self {
|
||||
Self {
|
||||
mode: HttpTransportMode::Streamable(StreamableHttpTransport::new(
|
||||
client.clone(),
|
||||
url.clone(),
|
||||
)),
|
||||
client,
|
||||
base_url: url,
|
||||
cancel_token,
|
||||
endpoint_timeout,
|
||||
}
|
||||
}
|
||||
|
||||
async fn switch_to_sse_and_send(&mut self, msg: serde_json::Value) -> Result<()> {
|
||||
let mut sse = SseTransport::connect(
|
||||
self.client.clone(),
|
||||
self.base_url.clone(),
|
||||
self.cancel_token.clone(),
|
||||
self.endpoint_timeout,
|
||||
)
|
||||
.await?;
|
||||
sse.send(msg).await?;
|
||||
self.mode = HttpTransportMode::Sse(sse);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransport for HttpTransport {
|
||||
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
|
||||
match &mut self.mode {
|
||||
HttpTransportMode::Streamable(transport) => match transport.send(msg.clone()).await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(StreamableSendError::Incompatible(detail)) => {
|
||||
tracing::debug!(
|
||||
"MCP Streamable HTTP unavailable; falling back to SSE endpoint discovery: {}",
|
||||
detail
|
||||
);
|
||||
self.switch_to_sse_and_send(msg).await
|
||||
}
|
||||
Err(StreamableSendError::Other(err)) => Err(err),
|
||||
},
|
||||
HttpTransportMode::Sse(transport) => transport.send(msg).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv(&mut self) -> Result<serde_json::Value> {
|
||||
match &mut self.mode {
|
||||
HttpTransportMode::Streamable(transport) => transport.recv().await,
|
||||
HttpTransportMode::Sse(transport) => transport.recv().await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown(&mut self) {
|
||||
if let HttpTransportMode::Sse(transport) = &mut self.mode {
|
||||
transport.shutdown().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamableHttpTransport {
|
||||
fn new(client: reqwest::Client, url: String) -> Self {
|
||||
Self {
|
||||
client,
|
||||
url,
|
||||
pending_messages: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&mut self,
|
||||
msg: serde_json::Value,
|
||||
) -> std::result::Result<(), StreamableSendError> {
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.url)
|
||||
.header(ACCEPT, "application/json, text/event-stream")
|
||||
.json(&msg)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| StreamableSendError::Other(err.into()))?;
|
||||
|
||||
let status = response.status();
|
||||
if status == StatusCode::ACCEPTED || status == StatusCode::NO_CONTENT {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !status.is_success() {
|
||||
let body_excerpt = bounded_body_excerpt(response, ERROR_BODY_PREVIEW_BYTES).await;
|
||||
if is_streamable_http_incompatible_status(status) {
|
||||
return Err(StreamableSendError::Incompatible(format!(
|
||||
"status={status} body={body_excerpt}"
|
||||
)));
|
||||
}
|
||||
return Err(StreamableSendError::Other(anyhow::anyhow!(
|
||||
"MCP Streamable HTTP rejected (transport=http url={} status={}): {}",
|
||||
mask_url_secrets(&self.url),
|
||||
status,
|
||||
body_excerpt,
|
||||
)));
|
||||
}
|
||||
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(str::to_string);
|
||||
let body = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|err| StreamableSendError::Other(err.into()))?;
|
||||
self.store_response_body(content_type.as_deref(), &body)
|
||||
.map_err(StreamableSendError::Other)
|
||||
}
|
||||
|
||||
async fn recv(&mut self) -> Result<serde_json::Value> {
|
||||
self.pending_messages
|
||||
.pop_front()
|
||||
.context("MCP Streamable HTTP response queue is empty")
|
||||
}
|
||||
|
||||
fn store_response_body(&mut self, content_type: Option<&str>, body: &str) -> Result<()> {
|
||||
if body.trim().is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let is_event_stream = content_type
|
||||
.map(|value| value.to_ascii_lowercase().contains("text/event-stream"))
|
||||
.unwrap_or(false)
|
||||
|| body.trim_start().starts_with("event:")
|
||||
|| body.trim_start().starts_with("data:");
|
||||
|
||||
if is_event_stream {
|
||||
for msg in parse_sse_json_messages(body)? {
|
||||
self.pending_messages.push_back(msg);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.pending_messages
|
||||
.push_back(serde_json::from_str(body).context("Invalid MCP Streamable HTTP JSON")?);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn is_streamable_http_incompatible_status(status: StatusCode) -> bool {
|
||||
matches!(
|
||||
status,
|
||||
StatusCode::NOT_FOUND
|
||||
| StatusCode::METHOD_NOT_ALLOWED
|
||||
| StatusCode::NOT_ACCEPTABLE
|
||||
| StatusCode::UNSUPPORTED_MEDIA_TYPE
|
||||
| StatusCode::NOT_IMPLEMENTED
|
||||
)
|
||||
}
|
||||
|
||||
fn parse_sse_json_messages(body: &str) -> Result<Vec<serde_json::Value>> {
|
||||
let normalized = body.replace("\r\n", "\n");
|
||||
let mut messages = Vec::new();
|
||||
|
||||
for block in normalized.split("\n\n") {
|
||||
let mut event_type = "message";
|
||||
let mut data = String::new();
|
||||
|
||||
for line in block.lines() {
|
||||
if let Some(value) = sse_field_value(line, "event:") {
|
||||
event_type = value;
|
||||
} else if let Some(value) = sse_field_value(line, "data:") {
|
||||
if !data.is_empty() {
|
||||
data.push('\n');
|
||||
}
|
||||
data.push_str(value);
|
||||
}
|
||||
}
|
||||
|
||||
if event_type != "message" || data.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
messages.push(
|
||||
serde_json::from_str(data.trim())
|
||||
.with_context(|| format!("Invalid MCP SSE message data: {}", data.trim()))?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
fn sse_field_value<'a>(line: &'a str, field: &str) -> Option<&'a str> {
|
||||
let value = line.strip_prefix(field)?;
|
||||
Some(value.strip_prefix(' ').unwrap_or(value))
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransport for SseTransport {
|
||||
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
|
||||
@@ -650,15 +875,12 @@ impl McpConnection {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(connect_timeout_secs))
|
||||
.build()?;
|
||||
Box::new(
|
||||
SseTransport::connect(
|
||||
client,
|
||||
url.clone(),
|
||||
cancel_token.clone(),
|
||||
Duration::from_secs(connect_timeout_secs),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
Box::new(HttpTransport::new(
|
||||
client,
|
||||
url.clone(),
|
||||
cancel_token.clone(),
|
||||
Duration::from_secs(connect_timeout_secs),
|
||||
))
|
||||
} else if let Some(command) = &config.command {
|
||||
let mut cmd = tokio::process::Command::new(command);
|
||||
cmd.args(&config.args)
|
||||
@@ -2219,6 +2441,151 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_sse_json_messages_extracts_message_events() {
|
||||
let body = "event: message\r\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\r\n\r\n";
|
||||
let messages = parse_sse_json_messages(body).unwrap();
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0]["id"], 1);
|
||||
assert!(messages[0].get("result").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mcp_connection_supports_streamable_http_event_stream_responses() {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
async fn read_http_request(socket: &mut TcpStream) -> String {
|
||||
let mut request = Vec::new();
|
||||
let mut buf = [0; 1024];
|
||||
let header_end = loop {
|
||||
let n = socket.read(&mut buf).await.unwrap();
|
||||
assert!(n > 0, "client closed before headers completed");
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
if let Some(pos) = request.windows(4).position(|window| window == b"\r\n\r\n") {
|
||||
break pos + 4;
|
||||
}
|
||||
};
|
||||
|
||||
let headers = String::from_utf8_lossy(&request[..header_end]);
|
||||
let content_length = headers
|
||||
.lines()
|
||||
.find_map(|line| {
|
||||
let (name, value) = line.split_once(':')?;
|
||||
name.eq_ignore_ascii_case("content-length")
|
||||
.then(|| value.trim().parse::<usize>().ok())
|
||||
.flatten()
|
||||
})
|
||||
.unwrap_or(0);
|
||||
let total_len = header_end + content_length;
|
||||
while request.len() < total_len {
|
||||
let n = socket.read(&mut buf).await.unwrap();
|
||||
assert!(n > 0, "client closed before body completed");
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
}
|
||||
|
||||
String::from_utf8(request).unwrap()
|
||||
}
|
||||
|
||||
async fn write_json_sse(socket: &mut TcpStream, response: serde_json::Value) {
|
||||
let body = format!("event: message\ndata: {response}\n\n");
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\n\r\n{}",
|
||||
body.len(),
|
||||
body
|
||||
);
|
||||
socket.write_all(response.as_bytes()).await.unwrap();
|
||||
}
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let server = tokio::spawn(async move {
|
||||
for _ in 0..6 {
|
||||
let (mut socket, _) = listener.accept().await.unwrap();
|
||||
tokio::spawn(async move {
|
||||
let request = read_http_request(&mut socket).await;
|
||||
assert!(request.starts_with("POST /mcp "));
|
||||
assert!(
|
||||
request.contains("Accept: application/json, text/event-stream")
|
||||
|| request.contains("accept: application/json, text/event-stream")
|
||||
);
|
||||
let body = request.split("\r\n\r\n").nth(1).unwrap_or("");
|
||||
let value: serde_json::Value = serde_json::from_str(body).unwrap();
|
||||
let method = value["method"].as_str().unwrap();
|
||||
|
||||
if method == "notifications/initialized" {
|
||||
socket
|
||||
.write_all(b"HTTP/1.1 202 Accepted\r\nContent-Length: 0\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
let id = value["id"].clone();
|
||||
let result = match method {
|
||||
"initialize" => serde_json::json!({
|
||||
"protocolVersion": "2024-11-05",
|
||||
"serverInfo": {"name": "mock-streamable", "version": "1.0.0"},
|
||||
"capabilities": {"tools": {}, "resources": {}, "prompts": {}}
|
||||
}),
|
||||
"tools/list" => serde_json::json!({
|
||||
"tools": [{
|
||||
"name": "read_wiki_structure",
|
||||
"description": "Read wiki structure",
|
||||
"inputSchema": {"type": "object"}
|
||||
}]
|
||||
}),
|
||||
"resources/list" => serde_json::json!({"resources": []}),
|
||||
"resources/templates/list" => {
|
||||
serde_json::json!({"resourceTemplates": []})
|
||||
}
|
||||
"prompts/list" => serde_json::json!({"prompts": []}),
|
||||
other => panic!("unexpected method: {other}"),
|
||||
};
|
||||
write_json_sse(
|
||||
&mut socket,
|
||||
serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": result
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let config = McpServerConfig {
|
||||
command: None,
|
||||
args: vec![],
|
||||
env: HashMap::new(),
|
||||
url: Some(format!("http://{addr}/mcp")),
|
||||
connect_timeout: Some(2),
|
||||
execute_timeout: None,
|
||||
read_timeout: None,
|
||||
disabled: false,
|
||||
enabled: true,
|
||||
required: false,
|
||||
enabled_tools: Vec::new(),
|
||||
disabled_tools: Vec::new(),
|
||||
};
|
||||
|
||||
let conn = McpConnection::connect_with_policy(
|
||||
"deepwiki".to_string(),
|
||||
config,
|
||||
&McpTimeouts::default(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(conn.state(), ConnectionState::Ready);
|
||||
assert_eq!(conn.tools().len(), 1);
|
||||
assert_eq!(conn.tools()[0].name, "read_wiki_structure");
|
||||
|
||||
server.abort();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_url_secrets_strips_userinfo() {
|
||||
let masked = mask_url_secrets("https://user:s3cret@host.example/api?foo=bar");
|
||||
|
||||
Reference in New Issue
Block a user