mirror of
https://github.com/openai/codex.git
synced 2026-05-03 04:42:20 +03:00
Extract provider and token modules into codex-auth
Move the foundational provider and token modules into codex-auth while keeping codex-core as the facade. Also move the corresponding unit tests and record the 3-PR migration checkpoints. Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
15
codex-rs/codex-auth/src/error.rs
Normal file
15
codex-rs/codex-auth/src/error.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
#[derive(Debug)]
|
||||
pub struct EnvVarError {
|
||||
pub var: String,
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EnvVarError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Missing environment variable: `{}`.", self.var)?;
|
||||
if let Some(instructions) = &self.instructions {
|
||||
write!(f, " {instructions}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
19
codex-rs/codex-auth/src/lib.rs
Normal file
19
codex-rs/codex-auth/src/lib.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
pub mod error;
|
||||
pub mod provider;
|
||||
pub mod token_data;
|
||||
|
||||
#[cfg(test)]
|
||||
mod model_provider_info_tests;
|
||||
#[cfg(test)]
|
||||
mod token_data_tests;
|
||||
|
||||
pub use error::EnvVarError;
|
||||
pub use provider::DEFAULT_LMSTUDIO_PORT;
|
||||
pub use provider::DEFAULT_OLLAMA_PORT;
|
||||
pub use provider::LMSTUDIO_OSS_PROVIDER_ID;
|
||||
pub use provider::ModelProviderInfo;
|
||||
pub use provider::OLLAMA_OSS_PROVIDER_ID;
|
||||
pub use provider::OPENAI_PROVIDER_ID;
|
||||
pub use provider::WireApi;
|
||||
pub use provider::built_in_model_providers;
|
||||
pub use provider::create_oss_provider_with_base_url;
|
||||
123
codex-rs/codex-auth/src/model_provider_info_tests.rs
Normal file
123
codex-rs/codex-auth/src/model_provider_info_tests.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use super::provider::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_ollama_model_provider_toml() {
|
||||
let azure_provider_toml = r#"
|
||||
name = "Ollama"
|
||||
base_url = "http://localhost:11434/v1"
|
||||
"#;
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Ollama".into(),
|
||||
base_url: Some("http://localhost:11434/v1".into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
assert_eq!(expected_provider, provider);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_azure_model_provider_toml() {
|
||||
let azure_provider_toml = r#"
|
||||
name = "Azure"
|
||||
base_url = "https://xxxxx.openai.azure.com/openai"
|
||||
env_key = "AZURE_OPENAI_API_KEY"
|
||||
query_params = { api-version = "2025-04-01-preview" }
|
||||
"#;
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Azure".into(),
|
||||
base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
|
||||
env_key: Some("AZURE_OPENAI_API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: Some(maplit::hashmap! {
|
||||
"api-version".to_string() => "2025-04-01-preview".to_string(),
|
||||
}),
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
assert_eq!(expected_provider, provider);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_example_model_provider_toml() {
|
||||
let azure_provider_toml = r#"
|
||||
name = "Example"
|
||||
base_url = "https://example.com"
|
||||
env_key = "API_KEY"
|
||||
http_headers = { "X-Example-Header" = "example-value" }
|
||||
env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
"#;
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Example".into(),
|
||||
base_url: Some("https://example.com".into()),
|
||||
env_key: Some("API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: Some(maplit::hashmap! {
|
||||
"X-Example-Header".to_string() => "example-value".to_string(),
|
||||
}),
|
||||
env_http_headers: Some(maplit::hashmap! {
|
||||
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
|
||||
}),
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
assert_eq!(expected_provider, provider);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_chat_wire_api_shows_helpful_error() {
|
||||
let provider_toml = r#"
|
||||
name = "OpenAI using Chat Completions"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
env_key = "OPENAI_API_KEY"
|
||||
wire_api = "chat"
|
||||
"#;
|
||||
|
||||
let err = toml::from_str::<ModelProviderInfo>(provider_toml).unwrap_err();
|
||||
assert!(err.to_string().contains(CHAT_WIRE_API_REMOVED_ERROR));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_websocket_connect_timeout() {
|
||||
let provider_toml = r#"
|
||||
name = "OpenAI"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
websocket_connect_timeout_ms = 15000
|
||||
supports_websockets = true
|
||||
"#;
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(provider_toml).unwrap();
|
||||
assert_eq!(provider.websocket_connect_timeout_ms, Some(15_000));
|
||||
}
|
||||
286
codex-rs/codex-auth/src/provider.rs
Normal file
286
codex-rs/codex-auth/src/provider.rs
Normal file
@@ -0,0 +1,286 @@
|
||||
use crate::error::EnvVarError;
|
||||
use codex_api::Provider as ApiProvider;
|
||||
use codex_api::provider::RetryConfig as ApiRetryConfig;
|
||||
use codex_app_server_protocol::AuthMode as ApiAuthMode;
|
||||
use http::HeaderMap;
|
||||
use http::header::HeaderName;
|
||||
use http::header::HeaderValue;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::time::Duration;
|
||||
|
||||
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
|
||||
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
|
||||
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
pub const DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS: u64 = 15_000;
|
||||
const MAX_STREAM_MAX_RETRIES: u64 = 100;
|
||||
const MAX_REQUEST_MAX_RETRIES: u64 = 100;
|
||||
|
||||
const OPENAI_PROVIDER_NAME: &str = "OpenAI";
|
||||
pub const OPENAI_PROVIDER_ID: &str = "openai";
|
||||
pub const CHAT_WIRE_API_REMOVED_ERROR: &str = "`wire_api = \"chat\"` is no longer supported.\nHow to fix: set `wire_api = \"responses\"` in your provider config.\nMore info: https://github.com/openai/codex/discussions/7782";
|
||||
pub const LEGACY_OLLAMA_CHAT_PROVIDER_ID: &str = "ollama-chat";
|
||||
pub const OLLAMA_CHAT_PROVIDER_REMOVED_ERROR: &str = "`ollama-chat` is no longer supported.\nHow to fix: replace `ollama-chat` with `ollama` in `model_provider`, `oss_provider`, or `--local-provider`.\nMore info: https://github.com/openai/codex/discussions/7782";
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, JsonSchema)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum WireApi {
|
||||
#[default]
|
||||
Responses,
|
||||
}
|
||||
|
||||
impl fmt::Display for WireApi {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let value = match self {
|
||||
Self::Responses => "responses",
|
||||
};
|
||||
f.write_str(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for WireApi {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let value = String::deserialize(deserializer)?;
|
||||
match value.as_str() {
|
||||
"responses" => Ok(Self::Responses),
|
||||
"chat" => Err(serde::de::Error::custom(CHAT_WIRE_API_REMOVED_ERROR)),
|
||||
_ => Err(serde::de::Error::unknown_variant(&value, &["responses"])),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, JsonSchema)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct ModelProviderInfo {
|
||||
pub name: String,
|
||||
pub base_url: Option<String>,
|
||||
pub env_key: Option<String>,
|
||||
pub env_key_instructions: Option<String>,
|
||||
pub experimental_bearer_token: Option<String>,
|
||||
#[serde(default)]
|
||||
pub wire_api: WireApi,
|
||||
pub query_params: Option<HashMap<String, String>>,
|
||||
pub http_headers: Option<HashMap<String, String>>,
|
||||
pub env_http_headers: Option<HashMap<String, String>>,
|
||||
pub request_max_retries: Option<u64>,
|
||||
pub stream_max_retries: Option<u64>,
|
||||
pub stream_idle_timeout_ms: Option<u64>,
|
||||
pub websocket_connect_timeout_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub requires_openai_auth: bool,
|
||||
#[serde(default)]
|
||||
pub supports_websockets: bool,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
fn build_header_map(&self) -> HeaderMap {
|
||||
let capacity = self.http_headers.as_ref().map_or(0, HashMap::len)
|
||||
+ self.env_http_headers.as_ref().map_or(0, HashMap::len);
|
||||
let mut headers = HeaderMap::with_capacity(capacity);
|
||||
if let Some(extra) = &self.http_headers {
|
||||
for (k, v) in extra {
|
||||
if let (Ok(name), Ok(value)) = (HeaderName::try_from(k), HeaderValue::try_from(v)) {
|
||||
headers.insert(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(env_headers) = &self.env_http_headers {
|
||||
for (header, env_var) in env_headers {
|
||||
if let Ok(val) = std::env::var(env_var)
|
||||
&& !val.trim().is_empty()
|
||||
&& let (Ok(name), Ok(value)) =
|
||||
(HeaderName::try_from(header), HeaderValue::try_from(val))
|
||||
{
|
||||
headers.insert(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headers
|
||||
}
|
||||
|
||||
pub fn to_api_provider(
|
||||
&self,
|
||||
auth_mode: Option<ApiAuthMode>,
|
||||
) -> Result<ApiProvider, EnvVarError> {
|
||||
let default_base_url = if matches!(
|
||||
auth_mode,
|
||||
Some(ApiAuthMode::Chatgpt | ApiAuthMode::ChatgptAuthTokens)
|
||||
) {
|
||||
"https://chatgpt.com/backend-api/codex"
|
||||
} else {
|
||||
"https://api.openai.com/v1"
|
||||
};
|
||||
let base_url = self
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| default_base_url.to_string());
|
||||
|
||||
let retry = ApiRetryConfig {
|
||||
max_attempts: self.request_max_retries(),
|
||||
base_delay: Duration::from_millis(200),
|
||||
retry_429: false,
|
||||
retry_5xx: true,
|
||||
retry_transport: true,
|
||||
};
|
||||
|
||||
Ok(ApiProvider {
|
||||
name: self.name.clone(),
|
||||
base_url,
|
||||
query_params: self.query_params.clone(),
|
||||
headers: self.build_header_map(),
|
||||
retry,
|
||||
stream_idle_timeout: self.stream_idle_timeout(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn api_key(&self) -> Result<Option<String>, EnvVarError> {
|
||||
match &self.env_key {
|
||||
Some(env_key) => {
|
||||
let api_key = std::env::var(env_key)
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty())
|
||||
.ok_or_else(|| EnvVarError {
|
||||
var: env_key.clone(),
|
||||
instructions: self.env_key_instructions.clone(),
|
||||
})?;
|
||||
Ok(Some(api_key))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn request_max_retries(&self) -> u64 {
|
||||
self.request_max_retries
|
||||
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
|
||||
.min(MAX_REQUEST_MAX_RETRIES)
|
||||
}
|
||||
|
||||
pub fn stream_max_retries(&self) -> u64 {
|
||||
self.stream_max_retries
|
||||
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
|
||||
.min(MAX_STREAM_MAX_RETRIES)
|
||||
}
|
||||
|
||||
pub fn stream_idle_timeout(&self) -> Duration {
|
||||
self.stream_idle_timeout_ms
|
||||
.map(Duration::from_millis)
|
||||
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
|
||||
}
|
||||
|
||||
pub fn websocket_connect_timeout(&self) -> Duration {
|
||||
self.websocket_connect_timeout_ms
|
||||
.map(Duration::from_millis)
|
||||
.unwrap_or(Duration::from_millis(DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS))
|
||||
}
|
||||
|
||||
pub fn create_openai_provider(base_url: Option<String>) -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: OPENAI_PROVIDER_NAME.into(),
|
||||
base_url,
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: Some(
|
||||
[("version".to_string(), env!("CARGO_PKG_VERSION").to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
env_http_headers: Some(
|
||||
[
|
||||
(
|
||||
"OpenAI-Organization".to_string(),
|
||||
"OPENAI_ORGANIZATION".to_string(),
|
||||
),
|
||||
("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: true,
|
||||
supports_websockets: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_openai(&self) -> bool {
|
||||
self.name == OPENAI_PROVIDER_NAME
|
||||
}
|
||||
}
|
||||
|
||||
pub const DEFAULT_LMSTUDIO_PORT: u16 = 1234;
|
||||
pub const DEFAULT_OLLAMA_PORT: u16 = 11434;
|
||||
|
||||
pub const LMSTUDIO_OSS_PROVIDER_ID: &str = "lmstudio";
|
||||
pub const OLLAMA_OSS_PROVIDER_ID: &str = "ollama";
|
||||
|
||||
pub fn built_in_model_providers(
|
||||
openai_base_url: Option<String>,
|
||||
) -> HashMap<String, ModelProviderInfo> {
|
||||
use ModelProviderInfo as P;
|
||||
let openai_provider = P::create_openai_provider(openai_base_url);
|
||||
|
||||
[
|
||||
(OPENAI_PROVIDER_ID, openai_provider),
|
||||
(
|
||||
OLLAMA_OSS_PROVIDER_ID,
|
||||
create_oss_provider(DEFAULT_OLLAMA_PORT, WireApi::Responses),
|
||||
),
|
||||
(
|
||||
LMSTUDIO_OSS_PROVIDER_ID,
|
||||
create_oss_provider(DEFAULT_LMSTUDIO_PORT, WireApi::Responses),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.to_string(), v))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn create_oss_provider(default_provider_port: u16, wire_api: WireApi) -> ModelProviderInfo {
|
||||
let default_codex_oss_base_url = format!(
|
||||
"http://localhost:{codex_oss_port}/v1",
|
||||
codex_oss_port = std::env::var("CODEX_OSS_PORT")
|
||||
.ok()
|
||||
.filter(|value| !value.trim().is_empty())
|
||||
.and_then(|value| value.parse::<u16>().ok())
|
||||
.unwrap_or(default_provider_port)
|
||||
);
|
||||
|
||||
let codex_oss_base_url = std::env::var("CODEX_OSS_BASE_URL")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty())
|
||||
.unwrap_or(default_codex_oss_base_url);
|
||||
create_oss_provider_with_base_url(&codex_oss_base_url, wire_api)
|
||||
}
|
||||
|
||||
pub fn create_oss_provider_with_base_url(base_url: &str, wire_api: WireApi) -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "gpt-oss".into(),
|
||||
base_url: Some(base_url.into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
websocket_connect_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
supports_websockets: false,
|
||||
}
|
||||
}
|
||||
163
codex-rs/codex-auth/src/token_data.rs
Normal file
163
codex-rs/codex-auth/src/token_data.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
use base64::Engine;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
|
||||
pub struct TokenData {
|
||||
#[serde(
|
||||
deserialize_with = "deserialize_id_token",
|
||||
serialize_with = "serialize_id_token"
|
||||
)]
|
||||
pub id_token: IdTokenInfo,
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
pub account_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
pub struct IdTokenInfo {
|
||||
pub email: Option<String>,
|
||||
pub chatgpt_plan_type: Option<PlanType>,
|
||||
pub chatgpt_user_id: Option<String>,
|
||||
pub chatgpt_account_id: Option<String>,
|
||||
pub raw_jwt: String,
|
||||
}
|
||||
|
||||
impl IdTokenInfo {
|
||||
pub fn get_chatgpt_plan_type(&self) -> Option<String> {
|
||||
self.chatgpt_plan_type.as_ref().map(|t| match t {
|
||||
PlanType::Known(plan) => format!("{plan:?}"),
|
||||
PlanType::Unknown(s) => s.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_workspace_account(&self) -> bool {
|
||||
matches!(
|
||||
self.chatgpt_plan_type,
|
||||
Some(PlanType::Known(
|
||||
KnownPlan::Team | KnownPlan::Business | KnownPlan::Enterprise | KnownPlan::Edu
|
||||
))
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PlanType {
|
||||
Known(KnownPlan),
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
impl PlanType {
|
||||
pub fn from_raw_value(raw: &str) -> Self {
|
||||
match raw.to_ascii_lowercase().as_str() {
|
||||
"free" => Self::Known(KnownPlan::Free),
|
||||
"go" => Self::Known(KnownPlan::Go),
|
||||
"plus" => Self::Known(KnownPlan::Plus),
|
||||
"pro" => Self::Known(KnownPlan::Pro),
|
||||
"team" => Self::Known(KnownPlan::Team),
|
||||
"business" => Self::Known(KnownPlan::Business),
|
||||
"enterprise" => Self::Known(KnownPlan::Enterprise),
|
||||
"education" | "edu" => Self::Known(KnownPlan::Edu),
|
||||
_ => Self::Unknown(raw.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum KnownPlan {
|
||||
Free,
|
||||
Go,
|
||||
Plus,
|
||||
Pro,
|
||||
Team,
|
||||
Business,
|
||||
Enterprise,
|
||||
Edu,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct IdClaims {
|
||||
#[serde(default)]
|
||||
email: Option<String>,
|
||||
#[serde(rename = "https://api.openai.com/profile", default)]
|
||||
profile: Option<ProfileClaims>,
|
||||
#[serde(rename = "https://api.openai.com/auth", default)]
|
||||
auth: Option<AuthClaims>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ProfileClaims {
|
||||
#[serde(default)]
|
||||
email: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthClaims {
|
||||
#[serde(default)]
|
||||
chatgpt_plan_type: Option<PlanType>,
|
||||
#[serde(default)]
|
||||
chatgpt_user_id: Option<String>,
|
||||
#[serde(default)]
|
||||
user_id: Option<String>,
|
||||
#[serde(default)]
|
||||
chatgpt_account_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum IdTokenInfoError {
|
||||
#[error("invalid ID token format")]
|
||||
InvalidFormat,
|
||||
#[error(transparent)]
|
||||
Base64(#[from] base64::DecodeError),
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
|
||||
let mut parts = jwt.split('.');
|
||||
let (_header_b64, payload_b64, _sig_b64) = match (parts.next(), parts.next(), parts.next()) {
|
||||
(Some(h), Some(p), Some(s)) if !h.is_empty() && !p.is_empty() && !s.is_empty() => (h, p, s),
|
||||
_ => return Err(IdTokenInfoError::InvalidFormat),
|
||||
};
|
||||
|
||||
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64)?;
|
||||
let claims: IdClaims = serde_json::from_slice(&payload_bytes)?;
|
||||
let email = claims
|
||||
.email
|
||||
.or_else(|| claims.profile.and_then(|profile| profile.email));
|
||||
|
||||
match claims.auth {
|
||||
Some(auth) => Ok(IdTokenInfo {
|
||||
email,
|
||||
raw_jwt: jwt.to_string(),
|
||||
chatgpt_plan_type: auth.chatgpt_plan_type,
|
||||
chatgpt_user_id: auth.chatgpt_user_id.or(auth.user_id),
|
||||
chatgpt_account_id: auth.chatgpt_account_id,
|
||||
}),
|
||||
None => Ok(IdTokenInfo {
|
||||
email,
|
||||
raw_jwt: jwt.to_string(),
|
||||
chatgpt_plan_type: None,
|
||||
chatgpt_user_id: None,
|
||||
chatgpt_account_id: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_id_token<'de, D>(deserializer: D) -> Result<IdTokenInfo, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
parse_chatgpt_jwt_claims(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
|
||||
fn serialize_id_token<S>(id_token: &IdTokenInfo, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&id_token.raw_jwt)
|
||||
}
|
||||
110
codex-rs/codex-auth/src/token_data_tests.rs
Normal file
110
codex-rs/codex-auth/src/token_data_tests.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
use base64::Engine;
|
||||
use super::token_data::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::Serialize;
|
||||
|
||||
#[test]
|
||||
fn id_token_info_parses_email_and_plan() {
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({
|
||||
"email": "user@example.com",
|
||||
"https://api.openai.com/auth": {
|
||||
"chatgpt_plan_type": "pro"
|
||||
}
|
||||
});
|
||||
|
||||
fn b64url_no_pad(bytes: &[u8]) -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
|
||||
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
|
||||
let signature_b64 = b64url_no_pad(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
|
||||
assert_eq!(info.email.as_deref(), Some("user@example.com"));
|
||||
assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Pro"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn id_token_info_parses_go_plan() {
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({
|
||||
"email": "user@example.com",
|
||||
"https://api.openai.com/auth": {
|
||||
"chatgpt_plan_type": "go"
|
||||
}
|
||||
});
|
||||
|
||||
fn b64url_no_pad(bytes: &[u8]) -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
|
||||
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
|
||||
let signature_b64 = b64url_no_pad(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
|
||||
assert_eq!(info.email.as_deref(), Some("user@example.com"));
|
||||
assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Go"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn id_token_info_handles_missing_fields() {
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({ "sub": "123" });
|
||||
|
||||
fn b64url_no_pad(bytes: &[u8]) -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
|
||||
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
|
||||
let signature_b64 = b64url_no_pad(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
|
||||
assert!(info.email.is_none());
|
||||
assert!(info.get_chatgpt_plan_type().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_account_detection_matches_workspace_plans() {
|
||||
let workspace = IdTokenInfo {
|
||||
chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Business)),
|
||||
..IdTokenInfo::default()
|
||||
};
|
||||
assert_eq!(workspace.is_workspace_account(), true);
|
||||
|
||||
let personal = IdTokenInfo {
|
||||
chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
..IdTokenInfo::default()
|
||||
};
|
||||
assert_eq!(personal.is_workspace_account(), false);
|
||||
}
|
||||
Reference in New Issue
Block a user