mirror of
https://github.com/openai/codex.git
synced 2026-05-04 13:21:54 +03:00
Add safety check notification and error handling (#19055)
Adds a new app-server notification that fires when a user account has been flagged for potential safety reasons.
This commit is contained in:
@@ -7,6 +7,7 @@ use codex_client::ByteStream;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_client::TransportError;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::ModelVerification;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::StreamExt;
|
||||
@@ -27,6 +28,7 @@ use tracing::trace;
|
||||
|
||||
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
|
||||
const OPENAI_MODEL_HEADER: &str = "openai-model";
|
||||
const TRUSTED_ACCESS_FOR_CYBER_VERIFICATION: &str = "trusted_access_for_cyber";
|
||||
|
||||
/// Streams SSE events from an on-disk fixture for tests.
|
||||
pub fn stream_from_fixture(
|
||||
@@ -165,6 +167,7 @@ pub struct ResponsesStreamEvent {
|
||||
#[serde(rename = "type")]
|
||||
pub(crate) kind: String,
|
||||
headers: Option<Value>,
|
||||
metadata: Option<Value>,
|
||||
response: Option<Value>,
|
||||
item: Option<Value>,
|
||||
item_id: Option<String>,
|
||||
@@ -183,8 +186,7 @@ impl ResponsesStreamEvent {
|
||||
///
|
||||
/// Precedence:
|
||||
/// 1. `response.headers` for standard Responses stream events.
|
||||
/// 2. top-level `headers` for websocket metadata events (for example
|
||||
/// `codex.response.metadata`).
|
||||
/// 2. top-level `headers` for websocket metadata events.
|
||||
pub fn response_model(&self) -> Option<String> {
|
||||
let response_headers_model = self
|
||||
.response
|
||||
@@ -200,6 +202,17 @@ impl ResponsesStreamEvent {
|
||||
.and_then(header_openai_model_value_from_json),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn model_verifications(&self) -> Option<Vec<ModelVerification>> {
|
||||
if self.kind() != "response.metadata" {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.metadata
|
||||
.as_ref()
|
||||
.and_then(|metadata| metadata.get("openai_verification_recommendation"))
|
||||
.and_then(model_verifications_from_json_value)
|
||||
}
|
||||
}
|
||||
|
||||
fn header_openai_model_value_from_json(value: &Value) -> Option<String> {
|
||||
@@ -214,6 +227,38 @@ fn header_openai_model_value_from_json(value: &Value) -> Option<String> {
|
||||
})
|
||||
}
|
||||
|
||||
fn model_verifications_from_json_value(value: &Value) -> Option<Vec<ModelVerification>> {
|
||||
let verifications = value
|
||||
.as_array()
|
||||
.map(|items| {
|
||||
let mut verifications = Vec::new();
|
||||
for verification in items
|
||||
.iter()
|
||||
.filter_map(Value::as_str)
|
||||
.filter_map(parse_model_verification)
|
||||
{
|
||||
if !verifications.contains(&verification) {
|
||||
verifications.push(verification);
|
||||
}
|
||||
}
|
||||
verifications
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
if verifications.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(verifications)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_model_verification(value: &str) -> Option<ModelVerification> {
|
||||
match value {
|
||||
TRUSTED_ACCESS_FOR_CYBER_VERIFICATION => Some(ModelVerification::TrustedAccessForCyber),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn json_value_as_string(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
Value::String(value) => Some(value.clone()),
|
||||
@@ -296,6 +341,9 @@ pub fn process_responses_event(
|
||||
response_error = ApiError::QuotaExceeded;
|
||||
} else if is_usage_not_included(&error) {
|
||||
response_error = ApiError::UsageNotIncluded;
|
||||
} else if is_cyber_policy_error(&error) {
|
||||
let message = cyber_policy_message(error.message);
|
||||
response_error = ApiError::CyberPolicy { message };
|
||||
} else if is_invalid_prompt_error(&error) {
|
||||
let message = error
|
||||
.message
|
||||
@@ -414,6 +462,7 @@ pub async fn process_sse(
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let model_verifications = event.model_verifications();
|
||||
|
||||
if let Some(model) = event.response_model()
|
||||
&& last_server_model.as_deref() != Some(model.as_str())
|
||||
@@ -427,6 +476,14 @@ pub async fn process_sse(
|
||||
}
|
||||
last_server_model = Some(model);
|
||||
}
|
||||
if let Some(verifications) = model_verifications
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::ModelVerifications(verifications)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
match process_responses_event(event) {
|
||||
Ok(Some(event)) => {
|
||||
@@ -488,11 +545,25 @@ fn is_invalid_prompt_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("invalid_prompt")
|
||||
}
|
||||
|
||||
fn is_cyber_policy_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("cyber_policy")
|
||||
}
|
||||
|
||||
fn is_server_overloaded_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("server_is_overloaded")
|
||||
|| error.code.as_deref() == Some("slow_down")
|
||||
}
|
||||
|
||||
fn cyber_policy_fallback_message() -> String {
|
||||
"This request has been flagged for possible cybersecurity risk.".to_string()
|
||||
}
|
||||
|
||||
fn cyber_policy_message(message: Option<String>) -> String {
|
||||
message
|
||||
.filter(|message| !message.trim().is_empty())
|
||||
.unwrap_or_else(cyber_policy_fallback_message)
|
||||
}
|
||||
|
||||
fn rate_limit_regex() -> &'static regex_lite::Regex {
|
||||
static RE: std::sync::OnceLock<regex_lite::Regex> = std::sync::OnceLock::new();
|
||||
#[expect(clippy::unwrap_used)]
|
||||
@@ -841,6 +912,45 @@ mod tests {
|
||||
assert_matches!(events[0], Err(ApiError::QuotaExceeded));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cyber_policy_error_is_fatal() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_cyber","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"cyber_policy","message":"This request was flagged for cyber policy."},"incomplete_details":null}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Err(ApiError::CyberPolicy { message }) => {
|
||||
assert_eq!(message, "This request was flagged for cyber policy.");
|
||||
}
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cyber_policy_error_uses_fallback_for_empty_message() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_cyber","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"cyber_policy","message":" "},"incomplete_details":null}}"#;
|
||||
|
||||
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
|
||||
|
||||
let events = collect_events(&[sse1.as_bytes()]).await;
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
|
||||
match &events[0] {
|
||||
Err(ApiError::CyberPolicy { message }) => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"This request has been flagged for possible cybersecurity risk."
|
||||
);
|
||||
}
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_prompt_without_type_is_invalid_request() {
|
||||
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_invalid_prompt_no_type","object":"response","created_at":1759771628,"status":"failed","background":false,"error":{"code":"invalid_prompt","message":"Invalid prompt: we've limited access to this content for safety reasons."},"incomplete_details":null}}"#;
|
||||
@@ -975,6 +1085,43 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_response_stream_ignores_model_verification_header() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"openai-verification-recommendation",
|
||||
HeaderValue::from_static(TRUSTED_ACCESS_FOR_CYBER_VERIFICATION),
|
||||
);
|
||||
let completed = json!({
|
||||
"type": "response.completed",
|
||||
"response": { "id": "resp-1" }
|
||||
});
|
||||
let sse = format!("event: response.completed\ndata: {completed}\n\n");
|
||||
let bytes = stream::iter(vec![Ok(Bytes::from(sse))]);
|
||||
let stream_response = StreamResponse {
|
||||
status: StatusCode::OK,
|
||||
headers,
|
||||
bytes: Box::pin(bytes),
|
||||
};
|
||||
|
||||
let mut stream = spawn_response_stream(
|
||||
stream_response,
|
||||
idle_timeout(),
|
||||
/*telemetry*/ None,
|
||||
/*turn_state*/ None,
|
||||
);
|
||||
let mut events = Vec::new();
|
||||
while let Some(event) = stream.rx_event.recv().await {
|
||||
events.push(event.expect("expected ok event"));
|
||||
}
|
||||
|
||||
assert!(
|
||||
!events
|
||||
.iter()
|
||||
.any(|event| matches!(event, ResponseEvent::ModelVerifications(_)))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_sse_ignores_response_model_field_in_payload() {
|
||||
let events = run_sse(vec![
|
||||
@@ -1042,10 +1189,44 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_sse_emits_model_verification_field() {
|
||||
let events = run_sse(vec![
|
||||
json!({
|
||||
"type": "response.metadata",
|
||||
"sequence_number": 1,
|
||||
"response_id": "resp-1",
|
||||
"metadata": {
|
||||
"openai_verification_recommendation": [TRUSTED_ACCESS_FOR_CYBER_VERIFICATION]
|
||||
}
|
||||
}),
|
||||
json!({
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": "resp-1"
|
||||
}
|
||||
}),
|
||||
])
|
||||
.await;
|
||||
|
||||
assert_matches!(
|
||||
&events[0],
|
||||
ResponseEvent::ModelVerifications(verifications)
|
||||
if verifications == &vec![ModelVerification::TrustedAccessForCyber]
|
||||
);
|
||||
assert_matches!(
|
||||
&events[1],
|
||||
ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage: None
|
||||
} if response_id == "resp-1"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn responses_stream_event_response_model_reads_top_level_headers() {
|
||||
let ev: ResponsesStreamEvent = serde_json::from_value(json!({
|
||||
"type": "codex.response.metadata",
|
||||
"type": "response.metadata",
|
||||
"headers": {
|
||||
"openai-model": CYBER_RESTRICTED_MODEL_FOR_TESTS,
|
||||
}
|
||||
@@ -1080,6 +1261,53 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn responses_stream_event_model_verification_reads_metadata_field() {
|
||||
let event = json!({
|
||||
"type": "response.metadata",
|
||||
"sequence_number": 1,
|
||||
"response_id": "resp-1",
|
||||
"metadata": {
|
||||
"openai_verification_recommendation": [TRUSTED_ACCESS_FOR_CYBER_VERIFICATION]
|
||||
}
|
||||
});
|
||||
let event: ResponsesStreamEvent =
|
||||
serde_json::from_value(event).expect("expected event to deserialize");
|
||||
|
||||
assert_eq!(
|
||||
event.model_verifications(),
|
||||
Some(vec![ModelVerification::TrustedAccessForCyber])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn responses_stream_event_model_verification_ignores_unknown_field() {
|
||||
let event = json!({
|
||||
"type": "response.metadata",
|
||||
"metadata": {
|
||||
"openai_verification_recommendation": ["unknown"]
|
||||
}
|
||||
});
|
||||
let event: ResponsesStreamEvent =
|
||||
serde_json::from_value(event).expect("expected event to deserialize");
|
||||
|
||||
assert_eq!(event.model_verifications(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn responses_stream_event_model_verification_ignores_non_array_field() {
|
||||
let event = json!({
|
||||
"type": "response.metadata",
|
||||
"metadata": {
|
||||
"openai_verification_recommendation": TRUSTED_ACCESS_FOR_CYBER_VERIFICATION
|
||||
}
|
||||
});
|
||||
let event: ResponsesStreamEvent =
|
||||
serde_json::from_value(event).expect("expected event to deserialize");
|
||||
|
||||
assert_eq!(event.model_verifications(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_parse_retry_after() {
|
||||
let err = Error {
|
||||
|
||||
Reference in New Issue
Block a user