mirror of
https://github.com/openai/codex.git
synced 2026-05-05 22:01:37 +03:00
Add fork snapshot modes (#15239)
## Summary - add `ForkSnapshotMode` to `ThreadManager::fork_thread` so callers can request either a committed snapshot or an interrupted snapshot - share the model-visible `<turn_aborted>` history marker between the live interrupt path and interrupted forks - update the small set of direct fork callsites to pass `ForkSnapshotMode::Committed` Note: this enables /btw to work similarly as Esc to interrupt (hopefully somewhat in distribution) --------- Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
committed by
GitHub
parent
84fb180eeb
commit
f547b79bd0
@@ -24,6 +24,9 @@ use crate::rollout::RolloutRecorder;
|
||||
use crate::rollout::truncation;
|
||||
use crate::shell_snapshot::ShellSnapshot;
|
||||
use crate::skills::SkillsManager;
|
||||
use crate::tasks::interrupted_turn_history_marker;
|
||||
use codex_app_server_protocol::ThreadHistoryBuilder;
|
||||
use codex_app_server_protocol::TurnStatus;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::config_types::CollaborationModeMask;
|
||||
#[cfg(test)]
|
||||
@@ -34,6 +37,8 @@ use codex_protocol::protocol::McpServerRefreshConfig;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnAbortedEvent;
|
||||
use codex_protocol::protocol::W3cTraceContext;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::FuturesUnordered;
|
||||
@@ -126,6 +131,45 @@ pub struct NewThread {
|
||||
pub session_configured: SessionConfiguredEvent,
|
||||
}
|
||||
|
||||
// TODO(ccunningham): Add an explicit non-interrupting live-turn snapshot once
|
||||
// core can represent sampling boundaries directly instead of relying on
|
||||
// whichever items happened to be persisted mid-turn.
|
||||
//
|
||||
// Two likely future variants:
|
||||
// - `TruncateToLastSamplingBoundary` for callers that want a coherent fork from
|
||||
// the last stable model boundary without synthesizing an interrupt.
|
||||
// - `WaitUntilNextSamplingBoundary` (or similar) for callers that prefer to
|
||||
// fork after the next sampling boundary rather than interrupting immediately.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ForkSnapshot {
|
||||
/// Fork a committed prefix ending strictly before the nth user message.
|
||||
///
|
||||
/// When `n` is within range, this cuts before that 0-based user-message
|
||||
/// boundary. When `n` is out of range and the source thread is currently
|
||||
/// mid-turn, this instead cuts before the active turn's opening boundary
|
||||
/// so the fork drops the unfinished turn suffix. When `n` is out of range
|
||||
/// and the source thread is already at a turn boundary, this returns the
|
||||
/// full committed history unchanged.
|
||||
TruncateBeforeNthUserMessage(usize),
|
||||
|
||||
/// Fork the current persisted history as if the source thread had been
|
||||
/// interrupted now.
|
||||
///
|
||||
/// If the persisted snapshot ends mid-turn, this appends the same
|
||||
/// `<turn_aborted>` marker produced by a real interrupt. If the snapshot is
|
||||
/// already at a turn boundary, this returns the current persisted history
|
||||
/// unchanged.
|
||||
Interrupted,
|
||||
}
|
||||
|
||||
/// Preserve legacy `fork_thread(usize, ...)` callsites by mapping them to the
|
||||
/// existing truncate-before-nth-user-message snapshot mode.
|
||||
impl From<usize> for ForkSnapshot {
|
||||
fn from(value: usize) -> Self {
|
||||
Self::TruncateBeforeNthUserMessage(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, PartialEq, Eq)]
|
||||
pub struct ThreadShutdownReport {
|
||||
pub completed: Vec<ThreadId>,
|
||||
@@ -541,20 +585,41 @@ impl ThreadManager {
|
||||
report
|
||||
}
|
||||
|
||||
/// Fork an existing thread by taking messages up to the given position (not including
|
||||
/// the message at the given position) and starting a new thread with identical
|
||||
/// configuration (unless overridden by the caller's `config`). The new thread will have
|
||||
/// a fresh id. Pass `usize::MAX` to keep the full rollout history.
|
||||
pub async fn fork_thread(
|
||||
/// Fork an existing thread by snapshotting rollout history according to
|
||||
/// `snapshot` and starting a new thread with identical configuration
|
||||
/// (unless overridden by the caller's `config`). The new thread will have
|
||||
/// a fresh id.
|
||||
pub async fn fork_thread<S>(
|
||||
&self,
|
||||
nth_user_message: usize,
|
||||
snapshot: S,
|
||||
config: Config,
|
||||
path: PathBuf,
|
||||
persist_extended_history: bool,
|
||||
parent_trace: Option<W3cTraceContext>,
|
||||
) -> CodexResult<NewThread> {
|
||||
) -> CodexResult<NewThread>
|
||||
where
|
||||
S: Into<ForkSnapshot>,
|
||||
{
|
||||
let snapshot = snapshot.into();
|
||||
let history = RolloutRecorder::get_rollout_history(&path).await?;
|
||||
let history = truncate_before_nth_user_message(history, nth_user_message);
|
||||
let snapshot_state = snapshot_turn_state(&history);
|
||||
let history = match snapshot {
|
||||
ForkSnapshot::TruncateBeforeNthUserMessage(nth_user_message) => {
|
||||
truncate_before_nth_user_message(history, nth_user_message, &snapshot_state)
|
||||
}
|
||||
ForkSnapshot::Interrupted => {
|
||||
let history = match history {
|
||||
InitialHistory::New => InitialHistory::New,
|
||||
InitialHistory::Forked(history) => InitialHistory::Forked(history),
|
||||
InitialHistory::Resumed(resumed) => InitialHistory::Forked(resumed.history),
|
||||
};
|
||||
if snapshot_state.ends_mid_turn {
|
||||
append_interrupted_boundary(history, snapshot_state.active_turn_id)
|
||||
} else {
|
||||
history
|
||||
}
|
||||
}
|
||||
};
|
||||
Box::pin(self.state.spawn_thread(
|
||||
config,
|
||||
history,
|
||||
@@ -838,11 +903,31 @@ impl ThreadManagerState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a prefix of `items` obtained by cutting strictly before the nth user message
|
||||
/// (0-based) and all items that follow it.
|
||||
fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> InitialHistory {
|
||||
/// Return a fork snapshot cut strictly before the nth user message (0-based).
|
||||
///
|
||||
/// Out-of-range values keep the full committed history at a turn boundary, but
|
||||
/// when the source thread is currently mid-turn they fall back to cutting
|
||||
/// before the active turn's opening boundary so the fork omits the unfinished
|
||||
/// suffix entirely.
|
||||
fn truncate_before_nth_user_message(
|
||||
history: InitialHistory,
|
||||
n: usize,
|
||||
snapshot_state: &SnapshotTurnState,
|
||||
) -> InitialHistory {
|
||||
let items: Vec<RolloutItem> = history.get_rollout_items();
|
||||
let rolled = truncation::truncate_rollout_before_nth_user_message_from_start(&items, n);
|
||||
let user_positions = truncation::user_message_positions_in_rollout(&items);
|
||||
let rolled = if snapshot_state.ends_mid_turn && n >= user_positions.len() {
|
||||
if let Some(cut_idx) = snapshot_state
|
||||
.active_turn_start_index
|
||||
.or_else(|| user_positions.last().copied())
|
||||
{
|
||||
items[..cut_idx].to_vec()
|
||||
} else {
|
||||
items
|
||||
}
|
||||
} else {
|
||||
truncation::truncate_rollout_before_nth_user_message_from_start(&items, n)
|
||||
};
|
||||
|
||||
if rolled.is_empty() {
|
||||
InitialHistory::New
|
||||
@@ -851,6 +936,95 @@ fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> Initia
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
struct SnapshotTurnState {
|
||||
ends_mid_turn: bool,
|
||||
active_turn_id: Option<String>,
|
||||
active_turn_start_index: Option<usize>,
|
||||
}
|
||||
|
||||
fn snapshot_turn_state(history: &InitialHistory) -> SnapshotTurnState {
|
||||
let rollout_items = history.get_rollout_items();
|
||||
let mut builder = ThreadHistoryBuilder::new();
|
||||
for item in &rollout_items {
|
||||
builder.handle_rollout_item(item);
|
||||
}
|
||||
let active_turn_id = builder.active_turn_id_if_explicit();
|
||||
if builder.has_active_turn() && active_turn_id.is_some() {
|
||||
let active_turn_snapshot = builder.active_turn_snapshot();
|
||||
if active_turn_snapshot
|
||||
.as_ref()
|
||||
.is_some_and(|turn| turn.status != TurnStatus::InProgress)
|
||||
{
|
||||
return SnapshotTurnState {
|
||||
ends_mid_turn: false,
|
||||
active_turn_id: None,
|
||||
active_turn_start_index: None,
|
||||
};
|
||||
}
|
||||
|
||||
return SnapshotTurnState {
|
||||
ends_mid_turn: true,
|
||||
active_turn_id,
|
||||
active_turn_start_index: builder.active_turn_start_index(),
|
||||
};
|
||||
}
|
||||
|
||||
let Some(last_user_position) = truncation::user_message_positions_in_rollout(&rollout_items)
|
||||
.last()
|
||||
.copied()
|
||||
else {
|
||||
return SnapshotTurnState {
|
||||
ends_mid_turn: false,
|
||||
active_turn_id: None,
|
||||
active_turn_start_index: None,
|
||||
};
|
||||
};
|
||||
|
||||
// Synthetic fork/resume histories can contain user/assistant response items
|
||||
// without explicit turn lifecycle events. If the persisted snapshot has no
|
||||
// terminating boundary after its last user message, treat it as mid-turn.
|
||||
SnapshotTurnState {
|
||||
ends_mid_turn: !rollout_items[last_user_position + 1..].iter().any(|item| {
|
||||
matches!(
|
||||
item,
|
||||
RolloutItem::EventMsg(EventMsg::TurnComplete(_) | EventMsg::TurnAborted(_))
|
||||
)
|
||||
}),
|
||||
active_turn_id: None,
|
||||
active_turn_start_index: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Append the same persisted interrupt boundary used by the live interrupt path
|
||||
/// to an existing fork snapshot after the source thread has been confirmed to
|
||||
/// be mid-turn.
|
||||
fn append_interrupted_boundary(history: InitialHistory, turn_id: Option<String>) -> InitialHistory {
|
||||
let aborted_event = RolloutItem::EventMsg(EventMsg::TurnAborted(TurnAbortedEvent {
|
||||
turn_id,
|
||||
reason: TurnAbortReason::Interrupted,
|
||||
}));
|
||||
|
||||
match history {
|
||||
InitialHistory::New => InitialHistory::Forked(vec![
|
||||
RolloutItem::ResponseItem(interrupted_turn_history_marker()),
|
||||
aborted_event,
|
||||
]),
|
||||
InitialHistory::Forked(mut history) => {
|
||||
history.push(RolloutItem::ResponseItem(interrupted_turn_history_marker()));
|
||||
history.push(aborted_event);
|
||||
InitialHistory::Forked(history)
|
||||
}
|
||||
InitialHistory::Resumed(mut resumed) => {
|
||||
resumed
|
||||
.history
|
||||
.push(RolloutItem::ResponseItem(interrupted_turn_history_marker()));
|
||||
resumed.history.push(aborted_event);
|
||||
InitialHistory::Forked(resumed.history)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "thread_manager_tests.rs"]
|
||||
mod tests;
|
||||
|
||||
Reference in New Issue
Block a user