feat: memories forgetting (#12900)

Add diff based memory forgetting
This commit is contained in:
jif-oai
2026-02-26 13:19:57 +00:00
committed by GitHub
parent 81ce645733
commit 382fa338b3
12 changed files with 1335 additions and 39 deletions

View File

@@ -1,4 +1,5 @@
use super::*;
use crate::model::Phase2InputSelection;
use crate::model::Phase2JobClaimOutcome;
use crate::model::Stage1JobClaim;
use crate::model::Stage1JobClaimOutcome;
@@ -6,10 +7,12 @@ use crate::model::Stage1Output;
use crate::model::Stage1OutputRow;
use crate::model::Stage1StartupClaimParams;
use crate::model::ThreadRow;
use crate::model::stage1_output_ref_from_parts;
use chrono::Duration;
use sqlx::Executor;
use sqlx::QueryBuilder;
use sqlx::Sqlite;
use std::collections::HashSet;
const JOB_KIND_MEMORY_STAGE1: &str = "memory_stage1";
const JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL: &str = "memory_consolidate_global";
@@ -257,6 +260,117 @@ LIMIT ?
.collect::<Result<Vec<_>, _>>()
}
/// Returns the current phase-2 input set along with its diff against the
/// last successful phase-2 selection.
///
/// Query behavior:
/// - current selection is the latest `n` non-empty stage-1 outputs ordered
/// by `source_updated_at DESC, thread_id DESC`
/// - previously selected rows are identified by `selected_for_phase2 = 1`
/// - `previous_selected` contains the current persisted rows that belonged
/// to the last successful phase-2 baseline
/// - `retained_thread_ids` records which current rows still match the exact
/// snapshot selected in the last successful phase-2 run
/// - removed rows are previously selected rows that are still present in
/// `stage1_outputs` but fall outside the current top-`n` selection
pub async fn get_phase2_input_selection(
&self,
n: usize,
) -> anyhow::Result<Phase2InputSelection> {
if n == 0 {
return Ok(Phase2InputSelection::default());
}
let current_rows = sqlx::query(
r#"
SELECT
so.thread_id,
COALESCE(t.rollout_path, '') AS rollout_path,
so.source_updated_at,
so.raw_memory,
so.rollout_summary,
so.rollout_slug,
so.generated_at,
so.selected_for_phase2,
so.selected_for_phase2_source_updated_at,
COALESCE(t.cwd, '') AS cwd
FROM stage1_outputs AS so
LEFT JOIN threads AS t
ON t.id = so.thread_id
WHERE length(trim(so.raw_memory)) > 0 OR length(trim(so.rollout_summary)) > 0
ORDER BY so.source_updated_at DESC, so.thread_id DESC
LIMIT ?
"#,
)
.bind(n as i64)
.fetch_all(self.pool.as_ref())
.await?;
let mut current_thread_ids = HashSet::with_capacity(current_rows.len());
let mut selected = Vec::with_capacity(current_rows.len());
let mut retained_thread_ids = Vec::new();
for row in current_rows {
let thread_id = row.try_get::<String, _>("thread_id")?;
current_thread_ids.insert(thread_id.clone());
let source_updated_at = row.try_get::<i64, _>("source_updated_at")?;
if row.try_get::<i64, _>("selected_for_phase2")? != 0
&& row.try_get::<Option<i64>, _>("selected_for_phase2_source_updated_at")?
== Some(source_updated_at)
{
retained_thread_ids.push(ThreadId::try_from(thread_id.clone())?);
}
selected.push(Stage1Output::try_from(Stage1OutputRow::try_from_row(
&row,
)?)?);
}
let previous_rows = sqlx::query(
r#"
SELECT
so.thread_id,
COALESCE(t.rollout_path, '') AS rollout_path,
so.source_updated_at,
so.raw_memory,
so.rollout_summary,
so.rollout_slug
, so.generated_at
, COALESCE(t.cwd, '') AS cwd
FROM stage1_outputs AS so
LEFT JOIN threads AS t
ON t.id = so.thread_id
WHERE so.selected_for_phase2 = 1
ORDER BY so.source_updated_at DESC, so.thread_id DESC
"#,
)
.fetch_all(self.pool.as_ref())
.await?;
let previous_selected = previous_rows
.iter()
.map(Stage1OutputRow::try_from_row)
.map(|row| row.and_then(Stage1Output::try_from))
.collect::<Result<Vec<_>, _>>()?;
let mut removed = Vec::new();
for row in previous_rows {
let thread_id = row.try_get::<String, _>("thread_id")?;
if current_thread_ids.contains(thread_id.as_str()) {
continue;
}
removed.push(stage1_output_ref_from_parts(
thread_id,
row.try_get("source_updated_at")?,
row.try_get("rollout_slug")?,
)?);
}
Ok(Phase2InputSelection {
selected,
previous_selected,
retained_thread_ids,
removed,
})
}
/// Attempts to claim a stage-1 job for a thread at `source_updated_at`.
///
/// Claim semantics:
@@ -454,6 +568,9 @@ WHERE kind = ? AND job_key = ?
/// - sets `status='done'` and `last_success_watermark = input_watermark`
/// - upserts `stage1_outputs` for the thread, replacing existing output only
/// when `source_updated_at` is newer or equal
/// - preserves any existing `selected_for_phase2` baseline until the next
/// successful phase-2 run rewrites the baseline selection, including the
/// snapshot timestamp chosen during that run
/// - persists optional `rollout_slug` for rollout summary artifact naming
/// - enqueues/advances the global phase-2 job watermark using
/// `source_updated_at`
@@ -806,12 +923,18 @@ WHERE kind = ? AND job_key = ?
/// - sets `status='done'`, clears lease/errors
/// - advances `last_success_watermark` to
/// `max(existing_last_success_watermark, completed_watermark)`
/// - rewrites `selected_for_phase2` so only the exact selected stage-1
/// snapshots remain marked as part of the latest successful phase-2
/// selection, and persists each selected snapshot's
/// `source_updated_at` for future retained-vs-added diffing
pub async fn mark_global_phase2_job_succeeded(
&self,
ownership_token: &str,
completed_watermark: i64,
selected_outputs: &[Stage1Output],
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let mut tx = self.pool.begin().await?;
let rows_affected = sqlx::query(
r#"
UPDATE jobs
@@ -830,11 +953,46 @@ WHERE kind = ? AND job_key = ?
.bind(JOB_KIND_MEMORY_CONSOLIDATE_GLOBAL)
.bind(MEMORY_CONSOLIDATION_JOB_KEY)
.bind(ownership_token)
.execute(self.pool.as_ref())
.execute(&mut *tx)
.await?
.rows_affected();
Ok(rows_affected > 0)
if rows_affected == 0 {
tx.commit().await?;
return Ok(false);
}
sqlx::query(
r#"
UPDATE stage1_outputs
SET
selected_for_phase2 = 0,
selected_for_phase2_source_updated_at = NULL
WHERE selected_for_phase2 != 0 OR selected_for_phase2_source_updated_at IS NOT NULL
"#,
)
.execute(&mut *tx)
.await?;
for output in selected_outputs {
sqlx::query(
r#"
UPDATE stage1_outputs
SET
selected_for_phase2 = 1,
selected_for_phase2_source_updated_at = ?
WHERE thread_id = ? AND source_updated_at = ?
"#,
)
.bind(output.source_updated_at.timestamp())
.bind(output.thread_id.to_string())
.bind(output.source_updated_at.timestamp())
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(true)
}
/// Marks the owned running global phase-2 job as failed and schedules retry.