diff --git a/crates/tools/src/lib.rs b/crates/tools/src/lib.rs index a7179410..050b840f 100644 --- a/crates/tools/src/lib.rs +++ b/crates/tools/src/lib.rs @@ -8,7 +8,7 @@ use async_trait::async_trait; use codewhale_protocol::{ToolKind, ToolOutput, ToolPayload}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use tokio::sync::RwLock; +use tokio::sync::Semaphore; /// Capabilities that a tool may have or require. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -309,9 +309,19 @@ pub trait ToolHandler: Send + Sync { ) -> std::result::Result; } -#[derive(Debug, Default)] +#[derive(Debug)] pub struct ToolCallRuntime { - pub parallel_execution: Arc>, + /// Serialise non-parallel tool executions. Capacity 1 ensures at most one + /// serial tool runs at a time, and blocks parallel tools while it runs. + serial_semaphore: Arc, +} + +impl Default for ToolCallRuntime { + fn default() -> Self { + Self { + serial_semaphore: Arc::new(Semaphore::new(1)), + } + } } #[derive(Default)] @@ -380,13 +390,20 @@ impl ToolRegistry { }; if configured.supports_parallel_tool_calls { - let _guard = self.runtime.parallel_execution.read().await; + // Parallel tools wait for any in-flight serial tool to finish, + // but do not hold the permit so other parallel tools may run concurrently. + drop(self.runtime.serial_semaphore.acquire().await + .map_err(|_| FunctionCallError::Cancelled { name: call.name })?); self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation) .await } else { - let _guard = self.runtime.parallel_execution.write().await; + // Serial tools hold the semaphore for the full execution duration, + // preventing other serial AND parallel tools from starting. + let _permit = self.runtime.serial_semaphore.acquire().await + .map_err(|_| FunctionCallError::Cancelled { name: call.name })?; self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation) .await + // _permit dropped here, releasing the semaphore. } }