mirror of
https://github.com/openai/codex.git
synced 2026-05-01 20:02:05 +03:00
Reapply "Add app-server transport layer with websocket support" (#11370)
Reapply "Add app-server transport layer with websocket support" with
additional fixes from https://github.com/openai/codex/pull/11313/changes
to avoid deadlocking.
This reverts commit 47356ff83c.
## Summary
To avoid deadlocking when queues are full, we maintain separate tokio
tasks dedicated to incoming vs outgoing event handling
- split the app-server main loop into two tasks in
`run_main_with_transport`
- inbound handling (`transport_event_rx`)
- outbound handling (`outgoing_rx` + `thread_created_rx`)
- separate incoming and outgoing websocket tasks
## Validation
Integration tests, testing thoroughly e2e in codex app w/ >10 concurrent
requests
<img width="1365" height="979" alt="Screenshot 2026-02-10 at 2 54 22 PM"
src="https://github.com/user-attachments/assets/47ca2c13-f322-4e5c-bedd-25859cbdc45f"
/>
---------
Co-authored-by: jif-oai <jif@openai.com>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
@@ -20,35 +19,44 @@ use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
#[cfg(test)]
|
||||
use codex_protocol::account::PlanType;
|
||||
|
||||
/// Stable identifier for a transport connection.
|
||||
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
|
||||
pub(crate) struct ConnectionId(pub(crate) u64);
|
||||
|
||||
/// Stable identifier for a client request scoped to a transport connection.
|
||||
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
||||
pub(crate) struct ConnectionRequestId {
|
||||
pub(crate) connection_id: ConnectionId,
|
||||
pub(crate) request_id: RequestId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum OutgoingEnvelope {
|
||||
ToConnection {
|
||||
connection_id: ConnectionId,
|
||||
message: OutgoingMessage,
|
||||
},
|
||||
Broadcast {
|
||||
message: OutgoingMessage,
|
||||
},
|
||||
}
|
||||
|
||||
/// Sends messages to the client and manages request callbacks.
|
||||
pub(crate) struct OutgoingMessageSender {
|
||||
next_request_id: AtomicI64,
|
||||
sender: mpsc::Sender<OutgoingMessage>,
|
||||
next_server_request_id: AtomicI64,
|
||||
sender: mpsc::Sender<OutgoingEnvelope>,
|
||||
request_id_to_callback: Mutex<HashMap<RequestId, oneshot::Sender<Result>>>,
|
||||
opted_out_notification_methods: Mutex<HashSet<String>>,
|
||||
}
|
||||
|
||||
impl OutgoingMessageSender {
|
||||
pub(crate) fn new(sender: mpsc::Sender<OutgoingMessage>) -> Self {
|
||||
pub(crate) fn new(sender: mpsc::Sender<OutgoingEnvelope>) -> Self {
|
||||
Self {
|
||||
next_request_id: AtomicI64::new(0),
|
||||
next_server_request_id: AtomicI64::new(0),
|
||||
sender,
|
||||
request_id_to_callback: Mutex::new(HashMap::new()),
|
||||
opted_out_notification_methods: Mutex::new(HashSet::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn set_opted_out_notification_methods(&self, methods: Vec<String>) {
|
||||
let mut opted_out = self.opted_out_notification_methods.lock().await;
|
||||
opted_out.clear();
|
||||
opted_out.extend(methods);
|
||||
}
|
||||
|
||||
async fn should_skip_notification(&self, method: &str) -> bool {
|
||||
let opted_out = self.opted_out_notification_methods.lock().await;
|
||||
opted_out.contains(method)
|
||||
}
|
||||
|
||||
pub(crate) async fn send_request(
|
||||
&self,
|
||||
request: ServerRequestPayload,
|
||||
@@ -61,7 +69,7 @@ impl OutgoingMessageSender {
|
||||
&self,
|
||||
request: ServerRequestPayload,
|
||||
) -> (RequestId, oneshot::Receiver<Result>) {
|
||||
let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed));
|
||||
let id = RequestId::Integer(self.next_server_request_id.fetch_add(1, Ordering::Relaxed));
|
||||
let outgoing_message_id = id.clone();
|
||||
let (tx_approve, rx_approve) = oneshot::channel();
|
||||
{
|
||||
@@ -71,7 +79,13 @@ impl OutgoingMessageSender {
|
||||
|
||||
let outgoing_message =
|
||||
OutgoingMessage::Request(request.request_with_id(outgoing_message_id.clone()));
|
||||
if let Err(err) = self.sender.send(outgoing_message).await {
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send request {outgoing_message_id:?} to client: {err:?}");
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
request_id_to_callback.remove(&outgoing_message_id);
|
||||
@@ -121,17 +135,31 @@ impl OutgoingMessageSender {
|
||||
entry.is_some()
|
||||
}
|
||||
|
||||
pub(crate) async fn send_response<T: Serialize>(&self, id: RequestId, response: T) {
|
||||
pub(crate) async fn send_response<T: Serialize>(
|
||||
&self,
|
||||
request_id: ConnectionRequestId,
|
||||
response: T,
|
||||
) {
|
||||
match serde_json::to_value(response) {
|
||||
Ok(result) => {
|
||||
let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result });
|
||||
if let Err(err) = self.sender.send(outgoing_message).await {
|
||||
let outgoing_message = OutgoingMessage::Response(OutgoingResponse {
|
||||
id: request_id.request_id,
|
||||
result,
|
||||
});
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id: request_id.connection_id,
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send response to client: {err:?}");
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
self.send_error(
|
||||
id,
|
||||
request_id,
|
||||
JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("failed to serialize response: {err}"),
|
||||
@@ -144,13 +172,11 @@ impl OutgoingMessageSender {
|
||||
}
|
||||
|
||||
pub(crate) async fn send_server_notification(&self, notification: ServerNotification) {
|
||||
let method = notification.to_string();
|
||||
if self.should_skip_notification(&method).await {
|
||||
return;
|
||||
}
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingMessage::AppServerNotification(notification))
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: OutgoingMessage::AppServerNotification(notification),
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send server notification to client: {err:?}");
|
||||
@@ -160,21 +186,35 @@ impl OutgoingMessageSender {
|
||||
/// All notifications should be migrated to [`ServerNotification`] and
|
||||
/// [`OutgoingMessage::Notification`] should be removed.
|
||||
pub(crate) async fn send_notification(&self, notification: OutgoingNotification) {
|
||||
if self
|
||||
.should_skip_notification(notification.method.as_str())
|
||||
let outgoing_message = OutgoingMessage::Notification(notification);
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
let outgoing_message = OutgoingMessage::Notification(notification);
|
||||
if let Err(err) = self.sender.send(outgoing_message).await {
|
||||
warn!("failed to send notification to client: {err:?}");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) {
|
||||
let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error });
|
||||
if let Err(err) = self.sender.send(outgoing_message).await {
|
||||
pub(crate) async fn send_error(
|
||||
&self,
|
||||
request_id: ConnectionRequestId,
|
||||
error: JSONRPCErrorError,
|
||||
) {
|
||||
let outgoing_message = OutgoingMessage::Error(OutgoingError {
|
||||
id: request_id.request_id,
|
||||
error,
|
||||
});
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id: request_id.connection_id,
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to send error to client: {err:?}");
|
||||
}
|
||||
}
|
||||
@@ -214,6 +254,8 @@ pub(crate) struct OutgoingError {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::AccountLoginCompletedNotification;
|
||||
use codex_app_server_protocol::AccountRateLimitsUpdatedNotification;
|
||||
use codex_app_server_protocol::AccountUpdatedNotification;
|
||||
@@ -224,6 +266,7 @@ mod tests {
|
||||
use codex_app_server_protocol::RateLimitWindow;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tokio::time::timeout;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::*;
|
||||
@@ -364,4 +407,75 @@ mod tests {
|
||||
"ensure the notification serializes correctly"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_response_routes_to_target_connection() {
|
||||
let (tx, mut rx) = mpsc::channel::<OutgoingEnvelope>(4);
|
||||
let outgoing = OutgoingMessageSender::new(tx);
|
||||
let request_id = ConnectionRequestId {
|
||||
connection_id: ConnectionId(42),
|
||||
request_id: RequestId::Integer(7),
|
||||
};
|
||||
|
||||
outgoing
|
||||
.send_response(request_id.clone(), json!({ "ok": true }))
|
||||
.await;
|
||||
|
||||
let envelope = timeout(Duration::from_secs(1), rx.recv())
|
||||
.await
|
||||
.expect("should receive envelope before timeout")
|
||||
.expect("channel should contain one message");
|
||||
|
||||
match envelope {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
} => {
|
||||
assert_eq!(connection_id, ConnectionId(42));
|
||||
let OutgoingMessage::Response(response) = message else {
|
||||
panic!("expected response message");
|
||||
};
|
||||
assert_eq!(response.id, request_id.request_id);
|
||||
assert_eq!(response.result, json!({ "ok": true }));
|
||||
}
|
||||
other => panic!("expected targeted response envelope, got: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_error_routes_to_target_connection() {
|
||||
let (tx, mut rx) = mpsc::channel::<OutgoingEnvelope>(4);
|
||||
let outgoing = OutgoingMessageSender::new(tx);
|
||||
let request_id = ConnectionRequestId {
|
||||
connection_id: ConnectionId(9),
|
||||
request_id: RequestId::Integer(3),
|
||||
};
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: "boom".to_string(),
|
||||
data: None,
|
||||
};
|
||||
|
||||
outgoing.send_error(request_id.clone(), error.clone()).await;
|
||||
|
||||
let envelope = timeout(Duration::from_secs(1), rx.recv())
|
||||
.await
|
||||
.expect("should receive envelope before timeout")
|
||||
.expect("channel should contain one message");
|
||||
|
||||
match envelope {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message,
|
||||
} => {
|
||||
assert_eq!(connection_id, ConnectionId(9));
|
||||
let OutgoingMessage::Error(outgoing_error) = message else {
|
||||
panic!("expected error message");
|
||||
};
|
||||
assert_eq!(outgoing_error.id, RequestId::Integer(3));
|
||||
assert_eq!(outgoing_error.error, error);
|
||||
}
|
||||
other => panic!("expected targeted error envelope, got: {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user