diff --git a/codex-rs/core/src/tools/handlers/apply_patch.rs b/codex-rs/core/src/tools/handlers/apply_patch.rs index 1e82b9cf10..2109f1d2c8 100644 --- a/codex-rs/core/src/tools/handlers/apply_patch.rs +++ b/codex-rs/core/src/tools/handlers/apply_patch.rs @@ -42,6 +42,10 @@ impl ToolHandler for ApplyPatchHandler { ) } + fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { + true + } + async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, diff --git a/codex-rs/core/src/tools/handlers/shell.rs b/codex-rs/core/src/tools/handlers/shell.rs index 81915fc129..3a5115f6ee 100644 --- a/codex-rs/core/src/tools/handlers/shell.rs +++ b/codex-rs/core/src/tools/handlers/shell.rs @@ -10,6 +10,7 @@ use crate::codex::TurnContext; use crate::exec::ExecParams; use crate::exec_env::create_env; use crate::function_tool::FunctionCallError; +use crate::is_safe_command::is_known_safe_command; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; @@ -77,6 +78,18 @@ impl ToolHandler for ShellHandler { ) } + fn is_mutating(&self, invocation: &ToolInvocation) -> bool { + match &invocation.payload { + ToolPayload::Function { arguments } => { + serde_json::from_str::(arguments) + .map(|params| !is_known_safe_command(¶ms.command)) + .unwrap_or(true) + } + ToolPayload::LocalShell { params } => !is_known_safe_command(¶ms.command), + _ => true, // unknown payloads => assume mutating + } + } + async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, diff --git a/codex-rs/core/src/tools/handlers/unified_exec.rs b/codex-rs/core/src/tools/handlers/unified_exec.rs index 6673cf9ba8..fb94f23d0d 100644 --- a/codex-rs/core/src/tools/handlers/unified_exec.rs +++ b/codex-rs/core/src/tools/handlers/unified_exec.rs @@ -1,9 +1,7 @@ use std::path::PathBuf; -use async_trait::async_trait; -use serde::Deserialize; - use crate::function_tool::FunctionCallError; +use crate::is_safe_command::is_known_safe_command; use crate::protocol::EventMsg; use crate::protocol::ExecCommandOutputDeltaEvent; use crate::protocol::ExecOutputStream; @@ -20,6 +18,8 @@ use crate::unified_exec::UnifiedExecContext; use crate::unified_exec::UnifiedExecResponse; use crate::unified_exec::UnifiedExecSessionManager; use crate::unified_exec::WriteStdinRequest; +use async_trait::async_trait; +use serde::Deserialize; pub struct UnifiedExecHandler; @@ -74,6 +74,19 @@ impl ToolHandler for UnifiedExecHandler { ) } + fn is_mutating(&self, invocation: &ToolInvocation) -> bool { + let (ToolPayload::Function { arguments } | ToolPayload::UnifiedExec { arguments }) = + &invocation.payload + else { + return true; + }; + + let Ok(params) = serde_json::from_str::(arguments) else { + return true; + }; + !is_known_safe_command(&["bash".to_string(), "-lc".to_string(), params.cmd]) + } + async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index 5388ab9dac..56a4547526 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -16,7 +16,6 @@ use crate::tools::router::ToolCall; use crate::tools::router::ToolRouter; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseInputItem; -use codex_utils_readiness::Readiness; pub(crate) struct ToolCallRuntime { router: Arc, @@ -55,7 +54,6 @@ impl ToolCallRuntime { let tracker = Arc::clone(&self.tracker); let lock = Arc::clone(&self.parallel_execution); let started = Instant::now(); - let readiness = self.turn_context.tool_call_gate.clone(); let handle: AbortOnDropHandle> = AbortOnDropHandle::new(tokio::spawn(async move { @@ -65,9 +63,6 @@ impl ToolCallRuntime { Ok(Self::aborted_response(&call, secs)) }, res = async { - tracing::trace!("waiting for tool gate"); - readiness.wait_ready().await; - tracing::trace!("tool gate released"); let _guard = if supports_parallel { Either::Left(lock.read().await) } else { diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index 8769259794..f35ff06315 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -2,15 +2,15 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use async_trait::async_trait; -use codex_protocol::models::ResponseInputItem; -use tracing::warn; - use crate::client_common::tools::ToolSpec; use crate::function_tool::FunctionCallError; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; +use async_trait::async_trait; +use codex_protocol::models::ResponseInputItem; +use codex_utils_readiness::Readiness; +use tracing::warn; #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum ToolKind { @@ -30,6 +30,10 @@ pub trait ToolHandler: Send + Sync { ) } + fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { + false + } + async fn handle(&self, invocation: ToolInvocation) -> Result; } @@ -106,6 +110,11 @@ impl ToolRegistry { let output_cell = &output_cell; let invocation = invocation; async move { + if handler.is_mutating(&invocation) { + tracing::trace!("waiting for tool gate"); + invocation.turn.tool_call_gate.wait_ready().await; + tracing::trace!("tool gate released"); + } match handler.handle(invocation).await { Ok(output) => { let preview = output.log_preview();