mirror of
https://github.com/openai/codex.git
synced 2026-04-30 19:32:04 +03:00
feat(rollout): preserve fork references across replay
Preserve fork-reference replay behavior on the current origin/main base and collapse the branch back to a single commit for easier future restacks.
This commit is contained in:
@@ -19,6 +19,59 @@ use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
fn find_user_input_positions(items: &[RolloutItem]) -> Vec<usize> {
|
||||
let mut pos = Vec::new();
|
||||
for (i, it) in items.iter().enumerate() {
|
||||
if let RolloutItem::ResponseItem(response_item) = it
|
||||
&& let Some(TurnItem::UserMessage(_)) = parse_turn_item(response_item)
|
||||
{
|
||||
pos.push(i);
|
||||
}
|
||||
}
|
||||
pos
|
||||
}
|
||||
|
||||
fn truncate_before_nth_user_message(
|
||||
items: &[RolloutItem],
|
||||
nth_user_message: usize,
|
||||
) -> Vec<RolloutItem> {
|
||||
if nth_user_message == usize::MAX {
|
||||
return items.to_vec();
|
||||
}
|
||||
let user_inputs = find_user_input_positions(items);
|
||||
let Some(cut_idx) = user_inputs.get(nth_user_message).copied() else {
|
||||
return Vec::new();
|
||||
};
|
||||
items[..cut_idx].to_vec()
|
||||
}
|
||||
|
||||
fn read_items_materialized(p: &std::path::Path) -> Vec<RolloutItem> {
|
||||
let text =
|
||||
std::fs::read_to_string(p).unwrap_or_else(|err| panic!("read rollout file {p:?}: {err}"));
|
||||
let mut items: Vec<RolloutItem> = Vec::new();
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: serde_json::Value =
|
||||
serde_json::from_str(line).unwrap_or_else(|err| panic!("jsonl line parse: {err}"));
|
||||
let rl: RolloutLine =
|
||||
serde_json::from_value(v).unwrap_or_else(|err| panic!("rollout line parse: {err}"));
|
||||
match rl.item {
|
||||
RolloutItem::SessionMeta(_) => {}
|
||||
RolloutItem::ForkReference(reference) => {
|
||||
let parent_items = read_items_materialized(&reference.rollout_path);
|
||||
items.extend(truncate_before_nth_user_message(
|
||||
&parent_items,
|
||||
reference.nth_user_message,
|
||||
));
|
||||
}
|
||||
other => items.push(other),
|
||||
}
|
||||
}
|
||||
items
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn fork_thread_twice_drops_to_first_message() {
|
||||
skip_if_no_network!();
|
||||
@@ -64,40 +117,9 @@ async fn fork_thread_twice_drops_to_first_message() {
|
||||
|
||||
// GetHistory flushes before returning the path; no wait needed.
|
||||
|
||||
// Helper: read rollout items (excluding SessionMeta) from a JSONL path.
|
||||
let read_items = |p: &std::path::Path| -> Vec<RolloutItem> {
|
||||
let text = std::fs::read_to_string(p).expect("read rollout file");
|
||||
let mut items: Vec<RolloutItem> = Vec::new();
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: serde_json::Value = serde_json::from_str(line).expect("jsonl line");
|
||||
let rl: RolloutLine = serde_json::from_value(v).expect("rollout line");
|
||||
match rl.item {
|
||||
RolloutItem::SessionMeta(_) => {}
|
||||
other => items.push(other),
|
||||
}
|
||||
}
|
||||
items
|
||||
};
|
||||
|
||||
// Compute expected prefixes after each fork by truncating base rollout
|
||||
// strictly before the nth user input (0-based).
|
||||
let base_items = read_items(&base_path);
|
||||
let find_user_input_positions = |items: &[RolloutItem]| -> Vec<usize> {
|
||||
let mut pos = Vec::new();
|
||||
for (i, it) in items.iter().enumerate() {
|
||||
if let RolloutItem::ResponseItem(response_item) = it
|
||||
&& let Some(TurnItem::UserMessage(_)) = parse_turn_item(response_item)
|
||||
{
|
||||
// Consider any user message as an input boundary; recorder stores both EventMsg and ResponseItem.
|
||||
// We specifically look for input items, which are represented as ContentItem::InputText.
|
||||
pos.push(i);
|
||||
}
|
||||
}
|
||||
pos
|
||||
};
|
||||
let base_items = read_items_materialized(&base_path);
|
||||
let user_inputs = find_user_input_positions(&base_items);
|
||||
|
||||
// After cutting at nth user input (n=1 → second user message), cut strictly before that input.
|
||||
@@ -124,9 +146,10 @@ async fn fork_thread_twice_drops_to_first_message() {
|
||||
let fork1_path = codex_fork1.rollout_path().expect("rollout path");
|
||||
|
||||
// GetHistory on fork1 flushed; the file is ready.
|
||||
let fork1_items = read_items(&fork1_path);
|
||||
let fork1_items = read_items_materialized(&fork1_path);
|
||||
assert!(fork1_items.len() > expected_after_first.len());
|
||||
pretty_assertions::assert_eq!(
|
||||
serde_json::to_value(&fork1_items).unwrap(),
|
||||
serde_json::to_value(&fork1_items[..expected_after_first.len()]).unwrap(),
|
||||
serde_json::to_value(&expected_after_first).unwrap()
|
||||
);
|
||||
|
||||
@@ -147,16 +170,68 @@ async fn fork_thread_twice_drops_to_first_message() {
|
||||
|
||||
let fork2_path = codex_fork2.rollout_path().expect("rollout path");
|
||||
// GetHistory on fork2 flushed; the file is ready.
|
||||
let fork1_items = read_items(&fork1_path);
|
||||
let fork1_items = read_items_materialized(&fork1_path);
|
||||
let fork1_user_inputs = find_user_input_positions(&fork1_items);
|
||||
let cut_last_on_fork1 = fork1_user_inputs
|
||||
.get(fork1_user_inputs.len().saturating_sub(1))
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
let expected_after_second: Vec<RolloutItem> = fork1_items[..cut_last_on_fork1].to_vec();
|
||||
let fork2_items = read_items(&fork2_path);
|
||||
let fork2_items = read_items_materialized(&fork2_path);
|
||||
assert!(fork2_items.len() > expected_after_second.len());
|
||||
pretty_assertions::assert_eq!(
|
||||
serde_json::to_value(&fork2_items).unwrap(),
|
||||
serde_json::to_value(&fork2_items[..expected_after_second.len()]).unwrap(),
|
||||
serde_json::to_value(&expected_after_second).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn fork_thread_session_configured_preserves_parent_and_history() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = MockServer::start().await;
|
||||
let sse = sse(vec![ev_response_created("resp"), ev_completed("resp")]);
|
||||
let response = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse, "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(response)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build(&server).await.expect("create conversation");
|
||||
let codex = test.codex.clone();
|
||||
let thread_manager = test.thread_manager.clone();
|
||||
let config_for_fork = test.config.clone();
|
||||
let parent_thread_id = test.session_configured.session_id;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "seed".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let _ = wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
|
||||
|
||||
let base_path = codex.rollout_path().expect("rollout path");
|
||||
|
||||
let NewThread {
|
||||
thread_id: child_thread_id,
|
||||
session_configured,
|
||||
..
|
||||
} = thread_manager
|
||||
.fork_thread(usize::MAX, config_for_fork, base_path, false, None)
|
||||
.await
|
||||
.expect("fork thread");
|
||||
|
||||
pretty_assertions::assert_eq!(session_configured.forked_from_id, Some(parent_thread_id));
|
||||
assert_ne!(child_thread_id, parent_thread_id);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user