Compare commits

...

3 Commits

Author SHA1 Message Date
celia-oai
41a653fd81 fix tests 2026-03-20 16:59:57 -07:00
celia-oai
703ed17fb6 changes 2026-03-20 16:29:57 -07:00
Matthew Zeng
dd88ed767b [apps] Use ARC for yolo mode. (#15273)
- [x] Use ARC for yolo mode.
2026-03-20 21:13:20 +00:00
7 changed files with 356 additions and 49 deletions

View File

@@ -883,6 +883,24 @@ mod tests {
account_id: Option<&str>,
access_token: &str,
refresh_token: &str,
) -> serde_json::Value {
chatgpt_auth_json_with_last_refresh(
plan_type,
chatgpt_user_id,
account_id,
access_token,
refresh_token,
"2025-01-01T00:00:00Z",
)
}
fn chatgpt_auth_json_with_last_refresh(
plan_type: &str,
chatgpt_user_id: Option<&str>,
account_id: Option<&str>,
access_token: &str,
refresh_token: &str,
last_refresh: &str,
) -> serde_json::Value {
chatgpt_auth_json_with_mode(
plan_type,
@@ -890,6 +908,7 @@ mod tests {
account_id,
access_token,
refresh_token,
last_refresh,
None,
)
}
@@ -900,6 +919,7 @@ mod tests {
account_id: Option<&str>,
access_token: &str,
refresh_token: &str,
last_refresh: &str,
auth_mode: Option<&str>,
) -> serde_json::Value {
let header = json!({ "alg": "none", "typ": "JWT" });
@@ -925,7 +945,7 @@ mod tests {
"refresh_token": refresh_token,
"account_id": account_id,
},
"last_refresh": "2025-01-01T00:00:00Z",
"last_refresh": last_refresh,
});
if let Some(auth_mode) = auth_mode {
auth_json["auth_mode"] = serde_json::Value::String(auth_mode.to_string());
@@ -1262,24 +1282,43 @@ enabled = false
#[tokio::test]
async fn fetch_cloud_requirements_recovers_after_unauthorized_reload() {
let auth = managed_auth_context(
"business",
Some("user-12345"),
Some("account-12345"),
"stale-access-token",
"test-refresh-token",
);
let auth_home = tempdir().expect("tempdir");
write_auth_json(
auth._home.path(),
chatgpt_auth_json(
auth_home.path(),
chatgpt_auth_json_with_last_refresh(
"business",
Some("user-12345"),
Some("account-12345"),
"stale-access-token",
"test-refresh-token",
// Keep auth "fresh" so the first request hits unauthorized recovery
// instead of AuthManager::auth() proactively reloading from disk.
"3025-01-01T00:00:00Z",
),
)
.expect("write initial auth");
let auth_manager = Arc::new(AuthManager::new(
auth_home.path().to_path_buf(),
false,
AuthCredentialsStoreMode::File,
));
write_auth_json(
auth_home.path(),
chatgpt_auth_json_with_last_refresh(
"business",
Some("user-12345"),
Some("account-12345"),
"fresh-access-token",
"test-refresh-token",
"3025-01-01T00:00:00Z",
),
)
.expect("write refreshed auth");
let auth = ManagedAuthContext {
_home: auth_home,
manager: auth_manager,
};
let fetcher = Arc::new(TokenFetcher {
expected_token: "fresh-access-token".to_string(),
@@ -1314,24 +1353,41 @@ enabled = false
#[tokio::test]
async fn fetch_cloud_requirements_recovers_after_unauthorized_reload_updates_cache_identity() {
let auth = managed_auth_context(
"business",
Some("user-12345"),
Some("account-12345"),
"stale-access-token",
"test-refresh-token",
);
let auth_home = tempdir().expect("tempdir");
write_auth_json(
auth._home.path(),
chatgpt_auth_json(
auth_home.path(),
chatgpt_auth_json_with_last_refresh(
"business",
Some("user-12345"),
Some("account-12345"),
"stale-access-token",
"test-refresh-token",
"3025-01-01T00:00:00Z",
),
)
.expect("write initial auth");
let auth_manager = Arc::new(AuthManager::new(
auth_home.path().to_path_buf(),
false,
AuthCredentialsStoreMode::File,
));
write_auth_json(
auth_home.path(),
chatgpt_auth_json_with_last_refresh(
"business",
Some("user-99999"),
Some("account-12345"),
"fresh-access-token",
"test-refresh-token",
"3025-01-01T00:00:00Z",
),
)
.expect("write refreshed auth");
let auth = ManagedAuthContext {
_home: auth_home,
manager: auth_manager,
};
let fetcher = Arc::new(TokenFetcher {
expected_token: "fresh-access-token".to_string(),
@@ -1432,6 +1488,7 @@ enabled = false
Some("account-12345"),
"test-access-token",
"test-refresh-token",
"2025-01-01T00:00:00Z",
Some("chatgptAuthTokens"),
),
)

View File

@@ -99,6 +99,7 @@ pub(crate) async fn monitor_action(
sess: &Session,
turn_context: &TurnContext,
action: serde_json::Value,
protection_client_callsite: &'static str,
) -> ArcMonitorOutcome {
let auth = match turn_context.auth_manager.as_ref() {
Some(auth_manager) => match auth_manager.auth().await {
@@ -138,7 +139,8 @@ pub(crate) async fn monitor_action(
return ArcMonitorOutcome::Ok;
}
};
let body = build_arc_monitor_request(sess, turn_context, action).await;
let body =
build_arc_monitor_request(sess, turn_context, action, protection_client_callsite).await;
let client = build_reqwest_client();
let mut request = client
.post(&url)
@@ -236,6 +238,7 @@ async fn build_arc_monitor_request(
sess: &Session,
turn_context: &TurnContext,
action: serde_json::Map<String, serde_json::Value>,
protection_client_callsite: &'static str,
) -> ArcMonitorRequest {
let history = sess.clone_history().await;
let mut messages = build_arc_monitor_messages(history.raw_items());
@@ -254,7 +257,7 @@ async fn build_arc_monitor_request(
codex_thread_id: conversation_id.clone(),
codex_turn_id: turn_context.sub_id.clone(),
conversation_id: Some(conversation_id),
protection_client_callsite: None,
protection_client_callsite: Some(protection_client_callsite.to_string()),
},
messages: Some(messages),
input: None,

View File

@@ -178,6 +178,7 @@ async fn build_arc_monitor_request_includes_relevant_history_and_null_policies()
&turn_context,
serde_json::from_value(serde_json::json!({ "tool": "mcp_tool_call" }))
.expect("action should deserialize"),
"normal",
)
.await;
@@ -188,7 +189,7 @@ async fn build_arc_monitor_request_includes_relevant_history_and_null_policies()
codex_thread_id: session.conversation_id.to_string(),
codex_turn_id: turn_context.sub_id.clone(),
conversation_id: Some(session.conversation_id.to_string()),
protection_client_callsite: None,
protection_client_callsite: Some("normal".to_string()),
},
messages: Some(vec![
ArcMonitorChatMessage {
@@ -285,6 +286,7 @@ async fn monitor_action_posts_expected_arc_request() {
"codex_thread_id": session.conversation_id.to_string(),
"codex_turn_id": turn_context.sub_id.clone(),
"conversation_id": session.conversation_id.to_string(),
"protection_client_callsite": "normal",
},
"messages": [{
"role": "user",
@@ -320,6 +322,7 @@ async fn monitor_action_posts_expected_arc_request() {
&session,
&turn_context,
serde_json::json!({ "tool": "mcp_tool_call" }),
"normal",
)
.await;
@@ -377,6 +380,7 @@ async fn monitor_action_uses_env_url_and_token_overrides() {
&session,
&turn_context,
serde_json::json!({ "tool": "mcp_tool_call" }),
"normal",
)
.await;
@@ -428,6 +432,7 @@ async fn monitor_action_rejects_legacy_response_fields() {
&session,
&turn_context,
serde_json::json!({ "tool": "mcp_tool_call" }),
"normal",
)
.await;

View File

@@ -457,6 +457,9 @@ const MCP_TOOL_APPROVAL_TOOL_TITLE_KEY: &str = "tool_title";
const MCP_TOOL_APPROVAL_TOOL_DESCRIPTION_KEY: &str = "tool_description";
const MCP_TOOL_APPROVAL_TOOL_PARAMS_KEY: &str = "tool_params";
const MCP_TOOL_APPROVAL_TOOL_PARAMS_DISPLAY_KEY: &str = "tool_params_display";
const MCP_TOOL_CALL_ARC_MONITOR_CALLSITE_DEFAULT: &str = "mcp_tool_call__default";
const MCP_TOOL_CALL_ARC_MONITOR_CALLSITE_ALWAYS_ALLOW: &str = "mcp_tool_call__always_allow";
const MCP_TOOL_CALL_ARC_MONITOR_CALLSITE_FULL_ACCESS: &str = "mcp_tool_call__full_access";
pub(crate) fn is_mcp_tool_approval_question_id(question_id: &str) -> bool {
question_id
@@ -494,14 +497,22 @@ async fn maybe_request_mcp_tool_approval(
let annotations = metadata.and_then(|metadata| metadata.annotations.as_ref());
let approval_required = annotations.is_some_and(requires_mcp_tool_approval);
let mut monitor_reason = None;
let auto_approved_by_policy = approval_mode == AppToolApproval::Approve
|| (approval_mode == AppToolApproval::Auto && is_full_access_mode(turn_context));
if approval_mode == AppToolApproval::Approve {
if auto_approved_by_policy {
if !approval_required {
return None;
}
match maybe_monitor_auto_approved_mcp_tool_call(sess, turn_context, invocation, metadata)
.await
match maybe_monitor_auto_approved_mcp_tool_call(
sess,
turn_context,
invocation,
metadata,
approval_mode,
)
.await
{
ArcMonitorOutcome::Ok => return None,
ArcMonitorOutcome::AskUser(reason) => {
@@ -515,13 +526,8 @@ async fn maybe_request_mcp_tool_approval(
}
}
if approval_mode == AppToolApproval::Auto {
if is_full_access_mode(turn_context) {
return None;
}
if !approval_required {
return None;
}
if approval_mode == AppToolApproval::Auto && !approval_required {
return None;
}
let session_approval_key = session_mcp_tool_approval_key(invocation, metadata, approval_mode);
@@ -653,9 +659,16 @@ async fn maybe_monitor_auto_approved_mcp_tool_call(
turn_context: &TurnContext,
invocation: &McpInvocation,
metadata: Option<&McpToolApprovalMetadata>,
approval_mode: AppToolApproval,
) -> ArcMonitorOutcome {
let action = prepare_arc_request_action(invocation, metadata);
monitor_action(sess, turn_context, action).await
monitor_action(
sess,
turn_context,
action,
mcp_tool_approval_callsite_mode(approval_mode, turn_context),
)
.await
}
fn prepare_arc_request_action(
@@ -749,6 +762,22 @@ fn is_full_access_mode(turn_context: &TurnContext) -> bool {
)
}
fn mcp_tool_approval_callsite_mode(
approval_mode: AppToolApproval,
turn_context: &TurnContext,
) -> &'static str {
match approval_mode {
AppToolApproval::Approve => MCP_TOOL_CALL_ARC_MONITOR_CALLSITE_ALWAYS_ALLOW,
AppToolApproval::Auto | AppToolApproval::Prompt => {
if approval_mode == AppToolApproval::Auto && is_full_access_mode(turn_context) {
MCP_TOOL_CALL_ARC_MONITOR_CALLSITE_FULL_ACCESS
} else {
MCP_TOOL_CALL_ARC_MONITOR_CALLSITE_DEFAULT
}
}
}
}
pub(crate) async fn lookup_mcp_tool_metadata(
sess: &Session,
turn_context: &TurnContext,

View File

@@ -776,6 +776,38 @@ fn approval_elicitation_meta_merges_session_and_always_persist_with_connector_so
);
}
#[tokio::test]
async fn approval_callsite_mode_distinguishes_default_always_allow_and_full_access() {
let (_session, mut turn_context) = make_session_and_context().await;
assert_eq!(
mcp_tool_approval_callsite_mode(AppToolApproval::Auto, &turn_context),
"mcp_tool_call__default"
);
assert_eq!(
mcp_tool_approval_callsite_mode(AppToolApproval::Prompt, &turn_context),
"mcp_tool_call__default"
);
assert_eq!(
mcp_tool_approval_callsite_mode(AppToolApproval::Approve, &turn_context),
"mcp_tool_call__always_allow"
);
turn_context
.approval_policy
.set(AskForApproval::Never)
.expect("test setup should allow updating approval policy");
turn_context
.sandbox_policy
.set(SandboxPolicy::DangerFullAccess)
.expect("test setup should allow updating sandbox policy");
assert_eq!(
mcp_tool_approval_callsite_mode(AppToolApproval::Auto, &turn_context),
"mcp_tool_call__full_access"
);
}
#[test]
fn declined_elicitation_response_stays_decline() {
let response = parse_mcp_tool_approval_elicitation_response(
@@ -1035,6 +1067,83 @@ async fn approve_mode_blocks_when_arc_returns_interrupt_for_model() {
);
}
#[tokio::test]
async fn full_access_auto_mode_blocks_when_arc_returns_interrupt_for_model() {
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/codex/safety/arc"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"outcome": "steer-model",
"short_reason": "needs approval",
"rationale": "high-risk action",
"risk_score": 96,
"risk_level": "critical",
"evidence": [{
"message": "dangerous_tool",
"why": "high-risk action",
}],
})))
.expect(1)
.mount(&server)
.await;
let (session, mut turn_context) = make_session_and_context().await;
turn_context.auth_manager = Some(crate::test_support::auth_manager_from_auth(
crate::CodexAuth::create_dummy_chatgpt_auth_for_testing(),
));
turn_context
.approval_policy
.set(AskForApproval::Never)
.expect("test setup should allow updating approval policy");
turn_context
.sandbox_policy
.set(SandboxPolicy::DangerFullAccess)
.expect("test setup should allow updating sandbox policy");
let mut config = (*turn_context.config).clone();
config.chatgpt_base_url = server.uri();
turn_context.config = Arc::new(config);
let session = Arc::new(session);
let turn_context = Arc::new(turn_context);
let invocation = McpInvocation {
server: CODEX_APPS_MCP_SERVER_NAME.to_string(),
tool: "dangerous_tool".to_string(),
arguments: Some(serde_json::json!({ "id": 1 })),
};
let metadata = McpToolApprovalMetadata {
annotations: Some(annotations(Some(false), Some(true), Some(true))),
connector_id: Some("calendar".to_string()),
connector_name: Some("Calendar".to_string()),
connector_description: Some("Manage events".to_string()),
tool_title: Some("Dangerous Tool".to_string()),
tool_description: Some("Performs a risky action.".to_string()),
codex_apps_meta: None,
};
let decision = maybe_request_mcp_tool_approval(
&session,
&turn_context,
"call-2",
&invocation,
Some(&metadata),
AppToolApproval::Auto,
)
.await;
assert_eq!(
decision,
Some(McpToolApprovalDecision::BlockedBySafetyMonitor(
"Tool call was cancelled because of safety risks: high-risk action".to_string(),
))
);
}
#[tokio::test]
async fn approve_mode_routes_arc_ask_user_to_guardian_when_guardian_reviewer_is_enabled() {
use wiremock::Mock;

View File

@@ -381,6 +381,116 @@ async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> {
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn auth_reloads_disk_auth_when_cached_auth_is_stale() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
let ctx = RefreshTokenTestContext::new(&server)?;
let stale_refresh = Utc::now() - Duration::days(9);
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(initial_tokens),
last_refresh: Some(stale_refresh),
};
ctx.write_auth(&initial_auth)?;
let fresh_refresh = Utc::now() - Duration::days(1);
let disk_tokens = build_tokens("disk-access-token", "disk-refresh-token");
let disk_auth = AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(disk_tokens.clone()),
last_refresh: Some(fresh_refresh),
};
save_auth(
ctx.codex_home.path(),
&disk_auth,
AuthCredentialsStoreMode::File,
)?;
let cached_auth = ctx
.auth_manager
.auth()
.await
.context("auth should reload from disk")?;
let cached = cached_auth
.get_token_data()
.context("token data should reload from disk")?;
assert_eq!(cached, disk_tokens);
let stored = ctx.load_auth()?;
assert_eq!(stored, disk_auth);
let requests = server.received_requests().await.unwrap_or_default();
assert!(requests.is_empty(), "expected no refresh token requests");
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn auth_reloads_disk_auth_without_calling_expired_refresh_token() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/oauth/token"))
.respond_with(ResponseTemplate::new(401).set_body_json(json!({
"error": {
"code": "refresh_token_expired"
}
})))
.expect(0)
.mount(&server)
.await;
let ctx = RefreshTokenTestContext::new(&server)?;
let stale_refresh = Utc::now() - Duration::days(9);
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(initial_tokens),
last_refresh: Some(stale_refresh),
};
ctx.write_auth(&initial_auth)?;
let fresh_refresh = Utc::now() - Duration::days(1);
let disk_tokens = build_tokens("disk-access-token", "disk-refresh-token");
let disk_auth = AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(disk_tokens.clone()),
last_refresh: Some(fresh_refresh),
};
save_auth(
ctx.codex_home.path(),
&disk_auth,
AuthCredentialsStoreMode::File,
)?;
let cached_auth = ctx
.auth_manager
.auth()
.await
.context("auth should reload from disk")?;
let cached = cached_auth
.get_token_data()
.context("token data should reload from disk")?;
assert_eq!(cached, disk_tokens);
let stored = ctx.load_auth()?;
assert_eq!(stored, disk_auth);
server.verify().await;
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn refresh_token_returns_permanent_error_for_expired_refresh_token() -> Result<()> {

View File

@@ -1090,10 +1090,13 @@ impl AuthManager {
}
/// Current cached auth (clone). May be `None` if not logged in or load failed.
/// Refreshes cached ChatGPT tokens if they are stale before returning.
/// For stale managed ChatGPT auth, first performs a guarded reload and then
/// refreshes only if the on-disk auth is unchanged.
pub async fn auth(&self) -> Option<CodexAuth> {
let auth = self.auth_cached()?;
if let Err(err) = self.refresh_if_stale(&auth).await {
if Self::is_stale_for_proactive_refresh(&auth)
&& let Err(err) = self.refresh_token().await
{
tracing::error!("Failed to refresh token: {}", err);
return Some(auth);
}
@@ -1320,30 +1323,21 @@ impl AuthManager {
self.auth_cached().as_ref().map(CodexAuth::auth_mode)
}
async fn refresh_if_stale(&self, auth: &CodexAuth) -> Result<bool, RefreshTokenError> {
fn is_stale_for_proactive_refresh(auth: &CodexAuth) -> bool {
let chatgpt_auth = match auth {
CodexAuth::Chatgpt(chatgpt_auth) => chatgpt_auth,
_ => return Ok(false),
_ => return false,
};
let auth_dot_json = match chatgpt_auth.current_auth_json() {
Some(auth_dot_json) => auth_dot_json,
None => return Ok(false),
};
let tokens = match auth_dot_json.tokens {
Some(tokens) => tokens,
None => return Ok(false),
None => return false,
};
let last_refresh = match auth_dot_json.last_refresh {
Some(last_refresh) => last_refresh,
None => return Ok(false),
None => return false,
};
if last_refresh >= Utc::now() - chrono::Duration::days(TOKEN_REFRESH_INTERVAL) {
return Ok(false);
}
self.refresh_and_persist_chatgpt_token(chatgpt_auth, tokens.refresh_token)
.await?;
Ok(true)
last_refresh < Utc::now() - chrono::Duration::days(TOKEN_REFRESH_INTERVAL)
}
async fn refresh_external_auth(