feat: memories forgetting (#12900)

Add diff based memory forgetting
This commit is contained in:
jif-oai
2026-02-26 13:19:57 +00:00
committed by GitHub
parent 81ce645733
commit 382fa338b3
12 changed files with 1335 additions and 39 deletions

View File

@@ -70,11 +70,31 @@ What it does:
If there is input, it then:
- spawns an internal consolidation sub-agent
- builds the Phase 2 prompt with a diff of the current Phase 1 input
selection versus the last successful Phase 2 selection (`added`,
`retained`, `removed`)
- runs it with no approvals, no network, and local write access only
- disables collab for that agent (to prevent recursive delegation)
- watches the agent status and heartbeats the global job lease while it runs
- marks the phase-2 job success/failure in the state DB when the agent finishes
Selection diff behavior:
- successful Phase 2 runs mark the exact stage-1 snapshots they consumed with
`selected_for_phase2 = 1` and persist the matching
`selected_for_phase2_source_updated_at`
- Phase 1 upserts preserve the previous `selected_for_phase2` baseline until
the next successful Phase 2 run rewrites it
- the next Phase 2 run compares the current top-N stage-1 inputs against that
prior snapshot selection to label inputs as `added` or `retained`; a
refreshed thread stays `added` until Phase 2 successfully selects its newer
snapshot
- rows that were previously selected but still exist outside the current top-N
selection are surfaced as `removed`
- before the agent starts, local `rollout_summaries/` and `raw_memories.md`
keep the union of the current selection and the previous successful
selection, so removed-thread evidence stays available during forgetting
Watermark behavior:
- The global phase-2 job claim includes an input watermark representing the latest input timestamp known when the job was claimed.

View File

@@ -8,6 +8,7 @@ use crate::memories::metrics;
use crate::memories::phase_two;
use crate::memories::prompts::build_consolidation_prompt;
use crate::memories::storage::rebuild_raw_memories_file_from_memories;
use crate::memories::storage::rollout_summary_file_stem;
use crate::memories::storage::sync_rollout_summaries_from_memories;
use codex_config::Constrained;
use codex_protocol::ThreadId;
@@ -17,8 +18,10 @@ use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_protocol::protocol::TokenUsage;
use codex_protocol::user_input::UserInput;
use codex_state::Stage1Output;
use codex_state::StateRuntime;
use codex_utils_absolute_path::AbsolutePathBuf;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
@@ -73,21 +76,24 @@ pub(super) async fn run(session: &Arc<Session>, config: Arc<Config>) {
};
// 3. Query the memories
let raw_memories = match db.list_stage1_outputs_for_global(max_raw_memories).await {
Ok(memories) => memories,
let selection = match db.get_phase2_input_selection(max_raw_memories).await {
Ok(selection) => selection,
Err(err) => {
tracing::error!("failed to list stage1 outputs from global: {}", err);
job::failed(session, db, &claim, "failed_load_stage1_outputs").await;
return;
}
};
let raw_memories = selection.selected.to_vec();
let artifact_memories = artifact_memories_for_phase2(&selection);
let new_watermark = get_watermark(claim.watermark, &raw_memories);
// 4. Update the file system by syncing the raw memories with the one extracted from DB at
// step 3
// [`rollout_summaries/`]
if let Err(err) =
sync_rollout_summaries_from_memories(&root, &raw_memories, max_raw_memories).await
sync_rollout_summaries_from_memories(&root, &artifact_memories, artifact_memories.len())
.await
{
tracing::error!("failed syncing local memory artifacts for global consolidation: {err}");
job::failed(session, db, &claim, "failed_sync_artifacts").await;
@@ -95,7 +101,8 @@ pub(super) async fn run(session: &Arc<Session>, config: Arc<Config>) {
}
// [`raw_memories.md`]
if let Err(err) =
rebuild_raw_memories_file_from_memories(&root, &raw_memories, max_raw_memories).await
rebuild_raw_memories_file_from_memories(&root, &artifact_memories, artifact_memories.len())
.await
{
tracing::error!("failed syncing local memory artifacts for global consolidation: {err}");
job::failed(session, db, &claim, "failed_rebuild_raw_memories").await;
@@ -103,12 +110,20 @@ pub(super) async fn run(session: &Arc<Session>, config: Arc<Config>) {
}
if raw_memories.is_empty() {
// We check only after sync of the file system.
job::succeed(session, db, &claim, new_watermark, "succeeded_no_input").await;
job::succeed(
session,
db,
&claim,
new_watermark,
&[],
"succeeded_no_input",
)
.await;
return;
}
// 5. Spawn the agent
let prompt = agent::get_prompt(config);
let prompt = agent::get_prompt(config, &selection);
let source = SessionSource::SubAgent(SubAgentSource::MemoryConsolidation);
let thread_id = match session
.services
@@ -129,6 +144,7 @@ pub(super) async fn run(session: &Arc<Session>, config: Arc<Config>) {
session,
claim,
new_watermark,
raw_memories.clone(),
thread_id,
phase_two_e2e_timer,
);
@@ -140,6 +156,22 @@ pub(super) async fn run(session: &Arc<Session>, config: Arc<Config>) {
emit_metrics(session, counters);
}
fn artifact_memories_for_phase2(
selection: &codex_state::Phase2InputSelection,
) -> Vec<Stage1Output> {
let mut seen = HashSet::new();
let mut memories = selection.selected.clone();
for memory in &selection.selected {
seen.insert(rollout_summary_file_stem(memory));
}
for memory in &selection.previous_selected {
if seen.insert(rollout_summary_file_stem(memory)) {
memories.push(memory.clone());
}
}
memories
}
mod job {
use super::*;
@@ -205,6 +237,7 @@ mod job {
db: &StateRuntime,
claim: &Claim,
completion_watermark: i64,
selected_outputs: &[codex_state::Stage1Output],
reason: &'static str,
) {
session.services.otel_manager.counter(
@@ -213,7 +246,7 @@ mod job {
&[("status", reason)],
);
let _ = db
.mark_global_phase2_job_succeeded(&claim.token, completion_watermark)
.mark_global_phase2_job_succeeded(&claim.token, completion_watermark, selected_outputs)
.await;
}
}
@@ -266,9 +299,12 @@ mod agent {
Some(agent_config)
}
pub(super) fn get_prompt(config: Arc<Config>) -> Vec<UserInput> {
pub(super) fn get_prompt(
config: Arc<Config>,
selection: &codex_state::Phase2InputSelection,
) -> Vec<UserInput> {
let root = memory_root(&config.codex_home);
let prompt = build_consolidation_prompt(&root);
let prompt = build_consolidation_prompt(&root, selection);
vec![UserInput::Text {
text: prompt,
text_elements: vec![],
@@ -280,6 +316,7 @@ mod agent {
session: &Arc<Session>,
claim: Claim,
new_watermark: i64,
selected_outputs: Vec<codex_state::Stage1Output>,
thread_id: ThreadId,
phase_two_e2e_timer: Option<codex_otel::Timer>,
) {
@@ -316,7 +353,15 @@ mod agent {
if let Some(token_usage) = agent_control.get_total_token_usage(thread_id).await {
emit_token_usage_metrics(&session, &token_usage);
}
job::succeed(&session, &db, &claim, new_watermark, "succeeded").await;
job::succeed(
&session,
&db,
&claim,
new_watermark,
&selected_outputs,
"succeeded",
)
.await;
} else {
job::failed(&session, &db, &claim, "failed_agent").await;
}

View File

@@ -1,9 +1,13 @@
use crate::memories::memory_root;
use crate::memories::phase_one;
use crate::memories::storage::rollout_summary_file_stem_from_parts;
use crate::truncate::TruncationPolicy;
use crate::truncate::truncate_text;
use askama::Template;
use codex_protocol::openai_models::ModelInfo;
use codex_state::Phase2InputSelection;
use codex_state::Stage1Output;
use codex_state::Stage1OutputRef;
use std::path::Path;
use tokio::fs;
use tracing::warn;
@@ -12,6 +16,7 @@ use tracing::warn;
#[template(path = "memories/consolidation.md", escape = "none")]
struct ConsolidationPromptTemplate<'a> {
memory_root: &'a str,
phase2_input_selection: &'a str,
}
#[derive(Template)]
@@ -30,17 +35,91 @@ struct MemoryToolDeveloperInstructionsTemplate<'a> {
}
/// Builds the consolidation subagent prompt for a specific memory root.
pub(super) fn build_consolidation_prompt(memory_root: &Path) -> String {
pub(super) fn build_consolidation_prompt(
memory_root: &Path,
selection: &Phase2InputSelection,
) -> String {
let memory_root = memory_root.display().to_string();
let phase2_input_selection = render_phase2_input_selection(selection);
let template = ConsolidationPromptTemplate {
memory_root: &memory_root,
phase2_input_selection: &phase2_input_selection,
};
template.render().unwrap_or_else(|err| {
warn!("failed to render memories consolidation prompt template: {err}");
format!("## Memory Phase 2 (Consolidation)\nConsolidate Codex memories in: {memory_root}")
format!(
"## Memory Phase 2 (Consolidation)\nConsolidate Codex memories in: {memory_root}\n\n{phase2_input_selection}"
)
})
}
fn render_phase2_input_selection(selection: &Phase2InputSelection) -> String {
let retained = selection.retained_thread_ids.len();
let added = selection.selected.len().saturating_sub(retained);
let selected = if selection.selected.is_empty() {
"- none".to_string()
} else {
selection
.selected
.iter()
.map(|item| {
render_selected_input_line(
item,
selection.retained_thread_ids.contains(&item.thread_id),
)
})
.collect::<Vec<_>>()
.join("\n")
};
let removed = if selection.removed.is_empty() {
"- none".to_string()
} else {
selection
.removed
.iter()
.map(render_removed_input_line)
.collect::<Vec<_>>()
.join("\n")
};
format!(
"- selected inputs this run: {}\n- newly added since the last successful Phase 2 run: {added}\n- retained from the last successful Phase 2 run: {retained}\n- removed from the last successful Phase 2 run: {}\n\nCurrent selected Phase 1 inputs:\n{selected}\n\nRemoved from the last successful Phase 2 selection:\n{removed}\n",
selection.selected.len(),
selection.removed.len(),
)
}
fn render_selected_input_line(item: &Stage1Output, retained: bool) -> String {
let status = if retained { "retained" } else { "added" };
let rollout_summary_file = format!(
"rollout_summaries/{}.md",
rollout_summary_file_stem_from_parts(
item.thread_id,
item.source_updated_at,
item.rollout_slug.as_deref(),
)
);
format!(
"- [{status}] thread_id={}, rollout_summary_file={rollout_summary_file}",
item.thread_id
)
}
fn render_removed_input_line(item: &Stage1OutputRef) -> String {
let rollout_summary_file = format!(
"rollout_summaries/{}.md",
rollout_summary_file_stem_from_parts(
item.thread_id,
item.source_updated_at,
item.rollout_slug.as_deref(),
)
);
format!(
"- thread_id={}, rollout_summary_file={rollout_summary_file}",
item.thread_id
)
}
/// Builds the stage-1 user message containing rollout metadata and content.
///
/// Large rollout payloads are truncated to 70% of the active model's effective