diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index bd0ec047cd..75d7753df5 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -1,9 +1,6 @@ use std::collections::HashMap; +use std::future::Future; use std::sync::Arc; -#[cfg(test)] -use std::sync::Mutex as StdMutex; -#[cfg(test)] -use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; use std::time::Duration; @@ -26,7 +23,6 @@ use codex_app_server_protocol::JSONRPCError; use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::JSONRPCMessage; use codex_app_server_protocol::JSONRPCNotification; -use codex_app_server_protocol::JSONRPCRequest; use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::RequestId; use serde::Serialize; @@ -67,7 +63,6 @@ use crate::protocol::FS_REMOVE_METHOD; use crate::protocol::FS_WRITE_FILE_METHOD; use crate::protocol::INITIALIZE_METHOD; use crate::protocol::INITIALIZED_METHOD; -use crate::protocol::InitializeParams; use crate::protocol::InitializeResponse; use crate::protocol::ReadParams; use crate::protocol::ReadResponse; @@ -75,10 +70,22 @@ use crate::protocol::TerminateParams; use crate::protocol::TerminateResponse; use crate::protocol::WriteParams; use crate::protocol::WriteResponse; -use crate::server::ExecServerHandler; use crate::server::ExecServerOutboundMessage; use crate::server::ExecServerServerNotification; +mod jsonrpc_backend; +mod local_backend; +#[cfg(test)] +mod process; +use jsonrpc_backend::JsonRpcBackend; +use local_backend::LocalBackend; +#[cfg(test)] +use process::ExecServerOutput; +#[cfg(test)] +use process::ExecServerProcess; +#[cfg(test)] +use process::RemoteProcessStatus; + impl Default for ExecServerClientConnectOptions { fn default() -> Self { Self { @@ -111,182 +118,27 @@ impl RemoteExecServerConnectArgs { } } -#[cfg(test)] -#[derive(Debug, Clone, PartialEq, Eq)] -struct ExecServerOutput { - stream: crate::protocol::ExecOutputStream, - chunk: Vec, -} - -#[cfg(test)] -struct ExecServerProcess { - process_id: String, - output_rx: broadcast::Receiver, - status: Arc, - client: ExecServerClient, -} - -#[cfg(test)] -impl ExecServerProcess { - fn output_receiver(&self) -> broadcast::Receiver { - self.output_rx.resubscribe() - } - - fn has_exited(&self) -> bool { - self.status.has_exited() - } - - fn exit_code(&self) -> Option { - self.status.exit_code() - } - - fn terminate(&self) { - let client = self.client.clone(); - let process_id = self.process_id.clone(); - tokio::spawn(async move { - let _ = client.terminate_session(&process_id).await; - }); - } -} - -#[cfg(test)] -struct RemoteProcessStatus { - exited: AtomicBool, - exit_code: StdMutex>, -} - -#[cfg(test)] -impl RemoteProcessStatus { - fn new() -> Self { - Self { - exited: AtomicBool::new(false), - exit_code: StdMutex::new(None), - } - } - - fn has_exited(&self) -> bool { - self.exited.load(Ordering::SeqCst) - } - - fn exit_code(&self) -> Option { - self.exit_code.lock().ok().and_then(|guard| *guard) - } - - fn mark_exited(&self, exit_code: Option) { - self.exited.store(true, Ordering::SeqCst); - if let Ok(mut guard) = self.exit_code.lock() { - *guard = exit_code; - } - } -} - -enum PendingRequest { - Initialize(oneshot::Sender>), - Exec(oneshot::Sender>), - Read(oneshot::Sender>), - Write(oneshot::Sender>), - Terminate(oneshot::Sender>), - FsReadFile(oneshot::Sender>), - FsWriteFile(oneshot::Sender>), - FsCreateDirectory(oneshot::Sender>), - FsGetMetadata(oneshot::Sender>), - FsReadDirectory(oneshot::Sender>), - FsRemove(oneshot::Sender>), - FsCopy(oneshot::Sender>), -} - -impl PendingRequest { - fn resolve_json(self, result: Value) -> Result<(), ExecServerError> { - match self { - PendingRequest::Initialize(tx) => { - let _ = tx.send(Ok(serde_json::from_value(result)?)); - } - PendingRequest::Exec(tx) => { - let _ = tx.send(Ok(serde_json::from_value(result)?)); - } - PendingRequest::Read(tx) => { - let _ = tx.send(Ok(serde_json::from_value(result)?)); - } - PendingRequest::Write(tx) => { - let _ = tx.send(Ok(serde_json::from_value(result)?)); - } - PendingRequest::Terminate(tx) => { - let _ = tx.send(Ok(serde_json::from_value(result)?)); - } - PendingRequest::FsReadFile(tx) => { - let _ = tx.send(Ok(serde_json::from_value(result)?)); - } - PendingRequest::FsWriteFile(tx) => { - let _ = tx.send(Ok(serde_json::from_value(result)?)); - } - PendingRequest::FsCreateDirectory(tx) => { - let _ = tx.send(Ok(serde_json::from_value(result)?)); - } - PendingRequest::FsGetMetadata(tx) => { - let _ = tx.send(Ok(serde_json::from_value(result)?)); - } - PendingRequest::FsReadDirectory(tx) => { - let _ = tx.send(Ok(serde_json::from_value(result)?)); - } - PendingRequest::FsRemove(tx) => { - let _ = tx.send(Ok(serde_json::from_value(result)?)); - } - PendingRequest::FsCopy(tx) => { - let _ = tx.send(Ok(serde_json::from_value(result)?)); - } - } - Ok(()) - } - - fn resolve_error(self, error: JSONRPCErrorError) { - match self { - PendingRequest::Initialize(tx) => { - let _ = tx.send(Err(error)); - } - PendingRequest::Exec(tx) => { - let _ = tx.send(Err(error)); - } - PendingRequest::Read(tx) => { - let _ = tx.send(Err(error)); - } - PendingRequest::Write(tx) => { - let _ = tx.send(Err(error)); - } - PendingRequest::Terminate(tx) => { - let _ = tx.send(Err(error)); - } - PendingRequest::FsReadFile(tx) => { - let _ = tx.send(Err(error)); - } - PendingRequest::FsWriteFile(tx) => { - let _ = tx.send(Err(error)); - } - PendingRequest::FsCreateDirectory(tx) => { - let _ = tx.send(Err(error)); - } - PendingRequest::FsGetMetadata(tx) => { - let _ = tx.send(Err(error)); - } - PendingRequest::FsReadDirectory(tx) => { - let _ = tx.send(Err(error)); - } - PendingRequest::FsRemove(tx) => { - let _ = tx.send(Err(error)); - } - PendingRequest::FsCopy(tx) => { - let _ = tx.send(Err(error)); - } - } - } -} +type PendingRequest = oneshot::Sender>; enum ClientBackend { - JsonRpc { - write_tx: mpsc::Sender, - }, - InProcess { - handler: Arc>, - }, + JsonRpc(JsonRpcBackend), + InProcess(LocalBackend), +} + +impl ClientBackend { + fn as_local(&self) -> Option<&LocalBackend> { + match self { + ClientBackend::JsonRpc(_) => None, + ClientBackend::InProcess(backend) => Some(backend), + } + } + + fn as_jsonrpc(&self) -> Option<&JsonRpcBackend> { + match self { + ClientBackend::JsonRpc(backend) => Some(backend), + ClientBackend::InProcess(_) => None, + } + } } struct Inner { @@ -300,12 +152,12 @@ struct Inner { impl Drop for Inner { fn drop(&mut self) { - if let ClientBackend::InProcess { handler } = &self.backend + if let Some(backend) = self.backend.as_local() && let Ok(handle) = tokio::runtime::Handle::try_current() { - let handler = Arc::clone(handler); + let backend = backend.clone(); handle.spawn(async move { - handler.lock().await.shutdown().await; + backend.shutdown().await; }); } for task in &self.transport_tasks { @@ -349,7 +201,7 @@ impl ExecServerClient { options: ExecServerClientConnectOptions, ) -> Result { let (outbound_tx, mut outgoing_rx) = mpsc::channel::(256); - let handler = Arc::new(Mutex::new(ExecServerHandler::new(outbound_tx))); + let backend = LocalBackend::new(crate::server::ExecServerHandler::new(outbound_tx)); let inner = Arc::new_cyclic(|weak| { let weak = weak.clone(); @@ -372,7 +224,7 @@ impl ExecServerClient { }); Inner { - backend: ClientBackend::InProcess { handler }, + backend: ClientBackend::InProcess(backend), pending: Mutex::new(HashMap::new()), events_tx: broadcast::channel(256).0, next_request_id: AtomicI64::new(1), @@ -465,7 +317,7 @@ impl ExecServerClient { }); Inner { - backend: ClientBackend::JsonRpc { write_tx }, + backend: ClientBackend::JsonRpc(JsonRpcBackend::new(write_tx)), pending: Mutex::new(HashMap::new()), events_tx: broadcast::channel(256).0, next_request_id: AtomicI64::new(1), @@ -526,11 +378,17 @@ impl ExecServerClient { } pub async fn exec(&self, params: ExecParams) -> Result { - self.request_exec(params).await + self.request_or_local(EXEC_METHOD, params, |backend, params| async move { + backend.exec(params).await + }) + .await } pub async fn read(&self, params: ReadParams) -> Result { - self.request_read(params).await + self.request_or_local(EXEC_READ_METHOD, params, |backend, params| async move { + backend.exec_read(params).await + }) + .await } pub async fn write( @@ -538,53 +396,56 @@ impl ExecServerClient { process_id: &str, chunk: Vec, ) -> Result { - self.write_process(WriteParams { + let params = WriteParams { process_id: process_id.to_string(), chunk: chunk.into(), + }; + self.request_or_local(EXEC_WRITE_METHOD, params, |backend, params| async move { + backend.exec_write(params).await }) .await } pub async fn terminate(&self, process_id: &str) -> Result { - self.terminate_session(process_id).await + let params = TerminateParams { + process_id: process_id.to_string(), + }; + self.request_or_local( + EXEC_TERMINATE_METHOD, + params, + |backend, params| async move { backend.terminate(params).await }, + ) + .await } pub async fn fs_read_file( &self, params: FsReadFileParams, ) -> Result { - if let ClientBackend::InProcess { handler } = &self.inner.backend { - return server_result_to_client(handler.lock().await.fs_read_file(params).await); - } - - self.send_pending_request(FS_READ_FILE_METHOD, ¶ms, PendingRequest::FsReadFile) - .await + self.request_or_local(FS_READ_FILE_METHOD, params, |backend, params| async move { + backend.fs_read_file(params).await + }) + .await } pub async fn fs_write_file( &self, params: FsWriteFileParams, ) -> Result { - if let ClientBackend::InProcess { handler } = &self.inner.backend { - return server_result_to_client(handler.lock().await.fs_write_file(params).await); - } - - self.send_pending_request(FS_WRITE_FILE_METHOD, ¶ms, PendingRequest::FsWriteFile) - .await + self.request_or_local(FS_WRITE_FILE_METHOD, params, |backend, params| async move { + backend.fs_write_file(params).await + }) + .await } pub async fn fs_create_directory( &self, params: FsCreateDirectoryParams, ) -> Result { - if let ClientBackend::InProcess { handler } = &self.inner.backend { - return server_result_to_client(handler.lock().await.fs_create_directory(params).await); - } - - self.send_pending_request( + self.request_or_local( FS_CREATE_DIRECTORY_METHOD, - ¶ms, - PendingRequest::FsCreateDirectory, + params, + |backend, params| async move { backend.fs_create_directory(params).await }, ) .await } @@ -593,14 +454,10 @@ impl ExecServerClient { &self, params: FsGetMetadataParams, ) -> Result { - if let ClientBackend::InProcess { handler } = &self.inner.backend { - return server_result_to_client(handler.lock().await.fs_get_metadata(params).await); - } - - self.send_pending_request( + self.request_or_local( FS_GET_METADATA_METHOD, - ¶ms, - PendingRequest::FsGetMetadata, + params, + |backend, params| async move { backend.fs_get_metadata(params).await }, ) .await } @@ -609,14 +466,10 @@ impl ExecServerClient { &self, params: FsReadDirectoryParams, ) -> Result { - if let ClientBackend::InProcess { handler } = &self.inner.backend { - return server_result_to_client(handler.lock().await.fs_read_directory(params).await); - } - - self.send_pending_request( + self.request_or_local( FS_READ_DIRECTORY_METHOD, - ¶ms, - PendingRequest::FsReadDirectory, + params, + |backend, params| async move { backend.fs_read_directory(params).await }, ) .await } @@ -625,21 +478,17 @@ impl ExecServerClient { &self, params: FsRemoveParams, ) -> Result { - if let ClientBackend::InProcess { handler } = &self.inner.backend { - return server_result_to_client(handler.lock().await.fs_remove(params).await); - } - - self.send_pending_request(FS_REMOVE_METHOD, ¶ms, PendingRequest::FsRemove) - .await + self.request_or_local(FS_REMOVE_METHOD, params, |backend, params| async move { + backend.fs_remove(params).await + }) + .await } pub async fn fs_copy(&self, params: FsCopyParams) -> Result { - if let ClientBackend::InProcess { handler } = &self.inner.backend { - return server_result_to_client(handler.lock().await.fs_copy(params).await); - } - - self.send_pending_request(FS_COPY_METHOD, ¶ms, PendingRequest::FsCopy) - .await + self.request_or_local(FS_COPY_METHOD, params, |backend, params| async move { + backend.fs_copy(params).await + }) + .await } async fn initialize( @@ -651,9 +500,14 @@ impl ExecServerClient { initialize_timeout, } = options; timeout(initialize_timeout, async { - let _: InitializeResponse = self - .request_initialize(InitializeParams { client_name }) - .await?; + if let Some(backend) = self.inner.backend.as_local() { + backend.initialize().await?; + } else { + let params = crate::protocol::InitializeParams { client_name }; + let _: InitializeResponse = self + .send_pending_request(INITIALIZE_METHOD, ¶ms) + .await?; + } self.notify(INITIALIZED_METHOD, &serde_json::json!({})) .await }) @@ -663,85 +517,13 @@ impl ExecServerClient { })? } - async fn request_exec(&self, params: ExecParams) -> Result { - if let ClientBackend::InProcess { handler } = &self.inner.backend { - return server_result_to_client(handler.lock().await.exec(params).await); - } - - self.send_pending_request(EXEC_METHOD, ¶ms, PendingRequest::Exec) - .await - } - - async fn write_process(&self, params: WriteParams) -> Result { - if let ClientBackend::InProcess { handler } = &self.inner.backend { - return server_result_to_client(handler.lock().await.write(params).await); - } - - self.send_pending_request(EXEC_WRITE_METHOD, ¶ms, PendingRequest::Write) - .await - } - - async fn request_read(&self, params: ReadParams) -> Result { - if let ClientBackend::InProcess { handler } = &self.inner.backend { - return server_result_to_client(handler.lock().await.read(params).await); - } - - self.send_pending_request(EXEC_READ_METHOD, ¶ms, PendingRequest::Read) - .await - } - - async fn terminate_session( - &self, - process_id: &str, - ) -> Result { - let params = TerminateParams { - process_id: process_id.to_string(), - }; - if let ClientBackend::InProcess { handler } = &self.inner.backend { - return server_result_to_client(handler.lock().await.terminate(params).await); - } - - self.send_pending_request(EXEC_TERMINATE_METHOD, ¶ms, PendingRequest::Terminate) - .await - } - async fn notify(&self, method: &str, params: &P) -> Result<(), ExecServerError> { match &self.inner.backend { - ClientBackend::JsonRpc { write_tx } => { - let params = serde_json::to_value(params)?; - write_tx - .send(JSONRPCMessage::Notification(JSONRPCNotification { - method: method.to_string(), - params: Some(params), - })) - .await - .map_err(|_| ExecServerError::Closed) - } - ClientBackend::InProcess { handler } => match method { - INITIALIZED_METHOD => handler - .lock() - .await - .initialized() - .map_err(ExecServerError::Protocol), - other => Err(ExecServerError::Protocol(format!( - "unsupported in-process notification method `{other}`" - ))), - }, + ClientBackend::JsonRpc(backend) => backend.notify(method, params).await, + ClientBackend::InProcess(backend) => backend.notify(method).await, } } - async fn request_initialize( - &self, - params: InitializeParams, - ) -> Result { - if let ClientBackend::InProcess { handler } = &self.inner.backend { - return server_result_to_client(handler.lock().await.initialize()); - } - - self.send_pending_request(INITIALIZE_METHOD, ¶ms, PendingRequest::Initialize) - .await - } - fn next_request_id(&self) -> RequestId { RequestId::Integer(self.inner.next_request_id.fetch_add(1, Ordering::SeqCst)) } @@ -750,10 +532,10 @@ impl ExecServerClient { &self, method: &str, params: &P, - build_pending: impl FnOnce(oneshot::Sender>) -> PendingRequest, ) -> Result where P: Serialize, + T: serde::de::DeserializeOwned, { let request_id = self.next_request_id(); let (response_tx, response_rx) = oneshot::channel(); @@ -761,32 +543,56 @@ impl ExecServerClient { .pending .lock() .await - .insert(request_id.clone(), build_pending(response_tx)); - let ClientBackend::JsonRpc { write_tx } = &self.inner.backend else { + .insert(request_id.clone(), response_tx); + let Some(backend) = self.inner.backend.as_jsonrpc() else { unreachable!("in-process requests return before JSON-RPC setup"); }; - let send_result = send_jsonrpc_request(write_tx, request_id.clone(), method, params).await; + let send_result = backend + .send_request(request_id.clone(), method, params) + .await; self.finish_request(request_id, send_result, response_rx) .await } + async fn request_or_local( + &self, + method: &str, + params: P, + call_local: impl FnOnce(LocalBackend, P) -> Fut, + ) -> Result + where + P: Serialize, + T: serde::de::DeserializeOwned, + Fut: Future>, + { + if let Some(backend) = self.inner.backend.as_local() { + return call_local(backend.clone(), params).await; + } + + self.send_pending_request(method, ¶ms).await + } + async fn finish_request( &self, request_id: RequestId, send_result: Result<(), ExecServerError>, - response_rx: oneshot::Receiver>, - ) -> Result { + response_rx: oneshot::Receiver>, + ) -> Result + where + T: serde::de::DeserializeOwned, + { if let Err(err) = send_result { self.inner.pending.lock().await.remove(&request_id); return Err(err); } - receive_typed_response(response_rx).await + let response = receive_json_response(response_rx).await?; + Ok(serde_json::from_value(response)?) } } -async fn receive_typed_response( - response_rx: oneshot::Receiver>, -) -> Result { +async fn receive_json_response( + response_rx: oneshot::Receiver>, +) -> Result { let result = response_rx.await.map_err(|_| ExecServerError::Closed)?; match result { Ok(response) => Ok(response), @@ -807,24 +613,6 @@ fn server_result_to_client(result: Result) -> Result( - write_tx: &mpsc::Sender, - request_id: RequestId, - method: &str, - params: &P, -) -> Result<(), ExecServerError> { - let params = serde_json::to_value(params)?; - write_tx - .send(JSONRPCMessage::Request(JSONRPCRequest { - id: request_id, - method: method.to_string(), - params: Some(params), - trace: None, - })) - .await - .map_err(|_| ExecServerError::Closed) -} - async fn handle_in_process_outbound_message( inner: &Arc, message: ExecServerOutboundMessage, @@ -864,12 +652,12 @@ async fn handle_server_message( match message { JSONRPCMessage::Response(JSONRPCResponse { id, result }) => { if let Some(pending) = inner.pending.lock().await.remove(&id) { - pending.resolve_json(result)?; + let _ = pending.send(Ok(result)); } } JSONRPCMessage::Error(JSONRPCError { id, error }) => { if let Some(pending) = inner.pending.lock().await.remove(&id) { - pending.resolve_error(error); + let _ = pending.send(Err(error)); } } JSONRPCMessage::Notification(notification) => { @@ -917,11 +705,11 @@ async fn handle_transport_shutdown(inner: &Arc) { .collect::>() }; for pending in pending { - pending.resolve_error(JSONRPCErrorError { + let _ = pending.send(Err(JSONRPCErrorError { code: -32000, data: None, message: "exec-server transport closed".to_string(), - }); + })); } } diff --git a/codex-rs/exec-server/src/client/jsonrpc_backend.rs b/codex-rs/exec-server/src/client/jsonrpc_backend.rs new file mode 100644 index 0000000000..9a28daaf6c --- /dev/null +++ b/codex-rs/exec-server/src/client/jsonrpc_backend.rs @@ -0,0 +1,51 @@ +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCRequest; +use codex_app_server_protocol::RequestId; +use serde::Serialize; +use tokio::sync::mpsc; + +use super::ExecServerError; + +pub(super) struct JsonRpcBackend { + write_tx: mpsc::Sender, +} + +impl JsonRpcBackend { + pub(super) fn new(write_tx: mpsc::Sender) -> Self { + Self { write_tx } + } + + pub(super) async fn notify( + &self, + method: &str, + params: &P, + ) -> Result<(), ExecServerError> { + let params = serde_json::to_value(params)?; + self.write_tx + .send(JSONRPCMessage::Notification(JSONRPCNotification { + method: method.to_string(), + params: Some(params), + })) + .await + .map_err(|_| ExecServerError::Closed) + } + + pub(super) async fn send_request( + &self, + request_id: RequestId, + method: &str, + params: &P, + ) -> Result<(), ExecServerError> { + let params = serde_json::to_value(params)?; + self.write_tx + .send(JSONRPCMessage::Request(JSONRPCRequest { + id: request_id, + method: method.to_string(), + params: Some(params), + trace: None, + })) + .await + .map_err(|_| ExecServerError::Closed) + } +} diff --git a/codex-rs/exec-server/src/client/local_backend.rs b/codex-rs/exec-server/src/client/local_backend.rs new file mode 100644 index 0000000000..34324fdf50 --- /dev/null +++ b/codex-rs/exec-server/src/client/local_backend.rs @@ -0,0 +1,141 @@ +use std::sync::Arc; + +use codex_app_server_protocol::FsCopyParams; +use codex_app_server_protocol::FsCopyResponse; +use codex_app_server_protocol::FsCreateDirectoryParams; +use codex_app_server_protocol::FsCreateDirectoryResponse; +use codex_app_server_protocol::FsGetMetadataParams; +use codex_app_server_protocol::FsGetMetadataResponse; +use codex_app_server_protocol::FsReadDirectoryParams; +use codex_app_server_protocol::FsReadDirectoryResponse; +use codex_app_server_protocol::FsReadFileParams; +use codex_app_server_protocol::FsReadFileResponse; +use codex_app_server_protocol::FsRemoveParams; +use codex_app_server_protocol::FsRemoveResponse; +use codex_app_server_protocol::FsWriteFileParams; +use codex_app_server_protocol::FsWriteFileResponse; +use tokio::sync::Mutex; + +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; +use crate::protocol::INITIALIZED_METHOD; +use crate::protocol::InitializeResponse; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateParams; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteParams; +use crate::protocol::WriteResponse; +use crate::server::ExecServerHandler; + +use super::ExecServerError; +use super::server_result_to_client; + +#[derive(Clone)] +pub(super) struct LocalBackend { + handler: Arc>, +} + +impl LocalBackend { + pub(super) fn new(handler: ExecServerHandler) -> Self { + Self { + handler: Arc::new(Mutex::new(handler)), + } + } + + pub(super) async fn shutdown(&self) { + self.handler.lock().await.shutdown().await; + } + + pub(super) async fn initialize(&self) -> Result { + server_result_to_client(self.handler.lock().await.initialize()) + } + + pub(super) async fn notify(&self, method: &str) -> Result<(), ExecServerError> { + match method { + INITIALIZED_METHOD => self + .handler + .lock() + .await + .initialized() + .map_err(ExecServerError::Protocol), + other => Err(ExecServerError::Protocol(format!( + "unsupported in-process notification method `{other}`" + ))), + } + } + + pub(super) async fn exec(&self, params: ExecParams) -> Result { + server_result_to_client(self.handler.lock().await.exec(params).await) + } + + pub(super) async fn exec_read( + &self, + params: ReadParams, + ) -> Result { + server_result_to_client(self.handler.lock().await.exec_read(params).await) + } + + pub(super) async fn exec_write( + &self, + params: WriteParams, + ) -> Result { + server_result_to_client(self.handler.lock().await.exec_write(params).await) + } + + pub(super) async fn terminate( + &self, + params: TerminateParams, + ) -> Result { + server_result_to_client(self.handler.lock().await.terminate(params).await) + } + + pub(super) async fn fs_read_file( + &self, + params: FsReadFileParams, + ) -> Result { + server_result_to_client(self.handler.lock().await.fs_read_file(params).await) + } + + pub(super) async fn fs_write_file( + &self, + params: FsWriteFileParams, + ) -> Result { + server_result_to_client(self.handler.lock().await.fs_write_file(params).await) + } + + pub(super) async fn fs_create_directory( + &self, + params: FsCreateDirectoryParams, + ) -> Result { + server_result_to_client(self.handler.lock().await.fs_create_directory(params).await) + } + + pub(super) async fn fs_get_metadata( + &self, + params: FsGetMetadataParams, + ) -> Result { + server_result_to_client(self.handler.lock().await.fs_get_metadata(params).await) + } + + pub(super) async fn fs_read_directory( + &self, + params: FsReadDirectoryParams, + ) -> Result { + server_result_to_client(self.handler.lock().await.fs_read_directory(params).await) + } + + pub(super) async fn fs_remove( + &self, + params: FsRemoveParams, + ) -> Result { + server_result_to_client(self.handler.lock().await.fs_remove(params).await) + } + + pub(super) async fn fs_copy( + &self, + params: FsCopyParams, + ) -> Result { + server_result_to_client(self.handler.lock().await.fs_copy(params).await) + } +} diff --git a/codex-rs/exec-server/src/client/process.rs b/codex-rs/exec-server/src/client/process.rs new file mode 100644 index 0000000000..b3f77e8aad --- /dev/null +++ b/codex-rs/exec-server/src/client/process.rs @@ -0,0 +1,72 @@ +use std::sync::Arc; +use std::sync::Mutex as StdMutex; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; + +use tokio::sync::broadcast; + +use super::ExecServerClient; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct ExecServerOutput { + pub(super) stream: crate::protocol::ExecOutputStream, + pub(super) chunk: Vec, +} + +pub(super) struct ExecServerProcess { + pub(super) process_id: String, + pub(super) output_rx: broadcast::Receiver, + pub(super) status: Arc, + pub(super) client: ExecServerClient, +} + +impl ExecServerProcess { + pub(super) fn output_receiver(&self) -> broadcast::Receiver { + self.output_rx.resubscribe() + } + + pub(super) fn has_exited(&self) -> bool { + self.status.has_exited() + } + + pub(super) fn exit_code(&self) -> Option { + self.status.exit_code() + } + + pub(super) fn terminate(&self) { + let client = self.client.clone(); + let process_id = self.process_id.clone(); + tokio::spawn(async move { + let _ = client.terminate(&process_id).await; + }); + } +} + +pub(super) struct RemoteProcessStatus { + exited: AtomicBool, + exit_code: StdMutex>, +} + +impl RemoteProcessStatus { + pub(super) fn new() -> Self { + Self { + exited: AtomicBool::new(false), + exit_code: StdMutex::new(None), + } + } + + pub(super) fn has_exited(&self) -> bool { + self.exited.load(Ordering::SeqCst) + } + + pub(super) fn exit_code(&self) -> Option { + self.exit_code.lock().ok().and_then(|guard| *guard) + } + + pub(super) fn mark_exited(&self, exit_code: Option) { + self.exited.store(true, Ordering::SeqCst); + if let Ok(mut guard) = self.exit_code.lock() { + *guard = exit_code; + } + } +} diff --git a/codex-rs/exec-server/src/client/tests.rs b/codex-rs/exec-server/src/client/tests.rs index 5c7e704ae2..b9035df22c 100644 --- a/codex-rs/exec-server/src/client/tests.rs +++ b/codex-rs/exec-server/src/client/tests.rs @@ -250,12 +250,7 @@ async fn connect_in_process_rejects_writes_to_unknown_processes() { Err(err) => panic!("failed to connect in-process client: {err}"), }; - let result = client - .write_process(crate::protocol::WriteParams { - process_id: "missing".to_string(), - chunk: b"input".to_vec().into(), - }) - .await; + let result = client.write("missing", b"input".to_vec()).await; match result { Err(ExecServerError::Server { code, message }) => { @@ -290,7 +285,7 @@ async fn connect_in_process_terminate_marks_process_exited() { Err(err) => panic!("failed to start in-process child: {err}"), }; - if let Err(err) = client.terminate_session(&process.process_id).await { + if let Err(err) = client.terminate(&process.process_id).await { panic!("failed to terminate in-process child: {err}"); } diff --git a/codex-rs/exec-server/src/server/handler.rs b/codex-rs/exec-server/src/server/handler.rs index f12d4307b4..4a1a2969b5 100644 --- a/codex-rs/exec-server/src/server/handler.rs +++ b/codex-rs/exec-server/src/server/handler.rs @@ -36,7 +36,11 @@ use crate::protocol::ReadResponse; use crate::protocol::TerminateResponse; use crate::protocol::WriteResponse; use crate::server::filesystem::ExecServerFileSystem; +use crate::server::routing::ExecServerClientNotification; +use crate::server::routing::ExecServerInboundMessage; use crate::server::routing::ExecServerOutboundMessage; +use crate::server::routing::ExecServerRequest; +use crate::server::routing::ExecServerResponseMessage; use crate::server::routing::ExecServerServerNotification; use crate::server::routing::internal_error; use crate::server::routing::invalid_params; @@ -297,7 +301,7 @@ impl ExecServerHandler { self.file_system.copy(params).await } - pub(crate) async fn read( + pub(crate) async fn exec_read( &self, params: crate::protocol::ReadParams, ) -> Result { @@ -360,7 +364,7 @@ impl ExecServerHandler { } } - pub(crate) async fn write( + pub(crate) async fn exec_write( &self, params: crate::protocol::WriteParams, ) -> Result { @@ -404,124 +408,89 @@ impl ExecServerHandler { Ok(TerminateResponse { running }) } -} -#[cfg(test)] -impl ExecServerHandler { - async fn handle_message( + pub(crate) async fn handle_message( &mut self, - message: crate::server::routing::ExecServerInboundMessage, + message: ExecServerInboundMessage, ) -> Result<(), String> { match message { - crate::server::routing::ExecServerInboundMessage::Request(request) => { - self.handle_request(request).await + ExecServerInboundMessage::Request(request) => self.handle_request(request).await, + ExecServerInboundMessage::Notification(ExecServerClientNotification::Initialized) => { + self.initialized() } - crate::server::routing::ExecServerInboundMessage::Notification( - crate::server::routing::ExecServerClientNotification::Initialized, - ) => self.initialized(), } } - async fn handle_request( - &mut self, - request: crate::server::routing::ExecServerRequest, - ) -> Result<(), String> { + async fn handle_request(&mut self, request: ExecServerRequest) -> Result<(), String> { let outbound = match request { - crate::server::routing::ExecServerRequest::Initialize { request_id, .. } => { - Self::request_outbound( - request_id, - self.initialize() - .map(crate::server::routing::ExecServerResponseMessage::Initialize), - ) - } - crate::server::routing::ExecServerRequest::Exec { request_id, params } => { - Self::request_outbound( - request_id, - self.exec(params) - .await - .map(crate::server::routing::ExecServerResponseMessage::Exec), - ) - } - crate::server::routing::ExecServerRequest::Read { request_id, params } => { - Self::request_outbound( - request_id, - self.read(params) - .await - .map(crate::server::routing::ExecServerResponseMessage::Read), - ) - } - crate::server::routing::ExecServerRequest::Write { request_id, params } => { - Self::request_outbound( - request_id, - self.write(params) - .await - .map(crate::server::routing::ExecServerResponseMessage::Write), - ) - } - crate::server::routing::ExecServerRequest::Terminate { request_id, params } => { - Self::request_outbound( - request_id, - self.terminate(params) - .await - .map(crate::server::routing::ExecServerResponseMessage::Terminate), - ) - } - crate::server::routing::ExecServerRequest::FsReadFile { request_id, params } => { - Self::request_outbound( - request_id, - self.fs_read_file(params) - .await - .map(crate::server::routing::ExecServerResponseMessage::FsReadFile), - ) - } - crate::server::routing::ExecServerRequest::FsWriteFile { request_id, params } => { - Self::request_outbound( - request_id, - self.fs_write_file(params) - .await - .map(crate::server::routing::ExecServerResponseMessage::FsWriteFile), - ) - } - crate::server::routing::ExecServerRequest::FsCreateDirectory { request_id, params } => { - Self::request_outbound( - request_id, - self.fs_create_directory(params) - .await - .map(crate::server::routing::ExecServerResponseMessage::FsCreateDirectory), - ) - } - crate::server::routing::ExecServerRequest::FsGetMetadata { request_id, params } => { - Self::request_outbound( - request_id, - self.fs_get_metadata(params) - .await - .map(crate::server::routing::ExecServerResponseMessage::FsGetMetadata), - ) - } - crate::server::routing::ExecServerRequest::FsReadDirectory { request_id, params } => { - Self::request_outbound( - request_id, - self.fs_read_directory(params) - .await - .map(crate::server::routing::ExecServerResponseMessage::FsReadDirectory), - ) - } - crate::server::routing::ExecServerRequest::FsRemove { request_id, params } => { - Self::request_outbound( - request_id, - self.fs_remove(params) - .await - .map(crate::server::routing::ExecServerResponseMessage::FsRemove), - ) - } - crate::server::routing::ExecServerRequest::FsCopy { request_id, params } => { - Self::request_outbound( - request_id, - self.fs_copy(params) - .await - .map(crate::server::routing::ExecServerResponseMessage::FsCopy), - ) - } + ExecServerRequest::Initialize { request_id, .. } => Self::request_outbound( + request_id, + self.initialize().map(ExecServerResponseMessage::Initialize), + ), + ExecServerRequest::Exec { request_id, params } => Self::request_outbound( + request_id, + self.exec(params).await.map(ExecServerResponseMessage::Exec), + ), + ExecServerRequest::Read { request_id, params } => Self::request_outbound( + request_id, + self.exec_read(params) + .await + .map(ExecServerResponseMessage::Read), + ), + ExecServerRequest::Write { request_id, params } => Self::request_outbound( + request_id, + self.exec_write(params) + .await + .map(ExecServerResponseMessage::Write), + ), + ExecServerRequest::Terminate { request_id, params } => Self::request_outbound( + request_id, + self.terminate(params) + .await + .map(ExecServerResponseMessage::Terminate), + ), + ExecServerRequest::FsReadFile { request_id, params } => Self::request_outbound( + request_id, + self.fs_read_file(params) + .await + .map(ExecServerResponseMessage::FsReadFile), + ), + ExecServerRequest::FsWriteFile { request_id, params } => Self::request_outbound( + request_id, + self.fs_write_file(params) + .await + .map(ExecServerResponseMessage::FsWriteFile), + ), + ExecServerRequest::FsCreateDirectory { request_id, params } => Self::request_outbound( + request_id, + self.fs_create_directory(params) + .await + .map(ExecServerResponseMessage::FsCreateDirectory), + ), + ExecServerRequest::FsGetMetadata { request_id, params } => Self::request_outbound( + request_id, + self.fs_get_metadata(params) + .await + .map(ExecServerResponseMessage::FsGetMetadata), + ), + ExecServerRequest::FsReadDirectory { request_id, params } => Self::request_outbound( + request_id, + self.fs_read_directory(params) + .await + .map(ExecServerResponseMessage::FsReadDirectory), + ), + ExecServerRequest::FsRemove { request_id, params } => Self::request_outbound( + request_id, + self.fs_remove(params) + .await + .map(ExecServerResponseMessage::FsRemove), + ), + ExecServerRequest::FsCopy { request_id, params } => Self::request_outbound( + request_id, + self.fs_copy(params) + .await + .map(ExecServerResponseMessage::FsCopy), + ), }; self.outbound_tx .send(outbound) @@ -531,19 +500,14 @@ impl ExecServerHandler { fn request_outbound( request_id: codex_app_server_protocol::RequestId, - result: Result< - crate::server::routing::ExecServerResponseMessage, - codex_app_server_protocol::JSONRPCErrorError, - >, - ) -> crate::server::routing::ExecServerOutboundMessage { + result: Result, + ) -> ExecServerOutboundMessage { match result { - Ok(response) => crate::server::routing::ExecServerOutboundMessage::Response { + Ok(response) => ExecServerOutboundMessage::Response { request_id, response, }, - Err(error) => { - crate::server::routing::ExecServerOutboundMessage::Error { request_id, error } - } + Err(error) => ExecServerOutboundMessage::Error { request_id, error }, } } } diff --git a/codex-rs/exec-server/src/server/handler/tests.rs b/codex-rs/exec-server/src/server/handler/tests.rs index 9c2a8b3b91..0a52b7d4f4 100644 --- a/codex-rs/exec-server/src/server/handler/tests.rs +++ b/codex-rs/exec-server/src/server/handler/tests.rs @@ -699,7 +699,7 @@ async fn read_paginates_retained_output_without_skipping_omitted_chunks() { } let first = handler - .read(ReadParams { + .exec_read(ReadParams { process_id: "proc-1".to_string(), after_seq: Some(0), max_bytes: Some(3), @@ -715,7 +715,7 @@ async fn read_paginates_retained_output_without_skipping_omitted_chunks() { assert_eq!(first.next_seq, 2); let second = handler - .read(ReadParams { + .exec_read(ReadParams { process_id: "proc-1".to_string(), after_seq: Some(first.next_seq - 1), max_bytes: Some(3), diff --git a/codex-rs/exec-server/src/server/processor.rs b/codex-rs/exec-server/src/server/processor.rs index 9dd9388b30..079168634d 100644 --- a/codex-rs/exec-server/src/server/processor.rs +++ b/codex-rs/exec-server/src/server/processor.rs @@ -6,11 +6,7 @@ use crate::connection::CHANNEL_CAPACITY; use crate::connection::JsonRpcConnection; use crate::connection::JsonRpcConnectionEvent; use crate::server::handler::ExecServerHandler; -use crate::server::routing::ExecServerClientNotification; -use crate::server::routing::ExecServerInboundMessage; use crate::server::routing::ExecServerOutboundMessage; -use crate::server::routing::ExecServerRequest; -use crate::server::routing::ExecServerResponseMessage; use crate::server::routing::RoutedExecServerMessage; use crate::server::routing::encode_outbound_message; use crate::server::routing::route_jsonrpc_message; @@ -40,8 +36,7 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) { match event { JsonRpcConnectionEvent::Message(message) => match route_jsonrpc_message(message) { Ok(RoutedExecServerMessage::Inbound(message)) => { - if let Err(err) = dispatch_to_handler(&mut handler, message, &outgoing_tx).await - { + if let Err(err) = handler.handle_message(message).await { warn!("closing exec-server connection after protocol error: {err}"); break; } @@ -70,119 +65,3 @@ pub(crate) async fn run_connection(connection: JsonRpcConnection) { drop(outgoing_tx); let _ = outbound_task.await; } - -async fn dispatch_to_handler( - handler: &mut ExecServerHandler, - message: ExecServerInboundMessage, - outgoing_tx: &mpsc::Sender, -) -> Result<(), String> { - match message { - ExecServerInboundMessage::Request(request) => { - let outbound = match request { - ExecServerRequest::Initialize { request_id, .. } => request_outbound( - request_id, - handler - .initialize() - .map(ExecServerResponseMessage::Initialize), - ), - ExecServerRequest::Exec { request_id, params } => request_outbound( - request_id, - handler - .exec(params) - .await - .map(ExecServerResponseMessage::Exec), - ), - ExecServerRequest::Read { request_id, params } => request_outbound( - request_id, - handler - .read(params) - .await - .map(ExecServerResponseMessage::Read), - ), - ExecServerRequest::Write { request_id, params } => request_outbound( - request_id, - handler - .write(params) - .await - .map(ExecServerResponseMessage::Write), - ), - ExecServerRequest::Terminate { request_id, params } => request_outbound( - request_id, - handler - .terminate(params) - .await - .map(ExecServerResponseMessage::Terminate), - ), - ExecServerRequest::FsReadFile { request_id, params } => request_outbound( - request_id, - handler - .fs_read_file(params) - .await - .map(ExecServerResponseMessage::FsReadFile), - ), - ExecServerRequest::FsWriteFile { request_id, params } => request_outbound( - request_id, - handler - .fs_write_file(params) - .await - .map(ExecServerResponseMessage::FsWriteFile), - ), - ExecServerRequest::FsCreateDirectory { request_id, params } => request_outbound( - request_id, - handler - .fs_create_directory(params) - .await - .map(ExecServerResponseMessage::FsCreateDirectory), - ), - ExecServerRequest::FsGetMetadata { request_id, params } => request_outbound( - request_id, - handler - .fs_get_metadata(params) - .await - .map(ExecServerResponseMessage::FsGetMetadata), - ), - ExecServerRequest::FsReadDirectory { request_id, params } => request_outbound( - request_id, - handler - .fs_read_directory(params) - .await - .map(ExecServerResponseMessage::FsReadDirectory), - ), - ExecServerRequest::FsRemove { request_id, params } => request_outbound( - request_id, - handler - .fs_remove(params) - .await - .map(ExecServerResponseMessage::FsRemove), - ), - ExecServerRequest::FsCopy { request_id, params } => request_outbound( - request_id, - handler - .fs_copy(params) - .await - .map(ExecServerResponseMessage::FsCopy), - ), - }; - outgoing_tx - .send(outbound) - .await - .map_err(|_| "outbound channel closed".to_string()) - } - ExecServerInboundMessage::Notification(ExecServerClientNotification::Initialized) => { - handler.initialized() - } - } -} - -fn request_outbound( - request_id: codex_app_server_protocol::RequestId, - result: Result, -) -> ExecServerOutboundMessage { - match result { - Ok(response) => ExecServerOutboundMessage::Response { - request_id, - response, - }, - Err(error) => ExecServerOutboundMessage::Error { request_id, error }, - } -}