Turn-state sticky routing per turn (#9332)

- capture the header from SSE/WS handshakes, store it per
ModelClientSession using `Oncelock`, echo it on turn-scoped requests,
and add SSE+WS integration tests for within-turn persistence +
cross-turn reset.

- keep `x-codex-turn-state` sticky within a user turn to maintain
routing continuity for retries/tool follow-ups.
This commit is contained in:
Ahmed Ibrahim
2026-01-16 09:30:11 -08:00
committed by GitHub
parent 4125c825f9
commit ebdd8795e9
11 changed files with 343 additions and 24 deletions

View File

@@ -12,6 +12,8 @@ use serde_json::Value;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::handshake::server::Request;
use tokio_tungstenite::tungstenite::handshake::server::Response;
use wiremock::BodyPrintLimit;
use wiremock::Match;
use wiremock::Mock;
@@ -19,6 +21,8 @@ use wiremock::MockBuilder;
use wiremock::MockServer;
use wiremock::Respond;
use wiremock::ResponseTemplate;
use wiremock::http::HeaderName;
use wiremock::http::HeaderValue;
use wiremock::matchers::method;
use wiremock::matchers::path_regex;
@@ -216,9 +220,30 @@ impl WebSocketRequest {
}
}
#[derive(Debug, Clone)]
pub struct WebSocketHandshake {
headers: Vec<(String, String)>,
}
impl WebSocketHandshake {
pub fn header(&self, name: &str) -> Option<String> {
self.headers
.iter()
.find(|(header, _)| header.eq_ignore_ascii_case(name))
.map(|(_, value)| value.clone())
}
}
#[derive(Debug, Clone)]
pub struct WebSocketConnectionConfig {
pub requests: Vec<Vec<Value>>,
pub response_headers: Vec<(String, String)>,
}
pub struct WebSocketTestServer {
uri: String,
connections: Arc<Mutex<Vec<Vec<WebSocketRequest>>>>,
handshakes: Arc<Mutex<Vec<WebSocketHandshake>>>,
shutdown: oneshot::Sender<()>,
task: tokio::task::JoinHandle<()>,
}
@@ -240,6 +265,18 @@ impl WebSocketTestServer {
connections.first().cloned().unwrap_or_default()
}
pub fn handshakes(&self) -> Vec<WebSocketHandshake> {
self.handshakes.lock().unwrap().clone()
}
pub fn single_handshake(&self) -> WebSocketHandshake {
let handshakes = self.handshakes.lock().unwrap();
if handshakes.len() != 1 {
panic!("expected 1 handshake, got {}", handshakes.len());
}
handshakes.first().cloned().unwrap()
}
pub async fn shutdown(self) {
let _ = self.shutdown.send(());
let _ = self.task.await;
@@ -786,13 +823,28 @@ pub async fn start_mock_server() -> MockServer {
/// request message, the server records the payload and streams the matching
/// events as WebSocket text frames before moving to the next request.
pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSocketTestServer {
let connections = connections
.into_iter()
.map(|requests| WebSocketConnectionConfig {
requests,
response_headers: Vec::new(),
})
.collect();
start_websocket_server_with_headers(connections).await
}
pub async fn start_websocket_server_with_headers(
connections: Vec<WebSocketConnectionConfig>,
) -> WebSocketTestServer {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind websocket server");
let addr = listener.local_addr().expect("websocket server address");
let uri = format!("ws://{addr}");
let connections_log = Arc::new(Mutex::new(Vec::new()));
let handshakes_log = Arc::new(Mutex::new(Vec::new()));
let requests = Arc::clone(&connections_log);
let handshakes = Arc::clone(&handshakes_log);
let connections = Arc::new(Mutex::new(VecDeque::from(connections)));
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
@@ -806,27 +858,57 @@ pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSoc
Ok(value) => value,
Err(_) => return,
};
let mut ws_stream = match tokio_tungstenite::accept_async(stream).await {
Ok(ws) => ws,
Err(_) => continue,
};
let connection_requests = {
let connection = {
let mut pending = connections.lock().unwrap();
pending.pop_front()
};
let Some(connection_requests) = connection_requests else {
let _ = ws_stream.close(None).await;
let Some(connection) = connection else {
continue;
};
let response_headers = connection.response_headers.clone();
let handshake_log = Arc::clone(&handshakes);
let callback = move |req: &Request, mut response: Response| {
let headers = req
.headers()
.iter()
.filter_map(|(name, value)| {
value
.to_str()
.ok()
.map(|value| (name.as_str().to_string(), value.to_string()))
})
.collect();
handshake_log
.lock()
.unwrap()
.push(WebSocketHandshake { headers });
let headers_mut = response.headers_mut();
for (name, value) in &response_headers {
if let (Ok(name), Ok(value)) = (
HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_str(value),
) {
headers_mut.insert(name, value);
}
}
Ok(response)
};
let mut ws_stream = match tokio_tungstenite::accept_hdr_async(stream, callback).await {
Ok(ws) => ws,
Err(_) => continue,
};
let connection_index = {
let mut log = requests.lock().unwrap();
log.push(Vec::new());
log.len() - 1
};
for request_events in connection_requests {
for request_events in connection.requests {
let Some(Ok(message)) = ws_stream.next().await else {
break;
};
@@ -858,6 +940,7 @@ pub async fn start_websocket_server(connections: Vec<Vec<Vec<Value>>>) -> WebSoc
WebSocketTestServer {
uri,
connections: connections_log,
handshakes: handshakes_log,
shutdown: shutdown_tx,
task,
}
@@ -942,6 +1025,45 @@ pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) -> Res
response_mock
}
/// Mounts a sequence of responses for each POST to `/v1/responses`.
/// Panics if more requests are received than responses provided.
pub async fn mount_response_sequence(
server: &MockServer,
responses: Vec<ResponseTemplate>,
) -> ResponseMock {
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
struct SeqResponder {
num_calls: AtomicUsize,
responses: Vec<ResponseTemplate>,
}
impl Respond for SeqResponder {
fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
let call_num = self.num_calls.fetch_add(1, Ordering::SeqCst);
self.responses
.get(call_num)
.unwrap_or_else(|| panic!("no response for {call_num}"))
.clone()
}
}
let num_calls = responses.len();
let responder = SeqResponder {
num_calls: AtomicUsize::new(0),
responses,
};
let (mock, response_mock) = base_mock();
mock.respond_with(responder)
.up_to_n_times(num_calls as u64)
.expect(num_calls as u64)
.mount(server)
.await;
response_mock
}
/// Validate invariants on the request body sent to `/v1/responses`.
///
/// - No `function_call_output`/`custom_tool_call_output` with missing/empty `call_id`.

View File

@@ -0,0 +1,122 @@
#![allow(clippy::expect_used, clippy::unwrap_used)]
use anyhow::Result;
use core_test_support::responses::WebSocketConnectionConfig;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_done;
use core_test_support::responses::ev_reasoning_item;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::ev_shell_command_call;
use core_test_support::responses::mount_response_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::sse_response;
use core_test_support::responses::start_mock_server;
use core_test_support::responses::start_websocket_server_with_headers;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::test_codex;
use pretty_assertions::assert_eq;
const TURN_STATE_HEADER: &str = "x-codex-turn-state";
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_turn_state_persists_within_turn_and_resets_after() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let call_id = "shell-turn-state";
let first_response = sse(vec![
ev_response_created("resp-1"),
ev_reasoning_item("rsn-1", &["thinking"], &[]),
ev_shell_command_call(call_id, "echo turn-state"),
ev_completed("resp-1"),
]);
let second_response = sse(vec![
ev_response_created("resp-2"),
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]);
let third_response = sse(vec![
ev_response_created("resp-3"),
ev_assistant_message("msg-2", "done"),
ev_completed("resp-3"),
]);
// First response sets turn_state; follow-up request in the same turn should echo it.
let responses = vec![
sse_response(first_response).insert_header(TURN_STATE_HEADER, "ts-1"),
sse_response(second_response),
sse_response(third_response),
];
let request_log = mount_response_sequence(&server, responses).await;
let test = test_codex().build(&server).await?;
test.submit_turn("run a shell command").await?;
test.submit_turn("second turn").await?;
let requests = request_log.requests();
assert_eq!(requests.len(), 3);
// Initial turn request has no header; follow-up has it; next turn clears it.
assert_eq!(requests[0].header(TURN_STATE_HEADER), None);
assert_eq!(
requests[1].header(TURN_STATE_HEADER),
Some("ts-1".to_string())
);
assert_eq!(requests[2].header(TURN_STATE_HEADER), None);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_turn_state_persists_within_turn_and_resets_after() -> Result<()> {
skip_if_no_network!(Ok(()));
let call_id = "ws-shell-turn-state";
// First connection delivers turn_state; second (same turn) must send it; third (new turn) must not.
let server = start_websocket_server_with_headers(vec![
WebSocketConnectionConfig {
requests: vec![vec![
ev_response_created("resp-1"),
ev_reasoning_item("rsn-1", &["thinking"], &[]),
ev_shell_command_call(call_id, "echo websocket"),
ev_done(),
]],
response_headers: vec![(TURN_STATE_HEADER.to_string(), "ts-1".to_string())],
},
WebSocketConnectionConfig {
requests: vec![vec![
ev_response_created("resp-2"),
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]],
response_headers: Vec::new(),
},
WebSocketConnectionConfig {
requests: vec![vec![
ev_response_created("resp-3"),
ev_assistant_message("msg-2", "done"),
ev_completed("resp-3"),
]],
response_headers: Vec::new(),
},
])
.await;
let mut builder = test_codex();
let test = builder.build_with_websocket_server(&server).await?;
test.submit_turn("run the echo command").await?;
test.submit_turn("second turn").await?;
let handshakes = server.handshakes();
assert_eq!(handshakes.len(), 3);
assert_eq!(handshakes[0].header(TURN_STATE_HEADER), None);
assert_eq!(
handshakes[1].header(TURN_STATE_HEADER),
Some("ts-1".to_string())
);
assert_eq!(handshakes[2].header(TURN_STATE_HEADER), None);
server.shutdown().await;
Ok(())
}