feat(runtime): bridge user-input events and API to external GUI clients
Add SSE event forwarding for UserInputRequired, REST endpoint for submitting user input responses, timeout protection for await_user_input, and fix interrupt_turn to clear active_turn immediately.
This commit is contained in:
@@ -5,10 +5,14 @@
|
||||
//! or whenever a tool requests live user input (`await_user_input`). Channels
|
||||
//! and engine state stay private to the parent module.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::core::events::Event;
|
||||
use crate::tools::spec::ToolError;
|
||||
use crate::tools::user_input::{UserInputRequest, UserInputResponse};
|
||||
|
||||
const USER_INPUT_TIMEOUT: Duration = Duration::from_secs(300);
|
||||
|
||||
use super::Engine;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -123,22 +127,43 @@ impl Engine {
|
||||
format!("Request cancelled while awaiting user input{suffix}"),
|
||||
));
|
||||
}
|
||||
decision = self.rx_user_input.recv() => {
|
||||
let Some(decision) = decision else {
|
||||
return Err(ToolError::execution_failed(
|
||||
"User input channel closed".to_string(),
|
||||
));
|
||||
};
|
||||
match decision {
|
||||
UserInputDecision::Submitted { id, response } if id == tool_id => {
|
||||
return Ok(response);
|
||||
result = tokio::time::timeout(USER_INPUT_TIMEOUT, self.rx_user_input.recv()) => {
|
||||
match result {
|
||||
Ok(Some(decision)) => {
|
||||
match decision {
|
||||
UserInputDecision::Submitted { id, response } if id == tool_id => {
|
||||
return Ok(response);
|
||||
}
|
||||
UserInputDecision::Cancelled { id } if id == tool_id => {
|
||||
return Err(ToolError::execution_failed(
|
||||
"User input cancelled".to_string(),
|
||||
));
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
UserInputDecision::Cancelled { id } if id == tool_id => {
|
||||
Ok(None) => {
|
||||
return Err(ToolError::execution_failed(
|
||||
"User input cancelled".to_string(),
|
||||
"User input channel closed".to_string(),
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = self
|
||||
.tx_event
|
||||
.send(Event::Status {
|
||||
message: format!(
|
||||
"User input timed out after {}s",
|
||||
USER_INPUT_TIMEOUT.as_secs()
|
||||
),
|
||||
})
|
||||
.await;
|
||||
return Err(ToolError::execution_failed(
|
||||
format!(
|
||||
"User input timed out after {}s",
|
||||
USER_INPUT_TIMEOUT.as_secs()
|
||||
),
|
||||
));
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,6 +296,25 @@ struct DecideApprovalResponse {
|
||||
delivered: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SubmitUserInputBody {
|
||||
answers: Vec<UserInputAnswerBody>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UserInputAnswerBody {
|
||||
id: String,
|
||||
label: String,
|
||||
value: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct SubmitUserInputResponse {
|
||||
ok: bool,
|
||||
input_id: String,
|
||||
delivered: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct RuntimeInfoResponse {
|
||||
bind_host: String,
|
||||
@@ -500,6 +519,7 @@ pub fn build_router(state: RuntimeApiState) -> Router {
|
||||
.route("/v1/threads/{id}/compact", post(compact_thread))
|
||||
.route("/v1/threads/{id}/events", get(stream_thread_events))
|
||||
.route("/v1/approvals/{approval_id}", post(decide_approval))
|
||||
.route("/v1/user-input/{thread_id}/{input_id}", post(submit_user_input))
|
||||
.route("/v1/tasks", get(list_tasks).post(create_task))
|
||||
.route("/v1/tasks/{id}", get(get_task))
|
||||
.route("/v1/tasks/{id}/cancel", post(cancel_task))
|
||||
@@ -1015,6 +1035,34 @@ async fn decide_approval(
|
||||
}))
|
||||
}
|
||||
|
||||
async fn submit_user_input(
|
||||
State(state): State<RuntimeApiState>,
|
||||
Path((thread_id, input_id)): Path<(String, String)>,
|
||||
Json(req): Json<SubmitUserInputBody>,
|
||||
) -> Result<Json<SubmitUserInputResponse>, ApiError> {
|
||||
use crate::tools::user_input::{UserInputAnswer, UserInputResponse};
|
||||
let answers: Vec<UserInputAnswer> = req
|
||||
.answers
|
||||
.into_iter()
|
||||
.map(|a| UserInputAnswer {
|
||||
id: a.id,
|
||||
label: a.label,
|
||||
value: a.value,
|
||||
})
|
||||
.collect();
|
||||
let response = UserInputResponse { answers };
|
||||
let delivered = state
|
||||
.runtime_threads
|
||||
.submit_user_input(&thread_id, &input_id, response)
|
||||
.await
|
||||
.map_err(|e| ApiError::internal(format!("Failed to submit user input: {e}")))?;
|
||||
Ok(Json(SubmitUserInputResponse {
|
||||
ok: true,
|
||||
input_id,
|
||||
delivered,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn runtime_info(State(state): State<RuntimeApiState>) -> Json<RuntimeInfoResponse> {
|
||||
Json(RuntimeInfoResponse {
|
||||
bind_host: state.bind_host.clone(),
|
||||
|
||||
@@ -833,6 +833,30 @@ impl RuntimeThreadManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn submit_user_input(
|
||||
&self,
|
||||
thread_id: &str,
|
||||
input_id: &str,
|
||||
response: crate::tools::user_input::UserInputResponse,
|
||||
) -> Result<bool> {
|
||||
let active = self.active.lock().await;
|
||||
let Some(state) = active.engines.get(thread_id) else {
|
||||
bail!("thread '{thread_id}' not loaded");
|
||||
};
|
||||
state.engine.submit_user_input(input_id, response).await?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn cancel_user_input(&self, thread_id: &str, input_id: &str) -> Result<bool> {
|
||||
let active = self.active.lock().await;
|
||||
let Some(state) = active.engines.get(thread_id) else {
|
||||
bail!("thread '{thread_id}' not loaded");
|
||||
};
|
||||
state.engine.cancel_user_input(input_id).await?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn pending_approvals_count(&self) -> usize {
|
||||
self.pending_approvals
|
||||
@@ -1704,7 +1728,41 @@ impl RuntimeThreadManager {
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.store.load_turn(turn_id)
|
||||
let ended_at = Utc::now();
|
||||
let mut turn = self.store.load_turn(turn_id)?;
|
||||
turn.status = RuntimeTurnStatus::Interrupted;
|
||||
turn.ended_at = Some(ended_at);
|
||||
turn.duration_ms = turn.started_at.map(|start| duration_ms(start, ended_at));
|
||||
self.store.save_turn(&turn)?;
|
||||
|
||||
let mut thread = self.get_thread(thread_id).await?;
|
||||
thread.latest_turn_id = Some(turn_id.to_string());
|
||||
thread.updated_at = Utc::now();
|
||||
self.store.save_thread(&thread)?;
|
||||
|
||||
self.emit_event(
|
||||
thread_id,
|
||||
Some(turn_id),
|
||||
None,
|
||||
"turn.completed",
|
||||
json!({ "turn": turn.clone() }),
|
||||
)
|
||||
.await?;
|
||||
|
||||
{
|
||||
let mut active = self.active.lock().await;
|
||||
if let Some(state) = active.engines.get_mut(thread_id)
|
||||
&& state
|
||||
.active_turn
|
||||
.as_ref()
|
||||
.is_some_and(|t| t.turn_id == turn_id)
|
||||
{
|
||||
state.active_turn = None;
|
||||
}
|
||||
touch_lru(&mut active.lru, thread_id);
|
||||
}
|
||||
|
||||
Ok(turn)
|
||||
}
|
||||
|
||||
pub async fn steer_turn(
|
||||
@@ -2791,6 +2849,19 @@ impl RuntimeThreadManager {
|
||||
}
|
||||
}
|
||||
}
|
||||
EngineEvent::UserInputRequired { id, request } => {
|
||||
self.emit_event(
|
||||
&thread_id,
|
||||
Some(&turn_id),
|
||||
None,
|
||||
"user_input.required",
|
||||
json!({
|
||||
"id": id,
|
||||
"request": request,
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
EngineEvent::Status { message } => {
|
||||
let item = TurnItemRecord {
|
||||
schema_version: CURRENT_RUNTIME_SCHEMA_VERSION,
|
||||
|
||||
Reference in New Issue
Block a user