mirror of
https://github.com/openai/codex.git
synced 2026-05-04 13:21:54 +03:00
refactor: normalize unix module layout for exec-server and shell-escalation (#12556)
## Why Shell execution refactoring in `exec-server` had become split between duplicated code paths, which blocked a clean introduction of the new reusable shell escalation flow. This commit creates a dedicated foundation crate so later shell tooling changes can share one implementation. ## What changed - Added the `codex-shell-escalation` crate and moved the core escalation pieces (`mcp` protocol/socket/session flow, policy glue) that were previously in `exec-server` into it. - Normalized `exec-server` Unix structure under a dedicated `unix` module layout and kept non-Unix builds narrow. - Wired crate/build metadata so `shell-escalation` is a first-class workspace dependency for follow-on integration work. ## Verification - Built and linted the stack at this commit point with `just clippy`. [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/openai/codex/pull/12556). * #12584 * #12583 * __->__ #12556
This commit is contained in:
@@ -1,11 +1,5 @@
|
||||
#[cfg(unix)]
|
||||
mod posix;
|
||||
mod unix;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use posix::main_execve_wrapper;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use posix::main_mcp_server;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub use posix::ExecResult;
|
||||
pub use unix::*;
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
use std::io;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd as _;
|
||||
use std::os::fd::OwnedFd;
|
||||
|
||||
use anyhow::Context as _;
|
||||
|
||||
use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::EscalateAction;
|
||||
use crate::posix::escalate_protocol::EscalateRequest;
|
||||
use crate::posix::escalate_protocol::EscalateResponse;
|
||||
use crate::posix::escalate_protocol::LEGACY_BASH_EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::SuperExecMessage;
|
||||
use crate::posix::escalate_protocol::SuperExecResult;
|
||||
use crate::posix::socket::AsyncDatagramSocket;
|
||||
use crate::posix::socket::AsyncSocket;
|
||||
|
||||
fn get_escalate_client() -> anyhow::Result<AsyncDatagramSocket> {
|
||||
// TODO: we should defensively require only calling this once, since AsyncSocket will take ownership of the fd.
|
||||
let client_fd = std::env::var(ESCALATE_SOCKET_ENV_VAR)?.parse::<i32>()?;
|
||||
if client_fd < 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"{ESCALATE_SOCKET_ENV_VAR} is not a valid file descriptor: {client_fd}"
|
||||
));
|
||||
}
|
||||
Ok(unsafe { AsyncDatagramSocket::from_raw_fd(client_fd) }?)
|
||||
}
|
||||
|
||||
pub(crate) async fn run(file: String, argv: Vec<String>) -> anyhow::Result<i32> {
|
||||
let handshake_client = get_escalate_client()?;
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
const HANDSHAKE_MESSAGE: [u8; 1] = [0];
|
||||
handshake_client
|
||||
.send_with_fds(&HANDSHAKE_MESSAGE, &[server.into_inner().into()])
|
||||
.await
|
||||
.context("failed to send handshake datagram")?;
|
||||
let env = std::env::vars()
|
||||
.filter(|(k, _)| {
|
||||
!matches!(
|
||||
k.as_str(),
|
||||
ESCALATE_SOCKET_ENV_VAR | EXEC_WRAPPER_ENV_VAR | LEGACY_BASH_EXEC_WRAPPER_ENV_VAR
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: file.clone().into(),
|
||||
argv: argv.clone(),
|
||||
workdir: std::env::current_dir()?,
|
||||
env,
|
||||
})
|
||||
.await
|
||||
.context("failed to send EscalateRequest")?;
|
||||
let message = client
|
||||
.receive::<EscalateResponse>()
|
||||
.await
|
||||
.context("failed to receive EscalateResponse")?;
|
||||
match message.action {
|
||||
EscalateAction::Escalate => {
|
||||
// TODO: maybe we should send ALL open FDs (except the escalate client)?
|
||||
let fds_to_send = [
|
||||
unsafe { OwnedFd::from_raw_fd(io::stdin().as_raw_fd()) },
|
||||
unsafe { OwnedFd::from_raw_fd(io::stdout().as_raw_fd()) },
|
||||
unsafe { OwnedFd::from_raw_fd(io::stderr().as_raw_fd()) },
|
||||
];
|
||||
|
||||
// TODO: also forward signals over the super-exec socket
|
||||
|
||||
client
|
||||
.send_with_fds(
|
||||
SuperExecMessage {
|
||||
fds: fds_to_send.iter().map(AsRawFd::as_raw_fd).collect(),
|
||||
},
|
||||
&fds_to_send,
|
||||
)
|
||||
.await
|
||||
.context("failed to send SuperExecMessage")?;
|
||||
let SuperExecResult { exit_code } = client.receive::<SuperExecResult>().await?;
|
||||
Ok(exit_code)
|
||||
}
|
||||
EscalateAction::Run => {
|
||||
// We avoid std::process::Command here because we want to be as transparent as
|
||||
// possible. std::os::unix::process::CommandExt has .exec() but it does some funky
|
||||
// stuff with signal masks and dup2() on its standard FDs, which we don't want.
|
||||
use std::ffi::CString;
|
||||
let file = CString::new(file).context("NUL in file")?;
|
||||
|
||||
let argv_cstrs: Vec<CString> = argv
|
||||
.iter()
|
||||
.map(|s| CString::new(s.as_str()).context("NUL in argv"))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let mut argv: Vec<*const libc::c_char> =
|
||||
argv_cstrs.iter().map(|s| s.as_ptr()).collect();
|
||||
argv.push(std::ptr::null());
|
||||
|
||||
let err = unsafe {
|
||||
libc::execv(file.as_ptr(), argv.as_ptr());
|
||||
std::io::Error::last_os_error()
|
||||
};
|
||||
|
||||
Err(err.into())
|
||||
}
|
||||
EscalateAction::Deny { reason } => {
|
||||
match reason {
|
||||
Some(reason) => eprintln!("Execution denied: {reason}"),
|
||||
None => eprintln!("Execution denied"),
|
||||
}
|
||||
Ok(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::os::fd::RawFd;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
/// 'exec-server escalate' reads this to find the inherited FD for the escalate socket.
|
||||
pub(super) const ESCALATE_SOCKET_ENV_VAR: &str = "CODEX_ESCALATE_SOCKET";
|
||||
|
||||
/// Patched shells use this to wrap exec() calls.
|
||||
pub(super) const EXEC_WRAPPER_ENV_VAR: &str = "EXEC_WRAPPER";
|
||||
|
||||
/// Compatibility alias for older patched bash builds.
|
||||
pub(super) const LEGACY_BASH_EXEC_WRAPPER_ENV_VAR: &str = "BASH_EXEC_WRAPPER";
|
||||
|
||||
/// The client sends this to the server to request an exec() call.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
pub(super) struct EscalateRequest {
|
||||
/// The absolute path to the executable to run, i.e. the first arg to exec.
|
||||
pub(super) file: PathBuf,
|
||||
/// The argv, including the program name (argv[0]).
|
||||
pub(super) argv: Vec<String>,
|
||||
pub(super) workdir: PathBuf,
|
||||
pub(super) env: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// The server sends this to the client to respond to an exec() request.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
pub(super) struct EscalateResponse {
|
||||
pub(super) action: EscalateAction,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
pub(super) enum EscalateAction {
|
||||
/// The command should be run directly by the client.
|
||||
Run,
|
||||
/// The command should be escalated to the server for execution.
|
||||
Escalate,
|
||||
/// The command should not be executed.
|
||||
Deny { reason: Option<String> },
|
||||
}
|
||||
|
||||
/// The client sends this to the server to forward its open FDs.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub(super) struct SuperExecMessage {
|
||||
pub(super) fds: Vec<RawFd>,
|
||||
}
|
||||
|
||||
/// The server responds when the exec()'d command has exited.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub(super) struct SuperExecResult {
|
||||
pub(super) exit_code: i32,
|
||||
}
|
||||
@@ -1,336 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use path_absolutize::Absolutize as _;
|
||||
|
||||
use codex_core::SandboxState;
|
||||
use codex_core::exec::process_exec_tool_call;
|
||||
use codex_core::sandboxing::SandboxPermissions;
|
||||
use codex_protocol::config_types::WindowsSandboxLevel;
|
||||
use tokio::process::Command;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::EscalateAction;
|
||||
use crate::posix::escalate_protocol::EscalateRequest;
|
||||
use crate::posix::escalate_protocol::EscalateResponse;
|
||||
use crate::posix::escalate_protocol::LEGACY_BASH_EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::SuperExecMessage;
|
||||
use crate::posix::escalate_protocol::SuperExecResult;
|
||||
use crate::posix::escalation_policy::EscalationPolicy;
|
||||
use crate::posix::mcp::ExecParams;
|
||||
use crate::posix::socket::AsyncDatagramSocket;
|
||||
use crate::posix::socket::AsyncSocket;
|
||||
use codex_core::exec::ExecExpiration;
|
||||
|
||||
pub(crate) struct EscalateServer {
|
||||
bash_path: PathBuf,
|
||||
execve_wrapper: PathBuf,
|
||||
policy: Arc<dyn EscalationPolicy>,
|
||||
}
|
||||
|
||||
impl EscalateServer {
|
||||
pub fn new<P>(bash_path: PathBuf, execve_wrapper: PathBuf, policy: P) -> Self
|
||||
where
|
||||
P: EscalationPolicy + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
bash_path,
|
||||
execve_wrapper,
|
||||
policy: Arc::new(policy),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn exec(
|
||||
&self,
|
||||
params: ExecParams,
|
||||
cancel_rx: CancellationToken,
|
||||
sandbox_state: &SandboxState,
|
||||
) -> anyhow::Result<ExecResult> {
|
||||
let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?;
|
||||
let client_socket = escalate_client.into_inner();
|
||||
client_socket.set_cloexec(false)?;
|
||||
|
||||
let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy.clone()));
|
||||
let mut env = std::env::vars().collect::<HashMap<String, String>>();
|
||||
env.insert(
|
||||
ESCALATE_SOCKET_ENV_VAR.to_string(),
|
||||
client_socket.as_raw_fd().to_string(),
|
||||
);
|
||||
env.insert(
|
||||
EXEC_WRAPPER_ENV_VAR.to_string(),
|
||||
self.execve_wrapper.to_string_lossy().to_string(),
|
||||
);
|
||||
env.insert(
|
||||
LEGACY_BASH_EXEC_WRAPPER_ENV_VAR.to_string(),
|
||||
self.execve_wrapper.to_string_lossy().to_string(),
|
||||
);
|
||||
|
||||
let ExecParams {
|
||||
command,
|
||||
workdir,
|
||||
timeout_ms: _,
|
||||
login,
|
||||
} = params;
|
||||
let result = process_exec_tool_call(
|
||||
codex_core::exec::ExecParams {
|
||||
command: vec![
|
||||
self.bash_path.to_string_lossy().to_string(),
|
||||
if login == Some(false) {
|
||||
"-c".to_string()
|
||||
} else {
|
||||
"-lc".to_string()
|
||||
},
|
||||
command,
|
||||
],
|
||||
cwd: PathBuf::from(&workdir),
|
||||
expiration: ExecExpiration::Cancellation(cancel_rx),
|
||||
env,
|
||||
network: None,
|
||||
sandbox_permissions: SandboxPermissions::UseDefault,
|
||||
windows_sandbox_level: WindowsSandboxLevel::Disabled,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
},
|
||||
&sandbox_state.sandbox_policy,
|
||||
&sandbox_state.sandbox_cwd,
|
||||
&sandbox_state.codex_linux_sandbox_exe,
|
||||
sandbox_state.use_linux_sandbox_bwrap,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
escalate_task.abort();
|
||||
let result = ExecResult {
|
||||
exit_code: result.exit_code,
|
||||
output: result.aggregated_output.text,
|
||||
duration: result.duration,
|
||||
timed_out: result.timed_out,
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
async fn escalate_task(
|
||||
socket: AsyncDatagramSocket,
|
||||
policy: Arc<dyn EscalationPolicy>,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let (_, mut fds) = socket.receive_with_fds().await?;
|
||||
if fds.len() != 1 {
|
||||
tracing::error!("expected 1 fd in datagram handshake, got {}", fds.len());
|
||||
continue;
|
||||
}
|
||||
let stream_socket = AsyncSocket::from_fd(fds.remove(0))?;
|
||||
let policy = policy.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = handle_escalate_session_with_policy(stream_socket, policy).await {
|
||||
tracing::error!("escalate session failed: {err:?}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ExecResult {
|
||||
pub(crate) exit_code: i32,
|
||||
pub(crate) output: String,
|
||||
pub(crate) duration: Duration,
|
||||
pub(crate) timed_out: bool,
|
||||
}
|
||||
|
||||
async fn handle_escalate_session_with_policy(
|
||||
socket: AsyncSocket,
|
||||
policy: Arc<dyn EscalationPolicy>,
|
||||
) -> anyhow::Result<()> {
|
||||
let EscalateRequest {
|
||||
file,
|
||||
argv,
|
||||
workdir,
|
||||
env,
|
||||
} = socket.receive::<EscalateRequest>().await?;
|
||||
let file = PathBuf::from(&file).absolutize()?.into_owned();
|
||||
let workdir = PathBuf::from(&workdir).absolutize()?.into_owned();
|
||||
let action = policy
|
||||
.determine_action(file.as_path(), &argv, &workdir)
|
||||
.await?;
|
||||
|
||||
tracing::debug!("decided {action:?} for {file:?} {argv:?} {workdir:?}");
|
||||
|
||||
match action {
|
||||
EscalateAction::Run => {
|
||||
socket
|
||||
.send(EscalateResponse {
|
||||
action: EscalateAction::Run,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
EscalateAction::Escalate => {
|
||||
socket
|
||||
.send(EscalateResponse {
|
||||
action: EscalateAction::Escalate,
|
||||
})
|
||||
.await?;
|
||||
let (msg, fds) = socket
|
||||
.receive_with_fds::<SuperExecMessage>()
|
||||
.await
|
||||
.context("failed to receive SuperExecMessage")?;
|
||||
if fds.len() != msg.fds.len() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"mismatched number of fds in SuperExecMessage: {} in the message, {} from the control message",
|
||||
msg.fds.len(),
|
||||
fds.len()
|
||||
));
|
||||
}
|
||||
|
||||
if msg
|
||||
.fds
|
||||
.iter()
|
||||
.any(|src_fd| fds.iter().any(|dst_fd| dst_fd.as_raw_fd() == *src_fd))
|
||||
{
|
||||
return Err(anyhow::anyhow!(
|
||||
"overlapping fds not yet supported in SuperExecMessage"
|
||||
));
|
||||
}
|
||||
|
||||
let mut command = Command::new(file);
|
||||
command
|
||||
.args(&argv[1..])
|
||||
.arg0(argv[0].clone())
|
||||
.envs(&env)
|
||||
.current_dir(&workdir)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null());
|
||||
unsafe {
|
||||
command.pre_exec(move || {
|
||||
for (dst_fd, src_fd) in msg.fds.iter().zip(&fds) {
|
||||
libc::dup2(src_fd.as_raw_fd(), *dst_fd);
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
let mut child = command.spawn()?;
|
||||
let exit_status = child.wait().await?;
|
||||
socket
|
||||
.send(SuperExecResult {
|
||||
exit_code: exit_status.code().unwrap_or(127),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
EscalateAction::Deny { reason } => {
|
||||
socket
|
||||
.send(EscalateResponse {
|
||||
action: EscalateAction::Deny { reason },
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
struct DeterministicEscalationPolicy {
|
||||
action: EscalateAction,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl EscalationPolicy for DeterministicEscalationPolicy {
|
||||
async fn determine_action(
|
||||
&self,
|
||||
_file: &Path,
|
||||
_argv: &[String],
|
||||
_workdir: &Path,
|
||||
) -> Result<EscalateAction, rmcp::ErrorData> {
|
||||
Ok(self.action.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_respects_run_in_sandbox_decision() -> anyhow::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let server_task = tokio::spawn(handle_escalate_session_with_policy(
|
||||
server,
|
||||
Arc::new(DeterministicEscalationPolicy {
|
||||
action: EscalateAction::Run,
|
||||
}),
|
||||
));
|
||||
|
||||
let mut env = HashMap::new();
|
||||
for i in 0..10 {
|
||||
let value = "A".repeat(1024);
|
||||
env.insert(format!("CODEX_TEST_VAR{i}"), value);
|
||||
}
|
||||
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: PathBuf::from("/bin/echo"),
|
||||
argv: vec!["echo".to_string()],
|
||||
workdir: PathBuf::from("/tmp"),
|
||||
env,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let response = client.receive::<EscalateResponse>().await?;
|
||||
assert_eq!(
|
||||
EscalateResponse {
|
||||
action: EscalateAction::Run,
|
||||
},
|
||||
response
|
||||
);
|
||||
server_task.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_executes_escalated_command() -> anyhow::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let server_task = tokio::spawn(handle_escalate_session_with_policy(
|
||||
server,
|
||||
Arc::new(DeterministicEscalationPolicy {
|
||||
action: EscalateAction::Escalate,
|
||||
}),
|
||||
));
|
||||
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: PathBuf::from("/bin/sh"),
|
||||
argv: vec![
|
||||
"sh".to_string(),
|
||||
"-c".to_string(),
|
||||
r#"if [ "$KEY" = VALUE ]; then exit 42; else exit 1; fi"#.to_string(),
|
||||
],
|
||||
workdir: std::env::current_dir()?,
|
||||
env: HashMap::from([("KEY".to_string(), "VALUE".to_string())]),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let response = client.receive::<EscalateResponse>().await?;
|
||||
assert_eq!(
|
||||
EscalateResponse {
|
||||
action: EscalateAction::Escalate,
|
||||
},
|
||||
response
|
||||
);
|
||||
|
||||
client
|
||||
.send_with_fds(SuperExecMessage { fds: Vec::new() }, &[])
|
||||
.await?;
|
||||
|
||||
let result = client.receive::<SuperExecResult>().await?;
|
||||
assert_eq!(42, result.exit_code);
|
||||
|
||||
server_task.await?
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
use std::path::Path;
|
||||
|
||||
use crate::posix::escalate_protocol::EscalateAction;
|
||||
|
||||
/// Decides what action to take in response to an execve request from a client.
|
||||
#[async_trait::async_trait]
|
||||
pub(crate) trait EscalationPolicy: Send + Sync {
|
||||
async fn determine_action(
|
||||
&self,
|
||||
file: &Path,
|
||||
argv: &[String],
|
||||
workdir: &Path,
|
||||
) -> Result<EscalateAction, rmcp::ErrorData>;
|
||||
}
|
||||
@@ -1,507 +0,0 @@
|
||||
use libc::c_uint;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use socket2::Domain;
|
||||
use socket2::MaybeUninitSlice;
|
||||
use socket2::MsgHdr;
|
||||
use socket2::MsgHdrMut;
|
||||
use socket2::Socket;
|
||||
use socket2::Type;
|
||||
use std::io::IoSlice;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd;
|
||||
use std::os::fd::OwnedFd;
|
||||
use std::os::fd::RawFd;
|
||||
use tokio::io::Interest;
|
||||
use tokio::io::unix::AsyncFd;
|
||||
|
||||
const MAX_FDS_PER_MESSAGE: usize = 16;
|
||||
const LENGTH_PREFIX_SIZE: usize = size_of::<u32>();
|
||||
const MAX_DATAGRAM_SIZE: usize = 8192;
|
||||
|
||||
/// Converts a slice of MaybeUninit<T> to a slice of T.
|
||||
///
|
||||
/// The caller guarantees that every element of `buf` is initialized.
|
||||
fn assume_init<T>(buf: &[MaybeUninit<T>]) -> &[T] {
|
||||
unsafe { std::slice::from_raw_parts(buf.as_ptr().cast(), buf.len()) }
|
||||
}
|
||||
|
||||
fn assume_init_slice<T, const N: usize>(buf: &[MaybeUninit<T>; N]) -> &[T; N] {
|
||||
unsafe { &*(buf as *const [MaybeUninit<T>; N] as *const [T; N]) }
|
||||
}
|
||||
|
||||
fn assume_init_vec<T>(mut buf: Vec<MaybeUninit<T>>) -> Vec<T> {
|
||||
unsafe {
|
||||
let ptr = buf.as_mut_ptr() as *mut T;
|
||||
let len = buf.len();
|
||||
let cap = buf.capacity();
|
||||
std::mem::forget(buf);
|
||||
Vec::from_raw_parts(ptr, len, cap)
|
||||
}
|
||||
}
|
||||
|
||||
fn control_space_for_fds(count: usize) -> usize {
|
||||
unsafe { libc::CMSG_SPACE((count * size_of::<RawFd>()) as _) as usize }
|
||||
}
|
||||
|
||||
/// Extracts the FDs from a SCM_RIGHTS control message.
|
||||
fn extract_fds(control: &[u8]) -> Vec<OwnedFd> {
|
||||
let mut fds = Vec::new();
|
||||
let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() };
|
||||
hdr.msg_control = control.as_ptr() as *mut libc::c_void;
|
||||
hdr.msg_controllen = control.len() as _;
|
||||
let hdr = hdr; // drop mut
|
||||
|
||||
let mut cmsg = unsafe { libc::CMSG_FIRSTHDR(&hdr) as *const libc::cmsghdr };
|
||||
while !cmsg.is_null() {
|
||||
let level = unsafe { (*cmsg).cmsg_level };
|
||||
let ty = unsafe { (*cmsg).cmsg_type };
|
||||
if level == libc::SOL_SOCKET && ty == libc::SCM_RIGHTS {
|
||||
let data_ptr = unsafe { libc::CMSG_DATA(cmsg).cast::<RawFd>() };
|
||||
let fd_count: usize = {
|
||||
let cmsg_data_len =
|
||||
unsafe { (*cmsg).cmsg_len as usize } - unsafe { libc::CMSG_LEN(0) as usize };
|
||||
cmsg_data_len / size_of::<RawFd>()
|
||||
};
|
||||
for i in 0..fd_count {
|
||||
let fd = unsafe { data_ptr.add(i).read() };
|
||||
fds.push(unsafe { OwnedFd::from_raw_fd(fd) });
|
||||
}
|
||||
}
|
||||
cmsg = unsafe { libc::CMSG_NXTHDR(&hdr, cmsg) };
|
||||
}
|
||||
fds
|
||||
}
|
||||
|
||||
/// Read a frame from a SOCK_STREAM socket.
|
||||
///
|
||||
/// A frame is a message length prefix followed by a payload. FDs may be included in the control
|
||||
/// message when receiving the frame header.
|
||||
async fn read_frame(async_socket: &AsyncFd<Socket>) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
|
||||
let (message_len, fds) = read_frame_header(async_socket).await?;
|
||||
let payload = read_frame_payload(async_socket, message_len).await?;
|
||||
Ok((payload, fds))
|
||||
}
|
||||
|
||||
/// Read the frame header (i.e. length) and any FDs from a SOCK_STREAM socket.
|
||||
async fn read_frame_header(
|
||||
async_socket: &AsyncFd<Socket>,
|
||||
) -> std::io::Result<(usize, Vec<OwnedFd>)> {
|
||||
let mut header = [MaybeUninit::<u8>::uninit(); LENGTH_PREFIX_SIZE];
|
||||
let mut filled = 0;
|
||||
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
|
||||
let mut captured_control = false;
|
||||
|
||||
while filled < LENGTH_PREFIX_SIZE {
|
||||
let mut guard = async_socket.readable().await?;
|
||||
// The first read should come with a control message containing any FDs.
|
||||
let result = if !captured_control {
|
||||
guard.try_io(|inner| {
|
||||
let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])];
|
||||
let (read, control_len) = {
|
||||
let mut msg = MsgHdrMut::new()
|
||||
.with_buffers(&mut bufs)
|
||||
.with_control(&mut control);
|
||||
let read = inner.get_ref().recvmsg(&mut msg, 0)?;
|
||||
(read, msg.control_len())
|
||||
};
|
||||
control.truncate(control_len);
|
||||
captured_control = true;
|
||||
Ok(read)
|
||||
})
|
||||
} else {
|
||||
guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..]))
|
||||
};
|
||||
let Ok(result) = result else {
|
||||
// Would block, try again.
|
||||
continue;
|
||||
};
|
||||
|
||||
let read = result?;
|
||||
if read == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"socket closed while receiving frame header",
|
||||
));
|
||||
}
|
||||
|
||||
filled += read;
|
||||
assert!(filled <= LENGTH_PREFIX_SIZE);
|
||||
if filled == LENGTH_PREFIX_SIZE {
|
||||
let len_bytes = assume_init_slice(&header);
|
||||
let payload_len = u32::from_le_bytes(*len_bytes) as usize;
|
||||
let fds = extract_fds(assume_init(&control));
|
||||
return Ok((payload_len, fds));
|
||||
}
|
||||
}
|
||||
unreachable!("header loop always returns")
|
||||
}
|
||||
|
||||
/// Read `message_len` bytes from a SOCK_STREAM socket.
|
||||
async fn read_frame_payload(
|
||||
async_socket: &AsyncFd<Socket>,
|
||||
message_len: usize,
|
||||
) -> std::io::Result<Vec<u8>> {
|
||||
if message_len == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut payload = vec![MaybeUninit::<u8>::uninit(); message_len];
|
||||
let mut filled = 0;
|
||||
while filled < message_len {
|
||||
let mut guard = async_socket.readable().await?;
|
||||
let result = guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..]));
|
||||
let Ok(result) = result else {
|
||||
// Would block, try again.
|
||||
continue;
|
||||
};
|
||||
let read = result?;
|
||||
if read == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"socket closed while receiving frame payload",
|
||||
));
|
||||
}
|
||||
filled += read;
|
||||
assert!(filled <= message_len);
|
||||
if filled == message_len {
|
||||
return Ok(assume_init_vec(payload));
|
||||
}
|
||||
}
|
||||
unreachable!("loop exits only after returning payload")
|
||||
}
|
||||
|
||||
fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
|
||||
let control = make_control_message(fds)?;
|
||||
let payload = [IoSlice::new(data)];
|
||||
let msg = if control.is_empty() {
|
||||
MsgHdr::new().with_buffers(&payload)
|
||||
} else {
|
||||
MsgHdr::new().with_buffers(&payload).with_control(&control)
|
||||
};
|
||||
let written = socket.sendmsg(&msg, 0)?;
|
||||
if written != data.len() {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
format!(
|
||||
"short datagram write: wrote {written} bytes out of {}",
|
||||
data.len()
|
||||
),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn encode_length(len: usize) -> std::io::Result<[u8; LENGTH_PREFIX_SIZE]> {
|
||||
let len_u32 = u32::try_from(len).map_err(|_| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("message too large: {len}"),
|
||||
)
|
||||
})?;
|
||||
Ok(len_u32.to_le_bytes())
|
||||
}
|
||||
|
||||
fn make_control_message(fds: &[OwnedFd]) -> std::io::Result<Vec<u8>> {
|
||||
if fds.len() > MAX_FDS_PER_MESSAGE {
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("too many fds: {}", fds.len()),
|
||||
))
|
||||
} else if fds.is_empty() {
|
||||
Ok(Vec::new())
|
||||
} else {
|
||||
let mut control = vec![0u8; control_space_for_fds(fds.len())];
|
||||
unsafe {
|
||||
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
|
||||
(*cmsg).cmsg_len =
|
||||
libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
|
||||
(*cmsg).cmsg_level = libc::SOL_SOCKET;
|
||||
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
|
||||
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
|
||||
for (i, fd) in fds.iter().enumerate() {
|
||||
data_ptr.add(i).write(fd.as_raw_fd());
|
||||
}
|
||||
}
|
||||
Ok(control)
|
||||
}
|
||||
}
|
||||
|
||||
fn receive_datagram_bytes(socket: &Socket) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
|
||||
let mut buffer = vec![MaybeUninit::<u8>::uninit(); MAX_DATAGRAM_SIZE];
|
||||
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
|
||||
let (read, control_len) = {
|
||||
let mut bufs = [MaybeUninitSlice::new(&mut buffer)];
|
||||
let mut msg = MsgHdrMut::new()
|
||||
.with_buffers(&mut bufs)
|
||||
.with_control(&mut control);
|
||||
let read = socket.recvmsg(&mut msg, 0)?;
|
||||
(read, msg.control_len())
|
||||
};
|
||||
let data = assume_init(&buffer[..read]).to_vec();
|
||||
let fds = extract_fds(assume_init(&control[..control_len]));
|
||||
Ok((data, fds))
|
||||
}
|
||||
|
||||
pub(crate) struct AsyncSocket {
|
||||
inner: AsyncFd<Socket>,
|
||||
}
|
||||
|
||||
impl AsyncSocket {
|
||||
fn new(socket: Socket) -> std::io::Result<AsyncSocket> {
|
||||
socket.set_nonblocking(true)?;
|
||||
let async_socket = AsyncFd::new(socket)?;
|
||||
Ok(AsyncSocket {
|
||||
inner: async_socket,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_fd(fd: OwnedFd) -> std::io::Result<AsyncSocket> {
|
||||
AsyncSocket::new(Socket::from(fd))
|
||||
}
|
||||
|
||||
pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> {
|
||||
let (server, client) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
|
||||
Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?))
|
||||
}
|
||||
|
||||
pub async fn send_with_fds<T: Serialize>(
|
||||
&self,
|
||||
msg: T,
|
||||
fds: &[OwnedFd],
|
||||
) -> std::io::Result<()> {
|
||||
let payload = serde_json::to_vec(&msg)?;
|
||||
let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + payload.len());
|
||||
frame.extend_from_slice(&encode_length(payload.len())?);
|
||||
frame.extend_from_slice(&payload);
|
||||
send_stream_frame(&self.inner, &frame, fds).await
|
||||
}
|
||||
|
||||
pub async fn receive_with_fds<T: for<'de> Deserialize<'de>>(
|
||||
&self,
|
||||
) -> std::io::Result<(T, Vec<OwnedFd>)> {
|
||||
let (payload, fds) = read_frame(&self.inner).await?;
|
||||
let message: T = serde_json::from_slice(&payload)?;
|
||||
Ok((message, fds))
|
||||
}
|
||||
|
||||
pub async fn send<T>(&self, msg: T) -> std::io::Result<()>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
self.send_with_fds(&msg, &[]).await
|
||||
}
|
||||
|
||||
pub async fn receive<T: for<'de> Deserialize<'de>>(&self) -> std::io::Result<T> {
|
||||
let (msg, fds) = self.receive_with_fds().await?;
|
||||
if !fds.is_empty() {
|
||||
tracing::warn!("unexpected fds in receive: {}", fds.len());
|
||||
}
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Socket {
|
||||
self.inner.into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_stream_frame(
|
||||
socket: &AsyncFd<Socket>,
|
||||
frame: &[u8],
|
||||
fds: &[OwnedFd],
|
||||
) -> std::io::Result<()> {
|
||||
let mut written = 0;
|
||||
let mut include_fds = !fds.is_empty();
|
||||
while written < frame.len() {
|
||||
let mut guard = socket.writable().await?;
|
||||
let result = guard.try_io(|inner| {
|
||||
send_stream_chunk(inner.get_ref(), &frame[written..], fds, include_fds)
|
||||
});
|
||||
let bytes_written = match result {
|
||||
Ok(bytes_written) => bytes_written?,
|
||||
Err(_would_block) => continue,
|
||||
};
|
||||
if bytes_written == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
"socket closed while sending frame payload",
|
||||
));
|
||||
}
|
||||
written += bytes_written;
|
||||
include_fds = false;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn send_stream_chunk(
|
||||
socket: &Socket,
|
||||
frame: &[u8],
|
||||
fds: &[OwnedFd],
|
||||
include_fds: bool,
|
||||
) -> std::io::Result<usize> {
|
||||
let control = if include_fds {
|
||||
make_control_message(fds)?
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
let payload = [IoSlice::new(frame)];
|
||||
let msg = if control.is_empty() {
|
||||
MsgHdr::new().with_buffers(&payload)
|
||||
} else {
|
||||
MsgHdr::new().with_buffers(&payload).with_control(&control)
|
||||
};
|
||||
socket.sendmsg(&msg, 0)
|
||||
}
|
||||
|
||||
pub(crate) struct AsyncDatagramSocket {
|
||||
inner: AsyncFd<Socket>,
|
||||
}
|
||||
|
||||
impl AsyncDatagramSocket {
|
||||
fn new(socket: Socket) -> std::io::Result<Self> {
|
||||
socket.set_nonblocking(true)?;
|
||||
Ok(Self {
|
||||
inner: AsyncFd::new(socket)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub unsafe fn from_raw_fd(fd: RawFd) -> std::io::Result<Self> {
|
||||
Self::new(unsafe { Socket::from_raw_fd(fd) })
|
||||
}
|
||||
|
||||
pub fn pair() -> std::io::Result<(Self, Self)> {
|
||||
let (server, client) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
|
||||
Ok((Self::new(server)?, Self::new(client)?))
|
||||
}
|
||||
|
||||
pub async fn send_with_fds(&self, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
|
||||
self.inner
|
||||
.async_io(Interest::WRITABLE, |socket| {
|
||||
send_datagram_bytes(socket, data, fds)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn receive_with_fds(&self) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
|
||||
self.inner
|
||||
.async_io(Interest::READABLE, receive_datagram_bytes)
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Socket {
|
||||
self.inner.into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::os::fd::AsFd;
|
||||
use std::os::fd::AsRawFd;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
|
||||
struct TestPayload {
|
||||
id: i32,
|
||||
label: String,
|
||||
}
|
||||
|
||||
fn fd_list(count: usize) -> std::io::Result<Vec<OwnedFd>> {
|
||||
let file = NamedTempFile::new()?;
|
||||
let mut fds = Vec::new();
|
||||
for _ in 0..count {
|
||||
fds.push(file.as_fd().try_clone_to_owned()?);
|
||||
}
|
||||
Ok(fds)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn async_socket_round_trips_payload_and_fds() -> std::io::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let payload = TestPayload {
|
||||
id: 7,
|
||||
label: "round-trip".to_string(),
|
||||
};
|
||||
let send_fds = fd_list(1)?;
|
||||
|
||||
let receive_task =
|
||||
tokio::spawn(async move { server.receive_with_fds::<TestPayload>().await });
|
||||
client.send_with_fds(payload.clone(), &send_fds).await?;
|
||||
drop(send_fds);
|
||||
|
||||
let (received_payload, received_fds) = receive_task.await.unwrap()?;
|
||||
assert_eq!(payload, received_payload);
|
||||
assert_eq!(1, received_fds.len());
|
||||
let fd_status = unsafe { libc::fcntl(received_fds[0].as_raw_fd(), libc::F_GETFD) };
|
||||
assert!(
|
||||
fd_status >= 0,
|
||||
"expected received file descriptor to be valid, but got {fd_status}",
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn async_socket_handles_large_payload() -> std::io::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let payload = vec![b'A'; 10_000];
|
||||
let receive_task = tokio::spawn(async move { server.receive::<Vec<u8>>().await });
|
||||
client.send(payload.clone()).await?;
|
||||
let received_payload = receive_task.await.unwrap()?;
|
||||
assert_eq!(payload, received_payload);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn async_datagram_sockets_round_trip_messages() -> std::io::Result<()> {
|
||||
let (server, client) = AsyncDatagramSocket::pair()?;
|
||||
let data = b"datagram payload".to_vec();
|
||||
let send_fds = fd_list(1)?;
|
||||
let receive_task = tokio::spawn(async move { server.receive_with_fds().await });
|
||||
|
||||
client.send_with_fds(&data, &send_fds).await?;
|
||||
drop(send_fds);
|
||||
|
||||
let (received_bytes, received_fds) = receive_task.await.unwrap()?;
|
||||
assert_eq!(data, received_bytes);
|
||||
assert_eq!(1, received_fds.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
|
||||
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
|
||||
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
|
||||
let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_stream_chunk_rejects_excessive_fd_counts() -> std::io::Result<()> {
|
||||
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
|
||||
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
|
||||
let err = send_stream_chunk(&socket, b"hello", &fds, true).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_length_errors_for_oversized_messages() {
|
||||
let err = encode_length(usize::MAX).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_fails_when_peer_closes_before_header() {
|
||||
let (server, client) = AsyncSocket::pair().expect("failed to create socket pair");
|
||||
drop(client);
|
||||
let err = server
|
||||
.receive::<serde_json::Value>()
|
||||
.await
|
||||
.expect_err("expected read failure");
|
||||
assert_eq!(std::io::ErrorKind::UnexpectedEof, err.kind());
|
||||
}
|
||||
}
|
||||
@@ -1,211 +0,0 @@
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct Stopwatch {
|
||||
limit: Duration,
|
||||
inner: Arc<Mutex<StopwatchState>>,
|
||||
notify: Arc<Notify>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct StopwatchState {
|
||||
elapsed: Duration,
|
||||
running_since: Option<Instant>,
|
||||
active_pauses: u32,
|
||||
}
|
||||
|
||||
impl Stopwatch {
|
||||
pub(crate) fn new(limit: Duration) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(StopwatchState {
|
||||
elapsed: Duration::ZERO,
|
||||
running_since: Some(Instant::now()),
|
||||
active_pauses: 0,
|
||||
})),
|
||||
notify: Arc::new(Notify::new()),
|
||||
limit,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn cancellation_token(&self) -> CancellationToken {
|
||||
let limit = self.limit;
|
||||
let token = CancellationToken::new();
|
||||
let cancel = token.clone();
|
||||
let inner = Arc::clone(&self.inner);
|
||||
let notify = Arc::clone(&self.notify);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (remaining, running) = {
|
||||
let guard = inner.lock().await;
|
||||
let elapsed = guard.elapsed
|
||||
+ guard
|
||||
.running_since
|
||||
.map(|since| since.elapsed())
|
||||
.unwrap_or_default();
|
||||
if elapsed >= limit {
|
||||
break;
|
||||
}
|
||||
(limit - elapsed, guard.running_since.is_some())
|
||||
};
|
||||
|
||||
if !running {
|
||||
notify.notified().await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let sleep = tokio::time::sleep(remaining);
|
||||
tokio::pin!(sleep);
|
||||
tokio::select! {
|
||||
_ = &mut sleep => {
|
||||
break;
|
||||
}
|
||||
_ = notify.notified() => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
cancel.cancel();
|
||||
});
|
||||
token
|
||||
}
|
||||
|
||||
/// Runs `fut`, pausing the stopwatch while the future is pending. The clock
|
||||
/// resumes automatically when the future completes. Nested/overlapping
|
||||
/// calls are reference-counted so the stopwatch only resumes when every
|
||||
/// pause is lifted.
|
||||
pub(crate) async fn pause_for<F, T>(&self, fut: F) -> T
|
||||
where
|
||||
F: Future<Output = T>,
|
||||
{
|
||||
self.pause().await;
|
||||
let result = fut.await;
|
||||
self.resume().await;
|
||||
result
|
||||
}
|
||||
|
||||
async fn pause(&self) {
|
||||
let mut guard = self.inner.lock().await;
|
||||
guard.active_pauses += 1;
|
||||
if guard.active_pauses == 1
|
||||
&& let Some(since) = guard.running_since.take()
|
||||
{
|
||||
guard.elapsed += since.elapsed();
|
||||
self.notify.notify_waiters();
|
||||
}
|
||||
}
|
||||
|
||||
async fn resume(&self) {
|
||||
let mut guard = self.inner.lock().await;
|
||||
if guard.active_pauses == 0 {
|
||||
return;
|
||||
}
|
||||
guard.active_pauses -= 1;
|
||||
if guard.active_pauses == 0 && guard.running_since.is_none() {
|
||||
guard.running_since = Some(Instant::now());
|
||||
self.notify.notify_waiters();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Stopwatch;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancellation_receiver_fires_after_limit() {
|
||||
let stopwatch = Stopwatch::new(Duration::from_millis(50));
|
||||
let token = stopwatch.cancellation_token();
|
||||
let start = Instant::now();
|
||||
token.cancelled().await;
|
||||
assert!(start.elapsed() >= Duration::from_millis(50));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pause_prevents_timeout_until_resumed() {
|
||||
let stopwatch = Stopwatch::new(Duration::from_millis(50));
|
||||
let token = stopwatch.cancellation_token();
|
||||
|
||||
let pause_handle = tokio::spawn({
|
||||
let stopwatch = stopwatch.clone();
|
||||
async move {
|
||||
stopwatch
|
||||
.pause_for(async {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
})
|
||||
.await;
|
||||
}
|
||||
});
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_millis(30), token.cancelled())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
pause_handle.await.expect("pause task should finish");
|
||||
|
||||
token.cancelled().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn overlapping_pauses_only_resume_once() {
|
||||
let stopwatch = Stopwatch::new(Duration::from_millis(50));
|
||||
let token = stopwatch.cancellation_token();
|
||||
|
||||
// First pause.
|
||||
let pause1 = {
|
||||
let stopwatch = stopwatch.clone();
|
||||
tokio::spawn(async move {
|
||||
stopwatch
|
||||
.pause_for(async {
|
||||
sleep(Duration::from_millis(80)).await;
|
||||
})
|
||||
.await;
|
||||
})
|
||||
};
|
||||
|
||||
// Overlapping pause that ends sooner.
|
||||
let pause2 = {
|
||||
let stopwatch = stopwatch.clone();
|
||||
tokio::spawn(async move {
|
||||
stopwatch
|
||||
.pause_for(async {
|
||||
sleep(Duration::from_millis(30)).await;
|
||||
})
|
||||
.await;
|
||||
})
|
||||
};
|
||||
|
||||
// While both pauses are active, the cancellation should not fire.
|
||||
assert!(
|
||||
timeout(Duration::from_millis(40), token.cancelled())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
pause2.await.expect("short pause should complete");
|
||||
|
||||
// Still paused because the long pause is active.
|
||||
assert!(
|
||||
timeout(Duration::from_millis(30), token.cancelled())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
pause1.await.expect("long pause should complete");
|
||||
|
||||
// Now the stopwatch should resume and hit the limit shortly after.
|
||||
token.cancelled().await;
|
||||
}
|
||||
}
|
||||
@@ -67,21 +67,16 @@ use codex_execpolicy::Decision;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_execpolicy::RuleMatch;
|
||||
use codex_shell_command::is_dangerous_command::command_might_be_dangerous;
|
||||
use codex_shell_escalation as shell_escalation;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::{self};
|
||||
|
||||
use crate::posix::mcp_escalation_policy::ExecPolicyOutcome;
|
||||
use crate::unix::mcp_escalation_policy::ExecPolicyOutcome;
|
||||
|
||||
mod escalate_client;
|
||||
mod escalate_protocol;
|
||||
mod escalate_server;
|
||||
mod escalation_policy;
|
||||
mod mcp;
|
||||
mod mcp_escalation_policy;
|
||||
mod socket;
|
||||
mod stopwatch;
|
||||
|
||||
pub use mcp::ExecResult;
|
||||
|
||||
@@ -165,7 +160,7 @@ pub async fn main_execve_wrapper() -> anyhow::Result<()> {
|
||||
.init();
|
||||
|
||||
let ExecveWrapperCli { file, argv } = ExecveWrapperCli::parse();
|
||||
let exit_code = escalate_client::run(file, argv).await?;
|
||||
let exit_code = shell_escalation::run(file, argv).await?;
|
||||
std::process::exit(exit_code);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -10,6 +9,8 @@ use codex_core::MCP_SANDBOX_STATE_METHOD;
|
||||
use codex_core::SandboxState;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_shell_escalation::EscalationPolicyFactory;
|
||||
use codex_shell_escalation::run_escalate_server;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::RoleServer;
|
||||
use rmcp::ServerHandler;
|
||||
@@ -19,7 +20,6 @@ use rmcp::handler::server::wrapper::Parameters;
|
||||
use rmcp::model::CustomRequest;
|
||||
use rmcp::model::CustomResult;
|
||||
use rmcp::model::*;
|
||||
use rmcp::schemars;
|
||||
use rmcp::service::RequestContext;
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::tool;
|
||||
@@ -29,11 +29,7 @@ use rmcp::transport::stdio;
|
||||
use serde_json::json;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::posix::escalate_server::EscalateServer;
|
||||
use crate::posix::escalate_server::{self};
|
||||
use crate::posix::escalation_policy::EscalationPolicy;
|
||||
use crate::posix::mcp_escalation_policy::McpEscalationPolicy;
|
||||
use crate::posix::stopwatch::Stopwatch;
|
||||
use crate::unix::mcp_escalation_policy::McpEscalationPolicy;
|
||||
|
||||
/// Path to our patched bash.
|
||||
const CODEX_BASH_PATH_ENV_VAR: &str = "CODEX_BASH_PATH";
|
||||
@@ -46,19 +42,7 @@ pub(crate) fn get_bash_path() -> Result<PathBuf> {
|
||||
.context(format!("{CODEX_BASH_PATH_ENV_VAR} must be set"))
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct ExecParams {
|
||||
/// The bash string to execute.
|
||||
pub command: String,
|
||||
/// The working directory to execute the command in. Must be an absolute path.
|
||||
pub workdir: String,
|
||||
/// The timeout for the command in milliseconds.
|
||||
pub timeout_ms: Option<u64>,
|
||||
/// Launch Bash with -lc instead of -c: defaults to true.
|
||||
pub login: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ExecResult {
|
||||
pub exit_code: i32,
|
||||
pub output: String,
|
||||
@@ -66,8 +50,8 @@ pub struct ExecResult {
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
impl From<escalate_server::ExecResult> for ExecResult {
|
||||
fn from(result: escalate_server::ExecResult) -> Self {
|
||||
impl From<codex_shell_escalation::ExecResult> for ExecResult {
|
||||
fn from(result: codex_shell_escalation::ExecResult) -> Self {
|
||||
Self {
|
||||
exit_code: result.exit_code,
|
||||
output: result.output,
|
||||
@@ -87,10 +71,27 @@ pub struct ExecTool {
|
||||
sandbox_state: Arc<RwLock<Option<SandboxState>>>,
|
||||
}
|
||||
|
||||
trait EscalationPolicyFactory {
|
||||
type Policy: EscalationPolicy + Send + Sync + 'static;
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize, rmcp::schemars::JsonSchema)]
|
||||
pub struct ExecParams {
|
||||
/// The bash string to execute.
|
||||
pub command: String,
|
||||
/// The working directory to execute the command in. Must be an absolute path.
|
||||
pub workdir: String,
|
||||
/// The timeout for the command in milliseconds.
|
||||
pub timeout_ms: Option<u64>,
|
||||
/// Launch Bash with -lc instead of -c: defaults to true.
|
||||
pub login: Option<bool>,
|
||||
}
|
||||
|
||||
fn create_policy(&self, policy: Arc<RwLock<Policy>>, stopwatch: Stopwatch) -> Self::Policy;
|
||||
impl From<ExecParams> for codex_shell_escalation::ExecParams {
|
||||
fn from(inner: ExecParams) -> Self {
|
||||
Self {
|
||||
command: inner.command,
|
||||
workdir: inner.workdir,
|
||||
timeout_ms: inner.timeout_ms,
|
||||
login: inner.login,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct McpEscalationPolicyFactory {
|
||||
@@ -101,7 +102,11 @@ struct McpEscalationPolicyFactory {
|
||||
impl EscalationPolicyFactory for McpEscalationPolicyFactory {
|
||||
type Policy = McpEscalationPolicy;
|
||||
|
||||
fn create_policy(&self, policy: Arc<RwLock<Policy>>, stopwatch: Stopwatch) -> Self::Policy {
|
||||
fn create_policy(
|
||||
&self,
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
stopwatch: codex_shell_escalation::Stopwatch,
|
||||
) -> Self::Policy {
|
||||
McpEscalationPolicy::new(
|
||||
policy,
|
||||
self.context.clone(),
|
||||
@@ -153,8 +158,8 @@ impl ExecTool {
|
||||
use_linux_sandbox_bwrap: false,
|
||||
});
|
||||
let result = run_escalate_server(
|
||||
params,
|
||||
sandbox_state,
|
||||
params.into(),
|
||||
&sandbox_state,
|
||||
&self.bash_path,
|
||||
&self.execve_wrapper,
|
||||
self.policy.clone(),
|
||||
@@ -172,48 +177,6 @@ impl ExecTool {
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs the escalate server to execute a shell command with potential
|
||||
/// escalation of execve calls.
|
||||
///
|
||||
/// - `exec_params` defines the shell command to run
|
||||
/// - `sandbox_state` is the sandbox to use to run the shell program
|
||||
/// - `shell_program` is the path to the shell program to run (e.g. /bin/bash)
|
||||
/// - `execve_wrapper` is the path to the execve wrapper binary to use for
|
||||
/// handling execve calls from the shell program. This is likely a symlink to
|
||||
/// Codex using a special name.
|
||||
/// - `policy` is the exec policy to use for deciding whether to allow or deny
|
||||
/// execve calls from the shell program.
|
||||
/// - `escalation_policy_factory` is a factory for creating an
|
||||
/// `EscalationPolicy` to use for deciding whether to allow, deny, or prompt
|
||||
/// the user for execve calls from the shell program. We use a factory here
|
||||
/// because the `EscalationPolicy` may need to capture request-specific
|
||||
/// context (e.g. the MCP request context) that is not available at the time
|
||||
/// we create the `ExecTool`.
|
||||
/// - `effective_timeout` is the timeout to use for running the shell command.
|
||||
/// Implementations are encouraged to excludeany time spent prompting the
|
||||
/// user.
|
||||
async fn run_escalate_server(
|
||||
exec_params: ExecParams,
|
||||
sandbox_state: SandboxState,
|
||||
shell_program: impl AsRef<Path>,
|
||||
execve_wrapper: impl AsRef<Path>,
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
escalation_policy_factory: impl EscalationPolicyFactory,
|
||||
effective_timeout: Duration,
|
||||
) -> anyhow::Result<crate::posix::escalate_server::ExecResult> {
|
||||
let stopwatch = Stopwatch::new(effective_timeout);
|
||||
let cancel_token = stopwatch.cancellation_token();
|
||||
let escalate_server = EscalateServer::new(
|
||||
shell_program.as_ref().to_path_buf(),
|
||||
execve_wrapper.as_ref().to_path_buf(),
|
||||
escalation_policy_factory.create_policy(policy, stopwatch),
|
||||
);
|
||||
|
||||
escalate_server
|
||||
.exec(exec_params, cancel_token, &sandbox_state)
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CodexSandboxStateUpdateMethod;
|
||||
|
||||
@@ -307,7 +270,7 @@ mod tests {
|
||||
/// `timeout_ms` fields are optional.
|
||||
#[test]
|
||||
fn exec_params_json_schema_matches_expected() {
|
||||
let schema = schemars::schema_for!(ExecParams);
|
||||
let schema = rmcp::schemars::schema_for!(ExecParams);
|
||||
let actual = serde_json::to_value(schema).expect("schema should serialize");
|
||||
|
||||
assert_eq!(
|
||||
@@ -2,6 +2,9 @@ use std::path::Path;
|
||||
|
||||
use codex_core::sandboxing::SandboxPermissions;
|
||||
use codex_execpolicy::Policy;
|
||||
use codex_shell_escalation::EscalateAction;
|
||||
use codex_shell_escalation::EscalationPolicy;
|
||||
use codex_shell_escalation::Stopwatch;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::RoleServer;
|
||||
use rmcp::model::CreateElicitationRequestParams;
|
||||
@@ -9,10 +12,7 @@ use rmcp::model::CreateElicitationResult;
|
||||
use rmcp::model::ElicitationAction;
|
||||
use rmcp::model::ElicitationSchema;
|
||||
use rmcp::service::RequestContext;
|
||||
|
||||
use crate::posix::escalate_protocol::EscalateAction;
|
||||
use crate::posix::escalation_policy::EscalationPolicy;
|
||||
use crate::posix::stopwatch::Stopwatch;
|
||||
use shlex::try_join;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
@@ -59,7 +59,7 @@ impl McpEscalationPolicy {
|
||||
workdir: &Path,
|
||||
context: RequestContext<RoleServer>,
|
||||
) -> Result<CreateElicitationResult, McpError> {
|
||||
let args = shlex::try_join(argv.iter().skip(1).map(String::as_str)).unwrap_or_default();
|
||||
let args = try_join(argv.iter().skip(1).map(String::as_str)).unwrap_or_default();
|
||||
let command = if args.is_empty() {
|
||||
file.display().to_string()
|
||||
} else {
|
||||
@@ -104,10 +104,10 @@ impl EscalationPolicy for McpEscalationPolicy {
|
||||
file: &Path,
|
||||
argv: &[String],
|
||||
workdir: &Path,
|
||||
) -> Result<EscalateAction, rmcp::ErrorData> {
|
||||
) -> anyhow::Result<EscalateAction> {
|
||||
let policy = self.policy.read().await;
|
||||
let outcome =
|
||||
crate::posix::evaluate_exec_policy(&policy, file, argv, self.preserve_program_paths)?;
|
||||
crate::unix::evaluate_exec_policy(&policy, file, argv, self.preserve_program_paths)?;
|
||||
let action = match outcome {
|
||||
ExecPolicyOutcome::Allow {
|
||||
sandbox_permissions,
|
||||
Reference in New Issue
Block a user