diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 8aa93c9687..11b3fa6e81 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2009,14 +2009,17 @@ version = "0.0.0" dependencies = [ "anyhow", "base64 0.22.1", + "clap", "codex-app-server-protocol", "codex-utils-cargo-bin", "codex-utils-pty", + "futures", "pretty_assertions", "serde", "serde_json", "thiserror 2.0.18", "tokio", + "tokio-tungstenite", "tracing", ] diff --git a/codex-rs/exec-server/Cargo.toml b/codex-rs/exec-server/Cargo.toml index 7b1ec13cef..1b47760975 100644 --- a/codex-rs/exec-server/Cargo.toml +++ b/codex-rs/exec-server/Cargo.toml @@ -13,8 +13,10 @@ workspace = true [dependencies] base64 = { workspace = true } +clap = { workspace = true, features = ["derive"] } codex-app-server-protocol = { workspace = true } codex-utils-pty = { workspace = true } +futures = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } thiserror = { workspace = true } @@ -22,11 +24,13 @@ tokio = { workspace = true, features = [ "io-std", "io-util", "macros", + "net", "process", "rt-multi-thread", "sync", "time", ] } +tokio-tungstenite = { workspace = true } tracing = { workspace = true } [dev-dependencies] diff --git a/codex-rs/exec-server/README.md b/codex-rs/exec-server/README.md index 8aeb3a73ec..46961f5b42 100644 --- a/codex-rs/exec-server/README.md +++ b/codex-rs/exec-server/README.md @@ -1,24 +1,41 @@ # codex-exec-server -`codex-exec-server` is a small standalone stdio JSON-RPC server for spawning -and controlling subprocesses through `codex-utils-pty`. +`codex-exec-server` is a small standalone JSON-RPC server for spawning and +controlling subprocesses through `codex-utils-pty`. It currently provides: - a standalone binary: `codex-exec-server` +- a transport-agnostic server runtime with stdio and websocket entrypoints - a Rust client: `ExecServerClient` +- a separate local launch helper: `spawn_local_exec_server` - a small protocol module with shared request/response types This crate is intentionally narrow. It is not wired into the main Codex CLI or unified-exec in this PR; it is only the standalone transport layer. +The internal shape is intentionally closer to `app-server` than the first cut: + +- transport adapters are separate from the per-connection request processor +- the client only speaks the protocol; it does not spawn a server subprocess +- local child-process launch is handled by a separate helper/factory layer + +That split is meant to leave reusable seams if exec-server and app-server later +share transport or JSON-RPC connection utilities. + ## Transport -The server speaks newline-delimited JSON-RPC 2.0 over stdio. +The server speaks the same JSON-RPC message shapes over multiple transports. -- `stdin`: one JSON-RPC message per line -- `stdout`: one JSON-RPC message per line -- `stderr`: reserved for logs / process errors +The standalone binary supports: + +- `stdio://` (default) +- `ws://IP:PORT` + +Wire framing: + +- stdio: one newline-delimited JSON-RPC message per line on stdin/stdout +- websocket: one JSON-RPC message per websocket text frame Like the app-server transport, messages on the wire omit the `"jsonrpc":"2.0"` field and use the shared `codex-app-server-protocol` envelope types. @@ -41,11 +58,11 @@ Each connection follows this sequence: 5. Read streaming notifications from `command/exec/outputDelta` and `command/exec/exited`. -If the server receives any notification other than `initialized`, it replies -with an error using request id `-1`. +If the client sends exec methods before completing the `initialize` / +`initialized` handshake, the server rejects them. -If the stdio connection closes, the server terminates any remaining managed -processes before exiting. +If a connection closes, the server terminates any remaining managed processes +for that connection. ## API @@ -72,10 +89,10 @@ Response: ### `initialized` Handshake acknowledgement notification sent by the client after a successful -`initialize` response. +`initialize` response. Exec methods are rejected until this arrives. -Params are currently ignored. Sending any other notification method is treated -as an invalid request. +Params are currently ignored. Sending any other client notification method is a +protocol error. ### `command/exec` @@ -242,13 +259,43 @@ Typical error cases: The crate exports: - `ExecServerClient` +- `ExecServerClientConnectOptions` +- `RemoteExecServerConnectArgs` - `ExecServerLaunchCommand` - `ExecServerProcess` +- `SpawnedExecServer` - `ExecServerError` +- `ExecServerTransport` +- `spawn_local_exec_server(...)` - protocol structs such as `ExecParams`, `ExecResponse`, `WriteParams`, `TerminateParams`, `ExecOutputDeltaNotification`, and `ExecExitedNotification` -- `run_main()` for embedding the stdio server in a binary +- `run_main()` and `run_main_with_transport(...)` + +### Binary + +Run over stdio: + +```text +codex-exec-server +``` + +Run as a websocket server: + +```text +codex-exec-server --listen ws://127.0.0.1:8080 +``` + +### Client + +Connect the client to an existing server transport: + +- `ExecServerClient::connect_stdio(...)` +- `ExecServerClient::connect_websocket(...)` + +Spawning a local child process is deliberately separate: + +- `spawn_local_exec_server(...)` ## Example session diff --git a/codex-rs/exec-server/src/bin/codex-exec-server.rs b/codex-rs/exec-server/src/bin/codex-exec-server.rs index 399167c1a9..16df84d9b6 100644 --- a/codex-rs/exec-server/src/bin/codex-exec-server.rs +++ b/codex-rs/exec-server/src/bin/codex-exec-server.rs @@ -1,6 +1,22 @@ +use clap::Parser; +use codex_exec_server::ExecServerTransport; + +#[derive(Debug, Parser)] +struct ExecServerArgs { + /// Transport endpoint URL. Supported values: `stdio://` (default), + /// `ws://IP:PORT`. + #[arg( + long = "listen", + value_name = "URL", + default_value = ExecServerTransport::DEFAULT_LISTEN_URL + )] + listen: ExecServerTransport, +} + #[tokio::main] async fn main() { - if let Err(err) = codex_exec_server::run_main().await { + let args = ExecServerArgs::parse(); + if let Err(err) = codex_exec_server::run_main_with_transport(args.listen).await { eprintln!("{err}"); std::process::exit(1); } diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index 68a8de9b79..d68eaf3277 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -1,6 +1,4 @@ use std::collections::HashMap; -use std::path::PathBuf; -use std::process::Stdio; use std::sync::Arc; use std::sync::Mutex as StdMutex; use std::sync::atomic::AtomicBool; @@ -17,19 +15,19 @@ use codex_app_server_protocol::RequestId; use serde::Serialize; use serde::de::DeserializeOwned; use serde_json::Value; -use tokio::io::AsyncBufReadExt; -use tokio::io::AsyncWriteExt; -use tokio::io::BufReader; -use tokio::process::Child; -use tokio::process::Command; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; use tokio::sync::Mutex; use tokio::sync::broadcast; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio::task::JoinHandle; +use tokio_tungstenite::connect_async; use tracing::debug; use tracing::warn; +use crate::connection::JsonRpcConnection; +use crate::connection::JsonRpcConnectionEvent; use crate::protocol::EXEC_EXITED_METHOD; use crate::protocol::EXEC_METHOD; use crate::protocol::EXEC_OUTPUT_DELTA_METHOD; @@ -49,9 +47,30 @@ use crate::protocol::WriteParams; use crate::protocol::WriteResponse; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct ExecServerLaunchCommand { - pub program: PathBuf, - pub args: Vec, +pub struct ExecServerClientConnectOptions { + pub client_name: String, +} + +impl Default for ExecServerClientConnectOptions { + fn default() -> Self { + Self { + client_name: "codex-core".to_string(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RemoteExecServerConnectArgs { + pub websocket_url: String, + pub client_name: String, +} + +impl From for ExecServerClientConnectOptions { + fn from(value: RemoteExecServerConnectArgs) -> Self { + Self { + client_name: value.client_name, + } + } } pub struct ExecServerProcess { @@ -143,24 +162,16 @@ struct RegisteredProcess { } struct Inner { - child: StdMutex>, - write_tx: mpsc::UnboundedSender, + write_tx: mpsc::Sender, pending: Mutex>>>, processes: Mutex>, next_request_id: AtomicI64, reader_task: JoinHandle<()>, - writer_task: JoinHandle<()>, } impl Drop for Inner { fn drop(&mut self) { self.reader_task.abort(); - self.writer_task.abort(); - if let Ok(mut child_guard) = self.child.lock() - && let Some(child) = child_guard.as_mut() - { - let _ = child.start_kill(); - } } } @@ -173,6 +184,12 @@ pub struct ExecServerClient { pub enum ExecServerError { #[error("failed to spawn exec-server: {0}")] Spawn(#[source] std::io::Error), + #[error("failed to connect to exec-server websocket `{url}`: {source}")] + WebSocketConnect { + url: String, + #[source] + source: tokio_tungstenite::tungstenite::Error, + }, #[error("exec-server transport closed")] Closed, #[error("failed to serialize or deserialize exec-server JSON: {0}")] @@ -184,102 +201,90 @@ pub enum ExecServerError { } impl ExecServerClient { - pub async fn spawn(command: ExecServerLaunchCommand) -> Result { - let mut child = Command::new(&command.program); - child.args(&command.args); - child.stdin(Stdio::piped()); - child.stdout(Stdio::piped()); - child.stderr(Stdio::inherit()); - child.kill_on_drop(true); + pub async fn connect_stdio( + stdin: W, + stdout: R, + options: ExecServerClientConnectOptions, + ) -> Result + where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, + { + Self::connect( + JsonRpcConnection::from_stdio(stdout, stdin, "exec-server stdio".to_string()), + options, + ) + .await + } - let mut child = child.spawn().map_err(ExecServerError::Spawn)?; - let stdin = child.stdin.take().ok_or_else(|| { - ExecServerError::Protocol("exec-server stdin was not captured".to_string()) - })?; - let stdout = child.stdout.take().ok_or_else(|| { - ExecServerError::Protocol("exec-server stdout was not captured".to_string()) - })?; + pub async fn connect_websocket( + args: RemoteExecServerConnectArgs, + ) -> Result { + let websocket_url = args.websocket_url.clone(); + let (stream, _) = connect_async(websocket_url.as_str()) + .await + .map_err(|source| ExecServerError::WebSocketConnect { + url: websocket_url.clone(), + source, + })?; - let (write_tx, mut write_rx) = mpsc::unbounded_channel::(); - let writer_task = tokio::spawn(async move { - let mut stdin = stdin; - while let Some(message) = write_rx.recv().await { - let encoded = match serde_json::to_vec(&message) { - Ok(encoded) => encoded, - Err(err) => { - warn!("failed to encode exec-server message: {err}"); - break; - } - }; - if stdin.write_all(&encoded).await.is_err() { - break; - } - if stdin.write_all(b"\n").await.is_err() { - break; - } - if stdin.flush().await.is_err() { - break; - } - } - }); + Self::connect( + JsonRpcConnection::from_websocket( + stream, + format!("exec-server websocket {websocket_url}"), + ), + args.into(), + ) + .await + } - let pending = Mutex::new(HashMap::< - RequestId, - oneshot::Sender>, - >::new()); - let processes = Mutex::new(HashMap::::new()); - let inner = Arc::new_cyclic(move |weak| { + async fn connect( + connection: JsonRpcConnection, + options: ExecServerClientConnectOptions, + ) -> Result { + let (write_tx, mut incoming_rx) = connection.into_parts(); + let inner = Arc::new_cyclic(|weak| { let weak = weak.clone(); let reader_task = tokio::spawn(async move { - let mut lines = BufReader::new(stdout).lines(); - loop { - let Some(inner) = weak.upgrade() else { - break; - }; - let next_line = lines.next_line().await; - match next_line { - Ok(Some(line)) => { - if line.trim().is_empty() { - continue; - } - match serde_json::from_str::(&line) { - Ok(message) => { - if let Err(err) = handle_server_message(&inner, message).await { - warn!("failed to handle exec-server message: {err}"); - break; - } - } - Err(err) => { - warn!("failed to parse exec-server message: {err}"); - break; - } + while let Some(event) = incoming_rx.recv().await { + match event { + JsonRpcConnectionEvent::Message(message) => { + if let Some(inner) = weak.upgrade() + && let Err(err) = handle_server_message(&inner, message).await + { + warn!("exec-server client closing after protocol error: {err}"); + handle_transport_shutdown(&inner).await; + return; } } - Ok(None) => break, - Err(err) => { - warn!("failed to read exec-server stdout: {err}"); - break; + JsonRpcConnectionEvent::Disconnected { reason } => { + if let Some(reason) = reason { + warn!("exec-server client transport disconnected: {reason}"); + } + if let Some(inner) = weak.upgrade() { + handle_transport_shutdown(&inner).await; + } + return; } } } + if let Some(inner) = weak.upgrade() { handle_transport_shutdown(&inner).await; } }); Inner { - child: StdMutex::new(Some(child)), write_tx, - pending, - processes, + pending: Mutex::new(HashMap::new()), + processes: Mutex::new(HashMap::new()), next_request_id: AtomicI64::new(1), reader_task, - writer_task, } }); let client = Self { inner }; - client.initialize().await?; + client.initialize(options).await?; Ok(client) } @@ -321,6 +326,29 @@ impl ExecServerClient { } }; + if !response.running { + status.mark_exited(response.exit_code); + } + + if let Some(stdout) = response.stdout { + let _ = self + .inner + .processes + .lock() + .await + .get(&process_id) + .map(|process| process.output_tx.send(stdout.into_inner())); + } + if let Some(stderr) = response.stderr { + let _ = self + .inner + .processes + .lock() + .await + .get(&process_id) + .map(|process| process.output_tx.send(stderr.into_inner())); + } + if let Some(exit_code) = response.exit_code { status.mark_exited(Some(exit_code)); } @@ -334,12 +362,15 @@ impl ExecServerClient { }) } - async fn initialize(&self) -> Result<(), ExecServerError> { + async fn initialize( + &self, + options: ExecServerClientConnectOptions, + ) -> Result<(), ExecServerError> { let _: InitializeResponse = self .request( INITIALIZE_METHOD, &InitializeParams { - client_name: "codex-core".to_string(), + client_name: options.client_name, }, ) .await?; @@ -372,6 +403,7 @@ impl ExecServerClient { method: method.to_string(), params: Some(params), })) + .await .map_err(|_| ExecServerError::Closed) } @@ -397,7 +429,7 @@ impl ExecServerClient { trace: None, }); - if self.inner.write_tx.send(message).is_err() { + if self.inner.write_tx.send(message).await.is_err() { self.inner.pending.lock().await.remove(&request_id); return Err(ExecServerError::Closed); } @@ -433,7 +465,7 @@ async fn handle_server_message( } JSONRPCMessage::Request(request) => { return Err(ExecServerError::Protocol(format!( - "unexpected exec-server request from child: {}", + "unexpected exec-server request from remote server: {}", request.method ))); } diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs new file mode 100644 index 0000000000..f9a1ec669f --- /dev/null +++ b/codex-rs/exec-server/src/connection.rs @@ -0,0 +1,262 @@ +use codex_app_server_protocol::JSONRPCMessage; +use futures::SinkExt; +use futures::StreamExt; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::io::BufWriter; +use tokio::sync::mpsc; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::tungstenite::Message; + +pub(crate) const CHANNEL_CAPACITY: usize = 128; + +#[derive(Debug)] +pub(crate) enum JsonRpcConnectionEvent { + Message(JSONRPCMessage), + Disconnected { reason: Option }, +} + +pub(crate) struct JsonRpcConnection { + outgoing_tx: mpsc::Sender, + incoming_rx: mpsc::Receiver, +} + +impl JsonRpcConnection { + pub(crate) fn from_stdio(reader: R, writer: W, connection_label: String) -> Self + where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, + { + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + + let reader_label = connection_label.clone(); + let incoming_tx_for_reader = incoming_tx.clone(); + tokio::spawn(async move { + let mut lines = BufReader::new(reader).lines(); + loop { + match lines.next_line().await { + Ok(Some(line)) => { + if line.trim().is_empty() { + continue; + } + match serde_json::from_str::(&line) { + Ok(message) => { + if incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; + } + } + Err(err) => { + send_disconnected( + &incoming_tx_for_reader, + Some(format!( + "failed to parse JSON-RPC message from {reader_label}: {err}" + )), + ) + .await; + break; + } + } + } + Ok(None) => { + send_disconnected(&incoming_tx_for_reader, None).await; + break; + } + Err(err) => { + send_disconnected( + &incoming_tx_for_reader, + Some(format!( + "failed to read JSON-RPC message from {reader_label}: {err}" + )), + ) + .await; + break; + } + } + } + }); + + tokio::spawn(async move { + let mut writer = BufWriter::new(writer); + while let Some(message) = outgoing_rx.recv().await { + if let Err(err) = write_jsonrpc_line_message(&mut writer, &message).await { + send_disconnected( + &incoming_tx, + Some(format!( + "failed to write JSON-RPC message to {connection_label}: {err}" + )), + ) + .await; + break; + } + } + }); + + Self { + outgoing_tx, + incoming_rx, + } + } + + pub(crate) fn from_websocket(stream: WebSocketStream, connection_label: String) -> Self + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (mut websocket_writer, mut websocket_reader) = stream.split(); + + let reader_label = connection_label.clone(); + let incoming_tx_for_reader = incoming_tx.clone(); + tokio::spawn(async move { + loop { + match websocket_reader.next().await { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::(text.as_ref()) { + Ok(message) => { + if incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; + } + } + Err(err) => { + send_disconnected( + &incoming_tx_for_reader, + Some(format!( + "failed to parse websocket JSON-RPC message from {reader_label}: {err}" + )), + ) + .await; + break; + } + } + } + Some(Ok(Message::Binary(bytes))) => { + match serde_json::from_slice::(bytes.as_ref()) { + Ok(message) => { + if incoming_tx_for_reader + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; + } + } + Err(err) => { + send_disconnected( + &incoming_tx_for_reader, + Some(format!( + "failed to parse websocket JSON-RPC message from {reader_label}: {err}" + )), + ) + .await; + break; + } + } + } + Some(Ok(Message::Close(_))) => { + send_disconnected(&incoming_tx_for_reader, None).await; + break; + } + Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {} + Some(Ok(_)) => {} + Some(Err(err)) => { + send_disconnected( + &incoming_tx_for_reader, + Some(format!( + "failed to read websocket JSON-RPC message from {reader_label}: {err}" + )), + ) + .await; + break; + } + None => { + send_disconnected(&incoming_tx_for_reader, None).await; + break; + } + } + } + }); + + tokio::spawn(async move { + while let Some(message) = outgoing_rx.recv().await { + match serialize_jsonrpc_message(&message) { + Ok(encoded) => { + if let Err(err) = websocket_writer.send(Message::Text(encoded.into())).await + { + send_disconnected( + &incoming_tx, + Some(format!( + "failed to write websocket JSON-RPC message to {connection_label}: {err}" + )), + ) + .await; + break; + } + } + Err(err) => { + send_disconnected( + &incoming_tx, + Some(format!( + "failed to serialize JSON-RPC message for {connection_label}: {err}" + )), + ) + .await; + break; + } + } + } + }); + + Self { + outgoing_tx, + incoming_rx, + } + } + + pub(crate) fn into_parts( + self, + ) -> ( + mpsc::Sender, + mpsc::Receiver, + ) { + (self.outgoing_tx, self.incoming_rx) + } +} + +async fn send_disconnected( + incoming_tx: &mpsc::Sender, + reason: Option, +) { + let _ = incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { reason }) + .await; +} + +async fn write_jsonrpc_line_message( + writer: &mut BufWriter, + message: &JSONRPCMessage, +) -> std::io::Result<()> +where + W: AsyncWrite + Unpin, +{ + let encoded = + serialize_jsonrpc_message(message).map_err(|err| std::io::Error::other(err.to_string()))?; + writer.write_all(encoded.as_bytes()).await?; + writer.write_all(b"\n").await?; + writer.flush().await +} + +fn serialize_jsonrpc_message(message: &JSONRPCMessage) -> Result { + serde_json::to_string(message) +} diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs index 5e915ce5ca..4c975aa5e2 100644 --- a/codex-rs/exec-server/src/lib.rs +++ b/codex-rs/exec-server/src/lib.rs @@ -1,11 +1,17 @@ mod client; +mod connection; +mod local; mod protocol; mod server; pub use client::ExecServerClient; +pub use client::ExecServerClientConnectOptions; pub use client::ExecServerError; -pub use client::ExecServerLaunchCommand; pub use client::ExecServerProcess; +pub use client::RemoteExecServerConnectArgs; +pub use local::ExecServerLaunchCommand; +pub use local::SpawnedExecServer; +pub use local::spawn_local_exec_server; pub use protocol::ExecExitedNotification; pub use protocol::ExecOutputDeltaNotification; pub use protocol::ExecOutputStream; @@ -17,4 +23,7 @@ pub use protocol::TerminateParams; pub use protocol::TerminateResponse; pub use protocol::WriteParams; pub use protocol::WriteResponse; +pub use server::ExecServerTransport; +pub use server::ExecServerTransportParseError; pub use server::run_main; +pub use server::run_main_with_transport; diff --git a/codex-rs/exec-server/src/local.rs b/codex-rs/exec-server/src/local.rs new file mode 100644 index 0000000000..25d8cd3f3f --- /dev/null +++ b/codex-rs/exec-server/src/local.rs @@ -0,0 +1,70 @@ +use std::path::PathBuf; +use std::process::Stdio; +use std::sync::Mutex as StdMutex; + +use tokio::process::Child; +use tokio::process::Command; + +use crate::client::ExecServerClient; +use crate::client::ExecServerClientConnectOptions; +use crate::client::ExecServerError; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExecServerLaunchCommand { + pub program: PathBuf, + pub args: Vec, +} + +pub struct SpawnedExecServer { + client: ExecServerClient, + child: StdMutex>, +} + +impl SpawnedExecServer { + pub fn client(&self) -> &ExecServerClient { + &self.client + } +} + +impl Drop for SpawnedExecServer { + fn drop(&mut self) { + if let Ok(mut child_guard) = self.child.lock() + && let Some(child) = child_guard.as_mut() + { + let _ = child.start_kill(); + } + } +} + +pub async fn spawn_local_exec_server( + command: ExecServerLaunchCommand, + options: ExecServerClientConnectOptions, +) -> Result { + let mut child = Command::new(&command.program); + child.args(&command.args); + child.stdin(Stdio::piped()); + child.stdout(Stdio::piped()); + child.stderr(Stdio::inherit()); + child.kill_on_drop(true); + + let mut child = child.spawn().map_err(ExecServerError::Spawn)?; + let stdin = child.stdin.take().ok_or_else(|| { + ExecServerError::Protocol("exec-server stdin was not captured".to_string()) + })?; + let stdout = child.stdout.take().ok_or_else(|| { + ExecServerError::Protocol("exec-server stdout was not captured".to_string()) + })?; + + let client = match ExecServerClient::connect_stdio(stdin, stdout, options).await { + Ok(client) => client, + Err(err) => { + let _ = child.start_kill(); + return Err(err); + } + }; + + Ok(SpawnedExecServer { + client, + child: StdMutex::new(Some(child)), + }) +} diff --git a/codex-rs/exec-server/src/server.rs b/codex-rs/exec-server/src/server.rs index 56d2206b0c..2a7ee08d7d 100644 --- a/codex-rs/exec-server/src/server.rs +++ b/codex-rs/exec-server/src/server.rs @@ -1,420 +1,15 @@ -use std::collections::HashMap; -use std::collections::VecDeque; -use std::sync::Arc; -use std::sync::Mutex as StdMutex; +mod processor; +mod transport; -use codex_app_server_protocol::JSONRPCError; -use codex_app_server_protocol::JSONRPCErrorError; -use codex_app_server_protocol::JSONRPCMessage; -use codex_app_server_protocol::JSONRPCNotification; -use codex_app_server_protocol::JSONRPCRequest; -use codex_app_server_protocol::JSONRPCResponse; -use codex_app_server_protocol::RequestId; -use codex_utils_pty::ExecCommandSession; -use codex_utils_pty::TerminalSize; -use serde::Serialize; -use tokio::io::AsyncBufReadExt; -use tokio::io::AsyncWriteExt; -use tokio::io::BufReader; -use tokio::io::BufWriter; -use tokio::sync::Mutex; - -use crate::protocol::EXEC_EXITED_METHOD; -use crate::protocol::EXEC_METHOD; -use crate::protocol::EXEC_OUTPUT_DELTA_METHOD; -use crate::protocol::EXEC_TERMINATE_METHOD; -use crate::protocol::EXEC_WRITE_METHOD; -use crate::protocol::ExecExitedNotification; -use crate::protocol::ExecOutputDeltaNotification; -use crate::protocol::ExecOutputStream; -use crate::protocol::ExecParams; -use crate::protocol::ExecResponse; -use crate::protocol::INITIALIZE_METHOD; -use crate::protocol::INITIALIZED_METHOD; -use crate::protocol::InitializeResponse; -use crate::protocol::PROTOCOL_VERSION; -use crate::protocol::TerminateParams; -use crate::protocol::TerminateResponse; -use crate::protocol::WriteParams; -use crate::protocol::WriteResponse; - -struct RunningProcess { - session: ExecCommandSession, - tty: bool, - stdout_buffer: Arc>, - stderr_buffer: Arc>, -} - -#[derive(Debug)] -struct BoundedBytesBuffer { - max_bytes: usize, - bytes: VecDeque, -} - -impl BoundedBytesBuffer { - fn new(max_bytes: usize) -> Self { - Self { - max_bytes, - bytes: VecDeque::with_capacity(max_bytes.min(8192)), - } - } - - fn push_chunk(&mut self, chunk: &[u8]) { - if self.max_bytes == 0 { - return; - } - for byte in chunk { - self.bytes.push_back(*byte); - if self.bytes.len() > self.max_bytes { - self.bytes.pop_front(); - } - } - } - - fn snapshot(&self) -> Vec { - self.bytes.iter().copied().collect() - } -} +pub use transport::ExecServerTransport; +pub use transport::ExecServerTransportParseError; pub async fn run_main() -> Result<(), Box> { - let writer = Arc::new(Mutex::new(BufWriter::new(tokio::io::stdout()))); - let processes = Arc::new(Mutex::new(HashMap::::new())); - let mut lines = BufReader::new(tokio::io::stdin()).lines(); - - while let Some(line) = lines.next_line().await? { - if line.trim().is_empty() { - continue; - } - - let message = serde_json::from_str::(&line)?; - if let JSONRPCMessage::Request(request) = message { - handle_request(request, &writer, &processes).await; - continue; - } - - if let JSONRPCMessage::Notification(notification) = message { - if notification.method != INITIALIZED_METHOD { - send_error( - &writer, - RequestId::Integer(-1), - invalid_request(format!( - "unexpected notification method: {}", - notification.method - )), - ) - .await; - } - continue; - } - } - - let remaining = { - let mut processes = processes.lock().await; - processes - .drain() - .map(|(_, process)| process) - .collect::>() - }; - for process in remaining { - process.session.terminate(); - } - - Ok(()) + run_main_with_transport(ExecServerTransport::Stdio).await } -async fn handle_request( - request: JSONRPCRequest, - writer: &Arc>>, - processes: &Arc>>, -) { - let response = match request.method.as_str() { - INITIALIZE_METHOD => serde_json::to_value(InitializeResponse { - protocol_version: PROTOCOL_VERSION.to_string(), - }) - .map_err(|err| internal_error(err.to_string())), - EXEC_METHOD => handle_exec_request(request.params, writer, processes).await, - EXEC_WRITE_METHOD => handle_write_request(request.params, processes).await, - EXEC_TERMINATE_METHOD => handle_terminate_request(request.params, processes).await, - other => Err(invalid_request(format!("unknown method: {other}"))), - }; - - match response { - Ok(result) => { - send_response( - writer, - JSONRPCResponse { - id: request.id, - result, - }, - ) - .await; - } - Err(err) => { - send_error(writer, request.id, err).await; - } - } -} - -async fn handle_exec_request( - params: Option, - writer: &Arc>>, - processes: &Arc>>, -) -> Result { - let params: ExecParams = serde_json::from_value(params.unwrap_or(serde_json::Value::Null)) - .map_err(|err| invalid_params(err.to_string()))?; - - let (program, args) = params - .argv - .split_first() - .ok_or_else(|| invalid_params("argv must not be empty".to_string()))?; - - let spawned = if params.tty { - codex_utils_pty::spawn_pty_process( - program, - args, - params.cwd.as_path(), - ¶ms.env, - ¶ms.arg0, - TerminalSize::default(), - ) - .await - } else { - codex_utils_pty::spawn_pipe_process_no_stdin( - program, - args, - params.cwd.as_path(), - ¶ms.env, - ¶ms.arg0, - ) - .await - } - .map_err(|err| internal_error(err.to_string()))?; - - let stdout_buffer = Arc::new(StdMutex::new(BoundedBytesBuffer::new( - params.output_bytes_cap, - ))); - let stderr_buffer = Arc::new(StdMutex::new(BoundedBytesBuffer::new( - params.output_bytes_cap, - ))); - - let process_id = params.process_id.clone(); - { - let mut process_map = processes.lock().await; - if process_map.contains_key(&process_id) { - spawned.session.terminate(); - return Err(invalid_request(format!( - "process {} already exists", - params.process_id - ))); - } - process_map.insert( - process_id.clone(), - RunningProcess { - session: spawned.session, - tty: params.tty, - stdout_buffer: Arc::clone(&stdout_buffer), - stderr_buffer: Arc::clone(&stderr_buffer), - }, - ); - } - - tokio::spawn(stream_output( - process_id.clone(), - ExecOutputStream::Stdout, - spawned.stdout_rx, - Arc::clone(writer), - Arc::clone(&stdout_buffer), - )); - tokio::spawn(stream_output( - process_id.clone(), - ExecOutputStream::Stderr, - spawned.stderr_rx, - Arc::clone(writer), - Arc::clone(&stderr_buffer), - )); - tokio::spawn(watch_exit( - process_id.clone(), - spawned.exit_rx, - Arc::clone(writer), - Arc::clone(processes), - )); - - serde_json::to_value(ExecResponse { - process_id, - running: true, - exit_code: None, - stdout: None, - stderr: None, - }) - .map_err(|err| internal_error(err.to_string())) -} - -async fn handle_write_request( - params: Option, - processes: &Arc>>, -) -> Result { - let params: WriteParams = serde_json::from_value(params.unwrap_or(serde_json::Value::Null)) - .map_err(|err| invalid_params(err.to_string()))?; - - let writer_tx = { - let process_map = processes.lock().await; - let process = process_map - .get(¶ms.process_id) - .ok_or_else(|| invalid_request(format!("unknown process id {}", params.process_id)))?; - if !process.tty { - return Err(invalid_request(format!( - "stdin is closed for process {}", - params.process_id - ))); - } - process.session.writer_sender() - }; - - writer_tx - .send(params.chunk.into_inner()) - .await - .map_err(|_| internal_error("failed to write to process stdin".to_string()))?; - - serde_json::to_value(WriteResponse { accepted: true }) - .map_err(|err| internal_error(err.to_string())) -} - -async fn handle_terminate_request( - params: Option, - processes: &Arc>>, -) -> Result { - let params: TerminateParams = serde_json::from_value(params.unwrap_or(serde_json::Value::Null)) - .map_err(|err| invalid_params(err.to_string()))?; - - let process = { - let mut process_map = processes.lock().await; - process_map.remove(¶ms.process_id) - }; - - if let Some(process) = process { - process.session.terminate(); - serde_json::to_value(TerminateResponse { running: true }) - .map_err(|err| internal_error(err.to_string())) - } else { - serde_json::to_value(TerminateResponse { running: false }) - .map_err(|err| internal_error(err.to_string())) - } -} - -async fn stream_output( - process_id: String, - stream: ExecOutputStream, - mut receiver: tokio::sync::mpsc::Receiver>, - writer: Arc>>, - buffer: Arc>, -) { - while let Some(chunk) = receiver.recv().await { - if let Ok(mut guard) = buffer.lock() { - guard.push_chunk(&chunk); - } - let notification = ExecOutputDeltaNotification { - process_id: process_id.clone(), - stream, - chunk: chunk.into(), - }; - if send_notification(&writer, EXEC_OUTPUT_DELTA_METHOD, ¬ification) - .await - .is_err() - { - break; - } - } -} - -async fn watch_exit( - process_id: String, - exit_rx: tokio::sync::oneshot::Receiver, - writer: Arc>>, - processes: Arc>>, -) { - let exit_code = exit_rx.await.unwrap_or(-1); - let removed = { - let mut processes = processes.lock().await; - processes.remove(&process_id) - }; - if let Some(process) = removed { - let _ = process.stdout_buffer.lock().map(|buffer| buffer.snapshot()); - let _ = process.stderr_buffer.lock().map(|buffer| buffer.snapshot()); - } - let _ = send_notification( - &writer, - EXEC_EXITED_METHOD, - &ExecExitedNotification { - process_id, - exit_code, - }, - ) - .await; -} - -async fn send_response( - writer: &Arc>>, - response: JSONRPCResponse, -) { - let _ = send_message(writer, JSONRPCMessage::Response(response)).await; -} - -async fn send_error( - writer: &Arc>>, - id: RequestId, - error: JSONRPCErrorError, -) { - let _ = send_message(writer, JSONRPCMessage::Error(JSONRPCError { error, id })).await; -} - -async fn send_notification( - writer: &Arc>>, - method: &str, - params: &T, -) -> Result<(), serde_json::Error> { - send_message( - writer, - JSONRPCMessage::Notification(JSONRPCNotification { - method: method.to_string(), - params: Some(serde_json::to_value(params)?), - }), - ) - .await - .map_err(serde_json::Error::io) -} - -async fn send_message( - writer: &Arc>>, - message: JSONRPCMessage, -) -> std::io::Result<()> { - let encoded = - serde_json::to_vec(&message).map_err(|err| std::io::Error::other(err.to_string()))?; - let mut writer = writer.lock().await; - writer.write_all(&encoded).await?; - writer.write_all(b"\n").await?; - writer.flush().await -} - -fn invalid_request(message: String) -> JSONRPCErrorError { - JSONRPCErrorError { - code: -32600, - data: None, - message, - } -} - -fn invalid_params(message: String) -> JSONRPCErrorError { - JSONRPCErrorError { - code: -32602, - data: None, - message, - } -} - -fn internal_error(message: String) -> JSONRPCErrorError { - JSONRPCErrorError { - code: -32603, - data: None, - message, - } +pub async fn run_main_with_transport( + transport: ExecServerTransport, +) -> Result<(), Box> { + transport::run_transport(transport).await } diff --git a/codex-rs/exec-server/src/server/processor.rs b/codex-rs/exec-server/src/server/processor.rs new file mode 100644 index 0000000000..d0f9252e3a --- /dev/null +++ b/codex-rs/exec-server/src/server/processor.rs @@ -0,0 +1,468 @@ +use std::collections::HashMap; +use std::collections::VecDeque; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; + +use codex_app_server_protocol::JSONRPCError; +use codex_app_server_protocol::JSONRPCErrorError; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCNotification; +use codex_app_server_protocol::JSONRPCRequest; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; +use codex_utils_pty::ExecCommandSession; +use codex_utils_pty::TerminalSize; +use tokio::sync::Mutex; +use tokio::sync::mpsc; +use tracing::debug; +use tracing::warn; + +use crate::connection::JsonRpcConnection; +use crate::connection::JsonRpcConnectionEvent; +use crate::protocol::EXEC_EXITED_METHOD; +use crate::protocol::EXEC_METHOD; +use crate::protocol::EXEC_OUTPUT_DELTA_METHOD; +use crate::protocol::EXEC_TERMINATE_METHOD; +use crate::protocol::EXEC_WRITE_METHOD; +use crate::protocol::ExecExitedNotification; +use crate::protocol::ExecOutputDeltaNotification; +use crate::protocol::ExecOutputStream; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; +use crate::protocol::INITIALIZE_METHOD; +use crate::protocol::INITIALIZED_METHOD; +use crate::protocol::InitializeResponse; +use crate::protocol::PROTOCOL_VERSION; +use crate::protocol::TerminateParams; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteParams; +use crate::protocol::WriteResponse; + +struct RunningProcess { + session: ExecCommandSession, + tty: bool, + stdout_buffer: Arc>, + stderr_buffer: Arc>, +} + +#[derive(Debug)] +struct BoundedBytesBuffer { + max_bytes: usize, + bytes: VecDeque, +} + +impl BoundedBytesBuffer { + fn new(max_bytes: usize) -> Self { + Self { + max_bytes, + bytes: VecDeque::with_capacity(max_bytes.min(8192)), + } + } + + fn push_chunk(&mut self, chunk: &[u8]) { + if self.max_bytes == 0 { + return; + } + for byte in chunk { + self.bytes.push_back(*byte); + if self.bytes.len() > self.max_bytes { + self.bytes.pop_front(); + } + } + } + + fn snapshot(&self) -> Vec { + self.bytes.iter().copied().collect() + } +} + +pub(crate) async fn run_connection(connection: JsonRpcConnection) { + let (outgoing_tx, mut incoming_rx) = connection.into_parts(); + let mut processor = ExecServerConnectionProcessor::new(outgoing_tx); + + while let Some(event) = incoming_rx.recv().await { + match event { + JsonRpcConnectionEvent::Message(message) => { + if let Err(err) = processor.handle_message(message).await { + warn!("closing exec-server connection after protocol error: {err}"); + break; + } + } + JsonRpcConnectionEvent::Disconnected { reason } => { + if let Some(reason) = reason { + debug!("exec-server connection disconnected: {reason}"); + } + break; + } + } + } + + processor.shutdown().await; +} + +struct ExecServerConnectionProcessor { + outgoing_tx: mpsc::Sender, + processes: Arc>>, + initialize_requested: bool, + initialized: bool, +} + +impl ExecServerConnectionProcessor { + fn new(outgoing_tx: mpsc::Sender) -> Self { + Self { + outgoing_tx, + processes: Arc::new(Mutex::new(HashMap::new())), + initialize_requested: false, + initialized: false, + } + } + + async fn shutdown(&self) { + let remaining = { + let mut processes = self.processes.lock().await; + processes + .drain() + .map(|(_, process)| process) + .collect::>() + }; + for process in remaining { + process.session.terminate(); + } + } + + async fn handle_message(&mut self, message: JSONRPCMessage) -> Result<(), String> { + match message { + JSONRPCMessage::Request(request) => self.handle_request(request).await, + JSONRPCMessage::Notification(notification) => self.handle_notification(notification), + JSONRPCMessage::Response(response) => Err(format!( + "unexpected client response for request id {:?}", + response.id + )), + JSONRPCMessage::Error(error) => Err(format!( + "unexpected client error for request id {:?}", + error.id + )), + } + } + + async fn handle_request(&mut self, request: JSONRPCRequest) -> Result<(), String> { + let response = match request.method.as_str() { + INITIALIZE_METHOD => self.handle_initialize_request(), + EXEC_METHOD => match self.require_initialized() { + Ok(()) => self.handle_exec_request(request.params).await, + Err(err) => Err(err), + }, + EXEC_WRITE_METHOD => match self.require_initialized() { + Ok(()) => self.handle_write_request(request.params).await, + Err(err) => Err(err), + }, + EXEC_TERMINATE_METHOD => match self.require_initialized() { + Ok(()) => self.handle_terminate_request(request.params).await, + Err(err) => Err(err), + }, + other => Err(invalid_request(format!("unknown method: {other}"))), + }; + + match response { + Ok(result) => { + self.send_response(JSONRPCResponse { + id: request.id, + result, + }) + .await; + } + Err(error) => { + self.send_error(request.id, error).await; + } + } + + Ok(()) + } + + fn handle_notification(&mut self, notification: JSONRPCNotification) -> Result<(), String> { + match notification.method.as_str() { + INITIALIZED_METHOD => { + if !self.initialize_requested { + return Err("received `initialized` notification before `initialize`".into()); + } + self.initialized = true; + Ok(()) + } + other => Err(format!("unexpected notification method: {other}")), + } + } + + fn handle_initialize_request(&mut self) -> Result { + if self.initialize_requested { + return Err(invalid_request( + "initialize may only be sent once per connection".to_string(), + )); + } + self.initialize_requested = true; + json_value(InitializeResponse { + protocol_version: PROTOCOL_VERSION.to_string(), + }) + } + + fn require_initialized(&self) -> Result<(), JSONRPCErrorError> { + if !self.initialize_requested { + return Err(invalid_request( + "client must call initialize before using exec methods".to_string(), + )); + } + if !self.initialized { + return Err(invalid_request( + "client must send initialized before using exec methods".to_string(), + )); + } + Ok(()) + } + + async fn handle_exec_request( + &self, + params: Option, + ) -> Result { + let params: ExecParams = serde_json::from_value(params.unwrap_or(serde_json::Value::Null)) + .map_err(|err| invalid_params(err.to_string()))?; + + let (program, args) = params + .argv + .split_first() + .ok_or_else(|| invalid_params("argv must not be empty".to_string()))?; + + let spawned = if params.tty { + codex_utils_pty::spawn_pty_process( + program, + args, + params.cwd.as_path(), + ¶ms.env, + ¶ms.arg0, + TerminalSize::default(), + ) + .await + } else { + codex_utils_pty::spawn_pipe_process_no_stdin( + program, + args, + params.cwd.as_path(), + ¶ms.env, + ¶ms.arg0, + ) + .await + } + .map_err(|err| internal_error(err.to_string()))?; + + let stdout_buffer = Arc::new(StdMutex::new(BoundedBytesBuffer::new( + params.output_bytes_cap, + ))); + let stderr_buffer = Arc::new(StdMutex::new(BoundedBytesBuffer::new( + params.output_bytes_cap, + ))); + + let process_id = params.process_id.clone(); + { + let mut process_map = self.processes.lock().await; + if process_map.contains_key(&process_id) { + spawned.session.terminate(); + return Err(invalid_request(format!( + "process {process_id} already exists" + ))); + } + process_map.insert( + process_id.clone(), + RunningProcess { + session: spawned.session, + tty: params.tty, + stdout_buffer: Arc::clone(&stdout_buffer), + stderr_buffer: Arc::clone(&stderr_buffer), + }, + ); + } + + tokio::spawn(stream_output( + process_id.clone(), + ExecOutputStream::Stdout, + spawned.stdout_rx, + self.outgoing_tx.clone(), + Arc::clone(&stdout_buffer), + )); + tokio::spawn(stream_output( + process_id.clone(), + ExecOutputStream::Stderr, + spawned.stderr_rx, + self.outgoing_tx.clone(), + Arc::clone(&stderr_buffer), + )); + tokio::spawn(watch_exit( + process_id.clone(), + spawned.exit_rx, + self.outgoing_tx.clone(), + Arc::clone(&self.processes), + )); + + json_value(ExecResponse { + process_id, + running: true, + exit_code: None, + stdout: None, + stderr: None, + }) + } + + async fn handle_write_request( + &self, + params: Option, + ) -> Result { + let params: WriteParams = serde_json::from_value(params.unwrap_or(serde_json::Value::Null)) + .map_err(|err| invalid_params(err.to_string()))?; + + let writer_tx = { + let process_map = self.processes.lock().await; + let process = process_map.get(¶ms.process_id).ok_or_else(|| { + invalid_request(format!("unknown process id {}", params.process_id)) + })?; + if !process.tty { + return Err(invalid_request(format!( + "stdin is closed for process {}", + params.process_id + ))); + } + process.session.writer_sender() + }; + + writer_tx + .send(params.chunk.into_inner()) + .await + .map_err(|_| internal_error("failed to write to process stdin".to_string()))?; + + json_value(WriteResponse { accepted: true }) + } + + async fn handle_terminate_request( + &self, + params: Option, + ) -> Result { + let params: TerminateParams = + serde_json::from_value(params.unwrap_or(serde_json::Value::Null)) + .map_err(|err| invalid_params(err.to_string()))?; + + let process = { + let mut process_map = self.processes.lock().await; + process_map.remove(¶ms.process_id) + }; + + if let Some(process) = process { + process.session.terminate(); + json_value(TerminateResponse { running: true }) + } else { + json_value(TerminateResponse { running: false }) + } + } + + async fn send_response(&self, response: JSONRPCResponse) { + let _ = self + .outgoing_tx + .send(JSONRPCMessage::Response(response)) + .await; + } + + async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) { + let _ = self + .outgoing_tx + .send(JSONRPCMessage::Error(JSONRPCError { error, id })) + .await; + } +} + +async fn stream_output( + process_id: String, + stream: ExecOutputStream, + mut receiver: tokio::sync::mpsc::Receiver>, + outgoing_tx: mpsc::Sender, + buffer: Arc>, +) { + while let Some(chunk) = receiver.recv().await { + if let Ok(mut guard) = buffer.lock() { + guard.push_chunk(&chunk); + } + let notification = ExecOutputDeltaNotification { + process_id: process_id.clone(), + stream, + chunk: chunk.into(), + }; + if send_notification(&outgoing_tx, EXEC_OUTPUT_DELTA_METHOD, ¬ification) + .await + .is_err() + { + break; + } + } +} + +async fn watch_exit( + process_id: String, + exit_rx: tokio::sync::oneshot::Receiver, + outgoing_tx: mpsc::Sender, + processes: Arc>>, +) { + let exit_code = exit_rx.await.unwrap_or(-1); + let removed = { + let mut processes = processes.lock().await; + processes.remove(&process_id) + }; + if let Some(process) = removed { + let _ = process.stdout_buffer.lock().map(|buffer| buffer.snapshot()); + let _ = process.stderr_buffer.lock().map(|buffer| buffer.snapshot()); + } + let _ = send_notification( + &outgoing_tx, + EXEC_EXITED_METHOD, + &ExecExitedNotification { + process_id, + exit_code, + }, + ) + .await; +} + +async fn send_notification( + outgoing_tx: &mpsc::Sender, + method: &str, + params: &T, +) -> Result<(), serde_json::Error> { + outgoing_tx + .send(JSONRPCMessage::Notification(JSONRPCNotification { + method: method.to_string(), + params: Some(serde_json::to_value(params)?), + })) + .await + .map_err(|_| serde_json::Error::io(std::io::Error::other("connection closed"))) +} + +fn json_value(value: T) -> Result { + serde_json::to_value(value).map_err(|err| internal_error(err.to_string())) +} + +fn invalid_request(message: String) -> JSONRPCErrorError { + JSONRPCErrorError { + code: -32600, + data: None, + message, + } +} + +fn invalid_params(message: String) -> JSONRPCErrorError { + JSONRPCErrorError { + code: -32602, + data: None, + message, + } +} + +fn internal_error(message: String) -> JSONRPCErrorError { + JSONRPCErrorError { + code: -32603, + data: None, + message, + } +} diff --git a/codex-rs/exec-server/src/server/transport.rs b/codex-rs/exec-server/src/server/transport.rs new file mode 100644 index 0000000000..b653c0b79b --- /dev/null +++ b/codex-rs/exec-server/src/server/transport.rs @@ -0,0 +1,166 @@ +use std::net::SocketAddr; +use std::str::FromStr; + +use tokio::net::TcpListener; +use tokio_tungstenite::accept_async; +use tracing::warn; + +use crate::connection::JsonRpcConnection; +use crate::server::processor::run_connection; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ExecServerTransport { + Stdio, + WebSocket { bind_address: SocketAddr }, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum ExecServerTransportParseError { + UnsupportedListenUrl(String), + InvalidWebSocketListenUrl(String), +} + +impl std::fmt::Display for ExecServerTransportParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ExecServerTransportParseError::UnsupportedListenUrl(listen_url) => write!( + f, + "unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`" + ), + ExecServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!( + f, + "invalid websocket --listen URL `{listen_url}`; expected `ws://IP:PORT`" + ), + } + } +} + +impl std::error::Error for ExecServerTransportParseError {} + +impl ExecServerTransport { + pub const DEFAULT_LISTEN_URL: &str = "stdio://"; + + pub fn from_listen_url(listen_url: &str) -> Result { + if listen_url == Self::DEFAULT_LISTEN_URL { + return Ok(Self::Stdio); + } + + if let Some(socket_addr) = listen_url.strip_prefix("ws://") { + let bind_address = socket_addr.parse::().map_err(|_| { + ExecServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string()) + })?; + return Ok(Self::WebSocket { bind_address }); + } + + Err(ExecServerTransportParseError::UnsupportedListenUrl( + listen_url.to_string(), + )) + } +} + +impl FromStr for ExecServerTransport { + type Err = ExecServerTransportParseError; + + fn from_str(s: &str) -> Result { + Self::from_listen_url(s) + } +} + +pub(crate) async fn run_transport( + transport: ExecServerTransport, +) -> Result<(), Box> { + match transport { + ExecServerTransport::Stdio => { + run_connection(JsonRpcConnection::from_stdio( + tokio::io::stdin(), + tokio::io::stdout(), + "exec-server stdio".to_string(), + )) + .await; + Ok(()) + } + ExecServerTransport::WebSocket { bind_address } => { + run_websocket_listener(bind_address).await + } + } +} + +async fn run_websocket_listener( + bind_address: SocketAddr, +) -> Result<(), Box> { + let listener = TcpListener::bind(bind_address).await?; + let local_addr = listener.local_addr()?; + print_websocket_startup_banner(local_addr); + + loop { + let (stream, peer_addr) = listener.accept().await?; + tokio::spawn(async move { + match accept_async(stream).await { + Ok(websocket) => { + run_connection(JsonRpcConnection::from_websocket( + websocket, + format!("exec-server websocket {peer_addr}"), + )) + .await; + } + Err(err) => { + warn!( + "failed to accept exec-server websocket connection from {peer_addr}: {err}" + ); + } + } + }); + } +} + +#[allow(clippy::print_stderr)] +fn print_websocket_startup_banner(addr: SocketAddr) { + eprintln!("codex-exec-server listening on ws://{addr}"); +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::ExecServerTransport; + + #[test] + fn exec_server_transport_parses_stdio_listen_url() { + let transport = + ExecServerTransport::from_listen_url(ExecServerTransport::DEFAULT_LISTEN_URL) + .expect("stdio listen URL should parse"); + assert_eq!(transport, ExecServerTransport::Stdio); + } + + #[test] + fn exec_server_transport_parses_websocket_listen_url() { + let transport = ExecServerTransport::from_listen_url("ws://127.0.0.1:1234") + .expect("websocket listen URL should parse"); + assert_eq!( + transport, + ExecServerTransport::WebSocket { + bind_address: "127.0.0.1:1234".parse().expect("valid socket address"), + } + ); + } + + #[test] + fn exec_server_transport_rejects_invalid_websocket_listen_url() { + let err = ExecServerTransport::from_listen_url("ws://localhost:1234") + .expect_err("hostname bind address should be rejected"); + assert_eq!( + err.to_string(), + "invalid websocket --listen URL `ws://localhost:1234`; expected `ws://IP:PORT`" + ); + } + + #[test] + fn exec_server_transport_rejects_unsupported_listen_url() { + let err = ExecServerTransport::from_listen_url("http://127.0.0.1:1234") + .expect_err("unsupported scheme should fail"); + assert_eq!( + err.to_string(), + "unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`" + ); + } +} diff --git a/codex-rs/exec-server/tests/stdio_smoke.rs b/codex-rs/exec-server/tests/stdio_smoke.rs index 79f264327c..0fd464ace0 100644 --- a/codex-rs/exec-server/tests/stdio_smoke.rs +++ b/codex-rs/exec-server/tests/stdio_smoke.rs @@ -3,6 +3,7 @@ use std::process::Stdio; use std::time::Duration; +use anyhow::Context; use codex_app_server_protocol::JSONRPCMessage; use codex_app_server_protocol::JSONRPCNotification; use codex_app_server_protocol::JSONRPCRequest; @@ -10,9 +11,12 @@ use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::RequestId; use codex_exec_server::ExecParams; use codex_exec_server::ExecServerClient; +use codex_exec_server::ExecServerClientConnectOptions; use codex_exec_server::ExecServerLaunchCommand; use codex_exec_server::InitializeParams; use codex_exec_server::InitializeResponse; +use codex_exec_server::RemoteExecServerConnectArgs; +use codex_exec_server::spawn_local_exec_server; use codex_utils_cargo_bin::cargo_bin; use pretty_assertions::assert_eq; use tokio::io::AsyncBufReadExt; @@ -76,13 +80,19 @@ async fn exec_server_client_streams_output_and_accepts_writes() -> anyhow::Resul env.insert("PATH".to_string(), path.to_string_lossy().into_owned()); } - let client = ExecServerClient::spawn(ExecServerLaunchCommand { - program: cargo_bin("codex-exec-server")?, - args: Vec::new(), - }) + let server = spawn_local_exec_server( + ExecServerLaunchCommand { + program: cargo_bin("codex-exec-server")?, + args: Vec::new(), + }, + ExecServerClientConnectOptions { + client_name: "exec-server-test".to_string(), + }, + ) .await?; - let process = client + let process = server + .client() .start_process(ExecParams { process_id: "2001".to_string(), argv: vec![ @@ -124,6 +134,86 @@ async fn exec_server_client_streams_output_and_accepts_writes() -> anyhow::Resul Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_server_client_connects_over_websocket() -> anyhow::Result<()> { + let mut env = std::collections::HashMap::new(); + if let Some(path) = std::env::var_os("PATH") { + env.insert("PATH".to_string(), path.to_string_lossy().into_owned()); + } + + let binary = cargo_bin("codex-exec-server")?; + let mut child = Command::new(binary); + child.args(["--listen", "ws://127.0.0.1:0"]); + child.stdin(Stdio::null()); + child.stdout(Stdio::null()); + child.stderr(Stdio::piped()); + let mut child = child.spawn()?; + let stderr = child.stderr.take().expect("stderr"); + let mut stderr_lines = BufReader::new(stderr).lines(); + let websocket_url = read_websocket_url(&mut stderr_lines).await?; + + let client = ExecServerClient::connect_websocket(RemoteExecServerConnectArgs { + websocket_url, + client_name: "exec-server-test".to_string(), + }) + .await?; + + let process = client + .start_process(ExecParams { + process_id: "2002".to_string(), + argv: vec![ + "bash".to_string(), + "-lc".to_string(), + "printf 'ready\\n'; while IFS= read -r line; do printf 'echo:%s\\n' \"$line\"; done" + .to_string(), + ], + cwd: std::env::current_dir()?, + env, + tty: true, + output_bytes_cap: 4096, + arg0: None, + }) + .await?; + + let mut output = process.output_receiver(); + assert!( + recv_until_contains(&mut output, "ready") + .await? + .contains("ready"), + "expected initial ready output" + ); + + process + .writer_sender() + .send(b"hello\n".to_vec()) + .await + .expect("write should succeed"); + + assert!( + recv_until_contains(&mut output, "echo:hello") + .await? + .contains("echo:hello"), + "expected echoed output" + ); + + process.terminate(); + child.start_kill()?; + Ok(()) +} + +async fn read_websocket_url(lines: &mut tokio::io::Lines>) -> anyhow::Result +where + R: tokio::io::AsyncRead + Unpin, +{ + let line = timeout(Duration::from_secs(5), lines.next_line()).await??; + let line = line.context("missing websocket startup banner")?; + let websocket_url = line + .split_whitespace() + .find(|part| part.starts_with("ws://")) + .context("missing websocket URL in startup banner")?; + Ok(websocket_url.to_string()) +} + async fn recv_until_contains( output: &mut broadcast::Receiver>, needle: &str,