mirror of
https://github.com/openai/codex.git
synced 2026-04-28 02:11:08 +03:00
app-server: Split transport module
This commit is contained in:
@@ -1,8 +1,5 @@
|
||||
pub(crate) mod auth;
|
||||
|
||||
use self::auth::WebsocketAuthPolicy;
|
||||
use self::auth::authorize_upgrade;
|
||||
use self::auth::should_warn_about_unauthenticated_non_loopback_listener;
|
||||
use crate::error_code::OVERLOADED_ERROR_CODE;
|
||||
use crate::message_processor::ConnectionSessionState;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
@@ -10,53 +7,20 @@ use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingError;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use axum::Router;
|
||||
use axum::body::Body;
|
||||
use axum::extract::ConnectInfo;
|
||||
use axum::extract::State;
|
||||
use axum::extract::ws::Message as WebSocketMessage;
|
||||
use axum::extract::ws::WebSocket;
|
||||
use axum::extract::ws::WebSocketUpgrade;
|
||||
use axum::http::HeaderMap;
|
||||
use axum::http::Request;
|
||||
use axum::http::StatusCode;
|
||||
use axum::http::header::ORIGIN;
|
||||
use axum::middleware;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::response::Response;
|
||||
use axum::routing::any;
|
||||
use axum::routing::get;
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::ServerRequest;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Stream;
|
||||
use owo_colors::Style;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
use std::net::SocketAddr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::io::{self};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
/// Size of the bounded channels used to communicate between tasks. The value
|
||||
@@ -64,85 +28,11 @@ use tracing::warn;
|
||||
/// plenty for an interactive CLI.
|
||||
pub(crate) const CHANNEL_CAPACITY: usize = 128;
|
||||
|
||||
fn colorize(text: &str, style: Style) -> String {
|
||||
text.if_supports_color(Stream::Stderr, |value| value.style(style))
|
||||
.to_string()
|
||||
}
|
||||
mod stdio;
|
||||
mod websocket;
|
||||
|
||||
#[allow(clippy::print_stderr)]
|
||||
fn print_websocket_startup_banner(addr: SocketAddr) {
|
||||
let title = colorize("codex app-server (WebSockets)", Style::new().bold().cyan());
|
||||
let listening_label = colorize("listening on:", Style::new().dimmed());
|
||||
let listen_url = colorize(&format!("ws://{addr}"), Style::new().green());
|
||||
let ready_label = colorize("readyz:", Style::new().dimmed());
|
||||
let ready_url = colorize(&format!("http://{addr}/readyz"), Style::new().green());
|
||||
let health_label = colorize("healthz:", Style::new().dimmed());
|
||||
let health_url = colorize(&format!("http://{addr}/healthz"), Style::new().green());
|
||||
let note_label = colorize("note:", Style::new().dimmed());
|
||||
eprintln!("{title}");
|
||||
eprintln!(" {listening_label} {listen_url}");
|
||||
eprintln!(" {ready_label} {ready_url}");
|
||||
eprintln!(" {health_label} {health_url}");
|
||||
if addr.ip().is_loopback() {
|
||||
eprintln!(
|
||||
" {note_label} binds localhost only (use SSH port-forwarding for remote access)"
|
||||
);
|
||||
} else {
|
||||
eprintln!(
|
||||
" {note_label} websocket auth is opt-in in this build; configure `--ws-auth ...` before real remote use"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct WebSocketListenerState {
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
connection_counter: Arc<AtomicU64>,
|
||||
auth_policy: Arc<WebsocketAuthPolicy>,
|
||||
}
|
||||
|
||||
async fn health_check_handler() -> StatusCode {
|
||||
StatusCode::OK
|
||||
}
|
||||
|
||||
async fn reject_requests_with_origin_header(
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if request.headers().contains_key(ORIGIN) {
|
||||
warn!(
|
||||
method = %request.method(),
|
||||
uri = %request.uri(),
|
||||
"rejecting websocket listener request with Origin header"
|
||||
);
|
||||
Err(StatusCode::FORBIDDEN)
|
||||
} else {
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
}
|
||||
|
||||
async fn websocket_upgrade_handler(
|
||||
websocket: WebSocketUpgrade,
|
||||
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
|
||||
State(state): State<WebSocketListenerState>,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(err) = authorize_upgrade(&headers, state.auth_policy.as_ref()) {
|
||||
warn!(
|
||||
%peer_addr,
|
||||
message = err.message(),
|
||||
"rejecting websocket client during upgrade"
|
||||
);
|
||||
return (err.status_code(), err.message()).into_response();
|
||||
}
|
||||
let connection_id = ConnectionId(state.connection_counter.fetch_add(1, Ordering::Relaxed));
|
||||
info!(%peer_addr, "websocket client connected");
|
||||
websocket
|
||||
.on_upgrade(move |stream| async move {
|
||||
run_websocket_connection(connection_id, stream, state.transport_event_tx).await;
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
pub(crate) use stdio::start_stdio_connection;
|
||||
pub(crate) use websocket::start_websocket_acceptor;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub enum AppServerTransport {
|
||||
@@ -276,262 +166,6 @@ impl OutboundConnectionState {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn start_stdio_connection(
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
stdio_handles: &mut Vec<JoinHandle<()>>,
|
||||
) -> IoResult<()> {
|
||||
let connection_id = ConnectionId(0);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: None,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "processor unavailable"))?;
|
||||
|
||||
let transport_event_tx_for_reader = transport_event_tx.clone();
|
||||
stdio_handles.push(tokio::spawn(async move {
|
||||
let stdin = io::stdin();
|
||||
let reader = BufReader::new(stdin);
|
||||
let mut lines = reader.lines();
|
||||
|
||||
loop {
|
||||
match lines.next_line().await {
|
||||
Ok(Some(line)) => {
|
||||
if !forward_incoming_message(
|
||||
&transport_event_tx_for_reader,
|
||||
&writer_tx_for_reader,
|
||||
connection_id,
|
||||
&line,
|
||||
)
|
||||
.await
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(err) => {
|
||||
error!("Failed reading stdin: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = transport_event_tx_for_reader
|
||||
.send(TransportEvent::ConnectionClosed { connection_id })
|
||||
.await;
|
||||
debug!("stdin reader finished (EOF)");
|
||||
}));
|
||||
|
||||
stdio_handles.push(tokio::spawn(async move {
|
||||
let mut stdout = io::stdout();
|
||||
while let Some(queued_message) = writer_rx.recv().await {
|
||||
let Some(mut json) = serialize_outgoing_message(queued_message.message) else {
|
||||
continue;
|
||||
};
|
||||
json.push('\n');
|
||||
if let Err(err) = stdout.write_all(json.as_bytes()).await {
|
||||
error!("Failed to write to stdout: {err}");
|
||||
break;
|
||||
}
|
||||
if let Some(write_complete_tx) = queued_message.write_complete_tx {
|
||||
let _ = write_complete_tx.send(());
|
||||
}
|
||||
}
|
||||
info!("stdout writer exited (channel closed)");
|
||||
}));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn start_websocket_acceptor(
|
||||
bind_address: SocketAddr,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
auth_policy: WebsocketAuthPolicy,
|
||||
) -> IoResult<JoinHandle<()>> {
|
||||
if should_warn_about_unauthenticated_non_loopback_listener(bind_address, &auth_policy) {
|
||||
warn!(
|
||||
%bind_address,
|
||||
"starting non-loopback websocket listener without auth; websocket auth is opt-in for now and will become the default in a future release"
|
||||
);
|
||||
}
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
let local_addr = listener.local_addr()?;
|
||||
print_websocket_startup_banner(local_addr);
|
||||
info!("app-server websocket listening on ws://{local_addr}");
|
||||
|
||||
let router = Router::new()
|
||||
.route("/readyz", get(health_check_handler))
|
||||
.route("/healthz", get(health_check_handler))
|
||||
.fallback(any(websocket_upgrade_handler))
|
||||
.layer(middleware::from_fn(reject_requests_with_origin_header))
|
||||
.with_state(WebSocketListenerState {
|
||||
transport_event_tx,
|
||||
connection_counter: Arc::new(AtomicU64::new(1)),
|
||||
auth_policy: Arc::new(auth_policy),
|
||||
});
|
||||
let server = axum::serve(
|
||||
listener,
|
||||
router.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.with_graceful_shutdown(async move {
|
||||
shutdown_token.cancelled().await;
|
||||
});
|
||||
Ok(tokio::spawn(async move {
|
||||
if let Err(err) = server.await {
|
||||
error!("websocket acceptor failed: {err}");
|
||||
}
|
||||
info!("websocket acceptor shutting down");
|
||||
}))
|
||||
}
|
||||
|
||||
async fn run_websocket_connection(
|
||||
connection_id: ConnectionId,
|
||||
websocket_stream: WebSocket,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
) {
|
||||
let (writer_tx, writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
let disconnect_token = CancellationToken::new();
|
||||
if transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: Some(disconnect_token.clone()),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let (websocket_writer, websocket_reader) = websocket_stream.split();
|
||||
let (writer_control_tx, writer_control_rx) =
|
||||
mpsc::channel::<WebSocketMessage>(CHANNEL_CAPACITY);
|
||||
let mut outbound_task = tokio::spawn(run_websocket_outbound_loop(
|
||||
websocket_writer,
|
||||
writer_rx,
|
||||
writer_control_rx,
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
let mut inbound_task = tokio::spawn(run_websocket_inbound_loop(
|
||||
websocket_reader,
|
||||
transport_event_tx.clone(),
|
||||
writer_tx_for_reader,
|
||||
writer_control_tx,
|
||||
connection_id,
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut outbound_task => {
|
||||
disconnect_token.cancel();
|
||||
inbound_task.abort();
|
||||
}
|
||||
_ = &mut inbound_task => {
|
||||
disconnect_token.cancel();
|
||||
outbound_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
let _ = transport_event_tx
|
||||
.send(TransportEvent::ConnectionClosed { connection_id })
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn run_websocket_outbound_loop(
|
||||
mut websocket_writer: futures::stream::SplitSink<WebSocket, WebSocketMessage>,
|
||||
mut writer_rx: mpsc::Receiver<QueuedOutgoingMessage>,
|
||||
mut writer_control_rx: mpsc::Receiver<WebSocketMessage>,
|
||||
disconnect_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
message = writer_control_rx.recv() => {
|
||||
let Some(message) = message else {
|
||||
break;
|
||||
};
|
||||
if websocket_writer.send(message).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
queued_message = writer_rx.recv() => {
|
||||
let Some(queued_message) = queued_message else {
|
||||
break;
|
||||
};
|
||||
let Some(json) = serialize_outgoing_message(queued_message.message) else {
|
||||
continue;
|
||||
};
|
||||
if websocket_writer.send(WebSocketMessage::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
if let Some(write_complete_tx) = queued_message.write_complete_tx {
|
||||
let _ = write_complete_tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_websocket_inbound_loop(
|
||||
mut websocket_reader: futures::stream::SplitStream<WebSocket>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
writer_tx_for_reader: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
writer_control_tx: mpsc::Sender<WebSocketMessage>,
|
||||
connection_id: ConnectionId,
|
||||
disconnect_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
incoming_message = websocket_reader.next() => {
|
||||
match incoming_message {
|
||||
Some(Ok(WebSocketMessage::Text(text))) => {
|
||||
if !forward_incoming_message(
|
||||
&transport_event_tx,
|
||||
&writer_tx_for_reader,
|
||||
connection_id,
|
||||
text.as_ref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(WebSocketMessage::Ping(payload))) => {
|
||||
match writer_control_tx.try_send(WebSocketMessage::Pong(payload)) {
|
||||
Ok(()) => {}
|
||||
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => break,
|
||||
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
|
||||
warn!("websocket control queue full while replying to ping; closing connection");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(WebSocketMessage::Pong(_))) => {}
|
||||
Some(Ok(WebSocketMessage::Close(_))) | None => break,
|
||||
Some(Ok(WebSocketMessage::Binary(_))) => {
|
||||
warn!("dropping unsupported binary websocket message");
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
warn!("websocket receive error: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn forward_incoming_message(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
writer: &mpsc::Sender<QueuedOutgoingMessage>,
|
||||
88
codex-rs/app-server/src/transport/stdio.rs
Normal file
88
codex-rs/app-server/src/transport/stdio.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
use super::forward_incoming_message;
|
||||
use super::serialize_outgoing_message;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
use tokio::io;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
|
||||
pub(crate) async fn start_stdio_connection(
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
stdio_handles: &mut Vec<JoinHandle<()>>,
|
||||
) -> IoResult<()> {
|
||||
let connection_id = ConnectionId(0);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: None,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "processor unavailable"))?;
|
||||
|
||||
let transport_event_tx_for_reader = transport_event_tx.clone();
|
||||
stdio_handles.push(tokio::spawn(async move {
|
||||
let stdin = io::stdin();
|
||||
let reader = BufReader::new(stdin);
|
||||
let mut lines = reader.lines();
|
||||
|
||||
loop {
|
||||
match lines.next_line().await {
|
||||
Ok(Some(line)) => {
|
||||
if !forward_incoming_message(
|
||||
&transport_event_tx_for_reader,
|
||||
&writer_tx_for_reader,
|
||||
connection_id,
|
||||
&line,
|
||||
)
|
||||
.await
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(err) => {
|
||||
error!("Failed reading stdin: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = transport_event_tx_for_reader
|
||||
.send(TransportEvent::ConnectionClosed { connection_id })
|
||||
.await;
|
||||
debug!("stdin reader finished (EOF)");
|
||||
}));
|
||||
|
||||
stdio_handles.push(tokio::spawn(async move {
|
||||
let mut stdout = io::stdout();
|
||||
while let Some(queued_message) = writer_rx.recv().await {
|
||||
let Some(mut json) = serialize_outgoing_message(queued_message.message) else {
|
||||
continue;
|
||||
};
|
||||
json.push('\n');
|
||||
if let Err(err) = stdout.write_all(json.as_bytes()).await {
|
||||
error!("Failed to write to stdout: {err}");
|
||||
break;
|
||||
}
|
||||
if let Some(write_complete_tx) = queued_message.write_complete_tx {
|
||||
let _ = write_complete_tx.send(());
|
||||
}
|
||||
}
|
||||
info!("stdout writer exited (channel closed)");
|
||||
}));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
308
codex-rs/app-server/src/transport/websocket.rs
Normal file
308
codex-rs/app-server/src/transport/websocket.rs
Normal file
@@ -0,0 +1,308 @@
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
use super::auth::WebsocketAuthPolicy;
|
||||
use super::auth::authorize_upgrade;
|
||||
use super::auth::should_warn_about_unauthenticated_non_loopback_listener;
|
||||
use super::forward_incoming_message;
|
||||
use super::serialize_outgoing_message;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use axum::Router;
|
||||
use axum::body::Body;
|
||||
use axum::extract::ConnectInfo;
|
||||
use axum::extract::State;
|
||||
use axum::extract::ws::Message as WebSocketMessage;
|
||||
use axum::extract::ws::WebSocket;
|
||||
use axum::extract::ws::WebSocketUpgrade;
|
||||
use axum::http::HeaderMap;
|
||||
use axum::http::Request;
|
||||
use axum::http::StatusCode;
|
||||
use axum::http::header::ORIGIN;
|
||||
use axum::middleware;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::response::Response;
|
||||
use axum::routing::any;
|
||||
use axum::routing::get;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Stream;
|
||||
use owo_colors::Style;
|
||||
use std::io::Result as IoResult;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
fn colorize(text: &str, style: Style) -> String {
|
||||
text.if_supports_color(Stream::Stderr, |value| value.style(style))
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[allow(clippy::print_stderr)]
|
||||
fn print_websocket_startup_banner(addr: SocketAddr) {
|
||||
let title = colorize("codex app-server (WebSockets)", Style::new().bold().cyan());
|
||||
let listening_label = colorize("listening on:", Style::new().dimmed());
|
||||
let listen_url = colorize(&format!("ws://{addr}"), Style::new().green());
|
||||
let ready_label = colorize("readyz:", Style::new().dimmed());
|
||||
let ready_url = colorize(&format!("http://{addr}/readyz"), Style::new().green());
|
||||
let health_label = colorize("healthz:", Style::new().dimmed());
|
||||
let health_url = colorize(&format!("http://{addr}/healthz"), Style::new().green());
|
||||
let note_label = colorize("note:", Style::new().dimmed());
|
||||
eprintln!("{title}");
|
||||
eprintln!(" {listening_label} {listen_url}");
|
||||
eprintln!(" {ready_label} {ready_url}");
|
||||
eprintln!(" {health_label} {health_url}");
|
||||
if addr.ip().is_loopback() {
|
||||
eprintln!(
|
||||
" {note_label} binds localhost only (use SSH port-forwarding for remote access)"
|
||||
);
|
||||
} else {
|
||||
eprintln!(
|
||||
" {note_label} websocket auth is opt-in in this build; configure `--ws-auth ...` before real remote use"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct WebSocketListenerState {
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
connection_counter: Arc<AtomicU64>,
|
||||
auth_policy: Arc<WebsocketAuthPolicy>,
|
||||
}
|
||||
|
||||
async fn health_check_handler() -> StatusCode {
|
||||
StatusCode::OK
|
||||
}
|
||||
|
||||
async fn reject_requests_with_origin_header(
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if request.headers().contains_key(ORIGIN) {
|
||||
warn!(
|
||||
method = %request.method(),
|
||||
uri = %request.uri(),
|
||||
"rejecting websocket listener request with Origin header"
|
||||
);
|
||||
Err(StatusCode::FORBIDDEN)
|
||||
} else {
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
}
|
||||
|
||||
async fn websocket_upgrade_handler(
|
||||
websocket: WebSocketUpgrade,
|
||||
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
|
||||
State(state): State<WebSocketListenerState>,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(err) = authorize_upgrade(&headers, state.auth_policy.as_ref()) {
|
||||
warn!(
|
||||
%peer_addr,
|
||||
message = err.message(),
|
||||
"rejecting websocket client during upgrade"
|
||||
);
|
||||
return (err.status_code(), err.message()).into_response();
|
||||
}
|
||||
let connection_id = ConnectionId(state.connection_counter.fetch_add(1, Ordering::Relaxed));
|
||||
info!(%peer_addr, "websocket client connected");
|
||||
websocket
|
||||
.on_upgrade(move |stream| async move {
|
||||
run_websocket_connection(connection_id, stream, state.transport_event_tx).await;
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
|
||||
pub(crate) async fn start_websocket_acceptor(
|
||||
bind_address: SocketAddr,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
auth_policy: WebsocketAuthPolicy,
|
||||
) -> IoResult<JoinHandle<()>> {
|
||||
if should_warn_about_unauthenticated_non_loopback_listener(bind_address, &auth_policy) {
|
||||
warn!(
|
||||
%bind_address,
|
||||
"starting non-loopback websocket listener without auth; websocket auth is opt-in for now and will become the default in a future release"
|
||||
);
|
||||
}
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
let local_addr = listener.local_addr()?;
|
||||
print_websocket_startup_banner(local_addr);
|
||||
info!("app-server websocket listening on ws://{local_addr}");
|
||||
|
||||
let router = Router::new()
|
||||
.route("/readyz", get(health_check_handler))
|
||||
.route("/healthz", get(health_check_handler))
|
||||
.fallback(any(websocket_upgrade_handler))
|
||||
.layer(middleware::from_fn(reject_requests_with_origin_header))
|
||||
.with_state(WebSocketListenerState {
|
||||
transport_event_tx,
|
||||
connection_counter: Arc::new(AtomicU64::new(1)),
|
||||
auth_policy: Arc::new(auth_policy),
|
||||
});
|
||||
let server = axum::serve(
|
||||
listener,
|
||||
router.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.with_graceful_shutdown(async move {
|
||||
shutdown_token.cancelled().await;
|
||||
});
|
||||
Ok(tokio::spawn(async move {
|
||||
if let Err(err) = server.await {
|
||||
error!("websocket acceptor failed: {err}");
|
||||
}
|
||||
info!("websocket acceptor shutting down");
|
||||
}))
|
||||
}
|
||||
|
||||
async fn run_websocket_connection(
|
||||
connection_id: ConnectionId,
|
||||
websocket_stream: WebSocket,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
) {
|
||||
let (writer_tx, writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
let disconnect_token = CancellationToken::new();
|
||||
if transport_event_tx
|
||||
.send(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: Some(disconnect_token.clone()),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let (websocket_writer, websocket_reader) = websocket_stream.split();
|
||||
let (writer_control_tx, writer_control_rx) =
|
||||
mpsc::channel::<WebSocketMessage>(CHANNEL_CAPACITY);
|
||||
let mut outbound_task = tokio::spawn(run_websocket_outbound_loop(
|
||||
websocket_writer,
|
||||
writer_rx,
|
||||
writer_control_rx,
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
let mut inbound_task = tokio::spawn(run_websocket_inbound_loop(
|
||||
websocket_reader,
|
||||
transport_event_tx.clone(),
|
||||
writer_tx_for_reader,
|
||||
writer_control_tx,
|
||||
connection_id,
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
|
||||
tokio::select! {
|
||||
_ = &mut outbound_task => {
|
||||
disconnect_token.cancel();
|
||||
inbound_task.abort();
|
||||
}
|
||||
_ = &mut inbound_task => {
|
||||
disconnect_token.cancel();
|
||||
outbound_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
let _ = transport_event_tx
|
||||
.send(TransportEvent::ConnectionClosed { connection_id })
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn run_websocket_outbound_loop(
|
||||
mut websocket_writer: futures::stream::SplitSink<WebSocket, WebSocketMessage>,
|
||||
mut writer_rx: mpsc::Receiver<QueuedOutgoingMessage>,
|
||||
mut writer_control_rx: mpsc::Receiver<WebSocketMessage>,
|
||||
disconnect_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
message = writer_control_rx.recv() => {
|
||||
let Some(message) = message else {
|
||||
break;
|
||||
};
|
||||
if websocket_writer.send(message).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
queued_message = writer_rx.recv() => {
|
||||
let Some(queued_message) = queued_message else {
|
||||
break;
|
||||
};
|
||||
let Some(json) = serialize_outgoing_message(queued_message.message) else {
|
||||
continue;
|
||||
};
|
||||
if websocket_writer.send(WebSocketMessage::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
if let Some(write_complete_tx) = queued_message.write_complete_tx {
|
||||
let _ = write_complete_tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_websocket_inbound_loop(
|
||||
mut websocket_reader: futures::stream::SplitStream<WebSocket>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
writer_tx_for_reader: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
writer_control_tx: mpsc::Sender<WebSocketMessage>,
|
||||
connection_id: ConnectionId,
|
||||
disconnect_token: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
incoming_message = websocket_reader.next() => {
|
||||
match incoming_message {
|
||||
Some(Ok(WebSocketMessage::Text(text))) => {
|
||||
if !forward_incoming_message(
|
||||
&transport_event_tx,
|
||||
&writer_tx_for_reader,
|
||||
connection_id,
|
||||
text.as_ref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(WebSocketMessage::Ping(payload))) => {
|
||||
match writer_control_tx.try_send(WebSocketMessage::Pong(payload)) {
|
||||
Ok(()) => {}
|
||||
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => break,
|
||||
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
|
||||
warn!("websocket control queue full while replying to ping; closing connection");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(WebSocketMessage::Pong(_))) => {}
|
||||
Some(Ok(WebSocketMessage::Close(_))) | None => break,
|
||||
Some(Ok(WebSocketMessage::Binary(_))) => {
|
||||
warn!("dropping unsupported binary websocket message");
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
warn!("websocket receive error: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user