mirror of
https://github.com/openai/codex.git
synced 2026-05-01 03:42:05 +03:00
exec-server (#6630)
This commit is contained in:
102
codex-rs/exec-server/src/posix/escalate_client.rs
Normal file
102
codex-rs/exec-server/src/posix/escalate_client.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
use std::io;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd as _;
|
||||
use std::os::fd::OwnedFd;
|
||||
|
||||
use anyhow::Context as _;
|
||||
|
||||
use crate::posix::escalate_protocol::BASH_EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::EscalateAction;
|
||||
use crate::posix::escalate_protocol::EscalateRequest;
|
||||
use crate::posix::escalate_protocol::EscalateResponse;
|
||||
use crate::posix::escalate_protocol::SuperExecMessage;
|
||||
use crate::posix::escalate_protocol::SuperExecResult;
|
||||
use crate::posix::socket::AsyncDatagramSocket;
|
||||
use crate::posix::socket::AsyncSocket;
|
||||
|
||||
fn get_escalate_client() -> anyhow::Result<AsyncDatagramSocket> {
|
||||
// TODO: we should defensively require only calling this once, since AsyncSocket will take ownership of the fd.
|
||||
let client_fd = std::env::var(ESCALATE_SOCKET_ENV_VAR)?.parse::<i32>()?;
|
||||
if client_fd < 0 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"{ESCALATE_SOCKET_ENV_VAR} is not a valid file descriptor: {client_fd}"
|
||||
));
|
||||
}
|
||||
Ok(unsafe { AsyncDatagramSocket::from_raw_fd(client_fd) }?)
|
||||
}
|
||||
|
||||
pub(crate) async fn run(file: String, argv: Vec<String>) -> anyhow::Result<i32> {
|
||||
let handshake_client = get_escalate_client()?;
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
const HANDSHAKE_MESSAGE: [u8; 1] = [0];
|
||||
handshake_client
|
||||
.send_with_fds(&HANDSHAKE_MESSAGE, &[server.into_inner().into()])
|
||||
.await
|
||||
.context("failed to send handshake datagram")?;
|
||||
let env = std::env::vars()
|
||||
.filter(|(k, _)| {
|
||||
!matches!(
|
||||
k.as_str(),
|
||||
ESCALATE_SOCKET_ENV_VAR | BASH_EXEC_WRAPPER_ENV_VAR
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: file.clone().into(),
|
||||
argv: argv.clone(),
|
||||
workdir: std::env::current_dir()?,
|
||||
env,
|
||||
})
|
||||
.await
|
||||
.context("failed to send EscalateRequest")?;
|
||||
let message = client.receive::<EscalateResponse>().await?;
|
||||
match message.action {
|
||||
EscalateAction::Escalate => {
|
||||
// TODO: maybe we should send ALL open FDs (except the escalate client)?
|
||||
let fds_to_send = [
|
||||
unsafe { OwnedFd::from_raw_fd(io::stdin().as_raw_fd()) },
|
||||
unsafe { OwnedFd::from_raw_fd(io::stdout().as_raw_fd()) },
|
||||
unsafe { OwnedFd::from_raw_fd(io::stderr().as_raw_fd()) },
|
||||
];
|
||||
|
||||
// TODO: also forward signals over the super-exec socket
|
||||
|
||||
client
|
||||
.send_with_fds(
|
||||
SuperExecMessage {
|
||||
fds: fds_to_send.iter().map(AsRawFd::as_raw_fd).collect(),
|
||||
},
|
||||
&fds_to_send,
|
||||
)
|
||||
.await
|
||||
.context("failed to send SuperExecMessage")?;
|
||||
let SuperExecResult { exit_code } = client.receive::<SuperExecResult>().await?;
|
||||
Ok(exit_code)
|
||||
}
|
||||
EscalateAction::Run => {
|
||||
// We avoid std::process::Command here because we want to be as transparent as
|
||||
// possible. std::os::unix::process::CommandExt has .exec() but it does some funky
|
||||
// stuff with signal masks and dup2() on its standard FDs, which we don't want.
|
||||
use std::ffi::CString;
|
||||
let file = CString::new(file).context("NUL in file")?;
|
||||
|
||||
let argv_cstrs: Vec<CString> = argv
|
||||
.iter()
|
||||
.map(|s| CString::new(s.as_str()).context("NUL in argv"))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let mut argv: Vec<*const libc::c_char> =
|
||||
argv_cstrs.iter().map(|s| s.as_ptr()).collect();
|
||||
argv.push(std::ptr::null());
|
||||
|
||||
let err = unsafe {
|
||||
libc::execv(file.as_ptr(), argv.as_ptr());
|
||||
std::io::Error::last_os_error()
|
||||
};
|
||||
|
||||
Err(err.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
49
codex-rs/exec-server/src/posix/escalate_protocol.rs
Normal file
49
codex-rs/exec-server/src/posix/escalate_protocol.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use std::collections::HashMap;
|
||||
use std::os::fd::RawFd;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
/// 'exec-server escalate' reads this to find the inherited FD for the escalate socket.
|
||||
pub(super) const ESCALATE_SOCKET_ENV_VAR: &str = "CODEX_ESCALATE_SOCKET";
|
||||
|
||||
/// The patched bash uses this to wrap exec() calls.
|
||||
pub(super) const BASH_EXEC_WRAPPER_ENV_VAR: &str = "BASH_EXEC_WRAPPER";
|
||||
|
||||
/// The client sends this to the server to request an exec() call.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
pub(super) struct EscalateRequest {
|
||||
/// The absolute path to the executable to run, i.e. the first arg to exec.
|
||||
pub(super) file: PathBuf,
|
||||
/// The argv, including the program name (argv[0]).
|
||||
pub(super) argv: Vec<String>,
|
||||
pub(super) workdir: PathBuf,
|
||||
pub(super) env: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// The server sends this to the client to respond to an exec() request.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
pub(super) struct EscalateResponse {
|
||||
pub(super) action: EscalateAction,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
pub(super) enum EscalateAction {
|
||||
/// The command should be run directly by the client.
|
||||
Run,
|
||||
/// The command should be escalated to the server for execution.
|
||||
Escalate,
|
||||
}
|
||||
|
||||
/// The client sends this to the server to forward its open FDs.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub(super) struct SuperExecMessage {
|
||||
pub(super) fds: Vec<RawFd>,
|
||||
}
|
||||
|
||||
/// The server responds when the exec()'d command has exited.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub(super) struct SuperExecResult {
|
||||
pub(super) exit_code: i32,
|
||||
}
|
||||
274
codex-rs/exec-server/src/posix/escalate_server.rs
Normal file
274
codex-rs/exec-server/src/posix/escalate_server.rs
Normal file
@@ -0,0 +1,274 @@
|
||||
use std::collections::HashMap;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use path_absolutize::Absolutize as _;
|
||||
|
||||
use codex_core::exec::SandboxType;
|
||||
use codex_core::exec::process_exec_tool_call;
|
||||
use codex_core::get_platform_sandbox;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::posix::escalate_protocol::BASH_EXEC_WRAPPER_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
|
||||
use crate::posix::escalate_protocol::EscalateAction;
|
||||
use crate::posix::escalate_protocol::EscalateRequest;
|
||||
use crate::posix::escalate_protocol::EscalateResponse;
|
||||
use crate::posix::escalate_protocol::SuperExecMessage;
|
||||
use crate::posix::escalate_protocol::SuperExecResult;
|
||||
use crate::posix::socket::AsyncDatagramSocket;
|
||||
use crate::posix::socket::AsyncSocket;
|
||||
|
||||
/// This is the policy which decides how to handle an exec() call.
|
||||
///
|
||||
/// `file` is the absolute, canonical path to the executable to run, i.e. the first arg to exec.
|
||||
/// `argv` is the argv, including the program name (`argv[0]`).
|
||||
/// `workdir` is the absolute, canonical path to the working directory in which to execute the
|
||||
/// command.
|
||||
pub(crate) type ExecPolicy = fn(file: &Path, argv: &[String], workdir: &Path) -> EscalateAction;
|
||||
|
||||
pub(crate) struct EscalateServer {
|
||||
bash_path: PathBuf,
|
||||
policy: ExecPolicy,
|
||||
}
|
||||
|
||||
impl EscalateServer {
|
||||
pub fn new(bash_path: PathBuf, policy: ExecPolicy) -> Self {
|
||||
Self { bash_path, policy }
|
||||
}
|
||||
|
||||
pub async fn exec(
|
||||
&self,
|
||||
command: String,
|
||||
env: HashMap<String, String>,
|
||||
workdir: PathBuf,
|
||||
timeout_ms: Option<u64>,
|
||||
) -> anyhow::Result<ExecResult> {
|
||||
let (escalate_server, escalate_client) = AsyncDatagramSocket::pair()?;
|
||||
let client_socket = escalate_client.into_inner();
|
||||
client_socket.set_cloexec(false)?;
|
||||
|
||||
let escalate_task = tokio::spawn(escalate_task(escalate_server, self.policy));
|
||||
let mut env = env.clone();
|
||||
env.insert(
|
||||
ESCALATE_SOCKET_ENV_VAR.to_string(),
|
||||
client_socket.as_raw_fd().to_string(),
|
||||
);
|
||||
env.insert(
|
||||
BASH_EXEC_WRAPPER_ENV_VAR.to_string(),
|
||||
format!("{} escalate", std::env::current_exe()?.to_string_lossy()),
|
||||
);
|
||||
let result = process_exec_tool_call(
|
||||
codex_core::exec::ExecParams {
|
||||
command: vec![
|
||||
self.bash_path.to_string_lossy().to_string(),
|
||||
"-c".to_string(),
|
||||
command,
|
||||
],
|
||||
cwd: PathBuf::from(&workdir),
|
||||
timeout_ms,
|
||||
env,
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
},
|
||||
get_platform_sandbox().unwrap_or(SandboxType::None),
|
||||
// TODO: use the sandbox policy and cwd from the calling client
|
||||
&SandboxPolicy::ReadOnly,
|
||||
&PathBuf::from("/__NONEXISTENT__"), // This is ignored for ReadOnly
|
||||
&None,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
escalate_task.abort();
|
||||
let result = ExecResult {
|
||||
exit_code: result.exit_code,
|
||||
output: result.aggregated_output.text,
|
||||
duration: result.duration,
|
||||
timed_out: result.timed_out,
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
async fn escalate_task(socket: AsyncDatagramSocket, policy: ExecPolicy) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let (_, mut fds) = socket.receive_with_fds().await?;
|
||||
if fds.len() != 1 {
|
||||
tracing::error!("expected 1 fd in datagram handshake, got {}", fds.len());
|
||||
continue;
|
||||
}
|
||||
let stream_socket = AsyncSocket::from_fd(fds.remove(0))?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = handle_escalate_session_with_policy(stream_socket, policy).await {
|
||||
tracing::error!("escalate session failed: {err:?}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ExecResult {
|
||||
pub(crate) exit_code: i32,
|
||||
pub(crate) output: String,
|
||||
pub(crate) duration: Duration,
|
||||
pub(crate) timed_out: bool,
|
||||
}
|
||||
|
||||
async fn handle_escalate_session_with_policy(
|
||||
socket: AsyncSocket,
|
||||
policy: ExecPolicy,
|
||||
) -> anyhow::Result<()> {
|
||||
let EscalateRequest {
|
||||
file,
|
||||
argv,
|
||||
workdir,
|
||||
env,
|
||||
} = socket.receive::<EscalateRequest>().await?;
|
||||
let file = PathBuf::from(&file).absolutize()?.into_owned();
|
||||
let workdir = PathBuf::from(&workdir).absolutize()?.into_owned();
|
||||
let action = policy(file.as_path(), &argv, &workdir);
|
||||
tracing::debug!("decided {action:?} for {file:?} {argv:?} {workdir:?}");
|
||||
match action {
|
||||
EscalateAction::Run => {
|
||||
socket
|
||||
.send(EscalateResponse {
|
||||
action: EscalateAction::Run,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
EscalateAction::Escalate => {
|
||||
socket
|
||||
.send(EscalateResponse {
|
||||
action: EscalateAction::Escalate,
|
||||
})
|
||||
.await?;
|
||||
let (msg, fds) = socket
|
||||
.receive_with_fds::<SuperExecMessage>()
|
||||
.await
|
||||
.context("failed to receive SuperExecMessage")?;
|
||||
if fds.len() != msg.fds.len() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"mismatched number of fds in SuperExecMessage: {} in the message, {} from the control message",
|
||||
msg.fds.len(),
|
||||
fds.len()
|
||||
));
|
||||
}
|
||||
|
||||
if msg
|
||||
.fds
|
||||
.iter()
|
||||
.any(|src_fd| fds.iter().any(|dst_fd| dst_fd.as_raw_fd() == *src_fd))
|
||||
{
|
||||
return Err(anyhow::anyhow!(
|
||||
"overlapping fds not yet supported in SuperExecMessage"
|
||||
));
|
||||
}
|
||||
|
||||
let mut command = Command::new(file);
|
||||
command
|
||||
.args(&argv[1..])
|
||||
.arg0(argv[0].clone())
|
||||
.envs(&env)
|
||||
.current_dir(&workdir)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null());
|
||||
unsafe {
|
||||
command.pre_exec(move || {
|
||||
for (dst_fd, src_fd) in msg.fds.iter().zip(&fds) {
|
||||
libc::dup2(src_fd.as_raw_fd(), *dst_fd);
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
let mut child = command.spawn()?;
|
||||
let exit_status = child.wait().await?;
|
||||
socket
|
||||
.send(SuperExecResult {
|
||||
exit_code: exit_status.code().unwrap_or(127),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_respects_run_in_sandbox_decision() -> anyhow::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let server_task = tokio::spawn(handle_escalate_session_with_policy(
|
||||
server,
|
||||
|_file, _argv, _workdir| EscalateAction::Run,
|
||||
));
|
||||
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: PathBuf::from("/bin/echo"),
|
||||
argv: vec!["echo".to_string()],
|
||||
workdir: PathBuf::from("/tmp"),
|
||||
env: HashMap::new(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let response = client.receive::<EscalateResponse>().await?;
|
||||
assert_eq!(
|
||||
EscalateResponse {
|
||||
action: EscalateAction::Run,
|
||||
},
|
||||
response
|
||||
);
|
||||
server_task.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_escalate_session_executes_escalated_command() -> anyhow::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let server_task = tokio::spawn(handle_escalate_session_with_policy(
|
||||
server,
|
||||
|_file, _argv, _workdir| EscalateAction::Escalate,
|
||||
));
|
||||
|
||||
client
|
||||
.send(EscalateRequest {
|
||||
file: PathBuf::from("/bin/sh"),
|
||||
argv: vec![
|
||||
"sh".to_string(),
|
||||
"-c".to_string(),
|
||||
r#"if [ "$KEY" = VALUE ]; then exit 42; else exit 1; fi"#.to_string(),
|
||||
],
|
||||
workdir: std::env::current_dir()?,
|
||||
env: HashMap::from([("KEY".to_string(), "VALUE".to_string())]),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let response = client.receive::<EscalateResponse>().await?;
|
||||
assert_eq!(
|
||||
EscalateResponse {
|
||||
action: EscalateAction::Escalate,
|
||||
},
|
||||
response
|
||||
);
|
||||
|
||||
client
|
||||
.send_with_fds(SuperExecMessage { fds: Vec::new() }, &[])
|
||||
.await?;
|
||||
|
||||
let result = client.receive::<SuperExecResult>().await?;
|
||||
assert_eq!(42, result.exit_code);
|
||||
|
||||
server_task.await?
|
||||
}
|
||||
}
|
||||
154
codex-rs/exec-server/src/posix/mcp.rs
Normal file
154
codex-rs/exec-server/src/posix/mcp.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::RoleServer;
|
||||
use rmcp::ServerHandler;
|
||||
use rmcp::ServiceExt;
|
||||
use rmcp::handler::server::router::tool::ToolRouter;
|
||||
use rmcp::handler::server::wrapper::Parameters;
|
||||
use rmcp::model::*;
|
||||
use rmcp::schemars;
|
||||
use rmcp::service::RequestContext;
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::tool;
|
||||
use rmcp::tool_handler;
|
||||
use rmcp::tool_router;
|
||||
use rmcp::transport::stdio;
|
||||
|
||||
use crate::posix::escalate_server;
|
||||
use crate::posix::escalate_server::EscalateServer;
|
||||
use crate::posix::escalate_server::ExecPolicy;
|
||||
|
||||
/// Path to our patched bash.
|
||||
const CODEX_BASH_PATH_ENV_VAR: &str = "CODEX_BASH_PATH";
|
||||
|
||||
pub(crate) fn get_bash_path() -> Result<PathBuf> {
|
||||
std::env::var(CODEX_BASH_PATH_ENV_VAR)
|
||||
.map(PathBuf::from)
|
||||
.context(format!("{CODEX_BASH_PATH_ENV_VAR} must be set"))
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct ExecParams {
|
||||
/// The bash string to execute.
|
||||
pub command: String,
|
||||
/// The working directory to execute the command in. Must be an absolute path.
|
||||
pub workdir: String,
|
||||
/// The timeout for the command in milliseconds.
|
||||
pub timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize, schemars::JsonSchema)]
|
||||
pub struct ExecResult {
|
||||
pub exit_code: i32,
|
||||
pub output: String,
|
||||
pub duration: Duration,
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
impl From<escalate_server::ExecResult> for ExecResult {
|
||||
fn from(result: escalate_server::ExecResult) -> Self {
|
||||
Self {
|
||||
exit_code: result.exit_code,
|
||||
output: result.output,
|
||||
duration: result.duration,
|
||||
timed_out: result.timed_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ExecTool {
|
||||
tool_router: ToolRouter<ExecTool>,
|
||||
bash_path: PathBuf,
|
||||
policy: ExecPolicy,
|
||||
}
|
||||
|
||||
#[tool_router]
|
||||
impl ExecTool {
|
||||
pub fn new(bash_path: PathBuf, policy: ExecPolicy) -> Self {
|
||||
Self {
|
||||
tool_router: Self::tool_router(),
|
||||
bash_path,
|
||||
policy,
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs a shell command and returns its output. You MUST provide the workdir as an absolute path.
|
||||
#[tool]
|
||||
async fn shell(
|
||||
&self,
|
||||
_context: RequestContext<RoleServer>,
|
||||
Parameters(params): Parameters<ExecParams>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let escalate_server = EscalateServer::new(self.bash_path.clone(), self.policy);
|
||||
let result = escalate_server
|
||||
.exec(
|
||||
params.command,
|
||||
// TODO: use ShellEnvironmentPolicy
|
||||
std::env::vars().collect(),
|
||||
PathBuf::from(¶ms.workdir),
|
||||
params.timeout_ms,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(e.to_string(), None))?;
|
||||
Ok(CallToolResult::success(vec![Content::json(
|
||||
ExecResult::from(result),
|
||||
)?]))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn prompt(
|
||||
&self,
|
||||
command: String,
|
||||
workdir: String,
|
||||
context: RequestContext<RoleServer>,
|
||||
) -> Result<CreateElicitationResult, McpError> {
|
||||
context
|
||||
.peer
|
||||
.create_elicitation(CreateElicitationRequestParam {
|
||||
message: format!("Allow Codex to run `{command:?}` in `{workdir:?}`?"),
|
||||
#[allow(clippy::expect_used)]
|
||||
requested_schema: ElicitationSchema::builder()
|
||||
.property("dummy", PrimitiveSchema::String(StringSchema::new()))
|
||||
.build()
|
||||
.expect("failed to build elicitation schema"),
|
||||
})
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(e.to_string(), None))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool_handler]
|
||||
impl ServerHandler for ExecTool {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
protocol_version: ProtocolVersion::V_2025_06_18,
|
||||
capabilities: ServerCapabilities::builder().enable_tools().build(),
|
||||
server_info: Implementation::from_build_env(),
|
||||
instructions: Some(
|
||||
"This server provides a tool to execute shell commands and return their output."
|
||||
.to_string(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn initialize(
|
||||
&self,
|
||||
_request: InitializeRequestParam,
|
||||
_context: RequestContext<RoleServer>,
|
||||
) -> Result<InitializeResult, McpError> {
|
||||
Ok(self.get_info())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn serve(
|
||||
bash_path: PathBuf,
|
||||
policy: ExecPolicy,
|
||||
) -> Result<RunningService<RoleServer, ExecTool>, rmcp::service::ServerInitializeError> {
|
||||
let tool = ExecTool::new(bash_path, policy);
|
||||
tool.serve(stdio()).await
|
||||
}
|
||||
486
codex-rs/exec-server/src/posix/socket.rs
Normal file
486
codex-rs/exec-server/src/posix/socket.rs
Normal file
@@ -0,0 +1,486 @@
|
||||
use libc::c_uint;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use socket2::Domain;
|
||||
use socket2::MaybeUninitSlice;
|
||||
use socket2::MsgHdr;
|
||||
use socket2::MsgHdrMut;
|
||||
use socket2::Socket;
|
||||
use socket2::Type;
|
||||
use std::io::IoSlice;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::os::fd::FromRawFd;
|
||||
use std::os::fd::OwnedFd;
|
||||
use std::os::fd::RawFd;
|
||||
use tokio::io::Interest;
|
||||
use tokio::io::unix::AsyncFd;
|
||||
|
||||
const MAX_FDS_PER_MESSAGE: usize = 16;
|
||||
const LENGTH_PREFIX_SIZE: usize = size_of::<u32>();
|
||||
const MAX_DATAGRAM_SIZE: usize = 8192;
|
||||
|
||||
/// Converts a slice of MaybeUninit<T> to a slice of T.
|
||||
///
|
||||
/// The caller guarantees that every element of `buf` is initialized.
|
||||
fn assume_init<T>(buf: &[MaybeUninit<T>]) -> &[T] {
|
||||
unsafe { std::slice::from_raw_parts(buf.as_ptr().cast(), buf.len()) }
|
||||
}
|
||||
|
||||
fn assume_init_slice<T, const N: usize>(buf: &[MaybeUninit<T>; N]) -> &[T; N] {
|
||||
unsafe { &*(buf as *const [MaybeUninit<T>; N] as *const [T; N]) }
|
||||
}
|
||||
|
||||
fn assume_init_vec<T>(mut buf: Vec<MaybeUninit<T>>) -> Vec<T> {
|
||||
unsafe {
|
||||
let ptr = buf.as_mut_ptr() as *mut T;
|
||||
let len = buf.len();
|
||||
let cap = buf.capacity();
|
||||
std::mem::forget(buf);
|
||||
Vec::from_raw_parts(ptr, len, cap)
|
||||
}
|
||||
}
|
||||
|
||||
fn control_space_for_fds(count: usize) -> usize {
|
||||
unsafe { libc::CMSG_SPACE((count * size_of::<RawFd>()) as _) as usize }
|
||||
}
|
||||
|
||||
/// Extracts the FDs from a SCM_RIGHTS control message.
|
||||
fn extract_fds(control: &[u8]) -> Vec<OwnedFd> {
|
||||
let mut fds = Vec::new();
|
||||
let mut hdr: libc::msghdr = unsafe { std::mem::zeroed() };
|
||||
hdr.msg_control = control.as_ptr() as *mut libc::c_void;
|
||||
hdr.msg_controllen = control.len() as _;
|
||||
let hdr = hdr; // drop mut
|
||||
|
||||
let mut cmsg = unsafe { libc::CMSG_FIRSTHDR(&hdr) as *const libc::cmsghdr };
|
||||
while !cmsg.is_null() {
|
||||
let level = unsafe { (*cmsg).cmsg_level };
|
||||
let ty = unsafe { (*cmsg).cmsg_type };
|
||||
if level == libc::SOL_SOCKET && ty == libc::SCM_RIGHTS {
|
||||
let data_ptr = unsafe { libc::CMSG_DATA(cmsg).cast::<RawFd>() };
|
||||
let fd_count: usize = {
|
||||
let cmsg_data_len =
|
||||
unsafe { (*cmsg).cmsg_len as usize } - unsafe { libc::CMSG_LEN(0) as usize };
|
||||
cmsg_data_len / size_of::<RawFd>()
|
||||
};
|
||||
for i in 0..fd_count {
|
||||
let fd = unsafe { data_ptr.add(i).read() };
|
||||
fds.push(unsafe { OwnedFd::from_raw_fd(fd) });
|
||||
}
|
||||
}
|
||||
cmsg = unsafe { libc::CMSG_NXTHDR(&hdr, cmsg) };
|
||||
}
|
||||
fds
|
||||
}
|
||||
|
||||
/// Read a frame from a SOCK_STREAM socket.
|
||||
///
|
||||
/// A frame is a message length prefix followed by a payload. FDs may be included in the control
|
||||
/// message when receiving the frame header.
|
||||
async fn read_frame(async_socket: &AsyncFd<Socket>) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
|
||||
let (message_len, fds) = read_frame_header(async_socket).await?;
|
||||
let payload = read_frame_payload(async_socket, message_len).await?;
|
||||
Ok((payload, fds))
|
||||
}
|
||||
|
||||
/// Read the frame header (i.e. length) and any FDs from a SOCK_STREAM socket.
|
||||
async fn read_frame_header(
|
||||
async_socket: &AsyncFd<Socket>,
|
||||
) -> std::io::Result<(usize, Vec<OwnedFd>)> {
|
||||
let mut header = [MaybeUninit::<u8>::uninit(); LENGTH_PREFIX_SIZE];
|
||||
let mut filled = 0;
|
||||
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
|
||||
let mut captured_control = false;
|
||||
|
||||
while filled < LENGTH_PREFIX_SIZE {
|
||||
let mut guard = async_socket.readable().await?;
|
||||
// The first read should come with a control message containing any FDs.
|
||||
let result = if !captured_control {
|
||||
guard.try_io(|inner| {
|
||||
let mut bufs = [MaybeUninitSlice::new(&mut header[filled..])];
|
||||
let (read, control_len) = {
|
||||
let mut msg = MsgHdrMut::new()
|
||||
.with_buffers(&mut bufs)
|
||||
.with_control(&mut control);
|
||||
let read = inner.get_ref().recvmsg(&mut msg, 0)?;
|
||||
(read, msg.control_len())
|
||||
};
|
||||
control.truncate(control_len);
|
||||
captured_control = true;
|
||||
Ok(read)
|
||||
})
|
||||
} else {
|
||||
guard.try_io(|inner| inner.get_ref().recv(&mut header[filled..]))
|
||||
};
|
||||
let Ok(result) = result else {
|
||||
// Would block, try again.
|
||||
continue;
|
||||
};
|
||||
|
||||
let read = result?;
|
||||
if read == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"socket closed while receiving frame header",
|
||||
));
|
||||
}
|
||||
|
||||
filled += read;
|
||||
assert!(filled <= LENGTH_PREFIX_SIZE);
|
||||
if filled == LENGTH_PREFIX_SIZE {
|
||||
let len_bytes = assume_init_slice(&header);
|
||||
let payload_len = u32::from_le_bytes(*len_bytes) as usize;
|
||||
let fds = extract_fds(assume_init(&control));
|
||||
return Ok((payload_len, fds));
|
||||
}
|
||||
}
|
||||
unreachable!("header loop always returns")
|
||||
}
|
||||
|
||||
/// Read `message_len` bytes from a SOCK_STREAM socket.
|
||||
async fn read_frame_payload(
|
||||
async_socket: &AsyncFd<Socket>,
|
||||
message_len: usize,
|
||||
) -> std::io::Result<Vec<u8>> {
|
||||
if message_len == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut payload = vec![MaybeUninit::<u8>::uninit(); message_len];
|
||||
let mut filled = 0;
|
||||
while filled < message_len {
|
||||
let mut guard = async_socket.readable().await?;
|
||||
let result = guard.try_io(|inner| inner.get_ref().recv(&mut payload[filled..]));
|
||||
let Ok(result) = result else {
|
||||
// Would block, try again.
|
||||
continue;
|
||||
};
|
||||
let read = result?;
|
||||
if read == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"socket closed while receiving frame payload",
|
||||
));
|
||||
}
|
||||
filled += read;
|
||||
assert!(filled <= message_len);
|
||||
if filled == message_len {
|
||||
return Ok(assume_init_vec(payload));
|
||||
}
|
||||
}
|
||||
unreachable!("loop exits only after returning payload")
|
||||
}
|
||||
|
||||
fn send_message_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
|
||||
if fds.len() > MAX_FDS_PER_MESSAGE {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("too many fds: {}", fds.len()),
|
||||
));
|
||||
}
|
||||
let mut frame = Vec::with_capacity(LENGTH_PREFIX_SIZE + data.len());
|
||||
frame.extend_from_slice(&encode_length(data.len())?);
|
||||
frame.extend_from_slice(data);
|
||||
|
||||
let mut control = vec![0u8; control_space_for_fds(fds.len())];
|
||||
unsafe {
|
||||
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
|
||||
(*cmsg).cmsg_len = libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
|
||||
(*cmsg).cmsg_level = libc::SOL_SOCKET;
|
||||
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
|
||||
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
|
||||
for (i, fd) in fds.iter().enumerate() {
|
||||
data_ptr.add(i).write(fd.as_raw_fd());
|
||||
}
|
||||
}
|
||||
|
||||
let payload = [IoSlice::new(&frame)];
|
||||
let msg = MsgHdr::new().with_buffers(&payload).with_control(&control);
|
||||
let mut sent = socket.sendmsg(&msg, 0)?;
|
||||
while sent < frame.len() {
|
||||
let bytes = socket.send(&frame[sent..])?;
|
||||
if bytes == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
"socket closed while sending frame payload",
|
||||
));
|
||||
}
|
||||
sent += bytes;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn encode_length(len: usize) -> std::io::Result<[u8; LENGTH_PREFIX_SIZE]> {
|
||||
let len_u32 = u32::try_from(len).map_err(|_| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("message too large: {len}"),
|
||||
)
|
||||
})?;
|
||||
Ok(len_u32.to_le_bytes())
|
||||
}
|
||||
|
||||
pub(crate) fn send_json_message<T: Serialize>(
|
||||
socket: &Socket,
|
||||
msg: T,
|
||||
fds: &[OwnedFd],
|
||||
) -> std::io::Result<()> {
|
||||
let data = serde_json::to_vec(&msg)?;
|
||||
send_message_bytes(socket, &data, fds)
|
||||
}
|
||||
|
||||
fn send_datagram_bytes(socket: &Socket, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
|
||||
if fds.len() > MAX_FDS_PER_MESSAGE {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("too many fds: {}", fds.len()),
|
||||
));
|
||||
}
|
||||
let mut control = vec![0u8; control_space_for_fds(fds.len())];
|
||||
if !fds.is_empty() {
|
||||
unsafe {
|
||||
let cmsg = control.as_mut_ptr().cast::<libc::cmsghdr>();
|
||||
(*cmsg).cmsg_len =
|
||||
libc::CMSG_LEN(size_of::<RawFd>() as c_uint * fds.len() as c_uint) as _;
|
||||
(*cmsg).cmsg_level = libc::SOL_SOCKET;
|
||||
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
|
||||
let data_ptr = libc::CMSG_DATA(cmsg).cast::<RawFd>();
|
||||
for (i, fd) in fds.iter().enumerate() {
|
||||
data_ptr.add(i).write(fd.as_raw_fd());
|
||||
}
|
||||
}
|
||||
}
|
||||
let payload = [IoSlice::new(data)];
|
||||
let msg = MsgHdr::new().with_buffers(&payload).with_control(&control);
|
||||
let written = socket.sendmsg(&msg, 0)?;
|
||||
if written != data.len() {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
format!(
|
||||
"short datagram write: wrote {written} bytes out of {}",
|
||||
data.len()
|
||||
),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn receive_datagram_bytes(socket: &Socket) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
|
||||
let mut buffer = vec![MaybeUninit::<u8>::uninit(); MAX_DATAGRAM_SIZE];
|
||||
let mut control = vec![MaybeUninit::<u8>::uninit(); control_space_for_fds(MAX_FDS_PER_MESSAGE)];
|
||||
let (read, control_len) = {
|
||||
let mut bufs = [MaybeUninitSlice::new(&mut buffer)];
|
||||
let mut msg = MsgHdrMut::new()
|
||||
.with_buffers(&mut bufs)
|
||||
.with_control(&mut control);
|
||||
let read = socket.recvmsg(&mut msg, 0)?;
|
||||
(read, msg.control_len())
|
||||
};
|
||||
let data = assume_init(&buffer[..read]).to_vec();
|
||||
let fds = extract_fds(assume_init(&control[..control_len]));
|
||||
Ok((data, fds))
|
||||
}
|
||||
|
||||
pub(crate) struct AsyncSocket {
|
||||
inner: AsyncFd<Socket>,
|
||||
}
|
||||
|
||||
impl AsyncSocket {
|
||||
fn new(socket: Socket) -> std::io::Result<AsyncSocket> {
|
||||
socket.set_nonblocking(true)?;
|
||||
let async_socket = AsyncFd::new(socket)?;
|
||||
Ok(AsyncSocket {
|
||||
inner: async_socket,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_fd(fd: OwnedFd) -> std::io::Result<AsyncSocket> {
|
||||
AsyncSocket::new(Socket::from(fd))
|
||||
}
|
||||
|
||||
pub fn pair() -> std::io::Result<(AsyncSocket, AsyncSocket)> {
|
||||
let (server, client) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
|
||||
Ok((AsyncSocket::new(server)?, AsyncSocket::new(client)?))
|
||||
}
|
||||
|
||||
pub async fn send_with_fds<T: Serialize>(
|
||||
&self,
|
||||
msg: T,
|
||||
fds: &[OwnedFd],
|
||||
) -> std::io::Result<()> {
|
||||
self.inner
|
||||
.async_io(Interest::WRITABLE, |socket| {
|
||||
send_json_message(socket, &msg, fds)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn receive_with_fds<T: for<'de> Deserialize<'de>>(
|
||||
&self,
|
||||
) -> std::io::Result<(T, Vec<OwnedFd>)> {
|
||||
let (payload, fds) = read_frame(&self.inner).await?;
|
||||
let message: T = serde_json::from_slice(&payload)?;
|
||||
Ok((message, fds))
|
||||
}
|
||||
|
||||
pub async fn send<T>(&self, msg: T) -> std::io::Result<()>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
self.send_with_fds(&msg, &[]).await
|
||||
}
|
||||
|
||||
pub async fn receive<T: for<'de> Deserialize<'de>>(&self) -> std::io::Result<T> {
|
||||
let (msg, fds) = self.receive_with_fds().await?;
|
||||
if !fds.is_empty() {
|
||||
tracing::warn!("unexpected fds in receive: {}", fds.len());
|
||||
}
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Socket {
|
||||
self.inner.into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct AsyncDatagramSocket {
|
||||
inner: AsyncFd<Socket>,
|
||||
}
|
||||
|
||||
impl AsyncDatagramSocket {
|
||||
fn new(socket: Socket) -> std::io::Result<Self> {
|
||||
socket.set_nonblocking(true)?;
|
||||
Ok(Self {
|
||||
inner: AsyncFd::new(socket)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub unsafe fn from_raw_fd(fd: RawFd) -> std::io::Result<Self> {
|
||||
Self::new(unsafe { Socket::from_raw_fd(fd) })
|
||||
}
|
||||
|
||||
pub fn pair() -> std::io::Result<(Self, Self)> {
|
||||
let (server, client) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
|
||||
Ok((Self::new(server)?, Self::new(client)?))
|
||||
}
|
||||
|
||||
pub async fn send_with_fds(&self, data: &[u8], fds: &[OwnedFd]) -> std::io::Result<()> {
|
||||
self.inner
|
||||
.async_io(Interest::WRITABLE, |socket| {
|
||||
send_datagram_bytes(socket, data, fds)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn receive_with_fds(&self) -> std::io::Result<(Vec<u8>, Vec<OwnedFd>)> {
|
||||
self.inner
|
||||
.async_io(Interest::READABLE, receive_datagram_bytes)
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Socket {
|
||||
self.inner.into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::os::fd::AsFd;
|
||||
use std::os::fd::AsRawFd;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
|
||||
struct TestPayload {
|
||||
id: i32,
|
||||
label: String,
|
||||
}
|
||||
|
||||
fn fd_list(count: usize) -> std::io::Result<Vec<OwnedFd>> {
|
||||
let file = NamedTempFile::new()?;
|
||||
let mut fds = Vec::new();
|
||||
for _ in 0..count {
|
||||
fds.push(file.as_fd().try_clone_to_owned()?);
|
||||
}
|
||||
Ok(fds)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn async_socket_round_trips_payload_and_fds() -> std::io::Result<()> {
|
||||
let (server, client) = AsyncSocket::pair()?;
|
||||
let payload = TestPayload {
|
||||
id: 7,
|
||||
label: "round-trip".to_string(),
|
||||
};
|
||||
let send_fds = fd_list(1)?;
|
||||
|
||||
let receive_task =
|
||||
tokio::spawn(async move { server.receive_with_fds::<TestPayload>().await });
|
||||
client.send_with_fds(payload.clone(), &send_fds).await?;
|
||||
drop(send_fds);
|
||||
|
||||
let (received_payload, received_fds) = receive_task.await.unwrap()?;
|
||||
assert_eq!(payload, received_payload);
|
||||
assert_eq!(1, received_fds.len());
|
||||
let fd_status = unsafe { libc::fcntl(received_fds[0].as_raw_fd(), libc::F_GETFD) };
|
||||
assert!(
|
||||
fd_status >= 0,
|
||||
"expected received file descriptor to be valid, but got {fd_status}",
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn async_datagram_sockets_round_trip_messages() -> std::io::Result<()> {
|
||||
let (server, client) = AsyncDatagramSocket::pair()?;
|
||||
let data = b"datagram payload".to_vec();
|
||||
let send_fds = fd_list(1)?;
|
||||
let receive_task = tokio::spawn(async move { server.receive_with_fds().await });
|
||||
|
||||
client.send_with_fds(&data, &send_fds).await?;
|
||||
drop(send_fds);
|
||||
|
||||
let (received_bytes, received_fds) = receive_task.await.unwrap()?;
|
||||
assert_eq!(data, received_bytes);
|
||||
assert_eq!(1, received_fds.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_message_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
|
||||
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::STREAM, None)?;
|
||||
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
|
||||
let err = send_message_bytes(&socket, b"hello", &fds).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_datagram_bytes_rejects_excessive_fd_counts() -> std::io::Result<()> {
|
||||
let (socket, _peer) = Socket::pair(Domain::UNIX, Type::DGRAM, None)?;
|
||||
let fds = fd_list(MAX_FDS_PER_MESSAGE + 1)?;
|
||||
let err = send_datagram_bytes(&socket, b"hi", &fds).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_length_errors_for_oversized_messages() {
|
||||
let err = encode_length(usize::MAX).unwrap_err();
|
||||
assert_eq!(std::io::ErrorKind::InvalidInput, err.kind());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_fails_when_peer_closes_before_header() {
|
||||
let (server, client) = AsyncSocket::pair().expect("failed to create socket pair");
|
||||
drop(client);
|
||||
let err = server
|
||||
.receive::<serde_json::Value>()
|
||||
.await
|
||||
.expect_err("expected read failure");
|
||||
assert_eq!(std::io::ErrorKind::UnexpectedEof, err.kind());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user