Compare commits

...

2 Commits

Author SHA1 Message Date
Gabriel Cohen
41ad46665a codex: fix CI failure on PR #13777
Collapse the interrupt resend guard in multi_agents.rs so clippy accepts the branch without changing behavior.

Co-authored-by: Codex <noreply@openai.com>
2026-03-06 10:49:48 -08:00
Gabriel Cohen
a2848e1689 Fix subagent notifications after interruptive resend
Co-authored-by: Codex <noreply@openai.com>
2026-03-06 10:41:07 -08:00
3 changed files with 840 additions and 68 deletions

View File

@@ -22,13 +22,76 @@ use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_protocol::protocol::TokenUsage;
use codex_protocol::user_input::UserInput;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Weak;
use tokio::sync::Mutex;
use tokio::sync::watch;
const AGENT_NAMES: &str = include_str!("agent_names.txt");
const FORKED_SPAWN_AGENT_OUTPUT_MESSAGE: &str = "You are the newly spawned agent. The prior conversation history was forked from your parent agent. Treat the next user message as your new task, and use the forked history only as background context.";
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
enum InterruptFollowUpPhase {
#[default]
None,
Pending,
Committed,
Cancelled,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
struct InterruptFollowUpState {
generation: u64,
phase: InterruptFollowUpPhase,
}
#[derive(Debug)]
pub(crate) struct InterruptFollowUpHandle {
sender: watch::Sender<InterruptFollowUpState>,
generation: u64,
finalized: bool,
}
impl InterruptFollowUpHandle {
fn new(sender: watch::Sender<InterruptFollowUpState>, generation: u64) -> Self {
Self {
sender,
generation,
finalized: false,
}
}
pub(crate) fn commit(&mut self) {
self.set_phase(InterruptFollowUpPhase::Committed);
self.finalized = true;
}
pub(crate) fn cancel(&mut self) {
self.set_phase(InterruptFollowUpPhase::Cancelled);
self.finalized = true;
}
fn set_phase(&self, phase: InterruptFollowUpPhase) {
let current = *self.sender.borrow();
if current.generation != self.generation {
return;
}
self.sender.send_replace(InterruptFollowUpState {
generation: self.generation,
phase,
});
}
}
impl Drop for InterruptFollowUpHandle {
fn drop(&mut self) {
if !self.finalized {
self.set_phase(InterruptFollowUpPhase::Cancelled);
}
}
}
#[derive(Clone, Debug, Default)]
pub(crate) struct SpawnAgentOptions {
pub(crate) fork_parent_spawn_call_id: Option<String>,
@@ -72,6 +135,10 @@ pub(crate) struct AgentControl {
/// `ThreadManagerState -> CodexThread -> Session -> SessionServices -> ThreadManagerState`.
manager: Weak<ThreadManagerState>,
state: Arc<Guards>,
// Redirected child work is coordinated through this shared watch map so the send_input,
// wait, and parent-notification paths all agree on whether an "Interrupted" status is still
// transient or should be reported.
interrupt_follow_up_state: Arc<Mutex<HashMap<ThreadId, watch::Sender<InterruptFollowUpState>>>>,
}
impl AgentControl {
@@ -322,6 +389,37 @@ impl AgentControl {
state.send_op(agent_id, Op::Interrupt).await
}
pub(crate) async fn begin_interrupt_follow_up(
&self,
agent_id: ThreadId,
) -> InterruptFollowUpHandle {
let sender = self.interrupt_follow_up_sender(agent_id).await;
let generation = sender.borrow().generation + 1;
sender.send_replace(InterruptFollowUpState {
generation,
phase: InterruptFollowUpPhase::Pending,
});
InterruptFollowUpHandle::new(sender, generation)
}
#[cfg(test)]
pub(crate) async fn mark_interrupt_follow_up_pending(&self, agent_id: ThreadId) {
self.force_interrupt_follow_up_phase(agent_id, InterruptFollowUpPhase::Pending)
.await;
}
#[cfg(test)]
pub(crate) async fn mark_interrupt_follow_up_committed(&self, agent_id: ThreadId) {
self.force_interrupt_follow_up_phase(agent_id, InterruptFollowUpPhase::Committed)
.await;
}
#[cfg(test)]
pub(crate) async fn mark_interrupt_follow_up_cancelled(&self, agent_id: ThreadId) {
self.force_interrupt_follow_up_phase(agent_id, InterruptFollowUpPhase::Cancelled)
.await;
}
/// Submit a shutdown request to an existing agent thread.
pub(crate) async fn shutdown_agent(&self, agent_id: ThreadId) -> CodexResult<String> {
let state = self.upgrade()?;
@@ -343,6 +441,33 @@ impl AgentControl {
thread.agent_status().await
}
pub(crate) async fn get_reportable_status(&self, agent_id: ThreadId) -> AgentStatus {
let status = self.get_status(agent_id).await;
if self
.is_status_reportable_immediately(agent_id, &status)
.await
{
status
} else {
AgentStatus::Running
}
}
pub(crate) async fn is_status_reportable_immediately(
&self,
agent_id: ThreadId,
status: &AgentStatus,
) -> bool {
if !is_interrupted_status(status) {
return true;
}
let follow_up_state = self.interrupt_follow_up_state(agent_id).await;
!matches!(
follow_up_state.phase,
InterruptFollowUpPhase::Pending | InterruptFollowUpPhase::Committed
)
}
pub(crate) async fn get_agent_nickname_and_role(
&self,
agent_id: ThreadId,
@@ -432,16 +557,8 @@ impl AgentControl {
let control = self.clone();
tokio::spawn(async move {
let status = match control.subscribe_status(child_thread_id).await {
Ok(mut status_rx) => {
let mut status = status_rx.borrow().clone();
while !is_final(&status) {
if status_rx.changed().await.is_err() {
status = control.get_status(child_thread_id).await;
break;
}
status = status_rx.borrow().clone();
}
status
Ok(status_rx) => {
wait_for_reportable_status(&control, child_thread_id, status_rx).await
}
Err(_) => control.get_status(child_thread_id).await,
};
@@ -486,6 +603,123 @@ impl AgentControl {
parent_thread.codex.session.user_shell().shell_snapshot()
}
}
pub(crate) async fn wait_for_reportable_status(
control: &AgentControl,
child_thread_id: ThreadId,
mut status_rx: watch::Receiver<AgentStatus>,
) -> AgentStatus {
let mut follow_up_rx = control.interrupt_follow_up_receiver(child_thread_id).await;
let mut status = status_rx.borrow().clone();
let mut blocked_generation = None;
loop {
if !is_final(&status) {
if let Some(generation) = blocked_generation.take() {
control
.clear_interrupt_follow_up_state(child_thread_id, generation)
.await;
}
if status_rx.changed().await.is_err() {
return control.get_status(child_thread_id).await;
}
status = status_rx.borrow().clone();
continue;
}
if !is_interrupted_status(&status) {
if let Some(generation) = blocked_generation.take() {
control
.clear_interrupt_follow_up_state(child_thread_id, generation)
.await;
}
return status;
}
let follow_up_state = *follow_up_rx.borrow();
match follow_up_state.phase {
InterruptFollowUpPhase::None | InterruptFollowUpPhase::Cancelled => return status,
InterruptFollowUpPhase::Pending | InterruptFollowUpPhase::Committed => {
// Once an interruptive resend is in flight, keep suppressing the old turn's
// `Interrupted` result until either the redirect is cancelled or the child emits
// any later status transition for that same generation.
blocked_generation = Some(follow_up_state.generation);
tokio::select! {
changed = status_rx.changed() => {
if changed.is_err() {
return control.get_status(child_thread_id).await;
}
status = status_rx.borrow().clone();
}
changed = follow_up_rx.changed() => {
if changed.is_err() {
return status;
}
}
}
}
}
}
}
impl AgentControl {
#[cfg(test)]
async fn force_interrupt_follow_up_phase(
&self,
agent_id: ThreadId,
phase: InterruptFollowUpPhase,
) {
let sender = self.interrupt_follow_up_sender(agent_id).await;
let current = *sender.borrow();
let generation = match phase {
InterruptFollowUpPhase::Pending => current.generation + 1,
_ if current.phase == InterruptFollowUpPhase::None => current.generation + 1,
_ => current.generation,
};
sender.send_replace(InterruptFollowUpState { generation, phase });
}
async fn interrupt_follow_up_sender(
&self,
agent_id: ThreadId,
) -> watch::Sender<InterruptFollowUpState> {
let mut state = self.interrupt_follow_up_state.lock().await;
state
.entry(agent_id)
.or_insert_with(|| {
let (sender, _receiver) = watch::channel(InterruptFollowUpState::default());
sender
})
.clone()
}
async fn interrupt_follow_up_state(&self, agent_id: ThreadId) -> InterruptFollowUpState {
let sender = self.interrupt_follow_up_sender(agent_id).await;
*sender.borrow()
}
async fn interrupt_follow_up_receiver(
&self,
agent_id: ThreadId,
) -> watch::Receiver<InterruptFollowUpState> {
self.interrupt_follow_up_sender(agent_id).await.subscribe()
}
async fn clear_interrupt_follow_up_state(&self, agent_id: ThreadId, generation: u64) {
let sender = self.interrupt_follow_up_sender(agent_id).await;
let current = *sender.borrow();
if current.generation != generation {
return;
}
sender.send_replace(InterruptFollowUpState {
generation,
phase: InterruptFollowUpPhase::None,
});
}
}
fn is_interrupted_status(status: &AgentStatus) -> bool {
matches!(status, AgentStatus::Errored(message) if message == "Interrupted")
}
#[cfg(test)]
mod tests {
use super::*;
@@ -499,6 +733,11 @@ mod tests {
use crate::config_loader::LoaderOverrides;
use crate::contextual_user_message::SUBAGENT_NOTIFICATION_OPEN_TAG;
use crate::features::Feature;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::handlers::multi_agents::MultiAgentHandler;
use crate::tools::registry::ToolHandler;
use crate::turn_diff_tracker::TurnDiffTracker;
use assert_matches::assert_matches;
use codex_protocol::config_types::ModeKind;
use codex_protocol::models::ContentItem;
@@ -512,7 +751,10 @@ mod tests {
use codex_protocol::protocol::TurnCompleteEvent;
use codex_protocol::protocol::TurnStartedEvent;
use pretty_assertions::assert_eq;
use serde_json::json;
use tempfile::TempDir;
use tokio::sync::Mutex;
use tokio::task::yield_now;
use tokio::time::Duration;
use tokio::time::sleep;
use tokio::time::timeout;
@@ -548,6 +790,24 @@ mod tests {
}]
}
fn function_invocation(
session: Arc<crate::codex::Session>,
turn: Arc<crate::codex::TurnContext>,
tool_name: &str,
arguments: serde_json::Value,
) -> ToolInvocation {
ToolInvocation {
session,
turn,
tracker: Arc::new(Mutex::new(TurnDiffTracker::default())),
call_id: "call-1".to_string(),
tool_name: tool_name.to_string(),
payload: ToolPayload::Function {
arguments: arguments.to_string(),
},
}
}
struct AgentControlHarness {
_home: TempDir,
config: Config,
@@ -1363,6 +1623,365 @@ mod tests {
);
}
#[tokio::test]
async fn completion_watcher_prefers_later_completion_after_interrupt() {
let harness = AgentControlHarness::new().await;
let (parent_thread_id, parent_thread) = harness.start_thread().await;
let (child_thread_id, child_thread) = harness.start_thread().await;
harness.control.maybe_start_completion_watcher(
child_thread_id,
Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id,
depth: 1,
agent_nickname: None,
agent_role: Some("explorer".to_string()),
})),
);
harness
.control
.mark_interrupt_follow_up_pending(child_thread_id)
.await;
let interrupted_turn = child_thread.codex.session.new_default_turn().await;
child_thread
.codex
.session
.send_event(
interrupted_turn.as_ref(),
EventMsg::TurnStarted(TurnStartedEvent {
turn_id: interrupted_turn.sub_id.clone(),
model_context_window: None,
collaboration_mode_kind: ModeKind::Default,
}),
)
.await;
child_thread
.codex
.session
.send_event(
interrupted_turn.as_ref(),
EventMsg::TurnAborted(TurnAbortedEvent {
turn_id: Some(interrupted_turn.sub_id.clone()),
reason: TurnAbortReason::Interrupted,
}),
)
.await;
yield_now().await;
let history_items = parent_thread
.codex
.session
.clone_history()
.await
.raw_items()
.to_vec();
assert_eq!(has_subagent_notification(&history_items), false);
harness
.control
.mark_interrupt_follow_up_committed(child_thread_id)
.await;
let completed_turn = child_thread.codex.session.new_default_turn().await;
child_thread
.codex
.session
.send_event(
completed_turn.as_ref(),
EventMsg::TurnStarted(TurnStartedEvent {
turn_id: completed_turn.sub_id.clone(),
model_context_window: None,
collaboration_mode_kind: ModeKind::Default,
}),
)
.await;
child_thread
.codex
.session
.send_event(
completed_turn.as_ref(),
EventMsg::TurnComplete(TurnCompleteEvent {
turn_id: completed_turn.sub_id.clone(),
last_agent_message: Some("done".to_string()),
}),
)
.await;
assert_eq!(wait_for_subagent_notification(&parent_thread).await, true);
let history_items = parent_thread
.codex
.session
.clone_history()
.await
.raw_items()
.to_vec();
assert_eq!(
history_contains_text(&history_items, "\"completed\":\"done\""),
true
);
assert_eq!(
history_contains_text(&history_items, "\"Interrupted\""),
false
);
}
#[tokio::test]
async fn send_input_interruptive_resend_notifies_parent_with_completion() {
let harness = AgentControlHarness::new().await;
let (parent_thread_id, parent_thread) = harness.start_thread().await;
let (child_thread_id, child_thread) = harness.start_thread().await;
harness.control.maybe_start_completion_watcher(
child_thread_id,
Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id,
depth: 1,
agent_nickname: None,
agent_role: Some("explorer".to_string()),
})),
);
let invocation = function_invocation(
Arc::clone(&parent_thread.codex.session),
parent_thread.codex.session.new_default_turn().await,
"send_input",
json!({
"id": child_thread_id.to_string(),
"message": "wrap up",
"interrupt": true
}),
);
MultiAgentHandler
.handle(invocation)
.await
.expect("send_input should succeed");
let interrupted_turn = child_thread.codex.session.new_default_turn().await;
child_thread
.codex
.session
.send_event(
interrupted_turn.as_ref(),
EventMsg::TurnStarted(TurnStartedEvent {
turn_id: interrupted_turn.sub_id.clone(),
model_context_window: None,
collaboration_mode_kind: ModeKind::Default,
}),
)
.await;
child_thread
.codex
.session
.send_event(
interrupted_turn.as_ref(),
EventMsg::TurnAborted(TurnAbortedEvent {
turn_id: Some(interrupted_turn.sub_id.clone()),
reason: TurnAbortReason::Interrupted,
}),
)
.await;
let completed_turn = child_thread.codex.session.new_default_turn().await;
child_thread
.codex
.session
.send_event(
completed_turn.as_ref(),
EventMsg::TurnStarted(TurnStartedEvent {
turn_id: completed_turn.sub_id.clone(),
model_context_window: None,
collaboration_mode_kind: ModeKind::Default,
}),
)
.await;
child_thread
.codex
.session
.send_event(
completed_turn.as_ref(),
EventMsg::TurnComplete(TurnCompleteEvent {
turn_id: completed_turn.sub_id.clone(),
last_agent_message: Some("done".to_string()),
}),
)
.await;
assert_eq!(wait_for_subagent_notification(&parent_thread).await, true);
let history_items = parent_thread
.codex
.session
.clone_history()
.await
.raw_items()
.to_vec();
assert_eq!(
history_contains_text(&history_items, "\"completed\":\"done\""),
true
);
assert_eq!(
history_contains_text(&history_items, "\"Interrupted\""),
false
);
}
#[tokio::test]
async fn completion_watcher_reports_interrupted_when_follow_up_is_cancelled() {
let harness = AgentControlHarness::new().await;
let (parent_thread_id, parent_thread) = harness.start_thread().await;
let (child_thread_id, child_thread) = harness.start_thread().await;
harness.control.maybe_start_completion_watcher(
child_thread_id,
Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id,
depth: 1,
agent_nickname: None,
agent_role: Some("explorer".to_string()),
})),
);
harness
.control
.mark_interrupt_follow_up_pending(child_thread_id)
.await;
let interrupted_turn = child_thread.codex.session.new_default_turn().await;
child_thread
.codex
.session
.send_event(
interrupted_turn.as_ref(),
EventMsg::TurnStarted(TurnStartedEvent {
turn_id: interrupted_turn.sub_id.clone(),
model_context_window: None,
collaboration_mode_kind: ModeKind::Default,
}),
)
.await;
child_thread
.codex
.session
.send_event(
interrupted_turn.as_ref(),
EventMsg::TurnAborted(TurnAbortedEvent {
turn_id: Some(interrupted_turn.sub_id.clone()),
reason: TurnAbortReason::Interrupted,
}),
)
.await;
harness
.control
.mark_interrupt_follow_up_cancelled(child_thread_id)
.await;
assert_eq!(wait_for_subagent_notification(&parent_thread).await, true);
let history_items = parent_thread
.codex
.session
.clone_history()
.await
.raw_items()
.to_vec();
assert_eq!(
history_contains_text(&history_items, "\"Interrupted\""),
true
);
}
#[tokio::test]
async fn completion_watcher_reports_completion_even_if_follow_up_cleanup_is_lost() {
let harness = AgentControlHarness::new().await;
let (parent_thread_id, parent_thread) = harness.start_thread().await;
let (child_thread_id, child_thread) = harness.start_thread().await;
harness.control.maybe_start_completion_watcher(
child_thread_id,
Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
parent_thread_id,
depth: 1,
agent_nickname: None,
agent_role: Some("explorer".to_string()),
})),
);
harness
.control
.mark_interrupt_follow_up_pending(child_thread_id)
.await;
let interrupted_turn = child_thread.codex.session.new_default_turn().await;
child_thread
.codex
.session
.send_event(
interrupted_turn.as_ref(),
EventMsg::TurnStarted(TurnStartedEvent {
turn_id: interrupted_turn.sub_id.clone(),
model_context_window: None,
collaboration_mode_kind: ModeKind::Default,
}),
)
.await;
child_thread
.codex
.session
.send_event(
interrupted_turn.as_ref(),
EventMsg::TurnAborted(TurnAbortedEvent {
turn_id: Some(interrupted_turn.sub_id.clone()),
reason: TurnAbortReason::Interrupted,
}),
)
.await;
yield_now().await;
let completed_turn = child_thread.codex.session.new_default_turn().await;
child_thread
.codex
.session
.send_event(
completed_turn.as_ref(),
EventMsg::TurnStarted(TurnStartedEvent {
turn_id: completed_turn.sub_id.clone(),
model_context_window: None,
collaboration_mode_kind: ModeKind::Default,
}),
)
.await;
child_thread
.codex
.session
.send_event(
completed_turn.as_ref(),
EventMsg::TurnComplete(TurnCompleteEvent {
turn_id: completed_turn.sub_id.clone(),
last_agent_message: Some("done".to_string()),
}),
)
.await;
assert_eq!(wait_for_subagent_notification(&parent_thread).await, true);
let history_items = parent_thread
.codex
.session
.clone_history()
.await
.raw_items()
.to_vec();
assert_eq!(
history_contains_text(&history_items, "\"completed\":\"done\""),
true
);
assert_eq!(
history_contains_text(&history_items, "\"Interrupted\""),
false
);
}
#[tokio::test]
async fn spawn_thread_subagent_gets_random_nickname_in_session_source() {
let harness = AgentControlHarness::new().await;

View File

@@ -123,6 +123,9 @@ pub struct NewThread {
/// them in memory.
pub struct ThreadManager {
state: Arc<ThreadManagerState>,
// AgentControl needs to be shared across clones so parent/child coordination state is visible
// to every code path that talks to subagents.
agent_control: AgentControl,
_test_codex_home_guard: Option<TempCodexHomeGuard>,
}
@@ -159,25 +162,27 @@ impl ThreadManager {
Arc::clone(&plugins_manager),
));
let file_watcher = build_file_watcher(codex_home.clone(), Arc::clone(&skills_manager));
let state = Arc::new(ThreadManagerState {
threads: Arc::new(RwLock::new(HashMap::new())),
thread_created_tx,
models_manager: Arc::new(ModelsManager::new(
codex_home,
auth_manager.clone(),
model_catalog,
collaboration_modes_config,
)),
skills_manager,
plugins_manager,
mcp_manager,
file_watcher,
auth_manager,
session_source,
ops_log: should_use_test_thread_manager_behavior()
.then(|| Arc::new(std::sync::Mutex::new(Vec::new()))),
});
Self {
state: Arc::new(ThreadManagerState {
threads: Arc::new(RwLock::new(HashMap::new())),
thread_created_tx,
models_manager: Arc::new(ModelsManager::new(
codex_home,
auth_manager.clone(),
model_catalog,
collaboration_modes_config,
)),
skills_manager,
plugins_manager,
mcp_manager,
file_watcher,
auth_manager,
session_source,
ops_log: should_use_test_thread_manager_behavior()
.then(|| Arc::new(std::sync::Mutex::new(Vec::new()))),
}),
agent_control: AgentControl::new(Arc::downgrade(&state)),
state,
_test_codex_home_guard: None,
}
}
@@ -218,24 +223,26 @@ impl ThreadManager {
Arc::clone(&plugins_manager),
));
let file_watcher = build_file_watcher(codex_home.clone(), Arc::clone(&skills_manager));
let state = Arc::new(ThreadManagerState {
threads: Arc::new(RwLock::new(HashMap::new())),
thread_created_tx,
models_manager: Arc::new(ModelsManager::with_provider_for_tests(
codex_home,
auth_manager.clone(),
provider,
)),
skills_manager,
plugins_manager,
mcp_manager,
file_watcher,
auth_manager,
session_source: SessionSource::Exec,
ops_log: should_use_test_thread_manager_behavior()
.then(|| Arc::new(std::sync::Mutex::new(Vec::new()))),
});
Self {
state: Arc::new(ThreadManagerState {
threads: Arc::new(RwLock::new(HashMap::new())),
thread_created_tx,
models_manager: Arc::new(ModelsManager::with_provider_for_tests(
codex_home,
auth_manager.clone(),
provider,
)),
skills_manager,
plugins_manager,
mcp_manager,
file_watcher,
auth_manager,
session_source: SessionSource::Exec,
ops_log: should_use_test_thread_manager_behavior()
.then(|| Arc::new(std::sync::Mutex::new(Vec::new()))),
}),
agent_control: AgentControl::new(Arc::downgrade(&state)),
state,
_test_codex_home_guard: None,
}
}
@@ -423,7 +430,7 @@ impl ThreadManager {
}
pub(crate) fn agent_control(&self) -> AgentControl {
AgentControl::new(Arc::downgrade(&self.state))
self.agent_control.clone()
}
#[cfg(test)]

View File

@@ -266,13 +266,31 @@ mod send_input {
.get_agent_nickname_and_role(receiver_thread_id)
.await
.unwrap_or((None, None));
if args.interrupt {
session
// Interruptive resend is modeled as "cancel the old turn, then queue replacement input".
// The follow-up handle suppresses the transient Interrupted status until the replacement
// either starts making progress or fails to submit.
let mut interrupt_follow_up = if args.interrupt {
Some(
session
.services
.agent_control
.begin_interrupt_follow_up(receiver_thread_id)
.await,
)
} else {
None
};
if args.interrupt
&& let Err(err) = session
.services
.agent_control
.interrupt_agent(receiver_thread_id)
.await
.map_err(|err| collab_agent_error(receiver_thread_id, err))?;
{
if let Some(follow_up) = interrupt_follow_up.as_mut() {
follow_up.cancel();
}
return Err(collab_agent_error(receiver_thread_id, err));
}
session
.send_event(
@@ -292,10 +310,16 @@ mod send_input {
.send_input(receiver_thread_id, input_items)
.await
.map_err(|err| collab_agent_error(receiver_thread_id, err));
if let Some(follow_up) = interrupt_follow_up.as_mut() {
match &result {
Ok(_) => follow_up.commit(),
Err(_) => follow_up.cancel(),
}
}
let status = session
.services
.agent_control
.get_status(receiver_thread_id)
.get_reportable_status(receiver_thread_id)
.await;
session
.send_event(
@@ -466,6 +490,7 @@ mod resume_agent {
pub(crate) mod wait {
use super::*;
use crate::agent::control::wait_for_reportable_status;
use crate::agent::status::is_final;
use futures::FutureExt;
use futures::StreamExt;
@@ -554,10 +579,17 @@ pub(crate) mod wait {
match session.services.agent_control.subscribe_status(*id).await {
Ok(rx) => {
let status = rx.borrow().clone();
if is_final(&status) {
if is_final(&status)
&& session
.services
.agent_control
.is_status_reportable_immediately(*id, &status)
.await
{
initial_final_statuses.push((*id, status));
} else {
status_rxs.push((*id, rx));
}
status_rxs.push((*id, rx));
}
Err(CodexErr::ThreadNotFound(_)) => {
initial_final_statuses.push((*id, AgentStatus::NotFound));
@@ -654,23 +686,11 @@ pub(crate) mod wait {
async fn wait_for_final_status(
session: Arc<Session>,
thread_id: ThreadId,
mut status_rx: Receiver<AgentStatus>,
status_rx: Receiver<AgentStatus>,
) -> Option<(ThreadId, AgentStatus)> {
let mut status = status_rx.borrow().clone();
if is_final(&status) {
return Some((thread_id, status));
}
loop {
if status_rx.changed().await.is_err() {
let latest = session.services.agent_control.get_status(thread_id).await;
return is_final(&latest).then_some((thread_id, latest));
}
status = status_rx.borrow().clone();
if is_final(&status) {
return Some((thread_id, status));
}
}
let status =
wait_for_reportable_status(&session.services.agent_control, thread_id, status_rx).await;
is_final(&status).then_some((thread_id, status))
}
}
@@ -995,10 +1015,16 @@ mod tests {
use crate::protocol::SubAgentSource;
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_protocol::ThreadId;
use codex_protocol::config_types::ModeKind;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::InitialHistory;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::TurnAbortReason;
use codex_protocol::protocol::TurnAbortedEvent;
use codex_protocol::protocol::TurnCompleteEvent;
use codex_protocol::protocol::TurnStartedEvent;
use pretty_assertions::assert_eq;
use serde::Deserialize;
use serde_json::json;
@@ -1007,6 +1033,7 @@ mod tests {
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::task::yield_now;
use tokio::time::timeout;
fn invocation(
@@ -1986,6 +2013,125 @@ mod tests {
assert_eq!(success, None);
}
#[tokio::test]
async fn wait_prefers_later_completion_after_interruptive_resend() {
let (mut session, turn) = make_session_and_context().await;
let manager = thread_manager();
session.services.agent_control = manager.agent_control();
let config = turn.config.as_ref().clone();
let thread = manager.start_thread(config).await.expect("start thread");
let agent_id = thread.thread_id;
let invocation = invocation(
Arc::new(session),
Arc::new(turn),
"wait",
function_payload(json!({
"ids": [agent_id.to_string()],
"timeout_ms": MIN_WAIT_TIMEOUT_MS
})),
);
let mut wait_task = tokio::spawn(async move { MultiAgentHandler.handle(invocation).await });
yield_now().await;
manager
.agent_control()
.mark_interrupt_follow_up_pending(agent_id)
.await;
let interrupted_turn = thread.thread.codex.session.new_default_turn().await;
thread
.thread
.codex
.session
.send_event(
interrupted_turn.as_ref(),
EventMsg::TurnStarted(TurnStartedEvent {
turn_id: interrupted_turn.sub_id.clone(),
model_context_window: None,
collaboration_mode_kind: ModeKind::Default,
}),
)
.await;
thread
.thread
.codex
.session
.send_event(
interrupted_turn.as_ref(),
EventMsg::TurnAborted(TurnAbortedEvent {
turn_id: Some(interrupted_turn.sub_id.clone()),
reason: TurnAbortReason::Interrupted,
}),
)
.await;
let early = timeout(Duration::from_millis(100), &mut wait_task).await;
assert!(
early.is_err(),
"wait should not resolve on the interrupted turn while a follow-up is still pending"
);
manager
.agent_control()
.mark_interrupt_follow_up_committed(agent_id)
.await;
let completed_turn = thread.thread.codex.session.new_default_turn().await;
thread
.thread
.codex
.session
.send_event(
completed_turn.as_ref(),
EventMsg::TurnStarted(TurnStartedEvent {
turn_id: completed_turn.sub_id.clone(),
model_context_window: None,
collaboration_mode_kind: ModeKind::Default,
}),
)
.await;
thread
.thread
.codex
.session
.send_event(
completed_turn.as_ref(),
EventMsg::TurnComplete(TurnCompleteEvent {
turn_id: completed_turn.sub_id.clone(),
last_agent_message: Some("done".to_string()),
}),
)
.await;
let output = timeout(Duration::from_secs(1), wait_task)
.await
.expect("wait should complete after the redirected turn finishes")
.expect("wait task should join")
.expect("wait should succeed");
let ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success,
..
} = output
else {
panic!("expected function output");
};
let result: wait::WaitResult =
serde_json::from_str(&content).expect("wait result should be json");
assert_eq!(
result,
wait::WaitResult {
status: HashMap::from([(
agent_id,
AgentStatus::Completed(Some("done".to_string()))
)]),
timed_out: false
}
);
assert_eq!(success, None);
}
#[tokio::test]
async fn close_agent_submits_shutdown_and_returns_status() {
let (mut session, turn) = make_session_and_context().await;