Restore rate limit and usage error handling

This commit is contained in:
jif-oai
2025-11-13 10:02:17 +00:00
parent 001ed59f5c
commit eaa68c3ae7
5 changed files with 349 additions and 78 deletions

View File

@@ -30,21 +30,14 @@ fn parse_rate_limit_window(
window_minutes_header: &str,
resets_at_header: &str,
) -> Option<RateLimitWindow> {
let used_percent: Option<f64> = parse_header_f64(headers, used_percent_header);
let used_percent: f64 = parse_header_f64(headers, used_percent_header)?;
let window_minutes = parse_header_i64(headers, window_minutes_header);
let resets_at = parse_header_i64(headers, resets_at_header);
used_percent.and_then(|used_percent| {
let window_minutes = parse_header_i64(headers, window_minutes_header);
let resets_at = parse_header_i64(headers, resets_at_header);
let has_data = used_percent != 0.0
|| window_minutes.is_some_and(|minutes| minutes != 0)
|| resets_at.is_some();
has_data.then_some(RateLimitWindow {
used_percent,
window_minutes,
resets_at,
})
Some(RateLimitWindow {
used_percent,
window_minutes,
resets_at,
})
}
@@ -62,3 +55,26 @@ fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option<i64> {
fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
headers.get(name)?.to_str().ok()
}
#[cfg(test)]
mod tests {
use super::*;
use reqwest::header::HeaderValue;
#[test]
fn snapshot_includes_zero_percent_values() {
let mut headers = HeaderMap::new();
headers.insert(
"x-codex-primary-used-percent",
HeaderValue::from_static("0.0"),
);
let snapshot = parse_rate_limit_snapshot(&headers).expect("snapshot should exist");
assert_eq!(snapshot.primary.unwrap().used_percent, 0.0);
}
#[test]
fn missing_headers_return_none() {
let headers = HeaderMap::new();
assert!(parse_rate_limit_snapshot(&headers).is_none());
}
}

View File

@@ -24,6 +24,10 @@ pub enum Error {
resets_at: Option<i64>,
rate_limits: Option<RateLimitSnapshot>,
},
#[error("usage not included")]
UsageNotIncluded,
#[error("quota exceeded")]
QuotaExceeded,
#[error("unexpected status {status}: {body}")]
UnexpectedStatus { status: StatusCode, body: String },
#[error("retry limit reached {status:?} request_id={request_id:?}")]

View File

@@ -3,7 +3,9 @@ use std::sync::Arc;
use codex_app_server_protocol::AuthMode;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::ConversationId;
use codex_protocol::protocol::RateLimitSnapshot;
use futures::TryStreamExt;
use reqwest::StatusCode;
use serde_json::Value;
use tracing::debug;
use tracing::trace;
@@ -50,23 +52,6 @@ impl ResponsesApiClient {
));
}
let auth = crate::client::http::resolve_auth(&self.config.auth_provider).await;
trace!(
"POST to {}: {:?}",
self.config.provider.get_full_url(
auth.as_ref()
.map(|a| codex_provider_config::AuthContext {
mode: a.mode,
bearer_token: a.bearer_token.clone(),
account_id: a.account_id.clone(),
})
.as_ref()
),
serde_json::to_string(payload_json)
.unwrap_or_else(|_| "<unable to serialize payload>".to_string())
);
let mut owned_headers: Vec<(String, String)> = vec![
(
"conversation_id".to_string(),
@@ -78,60 +63,255 @@ impl ResponsesApiClient {
),
];
owned_headers.extend(self.config.extra_headers.iter().cloned());
let extra_headers = crate::client::http::header_pairs(&owned_headers);
let mut req_builder = crate::client::http::build_request(
&self.config.http_client,
&self.config.provider,
&auth,
&extra_headers,
)
.await?;
req_builder = req_builder
.header(reqwest::header::ACCEPT, "text/event-stream")
.json(payload_json);
let mut refreshed_auth = false;
loop {
let auth = crate::client::http::resolve_auth(&self.config.auth_provider).await;
if let Some(auth_ctx) = auth.as_ref()
&& auth_ctx.mode == AuthMode::ChatGPT
&& let Some(account_id) = auth_ctx.account_id.clone()
{
req_builder = req_builder.header("chatgpt-account-id", account_id);
}
trace!(
"POST to {}: {:?}",
self.config.provider.get_full_url(
auth.as_ref()
.map(|a| codex_provider_config::AuthContext {
mode: a.mode,
bearer_token: a.bearer_token.clone(),
account_id: a.account_id.clone(),
})
.as_ref()
),
serde_json::to_string(payload_json)
.unwrap_or_else(|_| "<unable to serialize payload>".to_string())
);
let res = self
.config
.otel_event_manager
.log_request(0, || req_builder.send())
.await
.map_err(|source| Error::ResponseStreamFailed {
source,
request_id: None,
})?;
let extra_headers = crate::client::http::header_pairs(&owned_headers);
let mut req_builder = crate::client::http::build_request(
&self.config.http_client,
&self.config.provider,
&auth,
&extra_headers,
)
.await?;
let snapshot = crate::client::rate_limits::parse_rate_limit_snapshot(res.headers());
req_builder = req_builder
.header(reqwest::header::ACCEPT, "text/event-stream")
.json(payload_json);
let stream = res
.bytes_stream()
.map_err(|err| Error::ResponseStreamFailed {
source: err,
request_id: None,
});
if let Some(auth_ctx) = auth.as_ref()
&& auth_ctx.mode == AuthMode::ChatGPT
&& let Some(account_id) = auth_ctx.account_id.clone()
{
req_builder = req_builder.header("chatgpt-account-id", account_id);
}
let (tx_event, rx_event) = crate::client::sse::spawn_wire_stream(
stream,
&self.config.provider,
self.config.otel_event_manager.clone(),
crate::decode_wire::responses::WireResponsesSseDecoder,
);
if let Some(snapshot) = snapshot
&& tx_event
.send(Ok(crate::stream::WireEvent::RateLimits(snapshot)))
let res = self
.config
.otel_event_manager
.log_request(0, || req_builder.send())
.await
.is_err()
{
debug!("receiver dropped rate limit snapshot event");
}
.map_err(|source| Error::ResponseStreamFailed {
source,
request_id: None,
})?;
Ok(rx_event)
let status = res.status();
let snapshot = crate::client::rate_limits::parse_rate_limit_snapshot(res.headers());
if !status.is_success() {
if status == StatusCode::UNAUTHORIZED
&& !refreshed_auth
&& self.config.auth_provider.is_some()
&& let Some(provider) = &self.config.auth_provider {
provider.refresh_token().await?;
refreshed_auth = true;
continue;
}
let body = res
.text()
.await
.unwrap_or_else(|err| format!("<failed to read body: {err}>"));
return Err(map_error_response(status, &body, snapshot));
}
let stream = res
.bytes_stream()
.map_err(|err| Error::ResponseStreamFailed {
source: err,
request_id: None,
});
let (tx_event, rx_event) = crate::client::sse::spawn_wire_stream(
stream,
&self.config.provider,
self.config.otel_event_manager.clone(),
crate::decode_wire::responses::WireResponsesSseDecoder,
);
if let Some(snapshot) = snapshot
&& tx_event
.send(Ok(crate::stream::WireEvent::RateLimits(snapshot)))
.await
.is_err()
{
debug!("receiver dropped rate limit snapshot event");
}
return Ok(rx_event);
}
}
}
fn map_error_response(
status: StatusCode,
body: &str,
rate_limits: Option<RateLimitSnapshot>,
) -> Error {
if let Ok(value) = serde_json::from_str::<Value>(body)
&& let Some(error) = value.get("error") {
let error_code = error
.get("type")
.or_else(|| error.get("code"))
.and_then(|value| value.as_str())
.map(str::to_lowercase);
if let Some(code) = error_code.as_deref() {
match code {
"usage_limit_reached" => {
let plan_type = extract_string_field(
error,
&[
&["plan_type"],
&["metadata", "plan_type"],
&["details", "plan_type"],
],
);
let resets_at = extract_i64_field(
error,
&[
&["resets_at"],
&["metadata", "resets_at"],
&["details", "resets_at"],
],
);
return Error::UsageLimitReached {
plan_type,
resets_at,
rate_limits,
};
}
"usage_not_included" => {
return Error::UsageNotIncluded;
}
"quota_exceeded" | "insufficient_quota" => {
return Error::QuotaExceeded;
}
_ => {}
}
}
if let Some(message) = error.get("message").and_then(|v| v.as_str())
&& !message.is_empty() {
return Error::Stream(message.to_string(), None);
}
}
Error::UnexpectedStatus {
status,
body: body.to_string(),
}
}
fn extract_string_field(value: &Value, paths: &[&[&str]]) -> Option<String> {
paths
.iter()
.filter_map(|path| nested_value(value, path))
.find_map(|candidate| candidate.as_str().map(std::string::ToString::to_string))
}
fn extract_i64_field(value: &Value, paths: &[&[&str]]) -> Option<i64> {
paths
.iter()
.filter_map(|path| nested_value(value, path))
.find_map(|candidate| match candidate {
Value::Number(num) => num.as_i64(),
Value::String(text) => text.parse::<i64>().ok(),
_ => None,
})
}
fn nested_value<'a>(value: &'a Value, path: &[&str]) -> Option<&'a Value> {
let mut current = value;
for segment in path {
current = current.get(segment)?;
}
Some(current)
}
#[cfg(test)]
mod tests {
use super::*;
use codex_protocol::protocol::RateLimitWindow;
use serde_json::json;
fn snapshot() -> RateLimitSnapshot {
RateLimitSnapshot {
primary: Some(RateLimitWindow {
used_percent: 40.0,
window_minutes: Some(15),
resets_at: Some(1_704_067_200),
}),
secondary: None,
}
}
#[test]
fn usage_limit_error_includes_metadata() {
let body = json!({
"error": {
"type": "usage_limit_reached",
"message": "limit",
"plan_type": "pro",
"resets_at": 1704,
}
})
.to_string();
let err = map_error_response(StatusCode::TOO_MANY_REQUESTS, &body, Some(snapshot()));
match err {
Error::UsageLimitReached {
plan_type,
resets_at,
rate_limits,
} => {
assert_eq!(plan_type.as_deref(), Some("pro"));
assert_eq!(resets_at, Some(1704));
assert!(rate_limits.is_some());
}
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn usage_not_included_maps_to_specific_variant() {
let body = json!({
"error": {
"code": "usage_not_included",
"message": "upgrade",
}
})
.to_string();
let err = map_error_response(StatusCode::PAYMENT_REQUIRED, &body, None);
assert!(matches!(err, Error::UsageNotIncluded));
}
#[test]
fn unexpected_status_falls_back_to_generic_error() {
let err = map_error_response(StatusCode::BAD_REQUEST, "oops", None);
match err {
Error::UnexpectedStatus { status, body } => {
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(body, "oops");
}
other => panic!("expected UnexpectedStatus, got {other:?}"),
}
}
}