diff --git a/codex-rs/codex-api/src/endpoint/responses_websocket.rs b/codex-rs/codex-api/src/endpoint/responses_websocket.rs index 175507f27c..6ebf5ab65e 100644 --- a/codex-rs/codex-api/src/endpoint/responses_websocket.rs +++ b/codex-rs/codex-api/src/endpoint/responses_websocket.rs @@ -260,6 +260,7 @@ impl ResponsesWebsocketClient { pub async fn connect( &self, extra_headers: HeaderMap, + default_headers: HeaderMap, turn_state: Option>>, telemetry: Option>, ) -> Result { @@ -268,8 +269,8 @@ impl ResponsesWebsocketClient { .websocket_url_for_path("responses") .map_err(|err| ApiError::Stream(format!("failed to build websocket URL: {err}")))?; - let mut headers = self.provider.headers.clone(); - headers.extend(extra_headers); + let mut headers = + merge_request_headers(&self.provider.headers, extra_headers, default_headers); add_auth_headers_to_header_map(&self.auth, &mut headers); let (stream, server_reasoning_included, models_etag) = @@ -284,6 +285,21 @@ impl ResponsesWebsocketClient { } } +fn merge_request_headers( + provider_headers: &HeaderMap, + extra_headers: HeaderMap, + default_headers: HeaderMap, +) -> HeaderMap { + let mut headers = provider_headers.clone(); + headers.extend(extra_headers); + for (name, value) in &default_headers { + if let http::header::Entry::Vacant(entry) = headers.entry(name) { + entry.insert(value.clone()); + } + } + headers +} + async fn connect_websocket( url: Url, headers: HeaderMap, @@ -673,4 +689,37 @@ mod tests { let api_error = map_wrapped_websocket_error_event(wrapped_error); assert!(api_error.is_none()); } + + #[test] + fn merge_request_headers_matches_http_precedence() { + let mut provider_headers = HeaderMap::new(); + provider_headers.insert( + "originator", + HeaderValue::from_static("provider-originator"), + ); + provider_headers.insert("x-priority", HeaderValue::from_static("provider")); + + let mut extra_headers = HeaderMap::new(); + extra_headers.insert("x-priority", HeaderValue::from_static("extra")); + + let mut default_headers = HeaderMap::new(); + default_headers.insert("originator", HeaderValue::from_static("default-originator")); + default_headers.insert("x-priority", HeaderValue::from_static("default")); + default_headers.insert("x-default-only", HeaderValue::from_static("default-only")); + + let merged = merge_request_headers(&provider_headers, extra_headers, default_headers); + + assert_eq!( + merged.get("originator"), + Some(&HeaderValue::from_static("provider-originator")) + ); + assert_eq!( + merged.get("x-priority"), + Some(&HeaderValue::from_static("extra")) + ); + assert_eq!( + merged.get("x-default-only"), + Some(&HeaderValue::from_static("default-only")) + ); + } } diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index e06131503c..a871e82030 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -398,7 +398,12 @@ impl ModelClient { let headers = self.build_websocket_headers(turn_state.as_ref(), turn_metadata_header); let websocket_telemetry = ModelClientSession::build_websocket_telemetry(otel_manager); ApiWebSocketResponsesClient::new(api_provider, api_auth) - .connect(headers, turn_state, Some(websocket_telemetry)) + .connect( + headers, + crate::default_client::default_headers(), + turn_state, + Some(websocket_telemetry), + ) .await } diff --git a/codex-rs/core/src/default_client.rs b/codex-rs/core/src/default_client.rs index 94ecd8fcec..f4957883ae 100644 --- a/codex-rs/core/src/default_client.rs +++ b/codex-rs/core/src/default_client.rs @@ -179,6 +179,20 @@ pub fn create_client() -> CodexHttpClient { } pub fn build_reqwest_client() -> reqwest::Client { + let ua = get_codex_user_agent(); + + let mut builder = reqwest::Client::builder() + // Set UA via dedicated helper to avoid header validation pitfalls + .user_agent(ua) + .default_headers(default_headers()); + if is_sandboxed() { + builder = builder.no_proxy(); + } + + builder.build().unwrap_or_else(|_| reqwest::Client::new()) +} + +pub fn default_headers() -> HeaderMap { let mut headers = HeaderMap::new(); headers.insert("originator", originator().header_value); if let Ok(guard) = REQUIREMENTS_RESIDENCY.read() @@ -190,17 +204,7 @@ pub fn build_reqwest_client() -> reqwest::Client { }; headers.insert(RESIDENCY_HEADER_NAME, value); } - let ua = get_codex_user_agent(); - - let mut builder = reqwest::Client::builder() - // Set UA via dedicated helper to avoid header validation pitfalls - .user_agent(ua) - .default_headers(headers); - if is_sandboxed() { - builder = builder.no_proxy(); - } - - builder.build().unwrap_or_else(|_| reqwest::Client::new()) + headers } fn is_sandboxed() -> bool {