mirror of
https://github.com/openai/codex.git
synced 2026-03-19 20:36:30 +03:00
Compare commits
7 Commits
starr/exec
...
fix/rollba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0d4c01ad1 | ||
|
|
9a937330fa | ||
|
|
60e7e634bb | ||
|
|
429551dad2 | ||
|
|
2dc1a99b12 | ||
|
|
8bb3043089 | ||
|
|
534680fd05 |
@@ -131,6 +131,7 @@ use crate::config::types::McpServerConfig;
|
||||
use crate::config::types::ShellEnvironmentPolicy;
|
||||
use crate::context_manager::ContextManager;
|
||||
use crate::context_manager::TotalTokenUsageBreakdown;
|
||||
use crate::context_manager::UserTurnBaselineFrame;
|
||||
use crate::environment_context::EnvironmentContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
@@ -809,6 +810,13 @@ pub(crate) struct SessionSettingsUpdate {
|
||||
pub(crate) personality: Option<Personality>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ReconstructedRolloutHistory {
|
||||
items: Vec<ResponseItem>,
|
||||
user_turn_baselines: Vec<UserTurnBaselineFrame>,
|
||||
previous_model: Option<String>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Builds the `x-codex-beta-features` header value for this session.
|
||||
///
|
||||
@@ -1610,23 +1618,22 @@ impl Session {
|
||||
let rollout_items = resumed_history.history;
|
||||
let restored_tool_selection =
|
||||
Self::extract_mcp_tool_selection_from_rollout(&rollout_items);
|
||||
let (previous_regular_turn_context_item, crossed_compaction_after_turn) =
|
||||
Self::last_rollout_regular_turn_context_lookup(&rollout_items);
|
||||
let previous_model =
|
||||
previous_regular_turn_context_item.map(|ctx| ctx.model.clone());
|
||||
let reconstructed = self
|
||||
.reconstruct_history_from_rollout(&turn_context, &rollout_items)
|
||||
.await;
|
||||
let previous_model = reconstructed.previous_model.clone();
|
||||
let reference_context_item = reconstructed
|
||||
.user_turn_baselines
|
||||
.last()
|
||||
.filter(|frame| !frame.invalidated_by_following_compaction)
|
||||
.map(|frame| frame.turn_context_item.clone());
|
||||
let curr = turn_context.model_info.slug.as_str();
|
||||
let reference_context_item = if !crossed_compaction_after_turn {
|
||||
previous_regular_turn_context_item.cloned()
|
||||
} else {
|
||||
// Keep the baseline empty when compaction may have stripped the referenced
|
||||
// context diffs so the first resumed regular turn fully reinjects context.
|
||||
None
|
||||
};
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
state.set_user_turn_baselines(reconstructed.user_turn_baselines.clone());
|
||||
state.set_reference_context_item(reference_context_item);
|
||||
state.set_previous_model(previous_model.clone());
|
||||
}
|
||||
self.set_previous_model(previous_model.clone()).await;
|
||||
|
||||
// If resuming, warn when the last recorded model differs from the current one.
|
||||
if let Some(prev) = previous_model.as_deref().filter(|p| *p != curr) {
|
||||
@@ -1644,11 +1651,8 @@ impl Session {
|
||||
}
|
||||
|
||||
// Always add response items to conversation history
|
||||
let reconstructed_history = self
|
||||
.reconstruct_history_from_rollout(&turn_context, &rollout_items)
|
||||
.await;
|
||||
if !reconstructed_history.is_empty() {
|
||||
self.record_into_history(&reconstructed_history, &turn_context)
|
||||
if !reconstructed.items.is_empty() {
|
||||
self.record_into_history(&reconstructed.items, &turn_context)
|
||||
.await;
|
||||
}
|
||||
|
||||
@@ -1669,18 +1673,19 @@ impl Session {
|
||||
InitialHistory::Forked(rollout_items) => {
|
||||
let restored_tool_selection =
|
||||
Self::extract_mcp_tool_selection_from_rollout(&rollout_items);
|
||||
let (previous_regular_turn_context_item, _) =
|
||||
Self::last_rollout_regular_turn_context_lookup(&rollout_items);
|
||||
let previous_model =
|
||||
previous_regular_turn_context_item.map(|ctx| ctx.model.clone());
|
||||
self.set_previous_model(previous_model).await;
|
||||
|
||||
// Always add response items to conversation history
|
||||
let reconstructed_history = self
|
||||
let reconstructed = self
|
||||
.reconstruct_history_from_rollout(&turn_context, &rollout_items)
|
||||
.await;
|
||||
if !reconstructed_history.is_empty() {
|
||||
self.record_into_history(&reconstructed_history, &turn_context)
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
state.set_user_turn_baselines(reconstructed.user_turn_baselines.clone());
|
||||
state.set_reference_context_item(None);
|
||||
state.set_previous_model(reconstructed.previous_model.clone());
|
||||
}
|
||||
|
||||
// Always add response items to conversation history
|
||||
if !reconstructed.items.is_empty() {
|
||||
self.record_into_history(&reconstructed.items, &turn_context)
|
||||
.await;
|
||||
}
|
||||
|
||||
@@ -1717,150 +1722,6 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `(last_turn_context_item, crossed_compaction_after_turn)` from the
|
||||
/// rollback-adjusted rollout view.
|
||||
///
|
||||
/// This relies on the invariant that only regular turns persist `TurnContextItem`.
|
||||
/// `ThreadRolledBack` markers are applied so resume/fork uses the post-rollback history view.
|
||||
///
|
||||
/// Returns `(None, false)` when no persisted `TurnContextItem` can be found.
|
||||
///
|
||||
/// Older/minimal rollouts may only contain `RolloutItem::TurnContext` entries without turn
|
||||
/// lifecycle events. In that case we fall back to the last `TurnContextItem` (plus whether a
|
||||
/// later `Compacted` item appears in rollout order).
|
||||
// TODO(ccunningham): Simplify this lookup by sharing rollout traversal/rollback application
|
||||
// with `reconstruct_history_from_rollout` so resume/fork baseline hydration does not need a
|
||||
// second bespoke rollout scan.
|
||||
fn last_rollout_regular_turn_context_lookup(
|
||||
rollout_items: &[RolloutItem],
|
||||
) -> (Option<&TurnContextItem>, bool) {
|
||||
// Reverse scan over rollout items. `ThreadRolledBack(num_turns)` is naturally handled by
|
||||
// skipping the next `num_turns` completed turn spans we encounter while walking backward.
|
||||
//
|
||||
// "Active turn" here means: we have seen `TurnComplete`/`TurnAborted` and are currently
|
||||
// scanning backward through that completed turn until its matching `TurnStarted`.
|
||||
let mut turns_to_skip_due_to_rollback = 0usize;
|
||||
let mut saw_surviving_compaction_after_candidate = false;
|
||||
let mut saw_turn_lifecycle_event = false;
|
||||
let mut active_turn_id: Option<&str> = None;
|
||||
let mut active_turn_saw_user_message = false;
|
||||
let mut active_turn_context: Option<&TurnContextItem> = None;
|
||||
let mut active_turn_contains_compaction = false;
|
||||
|
||||
for item in rollout_items.iter().rev() {
|
||||
match item {
|
||||
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => {
|
||||
// Rollbacks count completed turns, not `TurnContextItem`s. We must continue
|
||||
// ignoring all items inside each skipped turn until we reach its
|
||||
// corresponding `TurnStarted`.
|
||||
let num_turns = usize::try_from(rollback.num_turns).unwrap_or(usize::MAX);
|
||||
turns_to_skip_due_to_rollback =
|
||||
turns_to_skip_due_to_rollback.saturating_add(num_turns);
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::TurnComplete(event)) => {
|
||||
saw_turn_lifecycle_event = true;
|
||||
// Enter the reverse "turn span" for this completed turn.
|
||||
active_turn_id = Some(event.turn_id.as_str());
|
||||
active_turn_saw_user_message = false;
|
||||
active_turn_context = None;
|
||||
active_turn_contains_compaction = false;
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::TurnAborted(event)) => {
|
||||
saw_turn_lifecycle_event = true;
|
||||
// Same reverse-turn handling as `TurnComplete`. Some aborted turns may not
|
||||
// have a turn id; in that case we cannot match `TurnContextItem`s to them.
|
||||
active_turn_id = event.turn_id.as_deref();
|
||||
active_turn_saw_user_message = false;
|
||||
active_turn_context = None;
|
||||
active_turn_contains_compaction = false;
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::UserMessage(_)) => {
|
||||
if active_turn_id.is_some() {
|
||||
active_turn_saw_user_message = true;
|
||||
}
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::TurnStarted(event)) => {
|
||||
saw_turn_lifecycle_event = true;
|
||||
if active_turn_id == Some(event.turn_id.as_str()) {
|
||||
let active_turn_is_rolled_back =
|
||||
active_turn_saw_user_message && turns_to_skip_due_to_rollback > 0;
|
||||
if active_turn_is_rolled_back {
|
||||
// `ThreadRolledBack(num_turns)` counts user turns, so only consume a
|
||||
// skip once we've confirmed this reverse-scanned turn span contains a
|
||||
// user message. Standalone task turns must not consume rollback skips.
|
||||
turns_to_skip_due_to_rollback -= 1;
|
||||
}
|
||||
if !active_turn_is_rolled_back {
|
||||
if let Some(context_item) = active_turn_context {
|
||||
return (
|
||||
Some(context_item),
|
||||
saw_surviving_compaction_after_candidate,
|
||||
);
|
||||
}
|
||||
// No `TurnContextItem` in this surviving turn; keep scanning older
|
||||
// turns, but remember if this turn compacted so the eventual
|
||||
// candidate reports "compaction happened after it".
|
||||
if active_turn_contains_compaction {
|
||||
saw_surviving_compaction_after_candidate = true;
|
||||
}
|
||||
}
|
||||
active_turn_id = None;
|
||||
active_turn_saw_user_message = false;
|
||||
active_turn_context = None;
|
||||
active_turn_contains_compaction = false;
|
||||
}
|
||||
}
|
||||
RolloutItem::TurnContext(ctx) => {
|
||||
// Capture the latest turn context seen in this reverse-scanned turn span. If
|
||||
// the turn later proves to be rolled back, we discard it when we hit the
|
||||
// matching `TurnStarted`. Older rollouts may have lifecycle events but omit
|
||||
// `TurnContextItem.turn_id`; accept those as belonging to the active turn
|
||||
// span for resume/fork hydration.
|
||||
if let Some(active_id) = active_turn_id
|
||||
&& ctx
|
||||
.turn_id
|
||||
.as_deref()
|
||||
.is_none_or(|turn_id| turn_id == active_id)
|
||||
{
|
||||
// Reverse scan sees the latest `TurnContextItem` for the turn first.
|
||||
active_turn_context.get_or_insert(ctx);
|
||||
}
|
||||
}
|
||||
RolloutItem::Compacted(_) => {
|
||||
if active_turn_id.is_some() {
|
||||
// Compaction inside the currently scanned turn is only "after" the
|
||||
// eventual candidate if this turn has no `TurnContextItem` and we keep
|
||||
// scanning into older turns.
|
||||
active_turn_contains_compaction = true;
|
||||
} else {
|
||||
saw_surviving_compaction_after_candidate = true;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Legacy/minimal rollouts may only persist `TurnContextItem`/`Compacted` without turn
|
||||
// lifecycle events. Fall back to the last `TurnContextItem` in rollout order so
|
||||
// resume/fork can still hydrate `previous_model` and detect compaction-after-baseline.
|
||||
if !saw_turn_lifecycle_event {
|
||||
let mut saw_compaction_after_last_turn_context = false;
|
||||
for item in rollout_items.iter().rev() {
|
||||
match item {
|
||||
RolloutItem::Compacted(_) => {
|
||||
saw_compaction_after_last_turn_context = true;
|
||||
}
|
||||
RolloutItem::TurnContext(ctx) => {
|
||||
return (Some(ctx), saw_compaction_after_last_turn_context);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(None, false)
|
||||
}
|
||||
|
||||
fn last_token_info_from_rollout(rollout_items: &[RolloutItem]) -> Option<TokenUsageInfo> {
|
||||
rollout_items.iter().rev().find_map(|item| match item {
|
||||
RolloutItem::EventMsg(EventMsg::TokenCount(ev)) => ev.info.clone(),
|
||||
@@ -2565,8 +2426,12 @@ impl Session {
|
||||
&self,
|
||||
turn_context: &TurnContext,
|
||||
rollout_items: &[RolloutItem],
|
||||
) -> Vec<ResponseItem> {
|
||||
) -> ReconstructedRolloutHistory {
|
||||
let mut history = ContextManager::new();
|
||||
let mut saw_turn_lifecycle_event = false;
|
||||
let mut active_turn_id: Option<&str> = None;
|
||||
let mut active_turn_saw_user_message = false;
|
||||
let mut active_turn_context: Option<TurnContextItem> = None;
|
||||
for item in rollout_items {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(response_item) => {
|
||||
@@ -2587,14 +2452,96 @@ impl Session {
|
||||
);
|
||||
history.replace(rebuilt);
|
||||
}
|
||||
|
||||
let compaction_is_after_latest_user_turn =
|
||||
active_turn_id.is_none() || !active_turn_saw_user_message;
|
||||
if compaction_is_after_latest_user_turn {
|
||||
history.invalidate_top_user_turn_baseline();
|
||||
}
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback)) => {
|
||||
let baseline_len_before = history.user_turn_baselines_len();
|
||||
history.drop_last_n_user_turns(rollback.num_turns);
|
||||
if history.user_turn_baselines_len() == baseline_len_before {
|
||||
history.truncate_user_turn_baselines_from_end(rollback.num_turns);
|
||||
}
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::TurnStarted(event)) => {
|
||||
saw_turn_lifecycle_event = true;
|
||||
active_turn_id = Some(event.turn_id.as_str());
|
||||
active_turn_saw_user_message = false;
|
||||
active_turn_context = None;
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::UserMessage(_)) => {
|
||||
if active_turn_id.is_some() {
|
||||
active_turn_saw_user_message = true;
|
||||
}
|
||||
}
|
||||
RolloutItem::TurnContext(ctx) => {
|
||||
if let Some(active_id) = active_turn_id
|
||||
&& ctx
|
||||
.turn_id
|
||||
.as_deref()
|
||||
.is_none_or(|turn_id| turn_id == active_id)
|
||||
{
|
||||
active_turn_context = Some(ctx.clone());
|
||||
}
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::TurnComplete(event)) => {
|
||||
saw_turn_lifecycle_event = true;
|
||||
if active_turn_id == Some(event.turn_id.as_str()) {
|
||||
if active_turn_saw_user_message
|
||||
&& let Some(ctx) = active_turn_context.take()
|
||||
{
|
||||
history.record_regular_turn_baseline(ctx);
|
||||
}
|
||||
active_turn_id = None;
|
||||
active_turn_saw_user_message = false;
|
||||
active_turn_context = None;
|
||||
}
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::TurnAborted(event)) => {
|
||||
saw_turn_lifecycle_event = true;
|
||||
if active_turn_id == event.turn_id.as_deref() {
|
||||
if active_turn_saw_user_message
|
||||
&& let Some(ctx) = active_turn_context.take()
|
||||
{
|
||||
history.record_regular_turn_baseline(ctx);
|
||||
}
|
||||
active_turn_id = None;
|
||||
active_turn_saw_user_message = false;
|
||||
active_turn_context = None;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
history.raw_items().to_vec()
|
||||
|
||||
if !saw_turn_lifecycle_event {
|
||||
let mut saw_compaction_after_last_turn_context = false;
|
||||
for item in rollout_items.iter().rev() {
|
||||
match item {
|
||||
RolloutItem::Compacted(_) => {
|
||||
saw_compaction_after_last_turn_context = true;
|
||||
}
|
||||
RolloutItem::TurnContext(ctx) => {
|
||||
history.set_user_turn_baselines(vec![UserTurnBaselineFrame {
|
||||
turn_context_item: ctx.clone(),
|
||||
invalidated_by_following_compaction:
|
||||
saw_compaction_after_last_turn_context,
|
||||
}]);
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ReconstructedRolloutHistory {
|
||||
items: history.raw_items().to_vec(),
|
||||
user_turn_baselines: history.user_turn_baselines(),
|
||||
previous_model: history.latest_user_turn_model(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Append ResponseItems to the in-memory conversation history only.
|
||||
@@ -2859,7 +2806,7 @@ impl Session {
|
||||
}
|
||||
|
||||
let mut state = self.state.lock().await;
|
||||
state.set_reference_context_item(Some(turn_context.to_turn_context_item()));
|
||||
state.record_regular_turn_baseline(turn_context.to_turn_context_item());
|
||||
}
|
||||
|
||||
pub(crate) async fn update_token_usage_info(
|
||||
@@ -4103,19 +4050,10 @@ mod handlers {
|
||||
|
||||
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
|
||||
|
||||
let mut history = sess.clone_history().await;
|
||||
// TODO(ccunningham): Fix rollback/backtracking baseline handling.
|
||||
// We clear `reference_context_item` here, but should restore the
|
||||
// post-rollback baseline from the surviving history/rollout instead.
|
||||
// Truncating history should also invalidate/recompute `previous_model`
|
||||
// so the next regular turn replays any dropped model-switch
|
||||
// instructions.
|
||||
history.drop_last_n_user_turns(num_turns);
|
||||
|
||||
// Replace with the raw items. We don't want to replace with a normalized
|
||||
// version of the history.
|
||||
sess.replace_history(history.raw_items().to_vec(), None)
|
||||
.await;
|
||||
{
|
||||
let mut state = sess.state.lock().await;
|
||||
state.rollback_user_turns(num_turns);
|
||||
}
|
||||
sess.recompute_token_usage(turn_context.as_ref()).await;
|
||||
|
||||
sess.send_event_raw_flushed(Event {
|
||||
@@ -6589,7 +6527,7 @@ mod tests {
|
||||
.reconstruct_history_from_rollout(reconstruction_turn.as_ref(), &rollout_items)
|
||||
.await;
|
||||
|
||||
assert_eq!(expected, reconstructed);
|
||||
assert_eq!(expected, reconstructed.items);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -6625,7 +6563,7 @@ mod tests {
|
||||
.reconstruct_history_from_rollout(&turn_context, &rollout_items)
|
||||
.await;
|
||||
|
||||
assert_eq!(reconstructed, replacement_history);
|
||||
assert_eq!(reconstructed.items, replacement_history);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -6747,6 +6685,88 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn record_initial_history_resumed_rollback_with_missing_turn_context_ids_keeps_distinct_baselines()
|
||||
{
|
||||
let (session, turn_context) = make_session_and_context().await;
|
||||
let first_turn_id = "turn-1".to_string();
|
||||
let second_turn_id = "turn-2".to_string();
|
||||
|
||||
let mut first_context_item = turn_context.to_turn_context_item();
|
||||
first_context_item.turn_id = None;
|
||||
first_context_item.model = "model-a".to_string();
|
||||
|
||||
let mut second_context_item = turn_context.to_turn_context_item();
|
||||
second_context_item.turn_id = None;
|
||||
second_context_item.model = "model-b".to_string();
|
||||
|
||||
let rollout_items = vec![
|
||||
RolloutItem::EventMsg(EventMsg::TurnStarted(
|
||||
codex_protocol::protocol::TurnStartedEvent {
|
||||
turn_id: first_turn_id.clone(),
|
||||
model_context_window: Some(128_000),
|
||||
collaboration_mode_kind: ModeKind::Default,
|
||||
},
|
||||
)),
|
||||
RolloutItem::EventMsg(EventMsg::UserMessage(
|
||||
codex_protocol::protocol::UserMessageEvent {
|
||||
message: "first".to_string(),
|
||||
images: None,
|
||||
local_images: Vec::new(),
|
||||
text_elements: Vec::new(),
|
||||
},
|
||||
)),
|
||||
RolloutItem::TurnContext(first_context_item.clone()),
|
||||
RolloutItem::EventMsg(EventMsg::TurnComplete(
|
||||
codex_protocol::protocol::TurnCompleteEvent {
|
||||
turn_id: first_turn_id,
|
||||
last_agent_message: None,
|
||||
},
|
||||
)),
|
||||
RolloutItem::EventMsg(EventMsg::TurnStarted(
|
||||
codex_protocol::protocol::TurnStartedEvent {
|
||||
turn_id: second_turn_id.clone(),
|
||||
model_context_window: Some(128_000),
|
||||
collaboration_mode_kind: ModeKind::Default,
|
||||
},
|
||||
)),
|
||||
RolloutItem::EventMsg(EventMsg::UserMessage(
|
||||
codex_protocol::protocol::UserMessageEvent {
|
||||
message: "second".to_string(),
|
||||
images: None,
|
||||
local_images: Vec::new(),
|
||||
text_elements: Vec::new(),
|
||||
},
|
||||
)),
|
||||
RolloutItem::TurnContext(second_context_item),
|
||||
RolloutItem::EventMsg(EventMsg::TurnComplete(
|
||||
codex_protocol::protocol::TurnCompleteEvent {
|
||||
turn_id: second_turn_id,
|
||||
last_agent_message: None,
|
||||
},
|
||||
)),
|
||||
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(
|
||||
codex_protocol::protocol::ThreadRolledBackEvent { num_turns: 1 },
|
||||
)),
|
||||
];
|
||||
|
||||
session
|
||||
.record_initial_history(InitialHistory::Resumed(ResumedHistory {
|
||||
conversation_id: ThreadId::default(),
|
||||
history: rollout_items,
|
||||
rollout_path: PathBuf::from("/tmp/resume.jsonl"),
|
||||
}))
|
||||
.await;
|
||||
|
||||
assert_eq!(session.previous_model().await, Some("model-a".to_string()));
|
||||
assert_eq!(
|
||||
serde_json::to_value(session.reference_context_item().await)
|
||||
.expect("serialize seeded reference context item"),
|
||||
serde_json::to_value(Some(first_context_item))
|
||||
.expect("serialize expected reference context item")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn record_initial_history_resumed_rollback_skips_only_user_turns() {
|
||||
let (session, turn_context) = make_session_and_context().await;
|
||||
@@ -6756,6 +6776,9 @@ mod tests {
|
||||
.clone()
|
||||
.expect("turn context should have turn_id");
|
||||
let standalone_turn_id = "standalone-task-turn".to_string();
|
||||
let mut standalone_turn_context_item = turn_context.to_turn_context_item();
|
||||
standalone_turn_context_item.turn_id = Some(standalone_turn_id.clone());
|
||||
standalone_turn_context_item.model = "standalone-task-model".to_string();
|
||||
let rollout_items = vec![
|
||||
RolloutItem::EventMsg(EventMsg::TurnStarted(
|
||||
codex_protocol::protocol::TurnStartedEvent {
|
||||
@@ -6787,6 +6810,9 @@ mod tests {
|
||||
collaboration_mode_kind: ModeKind::Default,
|
||||
},
|
||||
)),
|
||||
// Older rollouts may contain a task-turn TurnContext; rollback semantics still count
|
||||
// user turns only, and baseline hydration must ignore this standalone task context.
|
||||
RolloutItem::TurnContext(standalone_turn_context_item),
|
||||
RolloutItem::EventMsg(EventMsg::TurnComplete(
|
||||
codex_protocol::protocol::TurnCompleteEvent {
|
||||
turn_id: standalone_turn_id,
|
||||
@@ -7141,9 +7167,102 @@ mod tests {
|
||||
|
||||
let history = sess.clone_history().await;
|
||||
assert_eq!(expected, history.raw_items());
|
||||
assert_eq!(sess.previous_model().await, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_rollback_rehydrates_baseline_from_in_memory_user_turn_stack() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
|
||||
let turn_1 = vec![user_message("turn 1 user")];
|
||||
let turn_2 = vec![user_message("turn 2 user")];
|
||||
sess.record_into_history(&turn_1, tc.as_ref()).await;
|
||||
sess.record_into_history(&turn_2, tc.as_ref()).await;
|
||||
|
||||
let mut turn_1_context = tc.to_turn_context_item();
|
||||
turn_1_context.turn_id = Some("turn-1".to_string());
|
||||
turn_1_context.model = "model-a".to_string();
|
||||
let mut turn_2_context = tc.to_turn_context_item();
|
||||
turn_2_context.turn_id = Some("turn-2".to_string());
|
||||
turn_2_context.model = "model-b".to_string();
|
||||
|
||||
sess.set_previous_model(Some("model-b".to_string())).await;
|
||||
{
|
||||
let mut state = sess.state.lock().await;
|
||||
state.set_user_turn_baselines(vec![
|
||||
UserTurnBaselineFrame {
|
||||
turn_context_item: turn_1_context.clone(),
|
||||
invalidated_by_following_compaction: false,
|
||||
},
|
||||
UserTurnBaselineFrame {
|
||||
turn_context_item: turn_2_context,
|
||||
invalidated_by_following_compaction: false,
|
||||
},
|
||||
]);
|
||||
state.set_reference_context_item(Some(tc.to_turn_context_item()));
|
||||
}
|
||||
|
||||
handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await;
|
||||
|
||||
let rollback_event = wait_for_thread_rolled_back(&rx).await;
|
||||
assert_eq!(rollback_event.num_turns, 1);
|
||||
assert_eq!(sess.previous_model().await, Some("model-a".to_string()));
|
||||
assert_eq!(
|
||||
sess.previous_model().await,
|
||||
Some("previous-regular-model".to_string())
|
||||
serde_json::to_value(sess.reference_context_item().await)
|
||||
.expect("serialize post-rollback reference context item"),
|
||||
serde_json::to_value(Some(turn_1_context))
|
||||
.expect("serialize expected post-rollback reference context item")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_rollback_trims_baselines_when_history_has_no_user_messages() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
|
||||
let assistant_only_history = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "compacted summary".to_string(),
|
||||
}],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
}];
|
||||
sess.record_into_history(&assistant_only_history, tc.as_ref())
|
||||
.await;
|
||||
|
||||
let mut turn_1_context = tc.to_turn_context_item();
|
||||
turn_1_context.turn_id = Some("turn-1".to_string());
|
||||
turn_1_context.model = "model-a".to_string();
|
||||
let mut turn_2_context = tc.to_turn_context_item();
|
||||
turn_2_context.turn_id = Some("turn-2".to_string());
|
||||
turn_2_context.model = "model-b".to_string();
|
||||
|
||||
sess.set_previous_model(Some("model-b".to_string())).await;
|
||||
{
|
||||
let mut state = sess.state.lock().await;
|
||||
state.set_user_turn_baselines(vec![
|
||||
UserTurnBaselineFrame {
|
||||
turn_context_item: turn_1_context.clone(),
|
||||
invalidated_by_following_compaction: false,
|
||||
},
|
||||
UserTurnBaselineFrame {
|
||||
turn_context_item: turn_2_context,
|
||||
invalidated_by_following_compaction: false,
|
||||
},
|
||||
]);
|
||||
}
|
||||
|
||||
handlers::thread_rollback(&sess, "sub-1".to_string(), 1).await;
|
||||
|
||||
let rollback_event = wait_for_thread_rolled_back(&rx).await;
|
||||
assert_eq!(rollback_event.num_turns, 1);
|
||||
assert_eq!(sess.previous_model().await, Some("model-a".to_string()));
|
||||
assert_eq!(
|
||||
serde_json::to_value(sess.reference_context_item().await)
|
||||
.expect("serialize post-rollback reference context item"),
|
||||
serde_json::to_value(Some(turn_1_context))
|
||||
.expect("serialize expected post-rollback reference context item")
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,18 @@ pub(crate) struct ContextManager {
|
||||
/// When this is `None`, settings diffing treats the next turn as having no
|
||||
/// baseline and emits a full reinjection of context state.
|
||||
reference_context_item: Option<TurnContextItem>,
|
||||
/// Rollback-aware stack of regular user-turn context baselines.
|
||||
///
|
||||
/// We keep this adjacent to model-visible history so rollback can trim user
|
||||
/// turns and recompute the effective reference baseline without re-reading
|
||||
/// rollout from disk.
|
||||
user_turn_baselines: Vec<UserTurnBaselineFrame>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct UserTurnBaselineFrame {
|
||||
pub(crate) turn_context_item: TurnContextItem,
|
||||
pub(crate) invalidated_by_following_compaction: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
@@ -52,6 +64,7 @@ impl ContextManager {
|
||||
items: Vec::new(),
|
||||
token_info: TokenUsageInfo::new_or_append(&None, &None, None),
|
||||
reference_context_item: None,
|
||||
user_turn_baselines: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,6 +84,86 @@ impl ContextManager {
|
||||
self.reference_context_item.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn set_user_turn_baselines(&mut self, baselines: Vec<UserTurnBaselineFrame>) {
|
||||
self.user_turn_baselines = baselines;
|
||||
}
|
||||
|
||||
pub(crate) fn user_turn_baselines(&self) -> Vec<UserTurnBaselineFrame> {
|
||||
self.user_turn_baselines.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn user_turn_baselines_len(&self) -> usize {
|
||||
self.user_turn_baselines.len()
|
||||
}
|
||||
|
||||
pub(crate) fn record_regular_turn_baseline(&mut self, item: TurnContextItem) {
|
||||
if let Some(top) = self.user_turn_baselines.last_mut()
|
||||
&& let (Some(top_turn_id), Some(item_turn_id)) = (
|
||||
top.turn_context_item.turn_id.as_deref(),
|
||||
item.turn_id.as_deref(),
|
||||
)
|
||||
&& top_turn_id == item_turn_id
|
||||
{
|
||||
top.turn_context_item = item.clone();
|
||||
top.invalidated_by_following_compaction = false;
|
||||
self.reference_context_item = Some(item);
|
||||
return;
|
||||
}
|
||||
|
||||
self.user_turn_baselines.push(UserTurnBaselineFrame {
|
||||
turn_context_item: item.clone(),
|
||||
invalidated_by_following_compaction: false,
|
||||
});
|
||||
self.reference_context_item = Some(item);
|
||||
}
|
||||
|
||||
pub(crate) fn invalidate_top_user_turn_baseline(&mut self) {
|
||||
if let Some(top) = self.user_turn_baselines.last_mut() {
|
||||
top.invalidated_by_following_compaction = true;
|
||||
}
|
||||
self.reference_context_item = None;
|
||||
}
|
||||
|
||||
pub(crate) fn truncate_user_turn_baselines_from_end(&mut self, num_turns: u32) {
|
||||
let frames_to_pop = usize::try_from(num_turns).unwrap_or(usize::MAX);
|
||||
let new_len = self.user_turn_baselines.len().saturating_sub(frames_to_pop);
|
||||
self.user_turn_baselines.truncate(new_len);
|
||||
self.reference_context_item = self
|
||||
.user_turn_baselines
|
||||
.last()
|
||||
.filter(|frame| !frame.invalidated_by_following_compaction)
|
||||
.map(|frame| frame.turn_context_item.clone());
|
||||
}
|
||||
|
||||
pub(crate) fn sync_reference_context_after_history_replacement(
|
||||
&mut self,
|
||||
reference_context_item: Option<TurnContextItem>,
|
||||
) {
|
||||
match reference_context_item {
|
||||
Some(item) => {
|
||||
if let Some(top) = self.user_turn_baselines.last_mut()
|
||||
&& top.turn_context_item.turn_id == item.turn_id
|
||||
{
|
||||
top.turn_context_item = item.clone();
|
||||
top.invalidated_by_following_compaction = false;
|
||||
}
|
||||
self.reference_context_item = Some(item);
|
||||
}
|
||||
None => {
|
||||
// Replacement histories that clear the reference baseline (for example
|
||||
// standalone/pre-turn compaction) should also invalidate the top user-turn
|
||||
// baseline so rollback and future turns do not reuse stale diffs.
|
||||
self.invalidate_top_user_turn_baseline();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn latest_user_turn_model(&self) -> Option<String> {
|
||||
self.user_turn_baselines
|
||||
.last()
|
||||
.map(|frame| frame.turn_context_item.model.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn set_token_usage_full(&mut self, context_window: i64) {
|
||||
match &mut self.token_info {
|
||||
Some(info) => info.fill_to_context_window(context_window),
|
||||
@@ -216,6 +309,7 @@ impl ContextManager {
|
||||
let user_positions = user_message_positions(&snapshot);
|
||||
let Some(&first_user_idx) = user_positions.first() else {
|
||||
self.replace(snapshot);
|
||||
self.truncate_user_turn_baselines_from_end(num_turns);
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -227,6 +321,7 @@ impl ContextManager {
|
||||
};
|
||||
|
||||
self.replace(snapshot[..cut_idx].to_vec());
|
||||
self.truncate_user_turn_baselines_from_end(num_turns);
|
||||
}
|
||||
|
||||
pub(crate) fn update_token_info(
|
||||
|
||||
@@ -4,6 +4,7 @@ pub(crate) mod updates;
|
||||
|
||||
pub(crate) use history::ContextManager;
|
||||
pub(crate) use history::TotalTokenUsageBreakdown;
|
||||
pub(crate) use history::UserTurnBaselineFrame;
|
||||
pub(crate) use history::estimate_response_item_model_visible_bytes;
|
||||
pub(crate) use history::is_codex_generated_item;
|
||||
pub(crate) use history::is_user_turn_boundary;
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::collections::HashSet;
|
||||
|
||||
use crate::codex::SessionConfiguration;
|
||||
use crate::context_manager::ContextManager;
|
||||
use crate::context_manager::UserTurnBaselineFrame;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::protocol::TokenUsageInfo;
|
||||
@@ -76,7 +77,7 @@ impl SessionState {
|
||||
) {
|
||||
self.history.replace(items);
|
||||
self.history
|
||||
.set_reference_context_item(reference_context_item);
|
||||
.sync_reference_context_after_history_replacement(reference_context_item);
|
||||
}
|
||||
|
||||
pub(crate) fn set_token_info(&mut self, info: Option<TokenUsageInfo>) {
|
||||
@@ -91,6 +92,19 @@ impl SessionState {
|
||||
self.history.reference_context_item()
|
||||
}
|
||||
|
||||
pub(crate) fn set_user_turn_baselines(&mut self, baselines: Vec<UserTurnBaselineFrame>) {
|
||||
self.history.set_user_turn_baselines(baselines);
|
||||
}
|
||||
|
||||
pub(crate) fn record_regular_turn_baseline(&mut self, item: TurnContextItem) {
|
||||
self.history.record_regular_turn_baseline(item);
|
||||
}
|
||||
|
||||
pub(crate) fn rollback_user_turns(&mut self, num_turns: u32) {
|
||||
self.history.drop_last_n_user_turns(num_turns);
|
||||
self.previous_model = self.history.latest_user_turn_model();
|
||||
}
|
||||
|
||||
// Token/rate limit helpers
|
||||
pub(crate) fn update_token_info_from_usage(
|
||||
&mut self,
|
||||
|
||||
Reference in New Issue
Block a user