mirror of
https://github.com/openai/codex.git
synced 2026-04-28 18:32:04 +03:00
## Why `argument-comment-lint` was green in CI even though the repo still had many uncommented literal arguments. The main gap was target coverage: the repo wrapper did not force Cargo to inspect test-only call sites, so examples like the `latest_session_lookup_params(true, ...)` tests in `codex-rs/tui_app_server/src/lib.rs` never entered the blocking CI path. This change cleans up the existing backlog, makes the default repo lint path cover all Cargo targets, and starts rolling that stricter CI enforcement out on the platform where it is currently validated. ## What changed - mechanically fixed existing `argument-comment-lint` violations across the `codex-rs` workspace, including tests, examples, and benches - updated `tools/argument-comment-lint/run-prebuilt-linter.sh` and `tools/argument-comment-lint/run.sh` so non-`--fix` runs default to `--all-targets` unless the caller explicitly narrows the target set - fixed both wrappers so forwarded cargo arguments after `--` are preserved with a single separator - documented the new default behavior in `tools/argument-comment-lint/README.md` - updated `rust-ci` so the macOS lint lane keeps the plain wrapper invocation and therefore enforces `--all-targets`, while Linux and Windows temporarily pass `-- --lib --bins` That temporary CI split keeps the stricter all-targets check where it is already cleaned up, while leaving room to finish the remaining Linux- and Windows-specific target-gated cleanup before enabling `--all-targets` on those runners. The Linux and Windows failures on the intermediate revision were caused by the wrapper forwarding bug, not by additional lint findings in those lanes. ## Validation - `bash -n tools/argument-comment-lint/run.sh` - `bash -n tools/argument-comment-lint/run-prebuilt-linter.sh` - shell-level wrapper forwarding check for `-- --lib --bins` - shell-level wrapper forwarding check for `-- --tests` - `just argument-comment-lint` - `cargo test` in `tools/argument-comment-lint` - `cargo test -p codex-terminal-detection` ## Follow-up - Clean up remaining Linux-only target-gated callsites, then switch the Linux lint lane back to the plain wrapper invocation. - Clean up remaining Windows-only target-gated callsites, then switch the Windows lint lane back to the plain wrapper invocation.
1168 lines
40 KiB
Rust
1168 lines
40 KiB
Rust
use std::collections::HashMap;
|
|
use std::fmt;
|
|
use std::sync::Arc;
|
|
use std::sync::atomic::AtomicI64;
|
|
use std::sync::atomic::Ordering;
|
|
|
|
use codex_app_server_protocol::JSONRPCErrorError;
|
|
use codex_app_server_protocol::RequestId;
|
|
use codex_app_server_protocol::Result;
|
|
use codex_app_server_protocol::ServerNotification;
|
|
use codex_app_server_protocol::ServerRequest;
|
|
use codex_app_server_protocol::ServerRequestPayload;
|
|
use codex_otel::span_w3c_trace_context;
|
|
use codex_protocol::ThreadId;
|
|
use codex_protocol::protocol::W3cTraceContext;
|
|
use serde::Serialize;
|
|
use tokio::sync::Mutex;
|
|
use tokio::sync::mpsc;
|
|
use tokio::sync::oneshot;
|
|
use tracing::Instrument;
|
|
use tracing::Span;
|
|
use tracing::warn;
|
|
|
|
use crate::error_code::INTERNAL_ERROR_CODE;
|
|
use crate::server_request_error::TURN_TRANSITION_PENDING_REQUEST_ERROR_REASON;
|
|
|
|
#[cfg(test)]
|
|
use codex_protocol::account::PlanType;
|
|
|
|
pub(crate) type ClientRequestResult = std::result::Result<Result, JSONRPCErrorError>;
|
|
|
|
/// Stable identifier for a transport connection.
|
|
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
|
|
pub(crate) struct ConnectionId(pub(crate) u64);
|
|
|
|
impl fmt::Display for ConnectionId {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
write!(f, "{}", self.0)
|
|
}
|
|
}
|
|
|
|
/// Stable identifier for a client request scoped to a transport connection.
|
|
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
|
|
pub(crate) struct ConnectionRequestId {
|
|
pub(crate) connection_id: ConnectionId,
|
|
pub(crate) request_id: RequestId,
|
|
}
|
|
|
|
/// Trace data we keep for an incoming request until we send its final
|
|
/// response or error.
|
|
#[derive(Clone)]
|
|
pub(crate) struct RequestContext {
|
|
request_id: ConnectionRequestId,
|
|
span: Span,
|
|
parent_trace: Option<W3cTraceContext>,
|
|
}
|
|
|
|
impl RequestContext {
|
|
pub(crate) fn new(
|
|
request_id: ConnectionRequestId,
|
|
span: Span,
|
|
parent_trace: Option<W3cTraceContext>,
|
|
) -> Self {
|
|
Self {
|
|
request_id,
|
|
span,
|
|
parent_trace,
|
|
}
|
|
}
|
|
|
|
pub(crate) fn request_trace(&self) -> Option<W3cTraceContext> {
|
|
span_w3c_trace_context(&self.span).or_else(|| self.parent_trace.clone())
|
|
}
|
|
|
|
pub(crate) fn span(&self) -> Span {
|
|
self.span.clone()
|
|
}
|
|
|
|
fn record_turn_id(&self, turn_id: &str) {
|
|
self.span.record("turn.id", turn_id);
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub(crate) enum OutgoingEnvelope {
|
|
ToConnection {
|
|
connection_id: ConnectionId,
|
|
message: OutgoingMessage,
|
|
write_complete_tx: Option<oneshot::Sender<()>>,
|
|
},
|
|
Broadcast {
|
|
message: OutgoingMessage,
|
|
},
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub(crate) struct QueuedOutgoingMessage {
|
|
pub(crate) message: OutgoingMessage,
|
|
pub(crate) write_complete_tx: Option<oneshot::Sender<()>>,
|
|
}
|
|
|
|
impl QueuedOutgoingMessage {
|
|
pub(crate) fn new(message: OutgoingMessage) -> Self {
|
|
Self {
|
|
message,
|
|
write_complete_tx: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Sends messages to the client and manages request callbacks.
|
|
pub(crate) struct OutgoingMessageSender {
|
|
next_server_request_id: AtomicI64,
|
|
sender: mpsc::Sender<OutgoingEnvelope>,
|
|
request_id_to_callback: Mutex<HashMap<RequestId, PendingCallbackEntry>>,
|
|
/// Incoming requests that are still waiting on a final response or error.
|
|
/// We keep them here because this is where responses, errors, and
|
|
/// disconnect cleanup all get handled.
|
|
request_contexts: Mutex<HashMap<ConnectionRequestId, RequestContext>>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub(crate) struct ThreadScopedOutgoingMessageSender {
|
|
outgoing: Arc<OutgoingMessageSender>,
|
|
connection_ids: Arc<Vec<ConnectionId>>,
|
|
thread_id: ThreadId,
|
|
}
|
|
|
|
struct PendingCallbackEntry {
|
|
callback: oneshot::Sender<ClientRequestResult>,
|
|
thread_id: Option<ThreadId>,
|
|
request: ServerRequest,
|
|
}
|
|
|
|
impl ThreadScopedOutgoingMessageSender {
|
|
pub(crate) fn new(
|
|
outgoing: Arc<OutgoingMessageSender>,
|
|
connection_ids: Vec<ConnectionId>,
|
|
thread_id: ThreadId,
|
|
) -> Self {
|
|
Self {
|
|
outgoing,
|
|
connection_ids: Arc::new(connection_ids),
|
|
thread_id,
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn send_request(
|
|
&self,
|
|
payload: ServerRequestPayload,
|
|
) -> (RequestId, oneshot::Receiver<ClientRequestResult>) {
|
|
self.outgoing
|
|
.send_request_to_connections(
|
|
Some(self.connection_ids.as_slice()),
|
|
payload,
|
|
Some(self.thread_id),
|
|
)
|
|
.await
|
|
}
|
|
|
|
pub(crate) async fn send_server_notification(&self, notification: ServerNotification) {
|
|
if self.connection_ids.is_empty() {
|
|
return;
|
|
}
|
|
self.outgoing
|
|
.send_server_notification_to_connections(self.connection_ids.as_slice(), notification)
|
|
.await;
|
|
}
|
|
|
|
pub(crate) async fn send_global_server_notification(&self, notification: ServerNotification) {
|
|
self.outgoing.send_server_notification(notification).await;
|
|
}
|
|
|
|
pub(crate) async fn abort_pending_server_requests(&self) {
|
|
self.outgoing
|
|
.cancel_requests_for_thread(
|
|
self.thread_id,
|
|
Some(JSONRPCErrorError {
|
|
code: INTERNAL_ERROR_CODE,
|
|
message: "client request resolved because the turn state was changed"
|
|
.to_string(),
|
|
data: Some(serde_json::json!({ "reason": TURN_TRANSITION_PENDING_REQUEST_ERROR_REASON })),
|
|
}),
|
|
)
|
|
.await
|
|
}
|
|
|
|
pub(crate) async fn send_response<T: Serialize>(
|
|
&self,
|
|
request_id: ConnectionRequestId,
|
|
response: T,
|
|
) {
|
|
self.outgoing.send_response(request_id, response).await;
|
|
}
|
|
|
|
pub(crate) async fn send_error(
|
|
&self,
|
|
request_id: ConnectionRequestId,
|
|
error: JSONRPCErrorError,
|
|
) {
|
|
self.outgoing.send_error(request_id, error).await;
|
|
}
|
|
}
|
|
|
|
impl OutgoingMessageSender {
|
|
pub(crate) fn new(sender: mpsc::Sender<OutgoingEnvelope>) -> Self {
|
|
Self {
|
|
next_server_request_id: AtomicI64::new(0),
|
|
sender,
|
|
request_id_to_callback: Mutex::new(HashMap::new()),
|
|
request_contexts: Mutex::new(HashMap::new()),
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn register_request_context(&self, request_context: RequestContext) {
|
|
let mut request_contexts = self.request_contexts.lock().await;
|
|
if request_contexts
|
|
.insert(request_context.request_id.clone(), request_context)
|
|
.is_some()
|
|
{
|
|
warn!("replaced unresolved request context");
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn connection_closed(&self, connection_id: ConnectionId) {
|
|
let mut request_contexts = self.request_contexts.lock().await;
|
|
request_contexts.retain(|request_id, _| request_id.connection_id != connection_id);
|
|
}
|
|
|
|
pub(crate) async fn request_trace_context(
|
|
&self,
|
|
request_id: &ConnectionRequestId,
|
|
) -> Option<W3cTraceContext> {
|
|
let request_contexts = self.request_contexts.lock().await;
|
|
request_contexts
|
|
.get(request_id)
|
|
.and_then(RequestContext::request_trace)
|
|
}
|
|
|
|
pub(crate) async fn record_request_turn_id(
|
|
&self,
|
|
request_id: &ConnectionRequestId,
|
|
turn_id: &str,
|
|
) {
|
|
let request_contexts = self.request_contexts.lock().await;
|
|
if let Some(request_context) = request_contexts.get(request_id) {
|
|
request_context.record_turn_id(turn_id);
|
|
}
|
|
}
|
|
|
|
async fn take_request_context(
|
|
&self,
|
|
request_id: &ConnectionRequestId,
|
|
) -> Option<RequestContext> {
|
|
let mut request_contexts = self.request_contexts.lock().await;
|
|
request_contexts.remove(request_id)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
async fn request_context_count(&self) -> usize {
|
|
self.request_contexts.lock().await.len()
|
|
}
|
|
|
|
pub(crate) async fn send_request(
|
|
&self,
|
|
request: ServerRequestPayload,
|
|
) -> (RequestId, oneshot::Receiver<ClientRequestResult>) {
|
|
self.send_request_to_connections(
|
|
/*connection_ids*/ None, request, /*thread_id*/ None,
|
|
)
|
|
.await
|
|
}
|
|
|
|
fn next_request_id(&self) -> RequestId {
|
|
RequestId::Integer(self.next_server_request_id.fetch_add(1, Ordering::Relaxed))
|
|
}
|
|
|
|
async fn send_request_to_connections(
|
|
&self,
|
|
connection_ids: Option<&[ConnectionId]>,
|
|
request: ServerRequestPayload,
|
|
thread_id: Option<ThreadId>,
|
|
) -> (RequestId, oneshot::Receiver<ClientRequestResult>) {
|
|
let id = self.next_request_id();
|
|
let outgoing_message_id = id.clone();
|
|
let request = request.request_with_id(outgoing_message_id.clone());
|
|
|
|
let (tx_approve, rx_approve) = oneshot::channel();
|
|
{
|
|
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
|
request_id_to_callback.insert(
|
|
id,
|
|
PendingCallbackEntry {
|
|
callback: tx_approve,
|
|
thread_id,
|
|
request: request.clone(),
|
|
},
|
|
);
|
|
}
|
|
|
|
let outgoing_message = OutgoingMessage::Request(request);
|
|
let send_result = match connection_ids {
|
|
None => {
|
|
self.sender
|
|
.send(OutgoingEnvelope::Broadcast {
|
|
message: outgoing_message,
|
|
})
|
|
.await
|
|
}
|
|
Some(connection_ids) => {
|
|
let mut send_error = None;
|
|
for connection_id in connection_ids {
|
|
if let Err(err) = self
|
|
.sender
|
|
.send(OutgoingEnvelope::ToConnection {
|
|
connection_id: *connection_id,
|
|
message: outgoing_message.clone(),
|
|
write_complete_tx: None,
|
|
})
|
|
.await
|
|
{
|
|
send_error = Some(err);
|
|
break;
|
|
}
|
|
}
|
|
match send_error {
|
|
Some(err) => Err(err),
|
|
None => Ok(()),
|
|
}
|
|
}
|
|
};
|
|
|
|
if let Err(err) = send_result {
|
|
warn!("failed to send request {outgoing_message_id:?} to client: {err:?}");
|
|
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
|
request_id_to_callback.remove(&outgoing_message_id);
|
|
}
|
|
(outgoing_message_id, rx_approve)
|
|
}
|
|
|
|
pub(crate) async fn replay_requests_to_connection_for_thread(
|
|
&self,
|
|
connection_id: ConnectionId,
|
|
thread_id: ThreadId,
|
|
) {
|
|
let requests = self.pending_requests_for_thread(thread_id).await;
|
|
for request in requests {
|
|
if let Err(err) = self
|
|
.sender
|
|
.send(OutgoingEnvelope::ToConnection {
|
|
connection_id,
|
|
message: OutgoingMessage::Request(request),
|
|
write_complete_tx: None,
|
|
})
|
|
.await
|
|
{
|
|
warn!("failed to resend request to client: {err:?}");
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn notify_client_response(&self, id: RequestId, result: Result) {
|
|
let entry = self.take_request_callback(&id).await;
|
|
|
|
match entry {
|
|
Some((id, entry)) => {
|
|
if let Err(err) = entry.callback.send(Ok(result)) {
|
|
warn!("could not notify callback for {id:?} due to: {err:?}");
|
|
}
|
|
}
|
|
None => {
|
|
warn!("could not find callback for {id:?}");
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn notify_client_error(&self, id: RequestId, error: JSONRPCErrorError) {
|
|
let entry = self.take_request_callback(&id).await;
|
|
|
|
match entry {
|
|
Some((id, entry)) => {
|
|
warn!("client responded with error for {id:?}: {error:?}");
|
|
if let Err(err) = entry.callback.send(Err(error)) {
|
|
warn!("could not notify callback for {id:?} due to: {err:?}");
|
|
}
|
|
}
|
|
None => {
|
|
warn!("could not find callback for {id:?}");
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn cancel_request(&self, id: &RequestId) -> bool {
|
|
self.take_request_callback(id).await.is_some()
|
|
}
|
|
|
|
pub(crate) async fn cancel_all_requests(&self, error: Option<JSONRPCErrorError>) {
|
|
let entries = {
|
|
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
|
request_id_to_callback
|
|
.drain()
|
|
.map(|(_, entry)| entry)
|
|
.collect::<Vec<_>>()
|
|
};
|
|
|
|
if let Some(error) = error {
|
|
for entry in entries {
|
|
if let Err(err) = entry.callback.send(Err(error.clone())) {
|
|
let request_id = entry.request.id();
|
|
warn!("could not notify callback for {request_id:?} due to: {err:?}");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn take_request_callback(
|
|
&self,
|
|
id: &RequestId,
|
|
) -> Option<(RequestId, PendingCallbackEntry)> {
|
|
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
|
request_id_to_callback.remove_entry(id)
|
|
}
|
|
|
|
pub(crate) async fn pending_requests_for_thread(
|
|
&self,
|
|
thread_id: ThreadId,
|
|
) -> Vec<ServerRequest> {
|
|
let request_id_to_callback = self.request_id_to_callback.lock().await;
|
|
let mut requests = request_id_to_callback
|
|
.iter()
|
|
.filter_map(|(_, entry)| {
|
|
(entry.thread_id == Some(thread_id)).then_some(entry.request.clone())
|
|
})
|
|
.collect::<Vec<_>>();
|
|
requests.sort_by(|left, right| left.id().cmp(right.id()));
|
|
requests
|
|
}
|
|
|
|
pub(crate) async fn cancel_requests_for_thread(
|
|
&self,
|
|
thread_id: ThreadId,
|
|
error: Option<JSONRPCErrorError>,
|
|
) {
|
|
let entries = {
|
|
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
|
let request_ids = request_id_to_callback
|
|
.iter()
|
|
.filter_map(|(request_id, entry)| {
|
|
(entry.thread_id == Some(thread_id)).then_some(request_id.clone())
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
let mut entries = Vec::with_capacity(request_ids.len());
|
|
for request_id in request_ids {
|
|
if let Some(entry) = request_id_to_callback.remove(&request_id) {
|
|
entries.push(entry);
|
|
}
|
|
}
|
|
entries
|
|
};
|
|
|
|
if let Some(error) = error {
|
|
for entry in entries {
|
|
if let Err(err) = entry.callback.send(Err(error.clone())) {
|
|
let request_id = entry.request.id();
|
|
warn!("could not notify callback for {request_id:?} due to: {err:?}",);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn send_response<T: Serialize>(
|
|
&self,
|
|
request_id: ConnectionRequestId,
|
|
response: T,
|
|
) {
|
|
let request_context = self.take_request_context(&request_id).await;
|
|
match serde_json::to_value(response) {
|
|
Ok(result) => {
|
|
let outgoing_message = OutgoingMessage::Response(OutgoingResponse {
|
|
id: request_id.request_id.clone(),
|
|
result,
|
|
});
|
|
self.send_outgoing_message_to_connection(
|
|
request_context,
|
|
request_id.connection_id,
|
|
outgoing_message,
|
|
"response",
|
|
)
|
|
.await;
|
|
}
|
|
Err(err) => {
|
|
self.send_error_inner(
|
|
request_context,
|
|
request_id,
|
|
JSONRPCErrorError {
|
|
code: INTERNAL_ERROR_CODE,
|
|
message: format!("failed to serialize response: {err}"),
|
|
data: None,
|
|
},
|
|
)
|
|
.await;
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn send_server_notification(&self, notification: ServerNotification) {
|
|
self.send_server_notification_to_connections(&[], notification)
|
|
.await;
|
|
}
|
|
|
|
pub(crate) async fn send_server_notification_to_connections(
|
|
&self,
|
|
connection_ids: &[ConnectionId],
|
|
notification: ServerNotification,
|
|
) {
|
|
tracing::trace!(
|
|
targeted_connections = connection_ids.len(),
|
|
"app-server event: {notification}"
|
|
);
|
|
let outgoing_message = OutgoingMessage::AppServerNotification(notification);
|
|
if connection_ids.is_empty() {
|
|
if let Err(err) = self
|
|
.sender
|
|
.send(OutgoingEnvelope::Broadcast {
|
|
message: outgoing_message,
|
|
})
|
|
.await
|
|
{
|
|
warn!("failed to send server notification to client: {err:?}");
|
|
}
|
|
return;
|
|
}
|
|
for connection_id in connection_ids {
|
|
if let Err(err) = self
|
|
.sender
|
|
.send(OutgoingEnvelope::ToConnection {
|
|
connection_id: *connection_id,
|
|
message: outgoing_message.clone(),
|
|
write_complete_tx: None,
|
|
})
|
|
.await
|
|
{
|
|
warn!("failed to send server notification to client: {err:?}");
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn send_server_notification_to_connection_and_wait(
|
|
&self,
|
|
connection_id: ConnectionId,
|
|
notification: ServerNotification,
|
|
) {
|
|
tracing::trace!("app-server event: {notification}");
|
|
let outgoing_message = OutgoingMessage::AppServerNotification(notification);
|
|
let (write_complete_tx, write_complete_rx) = oneshot::channel();
|
|
if let Err(err) = self
|
|
.sender
|
|
.send(OutgoingEnvelope::ToConnection {
|
|
connection_id,
|
|
message: outgoing_message,
|
|
write_complete_tx: Some(write_complete_tx),
|
|
})
|
|
.await
|
|
{
|
|
warn!("failed to send server notification to client: {err:?}");
|
|
}
|
|
let _ = write_complete_rx.await;
|
|
}
|
|
|
|
pub(crate) async fn send_error(
|
|
&self,
|
|
request_id: ConnectionRequestId,
|
|
error: JSONRPCErrorError,
|
|
) {
|
|
let request_context = self.take_request_context(&request_id).await;
|
|
self.send_error_inner(request_context, request_id, error)
|
|
.await;
|
|
}
|
|
|
|
async fn send_error_inner(
|
|
&self,
|
|
request_context: Option<RequestContext>,
|
|
request_id: ConnectionRequestId,
|
|
error: JSONRPCErrorError,
|
|
) {
|
|
let outgoing_message = OutgoingMessage::Error(OutgoingError {
|
|
id: request_id.request_id,
|
|
error,
|
|
});
|
|
self.send_outgoing_message_to_connection(
|
|
request_context,
|
|
request_id.connection_id,
|
|
outgoing_message,
|
|
"error",
|
|
)
|
|
.await;
|
|
}
|
|
|
|
async fn send_outgoing_message_to_connection(
|
|
&self,
|
|
request_context: Option<RequestContext>,
|
|
connection_id: ConnectionId,
|
|
message: OutgoingMessage,
|
|
message_kind: &'static str,
|
|
) {
|
|
let send_fut = self.sender.send(OutgoingEnvelope::ToConnection {
|
|
connection_id,
|
|
message,
|
|
write_complete_tx: None,
|
|
});
|
|
let send_result = if let Some(request_context) = request_context {
|
|
send_fut.instrument(request_context.span()).await
|
|
} else {
|
|
send_fut.await
|
|
};
|
|
|
|
if let Err(err) = send_result {
|
|
warn!("failed to send {message_kind} to client: {err:?}");
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Outgoing message from the server to the client.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(untagged)]
|
|
pub(crate) enum OutgoingMessage {
|
|
Request(ServerRequest),
|
|
/// AppServerNotification is specific to the case where this is run as an
|
|
/// "app server" as opposed to an MCP server.
|
|
AppServerNotification(ServerNotification),
|
|
Response(OutgoingResponse),
|
|
Error(OutgoingError),
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Serialize)]
|
|
pub(crate) struct OutgoingResponse {
|
|
pub id: RequestId,
|
|
pub result: Result,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Serialize)]
|
|
pub(crate) struct OutgoingError {
|
|
pub error: JSONRPCErrorError,
|
|
pub id: RequestId,
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::time::Duration;
|
|
|
|
use codex_app_server_protocol::AccountLoginCompletedNotification;
|
|
use codex_app_server_protocol::AccountRateLimitsUpdatedNotification;
|
|
use codex_app_server_protocol::AccountUpdatedNotification;
|
|
use codex_app_server_protocol::ApplyPatchApprovalParams;
|
|
use codex_app_server_protocol::AuthMode;
|
|
use codex_app_server_protocol::ConfigWarningNotification;
|
|
use codex_app_server_protocol::DynamicToolCallParams;
|
|
use codex_app_server_protocol::FileChangeRequestApprovalParams;
|
|
use codex_app_server_protocol::ModelRerouteReason;
|
|
use codex_app_server_protocol::ModelReroutedNotification;
|
|
use codex_app_server_protocol::RateLimitSnapshot;
|
|
use codex_app_server_protocol::RateLimitWindow;
|
|
use codex_app_server_protocol::ToolRequestUserInputParams;
|
|
use codex_protocol::ThreadId;
|
|
use pretty_assertions::assert_eq;
|
|
use serde_json::json;
|
|
use std::sync::Arc;
|
|
use tokio::time::timeout;
|
|
use uuid::Uuid;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn verify_server_notification_serialization() {
|
|
let notification =
|
|
ServerNotification::AccountLoginCompleted(AccountLoginCompletedNotification {
|
|
login_id: Some(Uuid::nil().to_string()),
|
|
success: true,
|
|
error: None,
|
|
});
|
|
|
|
let jsonrpc_notification = OutgoingMessage::AppServerNotification(notification);
|
|
assert_eq!(
|
|
json!({
|
|
"method": "account/login/completed",
|
|
"params": {
|
|
"loginId": Uuid::nil().to_string(),
|
|
"success": true,
|
|
"error": null,
|
|
},
|
|
}),
|
|
serde_json::to_value(jsonrpc_notification)
|
|
.expect("ensure the strum macros serialize the method field correctly"),
|
|
"ensure the strum macros serialize the method field correctly"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn verify_account_login_completed_notification_serialization() {
|
|
let notification =
|
|
ServerNotification::AccountLoginCompleted(AccountLoginCompletedNotification {
|
|
login_id: Some(Uuid::nil().to_string()),
|
|
success: true,
|
|
error: None,
|
|
});
|
|
|
|
let jsonrpc_notification = OutgoingMessage::AppServerNotification(notification);
|
|
assert_eq!(
|
|
json!({
|
|
"method": "account/login/completed",
|
|
"params": {
|
|
"loginId": Uuid::nil().to_string(),
|
|
"success": true,
|
|
"error": null,
|
|
},
|
|
}),
|
|
serde_json::to_value(jsonrpc_notification)
|
|
.expect("ensure the notification serializes correctly"),
|
|
"ensure the notification serializes correctly"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn verify_account_rate_limits_notification_serialization() {
|
|
let notification =
|
|
ServerNotification::AccountRateLimitsUpdated(AccountRateLimitsUpdatedNotification {
|
|
rate_limits: RateLimitSnapshot {
|
|
limit_id: Some("codex".to_string()),
|
|
limit_name: None,
|
|
primary: Some(RateLimitWindow {
|
|
used_percent: 25,
|
|
window_duration_mins: Some(15),
|
|
resets_at: Some(123),
|
|
}),
|
|
secondary: None,
|
|
credits: None,
|
|
plan_type: Some(PlanType::Plus),
|
|
},
|
|
});
|
|
|
|
let jsonrpc_notification = OutgoingMessage::AppServerNotification(notification);
|
|
assert_eq!(
|
|
json!({
|
|
"method": "account/rateLimits/updated",
|
|
"params": {
|
|
"rateLimits": {
|
|
"limitId": "codex",
|
|
"limitName": null,
|
|
"primary": {
|
|
"usedPercent": 25,
|
|
"windowDurationMins": 15,
|
|
"resetsAt": 123
|
|
},
|
|
"secondary": null,
|
|
"credits": null,
|
|
"planType": "plus"
|
|
}
|
|
},
|
|
}),
|
|
serde_json::to_value(jsonrpc_notification)
|
|
.expect("ensure the notification serializes correctly"),
|
|
"ensure the notification serializes correctly"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn verify_account_updated_notification_serialization() {
|
|
let notification = ServerNotification::AccountUpdated(AccountUpdatedNotification {
|
|
auth_mode: Some(AuthMode::ApiKey),
|
|
plan_type: None,
|
|
});
|
|
|
|
let jsonrpc_notification = OutgoingMessage::AppServerNotification(notification);
|
|
assert_eq!(
|
|
json!({
|
|
"method": "account/updated",
|
|
"params": {
|
|
"authMode": "apikey",
|
|
"planType": null
|
|
},
|
|
}),
|
|
serde_json::to_value(jsonrpc_notification)
|
|
.expect("ensure the notification serializes correctly"),
|
|
"ensure the notification serializes correctly"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn verify_config_warning_notification_serialization() {
|
|
let notification = ServerNotification::ConfigWarning(ConfigWarningNotification {
|
|
summary: "Config error: using defaults".to_string(),
|
|
details: Some("error loading config: bad config".to_string()),
|
|
path: None,
|
|
range: None,
|
|
});
|
|
|
|
let jsonrpc_notification = OutgoingMessage::AppServerNotification(notification);
|
|
assert_eq!(
|
|
json!( {
|
|
"method": "configWarning",
|
|
"params": {
|
|
"summary": "Config error: using defaults",
|
|
"details": "error loading config: bad config",
|
|
},
|
|
}),
|
|
serde_json::to_value(jsonrpc_notification)
|
|
.expect("ensure the notification serializes correctly"),
|
|
"ensure the notification serializes correctly"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn verify_model_rerouted_notification_serialization() {
|
|
let notification = ServerNotification::ModelRerouted(ModelReroutedNotification {
|
|
thread_id: "thread-1".to_string(),
|
|
turn_id: "turn-1".to_string(),
|
|
from_model: "gpt-5.3-codex".to_string(),
|
|
to_model: "gpt-5.2".to_string(),
|
|
reason: ModelRerouteReason::HighRiskCyberActivity,
|
|
});
|
|
|
|
let jsonrpc_notification = OutgoingMessage::AppServerNotification(notification);
|
|
assert_eq!(
|
|
json!({
|
|
"method": "model/rerouted",
|
|
"params": {
|
|
"threadId": "thread-1",
|
|
"turnId": "turn-1",
|
|
"fromModel": "gpt-5.3-codex",
|
|
"toModel": "gpt-5.2",
|
|
"reason": "highRiskCyberActivity",
|
|
},
|
|
}),
|
|
serde_json::to_value(jsonrpc_notification)
|
|
.expect("ensure the notification serializes correctly"),
|
|
"ensure the notification serializes correctly"
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn send_response_routes_to_target_connection() {
|
|
let (tx, mut rx) = mpsc::channel::<OutgoingEnvelope>(4);
|
|
let outgoing = OutgoingMessageSender::new(tx);
|
|
let request_id = ConnectionRequestId {
|
|
connection_id: ConnectionId(42),
|
|
request_id: RequestId::Integer(7),
|
|
};
|
|
|
|
outgoing
|
|
.send_response(request_id.clone(), json!({ "ok": true }))
|
|
.await;
|
|
|
|
let envelope = timeout(Duration::from_secs(1), rx.recv())
|
|
.await
|
|
.expect("should receive envelope before timeout")
|
|
.expect("channel should contain one message");
|
|
|
|
match envelope {
|
|
OutgoingEnvelope::ToConnection {
|
|
connection_id,
|
|
message,
|
|
..
|
|
} => {
|
|
assert_eq!(connection_id, ConnectionId(42));
|
|
let OutgoingMessage::Response(response) = message else {
|
|
panic!("expected response message");
|
|
};
|
|
assert_eq!(response.id, request_id.request_id);
|
|
assert_eq!(response.result, json!({ "ok": true }));
|
|
}
|
|
other => panic!("expected targeted response envelope, got: {other:?}"),
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn send_response_clears_registered_request_context() {
|
|
let (tx, _rx) = mpsc::channel::<OutgoingEnvelope>(4);
|
|
let outgoing = OutgoingMessageSender::new(tx);
|
|
let request_id = ConnectionRequestId {
|
|
connection_id: ConnectionId(42),
|
|
request_id: RequestId::Integer(7),
|
|
};
|
|
|
|
outgoing
|
|
.register_request_context(RequestContext::new(
|
|
request_id.clone(),
|
|
tracing::info_span!("app_server.request", rpc.method = "thread/start"),
|
|
/*parent_trace*/ None,
|
|
))
|
|
.await;
|
|
assert_eq!(outgoing.request_context_count().await, 1);
|
|
|
|
outgoing
|
|
.send_response(request_id, json!({ "ok": true }))
|
|
.await;
|
|
|
|
assert_eq!(outgoing.request_context_count().await, 0);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn send_error_routes_to_target_connection() {
|
|
let (tx, mut rx) = mpsc::channel::<OutgoingEnvelope>(4);
|
|
let outgoing = OutgoingMessageSender::new(tx);
|
|
let request_id = ConnectionRequestId {
|
|
connection_id: ConnectionId(9),
|
|
request_id: RequestId::Integer(3),
|
|
};
|
|
let error = JSONRPCErrorError {
|
|
code: INTERNAL_ERROR_CODE,
|
|
message: "boom".to_string(),
|
|
data: None,
|
|
};
|
|
|
|
outgoing.send_error(request_id.clone(), error.clone()).await;
|
|
|
|
let envelope = timeout(Duration::from_secs(1), rx.recv())
|
|
.await
|
|
.expect("should receive envelope before timeout")
|
|
.expect("channel should contain one message");
|
|
|
|
match envelope {
|
|
OutgoingEnvelope::ToConnection {
|
|
connection_id,
|
|
message,
|
|
..
|
|
} => {
|
|
assert_eq!(connection_id, ConnectionId(9));
|
|
let OutgoingMessage::Error(outgoing_error) = message else {
|
|
panic!("expected error message");
|
|
};
|
|
assert_eq!(outgoing_error.id, RequestId::Integer(3));
|
|
assert_eq!(outgoing_error.error, error);
|
|
}
|
|
other => panic!("expected targeted error envelope, got: {other:?}"),
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn send_server_notification_to_connection_and_wait_tracks_write_completion() {
|
|
let (tx, mut rx) = mpsc::channel::<OutgoingEnvelope>(4);
|
|
let outgoing = OutgoingMessageSender::new(tx);
|
|
let send_task = tokio::spawn(async move {
|
|
outgoing
|
|
.send_server_notification_to_connection_and_wait(
|
|
ConnectionId(42),
|
|
ServerNotification::ModelRerouted(ModelReroutedNotification {
|
|
thread_id: "thread-1".to_string(),
|
|
turn_id: "turn-1".to_string(),
|
|
from_model: "gpt-5.3-codex".to_string(),
|
|
to_model: "gpt-5.2".to_string(),
|
|
reason: ModelRerouteReason::HighRiskCyberActivity,
|
|
}),
|
|
)
|
|
.await
|
|
});
|
|
|
|
let envelope = timeout(Duration::from_secs(1), rx.recv())
|
|
.await
|
|
.expect("should receive envelope before timeout")
|
|
.expect("channel should contain one message");
|
|
let OutgoingEnvelope::ToConnection {
|
|
connection_id,
|
|
message,
|
|
write_complete_tx,
|
|
} = envelope
|
|
else {
|
|
panic!("expected targeted server notification envelope");
|
|
};
|
|
assert_eq!(connection_id, ConnectionId(42));
|
|
assert!(matches!(message, OutgoingMessage::AppServerNotification(_)));
|
|
write_complete_tx
|
|
.expect("write completion sender should be attached")
|
|
.send(())
|
|
.expect("receiver should still be waiting");
|
|
|
|
timeout(Duration::from_secs(1), send_task)
|
|
.await
|
|
.expect("send task should finish after write completion is signaled")
|
|
.expect("send task should not panic");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn connection_closed_clears_registered_request_contexts() {
|
|
let (tx, _rx) = mpsc::channel::<OutgoingEnvelope>(4);
|
|
let outgoing = OutgoingMessageSender::new(tx);
|
|
let closed_connection_request = ConnectionRequestId {
|
|
connection_id: ConnectionId(9),
|
|
request_id: RequestId::Integer(3),
|
|
};
|
|
let open_connection_request = ConnectionRequestId {
|
|
connection_id: ConnectionId(10),
|
|
request_id: RequestId::Integer(4),
|
|
};
|
|
|
|
outgoing
|
|
.register_request_context(RequestContext::new(
|
|
closed_connection_request,
|
|
tracing::info_span!("app_server.request", rpc.method = "turn/interrupt"),
|
|
/*parent_trace*/ None,
|
|
))
|
|
.await;
|
|
outgoing
|
|
.register_request_context(RequestContext::new(
|
|
open_connection_request,
|
|
tracing::info_span!("app_server.request", rpc.method = "turn/start"),
|
|
/*parent_trace*/ None,
|
|
))
|
|
.await;
|
|
assert_eq!(outgoing.request_context_count().await, 2);
|
|
|
|
outgoing.connection_closed(ConnectionId(9)).await;
|
|
|
|
assert_eq!(outgoing.request_context_count().await, 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn notify_client_error_forwards_error_to_waiter() {
|
|
let (tx, _rx) = mpsc::channel::<OutgoingEnvelope>(4);
|
|
let outgoing = OutgoingMessageSender::new(tx);
|
|
|
|
let (request_id, wait_for_result) = outgoing
|
|
.send_request(ServerRequestPayload::ApplyPatchApproval(
|
|
ApplyPatchApprovalParams {
|
|
conversation_id: ThreadId::new(),
|
|
call_id: "call-id".to_string(),
|
|
file_changes: HashMap::new(),
|
|
reason: None,
|
|
grant_root: None,
|
|
},
|
|
))
|
|
.await;
|
|
|
|
let error = JSONRPCErrorError {
|
|
code: INTERNAL_ERROR_CODE,
|
|
message: "refresh failed".to_string(),
|
|
data: None,
|
|
};
|
|
|
|
outgoing
|
|
.notify_client_error(request_id, error.clone())
|
|
.await;
|
|
|
|
let result = timeout(Duration::from_secs(1), wait_for_result)
|
|
.await
|
|
.expect("wait should not time out")
|
|
.expect("waiter should receive a callback");
|
|
assert_eq!(result, Err(error));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn pending_requests_for_thread_returns_thread_requests_in_request_id_order() {
|
|
let (tx, _rx) = mpsc::channel::<OutgoingEnvelope>(8);
|
|
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
|
let thread_id = ThreadId::new();
|
|
let thread_outgoing = ThreadScopedOutgoingMessageSender::new(
|
|
outgoing.clone(),
|
|
vec![ConnectionId(1)],
|
|
thread_id,
|
|
);
|
|
|
|
let (dynamic_tool_request_id, _dynamic_tool_waiter) = thread_outgoing
|
|
.send_request(ServerRequestPayload::DynamicToolCall(
|
|
DynamicToolCallParams {
|
|
thread_id: thread_id.to_string(),
|
|
turn_id: "turn-1".to_string(),
|
|
call_id: "call-0".to_string(),
|
|
tool: "tool".to_string(),
|
|
arguments: json!({}),
|
|
},
|
|
))
|
|
.await;
|
|
let (first_request_id, _first_waiter) = thread_outgoing
|
|
.send_request(ServerRequestPayload::ToolRequestUserInput(
|
|
ToolRequestUserInputParams {
|
|
thread_id: thread_id.to_string(),
|
|
turn_id: "turn-1".to_string(),
|
|
item_id: "call-1".to_string(),
|
|
questions: vec![],
|
|
},
|
|
))
|
|
.await;
|
|
let (second_request_id, _second_waiter) = thread_outgoing
|
|
.send_request(ServerRequestPayload::FileChangeRequestApproval(
|
|
FileChangeRequestApprovalParams {
|
|
thread_id: thread_id.to_string(),
|
|
turn_id: "turn-1".to_string(),
|
|
item_id: "call-2".to_string(),
|
|
reason: None,
|
|
grant_root: None,
|
|
},
|
|
))
|
|
.await;
|
|
let pending_requests = outgoing.pending_requests_for_thread(thread_id).await;
|
|
assert_eq!(
|
|
pending_requests
|
|
.iter()
|
|
.map(ServerRequest::id)
|
|
.collect::<Vec<_>>(),
|
|
vec![
|
|
&dynamic_tool_request_id,
|
|
&first_request_id,
|
|
&second_request_id
|
|
]
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn cancel_requests_for_thread_cancels_all_thread_requests() {
|
|
let (tx, _rx) = mpsc::channel::<OutgoingEnvelope>(8);
|
|
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
|
let thread_id = ThreadId::new();
|
|
let thread_outgoing = ThreadScopedOutgoingMessageSender::new(
|
|
outgoing.clone(),
|
|
vec![ConnectionId(1)],
|
|
thread_id,
|
|
);
|
|
|
|
let (_dynamic_tool_request_id, dynamic_tool_waiter) = thread_outgoing
|
|
.send_request(ServerRequestPayload::DynamicToolCall(
|
|
DynamicToolCallParams {
|
|
thread_id: thread_id.to_string(),
|
|
turn_id: "turn-1".to_string(),
|
|
call_id: "call-0".to_string(),
|
|
tool: "tool".to_string(),
|
|
arguments: json!({}),
|
|
},
|
|
))
|
|
.await;
|
|
let (_request_id, user_input_waiter) = thread_outgoing
|
|
.send_request(ServerRequestPayload::ToolRequestUserInput(
|
|
ToolRequestUserInputParams {
|
|
thread_id: thread_id.to_string(),
|
|
turn_id: "turn-1".to_string(),
|
|
item_id: "call-1".to_string(),
|
|
questions: vec![],
|
|
},
|
|
))
|
|
.await;
|
|
let error = JSONRPCErrorError {
|
|
code: INTERNAL_ERROR_CODE,
|
|
message: "tracked request cancelled".to_string(),
|
|
data: None,
|
|
};
|
|
|
|
outgoing
|
|
.cancel_requests_for_thread(thread_id, Some(error.clone()))
|
|
.await;
|
|
|
|
let dynamic_tool_result = timeout(Duration::from_secs(1), dynamic_tool_waiter)
|
|
.await
|
|
.expect("dynamic tool waiter should resolve")
|
|
.expect("dynamic tool waiter should receive a callback");
|
|
let user_input_result = timeout(Duration::from_secs(1), user_input_waiter)
|
|
.await
|
|
.expect("user input waiter should resolve")
|
|
.expect("user input waiter should receive a callback");
|
|
assert_eq!(dynamic_tool_result, Err(error.clone()));
|
|
assert_eq!(user_input_result, Err(error));
|
|
assert!(
|
|
outgoing
|
|
.pending_requests_for_thread(thread_id)
|
|
.await
|
|
.is_empty()
|
|
);
|
|
}
|
|
}
|