diff --git a/codex-rs/code-mode/src/lib.rs b/codex-rs/code-mode/src/lib.rs index 880e84ef4a..c2b5c46ff1 100644 --- a/codex-rs/code-mode/src/lib.rs +++ b/codex-rs/code-mode/src/lib.rs @@ -26,6 +26,7 @@ pub use runtime::WaitRequest; pub use service::CodeModeService; pub use service::CodeModeTurnHost; pub use service::CodeModeTurnWorker; +pub use service::ExecuteUntilDoneRequest; pub const PUBLIC_TOOL_NAME: &str = "exec"; pub const WAIT_TOOL_NAME: &str = "wait"; diff --git a/codex-rs/code-mode/src/service.rs b/codex-rs/code-mode/src/service.rs index 23ca7a7460..2fd7d8ae90 100644 --- a/codex-rs/code-mode/src/service.rs +++ b/codex-rs/code-mode/src/service.rs @@ -23,6 +23,10 @@ use crate::runtime::TurnMessage; use crate::runtime::WaitRequest; use crate::runtime::spawn_runtime; +mod execute_until_done; + +pub use execute_until_done::ExecuteUntilDoneRequest; + #[async_trait] pub trait CodeModeTurnHost: Send + Sync { async fn invoke_tool( @@ -53,6 +57,11 @@ pub struct CodeModeService { inner: Arc, } +struct StartedExecution { + control_tx: mpsc::UnboundedSender, + response_rx: oneshot::Receiver, +} + impl CodeModeService { pub fn new() -> Self { let (turn_message_tx, turn_message_rx) = mpsc::unbounded_channel(); @@ -76,7 +85,8 @@ impl CodeModeService { *self.inner.stored_values.lock().await = values; } - pub async fn execute(&self, request: ExecuteRequest) -> Result { + async fn start_execution(&self, request: ExecuteRequest) -> Result { + let mut sessions = self.inner.sessions.lock().await; let cell_id = self .inner .next_cell_id @@ -87,18 +97,19 @@ impl CodeModeService { let (control_tx, control_rx) = mpsc::unbounded_channel(); let (response_tx, response_rx) = oneshot::channel(); - self.inner.sessions.lock().await.insert( + sessions.insert( cell_id.clone(), SessionHandle { control_tx: control_tx.clone(), runtime_tx: runtime_tx.clone(), }, ); + drop(sessions); tokio::spawn(run_session_control( Arc::clone(&self.inner), SessionControlContext { - cell_id: cell_id.clone(), + cell_id, runtime_tx, runtime_terminate_handle, }, @@ -108,7 +119,16 @@ impl CodeModeService { request.yield_time_ms.unwrap_or(DEFAULT_EXEC_YIELD_TIME_MS), )); - response_rx + Ok(StartedExecution { + control_tx, + response_rx, + }) + } + + pub async fn execute(&self, request: ExecuteRequest) -> Result { + self.start_execution(request) + .await? + .response_rx .await .map_err(|_| "exec runtime ended unexpectedly".to_string()) } @@ -477,6 +497,7 @@ mod tests { use tokio::sync::oneshot; use super::CodeModeService; + use super::ExecuteUntilDoneRequest; use super::Inner; use super::RuntimeCommand; use super::RuntimeResponse; @@ -832,6 +853,149 @@ image({ ); } + #[tokio::test] + async fn execute_until_done_returns_quick_result() { + let service = CodeModeService::new(); + + let response = service + .execute_until_done(ExecuteUntilDoneRequest { + execute: ExecuteRequest { + source: r#"text("done");"#.to_string(), + yield_time_ms: Some(60_000), + ..execute_request("") + }, + poll_yield_time_ms: 1, + terminate_on_drop: true, + }) + .await + .unwrap(); + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: "1".to_string(), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "done".to_string(), + }], + stored_values: HashMap::new(), + error_text: None, + } + ); + } + + #[tokio::test] + async fn execute_until_done_drains_timer_yields() { + let service = CodeModeService::new(); + + let response = service + .execute_until_done(ExecuteUntilDoneRequest { + execute: ExecuteRequest { + source: r#" +await new Promise(resolve => setTimeout(resolve, 20)); +text("done"); +"# + .to_string(), + yield_time_ms: Some(1), + ..execute_request("") + }, + poll_yield_time_ms: 1_000, + terminate_on_drop: true, + }) + .await + .unwrap(); + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: "1".to_string(), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "done".to_string(), + }], + stored_values: HashMap::new(), + error_text: None, + } + ); + } + + #[tokio::test] + async fn execute_until_done_drains_explicit_yield_control() { + let service = CodeModeService::new(); + + let response = service + .execute_until_done(ExecuteUntilDoneRequest { + execute: ExecuteRequest { + source: r#"text("before"); yield_control(); text("after");"#.to_string(), + yield_time_ms: Some(60_000), + ..execute_request("") + }, + poll_yield_time_ms: 60_000, + terminate_on_drop: true, + }) + .await + .unwrap(); + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: "1".to_string(), + content_items: vec![ + FunctionCallOutputContentItem::InputText { + text: "before".to_string(), + }, + FunctionCallOutputContentItem::InputText { + text: "after".to_string(), + }, + ], + stored_values: HashMap::new(), + error_text: None, + } + ); + } + + #[tokio::test] + async fn execute_until_done_terminates_running_cell_when_cancelled() { + let service = Arc::new(CodeModeService::new()); + let run_service = Arc::clone(&service); + let handle = tokio::spawn(async move { + run_service + .execute_until_done(ExecuteUntilDoneRequest { + execute: ExecuteRequest { + source: "await new Promise(() => {})".to_string(), + yield_time_ms: Some(60_000), + ..execute_request("") + }, + poll_yield_time_ms: 60_000, + terminate_on_drop: true, + }) + .await + }); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if service.inner.sessions.lock().await.len() == 1 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + + handle.abort(); + let _ = handle.await; + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if service.inner.sessions.lock().await.is_empty() { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + } + #[tokio::test] async fn terminate_waits_for_runtime_shutdown_before_responding() { let inner = test_inner(); diff --git a/codex-rs/code-mode/src/service/execute_until_done.rs b/codex-rs/code-mode/src/service/execute_until_done.rs new file mode 100644 index 0000000000..69e6f4c6d3 --- /dev/null +++ b/codex-rs/code-mode/src/service/execute_until_done.rs @@ -0,0 +1,106 @@ +use tokio::sync::mpsc; +use tokio::sync::oneshot; + +use crate::runtime::ExecuteRequest; +use crate::runtime::RuntimeResponse; +use crate::runtime::WaitRequest; + +use super::CodeModeService; +use super::SessionControlCommand; +use super::StartedExecution; + +#[derive(Clone, Debug)] +pub struct ExecuteUntilDoneRequest { + pub execute: ExecuteRequest, + pub poll_yield_time_ms: u64, + pub terminate_on_drop: bool, +} + +impl CodeModeService { + pub async fn execute_until_done( + &self, + request: ExecuteUntilDoneRequest, + ) -> Result { + let StartedExecution { + control_tx, + response_rx, + } = self.start_execution(request.execute).await?; + let mut drop_guard = RunningCellDropGuard { + control_tx, + is_armed: true, + terminate_on_drop: request.terminate_on_drop, + }; + let mut accumulated_content_items = Vec::new(); + let mut response = response_rx + .await + .map_err(|_| "exec runtime ended unexpectedly".to_string())?; + + loop { + match response { + RuntimeResponse::Yielded { + cell_id, + content_items, + } => { + accumulated_content_items.extend(content_items); + response = self + .wait(WaitRequest { + cell_id, + yield_time_ms: request.poll_yield_time_ms, + terminate: false, + }) + .await?; + } + RuntimeResponse::Terminated { + cell_id, + content_items, + } => { + accumulated_content_items.extend(content_items); + drop_guard.disarm(); + return Ok(RuntimeResponse::Terminated { + cell_id, + content_items: accumulated_content_items, + }); + } + RuntimeResponse::Result { + cell_id, + content_items, + stored_values, + error_text, + } => { + accumulated_content_items.extend(content_items); + drop_guard.disarm(); + return Ok(RuntimeResponse::Result { + cell_id, + content_items: accumulated_content_items, + stored_values, + error_text, + }); + } + } + } + } +} + +struct RunningCellDropGuard { + control_tx: mpsc::UnboundedSender, + is_armed: bool, + terminate_on_drop: bool, +} + +impl RunningCellDropGuard { + fn disarm(&mut self) { + self.is_armed = false; + } +} + +impl Drop for RunningCellDropGuard { + fn drop(&mut self) { + if !self.terminate_on_drop || !self.is_armed { + return; + } + let (response_tx, _response_rx) = oneshot::channel(); + let _ = self + .control_tx + .send(SessionControlCommand::Terminate { response_tx }); + } +}