feat: prune old memories in DB (#13734)

To save memory
This commit is contained in:
jif-oai
2026-03-06 14:10:49 +00:00
committed by GitHub
parent b6d43ec8eb
commit 8ad768eb76
4 changed files with 280 additions and 0 deletions

View File

@@ -57,6 +57,8 @@ mod phase_one {
pub(super) const JOB_RETRY_DELAY_SECONDS: i64 = 3_600;
/// Maximum number of threads to scan.
pub(super) const THREAD_SCAN_LIMIT: usize = 5_000;
/// Size of the batches when pruning old thread memories.
pub(super) const PRUNE_BATCH_SIZE: usize = 200;
}
/// Phase 2 (aka `Consolidation`).

View File

@@ -7,6 +7,7 @@ use crate::config::types::MemoriesConfig;
use crate::error::CodexErr;
use crate::memories::metrics;
use crate::memories::phase_one;
use crate::memories::phase_one::PRUNE_BATCH_SIZE;
use crate::memories::prompts::build_stage_one_input_message;
use crate::rollout::INTERACTIVE_SESSION_SOURCES;
use crate::rollout::policy::should_persist_response_item_for_memories;
@@ -120,6 +121,30 @@ pub(in crate::memories) async fn run(session: &Arc<Session>, config: &Config) {
);
}
/// Prune old un-used "dead" raw memories.
pub(in crate::memories) async fn prune(session: &Arc<Session>, config: &Config) {
if let Some(db) = session.services.state_db.as_deref() {
let max_unused_days = config.memories.max_unused_days;
match db
.prune_stage1_outputs_for_retention(max_unused_days, PRUNE_BATCH_SIZE)
.await
{
Ok(pruned) => {
if pruned > 0 {
info!(
"memory startup pruned {pruned} stale stage-1 output row(s) older than {max_unused_days} days"
);
}
}
Err(err) => {
warn!(
"state db prune_stage1_outputs_for_retention failed during memories startup: {err}"
);
}
}
}
}
/// JSON schema used to constrain phase-1 model output.
pub fn output_schema() -> Value {
json!({

View File

@@ -34,6 +34,8 @@ pub(crate) fn start_memories_startup_task(
return;
};
// Clean memories to make preserve DB size
phase1::prune(&session, &config).await;
// Run phase 1.
phase1::run(&session, &config).await;
// Run phase 2.

View File

@@ -292,6 +292,49 @@ LIMIT ?
.collect::<Result<Vec<_>, _>>()
}
/// Prunes stale stage-1 outputs while preserving the latest phase-2
/// baseline and stage-1 job watermarks.
///
/// Query behavior:
/// - considers only rows with `selected_for_phase2 = 0`
/// - keeps recency as `COALESCE(last_usage, source_updated_at)`
/// - removes rows older than `max_unused_days`
/// - prunes at most `limit` rows ordered from stalest to newest
pub async fn prune_stage1_outputs_for_retention(
&self,
max_unused_days: i64,
limit: usize,
) -> anyhow::Result<usize> {
if limit == 0 {
return Ok(0);
}
let cutoff = (Utc::now() - Duration::days(max_unused_days.max(0))).timestamp();
let rows_affected = sqlx::query(
r#"
DELETE FROM stage1_outputs
WHERE thread_id IN (
SELECT thread_id
FROM stage1_outputs
WHERE selected_for_phase2 = 0
AND COALESCE(last_usage, source_updated_at) < ?
ORDER BY
COALESCE(last_usage, source_updated_at) ASC,
source_updated_at ASC,
thread_id ASC
LIMIT ?
)
"#,
)
.bind(cutoff)
.bind(limit as i64)
.execute(self.pool.as_ref())
.await?
.rows_affected();
Ok(rows_affected as usize)
}
/// Returns the current phase-2 input set along with its diff against the
/// last successful phase-2 selection.
///
@@ -3875,6 +3918,214 @@ VALUES (?, ?, ?, ?, ?)
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn prune_stage1_outputs_for_retention_prunes_stale_unselected_rows_only() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let stale_unused =
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("stale unused");
let stale_used = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("stale used");
let stale_selected =
ThreadId::from_string(&Uuid::new_v4().to_string()).expect("stale selected");
let fresh_used = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("fresh used");
for (thread_id, workspace) in [
(stale_unused, "workspace-stale-unused"),
(stale_used, "workspace-stale-used"),
(stale_selected, "workspace-stale-selected"),
(fresh_used, "workspace-fresh-used"),
] {
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join(workspace),
))
.await
.expect("upsert thread");
}
let now = Utc::now().timestamp();
for (thread_id, source_updated_at, summary) in [
(
stale_unused,
now - Duration::days(60).num_seconds(),
"stale-unused",
),
(
stale_used,
now - Duration::days(50).num_seconds(),
"stale-used",
),
(
stale_selected,
now - Duration::days(45).num_seconds(),
"stale-selected",
),
(
fresh_used,
now - Duration::days(10).num_seconds(),
"fresh-used",
),
] {
let claim = runtime
.try_claim_stage1_job(thread_id, owner, source_updated_at, 3600, 64)
.await
.expect("claim stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
source_updated_at,
&format!("raw-{summary}"),
summary,
None,
)
.await
.expect("mark stage1 success"),
"stage1 success should persist output"
);
}
sqlx::query(
"UPDATE stage1_outputs SET usage_count = ?, last_usage = ? WHERE thread_id = ?",
)
.bind(3_i64)
.bind(now - Duration::days(40).num_seconds())
.bind(stale_used.to_string())
.execute(runtime.pool.as_ref())
.await
.expect("set stale used metadata");
sqlx::query(
"UPDATE stage1_outputs SET selected_for_phase2 = 1, selected_for_phase2_source_updated_at = source_updated_at WHERE thread_id = ?",
)
.bind(stale_selected.to_string())
.execute(runtime.pool.as_ref())
.await
.expect("mark selected for phase2");
sqlx::query(
"UPDATE stage1_outputs SET usage_count = ?, last_usage = ? WHERE thread_id = ?",
)
.bind(8_i64)
.bind(now - Duration::days(2).num_seconds())
.bind(fresh_used.to_string())
.execute(runtime.pool.as_ref())
.await
.expect("set fresh used metadata");
let before_jobs_count =
sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM jobs WHERE kind = 'memory_stage1'")
.fetch_one(runtime.pool.as_ref())
.await
.expect("count stage1 jobs before prune");
let pruned = runtime
.prune_stage1_outputs_for_retention(30, 100)
.await
.expect("prune stage1 outputs");
assert_eq!(pruned, 2);
let remaining = sqlx::query_scalar::<_, String>(
"SELECT thread_id FROM stage1_outputs ORDER BY thread_id",
)
.fetch_all(runtime.pool.as_ref())
.await
.expect("load remaining stage1 outputs");
let mut expected_remaining = vec![fresh_used.to_string(), stale_selected.to_string()];
expected_remaining.sort();
assert_eq!(remaining, expected_remaining);
let after_jobs_count =
sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM jobs WHERE kind = 'memory_stage1'")
.fetch_one(runtime.pool.as_ref())
.await
.expect("count stage1 jobs after prune");
assert_eq!(after_jobs_count, before_jobs_count);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn prune_stage1_outputs_for_retention_respects_batch_limit() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let thread_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread a");
let thread_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread b");
let thread_c = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread c");
for (thread_id, workspace) in [
(thread_a, "workspace-a"),
(thread_b, "workspace-b"),
(thread_c, "workspace-c"),
] {
runtime
.upsert_thread(&test_thread_metadata(
&codex_home,
thread_id,
codex_home.join(workspace),
))
.await
.expect("upsert thread");
}
let now = Utc::now().timestamp();
for (thread_id, source_updated_at, summary) in [
(thread_a, now - Duration::days(60).num_seconds(), "stale-a"),
(thread_b, now - Duration::days(50).num_seconds(), "stale-b"),
(thread_c, now - Duration::days(40).num_seconds(), "stale-c"),
] {
let claim = runtime
.try_claim_stage1_job(thread_id, owner, source_updated_at, 3600, 64)
.await
.expect("claim stage1");
let ownership_token = match claim {
Stage1JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stage1 claim outcome: {other:?}"),
};
assert!(
runtime
.mark_stage1_job_succeeded(
thread_id,
ownership_token.as_str(),
source_updated_at,
&format!("raw-{summary}"),
summary,
None,
)
.await
.expect("mark stage1 success"),
"stage1 success should persist output"
);
}
let pruned = runtime
.prune_stage1_outputs_for_retention(30, 2)
.await
.expect("prune stage1 outputs with limit");
assert_eq!(pruned, 2);
let remaining_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM stage1_outputs")
.fetch_one(runtime.pool.as_ref())
.await
.expect("count remaining stage1 outputs");
assert_eq!(remaining_count, 1);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn mark_stage1_job_succeeded_enqueues_global_consolidation() {
let codex_home = unique_temp_dir();