refactor(mcp): centralize json-rpc framing

This commit is contained in:
Hunter Bown
2026-05-09 12:27:20 -05:00
parent b78c2f8483
commit dcf7b66ad8
+208 -67
View File
@@ -274,8 +274,8 @@ pub enum ConnectionState {
#[async_trait::async_trait]
pub trait McpTransport: Send + Sync {
async fn send(&mut self, msg: serde_json::Value) -> Result<()>;
async fn recv(&mut self) -> Result<serde_json::Value>;
async fn send(&mut self, msg: Vec<u8>) -> Result<()>;
async fn recv(&mut self) -> Result<Vec<u8>>;
/// Graceful shutdown — stdio transports send SIGTERM to the child and
/// give it a brief window to exit before tokio's `kill_on_drop` fires
@@ -323,14 +323,14 @@ fn send_sigterm(child: &Child) -> bool {
#[async_trait::async_trait]
impl McpTransport for StdioTransport {
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
let line = serde_json::to_string(&msg)? + "\n";
self.stdin.write_all(line.as_bytes()).await?;
async fn send(&mut self, mut msg: Vec<u8>) -> Result<()> {
msg.push(b'\n');
self.stdin.write_all(&msg).await?;
self.stdin.flush().await?;
Ok(())
}
async fn recv(&mut self) -> Result<serde_json::Value> {
async fn recv(&mut self) -> Result<Vec<u8>> {
let mut line = String::new();
loop {
line.clear();
@@ -344,9 +344,7 @@ impl McpTransport for StdioTransport {
continue;
}
if let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) {
return Ok(value);
}
return Ok(trimmed.as_bytes().to_vec());
}
}
@@ -374,8 +372,13 @@ pub struct SseTransport {
client: reqwest::Client,
base_url: String,
endpoint_url: Option<String>,
receiver: tokio::sync::mpsc::UnboundedReceiver<serde_json::Value>,
pending_messages: VecDeque<serde_json::Value>,
receiver: tokio::sync::mpsc::UnboundedReceiver<SseInbound>,
pending_messages: VecDeque<Vec<u8>>,
}
enum SseInbound {
Endpoint(String),
Message(Vec<u8>),
}
struct HttpTransport {
@@ -394,7 +397,7 @@ enum HttpTransportMode {
struct StreamableHttpTransport {
client: reqwest::Client,
url: String,
pending_messages: VecDeque<serde_json::Value>,
pending_messages: VecDeque<Vec<u8>>,
}
enum StreamableSendError {
@@ -461,7 +464,7 @@ impl SseTransport {
async fn run_sse_loop(
client: reqwest::Client,
url: String,
tx: tokio::sync::mpsc::UnboundedSender<serde_json::Value>,
tx: tokio::sync::mpsc::UnboundedSender<SseInbound>,
cancel_token: tokio_util::sync::CancellationToken,
) -> Result<()> {
let response = client.get(&url).send().await.with_context(|| {
@@ -523,14 +526,11 @@ impl SseTransport {
match event_type {
"endpoint" => {
// Special internal message to set endpoint
let _ = tx.send(serde_json::json!({
"__internal_sse_endpoint__": data
}));
let _ = tx.send(SseInbound::Endpoint(data));
}
"message" => {
if let Ok(val) = serde_json::from_str::<serde_json::Value>(&data) {
let _ = tx.send(val);
if !data.trim().is_empty() {
let _ = tx.send(SseInbound::Message(data.into_bytes()));
}
}
_ => {}
@@ -564,21 +564,19 @@ impl SseTransport {
}
};
if self.store_endpoint_from_internal_message(&msg)? {
return Ok(());
match msg {
SseInbound::Endpoint(endpoint) => {
self.store_endpoint(&endpoint)?;
return Ok(());
}
SseInbound::Message(msg) => self.pending_messages.push_back(msg),
}
self.pending_messages.push_back(msg);
}
}
fn store_endpoint_from_internal_message(&mut self, msg: &serde_json::Value) -> Result<bool> {
let Some(endpoint) = msg.get("__internal_sse_endpoint__") else {
return Ok(false);
};
let url_str = endpoint.as_str().context("Invalid endpoint format")?;
self.endpoint_url = Some(Self::resolve_endpoint_url(&self.base_url, url_str)?);
Ok(true)
fn store_endpoint(&mut self, endpoint: &str) -> Result<()> {
self.endpoint_url = Some(Self::resolve_endpoint_url(&self.base_url, endpoint)?);
Ok(())
}
fn resolve_endpoint_url(base_url: &str, endpoint_url: &str) -> Result<String> {
@@ -610,7 +608,7 @@ impl HttpTransport {
}
}
async fn switch_to_sse_and_send(&mut self, msg: serde_json::Value) -> Result<()> {
async fn switch_to_sse_and_send(&mut self, msg: Vec<u8>) -> Result<()> {
let mut sse = SseTransport::connect(
self.client.clone(),
self.base_url.clone(),
@@ -626,7 +624,7 @@ impl HttpTransport {
#[async_trait::async_trait]
impl McpTransport for HttpTransport {
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
match &mut self.mode {
HttpTransportMode::Streamable(transport) => match transport.send(msg.clone()).await {
Ok(()) => Ok(()),
@@ -643,7 +641,7 @@ impl McpTransport for HttpTransport {
}
}
async fn recv(&mut self) -> Result<serde_json::Value> {
async fn recv(&mut self) -> Result<Vec<u8>> {
match &mut self.mode {
HttpTransportMode::Streamable(transport) => transport.recv().await,
HttpTransportMode::Sse(transport) => transport.recv().await,
@@ -666,15 +664,13 @@ impl StreamableHttpTransport {
}
}
async fn send(
&mut self,
msg: serde_json::Value,
) -> std::result::Result<(), StreamableSendError> {
async fn send(&mut self, msg: Vec<u8>) -> std::result::Result<(), StreamableSendError> {
let response = self
.client
.post(&self.url)
.header(ACCEPT, "application/json, text/event-stream")
.json(&msg)
.header(CONTENT_TYPE, "application/json")
.body(msg)
.send()
.await
.map_err(|err| StreamableSendError::Other(err.into()))?;
@@ -712,7 +708,7 @@ impl StreamableHttpTransport {
.map_err(StreamableSendError::Other)
}
async fn recv(&mut self) -> Result<serde_json::Value> {
async fn recv(&mut self) -> Result<Vec<u8>> {
self.pending_messages
.pop_front()
.context("MCP Streamable HTTP response queue is empty")
@@ -730,14 +726,13 @@ impl StreamableHttpTransport {
|| body.trim_start().starts_with("data:");
if is_event_stream {
for msg in parse_sse_json_messages(body)? {
for msg in parse_sse_message_data(body) {
self.pending_messages.push_back(msg);
}
return Ok(());
}
self.pending_messages
.push_back(serde_json::from_str(body).context("Invalid MCP Streamable HTTP JSON")?);
self.pending_messages.push_back(body.as_bytes().to_vec());
Ok(())
}
}
@@ -753,7 +748,7 @@ fn is_streamable_http_incompatible_status(status: StatusCode) -> bool {
)
}
fn parse_sse_json_messages(body: &str) -> Result<Vec<serde_json::Value>> {
fn parse_sse_message_data(body: &str) -> Vec<Vec<u8>> {
let normalized = body.replace("\r\n", "\n");
let mut messages = Vec::new();
@@ -776,13 +771,10 @@ fn parse_sse_json_messages(body: &str) -> Result<Vec<serde_json::Value>> {
continue;
}
messages.push(
serde_json::from_str(data.trim())
.with_context(|| format!("Invalid MCP SSE message data: {}", data.trim()))?,
);
messages.push(data.trim().as_bytes().to_vec());
}
Ok(messages)
messages
}
fn sse_field_value<'a>(line: &'a str, field: &str) -> Option<&'a str> {
@@ -792,29 +784,36 @@ fn sse_field_value<'a>(line: &'a str, field: &str) -> Option<&'a str> {
#[async_trait::async_trait]
impl McpTransport for SseTransport {
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
let endpoint = self
.endpoint_url
.as_ref()
.context("SSE endpoint not yet discovered")?;
let response = self.client.post(endpoint).json(&msg).send().await?;
let response = self
.client
.post(endpoint)
.header(CONTENT_TYPE, "application/json")
.body(msg)
.send()
.await?;
if !response.status().is_success() {
anyhow::bail!("Failed to send message via SSE POST: {}", response.status());
}
Ok(())
}
async fn recv(&mut self) -> Result<serde_json::Value> {
async fn recv(&mut self) -> Result<Vec<u8>> {
loop {
let msg = if let Some(msg) = self.pending_messages.pop_front() {
msg
} else {
self.receiver.recv().await.context("SSE transport closed")?
};
if self.store_endpoint_from_internal_message(&msg)? {
continue;
if let Some(msg) = self.pending_messages.pop_front() {
return Ok(msg);
}
match self.receiver.recv().await.context("SSE transport closed")? {
SseInbound::Endpoint(endpoint) => {
self.store_endpoint(&endpoint)?;
}
SseInbound::Message(msg) => return Ok(msg),
}
return Ok(msg);
}
}
}
@@ -1299,14 +1298,18 @@ impl McpConnection {
}
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
self.transport.send(msg).await
let bytes = serde_json::to_vec(&msg).context("Failed to serialize MCP JSON-RPC message")?;
self.transport.send(bytes).await
}
async fn recv(&mut self, expected_id: u64) -> Result<serde_json::Value> {
loop {
let value = self.transport.recv().await.inspect_err(|_e| {
let bytes = self.transport.recv().await.inspect_err(|_e| {
self.state = ConnectionState::Disconnected;
})?;
let value: serde_json::Value = serde_json::from_slice(&bytes).with_context(|| {
format!("Invalid MCP JSON-RPC message from server '{}'", self.name)
})?;
// Check if this is a response with the expected id
if value.get("id").and_then(serde_json::Value::as_u64) == Some(expected_id) {
@@ -2214,6 +2217,8 @@ pub fn format_tool_result(result: &serde_json::Value) -> String {
#[cfg(test)]
mod tests {
use super::*;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
#[test]
fn test_mcp_config_defaults() {
@@ -2393,6 +2398,141 @@ mod tests {
assert!(formatted.contains("[image content]"));
}
struct ScriptedValueTransport {
sent: Arc<Mutex<Vec<serde_json::Value>>>,
responses: VecDeque<Vec<u8>>,
}
#[async_trait::async_trait]
impl McpTransport for ScriptedValueTransport {
async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
self.sent
.lock()
.unwrap()
.push(serde_json::from_slice(&msg)?);
Ok(())
}
async fn recv(&mut self) -> Result<Vec<u8>> {
self.responses
.pop_front()
.context("scripted transport exhausted")
}
}
struct HangingValueTransport {
sent: Arc<Mutex<Vec<serde_json::Value>>>,
}
#[async_trait::async_trait]
impl McpTransport for HangingValueTransport {
async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
self.sent
.lock()
.unwrap()
.push(serde_json::from_slice(&msg)?);
Ok(())
}
async fn recv(&mut self) -> Result<Vec<u8>> {
std::future::pending().await
}
}
fn test_server_config() -> McpServerConfig {
McpServerConfig {
command: Some("mock".to_string()),
args: Vec::new(),
env: HashMap::new(),
url: None,
connect_timeout: None,
execute_timeout: None,
read_timeout: None,
disabled: false,
enabled: true,
required: false,
enabled_tools: Vec::new(),
disabled_tools: Vec::new(),
}
}
fn test_connection(transport: Box<dyn McpTransport>) -> McpConnection {
McpConnection {
name: "mock".to_string(),
transport,
tools: Vec::new(),
resources: Vec::new(),
resource_templates: Vec::new(),
prompts: Vec::new(),
request_id: AtomicU64::new(1),
state: ConnectionState::Ready,
config: test_server_config(),
cancel_token: tokio_util::sync::CancellationToken::new(),
}
}
fn json_frame(value: serde_json::Value) -> Vec<u8> {
serde_json::to_vec(&value).unwrap()
}
#[tokio::test]
async fn call_method_skips_notifications_and_unmatched_responses() {
let sent = Arc::new(Mutex::new(Vec::new()));
let transport = ScriptedValueTransport {
sent: Arc::clone(&sent),
responses: VecDeque::from([
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/progress",
"params": {"progress": 0.5}
})),
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 99,
"result": {"ignored": true}
})),
json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": {"ok": true}
})),
]),
};
let mut conn = test_connection(Box::new(transport));
let result = conn
.call_method("tools/call", serde_json::json!({"name": "echo"}), 1)
.await
.unwrap();
assert_eq!(result, serde_json::json!({"ok": true}));
let sent = sent.lock().unwrap();
assert_eq!(sent.len(), 1);
assert_eq!(sent[0]["jsonrpc"], "2.0");
assert_eq!(sent[0]["id"], 1);
assert_eq!(sent[0]["method"], "tools/call");
}
#[tokio::test]
async fn call_method_times_out_while_waiting_for_response() {
let sent = Arc::new(Mutex::new(Vec::new()));
let mut conn = test_connection(Box::new(HangingValueTransport {
sent: Arc::clone(&sent),
}));
let err = conn
.call_method("tools/call", serde_json::json!({"name": "echo"}), 0)
.await
.expect_err("hung receive should time out");
assert!(
err.to_string()
.contains("MCP method 'tools/call' on server 'mock' timed out after 0s"),
"unexpected error: {err:#}"
);
assert_eq!(sent.lock().unwrap().len(), 1);
}
#[tokio::test]
async fn test_mcp_pool_empty_config() {
let pool = McpPool::new(McpConfig::default());
@@ -2442,12 +2582,13 @@ mod tests {
}
#[test]
fn parse_sse_json_messages_extracts_message_events() {
fn parse_sse_message_data_extracts_message_events() {
let body = "event: message\r\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\r\n\r\n";
let messages = parse_sse_json_messages(body).unwrap();
let messages = parse_sse_message_data(body);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0]["id"], 1);
assert!(messages[0].get("result").is_some());
let value: serde_json::Value = serde_json::from_slice(&messages[0]).unwrap();
assert_eq!(value["id"], 1);
assert!(value.get("result").is_some());
}
#[tokio::test]
@@ -2736,11 +2877,11 @@ mod tests {
.unwrap();
transport
.send(serde_json::json!({
.send(json_frame(serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize"
}))
})))
.await
.unwrap();