mirror of
https://github.com/openai/codex.git
synced 2026-05-04 05:11:37 +03:00
## Problem The first user turn can pay websocket handshake latency even when a session has already started. We want to reduce that initial delay while preserving turn semantics and avoiding any prompt send during startup. Reviewer feedback also called out duplicated connect/setup paths and unnecessary preconnect state complexity. ## Mental model `ModelClient` owns session-scoped transport state. During session startup, it can opportunistically warm one websocket handshake slot. A turn-scoped `ModelClientSession` adopts that slot once if available, restores captured sticky turn-state, and otherwise opens a websocket through the same shared connect path. If startup preconnect is still in flight, first turn setup awaits that task and treats it as the first connection attempt for the turn. Preconnect is handshake-only. The first `response.create` is still sent only when a turn starts. ## Non-goals This change does not make preconnect required for correctness and does not change prompt/turn payload semantics. It also does not expand fallback behavior beyond clearing preconnect state when fallback activates. ## Tradeoffs The implementation prioritizes simpler ownership and shared connection code over header-match gating for reuse. The single-slot cache keeps lifecycle straightforward but only benefits the immediate next turn. Awaiting in-flight preconnect has the same app-level connect-timeout semantics as existing websocket connect behavior (no new timeout class introduced by this PR). ## Architecture `core/src/client.rs`: - Added session-level preconnect lifecycle state (`Idle` / `InFlight` / `Ready`) carrying one warmed websocket plus optional captured turn-state. - Added `pre_establish_connection()` startup warmup and `preconnect()` handshake-only setup. - Deduped auth/provider resolution into `current_client_setup()` and websocket handshake wiring into `connect_websocket()` / `build_websocket_headers()`. - Updated turn websocket path to adopt preconnect first, await in-flight preconnect when present, then create a new websocket only when needed. - Ensured fallback activation clears warmed preconnect state. - Added documentation for lifecycle, ownership, sticky-routing invariants, and timeout semantics. `core/src/codex.rs`: - Session startup invokes `model_client.pre_establish_connection(...)`. - Turn metadata resolution uses the shared timeout helper. `core/src/turn_metadata.rs`: - Centralized shared timeout helper used by both turn-time metadata resolution and startup preconnect metadata building. `core/tests/common/responses.rs` + websocket test suites: - Added deterministic handshake waiting helper (`wait_for_handshakes`) with bounded polling. - Added startup preconnect and in-flight preconnect reuse coverage. - Fallback expectations now assert exactly two websocket attempts in covered scenarios (startup preconnect + turn attempt before fallback sticks). ## Observability Preconnect remains best-effort and non-fatal. Existing websocket/fallback telemetry remains in place, and debug logs now make preconnect-await behavior and preconnect failures easier to reason about. ## Tests Validated with: 1. `just fmt` 2. `cargo test -p codex-core websocket_preconnect -- --nocapture` 3. `cargo test -p codex-core websocket_fallback -- --nocapture` 4. `cargo test -p codex-core websocket_first_turn_waits_for_inflight_preconnect -- --nocapture`
1230 lines
38 KiB
Rust
1230 lines
38 KiB
Rust
use std::collections::VecDeque;
|
|
use std::sync::Arc;
|
|
use std::sync::Mutex;
|
|
use std::time::Duration;
|
|
|
|
use anyhow::Result;
|
|
use base64::Engine;
|
|
use codex_protocol::openai_models::ModelsResponse;
|
|
use futures::SinkExt;
|
|
use futures::StreamExt;
|
|
use serde_json::Value;
|
|
use tokio::net::TcpListener;
|
|
use tokio::sync::oneshot;
|
|
use tokio_tungstenite::tungstenite::Message;
|
|
use tokio_tungstenite::tungstenite::handshake::server::Request;
|
|
use tokio_tungstenite::tungstenite::handshake::server::Response;
|
|
use wiremock::BodyPrintLimit;
|
|
use wiremock::Match;
|
|
use wiremock::Mock;
|
|
use wiremock::MockBuilder;
|
|
use wiremock::MockServer;
|
|
use wiremock::Respond;
|
|
use wiremock::ResponseTemplate;
|
|
use wiremock::http::HeaderName;
|
|
use wiremock::http::HeaderValue;
|
|
use wiremock::matchers::method;
|
|
use wiremock::matchers::path_regex;
|
|
|
|
use crate::test_codex::ApplyPatchModelOutput;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ResponseMock {
|
|
requests: Arc<Mutex<Vec<ResponsesRequest>>>,
|
|
}
|
|
|
|
impl ResponseMock {
|
|
fn new() -> Self {
|
|
Self {
|
|
requests: Arc::new(Mutex::new(Vec::new())),
|
|
}
|
|
}
|
|
|
|
pub fn single_request(&self) -> ResponsesRequest {
|
|
let requests = self.requests.lock().unwrap();
|
|
if requests.len() != 1 {
|
|
panic!("expected 1 request, got {}", requests.len());
|
|
}
|
|
requests.first().unwrap().clone()
|
|
}
|
|
|
|
pub fn requests(&self) -> Vec<ResponsesRequest> {
|
|
self.requests.lock().unwrap().clone()
|
|
}
|
|
|
|
pub fn last_request(&self) -> Option<ResponsesRequest> {
|
|
self.requests.lock().unwrap().last().cloned()
|
|
}
|
|
|
|
/// Returns true if any captured request contains a `function_call` with the
|
|
/// provided `call_id`.
|
|
pub fn saw_function_call(&self, call_id: &str) -> bool {
|
|
self.requests()
|
|
.iter()
|
|
.any(|req| req.has_function_call(call_id))
|
|
}
|
|
|
|
/// Returns the `output` string for a matching `function_call_output` with
|
|
/// the provided `call_id`, searching across all captured requests.
|
|
pub fn function_call_output_text(&self, call_id: &str) -> Option<String> {
|
|
self.requests()
|
|
.iter()
|
|
.find_map(|req| req.function_call_output_text(call_id))
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ResponsesRequest(wiremock::Request);
|
|
|
|
fn is_zstd_encoding(value: &str) -> bool {
|
|
value
|
|
.split(',')
|
|
.any(|entry| entry.trim().eq_ignore_ascii_case("zstd"))
|
|
}
|
|
|
|
fn decode_body_bytes(body: &[u8], content_encoding: Option<&str>) -> Vec<u8> {
|
|
if content_encoding.is_some_and(is_zstd_encoding) {
|
|
zstd::stream::decode_all(std::io::Cursor::new(body)).unwrap_or_else(|err| {
|
|
panic!("failed to decode zstd request body: {err}");
|
|
})
|
|
} else {
|
|
body.to_vec()
|
|
}
|
|
}
|
|
|
|
impl ResponsesRequest {
|
|
pub fn body_json(&self) -> Value {
|
|
let body = decode_body_bytes(
|
|
&self.0.body,
|
|
self.0
|
|
.headers
|
|
.get("content-encoding")
|
|
.and_then(|value| value.to_str().ok()),
|
|
);
|
|
serde_json::from_slice(&body).unwrap()
|
|
}
|
|
|
|
pub fn body_bytes(&self) -> Vec<u8> {
|
|
self.0.body.clone()
|
|
}
|
|
|
|
pub fn instructions_text(&self) -> String {
|
|
self.body_json()["instructions"]
|
|
.as_str()
|
|
.unwrap()
|
|
.to_string()
|
|
}
|
|
|
|
/// Returns all `input_text` spans from `message` inputs for the provided role.
|
|
pub fn message_input_texts(&self, role: &str) -> Vec<String> {
|
|
self.inputs_of_type("message")
|
|
.into_iter()
|
|
.filter(|item| item.get("role").and_then(Value::as_str) == Some(role))
|
|
.filter_map(|item| item.get("content").and_then(Value::as_array).cloned())
|
|
.flatten()
|
|
.filter(|span| span.get("type").and_then(Value::as_str) == Some("input_text"))
|
|
.filter_map(|span| span.get("text").and_then(Value::as_str).map(str::to_owned))
|
|
.collect()
|
|
}
|
|
|
|
pub fn input(&self) -> Vec<Value> {
|
|
self.body_json()["input"]
|
|
.as_array()
|
|
.expect("input array not found in request")
|
|
.clone()
|
|
}
|
|
|
|
pub fn inputs_of_type(&self, ty: &str) -> Vec<Value> {
|
|
self.input()
|
|
.iter()
|
|
.filter(|item| item.get("type").and_then(Value::as_str) == Some(ty))
|
|
.cloned()
|
|
.collect()
|
|
}
|
|
|
|
pub fn function_call_output(&self, call_id: &str) -> Value {
|
|
self.call_output(call_id, "function_call_output")
|
|
}
|
|
|
|
pub fn custom_tool_call_output(&self, call_id: &str) -> Value {
|
|
self.call_output(call_id, "custom_tool_call_output")
|
|
}
|
|
|
|
pub fn call_output(&self, call_id: &str, call_type: &str) -> Value {
|
|
self.input()
|
|
.iter()
|
|
.find(|item| {
|
|
item.get("type").unwrap() == call_type && item.get("call_id").unwrap() == call_id
|
|
})
|
|
.cloned()
|
|
.unwrap_or_else(|| panic!("function call output {call_id} item not found in request"))
|
|
}
|
|
|
|
/// Returns true if this request's `input` contains a `function_call` with
|
|
/// the specified `call_id`.
|
|
pub fn has_function_call(&self, call_id: &str) -> bool {
|
|
self.input().iter().any(|item| {
|
|
item.get("type").and_then(Value::as_str) == Some("function_call")
|
|
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
|
})
|
|
}
|
|
|
|
/// If present, returns the `output` string of the `function_call_output`
|
|
/// entry matching `call_id` in this request's `input`.
|
|
pub fn function_call_output_text(&self, call_id: &str) -> Option<String> {
|
|
let binding = self.input();
|
|
let item = binding.iter().find(|item| {
|
|
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
|
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
|
})?;
|
|
item.get("output")
|
|
.and_then(Value::as_str)
|
|
.map(str::to_string)
|
|
}
|
|
|
|
pub fn function_call_output_content_and_success(
|
|
&self,
|
|
call_id: &str,
|
|
) -> Option<(Option<String>, Option<bool>)> {
|
|
self.call_output_content_and_success(call_id, "function_call_output")
|
|
}
|
|
|
|
pub fn custom_tool_call_output_content_and_success(
|
|
&self,
|
|
call_id: &str,
|
|
) -> Option<(Option<String>, Option<bool>)> {
|
|
self.call_output_content_and_success(call_id, "custom_tool_call_output")
|
|
}
|
|
|
|
fn call_output_content_and_success(
|
|
&self,
|
|
call_id: &str,
|
|
call_type: &str,
|
|
) -> Option<(Option<String>, Option<bool>)> {
|
|
let output = self
|
|
.call_output(call_id, call_type)
|
|
.get("output")
|
|
.cloned()
|
|
.unwrap_or(Value::Null);
|
|
match output {
|
|
Value::String(text) => Some((Some(text), None)),
|
|
Value::Object(obj) => Some((
|
|
obj.get("content")
|
|
.and_then(Value::as_str)
|
|
.map(str::to_string),
|
|
obj.get("success").and_then(Value::as_bool),
|
|
)),
|
|
_ => Some((None, None)),
|
|
}
|
|
}
|
|
|
|
pub fn header(&self, name: &str) -> Option<String> {
|
|
self.0
|
|
.headers
|
|
.get(name)
|
|
.and_then(|v| v.to_str().ok())
|
|
.map(str::to_string)
|
|
}
|
|
|
|
pub fn path(&self) -> String {
|
|
self.0.url.path().to_string()
|
|
}
|
|
|
|
pub fn query_param(&self, name: &str) -> Option<String> {
|
|
self.0
|
|
.url
|
|
.query_pairs()
|
|
.find(|(k, _)| k == name)
|
|
.map(|(_, v)| v.to_string())
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct WebSocketRequest {
|
|
body: Value,
|
|
}
|
|
|
|
impl WebSocketRequest {
|
|
pub fn body_json(&self) -> Value {
|
|
self.body.clone()
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct WebSocketHandshake {
|
|
headers: Vec<(String, String)>,
|
|
}
|
|
|
|
impl WebSocketHandshake {
|
|
pub fn header(&self, name: &str) -> Option<String> {
|
|
self.headers
|
|
.iter()
|
|
.find(|(header, _)| header.eq_ignore_ascii_case(name))
|
|
.map(|(_, value)| value.clone())
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct WebSocketConnectionConfig {
|
|
pub requests: Vec<Vec<Value>>,
|
|
pub response_headers: Vec<(String, String)>,
|
|
/// Optional delay inserted before accepting the websocket handshake.
|
|
///
|
|
/// Tests use this to force startup preconnect into an in-flight state so first-turn adoption
|
|
/// paths can be exercised deterministically.
|
|
pub accept_delay: Option<Duration>,
|
|
}
|
|
|
|
pub struct WebSocketTestServer {
|
|
uri: String,
|
|
connections: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
|
|
handshakes: Arc<Mutex<Vec<WebSocketHandshake>>>,
|
|
shutdown: oneshot::Sender<()>,
|
|
task: tokio::task::JoinHandle<()>,
|
|
}
|
|
|
|
impl WebSocketTestServer {
|
|
pub fn uri(&self) -> &str {
|
|
&self.uri
|
|
}
|
|
|
|
pub fn connections(&self) -> Vec<Vec<WebSocketRequest>> {
|
|
self.connections.lock().unwrap().clone()
|
|
}
|
|
|
|
pub fn single_connection(&self) -> Vec<WebSocketRequest> {
|
|
let connections = self.connections.lock().unwrap();
|
|
if connections.len() != 1 {
|
|
panic!("expected 1 connection, got {}", connections.len());
|
|
}
|
|
connections.first().cloned().unwrap_or_default()
|
|
}
|
|
|
|
pub fn handshakes(&self) -> Vec<WebSocketHandshake> {
|
|
self.handshakes.lock().unwrap().clone()
|
|
}
|
|
|
|
/// Waits until at least `expected` websocket handshakes have been observed or timeout elapses.
|
|
///
|
|
/// Uses a short bounded polling interval so tests can deterministically wait for background
|
|
/// preconnect activity without busy-spinning.
|
|
pub async fn wait_for_handshakes(&self, expected: usize, timeout: Duration) -> bool {
|
|
if self.handshakes.lock().unwrap().len() >= expected {
|
|
return true;
|
|
}
|
|
|
|
let deadline = tokio::time::Instant::now() + timeout;
|
|
let poll_interval = Duration::from_millis(10);
|
|
loop {
|
|
if self.handshakes.lock().unwrap().len() >= expected {
|
|
return true;
|
|
}
|
|
let now = tokio::time::Instant::now();
|
|
if now >= deadline {
|
|
return false;
|
|
}
|
|
let sleep_for = std::cmp::min(poll_interval, deadline.saturating_duration_since(now));
|
|
tokio::time::sleep(sleep_for).await;
|
|
}
|
|
}
|
|
pub fn single_handshake(&self) -> WebSocketHandshake {
|
|
let handshakes = self.handshakes.lock().unwrap();
|
|
if handshakes.len() != 1 {
|
|
panic!("expected 1 handshake, got {}", handshakes.len());
|
|
}
|
|
handshakes.first().cloned().unwrap()
|
|
}
|
|
|
|
pub async fn shutdown(self) {
|
|
let _ = self.shutdown.send(());
|
|
let _ = self.task.await;
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ModelsMock {
|
|
requests: Arc<Mutex<Vec<wiremock::Request>>>,
|
|
}
|
|
|
|
impl ModelsMock {
|
|
fn new() -> Self {
|
|
Self {
|
|
requests: Arc::new(Mutex::new(Vec::new())),
|
|
}
|
|
}
|
|
|
|
pub fn requests(&self) -> Vec<wiremock::Request> {
|
|
self.requests.lock().unwrap().clone()
|
|
}
|
|
|
|
pub fn single_request_path(&self) -> String {
|
|
let requests = self.requests.lock().unwrap();
|
|
if requests.len() != 1 {
|
|
panic!("expected 1 request, got {}", requests.len());
|
|
}
|
|
requests.first().unwrap().url.path().to_string()
|
|
}
|
|
}
|
|
|
|
impl Match for ModelsMock {
|
|
fn matches(&self, request: &wiremock::Request) -> bool {
|
|
self.requests.lock().unwrap().push(request.clone());
|
|
true
|
|
}
|
|
}
|
|
|
|
impl Match for ResponseMock {
|
|
fn matches(&self, request: &wiremock::Request) -> bool {
|
|
self.requests
|
|
.lock()
|
|
.unwrap()
|
|
.push(ResponsesRequest(request.clone()));
|
|
|
|
// Enforce invariant checks on every request body captured by the mock.
|
|
// Panic on orphan tool outputs or calls to catch regressions early.
|
|
validate_request_body_invariants(request);
|
|
true
|
|
}
|
|
}
|
|
|
|
/// Build an SSE stream body from a list of JSON events.
|
|
pub fn sse(events: Vec<Value>) -> String {
|
|
use std::fmt::Write as _;
|
|
let mut out = String::new();
|
|
for ev in events {
|
|
let kind = ev.get("type").and_then(|v| v.as_str()).unwrap();
|
|
writeln!(&mut out, "event: {kind}").unwrap();
|
|
if !ev.as_object().map(|o| o.len() == 1).unwrap_or(false) {
|
|
write!(&mut out, "data: {ev}\n\n").unwrap();
|
|
} else {
|
|
out.push('\n');
|
|
}
|
|
}
|
|
out
|
|
}
|
|
|
|
pub fn sse_completed(id: &str) -> String {
|
|
sse(vec![ev_response_created(id), ev_completed(id)])
|
|
}
|
|
|
|
/// Convenience: SSE event for a completed response with a specific id.
|
|
pub fn ev_completed(id: &str) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.completed",
|
|
"response": {
|
|
"id": id,
|
|
"usage": {"input_tokens":0,"input_tokens_details":null,"output_tokens":0,"output_tokens_details":null,"total_tokens":0}
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn ev_done() -> Value {
|
|
serde_json::json!({
|
|
"type": "response.done",
|
|
"response": {
|
|
"usage": {"input_tokens":0,"input_tokens_details":null,"output_tokens":0,"output_tokens_details":null,"total_tokens":0}
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Convenience: SSE event for a created response with a specific id.
|
|
pub fn ev_response_created(id: &str) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.created",
|
|
"response": {
|
|
"id": id,
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn ev_completed_with_tokens(id: &str, total_tokens: i64) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.completed",
|
|
"response": {
|
|
"id": id,
|
|
"usage": {
|
|
"input_tokens": total_tokens,
|
|
"input_tokens_details": null,
|
|
"output_tokens": 0,
|
|
"output_tokens_details": null,
|
|
"total_tokens": total_tokens
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Convenience: SSE event for a single assistant message output item.
|
|
pub fn ev_assistant_message(id: &str, text: &str) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.output_item.done",
|
|
"item": {
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"id": id,
|
|
"content": [{"type": "output_text", "text": text}]
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn ev_message_item_added(id: &str, text: &str) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.output_item.added",
|
|
"item": {
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"id": id,
|
|
"content": [{"type": "output_text", "text": text}]
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn ev_output_text_delta(delta: &str) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.output_text.delta",
|
|
"delta": delta,
|
|
})
|
|
}
|
|
|
|
pub fn ev_reasoning_item(id: &str, summary: &[&str], raw_content: &[&str]) -> Value {
|
|
let summary_entries: Vec<Value> = summary
|
|
.iter()
|
|
.map(|text| serde_json::json!({"type": "summary_text", "text": text}))
|
|
.collect();
|
|
|
|
let overhead = "b".repeat(550);
|
|
let raw_content_joined = raw_content.join("");
|
|
let encrypted_content =
|
|
base64::engine::general_purpose::STANDARD.encode(overhead + raw_content_joined.as_str());
|
|
|
|
let mut event = serde_json::json!({
|
|
"type": "response.output_item.done",
|
|
"item": {
|
|
"type": "reasoning",
|
|
"id": id,
|
|
"summary": summary_entries,
|
|
"encrypted_content": encrypted_content,
|
|
}
|
|
});
|
|
|
|
if !raw_content.is_empty() {
|
|
let content_entries: Vec<Value> = raw_content
|
|
.iter()
|
|
.map(|text| serde_json::json!({"type": "reasoning_text", "text": text}))
|
|
.collect();
|
|
event["item"]["content"] = Value::Array(content_entries);
|
|
}
|
|
|
|
event
|
|
}
|
|
|
|
pub fn ev_reasoning_item_added(id: &str, summary: &[&str]) -> Value {
|
|
let summary_entries: Vec<Value> = summary
|
|
.iter()
|
|
.map(|text| serde_json::json!({"type": "summary_text", "text": text}))
|
|
.collect();
|
|
|
|
serde_json::json!({
|
|
"type": "response.output_item.added",
|
|
"item": {
|
|
"type": "reasoning",
|
|
"id": id,
|
|
"summary": summary_entries,
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn ev_reasoning_summary_text_delta(delta: &str) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.reasoning_summary_text.delta",
|
|
"delta": delta,
|
|
"summary_index": 0,
|
|
})
|
|
}
|
|
|
|
pub fn ev_reasoning_text_delta(delta: &str) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.reasoning_text.delta",
|
|
"delta": delta,
|
|
"content_index": 0,
|
|
})
|
|
}
|
|
|
|
pub fn ev_web_search_call_added_partial(id: &str, status: &str) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.output_item.added",
|
|
"item": {
|
|
"type": "web_search_call",
|
|
"id": id,
|
|
"status": status
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn ev_web_search_call_done(id: &str, status: &str, query: &str) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.output_item.done",
|
|
"item": {
|
|
"type": "web_search_call",
|
|
"id": id,
|
|
"status": status,
|
|
"action": {"type": "search", "query": query}
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn ev_function_call(call_id: &str, name: &str, arguments: &str) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.output_item.done",
|
|
"item": {
|
|
"type": "function_call",
|
|
"call_id": call_id,
|
|
"name": name,
|
|
"arguments": arguments
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn ev_custom_tool_call(call_id: &str, name: &str, input: &str) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.output_item.done",
|
|
"item": {
|
|
"type": "custom_tool_call",
|
|
"call_id": call_id,
|
|
"name": name,
|
|
"input": input
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn ev_local_shell_call(call_id: &str, status: &str, command: Vec<&str>) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.output_item.done",
|
|
"item": {
|
|
"type": "local_shell_call",
|
|
"call_id": call_id,
|
|
"status": status,
|
|
"action": {
|
|
"type": "exec",
|
|
"command": command,
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn ev_apply_patch_call(
|
|
call_id: &str,
|
|
patch: &str,
|
|
output_type: ApplyPatchModelOutput,
|
|
) -> Value {
|
|
match output_type {
|
|
ApplyPatchModelOutput::Freeform => ev_apply_patch_custom_tool_call(call_id, patch),
|
|
ApplyPatchModelOutput::Function => ev_apply_patch_function_call(call_id, patch),
|
|
ApplyPatchModelOutput::Shell => ev_apply_patch_shell_call(call_id, patch),
|
|
ApplyPatchModelOutput::ShellViaHeredoc => {
|
|
ev_apply_patch_shell_call_via_heredoc(call_id, patch)
|
|
}
|
|
ApplyPatchModelOutput::ShellCommandViaHeredoc => {
|
|
ev_apply_patch_shell_command_call_via_heredoc(call_id, patch)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Convenience: SSE event for an `apply_patch` custom tool call with raw patch
|
|
/// text. This mirrors the payload produced by the Responses API when the model
|
|
/// invokes `apply_patch` directly (before we convert it to a function call).
|
|
pub fn ev_apply_patch_custom_tool_call(call_id: &str, patch: &str) -> Value {
|
|
serde_json::json!({
|
|
"type": "response.output_item.done",
|
|
"item": {
|
|
"type": "custom_tool_call",
|
|
"name": "apply_patch",
|
|
"input": patch,
|
|
"call_id": call_id
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Convenience: SSE event for an `apply_patch` function call. The Responses API
|
|
/// wraps the patch content in a JSON string under the `input` key; we recreate
|
|
/// the same structure so downstream code exercises the full parsing path.
|
|
pub fn ev_apply_patch_function_call(call_id: &str, patch: &str) -> Value {
|
|
let arguments = serde_json::json!({ "input": patch });
|
|
let arguments = serde_json::to_string(&arguments).expect("serialize apply_patch arguments");
|
|
|
|
serde_json::json!({
|
|
"type": "response.output_item.done",
|
|
"item": {
|
|
"type": "function_call",
|
|
"name": "apply_patch",
|
|
"arguments": arguments,
|
|
"call_id": call_id
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn ev_shell_command_call(call_id: &str, command: &str) -> Value {
|
|
let args = serde_json::json!({ "command": command });
|
|
ev_shell_command_call_with_args(call_id, &args)
|
|
}
|
|
|
|
pub fn ev_shell_command_call_with_args(call_id: &str, args: &serde_json::Value) -> Value {
|
|
let arguments = serde_json::to_string(args).expect("serialize shell command arguments");
|
|
ev_function_call(call_id, "shell_command", &arguments)
|
|
}
|
|
|
|
pub fn ev_apply_patch_shell_call(call_id: &str, patch: &str) -> Value {
|
|
let args = serde_json::json!({ "command": ["apply_patch", patch] });
|
|
let arguments = serde_json::to_string(&args).expect("serialize apply_patch arguments");
|
|
|
|
ev_function_call(call_id, "shell", &arguments)
|
|
}
|
|
|
|
pub fn ev_apply_patch_shell_call_via_heredoc(call_id: &str, patch: &str) -> Value {
|
|
let script = format!("apply_patch <<'EOF'\n{patch}\nEOF\n");
|
|
let args = serde_json::json!({ "command": ["bash", "-lc", script] });
|
|
let arguments = serde_json::to_string(&args).expect("serialize apply_patch arguments");
|
|
|
|
ev_function_call(call_id, "shell", &arguments)
|
|
}
|
|
|
|
pub fn ev_apply_patch_shell_command_call_via_heredoc(call_id: &str, patch: &str) -> Value {
|
|
let args = serde_json::json!({ "command": format!("apply_patch <<'EOF'\n{patch}\nEOF\n") });
|
|
let arguments = serde_json::to_string(&args).expect("serialize apply_patch arguments");
|
|
|
|
ev_function_call(call_id, "shell_command", &arguments)
|
|
}
|
|
|
|
pub fn sse_failed(id: &str, code: &str, message: &str) -> String {
|
|
sse(vec![serde_json::json!({
|
|
"type": "response.failed",
|
|
"response": {
|
|
"id": id,
|
|
"error": {"code": code, "message": message}
|
|
}
|
|
})])
|
|
}
|
|
|
|
pub fn sse_response(body: String) -> ResponseTemplate {
|
|
ResponseTemplate::new(200)
|
|
.insert_header("content-type", "text/event-stream")
|
|
.set_body_raw(body, "text/event-stream")
|
|
}
|
|
|
|
pub async fn mount_response_once(server: &MockServer, response: ResponseTemplate) -> ResponseMock {
|
|
let (mock, response_mock) = base_mock();
|
|
mock.respond_with(response)
|
|
.up_to_n_times(1)
|
|
.mount(server)
|
|
.await;
|
|
response_mock
|
|
}
|
|
|
|
pub async fn mount_response_once_match<M>(
|
|
server: &MockServer,
|
|
matcher: M,
|
|
response: ResponseTemplate,
|
|
) -> ResponseMock
|
|
where
|
|
M: wiremock::Match + Send + Sync + 'static,
|
|
{
|
|
let (mock, response_mock) = base_mock();
|
|
mock.and(matcher)
|
|
.respond_with(response)
|
|
.up_to_n_times(1)
|
|
.mount(server)
|
|
.await;
|
|
response_mock
|
|
}
|
|
|
|
fn base_mock() -> (MockBuilder, ResponseMock) {
|
|
let response_mock = ResponseMock::new();
|
|
let mock = Mock::given(method("POST"))
|
|
.and(path_regex(".*/responses$"))
|
|
.and(response_mock.clone());
|
|
(mock, response_mock)
|
|
}
|
|
|
|
fn compact_mock() -> (MockBuilder, ResponseMock) {
|
|
let response_mock = ResponseMock::new();
|
|
let mock = Mock::given(method("POST"))
|
|
.and(path_regex(".*/responses/compact$"))
|
|
.and(response_mock.clone());
|
|
(mock, response_mock)
|
|
}
|
|
|
|
fn models_mock() -> (MockBuilder, ModelsMock) {
|
|
let models_mock = ModelsMock::new();
|
|
let mock = Mock::given(method("GET"))
|
|
.and(path_regex(".*/models$"))
|
|
.and(models_mock.clone());
|
|
(mock, models_mock)
|
|
}
|
|
|
|
pub async fn mount_sse_once_match<M>(server: &MockServer, matcher: M, body: String) -> ResponseMock
|
|
where
|
|
M: wiremock::Match + Send + Sync + 'static,
|
|
{
|
|
let (mock, response_mock) = base_mock();
|
|
mock.and(matcher)
|
|
.respond_with(sse_response(body))
|
|
.up_to_n_times(1)
|
|
.mount(server)
|
|
.await;
|
|
response_mock
|
|
}
|
|
|
|
pub async fn mount_sse_once(server: &MockServer, body: String) -> ResponseMock {
|
|
let (mock, response_mock) = base_mock();
|
|
mock.respond_with(sse_response(body))
|
|
.up_to_n_times(1)
|
|
.mount(server)
|
|
.await;
|
|
response_mock
|
|
}
|
|
|
|
pub async fn mount_compact_json_once_match<M>(
|
|
server: &MockServer,
|
|
matcher: M,
|
|
body: serde_json::Value,
|
|
) -> ResponseMock
|
|
where
|
|
M: wiremock::Match + Send + Sync + 'static,
|
|
{
|
|
let (mock, response_mock) = compact_mock();
|
|
mock.and(matcher)
|
|
.respond_with(
|
|
ResponseTemplate::new(200)
|
|
.insert_header("content-type", "application/json")
|
|
.set_body_json(body.clone()),
|
|
)
|
|
.up_to_n_times(1)
|
|
.mount(server)
|
|
.await;
|
|
response_mock
|
|
}
|
|
|
|
pub async fn mount_compact_json_once(server: &MockServer, body: serde_json::Value) -> ResponseMock {
|
|
let (mock, response_mock) = compact_mock();
|
|
mock.respond_with(
|
|
ResponseTemplate::new(200)
|
|
.insert_header("content-type", "application/json")
|
|
.set_body_json(body.clone()),
|
|
)
|
|
.up_to_n_times(1)
|
|
.mount(server)
|
|
.await;
|
|
response_mock
|
|
}
|
|
|
|
pub async fn mount_models_once(server: &MockServer, body: ModelsResponse) -> ModelsMock {
|
|
let (mock, models_mock) = models_mock();
|
|
mock.respond_with(
|
|
ResponseTemplate::new(200)
|
|
.insert_header("content-type", "application/json")
|
|
.set_body_json(body.clone()),
|
|
)
|
|
.up_to_n_times(1)
|
|
.mount(server)
|
|
.await;
|
|
models_mock
|
|
}
|
|
|
|
pub async fn mount_models_once_with_delay(
|
|
server: &MockServer,
|
|
body: ModelsResponse,
|
|
delay: Duration,
|
|
) -> ModelsMock {
|
|
let (mock, models_mock) = models_mock();
|
|
mock.respond_with(
|
|
ResponseTemplate::new(200)
|
|
.insert_header("content-type", "application/json")
|
|
.set_body_json(body.clone())
|
|
.set_delay(delay),
|
|
)
|
|
.up_to_n_times(1)
|
|
.mount(server)
|
|
.await;
|
|
models_mock
|
|
}
|
|
|
|
pub async fn mount_models_once_with_etag(
|
|
server: &MockServer,
|
|
body: ModelsResponse,
|
|
etag: &str,
|
|
) -> ModelsMock {
|
|
let (mock, models_mock) = models_mock();
|
|
mock.respond_with(
|
|
ResponseTemplate::new(200)
|
|
.insert_header("content-type", "application/json")
|
|
// ModelsClient reads the ETag header, not a JSON field.
|
|
.insert_header("ETag", etag)
|
|
.set_body_json(body.clone()),
|
|
)
|
|
.up_to_n_times(1)
|
|
.mount(server)
|
|
.await;
|
|
models_mock
|
|
}
|
|
|
|
pub async fn start_mock_server() -> MockServer {
|
|
let server = MockServer::builder()
|
|
.body_print_limit(BodyPrintLimit::Limited(80_000))
|
|
.start()
|
|
.await;
|
|
|
|
// Provide a default `/models` response so tests remain hermetic when the client queries it.
|
|
let _ = mount_models_once(&server, ModelsResponse { models: Vec::new() }).await;
|
|
|
|
server
|
|
}
|
|
|
|
/// Starts a lightweight WebSocket server for `/v1/responses` tests.
|
|
///
|
|
/// Each connection consumes a queue of request/event sequences. For each
|
|
/// request message, the server records the payload and streams the matching
|
|
/// events as WebSocket text frames before moving to the next request.
|
|
pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSocketTestServer {
|
|
let connections = connections
|
|
.into_iter()
|
|
.map(|requests| WebSocketConnectionConfig {
|
|
requests,
|
|
response_headers: Vec::new(),
|
|
accept_delay: None,
|
|
})
|
|
.collect();
|
|
start_websocket_server_with_headers(connections).await
|
|
}
|
|
|
|
pub async fn start_websocket_server_with_headers(
|
|
connections: Vec<WebSocketConnectionConfig>,
|
|
) -> WebSocketTestServer {
|
|
let listener = TcpListener::bind("127.0.0.1:0")
|
|
.await
|
|
.expect("bind websocket server");
|
|
let addr = listener.local_addr().expect("websocket server address");
|
|
let uri = format!("ws://{addr}");
|
|
let connections_log = Arc::new(Mutex::new(Vec::new()));
|
|
let handshakes_log = Arc::new(Mutex::new(Vec::new()));
|
|
let requests = Arc::clone(&connections_log);
|
|
let handshakes = Arc::clone(&handshakes_log);
|
|
let connections = Arc::new(Mutex::new(VecDeque::from(connections)));
|
|
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
|
|
|
|
let task = tokio::spawn(async move {
|
|
loop {
|
|
let accept_res = tokio::select! {
|
|
_ = &mut shutdown_rx => return,
|
|
accept_res = listener.accept() => accept_res,
|
|
};
|
|
let (stream, _) = match accept_res {
|
|
Ok(value) => value,
|
|
Err(_) => return,
|
|
};
|
|
let connection = {
|
|
let mut pending = connections.lock().unwrap();
|
|
pending.pop_front()
|
|
};
|
|
|
|
let Some(connection) = connection else {
|
|
continue;
|
|
};
|
|
|
|
if let Some(delay) = connection.accept_delay {
|
|
tokio::time::sleep(delay).await;
|
|
}
|
|
|
|
let response_headers = connection.response_headers.clone();
|
|
let handshake_log = Arc::clone(&handshakes);
|
|
let callback = move |req: &Request, mut response: Response| {
|
|
let headers = req
|
|
.headers()
|
|
.iter()
|
|
.filter_map(|(name, value)| {
|
|
value
|
|
.to_str()
|
|
.ok()
|
|
.map(|value| (name.as_str().to_string(), value.to_string()))
|
|
})
|
|
.collect();
|
|
handshake_log
|
|
.lock()
|
|
.unwrap()
|
|
.push(WebSocketHandshake { headers });
|
|
|
|
let headers_mut = response.headers_mut();
|
|
for (name, value) in &response_headers {
|
|
if let (Ok(name), Ok(value)) = (
|
|
HeaderName::from_bytes(name.as_bytes()),
|
|
HeaderValue::from_str(value),
|
|
) {
|
|
headers_mut.insert(name, value);
|
|
}
|
|
}
|
|
|
|
Ok(response)
|
|
};
|
|
|
|
let mut ws_stream = match tokio_tungstenite::accept_hdr_async(stream, callback).await {
|
|
Ok(ws) => ws,
|
|
Err(_) => continue,
|
|
};
|
|
|
|
let connection_index = {
|
|
let mut log = requests.lock().unwrap();
|
|
log.push(Vec::new());
|
|
log.len() - 1
|
|
};
|
|
for request_events in connection.requests {
|
|
let Some(Ok(message)) = ws_stream.next().await else {
|
|
break;
|
|
};
|
|
if let Some(body) = parse_ws_request_body(message) {
|
|
let mut log = requests.lock().unwrap();
|
|
if let Some(connection_log) = log.get_mut(connection_index) {
|
|
connection_log.push(WebSocketRequest { body });
|
|
}
|
|
}
|
|
|
|
for event in &request_events {
|
|
let Ok(payload) = serde_json::to_string(event) else {
|
|
continue;
|
|
};
|
|
if ws_stream.send(Message::Text(payload.into())).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
let _ = ws_stream.close(None).await;
|
|
|
|
if connections.lock().unwrap().is_empty() {
|
|
return;
|
|
}
|
|
}
|
|
});
|
|
|
|
WebSocketTestServer {
|
|
uri,
|
|
connections: connections_log,
|
|
handshakes: handshakes_log,
|
|
shutdown: shutdown_tx,
|
|
task,
|
|
}
|
|
}
|
|
|
|
fn parse_ws_request_body(message: Message) -> Option<Value> {
|
|
match message {
|
|
Message::Text(text) => serde_json::from_str(&text).ok(),
|
|
Message::Binary(bytes) => serde_json::from_slice(&bytes).ok(),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct FunctionCallResponseMocks {
|
|
pub function_call: ResponseMock,
|
|
pub completion: ResponseMock,
|
|
}
|
|
|
|
pub async fn mount_function_call_agent_response(
|
|
server: &MockServer,
|
|
call_id: &str,
|
|
arguments: &str,
|
|
tool_name: &str,
|
|
) -> FunctionCallResponseMocks {
|
|
let first_response = sse(vec![
|
|
ev_response_created("resp-1"),
|
|
ev_function_call(call_id, tool_name, arguments),
|
|
ev_completed("resp-1"),
|
|
]);
|
|
let function_call = mount_sse_once(server, first_response).await;
|
|
|
|
let second_response = sse(vec![
|
|
ev_assistant_message("msg-1", "done"),
|
|
ev_completed("resp-2"),
|
|
]);
|
|
let completion = mount_sse_once(server, second_response).await;
|
|
|
|
FunctionCallResponseMocks {
|
|
function_call,
|
|
completion,
|
|
}
|
|
}
|
|
|
|
/// Mounts a sequence of SSE response bodies and serves them in order for each
|
|
/// POST to `/v1/responses`. Panics if more requests are received than bodies
|
|
/// provided. Also asserts the exact number of expected calls.
|
|
pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) -> ResponseMock {
|
|
use std::sync::atomic::AtomicUsize;
|
|
use std::sync::atomic::Ordering;
|
|
|
|
struct SeqResponder {
|
|
num_calls: AtomicUsize,
|
|
responses: Vec<String>,
|
|
}
|
|
|
|
impl Respond for SeqResponder {
|
|
fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
|
|
let call_num = self.num_calls.fetch_add(1, Ordering::SeqCst);
|
|
match self.responses.get(call_num) {
|
|
Some(body) => ResponseTemplate::new(200)
|
|
.insert_header("content-type", "text/event-stream")
|
|
.set_body_string(body.clone()),
|
|
None => panic!("no response for {call_num}"),
|
|
}
|
|
}
|
|
}
|
|
|
|
let num_calls = bodies.len();
|
|
let responder = SeqResponder {
|
|
num_calls: AtomicUsize::new(0),
|
|
responses: bodies,
|
|
};
|
|
|
|
let (mock, response_mock) = base_mock();
|
|
mock.respond_with(responder)
|
|
.up_to_n_times(num_calls as u64)
|
|
.expect(num_calls as u64)
|
|
.mount(server)
|
|
.await;
|
|
|
|
response_mock
|
|
}
|
|
|
|
/// Mounts a sequence of responses for each POST to `/v1/responses`.
|
|
/// Panics if more requests are received than responses provided.
|
|
pub async fn mount_response_sequence(
|
|
server: &MockServer,
|
|
responses: Vec<ResponseTemplate>,
|
|
) -> ResponseMock {
|
|
use std::sync::atomic::AtomicUsize;
|
|
use std::sync::atomic::Ordering;
|
|
|
|
struct SeqResponder {
|
|
num_calls: AtomicUsize,
|
|
responses: Vec<ResponseTemplate>,
|
|
}
|
|
|
|
impl Respond for SeqResponder {
|
|
fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
|
|
let call_num = self.num_calls.fetch_add(1, Ordering::SeqCst);
|
|
self.responses
|
|
.get(call_num)
|
|
.unwrap_or_else(|| panic!("no response for {call_num}"))
|
|
.clone()
|
|
}
|
|
}
|
|
|
|
let num_calls = responses.len();
|
|
let responder = SeqResponder {
|
|
num_calls: AtomicUsize::new(0),
|
|
responses,
|
|
};
|
|
|
|
let (mock, response_mock) = base_mock();
|
|
mock.respond_with(responder)
|
|
.up_to_n_times(num_calls as u64)
|
|
.expect(num_calls as u64)
|
|
.mount(server)
|
|
.await;
|
|
response_mock
|
|
}
|
|
|
|
/// Validate invariants on the request body sent to `/v1/responses`.
|
|
///
|
|
/// - No `function_call_output`/`custom_tool_call_output` with missing/empty `call_id`.
|
|
/// - Every `function_call_output` must match a prior `function_call` or
|
|
/// `local_shell_call` with the same `call_id` in the same `input`.
|
|
/// - Every `custom_tool_call_output` must match a prior `custom_tool_call`.
|
|
/// - Additionally, enforce symmetry: every `function_call`/`custom_tool_call`
|
|
/// in the `input` must have a matching output entry.
|
|
fn validate_request_body_invariants(request: &wiremock::Request) {
|
|
// Skip GET requests (e.g., /models)
|
|
if request.method != "POST" || !request.url.path().ends_with("/responses") {
|
|
return;
|
|
}
|
|
let body_bytes = decode_body_bytes(
|
|
&request.body,
|
|
request
|
|
.headers
|
|
.get("content-encoding")
|
|
.and_then(|value| value.to_str().ok()),
|
|
);
|
|
let Ok(body): Result<Value, _> = serde_json::from_slice(&body_bytes) else {
|
|
return;
|
|
};
|
|
let Some(items) = body.get("input").and_then(Value::as_array) else {
|
|
panic!("input array not found in request");
|
|
};
|
|
|
|
use std::collections::HashSet;
|
|
|
|
fn get_call_id(item: &Value) -> Option<&str> {
|
|
item.get("call_id")
|
|
.and_then(Value::as_str)
|
|
.filter(|id| !id.is_empty())
|
|
}
|
|
|
|
fn gather_ids(items: &[Value], kind: &str) -> HashSet<String> {
|
|
items
|
|
.iter()
|
|
.filter(|item| item.get("type").and_then(Value::as_str) == Some(kind))
|
|
.filter_map(get_call_id)
|
|
.map(str::to_string)
|
|
.collect()
|
|
}
|
|
|
|
fn gather_output_ids(items: &[Value], kind: &str, missing_msg: &str) -> HashSet<String> {
|
|
items
|
|
.iter()
|
|
.filter(|item| item.get("type").and_then(Value::as_str) == Some(kind))
|
|
.map(|item| {
|
|
let Some(id) = get_call_id(item) else {
|
|
panic!("{missing_msg}");
|
|
};
|
|
id.to_string()
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
let function_calls = gather_ids(items, "function_call");
|
|
let custom_tool_calls = gather_ids(items, "custom_tool_call");
|
|
let local_shell_calls = gather_ids(items, "local_shell_call");
|
|
let function_call_outputs = gather_output_ids(
|
|
items,
|
|
"function_call_output",
|
|
"orphan function_call_output with empty call_id should be dropped",
|
|
);
|
|
let custom_tool_call_outputs = gather_output_ids(
|
|
items,
|
|
"custom_tool_call_output",
|
|
"orphan custom_tool_call_output with empty call_id should be dropped",
|
|
);
|
|
|
|
for cid in &function_call_outputs {
|
|
assert!(
|
|
function_calls.contains(cid) || local_shell_calls.contains(cid),
|
|
"function_call_output without matching call in input: {cid}",
|
|
);
|
|
}
|
|
for cid in &custom_tool_call_outputs {
|
|
assert!(
|
|
custom_tool_calls.contains(cid),
|
|
"custom_tool_call_output without matching call in input: {cid}",
|
|
);
|
|
}
|
|
|
|
for cid in &function_calls {
|
|
assert!(
|
|
function_call_outputs.contains(cid),
|
|
"Function call output is missing for call id: {cid}",
|
|
);
|
|
}
|
|
for cid in &custom_tool_calls {
|
|
assert!(
|
|
custom_tool_call_outputs.contains(cid),
|
|
"Custom tool call output is missing for call id: {cid}",
|
|
);
|
|
}
|
|
}
|