From e30583ab45829b321a0e623c81f6e8f17a09698c Mon Sep 17 00:00:00 2001 From: reidliu41 Date: Sat, 9 May 2026 18:12:09 +0800 Subject: [PATCH] Support Streamable HTTP MCP endpoints Add direct POST-based MCP transport for Streamable HTTP servers while keeping the existing SSE endpoint-discovery path as a fallback. Parse both JSON and text/event-stream responses so servers like DeepWiki can be validated and used. --- crates/tui/src/mcp.rs | 385 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 376 insertions(+), 9 deletions(-) diff --git a/crates/tui/src/mcp.rs b/crates/tui/src/mcp.rs index 57807ae2..ecf54b76 100644 --- a/crates/tui/src/mcp.rs +++ b/crates/tui/src/mcp.rs @@ -12,6 +12,8 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; use anyhow::{Context, Result}; +use reqwest::StatusCode; +use reqwest::header::{ACCEPT, CONTENT_TYPE}; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; use tokio::process::{Child, ChildStdin, ChildStdout}; @@ -376,6 +378,30 @@ pub struct SseTransport { pending_messages: VecDeque, } +struct HttpTransport { + mode: HttpTransportMode, + client: reqwest::Client, + base_url: String, + cancel_token: tokio_util::sync::CancellationToken, + endpoint_timeout: Duration, +} + +enum HttpTransportMode { + Streamable(StreamableHttpTransport), + Sse(SseTransport), +} + +struct StreamableHttpTransport { + client: reqwest::Client, + url: String, + pending_messages: VecDeque, +} + +enum StreamableSendError { + Incompatible(String), + Other(anyhow::Error), +} + impl SseTransport { pub async fn connect( client: reqwest::Client, @@ -565,6 +591,205 @@ impl SseTransport { } } +impl HttpTransport { + fn new( + client: reqwest::Client, + url: String, + cancel_token: tokio_util::sync::CancellationToken, + endpoint_timeout: Duration, + ) -> Self { + Self { + mode: HttpTransportMode::Streamable(StreamableHttpTransport::new( + client.clone(), + url.clone(), + )), + client, + base_url: url, + cancel_token, + endpoint_timeout, + } + } + + async fn switch_to_sse_and_send(&mut self, msg: serde_json::Value) -> Result<()> { + let mut sse = SseTransport::connect( + self.client.clone(), + self.base_url.clone(), + self.cancel_token.clone(), + self.endpoint_timeout, + ) + .await?; + sse.send(msg).await?; + self.mode = HttpTransportMode::Sse(sse); + Ok(()) + } +} + +#[async_trait::async_trait] +impl McpTransport for HttpTransport { + async fn send(&mut self, msg: serde_json::Value) -> Result<()> { + match &mut self.mode { + HttpTransportMode::Streamable(transport) => match transport.send(msg.clone()).await { + Ok(()) => Ok(()), + Err(StreamableSendError::Incompatible(detail)) => { + tracing::debug!( + "MCP Streamable HTTP unavailable; falling back to SSE endpoint discovery: {}", + detail + ); + self.switch_to_sse_and_send(msg).await + } + Err(StreamableSendError::Other(err)) => Err(err), + }, + HttpTransportMode::Sse(transport) => transport.send(msg).await, + } + } + + async fn recv(&mut self) -> Result { + match &mut self.mode { + HttpTransportMode::Streamable(transport) => transport.recv().await, + HttpTransportMode::Sse(transport) => transport.recv().await, + } + } + + async fn shutdown(&mut self) { + if let HttpTransportMode::Sse(transport) = &mut self.mode { + transport.shutdown().await; + } + } +} + +impl StreamableHttpTransport { + fn new(client: reqwest::Client, url: String) -> Self { + Self { + client, + url, + pending_messages: VecDeque::new(), + } + } + + async fn send( + &mut self, + msg: serde_json::Value, + ) -> std::result::Result<(), StreamableSendError> { + let response = self + .client + .post(&self.url) + .header(ACCEPT, "application/json, text/event-stream") + .json(&msg) + .send() + .await + .map_err(|err| StreamableSendError::Other(err.into()))?; + + let status = response.status(); + if status == StatusCode::ACCEPTED || status == StatusCode::NO_CONTENT { + return Ok(()); + } + + if !status.is_success() { + let body_excerpt = bounded_body_excerpt(response, ERROR_BODY_PREVIEW_BYTES).await; + if is_streamable_http_incompatible_status(status) { + return Err(StreamableSendError::Incompatible(format!( + "status={status} body={body_excerpt}" + ))); + } + return Err(StreamableSendError::Other(anyhow::anyhow!( + "MCP Streamable HTTP rejected (transport=http url={} status={}): {}", + mask_url_secrets(&self.url), + status, + body_excerpt, + ))); + } + + let content_type = response + .headers() + .get(CONTENT_TYPE) + .and_then(|value| value.to_str().ok()) + .map(str::to_string); + let body = response + .text() + .await + .map_err(|err| StreamableSendError::Other(err.into()))?; + self.store_response_body(content_type.as_deref(), &body) + .map_err(StreamableSendError::Other) + } + + async fn recv(&mut self) -> Result { + self.pending_messages + .pop_front() + .context("MCP Streamable HTTP response queue is empty") + } + + fn store_response_body(&mut self, content_type: Option<&str>, body: &str) -> Result<()> { + if body.trim().is_empty() { + return Ok(()); + } + + let is_event_stream = content_type + .map(|value| value.to_ascii_lowercase().contains("text/event-stream")) + .unwrap_or(false) + || body.trim_start().starts_with("event:") + || body.trim_start().starts_with("data:"); + + if is_event_stream { + for msg in parse_sse_json_messages(body)? { + self.pending_messages.push_back(msg); + } + return Ok(()); + } + + self.pending_messages + .push_back(serde_json::from_str(body).context("Invalid MCP Streamable HTTP JSON")?); + Ok(()) + } +} + +fn is_streamable_http_incompatible_status(status: StatusCode) -> bool { + matches!( + status, + StatusCode::NOT_FOUND + | StatusCode::METHOD_NOT_ALLOWED + | StatusCode::NOT_ACCEPTABLE + | StatusCode::UNSUPPORTED_MEDIA_TYPE + | StatusCode::NOT_IMPLEMENTED + ) +} + +fn parse_sse_json_messages(body: &str) -> Result> { + let normalized = body.replace("\r\n", "\n"); + let mut messages = Vec::new(); + + for block in normalized.split("\n\n") { + let mut event_type = "message"; + let mut data = String::new(); + + for line in block.lines() { + if let Some(value) = sse_field_value(line, "event:") { + event_type = value; + } else if let Some(value) = sse_field_value(line, "data:") { + if !data.is_empty() { + data.push('\n'); + } + data.push_str(value); + } + } + + if event_type != "message" || data.trim().is_empty() { + continue; + } + + messages.push( + serde_json::from_str(data.trim()) + .with_context(|| format!("Invalid MCP SSE message data: {}", data.trim()))?, + ); + } + + Ok(messages) +} + +fn sse_field_value<'a>(line: &'a str, field: &str) -> Option<&'a str> { + let value = line.strip_prefix(field)?; + Some(value.strip_prefix(' ').unwrap_or(value)) +} + #[async_trait::async_trait] impl McpTransport for SseTransport { async fn send(&mut self, msg: serde_json::Value) -> Result<()> { @@ -650,15 +875,12 @@ impl McpConnection { let client = reqwest::Client::builder() .timeout(Duration::from_secs(connect_timeout_secs)) .build()?; - Box::new( - SseTransport::connect( - client, - url.clone(), - cancel_token.clone(), - Duration::from_secs(connect_timeout_secs), - ) - .await?, - ) + Box::new(HttpTransport::new( + client, + url.clone(), + cancel_token.clone(), + Duration::from_secs(connect_timeout_secs), + )) } else if let Some(command) = &config.command { let mut cmd = tokio::process::Command::new(command); cmd.args(&config.args) @@ -2219,6 +2441,151 @@ mod tests { ); } + #[test] + fn parse_sse_json_messages_extracts_message_events() { + let body = "event: message\r\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\r\n\r\n"; + let messages = parse_sse_json_messages(body).unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0]["id"], 1); + assert!(messages[0].get("result").is_some()); + } + + #[tokio::test] + async fn mcp_connection_supports_streamable_http_event_stream_responses() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::{TcpListener, TcpStream}; + + async fn read_http_request(socket: &mut TcpStream) -> String { + let mut request = Vec::new(); + let mut buf = [0; 1024]; + let header_end = loop { + let n = socket.read(&mut buf).await.unwrap(); + assert!(n > 0, "client closed before headers completed"); + request.extend_from_slice(&buf[..n]); + if let Some(pos) = request.windows(4).position(|window| window == b"\r\n\r\n") { + break pos + 4; + } + }; + + let headers = String::from_utf8_lossy(&request[..header_end]); + let content_length = headers + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().ok()) + .flatten() + }) + .unwrap_or(0); + let total_len = header_end + content_length; + while request.len() < total_len { + let n = socket.read(&mut buf).await.unwrap(); + assert!(n > 0, "client closed before body completed"); + request.extend_from_slice(&buf[..n]); + } + + String::from_utf8(request).unwrap() + } + + async fn write_json_sse(socket: &mut TcpStream, response: serde_json::Value) { + let body = format!("event: message\ndata: {response}\n\n"); + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + socket.write_all(response.as_bytes()).await.unwrap(); + } + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let server = tokio::spawn(async move { + for _ in 0..6 { + let (mut socket, _) = listener.accept().await.unwrap(); + tokio::spawn(async move { + let request = read_http_request(&mut socket).await; + assert!(request.starts_with("POST /mcp ")); + assert!( + request.contains("Accept: application/json, text/event-stream") + || request.contains("accept: application/json, text/event-stream") + ); + let body = request.split("\r\n\r\n").nth(1).unwrap_or(""); + let value: serde_json::Value = serde_json::from_str(body).unwrap(); + let method = value["method"].as_str().unwrap(); + + if method == "notifications/initialized" { + socket + .write_all(b"HTTP/1.1 202 Accepted\r\nContent-Length: 0\r\n\r\n") + .await + .unwrap(); + return; + } + + let id = value["id"].clone(); + let result = match method { + "initialize" => serde_json::json!({ + "protocolVersion": "2024-11-05", + "serverInfo": {"name": "mock-streamable", "version": "1.0.0"}, + "capabilities": {"tools": {}, "resources": {}, "prompts": {}} + }), + "tools/list" => serde_json::json!({ + "tools": [{ + "name": "read_wiki_structure", + "description": "Read wiki structure", + "inputSchema": {"type": "object"} + }] + }), + "resources/list" => serde_json::json!({"resources": []}), + "resources/templates/list" => { + serde_json::json!({"resourceTemplates": []}) + } + "prompts/list" => serde_json::json!({"prompts": []}), + other => panic!("unexpected method: {other}"), + }; + write_json_sse( + &mut socket, + serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": result + }), + ) + .await; + }); + } + }); + + let config = McpServerConfig { + command: None, + args: vec![], + env: HashMap::new(), + url: Some(format!("http://{addr}/mcp")), + connect_timeout: Some(2), + execute_timeout: None, + read_timeout: None, + disabled: false, + enabled: true, + required: false, + enabled_tools: Vec::new(), + disabled_tools: Vec::new(), + }; + + let conn = McpConnection::connect_with_policy( + "deepwiki".to_string(), + config, + &McpTimeouts::default(), + None, + ) + .await + .unwrap(); + + assert_eq!(conn.state(), ConnectionState::Ready); + assert_eq!(conn.tools().len(), 1); + assert_eq!(conn.tools()[0].name, "read_wiki_structure"); + + server.abort(); + } + #[test] fn mask_url_secrets_strips_userinfo() { let masked = mask_url_secrets("https://user:s3cret@host.example/api?foo=bar");