Compare commits

...

1 Commits

Author SHA1 Message Date
Ee Durbin
5880fd7913 Core: pause pending steers after usage limits
Co-authored-by: Codex <noreply@openai.com>
2026-05-11 16:51:33 -07:00
9 changed files with 426 additions and 25 deletions

View File

@@ -289,6 +289,7 @@ use crate::shell;
use crate::shell_snapshot::ShellSnapshot;
use crate::state::ActiveTurn;
use crate::state::MailboxDeliveryPhase;
use crate::state::PendingInputItem;
use crate::state::PendingRequestPermissions;
use crate::state::SessionServices;
use crate::state::SessionState;
@@ -3075,7 +3076,7 @@ impl Session {
}
let mut turn_state = active_turn.turn_state.lock().await;
turn_state.push_pending_input(input.into());
turn_state.push_pending_steer_input(input.into());
turn_state.accept_mailbox_delivery_for_current_turn();
Ok(active_turn_id.clone())
}
@@ -3125,6 +3126,27 @@ impl Session {
.set_mailbox_delivery_phase(MailboxDeliveryPhase::CurrentTurn);
}
pub(crate) async fn mark_usage_limit_reached(&self, sub_id: &str) {
let turn_state = self.turn_state_for_sub_id(sub_id).await;
let Some(turn_state) = turn_state else {
return;
};
turn_state.lock().await.mark_usage_limit_reached();
}
pub(crate) async fn usage_limit_reached_for_active_turn(&self) -> bool {
let turn_state = {
let active = self.active_turn.lock().await;
active
.as_ref()
.map(|active_turn| Arc::clone(&active_turn.turn_state))
};
let Some(turn_state) = turn_state else {
return false;
};
turn_state.lock().await.usage_limit_reached()
}
pub(crate) async fn record_memory_citation_for_turn(&self, sub_id: &str) {
let turn_state = self.turn_state_for_sub_id(sub_id).await;
let Some(turn_state) = turn_state else {
@@ -3166,30 +3188,41 @@ impl Session {
clippy::await_holding_invalid_type,
reason = "active turn checks and turn state updates must remain atomic"
)]
pub async fn prepend_pending_input(&self, input: Vec<ResponseInputItem>) -> Result<(), ()> {
pub(crate) async fn prepend_pending_input_items(
&self,
input: Vec<PendingInputItem>,
) -> Result<(), ()> {
let mut active = self.active_turn.lock().await;
match active.as_mut() {
Some(at) => {
let mut ts = at.turn_state.lock().await;
ts.prepend_pending_input(input);
ts.prepend_pending_input_items(input);
Ok(())
}
None => Err(()),
}
}
pub async fn get_pending_input(&self) -> Vec<ResponseInputItem> {
self.get_pending_input_items()
.await
.into_iter()
.map(PendingInputItem::into_response_input_item)
.collect()
}
#[expect(
clippy::await_holding_invalid_type,
reason = "active turn checks and turn state updates must remain atomic"
)]
pub async fn get_pending_input(&self) -> Vec<ResponseInputItem> {
pub(crate) async fn get_pending_input_items(&self) -> Vec<PendingInputItem> {
let (pending_input, accepts_mailbox_delivery) = {
let mut active = self.active_turn.lock().await;
match active.as_mut() {
Some(at) => {
let mut ts = at.turn_state.lock().await;
(
ts.take_pending_input(),
ts.take_pending_input_items(),
ts.accepts_mailbox_delivery_for_current_turn(),
)
}
@@ -3204,7 +3237,7 @@ impl Session {
mailbox_rx
.drain()
.into_iter()
.map(|mail| mail.to_response_input_item())
.map(|mail| PendingInputItem::injected(mail.to_response_input_item()))
.collect::<Vec<_>>()
};
if pending_input.is_empty() {

View File

@@ -4,6 +4,7 @@ use crate::config::ConfigBuilder;
use crate::config::test_config;
use crate::context::ContextualUserFragment;
use crate::context::TurnAborted;
use crate::context::UserShellCommand;
use crate::exec::ExecCapturePolicy;
use crate::function_tool::FunctionCallError;
use crate::shell::default_user_shell;
@@ -59,6 +60,7 @@ use crate::goals::GoalRuntimeEvent;
use crate::goals::SetGoalRequest;
use crate::rollout::recorder::RolloutRecorder;
use crate::state::ActiveTurn;
use crate::state::PendingInputItem;
use crate::state::TaskKind;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
@@ -7388,6 +7390,132 @@ async fn task_finish_emits_turn_item_lifecycle_for_leftover_pending_user_input()
));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn task_finish_discards_leftover_pending_steer_after_usage_limit() {
let (sess, tc, rx) = make_session_and_context_with_rx().await;
let input = vec![UserInput::Text {
text: "hello".to_string(),
text_elements: Vec::new(),
}];
sess.spawn_task(
Arc::clone(&tc),
input,
NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: false,
},
)
.await;
while rx.try_recv().is_ok() {}
sess.steer_input(
vec![UserInput::Text {
text: "late pending input".to_string(),
text_elements: Vec::new(),
}],
Some(&tc.sub_id),
/*responsesapi_client_metadata*/ None,
)
.await
.expect("steer active turn");
sess.mark_usage_limit_reached(&tc.sub_id).await;
sess.on_task_finished(Arc::clone(&tc), /*last_agent_message*/ None)
.await;
let history = sess.clone_history().await;
let unexpected = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "late pending input".to_string(),
}],
phase: None,
};
assert!(
!history.raw_items().iter().any(|item| item == &unexpected),
"expected pending input to be discarded after usage-limit completion"
);
let first = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
.await
.expect("expected turn complete event")
.expect("channel open");
assert!(matches!(
first.msg,
EventMsg::TurnComplete(TurnCompleteEvent {
turn_id,
last_agent_message: None,
time_to_first_token_ms: None,
..
}) if turn_id == tc.sub_id
));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn task_finish_preserves_injected_pending_context_after_usage_limit() {
let (sess, tc, rx) = make_session_and_context_with_rx().await;
let input = vec![UserInput::Text {
text: "hello".to_string(),
text_elements: Vec::new(),
}];
sess.spawn_task(
Arc::clone(&tc),
input,
NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: false,
},
)
.await;
while rx.try_recv().is_ok() {}
let shell_output = match ContextualUserFragment::into(UserShellCommand::new(
"echo hi",
/*exit_code*/ 0,
std::time::Duration::from_secs(1),
"hi",
)) {
ResponseItem::Message {
role,
content,
phase,
..
} => ResponseInputItem::Message {
role,
content,
phase,
},
other => panic!("expected user shell output message, got {other:?}"),
};
let code_mode_notification = ResponseInputItem::CustomToolCallOutput {
call_id: "call-1".to_string(),
name: Some("code_mode".to_string()),
output: FunctionCallOutputPayload::from_text("cell output".to_string()),
};
sess.inject_response_items(vec![shell_output.clone(), code_mode_notification.clone()])
.await
.expect("inject pending context into active turn");
sess.mark_usage_limit_reached(&tc.sub_id).await;
sess.on_task_finished(Arc::clone(&tc), /*last_agent_message*/ None)
.await;
let history = sess.clone_history().await;
assert!(
history
.raw_items()
.contains(&ResponseItem::from(shell_output))
);
assert!(
history
.raw_items()
.contains(&ResponseItem::from(code_mode_notification))
);
}
#[tokio::test]
async fn steer_input_requires_active_turn() {
let (sess, _tc, _rx) = make_session_and_context_with_rx().await;
@@ -7565,8 +7693,15 @@ async fn prepend_pending_input_keeps_older_tail_ahead_of_newer_input() {
.await
.expect("inject initial pending input into active turn");
let drained = sess.get_pending_input().await;
assert_eq!(drained, vec![blocked, later.clone()]);
let drained = sess.get_pending_input_items().await;
assert_eq!(
drained
.iter()
.cloned()
.map(PendingInputItem::into_response_input_item)
.collect::<Vec<_>>(),
vec![blocked, later.clone()]
);
sess.inject_response_items(vec![newer.clone()])
.await
@@ -7574,7 +7709,7 @@ async fn prepend_pending_input_keeps_older_tail_ahead_of_newer_input() {
let mut drained_iter = drained.into_iter();
let _blocked = drained_iter.next().expect("blocked prompt should exist");
sess.prepend_pending_input(drained_iter.collect())
sess.prepend_pending_input_items(drained_iter.collect())
.await
.expect("requeue later pending input at the front of the queue");
@@ -7652,7 +7787,13 @@ async fn abort_empty_active_turn_preserves_pending_input() {
assert!(sess.active_turn.lock().await.is_none());
assert_eq!(
turn_state.lock().await.take_pending_input(),
turn_state
.lock()
.await
.take_pending_input_items()
.into_iter()
.map(PendingInputItem::into_response_input_item)
.collect::<Vec<_>>(),
vec![pending_item]
);
}

View File

@@ -390,7 +390,7 @@ pub(crate) async fn run_turn(
// submitted through the UI while the model was running. Though the UI
// may support this, the model might not.
let pending_input = if can_drain_pending_input {
sess.get_pending_input().await
sess.get_pending_input_items().await
} else {
Vec::new()
};
@@ -402,7 +402,13 @@ pub(crate) async fn run_turn(
if !pending_input.is_empty() {
let mut pending_input_iter = pending_input.into_iter();
while let Some(pending_input_item) = pending_input_iter.next() {
match inspect_pending_input(&sess, &turn_context, pending_input_item).await {
match inspect_pending_input(
&sess,
&turn_context,
pending_input_item.into_response_input_item(),
)
.await
{
PendingInputHookDisposition::Accepted(pending_input) => {
accepted_pending_input.push(*pending_input);
}
@@ -411,7 +417,9 @@ pub(crate) async fn run_turn(
} => {
let remaining_pending_input = pending_input_iter.collect::<Vec<_>>();
if !remaining_pending_input.is_empty() {
let _ = sess.prepend_pending_input(remaining_pending_input).await;
let _ = sess
.prepend_pending_input_items(remaining_pending_input)
.await;
requeued_pending_input = true;
}
blocked_pending_input_contexts = additional_contexts;
@@ -1076,12 +1084,18 @@ async fn run_sampling_request(
sess.set_total_tokens_full(&turn_context).await;
return Err(CodexErr::ContextWindowExceeded);
}
Err(CodexErr::UsageLimitReached(e)) => {
let rate_limits = e.rate_limits.clone();
if let Some(rate_limits) = rate_limits {
Err(
err @ (CodexErr::UsageLimitReached(_)
| CodexErr::QuotaExceeded
| CodexErr::UsageNotIncluded),
) => {
if let CodexErr::UsageLimitReached(e) = &err
&& let Some(rate_limits) = e.rate_limits.clone()
{
sess.update_rate_limits(&turn_context, *rate_limits).await;
}
return Err(CodexErr::UsageLimitReached(e));
sess.mark_usage_limit_reached(&turn_context.sub_id).await;
return Err(err);
}
Err(err) => err,
};

View File

@@ -6,6 +6,7 @@ pub(crate) use service::SessionServices;
pub(crate) use session::SessionState;
pub(crate) use turn::ActiveTurn;
pub(crate) use turn::MailboxDeliveryPhase;
pub(crate) use turn::PendingInputItem;
pub(crate) use turn::PendingRequestPermissions;
pub(crate) use turn::RunningTask;
pub(crate) use turn::TaskKind;

View File

@@ -25,6 +25,32 @@ use codex_protocol::models::AdditionalPermissionProfile;
use codex_protocol::protocol::ReviewDecision;
use codex_protocol::protocol::TokenUsage;
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum PendingInputItem {
TurnSteer(ResponseInputItem),
Injected(ResponseInputItem),
}
impl PendingInputItem {
pub(crate) fn turn_steer(input: ResponseInputItem) -> Self {
Self::TurnSteer(input)
}
pub(crate) fn injected(input: ResponseInputItem) -> Self {
Self::Injected(input)
}
pub(crate) fn is_turn_steer(&self) -> bool {
matches!(self, Self::TurnSteer(_))
}
pub(crate) fn into_response_input_item(self) -> ResponseInputItem {
match self {
Self::TurnSteer(input) | Self::Injected(input) => input,
}
}
}
/// Metadata about the currently running turn.
pub(crate) struct ActiveTurn {
pub(crate) tasks: IndexMap<String, RunningTask>,
@@ -113,7 +139,8 @@ pub(crate) struct TurnState {
pending_user_input: HashMap<String, oneshot::Sender<RequestUserInputResponse>>,
pending_elicitations: HashMap<(String, RequestId), oneshot::Sender<ElicitationResponse>>,
pending_dynamic_tools: HashMap<String, oneshot::Sender<DynamicToolResponse>>,
pending_input: Vec<ResponseInputItem>,
pending_input: Vec<PendingInputItem>,
usage_limit_reached: bool,
mailbox_delivery_phase: MailboxDeliveryPhase,
granted_permissions: Option<AdditionalPermissionProfile>,
strict_auto_review_enabled: bool,
@@ -151,6 +178,7 @@ impl TurnState {
self.pending_elicitations.clear();
self.pending_dynamic_tools.clear();
self.pending_input.clear();
self.usage_limit_reached = false;
}
pub(crate) fn insert_pending_request_permissions(
@@ -219,10 +247,14 @@ impl TurnState {
}
pub(crate) fn push_pending_input(&mut self, input: ResponseInputItem) {
self.pending_input.push(input);
self.pending_input.push(PendingInputItem::injected(input));
}
pub(crate) fn prepend_pending_input(&mut self, mut input: Vec<ResponseInputItem>) {
pub(crate) fn push_pending_steer_input(&mut self, input: ResponseInputItem) {
self.pending_input.push(PendingInputItem::turn_steer(input));
}
pub(crate) fn prepend_pending_input_items(&mut self, mut input: Vec<PendingInputItem>) {
if input.is_empty() {
return;
}
@@ -231,7 +263,7 @@ impl TurnState {
self.pending_input = input;
}
pub(crate) fn take_pending_input(&mut self) -> Vec<ResponseInputItem> {
pub(crate) fn take_pending_input_items(&mut self) -> Vec<PendingInputItem> {
if self.pending_input.is_empty() {
Vec::with_capacity(0)
} else {
@@ -245,6 +277,14 @@ impl TurnState {
!self.pending_input.is_empty()
}
pub(crate) fn mark_usage_limit_reached(&mut self) {
self.usage_limit_reached = true;
}
pub(crate) fn usage_limit_reached(&self) -> bool {
self.usage_limit_reached
}
pub(crate) fn accept_mailbox_delivery_for_current_turn(&mut self) {
self.set_mailbox_delivery_phase(MailboxDeliveryPhase::CurrentTurn);
}

View File

@@ -29,6 +29,7 @@ use crate::hook_runtime::record_pending_input;
use crate::session::session::Session;
use crate::session::turn_context::TurnContext;
use crate::state::ActiveTurn;
use crate::state::PendingInputItem;
use crate::state::RunningTask;
use crate::state::TaskKind;
use codex_analytics::TurnTokenUsageFact;
@@ -40,7 +41,6 @@ use codex_otel::TURN_MEMORY_METRIC;
use codex_otel::TURN_NETWORK_PROXY_METRIC;
use codex_otel::TURN_TOKEN_USAGE_METRIC;
use codex_otel::TURN_TOOL_CALL_METRIC;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::RolloutItem;
@@ -562,12 +562,13 @@ impl Session {
.turn_metadata_state
.cancel_git_enrichment_task();
let mut pending_input = Vec::<ResponseInputItem>::new();
let mut pending_input = Vec::<PendingInputItem>::new();
let mut should_clear_active_turn = false;
let mut token_usage_at_turn_start = None;
let mut turn_had_memory_citation = false;
let mut turn_tool_calls = 0_u64;
let mut records_turn_token_usage_on_span = false;
let mut usage_limit_reached = false;
let turn_state = {
let mut active = self.active_turn.lock().await;
if let Some(at) = active.as_mut()
@@ -587,14 +588,24 @@ impl Session {
};
if let Some(turn_state) = turn_state.as_ref() {
let mut ts = turn_state.lock().await;
pending_input = ts.take_pending_input();
pending_input = ts.take_pending_input_items();
turn_had_memory_citation = ts.has_memory_citation;
turn_tool_calls = ts.tool_calls;
token_usage_at_turn_start = Some(ts.token_usage_at_turn_start.clone());
usage_limit_reached = ts.usage_limit_reached();
}
if !pending_input.is_empty() {
for pending_input_item in pending_input {
match inspect_pending_input(self, &turn_context, pending_input_item).await {
if usage_limit_reached && pending_input_item.is_turn_steer() {
continue;
}
match inspect_pending_input(
self,
&turn_context,
pending_input_item.into_response_input_item(),
)
.await
{
PendingInputHookDisposition::Accepted(pending_input) => {
record_pending_input(self, &turn_context, *pending_input).await;
}

View File

@@ -78,6 +78,9 @@ impl SessionTask for RegularTask {
)
.instrument(run_turn_span.clone())
.await;
if sess.usage_limit_reached_for_active_turn().await {
return last_agent_message;
}
if !sess.has_pending_input().await {
return last_agent_message;
}

View File

@@ -2694,6 +2694,79 @@ async fn usage_limit_error_emits_rate_limit_event() -> anyhow::Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn usage_limit_error_does_not_auto_send_pending_steer() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
let response = ResponseTemplate::new(429)
.insert_header("x-codex-primary-used-percent", "100.0")
.insert_header("x-codex-secondary-used-percent", "87.5")
.insert_header("x-codex-primary-over-secondary-limit-percent", "95.0")
.insert_header("x-codex-primary-window-minutes", "15")
.insert_header("x-codex-secondary-window-minutes", "60")
.set_delay(std::time::Duration::from_millis(100))
.set_body_json(json!({
"error": {
"type": "usage_limit_reached",
"message": "limit reached",
"resets_at": 1704067242,
"plan_type": "pro"
}
}));
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(response)
.expect(1)
.mount(&server)
.await;
let mut builder = test_codex();
let codex_fixture = builder.build(&server).await?;
let codex = codex_fixture.codex.clone();
codex
.submit(Op::UserInput {
environments: None,
items: vec![UserInput::Text {
text: "hello".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
responsesapi_client_metadata: None,
})
.await?;
wait_for_event(&codex, |msg| matches!(msg, EventMsg::TurnStarted(_))).await;
codex
.submit(Op::UserInput {
environments: None,
items: vec![UserInput::Text {
text: "steer while blocked".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
responsesapi_client_metadata: None,
})
.await?;
wait_for_event(&codex, |msg| matches!(msg, EventMsg::Error(_))).await;
wait_for_event(&codex, |msg| matches!(msg, EventMsg::TurnComplete(_))).await;
tokio::time::sleep(std::time::Duration::from_millis(150)).await;
assert_eq!(
server
.received_requests()
.await
.expect("mock server should not fail")
.len(),
1
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));

View File

@@ -3,8 +3,10 @@ use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::Op;
use codex_protocol::user_input::UserInput;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::mount_response_once;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::sse;
use core_test_support::responses::sse_response;
use core_test_support::responses::start_mock_server;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::test_codex;
@@ -74,3 +76,86 @@ async fn quota_exceeded_emits_single_error_event() -> Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn quota_exceeded_does_not_auto_send_pending_steer() -> Result<()> {
assert_usage_limit_like_failure_does_not_auto_send_pending_steer(
"insufficient_quota",
"You exceeded your current quota.",
)
.await
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn usage_not_included_does_not_auto_send_pending_steer() -> Result<()> {
assert_usage_limit_like_failure_does_not_auto_send_pending_steer(
"usage_not_included",
"Usage is not included with this plan.",
)
.await
}
async fn assert_usage_limit_like_failure_does_not_auto_send_pending_steer(
code: &str,
message: &str,
) -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let mut builder = test_codex();
mount_response_once(
&server,
sse_response(sse(vec![
ev_response_created("resp-1"),
json!({
"type": "response.failed",
"response": {
"id": "resp-1",
"error": {
"code": code,
"message": message,
}
}
}),
]))
.set_delay(std::time::Duration::from_millis(100)),
)
.await;
let test = builder.build(&server).await?;
test.codex
.submit(Op::UserInput {
environments: None,
items: vec![UserInput::Text {
text: "hello".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
responsesapi_client_metadata: None,
})
.await?;
wait_for_event(&test.codex, |msg| matches!(msg, EventMsg::TurnStarted(_))).await;
test.codex
.submit(Op::UserInput {
environments: None,
items: vec![UserInput::Text {
text: "steer while blocked".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
responsesapi_client_metadata: None,
})
.await?;
wait_for_event(&test.codex, |msg| matches!(msg, EventMsg::Error(_))).await;
wait_for_event(&test.codex, |msg| matches!(msg, EventMsg::TurnComplete(_))).await;
tokio::time::sleep(std::time::Duration::from_millis(150)).await;
assert_eq!(
server.received_requests().await.unwrap_or_default().len(),
1
);
Ok(())
}