Reuse connection between turns (#12294)

Add a pool of one to the model client to reuse connections across turns.
This commit is contained in:
pakrym-oai
2026-02-20 10:09:46 -08:00
committed by GitHub
parent 035c4c30bb
commit 86803ca9bf
3 changed files with 69 additions and 1 deletions

View File

@@ -175,6 +175,19 @@ pub struct ResponsesWebsocketConnection {
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
}
impl std::fmt::Debug for ResponsesWebsocketConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponsesWebsocketConnection")
.field("stream", &"<ws-stream>")
.field("idle_timeout", &self.idle_timeout)
.field("server_reasoning_included", &self.server_reasoning_included)
.field("models_etag", &self.models_etag)
.field("server_model", &self.server_model)
.field("telemetry", &self.telemetry.as_ref().map(|_| "<telemetry>"))
.finish()
}
}
impl ResponsesWebsocketConnection {
fn new(
stream: WsStream,

View File

@@ -28,6 +28,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::OnceLock;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
@@ -123,6 +124,7 @@ struct ModelClientState {
include_timing_metrics: bool,
beta_features_header: Option<String>,
disable_websockets: AtomicBool,
cached_websocket_connection: StdMutex<Option<ApiWebSocketConnection>>,
}
/// Resolved API client setup for a single request attempt.
@@ -228,6 +230,7 @@ impl ModelClient {
include_timing_metrics,
beta_features_header,
disable_websockets: AtomicBool::new(false),
cached_websocket_connection: StdMutex::new(None),
}),
}
}
@@ -239,13 +242,29 @@ impl ModelClient {
pub fn new_session(&self) -> ModelClientSession {
ModelClientSession {
client: self.clone(),
connection: None,
connection: self.take_cached_websocket_connection(),
websocket_last_request: None,
websocket_last_response_rx: None,
turn_state: Arc::new(OnceLock::new()),
}
}
fn take_cached_websocket_connection(&self) -> Option<ApiWebSocketConnection> {
self.state
.cached_websocket_connection
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take()
}
fn store_cached_websocket_connection(&self, connection: ApiWebSocketConnection) {
*self
.state
.cached_websocket_connection
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(connection);
}
/// Compacts the current conversation history using the Compact endpoint.
///
/// This is a unary call (no streaming) that returns a new list of
@@ -452,6 +471,14 @@ impl ModelClient {
}
}
impl Drop for ModelClientSession {
fn drop(&mut self) {
if let Some(connection) = self.connection.take() {
self.client.store_cached_websocket_connection(connection);
}
}
}
impl ModelClientSession {
fn activate_http_fallback(&self, websocket_enabled: bool) -> bool {
websocket_enabled

View File

@@ -118,6 +118,34 @@ async fn responses_websocket_preconnect_reuses_connection() {
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_reuses_connection_after_session_drop() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
]])
.await;
let harness = websocket_harness(&server).await;
let prompt_one = prompt_with_input(vec![message_item("hello")]);
let prompt_two = prompt_with_input(vec![message_item("again")]);
{
let mut client_session = harness.client.new_session();
stream_until_complete(&mut client_session, &harness, &prompt_one).await;
}
let mut client_session = harness.client.new_session();
stream_until_complete(&mut client_session, &harness, &prompt_two).await;
assert_eq!(server.handshakes().len(), 1);
assert_eq!(server.single_connection().len(), 2);
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_preconnect_is_reused_even_with_header_changes() {
skip_if_no_network!();