add logging & add regular pings

This commit is contained in:
Ruslan Nigmatullin
2026-04-03 22:54:06 -07:00
parent 2230a7ca20
commit cb3931b0ad
2 changed files with 241 additions and 27 deletions

View File

@@ -7,6 +7,7 @@ use codex_state::StateRuntime;
use gethostname::gethostname;
use std::io;
use std::io::ErrorKind;
use tracing::info;
use tracing::warn;
const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
@@ -36,26 +37,48 @@ pub(super) async fn load_persisted_remote_control_enrollment(
remote_control_target: &RemoteControlTarget,
account_id: Option<&str>,
) -> Option<RemoteControlEnrollment> {
let state_db = state_db?;
let Some(state_db) = state_db else {
info!(
"remote control enrollment cache unavailable because sqlite state db is disabled: websocket_url={}, account_id={:?}",
remote_control_target.websocket_url, account_id
);
return None;
};
let enrollment = match state_db
.get_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
.await
{
Ok(enrollment) => enrollment,
Err(err) => {
warn!("{err}");
warn!(
"failed to load persisted remote control enrollment: websocket_url={}, account_id={:?}, err={err}",
remote_control_target.websocket_url, account_id
);
return None;
}
};
enrollment.map(
|(server_id, environment_id, server_name)| RemoteControlEnrollment {
account_id: account_id.map(&str::to_string),
environment_id,
server_id,
server_name,
},
)
match enrollment {
Some((server_id, environment_id, server_name)) => {
info!(
"reusing persisted remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}",
remote_control_target.websocket_url, account_id, server_id, environment_id
);
Some(RemoteControlEnrollment {
account_id: account_id.map(&str::to_string),
environment_id,
server_id,
server_name,
})
}
None => {
info!(
"no persisted remote control enrollment found: websocket_url={}, account_id={:?}",
remote_control_target.websocket_url, account_id
);
None
}
}
}
pub(super) async fn update_persisted_remote_control_enrollment(
@@ -65,6 +88,12 @@ pub(super) async fn update_persisted_remote_control_enrollment(
enrollment: Option<&RemoteControlEnrollment>,
) -> io::Result<()> {
let Some(state_db) = state_db else {
info!(
"skipping remote control enrollment persistence because sqlite state db is disabled: websocket_url={}, account_id={:?}, has_enrollment={}",
remote_control_target.websocket_url,
account_id,
enrollment.is_some()
);
return Ok(());
};
if let &Some(enrollment) = &enrollment
@@ -85,13 +114,25 @@ pub(super) async fn update_persisted_remote_control_enrollment(
&enrollment.server_name,
)
.await
.map_err(io::Error::other)
.map_err(io::Error::other)?;
info!(
"persisted remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}",
remote_control_target.websocket_url,
account_id,
enrollment.server_id,
enrollment.environment_id
);
Ok(())
} else {
state_db
let rows_affected = state_db
.delete_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
.await
.map(|_| ())
.map_err(io::Error::other)
.map_err(io::Error::other)?;
info!(
"cleared persisted remote control enrollment: websocket_url={}, account_id={:?}, rows_affected={rows_affected}",
remote_control_target.websocket_url, account_id
);
Ok(())
}
}

View File

@@ -49,6 +49,10 @@ use tracing::warn;
pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "2";
pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor";
const REMOTE_CONTROL_WEBSOCKET_PING_INTERVAL: std::time::Duration =
std::time::Duration::from_secs(10);
const REMOTE_CONTROL_WEBSOCKET_PONG_TIMEOUT: std::time::Duration =
std::time::Duration::from_secs(60);
struct BoundedOutboundBuffer {
buffer_by_client: HashMap<(ClientId, StreamId), BTreeMap<u64, ServerEnvelope>>,
@@ -235,12 +239,14 @@ impl RemoteControlWebsocket {
self.server_event_rx.clone(),
self.used_rx.clone(),
websocket_writer,
REMOTE_CONTROL_WEBSOCKET_PING_INTERVAL,
shutdown_token.clone(),
));
join_set.spawn(Self::run_websocket_reader(
self.client_tracker.clone(),
self.state.clone(),
websocket_reader,
REMOTE_CONTROL_WEBSOCKET_PONG_TIMEOUT,
shutdown_token.clone(),
));
@@ -260,6 +266,7 @@ impl RemoteControlWebsocket {
WebSocketStream<MaybeTlsStream<TcpStream>>,
tungstenite::Message,
>,
ping_interval: std::time::Duration,
shutdown_token: CancellationToken,
) {
let result = Self::run_server_writer_inner(
@@ -267,6 +274,7 @@ impl RemoteControlWebsocket {
server_event_rx,
used_rx,
websocket_writer,
ping_interval,
shutdown_token,
)
.await;
@@ -285,6 +293,7 @@ impl RemoteControlWebsocket {
WebSocketStream<MaybeTlsStream<TcpStream>>,
tungstenite::Message,
>,
ping_interval: std::time::Duration,
shutdown_token: CancellationToken,
) -> io::Result<()> {
for server_envelope in state.lock().await.outbound_buffer.server_envelopes() {
@@ -305,15 +314,35 @@ impl RemoteControlWebsocket {
};
}
let mut ping_interval =
tokio::time::interval_at(tokio::time::Instant::now() + ping_interval, ping_interval);
ping_interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
let mut server_event_rx = server_event_rx.lock().await;
loop {
tokio::select! {
_ = shutdown_token.cancelled() => return Ok(()),
_ = used_rx.wait_for(|used| *used < super::CHANNEL_CAPACITY) => {}
};
let outbound_has_capacity = *used_rx.borrow() < super::CHANNEL_CAPACITY;
let queued_server_envelope = tokio::select! {
_ = shutdown_token.cancelled() => return Ok(()),
recv_result = server_event_rx.recv() => {
_ = ping_interval.tick() => {
if let Err(err) = websocket_writer
.send(tungstenite::Message::Ping(Vec::new().into()))
.await
{
return Err(io::Error::other(err));
}
continue;
}
wait_result = used_rx.changed(), if !outbound_has_capacity =>
{
if wait_result.is_err() {
return Err(io::Error::new(
ErrorKind::UnexpectedEof,
"outbound buffer usage channel closed",
));
}
continue;
}
recv_result = server_event_rx.recv(), if outbound_has_capacity => {
match recv_result {
Some(queued_server_envelope) => queued_server_envelope,
None => {
@@ -364,12 +393,14 @@ impl RemoteControlWebsocket {
client_tracker: Arc<Mutex<ClientTracker>>,
state: Arc<Mutex<WebsocketState>>,
websocket_reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
pong_timeout: std::time::Duration,
shutdown_token: CancellationToken,
) {
let result = Self::run_websocket_reader_inner(
client_tracker,
state,
websocket_reader,
pong_timeout,
shutdown_token,
)
.await;
@@ -384,15 +415,24 @@ impl RemoteControlWebsocket {
client_tracker: Arc<Mutex<ClientTracker>>,
state: Arc<Mutex<WebsocketState>>,
mut websocket_reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
pong_timeout: std::time::Duration,
shutdown_token: CancellationToken,
) -> io::Result<()> {
let mut client_tracker = client_tracker.lock().await;
let mut idle_sweep_interval = tokio::time::interval(REMOTE_CONTROL_IDLE_SWEEP_INTERVAL);
idle_sweep_interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
let pong_deadline = tokio::time::sleep(pong_timeout);
tokio::pin!(pong_deadline);
loop {
let incoming_message = tokio::select! {
_ = shutdown_token.cancelled() => return Ok(()),
_ = &mut pong_deadline => {
return Err(io::Error::new(
ErrorKind::TimedOut,
"remote control websocket pong timeout",
));
}
client_id = client_tracker.bookkeep_join_set() => {
let Some(client_id) = client_id else {
continue;
@@ -432,9 +472,13 @@ impl RemoteControlWebsocket {
}
}
}
Ok(tungstenite::Message::Ping(_))
| Ok(tungstenite::Message::Pong(_))
| Ok(tungstenite::Message::Frame(_)) => continue,
Ok(tungstenite::Message::Pong(_)) => {
pong_deadline
.as_mut()
.reset(tokio::time::Instant::now() + pong_timeout);
continue;
}
Ok(tungstenite::Message::Ping(_)) | Ok(tungstenite::Message::Frame(_)) => continue,
Ok(tungstenite::Message::Binary(_)) => {
warn!("dropping unsupported binary remote-control websocket message");
continue;
@@ -601,11 +645,16 @@ pub(super) async fn connect_remote_control_websocket(
ensure_rustls_crypto_provider();
let auth = load_remote_control_auth(auth_manager).await?;
if auth.account_id.as_ref()
!= enrollment
.as_ref()
.and_then(|enrollment| enrollment.account_id.as_ref())
{
let enrollment_account_id = enrollment
.as_ref()
.and_then(|enrollment| enrollment.account_id.clone());
if auth.account_id.as_deref() != enrollment_account_id.as_deref() {
info!(
"clearing in-memory remote control enrollment because account id changed: websocket_url={}, previous_account_id={:?}, current_account_id={:?}",
remote_control_target.websocket_url,
enrollment_account_id.as_deref(),
auth.account_id.as_deref()
);
*enrollment = None;
}
@@ -619,6 +668,12 @@ pub(super) async fn connect_remote_control_websocket(
}
if enrollment.is_none() {
info!(
"creating new remote control enrollment: websocket_url={}, enroll_url={}, account_id={:?}",
remote_control_target.websocket_url,
remote_control_target.enroll_url,
auth.account_id.as_deref()
);
let new_enrollment = match enroll_remote_control_server(remote_control_target, &auth).await
{
Ok(new_enrollment) => new_enrollment,
@@ -642,6 +697,13 @@ pub(super) async fn connect_remote_control_websocket(
{
warn!("failed to persist remote control enrollment in sqlite state db: {err}");
}
info!(
"created new remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}",
remote_control_target.websocket_url,
new_enrollment.account_id.as_deref(),
new_enrollment.server_id,
new_enrollment.environment_id
);
*enrollment = Some(new_enrollment);
}
@@ -660,6 +722,13 @@ pub(super) async fn connect_remote_control_websocket(
Err(err) => {
match &err {
tungstenite::Error::Http(response) if response.status().as_u16() == 404 => {
info!(
"remote control websocket returned HTTP 404; clearing stale enrollment before re-enrolling: websocket_url={}, account_id={:?}, server_id={}, environment_id={}",
remote_control_target.websocket_url,
auth.account_id.as_deref(),
enrollment_ref.server_id,
enrollment_ref.environment_id
);
if let Err(clear_err) = update_persisted_remote_control_enrollment(
state_db,
remote_control_target,
@@ -753,6 +822,7 @@ mod tests {
use codex_login::token_data::TokenData;
use codex_login::token_data::parse_chatgpt_jwt_claims;
use codex_state::StateRuntime;
use futures::StreamExt;
use pretty_assertions::assert_eq;
use std::sync::Arc;
use tempfile::TempDir;
@@ -764,6 +834,7 @@ mod tests {
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio::time::timeout;
use tokio_tungstenite::accept_async;
async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc<StateRuntime> {
StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string())
@@ -1052,6 +1123,80 @@ mod tests {
.expect("websocket task should join");
}
#[tokio::test]
async fn run_server_writer_inner_sends_periodic_ping_frames() {
let (client_stream, mut server_stream) = connected_websocket_pair().await;
let (websocket_writer, _websocket_reader) = client_stream.split();
let (outbound_buffer, used_rx) = BoundedOutboundBuffer::new();
let state = Arc::new(Mutex::new(WebsocketState {
outbound_buffer,
subscribe_cursor: None,
next_seq_id: 0,
}));
let (_server_event_tx, server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY);
let server_event_rx = Arc::new(Mutex::new(server_event_rx));
let shutdown_token = CancellationToken::new();
let writer_task = tokio::spawn(RemoteControlWebsocket::run_server_writer_inner(
state,
server_event_rx,
used_rx,
websocket_writer,
Duration::from_millis(20),
shutdown_token.clone(),
));
let message = timeout(Duration::from_secs(5), server_stream.next())
.await
.expect("ping frame should arrive in time")
.expect("server websocket should stay open")
.expect("ping frame should read");
assert!(matches!(message, tungstenite::Message::Ping(_)));
shutdown_token.cancel();
writer_task
.await
.expect("writer task should join")
.expect("writer should stop cleanly");
}
#[tokio::test]
async fn run_websocket_reader_inner_times_out_without_pong_frames() {
let (client_stream, _server_stream) = connected_websocket_pair().await;
let (_websocket_writer, websocket_reader) = client_stream.split();
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
let state = Arc::new(Mutex::new(WebsocketState {
outbound_buffer,
subscribe_cursor: None,
next_seq_id: 0,
}));
let (server_event_tx, _server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY);
let (transport_event_tx, _transport_event_rx) =
mpsc::channel(super::super::CHANNEL_CAPACITY);
let shutdown_token = CancellationToken::new();
let client_tracker = Arc::new(Mutex::new(ClientTracker::new(
server_event_tx,
transport_event_tx,
&shutdown_token,
)));
let err = timeout(
Duration::from_secs(5),
RemoteControlWebsocket::run_websocket_reader_inner(
client_tracker,
state,
websocket_reader,
Duration::from_millis(100),
shutdown_token,
),
)
.await
.expect("reader should time out waiting for pong")
.expect_err("missing pong should fail the websocket reader");
assert_eq!(err.kind(), ErrorKind::TimedOut);
assert_eq!(err.to_string(), "remote control websocket pong timeout");
}
async fn accept_http_request(listener: &TcpListener) -> (TcpStream, String) {
let (stream, _) = timeout(Duration::from_secs(5), listener.accept())
.await
@@ -1081,6 +1226,34 @@ mod tests {
)
}
async fn connected_websocket_pair() -> (
WebSocketStream<MaybeTlsStream<TcpStream>>,
WebSocketStream<TcpStream>,
) {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let connect_task = tokio::spawn(connect_async(format!(
"ws://{}",
listener
.local_addr()
.expect("listener should have a local addr")
)));
let (server_stream, _) = listener
.accept()
.await
.expect("server should accept client");
let server_stream = accept_async(server_stream)
.await
.expect("server websocket handshake should succeed");
let (client_stream, _) = connect_task
.await
.expect("client connect task should join")
.expect("client websocket handshake should succeed");
(client_stream, server_stream)
}
async fn respond_with_status_and_headers(
mut stream: TcpStream,
status: &str,