Compare commits

...

3 Commits

Author SHA1 Message Date
pakrym-oai
a14a018b2f Fall back to http on websocket request 2026-02-05 14:03:03 -08:00
pakrym-oai
dfa04d5a47 Merge branch 'main' into pakrym/send-beta-header-with-websocket-connects
# Conflicts:
#	codex-rs/core/src/client.rs
2026-02-05 13:34:28 -08:00
pakrym-oai
7e3eb2b392 Send beta header with websocket connects 2026-02-05 00:08:19 -08:00
4 changed files with 146 additions and 38 deletions

View File

@@ -82,6 +82,8 @@ use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::tools::spec::create_tools_json_for_responses_api;
pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
pub const OPENAI_BETA_RESPONSES_WEBSOCKETS: &str = "responses_websockets=2026-02-04";
pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata";
pub const X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER: &str =
@@ -153,6 +155,11 @@ pub struct ModelClientSession {
turn_state: Arc<OnceLock<String>>,
}
enum WebsocketStreamOutcome {
Stream(ResponseStream),
FallbackToHttp,
}
impl ModelClient {
#[allow(clippy::too_many_arguments)]
/// Creates a new session-scoped `ModelClient`.
@@ -329,6 +336,20 @@ impl ModelClientSession {
.swap(true, Ordering::Relaxed)
}
fn switch_to_http_fallback_silent(&mut self, otel_manager: &OtelManager) {
let websocket_enabled = self.responses_websocket_enabled();
if self.activate_http_fallback(websocket_enabled) {
otel_manager.counter(
"codex.transport.fallback_to_http",
1,
&[("from_wire_api", "responses_websocket")],
);
}
self.connection = None;
self.websocket_last_items.clear();
}
fn responses_websocket_enabled(&self) -> bool {
self.client.state.provider.supports_websockets
&& self.client.state.enable_responses_websockets
@@ -477,14 +498,8 @@ impl ModelClientSession {
};
if needs_new {
let mut headers = options.extra_headers.clone();
headers.extend(build_conversation_headers(options.conversation_id.clone()));
if self.client.state.include_timing_metrics {
headers.insert(
X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER,
HeaderValue::from_static("true"),
);
}
let headers =
build_websocket_connect_headers(options, self.client.state.include_timing_metrics);
let websocket_telemetry = Self::build_websocket_telemetry(otel_manager);
let new_conn: ApiWebSocketConnection =
ApiWebSocketResponsesClient::new(api_provider, api_auth)
@@ -599,7 +614,7 @@ impl ModelClientSession {
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
turn_metadata_header: Option<&str>,
) -> Result<ResponseStream> {
) -> Result<WebsocketStreamOutcome> {
let auth_manager = self.client.state.auth_manager.clone();
let api_prompt = Self::build_responses_request(prompt)?;
@@ -639,6 +654,11 @@ impl ModelClientSession {
.await
{
Ok(connection) => connection,
Err(ApiError::Transport(TransportError::Http { status, .. }))
if status == StatusCode::UPGRADE_REQUIRED =>
{
return Ok(WebsocketStreamOutcome::FallbackToHttp);
}
Err(ApiError::Transport(
unauthorized_transport @ TransportError::Http { status, .. },
)) if status == StatusCode::UNAUTHORIZED => {
@@ -654,7 +674,10 @@ impl ModelClientSession {
.map_err(map_api_error)?;
self.websocket_last_items = api_prompt.input.clone();
return Ok(map_response_stream(stream_result, otel_manager.clone()));
return Ok(WebsocketStreamOutcome::Stream(map_response_stream(
stream_result,
otel_manager.clone(),
)));
}
}
@@ -694,30 +717,34 @@ impl ModelClientSession {
let wire_api = self.client.state.provider.wire_api;
match wire_api {
WireApi::Responses => {
let websocket_enabled =
self.responses_websocket_enabled() && !self.disable_websockets();
if websocket_enabled {
self.stream_responses_websocket(
prompt,
model_info,
otel_manager,
effort,
summary,
turn_metadata_header,
)
.await
} else {
self.stream_responses_api(
prompt,
model_info,
otel_manager,
effort,
summary,
turn_metadata_header,
)
.await
if self.responses_websocket_enabled() && !self.disable_websockets() {
match self
.stream_responses_websocket(
prompt,
model_info,
otel_manager,
effort,
summary,
turn_metadata_header,
)
.await?
{
WebsocketStreamOutcome::Stream(stream) => return Ok(stream),
WebsocketStreamOutcome::FallbackToHttp => {
self.switch_to_http_fallback_silent(otel_manager);
}
}
}
self.stream_responses_api(
prompt,
model_info,
otel_manager,
effort,
summary,
turn_metadata_header,
)
.await
}
}
}
@@ -787,6 +814,25 @@ fn build_responses_headers(
headers
}
fn build_websocket_connect_headers(
options: &ApiResponsesOptions,
include_timing_metrics: bool,
) -> ApiHeaderMap {
let mut headers = options.extra_headers.clone();
headers.extend(build_conversation_headers(options.conversation_id.clone()));
headers.insert(
OPENAI_BETA_HEADER,
HeaderValue::from_static(OPENAI_BETA_RESPONSES_WEBSOCKETS),
);
if include_timing_metrics {
headers.insert(
X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER,
HeaderValue::from_static("true"),
);
}
headers
}
fn map_response_stream<S>(api_stream: S, otel_manager: OtelManager) -> ResponseStream
where
S: futures::Stream<Item = std::result::Result<ResponseEvent, ApiError>>

View File

@@ -4039,7 +4039,8 @@ async fn run_sampling_request(
// Use the configured provider-specific stream retry budget.
let max_retries = turn_context.provider.stream_max_retries();
if retries >= max_retries
if retries > 0
&& retries >= max_retries
&& client_session.try_switch_fallback_transport(&turn_context.otel_manager)
{
sess.send_event(

View File

@@ -40,6 +40,8 @@ use tempfile::TempDir;
use tracing_test::traced_test;
const MODEL: &str = "gpt-5.2-codex";
const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
const OPENAI_BETA_RESPONSES_WEBSOCKETS: &str = "responses_websockets=2026-02-04";
struct WebsocketTestHarness {
_codex_home: TempDir,
@@ -74,6 +76,11 @@ async fn responses_websocket_streams_request() {
assert_eq!(body["model"].as_str(), Some(MODEL));
assert_eq!(body["stream"], serde_json::Value::Bool(true));
assert_eq!(body["input"].as_array().map(Vec::len), Some(1));
let handshake = server.single_handshake();
assert_eq!(
handshake.header(OPENAI_BETA_HEADER),
Some(OPENAI_BETA_RESPONSES_WEBSOCKETS.to_string())
);
server.shutdown().await;
}

View File

@@ -9,13 +9,23 @@ use core_test_support::responses::sse;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::test_codex;
use pretty_assertions::assert_eq;
use wiremock::Mock;
use wiremock::ResponseTemplate;
use wiremock::http::Method;
use wiremock::matchers::method;
use wiremock::matchers::path_regex;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result<()> {
async fn websocket_fallback_switches_to_http_on_upgrade_required_connect() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = responses::start_mock_server().await;
Mock::given(method("GET"))
.and(path_regex(".*/responses$"))
.respond_with(ResponseTemplate::new(426))
.mount(&server)
.await;
let response_mock = mount_sse_once(
&server,
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
@@ -28,7 +38,9 @@ async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result
config.model_provider.base_url = Some(base_url);
config.model_provider.wire_api = codex_core::WireApi::Responses;
config.features.enable(Feature::ResponsesWebsockets);
config.model_provider.stream_max_retries = Some(0);
// If we don't treat 426 specially, the sampling loop would retry the WebSocket
// handshake before switching to the HTTP transport.
config.model_provider.stream_max_retries = Some(1);
config.model_provider.request_max_retries = Some(0);
}
});
@@ -53,6 +65,48 @@ async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = responses::start_mock_server().await;
let response_mock = mount_sse_once(
&server,
sse(vec![ev_response_created("resp-1"), ev_completed("resp-1")]),
)
.await;
let mut builder = test_codex().with_config({
let base_url = format!("{}/v1", server.uri());
move |config| {
config.model_provider.base_url = Some(base_url);
config.model_provider.wire_api = codex_core::WireApi::Responses;
config.features.enable(Feature::ResponsesWebsockets);
config.model_provider.stream_max_retries = Some(1);
config.model_provider.request_max_retries = Some(0);
}
});
let test = builder.build(&server).await?;
test.submit_turn("hello").await?;
let requests = server.received_requests().await.unwrap_or_default();
let websocket_attempts = requests
.iter()
.filter(|req| req.method == Method::GET && req.url.path().ends_with("/responses"))
.count();
let http_attempts = requests
.iter()
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
.count();
assert_eq!(websocket_attempts, 2);
assert_eq!(http_attempts, 1);
assert_eq!(response_mock.requests().len(), 1);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_fallback_is_sticky_across_turns() -> Result<()> {
skip_if_no_network!(Ok(()));
@@ -73,7 +127,7 @@ async fn websocket_fallback_is_sticky_across_turns() -> Result<()> {
config.model_provider.base_url = Some(base_url);
config.model_provider.wire_api = codex_core::WireApi::Responses;
config.features.enable(Feature::ResponsesWebsockets);
config.model_provider.stream_max_retries = Some(0);
config.model_provider.stream_max_retries = Some(1);
config.model_provider.request_max_retries = Some(0);
}
});
@@ -92,7 +146,7 @@ async fn websocket_fallback_is_sticky_across_turns() -> Result<()> {
.filter(|req| req.method == Method::POST && req.url.path().ends_with("/responses"))
.count();
assert_eq!(websocket_attempts, 1);
assert_eq!(websocket_attempts, 2);
assert_eq!(http_attempts, 2);
assert_eq!(response_mock.requests().len(), 2);