mirror of
https://github.com/openai/codex.git
synced 2026-04-28 02:11:08 +03:00
Compare full request for websockets incrementality (#11343)
Tools can dynamically change mid-turn now. We need to be more thorough about reusing incremental connections.
This commit is contained in:
@@ -80,7 +80,7 @@ pub enum ResponseEvent {
|
|||||||
ModelsEtag(String),
|
ModelsEtag(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Clone)]
|
#[derive(Debug, Serialize, Clone, PartialEq)]
|
||||||
pub struct Reasoning {
|
pub struct Reasoning {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub effort: Option<ReasoningEffortConfig>,
|
pub effort: Option<ReasoningEffortConfig>,
|
||||||
@@ -88,14 +88,14 @@ pub struct Reasoning {
|
|||||||
pub summary: Option<ReasoningSummaryConfig>,
|
pub summary: Option<ReasoningSummaryConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Default, Clone)]
|
#[derive(Debug, Serialize, Default, Clone, PartialEq)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum TextFormatType {
|
pub enum TextFormatType {
|
||||||
#[default]
|
#[default]
|
||||||
JsonSchema,
|
JsonSchema,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Default, Clone)]
|
#[derive(Debug, Serialize, Default, Clone, PartialEq)]
|
||||||
pub struct TextFormat {
|
pub struct TextFormat {
|
||||||
/// Format type used by the OpenAI text controls.
|
/// Format type used by the OpenAI text controls.
|
||||||
pub r#type: TextFormatType,
|
pub r#type: TextFormatType,
|
||||||
@@ -109,7 +109,7 @@ pub struct TextFormat {
|
|||||||
|
|
||||||
/// Controls the `text` field for the Responses API, combining verbosity and
|
/// Controls the `text` field for the Responses API, combining verbosity and
|
||||||
/// optional JSON schema output formatting.
|
/// optional JSON schema output formatting.
|
||||||
#[derive(Debug, Serialize, Default, Clone)]
|
#[derive(Debug, Serialize, Default, Clone, PartialEq)]
|
||||||
pub struct TextControls {
|
pub struct TextControls {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub verbosity: Option<OpenAiVerbosity>,
|
pub verbosity: Option<OpenAiVerbosity>,
|
||||||
@@ -117,7 +117,7 @@ pub struct TextControls {
|
|||||||
pub format: Option<TextFormat>,
|
pub format: Option<TextFormat>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Default, Clone)]
|
#[derive(Debug, Serialize, Default, Clone, PartialEq)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum OpenAiVerbosity {
|
pub enum OpenAiVerbosity {
|
||||||
Low,
|
Low,
|
||||||
@@ -136,7 +136,7 @@ impl From<VerbosityConfig> for OpenAiVerbosity {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Clone)]
|
#[derive(Debug, Serialize, Clone, PartialEq)]
|
||||||
pub struct ResponsesApiRequest {
|
pub struct ResponsesApiRequest {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub instructions: String,
|
pub instructions: String,
|
||||||
|
|||||||
@@ -155,8 +155,8 @@ pub struct ModelClient {
|
|||||||
/// The session establishes a Responses WebSocket connection lazily and reuses it across multiple
|
/// The session establishes a Responses WebSocket connection lazily and reuses it across multiple
|
||||||
/// requests within the turn. It also caches per-turn state:
|
/// requests within the turn. It also caches per-turn state:
|
||||||
///
|
///
|
||||||
/// - The last request's input items, so subsequent calls can use `response.append` when the input
|
/// - The last full request, so subsequent calls can use `response.append` only when the current
|
||||||
/// is an incremental extension of the previous request.
|
/// request is an incremental extension of the previous one.
|
||||||
/// - The `x-codex-turn-state` sticky-routing token, which must be replayed for all requests within
|
/// - The `x-codex-turn-state` sticky-routing token, which must be replayed for all requests within
|
||||||
/// the same turn.
|
/// the same turn.
|
||||||
///
|
///
|
||||||
@@ -166,7 +166,7 @@ pub struct ModelClient {
|
|||||||
pub struct ModelClientSession {
|
pub struct ModelClientSession {
|
||||||
client: ModelClient,
|
client: ModelClient,
|
||||||
connection: Option<ApiWebSocketConnection>,
|
connection: Option<ApiWebSocketConnection>,
|
||||||
websocket_last_items: Vec<ResponseItem>,
|
websocket_last_request: Option<ResponsesApiRequest>,
|
||||||
websocket_last_response_id: Option<String>,
|
websocket_last_response_id: Option<String>,
|
||||||
websocket_last_response_id_rx: Option<oneshot::Receiver<String>>,
|
websocket_last_response_id_rx: Option<oneshot::Receiver<String>>,
|
||||||
/// Turn state for sticky routing.
|
/// Turn state for sticky routing.
|
||||||
@@ -230,7 +230,7 @@ impl ModelClient {
|
|||||||
ModelClientSession {
|
ModelClientSession {
|
||||||
client: self.clone(),
|
client: self.clone(),
|
||||||
connection: None,
|
connection: None,
|
||||||
websocket_last_items: Vec::new(),
|
websocket_last_request: None,
|
||||||
websocket_last_response_id: None,
|
websocket_last_response_id: None,
|
||||||
websocket_last_response_id_rx: None,
|
websocket_last_response_id_rx: None,
|
||||||
turn_state: Arc::new(OnceLock::new()),
|
turn_state: Arc::new(OnceLock::new()),
|
||||||
@@ -530,16 +530,25 @@ impl ModelClientSession {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_incremental_items(&self, input_items: &[ResponseItem]) -> Option<Vec<ResponseItem>> {
|
fn get_incremental_items(&self, request: &ResponsesApiRequest) -> Option<Vec<ResponseItem>> {
|
||||||
// Checks whether the current request input is an incremental append to the previous request.
|
// Checks whether the current request is an incremental append to the previous request.
|
||||||
// If items in the new request contain all the items from the previous request we build
|
// We only append when non-input request fields are unchanged and `input` is a strict
|
||||||
// a response.append request otherwise we start with a fresh response.create request.
|
// extension of the previous input.
|
||||||
let previous_len = self.websocket_last_items.len();
|
let previous_request = self.websocket_last_request.as_ref()?;
|
||||||
let can_append = previous_len > 0
|
let mut previous_without_input = previous_request.clone();
|
||||||
&& input_items.starts_with(&self.websocket_last_items)
|
previous_without_input.input.clear();
|
||||||
&& previous_len < input_items.len();
|
let mut request_without_input = request.clone();
|
||||||
if can_append {
|
request_without_input.input.clear();
|
||||||
Some(input_items[previous_len..].to_vec())
|
if previous_without_input != request_without_input {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let previous_len = previous_request.input.len();
|
||||||
|
if previous_len > 0
|
||||||
|
&& request.input.starts_with(&previous_request.input)
|
||||||
|
&& previous_len < request.input.len()
|
||||||
|
{
|
||||||
|
Some(request.input[previous_len..].to_vec())
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
@@ -571,10 +580,10 @@ impl ModelClientSession {
|
|||||||
fn prepare_websocket_request(
|
fn prepare_websocket_request(
|
||||||
&mut self,
|
&mut self,
|
||||||
payload: ResponseCreateWsRequest,
|
payload: ResponseCreateWsRequest,
|
||||||
) -> (ResponsesWsRequest, Vec<ResponseItem>) {
|
request: &ResponsesApiRequest,
|
||||||
let full_input = payload.input.clone();
|
) -> ResponsesWsRequest {
|
||||||
let responses_websockets_v2_enabled = self.client.responses_websockets_v2_enabled();
|
let responses_websockets_v2_enabled = self.client.responses_websockets_v2_enabled();
|
||||||
let incremental_items = self.get_incremental_items(&full_input);
|
let incremental_items = self.get_incremental_items(request);
|
||||||
if let Some(append_items) = incremental_items {
|
if let Some(append_items) = incremental_items {
|
||||||
if responses_websockets_v2_enabled
|
if responses_websockets_v2_enabled
|
||||||
&& let Some(previous_response_id) = self.websocket_previous_response_id()
|
&& let Some(previous_response_id) = self.websocket_previous_response_id()
|
||||||
@@ -584,20 +593,17 @@ impl ModelClientSession {
|
|||||||
input: append_items,
|
input: append_items,
|
||||||
..payload
|
..payload
|
||||||
};
|
};
|
||||||
return (ResponsesWsRequest::ResponseCreate(payload), full_input);
|
return ResponsesWsRequest::ResponseCreate(payload);
|
||||||
}
|
}
|
||||||
|
|
||||||
if !responses_websockets_v2_enabled {
|
if !responses_websockets_v2_enabled {
|
||||||
return (
|
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
|
||||||
ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
|
input: append_items,
|
||||||
input: append_items,
|
});
|
||||||
}),
|
|
||||||
full_input,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(ResponsesWsRequest::ResponseCreate(payload), full_input)
|
ResponsesWsRequest::ResponseCreate(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Opportunistically warms a websocket for this turn-scoped client session.
|
/// Opportunistically warms a websocket for this turn-scoped client session.
|
||||||
@@ -650,7 +656,7 @@ impl ModelClientSession {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if needs_new {
|
if needs_new {
|
||||||
self.websocket_last_items.clear();
|
self.websocket_last_request = None;
|
||||||
self.websocket_last_response_id = None;
|
self.websocket_last_response_id = None;
|
||||||
self.websocket_last_response_id_rx = None;
|
self.websocket_last_response_id_rx = None;
|
||||||
let turn_state = options
|
let turn_state = options
|
||||||
@@ -806,7 +812,7 @@ impl ModelClientSession {
|
|||||||
Err(err) => return Err(map_api_error(err)),
|
Err(err) => return Err(map_api_error(err)),
|
||||||
}
|
}
|
||||||
|
|
||||||
let (request, request_input) = self.prepare_websocket_request(ws_payload);
|
let ws_request = self.prepare_websocket_request(ws_payload, &request);
|
||||||
|
|
||||||
let stream_result = self
|
let stream_result = self
|
||||||
.connection
|
.connection
|
||||||
@@ -816,10 +822,10 @@ impl ModelClientSession {
|
|||||||
"websocket connection is unavailable".to_string(),
|
"websocket connection is unavailable".to_string(),
|
||||||
))
|
))
|
||||||
})?
|
})?
|
||||||
.stream_request(request)
|
.stream_request(ws_request)
|
||||||
.await
|
.await
|
||||||
.map_err(map_api_error)?;
|
.map_err(map_api_error)?;
|
||||||
self.websocket_last_items = request_input;
|
self.websocket_last_request = Some(request);
|
||||||
let (last_response_id_sender, last_response_id_receiver) = oneshot::channel();
|
let (last_response_id_sender, last_response_id_receiver) = oneshot::channel();
|
||||||
self.websocket_last_response_id_rx = Some(last_response_id_receiver);
|
self.websocket_last_response_id_rx = Some(last_response_id_receiver);
|
||||||
let mut last_response_id_sender = Some(last_response_id_sender);
|
let mut last_response_id_sender = Some(last_response_id_sender);
|
||||||
@@ -928,7 +934,7 @@ impl ModelClientSession {
|
|||||||
);
|
);
|
||||||
|
|
||||||
self.connection = None;
|
self.connection = None;
|
||||||
self.websocket_last_items.clear();
|
self.websocket_last_request = None;
|
||||||
}
|
}
|
||||||
activated
|
activated
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ use codex_otel::metrics::MetricsConfig;
|
|||||||
use codex_protocol::ThreadId;
|
use codex_protocol::ThreadId;
|
||||||
use codex_protocol::account::PlanType;
|
use codex_protocol::account::PlanType;
|
||||||
use codex_protocol::config_types::ReasoningSummary;
|
use codex_protocol::config_types::ReasoningSummary;
|
||||||
|
use codex_protocol::models::BaseInstructions;
|
||||||
use codex_protocol::openai_models::ModelInfo;
|
use codex_protocol::openai_models::ModelInfo;
|
||||||
use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig;
|
use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig;
|
||||||
use codex_protocol::user_input::UserInput;
|
use codex_protocol::user_input::UserInput;
|
||||||
@@ -603,6 +604,42 @@ async fn responses_websocket_creates_on_non_prefix() {
|
|||||||
server.shutdown().await;
|
server.shutdown().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn responses_websocket_creates_when_non_input_request_fields_change() {
|
||||||
|
skip_if_no_network!();
|
||||||
|
|
||||||
|
let server = start_websocket_server(vec![vec![
|
||||||
|
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
|
||||||
|
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
|
||||||
|
]])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let harness = websocket_harness(&server).await;
|
||||||
|
let mut client_session = harness.client.new_session();
|
||||||
|
let prompt_one =
|
||||||
|
prompt_with_input_and_instructions(vec![message_item("hello")], "base instructions one");
|
||||||
|
let prompt_two = prompt_with_input_and_instructions(
|
||||||
|
vec![message_item("hello"), message_item("second")],
|
||||||
|
"base instructions two",
|
||||||
|
);
|
||||||
|
|
||||||
|
stream_until_complete(&mut client_session, &harness, &prompt_one).await;
|
||||||
|
stream_until_complete(&mut client_session, &harness, &prompt_two).await;
|
||||||
|
|
||||||
|
let connection = server.single_connection();
|
||||||
|
assert_eq!(connection.len(), 2);
|
||||||
|
let second = connection.get(1).expect("missing request").body_json();
|
||||||
|
|
||||||
|
assert_eq!(second["type"].as_str(), Some("response.create"));
|
||||||
|
assert_eq!(second.get("previous_response_id"), None);
|
||||||
|
assert_eq!(
|
||||||
|
second["input"],
|
||||||
|
serde_json::to_value(&prompt_two.input).expect("serialize full input")
|
||||||
|
);
|
||||||
|
|
||||||
|
server.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
async fn responses_websocket_v2_creates_with_previous_response_id_on_prefix() {
|
async fn responses_websocket_v2_creates_with_previous_response_id_on_prefix() {
|
||||||
skip_if_no_network!();
|
skip_if_no_network!();
|
||||||
@@ -637,6 +674,43 @@ async fn responses_websocket_v2_creates_with_previous_response_id_on_prefix() {
|
|||||||
server.shutdown().await;
|
server.shutdown().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn responses_websocket_v2_creates_without_previous_response_id_when_non_input_fields_change()
|
||||||
|
{
|
||||||
|
skip_if_no_network!();
|
||||||
|
|
||||||
|
let server = start_websocket_server(vec![vec![
|
||||||
|
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
|
||||||
|
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
|
||||||
|
]])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let harness = websocket_harness_with_v2(&server, true).await;
|
||||||
|
let mut session = harness.client.new_session();
|
||||||
|
let prompt_one =
|
||||||
|
prompt_with_input_and_instructions(vec![message_item("hello")], "base instructions one");
|
||||||
|
let prompt_two = prompt_with_input_and_instructions(
|
||||||
|
vec![message_item("hello"), message_item("second")],
|
||||||
|
"base instructions two",
|
||||||
|
);
|
||||||
|
|
||||||
|
stream_until_complete(&mut session, &harness, &prompt_one).await;
|
||||||
|
stream_until_complete(&mut session, &harness, &prompt_two).await;
|
||||||
|
|
||||||
|
let connection = server.single_connection();
|
||||||
|
assert_eq!(connection.len(), 2);
|
||||||
|
let second = connection.get(1).expect("missing request").body_json();
|
||||||
|
|
||||||
|
assert_eq!(second["type"].as_str(), Some("response.create"));
|
||||||
|
assert_eq!(second.get("previous_response_id"), None);
|
||||||
|
assert_eq!(
|
||||||
|
second["input"],
|
||||||
|
serde_json::to_value(&prompt_two.input).expect("serialize full input")
|
||||||
|
);
|
||||||
|
|
||||||
|
server.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
async fn responses_websocket_v2_after_error_uses_full_create_without_previous_response_id() {
|
async fn responses_websocket_v2_after_error_uses_full_create_without_previous_response_id() {
|
||||||
skip_if_no_network!();
|
skip_if_no_network!();
|
||||||
@@ -778,6 +852,14 @@ fn prompt_with_input(input: Vec<ResponseItem>) -> Prompt {
|
|||||||
prompt
|
prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn prompt_with_input_and_instructions(input: Vec<ResponseItem>, instructions: &str) -> Prompt {
|
||||||
|
let mut prompt = prompt_with_input(input);
|
||||||
|
prompt.base_instructions = BaseInstructions {
|
||||||
|
text: instructions.to_string(),
|
||||||
|
};
|
||||||
|
prompt
|
||||||
|
}
|
||||||
|
|
||||||
fn websocket_provider(server: &WebSocketTestServer) -> ModelProviderInfo {
|
fn websocket_provider(server: &WebSocketTestServer) -> ModelProviderInfo {
|
||||||
ModelProviderInfo {
|
ModelProviderInfo {
|
||||||
name: "mock-ws".into(),
|
name: "mock-ws".into(),
|
||||||
|
|||||||
Reference in New Issue
Block a user