fix(hooks): harden message_submit review cases
This commit is contained in:
+178
-42
@@ -19,6 +19,7 @@ use std::collections::HashMap;
|
||||
use std::io::{Read, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::thread::{self, JoinHandle};
|
||||
use std::time::{Duration, Instant};
|
||||
use wait_timeout::ChildExt;
|
||||
|
||||
@@ -428,13 +429,44 @@ pub struct HookResult {
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum MessageSubmitOutcome {
|
||||
/// No hook changed the submitted text.
|
||||
Unchanged,
|
||||
Unchanged { warning: Option<String> },
|
||||
/// One or more hooks replaced the submitted text.
|
||||
Replaced(String),
|
||||
Replaced {
|
||||
text: String,
|
||||
warning: Option<String>,
|
||||
},
|
||||
/// A hook intentionally blocked the submission.
|
||||
Blocked { reason: String },
|
||||
}
|
||||
|
||||
impl MessageSubmitOutcome {
|
||||
pub fn unchanged() -> Self {
|
||||
Self::Unchanged { warning: None }
|
||||
}
|
||||
|
||||
pub fn replaced(text: String) -> Self {
|
||||
Self::Replaced {
|
||||
text,
|
||||
warning: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_warning(self, warning: Option<String>) -> Self {
|
||||
match self {
|
||||
Self::Unchanged { .. } => Self::Unchanged { warning },
|
||||
Self::Replaced { text, .. } => Self::Replaced { text, warning },
|
||||
Self::Blocked { reason } => Self::Blocked { reason },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn warning(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Unchanged { warning } | Self::Replaced { warning, .. } => warning.as_deref(),
|
||||
Self::Blocked { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum MessageSubmitStdout {
|
||||
Unchanged,
|
||||
@@ -530,15 +562,16 @@ impl HookExecutor {
|
||||
original_text: &str,
|
||||
) -> MessageSubmitOutcome {
|
||||
if !self.config.enabled {
|
||||
return MessageSubmitOutcome::Unchanged;
|
||||
return MessageSubmitOutcome::unchanged();
|
||||
}
|
||||
|
||||
let hooks = self.config.hooks_for_event(HookEvent::MessageSubmit);
|
||||
if hooks.is_empty() {
|
||||
return MessageSubmitOutcome::Unchanged;
|
||||
return MessageSubmitOutcome::unchanged();
|
||||
}
|
||||
|
||||
let mut current_text = original_text.to_string();
|
||||
let mut warning = None;
|
||||
|
||||
for hook in hooks {
|
||||
let hook_context = context.clone().with_message(¤t_text);
|
||||
@@ -578,6 +611,7 @@ impl HookExecutor {
|
||||
);
|
||||
|
||||
if hook.continue_on_error {
|
||||
warning = message_submit_continue_warning(&result).or(warning);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -607,9 +641,9 @@ impl HookExecutor {
|
||||
}
|
||||
|
||||
if current_text == original_text {
|
||||
MessageSubmitOutcome::Unchanged
|
||||
MessageSubmitOutcome::unchanged().with_warning(warning)
|
||||
} else {
|
||||
MessageSubmitOutcome::Replaced(current_text)
|
||||
MessageSubmitOutcome::replaced(current_text).with_warning(warning)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -847,31 +881,29 @@ impl HookExecutor {
|
||||
}
|
||||
};
|
||||
|
||||
if let (Some(bytes), Some(mut stdin)) = (stdin_bytes, child.stdin.take()) {
|
||||
let _ = stdin.write_all(&bytes);
|
||||
let _ = stdin.write_all(b"\n");
|
||||
let _ = stdin.flush();
|
||||
}
|
||||
|
||||
fn read_pipe(mut pipe: impl Read) -> String {
|
||||
let mut buf = String::new();
|
||||
let _ = pipe.read_to_string(&mut buf);
|
||||
buf
|
||||
}
|
||||
let stdout_reader = child.stdout.take().map(spawn_pipe_reader);
|
||||
let stderr_reader = child.stderr.take().map(spawn_pipe_reader);
|
||||
let _stdin_writer = match (stdin_bytes, child.stdin.take()) {
|
||||
(Some(bytes), Some(stdin)) => Some(spawn_stdin_writer(stdin, bytes)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
match child.wait_timeout(timeout) {
|
||||
Ok(Some(status)) => HookResult {
|
||||
name: hook.name.clone(),
|
||||
success: status.success(),
|
||||
exit_code: status.code(),
|
||||
stdout: child.stdout.take().map(read_pipe).unwrap_or_default(),
|
||||
stderr: child.stderr.take().map(read_pipe).unwrap_or_default(),
|
||||
stdout: join_reader(stdout_reader),
|
||||
stderr: join_reader(stderr_reader),
|
||||
duration: started.elapsed(),
|
||||
error: None,
|
||||
},
|
||||
Ok(None) => {
|
||||
let _ = child.kill();
|
||||
let _ = child.wait();
|
||||
// Do not join pipe threads on timeout: descendant processes can
|
||||
// inherit pipe fds, and waiting for those threads would defeat
|
||||
// the hook timeout we just enforced.
|
||||
HookResult {
|
||||
name: hook.name.clone(),
|
||||
success: false,
|
||||
@@ -882,15 +914,19 @@ impl HookExecutor {
|
||||
error: Some(format!("Hook timed out after {timeout_secs}s")),
|
||||
}
|
||||
}
|
||||
Err(e) => HookResult {
|
||||
name: hook.name.clone(),
|
||||
success: false,
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
duration: started.elapsed(),
|
||||
error: Some(format!("Failed to wait for hook: {e}")),
|
||||
},
|
||||
Err(e) => {
|
||||
let _ = child.kill();
|
||||
let _ = child.wait();
|
||||
HookResult {
|
||||
name: hook.name.clone(),
|
||||
success: false,
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: String::new(),
|
||||
duration: started.elapsed(),
|
||||
error: Some(format!("Failed to wait for hook: {e}")),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -928,6 +964,28 @@ impl HookExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_pipe_reader(mut pipe: impl Read + Send + 'static) -> JoinHandle<String> {
|
||||
thread::spawn(move || {
|
||||
let mut buf = String::new();
|
||||
let _ = pipe.read_to_string(&mut buf);
|
||||
buf
|
||||
})
|
||||
}
|
||||
|
||||
fn join_reader(reader: Option<JoinHandle<String>>) -> String {
|
||||
reader
|
||||
.and_then(|handle| handle.join().ok())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn spawn_stdin_writer(mut stdin: std::process::ChildStdin, mut bytes: Vec<u8>) -> JoinHandle<()> {
|
||||
thread::spawn(move || {
|
||||
bytes.push(b'\n');
|
||||
let _ = stdin.write_all(&bytes);
|
||||
let _ = stdin.flush();
|
||||
})
|
||||
}
|
||||
|
||||
fn message_submit_payload(context: &HookContext, text: &str) -> serde_json::Value {
|
||||
json!({
|
||||
"event": HookEvent::MessageSubmit.as_str(),
|
||||
@@ -956,12 +1014,21 @@ fn parse_message_submit_stdout(stdout: &str) -> MessageSubmitStdout {
|
||||
};
|
||||
|
||||
match object.get("text") {
|
||||
Some(serde_json::Value::String(text)) => MessageSubmitStdout::Replaced(text.clone()),
|
||||
Some(serde_json::Value::String(text)) if !text.is_empty() => {
|
||||
MessageSubmitStdout::Replaced(text.clone())
|
||||
}
|
||||
Some(serde_json::Value::String(_)) => {
|
||||
MessageSubmitStdout::Invalid("stdout `text` field must not be empty".to_string())
|
||||
}
|
||||
Some(_) => MessageSubmitStdout::Invalid("stdout `text` field must be a string".to_string()),
|
||||
None => MessageSubmitStdout::Unchanged,
|
||||
}
|
||||
}
|
||||
|
||||
fn message_submit_continue_warning(result: &HookResult) -> Option<String> {
|
||||
result.error.as_deref().and_then(first_non_empty_line)
|
||||
}
|
||||
|
||||
fn message_submit_block_reason(result: &HookResult, fallback: &str) -> String {
|
||||
if let Some(reason) = message_submit_stdout_reason(&result.stdout) {
|
||||
return reason;
|
||||
@@ -995,13 +1062,10 @@ fn first_non_empty_line(text: &str) -> Option<String> {
|
||||
|
||||
fn truncate_hook_message(message: &str) -> String {
|
||||
const MAX_CHARS: usize = 240;
|
||||
let mut out = String::new();
|
||||
for (idx, ch) in message.chars().enumerate() {
|
||||
if idx >= MAX_CHARS {
|
||||
out.push('…');
|
||||
return out;
|
||||
}
|
||||
out.push(ch);
|
||||
let mut chars = message.chars();
|
||||
let mut out: String = chars.by_ref().take(MAX_CHARS).collect();
|
||||
if chars.next().is_some() {
|
||||
out.push('…');
|
||||
}
|
||||
out
|
||||
}
|
||||
@@ -1128,6 +1192,14 @@ NOEQUAL line dropped
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_message_submit_stdout_rejects_empty_text() {
|
||||
assert_eq!(
|
||||
super::parse_message_submit_stdout(r#"{"text":""}"#),
|
||||
MessageSubmitStdout::Invalid("stdout `text` field must not be empty".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_message_submit_stdout_rejects_non_object_json() {
|
||||
assert!(matches!(
|
||||
@@ -1277,6 +1349,35 @@ NOEQUAL line dropped
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[test]
|
||||
fn message_submit_stdin_write_does_not_deadlock_when_hook_writes_first() {
|
||||
let dir = tempfile::tempdir().expect("tempdir");
|
||||
let command = write_hook_script(
|
||||
&dir,
|
||||
"write_before_read.sh",
|
||||
r#"#!/bin/sh
|
||||
dd if=/dev/zero bs=1024 count=256 2>/dev/null | tr '\000' x
|
||||
dd if=/dev/zero bs=1024 count=256 2>/dev/null | tr '\000' e >&2
|
||||
payload=$(cat)
|
||||
printf '\ndone:%s\n' "${#payload}"
|
||||
"#,
|
||||
);
|
||||
let hook = Hook::new(HookEvent::MessageSubmit, &command).with_timeout(5);
|
||||
let executor = HookExecutor::new(HooksConfig::default(), dir.path().to_path_buf());
|
||||
let env_vars = HashMap::new();
|
||||
let payload = json!({
|
||||
"event": "message_submit",
|
||||
"text": "x".repeat(256 * 1024),
|
||||
});
|
||||
|
||||
let result = executor.execute_sync_with_stdin(&hook, &env_vars, &payload);
|
||||
|
||||
assert!(result.success, "hook should complete: {result:?}");
|
||||
assert!(result.stdout.contains("done:"), "stdout was drained");
|
||||
assert!(result.stderr.len() >= 256 * 1024, "stderr was drained");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_executor_session_id() {
|
||||
let executor = HookExecutor::new(HooksConfig::default(), PathBuf::from("."));
|
||||
@@ -1337,7 +1438,7 @@ esac
|
||||
|
||||
assert_eq!(
|
||||
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
|
||||
MessageSubmitOutcome::Replaced("first second".to_string())
|
||||
MessageSubmitOutcome::replaced("first second".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1390,7 +1491,7 @@ printf '%s\n' '{"text":"ignored"}'
|
||||
|
||||
assert_eq!(
|
||||
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
|
||||
MessageSubmitOutcome::Unchanged
|
||||
MessageSubmitOutcome::unchanged()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1400,7 +1501,7 @@ printf '%s\n' '{"text":"ignored"}'
|
||||
|
||||
assert_eq!(
|
||||
executor.execute_message_submit_transform(&HookContext::new(), "original"),
|
||||
MessageSubmitOutcome::Unchanged
|
||||
MessageSubmitOutcome::unchanged()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1429,7 +1530,7 @@ printf '%s\n' '{"text":"should not apply"}'
|
||||
|
||||
assert_eq!(
|
||||
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
|
||||
MessageSubmitOutcome::Unchanged
|
||||
MessageSubmitOutcome::unchanged()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1465,7 +1566,42 @@ printf '%s\n' '{"text":"recovered"}'
|
||||
|
||||
assert_eq!(
|
||||
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
|
||||
MessageSubmitOutcome::Replaced("recovered".to_string())
|
||||
MessageSubmitOutcome::replaced("recovered".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[test]
|
||||
fn message_submit_timeout_continue_surfaces_warning_and_runs_later_hooks() {
|
||||
let dir = tempfile::tempdir().expect("tempdir");
|
||||
let slow = write_hook_script(
|
||||
&dir,
|
||||
"slow_continue.sh",
|
||||
r#"#!/bin/sh
|
||||
sleep 2
|
||||
"#,
|
||||
);
|
||||
let replacing = write_hook_script(
|
||||
&dir,
|
||||
"replace_after_timeout.sh",
|
||||
r#"#!/bin/sh
|
||||
printf '%s\n' '{"text":"after timeout"}'
|
||||
"#,
|
||||
);
|
||||
let mut slow_hook = Hook::new(HookEvent::MessageSubmit, &slow).with_timeout(1);
|
||||
slow_hook.continue_on_error = true;
|
||||
let config = HooksConfig {
|
||||
enabled: true,
|
||||
hooks: vec![slow_hook, Hook::new(HookEvent::MessageSubmit, &replacing)],
|
||||
working_dir: Some(dir.path().to_path_buf()),
|
||||
..HooksConfig::default()
|
||||
};
|
||||
let executor = HookExecutor::new(config, dir.path().to_path_buf());
|
||||
|
||||
assert_eq!(
|
||||
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
|
||||
MessageSubmitOutcome::replaced("after timeout".to_string())
|
||||
.with_warning(Some("Hook timed out after 1s".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1500,7 +1636,7 @@ printf '%s\n' '{"text":"valid later"}'
|
||||
|
||||
assert_eq!(
|
||||
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
|
||||
MessageSubmitOutcome::Replaced("valid later".to_string())
|
||||
MessageSubmitOutcome::replaced("valid later".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -4511,12 +4511,15 @@ async fn dispatch_user_message(
|
||||
.has_hooks_for_event(crate::hooks::HookEvent::MessageSubmit)
|
||||
{
|
||||
let context = app.base_hook_context().with_message(&message.display);
|
||||
match app
|
||||
let outcome = app
|
||||
.hooks
|
||||
.execute_message_submit_transform(&context, &message.display)
|
||||
{
|
||||
crate::hooks::MessageSubmitOutcome::Unchanged => {}
|
||||
crate::hooks::MessageSubmitOutcome::Replaced(text) => {
|
||||
.execute_message_submit_transform(&context, &message.display);
|
||||
if let Some(warning) = outcome.warning() {
|
||||
app.status_message = Some(warning.to_string());
|
||||
}
|
||||
match outcome {
|
||||
crate::hooks::MessageSubmitOutcome::Unchanged { .. } => {}
|
||||
crate::hooks::MessageSubmitOutcome::Replaced { text, .. } => {
|
||||
message.display = text;
|
||||
}
|
||||
crate::hooks::MessageSubmitOutcome::Blocked { reason } => {
|
||||
|
||||
@@ -2081,13 +2081,20 @@ fn write_message_submit_hook(dir: &TempDir, name: &str, body: &str) -> String {
|
||||
|
||||
#[cfg(not(windows))]
|
||||
fn configure_single_message_submit_hook(app: &mut App, dir: &TempDir, command: String) {
|
||||
configure_message_submit_hooks(app, dir, vec![command]);
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
fn configure_message_submit_hooks(app: &mut App, dir: &TempDir, commands: Vec<String>) {
|
||||
app.hooks = crate::hooks::HookExecutor::new(
|
||||
crate::hooks::HooksConfig {
|
||||
enabled: true,
|
||||
hooks: vec![crate::hooks::Hook::new(
|
||||
crate::hooks::HookEvent::MessageSubmit,
|
||||
&command,
|
||||
)],
|
||||
hooks: commands
|
||||
.iter()
|
||||
.map(|command| {
|
||||
crate::hooks::Hook::new(crate::hooks::HookEvent::MessageSubmit, command)
|
||||
})
|
||||
.collect(),
|
||||
working_dir: Some(dir.path().to_path_buf()),
|
||||
..crate::hooks::HooksConfig::default()
|
||||
},
|
||||
@@ -2095,6 +2102,62 @@ fn configure_single_message_submit_hook(app: &mut App, dir: &TempDir, command: S
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn dispatch_user_message_surfaces_continued_message_submit_timeout() {
|
||||
let dir = TempDir::new().expect("tempdir");
|
||||
let slow = write_message_submit_hook(
|
||||
&dir,
|
||||
"slow.sh",
|
||||
r#"#!/bin/sh
|
||||
sleep 2
|
||||
"#,
|
||||
);
|
||||
let replacing = write_message_submit_hook(
|
||||
&dir,
|
||||
"replace.sh",
|
||||
r#"#!/bin/sh
|
||||
printf '%s\n' '{"text":"after timeout"}'
|
||||
"#,
|
||||
);
|
||||
let mut app = create_test_app();
|
||||
app.hooks = crate::hooks::HookExecutor::new(
|
||||
crate::hooks::HooksConfig {
|
||||
enabled: true,
|
||||
hooks: vec![
|
||||
crate::hooks::Hook::new(crate::hooks::HookEvent::MessageSubmit, &slow)
|
||||
.with_timeout(1),
|
||||
crate::hooks::Hook::new(crate::hooks::HookEvent::MessageSubmit, &replacing),
|
||||
],
|
||||
working_dir: Some(dir.path().to_path_buf()),
|
||||
..crate::hooks::HooksConfig::default()
|
||||
},
|
||||
dir.path().to_path_buf(),
|
||||
);
|
||||
let mut engine = crate::core::engine::mock_engine_handle();
|
||||
let config = Config::default();
|
||||
|
||||
dispatch_user_message(
|
||||
&mut app,
|
||||
&config,
|
||||
&engine.handle,
|
||||
QueuedMessage::new("hello".to_string(), None),
|
||||
)
|
||||
.await
|
||||
.expect("dispatch user message");
|
||||
|
||||
assert_eq!(
|
||||
app.status_message.as_deref(),
|
||||
Some("Hook timed out after 1s")
|
||||
);
|
||||
match engine.rx_op.recv().await.expect("send message op") {
|
||||
crate::core::ops::Op::SendMessage { content, .. } => {
|
||||
assert_eq!(content, "after timeout");
|
||||
}
|
||||
other => panic!("expected SendMessage, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
#[tokio::test]
|
||||
async fn dispatch_user_message_uses_transformed_message_submit_text() {
|
||||
|
||||
@@ -405,7 +405,7 @@ The hook receives JSON on stdin:
|
||||
}
|
||||
```
|
||||
|
||||
If the hook exits `0` and prints JSON with a string `text` field,
|
||||
If the hook exits `0` and prints JSON with a non-empty string `text` field,
|
||||
that value replaces the submitted text:
|
||||
|
||||
```json
|
||||
@@ -413,10 +413,13 @@ that value replaces the submitted text:
|
||||
```
|
||||
|
||||
Exit `0` with empty stdout, or stdout JSON without `text`, leaves
|
||||
the current text unchanged. Exit `2` blocks the submission before
|
||||
the turn starts; a `reason` field, stderr, or stdout can provide the
|
||||
status message shown in the TUI. Other non-zero exits follow the
|
||||
hook's `continue_on_error` setting.
|
||||
the current text unchanged. A JSON `text` field must not be empty;
|
||||
`{"text":""}` is treated as invalid stdout and ignored. Exit `2`
|
||||
blocks the submission before the turn starts; a `reason` field,
|
||||
stderr, or stdout can provide the status message shown in the TUI.
|
||||
Other non-zero exits follow the hook's `continue_on_error` setting.
|
||||
Timeouts and spawn failures are also surfaced as transient TUI status
|
||||
messages when `continue_on_error = true` lets submission continue.
|
||||
|
||||
Multiple `message_submit` hooks run in config order, and each hook
|
||||
receives the text produced by the previous hook. Hooks marked
|
||||
|
||||
@@ -186,7 +186,7 @@ command = "~/.codewhale/hooks/pre.sh" # / message_submit / mode_change /
|
||||
</p>
|
||||
<p className="mt-3 text-sm text-ink-soft leading-[1.9]">
|
||||
<code className="inline">message_submit</code> hooks run before a user message is sent to the model. A non-background hook can print
|
||||
<code className="inline">{'{"text":"replacement"}'}</code> on stdout to replace the message, or exit with code <code className="inline">2</code> to block the submission.
|
||||
<code className="inline">{'{"text":"replacement"}'}</code> on stdout to replace the message; <code className="inline">text</code> must be non-empty. Exit with code <code className="inline">2</code> to block the submission.
|
||||
<code className="inline">shell_env</code> keeps its existing <code className="inline">KEY=VALUE</code> stdout contract.
|
||||
</p>
|
||||
</section>
|
||||
@@ -441,7 +441,7 @@ command = "~/.codewhale/hooks/pre.sh" # / message_submit / mode_change /
|
||||
</p>
|
||||
<p className="mt-3 text-sm text-ink-soft leading-relaxed">
|
||||
<code className="inline">message_submit</code> hooks run before a user message is sent to the model. A non-background hook can print
|
||||
<code className="inline">{'{"text":"replacement"}'}</code> on stdout to replace the message, or exit with code <code className="inline">2</code> to block the submission.
|
||||
<code className="inline">{'{"text":"replacement"}'}</code> on stdout to replace the message; <code className="inline">text</code> must be non-empty. Exit with code <code className="inline">2</code> to block the submission.
|
||||
<code className="inline">shell_env</code> keeps its existing <code className="inline">KEY=VALUE</code> stdout contract.
|
||||
</p>
|
||||
</section>
|
||||
|
||||
Reference in New Issue
Block a user