Compare commits

...

1 Commits

Author SHA1 Message Date
Rasmus Rygaard
cf698a53f1 ping pong pump 2026-02-11 10:18:14 -08:00

View File

@@ -26,6 +26,7 @@ use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio::time::Instant;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
@@ -45,9 +46,11 @@ type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
const X_MODELS_ETAG_HEADER: &str = "x-models-etag";
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
const IDLE_CONTROL_FRAME_POLL_INTERVAL: Duration = Duration::from_millis(250);
pub struct ResponsesWebsocketConnection {
stream: Arc<Mutex<Option<WsStream>>>,
idle_control_task: Arc<Mutex<Option<JoinHandle<()>>>>,
// TODO (pakrym): is this the right place for timeout?
idle_timeout: Duration,
server_reasoning_included: bool,
@@ -63,8 +66,13 @@ impl ResponsesWebsocketConnection {
models_etag: Option<String>,
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
) -> Self {
let stream = Arc::new(Mutex::new(Some(stream)));
let idle_control_task = Arc::new(Mutex::new(Some(tokio::spawn(run_idle_control_frames(
Arc::clone(&stream),
)))));
Self {
stream: Arc::new(Mutex::new(Some(stream))),
stream,
idle_control_task,
idle_timeout,
server_reasoning_included,
models_etag,
@@ -80,9 +88,12 @@ impl ResponsesWebsocketConnection {
&self,
request: ResponsesWsRequest,
) -> Result<ResponseStream, ApiError> {
stop_idle_control_task(&self.idle_control_task).await;
let (tx_event, rx_event) =
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
let stream = Arc::clone(&self.stream);
let idle_control_task = Arc::clone(&self.idle_control_task);
let idle_timeout = self.idle_timeout;
let server_reasoning_included = self.server_reasoning_included;
let models_etag = self.models_etag.clone();
@@ -122,6 +133,8 @@ impl ResponsesWebsocketConnection {
let _ = ws_stream.close(None).await;
*guard = None;
let _ = tx_event.send(Err(err)).await;
} else {
spawn_idle_control_task_if_open(Arc::clone(&stream), idle_control_task).await;
}
});
@@ -129,6 +142,74 @@ impl ResponsesWebsocketConnection {
}
}
async fn stop_idle_control_task(idle_control_task: &Arc<Mutex<Option<JoinHandle<()>>>>) {
if let Some(handle) = idle_control_task.lock().await.take() {
handle.abort();
let _ = handle.await;
}
}
async fn spawn_idle_control_task_if_open(
stream: Arc<Mutex<Option<WsStream>>>,
idle_control_task: Arc<Mutex<Option<JoinHandle<()>>>>,
) {
let mut task_guard = idle_control_task.lock().await;
if task_guard.is_some() {
return;
}
*task_guard = Some(tokio::spawn(run_idle_control_frames(stream)));
}
async fn run_idle_control_frames(stream: Arc<Mutex<Option<WsStream>>>) {
loop {
let mut guard = stream.lock().await;
let Some(ws_stream) = guard.as_mut() else {
return;
};
let message =
match tokio::time::timeout(IDLE_CONTROL_FRAME_POLL_INTERVAL, ws_stream.next()).await {
Ok(Some(Ok(message))) => message,
Ok(Some(Err(err))) => {
debug!("idle websocket read failed: {err}");
let _ = ws_stream.close(None).await;
*guard = None;
return;
}
Ok(None) => {
*guard = None;
return;
}
Err(_) => {
continue;
}
};
match message {
Message::Ping(payload) => {
if ws_stream.send(Message::Pong(payload)).await.is_err() {
let _ = ws_stream.close(None).await;
*guard = None;
return;
}
}
Message::Pong(_) => {}
Message::Close(_) => {
let _ = ws_stream.close(None).await;
*guard = None;
return;
}
Message::Text(text) => {
debug!("dropping unexpected idle websocket event: {text}");
}
Message::Binary(_) => {
debug!("dropping unexpected idle websocket binary event");
}
_ => {}
}
}
}
pub struct ResponsesWebsocketClient<A: AuthProvider> {
provider: Provider,
auth: A,