codex-rs/app-server: graceful websocket restart on Ctrl-C (#12517)

## Summary
- add graceful websocket app-server restart on Ctrl-C by draining until
no assistant turns are running
- stop the websocket acceptor and disconnect existing connections once
the drain condition is met
- add a websocket integration test that verifies Ctrl-C waits for an
in-flight turn before exit

## Verification
- `cargo check -p codex-app-server --quiet`
- `cargo test -p codex-app-server --test all
suite::v2::connection_handling_websocket`
- I (maxj) tested remote and local Codex.app

---------

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Max Johnson
2026-02-24 16:27:59 -08:00
committed by GitHub
parent 3d356723c4
commit 5163850025
8 changed files with 493 additions and 42 deletions

View File

@@ -15,11 +15,13 @@ use std::sync::Arc;
use tokio::sync::Mutex;
#[cfg(test)]
use tokio::sync::mpsc;
use tokio::sync::watch;
#[derive(Clone)]
pub(crate) struct ThreadWatchManager {
state: Arc<Mutex<ThreadWatchState>>,
outgoing: Option<Arc<OutgoingMessageSender>>,
running_turn_count_tx: watch::Sender<usize>,
}
pub(crate) struct ThreadWatchActiveGuard {
@@ -71,16 +73,20 @@ impl Default for ThreadWatchManager {
impl ThreadWatchManager {
pub(crate) fn new() -> Self {
let (running_turn_count_tx, _running_turn_count_rx) = watch::channel(0);
Self {
state: Arc::new(Mutex::new(ThreadWatchState::default())),
outgoing: None,
running_turn_count_tx,
}
}
pub(crate) fn new_with_outgoing(outgoing: Arc<OutgoingMessageSender>) -> Self {
let (running_turn_count_tx, _running_turn_count_rx) = watch::channel(0);
Self {
state: Arc::new(Mutex::new(ThreadWatchState::default())),
outgoing: Some(outgoing),
running_turn_count_tx,
}
}
@@ -113,6 +119,21 @@ impl ThreadWatchManager {
.collect()
}
#[cfg(test)]
pub(crate) async fn running_turn_count(&self) -> usize {
self.state
.lock()
.await
.runtime_by_thread_id
.values()
.filter(|runtime| runtime.running)
.count()
}
pub(crate) fn subscribe_running_turn_count(&self) -> watch::Receiver<usize> {
self.running_turn_count_tx.subscribe()
}
pub(crate) async fn note_turn_started(&self, thread_id: &str) {
self.update_runtime_for_thread(thread_id, |runtime| {
runtime.is_loaded = true;
@@ -193,10 +214,17 @@ impl ThreadWatchManager {
where
F: FnOnce(&mut ThreadWatchState) -> Option<ThreadStatusChangedNotification>,
{
let notification = {
let (notification, running_turn_count) = {
let mut state = self.state.lock().await;
mutate(&mut state)
let notification = mutate(&mut state);
let running_turn_count = state
.runtime_by_thread_id
.values()
.filter(|runtime| runtime.running)
.count();
(notification, running_turn_count)
};
let _ = self.running_turn_count_tx.send(running_turn_count);
if let Some(notification) = notification
&& let Some(outgoing) = &self.outgoing
@@ -588,6 +616,32 @@ mod tests {
);
}
#[tokio::test]
async fn has_running_turns_tracks_runtime_running_flag_only() {
let manager = ThreadWatchManager::new();
manager
.upsert_thread(test_thread(
INTERACTIVE_THREAD_ID,
codex_app_server_protocol::SessionSource::Cli,
))
.await;
assert_eq!(manager.running_turn_count().await, 0);
let _permission_guard = manager
.note_permission_requested(INTERACTIVE_THREAD_ID)
.await;
assert_eq!(manager.running_turn_count().await, 0);
manager.note_turn_started(INTERACTIVE_THREAD_ID).await;
assert_eq!(manager.running_turn_count().await, 1);
manager
.note_turn_completed(INTERACTIVE_THREAD_ID, false)
.await;
assert_eq!(manager.running_turn_count().await, 0);
}
#[tokio::test]
async fn status_change_emits_notification() {
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(8);