Attempt to reload auth as a step in 401 recovery (#8880)

When authentication fails, first attempt to reload the auth from file
and then attempt to refresh it.
This commit is contained in:
pakrym-oai
2026-01-08 15:06:44 -08:00
committed by GitHub
parent be4364bb80
commit 62a73b6d58
3 changed files with 377 additions and 57 deletions

View File

@@ -75,10 +75,6 @@ impl RefreshTokenError {
Self::Transient(_) => None,
}
}
fn other_with_message(message: impl Into<String>) -> Self {
Self::Transient(std::io::Error::other(message.into()))
}
}
impl From<RefreshTokenError> for std::io::Error {
@@ -450,6 +446,7 @@ async fn try_refresh_token(
Ok(refresh_response)
} else {
let body = response.text().await.unwrap_or_default();
tracing::error!("Failed to refresh token: {status}: {body}");
if status == StatusCode::UNAUTHORIZED {
let failed = classify_refresh_token_failure(&body);
Err(RefreshTokenError::Permanent(failed))
@@ -548,6 +545,89 @@ struct CachedAuth {
auth: Option<CodexAuth>,
}
enum UnauthorizedRecoveryStep {
Reload,
RefreshToken,
Done,
}
enum ReloadOutcome {
Reloaded,
Skipped,
}
// UnauthorizedRecovery is a state machine that handles an attempt to refresh the authentication when requests
// to API fail with 401 status code.
// The client calls next() every time it encounters a 401 error, one time per retry.
// For API key based authentication, we don't do anything and let the error bubble to the user.
// For ChatGPT based authentication, we:
// 1. Attempt to reload the auth data from disk. We only reload if the account id matches the one the current process is running as.
// 2. Attempt to refresh the token using OAuth token refresh flow.
// If after both steps the server still responds with 401 we let the error bubble to the user.
pub struct UnauthorizedRecovery {
manager: Arc<AuthManager>,
step: UnauthorizedRecoveryStep,
expected_account_id: Option<String>,
}
impl UnauthorizedRecovery {
fn new(manager: Arc<AuthManager>) -> Self {
let expected_account_id = manager
.auth_cached()
.as_ref()
.and_then(CodexAuth::get_account_id);
Self {
manager,
step: UnauthorizedRecoveryStep::Reload,
expected_account_id,
}
}
pub fn has_next(&self) -> bool {
if !self
.manager
.auth_cached()
.is_some_and(|auth| auth.mode == AuthMode::ChatGPT)
{
return false;
}
!matches!(self.step, UnauthorizedRecoveryStep::Done)
}
pub async fn next(&mut self) -> Result<(), RefreshTokenError> {
if !self.has_next() {
return Err(RefreshTokenError::Permanent(RefreshTokenFailedError::new(
RefreshTokenFailedReason::Other,
"No more recovery steps available.",
)));
}
match self.step {
UnauthorizedRecoveryStep::Reload => {
match self
.manager
.reload_if_account_id_matches(self.expected_account_id.as_deref())
{
ReloadOutcome::Reloaded => {
self.step = UnauthorizedRecoveryStep::RefreshToken;
}
ReloadOutcome::Skipped => {
self.manager.refresh_token().await?;
self.step = UnauthorizedRecoveryStep::Done;
}
}
}
UnauthorizedRecoveryStep::RefreshToken => {
self.manager.refresh_token().await?;
self.step = UnauthorizedRecoveryStep::Done;
}
UnauthorizedRecoveryStep::Done => {}
}
Ok(())
}
}
/// Central manager providing a single source of truth for auth.json derived
/// authentication data. It loads once (or on preference change) and then
/// hands out cloned `CodexAuth` values so the rest of the program has a
@@ -633,20 +713,34 @@ impl AuthManager {
/// Force a reload of the auth information from auth.json. Returns
/// whether the auth value changed.
pub fn reload(&self) -> bool {
let new_auth = load_auth(
&self.codex_home,
self.enable_codex_api_key_env,
self.auth_credentials_store_mode,
)
.ok()
.flatten();
if let Ok(mut guard) = self.inner.write() {
let changed = !AuthManager::auths_equal(&guard.auth, &new_auth);
guard.auth = new_auth;
changed
} else {
false
tracing::info!("Reloading auth");
let new_auth = self.load_auth_from_storage();
self.set_auth(new_auth)
}
fn reload_if_account_id_matches(&self, expected_account_id: Option<&str>) -> ReloadOutcome {
let expected_account_id = match expected_account_id {
Some(account_id) => account_id,
None => {
tracing::info!("Skipping auth reload because no account id is available.");
return ReloadOutcome::Skipped;
}
};
let new_auth = self.load_auth_from_storage();
let new_account_id = new_auth.as_ref().and_then(CodexAuth::get_account_id);
if new_account_id.as_deref() != Some(expected_account_id) {
let found_account_id = new_account_id.as_deref().unwrap_or("unknown");
tracing::info!(
"Skipping auth reload due to account id mismatch (expected: {expected_account_id}, found: {found_account_id})"
);
return ReloadOutcome::Skipped;
}
tracing::info!("Reloading auth for account {expected_account_id}");
self.set_auth(new_auth);
ReloadOutcome::Reloaded
}
fn auths_equal(a: &Option<CodexAuth>, b: &Option<CodexAuth>) -> bool {
@@ -657,6 +751,27 @@ impl AuthManager {
}
}
fn load_auth_from_storage(&self) -> Option<CodexAuth> {
load_auth(
&self.codex_home,
self.enable_codex_api_key_env,
self.auth_credentials_store_mode,
)
.ok()
.flatten()
}
fn set_auth(&self, new_auth: Option<CodexAuth>) -> bool {
if let Ok(mut guard) = self.inner.write() {
let changed = !AuthManager::auths_equal(&guard.auth, &new_auth);
tracing::info!("Reloaded auth, changed: {changed}");
guard.auth = new_auth;
changed
} else {
false
}
}
/// Convenience constructor returning an `Arc` wrapper.
pub fn shared(
codex_home: PathBuf,
@@ -670,22 +785,27 @@ impl AuthManager {
))
}
pub fn unauthorized_recovery(self: &Arc<Self>) -> UnauthorizedRecovery {
UnauthorizedRecovery::new(Arc::clone(self))
}
/// Attempt to refresh the current auth token (if any). On success, reload
/// the auth state from disk so other components observe refreshed token.
/// If the token refresh fails, returns the error to the caller.
pub async fn refresh_token(&self) -> Result<Option<String>, RefreshTokenError> {
pub async fn refresh_token(&self) -> Result<(), RefreshTokenError> {
tracing::info!("Refreshing token");
let auth = match self.auth_cached() {
Some(auth) => auth,
None => return Ok(None),
None => return Ok(()),
};
tracing::info!("Refreshing token");
let token_data = auth.get_current_token_data().ok_or_else(|| {
RefreshTokenError::Transient(std::io::Error::other("Token data is not available."))
})?;
let access = self.refresh_tokens(&auth, token_data.refresh_token).await?;
self.refresh_tokens(&auth, token_data.refresh_token).await?;
// Reload to pick up persisted changes.
self.reload();
Ok(Some(access))
Ok(())
}
/// Log out by deleting the ondisk auth.json (if present). Returns Ok(true)
@@ -732,10 +852,10 @@ impl AuthManager {
&self,
auth: &CodexAuth,
refresh_token: String,
) -> Result<String, RefreshTokenError> {
) -> Result<(), RefreshTokenError> {
let refresh_response = try_refresh_token(refresh_token, &auth.client).await?;
let updated = update_tokens(
update_tokens(
&auth.storage,
refresh_response.id_token,
refresh_response.access_token,
@@ -744,12 +864,7 @@ impl AuthManager {
.await
.map_err(RefreshTokenError::from)?;
match updated.tokens {
Some(tokens) => Ok(tokens.access_token),
None => Err(RefreshTokenError::other_with_message(
"Token data is not available after refresh.",
)),
}
Ok(())
}
}

View File

@@ -2,6 +2,7 @@ use std::sync::Arc;
use crate::api_bridge::auth_provider_from_auth;
use crate::api_bridge::map_api_error;
use crate::auth::UnauthorizedRecovery;
use codex_api::AggregateStreamExt;
use codex_api::ChatClient as ApiChatClient;
use codex_api::CompactClient as ApiCompactClient;
@@ -20,6 +21,7 @@ use codex_api::error::ApiError;
use codex_api::requests::responses::Compression;
use codex_app_server_protocol::AuthMode;
use codex_otel::OtelManager;
use codex_protocol::ThreadId;
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
use codex_protocol::models::ResponseItem;
@@ -155,7 +157,9 @@ impl ModelClient {
let conversation_id = self.conversation_id.to_string();
let session_source = self.session_source.clone();
let mut refreshed = false;
let mut auth_recovery = auth_manager
.as_ref()
.map(super::auth::AuthManager::unauthorized_recovery);
loop {
let auth = match auth_manager.as_ref() {
Some(manager) => manager.auth().await,
@@ -184,7 +188,7 @@ impl ModelClient {
Err(ApiError::Transport(TransportError::Http { status, .. }))
if status == StatusCode::UNAUTHORIZED =>
{
handle_unauthorized(status, &mut refreshed, &auth_manager, &auth).await?;
handle_unauthorized(status, &mut auth_recovery).await?;
continue;
}
Err(err) => return Err(map_api_error(err)),
@@ -246,7 +250,9 @@ impl ModelClient {
let conversation_id = self.conversation_id.to_string();
let session_source = self.session_source.clone();
let mut refreshed = false;
let mut auth_recovery = auth_manager
.as_ref()
.map(super::auth::AuthManager::unauthorized_recovery);
loop {
let auth = match auth_manager.as_ref() {
Some(manager) => manager.auth().await,
@@ -298,7 +304,7 @@ impl ModelClient {
Err(ApiError::Transport(TransportError::Http { status, .. }))
if status == StatusCode::UNAUTHORIZED =>
{
handle_unauthorized(status, &mut refreshed, &auth_manager, &auth).await?;
handle_unauthorized(status, &mut auth_recovery).await?;
continue;
}
Err(err) => return Err(map_api_error(err)),
@@ -509,29 +515,19 @@ where
/// the mapped `CodexErr` is returned to the caller.
async fn handle_unauthorized(
status: StatusCode,
refreshed: &mut bool,
auth_manager: &Option<Arc<AuthManager>>,
auth: &Option<crate::auth::CodexAuth>,
auth_recovery: &mut Option<UnauthorizedRecovery>,
) -> Result<()> {
if *refreshed {
return Err(map_unauthorized_status(status));
}
if let Some(manager) = auth_manager.as_ref()
&& let Some(auth) = auth.as_ref()
&& auth.mode == AuthMode::ChatGPT
if let Some(recovery) = auth_recovery
&& recovery.has_next()
{
match manager.refresh_token().await {
Ok(_) => {
*refreshed = true;
Ok(())
}
return match recovery.next().await {
Ok(_) => Ok(()),
Err(RefreshTokenError::Permanent(failed)) => Err(CodexErr::RefreshTokenFailed(failed)),
Err(RefreshTokenError::Transient(other)) => Err(CodexErr::Io(other)),
}
} else {
Err(map_unauthorized_status(status))
};
}
Err(map_unauthorized_status(status))
}
fn map_unauthorized_status(status: StatusCode) -> CodexErr {

View File

@@ -16,8 +16,10 @@ use codex_core::token_data::TokenData;
use core_test_support::skip_if_no_network;
use pretty_assertions::assert_eq;
use serde::Serialize;
use serde_json::Value;
use serde_json::json;
use std::ffi::OsString;
use std::sync::Arc;
use tempfile::TempDir;
use wiremock::Mock;
use wiremock::MockServer;
@@ -54,12 +56,10 @@ async fn refresh_token_succeeds_updates_storage() -> Result<()> {
};
ctx.write_auth(&initial_auth)?;
let access = ctx
.auth_manager
ctx.auth_manager
.refresh_token()
.await
.context("refresh should succeed")?;
assert_eq!(access, Some("new-access-token".to_string()));
let refreshed_tokens = TokenData {
access_token: "new-access-token".to_string(),
@@ -294,9 +294,218 @@ async fn refresh_token_returns_transient_error_on_server_failure() -> Result<()>
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn unauthorized_recovery_reloads_then_refreshes_tokens() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/oauth/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"access_token": "recovered-access-token",
"refresh_token": "recovered-refresh-token"
})))
.expect(1)
.mount(&server)
.await;
let ctx = RefreshTokenTestContext::new(&server)?;
let initial_last_refresh = Utc::now() - Duration::days(1);
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
openai_api_key: None,
tokens: Some(initial_tokens.clone()),
last_refresh: Some(initial_last_refresh),
};
ctx.write_auth(&initial_auth)?;
let disk_tokens = build_tokens("disk-access-token", "disk-refresh-token");
let disk_auth = AuthDotJson {
openai_api_key: None,
tokens: Some(disk_tokens.clone()),
last_refresh: Some(initial_last_refresh),
};
save_auth(
ctx.codex_home.path(),
&disk_auth,
AuthCredentialsStoreMode::File,
)?;
let cached_before = ctx
.auth_manager
.auth_cached()
.expect("auth should be cached");
let cached_before_tokens = cached_before
.get_token_data()
.context("token data should be cached")?;
assert_eq!(cached_before_tokens, initial_tokens);
let mut recovery = ctx.auth_manager.unauthorized_recovery();
assert!(recovery.has_next());
recovery.next().await?;
let cached_after = ctx
.auth_manager
.auth_cached()
.expect("auth should be cached after reload");
let cached_after_tokens = cached_after
.get_token_data()
.context("token data should reload")?;
assert_eq!(cached_after_tokens, disk_tokens);
let requests = server.received_requests().await.unwrap_or_default();
assert!(requests.is_empty(), "expected no refresh token requests");
recovery.next().await?;
let refreshed_tokens = TokenData {
access_token: "recovered-access-token".to_string(),
refresh_token: "recovered-refresh-token".to_string(),
..disk_tokens.clone()
};
let stored = ctx.load_auth()?;
let tokens = stored.tokens.as_ref().context("tokens should exist")?;
assert_eq!(tokens, &refreshed_tokens);
let cached_auth = ctx
.auth_manager
.auth()
.await
.expect("auth should be cached");
let cached_tokens = cached_auth
.get_token_data()
.context("token data should be cached")?;
assert_eq!(cached_tokens, refreshed_tokens);
assert!(!recovery.has_next());
server.verify().await;
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn unauthorized_recovery_skips_reload_on_account_mismatch() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/oauth/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"access_token": "recovered-access-token",
"refresh_token": "recovered-refresh-token"
})))
.expect(1)
.mount(&server)
.await;
let ctx = RefreshTokenTestContext::new(&server)?;
let initial_last_refresh = Utc::now() - Duration::days(1);
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
openai_api_key: None,
tokens: Some(initial_tokens.clone()),
last_refresh: Some(initial_last_refresh),
};
ctx.write_auth(&initial_auth)?;
let mut disk_tokens = build_tokens("disk-access-token", "disk-refresh-token");
disk_tokens.account_id = Some("other-account".to_string());
let expected_tokens = TokenData {
access_token: "recovered-access-token".to_string(),
refresh_token: "recovered-refresh-token".to_string(),
..disk_tokens.clone()
};
let disk_auth = AuthDotJson {
openai_api_key: None,
tokens: Some(disk_tokens),
last_refresh: Some(initial_last_refresh),
};
save_auth(
ctx.codex_home.path(),
&disk_auth,
AuthCredentialsStoreMode::File,
)?;
let cached_before = ctx
.auth_manager
.auth_cached()
.expect("auth should be cached");
let cached_before_tokens = cached_before
.get_token_data()
.context("token data should be cached")?;
assert_eq!(cached_before_tokens, initial_tokens);
let mut recovery = ctx.auth_manager.unauthorized_recovery();
assert!(recovery.has_next());
recovery.next().await?;
let stored = ctx.load_auth()?;
let tokens = stored.tokens.as_ref().context("tokens should exist")?;
assert_eq!(tokens, &expected_tokens);
let requests = server.received_requests().await.unwrap_or_default();
let request = requests
.first()
.context("expected a refresh token request")?;
let body: Value =
serde_json::from_slice(&request.body).context("refresh request body should be json")?;
let refresh_token = body
.get("refresh_token")
.and_then(Value::as_str)
.context("refresh_token should be set")?;
assert_eq!(refresh_token, INITIAL_REFRESH_TOKEN);
let cached_after = ctx
.auth_manager
.auth()
.await
.context("auth should remain cached after refresh")?;
let cached_after_tokens = cached_after
.get_token_data()
.context("token data should reflect refreshed tokens")?;
assert_eq!(cached_after_tokens, expected_tokens);
assert!(!recovery.has_next());
server.verify().await;
Ok(())
}
#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn unauthorized_recovery_requires_chatgpt_auth() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
let ctx = RefreshTokenTestContext::new(&server)?;
let auth = AuthDotJson {
openai_api_key: Some("sk-test".to_string()),
tokens: None,
last_refresh: None,
};
ctx.write_auth(&auth)?;
let mut recovery = ctx.auth_manager.unauthorized_recovery();
assert!(!recovery.has_next());
let err = recovery
.next()
.await
.err()
.context("recovery should fail")?;
assert_eq!(err.failed_reason(), Some(RefreshTokenFailedReason::Other));
let requests = server.received_requests().await.unwrap_or_default();
assert!(requests.is_empty(), "expected no refresh token requests");
Ok(())
}
struct RefreshTokenTestContext {
codex_home: TempDir,
auth_manager: AuthManager,
auth_manager: Arc<AuthManager>,
_env_guard: EnvGuard,
}
@@ -307,7 +516,7 @@ impl RefreshTokenTestContext {
let endpoint = format!("{}/oauth/token", server.uri());
let env_guard = EnvGuard::set(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, endpoint);
let auth_manager = AuthManager::new(
let auth_manager = AuthManager::shared(
codex_home.path().to_path_buf(),
false,
AuthCredentialsStoreMode::File,