refactor(tools): replace Semaphore with RwLock for parallel-safe tool execution

- Use OwnedRwLockReadGuard for parallel-safe tools, OwnedRwLockWriteGuard for serial
- Add TOOL_EXECUTION_LOCK_HELD task-local for reentrancy detection
- Add BlockingHandler test harness and parallel-vs-serial concurrency tests
This commit is contained in:
Hunter Bown
2026-05-26 16:39:39 -05:00
parent 60c1b6619c
commit 74878dcd30
2 changed files with 248 additions and 22 deletions
+41 -21
View File
@@ -8,7 +8,11 @@ use async_trait::async_trait;
use codewhale_protocol::{ToolKind, ToolOutput, ToolPayload};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::Semaphore;
use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
tokio::task_local! {
static TOOL_EXECUTION_LOCK_HELD: ();
}
/// Capabilities that a tool may have or require.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@@ -311,15 +315,36 @@ pub trait ToolHandler: Send + Sync {
#[derive(Debug)]
pub struct ToolCallRuntime {
/// Serialise non-parallel tool executions. Capacity 1 ensures at most one
/// serial tool runs at a time, and blocks parallel tools while it runs.
serial_semaphore: Arc<Semaphore>,
/// Preserve read/write tool execution semantics: parallel-safe tools may
/// overlap, while serial tools run exclusively.
execution_lock: Arc<RwLock<()>>,
}
impl Default for ToolCallRuntime {
fn default() -> Self {
Self {
serial_semaphore: Arc::new(Semaphore::new(1)),
execution_lock: Arc::new(RwLock::new(())),
}
}
}
#[derive(Debug)]
enum ToolExecutionGuard {
Parallel(#[allow(dead_code)] OwnedRwLockReadGuard<()>),
Serial(#[allow(dead_code)] OwnedRwLockWriteGuard<()>),
Reentrant,
}
impl ToolCallRuntime {
async fn acquire(&self, supports_parallel: bool) -> ToolExecutionGuard {
if TOOL_EXECUTION_LOCK_HELD.try_with(|_| ()).is_ok() {
return ToolExecutionGuard::Reentrant;
}
if supports_parallel {
ToolExecutionGuard::Parallel(self.execution_lock.clone().read_owned().await)
} else {
ToolExecutionGuard::Serial(self.execution_lock.clone().write_owned().await)
}
}
}
@@ -389,22 +414,17 @@ impl ToolRegistry {
source: call.source,
};
if configured.supports_parallel_tool_calls {
// Parallel tools wait for any in-flight serial tool to finish,
// but do not hold the permit so other parallel tools may run concurrently.
drop(self.runtime.serial_semaphore.acquire().await
.map_err(|_| FunctionCallError::Cancelled { name: call.name })?);
self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
.await
} else {
// Serial tools hold the semaphore for the full execution duration,
// preventing other serial AND parallel tools from starting.
let _permit = self.runtime.serial_semaphore.acquire().await
.map_err(|_| FunctionCallError::Cancelled { name: call.name })?;
self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
.await
// _permit dropped here, releasing the semaphore.
}
let _guard = self
.runtime
.acquire(configured.supports_parallel_tool_calls)
.await;
TOOL_EXECUTION_LOCK_HELD
.scope(
(),
self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation),
)
.await
}
async fn execute_with_timeout(
+207 -1
View File
@@ -1,4 +1,5 @@
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use async_trait::async_trait;
use codewhale_protocol::{ToolKind, ToolOutput, ToolPayload};
@@ -6,6 +7,7 @@ use codewhale_tools::{
ToolCall, ToolCallSource, ToolHandler, ToolInvocation, ToolRegistry, ToolSpec,
};
use serde_json::json;
use tokio::sync::Notify;
struct EchoHandler;
@@ -33,6 +35,64 @@ impl ToolHandler for EchoHandler {
}
}
struct BlockingHandler {
started: Arc<Notify>,
release: Arc<Notify>,
}
#[async_trait]
impl ToolHandler for BlockingHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation,
) -> std::result::Result<ToolOutput, codewhale_tools::FunctionCallError> {
self.started.notify_waiters();
self.release.notified().await;
Ok(ToolOutput::Function {
body: Some(json!({
"tool": invocation.tool_name,
"call_id": invocation.call_id
})),
success: true,
})
}
}
struct ReentrantHandler {
registry: Arc<OnceLock<Arc<ToolRegistry>>>,
}
#[async_trait]
impl ToolHandler for ReentrantHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(
&self,
_invocation: ToolInvocation,
) -> std::result::Result<ToolOutput, codewhale_tools::FunctionCallError> {
let registry = self.registry.get().expect("registry initialized").clone();
registry
.dispatch(
ToolCall {
name: "inner".to_string(),
payload: ToolPayload::Function {
arguments: "{}".to_string(),
},
source: ToolCallSource::Direct,
raw_tool_call_id: Some("inner-call".to_string()),
},
true,
)
.await
}
}
#[tokio::test]
async fn dispatches_function_tool_with_parallel_flag() {
let mut registry = ToolRegistry::default();
@@ -68,3 +128,149 @@ async fn dispatches_function_tool_with_parallel_flag() {
other => panic!("unexpected output: {other:?}"),
}
}
#[tokio::test]
async fn serial_tool_waits_for_running_parallel_tool() {
let started = Arc::new(Notify::new());
let release = Arc::new(Notify::new());
let mut registry = ToolRegistry::default();
registry
.register(
ToolSpec {
name: "slow_read".to_string(),
input_schema: json!({"type":"object"}),
output_schema: json!({"type":"object"}),
supports_parallel_tool_calls: true,
timeout_ms: Some(1000),
},
Arc::new(BlockingHandler {
started: started.clone(),
release: release.clone(),
}),
)
.expect("register slow read");
registry
.register(
ToolSpec {
name: "serial".to_string(),
input_schema: json!({"type":"object"}),
output_schema: json!({"type":"object"}),
supports_parallel_tool_calls: false,
timeout_ms: Some(1000),
},
Arc::new(EchoHandler),
)
.expect("register serial");
let registry = Arc::new(registry);
let started_wait = started.notified();
let parallel_registry = registry.clone();
let parallel = tokio::spawn(async move {
parallel_registry
.dispatch(
ToolCall {
name: "slow_read".to_string(),
payload: ToolPayload::Function {
arguments: "{}".to_string(),
},
source: ToolCallSource::Direct,
raw_tool_call_id: Some("parallel-call".to_string()),
},
true,
)
.await
});
tokio::time::timeout(Duration::from_secs(1), started_wait)
.await
.expect("parallel tool started");
let serial_registry = registry.clone();
let mut serial = tokio::spawn(async move {
serial_registry
.dispatch(
ToolCall {
name: "serial".to_string(),
payload: ToolPayload::Function {
arguments: "{}".to_string(),
},
source: ToolCallSource::Direct,
raw_tool_call_id: Some("serial-call".to_string()),
},
true,
)
.await
});
tokio::select! {
_ = &mut serial => panic!("serial tool overlapped a running parallel tool"),
() = tokio::time::sleep(Duration::from_millis(50)) => {}
}
release.notify_waiters();
serial
.await
.expect("serial task panicked")
.expect("serial ran");
parallel
.await
.expect("parallel task panicked")
.expect("parallel ran");
}
#[tokio::test]
async fn serial_tool_can_reenter_registry_without_deadlock() {
let registry_cell = Arc::new(OnceLock::new());
let mut registry = ToolRegistry::default();
registry
.register(
ToolSpec {
name: "outer".to_string(),
input_schema: json!({"type":"object"}),
output_schema: json!({"type":"object"}),
supports_parallel_tool_calls: false,
timeout_ms: Some(1000),
},
Arc::new(ReentrantHandler {
registry: registry_cell.clone(),
}),
)
.expect("register outer");
registry
.register(
ToolSpec {
name: "inner".to_string(),
input_schema: json!({"type":"object"}),
output_schema: json!({"type":"object"}),
supports_parallel_tool_calls: false,
timeout_ms: Some(1000),
},
Arc::new(EchoHandler),
)
.expect("register inner");
let registry = Arc::new(registry);
assert!(registry_cell.set(registry.clone()).is_ok());
let output = tokio::time::timeout(
Duration::from_secs(1),
registry.dispatch(
ToolCall {
name: "outer".to_string(),
payload: ToolPayload::Function {
arguments: "{}".to_string(),
},
source: ToolCallSource::Direct,
raw_tool_call_id: Some("outer-call".to_string()),
},
true,
),
)
.await
.expect("outer dispatch timed out")
.expect("outer dispatch failed");
match output {
ToolOutput::Function { success, .. } => assert!(success),
other => panic!("unexpected output: {other:?}"),
}
}