Add safety check notification and error handling (#19055)

Adds a new app-server notification that fires when a user account has
been flagged for potential safety reasons.
This commit is contained in:
Eric Traut
2026-04-22 22:24:12 -07:00
committed by GitHub
parent 02170996e6
commit bbff4ee61a
61 changed files with 1414 additions and 15 deletions

View File

@@ -32,6 +32,7 @@ pub fn map_api_error(err: ApiError) -> CodexErr {
identity_error_code: None,
}),
ApiError::InvalidRequest { message } => CodexErr::InvalidRequest(message),
ApiError::CyberPolicy { message } => CodexErr::CyberPolicy { message },
ApiError::Transport(transport) => match transport {
TransportError::Http {
status,
@@ -55,7 +56,19 @@ pub fn map_api_error(err: ApiError) -> CodexErr {
}
if status == http::StatusCode::BAD_REQUEST {
if body_text
if let Ok(parsed) = serde_json::from_str::<Value>(&body_text)
&& let Some(error) = parsed.get("error")
&& error.get("code").and_then(Value::as_str)
== Some(CYBER_POLICY_ERROR_CODE)
{
let message = error
.get("message")
.and_then(Value::as_str)
.filter(|message| !message.trim().is_empty())
.map(str::to_string)
.unwrap_or_else(|| CYBER_POLICY_FALLBACK_MESSAGE.to_string());
CodexErr::CyberPolicy { message }
} else if body_text
.contains("The image data you provided does not represent a valid image")
{
CodexErr::InvalidImageRequest()
@@ -125,6 +138,9 @@ const OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id";
const CF_RAY_HEADER: &str = "cf-ray";
const X_OPENAI_AUTHORIZATION_ERROR_HEADER: &str = "x-openai-authorization-error";
const X_ERROR_JSON_HEADER: &str = "x-error-json";
const CYBER_POLICY_ERROR_CODE: &str = "cyber_policy";
const CYBER_POLICY_FALLBACK_MESSAGE: &str =
"This request has been flagged for possible cybersecurity risk.";
#[cfg(test)]
#[path = "api_bridge_tests.rs"]

View File

@@ -26,6 +26,104 @@ fn map_api_error_maps_server_overloaded_from_503_body() {
assert!(matches!(err, CodexErr::ServerOverloaded));
}
#[test]
fn map_api_error_maps_cyber_policy_from_400_body() {
let body = serde_json::json!({
"error": {
"message": "This request has been flagged for potentially high-risk cyber activity.",
"type": "invalid_request",
"param": null,
"code": "cyber_policy"
}
})
.to_string();
let err = map_api_error(ApiError::Transport(TransportError::Http {
status: http::StatusCode::BAD_REQUEST,
url: Some("http://example.com/v1/responses".to_string()),
headers: None,
body: Some(body),
}));
let CodexErr::CyberPolicy { message } = err else {
panic!("expected CodexErr::CyberPolicy, got {err:?}");
};
assert_eq!(
message,
"This request has been flagged for potentially high-risk cyber activity."
);
}
#[test]
fn map_api_error_maps_wrapped_websocket_cyber_policy_from_400_body() {
let body = serde_json::json!({
"type": "error",
"status": 400,
"error": {
"message": "This websocket request was flagged.",
"type": "invalid_request",
"code": "cyber_policy"
}
})
.to_string();
let err = map_api_error(ApiError::Transport(TransportError::Http {
status: http::StatusCode::BAD_REQUEST,
url: Some("ws://example.com/v1/responses".to_string()),
headers: None,
body: Some(body),
}));
let CodexErr::CyberPolicy { message } = err else {
panic!("expected CodexErr::CyberPolicy, got {err:?}");
};
assert_eq!(message, "This websocket request was flagged.");
}
#[test]
fn map_api_error_uses_cyber_policy_fallback_for_missing_message() {
let body = serde_json::json!({
"error": {
"code": "cyber_policy"
}
})
.to_string();
let err = map_api_error(ApiError::Transport(TransportError::Http {
status: http::StatusCode::BAD_REQUEST,
url: Some("http://example.com/v1/responses".to_string()),
headers: None,
body: Some(body),
}));
let CodexErr::CyberPolicy { message } = err else {
panic!("expected CodexErr::CyberPolicy, got {err:?}");
};
assert_eq!(
message,
"This request has been flagged for possible cybersecurity risk."
);
}
#[test]
fn map_api_error_keeps_unknown_400_errors_generic() {
let body = serde_json::json!({
"error": {
"message": "Some other bad request.",
"code": "some_other_policy"
}
})
.to_string();
let err = map_api_error(ApiError::Transport(TransportError::Http {
status: http::StatusCode::BAD_REQUEST,
url: Some("http://example.com/v1/responses".to_string()),
headers: None,
body: Some(body.clone()),
}));
let CodexErr::InvalidRequest(message) = err else {
panic!("expected CodexErr::InvalidRequest, got {err:?}");
};
assert_eq!(message, body);
}
#[test]
fn map_api_error_maps_usage_limit_limit_name_header() {
let mut headers = HeaderMap::new();

View File

@@ -3,6 +3,7 @@ use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
use codex_protocol::config_types::Verbosity as VerbosityConfig;
use codex_protocol::models::ResponseItem;
use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig;
use codex_protocol::protocol::ModelVerification;
use codex_protocol::protocol::RateLimitSnapshot;
use codex_protocol::protocol::TokenUsage;
use codex_protocol::protocol::W3cTraceContext;
@@ -71,6 +72,8 @@ pub enum ResponseEvent {
/// Emitted when the server includes `OpenAI-Model` on the stream response.
/// This can differ from the requested model when backend safety routing applies.
ServerModel(String),
/// Emitted when the server recommends additional account verification.
ModelVerifications(Vec<ModelVerification>),
/// Emitted when `X-Reasoning-Included: true` is present on the response,
/// meaning the server already accounted for past reasoning tokens and the
/// client should not re-estimate them.

View File

@@ -608,6 +608,7 @@ async fn run_websocket_response_stream(
continue;
}
};
let model_verifications = event.model_verifications();
if event.kind() == "codex.rate_limits" {
if let Some(snapshot) = parse_rate_limit_event(&text) {
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
@@ -622,6 +623,16 @@ async fn run_websocket_response_stream(
.await;
last_server_model = Some(model);
}
if let Some(verifications) = model_verifications
&& tx_event
.send(Ok(ResponseEvent::ModelVerifications(verifications)))
.await
.is_err()
{
return Err(ApiError::Stream(
"response event consumer dropped".to_string(),
));
}
match process_responses_event(event) {
Ok(Some(event)) => {
let is_completed = matches!(event, ResponseEvent::Completed { .. });

View File

@@ -27,6 +27,8 @@ pub enum ApiError {
RateLimit(String),
#[error("invalid request: {message}")]
InvalidRequest { message: String },
#[error("cyber policy: {message}")]
CyberPolicy { message: String },
#[error("server overloaded")]
ServerOverloaded,
}

View File

@@ -7,6 +7,7 @@ use codex_client::ByteStream;
use codex_client::StreamResponse;
use codex_client::TransportError;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::ModelVerification;
use codex_protocol::protocol::TokenUsage;
use eventsource_stream::Eventsource;
use futures::StreamExt;
@@ -27,6 +28,7 @@ use tracing::trace;
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
const OPENAI_MODEL_HEADER: &str = "openai-model";
const TRUSTED_ACCESS_FOR_CYBER_VERIFICATION: &str = "trusted_access_for_cyber";
/// Streams SSE events from an on-disk fixture for tests.
pub fn stream_from_fixture(
@@ -165,6 +167,7 @@ pub struct ResponsesStreamEvent {
#[serde(rename = "type")]
pub(crate) kind: String,
headers: Option<Value>,
metadata: Option<Value>,
response: Option<Value>,
item: Option<Value>,
item_id: Option<String>,
@@ -183,8 +186,7 @@ impl ResponsesStreamEvent {
///
/// Precedence:
/// 1. `response.headers` for standard Responses stream events.
/// 2. top-level `headers` for websocket metadata events (for example
/// `codex.response.metadata`).
/// 2. top-level `headers` for websocket metadata events.
pub fn response_model(&self) -> Option<String> {
let response_headers_model = self
.response
@@ -200,6 +202,17 @@ impl ResponsesStreamEvent {
.and_then(header_openai_model_value_from_json),
}
}
pub(crate) fn model_verifications(&self) -> Option<Vec<ModelVerification>> {
if self.kind() != "response.metadata" {
return None;
}
self.metadata
.as_ref()
.and_then(|metadata| metadata.get("openai_verification_recommendation"))
.and_then(model_verifications_from_json_value)
}
}
fn header_openai_model_value_from_json(value: &Value) -> Option<String> {
@@ -214,6 +227,38 @@ fn header_openai_model_value_from_json(value: &Value) -> Option<String> {
})
}
fn model_verifications_from_json_value(value: &Value) -> Option<Vec<ModelVerification>> {
let verifications = value
.as_array()
.map(|items| {
let mut verifications = Vec::new();
for verification in items
.iter()
.filter_map(Value::as_str)
.filter_map(parse_model_verification)
{
if !verifications.contains(&verification) {
verifications.push(verification);
}
}
verifications
})
.unwrap_or_default();
if verifications.is_empty() {
None
} else {
Some(verifications)
}
}
fn parse_model_verification(value: &str) -> Option<ModelVerification> {
match value {
TRUSTED_ACCESS_FOR_CYBER_VERIFICATION => Some(ModelVerification::TrustedAccessForCyber),
_ => None,
}
}
fn json_value_as_string(value: &Value) -> Option<String> {
match value {
Value::String(value) => Some(value.clone()),
@@ -296,6 +341,9 @@ pub fn process_responses_event(
response_error = ApiError::QuotaExceeded;
} else if is_usage_not_included(&error) {
response_error = ApiError::UsageNotIncluded;
} else if is_cyber_policy_error(&error) {
let message = cyber_policy_message(error.message);
response_error = ApiError::CyberPolicy { message };
} else if is_invalid_prompt_error(&error) {
let message = error
.message
@@ -414,6 +462,7 @@ pub async fn process_sse(
continue;
}
};
let model_verifications = event.model_verifications();
if let Some(model) = event.response_model()
&& last_server_model.as_deref() != Some(model.as_str())
@@ -427,6 +476,14 @@ pub async fn process_sse(
}
last_server_model = Some(model);
}
if let Some(verifications) = model_verifications
&& tx_event
.send(Ok(ResponseEvent::ModelVerifications(verifications)))
.await
.is_err()
{
return;
}
match process_responses_event(event) {
Ok(Some(event)) => {
@@ -488,11 +545,25 @@ fn is_invalid_prompt_error(error: &Error) -> bool {
error.code.as_deref() == Some("invalid_prompt")
}
fn is_cyber_policy_error(error: &Error) -> bool {
error.code.as_deref() == Some("cyber_policy")
}
fn is_server_overloaded_error(error: &Error) -> bool {
error.code.as_deref() == Some("server_is_overloaded")
|| error.code.as_deref() == Some("slow_down")
}
fn cyber_policy_fallback_message() -> String {
"This request has been flagged for possible cybersecurity risk.".to_string()
}
fn cyber_policy_message(message: Option<String>) -> String {
message
.filter(|message| !message.trim().is_empty())
.unwrap_or_else(cyber_policy_fallback_message)
}
fn rate_limit_regex() -> &'static regex_lite::Regex {
static RE: std::sync::OnceLock<regex_lite::Regex> = std::sync::OnceLock::new();
#[expect(clippy::unwrap_used)]
@@ -841,6 +912,45 @@ mod tests {
assert_matches!(events[0], Err(ApiError::QuotaExceeded));
}
#[tokio::test]
async fn cyber_policy_error_is_fatal() {
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_cyber","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"cyber_policy","message":"This request was flagged for cyber policy."},"incomplete_details":null}}"#;
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
let events = collect_events(&[sse1.as_bytes()]).await;
assert_eq!(events.len(), 1);
match &events[0] {
Err(ApiError::CyberPolicy { message }) => {
assert_eq!(message, "This request was flagged for cyber policy.");
}
other => panic!("unexpected event: {other:?}"),
}
}
#[tokio::test]
async fn cyber_policy_error_uses_fallback_for_empty_message() {
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_cyber","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"cyber_policy","message":" "},"incomplete_details":null}}"#;
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
let events = collect_events(&[sse1.as_bytes()]).await;
assert_eq!(events.len(), 1);
match &events[0] {
Err(ApiError::CyberPolicy { message }) => {
assert_eq!(
message,
"This request has been flagged for possible cybersecurity risk."
);
}
other => panic!("unexpected event: {other:?}"),
}
}
#[tokio::test]
async fn invalid_prompt_without_type_is_invalid_request() {
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_invalid_prompt_no_type","object":"response","created_at":1759771628,"status":"failed","background":false,"error":{"code":"invalid_prompt","message":"Invalid prompt: we've limited access to this content for safety reasons."},"incomplete_details":null}}"#;
@@ -975,6 +1085,43 @@ mod tests {
}
}
#[tokio::test]
async fn spawn_response_stream_ignores_model_verification_header() {
let mut headers = HeaderMap::new();
headers.insert(
"openai-verification-recommendation",
HeaderValue::from_static(TRUSTED_ACCESS_FOR_CYBER_VERIFICATION),
);
let completed = json!({
"type": "response.completed",
"response": { "id": "resp-1" }
});
let sse = format!("event: response.completed\ndata: {completed}\n\n");
let bytes = stream::iter(vec![Ok(Bytes::from(sse))]);
let stream_response = StreamResponse {
status: StatusCode::OK,
headers,
bytes: Box::pin(bytes),
};
let mut stream = spawn_response_stream(
stream_response,
idle_timeout(),
/*telemetry*/ None,
/*turn_state*/ None,
);
let mut events = Vec::new();
while let Some(event) = stream.rx_event.recv().await {
events.push(event.expect("expected ok event"));
}
assert!(
!events
.iter()
.any(|event| matches!(event, ResponseEvent::ModelVerifications(_)))
);
}
#[tokio::test]
async fn process_sse_ignores_response_model_field_in_payload() {
let events = run_sse(vec![
@@ -1042,10 +1189,44 @@ mod tests {
);
}
#[tokio::test]
async fn process_sse_emits_model_verification_field() {
let events = run_sse(vec![
json!({
"type": "response.metadata",
"sequence_number": 1,
"response_id": "resp-1",
"metadata": {
"openai_verification_recommendation": [TRUSTED_ACCESS_FOR_CYBER_VERIFICATION]
}
}),
json!({
"type": "response.completed",
"response": {
"id": "resp-1"
}
}),
])
.await;
assert_matches!(
&events[0],
ResponseEvent::ModelVerifications(verifications)
if verifications == &vec![ModelVerification::TrustedAccessForCyber]
);
assert_matches!(
&events[1],
ResponseEvent::Completed {
response_id,
token_usage: None
} if response_id == "resp-1"
);
}
#[test]
fn responses_stream_event_response_model_reads_top_level_headers() {
let ev: ResponsesStreamEvent = serde_json::from_value(json!({
"type": "codex.response.metadata",
"type": "response.metadata",
"headers": {
"openai-model": CYBER_RESTRICTED_MODEL_FOR_TESTS,
}
@@ -1080,6 +1261,53 @@ mod tests {
);
}
#[test]
fn responses_stream_event_model_verification_reads_metadata_field() {
let event = json!({
"type": "response.metadata",
"sequence_number": 1,
"response_id": "resp-1",
"metadata": {
"openai_verification_recommendation": [TRUSTED_ACCESS_FOR_CYBER_VERIFICATION]
}
});
let event: ResponsesStreamEvent =
serde_json::from_value(event).expect("expected event to deserialize");
assert_eq!(
event.model_verifications(),
Some(vec![ModelVerification::TrustedAccessForCyber])
);
}
#[test]
fn responses_stream_event_model_verification_ignores_unknown_field() {
let event = json!({
"type": "response.metadata",
"metadata": {
"openai_verification_recommendation": ["unknown"]
}
});
let event: ResponsesStreamEvent =
serde_json::from_value(event).expect("expected event to deserialize");
assert_eq!(event.model_verifications(), None);
}
#[test]
fn responses_stream_event_model_verification_ignores_non_array_field() {
let event = json!({
"type": "response.metadata",
"metadata": {
"openai_verification_recommendation": TRUSTED_ACCESS_FOR_CYBER_VERIFICATION
}
});
let event: ResponsesStreamEvent =
serde_json::from_value(event).expect("expected event to deserialize");
assert_eq!(event.model_verifications(), None);
}
#[test]
fn test_try_parse_retry_after() {
let err = Error {