mirror of
https://github.com/openai/codex.git
synced 2026-05-06 06:12:59 +03:00
feat: sub-agent injection (#12152)
This PR adds parent-thread sub-agent completion notifications and change the prompt of the model to prevent if from being confused
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
use crate::agent::AgentStatus;
|
||||
use crate::agent::guards::Guards;
|
||||
use crate::agent::status::is_final;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::session_prefix::format_subagent_notification_message;
|
||||
use crate::thread_manager::ThreadManagerState;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use std::path::PathBuf;
|
||||
@@ -46,6 +49,7 @@ impl AgentControl {
|
||||
) -> CodexResult<ThreadId> {
|
||||
let state = self.upgrade()?;
|
||||
let reservation = self.state.reserve_spawn_slot(config.agent_max_threads)?;
|
||||
let notification_source = session_source.clone();
|
||||
|
||||
// The same `AgentControl` is sent to spawn the thread.
|
||||
let new_thread = match session_source {
|
||||
@@ -64,6 +68,7 @@ impl AgentControl {
|
||||
state.notify_thread_created(new_thread.thread_id);
|
||||
|
||||
self.send_input(new_thread.thread_id, items).await?;
|
||||
self.maybe_start_completion_watcher(new_thread.thread_id, notification_source);
|
||||
|
||||
Ok(new_thread.thread_id)
|
||||
}
|
||||
@@ -77,6 +82,7 @@ impl AgentControl {
|
||||
) -> CodexResult<ThreadId> {
|
||||
let state = self.upgrade()?;
|
||||
let reservation = self.state.reserve_spawn_slot(config.agent_max_threads)?;
|
||||
let notification_source = session_source.clone();
|
||||
|
||||
let resumed_thread = state
|
||||
.resume_thread_from_rollout_with_source(
|
||||
@@ -90,6 +96,7 @@ impl AgentControl {
|
||||
// Resumed threads are re-registered in-memory and need the same listener
|
||||
// attachment path as freshly spawned threads.
|
||||
state.notify_thread_created(resumed_thread.thread_id);
|
||||
self.maybe_start_completion_watcher(resumed_thread.thread_id, Some(notification_source));
|
||||
|
||||
Ok(resumed_thread.thread_id)
|
||||
}
|
||||
@@ -164,13 +171,60 @@ impl AgentControl {
|
||||
thread.total_token_usage().await
|
||||
}
|
||||
|
||||
/// Starts a detached watcher for sub-agents spawned from another thread.
|
||||
///
|
||||
/// This is only enabled for `SubAgentSource::ThreadSpawn`, where a parent thread exists and
|
||||
/// can receive completion notifications.
|
||||
fn maybe_start_completion_watcher(
|
||||
&self,
|
||||
child_thread_id: ThreadId,
|
||||
session_source: Option<SessionSource>,
|
||||
) {
|
||||
let Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id, ..
|
||||
})) = session_source
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let control = self.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut status_rx = match control.subscribe_status(child_thread_id).await {
|
||||
Ok(rx) => rx,
|
||||
Err(_) => return,
|
||||
};
|
||||
let mut status = status_rx.borrow().clone();
|
||||
while !is_final(&status) {
|
||||
if status_rx.changed().await.is_err() {
|
||||
status = control.get_status(child_thread_id).await;
|
||||
break;
|
||||
}
|
||||
status = status_rx.borrow().clone();
|
||||
}
|
||||
if !is_final(&status) {
|
||||
return;
|
||||
}
|
||||
|
||||
let Ok(state) = control.upgrade() else {
|
||||
return;
|
||||
};
|
||||
let Ok(parent_thread) = state.get_thread(parent_thread_id).await else {
|
||||
return;
|
||||
};
|
||||
parent_thread
|
||||
.inject_user_message_without_turn(format_subagent_notification_message(
|
||||
&child_thread_id.to_string(),
|
||||
&status,
|
||||
))
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
fn upgrade(&self) -> CodexResult<Arc<ThreadManagerState>> {
|
||||
self.manager
|
||||
.upgrade()
|
||||
.ok_or_else(|| CodexErr::UnsupportedOperation("thread manager dropped".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -180,16 +234,24 @@ mod tests {
|
||||
use crate::agent::agent_status_from_event;
|
||||
use crate::config::Config;
|
||||
use crate::config::ConfigBuilder;
|
||||
use crate::session_prefix::SUBAGENT_NOTIFICATION_OPEN_TAG;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::config_types::ModeKind;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::ErrorEvent;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnAbortedEvent;
|
||||
use codex_protocol::protocol::TurnCompleteEvent;
|
||||
use codex_protocol::protocol::TurnStartedEvent;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
async fn test_config_with_cli_overrides(
|
||||
@@ -250,6 +312,42 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn has_subagent_notification(history_items: &[ResponseItem]) -> bool {
|
||||
history_items.iter().any(|item| {
|
||||
let ResponseItem::Message { role, content, .. } = item else {
|
||||
return false;
|
||||
};
|
||||
if role != "user" {
|
||||
return false;
|
||||
}
|
||||
content.iter().any(|content_item| match content_item {
|
||||
ContentItem::InputText { text } | ContentItem::OutputText { text } => {
|
||||
text.contains(SUBAGENT_NOTIFICATION_OPEN_TAG)
|
||||
}
|
||||
ContentItem::InputImage { .. } => false,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
async fn wait_for_subagent_notification(parent_thread: &Arc<CodexThread>) -> bool {
|
||||
let wait = async {
|
||||
loop {
|
||||
let history_items = parent_thread
|
||||
.codex
|
||||
.session
|
||||
.clone_history()
|
||||
.await
|
||||
.raw_items()
|
||||
.to_vec();
|
||||
if has_subagent_notification(&history_items) {
|
||||
return true;
|
||||
}
|
||||
sleep(Duration::from_millis(25)).await;
|
||||
}
|
||||
};
|
||||
timeout(Duration::from_secs(2), wait).await.is_ok()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_input_errors_when_manager_dropped() {
|
||||
let control = AgentControl::default();
|
||||
@@ -683,4 +781,35 @@ mod tests {
|
||||
.await
|
||||
.expect("shutdown resumed thread");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_child_completion_notifies_parent_history() {
|
||||
let harness = AgentControlHarness::new().await;
|
||||
let (parent_thread_id, parent_thread) = harness.start_thread().await;
|
||||
|
||||
let child_thread_id = harness
|
||||
.control
|
||||
.spawn_agent(
|
||||
harness.config.clone(),
|
||||
text_input("hello child"),
|
||||
Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id,
|
||||
depth: 1,
|
||||
})),
|
||||
)
|
||||
.await
|
||||
.expect("child spawn should succeed");
|
||||
|
||||
let child_thread = harness
|
||||
.manager
|
||||
.get_thread(child_thread_id)
|
||||
.await
|
||||
.expect("child thread should exist");
|
||||
let _ = child_thread
|
||||
.submit(Op::Shutdown {})
|
||||
.await
|
||||
.expect("child shutdown should submit");
|
||||
|
||||
assert_eq!(wait_for_subagent_notification(&parent_thread).await, true);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user