diff --git a/codex-rs/api-client/src/client/rate_limits.rs b/codex-rs/api-client/src/client/rate_limits.rs index d4ecec2223..d0c329e2e4 100644 --- a/codex-rs/api-client/src/client/rate_limits.rs +++ b/codex-rs/api-client/src/client/rate_limits.rs @@ -30,21 +30,14 @@ fn parse_rate_limit_window( window_minutes_header: &str, resets_at_header: &str, ) -> Option { - let used_percent: Option = parse_header_f64(headers, used_percent_header); + let used_percent: f64 = parse_header_f64(headers, used_percent_header)?; + let window_minutes = parse_header_i64(headers, window_minutes_header); + let resets_at = parse_header_i64(headers, resets_at_header); - used_percent.and_then(|used_percent| { - let window_minutes = parse_header_i64(headers, window_minutes_header); - let resets_at = parse_header_i64(headers, resets_at_header); - - let has_data = used_percent != 0.0 - || window_minutes.is_some_and(|minutes| minutes != 0) - || resets_at.is_some(); - - has_data.then_some(RateLimitWindow { - used_percent, - window_minutes, - resets_at, - }) + Some(RateLimitWindow { + used_percent, + window_minutes, + resets_at, }) } @@ -62,3 +55,26 @@ fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option { fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> { headers.get(name)?.to_str().ok() } + +#[cfg(test)] +mod tests { + use super::*; + use reqwest::header::HeaderValue; + + #[test] + fn snapshot_includes_zero_percent_values() { + let mut headers = HeaderMap::new(); + headers.insert( + "x-codex-primary-used-percent", + HeaderValue::from_static("0.0"), + ); + let snapshot = parse_rate_limit_snapshot(&headers).expect("snapshot should exist"); + assert_eq!(snapshot.primary.unwrap().used_percent, 0.0); + } + + #[test] + fn missing_headers_return_none() { + let headers = HeaderMap::new(); + assert!(parse_rate_limit_snapshot(&headers).is_none()); + } +} diff --git a/codex-rs/api-client/src/error.rs b/codex-rs/api-client/src/error.rs index d740430899..133b9e095b 100644 --- a/codex-rs/api-client/src/error.rs +++ b/codex-rs/api-client/src/error.rs @@ -24,6 +24,10 @@ pub enum Error { resets_at: Option, rate_limits: Option, }, + #[error("usage not included")] + UsageNotIncluded, + #[error("quota exceeded")] + QuotaExceeded, #[error("unexpected status {status}: {body}")] UnexpectedStatus { status: StatusCode, body: String }, #[error("retry limit reached {status:?} request_id={request_id:?}")] diff --git a/codex-rs/api-client/src/responses.rs b/codex-rs/api-client/src/responses.rs index 9659303960..8cba39f86f 100644 --- a/codex-rs/api-client/src/responses.rs +++ b/codex-rs/api-client/src/responses.rs @@ -3,7 +3,9 @@ use std::sync::Arc; use codex_app_server_protocol::AuthMode; use codex_otel::otel_event_manager::OtelEventManager; use codex_protocol::ConversationId; +use codex_protocol::protocol::RateLimitSnapshot; use futures::TryStreamExt; +use reqwest::StatusCode; use serde_json::Value; use tracing::debug; use tracing::trace; @@ -50,23 +52,6 @@ impl ResponsesApiClient { )); } - let auth = crate::client::http::resolve_auth(&self.config.auth_provider).await; - - trace!( - "POST to {}: {:?}", - self.config.provider.get_full_url( - auth.as_ref() - .map(|a| codex_provider_config::AuthContext { - mode: a.mode, - bearer_token: a.bearer_token.clone(), - account_id: a.account_id.clone(), - }) - .as_ref() - ), - serde_json::to_string(payload_json) - .unwrap_or_else(|_| "".to_string()) - ); - let mut owned_headers: Vec<(String, String)> = vec![ ( "conversation_id".to_string(), @@ -78,60 +63,255 @@ impl ResponsesApiClient { ), ]; owned_headers.extend(self.config.extra_headers.iter().cloned()); - let extra_headers = crate::client::http::header_pairs(&owned_headers); - let mut req_builder = crate::client::http::build_request( - &self.config.http_client, - &self.config.provider, - &auth, - &extra_headers, - ) - .await?; - req_builder = req_builder - .header(reqwest::header::ACCEPT, "text/event-stream") - .json(payload_json); + let mut refreshed_auth = false; + loop { + let auth = crate::client::http::resolve_auth(&self.config.auth_provider).await; - if let Some(auth_ctx) = auth.as_ref() - && auth_ctx.mode == AuthMode::ChatGPT - && let Some(account_id) = auth_ctx.account_id.clone() - { - req_builder = req_builder.header("chatgpt-account-id", account_id); - } + trace!( + "POST to {}: {:?}", + self.config.provider.get_full_url( + auth.as_ref() + .map(|a| codex_provider_config::AuthContext { + mode: a.mode, + bearer_token: a.bearer_token.clone(), + account_id: a.account_id.clone(), + }) + .as_ref() + ), + serde_json::to_string(payload_json) + .unwrap_or_else(|_| "".to_string()) + ); - let res = self - .config - .otel_event_manager - .log_request(0, || req_builder.send()) - .await - .map_err(|source| Error::ResponseStreamFailed { - source, - request_id: None, - })?; + let extra_headers = crate::client::http::header_pairs(&owned_headers); + let mut req_builder = crate::client::http::build_request( + &self.config.http_client, + &self.config.provider, + &auth, + &extra_headers, + ) + .await?; - let snapshot = crate::client::rate_limits::parse_rate_limit_snapshot(res.headers()); + req_builder = req_builder + .header(reqwest::header::ACCEPT, "text/event-stream") + .json(payload_json); - let stream = res - .bytes_stream() - .map_err(|err| Error::ResponseStreamFailed { - source: err, - request_id: None, - }); + if let Some(auth_ctx) = auth.as_ref() + && auth_ctx.mode == AuthMode::ChatGPT + && let Some(account_id) = auth_ctx.account_id.clone() + { + req_builder = req_builder.header("chatgpt-account-id", account_id); + } - let (tx_event, rx_event) = crate::client::sse::spawn_wire_stream( - stream, - &self.config.provider, - self.config.otel_event_manager.clone(), - crate::decode_wire::responses::WireResponsesSseDecoder, - ); - if let Some(snapshot) = snapshot - && tx_event - .send(Ok(crate::stream::WireEvent::RateLimits(snapshot))) + let res = self + .config + .otel_event_manager + .log_request(0, || req_builder.send()) .await - .is_err() - { - debug!("receiver dropped rate limit snapshot event"); - } + .map_err(|source| Error::ResponseStreamFailed { + source, + request_id: None, + })?; - Ok(rx_event) + let status = res.status(); + let snapshot = crate::client::rate_limits::parse_rate_limit_snapshot(res.headers()); + + if !status.is_success() { + if status == StatusCode::UNAUTHORIZED + && !refreshed_auth + && self.config.auth_provider.is_some() + && let Some(provider) = &self.config.auth_provider { + provider.refresh_token().await?; + refreshed_auth = true; + continue; + } + + let body = res + .text() + .await + .unwrap_or_else(|err| format!("")); + return Err(map_error_response(status, &body, snapshot)); + } + + let stream = res + .bytes_stream() + .map_err(|err| Error::ResponseStreamFailed { + source: err, + request_id: None, + }); + + let (tx_event, rx_event) = crate::client::sse::spawn_wire_stream( + stream, + &self.config.provider, + self.config.otel_event_manager.clone(), + crate::decode_wire::responses::WireResponsesSseDecoder, + ); + if let Some(snapshot) = snapshot + && tx_event + .send(Ok(crate::stream::WireEvent::RateLimits(snapshot))) + .await + .is_err() + { + debug!("receiver dropped rate limit snapshot event"); + } + + return Ok(rx_event); + } + } +} + +fn map_error_response( + status: StatusCode, + body: &str, + rate_limits: Option, +) -> Error { + if let Ok(value) = serde_json::from_str::(body) + && let Some(error) = value.get("error") { + let error_code = error + .get("type") + .or_else(|| error.get("code")) + .and_then(|value| value.as_str()) + .map(str::to_lowercase); + if let Some(code) = error_code.as_deref() { + match code { + "usage_limit_reached" => { + let plan_type = extract_string_field( + error, + &[ + &["plan_type"], + &["metadata", "plan_type"], + &["details", "plan_type"], + ], + ); + let resets_at = extract_i64_field( + error, + &[ + &["resets_at"], + &["metadata", "resets_at"], + &["details", "resets_at"], + ], + ); + return Error::UsageLimitReached { + plan_type, + resets_at, + rate_limits, + }; + } + "usage_not_included" => { + return Error::UsageNotIncluded; + } + "quota_exceeded" | "insufficient_quota" => { + return Error::QuotaExceeded; + } + _ => {} + } + } + + if let Some(message) = error.get("message").and_then(|v| v.as_str()) + && !message.is_empty() { + return Error::Stream(message.to_string(), None); + } + } + + Error::UnexpectedStatus { + status, + body: body.to_string(), + } +} + +fn extract_string_field(value: &Value, paths: &[&[&str]]) -> Option { + paths + .iter() + .filter_map(|path| nested_value(value, path)) + .find_map(|candidate| candidate.as_str().map(std::string::ToString::to_string)) +} + +fn extract_i64_field(value: &Value, paths: &[&[&str]]) -> Option { + paths + .iter() + .filter_map(|path| nested_value(value, path)) + .find_map(|candidate| match candidate { + Value::Number(num) => num.as_i64(), + Value::String(text) => text.parse::().ok(), + _ => None, + }) +} + +fn nested_value<'a>(value: &'a Value, path: &[&str]) -> Option<&'a Value> { + let mut current = value; + for segment in path { + current = current.get(segment)?; + } + Some(current) +} + +#[cfg(test)] +mod tests { + use super::*; + use codex_protocol::protocol::RateLimitWindow; + use serde_json::json; + + fn snapshot() -> RateLimitSnapshot { + RateLimitSnapshot { + primary: Some(RateLimitWindow { + used_percent: 40.0, + window_minutes: Some(15), + resets_at: Some(1_704_067_200), + }), + secondary: None, + } + } + + #[test] + fn usage_limit_error_includes_metadata() { + let body = json!({ + "error": { + "type": "usage_limit_reached", + "message": "limit", + "plan_type": "pro", + "resets_at": 1704, + } + }) + .to_string(); + + let err = map_error_response(StatusCode::TOO_MANY_REQUESTS, &body, Some(snapshot())); + match err { + Error::UsageLimitReached { + plan_type, + resets_at, + rate_limits, + } => { + assert_eq!(plan_type.as_deref(), Some("pro")); + assert_eq!(resets_at, Some(1704)); + assert!(rate_limits.is_some()); + } + other => panic!("unexpected error: {other:?}"), + } + } + + #[test] + fn usage_not_included_maps_to_specific_variant() { + let body = json!({ + "error": { + "code": "usage_not_included", + "message": "upgrade", + } + }) + .to_string(); + + let err = map_error_response(StatusCode::PAYMENT_REQUIRED, &body, None); + assert!(matches!(err, Error::UsageNotIncluded)); + } + + #[test] + fn unexpected_status_falls_back_to_generic_error() { + let err = map_error_response(StatusCode::BAD_REQUEST, "oops", None); + match err { + Error::UnexpectedStatus { status, body } => { + assert_eq!(status, StatusCode::BAD_REQUEST); + assert_eq!(body, "oops"); + } + other => panic!("expected UnexpectedStatus, got {other:?}"), + } } } diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index d7578baf0f..0bb1174137 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -339,6 +339,8 @@ fn map_api_error(err: codex_api_client::Error) -> CodexErr { rate_limits, }) } + codex_api_client::Error::UsageNotIncluded => CodexErr::UsageNotIncluded, + codex_api_client::Error::QuotaExceeded => CodexErr::QuotaExceeded, codex_api_client::Error::UnexpectedStatus { status, body } => { CodexErr::UnexpectedStatus(UnexpectedResponseError { status, diff --git a/codex-rs/core/src/wire_payload.rs b/codex-rs/core/src/wire_payload.rs index 565a0afd15..96b658703f 100644 --- a/codex-rs/core/src/wire_payload.rs +++ b/codex-rs/core/src/wire_payload.rs @@ -340,3 +340,72 @@ pub fn build_chat_payload(prompt: &Prompt, model: &str, instructions: String) -> .unwrap_or_else(|_| Vec::::new()); json!({ "model": model, "messages": messages, "stream": true, "tools": tools_json }) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::client_common::Prompt; + use codex_protocol::ConversationId; + use codex_protocol::models::ContentItem; + use codex_protocol::models::ReasoningItemReasoningSummary; + use codex_protocol::models::ResponseItem; + + fn prompt_with_items(items: Vec) -> Prompt { + Prompt { + input: items, + tools: Vec::new(), + parallel_tool_calls: false, + base_instructions_override: None, + output_schema: None, + } + } + + #[test] + fn azure_payload_includes_existing_item_ids() { + let prompt = prompt_with_items(vec![ + ResponseItem::Message { + id: Some("msg-1".to_string()), + role: "assistant".to_string(), + content: vec![ContentItem::InputText { + text: "hello".to_string(), + }], + }, + ResponseItem::Reasoning { + id: "reason-1".to_string(), + summary: vec![ReasoningItemReasoningSummary::SummaryText { + text: "thinking".to_string(), + }], + content: None, + encrypted_content: None, + }, + ]); + + let payload = build_responses_payload( + &prompt, + "gpt-5", + ConversationId::new(), + true, + None, + None, + "instructions".to_string(), + ); + + assert_eq!(payload.get("store").and_then(Value::as_bool), Some(true)); + let input = payload + .get("input") + .and_then(Value::as_array) + .expect("input array present"); + let ids: Vec<_> = input + .iter() + .map(|item| { + item.get("id") + .and_then(Value::as_str) + .map(std::string::ToString::to_string) + }) + .collect(); + assert_eq!( + ids, + vec![Some("msg-1".to_string()), Some("reason-1".to_string())] + ); + } +}