feat: introducing a network sandbox proxy (#8442)

This add a new crate, `codex-network-proxy`, a local network proxy
service used by Codex to enforce fine-grained network policy (domain
allow/deny) and to surface blocked network events for interactive
approvals.

- New crate: `codex-rs/network-proxy/` (`codex-network-proxy` binary +
library)
- Core capabilities:
  - HTTP proxy support (including CONNECT tunneling)
  - SOCKS5 proxy support (in the later PR)
- policy evaluation (allowed/denied domain lists; denylist wins;
wildcard support)
  - small admin API for polling/reload/mode changes
- optional MITM support for HTTPS CONNECT to enforce “limited mode”
method restrictions (later PR)

Will follow up integration with codex in subsequent PRs.

## Testing

- `cd codex-rs && cargo build -p codex-network-proxy`
- `cd codex-rs && cargo run -p codex-network-proxy -- proxy`
This commit is contained in:
viyatb-oai
2026-01-23 20:47:09 -05:00
committed by GitHub
parent 69cfc73dc6
commit 77222492f9
22 changed files with 4904 additions and 21 deletions

View File

@@ -0,0 +1,160 @@
use crate::config::NetworkMode;
use crate::responses::json_response;
use crate::responses::text_response;
use crate::state::NetworkProxyState;
use anyhow::Context;
use anyhow::Result;
use rama_core::rt::Executor;
use rama_core::service::service_fn;
use rama_http::Body;
use rama_http::Request;
use rama_http::Response;
use rama_http::StatusCode;
use rama_http_backend::server::HttpServer;
use rama_tcp::server::TcpListener;
use serde::Deserialize;
use serde::Serialize;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::error;
use tracing::info;
pub async fn run_admin_api(state: Arc<NetworkProxyState>, addr: SocketAddr) -> Result<()> {
// Debug-only admin API (health/config/patterns/blocked + mode/reload). Policy is config-driven
// and constraint-enforced; this endpoint should not become a second policy/approval plane.
let listener = TcpListener::build()
.bind(addr)
.await
// See `http_proxy.rs` for details on why we wrap `BoxError` before converting to anyhow.
.map_err(rama_core::error::OpaqueError::from)
.map_err(anyhow::Error::from)
.with_context(|| format!("bind admin API: {addr}"))?;
let server_state = state.clone();
let server = HttpServer::auto(Executor::new()).service(service_fn(move |req| {
let state = server_state.clone();
async move { handle_admin_request(state, req).await }
}));
info!("admin API listening on {addr}");
listener.serve(server).await;
Ok(())
}
async fn handle_admin_request(
state: Arc<NetworkProxyState>,
req: Request,
) -> Result<Response, Infallible> {
const MODE_BODY_LIMIT: usize = 8 * 1024;
let method = req.method().clone();
let path = req.uri().path().to_string();
let response = match (method.as_str(), path.as_str()) {
("GET", "/health") => Response::new(Body::from("ok")),
("GET", "/config") => match state.current_cfg().await {
Ok(cfg) => json_response(&cfg),
Err(err) => {
error!("failed to load config: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
},
("GET", "/patterns") => match state.current_patterns().await {
Ok((allow, deny)) => json_response(&PatternsResponse {
allowed: allow,
denied: deny,
}),
Err(err) => {
error!("failed to load patterns: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
},
("GET", "/blocked") => match state.drain_blocked().await {
Ok(blocked) => json_response(&BlockedResponse { blocked }),
Err(err) => {
error!("failed to read blocked queue: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
},
("POST", "/mode") => {
let mut body = req.into_body();
let mut buf: Vec<u8> = Vec::new();
loop {
let chunk = match body.chunk().await {
Ok(chunk) => chunk,
Err(err) => {
error!("failed to read mode body: {err}");
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid body"));
}
};
let Some(chunk) = chunk else {
break;
};
if buf.len().saturating_add(chunk.len()) > MODE_BODY_LIMIT {
return Ok(text_response(
StatusCode::PAYLOAD_TOO_LARGE,
"body too large",
));
}
buf.extend_from_slice(&chunk);
}
if buf.is_empty() {
return Ok(text_response(StatusCode::BAD_REQUEST, "missing body"));
}
let update: ModeUpdate = match serde_json::from_slice(&buf) {
Ok(update) => update,
Err(err) => {
error!("failed to parse mode update: {err}");
return Ok(text_response(StatusCode::BAD_REQUEST, "invalid json"));
}
};
match state.set_network_mode(update.mode).await {
Ok(()) => json_response(&ModeUpdateResponse {
status: "ok",
mode: update.mode,
}),
Err(err) => {
error!("mode update failed: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "mode update failed")
}
}
}
("POST", "/reload") => match state.force_reload().await {
Ok(()) => json_response(&ReloadResponse { status: "reloaded" }),
Err(err) => {
error!("reload failed: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "reload failed")
}
},
_ => text_response(StatusCode::NOT_FOUND, "not found"),
};
Ok(response)
}
#[derive(Deserialize)]
struct ModeUpdate {
mode: NetworkMode,
}
#[derive(Debug, Serialize)]
struct PatternsResponse {
allowed: Vec<String>,
denied: Vec<String>,
}
#[derive(Debug, Serialize)]
struct BlockedResponse<T> {
blocked: T,
}
#[derive(Debug, Serialize)]
struct ModeUpdateResponse {
status: &'static str,
mode: NetworkMode,
}
#[derive(Debug, Serialize)]
struct ReloadResponse {
status: &'static str,
}

View File

@@ -0,0 +1,433 @@
use anyhow::Context;
use anyhow::Result;
use anyhow::bail;
use serde::Deserialize;
use serde::Serialize;
use std::net::IpAddr;
use std::net::SocketAddr;
use tracing::warn;
use url::Url;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct NetworkProxyConfig {
#[serde(default)]
pub network_proxy: NetworkProxySettings,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkProxySettings {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_proxy_url")]
pub proxy_url: String,
#[serde(default = "default_admin_url")]
pub admin_url: String,
#[serde(default)]
pub allow_upstream_proxy: bool,
#[serde(default)]
pub dangerously_allow_non_loopback_proxy: bool,
#[serde(default)]
pub dangerously_allow_non_loopback_admin: bool,
#[serde(default)]
pub mode: NetworkMode,
#[serde(default)]
pub policy: NetworkPolicy,
}
impl Default for NetworkProxySettings {
fn default() -> Self {
Self {
enabled: false,
proxy_url: default_proxy_url(),
admin_url: default_admin_url(),
allow_upstream_proxy: false,
dangerously_allow_non_loopback_proxy: false,
dangerously_allow_non_loopback_admin: false,
mode: NetworkMode::default(),
policy: NetworkPolicy::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct NetworkPolicy {
#[serde(default)]
pub allowed_domains: Vec<String>,
#[serde(default)]
pub denied_domains: Vec<String>,
#[serde(default)]
pub allow_unix_sockets: Vec<String>,
#[serde(default)]
pub allow_local_binding: bool,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum NetworkMode {
/// Limited (read-only) access: only GET/HEAD/OPTIONS are allowed for HTTP. HTTPS CONNECT is
/// blocked unless MITM is enabled so the proxy can enforce method policy on inner requests.
Limited,
/// Full network access: all HTTP methods are allowed, and HTTPS CONNECTs are tunneled without
/// MITM interception.
#[default]
Full,
}
impl NetworkMode {
pub fn allows_method(self, method: &str) -> bool {
match self {
Self::Full => true,
Self::Limited => matches!(method, "GET" | "HEAD" | "OPTIONS"),
}
}
}
fn default_proxy_url() -> String {
"http://127.0.0.1:3128".to_string()
}
fn default_admin_url() -> String {
"http://127.0.0.1:8080".to_string()
}
/// Clamp non-loopback bind addresses to loopback unless explicitly allowed.
fn clamp_non_loopback(addr: SocketAddr, allow_non_loopback: bool, name: &str) -> SocketAddr {
if addr.ip().is_loopback() {
return addr;
}
if allow_non_loopback {
warn!("DANGEROUS: {name} listening on non-loopback address {addr}");
return addr;
}
warn!(
"{name} requested non-loopback bind ({addr}); clamping to 127.0.0.1:{port} (set dangerously_allow_non_loopback_proxy or dangerously_allow_non_loopback_admin to override)",
port = addr.port()
);
SocketAddr::from(([127, 0, 0, 1], addr.port()))
}
pub(crate) fn clamp_bind_addrs(
http_addr: SocketAddr,
admin_addr: SocketAddr,
cfg: &NetworkProxySettings,
) -> (SocketAddr, SocketAddr) {
let http_addr = clamp_non_loopback(
http_addr,
cfg.dangerously_allow_non_loopback_proxy,
"HTTP proxy",
);
let admin_addr = clamp_non_loopback(
admin_addr,
cfg.dangerously_allow_non_loopback_admin,
"admin API",
);
if cfg.policy.allow_unix_sockets.is_empty() {
return (http_addr, admin_addr);
}
// `x-unix-socket` is intentionally a local escape hatch. If the proxy (or admin API) is
// reachable from outside the machine, it can become a remote bridge into local daemons
// (e.g. docker.sock). To avoid footguns, enforce loopback binding whenever unix sockets
// are enabled.
if cfg.dangerously_allow_non_loopback_proxy && !http_addr.ip().is_loopback() {
warn!(
"unix socket proxying is enabled; ignoring dangerously_allow_non_loopback_proxy and clamping HTTP proxy to loopback"
);
}
if cfg.dangerously_allow_non_loopback_admin && !admin_addr.ip().is_loopback() {
warn!(
"unix socket proxying is enabled; ignoring dangerously_allow_non_loopback_admin and clamping admin API to loopback"
);
}
(
SocketAddr::from(([127, 0, 0, 1], http_addr.port())),
SocketAddr::from(([127, 0, 0, 1], admin_addr.port())),
)
}
pub struct RuntimeConfig {
pub http_addr: SocketAddr,
pub admin_addr: SocketAddr,
}
pub fn resolve_runtime(cfg: &NetworkProxyConfig) -> Result<RuntimeConfig> {
let http_addr = resolve_addr(&cfg.network_proxy.proxy_url, 3128).with_context(|| {
format!(
"invalid network_proxy.proxy_url: {}",
cfg.network_proxy.proxy_url
)
})?;
let admin_addr = resolve_addr(&cfg.network_proxy.admin_url, 8080).with_context(|| {
format!(
"invalid network_proxy.admin_url: {}",
cfg.network_proxy.admin_url
)
})?;
let (http_addr, admin_addr) = clamp_bind_addrs(http_addr, admin_addr, &cfg.network_proxy);
Ok(RuntimeConfig {
http_addr,
admin_addr,
})
}
fn resolve_addr(url: &str, default_port: u16) -> Result<SocketAddr> {
let addr_parts = parse_host_port(url, default_port)?;
let host = if addr_parts.host.eq_ignore_ascii_case("localhost") {
"127.0.0.1".to_string()
} else {
addr_parts.host
};
match host.parse::<IpAddr>() {
Ok(ip) => Ok(SocketAddr::new(ip, addr_parts.port)),
Err(_) => Ok(SocketAddr::from(([127, 0, 0, 1], addr_parts.port))),
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct SocketAddressParts {
host: String,
port: u16,
}
fn parse_host_port(url: &str, default_port: u16) -> Result<SocketAddressParts> {
let trimmed = url.trim();
if trimmed.is_empty() {
bail!("missing host in network proxy address: {url}");
}
// Avoid treating unbracketed IPv6 literals like "2001:db8::1" as scheme-prefixed URLs.
if matches!(trimmed.parse::<IpAddr>(), Ok(IpAddr::V6(_))) && !trimmed.starts_with('[') {
return Ok(SocketAddressParts {
host: trimmed.to_string(),
port: default_port,
});
}
// Prefer the standard URL parser when the input is URL-like. Prefix a scheme when absent so
// we still accept loose host:port inputs.
let candidate = if trimmed.contains("://") {
trimmed.to_string()
} else {
format!("http://{trimmed}")
};
if let Ok(parsed) = Url::parse(&candidate)
&& let Some(host) = parsed.host_str()
{
let host = host.trim_matches(|c| c == '[' || c == ']');
if host.is_empty() {
bail!("missing host in network proxy address: {url}");
}
return Ok(SocketAddressParts {
host: host.to_string(),
port: parsed.port().unwrap_or(default_port),
});
}
parse_host_port_fallback(trimmed, default_port)
}
fn parse_host_port_fallback(input: &str, default_port: u16) -> Result<SocketAddressParts> {
let without_scheme = input
.split_once("://")
.map(|(_, rest)| rest)
.unwrap_or(input);
let host_port = without_scheme.split('/').next().unwrap_or(without_scheme);
let host_port = host_port
.rsplit_once('@')
.map(|(_, rest)| rest)
.unwrap_or(host_port);
if host_port.starts_with('[')
&& let Some(end) = host_port.find(']')
{
let host = &host_port[1..end];
let port = host_port[end + 1..]
.strip_prefix(':')
.and_then(|port| port.parse::<u16>().ok())
.unwrap_or(default_port);
if host.is_empty() {
bail!("missing host in network proxy address: {input}");
}
return Ok(SocketAddressParts {
host: host.to_string(),
port,
});
}
// Only treat `host:port` as such when there's a single `:`. This avoids
// accidentally interpreting unbracketed IPv6 addresses as `host:port`.
if host_port.bytes().filter(|b| *b == b':').count() == 1
&& let Some((host, port)) = host_port.rsplit_once(':')
&& let Ok(port) = port.parse::<u16>()
{
if host.is_empty() {
bail!("missing host in network proxy address: {input}");
}
return Ok(SocketAddressParts {
host: host.to_string(),
port,
});
}
if host_port.is_empty() {
bail!("missing host in network proxy address: {input}");
}
Ok(SocketAddressParts {
host: host_port.to_string(),
port: default_port,
})
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn parse_host_port_defaults_for_empty_string() {
assert!(parse_host_port("", 1234).is_err());
}
#[test]
fn parse_host_port_defaults_for_whitespace() {
assert!(parse_host_port(" ", 5555).is_err());
}
#[test]
fn parse_host_port_parses_host_port_without_scheme() {
assert_eq!(
parse_host_port("127.0.0.1:8080", 3128).unwrap(),
SocketAddressParts {
host: "127.0.0.1".to_string(),
port: 8080,
}
);
}
#[test]
fn parse_host_port_parses_host_port_with_scheme_and_path() {
assert_eq!(
parse_host_port("http://example.com:8080/some/path", 3128).unwrap(),
SocketAddressParts {
host: "example.com".to_string(),
port: 8080,
}
);
}
#[test]
fn parse_host_port_strips_userinfo() {
assert_eq!(
parse_host_port("http://user:pass@host.example:5555", 3128).unwrap(),
SocketAddressParts {
host: "host.example".to_string(),
port: 5555,
}
);
}
#[test]
fn parse_host_port_parses_ipv6_with_brackets() {
assert_eq!(
parse_host_port("http://[::1]:9999", 3128).unwrap(),
SocketAddressParts {
host: "::1".to_string(),
port: 9999,
}
);
}
#[test]
fn parse_host_port_does_not_treat_unbracketed_ipv6_as_host_port() {
assert_eq!(
parse_host_port("2001:db8::1", 3128).unwrap(),
SocketAddressParts {
host: "2001:db8::1".to_string(),
port: 3128,
}
);
}
#[test]
fn parse_host_port_falls_back_to_default_port_when_port_is_invalid() {
assert_eq!(
parse_host_port("example.com:notaport", 3128).unwrap(),
SocketAddressParts {
host: "example.com:notaport".to_string(),
port: 3128,
}
);
}
#[test]
fn resolve_addr_maps_localhost_to_loopback() {
assert_eq!(
resolve_addr("localhost", 3128).unwrap(),
"127.0.0.1:3128".parse::<SocketAddr>().unwrap()
);
}
#[test]
fn resolve_addr_parses_ip_literals() {
assert_eq!(
resolve_addr("1.2.3.4", 80).unwrap(),
"1.2.3.4:80".parse::<SocketAddr>().unwrap()
);
}
#[test]
fn resolve_addr_parses_ipv6_literals() {
assert_eq!(
resolve_addr("http://[::1]:8080", 3128).unwrap(),
"[::1]:8080".parse::<SocketAddr>().unwrap()
);
}
#[test]
fn resolve_addr_falls_back_to_loopback_for_hostnames() {
assert_eq!(
resolve_addr("http://example.com:5555", 3128).unwrap(),
"127.0.0.1:5555".parse::<SocketAddr>().unwrap()
);
}
#[test]
fn clamp_bind_addrs_allows_non_loopback_when_enabled() {
let cfg = NetworkProxySettings {
dangerously_allow_non_loopback_proxy: true,
dangerously_allow_non_loopback_admin: true,
..Default::default()
};
let http_addr = "0.0.0.0:3128".parse::<SocketAddr>().unwrap();
let admin_addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap();
let (http_addr, admin_addr) = clamp_bind_addrs(http_addr, admin_addr, &cfg);
assert_eq!(http_addr, "0.0.0.0:3128".parse::<SocketAddr>().unwrap());
assert_eq!(admin_addr, "0.0.0.0:8080".parse::<SocketAddr>().unwrap());
}
#[test]
fn clamp_bind_addrs_forces_loopback_when_unix_sockets_enabled() {
let cfg = NetworkProxySettings {
dangerously_allow_non_loopback_proxy: true,
dangerously_allow_non_loopback_admin: true,
policy: NetworkPolicy {
allow_unix_sockets: vec!["/tmp/docker.sock".to_string()],
..Default::default()
},
..Default::default()
};
let http_addr = "0.0.0.0:3128".parse::<SocketAddr>().unwrap();
let admin_addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap();
let (http_addr, admin_addr) = clamp_bind_addrs(http_addr, admin_addr, &cfg);
assert_eq!(http_addr, "127.0.0.1:3128".parse::<SocketAddr>().unwrap());
assert_eq!(admin_addr, "127.0.0.1:8080".parse::<SocketAddr>().unwrap());
}
}

View File

@@ -0,0 +1,636 @@
use crate::config::NetworkMode;
use crate::network_policy::NetworkDecision;
use crate::network_policy::NetworkPolicyDecider;
use crate::network_policy::NetworkPolicyRequest;
use crate::network_policy::NetworkProtocol;
use crate::network_policy::evaluate_host_policy;
use crate::policy::normalize_host;
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
use crate::reasons::REASON_NOT_ALLOWED;
use crate::reasons::REASON_PROXY_DISABLED;
use crate::responses::blocked_header_value;
use crate::responses::json_response;
use crate::runtime::unix_socket_permissions_supported;
use crate::state::BlockedRequest;
use crate::state::NetworkProxyState;
use crate::upstream::UpstreamClient;
use crate::upstream::proxy_for_connect;
use anyhow::Context as _;
use anyhow::Result;
use rama_core::Layer;
use rama_core::Service;
use rama_core::error::BoxError;
use rama_core::error::ErrorExt as _;
use rama_core::error::OpaqueError;
use rama_core::extensions::ExtensionsMut;
use rama_core::extensions::ExtensionsRef;
use rama_core::layer::AddInputExtensionLayer;
use rama_core::rt::Executor;
use rama_core::service::service_fn;
use rama_http::Body;
use rama_http::HeaderValue;
use rama_http::Request;
use rama_http::Response;
use rama_http::StatusCode;
use rama_http::layer::remove_header::RemoveRequestHeaderLayer;
use rama_http::layer::remove_header::RemoveResponseHeaderLayer;
use rama_http::matcher::MethodMatcher;
use rama_http_backend::client::proxy::layer::HttpProxyConnector;
use rama_http_backend::server::HttpServer;
use rama_http_backend::server::layer::upgrade::UpgradeLayer;
use rama_http_backend::server::layer::upgrade::Upgraded;
use rama_net::Protocol;
use rama_net::address::ProxyAddress;
use rama_net::client::ConnectorService;
use rama_net::client::EstablishedClientConnection;
use rama_net::http::RequestContext;
use rama_net::proxy::ProxyRequest;
use rama_net::proxy::ProxyTarget;
use rama_net::proxy::StreamForwardService;
use rama_net::stream::SocketInfo;
use rama_tcp::client::Request as TcpRequest;
use rama_tcp::client::service::TcpConnector;
use rama_tcp::server::TcpListener;
use rama_tls_boring::client::TlsConnectorDataBuilder;
use rama_tls_boring::client::TlsConnectorLayer;
use serde::Serialize;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::error;
use tracing::info;
use tracing::warn;
pub async fn run_http_proxy(
state: Arc<NetworkProxyState>,
addr: SocketAddr,
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
) -> Result<()> {
let listener = TcpListener::build()
.bind(addr)
.await
// Rama's `BoxError` is a `Box<dyn Error + Send + Sync>` without an explicit `'static`
// lifetime bound, which means it doesn't satisfy `anyhow::Context`'s `StdError` constraint.
// Wrap it in Rama's `OpaqueError` so we can preserve the original error as a source and
// still use `anyhow` for chaining.
.map_err(rama_core::error::OpaqueError::from)
.map_err(anyhow::Error::from)
.with_context(|| format!("bind HTTP proxy: {addr}"))?;
let http_service = HttpServer::auto(Executor::new()).service(
(
UpgradeLayer::new(
MethodMatcher::CONNECT,
service_fn({
let policy_decider = policy_decider.clone();
move |req| http_connect_accept(policy_decider.clone(), req)
}),
service_fn(http_connect_proxy),
),
RemoveResponseHeaderLayer::hop_by_hop(),
RemoveRequestHeaderLayer::hop_by_hop(),
)
.into_layer(service_fn({
let policy_decider = policy_decider.clone();
move |req| http_plain_proxy(policy_decider.clone(), req)
})),
);
info!("HTTP proxy listening on {addr}");
listener
.serve(AddInputExtensionLayer::new(state).into_layer(http_service))
.await;
Ok(())
}
async fn http_connect_accept(
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
mut req: Request,
) -> Result<(Response, Request), Response> {
let app_state = req
.extensions()
.get::<Arc<NetworkProxyState>>()
.cloned()
.ok_or_else(|| text_response(StatusCode::INTERNAL_SERVER_ERROR, "missing state"))?;
let authority = match RequestContext::try_from(&req).map(|ctx| ctx.host_with_port()) {
Ok(authority) => authority,
Err(err) => {
warn!("CONNECT missing authority: {err}");
return Err(text_response(StatusCode::BAD_REQUEST, "missing authority"));
}
};
let host = normalize_host(&authority.host.to_string());
if host.is_empty() {
return Err(text_response(StatusCode::BAD_REQUEST, "invalid host"));
}
let client = client_addr(&req);
let enabled = app_state
.enabled()
.await
.map_err(|err| internal_error("failed to read enabled state", err))?;
if !enabled {
let client = client.as_deref().unwrap_or_default();
warn!("CONNECT blocked; proxy disabled (client={client}, host={host})");
return Err(proxy_disabled_response(
&app_state,
host,
client_addr(&req),
Some("CONNECT".to_string()),
"http-connect",
)
.await);
}
let request = NetworkPolicyRequest::new(
NetworkProtocol::HttpsConnect,
host.clone(),
authority.port,
client.clone(),
Some("CONNECT".to_string()),
None,
None,
);
match evaluate_host_policy(&app_state, policy_decider.as_ref(), &request).await {
Ok(NetworkDecision::Deny { reason }) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
client.clone(),
Some("CONNECT".to_string()),
None,
"http-connect".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!("CONNECT blocked (client={client}, host={host}, reason={reason})");
return Err(blocked_text(&reason));
}
Ok(NetworkDecision::Allow) => {
let client = client.as_deref().unwrap_or_default();
info!("CONNECT allowed (client={client}, host={host})");
}
Err(err) => {
error!("failed to evaluate host for CONNECT {host}: {err}");
return Err(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
}
let mode = app_state
.network_mode()
.await
.map_err(|err| internal_error("failed to read network mode", err))?;
if mode == NetworkMode::Limited {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
REASON_METHOD_NOT_ALLOWED.to_string(),
client.clone(),
Some("CONNECT".to_string()),
Some(NetworkMode::Limited),
"http-connect".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!("CONNECT blocked by method policy (client={client}, host={host}, mode=limited)");
return Err(blocked_text(REASON_METHOD_NOT_ALLOWED));
}
req.extensions_mut().insert(ProxyTarget(authority));
req.extensions_mut().insert(mode);
Ok((
Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
.unwrap_or_else(|_| Response::new(Body::empty())),
req,
))
}
async fn http_connect_proxy(upgraded: Upgraded) -> Result<(), Infallible> {
if upgraded.extensions().get::<ProxyTarget>().is_none() {
warn!("CONNECT missing proxy target");
return Ok(());
}
let allow_upstream_proxy = match upgraded
.extensions()
.get::<Arc<NetworkProxyState>>()
.cloned()
{
Some(state) => match state.allow_upstream_proxy().await {
Ok(allowed) => allowed,
Err(err) => {
error!("failed to read upstream proxy setting: {err}");
false
}
},
None => {
error!("missing app state");
false
}
};
let proxy = if allow_upstream_proxy {
proxy_for_connect()
} else {
None
};
if let Err(err) = forward_connect_tunnel(upgraded, proxy).await {
warn!("tunnel error: {err}");
}
Ok(())
}
async fn forward_connect_tunnel(
upgraded: Upgraded,
proxy: Option<ProxyAddress>,
) -> Result<(), BoxError> {
let authority = upgraded
.extensions()
.get::<ProxyTarget>()
.map(|target| target.0.clone())
.ok_or_else(|| OpaqueError::from_display("missing forward authority").into_boxed())?;
let mut extensions = upgraded.extensions().clone();
if let Some(proxy) = proxy {
extensions.insert(proxy);
}
let req = TcpRequest::new_with_extensions(authority.clone(), extensions)
.with_protocol(Protocol::HTTPS);
let proxy_connector = HttpProxyConnector::optional(TcpConnector::new());
let tls_config = TlsConnectorDataBuilder::new_http_auto().into_shared_builder();
let connector = TlsConnectorLayer::tunnel(None)
.with_connector_data(tls_config)
.into_layer(proxy_connector);
let EstablishedClientConnection { conn: target, .. } =
connector.connect(req).await.map_err(|err| {
OpaqueError::from_boxed(err)
.with_context(|| format!("establish CONNECT tunnel to {authority}"))
.into_boxed()
})?;
let proxy_req = ProxyRequest {
source: upgraded,
target,
};
StreamForwardService::default()
.serve(proxy_req)
.await
.map_err(|err| {
OpaqueError::from_boxed(err.into())
.with_context(|| format!("forward CONNECT tunnel to {authority}"))
.into_boxed()
})
}
async fn http_plain_proxy(
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
req: Request,
) -> Result<Response, Infallible> {
let app_state = match req.extensions().get::<Arc<NetworkProxyState>>().cloned() {
Some(state) => state,
None => {
error!("missing app state");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
};
let client = client_addr(&req);
let method_allowed = match app_state
.method_allowed(req.method().as_str())
.await
.map_err(|err| internal_error("failed to evaluate method policy", err))
{
Ok(allowed) => allowed,
Err(resp) => return Ok(resp),
};
// `x-unix-socket` is an escape hatch for talking to local daemons. We keep it tightly scoped:
// macOS-only + explicit allowlist, to avoid turning the proxy into a general local capability
// escalation mechanism.
if let Some(unix_socket_header) = req.headers().get("x-unix-socket") {
let socket_path = match unix_socket_header.to_str() {
Ok(value) => value.to_string(),
Err(_) => {
warn!("invalid x-unix-socket header value (non-UTF8)");
return Ok(text_response(
StatusCode::BAD_REQUEST,
"invalid x-unix-socket header",
));
}
};
let enabled = match app_state
.enabled()
.await
.map_err(|err| internal_error("failed to read enabled state", err))
{
Ok(enabled) => enabled,
Err(resp) => return Ok(resp),
};
if !enabled {
let client = client.as_deref().unwrap_or_default();
warn!("unix socket blocked; proxy disabled (client={client}, path={socket_path})");
return Ok(proxy_disabled_response(
&app_state,
socket_path,
client_addr(&req),
Some(req.method().as_str().to_string()),
"unix-socket",
)
.await);
}
if !method_allowed {
let client = client.as_deref().unwrap_or_default();
let method = req.method();
warn!(
"unix socket blocked by method policy (client={client}, method={method}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Ok(json_blocked("unix-socket", REASON_METHOD_NOT_ALLOWED));
}
if !unix_socket_permissions_supported() {
warn!("unix socket proxy unsupported on this platform (path={socket_path})");
return Ok(text_response(
StatusCode::NOT_IMPLEMENTED,
"unix sockets unsupported",
));
}
return match app_state.is_unix_socket_allowed(&socket_path).await {
Ok(true) => {
let client = client.as_deref().unwrap_or_default();
info!("unix socket allowed (client={client}, path={socket_path})");
match proxy_via_unix_socket(req, &socket_path).await {
Ok(resp) => Ok(resp),
Err(err) => {
warn!("unix socket proxy failed: {err}");
Ok(text_response(
StatusCode::BAD_GATEWAY,
"unix socket proxy failed",
))
}
}
}
Ok(false) => {
let client = client.as_deref().unwrap_or_default();
warn!("unix socket blocked (client={client}, path={socket_path})");
Ok(json_blocked("unix-socket", REASON_NOT_ALLOWED))
}
Err(err) => {
warn!("unix socket check failed: {err}");
Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"))
}
};
}
let authority = match RequestContext::try_from(&req).map(|ctx| ctx.host_with_port()) {
Ok(authority) => authority,
Err(err) => {
warn!("missing host: {err}");
return Ok(text_response(StatusCode::BAD_REQUEST, "missing host"));
}
};
let host = normalize_host(&authority.host.to_string());
let port = authority.port;
let enabled = match app_state
.enabled()
.await
.map_err(|err| internal_error("failed to read enabled state", err))
{
Ok(enabled) => enabled,
Err(resp) => return Ok(resp),
};
if !enabled {
let client = client.as_deref().unwrap_or_default();
let method = req.method();
warn!("request blocked; proxy disabled (client={client}, host={host}, method={method})");
return Ok(proxy_disabled_response(
&app_state,
host,
client_addr(&req),
Some(req.method().as_str().to_string()),
"http",
)
.await);
}
let request = NetworkPolicyRequest::new(
NetworkProtocol::Http,
host.clone(),
port,
client.clone(),
Some(req.method().as_str().to_string()),
None,
None,
);
match evaluate_host_policy(&app_state, policy_decider.as_ref(), &request).await {
Ok(NetworkDecision::Deny { reason }) => {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
reason.clone(),
client.clone(),
Some(req.method().as_str().to_string()),
None,
"http".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
warn!("request blocked (client={client}, host={host}, reason={reason})");
return Ok(json_blocked(&host, &reason));
}
Ok(NetworkDecision::Allow) => {}
Err(err) => {
error!("failed to evaluate host for {host}: {err}");
return Ok(text_response(StatusCode::INTERNAL_SERVER_ERROR, "error"));
}
}
if !method_allowed {
let _ = app_state
.record_blocked(BlockedRequest::new(
host.clone(),
REASON_METHOD_NOT_ALLOWED.to_string(),
client.clone(),
Some(req.method().as_str().to_string()),
Some(NetworkMode::Limited),
"http".to_string(),
))
.await;
let client = client.as_deref().unwrap_or_default();
let method = req.method();
warn!(
"request blocked by method policy (client={client}, host={host}, method={method}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)"
);
return Ok(json_blocked(&host, REASON_METHOD_NOT_ALLOWED));
}
let client = client.as_deref().unwrap_or_default();
let method = req.method();
info!("request allowed (client={client}, host={host}, method={method})");
let allow_upstream_proxy = match app_state
.allow_upstream_proxy()
.await
.map_err(|err| internal_error("failed to read upstream proxy config", err))
{
Ok(allow) => allow,
Err(resp) => return Ok(resp),
};
let client = if allow_upstream_proxy {
UpstreamClient::from_env_proxy()
} else {
UpstreamClient::direct()
};
match client.serve(req).await {
Ok(resp) => Ok(resp),
Err(err) => {
warn!("upstream request failed: {err}");
Ok(text_response(StatusCode::BAD_GATEWAY, "upstream failure"))
}
}
}
async fn proxy_via_unix_socket(req: Request, socket_path: &str) -> Result<Response> {
#[cfg(target_os = "macos")]
{
let client = UpstreamClient::unix_socket(socket_path);
let (mut parts, body) = req.into_parts();
let path = parts
.uri
.path_and_query()
.map(rama_http::uri::PathAndQuery::as_str)
.unwrap_or("/");
parts.uri = path
.parse()
.with_context(|| format!("invalid unix socket request path: {path}"))?;
parts.headers.remove("x-unix-socket");
let req = Request::from_parts(parts, body);
client.serve(req).await.map_err(anyhow::Error::from)
}
#[cfg(not(target_os = "macos"))]
{
let _ = req;
let _ = socket_path;
Err(anyhow::anyhow!("unix sockets not supported"))
}
}
fn client_addr<T: ExtensionsRef>(input: &T) -> Option<String> {
input
.extensions()
.get::<SocketInfo>()
.map(|info| info.peer_addr().to_string())
}
fn json_blocked(host: &str, reason: &str) -> Response {
let response = BlockedResponse {
status: "blocked",
host,
reason,
};
let mut resp = json_response(&response);
*resp.status_mut() = StatusCode::FORBIDDEN;
resp.headers_mut().insert(
"x-proxy-error",
HeaderValue::from_static(blocked_header_value(reason)),
);
resp
}
fn blocked_text(reason: &str) -> Response {
crate::responses::blocked_text_response(reason)
}
async fn proxy_disabled_response(
app_state: &NetworkProxyState,
host: String,
client: Option<String>,
method: Option<String>,
protocol: &str,
) -> Response {
let _ = app_state
.record_blocked(BlockedRequest::new(
host,
REASON_PROXY_DISABLED.to_string(),
client,
method,
None,
protocol.to_string(),
))
.await;
text_response(StatusCode::SERVICE_UNAVAILABLE, "proxy disabled")
}
fn internal_error(context: &str, err: impl std::fmt::Display) -> Response {
error!("{context}: {err}");
text_response(StatusCode::INTERNAL_SERVER_ERROR, "error")
}
fn text_response(status: StatusCode, body: &str) -> Response {
Response::builder()
.status(status)
.header("content-type", "text/plain")
.body(Body::from(body.to_string()))
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
}
#[derive(Serialize)]
struct BlockedResponse<'a> {
status: &'static str,
host: &'a str,
reason: &'a str,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::NetworkMode;
use crate::config::NetworkPolicy;
use crate::runtime::network_proxy_state_for_policy;
use pretty_assertions::assert_eq;
use rama_http::Method;
use rama_http::Request;
use std::sync::Arc;
#[tokio::test]
async fn http_connect_accept_blocks_in_limited_mode() {
let policy = NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
..Default::default()
};
let state = Arc::new(network_proxy_state_for_policy(policy));
state.set_network_mode(NetworkMode::Limited).await.unwrap();
let mut req = Request::builder()
.method(Method::CONNECT)
.uri("https://example.com:443")
.header("host", "example.com:443")
.body(Body::empty())
.unwrap();
req.extensions_mut().insert(state);
let response = http_connect_accept(None, req).await.unwrap_err();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
assert_eq!(
response.headers().get("x-proxy-error").unwrap(),
"blocked-by-method-policy"
);
}
}

View File

@@ -0,0 +1,29 @@
#![deny(clippy::print_stdout, clippy::print_stderr)]
mod admin;
mod config;
mod http_proxy;
mod network_policy;
mod policy;
mod proxy;
mod reasons;
mod responses;
mod runtime;
mod state;
mod upstream;
use anyhow::Result;
pub use network_policy::NetworkDecision;
pub use network_policy::NetworkPolicyDecider;
pub use network_policy::NetworkPolicyRequest;
pub use network_policy::NetworkProtocol;
pub use proxy::Args;
pub use proxy::NetworkProxy;
pub use proxy::NetworkProxyBuilder;
pub use proxy::NetworkProxyHandle;
pub async fn run_main(args: Args) -> Result<()> {
let _ = args;
let proxy = NetworkProxy::builder().build().await?;
proxy.run().await?.wait().await
}

View File

@@ -0,0 +1,14 @@
use anyhow::Result;
use clap::Parser;
use codex_network_proxy::Args;
use codex_network_proxy::NetworkProxy;
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let args = Args::parse();
let _ = args;
let proxy = NetworkProxy::builder().build().await?;
proxy.run().await?.wait().await
}

View File

@@ -0,0 +1,234 @@
use crate::reasons::REASON_POLICY_DENIED;
use crate::runtime::HostBlockDecision;
use crate::runtime::HostBlockReason;
use crate::state::NetworkProxyState;
use anyhow::Result;
use async_trait::async_trait;
use std::future::Future;
use std::sync::Arc;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NetworkProtocol {
Http,
HttpsConnect,
Socks5Tcp,
Socks5Udp,
}
#[derive(Clone, Debug)]
pub struct NetworkPolicyRequest {
pub protocol: NetworkProtocol,
pub host: String,
pub port: u16,
pub client_addr: Option<String>,
pub method: Option<String>,
pub command: Option<String>,
pub exec_policy_hint: Option<String>,
}
impl NetworkPolicyRequest {
pub fn new(
protocol: NetworkProtocol,
host: String,
port: u16,
client_addr: Option<String>,
method: Option<String>,
command: Option<String>,
exec_policy_hint: Option<String>,
) -> Self {
Self {
protocol,
host,
port,
client_addr,
method,
command,
exec_policy_hint,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum NetworkDecision {
Allow,
Deny { reason: String },
}
impl NetworkDecision {
pub fn deny(reason: impl Into<String>) -> Self {
let reason = reason.into();
let reason = if reason.is_empty() {
REASON_POLICY_DENIED.to_string()
} else {
reason
};
Self::Deny { reason }
}
}
/// Decide whether a network request should be allowed.
///
/// If `command` or `exec_policy_hint` is provided, callers can map exec-policy
/// approvals to network access (e.g., allow all requests for commands matching
/// approved prefixes like `curl *`).
#[async_trait]
pub trait NetworkPolicyDecider: Send + Sync + 'static {
async fn decide(&self, req: NetworkPolicyRequest) -> NetworkDecision;
}
#[async_trait]
impl<D: NetworkPolicyDecider + ?Sized> NetworkPolicyDecider for Arc<D> {
async fn decide(&self, req: NetworkPolicyRequest) -> NetworkDecision {
(**self).decide(req).await
}
}
#[async_trait]
impl<F, Fut> NetworkPolicyDecider for F
where
F: Fn(NetworkPolicyRequest) -> Fut + Send + Sync + 'static,
Fut: Future<Output = NetworkDecision> + Send,
{
async fn decide(&self, req: NetworkPolicyRequest) -> NetworkDecision {
(self)(req).await
}
}
pub(crate) async fn evaluate_host_policy(
state: &NetworkProxyState,
decider: Option<&Arc<dyn NetworkPolicyDecider>>,
request: &NetworkPolicyRequest,
) -> Result<NetworkDecision> {
match state.host_blocked(&request.host, request.port).await? {
HostBlockDecision::Allowed => Ok(NetworkDecision::Allow),
HostBlockDecision::Blocked(HostBlockReason::NotAllowed) => {
if let Some(decider) = decider {
Ok(decider.decide(request.clone()).await)
} else {
Ok(NetworkDecision::deny(HostBlockReason::NotAllowed.as_str()))
}
}
HostBlockDecision::Blocked(reason) => Ok(NetworkDecision::deny(reason.as_str())),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::NetworkPolicy;
use crate::reasons::REASON_DENIED;
use crate::reasons::REASON_NOT_ALLOWED_LOCAL;
use crate::state::network_proxy_state_for_policy;
use pretty_assertions::assert_eq;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
#[tokio::test]
async fn evaluate_host_policy_invokes_decider_for_not_allowed() {
let state = network_proxy_state_for_policy(NetworkPolicy::default());
let calls = Arc::new(AtomicUsize::new(0));
let decider: Arc<dyn NetworkPolicyDecider> = Arc::new({
let calls = calls.clone();
move |_req| {
calls.fetch_add(1, Ordering::SeqCst);
// The default policy denies all; the decider is consulted for not_allowed
// requests and can override that decision.
async { NetworkDecision::Allow }
}
});
let request = NetworkPolicyRequest::new(
NetworkProtocol::Http,
"example.com".to_string(),
80,
None,
Some("GET".to_string()),
None,
None,
);
let decision = evaluate_host_policy(&state, Some(&decider), &request)
.await
.unwrap();
assert_eq!(decision, NetworkDecision::Allow);
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn evaluate_host_policy_skips_decider_for_denied() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
denied_domains: vec!["blocked.com".to_string()],
..NetworkPolicy::default()
});
let calls = Arc::new(AtomicUsize::new(0));
let decider: Arc<dyn NetworkPolicyDecider> = Arc::new({
let calls = calls.clone();
move |_req| {
calls.fetch_add(1, Ordering::SeqCst);
async { NetworkDecision::Allow }
}
});
let request = NetworkPolicyRequest::new(
NetworkProtocol::Http,
"blocked.com".to_string(),
80,
None,
Some("GET".to_string()),
None,
None,
);
let decision = evaluate_host_policy(&state, Some(&decider), &request)
.await
.unwrap();
assert_eq!(
decision,
NetworkDecision::Deny {
reason: REASON_DENIED.to_string()
}
);
assert_eq!(calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn evaluate_host_policy_skips_decider_for_not_allowed_local() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
let calls = Arc::new(AtomicUsize::new(0));
let decider: Arc<dyn NetworkPolicyDecider> = Arc::new({
let calls = calls.clone();
move |_req| {
calls.fetch_add(1, Ordering::SeqCst);
async { NetworkDecision::Allow }
}
});
let request = NetworkPolicyRequest::new(
NetworkProtocol::Http,
"127.0.0.1".to_string(),
80,
None,
Some("GET".to_string()),
None,
None,
);
let decision = evaluate_host_policy(&state, Some(&decider), &request)
.await
.unwrap();
assert_eq!(
decision,
NetworkDecision::Deny {
reason: REASON_NOT_ALLOWED_LOCAL.to_string()
}
);
assert_eq!(calls.load(Ordering::SeqCst), 0);
}
}

View File

@@ -0,0 +1,435 @@
#[cfg(test)]
use crate::config::NetworkMode;
use anyhow::Context;
use anyhow::Result;
use anyhow::ensure;
use globset::GlobBuilder;
use globset::GlobSet;
use globset::GlobSetBuilder;
use std::collections::HashSet;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
use url::Host as UrlHost;
/// A normalized host string for policy evaluation.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Host(String);
impl Host {
pub fn parse(input: &str) -> Result<Self> {
let normalized = normalize_host(input);
ensure!(!normalized.is_empty(), "host is empty");
Ok(Self(normalized))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
/// Returns true if the host is a loopback hostname or IP literal.
pub fn is_loopback_host(host: &Host) -> bool {
let host = host.as_str();
let host = host.split_once('%').map(|(ip, _)| ip).unwrap_or(host);
if host == "localhost" {
return true;
}
if let Ok(ip) = host.parse::<IpAddr>() {
return ip.is_loopback();
}
false
}
pub fn is_non_public_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(ip) => is_non_public_ipv4(ip),
IpAddr::V6(ip) => is_non_public_ipv6(ip),
}
}
fn is_non_public_ipv4(ip: Ipv4Addr) -> bool {
// Use the standard library classification helpers where possible; they encode the intent more
// clearly than hand-rolled range checks. Some non-public ranges (e.g., CGNAT and TEST-NET
// blocks) are not covered by stable stdlib helpers yet, so we fall back to CIDR checks.
ip.is_loopback()
|| ip.is_private()
|| ip.is_link_local()
|| ip.is_unspecified()
|| ip.is_multicast()
|| ip.is_broadcast()
|| ipv4_in_cidr(ip, [0, 0, 0, 0], 8) // "this network" (RFC 1122)
|| ipv4_in_cidr(ip, [100, 64, 0, 0], 10) // CGNAT (RFC 6598)
|| ipv4_in_cidr(ip, [192, 0, 0, 0], 24) // IETF Protocol Assignments (RFC 6890)
|| ipv4_in_cidr(ip, [192, 0, 2, 0], 24) // TEST-NET-1 (RFC 5737)
|| ipv4_in_cidr(ip, [198, 18, 0, 0], 15) // Benchmarking (RFC 2544)
|| ipv4_in_cidr(ip, [198, 51, 100, 0], 24) // TEST-NET-2 (RFC 5737)
|| ipv4_in_cidr(ip, [203, 0, 113, 0], 24) // TEST-NET-3 (RFC 5737)
|| ipv4_in_cidr(ip, [240, 0, 0, 0], 4) // Reserved (RFC 6890)
}
fn ipv4_in_cidr(ip: Ipv4Addr, base: [u8; 4], prefix: u8) -> bool {
let ip = u32::from(ip);
let base = u32::from(Ipv4Addr::from(base));
let mask = if prefix == 0 {
0
} else {
u32::MAX << (32 - prefix)
};
(ip & mask) == (base & mask)
}
fn is_non_public_ipv6(ip: Ipv6Addr) -> bool {
if let Some(v4) = ip.to_ipv4() {
return is_non_public_ipv4(v4) || ip.is_loopback();
}
// Treat anything that isn't globally routable as "local" for SSRF prevention. In particular:
// - `::1` loopback
// - `fc00::/7` unique-local (RFC 4193)
// - `fe80::/10` link-local
// - `::` unspecified
// - multicast ranges
ip.is_loopback()
|| ip.is_unspecified()
|| ip.is_multicast()
|| ip.is_unique_local()
|| ip.is_unicast_link_local()
}
/// Normalize host fragments for policy matching (trim whitespace, strip ports/brackets, lowercase).
pub fn normalize_host(host: &str) -> String {
let host = host.trim();
if host.starts_with('[')
&& let Some(end) = host.find(']')
{
return normalize_dns_host(&host[1..end]);
}
// The proxy stack should typically hand us a host without a port, but be
// defensive and strip `:port` when there is exactly one `:`.
if host.bytes().filter(|b| *b == b':').count() == 1 {
let host = host.split(':').next().unwrap_or_default();
return normalize_dns_host(host);
}
// Avoid mangling unbracketed IPv6 literals, but strip trailing dots so fully qualified domain
// names are treated the same as their dotless variants.
normalize_dns_host(host)
}
fn normalize_dns_host(host: &str) -> String {
let host = host.to_ascii_lowercase();
host.trim_end_matches('.').to_string()
}
fn normalize_pattern(pattern: &str) -> String {
let pattern = pattern.trim();
if pattern == "*" {
return "*".to_string();
}
let (prefix, remainder) = if let Some(domain) = pattern.strip_prefix("**.") {
("**.", domain)
} else if let Some(domain) = pattern.strip_prefix("*.") {
("*.", domain)
} else {
("", pattern)
};
let remainder = normalize_host(remainder);
if prefix.is_empty() {
remainder
} else {
format!("{prefix}{remainder}")
}
}
pub(crate) fn compile_globset(patterns: &[String]) -> Result<GlobSet> {
let mut builder = GlobSetBuilder::new();
let mut seen = HashSet::new();
for pattern in patterns {
let pattern = normalize_pattern(pattern);
// Supported domain patterns:
// - "example.com": match the exact host
// - "*.example.com": match any subdomain (not the apex)
// - "**.example.com": match the apex and any subdomain
// - "*": match any host
for candidate in expand_domain_pattern(&pattern) {
if !seen.insert(candidate.clone()) {
continue;
}
let glob = GlobBuilder::new(&candidate)
.case_insensitive(true)
.build()
.with_context(|| format!("invalid glob pattern: {candidate}"))?;
builder.add(glob);
}
}
Ok(builder.build()?)
}
#[derive(Debug, Clone)]
pub(crate) enum DomainPattern {
Any,
ApexAndSubdomains(String),
SubdomainsOnly(String),
Exact(String),
}
impl DomainPattern {
/// Parse a policy pattern for constraint comparisons.
///
/// Validation of glob syntax happens when building the globset; here we only
/// decode the wildcard prefixes to keep constraint checks lightweight.
pub(crate) fn parse(input: &str) -> Self {
let input = input.trim();
if input.is_empty() {
return Self::Exact(String::new());
}
if input == "*" {
Self::Any
} else if let Some(domain) = input.strip_prefix("**.") {
Self::parse_domain(domain, Self::ApexAndSubdomains)
} else if let Some(domain) = input.strip_prefix("*.") {
Self::parse_domain(domain, Self::SubdomainsOnly)
} else {
Self::Exact(input.to_string())
}
}
/// Parse a policy pattern for constraint comparisons, validating domain parts with `url`.
pub(crate) fn parse_for_constraints(input: &str) -> Self {
let input = input.trim();
if input.is_empty() {
return Self::Exact(String::new());
}
if input == "*" {
return Self::Any;
}
if let Some(domain) = input.strip_prefix("**.") {
return Self::ApexAndSubdomains(parse_domain_for_constraints(domain));
}
if let Some(domain) = input.strip_prefix("*.") {
return Self::SubdomainsOnly(parse_domain_for_constraints(domain));
}
Self::Exact(parse_domain_for_constraints(input))
}
fn parse_domain(domain: &str, build: impl FnOnce(String) -> Self) -> Self {
let domain = domain.trim();
if domain.is_empty() {
return Self::Exact(String::new());
}
build(domain.to_string())
}
pub(crate) fn allows(&self, candidate: &DomainPattern) -> bool {
match self {
DomainPattern::Any => true,
DomainPattern::Exact(domain) => match candidate {
DomainPattern::Exact(candidate) => domain_eq(candidate, domain),
_ => false,
},
DomainPattern::SubdomainsOnly(domain) => match candidate {
DomainPattern::Any => false,
DomainPattern::Exact(candidate) => is_strict_subdomain(candidate, domain),
DomainPattern::SubdomainsOnly(candidate) => {
is_subdomain_or_equal(candidate, domain)
}
DomainPattern::ApexAndSubdomains(candidate) => {
is_strict_subdomain(candidate, domain)
}
},
DomainPattern::ApexAndSubdomains(domain) => match candidate {
DomainPattern::Any => false,
DomainPattern::Exact(candidate) => is_subdomain_or_equal(candidate, domain),
DomainPattern::SubdomainsOnly(candidate) => {
is_subdomain_or_equal(candidate, domain)
}
DomainPattern::ApexAndSubdomains(candidate) => {
is_subdomain_or_equal(candidate, domain)
}
},
}
}
}
fn parse_domain_for_constraints(domain: &str) -> String {
let domain = domain.trim().trim_end_matches('.');
if domain.is_empty() {
return String::new();
}
let host = if domain.starts_with('[') && domain.ends_with(']') {
&domain[1..domain.len().saturating_sub(1)]
} else {
domain
};
if host.contains('*') || host.contains('?') || host.contains('%') {
return domain.to_string();
}
match UrlHost::parse(host) {
Ok(host) => host.to_string(),
Err(_) => String::new(),
}
}
fn expand_domain_pattern(pattern: &str) -> Vec<String> {
match DomainPattern::parse(pattern) {
DomainPattern::Any => vec![pattern.to_string()],
DomainPattern::Exact(domain) => vec![domain],
DomainPattern::SubdomainsOnly(domain) => {
vec![format!("?*.{domain}")]
}
DomainPattern::ApexAndSubdomains(domain) => {
vec![domain.clone(), format!("?*.{domain}")]
}
}
}
fn normalize_domain(domain: &str) -> String {
domain.trim_end_matches('.').to_ascii_lowercase()
}
fn domain_eq(left: &str, right: &str) -> bool {
normalize_domain(left) == normalize_domain(right)
}
fn is_subdomain_or_equal(child: &str, parent: &str) -> bool {
let child = normalize_domain(child);
let parent = normalize_domain(parent);
if child == parent {
return true;
}
child.ends_with(&format!(".{parent}"))
}
fn is_strict_subdomain(child: &str, parent: &str) -> bool {
let child = normalize_domain(child);
let parent = normalize_domain(parent);
child != parent && child.ends_with(&format!(".{parent}"))
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn method_allowed_full_allows_everything() {
assert!(NetworkMode::Full.allows_method("GET"));
assert!(NetworkMode::Full.allows_method("POST"));
assert!(NetworkMode::Full.allows_method("CONNECT"));
}
#[test]
fn method_allowed_limited_allows_only_safe_methods() {
assert!(NetworkMode::Limited.allows_method("GET"));
assert!(NetworkMode::Limited.allows_method("HEAD"));
assert!(NetworkMode::Limited.allows_method("OPTIONS"));
assert!(!NetworkMode::Limited.allows_method("POST"));
assert!(!NetworkMode::Limited.allows_method("CONNECT"));
}
#[test]
fn compile_globset_normalizes_trailing_dots() {
let set = compile_globset(&["Example.COM.".to_string()]).unwrap();
assert_eq!(true, set.is_match("example.com"));
assert_eq!(false, set.is_match("api.example.com"));
}
#[test]
fn compile_globset_normalizes_wildcards() {
let set = compile_globset(&["*.Example.COM.".to_string()]).unwrap();
assert_eq!(true, set.is_match("api.example.com"));
assert_eq!(false, set.is_match("example.com"));
}
#[test]
fn compile_globset_normalizes_apex_and_subdomains() {
let set = compile_globset(&["**.Example.COM.".to_string()]).unwrap();
assert_eq!(true, set.is_match("example.com"));
assert_eq!(true, set.is_match("api.example.com"));
}
#[test]
fn compile_globset_normalizes_bracketed_ipv6_literals() {
let set = compile_globset(&["[::1]".to_string()]).unwrap();
assert_eq!(true, set.is_match("::1"));
}
#[test]
fn is_loopback_host_handles_localhost_variants() {
assert!(is_loopback_host(&Host::parse("localhost").unwrap()));
assert!(is_loopback_host(&Host::parse("localhost.").unwrap()));
assert!(is_loopback_host(&Host::parse("LOCALHOST").unwrap()));
assert!(!is_loopback_host(&Host::parse("notlocalhost").unwrap()));
}
#[test]
fn is_loopback_host_handles_ip_literals() {
assert!(is_loopback_host(&Host::parse("127.0.0.1").unwrap()));
assert!(is_loopback_host(&Host::parse("::1").unwrap()));
assert!(!is_loopback_host(&Host::parse("1.2.3.4").unwrap()));
}
#[test]
fn is_non_public_ip_rejects_private_and_loopback_ranges() {
assert!(is_non_public_ip("127.0.0.1".parse().unwrap()));
assert!(is_non_public_ip("10.0.0.1".parse().unwrap()));
assert!(is_non_public_ip("192.168.0.1".parse().unwrap()));
assert!(is_non_public_ip("100.64.0.1".parse().unwrap()));
assert!(is_non_public_ip("192.0.0.1".parse().unwrap()));
assert!(is_non_public_ip("192.0.2.1".parse().unwrap()));
assert!(is_non_public_ip("198.18.0.1".parse().unwrap()));
assert!(is_non_public_ip("198.51.100.1".parse().unwrap()));
assert!(is_non_public_ip("203.0.113.1".parse().unwrap()));
assert!(is_non_public_ip("240.0.0.1".parse().unwrap()));
assert!(is_non_public_ip("0.1.2.3".parse().unwrap()));
assert!(!is_non_public_ip("8.8.8.8".parse().unwrap()));
assert!(is_non_public_ip("::ffff:127.0.0.1".parse().unwrap()));
assert!(is_non_public_ip("::ffff:10.0.0.1".parse().unwrap()));
assert!(!is_non_public_ip("::ffff:8.8.8.8".parse().unwrap()));
assert!(is_non_public_ip("::1".parse().unwrap()));
assert!(is_non_public_ip("fe80::1".parse().unwrap()));
assert!(is_non_public_ip("fc00::1".parse().unwrap()));
}
#[test]
fn normalize_host_lowercases_and_trims() {
assert_eq!(normalize_host(" ExAmPlE.CoM "), "example.com");
}
#[test]
fn normalize_host_strips_port_for_host_port() {
assert_eq!(normalize_host("example.com:1234"), "example.com");
}
#[test]
fn normalize_host_preserves_unbracketed_ipv6() {
assert_eq!(normalize_host("2001:db8::1"), "2001:db8::1");
}
#[test]
fn normalize_host_strips_trailing_dot() {
assert_eq!(normalize_host("example.com."), "example.com");
assert_eq!(normalize_host("ExAmPlE.CoM."), "example.com");
}
#[test]
fn normalize_host_strips_trailing_dot_with_port() {
assert_eq!(normalize_host("example.com.:443"), "example.com");
}
#[test]
fn normalize_host_strips_brackets_for_ipv6() {
assert_eq!(normalize_host("[::1]"), "::1");
assert_eq!(normalize_host("[::1]:443"), "::1");
}
}

View File

@@ -0,0 +1,176 @@
use crate::admin;
use crate::config;
use crate::http_proxy;
use crate::network_policy::NetworkPolicyDecider;
use crate::runtime::unix_socket_permissions_supported;
use crate::state::NetworkProxyState;
use anyhow::Context;
use anyhow::Result;
use clap::Parser;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::task::JoinHandle;
use tracing::warn;
#[derive(Debug, Clone, Parser)]
#[command(name = "codex-network-proxy", about = "Codex network sandbox proxy")]
pub struct Args {}
#[derive(Clone, Default)]
pub struct NetworkProxyBuilder {
state: Option<Arc<NetworkProxyState>>,
http_addr: Option<SocketAddr>,
admin_addr: Option<SocketAddr>,
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
}
impl NetworkProxyBuilder {
pub fn state(mut self, state: Arc<NetworkProxyState>) -> Self {
self.state = Some(state);
self
}
pub fn http_addr(mut self, addr: SocketAddr) -> Self {
self.http_addr = Some(addr);
self
}
pub fn admin_addr(mut self, addr: SocketAddr) -> Self {
self.admin_addr = Some(addr);
self
}
pub fn policy_decider<D>(mut self, decider: D) -> Self
where
D: NetworkPolicyDecider,
{
self.policy_decider = Some(Arc::new(decider));
self
}
pub fn policy_decider_arc(mut self, decider: Arc<dyn NetworkPolicyDecider>) -> Self {
self.policy_decider = Some(decider);
self
}
pub async fn build(self) -> Result<NetworkProxy> {
let state = match self.state {
Some(state) => state,
None => Arc::new(NetworkProxyState::new().await?),
};
let current_cfg = state.current_cfg().await?;
let runtime = config::resolve_runtime(&current_cfg)?;
// Reapply bind clamping for caller overrides so unix-socket proxying stays loopback-only.
let (http_addr, admin_addr) = config::clamp_bind_addrs(
self.http_addr.unwrap_or(runtime.http_addr),
self.admin_addr.unwrap_or(runtime.admin_addr),
&current_cfg.network_proxy,
);
Ok(NetworkProxy {
state,
http_addr,
admin_addr,
policy_decider: self.policy_decider,
})
}
}
#[derive(Clone)]
pub struct NetworkProxy {
state: Arc<NetworkProxyState>,
http_addr: SocketAddr,
admin_addr: SocketAddr,
policy_decider: Option<Arc<dyn NetworkPolicyDecider>>,
}
impl NetworkProxy {
pub fn builder() -> NetworkProxyBuilder {
NetworkProxyBuilder::default()
}
pub async fn run(&self) -> Result<NetworkProxyHandle> {
let current_cfg = self.state.current_cfg().await?;
if !current_cfg.network_proxy.enabled {
warn!("network_proxy.enabled is false; skipping proxy listeners");
return Ok(NetworkProxyHandle::noop());
}
if !unix_socket_permissions_supported() {
warn!("allowUnixSockets is macOS-only; requests will be rejected on this platform");
}
let http_task = tokio::spawn(http_proxy::run_http_proxy(
self.state.clone(),
self.http_addr,
self.policy_decider.clone(),
));
let admin_task = tokio::spawn(admin::run_admin_api(self.state.clone(), self.admin_addr));
Ok(NetworkProxyHandle {
http_task: Some(http_task),
admin_task: Some(admin_task),
completed: false,
})
}
}
pub struct NetworkProxyHandle {
http_task: Option<JoinHandle<Result<()>>>,
admin_task: Option<JoinHandle<Result<()>>>,
completed: bool,
}
impl NetworkProxyHandle {
fn noop() -> Self {
Self {
http_task: Some(tokio::spawn(async { Ok(()) })),
admin_task: Some(tokio::spawn(async { Ok(()) })),
completed: true,
}
}
pub async fn wait(mut self) -> Result<()> {
let http_task = self.http_task.take().context("missing http proxy task")?;
let admin_task = self.admin_task.take().context("missing admin proxy task")?;
let http_result = http_task.await;
let admin_result = admin_task.await;
self.completed = true;
http_result??;
admin_result??;
Ok(())
}
pub async fn shutdown(mut self) -> Result<()> {
abort_tasks(self.http_task.take(), self.admin_task.take()).await;
self.completed = true;
Ok(())
}
}
async fn abort_tasks(
http_task: Option<JoinHandle<Result<()>>>,
admin_task: Option<JoinHandle<Result<()>>>,
) {
if let Some(http_task) = http_task {
http_task.abort();
let _ = http_task.await;
}
if let Some(admin_task) = admin_task {
admin_task.abort();
let _ = admin_task.await;
}
}
impl Drop for NetworkProxyHandle {
fn drop(&mut self) {
if self.completed {
return;
}
let http_task = self.http_task.take();
let admin_task = self.admin_task.take();
tokio::spawn(async move {
abort_tasks(http_task, admin_task).await;
});
}
}

View File

@@ -0,0 +1,6 @@
pub(crate) const REASON_DENIED: &str = "denied";
pub(crate) const REASON_METHOD_NOT_ALLOWED: &str = "method_not_allowed";
pub(crate) const REASON_NOT_ALLOWED: &str = "not_allowed";
pub(crate) const REASON_NOT_ALLOWED_LOCAL: &str = "not_allowed_local";
pub(crate) const REASON_POLICY_DENIED: &str = "policy_denied";
pub(crate) const REASON_PROXY_DISABLED: &str = "proxy_disabled";

View File

@@ -0,0 +1,67 @@
use crate::reasons::REASON_DENIED;
use crate::reasons::REASON_METHOD_NOT_ALLOWED;
use crate::reasons::REASON_NOT_ALLOWED;
use crate::reasons::REASON_NOT_ALLOWED_LOCAL;
use rama_http::Body;
use rama_http::Response;
use rama_http::StatusCode;
use serde::Serialize;
use tracing::error;
pub fn text_response(status: StatusCode, body: &str) -> Response {
Response::builder()
.status(status)
.header("content-type", "text/plain")
.body(Body::from(body.to_string()))
.unwrap_or_else(|_| Response::new(Body::from(body.to_string())))
}
pub fn json_response<T: Serialize>(value: &T) -> Response {
let body = match serde_json::to_string(value) {
Ok(body) => body,
Err(err) => {
error!("failed to serialize JSON response: {err}");
"{}".to_string()
}
};
Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap_or_else(|err| {
error!("failed to build JSON response: {err}");
Response::new(Body::from("{}"))
})
}
pub fn blocked_header_value(reason: &str) -> &'static str {
match reason {
REASON_NOT_ALLOWED | REASON_NOT_ALLOWED_LOCAL => "blocked-by-allowlist",
REASON_DENIED => "blocked-by-denylist",
REASON_METHOD_NOT_ALLOWED => "blocked-by-method-policy",
_ => "blocked-by-policy",
}
}
pub fn blocked_message(reason: &str) -> &'static str {
match reason {
REASON_NOT_ALLOWED => "Codex blocked this request: domain not in allowlist.",
REASON_NOT_ALLOWED_LOCAL => {
"Codex blocked this request: local/private addresses not allowed."
}
REASON_DENIED => "Codex blocked this request: domain denied by policy.",
REASON_METHOD_NOT_ALLOWED => {
"Codex blocked this request: method not allowed in limited mode."
}
_ => "Codex blocked this request by network policy.",
}
}
pub fn blocked_text_response(reason: &str) -> Response {
Response::builder()
.status(StatusCode::FORBIDDEN)
.header("content-type", "text/plain")
.header("x-proxy-error", blocked_header_value(reason))
.body(Body::from(blocked_message(reason)))
.unwrap_or_else(|_| Response::new(Body::from("blocked")))
}

View File

@@ -0,0 +1,984 @@
use crate::config::NetworkMode;
use crate::config::NetworkProxyConfig;
use crate::policy::Host;
use crate::policy::is_loopback_host;
use crate::policy::is_non_public_ip;
use crate::policy::normalize_host;
use crate::reasons::REASON_DENIED;
use crate::reasons::REASON_NOT_ALLOWED;
use crate::reasons::REASON_NOT_ALLOWED_LOCAL;
use crate::state::NetworkProxyConstraints;
use crate::state::build_config_state;
use crate::state::validate_policy_against_constraints;
use anyhow::Context;
use anyhow::Result;
use codex_utils_absolute_path::AbsolutePathBuf;
use globset::GlobSet;
use serde::Serialize;
use std::collections::HashSet;
use std::collections::VecDeque;
use std::net::IpAddr;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use std::time::SystemTime;
use time::OffsetDateTime;
use tokio::net::lookup_host;
use tokio::sync::RwLock;
use tokio::time::timeout;
use tracing::info;
use tracing::warn;
const MAX_BLOCKED_EVENTS: usize = 200;
const DNS_LOOKUP_TIMEOUT: Duration = Duration::from_secs(2);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum HostBlockReason {
Denied,
NotAllowed,
NotAllowedLocal,
}
impl HostBlockReason {
pub const fn as_str(self) -> &'static str {
match self {
Self::Denied => REASON_DENIED,
Self::NotAllowed => REASON_NOT_ALLOWED,
Self::NotAllowedLocal => REASON_NOT_ALLOWED_LOCAL,
}
}
}
impl std::fmt::Display for HostBlockReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum HostBlockDecision {
Allowed,
Blocked(HostBlockReason),
}
#[derive(Clone, Debug, Serialize)]
pub struct BlockedRequest {
pub host: String,
pub reason: String,
pub client: Option<String>,
pub method: Option<String>,
pub mode: Option<NetworkMode>,
pub protocol: String,
pub timestamp: i64,
}
impl BlockedRequest {
pub fn new(
host: String,
reason: String,
client: Option<String>,
method: Option<String>,
mode: Option<NetworkMode>,
protocol: String,
) -> Self {
Self {
host,
reason,
client,
method,
mode,
protocol,
timestamp: unix_timestamp(),
}
}
}
#[derive(Clone)]
pub(crate) struct ConfigState {
pub(crate) config: NetworkProxyConfig,
pub(crate) allow_set: GlobSet,
pub(crate) deny_set: GlobSet,
pub(crate) constraints: NetworkProxyConstraints,
pub(crate) layer_mtimes: Vec<LayerMtime>,
pub(crate) cfg_path: PathBuf,
pub(crate) blocked: VecDeque<BlockedRequest>,
}
#[derive(Clone)]
pub(crate) struct LayerMtime {
pub(crate) path: PathBuf,
pub(crate) mtime: Option<SystemTime>,
}
impl LayerMtime {
pub(crate) fn new(path: PathBuf) -> Self {
let mtime = path.metadata().and_then(|m| m.modified()).ok();
Self { path, mtime }
}
}
#[derive(Clone)]
pub struct NetworkProxyState {
state: Arc<RwLock<ConfigState>>,
}
impl std::fmt::Debug for NetworkProxyState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Avoid logging internal state (config contents, derived globsets, etc.) which can be noisy
// and may contain sensitive paths.
f.debug_struct("NetworkProxyState").finish_non_exhaustive()
}
}
impl NetworkProxyState {
pub async fn new() -> Result<Self> {
let cfg_state = build_config_state().await?;
Ok(Self {
state: Arc::new(RwLock::new(cfg_state)),
})
}
pub async fn current_cfg(&self) -> Result<NetworkProxyConfig> {
// Callers treat `NetworkProxyState` as a live view of policy. We reload-on-demand so edits to
// `config.toml` (including Codex-managed writes) take effect without a restart.
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.config.clone())
}
pub async fn current_patterns(&self) -> Result<(Vec<String>, Vec<String>)> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok((
guard.config.network_proxy.policy.allowed_domains.clone(),
guard.config.network_proxy.policy.denied_domains.clone(),
))
}
pub async fn enabled(&self) -> Result<bool> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.config.network_proxy.enabled)
}
pub async fn force_reload(&self) -> Result<()> {
let (previous_cfg, cfg_path) = {
let guard = self.state.read().await;
(guard.config.clone(), guard.cfg_path.clone())
};
match build_config_state().await {
Ok(mut new_state) => {
// Policy changes are operationally sensitive; logging diffs makes changes traceable
// without needing to dump full config blobs (which can include unrelated settings).
log_policy_changes(&previous_cfg, &new_state.config);
let mut guard = self.state.write().await;
new_state.blocked = guard.blocked.clone();
*guard = new_state;
let path = guard.cfg_path.display();
info!("reloaded config from {path}");
Ok(())
}
Err(err) => {
let path = cfg_path.display();
warn!("failed to reload config from {path}: {err}; keeping previous config");
Err(err)
}
}
}
pub async fn host_blocked(&self, host: &str, port: u16) -> Result<HostBlockDecision> {
self.reload_if_needed().await?;
let host = match Host::parse(host) {
Ok(host) => host,
Err(_) => return Ok(HostBlockDecision::Blocked(HostBlockReason::NotAllowed)),
};
let (deny_set, allow_set, allow_local_binding, allowed_domains_empty, allowed_domains) = {
let guard = self.state.read().await;
(
guard.deny_set.clone(),
guard.allow_set.clone(),
guard.config.network_proxy.policy.allow_local_binding,
guard.config.network_proxy.policy.allowed_domains.is_empty(),
guard.config.network_proxy.policy.allowed_domains.clone(),
)
};
let host_str = host.as_str();
// Decision order matters:
// 1) explicit deny always wins
// 2) local/private networking is opt-in (defense-in-depth)
// 3) allowlist is enforced when configured
if deny_set.is_match(host_str) {
return Ok(HostBlockDecision::Blocked(HostBlockReason::Denied));
}
let is_allowlisted = allow_set.is_match(host_str);
if !allow_local_binding {
// If the intent is "prevent access to local/internal networks", we must not rely solely
// on string checks like `localhost` / `127.0.0.1`. Attackers can use DNS rebinding or
// public suffix services that map hostnames onto private IPs.
//
// We therefore do a best-effort DNS + IP classification check before allowing the
// request. Explicit local/loopback literals are allowed only when explicitly
// allowlisted; hostnames that resolve to local/private IPs are blocked even if
// allowlisted.
let local_literal = {
let host_no_scope = host_str
.split_once('%')
.map(|(ip, _)| ip)
.unwrap_or(host_str);
if is_loopback_host(&host) {
true
} else if let Ok(ip) = host_no_scope.parse::<IpAddr>() {
is_non_public_ip(ip)
} else {
false
}
};
if local_literal {
if !is_explicit_local_allowlisted(&allowed_domains, &host) {
return Ok(HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal));
}
} else if host_resolves_to_non_public_ip(host_str, port).await {
return Ok(HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal));
}
}
if allowed_domains_empty || !is_allowlisted {
Ok(HostBlockDecision::Blocked(HostBlockReason::NotAllowed))
} else {
Ok(HostBlockDecision::Allowed)
}
}
pub async fn record_blocked(&self, entry: BlockedRequest) -> Result<()> {
self.reload_if_needed().await?;
let mut guard = self.state.write().await;
guard.blocked.push_back(entry);
while guard.blocked.len() > MAX_BLOCKED_EVENTS {
guard.blocked.pop_front();
}
Ok(())
}
/// Drain and return the buffered blocked-request entries in FIFO order.
pub async fn drain_blocked(&self) -> Result<Vec<BlockedRequest>> {
self.reload_if_needed().await?;
let blocked = {
let mut guard = self.state.write().await;
std::mem::take(&mut guard.blocked)
};
Ok(blocked.into_iter().collect())
}
pub async fn is_unix_socket_allowed(&self, path: &str) -> Result<bool> {
self.reload_if_needed().await?;
if !unix_socket_permissions_supported() {
return Ok(false);
}
// We only support absolute unix socket paths (a relative path would be ambiguous with
// respect to the proxy process's CWD and can lead to confusing allowlist behavior).
let requested_path = Path::new(path);
if !requested_path.is_absolute() {
return Ok(false);
}
let guard = self.state.read().await;
// Normalize the path while keeping the absolute-path requirement explicit.
let requested_abs = match AbsolutePathBuf::from_absolute_path(requested_path) {
Ok(path) => path,
Err(_) => return Ok(false),
};
let requested_canonical = std::fs::canonicalize(requested_abs.as_path()).ok();
for allowed in &guard.config.network_proxy.policy.allow_unix_sockets {
if allowed == path {
return Ok(true);
}
// Best-effort canonicalization to reduce surprises with symlinks.
// If canonicalization fails (e.g., socket not created yet), fall back to raw comparison.
let Some(requested_canonical) = &requested_canonical else {
continue;
};
if let Ok(allowed_canonical) = std::fs::canonicalize(allowed)
&& &allowed_canonical == requested_canonical
{
return Ok(true);
}
}
Ok(false)
}
pub async fn method_allowed(&self, method: &str) -> Result<bool> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.config.network_proxy.mode.allows_method(method))
}
pub async fn allow_upstream_proxy(&self) -> Result<bool> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.config.network_proxy.allow_upstream_proxy)
}
pub async fn network_mode(&self) -> Result<NetworkMode> {
self.reload_if_needed().await?;
let guard = self.state.read().await;
Ok(guard.config.network_proxy.mode)
}
pub async fn set_network_mode(&self, mode: NetworkMode) -> Result<()> {
loop {
self.reload_if_needed().await?;
let (candidate, constraints) = {
let guard = self.state.read().await;
let mut candidate = guard.config.clone();
candidate.network_proxy.mode = mode;
(candidate, guard.constraints.clone())
};
validate_policy_against_constraints(&candidate, &constraints)
.context("network_proxy.mode constrained by managed config")?;
let mut guard = self.state.write().await;
if guard.constraints != constraints {
drop(guard);
continue;
}
guard.config.network_proxy.mode = mode;
info!("updated network mode to {mode:?}");
return Ok(());
}
}
async fn reload_if_needed(&self) -> Result<()> {
let needs_reload = {
let guard = self.state.read().await;
guard.layer_mtimes.iter().any(|layer| {
let metadata = std::fs::metadata(&layer.path).ok();
match (metadata.and_then(|m| m.modified().ok()), layer.mtime) {
(Some(new_mtime), Some(old_mtime)) => new_mtime > old_mtime,
(Some(_), None) => true,
(None, Some(_)) => true,
(None, None) => false,
}
})
};
if !needs_reload {
return Ok(());
}
self.force_reload().await
}
}
pub(crate) fn unix_socket_permissions_supported() -> bool {
cfg!(target_os = "macos")
}
async fn host_resolves_to_non_public_ip(host: &str, port: u16) -> bool {
if let Ok(ip) = host.parse::<IpAddr>() {
return is_non_public_ip(ip);
}
// If DNS lookup fails, default to "not local/private" rather than blocking. In practice, the
// subsequent connect attempt will fail anyway, and blocking on transient resolver issues would
// make the proxy fragile. The allowlist/denylist remains the primary control plane.
let addrs = match timeout(DNS_LOOKUP_TIMEOUT, lookup_host((host, port))).await {
Ok(Ok(addrs)) => addrs,
Ok(Err(_)) | Err(_) => return false,
};
for addr in addrs {
if is_non_public_ip(addr.ip()) {
return true;
}
}
false
}
fn log_policy_changes(previous: &NetworkProxyConfig, next: &NetworkProxyConfig) {
log_domain_list_changes(
"allowlist",
&previous.network_proxy.policy.allowed_domains,
&next.network_proxy.policy.allowed_domains,
);
log_domain_list_changes(
"denylist",
&previous.network_proxy.policy.denied_domains,
&next.network_proxy.policy.denied_domains,
);
}
fn log_domain_list_changes(list_name: &str, previous: &[String], next: &[String]) {
let previous_set: HashSet<String> = previous
.iter()
.map(|entry| entry.to_ascii_lowercase())
.collect();
let next_set: HashSet<String> = next
.iter()
.map(|entry| entry.to_ascii_lowercase())
.collect();
let added = next_set
.difference(&previous_set)
.cloned()
.collect::<HashSet<_>>();
let removed = previous_set
.difference(&next_set)
.cloned()
.collect::<HashSet<_>>();
let mut seen_next = HashSet::new();
for entry in next {
let key = entry.to_ascii_lowercase();
if seen_next.insert(key.clone()) && added.contains(&key) {
info!("config entry added to {list_name}: {entry}");
}
}
let mut seen_previous = HashSet::new();
for entry in previous {
let key = entry.to_ascii_lowercase();
if seen_previous.insert(key.clone()) && removed.contains(&key) {
info!("config entry removed from {list_name}: {entry}");
}
}
}
fn is_explicit_local_allowlisted(allowed_domains: &[String], host: &Host) -> bool {
let normalized_host = host.as_str();
allowed_domains.iter().any(|pattern| {
let pattern = pattern.trim();
if pattern == "*" || pattern.starts_with("*.") || pattern.starts_with("**.") {
return false;
}
if pattern.contains('*') || pattern.contains('?') {
return false;
}
normalize_host(pattern) == normalized_host
})
}
fn unix_timestamp() -> i64 {
OffsetDateTime::now_utc().unix_timestamp()
}
#[cfg(test)]
pub(crate) fn network_proxy_state_for_policy(
policy: crate::config::NetworkPolicy,
) -> NetworkProxyState {
let config = NetworkProxyConfig {
network_proxy: crate::config::NetworkProxySettings {
enabled: true,
mode: NetworkMode::Full,
policy,
..crate::config::NetworkProxySettings::default()
},
};
let allow_set =
crate::policy::compile_globset(&config.network_proxy.policy.allowed_domains).unwrap();
let deny_set =
crate::policy::compile_globset(&config.network_proxy.policy.denied_domains).unwrap();
let state = ConfigState {
config,
allow_set,
deny_set,
constraints: NetworkProxyConstraints::default(),
layer_mtimes: Vec::new(),
cfg_path: PathBuf::from("/nonexistent/config.toml"),
blocked: VecDeque::new(),
};
NetworkProxyState {
state: Arc::new(RwLock::new(state)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::NetworkPolicy;
use crate::config::NetworkProxyConfig;
use crate::config::NetworkProxySettings;
use crate::policy::compile_globset;
use crate::state::NetworkProxyConstraints;
use crate::state::validate_policy_against_constraints;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn host_blocked_denied_wins_over_allowed() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
denied_domains: vec!["example.com".to_string()],
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("example.com", 80).await.unwrap(),
HostBlockDecision::Blocked(HostBlockReason::Denied)
);
}
#[tokio::test]
async fn host_blocked_requires_allowlist_match() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("example.com", 80).await.unwrap(),
HostBlockDecision::Allowed
);
assert_eq!(
// Use a public IP literal to avoid relying on ambient DNS behavior (some networks
// resolve unknown hostnames to private IPs, which would trigger `not_allowed_local`).
state.host_blocked("8.8.8.8", 80).await.unwrap(),
HostBlockDecision::Blocked(HostBlockReason::NotAllowed)
);
}
#[tokio::test]
async fn host_blocked_subdomain_wildcards_exclude_apex() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["*.openai.com".to_string()],
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("api.openai.com", 80).await.unwrap(),
HostBlockDecision::Allowed
);
assert_eq!(
state.host_blocked("openai.com", 80).await.unwrap(),
HostBlockDecision::Blocked(HostBlockReason::NotAllowed)
);
}
#[tokio::test]
async fn host_blocked_rejects_loopback_when_local_binding_disabled() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("127.0.0.1", 80).await.unwrap(),
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
);
assert_eq!(
state.host_blocked("localhost", 80).await.unwrap(),
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
);
}
#[tokio::test]
async fn host_blocked_rejects_loopback_when_allowlist_is_wildcard() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["*".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("127.0.0.1", 80).await.unwrap(),
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
);
}
#[tokio::test]
async fn host_blocked_rejects_private_ip_literal_when_allowlist_is_wildcard() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["*".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("10.0.0.1", 80).await.unwrap(),
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
);
}
#[tokio::test]
async fn host_blocked_allows_loopback_when_explicitly_allowlisted_and_local_binding_disabled() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["localhost".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("localhost", 80).await.unwrap(),
HostBlockDecision::Allowed
);
}
#[tokio::test]
async fn host_blocked_allows_private_ip_literal_when_explicitly_allowlisted() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["10.0.0.1".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("10.0.0.1", 80).await.unwrap(),
HostBlockDecision::Allowed
);
}
#[tokio::test]
async fn host_blocked_rejects_scoped_ipv6_literal_when_not_allowlisted() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("fe80::1%lo0", 80).await.unwrap(),
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
);
}
#[tokio::test]
async fn host_blocked_allows_scoped_ipv6_literal_when_explicitly_allowlisted() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["fe80::1%lo0".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("fe80::1%lo0", 80).await.unwrap(),
HostBlockDecision::Allowed
);
}
#[tokio::test]
async fn host_blocked_rejects_private_ip_literals_when_local_binding_disabled() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("10.0.0.1", 80).await.unwrap(),
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
);
}
#[tokio::test]
async fn host_blocked_rejects_loopback_when_allowlist_empty() {
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec![],
allow_local_binding: false,
..NetworkPolicy::default()
});
assert_eq!(
state.host_blocked("127.0.0.1", 80).await.unwrap(),
HostBlockDecision::Blocked(HostBlockReason::NotAllowedLocal)
);
}
#[test]
fn validate_policy_against_constraints_disallows_widening_allowed_domains() {
let constraints = NetworkProxyConstraints {
allowed_domains: Some(vec!["example.com".to_string()]),
..NetworkProxyConstraints::default()
};
let config = NetworkProxyConfig {
network_proxy: NetworkProxySettings {
enabled: true,
policy: NetworkPolicy {
allowed_domains: vec!["example.com".to_string(), "evil.com".to_string()],
..NetworkPolicy::default()
},
..NetworkProxySettings::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_disallows_widening_mode() {
let constraints = NetworkProxyConstraints {
mode: Some(NetworkMode::Limited),
..NetworkProxyConstraints::default()
};
let config = NetworkProxyConfig {
network_proxy: NetworkProxySettings {
enabled: true,
mode: NetworkMode::Full,
..NetworkProxySettings::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_allows_narrowing_wildcard_allowlist() {
let constraints = NetworkProxyConstraints {
allowed_domains: Some(vec!["*.example.com".to_string()]),
..NetworkProxyConstraints::default()
};
let config = NetworkProxyConfig {
network_proxy: NetworkProxySettings {
enabled: true,
policy: NetworkPolicy {
allowed_domains: vec!["api.example.com".to_string()],
..NetworkPolicy::default()
},
..NetworkProxySettings::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_ok());
}
#[test]
fn validate_policy_against_constraints_rejects_widening_wildcard_allowlist() {
let constraints = NetworkProxyConstraints {
allowed_domains: Some(vec!["*.example.com".to_string()]),
..NetworkProxyConstraints::default()
};
let config = NetworkProxyConfig {
network_proxy: NetworkProxySettings {
enabled: true,
policy: NetworkPolicy {
allowed_domains: vec!["**.example.com".to_string()],
..NetworkPolicy::default()
},
..NetworkProxySettings::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_requires_managed_denied_domains_entries() {
let constraints = NetworkProxyConstraints {
denied_domains: Some(vec!["evil.com".to_string()]),
..NetworkProxyConstraints::default()
};
let config = NetworkProxyConfig {
network_proxy: NetworkProxySettings {
enabled: true,
policy: NetworkPolicy {
denied_domains: vec![],
..NetworkPolicy::default()
},
..NetworkProxySettings::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_disallows_enabling_when_managed_disabled() {
let constraints = NetworkProxyConstraints {
enabled: Some(false),
..NetworkProxyConstraints::default()
};
let config = NetworkProxyConfig {
network_proxy: NetworkProxySettings {
enabled: true,
..NetworkProxySettings::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_disallows_allow_local_binding_when_managed_disabled() {
let constraints = NetworkProxyConstraints {
allow_local_binding: Some(false),
..NetworkProxyConstraints::default()
};
let config = NetworkProxyConfig {
network_proxy: NetworkProxySettings {
enabled: true,
policy: NetworkPolicy {
allow_local_binding: true,
..NetworkPolicy::default()
},
..NetworkProxySettings::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_disallows_non_loopback_admin_without_managed_opt_in() {
let constraints = NetworkProxyConstraints {
dangerously_allow_non_loopback_admin: Some(false),
..NetworkProxyConstraints::default()
};
let config = NetworkProxyConfig {
network_proxy: NetworkProxySettings {
enabled: true,
dangerously_allow_non_loopback_admin: true,
..NetworkProxySettings::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_err());
}
#[test]
fn validate_policy_against_constraints_allows_non_loopback_admin_with_managed_opt_in() {
let constraints = NetworkProxyConstraints {
dangerously_allow_non_loopback_admin: Some(true),
..NetworkProxyConstraints::default()
};
let config = NetworkProxyConfig {
network_proxy: NetworkProxySettings {
enabled: true,
dangerously_allow_non_loopback_admin: true,
..NetworkProxySettings::default()
},
};
assert!(validate_policy_against_constraints(&config, &constraints).is_ok());
}
#[test]
fn compile_globset_is_case_insensitive() {
let patterns = vec!["ExAmPle.CoM".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("example.com"));
assert!(set.is_match("EXAMPLE.COM"));
}
#[test]
fn compile_globset_excludes_apex_for_subdomain_patterns() {
let patterns = vec!["*.openai.com".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("api.openai.com"));
assert!(!set.is_match("openai.com"));
assert!(!set.is_match("evilopenai.com"));
}
#[test]
fn compile_globset_includes_apex_for_double_wildcard_patterns() {
let patterns = vec!["**.openai.com".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("openai.com"));
assert!(set.is_match("api.openai.com"));
assert!(!set.is_match("evilopenai.com"));
}
#[test]
fn compile_globset_matches_all_with_star() {
let patterns = vec!["*".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("openai.com"));
assert!(set.is_match("api.openai.com"));
}
#[test]
fn compile_globset_dedupes_patterns_without_changing_behavior() {
let patterns = vec!["example.com".to_string(), "example.com".to_string()];
let set = compile_globset(&patterns).unwrap();
assert!(set.is_match("example.com"));
assert!(set.is_match("EXAMPLE.COM"));
assert!(!set.is_match("not-example.com"));
}
#[test]
fn compile_globset_rejects_invalid_patterns() {
let patterns = vec!["[".to_string()];
assert!(compile_globset(&patterns).is_err());
}
#[cfg(target_os = "macos")]
#[tokio::test]
async fn unix_socket_allowlist_is_respected_on_macos() {
let socket_path = "/tmp/example.sock".to_string();
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_unix_sockets: vec![socket_path.clone()],
..NetworkPolicy::default()
});
assert!(state.is_unix_socket_allowed(&socket_path).await.unwrap());
assert!(
!state
.is_unix_socket_allowed("/tmp/not-allowed.sock")
.await
.unwrap()
);
}
#[cfg(target_os = "macos")]
#[tokio::test]
async fn unix_socket_allowlist_resolves_symlinks() {
use std::os::unix::fs::symlink;
use tempfile::tempdir;
let temp_dir = tempdir().unwrap();
let dir = temp_dir.path();
let real = dir.join("real.sock");
let link = dir.join("link.sock");
// The allowlist mechanism is path-based; for test purposes we don't need an actual unix
// domain socket. Any filesystem entry works for canonicalization.
std::fs::write(&real, b"not a socket").unwrap();
symlink(&real, &link).unwrap();
let real_s = real.to_str().unwrap().to_string();
let link_s = link.to_str().unwrap().to_string();
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_unix_sockets: vec![real_s],
..NetworkPolicy::default()
});
assert!(state.is_unix_socket_allowed(&link_s).await.unwrap());
}
#[cfg(not(target_os = "macos"))]
#[tokio::test]
async fn unix_socket_allowlist_is_rejected_on_non_macos() {
let socket_path = "/tmp/example.sock".to_string();
let state = network_proxy_state_for_policy(NetworkPolicy {
allowed_domains: vec!["example.com".to_string()],
allow_unix_sockets: vec![socket_path.clone()],
..NetworkPolicy::default()
});
assert!(!state.is_unix_socket_allowed(&socket_path).await.unwrap());
}
}

View File

@@ -0,0 +1,419 @@
use crate::config::NetworkMode;
use crate::config::NetworkProxyConfig;
use crate::policy::DomainPattern;
use crate::policy::compile_globset;
use crate::runtime::ConfigState;
use crate::runtime::LayerMtime;
use anyhow::Context;
use anyhow::Result;
use codex_app_server_protocol::ConfigLayerSource;
use codex_core::config::CONFIG_TOML_FILE;
use codex_core::config::Constrained;
use codex_core::config::ConstraintError;
use codex_core::config::find_codex_home;
use codex_core::config_loader::ConfigLayerStack;
use codex_core::config_loader::ConfigLayerStackOrdering;
use codex_core::config_loader::LoaderOverrides;
use codex_core::config_loader::RequirementSource;
use codex_core::config_loader::load_config_layers_state;
use serde::Deserialize;
use std::collections::HashSet;
pub use crate::runtime::BlockedRequest;
pub use crate::runtime::NetworkProxyState;
#[cfg(test)]
pub(crate) use crate::runtime::network_proxy_state_for_policy;
pub(crate) async fn build_config_state() -> Result<ConfigState> {
// Load config through `codex-core` so we inherit the same layer ordering and semantics as the
// rest of Codex (system/managed layers, user layers, session flags, etc.).
let codex_home = find_codex_home().context("failed to resolve CODEX_HOME")?;
let cli_overrides = Vec::new();
let overrides = LoaderOverrides::default();
let config_layer_stack = load_config_layers_state(&codex_home, None, &cli_overrides, overrides)
.await
.context("failed to load Codex config")?;
let cfg_path = codex_home.join(CONFIG_TOML_FILE);
// Deserialize from the merged effective config, rather than parsing config.toml ourselves.
// This avoids a second parser/merger implementation (and the drift that comes with it).
let merged_toml = config_layer_stack.effective_config();
let config: NetworkProxyConfig = merged_toml
.try_into()
.context("failed to deserialize network proxy config")?;
// Security boundary: user-controlled layers must not be able to widen restrictions set by
// trusted/managed layers (e.g., MDM). Enforce this before building runtime state.
let constraints = enforce_trusted_constraints(&config_layer_stack, &config)?;
let layer_mtimes = collect_layer_mtimes(&config_layer_stack);
let deny_set = compile_globset(&config.network_proxy.policy.denied_domains)?;
let allow_set = compile_globset(&config.network_proxy.policy.allowed_domains)?;
Ok(ConfigState {
config,
allow_set,
deny_set,
constraints,
layer_mtimes,
cfg_path,
blocked: std::collections::VecDeque::new(),
})
}
fn collect_layer_mtimes(stack: &ConfigLayerStack) -> Vec<LayerMtime> {
stack
.get_layers(ConfigLayerStackOrdering::LowestPrecedenceFirst, false)
.iter()
.filter_map(|layer| {
let path = match &layer.name {
ConfigLayerSource::System { file } => Some(file.as_path().to_path_buf()),
ConfigLayerSource::User { file } => Some(file.as_path().to_path_buf()),
ConfigLayerSource::Project { dot_codex_folder } => dot_codex_folder
.join(CONFIG_TOML_FILE)
.ok()
.map(|p| p.as_path().to_path_buf()),
ConfigLayerSource::LegacyManagedConfigTomlFromFile { file } => {
Some(file.as_path().to_path_buf())
}
_ => None,
};
path.map(LayerMtime::new)
})
.collect()
}
#[derive(Debug, Default, Deserialize)]
struct PartialConfig {
#[serde(default)]
network_proxy: PartialNetworkProxyConfig,
}
#[derive(Debug, Default, Deserialize)]
struct PartialNetworkProxyConfig {
enabled: Option<bool>,
mode: Option<NetworkMode>,
allow_upstream_proxy: Option<bool>,
dangerously_allow_non_loopback_proxy: Option<bool>,
dangerously_allow_non_loopback_admin: Option<bool>,
#[serde(default)]
policy: PartialNetworkPolicy,
}
#[derive(Debug, Default, Deserialize)]
struct PartialNetworkPolicy {
#[serde(default)]
allowed_domains: Option<Vec<String>>,
#[serde(default)]
denied_domains: Option<Vec<String>>,
#[serde(default)]
allow_unix_sockets: Option<Vec<String>>,
#[serde(default)]
allow_local_binding: Option<bool>,
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(crate) struct NetworkProxyConstraints {
pub(crate) enabled: Option<bool>,
pub(crate) mode: Option<NetworkMode>,
pub(crate) allow_upstream_proxy: Option<bool>,
pub(crate) dangerously_allow_non_loopback_proxy: Option<bool>,
pub(crate) dangerously_allow_non_loopback_admin: Option<bool>,
pub(crate) allowed_domains: Option<Vec<String>>,
pub(crate) denied_domains: Option<Vec<String>>,
pub(crate) allow_unix_sockets: Option<Vec<String>>,
pub(crate) allow_local_binding: Option<bool>,
}
fn enforce_trusted_constraints(
layers: &codex_core::config_loader::ConfigLayerStack,
config: &NetworkProxyConfig,
) -> Result<NetworkProxyConstraints> {
let constraints = network_proxy_constraints_from_trusted_layers(layers)?;
validate_policy_against_constraints(config, &constraints)
.context("network proxy constraints")?;
Ok(constraints)
}
fn network_proxy_constraints_from_trusted_layers(
layers: &codex_core::config_loader::ConfigLayerStack,
) -> Result<NetworkProxyConstraints> {
let mut constraints = NetworkProxyConstraints::default();
for layer in layers.get_layers(
codex_core::config_loader::ConfigLayerStackOrdering::LowestPrecedenceFirst,
false,
) {
// Only trusted layers contribute constraints. User-controlled layers can narrow policy but
// must never widen beyond what managed config allows.
if is_user_controlled_layer(&layer.name) {
continue;
}
let partial: PartialConfig = layer
.config
.clone()
.try_into()
.context("failed to deserialize trusted config layer")?;
if let Some(enabled) = partial.network_proxy.enabled {
constraints.enabled = Some(enabled);
}
if let Some(mode) = partial.network_proxy.mode {
constraints.mode = Some(mode);
}
if let Some(allow_upstream_proxy) = partial.network_proxy.allow_upstream_proxy {
constraints.allow_upstream_proxy = Some(allow_upstream_proxy);
}
if let Some(dangerously_allow_non_loopback_proxy) =
partial.network_proxy.dangerously_allow_non_loopback_proxy
{
constraints.dangerously_allow_non_loopback_proxy =
Some(dangerously_allow_non_loopback_proxy);
}
if let Some(dangerously_allow_non_loopback_admin) =
partial.network_proxy.dangerously_allow_non_loopback_admin
{
constraints.dangerously_allow_non_loopback_admin =
Some(dangerously_allow_non_loopback_admin);
}
if let Some(allowed_domains) = partial.network_proxy.policy.allowed_domains {
constraints.allowed_domains = Some(allowed_domains);
}
if let Some(denied_domains) = partial.network_proxy.policy.denied_domains {
constraints.denied_domains = Some(denied_domains);
}
if let Some(allow_unix_sockets) = partial.network_proxy.policy.allow_unix_sockets {
constraints.allow_unix_sockets = Some(allow_unix_sockets);
}
if let Some(allow_local_binding) = partial.network_proxy.policy.allow_local_binding {
constraints.allow_local_binding = Some(allow_local_binding);
}
}
Ok(constraints)
}
fn is_user_controlled_layer(layer: &ConfigLayerSource) -> bool {
matches!(
layer,
ConfigLayerSource::User { .. }
| ConfigLayerSource::Project { .. }
| ConfigLayerSource::SessionFlags
)
}
pub(crate) fn validate_policy_against_constraints(
config: &NetworkProxyConfig,
constraints: &NetworkProxyConstraints,
) -> std::result::Result<(), ConstraintError> {
fn invalid_value(
field_name: &'static str,
candidate: impl Into<String>,
allowed: impl Into<String>,
) -> ConstraintError {
ConstraintError::InvalidValue {
field_name,
candidate: candidate.into(),
allowed: allowed.into(),
requirement_source: RequirementSource::Unknown,
}
}
let enabled = config.network_proxy.enabled;
if let Some(max_enabled) = constraints.enabled {
let _ = Constrained::new(enabled, move |candidate| {
if *candidate && !max_enabled {
Err(invalid_value(
"network_proxy.enabled",
"true",
"false (disabled by managed config)",
))
} else {
Ok(())
}
})?;
}
if let Some(max_mode) = constraints.mode {
let _ = Constrained::new(config.network_proxy.mode, move |candidate| {
if network_mode_rank(*candidate) > network_mode_rank(max_mode) {
Err(invalid_value(
"network_proxy.mode",
format!("{candidate:?}"),
format!("{max_mode:?} or more restrictive"),
))
} else {
Ok(())
}
})?;
}
let allow_upstream_proxy = constraints.allow_upstream_proxy;
let _ = Constrained::new(
config.network_proxy.allow_upstream_proxy,
move |candidate| match allow_upstream_proxy {
Some(true) | None => Ok(()),
Some(false) => {
if *candidate {
Err(invalid_value(
"network_proxy.allow_upstream_proxy",
"true",
"false (disabled by managed config)",
))
} else {
Ok(())
}
}
},
)?;
let allow_non_loopback_admin = constraints.dangerously_allow_non_loopback_admin;
let _ = Constrained::new(
config.network_proxy.dangerously_allow_non_loopback_admin,
move |candidate| match allow_non_loopback_admin {
Some(true) | None => Ok(()),
Some(false) => {
if *candidate {
Err(invalid_value(
"network_proxy.dangerously_allow_non_loopback_admin",
"true",
"false (disabled by managed config)",
))
} else {
Ok(())
}
}
},
)?;
let allow_non_loopback_proxy = constraints.dangerously_allow_non_loopback_proxy;
let _ = Constrained::new(
config.network_proxy.dangerously_allow_non_loopback_proxy,
move |candidate| match allow_non_loopback_proxy {
Some(true) | None => Ok(()),
Some(false) => {
if *candidate {
Err(invalid_value(
"network_proxy.dangerously_allow_non_loopback_proxy",
"true",
"false (disabled by managed config)",
))
} else {
Ok(())
}
}
},
)?;
if let Some(allow_local_binding) = constraints.allow_local_binding {
let _ = Constrained::new(
config.network_proxy.policy.allow_local_binding,
move |candidate| {
if *candidate && !allow_local_binding {
Err(invalid_value(
"network_proxy.policy.allow_local_binding",
"true",
"false (disabled by managed config)",
))
} else {
Ok(())
}
},
)?;
}
if let Some(allowed_domains) = &constraints.allowed_domains {
let managed_patterns: Vec<DomainPattern> = allowed_domains
.iter()
.map(|entry| DomainPattern::parse_for_constraints(entry))
.collect();
let _ = Constrained::new(
config.network_proxy.policy.allowed_domains.clone(),
move |candidate| {
let mut invalid = Vec::new();
for entry in candidate {
let candidate_pattern = DomainPattern::parse_for_constraints(entry);
if !managed_patterns
.iter()
.any(|managed| managed.allows(&candidate_pattern))
{
invalid.push(entry.clone());
}
}
if invalid.is_empty() {
Ok(())
} else {
Err(invalid_value(
"network_proxy.policy.allowed_domains",
format!("{invalid:?}"),
"subset of managed allowed_domains",
))
}
},
)?;
}
if let Some(denied_domains) = &constraints.denied_domains {
let required_set: HashSet<String> = denied_domains
.iter()
.map(|s| s.to_ascii_lowercase())
.collect();
let _ = Constrained::new(
config.network_proxy.policy.denied_domains.clone(),
move |candidate| {
let candidate_set: HashSet<String> =
candidate.iter().map(|s| s.to_ascii_lowercase()).collect();
let missing: Vec<String> = required_set
.iter()
.filter(|entry| !candidate_set.contains(*entry))
.cloned()
.collect();
if missing.is_empty() {
Ok(())
} else {
Err(invalid_value(
"network_proxy.policy.denied_domains",
"missing managed denied_domains entries",
format!("{missing:?}"),
))
}
},
)?;
}
if let Some(allow_unix_sockets) = &constraints.allow_unix_sockets {
let allowed_set: HashSet<String> = allow_unix_sockets
.iter()
.map(|s| s.to_ascii_lowercase())
.collect();
let _ = Constrained::new(
config.network_proxy.policy.allow_unix_sockets.clone(),
move |candidate| {
let mut invalid = Vec::new();
for entry in candidate {
if !allowed_set.contains(&entry.to_ascii_lowercase()) {
invalid.push(entry.clone());
}
}
if invalid.is_empty() {
Ok(())
} else {
Err(invalid_value(
"network_proxy.policy.allow_unix_sockets",
format!("{invalid:?}"),
"subset of managed allow_unix_sockets",
))
}
},
)?;
}
Ok(())
}
fn network_mode_rank(mode: NetworkMode) -> u8 {
match mode {
NetworkMode::Limited => 0,
NetworkMode::Full => 1,
}
}

View File

@@ -0,0 +1,188 @@
use rama_core::Layer;
use rama_core::Service;
use rama_core::error::BoxError;
use rama_core::error::ErrorContext as _;
use rama_core::error::OpaqueError;
use rama_core::extensions::ExtensionsMut;
use rama_core::extensions::ExtensionsRef;
use rama_core::service::BoxService;
use rama_http::Body;
use rama_http::Request;
use rama_http::Response;
use rama_http::layer::version_adapter::RequestVersionAdapter;
use rama_http_backend::client::HttpClientService;
use rama_http_backend::client::HttpConnector;
use rama_http_backend::client::proxy::layer::HttpProxyConnectorLayer;
use rama_net::address::ProxyAddress;
use rama_net::client::EstablishedClientConnection;
use rama_net::http::RequestContext;
use rama_tcp::client::service::TcpConnector;
use rama_tls_boring::client::TlsConnectorDataBuilder;
use rama_tls_boring::client::TlsConnectorLayer;
use tracing::warn;
#[cfg(target_os = "macos")]
use rama_unix::client::UnixConnector;
#[derive(Clone, Default)]
struct ProxyConfig {
http: Option<ProxyAddress>,
https: Option<ProxyAddress>,
all: Option<ProxyAddress>,
}
impl ProxyConfig {
fn from_env() -> Self {
let http = read_proxy_env(&["HTTP_PROXY", "http_proxy"]);
let https = read_proxy_env(&["HTTPS_PROXY", "https_proxy"]);
let all = read_proxy_env(&["ALL_PROXY", "all_proxy"]);
Self { http, https, all }
}
fn proxy_for_request(&self, req: &Request) -> Option<ProxyAddress> {
let is_secure = RequestContext::try_from(req)
.map(|ctx| ctx.protocol.is_secure())
.unwrap_or(false);
self.proxy_for_protocol(is_secure)
}
fn proxy_for_protocol(&self, is_secure: bool) -> Option<ProxyAddress> {
if is_secure {
self.https
.clone()
.or_else(|| self.http.clone())
.or_else(|| self.all.clone())
} else {
self.http.clone().or_else(|| self.all.clone())
}
}
}
fn read_proxy_env(keys: &[&str]) -> Option<ProxyAddress> {
for key in keys {
let Ok(value) = std::env::var(key) else {
continue;
};
let value = value.trim();
if value.is_empty() {
continue;
}
match ProxyAddress::try_from(value) {
Ok(proxy) => {
if proxy
.protocol
.as_ref()
.map(rama_net::Protocol::is_http)
.unwrap_or(true)
{
return Some(proxy);
}
warn!("ignoring {key}: non-http proxy protocol");
}
Err(err) => {
warn!("ignoring {key}: invalid proxy address ({err})");
}
}
}
None
}
pub(crate) fn proxy_for_connect() -> Option<ProxyAddress> {
ProxyConfig::from_env().proxy_for_protocol(true)
}
#[derive(Clone)]
pub(crate) struct UpstreamClient {
connector: BoxService<
Request<Body>,
EstablishedClientConnection<HttpClientService<Body>, Request<Body>>,
BoxError,
>,
proxy_config: ProxyConfig,
}
impl UpstreamClient {
pub(crate) fn direct() -> Self {
Self::new(ProxyConfig::default())
}
pub(crate) fn from_env_proxy() -> Self {
Self::new(ProxyConfig::from_env())
}
#[cfg(target_os = "macos")]
pub(crate) fn unix_socket(path: &str) -> Self {
let connector = build_unix_connector(path);
Self {
connector,
proxy_config: ProxyConfig::default(),
}
}
fn new(proxy_config: ProxyConfig) -> Self {
let connector = build_http_connector();
Self {
connector,
proxy_config,
}
}
}
impl Service<Request<Body>> for UpstreamClient {
type Output = Response;
type Error = OpaqueError;
async fn serve(&self, mut req: Request<Body>) -> Result<Self::Output, Self::Error> {
if let Some(proxy) = self.proxy_config.proxy_for_request(&req) {
req.extensions_mut().insert(proxy);
}
let uri = req.uri().clone();
let EstablishedClientConnection {
input: mut req,
conn: http_connection,
} = self
.connector
.serve(req)
.await
.map_err(OpaqueError::from_boxed)?;
req.extensions_mut()
.extend(http_connection.extensions().clone());
http_connection
.serve(req)
.await
.map_err(OpaqueError::from_boxed)
.with_context(|| format!("http request failure for uri: {uri}"))
}
}
fn build_http_connector() -> BoxService<
Request<Body>,
EstablishedClientConnection<HttpClientService<Body>, Request<Body>>,
BoxError,
> {
let transport = TcpConnector::default();
let proxy = HttpProxyConnectorLayer::optional().into_layer(transport);
let tls_config = TlsConnectorDataBuilder::new_http_auto().into_shared_builder();
let tls = TlsConnectorLayer::auto()
.with_connector_data(tls_config)
.into_layer(proxy);
let tls = RequestVersionAdapter::new(tls);
let connector = HttpConnector::new(tls);
connector.boxed()
}
#[cfg(target_os = "macos")]
fn build_unix_connector(
path: &str,
) -> BoxService<
Request<Body>,
EstablishedClientConnection<HttpClientService<Body>, Request<Body>>,
BoxError,
> {
let transport = UnixConnector::fixed(path);
let connector = HttpConnector::new(transport);
connector.boxed()
}