mirror of
https://github.com/openai/codex.git
synced 2026-03-21 21:36:31 +03:00
Compare commits
3 Commits
windows_ke
...
dev/cc/fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41a653fd81 | ||
|
|
703ed17fb6 | ||
|
|
dd88ed767b |
@@ -883,6 +883,24 @@ mod tests {
|
||||
account_id: Option<&str>,
|
||||
access_token: &str,
|
||||
refresh_token: &str,
|
||||
) -> serde_json::Value {
|
||||
chatgpt_auth_json_with_last_refresh(
|
||||
plan_type,
|
||||
chatgpt_user_id,
|
||||
account_id,
|
||||
access_token,
|
||||
refresh_token,
|
||||
"2025-01-01T00:00:00Z",
|
||||
)
|
||||
}
|
||||
|
||||
fn chatgpt_auth_json_with_last_refresh(
|
||||
plan_type: &str,
|
||||
chatgpt_user_id: Option<&str>,
|
||||
account_id: Option<&str>,
|
||||
access_token: &str,
|
||||
refresh_token: &str,
|
||||
last_refresh: &str,
|
||||
) -> serde_json::Value {
|
||||
chatgpt_auth_json_with_mode(
|
||||
plan_type,
|
||||
@@ -890,6 +908,7 @@ mod tests {
|
||||
account_id,
|
||||
access_token,
|
||||
refresh_token,
|
||||
last_refresh,
|
||||
None,
|
||||
)
|
||||
}
|
||||
@@ -900,6 +919,7 @@ mod tests {
|
||||
account_id: Option<&str>,
|
||||
access_token: &str,
|
||||
refresh_token: &str,
|
||||
last_refresh: &str,
|
||||
auth_mode: Option<&str>,
|
||||
) -> serde_json::Value {
|
||||
let header = json!({ "alg": "none", "typ": "JWT" });
|
||||
@@ -925,7 +945,7 @@ mod tests {
|
||||
"refresh_token": refresh_token,
|
||||
"account_id": account_id,
|
||||
},
|
||||
"last_refresh": "2025-01-01T00:00:00Z",
|
||||
"last_refresh": last_refresh,
|
||||
});
|
||||
if let Some(auth_mode) = auth_mode {
|
||||
auth_json["auth_mode"] = serde_json::Value::String(auth_mode.to_string());
|
||||
@@ -1262,24 +1282,43 @@ enabled = false
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_cloud_requirements_recovers_after_unauthorized_reload() {
|
||||
let auth = managed_auth_context(
|
||||
"business",
|
||||
Some("user-12345"),
|
||||
Some("account-12345"),
|
||||
"stale-access-token",
|
||||
"test-refresh-token",
|
||||
);
|
||||
let auth_home = tempdir().expect("tempdir");
|
||||
write_auth_json(
|
||||
auth._home.path(),
|
||||
chatgpt_auth_json(
|
||||
auth_home.path(),
|
||||
chatgpt_auth_json_with_last_refresh(
|
||||
"business",
|
||||
Some("user-12345"),
|
||||
Some("account-12345"),
|
||||
"stale-access-token",
|
||||
"test-refresh-token",
|
||||
// Keep auth "fresh" so the first request hits unauthorized recovery
|
||||
// instead of AuthManager::auth() proactively reloading from disk.
|
||||
"3025-01-01T00:00:00Z",
|
||||
),
|
||||
)
|
||||
.expect("write initial auth");
|
||||
let auth_manager = Arc::new(AuthManager::new(
|
||||
auth_home.path().to_path_buf(),
|
||||
false,
|
||||
AuthCredentialsStoreMode::File,
|
||||
));
|
||||
|
||||
write_auth_json(
|
||||
auth_home.path(),
|
||||
chatgpt_auth_json_with_last_refresh(
|
||||
"business",
|
||||
Some("user-12345"),
|
||||
Some("account-12345"),
|
||||
"fresh-access-token",
|
||||
"test-refresh-token",
|
||||
"3025-01-01T00:00:00Z",
|
||||
),
|
||||
)
|
||||
.expect("write refreshed auth");
|
||||
let auth = ManagedAuthContext {
|
||||
_home: auth_home,
|
||||
manager: auth_manager,
|
||||
};
|
||||
|
||||
let fetcher = Arc::new(TokenFetcher {
|
||||
expected_token: "fresh-access-token".to_string(),
|
||||
@@ -1314,24 +1353,41 @@ enabled = false
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetch_cloud_requirements_recovers_after_unauthorized_reload_updates_cache_identity() {
|
||||
let auth = managed_auth_context(
|
||||
"business",
|
||||
Some("user-12345"),
|
||||
Some("account-12345"),
|
||||
"stale-access-token",
|
||||
"test-refresh-token",
|
||||
);
|
||||
let auth_home = tempdir().expect("tempdir");
|
||||
write_auth_json(
|
||||
auth._home.path(),
|
||||
chatgpt_auth_json(
|
||||
auth_home.path(),
|
||||
chatgpt_auth_json_with_last_refresh(
|
||||
"business",
|
||||
Some("user-12345"),
|
||||
Some("account-12345"),
|
||||
"stale-access-token",
|
||||
"test-refresh-token",
|
||||
"3025-01-01T00:00:00Z",
|
||||
),
|
||||
)
|
||||
.expect("write initial auth");
|
||||
let auth_manager = Arc::new(AuthManager::new(
|
||||
auth_home.path().to_path_buf(),
|
||||
false,
|
||||
AuthCredentialsStoreMode::File,
|
||||
));
|
||||
|
||||
write_auth_json(
|
||||
auth_home.path(),
|
||||
chatgpt_auth_json_with_last_refresh(
|
||||
"business",
|
||||
Some("user-99999"),
|
||||
Some("account-12345"),
|
||||
"fresh-access-token",
|
||||
"test-refresh-token",
|
||||
"3025-01-01T00:00:00Z",
|
||||
),
|
||||
)
|
||||
.expect("write refreshed auth");
|
||||
let auth = ManagedAuthContext {
|
||||
_home: auth_home,
|
||||
manager: auth_manager,
|
||||
};
|
||||
|
||||
let fetcher = Arc::new(TokenFetcher {
|
||||
expected_token: "fresh-access-token".to_string(),
|
||||
@@ -1432,6 +1488,7 @@ enabled = false
|
||||
Some("account-12345"),
|
||||
"test-access-token",
|
||||
"test-refresh-token",
|
||||
"2025-01-01T00:00:00Z",
|
||||
Some("chatgptAuthTokens"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -99,6 +99,7 @@ pub(crate) async fn monitor_action(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
action: serde_json::Value,
|
||||
protection_client_callsite: &'static str,
|
||||
) -> ArcMonitorOutcome {
|
||||
let auth = match turn_context.auth_manager.as_ref() {
|
||||
Some(auth_manager) => match auth_manager.auth().await {
|
||||
@@ -138,7 +139,8 @@ pub(crate) async fn monitor_action(
|
||||
return ArcMonitorOutcome::Ok;
|
||||
}
|
||||
};
|
||||
let body = build_arc_monitor_request(sess, turn_context, action).await;
|
||||
let body =
|
||||
build_arc_monitor_request(sess, turn_context, action, protection_client_callsite).await;
|
||||
let client = build_reqwest_client();
|
||||
let mut request = client
|
||||
.post(&url)
|
||||
@@ -236,6 +238,7 @@ async fn build_arc_monitor_request(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
action: serde_json::Map<String, serde_json::Value>,
|
||||
protection_client_callsite: &'static str,
|
||||
) -> ArcMonitorRequest {
|
||||
let history = sess.clone_history().await;
|
||||
let mut messages = build_arc_monitor_messages(history.raw_items());
|
||||
@@ -254,7 +257,7 @@ async fn build_arc_monitor_request(
|
||||
codex_thread_id: conversation_id.clone(),
|
||||
codex_turn_id: turn_context.sub_id.clone(),
|
||||
conversation_id: Some(conversation_id),
|
||||
protection_client_callsite: None,
|
||||
protection_client_callsite: Some(protection_client_callsite.to_string()),
|
||||
},
|
||||
messages: Some(messages),
|
||||
input: None,
|
||||
|
||||
@@ -178,6 +178,7 @@ async fn build_arc_monitor_request_includes_relevant_history_and_null_policies()
|
||||
&turn_context,
|
||||
serde_json::from_value(serde_json::json!({ "tool": "mcp_tool_call" }))
|
||||
.expect("action should deserialize"),
|
||||
"normal",
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -188,7 +189,7 @@ async fn build_arc_monitor_request_includes_relevant_history_and_null_policies()
|
||||
codex_thread_id: session.conversation_id.to_string(),
|
||||
codex_turn_id: turn_context.sub_id.clone(),
|
||||
conversation_id: Some(session.conversation_id.to_string()),
|
||||
protection_client_callsite: None,
|
||||
protection_client_callsite: Some("normal".to_string()),
|
||||
},
|
||||
messages: Some(vec![
|
||||
ArcMonitorChatMessage {
|
||||
@@ -285,6 +286,7 @@ async fn monitor_action_posts_expected_arc_request() {
|
||||
"codex_thread_id": session.conversation_id.to_string(),
|
||||
"codex_turn_id": turn_context.sub_id.clone(),
|
||||
"conversation_id": session.conversation_id.to_string(),
|
||||
"protection_client_callsite": "normal",
|
||||
},
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
@@ -320,6 +322,7 @@ async fn monitor_action_posts_expected_arc_request() {
|
||||
&session,
|
||||
&turn_context,
|
||||
serde_json::json!({ "tool": "mcp_tool_call" }),
|
||||
"normal",
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -377,6 +380,7 @@ async fn monitor_action_uses_env_url_and_token_overrides() {
|
||||
&session,
|
||||
&turn_context,
|
||||
serde_json::json!({ "tool": "mcp_tool_call" }),
|
||||
"normal",
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -428,6 +432,7 @@ async fn monitor_action_rejects_legacy_response_fields() {
|
||||
&session,
|
||||
&turn_context,
|
||||
serde_json::json!({ "tool": "mcp_tool_call" }),
|
||||
"normal",
|
||||
)
|
||||
.await;
|
||||
|
||||
|
||||
@@ -457,6 +457,9 @@ const MCP_TOOL_APPROVAL_TOOL_TITLE_KEY: &str = "tool_title";
|
||||
const MCP_TOOL_APPROVAL_TOOL_DESCRIPTION_KEY: &str = "tool_description";
|
||||
const MCP_TOOL_APPROVAL_TOOL_PARAMS_KEY: &str = "tool_params";
|
||||
const MCP_TOOL_APPROVAL_TOOL_PARAMS_DISPLAY_KEY: &str = "tool_params_display";
|
||||
const MCP_TOOL_CALL_ARC_MONITOR_CALLSITE_DEFAULT: &str = "mcp_tool_call__default";
|
||||
const MCP_TOOL_CALL_ARC_MONITOR_CALLSITE_ALWAYS_ALLOW: &str = "mcp_tool_call__always_allow";
|
||||
const MCP_TOOL_CALL_ARC_MONITOR_CALLSITE_FULL_ACCESS: &str = "mcp_tool_call__full_access";
|
||||
|
||||
pub(crate) fn is_mcp_tool_approval_question_id(question_id: &str) -> bool {
|
||||
question_id
|
||||
@@ -494,14 +497,22 @@ async fn maybe_request_mcp_tool_approval(
|
||||
let annotations = metadata.and_then(|metadata| metadata.annotations.as_ref());
|
||||
let approval_required = annotations.is_some_and(requires_mcp_tool_approval);
|
||||
let mut monitor_reason = None;
|
||||
let auto_approved_by_policy = approval_mode == AppToolApproval::Approve
|
||||
|| (approval_mode == AppToolApproval::Auto && is_full_access_mode(turn_context));
|
||||
|
||||
if approval_mode == AppToolApproval::Approve {
|
||||
if auto_approved_by_policy {
|
||||
if !approval_required {
|
||||
return None;
|
||||
}
|
||||
|
||||
match maybe_monitor_auto_approved_mcp_tool_call(sess, turn_context, invocation, metadata)
|
||||
.await
|
||||
match maybe_monitor_auto_approved_mcp_tool_call(
|
||||
sess,
|
||||
turn_context,
|
||||
invocation,
|
||||
metadata,
|
||||
approval_mode,
|
||||
)
|
||||
.await
|
||||
{
|
||||
ArcMonitorOutcome::Ok => return None,
|
||||
ArcMonitorOutcome::AskUser(reason) => {
|
||||
@@ -515,13 +526,8 @@ async fn maybe_request_mcp_tool_approval(
|
||||
}
|
||||
}
|
||||
|
||||
if approval_mode == AppToolApproval::Auto {
|
||||
if is_full_access_mode(turn_context) {
|
||||
return None;
|
||||
}
|
||||
if !approval_required {
|
||||
return None;
|
||||
}
|
||||
if approval_mode == AppToolApproval::Auto && !approval_required {
|
||||
return None;
|
||||
}
|
||||
|
||||
let session_approval_key = session_mcp_tool_approval_key(invocation, metadata, approval_mode);
|
||||
@@ -653,9 +659,16 @@ async fn maybe_monitor_auto_approved_mcp_tool_call(
|
||||
turn_context: &TurnContext,
|
||||
invocation: &McpInvocation,
|
||||
metadata: Option<&McpToolApprovalMetadata>,
|
||||
approval_mode: AppToolApproval,
|
||||
) -> ArcMonitorOutcome {
|
||||
let action = prepare_arc_request_action(invocation, metadata);
|
||||
monitor_action(sess, turn_context, action).await
|
||||
monitor_action(
|
||||
sess,
|
||||
turn_context,
|
||||
action,
|
||||
mcp_tool_approval_callsite_mode(approval_mode, turn_context),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn prepare_arc_request_action(
|
||||
@@ -749,6 +762,22 @@ fn is_full_access_mode(turn_context: &TurnContext) -> bool {
|
||||
)
|
||||
}
|
||||
|
||||
fn mcp_tool_approval_callsite_mode(
|
||||
approval_mode: AppToolApproval,
|
||||
turn_context: &TurnContext,
|
||||
) -> &'static str {
|
||||
match approval_mode {
|
||||
AppToolApproval::Approve => MCP_TOOL_CALL_ARC_MONITOR_CALLSITE_ALWAYS_ALLOW,
|
||||
AppToolApproval::Auto | AppToolApproval::Prompt => {
|
||||
if approval_mode == AppToolApproval::Auto && is_full_access_mode(turn_context) {
|
||||
MCP_TOOL_CALL_ARC_MONITOR_CALLSITE_FULL_ACCESS
|
||||
} else {
|
||||
MCP_TOOL_CALL_ARC_MONITOR_CALLSITE_DEFAULT
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn lookup_mcp_tool_metadata(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
|
||||
@@ -776,6 +776,38 @@ fn approval_elicitation_meta_merges_session_and_always_persist_with_connector_so
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn approval_callsite_mode_distinguishes_default_always_allow_and_full_access() {
|
||||
let (_session, mut turn_context) = make_session_and_context().await;
|
||||
|
||||
assert_eq!(
|
||||
mcp_tool_approval_callsite_mode(AppToolApproval::Auto, &turn_context),
|
||||
"mcp_tool_call__default"
|
||||
);
|
||||
assert_eq!(
|
||||
mcp_tool_approval_callsite_mode(AppToolApproval::Prompt, &turn_context),
|
||||
"mcp_tool_call__default"
|
||||
);
|
||||
assert_eq!(
|
||||
mcp_tool_approval_callsite_mode(AppToolApproval::Approve, &turn_context),
|
||||
"mcp_tool_call__always_allow"
|
||||
);
|
||||
|
||||
turn_context
|
||||
.approval_policy
|
||||
.set(AskForApproval::Never)
|
||||
.expect("test setup should allow updating approval policy");
|
||||
turn_context
|
||||
.sandbox_policy
|
||||
.set(SandboxPolicy::DangerFullAccess)
|
||||
.expect("test setup should allow updating sandbox policy");
|
||||
|
||||
assert_eq!(
|
||||
mcp_tool_approval_callsite_mode(AppToolApproval::Auto, &turn_context),
|
||||
"mcp_tool_call__full_access"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn declined_elicitation_response_stays_decline() {
|
||||
let response = parse_mcp_tool_approval_elicitation_response(
|
||||
@@ -1035,6 +1067,83 @@ async fn approve_mode_blocks_when_arc_returns_interrupt_for_model() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn full_access_auto_mode_blocks_when_arc_returns_interrupt_for_model() {
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/codex/safety/arc"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
|
||||
"outcome": "steer-model",
|
||||
"short_reason": "needs approval",
|
||||
"rationale": "high-risk action",
|
||||
"risk_score": 96,
|
||||
"risk_level": "critical",
|
||||
"evidence": [{
|
||||
"message": "dangerous_tool",
|
||||
"why": "high-risk action",
|
||||
}],
|
||||
})))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let (session, mut turn_context) = make_session_and_context().await;
|
||||
turn_context.auth_manager = Some(crate::test_support::auth_manager_from_auth(
|
||||
crate::CodexAuth::create_dummy_chatgpt_auth_for_testing(),
|
||||
));
|
||||
turn_context
|
||||
.approval_policy
|
||||
.set(AskForApproval::Never)
|
||||
.expect("test setup should allow updating approval policy");
|
||||
turn_context
|
||||
.sandbox_policy
|
||||
.set(SandboxPolicy::DangerFullAccess)
|
||||
.expect("test setup should allow updating sandbox policy");
|
||||
let mut config = (*turn_context.config).clone();
|
||||
config.chatgpt_base_url = server.uri();
|
||||
turn_context.config = Arc::new(config);
|
||||
|
||||
let session = Arc::new(session);
|
||||
let turn_context = Arc::new(turn_context);
|
||||
let invocation = McpInvocation {
|
||||
server: CODEX_APPS_MCP_SERVER_NAME.to_string(),
|
||||
tool: "dangerous_tool".to_string(),
|
||||
arguments: Some(serde_json::json!({ "id": 1 })),
|
||||
};
|
||||
let metadata = McpToolApprovalMetadata {
|
||||
annotations: Some(annotations(Some(false), Some(true), Some(true))),
|
||||
connector_id: Some("calendar".to_string()),
|
||||
connector_name: Some("Calendar".to_string()),
|
||||
connector_description: Some("Manage events".to_string()),
|
||||
tool_title: Some("Dangerous Tool".to_string()),
|
||||
tool_description: Some("Performs a risky action.".to_string()),
|
||||
codex_apps_meta: None,
|
||||
};
|
||||
|
||||
let decision = maybe_request_mcp_tool_approval(
|
||||
&session,
|
||||
&turn_context,
|
||||
"call-2",
|
||||
&invocation,
|
||||
Some(&metadata),
|
||||
AppToolApproval::Auto,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
decision,
|
||||
Some(McpToolApprovalDecision::BlockedBySafetyMonitor(
|
||||
"Tool call was cancelled because of safety risks: high-risk action".to_string(),
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn approve_mode_routes_arc_ask_user_to_guardian_when_guardian_reviewer_is_enabled() {
|
||||
use wiremock::Mock;
|
||||
|
||||
@@ -381,6 +381,116 @@ async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[serial_test::serial(auth_refresh)]
|
||||
#[tokio::test]
|
||||
async fn auth_reloads_disk_auth_when_cached_auth_is_stale() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let ctx = RefreshTokenTestContext::new(&server)?;
|
||||
let stale_refresh = Utc::now() - Duration::days(9);
|
||||
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
|
||||
let initial_auth = AuthDotJson {
|
||||
auth_mode: Some(AuthMode::Chatgpt),
|
||||
openai_api_key: None,
|
||||
tokens: Some(initial_tokens),
|
||||
last_refresh: Some(stale_refresh),
|
||||
};
|
||||
ctx.write_auth(&initial_auth)?;
|
||||
|
||||
let fresh_refresh = Utc::now() - Duration::days(1);
|
||||
let disk_tokens = build_tokens("disk-access-token", "disk-refresh-token");
|
||||
let disk_auth = AuthDotJson {
|
||||
auth_mode: Some(AuthMode::Chatgpt),
|
||||
openai_api_key: None,
|
||||
tokens: Some(disk_tokens.clone()),
|
||||
last_refresh: Some(fresh_refresh),
|
||||
};
|
||||
save_auth(
|
||||
ctx.codex_home.path(),
|
||||
&disk_auth,
|
||||
AuthCredentialsStoreMode::File,
|
||||
)?;
|
||||
|
||||
let cached_auth = ctx
|
||||
.auth_manager
|
||||
.auth()
|
||||
.await
|
||||
.context("auth should reload from disk")?;
|
||||
let cached = cached_auth
|
||||
.get_token_data()
|
||||
.context("token data should reload from disk")?;
|
||||
assert_eq!(cached, disk_tokens);
|
||||
|
||||
let stored = ctx.load_auth()?;
|
||||
assert_eq!(stored, disk_auth);
|
||||
|
||||
let requests = server.received_requests().await.unwrap_or_default();
|
||||
assert!(requests.is_empty(), "expected no refresh token requests");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[serial_test::serial(auth_refresh)]
|
||||
#[tokio::test]
|
||||
async fn auth_reloads_disk_auth_without_calling_expired_refresh_token() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/oauth/token"))
|
||||
.respond_with(ResponseTemplate::new(401).set_body_json(json!({
|
||||
"error": {
|
||||
"code": "refresh_token_expired"
|
||||
}
|
||||
})))
|
||||
.expect(0)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let ctx = RefreshTokenTestContext::new(&server)?;
|
||||
let stale_refresh = Utc::now() - Duration::days(9);
|
||||
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
|
||||
let initial_auth = AuthDotJson {
|
||||
auth_mode: Some(AuthMode::Chatgpt),
|
||||
openai_api_key: None,
|
||||
tokens: Some(initial_tokens),
|
||||
last_refresh: Some(stale_refresh),
|
||||
};
|
||||
ctx.write_auth(&initial_auth)?;
|
||||
|
||||
let fresh_refresh = Utc::now() - Duration::days(1);
|
||||
let disk_tokens = build_tokens("disk-access-token", "disk-refresh-token");
|
||||
let disk_auth = AuthDotJson {
|
||||
auth_mode: Some(AuthMode::Chatgpt),
|
||||
openai_api_key: None,
|
||||
tokens: Some(disk_tokens.clone()),
|
||||
last_refresh: Some(fresh_refresh),
|
||||
};
|
||||
save_auth(
|
||||
ctx.codex_home.path(),
|
||||
&disk_auth,
|
||||
AuthCredentialsStoreMode::File,
|
||||
)?;
|
||||
|
||||
let cached_auth = ctx
|
||||
.auth_manager
|
||||
.auth()
|
||||
.await
|
||||
.context("auth should reload from disk")?;
|
||||
let cached = cached_auth
|
||||
.get_token_data()
|
||||
.context("token data should reload from disk")?;
|
||||
assert_eq!(cached, disk_tokens);
|
||||
|
||||
let stored = ctx.load_auth()?;
|
||||
assert_eq!(stored, disk_auth);
|
||||
|
||||
server.verify().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[serial_test::serial(auth_refresh)]
|
||||
#[tokio::test]
|
||||
async fn refresh_token_returns_permanent_error_for_expired_refresh_token() -> Result<()> {
|
||||
|
||||
@@ -1090,10 +1090,13 @@ impl AuthManager {
|
||||
}
|
||||
|
||||
/// Current cached auth (clone). May be `None` if not logged in or load failed.
|
||||
/// Refreshes cached ChatGPT tokens if they are stale before returning.
|
||||
/// For stale managed ChatGPT auth, first performs a guarded reload and then
|
||||
/// refreshes only if the on-disk auth is unchanged.
|
||||
pub async fn auth(&self) -> Option<CodexAuth> {
|
||||
let auth = self.auth_cached()?;
|
||||
if let Err(err) = self.refresh_if_stale(&auth).await {
|
||||
if Self::is_stale_for_proactive_refresh(&auth)
|
||||
&& let Err(err) = self.refresh_token().await
|
||||
{
|
||||
tracing::error!("Failed to refresh token: {}", err);
|
||||
return Some(auth);
|
||||
}
|
||||
@@ -1320,30 +1323,21 @@ impl AuthManager {
|
||||
self.auth_cached().as_ref().map(CodexAuth::auth_mode)
|
||||
}
|
||||
|
||||
async fn refresh_if_stale(&self, auth: &CodexAuth) -> Result<bool, RefreshTokenError> {
|
||||
fn is_stale_for_proactive_refresh(auth: &CodexAuth) -> bool {
|
||||
let chatgpt_auth = match auth {
|
||||
CodexAuth::Chatgpt(chatgpt_auth) => chatgpt_auth,
|
||||
_ => return Ok(false),
|
||||
_ => return false,
|
||||
};
|
||||
|
||||
let auth_dot_json = match chatgpt_auth.current_auth_json() {
|
||||
Some(auth_dot_json) => auth_dot_json,
|
||||
None => return Ok(false),
|
||||
};
|
||||
let tokens = match auth_dot_json.tokens {
|
||||
Some(tokens) => tokens,
|
||||
None => return Ok(false),
|
||||
None => return false,
|
||||
};
|
||||
let last_refresh = match auth_dot_json.last_refresh {
|
||||
Some(last_refresh) => last_refresh,
|
||||
None => return Ok(false),
|
||||
None => return false,
|
||||
};
|
||||
if last_refresh >= Utc::now() - chrono::Duration::days(TOKEN_REFRESH_INTERVAL) {
|
||||
return Ok(false);
|
||||
}
|
||||
self.refresh_and_persist_chatgpt_token(chatgpt_auth, tokens.refresh_token)
|
||||
.await?;
|
||||
Ok(true)
|
||||
last_refresh < Utc::now() - chrono::Duration::days(TOKEN_REFRESH_INTERVAL)
|
||||
}
|
||||
|
||||
async fn refresh_external_auth(
|
||||
|
||||
Reference in New Issue
Block a user