Add codex-network-proxy crate

This commit is contained in:
viyatb-oai
2025-12-21 12:15:59 -08:00
parent 25ecd0c2e4
commit f65edf9c91
17 changed files with 3471 additions and 749 deletions

1830
codex-rs/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -26,6 +26,7 @@ members = [
"login",
"mcp-server",
"mcp-types",
"network-proxy",
"ollama",
"process-hardening",
"protocol",
@@ -83,6 +84,7 @@ codex-lmstudio = { path = "lmstudio" }
codex-login = { path = "login" }
codex-mcp-server = { path = "mcp-server" }
codex-ollama = { path = "ollama" }
codex-network-proxy = { path = "network-proxy" }
codex-otel = { path = "otel" }
codex-process-hardening = { path = "process-hardening" }
codex-protocol = { path = "protocol" }
@@ -135,6 +137,7 @@ env_logger = "0.11.5"
escargot = "0.5"
eventsource-stream = "0.2.3"
futures = { version = "0.3", default-features = false }
globset = "0.4"
http = "1.3.1"
icu_decimal = "2.1"
icu_locale_core = "2.1"

View File

@@ -1346,6 +1346,12 @@ pub fn find_codex_home() -> std::io::Result<PathBuf> {
Ok(p)
}
/// Returns the default path to the Codex config file (`config.toml`).
pub fn default_config_path() -> std::io::Result<PathBuf> {
let codex_home = find_codex_home()?;
Ok(codex_home.join(CONFIG_TOML_FILE))
}
/// Returns the path to the folder where Codex logs are stored. Does not verify
/// that the directory exists.
pub fn log_dir(cfg: &Config) -> std::io::Result<PathBuf> {

View File

@@ -0,0 +1,43 @@
[package]
name = "codex-network-proxy"
edition = "2024"
version = { workspace = true }
[[bin]]
name = "codex-network-proxy"
path = "src/main.rs"
[lib]
name = "codex_network_proxy"
path = "src/lib.rs"
[lints]
workspace = true
[features]
default = ["mitm"]
mitm = [
"tokio-rustls",
"rustls",
"rustls-native-certs",
"rustls-pemfile",
"rcgen",
]
[dependencies]
anyhow = { workspace = true }
clap = { workspace = true, features = ["derive"] }
codex-core = { workspace = true }
globset = { workspace = true }
hyper = { version = "0.14", features = ["full"] }
rcgen = { version = "0.13", features = ["pem", "x509-parser"], optional = true }
rustls = { version = "0.21", optional = true }
rustls-native-certs = { version = "0.6", optional = true }
rustls-pemfile = { version = "1", optional = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tokio = { workspace = true, features = ["full"] }
tokio-rustls = { version = "0.24", optional = true }
toml = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["fmt"] }

View File

@@ -0,0 +1,108 @@
use crate::config::NetworkMode;
use crate::responses::json_response;
use crate::responses::text_response;
use crate::state::AppState;
use anyhow::Result;
use hyper::Body;
use hyper::Method;
use hyper::Request;
use hyper::Response;
use hyper::Server;
use hyper::StatusCode;
use hyper::body::to_bytes;
use hyper::service::make_service_fn;
use hyper::service::service_fn;
use serde::Deserialize;
use serde_json::json;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::error;
use tracing::info;
pub async fn run_admin_api(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
let make_svc = make_service_fn(move |_conn: &hyper::server::conn::AddrStream| {
let state = state.clone();
async move {
Ok::<_, Infallible>(service_fn(move |req| {
handle_admin_request(req, state.clone())
}))
}
});
let server = Server::bind(&addr).serve(make_svc);
info!(addr = %addr, "admin API listening");
server.await?;
Ok(())
}
async fn handle_admin_request(
req: Request<Body>,
state: Arc<AppState>,
) -> Result<Response<Body>, Infallible> {
let method = req.method().clone();
let path = req.uri().path().to_string();
let response = match (method, path.as_str()) {
(Method::GET, "/health") => Response::new(Body::from("ok")),
(Method::GET, "/config") => match state.current_cfg().await {
Ok(cfg) => json_response(&cfg),
Err(err) => {
error!(error = %err, "failed to load config");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
},
(Method::GET, "/patterns") => match state.current_patterns().await {
Ok((allow, deny)) => json_response(&json!({"allowed": allow, "denied": deny})),
Err(err) => {
error!(error = %err, "failed to load patterns");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
},
(Method::GET, "/blocked") => match state.drain_blocked().await {
Ok(blocked) => json_response(&json!({ "blocked": blocked })),
Err(err) => {
error!(error = %err, "failed to read blocked queue");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
},
(Method::POST, "/mode") => {
let body = match to_bytes(req.into_body()).await {
Ok(bytes) => bytes,
Err(err) => {
error!(error = %err, "failed to read mode body");
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid body"));
}
};
if body.is_empty() {
return Ok(text_response(StatusCode::BAD_REQUEST, "missing body"));
}
let update: ModeUpdate = match serde_json::from_slice(&body) {
Ok(update) => update,
Err(err) => {
error!(error = %err, "failed to parse mode update");
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid json"));
}
};
match state.set_network_mode(update.mode).await {
Ok(()) => json_response(&json!({"status": "ok", "mode": update.mode})),
Err(err) => {
error!(error = %err, "mode update failed");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "mode update failed")
}
}
}
(Method::POST, "/reload") => match state.force_reload().await {
Ok(()) => json_response(&json!({"status": "reloaded"})),
Err(err) => {
error!(error = %err, "reload failed");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "reload failed")
}
},
_ => text_response(StatusCode::NOT_FOUND, "not found"),
};
Ok(response)
}
#[derive(Deserialize)]
struct ModeUpdate {
mode: NetworkMode,
}

View File

@@ -0,0 +1,203 @@
use anyhow::Context;
use anyhow::Result;
use codex_core::config::default_config_path;
use serde::Deserialize;
use serde::Serialize;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
#[serde(default)]
pub network_proxy: NetworkProxyConfig,
}
impl Default for Config {
fn default() -> Self {
Self {
network_proxy: NetworkProxyConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkProxyConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_proxy_url")]
pub proxy_url: String,
#[serde(default = "default_admin_url")]
pub admin_url: String,
#[serde(default)]
pub mode: NetworkMode,
#[serde(default)]
pub policy: NetworkPolicy,
#[serde(default)]
pub mitm: MitmConfig,
}
impl Default for NetworkProxyConfig {
fn default() -> Self {
Self {
enabled: false,
proxy_url: default_proxy_url(),
admin_url: default_admin_url(),
mode: NetworkMode::default(),
policy: NetworkPolicy::default(),
mitm: MitmConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkPolicy {
#[serde(default, rename = "allowed_domains", alias = "allowedDomains")]
pub allowed_domains: Vec<String>,
#[serde(default, rename = "denied_domains", alias = "deniedDomains")]
pub denied_domains: Vec<String>,
#[serde(default, rename = "allow_unix_sockets", alias = "allowUnixSockets")]
pub allow_unix_sockets: Vec<String>,
#[serde(default, rename = "allow_local_binding", alias = "allowLocalBinding")]
pub allow_local_binding: bool,
}
impl Default for NetworkPolicy {
fn default() -> Self {
Self {
allowed_domains: Vec::new(),
denied_domains: Vec::new(),
allow_unix_sockets: Vec::new(),
allow_local_binding: false,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum NetworkMode {
Limited,
Full,
}
impl Default for NetworkMode {
fn default() -> Self {
NetworkMode::Full
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MitmConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub inspect: bool,
#[serde(default = "default_mitm_max_body_bytes")]
pub max_body_bytes: usize,
#[serde(default = "default_ca_cert_path")]
pub ca_cert_path: PathBuf,
#[serde(default = "default_ca_key_path")]
pub ca_key_path: PathBuf,
}
impl Default for MitmConfig {
fn default() -> Self {
Self {
enabled: false,
inspect: false,
max_body_bytes: default_mitm_max_body_bytes(),
ca_cert_path: default_ca_cert_path(),
ca_key_path: default_ca_key_path(),
}
}
}
fn default_proxy_url() -> String {
"http://127.0.0.1:3128".to_string()
}
fn default_admin_url() -> String {
"http://127.0.0.1:8080".to_string()
}
fn default_ca_cert_path() -> PathBuf {
PathBuf::from("network_proxy/mitm/ca.pem")
}
fn default_ca_key_path() -> PathBuf {
PathBuf::from("network_proxy/mitm/ca.key")
}
fn default_mitm_max_body_bytes() -> usize {
4096
}
pub struct RuntimeConfig {
pub http_addr: SocketAddr,
pub socks_addr: SocketAddr,
pub admin_addr: SocketAddr,
}
pub fn default_codex_config_path() -> Result<PathBuf> {
default_config_path().context("failed to resolve Codex config path")
}
pub fn resolve_runtime(cfg: &Config) -> RuntimeConfig {
let http_addr = resolve_addr(&cfg.network_proxy.proxy_url, 3128);
let admin_addr = resolve_addr(&cfg.network_proxy.admin_url, 8080);
let socks_addr = SocketAddr::from(([127, 0, 0, 1], 8081));
RuntimeConfig {
http_addr,
socks_addr,
admin_addr,
}
}
fn resolve_addr(url: &str, default_port: u16) -> SocketAddr {
let (host, port) = parse_host_port(url, default_port);
let host = if host.eq_ignore_ascii_case("localhost") {
"127.0.0.1"
} else {
host
};
match host.parse::<IpAddr>() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => SocketAddr::from(([127, 0, 0, 1], port)),
}
}
fn parse_host_port(url: &str, default_port: u16) -> (&str, u16) {
let trimmed = url.trim();
if trimmed.is_empty() {
return ("127.0.0.1", default_port);
}
let without_scheme = trimmed
.split_once("://")
.map(|(_, rest)| rest)
.unwrap_or(trimmed);
let host_port = without_scheme.split('/').next().unwrap_or(without_scheme);
let host_port = host_port
.rsplit_once('@')
.map(|(_, rest)| rest)
.unwrap_or(host_port);
if host_port.starts_with('[') {
if let Some(end) = host_port.find(']') {
let host = &host_port[1..end];
let port = host_port[end + 1..]
.strip_prefix(':')
.and_then(|port| port.parse::<u16>().ok())
.unwrap_or(default_port);
return (host, port);
}
}
if let Some((host, port)) = host_port.rsplit_once(':') {
if let Ok(port) = port.parse::<u16>() {
return (host, port);
}
}
(host_port, default_port)
}

View File

@@ -0,0 +1,446 @@
use crate::config::NetworkMode;
use crate::mitm;
use crate::policy::normalize_host;
use crate::responses::blocked_text;
use crate::responses::json_blocked;
use crate::responses::text_response;
use crate::state::AppState;
use crate::state::BlockedRequest;
use anyhow::Result;
use hyper::Body;
use hyper::Method;
use hyper::Request;
use hyper::Response;
use hyper::Server;
use hyper::StatusCode;
use hyper::Uri;
use hyper::body::to_bytes;
use hyper::header::HOST;
use hyper::header::HeaderName;
use hyper::service::make_service_fn;
use hyper::service::service_fn;
use std::collections::HashSet;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::copy_bidirectional;
use tokio::net::TcpStream;
use tracing::error;
use tracing::info;
use tracing::warn;
pub async fn run_http_proxy(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
let make_svc = make_service_fn(move |conn: &hyper::server::conn::AddrStream| {
let state = state.clone();
let client_addr = conn.remote_addr();
async move {
Ok::<_, Infallible>(service_fn(move |req| {
handle_proxy_request(req, state.clone(), client_addr)
}))
}
});
let server = Server::bind(&addr).serve(make_svc);
info!(addr = %addr, "HTTP proxy listening");
server.await?;
Ok(())
}
async fn handle_proxy_request(
req: Request<Body>,
state: Arc<AppState>,
client_addr: SocketAddr,
) -> Result<Response<Body>, Infallible> {
let response = if req.method() == Method::CONNECT {
handle_connect(req, state, client_addr).await
} else {
handle_http_forward(req, state, client_addr).await
};
Ok(response)
}
async fn handle_connect(
req: Request<Body>,
state: Arc<AppState>,
client_addr: SocketAddr,
) -> Response<Body> {
let authority = match req.uri().authority() {
Some(auth) => auth.as_str().to_string(),
None => return text_response(StatusCode::BAD_REQUEST, "missing authority"),
};
let (authority_host, target_port) = split_authority(&authority);
let host = normalize_host(&authority_host);
if host.is_empty() {
return text_response(StatusCode::BAD_REQUEST, "invalid host");
}
match state.host_blocked(&host).await {
Ok((true, reason)) => {
let _ = state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
Some(client_addr.to_string()),
Some("CONNECT".to_string()),
None,
"http-connect".to_string(),
))
.await;
warn!(client = %client_addr, host = %host, reason = %reason, "CONNECT blocked");
return blocked_text(&reason);
}
Ok((false, _)) => {
info!(client = %client_addr, host = %host, "CONNECT allowed");
}
Err(err) => {
error!(error = %err, "failed to evaluate host");
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
}
}
let mode = match state.network_mode().await {
Ok(mode) => mode,
Err(err) => {
error!(error = %err, "failed to read network mode");
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
}
};
let mitm_state = match state.mitm_state().await {
Ok(state) => state,
Err(err) => {
error!(error = %err, "failed to load MITM state");
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
}
};
if mode == NetworkMode::Limited && mitm_state.is_none() {
let _ = state
.record_blocked(BlockedRequest::new(
host.clone(),
"mitm_required".to_string(),
Some(client_addr.to_string()),
Some("CONNECT".to_string()),
Some(NetworkMode::Limited),
"http-connect".to_string(),
))
.await;
warn!(
client = %client_addr,
host = %host,
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"CONNECT blocked; MITM required for read-only HTTPS in limited mode"
);
return blocked_text("mitm_required");
}
let on_upgrade = hyper::upgrade::on(req);
tokio::spawn(async move {
match on_upgrade.await {
Ok(upgraded) => {
if let Some(mitm_state) = mitm_state {
info!(client = %client_addr, host = %host, mode = ?mode, "CONNECT MITM enabled");
if let Err(err) =
mitm::mitm_tunnel(upgraded, &host, target_port, mode, mitm_state).await
{
warn!(error = %err, "MITM tunnel error");
}
return;
}
let mut upgraded = upgraded;
match TcpStream::connect(&authority).await {
Ok(mut server_stream) => {
if let Err(err) =
copy_bidirectional(&mut upgraded, &mut server_stream).await
{
warn!(error = %err, "tunnel error");
}
}
Err(err) => {
warn!(error = %err, "failed to connect to upstream");
}
}
}
Err(err) => warn!(error = %err, "upgrade failed"),
}
});
Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
.unwrap_or_else(|_| Response::new(Body::empty()))
}
async fn handle_http_forward(
req: Request<Body>,
state: Arc<AppState>,
client_addr: SocketAddr,
) -> Response<Body> {
let (parts, body) = req.into_parts();
let method_allowed = match state.method_allowed(&parts.method).await {
Ok(allowed) => allowed,
Err(err) => {
error!(error = %err, "failed to evaluate method policy");
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
}
};
let unix_socket = parts
.headers
.get("x-unix-socket")
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string());
if let Some(socket_path) = unix_socket {
if !method_allowed {
warn!(
client = %client_addr,
method = %parts.method,
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"unix socket blocked by method policy"
);
return json_blocked("unix-socket", "method_not_allowed");
}
if !cfg!(target_os = "macos") {
warn!(path = %socket_path, "unix socket proxy unsupported on this platform");
return text_response(StatusCode::NOT_IMPLEMENTED, "unix sockets unsupported");
}
match state.is_unix_socket_allowed(&socket_path).await {
Ok(true) => {
info!(client = %client_addr, path = %socket_path, "unix socket allowed");
match proxy_via_unix_socket(Request::from_parts(parts, body), &socket_path).await {
Ok(resp) => return resp,
Err(err) => {
warn!(error = %err, "unix socket proxy failed");
return text_response(StatusCode::BAD_GATEWAY, "unix socket proxy failed");
}
}
}
Ok(false) => {
warn!(client = %client_addr, path = %socket_path, "unix socket blocked");
return json_blocked("unix-socket", "not_allowed");
}
Err(err) => {
warn!(error = %err, "unix socket check failed");
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
}
}
}
let host_header = parts
.headers
.get(HOST)
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string())
.or_else(|| parts.uri.authority().map(|a| a.as_str().to_string()));
let authority = match host_header {
Some(h) => h,
None => return text_response(StatusCode::BAD_REQUEST, "missing host"),
};
let authority = authority.trim().to_string();
let host = normalize_host(&authority);
if host.is_empty() {
return text_response(StatusCode::BAD_REQUEST, "invalid host");
}
match state.host_blocked(&host).await {
Ok((true, reason)) => {
let _ = state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
Some(client_addr.to_string()),
Some(parts.method.to_string()),
None,
"http".to_string(),
))
.await;
warn!(client = %client_addr, host = %host, reason = %reason, "request blocked");
return json_blocked(&host, &reason);
}
Ok((false, _)) => {}
Err(err) => {
error!(error = %err, "failed to evaluate host");
return text_response(StatusCode::INTERNAL_SERVER_ERROR, "error");
}
}
if !method_allowed {
let _ = state
.record_blocked(BlockedRequest::new(
host.clone(),
"method_not_allowed".to_string(),
Some(client_addr.to_string()),
Some(parts.method.to_string()),
Some(NetworkMode::Limited),
"http".to_string(),
))
.await;
warn!(
client = %client_addr,
host = %host,
method = %parts.method,
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"request blocked by method policy"
);
return json_blocked(&host, "method_not_allowed");
}
info!(
client = %client_addr,
host = %host,
method = %parts.method,
"request allowed"
);
let uri = match build_forward_uri(&authority, &parts.uri) {
Ok(uri) => uri,
Err(err) => {
warn!(error = %err, "failed to build upstream uri");
return text_response(StatusCode::BAD_REQUEST, "invalid uri");
}
};
let body_bytes = match to_bytes(body).await {
Ok(bytes) => bytes,
Err(err) => {
warn!(error = %err, "failed to read body");
return text_response(StatusCode::BAD_GATEWAY, "failed to read body");
}
};
let mut builder = Request::builder()
.method(parts.method)
.uri(uri)
.version(parts.version);
let hop_headers = hop_by_hop_headers();
for (name, value) in parts.headers.iter() {
let name_str = name.as_str().to_ascii_lowercase();
if hop_headers.contains(name_str.as_str())
|| name == &HeaderName::from_static("x-unix-socket")
{
continue;
}
builder = builder.header(name, value);
}
let forwarded_req = match builder.body(Body::from(body_bytes)) {
Ok(req) => req,
Err(err) => {
warn!(error = %err, "failed to build request");
return text_response(StatusCode::BAD_GATEWAY, "invalid request");
}
};
match state.client.request(forwarded_req).await {
Ok(resp) => filter_response(resp),
Err(err) => {
warn!(error = %err, "upstream request failed");
text_response(StatusCode::BAD_GATEWAY, "upstream failure")
}
}
}
fn build_forward_uri(authority: &str, uri: &Uri) -> Result<Uri> {
let path = path_and_query(uri);
let target = format!("http://{authority}{path}");
Ok(target.parse()?)
}
fn filter_response(resp: Response<Body>) -> Response<Body> {
let mut builder = Response::builder().status(resp.status());
let hop_headers = hop_by_hop_headers();
for (name, value) in resp.headers().iter() {
if hop_headers.contains(name.as_str().to_ascii_lowercase().as_str()) {
continue;
}
builder = builder.header(name, value);
}
builder
.body(resp.into_body())
.unwrap_or_else(|_| Response::new(Body::from("proxy error")))
}
fn path_and_query(uri: &Uri) -> String {
uri.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/")
.to_string()
}
fn hop_by_hop_headers() -> HashSet<&'static str> {
[
"connection",
"proxy-connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
]
.into_iter()
.collect()
}
fn split_authority(authority: &str) -> (String, u16) {
if let Some(host) = authority.strip_prefix('[') {
if let Some(end) = host.find(']') {
let hostname = host[..end].to_string();
let port = host[end + 1..]
.strip_prefix(':')
.and_then(|p| p.parse::<u16>().ok())
.unwrap_or(443);
return (hostname, port);
}
}
let mut parts = authority.splitn(2, ':');
let host = parts.next().unwrap_or("").to_string();
let port = parts
.next()
.and_then(|p| p.parse::<u16>().ok())
.unwrap_or(443);
(host, port)
}
async fn proxy_via_unix_socket(req: Request<Body>, socket_path: &str) -> Result<Response<Body>> {
#[cfg(target_os = "macos")]
{
use hyper::client::conn::Builder as ConnBuilder;
use tokio::net::UnixStream;
let path = path_and_query(req.uri());
let (parts, body) = req.into_parts();
let body_bytes = to_bytes(body).await?;
let mut builder = Request::builder()
.method(parts.method)
.uri(path)
.version(parts.version);
let hop_headers = hop_by_hop_headers();
for (name, value) in parts.headers.iter() {
let name_str = name.as_str().to_ascii_lowercase();
if hop_headers.contains(name_str.as_str())
|| name == &HeaderName::from_static("x-unix-socket")
{
continue;
}
builder = builder.header(name, value);
}
let req = builder.body(Body::from(body_bytes))?;
let stream = UnixStream::connect(socket_path).await?;
let (mut sender, conn) = ConnBuilder::new().handshake(stream).await?;
tokio::spawn(async move {
if let Err(err) = conn.await {
warn!(error = %err, "unix socket connection error");
}
});
Ok(sender.send_request(req).await?)
}
#[cfg(not(target_os = "macos"))]
{
let _ = req;
let _ = socket_path;
Err(anyhow::anyhow!("unix sockets not supported"))
}
}

View File

@@ -0,0 +1,17 @@
use anyhow::Context;
use anyhow::Result;
use codex_core::config::find_codex_home;
use std::fs;
pub fn run_init() -> Result<()> {
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
let root = codex_home.join("network_proxy");
let mitm_dir = root.join("mitm");
fs::create_dir_all(&root).with_context(|| format!("failed to create {}", root.display()))?;
fs::create_dir_all(&mitm_dir)
.with_context(|| format!("failed to create {}", mitm_dir.display()))?;
println!("ensured {}", mitm_dir.display());
Ok(())
}

View File

@@ -0,0 +1,58 @@
mod admin;
mod config;
mod http_proxy;
mod init;
mod mitm;
mod policy;
mod responses;
mod socks5;
mod state;
use crate::state::AppState;
use anyhow::Result;
use clap::Parser;
use clap::Subcommand;
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::warn;
#[derive(Debug, Clone, Parser)]
#[command(name = "codex-network-proxy", about = "Codex network sandbox proxy")]
pub struct Args {
#[command(subcommand)]
pub command: Option<Command>,
}
#[derive(Debug, Clone, Subcommand)]
pub enum Command {
/// Initialize the Codex network proxy directories (e.g. MITM cert paths).
Init,
}
pub async fn run_main(args: Args) -> Result<()> {
tracing_subscriber::fmt::init();
if let Some(Command::Init) = args.command {
init::run_init()?;
return Ok(());
}
if cfg!(not(target_os = "macos")) {
warn!("allowUnixSockets is macOS-only; requests will be rejected on this platform");
}
let cfg_path = config::default_codex_config_path()?;
let state = Arc::new(AppState::new(cfg_path).await?);
let runtime = config::resolve_runtime(&state.current_cfg().await?);
let http_addr: SocketAddr = runtime.http_addr;
let socks_addr: SocketAddr = runtime.socks_addr;
let admin_addr: SocketAddr = runtime.admin_addr;
let http_task = http_proxy::run_http_proxy(state.clone(), http_addr);
let socks_task = socks5::run_socks5(state.clone(), socks_addr);
let admin_task = admin::run_admin_api(state.clone(), admin_addr);
tokio::try_join!(http_task, socks_task, admin_task)?;
Ok(())
}

View File

@@ -0,0 +1,8 @@
use anyhow::Result;
use clap::Parser;
use codex_network_proxy::Args;
#[tokio::main]
async fn main() -> Result<()> {
codex_network_proxy::run_main(Args::parse()).await
}

View File

@@ -0,0 +1,665 @@
#[cfg(feature = "mitm")]
mod imp {
use crate::config::MitmConfig;
use crate::config::NetworkMode;
use crate::policy::method_allowed;
use crate::policy::normalize_host;
use crate::responses::text_response;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use hyper::Body;
use hyper::Method;
use hyper::Request;
use hyper::Response;
use hyper::StatusCode;
use hyper::Uri;
use hyper::Version;
use hyper::body::HttpBody;
use hyper::header::HOST;
use hyper::server::conn::Http;
use hyper::service::service_fn;
use rcgen::BasicConstraints;
use rcgen::Certificate;
use rcgen::CertificateParams;
use rcgen::DistinguishedName;
use rcgen::DnType;
use rcgen::ExtendedKeyUsagePurpose;
use rcgen::IsCa;
use rcgen::KeyPair;
use rcgen::KeyUsagePurpose;
use rcgen::SanType;
use rustls::Certificate as RustlsCertificate;
use rustls::ClientConfig;
use rustls::PrivateKey;
use rustls::RootCertStore;
use rustls::ServerConfig;
use std::collections::HashSet;
use std::convert::Infallible;
use std::fs;
use std::io::Cursor;
use std::net::IpAddr;
use std::path::Path;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::TlsConnector;
use tracing::info;
use tracing::warn;
#[derive(Clone, Copy, Debug)]
enum MitmProtocol {
Http1,
Http2,
}
struct MitmTarget {
host: String,
port: u16,
}
impl MitmTarget {
fn authority(&self) -> String {
if self.port == 443 {
self.host.clone()
} else {
format!("{}:{}", self.host, self.port)
}
}
}
struct RequestLogContext {
host: String,
method: Method,
path: String,
}
struct ResponseLogContext {
host: String,
method: Method,
path: String,
status: StatusCode,
}
pub struct MitmState {
ca_key: KeyPair,
ca_cert: Certificate,
client_config: Arc<ClientConfig>,
inspect: bool,
max_body_bytes: usize,
}
impl MitmState {
pub fn new(cfg: &MitmConfig) -> Result<Self> {
let (ca_cert_pem, ca_key_pem) = load_or_create_ca(cfg)?;
let ca_key = KeyPair::from_pem(&ca_key_pem).context("failed to parse CA key")?;
let ca_params = CertificateParams::from_ca_cert_pem(&ca_cert_pem)
.context("failed to parse CA cert")?;
let ca_cert = ca_params
.self_signed(&ca_key)
.context("failed to reconstruct CA cert")?;
let client_config = build_client_config()?;
Ok(Self {
ca_key,
ca_cert,
client_config,
inspect: cfg.inspect,
max_body_bytes: cfg.max_body_bytes,
})
}
pub fn server_config_for_host(&self, host: &str) -> Result<Arc<ServerConfig>> {
let (certs, key) = issue_host_certificate(host, &self.ca_cert, &self.ca_key)?;
let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.context("failed to build server TLS config")?;
config.alpn_protocols = vec![b"http/1.1".to_vec()];
Ok(Arc::new(config))
}
pub fn client_config(&self) -> Arc<ClientConfig> {
Arc::clone(&self.client_config)
}
pub fn inspect_enabled(&self) -> bool {
self.inspect
}
pub fn max_body_bytes(&self) -> usize {
self.max_body_bytes
}
}
pub async fn mitm_tunnel(
stream: hyper::upgrade::Upgraded,
host: &str,
port: u16,
mode: NetworkMode,
state: Arc<MitmState>,
) -> Result<()> {
let server_config = state.server_config_for_host(host)?;
let acceptor = TlsAcceptor::from(server_config);
let tls_stream = acceptor
.accept(stream)
.await
.context("client TLS handshake failed")?;
let protocol = match tls_stream.get_ref().1.alpn_protocol() {
Some(proto) if proto == b"h2" => MitmProtocol::Http2,
_ => MitmProtocol::Http1,
};
info!(
host = %host,
port = port,
protocol = ?protocol,
mode = ?mode,
inspect = state.inspect_enabled(),
max_body_bytes = state.max_body_bytes(),
"MITM TLS established"
);
let target = Arc::new(MitmTarget {
host: host.to_string(),
port,
});
let service = {
let state = state.clone();
let target = target.clone();
service_fn(move |req| handle_mitm_request(req, target.clone(), mode, state.clone()))
};
let mut http = Http::new();
match protocol {
MitmProtocol::Http2 => {
http.http2_only(true);
}
MitmProtocol::Http1 => {
http.http1_only(true);
}
}
http.serve_connection(tls_stream, service)
.await
.context("MITM HTTP handling failed")?;
Ok(())
}
async fn handle_mitm_request(
req: Request<Body>,
target: Arc<MitmTarget>,
mode: NetworkMode,
state: Arc<MitmState>,
) -> Result<Response<Body>, Infallible> {
let response = match forward_request(req, target.as_ref(), mode, state.as_ref()).await {
Ok(resp) => resp,
Err(err) => {
warn!(error = %err, host = %target.host, "MITM upstream request failed");
text_response(StatusCode::BAD_GATEWAY, "mitm upstream error")
}
};
Ok(response)
}
async fn forward_request(
req: Request<Body>,
target: &MitmTarget,
mode: NetworkMode,
state: &MitmState,
) -> Result<Response<Body>> {
if req.method() == Method::CONNECT {
return Ok(text_response(
StatusCode::METHOD_NOT_ALLOWED,
"CONNECT not supported inside MITM",
));
}
let (parts, body) = req.into_parts();
let request_version = parts.version;
let method = parts.method.clone();
let inspect = state.inspect_enabled();
let max_body_bytes = state.max_body_bytes();
if let Some(request_host) = extract_request_host(&parts) {
let normalized = normalize_host(&request_host);
if !normalized.is_empty() && normalized != target.host {
warn!(
target = %target.host,
request_host = %normalized,
"MITM host mismatch"
);
return Ok(text_response(StatusCode::BAD_REQUEST, "host mismatch"));
}
}
let path = path_and_query(&parts.uri);
let uri = build_origin_form_uri(&path)?;
let authority = target.authority();
if !method_allowed(mode, &method) {
warn!(
host = %authority,
method = %method,
path = %path,
mode = ?mode,
allowed_methods = "GET, HEAD, OPTIONS",
"MITM blocked by method policy"
);
return Ok(text_response(StatusCode::FORBIDDEN, "method not allowed"));
}
let mut builder = Request::builder()
.method(method.clone())
.uri(uri)
.version(Version::HTTP_11);
let hop_headers = hop_by_hop_headers();
for (name, value) in parts.headers.iter() {
let name_str = name.as_str().to_ascii_lowercase();
if hop_headers.contains(name_str.as_str()) || name == &HOST {
continue;
}
builder = builder.header(name, value);
}
builder = builder.header(HOST, authority.as_str());
let body = if inspect {
let (tx, out_body) = Body::channel();
let ctx = RequestLogContext {
host: authority.clone(),
method: method.clone(),
path: path.clone(),
};
tokio::spawn(async move {
stream_body(body, tx, max_body_bytes, ctx).await;
});
out_body
} else {
body
};
let upstream_req = builder
.body(body)
.context("failed to build upstream request")?;
let upstream_resp = send_upstream_request(upstream_req, target, state).await?;
respond_with_inspection(
upstream_resp,
request_version,
inspect,
max_body_bytes,
&method,
&path,
&authority,
)
.await
}
async fn send_upstream_request(
req: Request<Body>,
target: &MitmTarget,
state: &MitmState,
) -> Result<Response<Body>> {
let upstream = TcpStream::connect((target.host.as_str(), target.port))
.await
.context("failed to connect to upstream")?;
let server_name = match target.host.parse::<IpAddr>() {
Ok(ip) => rustls::ServerName::IpAddress(ip),
Err(_) => rustls::ServerName::try_from(target.host.as_str())
.map_err(|_| anyhow!("invalid server name"))?,
};
let connector = TlsConnector::from(state.client_config());
let tls_stream = connector
.connect(server_name, upstream)
.await
.context("upstream TLS handshake failed")?;
let (mut sender, conn) = hyper::client::conn::Builder::new()
.handshake(tls_stream)
.await
.context("upstream HTTP handshake failed")?;
tokio::spawn(async move {
if let Err(err) = conn.await {
warn!(error = %err, "MITM upstream connection error");
}
});
let resp = sender
.send_request(req)
.await
.context("upstream request failed")?;
Ok(resp)
}
async fn respond_with_inspection(
resp: Response<Body>,
request_version: Version,
inspect: bool,
max_body_bytes: usize,
method: &Method,
path: &str,
authority: &str,
) -> Result<Response<Body>> {
let (parts, body) = resp.into_parts();
let mut builder = Response::builder()
.status(parts.status)
.version(request_version);
let hop_headers = hop_by_hop_headers();
for (name, value) in parts.headers.iter() {
if hop_headers.contains(name.as_str().to_ascii_lowercase().as_str()) {
continue;
}
builder = builder.header(name, value);
}
let body = if inspect {
let (tx, out_body) = Body::channel();
let ctx = ResponseLogContext {
host: authority.to_string(),
method: method.clone(),
path: path.to_string(),
status: parts.status,
};
tokio::spawn(async move {
stream_body(body, tx, max_body_bytes, ctx).await;
});
out_body
} else {
body
};
Ok(builder
.body(body)
.unwrap_or_else(|_| Response::new(Body::from("proxy error"))))
}
async fn stream_body<T>(
mut body: Body,
mut tx: hyper::body::Sender,
max_body_bytes: usize,
ctx: T,
) where
T: BodyLoggable,
{
let mut len: usize = 0;
let mut truncated = false;
while let Some(chunk) = body.data().await {
match chunk {
Ok(bytes) => {
len = len.saturating_add(bytes.len());
if len > max_body_bytes {
truncated = true;
}
if tx.send_data(bytes).await.is_err() {
break;
}
}
Err(err) => {
warn!(error = %err, "MITM body stream error");
break;
}
}
}
if let Ok(Some(trailers)) = body.trailers().await {
let _ = tx.send_trailers(trailers).await;
}
ctx.log(len, truncated);
}
trait BodyLoggable {
fn log(self, len: usize, truncated: bool);
}
impl BodyLoggable for RequestLogContext {
fn log(self, len: usize, truncated: bool) {
info!(
host = %self.host,
method = %self.method,
path = %self.path,
body_len = len,
truncated = truncated,
"MITM inspected request body"
);
}
}
impl BodyLoggable for ResponseLogContext {
fn log(self, len: usize, truncated: bool) {
info!(
host = %self.host,
method = %self.method,
path = %self.path,
status = %self.status,
body_len = len,
truncated = truncated,
"MITM inspected response body"
);
}
}
fn extract_request_host(parts: &hyper::http::request::Parts) -> Option<String> {
parts
.headers
.get(HOST)
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string())
.or_else(|| parts.uri.authority().map(|a| a.as_str().to_string()))
}
fn path_and_query(uri: &Uri) -> String {
uri.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/")
.to_string()
}
fn build_origin_form_uri(path: &str) -> Result<Uri> {
path.parse().context("invalid request path")
}
fn hop_by_hop_headers() -> HashSet<&'static str> {
[
"connection",
"proxy-connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
]
.into_iter()
.collect()
}
fn build_client_config() -> Result<Arc<ClientConfig>> {
let mut roots = RootCertStore::empty();
let certs = rustls_native_certs::load_native_certs()
.map_err(|err| anyhow!("failed to load native certs: {err}"))?;
for cert in certs {
if roots.add(&RustlsCertificate(cert.0)).is_err() {
warn!("skipping invalid root cert");
}
}
if roots.is_empty() {
return Err(anyhow!("no root certificates available"));
}
let mut config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
config.alpn_protocols = vec![b"http/1.1".to_vec()];
Ok(Arc::new(config))
}
fn issue_host_certificate(
host: &str,
ca_cert: &Certificate,
ca_key: &KeyPair,
) -> Result<(Vec<RustlsCertificate>, PrivateKey)> {
let mut params = if let Ok(ip) = host.parse::<IpAddr>() {
let mut params = CertificateParams::new(Vec::new())
.map_err(|err| anyhow!("failed to create cert params: {err}"))?;
params.subject_alt_names.push(SanType::IpAddress(ip));
params
} else {
CertificateParams::new(vec![host.to_string()])
.map_err(|err| anyhow!("failed to create cert params: {err}"))?
};
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
params.key_usages = vec![
KeyUsagePurpose::DigitalSignature,
KeyUsagePurpose::KeyEncipherment,
];
let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)
.map_err(|err| anyhow!("failed to generate host key pair: {err}"))?;
let cert = params
.signed_by(&key_pair, ca_cert, ca_key)
.map_err(|err| anyhow!("failed to sign host cert: {err}"))?;
let cert_pem = cert.pem();
let key_pem = key_pair.serialize_pem();
let certs = certs_from_pem(&cert_pem)?;
let key = private_key_from_pem(&key_pem)?;
Ok((certs, key))
}
fn load_or_create_ca(cfg: &MitmConfig) -> Result<(String, String)> {
let cert_path = &cfg.ca_cert_path;
let key_path = &cfg.ca_key_path;
if cert_path.exists() || key_path.exists() {
if !cert_path.exists() || !key_path.exists() {
return Err(anyhow!("both ca_cert_path and ca_key_path must exist"));
}
let cert_pem = fs::read_to_string(cert_path)
.with_context(|| format!("failed to read CA cert {}", cert_path.display()))?;
let key_pem = fs::read_to_string(key_path)
.with_context(|| format!("failed to read CA key {}", key_path.display()))?;
return Ok((cert_pem, key_pem));
}
if let Some(parent) = cert_path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create {}", parent.display()))?;
}
if let Some(parent) = key_path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("failed to create {}", parent.display()))?;
}
let (cert_pem, key_pem) = generate_ca()?;
write_private_file(cert_path, cert_pem.as_bytes(), 0o644)?;
write_private_file(key_path, key_pem.as_bytes(), 0o600)?;
info!(
cert_path = %cert_path.display(),
key_path = %key_path.display(),
"generated MITM CA"
);
Ok((cert_pem, key_pem))
}
fn generate_ca() -> Result<(String, String)> {
let mut params = CertificateParams::default();
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![
KeyUsagePurpose::KeyCertSign,
KeyUsagePurpose::DigitalSignature,
KeyUsagePurpose::KeyEncipherment,
];
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, "network_proxy MITM CA");
params.distinguished_name = dn;
let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)
.map_err(|err| anyhow!("failed to generate CA key pair: {err}"))?;
let cert = params
.self_signed(&key_pair)
.map_err(|err| anyhow!("failed to generate CA cert: {err}"))?;
let cert_pem = cert.pem();
let key_pem = key_pair.serialize_pem();
Ok((cert_pem, key_pem))
}
fn certs_from_pem(pem: &str) -> Result<Vec<RustlsCertificate>> {
let mut reader = Cursor::new(pem);
let certs = rustls_pemfile::certs(&mut reader).context("failed to parse cert PEM")?;
if certs.is_empty() {
return Err(anyhow!("no certificates found"));
}
Ok(certs.into_iter().map(RustlsCertificate).collect())
}
fn private_key_from_pem(pem: &str) -> Result<PrivateKey> {
let mut reader = Cursor::new(pem);
let mut keys =
rustls_pemfile::pkcs8_private_keys(&mut reader).context("failed to parse pkcs8 key")?;
if let Some(key) = keys.pop() {
return Ok(PrivateKey(key));
}
let mut reader = Cursor::new(pem);
let mut keys =
rustls_pemfile::rsa_private_keys(&mut reader).context("failed to parse rsa key")?;
if let Some(key) = keys.pop() {
return Ok(PrivateKey(key));
}
Err(anyhow!("no private key found"))
}
fn write_private_file(path: &Path, contents: &[u8], mode: u32) -> Result<()> {
fs::write(path, contents).with_context(|| format!("failed to write {}", path.display()))?;
set_permissions(path, mode)?;
Ok(())
}
#[cfg(unix)]
fn set_permissions(path: &Path, mode: u32) -> Result<()> {
use std::os::unix::fs::PermissionsExt;
fs::set_permissions(path, fs::Permissions::from_mode(mode))
.with_context(|| format!("failed to set permissions on {}", path.display()))?;
Ok(())
}
#[cfg(not(unix))]
fn set_permissions(_path: &Path, _mode: u32) -> Result<()> {
Ok(())
}
}
#[cfg(not(feature = "mitm"))]
mod imp {
use crate::config::MitmConfig;
use crate::config::NetworkMode;
use anyhow::Result;
use anyhow::anyhow;
use hyper::upgrade::Upgraded;
use std::sync::Arc;
#[derive(Debug)]
pub struct MitmState;
#[allow(dead_code)]
impl MitmState {
pub fn new(_cfg: &MitmConfig) -> Result<Self> {
Err(anyhow!("MITM feature disabled at build time"))
}
pub fn inspect_enabled(&self) -> bool {
false
}
pub fn max_body_bytes(&self) -> usize {
0
}
}
pub async fn mitm_tunnel(
_stream: Upgraded,
_host: &str,
_port: u16,
_mode: NetworkMode,
_state: Arc<MitmState>,
) -> Result<()> {
Err(anyhow!("MITM feature disabled at build time"))
}
}
pub use imp::*;

View File

@@ -0,0 +1,31 @@
use crate::config::NetworkMode;
use hyper::Method;
use std::net::IpAddr;
pub fn method_allowed(mode: NetworkMode, method: &Method) -> bool {
match mode {
NetworkMode::Full => true,
NetworkMode::Limited => matches!(method, &Method::GET | &Method::HEAD | &Method::OPTIONS),
}
}
pub fn is_loopback_host(host: &str) -> bool {
let host = host.to_ascii_lowercase();
if host == "localhost" || host == "localhost." {
return true;
}
if let Ok(ip) = host.parse::<IpAddr>() {
return ip.is_loopback();
}
false
}
pub fn normalize_host(host: &str) -> String {
let host = host.trim();
if host.starts_with('[') {
if let Some(end) = host.find(']') {
return host[1..end].to_ascii_lowercase();
}
}
host.split(':').next().unwrap_or("").to_ascii_lowercase()
}

View File

@@ -0,0 +1,65 @@
use hyper::Body;
use hyper::Response;
use hyper::StatusCode;
use serde::Serialize;
use serde_json::json;
pub fn json_blocked(host: &str, reason: &str) -> Response<Body> {
let body = Body::from(json!({"status":"blocked","host":host,"reason":reason}).to_string());
Response::builder()
.status(StatusCode::FORBIDDEN)
.header("content-type", "application/json")
.header("x-proxy-error", blocked_header_value(reason))
.body(body)
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
}
pub fn blocked_text(reason: &str) -> Response<Body> {
Response::builder()
.status(StatusCode::FORBIDDEN)
.header("content-type", "text/plain")
.header("x-proxy-error", blocked_header_value(reason))
.body(Body::from(blocked_message(reason).to_string()))
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
}
pub fn text_response(status: StatusCode, body: &str) -> Response<Body> {
Response::builder()
.status(status)
.header("content-type", "text/plain")
.body(Body::from(body.to_string()))
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
}
pub fn json_response<T: Serialize>(value: &T) -> Response<Body> {
let body = match serde_json::to_string(value) {
Ok(body) => body,
Err(_) => "{}".to_string(),
};
Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap_or_else(|_| Response::new(Body::from("{}")))
}
fn blocked_header_value(reason: &str) -> &'static str {
match reason {
"not_allowed" | "not_allowed_local" => "blocked-by-allowlist",
"denied" => "blocked-by-denylist",
"method_not_allowed" => "blocked-by-method-policy",
"mitm_required" => "blocked-by-mitm-required",
_ => "blocked-by-policy",
}
}
fn blocked_message(reason: &str) -> &'static str {
match reason {
"not_allowed" => "Codex blocked this request: domain not in allowlist.",
"not_allowed_local" => "Codex blocked this request: local addresses not allowed.",
"denied" => "Codex blocked this request: domain denied by policy.",
"method_not_allowed" => "Codex blocked this request: method not allowed in limited mode.",
"mitm_required" => "Codex blocked this request: MITM required for limited HTTPS.",
_ => "Codex blocked this request by network policy.",
}
}

View File

@@ -0,0 +1,192 @@
use crate::config::NetworkMode;
use crate::policy::normalize_host;
use crate::state::AppState;
use crate::state::BlockedRequest;
use anyhow::Result;
use anyhow::anyhow;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::copy_bidirectional;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tracing::error;
use tracing::info;
use tracing::warn;
pub async fn run_socks5(state: Arc<AppState>, addr: SocketAddr) -> Result<()> {
let listener = TcpListener::bind(addr).await?;
info!(addr = %addr, "SOCKS5 proxy listening");
match state.network_mode().await {
Ok(NetworkMode::Limited) => {
info!(
mode = "limited",
"SOCKS5 is blocked in limited mode; set mode=\"full\" to allow SOCKS5"
);
}
Ok(NetworkMode::Full) => {}
Err(err) => {
warn!(error = %err, "failed to read network mode");
}
}
loop {
let (stream, peer_addr) = listener.accept().await?;
let state = state.clone();
tokio::spawn(async move {
if let Err(err) = handle_socks5_client(stream, peer_addr, state).await {
warn!(error = %err, "SOCKS5 session ended with error");
}
});
}
}
async fn handle_socks5_client(
mut stream: TcpStream,
peer_addr: SocketAddr,
state: Arc<AppState>,
) -> Result<()> {
let mut header = [0u8; 2];
stream.read_exact(&mut header).await?;
if header[0] != 0x05 {
return Err(anyhow!("invalid SOCKS version"));
}
let nmethods = header[1] as usize;
let mut methods = vec![0u8; nmethods];
stream.read_exact(&mut methods).await?;
stream.write_all(&[0x05, 0x00]).await?;
let mut req_header = [0u8; 4];
stream.read_exact(&mut req_header).await?;
if req_header[0] != 0x05 {
return Err(anyhow!("invalid SOCKS request version"));
}
let cmd = req_header[1];
if cmd != 0x01 {
stream
.write_all(&[0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Err(anyhow!("unsupported SOCKS command"));
}
let atyp = req_header[3];
let host = match atyp {
0x01 => {
let mut addr = [0u8; 4];
stream.read_exact(&mut addr).await?;
format!("{}.{}.{}.{}", addr[0], addr[1], addr[2], addr[3])
}
0x03 => {
let mut len_buf = [0u8; 1];
stream.read_exact(&mut len_buf).await?;
let len = len_buf[0] as usize;
let mut domain = vec![0u8; len];
stream.read_exact(&mut domain).await?;
String::from_utf8_lossy(&domain).to_string()
}
0x04 => {
stream
.write_all(&[0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Err(anyhow!("ipv6 not supported"));
}
_ => {
stream
.write_all(&[0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Err(anyhow!("unknown address type"));
}
};
let mut port_buf = [0u8; 2];
stream.read_exact(&mut port_buf).await?;
let port = u16::from_be_bytes(port_buf);
let normalized_host = normalize_host(&host);
match state.network_mode().await {
Ok(NetworkMode::Limited) => {
let _ = state
.record_blocked(BlockedRequest::new(
normalized_host.clone(),
"method_not_allowed".to_string(),
Some(peer_addr.to_string()),
None,
Some(NetworkMode::Limited),
"socks5".to_string(),
))
.await;
warn!(
client = %peer_addr,
host = %normalized_host,
mode = "limited",
allowed_methods = "GET, HEAD, OPTIONS",
"SOCKS blocked by method policy"
);
stream
.write_all(&[0x05, 0x02, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Ok(());
}
Ok(NetworkMode::Full) => {}
Err(err) => {
error!(error = %err, "failed to evaluate method policy");
stream
.write_all(&[0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Ok(());
}
}
match state.host_blocked(&normalized_host).await {
Ok((true, reason)) => {
let _ = state
.record_blocked(BlockedRequest::new(
normalized_host.clone(),
reason.clone(),
Some(peer_addr.to_string()),
None,
None,
"socks5".to_string(),
))
.await;
warn!(client = %peer_addr, host = %normalized_host, reason = %reason, "SOCKS blocked");
stream
.write_all(&[0x05, 0x02, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Ok(());
}
Ok((false, _)) => {
info!(
client = %peer_addr,
host = %normalized_host,
port = port,
"SOCKS allowed"
);
}
Err(err) => {
error!(error = %err, "failed to evaluate host");
stream
.write_all(&[0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Ok(());
}
}
let target = format!("{host}:{port}");
let mut upstream = match TcpStream::connect(&target).await {
Ok(stream) => stream,
Err(err) => {
warn!(error = %err, "SOCKS connect failed");
stream
.write_all(&[0x05, 0x04, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
return Ok(());
}
};
stream
.write_all(&[0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
.await?;
let _ = copy_bidirectional(&mut stream, &mut upstream).await;
Ok(())
}

View File

@@ -0,0 +1,350 @@
use crate::config::Config;
use crate::config::MitmConfig;
use crate::config::NetworkMode;
use crate::mitm::MitmState;
use crate::policy::is_loopback_host;
use crate::policy::method_allowed;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use globset::GlobBuilder;
use globset::GlobSet;
use globset::GlobSetBuilder;
use hyper::Client;
use hyper::Method;
use hyper::client::HttpConnector;
use serde::Serialize;
use std::collections::HashSet;
use std::collections::VecDeque;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use tokio::sync::RwLock;
use tracing::info;
use tracing::warn;
const MAX_BLOCKED_EVENTS: usize = 200;
#[derive(Clone, Debug, Serialize)]
pub struct BlockedRequest {
pub host: String,
pub reason: String,
pub client: Option<String>,
pub method: Option<String>,
pub mode: Option<NetworkMode>,
pub protocol: String,
pub timestamp: i64,
}
impl BlockedRequest {
pub fn new(
host: String,
reason: String,
client: Option<String>,
method: Option<String>,
mode: Option<NetworkMode>,
protocol: String,
) -> Self {
Self {
host,
reason,
client,
method,
mode,
protocol,
timestamp: unix_timestamp(),
}
}
}
#[derive(Clone)]
struct ConfigState {
cfg: Config,
mtime: Option<SystemTime>,
allow_set: GlobSet,
deny_set: GlobSet,
mitm: Option<Arc<MitmState>>,
cfg_path: PathBuf,
blocked: VecDeque<BlockedRequest>,
}
#[derive(Clone)]
pub struct AppState {
pub(crate) client: Client<HttpConnector>,
state: Arc<RwLock<ConfigState>>,
}
impl AppState {
pub async fn new(cfg_path: PathBuf) -> Result<Self> {
let cfg_state = build_config_state(cfg_path)?;
let client = Client::new();
Ok(Self {
client,
state: Arc::new(RwLock::new(cfg_state)),
})
}
pub async fn current_cfg(&self) -> Result<Config> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.cfg.clone())
}
pub async fn current_patterns(&self) -> Result<(Vec<String>, Vec<String>)> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok((
guard.cfg.network_proxy.policy.allowed_domains.clone(),
guard.cfg.network_proxy.policy.denied_domains.clone(),
))
}
pub async fn force_reload(&self) -> Result<()> {
let mut guard = self.state.write().await;
let previous_cfg = guard.cfg.clone();
let blocked = guard.blocked.clone();
let cfg_path = guard.cfg_path.clone();
match build_config_state(cfg_path.clone()) {
Ok(mut new_state) => {
log_policy_changes(&previous_cfg, &new_state.cfg);
new_state.blocked = blocked;
*guard = new_state;
info!(path = %cfg_path.display(), "reloaded config");
Ok(())
}
Err(err) => {
warn!(error = %err, path = %cfg_path.display(), "failed to reload config; keeping previous config");
Err(err)
}
}
}
pub async fn host_blocked(&self, host: &str) -> Result<(bool, String)> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
if guard.deny_set.is_match(host) {
return Ok((true, "denied".to_string()));
}
let is_loopback = is_loopback_host(host);
if is_loopback
&& !guard.cfg.network_proxy.policy.allow_local_binding
&& !guard.allow_set.is_match(host)
{
return Ok((true, "not_allowed_local".to_string()));
}
if guard.cfg.network_proxy.policy.allowed_domains.is_empty()
|| !guard.allow_set.is_match(host)
{
return Ok((true, "not_allowed".to_string()));
}
Ok((false, String::new()))
}
pub async fn record_blocked(&self, entry: BlockedRequest) -> Result<()> {
self.reload_if_needed().await?;
let mut guard = self.state.write().await;
guard.blocked.push_back(entry);
while guard.blocked.len() > MAX_BLOCKED_EVENTS {
guard.blocked.pop_front();
}
Ok(())
}
pub async fn drain_blocked(&self) -> Result<Vec<BlockedRequest>> {
self.reload_if_needed().await?;
let mut guard = self.state.write().await;
let blocked = std::mem::take(&mut guard.blocked);
Ok(blocked.into_iter().collect())
}
pub async fn is_unix_socket_allowed(&self, path: &str) -> Result<bool> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard
.cfg
.network_proxy
.policy
.allow_unix_sockets
.iter()
.any(|p| p == path))
}
pub async fn method_allowed(&self, method: &Method) -> Result<bool> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(method_allowed(guard.cfg.network_proxy.mode, method))
}
pub async fn network_mode(&self) -> Result<NetworkMode> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.cfg.network_proxy.mode)
}
pub async fn set_network_mode(&self, mode: NetworkMode) -> Result<()> {
self.reload_if_needed().await?;
let mut guard = self.state.write().await;
guard.cfg.network_proxy.mode = mode;
info!(mode = ?mode, "updated network mode");
Ok(())
}
pub async fn mitm_state(&self) -> Result<Option<Arc<MitmState>>> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.mitm.clone())
}
async fn reload_if_needed(&self) -> Result<()> {
let needs_reload = {
let guard = self.state.read().await;
if !guard.cfg_path.exists() {
true
} else {
let metadata = std::fs::metadata(&guard.cfg_path).ok();
match (metadata.and_then(|m| m.modified().ok()), guard.mtime) {
(Some(new_mtime), Some(old_mtime)) => new_mtime > old_mtime,
(Some(_), None) => true,
_ => false,
}
}
};
if !needs_reload {
return Ok(());
}
self.force_reload().await
}
}
fn build_config_state(cfg_path: PathBuf) -> Result<ConfigState> {
let mut cfg = if cfg_path.exists() {
load_config_from_path(&cfg_path).with_context(|| {
format!(
"failed to load config from {}",
cfg_path.as_path().display()
)
})?
} else {
Config::default()
};
resolve_mitm_paths(&mut cfg, &cfg_path);
let mtime = cfg_path.metadata().and_then(|m| m.modified()).ok();
let deny_set = compile_globset(&cfg.network_proxy.policy.denied_domains)?;
let allow_set = compile_globset(&cfg.network_proxy.policy.allowed_domains)?;
let mitm = if cfg.network_proxy.mitm.enabled {
build_mitm_state(&cfg.network_proxy.mitm)?
} else {
None
};
Ok(ConfigState {
cfg,
mtime,
allow_set,
deny_set,
mitm,
cfg_path,
blocked: VecDeque::new(),
})
}
fn resolve_mitm_paths(cfg: &mut Config, cfg_path: &Path) {
let base = cfg_path.parent().unwrap_or_else(|| Path::new("."));
if cfg.network_proxy.mitm.ca_cert_path.is_relative() {
cfg.network_proxy.mitm.ca_cert_path = base.join(&cfg.network_proxy.mitm.ca_cert_path);
}
if cfg.network_proxy.mitm.ca_key_path.is_relative() {
cfg.network_proxy.mitm.ca_key_path = base.join(&cfg.network_proxy.mitm.ca_key_path);
}
}
fn build_mitm_state(_cfg: &MitmConfig) -> Result<Option<Arc<MitmState>>> {
#[cfg(feature = "mitm")]
{
return Ok(Some(Arc::new(MitmState::new(_cfg)?)));
}
#[cfg(not(feature = "mitm"))]
{
warn!("MITM enabled in config but binary built without mitm feature");
Ok(None)
}
}
fn compile_globset(patterns: &[String]) -> Result<GlobSet> {
let mut builder = GlobSetBuilder::new();
let mut seen = HashSet::new();
for pattern in patterns {
let mut expanded = Vec::with_capacity(2);
expanded.push(pattern.as_str());
if let Some(apex) = pattern.strip_prefix("*.") {
expanded.push(apex);
}
for candidate in expanded {
if !seen.insert(candidate.to_string()) {
continue;
}
let glob = GlobBuilder::new(candidate)
.case_insensitive(true)
.build()
.with_context(|| format!("invalid glob pattern: {candidate}"))?;
builder.add(glob);
}
}
Ok(builder.build()?)
}
fn log_policy_changes(previous: &Config, next: &Config) {
log_domain_list_changes(
"allowlist",
&previous.network_proxy.policy.allowed_domains,
&next.network_proxy.policy.allowed_domains,
);
log_domain_list_changes(
"denylist",
&previous.network_proxy.policy.denied_domains,
&next.network_proxy.policy.denied_domains,
);
}
fn log_domain_list_changes(list_name: &str, previous: &[String], next: &[String]) {
let previous_set: HashSet<String> = previous
.iter()
.map(|entry| entry.to_ascii_lowercase())
.collect();
let next_set: HashSet<String> = next
.iter()
.map(|entry| entry.to_ascii_lowercase())
.collect();
let mut seen_next = HashSet::new();
for entry in next {
let key = entry.to_ascii_lowercase();
if seen_next.insert(key.clone()) && !previous_set.contains(&key) {
info!(list = list_name, entry = %entry, "config entry added");
}
}
let mut seen_previous = HashSet::new();
for entry in previous {
let key = entry.to_ascii_lowercase();
if seen_previous.insert(key.clone()) && !next_set.contains(&key) {
info!(list = list_name, entry = %entry, "config entry removed");
}
}
}
fn unix_timestamp() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_secs() as i64)
.unwrap_or(0)
}
fn load_config_from_path(path: &Path) -> Result<Config> {
let raw = std::fs::read_to_string(path)
.with_context(|| format!("unable to read config file {}", path.display()))?;
toml::from_str(&raw).map_err(|err| anyhow!("unable to parse config: {err}"))
}

View File

@@ -0,0 +1,102 @@
# Codex Network Proxy Design
This document describes the Codex network proxy that runs outside the sandbox and enforces an allow-only network policy for sandboxed subprocesses. The proxy is a single binary with HTTP proxying, SOCKS5, and an admin API. Codex owns the policy state in `~/.codex/config.toml`; the proxy reads that configuration and applies it at the network edge.
## Goals
1. Enforce allow-only network access with denylist precedence.
2. Support wildcard domain patterns, including apex match for `*.domain.tld`.
3. Allow two modes: **limited** (read-only) and **full** (all methods).
4. Provide optional **MITM** to enforce read-only on HTTPS.
5. Allow hot-reloaded configuration via admin API.
6. Provide clear audit logging of allow/deny decisions and policy changes.
7. Enable a single binary with HTTP proxy, SOCKS5 proxy, and admin API.
## Non-Goals
- Enterprise policy distribution or centralized multi-tenant orchestration.
- Deep packet inspection beyond the supported HTTP/HTTPS interception modes.
- Perfect protocol coverage for all network traffic types.
## Architecture
```mermaid
flowchart LR
subgraph Sandbox["Codex (sandboxed)"]
Tools["commands / tools<br/>curl, git, python"]
SocksClients["SOCKS clients"]
end
subgraph Proxy["codex-network-proxy (host process)"]
HttpProxy["HTTP Proxy :3128<br/>CONNECT tunnel<br/>MITM (optional)"]
SocksProxy["SOCKS5 Proxy :8081"]
Admin["Admin API :8080<br/>/health /config /blocked<br/>/reload /mode"]
end
Config["~/.codex/config.toml<br/>[network_proxy.*]"]
Tools -->|HTTP| HttpProxy
SocksClients -->|SOCKS5| SocksProxy
Admin -->|reads + reloads| Config
```
## Configuration Model
The proxy reads `~/.codex/config.toml`:
- `[network_proxy]` for endpoints, mode, and toggles.
- `[network_proxy.policy]` for `allowed_domains` / `denied_domains` (and, on macOS, optional local IPC allowances).
- `[network_proxy.mitm]` for MITM CA paths and inspection settings.
Codex is the source of truth. Approval actions update the config and trigger a proxy reload.
## Enforcement Model
- **Allow/deny precedence:** denylist wins; allowlist is required for access.
- **Limited mode:** only GET/HEAD/OPTIONS are permitted. HTTPS requires MITM to enforce method constraints; otherwise CONNECT is blocked with a clear reason.
- **Full mode:** all methods allowed; CONNECT tunneling is permitted without MITM.
## macOS Sandbox Integration (Seatbelt)
On macOS, Codex uses Seatbelt (`sandbox-exec`) for OS-level enforcement.
Key points:
- **Per-domain gating happens in the proxy**, not in Seatbelt: Seatbelt network rules are intentionally limited to loopback proxy ports (e.g. `localhost:3128` / `localhost:8081`) so all outbound traffic is forced through the proxy, which then applies the allow/deny policy and prompts.
- **Local IPC is deny-by-default** when proxy-restricted network access is active. Some tools rely on Unix domain sockets (e.g. the SSH agent). These are blocked unless explicitly allowed via:
- `network_proxy.policy.allow_unix_sockets` (absolute socket paths, `$SSH_AUTH_SOCK`, or the `ssh-agent` preset), and/or
- `network_proxy.policy.allow_local_binding` (if you need to bind/listen on localhost ports).
When approvals are enabled, Codex can preflight commands that appear to require the SSH agent and prompt to allow the SSH agent socket before running.
## Logging and Auditability
The proxy logs:
- Allow/deny decisions (host, client, reason).
- Policy updates (allowlist/denylist adds/removes).
- Mode changes and config reloads.
- MITM lifecycle events (CA generated, TLS established).
## Decision to Make: Preflight Strictness
Codex performs a preflight check before running some commands. Preflight currently scans CLI args for URLs on known network tools (curl, git, etc.) and shell `-c` snippets.
We need to decide how strict preflight should be:
Option A: **Heuristic preflight (current)**
- Pros: catches obvious `curl https://...` style commands early.
- Cons: misses dynamic URLs inside scripts; can still overprompt on shell snippets.
Option B: **Strict preflight**
- Only preflight when a URL argument is present in the command.
- For everything else, rely on the proxy `/blocked` prompt at connect time.
- Pros: fewer false positives, clearer user experience.
- Cons: fewer early prompts; approvals shift to runtime events.
Decision: **TBD**. We should choose a configuration flag (`network_proxy.preflight_mode = "heuristic" | "strict"`) and default based on observed UX.
## Open Items
- Finalize preflight strictness and expose a config toggle if needed.
- Confirm documentation for MITM trust steps and CA injection into sandboxed commands.

View File

@@ -0,0 +1,93 @@
# Codex Network Proxy Quickstart (Local)
This is a compact guide to build and validate the Codex network proxy locally.
## Build
From the Codex repo:
```bash
cd /Users/viyatb/code/codex/codex-rs
cargo build -p codex-network-proxy
```
For MITM support:
```bash
cargo build -p codex-network-proxy --features mitm
```
## Configure
Add this to `~/.codex/config.toml`:
```toml
[network_proxy]
enabled = true
proxy_url = "http://127.0.0.1:3128"
admin_url = "http://127.0.0.1:8080"
mode = "limited" # or "full"
poll_interval_ms = 1000
[network_proxy.policy]
allowed_domains = ["example.com", "*.github.com"]
denied_domains = ["metadata.google.internal", "169.254.*"]
# macOS only: allow specific local IPC when proxy-restricted.
allow_local_binding = false
# Example: allow SSH agent socket for git/ssh.
allow_unix_sockets = ["$SSH_AUTH_SOCK"]
[network_proxy.mitm]
enabled = false
```
## Run the proxy
```bash
cd /Users/viyatb/code/codex/codex-rs
cargo run -p codex-network-proxy -- proxy
```
With MITM:
```bash
cargo run -p codex-network-proxy --features mitm -- proxy
```
## Test with curl
HTTP/HTTPS via proxy:
```bash
export HTTP_PROXY="http://127.0.0.1:3128"
export HTTPS_PROXY="http://127.0.0.1:3128"
curl -sS https://example.com
```
Limited mode + HTTPS requires MITM. If MITM is on, trust the generated CA:
```bash
security add-trusted-cert -d -r trustRoot \
-k ~/Library/Keychains/login.keychain-db \
~/.codex/network_proxy/mitm/ca.pem
```
Or pass the CA directly:
```bash
curl --cacert ~/.codex/network_proxy/mitm/ca.pem -sS https://example.com
```
## Admin endpoints
Reload config after edits:
```bash
curl -fsS -X POST http://127.0.0.1:8080/reload
```
Switch modes:
```bash
curl -fsS -X POST http://127.0.0.1:8080/mode -d '{"mode":"full"}'
```