Turn-state sticky routing per turn (#9332)

- capture the header from SSE/WS handshakes, store it per
ModelClientSession using `Oncelock`, echo it on turn-scoped requests,
and add SSE+WS integration tests for within-turn persistence +
cross-turn reset.

- keep `x-codex-turn-state` sticky within a user turn to maintain
routing continuity for retries/tool follow-ups.
This commit is contained in:
Ahmed Ibrahim
2026-01-16 09:30:11 -08:00
committed by GitHub
parent 4125c825f9
commit ebdd8795e9
11 changed files with 343 additions and 24 deletions

View File

@@ -19,6 +19,7 @@ use codex_protocol::protocol::SessionSource;
use http::HeaderMap;
use serde_json::Value;
use std::sync::Arc;
use std::sync::OnceLock;
use tracing::instrument;
pub struct ResponsesClient<T: HttpTransport, A: AuthProvider> {
@@ -36,6 +37,7 @@ pub struct ResponsesOptions {
pub session_source: Option<SessionSource>,
pub extra_headers: HeaderMap,
pub compression: Compression,
pub turn_state: Option<Arc<OnceLock<String>>>,
}
impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
@@ -58,9 +60,15 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
pub async fn stream_request(
&self,
request: ResponsesRequest,
turn_state: Option<Arc<OnceLock<String>>>,
) -> Result<ResponseStream, ApiError> {
self.stream(request.body, request.headers, request.compression)
.await
self.stream(
request.body,
request.headers,
request.compression,
turn_state,
)
.await
}
#[instrument(level = "trace", skip_all, err)]
@@ -80,6 +88,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
session_source,
extra_headers,
compression,
turn_state,
} = options;
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
@@ -96,7 +105,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
.compression(compression)
.build(self.streaming.provider())?;
self.stream_request(request).await
self.stream_request(request, turn_state).await
}
fn path(&self) -> &'static str {
@@ -111,6 +120,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
body: Value,
extra_headers: HeaderMap,
compression: Compression,
turn_state: Option<Arc<OnceLock<String>>>,
) -> Result<ResponseStream, ApiError> {
let compression = match compression {
Compression::None => RequestCompression::None,
@@ -124,6 +134,7 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
extra_headers,
compression,
spawn_response_stream,
turn_state,
)
.await
}