Compare commits

...

3 Commits

Author SHA1 Message Date
Celia Chen
17fbd32a39 Merge branch 'main' into pakrym/refactor-modelprovider-api 2026-04-15 12:24:28 -07:00
Celia Chen
25f7f2fff6 Merge branch 'main' into pakrym/refactor-modelprovider-api 2026-04-15 11:12:24 -07:00
pakrym-oai
441147f203 Wrap model providers behind runtime trait 2026-04-14 18:46:04 -07:00
11 changed files with 103 additions and 34 deletions

View File

@@ -100,6 +100,7 @@ use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::flags::CODEX_RS_SSE_FIXTURE;
use crate::model_provider::ModelProvider;
use crate::util::emit_feedback_auth_recovery_tags;
use codex_api::CoreAuthProvider;
use codex_api::map_api_error;
@@ -147,7 +148,7 @@ struct ModelClientState {
conversation_id: ThreadId,
window_generation: AtomicU64,
installation_id: String,
provider: ModelProviderInfo,
provider: Arc<dyn ModelProvider>,
auth_env_telemetry: AuthEnvTelemetry,
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
@@ -303,6 +304,7 @@ impl ModelClient {
.as_ref()
.is_some_and(|manager| manager.codex_api_key_env_enabled());
let auth_env_telemetry = collect_auth_env_telemetry(&provider, codex_api_key_env_enabled);
let provider = <dyn ModelProvider>::new(provider, auth_manager.clone());
Self {
state: Arc::new(ModelClientState {
auth_manager,
@@ -636,7 +638,7 @@ impl ModelClient {
///
/// WebSocket use is controlled by provider capability and session-scoped fallback state.
pub fn responses_websocket_enabled(&self) -> bool {
if !self.state.provider.supports_websockets
if !self.state.provider.info().supports_websockets
|| self.state.disable_websockets.load(Ordering::Relaxed)
|| (*CODEX_RS_SSE_FIXTURE).is_some()
{
@@ -651,15 +653,16 @@ impl ModelClient {
/// This centralizes setup used by both prewarm and normal request paths so they stay in
/// lockstep when auth/provider resolution changes.
async fn current_client_setup(&self) -> Result<CurrentClientSetup> {
let auth = match self.state.auth_manager.as_ref() {
let auth = match self.state.provider.auth_manager() {
Some(manager) => manager.auth().await,
None => None,
};
let api_provider = self
.state
.provider
.info()
.to_api_provider(auth.as_ref().map(CodexAuth::auth_mode))?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?;
let api_auth = auth_provider_from_auth(auth.clone(), self.state.provider.info())?;
Ok(CurrentClientSetup {
auth,
api_provider,
@@ -689,7 +692,7 @@ impl ModelClient {
request_route_telemetry,
self.state.auth_env_telemetry.clone(),
);
let websocket_connect_timeout = self.state.provider.websocket_connect_timeout();
let websocket_connect_timeout = self.state.provider.info().websocket_connect_timeout();
let start = Instant::now();
let result = match tokio::time::timeout(
websocket_connect_timeout,
@@ -1033,8 +1036,8 @@ impl ModelClientSession {
level = "info",
skip_all,
fields(
provider = %self.client.state.provider.name,
wire_api = %self.client.state.provider.wire_api,
provider = %self.client.state.provider.info().name,
wire_api = %self.client.state.provider.info().wire_api,
transport = "responses_websocket",
api.path = "responses",
turn.has_metadata_header = params.turn_metadata_header.is_some()
@@ -1105,7 +1108,7 @@ impl ModelClientSession {
fn responses_request_compression(&self, auth: Option<&CodexAuth>) -> Compression {
if self.client.state.enable_request_compression
&& auth.is_some_and(CodexAuth::is_chatgpt_auth)
&& self.client.state.provider.is_openai()
&& self.client.state.provider.info().is_openai()
{
Compression::Zstd
} else {
@@ -1124,7 +1127,7 @@ impl ModelClientSession {
skip_all,
fields(
model = %model_info.slug,
wire_api = %self.client.state.provider.wire_api,
wire_api = %self.client.state.provider.info().wire_api,
transport = "responses_http",
http.method = "POST",
api.path = "responses",
@@ -1145,7 +1148,7 @@ impl ModelClientSession {
warn!(path, "Streaming from fixture");
let stream = codex_api::stream_from_fixture(
path,
self.client.state.provider.stream_idle_timeout(),
self.client.state.provider.info().stream_idle_timeout(),
)
.map_err(map_api_error)?;
let (stream, _last_request_rx) = map_response_stream(stream, session_telemetry.clone());
@@ -1221,7 +1224,7 @@ impl ModelClientSession {
skip_all,
fields(
model = %model_info.slug,
wire_api = %self.client.state.provider.wire_api,
wire_api = %self.client.state.provider.info().wire_api,
transport = "responses_websocket",
api.path = "responses",
turn.has_metadata_header = turn_metadata_header.is_some(),
@@ -1432,7 +1435,7 @@ impl ModelClientSession {
service_tier: Option<ServiceTier>,
turn_metadata_header: Option<&str>,
) -> Result<ResponseStream> {
let wire_api = self.client.state.provider.wire_api;
let wire_api = self.client.state.provider.info().wire_api;
match wire_api {
WireApi::Responses => {
if self.client.responses_websocket_enabled() {

View File

@@ -77,6 +77,7 @@ use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_login::auth_env_telemetry::collect_auth_env_telemetry;
use codex_login::default_client::originator;
use codex_login::provider_auth::auth_manager_for_provider;
use codex_mcp::McpConnectionManager;
use codex_mcp::ToolInfo;
use codex_mcp::codex_apps_tools_cache_key;
@@ -193,6 +194,7 @@ use crate::config::resolve_web_search_mode_for_turn;
use crate::context_manager::ContextManager;
use crate::context_manager::TotalTokenUsageBreakdown;
use crate::environment_context::EnvironmentContext;
use crate::model_provider::ModelProvider;
use crate::thread_rollout_truncation::initial_history_has_prior_user_turns;
use codex_config::CONFIG_TOML_FILE;
use codex_config::types::McpServerConfig;
@@ -889,7 +891,7 @@ pub(crate) struct TurnContext {
pub(crate) auth_manager: Option<Arc<AuthManager>>,
pub(crate) model_info: ModelInfo,
pub(crate) session_telemetry: SessionTelemetry,
pub(crate) provider: ModelProviderInfo,
pub(crate) provider: Arc<dyn ModelProvider>,
pub(crate) reasoning_effort: Option<ReasoningEffortConfig>,
pub(crate) reasoning_summary: ReasoningSummaryConfig,
pub(crate) session_source: SessionSource,
@@ -1592,8 +1594,9 @@ impl Session {
let session_source = session_configuration.session_source.clone();
let image_generation_tool_auth_allowed =
image_generation_tool_auth_allowed(auth_manager.as_deref());
let auth_manager_for_context = auth_manager;
let provider_for_context = provider;
let auth_manager_for_context = auth_manager.clone();
let provider_auth_manager = auth_manager_for_provider(auth_manager, &provider);
let provider_for_context = <dyn ModelProvider>::new(provider, provider_auth_manager);
let session_telemetry_for_context = session_telemetry;
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_info: &model_info,
@@ -5978,7 +5981,7 @@ async fn spawn_review_thread(
));
let review_prompt = resolved.prompt.clone();
let provider = parent_turn_context.provider.clone();
let provider = Arc::clone(&parent_turn_context.provider);
let auth_manager = parent_turn_context.auth_manager.clone();
let model_info = review_model_info.clone();
@@ -6803,7 +6806,7 @@ async fn run_auto_compact(
reason: CompactionReason,
phase: CompactionPhase,
) -> CodexResult<()> {
if should_use_remote_compact_task(&turn_context.provider) {
if should_use_remote_compact_task(turn_context.provider.as_ref()) {
run_inline_remote_auto_compact_task(
Arc::clone(sess),
Arc::clone(turn_context),
@@ -7084,7 +7087,7 @@ async fn run_sampling_request(
}
// Use the configured provider-specific stream retry budget.
let max_retries = turn_context.provider.stream_max_retries();
let max_retries = turn_context.provider.info().stream_max_retries();
if retries >= max_retries
&& client_session.try_switch_fallback_transport(
&turn_context.session_telemetry,

View File

@@ -7,6 +7,7 @@ use crate::exec::ExecCapturePolicy;
use crate::exec::ExecParams;
use crate::exec_policy::ExecPolicyManager;
use crate::guardian::GUARDIAN_REVIEWER_NAME;
use crate::model_provider::ModelProvider;
use crate::sandboxing::SandboxPermissions;
use crate::tools::context::FunctionToolOutput;
use crate::turn_diff_tracker::TurnDiffTracker;
@@ -101,7 +102,10 @@ async fn guardian_allows_shell_additional_permissions_requests_past_policy_valid
));
session.services.models_manager = models_manager;
turn_context_raw.config = Arc::clone(&config);
turn_context_raw.provider = config.model_provider.clone();
turn_context_raw.provider = <dyn ModelProvider>::new(
config.model_provider.clone(),
turn_context_raw.auth_manager.clone(),
);
let session = Arc::new(session);
let turn_context = Arc::new(turn_context_raw);
let expiration_ms: u64 = if cfg!(windows) { 2_500 } else { 1_000 };

View File

@@ -19,7 +19,6 @@ use codex_analytics::CompactionStrategy;
use codex_analytics::CompactionTrigger;
use codex_analytics::now_unix_seconds;
use codex_features::Feature;
use codex_model_provider_info::ModelProviderInfo;
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result as CodexResult;
use codex_protocol::items::ContextCompactionItem;
@@ -38,6 +37,8 @@ use codex_utils_output_truncation::truncate_text;
use futures::prelude::*;
use tracing::error;
use crate::model_provider::ModelProvider;
pub const SUMMARIZATION_PROMPT: &str = include_str!("../templates/compact/prompt.md");
pub const SUMMARY_PREFIX: &str = include_str!("../templates/compact/summary_prefix.md");
const COMPACT_USER_MESSAGE_MAX_TOKENS: usize = 20_000;
@@ -57,8 +58,8 @@ pub(crate) enum InitialContextInjection {
DoNotInject,
}
pub(crate) fn should_use_remote_compact_task(provider: &ModelProviderInfo) -> bool {
provider.is_openai()
pub(crate) fn should_use_remote_compact_task(provider: &dyn ModelProvider) -> bool {
provider.info().is_openai()
}
pub(crate) async fn run_inline_auto_compact_task(
@@ -166,7 +167,7 @@ async fn run_compact_task_inner_impl(
let mut truncated_count = 0usize;
let max_retries = turn_context.provider.stream_max_retries();
let max_retries = turn_context.provider.info().stream_max_retries();
let mut retries = 0;
let mut client_session = sess.services.model_client.new_session();
// Reuse one client session so turn-scoped state (sticky routing, websocket incremental

View File

@@ -14,6 +14,7 @@ use crate::config_loader::NetworkDomainPermissionToml;
use crate::config_loader::NetworkDomainPermissionsToml;
use crate::config_loader::RequirementSource;
use crate::config_loader::Sourced;
use crate::model_provider::ModelProvider;
use crate::test_support;
use codex_config::config_toml::ConfigToml;
use codex_network_proxy::NetworkProxyConfig;
@@ -82,7 +83,8 @@ async fn guardian_test_session_and_turn_with_base_url(
));
session.services.models_manager = models_manager;
turn.config = Arc::clone(&config);
turn.provider = config.model_provider.clone();
turn.provider =
<dyn ModelProvider>::new(config.model_provider.clone(), turn.auth_manager.clone());
turn.user_instructions = None;
(Arc::new(session), Arc::new(turn))
@@ -888,7 +890,8 @@ async fn guardian_review_request_layout_matches_model_visible_request_snapshot()
));
session.services.models_manager = models_manager;
turn.config = Arc::clone(&config);
turn.provider = config.model_provider.clone();
turn.provider =
<dyn ModelProvider>::new(config.model_provider.clone(), turn.auth_manager.clone());
let session = Arc::new(session);
let turn = Arc::new(turn);
seed_guardian_parent_history(&session, &turn).await;
@@ -1260,7 +1263,8 @@ async fn guardian_review_surfaces_responses_api_errors_in_rejection_reason() ->
.models_manager = models_manager;
let turn_mut = Arc::get_mut(&mut turn).expect("turn should be uniquely owned");
turn_mut.config = Arc::clone(&config);
turn_mut.provider = config.model_provider.clone();
turn_mut.provider =
<dyn ModelProvider>::new(config.model_provider.clone(), turn_mut.auth_manager.clone());
turn_mut.user_instructions = None;
seed_guardian_parent_history(&session, &turn).await;

View File

@@ -61,6 +61,7 @@ mod mcp_tool_call;
mod memories;
pub(crate) mod mention_syntax;
pub(crate) mod message_history;
mod model_provider;
pub(crate) mod utils;
pub use mention_syntax::PLUGIN_TEXT_MENTION_SIGIL;
pub use mention_syntax::TOOL_MENTION_SIGIL;

View File

@@ -2,6 +2,7 @@ use super::*;
use crate::codex::make_session_and_context;
use crate::codex::make_session_and_context_with_rx;
use crate::config::ConfigBuilder;
use crate::model_provider::ModelProvider;
use crate::state::ActiveTurn;
use codex_config::CONFIG_TOML_FILE;
use codex_config::config_toml::ConfigToml;
@@ -1320,7 +1321,10 @@ async fn guardian_mode_skips_auto_when_annotations_do_not_require_approval() {
));
session.services.models_manager = models_manager;
turn_context.config = Arc::clone(&config);
turn_context.provider = config.model_provider.clone();
turn_context.provider = <dyn ModelProvider>::new(
config.model_provider.clone(),
turn_context.auth_manager.clone(),
);
let session = Arc::new(session);
let turn_context = Arc::new(turn_context);
@@ -1395,7 +1399,10 @@ async fn guardian_mode_mcp_denial_returns_rationale_message() {
));
session.services.models_manager = models_manager;
turn_context.config = Arc::clone(&config);
turn_context.provider = config.model_provider.clone();
turn_context.provider = <dyn ModelProvider>::new(
config.model_provider.clone(),
turn_context.auth_manager.clone(),
);
let session = Arc::new(session);
let turn_context = Arc::new(turn_context);
@@ -1843,7 +1850,10 @@ async fn approve_mode_routes_arc_ask_user_to_guardian_when_guardian_reviewer_is_
));
session.services.models_manager = models_manager;
turn_context.config = Arc::clone(&config);
turn_context.provider = config.model_provider.clone();
turn_context.provider = <dyn ModelProvider>::new(
config.model_provider.clone(),
turn_context.auth_manager.clone(),
);
let session = Arc::new(session);
let turn_context = Arc::new(turn_context);

View File

@@ -0,0 +1,42 @@
use std::fmt;
use std::sync::Arc;
use codex_login::AuthManager;
use codex_model_provider_info::ModelProviderInfo;
/// Runtime provider abstraction used by turn execution.
///
/// Implementations own provider-specific behavior for a model backend. The
/// `ModelProviderInfo` returned by `info` is the serialized/configured provider
/// metadata used by the generic OpenAI-compatible implementation.
pub(crate) trait ModelProvider: fmt::Debug + Send + Sync {
fn info(&self) -> &ModelProviderInfo;
fn auth_manager(&self) -> Option<&AuthManager>;
}
impl dyn ModelProvider {
pub(crate) fn new(
info: ModelProviderInfo,
auth_manager: Option<Arc<AuthManager>>,
) -> Arc<Self> {
Arc::new(GenericModelProvider { info, auth_manager })
}
}
/// Generic OpenAI-compatible model provider backed by a `ModelProviderInfo`.
#[derive(Clone, Debug)]
struct GenericModelProvider {
info: ModelProviderInfo,
auth_manager: Option<Arc<AuthManager>>,
}
impl ModelProvider for GenericModelProvider {
fn info(&self) -> &ModelProviderInfo {
&self.info
}
fn auth_manager(&self) -> Option<&AuthManager> {
self.auth_manager.as_deref()
}
}

View File

@@ -27,7 +27,7 @@ impl SessionTask for CompactTask {
_cancellation_token: CancellationToken,
) -> Option<String> {
let session = session.clone_session();
let _ = if crate::compact::should_use_remote_compact_task(&ctx.provider) {
let _ = if crate::compact::should_use_remote_compact_task(ctx.provider.as_ref()) {
session.services.session_telemetry.counter(
"codex.task.compact",
/*inc*/ 1,

View File

@@ -224,7 +224,7 @@ fn build_agent_shared_config(turn: &TurnContext) -> Result<Config, FunctionCallE
let base_config = turn.config.clone();
let mut config = (*base_config).clone();
config.model = Some(turn.model_info.slug.clone());
config.model_provider = turn.provider.clone();
config.model_provider = turn.provider.info().clone();
config.model_reasoning_effort = turn
.reasoning_effort
.or(turn.model_info.default_reasoning_level);

View File

@@ -5,6 +5,7 @@ use crate::codex::make_session_and_context;
use crate::config::AgentRoleConfig;
use crate::config::DEFAULT_AGENT_MAX_DEPTH;
use crate::function_tool::FunctionCallError;
use crate::model_provider::ModelProvider;
use crate::session_prefix::format_subagent_notification_message;
use crate::state::TaskKind;
use crate::tasks::SessionTask;
@@ -365,7 +366,7 @@ async fn spawn_agent_uses_explorer_role_and_preserves_approval_policy() {
turn.approval_policy
.set(AskForApproval::OnRequest)
.expect("approval policy should be set");
turn.provider = provider;
turn.provider = <dyn ModelProvider>::new(provider, turn.auth_manager.clone());
turn.config = Arc::new(config);
let invocation = invocation(
@@ -3505,7 +3506,7 @@ async fn build_agent_spawn_config_uses_turn_context_values() {
let mut expected = (*turn.config).clone();
expected.base_instructions = Some(base_instructions.text);
expected.model = Some(turn.model_info.slug.clone());
expected.model_provider = turn.provider.clone();
expected.model_provider = turn.provider.info().clone();
expected.model_reasoning_effort = turn.reasoning_effort;
expected.model_reasoning_summary = Some(turn.reasoning_summary);
expected.developer_instructions = turn.developer_instructions.clone();
@@ -3559,7 +3560,7 @@ async fn build_agent_resume_config_clears_base_instructions() {
let mut expected = (*turn.config).clone();
expected.base_instructions = None;
expected.model = Some(turn.model_info.slug.clone());
expected.model_provider = turn.provider.clone();
expected.model_provider = turn.provider.info().clone();
expected.model_reasoning_effort = turn.reasoning_effort;
expected.model_reasoning_summary = Some(turn.reasoning_summary);
expected.developer_instructions = turn.developer_instructions.clone();