mirror of
https://github.com/openai/codex.git
synced 2026-05-02 12:21:26 +03:00
Attach WebRTC realtime starts to sideband websocket (#17057)
Summary: - parse the realtime call Location header and join that call over the direct realtime WebSocket - keep WebRTC starts alive on the existing realtime conversation path Validation: - just fmt - git diff --check - cargo check -p codex-api - cargo check -p codex-core --tests - local cargo tests not run; relying on PR CI
This commit is contained in:
@@ -12,6 +12,7 @@ use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use http::Method;
|
||||
use http::header::CONTENT_TYPE;
|
||||
use http::header::LOCATION;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::to_string;
|
||||
@@ -26,9 +27,14 @@ pub struct RealtimeCallClient<T: HttpTransport, A: AuthProvider> {
|
||||
session: EndpointSession<T, A>,
|
||||
}
|
||||
|
||||
/// Answer from creating a WebRTC Realtime call.
|
||||
///
|
||||
/// `sdp` configures the peer connection. `call_id` is parsed from the response `Location` header
|
||||
/// and is later used by the server-side sideband WebSocket to join this exact call.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RealtimeCallResponse {
|
||||
pub sdp: String,
|
||||
pub call_id: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -101,8 +107,9 @@ impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
|
||||
.await?;
|
||||
|
||||
let sdp = decode_sdp_response(resp.body.as_ref())?;
|
||||
let call_id = decode_call_id_from_location(&resp.headers)?;
|
||||
|
||||
Ok(RealtimeCallResponse { sdp })
|
||||
Ok(RealtimeCallResponse { sdp, call_id })
|
||||
}
|
||||
|
||||
pub async fn create_with_session_and_headers(
|
||||
@@ -111,6 +118,9 @@ impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
|
||||
session_config: RealtimeSessionConfig,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<RealtimeCallResponse, ApiError> {
|
||||
// WebRTC can begin inference as soon as the peer connection comes up, so the initial
|
||||
// session payload is sent with call creation. The sideband WebSocket still sends its normal
|
||||
// session.update after it joins.
|
||||
let mut session = realtime_session_json(session_config)?;
|
||||
if let Some(session) = session.as_object_mut() {
|
||||
session.remove("id");
|
||||
@@ -127,7 +137,8 @@ impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
|
||||
.execute(Method::POST, Self::path(), extra_headers, Some(body))
|
||||
.await?;
|
||||
let sdp = decode_sdp_response(resp.body.as_ref())?;
|
||||
return Ok(RealtimeCallResponse { sdp });
|
||||
let call_id = decode_call_id_from_location(&resp.headers)?;
|
||||
return Ok(RealtimeCallResponse { sdp, call_id });
|
||||
}
|
||||
|
||||
let session = to_string(&session).map_err(|err| ApiError::InvalidRequest {
|
||||
@@ -164,8 +175,9 @@ impl<T: HttpTransport, A: AuthProvider> RealtimeCallClient<T, A> {
|
||||
.await?;
|
||||
|
||||
let sdp = decode_sdp_response(resp.body.as_ref())?;
|
||||
let call_id = decode_call_id_from_location(&resp.headers)?;
|
||||
|
||||
Ok(RealtimeCallResponse { sdp })
|
||||
Ok(RealtimeCallResponse { sdp, call_id })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,6 +194,27 @@ fn decode_sdp_response(body: &[u8]) -> Result<String, ApiError> {
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_call_id_from_location(headers: &HeaderMap) -> Result<String, ApiError> {
|
||||
let location = headers
|
||||
.get(LOCATION)
|
||||
.ok_or_else(|| ApiError::Stream("realtime call response missing Location".to_string()))?
|
||||
.to_str()
|
||||
.map_err(|err| ApiError::Stream(format!("invalid realtime call Location: {err}")))?;
|
||||
|
||||
location
|
||||
.split('?')
|
||||
.next()
|
||||
.unwrap_or(location)
|
||||
.rsplit('/')
|
||||
.find(|segment| segment.starts_with("rtc_") && segment.len() > "rtc_".len())
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| {
|
||||
ApiError::Stream(format!(
|
||||
"realtime call Location does not contain a call id: {location}"
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -201,12 +234,27 @@ mod tests {
|
||||
#[derive(Clone)]
|
||||
struct CapturingTransport {
|
||||
last_request: Arc<Mutex<Option<Request>>>,
|
||||
response_headers: HeaderMap,
|
||||
}
|
||||
|
||||
impl CapturingTransport {
|
||||
fn new() -> Self {
|
||||
Self::with_location("/v1/realtime/calls/rtc_test")
|
||||
}
|
||||
|
||||
fn with_location(location: &str) -> Self {
|
||||
let mut response_headers = HeaderMap::new();
|
||||
response_headers.insert(LOCATION, HeaderValue::from_str(location).unwrap());
|
||||
Self {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
response_headers,
|
||||
}
|
||||
}
|
||||
|
||||
fn without_location() -> Self {
|
||||
Self {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
response_headers: HeaderMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -217,7 +265,7 @@ mod tests {
|
||||
*self.last_request.lock().unwrap() = Some(req);
|
||||
Ok(Response {
|
||||
status: StatusCode::OK,
|
||||
headers: HeaderMap::new(),
|
||||
headers: self.response_headers.clone(),
|
||||
body: Bytes::from_static(b"v=0\r\n"),
|
||||
})
|
||||
}
|
||||
@@ -280,7 +328,8 @@ mod tests {
|
||||
assert_eq!(
|
||||
response,
|
||||
RealtimeCallResponse {
|
||||
sdp: "v=0\r\n".to_string()
|
||||
sdp: "v=0\r\n".to_string(),
|
||||
call_id: "rtc_test".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
@@ -304,6 +353,41 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extracts_call_id_from_forwarded_backend_location() {
|
||||
let transport =
|
||||
CapturingTransport::with_location("/v1/realtime/calls/calls/rtc_backend_test");
|
||||
let client = RealtimeCallClient::new(
|
||||
transport.clone(),
|
||||
provider("https://chatgpt.com/backend-api/codex"),
|
||||
DummyAuth,
|
||||
);
|
||||
|
||||
let response = client
|
||||
.create("v=offer\r\n".to_string())
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
RealtimeCallResponse {
|
||||
sdp: "v=0\r\n".to_string(),
|
||||
call_id: "rtc_backend_test".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
let request = transport.last_request.lock().unwrap().clone().unwrap();
|
||||
assert_eq!(request.method, Method::POST);
|
||||
assert_eq!(
|
||||
request.url,
|
||||
"https://chatgpt.com/backend-api/codex/realtime/calls"
|
||||
);
|
||||
assert_eq!(
|
||||
request.body,
|
||||
Some(RequestBody::Raw(Bytes::from_static(b"v=offer\r\n")))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sends_api_session_call_as_multipart_body() {
|
||||
let transport = CapturingTransport::new();
|
||||
@@ -324,7 +408,8 @@ mod tests {
|
||||
assert_eq!(
|
||||
response,
|
||||
RealtimeCallResponse {
|
||||
sdp: "v=0\r\n".to_string()
|
||||
sdp: "v=0\r\n".to_string(),
|
||||
call_id: "rtc_test".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
@@ -385,7 +470,8 @@ mod tests {
|
||||
assert_eq!(
|
||||
response,
|
||||
RealtimeCallResponse {
|
||||
sdp: "v=0\r\n".to_string()
|
||||
sdp: "v=0\r\n".to_string(),
|
||||
call_id: "rtc_test".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
@@ -412,4 +498,35 @@ mod tests {
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn errors_when_location_is_missing() {
|
||||
let transport = CapturingTransport::without_location();
|
||||
let client =
|
||||
RealtimeCallClient::new(transport, provider("https://api.openai.com/v1"), DummyAuth);
|
||||
|
||||
let err = client
|
||||
.create("v=offer\r\n".to_string())
|
||||
.await
|
||||
.expect_err("request should require Location");
|
||||
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"stream error: realtime call response missing Location"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_location_without_call_id() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(LOCATION, HeaderValue::from_static("/v1/realtime/calls"));
|
||||
|
||||
let err = decode_call_id_from_location(&headers)
|
||||
.expect_err("Location without rtc_ segment should fail");
|
||||
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"stream error: realtime call Location does not contain a call id: /v1/realtime/calls"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ use crate::endpoint::realtime_websocket::protocol::RealtimeTranscriptEntry;
|
||||
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use codex_client::backoff;
|
||||
use codex_client::maybe_build_rustls_client_config_with_custom_ca;
|
||||
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
|
||||
use futures::SinkExt;
|
||||
@@ -28,6 +29,7 @@ use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::sleep;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::tungstenite::Error as WsError;
|
||||
@@ -37,6 +39,7 @@ use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
use tungstenite::protocol::WebSocketConfig;
|
||||
use url::Url;
|
||||
|
||||
@@ -455,7 +458,6 @@ impl RealtimeWebsocketClient {
|
||||
extra_headers: HeaderMap,
|
||||
default_headers: HeaderMap,
|
||||
) -> Result<RealtimeWebsocketConnection, ApiError> {
|
||||
ensure_rustls_crypto_provider();
|
||||
let ws_url = websocket_url_from_api_url(
|
||||
self.provider.base_url.as_str(),
|
||||
self.provider.query_params.as_ref(),
|
||||
@@ -463,6 +465,78 @@ impl RealtimeWebsocketClient {
|
||||
config.event_parser,
|
||||
config.session_mode,
|
||||
)?;
|
||||
self.connect_realtime_websocket_url(ws_url, config, extra_headers, default_headers)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn connect_webrtc_sideband(
|
||||
&self,
|
||||
config: RealtimeSessionConfig,
|
||||
call_id: &str,
|
||||
extra_headers: HeaderMap,
|
||||
default_headers: HeaderMap,
|
||||
) -> Result<RealtimeWebsocketConnection, ApiError> {
|
||||
// The WebRTC call already exists; this loop only retries joining its sideband control
|
||||
// socket. Once joined, the returned connection is the same reader/writer state that the
|
||||
// ordinary websocket start path uses.
|
||||
for attempt in 0..=self.provider.retry.max_attempts {
|
||||
let result = self
|
||||
.connect_webrtc_sideband_once(
|
||||
config.clone(),
|
||||
call_id,
|
||||
extra_headers.clone(),
|
||||
default_headers.clone(),
|
||||
)
|
||||
.await;
|
||||
match result {
|
||||
Ok(connection) => return Ok(connection),
|
||||
Err(err) if attempt < self.provider.retry.max_attempts => {
|
||||
let delay = backoff(self.provider.retry.base_delay, attempt + 1);
|
||||
warn!(
|
||||
attempt = attempt + 1,
|
||||
call_id,
|
||||
delay_ms = delay.as_millis(),
|
||||
"realtime sideband websocket connect failed; retrying: {err}"
|
||||
);
|
||||
sleep(delay).await;
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
Err(ApiError::Stream(
|
||||
"realtime sideband websocket retry loop exhausted".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn connect_webrtc_sideband_once(
|
||||
&self,
|
||||
config: RealtimeSessionConfig,
|
||||
call_id: &str,
|
||||
extra_headers: HeaderMap,
|
||||
default_headers: HeaderMap,
|
||||
) -> Result<RealtimeWebsocketConnection, ApiError> {
|
||||
// Keep the parser/session query shaping from standalone realtime while replacing the model
|
||||
// query with a call_id join onto an existing WebRTC session.
|
||||
let ws_url = websocket_url_from_api_url_for_call(
|
||||
self.provider.base_url.as_str(),
|
||||
self.provider.query_params.as_ref(),
|
||||
config.event_parser,
|
||||
config.session_mode,
|
||||
call_id,
|
||||
)?;
|
||||
self.connect_realtime_websocket_url(ws_url, config, extra_headers, default_headers)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn connect_realtime_websocket_url(
|
||||
&self,
|
||||
ws_url: Url,
|
||||
config: RealtimeSessionConfig,
|
||||
extra_headers: HeaderMap,
|
||||
default_headers: HeaderMap,
|
||||
) -> Result<RealtimeWebsocketConnection, ApiError> {
|
||||
ensure_rustls_crypto_provider();
|
||||
|
||||
let mut request = ws_url
|
||||
.as_str()
|
||||
@@ -596,6 +670,24 @@ fn websocket_url_from_api_url(
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
fn websocket_url_from_api_url_for_call(
|
||||
api_url: &str,
|
||||
query_params: Option<&HashMap<String, String>>,
|
||||
event_parser: RealtimeEventParser,
|
||||
session_mode: RealtimeSessionMode,
|
||||
call_id: &str,
|
||||
) -> Result<Url, ApiError> {
|
||||
let mut url = websocket_url_from_api_url(
|
||||
api_url,
|
||||
query_params,
|
||||
/*model*/ None,
|
||||
event_parser,
|
||||
session_mode,
|
||||
)?;
|
||||
url.query_pairs_mut().append_pair("call_id", call_id);
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
fn normalize_realtime_path(url: &mut Url) {
|
||||
let path = url.path().to_string();
|
||||
if path.is_empty() || path == "/" {
|
||||
@@ -1094,6 +1186,22 @@ mod tests {
|
||||
assert_eq!(url.as_str(), "wss://example.com/v1/realtime");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_for_call_id_joins_existing_realtime_session() {
|
||||
let url = websocket_url_from_api_url_for_call(
|
||||
"https://api.openai.com/v1",
|
||||
/*query_params*/ None,
|
||||
RealtimeEventParser::RealtimeV2,
|
||||
RealtimeSessionMode::Conversational,
|
||||
"rtc_test",
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://api.openai.com/v1/realtime?call_id=rtc_test"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn e2e_connect_and_exchange_events_against_mock_ws_server() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
|
||||
Reference in New Issue
Block a user