fix(hooks): harden message_submit review cases

This commit is contained in:
ningjingkun
2026-05-29 11:40:20 +08:00
committed by Hunter B
parent 467e2cbfff
commit 4146ec617e
5 changed files with 263 additions and 58 deletions
+178 -42
View File
@@ -19,6 +19,7 @@ use std::collections::HashMap;
use std::io::{Read, Write};
use std::path::PathBuf;
use std::process::{Command, Stdio};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use wait_timeout::ChildExt;
@@ -428,13 +429,44 @@ pub struct HookResult {
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageSubmitOutcome {
/// No hook changed the submitted text.
Unchanged,
Unchanged { warning: Option<String> },
/// One or more hooks replaced the submitted text.
Replaced(String),
Replaced {
text: String,
warning: Option<String>,
},
/// A hook intentionally blocked the submission.
Blocked { reason: String },
}
impl MessageSubmitOutcome {
pub fn unchanged() -> Self {
Self::Unchanged { warning: None }
}
pub fn replaced(text: String) -> Self {
Self::Replaced {
text,
warning: None,
}
}
fn with_warning(self, warning: Option<String>) -> Self {
match self {
Self::Unchanged { .. } => Self::Unchanged { warning },
Self::Replaced { text, .. } => Self::Replaced { text, warning },
Self::Blocked { reason } => Self::Blocked { reason },
}
}
pub fn warning(&self) -> Option<&str> {
match self {
Self::Unchanged { warning } | Self::Replaced { warning, .. } => warning.as_deref(),
Self::Blocked { .. } => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum MessageSubmitStdout {
Unchanged,
@@ -530,15 +562,16 @@ impl HookExecutor {
original_text: &str,
) -> MessageSubmitOutcome {
if !self.config.enabled {
return MessageSubmitOutcome::Unchanged;
return MessageSubmitOutcome::unchanged();
}
let hooks = self.config.hooks_for_event(HookEvent::MessageSubmit);
if hooks.is_empty() {
return MessageSubmitOutcome::Unchanged;
return MessageSubmitOutcome::unchanged();
}
let mut current_text = original_text.to_string();
let mut warning = None;
for hook in hooks {
let hook_context = context.clone().with_message(&current_text);
@@ -578,6 +611,7 @@ impl HookExecutor {
);
if hook.continue_on_error {
warning = message_submit_continue_warning(&result).or(warning);
continue;
}
@@ -607,9 +641,9 @@ impl HookExecutor {
}
if current_text == original_text {
MessageSubmitOutcome::Unchanged
MessageSubmitOutcome::unchanged().with_warning(warning)
} else {
MessageSubmitOutcome::Replaced(current_text)
MessageSubmitOutcome::replaced(current_text).with_warning(warning)
}
}
@@ -847,31 +881,29 @@ impl HookExecutor {
}
};
if let (Some(bytes), Some(mut stdin)) = (stdin_bytes, child.stdin.take()) {
let _ = stdin.write_all(&bytes);
let _ = stdin.write_all(b"\n");
let _ = stdin.flush();
}
fn read_pipe(mut pipe: impl Read) -> String {
let mut buf = String::new();
let _ = pipe.read_to_string(&mut buf);
buf
}
let stdout_reader = child.stdout.take().map(spawn_pipe_reader);
let stderr_reader = child.stderr.take().map(spawn_pipe_reader);
let _stdin_writer = match (stdin_bytes, child.stdin.take()) {
(Some(bytes), Some(stdin)) => Some(spawn_stdin_writer(stdin, bytes)),
_ => None,
};
match child.wait_timeout(timeout) {
Ok(Some(status)) => HookResult {
name: hook.name.clone(),
success: status.success(),
exit_code: status.code(),
stdout: child.stdout.take().map(read_pipe).unwrap_or_default(),
stderr: child.stderr.take().map(read_pipe).unwrap_or_default(),
stdout: join_reader(stdout_reader),
stderr: join_reader(stderr_reader),
duration: started.elapsed(),
error: None,
},
Ok(None) => {
let _ = child.kill();
let _ = child.wait();
// Do not join pipe threads on timeout: descendant processes can
// inherit pipe fds, and waiting for those threads would defeat
// the hook timeout we just enforced.
HookResult {
name: hook.name.clone(),
success: false,
@@ -882,15 +914,19 @@ impl HookExecutor {
error: Some(format!("Hook timed out after {timeout_secs}s")),
}
}
Err(e) => HookResult {
name: hook.name.clone(),
success: false,
exit_code: None,
stdout: String::new(),
stderr: String::new(),
duration: started.elapsed(),
error: Some(format!("Failed to wait for hook: {e}")),
},
Err(e) => {
let _ = child.kill();
let _ = child.wait();
HookResult {
name: hook.name.clone(),
success: false,
exit_code: None,
stdout: String::new(),
stderr: String::new(),
duration: started.elapsed(),
error: Some(format!("Failed to wait for hook: {e}")),
}
}
}
}
@@ -928,6 +964,28 @@ impl HookExecutor {
}
}
fn spawn_pipe_reader(mut pipe: impl Read + Send + 'static) -> JoinHandle<String> {
thread::spawn(move || {
let mut buf = String::new();
let _ = pipe.read_to_string(&mut buf);
buf
})
}
fn join_reader(reader: Option<JoinHandle<String>>) -> String {
reader
.and_then(|handle| handle.join().ok())
.unwrap_or_default()
}
fn spawn_stdin_writer(mut stdin: std::process::ChildStdin, mut bytes: Vec<u8>) -> JoinHandle<()> {
thread::spawn(move || {
bytes.push(b'\n');
let _ = stdin.write_all(&bytes);
let _ = stdin.flush();
})
}
fn message_submit_payload(context: &HookContext, text: &str) -> serde_json::Value {
json!({
"event": HookEvent::MessageSubmit.as_str(),
@@ -956,12 +1014,21 @@ fn parse_message_submit_stdout(stdout: &str) -> MessageSubmitStdout {
};
match object.get("text") {
Some(serde_json::Value::String(text)) => MessageSubmitStdout::Replaced(text.clone()),
Some(serde_json::Value::String(text)) if !text.is_empty() => {
MessageSubmitStdout::Replaced(text.clone())
}
Some(serde_json::Value::String(_)) => {
MessageSubmitStdout::Invalid("stdout `text` field must not be empty".to_string())
}
Some(_) => MessageSubmitStdout::Invalid("stdout `text` field must be a string".to_string()),
None => MessageSubmitStdout::Unchanged,
}
}
fn message_submit_continue_warning(result: &HookResult) -> Option<String> {
result.error.as_deref().and_then(first_non_empty_line)
}
fn message_submit_block_reason(result: &HookResult, fallback: &str) -> String {
if let Some(reason) = message_submit_stdout_reason(&result.stdout) {
return reason;
@@ -995,13 +1062,10 @@ fn first_non_empty_line(text: &str) -> Option<String> {
fn truncate_hook_message(message: &str) -> String {
const MAX_CHARS: usize = 240;
let mut out = String::new();
for (idx, ch) in message.chars().enumerate() {
if idx >= MAX_CHARS {
out.push('…');
return out;
}
out.push(ch);
let mut chars = message.chars();
let mut out: String = chars.by_ref().take(MAX_CHARS).collect();
if chars.next().is_some() {
out.push('…');
}
out
}
@@ -1128,6 +1192,14 @@ NOEQUAL line dropped
));
}
#[test]
fn parse_message_submit_stdout_rejects_empty_text() {
assert_eq!(
super::parse_message_submit_stdout(r#"{"text":""}"#),
MessageSubmitStdout::Invalid("stdout `text` field must not be empty".to_string())
);
}
#[test]
fn parse_message_submit_stdout_rejects_non_object_json() {
assert!(matches!(
@@ -1277,6 +1349,35 @@ NOEQUAL line dropped
);
}
#[cfg(not(windows))]
#[test]
fn message_submit_stdin_write_does_not_deadlock_when_hook_writes_first() {
let dir = tempfile::tempdir().expect("tempdir");
let command = write_hook_script(
&dir,
"write_before_read.sh",
r#"#!/bin/sh
dd if=/dev/zero bs=1024 count=256 2>/dev/null | tr '\000' x
dd if=/dev/zero bs=1024 count=256 2>/dev/null | tr '\000' e >&2
payload=$(cat)
printf '\ndone:%s\n' "${#payload}"
"#,
);
let hook = Hook::new(HookEvent::MessageSubmit, &command).with_timeout(5);
let executor = HookExecutor::new(HooksConfig::default(), dir.path().to_path_buf());
let env_vars = HashMap::new();
let payload = json!({
"event": "message_submit",
"text": "x".repeat(256 * 1024),
});
let result = executor.execute_sync_with_stdin(&hook, &env_vars, &payload);
assert!(result.success, "hook should complete: {result:?}");
assert!(result.stdout.contains("done:"), "stdout was drained");
assert!(result.stderr.len() >= 256 * 1024, "stderr was drained");
}
#[test]
fn test_executor_session_id() {
let executor = HookExecutor::new(HooksConfig::default(), PathBuf::from("."));
@@ -1337,7 +1438,7 @@ esac
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::Replaced("first second".to_string())
MessageSubmitOutcome::replaced("first second".to_string())
);
}
@@ -1390,7 +1491,7 @@ printf '%s\n' '{"text":"ignored"}'
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::Unchanged
MessageSubmitOutcome::unchanged()
);
}
@@ -1400,7 +1501,7 @@ printf '%s\n' '{"text":"ignored"}'
assert_eq!(
executor.execute_message_submit_transform(&HookContext::new(), "original"),
MessageSubmitOutcome::Unchanged
MessageSubmitOutcome::unchanged()
);
}
@@ -1429,7 +1530,7 @@ printf '%s\n' '{"text":"should not apply"}'
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::Unchanged
MessageSubmitOutcome::unchanged()
);
}
@@ -1465,7 +1566,42 @@ printf '%s\n' '{"text":"recovered"}'
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::Replaced("recovered".to_string())
MessageSubmitOutcome::replaced("recovered".to_string())
);
}
#[cfg(not(windows))]
#[test]
fn message_submit_timeout_continue_surfaces_warning_and_runs_later_hooks() {
let dir = tempfile::tempdir().expect("tempdir");
let slow = write_hook_script(
&dir,
"slow_continue.sh",
r#"#!/bin/sh
sleep 2
"#,
);
let replacing = write_hook_script(
&dir,
"replace_after_timeout.sh",
r#"#!/bin/sh
printf '%s\n' '{"text":"after timeout"}'
"#,
);
let mut slow_hook = Hook::new(HookEvent::MessageSubmit, &slow).with_timeout(1);
slow_hook.continue_on_error = true;
let config = HooksConfig {
enabled: true,
hooks: vec![slow_hook, Hook::new(HookEvent::MessageSubmit, &replacing)],
working_dir: Some(dir.path().to_path_buf()),
..HooksConfig::default()
};
let executor = HookExecutor::new(config, dir.path().to_path_buf());
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::replaced("after timeout".to_string())
.with_warning(Some("Hook timed out after 1s".to_string()))
);
}
@@ -1500,7 +1636,7 @@ printf '%s\n' '{"text":"valid later"}'
assert_eq!(
executor.execute_message_submit_transform(&submit_context(&dir), "original"),
MessageSubmitOutcome::Replaced("valid later".to_string())
MessageSubmitOutcome::replaced("valid later".to_string())
);
}
+8 -5
View File
@@ -4511,12 +4511,15 @@ async fn dispatch_user_message(
.has_hooks_for_event(crate::hooks::HookEvent::MessageSubmit)
{
let context = app.base_hook_context().with_message(&message.display);
match app
let outcome = app
.hooks
.execute_message_submit_transform(&context, &message.display)
{
crate::hooks::MessageSubmitOutcome::Unchanged => {}
crate::hooks::MessageSubmitOutcome::Replaced(text) => {
.execute_message_submit_transform(&context, &message.display);
if let Some(warning) = outcome.warning() {
app.status_message = Some(warning.to_string());
}
match outcome {
crate::hooks::MessageSubmitOutcome::Unchanged { .. } => {}
crate::hooks::MessageSubmitOutcome::Replaced { text, .. } => {
message.display = text;
}
crate::hooks::MessageSubmitOutcome::Blocked { reason } => {
+67 -4
View File
@@ -2081,13 +2081,20 @@ fn write_message_submit_hook(dir: &TempDir, name: &str, body: &str) -> String {
#[cfg(not(windows))]
fn configure_single_message_submit_hook(app: &mut App, dir: &TempDir, command: String) {
configure_message_submit_hooks(app, dir, vec![command]);
}
#[cfg(not(windows))]
fn configure_message_submit_hooks(app: &mut App, dir: &TempDir, commands: Vec<String>) {
app.hooks = crate::hooks::HookExecutor::new(
crate::hooks::HooksConfig {
enabled: true,
hooks: vec![crate::hooks::Hook::new(
crate::hooks::HookEvent::MessageSubmit,
&command,
)],
hooks: commands
.iter()
.map(|command| {
crate::hooks::Hook::new(crate::hooks::HookEvent::MessageSubmit, command)
})
.collect(),
working_dir: Some(dir.path().to_path_buf()),
..crate::hooks::HooksConfig::default()
},
@@ -2095,6 +2102,62 @@ fn configure_single_message_submit_hook(app: &mut App, dir: &TempDir, command: S
);
}
#[cfg(not(windows))]
#[tokio::test]
async fn dispatch_user_message_surfaces_continued_message_submit_timeout() {
let dir = TempDir::new().expect("tempdir");
let slow = write_message_submit_hook(
&dir,
"slow.sh",
r#"#!/bin/sh
sleep 2
"#,
);
let replacing = write_message_submit_hook(
&dir,
"replace.sh",
r#"#!/bin/sh
printf '%s\n' '{"text":"after timeout"}'
"#,
);
let mut app = create_test_app();
app.hooks = crate::hooks::HookExecutor::new(
crate::hooks::HooksConfig {
enabled: true,
hooks: vec![
crate::hooks::Hook::new(crate::hooks::HookEvent::MessageSubmit, &slow)
.with_timeout(1),
crate::hooks::Hook::new(crate::hooks::HookEvent::MessageSubmit, &replacing),
],
working_dir: Some(dir.path().to_path_buf()),
..crate::hooks::HooksConfig::default()
},
dir.path().to_path_buf(),
);
let mut engine = crate::core::engine::mock_engine_handle();
let config = Config::default();
dispatch_user_message(
&mut app,
&config,
&engine.handle,
QueuedMessage::new("hello".to_string(), None),
)
.await
.expect("dispatch user message");
assert_eq!(
app.status_message.as_deref(),
Some("Hook timed out after 1s")
);
match engine.rx_op.recv().await.expect("send message op") {
crate::core::ops::Op::SendMessage { content, .. } => {
assert_eq!(content, "after timeout");
}
other => panic!("expected SendMessage, got {other:?}"),
}
}
#[cfg(not(windows))]
#[tokio::test]
async fn dispatch_user_message_uses_transformed_message_submit_text() {
+8 -5
View File
@@ -405,7 +405,7 @@ The hook receives JSON on stdin:
}
```
If the hook exits `0` and prints JSON with a string `text` field,
If the hook exits `0` and prints JSON with a non-empty string `text` field,
that value replaces the submitted text:
```json
@@ -413,10 +413,13 @@ that value replaces the submitted text:
```
Exit `0` with empty stdout, or stdout JSON without `text`, leaves
the current text unchanged. Exit `2` blocks the submission before
the turn starts; a `reason` field, stderr, or stdout can provide the
status message shown in the TUI. Other non-zero exits follow the
hook's `continue_on_error` setting.
the current text unchanged. A JSON `text` field must not be empty;
`{"text":""}` is treated as invalid stdout and ignored. Exit `2`
blocks the submission before the turn starts; a `reason` field,
stderr, or stdout can provide the status message shown in the TUI.
Other non-zero exits follow the hook's `continue_on_error` setting.
Timeouts and spawn failures are also surfaced as transient TUI status
messages when `continue_on_error = true` lets submission continue.
Multiple `message_submit` hooks run in config order, and each hook
receives the text produced by the previous hook. Hooks marked
+2 -2
View File
@@ -186,7 +186,7 @@ command = "~/.codewhale/hooks/pre.sh" # / message_submit / mode_change /
</p>
<p className="mt-3 text-sm text-ink-soft leading-[1.9]">
<code className="inline">message_submit</code> hooks run before a user message is sent to the model. A non-background hook can print
<code className="inline">{'{"text":"replacement"}'}</code> on stdout to replace the message, or exit with code <code className="inline">2</code> to block the submission.
<code className="inline">{'{"text":"replacement"}'}</code> on stdout to replace the message; <code className="inline">text</code> must be non-empty. Exit with code <code className="inline">2</code> to block the submission.
<code className="inline">shell_env</code> keeps its existing <code className="inline">KEY=VALUE</code> stdout contract.
</p>
</section>
@@ -441,7 +441,7 @@ command = "~/.codewhale/hooks/pre.sh" # / message_submit / mode_change /
</p>
<p className="mt-3 text-sm text-ink-soft leading-relaxed">
<code className="inline">message_submit</code> hooks run before a user message is sent to the model. A non-background hook can print
<code className="inline">{'{"text":"replacement"}'}</code> on stdout to replace the message, or exit with code <code className="inline">2</code> to block the submission.
<code className="inline">{'{"text":"replacement"}'}</code> on stdout to replace the message; <code className="inline">text</code> must be non-empty. Exit with code <code className="inline">2</code> to block the submission.
<code className="inline">shell_env</code> keeps its existing <code className="inline">KEY=VALUE</code> stdout contract.
</p>
</section>