feat(rollout): preserve fork references across replay

Preserve fork-reference replay behavior on the current origin/main base and collapse the branch back to a single commit for easier future restacks.
This commit is contained in:
Friel
2026-03-14 13:31:40 -07:00
parent 65f631c3d6
commit f30fde6221
20 changed files with 2320 additions and 85 deletions

View File

@@ -19,6 +19,59 @@ use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;
fn find_user_input_positions(items: &[RolloutItem]) -> Vec<usize> {
let mut pos = Vec::new();
for (i, it) in items.iter().enumerate() {
if let RolloutItem::ResponseItem(response_item) = it
&& let Some(TurnItem::UserMessage(_)) = parse_turn_item(response_item)
{
pos.push(i);
}
}
pos
}
fn truncate_before_nth_user_message(
items: &[RolloutItem],
nth_user_message: usize,
) -> Vec<RolloutItem> {
if nth_user_message == usize::MAX {
return items.to_vec();
}
let user_inputs = find_user_input_positions(items);
let Some(cut_idx) = user_inputs.get(nth_user_message).copied() else {
return Vec::new();
};
items[..cut_idx].to_vec()
}
fn read_items_materialized(p: &std::path::Path) -> Vec<RolloutItem> {
let text =
std::fs::read_to_string(p).unwrap_or_else(|err| panic!("read rollout file {p:?}: {err}"));
let mut items: Vec<RolloutItem> = Vec::new();
for line in text.lines() {
if line.trim().is_empty() {
continue;
}
let v: serde_json::Value =
serde_json::from_str(line).unwrap_or_else(|err| panic!("jsonl line parse: {err}"));
let rl: RolloutLine =
serde_json::from_value(v).unwrap_or_else(|err| panic!("rollout line parse: {err}"));
match rl.item {
RolloutItem::SessionMeta(_) => {}
RolloutItem::ForkReference(reference) => {
let parent_items = read_items_materialized(&reference.rollout_path);
items.extend(truncate_before_nth_user_message(
&parent_items,
reference.nth_user_message,
));
}
other => items.push(other),
}
}
items
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn fork_thread_twice_drops_to_first_message() {
skip_if_no_network!();
@@ -64,40 +117,9 @@ async fn fork_thread_twice_drops_to_first_message() {
// GetHistory flushes before returning the path; no wait needed.
// Helper: read rollout items (excluding SessionMeta) from a JSONL path.
let read_items = |p: &std::path::Path| -> Vec<RolloutItem> {
let text = std::fs::read_to_string(p).expect("read rollout file");
let mut items: Vec<RolloutItem> = Vec::new();
for line in text.lines() {
if line.trim().is_empty() {
continue;
}
let v: serde_json::Value = serde_json::from_str(line).expect("jsonl line");
let rl: RolloutLine = serde_json::from_value(v).expect("rollout line");
match rl.item {
RolloutItem::SessionMeta(_) => {}
other => items.push(other),
}
}
items
};
// Compute expected prefixes after each fork by truncating base rollout
// strictly before the nth user input (0-based).
let base_items = read_items(&base_path);
let find_user_input_positions = |items: &[RolloutItem]| -> Vec<usize> {
let mut pos = Vec::new();
for (i, it) in items.iter().enumerate() {
if let RolloutItem::ResponseItem(response_item) = it
&& let Some(TurnItem::UserMessage(_)) = parse_turn_item(response_item)
{
// Consider any user message as an input boundary; recorder stores both EventMsg and ResponseItem.
// We specifically look for input items, which are represented as ContentItem::InputText.
pos.push(i);
}
}
pos
};
let base_items = read_items_materialized(&base_path);
let user_inputs = find_user_input_positions(&base_items);
// After cutting at nth user input (n=1 → second user message), cut strictly before that input.
@@ -124,9 +146,10 @@ async fn fork_thread_twice_drops_to_first_message() {
let fork1_path = codex_fork1.rollout_path().expect("rollout path");
// GetHistory on fork1 flushed; the file is ready.
let fork1_items = read_items(&fork1_path);
let fork1_items = read_items_materialized(&fork1_path);
assert!(fork1_items.len() > expected_after_first.len());
pretty_assertions::assert_eq!(
serde_json::to_value(&fork1_items).unwrap(),
serde_json::to_value(&fork1_items[..expected_after_first.len()]).unwrap(),
serde_json::to_value(&expected_after_first).unwrap()
);
@@ -147,16 +170,68 @@ async fn fork_thread_twice_drops_to_first_message() {
let fork2_path = codex_fork2.rollout_path().expect("rollout path");
// GetHistory on fork2 flushed; the file is ready.
let fork1_items = read_items(&fork1_path);
let fork1_items = read_items_materialized(&fork1_path);
let fork1_user_inputs = find_user_input_positions(&fork1_items);
let cut_last_on_fork1 = fork1_user_inputs
.get(fork1_user_inputs.len().saturating_sub(1))
.copied()
.unwrap_or(0);
let expected_after_second: Vec<RolloutItem> = fork1_items[..cut_last_on_fork1].to_vec();
let fork2_items = read_items(&fork2_path);
let fork2_items = read_items_materialized(&fork2_path);
assert!(fork2_items.len() > expected_after_second.len());
pretty_assertions::assert_eq!(
serde_json::to_value(&fork2_items).unwrap(),
serde_json::to_value(&fork2_items[..expected_after_second.len()]).unwrap(),
serde_json::to_value(&expected_after_second).unwrap()
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn fork_thread_session_configured_preserves_parent_and_history() {
skip_if_no_network!();
let server = MockServer::start().await;
let sse = sse(vec![ev_response_created("resp"), ev_completed("resp")]);
let response = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse, "text/event-stream");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(response)
.expect(1)
.mount(&server)
.await;
let mut builder = test_codex();
let test = builder.build(&server).await.expect("create conversation");
let codex = test.codex.clone();
let thread_manager = test.thread_manager.clone();
let config_for_fork = test.config.clone();
let parent_thread_id = test.session_configured.session_id;
codex
.submit(Op::UserInput {
items: vec![UserInput::Text {
text: "seed".to_string(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
})
.await
.unwrap();
let _ = wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
let base_path = codex.rollout_path().expect("rollout path");
let NewThread {
thread_id: child_thread_id,
session_configured,
..
} = thread_manager
.fork_thread(usize::MAX, config_for_fork, base_path, false, None)
.await
.expect("fork thread");
pretty_assertions::assert_eq!(session_configured.forked_from_id, Some(parent_thread_id));
assert_ne!(child_thread_id, parent_thread_id);
}