mirror of
https://github.com/openai/codex.git
synced 2026-05-01 03:42:05 +03:00
feat: sub-agent injection (#12152)
This PR adds parent-thread sub-agent completion notifications and change the prompt of the model to prevent if from being confused
This commit is contained in:
@@ -114,6 +114,7 @@ mod skills;
|
||||
mod sqlite_state;
|
||||
mod stream_error_allows_next_turn;
|
||||
mod stream_no_completed;
|
||||
mod subagent_notifications;
|
||||
mod text_encoding_fix;
|
||||
mod tool_harness;
|
||||
mod tool_parallelism;
|
||||
|
||||
196
codex-rs/core/tests/suite/subagent_notifications.rs
Normal file
196
codex-rs/core/tests/suite/subagent_notifications.rs
Normal file
@@ -0,0 +1,196 @@
|
||||
use anyhow::Result;
|
||||
use codex_core::features::Feature;
|
||||
use core_test_support::responses::ResponsesRequest;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_response_once_match;
|
||||
use core_test_support::responses::mount_sse_once_match;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::sse_response;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::TestCodex;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use serde_json::json;
|
||||
use std::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::sleep;
|
||||
use wiremock::MockServer;
|
||||
|
||||
const SPAWN_CALL_ID: &str = "spawn-call-1";
|
||||
const TURN_1_PROMPT: &str = "spawn a child and continue";
|
||||
const TURN_2_NO_WAIT_PROMPT: &str = "follow up without wait";
|
||||
const CHILD_PROMPT: &str = "child: do work";
|
||||
|
||||
fn body_contains(req: &wiremock::Request, text: &str) -> bool {
|
||||
let is_zstd = req
|
||||
.headers
|
||||
.get("content-encoding")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.is_some_and(|value| {
|
||||
value
|
||||
.split(',')
|
||||
.any(|entry| entry.trim().eq_ignore_ascii_case("zstd"))
|
||||
});
|
||||
let bytes = if is_zstd {
|
||||
zstd::stream::decode_all(std::io::Cursor::new(&req.body)).ok()
|
||||
} else {
|
||||
Some(req.body.clone())
|
||||
};
|
||||
bytes
|
||||
.and_then(|body| String::from_utf8(body).ok())
|
||||
.is_some_and(|body| body.contains(text))
|
||||
}
|
||||
|
||||
fn has_subagent_notification(req: &ResponsesRequest) -> bool {
|
||||
req.message_input_texts("user")
|
||||
.iter()
|
||||
.any(|text| text.contains("<subagent_notification>"))
|
||||
}
|
||||
|
||||
async fn wait_for_spawned_thread_id(test: &TestCodex) -> Result<String> {
|
||||
let deadline = Instant::now() + Duration::from_secs(2);
|
||||
loop {
|
||||
let ids = test.thread_manager.list_thread_ids().await;
|
||||
if let Some(spawned_id) = ids
|
||||
.iter()
|
||||
.find(|id| **id != test.session_configured.session_id)
|
||||
{
|
||||
return Ok(spawned_id.to_string());
|
||||
}
|
||||
if Instant::now() >= deadline {
|
||||
anyhow::bail!("timed out waiting for spawned thread id");
|
||||
}
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_requests(
|
||||
mock: &core_test_support::responses::ResponseMock,
|
||||
) -> Result<Vec<ResponsesRequest>> {
|
||||
let deadline = Instant::now() + Duration::from_secs(2);
|
||||
loop {
|
||||
let requests = mock.requests();
|
||||
if !requests.is_empty() {
|
||||
return Ok(requests);
|
||||
}
|
||||
if Instant::now() >= deadline {
|
||||
anyhow::bail!("expected at least 1 request, got {}", requests.len());
|
||||
}
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup_turn_one_with_spawned_child(
|
||||
server: &MockServer,
|
||||
child_response_delay: Option<Duration>,
|
||||
) -> Result<(TestCodex, String)> {
|
||||
let spawn_args = serde_json::to_string(&json!({
|
||||
"message": CHILD_PROMPT,
|
||||
}))?;
|
||||
|
||||
mount_sse_once_match(
|
||||
server,
|
||||
|req: &wiremock::Request| body_contains(req, TURN_1_PROMPT),
|
||||
sse(vec![
|
||||
ev_response_created("resp-turn1-1"),
|
||||
ev_function_call(SPAWN_CALL_ID, "spawn_agent", &spawn_args),
|
||||
ev_completed("resp-turn1-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let child_sse = sse(vec![
|
||||
ev_response_created("resp-child-1"),
|
||||
ev_assistant_message("msg-child-1", "child done"),
|
||||
ev_completed("resp-child-1"),
|
||||
]);
|
||||
let child_request_log = if let Some(delay) = child_response_delay {
|
||||
mount_response_once_match(
|
||||
server,
|
||||
|req: &wiremock::Request| {
|
||||
body_contains(req, CHILD_PROMPT) && !body_contains(req, SPAWN_CALL_ID)
|
||||
},
|
||||
sse_response(child_sse).set_delay(delay),
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
mount_sse_once_match(
|
||||
server,
|
||||
|req: &wiremock::Request| {
|
||||
body_contains(req, CHILD_PROMPT) && !body_contains(req, SPAWN_CALL_ID)
|
||||
},
|
||||
child_sse,
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
let _turn1_followup = mount_sse_once_match(
|
||||
server,
|
||||
|req: &wiremock::Request| body_contains(req, SPAWN_CALL_ID),
|
||||
sse(vec![
|
||||
ev_response_created("resp-turn1-2"),
|
||||
ev_assistant_message("msg-turn1-2", "parent done"),
|
||||
ev_completed("resp-turn1-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.features.enable(Feature::Collab);
|
||||
});
|
||||
let test = builder.build(server).await?;
|
||||
test.submit_turn(TURN_1_PROMPT).await?;
|
||||
if child_response_delay.is_none() {
|
||||
let _ = wait_for_requests(&child_request_log).await?;
|
||||
let rollout_path = test
|
||||
.codex
|
||||
.rollout_path()
|
||||
.ok_or_else(|| anyhow::anyhow!("expected parent rollout path"))?;
|
||||
let deadline = Instant::now() + Duration::from_secs(6);
|
||||
loop {
|
||||
let has_notification = tokio::fs::read_to_string(&rollout_path)
|
||||
.await
|
||||
.is_ok_and(|rollout| rollout.contains("<subagent_notification>"));
|
||||
if has_notification {
|
||||
break;
|
||||
}
|
||||
if Instant::now() >= deadline {
|
||||
anyhow::bail!(
|
||||
"timed out waiting for parent rollout to include subagent notification"
|
||||
);
|
||||
}
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
let spawned_id = wait_for_spawned_thread_id(&test).await?;
|
||||
|
||||
Ok((test, spawned_id))
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn subagent_notification_is_included_without_wait() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let (test, _spawned_id) = setup_turn_one_with_spawned_child(&server, None).await?;
|
||||
|
||||
let turn2 = mount_sse_once_match(
|
||||
&server,
|
||||
|req: &wiremock::Request| body_contains(req, TURN_2_NO_WAIT_PROMPT),
|
||||
sse(vec![
|
||||
ev_response_created("resp-turn2-1"),
|
||||
ev_assistant_message("msg-turn2-1", "no wait path"),
|
||||
ev_completed("resp-turn2-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
test.submit_turn(TURN_2_NO_WAIT_PROMPT).await?;
|
||||
|
||||
let turn2_requests = wait_for_requests(&turn2).await?;
|
||||
assert!(turn2_requests.iter().any(has_subagent_notification));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user