mirror of
https://github.com/openai/codex.git
synced 2026-04-27 09:51:03 +03:00
add logging & add regular pings
This commit is contained in:
@@ -7,6 +7,7 @@ use codex_state::StateRuntime;
|
||||
use gethostname::gethostname;
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
|
||||
@@ -36,26 +37,48 @@ pub(super) async fn load_persisted_remote_control_enrollment(
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
account_id: Option<&str>,
|
||||
) -> Option<RemoteControlEnrollment> {
|
||||
let state_db = state_db?;
|
||||
let Some(state_db) = state_db else {
|
||||
info!(
|
||||
"remote control enrollment cache unavailable because sqlite state db is disabled: websocket_url={}, account_id={:?}",
|
||||
remote_control_target.websocket_url, account_id
|
||||
);
|
||||
return None;
|
||||
};
|
||||
let enrollment = match state_db
|
||||
.get_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
|
||||
.await
|
||||
{
|
||||
Ok(enrollment) => enrollment,
|
||||
Err(err) => {
|
||||
warn!("{err}");
|
||||
warn!(
|
||||
"failed to load persisted remote control enrollment: websocket_url={}, account_id={:?}, err={err}",
|
||||
remote_control_target.websocket_url, account_id
|
||||
);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
enrollment.map(
|
||||
|(server_id, environment_id, server_name)| RemoteControlEnrollment {
|
||||
account_id: account_id.map(&str::to_string),
|
||||
environment_id,
|
||||
server_id,
|
||||
server_name,
|
||||
},
|
||||
)
|
||||
match enrollment {
|
||||
Some((server_id, environment_id, server_name)) => {
|
||||
info!(
|
||||
"reusing persisted remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}",
|
||||
remote_control_target.websocket_url, account_id, server_id, environment_id
|
||||
);
|
||||
Some(RemoteControlEnrollment {
|
||||
account_id: account_id.map(&str::to_string),
|
||||
environment_id,
|
||||
server_id,
|
||||
server_name,
|
||||
})
|
||||
}
|
||||
None => {
|
||||
info!(
|
||||
"no persisted remote control enrollment found: websocket_url={}, account_id={:?}",
|
||||
remote_control_target.websocket_url, account_id
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn update_persisted_remote_control_enrollment(
|
||||
@@ -65,6 +88,12 @@ pub(super) async fn update_persisted_remote_control_enrollment(
|
||||
enrollment: Option<&RemoteControlEnrollment>,
|
||||
) -> io::Result<()> {
|
||||
let Some(state_db) = state_db else {
|
||||
info!(
|
||||
"skipping remote control enrollment persistence because sqlite state db is disabled: websocket_url={}, account_id={:?}, has_enrollment={}",
|
||||
remote_control_target.websocket_url,
|
||||
account_id,
|
||||
enrollment.is_some()
|
||||
);
|
||||
return Ok(());
|
||||
};
|
||||
if let &Some(enrollment) = &enrollment
|
||||
@@ -85,13 +114,25 @@ pub(super) async fn update_persisted_remote_control_enrollment(
|
||||
&enrollment.server_name,
|
||||
)
|
||||
.await
|
||||
.map_err(io::Error::other)
|
||||
.map_err(io::Error::other)?;
|
||||
info!(
|
||||
"persisted remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}",
|
||||
remote_control_target.websocket_url,
|
||||
account_id,
|
||||
enrollment.server_id,
|
||||
enrollment.environment_id
|
||||
);
|
||||
Ok(())
|
||||
} else {
|
||||
state_db
|
||||
let rows_affected = state_db
|
||||
.delete_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(io::Error::other)
|
||||
.map_err(io::Error::other)?;
|
||||
info!(
|
||||
"cleared persisted remote control enrollment: websocket_url={}, account_id={:?}, rows_affected={rows_affected}",
|
||||
remote_control_target.websocket_url, account_id
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,10 @@ use tracing::warn;
|
||||
pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "2";
|
||||
pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
|
||||
const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor";
|
||||
const REMOTE_CONTROL_WEBSOCKET_PING_INTERVAL: std::time::Duration =
|
||||
std::time::Duration::from_secs(10);
|
||||
const REMOTE_CONTROL_WEBSOCKET_PONG_TIMEOUT: std::time::Duration =
|
||||
std::time::Duration::from_secs(60);
|
||||
|
||||
struct BoundedOutboundBuffer {
|
||||
buffer_by_client: HashMap<(ClientId, StreamId), BTreeMap<u64, ServerEnvelope>>,
|
||||
@@ -235,12 +239,14 @@ impl RemoteControlWebsocket {
|
||||
self.server_event_rx.clone(),
|
||||
self.used_rx.clone(),
|
||||
websocket_writer,
|
||||
REMOTE_CONTROL_WEBSOCKET_PING_INTERVAL,
|
||||
shutdown_token.clone(),
|
||||
));
|
||||
join_set.spawn(Self::run_websocket_reader(
|
||||
self.client_tracker.clone(),
|
||||
self.state.clone(),
|
||||
websocket_reader,
|
||||
REMOTE_CONTROL_WEBSOCKET_PONG_TIMEOUT,
|
||||
shutdown_token.clone(),
|
||||
));
|
||||
|
||||
@@ -260,6 +266,7 @@ impl RemoteControlWebsocket {
|
||||
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
tungstenite::Message,
|
||||
>,
|
||||
ping_interval: std::time::Duration,
|
||||
shutdown_token: CancellationToken,
|
||||
) {
|
||||
let result = Self::run_server_writer_inner(
|
||||
@@ -267,6 +274,7 @@ impl RemoteControlWebsocket {
|
||||
server_event_rx,
|
||||
used_rx,
|
||||
websocket_writer,
|
||||
ping_interval,
|
||||
shutdown_token,
|
||||
)
|
||||
.await;
|
||||
@@ -285,6 +293,7 @@ impl RemoteControlWebsocket {
|
||||
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
tungstenite::Message,
|
||||
>,
|
||||
ping_interval: std::time::Duration,
|
||||
shutdown_token: CancellationToken,
|
||||
) -> io::Result<()> {
|
||||
for server_envelope in state.lock().await.outbound_buffer.server_envelopes() {
|
||||
@@ -305,15 +314,35 @@ impl RemoteControlWebsocket {
|
||||
};
|
||||
}
|
||||
|
||||
let mut ping_interval =
|
||||
tokio::time::interval_at(tokio::time::Instant::now() + ping_interval, ping_interval);
|
||||
ping_interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
|
||||
|
||||
let mut server_event_rx = server_event_rx.lock().await;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_token.cancelled() => return Ok(()),
|
||||
_ = used_rx.wait_for(|used| *used < super::CHANNEL_CAPACITY) => {}
|
||||
};
|
||||
let outbound_has_capacity = *used_rx.borrow() < super::CHANNEL_CAPACITY;
|
||||
let queued_server_envelope = tokio::select! {
|
||||
_ = shutdown_token.cancelled() => return Ok(()),
|
||||
recv_result = server_event_rx.recv() => {
|
||||
_ = ping_interval.tick() => {
|
||||
if let Err(err) = websocket_writer
|
||||
.send(tungstenite::Message::Ping(Vec::new().into()))
|
||||
.await
|
||||
{
|
||||
return Err(io::Error::other(err));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
wait_result = used_rx.changed(), if !outbound_has_capacity =>
|
||||
{
|
||||
if wait_result.is_err() {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::UnexpectedEof,
|
||||
"outbound buffer usage channel closed",
|
||||
));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
recv_result = server_event_rx.recv(), if outbound_has_capacity => {
|
||||
match recv_result {
|
||||
Some(queued_server_envelope) => queued_server_envelope,
|
||||
None => {
|
||||
@@ -364,12 +393,14 @@ impl RemoteControlWebsocket {
|
||||
client_tracker: Arc<Mutex<ClientTracker>>,
|
||||
state: Arc<Mutex<WebsocketState>>,
|
||||
websocket_reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
||||
pong_timeout: std::time::Duration,
|
||||
shutdown_token: CancellationToken,
|
||||
) {
|
||||
let result = Self::run_websocket_reader_inner(
|
||||
client_tracker,
|
||||
state,
|
||||
websocket_reader,
|
||||
pong_timeout,
|
||||
shutdown_token,
|
||||
)
|
||||
.await;
|
||||
@@ -384,15 +415,24 @@ impl RemoteControlWebsocket {
|
||||
client_tracker: Arc<Mutex<ClientTracker>>,
|
||||
state: Arc<Mutex<WebsocketState>>,
|
||||
mut websocket_reader: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
||||
pong_timeout: std::time::Duration,
|
||||
shutdown_token: CancellationToken,
|
||||
) -> io::Result<()> {
|
||||
let mut client_tracker = client_tracker.lock().await;
|
||||
let mut idle_sweep_interval = tokio::time::interval(REMOTE_CONTROL_IDLE_SWEEP_INTERVAL);
|
||||
idle_sweep_interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
|
||||
let pong_deadline = tokio::time::sleep(pong_timeout);
|
||||
tokio::pin!(pong_deadline);
|
||||
|
||||
loop {
|
||||
let incoming_message = tokio::select! {
|
||||
_ = shutdown_token.cancelled() => return Ok(()),
|
||||
_ = &mut pong_deadline => {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::TimedOut,
|
||||
"remote control websocket pong timeout",
|
||||
));
|
||||
}
|
||||
client_id = client_tracker.bookkeep_join_set() => {
|
||||
let Some(client_id) = client_id else {
|
||||
continue;
|
||||
@@ -432,9 +472,13 @@ impl RemoteControlWebsocket {
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(tungstenite::Message::Ping(_))
|
||||
| Ok(tungstenite::Message::Pong(_))
|
||||
| Ok(tungstenite::Message::Frame(_)) => continue,
|
||||
Ok(tungstenite::Message::Pong(_)) => {
|
||||
pong_deadline
|
||||
.as_mut()
|
||||
.reset(tokio::time::Instant::now() + pong_timeout);
|
||||
continue;
|
||||
}
|
||||
Ok(tungstenite::Message::Ping(_)) | Ok(tungstenite::Message::Frame(_)) => continue,
|
||||
Ok(tungstenite::Message::Binary(_)) => {
|
||||
warn!("dropping unsupported binary remote-control websocket message");
|
||||
continue;
|
||||
@@ -601,11 +645,16 @@ pub(super) async fn connect_remote_control_websocket(
|
||||
ensure_rustls_crypto_provider();
|
||||
|
||||
let auth = load_remote_control_auth(auth_manager).await?;
|
||||
if auth.account_id.as_ref()
|
||||
!= enrollment
|
||||
.as_ref()
|
||||
.and_then(|enrollment| enrollment.account_id.as_ref())
|
||||
{
|
||||
let enrollment_account_id = enrollment
|
||||
.as_ref()
|
||||
.and_then(|enrollment| enrollment.account_id.clone());
|
||||
if auth.account_id.as_deref() != enrollment_account_id.as_deref() {
|
||||
info!(
|
||||
"clearing in-memory remote control enrollment because account id changed: websocket_url={}, previous_account_id={:?}, current_account_id={:?}",
|
||||
remote_control_target.websocket_url,
|
||||
enrollment_account_id.as_deref(),
|
||||
auth.account_id.as_deref()
|
||||
);
|
||||
*enrollment = None;
|
||||
}
|
||||
|
||||
@@ -619,6 +668,12 @@ pub(super) async fn connect_remote_control_websocket(
|
||||
}
|
||||
|
||||
if enrollment.is_none() {
|
||||
info!(
|
||||
"creating new remote control enrollment: websocket_url={}, enroll_url={}, account_id={:?}",
|
||||
remote_control_target.websocket_url,
|
||||
remote_control_target.enroll_url,
|
||||
auth.account_id.as_deref()
|
||||
);
|
||||
let new_enrollment = match enroll_remote_control_server(remote_control_target, &auth).await
|
||||
{
|
||||
Ok(new_enrollment) => new_enrollment,
|
||||
@@ -642,6 +697,13 @@ pub(super) async fn connect_remote_control_websocket(
|
||||
{
|
||||
warn!("failed to persist remote control enrollment in sqlite state db: {err}");
|
||||
}
|
||||
info!(
|
||||
"created new remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}",
|
||||
remote_control_target.websocket_url,
|
||||
new_enrollment.account_id.as_deref(),
|
||||
new_enrollment.server_id,
|
||||
new_enrollment.environment_id
|
||||
);
|
||||
*enrollment = Some(new_enrollment);
|
||||
}
|
||||
|
||||
@@ -660,6 +722,13 @@ pub(super) async fn connect_remote_control_websocket(
|
||||
Err(err) => {
|
||||
match &err {
|
||||
tungstenite::Error::Http(response) if response.status().as_u16() == 404 => {
|
||||
info!(
|
||||
"remote control websocket returned HTTP 404; clearing stale enrollment before re-enrolling: websocket_url={}, account_id={:?}, server_id={}, environment_id={}",
|
||||
remote_control_target.websocket_url,
|
||||
auth.account_id.as_deref(),
|
||||
enrollment_ref.server_id,
|
||||
enrollment_ref.environment_id
|
||||
);
|
||||
if let Err(clear_err) = update_persisted_remote_control_enrollment(
|
||||
state_db,
|
||||
remote_control_target,
|
||||
@@ -753,6 +822,7 @@ mod tests {
|
||||
use codex_login::token_data::TokenData;
|
||||
use codex_login::token_data::parse_chatgpt_jwt_claims;
|
||||
use codex_state::StateRuntime;
|
||||
use futures::StreamExt;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
@@ -764,6 +834,7 @@ mod tests {
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::accept_async;
|
||||
|
||||
async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc<StateRuntime> {
|
||||
StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string())
|
||||
@@ -1052,6 +1123,80 @@ mod tests {
|
||||
.expect("websocket task should join");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_server_writer_inner_sends_periodic_ping_frames() {
|
||||
let (client_stream, mut server_stream) = connected_websocket_pair().await;
|
||||
let (websocket_writer, _websocket_reader) = client_stream.split();
|
||||
let (outbound_buffer, used_rx) = BoundedOutboundBuffer::new();
|
||||
let state = Arc::new(Mutex::new(WebsocketState {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id: 0,
|
||||
}));
|
||||
let (_server_event_tx, server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY);
|
||||
let server_event_rx = Arc::new(Mutex::new(server_event_rx));
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let writer_task = tokio::spawn(RemoteControlWebsocket::run_server_writer_inner(
|
||||
state,
|
||||
server_event_rx,
|
||||
used_rx,
|
||||
websocket_writer,
|
||||
Duration::from_millis(20),
|
||||
shutdown_token.clone(),
|
||||
));
|
||||
|
||||
let message = timeout(Duration::from_secs(5), server_stream.next())
|
||||
.await
|
||||
.expect("ping frame should arrive in time")
|
||||
.expect("server websocket should stay open")
|
||||
.expect("ping frame should read");
|
||||
assert!(matches!(message, tungstenite::Message::Ping(_)));
|
||||
|
||||
shutdown_token.cancel();
|
||||
writer_task
|
||||
.await
|
||||
.expect("writer task should join")
|
||||
.expect("writer should stop cleanly");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_websocket_reader_inner_times_out_without_pong_frames() {
|
||||
let (client_stream, _server_stream) = connected_websocket_pair().await;
|
||||
let (_websocket_writer, websocket_reader) = client_stream.split();
|
||||
let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new();
|
||||
let state = Arc::new(Mutex::new(WebsocketState {
|
||||
outbound_buffer,
|
||||
subscribe_cursor: None,
|
||||
next_seq_id: 0,
|
||||
}));
|
||||
let (server_event_tx, _server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY);
|
||||
let (transport_event_tx, _transport_event_rx) =
|
||||
mpsc::channel(super::super::CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let client_tracker = Arc::new(Mutex::new(ClientTracker::new(
|
||||
server_event_tx,
|
||||
transport_event_tx,
|
||||
&shutdown_token,
|
||||
)));
|
||||
|
||||
let err = timeout(
|
||||
Duration::from_secs(5),
|
||||
RemoteControlWebsocket::run_websocket_reader_inner(
|
||||
client_tracker,
|
||||
state,
|
||||
websocket_reader,
|
||||
Duration::from_millis(100),
|
||||
shutdown_token,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("reader should time out waiting for pong")
|
||||
.expect_err("missing pong should fail the websocket reader");
|
||||
|
||||
assert_eq!(err.kind(), ErrorKind::TimedOut);
|
||||
assert_eq!(err.to_string(), "remote control websocket pong timeout");
|
||||
}
|
||||
|
||||
async fn accept_http_request(listener: &TcpListener) -> (TcpStream, String) {
|
||||
let (stream, _) = timeout(Duration::from_secs(5), listener.accept())
|
||||
.await
|
||||
@@ -1081,6 +1226,34 @@ mod tests {
|
||||
)
|
||||
}
|
||||
|
||||
async fn connected_websocket_pair() -> (
|
||||
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
WebSocketStream<TcpStream>,
|
||||
) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let connect_task = tokio::spawn(connect_async(format!(
|
||||
"ws://{}",
|
||||
listener
|
||||
.local_addr()
|
||||
.expect("listener should have a local addr")
|
||||
)));
|
||||
let (server_stream, _) = listener
|
||||
.accept()
|
||||
.await
|
||||
.expect("server should accept client");
|
||||
let server_stream = accept_async(server_stream)
|
||||
.await
|
||||
.expect("server websocket handshake should succeed");
|
||||
let (client_stream, _) = connect_task
|
||||
.await
|
||||
.expect("client connect task should join")
|
||||
.expect("client websocket handshake should succeed");
|
||||
|
||||
(client_stream, server_stream)
|
||||
}
|
||||
|
||||
async fn respond_with_status_and_headers(
|
||||
mut stream: TcpStream,
|
||||
status: &str,
|
||||
|
||||
Reference in New Issue
Block a user