codex-rs/app-server: graceful websocket restart on Ctrl-C (#12517)

## Summary
- add graceful websocket app-server restart on Ctrl-C by draining until
no assistant turns are running
- stop the websocket acceptor and disconnect existing connections once
the drain condition is met
- add a websocket integration test that verifies Ctrl-C waits for an
in-flight turn before exit

## Verification
- `cargo check -p codex-app-server --quiet`
- `cargo test -p codex-app-server --test all
suite::v2::connection_handling_websocket`
- I (maxj) tested remote and local Codex.app

---------

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Max Johnson
2026-02-24 16:27:59 -08:00
committed by GitHub
parent 3d356723c4
commit 5163850025
8 changed files with 493 additions and 42 deletions

View File

@@ -98,6 +98,72 @@ enum OutboundControlEvent {
},
/// Remove state for a closed/disconnected connection.
Closed { connection_id: ConnectionId },
/// Disconnect all connection-oriented clients during graceful restart.
DisconnectAll,
}
#[derive(Default)]
struct ShutdownState {
requested: bool,
forced: bool,
last_logged_running_turn_count: Option<usize>,
}
enum ShutdownAction {
Noop,
Finish,
}
impl ShutdownState {
fn requested(&self) -> bool {
self.requested
}
fn forced(&self) -> bool {
self.forced
}
fn on_ctrl_c(&mut self, connection_count: usize, running_turn_count: usize) {
if self.requested {
self.forced = true;
return;
}
self.requested = true;
self.last_logged_running_turn_count = None;
info!(
"received Ctrl-C; entering graceful restart drain (connections={}, runningAssistantTurns={}, requests still accepted until no assistant turns are running)",
connection_count, running_turn_count,
);
}
fn update(&mut self, running_turn_count: usize, connection_count: usize) -> ShutdownAction {
if !self.requested {
return ShutdownAction::Noop;
}
if self.forced || running_turn_count == 0 {
if self.forced {
info!(
"received second Ctrl-C; forcing restart with {running_turn_count} running assistant turn(s) and {connection_count} connection(s)"
);
} else {
info!(
"Ctrl-C restart: no assistant turns running; stopping acceptor and disconnecting {connection_count} connection(s)"
);
}
return ShutdownAction::Finish;
}
if self.last_logged_running_turn_count != Some(running_turn_count) {
info!(
"Ctrl-C restart: waiting for {running_turn_count} running assistant turn(s) to finish"
);
self.last_logged_running_turn_count = Some(running_turn_count);
}
ShutdownAction::Noop
}
}
fn config_warning_from_error(
@@ -253,19 +319,37 @@ pub async fn run_main_with_transport(
let (outbound_control_tx, mut outbound_control_rx) =
mpsc::channel::<OutboundControlEvent>(CHANNEL_CAPACITY);
enum TransportRuntime {
Stdio,
WebSocket {
accept_handle: JoinHandle<()>,
shutdown_token: CancellationToken,
},
}
let mut stdio_handles = Vec::<JoinHandle<()>>::new();
let mut websocket_accept_handle = None;
match transport {
let transport_runtime = match transport {
AppServerTransport::Stdio => {
start_stdio_connection(transport_event_tx.clone(), &mut stdio_handles).await?;
TransportRuntime::Stdio
}
AppServerTransport::WebSocket { bind_address } => {
websocket_accept_handle =
Some(start_websocket_acceptor(bind_address, transport_event_tx.clone()).await?);
let shutdown_token = CancellationToken::new();
let accept_handle = start_websocket_acceptor(
bind_address,
transport_event_tx.clone(),
shutdown_token.clone(),
)
.await?;
TransportRuntime::WebSocket {
accept_handle,
shutdown_token,
}
}
}
let single_client_mode = matches!(transport, AppServerTransport::Stdio);
};
let single_client_mode = matches!(&transport_runtime, TransportRuntime::Stdio);
let shutdown_when_no_connections = single_client_mode;
let graceful_ctrl_c_restart_enabled = !single_client_mode;
// Parse CLI overrides once and derive the base Config eagerly so later
// components do not need to work with raw TOML values.
@@ -434,6 +518,16 @@ pub async fn run_main_with_transport(
OutboundControlEvent::Closed { connection_id } => {
outbound_connections.remove(&connection_id);
}
OutboundControlEvent::DisconnectAll => {
info!(
"disconnecting {} outbound websocket connection(s) for graceful restart",
outbound_connections.len()
);
for connection_state in outbound_connections.values() {
connection_state.request_disconnect();
}
outbound_connections.clear();
}
}
}
envelope = outgoing_rx.recv() => {
@@ -464,11 +558,46 @@ pub async fn run_main_with_transport(
config_warnings,
});
let mut thread_created_rx = processor.thread_created_receiver();
let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count();
let mut connections = HashMap::<ConnectionId, ConnectionState>::new();
let websocket_accept_shutdown = match &transport_runtime {
TransportRuntime::WebSocket { shutdown_token, .. } => Some(shutdown_token.clone()),
TransportRuntime::Stdio => None,
};
async move {
let mut listen_for_threads = true;
let mut shutdown_state = ShutdownState::default();
loop {
let running_turn_count = {
let running_turn_count = running_turn_count_rx.borrow();
*running_turn_count
};
if matches!(
shutdown_state.update(running_turn_count, connections.len()),
ShutdownAction::Finish
) {
if let Some(shutdown_token) = &websocket_accept_shutdown {
shutdown_token.cancel();
}
let _ = outbound_control_tx
.send(OutboundControlEvent::DisconnectAll)
.await;
break;
}
tokio::select! {
ctrl_c_result = tokio::signal::ctrl_c(), if graceful_ctrl_c_restart_enabled && !shutdown_state.forced() => {
if let Err(err) = ctrl_c_result {
warn!("failed to listen for Ctrl-C during graceful restart drain: {err}");
}
let running_turn_count = *running_turn_count_rx.borrow();
shutdown_state.on_ctrl_c(connections.len(), running_turn_count);
}
changed = running_turn_count_rx.changed(), if graceful_ctrl_c_restart_enabled && shutdown_state.requested() => {
if changed.is_err() {
warn!("running-turn watcher closed during graceful restart drain");
}
}
event = transport_event_rx.recv() => {
let Some(event) = event else {
break;
@@ -619,8 +748,13 @@ pub async fn run_main_with_transport(
let _ = processor_handle.await;
let _ = outbound_handle.await;
if let Some(handle) = websocket_accept_handle {
handle.abort();
if let TransportRuntime::WebSocket {
accept_handle,
shutdown_token,
} = transport_runtime
{
shutdown_token.cancel();
let _ = accept_handle.await;
}
for handle in stdio_handles {