mirror of
https://github.com/openai/codex.git
synced 2026-04-27 09:51:03 +03:00
Allow hooks to error (#11615)
Allow hooks to return errors. We should do this before introducing more hook types, or we'll have to migrate them all.
This commit is contained in:
@@ -44,6 +44,7 @@ use async_channel::Sender;
|
||||
use codex_hooks::HookEvent;
|
||||
use codex_hooks::HookEventAfterAgent;
|
||||
use codex_hooks::HookPayload;
|
||||
use codex_hooks::HookResult;
|
||||
use codex_hooks::Hooks;
|
||||
use codex_hooks::HooksConfig;
|
||||
use codex_network_proxy::NetworkProxy;
|
||||
@@ -4533,7 +4534,8 @@ pub(crate) async fn run_turn(
|
||||
|
||||
if !needs_follow_up {
|
||||
last_agent_message = sampling_request_last_agent_message;
|
||||
sess.hooks()
|
||||
let hook_outcomes = sess
|
||||
.hooks()
|
||||
.dispatch(HookPayload {
|
||||
session_id: sess.conversation_id,
|
||||
cwd: turn_context.cwd.clone(),
|
||||
@@ -4548,6 +4550,47 @@ pub(crate) async fn run_turn(
|
||||
},
|
||||
})
|
||||
.await;
|
||||
|
||||
let mut abort_message = None;
|
||||
for hook_outcome in hook_outcomes {
|
||||
let hook_name = hook_outcome.hook_name;
|
||||
match hook_outcome.result {
|
||||
HookResult::Success => {}
|
||||
HookResult::FailedContinue(error) => {
|
||||
warn!(
|
||||
turn_id = %turn_context.sub_id,
|
||||
hook_name = %hook_name,
|
||||
error = %error,
|
||||
"after_agent hook failed; continuing"
|
||||
);
|
||||
}
|
||||
HookResult::FailedAbort(error) => {
|
||||
let message = format!(
|
||||
"after_agent hook '{hook_name}' failed and aborted turn completion: {error}"
|
||||
);
|
||||
warn!(
|
||||
turn_id = %turn_context.sub_id,
|
||||
hook_name = %hook_name,
|
||||
error = %error,
|
||||
"after_agent hook failed; aborting operation"
|
||||
);
|
||||
if abort_message.is_none() {
|
||||
abort_message = Some(message);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(message) = abort_message {
|
||||
sess.send_event(
|
||||
&turn_context,
|
||||
EventMsg::Error(ErrorEvent {
|
||||
message,
|
||||
codex_error_info: None,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
return None;
|
||||
}
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
|
||||
@@ -15,6 +15,7 @@ use async_trait::async_trait;
|
||||
use codex_hooks::HookEvent;
|
||||
use codex_hooks::HookEventAfterToolUse;
|
||||
use codex_hooks::HookPayload;
|
||||
use codex_hooks::HookResult;
|
||||
use codex_hooks::HookToolInput;
|
||||
use codex_hooks::HookToolInputLocalShell;
|
||||
use codex_hooks::HookToolKind;
|
||||
@@ -172,7 +173,7 @@ impl ToolRegistry {
|
||||
Ok((preview, success)) => (preview.clone(), *success),
|
||||
Err(err) => (err.to_string(), false),
|
||||
};
|
||||
dispatch_after_tool_use_hook(AfterToolUseHookDispatch {
|
||||
let hook_abort_error = dispatch_after_tool_use_hook(AfterToolUseHookDispatch {
|
||||
invocation: &invocation,
|
||||
output_preview,
|
||||
success,
|
||||
@@ -182,6 +183,10 @@ impl ToolRegistry {
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Some(err) = hook_abort_error {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
let mut guard = output_cell.lock().await;
|
||||
@@ -339,12 +344,14 @@ struct AfterToolUseHookDispatch<'a> {
|
||||
mutating: bool,
|
||||
}
|
||||
|
||||
async fn dispatch_after_tool_use_hook(dispatch: AfterToolUseHookDispatch<'_>) {
|
||||
async fn dispatch_after_tool_use_hook(
|
||||
dispatch: AfterToolUseHookDispatch<'_>,
|
||||
) -> Option<FunctionCallError> {
|
||||
let AfterToolUseHookDispatch { invocation, .. } = dispatch;
|
||||
let session = invocation.session.as_ref();
|
||||
let turn = invocation.turn.as_ref();
|
||||
let tool_input = HookToolInput::from(&invocation.payload);
|
||||
session
|
||||
let hook_outcomes = session
|
||||
.hooks()
|
||||
.dispatch(HookPayload {
|
||||
session_id: session.conversation_id,
|
||||
@@ -373,4 +380,34 @@ async fn dispatch_after_tool_use_hook(dispatch: AfterToolUseHookDispatch<'_>) {
|
||||
},
|
||||
})
|
||||
.await;
|
||||
|
||||
for hook_outcome in hook_outcomes {
|
||||
let hook_name = hook_outcome.hook_name;
|
||||
match hook_outcome.result {
|
||||
HookResult::Success => {}
|
||||
HookResult::FailedContinue(error) => {
|
||||
warn!(
|
||||
call_id = %invocation.call_id,
|
||||
tool_name = %invocation.tool_name,
|
||||
hook_name = %hook_name,
|
||||
error = %error,
|
||||
"after_tool_use hook failed; continuing"
|
||||
);
|
||||
}
|
||||
HookResult::FailedAbort(error) => {
|
||||
warn!(
|
||||
call_id = %invocation.call_id,
|
||||
tool_name = %invocation.tool_name,
|
||||
hook_name = %hook_name,
|
||||
error = %error,
|
||||
"after_tool_use hook failed; aborting operation"
|
||||
);
|
||||
return Some(FunctionCallError::Fatal(format!(
|
||||
"after_tool_use hook '{hook_name}' failed and aborted operation: {error}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
@@ -9,8 +9,9 @@ pub use types::Hook;
|
||||
pub use types::HookEvent;
|
||||
pub use types::HookEventAfterAgent;
|
||||
pub use types::HookEventAfterToolUse;
|
||||
pub use types::HookOutcome;
|
||||
pub use types::HookPayload;
|
||||
pub use types::HookResponse;
|
||||
pub use types::HookResult;
|
||||
pub use types::HookToolInput;
|
||||
pub use types::HookToolInputLocalShell;
|
||||
pub use types::HookToolKind;
|
||||
|
||||
@@ -2,8 +2,8 @@ use tokio::process::Command;
|
||||
|
||||
use crate::types::Hook;
|
||||
use crate::types::HookEvent;
|
||||
use crate::types::HookOutcome;
|
||||
use crate::types::HookPayload;
|
||||
use crate::types::HookResponse;
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct HooksConfig {
|
||||
@@ -45,14 +45,19 @@ impl Hooks {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn dispatch(&self, hook_payload: HookPayload) {
|
||||
// TODO(gt): support interrupting program execution by returning a result here.
|
||||
for hook in self.hooks_for_event(&hook_payload.hook_event) {
|
||||
pub async fn dispatch(&self, hook_payload: HookPayload) -> Vec<HookResponse> {
|
||||
let hooks = self.hooks_for_event(&hook_payload.hook_event);
|
||||
let mut outcomes = Vec::with_capacity(hooks.len());
|
||||
for hook in hooks {
|
||||
let outcome = hook.execute(&hook_payload).await;
|
||||
if matches!(outcome, HookOutcome::Stop) {
|
||||
let should_abort_operation = outcome.result.should_abort_operation();
|
||||
outcomes.push(outcome);
|
||||
if should_abort_operation {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
outcomes
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,6 +93,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::types::HookEventAfterAgent;
|
||||
use crate::types::HookEventAfterToolUse;
|
||||
use crate::types::HookResult;
|
||||
use crate::types::HookToolInput;
|
||||
use crate::types::HookToolKind;
|
||||
|
||||
@@ -113,14 +119,50 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn counting_hook(calls: &Arc<AtomicUsize>, outcome: HookOutcome) -> Hook {
|
||||
fn counting_success_hook(calls: &Arc<AtomicUsize>, name: &str) -> Hook {
|
||||
let hook_name = name.to_string();
|
||||
let calls = Arc::clone(calls);
|
||||
Hook {
|
||||
name: hook_name,
|
||||
func: Arc::new(move |_| {
|
||||
let calls = Arc::clone(&calls);
|
||||
Box::pin(async move {
|
||||
calls.fetch_add(1, Ordering::SeqCst);
|
||||
outcome
|
||||
HookResult::Success
|
||||
})
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn failing_continue_hook(calls: &Arc<AtomicUsize>, name: &str, message: &str) -> Hook {
|
||||
let hook_name = name.to_string();
|
||||
let message = message.to_string();
|
||||
let calls = Arc::clone(calls);
|
||||
Hook {
|
||||
name: hook_name,
|
||||
func: Arc::new(move |_| {
|
||||
let calls = Arc::clone(&calls);
|
||||
let message = message.clone();
|
||||
Box::pin(async move {
|
||||
calls.fetch_add(1, Ordering::SeqCst);
|
||||
HookResult::FailedContinue(std::io::Error::other(message).into())
|
||||
})
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn failing_abort_hook(calls: &Arc<AtomicUsize>, name: &str, message: &str) -> Hook {
|
||||
let hook_name = name.to_string();
|
||||
let message = message.to_string();
|
||||
let calls = Arc::clone(calls);
|
||||
Hook {
|
||||
name: hook_name,
|
||||
func: Arc::new(move |_| {
|
||||
let calls = Arc::clone(&calls);
|
||||
let message = message.clone();
|
||||
Box::pin(async move {
|
||||
calls.fetch_add(1, Ordering::SeqCst);
|
||||
HookResult::FailedAbort(std::io::Error::other(message).into())
|
||||
})
|
||||
}),
|
||||
}
|
||||
@@ -212,11 +254,14 @@ mod tests {
|
||||
async fn dispatch_executes_hook() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let hooks = Hooks {
|
||||
after_agent: vec![counting_hook(&calls, HookOutcome::Continue)],
|
||||
after_agent: vec![counting_success_hook(&calls, "counting")],
|
||||
..Hooks::default()
|
||||
};
|
||||
|
||||
hooks.dispatch(hook_payload("1")).await;
|
||||
let outcomes = hooks.dispatch(hook_payload("1")).await;
|
||||
assert_eq!(outcomes.len(), 1);
|
||||
assert_eq!(outcomes[0].hook_name, "counting");
|
||||
assert!(matches!(outcomes[0].result, HookResult::Success));
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
@@ -224,7 +269,8 @@ mod tests {
|
||||
async fn default_hook_is_noop_and_continues() {
|
||||
let payload = hook_payload("d");
|
||||
let outcome = Hook::default().execute(&payload).await;
|
||||
assert_eq!(outcome, HookOutcome::Continue);
|
||||
assert_eq!(outcome.hook_name, "default");
|
||||
assert!(matches!(outcome.result, HookResult::Success));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -232,28 +278,36 @@ mod tests {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let hooks = Hooks {
|
||||
after_agent: vec![
|
||||
counting_hook(&calls, HookOutcome::Continue),
|
||||
counting_hook(&calls, HookOutcome::Continue),
|
||||
counting_success_hook(&calls, "counting-1"),
|
||||
counting_success_hook(&calls, "counting-2"),
|
||||
],
|
||||
..Hooks::default()
|
||||
};
|
||||
|
||||
hooks.dispatch(hook_payload("2")).await;
|
||||
let outcomes = hooks.dispatch(hook_payload("2")).await;
|
||||
assert_eq!(outcomes.len(), 2);
|
||||
assert_eq!(outcomes[0].hook_name, "counting-1");
|
||||
assert_eq!(outcomes[1].hook_name, "counting-2");
|
||||
assert!(matches!(outcomes[0].result, HookResult::Success));
|
||||
assert!(matches!(outcomes[1].result, HookResult::Success));
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_stops_when_hook_returns_stop() {
|
||||
async fn dispatch_stops_when_hook_requests_abort() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let hooks = Hooks {
|
||||
after_agent: vec![
|
||||
counting_hook(&calls, HookOutcome::Stop),
|
||||
counting_hook(&calls, HookOutcome::Continue),
|
||||
failing_abort_hook(&calls, "abort", "hook failed"),
|
||||
counting_success_hook(&calls, "counting"),
|
||||
],
|
||||
..Hooks::default()
|
||||
};
|
||||
|
||||
hooks.dispatch(hook_payload("3")).await;
|
||||
let outcomes = hooks.dispatch(hook_payload("3")).await;
|
||||
assert_eq!(outcomes.len(), 1);
|
||||
assert_eq!(outcomes[0].hook_name, "abort");
|
||||
assert!(matches!(outcomes[0].result, HookResult::FailedAbort(_)));
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
@@ -261,11 +315,53 @@ mod tests {
|
||||
async fn dispatch_executes_after_tool_use_hooks() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let hooks = Hooks {
|
||||
after_tool_use: vec![counting_hook(&calls, HookOutcome::Continue)],
|
||||
after_tool_use: vec![counting_success_hook(&calls, "counting")],
|
||||
..Hooks::default()
|
||||
};
|
||||
|
||||
hooks.dispatch(after_tool_use_payload("p")).await;
|
||||
let outcomes = hooks.dispatch(after_tool_use_payload("p")).await;
|
||||
assert_eq!(outcomes.len(), 1);
|
||||
assert_eq!(outcomes[0].hook_name, "counting");
|
||||
assert!(matches!(outcomes[0].result, HookResult::Success));
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_continues_after_continueable_failure() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let hooks = Hooks {
|
||||
after_agent: vec![
|
||||
failing_continue_hook(&calls, "failing", "hook failed"),
|
||||
counting_success_hook(&calls, "counting"),
|
||||
],
|
||||
..Hooks::default()
|
||||
};
|
||||
|
||||
let outcomes = hooks.dispatch(hook_payload("err")).await;
|
||||
assert_eq!(outcomes.len(), 2);
|
||||
assert_eq!(outcomes[0].hook_name, "failing");
|
||||
assert!(matches!(outcomes[0].result, HookResult::FailedContinue(_)));
|
||||
assert_eq!(outcomes[1].hook_name, "counting");
|
||||
assert!(matches!(outcomes[1].result, HookResult::Success));
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_returns_after_tool_use_failure_outcome() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let hooks = Hooks {
|
||||
after_tool_use: vec![failing_continue_hook(
|
||||
&calls,
|
||||
"failing",
|
||||
"after_tool_use hook failed",
|
||||
)],
|
||||
..Hooks::default()
|
||||
};
|
||||
|
||||
let outcomes = hooks.dispatch(after_tool_use_payload("err-tool")).await;
|
||||
assert_eq!(outcomes.len(), 1);
|
||||
assert_eq!(outcomes[0].hook_name, "failing");
|
||||
assert!(matches!(outcomes[0].result, HookResult::FailedContinue(_)));
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
@@ -276,6 +372,7 @@ mod tests {
|
||||
let payload_path = temp_dir.path().join("payload.json");
|
||||
let payload_path_arg = payload_path.to_string_lossy().into_owned();
|
||||
let hook = Hook {
|
||||
name: "write_payload".to_string(),
|
||||
func: Arc::new(move |payload: &HookPayload| {
|
||||
let payload_path_arg = payload_path_arg.clone();
|
||||
Box::pin(async move {
|
||||
@@ -290,7 +387,7 @@ mod tests {
|
||||
])
|
||||
.expect("build command");
|
||||
command.status().await.expect("run hook command");
|
||||
HookOutcome::Continue
|
||||
HookResult::Success
|
||||
})
|
||||
}),
|
||||
};
|
||||
@@ -302,7 +399,9 @@ mod tests {
|
||||
after_agent: vec![hook],
|
||||
..Hooks::default()
|
||||
};
|
||||
hooks.dispatch(payload).await;
|
||||
let outcomes = hooks.dispatch(payload).await;
|
||||
assert_eq!(outcomes.len(), 1);
|
||||
assert!(matches!(outcomes[0].result, HookResult::Success));
|
||||
|
||||
let contents = timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
@@ -330,6 +429,7 @@ mod tests {
|
||||
fs::write(&script_path, "[IO.File]::WriteAllText($args[0], $args[1])")?;
|
||||
let script_path_arg = script_path.to_string_lossy().into_owned();
|
||||
let hook = Hook {
|
||||
name: "write_payload".to_string(),
|
||||
func: Arc::new(move |payload: &HookPayload| {
|
||||
let payload_path_arg = payload_path_arg.clone();
|
||||
let script_path_arg = script_path_arg.clone();
|
||||
@@ -348,7 +448,7 @@ mod tests {
|
||||
])
|
||||
.expect("build command");
|
||||
command.status().await.expect("run hook command");
|
||||
HookOutcome::Continue
|
||||
HookResult::Success
|
||||
})
|
||||
}),
|
||||
};
|
||||
@@ -360,7 +460,9 @@ mod tests {
|
||||
after_agent: vec![hook],
|
||||
..Hooks::default()
|
||||
};
|
||||
hooks.dispatch(payload).await;
|
||||
let outcomes = hooks.dispatch(payload).await;
|
||||
assert_eq!(outcomes.len(), 1);
|
||||
assert!(matches!(outcomes[0].result, HookResult::Success));
|
||||
|
||||
let contents = timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
|
||||
@@ -10,24 +10,53 @@ use futures::future::BoxFuture;
|
||||
use serde::Serialize;
|
||||
use serde::Serializer;
|
||||
|
||||
pub type HookFn = Arc<dyn for<'a> Fn(&'a HookPayload) -> BoxFuture<'a, HookOutcome> + Send + Sync>;
|
||||
pub type HookFn = Arc<dyn for<'a> Fn(&'a HookPayload) -> BoxFuture<'a, HookResult> + Send + Sync>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum HookResult {
|
||||
/// Success: hook completed successfully.
|
||||
Success,
|
||||
/// FailedContinue: hook failed, but other subsequent hooks should still execute and the
|
||||
/// operation should continue.
|
||||
FailedContinue(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
/// FailedAbort: hook failed, other subsequent hooks should not execute, and the operation
|
||||
/// should be aborted.
|
||||
FailedAbort(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
}
|
||||
|
||||
impl HookResult {
|
||||
pub fn should_abort_operation(&self) -> bool {
|
||||
matches!(self, Self::FailedAbort(_))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HookResponse {
|
||||
pub hook_name: String,
|
||||
pub result: HookResult,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Hook {
|
||||
pub name: String,
|
||||
pub func: HookFn,
|
||||
}
|
||||
|
||||
impl Default for Hook {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
func: Arc::new(|_| Box::pin(async { HookOutcome::Continue })),
|
||||
name: "default".to_string(),
|
||||
func: Arc::new(|_| Box::pin(async { HookResult::Success })),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Hook {
|
||||
pub async fn execute(&self, payload: &HookPayload) -> HookOutcome {
|
||||
(self.func)(payload).await
|
||||
pub async fn execute(&self, payload: &HookPayload) -> HookResponse {
|
||||
HookResponse {
|
||||
hook_name: self.name.clone(),
|
||||
result: (self.func)(payload).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,13 +155,6 @@ pub enum HookEvent {
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum HookOutcome {
|
||||
Continue,
|
||||
#[allow(dead_code)]
|
||||
Stop,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
@@ -6,8 +6,8 @@ use serde::Serialize;
|
||||
|
||||
use crate::Hook;
|
||||
use crate::HookEvent;
|
||||
use crate::HookOutcome;
|
||||
use crate::HookPayload;
|
||||
use crate::HookResult;
|
||||
use crate::command_from_argv;
|
||||
|
||||
/// Legacy notify payload appended as the final argv argument for backward compatibility.
|
||||
@@ -48,12 +48,13 @@ pub fn legacy_notify_json(hook_event: &HookEvent, cwd: &Path) -> Result<String,
|
||||
pub fn notify_hook(argv: Vec<String>) -> Hook {
|
||||
let argv = Arc::new(argv);
|
||||
Hook {
|
||||
name: "legacy_notify".to_string(),
|
||||
func: Arc::new(move |payload: &HookPayload| {
|
||||
let argv = Arc::clone(&argv);
|
||||
Box::pin(async move {
|
||||
let mut command = match command_from_argv(&argv) {
|
||||
Some(command) => command,
|
||||
None => return HookOutcome::Continue,
|
||||
None => return HookResult::Success,
|
||||
};
|
||||
if let Ok(notify_payload) = legacy_notify_json(&payload.hook_event, &payload.cwd) {
|
||||
command.arg(notify_payload);
|
||||
@@ -65,8 +66,10 @@ pub fn notify_hook(argv: Vec<String>) -> Hook {
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null());
|
||||
|
||||
let _ = command.spawn();
|
||||
HookOutcome::Continue
|
||||
match command.spawn() {
|
||||
Ok(_) => HookResult::Success,
|
||||
Err(err) => HookResult::FailedContinue(err.into()),
|
||||
}
|
||||
})
|
||||
}),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user