steering and queue tests

This commit is contained in:
Roy Han
2026-03-12 15:40:22 -07:00
parent 8270e0b977
commit 9a7549c9bd

View File

@@ -9,6 +9,8 @@ use codex_protocol::items::AgentMessageContent;
use codex_protocol::items::TurnItem;
use codex_protocol::models::WebSearchAction;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::SandboxPolicy;
use codex_protocol::protocol::ItemCompletedEvent;
use codex_protocol::protocol::ItemStartedEvent;
use codex_protocol::protocol::Op;
@@ -17,6 +19,7 @@ use codex_protocol::user_input::TextElement;
use codex_protocol::user_input::UserInput;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::ev_image_generation_call;
use core_test_support::responses::ev_message_item_added;
use core_test_support::responses::ev_output_text_delta;
@@ -28,6 +31,7 @@ use core_test_support::responses::ev_response_created;
use core_test_support::responses::ev_web_search_call_added_partial;
use core_test_support::responses::ev_web_search_call_done;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::skip_if_no_network;
@@ -36,11 +40,58 @@ use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use core_test_support::wait_for_event_match;
use pretty_assertions::assert_eq;
use serde_json::json;
use serde_json::Value;
use tempfile::tempdir;
use tokio::time::Duration;
use tokio::time::Instant;
use tokio::time::sleep;
fn user_message_item_by_text<'a>(input: &'a [Value], text: &str) -> &'a Value {
input
.iter()
.find(|item| {
if item.get("type").and_then(Value::as_str) != Some("message")
|| item.get("role").and_then(Value::as_str) != Some("user")
{
return false;
}
item.get("content")
.and_then(Value::as_array)
.is_some_and(|content| {
content.iter().any(|entry| {
entry.get("type").and_then(Value::as_str) == Some("input_text")
&& entry.get("text").and_then(Value::as_str) == Some(text)
})
})
})
.unwrap_or_else(|| panic!("submitted user message input item not found for text: {text}"))
}
fn user_message_item_containing<'a>(input: &'a [Value], needle: &str) -> &'a Value {
input
.iter()
.find(|item| {
if item.get("type").and_then(Value::as_str) != Some("message")
|| item.get("role").and_then(Value::as_str) != Some("user")
{
return false;
}
item.get("content")
.and_then(Value::as_array)
.is_some_and(|content| {
content.iter().any(|entry| {
entry.get("type").and_then(Value::as_str) == Some("input_text")
&& entry
.get("text")
.and_then(Value::as_str)
.is_some_and(|text| text.contains(needle))
})
})
})
.unwrap_or_else(|| panic!("submitted user message input item not found containing: {needle}"))
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn user_message_item_is_emitted() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
@@ -204,6 +255,261 @@ async fn user_message_type_metadata_is_emitted_when_feature_enabled() -> anyhow:
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn user_message_type_prompt_steering_metadata_is_emitted_when_feature_enabled()
-> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let temp = tempdir()?;
let unblock_path = temp.path().join("unblock-steering");
let command = format!(
"while [ ! -f \"{}\" ]; do sleep 0.01; done; echo done",
unblock_path.display()
);
let call_id = "shell-steering-call";
let responses = mount_sse_sequence(
&server,
vec![
sse(vec![
ev_response_created("resp-1"),
ev_function_call(call_id, "shell", &serde_json::to_string(&json!({
"command": ["/bin/sh", "-c", command],
}))?),
ev_completed("resp-1"),
]),
sse(vec![
ev_assistant_message("msg-2", "done"),
ev_completed("resp-2"),
]),
],
)
.await;
let test = test_codex()
.with_model("gpt-5")
.with_config(|config| {
config
.features
.enable(Feature::UserMessageTypeMetadata)
.expect("feature flag should be enabled for this test");
})
.build(&server)
.await?;
let codex = test.codex.clone();
let turn_model = test.session_configured.model.clone();
codex
.submit(Op::UserTurn {
items: vec![UserInput::Text {
text: "start steering flow".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
cwd: test.cwd_path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::DangerFullAccess,
model: turn_model,
effort: None,
summary: None,
service_tier: None,
collaboration_mode: None,
personality: None,
})
.await?;
let turn_id = wait_for_event_match(&codex, |ev| match ev {
EventMsg::TurnStarted(event) => Some(event.turn_id.clone()),
_ => None,
})
.await;
wait_for_event_match(&codex, |ev| match ev {
EventMsg::ExecCommandBegin(event) if event.call_id == call_id => Some(()),
_ => None,
})
.await;
let steering_text = "steering metadata check";
let steered_turn_id = codex
.steer_input(
vec![UserInput::Text {
text: steering_text.into(),
text_elements: Vec::new(),
}],
Some(turn_id.as_str()),
)
.await
.expect("steer should succeed on active turn");
assert_eq!(steered_turn_id, turn_id);
std::fs::write(&unblock_path, "go")?;
wait_for_event(&codex, |ev| match ev {
EventMsg::TurnComplete(event) => event.turn_id == turn_id,
_ => false,
})
.await;
let deadline = Instant::now() + Duration::from_secs(3);
while responses.requests().len() < 2 {
if Instant::now() >= deadline {
panic!("timed out waiting for second responses request");
}
sleep(Duration::from_millis(10)).await;
}
let requests = responses.requests();
let second_body = requests
.get(1)
.expect("second responses request should be present")
.body_json();
let input = second_body
.get("input")
.and_then(Value::as_array)
.expect("request input array");
let steered_message = user_message_item_by_text(input, steering_text);
assert_eq!(
steered_message
.get("metadata")
.and_then(|metadata| metadata.get("user_message_type"))
.and_then(Value::as_str),
Some("prompt_steering")
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn user_message_type_prompt_queued_metadata_is_emitted_when_feature_enabled()
-> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let temp = tempdir()?;
let unblock_path = temp.path().join("unblock-queued");
let command = format!(
"while [ ! -f \"{}\" ]; do sleep 0.01; done; echo done",
unblock_path.display()
);
let call_id = "shell-queued-call";
let responses = mount_sse_sequence(
&server,
vec![
sse(vec![
ev_response_created("resp-1"),
ev_function_call(call_id, "shell", &serde_json::to_string(&json!({
"command": ["/bin/sh", "-c", command],
}))?),
ev_completed("resp-1"),
]),
sse(vec![
ev_assistant_message("msg-2", "done"),
ev_completed("resp-2"),
]),
],
)
.await;
let test = test_codex()
.with_model("gpt-5")
.with_config(|config| {
config
.features
.enable(Feature::UserMessageTypeMetadata)
.expect("feature flag should be enabled for this test");
})
.build(&server)
.await?;
let codex = test.codex.clone();
let turn_model = test.session_configured.model.clone();
codex
.submit(Op::UserTurn {
items: vec![UserInput::Text {
text: "start queued flow".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
cwd: test.cwd_path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::DangerFullAccess,
model: turn_model,
effort: None,
summary: None,
service_tier: None,
collaboration_mode: None,
personality: None,
})
.await?;
let turn_id = wait_for_event_match(&codex, |ev| match ev {
EventMsg::TurnStarted(event) => Some(event.turn_id.clone()),
_ => None,
})
.await;
wait_for_event_match(&codex, |ev| match ev {
EventMsg::ExecCommandBegin(event) if event.call_id == call_id => Some(()),
_ => None,
})
.await;
let queued_text = "queued metadata check";
codex
.submit(Op::RunUserShellCommand {
command: format!("printf '{queued_text}'"),
})
.await?;
wait_for_event_match(&codex, |ev| match ev {
EventMsg::ExecCommandEnd(event)
if event.source == codex_protocol::protocol::ExecCommandSource::UserShell =>
{
Some(())
}
_ => None,
})
.await;
std::fs::write(&unblock_path, "go")?;
wait_for_event(&codex, |ev| match ev {
EventMsg::TurnComplete(event) => event.turn_id == turn_id,
_ => false,
})
.await;
let deadline = Instant::now() + Duration::from_secs(3);
while responses.requests().len() < 2 {
if Instant::now() >= deadline {
panic!("timed out waiting for second responses request");
}
sleep(Duration::from_millis(10)).await;
}
let requests = responses.requests();
let second_body = requests
.get(1)
.expect("second responses request should be present")
.body_json();
let input = second_body
.get("input")
.and_then(Value::as_array)
.expect("request input array");
let queued_message = user_message_item_containing(input, queued_text);
assert_eq!(
queued_message
.get("metadata")
.and_then(|metadata| metadata.get("user_message_type"))
.and_then(Value::as_str),
Some("prompt_queued")
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn assistant_message_item_is_emitted() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));