use tokio::process::Command; use crate::types::Hook; use crate::types::HookEvent; use crate::types::HookPayload; use crate::types::HookResponse; #[derive(Default, Clone)] pub struct HooksConfig { pub legacy_notify_argv: Option>, } #[derive(Clone)] pub struct Hooks { after_agent: Vec, after_tool_use: Vec, } impl Default for Hooks { fn default() -> Self { Self::new(HooksConfig::default()) } } // Hooks are arbitrary, user-specified functions that are deterministically // executed after specific events in the Codex lifecycle. impl Hooks { pub fn new(config: HooksConfig) -> Self { let after_agent = config .legacy_notify_argv .filter(|argv| !argv.is_empty() && !argv[0].is_empty()) .map(crate::notify_hook) .into_iter() .collect(); Self { after_agent, after_tool_use: Vec::new(), } } fn hooks_for_event(&self, hook_event: &HookEvent) -> &[Hook] { match hook_event { HookEvent::AfterAgent { .. } => &self.after_agent, HookEvent::AfterToolUse { .. } => &self.after_tool_use, } } pub async fn dispatch(&self, hook_payload: HookPayload) -> Vec { let hooks = self.hooks_for_event(&hook_payload.hook_event); let mut outcomes = Vec::with_capacity(hooks.len()); for hook in hooks { let outcome = hook.execute(&hook_payload).await; let should_abort_operation = outcome.result.should_abort_operation(); outcomes.push(outcome); if should_abort_operation { break; } } outcomes } } pub fn command_from_argv(argv: &[String]) -> Option { let (program, args) = argv.split_first()?; if program.is_empty() { return None; } let mut command = Command::new(program); command.args(args); Some(command) } #[cfg(test)] mod tests { use std::fs; use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::time::Duration; use anyhow::Result; use chrono::TimeZone; use chrono::Utc; use codex_protocol::ThreadId; use pretty_assertions::assert_eq; use serde_json::to_string; use tempfile::tempdir; use tokio::time::timeout; use super::*; use crate::types::HookEventAfterAgent; use crate::types::HookEventAfterToolUse; use crate::types::HookResult; use crate::types::HookToolInput; use crate::types::HookToolKind; const CWD: &str = "/tmp"; const INPUT_MESSAGE: &str = "hello"; fn hook_payload(label: &str) -> HookPayload { HookPayload { session_id: ThreadId::new(), cwd: PathBuf::from(CWD), triggered_at: Utc .with_ymd_and_hms(2025, 1, 1, 0, 0, 0) .single() .expect("valid timestamp"), hook_event: HookEvent::AfterAgent { event: HookEventAfterAgent { thread_id: ThreadId::new(), turn_id: format!("turn-{label}"), input_messages: vec![INPUT_MESSAGE.to_string()], last_assistant_message: Some("hi".to_string()), }, }, } } fn counting_success_hook(calls: &Arc, name: &str) -> Hook { let hook_name = name.to_string(); let calls = Arc::clone(calls); Hook { name: hook_name, func: Arc::new(move |_| { let calls = Arc::clone(&calls); Box::pin(async move { calls.fetch_add(1, Ordering::SeqCst); HookResult::Success }) }), } } fn failing_continue_hook(calls: &Arc, name: &str, message: &str) -> Hook { let hook_name = name.to_string(); let message = message.to_string(); let calls = Arc::clone(calls); Hook { name: hook_name, func: Arc::new(move |_| { let calls = Arc::clone(&calls); let message = message.clone(); Box::pin(async move { calls.fetch_add(1, Ordering::SeqCst); HookResult::FailedContinue(std::io::Error::other(message).into()) }) }), } } fn failing_abort_hook(calls: &Arc, name: &str, message: &str) -> Hook { let hook_name = name.to_string(); let message = message.to_string(); let calls = Arc::clone(calls); Hook { name: hook_name, func: Arc::new(move |_| { let calls = Arc::clone(&calls); let message = message.clone(); Box::pin(async move { calls.fetch_add(1, Ordering::SeqCst); HookResult::FailedAbort(std::io::Error::other(message).into()) }) }), } } fn after_tool_use_payload(label: &str) -> HookPayload { HookPayload { session_id: ThreadId::new(), cwd: PathBuf::from(CWD), triggered_at: Utc .with_ymd_and_hms(2025, 1, 1, 0, 0, 0) .single() .expect("valid timestamp"), hook_event: HookEvent::AfterToolUse { event: HookEventAfterToolUse { turn_id: format!("turn-{label}"), call_id: format!("call-{label}"), tool_name: "apply_patch".to_string(), tool_kind: HookToolKind::Custom, tool_input: HookToolInput::Custom { input: "*** Begin Patch".to_string(), }, executed: true, success: true, duration_ms: 1, mutating: true, sandbox: "none".to_string(), sandbox_policy: "danger-full-access".to_string(), output_preview: "ok".to_string(), }, }, } } #[test] fn command_from_argv_returns_none_for_empty_args() { assert!(command_from_argv(&[]).is_none()); assert!(command_from_argv(&["".to_string()]).is_none()); } #[tokio::test] async fn command_from_argv_builds_command() -> Result<()> { let argv = if cfg!(windows) { vec![ "cmd".to_string(), "/C".to_string(), "echo hello world".to_string(), ] } else { vec!["echo".to_string(), "hello".to_string(), "world".to_string()] }; let mut command = command_from_argv(&argv).ok_or_else(|| anyhow::anyhow!("command"))?; let output = command.stdout(Stdio::piped()).output().await?; let stdout = String::from_utf8_lossy(&output.stdout); let trimmed = stdout.trim_end_matches(['\r', '\n']); assert_eq!(trimmed, "hello world"); Ok(()) } #[test] fn hooks_new_requires_program_name() { assert!(Hooks::new(HooksConfig::default()).after_agent.is_empty()); assert!( Hooks::new(HooksConfig { legacy_notify_argv: Some(vec![]), }) .after_agent .is_empty() ); assert!( Hooks::new(HooksConfig { legacy_notify_argv: Some(vec!["".to_string()]), }) .after_agent .is_empty() ); assert_eq!( Hooks::new(HooksConfig { legacy_notify_argv: Some(vec!["notify-send".to_string()]), }) .after_agent .len(), 1 ); } #[tokio::test] async fn dispatch_executes_hook() { let calls = Arc::new(AtomicUsize::new(0)); let hooks = Hooks { after_agent: vec![counting_success_hook(&calls, "counting")], ..Hooks::default() }; let outcomes = hooks.dispatch(hook_payload("1")).await; assert_eq!(outcomes.len(), 1); assert_eq!(outcomes[0].hook_name, "counting"); assert!(matches!(outcomes[0].result, HookResult::Success)); assert_eq!(calls.load(Ordering::SeqCst), 1); } #[tokio::test] async fn default_hook_is_noop_and_continues() { let payload = hook_payload("d"); let outcome = Hook::default().execute(&payload).await; assert_eq!(outcome.hook_name, "default"); assert!(matches!(outcome.result, HookResult::Success)); } #[tokio::test] async fn dispatch_executes_multiple_hooks_for_same_event() { let calls = Arc::new(AtomicUsize::new(0)); let hooks = Hooks { after_agent: vec![ counting_success_hook(&calls, "counting-1"), counting_success_hook(&calls, "counting-2"), ], ..Hooks::default() }; let outcomes = hooks.dispatch(hook_payload("2")).await; assert_eq!(outcomes.len(), 2); assert_eq!(outcomes[0].hook_name, "counting-1"); assert_eq!(outcomes[1].hook_name, "counting-2"); assert!(matches!(outcomes[0].result, HookResult::Success)); assert!(matches!(outcomes[1].result, HookResult::Success)); assert_eq!(calls.load(Ordering::SeqCst), 2); } #[tokio::test] async fn dispatch_stops_when_hook_requests_abort() { let calls = Arc::new(AtomicUsize::new(0)); let hooks = Hooks { after_agent: vec![ failing_abort_hook(&calls, "abort", "hook failed"), counting_success_hook(&calls, "counting"), ], ..Hooks::default() }; let outcomes = hooks.dispatch(hook_payload("3")).await; assert_eq!(outcomes.len(), 1); assert_eq!(outcomes[0].hook_name, "abort"); assert!(matches!(outcomes[0].result, HookResult::FailedAbort(_))); assert_eq!(calls.load(Ordering::SeqCst), 1); } #[tokio::test] async fn dispatch_executes_after_tool_use_hooks() { let calls = Arc::new(AtomicUsize::new(0)); let hooks = Hooks { after_tool_use: vec![counting_success_hook(&calls, "counting")], ..Hooks::default() }; let outcomes = hooks.dispatch(after_tool_use_payload("p")).await; assert_eq!(outcomes.len(), 1); assert_eq!(outcomes[0].hook_name, "counting"); assert!(matches!(outcomes[0].result, HookResult::Success)); assert_eq!(calls.load(Ordering::SeqCst), 1); } #[tokio::test] async fn dispatch_continues_after_continueable_failure() { let calls = Arc::new(AtomicUsize::new(0)); let hooks = Hooks { after_agent: vec![ failing_continue_hook(&calls, "failing", "hook failed"), counting_success_hook(&calls, "counting"), ], ..Hooks::default() }; let outcomes = hooks.dispatch(hook_payload("err")).await; assert_eq!(outcomes.len(), 2); assert_eq!(outcomes[0].hook_name, "failing"); assert!(matches!(outcomes[0].result, HookResult::FailedContinue(_))); assert_eq!(outcomes[1].hook_name, "counting"); assert!(matches!(outcomes[1].result, HookResult::Success)); assert_eq!(calls.load(Ordering::SeqCst), 2); } #[tokio::test] async fn dispatch_returns_after_tool_use_failure_outcome() { let calls = Arc::new(AtomicUsize::new(0)); let hooks = Hooks { after_tool_use: vec![failing_continue_hook( &calls, "failing", "after_tool_use hook failed", )], ..Hooks::default() }; let outcomes = hooks.dispatch(after_tool_use_payload("err-tool")).await; assert_eq!(outcomes.len(), 1); assert_eq!(outcomes[0].hook_name, "failing"); assert!(matches!(outcomes[0].result, HookResult::FailedContinue(_))); assert_eq!(calls.load(Ordering::SeqCst), 1); } #[cfg(not(windows))] #[tokio::test] async fn hook_executes_program_with_payload_argument_unix() -> Result<()> { let temp_dir = tempdir()?; let payload_path = temp_dir.path().join("payload.json"); let payload_path_arg = payload_path.to_string_lossy().into_owned(); let hook = Hook { name: "write_payload".to_string(), func: Arc::new(move |payload: &HookPayload| { let payload_path_arg = payload_path_arg.clone(); Box::pin(async move { let json = to_string(payload).expect("serialize hook payload"); let mut command = command_from_argv(&[ "/bin/sh".to_string(), "-c".to_string(), "printf '%s' \"$2\" > \"$1\"".to_string(), "sh".to_string(), payload_path_arg, json, ]) .expect("build command"); command.status().await.expect("run hook command"); HookResult::Success }) }), }; let payload = hook_payload("4"); let expected = to_string(&payload)?; let hooks = Hooks { after_agent: vec![hook], ..Hooks::default() }; let outcomes = hooks.dispatch(payload).await; assert_eq!(outcomes.len(), 1); assert!(matches!(outcomes[0].result, HookResult::Success)); let contents = timeout(Duration::from_secs(2), async { loop { if let Ok(contents) = fs::read_to_string(&payload_path) && !contents.is_empty() { return contents; } tokio::time::sleep(Duration::from_millis(10)).await; } }) .await?; assert_eq!(contents, expected); Ok(()) } #[cfg(windows)] #[tokio::test] async fn hook_executes_program_with_payload_argument_windows() -> Result<()> { let temp_dir = tempdir()?; let payload_path = temp_dir.path().join("payload.json"); let payload_path_arg = payload_path.to_string_lossy().into_owned(); let script_path = temp_dir.path().join("write_payload.ps1"); fs::write(&script_path, "[IO.File]::WriteAllText($args[0], $args[1])")?; let script_path_arg = script_path.to_string_lossy().into_owned(); let hook = Hook { name: "write_payload".to_string(), func: Arc::new(move |payload: &HookPayload| { let payload_path_arg = payload_path_arg.clone(); let script_path_arg = script_path_arg.clone(); Box::pin(async move { let json = to_string(payload).expect("serialize hook payload"); let mut command = command_from_argv(&[ "powershell.exe".to_string(), "-NoLogo".to_string(), "-NoProfile".to_string(), "-ExecutionPolicy".to_string(), "Bypass".to_string(), "-File".to_string(), script_path_arg, payload_path_arg, json, ]) .expect("build command"); command.status().await.expect("run hook command"); HookResult::Success }) }), }; let payload = hook_payload("4"); let expected = to_string(&payload)?; let hooks = Hooks { after_agent: vec![hook], ..Hooks::default() }; let outcomes = hooks.dispatch(payload).await; assert_eq!(outcomes.len(), 1); assert!(matches!(outcomes[0].result, HookResult::Success)); let contents = timeout(Duration::from_secs(2), async { loop { if let Ok(contents) = fs::read_to_string(&payload_path) && !contents.is_empty() { return contents; } tokio::time::sleep(Duration::from_millis(10)).await; } }) .await?; assert_eq!(contents, expected); Ok(()) } }