Compare commits

...

1 Commits

Author SHA1 Message Date
Michael Bolin
8ccb31a65e refactor: move shell and snapshot code out of codex-core 2026-04-02 01:41:31 -07:00
39 changed files with 1516 additions and 1442 deletions

29
codex-rs/Cargo.lock generated
View File

@@ -1865,8 +1865,10 @@ dependencies = [
"codex-rollout",
"codex-sandboxing",
"codex-secrets",
"codex-shell",
"codex-shell-command",
"codex-shell-escalation",
"codex-shell-snapshot",
"codex-state",
"codex-terminal-detection",
"codex-tools",
@@ -2568,6 +2570,15 @@ dependencies = [
"tracing",
]
[[package]]
name = "codex-shell"
version = "0.0.0"
dependencies = [
"libc",
"serde",
"which 8.0.0",
]
[[package]]
name = "codex-shell-command"
version = "0.0.0"
@@ -2575,6 +2586,7 @@ dependencies = [
"anyhow",
"base64 0.22.1",
"codex-protocol",
"codex-shell",
"codex-utils-absolute-path",
"once_cell",
"pretty_assertions",
@@ -2609,6 +2621,22 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "codex-shell-snapshot"
version = "0.0.0"
dependencies = [
"anyhow",
"codex-otel",
"codex-protocol",
"codex-shell",
"codex-utils-pty",
"libc",
"pretty_assertions",
"tempfile",
"tokio",
"tracing",
]
[[package]]
name = "codex-skills"
version = "0.0.0"
@@ -2668,6 +2696,7 @@ dependencies = [
"codex-code-mode",
"codex-features",
"codex-protocol",
"codex-shell",
"codex-utils-absolute-path",
"codex-utils-pty",
"pretty_assertions",

View File

@@ -24,6 +24,8 @@ members = [
"config",
"shell-command",
"shell-escalation",
"shell",
"shell-snapshot",
"skills",
"core",
"core-skills",
@@ -149,6 +151,8 @@ codex-sandboxing = { path = "sandboxing" }
codex-secrets = { path = "secrets" }
codex-shell-command = { path = "shell-command" }
codex-shell-escalation = { path = "shell-escalation" }
codex-shell = { path = "shell" }
codex-shell-snapshot = { path = "shell-snapshot" }
codex-skills = { path = "skills" }
codex-state = { path = "state" }
codex-stdio-to-uds = { path = "stdio-to-uds" }

View File

@@ -51,6 +51,7 @@ codex-protocol = { workspace = true }
codex-rollout = { workspace = true }
codex-rmcp-client = { workspace = true }
codex-sandboxing = { workspace = true }
codex-shell = { workspace = true }
codex-state = { workspace = true }
codex-terminal-detection = { workspace = true }
codex-tools = { workspace = true }
@@ -64,6 +65,7 @@ codex-utils-plugins = { workspace = true }
codex-utils-pty = { workspace = true }
codex-utils-readiness = { workspace = true }
codex-secrets = { workspace = true }
codex-shell-snapshot = { workspace = true }
codex-utils-string = { workspace = true }
codex-utils-stream-parser = { workspace = true }
codex-utils-template = { workspace = true }

View File

@@ -956,7 +956,7 @@ impl AgentControl {
};
let parent_thread = state.get_thread(*parent_thread_id).await.ok()?;
parent_thread.codex.session.user_shell().shell_snapshot()
parent_thread.codex.session.shell_snapshot()
}
async fn inherited_exec_policy_for_source(

View File

@@ -290,6 +290,7 @@ use crate::rollout::policy::EventPersistenceMode;
use crate::session_startup_prewarm::SessionStartupPrewarmHandle;
use crate::shell;
use crate::shell_snapshot::ShellSnapshot;
use crate::shell_snapshot::spawn_stale_snapshot_cleanup;
use crate::skills_watcher::SkillsWatcher;
use crate::skills_watcher::SkillsWatcherEvent;
use crate::state::ActiveTurn;
@@ -1400,7 +1401,7 @@ impl Session {
windows_sandbox_level: session_configuration.windows_sandbox_level,
})
.with_unified_exec_shell_mode_for_session(
crate::tools::spec::tool_user_shell_type(user_shell),
user_shell.shell_type,
shell_zsh_path,
main_execve_wrapper_exe,
)
@@ -1729,7 +1730,7 @@ impl Session {
);
let use_zsh_fork_shell = config.features.enabled(Feature::ShellZshFork);
let mut default_shell = if let Some(user_shell_override) =
let default_shell = if let Some(user_shell_override) =
session_configuration.user_shell_override.clone()
{
user_shell_override
@@ -1750,25 +1751,24 @@ impl Session {
shell::default_user_shell()
};
// Create the mutable state for the Session.
let shell_snapshot_tx = if config.features.enabled(Feature::ShellSnapshot) {
if let Some(snapshot) = session_configuration.inherited_shell_snapshot.clone() {
let (tx, rx) = watch::channel(Some(snapshot));
default_shell.shell_snapshot = rx;
tx
let (shell_snapshot_tx, shell_snapshot_rx) =
if config.features.enabled(Feature::ShellSnapshot) {
if let Some(snapshot) = session_configuration.inherited_shell_snapshot.clone() {
watch::channel(Some(snapshot))
} else {
let (shell_snapshot_tx, shell_snapshot_rx) = ShellSnapshot::start_snapshotting(
config.codex_home.clone(),
conversation_id,
session_configuration.cwd.to_path_buf(),
default_shell.clone(),
session_telemetry.clone(),
);
spawn_stale_snapshot_cleanup(config.codex_home.clone(), conversation_id);
(shell_snapshot_tx, shell_snapshot_rx)
}
} else {
ShellSnapshot::start_snapshotting(
config.codex_home.clone(),
conversation_id,
session_configuration.cwd.to_path_buf(),
&mut default_shell,
session_telemetry.clone(),
)
}
} else {
let (tx, rx) = watch::channel(None);
default_shell.shell_snapshot = rx;
tx
};
watch::channel(None)
};
let thread_name =
match session_index::find_thread_name_by_id(&config.codex_home, &conversation_id)
.instrument(info_span!(
@@ -1885,6 +1885,7 @@ impl Session {
rollout: Mutex::new(rollout_recorder),
user_shell: Arc::new(default_shell),
shell_snapshot_tx,
shell_snapshot_rx,
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
exec_policy,
auth_manager: Arc::clone(&auth_manager),
@@ -2333,6 +2334,7 @@ impl Session {
self.services.shell_snapshot_tx.clone(),
self.services.session_telemetry.clone(),
);
spawn_stale_snapshot_cleanup(codex_home.to_path_buf(), self.conversation_id);
}
pub(crate) async fn update_settings(
@@ -4225,6 +4227,10 @@ impl Session {
Arc::clone(&self.services.user_shell)
}
pub(crate) fn shell_snapshot(&self) -> Option<Arc<ShellSnapshot>> {
self.services.shell_snapshot_rx.borrow().clone()
}
pub(crate) async fn current_rollout_path(&self) -> Option<PathBuf> {
let recorder = {
let guard = self.services.rollout.lock().await;
@@ -5472,7 +5478,7 @@ async fn spawn_review_thread(
windows_sandbox_level: parent_turn_context.windows_sandbox_level,
})
.with_unified_exec_shell_mode_for_session(
crate::tools::spec::tool_user_shell_type(sess.services.user_shell.as_ref()),
sess.services.user_shell.shell_type,
sess.services.shell_zsh_path.as_ref(),
sess.services.main_execve_wrapper_exe.as_ref(),
)

View File

@@ -2654,6 +2654,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
);
let skills_watcher = Arc::new(SkillsWatcher::noop());
let (shell_snapshot_tx, shell_snapshot_rx) = watch::channel(None);
let services = SessionServices {
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new_uninitialized(
&config.permissions.approval_policy,
@@ -2675,7 +2676,8 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
}),
rollout: Mutex::new(None),
user_shell: Arc::new(default_user_shell()),
shell_snapshot_tx: watch::channel(None).0,
shell_snapshot_tx,
shell_snapshot_rx,
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
exec_policy,
auth_manager: auth_manager.clone(),
@@ -3491,6 +3493,7 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
);
let skills_watcher = Arc::new(SkillsWatcher::noop());
let (shell_snapshot_tx, shell_snapshot_rx) = watch::channel(None);
let services = SessionServices {
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new_uninitialized(
&config.permissions.approval_policy,
@@ -3512,7 +3515,8 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
}),
rollout: Mutex::new(None),
user_shell: Arc::new(default_user_shell()),
shell_snapshot_tx: watch::channel(None).0,
shell_snapshot_tx,
shell_snapshot_rx,
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
exec_policy,
auth_manager: Arc::clone(&auth_manager),

View File

@@ -8,7 +8,6 @@ fn fake_shell() -> Shell {
Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
}
}
@@ -222,7 +221,6 @@ fn equals_except_shell_ignores_shell() {
Shell {
shell_type: ShellType::Bash,
shell_path: "/bin/bash".into(),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
},
/*current_date*/ None,
/*timezone*/ None,
@@ -234,7 +232,6 @@ fn equals_except_shell_ignores_shell() {
Shell {
shell_type: ShellType::Zsh,
shell_path: "/bin/zsh".into(),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
},
/*current_date*/ None,
/*timezone*/ None,

View File

@@ -72,7 +72,6 @@ mod sandbox_tags;
pub mod sandboxing;
mod session_prefix;
mod session_startup_prewarm;
mod shell_detect;
pub mod skills;
pub(crate) use skills::SkillError;
pub(crate) use skills::SkillInjections;

View File

@@ -1,385 +1,6 @@
use crate::shell_detect::detect_shell_type;
use crate::shell_snapshot::ShellSnapshot;
use serde::Deserialize;
use serde::Serialize;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::watch;
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub enum ShellType {
Zsh,
Bash,
PowerShell,
Sh,
Cmd,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Shell {
pub(crate) shell_type: ShellType,
pub(crate) shell_path: PathBuf,
#[serde(
skip_serializing,
skip_deserializing,
default = "empty_shell_snapshot_receiver"
)]
pub(crate) shell_snapshot: watch::Receiver<Option<Arc<ShellSnapshot>>>,
}
impl Shell {
pub fn name(&self) -> &'static str {
match self.shell_type {
ShellType::Zsh => "zsh",
ShellType::Bash => "bash",
ShellType::PowerShell => "powershell",
ShellType::Sh => "sh",
ShellType::Cmd => "cmd",
}
}
/// Takes a string of shell and returns the full list of command args to
/// use with `exec()` to run the shell command.
pub fn derive_exec_args(&self, command: &str, use_login_shell: bool) -> Vec<String> {
match self.shell_type {
ShellType::Zsh | ShellType::Bash | ShellType::Sh => {
let arg = if use_login_shell { "-lc" } else { "-c" };
vec![
self.shell_path.to_string_lossy().to_string(),
arg.to_string(),
command.to_string(),
]
}
ShellType::PowerShell => {
let mut args = vec![self.shell_path.to_string_lossy().to_string()];
if !use_login_shell {
args.push("-NoProfile".to_string());
}
args.push("-Command".to_string());
args.push(command.to_string());
args
}
ShellType::Cmd => {
let mut args = vec![self.shell_path.to_string_lossy().to_string()];
args.push("/c".to_string());
args.push(command.to_string());
args
}
}
}
/// Return the shell snapshot if existing.
pub fn shell_snapshot(&self) -> Option<Arc<ShellSnapshot>> {
self.shell_snapshot.borrow().clone()
}
}
pub(crate) fn empty_shell_snapshot_receiver() -> watch::Receiver<Option<Arc<ShellSnapshot>>> {
let (_tx, rx) = watch::channel(None);
rx
}
impl PartialEq for Shell {
fn eq(&self, other: &Self) -> bool {
self.shell_type == other.shell_type && self.shell_path == other.shell_path
}
}
impl Eq for Shell {}
#[cfg(unix)]
fn get_user_shell_path() -> Option<PathBuf> {
let uid = unsafe { libc::getuid() };
use std::ffi::CStr;
use std::mem::MaybeUninit;
use std::ptr;
let mut passwd = MaybeUninit::<libc::passwd>::uninit();
// We cannot use getpwuid here: it returns pointers into libc-managed
// storage, which is not safe to read concurrently on all targets (the musl
// static build used by the CLI can segfault when parallel callers race on
// that buffer). getpwuid_r keeps the passwd data in caller-owned memory.
let suggested_buffer_len = unsafe { libc::sysconf(libc::_SC_GETPW_R_SIZE_MAX) };
let buffer_len = usize::try_from(suggested_buffer_len)
.ok()
.filter(|len| *len > 0)
.unwrap_or(1024);
let mut buffer = vec![0; buffer_len];
loop {
let mut result = ptr::null_mut();
let status = unsafe {
libc::getpwuid_r(
uid,
passwd.as_mut_ptr(),
buffer.as_mut_ptr().cast(),
buffer.len(),
&mut result,
)
};
if status == 0 {
if result.is_null() {
return None;
}
let passwd = unsafe { passwd.assume_init_ref() };
if passwd.pw_shell.is_null() {
return None;
}
let shell_path = unsafe { CStr::from_ptr(passwd.pw_shell) }
.to_string_lossy()
.into_owned();
return Some(PathBuf::from(shell_path));
}
if status != libc::ERANGE {
return None;
}
// Retry with a larger buffer until libc can materialize the passwd entry.
let new_len = buffer.len().checked_mul(2)?;
if new_len > 1024 * 1024 {
return None;
}
buffer.resize(new_len, 0);
}
}
#[cfg(not(unix))]
fn get_user_shell_path() -> Option<PathBuf> {
None
}
fn file_exists(path: &PathBuf) -> Option<PathBuf> {
if std::fs::metadata(path).is_ok_and(|metadata| metadata.is_file()) {
Some(PathBuf::from(path))
} else {
None
}
}
fn get_shell_path(
shell_type: ShellType,
provided_path: Option<&PathBuf>,
binary_name: &str,
fallback_paths: Vec<&str>,
) -> Option<PathBuf> {
// If exact provided path exists, use it
if provided_path.and_then(file_exists).is_some() {
return provided_path.cloned();
}
// Check if the shell we are trying to load is user's default shell
// if just use it
let default_shell_path = get_user_shell_path();
if let Some(default_shell_path) = default_shell_path
&& detect_shell_type(&default_shell_path) == Some(shell_type)
&& file_exists(&default_shell_path).is_some()
{
return Some(default_shell_path);
}
if let Ok(path) = which::which(binary_name) {
return Some(path);
}
for path in fallback_paths {
//check exists
if let Some(path) = file_exists(&PathBuf::from(path)) {
return Some(path);
}
}
None
}
fn get_zsh_shell(path: Option<&PathBuf>) -> Option<Shell> {
let shell_path = get_shell_path(ShellType::Zsh, path, "zsh", vec!["/bin/zsh"]);
shell_path.map(|shell_path| Shell {
shell_type: ShellType::Zsh,
shell_path,
shell_snapshot: empty_shell_snapshot_receiver(),
})
}
fn get_bash_shell(path: Option<&PathBuf>) -> Option<Shell> {
let shell_path = get_shell_path(ShellType::Bash, path, "bash", vec!["/bin/bash"]);
shell_path.map(|shell_path| Shell {
shell_type: ShellType::Bash,
shell_path,
shell_snapshot: empty_shell_snapshot_receiver(),
})
}
fn get_sh_shell(path: Option<&PathBuf>) -> Option<Shell> {
let shell_path = get_shell_path(ShellType::Sh, path, "sh", vec!["/bin/sh"]);
shell_path.map(|shell_path| Shell {
shell_type: ShellType::Sh,
shell_path,
shell_snapshot: empty_shell_snapshot_receiver(),
})
}
fn get_powershell_shell(path: Option<&PathBuf>) -> Option<Shell> {
let shell_path = get_shell_path(
ShellType::PowerShell,
path,
"pwsh",
vec!["/usr/local/bin/pwsh"],
)
.or_else(|| get_shell_path(ShellType::PowerShell, path, "powershell", vec![]));
shell_path.map(|shell_path| Shell {
shell_type: ShellType::PowerShell,
shell_path,
shell_snapshot: empty_shell_snapshot_receiver(),
})
}
fn get_cmd_shell(path: Option<&PathBuf>) -> Option<Shell> {
let shell_path = get_shell_path(ShellType::Cmd, path, "cmd", vec![]);
shell_path.map(|shell_path| Shell {
shell_type: ShellType::Cmd,
shell_path,
shell_snapshot: empty_shell_snapshot_receiver(),
})
}
fn ultimate_fallback_shell() -> Shell {
if cfg!(windows) {
Shell {
shell_type: ShellType::Cmd,
shell_path: PathBuf::from("cmd.exe"),
shell_snapshot: empty_shell_snapshot_receiver(),
}
} else {
Shell {
shell_type: ShellType::Sh,
shell_path: PathBuf::from("/bin/sh"),
shell_snapshot: empty_shell_snapshot_receiver(),
}
}
}
pub fn get_shell_by_model_provided_path(shell_path: &PathBuf) -> Shell {
detect_shell_type(shell_path)
.and_then(|shell_type| get_shell(shell_type, Some(shell_path)))
.unwrap_or(ultimate_fallback_shell())
}
pub fn get_shell(shell_type: ShellType, path: Option<&PathBuf>) -> Option<Shell> {
match shell_type {
ShellType::Zsh => get_zsh_shell(path),
ShellType::Bash => get_bash_shell(path),
ShellType::PowerShell => get_powershell_shell(path),
ShellType::Sh => get_sh_shell(path),
ShellType::Cmd => get_cmd_shell(path),
}
}
pub fn default_user_shell() -> Shell {
default_user_shell_from_path(get_user_shell_path())
}
fn default_user_shell_from_path(user_shell_path: Option<PathBuf>) -> Shell {
if cfg!(windows) {
get_shell(ShellType::PowerShell, /*path*/ None).unwrap_or(ultimate_fallback_shell())
} else {
let user_default_shell = user_shell_path
.and_then(|shell| detect_shell_type(&shell))
.and_then(|shell_type| get_shell(shell_type, /*path*/ None));
let shell_with_fallback = if cfg!(target_os = "macos") {
user_default_shell
.or_else(|| get_shell(ShellType::Zsh, /*path*/ None))
.or_else(|| get_shell(ShellType::Bash, /*path*/ None))
} else {
user_default_shell
.or_else(|| get_shell(ShellType::Bash, /*path*/ None))
.or_else(|| get_shell(ShellType::Zsh, /*path*/ None))
};
shell_with_fallback.unwrap_or(ultimate_fallback_shell())
}
}
#[cfg(test)]
mod detect_shell_type_tests {
use super::*;
#[test]
fn test_detect_shell_type() {
assert_eq!(
detect_shell_type(&PathBuf::from("zsh")),
Some(ShellType::Zsh)
);
assert_eq!(
detect_shell_type(&PathBuf::from("bash")),
Some(ShellType::Bash)
);
assert_eq!(
detect_shell_type(&PathBuf::from("pwsh")),
Some(ShellType::PowerShell)
);
assert_eq!(
detect_shell_type(&PathBuf::from("powershell")),
Some(ShellType::PowerShell)
);
assert_eq!(detect_shell_type(&PathBuf::from("fish")), None);
assert_eq!(detect_shell_type(&PathBuf::from("other")), None);
assert_eq!(
detect_shell_type(&PathBuf::from("/bin/zsh")),
Some(ShellType::Zsh)
);
assert_eq!(
detect_shell_type(&PathBuf::from("/bin/bash")),
Some(ShellType::Bash)
);
assert_eq!(
detect_shell_type(&PathBuf::from("powershell.exe")),
Some(ShellType::PowerShell)
);
assert_eq!(
detect_shell_type(&PathBuf::from(if cfg!(windows) {
"C:\\windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"
} else {
"/usr/local/bin/pwsh"
})),
Some(ShellType::PowerShell)
);
assert_eq!(
detect_shell_type(&PathBuf::from("pwsh.exe")),
Some(ShellType::PowerShell)
);
assert_eq!(
detect_shell_type(&PathBuf::from("/usr/local/bin/pwsh")),
Some(ShellType::PowerShell)
);
assert_eq!(
detect_shell_type(&PathBuf::from("/bin/sh")),
Some(ShellType::Sh)
);
assert_eq!(detect_shell_type(&PathBuf::from("sh")), Some(ShellType::Sh));
assert_eq!(
detect_shell_type(&PathBuf::from("cmd")),
Some(ShellType::Cmd)
);
assert_eq!(
detect_shell_type(&PathBuf::from("cmd.exe")),
Some(ShellType::Cmd)
);
}
}
#[cfg(test)]
#[cfg(unix)]
#[path = "shell_tests.rs"]
mod tests;
pub use codex_shell::Shell;
pub use codex_shell::ShellType;
pub use codex_shell::default_user_shell;
pub use codex_shell::detect_shell_type;
pub use codex_shell::get_shell;
pub use codex_shell::get_shell_by_model_provided_path;

View File

@@ -1,24 +0,0 @@
use crate::shell::ShellType;
use std::path::Path;
use std::path::PathBuf;
pub(crate) fn detect_shell_type(shell_path: &PathBuf) -> Option<ShellType> {
match shell_path.as_os_str().to_str() {
Some("zsh") => Some(ShellType::Zsh),
Some("sh") => Some(ShellType::Sh),
Some("cmd") => Some(ShellType::Cmd),
Some("bash") => Some(ShellType::Bash),
Some("pwsh") => Some(ShellType::PowerShell),
Some("powershell") => Some(ShellType::PowerShell),
_ => {
let shell_name = shell_path.file_stem();
if let Some(shell_name) = shell_name {
let shell_name_path = Path::new(shell_name);
if shell_name_path != Path::new(shell_path) {
return detect_shell_type(&shell_name_path.to_path_buf());
}
}
None
}
}
}

View File

@@ -1,490 +1,19 @@
use std::io::ErrorKind;
use std::path::Path;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use std::time::SystemTime;
use crate::rollout::list::find_thread_path_by_id_str;
use crate::shell::Shell;
use crate::shell::ShellType;
use crate::shell::get_shell;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use anyhow::bail;
use codex_otel::SessionTelemetry;
use codex_protocol::ThreadId;
pub use codex_shell_snapshot::SNAPSHOT_DIR;
pub use codex_shell_snapshot::SNAPSHOT_RETENTION;
pub use codex_shell_snapshot::ShellSnapshot;
pub use codex_shell_snapshot::ShellSnapshotReceiver;
pub use codex_shell_snapshot::ShellSnapshotSender;
use codex_shell_snapshot::remove_snapshot_file;
pub use codex_shell_snapshot::snapshot_session_id_from_file_name;
use tokio::fs;
use tokio::process::Command;
use tokio::sync::watch;
use tokio::time::timeout;
use tracing::Instrument;
use tracing::info_span;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ShellSnapshot {
pub path: PathBuf,
pub cwd: PathBuf,
}
const SNAPSHOT_TIMEOUT: Duration = Duration::from_secs(10);
const SNAPSHOT_RETENTION: Duration = Duration::from_secs(60 * 60 * 24 * 3); // 3 days retention.
const SNAPSHOT_DIR: &str = "shell_snapshots";
const EXCLUDED_EXPORT_VARS: &[&str] = &["PWD", "OLDPWD"];
impl ShellSnapshot {
pub fn start_snapshotting(
codex_home: PathBuf,
session_id: ThreadId,
session_cwd: PathBuf,
shell: &mut Shell,
session_telemetry: SessionTelemetry,
) -> watch::Sender<Option<Arc<ShellSnapshot>>> {
let (shell_snapshot_tx, shell_snapshot_rx) = watch::channel(None);
shell.shell_snapshot = shell_snapshot_rx;
Self::spawn_snapshot_task(
codex_home,
session_id,
session_cwd,
shell.clone(),
shell_snapshot_tx.clone(),
session_telemetry,
);
shell_snapshot_tx
}
pub fn refresh_snapshot(
codex_home: PathBuf,
session_id: ThreadId,
session_cwd: PathBuf,
shell: Shell,
shell_snapshot_tx: watch::Sender<Option<Arc<ShellSnapshot>>>,
session_telemetry: SessionTelemetry,
) {
Self::spawn_snapshot_task(
codex_home,
session_id,
session_cwd,
shell,
shell_snapshot_tx,
session_telemetry,
);
}
fn spawn_snapshot_task(
codex_home: PathBuf,
session_id: ThreadId,
session_cwd: PathBuf,
snapshot_shell: Shell,
shell_snapshot_tx: watch::Sender<Option<Arc<ShellSnapshot>>>,
session_telemetry: SessionTelemetry,
) {
let snapshot_span = info_span!("shell_snapshot", thread_id = %session_id);
tokio::spawn(
async move {
let timer = session_telemetry.start_timer("codex.shell_snapshot.duration_ms", &[]);
let snapshot = ShellSnapshot::try_new(
&codex_home,
session_id,
session_cwd.as_path(),
&snapshot_shell,
)
.await
.map(Arc::new);
let success = snapshot.is_ok();
let success_tag = if success { "true" } else { "false" };
let _ = timer.map(|timer| timer.record(&[("success", success_tag)]));
let mut counter_tags = vec![("success", success_tag)];
if let Some(failure_reason) = snapshot.as_ref().err() {
counter_tags.push(("failure_reason", *failure_reason));
}
session_telemetry.counter("codex.shell_snapshot", /*inc*/ 1, &counter_tags);
let _ = shell_snapshot_tx.send(snapshot.ok());
}
.instrument(snapshot_span),
);
}
async fn try_new(
codex_home: &Path,
session_id: ThreadId,
session_cwd: &Path,
shell: &Shell,
) -> std::result::Result<Self, &'static str> {
// File to store the snapshot
let extension = match shell.shell_type {
ShellType::PowerShell => "ps1",
_ => "sh",
};
let nonce = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|duration| duration.as_nanos())
.unwrap_or(0);
let path = codex_home
.join(SNAPSHOT_DIR)
.join(format!("{session_id}.{nonce}.{extension}"));
let temp_path = codex_home
.join(SNAPSHOT_DIR)
.join(format!("{session_id}.tmp-{nonce}"));
// Clean the (unlikely) leaked snapshot files.
let codex_home = codex_home.to_path_buf();
let cleanup_session_id = session_id;
tokio::spawn(async move {
if let Err(err) = cleanup_stale_snapshots(&codex_home, cleanup_session_id).await {
tracing::warn!("Failed to clean up shell snapshots: {err:?}");
}
});
// Make the new snapshot.
let temp_path =
match write_shell_snapshot(shell.shell_type.clone(), &temp_path, session_cwd).await {
Ok(path) => {
tracing::info!("Shell snapshot successfully created: {}", path.display());
path
}
Err(err) => {
tracing::warn!(
"Failed to create shell snapshot for {}: {err:?}",
shell.name()
);
return Err("write_failed");
}
};
let temp_snapshot = Self {
path: temp_path.clone(),
cwd: session_cwd.to_path_buf(),
};
if let Err(err) = validate_snapshot(shell, &temp_snapshot.path, session_cwd).await {
tracing::error!("Shell snapshot validation failed: {err:?}");
remove_snapshot_file(&temp_snapshot.path).await;
return Err("validation_failed");
}
if let Err(err) = fs::rename(&temp_snapshot.path, &path).await {
tracing::warn!("Failed to finalize shell snapshot: {err:?}");
remove_snapshot_file(&temp_snapshot.path).await;
return Err("write_failed");
}
Ok(Self {
path,
cwd: session_cwd.to_path_buf(),
})
}
}
impl Drop for ShellSnapshot {
fn drop(&mut self) {
if let Err(err) = std::fs::remove_file(&self.path) {
tracing::warn!(
"Failed to delete shell snapshot at {:?}: {err:?}",
self.path
);
}
}
}
async fn write_shell_snapshot(
shell_type: ShellType,
output_path: &Path,
cwd: &Path,
) -> Result<PathBuf> {
if shell_type == ShellType::PowerShell || shell_type == ShellType::Cmd {
bail!("Shell snapshot not supported yet for {shell_type:?}");
}
let shell = get_shell(shell_type.clone(), /*path*/ None)
.with_context(|| format!("No available shell for {shell_type:?}"))?;
let raw_snapshot = capture_snapshot(&shell, cwd).await?;
let snapshot = strip_snapshot_preamble(&raw_snapshot)?;
if let Some(parent) = output_path.parent() {
let parent_display = parent.display();
fs::create_dir_all(parent)
.await
.with_context(|| format!("Failed to create snapshot parent {parent_display}"))?;
}
let snapshot_path = output_path.display();
fs::write(output_path, snapshot)
.await
.with_context(|| format!("Failed to write snapshot to {snapshot_path}"))?;
Ok(output_path.to_path_buf())
}
async fn capture_snapshot(shell: &Shell, cwd: &Path) -> Result<String> {
let shell_type = shell.shell_type.clone();
match shell_type {
ShellType::Zsh => run_shell_script(shell, &zsh_snapshot_script(), cwd).await,
ShellType::Bash => run_shell_script(shell, &bash_snapshot_script(), cwd).await,
ShellType::Sh => run_shell_script(shell, &sh_snapshot_script(), cwd).await,
ShellType::PowerShell => run_shell_script(shell, powershell_snapshot_script(), cwd).await,
ShellType::Cmd => bail!("Shell snapshotting is not yet supported for {shell_type:?}"),
}
}
fn strip_snapshot_preamble(snapshot: &str) -> Result<String> {
let marker = "# Snapshot file";
let Some(start) = snapshot.find(marker) else {
bail!("Snapshot output missing marker {marker}");
};
Ok(snapshot[start..].to_string())
}
async fn validate_snapshot(shell: &Shell, snapshot_path: &Path, cwd: &Path) -> Result<()> {
let snapshot_path_display = snapshot_path.display();
let script = format!("set -e; . \"{snapshot_path_display}\"");
run_script_with_timeout(
shell,
&script,
SNAPSHOT_TIMEOUT,
/*use_login_shell*/ false,
cwd,
)
.await
.map(|_| ())
}
async fn run_shell_script(shell: &Shell, script: &str, cwd: &Path) -> Result<String> {
run_script_with_timeout(
shell,
script,
SNAPSHOT_TIMEOUT,
/*use_login_shell*/ true,
cwd,
)
.await
}
async fn run_script_with_timeout(
shell: &Shell,
script: &str,
snapshot_timeout: Duration,
use_login_shell: bool,
cwd: &Path,
) -> Result<String> {
let args = shell.derive_exec_args(script, use_login_shell);
let shell_name = shell.name();
// Handler is kept as guard to control the drop. The `mut` pattern is required because .args()
// returns a ref of handler.
let mut handler = Command::new(&args[0]);
handler.args(&args[1..]);
handler.stdin(Stdio::null());
handler.current_dir(cwd);
#[cfg(unix)]
unsafe {
handler.pre_exec(|| {
codex_utils_pty::process_group::detach_from_tty()?;
Ok(())
});
}
handler.kill_on_drop(true);
let output = timeout(snapshot_timeout, handler.output())
.await
.map_err(|_| anyhow!("Snapshot command timed out for {shell_name}"))?
.with_context(|| format!("Failed to execute {shell_name}"))?;
if !output.status.success() {
let status = output.status;
let stderr = String::from_utf8_lossy(&output.stderr);
bail!("Snapshot command exited with status {status}: {stderr}");
}
Ok(String::from_utf8_lossy(&output.stdout).into_owned())
}
fn excluded_exports_regex() -> String {
EXCLUDED_EXPORT_VARS.join("|")
}
fn zsh_snapshot_script() -> String {
let excluded = excluded_exports_regex();
let script = r##"if [[ -n "$ZDOTDIR" ]]; then
rc="$ZDOTDIR/.zshrc"
else
rc="$HOME/.zshrc"
fi
[[ -r "$rc" ]] && . "$rc"
print '# Snapshot file'
print '# Unset all aliases to avoid conflicts with functions'
print 'unalias -a 2>/dev/null || true'
print '# Functions'
functions
print ''
setopt_count=$(setopt | wc -l | tr -d ' ')
print "# setopts $setopt_count"
setopt | sed 's/^/setopt /'
print ''
alias_count=$(alias -L | wc -l | tr -d ' ')
print "# aliases $alias_count"
alias -L
print ''
export_lines=$(export -p | awk '
/^(export|declare -x|typeset -x) / {
line=$0
name=line
sub(/^(export|declare -x|typeset -x) /, "", name)
sub(/=.*/, "", name)
if (name ~ /^(EXCLUDED_EXPORTS)$/) {
next
}
if (name ~ /^[A-Za-z_][A-Za-z0-9_]*$/) {
print line
}
}')
export_count=$(printf '%s\n' "$export_lines" | sed '/^$/d' | wc -l | tr -d ' ')
print "# exports $export_count"
if [[ -n "$export_lines" ]]; then
print -r -- "$export_lines"
fi
"##;
script.replace("EXCLUDED_EXPORTS", &excluded)
}
fn bash_snapshot_script() -> String {
let excluded = excluded_exports_regex();
let script = r##"if [ -z "$BASH_ENV" ] && [ -r "$HOME/.bashrc" ]; then
. "$HOME/.bashrc"
fi
echo '# Snapshot file'
echo '# Unset all aliases to avoid conflicts with functions'
unalias -a 2>/dev/null || true
echo '# Functions'
declare -f
echo ''
bash_opts=$(set -o | awk '$2=="on"{print $1}')
bash_opt_count=$(printf '%s\n' "$bash_opts" | sed '/^$/d' | wc -l | tr -d ' ')
echo "# setopts $bash_opt_count"
if [ -n "$bash_opts" ]; then
printf 'set -o %s\n' $bash_opts
fi
echo ''
alias_count=$(alias -p | wc -l | tr -d ' ')
echo "# aliases $alias_count"
alias -p
echo ''
export_lines=$(
while IFS= read -r name; do
if [[ "$name" =~ ^(EXCLUDED_EXPORTS)$ ]]; then
continue
fi
if [[ ! "$name" =~ ^[A-Za-z_][A-Za-z0-9_]*$ ]]; then
continue
fi
declare -xp "$name" 2>/dev/null || true
done < <(compgen -e)
)
export_count=$(printf '%s\n' "$export_lines" | sed '/^$/d' | wc -l | tr -d ' ')
echo "# exports $export_count"
if [ -n "$export_lines" ]; then
printf '%s\n' "$export_lines"
fi
"##;
script.replace("EXCLUDED_EXPORTS", &excluded)
}
fn sh_snapshot_script() -> String {
let excluded = excluded_exports_regex();
let script = r##"if [ -n "$ENV" ] && [ -r "$ENV" ]; then
. "$ENV"
fi
echo '# Snapshot file'
echo '# Unset all aliases to avoid conflicts with functions'
unalias -a 2>/dev/null || true
echo '# Functions'
if command -v typeset >/dev/null 2>&1; then
typeset -f
elif command -v declare >/dev/null 2>&1; then
declare -f
fi
echo ''
if set -o >/dev/null 2>&1; then
sh_opts=$(set -o | awk '$2=="on"{print $1}')
sh_opt_count=$(printf '%s\n' "$sh_opts" | sed '/^$/d' | wc -l | tr -d ' ')
echo "# setopts $sh_opt_count"
if [ -n "$sh_opts" ]; then
printf 'set -o %s\n' $sh_opts
fi
else
echo '# setopts 0'
fi
echo ''
if alias >/dev/null 2>&1; then
alias_count=$(alias | wc -l | tr -d ' ')
echo "# aliases $alias_count"
alias
echo ''
else
echo '# aliases 0'
fi
if export -p >/dev/null 2>&1; then
export_lines=$(export -p | awk '
/^(export|declare -x|typeset -x) / {
line=$0
name=line
sub(/^(export|declare -x|typeset -x) /, "", name)
sub(/=.*/, "", name)
if (name ~ /^(EXCLUDED_EXPORTS)$/) {
next
}
if (name ~ /^[A-Za-z_][A-Za-z0-9_]*$/) {
print line
}
}')
export_count=$(printf '%s\n' "$export_lines" | sed '/^$/d' | wc -l | tr -d ' ')
echo "# exports $export_count"
if [ -n "$export_lines" ]; then
printf '%s\n' "$export_lines"
fi
else
export_count=$(env | sort | awk -F= '$1 ~ /^[A-Za-z_][A-Za-z0-9_]*$/ { count++ } END { print count }')
echo "# exports $export_count"
env | sort | while IFS='=' read -r key value; do
case "$key" in
""|[0-9]*|*[!A-Za-z0-9_]*|EXCLUDED_EXPORTS) continue ;;
esac
escaped=$(printf "%s" "$value" | sed "s/'/'\"'\"'/g")
printf "export %s='%s'\n" "$key" "$escaped"
done
fi
"##;
script.replace("EXCLUDED_EXPORTS", &excluded)
}
fn powershell_snapshot_script() -> &'static str {
r##"$ErrorActionPreference = 'Stop'
Write-Output '# Snapshot file'
Write-Output '# Unset all aliases to avoid conflicts with functions'
Write-Output 'Remove-Item Alias:* -ErrorAction SilentlyContinue'
Write-Output '# Functions'
Get-ChildItem Function: | ForEach-Object {
"function {0} {{`n{1}`n}}" -f $_.Name, $_.Definition
}
Write-Output ''
$aliases = Get-Alias
Write-Output ("# aliases " + $aliases.Count)
$aliases | ForEach-Object {
"Set-Alias -Name {0} -Value {1}" -f $_.Name, $_.Definition
}
Write-Output ''
$envVars = Get-ChildItem Env:
Write-Output ("# exports " + $envVars.Count)
$envVars | ForEach-Object {
$escaped = $_.Value -replace "'", "''"
"`$env:{0}='{1}'" -f $_.Name, $escaped
}
"##
}
/// Removes shell snapshots that either lack a matching session rollout file or
/// whose rollouts have not been updated within the retention window.
@@ -547,22 +76,12 @@ pub async fn cleanup_stale_snapshots(codex_home: &Path, active_session_id: Threa
Ok(())
}
async fn remove_snapshot_file(path: &Path) {
if let Err(err) = fs::remove_file(path).await {
tracing::warn!("Failed to delete shell snapshot at {:?}: {err:?}", path);
}
}
fn snapshot_session_id_from_file_name(file_name: &str) -> Option<&str> {
let (stem, extension) = file_name.rsplit_once('.')?;
match extension {
"sh" | "ps1" => Some(
stem.split_once('.')
.map_or(stem, |(session_id, _generation)| session_id),
),
_ if extension.starts_with("tmp-") => Some(stem),
_ => None,
}
pub(crate) fn spawn_stale_snapshot_cleanup(codex_home: PathBuf, active_session_id: ThreadId) {
tokio::spawn(async move {
if let Err(err) = cleanup_stale_snapshots(&codex_home, active_session_id).await {
tracing::warn!("Failed to clean up shell snapshots: {err:?}");
}
});
}
#[cfg(test)]

View File

@@ -1,101 +1,26 @@
use super::*;
use anyhow::Result;
use pretty_assertions::assert_eq;
#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
use std::path::Path;
use std::path::PathBuf;
#[cfg(unix)]
use std::process::Command;
#[cfg(target_os = "linux")]
use std::process::Command as StdCommand;
use std::time::Duration;
#[cfg(unix)]
use std::time::SystemTime;
use tempfile::tempdir;
#[cfg(unix)]
struct BlockingStdinPipe {
original: i32,
write_end: i32,
}
#[cfg(unix)]
impl BlockingStdinPipe {
fn install() -> Result<Self> {
let mut fds = [0i32; 2];
if unsafe { libc::pipe(fds.as_mut_ptr()) } == -1 {
return Err(std::io::Error::last_os_error()).context("create stdin pipe");
}
let original = unsafe { libc::dup(libc::STDIN_FILENO) };
if original == -1 {
let err = std::io::Error::last_os_error();
unsafe {
libc::close(fds[0]);
libc::close(fds[1]);
}
return Err(err).context("dup stdin");
}
if unsafe { libc::dup2(fds[0], libc::STDIN_FILENO) } == -1 {
let err = std::io::Error::last_os_error();
unsafe {
libc::close(fds[0]);
libc::close(fds[1]);
libc::close(original);
}
return Err(err).context("replace stdin");
}
unsafe {
libc::close(fds[0]);
}
Ok(Self {
original,
write_end: fds[1],
})
}
}
#[cfg(unix)]
impl Drop for BlockingStdinPipe {
fn drop(&mut self) {
unsafe {
libc::dup2(self.original, libc::STDIN_FILENO);
libc::close(self.original);
libc::close(self.write_end);
}
}
}
#[cfg(not(target_os = "windows"))]
fn assert_posix_snapshot_sections(snapshot: &str) {
assert!(snapshot.contains("# Snapshot file"));
assert!(snapshot.contains("aliases "));
assert!(snapshot.contains("exports "));
assert!(
snapshot.contains("PATH"),
"snapshot should capture a PATH export"
);
assert!(snapshot.contains("setopts "));
}
async fn get_snapshot(shell_type: ShellType) -> Result<String> {
let dir = tempdir()?;
let path = dir.path().join("snapshot.sh");
write_shell_snapshot(shell_type, &path, dir.path()).await?;
let content = fs::read_to_string(&path).await?;
Ok(content)
}
#[test]
fn strip_snapshot_preamble_removes_leading_output() {
let snapshot = "noise\n# Snapshot file\nexport PATH=/bin\n";
let cleaned = strip_snapshot_preamble(snapshot).expect("snapshot marker exists");
assert_eq!(cleaned, "# Snapshot file\nexport PATH=/bin\n");
}
#[test]
fn strip_snapshot_preamble_requires_marker() {
let result = strip_snapshot_preamble("missing header");
assert!(result.is_err());
async fn write_rollout_stub(codex_home: &Path, session_id: ThreadId) -> Result<PathBuf> {
let dir = codex_home
.join("sessions")
.join("2025")
.join("01")
.join("01");
fs::create_dir_all(&dir).await?;
let path = dir.join(format!("rollout-2025-01-01T00-00-00-{session_id}.jsonl"));
fs::write(&path, "").await?;
Ok(path)
}
#[test]
@@ -120,286 +45,6 @@ fn snapshot_file_name_parser_supports_legacy_and_suffixed_names() {
);
}
#[cfg(unix)]
#[test]
fn bash_snapshot_filters_invalid_exports() -> Result<()> {
let output = Command::new("/bin/bash")
.arg("-c")
.arg(bash_snapshot_script())
.env("BASH_ENV", "/dev/null")
.env("VALID_NAME", "ok")
.env("PWD", "/tmp/stale")
.env("NEXTEST_BIN_EXE_codex-write-config-schema", "/path/to/bin")
.env("BAD-NAME", "broken")
.output()?;
assert!(output.status.success());
let stdout = String::from_utf8_lossy(&output.stdout);
assert!(stdout.contains("VALID_NAME"));
assert!(!stdout.contains("PWD=/tmp/stale"));
assert!(!stdout.contains("NEXTEST_BIN_EXE_codex-write-config-schema"));
assert!(!stdout.contains("BAD-NAME"));
Ok(())
}
#[cfg(unix)]
#[test]
fn bash_snapshot_preserves_multiline_exports() -> Result<()> {
let multiline_cert = "-----BEGIN CERTIFICATE-----\nabc\n-----END CERTIFICATE-----";
let output = Command::new("/bin/bash")
.arg("-c")
.arg(bash_snapshot_script())
.env("BASH_ENV", "/dev/null")
.env("MULTILINE_CERT", multiline_cert)
.output()?;
assert!(output.status.success());
let stdout = String::from_utf8_lossy(&output.stdout);
assert!(
stdout.contains("MULTILINE_CERT=") || stdout.contains("MULTILINE_CERT"),
"snapshot should include the multiline export name"
);
let dir = tempdir()?;
let snapshot_path = dir.path().join("snapshot.sh");
std::fs::write(&snapshot_path, stdout.as_bytes())?;
let validate = Command::new("/bin/bash")
.arg("-c")
.arg("set -e; . \"$1\"")
.arg("bash")
.arg(&snapshot_path)
.env("BASH_ENV", "/dev/null")
.output()?;
assert!(
validate.status.success(),
"snapshot validation failed: {}",
String::from_utf8_lossy(&validate.stderr)
);
Ok(())
}
#[cfg(unix)]
#[tokio::test]
async fn try_new_creates_and_deletes_snapshot_file() -> Result<()> {
let dir = tempdir()?;
let shell = Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
};
let snapshot = ShellSnapshot::try_new(dir.path(), ThreadId::new(), dir.path(), &shell)
.await
.expect("snapshot should be created");
let path = snapshot.path.clone();
assert!(path.exists());
assert_eq!(snapshot.cwd, dir.path().to_path_buf());
drop(snapshot);
assert!(!path.exists());
Ok(())
}
#[cfg(unix)]
#[tokio::test]
async fn try_new_uses_distinct_generation_paths() -> Result<()> {
let dir = tempdir()?;
let session_id = ThreadId::new();
let shell = Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
};
let initial_snapshot = ShellSnapshot::try_new(dir.path(), session_id, dir.path(), &shell)
.await
.expect("initial snapshot should be created");
let refreshed_snapshot = ShellSnapshot::try_new(dir.path(), session_id, dir.path(), &shell)
.await
.expect("refreshed snapshot should be created");
let initial_path = initial_snapshot.path.clone();
let refreshed_path = refreshed_snapshot.path.clone();
assert_ne!(initial_path, refreshed_path);
assert_eq!(initial_path.exists(), true);
assert_eq!(refreshed_path.exists(), true);
drop(initial_snapshot);
assert_eq!(initial_path.exists(), false);
assert_eq!(refreshed_path.exists(), true);
drop(refreshed_snapshot);
assert_eq!(refreshed_path.exists(), false);
Ok(())
}
#[cfg(unix)]
#[tokio::test]
async fn snapshot_shell_does_not_inherit_stdin() -> Result<()> {
let _stdin_guard = BlockingStdinPipe::install()?;
let dir = tempdir()?;
let home = dir.path();
let read_status_path = home.join("stdin-read-status");
let read_status_display = read_status_path.display();
// Persist the startup `read` exit status so the test can assert whether
// bash saw EOF on stdin after the snapshot process exits.
let bashrc = format!("read -t 1 -r ignored\nprintf '%s' \"$?\" > \"{read_status_display}\"\n");
fs::write(home.join(".bashrc"), bashrc).await?;
let shell = Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
};
let home_display = home.display();
let script = format!(
"HOME=\"{home_display}\"; export HOME; {}",
bash_snapshot_script()
);
let output = run_script_with_timeout(
&shell,
&script,
Duration::from_secs(2),
/*use_login_shell*/ true,
home,
)
.await
.context("run snapshot command")?;
let read_status = fs::read_to_string(&read_status_path)
.await
.context("read stdin probe status")?;
assert_eq!(
read_status, "1",
"expected shell startup read to see EOF on stdin; status={read_status:?}"
);
assert!(
output.contains("# Snapshot file"),
"expected snapshot marker in output; output={output:?}"
);
Ok(())
}
#[cfg(target_os = "linux")]
#[tokio::test]
async fn timed_out_snapshot_shell_is_terminated() -> Result<()> {
use std::process::Stdio;
use tokio::time::Duration as TokioDuration;
use tokio::time::Instant;
use tokio::time::sleep;
let dir = tempdir()?;
let pid_path = dir.path().join("pid");
let script = format!("echo $$ > \"{}\"; sleep 30", pid_path.display());
let shell = Shell {
shell_type: ShellType::Sh,
shell_path: PathBuf::from("/bin/sh"),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
};
let err = run_script_with_timeout(
&shell,
&script,
Duration::from_secs(1),
/*use_login_shell*/ true,
dir.path(),
)
.await
.expect_err("snapshot shell should time out");
assert!(
err.to_string().contains("timed out"),
"expected timeout error, got {err:?}"
);
let pid = fs::read_to_string(&pid_path)
.await
.expect("snapshot shell writes its pid before timing out")
.trim()
.parse::<i32>()?;
let deadline = Instant::now() + TokioDuration::from_secs(1);
loop {
let kill_status = StdCommand::new("kill")
.arg("-0")
.arg(pid.to_string())
.stderr(Stdio::null())
.stdout(Stdio::null())
.status()?;
if !kill_status.success() {
break;
}
if Instant::now() >= deadline {
panic!("timed out snapshot shell is still alive after grace period");
}
sleep(TokioDuration::from_millis(50)).await;
}
Ok(())
}
#[cfg(target_os = "macos")]
#[tokio::test]
async fn macos_zsh_snapshot_includes_sections() -> Result<()> {
let snapshot = get_snapshot(ShellType::Zsh).await?;
assert_posix_snapshot_sections(&snapshot);
Ok(())
}
#[cfg(target_os = "linux")]
#[tokio::test]
async fn linux_bash_snapshot_includes_sections() -> Result<()> {
let snapshot = get_snapshot(ShellType::Bash).await?;
assert_posix_snapshot_sections(&snapshot);
Ok(())
}
#[cfg(target_os = "linux")]
#[tokio::test]
async fn linux_sh_snapshot_includes_sections() -> Result<()> {
let snapshot = get_snapshot(ShellType::Sh).await?;
assert_posix_snapshot_sections(&snapshot);
Ok(())
}
#[cfg(target_os = "windows")]
#[ignore]
#[tokio::test]
async fn windows_powershell_snapshot_includes_sections() -> Result<()> {
let snapshot = get_snapshot(ShellType::PowerShell).await?;
assert!(snapshot.contains("# Snapshot file"));
assert!(snapshot.contains("aliases "));
assert!(snapshot.contains("exports "));
Ok(())
}
async fn write_rollout_stub(codex_home: &Path, session_id: ThreadId) -> Result<PathBuf> {
let dir = codex_home
.join("sessions")
.join("2025")
.join("01")
.join("01");
fs::create_dir_all(&dir).await?;
let path = dir.join(format!("rollout-2025-01-01T00-00-00-{session_id}.jsonl"));
fs::write(&path, "").await?;
Ok(path)
}
#[tokio::test]
async fn cleanup_stale_snapshots_removes_orphans_and_keeps_live() -> Result<()> {
let dir = tempdir()?;
@@ -476,7 +121,7 @@ fn set_file_mtime(path: &Path, age: Duration) -> Result<()> {
.saturating_sub(age.as_secs());
let tv_sec = now
.try_into()
.map_err(|_| anyhow!("Snapshot mtime is out of range for libc::timespec"))?;
.map_err(|_| anyhow::anyhow!("Snapshot mtime is out of range for libc::timespec"))?;
let ts = libc::timespec { tv_sec, tv_nsec: 0 };
let times = [ts, ts];
let c_path = std::ffi::CString::new(path.as_os_str().as_bytes())?;

View File

@@ -1,5 +1,3 @@
use std::sync::Arc;
use crate::RolloutRecorder;
use crate::SkillsManager;
use crate::agent::AgentControl;
@@ -9,6 +7,8 @@ use crate::exec_policy::ExecPolicyManager;
use crate::mcp::McpManager;
use crate::models_manager::manager::ModelsManager;
use crate::plugins::PluginsManager;
use crate::shell_snapshot::ShellSnapshotReceiver;
use crate::shell_snapshot::ShellSnapshotSender;
use crate::skills_watcher::SkillsWatcher;
use crate::tools::code_mode::CodeModeService;
use crate::tools::network_approval::NetworkApprovalService;
@@ -22,9 +22,9 @@ use codex_mcp::mcp_connection_manager::McpConnectionManager;
use codex_otel::SessionTelemetry;
use codex_rollout::state_db::StateDbHandle;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
pub(crate) struct SessionServices {
@@ -39,7 +39,8 @@ pub(crate) struct SessionServices {
pub(crate) hooks: Hooks,
pub(crate) rollout: Mutex<Option<RolloutRecorder>>,
pub(crate) user_shell: Arc<crate::shell::Shell>,
pub(crate) shell_snapshot_tx: watch::Sender<Option<Arc<crate::shell_snapshot::ShellSnapshot>>>,
pub(crate) shell_snapshot_tx: ShellSnapshotSender,
pub(crate) shell_snapshot_rx: ShellSnapshotReceiver,
pub(crate) show_raw_agent_reasoning: bool,
pub(crate) exec_policy: Arc<ExecPolicyManager>,
pub(crate) auth_manager: Arc<AuthManager>,

View File

@@ -124,10 +124,12 @@ pub(crate) async fn execute_user_shell_command(
// We do not source rc files or otherwise reformat the script.
let use_login_shell = true;
let session_shell = session.user_shell();
let session_shell_snapshot = session.shell_snapshot();
let display_command = session_shell.derive_exec_args(&command, use_login_shell);
let exec_command = maybe_wrap_shell_lc_with_snapshot(
&display_command,
session_shell.as_ref(),
session_shell_snapshot.as_deref(),
turn_context.cwd.as_path(),
&turn_context.shell_environment_policy.r#set,
);

View File

@@ -9,7 +9,6 @@ use crate::exec_env::create_env;
use crate::sandboxing::SandboxPermissions;
use crate::shell::Shell;
use crate::shell::ShellType;
use crate::shell_snapshot::ShellSnapshot;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
@@ -22,7 +21,6 @@ use codex_shell_command::powershell::try_find_powershell_executable_blocking;
use codex_shell_command::powershell::try_find_pwsh_executable_blocking;
use serde_json::json;
use tokio::sync::Mutex;
use tokio::sync::watch;
/// The logic for is_known_safe_command() has heuristics for known shells,
/// so we must ensure the commands generated by [ShellCommandHandler] can be
@@ -32,14 +30,12 @@ fn commands_generated_by_shell_command_handler_can_be_matched_by_is_known_safe_c
let bash_shell = Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
};
assert_safe(&bash_shell, "ls -la");
let zsh_shell = Shell {
shell_type: ShellType::Zsh,
shell_path: PathBuf::from("/bin/zsh"),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
};
assert_safe(&zsh_shell, "ls -la");
@@ -47,7 +43,6 @@ fn commands_generated_by_shell_command_handler_can_be_matched_by_is_known_safe_c
let powershell = Shell {
shell_type: ShellType::PowerShell,
shell_path: path.to_path_buf(),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
};
assert_safe(&powershell, "ls -Name");
}
@@ -56,7 +51,6 @@ fn commands_generated_by_shell_command_handler_can_be_matched_by_is_known_safe_c
let pwsh = Shell {
shell_type: ShellType::PowerShell,
shell_path: path.to_path_buf(),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
};
assert_safe(&pwsh, "ls -Name");
}
@@ -124,14 +118,9 @@ async fn shell_command_handler_to_exec_params_uses_session_shell_and_turn_contex
#[test]
fn shell_command_handler_respects_explicit_login_flag() {
let (_tx, shell_snapshot) = watch::channel(Some(Arc::new(ShellSnapshot {
path: PathBuf::from("/tmp/snapshot.sh"),
cwd: PathBuf::from("/tmp"),
})));
let shell = Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
shell_snapshot,
};
let login_command = ShellCommandHandler::base_command(

View File

@@ -392,11 +392,10 @@ pub(crate) fn get_command(
match shell_mode {
UnifiedExecShellMode::Direct => {
let model_shell = args.shell.as_ref().map(|shell_str| {
let mut shell = get_shell_by_model_provided_path(&PathBuf::from(shell_str));
shell.shell_snapshot = crate::shell::empty_shell_snapshot_receiver();
shell
});
let model_shell = args
.shell
.as_ref()
.map(|shell_str| get_shell_by_model_provided_path(&PathBuf::from(shell_str)));
let shell = model_shell.as_ref().unwrap_or(session_shell.as_ref());
Ok(shell.derive_exec_args(&args.cmd, use_login_shell))
}

View File

@@ -6,6 +6,7 @@ small and focused and reuses the orchestrator for approvals + sandbox + retry.
*/
use crate::path_utils;
use crate::shell::Shell;
use crate::shell_snapshot::ShellSnapshot;
use crate::tools::sandboxing::ToolError;
use codex_protocol::models::PermissionProfile;
use codex_sandboxing::SandboxCommand;
@@ -38,7 +39,7 @@ pub(crate) fn build_sandbox_command(
/// POSIX-only helper: for commands produced by `Shell::derive_exec_args`
/// for Bash/Zsh/sh of the form `[shell_path, "-lc", "<script>"]`, and
/// when a snapshot is configured on the session shell, rewrite the argv
/// when a snapshot is configured for the session shell, rewrite the argv
/// to a single non-login shell that sources the snapshot before running
/// the original script:
///
@@ -51,6 +52,7 @@ pub(crate) fn build_sandbox_command(
pub(crate) fn maybe_wrap_shell_lc_with_snapshot(
command: &[String],
session_shell: &Shell,
session_shell_snapshot: Option<&ShellSnapshot>,
cwd: &Path,
explicit_env_overrides: &HashMap<String, String>,
) -> Vec<String> {
@@ -58,7 +60,7 @@ pub(crate) fn maybe_wrap_shell_lc_with_snapshot(
return command.to_vec();
}
let Some(snapshot) = session_shell.shell_snapshot() else {
let Some(snapshot) = session_shell_snapshot else {
return command.to_vec();
};

View File

@@ -4,25 +4,24 @@ use crate::shell_snapshot::ShellSnapshot;
use pretty_assertions::assert_eq;
use std::path::PathBuf;
use std::process::Command;
use std::sync::Arc;
use tempfile::tempdir;
use tokio::sync::watch;
fn shell_with_snapshot(
shell_type: ShellType,
shell_path: &str,
snapshot_path: PathBuf,
snapshot_cwd: PathBuf,
) -> Shell {
let (_tx, shell_snapshot) = watch::channel(Some(Arc::new(ShellSnapshot {
path: snapshot_path,
cwd: snapshot_cwd,
})));
Shell {
shell_type,
shell_path: PathBuf::from(shell_path),
shell_snapshot,
}
) -> (Shell, ShellSnapshot) {
(
Shell {
shell_type,
shell_path: PathBuf::from(shell_path),
},
ShellSnapshot {
path: snapshot_path,
cwd: snapshot_cwd,
},
)
}
#[test]
@@ -30,7 +29,7 @@ fn maybe_wrap_shell_lc_with_snapshot_bootstraps_in_user_shell() {
let dir = tempdir().expect("create temp dir");
let snapshot_path = dir.path().join("snapshot.sh");
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
let session_shell = shell_with_snapshot(
let (session_shell, session_shell_snapshot) = shell_with_snapshot(
ShellType::Zsh,
"/bin/zsh",
snapshot_path,
@@ -42,8 +41,13 @@ fn maybe_wrap_shell_lc_with_snapshot_bootstraps_in_user_shell() {
"echo hello".to_string(),
];
let rewritten =
maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path(), &HashMap::new());
let rewritten = maybe_wrap_shell_lc_with_snapshot(
&command,
&session_shell,
Some(&session_shell_snapshot),
dir.path(),
&HashMap::new(),
);
assert_eq!(rewritten[0], "/bin/zsh");
assert_eq!(rewritten[1], "-c");
@@ -56,7 +60,7 @@ fn maybe_wrap_shell_lc_with_snapshot_escapes_single_quotes() {
let dir = tempdir().expect("create temp dir");
let snapshot_path = dir.path().join("snapshot.sh");
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
let session_shell = shell_with_snapshot(
let (session_shell, session_shell_snapshot) = shell_with_snapshot(
ShellType::Zsh,
"/bin/zsh",
snapshot_path,
@@ -68,8 +72,13 @@ fn maybe_wrap_shell_lc_with_snapshot_escapes_single_quotes() {
"echo 'hello'".to_string(),
];
let rewritten =
maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path(), &HashMap::new());
let rewritten = maybe_wrap_shell_lc_with_snapshot(
&command,
&session_shell,
Some(&session_shell_snapshot),
dir.path(),
&HashMap::new(),
);
assert!(rewritten[2].contains(r#"exec '/bin/bash' -c 'echo '"'"'hello'"'"''"#));
}
@@ -79,7 +88,7 @@ fn maybe_wrap_shell_lc_with_snapshot_uses_bash_bootstrap_shell() {
let dir = tempdir().expect("create temp dir");
let snapshot_path = dir.path().join("snapshot.sh");
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
let session_shell = shell_with_snapshot(
let (session_shell, session_shell_snapshot) = shell_with_snapshot(
ShellType::Bash,
"/bin/bash",
snapshot_path,
@@ -91,8 +100,13 @@ fn maybe_wrap_shell_lc_with_snapshot_uses_bash_bootstrap_shell() {
"echo hello".to_string(),
];
let rewritten =
maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path(), &HashMap::new());
let rewritten = maybe_wrap_shell_lc_with_snapshot(
&command,
&session_shell,
Some(&session_shell_snapshot),
dir.path(),
&HashMap::new(),
);
assert_eq!(rewritten[0], "/bin/bash");
assert_eq!(rewritten[1], "-c");
@@ -105,7 +119,7 @@ fn maybe_wrap_shell_lc_with_snapshot_uses_sh_bootstrap_shell() {
let dir = tempdir().expect("create temp dir");
let snapshot_path = dir.path().join("snapshot.sh");
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
let session_shell = shell_with_snapshot(
let (session_shell, session_shell_snapshot) = shell_with_snapshot(
ShellType::Sh,
"/bin/sh",
snapshot_path,
@@ -117,8 +131,13 @@ fn maybe_wrap_shell_lc_with_snapshot_uses_sh_bootstrap_shell() {
"echo hello".to_string(),
];
let rewritten =
maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path(), &HashMap::new());
let rewritten = maybe_wrap_shell_lc_with_snapshot(
&command,
&session_shell,
Some(&session_shell_snapshot),
dir.path(),
&HashMap::new(),
);
assert_eq!(rewritten[0], "/bin/sh");
assert_eq!(rewritten[1], "-c");
@@ -131,7 +150,7 @@ fn maybe_wrap_shell_lc_with_snapshot_preserves_trailing_args() {
let dir = tempdir().expect("create temp dir");
let snapshot_path = dir.path().join("snapshot.sh");
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
let session_shell = shell_with_snapshot(
let (session_shell, session_shell_snapshot) = shell_with_snapshot(
ShellType::Zsh,
"/bin/zsh",
snapshot_path,
@@ -145,8 +164,13 @@ fn maybe_wrap_shell_lc_with_snapshot_preserves_trailing_args() {
"arg1".to_string(),
];
let rewritten =
maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path(), &HashMap::new());
let rewritten = maybe_wrap_shell_lc_with_snapshot(
&command,
&session_shell,
Some(&session_shell_snapshot),
dir.path(),
&HashMap::new(),
);
assert!(
rewritten[2]
@@ -163,7 +187,7 @@ fn maybe_wrap_shell_lc_with_snapshot_skips_when_cwd_mismatch() {
let command_cwd = dir.path().join("worktree-b");
std::fs::create_dir_all(&snapshot_cwd).expect("create snapshot cwd");
std::fs::create_dir_all(&command_cwd).expect("create command cwd");
let session_shell =
let (session_shell, session_shell_snapshot) =
shell_with_snapshot(ShellType::Zsh, "/bin/zsh", snapshot_path, snapshot_cwd);
let command = vec![
"/bin/bash".to_string(),
@@ -171,8 +195,13 @@ fn maybe_wrap_shell_lc_with_snapshot_skips_when_cwd_mismatch() {
"echo hello".to_string(),
];
let rewritten =
maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, &command_cwd, &HashMap::new());
let rewritten = maybe_wrap_shell_lc_with_snapshot(
&command,
&session_shell,
Some(&session_shell_snapshot),
&command_cwd,
&HashMap::new(),
);
assert_eq!(rewritten, command);
}
@@ -182,7 +211,7 @@ fn maybe_wrap_shell_lc_with_snapshot_accepts_dot_alias_cwd() {
let dir = tempdir().expect("create temp dir");
let snapshot_path = dir.path().join("snapshot.sh");
std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot");
let session_shell = shell_with_snapshot(
let (session_shell, session_shell_snapshot) = shell_with_snapshot(
ShellType::Zsh,
"/bin/zsh",
snapshot_path,
@@ -195,8 +224,13 @@ fn maybe_wrap_shell_lc_with_snapshot_accepts_dot_alias_cwd() {
];
let command_cwd = dir.path().join(".");
let rewritten =
maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, &command_cwd, &HashMap::new());
let rewritten = maybe_wrap_shell_lc_with_snapshot(
&command,
&session_shell,
Some(&session_shell_snapshot),
&command_cwd,
&HashMap::new(),
);
assert_eq!(rewritten[0], "/bin/zsh");
assert_eq!(rewritten[1], "-c");
@@ -213,7 +247,7 @@ fn maybe_wrap_shell_lc_with_snapshot_restores_explicit_override_precedence() {
"# Snapshot file\nexport TEST_ENV_SNAPSHOT=global\nexport SNAPSHOT_ONLY=from_snapshot\n",
)
.expect("write snapshot");
let session_shell = shell_with_snapshot(
let (session_shell, session_shell_snapshot) = shell_with_snapshot(
ShellType::Bash,
"/bin/bash",
snapshot_path,
@@ -229,6 +263,7 @@ fn maybe_wrap_shell_lc_with_snapshot_restores_explicit_override_precedence() {
let rewritten = maybe_wrap_shell_lc_with_snapshot(
&command,
&session_shell,
Some(&session_shell_snapshot),
dir.path(),
&explicit_env_overrides,
);
@@ -254,7 +289,7 @@ fn maybe_wrap_shell_lc_with_snapshot_keeps_snapshot_path_without_override() {
"# Snapshot file\nexport PATH='/snapshot/bin'\n",
)
.expect("write snapshot");
let session_shell = shell_with_snapshot(
let (session_shell, session_shell_snapshot) = shell_with_snapshot(
ShellType::Bash,
"/bin/bash",
snapshot_path,
@@ -265,8 +300,13 @@ fn maybe_wrap_shell_lc_with_snapshot_keeps_snapshot_path_without_override() {
"-lc".to_string(),
"printf '%s' \"$PATH\"".to_string(),
];
let rewritten =
maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path(), &HashMap::new());
let rewritten = maybe_wrap_shell_lc_with_snapshot(
&command,
&session_shell,
Some(&session_shell_snapshot),
dir.path(),
&HashMap::new(),
);
let output = Command::new(&rewritten[0])
.args(&rewritten[1..])
.output()
@@ -285,7 +325,7 @@ fn maybe_wrap_shell_lc_with_snapshot_applies_explicit_path_override() {
"# Snapshot file\nexport PATH='/snapshot/bin'\n",
)
.expect("write snapshot");
let session_shell = shell_with_snapshot(
let (session_shell, session_shell_snapshot) = shell_with_snapshot(
ShellType::Bash,
"/bin/bash",
snapshot_path,
@@ -300,6 +340,7 @@ fn maybe_wrap_shell_lc_with_snapshot_applies_explicit_path_override() {
let rewritten = maybe_wrap_shell_lc_with_snapshot(
&command,
&session_shell,
Some(&session_shell_snapshot),
dir.path(),
&explicit_env_overrides,
);
@@ -322,7 +363,7 @@ fn maybe_wrap_shell_lc_with_snapshot_does_not_embed_override_values_in_argv() {
"# Snapshot file\nexport OPENAI_API_KEY='snapshot-value'\n",
)
.expect("write snapshot");
let session_shell = shell_with_snapshot(
let (session_shell, session_shell_snapshot) = shell_with_snapshot(
ShellType::Bash,
"/bin/bash",
snapshot_path,
@@ -340,6 +381,7 @@ fn maybe_wrap_shell_lc_with_snapshot_does_not_embed_override_values_in_argv() {
let rewritten = maybe_wrap_shell_lc_with_snapshot(
&command,
&session_shell,
Some(&session_shell_snapshot),
dir.path(),
&explicit_env_overrides,
);
@@ -366,7 +408,7 @@ fn maybe_wrap_shell_lc_with_snapshot_preserves_unset_override_variables() {
"# Snapshot file\nexport CODEX_TEST_UNSET_OVERRIDE='snapshot-value'\n",
)
.expect("write snapshot");
let session_shell = shell_with_snapshot(
let (session_shell, session_shell_snapshot) = shell_with_snapshot(
ShellType::Bash,
"/bin/bash",
snapshot_path,
@@ -384,6 +426,7 @@ fn maybe_wrap_shell_lc_with_snapshot_preserves_unset_override_variables() {
let rewritten = maybe_wrap_shell_lc_with_snapshot(
&command,
&session_shell,
Some(&session_shell_snapshot),
dir.path(),
&explicit_env_overrides,
);

View File

@@ -221,9 +221,11 @@ impl ToolRuntime<ShellRequest, ExecToolCallOutput> for ShellRuntime {
ctx: &ToolCtx,
) -> Result<ExecToolCallOutput, ToolError> {
let session_shell = ctx.session.user_shell();
let session_shell_snapshot = ctx.session.shell_snapshot();
let command = maybe_wrap_shell_lc_with_snapshot(
&req.command,
session_shell.as_ref(),
session_shell_snapshot.as_deref(),
&req.cwd,
&req.explicit_env_overrides,
);

View File

@@ -202,9 +202,11 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
) -> Result<UnifiedExecProcess, ToolError> {
let base_command = &req.command;
let session_shell = ctx.session.user_shell();
let session_shell_snapshot = ctx.session.shell_snapshot();
let command = maybe_wrap_shell_lc_with_snapshot(
base_command,
session_shell.as_ref(),
session_shell_snapshot.as_deref(),
&req.cwd,
&req.explicit_env_overrides,
);

View File

@@ -1,5 +1,3 @@
use crate::shell::Shell;
use crate::shell::ShellType;
use crate::tools::handlers::agent_jobs::BatchJobHandler;
use crate::tools::handlers::multi_agents_common::DEFAULT_WAIT_TIMEOUT_MS;
use crate::tools::handlers::multi_agents_common::MAX_WAIT_TIMEOUT_MS;
@@ -12,23 +10,12 @@ use codex_tools::DiscoverableTool;
use codex_tools::ToolHandlerKind;
use codex_tools::ToolRegistryPlanAppTool;
use codex_tools::ToolRegistryPlanParams;
use codex_tools::ToolUserShellType;
use codex_tools::ToolsConfig;
use codex_tools::WaitAgentTimeoutOptions;
use codex_tools::build_tool_registry_plan;
use std::collections::HashMap;
use std::sync::Arc;
pub(crate) fn tool_user_shell_type(user_shell: &Shell) -> ToolUserShellType {
match user_shell.shell_type {
ShellType::Zsh => ToolUserShellType::Zsh,
ShellType::Bash => ToolUserShellType::Bash,
ShellType::PowerShell => ToolUserShellType::PowerShell,
ShellType::Sh => ToolUserShellType::Sh,
ShellType::Cmd => ToolUserShellType::Cmd,
}
}
pub(crate) fn build_specs_with_discoverable_tools(
config: &ToolsConfig,
mcp_tools: Option<HashMap<String, rmcp::model::Tool>>,

View File

@@ -556,7 +556,6 @@ fn shell_zsh_fork_prefers_shell_command_over_unified_exec() {
let user_shell = Shell {
shell_type: ShellType::Zsh,
shell_path: PathBuf::from("/bin/zsh"),
shell_snapshot: crate::shell::empty_shell_snapshot_receiver(),
};
assert_eq!(tools_config.shell_type, ConfigShellToolType::ShellCommand);
@@ -571,7 +570,7 @@ fn shell_zsh_fork_prefers_shell_command_over_unified_exec() {
assert_eq!(
tools_config
.with_unified_exec_shell_mode_for_session(
tool_user_shell_type(&user_shell),
user_shell.shell_type,
Some(&PathBuf::from(if cfg!(windows) {
r"C:\opt\codex\zsh"
} else {

View File

@@ -10,6 +10,7 @@ workspace = true
[dependencies]
base64 = { workspace = true }
codex-protocol = { workspace = true }
codex-shell = { workspace = true }
codex-utils-absolute-path = { workspace = true }
once_cell = { workspace = true }
regex = { workspace = true }

View File

@@ -5,8 +5,8 @@ use tree_sitter::Parser;
use tree_sitter::Tree;
use tree_sitter_bash::LANGUAGE as BASH;
use crate::shell_detect::ShellType;
use crate::shell_detect::detect_shell_type;
use codex_shell::ShellType;
use codex_shell::detect_shell_type;
/// Parse the provided bash source using tree-sitter-bash, returning a Tree on
/// success or None if parsing failed.

View File

@@ -1,7 +1,5 @@
//! Command parsing and safety utilities shared across Codex crates.
mod shell_detect;
pub mod bash;
pub mod command_safety;
pub mod parse_command;

View File

@@ -2,8 +2,8 @@ use std::path::PathBuf;
use codex_utils_absolute_path::AbsolutePathBuf;
use crate::shell_detect::ShellType;
use crate::shell_detect::detect_shell_type;
use codex_shell::ShellType;
use codex_shell::detect_shell_type;
const POWERSHELL_FLAGS: &[&str] = &["-nologo", "-noprofile", "-command", "-c"];

View File

@@ -1,32 +0,0 @@
use std::path::Path;
use std::path::PathBuf;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub(crate) enum ShellType {
Zsh,
Bash,
PowerShell,
Sh,
Cmd,
}
pub(crate) fn detect_shell_type(shell_path: &PathBuf) -> Option<ShellType> {
match shell_path.as_os_str().to_str() {
Some("zsh") => Some(ShellType::Zsh),
Some("sh") => Some(ShellType::Sh),
Some("cmd") => Some(ShellType::Cmd),
Some("bash") => Some(ShellType::Bash),
Some("pwsh") => Some(ShellType::PowerShell),
Some("powershell") => Some(ShellType::PowerShell),
_ => {
let shell_name = shell_path.file_stem();
if let Some(shell_name) = shell_name {
let shell_name_path = Path::new(shell_name);
if shell_name_path != Path::new(shell_path) {
return detect_shell_type(&shell_name_path.to_path_buf());
}
}
None
}
}
}

View File

@@ -0,0 +1,7 @@
load("//:defs.bzl", "codex_rust_crate")
codex_rust_crate(
name = "shell-snapshot",
crate_name = "codex_shell_snapshot",
test_tags = ["no-sandbox"],
)

View File

@@ -0,0 +1,29 @@
[package]
edition.workspace = true
license.workspace = true
name = "codex-shell-snapshot"
version.workspace = true
[lints]
workspace = true
[dependencies]
anyhow = { workspace = true }
codex-otel = { workspace = true }
codex-protocol = { workspace = true }
codex-shell = { workspace = true }
codex-utils-pty = { workspace = true }
tokio = { workspace = true, features = [
"fs",
"macros",
"process",
"rt",
"sync",
"time",
] }
tracing = { workspace = true, features = ["log"] }
[dev-dependencies]
libc = { workspace = true }
pretty_assertions = { workspace = true }
tempfile = { workspace = true }

View File

@@ -0,0 +1,496 @@
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use anyhow::bail;
use codex_otel::SessionTelemetry;
use codex_protocol::ThreadId;
use codex_shell::Shell;
use codex_shell::ShellType;
use std::path::Path;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use std::time::SystemTime;
use tokio::fs;
use tokio::process::Command;
use tokio::sync::watch;
use tokio::time::timeout;
use tracing::Instrument;
use tracing::info_span;
const SNAPSHOT_TIMEOUT: Duration = Duration::from_secs(10);
pub const SNAPSHOT_RETENTION: Duration = Duration::from_secs(60 * 60 * 24 * 3);
pub const SNAPSHOT_DIR: &str = "shell_snapshots";
const EXCLUDED_EXPORT_VARS: &[&str] = &["PWD", "OLDPWD"];
pub type ShellSnapshotSender = watch::Sender<Option<Arc<ShellSnapshot>>>;
pub type ShellSnapshotReceiver = watch::Receiver<Option<Arc<ShellSnapshot>>>;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ShellSnapshot {
pub path: PathBuf,
pub cwd: PathBuf,
}
impl ShellSnapshot {
pub fn start_snapshotting(
codex_home: PathBuf,
session_id: ThreadId,
session_cwd: PathBuf,
shell: Shell,
session_telemetry: SessionTelemetry,
) -> (ShellSnapshotSender, ShellSnapshotReceiver) {
let (shell_snapshot_tx, shell_snapshot_rx) = watch::channel(None);
Self::spawn_snapshot_task(
codex_home,
session_id,
session_cwd,
shell,
shell_snapshot_tx.clone(),
session_telemetry,
);
(shell_snapshot_tx, shell_snapshot_rx)
}
pub fn refresh_snapshot(
codex_home: PathBuf,
session_id: ThreadId,
session_cwd: PathBuf,
shell: Shell,
shell_snapshot_tx: ShellSnapshotSender,
session_telemetry: SessionTelemetry,
) {
Self::spawn_snapshot_task(
codex_home,
session_id,
session_cwd,
shell,
shell_snapshot_tx,
session_telemetry,
);
}
fn spawn_snapshot_task(
codex_home: PathBuf,
session_id: ThreadId,
session_cwd: PathBuf,
snapshot_shell: Shell,
shell_snapshot_tx: ShellSnapshotSender,
session_telemetry: SessionTelemetry,
) {
let snapshot_span = info_span!("shell_snapshot", thread_id = %session_id);
tokio::spawn(
async move {
let timer = session_telemetry.start_timer("codex.shell_snapshot.duration_ms", &[]);
let snapshot = ShellSnapshot::try_new(
&codex_home,
session_id,
session_cwd.as_path(),
&snapshot_shell,
)
.await
.map(Arc::new);
let success = snapshot.is_ok();
let success_tag = if success { "true" } else { "false" };
let _ = timer.map(|timer| timer.record(&[("success", success_tag)]));
let mut counter_tags = vec![("success", success_tag)];
if let Some(failure_reason) = snapshot.as_ref().err() {
counter_tags.push(("failure_reason", *failure_reason));
}
session_telemetry.counter("codex.shell_snapshot", /*inc*/ 1, &counter_tags);
let _ = shell_snapshot_tx.send(snapshot.ok());
}
.instrument(snapshot_span),
);
}
async fn try_new(
codex_home: &Path,
session_id: ThreadId,
session_cwd: &Path,
shell: &Shell,
) -> std::result::Result<Self, &'static str> {
let extension = match shell.shell_type {
ShellType::PowerShell => "ps1",
ShellType::Zsh | ShellType::Bash | ShellType::Sh | ShellType::Cmd => "sh",
};
let nonce = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|duration| duration.as_nanos())
.unwrap_or(0);
let path = codex_home
.join(SNAPSHOT_DIR)
.join(format!("{session_id}.{nonce}.{extension}"));
let temp_path = codex_home
.join(SNAPSHOT_DIR)
.join(format!("{session_id}.tmp-{nonce}"));
let temp_path = match write_shell_snapshot(shell.shell_type, &temp_path, session_cwd).await
{
Ok(path) => {
tracing::info!("Shell snapshot successfully created: {}", path.display());
path
}
Err(err) => {
tracing::warn!(
"Failed to create shell snapshot for {}: {err:?}",
shell.name()
);
return Err("write_failed");
}
};
let temp_snapshot = Self {
path: temp_path.clone(),
cwd: session_cwd.to_path_buf(),
};
if let Err(err) = validate_snapshot(shell, &temp_snapshot.path, session_cwd).await {
tracing::error!("Shell snapshot validation failed: {err:?}");
remove_snapshot_file(&temp_snapshot.path).await;
return Err("validation_failed");
}
if let Err(err) = fs::rename(&temp_snapshot.path, &path).await {
tracing::warn!("Failed to finalize shell snapshot: {err:?}");
remove_snapshot_file(&temp_snapshot.path).await;
return Err("write_failed");
}
Ok(Self {
path,
cwd: session_cwd.to_path_buf(),
})
}
}
impl Drop for ShellSnapshot {
fn drop(&mut self) {
if let Err(err) = std::fs::remove_file(&self.path) {
tracing::warn!(
"Failed to delete shell snapshot at {:?}: {err:?}",
self.path
);
}
}
}
async fn write_shell_snapshot(
shell_type: ShellType,
output_path: &Path,
cwd: &Path,
) -> Result<PathBuf> {
if shell_type == ShellType::PowerShell || shell_type == ShellType::Cmd {
bail!("Shell snapshot not supported yet for {shell_type:?}");
}
let shell = codex_shell::get_shell(shell_type, /*path*/ None)
.with_context(|| format!("No available shell for {shell_type:?}"))?;
let raw_snapshot = capture_snapshot(&shell, cwd).await?;
let snapshot = strip_snapshot_preamble(&raw_snapshot)?;
if let Some(parent) = output_path.parent() {
let parent_display = parent.display();
fs::create_dir_all(parent)
.await
.with_context(|| format!("Failed to create snapshot parent {parent_display}"))?;
}
let snapshot_path = output_path.display();
fs::write(output_path, snapshot)
.await
.with_context(|| format!("Failed to write snapshot to {snapshot_path}"))?;
Ok(output_path.to_path_buf())
}
async fn capture_snapshot(shell: &Shell, cwd: &Path) -> Result<String> {
match shell.shell_type {
ShellType::Zsh => run_shell_script(shell, &zsh_snapshot_script(), cwd).await,
ShellType::Bash => run_shell_script(shell, &bash_snapshot_script(), cwd).await,
ShellType::Sh => run_shell_script(shell, &sh_snapshot_script(), cwd).await,
ShellType::PowerShell => run_shell_script(shell, powershell_snapshot_script(), cwd).await,
ShellType::Cmd => bail!(
"Shell snapshotting is not yet supported for {:?}",
shell.shell_type
),
}
}
fn strip_snapshot_preamble(snapshot: &str) -> Result<String> {
let marker = "# Snapshot file";
let Some(start) = snapshot.find(marker) else {
bail!("Snapshot output missing marker {marker}");
};
Ok(snapshot[start..].to_string())
}
async fn validate_snapshot(shell: &Shell, snapshot_path: &Path, cwd: &Path) -> Result<()> {
let snapshot_path_display = snapshot_path.display();
let script = format!("set -e; . \"{snapshot_path_display}\"");
run_script_with_timeout(
shell,
&script,
SNAPSHOT_TIMEOUT,
/*use_login_shell*/ false,
cwd,
)
.await
.map(|_| ())
}
async fn run_shell_script(shell: &Shell, script: &str, cwd: &Path) -> Result<String> {
run_script_with_timeout(
shell,
script,
SNAPSHOT_TIMEOUT,
/*use_login_shell*/ true,
cwd,
)
.await
}
async fn run_script_with_timeout(
shell: &Shell,
script: &str,
snapshot_timeout: Duration,
use_login_shell: bool,
cwd: &Path,
) -> Result<String> {
let args = shell.derive_exec_args(script, use_login_shell);
let shell_name = shell.name();
let mut handler = Command::new(&args[0]);
handler.args(&args[1..]);
handler.stdin(Stdio::null());
handler.current_dir(cwd);
#[cfg(unix)]
unsafe {
handler.pre_exec(|| {
codex_utils_pty::process_group::detach_from_tty()?;
Ok(())
});
}
handler.kill_on_drop(true);
let output = timeout(snapshot_timeout, handler.output())
.await
.map_err(|_| anyhow!("Snapshot command timed out for {shell_name}"))?
.with_context(|| format!("Failed to execute {shell_name}"))?;
if !output.status.success() {
let status = output.status;
let stderr = String::from_utf8_lossy(&output.stderr);
bail!("Snapshot command exited with status {status}: {stderr}");
}
Ok(String::from_utf8_lossy(&output.stdout).into_owned())
}
fn excluded_exports_regex() -> String {
EXCLUDED_EXPORT_VARS.join("|")
}
fn zsh_snapshot_script() -> String {
let excluded = excluded_exports_regex();
let script = r##"if [[ -n "$ZDOTDIR" ]]; then
rc="$ZDOTDIR/.zshrc"
else
rc="$HOME/.zshrc"
fi
[[ -r "$rc" ]] && . "$rc"
print '# Snapshot file'
print '# Unset all aliases to avoid conflicts with functions'
print 'unalias -a 2>/dev/null || true'
print '# Functions'
functions
print ''
setopt_count=$(setopt | wc -l | tr -d ' ')
print "# setopts $setopt_count"
setopt | sed 's/^/setopt /'
print ''
alias_count=$(alias -L | wc -l | tr -d ' ')
print "# aliases $alias_count"
alias -L
print ''
export_lines=$(export -p | awk '
/^(export|declare -x|typeset -x) / {
line=$0
name=line
sub(/^(export|declare -x|typeset -x) /, "", name)
sub(/=.*/, "", name)
if (name ~ /^(EXCLUDED_EXPORTS)$/) {
next
}
if (name ~ /^[A-Za-z_][A-Za-z0-9_]*$/) {
print line
}
}')
export_count=$(printf '%s\n' "$export_lines" | sed '/^$/d' | wc -l | tr -d ' ')
print "# exports $export_count"
if [[ -n "$export_lines" ]]; then
print -r -- "$export_lines"
fi
"##;
script.replace("EXCLUDED_EXPORTS", &excluded)
}
fn bash_snapshot_script() -> String {
let excluded = excluded_exports_regex();
let script = r##"if [ -z "$BASH_ENV" ] && [ -r "$HOME/.bashrc" ]; then
. "$HOME/.bashrc"
fi
echo '# Snapshot file'
echo '# Unset all aliases to avoid conflicts with functions'
unalias -a 2>/dev/null || true
echo '# Functions'
declare -f
echo ''
bash_opts=$(set -o | awk '$2=="on"{print $1}')
bash_opt_count=$(printf '%s\n' "$bash_opts" | sed '/^$/d' | wc -l | tr -d ' ')
echo "# setopts $bash_opt_count"
if [ -n "$bash_opts" ]; then
printf 'set -o %s\n' $bash_opts
fi
echo ''
alias_count=$(alias -p | wc -l | tr -d ' ')
echo "# aliases $alias_count"
alias -p
echo ''
export_lines=$(
while IFS= read -r name; do
if [[ "$name" =~ ^(EXCLUDED_EXPORTS)$ ]]; then
continue
fi
if [[ ! "$name" =~ ^[A-Za-z_][A-Za-z0-9_]*$ ]]; then
continue
fi
declare -xp "$name" 2>/dev/null || true
done < <(compgen -e)
)
export_count=$(printf '%s\n' "$export_lines" | sed '/^$/d' | wc -l | tr -d ' ')
echo "# exports $export_count"
if [ -n "$export_lines" ]; then
printf '%s\n' "$export_lines"
fi
"##;
script.replace("EXCLUDED_EXPORTS", &excluded)
}
fn sh_snapshot_script() -> String {
let excluded = excluded_exports_regex();
let script = r##"if [ -n "$ENV" ] && [ -r "$ENV" ]; then
. "$ENV"
fi
echo '# Snapshot file'
echo '# Unset all aliases to avoid conflicts with functions'
unalias -a 2>/dev/null || true
echo '# Functions'
if command -v typeset >/dev/null 2>&1; then
typeset -f
elif command -v declare >/dev/null 2>&1; then
declare -f
fi
echo ''
if set -o >/dev/null 2>&1; then
sh_opts=$(set -o | awk '$2=="on"{print $1}')
sh_opt_count=$(printf '%s\n' "$sh_opts" | sed '/^$/d' | wc -l | tr -d ' ')
echo "# setopts $sh_opt_count"
if [ -n "$sh_opts" ]; then
printf 'set -o %s\n' $sh_opts
fi
else
echo '# setopts 0'
fi
echo ''
if alias >/dev/null 2>&1; then
alias_count=$(alias | wc -l | tr -d ' ')
echo "# aliases $alias_count"
alias
echo ''
else
echo '# aliases 0'
fi
if export -p >/dev/null 2>&1; then
export_lines=$(export -p | awk '
/^(export|declare -x|typeset -x) / {
line=$0
name=line
sub(/^(export|declare -x|typeset -x) /, "", name)
sub(/=.*/, "", name)
if (name ~ /^(EXCLUDED_EXPORTS)$/) {
next
}
if (name ~ /^[A-Za-z_][A-Za-z0-9_]*$/) {
print line
}
}')
export_count=$(printf '%s\n' "$export_lines" | sed '/^$/d' | wc -l | tr -d ' ')
echo "# exports $export_count"
if [ -n "$export_lines" ]; then
printf '%s\n' "$export_lines"
fi
else
export_count=$(env | sort | awk -F= '$1 ~ /^[A-Za-z_][A-Za-z0-9_]*$/ { count++ } END { print count }')
echo "# exports $export_count"
env | sort | while IFS='=' read -r key value; do
case "$key" in
""|[0-9]*|*[!A-Za-z0-9_]*|EXCLUDED_EXPORTS) continue ;;
esac
escaped=$(printf "%s" "$value" | sed "s/'/'\"'\"'/g")
printf "export %s='%s'\n" "$key" "$escaped"
done
fi
"##;
script.replace("EXCLUDED_EXPORTS", &excluded)
}
fn powershell_snapshot_script() -> &'static str {
r##"$ErrorActionPreference = 'Stop'
Write-Output '# Snapshot file'
Write-Output '# Unset all aliases to avoid conflicts with functions'
Write-Output 'Remove-Item Alias:* -ErrorAction SilentlyContinue'
Write-Output '# Functions'
Get-ChildItem Function: | ForEach-Object {
"function {0} {{`n{1}`n}}" -f $_.Name, $_.Definition
}
Write-Output ''
$aliases = Get-Alias
Write-Output ("# aliases " + $aliases.Count)
$aliases | ForEach-Object {
"Set-Alias -Name {0} -Value {1}" -f $_.Name, $_.Definition
}
Write-Output ''
$envVars = Get-ChildItem Env:
Write-Output ("# exports " + $envVars.Count)
$envVars | ForEach-Object {
$escaped = $_.Value -replace "'", "''"
"`$env:{0}='{1}'" -f $_.Name, $escaped
}
"##
}
pub async fn remove_snapshot_file(path: &Path) {
if let Err(err) = fs::remove_file(path).await {
tracing::warn!("Failed to delete shell snapshot at {:?}: {err:?}", path);
}
}
pub fn snapshot_session_id_from_file_name(file_name: &str) -> Option<&str> {
let (stem, extension) = file_name.rsplit_once('.')?;
match extension {
"sh" | "ps1" => Some(
stem.split_once('.')
.map_or(stem, |(session_id, _generation)| session_id),
),
_ if extension.starts_with("tmp-") => Some(stem),
_ => None,
}
}
#[cfg(test)]
#[path = "shell_snapshot_tests.rs"]
mod tests;

View File

@@ -0,0 +1,378 @@
use super::*;
use pretty_assertions::assert_eq;
#[cfg(unix)]
use std::process::Command as StdCommand;
use tempfile::tempdir;
#[cfg(unix)]
struct BlockingStdinPipe {
original: i32,
write_end: i32,
}
#[cfg(unix)]
impl BlockingStdinPipe {
fn install() -> Result<Self> {
let mut fds = [0i32; 2];
if unsafe { libc::pipe(fds.as_mut_ptr()) } == -1 {
return Err(std::io::Error::last_os_error()).context("create stdin pipe");
}
let original = unsafe { libc::dup(libc::STDIN_FILENO) };
if original == -1 {
let err = std::io::Error::last_os_error();
unsafe {
libc::close(fds[0]);
libc::close(fds[1]);
}
return Err(err).context("dup stdin");
}
if unsafe { libc::dup2(fds[0], libc::STDIN_FILENO) } == -1 {
let err = std::io::Error::last_os_error();
unsafe {
libc::close(fds[0]);
libc::close(fds[1]);
libc::close(original);
}
return Err(err).context("replace stdin");
}
unsafe {
libc::close(fds[0]);
}
Ok(Self {
original,
write_end: fds[1],
})
}
}
#[cfg(unix)]
impl Drop for BlockingStdinPipe {
fn drop(&mut self) {
unsafe {
libc::dup2(self.original, libc::STDIN_FILENO);
libc::close(self.original);
libc::close(self.write_end);
}
}
}
#[cfg(not(target_os = "windows"))]
fn assert_posix_snapshot_sections(snapshot: &str) {
assert!(snapshot.contains("# Snapshot file"));
assert!(snapshot.contains("aliases "));
assert!(snapshot.contains("exports "));
assert!(
snapshot.contains("PATH"),
"snapshot should capture a PATH export"
);
assert!(snapshot.contains("setopts "));
}
async fn get_snapshot(shell_type: ShellType) -> Result<String> {
let dir = tempdir()?;
let path = dir.path().join("snapshot.sh");
write_shell_snapshot(shell_type, &path, dir.path()).await?;
let content = fs::read_to_string(&path).await?;
Ok(content)
}
#[test]
fn strip_snapshot_preamble_removes_leading_output() {
let snapshot = "noise\n# Snapshot file\nexport PATH=/bin\n";
let cleaned = strip_snapshot_preamble(snapshot).expect("snapshot marker exists");
assert_eq!(cleaned, "# Snapshot file\nexport PATH=/bin\n");
}
#[test]
fn strip_snapshot_preamble_requires_marker() {
let result = strip_snapshot_preamble("missing header");
assert!(result.is_err());
}
#[test]
fn snapshot_file_name_parser_supports_legacy_and_suffixed_names() {
let session_id = "019cf82b-6a62-7700-bbbd-46909794ef89";
assert_eq!(
snapshot_session_id_from_file_name(&format!("{session_id}.sh")),
Some(session_id)
);
assert_eq!(
snapshot_session_id_from_file_name(&format!("{session_id}.123.sh")),
Some(session_id)
);
assert_eq!(
snapshot_session_id_from_file_name(&format!("{session_id}.tmp-123")),
Some(session_id)
);
assert_eq!(
snapshot_session_id_from_file_name("not-a-snapshot.txt"),
None
);
}
#[cfg(unix)]
#[test]
fn bash_snapshot_filters_invalid_exports() -> Result<()> {
let output = StdCommand::new("/bin/bash")
.arg("-c")
.arg(bash_snapshot_script())
.env("BASH_ENV", "/dev/null")
.env("VALID_NAME", "ok")
.env("PWD", "/tmp/stale")
.env("NEXTEST_BIN_EXE_codex-write-config-schema", "/path/to/bin")
.env("BAD-NAME", "broken")
.output()?;
assert!(output.status.success());
let stdout = String::from_utf8_lossy(&output.stdout);
assert!(stdout.contains("VALID_NAME"));
assert!(!stdout.contains("PWD=/tmp/stale"));
assert!(!stdout.contains("NEXTEST_BIN_EXE_codex-write-config-schema"));
assert!(!stdout.contains("BAD-NAME"));
Ok(())
}
#[cfg(unix)]
#[test]
fn bash_snapshot_preserves_multiline_exports() -> Result<()> {
let multiline_cert = "-----BEGIN CERTIFICATE-----\nabc\n-----END CERTIFICATE-----";
let output = StdCommand::new("/bin/bash")
.arg("-c")
.arg(bash_snapshot_script())
.env("BASH_ENV", "/dev/null")
.env("MULTILINE_CERT", multiline_cert)
.output()?;
assert!(output.status.success());
let stdout = String::from_utf8_lossy(&output.stdout);
assert!(
stdout.contains("MULTILINE_CERT=") || stdout.contains("MULTILINE_CERT"),
"snapshot should include the multiline export name"
);
let dir = tempdir()?;
let snapshot_path = dir.path().join("snapshot.sh");
std::fs::write(&snapshot_path, stdout.as_bytes())?;
let validate = StdCommand::new("/bin/bash")
.arg("-c")
.arg("set -e; . \"$1\"")
.arg("bash")
.arg(&snapshot_path)
.env("BASH_ENV", "/dev/null")
.output()?;
assert!(
validate.status.success(),
"snapshot validation failed: {}",
String::from_utf8_lossy(&validate.stderr)
);
Ok(())
}
#[cfg(unix)]
#[tokio::test]
async fn try_new_creates_and_deletes_snapshot_file() -> Result<()> {
let dir = tempdir()?;
let shell = Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
};
let snapshot = ShellSnapshot::try_new(dir.path(), ThreadId::new(), dir.path(), &shell)
.await
.expect("snapshot should be created");
let path = snapshot.path.clone();
assert!(path.exists());
assert_eq!(snapshot.cwd, dir.path().to_path_buf());
drop(snapshot);
assert!(!path.exists());
Ok(())
}
#[cfg(unix)]
#[tokio::test]
async fn try_new_uses_distinct_generation_paths() -> Result<()> {
let dir = tempdir()?;
let session_id = ThreadId::new();
let shell = Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
};
let initial_snapshot = ShellSnapshot::try_new(dir.path(), session_id, dir.path(), &shell)
.await
.expect("initial snapshot should be created");
let refreshed_snapshot = ShellSnapshot::try_new(dir.path(), session_id, dir.path(), &shell)
.await
.expect("refreshed snapshot should be created");
let initial_path = initial_snapshot.path.clone();
let refreshed_path = refreshed_snapshot.path.clone();
assert_ne!(initial_path, refreshed_path);
assert_eq!(initial_path.exists(), true);
assert_eq!(refreshed_path.exists(), true);
drop(initial_snapshot);
assert_eq!(initial_path.exists(), false);
assert_eq!(refreshed_path.exists(), true);
drop(refreshed_snapshot);
assert_eq!(refreshed_path.exists(), false);
Ok(())
}
#[cfg(unix)]
#[tokio::test]
async fn snapshot_shell_does_not_inherit_stdin() -> Result<()> {
let _stdin_guard = BlockingStdinPipe::install()?;
let dir = tempdir()?;
let home = dir.path();
let read_status_path = home.join("stdin-read-status");
let read_status_display = read_status_path.display();
let bashrc = format!("read -t 1 -r ignored\nprintf '%s' \"$?\" > \"{read_status_display}\"\n");
fs::write(home.join(".bashrc"), bashrc).await?;
let shell = Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
};
let home_display = home.display();
let script = format!(
"HOME=\"{home_display}\"; export HOME; {}",
bash_snapshot_script()
);
let output = run_script_with_timeout(
&shell,
&script,
Duration::from_secs(2),
/*use_login_shell*/ true,
home,
)
.await
.context("run snapshot command")?;
let read_status = fs::read_to_string(&read_status_path)
.await
.context("read stdin probe status")?;
assert_eq!(
read_status, "1",
"expected shell startup read to see EOF on stdin; status={read_status:?}"
);
assert!(
output.contains("# Snapshot file"),
"expected snapshot marker in output; output={output:?}"
);
Ok(())
}
#[cfg(target_os = "linux")]
#[tokio::test]
async fn timed_out_snapshot_shell_is_terminated() -> Result<()> {
use std::process::Stdio;
use tokio::time::Duration as TokioDuration;
use tokio::time::Instant;
use tokio::time::sleep;
let dir = tempdir()?;
let pid_path = dir.path().join("pid");
let script = format!("echo $$ > \"{}\"; sleep 30", pid_path.display());
let shell = Shell {
shell_type: ShellType::Sh,
shell_path: PathBuf::from("/bin/sh"),
};
let err = run_script_with_timeout(
&shell,
&script,
Duration::from_secs(1),
/*use_login_shell*/ true,
dir.path(),
)
.await
.expect_err("snapshot shell should time out");
assert!(
err.to_string().contains("timed out"),
"expected timeout error, got {err:?}"
);
let pid = fs::read_to_string(&pid_path)
.await
.expect("snapshot shell writes its pid before timing out")
.trim()
.parse::<i32>()?;
let deadline = Instant::now() + TokioDuration::from_secs(1);
loop {
let kill_status = StdCommand::new("kill")
.arg("-0")
.arg(pid.to_string())
.stderr(Stdio::null())
.stdout(Stdio::null())
.status()?;
if !kill_status.success() {
break;
}
if Instant::now() >= deadline {
panic!("timed out snapshot shell is still alive after grace period");
}
sleep(TokioDuration::from_millis(50)).await;
}
Ok(())
}
#[cfg(target_os = "macos")]
#[tokio::test]
async fn macos_zsh_snapshot_includes_sections() -> Result<()> {
let snapshot = get_snapshot(ShellType::Zsh).await?;
assert_posix_snapshot_sections(&snapshot);
Ok(())
}
#[cfg(target_os = "linux")]
#[tokio::test]
async fn linux_bash_snapshot_includes_sections() -> Result<()> {
let snapshot = get_snapshot(ShellType::Bash).await?;
assert_posix_snapshot_sections(&snapshot);
Ok(())
}
#[cfg(target_os = "linux")]
#[tokio::test]
async fn linux_sh_snapshot_includes_sections() -> Result<()> {
let snapshot = get_snapshot(ShellType::Sh).await?;
assert_posix_snapshot_sections(&snapshot);
Ok(())
}
#[cfg(target_os = "windows")]
#[ignore]
#[tokio::test]
async fn windows_powershell_snapshot_includes_sections() -> Result<()> {
let snapshot = get_snapshot(ShellType::PowerShell).await?;
assert!(snapshot.contains("# Snapshot file"));
assert!(snapshot.contains("aliases "));
assert!(snapshot.contains("exports "));
Ok(())
}

View File

@@ -0,0 +1,7 @@
load("//:defs.bzl", "codex_rust_crate")
codex_rust_crate(
name = "shell",
crate_name = "codex_shell",
test_tags = ["no-sandbox"],
)

15
codex-rs/shell/Cargo.toml Normal file
View File

@@ -0,0 +1,15 @@
[package]
edition.workspace = true
license.workspace = true
name = "codex-shell"
version.workspace = true
[lints]
workspace = true
[dependencies]
serde = { workspace = true, features = ["derive"] }
which = { workspace = true }
[target.'cfg(unix)'.dependencies]
libc = { workspace = true }

356
codex-rs/shell/src/lib.rs Normal file
View File

@@ -0,0 +1,356 @@
use serde::Deserialize;
use serde::Serialize;
use std::path::Path;
use std::path::PathBuf;
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum ShellType {
Zsh,
Bash,
PowerShell,
Sh,
Cmd,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Shell {
pub shell_type: ShellType,
pub shell_path: PathBuf,
}
impl Shell {
pub fn name(&self) -> &'static str {
match self.shell_type {
ShellType::Zsh => "zsh",
ShellType::Bash => "bash",
ShellType::PowerShell => "powershell",
ShellType::Sh => "sh",
ShellType::Cmd => "cmd",
}
}
/// Takes a string of shell and returns the full list of command args to
/// use with `exec()` to run the shell command.
pub fn derive_exec_args(&self, command: &str, use_login_shell: bool) -> Vec<String> {
match self.shell_type {
ShellType::Zsh | ShellType::Bash | ShellType::Sh => {
let arg = if use_login_shell { "-lc" } else { "-c" };
vec![
self.shell_path.to_string_lossy().to_string(),
arg.to_string(),
command.to_string(),
]
}
ShellType::PowerShell => {
let mut args = vec![self.shell_path.to_string_lossy().to_string()];
if !use_login_shell {
args.push("-NoProfile".to_string());
}
args.push("-Command".to_string());
args.push(command.to_string());
args
}
ShellType::Cmd => {
let mut args = vec![self.shell_path.to_string_lossy().to_string()];
args.push("/c".to_string());
args.push(command.to_string());
args
}
}
}
}
pub fn detect_shell_type(shell_path: &Path) -> Option<ShellType> {
match shell_path.as_os_str().to_str() {
Some("zsh") => Some(ShellType::Zsh),
Some("sh") => Some(ShellType::Sh),
Some("cmd") => Some(ShellType::Cmd),
Some("bash") => Some(ShellType::Bash),
Some("pwsh") => Some(ShellType::PowerShell),
Some("powershell") => Some(ShellType::PowerShell),
_ => {
let shell_name = shell_path.file_stem()?;
let shell_name_path = Path::new(shell_name);
if shell_name_path == shell_path {
return None;
}
detect_shell_type(shell_name_path)
}
}
}
pub fn get_shell_by_model_provided_path(shell_path: &Path) -> Shell {
detect_shell_type(shell_path)
.and_then(|shell_type| get_shell(shell_type, Some(shell_path)))
.unwrap_or_else(ultimate_fallback_shell)
}
pub fn get_shell(shell_type: ShellType, path: Option<&Path>) -> Option<Shell> {
match shell_type {
ShellType::Zsh => get_zsh_shell(path),
ShellType::Bash => get_bash_shell(path),
ShellType::PowerShell => get_powershell_shell(path),
ShellType::Sh => get_sh_shell(path),
ShellType::Cmd => get_cmd_shell(path),
}
}
pub fn default_user_shell() -> Shell {
default_user_shell_from_path(get_user_shell_path())
}
#[cfg(unix)]
fn get_user_shell_path() -> Option<PathBuf> {
let uid = unsafe { libc::getuid() };
use std::ffi::CStr;
use std::mem::MaybeUninit;
use std::ptr;
let mut passwd = MaybeUninit::<libc::passwd>::uninit();
// We cannot use getpwuid here: it returns pointers into libc-managed
// storage, which is not safe to read concurrently on all targets (the musl
// static build used by the CLI can segfault when parallel callers race on
// that buffer). getpwuid_r keeps the passwd data in caller-owned memory.
let suggested_buffer_len = unsafe { libc::sysconf(libc::_SC_GETPW_R_SIZE_MAX) };
let buffer_len = usize::try_from(suggested_buffer_len)
.ok()
.filter(|len| *len > 0)
.unwrap_or(1024);
let mut buffer = vec![0; buffer_len];
loop {
let mut result = ptr::null_mut();
let status = unsafe {
libc::getpwuid_r(
uid,
passwd.as_mut_ptr(),
buffer.as_mut_ptr().cast(),
buffer.len(),
&mut result,
)
};
if status == 0 {
if result.is_null() {
return None;
}
let passwd = unsafe { passwd.assume_init_ref() };
if passwd.pw_shell.is_null() {
return None;
}
let shell_path = unsafe { CStr::from_ptr(passwd.pw_shell) }
.to_string_lossy()
.into_owned();
return Some(PathBuf::from(shell_path));
}
if status != libc::ERANGE {
return None;
}
// Retry with a larger buffer until libc can materialize the passwd entry.
let new_len = buffer.len().checked_mul(2)?;
if new_len > 1024 * 1024 {
return None;
}
buffer.resize(new_len, 0);
}
}
#[cfg(not(unix))]
fn get_user_shell_path() -> Option<PathBuf> {
None
}
fn default_user_shell_from_path(user_shell_path: Option<PathBuf>) -> Shell {
if cfg!(windows) {
get_shell(ShellType::PowerShell, /*path*/ None).unwrap_or_else(ultimate_fallback_shell)
} else {
let user_default_shell = user_shell_path
.and_then(|shell| detect_shell_type(&shell))
.and_then(|shell_type| get_shell(shell_type, /*path*/ None));
let shell_with_fallback = if cfg!(target_os = "macos") {
user_default_shell
.or_else(|| get_shell(ShellType::Zsh, /*path*/ None))
.or_else(|| get_shell(ShellType::Bash, /*path*/ None))
} else {
user_default_shell
.or_else(|| get_shell(ShellType::Bash, /*path*/ None))
.or_else(|| get_shell(ShellType::Zsh, /*path*/ None))
};
shell_with_fallback.unwrap_or_else(ultimate_fallback_shell)
}
}
fn file_exists(path: &Path) -> Option<PathBuf> {
if std::fs::metadata(path).is_ok_and(|metadata| metadata.is_file()) {
Some(path.to_path_buf())
} else {
None
}
}
fn get_shell_path(
shell_type: ShellType,
provided_path: Option<&Path>,
binary_name: &str,
fallback_paths: &[&str],
) -> Option<PathBuf> {
// If exact provided path exists, use it.
if provided_path.and_then(file_exists).is_some() {
return provided_path.map(Path::to_path_buf);
}
// Check whether the shell we are trying to load is the user's default
// shell and prefer that exact path when available.
let default_shell_path = get_user_shell_path();
if let Some(default_shell_path) = default_shell_path
&& detect_shell_type(&default_shell_path) == Some(shell_type)
&& file_exists(&default_shell_path).is_some()
{
return Some(default_shell_path);
}
if let Ok(path) = which::which(binary_name) {
return Some(path);
}
for path in fallback_paths {
if let Some(path) = file_exists(Path::new(path)) {
return Some(path);
}
}
None
}
fn get_zsh_shell(path: Option<&Path>) -> Option<Shell> {
let shell_path = get_shell_path(ShellType::Zsh, path, "zsh", &["/bin/zsh"]);
shell_path.map(|shell_path| Shell {
shell_type: ShellType::Zsh,
shell_path,
})
}
fn get_bash_shell(path: Option<&Path>) -> Option<Shell> {
let shell_path = get_shell_path(ShellType::Bash, path, "bash", &["/bin/bash"]);
shell_path.map(|shell_path| Shell {
shell_type: ShellType::Bash,
shell_path,
})
}
fn get_sh_shell(path: Option<&Path>) -> Option<Shell> {
let shell_path = get_shell_path(ShellType::Sh, path, "sh", &["/bin/sh"]);
shell_path.map(|shell_path| Shell {
shell_type: ShellType::Sh,
shell_path,
})
}
fn get_powershell_shell(path: Option<&Path>) -> Option<Shell> {
let shell_path = get_shell_path(
ShellType::PowerShell,
path,
"pwsh",
&["/usr/local/bin/pwsh"],
)
.or_else(|| get_shell_path(ShellType::PowerShell, path, "powershell", &[]));
shell_path.map(|shell_path| Shell {
shell_type: ShellType::PowerShell,
shell_path,
})
}
fn get_cmd_shell(path: Option<&Path>) -> Option<Shell> {
let shell_path = get_shell_path(ShellType::Cmd, path, "cmd", &[]);
shell_path.map(|shell_path| Shell {
shell_type: ShellType::Cmd,
shell_path,
})
}
fn ultimate_fallback_shell() -> Shell {
if cfg!(windows) {
Shell {
shell_type: ShellType::Cmd,
shell_path: PathBuf::from("cmd.exe"),
}
} else {
Shell {
shell_type: ShellType::Sh,
shell_path: PathBuf::from("/bin/sh"),
}
}
}
#[cfg(test)]
mod detect_shell_type_tests {
use super::*;
#[test]
fn test_detect_shell_type() {
assert_eq!(detect_shell_type(Path::new("zsh")), Some(ShellType::Zsh));
assert_eq!(detect_shell_type(Path::new("bash")), Some(ShellType::Bash));
assert_eq!(
detect_shell_type(Path::new("pwsh")),
Some(ShellType::PowerShell)
);
assert_eq!(
detect_shell_type(Path::new("powershell")),
Some(ShellType::PowerShell)
);
assert_eq!(detect_shell_type(Path::new("fish")), None);
assert_eq!(detect_shell_type(Path::new("other")), None);
assert_eq!(
detect_shell_type(Path::new("/bin/zsh")),
Some(ShellType::Zsh)
);
assert_eq!(
detect_shell_type(Path::new("/bin/bash")),
Some(ShellType::Bash)
);
assert_eq!(
detect_shell_type(Path::new("powershell.exe")),
Some(ShellType::PowerShell)
);
assert_eq!(
detect_shell_type(Path::new(if cfg!(windows) {
"C:\\windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"
} else {
"/usr/local/bin/pwsh"
})),
Some(ShellType::PowerShell)
);
assert_eq!(
detect_shell_type(Path::new("pwsh.exe")),
Some(ShellType::PowerShell)
);
assert_eq!(
detect_shell_type(Path::new("/usr/local/bin/pwsh")),
Some(ShellType::PowerShell)
);
assert_eq!(detect_shell_type(Path::new("/bin/sh")), Some(ShellType::Sh));
assert_eq!(detect_shell_type(Path::new("sh")), Some(ShellType::Sh));
assert_eq!(detect_shell_type(Path::new("cmd")), Some(ShellType::Cmd));
assert_eq!(
detect_shell_type(Path::new("cmd.exe")),
Some(ShellType::Cmd)
);
}
}
#[cfg(test)]
#[path = "shell_tests.rs"]
mod tests;

View File

@@ -1,4 +1,5 @@
use super::*;
use std::path::Path;
use std::path::PathBuf;
use std::process::Command;
@@ -9,7 +10,7 @@ fn detects_zsh() {
let shell_path = zsh_shell.shell_path;
assert_eq!(shell_path, std::path::Path::new("/bin/zsh"));
assert_eq!(shell_path, Path::new("/bin/zsh"));
}
#[test]
@@ -19,7 +20,7 @@ fn fish_fallback_to_zsh() {
let shell_path = zsh_shell.shell_path;
assert_eq!(shell_path, std::path::Path::new("/bin/zsh"));
assert_eq!(shell_path, Path::new("/bin/zsh"));
}
#[test]
@@ -106,7 +107,6 @@ fn derive_exec_args() {
let test_bash_shell = Shell {
shell_type: ShellType::Bash,
shell_path: PathBuf::from("/bin/bash"),
shell_snapshot: empty_shell_snapshot_receiver(),
};
assert_eq!(
test_bash_shell.derive_exec_args("echo hello", /*use_login_shell*/ false),
@@ -120,7 +120,6 @@ fn derive_exec_args() {
let test_zsh_shell = Shell {
shell_type: ShellType::Zsh,
shell_path: PathBuf::from("/bin/zsh"),
shell_snapshot: empty_shell_snapshot_receiver(),
};
assert_eq!(
test_zsh_shell.derive_exec_args("echo hello", /*use_login_shell*/ false),
@@ -134,7 +133,6 @@ fn derive_exec_args() {
let test_powershell_shell = Shell {
shell_type: ShellType::PowerShell,
shell_path: PathBuf::from("pwsh.exe"),
shell_snapshot: empty_shell_snapshot_receiver(),
};
assert_eq!(
test_powershell_shell.derive_exec_args("echo hello", /*use_login_shell*/ false),
@@ -146,8 +144,9 @@ fn derive_exec_args() {
);
}
#[tokio::test]
async fn test_current_shell_detects_zsh() {
#[test]
#[cfg(unix)]
fn test_current_shell_detects_zsh() {
let shell = Command::new("sh")
.arg("-c")
.arg("echo $SHELL")
@@ -161,14 +160,13 @@ async fn test_current_shell_detects_zsh() {
Shell {
shell_type: ShellType::Zsh,
shell_path: PathBuf::from(shell_path),
shell_snapshot: empty_shell_snapshot_receiver(),
}
);
}
}
#[tokio::test]
async fn detects_powershell_as_default() {
#[test]
fn detects_powershell_as_default() {
if !cfg!(windows) {
return;
}

View File

@@ -12,6 +12,7 @@ codex-app-server-protocol = { workspace = true }
codex-code-mode = { workspace = true }
codex-features = { workspace = true }
codex-protocol = { workspace = true }
codex-shell = { workspace = true }
codex-utils-absolute-path = { workspace = true }
codex-utils-pty = { workspace = true }
rmcp = { workspace = true, default-features = false, features = [

View File

@@ -86,7 +86,6 @@ pub use responses_api::mcp_tool_to_deferred_responses_api_tool;
pub use responses_api::mcp_tool_to_responses_api_tool;
pub use responses_api::tool_definition_to_responses_api_tool;
pub use tool_config::ShellCommandBackendConfig;
pub use tool_config::ToolUserShellType;
pub use tool_config::ToolsConfig;
pub use tool_config::ToolsConfigParams;
pub use tool_config::UnifiedExecShellMode;

View File

@@ -13,6 +13,7 @@ use codex_protocol::openai_models::WebSearchToolType;
use codex_protocol::protocol::SandboxPolicy;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_shell::ShellType;
use codex_utils_absolute_path::AbsolutePathBuf;
use std::path::PathBuf;
@@ -22,15 +23,6 @@ pub enum ShellCommandBackendConfig {
ZshFork,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum ToolUserShellType {
Zsh,
Bash,
PowerShell,
Sh,
Cmd,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum UnifiedExecShellMode {
Direct,
@@ -46,13 +38,13 @@ pub struct ZshForkConfig {
impl UnifiedExecShellMode {
pub fn for_session(
shell_command_backend: ShellCommandBackendConfig,
user_shell_type: ToolUserShellType,
user_shell_type: ShellType,
shell_zsh_path: Option<&PathBuf>,
main_execve_wrapper_exe: Option<&PathBuf>,
) -> Self {
if cfg!(unix)
&& shell_command_backend == ShellCommandBackendConfig::ZshFork
&& matches!(user_shell_type, ToolUserShellType::Zsh)
&& matches!(user_shell_type, ShellType::Zsh)
&& let (Some(shell_zsh_path), Some(main_execve_wrapper_exe)) =
(shell_zsh_path, main_execve_wrapper_exe)
&& let (Ok(shell_zsh_path), Ok(main_execve_wrapper_exe)) = (
@@ -246,7 +238,7 @@ impl ToolsConfig {
pub fn with_unified_exec_shell_mode_for_session(
mut self,
user_shell_type: ToolUserShellType,
user_shell_type: ShellType,
shell_zsh_path: Option<&PathBuf>,
main_execve_wrapper_exe: Option<&PathBuf>,
) -> Self {

View File

@@ -9,6 +9,7 @@ use codex_protocol::openai_models::ModelInfo;
use codex_protocol::protocol::SandboxPolicy;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_shell::ShellType;
use codex_utils_absolute_path::AbsolutePathBuf;
use pretty_assertions::assert_eq;
use serde_json::json;
@@ -103,7 +104,7 @@ fn shell_zsh_fork_prefers_shell_command_over_unified_exec() {
assert_eq!(
tools_config
.with_unified_exec_shell_mode_for_session(
ToolUserShellType::Zsh,
ShellType::Zsh,
Some(&PathBuf::from(if cfg!(windows) {
r"C:\opt\codex\zsh"
} else {