mirror of
https://github.com/openai/codex.git
synced 2026-05-03 21:01:55 +03:00
Add safety check notification and error handling (#19055)
Adds a new app-server notification that fires when a user account has been flagged for potential safety reasons.
This commit is contained in:
@@ -32,6 +32,7 @@ pub fn map_api_error(err: ApiError) -> CodexErr {
|
||||
identity_error_code: None,
|
||||
}),
|
||||
ApiError::InvalidRequest { message } => CodexErr::InvalidRequest(message),
|
||||
ApiError::CyberPolicy { message } => CodexErr::CyberPolicy { message },
|
||||
ApiError::Transport(transport) => match transport {
|
||||
TransportError::Http {
|
||||
status,
|
||||
@@ -55,7 +56,19 @@ pub fn map_api_error(err: ApiError) -> CodexErr {
|
||||
}
|
||||
|
||||
if status == http::StatusCode::BAD_REQUEST {
|
||||
if body_text
|
||||
if let Ok(parsed) = serde_json::from_str::<Value>(&body_text)
|
||||
&& let Some(error) = parsed.get("error")
|
||||
&& error.get("code").and_then(Value::as_str)
|
||||
== Some(CYBER_POLICY_ERROR_CODE)
|
||||
{
|
||||
let message = error
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
.filter(|message| !message.trim().is_empty())
|
||||
.map(str::to_string)
|
||||
.unwrap_or_else(|| CYBER_POLICY_FALLBACK_MESSAGE.to_string());
|
||||
CodexErr::CyberPolicy { message }
|
||||
} else if body_text
|
||||
.contains("The image data you provided does not represent a valid image")
|
||||
{
|
||||
CodexErr::InvalidImageRequest()
|
||||
@@ -125,6 +138,9 @@ const OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id";
|
||||
const CF_RAY_HEADER: &str = "cf-ray";
|
||||
const X_OPENAI_AUTHORIZATION_ERROR_HEADER: &str = "x-openai-authorization-error";
|
||||
const X_ERROR_JSON_HEADER: &str = "x-error-json";
|
||||
const CYBER_POLICY_ERROR_CODE: &str = "cyber_policy";
|
||||
const CYBER_POLICY_FALLBACK_MESSAGE: &str =
|
||||
"This request has been flagged for possible cybersecurity risk.";
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "api_bridge_tests.rs"]
|
||||
|
||||
@@ -26,6 +26,104 @@ fn map_api_error_maps_server_overloaded_from_503_body() {
|
||||
assert!(matches!(err, CodexErr::ServerOverloaded));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn map_api_error_maps_cyber_policy_from_400_body() {
|
||||
let body = serde_json::json!({
|
||||
"error": {
|
||||
"message": "This request has been flagged for potentially high-risk cyber activity.",
|
||||
"type": "invalid_request",
|
||||
"param": null,
|
||||
"code": "cyber_policy"
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
let err = map_api_error(ApiError::Transport(TransportError::Http {
|
||||
status: http::StatusCode::BAD_REQUEST,
|
||||
url: Some("http://example.com/v1/responses".to_string()),
|
||||
headers: None,
|
||||
body: Some(body),
|
||||
}));
|
||||
|
||||
let CodexErr::CyberPolicy { message } = err else {
|
||||
panic!("expected CodexErr::CyberPolicy, got {err:?}");
|
||||
};
|
||||
assert_eq!(
|
||||
message,
|
||||
"This request has been flagged for potentially high-risk cyber activity."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn map_api_error_maps_wrapped_websocket_cyber_policy_from_400_body() {
|
||||
let body = serde_json::json!({
|
||||
"type": "error",
|
||||
"status": 400,
|
||||
"error": {
|
||||
"message": "This websocket request was flagged.",
|
||||
"type": "invalid_request",
|
||||
"code": "cyber_policy"
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
let err = map_api_error(ApiError::Transport(TransportError::Http {
|
||||
status: http::StatusCode::BAD_REQUEST,
|
||||
url: Some("ws://example.com/v1/responses".to_string()),
|
||||
headers: None,
|
||||
body: Some(body),
|
||||
}));
|
||||
|
||||
let CodexErr::CyberPolicy { message } = err else {
|
||||
panic!("expected CodexErr::CyberPolicy, got {err:?}");
|
||||
};
|
||||
assert_eq!(message, "This websocket request was flagged.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn map_api_error_uses_cyber_policy_fallback_for_missing_message() {
|
||||
let body = serde_json::json!({
|
||||
"error": {
|
||||
"code": "cyber_policy"
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
let err = map_api_error(ApiError::Transport(TransportError::Http {
|
||||
status: http::StatusCode::BAD_REQUEST,
|
||||
url: Some("http://example.com/v1/responses".to_string()),
|
||||
headers: None,
|
||||
body: Some(body),
|
||||
}));
|
||||
|
||||
let CodexErr::CyberPolicy { message } = err else {
|
||||
panic!("expected CodexErr::CyberPolicy, got {err:?}");
|
||||
};
|
||||
assert_eq!(
|
||||
message,
|
||||
"This request has been flagged for possible cybersecurity risk."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn map_api_error_keeps_unknown_400_errors_generic() {
|
||||
let body = serde_json::json!({
|
||||
"error": {
|
||||
"message": "Some other bad request.",
|
||||
"code": "some_other_policy"
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
let err = map_api_error(ApiError::Transport(TransportError::Http {
|
||||
status: http::StatusCode::BAD_REQUEST,
|
||||
url: Some("http://example.com/v1/responses".to_string()),
|
||||
headers: None,
|
||||
body: Some(body.clone()),
|
||||
}));
|
||||
|
||||
let CodexErr::InvalidRequest(message) = err else {
|
||||
panic!("expected CodexErr::InvalidRequest, got {err:?}");
|
||||
};
|
||||
assert_eq!(message, body);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn map_api_error_maps_usage_limit_limit_name_header() {
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
@@ -3,6 +3,7 @@ use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::protocol::ModelVerification;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use codex_protocol::protocol::W3cTraceContext;
|
||||
@@ -71,6 +72,8 @@ pub enum ResponseEvent {
|
||||
/// Emitted when the server includes `OpenAI-Model` on the stream response.
|
||||
/// This can differ from the requested model when backend safety routing applies.
|
||||
ServerModel(String),
|
||||
/// Emitted when the server recommends additional account verification.
|
||||
ModelVerifications(Vec<ModelVerification>),
|
||||
/// Emitted when `X-Reasoning-Included: true` is present on the response,
|
||||
/// meaning the server already accounted for past reasoning tokens and the
|
||||
/// client should not re-estimate them.
|
||||
|
||||
@@ -608,6 +608,7 @@ async fn run_websocket_response_stream(
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let model_verifications = event.model_verifications();
|
||||
if event.kind() == "codex.rate_limits" {
|
||||
if let Some(snapshot) = parse_rate_limit_event(&text) {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
|
||||
@@ -622,6 +623,16 @@ async fn run_websocket_response_stream(
|
||||
.await;
|
||||
last_server_model = Some(model);
|
||||
}
|
||||
if let Some(verifications) = model_verifications
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::ModelVerifications(verifications)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err(ApiError::Stream(
|
||||
"response event consumer dropped".to_string(),
|
||||
));
|
||||
}
|
||||
match process_responses_event(event) {
|
||||
Ok(Some(event)) => {
|
||||
let is_completed = matches!(event, ResponseEvent::Completed { .. });
|
||||
|
||||
@@ -27,6 +27,8 @@ pub enum ApiError {
|
||||
RateLimit(String),
|
||||
#[error("invalid request: {message}")]
|
||||
InvalidRequest { message: String },
|
||||
#[error("cyber policy: {message}")]
|
||||
CyberPolicy { message: String },
|
||||
#[error("server overloaded")]
|
||||
ServerOverloaded,
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ use codex_client::ByteStream;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_client::TransportError;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::ModelVerification;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::StreamExt;
|
||||
@@ -27,6 +28,7 @@ use tracing::trace;
|
||||
|
||||
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
|
||||
const OPENAI_MODEL_HEADER: &str = "openai-model";
|
||||
const TRUSTED_ACCESS_FOR_CYBER_VERIFICATION: &str = "trusted_access_for_cyber";
|
||||
|
||||
/// Streams SSE events from an on-disk fixture for tests.
|
||||
pub fn stream_from_fixture(
|
||||
@@ -165,6 +167,7 @@ pub struct ResponsesStreamEvent {
|
||||
#[serde(rename = "type")]
|
||||
pub(crate) kind: String,
|
||||
headers: Option<Value>,
|
||||
metadata: Option<Value>,
|
||||
response: Option<Value>,
|
||||
item: Option<Value>,
|
||||
item_id: Option<String>,
|
||||
@@ -183,8 +186,7 @@ impl ResponsesStreamEvent {
|
||||
///
|
||||
/// Precedence:
|
||||
/// 1. `response.headers` for standard Responses stream events.
|
||||
/// 2. top-level `headers` for websocket metadata events (for example
|
||||
/// `codex.response.metadata`).
|
||||
/// 2. top-level `headers` for websocket metadata events.
|
||||
pub fn response_model(&self) -> Option<String> {
|
||||
let response_headers_model = self
|
||||
.response
|
||||
@@ -200,6 +202,17 @@ impl ResponsesStreamEvent {
|
||||
.and_then(header_openai_model_value_from_json),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn model_verifications(&self) -> Option<Vec<ModelVerification>> {
|
||||
if self.kind() != "response.metadata" {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.metadata
|
||||
.as_ref()
|
||||
.and_then(|metadata| metadata.get("openai_verification_recommendation"))
|
||||
.and_then(model_verifications_from_json_value)
|
||||
}
|
||||
}
|
||||
|
||||
fn header_openai_model_value_from_json(value: &Value) -> Option<String> {
|
||||
@@ -214,6 +227,38 @@ fn header_openai_model_value_from_json(value: &Value) -> Option<String> {
|
||||
})
|
||||
}
|
||||
|
||||
fn model_verifications_from_json_value(value: &Value) -> Option<Vec<ModelVerification>> {
|
||||
let verifications = value
|
||||
.as_array()
|
||||
.map(|items| {
|
||||
let mut verifications = Vec::new();
|
||||
for verification in items
|
||||
.iter()
|
||||
.filter_map(Value::as_str)
|
||||
.filter_map(parse_model_verification)
|
||||
{
|
||||
if !verifications.contains(&verification) {
|
||||
verifications.push(verification);
|
||||
}
|
||||
}
|
||||
verifications
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
if verifications.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(verifications)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_model_verification(value: &str) -> Option<ModelVerification> {
|
||||
match value {
|
||||
TRUSTED_ACCESS_FOR_CYBER_VERIFICATION => Some(ModelVerification::TrustedAccessForCyber),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn json_value_as_string(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
Value::String(value) => Some(value.clone()),
|
||||
@@ -296,6 +341,9 @@ pub fn process_responses_event(
|
||||
response_error = ApiError::QuotaExceeded;
|
||||
} else if is_usage_not_included(&error) {
|
||||
response_error = ApiError::UsageNotIncluded;
|
||||
} else if is_cyber_policy_error(&error) {
|
||||
let message = cyber_policy_message(error.message);
|
||||
response_error = ApiError::CyberPolicy { message };
|
||||
} else if is_invalid_prompt_error(&error) {
|
||||
let message = error
|
||||
.message
|
||||
@@ -414,6 +462,7 @@ pub async fn process_sse(
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let model_verifications = event.model_verifications();
|
||||
|
||||
if let Some(model) = event.response_model()
|
||||
&& last_server_model.as_deref() != Some(model.as_str())
|
||||
@@ -427,6 +476,14 @@ pub async fn process_sse(
|
||||
}
|
||||
last_server_model = Some(model);
|
||||
}
|
||||
if let Some(verifications) = model_verifications
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::ModelVerifications(verifications)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
match process_responses_event(event) {
|
||||
Ok(Some(event)) => {
|
||||
@@ -488,11 +545,25 @@ fn is_invalid_prompt_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("invalid_prompt")
|
||||
}
|
||||
|
||||
fn is_cyber_policy_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("cyber_policy")
|
||||
}
|
||||
|
||||
fn is_server_overloaded_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("server_is_overloaded")
|
||||
|| error.code.as_deref() == Some("slow_down")
|
||||
}
|
||||
|
||||
fn cyber_policy_fallback_message() -> String {
|
||||
"This request has been flagged for possible cybersecurity risk.".to_string()
|
||||
}
|
||||
|
||||
fn cyber_policy_message(message: Option<String>) -> String {
|
||||
message
|
||||
.filter(|message| !message.trim().is_empty())
|
||||
.unwrap_or_else(cyber_policy_fallback_message)
|
||||
}
|
||||
|
||||
fn rate_limit_regex() -> &'static regex_lite::Regex {
|
||||
static RE: std::sync::OnceLock<regex_lite::Regex> = std::sync::OnceLock::new();
|
||||
#[expect(clippy::unwrap_used)]
|
||||
@@ -841,6 +912,45 @@ mod tests {
|
||||
assert_matches!(events[0], Err(ApiError::QuotaExceeded));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cyber_policy_error_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_cyber","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"cyber_policy","message":"This request was flagged for cyber policy."},"incomplete_details":null}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Err(ApiError::CyberPolicy { message }) => {
|
||||
assert_eq!(message, "This request was flagged for cyber policy.");
|
||||
}
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cyber_policy_error_uses_fallback_for_empty_message() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_cyber","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"cyber_policy","message":" "},"incomplete_details":null}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Err(ApiError::CyberPolicy { message }) => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"This request has been flagged for possible cybersecurity risk."
|
||||
);
|
||||
}
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_prompt_without_type_is_invalid_request() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_invalid_prompt_no_type","object":"response","created_at":1759771628,"status":"failed","background":false,"error":{"code":"invalid_prompt","message":"Invalid prompt: we've limited access to this content for safety reasons."},"incomplete_details":null}}"#;
|
||||
@@ -975,6 +1085,43 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_response_stream_ignores_model_verification_header() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"openai-verification-recommendation",
|
||||
HeaderValue::from_static(TRUSTED_ACCESS_FOR_CYBER_VERIFICATION),
|
||||
);
|
||||
let completed = json!({
|
||||
"type": "response.completed",
|
||||
"response": { "id": "resp-1" }
|
||||
});
|
||||
let sse = format!("event: response.completed\ndata: {completed}\n\n");
|
||||
let bytes = stream::iter(vec![Ok(Bytes::from(sse))]);
|
||||
let stream_response = StreamResponse {
|
||||
status: StatusCode::OK,
|
||||
headers,
|
||||
bytes: Box::pin(bytes),
|
||||
};
|
||||
|
||||
let mut stream = spawn_response_stream(
|
||||
stream_response,
|
||||
idle_timeout(),
|
||||
/*telemetry*/ None,
|
||||
/*turn_state*/ None,
|
||||
);
|
||||
let mut events = Vec::new();
|
||||
while let Some(event) = stream.rx_event.recv().await {
|
||||
events.push(event.expect("expected ok event"));
|
||||
}
|
||||
|
||||
assert!(
|
||||
!events
|
||||
.iter()
|
||||
.any(|event| matches!(event, ResponseEvent::ModelVerifications(_)))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_sse_ignores_response_model_field_in_payload() {
|
||||
let events = run_sse(vec![
|
||||
@@ -1042,10 +1189,44 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_sse_emits_model_verification_field() {
|
||||
let events = run_sse(vec![
|
||||
json!({
|
||||
"type": "response.metadata",
|
||||
"sequence_number": 1,
|
||||
"response_id": "resp-1",
|
||||
"metadata": {
|
||||
"openai_verification_recommendation": [TRUSTED_ACCESS_FOR_CYBER_VERIFICATION]
|
||||
}
|
||||
}),
|
||||
json!({
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": "resp-1"
|
||||
}
|
||||
}),
|
||||
])
|
||||
.await;
|
||||
|
||||
assert_matches!(
|
||||
&events[0],
|
||||
ResponseEvent::ModelVerifications(verifications)
|
||||
if verifications == &vec![ModelVerification::TrustedAccessForCyber]
|
||||
);
|
||||
assert_matches!(
|
||||
&events[1],
|
||||
ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage: None
|
||||
} if response_id == "resp-1"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn responses_stream_event_response_model_reads_top_level_headers() {
|
||||
let ev: ResponsesStreamEvent = serde_json::from_value(json!({
|
||||
"type": "codex.response.metadata",
|
||||
"type": "response.metadata",
|
||||
"headers": {
|
||||
"openai-model": CYBER_RESTRICTED_MODEL_FOR_TESTS,
|
||||
}
|
||||
@@ -1080,6 +1261,53 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn responses_stream_event_model_verification_reads_metadata_field() {
|
||||
let event = json!({
|
||||
"type": "response.metadata",
|
||||
"sequence_number": 1,
|
||||
"response_id": "resp-1",
|
||||
"metadata": {
|
||||
"openai_verification_recommendation": [TRUSTED_ACCESS_FOR_CYBER_VERIFICATION]
|
||||
}
|
||||
});
|
||||
let event: ResponsesStreamEvent =
|
||||
serde_json::from_value(event).expect("expected event to deserialize");
|
||||
|
||||
assert_eq!(
|
||||
event.model_verifications(),
|
||||
Some(vec![ModelVerification::TrustedAccessForCyber])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn responses_stream_event_model_verification_ignores_unknown_field() {
|
||||
let event = json!({
|
||||
"type": "response.metadata",
|
||||
"metadata": {
|
||||
"openai_verification_recommendation": ["unknown"]
|
||||
}
|
||||
});
|
||||
let event: ResponsesStreamEvent =
|
||||
serde_json::from_value(event).expect("expected event to deserialize");
|
||||
|
||||
assert_eq!(event.model_verifications(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn responses_stream_event_model_verification_ignores_non_array_field() {
|
||||
let event = json!({
|
||||
"type": "response.metadata",
|
||||
"metadata": {
|
||||
"openai_verification_recommendation": TRUSTED_ACCESS_FOR_CYBER_VERIFICATION
|
||||
}
|
||||
});
|
||||
let event: ResponsesStreamEvent =
|
||||
serde_json::from_value(event).expect("expected event to deserialize");
|
||||
|
||||
assert_eq!(event.model_verifications(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_retry_after() {
|
||||
let err = Error {
|
||||
|
||||
Reference in New Issue
Block a user