feat: phase 2 consolidation (#11306)

Consolidation phase of memories

Cleaning and better handling of concurrency
This commit is contained in:
jif-oai
2026-02-10 14:31:16 +00:00
committed by GitHub
parent d735df1f50
commit e57892b211
5 changed files with 907 additions and 252 deletions

View File

@@ -60,6 +60,19 @@ pub enum Phase1JobClaimOutcome {
SkippedRunning,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DirtyMemoryScope {
pub scope_kind: String,
pub scope_key: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Phase2JobClaimOutcome {
Claimed { ownership_token: String },
SkippedNotDirty,
SkippedRunning,
}
impl StateRuntime {
/// Initialize the state runtime using the provided Codex home and default provider.
///
@@ -1040,62 +1053,342 @@ ON CONFLICT(scope_kind, scope_key) DO UPDATE SET
Ok(())
}
/// Try to acquire or renew the per-cwd memory consolidation lock.
///
/// Returns `true` when the lock is acquired/renewed for `working_thread_id`.
/// Returns `false` when another owner holds a non-expired lease.
pub async fn try_acquire_memory_consolidation_lock(
/// List scopes that currently require phase-2 consolidation.
pub async fn list_dirty_memory_scopes(
&self,
cwd: &Path,
working_thread_id: ThreadId,
lease_seconds: i64,
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let stale_cutoff = now.saturating_sub(lease_seconds.max(0));
let result = sqlx::query(
limit: usize,
) -> anyhow::Result<Vec<DirtyMemoryScope>> {
if limit == 0 {
return Ok(Vec::new());
}
let rows = sqlx::query(
r#"
INSERT INTO memory_consolidation_locks (
cwd,
working_thread_id,
updated_at
) VALUES (?, ?, ?)
ON CONFLICT(cwd) DO UPDATE SET
working_thread_id = excluded.working_thread_id,
updated_at = excluded.updated_at
WHERE memory_consolidation_locks.working_thread_id = excluded.working_thread_id
OR memory_consolidation_locks.updated_at <= ?
SELECT scope_kind, scope_key
FROM memory_scope_dirty
WHERE dirty = 1
ORDER BY updated_at DESC, scope_kind ASC, scope_key ASC
LIMIT ?
"#,
)
.bind(cwd.display().to_string())
.bind(working_thread_id.to_string())
.bind(now)
.bind(stale_cutoff)
.execute(self.pool.as_ref())
.bind(limit as i64)
.fetch_all(self.pool.as_ref())
.await?;
Ok(result.rows_affected() > 0)
rows.into_iter()
.map(|row| {
Ok(DirtyMemoryScope {
scope_kind: row.try_get("scope_kind")?,
scope_key: row.try_get("scope_key")?,
})
})
.collect()
}
/// Release the per-cwd memory consolidation lock if held by `working_thread_id`.
///
/// Returns `true` when a lock row was removed.
pub async fn release_memory_consolidation_lock(
/// Try to claim a phase-2 consolidation job for `(scope_kind, scope_key)`.
pub async fn try_claim_phase2_job(
&self,
cwd: &Path,
working_thread_id: ThreadId,
scope_kind: &str,
scope_key: &str,
owner_session_id: ThreadId,
lease_seconds: i64,
) -> anyhow::Result<Phase2JobClaimOutcome> {
const CAS_RETRY_LIMIT: usize = 3;
for _ in 0..CAS_RETRY_LIMIT {
let now = Utc::now().timestamp();
let stale_cutoff = now.saturating_sub(lease_seconds.max(0));
let ownership_token = Uuid::new_v4().to_string();
let owner_session_id = owner_session_id.to_string();
let mut tx = self.pool.begin().await?;
let dirty_row = sqlx::query(
r#"
SELECT dirty
FROM memory_scope_dirty
WHERE scope_kind = ? AND scope_key = ?
"#,
)
.bind(scope_kind)
.bind(scope_key)
.fetch_optional(&mut *tx)
.await?;
let Some(dirty_row) = dirty_row else {
tx.commit().await?;
return Ok(Phase2JobClaimOutcome::SkippedNotDirty);
};
let dirty: bool = dirty_row.try_get("dirty")?;
if !dirty {
tx.commit().await?;
return Ok(Phase2JobClaimOutcome::SkippedNotDirty);
}
let existing = sqlx::query(
r#"
SELECT status, last_heartbeat_at, attempt
FROM memory_phase2_jobs
WHERE scope_kind = ? AND scope_key = ?
"#,
)
.bind(scope_kind)
.bind(scope_key)
.fetch_optional(&mut *tx)
.await?;
let Some(existing) = existing else {
sqlx::query(
r#"
INSERT INTO memory_phase2_jobs (
scope_kind,
scope_key,
status,
owner_session_id,
agent_thread_id,
started_at,
last_heartbeat_at,
finished_at,
attempt,
failure_reason,
ownership_token
) VALUES (?, ?, 'running', ?, NULL, ?, ?, NULL, 1, NULL, ?)
"#,
)
.bind(scope_kind)
.bind(scope_key)
.bind(owner_session_id.as_str())
.bind(now)
.bind(now)
.bind(ownership_token.as_str())
.execute(&mut *tx)
.await?;
tx.commit().await?;
return Ok(Phase2JobClaimOutcome::Claimed { ownership_token });
};
let status: String = existing.try_get("status")?;
let existing_last_heartbeat_at: Option<i64> = existing.try_get("last_heartbeat_at")?;
let existing_attempt: i64 = existing.try_get("attempt")?;
if status == "running"
&& existing_last_heartbeat_at
.is_some_and(|last_heartbeat_at| last_heartbeat_at > stale_cutoff)
{
tx.commit().await?;
return Ok(Phase2JobClaimOutcome::SkippedRunning);
}
let new_attempt = existing_attempt.saturating_add(1);
let rows_affected = if let Some(existing_last_heartbeat_at) = existing_last_heartbeat_at
{
sqlx::query(
r#"
UPDATE memory_phase2_jobs
SET
status = 'running',
owner_session_id = ?,
agent_thread_id = NULL,
started_at = ?,
last_heartbeat_at = ?,
finished_at = NULL,
attempt = ?,
failure_reason = NULL,
ownership_token = ?
WHERE scope_kind = ? AND scope_key = ?
AND status = ? AND attempt = ? AND last_heartbeat_at = ?
"#,
)
.bind(owner_session_id.as_str())
.bind(now)
.bind(now)
.bind(new_attempt)
.bind(ownership_token.as_str())
.bind(scope_kind)
.bind(scope_key)
.bind(status.as_str())
.bind(existing_attempt)
.bind(existing_last_heartbeat_at)
.execute(&mut *tx)
.await?
.rows_affected()
} else {
sqlx::query(
r#"
UPDATE memory_phase2_jobs
SET
status = 'running',
owner_session_id = ?,
agent_thread_id = NULL,
started_at = ?,
last_heartbeat_at = ?,
finished_at = NULL,
attempt = ?,
failure_reason = NULL,
ownership_token = ?
WHERE scope_kind = ? AND scope_key = ?
AND status = ? AND attempt = ? AND last_heartbeat_at IS NULL
"#,
)
.bind(owner_session_id.as_str())
.bind(now)
.bind(now)
.bind(new_attempt)
.bind(ownership_token.as_str())
.bind(scope_kind)
.bind(scope_key)
.bind(status.as_str())
.bind(existing_attempt)
.execute(&mut *tx)
.await?
.rows_affected()
};
if rows_affected == 0 {
tx.rollback().await?;
continue;
}
tx.commit().await?;
return Ok(Phase2JobClaimOutcome::Claimed { ownership_token });
}
Ok(Phase2JobClaimOutcome::SkippedRunning)
}
/// Persist the spawned phase-2 agent id for an owned running job.
pub async fn set_phase2_job_agent_thread_id(
&self,
scope_kind: &str,
scope_key: &str,
ownership_token: &str,
agent_thread_id: ThreadId,
) -> anyhow::Result<bool> {
let result = sqlx::query(
let now = Utc::now().timestamp();
let rows_affected = sqlx::query(
r#"
DELETE FROM memory_consolidation_locks
WHERE cwd = ? AND working_thread_id = ?
UPDATE memory_phase2_jobs
SET
agent_thread_id = ?,
last_heartbeat_at = ?
WHERE scope_kind = ? AND scope_key = ?
AND status = 'running' AND ownership_token = ?
"#,
)
.bind(cwd.display().to_string())
.bind(working_thread_id.to_string())
.bind(agent_thread_id.to_string())
.bind(now)
.bind(scope_kind)
.bind(scope_key)
.bind(ownership_token)
.execute(self.pool.as_ref())
.await?
.rows_affected();
Ok(rows_affected > 0)
}
/// Refresh heartbeat timestamp for an owned running phase-2 job.
pub async fn heartbeat_phase2_job(
&self,
scope_kind: &str,
scope_key: &str,
ownership_token: &str,
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let rows_affected = sqlx::query(
r#"
UPDATE memory_phase2_jobs
SET last_heartbeat_at = ?
WHERE scope_kind = ? AND scope_key = ?
AND status = 'running' AND ownership_token = ?
"#,
)
.bind(now)
.bind(scope_kind)
.bind(scope_key)
.bind(ownership_token)
.execute(self.pool.as_ref())
.await?
.rows_affected();
Ok(rows_affected > 0)
}
/// Finalize a claimed phase-2 job as succeeded and clear dirty state.
pub async fn mark_phase2_job_succeeded(
&self,
scope_kind: &str,
scope_key: &str,
ownership_token: &str,
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let mut tx = self.pool.begin().await?;
let rows_affected = sqlx::query(
r#"
UPDATE memory_phase2_jobs
SET
status = 'succeeded',
finished_at = ?,
failure_reason = NULL
WHERE scope_kind = ? AND scope_key = ?
AND status = 'running' AND ownership_token = ?
"#,
)
.bind(now)
.bind(scope_kind)
.bind(scope_key)
.bind(ownership_token)
.execute(&mut *tx)
.await?
.rows_affected();
if rows_affected == 0 {
tx.commit().await?;
return Ok(false);
}
sqlx::query(
r#"
UPDATE memory_scope_dirty
SET dirty = 0, updated_at = ?
WHERE scope_kind = ? AND scope_key = ?
"#,
)
.bind(now)
.bind(scope_kind)
.bind(scope_key)
.execute(&mut *tx)
.await?;
Ok(result.rows_affected() > 0)
tx.commit().await?;
Ok(true)
}
/// Finalize a claimed phase-2 job as failed, leaving dirty scope set.
pub async fn mark_phase2_job_failed(
&self,
scope_kind: &str,
scope_key: &str,
ownership_token: &str,
failure_reason: &str,
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let rows_affected = sqlx::query(
r#"
UPDATE memory_phase2_jobs
SET
status = 'failed',
finished_at = ?,
failure_reason = ?
WHERE scope_kind = ? AND scope_key = ?
AND status = 'running' AND ownership_token = ?
"#,
)
.bind(now)
.bind(failure_reason)
.bind(scope_kind)
.bind(scope_key)
.bind(ownership_token)
.execute(self.pool.as_ref())
.await?
.rows_affected();
Ok(rows_affected > 0)
}
/// Persist dynamic tools for a thread if none have been stored yet.
@@ -1478,6 +1771,7 @@ fn push_thread_order_and_limit(
#[cfg(test)]
mod tests {
use super::Phase1JobClaimOutcome;
use super::Phase2JobClaimOutcome;
use super::STATE_DB_FILENAME;
use super::STATE_DB_VERSION;
use super::StateRuntime;
@@ -1845,90 +2139,259 @@ mod tests {
}
#[tokio::test]
async fn memory_consolidation_lock_enforces_owner_and_release() {
async fn phase2_job_claim_requires_dirty_scope() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let cwd = codex_home.join("workspace");
let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let claim_without_dirty = runtime
.try_claim_phase2_job("cwd", "scope", owner, 3600)
.await
.expect("claim without dirty");
assert_eq!(claim_without_dirty, Phase2JobClaimOutcome::SkippedNotDirty);
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_a, 600)
.await
.expect("acquire for owner_a"),
"owner_a should acquire lock"
runtime
.mark_memory_scope_dirty("cwd", "scope", false)
.await
.expect("mark dirty false");
let claim_with_false_dirty = runtime
.try_claim_phase2_job("cwd", "scope", owner, 3600)
.await
.expect("claim with false dirty");
assert_eq!(
claim_with_false_dirty,
Phase2JobClaimOutcome::SkippedNotDirty
);
runtime
.mark_memory_scope_dirty("cwd", "scope", true)
.await
.expect("mark dirty true");
let claim_with_dirty = runtime
.try_claim_phase2_job("cwd", "scope", owner, 3600)
.await
.expect("claim with dirty");
assert!(
!runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_b, 600)
.await
.expect("acquire for owner_b should fail"),
"owner_b should not steal active lock"
);
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_a, 600)
.await
.expect("owner_a should renew lock"),
"owner_a should renew lock"
);
assert!(
!runtime
.release_memory_consolidation_lock(cwd.as_path(), owner_b)
.await
.expect("owner_b release should be no-op"),
"non-owner release should not remove lock"
);
assert!(
runtime
.release_memory_consolidation_lock(cwd.as_path(), owner_a)
.await
.expect("owner_a release"),
"owner_a should release lock"
);
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_b, 600)
.await
.expect("owner_b acquire after release"),
"owner_b should acquire released lock"
matches!(claim_with_dirty, Phase2JobClaimOutcome::Claimed { .. }),
"dirty scope should be claimable"
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn memory_consolidation_lock_can_be_stolen_when_lease_expired() {
async fn phase2_running_job_skips_fresh_claims_and_allows_stale_steal() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let cwd = codex_home.join("workspace");
let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
runtime
.mark_memory_scope_dirty("cwd", "scope", true)
.await
.expect("mark dirty true");
let claim_a = runtime
.try_claim_phase2_job("cwd", "scope", owner_a, 3600)
.await
.expect("claim owner_a");
let owner_a_token = match claim_a {
Phase2JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected claim outcome: {other:?}"),
};
let fresh_claim_b = runtime
.try_claim_phase2_job("cwd", "scope", owner_b, 3600)
.await
.expect("fresh claim owner_b");
assert_eq!(fresh_claim_b, Phase2JobClaimOutcome::SkippedRunning);
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_a, 600)
.heartbeat_phase2_job("cwd", "scope", owner_a_token.as_str())
.await
.expect("owner_a acquire")
.expect("owner_a heartbeat"),
"current owner should heartbeat"
);
assert!(
!runtime
.heartbeat_phase2_job("cwd", "scope", "wrong-token")
.await
.expect("wrong token heartbeat"),
"wrong token should not heartbeat"
);
let stale_claim_b = runtime
.try_claim_phase2_job("cwd", "scope", owner_b, 0)
.await
.expect("stale claim owner_b");
let owner_b_token = match stale_claim_b {
Phase2JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected stale claim outcome: {other:?}"),
};
assert!(
!runtime
.heartbeat_phase2_job("cwd", "scope", owner_a_token.as_str())
.await
.expect("stale owner heartbeat"),
"stale owner should lose heartbeat ownership"
);
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_b, 0)
.heartbeat_phase2_job("cwd", "scope", owner_b_token.as_str())
.await
.expect("owner_b steal with expired lease"),
"owner_b should steal lock when lease cutoff marks previous lock stale"
.expect("new owner heartbeat"),
"new owner should heartbeat"
);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn phase2_success_requires_owner_and_clears_dirty_scope() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let owner = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
runtime
.mark_memory_scope_dirty("cwd", "scope", true)
.await
.expect("mark dirty true");
let claim = runtime
.try_claim_phase2_job("cwd", "scope", owner, 3600)
.await
.expect("claim");
let ownership_token = match claim {
Phase2JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected claim outcome: {other:?}"),
};
assert!(
!runtime
.mark_phase2_job_succeeded("cwd", "scope", "wrong-token")
.await
.expect("wrong token success should fail"),
"wrong token should not finalize phase2 job"
);
let dirty_after_wrong_token = sqlx::query(
"SELECT dirty FROM memory_scope_dirty WHERE scope_kind = ? AND scope_key = ?",
)
.bind("cwd")
.bind("scope")
.fetch_one(runtime.pool.as_ref())
.await
.expect("fetch dirty after wrong token")
.try_get::<bool, _>("dirty")
.expect("dirty value");
assert!(dirty_after_wrong_token, "dirty scope should remain dirty");
assert!(
runtime
.mark_phase2_job_succeeded("cwd", "scope", ownership_token.as_str())
.await
.expect("owner success should pass"),
"owner token should finalize phase2 job"
);
let dirty_after_success = sqlx::query(
"SELECT dirty FROM memory_scope_dirty WHERE scope_kind = ? AND scope_key = ?",
)
.bind("cwd")
.bind("scope")
.fetch_one(runtime.pool.as_ref())
.await
.expect("fetch dirty after success")
.try_get::<bool, _>("dirty")
.expect("dirty value");
assert!(
!dirty_after_success,
"successful phase2 finalization should clear dirty scope"
);
let dirty_scopes = runtime
.list_dirty_memory_scopes(10)
.await
.expect("list dirty scopes");
assert_eq!(dirty_scopes, Vec::new());
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn phase2_failure_keeps_scope_dirty_and_allows_retry() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");
let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("owner id");
runtime
.mark_memory_scope_dirty("cwd", "scope", true)
.await
.expect("mark dirty true");
let claim_a = runtime
.try_claim_phase2_job("cwd", "scope", owner_a, 3600)
.await
.expect("claim owner_a");
let owner_a_token = match claim_a {
Phase2JobClaimOutcome::Claimed { ownership_token } => ownership_token,
other => panic!("unexpected claim outcome: {other:?}"),
};
assert!(
runtime
.mark_phase2_job_failed(
"cwd",
"scope",
owner_a_token.as_str(),
"consolidation failed",
)
.await
.expect("mark phase2 failed"),
"owner token should fail phase2 job"
);
let dirty_scopes = runtime
.list_dirty_memory_scopes(10)
.await
.expect("list dirty scopes");
assert_eq!(
dirty_scopes,
vec![super::DirtyMemoryScope {
scope_kind: "cwd".to_string(),
scope_key: "scope".to_string(),
}]
);
let claim_b = runtime
.try_claim_phase2_job("cwd", "scope", owner_b, 3600)
.await
.expect("claim owner_b");
assert!(
matches!(claim_b, Phase2JobClaimOutcome::Claimed { .. }),
"failed jobs should be retryable while dirty"
);
let attempt = sqlx::query(
"SELECT attempt FROM memory_phase2_jobs WHERE scope_kind = ? AND scope_key = ?",
)
.bind("cwd")
.bind("scope")
.fetch_one(runtime.pool.as_ref())
.await
.expect("fetch attempt")
.try_get::<i64, _>("attempt")
.expect("attempt value");
assert_eq!(attempt, 2);
let _ = tokio::fs::remove_dir_all(codex_home).await;
}
#[tokio::test]
async fn phase1_job_claim_and_success_require_current_owner_token() {
let codex_home = unique_temp_dir();