mirror of
https://github.com/openai/codex.git
synced 2026-05-05 22:01:37 +03:00
438 lines
15 KiB
Rust
438 lines
15 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
use std::time::Instant;
|
|
|
|
use crate::client_common::tools::ToolSpec;
|
|
use crate::features::Feature;
|
|
use crate::function_tool::FunctionCallError;
|
|
use crate::memories::usage::emit_metric_for_tool_read;
|
|
use crate::protocol::SandboxPolicy;
|
|
use crate::sandbox_tags::sandbox_tag;
|
|
use crate::tools::context::ToolInvocation;
|
|
use crate::tools::context::ToolOutput;
|
|
use crate::tools::context::ToolPayload;
|
|
use async_trait::async_trait;
|
|
use codex_hooks::HookEvent;
|
|
use codex_hooks::HookEventAfterToolUse;
|
|
use codex_hooks::HookPayload;
|
|
use codex_hooks::HookResult;
|
|
use codex_hooks::HookToolInput;
|
|
use codex_hooks::HookToolInputLocalShell;
|
|
use codex_hooks::HookToolKind;
|
|
use codex_protocol::models::ResponseInputItem;
|
|
use codex_utils_readiness::Readiness;
|
|
use tracing::warn;
|
|
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
|
pub enum ToolKind {
|
|
Function,
|
|
Mcp,
|
|
}
|
|
|
|
#[async_trait]
|
|
pub trait ToolHandler: Send + Sync {
|
|
fn kind(&self) -> ToolKind;
|
|
|
|
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
|
matches!(
|
|
(self.kind(), payload),
|
|
(ToolKind::Function, ToolPayload::Function { .. })
|
|
| (ToolKind::Mcp, ToolPayload::Mcp { .. })
|
|
)
|
|
}
|
|
|
|
/// Returns `true` if the [ToolInvocation] *might* mutate the environment of the
|
|
/// user (through file system, OS operations, ...).
|
|
/// This function must remains defensive and return `true` if a doubt exist on the
|
|
/// exact effect of a ToolInvocation.
|
|
async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
|
|
false
|
|
}
|
|
|
|
/// Perform the actual [ToolInvocation] and returns a [ToolOutput] containing
|
|
/// the final output to return to the model.
|
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError>;
|
|
}
|
|
|
|
pub struct ToolRegistry {
|
|
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
|
}
|
|
|
|
impl ToolRegistry {
|
|
pub fn new(handlers: HashMap<String, Arc<dyn ToolHandler>>) -> Self {
|
|
Self { handlers }
|
|
}
|
|
|
|
pub fn handler(&self, name: &str) -> Option<Arc<dyn ToolHandler>> {
|
|
self.handlers.get(name).map(Arc::clone)
|
|
}
|
|
|
|
// TODO(jif) for dynamic tools.
|
|
// pub fn register(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
|
// let name = name.into();
|
|
// if self.handlers.insert(name.clone(), handler).is_some() {
|
|
// warn!("overwriting handler for tool {name}");
|
|
// }
|
|
// }
|
|
|
|
pub async fn dispatch(
|
|
&self,
|
|
invocation: ToolInvocation,
|
|
) -> Result<ResponseInputItem, FunctionCallError> {
|
|
let tool_name = invocation.tool_name.clone();
|
|
let call_id_owned = invocation.call_id.clone();
|
|
let otel = invocation.turn.otel_manager.clone();
|
|
let payload_for_response = invocation.payload.clone();
|
|
let log_payload = payload_for_response.log_payload();
|
|
let metric_tags = [
|
|
(
|
|
"sandbox",
|
|
sandbox_tag(
|
|
&invocation.turn.sandbox_policy,
|
|
invocation.turn.windows_sandbox_level,
|
|
invocation
|
|
.turn
|
|
.features
|
|
.enabled(Feature::UseLinuxSandboxBwrap),
|
|
),
|
|
),
|
|
(
|
|
"sandbox_policy",
|
|
sandbox_policy_tag(&invocation.turn.sandbox_policy),
|
|
),
|
|
];
|
|
let (mcp_server, mcp_server_origin) = match &invocation.payload {
|
|
ToolPayload::Mcp { server, .. } => {
|
|
let manager = invocation
|
|
.session
|
|
.services
|
|
.mcp_connection_manager
|
|
.read()
|
|
.await;
|
|
let origin = manager.server_origin(server).map(str::to_owned);
|
|
(Some(server.clone()), origin)
|
|
}
|
|
_ => (None, None),
|
|
};
|
|
let mcp_server_ref = mcp_server.as_deref();
|
|
let mcp_server_origin_ref = mcp_server_origin.as_deref();
|
|
|
|
let handler = match self.handler(tool_name.as_ref()) {
|
|
Some(handler) => handler,
|
|
None => {
|
|
let message =
|
|
unsupported_tool_call_message(&invocation.payload, tool_name.as_ref());
|
|
otel.tool_result_with_tags(
|
|
tool_name.as_ref(),
|
|
&call_id_owned,
|
|
log_payload.as_ref(),
|
|
Duration::ZERO,
|
|
false,
|
|
&message,
|
|
&metric_tags,
|
|
mcp_server_ref,
|
|
mcp_server_origin_ref,
|
|
);
|
|
return Err(FunctionCallError::RespondToModel(message));
|
|
}
|
|
};
|
|
|
|
if !handler.matches_kind(&invocation.payload) {
|
|
let message = format!("tool {tool_name} invoked with incompatible payload");
|
|
otel.tool_result_with_tags(
|
|
tool_name.as_ref(),
|
|
&call_id_owned,
|
|
log_payload.as_ref(),
|
|
Duration::ZERO,
|
|
false,
|
|
&message,
|
|
&metric_tags,
|
|
mcp_server_ref,
|
|
mcp_server_origin_ref,
|
|
);
|
|
return Err(FunctionCallError::Fatal(message));
|
|
}
|
|
|
|
let is_mutating = handler.is_mutating(&invocation).await;
|
|
let output_cell = tokio::sync::Mutex::new(None);
|
|
let invocation_for_tool = invocation.clone();
|
|
|
|
let started = Instant::now();
|
|
let result = otel
|
|
.log_tool_result_with_tags(
|
|
tool_name.as_ref(),
|
|
&call_id_owned,
|
|
log_payload.as_ref(),
|
|
&metric_tags,
|
|
mcp_server_ref,
|
|
mcp_server_origin_ref,
|
|
|| {
|
|
let handler = handler.clone();
|
|
let output_cell = &output_cell;
|
|
async move {
|
|
if is_mutating {
|
|
tracing::trace!("waiting for tool gate");
|
|
invocation_for_tool.turn.tool_call_gate.wait_ready().await;
|
|
tracing::trace!("tool gate released");
|
|
}
|
|
match handler.handle(invocation_for_tool).await {
|
|
Ok(output) => {
|
|
let preview = output.log_preview();
|
|
let success = output.success_for_logging();
|
|
let mut guard = output_cell.lock().await;
|
|
*guard = Some(output);
|
|
Ok((preview, success))
|
|
}
|
|
Err(err) => Err(err),
|
|
}
|
|
}
|
|
},
|
|
)
|
|
.await;
|
|
let duration = started.elapsed();
|
|
let (output_preview, success) = match &result {
|
|
Ok((preview, success)) => (preview.clone(), *success),
|
|
Err(err) => (err.to_string(), false),
|
|
};
|
|
emit_metric_for_tool_read(&invocation, success).await;
|
|
let hook_abort_error = dispatch_after_tool_use_hook(AfterToolUseHookDispatch {
|
|
invocation: &invocation,
|
|
output_preview,
|
|
success,
|
|
executed: true,
|
|
duration,
|
|
mutating: is_mutating,
|
|
})
|
|
.await;
|
|
|
|
if let Some(err) = hook_abort_error {
|
|
return Err(err);
|
|
}
|
|
|
|
match result {
|
|
Ok(_) => {
|
|
let mut guard = output_cell.lock().await;
|
|
let output = guard.take().ok_or_else(|| {
|
|
FunctionCallError::Fatal("tool produced no output".to_string())
|
|
})?;
|
|
Ok(output.into_response(&call_id_owned, &payload_for_response))
|
|
}
|
|
Err(err) => Err(err),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ConfiguredToolSpec {
|
|
pub spec: ToolSpec,
|
|
pub supports_parallel_tool_calls: bool,
|
|
}
|
|
|
|
impl ConfiguredToolSpec {
|
|
pub fn new(spec: ToolSpec, supports_parallel_tool_calls: bool) -> Self {
|
|
Self {
|
|
spec,
|
|
supports_parallel_tool_calls,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct ToolRegistryBuilder {
|
|
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
|
specs: Vec<ConfiguredToolSpec>,
|
|
}
|
|
|
|
impl ToolRegistryBuilder {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
handlers: HashMap::new(),
|
|
specs: Vec::new(),
|
|
}
|
|
}
|
|
|
|
pub fn push_spec(&mut self, spec: ToolSpec) {
|
|
self.push_spec_with_parallel_support(spec, false);
|
|
}
|
|
|
|
pub fn push_spec_with_parallel_support(
|
|
&mut self,
|
|
spec: ToolSpec,
|
|
supports_parallel_tool_calls: bool,
|
|
) {
|
|
self.specs
|
|
.push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
|
|
}
|
|
|
|
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
|
let name = name.into();
|
|
if self
|
|
.handlers
|
|
.insert(name.clone(), handler.clone())
|
|
.is_some()
|
|
{
|
|
warn!("overwriting handler for tool {name}");
|
|
}
|
|
}
|
|
|
|
// TODO(jif) for dynamic tools.
|
|
// pub fn register_many<I>(&mut self, names: I, handler: Arc<dyn ToolHandler>)
|
|
// where
|
|
// I: IntoIterator,
|
|
// I::Item: Into<String>,
|
|
// {
|
|
// for name in names {
|
|
// let name = name.into();
|
|
// if self
|
|
// .handlers
|
|
// .insert(name.clone(), handler.clone())
|
|
// .is_some()
|
|
// {
|
|
// warn!("overwriting handler for tool {name}");
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
pub fn build(self) -> (Vec<ConfiguredToolSpec>, ToolRegistry) {
|
|
let registry = ToolRegistry::new(self.handlers);
|
|
(self.specs, registry)
|
|
}
|
|
}
|
|
|
|
fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &str) -> String {
|
|
match payload {
|
|
ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"),
|
|
_ => format!("unsupported call: {tool_name}"),
|
|
}
|
|
}
|
|
|
|
fn sandbox_policy_tag(policy: &SandboxPolicy) -> &'static str {
|
|
match policy {
|
|
SandboxPolicy::ReadOnly { .. } => "read-only",
|
|
SandboxPolicy::WorkspaceWrite { .. } => "workspace-write",
|
|
SandboxPolicy::Custom { .. } => "custom",
|
|
SandboxPolicy::DangerFullAccess => "danger-full-access",
|
|
SandboxPolicy::ExternalSandbox { .. } => "external-sandbox",
|
|
}
|
|
}
|
|
|
|
// Hooks use a separate wire-facing input type so hook payload JSON stays stable
|
|
// and decoupled from core's internal tool runtime representation.
|
|
impl From<&ToolPayload> for HookToolInput {
|
|
fn from(payload: &ToolPayload) -> Self {
|
|
match payload {
|
|
ToolPayload::Function { arguments } => HookToolInput::Function {
|
|
arguments: arguments.clone(),
|
|
},
|
|
ToolPayload::Custom { input } => HookToolInput::Custom {
|
|
input: input.clone(),
|
|
},
|
|
ToolPayload::LocalShell { params } => HookToolInput::LocalShell {
|
|
params: HookToolInputLocalShell {
|
|
command: params.command.clone(),
|
|
workdir: params.workdir.clone(),
|
|
timeout_ms: params.timeout_ms,
|
|
sandbox_permissions: params.sandbox_permissions,
|
|
prefix_rule: params.prefix_rule.clone(),
|
|
justification: params.justification.clone(),
|
|
},
|
|
},
|
|
ToolPayload::Mcp {
|
|
server,
|
|
tool,
|
|
raw_arguments,
|
|
} => HookToolInput::Mcp {
|
|
server: server.clone(),
|
|
tool: tool.clone(),
|
|
arguments: raw_arguments.clone(),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
fn hook_tool_kind(tool_input: &HookToolInput) -> HookToolKind {
|
|
match tool_input {
|
|
HookToolInput::Function { .. } => HookToolKind::Function,
|
|
HookToolInput::Custom { .. } => HookToolKind::Custom,
|
|
HookToolInput::LocalShell { .. } => HookToolKind::LocalShell,
|
|
HookToolInput::Mcp { .. } => HookToolKind::Mcp,
|
|
}
|
|
}
|
|
|
|
struct AfterToolUseHookDispatch<'a> {
|
|
invocation: &'a ToolInvocation,
|
|
output_preview: String,
|
|
success: bool,
|
|
executed: bool,
|
|
duration: Duration,
|
|
mutating: bool,
|
|
}
|
|
|
|
async fn dispatch_after_tool_use_hook(
|
|
dispatch: AfterToolUseHookDispatch<'_>,
|
|
) -> Option<FunctionCallError> {
|
|
let AfterToolUseHookDispatch { invocation, .. } = dispatch;
|
|
let session = invocation.session.as_ref();
|
|
let turn = invocation.turn.as_ref();
|
|
let tool_input = HookToolInput::from(&invocation.payload);
|
|
let hook_outcomes = session
|
|
.hooks()
|
|
.dispatch(HookPayload {
|
|
session_id: session.conversation_id,
|
|
cwd: turn.cwd.clone(),
|
|
triggered_at: chrono::Utc::now(),
|
|
hook_event: HookEvent::AfterToolUse {
|
|
event: HookEventAfterToolUse {
|
|
turn_id: turn.sub_id.clone(),
|
|
call_id: invocation.call_id.clone(),
|
|
tool_name: invocation.tool_name.clone(),
|
|
tool_kind: hook_tool_kind(&tool_input),
|
|
tool_input,
|
|
executed: dispatch.executed,
|
|
success: dispatch.success,
|
|
duration_ms: u64::try_from(dispatch.duration.as_millis()).unwrap_or(u64::MAX),
|
|
mutating: dispatch.mutating,
|
|
sandbox: sandbox_tag(
|
|
&turn.sandbox_policy,
|
|
turn.windows_sandbox_level,
|
|
turn.features.enabled(Feature::UseLinuxSandboxBwrap),
|
|
)
|
|
.to_string(),
|
|
sandbox_policy: sandbox_policy_tag(&turn.sandbox_policy).to_string(),
|
|
output_preview: dispatch.output_preview.clone(),
|
|
},
|
|
},
|
|
})
|
|
.await;
|
|
|
|
for hook_outcome in hook_outcomes {
|
|
let hook_name = hook_outcome.hook_name;
|
|
match hook_outcome.result {
|
|
HookResult::Success => {}
|
|
HookResult::FailedContinue(error) => {
|
|
warn!(
|
|
call_id = %invocation.call_id,
|
|
tool_name = %invocation.tool_name,
|
|
hook_name = %hook_name,
|
|
error = %error,
|
|
"after_tool_use hook failed; continuing"
|
|
);
|
|
}
|
|
HookResult::FailedAbort(error) => {
|
|
warn!(
|
|
call_id = %invocation.call_id,
|
|
tool_name = %invocation.tool_name,
|
|
hook_name = %hook_name,
|
|
error = %error,
|
|
"after_tool_use hook failed; aborting operation"
|
|
);
|
|
return Some(FunctionCallError::Fatal(format!(
|
|
"after_tool_use hook '{hook_name}' failed and aborted operation: {error}"
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|