Strip unsupported images from prompt history to guard against model switch (#11349)

- Make `ContextManager::for_prompt` modality-aware and strip input_image
content when the active model is text-only.
- Added a test for multi-model -> text-only model switch
This commit is contained in:
Ahmed Ibrahim
2026-02-10 11:58:00 -08:00
committed by GitHub
parent 82f93a13b2
commit 5e01450963
7 changed files with 403 additions and 29 deletions

View File

@@ -4006,7 +4006,11 @@ pub(crate) async fn run_turn(
}
// Construct the input that we will send to the model.
let sampling_request_input: Vec<ResponseItem> = { sess.clone_history().await.for_prompt() };
let sampling_request_input: Vec<ResponseItem> = {
sess.clone_history()
.await
.for_prompt(&turn_context.model_info.input_modalities)
};
let sampling_request_input_messages = sampling_request_input
.iter()
@@ -6936,7 +6940,9 @@ mod tests {
rollout_items.push(RolloutItem::ResponseItem(assistant1.clone()));
let summary1 = "summary one";
let snapshot1 = live_history.clone().for_prompt();
let snapshot1 = live_history
.clone()
.for_prompt(&reconstruction_turn.model_info.input_modalities);
let user_messages1 = collect_user_messages(&snapshot1);
let rebuilt1 =
compact::build_compacted_history(initial_context.clone(), &user_messages1, summary1);
@@ -6977,7 +6983,9 @@ mod tests {
rollout_items.push(RolloutItem::ResponseItem(assistant2.clone()));
let summary2 = "summary two";
let snapshot2 = live_history.clone().for_prompt();
let snapshot2 = live_history
.clone()
.for_prompt(&reconstruction_turn.model_info.input_modalities);
let user_messages2 = collect_user_messages(&snapshot2);
let rebuilt2 =
compact::build_compacted_history(initial_context.clone(), &user_messages2, summary2);
@@ -7017,7 +7025,10 @@ mod tests {
);
rollout_items.push(RolloutItem::ResponseItem(assistant3));
(rollout_items, live_history.for_prompt())
(
rollout_items,
live_history.for_prompt(&reconstruction_turn.model_info.input_modalities),
)
}
#[tokio::test]

View File

@@ -112,7 +112,9 @@ async fn run_compact_task_inner(
loop {
// Clone is required because of the loop
let turn_input = history.clone().for_prompt();
let turn_input = history
.clone()
.for_prompt(&turn_context.model_info.input_modalities);
let turn_input_len = turn_input.len();
let prompt = Prompt {
input: turn_input,

View File

@@ -87,7 +87,7 @@ async fn run_remote_compact_task_inner_impl(
.collect();
let prompt = Prompt {
input: history.for_prompt(),
input: history.for_prompt(&turn_context.model_info.input_modalities),
tools: vec![],
parallel_tool_calls: false,
base_instructions,

View File

@@ -15,6 +15,7 @@ use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputContentItem;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseItem;
use codex_protocol::openai_models::InputModality;
use codex_protocol::protocol::TokenUsage;
use codex_protocol::protocol::TokenUsageInfo;
use std::ops::Deref;
@@ -79,9 +80,11 @@ impl ContextManager {
}
/// Returns the history prepared for sending to the model. This applies a proper
/// normalization and drop un-suited items.
pub(crate) fn for_prompt(mut self) -> Vec<ResponseItem> {
self.normalize_history();
/// normalization and drops un-suited items. When `input_modalities` does not
/// include `InputModality::Image`, images are stripped from messages and tool
/// outputs.
pub(crate) fn for_prompt(mut self, input_modalities: &[InputModality]) -> Vec<ResponseItem> {
self.normalize_history(input_modalities);
self.items
.retain(|item| !matches!(item, ResponseItem::GhostSnapshot { .. }));
self.items
@@ -309,12 +312,16 @@ impl ContextManager {
/// This function enforces a couple of invariants on the in-memory history:
/// 1. every call (function/custom) has a corresponding output entry
/// 2. every output has a corresponding call entry
fn normalize_history(&mut self) {
/// 3. when images are unsupported, image content is stripped from messages and tool outputs
fn normalize_history(&mut self, input_modalities: &[InputModality]) {
// all function/tool calls must have a corresponding output
normalize::ensure_call_outputs_present(&mut self.items);
// all outputs must have a corresponding function/tool call
normalize::remove_orphan_outputs(&mut self.items);
// strip images when model does not support them
normalize::strip_images_when_unsupported(input_modalities, &mut self.items);
}
fn process_item(&self, item: &ResponseItem, policy: TruncationPolicy) -> ResponseItem {

View File

@@ -12,6 +12,8 @@ use codex_protocol::models::LocalShellExecAction;
use codex_protocol::models::LocalShellStatus;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ReasoningItemReasoningSummary;
use codex_protocol::openai_models::InputModality;
use codex_protocol::openai_models::default_input_modalities;
use pretty_assertions::assert_eq;
use regex_lite::Regex;
@@ -240,13 +242,122 @@ fn total_token_usage_includes_all_items_after_last_model_generated_item() {
);
}
#[test]
fn for_prompt_strips_images_when_model_does_not_support_images() {
let items = vec![
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![
ContentItem::InputText {
text: "look at this".to_string(),
},
ContentItem::InputImage {
image_url: "https://example.com/img.png".to_string(),
},
ContentItem::InputText {
text: "caption".to_string(),
},
],
end_turn: None,
phase: None,
},
ResponseItem::FunctionCall {
id: None,
name: "view_image".to_string(),
arguments: "{}".to_string(),
call_id: "call-1".to_string(),
},
ResponseItem::FunctionCallOutput {
call_id: "call-1".to_string(),
output: FunctionCallOutputPayload::from_content_items(vec![
FunctionCallOutputContentItem::InputText {
text: "image result".to_string(),
},
FunctionCallOutputContentItem::InputImage {
image_url: "https://example.com/result.png".to_string(),
},
]),
},
];
let history = create_history_with_items(items);
let text_only_modalities = vec![InputModality::Text];
let stripped = history.for_prompt(&text_only_modalities);
let expected = vec![
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![
ContentItem::InputText {
text: "look at this".to_string(),
},
ContentItem::InputText {
text: "image content omitted because you do not support image input"
.to_string(),
},
ContentItem::InputText {
text: "caption".to_string(),
},
],
end_turn: None,
phase: None,
},
ResponseItem::FunctionCall {
id: None,
name: "view_image".to_string(),
arguments: "{}".to_string(),
call_id: "call-1".to_string(),
},
ResponseItem::FunctionCallOutput {
call_id: "call-1".to_string(),
output: FunctionCallOutputPayload::from_content_items(vec![
FunctionCallOutputContentItem::InputText {
text: "image result".to_string(),
},
FunctionCallOutputContentItem::InputText {
text: "image content omitted because you do not support image input"
.to_string(),
},
]),
},
];
assert_eq!(stripped, expected);
// With image support, images are preserved
let modalities = default_input_modalities();
let with_images = create_history_with_items(vec![ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![
ContentItem::InputText {
text: "look".to_string(),
},
ContentItem::InputImage {
image_url: "https://example.com/img.png".to_string(),
},
],
end_turn: None,
phase: None,
}]);
let preserved = with_images.for_prompt(&modalities);
assert_eq!(preserved.len(), 1);
if let ResponseItem::Message { content, .. } = &preserved[0] {
assert_eq!(content.len(), 2);
assert!(matches!(content[1], ContentItem::InputImage { .. }));
} else {
panic!("expected Message");
}
}
#[test]
fn get_history_for_prompt_drops_ghost_commits() {
let items = vec![ResponseItem::GhostSnapshot {
ghost_commit: GhostCommit::new("ghost-1".to_string(), None, Vec::new(), Vec::new()),
}];
let history = create_history_with_items(items);
let filtered = history.for_prompt();
let modalities = default_input_modalities();
let filtered = history.for_prompt(&modalities);
assert_eq!(filtered, vec![]);
}
@@ -422,10 +533,11 @@ fn drop_last_n_user_turns_preserves_prefix() {
assistant_msg("a2"),
];
let modalities = default_input_modalities();
let mut history = create_history_with_items(items);
history.drop_last_n_user_turns(1);
assert_eq!(
history.for_prompt(),
history.for_prompt(&modalities),
vec![
assistant_msg("session prefix item"),
user_msg("u1"),
@@ -442,7 +554,7 @@ fn drop_last_n_user_turns_preserves_prefix() {
]);
history.drop_last_n_user_turns(99);
assert_eq!(
history.for_prompt(),
history.for_prompt(&modalities),
vec![assistant_msg("session prefix item")]
);
}
@@ -465,6 +577,7 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
assistant_msg("turn 2 assistant"),
];
let modalities = default_input_modalities();
let mut history = create_history_with_items(items);
history.drop_last_n_user_turns(1);
@@ -482,7 +595,10 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
assistant_msg("turn 1 assistant"),
];
assert_eq!(history.for_prompt(), expected_prefix_and_first_turn);
assert_eq!(
history.for_prompt(&modalities),
expected_prefix_and_first_turn
);
let expected_prefix_only = vec![
user_input_text_msg("<environment_context>ctx</environment_context>"),
@@ -512,7 +628,7 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
assistant_msg("turn 2 assistant"),
]);
history.drop_last_n_user_turns(2);
assert_eq!(history.for_prompt(), expected_prefix_only);
assert_eq!(history.for_prompt(&modalities), expected_prefix_only);
let mut history = create_history_with_items(vec![
user_input_text_msg("<environment_context>ctx</environment_context>"),
@@ -530,7 +646,7 @@ fn drop_last_n_user_turns_ignores_session_prefix_user_messages() {
assistant_msg("turn 2 assistant"),
]);
history.drop_last_n_user_turns(3);
assert_eq!(history.for_prompt(), expected_prefix_only);
assert_eq!(history.for_prompt(&modalities), expected_prefix_only);
}
#[test]
@@ -574,8 +690,9 @@ fn normalization_retains_local_shell_outputs() {
},
];
let modalities = default_input_modalities();
let history = create_history_with_items(items.clone());
let normalized = history.for_prompt();
let normalized = history.for_prompt(&modalities);
assert_eq!(normalized, items);
}
@@ -777,7 +894,7 @@ fn normalize_adds_missing_output_for_function_call() {
}];
let mut h = create_history_with_items(items);
h.normalize_history();
h.normalize_history(&default_input_modalities());
assert_eq!(
h.raw_items(),
@@ -808,7 +925,7 @@ fn normalize_adds_missing_output_for_custom_tool_call() {
}];
let mut h = create_history_with_items(items);
h.normalize_history();
h.normalize_history(&default_input_modalities());
assert_eq!(
h.raw_items(),
@@ -845,7 +962,7 @@ fn normalize_adds_missing_output_for_local_shell_call_with_id() {
}];
let mut h = create_history_with_items(items);
h.normalize_history();
h.normalize_history(&default_input_modalities());
assert_eq!(
h.raw_items(),
@@ -879,7 +996,7 @@ fn normalize_removes_orphan_function_call_output() {
}];
let mut h = create_history_with_items(items);
h.normalize_history();
h.normalize_history(&default_input_modalities());
assert_eq!(h.raw_items(), vec![]);
}
@@ -893,7 +1010,7 @@ fn normalize_removes_orphan_custom_tool_call_output() {
}];
let mut h = create_history_with_items(items);
h.normalize_history();
h.normalize_history(&default_input_modalities());
assert_eq!(h.raw_items(), vec![]);
}
@@ -938,7 +1055,7 @@ fn normalize_mixed_inserts_and_removals() {
];
let mut h = create_history_with_items(items);
h.normalize_history();
h.normalize_history(&default_input_modalities());
assert_eq!(
h.raw_items(),
@@ -993,7 +1110,7 @@ fn normalize_adds_missing_output_for_function_call_inserts_output() {
call_id: "call-x".to_string(),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
h.normalize_history(&default_input_modalities());
assert_eq!(
h.raw_items(),
vec![
@@ -1023,7 +1140,7 @@ fn normalize_adds_missing_output_for_custom_tool_call_panics_in_debug() {
input: "{}".to_string(),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
h.normalize_history(&default_input_modalities());
}
#[cfg(debug_assertions)]
@@ -1043,7 +1160,7 @@ fn normalize_adds_missing_output_for_local_shell_call_with_id_panics_in_debug()
}),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
h.normalize_history(&default_input_modalities());
}
#[cfg(debug_assertions)]
@@ -1055,7 +1172,7 @@ fn normalize_removes_orphan_function_call_output_panics_in_debug() {
output: FunctionCallOutputPayload::from_text("ok".to_string()),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
h.normalize_history(&default_input_modalities());
}
#[cfg(debug_assertions)]
@@ -1067,7 +1184,7 @@ fn normalize_removes_orphan_custom_tool_call_output_panics_in_debug() {
output: "ok".to_string(),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
h.normalize_history(&default_input_modalities());
}
#[cfg(debug_assertions)]
@@ -1106,5 +1223,5 @@ fn normalize_mixed_inserts_and_removals_panics_in_debug() {
},
];
let mut h = create_history_with_items(items);
h.normalize_history();
h.normalize_history(&default_input_modalities());
}

View File

@@ -1,12 +1,18 @@
use std::collections::HashSet;
use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputContentItem;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseItem;
use codex_protocol::openai_models::InputModality;
use crate::util::error_or_panic;
use tracing::info;
const IMAGE_CONTENT_OMITTED_PLACEHOLDER: &str =
"image content omitted because you do not support image input";
pub(crate) fn ensure_call_outputs_present(items: &mut Vec<ResponseItem>) {
// Collect synthetic outputs to insert immediately after their calls.
// Store the insertion position (index of call) alongside the item so
@@ -211,3 +217,53 @@ where
items.remove(pos);
}
}
/// Strip image content from messages and tool outputs when the model does not support images.
/// When `input_modalities` contains `InputModality::Image`, no stripping is performed.
pub(crate) fn strip_images_when_unsupported(
input_modalities: &[InputModality],
items: &mut [ResponseItem],
) {
let supports_images = input_modalities.contains(&InputModality::Image);
if supports_images {
return;
}
for item in items.iter_mut() {
match item {
ResponseItem::Message { content, .. } => {
let mut normalized_content = Vec::with_capacity(content.len());
for content_item in content.iter() {
match content_item {
ContentItem::InputImage { .. } => {
normalized_content.push(ContentItem::InputText {
text: IMAGE_CONTENT_OMITTED_PLACEHOLDER.to_string(),
});
}
_ => normalized_content.push(content_item.clone()),
}
}
*content = normalized_content;
}
ResponseItem::FunctionCallOutput { output, .. } => {
if let Some(content_items) = output.content_items_mut() {
let mut normalized_content_items = Vec::with_capacity(content_items.len());
for content_item in content_items.iter() {
match content_item {
FunctionCallOutputContentItem::InputImage { .. } => {
normalized_content_items.push(
FunctionCallOutputContentItem::InputText {
text: IMAGE_CONTENT_OMITTED_PLACEHOLDER.to_string(),
},
);
}
_ => normalized_content_items.push(content_item.clone()),
}
}
*content_items = normalized_content_items;
}
}
_ => {}
}
}
}

View File

@@ -1,12 +1,24 @@
use anyhow::Result;
use codex_core::CodexAuth;
use codex_core::config::types::Personality;
use codex_core::features::Feature;
use codex_core::models_manager::manager::RefreshStrategy;
use codex_core::protocol::AskForApproval;
use codex_core::protocol::EventMsg;
use codex_core::protocol::Op;
use codex_core::protocol::SandboxPolicy;
use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::openai_models::ConfigShellToolType;
use codex_protocol::openai_models::InputModality;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::openai_models::ModelVisibility;
use codex_protocol::openai_models::ModelsResponse;
use codex_protocol::openai_models::ReasoningEffort;
use codex_protocol::openai_models::ReasoningEffortPreset;
use codex_protocol::openai_models::TruncationPolicyConfig;
use codex_protocol::openai_models::default_input_modalities;
use codex_protocol::user_input::UserInput;
use core_test_support::responses::mount_models_once;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse_completed;
use core_test_support::responses::start_mock_server;
@@ -14,6 +26,8 @@ use core_test_support::skip_if_no_network;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use pretty_assertions::assert_eq;
use serde_json::Value;
use wiremock::MockServer;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn model_change_appends_model_instructions_developer_message() -> Result<()> {
@@ -190,3 +204,170 @@ async fn model_and_personality_change_only_appends_model_instructions() -> Resul
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn model_change_from_image_to_text_strips_prior_image_content() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = MockServer::start().await;
let image_model_slug = "test-image-model";
let text_model_slug = "test-text-only-model";
let image_model = ModelInfo {
slug: image_model_slug.to_string(),
display_name: "Test Image Model".to_string(),
description: Some("supports image input".to_string()),
default_reasoning_level: Some(ReasoningEffort::Medium),
supported_reasoning_levels: vec![ReasoningEffortPreset {
effort: ReasoningEffort::Medium,
description: ReasoningEffort::Medium.to_string(),
}],
shell_type: ConfigShellToolType::ShellCommand,
visibility: ModelVisibility::List,
supported_in_api: true,
input_modalities: default_input_modalities(),
priority: 1,
upgrade: None,
base_instructions: "base instructions".to_string(),
model_messages: None,
supports_reasoning_summaries: false,
support_verbosity: false,
default_verbosity: None,
apply_patch_tool_type: None,
truncation_policy: TruncationPolicyConfig::bytes(10_000),
supports_parallel_tool_calls: false,
context_window: Some(272_000),
auto_compact_token_limit: None,
effective_context_window_percent: 95,
experimental_supported_tools: Vec::new(),
};
let mut text_model = image_model.clone();
text_model.slug = text_model_slug.to_string();
text_model.display_name = "Test Text Model".to_string();
text_model.description = Some("text only".to_string());
text_model.input_modalities = vec![InputModality::Text];
mount_models_once(
&server,
ModelsResponse {
models: vec![image_model, text_model],
},
)
.await;
let responses = mount_sse_sequence(
&server,
vec![sse_completed("resp-1"), sse_completed("resp-2")],
)
.await;
let mut builder = test_codex()
.with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
.with_config(move |config| {
config.features.enable(Feature::RemoteModels);
config.model = Some(image_model_slug.to_string());
});
let test = builder.build(&server).await?;
let models_manager = test.thread_manager.get_models_manager();
let _ = models_manager
.list_models(&test.config, RefreshStrategy::OnlineIfUncached)
.await;
let image_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGNgYAAAAAMAASsJTYQAAAAASUVORK5CYII="
.to_string();
test.codex
.submit(Op::UserTurn {
items: vec![
UserInput::Image {
image_url: image_url.clone(),
},
UserInput::Text {
text: "first turn".to_string(),
text_elements: Vec::new(),
},
],
final_output_json_schema: None,
cwd: test.cwd_path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::ReadOnly,
model: image_model_slug.to_string(),
effort: test.config.model_reasoning_effort,
summary: ReasoningSummary::Auto,
collaboration_mode: None,
personality: None,
})
.await?;
wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
test.codex
.submit(Op::UserTurn {
items: vec![UserInput::Text {
text: "second turn".to_string(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
cwd: test.cwd_path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::ReadOnly,
model: text_model_slug.to_string(),
effort: test.config.model_reasoning_effort,
summary: ReasoningSummary::Auto,
collaboration_mode: None,
personality: None,
})
.await?;
wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
let requests = responses.requests();
assert_eq!(requests.len(), 2, "expected two model requests");
let first_request = requests.first().expect("expected first request");
let first_has_input_image = first_request.inputs_of_type("message").iter().any(|item| {
item.get("content")
.and_then(Value::as_array)
.is_some_and(|content| {
content
.iter()
.any(|span| span.get("type").and_then(Value::as_str) == Some("input_image"))
})
});
assert!(
first_has_input_image,
"first request should include the uploaded image"
);
let second_request = requests.last().expect("expected second request");
let second_has_input_image = second_request.inputs_of_type("message").iter().any(|item| {
item.get("content")
.and_then(Value::as_array)
.is_some_and(|content| {
content
.iter()
.any(|span| span.get("type").and_then(Value::as_str) == Some("input_image"))
})
});
assert!(
!second_has_input_image,
"second request should strip unsupported image content"
);
let second_user_texts = second_request.message_input_texts("user");
assert!(
second_user_texts
.iter()
.any(|text| text == "image content omitted because you do not support image input"),
"second request should include the image-omitted placeholder text"
);
assert!(
second_user_texts
.iter()
.any(|text| text == &codex_protocol::models::image_open_tag_text()),
"second request should preserve the image open tag text"
);
assert!(
second_user_texts
.iter()
.any(|text| text == &codex_protocol::models::image_close_tag_text()),
"second request should preserve the image close tag text"
);
Ok(())
}