diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index a4c424664c..8af3f87615 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -222,6 +222,7 @@ pub(crate) async fn apply_bespoke_event_handling( EventMsg::TurnComplete(turn_complete_event) => { // All per-thread requests are bound to a turn, so abort them. outgoing.abort_pending_server_requests().await; + respond_to_pending_interrupts(&thread_state, &outgoing, /*abort_reason*/ None).await; let turn_failed = thread_state.lock().await.turn_summary.last_error.is_some(); thread_watch_manager .note_turn_completed(&conversation_id.to_string(), turn_failed) @@ -1846,26 +1847,12 @@ pub(crate) async fn apply_bespoke_event_handling( EventMsg::TurnAborted(turn_aborted_event) => { // All per-thread requests are bound to a turn, so abort them. outgoing.abort_pending_server_requests().await; - let pending = { - let mut state = thread_state.lock().await; - std::mem::take(&mut state.pending_interrupts) - }; - if !pending.is_empty() { - for (rid, ver) in pending { - match ver { - ApiVersion::V1 => { - let response = InterruptConversationResponse { - abort_reason: turn_aborted_event.reason.clone(), - }; - outgoing.send_response(rid, response).await; - } - ApiVersion::V2 => { - let response = TurnInterruptResponse {}; - outgoing.send_response(rid, response).await; - } - } - } - } + respond_to_pending_interrupts( + &thread_state, + &outgoing, + Some(turn_aborted_event.reason.clone()), + ) + .await; thread_watch_manager .note_turn_interrupted(&conversation_id.to_string()) @@ -2342,6 +2329,33 @@ async fn handle_thread_rollback_failed( } } +async fn respond_to_pending_interrupts( + thread_state: &Arc>, + outgoing: &ThreadScopedOutgoingMessageSender, + abort_reason: Option, +) { + let pending = { + let mut state = thread_state.lock().await; + std::mem::take(&mut state.pending_interrupts) + }; + + for (rid, ver) in pending { + match ver { + ApiVersion::V1 => { + let Some(abort_reason) = abort_reason.clone() else { + debug_assert!(false, "v1 interrupts only resolve from TurnAborted"); + continue; + }; + let response = InterruptConversationResponse { abort_reason }; + outgoing.send_response(rid, response).await; + } + ApiVersion::V2 => { + outgoing.send_response(rid, TurnInterruptResponse {}).await; + } + } + } +} + async fn handle_token_count_event( conversation_id: ThreadId, turn_id: String, @@ -4192,17 +4206,19 @@ mod tests { let thread_state = new_thread_state(); { let mut state = thread_state.lock().await; - state.track_current_turn_event(&EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { + state.track_current_turn_event( + &event_turn_id, + &EventMsg::TurnStarted(codex_protocol::protocol::TurnStartedEvent { turn_id: event_turn_id.clone(), started_at: Some(42), model_context_window: None, collaboration_mode_kind: Default::default(), - }, - )); - state.track_current_turn_event(&EventMsg::TurnComplete(turn_complete_event( + }), + ); + state.track_current_turn_event( &event_turn_id, - ))); + &EventMsg::TurnComplete(turn_complete_event(&event_turn_id)), + ); } handle_turn_complete( diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index e81ee547f5..b679cc34bb 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -7790,11 +7790,6 @@ impl CodexMessageProcessor { async fn turn_interrupt(&self, request_id: ConnectionRequestId, params: TurnInterruptParams) { let TurnInterruptParams { thread_id, turn_id } = params; let is_startup_interrupt = turn_id.is_empty(); - if !is_startup_interrupt { - self.outgoing - .record_request_turn_id(&request_id, &turn_id) - .await; - } let (thread_uuid, thread) = match self.load_thread(&thread_id).await { Ok(v) => v, @@ -7808,10 +7803,40 @@ impl CodexMessageProcessor { // interrupts do not have a turn and are acknowledged after submission. if !is_startup_interrupt { let thread_state = self.thread_state_manager.thread_state(thread_uuid).await; - let mut thread_state = thread_state.lock().await; - thread_state - .pending_interrupts - .push((request_id.clone(), ApiVersion::V2)); + let is_running = matches!(thread.agent_status().await, AgentStatus::Running); + let interrupt_outcome = { + let mut thread_state = thread_state.lock().await; + if let Some(active_turn) = thread_state.active_turn_snapshot() { + if active_turn.id != turn_id { + Err(format!( + "expected active turn id {turn_id} but found {}", + active_turn.id + )) + } else { + thread_state + .pending_interrupts + .push((request_id.clone(), ApiVersion::V2)); + Ok(()) + } + } else if thread_state.last_terminal_turn_id.as_deref() == Some(turn_id.as_str()) { + Err("no active turn to interrupt".to_string()) + } else if is_running { + thread_state + .pending_interrupts + .push((request_id.clone(), ApiVersion::V2)); + Ok(()) + } else { + Err("no active turn to interrupt".to_string()) + } + }; + if let Err(message) = interrupt_outcome { + self.send_invalid_request_error(request_id, message).await; + return; + } + + self.outgoing + .record_request_turn_id(&request_id, &turn_id) + .await; } // Submit the interrupt. Turn interrupts respond upon TurnAborted; startup @@ -8074,7 +8099,7 @@ impl CodexMessageProcessor { // opt-in stays synchronized with the conversation. let raw_events_enabled = { let mut thread_state = thread_state.lock().await; - thread_state.track_current_turn_event(&event.msg); + thread_state.track_current_turn_event(&event.id, &event.msg); thread_state.experimental_raw_events }; let subscribed_connection_ids = thread_state_manager @@ -11237,14 +11262,15 @@ mod tests { let state = manager.thread_state(thread_id).await; let mut state = state.lock().await; state.cancel_tx = Some(cancel_tx); - state.track_current_turn_event(&EventMsg::TurnStarted( - codex_protocol::protocol::TurnStartedEvent { + state.track_current_turn_event( + "turn-1", + &EventMsg::TurnStarted(codex_protocol::protocol::TurnStartedEvent { turn_id: "turn-1".to_string(), started_at: None, model_context_window: None, collaboration_mode_kind: Default::default(), - }, - )); + }), + ); } manager.remove_thread_state(thread_id).await; diff --git a/codex-rs/app-server/src/thread_state.rs b/codex-rs/app-server/src/thread_state.rs index 323aba19d7..77b6defabb 100644 --- a/codex-rs/app-server/src/thread_state.rs +++ b/codex-rs/app-server/src/thread_state.rs @@ -60,6 +60,7 @@ pub(crate) struct ThreadState { pub(crate) pending_interrupts: PendingInterruptQueue, pub(crate) pending_rollbacks: Option, pub(crate) turn_summary: TurnSummary, + pub(crate) last_terminal_turn_id: Option, pub(crate) cancel_tx: Option>, pub(crate) experimental_raw_events: bool, pub(crate) listener_generation: u64, @@ -114,7 +115,7 @@ impl ThreadState { self.current_turn_history.active_turn_snapshot() } - pub(crate) fn track_current_turn_event(&mut self, event: &EventMsg) { + pub(crate) fn track_current_turn_event(&mut self, event_turn_id: &str, event: &EventMsg) { if let EventMsg::TurnStarted(payload) = event { self.turn_summary.started_at = payload.started_at; } @@ -122,6 +123,7 @@ impl ThreadState { if matches!(event, EventMsg::TurnAborted(_) | EventMsg::TurnComplete(_)) && !self.current_turn_history.has_active_turn() { + self.last_terminal_turn_id = Some(event_turn_id.to_string()); self.current_turn_history.reset(); } } diff --git a/codex-rs/app-server/tests/suite/v2/turn_interrupt.rs b/codex-rs/app-server/tests/suite/v2/turn_interrupt.rs index f8eaf799da..aedc54e016 100644 --- a/codex-rs/app-server/tests/suite/v2/turn_interrupt.rs +++ b/codex-rs/app-server/tests/suite/v2/turn_interrupt.rs @@ -2,10 +2,12 @@ use anyhow::Result; use app_test_support::McpProcess; +use app_test_support::create_final_assistant_message_sse_response; use app_test_support::create_mock_responses_server_sequence; use app_test_support::create_mock_responses_server_sequence_unchecked; use app_test_support::create_shell_command_sse_response; use app_test_support::to_response; +use codex_app_server_protocol::JSONRPCError; use codex_app_server_protocol::JSONRPCNotification; use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::RequestId; @@ -24,6 +26,7 @@ use tempfile::TempDir; use tokio::time::timeout; const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); +const INVALID_REQUEST_ERROR_CODE: i64 = -32600; #[tokio::test] async fn turn_interrupt_aborts_running_turn() -> Result<()> { @@ -125,6 +128,82 @@ async fn turn_interrupt_aborts_running_turn() -> Result<()> { Ok(()) } +#[tokio::test] +async fn turn_interrupt_rejects_completed_turn() -> Result<()> { + let tmp = TempDir::new()?; + let codex_home = tmp.path().join("codex_home"); + std::fs::create_dir(&codex_home)?; + + let server = create_mock_responses_server_sequence_unchecked(vec![ + create_final_assistant_message_sse_response("done")?, + ]) + .await; + create_config_toml(&codex_home, &server.uri(), "never", "workspace-write")?; + + let mut mcp = McpProcess::new(&codex_home).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let thread_req = mcp + .send_thread_start_request(ThreadStartParams { + model: Some("mock-model".to_string()), + ..Default::default() + }) + .await?; + let thread_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(thread_req)), + ) + .await??; + let ThreadStartResponse { thread, .. } = to_response::(thread_resp)?; + + let turn_req = mcp + .send_turn_start_request(TurnStartParams { + thread_id: thread.id.clone(), + input: vec![V2UserInput::Text { + text: "say done".to_string(), + text_elements: Vec::new(), + }], + ..Default::default() + }) + .await?; + let turn_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(turn_req)), + ) + .await??; + let TurnStartResponse { turn } = to_response::(turn_resp)?; + + let completed_notif: JSONRPCNotification = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("turn/completed"), + ) + .await??; + let completed: TurnCompletedNotification = serde_json::from_value( + completed_notif + .params + .expect("turn/completed params must be present"), + )?; + assert_eq!(completed.thread_id, thread.id); + assert_eq!(completed.turn.id, turn.id); + assert_eq!(completed.turn.status, TurnStatus::Completed); + + let interrupt_id = mcp + .send_turn_interrupt_request(TurnInterruptParams { + thread_id: thread.id, + turn_id: turn.id, + }) + .await?; + + let interrupt_err: JSONRPCError = timeout( + std::time::Duration::from_millis(500), + mcp.read_stream_until_error_message(RequestId::Integer(interrupt_id)), + ) + .await??; + assert_eq!(interrupt_err.error.code, INVALID_REQUEST_ERROR_CODE); + + Ok(()) +} + #[tokio::test] async fn turn_interrupt_resolves_pending_command_approval_request() -> Result<()> { #[cfg(target_os = "windows")]