Attach WebRTC realtime starts to sideband websocket (#17057)

Summary:
- parse the realtime call Location header and join that call over the
direct realtime WebSocket
- keep WebRTC starts alive on the existing realtime conversation path

Validation:
- just fmt
- git diff --check
- cargo check -p codex-api
- cargo check -p codex-core --tests
- local cargo tests not run; relying on PR CI
This commit is contained in:
Ahmed Ibrahim
2026-04-08 15:25:42 -07:00
committed by GitHub
parent 19bd018300
commit 794a0240f9
7 changed files with 534 additions and 89 deletions

View File

@@ -12,6 +12,7 @@ use http::HeaderMap;
use http::HeaderValue;
use http::Method;
use http::header::CONTENT_TYPE;
use http::header::LOCATION;
use serde::Serialize;
use serde_json::Value;
use serde_json::to_string;
@@ -26,9 +27,14 @@ pub struct RealtimeCallClient<T: HttpTransport, A: AuthProvider> {
session: EndpointSession<T, A>,
}
/// Answer from creating a WebRTC Realtime call.
///
/// `sdp` configures the peer connection. `call_id` is parsed from the response `Location` header
/// and is later used by the server-side sideband WebSocket to join this exact call.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RealtimeCallResponse {
pub sdp: String,
pub call_id: String,
}
#[derive(Serialize)]
@@ -101,8 +107,9 @@ impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
.await?;
let sdp = decode_sdp_response(resp.body.as_ref())?;
let call_id = decode_call_id_from_location(&resp.headers)?;
Ok(RealtimeCallResponse { sdp })
Ok(RealtimeCallResponse { sdp, call_id })
}
pub async fn create_with_session_and_headers(
@@ -111,6 +118,9 @@ impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
session_config: RealtimeSessionConfig,
extra_headers: HeaderMap,
) -> Result<RealtimeCallResponse, ApiError> {
// WebRTC can begin inference as soon as the peer connection comes up, so the initial
// session payload is sent with call creation. The sideband WebSocket still sends its normal
// session.update after it joins.
let mut session = realtime_session_json(session_config)?;
if let Some(session) = session.as_object_mut() {
session.remove("id");
@@ -127,7 +137,8 @@ impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
.execute(Method::POST, Self::path(), extra_headers, Some(body))
.await?;
let sdp = decode_sdp_response(resp.body.as_ref())?;
return Ok(RealtimeCallResponse { sdp });
let call_id = decode_call_id_from_location(&resp.headers)?;
return Ok(RealtimeCallResponse { sdp, call_id });
}
let session = to_string(&session).map_err(|err| ApiError::InvalidRequest {
@@ -164,8 +175,9 @@ impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
.await?;
let sdp = decode_sdp_response(resp.body.as_ref())?;
let call_id = decode_call_id_from_location(&resp.headers)?;
Ok(RealtimeCallResponse { sdp })
Ok(RealtimeCallResponse { sdp, call_id })
}
}
@@ -182,6 +194,27 @@ fn decode_sdp_response(body: &[u8]) -> Result<String, ApiError> {
})
}
fn decode_call_id_from_location(headers: &HeaderMap) -> Result<String, ApiError> {
let location = headers
.get(LOCATION)
.ok_or_else(|| ApiError::Stream("realtime call response missing Location".to_string()))?
.to_str()
.map_err(|err| ApiError::Stream(format!("invalid realtime call Location: {err}")))?;
location
.split('?')
.next()
.unwrap_or(location)
.rsplit('/')
.find(|segment| segment.starts_with("rtc_") && segment.len() > "rtc_".len())
.map(str::to_string)
.ok_or_else(|| {
ApiError::Stream(format!(
"realtime call Location does not contain a call id: {location}"
))
})
}
#[cfg(test)]
mod tests {
use super::*;
@@ -201,12 +234,27 @@ mod tests {
#[derive(Clone)]
struct CapturingTransport {
last_request: Arc<Mutex<Option<Request>>>,
response_headers: HeaderMap,
}
impl CapturingTransport {
fn new() -> Self {
Self::with_location("/v1/realtime/calls/rtc_test")
}
fn with_location(location: &str) -> Self {
let mut response_headers = HeaderMap::new();
response_headers.insert(LOCATION, HeaderValue::from_str(location).unwrap());
Self {
last_request: Arc::new(Mutex::new(None)),
response_headers,
}
}
fn without_location() -> Self {
Self {
last_request: Arc::new(Mutex::new(None)),
response_headers: HeaderMap::new(),
}
}
}
@@ -217,7 +265,7 @@ mod tests {
*self.last_request.lock().unwrap() = Some(req);
Ok(Response {
status: StatusCode::OK,
headers: HeaderMap::new(),
headers: self.response_headers.clone(),
body: Bytes::from_static(b"v=0\r\n"),
})
}
@@ -280,7 +328,8 @@ mod tests {
assert_eq!(
response,
RealtimeCallResponse {
sdp: "v=0\r\n".to_string()
sdp: "v=0\r\n".to_string(),
call_id: "rtc_test".to_string(),
}
);
@@ -304,6 +353,41 @@ mod tests {
);
}
#[tokio::test]
async fn extracts_call_id_from_forwarded_backend_location() {
let transport =
CapturingTransport::with_location("/v1/realtime/calls/calls/rtc_backend_test");
let client = RealtimeCallClient::new(
transport.clone(),
provider("https://chatgpt.com/backend-api/codex"),
DummyAuth,
);
let response = client
.create("v=offer\r\n".to_string())
.await
.expect("request should succeed");
assert_eq!(
response,
RealtimeCallResponse {
sdp: "v=0\r\n".to_string(),
call_id: "rtc_backend_test".to_string(),
}
);
let request = transport.last_request.lock().unwrap().clone().unwrap();
assert_eq!(request.method, Method::POST);
assert_eq!(
request.url,
"https://chatgpt.com/backend-api/codex/realtime/calls"
);
assert_eq!(
request.body,
Some(RequestBody::Raw(Bytes::from_static(b"v=offer\r\n")))
);
}
#[tokio::test]
async fn sends_api_session_call_as_multipart_body() {
let transport = CapturingTransport::new();
@@ -324,7 +408,8 @@ mod tests {
assert_eq!(
response,
RealtimeCallResponse {
sdp: "v=0\r\n".to_string()
sdp: "v=0\r\n".to_string(),
call_id: "rtc_test".to_string(),
}
);
@@ -385,7 +470,8 @@ mod tests {
assert_eq!(
response,
RealtimeCallResponse {
sdp: "v=0\r\n".to_string()
sdp: "v=0\r\n".to_string(),
call_id: "rtc_test".to_string(),
}
);
@@ -412,4 +498,35 @@ mod tests {
))
);
}
#[tokio::test]
async fn errors_when_location_is_missing() {
let transport = CapturingTransport::without_location();
let client =
RealtimeCallClient::new(transport, provider("https://api.openai.com/v1"), DummyAuth);
let err = client
.create("v=offer\r\n".to_string())
.await
.expect_err("request should require Location");
assert_eq!(
err.to_string(),
"stream error: realtime call response missing Location"
);
}
#[test]
fn rejects_location_without_call_id() {
let mut headers = HeaderMap::new();
headers.insert(LOCATION, HeaderValue::from_static("/v1/realtime/calls"));
let err = decode_call_id_from_location(&headers)
.expect_err("Location without rtc_ segment should fail");
assert_eq!(
err.to_string(),
"stream error: realtime call Location does not contain a call id: /v1/realtime/calls"
);
}
}

View File

@@ -14,6 +14,7 @@ use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptEntry;
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
use crate::error::ApiError;
use crate::provider::Provider;
use codex_client::backoff;
use codex_client::maybe_build_rustls_client_config_with_custom_ca;
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
use futures::SinkExt;
@@ -28,6 +29,7 @@ use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::time::sleep;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Error as WsError;
@@ -37,6 +39,7 @@ use tracing::debug;
use tracing::error;
use tracing::info;
use tracing::trace;
use tracing::warn;
use tungstenite::protocol::WebSocketConfig;
use url::Url;
@@ -455,7 +458,6 @@ impl RealtimeWebsocketClient {
extra_headers: HeaderMap,
default_headers: HeaderMap,
) -> Result<RealtimeWebsocketConnection, ApiError> {
ensure_rustls_crypto_provider();
let ws_url = websocket_url_from_api_url(
self.provider.base_url.as_str(),
self.provider.query_params.as_ref(),
@@ -463,6 +465,78 @@ impl RealtimeWebsocketClient {
config.event_parser,
config.session_mode,
)?;
self.connect_realtime_websocket_url(ws_url, config, extra_headers, default_headers)
.await
}
pub async fn connect_webrtc_sideband(
&self,
config: RealtimeSessionConfig,
call_id: &str,
extra_headers: HeaderMap,
default_headers: HeaderMap,
) -> Result<RealtimeWebsocketConnection, ApiError> {
// The WebRTC call already exists; this loop only retries joining its sideband control
// socket. Once joined, the returned connection is the same reader/writer state that the
// ordinary websocket start path uses.
for attempt in 0..=self.provider.retry.max_attempts {
let result = self
.connect_webrtc_sideband_once(
config.clone(),
call_id,
extra_headers.clone(),
default_headers.clone(),
)
.await;
match result {
Ok(connection) => return Ok(connection),
Err(err) if attempt < self.provider.retry.max_attempts => {
let delay = backoff(self.provider.retry.base_delay, attempt + 1);
warn!(
attempt = attempt + 1,
call_id,
delay_ms = delay.as_millis(),
"realtime sideband websocket connect failed; retrying: {err}"
);
sleep(delay).await;
}
Err(err) => return Err(err),
}
}
Err(ApiError::Stream(
"realtime sideband websocket retry loop exhausted".to_string(),
))
}
async fn connect_webrtc_sideband_once(
&self,
config: RealtimeSessionConfig,
call_id: &str,
extra_headers: HeaderMap,
default_headers: HeaderMap,
) -> Result<RealtimeWebsocketConnection, ApiError> {
// Keep the parser/session query shaping from standalone realtime while replacing the model
// query with a call_id join onto an existing WebRTC session.
let ws_url = websocket_url_from_api_url_for_call(
self.provider.base_url.as_str(),
self.provider.query_params.as_ref(),
config.event_parser,
config.session_mode,
call_id,
)?;
self.connect_realtime_websocket_url(ws_url, config, extra_headers, default_headers)
.await
}
async fn connect_realtime_websocket_url(
&self,
ws_url: Url,
config: RealtimeSessionConfig,
extra_headers: HeaderMap,
default_headers: HeaderMap,
) -> Result<RealtimeWebsocketConnection, ApiError> {
ensure_rustls_crypto_provider();
let mut request = ws_url
.as_str()
@@ -596,6 +670,24 @@ fn websocket_url_from_api_url(
Ok(url)
}
fn websocket_url_from_api_url_for_call(
api_url: &str,
query_params: Option<&HashMap<String, String>>,
event_parser: RealtimeEventParser,
session_mode: RealtimeSessionMode,
call_id: &str,
) -> Result<Url, ApiError> {
let mut url = websocket_url_from_api_url(
api_url,
query_params,
/*model*/ None,
event_parser,
session_mode,
)?;
url.query_pairs_mut().append_pair("call_id", call_id);
Ok(url)
}
fn normalize_realtime_path(url: &mut Url) {
let path = url.path().to_string();
if path.is_empty() || path == "/" {
@@ -1094,6 +1186,22 @@ mod tests {
assert_eq!(url.as_str(), "wss://example.com/v1/realtime");
}
#[test]
fn websocket_url_for_call_id_joins_existing_realtime_session() {
let url = websocket_url_from_api_url_for_call(
"https://api.openai.com/v1",
/*query_params*/ None,
RealtimeEventParser::RealtimeV2,
RealtimeSessionMode::Conversational,
"rtc_test",
)
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://api.openai.com/v1/realtime?call_id=rtc_test"
);
}
#[tokio::test]
async fn e2e_connect_and_exchange_events_against_mock_ws_server() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");