Compare commits

...

1 Commits

Author SHA1 Message Date
Ruslan Nigmatullin
f67382d22f [exec-server] add HTTP health endpoints 2026-05-09 20:38:35 +00:00
10 changed files with 257 additions and 52 deletions

2
codex-rs/Cargo.lock generated
View File

@@ -2741,6 +2741,7 @@ dependencies = [
"anyhow",
"arc-swap",
"async-trait",
"axum",
"base64 0.22.1",
"bytes",
"codex-app-server-protocol",
@@ -2751,6 +2752,7 @@ dependencies = [
"codex-test-binary-support",
"codex-utils-absolute-path",
"codex-utils-pty",
"codex-utils-rustls-provider",
"ctor 0.6.3",
"futures",
"pretty_assertions",

View File

@@ -13,6 +13,7 @@ workspace = true
[dependencies]
arc-swap = { workspace = true }
async-trait = { workspace = true }
axum = { workspace = true, features = ["http1", "tokio", "ws"] }
base64 = { workspace = true }
bytes = { workspace = true }
codex-app-server-protocol = { workspace = true }
@@ -22,6 +23,7 @@ codex-protocol = { workspace = true }
codex-sandboxing = { workspace = true }
codex-utils-absolute-path = { workspace = true }
codex-utils-pty = { workspace = true }
codex-utils-rustls-provider = { workspace = true }
futures = { workspace = true }
reqwest = { workspace = true, features = ["json", "rustls-tls", "stream"] }
serde = { workspace = true, features = ["derive"] }

View File

@@ -7,6 +7,8 @@ use tokio_tungstenite::connect_async;
use tracing::debug;
use tracing::warn;
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
use crate::ExecServerClient;
use crate::ExecServerError;
use crate::client_api::RemoteExecServerConnectArgs;
@@ -53,6 +55,7 @@ impl ExecServerClient {
pub async fn connect_websocket(
args: RemoteExecServerConnectArgs,
) -> Result<Self, ExecServerError> {
ensure_rustls_crypto_provider();
let websocket_url = args.websocket_url.clone();
let connect_timeout = args.connect_timeout;
let (stream, _) = timeout(connect_timeout, connect_async(websocket_url.as_str()))

View File

@@ -3,6 +3,8 @@ use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::time::Duration;
use axum::extract::ws::Message as AxumWebSocketMessage;
use axum::extract::ws::WebSocket as AxumWebSocket;
use codex_app_server_protocol::JSONRPCMessage;
use futures::SinkExt;
use futures::StreamExt;
@@ -441,6 +443,140 @@ impl JsonRpcConnection {
}
}
pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self {
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (disconnected_tx, disconnected_rx) = watch::channel(false);
let (mut websocket_writer, mut websocket_reader) = stream.split();
let reader_label = connection_label.clone();
let incoming_tx_for_reader = incoming_tx.clone();
let disconnected_tx_for_reader = disconnected_tx.clone();
let reader_task = tokio::spawn(async move {
loop {
match websocket_reader.next().await {
Some(Ok(AxumWebSocketMessage::Text(text))) => {
match serde_json::from_str::<JSONRPCMessage>(text.as_ref()) {
Ok(message) => {
if incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
}
Err(err) => {
send_malformed_message(
&incoming_tx_for_reader,
Some(format!(
"failed to parse websocket JSON-RPC message from {reader_label}: {err}"
)),
)
.await;
}
}
}
Some(Ok(AxumWebSocketMessage::Binary(bytes))) => {
match serde_json::from_slice::<JSONRPCMessage>(bytes.as_ref()) {
Ok(message) => {
if incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
}
Err(err) => {
send_malformed_message(
&incoming_tx_for_reader,
Some(format!(
"failed to parse websocket JSON-RPC message from {reader_label}: {err}"
)),
)
.await;
}
}
}
Some(Ok(AxumWebSocketMessage::Close(_))) => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
/*reason*/ None,
)
.await;
break;
}
Some(Ok(AxumWebSocketMessage::Ping(_)))
| Some(Ok(AxumWebSocketMessage::Pong(_))) => {}
Some(Err(err)) => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
Some(format!(
"failed to read websocket JSON-RPC message from {reader_label}: {err}"
)),
)
.await;
break;
}
None => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
/*reason*/ None,
)
.await;
break;
}
}
}
});
let writer_task = tokio::spawn(async move {
while let Some(message) = outgoing_rx.recv().await {
match serialize_jsonrpc_message(&message) {
Ok(encoded) => {
if let Err(err) = websocket_writer
.send(AxumWebSocketMessage::Text(encoded.into()))
.await
{
send_disconnected(
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to write websocket JSON-RPC message to {connection_label}: {err}"
)),
)
.await;
break;
}
}
Err(err) => {
send_disconnected(
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to serialize JSON-RPC message for {connection_label}: {err}"
)),
)
.await;
break;
}
}
}
});
Self {
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
transport: JsonRpcTransport::Plain,
}
}
pub(crate) fn with_child_process(mut self, child_process: Child) -> Self {
self.transport = JsonRpcTransport::from_child_process(child_process);
self

View File

@@ -506,7 +506,7 @@ mod tests {
#[tokio::test]
async fn environment_manager_reports_remote_url() {
let manager = EnvironmentManager::create_for_tests(
Some("ws://127.0.0.1:8765".to_string()),
Some("ws://127.0.0.1:8765/ws".to_string()),
test_runtime_paths(),
)
.await;
@@ -517,7 +517,10 @@ mod tests {
Some(REMOTE_ENVIRONMENT_ID)
);
assert!(environment.is_remote());
assert_eq!(environment.exec_server_url(), Some("ws://127.0.0.1:8765"));
assert_eq!(
environment.exec_server_url(),
Some("ws://127.0.0.1:8765/ws")
);
assert!(Arc::ptr_eq(
&environment,
&manager
@@ -548,7 +551,7 @@ mod tests {
snapshot: EnvironmentProviderSnapshot {
environments: vec![(
REMOTE_ENVIRONMENT_ID.to_string(),
Environment::create_for_tests(Some("ws://127.0.0.1:8765".to_string()))
Environment::create_for_tests(Some("ws://127.0.0.1:8765/ws".to_string()))
.expect("remote environment"),
)],
default: EnvironmentDefault::EnvironmentId(REMOTE_ENVIRONMENT_ID.to_string()),
@@ -620,7 +623,7 @@ mod tests {
snapshot: EnvironmentProviderSnapshot {
environments: vec![(
"devbox".to_string(),
Environment::create_for_tests(Some("ws://127.0.0.1:8765".to_string()))
Environment::create_for_tests(Some("ws://127.0.0.1:8765/ws".to_string()))
.expect("remote environment"),
)],
default: EnvironmentDefault::EnvironmentId("devbox".to_string()),
@@ -645,7 +648,7 @@ mod tests {
snapshot: EnvironmentProviderSnapshot {
environments: vec![(
"devbox".to_string(),
Environment::create_for_tests(Some("ws://127.0.0.1:8765".to_string()))
Environment::create_for_tests(Some("ws://127.0.0.1:8765/ws".to_string()))
.expect("remote environment"),
)],
default: EnvironmentDefault::Disabled,
@@ -672,7 +675,7 @@ mod tests {
snapshot: EnvironmentProviderSnapshot {
environments: vec![(
"devbox".to_string(),
Environment::create_for_tests(Some("ws://127.0.0.1:8765".to_string()))
Environment::create_for_tests(Some("ws://127.0.0.1:8765/ws".to_string()))
.expect("remote environment"),
)],
default: EnvironmentDefault::EnvironmentId("missing".to_string()),
@@ -766,23 +769,29 @@ mod tests {
let manager = EnvironmentManager::disabled_for_tests(test_runtime_paths());
manager
.upsert_environment("executor-a".to_string(), "ws://127.0.0.1:8765".to_string())
.upsert_environment(
"executor-a".to_string(),
"ws://127.0.0.1:8765/ws".to_string(),
)
.expect("remote environment");
let first = manager
.get_environment("executor-a")
.expect("first remote environment");
assert!(first.is_remote());
assert_eq!(first.exec_server_url(), Some("ws://127.0.0.1:8765"));
assert_eq!(first.exec_server_url(), Some("ws://127.0.0.1:8765/ws"));
assert_eq!(manager.default_environment_id(), None);
manager
.upsert_environment("executor-a".to_string(), "ws://127.0.0.1:9876".to_string())
.upsert_environment(
"executor-a".to_string(),
"ws://127.0.0.1:9876/ws".to_string(),
)
.expect("updated remote environment");
let second = manager
.get_environment("executor-a")
.expect("second remote environment");
assert!(second.is_remote());
assert_eq!(second.exec_server_url(), Some("ws://127.0.0.1:9876"));
assert_eq!(second.exec_server_url(), Some("ws://127.0.0.1:9876/ws"));
assert!(!Arc::ptr_eq(&first, &second));
}

View File

@@ -162,7 +162,7 @@ mod tests {
#[tokio::test]
async fn default_provider_adds_remote_environment_for_websocket_url() {
let provider = DefaultEnvironmentProvider::new(Some("ws://127.0.0.1:8765".to_string()));
let provider = DefaultEnvironmentProvider::new(Some("ws://127.0.0.1:8765/ws".to_string()));
let snapshot = provider.snapshot().await.expect("environments");
let EnvironmentProviderSnapshot {
environments,
@@ -177,7 +177,7 @@ mod tests {
assert!(remote_environment.is_remote());
assert_eq!(
remote_environment.exec_server_url(),
Some("ws://127.0.0.1:8765")
Some("ws://127.0.0.1:8765/ws")
);
assert_eq!(
default,
@@ -187,13 +187,14 @@ mod tests {
#[tokio::test]
async fn default_provider_normalizes_exec_server_url() {
let provider = DefaultEnvironmentProvider::new(Some(" ws://127.0.0.1:8765 ".to_string()));
let provider =
DefaultEnvironmentProvider::new(Some(" ws://127.0.0.1:8765/ws ".to_string()));
let snapshot = provider.snapshot().await.expect("environments");
let environments: HashMap<_, _> = snapshot.environments.into_iter().collect();
assert_eq!(
environments[REMOTE_ENVIRONMENT_ID].exec_server_url(),
Some("ws://127.0.0.1:8765")
Some("ws://127.0.0.1:8765/ws")
);
}
}

View File

@@ -333,7 +333,7 @@ mod tests {
environments: vec![
EnvironmentToml {
id: "devbox".to_string(),
url: Some(" ws://127.0.0.1:8765 ".to_string()),
url: Some(" ws://127.0.0.1:8765/ws ".to_string()),
..Default::default()
},
EnvironmentToml {
@@ -370,7 +370,7 @@ mod tests {
assert!(!environments.contains_key(LOCAL_ENVIRONMENT_ID));
assert_eq!(
environments["devbox"].exec_server_url(),
Some("ws://127.0.0.1:8765")
Some("ws://127.0.0.1:8765/ws")
);
assert!(environments["ssh-dev"].is_remote());
assert_eq!(environments["ssh-dev"].exec_server_url(), None);
@@ -411,7 +411,7 @@ mod tests {
(
EnvironmentToml {
id: "local".to_string(),
url: Some("ws://127.0.0.1:8765".to_string()),
url: Some("ws://127.0.0.1:8765/ws".to_string()),
..Default::default()
},
"environment id `local` is reserved",
@@ -419,7 +419,7 @@ mod tests {
(
EnvironmentToml {
id: " devbox ".to_string(),
url: Some("ws://127.0.0.1:8765".to_string()),
url: Some("ws://127.0.0.1:8765/ws".to_string()),
..Default::default()
},
"environment id ` devbox ` must not contain surrounding whitespace",
@@ -427,7 +427,7 @@ mod tests {
(
EnvironmentToml {
id: "dev box".to_string(),
url: Some("ws://127.0.0.1:8765".to_string()),
url: Some("ws://127.0.0.1:8765/ws".to_string()),
..Default::default()
},
"environment id `dev box` must contain only ASCII letters, numbers, '-' or '_'",
@@ -443,7 +443,7 @@ mod tests {
(
EnvironmentToml {
id: "devbox".to_string(),
url: Some("ws://127.0.0.1:8765".to_string()),
url: Some("ws://127.0.0.1:8765/ws".to_string()),
program: Some("codex".to_string()),
..Default::default()
},
@@ -528,7 +528,7 @@ mod tests {
environments: vec![
EnvironmentToml {
id: "devbox".to_string(),
url: Some("ws://127.0.0.1:8765".to_string()),
url: Some("ws://127.0.0.1:8765/ws".to_string()),
connect_timeout_sec: Some(Duration::from_secs(12)),
initialize_timeout_sec: Some(Duration::from_secs(34)),
..Default::default()
@@ -546,7 +546,7 @@ mod tests {
assert_eq!(
provider.environments[0].1,
ExecServerTransportParams::WebSocketUrl {
websocket_url: "ws://127.0.0.1:8765".to_string(),
websocket_url: "ws://127.0.0.1:8765/ws".to_string(),
connect_timeout: Duration::from_secs(12),
initialize_timeout: Duration::from_secs(34),
}
@@ -591,7 +591,7 @@ mod tests {
environments: vec![
EnvironmentToml {
id: "devbox".to_string(),
url: Some("ws://127.0.0.1:8765".to_string()),
url: Some("ws://127.0.0.1:8765/ws".to_string()),
..Default::default()
},
EnvironmentToml {
@@ -616,7 +616,7 @@ mod tests {
default: None,
environments: vec![EnvironmentToml {
id: id.clone(),
url: Some("ws://127.0.0.1:8765".to_string()),
url: Some("ws://127.0.0.1:8765/ws".to_string()),
..Default::default()
}],
})
@@ -655,7 +655,7 @@ default = "ssh-dev"
[[environments]]
id = "devbox"
url = "ws://127.0.0.1:4512"
url = "ws://127.0.0.1:4512/ws"
connect_timeout_sec = 12.0
initialize_timeout_sec = 34.0
@@ -678,7 +678,7 @@ CODEX_LOG = "debug"
environments.environments[0],
EnvironmentToml {
id: "devbox".to_string(),
url: Some("ws://127.0.0.1:4512".to_string()),
url: Some("ws://127.0.0.1:4512/ws".to_string()),
connect_timeout_sec: Some(Duration::from_secs(12)),
initialize_timeout_sec: Some(Duration::from_secs(34)),
..Default::default()
@@ -712,7 +712,7 @@ CODEX_LOG = "debug"
r#"
[[environments]]
id = "devbox"
url = "ws://127.0.0.1:4512"
url = "ws://127.0.0.1:4512/ws"
unknown = true
"#,
"unknown field `unknown`",

View File

@@ -7,6 +7,8 @@ use tokio::time::sleep;
use tokio_tungstenite::connect_async;
use tracing::warn;
use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
use crate::ExecServerError;
use crate::ExecServerRuntimePaths;
use crate::connection::JsonRpcConnection;
@@ -133,6 +135,7 @@ pub async fn run_remote_executor(
config: RemoteExecutorConfig,
runtime_paths: ExecServerRuntimePaths,
) -> Result<(), ExecServerError> {
ensure_rustls_crypto_provider();
let client = ExecutorRegistryClient::new(config.base_url.clone(), config.bearer_token.clone())?;
let processor = ConnectionProcessor::new(runtime_paths);
let mut backoff = Duration::from_secs(1);

View File

@@ -1,11 +1,18 @@
use axum::Router;
use axum::extract::ConnectInfo;
use axum::extract::State;
use axum::extract::ws::WebSocketUpgrade;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::any;
use axum::routing::get;
use std::io::Write as _;
use std::net::SocketAddr;
use tokio::io;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::net::TcpListener;
use tokio_tungstenite::accept_async;
use tracing::warn;
use tracing::info;
use crate::ExecServerRuntimePaths;
use crate::connection::JsonRpcConnection;
@@ -109,31 +116,49 @@ async fn run_websocket_listener(
let listener = TcpListener::bind(bind_address).await?;
let local_addr = listener.local_addr()?;
let processor = ConnectionProcessor::new(runtime_paths);
tracing::info!("codex-exec-server listening on ws://{local_addr}");
println!("ws://{local_addr}");
info!("codex-exec-server listening on ws://{local_addr}/ws");
println!("ws://{local_addr}/ws");
std::io::stdout().flush()?;
loop {
let (stream, peer_addr) = listener.accept().await?;
let processor = processor.clone();
tokio::spawn(async move {
match accept_async(stream).await {
Ok(websocket) => {
processor
.run_connection(JsonRpcConnection::from_websocket(
websocket,
format!("exec-server websocket {peer_addr}"),
))
.await;
}
Err(err) => {
warn!(
"failed to accept exec-server websocket connection from {peer_addr}: {err}"
);
}
}
});
}
let router = Router::new()
.route("/", get(health_check_handler))
.route("/readyz", get(health_check_handler))
.route("/healthz", get(health_check_handler))
.route("/ws", any(websocket_upgrade_handler))
.fallback(any(websocket_upgrade_handler))
.with_state(ExecServerWebSocketState { processor });
axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.await?;
Ok(())
}
#[derive(Clone)]
struct ExecServerWebSocketState {
processor: ConnectionProcessor,
}
async fn health_check_handler() -> StatusCode {
StatusCode::OK
}
async fn websocket_upgrade_handler(
websocket: WebSocketUpgrade,
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
State(state): State<ExecServerWebSocketState>,
) -> impl IntoResponse {
info!(%peer_addr, "exec-server websocket client connected");
websocket.on_upgrade(move |stream| async move {
state
.processor
.run_connection(JsonRpcConnection::from_axum_websocket(
stream,
format!("exec-server websocket {peer_addr}"),
))
.await;
})
}
#[cfg(test)]

View File

@@ -0,0 +1,24 @@
#![cfg(unix)]
mod common;
use common::exec_server::exec_server;
use pretty_assertions::assert_eq;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_server_serves_health_checks_alongside_websocket_endpoint() -> anyhow::Result<()> {
let mut server = exec_server().await?;
let http_base_url = server
.websocket_url()
.strip_prefix("ws://")
.and_then(|url| url.strip_suffix("/ws"))
.expect("websocket URL should use ws://.../ws");
for path in ["/", "/readyz", "/healthz"] {
let response = reqwest::get(format!("http://{http_base_url}{path}")).await?;
assert_eq!(response.status(), reqwest::StatusCode::OK);
}
server.shutdown().await?;
Ok(())
}