feat: memories forgetting (#12900)

Add diff based memory forgetting
This commit is contained in:
jif-oai
2026-02-26 13:19:57 +00:00
committed by GitHub
parent 81ce645733
commit 382fa338b3
12 changed files with 1335 additions and 39 deletions

View File

@@ -2773,7 +2773,11 @@ WHERE kind = 'memory_stage1'
assert_eq!(phase2_input_watermark, 100);
assert!(
runtime
.mark_global_phase2_job_succeeded(phase2_token.as_str(), phase2_input_watermark)
.mark_global_phase2_job_succeeded(
phase2_token.as_str(),
phase2_input_watermark,
&[],
)
.await
.expect("mark initial phase2 succeeded"),
"initial phase2 success should clear global dirty state"
@@ -2819,7 +2823,11 @@ WHERE kind = 'memory_stage1'
assert_eq!(phase2_input_watermark, 101);
assert!(
runtime
.mark_global_phase2_job_succeeded(phase2_token.as_str(), phase2_input_watermark)
.mark_global_phase2_job_succeeded(
phase2_token.as_str(),
phase2_input_watermark,
&[],
)
.await
.expect("mark phase2 succeeded after no-output delete")
);
@@ -2936,7 +2944,7 @@ WHERE kind = 'memory_stage1'
};
assert!(
runtime
.mark_global_phase2_job_succeeded(ownership_token.as_str(), input_watermark)
.mark_global_phase2_job_succeeded(ownership_token.as_str(), input_watermark, &[],)
.await
.expect("mark phase2 succeeded"),
"phase2 success should finalize for current token"
@@ -3124,6 +3132,646 @@ VALUES (?, ?, ?, ?, ?)
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_phase2_input_selection_reports_added_retained_and_removed_rows() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let thread_id_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let thread_id_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
for (thread_id, workspace) in [
(thread_id_a, "workspace-a"),
(thread_id_b, "workspace-b"),
(thread_id_c, "workspace-c"),
] {
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join(workspace),
))
.await
.expect("upsert thread");
}
for (thread_id, updated_at, slug) in [
(thread_id_a, 100, Some("rollout-a")),
(thread_id_b, 101, Some("rollout-b")),
(thread_id_c, 102, Some("rollout-c")),
] {
let claim = runtime
.try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64)
.await
.expect("claim stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
updated_at,
&format!("raw-{updated_at}"),
&format!("summary-{updated_at}"),
slug,
)
.await
.expect("mark stage1 succeeded"),
"stage1 success should persist output"
);
}
let claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2");
let (ownership_token, input_watermark) = match claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected phase2 claim outcome: {other:?}"),
};
assert_eq!(input_watermark, 102);
let selected_outputs = runtime
.list_stage1_outputs_for_global(10)
.await
.expect("list stage1 outputs for global")
.into_iter()
.filter(|output| output.thread_id == thread_id_c || output.thread_id == thread_id_a)
.collect::<Vec<_>>();
assert!(
runtime
.mark_global_phase2_job_succeeded(
ownership_token.as_str(),
input_watermark,
&selected_outputs,
)
.await
.expect("mark phase2 success with selection"),
"phase2 success should persist selected rows"
);
let selection = runtime
.get_phase2_input_selection(2)
.await
.expect("load phase2 input selection");
assert_eq!(selection.selected.len(), 2);
assert_eq!(selection.previous_selected.len(), 2);
assert_eq!(selection.selected[0].thread_id, thread_id_c);
assert_eq!(
selection.selected[0].rollout_path,
codex_home.join(format!("rollout-{thread_id_c}.jsonl"))
);
assert_eq!(selection.selected[1].thread_id, thread_id_b);
assert_eq!(selection.retained_thread_ids, vec![thread_id_c]);
assert_eq!(selection.removed.len(), 1);
assert_eq!(selection.removed[0].thread_id, thread_id_a);
assert_eq!(
selection.removed[0].rollout_slug.as_deref(),
Some("rollout-a")
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_phase2_input_selection_treats_regenerated_selected_rows_as_added() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join("workspace"),
))
.await
.expect("upsert thread");
let first_claim = runtime
.try_claim_stage1_job(thread_id, owner, 100, 3600, 64)
.await
.expect("claim initial stage1");
let first_token = match first_claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
first_token.as_str(),
100,
"raw-100",
"summary-100",
Some("rollout-100"),
)
.await
.expect("mark initial stage1 success"),
"initial stage1 success should persist output"
);
let phase2_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2");
let (phase2_token, input_watermark) = match phase2_claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected phase2 claim outcome: {other:?}"),
};
let selected_outputs = runtime
.list_stage1_outputs_for_global(1)
.await
.expect("list selected outputs");
assert!(
runtime
.mark_global_phase2_job_succeeded(
phase2_token.as_str(),
input_watermark,
&selected_outputs,
)
.await
.expect("mark phase2 success"),
"phase2 success should persist selected rows"
);
let refreshed_claim = runtime
.try_claim_stage1_job(thread_id, owner, 101, 3600, 64)
.await
.expect("claim refreshed stage1");
let refreshed_token = match refreshed_claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
refreshed_token.as_str(),
101,
"raw-101",
"summary-101",
Some("rollout-101"),
)
.await
.expect("mark refreshed stage1 success"),
"refreshed stage1 success should persist output"
);
let selection = runtime
.get_phase2_input_selection(1)
.await
.expect("load phase2 input selection");
assert_eq!(selection.selected.len(), 1);
assert_eq!(selection.previous_selected.len(), 1);
assert_eq!(selection.selected[0].thread_id, thread_id);
assert_eq!(selection.selected[0].source_updated_at.timestamp(), 101);
assert!(selection.retained_thread_ids.is_empty());
assert!(selection.removed.is_empty());
let (selected_for_phase2, selected_for_phase2_source_updated_at) =
sqlx::query_as::<_, (i64, Option<i64>)>(
"SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?",
)
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("load selected_for_phase2");
assert_eq!(selected_for_phase2, 1);
assert_eq!(selected_for_phase2_source_updated_at, Some(100));
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn get_phase2_input_selection_reports_regenerated_previous_selection_as_removed() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread a");
let thread_id_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread b");
let thread_id_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread c");
let thread_id_d = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread d");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
for (thread_id, workspace) in [
(thread_id_a, "workspace-a"),
(thread_id_b, "workspace-b"),
(thread_id_c, "workspace-c"),
(thread_id_d, "workspace-d"),
] {
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join(workspace),
))
.await
.expect("upsert thread");
}
for (thread_id, updated_at, slug) in [
(thread_id_a, 100, Some("rollout-a-100")),
(thread_id_b, 101, Some("rollout-b-101")),
(thread_id_c, 99, Some("rollout-c-99")),
(thread_id_d, 98, Some("rollout-d-98")),
] {
let claim = runtime
.try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64)
.await
.expect("claim initial stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
updated_at,
&format!("raw-{updated_at}"),
&format!("summary-{updated_at}"),
slug,
)
.await
.expect("mark stage1 succeeded"),
"stage1 success should persist output"
);
}
let phase2_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2");
let (phase2_token, input_watermark) = match phase2_claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected phase2 claim outcome: {other:?}"),
};
let selected_outputs = runtime
.list_stage1_outputs_for_global(2)
.await
.expect("list selected outputs");
assert_eq!(
selected_outputs
.iter()
.map(|output| output.thread_id)
.collect::<Vec<_>>(),
vec![thread_id_b, thread_id_a]
);
assert!(
runtime
.mark_global_phase2_job_succeeded(
phase2_token.as_str(),
input_watermark,
&selected_outputs,
)
.await
.expect("mark phase2 success"),
"phase2 success should persist selected rows"
);
for (thread_id, updated_at, slug) in [
(thread_id_a, 102, Some("rollout-a-102")),
(thread_id_c, 103, Some("rollout-c-103")),
(thread_id_d, 104, Some("rollout-d-104")),
] {
let claim = runtime
.try_claim_stage1_job(thread_id, owner, updated_at, 3600, 64)
.await
.expect("claim refreshed stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
updated_at,
&format!("raw-{updated_at}"),
&format!("summary-{updated_at}"),
slug,
)
.await
.expect("mark refreshed stage1 success"),
"refreshed stage1 success should persist output"
);
}
let selection = runtime
.get_phase2_input_selection(2)
.await
.expect("load phase2 input selection");
assert_eq!(
selection
.selected
.iter()
.map(|output| output.thread_id)
.collect::<Vec<_>>(),
vec![thread_id_d, thread_id_c]
);
assert_eq!(
selection
.previous_selected
.iter()
.map(|output| output.thread_id)
.collect::<Vec<_>>(),
vec![thread_id_a, thread_id_b]
);
assert!(selection.retained_thread_ids.is_empty());
assert_eq!(
selection
.removed
.iter()
.map(|output| (output.thread_id, output.source_updated_at.timestamp()))
.collect::<Vec<_>>(),
vec![(thread_id_a, 102), (thread_id_b, 101)]
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn mark_global_phase2_job_succeeded_updates_selected_snapshot_timestamp() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join("workspace"),
))
.await
.expect("upsert thread");
let initial_claim = runtime
.try_claim_stage1_job(thread_id, owner, 100, 3600, 64)
.await
.expect("claim initial stage1");
let initial_token = match initial_claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
initial_token.as_str(),
100,
"raw-100",
"summary-100",
Some("rollout-100"),
)
.await
.expect("mark initial stage1 success"),
"initial stage1 success should persist output"
);
let first_phase2_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim first phase2");
let (first_phase2_token, first_input_watermark) = match first_phase2_claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected first phase2 claim outcome: {other:?}"),
};
let first_selected_outputs = runtime
.list_stage1_outputs_for_global(1)
.await
.expect("list first selected outputs");
assert!(
runtime
.mark_global_phase2_job_succeeded(
first_phase2_token.as_str(),
first_input_watermark,
&first_selected_outputs,
)
.await
.expect("mark first phase2 success"),
"first phase2 success should persist selected rows"
);
let refreshed_claim = runtime
.try_claim_stage1_job(thread_id, owner, 101, 3600, 64)
.await
.expect("claim refreshed stage1");
let refreshed_token = match refreshed_claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected refreshed stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
refreshed_token.as_str(),
101,
"raw-101",
"summary-101",
Some("rollout-101"),
)
.await
.expect("mark refreshed stage1 success"),
"refreshed stage1 success should persist output"
);
let second_phase2_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim second phase2");
let (second_phase2_token, second_input_watermark) = match second_phase2_claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected second phase2 claim outcome: {other:?}"),
};
let second_selected_outputs = runtime
.list_stage1_outputs_for_global(1)
.await
.expect("list second selected outputs");
assert_eq!(
second_selected_outputs[0].source_updated_at.timestamp(),
101
);
assert!(
runtime
.mark_global_phase2_job_succeeded(
second_phase2_token.as_str(),
second_input_watermark,
&second_selected_outputs,
)
.await
.expect("mark second phase2 success"),
"second phase2 success should persist selected rows"
);
let selection = runtime
.get_phase2_input_selection(1)
.await
.expect("load phase2 input selection after refresh");
assert_eq!(selection.retained_thread_ids, vec![thread_id]);
let (selected_for_phase2, selected_for_phase2_source_updated_at) =
sqlx::query_as::<_, (i64, Option<i64>)>(
"SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?",
)
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("load selected snapshot after phase2");
assert_eq!(selected_for_phase2, 1);
assert_eq!(selected_for_phase2_source_updated_at, Some(101));
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn mark_global_phase2_job_succeeded_only_marks_exact_selected_snapshots() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let thread_id = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join("workspace"),
))
.await
.expect("upsert thread");
let initial_claim = runtime
.try_claim_stage1_job(thread_id, owner, 100, 3600, 64)
.await
.expect("claim initial stage1");
let initial_token = match initial_claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
initial_token.as_str(),
100,
"raw-100",
"summary-100",
Some("rollout-100"),
)
.await
.expect("mark initial stage1 success"),
"initial stage1 success should persist output"
);
let phase2_claim = runtime
.try_claim_global_phase2_job(owner, 3600)
.await
.expect("claim phase2");
let (phase2_token, input_watermark) = match phase2_claim {
Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => (ownership_token, input_watermark),
other => panic!("unexpected phase2 claim outcome: {other:?}"),
};
let selected_outputs = runtime
.list_stage1_outputs_for_global(1)
.await
.expect("list selected outputs");
assert_eq!(selected_outputs[0].source_updated_at.timestamp(), 100);
let refreshed_claim = runtime
.try_claim_stage1_job(thread_id, owner, 101, 3600, 64)
.await
.expect("claim refreshed stage1");
let refreshed_token = match refreshed_claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
refreshed_token.as_str(),
101,
"raw-101",
"summary-101",
Some("rollout-101"),
)
.await
.expect("mark refreshed stage1 success"),
"refreshed stage1 success should persist output"
);
assert!(
runtime
.mark_global_phase2_job_succeeded(
phase2_token.as_str(),
input_watermark,
&selected_outputs,
)
.await
.expect("mark phase2 success"),
"phase2 success should still complete"
);
let (selected_for_phase2, selected_for_phase2_source_updated_at) =
sqlx::query_as::<_, (i64, Option<i64>)>(
"SELECT selected_for_phase2, selected_for_phase2_source_updated_at FROM stage1_outputs WHERE thread_id = ?",
)
.bind(thread_id.to_string())
.fetch_one(runtime.pool.as_ref())
.await
.expect("load selected_for_phase2");
assert_eq!(selected_for_phase2, 0);
assert_eq!(selected_for_phase2_source_updated_at, None);
let selection = runtime
.get_phase2_input_selection(1)
.await
.expect("load phase2 input selection");
assert_eq!(selection.selected.len(), 1);
assert_eq!(selection.selected[0].source_updated_at.timestamp(), 101);
assert!(selection.retained_thread_ids.is_empty());
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn record_stage1_output_usage_updates_usage_metadata() {
let codex_home = unique_temp_dir();
@@ -3395,7 +4043,7 @@ VALUES (?, ?, ?, ?, ?)
assert_eq!(
runtime
.mark_global_phase2_job_succeeded(token_a.as_str(), 300)
.mark_global_phase2_job_succeeded(token_a.as_str(), 300, &[])
.await
.expect("mark stale owner success result"),
false,
@@ -3403,7 +4051,7 @@ VALUES (?, ?, ?, ?, ?)
);
assert!(
runtime
.mark_global_phase2_job_succeeded(token_b.as_str(), 300)
.mark_global_phase2_job_succeeded(token_b.as_str(), 300, &[])
.await
.expect("mark takeover owner success"),
"takeover owner should finalize consolidation"
@@ -3440,7 +4088,7 @@ VALUES (?, ?, ?, ?, ?)
};
assert!(
runtime
.mark_global_phase2_job_succeeded(token_a.as_str(), 500)
.mark_global_phase2_job_succeeded(token_a.as_str(), 500, &[])
.await
.expect("mark initial phase2 success"),
"initial phase2 success should finalize"