Compare commits

...

2 Commits

Author SHA1 Message Date
Dylan Hurd
b9635bd082 support websockets 2026-03-18 00:28:30 -07:00
Dylan Hurd
77c14132c9 feat(core) serialize response_item.id 2026-03-18 00:20:21 -07:00
15 changed files with 562 additions and 34 deletions

View File

@@ -70,6 +70,19 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
&self,
request: ResponsesApiRequest,
options: ResponsesOptions,
) -> Result<ResponseStream, ApiError> {
let mut body = serde_json::to_value(&request)
.map_err(|e| ApiError::Stream(format!("failed to encode responses request: {e}")))?;
if request.store && self.session.provider().is_azure_responses_endpoint() {
attach_item_ids(&mut body, &request.input);
}
self.stream_request_with_body(body, options).await
}
pub async fn stream_request_with_body(
&self,
body: Value,
options: ResponsesOptions,
) -> Result<ResponseStream, ApiError> {
let ResponsesOptions {
conversation_id,
@@ -79,12 +92,6 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
turn_state,
} = options;
let mut body = serde_json::to_value(&request)
.map_err(|e| ApiError::Stream(format!("failed to encode responses request: {e}")))?;
if request.store && self.session.provider().is_azure_responses_endpoint() {
attach_item_ids(&mut body, &request.input);
}
let mut headers = extra_headers;
if let Some(ref conv_id) = conversation_id {
insert_header(&mut headers, "x-client-request-id", conv_id);

View File

@@ -215,6 +215,18 @@ impl ResponsesWebsocketConnection {
&self,
request: ResponsesWsRequest,
connection_reused: bool,
) -> Result<ResponseStream, ApiError> {
let request_body = serde_json::to_value(&request).map_err(|err| {
ApiError::Stream(format!("failed to encode websocket request: {err}"))
})?;
self.stream_request_with_body(request_body, connection_reused)
.await
}
pub async fn stream_request_with_body(
&self,
request_body: Value,
connection_reused: bool,
) -> Result<ResponseStream, ApiError> {
let (tx_event, rx_event) =
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
@@ -224,9 +236,6 @@ impl ResponsesWebsocketConnection {
let models_etag = self.models_etag.clone();
let server_model = self.server_model.clone();
let telemetry = self.telemetry.clone();
let request_body = serde_json::to_value(&request).map_err(|err| {
ApiError::Stream(format!("failed to encode websocket request: {err}"))
})?;
let current_span = Span::current();
tokio::spawn(

View File

@@ -437,6 +437,9 @@
"realtime_conversation": {
"type": "boolean"
},
"record_response_item_id": {
"type": "boolean"
},
"remote_models": {
"type": "boolean"
},
@@ -1995,6 +1998,9 @@
"realtime_conversation": {
"type": "boolean"
},
"record_response_item_id": {
"type": "boolean"
},
"remote_models": {
"type": "boolean"
},
@@ -2479,4 +2485,4 @@
},
"title": "ConfigToml",
"type": "object"
}
}

View File

@@ -104,6 +104,9 @@ use crate::response_debug_context::extract_response_debug_context;
use crate::response_debug_context::extract_response_debug_context_from_api_error;
use crate::response_debug_context::telemetry_api_error_message;
use crate::response_debug_context::telemetry_transport_error_message;
use crate::response_item_id_serde::ResponseItemIdSerialization;
use crate::response_item_id_serde::serialize_responses_request_body;
use crate::response_item_id_serde::serialize_responses_ws_request_body;
use crate::tools::spec::create_tools_json_for_responses_api;
use crate::util::FeedbackRequestTags;
use crate::util::emit_feedback_auth_recovery_tags;
@@ -137,6 +140,7 @@ struct ModelClientState {
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
response_item_ids: ResponseItemIdSerialization,
disable_websockets: AtomicBool,
cached_websocket_session: StdMutex<WebsocketSession>,
}
@@ -178,6 +182,22 @@ pub struct ModelClient {
state: Arc<ModelClientState>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ModelClientResponseItemIds {
#[default]
Disabled,
Enabled,
}
impl From<ModelClientResponseItemIds> for ResponseItemIdSerialization {
fn from(value: ModelClientResponseItemIds) -> Self {
match value {
ModelClientResponseItemIds::Disabled => ResponseItemIdSerialization::Disabled,
ModelClientResponseItemIds::Enabled => ResponseItemIdSerialization::Enabled,
}
}
}
/// A turn-scoped streaming session created from a [`ModelClient`].
///
/// The session establishes a Responses WebSocket connection lazily and reuses it across multiple
@@ -257,6 +277,7 @@ impl ModelClient {
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
response_item_ids: ModelClientResponseItemIds,
) -> Self {
let codex_api_key_env_enabled = auth_manager
.as_ref()
@@ -273,6 +294,7 @@ impl ModelClient {
enable_request_compression,
include_timing_metrics,
beta_features_header,
response_item_ids: response_item_ids.into(),
disable_websockets: AtomicBool::new(false),
cached_websocket_session: StdMutex::new(WebsocketSession::default()),
}),
@@ -1049,7 +1071,18 @@ impl ModelClientSession {
client_setup.api_auth,
)
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
let stream_result = client.stream_request(request, options).await;
let stream_result = if self.client.state.response_item_ids.is_enabled() {
let request_body =
serialize_responses_request_body(&request, self.client.state.response_item_ids)
.map_err(|err| {
map_api_error(ApiError::Stream(format!(
"failed to encode responses request: {err}"
)))
})?;
client.stream_request_with_body(request_body, options).await
} else {
client.stream_request(request, options).await
};
match stream_result {
Ok(stream) => {
@@ -1170,15 +1203,34 @@ impl ModelClientSession {
let ws_request = self.prepare_websocket_request(ws_payload, &request);
self.websocket_session.last_request = Some(request);
let stream_result = self.websocket_session.connection.as_ref().ok_or_else(|| {
let connection = self.websocket_session.connection.as_ref().ok_or_else(|| {
map_api_error(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
})?;
let stream_result = stream_result
.stream_request(ws_request, self.websocket_session.connection_reused())
.await
.map_err(map_api_error)?;
let stream_result = if self.client.state.response_item_ids.is_enabled() {
let request_body = serialize_responses_ws_request_body(
&ws_request,
self.client.state.response_item_ids,
)
.map_err(|err| {
map_api_error(ApiError::Stream(format!(
"failed to encode websocket request: {err}"
)))
})?;
connection
.stream_request_with_body(
request_body,
self.websocket_session.connection_reused(),
)
.await
.map_err(map_api_error)?
} else {
connection
.stream_request(ws_request, self.websocket_session.connection_reused())
.await
.map_err(map_api_error)?
};
let (stream, last_request_rx) =
map_response_stream(stream_result, session_telemetry.clone());
self.websocket_session.last_response_rx = Some(last_request_rx);

View File

@@ -1,5 +1,6 @@
use super::AuthRequestTelemetryContext;
use super::ModelClient;
use super::ModelClientResponseItemIds;
use super::PendingUnauthorizedRetry;
use super::UnauthorizedRecoveryExecution;
use codex_otel::SessionTelemetry;
@@ -24,6 +25,7 @@ fn test_model_client(session_source: SessionSource) -> ModelClient {
false,
false,
None,
ModelClientResponseItemIds::Disabled,
)
}

View File

@@ -154,6 +154,7 @@ use uuid::Uuid;
use crate::ModelProviderInfo;
use crate::client::ModelClient;
use crate::client::ModelClientResponseItemIds;
use crate::client::ModelClientSession;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
@@ -1815,6 +1816,11 @@ impl Session {
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Self::build_model_client_beta_features_header(config.as_ref()),
if config.features.enabled(Feature::RecordResponseItemId) {
ModelClientResponseItemIds::Enabled
} else {
ModelClientResponseItemIds::Disabled
},
),
code_mode_service: crate::tools::code_mode::CodeModeService::new(
config.js_repl_node_path.clone(),

View File

@@ -240,6 +240,7 @@ fn test_model_client_session() -> crate::client::ModelClientSession {
false,
false,
None,
crate::client::ModelClientResponseItemIds::Disabled,
)
.new_session()
}
@@ -2515,6 +2516,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),
crate::client::ModelClientResponseItemIds::Disabled,
),
code_mode_service: crate::tools::code_mode::CodeModeService::new(
config.js_repl_node_path.clone(),
@@ -3309,6 +3311,7 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),
crate::client::ModelClientResponseItemIds::Disabled,
),
code_mode_service: crate::tools::code_mode::CodeModeService::new(
config.js_repl_node_path.clone(),

View File

@@ -182,6 +182,8 @@ pub enum Feature {
RealtimeConversation,
/// Route interactive startup to the app-server-backed TUI implementation.
TuiAppServer,
/// Persist and resend provider ResponseItem ids in internal rollout/client payloads.
RecordResponseItemId,
/// Prevent idle system sleep while a turn is actively running.
PreventIdleSleep,
/// Legacy rollout flag for Responses API WebSocket transport experiments.
@@ -839,6 +841,12 @@ pub const FEATURES: &[FeatureSpec] = &[
},
default_enabled: false,
},
FeatureSpec {
id: Feature::RecordResponseItemId,
key: "record_response_item_id",
stage: Stage::UnderDevelopment,
default_enabled: false,
},
FeatureSpec {
id: Feature::PreventIdleSleep,
key: "prevent_idle_sleep",

View File

@@ -80,6 +80,7 @@ pub mod token_data;
mod truncate;
mod unified_exec;
pub mod windows_sandbox;
pub use client::ModelClientResponseItemIds;
pub use client::X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER;
pub use model_provider_info::DEFAULT_LMSTUDIO_PORT;
pub use model_provider_info::DEFAULT_OLLAMA_PORT;
@@ -92,6 +93,7 @@ pub use model_provider_info::built_in_model_providers;
pub use model_provider_info::create_oss_provider_with_base_url;
mod event_mapping;
mod response_debug_context;
mod response_item_id_serde;
pub mod review_format;
pub mod review_prompts;
mod seatbelt_permissions;

View File

@@ -0,0 +1,229 @@
use codex_api::ResponsesApiRequest;
use codex_api::common::ResponsesWsRequest;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::CompactedItem;
use codex_protocol::protocol::RolloutItem;
use serde::Serialize;
use serde_json::Value;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub(crate) enum ResponseItemIdSerialization {
#[default]
Disabled,
Enabled,
}
impl ResponseItemIdSerialization {
pub(crate) fn is_enabled(self) -> bool {
matches!(self, Self::Enabled)
}
}
pub(crate) fn serialize_responses_request_body(
request: &ResponsesApiRequest,
response_item_ids: ResponseItemIdSerialization,
) -> serde_json::Result<Value> {
serialize_input_payload(request, &request.input, response_item_ids)
}
pub(crate) fn serialize_responses_ws_request_body(
request: &ResponsesWsRequest,
response_item_ids: ResponseItemIdSerialization,
) -> serde_json::Result<Value> {
let mut value = serde_json::to_value(request)?;
if response_item_ids.is_enabled() {
match request {
ResponsesWsRequest::ResponseCreate(request) => {
attach_response_item_ids(value.get_mut("input"), &request.input);
}
}
}
Ok(value)
}
pub(crate) fn serialize_rollout_line(
timestamp: String,
item: &RolloutItem,
response_item_ids: ResponseItemIdSerialization,
) -> serde_json::Result<String> {
let mut value = serde_json::to_value(RolloutLineRef { timestamp, item })?;
if response_item_ids.is_enabled() {
attach_rollout_item_id(&mut value, item);
}
serde_json::to_string(&value)
}
fn serialize_input_payload<T: Serialize>(
request: &T,
input: &[ResponseItem],
response_item_ids: ResponseItemIdSerialization,
) -> serde_json::Result<Value> {
let mut value = serde_json::to_value(request)?;
if response_item_ids.is_enabled() {
attach_response_item_ids(value.get_mut("input"), input);
}
Ok(value)
}
fn attach_rollout_item_id(value: &mut Value, item: &RolloutItem) {
match item {
RolloutItem::ResponseItem(response_item) => {
let Some(id) = response_item_id(response_item) else {
return;
};
if let Some(payload) = value.get_mut("payload").and_then(Value::as_object_mut) {
payload.insert("id".to_string(), Value::String(id.to_string()));
}
}
RolloutItem::Compacted(compacted_item) => {
attach_compacted_replacement_history_ids(value, compacted_item);
}
RolloutItem::SessionMeta(_) | RolloutItem::TurnContext(_) | RolloutItem::EventMsg(_) => {}
}
}
fn attach_compacted_replacement_history_ids(value: &mut Value, compacted_item: &CompactedItem) {
let Some(payload) = value.get_mut("payload").and_then(Value::as_object_mut) else {
return;
};
attach_response_item_ids(
payload.get_mut("replacement_history"),
compacted_item
.replacement_history
.as_deref()
.unwrap_or_default(),
);
}
fn attach_response_item_ids(items_value: Option<&mut Value>, original_items: &[ResponseItem]) {
let Some(items_value) = items_value else {
return;
};
let Value::Array(items) = items_value else {
return;
};
for (value, item) in items.iter_mut().zip(original_items.iter()) {
let Some(id) = response_item_id(item) else {
continue;
};
if id.is_empty() {
continue;
}
if let Some(obj) = value.as_object_mut() {
obj.insert("id".to_string(), Value::String(id.to_string()));
}
}
}
fn response_item_id(item: &ResponseItem) -> Option<&str> {
match item {
ResponseItem::Reasoning { id, .. } => Some(id.as_str()),
ResponseItem::Message { id: Some(id), .. }
| ResponseItem::WebSearchCall { id: Some(id), .. }
| ResponseItem::FunctionCall { id: Some(id), .. }
| ResponseItem::ToolSearchCall { id: Some(id), .. }
| ResponseItem::LocalShellCall { id: Some(id), .. }
| ResponseItem::CustomToolCall { id: Some(id), .. } => Some(id.as_str()),
ResponseItem::Message { id: None, .. }
| ResponseItem::WebSearchCall { id: None, .. }
| ResponseItem::LocalShellCall { id: None, .. }
| ResponseItem::FunctionCall { id: None, .. }
| ResponseItem::ToolSearchCall { id: None, .. }
| ResponseItem::CustomToolCall { id: None, .. }
| ResponseItem::FunctionCallOutput { .. }
| ResponseItem::CustomToolCallOutput { .. }
| ResponseItem::ToolSearchOutput { .. }
| ResponseItem::ImageGenerationCall { .. }
| ResponseItem::GhostSnapshot { .. }
| ResponseItem::Compaction { .. }
| ResponseItem::Other => None,
}
}
#[derive(Serialize)]
struct RolloutLineRef<'a> {
timestamp: String,
#[serde(flatten)]
item: &'a RolloutItem,
}
#[cfg(test)]
mod tests {
use super::ResponseItemIdSerialization;
use super::serialize_responses_ws_request_body;
use super::serialize_rollout_line;
use codex_api::ResponseCreateWsRequest;
use codex_api::common::ResponsesWsRequest;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::CompactedItem;
use codex_protocol::protocol::RolloutItem;
use pretty_assertions::assert_eq;
use serde_json::json;
#[test]
fn serialize_responses_ws_request_body_preserves_response_create_wrapper() {
let request = ResponsesWsRequest::ResponseCreate(ResponseCreateWsRequest {
model: "gpt-test".to_string(),
instructions: "Be helpful".to_string(),
previous_response_id: None,
input: vec![assistant_message(Some("msg_123"))],
tools: Vec::new(),
tool_choice: "auto".to_string(),
parallel_tool_calls: false,
reasoning: None,
store: true,
stream: true,
include: Vec::new(),
service_tier: None,
prompt_cache_key: None,
text: None,
generate: None,
client_metadata: None,
});
let value =
serialize_responses_ws_request_body(&request, ResponseItemIdSerialization::Enabled)
.expect("websocket request should serialize");
assert_eq!(value["type"], json!("response.create"));
assert_eq!(value["input"][0]["id"], json!("msg_123"));
}
#[test]
fn serialize_rollout_line_preserves_ids_in_compacted_replacement_history() {
let line = serialize_rollout_line(
"2026-03-18T00:00:00Z".to_string(),
&RolloutItem::Compacted(CompactedItem {
message: "compacted".to_string(),
replacement_history: Some(vec![assistant_message(Some("msg_123"))]),
}),
ResponseItemIdSerialization::Enabled,
)
.expect("rollout line should serialize");
let value: serde_json::Value =
serde_json::from_str(&line).expect("serialized rollout line should be valid json");
assert_eq!(
value["payload"]["replacement_history"][0]["id"],
json!("msg_123")
);
}
fn assistant_message(id: Option<&str>) -> ResponseItem {
ResponseItem::Message {
id: id.map(str::to_string),
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "hello".to_string(),
}],
end_turn: None,
phase: None,
}
}
}

View File

@@ -41,8 +41,11 @@ use super::policy::EventPersistenceMode;
use super::policy::is_persisted_response_item;
use crate::config::Config;
use crate::default_client::originator;
use crate::features::Feature;
use crate::git_info::collect_git_info;
use crate::path_utils;
use crate::response_item_id_serde::ResponseItemIdSerialization;
use crate::response_item_id_serde::serialize_rollout_line;
use crate::state_db;
use crate::state_db::StateDbHandle;
use crate::truncate::TruncationPolicy;
@@ -463,6 +466,11 @@ impl RolloutRecorder {
state_builder,
config.model_provider_id.clone(),
config.memories.generate_memories,
if config.features.enabled(Feature::RecordResponseItemId) {
ResponseItemIdSerialization::Enabled
} else {
ResponseItemIdSerialization::Disabled
},
));
Ok(Self {
@@ -716,8 +724,12 @@ async fn rollout_writer(
mut state_builder: Option<ThreadMetadataBuilder>,
default_provider: String,
generate_memories: bool,
response_item_ids: ResponseItemIdSerialization,
) -> std::io::Result<()> {
let mut writer = file.map(|file| JsonlWriter { file });
let mut writer = file.map(|file| JsonlWriter {
file,
response_item_ids,
});
let mut buffered_items = Vec::<RolloutItem>::new();
if let Some(builder) = state_builder.as_mut() {
builder.rollout_path = rollout_path.clone();
@@ -775,6 +787,7 @@ async fn rollout_writer(
let file = open_log_file(log_file_info.path.as_path())?;
writer = Some(JsonlWriter {
file: tokio::fs::File::from_std(file),
response_item_ids,
});
if let Some(session_meta) = meta.take() {
@@ -946,13 +959,7 @@ async fn sync_thread_state_after_write(
struct JsonlWriter {
file: tokio::fs::File,
}
#[derive(serde::Serialize)]
struct RolloutLineRef<'a> {
timestamp: String,
#[serde(flatten)]
item: &'a RolloutItem,
response_item_ids: ResponseItemIdSerialization,
}
impl JsonlWriter {
@@ -964,14 +971,7 @@ impl JsonlWriter {
.format(timestamp_format)
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
let line = RolloutLineRef {
timestamp,
item: rollout_item,
};
self.write_line(&line).await
}
async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> {
let mut json = serde_json::to_string(item)?;
let mut json = serialize_rollout_line(timestamp, rollout_item, self.response_item_ids)?;
json.push('\n');
self.file.write_all(json.as_bytes()).await?;
self.file.flush().await?;

View File

@@ -3,6 +3,8 @@ use crate::config::ConfigBuilder;
use crate::features::Feature;
use chrono::TimeZone;
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::AgentMessageEvent;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::EventMsg;
@@ -137,6 +139,60 @@ async fn recorder_materializes_only_after_explicit_persist() -> std::io::Result<
Ok(())
}
#[tokio::test]
async fn recorder_serializes_response_item_ids_when_feature_enabled() -> std::io::Result<()> {
let home = TempDir::new().expect("temp dir");
let mut config = ConfigBuilder::default()
.codex_home(home.path().to_path_buf())
.build()
.await?;
config
.features
.enable(Feature::RecordResponseItemId)
.expect("test config should allow feature update");
let recorder = RolloutRecorder::new(
&config,
RolloutRecorderParams::new(
ThreadId::new(),
None,
SessionSource::Exec,
BaseInstructions::default(),
Vec::new(),
EventPersistenceMode::Limited,
),
None,
None,
)
.await?;
recorder
.record_items(&[RolloutItem::ResponseItem(ResponseItem::Message {
id: Some("msg-1".to_string()),
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "hello".to_string(),
}],
end_turn: None,
phase: None,
})])
.await?;
recorder.persist().await?;
recorder.flush().await?;
let text = std::fs::read_to_string(recorder.rollout_path())?;
let response_line = text
.lines()
.find(|line| line.contains("\"type\":\"response_item\""))
.expect("response item line should be present");
let response_line: serde_json::Value =
serde_json::from_str(response_line).expect("response line should be valid json");
assert_eq!(response_line["payload"]["id"].as_str(), Some("msg-1"));
recorder.shutdown().await?;
Ok(())
}
#[tokio::test]
async fn metadata_irrelevant_events_touch_state_db_updated_at() -> std::io::Result<()> {
let home = TempDir::new().expect("temp dir");

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use codex_core::CodexAuth;
use codex_core::ModelClient;
use codex_core::ModelClientResponseItemIds;
use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseEvent;
@@ -95,6 +96,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
false,
false,
None,
ModelClientResponseItemIds::Disabled,
);
let mut client_session = client.new_session();
@@ -208,6 +210,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
false,
false,
None,
ModelClientResponseItemIds::Disabled,
);
let mut client_session = client.new_session();
@@ -320,6 +323,7 @@ async fn responses_respects_model_info_overrides_from_config() {
false,
false,
None,
ModelClientResponseItemIds::Disabled,
);
let mut client_session = client.new_session();

View File

@@ -1,5 +1,6 @@
use codex_core::CodexAuth;
use codex_core::ModelClient;
use codex_core::ModelClientResponseItemIds;
use codex_core::ModelProviderInfo;
use codex_core::NewThread;
use codex_core::Prompt;
@@ -1835,6 +1836,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
false,
false,
None,
ModelClientResponseItemIds::Enabled,
);
let mut client_session = client.new_session();
@@ -1943,6 +1945,107 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_request_includes_response_item_ids_when_feature_enabled() {
skip_if_no_network!();
let server = MockServer::start().await;
let sse_body = concat!(
"data: {\"type\":\"response.created\",\"response\":{}}\n\n",
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\"}}\n\n",
);
let resp_mock = mount_sse_once(&server, sse_body.to_string()).await;
let provider = ModelProviderInfo {
name: "openai".into(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(5_000),
websocket_connect_timeout_ms: None,
requires_openai_auth: false,
supports_websockets: false,
};
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home).await;
config.model_provider_id = provider.name.clone();
config.model_provider = provider.clone();
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let model = codex_core::test_support::get_model_offline(config.model.as_deref());
config.model = Some(model.clone());
let config = Arc::new(config);
let model_info =
codex_core::test_support::construct_model_info_offline(model.as_str(), &config);
let conversation_id = ThreadId::new();
let session_telemetry = SessionTelemetry::new(
conversation_id,
model.as_str(),
model_info.slug.as_str(),
None,
Some("test@test.com".to_string()),
None,
"test_originator".to_string(),
false,
"test".to_string(),
SessionSource::Exec,
);
let client = ModelClient::new(
None,
conversation_id,
provider,
SessionSource::Exec,
config.model_verbosity,
false,
false,
None,
ModelClientResponseItemIds::Enabled,
);
let mut client_session = client.new_session();
let mut prompt = Prompt::default();
prompt.input.push(ResponseItem::Message {
id: Some("message-id".into()),
role: "assistant".into(),
content: vec![ContentItem::OutputText {
text: "message".into(),
}],
end_turn: None,
phase: None,
});
let mut stream = client_session
.stream(
&prompt,
&model_info,
&session_telemetry,
effort,
summary.unwrap_or(ReasoningSummary::Auto),
None,
None,
)
.await
.expect("responses stream to start");
while let Some(event) = stream.next().await {
if let Ok(ResponseEvent::Completed { .. }) = event {
break;
}
}
let body = resp_mock.single_request().body_json();
assert_eq!(body["input"][0]["id"].as_str(), Some("message-id"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn token_count_includes_rate_limits_snapshot() {
skip_if_no_network!();

View File

@@ -1,6 +1,7 @@
#![allow(clippy::expect_used, clippy::unwrap_used)]
use codex_core::CodexAuth;
use codex_core::ModelClient;
use codex_core::ModelClientResponseItemIds;
use codex_core::ModelClientSession;
use codex_core::ModelProviderInfo;
use codex_core::Prompt;
@@ -97,6 +98,29 @@ async fn responses_websocket_streams_request() {
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_includes_response_item_ids_when_feature_enabled() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![vec![
ev_response_created("resp-1"),
ev_completed("resp-1"),
]]])
.await;
let harness = websocket_harness_with_response_item_ids(&server).await;
let mut client_session = harness.client.new_session();
let prompt = prompt_with_input(vec![assistant_message_item("msg-1", "hello")]);
stream_until_complete(&mut client_session, &harness, &prompt).await;
let connection = server.single_connection();
let body = connection.first().expect("missing request").body_json();
assert_eq!(body["input"][0]["id"].as_str(), Some("msg-1"));
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_streams_without_feature_flag_when_provider_supports_websockets() {
skip_if_no_network!();
@@ -1551,6 +1575,17 @@ async fn websocket_harness(server: &WebSocketTestServer) -> WebsocketTestHarness
websocket_harness_with_runtime_metrics(server, false).await
}
async fn websocket_harness_with_response_item_ids(
server: &WebSocketTestServer,
) -> WebsocketTestHarness {
websocket_harness_with_provider_options(
websocket_provider(server),
/*runtime_metrics_enabled*/ false,
ModelClientResponseItemIds::Enabled,
)
.await
}
async fn websocket_harness_with_runtime_metrics(
server: &WebSocketTestServer,
runtime_metrics_enabled: bool,
@@ -1569,13 +1604,18 @@ async fn websocket_harness_with_options(
server: &WebSocketTestServer,
runtime_metrics_enabled: bool,
) -> WebsocketTestHarness {
websocket_harness_with_provider_options(websocket_provider(server), runtime_metrics_enabled)
.await
websocket_harness_with_provider_options(
websocket_provider(server),
runtime_metrics_enabled,
ModelClientResponseItemIds::Disabled,
)
.await
}
async fn websocket_harness_with_provider_options(
provider: ModelProviderInfo,
runtime_metrics_enabled: bool,
response_item_ids: ModelClientResponseItemIds,
) -> WebsocketTestHarness {
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home).await;
@@ -1621,6 +1661,7 @@ async fn websocket_harness_with_provider_options(
false,
runtime_metrics_enabled,
None,
response_item_ids,
);
WebsocketTestHarness {