mirror of
https://github.com/openai/codex.git
synced 2026-05-05 22:01:37 +03:00
[6/6] Fail exec client operations after disconnect (#18027)
## Summary - Reject new exec-server client operations once the transport has disconnected. - Convert pending RPC calls into closed errors instead of synthetic server errors. - Cover pending read and later write behavior after remote executor disconnect. ## Verification - `just fmt` - `cargo check -p codex-exec-server` ## Stack ```text @ #18027 [6/6] Fail exec client operations after disconnect │ o #18212 [5/6] Wire executor-backed MCP stdio │ o #18087 [4/6] Abstract MCP stdio server launching │ o #18020 [3/6] Add pushed exec process events │ o #18086 [2/6] Support piped stdin in exec process API │ o #18085 [1/6] Add MCP server environment config │ o main ``` --------- Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -2,6 +2,7 @@ use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
@@ -140,6 +141,10 @@ struct Inner {
|
||||
// need serialization so concurrent register/remove operations do not
|
||||
// overwrite each other's copy-on-write updates.
|
||||
sessions_write_lock: Mutex<()>,
|
||||
// Once the transport closes, every executor operation should fail quickly
|
||||
// with the same canonical message. This client never reconnects, so the
|
||||
// latch only moves from unset to set once.
|
||||
disconnected: OnceLock<String>,
|
||||
session_id: std::sync::RwLock<Option<String>>,
|
||||
reader_task: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
@@ -171,6 +176,8 @@ pub enum ExecServerError {
|
||||
InitializeTimedOut { timeout: Duration },
|
||||
#[error("exec-server transport closed")]
|
||||
Closed,
|
||||
#[error("{0}")]
|
||||
Disconnected(String),
|
||||
#[error("failed to serialize or deserialize exec-server JSON: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error("exec-server protocol error: {0}")]
|
||||
@@ -246,19 +253,11 @@ impl ExecServerClient {
|
||||
}
|
||||
|
||||
pub async fn exec(&self, params: ExecParams) -> Result<ExecResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(EXEC_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(EXEC_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(EXEC_READ_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(EXEC_READ_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn write(
|
||||
@@ -266,107 +265,73 @@ impl ExecServerClient {
|
||||
process_id: &ProcessId,
|
||||
chunk: Vec<u8>,
|
||||
) -> Result<WriteResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(
|
||||
EXEC_WRITE_METHOD,
|
||||
&WriteParams {
|
||||
process_id: process_id.clone(),
|
||||
chunk: chunk.into(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(
|
||||
EXEC_WRITE_METHOD,
|
||||
&WriteParams {
|
||||
process_id: process_id.clone(),
|
||||
chunk: chunk.into(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn terminate(
|
||||
&self,
|
||||
process_id: &ProcessId,
|
||||
) -> Result<TerminateResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(
|
||||
EXEC_TERMINATE_METHOD,
|
||||
&TerminateParams {
|
||||
process_id: process_id.clone(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(
|
||||
EXEC_TERMINATE_METHOD,
|
||||
&TerminateParams {
|
||||
process_id: process_id.clone(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn fs_read_file(
|
||||
&self,
|
||||
params: FsReadFileParams,
|
||||
) -> Result<FsReadFileResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_READ_FILE_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_READ_FILE_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn fs_write_file(
|
||||
&self,
|
||||
params: FsWriteFileParams,
|
||||
) -> Result<FsWriteFileResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_WRITE_FILE_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_WRITE_FILE_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn fs_create_directory(
|
||||
&self,
|
||||
params: FsCreateDirectoryParams,
|
||||
) -> Result<FsCreateDirectoryResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_CREATE_DIRECTORY_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_CREATE_DIRECTORY_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn fs_get_metadata(
|
||||
&self,
|
||||
params: FsGetMetadataParams,
|
||||
) -> Result<FsGetMetadataResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_GET_METADATA_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_GET_METADATA_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn fs_read_directory(
|
||||
&self,
|
||||
params: FsReadDirectoryParams,
|
||||
) -> Result<FsReadDirectoryResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_READ_DIRECTORY_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_READ_DIRECTORY_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn fs_remove(
|
||||
&self,
|
||||
params: FsRemoveParams,
|
||||
) -> Result<FsRemoveResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_REMOVE_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_REMOVE_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub async fn fs_copy(&self, params: FsCopyParams) -> Result<FsCopyResponse, ExecServerError> {
|
||||
self.inner
|
||||
.client
|
||||
.call(FS_COPY_METHOD, ¶ms)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.call(FS_COPY_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
pub(crate) async fn register_session(
|
||||
@@ -411,18 +376,21 @@ impl ExecServerClient {
|
||||
&& let Err(err) =
|
||||
handle_server_notification(&inner, notification).await
|
||||
{
|
||||
fail_all_sessions(
|
||||
let message = record_disconnected(
|
||||
&inner,
|
||||
format!("exec-server notification handling failed: {err}"),
|
||||
)
|
||||
.await;
|
||||
);
|
||||
fail_all_sessions(&inner, message).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
RpcClientEvent::Disconnected { reason } => {
|
||||
if let Some(inner) = weak.upgrade() {
|
||||
fail_all_sessions(&inner, disconnected_message(reason.as_deref()))
|
||||
.await;
|
||||
let message = record_disconnected(
|
||||
&inner,
|
||||
disconnected_message(reason.as_deref()),
|
||||
);
|
||||
fail_all_sessions(&inner, message).await;
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -434,6 +402,7 @@ impl ExecServerClient {
|
||||
client: rpc_client,
|
||||
sessions: ArcSwap::from_pointee(HashMap::new()),
|
||||
sessions_write_lock: Mutex::new(()),
|
||||
disconnected: OnceLock::new(),
|
||||
session_id: std::sync::RwLock::new(None),
|
||||
reader_task,
|
||||
}
|
||||
@@ -451,6 +420,36 @@ impl ExecServerClient {
|
||||
.await
|
||||
.map_err(ExecServerError::Json)
|
||||
}
|
||||
|
||||
async fn call<P, T>(&self, method: &str, params: &P) -> Result<T, ExecServerError>
|
||||
where
|
||||
P: serde::Serialize,
|
||||
T: serde::de::DeserializeOwned,
|
||||
{
|
||||
// Reject new work before allocating a JSON-RPC request id. MCP tool
|
||||
// calls, process writes, and fs operations all pass through here, so
|
||||
// this is the shared low-level failure path after executor disconnect.
|
||||
if let Some(error) = self.inner.disconnected_error() {
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
match self.inner.client.call(method, params).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(error) => {
|
||||
let error = ExecServerError::from(error);
|
||||
if is_transport_closed_error(&error) {
|
||||
// A call can race with disconnect after the preflight
|
||||
// check. Only the reader task drains sessions so queued
|
||||
// process notifications stay ordered before disconnect.
|
||||
let message = disconnected_message(/*reason*/ None);
|
||||
let message = record_disconnected(&self.inner, message);
|
||||
Err(ExecServerError::Disconnected(message))
|
||||
} else {
|
||||
Err(error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpcCallError> for ExecServerError {
|
||||
@@ -630,6 +629,20 @@ impl Session {
|
||||
}
|
||||
|
||||
impl Inner {
|
||||
fn disconnected_error(&self) -> Option<ExecServerError> {
|
||||
self.disconnected
|
||||
.get()
|
||||
.cloned()
|
||||
.map(ExecServerError::Disconnected)
|
||||
}
|
||||
|
||||
fn set_disconnected(&self, message: String) -> Option<String> {
|
||||
match self.disconnected.set(message.clone()) {
|
||||
Ok(()) => Some(message),
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_session(&self, process_id: &ProcessId) -> Option<Arc<SessionState>> {
|
||||
self.sessions.load().get(process_id).cloned()
|
||||
}
|
||||
@@ -640,6 +653,12 @@ impl Inner {
|
||||
session: Arc<SessionState>,
|
||||
) -> Result<(), ExecServerError> {
|
||||
let _sessions_write_guard = self.sessions_write_lock.lock().await;
|
||||
// Do not register a process session that can never receive executor
|
||||
// notifications. Without this check, remote MCP startup could create a
|
||||
// dead session and wait for process output that will never arrive.
|
||||
if let Some(error) = self.disconnected_error() {
|
||||
return Err(error);
|
||||
}
|
||||
let sessions = self.sessions.load();
|
||||
if sessions.contains_key(process_id) {
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
@@ -680,20 +699,36 @@ fn disconnected_message(reason: Option<&str>) -> String {
|
||||
}
|
||||
|
||||
fn is_transport_closed_error(error: &ExecServerError) -> bool {
|
||||
matches!(error, ExecServerError::Closed)
|
||||
|| matches!(
|
||||
error,
|
||||
ExecServerError::Server {
|
||||
code: -32000,
|
||||
message,
|
||||
} if message == "JSON-RPC transport closed"
|
||||
)
|
||||
matches!(
|
||||
error,
|
||||
ExecServerError::Closed | ExecServerError::Disconnected(_)
|
||||
) || matches!(
|
||||
error,
|
||||
ExecServerError::Server {
|
||||
code: -32000,
|
||||
message,
|
||||
} if message == "JSON-RPC transport closed"
|
||||
)
|
||||
}
|
||||
|
||||
fn record_disconnected(inner: &Arc<Inner>, message: String) -> String {
|
||||
// The first observer records the canonical disconnect reason. Session
|
||||
// draining stays with the reader task so it can preserve notification
|
||||
// ordering before publishing the terminal failure.
|
||||
if let Some(message) = inner.set_disconnected(message.clone()) {
|
||||
message
|
||||
} else {
|
||||
inner.disconnected.get().cloned().unwrap_or(message)
|
||||
}
|
||||
}
|
||||
|
||||
async fn fail_all_sessions(inner: &Arc<Inner>, message: String) {
|
||||
let sessions = inner.take_all_sessions().await;
|
||||
|
||||
for (_, session) in sessions {
|
||||
// Sessions synthesize a closed read response and emit a pushed Failed
|
||||
// event. That covers both polling consumers and streaming consumers
|
||||
// such as executor-backed MCP stdio.
|
||||
session.set_failure(message.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,9 +195,46 @@ fn map_remote_error(error: ExecServerError) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, message)
|
||||
}
|
||||
ExecServerError::Server { message, .. } => io::Error::other(message),
|
||||
ExecServerError::Closed => {
|
||||
ExecServerError::Closed | ExecServerError::Disconnected(_) => {
|
||||
io::Error::new(io::ErrorKind::BrokenPipe, "exec-server transport closed")
|
||||
}
|
||||
_ => io::Error::other(error.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn transport_errors_map_to_broken_pipe() {
|
||||
let errors = [
|
||||
ExecServerError::Closed,
|
||||
ExecServerError::Disconnected("exec-server transport disconnected".to_string()),
|
||||
];
|
||||
|
||||
let mapped_errors = errors
|
||||
.into_iter()
|
||||
.map(|error| {
|
||||
let error = map_remote_error(error);
|
||||
(error.kind(), error.to_string())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
mapped_errors,
|
||||
vec![
|
||||
(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
"exec-server transport closed".to_string()
|
||||
),
|
||||
(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
"exec-server transport closed".to_string()
|
||||
),
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,12 +18,23 @@ use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::watch;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::connection::JsonRpcConnectionEvent;
|
||||
|
||||
type PendingRequest = oneshot::Sender<Result<Value, JSONRPCErrorError>>;
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum RpcCallError {
|
||||
/// The underlying JSON-RPC transport closed before this call completed.
|
||||
Closed,
|
||||
/// The response bytes were valid JSON-RPC but not the expected result type.
|
||||
Json(serde_json::Error),
|
||||
/// The executor returned a JSON-RPC error response for this call.
|
||||
Server(JSONRPCErrorError),
|
||||
}
|
||||
|
||||
type PendingRequest = oneshot::Sender<Result<Value, RpcCallError>>;
|
||||
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
|
||||
type RequestRoute<S> =
|
||||
Box<dyn Fn(Arc<S>, JSONRPCRequest) -> BoxFuture<RpcServerOutboundMessage> + Send + Sync>;
|
||||
@@ -172,6 +183,10 @@ where
|
||||
pub(crate) struct RpcClient {
|
||||
write_tx: mpsc::Sender<JSONRPCMessage>,
|
||||
pending: Arc<Mutex<HashMap<RequestId, PendingRequest>>>,
|
||||
// Shared transport state from `JsonRpcConnection`. Calls use this to fail
|
||||
// immediately when the socket closes, even if no JSON-RPC error response
|
||||
// can be delivered for their request id.
|
||||
disconnected_rx: watch::Receiver<bool>,
|
||||
next_request_id: AtomicI64,
|
||||
transport_tasks: Vec<JoinHandle<()>>,
|
||||
reader_task: JoinHandle<()>,
|
||||
@@ -179,8 +194,7 @@ pub(crate) struct RpcClient {
|
||||
|
||||
impl RpcClient {
|
||||
pub(crate) fn new(connection: JsonRpcConnection) -> (Self, mpsc::Receiver<RpcClientEvent>) {
|
||||
let (write_tx, mut incoming_rx, _disconnected_rx, transport_tasks) =
|
||||
connection.into_parts();
|
||||
let (write_tx, mut incoming_rx, disconnected_rx, transport_tasks) = connection.into_parts();
|
||||
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
|
||||
let (event_tx, event_rx) = mpsc::channel(128);
|
||||
|
||||
@@ -218,6 +232,7 @@ impl RpcClient {
|
||||
Self {
|
||||
write_tx,
|
||||
pending,
|
||||
disconnected_rx,
|
||||
next_request_id: AtomicI64::new(1),
|
||||
transport_tasks,
|
||||
reader_task,
|
||||
@@ -253,10 +268,16 @@ impl RpcClient {
|
||||
{
|
||||
let request_id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::SeqCst));
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
self.pending
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id.clone(), response_tx);
|
||||
{
|
||||
let mut pending = self.pending.lock().await;
|
||||
// Registering the pending request and checking disconnect must be
|
||||
// atomic with the reader's drain_pending path. Otherwise a call
|
||||
// can sneak in after the drain and wait forever.
|
||||
if *self.disconnected_rx.borrow() {
|
||||
return Err(RpcCallError::Closed);
|
||||
}
|
||||
pending.insert(request_id.clone(), response_tx);
|
||||
}
|
||||
|
||||
let params = match serde_json::to_value(params) {
|
||||
Ok(params) => params,
|
||||
@@ -280,10 +301,17 @@ impl RpcClient {
|
||||
return Err(RpcCallError::Closed);
|
||||
}
|
||||
|
||||
let result = response_rx.await.map_err(|_| RpcCallError::Closed)?;
|
||||
// Do not race in-flight requests directly against the transport-close
|
||||
// watch value. The connection reader receives JSON-RPC messages and
|
||||
// the terminal disconnect event on one ordered queue, then drains any
|
||||
// still-pending requests. Awaiting this receiver preserves that order:
|
||||
// responses already read before EOF still win, and truly pending calls
|
||||
// are failed once the reader observes the disconnect.
|
||||
let result: Result<Value, RpcCallError> =
|
||||
response_rx.await.map_err(|_| RpcCallError::Closed)?;
|
||||
let response = match result {
|
||||
Ok(response) => response,
|
||||
Err(error) => return Err(RpcCallError::Server(error)),
|
||||
Err(error) => return Err(error),
|
||||
};
|
||||
serde_json::from_value(response).map_err(RpcCallError::Json)
|
||||
}
|
||||
@@ -304,13 +332,6 @@ impl Drop for RpcClient {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum RpcCallError {
|
||||
Closed,
|
||||
Json(serde_json::Error),
|
||||
Server(JSONRPCErrorError),
|
||||
}
|
||||
|
||||
pub(crate) fn encode_server_message(
|
||||
message: RpcServerOutboundMessage,
|
||||
) -> Result<JSONRPCMessage, serde_json::Error> {
|
||||
@@ -417,7 +438,7 @@ async fn handle_server_message(
|
||||
}
|
||||
JSONRPCMessage::Error(JSONRPCError { id, error }) => {
|
||||
if let Some(pending) = pending.lock().await.remove(&id) {
|
||||
let _ = pending.send(Err(error));
|
||||
let _ = pending.send(Err(RpcCallError::Server(error)));
|
||||
}
|
||||
}
|
||||
JSONRPCMessage::Notification(notification) => {
|
||||
@@ -445,11 +466,7 @@ async fn drain_pending(pending: &Mutex<HashMap<RequestId, PendingRequest>>) {
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
for pending in pending {
|
||||
let _ = pending.send(Err(JSONRPCErrorError {
|
||||
code: -32000,
|
||||
data: None,
|
||||
message: "JSON-RPC transport closed".to_string(),
|
||||
}));
|
||||
let _ = pending.send(Err(RpcCallError::Closed));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user