Files
codex/prs/bolinfest/PR-1861.md
2025-09-02 15:17:45 -07:00

14 KiB
Raw Blame History

PR #1861: Prefer env var auth over default codex auth

Description

Summary

  • Prioritize provider-specific API keys over default Codex auth when building requests
  • Add test to ensure provider env var auth overrides default auth

Testing

  • just fmt
  • just fix (fails: let expressions in this position are unstable)
  • cargo test --all-features (fails: let expressions in this position are unstable)

https://chatgpt.com/codex/tasks/task_i_68926a104f7483208f2c8fd36763e0e3

Full Diff

diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs
index 9748cde7cb..ed05fb5db0 100644
--- a/codex-rs/core/src/client.rs
+++ b/codex-rs/core/src/client.rs
@@ -623,7 +623,7 @@ mod tests {
             request_max_retries: Some(0),
             stream_max_retries: Some(0),
             stream_idle_timeout_ms: Some(1000),
-            requires_auth: false,
+            requires_openai_auth: false,
         };
 
         let events = collect_events(
@@ -683,7 +683,7 @@ mod tests {
             request_max_retries: Some(0),
             stream_max_retries: Some(0),
             stream_idle_timeout_ms: Some(1000),
-            requires_auth: false,
+            requires_openai_auth: false,
         };
 
         let events = collect_events(&[sse1.as_bytes()], provider).await;
@@ -786,7 +786,7 @@ mod tests {
                 request_max_retries: Some(0),
                 stream_max_retries: Some(0),
                 stream_idle_timeout_ms: Some(1000),
-                requires_auth: false,
+                requires_openai_auth: false,
             };
 
             let out = run_sse(evs, provider).await;
diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs
index f48cc9340b..63a2e5949f 100644
--- a/codex-rs/core/src/config.rs
+++ b/codex-rs/core/src/config.rs
@@ -842,7 +842,7 @@ disable_response_storage = true
             request_max_retries: Some(4),
             stream_max_retries: Some(10),
             stream_idle_timeout_ms: Some(300_000),
-            requires_auth: false,
+            requires_openai_auth: false,
         };
         let model_provider_map = {
             let mut model_provider_map = built_in_model_providers();
diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs
index db369df3b7..a980211199 100644
--- a/codex-rs/core/src/model_provider_info.rs
+++ b/codex-rs/core/src/model_provider_info.rs
@@ -9,7 +9,6 @@ use codex_login::AuthMode;
 use codex_login::CodexAuth;
 use serde::Deserialize;
 use serde::Serialize;
-use std::borrow::Cow;
 use std::collections::HashMap;
 use std::env::VarError;
 use std::time::Duration;
@@ -79,7 +78,7 @@ pub struct ModelProviderInfo {
 
     /// Whether this provider requires some form of standard authentication (API key, ChatGPT token).
     #[serde(default)]
-    pub requires_auth: bool,
+    pub requires_openai_auth: bool,
 }
 
 impl ModelProviderInfo {
@@ -87,26 +86,32 @@ impl ModelProviderInfo {
     /// reqwest Client applying:
     ///   • provider-specific headers (static + env based)
     ///   • Bearer auth header when an API key is available.
+    ///   • Auth token for OAuth.
     ///
-    /// When `require_api_key` is true and the provider declares an `env_key`
-    /// but the variable is missing/empty, returns an [`Err`] identical to the
+    /// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the
     /// one produced by [`ModelProviderInfo::api_key`].
     pub async fn create_request_builder<'a>(
         &'a self,
         client: &'a reqwest::Client,
         auth: &Option<CodexAuth>,
     ) -> crate::error::Result<reqwest::RequestBuilder> {
-        let auth: Cow<'_, Option<CodexAuth>> = if auth.is_some() {
-            Cow::Borrowed(auth)
-        } else {
-            Cow::Owned(self.get_fallback_auth()?)
+        let effective_auth = match self.api_key() {
+            Ok(Some(key)) => Some(CodexAuth::from_api_key(key)),
+            Ok(None) => auth.clone(),
+            Err(err) => {
+                if auth.is_some() {
+                    auth.clone()
+                } else {
+                    return Err(err);
+                }
+            }
         };
 
-        let url = self.get_full_url(&auth);
+        let url = self.get_full_url(&effective_auth);
 
         let mut builder = client.post(url);
 
-        if let Some(auth) = auth.as_ref() {
+        if let Some(auth) = effective_auth.as_ref() {
             builder = builder.bearer_auth(auth.get_token().await?);
         }
 
@@ -216,14 +221,6 @@ impl ModelProviderInfo {
             .map(Duration::from_millis)
             .unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
     }
-
-    fn get_fallback_auth(&self) -> crate::error::Result<Option<CodexAuth>> {
-        let api_key = self.api_key()?;
-        if let Some(api_key) = api_key {
-            return Ok(Some(CodexAuth::from_api_key(api_key)));
-        }
-        Ok(None)
-    }
 }
 
 const DEFAULT_OLLAMA_PORT: u32 = 11434;
@@ -275,7 +272,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
                 request_max_retries: None,
                 stream_max_retries: None,
                 stream_idle_timeout_ms: None,
-                requires_auth: true,
+                requires_openai_auth: true,
             },
         ),
         (BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
@@ -319,7 +316,7 @@ pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
         request_max_retries: None,
         stream_max_retries: None,
         stream_idle_timeout_ms: None,
-        requires_auth: false,
+        requires_openai_auth: false,
     }
 }
 
@@ -347,7 +344,7 @@ base_url = "http://localhost:11434/v1"
             request_max_retries: None,
             stream_max_retries: None,
             stream_idle_timeout_ms: None,
-            requires_auth: false,
+            requires_openai_auth: false,
         };
 
         let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
@@ -376,7 +373,7 @@ query_params = { api-version = "2025-04-01-preview" }
             request_max_retries: None,
             stream_max_retries: None,
             stream_idle_timeout_ms: None,
-            requires_auth: false,
+            requires_openai_auth: false,
         };
 
         let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
@@ -408,7 +405,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
             request_max_retries: None,
             stream_max_retries: None,
             stream_idle_timeout_ms: None,
-            requires_auth: false,
+            requires_openai_auth: false,
         };
 
         let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
diff --git a/codex-rs/core/tests/client.rs b/codex-rs/core/tests/client.rs
index 00f91a879e..60eb922474 100644
--- a/codex-rs/core/tests/client.rs
+++ b/codex-rs/core/tests/client.rs
@@ -458,7 +458,7 @@ async fn azure_overrides_assign_properties_used_for_responses_url() {
         request_max_retries: None,
         stream_max_retries: None,
         stream_idle_timeout_ms: None,
-        requires_auth: false,
+        requires_openai_auth: false,
     };
 
     // Init session
@@ -481,6 +481,86 @@ async fn azure_overrides_assign_properties_used_for_responses_url() {
     wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
 }
 
+#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+async fn env_var_overrides_loaded_auth() {
+    #![allow(clippy::unwrap_used)]
+
+    let existing_env_var_with_random_value = if cfg!(windows) { "USERNAME" } else { "USER" };
+
+    // Mock server
+    let server = MockServer::start().await;
+
+    // First request  must NOT include `previous_response_id`.
+    let first = ResponseTemplate::new(200)
+        .insert_header("content-type", "text/event-stream")
+        .set_body_raw(sse_completed("resp1"), "text/event-stream");
+
+    // Expect POST to /openai/responses with api-version query param
+    Mock::given(method("POST"))
+        .and(path("/openai/responses"))
+        .and(query_param("api-version", "2025-04-01-preview"))
+        .and(header_regex("Custom-Header", "Value"))
+        .and(header_regex(
+            "Authorization",
+            format!(
+                "Bearer {}",
+                std::env::var(existing_env_var_with_random_value).unwrap()
+            )
+            .as_str(),
+        ))
+        .respond_with(first)
+        .expect(1)
+        .mount(&server)
+        .await;
+
+    let provider = ModelProviderInfo {
+        name: "custom".to_string(),
+        base_url: Some(format!("{}/openai", server.uri())),
+        // Reuse the existing environment variable to avoid using unsafe code
+        env_key: Some(existing_env_var_with_random_value.to_string()),
+        query_params: Some(std::collections::HashMap::from([(
+            "api-version".to_string(),
+            "2025-04-01-preview".to_string(),
+        )])),
+        env_key_instructions: None,
+        wire_api: WireApi::Responses,
+        http_headers: Some(std::collections::HashMap::from([(
+            "Custom-Header".to_string(),
+            "Value".to_string(),
+        )])),
+        env_http_headers: None,
+        request_max_retries: None,
+        stream_max_retries: None,
+        stream_idle_timeout_ms: None,
+        requires_openai_auth: false,
+    };
+
+    // Init session
+    let codex_home = TempDir::new().unwrap();
+    let mut config = load_default_config_for_test(&codex_home);
+    config.model_provider = provider;
+
+    let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
+    let CodexSpawnOk { codex, .. } = Codex::spawn(
+        config,
+        Some(auth_from_token("Default Access Token".to_string())),
+        ctrl_c.clone(),
+    )
+    .await
+    .unwrap();
+
+    codex
+        .submit(Op::UserInput {
+            items: vec![InputItem::Text {
+                text: "hello".into(),
+            }],
+        })
+        .await
+        .unwrap();
+
+    wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
+}
+
 fn auth_from_token(id_token: String) -> CodexAuth {
     CodexAuth::new(
         None,
diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs
index 3e30d93709..8a4216b129 100644
--- a/codex-rs/core/tests/stream_no_completed.rs
+++ b/codex-rs/core/tests/stream_no_completed.rs
@@ -90,7 +90,7 @@ async fn retries_on_early_close() {
         request_max_retries: Some(0),
         stream_max_retries: Some(1),
         stream_idle_timeout_ms: Some(2000),
-        requires_auth: false,
+        requires_openai_auth: false,
     };
 
     let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs
index 50535e5967..0228a56859 100644
--- a/codex-rs/tui/src/lib.rs
+++ b/codex-rs/tui/src/lib.rs
@@ -287,7 +287,7 @@ fn restore() {
 
 #[allow(clippy::unwrap_used)]
 fn should_show_login_screen(config: &Config) -> bool {
-    if config.model_provider.requires_auth {
+    if config.model_provider.requires_openai_auth {
         // Reading the OpenAI API key is an async operation because it may need
         // to refresh the token. Block on it.
         let codex_home = config.codex_home.clone();

Review Comments

codex-rs/core/src/model_provider_info.rs

@@ -96,17 +95,23 @@ impl ModelProviderInfo {
         client: &'a reqwest::Client,
         auth: &Option<CodexAuth>,
     ) -> crate::error::Result<reqwest::RequestBuilder> {
-        let auth: Cow<'_, Option<CodexAuth>> = if auth.is_some() {
-            Cow::Borrowed(auth)
-        } else {
-            Cow::Owned(self.get_fallback_auth()?)
+        let effective_auth = match self.api_key() {

docstring for this function seems to be out of date?

codex-rs/core/tests/client.rs

@@ -460,6 +460,86 @@ async fn azure_overrides_assign_properties_used_for_responses_url() {
     wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
 }
 
+#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+async fn env_var_overrides_loaded_auth() {
+    #![allow(clippy::unwrap_used)]
+
+    let existing_env_var_with_random_value = if cfg!(windows) { "USERNAME" } else { "USER" };
+
+    // Mock server
+    let server = MockServer::start().await;
+
+    // First request  must NOT include `previous_response_id`.
+    let first = ResponseTemplate::new(200)
+        .insert_header("content-type", "text/event-stream")
+        .set_body_raw(sse_completed("resp1"), "text/event-stream");
+
+    // Expect POST to /openai/responses with api-version query param
+    Mock::given(method("POST"))
+        .and(path("/openai/responses"))
+        .and(query_param("api-version", "2025-04-01-preview"))
+        .and(header_regex("Custom-Header", "Value"))
+        .and(header_regex(
+            "Authorization",
+            format!(
+                "Bearer {}",
+                std::env::var(existing_env_var_with_random_value).unwrap()
+            )
+            .as_str(),
+        ))
+        .respond_with(first)
+        .expect(1)
+        .mount(&server)
+        .await;
+
+    let provider = ModelProviderInfo {
+        name: "custom".to_string(),
+        base_url: Some(format!("{}/openai", server.uri())),
+        // Reuse the existing environment variable to avoid using unsafe code
+        env_key: Some(existing_env_var_with_random_value.to_string()),
+        query_params: Some(std::collections::HashMap::from([(
+            "api-version".to_string(),
+            "2025-04-01-preview".to_string(),
+        )])),
+        env_key_instructions: None,
+        wire_api: WireApi::Responses,
+        http_headers: Some(std::collections::HashMap::from([(
+            "Custom-Header".to_string(),
+            "Value".to_string(),
+        )])),
+        env_http_headers: None,
+        request_max_retries: None,
+        stream_max_retries: None,
+        stream_idle_timeout_ms: None,
+        requires_auth: false,

Should this be true?

Moreover, is this docstring still accurate?

f6c8d1117c/codex-rs/core/src/model_provider_info.rs (L80-L82)