Add fork snapshot modes (#15239)

## Summary
- add `ForkSnapshotMode` to `ThreadManager::fork_thread` so callers can
request either a committed snapshot or an interrupted snapshot
- share the model-visible `<turn_aborted>` history marker between the
live interrupt path and interrupted forks
- update the small set of direct fork callsites to pass
`ForkSnapshotMode::Committed`

Note: this enables /btw to work similarly as Esc to interrupt (hopefully
somewhat in distribution)

---------

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Charley Cunningham
2026-03-23 19:05:42 -07:00
committed by GitHub
parent 84fb180eeb
commit f547b79bd0
15 changed files with 823 additions and 52 deletions

View File

@@ -24,6 +24,9 @@ use crate::rollout::RolloutRecorder;
use crate::rollout::truncation;
use crate::shell_snapshot::ShellSnapshot;
use crate::skills::SkillsManager;
use crate::tasks::interrupted_turn_history_marker;
use codex_app_server_protocol::ThreadHistoryBuilder;
use codex_app_server_protocol::TurnStatus;
use codex_protocol::ThreadId;
use codex_protocol::config_types::CollaborationModeMask;
#[cfg(test)]
@@ -34,6 +37,8 @@ use codex_protocol::protocol::McpServerRefreshConfig;
use codex_protocol::protocol::Op;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::TurnAbortReason;
use codex_protocol::protocol::TurnAbortedEvent;
use codex_protocol::protocol::W3cTraceContext;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
@@ -126,6 +131,45 @@ pub struct NewThread {
pub session_configured: SessionConfiguredEvent,
}
// TODO(ccunningham): Add an explicit non-interrupting live-turn snapshot once
// core can represent sampling boundaries directly instead of relying on
// whichever items happened to be persisted mid-turn.
//
// Two likely future variants:
// - `TruncateToLastSamplingBoundary` for callers that want a coherent fork from
// the last stable model boundary without synthesizing an interrupt.
// - `WaitUntilNextSamplingBoundary` (or similar) for callers that prefer to
// fork after the next sampling boundary rather than interrupting immediately.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ForkSnapshot {
/// Fork a committed prefix ending strictly before the nth user message.
///
/// When `n` is within range, this cuts before that 0-based user-message
/// boundary. When `n` is out of range and the source thread is currently
/// mid-turn, this instead cuts before the active turn's opening boundary
/// so the fork drops the unfinished turn suffix. When `n` is out of range
/// and the source thread is already at a turn boundary, this returns the
/// full committed history unchanged.
TruncateBeforeNthUserMessage(usize),
/// Fork the current persisted history as if the source thread had been
/// interrupted now.
///
/// If the persisted snapshot ends mid-turn, this appends the same
/// `<turn_aborted>` marker produced by a real interrupt. If the snapshot is
/// already at a turn boundary, this returns the current persisted history
/// unchanged.
Interrupted,
}
/// Preserve legacy `fork_thread(usize, ...)` callsites by mapping them to the
/// existing truncate-before-nth-user-message snapshot mode.
impl From<usize> for ForkSnapshot {
fn from(value: usize) -> Self {
Self::TruncateBeforeNthUserMessage(value)
}
}
#[derive(Debug, Default, PartialEq, Eq)]
pub struct ThreadShutdownReport {
pub completed: Vec<ThreadId>,
@@ -541,20 +585,41 @@ impl ThreadManager {
report
}
/// Fork an existing thread by taking messages up to the given position (not including
/// the message at the given position) and starting a new thread with identical
/// configuration (unless overridden by the caller's `config`). The new thread will have
/// a fresh id. Pass `usize::MAX` to keep the full rollout history.
pub async fn fork_thread(
/// Fork an existing thread by snapshotting rollout history according to
/// `snapshot` and starting a new thread with identical configuration
/// (unless overridden by the caller's `config`). The new thread will have
/// a fresh id.
pub async fn fork_thread<S>(
&self,
nth_user_message: usize,
snapshot: S,
config: Config,
path: PathBuf,
persist_extended_history: bool,
parent_trace: Option<W3cTraceContext>,
) -> CodexResult<NewThread> {
) -> CodexResult<NewThread>
where
S: Into<ForkSnapshot>,
{
let snapshot = snapshot.into();
let history = RolloutRecorder::get_rollout_history(&path).await?;
let history = truncate_before_nth_user_message(history, nth_user_message);
let snapshot_state = snapshot_turn_state(&history);
let history = match snapshot {
ForkSnapshot::TruncateBeforeNthUserMessage(nth_user_message) => {
truncate_before_nth_user_message(history, nth_user_message, &snapshot_state)
}
ForkSnapshot::Interrupted => {
let history = match history {
InitialHistory::New => InitialHistory::New,
InitialHistory::Forked(history) => InitialHistory::Forked(history),
InitialHistory::Resumed(resumed) => InitialHistory::Forked(resumed.history),
};
if snapshot_state.ends_mid_turn {
append_interrupted_boundary(history, snapshot_state.active_turn_id)
} else {
history
}
}
};
Box::pin(self.state.spawn_thread(
config,
history,
@@ -838,11 +903,31 @@ impl ThreadManagerState {
}
}
/// Return a prefix of `items` obtained by cutting strictly before the nth user message
/// (0-based) and all items that follow it.
fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> InitialHistory {
/// Return a fork snapshot cut strictly before the nth user message (0-based).
///
/// Out-of-range values keep the full committed history at a turn boundary, but
/// when the source thread is currently mid-turn they fall back to cutting
/// before the active turn's opening boundary so the fork omits the unfinished
/// suffix entirely.
fn truncate_before_nth_user_message(
history: InitialHistory,
n: usize,
snapshot_state: &SnapshotTurnState,
) -> InitialHistory {
let items: Vec<RolloutItem> = history.get_rollout_items();
let rolled = truncation::truncate_rollout_before_nth_user_message_from_start(&items, n);
let user_positions = truncation::user_message_positions_in_rollout(&items);
let rolled = if snapshot_state.ends_mid_turn && n >= user_positions.len() {
if let Some(cut_idx) = snapshot_state
.active_turn_start_index
.or_else(|| user_positions.last().copied())
{
items[..cut_idx].to_vec()
} else {
items
}
} else {
truncation::truncate_rollout_before_nth_user_message_from_start(&items, n)
};
if rolled.is_empty() {
InitialHistory::New
@@ -851,6 +936,95 @@ fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> Initia
}
}
#[derive(Debug, Eq, PartialEq)]
struct SnapshotTurnState {
ends_mid_turn: bool,
active_turn_id: Option<String>,
active_turn_start_index: Option<usize>,
}
fn snapshot_turn_state(history: &InitialHistory) -> SnapshotTurnState {
let rollout_items = history.get_rollout_items();
let mut builder = ThreadHistoryBuilder::new();
for item in &rollout_items {
builder.handle_rollout_item(item);
}
let active_turn_id = builder.active_turn_id_if_explicit();
if builder.has_active_turn() && active_turn_id.is_some() {
let active_turn_snapshot = builder.active_turn_snapshot();
if active_turn_snapshot
.as_ref()
.is_some_and(|turn| turn.status != TurnStatus::InProgress)
{
return SnapshotTurnState {
ends_mid_turn: false,
active_turn_id: None,
active_turn_start_index: None,
};
}
return SnapshotTurnState {
ends_mid_turn: true,
active_turn_id,
active_turn_start_index: builder.active_turn_start_index(),
};
}
let Some(last_user_position) = truncation::user_message_positions_in_rollout(&rollout_items)
.last()
.copied()
else {
return SnapshotTurnState {
ends_mid_turn: false,
active_turn_id: None,
active_turn_start_index: None,
};
};
// Synthetic fork/resume histories can contain user/assistant response items
// without explicit turn lifecycle events. If the persisted snapshot has no
// terminating boundary after its last user message, treat it as mid-turn.
SnapshotTurnState {
ends_mid_turn: !rollout_items[last_user_position + 1..].iter().any(|item| {
matches!(
item,
RolloutItem::EventMsg(EventMsg::TurnComplete(_) | EventMsg::TurnAborted(_))
)
}),
active_turn_id: None,
active_turn_start_index: None,
}
}
/// Append the same persisted interrupt boundary used by the live interrupt path
/// to an existing fork snapshot after the source thread has been confirmed to
/// be mid-turn.
fn append_interrupted_boundary(history: InitialHistory, turn_id: Option<String>) -> InitialHistory {
let aborted_event = RolloutItem::EventMsg(EventMsg::TurnAborted(TurnAbortedEvent {
turn_id,
reason: TurnAbortReason::Interrupted,
}));
match history {
InitialHistory::New => InitialHistory::Forked(vec![
RolloutItem::ResponseItem(interrupted_turn_history_marker()),
aborted_event,
]),
InitialHistory::Forked(mut history) => {
history.push(RolloutItem::ResponseItem(interrupted_turn_history_marker()));
history.push(aborted_event);
InitialHistory::Forked(history)
}
InitialHistory::Resumed(mut resumed) => {
resumed
.history
.push(RolloutItem::ResponseItem(interrupted_turn_history_marker()));
resumed.history.push(aborted_event);
InitialHistory::Forked(resumed.history)
}
}
}
#[cfg(test)]
#[path = "thread_manager_tests.rs"]
mod tests;