mirror of
https://github.com/openai/codex.git
synced 2026-05-01 20:02:05 +03:00
cr
This commit is contained in:
@@ -6,6 +6,7 @@ pub use super::protocol::ClientEvent;
|
||||
pub use super::protocol::ClientId;
|
||||
use super::protocol::PongStatus;
|
||||
use super::protocol::ServerEvent;
|
||||
use super::protocol::StreamId;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use crate::transport::remote_control::QueuedServerEnvelope;
|
||||
@@ -33,8 +34,9 @@ struct ClientState {
|
||||
}
|
||||
|
||||
pub(crate) struct ClientTracker {
|
||||
clients: HashMap<ClientId, ClientState>,
|
||||
join_set: JoinSet<ClientId>,
|
||||
clients: HashMap<(ClientId, StreamId), ClientState>,
|
||||
legacy_stream_ids: HashMap<ClientId, StreamId>,
|
||||
join_set: JoinSet<(ClientId, StreamId)>,
|
||||
server_event_tx: mpsc::Sender<QueuedServerEnvelope>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
@@ -48,6 +50,7 @@ impl ClientTracker {
|
||||
) -> Self {
|
||||
Self {
|
||||
clients: HashMap::new(),
|
||||
legacy_stream_ids: HashMap::new(),
|
||||
join_set: JoinSet::new(),
|
||||
server_event_tx,
|
||||
transport_event_tx,
|
||||
@@ -55,12 +58,12 @@ impl ClientTracker {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn bookkeep_join_set(&mut self) -> Option<ClientId> {
|
||||
pub(crate) async fn bookkeep_join_set(&mut self) -> Option<(ClientId, StreamId)> {
|
||||
while let Some(join_result) = self.join_set.join_next().await {
|
||||
let Ok(client_id) = join_result else {
|
||||
let Ok(client_key) = join_result else {
|
||||
continue;
|
||||
};
|
||||
return Some(client_id);
|
||||
return Some(client_key);
|
||||
}
|
||||
futures::future::pending().await
|
||||
}
|
||||
@@ -68,8 +71,8 @@ impl ClientTracker {
|
||||
pub(crate) async fn shutdown(&mut self) {
|
||||
self.shutdown_token.cancel();
|
||||
|
||||
while let Some(client_id) = self.clients.keys().next().cloned() {
|
||||
let _ = self.close_client(&client_id).await;
|
||||
while let Some(client_key) = self.clients.keys().next().cloned() {
|
||||
let _ = self.close_client(&client_key).await;
|
||||
}
|
||||
|
||||
self.drain_join_set().await;
|
||||
@@ -79,6 +82,10 @@ impl ClientTracker {
|
||||
while self.join_set.join_next().await.is_some() {}
|
||||
}
|
||||
|
||||
pub(crate) fn legacy_stream_id(&self, client_id: &ClientId) -> Option<StreamId> {
|
||||
self.legacy_stream_ids.get(client_id).cloned()
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_message(
|
||||
&mut self,
|
||||
client_envelope: ClientEnvelope,
|
||||
@@ -86,14 +93,40 @@ impl ClientTracker {
|
||||
let ClientEnvelope {
|
||||
client_id,
|
||||
event,
|
||||
stream_id,
|
||||
seq_id,
|
||||
cursor: _,
|
||||
} = client_envelope;
|
||||
let is_legacy_stream_id = stream_id.is_none();
|
||||
let is_initialize = matches!(&event, ClientEvent::ClientMessage { message } if remote_control_message_starts_connection(message));
|
||||
let stream_id = match stream_id {
|
||||
Some(stream_id) => stream_id,
|
||||
None if is_initialize => {
|
||||
// TODO(ruslan): delete this fallback once all clients are updated to send stream_id.
|
||||
self.legacy_stream_ids
|
||||
.remove(&client_id)
|
||||
.unwrap_or_else(StreamId::new_random)
|
||||
}
|
||||
None => self
|
||||
.legacy_stream_ids
|
||||
.get(&client_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| {
|
||||
if matches!(&event, ClientEvent::Ping) {
|
||||
StreamId::new_random()
|
||||
} else {
|
||||
StreamId(String::new())
|
||||
}
|
||||
}),
|
||||
};
|
||||
if stream_id.0.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let client_key = (client_id.clone(), stream_id.clone());
|
||||
match event {
|
||||
ClientEvent::ClientMessage { message } => {
|
||||
let is_initialize = remote_control_message_starts_connection(&message);
|
||||
if let Some(seq_id) = seq_id
|
||||
&& let Some(client) = self.clients.get(&client_id)
|
||||
&& let Some(client) = self.clients.get(&client_key)
|
||||
&& client
|
||||
.last_inbound_seq_id
|
||||
.is_some_and(|last_seq_id| last_seq_id >= seq_id)
|
||||
@@ -102,24 +135,22 @@ impl ClientTracker {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if is_initialize && self.clients.contains_key(&client_id) {
|
||||
self.close_client(&client_id).await?;
|
||||
if is_initialize && self.clients.contains_key(&client_key) {
|
||||
self.close_client(&client_key).await?;
|
||||
}
|
||||
|
||||
if let Some(connection_id) = self.clients.get_mut(&client_id).map(|client| {
|
||||
if let Some(connection_id) = self.clients.get_mut(&client_key).map(|client| {
|
||||
client.last_activity_at = Instant::now();
|
||||
if let Some(seq_id) = seq_id {
|
||||
client.last_inbound_seq_id = Some(seq_id);
|
||||
}
|
||||
client.connection_id
|
||||
}) {
|
||||
self.transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| Stopped)?;
|
||||
self.send_transport_event(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
})
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -131,33 +162,35 @@ impl ClientTracker {
|
||||
let (writer_tx, writer_rx) =
|
||||
mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let disconnect_token = self.shutdown_token.child_token();
|
||||
self.transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: Some(disconnect_token.clone()),
|
||||
})
|
||||
.await
|
||||
.map_err(|_| Stopped)?;
|
||||
self.send_transport_event(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: Some(disconnect_token.clone()),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let (status_tx, status_rx) = watch::channel(PongStatus::Active);
|
||||
self.join_set.spawn(Self::run_client_outbound(
|
||||
client_id.clone(),
|
||||
stream_id.clone(),
|
||||
self.server_event_tx.clone(),
|
||||
writer_rx,
|
||||
status_rx,
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
self.clients.insert(
|
||||
client_id,
|
||||
client_key,
|
||||
ClientState {
|
||||
connection_id,
|
||||
disconnect_token,
|
||||
last_activity_at: Instant::now(),
|
||||
last_inbound_seq_id: seq_id,
|
||||
last_inbound_seq_id: if is_legacy_stream_id { None } else { seq_id },
|
||||
status_tx,
|
||||
},
|
||||
);
|
||||
if is_legacy_stream_id {
|
||||
self.legacy_stream_ids.insert(client_id.clone(), stream_id);
|
||||
}
|
||||
self.send_transport_event(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
@@ -166,7 +199,7 @@ impl ClientTracker {
|
||||
}
|
||||
ClientEvent::Ack => Ok(()),
|
||||
ClientEvent::Ping => {
|
||||
if let Some(client) = self.clients.get_mut(&client_id) {
|
||||
if let Some(client) = self.clients.get_mut(&client_key) {
|
||||
client.last_activity_at = Instant::now();
|
||||
let _ = client.status_tx.send(PongStatus::Active);
|
||||
return Ok(());
|
||||
@@ -179,23 +212,25 @@ impl ClientTracker {
|
||||
status: PongStatus::Unknown,
|
||||
},
|
||||
client_id,
|
||||
stream_id,
|
||||
write_complete_tx: None,
|
||||
};
|
||||
let _ = server_event_tx.send(server_envelope).await;
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
ClientEvent::ClientClosed => self.close_client(&client_id).await,
|
||||
ClientEvent::ClientClosed => self.close_client(&client_key).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_client_outbound(
|
||||
client_id: ClientId,
|
||||
stream_id: StreamId,
|
||||
server_event_tx: mpsc::Sender<QueuedServerEnvelope>,
|
||||
mut writer_rx: mpsc::Receiver<QueuedOutgoingMessage>,
|
||||
mut status_rx: watch::Receiver<PongStatus>,
|
||||
disconnect_token: CancellationToken,
|
||||
) -> ClientId {
|
||||
) -> (ClientId, StreamId) {
|
||||
loop {
|
||||
let (event, write_complete_tx) = tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
@@ -225,6 +260,7 @@ impl ClientTracker {
|
||||
send_result = server_event_tx.send(QueuedServerEnvelope {
|
||||
event,
|
||||
client_id: client_id.clone(),
|
||||
stream_id: stream_id.clone(),
|
||||
write_complete_tx,
|
||||
}) => send_result,
|
||||
};
|
||||
@@ -232,28 +268,40 @@ impl ClientTracker {
|
||||
break;
|
||||
}
|
||||
}
|
||||
client_id
|
||||
(client_id, stream_id)
|
||||
}
|
||||
|
||||
pub(crate) async fn close_expired_clients(&mut self) -> Result<Vec<ClientId>, Stopped> {
|
||||
pub(crate) async fn close_expired_clients(
|
||||
&mut self,
|
||||
) -> Result<Vec<(ClientId, StreamId)>, Stopped> {
|
||||
let now = Instant::now();
|
||||
let expired_client_ids: Vec<ClientId> = self
|
||||
let expired_client_ids: Vec<(ClientId, StreamId)> = self
|
||||
.clients
|
||||
.iter()
|
||||
.filter_map(|(client_id, client)| {
|
||||
(!remote_control_client_is_alive(client, now)).then_some(client_id.clone())
|
||||
.filter_map(|(client_key, client)| {
|
||||
(!remote_control_client_is_alive(client, now)).then_some(client_key.clone())
|
||||
})
|
||||
.collect();
|
||||
for client_id in &expired_client_ids {
|
||||
self.close_client(client_id).await?;
|
||||
for client_key in &expired_client_ids {
|
||||
self.close_client(client_key).await?;
|
||||
}
|
||||
Ok(expired_client_ids)
|
||||
}
|
||||
|
||||
pub(super) async fn close_client(&mut self, client_id: &ClientId) -> Result<(), Stopped> {
|
||||
let Some(client) = self.clients.remove(client_id) else {
|
||||
pub(super) async fn close_client(
|
||||
&mut self,
|
||||
client_key: &(ClientId, StreamId),
|
||||
) -> Result<(), Stopped> {
|
||||
let Some(client) = self.clients.remove(client_key) else {
|
||||
return Ok(());
|
||||
};
|
||||
if self
|
||||
.legacy_stream_ids
|
||||
.get(&client_key.0)
|
||||
.is_some_and(|stream_id| stream_id == &client_key.1)
|
||||
{
|
||||
self.legacy_stream_ids.remove(&client_key.0);
|
||||
}
|
||||
client.disconnect_token.cancel();
|
||||
self.send_transport_event(TransportEvent::ConnectionClosed {
|
||||
connection_id: client.connection_id,
|
||||
@@ -296,6 +344,13 @@ mod tests {
|
||||
use tokio::time::timeout;
|
||||
|
||||
fn initialize_envelope(client_id: &str) -> ClientEnvelope {
|
||||
initialize_envelope_with_stream_id(client_id, None)
|
||||
}
|
||||
|
||||
fn initialize_envelope_with_stream_id(
|
||||
client_id: &str,
|
||||
stream_id: Option<&str>,
|
||||
) -> ClientEnvelope {
|
||||
ClientEnvelope {
|
||||
event: ClientEvent::ClientMessage {
|
||||
message: JSONRPCMessage::Request(JSONRPCRequest {
|
||||
@@ -311,6 +366,7 @@ mod tests {
|
||||
}),
|
||||
},
|
||||
client_id: ClientId(client_id.to_string()),
|
||||
stream_id: stream_id.map(|stream_id| StreamId(stream_id.to_string())),
|
||||
seq_id: Some(0),
|
||||
cursor: None,
|
||||
}
|
||||
@@ -358,7 +414,7 @@ mod tests {
|
||||
.await
|
||||
.expect("bookkeeping should process the closed task")
|
||||
.expect("closed task should return client id");
|
||||
assert_eq!(closed_client_id, ClientId("client-1".to_string()));
|
||||
assert_eq!(closed_client_id.0, ClientId("client-1".to_string()));
|
||||
client_tracker
|
||||
.close_client(&closed_client_id)
|
||||
.await
|
||||
@@ -390,6 +446,7 @@ mod tests {
|
||||
status: PongStatus::Unknown,
|
||||
},
|
||||
client_id: ClientId("queued-client".to_string()),
|
||||
stream_id: StreamId("queued-stream".to_string()),
|
||||
write_complete_tx: None,
|
||||
})
|
||||
.await
|
||||
@@ -431,4 +488,85 @@ mod tests {
|
||||
.await
|
||||
.expect("shutdown should not hang on blocked server forwarding");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn initialize_with_new_stream_id_opens_new_connection_for_same_client() {
|
||||
let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let mut client_tracker =
|
||||
ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token);
|
||||
|
||||
client_tracker
|
||||
.handle_message(initialize_envelope_with_stream_id(
|
||||
"client-1",
|
||||
Some("stream-1"),
|
||||
))
|
||||
.await
|
||||
.expect("first initialize should open client");
|
||||
let first_connection_id = match transport_event_rx.recv().await.expect("open event") {
|
||||
TransportEvent::ConnectionOpened { connection_id, .. } => connection_id,
|
||||
other => panic!("expected connection opened, got {other:?}"),
|
||||
};
|
||||
let _ = transport_event_rx.recv().await.expect("initialize event");
|
||||
|
||||
client_tracker
|
||||
.handle_message(initialize_envelope_with_stream_id(
|
||||
"client-1",
|
||||
Some("stream-2"),
|
||||
))
|
||||
.await
|
||||
.expect("second initialize should open client");
|
||||
let second_connection_id = match transport_event_rx.recv().await.expect("open event") {
|
||||
TransportEvent::ConnectionOpened { connection_id, .. } => connection_id,
|
||||
other => panic!("expected connection opened, got {other:?}"),
|
||||
};
|
||||
|
||||
assert_ne!(first_connection_id, second_connection_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn legacy_initialize_without_stream_id_resets_inbound_seq_id() {
|
||||
let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let mut client_tracker =
|
||||
ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token);
|
||||
|
||||
client_tracker
|
||||
.handle_message(initialize_envelope("client-1"))
|
||||
.await
|
||||
.expect("initialize should open client");
|
||||
let connection_id = match transport_event_rx.recv().await.expect("open event") {
|
||||
TransportEvent::ConnectionOpened { connection_id, .. } => connection_id,
|
||||
other => panic!("expected connection opened, got {other:?}"),
|
||||
};
|
||||
let _ = transport_event_rx.recv().await.expect("initialize event");
|
||||
|
||||
client_tracker
|
||||
.handle_message(ClientEnvelope {
|
||||
event: ClientEvent::ClientMessage {
|
||||
message: JSONRPCMessage::Notification(
|
||||
codex_app_server_protocol::JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
},
|
||||
),
|
||||
},
|
||||
client_id: ClientId("client-1".to_string()),
|
||||
stream_id: None,
|
||||
seq_id: Some(0),
|
||||
cursor: None,
|
||||
})
|
||||
.await
|
||||
.expect("legacy followup should be forwarded");
|
||||
|
||||
match transport_event_rx.recv().await.expect("followup event") {
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: incoming_connection_id,
|
||||
..
|
||||
} => assert_eq!(incoming_connection_id, connection_id),
|
||||
other => panic!("expected incoming message, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(super) struct RemoteControlEnrollment {
|
||||
pub(super) account_id: Option<String>,
|
||||
pub(super) environment_id: String,
|
||||
pub(super) server_id: String,
|
||||
pub(super) server_name: String,
|
||||
}
|
||||
@@ -47,11 +48,14 @@ pub(super) async fn load_persisted_remote_control_enrollment(
|
||||
}
|
||||
};
|
||||
|
||||
enrollment.map(|(server_id, server_name)| RemoteControlEnrollment {
|
||||
account_id: account_id.map(&str::to_string),
|
||||
server_id,
|
||||
server_name,
|
||||
})
|
||||
enrollment.map(
|
||||
|(server_id, environment_id, server_name)| RemoteControlEnrollment {
|
||||
account_id: account_id.map(&str::to_string),
|
||||
environment_id,
|
||||
server_id,
|
||||
server_name,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(super) async fn update_persisted_remote_control_enrollment(
|
||||
@@ -77,6 +81,7 @@ pub(super) async fn update_persisted_remote_control_enrollment(
|
||||
&remote_control_target.websocket_url,
|
||||
account_id,
|
||||
&enrollment.server_id,
|
||||
&enrollment.environment_id,
|
||||
&enrollment.server_name,
|
||||
)
|
||||
.await
|
||||
@@ -182,6 +187,7 @@ pub(super) async fn enroll_remote_control_server(
|
||||
|
||||
Ok(RemoteControlEnrollment {
|
||||
account_id: account_id.map(&str::to_string),
|
||||
environment_id: enrollment.environment_id,
|
||||
server_id: enrollment.server_id,
|
||||
server_name,
|
||||
})
|
||||
@@ -221,11 +227,13 @@ mod tests {
|
||||
.expect("second target should parse");
|
||||
let first_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account-a".to_string()),
|
||||
environment_id: "env_first".to_string(),
|
||||
server_id: "srv_e_first".to_string(),
|
||||
server_name: "first-server".to_string(),
|
||||
};
|
||||
let second_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account-a".to_string()),
|
||||
environment_id: "env_second".to_string(),
|
||||
server_id: "srv_e_second".to_string(),
|
||||
server_name: "second-server".to_string(),
|
||||
};
|
||||
@@ -287,11 +295,13 @@ mod tests {
|
||||
.expect("second target should parse");
|
||||
let first_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account-a".to_string()),
|
||||
environment_id: "env_first".to_string(),
|
||||
server_id: "srv_e_first".to_string(),
|
||||
server_name: "first-server".to_string(),
|
||||
};
|
||||
let second_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account-a".to_string()),
|
||||
environment_id: "env_second".to_string(),
|
||||
server_id: "srv_e_second".to_string(),
|
||||
server_name: "second-server".to_string(),
|
||||
};
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::transport::remote_control::websocket::load_remote_control_auth;
|
||||
|
||||
pub use self::protocol::ClientId;
|
||||
use self::protocol::ServerEvent;
|
||||
use self::protocol::StreamId;
|
||||
use self::protocol::normalize_remote_control_url;
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
@@ -24,6 +25,7 @@ use tokio_util::sync::CancellationToken;
|
||||
pub(super) struct QueuedServerEnvelope {
|
||||
pub(super) event: ServerEvent,
|
||||
pub(super) client_id: ClientId,
|
||||
pub(super) stream_id: StreamId,
|
||||
pub(super) write_complete_tx: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
|
||||
@@ -24,12 +24,23 @@ pub(super) struct EnrollRemoteServerRequest {
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(super) struct EnrollRemoteServerResponse {
|
||||
pub(super) server_id: String,
|
||||
pub(super) environment_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct ClientId(pub String);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct StreamId(pub String);
|
||||
|
||||
impl StreamId {
|
||||
pub fn new_random() -> Self {
|
||||
Self(uuid::Uuid::now_v7().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ClientEvent {
|
||||
@@ -44,13 +55,11 @@ pub enum ClientEvent {
|
||||
pub(crate) struct ClientEnvelope {
|
||||
#[serde(flatten)]
|
||||
pub(crate) event: ClientEvent,
|
||||
#[serde(rename = "client_id", alias = "clientId")]
|
||||
#[serde(rename = "client_id")]
|
||||
pub(crate) client_id: ClientId,
|
||||
#[serde(
|
||||
rename = "seq_id",
|
||||
alias = "seqId",
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
#[serde(rename = "stream_id", skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) stream_id: Option<StreamId>,
|
||||
#[serde(rename = "seq_id", skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) seq_id: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) cursor: Option<String>,
|
||||
@@ -81,9 +90,11 @@ pub enum ServerEvent {
|
||||
pub(crate) struct ServerEnvelope {
|
||||
#[serde(flatten)]
|
||||
pub(crate) event: ServerEvent,
|
||||
#[serde(rename = "client_id", alias = "clientId")]
|
||||
#[serde(rename = "client_id")]
|
||||
pub(crate) client_id: ClientId,
|
||||
#[serde(rename = "seq_id", alias = "seqId")]
|
||||
#[serde(rename = "stream_id")]
|
||||
pub(crate) stream_id: StreamId,
|
||||
#[serde(rename = "seq_id")]
|
||||
pub(crate) seq_id: u64,
|
||||
}
|
||||
|
||||
|
||||
@@ -95,7 +95,11 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages()
|
||||
enroll_request.request_line,
|
||||
"POST /backend-api/wham/remote/control/server/enroll HTTP/1.1"
|
||||
);
|
||||
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await;
|
||||
respond_with_json(
|
||||
enroll_request.stream,
|
||||
json!({ "server_id": "srv_e_test", "environment_id": "env_test" }),
|
||||
)
|
||||
.await;
|
||||
let mut websocket = accept_remote_control_connection(&listener).await;
|
||||
|
||||
let client_id = ClientId("client-1".to_string());
|
||||
@@ -104,6 +108,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages()
|
||||
ClientEnvelope {
|
||||
event: ClientEvent::Ping,
|
||||
client_id: client_id.clone(),
|
||||
stream_id: None,
|
||||
seq_id: None,
|
||||
cursor: None,
|
||||
},
|
||||
@@ -131,6 +136,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages()
|
||||
),
|
||||
},
|
||||
client_id: client_id.clone(),
|
||||
stream_id: None,
|
||||
seq_id: Some(0),
|
||||
cursor: None,
|
||||
},
|
||||
@@ -161,6 +167,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages()
|
||||
message: initialize_message.clone(),
|
||||
},
|
||||
client_id: client_id.clone(),
|
||||
stream_id: None,
|
||||
seq_id: Some(1),
|
||||
cursor: None,
|
||||
},
|
||||
@@ -207,6 +214,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages()
|
||||
message: followup_message.clone(),
|
||||
},
|
||||
client_id: client_id.clone(),
|
||||
stream_id: None,
|
||||
seq_id: Some(2),
|
||||
cursor: None,
|
||||
},
|
||||
@@ -232,6 +240,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages()
|
||||
ClientEnvelope {
|
||||
event: ClientEvent::Ping,
|
||||
client_id: client_id.clone(),
|
||||
stream_id: None,
|
||||
seq_id: None,
|
||||
cursor: None,
|
||||
},
|
||||
@@ -281,6 +290,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages()
|
||||
ClientEnvelope {
|
||||
event: ClientEvent::ClientClosed,
|
||||
client_id: client_id.clone(),
|
||||
stream_id: None,
|
||||
seq_id: None,
|
||||
cursor: None,
|
||||
},
|
||||
@@ -304,6 +314,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages()
|
||||
ClientEnvelope {
|
||||
event: ClientEvent::Ping,
|
||||
client_id,
|
||||
stream_id: None,
|
||||
seq_id: None,
|
||||
cursor: None,
|
||||
},
|
||||
@@ -348,7 +359,11 @@ async fn remote_control_transport_reconnects_after_disconnect() {
|
||||
enroll_request.request_line,
|
||||
"POST /backend-api/wham/remote/control/server/enroll HTTP/1.1"
|
||||
);
|
||||
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await;
|
||||
respond_with_json(
|
||||
enroll_request.stream,
|
||||
json!({ "server_id": "srv_e_test", "environment_id": "env_test" }),
|
||||
)
|
||||
.await;
|
||||
let mut first_websocket = accept_remote_control_connection(&listener).await;
|
||||
first_websocket
|
||||
.close(None)
|
||||
@@ -374,6 +389,7 @@ async fn remote_control_transport_reconnects_after_disconnect() {
|
||||
}),
|
||||
},
|
||||
client_id: ClientId("client-2".to_string()),
|
||||
stream_id: None,
|
||||
seq_id: Some(0),
|
||||
cursor: None,
|
||||
},
|
||||
@@ -414,7 +430,11 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() {
|
||||
.expect("remote control should start");
|
||||
|
||||
let enroll_request = accept_http_request(&listener).await;
|
||||
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await;
|
||||
respond_with_json(
|
||||
enroll_request.stream,
|
||||
json!({ "server_id": "srv_e_test", "environment_id": "env_test" }),
|
||||
)
|
||||
.await;
|
||||
let mut first_websocket = accept_remote_control_connection(&listener).await;
|
||||
|
||||
let client_id = ClientId("client-1".to_string());
|
||||
@@ -436,6 +456,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() {
|
||||
message: initialize_message,
|
||||
},
|
||||
client_id: client_id.clone(),
|
||||
stream_id: None,
|
||||
seq_id: Some(0),
|
||||
cursor: None,
|
||||
},
|
||||
@@ -493,6 +514,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() {
|
||||
ClientEnvelope {
|
||||
event: ClientEvent::ClientClosed,
|
||||
client_id: client_id.clone(),
|
||||
stream_id: None,
|
||||
seq_id: None,
|
||||
cursor: None,
|
||||
},
|
||||
@@ -519,6 +541,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() {
|
||||
ClientEnvelope {
|
||||
event: ClientEvent::Ping,
|
||||
client_id,
|
||||
stream_id: None,
|
||||
seq_id: None,
|
||||
cursor: None,
|
||||
},
|
||||
@@ -582,7 +605,11 @@ async fn remote_control_http_mode_enrolls_before_connecting() {
|
||||
"app_server_version": env!("CARGO_PKG_VERSION"),
|
||||
})
|
||||
);
|
||||
respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await;
|
||||
respond_with_json(
|
||||
enroll_request.stream,
|
||||
json!({ "server_id": "srv_e_test", "environment_id": "env_test" }),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (handshake_request, mut websocket) =
|
||||
accept_remote_control_backend_connection(&listener).await;
|
||||
@@ -634,6 +661,7 @@ async fn remote_control_http_mode_enrolls_before_connecting() {
|
||||
message: initialize_message.clone(),
|
||||
},
|
||||
client_id: backend_client_id.clone(),
|
||||
stream_id: None,
|
||||
seq_id: Some(0),
|
||||
cursor: None,
|
||||
},
|
||||
@@ -742,6 +770,7 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling
|
||||
normalize_remote_control_url(&remote_control_url).expect("target should parse");
|
||||
let persisted_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account_id".to_string()),
|
||||
environment_id: "env_persisted".to_string(),
|
||||
server_id: "srv_e_persisted".to_string(),
|
||||
server_name: "persisted-server".to_string(),
|
||||
};
|
||||
@@ -803,11 +832,13 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404()
|
||||
let expected_server_name = gethostname().to_string_lossy().trim().to_string();
|
||||
let stale_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account_id".to_string()),
|
||||
environment_id: "env_stale".to_string(),
|
||||
server_id: "srv_e_stale".to_string(),
|
||||
server_name: "stale-server".to_string(),
|
||||
};
|
||||
let refreshed_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account_id".to_string()),
|
||||
environment_id: "env_refreshed".to_string(),
|
||||
server_id: "srv_e_refreshed".to_string(),
|
||||
server_name: expected_server_name,
|
||||
};
|
||||
@@ -851,7 +882,10 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404()
|
||||
);
|
||||
respond_with_json(
|
||||
enroll_request.stream,
|
||||
json!({ "server_id": refreshed_enrollment.server_id }),
|
||||
json!({
|
||||
"server_id": refreshed_enrollment.server_id,
|
||||
"environment_id": refreshed_enrollment.environment_id,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1048,8 +1082,15 @@ async fn read_server_event(websocket: &mut WebSocketStream<TcpStream>) -> serde_
|
||||
.expect("websocket frame should be readable");
|
||||
match frame {
|
||||
tungstenite::Message::Text(text) => {
|
||||
return serde_json::from_str(text.as_ref())
|
||||
.expect("server event should deserialize");
|
||||
let mut event: serde_json::Value =
|
||||
serde_json::from_str(text.as_ref()).expect("server event should deserialize");
|
||||
if let Some(stream_id) = event
|
||||
.as_object_mut()
|
||||
.and_then(|event| event.remove("stream_id"))
|
||||
{
|
||||
assert!(stream_id.is_string(), "stream_id should be a string");
|
||||
}
|
||||
return event;
|
||||
}
|
||||
tungstenite::Message::Ping(payload) => {
|
||||
websocket
|
||||
|
||||
@@ -14,6 +14,7 @@ use super::protocol::ClientEvent;
|
||||
use super::protocol::ClientId;
|
||||
use super::protocol::RemoteControlTarget;
|
||||
use super::protocol::ServerEnvelope;
|
||||
use super::protocol::StreamId;
|
||||
use axum::http::HeaderValue;
|
||||
use base64::Engine;
|
||||
use codex_core::AuthManager;
|
||||
@@ -50,7 +51,7 @@ pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
|
||||
const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor";
|
||||
|
||||
struct BoundedOutboundBuffer {
|
||||
buffer_by_client: HashMap<ClientId, BTreeMap<u64, ServerEnvelope>>,
|
||||
buffer_by_client: HashMap<(ClientId, StreamId), BTreeMap<u64, ServerEnvelope>>,
|
||||
used_tx: watch::Sender<usize>,
|
||||
}
|
||||
|
||||
@@ -66,20 +67,29 @@ impl BoundedOutboundBuffer {
|
||||
|
||||
fn insert(&mut self, server_envelope: &ServerEnvelope) {
|
||||
self.buffer_by_client
|
||||
.entry(server_envelope.client_id.clone())
|
||||
.entry((
|
||||
server_envelope.client_id.clone(),
|
||||
server_envelope.stream_id.clone(),
|
||||
))
|
||||
.or_default()
|
||||
.insert(server_envelope.seq_id, server_envelope.clone());
|
||||
self.used_tx.send_modify(|used| *used += 1);
|
||||
}
|
||||
|
||||
fn remove(&mut self, client_id: &ClientId) {
|
||||
if let Some(buffer) = self.buffer_by_client.remove(client_id) {
|
||||
fn remove(&mut self, client_id: &ClientId, stream_id: &StreamId) {
|
||||
if let Some(buffer) = self
|
||||
.buffer_by_client
|
||||
.remove(&(client_id.clone(), stream_id.clone()))
|
||||
{
|
||||
self.used_tx.send_modify(|used| *used -= buffer.len());
|
||||
}
|
||||
}
|
||||
|
||||
fn ack(&mut self, client_id: &ClientId, acked_seq_id: u64) {
|
||||
let Some(buffer) = self.buffer_by_client.get_mut(client_id) else {
|
||||
fn ack(&mut self, client_id: &ClientId, stream_id: &StreamId, acked_seq_id: u64) {
|
||||
let Some(buffer) = self
|
||||
.buffer_by_client
|
||||
.get_mut(&(client_id.clone(), stream_id.clone()))
|
||||
else {
|
||||
return;
|
||||
};
|
||||
while let Some(seq_id) = buffer.first_key_value().map(|(seq_id, _)| seq_id)
|
||||
@@ -89,7 +99,8 @@ impl BoundedOutboundBuffer {
|
||||
self.used_tx.send_modify(|used| *used -= 1);
|
||||
}
|
||||
if buffer.is_empty() {
|
||||
self.buffer_by_client.remove(client_id);
|
||||
self.buffer_by_client
|
||||
.remove(&(client_id.clone(), stream_id.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,6 +331,7 @@ impl RemoteControlWebsocket {
|
||||
event: queued_server_envelope.event,
|
||||
client_id: queued_server_envelope.client_id,
|
||||
seq_id,
|
||||
stream_id: queued_server_envelope.stream_id,
|
||||
};
|
||||
state.outbound_buffer.insert(&server_envelope);
|
||||
|
||||
@@ -391,14 +403,14 @@ impl RemoteControlWebsocket {
|
||||
continue;
|
||||
}
|
||||
_ = idle_sweep_interval.tick() => {
|
||||
let expired_client_ids = match client_tracker.close_expired_clients().await {
|
||||
Ok(expired_client_ids) => expired_client_ids,
|
||||
let expired_client_keys = match client_tracker.close_expired_clients().await {
|
||||
Ok(expired_client_keys) => expired_client_keys,
|
||||
Err(_) => return Ok(()),
|
||||
};
|
||||
if !expired_client_ids.is_empty() {
|
||||
if !expired_client_keys.is_empty() {
|
||||
let mut state = state.lock().await;
|
||||
for client_id in expired_client_ids {
|
||||
state.outbound_buffer.remove(&client_id);
|
||||
for (client_id, stream_id) in expired_client_keys {
|
||||
state.outbound_buffer.remove(&client_id, &stream_id);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
@@ -441,22 +453,29 @@ impl RemoteControlWebsocket {
|
||||
}
|
||||
};
|
||||
|
||||
let resolved_stream_id = client_envelope
|
||||
.stream_id
|
||||
.clone()
|
||||
.or_else(|| client_tracker.legacy_stream_id(&client_envelope.client_id));
|
||||
let mut state = state.lock().await;
|
||||
if let Some(cursor) = client_envelope.cursor.as_deref() {
|
||||
state.subscribe_cursor = Some(cursor.to_string());
|
||||
}
|
||||
if let ClientEvent::Ack = &client_envelope.event
|
||||
&& let Some(acked_seq_id) = client_envelope.seq_id
|
||||
&& let Some(stream_id) = resolved_stream_id.as_ref()
|
||||
{
|
||||
state
|
||||
.outbound_buffer
|
||||
.ack(&client_envelope.client_id, acked_seq_id);
|
||||
}
|
||||
if matches!(&client_envelope.event, ClientEvent::ClientClosed)
|
||||
|| remote_control_message_starts_connection(&client_envelope.event)
|
||||
{
|
||||
state.outbound_buffer.remove(&client_envelope.client_id);
|
||||
.ack(&client_envelope.client_id, stream_id, acked_seq_id);
|
||||
}
|
||||
if (matches!(&client_envelope.event, ClientEvent::ClientClosed)
|
||||
|| remote_control_message_starts_connection(&client_envelope.event))
|
||||
&& let Some(stream_id) = resolved_stream_id.as_ref() {
|
||||
state
|
||||
.outbound_buffer
|
||||
.remove(&client_envelope.client_id, stream_id);
|
||||
}
|
||||
drop(state);
|
||||
|
||||
if client_tracker
|
||||
@@ -834,6 +853,7 @@ mod tests {
|
||||
let mut auth_recovery = auth_manager.unauthorized_recovery();
|
||||
let mut enrollment = Some(RemoteControlEnrollment {
|
||||
account_id: Some("account_id".to_string()),
|
||||
environment_id: "env_test".to_string(),
|
||||
server_id: "srv_e_test".to_string(),
|
||||
server_name: "test-server".to_string(),
|
||||
});
|
||||
@@ -888,6 +908,7 @@ mod tests {
|
||||
let mut auth_recovery = auth_manager.unauthorized_recovery();
|
||||
let mut enrollment = Some(RemoteControlEnrollment {
|
||||
account_id: Some("account_id".to_string()),
|
||||
environment_id: "env_test".to_string(),
|
||||
server_id: "srv_e_test".to_string(),
|
||||
server_name: "test-server".to_string(),
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user