Compare commits

..

3 Commits

Author SHA1 Message Date
Michael Bolin
1ccd7af0b3 feat: support dotenv (including ~/.codex/.env) 2025-07-22 10:45:00 -07:00
Michael Bolin
dfa9a44202 feat: support dotenv 2025-07-22 10:19:42 -07:00
Michael Bolin
d5809ef6ef chore: install an extension for TOML syntax highlighting in the devcontainer 2025-07-22 10:11:07 -07:00
49 changed files with 544 additions and 1033 deletions

View File

@@ -8,8 +8,8 @@
"@actions/github": "^6.0.1",
},
"devDependencies": {
"@types/bun": "^1.2.19",
"@types/node": "^24.1.0",
"@types/bun": "^1.2.18",
"@types/node": "^24.0.13",
"prettier": "^3.6.2",
"typescript": "^5.8.3",
},
@@ -48,15 +48,15 @@
"@octokit/types": ["@octokit/types@13.10.0", "", { "dependencies": { "@octokit/openapi-types": "^24.2.0" } }, "sha512-ifLaO34EbbPj0Xgro4G5lP5asESjwHracYJvVaPIyXMuiuXLlhic3S47cBdTb+jfODkTE5YtGCLt3Ay3+J97sA=="],
"@types/bun": ["@types/bun@1.2.19", "", { "dependencies": { "bun-types": "1.2.19" } }, "sha512-d9ZCmrH3CJ2uYKXQIUuZ/pUnTqIvLDS0SK7pFmbx8ma+ziH/FRMoAq5bYpRG7y+w1gl+HgyNZbtqgMq4W4e2Lg=="],
"@types/bun": ["@types/bun@1.2.18", "", { "dependencies": { "bun-types": "1.2.18" } }, "sha512-Xf6RaWVheyemaThV0kUfaAUvCNokFr+bH8Jxp+tTZfx7dAPA8z9ePnP9S9+Vspzuxxx9JRAXhnyccRj3GyCMdQ=="],
"@types/node": ["@types/node@24.1.0", "", { "dependencies": { "undici-types": "~7.8.0" } }, "sha512-ut5FthK5moxFKH2T1CUOC6ctR67rQRvvHdFLCD2Ql6KXmMuCrjsSsRI9UsLCm9M18BMwClv4pn327UvB7eeO1w=="],
"@types/node": ["@types/node@24.0.13", "", { "dependencies": { "undici-types": "~7.8.0" } }, "sha512-Qm9OYVOFHFYg3wJoTSrz80hoec5Lia/dPp84do3X7dZvLikQvM1YpmvTBEdIr/e+U8HTkFjLHLnl78K/qjf+jQ=="],
"@types/react": ["@types/react@19.1.8", "", { "dependencies": { "csstype": "^3.0.2" } }, "sha512-AwAfQ2Wa5bCx9WP8nZL2uMZWod7J7/JSplxbTmBQ5ms6QpqNYm672H0Vu9ZVKVngQ+ii4R/byguVEUZQyeg44g=="],
"before-after-hook": ["before-after-hook@2.2.3", "", {}, "sha512-NzUnlZexiaH/46WDhANlyR2bXRopNg4F/zuSA3OpZnllCUgRaOF2znDioDWrmbNVsuZk6l9pMquQB38cfBZwkQ=="],
"bun-types": ["bun-types@1.2.19", "", { "dependencies": { "@types/node": "*" }, "peerDependencies": { "@types/react": "^19" } }, "sha512-uAOTaZSPuYsWIXRpj7o56Let0g/wjihKCkeRqUBhlLVM/Bt+Fj9xTo+LhC1OV1XDaGkz4hNC80et5xgy+9KTHQ=="],
"bun-types": ["bun-types@1.2.18", "", { "dependencies": { "@types/node": "*" }, "peerDependencies": { "@types/react": "^19" } }, "sha512-04+Eha5NP7Z0A9YgDAzMk5PHR16ZuLVa83b26kH5+cp1qZW4F6FmAURngE7INf4tKOvCE69vYvDEwoNl1tGiWw=="],
"csstype": ["csstype@3.1.3", "", {}, "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw=="],
@@ -82,8 +82,6 @@
"@octokit/plugin-rest-endpoint-methods/@octokit/types": ["@octokit/types@12.6.0", "", { "dependencies": { "@octokit/openapi-types": "^20.0.0" } }, "sha512-1rhSOfRa6H9w4YwK0yrf5faDaDTb+yLyBUKOCV4xtCDB5VmIPqd/v9yr9o6SAzOAlRxMiRiCic6JVM1/kunVkw=="],
"bun-types/@types/node": ["@types/node@24.0.13", "", { "dependencies": { "undici-types": "~7.8.0" } }, "sha512-Qm9OYVOFHFYg3wJoTSrz80hoec5Lia/dPp84do3X7dZvLikQvM1YpmvTBEdIr/e+U8HTkFjLHLnl78K/qjf+jQ=="],
"@octokit/plugin-paginate-rest/@octokit/types/@octokit/openapi-types": ["@octokit/openapi-types@20.0.0", "", {}, "sha512-EtqRBEjp1dL/15V7WiX5LJMIxxkdiGJnabzYx5Apx4FkQIFgAfKumXeYAqqJCj1s+BMX4cPFIFC4OLCR6stlnA=="],
"@octokit/plugin-rest-endpoint-methods/@octokit/types/@octokit/openapi-types": ["@octokit/openapi-types@20.0.0", "", {}, "sha512-EtqRBEjp1dL/15V7WiX5LJMIxxkdiGJnabzYx5Apx4FkQIFgAfKumXeYAqqJCj1s+BMX4cPFIFC4OLCR6stlnA=="],

View File

@@ -13,8 +13,8 @@
"@actions/github": "^6.0.1"
},
"devDependencies": {
"@types/bun": "^1.2.19",
"@types/node": "^24.1.0",
"@types/bun": "^1.2.18",
"@types/node": "^24.0.13",
"prettier": "^3.6.2",
"typescript": "^5.8.3"
}

View File

@@ -93,7 +93,7 @@ jobs:
sudo apt install -y musl-tools pkg-config
- name: Cargo build
run: cargo build --target ${{ matrix.target }} --release --bin codex --bin codex-exec --bin codex-linux-sandbox
run: cargo build --target ${{ matrix.target }} --release --all-targets --all-features
- name: Stage artifacts
shell: bash

View File

@@ -15,7 +15,7 @@ This is the home of the **Codex CLI**, which is a coding agent from OpenAI that
<!-- Begin ToC -->
- [Experimental technology disclaimer](#experimental-technology-disclaimer)
- [Quick start](#quickstart)
- [Quickstart](#quickstart)
- [OpenAI API Users](#openai-api-users)
- [OpenAI Plus/Pro Users](#openai-pluspro-users)
- [Why Codex?](#why-codex)

29
codex-rs/Cargo.lock generated
View File

@@ -648,6 +648,7 @@ version = "0.0.0"
dependencies = [
"clap",
"codex-core",
"dotenvy",
"serde",
"toml 0.9.1",
]
@@ -676,13 +677,13 @@ dependencies = [
"openssl-sys",
"predicates",
"pretty_assertions",
"rand 0.9.2",
"rand 0.9.1",
"reqwest",
"seccompiler",
"serde",
"serde_json",
"sha1",
"strum_macros 0.27.2",
"strum_macros 0.27.1",
"tempfile",
"thiserror 2.0.12",
"time",
@@ -756,9 +757,7 @@ version = "0.0.0"
dependencies = [
"anyhow",
"clap",
"codex-common",
"codex-core",
"dotenvy",
"landlock",
"libc",
"seccompiler",
@@ -796,6 +795,7 @@ version = "0.0.0"
dependencies = [
"anyhow",
"assert_cmd",
"codex-common",
"codex-core",
"codex-linux-sandbox",
"mcp-types",
@@ -840,8 +840,8 @@ dependencies = [
"regex-lite",
"serde_json",
"shlex",
"strum 0.27.2",
"strum_macros 0.27.2",
"strum 0.27.1",
"strum_macros 0.27.1",
"tokio",
"tracing",
"tracing-appender",
@@ -3279,9 +3279,9 @@ dependencies = [
[[package]]
name = "rand"
version = "0.9.2"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
@@ -4263,9 +4263,9 @@ dependencies = [
[[package]]
name = "strum"
version = "0.27.2"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
[[package]]
name = "strum_macros"
@@ -4282,13 +4282,14 @@ dependencies = [
[[package]]
name = "strum_macros"
version = "0.27.2"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.104",
]
@@ -4855,9 +4856,9 @@ dependencies = [
[[package]]
name = "tree-sitter"
version = "0.25.8"
version = "0.25.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d7b8994f367f16e6fa14b5aebbcb350de5d7cbea82dc5b00ae997dd71680dd2"
checksum = "a7cf18d43cbf0bfca51f657132cc616a5097edc4424d538bae6fa60142eaf9f0"
dependencies = [
"cc",
"regex",

View File

@@ -14,7 +14,7 @@ workspace = true
anyhow = "1"
similar = "2.7.0"
thiserror = "2.0.12"
tree-sitter = "0.25.8"
tree-sitter = "0.25.3"
tree-sitter-bash = "0.25.0"
[dev-dependencies]

View File

@@ -1,5 +1,3 @@
use std::path::PathBuf;
use clap::Parser;
use codex_common::CliConfigOverrides;
use codex_core::config::Config;
@@ -19,10 +17,7 @@ pub struct ApplyCommand {
#[clap(flatten)]
pub config_overrides: CliConfigOverrides,
}
pub async fn run_apply_command(
apply_cli: ApplyCommand,
cwd: Option<PathBuf>,
) -> anyhow::Result<()> {
pub async fn run_apply_command(apply_cli: ApplyCommand) -> anyhow::Result<()> {
let config = Config::load_with_cli_overrides(
apply_cli
.config_overrides
@@ -34,13 +29,10 @@ pub async fn run_apply_command(
init_chatgpt_token_from_auth(&config.codex_home).await?;
let task_response = get_task(&config, apply_cli.task_id).await?;
apply_diff_from_task(task_response, cwd).await
apply_diff_from_task(task_response).await
}
pub async fn apply_diff_from_task(
task_response: GetTaskResponse,
cwd: Option<PathBuf>,
) -> anyhow::Result<()> {
pub async fn apply_diff_from_task(task_response: GetTaskResponse) -> anyhow::Result<()> {
let diff_turn = match task_response.current_diff_task_turn {
Some(turn) => turn,
None => anyhow::bail!("No diff turn found"),
@@ -50,17 +42,13 @@ pub async fn apply_diff_from_task(
_ => None,
});
match output_diff {
Some(output_diff) => apply_diff(&output_diff.diff, cwd).await,
Some(output_diff) => apply_diff(&output_diff.diff).await,
None => anyhow::bail!("No PR output item found"),
}
}
async fn apply_diff(diff: &str, cwd: Option<PathBuf>) -> anyhow::Result<()> {
let mut cmd = tokio::process::Command::new("git");
if let Some(cwd) = cwd {
cmd.current_dir(cwd);
}
let toplevel_output = cmd
async fn apply_diff(diff: &str) -> anyhow::Result<()> {
let toplevel_output = tokio::process::Command::new("git")
.args(vec!["rev-parse", "--show-toplevel"])
.output()
.await?;

View File

@@ -78,7 +78,17 @@ async fn test_apply_command_creates_fibonacci_file() {
.await
.expect("Failed to load fixture");
apply_diff_from_task(task_response, Some(repo_path.to_path_buf()))
let original_dir = std::env::current_dir().expect("Failed to get current dir");
std::env::set_current_dir(repo_path).expect("Failed to change directory");
struct DirGuard(std::path::PathBuf);
impl Drop for DirGuard {
fn drop(&mut self) {
let _ = std::env::set_current_dir(&self.0);
}
}
let _guard = DirGuard(original_dir);
apply_diff_from_task(task_response)
.await
.expect("Failed to apply diff from task");
@@ -163,7 +173,7 @@ console.log(fib(10));
.await
.expect("Failed to load fixture");
let apply_result = apply_diff_from_task(task_response, Some(repo_path.to_path_buf())).await;
let apply_result = apply_diff_from_task(task_response).await;
assert!(
apply_result.is_err(),

View File

@@ -9,6 +9,7 @@ use codex_cli::SeatbeltCommand;
use codex_cli::login::run_login_with_chatgpt;
use codex_cli::proto;
use codex_common::CliConfigOverrides;
use codex_common::load_dotenv;
use codex_exec::Cli as ExecCli;
use codex_tui::Cli as TuiCli;
use std::path::PathBuf;
@@ -99,6 +100,8 @@ fn main() -> anyhow::Result<()> {
}
async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
load_dotenv();
let cli = MultitoolCli::parse();
match cli.subcommand {
@@ -145,7 +148,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()
},
Some(Subcommand::Apply(mut apply_cli)) => {
prepend_config_flags(&mut apply_cli.config_overrides, cli.config_overrides);
run_apply_command(apply_cli, None).await?;
run_apply_command(apply_cli).await?;
}
}

View File

@@ -9,11 +9,12 @@ workspace = true
[dependencies]
clap = { version = "4", features = ["derive", "wrap_help"], optional = true }
codex-core = { path = "../core" }
dotenvy = { version = "0.15.7", optional = true }
toml = { version = "0.9", optional = true }
serde = { version = "1", optional = true }
[features]
# Separate feature so that `clap` is not a mandatory dependency.
cli = ["clap", "toml", "serde"]
cli = ["clap", "dotenvy", "toml", "serde"]
elapsed = []
sandbox_summary = []

View File

@@ -0,0 +1,7 @@
/// Load env vars from ~/.codex/.env and `$(pwd)/.env`.
pub fn load_dotenv() {
if let Ok(codex_home) = codex_core::config::find_codex_home() {
dotenvy::from_path(codex_home.join(".env")).ok();
}
dotenvy::dotenv().ok();
}

View File

@@ -1,6 +1,12 @@
#[cfg(feature = "cli")]
mod approval_mode_cli_arg;
#[cfg(feature = "cli")]
mod dotenv;
#[cfg(feature = "cli")]
pub use dotenv::load_dotenv;
#[cfg(feature = "elapsed")]
pub mod elapsed;

View File

@@ -30,7 +30,7 @@ reqwest = { version = "0.12", features = ["json", "stream"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
sha1 = "0.10.6"
strum_macros = "0.27.2"
strum_macros = "0.27.1"
thiserror = "2.0.12"
time = { version = "0.3", features = ["formatting", "local-offset", "macros"] }
tokio = { version = "1", features = [
@@ -43,7 +43,7 @@ tokio = { version = "1", features = [
tokio-util = "0.7.14"
toml = "0.9.1"
tracing = { version = "0.1.41", features = ["log"] }
tree-sitter = "0.25.8"
tree-sitter = "0.25.3"
tree-sitter-bash = "0.25.0"
uuid = { version = "1", features = ["serde", "v4"] }
wildmatch = "2.4.0"

View File

@@ -41,7 +41,7 @@ pub(crate) async fn stream_chat_completions(
for item in &prompt.input {
match item {
ResponseItem::Message { role, content, .. } => {
ResponseItem::Message { role, content } => {
let mut text = String::new();
for c in content {
match c {
@@ -58,7 +58,6 @@ pub(crate) async fn stream_chat_completions(
name,
arguments,
call_id,
..
} => {
messages.push(json!({
"role": "assistant",
@@ -260,7 +259,6 @@ async fn process_chat_sse<S>(
content: vec![ContentItem::OutputText {
text: content.to_string(),
}],
id: None,
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
@@ -302,7 +300,6 @@ async fn process_chat_sse<S>(
"tool_calls" if fn_call_state.active => {
// Build the FunctionCall response item.
let item = ResponseItem::FunctionCall {
id: None,
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
arguments: fn_call_state.arguments.clone(),
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
@@ -405,7 +402,6 @@ where
}))) => {
if !this.cumulative.is_empty() {
let aggregated_item = crate::models::ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![crate::models::ContentItem::OutputText {
text: std::mem::take(&mut this.cumulative),

View File

@@ -117,15 +117,6 @@ impl ModelClient {
let full_instructions = prompt.get_full_instructions(&self.config.model);
let tools_json = create_tools_json_for_responses_api(prompt, &self.config.model)?;
let reasoning = create_reasoning_param_for_request(&self.config, self.effort, self.summary);
// Request encrypted COT if we are not storing responses,
// otherwise reasoning items will be referenced by ID
let include = if !prompt.store && reasoning.is_some() {
vec!["reasoning.encrypted_content".to_string()]
} else {
vec![]
};
let payload = ResponsesApiRequest {
model: &self.config.model,
instructions: &full_instructions,
@@ -134,10 +125,10 @@ impl ModelClient {
tool_choice: "auto",
parallel_tool_calls: false,
reasoning,
previous_response_id: prompt.prev_id.clone(),
store: prompt.store,
// TODO: make this configurable
stream: true,
include,
};
trace!(

View File

@@ -22,6 +22,8 @@ const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
pub struct Prompt {
/// Conversation context input items.
pub input: Vec<ResponseItem>,
/// Optional previous response ID (when storage is enabled).
pub prev_id: Option<String>,
/// Optional instructions from the user to amend to the built-in agent
/// instructions.
pub user_instructions: Option<String>,
@@ -32,18 +34,11 @@ pub struct Prompt {
/// the "fully qualified" tool name (i.e., prefixed with the server name),
/// which should be reported to the model in place of Tool::name.
pub extra_tools: HashMap<String, mcp_types::Tool>,
/// Optional override for the built-in BASE_INSTRUCTIONS.
pub base_instructions_override: Option<String>,
}
impl Prompt {
pub(crate) fn get_full_instructions(&self, model: &str) -> Cow<'_, str> {
let base = self
.base_instructions_override
.as_deref()
.unwrap_or(BASE_INSTRUCTIONS);
let mut sections: Vec<&str> = vec![base];
let mut sections: Vec<&str> = vec![BASE_INSTRUCTIONS];
if let Some(ref user) = self.user_instructions {
sections.push(user);
}
@@ -131,10 +126,11 @@ pub(crate) struct ResponsesApiRequest<'a> {
pub(crate) tool_choice: &'static str,
pub(crate) parallel_tool_calls: bool,
pub(crate) reasoning: Option<Reasoning>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) previous_response_id: Option<String>,
/// true when using the Responses API.
pub(crate) store: bool,
pub(crate) stream: bool,
pub(crate) include: Vec<String>,
}
use crate::config::Config;

View File

@@ -34,6 +34,7 @@ use tracing::trace;
use tracing::warn;
use uuid::Uuid;
use crate::WireApi;
use crate::client::ModelClient;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
@@ -107,15 +108,13 @@ impl Codex {
let (tx_sub, rx_sub) = async_channel::bounded(64);
let (tx_event, rx_event) = async_channel::bounded(1600);
let user_instructions = get_user_instructions(&config).await;
let instructions = get_user_instructions(&config).await;
let configure_session = Op::ConfigureSession {
provider: config.model_provider.clone(),
model: config.model.clone(),
model_reasoning_effort: config.model_reasoning_effort,
model_reasoning_summary: config.model_reasoning_summary,
user_instructions,
base_instructions: config.base_instructions.clone(),
instructions,
approval_policy: config.approval_policy,
sandbox_policy: config.sandbox_policy.clone(),
disable_response_storage: config.disable_response_storage,
@@ -184,13 +183,11 @@ pub(crate) struct Session {
/// the model as well as sandbox policies are resolved against this path
/// instead of `std::env::current_dir()`.
cwd: PathBuf,
base_instructions: Option<String>,
user_instructions: Option<String>,
instructions: Option<String>,
approval_policy: AskForApproval,
sandbox_policy: SandboxPolicy,
shell_environment_policy: ShellEnvironmentPolicy,
writable_roots: Mutex<Vec<PathBuf>>,
disable_response_storage: bool,
/// Manager for external MCP servers/tools.
mcp_connection_manager: McpConnectionManager,
@@ -219,9 +216,13 @@ impl Session {
struct State {
approved_commands: HashSet<Vec<String>>,
current_task: Option<AgentTask>,
/// Call IDs that have been sent from the Responses API but have not been sent back yet.
/// You CANNOT send a Responses API follow-up message unless you have sent back the output for all pending calls or else it will 400.
pending_call_ids: HashSet<String>,
previous_response_id: Option<String>,
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
pending_input: Vec<ResponseInputItem>,
history: ConversationHistory,
zdr_transcript: Option<ConversationHistory>,
}
impl Session {
@@ -253,7 +254,6 @@ impl Session {
pub async fn request_command_approval(
&self,
sub_id: String,
call_id: String,
command: Vec<String>,
cwd: PathBuf,
reason: Option<String>,
@@ -262,7 +262,6 @@ impl Session {
let event = Event {
id: sub_id.clone(),
msg: EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
call_id,
command,
cwd,
reason,
@@ -279,7 +278,6 @@ impl Session {
pub async fn request_patch_approval(
&self,
sub_id: String,
call_id: String,
action: &ApplyPatchAction,
reason: Option<String>,
grant_root: Option<PathBuf>,
@@ -288,7 +286,6 @@ impl Session {
let event = Event {
id: sub_id.clone(),
msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
call_id,
changes: convert_apply_patch_to_protocol(action),
reason,
grant_root,
@@ -320,11 +317,18 @@ impl Session {
debug!("Recording items for conversation: {items:?}");
self.record_state_snapshot(items).await;
self.state.lock().unwrap().history.record_items(items);
if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() {
transcript.record_items(items);
}
}
async fn record_state_snapshot(&self, items: &[ResponseItem]) {
let snapshot = { crate::rollout::SessionStateSnapshot {} };
let snapshot = {
let state = self.state.lock().unwrap();
crate::rollout::SessionStateSnapshot {
previous_response_id: state.previous_response_id.clone(),
}
};
let recorder = {
let guard = self.rollout.lock().unwrap();
@@ -426,6 +430,8 @@ impl Session {
pub fn abort(&self) {
info!("Aborting existing session");
let mut state = self.state.lock().unwrap();
// Don't clear pending_call_ids because we need to keep track of them to ensure we don't 400 on the next turn.
// We will generate a synthetic aborted response for each pending call id.
state.pending_approvals.clear();
state.pending_input.clear();
if let Some(task) = state.current_task.take() {
@@ -470,10 +476,15 @@ impl Drop for Session {
}
impl State {
pub fn partial_clone(&self) -> Self {
pub fn partial_clone(&self, retain_zdr_transcript: bool) -> Self {
Self {
approved_commands: self.approved_commands.clone(),
history: self.history.clone(),
previous_response_id: self.previous_response_id.clone(),
zdr_transcript: if retain_zdr_transcript {
self.zdr_transcript.clone()
} else {
None
},
..Default::default()
}
}
@@ -566,8 +577,7 @@ async fn submission_loop(
model,
model_reasoning_effort,
model_reasoning_summary,
user_instructions,
base_instructions,
instructions,
approval_policy,
sandbox_policy,
disable_response_storage,
@@ -592,11 +602,13 @@ async fn submission_loop(
}
// Optionally resume an existing rollout.
let mut restored_items: Option<Vec<ResponseItem>> = None;
let mut restored_prev_id: Option<String> = None;
let rollout_recorder: Option<RolloutRecorder> =
if let Some(path) = resume_path.as_ref() {
match RolloutRecorder::resume(path).await {
Ok((rec, saved)) => {
session_id = saved.session_id;
restored_prev_id = saved.state.previous_response_id;
if !saved.items.is_empty() {
restored_items = Some(saved.items);
}
@@ -613,17 +625,15 @@ async fn submission_loop(
let rollout_recorder = match rollout_recorder {
Some(rec) => Some(rec),
None => {
match RolloutRecorder::new(&config, session_id, user_instructions.clone())
.await
{
Ok(r) => Some(r),
Err(e) => {
warn!("failed to initialise rollout recorder: {e}");
None
}
None => match RolloutRecorder::new(&config, session_id, instructions.clone())
.await
{
Ok(r) => Some(r),
Err(e) => {
warn!("failed to initialise rollout recorder: {e}");
None
}
}
},
};
let client = ModelClient::new(
@@ -635,13 +645,22 @@ async fn submission_loop(
);
// abort any current running session and clone its state
let retain_zdr_transcript =
record_conversation_history(disable_response_storage, provider.wire_api);
let state = match sess.take() {
Some(sess) => {
sess.abort();
sess.state.lock().unwrap().partial_clone()
sess.state
.lock()
.unwrap()
.partial_clone(retain_zdr_transcript)
}
None => State {
history: ConversationHistory::new(),
zdr_transcript: if retain_zdr_transcript {
Some(ConversationHistory::new())
} else {
None
},
..Default::default()
},
};
@@ -680,8 +699,7 @@ async fn submission_loop(
client,
tx_event: tx_event.clone(),
ctrl_c: Arc::clone(&ctrl_c),
user_instructions,
base_instructions,
instructions,
approval_policy,
sandbox_policy,
shell_environment_policy: config.shell_environment_policy.clone(),
@@ -692,14 +710,18 @@ async fn submission_loop(
state: Mutex::new(state),
rollout: Mutex::new(rollout_recorder),
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
disable_response_storage,
}));
// Patch restored state into the newly created session.
if let Some(sess_arc) = &sess {
if restored_items.is_some() {
if restored_prev_id.is_some() || restored_items.is_some() {
let mut st = sess_arc.state.lock().unwrap();
st.history.record_items(restored_items.unwrap().iter());
st.previous_response_id = restored_prev_id;
if let (Some(hist), Some(items)) =
(st.zdr_transcript.as_mut(), restored_items.as_ref())
{
hist.record_items(items.iter());
}
}
}
@@ -812,37 +834,6 @@ async fn submission_loop(
}
});
}
Op::Shutdown => {
info!("Shutting down Codex instance");
// Gracefully flush and shutdown rollout recorder on session end so tests
// that inspect the rollout file do not race with the background writer.
if let Some(sess_arc) = sess {
let recorder_opt = sess_arc.rollout.lock().unwrap().take();
if let Some(rec) = recorder_opt {
if let Err(e) = rec.shutdown().await {
warn!("failed to shutdown rollout recorder: {e}");
let event = Event {
id: sub.id.clone(),
msg: EventMsg::Error(ErrorEvent {
message: "Failed to shutdown rollout recorder".to_string(),
}),
};
if let Err(e) = tx_event.send(event).await {
warn!("failed to send error message: {e:?}");
}
}
}
}
let event = Event {
id: sub.id.clone(),
msg: EventMsg::ShutdownComplete,
};
if let Err(e) = tx_event.send(event).await {
warn!("failed to send Shutdown event: {e}");
}
break;
}
}
}
debug!("Agent loop exited");
@@ -877,8 +868,14 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
sess.record_conversation_items(&[initial_input_for_turn.clone().into()])
.await;
let mut input_for_next_turn: Vec<ResponseInputItem> = vec![initial_input_for_turn];
let last_agent_message: Option<String>;
loop {
let mut net_new_turn_input = input_for_next_turn
.drain(..)
.map(ResponseItem::from)
.collect::<Vec<_>>();
// Note that pending_input would be something like a message the user
// submitted through the UI while the model was running. Though the UI
// may support this, the model might not.
@@ -895,7 +892,29 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
// only record the new items that originated in this turn so that it
// represents an append-only log without duplicates.
let turn_input: Vec<ResponseItem> =
[sess.state.lock().unwrap().history.contents(), pending_input].concat();
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
// If we are using Chat/ZDR, we need to send the transcript with
// every turn. By induction, `transcript` already contains:
// - The `input` that kicked off this task.
// - Each `ResponseItem` that was recorded in the previous turn.
// - Each response to a `ResponseItem` (in practice, the only
// response type we seem to have is `FunctionCallOutput`).
//
// The only thing the `transcript` does not contain is the
// `pending_input` that was injected while the model was
// running. We need to add that to the conversation history
// so that the model can see it in the next turn.
[transcript.contents(), pending_input].concat()
} else {
// In practice, net_new_turn_input should contain only:
// - User messages
// - Outputs for function calls requested by the model
net_new_turn_input.extend(pending_input);
// Responses API path we can just send the new items and
// record the same.
net_new_turn_input
};
let turn_input_messages: Vec<String> = turn_input
.iter()
@@ -971,19 +990,8 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
},
);
}
(
ResponseItem::Reasoning {
id,
summary,
encrypted_content,
},
None,
) => {
items_to_record_in_conversation_history.push(ResponseItem::Reasoning {
id: id.clone(),
summary: summary.clone(),
encrypted_content: encrypted_content.clone(),
});
(ResponseItem::Reasoning { .. }, None) => {
// Omit from conversation history.
}
_ => {
warn!("Unexpected response item: {item:?} with response: {response:?}");
@@ -1012,6 +1020,8 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
});
break;
}
input_for_next_turn = responses;
}
Err(e) => {
info!("Turn error: {e:#}");
@@ -1039,13 +1049,27 @@ async fn run_turn(
sub_id: String,
input: Vec<ResponseItem>,
) -> CodexResult<Vec<ProcessedResponseItem>> {
// Decide whether to use server-side storage (previous_response_id) or disable it
let (prev_id, store) = {
let state = sess.state.lock().unwrap();
let store = state.zdr_transcript.is_none();
let prev_id = if store {
state.previous_response_id.clone()
} else {
// When using ZDR, the Responses API may send previous_response_id
// back, but trying to use it results in a 400.
None
};
(prev_id, store)
};
let extra_tools = sess.mcp_connection_manager.list_all_tools();
let prompt = Prompt {
input,
user_instructions: sess.user_instructions.clone(),
store: !sess.disable_response_storage,
prev_id,
user_instructions: sess.instructions.clone(),
store,
extra_tools,
base_instructions_override: sess.base_instructions.clone(),
};
let mut retries = 0;
@@ -1117,17 +1141,11 @@ async fn try_run_turn(
// This usually happens because the user interrupted the model before we responded to one of its tool calls
// and then the user sent a follow-up message.
let missing_calls = {
prompt
.input
sess.state
.lock()
.unwrap()
.pending_call_ids
.iter()
.filter_map(|ri| match ri {
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => Some(call_id),
_ => None,
})
.filter_map(|call_id| {
if completed_call_ids.contains(&call_id) {
None
@@ -1181,14 +1199,31 @@ async fn try_run_turn(
};
match event {
ResponseEvent::Created => {}
ResponseEvent::Created => {
let mut state = sess.state.lock().unwrap();
// We successfully created a new response and ensured that all pending calls were included so we can clear the pending call ids.
state.pending_call_ids.clear();
}
ResponseEvent::OutputItemDone(item) => {
let call_id = match &item {
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => Some(call_id),
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
_ => None,
};
if let Some(call_id) = call_id {
// We just got a new call id so we need to make sure to respond to it in the next turn.
let mut state = sess.state.lock().unwrap();
state.pending_call_ids.insert(call_id.clone());
}
let response = handle_response_item(sess, sub_id, item.clone()).await?;
output.push(ProcessedResponseItem { item, response });
}
ResponseEvent::Completed {
response_id: _,
response_id,
token_usage,
} => {
if let Some(token_usage) = token_usage {
@@ -1201,6 +1236,8 @@ async fn try_run_turn(
.ok();
}
let mut state = sess.state.lock().unwrap();
state.previous_response_id = Some(response_id);
return Ok(output);
}
ResponseEvent::OutputTextDelta(delta) => {
@@ -1240,7 +1277,7 @@ async fn handle_response_item(
}
None
}
ResponseItem::Reasoning { summary, .. } => {
ResponseItem::Reasoning { id: _, summary } => {
for item in summary {
let text = match item {
ReasoningItemReasoningSummary::SummaryText { text } => text,
@@ -1257,7 +1294,6 @@ async fn handle_response_item(
name,
arguments,
call_id,
..
} => {
info!("FunctionCall: {arguments}");
Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
@@ -1428,7 +1464,6 @@ async fn handle_container_exec_with_params(
let rx_approve = sess
.request_command_approval(
sub_id.clone(),
call_id.clone(),
params.command.clone(),
params.cwd.clone(),
None,
@@ -1556,7 +1591,6 @@ async fn handle_sandbox_error(
let rx_approve = sess
.request_command_approval(
sub_id.clone(),
call_id.clone(),
params.command.clone(),
params.cwd.clone(),
Some("command failed; retry without sandbox?".to_string()),
@@ -1574,7 +1608,9 @@ async fn handle_sandbox_error(
sess.notify_background_event(&sub_id, "retrying command without sandbox")
.await;
sess.notify_exec_command_begin(&sub_id, &call_id, &params)
// Emit a fresh Begin event so progress bars reset.
let retry_call_id = format!("{call_id}-retry");
sess.notify_exec_command_begin(&sub_id, &retry_call_id, &params)
.await;
// This is an escalated retry; the policy will not be
@@ -1597,8 +1633,14 @@ async fn handle_sandbox_error(
duration,
} = retry_output;
sess.notify_exec_command_end(&sub_id, &call_id, &stdout, &stderr, exit_code)
.await;
sess.notify_exec_command_end(
&sub_id,
&retry_call_id,
&stdout,
&stderr,
exit_code,
)
.await;
let is_success = exit_code == 0;
let content = format_exec_output(
@@ -1662,7 +1704,7 @@ async fn apply_patch(
// Compute a readable summary of path changes to include in the
// approval request so the user can make an informed decision.
let rx_approve = sess
.request_patch_approval(sub_id.clone(), call_id.clone(), &action, None, None)
.request_patch_approval(sub_id.clone(), &action, None, None)
.await;
match rx_approve.await.unwrap_or_default() {
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => false,
@@ -1700,13 +1742,7 @@ async fn apply_patch(
));
let rx = sess
.request_patch_approval(
sub_id.clone(),
call_id.clone(),
&action,
reason.clone(),
Some(root.clone()),
)
.request_patch_approval(sub_id.clone(), &action, reason.clone(), Some(root.clone()))
.await;
if !matches!(
@@ -1790,7 +1826,6 @@ async fn apply_patch(
let rx = sess
.request_patch_approval(
sub_id.clone(),
call_id.clone(),
&action,
reason.clone(),
Some(root.clone()),
@@ -2049,7 +2084,7 @@ fn format_exec_output(output: &str, exit_code: i32, duration: Duration) -> Strin
fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
responses.iter().rev().find_map(|item| {
if let ResponseItem::Message { role, content, .. } = item {
if let ResponseItem::Message { role, content } = item {
if role == "assistant" {
content.iter().rev().find_map(|ci| {
if let ContentItem::OutputText { text } = ci {
@@ -2066,3 +2101,15 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
}
})
}
/// See [`ConversationHistory`] for details.
fn record_conversation_history(disable_response_storage: bool, wire_api: WireApi) -> bool {
if disable_response_storage {
return true;
}
match wire_api {
WireApi::Responses => false,
WireApi::Chat => true,
}
}

View File

@@ -63,10 +63,7 @@ pub struct Config {
pub disable_response_storage: bool,
/// User-provided instructions from instructions.md.
pub user_instructions: Option<String>,
/// Base instructions override.
pub base_instructions: Option<String>,
pub instructions: Option<String>,
/// Optional external notifier command. When set, Codex will spawn this
/// program after each completed *turn* (i.e. when the agent finishes
@@ -330,9 +327,6 @@ pub struct ConfigToml {
/// Experimental rollout resume path (absolute path to .jsonl; undocumented).
pub experimental_resume: Option<PathBuf>,
/// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS.
pub experimental_instructions_file: Option<PathBuf>,
}
impl ConfigToml {
@@ -365,7 +359,6 @@ pub struct ConfigOverrides {
pub model_provider: Option<String>,
pub config_profile: Option<String>,
pub codex_linux_sandbox_exe: Option<PathBuf>,
pub base_instructions: Option<String>,
}
impl Config {
@@ -376,7 +369,7 @@ impl Config {
overrides: ConfigOverrides,
codex_home: PathBuf,
) -> std::io::Result<Self> {
let user_instructions = Self::load_instructions(Some(&codex_home));
let instructions = Self::load_instructions(Some(&codex_home));
// Destructure ConfigOverrides fully to ensure all overrides are applied.
let ConfigOverrides {
@@ -387,7 +380,6 @@ impl Config {
model_provider,
config_profile: config_profile_key,
codex_linux_sandbox_exe,
base_instructions,
} = overrides;
let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) {
@@ -465,10 +457,6 @@ impl Config {
let experimental_resume = cfg.experimental_resume;
let base_instructions = base_instructions.or(Self::get_base_instructions(
cfg.experimental_instructions_file.as_ref(),
));
let config = Self {
model,
model_context_window,
@@ -487,8 +475,7 @@ impl Config {
.or(cfg.disable_response_storage)
.unwrap_or(false),
notify: cfg.notify,
user_instructions,
base_instructions,
instructions,
mcp_servers: cfg.mcp_servers,
model_providers,
project_doc_max_bytes: cfg.project_doc_max_bytes.unwrap_or(PROJECT_DOC_MAX_BYTES),
@@ -538,15 +525,6 @@ impl Config {
}
})
}
fn get_base_instructions(path: Option<&PathBuf>) -> Option<String> {
let path = path.as_ref()?;
std::fs::read_to_string(path)
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}
}
fn default_model() -> String {
@@ -823,7 +801,7 @@ disable_response_storage = true
sandbox_policy: SandboxPolicy::new_read_only_policy(),
shell_environment_policy: ShellEnvironmentPolicy::default(),
disable_response_storage: false,
user_instructions: None,
instructions: None,
notify: None,
cwd: fixture.cwd(),
mcp_servers: HashMap::new(),
@@ -840,7 +818,6 @@ disable_response_storage = true
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
base_instructions: None,
},
o3_profile_config
);
@@ -871,7 +848,7 @@ disable_response_storage = true
sandbox_policy: SandboxPolicy::new_read_only_policy(),
shell_environment_policy: ShellEnvironmentPolicy::default(),
disable_response_storage: false,
user_instructions: None,
instructions: None,
notify: None,
cwd: fixture.cwd(),
mcp_servers: HashMap::new(),
@@ -888,7 +865,6 @@ disable_response_storage = true
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
base_instructions: None,
};
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
@@ -934,7 +910,7 @@ disable_response_storage = true
sandbox_policy: SandboxPolicy::new_read_only_policy(),
shell_environment_policy: ShellEnvironmentPolicy::default(),
disable_response_storage: true,
user_instructions: None,
instructions: None,
notify: None,
cwd: fixture.cwd(),
mcp_servers: HashMap::new(),
@@ -951,7 +927,6 @@ disable_response_storage = true
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
base_instructions: None,
};
assert_eq!(expected_zdr_profile_config, zdr_profile_config);

View File

@@ -1,7 +1,12 @@
use crate::models::ResponseItem;
/// Transcript of conversation history
#[derive(Debug, Clone, Default)]
/// Transcript of conversation history that is needed:
/// - for ZDR clients for which previous_response_id is not available, so we
/// must include the transcript with every API call. This must include each
/// `function_call` and its corresponding `function_call_output`.
/// - for clients using the "chat completions" API as opposed to the
/// "responses" API.
#[derive(Debug, Clone)]
pub(crate) struct ConversationHistory {
/// The oldest items are at the beginning of the vector.
items: Vec<ResponseItem>,
@@ -39,8 +44,7 @@ fn is_api_message(message: &ResponseItem) -> bool {
ResponseItem::Message { role, .. } => role.as_str() != "system",
ResponseItem::FunctionCallOutput { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::Reasoning { .. } => true,
ResponseItem::Other => false,
| ResponseItem::LocalShellCall { .. } => true,
ResponseItem::Reasoning { .. } | ResponseItem::Other => false,
}
}

View File

@@ -3,7 +3,6 @@ use std::collections::HashMap;
use base64::Engine;
use mcp_types::CallToolResult;
use serde::Deserialize;
use serde::Deserializer;
use serde::Serialize;
use serde::ser::Serializer;
@@ -38,14 +37,12 @@ pub enum ContentItem {
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseItem {
Message {
id: Option<String>,
role: String,
content: Vec<ContentItem>,
},
Reasoning {
id: String,
summary: Vec<ReasoningItemReasoningSummary>,
encrypted_content: Option<String>,
},
LocalShellCall {
/// Set when using the chat completions API.
@@ -56,7 +53,6 @@ pub enum ResponseItem {
action: LocalShellAction,
},
FunctionCall {
id: Option<String>,
name: String,
// The Responses API returns the function call arguments as a *string* that contains
// JSON, not as an alreadyparsed object. We keep it as a raw string here and let
@@ -82,11 +78,7 @@ pub enum ResponseItem {
impl From<ResponseInputItem> for ResponseItem {
fn from(item: ResponseInputItem) -> Self {
match item {
ResponseInputItem::Message { role, content } => Self::Message {
role,
content,
id: None,
},
ResponseInputItem::Message { role, content } => Self::Message { role, content },
ResponseInputItem::FunctionCallOutput { call_id, output } => {
Self::FunctionCallOutput { call_id, output }
}
@@ -185,7 +177,7 @@ pub struct ShellToolCallParams {
pub timeout_ms: Option<u64>,
}
#[derive(Debug, Clone)]
#[derive(Deserialize, Debug, Clone)]
pub struct FunctionCallOutputPayload {
pub content: String,
#[expect(dead_code)]
@@ -213,19 +205,6 @@ impl Serialize for FunctionCallOutputPayload {
}
}
impl<'de> Deserialize<'de> for FunctionCallOutputPayload {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(FunctionCallOutputPayload {
content: s,
success: None,
})
}
}
// Implement Display so callers can treat the payload like a plain string when logging or doing
// trivial substring checks in tests (existing tests call `.contains()` on the output). Display
// returns the raw `content` field.

View File

@@ -27,16 +27,16 @@ const PROJECT_DOC_SEPARATOR: &str = "\n\n--- project-doc ---\n\n";
/// string of instructions.
pub(crate) async fn get_user_instructions(config: &Config) -> Option<String> {
match find_project_doc(config).await {
Ok(Some(project_doc)) => match &config.user_instructions {
Ok(Some(project_doc)) => match &config.instructions {
Some(original_instructions) => Some(format!(
"{original_instructions}{PROJECT_DOC_SEPARATOR}{project_doc}"
)),
None => Some(project_doc),
},
Ok(None) => config.user_instructions.clone(),
Ok(None) => config.instructions.clone(),
Err(e) => {
error!("error trying to find project doc: {e:#}");
config.user_instructions.clone()
config.instructions.clone()
}
}
}
@@ -159,7 +159,7 @@ mod tests {
config.cwd = root.path().to_path_buf();
config.project_doc_max_bytes = limit;
config.user_instructions = instructions.map(ToOwned::to_owned);
config.instructions = instructions.map(ToOwned::to_owned);
config
}

View File

@@ -44,12 +44,8 @@ pub enum Op {
model_reasoning_effort: ReasoningEffortConfig,
model_reasoning_summary: ReasoningSummaryConfig,
/// Model instructions that are appended to the base instructions.
user_instructions: Option<String>,
/// Base instructions override.
base_instructions: Option<String>,
/// Model instructions
instructions: Option<String>,
/// When to escalate for approval for execution
approval_policy: AskForApproval,
/// How to sandbox commands executed in the system
@@ -116,9 +112,6 @@ pub enum Op {
/// Request a single history entry identified by `log_id` + `offset`.
GetHistoryEntryRequest { offset: usize, log_id: u64 },
/// Request to shut down codex instance.
Shutdown,
}
/// Determines the conditions under which the user is consulted to approve
@@ -329,9 +322,6 @@ pub enum EventMsg {
/// Response to GetHistoryEntryRequest.
GetHistoryEntryResponse(GetHistoryEntryResponseEvent),
/// Notification that the agent is shutting down.
ShutdownComplete,
}
// Individual event payload types matching each `EventMsg` variant.
@@ -428,8 +418,6 @@ pub struct ExecCommandEndEvent {
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ExecApprovalRequestEvent {
/// Identifier for the associated exec call, if available.
pub call_id: String,
/// The command to be executed.
pub command: Vec<String>,
/// The command's working directory.
@@ -441,8 +429,6 @@ pub struct ExecApprovalRequestEvent {
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ApplyPatchApprovalRequestEvent {
/// Responses API call id for the associated patch apply call, if available.
pub call_id: String,
pub changes: HashMap<PathBuf, FileChange>,
/// Optional explanatory reason (e.g. request for extra write access).
#[serde(skip_serializing_if = "Option::is_none")]

View File

@@ -14,9 +14,7 @@ use time::macros::format_description;
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::Sender;
use tokio::sync::mpsc::{self};
use tokio::sync::oneshot;
use tracing::info;
use tracing::warn;
use uuid::Uuid;
use crate::config::Config;
@@ -32,7 +30,9 @@ pub struct SessionMeta {
}
#[derive(Serialize, Deserialize, Default, Clone)]
pub struct SessionStateSnapshot {}
pub struct SessionStateSnapshot {
pub previous_response_id: Option<String>,
}
#[derive(Serialize, Deserialize, Default, Clone)]
pub struct SavedSession {
@@ -58,10 +58,10 @@ pub(crate) struct RolloutRecorder {
tx: Sender<RolloutCmd>,
}
#[derive(Clone)]
enum RolloutCmd {
AddItems(Vec<ResponseItem>),
UpdateState(SessionStateSnapshot),
Shutdown { ack: oneshot::Sender<()> },
}
impl RolloutRecorder {
@@ -119,9 +119,8 @@ impl RolloutRecorder {
ResponseItem::Message { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::FunctionCallOutput { .. }
| ResponseItem::Reasoning { .. } => filtered.push(item.clone()),
ResponseItem::Other => {
| ResponseItem::FunctionCallOutput { .. } => filtered.push(item.clone()),
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
// These should never be serialized.
continue;
}
@@ -173,17 +172,13 @@ impl RolloutRecorder {
}
continue;
}
match serde_json::from_value::<ResponseItem>(v.clone()) {
Ok(item) => match item {
if let Ok(item) = serde_json::from_value::<ResponseItem>(v.clone()) {
match item {
ResponseItem::Message { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::FunctionCallOutput { .. }
| ResponseItem::Reasoning { .. } => items.push(item),
ResponseItem::Other => {}
},
Err(e) => {
warn!("failed to parse item: {v:?}, error: {e}");
| ResponseItem::FunctionCallOutput { .. } => items.push(item),
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
}
}
}
@@ -205,21 +200,6 @@ impl RolloutRecorder {
info!("Resumed rollout successfully from {path:?}");
Ok((Self { tx }, saved))
}
pub async fn shutdown(&self) -> std::io::Result<()> {
let (tx_done, rx_done) = oneshot::channel();
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
Ok(_) => rx_done
.await
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))),
Err(e) => {
warn!("failed to send rollout shutdown command: {e}");
Err(IoError::other(format!(
"failed to send rollout shutdown command: {e}"
)))
}
}
}
}
struct LogFileInfo {
@@ -287,14 +267,13 @@ async fn rollout_writer(
ResponseItem::Message { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::FunctionCallOutput { .. }
| ResponseItem::Reasoning { .. } => {
| ResponseItem::FunctionCallOutput { .. } => {
if let Ok(json) = serde_json::to_string(&item) {
let _ = file.write_all(json.as_bytes()).await;
let _ = file.write_all(b"\n").await;
}
}
ResponseItem::Other => {}
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
}
}
let _ = file.flush().await;
@@ -315,9 +294,6 @@ async fn rollout_writer(
let _ = file.flush().await;
}
}
RolloutCmd::Shutdown { ack } => {
let _ = ack.send(());
}
}
}
}

View File

@@ -1,3 +1,5 @@
use std::time::Duration;
use codex_core::Codex;
use codex_core::ModelProviderInfo;
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
@@ -9,6 +11,7 @@ mod test_support;
use tempfile::TempDir;
use test_support::load_default_config_for_test;
use test_support::load_sse_fixture_with_id;
use tokio::time::timeout;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
@@ -83,15 +86,21 @@ async fn includes_session_id_and_model_headers_in_request() {
.await
.unwrap();
let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) =
test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::SessionConfigured(_)))
let mut current_session_id = None;
// Wait for TaskComplete
loop {
let ev = timeout(Duration::from_secs(1), codex.next_event())
.await
else {
unreachable!()
};
.unwrap()
.unwrap();
let current_session_id = Some(session_id.to_string());
test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
if let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) = ev.msg {
current_session_id = Some(session_id.to_string());
}
if matches!(ev.msg, EventMsg::TaskComplete(_)) {
break;
}
}
// get request from the server
let request = &server.received_requests().await.unwrap()[0];
@@ -99,76 +108,6 @@ async fn includes_session_id_and_model_headers_in_request() {
let originator = request.headers.get("originator").unwrap();
assert!(current_session_id.is_some());
assert_eq!(
request_body.to_str().unwrap(),
current_session_id.as_ref().unwrap()
);
assert_eq!(request_body.to_str().unwrap(), &current_session_id.unwrap());
assert_eq!(originator.to_str().unwrap(), "codex_cli_rs");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn includes_base_instructions_override_in_request() {
#![allow(clippy::unwrap_used)]
// Mock server
let server = MockServer::start().await;
// First request must NOT include `previous_response_id`.
let first = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("resp1"), "text/event-stream");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(first)
.expect(1)
.mount(&server)
.await;
let model_provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.uri()),
// Environment variable that should exist in the test environment.
// ModelClient will return an error if the environment variable for the
// provider is not set.
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api: codex_core::WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: None,
};
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home);
config.base_instructions = Some("test instructions".to_string());
config.model_provider = model_provider;
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
let (codex, ..) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
codex
.submit(Op::UserInput {
items: vec![InputItem::Text {
text: "hello".into(),
}],
})
.await
.unwrap();
test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let request = &server.received_requests().await.unwrap()[0];
let request_body = request.body_json::<serde_json::Value>().unwrap();
assert!(
request_body["instructions"]
.as_str()
.unwrap()
.contains("test instructions")
);
}

View File

@@ -0,0 +1,165 @@
use std::time::Duration;
use codex_core::Codex;
use codex_core::ModelProviderInfo;
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
use codex_core::protocol::ErrorEvent;
use codex_core::protocol::EventMsg;
use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
mod test_support;
use serde_json::Value;
use tempfile::TempDir;
use test_support::load_default_config_for_test;
use test_support::load_sse_fixture_with_id;
use tokio::time::timeout;
use wiremock::Match;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::Request;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;
/// Matcher asserting that JSON body has NO `previous_response_id` field.
struct NoPrevId;
impl Match for NoPrevId {
fn matches(&self, req: &Request) -> bool {
serde_json::from_slice::<Value>(&req.body)
.map(|v| v.get("previous_response_id").is_none())
.unwrap_or(false)
}
}
/// Matcher asserting that JSON body HAS a `previous_response_id` field.
struct HasPrevId;
impl Match for HasPrevId {
fn matches(&self, req: &Request) -> bool {
serde_json::from_slice::<Value>(&req.body)
.map(|v| v.get("previous_response_id").is_some())
.unwrap_or(false)
}
}
/// Build minimal SSE stream with completed marker using the JSON fixture.
fn sse_completed(id: &str) -> String {
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn keeps_previous_response_id_between_tasks() {
#![allow(clippy::unwrap_used)]
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
println!(
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
);
return;
}
// Mock server
let server = MockServer::start().await;
// First request must NOT include `previous_response_id`.
let first = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("resp1"), "text/event-stream");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.and(NoPrevId)
.respond_with(first)
.expect(1)
.mount(&server)
.await;
// Second request MUST include `previous_response_id`.
let second = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("resp2"), "text/event-stream");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.and(HasPrevId)
.respond_with(second)
.expect(1)
.mount(&server)
.await;
// Configure retry behavior explicitly to avoid mutating process-wide
// environment variables.
let model_provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.uri()),
// Environment variable that should exist in the test environment.
// ModelClient will return an error if the environment variable for the
// provider is not set.
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api: codex_core::WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
// disable retries so we don't get duplicate calls in this test
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: None,
};
// Init session
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home);
config.model_provider = model_provider;
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
let (codex, _init_id, _session_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
// Task 1 triggers first request (no previous_response_id)
codex
.submit(Op::UserInput {
items: vec![InputItem::Text {
text: "hello".into(),
}],
})
.await
.unwrap();
// Wait for TaskComplete
loop {
let ev = timeout(Duration::from_secs(1), codex.next_event())
.await
.unwrap()
.unwrap();
if matches!(ev.msg, EventMsg::TaskComplete(_)) {
break;
}
}
// Task 2 should include `previous_response_id` (triggers second request)
codex
.submit(Op::UserInput {
items: vec![InputItem::Text {
text: "again".into(),
}],
})
.await
.unwrap();
// Wait for TaskComplete or error
loop {
let ev = timeout(Duration::from_secs(1), codex.next_event())
.await
.unwrap()
.unwrap();
match ev.msg {
EventMsg::TaskComplete(_) => break,
EventMsg::Error(ErrorEvent { message }) => {
panic!("unexpected error: {message}")
}
_ => {
// Ignore other events.
}
}
}
}

View File

@@ -76,24 +76,3 @@ pub fn load_sse_fixture_with_id(path: impl AsRef<std::path::Path>, id: &str) ->
})
.collect()
}
#[allow(dead_code)]
pub async fn wait_for_event<F>(
codex: &codex_core::Codex,
mut predicate: F,
) -> codex_core::protocol::EventMsg
where
F: FnMut(&codex_core::protocol::EventMsg) -> bool,
{
use tokio::time::Duration;
use tokio::time::timeout;
loop {
let ev = timeout(Duration::from_secs(1), codex.next_event())
.await
.expect("timeout waiting for event")
.expect("stream ended unexpectedly");
if predicate(&ev.msg) {
return ev.msg;
}
}
}

View File

@@ -1,23 +1,15 @@
use std::path::Path;
use codex_common::summarize_sandbox_policy;
use codex_core::WireApi;
use codex_core::config::Config;
use codex_core::model_supports_reasoning_summaries;
use codex_core::protocol::Event;
pub(crate) enum CodexStatus {
Running,
InitiateShutdown,
Shutdown,
}
pub(crate) trait EventProcessor {
/// Print summary of effective configuration and user prompt.
fn print_config_summary(&mut self, config: &Config, prompt: &str);
/// Handle a single event emitted by the agent.
fn process_event(&mut self, event: Event) -> CodexStatus;
fn process_event(&mut self, event: Event);
}
pub(crate) fn create_config_summary_entries(config: &Config) -> Vec<(&'static str, String)> {
@@ -43,28 +35,3 @@ pub(crate) fn create_config_summary_entries(config: &Config) -> Vec<(&'static st
entries
}
pub(crate) fn handle_last_message(
last_agent_message: Option<&str>,
last_message_path: Option<&Path>,
) {
match (last_message_path, last_agent_message) {
(Some(path), Some(msg)) => write_last_message_file(msg, Some(path)),
(Some(path), None) => {
write_last_message_file("", Some(path));
eprintln!(
"Warning: no last agent message; wrote empty content to {}",
path.display()
);
}
(None, _) => eprintln!("Warning: no file to write last message to."),
}
}
fn write_last_message_file(contents: &str, last_message_path: Option<&Path>) {
if let Some(path) = last_message_path {
if let Err(e) = std::fs::write(path, contents) {
eprintln!("Failed to write last message file {path:?}: {e}");
}
}
}

View File

@@ -15,20 +15,16 @@ use codex_core::protocol::McpToolCallEndEvent;
use codex_core::protocol::PatchApplyBeginEvent;
use codex_core::protocol::PatchApplyEndEvent;
use codex_core::protocol::SessionConfiguredEvent;
use codex_core::protocol::TaskCompleteEvent;
use codex_core::protocol::TokenUsage;
use owo_colors::OwoColorize;
use owo_colors::Style;
use shlex::try_join;
use std::collections::HashMap;
use std::io::Write;
use std::path::PathBuf;
use std::time::Instant;
use crate::event_processor::CodexStatus;
use crate::event_processor::EventProcessor;
use crate::event_processor::create_config_summary_entries;
use crate::event_processor::handle_last_message;
/// This should be configurable. When used in CI, users may not want to impose
/// a limit so they can see the full transcript.
@@ -58,15 +54,10 @@ pub(crate) struct EventProcessorWithHumanOutput {
show_agent_reasoning: bool,
answer_started: bool,
reasoning_started: bool,
last_message_path: Option<PathBuf>,
}
impl EventProcessorWithHumanOutput {
pub(crate) fn create_with_ansi(
with_ansi: bool,
config: &Config,
last_message_path: Option<PathBuf>,
) -> Self {
pub(crate) fn create_with_ansi(with_ansi: bool, config: &Config) -> Self {
let call_id_to_command = HashMap::new();
let call_id_to_patch = HashMap::new();
let call_id_to_tool_call = HashMap::new();
@@ -86,7 +77,6 @@ impl EventProcessorWithHumanOutput {
show_agent_reasoning: !config.hide_agent_reasoning,
answer_started: false,
reasoning_started: false,
last_message_path,
}
} else {
Self {
@@ -103,7 +93,6 @@ impl EventProcessorWithHumanOutput {
show_agent_reasoning: !config.hide_agent_reasoning,
answer_started: false,
reasoning_started: false,
last_message_path,
}
}
}
@@ -169,7 +158,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
);
}
fn process_event(&mut self, event: Event) -> CodexStatus {
fn process_event(&mut self, event: Event) {
let Event { id: _, msg } = event;
match msg {
EventMsg::Error(ErrorEvent { message }) => {
@@ -179,16 +168,9 @@ impl EventProcessor for EventProcessorWithHumanOutput {
EventMsg::BackgroundEvent(BackgroundEventEvent { message }) => {
ts_println!(self, "{}", message.style(self.dimmed));
}
EventMsg::TaskStarted => {
EventMsg::TaskStarted | EventMsg::TaskComplete(_) => {
// Ignore.
}
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
handle_last_message(
last_agent_message.as_deref(),
self.last_message_path.as_deref(),
);
return CodexStatus::InitiateShutdown;
}
EventMsg::TokenCount(TokenUsage { total_tokens, .. }) => {
ts_println!(self, "tokens used: {total_tokens}");
}
@@ -203,7 +185,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
}
EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }) => {
if !self.show_agent_reasoning {
return CodexStatus::Running;
return;
}
if !self.reasoning_started {
ts_println!(
@@ -516,9 +498,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
EventMsg::GetHistoryEntryResponse(_) => {
// Currently ignored in exec output.
}
EventMsg::ShutdownComplete => return CodexStatus::Shutdown,
}
CodexStatus::Running
}
}

View File

@@ -1,24 +1,18 @@
use std::collections::HashMap;
use std::path::PathBuf;
use codex_core::config::Config;
use codex_core::protocol::Event;
use codex_core::protocol::EventMsg;
use codex_core::protocol::TaskCompleteEvent;
use serde_json::json;
use crate::event_processor::CodexStatus;
use crate::event_processor::EventProcessor;
use crate::event_processor::create_config_summary_entries;
use crate::event_processor::handle_last_message;
pub(crate) struct EventProcessorWithJsonOutput {
last_message_path: Option<PathBuf>,
}
pub(crate) struct EventProcessorWithJsonOutput;
impl EventProcessorWithJsonOutput {
pub fn new(last_message_path: Option<PathBuf>) -> Self {
Self { last_message_path }
pub fn new() -> Self {
Self {}
}
}
@@ -39,25 +33,15 @@ impl EventProcessor for EventProcessorWithJsonOutput {
println!("{prompt_json}");
}
fn process_event(&mut self, event: Event) -> CodexStatus {
fn process_event(&mut self, event: Event) {
match event.msg {
EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_) => {
// Suppress streaming events in JSON mode.
CodexStatus::Running
}
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
handle_last_message(
last_agent_message.as_deref(),
self.last_message_path.as_deref(),
);
CodexStatus::InitiateShutdown
}
EventMsg::ShutdownComplete => CodexStatus::Shutdown,
_ => {
if let Ok(line) = serde_json::to_string(&event) {
println!("{line}");
}
CodexStatus::Running
}
}
}

View File

@@ -5,10 +5,12 @@ mod event_processor_with_json_output;
use std::io::IsTerminal;
use std::io::Read;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
pub use cli::Cli;
use codex_common::load_dotenv;
use codex_core::codex_wrapper;
use codex_core::config::Config;
use codex_core::config::ConfigOverrides;
@@ -27,10 +29,11 @@ use tracing::error;
use tracing::info;
use tracing_subscriber::EnvFilter;
use crate::event_processor::CodexStatus;
use crate::event_processor::EventProcessor;
pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> anyhow::Result<()> {
load_dotenv();
let Cli {
images,
model,
@@ -110,7 +113,6 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
model_provider: None,
codex_linux_sandbox_exe,
base_instructions: None,
};
// Parse `-c` overrides.
let cli_kv_overrides = match config_overrides.parse_overrides() {
@@ -123,12 +125,11 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides)?;
let mut event_processor: Box<dyn EventProcessor> = if json_mode {
Box::new(EventProcessorWithJsonOutput::new(last_message_file.clone()))
Box::new(EventProcessorWithJsonOutput::new())
} else {
Box::new(EventProcessorWithHumanOutput::create_with_ansi(
stdout_with_ansi,
&config,
last_message_file.clone(),
))
};
@@ -225,17 +226,40 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
// Run the loop until the task is complete.
while let Some(event) = rx.recv().await {
let shutdown: CodexStatus = event_processor.process_event(event);
match shutdown {
CodexStatus::Running => continue,
CodexStatus::InitiateShutdown => {
codex.submit(Op::Shutdown).await?;
}
CodexStatus::Shutdown => {
break;
let (is_last_event, last_assistant_message) = match &event.msg {
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
(true, last_agent_message.clone())
}
_ => (false, None),
};
event_processor.process_event(event);
if is_last_event {
handle_last_message(last_assistant_message, last_message_file.as_deref())?;
break;
}
}
Ok(())
}
fn handle_last_message(
last_agent_message: Option<String>,
last_message_file: Option<&Path>,
) -> std::io::Result<()> {
match (last_agent_message, last_message_file) {
(Some(last_agent_message), Some(last_message_file)) => {
// Last message and a file to write to.
std::fs::write(last_message_file, last_agent_message)?;
}
(None, Some(last_message_file)) => {
eprintln!(
"Warning: No last message to write to file: {}",
last_message_file.to_string_lossy()
);
}
(_, None) => {
// No last message and no file to write to.
}
}
Ok(())
}

View File

@@ -17,9 +17,7 @@ workspace = true
[dependencies]
anyhow = "1"
clap = { version = "4", features = ["derive"] }
codex-common = { path = "../common", features = ["cli"] }
codex-core = { path = "../core" }
dotenvy = "0.15.7"
tokio = { version = "1", features = ["rt-multi-thread"] }
[dev-dependencies]

View File

@@ -43,10 +43,6 @@ where
crate::run_main();
}
// This modifies the environment, which is not thread-safe, so do this
// before creating any threads/the Tokio runtime.
load_dotenv();
// Regular invocation create a Tokio runtime and execute the provided
// async entry-point.
let runtime = tokio::runtime::Runtime::new()?;
@@ -65,11 +61,3 @@ where
pub fn run_main() -> ! {
panic!("codex-linux-sandbox is only supported on Linux");
}
/// Load env vars from ~/.codex/.env and `$(pwd)/.env`.
fn load_dotenv() {
if let Ok(codex_home) = codex_core::config::find_codex_home() {
dotenvy::from_path(codex_home.join(".env")).ok();
}
dotenvy::dotenv().ok();
}

View File

@@ -16,6 +16,7 @@ workspace = true
[dependencies]
anyhow = "1"
codex-common = { path = "../common", features = ["cli"] }
codex-core = { path = "../core" }
codex-linux-sandbox = { path = "../linux-sandbox" }
mcp-types = { path = "../mcp-types" }

View File

@@ -14,7 +14,7 @@ use std::path::PathBuf;
use crate::json_to_toml::json_to_toml;
/// Client-supplied configuration for a `codex` tool-call.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "kebab-case")]
pub struct CodexToolCallParam {
/// The *initial user prompt* to start the Codex conversation.
@@ -46,10 +46,6 @@ pub struct CodexToolCallParam {
/// CODEX_HOME/config.toml.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub config: Option<HashMap<String, serde_json::Value>>,
/// The set of instructions to use instead of the default ones.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub base_instructions: Option<String>,
}
/// Custom enum mirroring [`AskForApproval`], but has an extra dependency on
@@ -139,7 +135,6 @@ impl CodexToolCallParam {
approval_policy,
sandbox,
config: cli_overrides,
base_instructions,
} = self;
// Build the `ConfigOverrides` recognised by codex-core.
@@ -151,7 +146,6 @@ impl CodexToolCallParam {
sandbox_mode: sandbox.map(Into::into),
model_provider: None,
codex_linux_sandbox_exe,
base_instructions,
};
let cli_overrides = cli_overrides
@@ -168,7 +162,7 @@ impl CodexToolCallParam {
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct CodexToolCallReplyParam {
pub(crate) struct CodexToolCallReplyParam {
/// The *session id* for this conversation.
pub session_id: String,
@@ -274,10 +268,6 @@ mod tests {
"description": "The *initial user prompt* to start the Codex conversation.",
"type": "string"
},
"base-instructions": {
"description": "The set of instructions to use instead of the default ones.",
"type": "string"
},
},
"required": [
"prompt"

View File

@@ -20,7 +20,6 @@ use mcp_types::CallToolResult;
use mcp_types::ContentBlock;
use mcp_types::RequestId;
use mcp_types::TextContent;
use serde_json::json;
use tokio::sync::Mutex;
use uuid::Uuid;
@@ -40,7 +39,6 @@ pub async fn run_codex_tool_session(
config: CodexConfig,
outgoing: Arc<OutgoingMessageSender>,
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
) {
let (codex, first_event, _ctrl_c, session_id) = match init_codex(config).await {
Ok(res) => res,
@@ -75,10 +73,7 @@ pub async fn run_codex_tool_session(
RequestId::String(s) => s.clone(),
RequestId::Integer(n) => n.to_string(),
};
running_requests_id_to_codex_uuid
.lock()
.await
.insert(id.clone(), session_id);
let submission = Submission {
id: sub_id.clone(),
op: Op::UserInput {
@@ -90,12 +85,9 @@ pub async fn run_codex_tool_session(
if let Err(e) = codex.submit_with_id(submission).await {
tracing::error!("Failed to submit initial prompt: {e}");
// unregister the id so we don't keep it in the map
running_requests_id_to_codex_uuid.lock().await.remove(&id);
return;
}
run_codex_tool_session_inner(codex, outgoing, id, running_requests_id_to_codex_uuid).await;
run_codex_tool_session_inner(codex, outgoing, id).await;
}
pub async fn run_codex_tool_session_reply(
@@ -103,13 +95,7 @@ pub async fn run_codex_tool_session_reply(
outgoing: Arc<OutgoingMessageSender>,
request_id: RequestId,
prompt: String,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
session_id: Uuid,
) {
running_requests_id_to_codex_uuid
.lock()
.await
.insert(request_id.clone(), session_id);
if let Err(e) = codex
.submit(Op::UserInput {
items: vec![InputItem::Text { text: prompt }],
@@ -117,28 +103,15 @@ pub async fn run_codex_tool_session_reply(
.await
{
tracing::error!("Failed to submit user input: {e}");
// unregister the id so we don't keep it in the map
running_requests_id_to_codex_uuid
.lock()
.await
.remove(&request_id);
return;
}
run_codex_tool_session_inner(
codex,
outgoing,
request_id,
running_requests_id_to_codex_uuid,
)
.await;
run_codex_tool_session_inner(codex, outgoing, request_id).await;
}
async fn run_codex_tool_session_inner(
codex: Arc<Codex>,
outgoing: Arc<OutgoingMessageSender>,
request_id: RequestId,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
) {
let request_id_str = match &request_id {
RequestId::String(s) => s.clone(),
@@ -156,7 +129,6 @@ async fn run_codex_tool_session_inner(
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
command,
cwd,
call_id,
reason: _,
}) => {
handle_exec_approval_request(
@@ -167,27 +139,16 @@ async fn run_codex_tool_session_inner(
request_id.clone(),
request_id_str.clone(),
event.id.clone(),
call_id,
)
.await;
continue;
}
EventMsg::Error(err_event) => {
// Return a response to conclude the tool call when the Codex session reports an error (e.g., interruption).
let result = json!({
"error": err_event.message,
});
outgoing.send_response(request_id.clone(), result).await;
break;
}
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
call_id,
reason,
grant_root,
changes,
}) => {
handle_patch_approval_request(
call_id,
reason,
grant_root,
changes,
@@ -217,11 +178,6 @@ async fn run_codex_tool_session_inner(
outgoing
.send_response(request_id.clone(), result.into())
.await;
// unregister the id so we don't keep it in the map
running_requests_id_to_codex_uuid
.lock()
.await
.remove(&request_id);
break;
}
EventMsg::SessionConfigured(_) => {
@@ -236,7 +192,8 @@ async fn run_codex_tool_session_inner(
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
// TODO: think how we want to support this in the MCP
}
EventMsg::TaskStarted
EventMsg::Error(_)
| EventMsg::TaskStarted
| EventMsg::TokenCount(_)
| EventMsg::AgentReasoning(_)
| EventMsg::McpToolCallBegin(_)
@@ -246,8 +203,7 @@ async fn run_codex_tool_session_inner(
| EventMsg::BackgroundEvent(_)
| EventMsg::PatchApplyBegin(_)
| EventMsg::PatchApplyEnd(_)
| EventMsg::GetHistoryEntryResponse(_)
| EventMsg::ShutdownComplete => {
| EventMsg::GetHistoryEntryResponse(_) => {
// For now, we do not do anything extra for these
// events. Note that
// send(codex_event_to_notification(&event)) above has

View File

@@ -32,7 +32,6 @@ pub struct ExecApprovalElicitRequestParams {
pub codex_elicitation: String,
pub codex_mcp_tool_call_id: String,
pub codex_event_id: String,
pub codex_call_id: String,
pub codex_command: Vec<String>,
pub codex_cwd: PathBuf,
}
@@ -46,7 +45,6 @@ pub struct ExecApprovalResponse {
pub decision: ReviewDecision,
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_exec_approval_request(
command: Vec<String>,
cwd: PathBuf,
@@ -55,7 +53,6 @@ pub(crate) async fn handle_exec_approval_request(
request_id: RequestId,
tool_call_id: String,
event_id: String,
call_id: String,
) {
let escaped_command =
shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "));
@@ -74,7 +71,6 @@ pub(crate) async fn handle_exec_approval_request(
codex_elicitation: "exec-approval".to_string(),
codex_mcp_tool_call_id: tool_call_id.clone(),
codex_event_id: event_id.clone(),
codex_call_id: call_id,
codex_command: command,
codex_cwd: cwd,
};

View File

@@ -4,6 +4,7 @@
use std::io::Result as IoResult;
use std::path::PathBuf;
use codex_common::load_dotenv;
use mcp_types::JSONRPCMessage;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
@@ -27,7 +28,6 @@ use crate::outgoing_message::OutgoingMessage;
use crate::outgoing_message::OutgoingMessageSender;
pub use crate::codex_tool_config::CodexToolCallParam;
pub use crate::codex_tool_config::CodexToolCallReplyParam;
pub use crate::exec_approval::ExecApprovalElicitRequestParams;
pub use crate::exec_approval::ExecApprovalResponse;
pub use crate::patch_approval::PatchApprovalElicitRequestParams;
@@ -39,6 +39,8 @@ pub use crate::patch_approval::PatchApprovalResponse;
const CHANNEL_CAPACITY: usize = 128;
pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()> {
load_dotenv();
// Install a simple subscriber so `tracing` output is visible. Users can
// control the log level with `RUST_LOG`.
tracing_subscriber::fmt()
@@ -82,7 +84,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
match msg {
JSONRPCMessage::Request(r) => processor.process_request(r).await,
JSONRPCMessage::Response(r) => processor.process_response(r).await,
JSONRPCMessage::Notification(n) => processor.process_notification(n).await,
JSONRPCMessage::Notification(n) => processor.process_notification(n),
JSONRPCMessage::Error(e) => processor.process_error(e),
}
}

View File

@@ -10,7 +10,6 @@ use crate::outgoing_message::OutgoingMessageSender;
use codex_core::Codex;
use codex_core::config::Config as CodexConfig;
use codex_core::protocol::Submission;
use mcp_types::CallToolRequestParams;
use mcp_types::CallToolResult;
use mcp_types::ClientRequest;
@@ -36,7 +35,6 @@ pub(crate) struct MessageProcessor {
initialized: bool,
codex_linux_sandbox_exe: Option<PathBuf>,
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
}
impl MessageProcessor {
@@ -51,7 +49,6 @@ impl MessageProcessor {
initialized: false,
codex_linux_sandbox_exe,
session_map: Arc::new(Mutex::new(HashMap::new())),
running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())),
}
}
@@ -119,7 +116,7 @@ impl MessageProcessor {
}
/// Handle a fire-and-forget JSON-RPC notification.
pub(crate) async fn process_notification(&mut self, notification: JSONRPCNotification) {
pub(crate) fn process_notification(&mut self, notification: JSONRPCNotification) {
let server_notification = match ServerNotification::try_from(notification) {
Ok(n) => n,
Err(e) => {
@@ -132,7 +129,7 @@ impl MessageProcessor {
// handler so additional logic can be implemented incrementally.
match server_notification {
ServerNotification::CancelledNotification(params) => {
self.handle_cancelled_notification(params).await;
self.handle_cancelled_notification(params);
}
ServerNotification::ProgressNotification(params) => {
self.handle_progress_notification(params);
@@ -382,7 +379,6 @@ impl MessageProcessor {
// Clone outgoing and session map to move into async task.
let outgoing = self.outgoing.clone();
let session_map = self.session_map.clone();
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
// Spawn an async task to handle the Codex session so that we do not
// block the synchronous message-processing loop.
@@ -394,7 +390,6 @@ impl MessageProcessor {
config,
outgoing,
session_map,
running_requests_id_to_codex_uuid,
)
.await;
});
@@ -469,12 +464,13 @@ impl MessageProcessor {
// Clone outgoing and session map to move into async task.
let outgoing = self.outgoing.clone();
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
let codex = {
// Spawn an async task to handle the Codex session so that we do not
// block the synchronous message-processing loop.
task::spawn(async move {
let session_map = session_map_mutex.lock().await;
match session_map.get(&session_id).cloned() {
Some(c) => c,
let codex = match session_map.get(&session_id) {
Some(codex) => codex,
None => {
tracing::warn!("Session not found for session_id: {session_id}");
let result = CallToolResult {
@@ -486,32 +482,21 @@ impl MessageProcessor {
is_error: Some(true),
structured_content: None,
};
// unwrap_or_default is fine here because we know the result is valid JSON
outgoing
.send_response(request_id, serde_json::to_value(result).unwrap_or_default())
.await;
return;
}
}
};
};
// Spawn the long-running reply handler.
tokio::spawn({
let codex = codex.clone();
let outgoing = outgoing.clone();
let prompt = prompt.clone();
let running_requests_id_to_codex_uuid = running_requests_id_to_codex_uuid.clone();
async move {
crate::codex_tool_runner::run_codex_tool_session_reply(
codex,
outgoing,
request_id,
prompt,
running_requests_id_to_codex_uuid,
session_id,
)
.await;
}
crate::codex_tool_runner::run_codex_tool_session_reply(
codex.clone(),
outgoing,
request_id,
prompt.clone(),
)
.await;
});
}
@@ -533,58 +518,11 @@ impl MessageProcessor {
// Notification handlers
// ---------------------------------------------------------------------
async fn handle_cancelled_notification(
fn handle_cancelled_notification(
&self,
params: <mcp_types::CancelledNotification as mcp_types::ModelContextProtocolNotification>::Params,
) {
let request_id = params.request_id;
// Create a stable string form early for logging and submission id.
let request_id_string = match &request_id {
RequestId::String(s) => s.clone(),
RequestId::Integer(i) => i.to_string(),
};
// Obtain the session_id while holding the first lock, then release.
let session_id = {
let map_guard = self.running_requests_id_to_codex_uuid.lock().await;
match map_guard.get(&request_id) {
Some(id) => *id, // Uuid is Copy
None => {
tracing::warn!("Session not found for request_id: {}", request_id_string);
return;
}
}
};
tracing::info!("session_id: {session_id}");
// Obtain the Codex Arc while holding the session_map lock, then release.
let codex_arc = {
let sessions_guard = self.session_map.lock().await;
match sessions_guard.get(&session_id) {
Some(codex) => Arc::clone(codex),
None => {
tracing::warn!("Session not found for session_id: {session_id}");
return;
}
}
};
// Submit interrupt to Codex.
let err = codex_arc
.submit_with_id(Submission {
id: request_id_string,
op: codex_core::protocol::Op::Interrupt,
})
.await;
if let Err(e) = err {
tracing::error!("Failed to submit interrupt to Codex: {e}");
return;
}
// unregister the id so we don't keep it in the map
self.running_requests_id_to_codex_uuid
.lock()
.await
.remove(&request_id);
tracing::info!("notifications/cancelled -> params: {:?}", params);
}
fn handle_progress_notification(

View File

@@ -27,7 +27,6 @@ pub struct PatchApprovalElicitRequestParams {
pub codex_elicitation: String,
pub codex_mcp_tool_call_id: String,
pub codex_event_id: String,
pub codex_call_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub codex_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
@@ -42,7 +41,6 @@ pub struct PatchApprovalResponse {
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_patch_approval_request(
call_id: String,
reason: Option<String>,
grant_root: Option<PathBuf>,
changes: HashMap<PathBuf, FileChange>,
@@ -68,7 +66,6 @@ pub(crate) async fn handle_patch_approval_request(
codex_elicitation: "patch-approval".to_string(),
codex_mcp_tool_call_id: tool_call_id.clone(),
codex_event_id: event_id.clone(),
codex_call_id: call_id,
codex_reason: reason,
codex_grant_root: grant_root,
codex_changes: changes,

View File

@@ -12,7 +12,6 @@ use tokio::process::ChildStdout;
use anyhow::Context;
use assert_cmd::prelude::*;
use codex_mcp_server::CodexToolCallParam;
use codex_mcp_server::CodexToolCallReplyParam;
use mcp_types::CallToolRequestParams;
use mcp_types::ClientCapabilities;
use mcp_types::Implementation;
@@ -142,29 +141,19 @@ impl McpProcess {
/// correlating notifications.
pub async fn send_codex_tool_call(
&mut self,
params: CodexToolCallParam,
) -> anyhow::Result<i64> {
let codex_tool_call_params = CallToolRequestParams {
name: "codex".to_string(),
arguments: Some(serde_json::to_value(params)?),
};
self.send_request(
mcp_types::CallToolRequest::METHOD,
Some(serde_json::to_value(codex_tool_call_params)?),
)
.await
}
pub async fn send_codex_reply_tool_call(
&mut self,
session_id: &str,
cwd: Option<String>,
prompt: &str,
) -> anyhow::Result<i64> {
let codex_tool_call_params = CallToolRequestParams {
name: "codex-reply".to_string(),
arguments: Some(serde_json::to_value(CodexToolCallReplyParam {
name: "codex".to_string(),
arguments: Some(serde_json::to_value(CodexToolCallParam {
cwd,
prompt: prompt.to_string(),
session_id: session_id.to_string(),
model: None,
profile: None,
approval_policy: None,
sandbox: None,
config: None,
})?),
};
self.send_request(
@@ -191,8 +180,6 @@ impl McpProcess {
Ok(request_id)
}
// allow dead code
#[allow(dead_code)]
pub async fn send_response(
&mut self,
id: RequestId,
@@ -220,8 +207,7 @@ impl McpProcess {
let message = serde_json::from_str::<JSONRPCMessage>(&line)?;
Ok(message)
}
// allow dead code
#[allow(dead_code)]
pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<JSONRPCRequest> {
loop {
let message = self.read_jsonrpc_message().await?;
@@ -244,8 +230,6 @@ impl McpProcess {
}
}
// allow dead code
#[allow(dead_code)]
pub async fn read_stream_until_response_message(
&mut self,
request_id: RequestId,
@@ -272,58 +256,4 @@ impl McpProcess {
}
}
}
pub async fn read_stream_until_configured_response_message(
&mut self,
) -> anyhow::Result<String> {
loop {
let message = self.read_jsonrpc_message().await?;
eprint!("message: {message:?}");
match message {
JSONRPCMessage::Notification(notification) => {
if notification.method == "codex/event" {
if let Some(params) = notification.params {
if let Some(msg) = params.get("msg") {
if let Some(msg_type) = msg.get("type") {
if msg_type == "session_configured" {
if let Some(session_id) = msg.get("session_id") {
return Ok(session_id
.to_string()
.trim_matches('"')
.to_string());
}
}
}
}
}
}
}
JSONRPCMessage::Request(_) => {
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
}
JSONRPCMessage::Error(_) => {
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
}
JSONRPCMessage::Response(_) => {
anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}");
}
}
}
}
// allow dead code
#[allow(dead_code)]
pub async fn send_notification(
&mut self,
method: &str,
params: Option<serde_json::Value>,
) -> anyhow::Result<()> {
self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification {
jsonrpc: JSONRPC_VERSION.into(),
method: method.to_string(),
params,
}))
.await
}
}

View File

@@ -4,8 +4,6 @@ mod responses;
pub use mcp_process::McpProcess;
pub use mock_model_server::create_mock_chat_completions_server;
#[allow(unused_imports)]
pub use responses::create_apply_patch_sse_response;
#[allow(unused_imports)]
pub use responses::create_final_assistant_message_sse_response;
pub use responses::create_shell_sse_response;

View File

@@ -39,8 +39,6 @@ pub fn create_shell_sse_response(
Ok(sse)
}
// allow dead code
#[allow(dead_code)]
pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Result<String> {
let assistant_message = json!({
"choices": [
@@ -60,8 +58,6 @@ pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Res
Ok(sse)
}
// allow dead code
#[allow(dead_code)]
pub fn create_apply_patch_sse_response(
patch_content: &str,
call_id: &str,

View File

@@ -8,7 +8,6 @@ use std::path::PathBuf;
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
use codex_core::protocol::FileChange;
use codex_core::protocol::ReviewDecision;
use codex_mcp_server::CodexToolCallParam;
use codex_mcp_server::ExecApprovalElicitRequestParams;
use codex_mcp_server::ExecApprovalResponse;
use codex_mcp_server::PatchApprovalElicitRequestParams;
@@ -77,10 +76,7 @@ async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> {
// In turn, it should reply with a tool call, which the MCP should forward
// as an elicitation.
let codex_request_id = mcp_process
.send_codex_tool_call(CodexToolCallParam {
prompt: "run `git init`".to_string(),
..Default::default()
})
.send_codex_tool_call(None, "run `git init`")
.await?;
let elicitation_request = timeout(
DEFAULT_READ_TIMEOUT,
@@ -171,7 +167,6 @@ fn create_expected_elicitation_request(
codex_event_id,
codex_command: command,
codex_cwd: workdir.to_path_buf(),
codex_call_id: "call1234".to_string(),
})?),
})
}
@@ -214,11 +209,10 @@ async fn patch_approval_triggers_elicitation() -> anyhow::Result<()> {
// Send a "codex" tool request that will trigger the apply_patch command
let codex_request_id = mcp_process
.send_codex_tool_call(CodexToolCallParam {
cwd: Some(cwd.path().to_string_lossy().to_string()),
prompt: "please modify the test file".to_string(),
..Default::default()
})
.send_codex_tool_call(
Some(cwd.path().to_string_lossy().to_string()),
"please modify the test file",
)
.await?;
let elicitation_request = timeout(
DEFAULT_READ_TIMEOUT,
@@ -285,75 +279,6 @@ async fn patch_approval_triggers_elicitation() -> anyhow::Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_codex_tool_passes_base_instructions() {
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
println!(
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
);
return;
}
// Apparently `#[tokio::test]` must return `()`, so we create a helper
// function that returns `Result` so we can use `?` in favor of `unwrap`.
if let Err(err) = codex_tool_passes_base_instructions().await {
panic!("failure: {err}");
}
}
async fn codex_tool_passes_base_instructions() -> anyhow::Result<()> {
#![allow(clippy::unwrap_used)]
let server =
create_mock_chat_completions_server(vec![create_final_assistant_message_sse_response(
"Enjoy!",
)?])
.await;
// Run `codex mcp` with a specific config.toml.
let codex_home = TempDir::new()?;
create_config_toml(codex_home.path(), &server.uri())?;
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
// Send a "codex" tool request, which should hit the completions endpoint.
let codex_request_id = mcp_process
.send_codex_tool_call(CodexToolCallParam {
prompt: "How are you?".to_string(),
base_instructions: Some("You are a helpful assistant.".to_string()),
..Default::default()
})
.await?;
let codex_response = timeout(
DEFAULT_READ_TIMEOUT,
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
)
.await??;
assert_eq!(
JSONRPCResponse {
jsonrpc: JSONRPC_VERSION.into(),
id: RequestId::Integer(codex_request_id),
result: json!({
"content": [
{
"text": "Enjoy!",
"type": "text"
}
]
}),
},
codex_response
);
let requests = server.received_requests().await.unwrap();
let request = requests[0].body_json::<serde_json::Value>().unwrap();
let instructions = request["messages"][0]["content"].as_str().unwrap();
assert!(instructions.starts_with("You are a helpful assistant."));
Ok(())
}
fn create_expected_patch_approval_elicitation_request(
elicitation_request_id: RequestId,
changes: HashMap<PathBuf, FileChange>,
@@ -385,7 +310,6 @@ fn create_expected_patch_approval_elicitation_request(
codex_reason: reason,
codex_grant_root: grant_root,
codex_changes: changes,
codex_call_id: "call1234".to_string(),
})?),
})
}

View File

@@ -1,176 +0,0 @@
#![cfg(unix)]
mod common;
use std::path::Path;
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
use codex_mcp_server::CodexToolCallParam;
use mcp_types::JSONRPCResponse;
use mcp_types::RequestId;
use serde_json::json;
use tempfile::TempDir;
use tokio::time::timeout;
use crate::common::McpProcess;
use crate::common::create_mock_chat_completions_server;
use crate::common::create_shell_sse_response;
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_shell_command_interruption() {
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
println!(
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
);
return;
}
if let Err(err) = shell_command_interruption().await {
panic!("failure: {err}");
}
}
async fn shell_command_interruption() -> anyhow::Result<()> {
// Use a cross-platform blocking command. On Windows plain `sleep` is not guaranteed to exist
// (MSYS/GNU coreutils may be absent) and the failure causes the tool call to finish immediately,
// which triggers a second model request before the test sends the explicit follow-up. That
// prematurely consumes the second mocked SSE response and leads to a third POST (panic: no response for 2).
// Powershell Start-Sleep is always available on Windows runners. On Unix we keep using `sleep`.
#[cfg(target_os = "windows")]
let shell_command = vec![
"powershell".to_string(),
"-Command".to_string(),
"Start-Sleep -Seconds 60".to_string(),
];
#[cfg(not(target_os = "windows"))]
let shell_command = vec!["sleep".to_string(), "60".to_string()];
let workdir_for_shell_function_call = TempDir::new()?;
// Create mock server with a single SSE response: the long sleep command
let server = create_mock_chat_completions_server(vec![
create_shell_sse_response(
shell_command.clone(),
Some(workdir_for_shell_function_call.path()),
Some(60_000), // 60 seconds timeout in ms
"call_sleep",
)?,
create_shell_sse_response(
shell_command.clone(),
Some(workdir_for_shell_function_call.path()),
Some(60_000), // 60 seconds timeout in ms
"call_sleep",
)?,
])
.await;
// Create Codex configuration
let codex_home = TempDir::new()?;
create_config_toml(codex_home.path(), server.uri())?;
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
// Send codex tool call that triggers "sleep 60"
let codex_request_id = mcp_process
.send_codex_tool_call(CodexToolCallParam {
cwd: None,
prompt: "First Run: run `sleep 60`".to_string(),
model: None,
profile: None,
approval_policy: None,
sandbox: None,
config: None,
base_instructions: None,
})
.await?;
let session_id = mcp_process
.read_stream_until_configured_response_message()
.await?;
// Give the command a moment to start
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
// Send interrupt notification
mcp_process
.send_notification(
"notifications/cancelled",
Some(json!({ "requestId": codex_request_id })),
)
.await?;
// Expect Codex to return an error or interruption response
let codex_response: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
)
.await??;
assert!(
codex_response
.result
.as_object()
.map(|o| o.contains_key("error"))
.unwrap_or(false),
"Expected an interruption or error result, got: {codex_response:?}"
);
let codex_reply_request_id = mcp_process
.send_codex_reply_tool_call(&session_id, "Second Run: run `sleep 60`")
.await?;
// Give the command a moment to start
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
// Send interrupt notification
mcp_process
.send_notification(
"notifications/cancelled",
Some(json!({ "requestId": codex_reply_request_id })),
)
.await?;
// Expect Codex to return an error or interruption response
let codex_response: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_reply_request_id)),
)
.await??;
assert!(
codex_response
.result
.as_object()
.map(|o| o.contains_key("error"))
.unwrap_or(false),
"Expected an interruption or error result, got: {codex_response:?}"
);
Ok(())
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
fn create_config_toml(codex_home: &Path, server_uri: String) -> std::io::Result<()> {
let config_toml = codex_home.join("config.toml");
std::fs::write(
config_toml,
format!(
r#"
model = "mock-model"
approval_policy = "never"
sandbox_mode = "danger-full-access"
model_provider = "mock_provider"
[model_providers.mock_provider]
name = "Mock provider for test"
base_url = "{server_uri}/v1"
wire_api = "chat"
request_max_retries = 0
stream_max_retries = 0
"#
),
)
}

View File

@@ -42,8 +42,8 @@ ratatui-image = "8.0.0"
regex-lite = "0.1"
serde_json = { version = "1", features = ["preserve_order"] }
shlex = "1.3.0"
strum = "0.27.2"
strum_macros = "0.27.2"
strum = "0.27.1"
strum_macros = "0.27.1"
tokio = { version = "1", features = [
"io-std",
"macros",

View File

@@ -223,7 +223,9 @@ impl App<'_> {
} => {
match &mut self.app_state {
AppState::Chat { widget } => {
widget.on_ctrl_c();
if widget.on_ctrl_c() {
self.app_event_tx.send(AppEvent::ExitRequest);
}
}
AppState::Login { .. } | AppState::GitWarning { .. } => {
// No-op.

View File

@@ -314,7 +314,6 @@ impl ChatWidget<'_> {
self.bottom_pane.set_task_running(false);
}
EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
call_id: _,
command,
cwd,
reason,
@@ -328,7 +327,6 @@ impl ChatWidget<'_> {
self.bottom_pane.push_approval_request(request);
}
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
call_id: _,
changes,
reason,
grant_root,
@@ -364,7 +362,7 @@ impl ChatWidget<'_> {
cwd: _,
}) => {
self.conversation_history
.reset_or_add_active_exec_command(call_id, command);
.add_active_exec_command(call_id, command);
self.request_redraw();
}
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
@@ -419,9 +417,6 @@ impl ChatWidget<'_> {
self.bottom_pane
.on_history_entry_response(log_id, offset, entry.map(|e| e.text));
}
EventMsg::ShutdownComplete => {
self.app_event_tx.send(AppEvent::ExitRequest);
}
event => {
self.conversation_history
.add_background_event(format!("{event:?}"));
@@ -474,7 +469,6 @@ impl ChatWidget<'_> {
self.reasoning_buffer.clear();
false
} else if self.bottom_pane.ctrl_c_quit_hint_visible() {
self.submit_op(Op::Shutdown);
true
} else {
self.bottom_pane.show_ctrl_c_quit_hint();

View File

@@ -235,30 +235,6 @@ impl ConversationHistoryWidget {
self.add_to_history(HistoryCell::new_active_exec_command(call_id, command));
}
/// If an ActiveExecCommand with the same call_id already exists, replace
/// it with a fresh one (resetting start time and view). Otherwise, add a new entry.
pub fn reset_or_add_active_exec_command(&mut self, call_id: String, command: Vec<String>) {
// Find the most recent matching ActiveExecCommand.
let maybe_idx = self.entries.iter().rposition(|entry| {
if let HistoryCell::ActiveExecCommand { call_id: id, .. } = &entry.cell {
id == &call_id
} else {
false
}
});
if let Some(idx) = maybe_idx {
let width = self.cached_width.get();
self.entries[idx].cell = HistoryCell::new_active_exec_command(call_id.clone(), command);
if width > 0 {
let height = self.entries[idx].cell.height(width);
self.entries[idx].line_count.set(height);
}
} else {
self.add_active_exec_command(call_id, command);
}
}
pub fn add_active_mcp_tool_call(
&mut self,
call_id: String,

View File

@@ -75,7 +75,6 @@ pub fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> std::io::
model_provider: None,
config_profile: cli.config_profile.clone(),
codex_linux_sandbox_exe,
base_instructions: None,
};
// Parse `-c` overrides from the CLI.
let cli_kv_overrides = match cli.config_overrides.parse_overrides() {