[6/6] Fail exec client operations after disconnect (#18027)

## Summary
- Reject new exec-server client operations once the transport has
disconnected.
- Convert pending RPC calls into closed errors instead of synthetic
server errors.
- Cover pending read and later write behavior after remote executor
disconnect.

## Verification
- `just fmt`
- `cargo check -p codex-exec-server`

## Stack
```text
@  #18027 [6/6] Fail exec client operations after disconnect
│
o  #18212 [5/6] Wire executor-backed MCP stdio
│
o  #18087 [4/6] Abstract MCP stdio server launching
│
o  #18020 [3/6] Add pushed exec process events
│
o  #18086 [2/6] Support piped stdin in exec process API
│
o  #18085 [1/6] Add MCP server environment config
│
o  main
```

---------

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Ahmed Ibrahim
2026-04-20 16:24:06 -07:00
committed by GitHub
parent 0f1c9b8963
commit 9ef1cab6f7
4 changed files with 225 additions and 102 deletions

View File

@@ -2,6 +2,7 @@ use std::collections::BTreeMap;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::OnceLock;
use std::time::Duration;
use arc_swap::ArcSwap;
@@ -140,6 +141,10 @@ struct Inner {
// need serialization so concurrent register/remove operations do not
// overwrite each other's copy-on-write updates.
sessions_write_lock: Mutex<()>,
// Once the transport closes, every executor operation should fail quickly
// with the same canonical message. This client never reconnects, so the
// latch only moves from unset to set once.
disconnected: OnceLock<String>,
session_id: std::sync::RwLock<Option<String>>,
reader_task: tokio::task::JoinHandle<()>,
}
@@ -171,6 +176,8 @@ pub enum ExecServerError {
InitializeTimedOut { timeout: Duration },
#[error("exec-server transport closed")]
Closed,
#[error("{0}")]
Disconnected(String),
#[error("failed to serialize or deserialize exec-server JSON: {0}")]
Json(#[from] serde_json::Error),
#[error("exec-server protocol error: {0}")]
@@ -246,19 +253,11 @@ impl ExecServerClient {
}
pub async fn exec(&self, params: ExecParams) -> Result<ExecResponse, ExecServerError> {
self.inner
.client
.call(EXEC_METHOD, &params)
.await
.map_err(Into::into)
self.call(EXEC_METHOD, &params).await
}
pub async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
self.inner
.client
.call(EXEC_READ_METHOD, &params)
.await
.map_err(Into::into)
self.call(EXEC_READ_METHOD, &params).await
}
pub async fn write(
@@ -266,107 +265,73 @@ impl ExecServerClient {
process_id: &ProcessId,
chunk: Vec<u8>,
) -> Result<WriteResponse, ExecServerError> {
self.inner
.client
.call(
EXEC_WRITE_METHOD,
&WriteParams {
process_id: process_id.clone(),
chunk: chunk.into(),
},
)
.await
.map_err(Into::into)
self.call(
EXEC_WRITE_METHOD,
&WriteParams {
process_id: process_id.clone(),
chunk: chunk.into(),
},
)
.await
}
pub async fn terminate(
&self,
process_id: &ProcessId,
) -> Result<TerminateResponse, ExecServerError> {
self.inner
.client
.call(
EXEC_TERMINATE_METHOD,
&TerminateParams {
process_id: process_id.clone(),
},
)
.await
.map_err(Into::into)
self.call(
EXEC_TERMINATE_METHOD,
&TerminateParams {
process_id: process_id.clone(),
},
)
.await
}
pub async fn fs_read_file(
&self,
params: FsReadFileParams,
) -> Result<FsReadFileResponse, ExecServerError> {
self.inner
.client
.call(FS_READ_FILE_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_READ_FILE_METHOD, &params).await
}
pub async fn fs_write_file(
&self,
params: FsWriteFileParams,
) -> Result<FsWriteFileResponse, ExecServerError> {
self.inner
.client
.call(FS_WRITE_FILE_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_WRITE_FILE_METHOD, &params).await
}
pub async fn fs_create_directory(
&self,
params: FsCreateDirectoryParams,
) -> Result<FsCreateDirectoryResponse, ExecServerError> {
self.inner
.client
.call(FS_CREATE_DIRECTORY_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_CREATE_DIRECTORY_METHOD, &params).await
}
pub async fn fs_get_metadata(
&self,
params: FsGetMetadataParams,
) -> Result<FsGetMetadataResponse, ExecServerError> {
self.inner
.client
.call(FS_GET_METADATA_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_GET_METADATA_METHOD, &params).await
}
pub async fn fs_read_directory(
&self,
params: FsReadDirectoryParams,
) -> Result<FsReadDirectoryResponse, ExecServerError> {
self.inner
.client
.call(FS_READ_DIRECTORY_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_READ_DIRECTORY_METHOD, &params).await
}
pub async fn fs_remove(
&self,
params: FsRemoveParams,
) -> Result<FsRemoveResponse, ExecServerError> {
self.inner
.client
.call(FS_REMOVE_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_REMOVE_METHOD, &params).await
}
pub async fn fs_copy(&self, params: FsCopyParams) -> Result<FsCopyResponse, ExecServerError> {
self.inner
.client
.call(FS_COPY_METHOD, &params)
.await
.map_err(Into::into)
self.call(FS_COPY_METHOD, &params).await
}
pub(crate) async fn register_session(
@@ -411,18 +376,21 @@ impl ExecServerClient {
&& let Err(err) =
handle_server_notification(&inner, notification).await
{
fail_all_sessions(
let message = record_disconnected(
&inner,
format!("exec-server notification handling failed: {err}"),
)
.await;
);
fail_all_sessions(&inner, message).await;
return;
}
}
RpcClientEvent::Disconnected { reason } => {
if let Some(inner) = weak.upgrade() {
fail_all_sessions(&inner, disconnected_message(reason.as_deref()))
.await;
let message = record_disconnected(
&inner,
disconnected_message(reason.as_deref()),
);
fail_all_sessions(&inner, message).await;
}
return;
}
@@ -434,6 +402,7 @@ impl ExecServerClient {
client: rpc_client,
sessions: ArcSwap::from_pointee(HashMap::new()),
sessions_write_lock: Mutex::new(()),
disconnected: OnceLock::new(),
session_id: std::sync::RwLock::new(None),
reader_task,
}
@@ -451,6 +420,36 @@ impl ExecServerClient {
.await
.map_err(ExecServerError::Json)
}
async fn call<P, T>(&self, method: &str, params: &P) -> Result<T, ExecServerError>
where
P: serde::Serialize,
T: serde::de::DeserializeOwned,
{
// Reject new work before allocating a JSON-RPC request id. MCP tool
// calls, process writes, and fs operations all pass through here, so
// this is the shared low-level failure path after executor disconnect.
if let Some(error) = self.inner.disconnected_error() {
return Err(error);
}
match self.inner.client.call(method, params).await {
Ok(response) => Ok(response),
Err(error) => {
let error = ExecServerError::from(error);
if is_transport_closed_error(&error) {
// A call can race with disconnect after the preflight
// check. Only the reader task drains sessions so queued
// process notifications stay ordered before disconnect.
let message = disconnected_message(/*reason*/ None);
let message = record_disconnected(&self.inner, message);
Err(ExecServerError::Disconnected(message))
} else {
Err(error)
}
}
}
}
}
impl From<RpcCallError> for ExecServerError {
@@ -630,6 +629,20 @@ impl Session {
}
impl Inner {
fn disconnected_error(&self) -> Option<ExecServerError> {
self.disconnected
.get()
.cloned()
.map(ExecServerError::Disconnected)
}
fn set_disconnected(&self, message: String) -> Option<String> {
match self.disconnected.set(message.clone()) {
Ok(()) => Some(message),
Err(_) => None,
}
}
fn get_session(&self, process_id: &ProcessId) -> Option<Arc<SessionState>> {
self.sessions.load().get(process_id).cloned()
}
@@ -640,6 +653,12 @@ impl Inner {
session: Arc<SessionState>,
) -> Result<(), ExecServerError> {
let _sessions_write_guard = self.sessions_write_lock.lock().await;
// Do not register a process session that can never receive executor
// notifications. Without this check, remote MCP startup could create a
// dead session and wait for process output that will never arrive.
if let Some(error) = self.disconnected_error() {
return Err(error);
}
let sessions = self.sessions.load();
if sessions.contains_key(process_id) {
return Err(ExecServerError::Protocol(format!(
@@ -680,20 +699,36 @@ fn disconnected_message(reason: Option<&str>) -> String {
}
fn is_transport_closed_error(error: &ExecServerError) -> bool {
matches!(error, ExecServerError::Closed)
|| matches!(
error,
ExecServerError::Server {
code: -32000,
message,
} if message == "JSON-RPC transport closed"
)
matches!(
error,
ExecServerError::Closed | ExecServerError::Disconnected(_)
) || matches!(
error,
ExecServerError::Server {
code: -32000,
message,
} if message == "JSON-RPC transport closed"
)
}
fn record_disconnected(inner: &Arc<Inner>, message: String) -> String {
// The first observer records the canonical disconnect reason. Session
// draining stays with the reader task so it can preserve notification
// ordering before publishing the terminal failure.
if let Some(message) = inner.set_disconnected(message.clone()) {
message
} else {
inner.disconnected.get().cloned().unwrap_or(message)
}
}
async fn fail_all_sessions(inner: &Arc<Inner>, message: String) {
let sessions = inner.take_all_sessions().await;
for (_, session) in sessions {
// Sessions synthesize a closed read response and emit a pushed Failed
// event. That covers both polling consumers and streaming consumers
// such as executor-backed MCP stdio.
session.set_failure(message.clone()).await;
}
}

View File

@@ -195,9 +195,46 @@ fn map_remote_error(error: ExecServerError) -> io::Error {
io::Error::new(io::ErrorKind::InvalidInput, message)
}
ExecServerError::Server { message, .. } => io::Error::other(message),
ExecServerError::Closed => {
ExecServerError::Closed | ExecServerError::Disconnected(_) => {
io::Error::new(io::ErrorKind::BrokenPipe, "exec-server transport closed")
}
_ => io::Error::other(error.to_string()),
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::*;
#[test]
fn transport_errors_map_to_broken_pipe() {
let errors = [
ExecServerError::Closed,
ExecServerError::Disconnected("exec-server transport disconnected".to_string()),
];
let mapped_errors = errors
.into_iter()
.map(|error| {
let error = map_remote_error(error);
(error.kind(), error.to_string())
})
.collect::<Vec<_>>();
assert_eq!(
mapped_errors,
vec![
(
io::ErrorKind::BrokenPipe,
"exec-server transport closed".to_string()
),
(
io::ErrorKind::BrokenPipe,
"exec-server transport closed".to_string()
),
]
);
}
}

View File

@@ -18,12 +18,23 @@ use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use crate::connection::JsonRpcConnection;
use crate::connection::JsonRpcConnectionEvent;
type PendingRequest = oneshot::Sender<Result<Value, JSONRPCErrorError>>;
#[derive(Debug)]
pub(crate) enum RpcCallError {
/// The underlying JSON-RPC transport closed before this call completed.
Closed,
/// The response bytes were valid JSON-RPC but not the expected result type.
Json(serde_json::Error),
/// The executor returned a JSON-RPC error response for this call.
Server(JSONRPCErrorError),
}
type PendingRequest = oneshot::Sender<Result<Value, RpcCallError>>;
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
type RequestRoute<S> =
Box<dyn Fn(Arc<S>, JSONRPCRequest) -> BoxFuture<RpcServerOutboundMessage> + Send + Sync>;
@@ -172,6 +183,10 @@ where
pub(crate) struct RpcClient {
write_tx: mpsc::Sender<JSONRPCMessage>,
pending: Arc<Mutex<HashMap<RequestId, PendingRequest>>>,
// Shared transport state from `JsonRpcConnection`. Calls use this to fail
// immediately when the socket closes, even if no JSON-RPC error response
// can be delivered for their request id.
disconnected_rx: watch::Receiver<bool>,
next_request_id: AtomicI64,
transport_tasks: Vec<JoinHandle<()>>,
reader_task: JoinHandle<()>,
@@ -179,8 +194,7 @@ pub(crate) struct RpcClient {
impl RpcClient {
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
let (write_tx, mut incoming_rx, _disconnected_rx, transport_tasks) =
connection.into_parts();
let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks) = connection.into_parts();
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
let (event_tx, event_rx) = mpsc::channel(128);
@@ -218,6 +232,7 @@ impl RpcClient {
Self {
write_tx,
pending,
disconnected_rx,
next_request_id: AtomicI64::new(1),
transport_tasks,
reader_task,
@@ -253,10 +268,16 @@ impl RpcClient {
{
let request_id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::SeqCst));
let (response_tx, response_rx) = oneshot::channel();
self.pending
.lock()
.await
.insert(request_id.clone(), response_tx);
{
let mut pending = self.pending.lock().await;
// Registering the pending request and checking disconnect must be
// atomic with the reader's drain_pending path. Otherwise a call
// can sneak in after the drain and wait forever.
if *self.disconnected_rx.borrow() {
return Err(RpcCallError::Closed);
}
pending.insert(request_id.clone(), response_tx);
}
let params = match serde_json::to_value(params) {
Ok(params) => params,
@@ -280,10 +301,17 @@ impl RpcClient {
return Err(RpcCallError::Closed);
}
let result = response_rx.await.map_err(|_| RpcCallError::Closed)?;
// Do not race in-flight requests directly against the transport-close
// watch value. The connection reader receives JSON-RPC messages and
// the terminal disconnect event on one ordered queue, then drains any
// still-pending requests. Awaiting this receiver preserves that order:
// responses already read before EOF still win, and truly pending calls
// are failed once the reader observes the disconnect.
let result: Result<Value, RpcCallError> =
response_rx.await.map_err(|_| RpcCallError::Closed)?;
let response = match result {
Ok(response) => response,
Err(error) => return Err(RpcCallError::Server(error)),
Err(error) => return Err(error),
};
serde_json::from_value(response).map_err(RpcCallError::Json)
}
@@ -304,13 +332,6 @@ impl Drop for RpcClient {
}
}
#[derive(Debug)]
pub(crate) enum RpcCallError {
Closed,
Json(serde_json::Error),
Server(JSONRPCErrorError),
}
pub(crate) fn encode_server_message(
message: RpcServerOutboundMessage,
) -> Result<JSONRPCMessage, serde_json::Error> {
@@ -417,7 +438,7 @@ async fn handle_server_message(
}
JSONRPCMessage::Error(JSONRPCError { id, error }) => {
if let Some(pending) = pending.lock().await.remove(&id) {
let _ = pending.send(Err(error));
let _ = pending.send(Err(RpcCallError::Server(error)));
}
}
JSONRPCMessage::Notification(notification) => {
@@ -445,11 +466,7 @@ async fn drain_pending(pending: &Mutex<HashMap<RequestId, PendingRequest>>) {
.collect::<Vec<_>>()
};
for pending in pending {
let _ = pending.send(Err(JSONRPCErrorError {
code: -32000,
data: None,
message: "JSON-RPC transport closed".to_string(),
}));
let _ = pending.send(Err(RpcCallError::Closed));
}
}