Files
codex/codex-rs/state/src/runtime.rs
Celia Chen fb2df99cf1 [feat] persist thread_dynamic_tools in db (#10252)
Persist thread_dynamic_tools in sqlite and read first from it. Fall back
to rollout files if it's not found. Persist dynamic tools to both sqlite
and rollout files.

Saw that new sessions get populated to db correctly & old sessions get
backfilled correctly at startup:
```
celia@com-92114 codex-rs % sqlite3 ~/.codex/state.sqlite \      "select thread_id, position,name,description,input_schema from thread_dynamic_tools;"
019c0cad-ec0d-74b2-a787-e8b33a349117|0|geo_lookup|lookup a city|{"properties":{"city":{"type":"string"}},"required":["city"],"type":"object"}
....
019c10ca-aa4b-7620-ae40-c0919fbd7ea7|0|geo_lookup|lookup a city|{"properties":{"city":{"type":"string"}},"required":["city"],"type":"object"}
```
2026-02-03 00:06:44 +00:00

703 lines
22 KiB
Rust

use crate::DB_ERROR_METRIC;
use crate::LogEntry;
use crate::LogQuery;
use crate::LogRow;
use crate::SortKey;
use crate::ThreadMetadata;
use crate::ThreadMetadataBuilder;
use crate::ThreadsPage;
use crate::apply_rollout_item;
use crate::migrations::MIGRATOR;
use crate::model::ThreadRow;
use crate::model::anchor_from_item;
use crate::model::datetime_to_epoch_seconds;
use crate::paths::file_modified_time_utc;
use chrono::DateTime;
use chrono::Utc;
use codex_otel::OtelManager;
use codex_protocol::ThreadId;
use codex_protocol::dynamic_tools::DynamicToolSpec;
use codex_protocol::protocol::RolloutItem;
use log::LevelFilter;
use serde_json::Value;
use sqlx::ConnectOptions;
use sqlx::QueryBuilder;
use sqlx::Row;
use sqlx::Sqlite;
use sqlx::SqlitePool;
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::sqlite::SqliteJournalMode;
use sqlx::sqlite::SqlitePoolOptions;
use sqlx::sqlite::SqliteSynchronous;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tracing::warn;
pub const STATE_DB_FILENAME: &str = "state.sqlite";
const METRIC_DB_INIT: &str = "codex.db.init";
#[derive(Clone)]
pub struct StateRuntime {
codex_home: PathBuf,
default_provider: String,
pool: Arc<sqlx::SqlitePool>,
}
impl StateRuntime {
/// Initialize the state runtime using the provided Codex home and default provider.
///
/// This opens (and migrates) the SQLite database at `codex_home/state.sqlite`.
pub async fn init(
codex_home: PathBuf,
default_provider: String,
otel: Option<OtelManager>,
) -> anyhow::Result<Arc<Self>> {
tokio::fs::create_dir_all(&codex_home).await?;
let state_path = codex_home.join(STATE_DB_FILENAME);
let existed = tokio::fs::try_exists(&state_path).await.unwrap_or(false);
let pool = match open_sqlite(&state_path).await {
Ok(db) => Arc::new(db),
Err(err) => {
warn!("failed to open state db at {}: {err}", state_path.display());
if let Some(otel) = otel.as_ref() {
otel.counter(METRIC_DB_INIT, 1, &[("status", "open_error")]);
}
return Err(err);
}
};
if let Some(otel) = otel.as_ref() {
otel.counter(METRIC_DB_INIT, 1, &[("status", "opened")]);
}
let runtime = Arc::new(Self {
pool,
codex_home,
default_provider,
});
if !existed && let Some(otel) = otel.as_ref() {
otel.counter(METRIC_DB_INIT, 1, &[("status", "created")]);
}
Ok(runtime)
}
/// Return the configured Codex home directory for this runtime.
pub fn codex_home(&self) -> &Path {
self.codex_home.as_path()
}
/// Load thread metadata by id using the underlying database.
pub async fn get_thread(&self, id: ThreadId) -> anyhow::Result<Option<crate::ThreadMetadata>> {
let row = sqlx::query(
r#"
SELECT
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
title,
sandbox_policy,
approval_mode,
tokens_used,
has_user_event,
archived_at,
git_sha,
git_branch,
git_origin_url
FROM threads
WHERE id = ?
"#,
)
.bind(id.to_string())
.fetch_optional(self.pool.as_ref())
.await?;
row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.transpose()
}
/// Get dynamic tools for a thread, if present.
pub async fn get_dynamic_tools(
&self,
thread_id: ThreadId,
) -> anyhow::Result<Option<Vec<DynamicToolSpec>>> {
let rows = sqlx::query(
r#"
SELECT name, description, input_schema
FROM thread_dynamic_tools
WHERE thread_id = ?
ORDER BY position ASC
"#,
)
.bind(thread_id.to_string())
.fetch_all(self.pool.as_ref())
.await?;
if rows.is_empty() {
return Ok(None);
}
let mut tools = Vec::with_capacity(rows.len());
for row in rows {
let input_schema: String = row.try_get("input_schema")?;
let input_schema = serde_json::from_str::<Value>(input_schema.as_str())?;
tools.push(DynamicToolSpec {
name: row.try_get("name")?,
description: row.try_get("description")?,
input_schema,
});
}
Ok(Some(tools))
}
/// Find a rollout path by thread id using the underlying database.
pub async fn find_rollout_path_by_id(
&self,
id: ThreadId,
archived_only: Option<bool>,
) -> anyhow::Result<Option<PathBuf>> {
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT rollout_path FROM threads WHERE id = ");
builder.push_bind(id.to_string());
match archived_only {
Some(true) => {
builder.push(" AND archived = 1");
}
Some(false) => {
builder.push(" AND archived = 0");
}
None => {}
}
let row = builder.build().fetch_optional(self.pool.as_ref()).await?;
Ok(row
.and_then(|r| r.try_get::<String, _>("rollout_path").ok())
.map(PathBuf::from))
}
/// List threads using the underlying database.
pub async fn list_threads(
&self,
page_size: usize,
anchor: Option<&crate::Anchor>,
sort_key: crate::SortKey,
allowed_sources: &[String],
model_providers: Option<&[String]>,
archived_only: bool,
) -> anyhow::Result<crate::ThreadsPage> {
let limit = page_size.saturating_add(1);
let mut builder = QueryBuilder::<Sqlite>::new(
r#"
SELECT
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
title,
sandbox_policy,
approval_mode,
tokens_used,
has_user_event,
archived_at,
git_sha,
git_branch,
git_origin_url
FROM threads
"#,
);
push_thread_filters(
&mut builder,
archived_only,
allowed_sources,
model_providers,
anchor,
sort_key,
);
push_thread_order_and_limit(&mut builder, sort_key, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
let mut items = rows
.into_iter()
.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from))
.collect::<Result<Vec<_>, _>>()?;
let num_scanned_rows = items.len();
let next_anchor = if items.len() > page_size {
items.pop();
items
.last()
.and_then(|item| anchor_from_item(item, sort_key))
} else {
None
};
Ok(ThreadsPage {
items,
next_anchor,
num_scanned_rows,
})
}
/// Insert one log entry into the logs table.
pub async fn insert_log(&self, entry: &LogEntry) -> anyhow::Result<()> {
self.insert_logs(std::slice::from_ref(entry)).await
}
/// Insert a batch of log entries into the logs table.
pub async fn insert_logs(&self, entries: &[LogEntry]) -> anyhow::Result<()> {
if entries.is_empty() {
return Ok(());
}
let mut builder = QueryBuilder::<Sqlite>::new(
"INSERT INTO logs (ts, ts_nanos, level, target, message, thread_id, module_path, file, line) ",
);
builder.push_values(entries, |mut row, entry| {
row.push_bind(entry.ts)
.push_bind(entry.ts_nanos)
.push_bind(&entry.level)
.push_bind(&entry.target)
.push_bind(&entry.message)
.push_bind(&entry.thread_id)
.push_bind(&entry.module_path)
.push_bind(&entry.file)
.push_bind(entry.line);
});
builder.build().execute(self.pool.as_ref()).await?;
Ok(())
}
pub(crate) async fn delete_logs_before(&self, cutoff_ts: i64) -> anyhow::Result<u64> {
let result = sqlx::query("DELETE FROM logs WHERE ts < ?")
.bind(cutoff_ts)
.execute(self.pool.as_ref())
.await?;
Ok(result.rows_affected())
}
/// Query logs with optional filters.
pub async fn query_logs(&self, query: &LogQuery) -> anyhow::Result<Vec<LogRow>> {
let mut builder = QueryBuilder::<Sqlite>::new(
"SELECT id, ts, ts_nanos, level, target, message, thread_id, file, line FROM logs WHERE 1 = 1",
);
push_log_filters(&mut builder, query);
if query.descending {
builder.push(" ORDER BY id DESC");
} else {
builder.push(" ORDER BY id ASC");
}
if let Some(limit) = query.limit {
builder.push(" LIMIT ").push_bind(limit as i64);
}
let rows = builder
.build_query_as::<LogRow>()
.fetch_all(self.pool.as_ref())
.await?;
Ok(rows)
}
/// Return the max log id matching optional filters.
pub async fn max_log_id(&self, query: &LogQuery) -> anyhow::Result<i64> {
let mut builder =
QueryBuilder::<Sqlite>::new("SELECT MAX(id) AS max_id FROM logs WHERE 1 = 1");
push_log_filters(&mut builder, query);
let row = builder.build().fetch_one(self.pool.as_ref()).await?;
let max_id: Option<i64> = row.try_get("max_id")?;
Ok(max_id.unwrap_or(0))
}
/// List thread ids using the underlying database (no rollout scanning).
pub async fn list_thread_ids(
&self,
limit: usize,
anchor: Option<&crate::Anchor>,
sort_key: crate::SortKey,
allowed_sources: &[String],
model_providers: Option<&[String]>,
archived_only: bool,
) -> anyhow::Result<Vec<ThreadId>> {
let mut builder = QueryBuilder::<Sqlite>::new("SELECT id FROM threads");
push_thread_filters(
&mut builder,
archived_only,
allowed_sources,
model_providers,
anchor,
sort_key,
);
push_thread_order_and_limit(&mut builder, sort_key, limit);
let rows = builder.build().fetch_all(self.pool.as_ref()).await?;
rows.into_iter()
.map(|row| {
let id: String = row.try_get("id")?;
Ok(ThreadId::try_from(id)?)
})
.collect()
}
/// Insert or replace thread metadata directly.
pub async fn upsert_thread(&self, metadata: &crate::ThreadMetadata) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO threads (
id,
rollout_path,
created_at,
updated_at,
source,
model_provider,
cwd,
title,
sandbox_policy,
approval_mode,
tokens_used,
has_user_event,
archived,
archived_at,
git_sha,
git_branch,
git_origin_url
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
rollout_path = excluded.rollout_path,
created_at = excluded.created_at,
updated_at = excluded.updated_at,
source = excluded.source,
model_provider = excluded.model_provider,
cwd = excluded.cwd,
title = excluded.title,
sandbox_policy = excluded.sandbox_policy,
approval_mode = excluded.approval_mode,
tokens_used = excluded.tokens_used,
has_user_event = excluded.has_user_event,
archived = excluded.archived,
archived_at = excluded.archived_at,
git_sha = excluded.git_sha,
git_branch = excluded.git_branch,
git_origin_url = excluded.git_origin_url
"#,
)
.bind(metadata.id.to_string())
.bind(metadata.rollout_path.display().to_string())
.bind(datetime_to_epoch_seconds(metadata.created_at))
.bind(datetime_to_epoch_seconds(metadata.updated_at))
.bind(metadata.source.as_str())
.bind(metadata.model_provider.as_str())
.bind(metadata.cwd.display().to_string())
.bind(metadata.title.as_str())
.bind(metadata.sandbox_policy.as_str())
.bind(metadata.approval_mode.as_str())
.bind(metadata.tokens_used)
.bind(metadata.has_user_event)
.bind(metadata.archived_at.is_some())
.bind(metadata.archived_at.map(datetime_to_epoch_seconds))
.bind(metadata.git_sha.as_deref())
.bind(metadata.git_branch.as_deref())
.bind(metadata.git_origin_url.as_deref())
.execute(self.pool.as_ref())
.await?;
Ok(())
}
/// Persist dynamic tools for a thread if none have been stored yet.
///
/// Dynamic tools are defined at thread start and should not change afterward.
/// This only writes the first time we see tools for a given thread.
pub async fn persist_dynamic_tools(
&self,
thread_id: ThreadId,
tools: Option<&[DynamicToolSpec]>,
) -> anyhow::Result<()> {
let Some(tools) = tools else {
return Ok(());
};
if tools.is_empty() {
return Ok(());
}
let mut tx = self.pool.begin().await?;
let thread_id = thread_id.to_string();
let existing: Option<i64> =
sqlx::query_scalar("SELECT 1 FROM thread_dynamic_tools WHERE thread_id = ? LIMIT 1")
.bind(thread_id.as_str())
.fetch_optional(&mut *tx)
.await?;
if existing.is_some() {
tx.commit().await?;
return Ok(());
}
for (idx, tool) in tools.iter().enumerate() {
let position = i64::try_from(idx).unwrap_or(i64::MAX);
let input_schema = serde_json::to_string(&tool.input_schema)?;
sqlx::query(
r#"
INSERT INTO thread_dynamic_tools (
thread_id,
position,
name,
description,
input_schema
) VALUES (?, ?, ?, ?, ?)
"#,
)
.bind(thread_id.as_str())
.bind(position)
.bind(tool.name.as_str())
.bind(tool.description.as_str())
.bind(input_schema)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
}
/// Apply rollout items incrementally using the underlying database.
pub async fn apply_rollout_items(
&self,
builder: &ThreadMetadataBuilder,
items: &[RolloutItem],
otel: Option<&OtelManager>,
) -> anyhow::Result<()> {
if items.is_empty() {
return Ok(());
}
let mut metadata = self
.get_thread(builder.id)
.await?
.unwrap_or_else(|| builder.build(&self.default_provider));
metadata.rollout_path = builder.rollout_path.clone();
for item in items {
apply_rollout_item(&mut metadata, item, &self.default_provider);
}
if let Some(updated_at) = file_modified_time_utc(builder.rollout_path.as_path()).await {
metadata.updated_at = updated_at;
}
// Keep the thread upsert before dynamic tools to satisfy the foreign key constraint:
// thread_dynamic_tools.thread_id -> threads.id.
if let Err(err) = self.upsert_thread(&metadata).await {
if let Some(otel) = otel {
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "apply_rollout_items")]);
}
return Err(err);
}
let dynamic_tools = extract_dynamic_tools(items);
if let Some(dynamic_tools) = dynamic_tools
&& let Err(err) = self
.persist_dynamic_tools(builder.id, dynamic_tools.as_deref())
.await
{
if let Some(otel) = otel {
otel.counter(DB_ERROR_METRIC, 1, &[("stage", "persist_dynamic_tools")]);
}
return Err(err);
}
Ok(())
}
/// Mark a thread as archived using the underlying database.
pub async fn mark_archived(
&self,
thread_id: ThreadId,
rollout_path: &Path,
archived_at: DateTime<Utc>,
) -> anyhow::Result<()> {
let Some(mut metadata) = self.get_thread(thread_id).await? else {
return Ok(());
};
metadata.archived_at = Some(archived_at);
metadata.rollout_path = rollout_path.to_path_buf();
if let Some(updated_at) = file_modified_time_utc(rollout_path).await {
metadata.updated_at = updated_at;
}
if metadata.id != thread_id {
warn!(
"thread id mismatch during archive: expected {thread_id}, got {}",
metadata.id
);
}
self.upsert_thread(&metadata).await
}
/// Mark a thread as unarchived using the underlying database.
pub async fn mark_unarchived(
&self,
thread_id: ThreadId,
rollout_path: &Path,
) -> anyhow::Result<()> {
let Some(mut metadata) = self.get_thread(thread_id).await? else {
return Ok(());
};
metadata.archived_at = None;
metadata.rollout_path = rollout_path.to_path_buf();
if let Some(updated_at) = file_modified_time_utc(rollout_path).await {
metadata.updated_at = updated_at;
}
if metadata.id != thread_id {
warn!(
"thread id mismatch during unarchive: expected {thread_id}, got {}",
metadata.id
);
}
self.upsert_thread(&metadata).await
}
}
fn push_log_filters<'a>(builder: &mut QueryBuilder<'a, Sqlite>, query: &'a LogQuery) {
if let Some(level_upper) = query.level_upper.as_ref() {
builder
.push(" AND UPPER(level) = ")
.push_bind(level_upper.as_str());
}
if let Some(from_ts) = query.from_ts {
builder.push(" AND ts >= ").push_bind(from_ts);
}
if let Some(to_ts) = query.to_ts {
builder.push(" AND ts <= ").push_bind(to_ts);
}
push_like_filters(builder, "module_path", &query.module_like);
push_like_filters(builder, "file", &query.file_like);
let has_thread_filter = !query.thread_ids.is_empty() || query.include_threadless;
if has_thread_filter {
builder.push(" AND (");
let mut needs_or = false;
for thread_id in &query.thread_ids {
if needs_or {
builder.push(" OR ");
}
builder.push("thread_id = ").push_bind(thread_id.as_str());
needs_or = true;
}
if query.include_threadless {
if needs_or {
builder.push(" OR ");
}
builder.push("thread_id IS NULL");
}
builder.push(")");
}
if let Some(after_id) = query.after_id {
builder.push(" AND id > ").push_bind(after_id);
}
}
fn push_like_filters<'a>(
builder: &mut QueryBuilder<'a, Sqlite>,
column: &str,
filters: &'a [String],
) {
if filters.is_empty() {
return;
}
builder.push(" AND (");
for (idx, filter) in filters.iter().enumerate() {
if idx > 0 {
builder.push(" OR ");
}
builder
.push(column)
.push(" LIKE '%' || ")
.push_bind(filter.as_str())
.push(" || '%'");
}
builder.push(")");
}
fn extract_dynamic_tools(items: &[RolloutItem]) -> Option<Option<Vec<DynamicToolSpec>>> {
items.iter().find_map(|item| match item {
RolloutItem::SessionMeta(meta_line) => Some(meta_line.meta.dynamic_tools.clone()),
RolloutItem::ResponseItem(_)
| RolloutItem::Compacted(_)
| RolloutItem::TurnContext(_)
| RolloutItem::EventMsg(_) => None,
})
}
async fn open_sqlite(path: &Path) -> anyhow::Result<SqlitePool> {
let options = SqliteConnectOptions::new()
.filename(path)
.create_if_missing(true)
.journal_mode(SqliteJournalMode::Wal)
.synchronous(SqliteSynchronous::Normal)
.busy_timeout(Duration::from_secs(5))
.log_statements(LevelFilter::Off);
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await?;
MIGRATOR.run(&pool).await?;
Ok(pool)
}
fn push_thread_filters<'a>(
builder: &mut QueryBuilder<'a, Sqlite>,
archived_only: bool,
allowed_sources: &'a [String],
model_providers: Option<&'a [String]>,
anchor: Option<&crate::Anchor>,
sort_key: SortKey,
) {
builder.push(" WHERE 1 = 1");
if archived_only {
builder.push(" AND archived = 1");
} else {
builder.push(" AND archived = 0");
}
builder.push(" AND has_user_event = 1");
if !allowed_sources.is_empty() {
builder.push(" AND source IN (");
let mut separated = builder.separated(", ");
for source in allowed_sources {
separated.push_bind(source);
}
separated.push_unseparated(")");
}
if let Some(model_providers) = model_providers
&& !model_providers.is_empty()
{
builder.push(" AND model_provider IN (");
let mut separated = builder.separated(", ");
for provider in model_providers {
separated.push_bind(provider);
}
separated.push_unseparated(")");
}
if let Some(anchor) = anchor {
let anchor_ts = datetime_to_epoch_seconds(anchor.ts);
let column = match sort_key {
SortKey::CreatedAt => "created_at",
SortKey::UpdatedAt => "updated_at",
};
builder.push(" AND (");
builder.push(column);
builder.push(" < ");
builder.push_bind(anchor_ts);
builder.push(" OR (");
builder.push(column);
builder.push(" = ");
builder.push_bind(anchor_ts);
builder.push(" AND id < ");
builder.push_bind(anchor.id.to_string());
builder.push("))");
}
}
fn push_thread_order_and_limit(
builder: &mut QueryBuilder<'_, Sqlite>,
sort_key: SortKey,
limit: usize,
) {
let order_column = match sort_key {
SortKey::CreatedAt => "created_at",
SortKey::UpdatedAt => "updated_at",
};
builder.push(" ORDER BY ");
builder.push(order_column);
builder.push(" DESC, id DESC");
builder.push(" LIMIT ");
builder.push_bind(limit as i64);
}