diff --git a/crates/tui/src/mcp.rs b/crates/tui/src/mcp.rs index ecf54b76..b0f2bcca 100644 --- a/crates/tui/src/mcp.rs +++ b/crates/tui/src/mcp.rs @@ -274,8 +274,8 @@ pub enum ConnectionState { #[async_trait::async_trait] pub trait McpTransport: Send + Sync { - async fn send(&mut self, msg: serde_json::Value) -> Result<()>; - async fn recv(&mut self) -> Result; + async fn send(&mut self, msg: Vec) -> Result<()>; + async fn recv(&mut self) -> Result>; /// Graceful shutdown — stdio transports send SIGTERM to the child and /// give it a brief window to exit before tokio's `kill_on_drop` fires @@ -323,14 +323,14 @@ fn send_sigterm(child: &Child) -> bool { #[async_trait::async_trait] impl McpTransport for StdioTransport { - async fn send(&mut self, msg: serde_json::Value) -> Result<()> { - let line = serde_json::to_string(&msg)? + "\n"; - self.stdin.write_all(line.as_bytes()).await?; + async fn send(&mut self, mut msg: Vec) -> Result<()> { + msg.push(b'\n'); + self.stdin.write_all(&msg).await?; self.stdin.flush().await?; Ok(()) } - async fn recv(&mut self) -> Result { + async fn recv(&mut self) -> Result> { let mut line = String::new(); loop { line.clear(); @@ -344,9 +344,7 @@ impl McpTransport for StdioTransport { continue; } - if let Ok(value) = serde_json::from_str::(trimmed) { - return Ok(value); - } + return Ok(trimmed.as_bytes().to_vec()); } } @@ -374,8 +372,13 @@ pub struct SseTransport { client: reqwest::Client, base_url: String, endpoint_url: Option, - receiver: tokio::sync::mpsc::UnboundedReceiver, - pending_messages: VecDeque, + receiver: tokio::sync::mpsc::UnboundedReceiver, + pending_messages: VecDeque>, +} + +enum SseInbound { + Endpoint(String), + Message(Vec), } struct HttpTransport { @@ -394,7 +397,7 @@ enum HttpTransportMode { struct StreamableHttpTransport { client: reqwest::Client, url: String, - pending_messages: VecDeque, + pending_messages: VecDeque>, } enum StreamableSendError { @@ -461,7 +464,7 @@ impl SseTransport { async fn run_sse_loop( client: reqwest::Client, url: String, - tx: tokio::sync::mpsc::UnboundedSender, + tx: tokio::sync::mpsc::UnboundedSender, cancel_token: tokio_util::sync::CancellationToken, ) -> Result<()> { let response = client.get(&url).send().await.with_context(|| { @@ -523,14 +526,11 @@ impl SseTransport { match event_type { "endpoint" => { - // Special internal message to set endpoint - let _ = tx.send(serde_json::json!({ - "__internal_sse_endpoint__": data - })); + let _ = tx.send(SseInbound::Endpoint(data)); } "message" => { - if let Ok(val) = serde_json::from_str::(&data) { - let _ = tx.send(val); + if !data.trim().is_empty() { + let _ = tx.send(SseInbound::Message(data.into_bytes())); } } _ => {} @@ -564,21 +564,19 @@ impl SseTransport { } }; - if self.store_endpoint_from_internal_message(&msg)? { - return Ok(()); + match msg { + SseInbound::Endpoint(endpoint) => { + self.store_endpoint(&endpoint)?; + return Ok(()); + } + SseInbound::Message(msg) => self.pending_messages.push_back(msg), } - - self.pending_messages.push_back(msg); } } - fn store_endpoint_from_internal_message(&mut self, msg: &serde_json::Value) -> Result { - let Some(endpoint) = msg.get("__internal_sse_endpoint__") else { - return Ok(false); - }; - let url_str = endpoint.as_str().context("Invalid endpoint format")?; - self.endpoint_url = Some(Self::resolve_endpoint_url(&self.base_url, url_str)?); - Ok(true) + fn store_endpoint(&mut self, endpoint: &str) -> Result<()> { + self.endpoint_url = Some(Self::resolve_endpoint_url(&self.base_url, endpoint)?); + Ok(()) } fn resolve_endpoint_url(base_url: &str, endpoint_url: &str) -> Result { @@ -610,7 +608,7 @@ impl HttpTransport { } } - async fn switch_to_sse_and_send(&mut self, msg: serde_json::Value) -> Result<()> { + async fn switch_to_sse_and_send(&mut self, msg: Vec) -> Result<()> { let mut sse = SseTransport::connect( self.client.clone(), self.base_url.clone(), @@ -626,7 +624,7 @@ impl HttpTransport { #[async_trait::async_trait] impl McpTransport for HttpTransport { - async fn send(&mut self, msg: serde_json::Value) -> Result<()> { + async fn send(&mut self, msg: Vec) -> Result<()> { match &mut self.mode { HttpTransportMode::Streamable(transport) => match transport.send(msg.clone()).await { Ok(()) => Ok(()), @@ -643,7 +641,7 @@ impl McpTransport for HttpTransport { } } - async fn recv(&mut self) -> Result { + async fn recv(&mut self) -> Result> { match &mut self.mode { HttpTransportMode::Streamable(transport) => transport.recv().await, HttpTransportMode::Sse(transport) => transport.recv().await, @@ -666,15 +664,13 @@ impl StreamableHttpTransport { } } - async fn send( - &mut self, - msg: serde_json::Value, - ) -> std::result::Result<(), StreamableSendError> { + async fn send(&mut self, msg: Vec) -> std::result::Result<(), StreamableSendError> { let response = self .client .post(&self.url) .header(ACCEPT, "application/json, text/event-stream") - .json(&msg) + .header(CONTENT_TYPE, "application/json") + .body(msg) .send() .await .map_err(|err| StreamableSendError::Other(err.into()))?; @@ -712,7 +708,7 @@ impl StreamableHttpTransport { .map_err(StreamableSendError::Other) } - async fn recv(&mut self) -> Result { + async fn recv(&mut self) -> Result> { self.pending_messages .pop_front() .context("MCP Streamable HTTP response queue is empty") @@ -730,14 +726,13 @@ impl StreamableHttpTransport { || body.trim_start().starts_with("data:"); if is_event_stream { - for msg in parse_sse_json_messages(body)? { + for msg in parse_sse_message_data(body) { self.pending_messages.push_back(msg); } return Ok(()); } - self.pending_messages - .push_back(serde_json::from_str(body).context("Invalid MCP Streamable HTTP JSON")?); + self.pending_messages.push_back(body.as_bytes().to_vec()); Ok(()) } } @@ -753,7 +748,7 @@ fn is_streamable_http_incompatible_status(status: StatusCode) -> bool { ) } -fn parse_sse_json_messages(body: &str) -> Result> { +fn parse_sse_message_data(body: &str) -> Vec> { let normalized = body.replace("\r\n", "\n"); let mut messages = Vec::new(); @@ -776,13 +771,10 @@ fn parse_sse_json_messages(body: &str) -> Result> { continue; } - messages.push( - serde_json::from_str(data.trim()) - .with_context(|| format!("Invalid MCP SSE message data: {}", data.trim()))?, - ); + messages.push(data.trim().as_bytes().to_vec()); } - Ok(messages) + messages } fn sse_field_value<'a>(line: &'a str, field: &str) -> Option<&'a str> { @@ -792,29 +784,36 @@ fn sse_field_value<'a>(line: &'a str, field: &str) -> Option<&'a str> { #[async_trait::async_trait] impl McpTransport for SseTransport { - async fn send(&mut self, msg: serde_json::Value) -> Result<()> { + async fn send(&mut self, msg: Vec) -> Result<()> { let endpoint = self .endpoint_url .as_ref() .context("SSE endpoint not yet discovered")?; - let response = self.client.post(endpoint).json(&msg).send().await?; + let response = self + .client + .post(endpoint) + .header(CONTENT_TYPE, "application/json") + .body(msg) + .send() + .await?; if !response.status().is_success() { anyhow::bail!("Failed to send message via SSE POST: {}", response.status()); } Ok(()) } - async fn recv(&mut self) -> Result { + async fn recv(&mut self) -> Result> { loop { - let msg = if let Some(msg) = self.pending_messages.pop_front() { - msg - } else { - self.receiver.recv().await.context("SSE transport closed")? - }; - if self.store_endpoint_from_internal_message(&msg)? { - continue; + if let Some(msg) = self.pending_messages.pop_front() { + return Ok(msg); + } + + match self.receiver.recv().await.context("SSE transport closed")? { + SseInbound::Endpoint(endpoint) => { + self.store_endpoint(&endpoint)?; + } + SseInbound::Message(msg) => return Ok(msg), } - return Ok(msg); } } } @@ -1299,14 +1298,18 @@ impl McpConnection { } async fn send(&mut self, msg: serde_json::Value) -> Result<()> { - self.transport.send(msg).await + let bytes = serde_json::to_vec(&msg).context("Failed to serialize MCP JSON-RPC message")?; + self.transport.send(bytes).await } async fn recv(&mut self, expected_id: u64) -> Result { loop { - let value = self.transport.recv().await.inspect_err(|_e| { + let bytes = self.transport.recv().await.inspect_err(|_e| { self.state = ConnectionState::Disconnected; })?; + let value: serde_json::Value = serde_json::from_slice(&bytes).with_context(|| { + format!("Invalid MCP JSON-RPC message from server '{}'", self.name) + })?; // Check if this is a response with the expected id if value.get("id").and_then(serde_json::Value::as_u64) == Some(expected_id) { @@ -2214,6 +2217,8 @@ pub fn format_tool_result(result: &serde_json::Value) -> String { #[cfg(test)] mod tests { use super::*; + use std::collections::VecDeque; + use std::sync::{Arc, Mutex}; #[test] fn test_mcp_config_defaults() { @@ -2393,6 +2398,141 @@ mod tests { assert!(formatted.contains("[image content]")); } + struct ScriptedValueTransport { + sent: Arc>>, + responses: VecDeque>, + } + + #[async_trait::async_trait] + impl McpTransport for ScriptedValueTransport { + async fn send(&mut self, msg: Vec) -> Result<()> { + self.sent + .lock() + .unwrap() + .push(serde_json::from_slice(&msg)?); + Ok(()) + } + + async fn recv(&mut self) -> Result> { + self.responses + .pop_front() + .context("scripted transport exhausted") + } + } + + struct HangingValueTransport { + sent: Arc>>, + } + + #[async_trait::async_trait] + impl McpTransport for HangingValueTransport { + async fn send(&mut self, msg: Vec) -> Result<()> { + self.sent + .lock() + .unwrap() + .push(serde_json::from_slice(&msg)?); + Ok(()) + } + + async fn recv(&mut self) -> Result> { + std::future::pending().await + } + } + + fn test_server_config() -> McpServerConfig { + McpServerConfig { + command: Some("mock".to_string()), + args: Vec::new(), + env: HashMap::new(), + url: None, + connect_timeout: None, + execute_timeout: None, + read_timeout: None, + disabled: false, + enabled: true, + required: false, + enabled_tools: Vec::new(), + disabled_tools: Vec::new(), + } + } + + fn test_connection(transport: Box) -> McpConnection { + McpConnection { + name: "mock".to_string(), + transport, + tools: Vec::new(), + resources: Vec::new(), + resource_templates: Vec::new(), + prompts: Vec::new(), + request_id: AtomicU64::new(1), + state: ConnectionState::Ready, + config: test_server_config(), + cancel_token: tokio_util::sync::CancellationToken::new(), + } + } + + fn json_frame(value: serde_json::Value) -> Vec { + serde_json::to_vec(&value).unwrap() + } + + #[tokio::test] + async fn call_method_skips_notifications_and_unmatched_responses() { + let sent = Arc::new(Mutex::new(Vec::new())); + let transport = ScriptedValueTransport { + sent: Arc::clone(&sent), + responses: VecDeque::from([ + json_frame(serde_json::json!({ + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": {"progress": 0.5} + })), + json_frame(serde_json::json!({ + "jsonrpc": "2.0", + "id": 99, + "result": {"ignored": true} + })), + json_frame(serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "result": {"ok": true} + })), + ]), + }; + let mut conn = test_connection(Box::new(transport)); + + let result = conn + .call_method("tools/call", serde_json::json!({"name": "echo"}), 1) + .await + .unwrap(); + + assert_eq!(result, serde_json::json!({"ok": true})); + let sent = sent.lock().unwrap(); + assert_eq!(sent.len(), 1); + assert_eq!(sent[0]["jsonrpc"], "2.0"); + assert_eq!(sent[0]["id"], 1); + assert_eq!(sent[0]["method"], "tools/call"); + } + + #[tokio::test] + async fn call_method_times_out_while_waiting_for_response() { + let sent = Arc::new(Mutex::new(Vec::new())); + let mut conn = test_connection(Box::new(HangingValueTransport { + sent: Arc::clone(&sent), + })); + + let err = conn + .call_method("tools/call", serde_json::json!({"name": "echo"}), 0) + .await + .expect_err("hung receive should time out"); + + assert!( + err.to_string() + .contains("MCP method 'tools/call' on server 'mock' timed out after 0s"), + "unexpected error: {err:#}" + ); + assert_eq!(sent.lock().unwrap().len(), 1); + } + #[tokio::test] async fn test_mcp_pool_empty_config() { let pool = McpPool::new(McpConfig::default()); @@ -2442,12 +2582,13 @@ mod tests { } #[test] - fn parse_sse_json_messages_extracts_message_events() { + fn parse_sse_message_data_extracts_message_events() { let body = "event: message\r\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\r\n\r\n"; - let messages = parse_sse_json_messages(body).unwrap(); + let messages = parse_sse_message_data(body); assert_eq!(messages.len(), 1); - assert_eq!(messages[0]["id"], 1); - assert!(messages[0].get("result").is_some()); + let value: serde_json::Value = serde_json::from_slice(&messages[0]).unwrap(); + assert_eq!(value["id"], 1); + assert!(value.get("result").is_some()); } #[tokio::test] @@ -2736,11 +2877,11 @@ mod tests { .unwrap(); transport - .send(serde_json::json!({ + .send(json_frame(serde_json::json!({ "jsonrpc": "2.0", "id": 1, "method": "initialize" - })) + }))) .await .unwrap();