mirror of
https://github.com/openai/codex.git
synced 2026-05-04 13:21:54 +03:00
codex-rs/app-server: graceful websocket restart on Ctrl-C (#12517)
## Summary - add graceful websocket app-server restart on Ctrl-C by draining until no assistant turns are running - stop the websocket acceptor and disconnect existing connections once the drain condition is met - add a websocket integration test that verifies Ctrl-C waits for an in-flight turn before exit ## Verification - `cargo check -p codex-app-server --quiet` - `cargo test -p codex-app-server --test all suite::v2::connection_handling_websocket` - I (maxj) tested remote and local Codex.app --------- Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -15,11 +15,13 @@ use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
#[cfg(test)]
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ThreadWatchManager {
|
||||
state: Arc<Mutex<ThreadWatchState>>,
|
||||
outgoing: Option<Arc<OutgoingMessageSender>>,
|
||||
running_turn_count_tx: watch::Sender<usize>,
|
||||
}
|
||||
|
||||
pub(crate) struct ThreadWatchActiveGuard {
|
||||
@@ -71,16 +73,20 @@ impl Default for ThreadWatchManager {
|
||||
|
||||
impl ThreadWatchManager {
|
||||
pub(crate) fn new() -> Self {
|
||||
let (running_turn_count_tx, _running_turn_count_rx) = watch::channel(0);
|
||||
Self {
|
||||
state: Arc::new(Mutex::new(ThreadWatchState::default())),
|
||||
outgoing: None,
|
||||
running_turn_count_tx,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_with_outgoing(outgoing: Arc<OutgoingMessageSender>) -> Self {
|
||||
let (running_turn_count_tx, _running_turn_count_rx) = watch::channel(0);
|
||||
Self {
|
||||
state: Arc::new(Mutex::new(ThreadWatchState::default())),
|
||||
outgoing: Some(outgoing),
|
||||
running_turn_count_tx,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,6 +119,21 @@ impl ThreadWatchManager {
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) async fn running_turn_count(&self) -> usize {
|
||||
self.state
|
||||
.lock()
|
||||
.await
|
||||
.runtime_by_thread_id
|
||||
.values()
|
||||
.filter(|runtime| runtime.running)
|
||||
.count()
|
||||
}
|
||||
|
||||
pub(crate) fn subscribe_running_turn_count(&self) -> watch::Receiver<usize> {
|
||||
self.running_turn_count_tx.subscribe()
|
||||
}
|
||||
|
||||
pub(crate) async fn note_turn_started(&self, thread_id: &str) {
|
||||
self.update_runtime_for_thread(thread_id, |runtime| {
|
||||
runtime.is_loaded = true;
|
||||
@@ -193,10 +214,17 @@ impl ThreadWatchManager {
|
||||
where
|
||||
F: FnOnce(&mut ThreadWatchState) -> Option<ThreadStatusChangedNotification>,
|
||||
{
|
||||
let notification = {
|
||||
let (notification, running_turn_count) = {
|
||||
let mut state = self.state.lock().await;
|
||||
mutate(&mut state)
|
||||
let notification = mutate(&mut state);
|
||||
let running_turn_count = state
|
||||
.runtime_by_thread_id
|
||||
.values()
|
||||
.filter(|runtime| runtime.running)
|
||||
.count();
|
||||
(notification, running_turn_count)
|
||||
};
|
||||
let _ = self.running_turn_count_tx.send(running_turn_count);
|
||||
|
||||
if let Some(notification) = notification
|
||||
&& let Some(outgoing) = &self.outgoing
|
||||
@@ -588,6 +616,32 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn has_running_turns_tracks_runtime_running_flag_only() {
|
||||
let manager = ThreadWatchManager::new();
|
||||
manager
|
||||
.upsert_thread(test_thread(
|
||||
INTERACTIVE_THREAD_ID,
|
||||
codex_app_server_protocol::SessionSource::Cli,
|
||||
))
|
||||
.await;
|
||||
|
||||
assert_eq!(manager.running_turn_count().await, 0);
|
||||
|
||||
let _permission_guard = manager
|
||||
.note_permission_requested(INTERACTIVE_THREAD_ID)
|
||||
.await;
|
||||
assert_eq!(manager.running_turn_count().await, 0);
|
||||
|
||||
manager.note_turn_started(INTERACTIVE_THREAD_ID).await;
|
||||
assert_eq!(manager.running_turn_count().await, 1);
|
||||
|
||||
manager
|
||||
.note_turn_completed(INTERACTIVE_THREAD_ID, false)
|
||||
.await;
|
||||
assert_eq!(manager.running_turn_count().await, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn status_change_emits_notification() {
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(8);
|
||||
|
||||
Reference in New Issue
Block a user