diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 96f2c5d891..76af2d5f1a 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -584,6 +584,7 @@ dependencies = [ "codex-file-search", "codex-protocol", "core_test_support", + "futures", "libc", "mcp-types", "portable-pty", diff --git a/codex-rs/agent/Cargo.toml b/codex-rs/agent/Cargo.toml index 0207540cce..462973d5c7 100644 --- a/codex-rs/agent/Cargo.toml +++ b/codex-rs/agent/Cargo.toml @@ -27,6 +27,7 @@ time = { workspace = true, features = ["formatting", "parsing", "local-offset", tracing = { workspace = true } tree-sitter = { workspace = true } tree-sitter-bash = { workspace = true } +futures = { workspace = true } [dev-dependencies] core_test_support = { workspace = true } diff --git a/codex-rs/core/review_prompt.md b/codex-rs/agent/review_prompt.md similarity index 100% rename from codex-rs/core/review_prompt.md rename to codex-rs/agent/review_prompt.md diff --git a/codex-rs/agent/src/client_common.rs b/codex-rs/agent/src/client_common.rs new file mode 100644 index 0000000000..5e6bf0eca5 --- /dev/null +++ b/codex-rs/agent/src/client_common.rs @@ -0,0 +1,363 @@ +use crate::model_family::ModelFamily; +use crate::tool_schema::OpenAiTool; +use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS; +use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; +use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; +use codex_protocol::config_types::Verbosity as VerbosityConfig; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::RateLimitSnapshot; +use codex_protocol::protocol::TokenUsage; +use futures::Stream; +use serde::Serialize; +use serde_json::Value; +use std::borrow::Cow; +use std::ops::Deref; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; +use tokio::sync::mpsc; + +/// Review thread system prompt. Edit `agent/review_prompt.md` to customize. +pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md"); + +/// API request payload for a single model turn +#[derive(Default, Debug, Clone)] +pub struct Prompt { + /// Conversation context input items. + pub input: Vec, + + /// Tools available to the model, including additional tools sourced from + /// external MCP servers. + pub tools: Vec, + + /// Optional override for the built-in BASE_INSTRUCTIONS. + pub base_instructions_override: Option, + + /// Optional the output schema for the model's response. + pub output_schema: Option, +} + +impl Prompt { + pub fn get_full_instructions<'a>(&'a self, model: &'a ModelFamily) -> Cow<'a, str> { + let base = self + .base_instructions_override + .as_deref() + .unwrap_or(model.base_instructions.deref()); + // When there are no custom instructions, add apply_patch_tool_instructions if: + // - the model needs special instructions (4.1) + // AND + // - there is no apply_patch tool present + let is_apply_patch_tool_present = self.tools.iter().any(|tool| match tool { + OpenAiTool::Function(f) => f.name == "apply_patch", + OpenAiTool::Freeform(f) => f.name == "apply_patch", + _ => false, + }); + if self.base_instructions_override.is_none() + && model.needs_special_apply_patch_instructions + && !is_apply_patch_tool_present + { + Cow::Owned(format!("{base}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}")) + } else { + Cow::Borrowed(base) + } + } + + pub fn get_formatted_input(&self) -> Vec { + self.input.clone() + } +} + +#[derive(Debug)] +pub enum ResponseEvent { + Created, + OutputItemDone(ResponseItem), + Completed { + response_id: String, + token_usage: Option, + }, + OutputTextDelta(String), + ReasoningSummaryDelta(String), + ReasoningContentDelta(String), + ReasoningSummaryPartAdded, + WebSearchCallBegin { + call_id: String, + }, + RateLimits(RateLimitSnapshot), +} + +#[derive(Debug, Serialize)] +pub struct Reasoning { + #[serde(skip_serializing_if = "Option::is_none")] + pub effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +#[derive(Debug, Serialize, Default, Clone)] +#[serde(rename_all = "snake_case")] +pub enum TextFormatType { + #[default] + JsonSchema, +} + +#[derive(Debug, Serialize, Default, Clone)] +pub struct TextFormat { + pub r#type: TextFormatType, + pub strict: bool, + pub schema: Value, + pub name: String, +} + +/// Controls under the `text` field in the Responses API for GPT-5. +#[derive(Debug, Serialize, Default, Clone)] +pub struct TextControls { + #[serde(skip_serializing_if = "Option::is_none")] + pub verbosity: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, +} + +#[derive(Debug, Serialize, Default, Clone)] +#[serde(rename_all = "lowercase")] +pub enum OpenAiVerbosity { + Low, + #[default] + Medium, + High, +} + +impl From for OpenAiVerbosity { + fn from(v: VerbosityConfig) -> Self { + match v { + VerbosityConfig::Low => OpenAiVerbosity::Low, + VerbosityConfig::Medium => OpenAiVerbosity::Medium, + VerbosityConfig::High => OpenAiVerbosity::High, + } + } +} + +/// Request object that is serialized as JSON and POST'ed when using the +/// Responses API. +#[derive(Debug, Serialize)] +pub struct ResponsesApiRequest<'a> { + pub model: &'a str, + pub instructions: &'a str, + // TODO(mbolin): ResponseItem::Other should not be serialized. Currently, + // we code defensively to avoid this case, but perhaps we should use a + // separate enum for serialization. + pub input: &'a Vec, + pub tools: &'a [serde_json::Value], + pub tool_choice: &'static str, + pub parallel_tool_calls: bool, + pub reasoning: Option, + pub store: bool, + pub stream: bool, + pub include: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_cache_key: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, +} + +pub fn create_reasoning_param_for_request( + model_family: &ModelFamily, + effort: Option, + summary: ReasoningSummaryConfig, +) -> Option { + if !model_family.supports_reasoning_summaries { + return None; + } + + Some(Reasoning { + effort, + summary: Some(summary), + }) +} + +pub fn create_text_param_for_request( + verbosity: Option, + output_schema: &Option, +) -> Option { + if verbosity.is_none() && output_schema.is_none() { + return None; + } + + Some(TextControls { + verbosity: verbosity.map(std::convert::Into::into), + format: output_schema.as_ref().map(|schema| TextFormat { + r#type: TextFormatType::JsonSchema, + strict: true, + schema: schema.clone(), + name: "codex_output_schema".to_string(), + }), + }) +} + +pub struct ResponseStream { + pub rx_event: mpsc::Receiver>, +} + +impl Stream for ResponseStream { + type Item = std::result::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + self.rx_event.poll_recv(cx) + } +} + +#[cfg(test)] +mod tests { + use crate::config_types::ReasoningSummaryFormat; + use crate::tooling::ApplyPatchToolType; + use pretty_assertions::assert_eq; + + use super::*; + + struct InstructionsTestCase { + pub slug: &'static str, + pub expects_apply_patch_instructions: bool, + } + #[test] + fn get_full_instructions_no_user_content() { + let prompt = Prompt::default(); + let base_instructions = "Base instructions".to_string(); + let test_cases = vec![ + InstructionsTestCase { + slug: "needs-apply-patch", + expects_apply_patch_instructions: true, + }, + InstructionsTestCase { + slug: "no-apply-patch", + expects_apply_patch_instructions: false, + }, + ]; + + for test_case in test_cases { + let model_family = ModelFamily { + slug: test_case.slug.to_string(), + family: "test".to_string(), + needs_special_apply_patch_instructions: test_case.expects_apply_patch_instructions, + supports_reasoning_summaries: false, + reasoning_summary_format: ReasoningSummaryFormat::None, + uses_local_shell_tool: false, + apply_patch_tool_type: Some(ApplyPatchToolType::Function), + base_instructions: base_instructions.clone(), + }; + + let expected = if test_case.expects_apply_patch_instructions { + format!( + "{}\n{}", + model_family.base_instructions, APPLY_PATCH_TOOL_INSTRUCTIONS + ) + } else { + model_family.base_instructions.clone() + }; + + let full = prompt.get_full_instructions(&model_family); + assert_eq!(full, expected); + } + } + + #[test] + fn serializes_text_verbosity_when_set() { + let input: Vec = vec![]; + let tools: Vec = vec![]; + let req = ResponsesApiRequest { + model: "gpt-5", + instructions: "i", + input: &input, + tools: &tools, + tool_choice: "auto", + parallel_tool_calls: false, + reasoning: None, + store: false, + stream: true, + include: vec![], + prompt_cache_key: None, + text: Some(TextControls { + verbosity: Some(OpenAiVerbosity::Low), + format: None, + }), + }; + + let v = serde_json::to_value(&req).expect("json"); + assert_eq!( + v.get("text") + .and_then(|t| t.get("verbosity")) + .and_then(|s| s.as_str()), + Some("low") + ); + } + + #[test] + fn serializes_text_schema_with_strict_format() { + let input: Vec = vec![]; + let tools: Vec = vec![]; + let schema = serde_json::json!({ + "type": "object", + "properties": { + "answer": {"type": "string"} + }, + "required": ["answer"], + }); + let text_controls = + create_text_param_for_request(None, &Some(schema.clone())).expect("text controls"); + + let req = ResponsesApiRequest { + model: "gpt-5", + instructions: "i", + input: &input, + tools: &tools, + tool_choice: "auto", + parallel_tool_calls: false, + reasoning: None, + store: false, + stream: true, + include: vec![], + prompt_cache_key: None, + text: Some(text_controls), + }; + + let v = serde_json::to_value(&req).expect("json"); + let text = v.get("text").expect("text field"); + assert!(text.get("verbosity").is_none()); + let format = text.get("format").expect("format field"); + + assert_eq!( + format.get("name"), + Some(&serde_json::Value::String("codex_output_schema".into())) + ); + assert_eq!( + format.get("type"), + Some(&serde_json::Value::String("json_schema".into())) + ); + assert_eq!(format.get("strict"), Some(&serde_json::Value::Bool(true))); + assert_eq!(format.get("schema"), Some(&schema)); + } + + #[test] + fn omits_text_when_not_set() { + let input: Vec = vec![]; + let tools: Vec = vec![]; + let req = ResponsesApiRequest { + model: "gpt-5", + instructions: "i", + input: &input, + tools: &tools, + tool_choice: "auto", + parallel_tool_calls: false, + reasoning: None, + store: false, + stream: true, + include: vec![], + prompt_cache_key: None, + text: None, + }; + + let v = serde_json::to_value(&req).expect("json"); + assert!(v.get("text").is_none()); + } +} diff --git a/codex-rs/agent/src/exec.rs b/codex-rs/agent/src/exec.rs new file mode 100644 index 0000000000..ba624b635f --- /dev/null +++ b/codex-rs/agent/src/exec.rs @@ -0,0 +1,21 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::time::Duration; + +const DEFAULT_TIMEOUT_MS: u64 = 10_000; + +#[derive(Clone, Debug)] +pub struct ExecParams { + pub command: Vec, + pub cwd: PathBuf, + pub timeout_ms: Option, + pub env: HashMap, + pub with_escalated_permissions: Option, + pub justification: Option, +} + +impl ExecParams { + pub fn timeout_duration(&self) -> Duration { + Duration::from_millis(self.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS)) + } +} diff --git a/codex-rs/agent/src/lib.rs b/codex-rs/agent/src/lib.rs index 98b3216fb2..6eab5ea956 100644 --- a/codex-rs/agent/src/lib.rs +++ b/codex-rs/agent/src/lib.rs @@ -1,10 +1,13 @@ pub mod apply_patch; pub mod bash; +pub mod client_common; pub mod command_safety; pub mod config_types; pub mod conversation_history; +pub mod exec; pub mod exec_command; pub mod function_tool; +pub mod model_client; pub mod model_family; pub mod model_provider; pub mod notifications; @@ -18,17 +21,22 @@ pub mod session_services; pub mod session_state; pub mod shell; pub mod token_data; +pub mod tool_schema; pub mod tooling; +pub mod tools_config; pub mod truncate; pub mod turn_diff_tracker; pub mod unified_exec; pub use apply_patch::*; pub use bash::*; +pub use client_common::*; pub use command_safety::*; pub use config_types::*; pub use conversation_history::*; +pub use exec::*; pub use function_tool::*; +pub use model_client::*; pub use model_family::*; pub use model_provider::*; pub use notifications::*; @@ -42,7 +50,9 @@ pub use session_services::*; pub use session_state::*; pub use shell::*; pub use token_data::*; +pub use tool_schema::*; pub use tooling::*; +pub use tools_config::*; pub use truncate::*; pub use turn_diff_tracker::*; pub use unified_exec::*; diff --git a/codex-rs/agent/src/model_client.rs b/codex-rs/agent/src/model_client.rs new file mode 100644 index 0000000000..da307ebd60 --- /dev/null +++ b/codex-rs/agent/src/model_client.rs @@ -0,0 +1,34 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; +use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; + +use crate::client_common::Prompt; +use crate::client_common::ResponseStream; +use crate::model_family::ModelFamily; +use crate::model_provider::ModelProviderInfo; +use crate::services::CredentialsProvider; + +#[async_trait] +pub trait ModelClientAdapter: Send + Sync { + type Error: std::error::Error + Send + Sync + 'static; + + fn get_model_context_window(&self) -> Option; + + fn get_auto_compact_token_limit(&self) -> Option; + + fn get_provider(&self) -> ModelProviderInfo; + + fn get_model(&self) -> String; + + fn get_model_family(&self) -> ModelFamily; + + fn get_reasoning_effort(&self) -> Option; + + fn get_reasoning_summary(&self) -> ReasoningSummaryConfig; + + fn get_auth_manager(&self) -> Option>; + + async fn stream(&self, prompt: &Prompt) -> Result, Self::Error>; +} diff --git a/codex-rs/agent/src/runtime/mod.rs b/codex-rs/agent/src/runtime/mod.rs new file mode 100644 index 0000000000..bcfd8b5abe --- /dev/null +++ b/codex-rs/agent/src/runtime/mod.rs @@ -0,0 +1,5 @@ +pub mod session; + +pub use session::Session; +pub use session::TurnContext; +pub use session::ConfigureSession; diff --git a/codex-rs/agent/src/runtime/session.rs b/codex-rs/agent/src/runtime/session.rs new file mode 100644 index 0000000000..c7bce25992 --- /dev/null +++ b/codex-rs/agent/src/runtime/session.rs @@ -0,0 +1,2921 @@ +pub(crate) struct Session { + conversation_id: ConversationId, + tx_event: Sender, + state: Mutex, + pub(crate) active_turn: Mutex>, + services: SessionServices, + agent_config: Arc, + next_internal_sub_id: AtomicU64, +} + +/// The context needed for a single turn of the conversation. +pub(crate) struct TurnContext { + pub(crate) client: Arc>, + /// The session's current working directory. All relative paths provided by + /// the model as well as sandbox policies are resolved against this path + /// instead of `std::env::current_dir()`. + pub(crate) cwd: PathBuf, + pub(crate) base_instructions: Option, + pub(crate) user_instructions: Option, + pub(crate) approval_policy: AskForApproval, + pub(crate) sandbox_policy: SandboxPolicy, + pub(crate) shell_environment_policy: ShellEnvironmentPolicy, + pub(crate) tools_config: ToolsConfig, + pub(crate) is_review_mode: bool, + pub(crate) final_output_json_schema: Option, +} + +impl TurnContext { + fn resolve_path(&self, path: Option) -> PathBuf { + path.as_ref() + .map(PathBuf::from) + .map_or_else(|| self.cwd.clone(), |p| self.cwd.join(p)) + } +} + +impl std::fmt::Debug for TurnContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TurnContext") + .field("cwd", &self.cwd) + .field("base_instructions", &self.base_instructions) + .field("user_instructions", &self.user_instructions) + .field("approval_policy", &self.approval_policy) + .field("sandbox_policy", &self.sandbox_policy) + .field("shell_environment_policy", &self.shell_environment_policy) + .field("tools_config", &self.tools_config) + .field("is_review_mode", &self.is_review_mode) + .field("final_output_json_schema", &self.final_output_json_schema) + .finish() + } +} + +/// Configure the model session. +struct ConfigureSession { + /// Provider identifier ("openai", "openrouter", ...). + provider: ModelProviderInfo, + + /// If not specified, server will use its default model. + model: String, + + model_reasoning_effort: Option, + model_reasoning_summary: ReasoningSummaryConfig, + + /// Model instructions that are appended to the base instructions. + user_instructions: Option, + + /// Base instructions override. + base_instructions: Option, + + /// When to escalate for approval for execution + approval_policy: AskForApproval, + /// How to sandbox commands executed in the system + sandbox_policy: SandboxPolicy, + + /// Working directory that should be treated as the *root* of the + /// session. All relative paths supplied by the model as well as the + /// execution sandbox are resolved against this directory **instead** + /// of the process-wide current working directory. CLI front-ends are + /// expected to expand this to an absolute path before sending the + /// `ConfigureSession` operation so that the business-logic layer can + /// operate deterministically. + cwd: PathBuf, +} + +struct SessionBootstrap { + conversation_id: ConversationId, + rollout_sink: Arc, + rollout_path: PathBuf, + mcp: Arc, + default_shell: crate::shell::Shell, + history_log_id: u64, + history_entry_count: usize, + startup_errors: Vec, +} + +async fn prepare_session_bootstrap( + agent_config: Arc, + configure_session: &ConfigureSession, + initial_history: &InitialHistory, +) -> anyhow::Result { + let (conversation_id, rollout_params) = match initial_history { + InitialHistory::New | InitialHistory::Forked(_) => { + let conversation_id = ConversationId::default(); + ( + conversation_id, + RolloutRecorderParams::new( + conversation_id, + agent_config.cwd.clone(), + configure_session.user_instructions.clone(), + ), + ) + } + InitialHistory::Resumed(resumed_history) => ( + resumed_history.conversation_id, + RolloutRecorderParams::resume(resumed_history.rollout_path.clone()), + ), + }; + + let rollout_config = RolloutConfig { + codex_home: agent_config.codex_home.clone(), + originator: crate::default_client::ORIGINATOR.value.clone(), + cli_version: env!("CARGO_PKG_VERSION").to_string(), + git_info_collector: Some(Arc::new(CoreGitInfoCollector)), + }; + + let rollout_fut = RolloutRecorder::new(&rollout_config, rollout_params); + let mcp_fut = McpConnectionManager::new(agent_config.mcp_servers.clone()); + let default_shell_fut = shell::default_user_shell(); + let history_meta_fut = crate::message_history::history_metadata(&agent_config); + + let (rollout_recorder, mcp_res, default_shell, (history_log_id, history_entry_count)) = + tokio::join!(rollout_fut, mcp_fut, default_shell_fut, history_meta_fut); + + let rollout_recorder = rollout_recorder.map_err(|e| { + error!("failed to initialize rollout recorder: {e:#}"); + anyhow::anyhow!("failed to initialize rollout recorder: {e:#}") + })?; + let rollout_path = rollout_recorder.get_rollout_path(); + let rollout_sink: Arc = Arc::new(rollout_recorder); + + let mut startup_errors = Vec::new(); + + let (mcp, failed_clients) = match mcp_res { + Ok((mgr, failures)) => (Arc::new(mgr) as Arc, failures), + Err(e) => { + let message = format!("Failed to create MCP connection manager: {e:#}"); + error!("{message}"); + startup_errors.push(message); + ( + Arc::new(McpConnectionManager::default()) as Arc, + Default::default(), + ) + } + }; + + if !failed_clients.is_empty() { + for (server_name, err) in failed_clients { + let message = format!("MCP client for `{server_name}` failed to start: {err:#}"); + error!("{message}"); + startup_errors.push(message); + } + } + + Ok(SessionBootstrap { + conversation_id, + rollout_sink, + rollout_path, + mcp, + default_shell, + history_log_id, + history_entry_count, + startup_errors, + }) +} + +impl Session { + async fn new( + configure_session: ConfigureSession, + agent_config: Arc, + tx_event: Sender, + initial_history: InitialHistory, + bootstrap: SessionBootstrap, + services: SessionServices, + turn_context: Arc, + ) -> anyhow::Result> { + let ConfigureSession { + provider, + model, + model_reasoning_effort, + .. + } = configure_session; + debug!("Configuring session: model={model}; provider={provider:?}"); + + let SessionBootstrap { + conversation_id, + rollout_path, + history_log_id, + history_entry_count, + startup_errors, + .. + } = bootstrap; + + // Create the mutable state for the Session. + let state = SessionState::new(); + + let sess = Arc::new(Session { + conversation_id, + tx_event: tx_event.clone(), + state: Mutex::new(state), + active_turn: Mutex::new(None), + services, + agent_config: agent_config.clone(), + next_internal_sub_id: AtomicU64::new(0), + }); + + // Dispatch the SessionConfiguredEvent first and then report any errors. + // If resuming, include converted initial messages in the payload so UIs can render them immediately. + let initial_messages = initial_history.get_event_msgs(); + sess.record_initial_history(&turn_context, initial_history) + .await; + + let events = std::iter::once(Event { + id: INITIAL_SUBMIT_ID.to_owned(), + msg: EventMsg::SessionConfigured(SessionConfiguredEvent { + session_id: conversation_id, + model, + reasoning_effort: model_reasoning_effort, + history_log_id, + history_entry_count, + initial_messages, + rollout_path, + }), + }) + .chain(startup_errors.into_iter().map(|message| Event { + id: INITIAL_SUBMIT_ID.to_owned(), + msg: EventMsg::Error(ErrorEvent { message }), + })); + for event in events { + sess.send_event(event).await; + } + + Ok(sess) + } + + fn next_internal_sub_id(&self) -> String { + let id = self + .next_internal_sub_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + format!("auto-compact-{id}") + } + + async fn record_initial_history( + &self, + turn_context: &TurnContext, + conversation_history: InitialHistory, + ) { + match conversation_history { + InitialHistory::New => { + // Build and record initial items (user instructions + environment context) + let items = self.build_initial_context(turn_context); + self.record_conversation_items(&items).await; + } + InitialHistory::Resumed(_) | InitialHistory::Forked(_) => { + let rollout_items = conversation_history.get_rollout_items(); + let persist = matches!(conversation_history, InitialHistory::Forked(_)); + + // Always add response items to conversation history + let reconstructed_history = + self.reconstruct_history_from_rollout(turn_context, &rollout_items); + if !reconstructed_history.is_empty() { + self.record_into_history(&reconstructed_history).await; + } + + // If persisting, persist all rollout items as-is (recorder filters) + if persist && !rollout_items.is_empty() { + self.persist_rollout_items(&rollout_items).await; + } + } + } + } + + /// Persist the event to rollout and send it to clients. + pub(crate) async fn send_event(&self, event: Event) { + // Persist the event into rollout (recorder filters as needed) + let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())]; + self.persist_rollout_items(&rollout_items).await; + if let Err(e) = self.tx_event.send(event).await { + error!("failed to send tool call event: {e}"); + } + } + + pub async fn request_command_approval( + &self, + sub_id: String, + call_id: String, + command: Vec, + cwd: PathBuf, + reason: Option, + ) -> ReviewDecision { + // Add the tx_approve callback to the map before sending the request. + let (tx_approve, rx_approve) = oneshot::channel(); + let event_id = sub_id.clone(); + let prev_entry = { + let mut active = self.active_turn.lock().await; + match active.as_mut() { + Some(at) => { + let mut ts = at.turn_state.lock().await; + ts.insert_pending_approval(sub_id, tx_approve) + } + None => None, + } + }; + if prev_entry.is_some() { + warn!("Overwriting existing pending approval for sub_id: {event_id}"); + } + + let event = Event { + id: event_id, + msg: EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { + call_id, + command, + cwd, + reason, + }), + }; + self.send_event(event).await; + rx_approve.await.unwrap_or_default() + } + + pub async fn request_patch_approval( + &self, + sub_id: String, + call_id: String, + action: &ApplyPatchAction, + reason: Option, + grant_root: Option, + ) -> oneshot::Receiver { + // Add the tx_approve callback to the map before sending the request. + let (tx_approve, rx_approve) = oneshot::channel(); + let event_id = sub_id.clone(); + let prev_entry = { + let mut active = self.active_turn.lock().await; + match active.as_mut() { + Some(at) => { + let mut ts = at.turn_state.lock().await; + ts.insert_pending_approval(sub_id, tx_approve) + } + None => None, + } + }; + if prev_entry.is_some() { + warn!("Overwriting existing pending approval for sub_id: {event_id}"); + } + + let event = Event { + id: event_id, + msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { + call_id, + changes: convert_apply_patch_to_protocol(action), + reason, + grant_root, + }), + }; + self.send_event(event).await; + rx_approve + } + + pub async fn notify_approval(&self, sub_id: &str, decision: ReviewDecision) { + let entry = { + let mut active = self.active_turn.lock().await; + match active.as_mut() { + Some(at) => { + let mut ts = at.turn_state.lock().await; + ts.remove_pending_approval(sub_id) + } + None => None, + } + }; + match entry { + Some(tx_approve) => { + tx_approve.send(decision).ok(); + } + None => { + warn!("No pending approval found for sub_id: {sub_id}"); + } + } + } + + pub async fn add_approved_command(&self, cmd: Vec) { + let mut state = self.state.lock().await; + state.add_approved_command(cmd); + } + + /// Records input items: always append to conversation history and + /// persist these response items to rollout. + async fn record_conversation_items(&self, items: &[ResponseItem]) { + self.record_into_history(items).await; + self.persist_rollout_response_items(items).await; + } + + fn reconstruct_history_from_rollout( + &self, + turn_context: &TurnContext, + rollout_items: &[RolloutItem], + ) -> Vec { + let mut history = ConversationHistory::new(); + for item in rollout_items { + match item { + RolloutItem::ResponseItem(response_item) => { + history.record_items(std::iter::once(response_item)); + } + RolloutItem::Compacted(compacted) => { + let snapshot = history.contents(); + let user_messages = collect_user_messages(&snapshot); + let rebuilt = build_compacted_history( + self.build_initial_context(turn_context), + &user_messages, + &compacted.message, + ); + history.replace(rebuilt); + } + _ => {} + } + } + history.contents() + } + + /// Append ResponseItems to the in-memory conversation history only. + async fn record_into_history(&self, items: &[ResponseItem]) { + let mut state = self.state.lock().await; + state.record_items(items.iter()); + } + + async fn replace_history(&self, items: Vec) { + let mut state = self.state.lock().await; + state.replace_history(items); + } + + async fn persist_rollout_response_items(&self, items: &[ResponseItem]) { + let rollout_items: Vec = items + .iter() + .cloned() + .map(RolloutItem::ResponseItem) + .collect(); + self.persist_rollout_items(&rollout_items).await; + } + + pub(crate) fn build_initial_context(&self, turn_context: &TurnContext) -> Vec { + let mut items = Vec::::with_capacity(2); + if let Some(user_instructions) = turn_context.user_instructions.as_deref() { + items.push(UserInstructions::new(user_instructions.to_string()).into()); + } + items.push(ResponseItem::from(EnvironmentContext::new( + Some(turn_context.cwd.clone()), + Some(turn_context.approval_policy), + Some(turn_context.sandbox_policy.clone()), + Some(self.user_shell().clone()), + ))); + items + } + + async fn persist_rollout_items(&self, items: &[RolloutItem]) { + let recorder = { + let guard = self.services.rollout.lock().await; + guard.clone() + }; + if let Some(rec) = recorder + && let Err(e) = rec.record_items(items).await + { + error!("failed to record rollout items: {e:#}"); + } + } + + pub(crate) async fn history_snapshot(&self) -> Vec { + let state = self.state.lock().await; + state.history_snapshot() + } + + async fn update_token_usage_info( + &self, + sub_id: &str, + turn_context: &TurnContext, + token_usage: Option<&TokenUsage>, + ) { + { + let mut state = self.state.lock().await; + if let Some(token_usage) = token_usage { + state.update_token_info_from_usage( + token_usage, + turn_context.client.get_model_context_window(), + ); + } + } + self.send_token_count_event(sub_id).await; + } + + async fn update_rate_limits(&self, sub_id: &str, new_rate_limits: RateLimitSnapshot) { + { + let mut state = self.state.lock().await; + state.set_rate_limits(new_rate_limits); + } + self.send_token_count_event(sub_id).await; + } + + async fn send_token_count_event(&self, sub_id: &str) { + let (info, rate_limits) = { + let state = self.state.lock().await; + state.token_info_and_rate_limits() + }; + let event = Event { + id: sub_id.to_string(), + msg: EventMsg::TokenCount(TokenCountEvent { info, rate_limits }), + }; + self.send_event(event).await; + } + + /// Record a user input item to conversation history and also persist a + /// corresponding UserMessage EventMsg to rollout. + async fn record_input_and_rollout_usermsg(&self, response_input: &ResponseInputItem) { + let response_item: ResponseItem = response_input.clone().into(); + // Add to conversation history and persist response item to rollout + self.record_conversation_items(std::slice::from_ref(&response_item)) + .await; + + // Derive user message events and persist only UserMessage to rollout + let msgs = + map_response_item_to_event_messages(&response_item, self.show_raw_agent_reasoning()); + let user_msgs: Vec = msgs + .into_iter() + .filter_map(|m| match m { + EventMsg::UserMessage(ev) => Some(RolloutItem::EventMsg(EventMsg::UserMessage(ev))), + _ => None, + }) + .collect(); + if !user_msgs.is_empty() { + self.persist_rollout_items(&user_msgs).await; + } + } + + async fn on_exec_command_begin( + &self, + turn_diff_tracker: &mut TurnDiffTracker, + exec_command_context: ExecCommandContext, + ) { + let ExecCommandContext { + sub_id, + call_id, + command_for_display, + cwd, + apply_patch, + } = exec_command_context; + let msg = match apply_patch { + Some(ApplyPatchCommandContext { + user_explicitly_approved_this_action, + changes, + }) => { + turn_diff_tracker.on_patch_begin(&changes); + + EventMsg::PatchApplyBegin(PatchApplyBeginEvent { + call_id, + auto_approved: !user_explicitly_approved_this_action, + changes, + }) + } + None => EventMsg::ExecCommandBegin(ExecCommandBeginEvent { + call_id, + command: command_for_display.clone(), + cwd, + parsed_cmd: parse_command(&command_for_display) + .into_iter() + .map(Into::into) + .collect(), + }), + }; + let event = Event { + id: sub_id.to_string(), + msg, + }; + self.send_event(event).await; + } + + async fn on_exec_command_end( + &self, + turn_diff_tracker: &mut TurnDiffTracker, + sub_id: &str, + call_id: &str, + output: &ExecToolCallOutput, + is_apply_patch: bool, + ) { + let ExecToolCallOutput { + stdout, + stderr, + aggregated_output, + duration, + exit_code, + timed_out: _, + } = output; + // Send full stdout/stderr to clients; do not truncate. + let stdout = stdout.text.clone(); + let stderr = stderr.text.clone(); + let formatted_output = format_exec_output_str(output); + let aggregated_output: String = aggregated_output.text.clone(); + + let msg = if is_apply_patch { + EventMsg::PatchApplyEnd(PatchApplyEndEvent { + call_id: call_id.to_string(), + stdout, + stderr, + success: *exit_code == 0, + }) + } else { + EventMsg::ExecCommandEnd(ExecCommandEndEvent { + call_id: call_id.to_string(), + stdout, + stderr, + aggregated_output, + exit_code: *exit_code, + duration: *duration, + formatted_output, + }) + }; + + let event = Event { + id: sub_id.to_string(), + msg, + }; + self.send_event(event).await; + + // If this is an apply_patch, after we emit the end patch, emit a second event + // with the full turn diff if there is one. + if is_apply_patch { + let unified_diff = turn_diff_tracker.get_unified_diff(); + if let Ok(Some(unified_diff)) = unified_diff { + let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff }); + let event = Event { + id: sub_id.into(), + msg, + }; + self.send_event(event).await; + } + } + } + /// Runs the exec tool call and emits events for the begin and end of the + /// command even on error. + /// + /// Returns the output of the exec tool call. + async fn run_exec_with_events<'a>( + &self, + turn_diff_tracker: &mut TurnDiffTracker, + begin_ctx: ExecCommandContext, + exec_args: ExecInvokeArgs<'a>, + ) -> crate::error::Result { + let is_apply_patch = begin_ctx.apply_patch.is_some(); + let sub_id = begin_ctx.sub_id.clone(); + let call_id = begin_ctx.call_id.clone(); + + self.on_exec_command_begin(turn_diff_tracker, begin_ctx.clone()) + .await; + + let ExecInvokeArgs { + params, + plan, + sandbox_policy, + sandbox_cwd, + codex_linux_sandbox_exe, + stdout_stream, + } = exec_args; + + let registry = BackendRegistry::new(); + let runtime_ctx = ExecRuntimeContext { + sandbox_policy, + sandbox_cwd, + codex_linux_sandbox_exe, + stdout_stream, + }; + + let result = run_with_plan(params, &plan, ®istry, &runtime_ctx).await; + + let output_stderr; + let borrowed: &ExecToolCallOutput = match &result { + Ok(output) => output, + Err(CodexErr::Sandbox(SandboxErr::Timeout { output })) => output, + Err(e) => { + output_stderr = ExecToolCallOutput { + exit_code: -1, + stdout: StreamOutput::new(String::new()), + stderr: StreamOutput::new(get_error_message_ui(e)), + aggregated_output: StreamOutput::new(get_error_message_ui(e)), + duration: Duration::default(), + timed_out: false, + }; + &output_stderr + } + }; + self.on_exec_command_end( + turn_diff_tracker, + &sub_id, + &call_id, + borrowed, + is_apply_patch, + ) + .await; + + result + } + + /// Helper that emits a BackgroundEvent with the given message. This keeps + /// the call‑sites terse so adding more diagnostics does not clutter the + /// core agent logic. + async fn notify_background_event(&self, sub_id: &str, message: impl Into) { + let event = Event { + id: sub_id.to_string(), + msg: EventMsg::BackgroundEvent(BackgroundEventEvent { + message: message.into(), + }), + }; + self.send_event(event).await; + } + + async fn notify_stream_error(&self, sub_id: &str, message: impl Into) { + let event = Event { + id: sub_id.to_string(), + msg: EventMsg::StreamError(StreamErrorEvent { + message: message.into(), + }), + }; + self.send_event(event).await; + } + + /// Build the full turn input by concatenating the current conversation + /// history with additional items for this turn. + pub async fn turn_input_with_history(&self, extra: Vec) -> Vec { + let history = { + let state = self.state.lock().await; + state.history_snapshot() + }; + [history, extra].concat() + } + + /// Returns the input if there was no task running to inject into + pub async fn inject_input(&self, input: Vec) -> Result<(), Vec> { + let mut active = self.active_turn.lock().await; + match active.as_mut() { + Some(at) => { + let mut ts = at.turn_state.lock().await; + ts.push_pending_input(input.into()); + Ok(()) + } + None => Err(input), + } + } + + pub async fn get_pending_input(&self) -> Vec { + let mut active = self.active_turn.lock().await; + match active.as_mut() { + Some(at) => { + let mut ts = at.turn_state.lock().await; + ts.take_pending_input() + } + None => Vec::with_capacity(0), + } + } + + pub async fn call_tool( + &self, + server: &str, + tool: &str, + arguments: Option, + ) -> anyhow::Result { + self.services.mcp.call_tool(server, tool, arguments).await + self.services.mcp.call_tool(server, tool, arguments).await + } + + pub async fn interrupt_task(self: &Arc) { + info!("interrupt received: abort current task, if any"); + self.abort_all_tasks(TurnAbortReason::Interrupted).await; + } + + fn interrupt_task_sync(&self) { + if let Ok(mut active) = self.active_turn.try_lock() + && let Some(at) = active.as_mut() + { + at.try_clear_pending_sync(); + let tasks = at.drain_tasks(); + *active = None; + for (_sub_id, task) in tasks { + task.handle.abort(); + } + } + } + + pub(crate) fn notifier(&self) -> &dyn Notifier { + self.services.notifier.as_ref() + } + + fn user_shell(&self) -> &shell::Shell { + self.services.sandbox.user_shell() + } + + fn show_raw_agent_reasoning(&self) -> bool { + self.services.show_raw_agent_reasoning + } +} + +impl Drop for Session { + fn drop(&mut self) { + self.interrupt_task_sync(); + } +} + +#[derive(Clone, Debug)] +pub(crate) struct ExecCommandContext { + pub(crate) sub_id: String, + pub(crate) call_id: String, + pub(crate) command_for_display: Vec, + pub(crate) cwd: PathBuf, + pub(crate) apply_patch: Option, +} + +#[derive(Clone, Debug)] +pub(crate) struct ApplyPatchCommandContext { + pub(crate) user_explicitly_approved_this_action: bool, + pub(crate) changes: HashMap, +} + +async fn submission_loop( + sess: Arc, + turn_context: Arc, + agent_config: Arc, + rx_sub: Receiver, +) { + let mut turn_context = turn_context; + // To break out of this loop, send Op::Shutdown. + while let Ok(sub) = rx_sub.recv().await { + debug!(?sub, "Submission"); + match sub.op { + Op::Interrupt => { + sess.interrupt_task().await; + } + Op::OverrideTurnContext { + cwd, + approval_policy, + sandbox_policy, + model, + effort, + summary, + } => { + // Recalculate the persistent turn context with provided overrides. + let prev = Arc::clone(&turn_context); + let provider = prev.client.get_provider(); + + // Effective model + family + let (effective_model, effective_family) = if let Some(ref m) = model { + let fam = find_family_for_model(m) + .unwrap_or_else(|| agent_config.model_family.clone()); + (m.clone(), fam) + } else { + (prev.client.get_model(), prev.client.get_model_family()) + }; + + // Effective reasoning settings + let effective_effort = effort.unwrap_or(prev.client.get_reasoning_effort()); + let effective_summary = summary.unwrap_or(prev.client.get_reasoning_summary()); + + let auth_manager = prev.client.get_auth_manager(); + + // Build updated config for the client + let mut updated_config = (*agent_config).clone(); + updated_config.model = effective_model.clone(); + updated_config.model_family = effective_family.clone(); + if let Some(model_info) = get_model_info(&effective_family) { + updated_config.model_context_window = Some(model_info.context_window); + } + + let client: Arc> = + Arc::new(CoreModelClientAdapter::new(ModelClient::new( + Arc::new(updated_config), + auth_manager, + provider, + effective_effort, + effective_summary, + sess.conversation_id, + ))); + + let new_approval_policy = approval_policy.unwrap_or(prev.approval_policy); + let new_sandbox_policy = sandbox_policy + .clone() + .unwrap_or(prev.sandbox_policy.clone()); + let new_cwd = cwd.clone().unwrap_or_else(|| prev.cwd.clone()); + + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_family: &effective_family, + include_plan_tool: agent_config.include_plan_tool, + include_apply_patch_tool: agent_config.include_apply_patch_tool, + include_web_search_request: agent_config.tools_web_search_request, + use_streamable_shell_tool: agent_config.use_experimental_streamable_shell_tool, + include_view_image_tool: agent_config.include_view_image_tool, + experimental_unified_exec_tool: agent_config.use_experimental_unified_exec_tool, + }); + + let new_turn_context = TurnContext { + client, + tools_config, + user_instructions: prev.user_instructions.clone(), + base_instructions: prev.base_instructions.clone(), + approval_policy: new_approval_policy, + sandbox_policy: new_sandbox_policy.clone(), + shell_environment_policy: prev.shell_environment_policy.clone(), + cwd: new_cwd.clone(), + is_review_mode: false, + final_output_json_schema: None, + }; + + // Install the new persistent context for subsequent tasks/turns. + turn_context = Arc::new(new_turn_context); + + // Optionally persist changes to model / effort + if cwd.is_some() || approval_policy.is_some() || sandbox_policy.is_some() { + sess.record_conversation_items(&[ResponseItem::from(EnvironmentContext::new( + cwd, + approval_policy, + sandbox_policy, + // Shell is not configurable from turn to turn + None, + ))]) + .await; + } + } + Op::UserInput { items } => { + // attempt to inject input into current task + if let Err(items) = sess.inject_input(items).await { + // no current task, spawn a new one + sess.spawn_task(Arc::clone(&turn_context), sub.id, items, RegularTask) + .await; + } + } + Op::UserTurn { + items, + cwd, + approval_policy, + sandbox_policy, + model, + effort, + summary, + final_output_json_schema, + } => { + // attempt to inject input into current task + if let Err(items) = sess.inject_input(items).await { + // Derive a fresh TurnContext for this turn using the provided overrides. + let provider = turn_context.client.get_provider(); + let auth_manager = turn_context.client.get_auth_manager(); + + // Derive a model family for the requested model; fall back to the session's. + let model_family = find_family_for_model(&model) + .unwrap_or_else(|| agent_config.model_family.clone()); + + // Create a per‑turn Config clone with the requested model/family. + let mut per_turn_config = (*agent_config).clone(); + per_turn_config.model = model.clone(); + per_turn_config.model_family = model_family.clone(); + if let Some(model_info) = get_model_info(&model_family) { + per_turn_config.model_context_window = Some(model_info.context_window); + } + + // Build a new client with per‑turn reasoning settings. + // Reuse the same provider and session id; auth defaults to env/API key. + let client: Arc> = + Arc::new(CoreModelClientAdapter::new(ModelClient::new( + Arc::new(per_turn_config), + auth_manager, + provider, + effort, + summary, + sess.conversation_id, + ))); + + let fresh_turn_context = TurnContext { + client, + tools_config: ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + include_plan_tool: agent_config.include_plan_tool, + include_apply_patch_tool: agent_config.include_apply_patch_tool, + include_web_search_request: agent_config.tools_web_search_request, + use_streamable_shell_tool: agent_config + .use_experimental_streamable_shell_tool, + include_view_image_tool: agent_config.include_view_image_tool, + experimental_unified_exec_tool: agent_config + .use_experimental_unified_exec_tool, + }), + user_instructions: turn_context.user_instructions.clone(), + base_instructions: turn_context.base_instructions.clone(), + approval_policy, + sandbox_policy, + shell_environment_policy: turn_context.shell_environment_policy.clone(), + cwd, + is_review_mode: false, + final_output_json_schema, + }; + + // if the environment context has changed, record it in the conversation history + let previous_env_context = EnvironmentContext::from(turn_context.as_ref()); + let new_env_context = EnvironmentContext::from(&fresh_turn_context); + if !new_env_context.equals_except_shell(&previous_env_context) { + sess.record_conversation_items(&[ResponseItem::from(new_env_context)]) + .await; + } + + // Install the new persistent context for subsequent tasks/turns. + turn_context = Arc::new(fresh_turn_context); + + // no current task, spawn a new one with the per-turn context + sess.spawn_task(Arc::clone(&turn_context), sub.id, items, RegularTask) + .await; + } + } + Op::ExecApproval { id, decision } => match decision { + ReviewDecision::Abort => { + sess.interrupt_task().await; + } + other => sess.notify_approval(&id, other).await, + }, + Op::PatchApproval { id, decision } => match decision { + ReviewDecision::Abort => { + sess.interrupt_task().await; + } + other => sess.notify_approval(&id, other).await, + }, + Op::AddToHistory { text } => { + let id = sess.conversation_id; + let agent_config_clone = agent_config.clone(); + tokio::spawn(async move { + if let Err(e) = + crate::message_history::append_entry(&text, &id, &agent_config_clone).await + { + warn!("failed to append to message history: {e}"); + } + }); + } + + Op::GetHistoryEntryRequest { offset, log_id } => { + let agent_config_clone = agent_config.clone(); + let sess_clone = sess.clone(); + let sub_id = sub.id.clone(); + + tokio::spawn(async move { + // Run lookup in blocking thread because it does file IO + locking. + let entry_opt = tokio::task::spawn_blocking(move || { + crate::message_history::lookup(log_id, offset, &agent_config_clone) + }) + .await + .unwrap_or(None); + + let event = Event { + id: sub_id, + msg: EventMsg::GetHistoryEntryResponse( + crate::protocol::GetHistoryEntryResponseEvent { + offset, + log_id, + entry: entry_opt.map(|e| { + codex_protocol::message_history::HistoryEntry { + conversation_id: e.session_id, + ts: e.ts, + text: e.text, + } + }), + }, + ), + }; + + sess_clone.send_event(event).await; + }); + } + Op::ListMcpTools => { + let sub_id = sub.id.clone(); + + // This is a cheap lookup from the connection manager's cache. + let tools = sess.services.mcp.list_all_tools(); + let event = Event { + id: sub_id, + msg: EventMsg::McpListToolsResponse( + crate::protocol::McpListToolsResponseEvent { tools }, + ), + }; + sess.send_event(event).await; + } + Op::ListCustomPrompts => { + let sub_id = sub.id.clone(); + + let custom_prompts: Vec = + if let Some(dir) = crate::custom_prompts::default_prompts_dir() { + crate::custom_prompts::discover_prompts_in(&dir).await + } else { + Vec::new() + }; + + let event = Event { + id: sub_id, + msg: EventMsg::ListCustomPromptsResponse(ListCustomPromptsResponseEvent { + custom_prompts, + }), + }; + sess.send_event(event).await; + } + Op::Compact => { + // Attempt to inject input into current task + if let Err(items) = sess + .inject_input(vec![InputItem::Text { + text: compact::SUMMARIZATION_PROMPT.to_string(), + }]) + .await + { + sess.spawn_task(Arc::clone(&turn_context), sub.id, items, CompactTask) + .await; + } + } + Op::Shutdown => { + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + info!("Shutting down Codex instance"); + + // Gracefully flush and shutdown rollout recorder on session end so tests + // that inspect the rollout file do not race with the background writer. + let recorder_opt = { + let mut guard = sess.services.rollout.lock().await; + guard.take() + }; + if let Some(rec) = recorder_opt + && let Err(e) = rec.shutdown().await + { + warn!("failed to shutdown rollout recorder: {e}"); + let event = Event { + id: sub.id.clone(), + msg: EventMsg::Error(ErrorEvent { + message: "Failed to shutdown rollout recorder".to_string(), + }), + }; + sess.send_event(event).await; + } + + let event = Event { + id: sub.id.clone(), + msg: EventMsg::ShutdownComplete, + }; + sess.send_event(event).await; + break; + } + Op::GetPath => { + let sub_id = sub.id.clone(); + // Flush rollout writes before returning the path so readers observe a consistent file. + let (path, rec_opt) = { + let guard = sess.services.rollout.lock().await; + match guard.as_ref() { + Some(rec) => (rec.get_rollout_path(), Some(rec.clone())), + None => { + error!("rollout recorder not found"); + continue; + } + } + }; + if let Some(rec) = rec_opt + && let Err(e) = rec.flush().await + { + warn!("failed to flush rollout recorder before GetHistory: {e}"); + } + let event = Event { + id: sub_id.clone(), + msg: EventMsg::ConversationPath(ConversationPathResponseEvent { + conversation_id: sess.conversation_id, + path, + }), + }; + sess.send_event(event).await; + } + Op::Review { review_request } => { + spawn_review_thread( + sess.clone(), + sess.agent_config.clone(), + turn_context.clone(), + sub.id, + review_request, + ) + .await; + } + _ => { + // Ignore unknown ops; enum is non_exhaustive to allow extensions. + } + } + } + debug!("Agent loop exited"); +} + +/// Spawn a review thread using the given prompt. +async fn spawn_review_thread( + sess: Arc, + agent_config: Arc, + parent_turn_context: Arc, + sub_id: String, + review_request: ReviewRequest, +) { + let model = agent_config.review_model.clone(); + let review_model_family = find_family_for_model(&model) + .unwrap_or_else(|| parent_turn_context.client.get_model_family()); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_family: &review_model_family, + include_plan_tool: false, + include_apply_patch_tool: agent_config.include_apply_patch_tool, + include_web_search_request: false, + use_streamable_shell_tool: false, + include_view_image_tool: false, + experimental_unified_exec_tool: agent_config.use_experimental_unified_exec_tool, + }); + + let base_instructions = REVIEW_PROMPT.to_string(); + let review_prompt = review_request.prompt.clone(); + let provider = parent_turn_context.client.get_provider(); + let auth_manager = parent_turn_context.client.get_auth_manager(); + let model_family = review_model_family.clone(); + + // Build per‑turn client with the requested model/family. + let mut per_turn_config = (*agent_config).clone(); + per_turn_config.model = model.clone(); + per_turn_config.model_family = model_family.clone(); + per_turn_config.model_reasoning_effort = Some(ReasoningEffortConfig::Low); + per_turn_config.model_reasoning_summary = ReasoningSummaryConfig::Detailed; + if let Some(model_info) = get_model_info(&model_family) { + per_turn_config.model_context_window = Some(model_info.context_window); + } + + let per_turn_config = Arc::new(per_turn_config); + let client: Arc> = + Arc::new(CoreModelClientAdapter::new(ModelClient::new( + per_turn_config.clone(), + auth_manager, + provider, + per_turn_config.model_reasoning_effort, + per_turn_config.model_reasoning_summary, + sess.conversation_id, + ))); + + let review_turn_context = TurnContext { + client, + tools_config, + user_instructions: None, + base_instructions: Some(base_instructions.clone()), + approval_policy: parent_turn_context.approval_policy, + sandbox_policy: parent_turn_context.sandbox_policy.clone(), + shell_environment_policy: parent_turn_context.shell_environment_policy.clone(), + cwd: parent_turn_context.cwd.clone(), + is_review_mode: true, + final_output_json_schema: None, + }; + + // Seed the child task with the review prompt as the initial user message. + let input: Vec = vec![InputItem::Text { + text: format!("{base_instructions}\n\n---\n\nNow, here's your task: {review_prompt}"), + }]; + let tc = Arc::new(review_turn_context); + + // Clone sub_id for the upcoming announcement before moving it into the task. + let sub_id_for_event = sub_id.clone(); + sess.spawn_task(tc.clone(), sub_id, input, ReviewTask).await; + + // Announce entering review mode so UIs can switch modes. + sess.send_event(Event { + id: sub_id_for_event, + msg: EventMsg::EnteredReviewMode(review_request), + }) + .await; +} + +/// Takes a user message as input and runs a loop where, at each turn, the model +/// replies with either: +/// +/// - requested function calls +/// - an assistant message +/// +/// While it is possible for the model to return multiple of these items in a +/// single turn, in practice, we generally one item per turn: +/// +/// - If the model requests a function call, we execute it and send the output +/// back to the model in the next turn. +/// - If the model sends only an assistant message, we record it in the +/// conversation history and consider the task complete. +/// +/// Review mode: when `turn_context.is_review_mode` is true, the turn runs in an +/// isolated in-memory thread without the parent session's prior history or +/// user_instructions. Emits ExitedReviewMode upon final review message. +pub(crate) async fn run_task( + sess: Arc, + turn_context: Arc, + sub_id: String, + input: Vec, +) -> Option { + if input.is_empty() { + return None; + } + let event = Event { + id: sub_id.clone(), + msg: EventMsg::TaskStarted(TaskStartedEvent { + model_context_window: turn_context.client.get_model_context_window(), + }), + }; + sess.send_event(event).await; + + let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input); + // For review threads, keep an isolated in-memory history so the + // model sees a fresh conversation without the parent session's history. + // For normal turns, continue recording to the session history as before. + let is_review_mode = turn_context.is_review_mode; + let mut review_thread_history: Vec = Vec::new(); + if is_review_mode { + // Seed review threads with environment context so the model knows the working directory. + review_thread_history.extend(sess.build_initial_context(turn_context.as_ref())); + review_thread_history.push(initial_input_for_turn.into()); + } else { + sess.record_input_and_rollout_usermsg(&initial_input_for_turn) + .await; + } + + let mut last_agent_message: Option = None; + // Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains + // many turns, from the perspective of the user, it is a single turn. + let mut turn_diff_tracker = TurnDiffTracker::new(); + let mut auto_compact_recently_attempted = false; + + loop { + // Note that pending_input would be something like a message the user + // submitted through the UI while the model was running. Though the UI + // may support this, the model might not. + let pending_input = sess + .get_pending_input() + .await + .into_iter() + .map(ResponseItem::from) + .collect::>(); + + // Construct the input that we will send to the model. + // + // - For review threads, use the isolated in-memory history so the + // model sees a fresh conversation (no parent history/user_instructions). + // + // - For normal turns, use the session's full history. When using the + // chat completions API (or ZDR clients), the model needs the full + // conversation history on each turn. The rollout file, however, should + // only record the new items that originated in this turn so that it + // represents an append-only log without duplicates. + let turn_input: Vec = if is_review_mode { + if !pending_input.is_empty() { + review_thread_history.extend(pending_input); + } + review_thread_history.clone() + } else { + sess.record_conversation_items(&pending_input).await; + sess.turn_input_with_history(pending_input).await + }; + + let turn_input_messages: Vec = turn_input + .iter() + .filter_map(|item| match item { + ResponseItem::Message { content, .. } => Some(content), + _ => None, + }) + .flat_map(|content| { + content.iter().filter_map(|item| match item { + ContentItem::OutputText { text } => Some(text.clone()), + _ => None, + }) + }) + .collect(); + match run_turn( + &sess, + turn_context.as_ref(), + &mut turn_diff_tracker, + sub_id.clone(), + turn_input, + ) + .await + { + Ok(turn_output) => { + let TurnRunResult { + processed_items, + total_token_usage, + } = turn_output; + let limit = turn_context + .client + .get_auto_compact_token_limit() + .unwrap_or(i64::MAX); + let total_usage_tokens = total_token_usage + .as_ref() + .map(TokenUsage::tokens_in_context_window); + let token_limit_reached = total_usage_tokens + .map(|tokens| (tokens as i64) >= limit) + .unwrap_or(false); + let mut items_to_record_in_conversation_history = Vec::::new(); + let mut responses = Vec::::new(); + for processed_response_item in processed_items { + let ProcessedResponseItem { item, response } = processed_response_item; + match (&item, &response) { + (ResponseItem::Message { role, .. }, None) if role == "assistant" => { + // If the model returned a message, we need to record it. + items_to_record_in_conversation_history.push(item); + } + ( + ResponseItem::LocalShellCall { .. }, + Some(ResponseInputItem::FunctionCallOutput { call_id, output }), + ) => { + items_to_record_in_conversation_history.push(item); + items_to_record_in_conversation_history.push( + ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output: output.clone(), + }, + ); + } + ( + ResponseItem::FunctionCall { .. }, + Some(ResponseInputItem::FunctionCallOutput { call_id, output }), + ) => { + items_to_record_in_conversation_history.push(item); + items_to_record_in_conversation_history.push( + ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output: output.clone(), + }, + ); + } + ( + ResponseItem::CustomToolCall { .. }, + Some(ResponseInputItem::CustomToolCallOutput { call_id, output }), + ) => { + items_to_record_in_conversation_history.push(item); + items_to_record_in_conversation_history.push( + ResponseItem::CustomToolCallOutput { + call_id: call_id.clone(), + output: output.clone(), + }, + ); + } + ( + ResponseItem::FunctionCall { .. }, + Some(ResponseInputItem::McpToolCallOutput { call_id, result }), + ) => { + items_to_record_in_conversation_history.push(item); + let output = match result { + Ok(call_tool_result) => { + convert_call_tool_result_to_function_call_output_payload( + call_tool_result, + ) + } + Err(err) => FunctionCallOutputPayload { + content: err.clone(), + success: Some(false), + }, + }; + items_to_record_in_conversation_history.push( + ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output, + }, + ); + } + ( + ResponseItem::Reasoning { + id, + summary, + content, + encrypted_content, + }, + None, + ) => { + items_to_record_in_conversation_history.push(ResponseItem::Reasoning { + id: id.clone(), + summary: summary.clone(), + content: content.clone(), + encrypted_content: encrypted_content.clone(), + }); + } + _ => { + warn!("Unexpected response item: {item:?} with response: {response:?}"); + } + }; + if let Some(response) = response { + responses.push(response); + } + } + + // Only attempt to take the lock if there is something to record. + if !items_to_record_in_conversation_history.is_empty() { + if is_review_mode { + review_thread_history + .extend(items_to_record_in_conversation_history.clone()); + } else { + sess.record_conversation_items(&items_to_record_in_conversation_history) + .await; + } + } + + if token_limit_reached { + if auto_compact_recently_attempted { + let limit_str = limit.to_string(); + let current_tokens = total_usage_tokens + .map(|tokens| tokens.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + let event = Event { + id: sub_id.clone(), + msg: EventMsg::Error(ErrorEvent { + message: format!( + "Conversation is still above the token limit after automatic summarization (limit {limit_str}, current {current_tokens}). Please start a new session or trim your input." + ), + }), + }; + sess.send_event(event).await; + break; + } + auto_compact_recently_attempted = true; + compact::run_inline_auto_compact_task(sess.clone(), turn_context.clone()).await; + continue; + } + + auto_compact_recently_attempted = false; + + if responses.is_empty() { + last_agent_message = get_last_assistant_message_from_turn( + &items_to_record_in_conversation_history, + ); + sess.notifier() + .notify(&UserNotification::AgentTurnComplete { + turn_id: sub_id.clone(), + input_messages: turn_input_messages, + last_assistant_message: last_agent_message.clone(), + }); + break; + } + continue; + } + Err(e) => { + info!("Turn error: {e:#}"); + let event = Event { + id: sub_id.clone(), + msg: EventMsg::Error(ErrorEvent { + message: e.to_string(), + }), + }; + sess.send_event(event).await; + // let the user continue the conversation + break; + } + } + } + + // If this was a review thread and we have a final assistant message, + // try to parse it as a ReviewOutput. + // + // If parsing fails, construct a minimal ReviewOutputEvent using the plain + // text as the overall explanation. Else, just exit review mode with None. + // + // Emits an ExitedReviewMode event with the parsed review output. + if turn_context.is_review_mode { + exit_review_mode( + sess.clone(), + sub_id.clone(), + last_agent_message.as_deref().map(parse_review_output_event), + ) + .await; + } + + last_agent_message +} + +/// Parse the review output; when not valid JSON, build a structured +/// fallback that carries the plain text as the overall explanation. +/// +/// Returns: a ReviewOutputEvent parsed from JSON or a fallback populated from text. +fn parse_review_output_event(text: &str) -> ReviewOutputEvent { + // Try direct parse first + if let Ok(ev) = serde_json::from_str::(text) { + return ev; + } + // If wrapped in markdown fences or extra prose, attempt to extract the first JSON object + if let (Some(start), Some(end)) = (text.find('{'), text.rfind('}')) + && start < end + && let Some(slice) = text.get(start..=end) + && let Ok(ev) = serde_json::from_str::(slice) + { + return ev; + } + // Not JSON – return a structured ReviewOutputEvent that carries + // the plain text as the overall explanation. + ReviewOutputEvent { + overall_explanation: text.to_string(), + ..Default::default() + } +} + +async fn run_turn( + sess: &Session, + turn_context: &TurnContext, + turn_diff_tracker: &mut TurnDiffTracker, + sub_id: String, + input: Vec, +) -> CodexResult { + let tools = get_openai_tools( + &turn_context.tools_config, + Some(sess.services.mcp.list_all_tools()), + ); + + let prompt = Prompt { + input, + tools, + base_instructions_override: turn_context.base_instructions.clone(), + output_schema: turn_context.final_output_json_schema.clone(), + }; + + let mut retries = 0; + loop { + match try_run_turn(sess, turn_context, turn_diff_tracker, &sub_id, &prompt).await { + Ok(output) => return Ok(output), + Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted), + Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), + Err(CodexErr::UsageLimitReached(e)) => { + let rate_limits = e.rate_limits.clone(); + if let Some(rate_limits) = rate_limits { + if let Some(rate_limits) = rate_limits { + sess.update_rate_limits(&sub_id, rate_limits).await; + } + return Err(CodexErr::UsageLimitReached(e)); + } + Err(CodexErr::UsageNotIncluded) => return Err(CodexErr::UsageNotIncluded), + Err(e) => { + // Use the configured provider-specific stream retry budget. + let max_retries = turn_context.client.get_provider().stream_max_retries(); + if retries < max_retries { + retries += 1; + let delay = match e { + CodexErr::Stream(_, Some(delay)) => delay, + _ => backoff(retries), + }; + warn!( + "stream disconnected - retrying turn ({retries}/{max_retries} in {delay:?})...", + ); + + // Surface retry information to any UI/front‑end so the + // user understands what is happening instead of staring + // at a seemingly frozen screen. + sess.notify_stream_error( + &sub_id, + format!( + "stream error: {e}; retrying {retries}/{max_retries} in {delay:?}…" + ), + ) + .await; + + tokio::time::sleep(delay).await; + } else { + return Err(e); + } + } + } + } +} + +/// When the model is prompted, it returns a stream of events. Some of these +/// events map to a `ResponseItem`. A `ResponseItem` may need to be +/// "handled" such that it produces a `ResponseInputItem` that needs to be +/// sent back to the model on the next turn. +#[derive(Debug)] +struct ProcessedResponseItem { + item: ResponseItem, + response: Option, +} + +#[derive(Debug)] +struct TurnRunResult { + processed_items: Vec, + total_token_usage: Option, +} + +async fn try_run_turn( + sess: &Session, + turn_context: &TurnContext, + turn_diff_tracker: &mut TurnDiffTracker, + sub_id: &str, + prompt: &Prompt, +) -> CodexResult { + // call_ids that are part of this response. + let completed_call_ids = prompt + .input + .iter() + .filter_map(|ri| match ri { + ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id), + ResponseItem::LocalShellCall { + call_id: Some(call_id), + .. + } => Some(call_id), + ResponseItem::CustomToolCallOutput { call_id, .. } => Some(call_id), + _ => None, + }) + .collect::>(); + + // call_ids that were pending but are not part of this response. + // This usually happens because the user interrupted the model before we responded to one of its tool calls + // and then the user sent a follow-up message. + let missing_calls = { + prompt + .input + .iter() + .filter_map(|ri| match ri { + ResponseItem::FunctionCall { call_id, .. } => Some(call_id), + ResponseItem::LocalShellCall { + call_id: Some(call_id), + .. + } => Some(call_id), + ResponseItem::CustomToolCall { call_id, .. } => Some(call_id), + _ => None, + }) + .filter_map(|call_id| { + if completed_call_ids.contains(&call_id) { + None + } else { + Some(call_id.clone()) + } + }) + .map(|call_id| ResponseItem::CustomToolCallOutput { + call_id, + output: "aborted".to_string(), + }) + .collect::>() + }; + let prompt: Cow = if missing_calls.is_empty() { + Cow::Borrowed(prompt) + } else { + // Add the synthetic aborted missing calls to the beginning of the input to ensure all call ids have responses. + let input = [missing_calls, prompt.input.clone()].concat(); + Cow::Owned(Prompt { + input, + ..prompt.clone() + }) + }; + + let rollout_item = RolloutItem::TurnContext(TurnContextItem { + cwd: turn_context.cwd.clone(), + approval_policy: turn_context.approval_policy, + sandbox_policy: turn_context.sandbox_policy.clone(), + model: turn_context.client.get_model(), + effort: turn_context.client.get_reasoning_effort(), + summary: turn_context.client.get_reasoning_summary(), + }); + sess.persist_rollout_items(&[rollout_item]).await; + let mut stream = turn_context.client.clone().stream(&prompt).await?; + + let mut output = Vec::new(); + + loop { + // Poll the next item from the model stream. We must inspect *both* Ok and Err + // cases so that transient stream failures (e.g., dropped SSE connection before + // `response.completed`) bubble up and trigger the caller's retry logic. + let event = stream.next().await; + let Some(event) = event else { + // Channel closed without yielding a final Completed event or explicit error. + // Treat as a disconnected stream so the caller can retry. + return Err(CodexErr::Stream( + "stream closed before response.completed".into(), + None, + )); + }; + + let event = match event { + Ok(ev) => ev, + Err(e) => { + // Propagate the underlying stream error to the caller (run_turn), which + // will apply the configured `stream_max_retries` policy. + return Err(e); + } + }; + + match event { + ResponseEvent::Created => {} + ResponseEvent::OutputItemDone(item) => { + let response = handle_response_item( + sess, + turn_context, + turn_diff_tracker, + sub_id, + item.clone(), + ) + .await?; + output.push(ProcessedResponseItem { item, response }); + } + ResponseEvent::WebSearchCallBegin { call_id } => { + let _ = sess + .tx_event + .send(Event { + id: sub_id.to_string(), + msg: EventMsg::WebSearchBegin(WebSearchBeginEvent { call_id }), + }) + .await; + } + ResponseEvent::RateLimits(snapshot) => { + // Update internal state with latest rate limits, but defer sending until + // token usage is available to avoid duplicate TokenCount events. + sess.update_rate_limits(sub_id, snapshot).await; + } + ResponseEvent::Completed { + response_id: _, + token_usage, + } => { + sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref()) + .await; + + let unified_diff = turn_diff_tracker.get_unified_diff(); + if let Ok(Some(unified_diff)) = unified_diff { + let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff }); + let event = Event { + id: sub_id.to_string(), + msg, + }; + sess.send_event(event).await; + } + + let result = TurnRunResult { + processed_items: output, + total_token_usage: token_usage.clone(), + }; + + return Ok(result); + } + ResponseEvent::OutputTextDelta(delta) => { + // In review child threads, suppress assistant text deltas; the + // UI will show a selection popup from the final ReviewOutput. + if !turn_context.is_review_mode { + let event = Event { + id: sub_id.to_string(), + msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }), + }; + sess.send_event(event).await; + } else { + trace!("suppressing OutputTextDelta in review mode"); + } + } + ResponseEvent::ReasoningSummaryDelta(delta) => { + let event = Event { + id: sub_id.to_string(), + msg: EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }), + }; + sess.send_event(event).await; + } + ResponseEvent::ReasoningSummaryPartAdded => { + let event = Event { + id: sub_id.to_string(), + msg: EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {}), + }; + sess.send_event(event).await; + } + ResponseEvent::ReasoningContentDelta(delta) => { + if sess.show_raw_agent_reasoning() { + let event = Event { + id: sub_id.to_string(), + msg: EventMsg::AgentReasoningRawContentDelta( + AgentReasoningRawContentDeltaEvent { delta }, + ), + }; + sess.send_event(event).await; + } + } + } + } +} + +async fn handle_response_item( + sess: &Session, + turn_context: &TurnContext, + turn_diff_tracker: &mut TurnDiffTracker, + sub_id: &str, + item: ResponseItem, +) -> CodexResult> { + debug!(?item, "Output item"); + let output = match item { + ResponseItem::FunctionCall { + name, + arguments, + call_id, + .. + } => { + info!("FunctionCall: {name}({arguments})"); + if let Some((server, tool_name)) = sess.services.mcp.parse_tool_name(&name) { + let resp = handle_mcp_tool_call( + sess, + sub_id, + call_id.clone(), + server, + tool_name, + arguments, + ) + .await; + Some(resp) + } else { + let result = handle_function_call( + sess, + turn_context, + turn_diff_tracker, + sub_id.to_string(), + name, + arguments, + call_id.clone(), + ) + .await; + + let output = match result { + Ok(content) => FunctionCallOutputPayload { + content, + success: Some(true), + }, + Err(FunctionCallError::RespondToModel(msg)) => FunctionCallOutputPayload { + content: msg, + success: Some(false), + }, + }; + Some(ResponseInputItem::FunctionCallOutput { call_id, output }) + } + } + ResponseItem::LocalShellCall { + id, + call_id, + status: _, + action, + } => { + let LocalShellAction::Exec(action) = action; + tracing::info!("LocalShellCall: {action:?}"); + let params = ShellToolCallParams { + command: action.command, + workdir: action.working_directory, + timeout_ms: action.timeout_ms, + with_escalated_permissions: None, + justification: None, + }; + let effective_call_id = match (call_id, id) { + (Some(call_id), _) => call_id, + (None, Some(id)) => id, + (None, None) => { + error!("LocalShellCall without call_id or id"); + return Ok(Some(ResponseInputItem::FunctionCallOutput { + call_id: "".to_string(), + output: FunctionCallOutputPayload { + content: "LocalShellCall without call_id or id".to_string(), + success: None, + }, + })); + } + }; + + let exec_params = to_exec_params(params, turn_context); + { + let result = handle_container_exec_with_params( + exec_params, + sess, + turn_context, + turn_diff_tracker, + sub_id.to_string(), + effective_call_id.clone(), + ) + .await; + + let output = match result { + Ok(content) => FunctionCallOutputPayload { + content, + success: Some(true), + }, + Err(FunctionCallError::RespondToModel(msg)) => FunctionCallOutputPayload { + content: msg, + success: Some(false), + }, + }; + Some(ResponseInputItem::FunctionCallOutput { + call_id: effective_call_id, + output, + }) + } + } + ResponseItem::CustomToolCall { + id: _, + call_id, + name, + input, + status: _, + } => { + let result = handle_custom_tool_call( + sess, + turn_context, + turn_diff_tracker, + sub_id.to_string(), + name, + input, + call_id.clone(), + ) + .await; + + let output = match result { + Ok(content) => content, + Err(FunctionCallError::RespondToModel(msg)) => msg, + }; + Some(ResponseInputItem::CustomToolCallOutput { call_id, output }) + } + ResponseItem::FunctionCallOutput { .. } => { + debug!("unexpected FunctionCallOutput from stream"); + None + } + ResponseItem::CustomToolCallOutput { .. } => { + debug!("unexpected CustomToolCallOutput from stream"); + None + } + ResponseItem::Message { .. } + | ResponseItem::Reasoning { .. } + | ResponseItem::WebSearchCall { .. } => { + // In review child threads, suppress assistant message events but + // keep reasoning/web search. + let msgs = match &item { + ResponseItem::Message { .. } if turn_context.is_review_mode => { + trace!("suppressing assistant Message in review mode"); + Vec::new() + } + _ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()), + }; + for msg in msgs { + let event = Event { + id: sub_id.to_string(), + msg, + }; + sess.send_event(event).await; + } + None + } + ResponseItem::Other => None, + }; + Ok(output) +} + +async fn handle_unified_exec_tool_call( + sess: &Session, + session_id: Option, + arguments: Vec, + timeout_ms: Option, +) -> Result { + let parsed_session_id = if let Some(session_id) = session_id { + match session_id.parse::() { + Ok(parsed) => Some(parsed), + Err(output) => { + return Err(FunctionCallError::RespondToModel(format!( + "invalid session_id: {session_id} due to error {output:?}" + ))); + } + } + } else { + None + }; + + let request = crate::unified_exec::UnifiedExecRequest { + session_id: parsed_session_id, + input_chunks: &arguments, + timeout_ms, + }; + + let value = sess + .services + .sandbox + .handle_unified_exec_request(request) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!("unified exec failed: {err:?}")) + })?; + + #[derive(Serialize)] + struct SerializedUnifiedExecResult { + session_id: Option, + output: String, + } + + serde_json::to_string(&SerializedUnifiedExecResult { + session_id: value.session_id.map(|id| id.to_string()), + output: value.output, + }) + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to serialize unified exec output: {err:?}" + )) + }) +} + +async fn handle_function_call( + sess: &Session, + turn_context: &TurnContext, + turn_diff_tracker: &mut TurnDiffTracker, + sub_id: String, + name: String, + arguments: String, + call_id: String, +) -> Result { + match name.as_str() { + "container.exec" | "shell" => { + let params = parse_container_exec_arguments(arguments, turn_context, &call_id)?; + handle_container_exec_with_params( + params, + sess, + turn_context, + turn_diff_tracker, + sub_id, + call_id, + ) + .await + } + "unified_exec" => { + #[derive(Deserialize)] + struct UnifiedExecArgs { + input: Vec, + #[serde(default)] + session_id: Option, + #[serde(default)] + timeout_ms: Option, + } + + let args: UnifiedExecArgs = serde_json::from_str(&arguments).map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {err:?}" + )) + })?; + + handle_unified_exec_tool_call(sess, args.session_id, args.input, args.timeout_ms).await + } + "view_image" => { + #[derive(serde::Deserialize)] + struct SeeImageArgs { + path: String, + } + let args: SeeImageArgs = serde_json::from_str(&arguments).map_err(|e| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {e:?}" + )) + })?; + let abs = turn_context.resolve_path(Some(args.path)); + sess.inject_input(vec![InputItem::LocalImage { path: abs }]) + .await + .map_err(|_| { + FunctionCallError::RespondToModel( + "unable to attach image (no active task)".to_string(), + ) + })?; + + Ok("attached local image path".to_string()) + } + "apply_patch" => { + let args: ApplyPatchToolArgs = serde_json::from_str(&arguments).map_err(|e| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {e:?}" + )) + })?; + let exec_params = ExecParams { + command: vec!["apply_patch".to_string(), args.input.clone()], + cwd: turn_context.cwd.clone(), + timeout_ms: None, + env: HashMap::new(), + with_escalated_permissions: None, + justification: None, + }; + handle_container_exec_with_params( + exec_params, + sess, + turn_context, + turn_diff_tracker, + sub_id, + call_id, + ) + .await + } + "update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await, + EXEC_COMMAND_TOOL_NAME => { + // TODO(mbolin): Sandbox check. + let exec_params: ExecCommandParams = serde_json::from_str(&arguments).map_err(|e| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {e:?}" + )) + })?; + let result = sess + .services + .sandbox + .handle_exec_command_request(exec_params) + .await; + match result { + Ok(output) => Ok(output.to_text_output()), + Err(err) => Err(FunctionCallError::RespondToModel(err)), + } + } + WRITE_STDIN_TOOL_NAME => { + let write_stdin_params = + serde_json::from_str::(&arguments).map_err(|e| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {e:?}" + )) + })?; + + let result = sess + .services + .sandbox + .handle_write_stdin_request(write_stdin_params) + .await + .map_err(FunctionCallError::RespondToModel)?; + + Ok(result.to_text_output()) + } + _ => Err(FunctionCallError::RespondToModel(format!( + "unsupported call: {name}" + ))), + } +} + +async fn handle_custom_tool_call( + sess: &Session, + turn_context: &TurnContext, + turn_diff_tracker: &mut TurnDiffTracker, + sub_id: String, + name: String, + input: String, + call_id: String, +) -> Result { + info!("CustomToolCall: {name} {input}"); + match name.as_str() { + "apply_patch" => { + let exec_params = ExecParams { + command: vec!["apply_patch".to_string(), input.clone()], + cwd: turn_context.cwd.clone(), + timeout_ms: None, + env: HashMap::new(), + with_escalated_permissions: None, + justification: None, + }; + + handle_container_exec_with_params( + exec_params, + sess, + turn_context, + turn_diff_tracker, + sub_id, + call_id, + ) + .await + } + _ => { + debug!("unexpected CustomToolCall from stream"); + Err(FunctionCallError::RespondToModel(format!( + "unsupported custom tool call: {name}" + ))) + } + } +} + +fn to_exec_params(params: ShellToolCallParams, turn_context: &TurnContext) -> ExecParams { + ExecParams { + command: params.command, + cwd: turn_context.resolve_path(params.workdir.clone()), + timeout_ms: params.timeout_ms, + env: create_env(&turn_context.shell_environment_policy), + with_escalated_permissions: params.with_escalated_permissions, + justification: params.justification, + } +} + +fn parse_container_exec_arguments( + arguments: String, + turn_context: &TurnContext, + _call_id: &str, +) -> Result { + serde_json::from_str::(&arguments) + .map(|p| to_exec_params(p, turn_context)) + .map_err(|e| { + FunctionCallError::RespondToModel(format!("failed to parse function arguments: {e:?}")) + }) +} + +pub struct ExecInvokeArgs<'a> { + pub params: ExecParams, + pub plan: ExecPlan, + pub sandbox_policy: &'a SandboxPolicy, + pub sandbox_cwd: &'a Path, + pub codex_linux_sandbox_exe: &'a Option, + pub stdout_stream: Option, +} + +fn maybe_translate_shell_command( + params: ExecParams, + sess: &Session, + turn_context: &TurnContext, +) -> ExecParams { + let should_translate = matches!(sess.user_shell(), crate::shell::Shell::PowerShell(_)) + || turn_context.shell_environment_policy.use_profile; + + if should_translate + && let Some(command) = sess + .user_shell() + .format_default_shell_invocation(params.command.clone()) + { + return ExecParams { command, ..params }; + } + params +} + +async fn handle_container_exec_with_params( + params: ExecParams, + sess: &Session, + turn_context: &TurnContext, + turn_diff_tracker: &mut TurnDiffTracker, + sub_id: String, + call_id: String, +) -> Result { + if params.with_escalated_permissions.unwrap_or(false) + && !matches!(turn_context.approval_policy, AskForApproval::OnRequest) + { + return Err(FunctionCallError::RespondToModel(format!( + "approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}", + policy = turn_context.approval_policy + ))); + } + + let mut params = params; + + // check if this was a patch, and apply it if so + let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) { + MaybeApplyPatchVerified::Body(changes) => { + let apply_patch_context = ApplyPatchContext { + approval_policy: turn_context.approval_policy, + sandbox_policy: &turn_context.sandbox_policy, + cwd: &turn_context.cwd, + }; + + match apply_patch(sess, apply_patch_context, &sub_id, &call_id, changes).await { + InternalApplyPatchInvocation::Output(item) => return item, + InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => { + params = build_exec_params_for_apply_patch(&apply_patch_exec, ¶ms)?; + Some(apply_patch_exec) + } + } + } + MaybeApplyPatchVerified::CorrectnessError(parse_error) => { + // It looks like an invocation of `apply_patch`, but we + // could not resolve it into a patch that would apply + // cleanly. Return to model for resample. + return Err(FunctionCallError::RespondToModel(format!( + "error: {parse_error:#?}" + ))); + } + MaybeApplyPatchVerified::ShellParseError(error) => { + trace!("Failed to parse shell command, {error:?}"); + None + } + MaybeApplyPatchVerified::NotApplyPatch => None, + }; + + let approved_session_commands = { + let state = sess.state.lock().await; + state.approved_commands_ref().clone() + }; + + let prepared = prepare_exec_invocation( + sess, + turn_context.approval_policy, + &turn_context.sandbox_policy, + &turn_context.cwd, + &sub_id, + &call_id, + params, + apply_patch_exec, + approved_session_commands, + ) + .await?; + + let PreparedExec { + params, + plan, + command_for_display, + apply_patch_exec, + } = prepared; + + let exec_command_context = ExecCommandContext { + sub_id: sub_id.clone(), + call_id: call_id.clone(), + command_for_display: command_for_display.clone(), + cwd: params.cwd.clone(), + apply_patch: apply_patch_exec.as_ref().map( + |ApplyPatchExec { + action, + user_explicitly_approved_this_action, + }| ApplyPatchCommandContext { + user_explicitly_approved_this_action: *user_explicitly_approved_this_action, + changes: convert_apply_patch_to_protocol(action), + }, + ), + }; + + let params = maybe_translate_shell_command(params, sess, turn_context); + let plan_for_invocation = plan.clone(); + let output_result = sess + .run_exec_with_events( + turn_diff_tracker, + exec_command_context.clone(), + ExecInvokeArgs { + params: params.clone(), + plan: plan_for_invocation, + sandbox_policy: &turn_context.sandbox_policy, + sandbox_cwd: &turn_context.cwd, + codex_linux_sandbox_exe: sess.services.sandbox.codex_linux_sandbox_exe(), + stdout_stream: if exec_command_context.apply_patch.is_some() { + None + } else { + Some(StdoutStream { + sub_id: sub_id.clone(), + call_id: call_id.clone(), + tx_event: sess.tx_event.clone(), + }) + }, + }, + ) + .await; + + match output_result { + Ok(output) => { + let ExecToolCallOutput { exit_code, .. } = &output; + let content = format_exec_output(&output); + if *exit_code == 0 { + Ok(content) + } else { + Err(FunctionCallError::RespondToModel(content)) + } + } + Err(CodexErr::Sandbox(error)) => { + handle_sandbox_error( + turn_diff_tracker, + params, + exec_command_context, + error, + &plan, + sess, + turn_context, + ) + .await + } + Err(e) => Err(FunctionCallError::RespondToModel(format!( + "execution error: {e:?}" + ))), + } +} + +async fn handle_sandbox_error( + turn_diff_tracker: &mut TurnDiffTracker, + params: ExecParams, + exec_command_context: ExecCommandContext, + error: SandboxErr, + plan: &ExecPlan, + sess: &Session, + turn_context: &TurnContext, +) -> Result { + let call_id = exec_command_context.call_id.clone(); + let sub_id = exec_command_context.sub_id.clone(); + let cwd = exec_command_context.cwd.clone(); + + if let SandboxErr::Timeout { output } = &error { + let content = format_exec_output(output); + return Err(FunctionCallError::RespondToModel(content)); + } + + let ExecPlan::Approved { + sandbox: sandbox_type, + on_failure_escalate, + .. + } = plan + else { + return Err(FunctionCallError::RespondToModel( + "execution failed without an approved plan".to_string(), + )); + }; + + if !on_failure_escalate { + return Err(FunctionCallError::RespondToModel(format!( + "failed in sandbox {sandbox_type:?} with execution error: {error:?}" + ))); + } + + // Note that when `error` is `SandboxErr::Denied`, it could be a false + // positive. That is, it may have exited with a non-zero exit code, not + // because the sandbox denied it, but because that is its expected behavior, + // i.e., a grep command that did not match anything. Ideally we would + // include additional metadata on the command to indicate whether non-zero + // exit codes merit a retry. + + // For now, we categorically ask the user to retry without sandbox and + // emit the raw error as a background event. + sess.notify_background_event(&sub_id, format!("Execution failed: {error}")) + .await; + + let command_for_retry = params.command.clone(); + let decision = sess + .request_command_approval( + sub_id.clone(), + call_id.clone(), + command_for_retry.clone(), + cwd.clone(), + Some("command failed; retry without sandbox?".to_string()), + ) + .await; + + match decision { + ReviewDecision::Approved | ReviewDecision::ApprovedForSession => { + // Persist this command as pre‑approved for the + // remainder of the session so future + // executions skip the sandbox directly. + // TODO(ragona): Isn't this a bug? It always saves the command in an | fork? + sess.add_approved_command(command_for_retry.clone()).await; + // Inform UI we are retrying without sandbox. + sess.notify_background_event(&sub_id, "retrying command without sandbox") + .await; + + // This is an escalated retry; the policy will not be + // examined and the sandbox has been set to `None`. + let retry_output_result = sess + .run_exec_with_events( + turn_diff_tracker, + exec_command_context.clone(), + ExecInvokeArgs { + params, + plan: ExecPlan::approved(SandboxType::None, false, true), + sandbox_policy: &turn_context.sandbox_policy, + sandbox_cwd: &turn_context.cwd, + codex_linux_sandbox_exe: sess.services.sandbox.codex_linux_sandbox_exe(), + stdout_stream: if exec_command_context.apply_patch.is_some() { + None + } else { + Some(StdoutStream { + sub_id: sub_id.clone(), + call_id: call_id.clone(), + tx_event: sess.tx_event.clone(), + }) + }, + }, + ) + .await; + + match retry_output_result { + Ok(retry_output) => { + let ExecToolCallOutput { exit_code, .. } = &retry_output; + let content = format_exec_output(&retry_output); + if *exit_code == 0 { + Ok(content) + } else { + Err(FunctionCallError::RespondToModel(content)) + } + } + Err(e) => Err(FunctionCallError::RespondToModel(format!( + "retry failed: {e}" + ))), + } + } + ReviewDecision::Denied | ReviewDecision::Abort => { + // Fall through to original failure handling. + Err(FunctionCallError::RespondToModel( + "exec command rejected by user".to_string(), + )) + } + } +} + +fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String { + let ExecToolCallOutput { + aggregated_output, .. + } = exec_output; + + // Head+tail truncation for the model: show the beginning and end with an elision. + // Clients still receive full streams; only this formatted summary is capped. + + let mut s = &aggregated_output.text; + let prefixed_str: String; + + if exec_output.timed_out { + prefixed_str = format!( + "command timed out after {} milliseconds\n", + exec_output.duration.as_millis() + ) + s; + s = &prefixed_str; + } + + let total_lines = s.lines().count(); + if s.len() <= MODEL_FORMAT_MAX_BYTES && total_lines <= MODEL_FORMAT_MAX_LINES { + return s.to_string(); + } + + let lines: Vec<&str> = s.lines().collect(); + let head_take = MODEL_FORMAT_HEAD_LINES.min(lines.len()); + let tail_take = MODEL_FORMAT_TAIL_LINES.min(lines.len().saturating_sub(head_take)); + let omitted = lines.len().saturating_sub(head_take + tail_take); + + // Join head and tail blocks (lines() strips newlines; reinsert them) + let head_block = lines + .iter() + .take(head_take) + .cloned() + .collect::>() + .join("\n"); + let tail_block = if tail_take > 0 { + lines[lines.len() - tail_take..].join("\n") + } else { + String::new() + }; + let marker = format!("\n[... omitted {omitted} of {total_lines} lines ...]\n\n"); + + // Byte budgets for head/tail around the marker + let mut head_budget = MODEL_FORMAT_HEAD_BYTES.min(MODEL_FORMAT_MAX_BYTES); + let tail_budget = MODEL_FORMAT_MAX_BYTES.saturating_sub(head_budget + marker.len()); + if tail_budget == 0 && marker.len() >= MODEL_FORMAT_MAX_BYTES { + // Degenerate case: marker alone exceeds budget; return a clipped marker + return take_bytes_at_char_boundary(&marker, MODEL_FORMAT_MAX_BYTES).to_string(); + } + if tail_budget == 0 { + // Make room for the marker by shrinking head + head_budget = MODEL_FORMAT_MAX_BYTES.saturating_sub(marker.len()); + } + + // Enforce line-count cap by trimming head/tail lines + let head_lines_text = head_block; + let tail_lines_text = tail_block; + // Build final string respecting byte budgets + let head_part = take_bytes_at_char_boundary(&head_lines_text, head_budget); + let mut result = String::with_capacity(MODEL_FORMAT_MAX_BYTES.min(s.len())); + + result.push_str(head_part); + result.push_str(&marker); + + let remaining = MODEL_FORMAT_MAX_BYTES.saturating_sub(result.len()); + let tail_budget_final = remaining; + let tail_part = take_last_bytes_at_char_boundary(&tail_lines_text, tail_budget_final); + result.push_str(tail_part); + + result +} + +// Truncate a &str to a byte budget at a char boundary (prefix) +#[inline] +fn take_bytes_at_char_boundary(s: &str, maxb: usize) -> &str { + if s.len() <= maxb { + return s; + } + let mut last_ok = 0; + for (i, ch) in s.char_indices() { + let nb = i + ch.len_utf8(); + if nb > maxb { + break; + } + last_ok = nb; + } + &s[..last_ok] +} + +// Take a suffix of a &str within a byte budget at a char boundary +#[inline] +fn take_last_bytes_at_char_boundary(s: &str, maxb: usize) -> &str { + if s.len() <= maxb { + return s; + } + let mut start = s.len(); + let mut used = 0usize; + for (i, ch) in s.char_indices().rev() { + let nb = ch.len_utf8(); + if used + nb > maxb { + break; + } + start = i; + used += nb; + if start == 0 { + break; + } + } + &s[start..] +} + +/// Exec output is a pre-serialized JSON payload +fn format_exec_output(exec_output: &ExecToolCallOutput) -> String { + let ExecToolCallOutput { + exit_code, + duration, + .. + } = exec_output; + + #[derive(Serialize)] + struct ExecMetadata { + exit_code: i32, + duration_seconds: f32, + } + + #[derive(Serialize)] + struct ExecOutput<'a> { + output: &'a str, + metadata: ExecMetadata, + } + + // round to 1 decimal place + let duration_seconds = ((duration.as_secs_f32()) * 10.0).round() / 10.0; + + let formatted_output = format_exec_output_str(exec_output); + + let payload = ExecOutput { + output: &formatted_output, + metadata: ExecMetadata { + exit_code: *exit_code, + duration_seconds, + }, + }; + + #[expect(clippy::expect_used)] + serde_json::to_string(&payload).expect("serialize ExecOutput") +} + +pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option { + responses.iter().rev().find_map(|item| { + if let ResponseItem::Message { role, content, .. } = item { + if role == "assistant" { + content.iter().rev().find_map(|ci| { + if let ContentItem::OutputText { text } = ci { + Some(text.clone()) + } else { + None + } + }) + } else { + None + } + } else { + None + } + }) +} +fn convert_call_tool_result_to_function_call_output_payload( + call_tool_result: &CallToolResult, +) -> FunctionCallOutputPayload { + let CallToolResult { + content, + is_error, + structured_content, + } = call_tool_result; + + // In terms of what to send back to the model, we prefer structured_content, + // if available, and fallback to content, otherwise. + let mut is_success = is_error != &Some(true); + let content = if let Some(structured_content) = structured_content + && structured_content != &serde_json::Value::Null + && let Ok(serialized_structured_content) = serde_json::to_string(&structured_content) + { + serialized_structured_content + } else { + match serde_json::to_string(&content) { + Ok(serialized_content) => serialized_content, + Err(err) => { + // If we could not serialize either content or structured_content to + // JSON, flag this as an error. + is_success = false; + err.to_string() + } + } + }; + + FunctionCallOutputPayload { + content, + success: Some(is_success), + } +} + +/// Emits an ExitedReviewMode Event with optional ReviewOutput, +/// and records a developer message with the review output. +pub(crate) async fn exit_review_mode( + session: Arc, + task_sub_id: String, + review_output: Option, +) { + let event = Event { + id: task_sub_id, + msg: EventMsg::ExitedReviewMode(ExitedReviewModeEvent { + review_output: review_output.clone(), + }), + }; + session.send_event(event).await; + + let mut user_message = String::new(); + if let Some(out) = review_output { + let mut findings_str = String::new(); + let text = out.overall_explanation.trim(); + if !text.is_empty() { + findings_str.push_str(text); + } + if !out.findings.is_empty() { + let block = format_review_findings_block(&out.findings, None); + findings_str.push_str(&format!("\n{block}")); + } + user_message.push_str(&format!( + r#" + User initiated a review task. Here's the full review output from reviewer model. User may select one or more comments to resolve. + review + + {findings_str} + + +"#)); + } else { + user_message.push_str(r#" + User initiated a review task, but was interrupted. If user asks about this, tell them to re-initiate a review with `/review` and wait for it to complete. + review + + None. + + +"#); + } + + session + .record_conversation_items(&[ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { text: user_message }], + }]) + .await; +} + +#[cfg(test)] +pub(crate) use tests::make_session_and_context; + +#[async_trait::async_trait] +impl ApprovalCoordinator for Session { + async fn request_patch_approval( + &self, + sub_id: String, + call_id: String, + action: &ApplyPatchAction, + reason: Option, + grant_root: Option, + ) -> ReviewDecision { + let rx = Session::request_patch_approval(self, sub_id, call_id, action, reason, grant_root) + .await; + rx.await.unwrap_or_default() + } + + async fn request_command_approval( + &self, + sub_id: String, + call_id: String, + command: Vec, + cwd: PathBuf, + reason: Option, + ) -> ReviewDecision { + Session::request_command_approval(self, sub_id, call_id, command, cwd, reason).await + } + + async fn add_approved_command(&self, command: Vec) { + Session::add_approved_command(self, command).await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::ConfigOverrides; + use crate::config::ConfigToml; + use crate::protocol::CompactedItem; + use crate::protocol::InitialHistory; + use crate::protocol::ResumedHistory; + use codex_protocol::models::ContentItem; + use mcp_types::ContentBlock; + use mcp_types::TextContent; + use pretty_assertions::assert_eq; + use serde_json::json; + use std::path::PathBuf; + use std::sync::Arc; + use std::time::Duration as StdDuration; + + fn sample_rollout( + session: &Session, + turn_context: &TurnContext, + ) -> (Vec, Vec) { + let mut rollout_items = Vec::new(); + let mut live_history = ConversationHistory::new(); + + let initial_context = session.build_initial_context(turn_context); + for item in &initial_context { + rollout_items.push(RolloutItem::ResponseItem(item.clone())); + } + live_history.record_items(initial_context.iter()); + + let user1 = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "first user".to_string(), + }], + }; + live_history.record_items(std::iter::once(&user1)); + rollout_items.push(RolloutItem::ResponseItem(user1.clone())); + + let assistant1 = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "assistant reply one".to_string(), + }], + }; + live_history.record_items(std::iter::once(&assistant1)); + rollout_items.push(RolloutItem::ResponseItem(assistant1.clone())); + + let summary1 = "summary one"; + let snapshot1 = live_history.contents(); + let user_messages1 = collect_user_messages(&snapshot1); + let rebuilt1 = build_compacted_history( + session.build_initial_context(turn_context), + &user_messages1, + summary1, + ); + live_history.replace(rebuilt1); + rollout_items.push(RolloutItem::Compacted(CompactedItem { + message: summary1.to_string(), + })); + + let user2 = ResponseItem::Message { + id: None, diff --git a/codex-rs/agent/src/sandbox/mod.rs b/codex-rs/agent/src/sandbox/mod.rs index 037eea61c6..ef9b8af9ab 100644 --- a/codex-rs/agent/src/sandbox/mod.rs +++ b/codex-rs/agent/src/sandbox/mod.rs @@ -1,3 +1,10 @@ +pub mod planner; pub mod types; +pub use planner::CommandPlanRequest; +pub use planner::ExecPlan; +pub use planner::PatchPlanRequest; +pub use planner::plan_apply_patch; +pub use planner::plan_exec; +pub use planner::should_escalate_on_failure; pub use types::SandboxType; diff --git a/codex-rs/agent/src/sandbox/planner.rs b/codex-rs/agent/src/sandbox/planner.rs new file mode 100644 index 0000000000..3331b72777 --- /dev/null +++ b/codex-rs/agent/src/sandbox/planner.rs @@ -0,0 +1,106 @@ +use std::collections::HashSet; +use std::path::Path; + +use codex_apply_patch::ApplyPatchAction; +use codex_protocol::protocol::AskForApproval; +use codex_protocol::protocol::SandboxPolicy; + +use crate::safety::SafetyCheck; +use crate::safety::assess_command_safety; +use crate::safety::assess_patch_safety; + +use super::SandboxType; + +#[derive(Clone, Debug)] +pub enum ExecPlan { + Reject { + reason: String, + }, + AskUser { + reason: Option, + }, + Approved { + sandbox: SandboxType, + on_failure_escalate: bool, + approved_by_user: bool, + }, +} + +impl ExecPlan { + pub fn approved( + sandbox: SandboxType, + on_failure_escalate: bool, + approved_by_user: bool, + ) -> Self { + ExecPlan::Approved { + sandbox, + on_failure_escalate, + approved_by_user, + } + } +} + +pub struct CommandPlanRequest<'a> { + pub command: &'a [String], + pub approval: AskForApproval, + pub policy: &'a SandboxPolicy, + pub approved_session_commands: &'a HashSet>, + pub with_escalated_permissions: bool, + pub justification: Option<&'a String>, +} + +pub struct PatchPlanRequest<'a> { + pub action: &'a ApplyPatchAction, + pub approval: AskForApproval, + pub policy: &'a SandboxPolicy, + pub cwd: &'a Path, + pub user_explicitly_approved: bool, +} + +pub fn plan_exec(req: &CommandPlanRequest<'_>) -> ExecPlan { + let safety = assess_command_safety( + req.command, + req.approval, + req.policy, + req.approved_session_commands, + req.with_escalated_permissions, + ); + + match safety { + SafetyCheck::AutoApprove { sandbox_type } => ExecPlan::approved( + sandbox_type, + should_escalate_on_failure(req.approval, sandbox_type), + false, + ), + SafetyCheck::AskUser => ExecPlan::AskUser { + reason: req.justification.map(ToOwned::to_owned), + }, + SafetyCheck::Reject { reason } => ExecPlan::Reject { reason }, + } +} + +pub fn plan_apply_patch(req: &PatchPlanRequest<'_>) -> ExecPlan { + if req.user_explicitly_approved { + return ExecPlan::approved(SandboxType::None, false, true); + } + + match assess_patch_safety(req.action, req.approval, req.policy, req.cwd) { + SafetyCheck::AutoApprove { sandbox_type } => ExecPlan::approved( + sandbox_type, + should_escalate_on_failure(req.approval, sandbox_type), + false, + ), + SafetyCheck::AskUser => ExecPlan::AskUser { reason: None }, + SafetyCheck::Reject { reason } => ExecPlan::Reject { reason }, + } +} + +pub fn should_escalate_on_failure(approval: AskForApproval, sandbox: SandboxType) -> bool { + matches!( + (approval, sandbox), + ( + AskForApproval::UnlessTrusted | AskForApproval::OnFailure, + SandboxType::MacosSeatbelt | SandboxType::LinuxSeccomp + ) + ) +} diff --git a/codex-rs/agent/src/tool_schema.rs b/codex-rs/agent/src/tool_schema.rs new file mode 100644 index 0000000000..4c8f977830 --- /dev/null +++ b/codex-rs/agent/src/tool_schema.rs @@ -0,0 +1,79 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Clone, Serialize, PartialEq)] +pub struct ResponsesApiTool { + pub name: String, + pub description: String, + /// TODO: Validation. When strict is set to true, the JSON schema, + /// `required` and `additional_properties` must be present. All fields in + /// `properties` must be present in `required`. + pub strict: bool, + pub parameters: JsonSchema, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct FreeformTool { + pub name: String, + pub description: String, + pub format: FreeformToolFormat, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct FreeformToolFormat { + pub r#type: String, + pub syntax: String, + pub definition: String, +} + +/// Generic JSON-Schema subset needed for our tool definitions +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum JsonSchema { + Boolean { + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + }, + String { + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + }, + /// MCP schema allows "number" | "integer" for Number + #[serde(alias = "integer")] + Number { + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + }, + Array { + items: Box, + + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + }, + Object { + properties: std::collections::BTreeMap, + #[serde(skip_serializing_if = "Option::is_none")] + required: Option>, + #[serde( + rename = "additionalProperties", + skip_serializing_if = "Option::is_none" + )] + additional_properties: Option, + }, +} + +/// When serialized as JSON, this produces a valid "Tool" in the OpenAI Responses API. +#[derive(Debug, Clone, Serialize, PartialEq)] +#[serde(tag = "type")] +pub enum OpenAiTool { + #[serde(rename = "function")] + Function(ResponsesApiTool), + #[serde(rename = "local_shell")] + LocalShell {}, + // TODO: Understand why we get an error on web_search although the API docs say it's supported. + // https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C + #[serde(rename = "web_search")] + WebSearch {}, + #[serde(rename = "custom")] + Freeform(FreeformTool), +} diff --git a/codex-rs/agent/src/tools_config.rs b/codex-rs/agent/src/tools_config.rs new file mode 100644 index 0000000000..20a2b0a501 --- /dev/null +++ b/codex-rs/agent/src/tools_config.rs @@ -0,0 +1,72 @@ +use crate::tooling::ApplyPatchToolType; +use crate::model_family::ModelFamily; + +#[derive(Debug, Clone)] +pub enum ConfigShellToolType { + Default, + Local, + Streamable, +} + +#[derive(Debug, Clone)] +pub struct ToolsConfig { + pub shell_type: ConfigShellToolType, + pub plan_tool: bool, + pub apply_patch_tool_type: Option, + pub web_search_request: bool, + pub include_view_image_tool: bool, + pub experimental_unified_exec_tool: bool, +} + +pub struct ToolsConfigParams<'a> { + pub model_family: &'a ModelFamily, + pub include_plan_tool: bool, + pub include_apply_patch_tool: bool, + pub include_web_search_request: bool, + pub use_streamable_shell_tool: bool, + pub include_view_image_tool: bool, + pub experimental_unified_exec_tool: bool, +} + +impl ToolsConfig { + pub fn new(params: &ToolsConfigParams) -> Self { + let ToolsConfigParams { + model_family, + include_plan_tool, + include_apply_patch_tool, + include_web_search_request, + use_streamable_shell_tool, + include_view_image_tool, + experimental_unified_exec_tool, + } = params; + let shell_type = if *use_streamable_shell_tool { + ConfigShellToolType::Streamable + } else if model_family.uses_local_shell_tool { + ConfigShellToolType::Local + } else { + ConfigShellToolType::Default + }; + + let apply_patch_tool_type = match model_family.apply_patch_tool_type { + Some(ApplyPatchToolType::Freeform) => Some(ApplyPatchToolType::Freeform), + Some(ApplyPatchToolType::Function) => Some(ApplyPatchToolType::Function), + None => { + if *include_apply_patch_tool { + Some(ApplyPatchToolType::Freeform) + } else { + None + } + } + }; + + Self { + shell_type, + plan_tool: *include_plan_tool, + apply_patch_tool_type, + web_search_request: *include_web_search_request, + include_view_image_tool: *include_view_image_tool, + experimental_unified_exec_tool: *experimental_unified_exec_tool, + } + } +} + diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index a64160f863..a694d4d125 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -103,7 +103,7 @@ pub(crate) async fn stream_chat_completions( for c in items { match c { ReasoningItemContent::ReasoningText { text: t } - | ReasoningItemContent::Text { text: t } => text.push_str(t), + | ReasoningItemContent::Text { text: t } => text.push_str(t.as_str()), } } if text.trim().is_empty() { @@ -158,7 +158,7 @@ pub(crate) async fn stream_chat_completions( match c { ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => { - text.push_str(t); + text.push_str(t.as_str()); } _ => {} } diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index b695581deb..a74cbe53e3 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -1,371 +1,5 @@ -use crate::error::Result; -use crate::model_family::ModelFamily; -use crate::openai_tools::OpenAiTool; -use crate::protocol::RateLimitSnapshot; -use crate::protocol::TokenUsage; -use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS; -use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; -use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; -use codex_protocol::config_types::Verbosity as VerbosityConfig; -use codex_protocol::models::ResponseItem; -use futures::Stream; -use serde::Serialize; -use serde_json::Value; -use std::borrow::Cow; -use std::ops::Deref; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; -use tokio::sync::mpsc; +pub use codex_agent::client_common::*; -/// Review thread system prompt. Edit `core/src/review_prompt.md` to customize. -pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md"); +use crate::error::CodexErr; -/// API request payload for a single model turn -#[derive(Default, Debug, Clone)] -pub struct Prompt { - /// Conversation context input items. - pub input: Vec, - - /// Tools available to the model, including additional tools sourced from - /// external MCP servers. - pub(crate) tools: Vec, - - /// Optional override for the built-in BASE_INSTRUCTIONS. - pub base_instructions_override: Option, - - /// Optional the output schema for the model's response. - pub output_schema: Option, -} - -impl Prompt { - pub(crate) fn get_full_instructions<'a>(&'a self, model: &'a ModelFamily) -> Cow<'a, str> { - let base = self - .base_instructions_override - .as_deref() - .unwrap_or(model.base_instructions.deref()); - // When there are no custom instructions, add apply_patch_tool_instructions if: - // - the model needs special instructions (4.1) - // AND - // - there is no apply_patch tool present - let is_apply_patch_tool_present = self.tools.iter().any(|tool| match tool { - OpenAiTool::Function(f) => f.name == "apply_patch", - OpenAiTool::Freeform(f) => f.name == "apply_patch", - _ => false, - }); - if self.base_instructions_override.is_none() - && model.needs_special_apply_patch_instructions - && !is_apply_patch_tool_present - { - Cow::Owned(format!("{base}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}")) - } else { - Cow::Borrowed(base) - } - } - - pub(crate) fn get_formatted_input(&self) -> Vec { - self.input.clone() - } -} - -#[derive(Debug)] -pub enum ResponseEvent { - Created, - OutputItemDone(ResponseItem), - Completed { - response_id: String, - token_usage: Option, - }, - OutputTextDelta(String), - ReasoningSummaryDelta(String), - ReasoningContentDelta(String), - ReasoningSummaryPartAdded, - WebSearchCallBegin { - call_id: String, - }, - RateLimits(RateLimitSnapshot), -} - -#[derive(Debug, Serialize)] -pub(crate) struct Reasoning { - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) effort: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) summary: Option, -} - -#[derive(Debug, Serialize, Default, Clone)] -#[serde(rename_all = "snake_case")] -pub(crate) enum TextFormatType { - #[default] - JsonSchema, -} - -#[derive(Debug, Serialize, Default, Clone)] -pub(crate) struct TextFormat { - pub(crate) r#type: TextFormatType, - pub(crate) strict: bool, - pub(crate) schema: Value, - pub(crate) name: String, -} - -/// Controls under the `text` field in the Responses API for GPT-5. -#[derive(Debug, Serialize, Default, Clone)] -pub(crate) struct TextControls { - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) verbosity: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) format: Option, -} - -#[derive(Debug, Serialize, Default, Clone)] -#[serde(rename_all = "lowercase")] -pub(crate) enum OpenAiVerbosity { - Low, - #[default] - Medium, - High, -} - -impl From for OpenAiVerbosity { - fn from(v: VerbosityConfig) -> Self { - match v { - VerbosityConfig::Low => OpenAiVerbosity::Low, - VerbosityConfig::Medium => OpenAiVerbosity::Medium, - VerbosityConfig::High => OpenAiVerbosity::High, - } - } -} - -/// Request object that is serialized as JSON and POST'ed when using the -/// Responses API. -#[derive(Debug, Serialize)] -pub(crate) struct ResponsesApiRequest<'a> { - pub(crate) model: &'a str, - pub(crate) instructions: &'a str, - // TODO(mbolin): ResponseItem::Other should not be serialized. Currently, - // we code defensively to avoid this case, but perhaps we should use a - // separate enum for serialization. - pub(crate) input: &'a Vec, - pub(crate) tools: &'a [serde_json::Value], - pub(crate) tool_choice: &'static str, - pub(crate) parallel_tool_calls: bool, - pub(crate) reasoning: Option, - pub(crate) store: bool, - pub(crate) stream: bool, - pub(crate) include: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) prompt_cache_key: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) text: Option, -} - -pub(crate) fn create_reasoning_param_for_request( - model_family: &ModelFamily, - effort: Option, - summary: ReasoningSummaryConfig, -) -> Option { - if !model_family.supports_reasoning_summaries { - return None; - } - - Some(Reasoning { - effort, - summary: Some(summary), - }) -} - -pub(crate) fn create_text_param_for_request( - verbosity: Option, - output_schema: &Option, -) -> Option { - if verbosity.is_none() && output_schema.is_none() { - return None; - } - - Some(TextControls { - verbosity: verbosity.map(std::convert::Into::into), - format: output_schema.as_ref().map(|schema| TextFormat { - r#type: TextFormatType::JsonSchema, - strict: true, - schema: schema.clone(), - name: "codex_output_schema".to_string(), - }), - }) -} - -pub struct ResponseStream { - pub(crate) rx_event: mpsc::Receiver>, -} - -impl Stream for ResponseStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.rx_event.poll_recv(cx) - } -} - -#[cfg(test)] -mod tests { - use crate::model_family::find_family_for_model; - use pretty_assertions::assert_eq; - - use super::*; - - struct InstructionsTestCase { - pub slug: &'static str, - pub expects_apply_patch_instructions: bool, - } - #[test] - fn get_full_instructions_no_user_content() { - let prompt = Prompt { - ..Default::default() - }; - let test_cases = vec![ - InstructionsTestCase { - slug: "gpt-3.5", - expects_apply_patch_instructions: true, - }, - InstructionsTestCase { - slug: "gpt-4.1", - expects_apply_patch_instructions: true, - }, - InstructionsTestCase { - slug: "gpt-4o", - expects_apply_patch_instructions: true, - }, - InstructionsTestCase { - slug: "gpt-5", - expects_apply_patch_instructions: true, - }, - InstructionsTestCase { - slug: "codex-mini-latest", - expects_apply_patch_instructions: true, - }, - InstructionsTestCase { - slug: "gpt-oss:120b", - expects_apply_patch_instructions: false, - }, - InstructionsTestCase { - slug: "gpt-5-codex", - expects_apply_patch_instructions: false, - }, - ]; - for test_case in test_cases { - let model_family = find_family_for_model(test_case.slug).expect("known model slug"); - let expected = if test_case.expects_apply_patch_instructions { - format!( - "{}\n{}", - model_family.clone().base_instructions, - APPLY_PATCH_TOOL_INSTRUCTIONS - ) - } else { - model_family.clone().base_instructions - }; - - let full = prompt.get_full_instructions(&model_family); - assert_eq!(full, expected); - } - } - - #[test] - fn serializes_text_verbosity_when_set() { - let input: Vec = vec![]; - let tools: Vec = vec![]; - let req = ResponsesApiRequest { - model: "gpt-5", - instructions: "i", - input: &input, - tools: &tools, - tool_choice: "auto", - parallel_tool_calls: false, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - text: Some(TextControls { - verbosity: Some(OpenAiVerbosity::Low), - format: None, - }), - }; - - let v = serde_json::to_value(&req).expect("json"); - assert_eq!( - v.get("text") - .and_then(|t| t.get("verbosity")) - .and_then(|s| s.as_str()), - Some("low") - ); - } - - #[test] - fn serializes_text_schema_with_strict_format() { - let input: Vec = vec![]; - let tools: Vec = vec![]; - let schema = serde_json::json!({ - "type": "object", - "properties": { - "answer": {"type": "string"} - }, - "required": ["answer"], - }); - let text_controls = - create_text_param_for_request(None, &Some(schema.clone())).expect("text controls"); - - let req = ResponsesApiRequest { - model: "gpt-5", - instructions: "i", - input: &input, - tools: &tools, - tool_choice: "auto", - parallel_tool_calls: false, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - text: Some(text_controls), - }; - - let v = serde_json::to_value(&req).expect("json"); - let text = v.get("text").expect("text field"); - assert!(text.get("verbosity").is_none()); - let format = text.get("format").expect("format field"); - - assert_eq!( - format.get("name"), - Some(&serde_json::Value::String("codex_output_schema".into())) - ); - assert_eq!( - format.get("type"), - Some(&serde_json::Value::String("json_schema".into())) - ); - assert_eq!(format.get("strict"), Some(&serde_json::Value::Bool(true))); - assert_eq!(format.get("schema"), Some(&schema)); - } - - #[test] - fn omits_text_when_not_set() { - let input: Vec = vec![]; - let tools: Vec = vec![]; - let req = ResponsesApiRequest { - model: "gpt-5", - instructions: "i", - input: &input, - tools: &tools, - tool_choice: "auto", - parallel_tool_calls: false, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - text: None, - }; - - let v = serde_json::to_value(&req).expect("json"); - assert!(v.get("text").is_none()); - } -} +pub type ResponseStream = codex_agent::client_common::ResponseStream; diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 188c30a373..ecd645602e 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -71,6 +71,7 @@ use crate::exec_command::WriteStdinParams; use crate::exec_env::create_env; use crate::mcp_connection_manager::McpConnectionManager; use crate::mcp_tool_call::handle_mcp_tool_call; +use crate::model_client_adapter::CoreModelClientAdapter; use crate::model_family::find_family_for_model; use crate::model_provider_info::ModelProviderExt; use crate::openai_model_info::get_model_info; @@ -119,6 +120,7 @@ use crate::sandbox::BackendRegistry; use crate::sandbox::ExecPlan; use crate::sandbox::ExecRuntimeContext; use crate::sandbox::PreparedExec; +use crate::sandbox::build_exec_params_for_apply_patch; use crate::sandbox::prepare_exec_invocation; use crate::sandbox::run_with_plan; use crate::shell; @@ -136,6 +138,7 @@ use codex_agent::apply_patch::ApplyPatchExec; use codex_agent::apply_patch::InternalApplyPatchInvocation; use codex_agent::apply_patch::apply_patch; use codex_agent::apply_patch::convert_apply_patch_to_protocol; +use codex_agent::model_client::ModelClientAdapter; use codex_agent::services::ApprovalCoordinator; use codex_agent::session_services::SessionServices; use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; @@ -246,14 +249,15 @@ impl Codex { // Construct the model client and initial turn context before handing off to the runtime. let credentials_provider: Option> = Some(auth_manager.clone()); - let client = ModelClient::new( - agent_config.clone(), - credentials_provider, - configure_session.provider.clone(), - configure_session.model_reasoning_effort, - configure_session.model_reasoning_summary, - conversation_id, - ); + let client: Arc> = + Arc::new(CoreModelClientAdapter::new(ModelClient::new( + agent_config.clone(), + credentials_provider, + configure_session.provider.clone(), + configure_session.model_reasoning_effort, + configure_session.model_reasoning_summary, + conversation_id, + ))); let tools_config = ToolsConfig::new(&ToolsConfigParams { model_family: &agent_config.model_family, include_plan_tool: agent_config.include_plan_tool, @@ -386,9 +390,8 @@ pub(crate) struct Session { } /// The context needed for a single turn of the conversation. -#[derive(Debug)] pub(crate) struct TurnContext { - pub(crate) client: ModelClient, + pub(crate) client: Arc>, /// The session's current working directory. All relative paths provided by /// the model as well as sandbox policies are resolved against this path /// instead of `std::env::current_dir()`. @@ -411,6 +414,22 @@ impl TurnContext { } } +impl std::fmt::Debug for TurnContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TurnContext") + .field("cwd", &self.cwd) + .field("base_instructions", &self.base_instructions) + .field("user_instructions", &self.user_instructions) + .field("approval_policy", &self.approval_policy) + .field("sandbox_policy", &self.sandbox_policy) + .field("shell_environment_policy", &self.shell_environment_policy) + .field("tools_config", &self.tools_config) + .field("is_review_mode", &self.is_review_mode) + .field("final_output_json_schema", &self.final_output_json_schema) + .finish() + } +} + /// Configure the model session. struct ConfigureSession { /// Provider identifier ("openai", "openrouter", ...). @@ -1232,14 +1251,15 @@ async fn submission_loop( updated_config.model_context_window = Some(model_info.context_window); } - let client = ModelClient::new( - Arc::new(updated_config), - auth_manager, - provider, - effective_effort, - effective_summary, - sess.conversation_id, - ); + let client: Arc> = + Arc::new(CoreModelClientAdapter::new(ModelClient::new( + Arc::new(updated_config), + auth_manager, + provider, + effective_effort, + effective_summary, + sess.conversation_id, + ))); let new_approval_policy = approval_policy.unwrap_or(prev.approval_policy); let new_sandbox_policy = sandbox_policy @@ -1323,14 +1343,15 @@ async fn submission_loop( // Build a new client with per‑turn reasoning settings. // Reuse the same provider and session id; auth defaults to env/API key. - let client = ModelClient::new( - Arc::new(per_turn_config), - auth_manager, - provider, - effort, - summary, - sess.conversation_id, - ); + let client: Arc> = + Arc::new(CoreModelClientAdapter::new(ModelClient::new( + Arc::new(per_turn_config), + auth_manager, + provider, + effort, + summary, + sess.conversation_id, + ))); let fresh_turn_context = TurnContext { client, @@ -1584,14 +1605,15 @@ async fn spawn_review_thread( } let per_turn_config = Arc::new(per_turn_config); - let client = ModelClient::new( - per_turn_config.clone(), - auth_manager, - provider, - per_turn_config.model_reasoning_effort, - per_turn_config.model_reasoning_summary, - sess.conversation_id, - ); + let client: Arc> = + Arc::new(CoreModelClientAdapter::new(ModelClient::new( + per_turn_config.clone(), + auth_manager, + provider, + per_turn_config.model_reasoning_effort, + per_turn_config.model_reasoning_summary, + sess.conversation_id, + ))); let review_turn_context = TurnContext { client, @@ -2673,6 +2695,8 @@ async fn handle_container_exec_with_params( ))); } + let mut params = params; + // check if this was a patch, and apply it if so let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) { MaybeApplyPatchVerified::Body(changes) => { @@ -2685,6 +2709,7 @@ async fn handle_container_exec_with_params( match apply_patch(sess, apply_patch_context, &sub_id, &call_id, changes).await { InternalApplyPatchInvocation::Output(item) => return item, InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => { + params = build_exec_params_for_apply_patch(&apply_patch_exec, ¶ms)?; Some(apply_patch_exec) } } @@ -2711,7 +2736,9 @@ async fn handle_container_exec_with_params( let prepared = prepare_exec_invocation( sess, - turn_context, + turn_context.approval_policy, + &turn_context.sandbox_policy, + &turn_context.cwd, &sub_id, &call_id, params, @@ -3560,14 +3587,15 @@ mod tests { let config = Arc::new(config); let agent_config = Arc::new(AgentConfig::from(config.as_ref())); let conversation_id = ConversationId::default(); - let client = ModelClient::new( - agent_config.clone(), - None, - agent_config.model_provider.clone(), - agent_config.model_reasoning_effort, - agent_config.model_reasoning_summary, - conversation_id, - ); + let client: Arc> = + Arc::new(CoreModelClientAdapter::new(ModelClient::new( + agent_config.clone(), + None, + agent_config.model_provider.clone(), + agent_config.model_reasoning_effort, + agent_config.model_reasoning_summary, + conversation_id, + ))); let tools_config = ToolsConfig::new(&ToolsConfigParams { model_family: &agent_config.model_family, include_plan_tool: agent_config.include_plan_tool, diff --git a/codex-rs/core/src/exec.rs b/codex-rs/core/src/exec.rs index 63200597d4..5554a7f884 100644 --- a/codex-rs/core/src/exec.rs +++ b/codex-rs/core/src/exec.rs @@ -1,7 +1,6 @@ #[cfg(unix)] use std::os::unix::process::ExitStatusExt; -use std::collections::HashMap; use std::io; use std::path::Path; use std::path::PathBuf; @@ -27,10 +26,9 @@ use crate::protocol::SandboxPolicy; use crate::seatbelt::spawn_command_under_seatbelt; use crate::spawn::StdioPolicy; use crate::spawn::spawn_child_async; +pub use codex_agent::exec::ExecParams; pub use codex_agent::sandbox::SandboxType; -const DEFAULT_TIMEOUT_MS: u64 = 10_000; - // Hardcode these since it does not seem worth including the libc crate just // for these. const SIGKILL_CODE: i32 = 9; @@ -46,22 +44,6 @@ const AGGREGATE_BUFFER_INITIAL_CAPACITY: usize = 8 * 1024; // 8 KiB /// Aggregation still collects full output; only the live event stream is capped. pub(crate) const MAX_EXEC_OUTPUT_DELTAS_PER_CALL: usize = 10_000; -#[derive(Clone, Debug)] -pub struct ExecParams { - pub command: Vec, - pub cwd: PathBuf, - pub timeout_ms: Option, - pub env: HashMap, - pub with_escalated_permissions: Option, - pub justification: Option, -} - -impl ExecParams { - pub fn timeout_duration(&self) -> Duration { - Duration::from_millis(self.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS)) - } -} - #[derive(Clone)] pub struct StdoutStream { pub sub_id: String, diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 4e57a8d1a7..d780b69146 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -37,11 +37,13 @@ pub mod landlock; mod mcp_connection_manager; mod mcp_tool_call; mod message_history; +mod model_client_adapter; mod model_provider_info; pub mod parse_command; mod truncate; mod unified_exec; mod user_instructions; +pub use model_client_adapter::CoreModelClientAdapter; pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID; pub use model_provider_info::built_in_model_providers; pub use model_provider_info::create_oss_provider_with_base_url; diff --git a/codex-rs/core/src/model_client_adapter.rs b/codex-rs/core/src/model_client_adapter.rs new file mode 100644 index 0000000000..ac645f28ea --- /dev/null +++ b/codex-rs/core/src/model_client_adapter.rs @@ -0,0 +1,70 @@ +use std::sync::Arc; + +use async_trait::async_trait; + +use crate::client::ModelClient; +use crate::client_common::Prompt; +use crate::error::CodexErr; +use crate::model_family::ModelFamily; +use codex_agent::model_client::ModelClientAdapter; +use codex_agent::model_provider::ModelProviderInfo; +use codex_agent::services::CredentialsProvider; + +#[derive(Clone)] +pub struct CoreModelClientAdapter { + inner: ModelClient, +} + +impl CoreModelClientAdapter { + pub fn new(inner: ModelClient) -> Self { + Self { inner } + } + + pub fn inner(&self) -> &ModelClient { + &self.inner + } +} + +#[async_trait] +impl ModelClientAdapter for CoreModelClientAdapter { + type Error = CodexErr; + + fn get_model_context_window(&self) -> Option { + self.inner.get_model_context_window() + } + + fn get_auto_compact_token_limit(&self) -> Option { + self.inner.get_auto_compact_token_limit() + } + + fn get_provider(&self) -> ModelProviderInfo { + self.inner.get_provider() + } + + fn get_model(&self) -> String { + self.inner.get_model() + } + + fn get_model_family(&self) -> ModelFamily { + self.inner.get_model_family() + } + + fn get_reasoning_effort(&self) -> Option { + self.inner.get_reasoning_effort() + } + + fn get_reasoning_summary(&self) -> codex_protocol::config_types::ReasoningSummary { + self.inner.get_reasoning_summary() + } + + fn get_auth_manager(&self) -> Option> { + self.inner.get_auth_manager() + } + + async fn stream( + &self, + prompt: &Prompt, + ) -> Result, Self::Error> { + self.inner.stream(prompt).await + } +} diff --git a/codex-rs/core/src/openai_tools.rs b/codex-rs/core/src/openai_tools.rs index 48dca79671..2f23bc8c5d 100644 --- a/codex-rs/core/src/openai_tools.rs +++ b/codex-rs/core/src/openai_tools.rs @@ -10,153 +10,14 @@ use crate::plan_tool::PLAN_TOOL; use crate::tool_apply_patch::ApplyPatchToolType; use crate::tool_apply_patch::create_apply_patch_freeform_tool; use crate::tool_apply_patch::create_apply_patch_json_tool; - -#[derive(Debug, Clone, Serialize, PartialEq)] -pub struct ResponsesApiTool { - pub(crate) name: String, - pub(crate) description: String, - /// TODO: Validation. When strict is set to true, the JSON schema, - /// `required` and `additional_properties` must be present. All fields in - /// `properties` must be present in `required`. - pub(crate) strict: bool, - pub(crate) parameters: JsonSchema, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct FreeformTool { - pub(crate) name: String, - pub(crate) description: String, - pub(crate) format: FreeformToolFormat, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct FreeformToolFormat { - pub(crate) r#type: String, - pub(crate) syntax: String, - pub(crate) definition: String, -} - -/// When serialized as JSON, this produces a valid "Tool" in the OpenAI -/// Responses API. -#[derive(Debug, Clone, Serialize, PartialEq)] -#[serde(tag = "type")] -pub(crate) enum OpenAiTool { - #[serde(rename = "function")] - Function(ResponsesApiTool), - #[serde(rename = "local_shell")] - LocalShell {}, - // TODO: Understand why we get an error on web_search although the API docs say it's supported. - // https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C - #[serde(rename = "web_search")] - WebSearch {}, - #[serde(rename = "custom")] - Freeform(FreeformTool), -} - -#[derive(Debug, Clone)] -pub enum ConfigShellToolType { - Default, - Local, - Streamable, -} - -#[derive(Debug, Clone)] -pub(crate) struct ToolsConfig { - pub shell_type: ConfigShellToolType, - pub plan_tool: bool, - pub apply_patch_tool_type: Option, - pub web_search_request: bool, - pub include_view_image_tool: bool, - pub experimental_unified_exec_tool: bool, -} - -pub(crate) struct ToolsConfigParams<'a> { - pub(crate) model_family: &'a ModelFamily, - pub(crate) include_plan_tool: bool, - pub(crate) include_apply_patch_tool: bool, - pub(crate) include_web_search_request: bool, - pub(crate) use_streamable_shell_tool: bool, - pub(crate) include_view_image_tool: bool, - pub(crate) experimental_unified_exec_tool: bool, -} - -impl ToolsConfig { - pub fn new(params: &ToolsConfigParams) -> Self { - let ToolsConfigParams { - model_family, - include_plan_tool, - include_apply_patch_tool, - include_web_search_request, - use_streamable_shell_tool, - include_view_image_tool, - experimental_unified_exec_tool, - } = params; - let shell_type = if *use_streamable_shell_tool { - ConfigShellToolType::Streamable - } else if model_family.uses_local_shell_tool { - ConfigShellToolType::Local - } else { - ConfigShellToolType::Default - }; - - let apply_patch_tool_type = match model_family.apply_patch_tool_type { - Some(ApplyPatchToolType::Freeform) => Some(ApplyPatchToolType::Freeform), - Some(ApplyPatchToolType::Function) => Some(ApplyPatchToolType::Function), - None => { - if *include_apply_patch_tool { - Some(ApplyPatchToolType::Freeform) - } else { - None - } - } - }; - - Self { - shell_type, - plan_tool: *include_plan_tool, - apply_patch_tool_type, - web_search_request: *include_web_search_request, - include_view_image_tool: *include_view_image_tool, - experimental_unified_exec_tool: *experimental_unified_exec_tool, - } - } -} - -/// Generic JSON‑Schema subset needed for our tool definitions -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(tag = "type", rename_all = "lowercase")] -pub(crate) enum JsonSchema { - Boolean { - #[serde(skip_serializing_if = "Option::is_none")] - description: Option, - }, - String { - #[serde(skip_serializing_if = "Option::is_none")] - description: Option, - }, - /// MCP schema allows "number" | "integer" for Number - #[serde(alias = "integer")] - Number { - #[serde(skip_serializing_if = "Option::is_none")] - description: Option, - }, - Array { - items: Box, - - #[serde(skip_serializing_if = "Option::is_none")] - description: Option, - }, - Object { - properties: BTreeMap, - #[serde(skip_serializing_if = "Option::is_none")] - required: Option>, - #[serde( - rename = "additionalProperties", - skip_serializing_if = "Option::is_none" - )] - additional_properties: Option, - }, -} +pub(crate) use codex_agent::tool_schema::FreeformTool; +pub(crate) use codex_agent::tool_schema::FreeformToolFormat; +pub(crate) use codex_agent::tool_schema::JsonSchema; +pub(crate) use codex_agent::tool_schema::OpenAiTool; +pub(crate) use codex_agent::tool_schema::ResponsesApiTool; +pub(crate) use codex_agent::tools_config::ConfigShellToolType; +pub(crate) use codex_agent::tools_config::ToolsConfig; +pub(crate) use codex_agent::tools_config::ToolsConfigParams; fn create_unified_exec_tool() -> OpenAiTool { let mut properties = BTreeMap::new(); diff --git a/codex-rs/core/src/sandbox/mod.rs b/codex-rs/core/src/sandbox/mod.rs index d1553f0d17..9ccf531505 100644 --- a/codex-rs/core/src/sandbox/mod.rs +++ b/codex-rs/core/src/sandbox/mod.rs @@ -2,17 +2,14 @@ mod apply_patch_adapter; mod backend; mod planner; +pub(crate) use apply_patch_adapter::build_exec_params_for_apply_patch; pub use backend::BackendRegistry; pub use backend::DirectBackend; pub use backend::LinuxBackend; pub use backend::SeatbeltBackend; pub use backend::SpawnBackend; -pub use planner::ExecPlan; -pub use planner::ExecRequest; -pub use planner::PatchExecRequest; +pub use codex_agent::sandbox::ExecPlan; pub(crate) use planner::PreparedExec; -pub use planner::plan_apply_patch; -pub use planner::plan_exec; pub(crate) use planner::prepare_exec_invocation; use crate::error::Result; diff --git a/codex-rs/core/src/sandbox/planner.rs b/codex-rs/core/src/sandbox/planner.rs index 29f8dbbaab..e60ebc9b89 100644 --- a/codex-rs/core/src/sandbox/planner.rs +++ b/codex-rs/core/src/sandbox/planner.rs @@ -2,111 +2,20 @@ use std::collections::HashSet; use std::path::Path; use codex_agent::apply_patch::ApplyPatchExec; -use codex_agent::safety::SafetyCheck; -use codex_agent::safety::assess_command_safety; -use codex_agent::safety::assess_patch_safety; +use codex_agent::sandbox::CommandPlanRequest; +use codex_agent::sandbox::ExecPlan; +use codex_agent::sandbox::PatchPlanRequest; use codex_agent::sandbox::SandboxType; +use codex_agent::sandbox::plan_apply_patch; +use codex_agent::sandbox::plan_exec; use codex_agent::services::ApprovalCoordinator; -use codex_apply_patch::ApplyPatchAction; -use super::apply_patch_adapter::build_exec_params_for_apply_patch; -use crate::codex::TurnContext; use crate::exec::ExecParams; use crate::function_tool::FunctionCallError; use crate::protocol::AskForApproval; use crate::protocol::ReviewDecision; use crate::protocol::SandboxPolicy; -#[derive(Clone, Debug)] -pub struct ExecRequest<'a> { - pub params: &'a ExecParams, - pub approval: AskForApproval, - pub policy: &'a SandboxPolicy, - pub approved_session_commands: &'a HashSet>, -} - -#[derive(Clone, Debug)] -pub enum ExecPlan { - Reject { - reason: String, - }, - AskUser { - reason: Option, - }, - Approved { - sandbox: SandboxType, - on_failure_escalate: bool, - approved_by_user: bool, - }, -} - -impl ExecPlan { - pub fn approved( - sandbox: SandboxType, - on_failure_escalate: bool, - approved_by_user: bool, - ) -> Self { - ExecPlan::Approved { - sandbox, - on_failure_escalate, - approved_by_user, - } - } -} - -pub fn plan_exec(req: &ExecRequest<'_>) -> ExecPlan { - let params = req.params; - let with_escalated_permissions = params.with_escalated_permissions.unwrap_or(false); - let safety = assess_command_safety( - ¶ms.command, - req.approval, - req.policy, - req.approved_session_commands, - with_escalated_permissions, - ); - - match safety { - SafetyCheck::AutoApprove { sandbox_type } => ExecPlan::Approved { - sandbox: sandbox_type, - on_failure_escalate: should_escalate_on_failure(req.approval, sandbox_type), - approved_by_user: false, - }, - SafetyCheck::AskUser => ExecPlan::AskUser { - reason: params.justification.clone(), - }, - SafetyCheck::Reject { reason } => ExecPlan::Reject { reason }, - } -} - -#[derive(Clone, Debug)] -pub struct PatchExecRequest<'a> { - pub action: &'a ApplyPatchAction, - pub approval: AskForApproval, - pub policy: &'a SandboxPolicy, - pub cwd: &'a Path, - pub user_explicitly_approved: bool, -} - -pub fn plan_apply_patch(req: &PatchExecRequest<'_>) -> ExecPlan { - if req.user_explicitly_approved { - ExecPlan::Approved { - sandbox: SandboxType::None, - on_failure_escalate: false, - approved_by_user: true, - } - } else { - match assess_patch_safety(req.action, req.approval, req.policy, req.cwd) { - SafetyCheck::AutoApprove { sandbox_type } => ExecPlan::Approved { - sandbox: sandbox_type, - on_failure_escalate: should_escalate_on_failure(req.approval, sandbox_type), - approved_by_user: false, - }, - SafetyCheck::AskUser => ExecPlan::AskUser { reason: None }, - SafetyCheck::Reject { reason } => ExecPlan::Reject { reason }, - } - } -} - #[derive(Debug)] pub(crate) struct PreparedExec { pub(crate) params: ExecParams, @@ -117,28 +26,31 @@ pub(crate) struct PreparedExec { pub(crate) async fn prepare_exec_invocation( approvals: &dyn ApprovalCoordinator, - turn_context: &TurnContext, + approval_policy: AskForApproval, + sandbox_policy: &SandboxPolicy, + cwd: &Path, sub_id: &str, call_id: &str, params: ExecParams, apply_patch_exec: Option, approved_session_commands: HashSet>, ) -> Result { - let mut params = params; + let command_for_display = if let Some(exec) = apply_patch_exec.as_ref() { + vec!["apply_patch".to_string(), exec.action.patch.clone()] + } else { + params.command.clone() + }; - let (plan, command_for_display) = if let Some(exec) = apply_patch_exec.as_ref() { - params = build_exec_params_for_apply_patch(exec, ¶ms)?; - let command_for_display = vec!["apply_patch".to_string(), exec.action.patch.clone()]; - - let plan_req = PatchExecRequest { + let plan = if let Some(exec) = apply_patch_exec.as_ref() { + let plan_req = PatchPlanRequest { action: &exec.action, - approval: turn_context.approval_policy, - policy: &turn_context.sandbox_policy, - cwd: &turn_context.cwd, + approval: approval_policy, + policy: sandbox_policy, + cwd, user_explicitly_approved: exec.user_explicitly_approved_this_action, }; - let plan = match plan_apply_patch(&plan_req) { + match plan_apply_patch(&plan_req) { plan @ ExecPlan::Approved { .. } => plan, ExecPlan::AskUser { .. } => { return Err(FunctionCallError::RespondToModel( @@ -150,20 +62,18 @@ pub(crate) async fn prepare_exec_invocation( "patch rejected: {reason}" ))); } + } + } else { + let plan_req = CommandPlanRequest { + command: ¶ms.command, + approval: approval_policy, + policy: sandbox_policy, + approved_session_commands: &approved_session_commands, + with_escalated_permissions: params.with_escalated_permissions.unwrap_or(false), + justification: params.justification.as_ref(), }; - (plan, command_for_display) - } else { - let command_for_display = params.command.clone(); - - let initial_plan = plan_exec(&ExecRequest { - params: ¶ms, - approval: turn_context.approval_policy, - policy: &turn_context.sandbox_policy, - approved_session_commands: &approved_session_commands, - }); - - let plan = match initial_plan { + match plan_exec(&plan_req) { plan @ ExecPlan::Approved { .. } => plan, ExecPlan::AskUser { reason } => { let decision = approvals @@ -175,6 +85,7 @@ pub(crate) async fn prepare_exec_invocation( reason, ) .await; + match decision { ReviewDecision::Approved => ExecPlan::approved(SandboxType::None, false, true), ReviewDecision::ApprovedForSession => { @@ -193,9 +104,7 @@ pub(crate) async fn prepare_exec_invocation( "exec command rejected: {reason:?}" ))); } - }; - - (plan, command_for_display) + } }; Ok(PreparedExec { @@ -205,13 +114,3 @@ pub(crate) async fn prepare_exec_invocation( apply_patch_exec, }) } - -fn should_escalate_on_failure(approval: AskForApproval, sandbox: SandboxType) -> bool { - matches!( - (approval, sandbox), - ( - AskForApproval::UnlessTrusted | AskForApproval::OnFailure, - SandboxType::MacosSeatbelt | SandboxType::LinuxSeccomp - ) - ) -}