refactor(mcp): centralize json-rpc framing
This commit is contained in:
+208
-67
@@ -274,8 +274,8 @@ pub enum ConnectionState {
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait McpTransport: Send + Sync {
|
||||
async fn send(&mut self, msg: serde_json::Value) -> Result<()>;
|
||||
async fn recv(&mut self) -> Result<serde_json::Value>;
|
||||
async fn send(&mut self, msg: Vec<u8>) -> Result<()>;
|
||||
async fn recv(&mut self) -> Result<Vec<u8>>;
|
||||
|
||||
/// Graceful shutdown — stdio transports send SIGTERM to the child and
|
||||
/// give it a brief window to exit before tokio's `kill_on_drop` fires
|
||||
@@ -323,14 +323,14 @@ fn send_sigterm(child: &Child) -> bool {
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransport for StdioTransport {
|
||||
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
|
||||
let line = serde_json::to_string(&msg)? + "\n";
|
||||
self.stdin.write_all(line.as_bytes()).await?;
|
||||
async fn send(&mut self, mut msg: Vec<u8>) -> Result<()> {
|
||||
msg.push(b'\n');
|
||||
self.stdin.write_all(&msg).await?;
|
||||
self.stdin.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv(&mut self) -> Result<serde_json::Value> {
|
||||
async fn recv(&mut self) -> Result<Vec<u8>> {
|
||||
let mut line = String::new();
|
||||
loop {
|
||||
line.clear();
|
||||
@@ -344,9 +344,7 @@ impl McpTransport for StdioTransport {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) {
|
||||
return Ok(value);
|
||||
}
|
||||
return Ok(trimmed.as_bytes().to_vec());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -374,8 +372,13 @@ pub struct SseTransport {
|
||||
client: reqwest::Client,
|
||||
base_url: String,
|
||||
endpoint_url: Option<String>,
|
||||
receiver: tokio::sync::mpsc::UnboundedReceiver<serde_json::Value>,
|
||||
pending_messages: VecDeque<serde_json::Value>,
|
||||
receiver: tokio::sync::mpsc::UnboundedReceiver<SseInbound>,
|
||||
pending_messages: VecDeque<Vec<u8>>,
|
||||
}
|
||||
|
||||
enum SseInbound {
|
||||
Endpoint(String),
|
||||
Message(Vec<u8>),
|
||||
}
|
||||
|
||||
struct HttpTransport {
|
||||
@@ -394,7 +397,7 @@ enum HttpTransportMode {
|
||||
struct StreamableHttpTransport {
|
||||
client: reqwest::Client,
|
||||
url: String,
|
||||
pending_messages: VecDeque<serde_json::Value>,
|
||||
pending_messages: VecDeque<Vec<u8>>,
|
||||
}
|
||||
|
||||
enum StreamableSendError {
|
||||
@@ -461,7 +464,7 @@ impl SseTransport {
|
||||
async fn run_sse_loop(
|
||||
client: reqwest::Client,
|
||||
url: String,
|
||||
tx: tokio::sync::mpsc::UnboundedSender<serde_json::Value>,
|
||||
tx: tokio::sync::mpsc::UnboundedSender<SseInbound>,
|
||||
cancel_token: tokio_util::sync::CancellationToken,
|
||||
) -> Result<()> {
|
||||
let response = client.get(&url).send().await.with_context(|| {
|
||||
@@ -523,14 +526,11 @@ impl SseTransport {
|
||||
|
||||
match event_type {
|
||||
"endpoint" => {
|
||||
// Special internal message to set endpoint
|
||||
let _ = tx.send(serde_json::json!({
|
||||
"__internal_sse_endpoint__": data
|
||||
}));
|
||||
let _ = tx.send(SseInbound::Endpoint(data));
|
||||
}
|
||||
"message" => {
|
||||
if let Ok(val) = serde_json::from_str::<serde_json::Value>(&data) {
|
||||
let _ = tx.send(val);
|
||||
if !data.trim().is_empty() {
|
||||
let _ = tx.send(SseInbound::Message(data.into_bytes()));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
@@ -564,21 +564,19 @@ impl SseTransport {
|
||||
}
|
||||
};
|
||||
|
||||
if self.store_endpoint_from_internal_message(&msg)? {
|
||||
return Ok(());
|
||||
match msg {
|
||||
SseInbound::Endpoint(endpoint) => {
|
||||
self.store_endpoint(&endpoint)?;
|
||||
return Ok(());
|
||||
}
|
||||
SseInbound::Message(msg) => self.pending_messages.push_back(msg),
|
||||
}
|
||||
|
||||
self.pending_messages.push_back(msg);
|
||||
}
|
||||
}
|
||||
|
||||
fn store_endpoint_from_internal_message(&mut self, msg: &serde_json::Value) -> Result<bool> {
|
||||
let Some(endpoint) = msg.get("__internal_sse_endpoint__") else {
|
||||
return Ok(false);
|
||||
};
|
||||
let url_str = endpoint.as_str().context("Invalid endpoint format")?;
|
||||
self.endpoint_url = Some(Self::resolve_endpoint_url(&self.base_url, url_str)?);
|
||||
Ok(true)
|
||||
fn store_endpoint(&mut self, endpoint: &str) -> Result<()> {
|
||||
self.endpoint_url = Some(Self::resolve_endpoint_url(&self.base_url, endpoint)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn resolve_endpoint_url(base_url: &str, endpoint_url: &str) -> Result<String> {
|
||||
@@ -610,7 +608,7 @@ impl HttpTransport {
|
||||
}
|
||||
}
|
||||
|
||||
async fn switch_to_sse_and_send(&mut self, msg: serde_json::Value) -> Result<()> {
|
||||
async fn switch_to_sse_and_send(&mut self, msg: Vec<u8>) -> Result<()> {
|
||||
let mut sse = SseTransport::connect(
|
||||
self.client.clone(),
|
||||
self.base_url.clone(),
|
||||
@@ -626,7 +624,7 @@ impl HttpTransport {
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransport for HttpTransport {
|
||||
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
|
||||
async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
|
||||
match &mut self.mode {
|
||||
HttpTransportMode::Streamable(transport) => match transport.send(msg.clone()).await {
|
||||
Ok(()) => Ok(()),
|
||||
@@ -643,7 +641,7 @@ impl McpTransport for HttpTransport {
|
||||
}
|
||||
}
|
||||
|
||||
async fn recv(&mut self) -> Result<serde_json::Value> {
|
||||
async fn recv(&mut self) -> Result<Vec<u8>> {
|
||||
match &mut self.mode {
|
||||
HttpTransportMode::Streamable(transport) => transport.recv().await,
|
||||
HttpTransportMode::Sse(transport) => transport.recv().await,
|
||||
@@ -666,15 +664,13 @@ impl StreamableHttpTransport {
|
||||
}
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&mut self,
|
||||
msg: serde_json::Value,
|
||||
) -> std::result::Result<(), StreamableSendError> {
|
||||
async fn send(&mut self, msg: Vec<u8>) -> std::result::Result<(), StreamableSendError> {
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.url)
|
||||
.header(ACCEPT, "application/json, text/event-stream")
|
||||
.json(&msg)
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(msg)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| StreamableSendError::Other(err.into()))?;
|
||||
@@ -712,7 +708,7 @@ impl StreamableHttpTransport {
|
||||
.map_err(StreamableSendError::Other)
|
||||
}
|
||||
|
||||
async fn recv(&mut self) -> Result<serde_json::Value> {
|
||||
async fn recv(&mut self) -> Result<Vec<u8>> {
|
||||
self.pending_messages
|
||||
.pop_front()
|
||||
.context("MCP Streamable HTTP response queue is empty")
|
||||
@@ -730,14 +726,13 @@ impl StreamableHttpTransport {
|
||||
|| body.trim_start().starts_with("data:");
|
||||
|
||||
if is_event_stream {
|
||||
for msg in parse_sse_json_messages(body)? {
|
||||
for msg in parse_sse_message_data(body) {
|
||||
self.pending_messages.push_back(msg);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.pending_messages
|
||||
.push_back(serde_json::from_str(body).context("Invalid MCP Streamable HTTP JSON")?);
|
||||
self.pending_messages.push_back(body.as_bytes().to_vec());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -753,7 +748,7 @@ fn is_streamable_http_incompatible_status(status: StatusCode) -> bool {
|
||||
)
|
||||
}
|
||||
|
||||
fn parse_sse_json_messages(body: &str) -> Result<Vec<serde_json::Value>> {
|
||||
fn parse_sse_message_data(body: &str) -> Vec<Vec<u8>> {
|
||||
let normalized = body.replace("\r\n", "\n");
|
||||
let mut messages = Vec::new();
|
||||
|
||||
@@ -776,13 +771,10 @@ fn parse_sse_json_messages(body: &str) -> Result<Vec<serde_json::Value>> {
|
||||
continue;
|
||||
}
|
||||
|
||||
messages.push(
|
||||
serde_json::from_str(data.trim())
|
||||
.with_context(|| format!("Invalid MCP SSE message data: {}", data.trim()))?,
|
||||
);
|
||||
messages.push(data.trim().as_bytes().to_vec());
|
||||
}
|
||||
|
||||
Ok(messages)
|
||||
messages
|
||||
}
|
||||
|
||||
fn sse_field_value<'a>(line: &'a str, field: &str) -> Option<&'a str> {
|
||||
@@ -792,29 +784,36 @@ fn sse_field_value<'a>(line: &'a str, field: &str) -> Option<&'a str> {
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransport for SseTransport {
|
||||
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
|
||||
async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
|
||||
let endpoint = self
|
||||
.endpoint_url
|
||||
.as_ref()
|
||||
.context("SSE endpoint not yet discovered")?;
|
||||
let response = self.client.post(endpoint).json(&msg).send().await?;
|
||||
let response = self
|
||||
.client
|
||||
.post(endpoint)
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(msg)
|
||||
.send()
|
||||
.await?;
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!("Failed to send message via SSE POST: {}", response.status());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv(&mut self) -> Result<serde_json::Value> {
|
||||
async fn recv(&mut self) -> Result<Vec<u8>> {
|
||||
loop {
|
||||
let msg = if let Some(msg) = self.pending_messages.pop_front() {
|
||||
msg
|
||||
} else {
|
||||
self.receiver.recv().await.context("SSE transport closed")?
|
||||
};
|
||||
if self.store_endpoint_from_internal_message(&msg)? {
|
||||
continue;
|
||||
if let Some(msg) = self.pending_messages.pop_front() {
|
||||
return Ok(msg);
|
||||
}
|
||||
|
||||
match self.receiver.recv().await.context("SSE transport closed")? {
|
||||
SseInbound::Endpoint(endpoint) => {
|
||||
self.store_endpoint(&endpoint)?;
|
||||
}
|
||||
SseInbound::Message(msg) => return Ok(msg),
|
||||
}
|
||||
return Ok(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1299,14 +1298,18 @@ impl McpConnection {
|
||||
}
|
||||
|
||||
async fn send(&mut self, msg: serde_json::Value) -> Result<()> {
|
||||
self.transport.send(msg).await
|
||||
let bytes = serde_json::to_vec(&msg).context("Failed to serialize MCP JSON-RPC message")?;
|
||||
self.transport.send(bytes).await
|
||||
}
|
||||
|
||||
async fn recv(&mut self, expected_id: u64) -> Result<serde_json::Value> {
|
||||
loop {
|
||||
let value = self.transport.recv().await.inspect_err(|_e| {
|
||||
let bytes = self.transport.recv().await.inspect_err(|_e| {
|
||||
self.state = ConnectionState::Disconnected;
|
||||
})?;
|
||||
let value: serde_json::Value = serde_json::from_slice(&bytes).with_context(|| {
|
||||
format!("Invalid MCP JSON-RPC message from server '{}'", self.name)
|
||||
})?;
|
||||
|
||||
// Check if this is a response with the expected id
|
||||
if value.get("id").and_then(serde_json::Value::as_u64) == Some(expected_id) {
|
||||
@@ -2214,6 +2217,8 @@ pub fn format_tool_result(result: &serde_json::Value) -> String {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[test]
|
||||
fn test_mcp_config_defaults() {
|
||||
@@ -2393,6 +2398,141 @@ mod tests {
|
||||
assert!(formatted.contains("[image content]"));
|
||||
}
|
||||
|
||||
struct ScriptedValueTransport {
|
||||
sent: Arc<Mutex<Vec<serde_json::Value>>>,
|
||||
responses: VecDeque<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransport for ScriptedValueTransport {
|
||||
async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
|
||||
self.sent
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push(serde_json::from_slice(&msg)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv(&mut self) -> Result<Vec<u8>> {
|
||||
self.responses
|
||||
.pop_front()
|
||||
.context("scripted transport exhausted")
|
||||
}
|
||||
}
|
||||
|
||||
struct HangingValueTransport {
|
||||
sent: Arc<Mutex<Vec<serde_json::Value>>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransport for HangingValueTransport {
|
||||
async fn send(&mut self, msg: Vec<u8>) -> Result<()> {
|
||||
self.sent
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push(serde_json::from_slice(&msg)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv(&mut self) -> Result<Vec<u8>> {
|
||||
std::future::pending().await
|
||||
}
|
||||
}
|
||||
|
||||
fn test_server_config() -> McpServerConfig {
|
||||
McpServerConfig {
|
||||
command: Some("mock".to_string()),
|
||||
args: Vec::new(),
|
||||
env: HashMap::new(),
|
||||
url: None,
|
||||
connect_timeout: None,
|
||||
execute_timeout: None,
|
||||
read_timeout: None,
|
||||
disabled: false,
|
||||
enabled: true,
|
||||
required: false,
|
||||
enabled_tools: Vec::new(),
|
||||
disabled_tools: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn test_connection(transport: Box<dyn McpTransport>) -> McpConnection {
|
||||
McpConnection {
|
||||
name: "mock".to_string(),
|
||||
transport,
|
||||
tools: Vec::new(),
|
||||
resources: Vec::new(),
|
||||
resource_templates: Vec::new(),
|
||||
prompts: Vec::new(),
|
||||
request_id: AtomicU64::new(1),
|
||||
state: ConnectionState::Ready,
|
||||
config: test_server_config(),
|
||||
cancel_token: tokio_util::sync::CancellationToken::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn json_frame(value: serde_json::Value) -> Vec<u8> {
|
||||
serde_json::to_vec(&value).unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn call_method_skips_notifications_and_unmatched_responses() {
|
||||
let sent = Arc::new(Mutex::new(Vec::new()));
|
||||
let transport = ScriptedValueTransport {
|
||||
sent: Arc::clone(&sent),
|
||||
responses: VecDeque::from([
|
||||
json_frame(serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/progress",
|
||||
"params": {"progress": 0.5}
|
||||
})),
|
||||
json_frame(serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 99,
|
||||
"result": {"ignored": true}
|
||||
})),
|
||||
json_frame(serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {"ok": true}
|
||||
})),
|
||||
]),
|
||||
};
|
||||
let mut conn = test_connection(Box::new(transport));
|
||||
|
||||
let result = conn
|
||||
.call_method("tools/call", serde_json::json!({"name": "echo"}), 1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, serde_json::json!({"ok": true}));
|
||||
let sent = sent.lock().unwrap();
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert_eq!(sent[0]["jsonrpc"], "2.0");
|
||||
assert_eq!(sent[0]["id"], 1);
|
||||
assert_eq!(sent[0]["method"], "tools/call");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn call_method_times_out_while_waiting_for_response() {
|
||||
let sent = Arc::new(Mutex::new(Vec::new()));
|
||||
let mut conn = test_connection(Box::new(HangingValueTransport {
|
||||
sent: Arc::clone(&sent),
|
||||
}));
|
||||
|
||||
let err = conn
|
||||
.call_method("tools/call", serde_json::json!({"name": "echo"}), 0)
|
||||
.await
|
||||
.expect_err("hung receive should time out");
|
||||
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("MCP method 'tools/call' on server 'mock' timed out after 0s"),
|
||||
"unexpected error: {err:#}"
|
||||
);
|
||||
assert_eq!(sent.lock().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_pool_empty_config() {
|
||||
let pool = McpPool::new(McpConfig::default());
|
||||
@@ -2442,12 +2582,13 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_sse_json_messages_extracts_message_events() {
|
||||
fn parse_sse_message_data_extracts_message_events() {
|
||||
let body = "event: message\r\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\r\n\r\n";
|
||||
let messages = parse_sse_json_messages(body).unwrap();
|
||||
let messages = parse_sse_message_data(body);
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0]["id"], 1);
|
||||
assert!(messages[0].get("result").is_some());
|
||||
let value: serde_json::Value = serde_json::from_slice(&messages[0]).unwrap();
|
||||
assert_eq!(value["id"], 1);
|
||||
assert!(value.get("result").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -2736,11 +2877,11 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
transport
|
||||
.send(serde_json::json!({
|
||||
.send(json_frame(serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize"
|
||||
}))
|
||||
})))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user