diff --git a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs index ef3af37942..be80ad523a 100644 --- a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs +++ b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs @@ -6,6 +6,7 @@ pub use super::protocol::ClientEvent; pub use super::protocol::ClientId; use super::protocol::PongStatus; use super::protocol::ServerEvent; +use super::protocol::StreamId; use crate::outgoing_message::ConnectionId; use crate::outgoing_message::QueuedOutgoingMessage; use crate::transport::remote_control::QueuedServerEnvelope; @@ -33,8 +34,9 @@ struct ClientState { } pub(crate) struct ClientTracker { - clients: HashMap, - join_set: JoinSet, + clients: HashMap<(ClientId, StreamId), ClientState>, + legacy_stream_ids: HashMap, + join_set: JoinSet<(ClientId, StreamId)>, server_event_tx: mpsc::Sender, transport_event_tx: mpsc::Sender, shutdown_token: CancellationToken, @@ -48,6 +50,7 @@ impl ClientTracker { ) -> Self { Self { clients: HashMap::new(), + legacy_stream_ids: HashMap::new(), join_set: JoinSet::new(), server_event_tx, transport_event_tx, @@ -55,12 +58,12 @@ impl ClientTracker { } } - pub(crate) async fn bookkeep_join_set(&mut self) -> Option { + pub(crate) async fn bookkeep_join_set(&mut self) -> Option<(ClientId, StreamId)> { while let Some(join_result) = self.join_set.join_next().await { - let Ok(client_id) = join_result else { + let Ok(client_key) = join_result else { continue; }; - return Some(client_id); + return Some(client_key); } futures::future::pending().await } @@ -68,8 +71,8 @@ impl ClientTracker { pub(crate) async fn shutdown(&mut self) { self.shutdown_token.cancel(); - while let Some(client_id) = self.clients.keys().next().cloned() { - let _ = self.close_client(&client_id).await; + while let Some(client_key) = self.clients.keys().next().cloned() { + let _ = self.close_client(&client_key).await; } self.drain_join_set().await; @@ -79,6 +82,10 @@ impl ClientTracker { while self.join_set.join_next().await.is_some() {} } + pub(crate) fn legacy_stream_id(&self, client_id: &ClientId) -> Option { + self.legacy_stream_ids.get(client_id).cloned() + } + pub(crate) async fn handle_message( &mut self, client_envelope: ClientEnvelope, @@ -86,14 +93,40 @@ impl ClientTracker { let ClientEnvelope { client_id, event, + stream_id, seq_id, cursor: _, } = client_envelope; + let is_legacy_stream_id = stream_id.is_none(); + let is_initialize = matches!(&event, ClientEvent::ClientMessage { message } if remote_control_message_starts_connection(message)); + let stream_id = match stream_id { + Some(stream_id) => stream_id, + None if is_initialize => { + // TODO(ruslan): delete this fallback once all clients are updated to send stream_id. + self.legacy_stream_ids + .remove(&client_id) + .unwrap_or_else(StreamId::new_random) + } + None => self + .legacy_stream_ids + .get(&client_id) + .cloned() + .unwrap_or_else(|| { + if matches!(&event, ClientEvent::Ping) { + StreamId::new_random() + } else { + StreamId(String::new()) + } + }), + }; + if stream_id.0.is_empty() { + return Ok(()); + } + let client_key = (client_id.clone(), stream_id.clone()); match event { ClientEvent::ClientMessage { message } => { - let is_initialize = remote_control_message_starts_connection(&message); if let Some(seq_id) = seq_id - && let Some(client) = self.clients.get(&client_id) + && let Some(client) = self.clients.get(&client_key) && client .last_inbound_seq_id .is_some_and(|last_seq_id| last_seq_id >= seq_id) @@ -102,24 +135,22 @@ impl ClientTracker { return Ok(()); } - if is_initialize && self.clients.contains_key(&client_id) { - self.close_client(&client_id).await?; + if is_initialize && self.clients.contains_key(&client_key) { + self.close_client(&client_key).await?; } - if let Some(connection_id) = self.clients.get_mut(&client_id).map(|client| { + if let Some(connection_id) = self.clients.get_mut(&client_key).map(|client| { client.last_activity_at = Instant::now(); if let Some(seq_id) = seq_id { client.last_inbound_seq_id = Some(seq_id); } client.connection_id }) { - self.transport_event_tx - .send(TransportEvent::IncomingMessage { - connection_id, - message, - }) - .await - .map_err(|_| Stopped)?; + self.send_transport_event(TransportEvent::IncomingMessage { + connection_id, + message, + }) + .await?; return Ok(()); } @@ -131,33 +162,35 @@ impl ClientTracker { let (writer_tx, writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); let disconnect_token = self.shutdown_token.child_token(); - self.transport_event_tx - .send(TransportEvent::ConnectionOpened { - connection_id, - writer: writer_tx, - disconnect_sender: Some(disconnect_token.clone()), - }) - .await - .map_err(|_| Stopped)?; + self.send_transport_event(TransportEvent::ConnectionOpened { + connection_id, + writer: writer_tx, + disconnect_sender: Some(disconnect_token.clone()), + }) + .await?; let (status_tx, status_rx) = watch::channel(PongStatus::Active); self.join_set.spawn(Self::run_client_outbound( client_id.clone(), + stream_id.clone(), self.server_event_tx.clone(), writer_rx, status_rx, disconnect_token.clone(), )); self.clients.insert( - client_id, + client_key, ClientState { connection_id, disconnect_token, last_activity_at: Instant::now(), - last_inbound_seq_id: seq_id, + last_inbound_seq_id: if is_legacy_stream_id { None } else { seq_id }, status_tx, }, ); + if is_legacy_stream_id { + self.legacy_stream_ids.insert(client_id.clone(), stream_id); + } self.send_transport_event(TransportEvent::IncomingMessage { connection_id, message, @@ -166,7 +199,7 @@ impl ClientTracker { } ClientEvent::Ack => Ok(()), ClientEvent::Ping => { - if let Some(client) = self.clients.get_mut(&client_id) { + if let Some(client) = self.clients.get_mut(&client_key) { client.last_activity_at = Instant::now(); let _ = client.status_tx.send(PongStatus::Active); return Ok(()); @@ -179,23 +212,25 @@ impl ClientTracker { status: PongStatus::Unknown, }, client_id, + stream_id, write_complete_tx: None, }; let _ = server_event_tx.send(server_envelope).await; }); Ok(()) } - ClientEvent::ClientClosed => self.close_client(&client_id).await, + ClientEvent::ClientClosed => self.close_client(&client_key).await, } } async fn run_client_outbound( client_id: ClientId, + stream_id: StreamId, server_event_tx: mpsc::Sender, mut writer_rx: mpsc::Receiver, mut status_rx: watch::Receiver, disconnect_token: CancellationToken, - ) -> ClientId { + ) -> (ClientId, StreamId) { loop { let (event, write_complete_tx) = tokio::select! { _ = disconnect_token.cancelled() => { @@ -225,6 +260,7 @@ impl ClientTracker { send_result = server_event_tx.send(QueuedServerEnvelope { event, client_id: client_id.clone(), + stream_id: stream_id.clone(), write_complete_tx, }) => send_result, }; @@ -232,28 +268,40 @@ impl ClientTracker { break; } } - client_id + (client_id, stream_id) } - pub(crate) async fn close_expired_clients(&mut self) -> Result, Stopped> { + pub(crate) async fn close_expired_clients( + &mut self, + ) -> Result, Stopped> { let now = Instant::now(); - let expired_client_ids: Vec = self + let expired_client_ids: Vec<(ClientId, StreamId)> = self .clients .iter() - .filter_map(|(client_id, client)| { - (!remote_control_client_is_alive(client, now)).then_some(client_id.clone()) + .filter_map(|(client_key, client)| { + (!remote_control_client_is_alive(client, now)).then_some(client_key.clone()) }) .collect(); - for client_id in &expired_client_ids { - self.close_client(client_id).await?; + for client_key in &expired_client_ids { + self.close_client(client_key).await?; } Ok(expired_client_ids) } - pub(super) async fn close_client(&mut self, client_id: &ClientId) -> Result<(), Stopped> { - let Some(client) = self.clients.remove(client_id) else { + pub(super) async fn close_client( + &mut self, + client_key: &(ClientId, StreamId), + ) -> Result<(), Stopped> { + let Some(client) = self.clients.remove(client_key) else { return Ok(()); }; + if self + .legacy_stream_ids + .get(&client_key.0) + .is_some_and(|stream_id| stream_id == &client_key.1) + { + self.legacy_stream_ids.remove(&client_key.0); + } client.disconnect_token.cancel(); self.send_transport_event(TransportEvent::ConnectionClosed { connection_id: client.connection_id, @@ -296,6 +344,13 @@ mod tests { use tokio::time::timeout; fn initialize_envelope(client_id: &str) -> ClientEnvelope { + initialize_envelope_with_stream_id(client_id, None) + } + + fn initialize_envelope_with_stream_id( + client_id: &str, + stream_id: Option<&str>, + ) -> ClientEnvelope { ClientEnvelope { event: ClientEvent::ClientMessage { message: JSONRPCMessage::Request(JSONRPCRequest { @@ -311,6 +366,7 @@ mod tests { }), }, client_id: ClientId(client_id.to_string()), + stream_id: stream_id.map(|stream_id| StreamId(stream_id.to_string())), seq_id: Some(0), cursor: None, } @@ -358,7 +414,7 @@ mod tests { .await .expect("bookkeeping should process the closed task") .expect("closed task should return client id"); - assert_eq!(closed_client_id, ClientId("client-1".to_string())); + assert_eq!(closed_client_id.0, ClientId("client-1".to_string())); client_tracker .close_client(&closed_client_id) .await @@ -390,6 +446,7 @@ mod tests { status: PongStatus::Unknown, }, client_id: ClientId("queued-client".to_string()), + stream_id: StreamId("queued-stream".to_string()), write_complete_tx: None, }) .await @@ -431,4 +488,85 @@ mod tests { .await .expect("shutdown should not hang on blocked server forwarding"); } + + #[tokio::test] + async fn initialize_with_new_stream_id_opens_new_connection_for_same_client() { + let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let mut client_tracker = + ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token); + + client_tracker + .handle_message(initialize_envelope_with_stream_id( + "client-1", + Some("stream-1"), + )) + .await + .expect("first initialize should open client"); + let first_connection_id = match transport_event_rx.recv().await.expect("open event") { + TransportEvent::ConnectionOpened { connection_id, .. } => connection_id, + other => panic!("expected connection opened, got {other:?}"), + }; + let _ = transport_event_rx.recv().await.expect("initialize event"); + + client_tracker + .handle_message(initialize_envelope_with_stream_id( + "client-1", + Some("stream-2"), + )) + .await + .expect("second initialize should open client"); + let second_connection_id = match transport_event_rx.recv().await.expect("open event") { + TransportEvent::ConnectionOpened { connection_id, .. } => connection_id, + other => panic!("expected connection opened, got {other:?}"), + }; + + assert_ne!(first_connection_id, second_connection_id); + } + + #[tokio::test] + async fn legacy_initialize_without_stream_id_resets_inbound_seq_id() { + let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let mut client_tracker = + ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token); + + client_tracker + .handle_message(initialize_envelope("client-1")) + .await + .expect("initialize should open client"); + let connection_id = match transport_event_rx.recv().await.expect("open event") { + TransportEvent::ConnectionOpened { connection_id, .. } => connection_id, + other => panic!("expected connection opened, got {other:?}"), + }; + let _ = transport_event_rx.recv().await.expect("initialize event"); + + client_tracker + .handle_message(ClientEnvelope { + event: ClientEvent::ClientMessage { + message: JSONRPCMessage::Notification( + codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }, + ), + }, + client_id: ClientId("client-1".to_string()), + stream_id: None, + seq_id: Some(0), + cursor: None, + }) + .await + .expect("legacy followup should be forwarded"); + + match transport_event_rx.recv().await.expect("followup event") { + TransportEvent::IncomingMessage { + connection_id: incoming_connection_id, + .. + } => assert_eq!(incoming_connection_id, connection_id), + other => panic!("expected incoming message, got {other:?}"), + } + } } diff --git a/codex-rs/app-server/src/transport/remote_control/enroll.rs b/codex-rs/app-server/src/transport/remote_control/enroll.rs index 23b190f962..f74721beac 100644 --- a/codex-rs/app-server/src/transport/remote_control/enroll.rs +++ b/codex-rs/app-server/src/transport/remote_control/enroll.rs @@ -20,6 +20,7 @@ pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; #[derive(Debug, Clone, PartialEq, Eq)] pub(super) struct RemoteControlEnrollment { pub(super) account_id: Option, + pub(super) environment_id: String, pub(super) server_id: String, pub(super) server_name: String, } @@ -47,11 +48,14 @@ pub(super) async fn load_persisted_remote_control_enrollment( } }; - enrollment.map(|(server_id, server_name)| RemoteControlEnrollment { - account_id: account_id.map(&str::to_string), - server_id, - server_name, - }) + enrollment.map( + |(server_id, environment_id, server_name)| RemoteControlEnrollment { + account_id: account_id.map(&str::to_string), + environment_id, + server_id, + server_name, + }, + ) } pub(super) async fn update_persisted_remote_control_enrollment( @@ -77,6 +81,7 @@ pub(super) async fn update_persisted_remote_control_enrollment( &remote_control_target.websocket_url, account_id, &enrollment.server_id, + &enrollment.environment_id, &enrollment.server_name, ) .await @@ -182,6 +187,7 @@ pub(super) async fn enroll_remote_control_server( Ok(RemoteControlEnrollment { account_id: account_id.map(&str::to_string), + environment_id: enrollment.environment_id, server_id: enrollment.server_id, server_name, }) @@ -221,11 +227,13 @@ mod tests { .expect("second target should parse"); let first_enrollment = RemoteControlEnrollment { account_id: Some("account-a".to_string()), + environment_id: "env_first".to_string(), server_id: "srv_e_first".to_string(), server_name: "first-server".to_string(), }; let second_enrollment = RemoteControlEnrollment { account_id: Some("account-a".to_string()), + environment_id: "env_second".to_string(), server_id: "srv_e_second".to_string(), server_name: "second-server".to_string(), }; @@ -287,11 +295,13 @@ mod tests { .expect("second target should parse"); let first_enrollment = RemoteControlEnrollment { account_id: Some("account-a".to_string()), + environment_id: "env_first".to_string(), server_id: "srv_e_first".to_string(), server_name: "first-server".to_string(), }; let second_enrollment = RemoteControlEnrollment { account_id: Some("account-a".to_string()), + environment_id: "env_second".to_string(), server_id: "srv_e_second".to_string(), server_name: "second-server".to_string(), }; diff --git a/codex-rs/app-server/src/transport/remote_control/mod.rs b/codex-rs/app-server/src/transport/remote_control/mod.rs index e9d91b17e1..4dd5a68769 100644 --- a/codex-rs/app-server/src/transport/remote_control/mod.rs +++ b/codex-rs/app-server/src/transport/remote_control/mod.rs @@ -8,6 +8,7 @@ use crate::transport::remote_control::websocket::load_remote_control_auth; pub use self::protocol::ClientId; use self::protocol::ServerEvent; +use self::protocol::StreamId; use self::protocol::normalize_remote_control_url; use super::CHANNEL_CAPACITY; use super::TransportEvent; @@ -24,6 +25,7 @@ use tokio_util::sync::CancellationToken; pub(super) struct QueuedServerEnvelope { pub(super) event: ServerEvent, pub(super) client_id: ClientId, + pub(super) stream_id: StreamId, pub(super) write_complete_tx: Option>, } diff --git a/codex-rs/app-server/src/transport/remote_control/protocol.rs b/codex-rs/app-server/src/transport/remote_control/protocol.rs index 4e057e565d..0bedf32dc5 100644 --- a/codex-rs/app-server/src/transport/remote_control/protocol.rs +++ b/codex-rs/app-server/src/transport/remote_control/protocol.rs @@ -24,12 +24,23 @@ pub(super) struct EnrollRemoteServerRequest { #[derive(Debug, Deserialize)] pub(super) struct EnrollRemoteServerResponse { pub(super) server_id: String, + pub(super) environment_id: String, } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(transparent)] pub struct ClientId(pub String); +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(transparent)] +pub struct StreamId(pub String); + +impl StreamId { + pub fn new_random() -> Self { + Self(uuid::Uuid::now_v7().to_string()) + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ClientEvent { @@ -44,13 +55,11 @@ pub enum ClientEvent { pub(crate) struct ClientEnvelope { #[serde(flatten)] pub(crate) event: ClientEvent, - #[serde(rename = "client_id", alias = "clientId")] + #[serde(rename = "client_id")] pub(crate) client_id: ClientId, - #[serde( - rename = "seq_id", - alias = "seqId", - skip_serializing_if = "Option::is_none" - )] + #[serde(rename = "stream_id", skip_serializing_if = "Option::is_none")] + pub(crate) stream_id: Option, + #[serde(rename = "seq_id", skip_serializing_if = "Option::is_none")] pub(crate) seq_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub(crate) cursor: Option, @@ -81,9 +90,11 @@ pub enum ServerEvent { pub(crate) struct ServerEnvelope { #[serde(flatten)] pub(crate) event: ServerEvent, - #[serde(rename = "client_id", alias = "clientId")] + #[serde(rename = "client_id")] pub(crate) client_id: ClientId, - #[serde(rename = "seq_id", alias = "seqId")] + #[serde(rename = "stream_id")] + pub(crate) stream_id: StreamId, + #[serde(rename = "seq_id")] pub(crate) seq_id: u64, } diff --git a/codex-rs/app-server/src/transport/remote_control/tests.rs b/codex-rs/app-server/src/transport/remote_control/tests.rs index e0556ebfbe..888fb47320 100644 --- a/codex-rs/app-server/src/transport/remote_control/tests.rs +++ b/codex-rs/app-server/src/transport/remote_control/tests.rs @@ -95,7 +95,11 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() enroll_request.request_line, "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" ); - respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await; + respond_with_json( + enroll_request.stream, + json!({ "server_id": "srv_e_test", "environment_id": "env_test" }), + ) + .await; let mut websocket = accept_remote_control_connection(&listener).await; let client_id = ClientId("client-1".to_string()); @@ -104,6 +108,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() ClientEnvelope { event: ClientEvent::Ping, client_id: client_id.clone(), + stream_id: None, seq_id: None, cursor: None, }, @@ -131,6 +136,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() ), }, client_id: client_id.clone(), + stream_id: None, seq_id: Some(0), cursor: None, }, @@ -161,6 +167,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() message: initialize_message.clone(), }, client_id: client_id.clone(), + stream_id: None, seq_id: Some(1), cursor: None, }, @@ -207,6 +214,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() message: followup_message.clone(), }, client_id: client_id.clone(), + stream_id: None, seq_id: Some(2), cursor: None, }, @@ -232,6 +240,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() ClientEnvelope { event: ClientEvent::Ping, client_id: client_id.clone(), + stream_id: None, seq_id: None, cursor: None, }, @@ -281,6 +290,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() ClientEnvelope { event: ClientEvent::ClientClosed, client_id: client_id.clone(), + stream_id: None, seq_id: None, cursor: None, }, @@ -304,6 +314,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() ClientEnvelope { event: ClientEvent::Ping, client_id, + stream_id: None, seq_id: None, cursor: None, }, @@ -348,7 +359,11 @@ async fn remote_control_transport_reconnects_after_disconnect() { enroll_request.request_line, "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" ); - respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await; + respond_with_json( + enroll_request.stream, + json!({ "server_id": "srv_e_test", "environment_id": "env_test" }), + ) + .await; let mut first_websocket = accept_remote_control_connection(&listener).await; first_websocket .close(None) @@ -374,6 +389,7 @@ async fn remote_control_transport_reconnects_after_disconnect() { }), }, client_id: ClientId("client-2".to_string()), + stream_id: None, seq_id: Some(0), cursor: None, }, @@ -414,7 +430,11 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() { .expect("remote control should start"); let enroll_request = accept_http_request(&listener).await; - respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await; + respond_with_json( + enroll_request.stream, + json!({ "server_id": "srv_e_test", "environment_id": "env_test" }), + ) + .await; let mut first_websocket = accept_remote_control_connection(&listener).await; let client_id = ClientId("client-1".to_string()); @@ -436,6 +456,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() { message: initialize_message, }, client_id: client_id.clone(), + stream_id: None, seq_id: Some(0), cursor: None, }, @@ -493,6 +514,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() { ClientEnvelope { event: ClientEvent::ClientClosed, client_id: client_id.clone(), + stream_id: None, seq_id: None, cursor: None, }, @@ -519,6 +541,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() { ClientEnvelope { event: ClientEvent::Ping, client_id, + stream_id: None, seq_id: None, cursor: None, }, @@ -582,7 +605,11 @@ async fn remote_control_http_mode_enrolls_before_connecting() { "app_server_version": env!("CARGO_PKG_VERSION"), }) ); - respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await; + respond_with_json( + enroll_request.stream, + json!({ "server_id": "srv_e_test", "environment_id": "env_test" }), + ) + .await; let (handshake_request, mut websocket) = accept_remote_control_backend_connection(&listener).await; @@ -634,6 +661,7 @@ async fn remote_control_http_mode_enrolls_before_connecting() { message: initialize_message.clone(), }, client_id: backend_client_id.clone(), + stream_id: None, seq_id: Some(0), cursor: None, }, @@ -742,6 +770,7 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling normalize_remote_control_url(&remote_control_url).expect("target should parse"); let persisted_enrollment = RemoteControlEnrollment { account_id: Some("account_id".to_string()), + environment_id: "env_persisted".to_string(), server_id: "srv_e_persisted".to_string(), server_name: "persisted-server".to_string(), }; @@ -803,11 +832,13 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404() let expected_server_name = gethostname().to_string_lossy().trim().to_string(); let stale_enrollment = RemoteControlEnrollment { account_id: Some("account_id".to_string()), + environment_id: "env_stale".to_string(), server_id: "srv_e_stale".to_string(), server_name: "stale-server".to_string(), }; let refreshed_enrollment = RemoteControlEnrollment { account_id: Some("account_id".to_string()), + environment_id: "env_refreshed".to_string(), server_id: "srv_e_refreshed".to_string(), server_name: expected_server_name, }; @@ -851,7 +882,10 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404() ); respond_with_json( enroll_request.stream, - json!({ "server_id": refreshed_enrollment.server_id }), + json!({ + "server_id": refreshed_enrollment.server_id, + "environment_id": refreshed_enrollment.environment_id, + }), ) .await; @@ -1048,8 +1082,15 @@ async fn read_server_event(websocket: &mut WebSocketStream) -> serde_ .expect("websocket frame should be readable"); match frame { tungstenite::Message::Text(text) => { - return serde_json::from_str(text.as_ref()) - .expect("server event should deserialize"); + let mut event: serde_json::Value = + serde_json::from_str(text.as_ref()).expect("server event should deserialize"); + if let Some(stream_id) = event + .as_object_mut() + .and_then(|event| event.remove("stream_id")) + { + assert!(stream_id.is_string(), "stream_id should be a string"); + } + return event; } tungstenite::Message::Ping(payload) => { websocket diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs index 74416642af..52cc9c14de 100644 --- a/codex-rs/app-server/src/transport/remote_control/websocket.rs +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -14,6 +14,7 @@ use super::protocol::ClientEvent; use super::protocol::ClientId; use super::protocol::RemoteControlTarget; use super::protocol::ServerEnvelope; +use super::protocol::StreamId; use axum::http::HeaderValue; use base64::Engine; use codex_core::AuthManager; @@ -50,7 +51,7 @@ pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor"; struct BoundedOutboundBuffer { - buffer_by_client: HashMap>, + buffer_by_client: HashMap<(ClientId, StreamId), BTreeMap>, used_tx: watch::Sender, } @@ -66,20 +67,29 @@ impl BoundedOutboundBuffer { fn insert(&mut self, server_envelope: &ServerEnvelope) { self.buffer_by_client - .entry(server_envelope.client_id.clone()) + .entry(( + server_envelope.client_id.clone(), + server_envelope.stream_id.clone(), + )) .or_default() .insert(server_envelope.seq_id, server_envelope.clone()); self.used_tx.send_modify(|used| *used += 1); } - fn remove(&mut self, client_id: &ClientId) { - if let Some(buffer) = self.buffer_by_client.remove(client_id) { + fn remove(&mut self, client_id: &ClientId, stream_id: &StreamId) { + if let Some(buffer) = self + .buffer_by_client + .remove(&(client_id.clone(), stream_id.clone())) + { self.used_tx.send_modify(|used| *used -= buffer.len()); } } - fn ack(&mut self, client_id: &ClientId, acked_seq_id: u64) { - let Some(buffer) = self.buffer_by_client.get_mut(client_id) else { + fn ack(&mut self, client_id: &ClientId, stream_id: &StreamId, acked_seq_id: u64) { + let Some(buffer) = self + .buffer_by_client + .get_mut(&(client_id.clone(), stream_id.clone())) + else { return; }; while let Some(seq_id) = buffer.first_key_value().map(|(seq_id, _)| seq_id) @@ -89,7 +99,8 @@ impl BoundedOutboundBuffer { self.used_tx.send_modify(|used| *used -= 1); } if buffer.is_empty() { - self.buffer_by_client.remove(client_id); + self.buffer_by_client + .remove(&(client_id.clone(), stream_id.clone())); } } @@ -320,6 +331,7 @@ impl RemoteControlWebsocket { event: queued_server_envelope.event, client_id: queued_server_envelope.client_id, seq_id, + stream_id: queued_server_envelope.stream_id, }; state.outbound_buffer.insert(&server_envelope); @@ -391,14 +403,14 @@ impl RemoteControlWebsocket { continue; } _ = idle_sweep_interval.tick() => { - let expired_client_ids = match client_tracker.close_expired_clients().await { - Ok(expired_client_ids) => expired_client_ids, + let expired_client_keys = match client_tracker.close_expired_clients().await { + Ok(expired_client_keys) => expired_client_keys, Err(_) => return Ok(()), }; - if !expired_client_ids.is_empty() { + if !expired_client_keys.is_empty() { let mut state = state.lock().await; - for client_id in expired_client_ids { - state.outbound_buffer.remove(&client_id); + for (client_id, stream_id) in expired_client_keys { + state.outbound_buffer.remove(&client_id, &stream_id); } } continue; @@ -441,22 +453,29 @@ impl RemoteControlWebsocket { } }; + let resolved_stream_id = client_envelope + .stream_id + .clone() + .or_else(|| client_tracker.legacy_stream_id(&client_envelope.client_id)); let mut state = state.lock().await; if let Some(cursor) = client_envelope.cursor.as_deref() { state.subscribe_cursor = Some(cursor.to_string()); } if let ClientEvent::Ack = &client_envelope.event && let Some(acked_seq_id) = client_envelope.seq_id + && let Some(stream_id) = resolved_stream_id.as_ref() { state .outbound_buffer - .ack(&client_envelope.client_id, acked_seq_id); - } - if matches!(&client_envelope.event, ClientEvent::ClientClosed) - || remote_control_message_starts_connection(&client_envelope.event) - { - state.outbound_buffer.remove(&client_envelope.client_id); + .ack(&client_envelope.client_id, stream_id, acked_seq_id); } + if (matches!(&client_envelope.event, ClientEvent::ClientClosed) + || remote_control_message_starts_connection(&client_envelope.event)) + && let Some(stream_id) = resolved_stream_id.as_ref() { + state + .outbound_buffer + .remove(&client_envelope.client_id, stream_id); + } drop(state); if client_tracker @@ -834,6 +853,7 @@ mod tests { let mut auth_recovery = auth_manager.unauthorized_recovery(); let mut enrollment = Some(RemoteControlEnrollment { account_id: Some("account_id".to_string()), + environment_id: "env_test".to_string(), server_id: "srv_e_test".to_string(), server_name: "test-server".to_string(), }); @@ -888,6 +908,7 @@ mod tests { let mut auth_recovery = auth_manager.unauthorized_recovery(); let mut enrollment = Some(RemoteControlEnrollment { account_id: Some("account_id".to_string()), + environment_id: "env_test".to_string(), server_id: "srv_e_test".to_string(), server_name: "test-server".to_string(), }); diff --git a/codex-rs/state/migrations/0023_remote_control_enrollments.sql b/codex-rs/state/migrations/0023_remote_control_enrollments.sql index 9a2081dd8f..247b8d4192 100644 --- a/codex-rs/state/migrations/0023_remote_control_enrollments.sql +++ b/codex-rs/state/migrations/0023_remote_control_enrollments.sql @@ -2,6 +2,7 @@ CREATE TABLE remote_control_enrollments ( websocket_url TEXT NOT NULL, account_id TEXT NOT NULL, server_id TEXT NOT NULL, + environment_id TEXT NOT NULL, server_name TEXT NOT NULL, updated_at INTEGER NOT NULL, PRIMARY KEY (websocket_url, account_id) diff --git a/codex-rs/state/src/runtime/remote_control.rs b/codex-rs/state/src/runtime/remote_control.rs index 307dac8184..4ac1f81872 100644 --- a/codex-rs/state/src/runtime/remote_control.rs +++ b/codex-rs/state/src/runtime/remote_control.rs @@ -11,10 +11,10 @@ impl StateRuntime { &self, websocket_url: &str, account_id: Option<&str>, - ) -> anyhow::Result> { + ) -> anyhow::Result> { let row = sqlx::query( r#" -SELECT server_id, server_name +SELECT server_id, environment_id, server_name FROM remote_control_enrollments WHERE websocket_url = ? AND account_id = ? "#, @@ -24,8 +24,14 @@ WHERE websocket_url = ? AND account_id = ? .fetch_optional(self.pool.as_ref()) .await?; - row.map(|row| Ok((row.try_get("server_id")?, row.try_get("server_name")?))) - .transpose() + row.map(|row| { + Ok(( + row.try_get("server_id")?, + row.try_get("environment_id")?, + row.try_get("server_name")?, + )) + }) + .transpose() } pub async fn upsert_remote_control_enrollment( @@ -33,6 +39,7 @@ WHERE websocket_url = ? AND account_id = ? websocket_url: &str, account_id: Option<&str>, server_id: &str, + environment_id: &str, server_name: &str, ) -> anyhow::Result<()> { sqlx::query( @@ -41,11 +48,13 @@ INSERT INTO remote_control_enrollments ( websocket_url, account_id, server_id, + environment_id, server_name, updated_at -) VALUES (?, ?, ?, ?, ?) +) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(websocket_url, account_id) DO UPDATE SET server_id = excluded.server_id, + environment_id = excluded.environment_id, server_name = excluded.server_name, updated_at = excluded.updated_at "#, @@ -53,6 +62,7 @@ ON CONFLICT(websocket_url, account_id) DO UPDATE SET .bind(websocket_url) .bind(remote_control_account_id_key(account_id)) .bind(server_id) + .bind(environment_id) .bind(server_name) .bind(Utc::now().timestamp()) .execute(self.pool.as_ref()) @@ -97,6 +107,7 @@ mod tests { "wss://example.com/backend-api/wham/remote/control/server", Some("account-a"), "srv_e_first", + "env_first", "first-server", ) .await @@ -106,6 +117,7 @@ mod tests { "wss://example.com/backend-api/wham/remote/control/server", Some("account-b"), "srv_e_second", + "env_second", "second-server", ) .await @@ -119,7 +131,11 @@ mod tests { ) .await .expect("load first enrollment"), - Some(("srv_e_first".to_string(), "first-server".to_string())) + Some(( + "srv_e_first".to_string(), + "env_first".to_string(), + "first-server".to_string() + )) ); assert_eq!( runtime @@ -147,6 +163,7 @@ mod tests { "wss://example.com/backend-api/wham/remote/control/server", None, "srv_e_first", + "env_first", "first-server", ) .await @@ -156,6 +173,7 @@ mod tests { "wss://example.com/backend-api/wham/remote/control/server", Some("account-a"), "srv_e_second", + "env_second", "second-server", ) .await @@ -189,7 +207,11 @@ mod tests { ) .await .expect("load retained enrollment"), - Some(("srv_e_second".to_string(), "second-server".to_string())) + Some(( + "srv_e_second".to_string(), + "env_second".to_string(), + "second-server".to_string() + )) ); let _ = tokio::fs::remove_dir_all(codex_home).await;