Compare commits

...

3 Commits

Author SHA1 Message Date
Michael Bolin
c1c63e8f83 core: support dynamic auth tokens for model providers 2026-03-30 16:55:43 -07:00
Michael Bolin
4127419694 auth: let AuthManager own external bearer auth 2026-03-30 16:27:55 -07:00
Michael Bolin
b397919da1 auth: generalize external auth tokens for bearer-only sources 2026-03-30 16:20:45 -07:00
22 changed files with 1406 additions and 98 deletions

View File

@@ -144,11 +144,11 @@ impl ExternalAuthRefresher for ExternalAuthRefreshBridge {
let response: ChatgptAuthTokensRefreshResponse =
serde_json::from_value(result).map_err(std::io::Error::other)?;
Ok(ExternalAuthTokens {
access_token: response.access_token,
chatgpt_account_id: response.chatgpt_account_id,
chatgpt_plan_type: response.chatgpt_plan_type,
})
Ok(ExternalAuthTokens::chatgpt(
response.access_token,
response.chatgpt_account_id,
response.chatgpt_plan_type,
))
}
}

View File

@@ -816,10 +816,63 @@
},
"type": "object"
},
"ModelProviderAuthInfo": {
"additionalProperties": false,
"description": "Configuration for obtaining a provider bearer token from a command.",
"properties": {
"args": {
"default": [],
"description": "Command arguments.",
"items": {
"type": "string"
},
"type": "array"
},
"command": {
"description": "Command to execute. Bare names are resolved via `PATH`; paths are resolved against `cwd`.",
"type": "string"
},
"cwd": {
"allOf": [
{
"$ref": "#/definitions/AbsolutePathBuf"
}
],
"default": "/Users/mbolin/code/codex/codex-rs",
"description": "Working directory used when running the token command."
},
"refresh_interval_ms": {
"default": 300000,
"description": "Maximum age for the cached token before rerunning the command.",
"format": "uint64",
"minimum": 1.0,
"type": "integer"
},
"timeout_ms": {
"default": 5000,
"description": "Maximum time to wait for the token command to exit successfully.",
"format": "uint64",
"minimum": 1.0,
"type": "integer"
}
},
"required": [
"command"
],
"type": "object"
},
"ModelProviderInfo": {
"additionalProperties": false,
"description": "Serializable representation of a provider definition.",
"properties": {
"auth": {
"allOf": [
{
"$ref": "#/definitions/ModelProviderAuthInfo"
}
],
"description": "Command-backed bearer-token configuration for this provider."
},
"base_url": {
"description": "Base URL for the provider's OpenAI-compatible API.",
"type": "string"

View File

@@ -64,6 +64,7 @@ mod tests {
env_key: Some("sk-should-not-leak".to_string()),
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: crate::model_provider_info::WireApi::Responses,
query_params: None,
http_headers: None,

View File

@@ -104,6 +104,7 @@ use crate::error::Result;
use crate::flags::CODEX_RS_SSE_FIXTURE;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::provider_auth::scoped_auth_manager_for_provider;
use crate::response_debug_context::extract_response_debug_context;
use crate::response_debug_context::extract_response_debug_context_from_api_error;
use crate::response_debug_context::telemetry_api_error_message;
@@ -261,6 +262,7 @@ impl ModelClient {
include_timing_metrics: bool,
beta_features_header: Option<String>,
) -> Self {
let auth_manager = scoped_auth_manager_for_provider(auth_manager, &provider);
let codex_api_key_env_enabled = auth_manager
.as_ref()
.is_some_and(|manager| manager.codex_api_key_env_enabled());
@@ -294,6 +296,10 @@ impl ModelClient {
}
}
pub(crate) fn auth_manager(&self) -> Option<Arc<AuthManager>> {
self.state.auth_manager.clone()
}
fn take_cached_websocket_session(&self) -> WebsocketSession {
let mut cached_websocket_session = self
.state

View File

@@ -243,6 +243,26 @@ web_search = false
);
}
#[test]
fn rejects_provider_auth_with_env_key() {
let err = toml::from_str::<ConfigToml>(
r#"
[model_providers.corp]
name = "Corp"
env_key = "CORP_TOKEN"
[model_providers.corp.auth]
command = "print-token"
"#,
)
.unwrap_err();
assert!(
err.to_string()
.contains("model_providers.corp: provider auth cannot be combined with env_key")
);
}
#[test]
fn config_toml_deserializes_model_availability_nux() {
let toml = r#"
@@ -4315,6 +4335,7 @@ model_verbosity = "high"
wire_api: crate::WireApi::Responses,
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
query_params: None,
http_headers: None,
env_http_headers: None,

View File

@@ -1837,6 +1837,18 @@ Built-in providers cannot be overridden. Rename your custom provider (for exampl
}
}
fn validate_model_providers(
model_providers: &HashMap<String, ModelProviderInfo>,
) -> Result<(), String> {
validate_reserved_model_provider_ids(model_providers)?;
for (key, provider) in model_providers {
provider
.validate()
.map_err(|message| format!("model_providers.{key}: {message}"))?;
}
Ok(())
}
fn deserialize_model_providers<'de, D>(
deserializer: D,
) -> Result<HashMap<String, ModelProviderInfo>, D::Error>
@@ -1844,7 +1856,7 @@ where
D: serde::Deserializer<'de>,
{
let model_providers = HashMap::<String, ModelProviderInfo>::deserialize(deserializer)?;
validate_reserved_model_provider_ids(&model_providers).map_err(serde::de::Error::custom)?;
validate_model_providers(&model_providers).map_err(serde::de::Error::custom)?;
Ok(model_providers)
}
@@ -1969,7 +1981,7 @@ impl Config {
codex_home: PathBuf,
config_layer_stack: ConfigLayerStack,
) -> std::io::Result<Self> {
validate_reserved_model_provider_ids(&cfg.model_providers)
validate_model_providers(&cfg.model_providers)
.map_err(|message| std::io::Error::new(std::io::ErrorKind::InvalidInput, message))?;
// Ensure that every field of ConfigRequirements is applied to the final
// Config.

View File

@@ -65,6 +65,7 @@ pub mod utils;
pub use utils::path_utils;
pub mod personality_migration;
pub mod plugins;
mod provider_auth;
pub(crate) mod mentions {
pub(crate) use crate::plugins::build_connector_slug_counts;
pub(crate) use crate::plugins::build_skill_name_counts;
@@ -107,6 +108,7 @@ pub use client::X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER;
pub use model_provider_info::DEFAULT_LMSTUDIO_PORT;
pub use model_provider_info::DEFAULT_OLLAMA_PORT;
pub use model_provider_info::LMSTUDIO_OSS_PROVIDER_ID;
pub use model_provider_info::ModelProviderAuthInfo;
pub use model_provider_info::ModelProviderInfo;
pub use model_provider_info::OLLAMA_OSS_PROVIDER_ID;
pub use model_provider_info::OPENAI_PROVIDER_ID;

View File

@@ -9,6 +9,7 @@ use crate::auth::AuthMode;
use crate::error::EnvVarError;
use codex_api::Provider as ApiProvider;
use codex_api::provider::RetryConfig as ApiRetryConfig;
use codex_utils_absolute_path::AbsolutePathBuf;
use http::HeaderMap;
use http::header::HeaderName;
use http::header::HeaderValue;
@@ -17,8 +18,11 @@ use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::fmt;
use std::num::NonZeroU64;
use std::time::Duration;
const DEFAULT_PROVIDER_AUTH_TIMEOUT_MS: u64 = 5_000;
const DEFAULT_PROVIDER_AUTH_REFRESH_INTERVAL_MS: u64 = 300_000;
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
@@ -66,6 +70,73 @@ impl<'de> Deserialize<'de> for WireApi {
}
}
/// Configuration for obtaining a provider bearer token from a command.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, JsonSchema)]
#[schemars(deny_unknown_fields)]
pub struct ModelProviderAuthInfo {
/// Command to execute. Bare names are resolved via `PATH`; paths are resolved against `cwd`.
pub command: String,
/// Command arguments.
#[serde(default)]
pub args: Vec<String>,
/// Maximum time to wait for the token command to exit successfully.
#[serde(default = "default_provider_auth_timeout_ms")]
pub timeout_ms: NonZeroU64,
/// Maximum age for the cached token before rerunning the command.
#[serde(default = "default_provider_auth_refresh_interval_ms")]
pub refresh_interval_ms: NonZeroU64,
/// Working directory used when running the token command.
#[serde(default = "default_provider_auth_cwd")]
pub cwd: AbsolutePathBuf,
}
impl ModelProviderAuthInfo {
pub(crate) fn timeout(&self) -> Duration {
Duration::from_millis(self.timeout_ms.get())
}
pub(crate) fn refresh_interval(&self) -> Duration {
Duration::from_millis(self.refresh_interval_ms.get())
}
}
fn default_provider_auth_timeout_ms() -> NonZeroU64 {
non_zero_u64(
DEFAULT_PROVIDER_AUTH_TIMEOUT_MS,
"model_providers.<id>.auth.timeout_ms",
)
}
fn default_provider_auth_refresh_interval_ms() -> NonZeroU64 {
non_zero_u64(
DEFAULT_PROVIDER_AUTH_REFRESH_INTERVAL_MS,
"model_providers.<id>.auth.refresh_interval_ms",
)
}
fn non_zero_u64(value: u64, field_name: &str) -> NonZeroU64 {
match NonZeroU64::new(value) {
Some(value) => value,
None => panic!("{field_name} must be non-zero"),
}
}
fn default_provider_auth_cwd() -> AbsolutePathBuf {
let deserializer = serde::de::value::StrDeserializer::<serde::de::value::Error>::new(".");
if let Ok(cwd) = AbsolutePathBuf::deserialize(deserializer) {
return cwd;
}
match AbsolutePathBuf::current_dir() {
Ok(cwd) => cwd,
Err(err) => panic!("provider auth cwd must resolve: {err}"),
}
}
/// Serializable representation of a provider definition.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, JsonSchema)]
#[schemars(deny_unknown_fields)]
@@ -86,6 +157,9 @@ pub struct ModelProviderInfo {
/// this may be necessary when using this programmatically.
pub experimental_bearer_token: Option<String>,
/// Command-backed bearer-token configuration for this provider.
pub auth: Option<ModelProviderAuthInfo>,
/// Which wire protocol this provider expects.
#[serde(default)]
pub wire_api: WireApi,
@@ -130,6 +204,36 @@ pub struct ModelProviderInfo {
}
impl ModelProviderInfo {
pub(crate) fn validate(&self) -> std::result::Result<(), String> {
let Some(auth) = self.auth.as_ref() else {
return Ok(());
};
if auth.command.trim().is_empty() {
return Err("provider auth.command must not be empty".to_string());
}
let mut conflicts = Vec::new();
if self.env_key.is_some() {
conflicts.push("env_key");
}
if self.experimental_bearer_token.is_some() {
conflicts.push("experimental_bearer_token");
}
if self.requires_openai_auth {
conflicts.push("requires_openai_auth");
}
if conflicts.is_empty() {
Ok(())
} else {
Err(format!(
"provider auth cannot be combined with {}",
conflicts.join(", ")
))
}
}
fn build_header_map(&self) -> crate::error::Result<HeaderMap> {
let capacity = self.http_headers.as_ref().map_or(0, HashMap::len)
+ self.env_http_headers.as_ref().map_or(0, HashMap::len);
@@ -246,6 +350,7 @@ impl ModelProviderInfo {
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: Some(
@@ -277,6 +382,10 @@ impl ModelProviderInfo {
pub fn is_openai(&self) -> bool {
self.name == OPENAI_PROVIDER_NAME
}
pub(crate) fn has_command_auth(&self) -> bool {
self.auth.is_some()
}
}
pub const DEFAULT_LMSTUDIO_PORT: u16 = 1234;
@@ -338,6 +447,7 @@ pub fn create_oss_provider_with_base_url(base_url: &str, wire_api: WireApi) -> M
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api,
query_params: None,
http_headers: None,

View File

@@ -1,5 +1,8 @@
use super::*;
use codex_utils_absolute_path::AbsolutePathBuf;
use codex_utils_absolute_path::AbsolutePathBufGuard;
use pretty_assertions::assert_eq;
use tempfile::tempdir;
#[test]
fn test_deserialize_ollama_model_provider_toml() {
@@ -13,6 +16,7 @@ base_url = "http://localhost:11434/v1"
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
@@ -43,6 +47,7 @@ query_params = { api-version = "2025-04-01-preview" }
env_key: Some("AZURE_OPENAI_API_KEY".into()),
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: Some(maplit::hashmap! {
"api-version".to_string() => "2025-04-01-preview".to_string(),
@@ -76,6 +81,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
env_key: Some("API_KEY".into()),
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: Some(maplit::hashmap! {
@@ -121,3 +127,29 @@ supports_websockets = true
let provider: ModelProviderInfo = toml::from_str(provider_toml).unwrap();
assert_eq!(provider.websocket_connect_timeout_ms, Some(15_000));
}
#[test]
fn test_deserialize_provider_auth_config_defaults() {
let base_dir = tempdir().unwrap();
let provider_toml = r#"
name = "Corp"
[auth]
command = "./scripts/print-token"
args = ["--format=text"]
"#;
let _guard = AbsolutePathBufGuard::new(base_dir.path());
let provider: ModelProviderInfo = toml::from_str(provider_toml).unwrap();
assert_eq!(
provider.auth,
Some(ModelProviderAuthInfo {
command: "./scripts/print-token".to_string(),
args: vec!["--format=text".to_string()],
timeout_ms: default_provider_auth_timeout_ms(),
refresh_interval_ms: default_provider_auth_refresh_interval_ms(),
cwd: AbsolutePathBuf::resolve_path_against_base(".", base_dir.path()).unwrap(),
})
);
}

View File

@@ -4,6 +4,8 @@ use crate::api_bridge::map_api_error;
use crate::auth::AuthManager;
use crate::auth::AuthMode;
use crate::auth::CodexAuth;
use crate::auth::RefreshTokenError;
use crate::auth::UnauthorizedRecovery;
use crate::auth_env_telemetry::AuthEnvTelemetry;
use crate::auth_env_telemetry::collect_auth_env_telemetry;
use crate::config::Config;
@@ -14,6 +16,7 @@ use crate::model_provider_info::ModelProviderInfo;
use crate::models_manager::collaboration_mode_presets::CollaborationModesConfig;
use crate::models_manager::collaboration_mode_presets::builtin_collaboration_mode_presets;
use crate::models_manager::model_info;
use crate::provider_auth::scoped_auth_manager;
use crate::response_debug_context::extract_response_debug_context;
use crate::response_debug_context::telemetry_transport_error_message;
use crate::util::FeedbackRequestTags;
@@ -212,6 +215,7 @@ impl ModelsManager {
collaboration_modes_config: CollaborationModesConfig,
provider: ModelProviderInfo,
) -> Self {
let auth_manager = scoped_auth_manager(auth_manager, &provider);
let cache_path = codex_home.join(MODEL_CACHE_FILE);
let cache_manager = ModelsCacheManager::new(cache_path, DEFAULT_MODEL_CACHE_TTL);
let catalog_mode = if model_catalog.is_some() {
@@ -396,7 +400,9 @@ impl ModelsManager {
return Ok(());
}
if self.auth_manager.auth_mode() != Some(AuthMode::Chatgpt) {
if self.auth_manager.auth_mode() != Some(AuthMode::Chatgpt)
&& !self.provider.has_command_auth()
{
if matches!(
refresh_strategy,
RefreshStrategy::Offline | RefreshStrategy::OnlineIfUncached
@@ -431,39 +437,72 @@ impl ModelsManager {
async fn fetch_and_update_models(&self) -> CoreResult<()> {
let _timer =
codex_otel::start_global_timer("codex.remote_models.fetch_update.duration_ms", &[]);
let auth = self.auth_manager.auth().await;
let auth_mode = auth.as_ref().map(CodexAuth::auth_mode);
let api_provider = self.provider.to_api_provider(auth_mode)?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
let auth_env = collect_auth_env_telemetry(
&self.provider,
self.auth_manager.codex_api_key_env_enabled(),
);
let transport = ReqwestTransport::new(build_reqwest_client());
let request_telemetry: Arc<dyn RequestTelemetry> = Arc::new(ModelsRequestTelemetry {
auth_mode: auth_mode.map(|mode| TelemetryAuthMode::from(mode).to_string()),
auth_header_attached: api_auth.auth_header_attached(),
auth_header_name: api_auth.auth_header_name(),
auth_env,
});
let client = ModelsClient::new(transport, api_provider, api_auth)
.with_telemetry(Some(request_telemetry));
let client_version = crate::models_manager::client_version_to_whole();
let (models, etag) = timeout(
MODELS_REFRESH_TIMEOUT,
client.list_models(&client_version, HeaderMap::new()),
)
.await
.map_err(|_| CodexErr::Timeout)?
.map_err(map_api_error)?;
let mut auth_recovery = self.auth_manager.unauthorized_recovery();
self.apply_remote_models(models.clone()).await;
*self.etag.write().await = etag.clone();
self.cache_manager
.persist_cache(&models, etag, client_version)
.await;
Ok(())
loop {
let auth = self.auth_manager.auth().await;
let auth_mode = auth.as_ref().map(CodexAuth::auth_mode);
let api_provider = self.provider.to_api_provider(auth_mode)?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
let auth_env = collect_auth_env_telemetry(
&self.provider,
self.auth_manager.codex_api_key_env_enabled(),
);
let transport = ReqwestTransport::new(build_reqwest_client());
let request_telemetry: Arc<dyn RequestTelemetry> = Arc::new(ModelsRequestTelemetry {
auth_mode: auth_mode.map(|mode| TelemetryAuthMode::from(mode).to_string()),
auth_header_attached: api_auth.auth_header_attached(),
auth_header_name: api_auth.auth_header_name(),
auth_env,
});
let client = ModelsClient::new(transport, api_provider, api_auth)
.with_telemetry(Some(request_telemetry));
let result = timeout(
MODELS_REFRESH_TIMEOUT,
client.list_models(&client_version, HeaderMap::new()),
)
.await
.map_err(|_| CodexErr::Timeout)?
.map_err(map_api_error);
match result {
Ok((models, etag)) => {
self.apply_remote_models(models.clone()).await;
*self.etag.write().await = etag.clone();
self.cache_manager
.persist_cache(&models, etag, client_version)
.await;
return Ok(());
}
Err(
err @ CodexErr::UnexpectedStatus(crate::error::UnexpectedResponseError {
status,
..
}),
) if status == http::StatusCode::UNAUTHORIZED => {
if !Self::recover_after_unauthorized(&mut auth_recovery).await? {
return Err(err);
}
}
Err(err) => return Err(err),
}
}
}
async fn recover_after_unauthorized(
auth_recovery: &mut UnauthorizedRecovery,
) -> CoreResult<bool> {
if !auth_recovery.has_next() {
return Ok(false);
}
match auth_recovery.next().await {
Ok(_) => Ok(true),
Err(RefreshTokenError::Permanent(failed)) => Err(CodexErr::RefreshTokenFailed(failed)),
Err(RefreshTokenError::Transient(other)) => Err(CodexErr::Io(other)),
}
}
async fn get_etag(&self) -> Option<String> {

View File

@@ -1,7 +1,9 @@
use super::*;
use crate::AuthManager;
use crate::CodexAuth;
use crate::auth::AuthCredentialsStoreMode;
use crate::config::ConfigBuilder;
use crate::model_provider_info::ModelProviderAuthInfo;
use crate::model_provider_info::WireApi;
use base64::Engine as _;
use chrono::Utc;
@@ -13,8 +15,10 @@ use http::StatusCode;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::collections::BTreeMap;
use std::num::NonZeroU64;
use std::sync::Arc;
use std::sync::Mutex;
use tempfile::TempDir;
use tempfile::tempdir;
use tracing::Event;
use tracing::Subscriber;
@@ -24,7 +28,12 @@ use tracing_subscriber::layer::Context;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::util::SubscriberInitExt;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::header_regex;
use wiremock::matchers::method;
use wiremock::matchers::path;
fn remote_model(slug: &str, display: &str, priority: i32) -> ModelInfo {
remote_model_with_visibility(slug, display, priority, "list")
@@ -79,6 +88,7 @@ fn provider_for(base_url: String) -> ModelProviderInfo {
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
@@ -92,6 +102,90 @@ fn provider_for(base_url: String) -> ModelProviderInfo {
}
}
struct ProviderAuthScript {
tempdir: TempDir,
command: String,
args: Vec<String>,
}
impl ProviderAuthScript {
fn new(tokens: &[&str]) -> std::io::Result<Self> {
let tempdir = tempfile::tempdir()?;
let tokens_file = tempdir.path().join("tokens.txt");
std::fs::write(&tokens_file, format!("{}\n", tokens.join("\n")))?;
#[cfg(unix)]
let (command, args) = {
let script_path = tempdir.path().join("print-token.sh");
std::fs::write(
&script_path,
r#"#!/bin/sh
first_line=$(sed -n '1p' tokens.txt)
printf '%s\n' "$first_line"
tail -n +2 tokens.txt > tokens.next
mv tokens.next tokens.txt
"#,
)?;
let mut permissions = std::fs::metadata(&script_path)?.permissions();
{
use std::os::unix::fs::PermissionsExt;
permissions.set_mode(0o755);
}
std::fs::set_permissions(&script_path, permissions)?;
("./print-token.sh".to_string(), Vec::new())
};
#[cfg(windows)]
let (command, args) = {
let script_path = tempdir.path().join("print-token.ps1");
std::fs::write(
&script_path,
r#"$lines = Get-Content -Path tokens.txt
if ($lines.Count -eq 0) { exit 1 }
Write-Output $lines[0]
$lines | Select-Object -Skip 1 | Set-Content -Path tokens.txt
"#,
)?;
(
"powershell".to_string(),
vec![
"-NoProfile".to_string(),
"-ExecutionPolicy".to_string(),
"Bypass".to_string(),
"-File".to_string(),
".\\print-token.ps1".to_string(),
],
)
};
Ok(Self {
tempdir,
command,
args,
})
}
fn auth_config(&self) -> ModelProviderAuthInfo {
ModelProviderAuthInfo {
command: self.command.clone(),
args: self.args.clone(),
timeout_ms: non_zero_u64(/*value*/ 1_000),
refresh_interval_ms: non_zero_u64(/*value*/ 60_000),
cwd: match codex_utils_absolute_path::AbsolutePathBuf::try_from(self.tempdir.path()) {
Ok(cwd) => cwd,
Err(err) => panic!("tempdir should be absolute: {err}"),
},
}
}
}
fn non_zero_u64(value: u64) -> NonZeroU64 {
match NonZeroU64::new(value) {
Some(value) => value,
None => panic!("expected non-zero value: {value}"),
}
}
#[derive(Default)]
struct TagCollectorVisitor {
tags: BTreeMap<String, String>,
@@ -310,6 +404,57 @@ async fn refresh_available_models_sorts_by_priority() {
);
}
#[tokio::test]
async fn refresh_available_models_refreshes_provider_auth_after_401() {
let server = MockServer::start().await;
let auth_script = ProviderAuthScript::new(&["first-token", "second-token"]).unwrap();
let remote_models = vec![remote_model(
"provider-model",
"Provider",
/*priority*/ 0,
)];
Mock::given(method("GET"))
.and(path("/models"))
.and(header_regex("Authorization", "Bearer first-token"))
.respond_with(ResponseTemplate::new(401).set_body_string("unauthorized"))
.expect(1)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/models"))
.and(header_regex("Authorization", "Bearer second-token"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "application/json")
.set_body_json(ModelsResponse {
models: remote_models.clone(),
}),
)
.expect(1)
.mount(&server)
.await;
let codex_home = tempdir().expect("temp dir");
let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("unused"));
let provider = ModelProviderInfo {
auth: Some(auth_script.auth_config()),
..provider_for(server.uri())
};
let manager = ModelsManager::with_provider_for_tests(
codex_home.path().to_path_buf(),
auth_manager,
provider,
);
manager
.refresh_available_models(RefreshStrategy::Online)
.await
.expect("refresh succeeds");
assert_models_contain(&manager.get_remote_models().await, &remote_models);
}
#[tokio::test]
async fn refresh_available_models_uses_cache_when_fresh() {
let server = MockServer::start().await;

View File

@@ -0,0 +1,191 @@
use std::fmt;
use std::io;
use std::path::Path;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Instant;
use tokio::process::Command;
use tokio::sync::Mutex;
use crate::AuthManager;
use crate::auth::ExternalAuthRefreshContext;
use crate::auth::ExternalAuthRefresher;
use crate::auth::ExternalAuthTokens;
use crate::error::CodexErr;
use crate::error::Result;
use crate::model_provider_info::ModelProviderAuthInfo;
use crate::model_provider_info::ModelProviderInfo;
pub(crate) fn scoped_auth_manager_for_provider(
auth_manager: Option<Arc<AuthManager>>,
provider: &ModelProviderInfo,
) -> Option<Arc<AuthManager>> {
auth_manager.map(|auth_manager| scoped_auth_manager(auth_manager, provider))
}
pub(crate) fn scoped_auth_manager(
auth_manager: Arc<AuthManager>,
provider: &ModelProviderInfo,
) -> Arc<AuthManager> {
match provider.auth.clone() {
Some(config) => {
auth_manager.with_external_bearer_refresher(Arc::new(ProviderAuthResolver::new(config)))
}
None => auth_manager,
}
}
#[derive(Clone)]
pub(crate) struct ProviderAuthResolver {
state: Arc<ProviderAuthState>,
}
impl ProviderAuthResolver {
fn new(config: ModelProviderAuthInfo) -> Self {
Self {
state: Arc::new(ProviderAuthState::new(config)),
}
}
}
#[async_trait::async_trait]
impl ExternalAuthRefresher for ProviderAuthResolver {
async fn resolve(&self) -> io::Result<Option<ExternalAuthTokens>> {
let mut cached = self.state.cached_token.lock().await;
if let Some(cached_token) = cached.as_ref()
&& cached_token.fetched_at.elapsed() < self.state.config.refresh_interval()
{
return Ok(Some(cached_token.tokens.clone()));
}
let tokens = run_provider_auth_command(&self.state.config)
.await
.map_err(codex_err_to_io)?;
*cached = Some(CachedProviderToken {
tokens: tokens.clone(),
fetched_at: Instant::now(),
});
Ok(Some(tokens))
}
async fn refresh(
&self,
_context: ExternalAuthRefreshContext,
) -> io::Result<ExternalAuthTokens> {
let tokens = run_provider_auth_command(&self.state.config)
.await
.map_err(codex_err_to_io)?;
let mut cached = self.state.cached_token.lock().await;
*cached = Some(CachedProviderToken {
tokens: tokens.clone(),
fetched_at: Instant::now(),
});
Ok(tokens)
}
}
impl fmt::Debug for ProviderAuthResolver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProviderAuthResolver")
.finish_non_exhaustive()
}
}
struct ProviderAuthState {
config: ModelProviderAuthInfo,
cached_token: Mutex<Option<CachedProviderToken>>,
}
impl ProviderAuthState {
fn new(config: ModelProviderAuthInfo) -> Self {
Self {
config,
cached_token: Mutex::new(None),
}
}
}
struct CachedProviderToken {
tokens: ExternalAuthTokens,
fetched_at: Instant,
}
async fn run_provider_auth_command(config: &ModelProviderAuthInfo) -> Result<ExternalAuthTokens> {
let program = resolve_provider_auth_program(&config.command, &config.cwd)?;
let mut command = Command::new(&program);
command
.args(&config.args)
.current_dir(config.cwd.as_path())
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true);
let output = tokio::time::timeout(config.timeout(), command.output())
.await
.map_err(|_| {
CodexErr::InvalidRequest(format!(
"provider auth command `{}` timed out after {} ms",
config.command,
config.timeout_ms.get()
))
})?
.map_err(|err| {
CodexErr::InvalidRequest(format!(
"provider auth command `{}` failed to start: {err}",
config.command
))
})?;
if !output.status.success() {
let status = output.status;
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
let stderr_suffix = if stderr.is_empty() {
String::new()
} else {
format!(": {stderr}")
};
return Err(CodexErr::InvalidRequest(format!(
"provider auth command `{}` exited with status {status}{stderr_suffix}",
config.command
)));
}
let stdout = String::from_utf8(output.stdout).map_err(|_| {
CodexErr::InvalidRequest(format!(
"provider auth command `{}` wrote non-UTF-8 data to stdout",
config.command
))
})?;
let token = stdout.trim().to_string();
if token.is_empty() {
return Err(CodexErr::InvalidRequest(format!(
"provider auth command `{}` produced an empty token",
config.command
)));
}
Ok(ExternalAuthTokens::access_token_only(token))
}
fn resolve_provider_auth_program(command: &str, cwd: &Path) -> Result<PathBuf> {
let path = Path::new(command);
if path.is_absolute() || path.components().count() > 1 {
return Ok(
codex_utils_absolute_path::AbsolutePathBuf::resolve_path_against_base(path, cwd)?
.into_path_buf(),
);
}
Ok(PathBuf::from(command))
}
fn codex_err_to_io(error: CodexErr) -> io::Error {
io::Error::other(error.to_string())
}
#[cfg(test)]
#[path = "provider_auth_tests.rs"]
mod tests;

View File

@@ -0,0 +1,139 @@
use super::*;
use pretty_assertions::assert_eq;
use std::num::NonZeroU64;
use tempfile::TempDir;
#[tokio::test]
async fn caches_command_output_until_refreshed() {
let script = ProviderAuthScript::new(&["first-token", "second-token"]).unwrap();
let source = ProviderAuthResolver::new(script.auth_config());
let first = source
.resolve()
.await
.unwrap()
.map(|tokens| tokens.access_token);
let second = source
.resolve()
.await
.unwrap()
.map(|tokens| tokens.access_token);
let refreshed = source
.refresh(ExternalAuthRefreshContext {
reason: crate::auth::ExternalAuthRefreshReason::Unauthorized,
previous_account_id: None,
})
.await
.unwrap();
let after_refresh = source
.resolve()
.await
.unwrap()
.map(|tokens| tokens.access_token);
assert_eq!(first.as_deref(), Some("first-token"));
assert_eq!(second.as_deref(), Some("first-token"));
assert_eq!(refreshed.access_token, "second-token");
assert_eq!(after_refresh.as_deref(), Some("second-token"));
}
#[tokio::test]
async fn refresh_returns_bearer_only_external_auth_tokens() {
let script = ProviderAuthScript::new(&["first-token"]).unwrap();
let source = ProviderAuthResolver::new(script.auth_config());
let tokens = source
.refresh(ExternalAuthRefreshContext {
reason: crate::auth::ExternalAuthRefreshReason::Unauthorized,
previous_account_id: Some("ignored".to_string()),
})
.await
.unwrap();
assert_eq!(tokens.access_token, "first-token");
assert_eq!(tokens.chatgpt_metadata, None);
}
struct ProviderAuthScript {
tempdir: TempDir,
command: String,
args: Vec<String>,
}
impl ProviderAuthScript {
fn new(tokens: &[&str]) -> std::io::Result<Self> {
let tempdir = tempfile::tempdir()?;
let token_file = tempdir.path().join("tokens.txt");
std::fs::write(&token_file, format!("{}\n", tokens.join("\n")))?;
#[cfg(unix)]
let (command, args) = {
let script_path = tempdir.path().join("print-token.sh");
std::fs::write(
&script_path,
r#"#!/bin/sh
first_line=$(sed -n '1p' tokens.txt)
printf '%s\n' "$first_line"
tail -n +2 tokens.txt > tokens.next
mv tokens.next tokens.txt
"#,
)?;
let mut permissions = std::fs::metadata(&script_path)?.permissions();
{
use std::os::unix::fs::PermissionsExt;
permissions.set_mode(0o755);
}
std::fs::set_permissions(&script_path, permissions)?;
("./print-token.sh".to_string(), Vec::new())
};
#[cfg(windows)]
let (command, args) = {
let script_path = tempdir.path().join("print-token.ps1");
std::fs::write(
&script_path,
r#"$lines = Get-Content -Path tokens.txt
if ($lines.Count -eq 0) { exit 1 }
Write-Output $lines[0]
$lines | Select-Object -Skip 1 | Set-Content -Path tokens.txt
"#,
)?;
(
"powershell".to_string(),
vec![
"-NoProfile".to_string(),
"-ExecutionPolicy".to_string(),
"Bypass".to_string(),
"-File".to_string(),
".\\print-token.ps1".to_string(),
],
)
};
Ok(Self {
tempdir,
command,
args,
})
}
fn auth_config(&self) -> ModelProviderAuthInfo {
ModelProviderAuthInfo {
command: self.command.clone(),
args: self.args.clone(),
timeout_ms: non_zero_u64(/*value*/ 1_000),
refresh_interval_ms: non_zero_u64(/*value*/ 60_000),
cwd: match codex_utils_absolute_path::AbsolutePathBuf::try_from(self.tempdir.path()) {
Ok(cwd) => cwd,
Err(err) => panic!("tempdir should be absolute: {err}"),
},
}
}
}
fn non_zero_u64(value: u64) -> NonZeroU64 {
match NonZeroU64::new(value) {
Some(value) => value,
None => panic!("expected non-zero value: {value}"),
}
}

View File

@@ -452,7 +452,12 @@ async fn prepare_realtime_start(
params: ConversationStartParams,
) -> CodexResult<PreparedRealtimeConversationStart> {
let provider = sess.provider().await;
let auth = sess.services.auth_manager.auth().await;
let auth_manager = sess
.services
.model_client
.auth_manager()
.unwrap_or_else(|| Arc::clone(&sess.services.auth_manager));
let auth = auth_manager.auth().await;
let realtime_api_key = realtime_api_key(auth.as_ref(), &provider)?;
let mut api_provider = provider.to_api_provider(Some(crate::auth::AuthMode::ApiKey))?;
let config = sess.get_config().await;

View File

@@ -46,6 +46,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
@@ -158,6 +159,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
@@ -265,6 +267,7 @@ async fn responses_respects_model_info_overrides_from_config() {
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,

View File

@@ -1,5 +1,7 @@
use codex_core::AuthManager;
use codex_core::CodexAuth;
use codex_core::ModelClient;
use codex_core::ModelProviderAuthInfo;
use codex_core::ModelProviderInfo;
use codex_core::NewThread;
use codex_core::Prompt;
@@ -64,6 +66,7 @@ use futures::StreamExt;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::io::Write;
use std::num::NonZeroU64;
use std::sync::Arc;
use tempfile::TempDir;
use uuid::Uuid;
@@ -143,6 +146,90 @@ fn write_auth_json(
fake_jwt
}
struct ProviderAuthCommandFixture {
tempdir: TempDir,
command: String,
args: Vec<String>,
}
impl ProviderAuthCommandFixture {
fn new(tokens: &[&str]) -> std::io::Result<Self> {
let tempdir = tempfile::tempdir()?;
let tokens_file = tempdir.path().join("tokens.txt");
std::fs::write(&tokens_file, format!("{}\n", tokens.join("\n")))?;
#[cfg(unix)]
let (command, args) = {
let script_path = tempdir.path().join("print-token.sh");
std::fs::write(
&script_path,
r#"#!/bin/sh
first_line=$(sed -n '1p' tokens.txt)
printf '%s\n' "$first_line"
tail -n +2 tokens.txt > tokens.next
mv tokens.next tokens.txt
"#,
)?;
let mut permissions = std::fs::metadata(&script_path)?.permissions();
{
use std::os::unix::fs::PermissionsExt;
permissions.set_mode(0o755);
}
std::fs::set_permissions(&script_path, permissions)?;
("./print-token.sh".to_string(), Vec::new())
};
#[cfg(windows)]
let (command, args) = {
let script_path = tempdir.path().join("print-token.ps1");
std::fs::write(
&script_path,
r#"$lines = Get-Content -Path tokens.txt
if ($lines.Count -eq 0) { exit 1 }
Write-Output $lines[0]
$lines | Select-Object -Skip 1 | Set-Content -Path tokens.txt
"#,
)?;
(
"powershell".to_string(),
vec![
"-NoProfile".to_string(),
"-ExecutionPolicy".to_string(),
"Bypass".to_string(),
"-File".to_string(),
".\\print-token.ps1".to_string(),
],
)
};
Ok(Self {
tempdir,
command,
args,
})
}
fn auth(&self) -> ModelProviderAuthInfo {
ModelProviderAuthInfo {
command: self.command.clone(),
args: self.args.clone(),
timeout_ms: non_zero_u64(/*value*/ 1_000),
refresh_interval_ms: non_zero_u64(/*value*/ 60_000),
cwd: match codex_utils_absolute_path::AbsolutePathBuf::try_from(self.tempdir.path()) {
Ok(cwd) => cwd,
Err(err) => panic!("tempdir should be absolute: {err}"),
},
}
}
}
fn non_zero_u64(value: u64) -> NonZeroU64 {
match NonZeroU64::new(value) {
Some(value) => value,
None => panic!("expected non-zero value: {value}"),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn resume_includes_initial_messages_and_sends_prior_items() {
skip_if_no_network!();
@@ -659,6 +746,227 @@ async fn includes_conversation_id_and_model_headers_in_request() {
assert_eq!(request_authorization, "Bearer Test API Key");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn provider_auth_command_supplies_bearer_token() {
skip_if_no_network!();
let server = MockServer::start().await;
let resp_mock = mount_sse_once(
&server,
sse(vec![ev_response_created("resp1"), ev_completed("resp1")]),
)
.await;
let auth_fixture = ProviderAuthCommandFixture::new(&["command-token"]).unwrap();
let provider = ModelProviderInfo {
name: "corp".into(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: Some(auth_fixture.auth()),
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(5_000),
websocket_connect_timeout_ms: None,
requires_openai_auth: false,
supports_websockets: false,
};
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home).await;
config.model_provider_id = provider.name.clone();
config.model_provider = provider.clone();
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let model = codex_core::test_support::get_model_offline(config.model.as_deref());
config.model = Some(model.clone());
let config = Arc::new(config);
let model_info =
codex_core::test_support::construct_model_info_offline(model.as_str(), &config);
let conversation_id = ThreadId::new();
let session_telemetry = SessionTelemetry::new(
conversation_id,
model.as_str(),
model_info.slug.as_str(),
/*account_id*/ None,
Some("test@test.com".to_string()),
/*auth_mode*/ None,
"test_originator".to_string(),
/*log_user_prompts*/ false,
"test".to_string(),
SessionSource::Exec,
);
let client = ModelClient::new(
Some(AuthManager::from_auth_for_testing(CodexAuth::from_api_key(
"unused-api-key",
))),
conversation_id,
provider,
SessionSource::Exec,
config.model_verbosity,
/*enable_request_compression*/ false,
/*include_timing_metrics*/ false,
/*beta_features_header*/ None,
);
let mut client_session = client.new_session();
let mut prompt = Prompt::default();
prompt.input.push(ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "hello".to_string(),
}],
end_turn: None,
phase: None,
});
let mut stream = client_session
.stream(
&prompt,
&model_info,
&session_telemetry,
effort,
summary.unwrap_or(ReasoningSummary::Auto),
/*service_tier*/ None,
/*turn_metadata_header*/ None,
)
.await
.expect("responses stream to start");
while let Some(event) = stream.next().await {
if let Ok(ResponseEvent::Completed { .. }) = event {
break;
}
}
let request = resp_mock.single_request();
let request_authorization = request
.header("authorization")
.expect("authorization header");
assert_eq!(request_authorization, "Bearer command-token");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn provider_auth_command_refreshes_after_401() {
skip_if_no_network!();
let server = MockServer::start().await;
let auth_fixture = ProviderAuthCommandFixture::new(&["first-token", "second-token"]).unwrap();
Mock::given(method("POST"))
.and(path("/v1/responses"))
.and(header_regex("Authorization", "Bearer first-token"))
.respond_with(ResponseTemplate::new(401).set_body_string("unauthorized"))
.expect(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/v1/responses"))
.and(header_regex("Authorization", "Bearer second-token"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(
sse(vec![ev_response_created("resp1"), ev_completed("resp1")]),
"text/event-stream",
),
)
.expect(1)
.mount(&server)
.await;
let provider = ModelProviderInfo {
name: "corp".into(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: Some(auth_fixture.auth()),
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(5_000),
websocket_connect_timeout_ms: None,
requires_openai_auth: false,
supports_websockets: false,
};
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home).await;
config.model_provider_id = provider.name.clone();
config.model_provider = provider.clone();
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let model = codex_core::test_support::get_model_offline(config.model.as_deref());
config.model = Some(model.clone());
let config = Arc::new(config);
let model_info =
codex_core::test_support::construct_model_info_offline(model.as_str(), &config);
let conversation_id = ThreadId::new();
let session_telemetry = SessionTelemetry::new(
conversation_id,
model.as_str(),
model_info.slug.as_str(),
/*account_id*/ None,
Some("test@test.com".to_string()),
/*auth_mode*/ None,
"test_originator".to_string(),
/*log_user_prompts*/ false,
"test".to_string(),
SessionSource::Exec,
);
let client = ModelClient::new(
Some(AuthManager::from_auth_for_testing(CodexAuth::from_api_key(
"unused-api-key",
))),
conversation_id,
provider,
SessionSource::Exec,
config.model_verbosity,
/*enable_request_compression*/ false,
/*include_timing_metrics*/ false,
/*beta_features_header*/ None,
);
let mut client_session = client.new_session();
let mut prompt = Prompt::default();
prompt.input.push(ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "hello".to_string(),
}],
end_turn: None,
phase: None,
});
let mut stream = client_session
.stream(
&prompt,
&model_info,
&session_telemetry,
effort,
summary.unwrap_or(ReasoningSummary::Auto),
/*service_tier*/ None,
/*turn_metadata_header*/ None,
)
.await
.expect("responses stream to start");
while let Some(event) = stream.next().await {
if let Ok(ResponseEvent::Completed { .. }) = event {
break;
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn includes_base_instructions_override_in_request() {
skip_if_no_network!();
@@ -1796,6 +2104,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
@@ -2396,6 +2705,7 @@ async fn azure_overrides_assign_properties_used_for_responses_url() {
// Reuse the existing environment variable to avoid using unsafe code
env_key: Some(existing_env_var_with_random_value.to_string()),
experimental_bearer_token: None,
auth: None,
query_params: Some(std::collections::HashMap::from([(
"api-version".to_string(),
"2025-04-01-preview".to_string(),
@@ -2486,6 +2796,7 @@ async fn env_var_overrides_loaded_auth() {
)])),
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
http_headers: Some(std::collections::HashMap::from([(
"Custom-Header".to_string(),

View File

@@ -1674,6 +1674,7 @@ fn websocket_provider_with_connect_timeout(
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,

View File

@@ -69,6 +69,7 @@ async fn continue_after_stream_error() {
env_key: Some("PATH".into()),
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,

View File

@@ -53,6 +53,7 @@ async fn retries_on_early_close() {
env_key: Some("PATH".into()),
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,

View File

@@ -4,6 +4,7 @@ use crate::auth::storage::get_auth_file;
use crate::token_data::IdTokenInfo;
use crate::token_data::KnownPlan as InternalKnownPlan;
use crate::token_data::PlanType as InternalPlanType;
use async_trait::async_trait;
use codex_protocol::account::PlanType as AccountPlanType;
use base64::Engine;
@@ -12,6 +13,7 @@ use pretty_assertions::assert_eq;
use serde::Serialize;
use serde_json::json;
use std::sync::Arc;
use std::sync::Mutex;
use tempfile::tempdir;
#[tokio::test]
@@ -252,6 +254,100 @@ fn refresh_failure_is_scoped_to_the_matching_auth_snapshot() {
assert_eq!(manager.refresh_failure_for_auth(&updated_auth), None);
}
#[test]
fn external_auth_tokens_without_chatgpt_metadata_cannot_seed_chatgpt_auth() {
let err = AuthDotJson::from_external_tokens(&ExternalAuthTokens::access_token_only(
"test-access-token",
))
.expect_err("bearer-only external auth should not seed ChatGPT auth");
assert_eq!(
err.to_string(),
"external auth tokens are missing ChatGPT metadata"
);
}
#[tokio::test]
async fn auth_manager_with_external_bearer_refresher_returns_provider_token_only_for_derived_manager()
{
let base_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("base-token"));
let derived_manager =
base_manager.with_external_bearer_refresher(Arc::new(StaticExternalAuthRefresher::new(
Some(ExternalAuthTokens::access_token_only("provider-token")),
ExternalAuthTokens::access_token_only("refreshed-provider-token"),
)));
assert_eq!(
base_manager
.auth()
.await
.and_then(|auth| auth.api_key().map(str::to_string)),
Some("base-token".to_string())
);
assert_eq!(
derived_manager
.auth()
.await
.and_then(|auth| auth.api_key().map(str::to_string)),
Some("provider-token".to_string())
);
}
#[tokio::test]
async fn unauthorized_recovery_uses_external_refresh_for_bearer_manager() {
let base_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("base-token"));
let refresher = Arc::new(StaticExternalAuthRefresher::new(
Some(ExternalAuthTokens::access_token_only("provider-token")),
ExternalAuthTokens::access_token_only("refreshed-provider-token"),
));
let derived_manager = base_manager.with_external_bearer_refresher(refresher.clone());
let mut recovery = derived_manager.unauthorized_recovery();
assert!(recovery.has_next());
assert_eq!(recovery.mode_name(), "external");
assert_eq!(recovery.step_name(), "external_refresh");
let result = recovery
.next()
.await
.expect("external refresh should succeed");
assert_eq!(result.auth_state_changed(), Some(true));
assert_eq!(*refresher.refresh_calls.lock().unwrap(), 1);
}
#[derive(Debug)]
struct StaticExternalAuthRefresher {
resolved: Option<ExternalAuthTokens>,
refreshed: ExternalAuthTokens,
refresh_calls: Mutex<usize>,
}
impl StaticExternalAuthRefresher {
fn new(resolved: Option<ExternalAuthTokens>, refreshed: ExternalAuthTokens) -> Self {
Self {
resolved,
refreshed,
refresh_calls: Mutex::new(0),
}
}
}
#[async_trait]
impl ExternalAuthRefresher for StaticExternalAuthRefresher {
async fn resolve(&self) -> std::io::Result<Option<ExternalAuthTokens>> {
Ok(self.resolved.clone())
}
async fn refresh(
&self,
_context: ExternalAuthRefreshContext,
) -> std::io::Result<ExternalAuthTokens> {
*self.refresh_calls.lock().unwrap() += 1;
Ok(self.refreshed.clone())
}
}
struct AuthFileParams {
openai_api_key: Option<String>,
chatgpt_plan_type: Option<String>,

View File

@@ -90,11 +90,43 @@ pub enum RefreshTokenError {
Transient(#[from] std::io::Error),
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ExternalAuthChatgptMetadata {
pub account_id: String,
pub plan_type: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ExternalAuthTokens {
pub access_token: String,
pub chatgpt_account_id: String,
pub chatgpt_plan_type: Option<String>,
pub chatgpt_metadata: Option<ExternalAuthChatgptMetadata>,
}
impl ExternalAuthTokens {
pub fn access_token_only(access_token: impl Into<String>) -> Self {
Self {
access_token: access_token.into(),
chatgpt_metadata: None,
}
}
pub fn chatgpt(
access_token: impl Into<String>,
chatgpt_account_id: impl Into<String>,
chatgpt_plan_type: Option<String>,
) -> Self {
Self {
access_token: access_token.into(),
chatgpt_metadata: Some(ExternalAuthChatgptMetadata {
account_id: chatgpt_account_id.into(),
plan_type: chatgpt_plan_type,
}),
}
}
pub fn chatgpt_metadata(&self) -> Option<&ExternalAuthChatgptMetadata> {
self.chatgpt_metadata.as_ref()
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -110,6 +142,10 @@ pub struct ExternalAuthRefreshContext {
#[async_trait]
pub trait ExternalAuthRefresher: Send + Sync {
async fn resolve(&self) -> std::io::Result<Option<ExternalAuthTokens>> {
Ok(None)
}
async fn refresh(
&self,
context: ExternalAuthRefreshContext,
@@ -736,11 +772,16 @@ fn refresh_token_endpoint() -> String {
impl AuthDotJson {
fn from_external_tokens(external: &ExternalAuthTokens) -> std::io::Result<Self> {
let Some(chatgpt_metadata) = external.chatgpt_metadata() else {
return Err(std::io::Error::other(
"external auth tokens are missing ChatGPT metadata",
));
};
let mut token_info =
parse_chatgpt_jwt_claims(&external.access_token).map_err(std::io::Error::other)?;
token_info.chatgpt_account_id = Some(external.chatgpt_account_id.clone());
token_info.chatgpt_plan_type = external
.chatgpt_plan_type
token_info.chatgpt_account_id = Some(chatgpt_metadata.account_id.clone());
token_info.chatgpt_plan_type = chatgpt_metadata
.plan_type
.as_deref()
.map(InternalPlanType::from_raw_value)
.or(token_info.chatgpt_plan_type)
@@ -749,7 +790,7 @@ impl AuthDotJson {
id_token: token_info,
access_token: external.access_token.clone(),
refresh_token: String::new(),
account_id: Some(external.chatgpt_account_id.clone()),
account_id: Some(chatgpt_metadata.account_id.clone()),
};
Ok(Self {
@@ -765,11 +806,11 @@ impl AuthDotJson {
chatgpt_account_id: &str,
chatgpt_plan_type: Option<&str>,
) -> std::io::Result<Self> {
let external = ExternalAuthTokens {
access_token: access_token.to_string(),
chatgpt_account_id: chatgpt_account_id.to_string(),
chatgpt_plan_type: chatgpt_plan_type.map(str::to_string),
};
let external = ExternalAuthTokens::chatgpt(
access_token,
chatgpt_account_id,
chatgpt_plan_type.map(str::to_string),
);
Self::from_external_tokens(&external)
}
@@ -799,8 +840,6 @@ impl AuthDotJson {
#[derive(Clone)]
struct CachedAuth {
auth: Option<CodexAuth>,
/// Callback used to refresh external auth by asking the parent app for new tokens.
external_refresher: Option<Arc<dyn ExternalAuthRefresher>>,
/// Permanent refresh failure cached for the current auth snapshot so
/// later refresh attempts for the same credentials fail fast without network.
permanent_refresh_failure: Option<AuthScopedRefreshFailure>,
@@ -812,6 +851,27 @@ struct AuthScopedRefreshFailure {
error: RefreshTokenFailedError,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ExternalAuthKind {
Bearer,
Chatgpt,
}
#[derive(Clone)]
struct ExternalAuthHandle {
kind: ExternalAuthKind,
refresher: Arc<dyn ExternalAuthRefresher>,
}
impl Debug for ExternalAuthHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExternalAuthHandle")
.field("kind", &self.kind)
.field("refresher", &"present")
.finish()
}
}
impl Debug for CachedAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CachedAuth")
@@ -819,10 +879,6 @@ impl Debug for CachedAuth {
"auth_mode",
&self.auth.as_ref().map(CodexAuth::api_auth_mode),
)
.field(
"external_refresher",
&self.external_refresher.as_ref().map(|_| "present"),
)
.field(
"permanent_refresh_failure",
&self
@@ -866,9 +922,14 @@ enum UnauthorizedRecoveryMode {
// 2. Attempt to refresh the token using OAuth token refresh flow.
// If after both steps the server still responds with 401 we let the error bubble to the user.
//
// For external ChatGPT auth tokens (chatgptAuthTokens), UnauthorizedRecovery does not touch disk or refresh
// tokens locally. Instead it calls the ExternalAuthRefresher (account/chatgptAuthTokens/refresh) to ask the
// parent app for new tokens, stores them in the ephemeral auth store, and retries once.
// For external auth sources, UnauthorizedRecovery delegates to the configured
// ExternalAuthRefresher and retries once.
//
// - External ChatGPT auth tokens (`chatgptAuthTokens`) are refreshed by asking
// the parent app for new tokens, persisting them in the ephemeral auth
// store, and reloading the cached auth snapshot.
// - External bearer auth sources resolve bearer-only tokens for custom model
// providers and refresh them without touching disk.
pub struct UnauthorizedRecovery {
manager: Arc<AuthManager>,
step: UnauthorizedRecoveryStep,
@@ -891,9 +952,10 @@ impl UnauthorizedRecovery {
fn new(manager: Arc<AuthManager>) -> Self {
let cached_auth = manager.auth_cached();
let expected_account_id = cached_auth.as_ref().and_then(CodexAuth::get_account_id);
let mode = if cached_auth
.as_ref()
.is_some_and(CodexAuth::is_external_chatgpt_tokens)
let mode = if manager.external_auth_kind() == Some(ExternalAuthKind::Bearer)
|| cached_auth
.as_ref()
.is_some_and(CodexAuth::is_external_chatgpt_tokens)
{
UnauthorizedRecoveryMode::External
} else {
@@ -912,6 +974,10 @@ impl UnauthorizedRecovery {
}
pub fn has_next(&self) -> bool {
if self.manager.external_auth_kind() == Some(ExternalAuthKind::Bearer) {
return !matches!(self.step, UnauthorizedRecoveryStep::Done);
}
if !self
.manager
.auth_cached()
@@ -931,6 +997,16 @@ impl UnauthorizedRecovery {
}
pub fn unavailable_reason(&self) -> &'static str {
if self.manager.external_auth_kind() == Some(ExternalAuthKind::Bearer) {
return if matches!(self.step, UnauthorizedRecoveryStep::Done) {
"recovery_exhausted"
} else if self.manager.has_external_auth_refresher() {
"ready"
} else {
"no_external_refresher"
};
}
if !self
.manager
.auth_cached()
@@ -1039,11 +1115,12 @@ impl UnauthorizedRecovery {
#[derive(Debug)]
pub struct AuthManager {
codex_home: PathBuf,
inner: RwLock<CachedAuth>,
inner: Arc<RwLock<CachedAuth>>,
enable_codex_api_key_env: bool,
auth_credentials_store_mode: AuthCredentialsStoreMode,
forced_chatgpt_workspace_id: RwLock<Option<String>>,
refresh_lock: AsyncMutex<()>,
forced_chatgpt_workspace_id: Arc<RwLock<Option<String>>>,
refresh_lock: Arc<AsyncMutex<()>>,
external_auth: RwLock<Option<ExternalAuthHandle>>,
}
impl AuthManager {
@@ -1065,15 +1142,15 @@ impl AuthManager {
.flatten();
Self {
codex_home,
inner: RwLock::new(CachedAuth {
inner: Arc::new(RwLock::new(CachedAuth {
auth: managed_auth,
external_refresher: None,
permanent_refresh_failure: None,
}),
})),
enable_codex_api_key_env,
auth_credentials_store_mode,
forced_chatgpt_workspace_id: RwLock::new(None),
refresh_lock: AsyncMutex::new(()),
forced_chatgpt_workspace_id: Arc::new(RwLock::new(None)),
refresh_lock: Arc::new(AsyncMutex::new(())),
external_auth: RwLock::new(None),
}
}
@@ -1081,17 +1158,17 @@ impl AuthManager {
pub fn from_auth_for_testing(auth: CodexAuth) -> Arc<Self> {
let cached = CachedAuth {
auth: Some(auth),
external_refresher: None,
permanent_refresh_failure: None,
};
Arc::new(Self {
codex_home: PathBuf::from("non-existent"),
inner: RwLock::new(cached),
inner: Arc::new(RwLock::new(cached)),
enable_codex_api_key_env: false,
auth_credentials_store_mode: AuthCredentialsStoreMode::File,
forced_chatgpt_workspace_id: RwLock::new(None),
refresh_lock: AsyncMutex::new(()),
forced_chatgpt_workspace_id: Arc::new(RwLock::new(None)),
refresh_lock: Arc::new(AsyncMutex::new(())),
external_auth: RwLock::new(None),
})
}
@@ -1099,16 +1176,16 @@ impl AuthManager {
pub fn from_auth_for_testing_with_home(auth: CodexAuth, codex_home: PathBuf) -> Arc<Self> {
let cached = CachedAuth {
auth: Some(auth),
external_refresher: None,
permanent_refresh_failure: None,
};
Arc::new(Self {
codex_home,
inner: RwLock::new(cached),
inner: Arc::new(RwLock::new(cached)),
enable_codex_api_key_env: false,
auth_credentials_store_mode: AuthCredentialsStoreMode::File,
forced_chatgpt_workspace_id: RwLock::new(None),
refresh_lock: AsyncMutex::new(()),
forced_chatgpt_workspace_id: Arc::new(RwLock::new(None)),
refresh_lock: Arc::new(AsyncMutex::new(())),
external_auth: RwLock::new(None),
})
}
@@ -1131,6 +1208,10 @@ impl AuthManager {
/// For stale managed ChatGPT auth, first performs a guarded reload and then
/// refreshes only if the on-disk auth is unchanged.
pub async fn auth(&self) -> Option<CodexAuth> {
if let Some(auth) = self.resolve_external_api_key_auth().await {
return Some(auth);
}
let auth = self.auth_cached()?;
if Self::is_stale_for_proactive_refresh(&auth)
&& let Err(err) = self.refresh_token().await
@@ -1251,17 +1332,38 @@ impl AuthManager {
}
pub fn set_external_auth_refresher(&self, refresher: Arc<dyn ExternalAuthRefresher>) {
if let Ok(mut guard) = self.inner.write() {
guard.external_refresher = Some(refresher);
if let Ok(mut guard) = self.external_auth.write() {
*guard = Some(ExternalAuthHandle {
kind: ExternalAuthKind::Chatgpt,
refresher,
});
}
}
pub fn clear_external_auth_refresher(&self) {
if let Ok(mut guard) = self.inner.write() {
guard.external_refresher = None;
if let Ok(mut guard) = self.external_auth.write() {
*guard = None;
}
}
pub fn with_external_bearer_refresher(
self: &Arc<Self>,
refresher: Arc<dyn ExternalAuthRefresher>,
) -> Arc<Self> {
Arc::new(Self {
codex_home: self.codex_home.clone(),
inner: Arc::clone(&self.inner),
enable_codex_api_key_env: self.enable_codex_api_key_env,
auth_credentials_store_mode: self.auth_credentials_store_mode,
forced_chatgpt_workspace_id: Arc::clone(&self.forced_chatgpt_workspace_id),
refresh_lock: Arc::clone(&self.refresh_lock),
external_auth: RwLock::new(Some(ExternalAuthHandle {
kind: ExternalAuthKind::Bearer,
refresher,
})),
})
}
pub fn set_forced_chatgpt_workspace_id(&self, workspace_id: Option<String>) {
if let Ok(mut guard) = self.forced_chatgpt_workspace_id.write() {
*guard = workspace_id;
@@ -1276,13 +1378,17 @@ impl AuthManager {
}
pub fn has_external_auth_refresher(&self) -> bool {
self.inner
self.external_auth
.read()
.ok()
.map(|guard| guard.external_refresher.is_some())
.map(|guard| guard.is_some())
.unwrap_or(false)
}
pub fn has_external_bearer_refresher(&self) -> bool {
self.external_auth_kind() == Some(ExternalAuthKind::Bearer)
}
pub fn is_external_auth_active(&self) -> bool {
self.auth_cached()
.as_ref()
@@ -1310,6 +1416,35 @@ impl AuthManager {
UnauthorizedRecovery::new(Arc::clone(self))
}
fn external_auth_handle(&self) -> Option<ExternalAuthHandle> {
self.external_auth
.read()
.ok()
.and_then(|guard| guard.clone())
}
fn external_auth_kind(&self) -> Option<ExternalAuthKind> {
self.external_auth_handle().map(|handle| handle.kind)
}
async fn resolve_external_api_key_auth(&self) -> Option<CodexAuth> {
let Some(handle) = self.external_auth_handle() else {
return None;
};
if handle.kind != ExternalAuthKind::Bearer {
return None;
}
match handle.refresher.resolve().await {
Ok(Some(tokens)) => Some(CodexAuth::from_api_key(&tokens.access_token)),
Ok(None) => None,
Err(err) => {
tracing::error!("Failed to resolve external bearer auth: {err}");
None
}
}
}
/// Attempt to refresh the token by first performing a guarded reload. Auth
/// is reloaded from storage only when the account id matches the currently
/// cached account id. If the persisted token differs from the cached token, we
@@ -1432,16 +1567,7 @@ impl AuthManager {
reason: ExternalAuthRefreshReason,
) -> Result<(), RefreshTokenError> {
let forced_chatgpt_workspace_id = self.forced_chatgpt_workspace_id();
let refresher = match self.inner.read() {
Ok(guard) => guard.external_refresher.clone(),
Err(_) => {
return Err(RefreshTokenError::Transient(std::io::Error::other(
"failed to read external auth state",
)));
}
};
let Some(refresher) = refresher else {
let Some(handle) = self.external_auth_handle() else {
return Err(RefreshTokenError::Transient(std::io::Error::other(
"external auth refresher is not configured",
)));
@@ -1456,14 +1582,22 @@ impl AuthManager {
previous_account_id,
};
let refreshed = refresher.refresh(context).await?;
let refreshed = handle.refresher.refresh(context).await?;
if handle.kind == ExternalAuthKind::Bearer {
return Ok(());
}
let Some(chatgpt_metadata) = refreshed.chatgpt_metadata() else {
return Err(RefreshTokenError::Transient(std::io::Error::other(
"external auth refresh did not return ChatGPT metadata",
)));
};
if let Some(expected_workspace_id) = forced_chatgpt_workspace_id.as_deref()
&& refreshed.chatgpt_account_id != expected_workspace_id
&& chatgpt_metadata.account_id != expected_workspace_id
{
return Err(RefreshTokenError::Transient(std::io::Error::other(
format!(
"external auth refresh returned workspace {:?}, expected {expected_workspace_id:?}",
refreshed.chatgpt_account_id,
chatgpt_metadata.account_id,
),
)));
}

View File

@@ -22,6 +22,11 @@ pub use auth::AuthManager;
pub use auth::CLIENT_ID;
pub use auth::CODEX_API_KEY_ENV_VAR;
pub use auth::CodexAuth;
pub use auth::ExternalAuthChatgptMetadata;
pub use auth::ExternalAuthRefreshContext;
pub use auth::ExternalAuthRefreshReason;
pub use auth::ExternalAuthRefresher;
pub use auth::ExternalAuthTokens;
pub use auth::OPENAI_API_KEY_ENV_VAR;
pub use auth::REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR;
pub use auth::RefreshTokenError;