mirror of
https://github.com/openai/codex.git
synced 2026-05-02 04:11:39 +03:00
refactor: codex app-server ThreadState (#11419)
this is a no-op functionality wise. consolidates thread-specific message processor / event handling state in ThreadState
This commit is contained in:
@@ -1,14 +1,12 @@
|
||||
use crate::codex_message_processor::ApiVersion;
|
||||
use crate::codex_message_processor::PendingInterrupts;
|
||||
use crate::codex_message_processor::PendingRollbacks;
|
||||
use crate::codex_message_processor::TurnSummary;
|
||||
use crate::codex_message_processor::TurnSummaryStore;
|
||||
use crate::codex_message_processor::read_rollout_items_from_rollout;
|
||||
use crate::codex_message_processor::read_summary_from_rollout;
|
||||
use crate::codex_message_processor::summary_to_thread;
|
||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::thread_state::ThreadState;
|
||||
use crate::thread_state::TurnSummary;
|
||||
use codex_app_server_protocol::AccountRateLimitsUpdatedNotification;
|
||||
use codex_app_server_protocol::AgentMessageDeltaNotification;
|
||||
use codex_app_server_protocol::ApplyPatchApprovalParams;
|
||||
@@ -98,6 +96,7 @@ use std::collections::HashMap;
|
||||
use std::convert::TryFrom;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::error;
|
||||
|
||||
@@ -109,9 +108,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
conversation_id: ThreadId,
|
||||
conversation: Arc<CodexThread>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
pending_interrupts: PendingInterrupts,
|
||||
pending_rollbacks: PendingRollbacks,
|
||||
turn_summary_store: TurnSummaryStore,
|
||||
thread_state: Arc<tokio::sync::Mutex<ThreadState>>,
|
||||
api_version: ApiVersion,
|
||||
fallback_model_provider: String,
|
||||
) {
|
||||
@@ -122,13 +119,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
match msg {
|
||||
EventMsg::TurnStarted(_) => {}
|
||||
EventMsg::TurnComplete(_ev) => {
|
||||
handle_turn_complete(
|
||||
conversation_id,
|
||||
event_turn_id,
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
)
|
||||
.await;
|
||||
handle_turn_complete(conversation_id, event_turn_id, &outgoing, &thread_state).await;
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id,
|
||||
@@ -159,9 +150,11 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
let patch_changes = convert_patch_changes(&changes);
|
||||
|
||||
let first_start = {
|
||||
let mut map = turn_summary_store.lock().await;
|
||||
let summary = map.entry(conversation_id).or_default();
|
||||
summary.file_change_started.insert(item_id.clone())
|
||||
let mut state = thread_state.lock().await;
|
||||
state
|
||||
.turn_summary
|
||||
.file_change_started
|
||||
.insert(item_id.clone())
|
||||
};
|
||||
if first_start {
|
||||
let item = ThreadItem::FileChange {
|
||||
@@ -198,7 +191,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
rx,
|
||||
conversation,
|
||||
outgoing,
|
||||
turn_summary_store,
|
||||
thread_state.clone(),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
@@ -718,7 +711,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
return handle_thread_rollback_failed(
|
||||
conversation_id,
|
||||
message,
|
||||
&pending_rollbacks,
|
||||
&thread_state,
|
||||
&outgoing,
|
||||
)
|
||||
.await;
|
||||
@@ -729,7 +722,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
codex_error_info: ev.codex_error_info.map(V2CodexErrorInfo::from),
|
||||
additional_details: None,
|
||||
};
|
||||
handle_error(conversation_id, turn_error.clone(), &turn_summary_store).await;
|
||||
handle_error(conversation_id, turn_error.clone(), &thread_state).await;
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::Error(ErrorNotification {
|
||||
error: turn_error.clone(),
|
||||
@@ -867,9 +860,11 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
let item_id = patch_begin_event.call_id.clone();
|
||||
|
||||
let first_start = {
|
||||
let mut map = turn_summary_store.lock().await;
|
||||
let summary = map.entry(conversation_id).or_default();
|
||||
summary.file_change_started.insert(item_id.clone())
|
||||
let mut state = thread_state.lock().await;
|
||||
state
|
||||
.turn_summary
|
||||
.file_change_started
|
||||
.insert(item_id.clone())
|
||||
};
|
||||
if first_start {
|
||||
let item = ThreadItem::FileChange {
|
||||
@@ -905,7 +900,7 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
status,
|
||||
event_turn_id.clone(),
|
||||
outgoing.as_ref(),
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -950,9 +945,8 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
// We need to detect which item type it is so we can emit the right notification.
|
||||
// We already have state tracking FileChange items on item/started, so let's use that.
|
||||
let is_file_change = {
|
||||
let map = turn_summary_store.lock().await;
|
||||
map.get(&conversation_id)
|
||||
.is_some_and(|summary| summary.file_change_started.contains(&item_id))
|
||||
let state = thread_state.lock().await;
|
||||
state.turn_summary.file_change_started.contains(&item_id)
|
||||
};
|
||||
if is_file_change {
|
||||
let notification = FileChangeOutputDeltaNotification {
|
||||
@@ -1049,8 +1043,8 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
// If this is a TurnAborted, reply to any pending interrupt requests.
|
||||
EventMsg::TurnAborted(turn_aborted_event) => {
|
||||
let pending = {
|
||||
let mut map = pending_interrupts.lock().await;
|
||||
map.remove(&conversation_id).unwrap_or_default()
|
||||
let mut state = thread_state.lock().await;
|
||||
std::mem::take(&mut state.pending_interrupts)
|
||||
};
|
||||
if !pending.is_empty() {
|
||||
for (rid, ver) in pending {
|
||||
@@ -1069,18 +1063,12 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
}
|
||||
}
|
||||
|
||||
handle_turn_interrupted(
|
||||
conversation_id,
|
||||
event_turn_id,
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
)
|
||||
.await;
|
||||
handle_turn_interrupted(conversation_id, event_turn_id, &outgoing, &thread_state).await;
|
||||
}
|
||||
EventMsg::ThreadRolledBack(_rollback_event) => {
|
||||
let pending = {
|
||||
let mut map = pending_rollbacks.lock().await;
|
||||
map.remove(&conversation_id)
|
||||
let mut state = thread_state.lock().await;
|
||||
state.pending_rollbacks.take()
|
||||
};
|
||||
|
||||
if let Some(request_id) = pending {
|
||||
@@ -1245,14 +1233,11 @@ async fn complete_file_change_item(
|
||||
status: PatchApplyStatus,
|
||||
turn_id: String,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
{
|
||||
let mut map = turn_summary_store.lock().await;
|
||||
if let Some(summary) = map.get_mut(&conversation_id) {
|
||||
summary.file_change_started.remove(&item_id);
|
||||
}
|
||||
}
|
||||
let mut state = thread_state.lock().await;
|
||||
state.turn_summary.file_change_started.remove(&item_id);
|
||||
drop(state);
|
||||
|
||||
let item = ThreadItem::FileChange {
|
||||
id: item_id,
|
||||
@@ -1324,20 +1309,20 @@ async fn maybe_emit_raw_response_item_completed(
|
||||
}
|
||||
|
||||
async fn find_and_remove_turn_summary(
|
||||
conversation_id: ThreadId,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
_conversation_id: ThreadId,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) -> TurnSummary {
|
||||
let mut map = turn_summary_store.lock().await;
|
||||
map.remove(&conversation_id).unwrap_or_default()
|
||||
let mut state = thread_state.lock().await;
|
||||
std::mem::take(&mut state.turn_summary)
|
||||
}
|
||||
|
||||
async fn handle_turn_complete(
|
||||
conversation_id: ThreadId,
|
||||
event_turn_id: String,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let turn_summary = find_and_remove_turn_summary(conversation_id, turn_summary_store).await;
|
||||
let turn_summary = find_and_remove_turn_summary(conversation_id, thread_state).await;
|
||||
|
||||
let (status, error) = match turn_summary.last_error {
|
||||
Some(error) => (TurnStatus::Failed, Some(error)),
|
||||
@@ -1351,9 +1336,9 @@ async fn handle_turn_interrupted(
|
||||
conversation_id: ThreadId,
|
||||
event_turn_id: String,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
find_and_remove_turn_summary(conversation_id, turn_summary_store).await;
|
||||
find_and_remove_turn_summary(conversation_id, thread_state).await;
|
||||
|
||||
emit_turn_completed_with_status(
|
||||
conversation_id,
|
||||
@@ -1366,15 +1351,12 @@ async fn handle_turn_interrupted(
|
||||
}
|
||||
|
||||
async fn handle_thread_rollback_failed(
|
||||
conversation_id: ThreadId,
|
||||
_conversation_id: ThreadId,
|
||||
message: String,
|
||||
pending_rollbacks: &PendingRollbacks,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
outgoing: &OutgoingMessageSender,
|
||||
) {
|
||||
let pending_rollback = {
|
||||
let mut map = pending_rollbacks.lock().await;
|
||||
map.remove(&conversation_id)
|
||||
};
|
||||
let pending_rollback = thread_state.lock().await.pending_rollbacks.take();
|
||||
|
||||
if let Some(request_id) = pending_rollback {
|
||||
outgoing
|
||||
@@ -1419,12 +1401,12 @@ async fn handle_token_count_event(
|
||||
}
|
||||
|
||||
async fn handle_error(
|
||||
conversation_id: ThreadId,
|
||||
_conversation_id: ThreadId,
|
||||
error: TurnError,
|
||||
turn_summary_store: &TurnSummaryStore,
|
||||
thread_state: &Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let mut map = turn_summary_store.lock().await;
|
||||
map.entry(conversation_id).or_default().last_error = Some(error);
|
||||
let mut state = thread_state.lock().await;
|
||||
state.turn_summary.last_error = Some(error);
|
||||
}
|
||||
|
||||
async fn on_patch_approval_response(
|
||||
@@ -1652,7 +1634,7 @@ async fn on_file_change_request_approval_response(
|
||||
receiver: oneshot::Receiver<JsonValue>,
|
||||
codex: Arc<CodexThread>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
turn_summary_store: TurnSummaryStore,
|
||||
thread_state: Arc<Mutex<ThreadState>>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let (decision, completion_status) = match response {
|
||||
@@ -1685,7 +1667,7 @@ async fn on_file_change_request_approval_response(
|
||||
status,
|
||||
event_turn_id.clone(),
|
||||
outgoing.as_ref(),
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -1915,13 +1897,12 @@ mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
use rmcp::model::Content;
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
fn new_turn_summary_store() -> TurnSummaryStore {
|
||||
Arc::new(Mutex::new(HashMap::new()))
|
||||
fn new_thread_state() -> Arc<Mutex<ThreadState>> {
|
||||
Arc::new(Mutex::new(ThreadState::default()))
|
||||
}
|
||||
|
||||
async fn recv_broadcast_message(
|
||||
@@ -1999,7 +1980,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_handle_error_records_message() -> Result<()> {
|
||||
let conversation_id = ThreadId::new();
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
let thread_state = new_thread_state();
|
||||
|
||||
handle_error(
|
||||
conversation_id,
|
||||
@@ -2008,11 +1989,11 @@ mod tests {
|
||||
codex_error_info: Some(V2CodexErrorInfo::InternalServerError),
|
||||
additional_details: None,
|
||||
},
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
|
||||
let turn_summary = find_and_remove_turn_summary(conversation_id, &turn_summary_store).await;
|
||||
let turn_summary = find_and_remove_turn_summary(conversation_id, &thread_state).await;
|
||||
assert_eq!(
|
||||
turn_summary.last_error,
|
||||
Some(TurnError {
|
||||
@@ -2030,13 +2011,13 @@ mod tests {
|
||||
let event_turn_id = "complete1".to_string();
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
let thread_state = new_thread_state();
|
||||
|
||||
handle_turn_complete(
|
||||
conversation_id,
|
||||
event_turn_id.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -2057,7 +2038,7 @@ mod tests {
|
||||
async fn test_handle_turn_interrupted_emits_interrupted_with_error() -> Result<()> {
|
||||
let conversation_id = ThreadId::new();
|
||||
let event_turn_id = "interrupt1".to_string();
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
let thread_state = new_thread_state();
|
||||
handle_error(
|
||||
conversation_id,
|
||||
TurnError {
|
||||
@@ -2065,7 +2046,7 @@ mod tests {
|
||||
codex_error_info: None,
|
||||
additional_details: None,
|
||||
},
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
@@ -2075,7 +2056,7 @@ mod tests {
|
||||
conversation_id,
|
||||
event_turn_id.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -2096,7 +2077,7 @@ mod tests {
|
||||
async fn test_handle_turn_complete_emits_failed_with_error() -> Result<()> {
|
||||
let conversation_id = ThreadId::new();
|
||||
let event_turn_id = "complete_err1".to_string();
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
let thread_state = new_thread_state();
|
||||
handle_error(
|
||||
conversation_id,
|
||||
TurnError {
|
||||
@@ -2104,7 +2085,7 @@ mod tests {
|
||||
codex_error_info: Some(V2CodexErrorInfo::Other),
|
||||
additional_details: None,
|
||||
},
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
@@ -2114,7 +2095,7 @@ mod tests {
|
||||
conversation_id,
|
||||
event_turn_id.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -2336,7 +2317,7 @@ mod tests {
|
||||
// Conversation A will have two turns; Conversation B will have one turn.
|
||||
let conversation_a = ThreadId::new();
|
||||
let conversation_b = ThreadId::new();
|
||||
let turn_summary_store = new_turn_summary_store();
|
||||
let thread_state = new_thread_state();
|
||||
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
@@ -2350,16 +2331,10 @@ mod tests {
|
||||
codex_error_info: Some(V2CodexErrorInfo::BadRequest),
|
||||
additional_details: None,
|
||||
},
|
||||
&turn_summary_store,
|
||||
)
|
||||
.await;
|
||||
handle_turn_complete(
|
||||
conversation_a,
|
||||
a_turn1.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
handle_turn_complete(conversation_a, a_turn1.clone(), &outgoing, &thread_state).await;
|
||||
|
||||
// Turn 1 on conversation B
|
||||
let b_turn1 = "b_turn1".to_string();
|
||||
@@ -2370,26 +2345,14 @@ mod tests {
|
||||
codex_error_info: None,
|
||||
additional_details: None,
|
||||
},
|
||||
&turn_summary_store,
|
||||
)
|
||||
.await;
|
||||
handle_turn_complete(
|
||||
conversation_b,
|
||||
b_turn1.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
&thread_state,
|
||||
)
|
||||
.await;
|
||||
handle_turn_complete(conversation_b, b_turn1.clone(), &outgoing, &thread_state).await;
|
||||
|
||||
// Turn 2 on conversation A
|
||||
let a_turn2 = "a_turn2".to_string();
|
||||
handle_turn_complete(
|
||||
conversation_a,
|
||||
a_turn2.clone(),
|
||||
&outgoing,
|
||||
&turn_summary_store,
|
||||
)
|
||||
.await;
|
||||
handle_turn_complete(conversation_a, a_turn2.clone(), &outgoing, &thread_state).await;
|
||||
|
||||
// Verify: A turn 1
|
||||
let msg = recv_broadcast_message(&mut rx).await?;
|
||||
|
||||
Reference in New Issue
Block a user