fix(tools): replace cross-await RwLock with Semaphore to prevent deadlock
Replace `Arc<RwLock<()>>` in ToolCallRuntime with `Arc<Semaphore>` to eliminate the risk of a tool re-entering and deadlocking on the same lock. Parallel tools now acquire then immediately drop the permit, allowing concurrent execution after any in-flight serial tool finishes. Serial tools hold the permit for the full duration. Fixes #2157. Harvested from #1856. Co-authored-by: Fire-dtx <58944505+Fire-dtx@users.noreply.github.com>
This commit is contained in:
+22
-5
@@ -8,7 +8,7 @@ use async_trait::async_trait;
|
||||
use codewhale_protocol::{ToolKind, ToolOutput, ToolPayload};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
/// Capabilities that a tool may have or require.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
@@ -309,9 +309,19 @@ pub trait ToolHandler: Send + Sync {
|
||||
) -> std::result::Result<ToolOutput, FunctionCallError>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
#[derive(Debug)]
|
||||
pub struct ToolCallRuntime {
|
||||
pub parallel_execution: Arc<RwLock<()>>,
|
||||
/// Serialise non-parallel tool executions. Capacity 1 ensures at most one
|
||||
/// serial tool runs at a time, and blocks parallel tools while it runs.
|
||||
serial_semaphore: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl Default for ToolCallRuntime {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
serial_semaphore: Arc::new(Semaphore::new(1)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -380,13 +390,20 @@ impl ToolRegistry {
|
||||
};
|
||||
|
||||
if configured.supports_parallel_tool_calls {
|
||||
let _guard = self.runtime.parallel_execution.read().await;
|
||||
// Parallel tools wait for any in-flight serial tool to finish,
|
||||
// but do not hold the permit so other parallel tools may run concurrently.
|
||||
drop(self.runtime.serial_semaphore.acquire().await
|
||||
.map_err(|_| FunctionCallError::Cancelled { name: call.name })?);
|
||||
self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
|
||||
.await
|
||||
} else {
|
||||
let _guard = self.runtime.parallel_execution.write().await;
|
||||
// Serial tools hold the semaphore for the full execution duration,
|
||||
// preventing other serial AND parallel tools from starting.
|
||||
let _permit = self.runtime.serial_semaphore.acquire().await
|
||||
.map_err(|_| FunctionCallError::Cancelled { name: call.name })?;
|
||||
self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
|
||||
.await
|
||||
// _permit dropped here, releasing the semaphore.
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user