mirror of
https://github.com/openai/codex.git
synced 2026-04-03 14:01:37 +03:00
Compare commits
3 Commits
mstar/remo
...
pr16267
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1c63e8f83 | ||
|
|
4127419694 | ||
|
|
b397919da1 |
@@ -144,11 +144,11 @@ impl ExternalAuthRefresher for ExternalAuthRefreshBridge {
|
||||
let response: ChatgptAuthTokensRefreshResponse =
|
||||
serde_json::from_value(result).map_err(std::io::Error::other)?;
|
||||
|
||||
Ok(ExternalAuthTokens {
|
||||
access_token: response.access_token,
|
||||
chatgpt_account_id: response.chatgpt_account_id,
|
||||
chatgpt_plan_type: response.chatgpt_plan_type,
|
||||
})
|
||||
Ok(ExternalAuthTokens::chatgpt(
|
||||
response.access_token,
|
||||
response.chatgpt_account_id,
|
||||
response.chatgpt_plan_type,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -816,10 +816,63 @@
|
||||
},
|
||||
"type": "object"
|
||||
},
|
||||
"ModelProviderAuthInfo": {
|
||||
"additionalProperties": false,
|
||||
"description": "Configuration for obtaining a provider bearer token from a command.",
|
||||
"properties": {
|
||||
"args": {
|
||||
"default": [],
|
||||
"description": "Command arguments.",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array"
|
||||
},
|
||||
"command": {
|
||||
"description": "Command to execute. Bare names are resolved via `PATH`; paths are resolved against `cwd`.",
|
||||
"type": "string"
|
||||
},
|
||||
"cwd": {
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/AbsolutePathBuf"
|
||||
}
|
||||
],
|
||||
"default": "/Users/mbolin/code/codex/codex-rs",
|
||||
"description": "Working directory used when running the token command."
|
||||
},
|
||||
"refresh_interval_ms": {
|
||||
"default": 300000,
|
||||
"description": "Maximum age for the cached token before rerunning the command.",
|
||||
"format": "uint64",
|
||||
"minimum": 1.0,
|
||||
"type": "integer"
|
||||
},
|
||||
"timeout_ms": {
|
||||
"default": 5000,
|
||||
"description": "Maximum time to wait for the token command to exit successfully.",
|
||||
"format": "uint64",
|
||||
"minimum": 1.0,
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"command"
|
||||
],
|
||||
"type": "object"
|
||||
},
|
||||
"ModelProviderInfo": {
|
||||
"additionalProperties": false,
|
||||
"description": "Serializable representation of a provider definition.",
|
||||
"properties": {
|
||||
"auth": {
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/ModelProviderAuthInfo"
|
||||
}
|
||||
],
|
||||
"description": "Command-backed bearer-token configuration for this provider."
|
||||
},
|
||||
"base_url": {
|
||||
"description": "Base URL for the provider's OpenAI-compatible API.",
|
||||
"type": "string"
|
||||
|
||||
@@ -64,6 +64,7 @@ mod tests {
|
||||
env_key: Some("sk-should-not-leak".to_string()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: crate::model_provider_info::WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
|
||||
@@ -104,6 +104,7 @@ use crate::error::Result;
|
||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
use crate::provider_auth::scoped_auth_manager_for_provider;
|
||||
use crate::response_debug_context::extract_response_debug_context;
|
||||
use crate::response_debug_context::extract_response_debug_context_from_api_error;
|
||||
use crate::response_debug_context::telemetry_api_error_message;
|
||||
@@ -261,6 +262,7 @@ impl ModelClient {
|
||||
include_timing_metrics: bool,
|
||||
beta_features_header: Option<String>,
|
||||
) -> Self {
|
||||
let auth_manager = scoped_auth_manager_for_provider(auth_manager, &provider);
|
||||
let codex_api_key_env_enabled = auth_manager
|
||||
.as_ref()
|
||||
.is_some_and(|manager| manager.codex_api_key_env_enabled());
|
||||
@@ -294,6 +296,10 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn auth_manager(&self) -> Option<Arc<AuthManager>> {
|
||||
self.state.auth_manager.clone()
|
||||
}
|
||||
|
||||
fn take_cached_websocket_session(&self) -> WebsocketSession {
|
||||
let mut cached_websocket_session = self
|
||||
.state
|
||||
|
||||
@@ -243,6 +243,26 @@ web_search = false
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_provider_auth_with_env_key() {
|
||||
let err = toml::from_str::<ConfigToml>(
|
||||
r#"
|
||||
[model_providers.corp]
|
||||
name = "Corp"
|
||||
env_key = "CORP_TOKEN"
|
||||
|
||||
[model_providers.corp.auth]
|
||||
command = "print-token"
|
||||
"#,
|
||||
)
|
||||
.unwrap_err();
|
||||
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("model_providers.corp: provider auth cannot be combined with env_key")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_toml_deserializes_model_availability_nux() {
|
||||
let toml = r#"
|
||||
@@ -4315,6 +4335,7 @@ model_verbosity = "high"
|
||||
wire_api: crate::WireApi::Responses,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
|
||||
@@ -1837,6 +1837,18 @@ Built-in providers cannot be overridden. Rename your custom provider (for exampl
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_model_providers(
|
||||
model_providers: &HashMap<String, ModelProviderInfo>,
|
||||
) -> Result<(), String> {
|
||||
validate_reserved_model_provider_ids(model_providers)?;
|
||||
for (key, provider) in model_providers {
|
||||
provider
|
||||
.validate()
|
||||
.map_err(|message| format!("model_providers.{key}: {message}"))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn deserialize_model_providers<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<HashMap<String, ModelProviderInfo>, D::Error>
|
||||
@@ -1844,7 +1856,7 @@ where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let model_providers = HashMap::<String, ModelProviderInfo>::deserialize(deserializer)?;
|
||||
validate_reserved_model_provider_ids(&model_providers).map_err(serde::de::Error::custom)?;
|
||||
validate_model_providers(&model_providers).map_err(serde::de::Error::custom)?;
|
||||
Ok(model_providers)
|
||||
}
|
||||
|
||||
@@ -1969,7 +1981,7 @@ impl Config {
|
||||
codex_home: PathBuf,
|
||||
config_layer_stack: ConfigLayerStack,
|
||||
) -> std::io::Result<Self> {
|
||||
validate_reserved_model_provider_ids(&cfg.model_providers)
|
||||
validate_model_providers(&cfg.model_providers)
|
||||
.map_err(|message| std::io::Error::new(std::io::ErrorKind::InvalidInput, message))?;
|
||||
// Ensure that every field of ConfigRequirements is applied to the final
|
||||
// Config.
|
||||
|
||||
@@ -65,6 +65,7 @@ pub mod utils;
|
||||
pub use utils::path_utils;
|
||||
pub mod personality_migration;
|
||||
pub mod plugins;
|
||||
mod provider_auth;
|
||||
pub(crate) mod mentions {
|
||||
pub(crate) use crate::plugins::build_connector_slug_counts;
|
||||
pub(crate) use crate::plugins::build_skill_name_counts;
|
||||
@@ -107,6 +108,7 @@ pub use client::X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER;
|
||||
pub use model_provider_info::DEFAULT_LMSTUDIO_PORT;
|
||||
pub use model_provider_info::DEFAULT_OLLAMA_PORT;
|
||||
pub use model_provider_info::LMSTUDIO_OSS_PROVIDER_ID;
|
||||
pub use model_provider_info::ModelProviderAuthInfo;
|
||||
pub use model_provider_info::ModelProviderInfo;
|
||||
pub use model_provider_info::OLLAMA_OSS_PROVIDER_ID;
|
||||
pub use model_provider_info::OPENAI_PROVIDER_ID;
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::auth::AuthMode;
|
||||
use crate::error::EnvVarError;
|
||||
use codex_api::Provider as ApiProvider;
|
||||
use codex_api::provider::RetryConfig as ApiRetryConfig;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use http::HeaderMap;
|
||||
use http::header::HeaderName;
|
||||
use http::header::HeaderValue;
|
||||
@@ -17,8 +18,11 @@ use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::num::NonZeroU64;
|
||||
use std::time::Duration;
|
||||
|
||||
const DEFAULT_PROVIDER_AUTH_TIMEOUT_MS: u64 = 5_000;
|
||||
const DEFAULT_PROVIDER_AUTH_REFRESH_INTERVAL_MS: u64 = 300_000;
|
||||
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
|
||||
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
|
||||
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
@@ -66,6 +70,73 @@ impl<'de> Deserialize<'de> for WireApi {
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for obtaining a provider bearer token from a command.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, JsonSchema)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct ModelProviderAuthInfo {
|
||||
/// Command to execute. Bare names are resolved via `PATH`; paths are resolved against `cwd`.
|
||||
pub command: String,
|
||||
|
||||
/// Command arguments.
|
||||
#[serde(default)]
|
||||
pub args: Vec<String>,
|
||||
|
||||
/// Maximum time to wait for the token command to exit successfully.
|
||||
#[serde(default = "default_provider_auth_timeout_ms")]
|
||||
pub timeout_ms: NonZeroU64,
|
||||
|
||||
/// Maximum age for the cached token before rerunning the command.
|
||||
#[serde(default = "default_provider_auth_refresh_interval_ms")]
|
||||
pub refresh_interval_ms: NonZeroU64,
|
||||
|
||||
/// Working directory used when running the token command.
|
||||
#[serde(default = "default_provider_auth_cwd")]
|
||||
pub cwd: AbsolutePathBuf,
|
||||
}
|
||||
|
||||
impl ModelProviderAuthInfo {
|
||||
pub(crate) fn timeout(&self) -> Duration {
|
||||
Duration::from_millis(self.timeout_ms.get())
|
||||
}
|
||||
|
||||
pub(crate) fn refresh_interval(&self) -> Duration {
|
||||
Duration::from_millis(self.refresh_interval_ms.get())
|
||||
}
|
||||
}
|
||||
|
||||
fn default_provider_auth_timeout_ms() -> NonZeroU64 {
|
||||
non_zero_u64(
|
||||
DEFAULT_PROVIDER_AUTH_TIMEOUT_MS,
|
||||
"model_providers.<id>.auth.timeout_ms",
|
||||
)
|
||||
}
|
||||
|
||||
fn default_provider_auth_refresh_interval_ms() -> NonZeroU64 {
|
||||
non_zero_u64(
|
||||
DEFAULT_PROVIDER_AUTH_REFRESH_INTERVAL_MS,
|
||||
"model_providers.<id>.auth.refresh_interval_ms",
|
||||
)
|
||||
}
|
||||
|
||||
fn non_zero_u64(value: u64, field_name: &str) -> NonZeroU64 {
|
||||
match NonZeroU64::new(value) {
|
||||
Some(value) => value,
|
||||
None => panic!("{field_name} must be non-zero"),
|
||||
}
|
||||
}
|
||||
|
||||
fn default_provider_auth_cwd() -> AbsolutePathBuf {
|
||||
let deserializer = serde::de::value::StrDeserializer::<serde::de::value::Error>::new(".");
|
||||
if let Ok(cwd) = AbsolutePathBuf::deserialize(deserializer) {
|
||||
return cwd;
|
||||
}
|
||||
|
||||
match AbsolutePathBuf::current_dir() {
|
||||
Ok(cwd) => cwd,
|
||||
Err(err) => panic!("provider auth cwd must resolve: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializable representation of a provider definition.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, JsonSchema)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
@@ -86,6 +157,9 @@ pub struct ModelProviderInfo {
|
||||
/// this may be necessary when using this programmatically.
|
||||
pub experimental_bearer_token: Option<String>,
|
||||
|
||||
/// Command-backed bearer-token configuration for this provider.
|
||||
pub auth: Option<ModelProviderAuthInfo>,
|
||||
|
||||
/// Which wire protocol this provider expects.
|
||||
#[serde(default)]
|
||||
pub wire_api: WireApi,
|
||||
@@ -130,6 +204,36 @@ pub struct ModelProviderInfo {
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
pub(crate) fn validate(&self) -> std::result::Result<(), String> {
|
||||
let Some(auth) = self.auth.as_ref() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if auth.command.trim().is_empty() {
|
||||
return Err("provider auth.command must not be empty".to_string());
|
||||
}
|
||||
|
||||
let mut conflicts = Vec::new();
|
||||
if self.env_key.is_some() {
|
||||
conflicts.push("env_key");
|
||||
}
|
||||
if self.experimental_bearer_token.is_some() {
|
||||
conflicts.push("experimental_bearer_token");
|
||||
}
|
||||
if self.requires_openai_auth {
|
||||
conflicts.push("requires_openai_auth");
|
||||
}
|
||||
|
||||
if conflicts.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!(
|
||||
"provider auth cannot be combined with {}",
|
||||
conflicts.join(", ")
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_header_map(&self) -> crate::error::Result<HeaderMap> {
|
||||
let capacity = self.http_headers.as_ref().map_or(0, HashMap::len)
|
||||
+ self.env_http_headers.as_ref().map_or(0, HashMap::len);
|
||||
@@ -246,6 +350,7 @@ impl ModelProviderInfo {
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: Some(
|
||||
@@ -277,6 +382,10 @@ impl ModelProviderInfo {
|
||||
pub fn is_openai(&self) -> bool {
|
||||
self.name == OPENAI_PROVIDER_NAME
|
||||
}
|
||||
|
||||
pub(crate) fn has_command_auth(&self) -> bool {
|
||||
self.auth.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
pub const DEFAULT_LMSTUDIO_PORT: u16 = 1234;
|
||||
@@ -338,6 +447,7 @@ pub fn create_oss_provider_with_base_url(base_url: &str, wire_api: WireApi) -> M
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use super::*;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use codex_utils_absolute_path::AbsolutePathBufGuard;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_ollama_model_provider_toml() {
|
||||
@@ -13,6 +16,7 @@ base_url = "http://localhost:11434/v1"
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
@@ -43,6 +47,7 @@ query_params = { api-version = "2025-04-01-preview" }
|
||||
env_key: Some("AZURE_OPENAI_API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: Some(maplit::hashmap! {
|
||||
"api-version".to_string() => "2025-04-01-preview".to_string(),
|
||||
@@ -76,6 +81,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
env_key: Some("API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: Some(maplit::hashmap! {
|
||||
@@ -121,3 +127,29 @@ supports_websockets = true
|
||||
let provider: ModelProviderInfo = toml::from_str(provider_toml).unwrap();
|
||||
assert_eq!(provider.websocket_connect_timeout_ms, Some(15_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_provider_auth_config_defaults() {
|
||||
let base_dir = tempdir().unwrap();
|
||||
let provider_toml = r#"
|
||||
name = "Corp"
|
||||
|
||||
[auth]
|
||||
command = "./scripts/print-token"
|
||||
args = ["--format=text"]
|
||||
"#;
|
||||
|
||||
let _guard = AbsolutePathBufGuard::new(base_dir.path());
|
||||
let provider: ModelProviderInfo = toml::from_str(provider_toml).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
provider.auth,
|
||||
Some(ModelProviderAuthInfo {
|
||||
command: "./scripts/print-token".to_string(),
|
||||
args: vec!["--format=text".to_string()],
|
||||
timeout_ms: default_provider_auth_timeout_ms(),
|
||||
refresh_interval_ms: default_provider_auth_refresh_interval_ms(),
|
||||
cwd: AbsolutePathBuf::resolve_path_against_base(".", base_dir.path()).unwrap(),
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ use crate::api_bridge::map_api_error;
|
||||
use crate::auth::AuthManager;
|
||||
use crate::auth::AuthMode;
|
||||
use crate::auth::CodexAuth;
|
||||
use crate::auth::RefreshTokenError;
|
||||
use crate::auth::UnauthorizedRecovery;
|
||||
use crate::auth_env_telemetry::AuthEnvTelemetry;
|
||||
use crate::auth_env_telemetry::collect_auth_env_telemetry;
|
||||
use crate::config::Config;
|
||||
@@ -14,6 +16,7 @@ use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::models_manager::collaboration_mode_presets::CollaborationModesConfig;
|
||||
use crate::models_manager::collaboration_mode_presets::builtin_collaboration_mode_presets;
|
||||
use crate::models_manager::model_info;
|
||||
use crate::provider_auth::scoped_auth_manager;
|
||||
use crate::response_debug_context::extract_response_debug_context;
|
||||
use crate::response_debug_context::telemetry_transport_error_message;
|
||||
use crate::util::FeedbackRequestTags;
|
||||
@@ -212,6 +215,7 @@ impl ModelsManager {
|
||||
collaboration_modes_config: CollaborationModesConfig,
|
||||
provider: ModelProviderInfo,
|
||||
) -> Self {
|
||||
let auth_manager = scoped_auth_manager(auth_manager, &provider);
|
||||
let cache_path = codex_home.join(MODEL_CACHE_FILE);
|
||||
let cache_manager = ModelsCacheManager::new(cache_path, DEFAULT_MODEL_CACHE_TTL);
|
||||
let catalog_mode = if model_catalog.is_some() {
|
||||
@@ -396,7 +400,9 @@ impl ModelsManager {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if self.auth_manager.auth_mode() != Some(AuthMode::Chatgpt) {
|
||||
if self.auth_manager.auth_mode() != Some(AuthMode::Chatgpt)
|
||||
&& !self.provider.has_command_auth()
|
||||
{
|
||||
if matches!(
|
||||
refresh_strategy,
|
||||
RefreshStrategy::Offline | RefreshStrategy::OnlineIfUncached
|
||||
@@ -431,39 +437,72 @@ impl ModelsManager {
|
||||
async fn fetch_and_update_models(&self) -> CoreResult<()> {
|
||||
let _timer =
|
||||
codex_otel::start_global_timer("codex.remote_models.fetch_update.duration_ms", &[]);
|
||||
let auth = self.auth_manager.auth().await;
|
||||
let auth_mode = auth.as_ref().map(CodexAuth::auth_mode);
|
||||
let api_provider = self.provider.to_api_provider(auth_mode)?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
|
||||
let auth_env = collect_auth_env_telemetry(
|
||||
&self.provider,
|
||||
self.auth_manager.codex_api_key_env_enabled(),
|
||||
);
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let request_telemetry: Arc<dyn RequestTelemetry> = Arc::new(ModelsRequestTelemetry {
|
||||
auth_mode: auth_mode.map(|mode| TelemetryAuthMode::from(mode).to_string()),
|
||||
auth_header_attached: api_auth.auth_header_attached(),
|
||||
auth_header_name: api_auth.auth_header_name(),
|
||||
auth_env,
|
||||
});
|
||||
let client = ModelsClient::new(transport, api_provider, api_auth)
|
||||
.with_telemetry(Some(request_telemetry));
|
||||
|
||||
let client_version = crate::models_manager::client_version_to_whole();
|
||||
let (models, etag) = timeout(
|
||||
MODELS_REFRESH_TIMEOUT,
|
||||
client.list_models(&client_version, HeaderMap::new()),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| CodexErr::Timeout)?
|
||||
.map_err(map_api_error)?;
|
||||
let mut auth_recovery = self.auth_manager.unauthorized_recovery();
|
||||
|
||||
self.apply_remote_models(models.clone()).await;
|
||||
*self.etag.write().await = etag.clone();
|
||||
self.cache_manager
|
||||
.persist_cache(&models, etag, client_version)
|
||||
.await;
|
||||
Ok(())
|
||||
loop {
|
||||
let auth = self.auth_manager.auth().await;
|
||||
let auth_mode = auth.as_ref().map(CodexAuth::auth_mode);
|
||||
let api_provider = self.provider.to_api_provider(auth_mode)?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider)?;
|
||||
let auth_env = collect_auth_env_telemetry(
|
||||
&self.provider,
|
||||
self.auth_manager.codex_api_key_env_enabled(),
|
||||
);
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let request_telemetry: Arc<dyn RequestTelemetry> = Arc::new(ModelsRequestTelemetry {
|
||||
auth_mode: auth_mode.map(|mode| TelemetryAuthMode::from(mode).to_string()),
|
||||
auth_header_attached: api_auth.auth_header_attached(),
|
||||
auth_header_name: api_auth.auth_header_name(),
|
||||
auth_env,
|
||||
});
|
||||
let client = ModelsClient::new(transport, api_provider, api_auth)
|
||||
.with_telemetry(Some(request_telemetry));
|
||||
|
||||
let result = timeout(
|
||||
MODELS_REFRESH_TIMEOUT,
|
||||
client.list_models(&client_version, HeaderMap::new()),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| CodexErr::Timeout)?
|
||||
.map_err(map_api_error);
|
||||
|
||||
match result {
|
||||
Ok((models, etag)) => {
|
||||
self.apply_remote_models(models.clone()).await;
|
||||
*self.etag.write().await = etag.clone();
|
||||
self.cache_manager
|
||||
.persist_cache(&models, etag, client_version)
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
Err(
|
||||
err @ CodexErr::UnexpectedStatus(crate::error::UnexpectedResponseError {
|
||||
status,
|
||||
..
|
||||
}),
|
||||
) if status == http::StatusCode::UNAUTHORIZED => {
|
||||
if !Self::recover_after_unauthorized(&mut auth_recovery).await? {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn recover_after_unauthorized(
|
||||
auth_recovery: &mut UnauthorizedRecovery,
|
||||
) -> CoreResult<bool> {
|
||||
if !auth_recovery.has_next() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
match auth_recovery.next().await {
|
||||
Ok(_) => Ok(true),
|
||||
Err(RefreshTokenError::Permanent(failed)) => Err(CodexErr::RefreshTokenFailed(failed)),
|
||||
Err(RefreshTokenError::Transient(other)) => Err(CodexErr::Io(other)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_etag(&self) -> Option<String> {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use super::*;
|
||||
use crate::AuthManager;
|
||||
use crate::CodexAuth;
|
||||
use crate::auth::AuthCredentialsStoreMode;
|
||||
use crate::config::ConfigBuilder;
|
||||
use crate::model_provider_info::ModelProviderAuthInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
use base64::Engine as _;
|
||||
use chrono::Utc;
|
||||
@@ -13,8 +15,10 @@ use http::StatusCode;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::collections::BTreeMap;
|
||||
use std::num::NonZeroU64;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use tempfile::TempDir;
|
||||
use tempfile::tempdir;
|
||||
use tracing::Event;
|
||||
use tracing::Subscriber;
|
||||
@@ -24,7 +28,12 @@ use tracing_subscriber::layer::Context;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::registry::LookupSpan;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::header_regex;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
fn remote_model(slug: &str, display: &str, priority: i32) -> ModelInfo {
|
||||
remote_model_with_visibility(slug, display, priority, "list")
|
||||
@@ -79,6 +88,7 @@ fn provider_for(base_url: String) -> ModelProviderInfo {
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
@@ -92,6 +102,90 @@ fn provider_for(base_url: String) -> ModelProviderInfo {
|
||||
}
|
||||
}
|
||||
|
||||
struct ProviderAuthScript {
|
||||
tempdir: TempDir,
|
||||
command: String,
|
||||
args: Vec<String>,
|
||||
}
|
||||
|
||||
impl ProviderAuthScript {
|
||||
fn new(tokens: &[&str]) -> std::io::Result<Self> {
|
||||
let tempdir = tempfile::tempdir()?;
|
||||
let tokens_file = tempdir.path().join("tokens.txt");
|
||||
std::fs::write(&tokens_file, format!("{}\n", tokens.join("\n")))?;
|
||||
|
||||
#[cfg(unix)]
|
||||
let (command, args) = {
|
||||
let script_path = tempdir.path().join("print-token.sh");
|
||||
std::fs::write(
|
||||
&script_path,
|
||||
r#"#!/bin/sh
|
||||
first_line=$(sed -n '1p' tokens.txt)
|
||||
printf '%s\n' "$first_line"
|
||||
tail -n +2 tokens.txt > tokens.next
|
||||
mv tokens.next tokens.txt
|
||||
"#,
|
||||
)?;
|
||||
let mut permissions = std::fs::metadata(&script_path)?.permissions();
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
permissions.set_mode(0o755);
|
||||
}
|
||||
std::fs::set_permissions(&script_path, permissions)?;
|
||||
("./print-token.sh".to_string(), Vec::new())
|
||||
};
|
||||
|
||||
#[cfg(windows)]
|
||||
let (command, args) = {
|
||||
let script_path = tempdir.path().join("print-token.ps1");
|
||||
std::fs::write(
|
||||
&script_path,
|
||||
r#"$lines = Get-Content -Path tokens.txt
|
||||
if ($lines.Count -eq 0) { exit 1 }
|
||||
Write-Output $lines[0]
|
||||
$lines | Select-Object -Skip 1 | Set-Content -Path tokens.txt
|
||||
"#,
|
||||
)?;
|
||||
(
|
||||
"powershell".to_string(),
|
||||
vec![
|
||||
"-NoProfile".to_string(),
|
||||
"-ExecutionPolicy".to_string(),
|
||||
"Bypass".to_string(),
|
||||
"-File".to_string(),
|
||||
".\\print-token.ps1".to_string(),
|
||||
],
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
tempdir,
|
||||
command,
|
||||
args,
|
||||
})
|
||||
}
|
||||
|
||||
fn auth_config(&self) -> ModelProviderAuthInfo {
|
||||
ModelProviderAuthInfo {
|
||||
command: self.command.clone(),
|
||||
args: self.args.clone(),
|
||||
timeout_ms: non_zero_u64(/*value*/ 1_000),
|
||||
refresh_interval_ms: non_zero_u64(/*value*/ 60_000),
|
||||
cwd: match codex_utils_absolute_path::AbsolutePathBuf::try_from(self.tempdir.path()) {
|
||||
Ok(cwd) => cwd,
|
||||
Err(err) => panic!("tempdir should be absolute: {err}"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn non_zero_u64(value: u64) -> NonZeroU64 {
|
||||
match NonZeroU64::new(value) {
|
||||
Some(value) => value,
|
||||
None => panic!("expected non-zero value: {value}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct TagCollectorVisitor {
|
||||
tags: BTreeMap<String, String>,
|
||||
@@ -310,6 +404,57 @@ async fn refresh_available_models_sorts_by_priority() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_available_models_refreshes_provider_auth_after_401() {
|
||||
let server = MockServer::start().await;
|
||||
let auth_script = ProviderAuthScript::new(&["first-token", "second-token"]).unwrap();
|
||||
let remote_models = vec![remote_model(
|
||||
"provider-model",
|
||||
"Provider",
|
||||
/*priority*/ 0,
|
||||
)];
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/models"))
|
||||
.and(header_regex("Authorization", "Bearer first-token"))
|
||||
.respond_with(ResponseTemplate::new(401).set_body_string("unauthorized"))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/models"))
|
||||
.and(header_regex("Authorization", "Bearer second-token"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "application/json")
|
||||
.set_body_json(ModelsResponse {
|
||||
models: remote_models.clone(),
|
||||
}),
|
||||
)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let codex_home = tempdir().expect("temp dir");
|
||||
let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("unused"));
|
||||
let provider = ModelProviderInfo {
|
||||
auth: Some(auth_script.auth_config()),
|
||||
..provider_for(server.uri())
|
||||
};
|
||||
let manager = ModelsManager::with_provider_for_tests(
|
||||
codex_home.path().to_path_buf(),
|
||||
auth_manager,
|
||||
provider,
|
||||
);
|
||||
|
||||
manager
|
||||
.refresh_available_models(RefreshStrategy::Online)
|
||||
.await
|
||||
.expect("refresh succeeds");
|
||||
|
||||
assert_models_contain(&manager.get_remote_models().await, &remote_models);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_available_models_uses_cache_when_fresh() {
|
||||
let server = MockServer::start().await;
|
||||
|
||||
191
codex-rs/core/src/provider_auth.rs
Normal file
191
codex-rs/core/src/provider_auth.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::auth::ExternalAuthRefreshContext;
|
||||
use crate::auth::ExternalAuthRefresher;
|
||||
use crate::auth::ExternalAuthTokens;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::model_provider_info::ModelProviderAuthInfo;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
|
||||
pub(crate) fn scoped_auth_manager_for_provider(
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
provider: &ModelProviderInfo,
|
||||
) -> Option<Arc<AuthManager>> {
|
||||
auth_manager.map(|auth_manager| scoped_auth_manager(auth_manager, provider))
|
||||
}
|
||||
|
||||
pub(crate) fn scoped_auth_manager(
|
||||
auth_manager: Arc<AuthManager>,
|
||||
provider: &ModelProviderInfo,
|
||||
) -> Arc<AuthManager> {
|
||||
match provider.auth.clone() {
|
||||
Some(config) => {
|
||||
auth_manager.with_external_bearer_refresher(Arc::new(ProviderAuthResolver::new(config)))
|
||||
}
|
||||
None => auth_manager,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ProviderAuthResolver {
|
||||
state: Arc<ProviderAuthState>,
|
||||
}
|
||||
|
||||
impl ProviderAuthResolver {
|
||||
fn new(config: ModelProviderAuthInfo) -> Self {
|
||||
Self {
|
||||
state: Arc::new(ProviderAuthState::new(config)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ExternalAuthRefresher for ProviderAuthResolver {
|
||||
async fn resolve(&self) -> io::Result<Option<ExternalAuthTokens>> {
|
||||
let mut cached = self.state.cached_token.lock().await;
|
||||
if let Some(cached_token) = cached.as_ref()
|
||||
&& cached_token.fetched_at.elapsed() < self.state.config.refresh_interval()
|
||||
{
|
||||
return Ok(Some(cached_token.tokens.clone()));
|
||||
}
|
||||
|
||||
let tokens = run_provider_auth_command(&self.state.config)
|
||||
.await
|
||||
.map_err(codex_err_to_io)?;
|
||||
*cached = Some(CachedProviderToken {
|
||||
tokens: tokens.clone(),
|
||||
fetched_at: Instant::now(),
|
||||
});
|
||||
Ok(Some(tokens))
|
||||
}
|
||||
|
||||
async fn refresh(
|
||||
&self,
|
||||
_context: ExternalAuthRefreshContext,
|
||||
) -> io::Result<ExternalAuthTokens> {
|
||||
let tokens = run_provider_auth_command(&self.state.config)
|
||||
.await
|
||||
.map_err(codex_err_to_io)?;
|
||||
let mut cached = self.state.cached_token.lock().await;
|
||||
*cached = Some(CachedProviderToken {
|
||||
tokens: tokens.clone(),
|
||||
fetched_at: Instant::now(),
|
||||
});
|
||||
Ok(tokens)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for ProviderAuthResolver {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("ProviderAuthResolver")
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
struct ProviderAuthState {
|
||||
config: ModelProviderAuthInfo,
|
||||
cached_token: Mutex<Option<CachedProviderToken>>,
|
||||
}
|
||||
|
||||
impl ProviderAuthState {
|
||||
fn new(config: ModelProviderAuthInfo) -> Self {
|
||||
Self {
|
||||
config,
|
||||
cached_token: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct CachedProviderToken {
|
||||
tokens: ExternalAuthTokens,
|
||||
fetched_at: Instant,
|
||||
}
|
||||
|
||||
async fn run_provider_auth_command(config: &ModelProviderAuthInfo) -> Result<ExternalAuthTokens> {
|
||||
let program = resolve_provider_auth_program(&config.command, &config.cwd)?;
|
||||
let mut command = Command::new(&program);
|
||||
command
|
||||
.args(&config.args)
|
||||
.current_dir(config.cwd.as_path())
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.kill_on_drop(true);
|
||||
|
||||
let output = tokio::time::timeout(config.timeout(), command.output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
CodexErr::InvalidRequest(format!(
|
||||
"provider auth command `{}` timed out after {} ms",
|
||||
config.command,
|
||||
config.timeout_ms.get()
|
||||
))
|
||||
})?
|
||||
.map_err(|err| {
|
||||
CodexErr::InvalidRequest(format!(
|
||||
"provider auth command `{}` failed to start: {err}",
|
||||
config.command
|
||||
))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
let status = output.status;
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
|
||||
let stderr_suffix = if stderr.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(": {stderr}")
|
||||
};
|
||||
return Err(CodexErr::InvalidRequest(format!(
|
||||
"provider auth command `{}` exited with status {status}{stderr_suffix}",
|
||||
config.command
|
||||
)));
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).map_err(|_| {
|
||||
CodexErr::InvalidRequest(format!(
|
||||
"provider auth command `{}` wrote non-UTF-8 data to stdout",
|
||||
config.command
|
||||
))
|
||||
})?;
|
||||
let token = stdout.trim().to_string();
|
||||
if token.is_empty() {
|
||||
return Err(CodexErr::InvalidRequest(format!(
|
||||
"provider auth command `{}` produced an empty token",
|
||||
config.command
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(ExternalAuthTokens::access_token_only(token))
|
||||
}
|
||||
|
||||
fn resolve_provider_auth_program(command: &str, cwd: &Path) -> Result<PathBuf> {
|
||||
let path = Path::new(command);
|
||||
if path.is_absolute() || path.components().count() > 1 {
|
||||
return Ok(
|
||||
codex_utils_absolute_path::AbsolutePathBuf::resolve_path_against_base(path, cwd)?
|
||||
.into_path_buf(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(PathBuf::from(command))
|
||||
}
|
||||
|
||||
fn codex_err_to_io(error: CodexErr) -> io::Error {
|
||||
io::Error::other(error.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "provider_auth_tests.rs"]
|
||||
mod tests;
|
||||
139
codex-rs/core/src/provider_auth_tests.rs
Normal file
139
codex-rs/core/src/provider_auth_tests.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::num::NonZeroU64;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn caches_command_output_until_refreshed() {
|
||||
let script = ProviderAuthScript::new(&["first-token", "second-token"]).unwrap();
|
||||
let source = ProviderAuthResolver::new(script.auth_config());
|
||||
|
||||
let first = source
|
||||
.resolve()
|
||||
.await
|
||||
.unwrap()
|
||||
.map(|tokens| tokens.access_token);
|
||||
let second = source
|
||||
.resolve()
|
||||
.await
|
||||
.unwrap()
|
||||
.map(|tokens| tokens.access_token);
|
||||
let refreshed = source
|
||||
.refresh(ExternalAuthRefreshContext {
|
||||
reason: crate::auth::ExternalAuthRefreshReason::Unauthorized,
|
||||
previous_account_id: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let after_refresh = source
|
||||
.resolve()
|
||||
.await
|
||||
.unwrap()
|
||||
.map(|tokens| tokens.access_token);
|
||||
|
||||
assert_eq!(first.as_deref(), Some("first-token"));
|
||||
assert_eq!(second.as_deref(), Some("first-token"));
|
||||
assert_eq!(refreshed.access_token, "second-token");
|
||||
assert_eq!(after_refresh.as_deref(), Some("second-token"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_returns_bearer_only_external_auth_tokens() {
|
||||
let script = ProviderAuthScript::new(&["first-token"]).unwrap();
|
||||
let source = ProviderAuthResolver::new(script.auth_config());
|
||||
|
||||
let tokens = source
|
||||
.refresh(ExternalAuthRefreshContext {
|
||||
reason: crate::auth::ExternalAuthRefreshReason::Unauthorized,
|
||||
previous_account_id: Some("ignored".to_string()),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(tokens.access_token, "first-token");
|
||||
assert_eq!(tokens.chatgpt_metadata, None);
|
||||
}
|
||||
|
||||
struct ProviderAuthScript {
|
||||
tempdir: TempDir,
|
||||
command: String,
|
||||
args: Vec<String>,
|
||||
}
|
||||
|
||||
impl ProviderAuthScript {
|
||||
fn new(tokens: &[&str]) -> std::io::Result<Self> {
|
||||
let tempdir = tempfile::tempdir()?;
|
||||
let token_file = tempdir.path().join("tokens.txt");
|
||||
std::fs::write(&token_file, format!("{}\n", tokens.join("\n")))?;
|
||||
|
||||
#[cfg(unix)]
|
||||
let (command, args) = {
|
||||
let script_path = tempdir.path().join("print-token.sh");
|
||||
std::fs::write(
|
||||
&script_path,
|
||||
r#"#!/bin/sh
|
||||
first_line=$(sed -n '1p' tokens.txt)
|
||||
printf '%s\n' "$first_line"
|
||||
tail -n +2 tokens.txt > tokens.next
|
||||
mv tokens.next tokens.txt
|
||||
"#,
|
||||
)?;
|
||||
let mut permissions = std::fs::metadata(&script_path)?.permissions();
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
permissions.set_mode(0o755);
|
||||
}
|
||||
std::fs::set_permissions(&script_path, permissions)?;
|
||||
("./print-token.sh".to_string(), Vec::new())
|
||||
};
|
||||
|
||||
#[cfg(windows)]
|
||||
let (command, args) = {
|
||||
let script_path = tempdir.path().join("print-token.ps1");
|
||||
std::fs::write(
|
||||
&script_path,
|
||||
r#"$lines = Get-Content -Path tokens.txt
|
||||
if ($lines.Count -eq 0) { exit 1 }
|
||||
Write-Output $lines[0]
|
||||
$lines | Select-Object -Skip 1 | Set-Content -Path tokens.txt
|
||||
"#,
|
||||
)?;
|
||||
(
|
||||
"powershell".to_string(),
|
||||
vec![
|
||||
"-NoProfile".to_string(),
|
||||
"-ExecutionPolicy".to_string(),
|
||||
"Bypass".to_string(),
|
||||
"-File".to_string(),
|
||||
".\\print-token.ps1".to_string(),
|
||||
],
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
tempdir,
|
||||
command,
|
||||
args,
|
||||
})
|
||||
}
|
||||
|
||||
fn auth_config(&self) -> ModelProviderAuthInfo {
|
||||
ModelProviderAuthInfo {
|
||||
command: self.command.clone(),
|
||||
args: self.args.clone(),
|
||||
timeout_ms: non_zero_u64(/*value*/ 1_000),
|
||||
refresh_interval_ms: non_zero_u64(/*value*/ 60_000),
|
||||
cwd: match codex_utils_absolute_path::AbsolutePathBuf::try_from(self.tempdir.path()) {
|
||||
Ok(cwd) => cwd,
|
||||
Err(err) => panic!("tempdir should be absolute: {err}"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn non_zero_u64(value: u64) -> NonZeroU64 {
|
||||
match NonZeroU64::new(value) {
|
||||
Some(value) => value,
|
||||
None => panic!("expected non-zero value: {value}"),
|
||||
}
|
||||
}
|
||||
@@ -452,7 +452,12 @@ async fn prepare_realtime_start(
|
||||
params: ConversationStartParams,
|
||||
) -> CodexResult<PreparedRealtimeConversationStart> {
|
||||
let provider = sess.provider().await;
|
||||
let auth = sess.services.auth_manager.auth().await;
|
||||
let auth_manager = sess
|
||||
.services
|
||||
.model_client
|
||||
.auth_manager()
|
||||
.unwrap_or_else(|| Arc::clone(&sess.services.auth_manager));
|
||||
let auth = auth_manager.auth().await;
|
||||
let realtime_api_key = realtime_api_key(auth.as_ref(), &provider)?;
|
||||
let mut api_provider = provider.to_api_provider(Some(crate::auth::AuthMode::ApiKey))?;
|
||||
let config = sess.get_config().await;
|
||||
|
||||
@@ -46,6 +46,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
@@ -158,6 +159,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
@@ -265,6 +267,7 @@ async fn responses_respects_model_info_overrides_from_config() {
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ModelClient;
|
||||
use codex_core::ModelProviderAuthInfo;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewThread;
|
||||
use codex_core::Prompt;
|
||||
@@ -64,6 +66,7 @@ use futures::StreamExt;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::io::Write;
|
||||
use std::num::NonZeroU64;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
use uuid::Uuid;
|
||||
@@ -143,6 +146,90 @@ fn write_auth_json(
|
||||
fake_jwt
|
||||
}
|
||||
|
||||
struct ProviderAuthCommandFixture {
|
||||
tempdir: TempDir,
|
||||
command: String,
|
||||
args: Vec<String>,
|
||||
}
|
||||
|
||||
impl ProviderAuthCommandFixture {
|
||||
fn new(tokens: &[&str]) -> std::io::Result<Self> {
|
||||
let tempdir = tempfile::tempdir()?;
|
||||
let tokens_file = tempdir.path().join("tokens.txt");
|
||||
std::fs::write(&tokens_file, format!("{}\n", tokens.join("\n")))?;
|
||||
|
||||
#[cfg(unix)]
|
||||
let (command, args) = {
|
||||
let script_path = tempdir.path().join("print-token.sh");
|
||||
std::fs::write(
|
||||
&script_path,
|
||||
r#"#!/bin/sh
|
||||
first_line=$(sed -n '1p' tokens.txt)
|
||||
printf '%s\n' "$first_line"
|
||||
tail -n +2 tokens.txt > tokens.next
|
||||
mv tokens.next tokens.txt
|
||||
"#,
|
||||
)?;
|
||||
let mut permissions = std::fs::metadata(&script_path)?.permissions();
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
permissions.set_mode(0o755);
|
||||
}
|
||||
std::fs::set_permissions(&script_path, permissions)?;
|
||||
("./print-token.sh".to_string(), Vec::new())
|
||||
};
|
||||
|
||||
#[cfg(windows)]
|
||||
let (command, args) = {
|
||||
let script_path = tempdir.path().join("print-token.ps1");
|
||||
std::fs::write(
|
||||
&script_path,
|
||||
r#"$lines = Get-Content -Path tokens.txt
|
||||
if ($lines.Count -eq 0) { exit 1 }
|
||||
Write-Output $lines[0]
|
||||
$lines | Select-Object -Skip 1 | Set-Content -Path tokens.txt
|
||||
"#,
|
||||
)?;
|
||||
(
|
||||
"powershell".to_string(),
|
||||
vec![
|
||||
"-NoProfile".to_string(),
|
||||
"-ExecutionPolicy".to_string(),
|
||||
"Bypass".to_string(),
|
||||
"-File".to_string(),
|
||||
".\\print-token.ps1".to_string(),
|
||||
],
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
tempdir,
|
||||
command,
|
||||
args,
|
||||
})
|
||||
}
|
||||
|
||||
fn auth(&self) -> ModelProviderAuthInfo {
|
||||
ModelProviderAuthInfo {
|
||||
command: self.command.clone(),
|
||||
args: self.args.clone(),
|
||||
timeout_ms: non_zero_u64(/*value*/ 1_000),
|
||||
refresh_interval_ms: non_zero_u64(/*value*/ 60_000),
|
||||
cwd: match codex_utils_absolute_path::AbsolutePathBuf::try_from(self.tempdir.path()) {
|
||||
Ok(cwd) => cwd,
|
||||
Err(err) => panic!("tempdir should be absolute: {err}"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn non_zero_u64(value: u64) -> NonZeroU64 {
|
||||
match NonZeroU64::new(value) {
|
||||
Some(value) => value,
|
||||
None => panic!("expected non-zero value: {value}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn resume_includes_initial_messages_and_sends_prior_items() {
|
||||
skip_if_no_network!();
|
||||
@@ -659,6 +746,227 @@ async fn includes_conversation_id_and_model_headers_in_request() {
|
||||
assert_eq!(request_authorization, "Bearer Test API Key");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn provider_auth_command_supplies_bearer_token() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = MockServer::start().await;
|
||||
let resp_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![ev_response_created("resp1"), ev_completed("resp1")]),
|
||||
)
|
||||
.await;
|
||||
let auth_fixture = ProviderAuthCommandFixture::new(&["command-token"]).unwrap();
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "corp".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: Some(auth_fixture.auth()),
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home).await;
|
||||
config.model_provider_id = provider.name.clone();
|
||||
config.model_provider = provider.clone();
|
||||
let effort = config.model_reasoning_effort;
|
||||
let summary = config.model_reasoning_summary;
|
||||
let model = codex_core::test_support::get_model_offline(config.model.as_deref());
|
||||
config.model = Some(model.clone());
|
||||
let config = Arc::new(config);
|
||||
let model_info =
|
||||
codex_core::test_support::construct_model_info_offline(model.as_str(), &config);
|
||||
let conversation_id = ThreadId::new();
|
||||
let session_telemetry = SessionTelemetry::new(
|
||||
conversation_id,
|
||||
model.as_str(),
|
||||
model_info.slug.as_str(),
|
||||
/*account_id*/ None,
|
||||
Some("test@test.com".to_string()),
|
||||
/*auth_mode*/ None,
|
||||
"test_originator".to_string(),
|
||||
/*log_user_prompts*/ false,
|
||||
"test".to_string(),
|
||||
SessionSource::Exec,
|
||||
);
|
||||
let client = ModelClient::new(
|
||||
Some(AuthManager::from_auth_for_testing(CodexAuth::from_api_key(
|
||||
"unused-api-key",
|
||||
))),
|
||||
conversation_id,
|
||||
provider,
|
||||
SessionSource::Exec,
|
||||
config.model_verbosity,
|
||||
/*enable_request_compression*/ false,
|
||||
/*include_timing_metrics*/ false,
|
||||
/*beta_features_header*/ None,
|
||||
);
|
||||
let mut client_session = client.new_session();
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input.push(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".to_string(),
|
||||
}],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
});
|
||||
|
||||
let mut stream = client_session
|
||||
.stream(
|
||||
&prompt,
|
||||
&model_info,
|
||||
&session_telemetry,
|
||||
effort,
|
||||
summary.unwrap_or(ReasoningSummary::Auto),
|
||||
/*service_tier*/ None,
|
||||
/*turn_metadata_header*/ None,
|
||||
)
|
||||
.await
|
||||
.expect("responses stream to start");
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
if let Ok(ResponseEvent::Completed { .. }) = event {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let request = resp_mock.single_request();
|
||||
let request_authorization = request
|
||||
.header("authorization")
|
||||
.expect("authorization header");
|
||||
assert_eq!(request_authorization, "Bearer command-token");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn provider_auth_command_refreshes_after_401() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = MockServer::start().await;
|
||||
let auth_fixture = ProviderAuthCommandFixture::new(&["first-token", "second-token"]).unwrap();
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(header_regex("Authorization", "Bearer first-token"))
|
||||
.respond_with(ResponseTemplate::new(401).set_body_string("unauthorized"))
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(header_regex("Authorization", "Bearer second-token"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(
|
||||
sse(vec![ev_response_created("resp1"), ev_completed("resp1")]),
|
||||
"text/event-stream",
|
||||
),
|
||||
)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let provider = ModelProviderInfo {
|
||||
name: "corp".into(),
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: Some(auth_fixture.auth()),
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: Some(5_000),
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home).await;
|
||||
config.model_provider_id = provider.name.clone();
|
||||
config.model_provider = provider.clone();
|
||||
let effort = config.model_reasoning_effort;
|
||||
let summary = config.model_reasoning_summary;
|
||||
let model = codex_core::test_support::get_model_offline(config.model.as_deref());
|
||||
config.model = Some(model.clone());
|
||||
let config = Arc::new(config);
|
||||
let model_info =
|
||||
codex_core::test_support::construct_model_info_offline(model.as_str(), &config);
|
||||
let conversation_id = ThreadId::new();
|
||||
let session_telemetry = SessionTelemetry::new(
|
||||
conversation_id,
|
||||
model.as_str(),
|
||||
model_info.slug.as_str(),
|
||||
/*account_id*/ None,
|
||||
Some("test@test.com".to_string()),
|
||||
/*auth_mode*/ None,
|
||||
"test_originator".to_string(),
|
||||
/*log_user_prompts*/ false,
|
||||
"test".to_string(),
|
||||
SessionSource::Exec,
|
||||
);
|
||||
let client = ModelClient::new(
|
||||
Some(AuthManager::from_auth_for_testing(CodexAuth::from_api_key(
|
||||
"unused-api-key",
|
||||
))),
|
||||
conversation_id,
|
||||
provider,
|
||||
SessionSource::Exec,
|
||||
config.model_verbosity,
|
||||
/*enable_request_compression*/ false,
|
||||
/*include_timing_metrics*/ false,
|
||||
/*beta_features_header*/ None,
|
||||
);
|
||||
let mut client_session = client.new_session();
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input.push(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".to_string(),
|
||||
}],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
});
|
||||
|
||||
let mut stream = client_session
|
||||
.stream(
|
||||
&prompt,
|
||||
&model_info,
|
||||
&session_telemetry,
|
||||
effort,
|
||||
summary.unwrap_or(ReasoningSummary::Auto),
|
||||
/*service_tier*/ None,
|
||||
/*turn_metadata_header*/ None,
|
||||
)
|
||||
.await
|
||||
.expect("responses stream to start");
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
if let Ok(ResponseEvent::Completed { .. }) = event {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_base_instructions_override_in_request() {
|
||||
skip_if_no_network!();
|
||||
@@ -1796,6 +2104,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
@@ -2396,6 +2705,7 @@ async fn azure_overrides_assign_properties_used_for_responses_url() {
|
||||
// Reuse the existing environment variable to avoid using unsafe code
|
||||
env_key: Some(existing_env_var_with_random_value.to_string()),
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
query_params: Some(std::collections::HashMap::from([(
|
||||
"api-version".to_string(),
|
||||
"2025-04-01-preview".to_string(),
|
||||
@@ -2486,6 +2796,7 @@ async fn env_var_overrides_loaded_auth() {
|
||||
)])),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
http_headers: Some(std::collections::HashMap::from([(
|
||||
"Custom-Header".to_string(),
|
||||
|
||||
@@ -1674,6 +1674,7 @@ fn websocket_provider_with_connect_timeout(
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
|
||||
@@ -69,6 +69,7 @@ async fn continue_after_stream_error() {
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
|
||||
@@ -53,6 +53,7 @@ async fn retries_on_early_close() {
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
auth: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::auth::storage::get_auth_file;
|
||||
use crate::token_data::IdTokenInfo;
|
||||
use crate::token_data::KnownPlan as InternalKnownPlan;
|
||||
use crate::token_data::PlanType as InternalPlanType;
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::account::PlanType as AccountPlanType;
|
||||
|
||||
use base64::Engine;
|
||||
@@ -12,6 +13,7 @@ use pretty_assertions::assert_eq;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -252,6 +254,100 @@ fn refresh_failure_is_scoped_to_the_matching_auth_snapshot() {
|
||||
assert_eq!(manager.refresh_failure_for_auth(&updated_auth), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn external_auth_tokens_without_chatgpt_metadata_cannot_seed_chatgpt_auth() {
|
||||
let err = AuthDotJson::from_external_tokens(&ExternalAuthTokens::access_token_only(
|
||||
"test-access-token",
|
||||
))
|
||||
.expect_err("bearer-only external auth should not seed ChatGPT auth");
|
||||
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"external auth tokens are missing ChatGPT metadata"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_manager_with_external_bearer_refresher_returns_provider_token_only_for_derived_manager()
|
||||
{
|
||||
let base_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("base-token"));
|
||||
let derived_manager =
|
||||
base_manager.with_external_bearer_refresher(Arc::new(StaticExternalAuthRefresher::new(
|
||||
Some(ExternalAuthTokens::access_token_only("provider-token")),
|
||||
ExternalAuthTokens::access_token_only("refreshed-provider-token"),
|
||||
)));
|
||||
|
||||
assert_eq!(
|
||||
base_manager
|
||||
.auth()
|
||||
.await
|
||||
.and_then(|auth| auth.api_key().map(str::to_string)),
|
||||
Some("base-token".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
derived_manager
|
||||
.auth()
|
||||
.await
|
||||
.and_then(|auth| auth.api_key().map(str::to_string)),
|
||||
Some("provider-token".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unauthorized_recovery_uses_external_refresh_for_bearer_manager() {
|
||||
let base_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("base-token"));
|
||||
let refresher = Arc::new(StaticExternalAuthRefresher::new(
|
||||
Some(ExternalAuthTokens::access_token_only("provider-token")),
|
||||
ExternalAuthTokens::access_token_only("refreshed-provider-token"),
|
||||
));
|
||||
let derived_manager = base_manager.with_external_bearer_refresher(refresher.clone());
|
||||
let mut recovery = derived_manager.unauthorized_recovery();
|
||||
|
||||
assert!(recovery.has_next());
|
||||
assert_eq!(recovery.mode_name(), "external");
|
||||
assert_eq!(recovery.step_name(), "external_refresh");
|
||||
|
||||
let result = recovery
|
||||
.next()
|
||||
.await
|
||||
.expect("external refresh should succeed");
|
||||
|
||||
assert_eq!(result.auth_state_changed(), Some(true));
|
||||
assert_eq!(*refresher.refresh_calls.lock().unwrap(), 1);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct StaticExternalAuthRefresher {
|
||||
resolved: Option<ExternalAuthTokens>,
|
||||
refreshed: ExternalAuthTokens,
|
||||
refresh_calls: Mutex<usize>,
|
||||
}
|
||||
|
||||
impl StaticExternalAuthRefresher {
|
||||
fn new(resolved: Option<ExternalAuthTokens>, refreshed: ExternalAuthTokens) -> Self {
|
||||
Self {
|
||||
resolved,
|
||||
refreshed,
|
||||
refresh_calls: Mutex::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ExternalAuthRefresher for StaticExternalAuthRefresher {
|
||||
async fn resolve(&self) -> std::io::Result<Option<ExternalAuthTokens>> {
|
||||
Ok(self.resolved.clone())
|
||||
}
|
||||
|
||||
async fn refresh(
|
||||
&self,
|
||||
_context: ExternalAuthRefreshContext,
|
||||
) -> std::io::Result<ExternalAuthTokens> {
|
||||
*self.refresh_calls.lock().unwrap() += 1;
|
||||
Ok(self.refreshed.clone())
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthFileParams {
|
||||
openai_api_key: Option<String>,
|
||||
chatgpt_plan_type: Option<String>,
|
||||
|
||||
@@ -90,11 +90,43 @@ pub enum RefreshTokenError {
|
||||
Transient(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct ExternalAuthChatgptMetadata {
|
||||
pub account_id: String,
|
||||
pub plan_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct ExternalAuthTokens {
|
||||
pub access_token: String,
|
||||
pub chatgpt_account_id: String,
|
||||
pub chatgpt_plan_type: Option<String>,
|
||||
pub chatgpt_metadata: Option<ExternalAuthChatgptMetadata>,
|
||||
}
|
||||
|
||||
impl ExternalAuthTokens {
|
||||
pub fn access_token_only(access_token: impl Into<String>) -> Self {
|
||||
Self {
|
||||
access_token: access_token.into(),
|
||||
chatgpt_metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chatgpt(
|
||||
access_token: impl Into<String>,
|
||||
chatgpt_account_id: impl Into<String>,
|
||||
chatgpt_plan_type: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
access_token: access_token.into(),
|
||||
chatgpt_metadata: Some(ExternalAuthChatgptMetadata {
|
||||
account_id: chatgpt_account_id.into(),
|
||||
plan_type: chatgpt_plan_type,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chatgpt_metadata(&self) -> Option<&ExternalAuthChatgptMetadata> {
|
||||
self.chatgpt_metadata.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
@@ -110,6 +142,10 @@ pub struct ExternalAuthRefreshContext {
|
||||
|
||||
#[async_trait]
|
||||
pub trait ExternalAuthRefresher: Send + Sync {
|
||||
async fn resolve(&self) -> std::io::Result<Option<ExternalAuthTokens>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn refresh(
|
||||
&self,
|
||||
context: ExternalAuthRefreshContext,
|
||||
@@ -736,11 +772,16 @@ fn refresh_token_endpoint() -> String {
|
||||
|
||||
impl AuthDotJson {
|
||||
fn from_external_tokens(external: &ExternalAuthTokens) -> std::io::Result<Self> {
|
||||
let Some(chatgpt_metadata) = external.chatgpt_metadata() else {
|
||||
return Err(std::io::Error::other(
|
||||
"external auth tokens are missing ChatGPT metadata",
|
||||
));
|
||||
};
|
||||
let mut token_info =
|
||||
parse_chatgpt_jwt_claims(&external.access_token).map_err(std::io::Error::other)?;
|
||||
token_info.chatgpt_account_id = Some(external.chatgpt_account_id.clone());
|
||||
token_info.chatgpt_plan_type = external
|
||||
.chatgpt_plan_type
|
||||
token_info.chatgpt_account_id = Some(chatgpt_metadata.account_id.clone());
|
||||
token_info.chatgpt_plan_type = chatgpt_metadata
|
||||
.plan_type
|
||||
.as_deref()
|
||||
.map(InternalPlanType::from_raw_value)
|
||||
.or(token_info.chatgpt_plan_type)
|
||||
@@ -749,7 +790,7 @@ impl AuthDotJson {
|
||||
id_token: token_info,
|
||||
access_token: external.access_token.clone(),
|
||||
refresh_token: String::new(),
|
||||
account_id: Some(external.chatgpt_account_id.clone()),
|
||||
account_id: Some(chatgpt_metadata.account_id.clone()),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
@@ -765,11 +806,11 @@ impl AuthDotJson {
|
||||
chatgpt_account_id: &str,
|
||||
chatgpt_plan_type: Option<&str>,
|
||||
) -> std::io::Result<Self> {
|
||||
let external = ExternalAuthTokens {
|
||||
access_token: access_token.to_string(),
|
||||
chatgpt_account_id: chatgpt_account_id.to_string(),
|
||||
chatgpt_plan_type: chatgpt_plan_type.map(str::to_string),
|
||||
};
|
||||
let external = ExternalAuthTokens::chatgpt(
|
||||
access_token,
|
||||
chatgpt_account_id,
|
||||
chatgpt_plan_type.map(str::to_string),
|
||||
);
|
||||
Self::from_external_tokens(&external)
|
||||
}
|
||||
|
||||
@@ -799,8 +840,6 @@ impl AuthDotJson {
|
||||
#[derive(Clone)]
|
||||
struct CachedAuth {
|
||||
auth: Option<CodexAuth>,
|
||||
/// Callback used to refresh external auth by asking the parent app for new tokens.
|
||||
external_refresher: Option<Arc<dyn ExternalAuthRefresher>>,
|
||||
/// Permanent refresh failure cached for the current auth snapshot so
|
||||
/// later refresh attempts for the same credentials fail fast without network.
|
||||
permanent_refresh_failure: Option<AuthScopedRefreshFailure>,
|
||||
@@ -812,6 +851,27 @@ struct AuthScopedRefreshFailure {
|
||||
error: RefreshTokenFailedError,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum ExternalAuthKind {
|
||||
Bearer,
|
||||
Chatgpt,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ExternalAuthHandle {
|
||||
kind: ExternalAuthKind,
|
||||
refresher: Arc<dyn ExternalAuthRefresher>,
|
||||
}
|
||||
|
||||
impl Debug for ExternalAuthHandle {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ExternalAuthHandle")
|
||||
.field("kind", &self.kind)
|
||||
.field("refresher", &"present")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for CachedAuth {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("CachedAuth")
|
||||
@@ -819,10 +879,6 @@ impl Debug for CachedAuth {
|
||||
"auth_mode",
|
||||
&self.auth.as_ref().map(CodexAuth::api_auth_mode),
|
||||
)
|
||||
.field(
|
||||
"external_refresher",
|
||||
&self.external_refresher.as_ref().map(|_| "present"),
|
||||
)
|
||||
.field(
|
||||
"permanent_refresh_failure",
|
||||
&self
|
||||
@@ -866,9 +922,14 @@ enum UnauthorizedRecoveryMode {
|
||||
// 2. Attempt to refresh the token using OAuth token refresh flow.
|
||||
// If after both steps the server still responds with 401 we let the error bubble to the user.
|
||||
//
|
||||
// For external ChatGPT auth tokens (chatgptAuthTokens), UnauthorizedRecovery does not touch disk or refresh
|
||||
// tokens locally. Instead it calls the ExternalAuthRefresher (account/chatgptAuthTokens/refresh) to ask the
|
||||
// parent app for new tokens, stores them in the ephemeral auth store, and retries once.
|
||||
// For external auth sources, UnauthorizedRecovery delegates to the configured
|
||||
// ExternalAuthRefresher and retries once.
|
||||
//
|
||||
// - External ChatGPT auth tokens (`chatgptAuthTokens`) are refreshed by asking
|
||||
// the parent app for new tokens, persisting them in the ephemeral auth
|
||||
// store, and reloading the cached auth snapshot.
|
||||
// - External bearer auth sources resolve bearer-only tokens for custom model
|
||||
// providers and refresh them without touching disk.
|
||||
pub struct UnauthorizedRecovery {
|
||||
manager: Arc<AuthManager>,
|
||||
step: UnauthorizedRecoveryStep,
|
||||
@@ -891,9 +952,10 @@ impl UnauthorizedRecovery {
|
||||
fn new(manager: Arc<AuthManager>) -> Self {
|
||||
let cached_auth = manager.auth_cached();
|
||||
let expected_account_id = cached_auth.as_ref().and_then(CodexAuth::get_account_id);
|
||||
let mode = if cached_auth
|
||||
.as_ref()
|
||||
.is_some_and(CodexAuth::is_external_chatgpt_tokens)
|
||||
let mode = if manager.external_auth_kind() == Some(ExternalAuthKind::Bearer)
|
||||
|| cached_auth
|
||||
.as_ref()
|
||||
.is_some_and(CodexAuth::is_external_chatgpt_tokens)
|
||||
{
|
||||
UnauthorizedRecoveryMode::External
|
||||
} else {
|
||||
@@ -912,6 +974,10 @@ impl UnauthorizedRecovery {
|
||||
}
|
||||
|
||||
pub fn has_next(&self) -> bool {
|
||||
if self.manager.external_auth_kind() == Some(ExternalAuthKind::Bearer) {
|
||||
return !matches!(self.step, UnauthorizedRecoveryStep::Done);
|
||||
}
|
||||
|
||||
if !self
|
||||
.manager
|
||||
.auth_cached()
|
||||
@@ -931,6 +997,16 @@ impl UnauthorizedRecovery {
|
||||
}
|
||||
|
||||
pub fn unavailable_reason(&self) -> &'static str {
|
||||
if self.manager.external_auth_kind() == Some(ExternalAuthKind::Bearer) {
|
||||
return if matches!(self.step, UnauthorizedRecoveryStep::Done) {
|
||||
"recovery_exhausted"
|
||||
} else if self.manager.has_external_auth_refresher() {
|
||||
"ready"
|
||||
} else {
|
||||
"no_external_refresher"
|
||||
};
|
||||
}
|
||||
|
||||
if !self
|
||||
.manager
|
||||
.auth_cached()
|
||||
@@ -1039,11 +1115,12 @@ impl UnauthorizedRecovery {
|
||||
#[derive(Debug)]
|
||||
pub struct AuthManager {
|
||||
codex_home: PathBuf,
|
||||
inner: RwLock<CachedAuth>,
|
||||
inner: Arc<RwLock<CachedAuth>>,
|
||||
enable_codex_api_key_env: bool,
|
||||
auth_credentials_store_mode: AuthCredentialsStoreMode,
|
||||
forced_chatgpt_workspace_id: RwLock<Option<String>>,
|
||||
refresh_lock: AsyncMutex<()>,
|
||||
forced_chatgpt_workspace_id: Arc<RwLock<Option<String>>>,
|
||||
refresh_lock: Arc<AsyncMutex<()>>,
|
||||
external_auth: RwLock<Option<ExternalAuthHandle>>,
|
||||
}
|
||||
|
||||
impl AuthManager {
|
||||
@@ -1065,15 +1142,15 @@ impl AuthManager {
|
||||
.flatten();
|
||||
Self {
|
||||
codex_home,
|
||||
inner: RwLock::new(CachedAuth {
|
||||
inner: Arc::new(RwLock::new(CachedAuth {
|
||||
auth: managed_auth,
|
||||
external_refresher: None,
|
||||
permanent_refresh_failure: None,
|
||||
}),
|
||||
})),
|
||||
enable_codex_api_key_env,
|
||||
auth_credentials_store_mode,
|
||||
forced_chatgpt_workspace_id: RwLock::new(None),
|
||||
refresh_lock: AsyncMutex::new(()),
|
||||
forced_chatgpt_workspace_id: Arc::new(RwLock::new(None)),
|
||||
refresh_lock: Arc::new(AsyncMutex::new(())),
|
||||
external_auth: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1081,17 +1158,17 @@ impl AuthManager {
|
||||
pub fn from_auth_for_testing(auth: CodexAuth) -> Arc<Self> {
|
||||
let cached = CachedAuth {
|
||||
auth: Some(auth),
|
||||
external_refresher: None,
|
||||
permanent_refresh_failure: None,
|
||||
};
|
||||
|
||||
Arc::new(Self {
|
||||
codex_home: PathBuf::from("non-existent"),
|
||||
inner: RwLock::new(cached),
|
||||
inner: Arc::new(RwLock::new(cached)),
|
||||
enable_codex_api_key_env: false,
|
||||
auth_credentials_store_mode: AuthCredentialsStoreMode::File,
|
||||
forced_chatgpt_workspace_id: RwLock::new(None),
|
||||
refresh_lock: AsyncMutex::new(()),
|
||||
forced_chatgpt_workspace_id: Arc::new(RwLock::new(None)),
|
||||
refresh_lock: Arc::new(AsyncMutex::new(())),
|
||||
external_auth: RwLock::new(None),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1099,16 +1176,16 @@ impl AuthManager {
|
||||
pub fn from_auth_for_testing_with_home(auth: CodexAuth, codex_home: PathBuf) -> Arc<Self> {
|
||||
let cached = CachedAuth {
|
||||
auth: Some(auth),
|
||||
external_refresher: None,
|
||||
permanent_refresh_failure: None,
|
||||
};
|
||||
Arc::new(Self {
|
||||
codex_home,
|
||||
inner: RwLock::new(cached),
|
||||
inner: Arc::new(RwLock::new(cached)),
|
||||
enable_codex_api_key_env: false,
|
||||
auth_credentials_store_mode: AuthCredentialsStoreMode::File,
|
||||
forced_chatgpt_workspace_id: RwLock::new(None),
|
||||
refresh_lock: AsyncMutex::new(()),
|
||||
forced_chatgpt_workspace_id: Arc::new(RwLock::new(None)),
|
||||
refresh_lock: Arc::new(AsyncMutex::new(())),
|
||||
external_auth: RwLock::new(None),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1131,6 +1208,10 @@ impl AuthManager {
|
||||
/// For stale managed ChatGPT auth, first performs a guarded reload and then
|
||||
/// refreshes only if the on-disk auth is unchanged.
|
||||
pub async fn auth(&self) -> Option<CodexAuth> {
|
||||
if let Some(auth) = self.resolve_external_api_key_auth().await {
|
||||
return Some(auth);
|
||||
}
|
||||
|
||||
let auth = self.auth_cached()?;
|
||||
if Self::is_stale_for_proactive_refresh(&auth)
|
||||
&& let Err(err) = self.refresh_token().await
|
||||
@@ -1251,17 +1332,38 @@ impl AuthManager {
|
||||
}
|
||||
|
||||
pub fn set_external_auth_refresher(&self, refresher: Arc<dyn ExternalAuthRefresher>) {
|
||||
if let Ok(mut guard) = self.inner.write() {
|
||||
guard.external_refresher = Some(refresher);
|
||||
if let Ok(mut guard) = self.external_auth.write() {
|
||||
*guard = Some(ExternalAuthHandle {
|
||||
kind: ExternalAuthKind::Chatgpt,
|
||||
refresher,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_external_auth_refresher(&self) {
|
||||
if let Ok(mut guard) = self.inner.write() {
|
||||
guard.external_refresher = None;
|
||||
if let Ok(mut guard) = self.external_auth.write() {
|
||||
*guard = None;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_external_bearer_refresher(
|
||||
self: &Arc<Self>,
|
||||
refresher: Arc<dyn ExternalAuthRefresher>,
|
||||
) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
codex_home: self.codex_home.clone(),
|
||||
inner: Arc::clone(&self.inner),
|
||||
enable_codex_api_key_env: self.enable_codex_api_key_env,
|
||||
auth_credentials_store_mode: self.auth_credentials_store_mode,
|
||||
forced_chatgpt_workspace_id: Arc::clone(&self.forced_chatgpt_workspace_id),
|
||||
refresh_lock: Arc::clone(&self.refresh_lock),
|
||||
external_auth: RwLock::new(Some(ExternalAuthHandle {
|
||||
kind: ExternalAuthKind::Bearer,
|
||||
refresher,
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_forced_chatgpt_workspace_id(&self, workspace_id: Option<String>) {
|
||||
if let Ok(mut guard) = self.forced_chatgpt_workspace_id.write() {
|
||||
*guard = workspace_id;
|
||||
@@ -1276,13 +1378,17 @@ impl AuthManager {
|
||||
}
|
||||
|
||||
pub fn has_external_auth_refresher(&self) -> bool {
|
||||
self.inner
|
||||
self.external_auth
|
||||
.read()
|
||||
.ok()
|
||||
.map(|guard| guard.external_refresher.is_some())
|
||||
.map(|guard| guard.is_some())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn has_external_bearer_refresher(&self) -> bool {
|
||||
self.external_auth_kind() == Some(ExternalAuthKind::Bearer)
|
||||
}
|
||||
|
||||
pub fn is_external_auth_active(&self) -> bool {
|
||||
self.auth_cached()
|
||||
.as_ref()
|
||||
@@ -1310,6 +1416,35 @@ impl AuthManager {
|
||||
UnauthorizedRecovery::new(Arc::clone(self))
|
||||
}
|
||||
|
||||
fn external_auth_handle(&self) -> Option<ExternalAuthHandle> {
|
||||
self.external_auth
|
||||
.read()
|
||||
.ok()
|
||||
.and_then(|guard| guard.clone())
|
||||
}
|
||||
|
||||
fn external_auth_kind(&self) -> Option<ExternalAuthKind> {
|
||||
self.external_auth_handle().map(|handle| handle.kind)
|
||||
}
|
||||
|
||||
async fn resolve_external_api_key_auth(&self) -> Option<CodexAuth> {
|
||||
let Some(handle) = self.external_auth_handle() else {
|
||||
return None;
|
||||
};
|
||||
if handle.kind != ExternalAuthKind::Bearer {
|
||||
return None;
|
||||
}
|
||||
|
||||
match handle.refresher.resolve().await {
|
||||
Ok(Some(tokens)) => Some(CodexAuth::from_api_key(&tokens.access_token)),
|
||||
Ok(None) => None,
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to resolve external bearer auth: {err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to refresh the token by first performing a guarded reload. Auth
|
||||
/// is reloaded from storage only when the account id matches the currently
|
||||
/// cached account id. If the persisted token differs from the cached token, we
|
||||
@@ -1432,16 +1567,7 @@ impl AuthManager {
|
||||
reason: ExternalAuthRefreshReason,
|
||||
) -> Result<(), RefreshTokenError> {
|
||||
let forced_chatgpt_workspace_id = self.forced_chatgpt_workspace_id();
|
||||
let refresher = match self.inner.read() {
|
||||
Ok(guard) => guard.external_refresher.clone(),
|
||||
Err(_) => {
|
||||
return Err(RefreshTokenError::Transient(std::io::Error::other(
|
||||
"failed to read external auth state",
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let Some(refresher) = refresher else {
|
||||
let Some(handle) = self.external_auth_handle() else {
|
||||
return Err(RefreshTokenError::Transient(std::io::Error::other(
|
||||
"external auth refresher is not configured",
|
||||
)));
|
||||
@@ -1456,14 +1582,22 @@ impl AuthManager {
|
||||
previous_account_id,
|
||||
};
|
||||
|
||||
let refreshed = refresher.refresh(context).await?;
|
||||
let refreshed = handle.refresher.refresh(context).await?;
|
||||
if handle.kind == ExternalAuthKind::Bearer {
|
||||
return Ok(());
|
||||
}
|
||||
let Some(chatgpt_metadata) = refreshed.chatgpt_metadata() else {
|
||||
return Err(RefreshTokenError::Transient(std::io::Error::other(
|
||||
"external auth refresh did not return ChatGPT metadata",
|
||||
)));
|
||||
};
|
||||
if let Some(expected_workspace_id) = forced_chatgpt_workspace_id.as_deref()
|
||||
&& refreshed.chatgpt_account_id != expected_workspace_id
|
||||
&& chatgpt_metadata.account_id != expected_workspace_id
|
||||
{
|
||||
return Err(RefreshTokenError::Transient(std::io::Error::other(
|
||||
format!(
|
||||
"external auth refresh returned workspace {:?}, expected {expected_workspace_id:?}",
|
||||
refreshed.chatgpt_account_id,
|
||||
chatgpt_metadata.account_id,
|
||||
),
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -22,6 +22,11 @@ pub use auth::AuthManager;
|
||||
pub use auth::CLIENT_ID;
|
||||
pub use auth::CODEX_API_KEY_ENV_VAR;
|
||||
pub use auth::CodexAuth;
|
||||
pub use auth::ExternalAuthChatgptMetadata;
|
||||
pub use auth::ExternalAuthRefreshContext;
|
||||
pub use auth::ExternalAuthRefreshReason;
|
||||
pub use auth::ExternalAuthRefresher;
|
||||
pub use auth::ExternalAuthTokens;
|
||||
pub use auth::OPENAI_API_KEY_ENV_VAR;
|
||||
pub use auth::REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR;
|
||||
pub use auth::RefreshTokenError;
|
||||
|
||||
Reference in New Issue
Block a user