mirror of
https://github.com/openai/codex.git
synced 2026-04-28 02:11:08 +03:00
Add a codex.rate_limits event for websockets (#10324)
When communicating over websockets, we can't rely on headers to deliver rate limit information. This PR adds a `codex.rate_limits` event that the server can pass to the client to inform them about rate limit usage. The client parses this data the same way we parse rate limit headers in HTTP mode. This PR also wires up the etag and reasoning headers for websockets
This commit is contained in:
@@ -5,6 +5,7 @@ use crate::common::ResponseStream;
|
||||
use crate::common::ResponsesWsRequest;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::rate_limits::parse_rate_limit_event;
|
||||
use crate::sse::responses::ResponsesStreamEvent;
|
||||
use crate::sse::responses::process_responses_event;
|
||||
use crate::telemetry::WebsocketTelemetry;
|
||||
@@ -33,6 +34,7 @@ use url::Url;
|
||||
|
||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
|
||||
const X_MODELS_ETAG_HEADER: &str = "x-models-etag";
|
||||
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
|
||||
|
||||
pub struct ResponsesWebsocketConnection {
|
||||
@@ -40,6 +42,7 @@ pub struct ResponsesWebsocketConnection {
|
||||
// TODO (pakrym): is this the right place for timeout?
|
||||
idle_timeout: Duration,
|
||||
server_reasoning_included: bool,
|
||||
models_etag: Option<String>,
|
||||
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
|
||||
}
|
||||
|
||||
@@ -48,12 +51,14 @@ impl ResponsesWebsocketConnection {
|
||||
stream: WsStream,
|
||||
idle_timeout: Duration,
|
||||
server_reasoning_included: bool,
|
||||
models_etag: Option<String>,
|
||||
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream: Arc::new(Mutex::new(Some(stream))),
|
||||
idle_timeout,
|
||||
server_reasoning_included,
|
||||
models_etag,
|
||||
telemetry,
|
||||
}
|
||||
}
|
||||
@@ -71,12 +76,16 @@ impl ResponsesWebsocketConnection {
|
||||
let stream = Arc::clone(&self.stream);
|
||||
let idle_timeout = self.idle_timeout;
|
||||
let server_reasoning_included = self.server_reasoning_included;
|
||||
let models_etag = self.models_etag.clone();
|
||||
let telemetry = self.telemetry.clone();
|
||||
let request_body = serde_json::to_value(&request).map_err(|err| {
|
||||
ApiError::Stream(format!("failed to encode websocket request: {err}"))
|
||||
})?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Some(etag) = models_etag {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await;
|
||||
}
|
||||
if server_reasoning_included {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ServerReasoningIncluded(true)))
|
||||
@@ -136,12 +145,13 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
|
||||
headers.extend(extra_headers);
|
||||
add_auth_headers_to_header_map(&self.auth, &mut headers);
|
||||
|
||||
let (stream, server_reasoning_included) =
|
||||
connect_websocket(ws_url, headers, turn_state).await?;
|
||||
let (stream, server_reasoning_included, models_etag) =
|
||||
connect_websocket(ws_url, headers, turn_state.clone()).await?;
|
||||
Ok(ResponsesWebsocketConnection::new(
|
||||
stream,
|
||||
self.provider.stream_idle_timeout,
|
||||
server_reasoning_included,
|
||||
models_etag,
|
||||
telemetry,
|
||||
))
|
||||
}
|
||||
@@ -151,7 +161,7 @@ async fn connect_websocket(
|
||||
url: Url,
|
||||
headers: HeaderMap,
|
||||
turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> Result<(WsStream, bool), ApiError> {
|
||||
) -> Result<(WsStream, bool, Option<String>), ApiError> {
|
||||
info!("connecting to websocket: {url}");
|
||||
|
||||
let mut request = url
|
||||
@@ -177,6 +187,11 @@ async fn connect_websocket(
|
||||
};
|
||||
|
||||
let reasoning_included = response.headers().contains_key(X_REASONING_INCLUDED_HEADER);
|
||||
let models_etag = response
|
||||
.headers()
|
||||
.get(X_MODELS_ETAG_HEADER)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(ToString::to_string);
|
||||
if let Some(turn_state) = turn_state
|
||||
&& let Some(header_value) = response
|
||||
.headers()
|
||||
@@ -185,7 +200,7 @@ async fn connect_websocket(
|
||||
{
|
||||
let _ = turn_state.set(header_value.to_string());
|
||||
}
|
||||
Ok((stream, reasoning_included))
|
||||
Ok((stream, reasoning_included, models_etag))
|
||||
}
|
||||
|
||||
fn map_ws_error(err: WsError, url: &Url) -> ApiError {
|
||||
@@ -273,6 +288,12 @@ async fn run_websocket_response_stream(
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if event.kind() == "codex.rate_limits" {
|
||||
if let Some(snapshot) = parse_rate_limit_event(&text) {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
match process_responses_event(event) {
|
||||
Ok(Some(event)) => {
|
||||
let is_completed = matches!(event, ResponseEvent::Completed { .. });
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use codex_protocol::account::PlanType;
|
||||
use codex_protocol::protocol::CreditsSnapshot;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use codex_protocol::protocol::RateLimitWindow;
|
||||
use http::HeaderMap;
|
||||
use serde::Deserialize;
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -41,6 +43,70 @@ pub fn parse_rate_limit(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RateLimitEventWindow {
|
||||
used_percent: f64,
|
||||
window_minutes: Option<i64>,
|
||||
reset_at: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RateLimitEventDetails {
|
||||
primary: Option<RateLimitEventWindow>,
|
||||
secondary: Option<RateLimitEventWindow>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RateLimitEventCredits {
|
||||
has_credits: bool,
|
||||
unlimited: bool,
|
||||
balance: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RateLimitEvent {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
plan_type: Option<PlanType>,
|
||||
rate_limits: Option<RateLimitEventDetails>,
|
||||
credits: Option<RateLimitEventCredits>,
|
||||
}
|
||||
|
||||
pub fn parse_rate_limit_event(payload: &str) -> Option<RateLimitSnapshot> {
|
||||
let event: RateLimitEvent = serde_json::from_str(payload).ok()?;
|
||||
if event.kind != "codex.rate_limits" {
|
||||
return None;
|
||||
}
|
||||
let (primary, secondary) = if let Some(details) = event.rate_limits.as_ref() {
|
||||
(
|
||||
map_event_window(details.primary.as_ref()),
|
||||
map_event_window(details.secondary.as_ref()),
|
||||
)
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let credits = event.credits.map(|credits| CreditsSnapshot {
|
||||
has_credits: credits.has_credits,
|
||||
unlimited: credits.unlimited,
|
||||
balance: credits.balance,
|
||||
});
|
||||
Some(RateLimitSnapshot {
|
||||
primary,
|
||||
secondary,
|
||||
credits,
|
||||
plan_type: event.plan_type,
|
||||
})
|
||||
}
|
||||
|
||||
fn map_event_window(window: Option<&RateLimitEventWindow>) -> Option<RateLimitWindow> {
|
||||
let window = window?;
|
||||
Some(RateLimitWindow {
|
||||
used_percent: window.used_percent,
|
||||
window_minutes: window.window_minutes,
|
||||
resets_at: window.reset_at,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parses the bespoke Codex rate-limit headers into a `RateLimitSnapshot`.
|
||||
pub fn parse_promo_message(headers: &HeaderMap) -> Option<String> {
|
||||
parse_header_str(headers, "x-codex-promo-message")
|
||||
|
||||
@@ -165,6 +165,12 @@ pub struct ResponsesStreamEvent {
|
||||
content_index: Option<i64>,
|
||||
}
|
||||
|
||||
impl ResponsesStreamEvent {
|
||||
pub fn kind(&self) -> &str {
|
||||
&self.kind
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponsesEventError {
|
||||
Api(ApiError),
|
||||
|
||||
Reference in New Issue
Block a user