mirror of
https://github.com/openai/codex.git
synced 2026-04-28 02:11:08 +03:00
Translate websocket errors (#10937)
When getting errors over a websocket connection, translate the error into our regular API error format
This commit is contained in:
@@ -13,7 +13,12 @@ use codex_client::TransportError;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderName;
|
||||
use http::HeaderValue;
|
||||
use http::StatusCode;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::map::Map as JsonMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
@@ -252,6 +257,83 @@ fn map_ws_error(err: WsError, url: &Url) -> ApiError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WrappedWebsocketErrorEvent {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
#[serde(alias = "status_code")]
|
||||
status: Option<u16>,
|
||||
#[serde(default)]
|
||||
error: Option<Value>,
|
||||
#[serde(default)]
|
||||
headers: Option<JsonMap<String, Value>>,
|
||||
}
|
||||
|
||||
fn parse_wrapped_websocket_error_event(payload: &str) -> Option<WrappedWebsocketErrorEvent> {
|
||||
let event: WrappedWebsocketErrorEvent = serde_json::from_str(payload).ok()?;
|
||||
if event.kind != "error" {
|
||||
return None;
|
||||
}
|
||||
Some(event)
|
||||
}
|
||||
|
||||
fn map_wrapped_websocket_error_event(event: WrappedWebsocketErrorEvent) -> Option<ApiError> {
|
||||
let WrappedWebsocketErrorEvent {
|
||||
status,
|
||||
error,
|
||||
headers,
|
||||
..
|
||||
} = event;
|
||||
|
||||
let status = StatusCode::from_u16(status?).ok()?;
|
||||
if status.is_success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let body = error.map(|error| {
|
||||
serde_json::to_string_pretty(&serde_json::json!({
|
||||
"error": error
|
||||
}))
|
||||
.unwrap_or_else(|_| {
|
||||
serde_json::json!({
|
||||
"error": error
|
||||
})
|
||||
.to_string()
|
||||
})
|
||||
});
|
||||
|
||||
Some(ApiError::Transport(TransportError::Http {
|
||||
status,
|
||||
url: None,
|
||||
headers: headers.map(json_headers_to_http_headers),
|
||||
body,
|
||||
}))
|
||||
}
|
||||
|
||||
fn json_headers_to_http_headers(headers: JsonMap<String, Value>) -> HeaderMap {
|
||||
let mut mapped = HeaderMap::new();
|
||||
for (name, value) in headers {
|
||||
let Ok(header_name) = HeaderName::from_bytes(name.as_bytes()) else {
|
||||
continue;
|
||||
};
|
||||
let Some(header_value) = json_header_value(value) else {
|
||||
continue;
|
||||
};
|
||||
mapped.insert(header_name, header_value);
|
||||
}
|
||||
mapped
|
||||
}
|
||||
|
||||
fn json_header_value(value: Value) -> Option<HeaderValue> {
|
||||
let value = match value {
|
||||
Value::String(value) => value,
|
||||
Value::Number(value) => value.to_string(),
|
||||
Value::Bool(value) => value.to_string(),
|
||||
_ => return None,
|
||||
};
|
||||
HeaderValue::from_str(&value).ok()
|
||||
}
|
||||
|
||||
async fn run_websocket_response_stream(
|
||||
ws_stream: &mut WsStream,
|
||||
tx_event: mpsc::Sender<std::result::Result<ResponseEvent, ApiError>>,
|
||||
@@ -306,6 +388,12 @@ async fn run_websocket_response_stream(
|
||||
match message {
|
||||
Message::Text(text) => {
|
||||
trace!("websocket event: {text}");
|
||||
if let Some(wrapped_error) = parse_wrapped_websocket_error_event(&text)
|
||||
&& let Some(error) = map_wrapped_websocket_error_event(wrapped_error)
|
||||
{
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
let event = match serde_json::from_str::<ResponsesStreamEvent>(&text) {
|
||||
Ok(event) => event,
|
||||
Err(err) => {
|
||||
@@ -357,10 +445,124 @@ async fn run_websocket_response_stream(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn websocket_config_enables_permessage_deflate() {
|
||||
let config = websocket_config();
|
||||
assert!(config.extensions.permessage_deflate.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_wrapped_websocket_error_event_maps_to_transport_http() {
|
||||
let payload = json!({
|
||||
"type": "error",
|
||||
"status": 429,
|
||||
"error": {
|
||||
"type": "usage_limit_reached",
|
||||
"message": "The usage limit has been reached",
|
||||
"plan_type": "pro",
|
||||
"resets_at": 1738888888
|
||||
},
|
||||
"headers": {
|
||||
"x-codex-primary-used-percent": "100.0",
|
||||
"x-codex-primary-window-minutes": 15
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let wrapped_error = parse_wrapped_websocket_error_event(&payload)
|
||||
.expect("expected websocket error payload to be parsed");
|
||||
let api_error = map_wrapped_websocket_error_event(wrapped_error)
|
||||
.expect("expected websocket error payload to map to ApiError");
|
||||
|
||||
let ApiError::Transport(TransportError::Http {
|
||||
status,
|
||||
headers,
|
||||
body,
|
||||
..
|
||||
}) = api_error
|
||||
else {
|
||||
panic!("expected ApiError::Transport(Http)");
|
||||
};
|
||||
|
||||
assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
|
||||
let headers = headers.expect("expected headers");
|
||||
assert_eq!(
|
||||
headers
|
||||
.get("x-codex-primary-used-percent")
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("100.0")
|
||||
);
|
||||
assert_eq!(
|
||||
headers
|
||||
.get("x-codex-primary-window-minutes")
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("15")
|
||||
);
|
||||
let body = body.expect("expected body");
|
||||
assert!(body.contains("usage_limit_reached"));
|
||||
assert!(body.contains("The usage limit has been reached"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_wrapped_websocket_error_event_ignores_non_error_payloads() {
|
||||
let payload = json!({
|
||||
"type": "response.created",
|
||||
"response": {
|
||||
"id": "resp-1"
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let wrapped_error = parse_wrapped_websocket_error_event(&payload);
|
||||
assert!(wrapped_error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_wrapped_websocket_error_event_with_status_maps_invalid_request() {
|
||||
let payload = json!({
|
||||
"type": "error",
|
||||
"status": 400,
|
||||
"error": {
|
||||
"type": "invalid_request_error",
|
||||
"message": "Model does not support image inputs"
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let wrapped_error = parse_wrapped_websocket_error_event(&payload)
|
||||
.expect("expected websocket error payload to be parsed");
|
||||
let api_error = map_wrapped_websocket_error_event(wrapped_error)
|
||||
.expect("expected websocket error payload to map to ApiError");
|
||||
let ApiError::Transport(TransportError::Http { status, body, .. }) = api_error else {
|
||||
panic!("expected ApiError::Transport(Http)");
|
||||
};
|
||||
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||||
let body = body.expect("expected body");
|
||||
assert!(body.contains("invalid_request_error"));
|
||||
assert!(body.contains("Model does not support image inputs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_wrapped_websocket_error_event_without_status_is_not_mapped() {
|
||||
let payload = json!({
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "usage_limit_reached",
|
||||
"message": "The usage limit has been reached"
|
||||
},
|
||||
"headers": {
|
||||
"x-codex-primary-used-percent": "100.0",
|
||||
"x-codex-primary-window-minutes": 15
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let wrapped_error = parse_wrapped_websocket_error_event(&payload)
|
||||
.expect("expected websocket error payload to be parsed");
|
||||
let api_error = map_wrapped_websocket_error_event(wrapped_error);
|
||||
assert!(api_error.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user