diff --git a/codex-rs/codex-api/src/endpoint/mod.rs b/codex-rs/codex-api/src/endpoint/mod.rs index 981643904e..277fda5bfa 100644 --- a/codex-rs/codex-api/src/endpoint/mod.rs +++ b/codex-rs/codex-api/src/endpoint/mod.rs @@ -6,3 +6,4 @@ pub mod realtime_websocket; pub mod responses; pub mod responses_websocket; mod session; +mod websocket_pump; diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs index e9a297de07..21faa849e7 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs @@ -7,21 +7,17 @@ use crate::endpoint::realtime_websocket::protocol::RealtimeSessionConfig; use crate::endpoint::realtime_websocket::protocol::SessionCreateSession; use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession; use crate::endpoint::realtime_websocket::protocol::parse_realtime_event; +use crate::endpoint::websocket_pump::WebsocketMessage; +use crate::endpoint::websocket_pump::WebsocketPump; use crate::error::ApiError; use crate::provider::Provider; use codex_utils_rustls_provider::ensure_rustls_crypto_provider; -use futures::SinkExt; -use futures::StreamExt; use http::HeaderMap; use std::sync::Arc; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; -use tokio::net::TcpStream; use tokio::sync::Mutex; use tokio::sync::mpsc; -use tokio::sync::oneshot; -use tokio_tungstenite::MaybeTlsStream; -use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::tungstenite::Error as WsError; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::client::IntoClientRequest; @@ -30,123 +26,6 @@ use tracing::trace; use tungstenite::protocol::WebSocketConfig; use url::Url; -struct WsStream { - tx_command: mpsc::Sender, - pump_task: tokio::task::JoinHandle<()>, -} - -enum WsCommand { - Send { - message: Message, - tx_result: oneshot::Sender>, - }, - Close { - tx_result: oneshot::Sender>, - }, -} - -impl WsStream { - fn new( - inner: WebSocketStream>, - ) -> (Self, mpsc::UnboundedReceiver>) { - let (tx_command, mut rx_command) = mpsc::channel::(32); - let (tx_message, rx_message) = mpsc::unbounded_channel::>(); - - let pump_task = tokio::spawn(async move { - let mut inner = inner; - loop { - tokio::select! { - command = rx_command.recv() => { - let Some(command) = command else { - break; - }; - match command { - WsCommand::Send { message, tx_result } => { - let result = inner.send(message).await; - let should_break = result.is_err(); - let _ = tx_result.send(result); - if should_break { - break; - } - } - WsCommand::Close { tx_result } => { - let result = inner.close(None).await; - let _ = tx_result.send(result); - break; - } - } - } - message = inner.next() => { - let Some(message) = message else { - break; - }; - match message { - Ok(Message::Ping(payload)) => { - if let Err(err) = inner.send(Message::Pong(payload)).await { - let _ = tx_message.send(Err(err)); - break; - } - } - Ok(Message::Pong(_)) => {} - Ok(message @ (Message::Text(_) - | Message::Binary(_) - | Message::Close(_) - | Message::Frame(_))) => { - let is_close = matches!(message, Message::Close(_)); - if tx_message.send(Ok(message)).is_err() { - break; - } - if is_close { - break; - } - } - Err(err) => { - let _ = tx_message.send(Err(err)); - break; - } - } - } - } - } - }); - - ( - Self { - tx_command, - pump_task, - }, - rx_message, - ) - } - - async fn request( - &self, - make_command: impl FnOnce(oneshot::Sender>) -> WsCommand, - ) -> Result<(), WsError> { - let (tx_result, rx_result) = oneshot::channel(); - if self.tx_command.send(make_command(tx_result)).await.is_err() { - return Err(WsError::ConnectionClosed); - } - rx_result.await.unwrap_or(Err(WsError::ConnectionClosed)) - } - - async fn send(&self, message: Message) -> Result<(), WsError> { - self.request(|tx_result| WsCommand::Send { message, tx_result }) - .await - } - - async fn close(&self) -> Result<(), WsError> { - self.request(|tx_result| WsCommand::Close { tx_result }) - .await - } -} - -impl Drop for WsStream { - fn drop(&mut self) { - self.pump_task.abort(); - } -} - pub struct RealtimeWebsocketConnection { writer: RealtimeWebsocketWriter, events: RealtimeWebsocketEvents, @@ -154,13 +33,13 @@ pub struct RealtimeWebsocketConnection { #[derive(Clone)] pub struct RealtimeWebsocketWriter { - stream: Arc, + stream: Arc, is_closed: Arc, } #[derive(Clone)] pub struct RealtimeWebsocketEvents { - rx_message: Arc>>>, + rx_message: Arc>>, is_closed: Arc, } @@ -209,10 +88,7 @@ impl RealtimeWebsocketConnection { self.events.clone() } - fn new( - stream: WsStream, - rx_message: mpsc::UnboundedReceiver>, - ) -> Self { + fn new(stream: WebsocketPump, rx_message: mpsc::UnboundedReceiver) -> Self { let stream = Arc::new(stream); let is_closed = Arc::new(AtomicBool::new(false)); Self { @@ -389,7 +265,7 @@ impl RealtimeWebsocketClient { ApiError::Stream(format!("failed to connect realtime websocket: {err}")) })?; - let (stream, rx_message) = WsStream::new(stream); + let (stream, rx_message) = WebsocketPump::new(stream); let connection = RealtimeWebsocketConnection::new(stream, rx_message); connection .send_session_create(config.prompt, config.session_id) @@ -445,6 +321,8 @@ fn websocket_url_from_api_url(api_url: &str) -> Result { #[cfg(test)] mod tests { use super::*; + use futures::SinkExt; + use futures::StreamExt; use http::HeaderValue; use pretty_assertions::assert_eq; use serde_json::Value; diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/mod.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/mod.rs index 7b10c6abdb..469fea8dcf 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/mod.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/mod.rs @@ -6,6 +6,5 @@ pub use methods::RealtimeWebsocketConnection; pub use methods::RealtimeWebsocketEvents; pub use methods::RealtimeWebsocketWriter; pub use protocol::RealtimeAudioFrame; -pub use protocol::RealtimeConnectionState; pub use protocol::RealtimeEvent; pub use protocol::RealtimeSessionConfig; diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs index f14df8bbb3..8b2b9b4d2a 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/protocol.rs @@ -19,15 +19,8 @@ pub struct RealtimeAudioFrame { pub samples_per_channel: Option, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum RealtimeConnectionState { - Connected, - Disconnected, -} - #[derive(Debug, Clone, PartialEq, Eq)] pub enum RealtimeEvent { - State(RealtimeConnectionState), SessionCreated { session_id: String }, SessionUpdated { backend_prompt: Option }, AudioOut(RealtimeAudioFrame), @@ -86,84 +79,68 @@ pub(super) struct ConversationItemContent { pub(super) text: String, } -#[derive(Debug, Deserialize)] -#[serde(tag = "type")] -enum RealtimeInboundMessage { - #[serde(rename = "session.created")] - SessionCreated { - session_id: Option, - session: Option, - }, - #[serde(rename = "session.updated")] - SessionUpdated { - session: Option, - }, - #[serde(rename = "response.output_audio.delta")] - OutputAudioDelta { - delta: Option, - data: Option, - sample_rate: Option, - num_channels: Option, - samples_per_channel: Option, - }, - #[serde(rename = "conversation.item.added")] - ConversationItemAdded { item: Option }, - #[serde(rename = "error")] - Error { - error: Option, - message: Option, - }, -} - -#[derive(Debug, Deserialize)] -struct RealtimeInboundSession { - id: Option, - backend_prompt: Option, -} - pub(super) fn parse_realtime_event(payload: &str) -> Option { - let parsed: RealtimeInboundMessage = match serde_json::from_str(payload) { - Ok(msg) => msg, + let parsed: Value = match serde_json::from_str(payload) { + Ok(value) => value, Err(err) => { debug!("failed to parse realtime event: {err}, data: {payload}"); return None; } }; - match parsed { - RealtimeInboundMessage::SessionCreated { - session_id, - session, - } => { - let session_id = session.and_then(|s| s.id).or(session_id); - session_id.map(|id| RealtimeEvent::SessionCreated { session_id: id }) + let event_type = parsed.get("type")?.as_str()?; + match event_type { + "session.created" => { + let session_id = parsed + .pointer("/session/id") + .and_then(Value::as_str) + .or_else(|| parsed.get("session_id").and_then(Value::as_str))?; + Some(RealtimeEvent::SessionCreated { + session_id: session_id.to_string(), + }) } - RealtimeInboundMessage::SessionUpdated { session } => Some(RealtimeEvent::SessionUpdated { - backend_prompt: session.and_then(|s| s.backend_prompt), + "session.updated" => Some(RealtimeEvent::SessionUpdated { + backend_prompt: parsed + .pointer("/session/backend_prompt") + .and_then(Value::as_str) + .map(ToString::to_string), }), - RealtimeInboundMessage::OutputAudioDelta { - delta, - data, - sample_rate, - num_channels, - samples_per_channel, - } => { - let data = delta.or(data)?; - let sample_rate = sample_rate?; - let num_channels = num_channels?; + "response.output_audio.delta" => { + let data = parsed + .get("delta") + .and_then(Value::as_str) + .or_else(|| parsed.get("data").and_then(Value::as_str))?; + let sample_rate = parsed + .get("sample_rate") + .and_then(Value::as_u64) + .and_then(|value| u32::try_from(value).ok())?; + let num_channels = parsed + .get("num_channels") + .and_then(Value::as_u64) + .and_then(|value| u16::try_from(value).ok())?; + let samples_per_channel = parsed + .get("samples_per_channel") + .and_then(Value::as_u64) + .and_then(|value| u32::try_from(value).ok()); Some(RealtimeEvent::AudioOut(RealtimeAudioFrame { - data, + data: data.to_string(), sample_rate, num_channels, samples_per_channel, })) } - RealtimeInboundMessage::ConversationItemAdded { item } => { - item.map(RealtimeEvent::ConversationItemAdded) - } - RealtimeInboundMessage::Error { error, message } => { - let message = message.or_else(|| error.map(|e| e.to_string()))?; + "conversation.item.added" => parsed + .get("item") + .cloned() + .map(RealtimeEvent::ConversationItemAdded), + "error" => { + let message = parsed + .get("message") + .and_then(Value::as_str) + .map(ToString::to_string) + .or_else(|| parsed.get("error").map(ToString::to_string))?; Some(RealtimeEvent::Error(message)) } + _ => None, } } diff --git a/codex-rs/codex-api/src/endpoint/responses_websocket.rs b/codex-rs/codex-api/src/endpoint/responses_websocket.rs index aa559e9836..419e2ef621 100644 --- a/codex-rs/codex-api/src/endpoint/responses_websocket.rs +++ b/codex-rs/codex-api/src/endpoint/responses_websocket.rs @@ -3,6 +3,8 @@ use crate::auth::add_auth_headers_to_header_map; use crate::common::ResponseEvent; use crate::common::ResponseStream; use crate::common::ResponsesWsRequest; +use crate::endpoint::websocket_pump::WebsocketMessage; +use crate::endpoint::websocket_pump::WebsocketPump; use crate::error::ApiError; use crate::provider::Provider; use crate::rate_limits::parse_rate_limit_event; @@ -11,8 +13,6 @@ use crate::sse::responses::process_responses_event; use crate::telemetry::WebsocketTelemetry; use codex_client::TransportError; use codex_utils_rustls_provider::ensure_rustls_crypto_provider; -use futures::SinkExt; -use futures::StreamExt; use http::HeaderMap; use http::HeaderName; use http::HeaderValue; @@ -23,10 +23,8 @@ use serde_json::map::Map as JsonMap; use std::sync::Arc; use std::sync::OnceLock; use std::time::Duration; -use tokio::net::TcpStream; use tokio::sync::Mutex; use tokio::sync::mpsc; -use tokio::sync::oneshot; use tokio::time::Instant; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::WebSocketStream; @@ -43,110 +41,22 @@ use tungstenite::protocol::WebSocketConfig; use url::Url; struct WsStream { - tx_command: mpsc::Sender, - rx_message: mpsc::UnboundedReceiver>, - pump_task: tokio::task::JoinHandle<()>, -} - -enum WsCommand { - Send { - message: Message, - tx_result: oneshot::Sender>, - }, - Close { - tx_result: oneshot::Sender>, - }, + pump: WebsocketPump, + rx_message: mpsc::UnboundedReceiver, } impl WsStream { - fn new(inner: WebSocketStream>) -> Self { - let (tx_command, mut rx_command) = mpsc::channel::(32); - let (tx_message, rx_message) = mpsc::unbounded_channel::>(); - - let pump_task = tokio::spawn(async move { - let mut inner = inner; - loop { - tokio::select! { - command = rx_command.recv() => { - let Some(command) = command else { - break; - }; - match command { - WsCommand::Send { message, tx_result } => { - let result = inner.send(message).await; - let should_break = result.is_err(); - let _ = tx_result.send(result); - if should_break { - break; - } - } - WsCommand::Close { tx_result } => { - let result = inner.close(None).await; - let _ = tx_result.send(result); - break; - } - } - } - message = inner.next() => { - let Some(message) = message else { - break; - }; - match message { - Ok(Message::Ping(payload)) => { - if let Err(err) = inner.send(Message::Pong(payload)).await { - let _ = tx_message.send(Err(err)); - break; - } - } - Ok(Message::Pong(_)) => {} - Ok(message @ (Message::Text(_) - | Message::Binary(_) - | Message::Close(_) - | Message::Frame(_))) => { - let is_close = matches!(message, Message::Close(_)); - if tx_message.send(Ok(message)).is_err() { - break; - } - if is_close { - break; - } - } - Err(err) => { - let _ = tx_message.send(Err(err)); - break; - } - } - } - } - } - }); - - Self { - tx_command, - rx_message, - pump_task, - } - } - - async fn request( - &self, - make_command: impl FnOnce(oneshot::Sender>) -> WsCommand, - ) -> Result<(), WsError> { - let (tx_result, rx_result) = oneshot::channel(); - if self.tx_command.send(make_command(tx_result)).await.is_err() { - return Err(WsError::ConnectionClosed); - } - rx_result.await.unwrap_or(Err(WsError::ConnectionClosed)) + fn new(inner: WebSocketStream>) -> Self { + let (pump, rx_message) = WebsocketPump::new(inner); + Self { pump, rx_message } } async fn send(&self, message: Message) -> Result<(), WsError> { - self.request(|tx_result| WsCommand::Send { message, tx_result }) - .await + self.pump.send(message).await } async fn close(&self) -> Result<(), WsError> { - self.request(|tx_result| WsCommand::Close { tx_result }) - .await + self.pump.close().await } async fn next(&mut self) -> Option> { @@ -154,12 +64,6 @@ impl WsStream { } } -impl Drop for WsStream { - fn drop(&mut self) { - self.pump_task.abort(); - } -} - const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state"; const X_MODELS_ETAG_HEADER: &str = "x-models-etag"; const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included"; diff --git a/codex-rs/codex-api/src/endpoint/websocket_pump.rs b/codex-rs/codex-api/src/endpoint/websocket_pump.rs new file mode 100644 index 0000000000..d7523bbeb0 --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/websocket_pump.rs @@ -0,0 +1,128 @@ +use futures::SinkExt; +use futures::StreamExt; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::tungstenite::Error as WsError; +use tokio_tungstenite::tungstenite::Message; + +pub(crate) type WebsocketMessage = Result; + +pub(crate) struct WebsocketPump { + tx_command: mpsc::Sender, + pump_task: tokio::task::JoinHandle<()>, +} + +enum WsCommand { + Send { + message: Message, + tx_result: oneshot::Sender>, + }, + Close { + tx_result: oneshot::Sender>, + }, +} + +impl WebsocketPump { + pub(crate) fn new( + inner: WebSocketStream>, + ) -> (Self, mpsc::UnboundedReceiver) { + let (tx_command, mut rx_command) = mpsc::channel::(32); + let (tx_message, rx_message) = mpsc::unbounded_channel::(); + + let pump_task = tokio::spawn(async move { + let mut inner = inner; + loop { + tokio::select! { + command = rx_command.recv() => { + let Some(command) = command else { + break; + }; + match command { + WsCommand::Send { message, tx_result } => { + let result = inner.send(message).await; + let should_break = result.is_err(); + let _ = tx_result.send(result); + if should_break { + break; + } + } + WsCommand::Close { tx_result } => { + let result = inner.close(None).await; + let _ = tx_result.send(result); + break; + } + } + } + message = inner.next() => { + let Some(message) = message else { + break; + }; + match message { + Ok(Message::Ping(payload)) => { + if let Err(err) = inner.send(Message::Pong(payload)).await { + let _ = tx_message.send(Err(err)); + break; + } + } + Ok(Message::Pong(_)) => {} + Ok(message @ (Message::Text(_) + | Message::Binary(_) + | Message::Close(_) + | Message::Frame(_))) => { + let is_close = matches!(message, Message::Close(_)); + if tx_message.send(Ok(message)).is_err() { + break; + } + if is_close { + break; + } + } + Err(err) => { + let _ = tx_message.send(Err(err)); + break; + } + } + } + } + } + }); + + ( + Self { + tx_command, + pump_task, + }, + rx_message, + ) + } + + pub(crate) async fn send(&self, message: Message) -> Result<(), WsError> { + self.request(|tx_result| WsCommand::Send { message, tx_result }) + .await + } + + pub(crate) async fn close(&self) -> Result<(), WsError> { + self.request(|tx_result| WsCommand::Close { tx_result }) + .await + } + + async fn request( + &self, + make_command: impl FnOnce(oneshot::Sender>) -> WsCommand, + ) -> Result<(), WsError> { + let (tx_result, rx_result) = oneshot::channel(); + if self.tx_command.send(make_command(tx_result)).await.is_err() { + return Err(WsError::ConnectionClosed); + } + rx_result.await.unwrap_or(Err(WsError::ConnectionClosed)) + } +} + +impl Drop for WebsocketPump { + fn drop(&mut self) { + self.pump_task.abort(); + } +} diff --git a/codex-rs/codex-api/src/lib.rs b/codex-rs/codex-api/src/lib.rs index 99d2c406ab..ff8953c034 100644 --- a/codex-rs/codex-api/src/lib.rs +++ b/codex-rs/codex-api/src/lib.rs @@ -30,7 +30,6 @@ pub use crate::endpoint::compact::CompactClient; pub use crate::endpoint::memories::MemoriesClient; pub use crate::endpoint::models::ModelsClient; pub use crate::endpoint::realtime_websocket::RealtimeAudioFrame; -pub use crate::endpoint::realtime_websocket::RealtimeConnectionState; pub use crate::endpoint::realtime_websocket::RealtimeEvent; pub use crate::endpoint::realtime_websocket::RealtimeSessionConfig; pub use crate::endpoint::realtime_websocket::RealtimeWebsocketClient; diff --git a/codex-rs/codex-api/tests/common/mod.rs b/codex-rs/codex-api/tests/common/mod.rs new file mode 100644 index 0000000000..773a258e81 --- /dev/null +++ b/codex-rs/codex-api/tests/common/mod.rs @@ -0,0 +1 @@ +pub mod ws_harness; diff --git a/codex-rs/codex-api/tests/common/ws_harness.rs b/codex-rs/codex-api/tests/common/ws_harness.rs new file mode 100644 index 0000000000..ed124ced12 --- /dev/null +++ b/codex-rs/codex-api/tests/common/ws_harness.rs @@ -0,0 +1,43 @@ +use std::collections::HashMap; +use std::future::Future; +use std::time::Duration; + +use codex_api::Provider; +use codex_api::provider::RetryConfig; +use http::HeaderMap; +use tokio::net::TcpListener; +use tokio::task::JoinHandle; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::accept_async; + +pub(crate) async fn spawn_ws_server(handler: F) -> (String, JoinHandle<()>) +where + F: FnOnce(WebSocketStream) -> Fut + Send + 'static, + Fut: Future + Send + 'static, +{ + let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind"); + let addr = listener.local_addr().expect("local addr"); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.expect("accept"); + let ws = accept_async(stream).await.expect("accept ws"); + handler(ws).await; + }); + (format!("ws://{addr}"), server) +} + +pub(crate) fn test_provider() -> Provider { + Provider { + name: "test".to_string(), + base_url: "http://localhost".to_string(), + query_params: Some(HashMap::new()), + headers: HeaderMap::new(), + retry: RetryConfig { + max_attempts: 1, + base_delay: Duration::from_millis(1), + retry_429: false, + retry_5xx: false, + retry_transport: false, + }, + stream_idle_timeout: Duration::from_secs(5), + } +} diff --git a/codex-rs/codex-api/tests/realtime_websocket_e2e.rs b/codex-rs/codex-api/tests/realtime_websocket_e2e.rs index b9d252d3a8..e5e477bfbb 100644 --- a/codex-rs/codex-api/tests/realtime_websocket_e2e.rs +++ b/codex-rs/codex-api/tests/realtime_websocket_e2e.rs @@ -1,30 +1,22 @@ -use std::collections::HashMap; +mod common; + use std::time::Duration; use codex_api::RealtimeAudioFrame; use codex_api::RealtimeEvent; use codex_api::RealtimeSessionConfig; use codex_api::RealtimeWebsocketClient; -use codex_api::provider::Provider; -use codex_api::provider::RetryConfig; +use common::ws_harness; use futures::SinkExt; use futures::StreamExt; use http::HeaderMap; use serde_json::Value; use serde_json::json; -use tokio::net::TcpListener; -use tokio_tungstenite::accept_async; use tokio_tungstenite::tungstenite::Message; #[tokio::test] async fn realtime_ws_e2e_session_create_and_event_flow() { - let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind"); - let addr = listener.local_addr().expect("local addr"); - - let server = tokio::spawn(async move { - let (stream, _) = listener.accept().await.expect("accept"); - let mut ws = accept_async(stream).await.expect("accept ws"); - + let (api_url, server) = ws_harness::spawn_ws_server(|mut ws| async move { let first = ws .next() .await @@ -76,27 +68,14 @@ async fn realtime_ws_e2e_session_create_and_event_flow() { )) .await .expect("send audio out"); - }); + }) + .await; - let provider = Provider { - name: "test".to_string(), - base_url: "http://localhost".to_string(), - query_params: Some(HashMap::new()), - headers: HeaderMap::new(), - retry: RetryConfig { - max_attempts: 1, - base_delay: Duration::from_millis(1), - retry_429: false, - retry_5xx: false, - retry_transport: false, - }, - stream_idle_timeout: Duration::from_secs(5), - }; - let client = RealtimeWebsocketClient::new(provider); + let client = RealtimeWebsocketClient::new(ws_harness::test_provider()); let connection = client .connect( RealtimeSessionConfig { - api_url: format!("ws://{addr}"), + api_url, prompt: "backend prompt".to_string(), session_id: Some("conv_123".to_string()), }, @@ -149,13 +128,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() { #[tokio::test] async fn realtime_ws_e2e_send_while_next_event_waits() { - let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind"); - let addr = listener.local_addr().expect("local addr"); - - let server = tokio::spawn(async move { - let (stream, _) = listener.accept().await.expect("accept"); - let mut ws = accept_async(stream).await.expect("accept ws"); - + let (api_url, server) = ws_harness::spawn_ws_server(|mut ws| async move { let first = ws .next() .await @@ -186,27 +159,14 @@ async fn realtime_ws_e2e_send_while_next_event_waits() { )) .await .expect("send session.created"); - }); + }) + .await; - let provider = Provider { - name: "test".to_string(), - base_url: "http://localhost".to_string(), - query_params: Some(HashMap::new()), - headers: HeaderMap::new(), - retry: RetryConfig { - max_attempts: 1, - base_delay: Duration::from_millis(1), - retry_429: false, - retry_5xx: false, - retry_transport: false, - }, - stream_idle_timeout: Duration::from_secs(5), - }; - let client = RealtimeWebsocketClient::new(provider); + let client = RealtimeWebsocketClient::new(ws_harness::test_provider()); let connection = client .connect( RealtimeSessionConfig { - api_url: format!("ws://{addr}"), + api_url, prompt: "backend prompt".to_string(), session_id: Some("conv_123".to_string()), }, @@ -249,13 +209,7 @@ async fn realtime_ws_e2e_send_while_next_event_waits() { #[tokio::test] async fn realtime_ws_e2e_disconnected_emitted_once() { - let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind"); - let addr = listener.local_addr().expect("local addr"); - - let server = tokio::spawn(async move { - let (stream, _) = listener.accept().await.expect("accept"); - let mut ws = accept_async(stream).await.expect("accept ws"); - + let (api_url, server) = ws_harness::spawn_ws_server(|mut ws| async move { let first = ws .next() .await @@ -267,27 +221,14 @@ async fn realtime_ws_e2e_disconnected_emitted_once() { assert_eq!(first_json["type"], "session.create"); ws.send(Message::Close(None)).await.expect("send close"); - }); + }) + .await; - let provider = Provider { - name: "test".to_string(), - base_url: "http://localhost".to_string(), - query_params: Some(HashMap::new()), - headers: HeaderMap::new(), - retry: RetryConfig { - max_attempts: 1, - base_delay: Duration::from_millis(1), - retry_429: false, - retry_5xx: false, - retry_transport: false, - }, - stream_idle_timeout: Duration::from_secs(5), - }; - let client = RealtimeWebsocketClient::new(provider); + let client = RealtimeWebsocketClient::new(ws_harness::test_provider()); let connection = client .connect( RealtimeSessionConfig { - api_url: format!("ws://{addr}"), + api_url, prompt: "backend prompt".to_string(), session_id: Some("conv_123".to_string()), }, @@ -308,13 +249,7 @@ async fn realtime_ws_e2e_disconnected_emitted_once() { #[tokio::test] async fn realtime_ws_e2e_ignores_unknown_text_events() { - let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind"); - let addr = listener.local_addr().expect("local addr"); - - let server = tokio::spawn(async move { - let (stream, _) = listener.accept().await.expect("accept"); - let mut ws = accept_async(stream).await.expect("accept ws"); - + let (api_url, server) = ws_harness::spawn_ws_server(|mut ws| async move { let first = ws .next() .await @@ -346,27 +281,14 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() { )) .await .expect("send session.created"); - }); + }) + .await; - let provider = Provider { - name: "test".to_string(), - base_url: "http://localhost".to_string(), - query_params: Some(HashMap::new()), - headers: HeaderMap::new(), - retry: RetryConfig { - max_attempts: 1, - base_delay: Duration::from_millis(1), - retry_429: false, - retry_5xx: false, - retry_transport: false, - }, - stream_idle_timeout: Duration::from_secs(5), - }; - let client = RealtimeWebsocketClient::new(provider); + let client = RealtimeWebsocketClient::new(ws_harness::test_provider()); let connection = client .connect( RealtimeSessionConfig { - api_url: format!("ws://{addr}"), + api_url, prompt: "backend prompt".to_string(), session_id: Some("conv_123".to_string()), },