Extract provider and token modules into codex-auth

Move the foundational provider and token modules into codex-auth while keeping codex-core as the facade. Also move the corresponding unit tests and record the 3-PR migration checkpoints.

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Ahmed Ibrahim
2026-03-18 12:21:10 -07:00
parent 334164a6f7
commit cf801bad4d
18 changed files with 724 additions and 560 deletions

View File

@@ -0,0 +1,15 @@
#[derive(Debug)]
pub struct EnvVarError {
pub var: String,
pub instructions: Option<String>,
}
impl std::fmt::Display for EnvVarError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Missing environment variable: `{}`.", self.var)?;
if let Some(instructions) = &self.instructions {
write!(f, " {instructions}")?;
}
Ok(())
}
}

View File

@@ -0,0 +1,19 @@
pub mod error;
pub mod provider;
pub mod token_data;
#[cfg(test)]
mod model_provider_info_tests;
#[cfg(test)]
mod token_data_tests;
pub use error::EnvVarError;
pub use provider::DEFAULT_LMSTUDIO_PORT;
pub use provider::DEFAULT_OLLAMA_PORT;
pub use provider::LMSTUDIO_OSS_PROVIDER_ID;
pub use provider::ModelProviderInfo;
pub use provider::OLLAMA_OSS_PROVIDER_ID;
pub use provider::OPENAI_PROVIDER_ID;
pub use provider::WireApi;
pub use provider::built_in_model_providers;
pub use provider::create_oss_provider_with_base_url;

View File

@@ -0,0 +1,123 @@
use super::provider::*;
use pretty_assertions::assert_eq;
#[test]
fn test_deserialize_ollama_model_provider_toml() {
let azure_provider_toml = r#"
name = "Ollama"
base_url = "http://localhost:11434/v1"
"#;
let expected_provider = ModelProviderInfo {
name: "Ollama".into(),
base_url: Some("http://localhost:11434/v1".into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
websocket_connect_timeout_ms: None,
requires_openai_auth: false,
supports_websockets: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn test_deserialize_azure_model_provider_toml() {
let azure_provider_toml = r#"
name = "Azure"
base_url = "https://xxxxx.openai.azure.com/openai"
env_key = "AZURE_OPENAI_API_KEY"
query_params = { api-version = "2025-04-01-preview" }
"#;
let expected_provider = ModelProviderInfo {
name: "Azure".into(),
base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
env_key: Some("AZURE_OPENAI_API_KEY".into()),
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: Some(maplit::hashmap! {
"api-version".to_string() => "2025-04-01-preview".to_string(),
}),
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
websocket_connect_timeout_ms: None,
requires_openai_auth: false,
supports_websockets: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn test_deserialize_example_model_provider_toml() {
let azure_provider_toml = r#"
name = "Example"
base_url = "https://example.com"
env_key = "API_KEY"
http_headers = { "X-Example-Header" = "example-value" }
env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
"#;
let expected_provider = ModelProviderInfo {
name: "Example".into(),
base_url: Some("https://example.com".into()),
env_key: Some("API_KEY".into()),
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: Some(maplit::hashmap! {
"X-Example-Header".to_string() => "example-value".to_string(),
}),
env_http_headers: Some(maplit::hashmap! {
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
}),
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
websocket_connect_timeout_ms: None,
requires_openai_auth: false,
supports_websockets: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn test_deserialize_chat_wire_api_shows_helpful_error() {
let provider_toml = r#"
name = "OpenAI using Chat Completions"
base_url = "https://api.openai.com/v1"
env_key = "OPENAI_API_KEY"
wire_api = "chat"
"#;
let err = toml::from_str::<ModelProviderInfo>(provider_toml).unwrap_err();
assert!(err.to_string().contains(CHAT_WIRE_API_REMOVED_ERROR));
}
#[test]
fn test_deserialize_websocket_connect_timeout() {
let provider_toml = r#"
name = "OpenAI"
base_url = "https://api.openai.com/v1"
websocket_connect_timeout_ms = 15000
supports_websockets = true
"#;
let provider: ModelProviderInfo = toml::from_str(provider_toml).unwrap();
assert_eq!(provider.websocket_connect_timeout_ms, Some(15_000));
}

View File

@@ -0,0 +1,286 @@
use crate::error::EnvVarError;
use codex_api::Provider as ApiProvider;
use codex_api::provider::RetryConfig as ApiRetryConfig;
use codex_app_server_protocol::AuthMode as ApiAuthMode;
use http::HeaderMap;
use http::header::HeaderName;
use http::header::HeaderValue;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::fmt;
use std::time::Duration;
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
pub const DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS: u64 = 15_000;
const MAX_STREAM_MAX_RETRIES: u64 = 100;
const MAX_REQUEST_MAX_RETRIES: u64 = 100;
const OPENAI_PROVIDER_NAME: &str = "OpenAI";
pub const OPENAI_PROVIDER_ID: &str = "openai";
pub const CHAT_WIRE_API_REMOVED_ERROR: &str = "`wire_api = \"chat\"` is no longer supported.\nHow to fix: set `wire_api = \"responses\"` in your provider config.\nMore info: https://github.com/openai/codex/discussions/7782";
pub const LEGACY_OLLAMA_CHAT_PROVIDER_ID: &str = "ollama-chat";
pub const OLLAMA_CHAT_PROVIDER_REMOVED_ERROR: &str = "`ollama-chat` is no longer supported.\nHow to fix: replace `ollama-chat` with `ollama` in `model_provider`, `oss_provider`, or `--local-provider`.\nMore info: https://github.com/openai/codex/discussions/7782";
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum WireApi {
#[default]
Responses,
}
impl fmt::Display for WireApi {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let value = match self {
Self::Responses => "responses",
};
f.write_str(value)
}
}
impl<'de> Deserialize<'de> for WireApi {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = String::deserialize(deserializer)?;
match value.as_str() {
"responses" => Ok(Self::Responses),
"chat" => Err(serde::de::Error::custom(CHAT_WIRE_API_REMOVED_ERROR)),
_ => Err(serde::de::Error::unknown_variant(&value, &["responses"])),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, JsonSchema)]
#[schemars(deny_unknown_fields)]
pub struct ModelProviderInfo {
pub name: String,
pub base_url: Option<String>,
pub env_key: Option<String>,
pub env_key_instructions: Option<String>,
pub experimental_bearer_token: Option<String>,
#[serde(default)]
pub wire_api: WireApi,
pub query_params: Option<HashMap<String, String>>,
pub http_headers: Option<HashMap<String, String>>,
pub env_http_headers: Option<HashMap<String, String>>,
pub request_max_retries: Option<u64>,
pub stream_max_retries: Option<u64>,
pub stream_idle_timeout_ms: Option<u64>,
pub websocket_connect_timeout_ms: Option<u64>,
#[serde(default)]
pub requires_openai_auth: bool,
#[serde(default)]
pub supports_websockets: bool,
}
impl ModelProviderInfo {
fn build_header_map(&self) -> HeaderMap {
let capacity = self.http_headers.as_ref().map_or(0, HashMap::len)
+ self.env_http_headers.as_ref().map_or(0, HashMap::len);
let mut headers = HeaderMap::with_capacity(capacity);
if let Some(extra) = &self.http_headers {
for (k, v) in extra {
if let (Ok(name), Ok(value)) = (HeaderName::try_from(k), HeaderValue::try_from(v)) {
headers.insert(name, value);
}
}
}
if let Some(env_headers) = &self.env_http_headers {
for (header, env_var) in env_headers {
if let Ok(val) = std::env::var(env_var)
&& !val.trim().is_empty()
&& let (Ok(name), Ok(value)) =
(HeaderName::try_from(header), HeaderValue::try_from(val))
{
headers.insert(name, value);
}
}
}
headers
}
pub fn to_api_provider(
&self,
auth_mode: Option<ApiAuthMode>,
) -> Result<ApiProvider, EnvVarError> {
let default_base_url = if matches!(
auth_mode,
Some(ApiAuthMode::Chatgpt | ApiAuthMode::ChatgptAuthTokens)
) {
"https://chatgpt.com/backend-api/codex"
} else {
"https://api.openai.com/v1"
};
let base_url = self
.base_url
.clone()
.unwrap_or_else(|| default_base_url.to_string());
let retry = ApiRetryConfig {
max_attempts: self.request_max_retries(),
base_delay: Duration::from_millis(200),
retry_429: false,
retry_5xx: true,
retry_transport: true,
};
Ok(ApiProvider {
name: self.name.clone(),
base_url,
query_params: self.query_params.clone(),
headers: self.build_header_map(),
retry,
stream_idle_timeout: self.stream_idle_timeout(),
})
}
pub fn api_key(&self) -> Result<Option<String>, EnvVarError> {
match &self.env_key {
Some(env_key) => {
let api_key = std::env::var(env_key)
.ok()
.filter(|v| !v.trim().is_empty())
.ok_or_else(|| EnvVarError {
var: env_key.clone(),
instructions: self.env_key_instructions.clone(),
})?;
Ok(Some(api_key))
}
None => Ok(None),
}
}
pub fn request_max_retries(&self) -> u64 {
self.request_max_retries
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
.min(MAX_REQUEST_MAX_RETRIES)
}
pub fn stream_max_retries(&self) -> u64 {
self.stream_max_retries
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
.min(MAX_STREAM_MAX_RETRIES)
}
pub fn stream_idle_timeout(&self) -> Duration {
self.stream_idle_timeout_ms
.map(Duration::from_millis)
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
}
pub fn websocket_connect_timeout(&self) -> Duration {
self.websocket_connect_timeout_ms
.map(Duration::from_millis)
.unwrap_or(Duration::from_millis(DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS))
}
pub fn create_openai_provider(base_url: Option<String>) -> ModelProviderInfo {
ModelProviderInfo {
name: OPENAI_PROVIDER_NAME.into(),
base_url,
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: Some(
[("version".to_string(), env!("CARGO_PKG_VERSION").to_string())]
.into_iter()
.collect(),
),
env_http_headers: Some(
[
(
"OpenAI-Organization".to_string(),
"OPENAI_ORGANIZATION".to_string(),
),
("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()),
]
.into_iter()
.collect(),
),
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
websocket_connect_timeout_ms: None,
requires_openai_auth: true,
supports_websockets: true,
}
}
pub fn is_openai(&self) -> bool {
self.name == OPENAI_PROVIDER_NAME
}
}
pub const DEFAULT_LMSTUDIO_PORT: u16 = 1234;
pub const DEFAULT_OLLAMA_PORT: u16 = 11434;
pub const LMSTUDIO_OSS_PROVIDER_ID: &str = "lmstudio";
pub const OLLAMA_OSS_PROVIDER_ID: &str = "ollama";
pub fn built_in_model_providers(
openai_base_url: Option<String>,
) -> HashMap<String, ModelProviderInfo> {
use ModelProviderInfo as P;
let openai_provider = P::create_openai_provider(openai_base_url);
[
(OPENAI_PROVIDER_ID, openai_provider),
(
OLLAMA_OSS_PROVIDER_ID,
create_oss_provider(DEFAULT_OLLAMA_PORT, WireApi::Responses),
),
(
LMSTUDIO_OSS_PROVIDER_ID,
create_oss_provider(DEFAULT_LMSTUDIO_PORT, WireApi::Responses),
),
]
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect()
}
pub fn create_oss_provider(default_provider_port: u16, wire_api: WireApi) -> ModelProviderInfo {
let default_codex_oss_base_url = format!(
"http://localhost:{codex_oss_port}/v1",
codex_oss_port = std::env::var("CODEX_OSS_PORT")
.ok()
.filter(|value| !value.trim().is_empty())
.and_then(|value| value.parse::<u16>().ok())
.unwrap_or(default_provider_port)
);
let codex_oss_base_url = std::env::var("CODEX_OSS_BASE_URL")
.ok()
.filter(|v| !v.trim().is_empty())
.unwrap_or(default_codex_oss_base_url);
create_oss_provider_with_base_url(&codex_oss_base_url, wire_api)
}
pub fn create_oss_provider_with_base_url(base_url: &str, wire_api: WireApi) -> ModelProviderInfo {
ModelProviderInfo {
name: "gpt-oss".into(),
base_url: Some(base_url.into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
websocket_connect_timeout_ms: None,
requires_openai_auth: false,
supports_websockets: false,
}
}

View File

@@ -0,0 +1,163 @@
use base64::Engine;
use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
pub struct TokenData {
#[serde(
deserialize_with = "deserialize_id_token",
serialize_with = "serialize_id_token"
)]
pub id_token: IdTokenInfo,
pub access_token: String,
pub refresh_token: String,
pub account_id: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct IdTokenInfo {
pub email: Option<String>,
pub chatgpt_plan_type: Option<PlanType>,
pub chatgpt_user_id: Option<String>,
pub chatgpt_account_id: Option<String>,
pub raw_jwt: String,
}
impl IdTokenInfo {
pub fn get_chatgpt_plan_type(&self) -> Option<String> {
self.chatgpt_plan_type.as_ref().map(|t| match t {
PlanType::Known(plan) => format!("{plan:?}"),
PlanType::Unknown(s) => s.clone(),
})
}
pub fn is_workspace_account(&self) -> bool {
matches!(
self.chatgpt_plan_type,
Some(PlanType::Known(
KnownPlan::Team | KnownPlan::Business | KnownPlan::Enterprise | KnownPlan::Edu
))
)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PlanType {
Known(KnownPlan),
Unknown(String),
}
impl PlanType {
pub fn from_raw_value(raw: &str) -> Self {
match raw.to_ascii_lowercase().as_str() {
"free" => Self::Known(KnownPlan::Free),
"go" => Self::Known(KnownPlan::Go),
"plus" => Self::Known(KnownPlan::Plus),
"pro" => Self::Known(KnownPlan::Pro),
"team" => Self::Known(KnownPlan::Team),
"business" => Self::Known(KnownPlan::Business),
"enterprise" => Self::Known(KnownPlan::Enterprise),
"education" | "edu" => Self::Known(KnownPlan::Edu),
_ => Self::Unknown(raw.to_string()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum KnownPlan {
Free,
Go,
Plus,
Pro,
Team,
Business,
Enterprise,
Edu,
}
#[derive(Deserialize)]
struct IdClaims {
#[serde(default)]
email: Option<String>,
#[serde(rename = "https://api.openai.com/profile", default)]
profile: Option<ProfileClaims>,
#[serde(rename = "https://api.openai.com/auth", default)]
auth: Option<AuthClaims>,
}
#[derive(Deserialize)]
struct ProfileClaims {
#[serde(default)]
email: Option<String>,
}
#[derive(Deserialize)]
struct AuthClaims {
#[serde(default)]
chatgpt_plan_type: Option<PlanType>,
#[serde(default)]
chatgpt_user_id: Option<String>,
#[serde(default)]
user_id: Option<String>,
#[serde(default)]
chatgpt_account_id: Option<String>,
}
#[derive(Debug, Error)]
pub enum IdTokenInfoError {
#[error("invalid ID token format")]
InvalidFormat,
#[error(transparent)]
Base64(#[from] base64::DecodeError),
#[error(transparent)]
Json(#[from] serde_json::Error),
}
pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
let mut parts = jwt.split('.');
let (_header_b64, payload_b64, _sig_b64) = match (parts.next(), parts.next(), parts.next()) {
(Some(h), Some(p), Some(s)) if !h.is_empty() && !p.is_empty() && !s.is_empty() => (h, p, s),
_ => return Err(IdTokenInfoError::InvalidFormat),
};
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64)?;
let claims: IdClaims = serde_json::from_slice(&payload_bytes)?;
let email = claims
.email
.or_else(|| claims.profile.and_then(|profile| profile.email));
match claims.auth {
Some(auth) => Ok(IdTokenInfo {
email,
raw_jwt: jwt.to_string(),
chatgpt_plan_type: auth.chatgpt_plan_type,
chatgpt_user_id: auth.chatgpt_user_id.or(auth.user_id),
chatgpt_account_id: auth.chatgpt_account_id,
}),
None => Ok(IdTokenInfo {
email,
raw_jwt: jwt.to_string(),
chatgpt_plan_type: None,
chatgpt_user_id: None,
chatgpt_account_id: None,
}),
}
}
fn deserialize_id_token<'de, D>(deserializer: D) -> Result<IdTokenInfo, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
parse_chatgpt_jwt_claims(&s).map_err(serde::de::Error::custom)
}
fn serialize_id_token<S>(id_token: &IdTokenInfo, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&id_token.raw_jwt)
}

View File

@@ -0,0 +1,110 @@
use base64::Engine;
use super::token_data::*;
use pretty_assertions::assert_eq;
use serde::Serialize;
#[test]
fn id_token_info_parses_email_and_plan() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({
"email": "user@example.com",
"https://api.openai.com/auth": {
"chatgpt_plan_type": "pro"
}
});
fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
assert_eq!(info.email.as_deref(), Some("user@example.com"));
assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Pro"));
}
#[test]
fn id_token_info_parses_go_plan() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({
"email": "user@example.com",
"https://api.openai.com/auth": {
"chatgpt_plan_type": "go"
}
});
fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
assert_eq!(info.email.as_deref(), Some("user@example.com"));
assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Go"));
}
#[test]
fn id_token_info_handles_missing_fields() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({ "sub": "123" });
fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
assert!(info.email.is_none());
assert!(info.get_chatgpt_plan_type().is_none());
}
#[test]
fn workspace_account_detection_matches_workspace_plans() {
let workspace = IdTokenInfo {
chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Business)),
..IdTokenInfo::default()
};
assert_eq!(workspace.is_workspace_account(), true);
let personal = IdTokenInfo {
chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)),
..IdTokenInfo::default()
};
assert_eq!(personal.is_workspace_account(), false);
}