refactor: codex app-server ThreadState (#11419)

this is a no-op functionality wise. consolidates thread-specific message
processor / event handling state in ThreadState
This commit is contained in:
Max Johnson
2026-02-11 12:20:54 -08:00
committed by GitHub
parent 42e22f3bde
commit b5339a591d
4 changed files with 258 additions and 187 deletions

View File

@@ -1,14 +1,12 @@
use crate::codex_message_processor::ApiVersion;
use crate::codex_message_processor::PendingInterrupts;
use crate::codex_message_processor::PendingRollbacks;
use crate::codex_message_processor::TurnSummary;
use crate::codex_message_processor::TurnSummaryStore;
use crate::codex_message_processor::read_rollout_items_from_rollout;
use crate::codex_message_processor::read_summary_from_rollout;
use crate::codex_message_processor::summary_to_thread;
use crate::error_code::INTERNAL_ERROR_CODE;
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
use crate::outgoing_message::OutgoingMessageSender;
use crate::thread_state::ThreadState;
use crate::thread_state::TurnSummary;
use codex_app_server_protocol::AccountRateLimitsUpdatedNotification;
use codex_app_server_protocol::AgentMessageDeltaNotification;
use codex_app_server_protocol::ApplyPatchApprovalParams;
@@ -98,6 +96,7 @@ use std::collections::HashMap;
use std::convert::TryFrom;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use tracing::error;
@@ -109,9 +108,7 @@ pub(crate) async fn apply_bespoke_event_handling(
conversation_id: ThreadId,
conversation: Arc<CodexThread>,
outgoing: Arc<OutgoingMessageSender>,
pending_interrupts: PendingInterrupts,
pending_rollbacks: PendingRollbacks,
turn_summary_store: TurnSummaryStore,
thread_state: Arc<tokio::sync::Mutex<ThreadState>>,
api_version: ApiVersion,
fallback_model_provider: String,
) {
@@ -122,13 +119,7 @@ pub(crate) async fn apply_bespoke_event_handling(
match msg {
EventMsg::TurnStarted(_) => {}
EventMsg::TurnComplete(_ev) => {
handle_turn_complete(
conversation_id,
event_turn_id,
&outgoing,
&turn_summary_store,
)
.await;
handle_turn_complete(conversation_id, event_turn_id, &outgoing, &thread_state).await;
}
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
call_id,
@@ -159,9 +150,11 @@ pub(crate) async fn apply_bespoke_event_handling(
let patch_changes = convert_patch_changes(&changes);
let first_start = {
let mut map = turn_summary_store.lock().await;
let summary = map.entry(conversation_id).or_default();
summary.file_change_started.insert(item_id.clone())
let mut state = thread_state.lock().await;
state
.turn_summary
.file_change_started
.insert(item_id.clone())
};
if first_start {
let item = ThreadItem::FileChange {
@@ -198,7 +191,7 @@ pub(crate) async fn apply_bespoke_event_handling(
rx,
conversation,
outgoing,
turn_summary_store,
thread_state.clone(),
)
.await;
});
@@ -718,7 +711,7 @@ pub(crate) async fn apply_bespoke_event_handling(
return handle_thread_rollback_failed(
conversation_id,
message,
&pending_rollbacks,
&thread_state,
&outgoing,
)
.await;
@@ -729,7 +722,7 @@ pub(crate) async fn apply_bespoke_event_handling(
codex_error_info: ev.codex_error_info.map(V2CodexErrorInfo::from),
additional_details: None,
};
handle_error(conversation_id, turn_error.clone(), &turn_summary_store).await;
handle_error(conversation_id, turn_error.clone(), &thread_state).await;
outgoing
.send_server_notification(ServerNotification::Error(ErrorNotification {
error: turn_error.clone(),
@@ -867,9 +860,11 @@ pub(crate) async fn apply_bespoke_event_handling(
let item_id = patch_begin_event.call_id.clone();
let first_start = {
let mut map = turn_summary_store.lock().await;
let summary = map.entry(conversation_id).or_default();
summary.file_change_started.insert(item_id.clone())
let mut state = thread_state.lock().await;
state
.turn_summary
.file_change_started
.insert(item_id.clone())
};
if first_start {
let item = ThreadItem::FileChange {
@@ -905,7 +900,7 @@ pub(crate) async fn apply_bespoke_event_handling(
status,
event_turn_id.clone(),
outgoing.as_ref(),
&turn_summary_store,
&thread_state,
)
.await;
}
@@ -950,9 +945,8 @@ pub(crate) async fn apply_bespoke_event_handling(
// We need to detect which item type it is so we can emit the right notification.
// We already have state tracking FileChange items on item/started, so let's use that.
let is_file_change = {
let map = turn_summary_store.lock().await;
map.get(&conversation_id)
.is_some_and(|summary| summary.file_change_started.contains(&item_id))
let state = thread_state.lock().await;
state.turn_summary.file_change_started.contains(&item_id)
};
if is_file_change {
let notification = FileChangeOutputDeltaNotification {
@@ -1049,8 +1043,8 @@ pub(crate) async fn apply_bespoke_event_handling(
// If this is a TurnAborted, reply to any pending interrupt requests.
EventMsg::TurnAborted(turn_aborted_event) => {
let pending = {
let mut map = pending_interrupts.lock().await;
map.remove(&conversation_id).unwrap_or_default()
let mut state = thread_state.lock().await;
std::mem::take(&mut state.pending_interrupts)
};
if !pending.is_empty() {
for (rid, ver) in pending {
@@ -1069,18 +1063,12 @@ pub(crate) async fn apply_bespoke_event_handling(
}
}
handle_turn_interrupted(
conversation_id,
event_turn_id,
&outgoing,
&turn_summary_store,
)
.await;
handle_turn_interrupted(conversation_id, event_turn_id, &outgoing, &thread_state).await;
}
EventMsg::ThreadRolledBack(_rollback_event) => {
let pending = {
let mut map = pending_rollbacks.lock().await;
map.remove(&conversation_id)
let mut state = thread_state.lock().await;
state.pending_rollbacks.take()
};
if let Some(request_id) = pending {
@@ -1245,14 +1233,11 @@ async fn complete_file_change_item(
status: PatchApplyStatus,
turn_id: String,
outgoing: &OutgoingMessageSender,
turn_summary_store: &TurnSummaryStore,
thread_state: &Arc<Mutex<ThreadState>>,
) {
{
let mut map = turn_summary_store.lock().await;
if let Some(summary) = map.get_mut(&conversation_id) {
summary.file_change_started.remove(&item_id);
}
}
let mut state = thread_state.lock().await;
state.turn_summary.file_change_started.remove(&item_id);
drop(state);
let item = ThreadItem::FileChange {
id: item_id,
@@ -1324,20 +1309,20 @@ async fn maybe_emit_raw_response_item_completed(
}
async fn find_and_remove_turn_summary(
conversation_id: ThreadId,
turn_summary_store: &TurnSummaryStore,
_conversation_id: ThreadId,
thread_state: &Arc<Mutex<ThreadState>>,
) -> TurnSummary {
let mut map = turn_summary_store.lock().await;
map.remove(&conversation_id).unwrap_or_default()
let mut state = thread_state.lock().await;
std::mem::take(&mut state.turn_summary)
}
async fn handle_turn_complete(
conversation_id: ThreadId,
event_turn_id: String,
outgoing: &OutgoingMessageSender,
turn_summary_store: &TurnSummaryStore,
thread_state: &Arc<Mutex<ThreadState>>,
) {
let turn_summary = find_and_remove_turn_summary(conversation_id, turn_summary_store).await;
let turn_summary = find_and_remove_turn_summary(conversation_id, thread_state).await;
let (status, error) = match turn_summary.last_error {
Some(error) => (TurnStatus::Failed, Some(error)),
@@ -1351,9 +1336,9 @@ async fn handle_turn_interrupted(
conversation_id: ThreadId,
event_turn_id: String,
outgoing: &OutgoingMessageSender,
turn_summary_store: &TurnSummaryStore,
thread_state: &Arc<Mutex<ThreadState>>,
) {
find_and_remove_turn_summary(conversation_id, turn_summary_store).await;
find_and_remove_turn_summary(conversation_id, thread_state).await;
emit_turn_completed_with_status(
conversation_id,
@@ -1366,15 +1351,12 @@ async fn handle_turn_interrupted(
}
async fn handle_thread_rollback_failed(
conversation_id: ThreadId,
_conversation_id: ThreadId,
message: String,
pending_rollbacks: &PendingRollbacks,
thread_state: &Arc<Mutex<ThreadState>>,
outgoing: &OutgoingMessageSender,
) {
let pending_rollback = {
let mut map = pending_rollbacks.lock().await;
map.remove(&conversation_id)
};
let pending_rollback = thread_state.lock().await.pending_rollbacks.take();
if let Some(request_id) = pending_rollback {
outgoing
@@ -1419,12 +1401,12 @@ async fn handle_token_count_event(
}
async fn handle_error(
conversation_id: ThreadId,
_conversation_id: ThreadId,
error: TurnError,
turn_summary_store: &TurnSummaryStore,
thread_state: &Arc<Mutex<ThreadState>>,
) {
let mut map = turn_summary_store.lock().await;
map.entry(conversation_id).or_default().last_error = Some(error);
let mut state = thread_state.lock().await;
state.turn_summary.last_error = Some(error);
}
async fn on_patch_approval_response(
@@ -1652,7 +1634,7 @@ async fn on_file_change_request_approval_response(
receiver: oneshot::Receiver<JsonValue>,
codex: Arc<CodexThread>,
outgoing: Arc<OutgoingMessageSender>,
turn_summary_store: TurnSummaryStore,
thread_state: Arc<Mutex<ThreadState>>,
) {
let response = receiver.await;
let (decision, completion_status) = match response {
@@ -1685,7 +1667,7 @@ async fn on_file_change_request_approval_response(
status,
event_turn_id.clone(),
outgoing.as_ref(),
&turn_summary_store,
&thread_state,
)
.await;
}
@@ -1915,13 +1897,12 @@ mod tests {
use pretty_assertions::assert_eq;
use rmcp::model::Content;
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
fn new_turn_summary_store() -> TurnSummaryStore {
Arc::new(Mutex::new(HashMap::new()))
fn new_thread_state() -> Arc<Mutex<ThreadState>> {
Arc::new(Mutex::new(ThreadState::default()))
}
async fn recv_broadcast_message(
@@ -1999,7 +1980,7 @@ mod tests {
#[tokio::test]
async fn test_handle_error_records_message() -> Result<()> {
let conversation_id = ThreadId::new();
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
handle_error(
conversation_id,
@@ -2008,11 +1989,11 @@ mod tests {
codex_error_info: Some(V2CodexErrorInfo::InternalServerError),
additional_details: None,
},
&turn_summary_store,
&thread_state,
)
.await;
let turn_summary = find_and_remove_turn_summary(conversation_id, &turn_summary_store).await;
let turn_summary = find_and_remove_turn_summary(conversation_id, &thread_state).await;
assert_eq!(
turn_summary.last_error,
Some(TurnError {
@@ -2030,13 +2011,13 @@ mod tests {
let event_turn_id = "complete1".to_string();
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
handle_turn_complete(
conversation_id,
event_turn_id.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
@@ -2057,7 +2038,7 @@ mod tests {
async fn test_handle_turn_interrupted_emits_interrupted_with_error() -> Result<()> {
let conversation_id = ThreadId::new();
let event_turn_id = "interrupt1".to_string();
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
handle_error(
conversation_id,
TurnError {
@@ -2065,7 +2046,7 @@ mod tests {
codex_error_info: None,
additional_details: None,
},
&turn_summary_store,
&thread_state,
)
.await;
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
@@ -2075,7 +2056,7 @@ mod tests {
conversation_id,
event_turn_id.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
@@ -2096,7 +2077,7 @@ mod tests {
async fn test_handle_turn_complete_emits_failed_with_error() -> Result<()> {
let conversation_id = ThreadId::new();
let event_turn_id = "complete_err1".to_string();
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
handle_error(
conversation_id,
TurnError {
@@ -2104,7 +2085,7 @@ mod tests {
codex_error_info: Some(V2CodexErrorInfo::Other),
additional_details: None,
},
&turn_summary_store,
&thread_state,
)
.await;
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
@@ -2114,7 +2095,7 @@ mod tests {
conversation_id,
event_turn_id.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
@@ -2336,7 +2317,7 @@ mod tests {
// Conversation A will have two turns; Conversation B will have one turn.
let conversation_a = ThreadId::new();
let conversation_b = ThreadId::new();
let turn_summary_store = new_turn_summary_store();
let thread_state = new_thread_state();
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
@@ -2350,16 +2331,10 @@ mod tests {
codex_error_info: Some(V2CodexErrorInfo::BadRequest),
additional_details: None,
},
&turn_summary_store,
)
.await;
handle_turn_complete(
conversation_a,
a_turn1.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
handle_turn_complete(conversation_a, a_turn1.clone(), &outgoing, &thread_state).await;
// Turn 1 on conversation B
let b_turn1 = "b_turn1".to_string();
@@ -2370,26 +2345,14 @@ mod tests {
codex_error_info: None,
additional_details: None,
},
&turn_summary_store,
)
.await;
handle_turn_complete(
conversation_b,
b_turn1.clone(),
&outgoing,
&turn_summary_store,
&thread_state,
)
.await;
handle_turn_complete(conversation_b, b_turn1.clone(), &outgoing, &thread_state).await;
// Turn 2 on conversation A
let a_turn2 = "a_turn2".to_string();
handle_turn_complete(
conversation_a,
a_turn2.clone(),
&outgoing,
&turn_summary_store,
)
.await;
handle_turn_complete(conversation_a, a_turn2.clone(), &outgoing, &thread_state).await;
// Verify: A turn 1
let msg = recv_broadcast_message(&mut rx).await?;