Retry MCP calls after stale SSE connections
This commit is contained in:
committed by
Hunter Bown
parent
58c57cb798
commit
d26c2128b8
+591
-1
@@ -610,6 +610,7 @@ struct StreamableHttpTransport {
|
||||
#[derive(Debug)]
|
||||
enum StreamableSendError {
|
||||
Incompatible(String),
|
||||
StaleSession(String),
|
||||
Other(anyhow::Error),
|
||||
}
|
||||
|
||||
@@ -922,6 +923,19 @@ impl McpTransport for HttpTransport {
|
||||
);
|
||||
self.switch_to_sse_and_send(msg).await
|
||||
}
|
||||
Err(StreamableSendError::StaleSession(detail)) => {
|
||||
if let HttpTransportMode::Streamable(transport) = &mut self.mode {
|
||||
tracing::debug!(
|
||||
target: "mcp",
|
||||
error = %detail,
|
||||
"MCP Streamable HTTP session expired; clearing cached session ID"
|
||||
);
|
||||
transport.session_id = None;
|
||||
}
|
||||
Err(anyhow::anyhow!(
|
||||
"MCP Streamable HTTP session expired; retry with a new session required ({detail})"
|
||||
))
|
||||
}
|
||||
Err(StreamableSendError::Other(err)) => Err(err),
|
||||
},
|
||||
HttpTransportMode::Sse(transport) => transport.send(msg).await,
|
||||
@@ -992,6 +1006,13 @@ impl StreamableHttpTransport {
|
||||
|
||||
if !status.is_success() {
|
||||
let body_excerpt = bounded_body_excerpt(response, ERROR_BODY_PREVIEW_BYTES).await;
|
||||
if self.session_id.is_some()
|
||||
&& is_streamable_http_stale_session_status(status, &body_excerpt)
|
||||
{
|
||||
return Err(StreamableSendError::StaleSession(format!(
|
||||
"status={status} body={body_excerpt}"
|
||||
)));
|
||||
}
|
||||
if is_streamable_http_incompatible_status(status) {
|
||||
return Err(StreamableSendError::Incompatible(format!(
|
||||
"status={status} body={body_excerpt}"
|
||||
@@ -1058,6 +1079,30 @@ fn is_streamable_http_incompatible_status(status: StatusCode) -> bool {
|
||||
)
|
||||
}
|
||||
|
||||
fn is_streamable_http_stale_session_status(status: StatusCode, body_excerpt: &str) -> bool {
|
||||
if status == StatusCode::NOT_FOUND {
|
||||
return true;
|
||||
}
|
||||
if status != StatusCode::BAD_REQUEST && status != StatusCode::UNAUTHORIZED {
|
||||
return false;
|
||||
}
|
||||
let body = body_excerpt.to_ascii_lowercase();
|
||||
body.contains("session") && (body.contains("expired") || body.contains("invalid"))
|
||||
}
|
||||
|
||||
fn is_mcp_stale_session_body(body: &str) -> bool {
|
||||
let body = body.to_ascii_lowercase();
|
||||
body.contains("session") && (body.contains("expired") || body.contains("invalid"))
|
||||
}
|
||||
|
||||
fn is_mcp_stale_session_error(err: &anyhow::Error) -> bool {
|
||||
let err = format!("{err:#}");
|
||||
err.contains("MCP Streamable HTTP session expired")
|
||||
|| err.contains("MCP session expired")
|
||||
|| err.contains("SSE transport closed")
|
||||
|| is_mcp_stale_session_body(&err)
|
||||
}
|
||||
|
||||
fn parse_sse_message_data(body: &str) -> Vec<Vec<u8>> {
|
||||
let normalized = body.replace("\r\n", "\n");
|
||||
let mut messages = Vec::new();
|
||||
@@ -1148,6 +1193,14 @@ impl McpTransport for SseTransport {
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let body_excerpt = bounded_body_excerpt(response, ERROR_BODY_PREVIEW_BYTES).await;
|
||||
if is_mcp_stale_session_body(&body_excerpt) {
|
||||
anyhow::bail!(
|
||||
"MCP session expired (transport=sse endpoint={} status={}): {}",
|
||||
mask_url_secrets(endpoint),
|
||||
status,
|
||||
body_excerpt
|
||||
);
|
||||
}
|
||||
anyhow::bail!(
|
||||
"MCP SSE POST rejected (transport=sse endpoint={} status={}): {}",
|
||||
mask_url_secrets(endpoint),
|
||||
@@ -1779,6 +1832,11 @@ impl McpConnection {
|
||||
// IDs, but accept numeric echoes for compatibility with older
|
||||
// servers and tests.
|
||||
if response_id_matches(value.get("id"), &expected_id) {
|
||||
if let Some(error) = value.get("error") {
|
||||
if is_mcp_stale_session_body(&error.to_string()) {
|
||||
anyhow::bail!("MCP session expired: {error}");
|
||||
}
|
||||
}
|
||||
return Ok(value);
|
||||
}
|
||||
// Skip notifications (no id) and responses with different ids
|
||||
@@ -2359,7 +2417,26 @@ impl McpPool {
|
||||
anyhow::bail!("MCP tool '{tool_name}' is disabled for server '{server_name}'");
|
||||
}
|
||||
let timeout = conn.config().effective_execute_timeout(&global_timeouts);
|
||||
conn.call_tool(tool_name, arguments, timeout).await
|
||||
match conn.call_tool(tool_name, arguments.clone(), timeout).await {
|
||||
Ok(result) => Ok(result),
|
||||
Err(err) if is_mcp_stale_session_error(&err) => {
|
||||
tracing::debug!(
|
||||
target: "mcp",
|
||||
server = server_name,
|
||||
tool = tool_name,
|
||||
error = %err,
|
||||
"retrying MCP tool call after stale session"
|
||||
);
|
||||
self.connections.remove(server_name);
|
||||
let conn = self.get_or_connect(server_name).await?;
|
||||
if !conn.config().is_tool_enabled(tool_name) {
|
||||
anyhow::bail!("MCP tool '{tool_name}' is disabled for server '{server_name}'");
|
||||
}
|
||||
let timeout = conn.config().effective_execute_timeout(&global_timeouts);
|
||||
conn.call_tool(tool_name, arguments, timeout).await
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get list of configured server names
|
||||
@@ -3500,6 +3577,94 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn json_rpc_session_error_is_marked_stale() {
|
||||
let sent = Arc::new(Mutex::new(Vec::new()));
|
||||
let transport = ScriptedValueTransport {
|
||||
sent: Arc::clone(&sent),
|
||||
responses: VecDeque::from([json_frame(serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"error": {
|
||||
"code": -32001,
|
||||
"message": "MCP session expired"
|
||||
}
|
||||
}))]),
|
||||
};
|
||||
let mut conn = test_connection(Box::new(transport));
|
||||
|
||||
let err = conn
|
||||
.call_tool("search", serde_json::json!({"query": "dephy"}), 1)
|
||||
.await
|
||||
.expect_err("session error should fail");
|
||||
|
||||
assert!(
|
||||
is_mcp_stale_session_error(&err),
|
||||
"JSON-RPC session error should be retryable, got: {err:#}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sse_transport_closed_is_retryable() {
|
||||
let err = anyhow::anyhow!("SSE transport closed");
|
||||
assert!(
|
||||
is_mcp_stale_session_error(&err),
|
||||
"closed SSE stream should force reconnect before retry"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn discover_all_ignores_unsupported_optional_capabilities() {
|
||||
let sent = Arc::new(Mutex::new(Vec::new()));
|
||||
let transport = ScriptedValueTransport {
|
||||
sent: Arc::clone(&sent),
|
||||
responses: VecDeque::from([
|
||||
json_frame(serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {
|
||||
"tools": [
|
||||
{ "name": "search", "inputSchema": {} }
|
||||
]
|
||||
}
|
||||
})),
|
||||
json_frame(serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": "resources not supported"
|
||||
}
|
||||
})),
|
||||
json_frame(serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": "resource templates not supported"
|
||||
}
|
||||
})),
|
||||
json_frame(serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": "prompts not supported"
|
||||
}
|
||||
})),
|
||||
]),
|
||||
};
|
||||
let mut conn = test_connection(Box::new(transport));
|
||||
|
||||
conn.discover_all().await.expect("discover");
|
||||
|
||||
assert_eq!(conn.tools.len(), 1);
|
||||
assert_eq!(conn.tools[0].name, "search");
|
||||
assert!(conn.resources.is_empty());
|
||||
assert!(conn.resource_templates.is_empty());
|
||||
assert!(conn.prompts.is_empty());
|
||||
}
|
||||
|
||||
/// #1244: when an MCP stdio server fails to spawn, the underlying OS
|
||||
/// error (e.g. ENOENT for a missing binary) must reach the user via the
|
||||
/// snapshot.error string. Regression test for `err.to_string()` dropping
|
||||
@@ -4277,6 +4442,431 @@ mod tests {
|
||||
server.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn streamable_http_stale_session_reconnects_and_retries_tool_call() {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let get_count = Arc::new(AtomicUsize::new(0));
|
||||
let stale_seen = Arc::new(AtomicBool::new(false));
|
||||
let success_seen = Arc::new(AtomicBool::new(false));
|
||||
let server_get_count = Arc::clone(&get_count);
|
||||
let server_stale_seen = Arc::clone(&stale_seen);
|
||||
let server_success_seen = Arc::clone(&success_seen);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
loop {
|
||||
let Ok((mut socket, _)) = listener.accept().await else {
|
||||
break;
|
||||
};
|
||||
let get_count = Arc::clone(&server_get_count);
|
||||
let stale_seen = Arc::clone(&server_stale_seen);
|
||||
let success_seen = Arc::clone(&server_success_seen);
|
||||
tokio::spawn(async move {
|
||||
let mut request = Vec::new();
|
||||
let mut buf = [0; 4096];
|
||||
let header_end = loop {
|
||||
let n = socket.read(&mut buf).await.unwrap();
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
if let Some(pos) = request.windows(4).position(|w| w == b"\r\n\r\n") {
|
||||
break pos + 4;
|
||||
}
|
||||
};
|
||||
let headers = String::from_utf8_lossy(&request[..header_end]).to_string();
|
||||
let content_length = headers
|
||||
.lines()
|
||||
.find_map(|line| {
|
||||
let (name, value) = line.split_once(':')?;
|
||||
name.eq_ignore_ascii_case("content-length")
|
||||
.then(|| value.trim().parse::<usize>().ok())
|
||||
.flatten()
|
||||
})
|
||||
.unwrap_or(0);
|
||||
while request.len() < header_end + content_length {
|
||||
let n = socket.read(&mut buf).await.unwrap();
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
}
|
||||
let body = &request[header_end..header_end + content_length];
|
||||
let session_header = headers.lines().find_map(|line| {
|
||||
let (name, value) = line.split_once(':')?;
|
||||
name.eq_ignore_ascii_case("mcp-session-id")
|
||||
.then(|| value.trim().to_string())
|
||||
});
|
||||
|
||||
if headers.starts_with("GET /mcp ") {
|
||||
let count = get_count.fetch_add(1, AtomicOrdering::SeqCst);
|
||||
let session = if count == 0 { "sess-old" } else { "sess-new" };
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\nMcp-Session-Id: {session}\r\nContent-Length: 0\r\n\r\n"
|
||||
);
|
||||
socket.write_all(response.as_bytes()).await.unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
let request_json: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let method = request_json
|
||||
.get("method")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("");
|
||||
let id = request_json
|
||||
.get("id")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| serde_json::json!("0"));
|
||||
|
||||
if method == "tools/call" && session_header.as_deref() == Some("sess-old") {
|
||||
stale_seen.store(true, AtomicOrdering::SeqCst);
|
||||
socket
|
||||
.write_all(
|
||||
b"HTTP/1.1 404 Not Found\r\nContent-Type: application/json\r\nContent-Length: 27\r\n\r\n{\"error\":\"session expired\"}",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
let result = match method {
|
||||
"initialize" => serde_json::json!({
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {}
|
||||
}),
|
||||
"tools/list" => serde_json::json!({
|
||||
"tools": [
|
||||
{ "name": "search", "inputSchema": {} }
|
||||
]
|
||||
}),
|
||||
"resources/list" => serde_json::json!({ "resources": [] }),
|
||||
"resources/templates/list" => {
|
||||
serde_json::json!({ "resourceTemplates": [] })
|
||||
}
|
||||
"prompts/list" => serde_json::json!({ "prompts": [] }),
|
||||
"tools/call" => {
|
||||
assert_eq!(session_header.as_deref(), Some("sess-new"));
|
||||
success_seen.store(true, AtomicOrdering::SeqCst);
|
||||
serde_json::json!({ "content": [{ "type": "text", "text": "ok" }] })
|
||||
}
|
||||
_ => {
|
||||
socket
|
||||
.write_all(b"HTTP/1.1 202 Accepted\r\nContent-Length: 0\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
};
|
||||
let response_body = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": result
|
||||
})
|
||||
.to_string();
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
|
||||
response_body.len(),
|
||||
response_body
|
||||
);
|
||||
socket.write_all(response.as_bytes()).await.unwrap();
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let mut cfg = McpConfig::default();
|
||||
cfg.servers.insert(
|
||||
"dephy".to_string(),
|
||||
McpServerConfig {
|
||||
command: None,
|
||||
args: Vec::new(),
|
||||
env: HashMap::new(),
|
||||
url: Some(format!("http://{addr}/mcp")),
|
||||
transport: None,
|
||||
connect_timeout: Some(2),
|
||||
execute_timeout: Some(2),
|
||||
read_timeout: None,
|
||||
disabled: false,
|
||||
enabled: true,
|
||||
required: false,
|
||||
enabled_tools: Vec::new(),
|
||||
disabled_tools: Vec::new(),
|
||||
headers: HashMap::new(),
|
||||
},
|
||||
);
|
||||
let mut pool = McpPool::new(cfg);
|
||||
|
||||
let result = pool
|
||||
.call_tool("mcp_dephy_search", serde_json::json!({ "query": "dephy" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
serde_json::json!({ "content": [{ "type": "text", "text": "ok" }] })
|
||||
);
|
||||
assert!(stale_seen.load(AtomicOrdering::SeqCst));
|
||||
assert!(success_seen.load(AtomicOrdering::SeqCst));
|
||||
assert_eq!(get_count.load(AtomicOrdering::SeqCst), 2);
|
||||
|
||||
server.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn legacy_sse_session_expiry_is_marked_stale() {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
let (mut socket, _) = listener.accept().await.unwrap();
|
||||
let mut request = Vec::new();
|
||||
let mut buf = [0; 4096];
|
||||
let header_end = loop {
|
||||
let n = socket.read(&mut buf).await.unwrap();
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
if let Some(pos) = request.windows(4).position(|w| w == b"\r\n\r\n") {
|
||||
break pos + 4;
|
||||
}
|
||||
};
|
||||
let headers = String::from_utf8_lossy(&request[..header_end]);
|
||||
assert!(headers.starts_with("POST /messages "));
|
||||
socket
|
||||
.write_all(
|
||||
b"HTTP/1.1 400 Bad Request\r\nContent-Type: application/json\r\nContent-Length: 27\r\n\r\n{\"error\":\"session expired\"}",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let (_sender, receiver) = mpsc::unbounded_channel();
|
||||
let mut transport = SseTransport {
|
||||
client: reqwest::Client::new(),
|
||||
base_url: format!("http://{addr}/sse"),
|
||||
headers: HashMap::new(),
|
||||
endpoint_url: Some(format!("http://{addr}/messages")),
|
||||
receiver,
|
||||
pending_messages: VecDeque::new(),
|
||||
};
|
||||
|
||||
let err = transport
|
||||
.send(br#"{"jsonrpc":"2.0","id":1,"method":"tools/call"}"#.to_vec())
|
||||
.await
|
||||
.expect_err("expired SSE session should fail");
|
||||
|
||||
assert!(
|
||||
is_mcp_stale_session_error(&err),
|
||||
"SSE session expiry should be retryable, got: {err:#}"
|
||||
);
|
||||
|
||||
server.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn legacy_sse_closed_stream_reconnects_and_retries_tool_call() {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
async fn read_http_request(socket: &mut TcpStream) -> (String, serde_json::Value) {
|
||||
let mut request = Vec::new();
|
||||
let mut buf = [0; 4096];
|
||||
let header_end = loop {
|
||||
let n = socket.read(&mut buf).await.unwrap();
|
||||
if n == 0 {
|
||||
return (String::new(), serde_json::Value::Null);
|
||||
}
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
if let Some(pos) = request.windows(4).position(|w| w == b"\r\n\r\n") {
|
||||
break pos + 4;
|
||||
}
|
||||
};
|
||||
let headers = String::from_utf8_lossy(&request[..header_end]).to_string();
|
||||
let content_length = headers
|
||||
.lines()
|
||||
.find_map(|line| {
|
||||
let (name, value) = line.split_once(':')?;
|
||||
name.eq_ignore_ascii_case("content-length")
|
||||
.then(|| value.trim().parse::<usize>().ok())
|
||||
.flatten()
|
||||
})
|
||||
.unwrap_or(0);
|
||||
while request.len() < header_end + content_length {
|
||||
let n = socket.read(&mut buf).await.unwrap();
|
||||
if n == 0 {
|
||||
return (headers, serde_json::Value::Null);
|
||||
}
|
||||
request.extend_from_slice(&buf[..n]);
|
||||
}
|
||||
let body = &request[header_end..header_end + content_length];
|
||||
let json = if body.is_empty() {
|
||||
serde_json::Value::Null
|
||||
} else {
|
||||
serde_json::from_slice(body).unwrap()
|
||||
};
|
||||
(headers, json)
|
||||
}
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let active_sse = Arc::new(Mutex::new(None::<mpsc::UnboundedSender<Option<String>>>));
|
||||
let get_count = Arc::new(AtomicUsize::new(0));
|
||||
let tool_call_count = Arc::new(AtomicUsize::new(0));
|
||||
let success_seen = Arc::new(AtomicBool::new(false));
|
||||
let server_active_sse = Arc::clone(&active_sse);
|
||||
let server_get_count = Arc::clone(&get_count);
|
||||
let server_tool_call_count = Arc::clone(&tool_call_count);
|
||||
let server_success_seen = Arc::clone(&success_seen);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
loop {
|
||||
let Ok((mut socket, _)) = listener.accept().await else {
|
||||
break;
|
||||
};
|
||||
let active_sse = Arc::clone(&server_active_sse);
|
||||
let get_count = Arc::clone(&server_get_count);
|
||||
let tool_call_count = Arc::clone(&server_tool_call_count);
|
||||
let success_seen = Arc::clone(&server_success_seen);
|
||||
tokio::spawn(async move {
|
||||
let (headers, request_json) = read_http_request(&mut socket).await;
|
||||
if headers.starts_with("GET /sse ") {
|
||||
get_count.fetch_add(1, AtomicOrdering::SeqCst);
|
||||
let (tx, mut rx) = mpsc::unbounded_channel::<Option<String>>();
|
||||
*active_sse.lock().unwrap() = Some(tx);
|
||||
socket
|
||||
.write_all(
|
||||
b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\r\n",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
socket
|
||||
.write_all(b"event: endpoint\ndata: /messages\n\n")
|
||||
.await
|
||||
.unwrap();
|
||||
while let Some(message) = rx.recv().await {
|
||||
let Some(message) = message else {
|
||||
return;
|
||||
};
|
||||
let event = format!("event: message\ndata: {message}\n\n");
|
||||
socket.write_all(event.as_bytes()).await.unwrap();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if !headers.starts_with("POST /messages ") {
|
||||
return;
|
||||
}
|
||||
|
||||
socket
|
||||
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let method = request_json
|
||||
.get("method")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("");
|
||||
if method == "notifications/initialized" {
|
||||
return;
|
||||
}
|
||||
|
||||
let id = request_json
|
||||
.get("id")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| serde_json::json!("0"));
|
||||
|
||||
if method == "tools/call" {
|
||||
let count = tool_call_count.fetch_add(1, AtomicOrdering::SeqCst);
|
||||
if count == 0 {
|
||||
if let Some(tx) = active_sse.lock().unwrap().take() {
|
||||
let _ = tx.send(None);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let result = match method {
|
||||
"initialize" => serde_json::json!({
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {}
|
||||
}),
|
||||
"tools/list" => serde_json::json!({
|
||||
"tools": [
|
||||
{ "name": "search", "inputSchema": {} }
|
||||
]
|
||||
}),
|
||||
"resources/list" => serde_json::json!({ "resources": [] }),
|
||||
"resources/templates/list" => {
|
||||
serde_json::json!({ "resourceTemplates": [] })
|
||||
}
|
||||
"prompts/list" => serde_json::json!({ "prompts": [] }),
|
||||
"tools/call" => {
|
||||
success_seen.store(true, AtomicOrdering::SeqCst);
|
||||
serde_json::json!({ "content": [{ "type": "text", "text": "ok" }] })
|
||||
}
|
||||
other => panic!("unexpected method: {other}"),
|
||||
};
|
||||
let response = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": result
|
||||
})
|
||||
.to_string();
|
||||
if let Some(tx) = active_sse.lock().unwrap().as_ref() {
|
||||
let _ = tx.send(Some(response));
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let mut cfg = McpConfig::default();
|
||||
cfg.servers.insert(
|
||||
"dephy".to_string(),
|
||||
McpServerConfig {
|
||||
command: None,
|
||||
args: Vec::new(),
|
||||
env: HashMap::new(),
|
||||
url: Some(format!("http://{addr}/sse")),
|
||||
transport: Some("sse".to_string()),
|
||||
connect_timeout: Some(2),
|
||||
execute_timeout: Some(2),
|
||||
read_timeout: None,
|
||||
disabled: false,
|
||||
enabled: true,
|
||||
required: false,
|
||||
enabled_tools: Vec::new(),
|
||||
disabled_tools: Vec::new(),
|
||||
headers: HashMap::new(),
|
||||
},
|
||||
);
|
||||
let mut pool = McpPool::new(cfg);
|
||||
|
||||
let result = pool
|
||||
.call_tool("mcp_dephy_search", serde_json::json!({ "query": "dephy" }))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
serde_json::json!({ "content": [{ "type": "text", "text": "ok" }] })
|
||||
);
|
||||
assert_eq!(tool_call_count.load(AtomicOrdering::SeqCst), 2);
|
||||
assert_eq!(get_count.load(AtomicOrdering::SeqCst), 2);
|
||||
assert!(success_seen.load(AtomicOrdering::SeqCst));
|
||||
|
||||
server.abort();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_id_starts_none() {
|
||||
let transport = StreamableHttpTransport::new(
|
||||
|
||||
Reference in New Issue
Block a user