feat(app-server): thread/rollback API (#8454)

Add `thread/rollback` to app-server to support IDEs undo-ing the last N
turns of a thread.

For context, an IDE partner will be supporting an "undo" capability
where the IDE (the app-server client) will be responsible for reverting
the local changes made during the last turn. To support this well, we
also need a way to drop the last turn (or more generally, the last N
turns) from the agent's context. This is what `thread/rollback` does.

**Core idea**: A Thread rollback is represented as a persisted event
message (EventMsg::ThreadRollback) in the rollout JSONL file, not by
rewriting history. On resume, both the model's context (core replay) and
the UI turn list (app-server v2's thread history builder) apply these
markers so the pruned history is consistent across live conversations
and `thread/resume`.

Implementation notes:
- Rollback only affects agent context and appends to the rollout file;
clients are responsible for reverting files on disk.
- If a thread rollback is currently in progress, subsequent
`thread/rollback` calls are rejected.
- Because we use `CodexConversation::submit` and codex core tracks
active turns, returning an error on concurrent rollbacks is communicated
via an `EventMsg::Error` with a new variant
`CodexErrorInfo::ThreadRollbackFailed`. app-server watches for that and
sends the BAD_REQUEST RPC response.

Tests cover thread rollbacks in both core and app-server, including when
`num_turns` > existing turns (which clears all turns).

**Note**: this explicitly does **not** behave like `/undo` which we just
removed from the CLI, which does the opposite of what `thread/rollback`
does. `/undo` reverts local changes via ghost commits/snapshots and does
not modify the agent's context / conversation history.
This commit is contained in:
Owen Lin
2026-01-06 13:23:48 -08:00
committed by GitHub
parent 188f79afee
commit 8b7ec31ba7
21 changed files with 1187 additions and 30 deletions

View File

@@ -91,6 +91,7 @@ use codex_app_server_protocol::ThreadListParams;
use codex_app_server_protocol::ThreadListResponse;
use codex_app_server_protocol::ThreadResumeParams;
use codex_app_server_protocol::ThreadResumeResponse;
use codex_app_server_protocol::ThreadRollbackParams;
use codex_app_server_protocol::ThreadStartParams;
use codex_app_server_protocol::ThreadStartResponse;
use codex_app_server_protocol::ThreadStartedNotification;
@@ -178,6 +179,8 @@ use uuid::Uuid;
type PendingInterruptQueue = Vec<(RequestId, ApiVersion)>;
pub(crate) type PendingInterrupts = Arc<Mutex<HashMap<ConversationId, PendingInterruptQueue>>>;
pub(crate) type PendingRollbacks = Arc<Mutex<HashMap<ConversationId, RequestId>>>;
/// Per-conversation accumulation of the latest states e.g. error message while a turn runs.
#[derive(Default, Clone)]
pub(crate) struct TurnSummary {
@@ -220,6 +223,8 @@ pub(crate) struct CodexMessageProcessor {
active_login: Arc<Mutex<Option<ActiveLogin>>>,
// Queue of pending interrupt requests per conversation. We reply when TurnAborted arrives.
pending_interrupts: PendingInterrupts,
// Queue of pending rollback requests per conversation. We reply when ThreadRollback arrives.
pending_rollbacks: PendingRollbacks,
turn_summary_store: TurnSummaryStore,
pending_fuzzy_searches: Arc<Mutex<HashMap<String, Arc<AtomicBool>>>>,
feedback: CodexFeedback,
@@ -275,6 +280,7 @@ impl CodexMessageProcessor {
conversation_listeners: HashMap::new(),
active_login: Arc::new(Mutex::new(None)),
pending_interrupts: Arc::new(Mutex::new(HashMap::new())),
pending_rollbacks: Arc::new(Mutex::new(HashMap::new())),
turn_summary_store: Arc::new(Mutex::new(HashMap::new())),
pending_fuzzy_searches: Arc::new(Mutex::new(HashMap::new())),
feedback,
@@ -365,6 +371,9 @@ impl CodexMessageProcessor {
ClientRequest::ThreadArchive { request_id, params } => {
self.thread_archive(request_id, params).await;
}
ClientRequest::ThreadRollback { request_id, params } => {
self.thread_rollback(request_id, params).await;
}
ClientRequest::ThreadList { request_id, params } => {
self.thread_list(request_id, params).await;
}
@@ -1506,6 +1515,52 @@ impl CodexMessageProcessor {
}
}
async fn thread_rollback(&mut self, request_id: RequestId, params: ThreadRollbackParams) {
let ThreadRollbackParams {
thread_id,
num_turns,
} = params;
if num_turns == 0 {
self.send_invalid_request_error(request_id, "numTurns must be >= 1".to_string())
.await;
return;
}
let (conversation_id, conversation) =
match self.conversation_from_thread_id(&thread_id).await {
Ok(v) => v,
Err(error) => {
self.outgoing.send_error(request_id, error).await;
return;
}
};
{
let mut map = self.pending_rollbacks.lock().await;
if map.contains_key(&conversation_id) {
self.send_invalid_request_error(
request_id,
"rollback already in progress for this thread".to_string(),
)
.await;
return;
}
map.insert(conversation_id, request_id.clone());
}
if let Err(err) = conversation.submit(Op::ThreadRollback { num_turns }).await {
// No ThreadRollback event will arrive if an error occurs.
// Clean up and reply immediately.
let mut map = self.pending_rollbacks.lock().await;
map.remove(&conversation_id);
self.send_internal_error(request_id, format!("failed to start rollback: {err}"))
.await;
}
}
async fn thread_list(&self, request_id: RequestId, params: ThreadListParams) {
let ThreadListParams {
cursor,
@@ -3095,8 +3150,10 @@ impl CodexMessageProcessor {
let outgoing_for_task = self.outgoing.clone();
let pending_interrupts = self.pending_interrupts.clone();
let pending_rollbacks = self.pending_rollbacks.clone();
let turn_summary_store = self.turn_summary_store.clone();
let api_version_for_task = api_version;
let fallback_model_provider = self.config.model_provider_id.clone();
tokio::spawn(async move {
loop {
tokio::select! {
@@ -3152,8 +3209,10 @@ impl CodexMessageProcessor {
conversation.clone(),
outgoing_for_task.clone(),
pending_interrupts.clone(),
pending_rollbacks.clone(),
turn_summary_store.clone(),
api_version_for_task,
fallback_model_provider.clone(),
)
.await;
}
@@ -3354,7 +3413,7 @@ async fn derive_config_from_params(
Config::load_with_cli_overrides_and_harness_overrides(cli_overrides, overrides).await
}
async fn read_summary_from_rollout(
pub(crate) async fn read_summary_from_rollout(
path: &Path,
fallback_provider: &str,
) -> std::io::Result<ConversationSummary> {
@@ -3413,6 +3472,24 @@ async fn read_summary_from_rollout(
})
}
pub(crate) async fn read_event_msgs_from_rollout(
path: &Path,
) -> std::io::Result<Vec<codex_protocol::protocol::EventMsg>> {
let items = match RolloutRecorder::get_rollout_history(path).await? {
InitialHistory::New => Vec::new(),
InitialHistory::Forked(items) => items,
InitialHistory::Resumed(resumed) => resumed.history,
};
Ok(items
.into_iter()
.filter_map(|item| match item {
RolloutItem::EventMsg(event) => Some(event),
_ => None,
})
.collect())
}
fn extract_conversation_summary(
path: PathBuf,
head: &[serde_json::Value],
@@ -3474,7 +3551,7 @@ fn parse_datetime(timestamp: Option<&str>) -> Option<DateTime<Utc>> {
})
}
fn summary_to_thread(summary: ConversationSummary) -> Thread {
pub(crate) fn summary_to_thread(summary: ConversationSummary) -> Thread {
let ConversationSummary {
conversation_id,
path,