Allow hooks to error (#11615)

Allow hooks to return errors. 

We should do this before introducing more hook types, or we'll have to
migrate them all.
This commit is contained in:
gt-oai
2026-02-16 14:11:05 +00:00
committed by GitHub
parent 825a4af42f
commit b3095679ed
6 changed files with 251 additions and 43 deletions

View File

@@ -44,6 +44,7 @@ use async_channel::Sender;
use codex_hooks::HookEvent;
use codex_hooks::HookEventAfterAgent;
use codex_hooks::HookPayload;
use codex_hooks::HookResult;
use codex_hooks::Hooks;
use codex_hooks::HooksConfig;
use codex_network_proxy::NetworkProxy;
@@ -4533,7 +4534,8 @@ pub(crate) async fn run_turn(
if !needs_follow_up {
last_agent_message = sampling_request_last_agent_message;
sess.hooks()
let hook_outcomes = sess
.hooks()
.dispatch(HookPayload {
session_id: sess.conversation_id,
cwd: turn_context.cwd.clone(),
@@ -4548,6 +4550,47 @@ pub(crate) async fn run_turn(
},
})
.await;
let mut abort_message = None;
for hook_outcome in hook_outcomes {
let hook_name = hook_outcome.hook_name;
match hook_outcome.result {
HookResult::Success => {}
HookResult::FailedContinue(error) => {
warn!(
turn_id = %turn_context.sub_id,
hook_name = %hook_name,
error = %error,
"after_agent hook failed; continuing"
);
}
HookResult::FailedAbort(error) => {
let message = format!(
"after_agent hook '{hook_name}' failed and aborted turn completion: {error}"
);
warn!(
turn_id = %turn_context.sub_id,
hook_name = %hook_name,
error = %error,
"after_agent hook failed; aborting operation"
);
if abort_message.is_none() {
abort_message = Some(message);
}
}
}
}
if let Some(message) = abort_message {
sess.send_event(
&turn_context,
EventMsg::Error(ErrorEvent {
message,
codex_error_info: None,
}),
)
.await;
return None;
}
break;
}
continue;

View File

@@ -15,6 +15,7 @@ use async_trait::async_trait;
use codex_hooks::HookEvent;
use codex_hooks::HookEventAfterToolUse;
use codex_hooks::HookPayload;
use codex_hooks::HookResult;
use codex_hooks::HookToolInput;
use codex_hooks::HookToolInputLocalShell;
use codex_hooks::HookToolKind;
@@ -172,7 +173,7 @@ impl ToolRegistry {
Ok((preview, success)) => (preview.clone(), *success),
Err(err) => (err.to_string(), false),
};
dispatch_after_tool_use_hook(AfterToolUseHookDispatch {
let hook_abort_error = dispatch_after_tool_use_hook(AfterToolUseHookDispatch {
invocation: &invocation,
output_preview,
success,
@@ -182,6 +183,10 @@ impl ToolRegistry {
})
.await;
if let Some(err) = hook_abort_error {
return Err(err);
}
match result {
Ok(_) => {
let mut guard = output_cell.lock().await;
@@ -339,12 +344,14 @@ struct AfterToolUseHookDispatch<'a> {
mutating: bool,
}
async fn dispatch_after_tool_use_hook(dispatch: AfterToolUseHookDispatch<'_>) {
async fn dispatch_after_tool_use_hook(
dispatch: AfterToolUseHookDispatch<'_>,
) -> Option<FunctionCallError> {
let AfterToolUseHookDispatch { invocation, .. } = dispatch;
let session = invocation.session.as_ref();
let turn = invocation.turn.as_ref();
let tool_input = HookToolInput::from(&invocation.payload);
session
let hook_outcomes = session
.hooks()
.dispatch(HookPayload {
session_id: session.conversation_id,
@@ -373,4 +380,34 @@ async fn dispatch_after_tool_use_hook(dispatch: AfterToolUseHookDispatch<'_>) {
},
})
.await;
for hook_outcome in hook_outcomes {
let hook_name = hook_outcome.hook_name;
match hook_outcome.result {
HookResult::Success => {}
HookResult::FailedContinue(error) => {
warn!(
call_id = %invocation.call_id,
tool_name = %invocation.tool_name,
hook_name = %hook_name,
error = %error,
"after_tool_use hook failed; continuing"
);
}
HookResult::FailedAbort(error) => {
warn!(
call_id = %invocation.call_id,
tool_name = %invocation.tool_name,
hook_name = %hook_name,
error = %error,
"after_tool_use hook failed; aborting operation"
);
return Some(FunctionCallError::Fatal(format!(
"after_tool_use hook '{hook_name}' failed and aborted operation: {error}"
)));
}
}
}
None
}

View File

@@ -9,8 +9,9 @@ pub use types::Hook;
pub use types::HookEvent;
pub use types::HookEventAfterAgent;
pub use types::HookEventAfterToolUse;
pub use types::HookOutcome;
pub use types::HookPayload;
pub use types::HookResponse;
pub use types::HookResult;
pub use types::HookToolInput;
pub use types::HookToolInputLocalShell;
pub use types::HookToolKind;

View File

@@ -2,8 +2,8 @@ use tokio::process::Command;
use crate::types::Hook;
use crate::types::HookEvent;
use crate::types::HookOutcome;
use crate::types::HookPayload;
use crate::types::HookResponse;
#[derive(Default, Clone)]
pub struct HooksConfig {
@@ -45,14 +45,19 @@ impl Hooks {
}
}
pub async fn dispatch(&self, hook_payload: HookPayload) {
// TODO(gt): support interrupting program execution by returning a result here.
for hook in self.hooks_for_event(&hook_payload.hook_event) {
pub async fn dispatch(&self, hook_payload: HookPayload) -> Vec<HookResponse> {
let hooks = self.hooks_for_event(&hook_payload.hook_event);
let mut outcomes = Vec::with_capacity(hooks.len());
for hook in hooks {
let outcome = hook.execute(&hook_payload).await;
if matches!(outcome, HookOutcome::Stop) {
let should_abort_operation = outcome.result.should_abort_operation();
outcomes.push(outcome);
if should_abort_operation {
break;
}
}
outcomes
}
}
@@ -88,6 +93,7 @@ mod tests {
use super::*;
use crate::types::HookEventAfterAgent;
use crate::types::HookEventAfterToolUse;
use crate::types::HookResult;
use crate::types::HookToolInput;
use crate::types::HookToolKind;
@@ -113,14 +119,50 @@ mod tests {
}
}
fn counting_hook(calls: &Arc<AtomicUsize>, outcome: HookOutcome) -> Hook {
fn counting_success_hook(calls: &Arc<AtomicUsize>, name: &str) -> Hook {
let hook_name = name.to_string();
let calls = Arc::clone(calls);
Hook {
name: hook_name,
func: Arc::new(move |_| {
let calls = Arc::clone(&calls);
Box::pin(async move {
calls.fetch_add(1, Ordering::SeqCst);
outcome
HookResult::Success
})
}),
}
}
fn failing_continue_hook(calls: &Arc<AtomicUsize>, name: &str, message: &str) -> Hook {
let hook_name = name.to_string();
let message = message.to_string();
let calls = Arc::clone(calls);
Hook {
name: hook_name,
func: Arc::new(move |_| {
let calls = Arc::clone(&calls);
let message = message.clone();
Box::pin(async move {
calls.fetch_add(1, Ordering::SeqCst);
HookResult::FailedContinue(std::io::Error::other(message).into())
})
}),
}
}
fn failing_abort_hook(calls: &Arc<AtomicUsize>, name: &str, message: &str) -> Hook {
let hook_name = name.to_string();
let message = message.to_string();
let calls = Arc::clone(calls);
Hook {
name: hook_name,
func: Arc::new(move |_| {
let calls = Arc::clone(&calls);
let message = message.clone();
Box::pin(async move {
calls.fetch_add(1, Ordering::SeqCst);
HookResult::FailedAbort(std::io::Error::other(message).into())
})
}),
}
@@ -212,11 +254,14 @@ mod tests {
async fn dispatch_executes_hook() {
let calls = Arc::new(AtomicUsize::new(0));
let hooks = Hooks {
after_agent: vec![counting_hook(&calls, HookOutcome::Continue)],
after_agent: vec![counting_success_hook(&calls, "counting")],
..Hooks::default()
};
hooks.dispatch(hook_payload("1")).await;
let outcomes = hooks.dispatch(hook_payload("1")).await;
assert_eq!(outcomes.len(), 1);
assert_eq!(outcomes[0].hook_name, "counting");
assert!(matches!(outcomes[0].result, HookResult::Success));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
@@ -224,7 +269,8 @@ mod tests {
async fn default_hook_is_noop_and_continues() {
let payload = hook_payload("d");
let outcome = Hook::default().execute(&payload).await;
assert_eq!(outcome, HookOutcome::Continue);
assert_eq!(outcome.hook_name, "default");
assert!(matches!(outcome.result, HookResult::Success));
}
#[tokio::test]
@@ -232,28 +278,36 @@ mod tests {
let calls = Arc::new(AtomicUsize::new(0));
let hooks = Hooks {
after_agent: vec![
counting_hook(&calls, HookOutcome::Continue),
counting_hook(&calls, HookOutcome::Continue),
counting_success_hook(&calls, "counting-1"),
counting_success_hook(&calls, "counting-2"),
],
..Hooks::default()
};
hooks.dispatch(hook_payload("2")).await;
let outcomes = hooks.dispatch(hook_payload("2")).await;
assert_eq!(outcomes.len(), 2);
assert_eq!(outcomes[0].hook_name, "counting-1");
assert_eq!(outcomes[1].hook_name, "counting-2");
assert!(matches!(outcomes[0].result, HookResult::Success));
assert!(matches!(outcomes[1].result, HookResult::Success));
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn dispatch_stops_when_hook_returns_stop() {
async fn dispatch_stops_when_hook_requests_abort() {
let calls = Arc::new(AtomicUsize::new(0));
let hooks = Hooks {
after_agent: vec![
counting_hook(&calls, HookOutcome::Stop),
counting_hook(&calls, HookOutcome::Continue),
failing_abort_hook(&calls, "abort", "hook failed"),
counting_success_hook(&calls, "counting"),
],
..Hooks::default()
};
hooks.dispatch(hook_payload("3")).await;
let outcomes = hooks.dispatch(hook_payload("3")).await;
assert_eq!(outcomes.len(), 1);
assert_eq!(outcomes[0].hook_name, "abort");
assert!(matches!(outcomes[0].result, HookResult::FailedAbort(_)));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
@@ -261,11 +315,53 @@ mod tests {
async fn dispatch_executes_after_tool_use_hooks() {
let calls = Arc::new(AtomicUsize::new(0));
let hooks = Hooks {
after_tool_use: vec![counting_hook(&calls, HookOutcome::Continue)],
after_tool_use: vec![counting_success_hook(&calls, "counting")],
..Hooks::default()
};
hooks.dispatch(after_tool_use_payload("p")).await;
let outcomes = hooks.dispatch(after_tool_use_payload("p")).await;
assert_eq!(outcomes.len(), 1);
assert_eq!(outcomes[0].hook_name, "counting");
assert!(matches!(outcomes[0].result, HookResult::Success));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn dispatch_continues_after_continueable_failure() {
let calls = Arc::new(AtomicUsize::new(0));
let hooks = Hooks {
after_agent: vec![
failing_continue_hook(&calls, "failing", "hook failed"),
counting_success_hook(&calls, "counting"),
],
..Hooks::default()
};
let outcomes = hooks.dispatch(hook_payload("err")).await;
assert_eq!(outcomes.len(), 2);
assert_eq!(outcomes[0].hook_name, "failing");
assert!(matches!(outcomes[0].result, HookResult::FailedContinue(_)));
assert_eq!(outcomes[1].hook_name, "counting");
assert!(matches!(outcomes[1].result, HookResult::Success));
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn dispatch_returns_after_tool_use_failure_outcome() {
let calls = Arc::new(AtomicUsize::new(0));
let hooks = Hooks {
after_tool_use: vec![failing_continue_hook(
&calls,
"failing",
"after_tool_use hook failed",
)],
..Hooks::default()
};
let outcomes = hooks.dispatch(after_tool_use_payload("err-tool")).await;
assert_eq!(outcomes.len(), 1);
assert_eq!(outcomes[0].hook_name, "failing");
assert!(matches!(outcomes[0].result, HookResult::FailedContinue(_)));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
@@ -276,6 +372,7 @@ mod tests {
let payload_path = temp_dir.path().join("payload.json");
let payload_path_arg = payload_path.to_string_lossy().into_owned();
let hook = Hook {
name: "write_payload".to_string(),
func: Arc::new(move |payload: &HookPayload| {
let payload_path_arg = payload_path_arg.clone();
Box::pin(async move {
@@ -290,7 +387,7 @@ mod tests {
])
.expect("build command");
command.status().await.expect("run hook command");
HookOutcome::Continue
HookResult::Success
})
}),
};
@@ -302,7 +399,9 @@ mod tests {
after_agent: vec![hook],
..Hooks::default()
};
hooks.dispatch(payload).await;
let outcomes = hooks.dispatch(payload).await;
assert_eq!(outcomes.len(), 1);
assert!(matches!(outcomes[0].result, HookResult::Success));
let contents = timeout(Duration::from_secs(2), async {
loop {
@@ -330,6 +429,7 @@ mod tests {
fs::write(&script_path, "[IO.File]::WriteAllText($args[0], $args[1])")?;
let script_path_arg = script_path.to_string_lossy().into_owned();
let hook = Hook {
name: "write_payload".to_string(),
func: Arc::new(move |payload: &HookPayload| {
let payload_path_arg = payload_path_arg.clone();
let script_path_arg = script_path_arg.clone();
@@ -348,7 +448,7 @@ mod tests {
])
.expect("build command");
command.status().await.expect("run hook command");
HookOutcome::Continue
HookResult::Success
})
}),
};
@@ -360,7 +460,9 @@ mod tests {
after_agent: vec![hook],
..Hooks::default()
};
hooks.dispatch(payload).await;
let outcomes = hooks.dispatch(payload).await;
assert_eq!(outcomes.len(), 1);
assert!(matches!(outcomes[0].result, HookResult::Success));
let contents = timeout(Duration::from_secs(2), async {
loop {

View File

@@ -10,24 +10,53 @@ use futures::future::BoxFuture;
use serde::Serialize;
use serde::Serializer;
pub type HookFn = Arc<dyn for<'a> Fn(&'a HookPayload) -> BoxFuture<'a, HookOutcome> + Send + Sync>;
pub type HookFn = Arc<dyn for<'a> Fn(&'a HookPayload) -> BoxFuture<'a, HookResult> + Send + Sync>;
#[derive(Debug)]
pub enum HookResult {
/// Success: hook completed successfully.
Success,
/// FailedContinue: hook failed, but other subsequent hooks should still execute and the
/// operation should continue.
FailedContinue(Box<dyn std::error::Error + Send + Sync + 'static>),
/// FailedAbort: hook failed, other subsequent hooks should not execute, and the operation
/// should be aborted.
FailedAbort(Box<dyn std::error::Error + Send + Sync + 'static>),
}
impl HookResult {
pub fn should_abort_operation(&self) -> bool {
matches!(self, Self::FailedAbort(_))
}
}
#[derive(Debug)]
pub struct HookResponse {
pub hook_name: String,
pub result: HookResult,
}
#[derive(Clone)]
pub struct Hook {
pub name: String,
pub func: HookFn,
}
impl Default for Hook {
fn default() -> Self {
Self {
func: Arc::new(|_| Box::pin(async { HookOutcome::Continue })),
name: "default".to_string(),
func: Arc::new(|_| Box::pin(async { HookResult::Success })),
}
}
}
impl Hook {
pub async fn execute(&self, payload: &HookPayload) -> HookOutcome {
(self.func)(payload).await
pub async fn execute(&self, payload: &HookPayload) -> HookResponse {
HookResponse {
hook_name: self.name.clone(),
result: (self.func)(payload).await,
}
}
}
@@ -126,13 +155,6 @@ pub enum HookEvent {
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HookOutcome {
Continue,
#[allow(dead_code)]
Stop,
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;

View File

@@ -6,8 +6,8 @@ use serde::Serialize;
use crate::Hook;
use crate::HookEvent;
use crate::HookOutcome;
use crate::HookPayload;
use crate::HookResult;
use crate::command_from_argv;
/// Legacy notify payload appended as the final argv argument for backward compatibility.
@@ -48,12 +48,13 @@ pub fn legacy_notify_json(hook_event: &HookEvent, cwd: &Path) -> Result<String,
pub fn notify_hook(argv: Vec<String>) -> Hook {
let argv = Arc::new(argv);
Hook {
name: "legacy_notify".to_string(),
func: Arc::new(move |payload: &HookPayload| {
let argv = Arc::clone(&argv);
Box::pin(async move {
let mut command = match command_from_argv(&argv) {
Some(command) => command,
None => return HookOutcome::Continue,
None => return HookResult::Success,
};
if let Ok(notify_payload) = legacy_notify_json(&payload.hook_event, &payload.cwd) {
command.arg(notify_payload);
@@ -65,8 +66,10 @@ pub fn notify_hook(argv: Vec<String>) -> Hook {
.stdout(Stdio::null())
.stderr(Stdio::null());
let _ = command.spawn();
HookOutcome::Continue
match command.spawn() {
Ok(_) => HookResult::Success,
Err(err) => HookResult::FailedContinue(err.into()),
}
})
}),
}