Reuse connection between turns (#12294)

Add a pool of one to the model client to reuse connections across turns.
This commit is contained in:
pakrym-oai
2026-02-20 10:09:46 -08:00
committed by GitHub
parent 035c4c30bb
commit 86803ca9bf
3 changed files with 69 additions and 1 deletions

View File

@@ -175,6 +175,19 @@ pub struct ResponsesWebsocketConnection {
telemetry: Option<Arc<dyn WebsocketTelemetry>>, telemetry: Option<Arc<dyn WebsocketTelemetry>>,
} }
impl std::fmt::Debug for ResponsesWebsocketConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponsesWebsocketConnection")
.field("stream", &"<ws-stream>")
.field("idle_timeout", &self.idle_timeout)
.field("server_reasoning_included", &self.server_reasoning_included)
.field("models_etag", &self.models_etag)
.field("server_model", &self.server_model)
.field("telemetry", &self.telemetry.as_ref().map(|_| "<telemetry>"))
.finish()
}
}
impl ResponsesWebsocketConnection { impl ResponsesWebsocketConnection {
fn new( fn new(
stream: WsStream, stream: WsStream,

View File

@@ -28,6 +28,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::OnceLock; use std::sync::OnceLock;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
@@ -123,6 +124,7 @@ struct ModelClientState {
include_timing_metrics: bool, include_timing_metrics: bool,
beta_features_header: Option<String>, beta_features_header: Option<String>,
disable_websockets: AtomicBool, disable_websockets: AtomicBool,
cached_websocket_connection: StdMutex<Option<ApiWebSocketConnection>>,
} }
/// Resolved API client setup for a single request attempt. /// Resolved API client setup for a single request attempt.
@@ -228,6 +230,7 @@ impl ModelClient {
include_timing_metrics, include_timing_metrics,
beta_features_header, beta_features_header,
disable_websockets: AtomicBool::new(false), disable_websockets: AtomicBool::new(false),
cached_websocket_connection: StdMutex::new(None),
}), }),
} }
} }
@@ -239,13 +242,29 @@ impl ModelClient {
pub fn new_session(&self) -> ModelClientSession { pub fn new_session(&self) -> ModelClientSession {
ModelClientSession { ModelClientSession {
client: self.clone(), client: self.clone(),
connection: None, connection: self.take_cached_websocket_connection(),
websocket_last_request: None, websocket_last_request: None,
websocket_last_response_rx: None, websocket_last_response_rx: None,
turn_state: Arc::new(OnceLock::new()), turn_state: Arc::new(OnceLock::new()),
} }
} }
fn take_cached_websocket_connection(&self) -> Option<ApiWebSocketConnection> {
self.state
.cached_websocket_connection
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take()
}
fn store_cached_websocket_connection(&self, connection: ApiWebSocketConnection) {
*self
.state
.cached_websocket_connection
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(connection);
}
/// Compacts the current conversation history using the Compact endpoint. /// Compacts the current conversation history using the Compact endpoint.
/// ///
/// This is a unary call (no streaming) that returns a new list of /// This is a unary call (no streaming) that returns a new list of
@@ -452,6 +471,14 @@ impl ModelClient {
} }
} }
impl Drop for ModelClientSession {
fn drop(&mut self) {
if let Some(connection) = self.connection.take() {
self.client.store_cached_websocket_connection(connection);
}
}
}
impl ModelClientSession { impl ModelClientSession {
fn activate_http_fallback(&self, websocket_enabled: bool) -> bool { fn activate_http_fallback(&self, websocket_enabled: bool) -> bool {
websocket_enabled websocket_enabled

View File

@@ -118,6 +118,34 @@ async fn responses_websocket_preconnect_reuses_connection() {
server.shutdown().await; server.shutdown().await;
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_reuses_connection_after_session_drop() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![
vec![ev_response_created("resp-1"), ev_completed("resp-1")],
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
]])
.await;
let harness = websocket_harness(&server).await;
let prompt_one = prompt_with_input(vec![message_item("hello")]);
let prompt_two = prompt_with_input(vec![message_item("again")]);
{
let mut client_session = harness.client.new_session();
stream_until_complete(&mut client_session, &harness, &prompt_one).await;
}
let mut client_session = harness.client.new_session();
stream_until_complete(&mut client_session, &harness, &prompt_two).await;
assert_eq!(server.handshakes().len(), 1);
assert_eq!(server.single_connection().len(), 2);
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_preconnect_is_reused_even_with_header_changes() { async fn responses_websocket_preconnect_is_reused_even_with_header_changes() {
skip_if_no_network!(); skip_if_no_network!();