refactor(tools): replace Semaphore with RwLock for parallel-safe tool execution
- Use OwnedRwLockReadGuard for parallel-safe tools, OwnedRwLockWriteGuard for serial - Add TOOL_EXECUTION_LOCK_HELD task-local for reentrancy detection - Add BlockingHandler test harness and parallel-vs-serial concurrency tests
This commit is contained in:
+41
-21
@@ -8,7 +8,11 @@ use async_trait::async_trait;
|
||||
use codewhale_protocol::{ToolKind, ToolOutput, ToolPayload};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
|
||||
|
||||
tokio::task_local! {
|
||||
static TOOL_EXECUTION_LOCK_HELD: ();
|
||||
}
|
||||
|
||||
/// Capabilities that a tool may have or require.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
@@ -311,15 +315,36 @@ pub trait ToolHandler: Send + Sync {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolCallRuntime {
|
||||
/// Serialise non-parallel tool executions. Capacity 1 ensures at most one
|
||||
/// serial tool runs at a time, and blocks parallel tools while it runs.
|
||||
serial_semaphore: Arc<Semaphore>,
|
||||
/// Preserve read/write tool execution semantics: parallel-safe tools may
|
||||
/// overlap, while serial tools run exclusively.
|
||||
execution_lock: Arc<RwLock<()>>,
|
||||
}
|
||||
|
||||
impl Default for ToolCallRuntime {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
serial_semaphore: Arc::new(Semaphore::new(1)),
|
||||
execution_lock: Arc::new(RwLock::new(())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ToolExecutionGuard {
|
||||
Parallel(#[allow(dead_code)] OwnedRwLockReadGuard<()>),
|
||||
Serial(#[allow(dead_code)] OwnedRwLockWriteGuard<()>),
|
||||
Reentrant,
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
async fn acquire(&self, supports_parallel: bool) -> ToolExecutionGuard {
|
||||
if TOOL_EXECUTION_LOCK_HELD.try_with(|_| ()).is_ok() {
|
||||
return ToolExecutionGuard::Reentrant;
|
||||
}
|
||||
|
||||
if supports_parallel {
|
||||
ToolExecutionGuard::Parallel(self.execution_lock.clone().read_owned().await)
|
||||
} else {
|
||||
ToolExecutionGuard::Serial(self.execution_lock.clone().write_owned().await)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -389,22 +414,17 @@ impl ToolRegistry {
|
||||
source: call.source,
|
||||
};
|
||||
|
||||
if configured.supports_parallel_tool_calls {
|
||||
// Parallel tools wait for any in-flight serial tool to finish,
|
||||
// but do not hold the permit so other parallel tools may run concurrently.
|
||||
drop(self.runtime.serial_semaphore.acquire().await
|
||||
.map_err(|_| FunctionCallError::Cancelled { name: call.name })?);
|
||||
self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
|
||||
.await
|
||||
} else {
|
||||
// Serial tools hold the semaphore for the full execution duration,
|
||||
// preventing other serial AND parallel tools from starting.
|
||||
let _permit = self.runtime.serial_semaphore.acquire().await
|
||||
.map_err(|_| FunctionCallError::Cancelled { name: call.name })?;
|
||||
self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
|
||||
.await
|
||||
// _permit dropped here, releasing the semaphore.
|
||||
}
|
||||
let _guard = self
|
||||
.runtime
|
||||
.acquire(configured.supports_parallel_tool_calls)
|
||||
.await;
|
||||
|
||||
TOOL_EXECUTION_LOCK_HELD
|
||||
.scope(
|
||||
(),
|
||||
self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn execute_with_timeout(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use codewhale_protocol::{ToolKind, ToolOutput, ToolPayload};
|
||||
@@ -6,6 +7,7 @@ use codewhale_tools::{
|
||||
ToolCall, ToolCallSource, ToolHandler, ToolInvocation, ToolRegistry, ToolSpec,
|
||||
};
|
||||
use serde_json::json;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
struct EchoHandler;
|
||||
|
||||
@@ -33,6 +35,64 @@ impl ToolHandler for EchoHandler {
|
||||
}
|
||||
}
|
||||
|
||||
struct BlockingHandler {
|
||||
started: Arc<Notify>,
|
||||
release: Arc<Notify>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for BlockingHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
) -> std::result::Result<ToolOutput, codewhale_tools::FunctionCallError> {
|
||||
self.started.notify_waiters();
|
||||
self.release.notified().await;
|
||||
Ok(ToolOutput::Function {
|
||||
body: Some(json!({
|
||||
"tool": invocation.tool_name,
|
||||
"call_id": invocation.call_id
|
||||
})),
|
||||
success: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ReentrantHandler {
|
||||
registry: Arc<OnceLock<Arc<ToolRegistry>>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ReentrantHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
_invocation: ToolInvocation,
|
||||
) -> std::result::Result<ToolOutput, codewhale_tools::FunctionCallError> {
|
||||
let registry = self.registry.get().expect("registry initialized").clone();
|
||||
registry
|
||||
.dispatch(
|
||||
ToolCall {
|
||||
name: "inner".to_string(),
|
||||
payload: ToolPayload::Function {
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
source: ToolCallSource::Direct,
|
||||
raw_tool_call_id: Some("inner-call".to_string()),
|
||||
},
|
||||
true,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatches_function_tool_with_parallel_flag() {
|
||||
let mut registry = ToolRegistry::default();
|
||||
@@ -68,3 +128,149 @@ async fn dispatches_function_tool_with_parallel_flag() {
|
||||
other => panic!("unexpected output: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn serial_tool_waits_for_running_parallel_tool() {
|
||||
let started = Arc::new(Notify::new());
|
||||
let release = Arc::new(Notify::new());
|
||||
let mut registry = ToolRegistry::default();
|
||||
registry
|
||||
.register(
|
||||
ToolSpec {
|
||||
name: "slow_read".to_string(),
|
||||
input_schema: json!({"type":"object"}),
|
||||
output_schema: json!({"type":"object"}),
|
||||
supports_parallel_tool_calls: true,
|
||||
timeout_ms: Some(1000),
|
||||
},
|
||||
Arc::new(BlockingHandler {
|
||||
started: started.clone(),
|
||||
release: release.clone(),
|
||||
}),
|
||||
)
|
||||
.expect("register slow read");
|
||||
registry
|
||||
.register(
|
||||
ToolSpec {
|
||||
name: "serial".to_string(),
|
||||
input_schema: json!({"type":"object"}),
|
||||
output_schema: json!({"type":"object"}),
|
||||
supports_parallel_tool_calls: false,
|
||||
timeout_ms: Some(1000),
|
||||
},
|
||||
Arc::new(EchoHandler),
|
||||
)
|
||||
.expect("register serial");
|
||||
|
||||
let registry = Arc::new(registry);
|
||||
let started_wait = started.notified();
|
||||
let parallel_registry = registry.clone();
|
||||
let parallel = tokio::spawn(async move {
|
||||
parallel_registry
|
||||
.dispatch(
|
||||
ToolCall {
|
||||
name: "slow_read".to_string(),
|
||||
payload: ToolPayload::Function {
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
source: ToolCallSource::Direct,
|
||||
raw_tool_call_id: Some("parallel-call".to_string()),
|
||||
},
|
||||
true,
|
||||
)
|
||||
.await
|
||||
});
|
||||
tokio::time::timeout(Duration::from_secs(1), started_wait)
|
||||
.await
|
||||
.expect("parallel tool started");
|
||||
|
||||
let serial_registry = registry.clone();
|
||||
let mut serial = tokio::spawn(async move {
|
||||
serial_registry
|
||||
.dispatch(
|
||||
ToolCall {
|
||||
name: "serial".to_string(),
|
||||
payload: ToolPayload::Function {
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
source: ToolCallSource::Direct,
|
||||
raw_tool_call_id: Some("serial-call".to_string()),
|
||||
},
|
||||
true,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut serial => panic!("serial tool overlapped a running parallel tool"),
|
||||
() = tokio::time::sleep(Duration::from_millis(50)) => {}
|
||||
}
|
||||
|
||||
release.notify_waiters();
|
||||
serial
|
||||
.await
|
||||
.expect("serial task panicked")
|
||||
.expect("serial ran");
|
||||
parallel
|
||||
.await
|
||||
.expect("parallel task panicked")
|
||||
.expect("parallel ran");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn serial_tool_can_reenter_registry_without_deadlock() {
|
||||
let registry_cell = Arc::new(OnceLock::new());
|
||||
let mut registry = ToolRegistry::default();
|
||||
registry
|
||||
.register(
|
||||
ToolSpec {
|
||||
name: "outer".to_string(),
|
||||
input_schema: json!({"type":"object"}),
|
||||
output_schema: json!({"type":"object"}),
|
||||
supports_parallel_tool_calls: false,
|
||||
timeout_ms: Some(1000),
|
||||
},
|
||||
Arc::new(ReentrantHandler {
|
||||
registry: registry_cell.clone(),
|
||||
}),
|
||||
)
|
||||
.expect("register outer");
|
||||
registry
|
||||
.register(
|
||||
ToolSpec {
|
||||
name: "inner".to_string(),
|
||||
input_schema: json!({"type":"object"}),
|
||||
output_schema: json!({"type":"object"}),
|
||||
supports_parallel_tool_calls: false,
|
||||
timeout_ms: Some(1000),
|
||||
},
|
||||
Arc::new(EchoHandler),
|
||||
)
|
||||
.expect("register inner");
|
||||
|
||||
let registry = Arc::new(registry);
|
||||
assert!(registry_cell.set(registry.clone()).is_ok());
|
||||
|
||||
let output = tokio::time::timeout(
|
||||
Duration::from_secs(1),
|
||||
registry.dispatch(
|
||||
ToolCall {
|
||||
name: "outer".to_string(),
|
||||
payload: ToolPayload::Function {
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
source: ToolCallSource::Direct,
|
||||
raw_tool_call_id: Some("outer-call".to_string()),
|
||||
},
|
||||
true,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("outer dispatch timed out")
|
||||
.expect("outer dispatch failed");
|
||||
|
||||
match output {
|
||||
ToolOutput::Function { success, .. } => assert!(success),
|
||||
other => panic!("unexpected output: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user