diff --git a/crates/tools/src/lib.rs b/crates/tools/src/lib.rs index 050b840f..b0ffc55b 100644 --- a/crates/tools/src/lib.rs +++ b/crates/tools/src/lib.rs @@ -8,7 +8,11 @@ use async_trait::async_trait; use codewhale_protocol::{ToolKind, ToolOutput, ToolPayload}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use tokio::sync::Semaphore; +use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock}; + +tokio::task_local! { + static TOOL_EXECUTION_LOCK_HELD: (); +} /// Capabilities that a tool may have or require. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -311,15 +315,36 @@ pub trait ToolHandler: Send + Sync { #[derive(Debug)] pub struct ToolCallRuntime { - /// Serialise non-parallel tool executions. Capacity 1 ensures at most one - /// serial tool runs at a time, and blocks parallel tools while it runs. - serial_semaphore: Arc, + /// Preserve read/write tool execution semantics: parallel-safe tools may + /// overlap, while serial tools run exclusively. + execution_lock: Arc>, } impl Default for ToolCallRuntime { fn default() -> Self { Self { - serial_semaphore: Arc::new(Semaphore::new(1)), + execution_lock: Arc::new(RwLock::new(())), + } + } +} + +#[derive(Debug)] +enum ToolExecutionGuard { + Parallel(#[allow(dead_code)] OwnedRwLockReadGuard<()>), + Serial(#[allow(dead_code)] OwnedRwLockWriteGuard<()>), + Reentrant, +} + +impl ToolCallRuntime { + async fn acquire(&self, supports_parallel: bool) -> ToolExecutionGuard { + if TOOL_EXECUTION_LOCK_HELD.try_with(|_| ()).is_ok() { + return ToolExecutionGuard::Reentrant; + } + + if supports_parallel { + ToolExecutionGuard::Parallel(self.execution_lock.clone().read_owned().await) + } else { + ToolExecutionGuard::Serial(self.execution_lock.clone().write_owned().await) } } } @@ -389,22 +414,17 @@ impl ToolRegistry { source: call.source, }; - if configured.supports_parallel_tool_calls { - // Parallel tools wait for any in-flight serial tool to finish, - // but do not hold the permit so other parallel tools may run concurrently. - drop(self.runtime.serial_semaphore.acquire().await - .map_err(|_| FunctionCallError::Cancelled { name: call.name })?); - self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation) - .await - } else { - // Serial tools hold the semaphore for the full execution duration, - // preventing other serial AND parallel tools from starting. - let _permit = self.runtime.serial_semaphore.acquire().await - .map_err(|_| FunctionCallError::Cancelled { name: call.name })?; - self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation) - .await - // _permit dropped here, releasing the semaphore. - } + let _guard = self + .runtime + .acquire(configured.supports_parallel_tool_calls) + .await; + + TOOL_EXECUTION_LOCK_HELD + .scope( + (), + self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation), + ) + .await } async fn execute_with_timeout( diff --git a/crates/tools/tests/parity_tools.rs b/crates/tools/tests/parity_tools.rs index fb08753b..ef525ba4 100644 --- a/crates/tools/tests/parity_tools.rs +++ b/crates/tools/tests/parity_tools.rs @@ -1,4 +1,5 @@ -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; +use std::time::Duration; use async_trait::async_trait; use codewhale_protocol::{ToolKind, ToolOutput, ToolPayload}; @@ -6,6 +7,7 @@ use codewhale_tools::{ ToolCall, ToolCallSource, ToolHandler, ToolInvocation, ToolRegistry, ToolSpec, }; use serde_json::json; +use tokio::sync::Notify; struct EchoHandler; @@ -33,6 +35,64 @@ impl ToolHandler for EchoHandler { } } +struct BlockingHandler { + started: Arc, + release: Arc, +} + +#[async_trait] +impl ToolHandler for BlockingHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle( + &self, + invocation: ToolInvocation, + ) -> std::result::Result { + self.started.notify_waiters(); + self.release.notified().await; + Ok(ToolOutput::Function { + body: Some(json!({ + "tool": invocation.tool_name, + "call_id": invocation.call_id + })), + success: true, + }) + } +} + +struct ReentrantHandler { + registry: Arc>>, +} + +#[async_trait] +impl ToolHandler for ReentrantHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle( + &self, + _invocation: ToolInvocation, + ) -> std::result::Result { + let registry = self.registry.get().expect("registry initialized").clone(); + registry + .dispatch( + ToolCall { + name: "inner".to_string(), + payload: ToolPayload::Function { + arguments: "{}".to_string(), + }, + source: ToolCallSource::Direct, + raw_tool_call_id: Some("inner-call".to_string()), + }, + true, + ) + .await + } +} + #[tokio::test] async fn dispatches_function_tool_with_parallel_flag() { let mut registry = ToolRegistry::default(); @@ -68,3 +128,149 @@ async fn dispatches_function_tool_with_parallel_flag() { other => panic!("unexpected output: {other:?}"), } } + +#[tokio::test] +async fn serial_tool_waits_for_running_parallel_tool() { + let started = Arc::new(Notify::new()); + let release = Arc::new(Notify::new()); + let mut registry = ToolRegistry::default(); + registry + .register( + ToolSpec { + name: "slow_read".to_string(), + input_schema: json!({"type":"object"}), + output_schema: json!({"type":"object"}), + supports_parallel_tool_calls: true, + timeout_ms: Some(1000), + }, + Arc::new(BlockingHandler { + started: started.clone(), + release: release.clone(), + }), + ) + .expect("register slow read"); + registry + .register( + ToolSpec { + name: "serial".to_string(), + input_schema: json!({"type":"object"}), + output_schema: json!({"type":"object"}), + supports_parallel_tool_calls: false, + timeout_ms: Some(1000), + }, + Arc::new(EchoHandler), + ) + .expect("register serial"); + + let registry = Arc::new(registry); + let started_wait = started.notified(); + let parallel_registry = registry.clone(); + let parallel = tokio::spawn(async move { + parallel_registry + .dispatch( + ToolCall { + name: "slow_read".to_string(), + payload: ToolPayload::Function { + arguments: "{}".to_string(), + }, + source: ToolCallSource::Direct, + raw_tool_call_id: Some("parallel-call".to_string()), + }, + true, + ) + .await + }); + tokio::time::timeout(Duration::from_secs(1), started_wait) + .await + .expect("parallel tool started"); + + let serial_registry = registry.clone(); + let mut serial = tokio::spawn(async move { + serial_registry + .dispatch( + ToolCall { + name: "serial".to_string(), + payload: ToolPayload::Function { + arguments: "{}".to_string(), + }, + source: ToolCallSource::Direct, + raw_tool_call_id: Some("serial-call".to_string()), + }, + true, + ) + .await + }); + + tokio::select! { + _ = &mut serial => panic!("serial tool overlapped a running parallel tool"), + () = tokio::time::sleep(Duration::from_millis(50)) => {} + } + + release.notify_waiters(); + serial + .await + .expect("serial task panicked") + .expect("serial ran"); + parallel + .await + .expect("parallel task panicked") + .expect("parallel ran"); +} + +#[tokio::test] +async fn serial_tool_can_reenter_registry_without_deadlock() { + let registry_cell = Arc::new(OnceLock::new()); + let mut registry = ToolRegistry::default(); + registry + .register( + ToolSpec { + name: "outer".to_string(), + input_schema: json!({"type":"object"}), + output_schema: json!({"type":"object"}), + supports_parallel_tool_calls: false, + timeout_ms: Some(1000), + }, + Arc::new(ReentrantHandler { + registry: registry_cell.clone(), + }), + ) + .expect("register outer"); + registry + .register( + ToolSpec { + name: "inner".to_string(), + input_schema: json!({"type":"object"}), + output_schema: json!({"type":"object"}), + supports_parallel_tool_calls: false, + timeout_ms: Some(1000), + }, + Arc::new(EchoHandler), + ) + .expect("register inner"); + + let registry = Arc::new(registry); + assert!(registry_cell.set(registry.clone()).is_ok()); + + let output = tokio::time::timeout( + Duration::from_secs(1), + registry.dispatch( + ToolCall { + name: "outer".to_string(), + payload: ToolPayload::Function { + arguments: "{}".to_string(), + }, + source: ToolCallSource::Direct, + raw_tool_call_id: Some("outer-call".to_string()), + }, + true, + ), + ) + .await + .expect("outer dispatch timed out") + .expect("outer dispatch failed"); + + match output { + ToolOutput::Function { success, .. } => assert!(success), + other => panic!("unexpected output: {other:?}"), + } +}