diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index 0302aa792e..4997c1605c 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -310,7 +310,13 @@ impl ExecServerClient { pub(crate) async fn register_session( &self, process_id: &str, - ) -> Result, ExecServerError> { + ) -> Result< + ( + broadcast::Sender, + broadcast::Receiver, + ), + ExecServerError, + > { let (events_tx, events_rx) = broadcast::channel(256); let mut sessions = self.inner.sessions.lock().await; if sessions.contains_key(process_id) { @@ -318,8 +324,8 @@ impl ExecServerClient { "session already registered for process {process_id}" ))); } - sessions.insert(process_id.to_string(), events_tx); - Ok(events_rx) + sessions.insert(process_id.to_string(), events_tx.clone()); + Ok((events_tx, events_rx)) } pub(crate) async fn unregister_session(&self, process_id: &str) { diff --git a/codex-rs/exec-server/src/local_process.rs b/codex-rs/exec-server/src/local_process.rs index 2256107950..a3df684f6a 100644 --- a/codex-rs/exec-server/src/local_process.rs +++ b/codex-rs/exec-server/src/local_process.rs @@ -89,7 +89,8 @@ pub(crate) struct LocalProcess { struct LocalExecProcess { process_id: ProcessId, - events: StdMutex>, + events_tx: broadcast::Sender, + initial_events_rx: StdMutex>>, backend: LocalProcess, } @@ -173,7 +174,14 @@ impl LocalProcess { async fn start_process( &self, params: ExecParams, - ) -> Result<(ExecResponse, broadcast::Receiver), JSONRPCErrorError> { + ) -> Result< + ( + ExecResponse, + broadcast::Sender, + broadcast::Receiver, + ), + JSONRPCErrorError, + > { self.require_initialized_for("exec")?; let process_id = params.process_id.clone(); warn!( @@ -249,7 +257,7 @@ impl LocalProcess { next_seq: 1, exit_code: None, output_notify: Arc::clone(&output_notify), - session_events_tx, + session_events_tx: session_events_tx.clone(), open_streams: 2, closed: false, })), @@ -290,13 +298,17 @@ impl LocalProcess { tty = params.tty, "exec-server started process" ); - Ok((ExecResponse { process_id }, session_events_rx)) + Ok(( + ExecResponse { process_id }, + session_events_tx, + session_events_rx, + )) } pub(crate) async fn exec(&self, params: ExecParams) -> Result { self.start_process(params) .await - .map(|(response, _)| response) + .map(|(response, _, _)| response) } pub(crate) async fn exec_read( @@ -458,13 +470,14 @@ impl LocalProcess { #[async_trait] impl ExecBackend for LocalProcess { async fn start(&self, params: ExecParams) -> Result, ExecServerError> { - let (response, events) = self + let (response, events_tx, events_rx) = self .start_process(params) .await .map_err(map_handler_error)?; Ok(Arc::new(LocalExecProcess { process_id: response.process_id.into(), - events: StdMutex::new(events), + events_tx, + initial_events_rx: StdMutex::new(Some(events_rx)), backend: self.clone(), })) } @@ -477,10 +490,13 @@ impl ExecProcess for LocalExecProcess { } fn subscribe(&self) -> broadcast::Receiver { - self.events + let mut initial_events_rx = self + .initial_events_rx .lock() - .expect("local exec process events mutex should not be poisoned") - .resubscribe() + .unwrap_or_else(std::sync::PoisonError::into_inner); + initial_events_rx + .take() + .unwrap_or_else(|| self.events_tx.subscribe()) } async fn write(&self, chunk: Vec) -> Result<(), ExecServerError> { diff --git a/codex-rs/exec-server/src/remote_process.rs b/codex-rs/exec-server/src/remote_process.rs index 312069fe42..9e27a85350 100644 --- a/codex-rs/exec-server/src/remote_process.rs +++ b/codex-rs/exec-server/src/remote_process.rs @@ -19,7 +19,8 @@ pub(crate) struct RemoteProcess { struct RemoteExecProcess { process_id: ProcessId, - events: StdMutex>, + events_tx: broadcast::Sender, + initial_events_rx: StdMutex>>, backend: RemoteProcess, } @@ -49,7 +50,7 @@ impl RemoteProcess { impl ExecBackend for RemoteProcess { async fn start(&self, params: ExecParams) -> Result, ExecServerError> { let process_id = params.process_id.clone(); - let events = self.client.register_session(&process_id).await?; + let (events_tx, events_rx) = self.client.register_session(&process_id).await?; if let Err(err) = self.client.exec(params).await { self.client.unregister_session(&process_id).await; return Err(err); @@ -57,7 +58,8 @@ impl ExecBackend for RemoteProcess { Ok(Arc::new(RemoteExecProcess { process_id: process_id.into(), - events: StdMutex::new(events), + events_tx, + initial_events_rx: StdMutex::new(Some(events_rx)), backend: self.clone(), })) } @@ -70,10 +72,13 @@ impl ExecProcess for RemoteExecProcess { } fn subscribe(&self) -> broadcast::Receiver { - self.events + let mut initial_events_rx = self + .initial_events_rx .lock() - .expect("remote exec process events mutex should not be poisoned") - .resubscribe() + .unwrap_or_else(std::sync::PoisonError::into_inner); + initial_events_rx + .take() + .unwrap_or_else(|| self.events_tx.subscribe()) } async fn write(&self, chunk: Vec) -> Result<(), ExecServerError> { diff --git a/codex-rs/exec-server/tests/exec_process.rs b/codex-rs/exec-server/tests/exec_process.rs index 4d52a169ae..05e7122943 100644 --- a/codex-rs/exec-server/tests/exec_process.rs +++ b/codex-rs/exec-server/tests/exec_process.rs @@ -6,10 +6,9 @@ use std::sync::Arc; use anyhow::Result; use codex_exec_server::Environment; +use codex_exec_server::ExecBackend; use codex_exec_server::ExecParams; -use codex_exec_server::ExecProcess; use codex_exec_server::ExecSessionEvent; -use codex_exec_server::ExecSessionHandle; use pretty_assertions::assert_eq; use test_case::test_case; use tokio::time::Duration; @@ -19,7 +18,7 @@ use common::exec_server::ExecServerHarness; use common::exec_server::exec_server; struct ProcessContext { - process: Arc, + backend: Arc, _server: Option, } @@ -28,13 +27,13 @@ async fn create_process_context(use_remote: bool) -> Result { let server = exec_server().await?; let environment = Environment::create(Some(server.websocket_url().to_string())).await?; Ok(ProcessContext { - process: environment.get_executor(), + backend: environment.get_exec_backend(), _server: Some(server), }) } else { let environment = Environment::create(None).await?; Ok(ProcessContext { - process: environment.get_executor(), + backend: environment.get_exec_backend(), _server: None, }) } @@ -42,8 +41,8 @@ async fn create_process_context(use_remote: bool) -> Result { async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> { let context = create_process_context(use_remote).await?; - let mut session = context - .process + let session = context + .backend .start(ExecParams { process_id: "proc-1".to_string(), argv: vec!["true".to_string()], @@ -53,11 +52,12 @@ async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> { arg0: None, }) .await?; - assert_eq!(session.process_id, "proc-1"); + assert_eq!(session.process_id().as_str(), "proc-1"); + let mut events = session.subscribe(); let mut exit_code = None; loop { - match timeout(Duration::from_secs(2), session.events.recv()).await?? { + match timeout(Duration::from_secs(2), events.recv()).await?? { ExecSessionEvent::Exited { exit_code: code, .. } => exit_code = Some(code), @@ -71,12 +71,13 @@ async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> { } async fn collect_process_output_from_events( - mut session: ExecSessionHandle, + session: Arc, ) -> Result<(String, i32, bool)> { + let mut events = session.subscribe(); let mut output = String::new(); let mut exit_code = None; loop { - match timeout(Duration::from_secs(2), session.events.recv()).await?? { + match timeout(Duration::from_secs(2), events.recv()).await?? { ExecSessionEvent::Output { chunk, .. } => { output.push_str(&String::from_utf8_lossy(&chunk)); } @@ -95,7 +96,7 @@ async fn assert_exec_process_streams_output(use_remote: bool) -> Result<()> { let context = create_process_context(use_remote).await?; let process_id = "proc-stream".to_string(); let session = context - .process + .backend .start(ExecParams { process_id: process_id.clone(), argv: vec![ @@ -109,7 +110,7 @@ async fn assert_exec_process_streams_output(use_remote: bool) -> Result<()> { arg0: None, }) .await?; - assert_eq!(session.process_id, process_id); + assert_eq!(session.process_id().as_str(), process_id); let (output, exit_code, closed) = collect_process_output_from_events(session).await?; assert_eq!(output, "session output\n"); @@ -122,7 +123,7 @@ async fn assert_exec_process_write_then_read(use_remote: bool) -> Result<()> { let context = create_process_context(use_remote).await?; let process_id = "proc-stdin".to_string(); let session = context - .process + .backend .start(ExecParams { process_id: process_id.clone(), argv: vec![ @@ -136,10 +137,10 @@ async fn assert_exec_process_write_then_read(use_remote: bool) -> Result<()> { arg0: None, }) .await?; - assert_eq!(session.process_id, process_id); + assert_eq!(session.process_id().as_str(), process_id); tokio::time::sleep(Duration::from_millis(200)).await; - session.write_stdin(b"hello\n".to_vec()).await?; + session.write(b"hello\n".to_vec()).await?; let (output, exit_code, closed) = collect_process_output_from_events(session).await?; assert!( @@ -151,6 +152,35 @@ async fn assert_exec_process_write_then_read(use_remote: bool) -> Result<()> { Ok(()) } +async fn assert_exec_process_preserves_queued_events_before_subscribe( + use_remote: bool, +) -> Result<()> { + let context = create_process_context(use_remote).await?; + let session = context + .backend + .start(ExecParams { + process_id: "proc-queued".to_string(), + argv: vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "printf 'queued output\\n'".to_string(), + ], + cwd: std::env::current_dir()?, + env: Default::default(), + tty: false, + arg0: None, + }) + .await?; + + tokio::time::sleep(Duration::from_millis(200)).await; + + let (output, exit_code, closed) = collect_process_output_from_events(session).await?; + assert_eq!(output, "queued output\n"); + assert_eq!(exit_code, 0); + assert!(closed); + Ok(()) +} + #[test_case(false ; "local")] #[test_case(true ; "remote")] #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -171,3 +201,10 @@ async fn exec_process_streams_output(use_remote: bool) -> Result<()> { async fn exec_process_write_then_read(use_remote: bool) -> Result<()> { assert_exec_process_write_then_read(use_remote).await } + +#[test_case(false ; "local")] +#[test_case(true ; "remote")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_process_preserves_queued_events_before_subscribe(use_remote: bool) -> Result<()> { + assert_exec_process_preserves_queued_events_before_subscribe(use_remote).await +}