mirror of
https://github.com/openai/codex.git
synced 2026-04-28 02:11:08 +03:00
feat: merge remote models instead of destructing (#7997)
- merge remote models instead of destructing - make config values have more precedent over remote values
This commit is contained in:
@@ -6,6 +6,7 @@ use codex_protocol::openai_models::ModelInfo;
|
||||
use codex_protocol::openai_models::ModelPreset;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use http::HeaderMap;
|
||||
use std::collections::HashSet;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -35,7 +36,7 @@ const CODEX_AUTO_BALANCED_MODEL: &str = "codex-auto-balanced";
|
||||
#[derive(Debug)]
|
||||
pub struct ModelsManager {
|
||||
// todo(aibrahim) merge available_models and model family creation into one struct
|
||||
available_models: RwLock<Vec<ModelPreset>>,
|
||||
local_models: Vec<ModelPreset>,
|
||||
remote_models: RwLock<Vec<ModelInfo>>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
etag: RwLock<Option<String>>,
|
||||
@@ -49,7 +50,7 @@ impl ModelsManager {
|
||||
pub fn new(auth_manager: Arc<AuthManager>) -> Self {
|
||||
let codex_home = auth_manager.codex_home().to_path_buf();
|
||||
Self {
|
||||
available_models: RwLock::new(builtin_model_presets(auth_manager.get_auth_mode())),
|
||||
local_models: builtin_model_presets(auth_manager.get_auth_mode()),
|
||||
remote_models: RwLock::new(Vec::new()),
|
||||
auth_manager,
|
||||
etag: RwLock::new(None),
|
||||
@@ -64,7 +65,7 @@ impl ModelsManager {
|
||||
pub fn with_provider(auth_manager: Arc<AuthManager>, provider: ModelProviderInfo) -> Self {
|
||||
let codex_home = auth_manager.codex_home().to_path_buf();
|
||||
Self {
|
||||
available_models: RwLock::new(builtin_model_presets(auth_manager.get_auth_mode())),
|
||||
local_models: builtin_model_presets(auth_manager.get_auth_mode()),
|
||||
remote_models: RwLock::new(Vec::new()),
|
||||
auth_manager,
|
||||
etag: RwLock::new(None),
|
||||
@@ -107,13 +108,13 @@ impl ModelsManager {
|
||||
if let Err(err) = self.refresh_available_models(config).await {
|
||||
error!("failed to refresh available models: {err}");
|
||||
}
|
||||
self.available_models.read().await.clone()
|
||||
let remote_models = self.remote_models.read().await.clone();
|
||||
self.build_available_models(remote_models)
|
||||
}
|
||||
|
||||
pub fn try_list_models(&self) -> Result<Vec<ModelPreset>, TryLockError> {
|
||||
self.available_models
|
||||
.try_read()
|
||||
.map(|models| models.clone())
|
||||
let remote_models = self.remote_models.try_read()?.clone();
|
||||
Ok(self.build_available_models(remote_models))
|
||||
}
|
||||
|
||||
fn find_family_for_model(slug: &str) -> ModelFamily {
|
||||
@@ -123,8 +124,8 @@ impl ModelsManager {
|
||||
/// Look up the requested model family while applying remote metadata overrides.
|
||||
pub async fn construct_model_family(&self, model: &str, config: &Config) -> ModelFamily {
|
||||
Self::find_family_for_model(model)
|
||||
.with_config_overrides(config)
|
||||
.with_remote_overrides(self.remote_models.read().await.clone())
|
||||
.with_config_overrides(config)
|
||||
}
|
||||
|
||||
pub async fn get_model(&self, model: &Option<String>, config: &Config) -> String {
|
||||
@@ -136,11 +137,10 @@ impl ModelsManager {
|
||||
}
|
||||
// if codex-auto-balanced exists & signed in with chatgpt mode, return it, otherwise return the default model
|
||||
let auth_mode = self.auth_manager.get_auth_mode();
|
||||
let remote_models = self.remote_models.read().await.clone();
|
||||
if auth_mode == Some(AuthMode::ChatGPT)
|
||||
&& self
|
||||
.available_models
|
||||
.read()
|
||||
.await
|
||||
.build_available_models(remote_models)
|
||||
.iter()
|
||||
.any(|m| m.model == CODEX_AUTO_BALANCED_MODEL)
|
||||
{
|
||||
@@ -163,7 +163,6 @@ impl ModelsManager {
|
||||
/// Replace the cached remote models and rebuild the derived presets list.
|
||||
async fn apply_remote_models(&self, models: Vec<ModelInfo>) {
|
||||
*self.remote_models.write().await = models;
|
||||
self.build_available_models().await;
|
||||
}
|
||||
|
||||
/// Attempt to satisfy the refresh from the cache when it matches the provider and TTL.
|
||||
@@ -203,22 +202,55 @@ impl ModelsManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert remote model metadata into picker-ready presets, marking defaults.
|
||||
async fn build_available_models(&self) {
|
||||
let mut available_models = self.remote_models.read().await.clone();
|
||||
available_models.sort_by(|a, b| a.priority.cmp(&b.priority));
|
||||
let mut model_presets: Vec<ModelPreset> = available_models
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.filter(|preset: &ModelPreset| preset.show_in_picker)
|
||||
.collect();
|
||||
if let Some(default) = model_presets.first_mut() {
|
||||
/// Merge remote model metadata into picker-ready presets, preserving existing entries.
|
||||
fn build_available_models(&self, mut remote_models: Vec<ModelInfo>) -> Vec<ModelPreset> {
|
||||
remote_models.sort_by(|a, b| a.priority.cmp(&b.priority));
|
||||
|
||||
let remote_presets: Vec<ModelPreset> = remote_models.into_iter().map(Into::into).collect();
|
||||
let existing_presets = self.local_models.clone();
|
||||
let mut merged_presets = Self::merge_presets(remote_presets, existing_presets);
|
||||
merged_presets = Self::filter_visible_models(merged_presets);
|
||||
|
||||
let has_default = merged_presets.iter().any(|preset| preset.is_default);
|
||||
if let Some(default) = merged_presets.first_mut()
|
||||
&& !has_default
|
||||
{
|
||||
default.is_default = true;
|
||||
}
|
||||
{
|
||||
let mut available_models_guard = self.available_models.write().await;
|
||||
*available_models_guard = model_presets;
|
||||
|
||||
merged_presets
|
||||
}
|
||||
|
||||
fn filter_visible_models(models: Vec<ModelPreset>) -> Vec<ModelPreset> {
|
||||
models
|
||||
.into_iter()
|
||||
.filter(|model| model.show_in_picker)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn merge_presets(
|
||||
remote_presets: Vec<ModelPreset>,
|
||||
existing_presets: Vec<ModelPreset>,
|
||||
) -> Vec<ModelPreset> {
|
||||
if remote_presets.is_empty() {
|
||||
return existing_presets;
|
||||
}
|
||||
|
||||
let remote_slugs: HashSet<&str> = remote_presets
|
||||
.iter()
|
||||
.map(|preset| preset.model.as_str())
|
||||
.collect();
|
||||
|
||||
let mut merged_presets = remote_presets.clone();
|
||||
for mut preset in existing_presets {
|
||||
if remote_slugs.contains(preset.model.as_str()) {
|
||||
continue;
|
||||
}
|
||||
preset.is_default = false;
|
||||
merged_presets.push(preset);
|
||||
}
|
||||
|
||||
merged_presets
|
||||
}
|
||||
|
||||
fn cache_path(&self) -> PathBuf {
|
||||
@@ -261,11 +293,21 @@ mod tests {
|
||||
use crate::model_provider_info::WireApi;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use core_test_support::responses::mount_models_once;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::tempdir;
|
||||
use wiremock::MockServer;
|
||||
|
||||
fn remote_model(slug: &str, display: &str, priority: i32) -> ModelInfo {
|
||||
remote_model_with_visibility(slug, display, priority, "list")
|
||||
}
|
||||
|
||||
fn remote_model_with_visibility(
|
||||
slug: &str,
|
||||
display: &str,
|
||||
priority: i32,
|
||||
visibility: &str,
|
||||
) -> ModelInfo {
|
||||
serde_json::from_value(json!({
|
||||
"slug": slug,
|
||||
"display_name": display,
|
||||
@@ -273,7 +315,7 @@ mod tests {
|
||||
"default_reasoning_level": "medium",
|
||||
"supported_reasoning_levels": [{"effort": "low", "description": "low"}, {"effort": "medium", "description": "medium"}],
|
||||
"shell_type": "shell_command",
|
||||
"visibility": "list",
|
||||
"visibility": visibility,
|
||||
"minimal_client_version": [0, 1, 0],
|
||||
"supported_in_api": true,
|
||||
"priority": priority,
|
||||
@@ -347,14 +389,23 @@ mod tests {
|
||||
assert_eq!(cached_remote, remote_models);
|
||||
|
||||
let available = manager.list_models(&config).await;
|
||||
assert_eq!(available.len(), 2);
|
||||
assert_eq!(available[0].model, "priority-high");
|
||||
let high_idx = available
|
||||
.iter()
|
||||
.position(|model| model.model == "priority-high")
|
||||
.expect("priority-high should be listed");
|
||||
let low_idx = available
|
||||
.iter()
|
||||
.position(|model| model.model == "priority-low")
|
||||
.expect("priority-low should be listed");
|
||||
assert!(
|
||||
available[0].is_default,
|
||||
high_idx < low_idx,
|
||||
"higher priority should be listed before lower priority"
|
||||
);
|
||||
assert!(
|
||||
available[high_idx].is_default,
|
||||
"highest priority should be default"
|
||||
);
|
||||
assert_eq!(available[1].model, "priority-low");
|
||||
assert!(!available[1].is_default);
|
||||
assert!(!available[low_idx].is_default);
|
||||
assert_eq!(
|
||||
models_mock.requests().len(),
|
||||
1,
|
||||
@@ -493,4 +544,94 @@ mod tests {
|
||||
"stale cache refresh should fetch /models once"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_available_models_drops_removed_remote_models() {
|
||||
let server = MockServer::start().await;
|
||||
let initial_models = vec![remote_model("remote-old", "Remote Old", 1)];
|
||||
let initial_mock = mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: initial_models,
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let codex_home = tempdir().expect("temp dir");
|
||||
let mut config = Config::load_from_base_config_with_overrides(
|
||||
ConfigToml::default(),
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)
|
||||
.expect("load default test config");
|
||||
config.features.enable(Feature::RemoteModels);
|
||||
let auth_manager =
|
||||
AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
|
||||
let provider = provider_for(server.uri());
|
||||
let mut manager = ModelsManager::with_provider(auth_manager, provider);
|
||||
manager.cache_ttl = Duration::ZERO;
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.await
|
||||
.expect("initial refresh succeeds");
|
||||
|
||||
server.reset().await;
|
||||
let refreshed_models = vec![remote_model("remote-new", "Remote New", 1)];
|
||||
let refreshed_mock = mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: refreshed_models,
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.await
|
||||
.expect("second refresh succeeds");
|
||||
|
||||
let available = manager
|
||||
.try_list_models()
|
||||
.expect("models should be available");
|
||||
assert!(
|
||||
available.iter().any(|preset| preset.model == "remote-new"),
|
||||
"new remote model should be listed"
|
||||
);
|
||||
assert!(
|
||||
!available.iter().any(|preset| preset.model == "remote-old"),
|
||||
"removed remote model should not be listed"
|
||||
);
|
||||
assert_eq!(
|
||||
initial_mock.requests().len(),
|
||||
1,
|
||||
"initial refresh should only hit /models once"
|
||||
);
|
||||
assert_eq!(
|
||||
refreshed_mock.requests().len(),
|
||||
1,
|
||||
"second refresh should only hit /models once"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_available_models_picks_default_after_hiding_hidden_models() {
|
||||
let auth_manager =
|
||||
AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
|
||||
let provider = provider_for("http://example.test".to_string());
|
||||
let mut manager = ModelsManager::with_provider(auth_manager, provider);
|
||||
manager.local_models = Vec::new();
|
||||
|
||||
let hidden_model = remote_model_with_visibility("hidden", "Hidden", 0, "hide");
|
||||
let visible_model = remote_model_with_visibility("visible", "Visible", 1, "list");
|
||||
|
||||
let mut expected = ModelPreset::from(visible_model.clone());
|
||||
expected.is_default = true;
|
||||
|
||||
let available = manager.build_available_models(vec![hidden_model, visible_model]);
|
||||
|
||||
assert_eq!(available, vec![expected]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ use core_test_support::skip_if_no_network;
|
||||
use core_test_support::skip_if_sandbox;
|
||||
use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_match;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::Duration;
|
||||
@@ -298,6 +299,108 @@ async fn remote_models_apply_remote_base_instructions() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn remote_models_preserve_builtin_presets() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let server = MockServer::start().await;
|
||||
let remote_model = test_remote_model("remote-alpha", ModelVisibility::List, 0);
|
||||
let models_mock = mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model.clone()],
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.features.enable(Feature::RemoteModels);
|
||||
|
||||
let auth = CodexAuth::from_api_key("dummy");
|
||||
let provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
let manager = ModelsManager::with_provider(
|
||||
codex_core::auth::AuthManager::from_auth_for_testing(auth),
|
||||
provider,
|
||||
);
|
||||
|
||||
manager
|
||||
.refresh_available_models(&config)
|
||||
.await
|
||||
.expect("refresh succeeds");
|
||||
|
||||
let available = manager.list_models(&config).await;
|
||||
let remote = available
|
||||
.iter()
|
||||
.find(|model| model.model == "remote-alpha")
|
||||
.expect("remote model should be listed");
|
||||
let mut expected_remote: ModelPreset = remote_model.into();
|
||||
expected_remote.is_default = true;
|
||||
assert_eq!(*remote, expected_remote);
|
||||
assert!(
|
||||
available
|
||||
.iter()
|
||||
.any(|model| model.model == "gpt-5.1-codex-max"),
|
||||
"builtin presets should remain available after refresh"
|
||||
);
|
||||
assert_eq!(
|
||||
models_mock.requests().len(),
|
||||
1,
|
||||
"expected a single /models request"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn remote_models_hide_picker_only_models() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
skip_if_sandbox!(Ok(()));
|
||||
|
||||
let server = MockServer::start().await;
|
||||
let remote_model = test_remote_model("codex-auto-balanced", ModelVisibility::Hide, 0);
|
||||
mount_models_once(
|
||||
&server,
|
||||
ModelsResponse {
|
||||
models: vec![remote_model],
|
||||
etag: String::new(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.features.enable(Feature::RemoteModels);
|
||||
|
||||
let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing();
|
||||
let provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
let manager = ModelsManager::with_provider(
|
||||
codex_core::auth::AuthManager::from_auth_for_testing(auth),
|
||||
provider,
|
||||
);
|
||||
|
||||
let selected = manager.get_model(&None, &config).await;
|
||||
assert_eq!(selected, "gpt-5.1-codex-max");
|
||||
|
||||
let available = manager.list_models(&config).await;
|
||||
assert!(
|
||||
available
|
||||
.iter()
|
||||
.all(|model| model.model != "codex-auto-balanced"),
|
||||
"hidden models should not appear in the picker list"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_for_model_available(
|
||||
manager: &Arc<ModelsManager>,
|
||||
slug: &str,
|
||||
@@ -362,3 +465,32 @@ where
|
||||
conversation_manager,
|
||||
})
|
||||
}
|
||||
|
||||
fn test_remote_model(slug: &str, visibility: ModelVisibility, priority: i32) -> ModelInfo {
|
||||
ModelInfo {
|
||||
slug: slug.to_string(),
|
||||
display_name: format!("{slug} display"),
|
||||
description: Some(format!("{slug} description")),
|
||||
default_reasoning_level: ReasoningEffort::Medium,
|
||||
supported_reasoning_levels: vec![ReasoningEffortPreset {
|
||||
effort: ReasoningEffort::Medium,
|
||||
description: ReasoningEffort::Medium.to_string(),
|
||||
}],
|
||||
shell_type: ConfigShellToolType::ShellCommand,
|
||||
visibility,
|
||||
minimal_client_version: ClientVersion(0, 1, 0),
|
||||
supported_in_api: true,
|
||||
priority,
|
||||
upgrade: None,
|
||||
base_instructions: None,
|
||||
supports_reasoning_summaries: false,
|
||||
support_verbosity: false,
|
||||
default_verbosity: None,
|
||||
apply_patch_tool_type: None,
|
||||
truncation_policy: TruncationPolicyConfig::bytes(10_000),
|
||||
supports_parallel_tool_calls: false,
|
||||
context_window: None,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
experimental_supported_tools: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user