mirror of
https://github.com/openai/codex.git
synced 2026-04-05 15:01:40 +03:00
Compare commits
24 Commits
starr/appl
...
codex/webs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
32706e4119 | ||
|
|
9bc01bc57b | ||
|
|
5b53648571 | ||
|
|
cb3931b0ad | ||
|
|
2230a7ca20 | ||
|
|
39097ab65d | ||
|
|
3a22e10172 | ||
|
|
c9e706f8b6 | ||
|
|
8a19dbb177 | ||
|
|
6edb865cc6 | ||
|
|
13d828d236 | ||
|
|
93672f05df | ||
|
|
de684bd7de | ||
|
|
044fe18b8a | ||
|
|
e716db62b1 | ||
|
|
0f1f511e3c | ||
|
|
457fabc409 | ||
|
|
f2623f20f3 | ||
|
|
bca7d04529 | ||
|
|
558b357a6c | ||
|
|
db806e3aaa | ||
|
|
5537c4c014 | ||
|
|
a8d299f065 | ||
|
|
2350059789 |
1
.bazelrc
1
.bazelrc
@@ -124,7 +124,6 @@ build:argument-comment-lint --@rules_rust//rust/toolchain/channel=nightly
|
||||
common:ci-windows --config=ci-bazel
|
||||
common:ci-windows --build_metadata=TAG_os=windows
|
||||
common:ci-windows --repo_contents_cache=D:/a/.cache/bazel-repo-contents-cache
|
||||
common:ci-windows --repository_cache=D:/a/.cache/bazel-repo-cache
|
||||
|
||||
# We prefer to run the build actions entirely remotely so we can dial up the concurrency.
|
||||
# We have platform-specific tests, so we want to execute the tests on all platforms using the strongest sandboxing available on each platform.
|
||||
|
||||
31
.github/actions/setup-bazel-ci/action.yml
vendored
31
.github/actions/setup-bazel-ci/action.yml
vendored
@@ -9,9 +9,9 @@ inputs:
|
||||
required: false
|
||||
default: "false"
|
||||
outputs:
|
||||
cache-hit:
|
||||
description: Whether the Bazel repository cache key was restored exactly.
|
||||
value: ${{ steps.cache_bazel_repository_restore.outputs.cache-hit }}
|
||||
repository-cache-path:
|
||||
description: Filesystem path used for the Bazel repository cache.
|
||||
value: ${{ steps.configure_bazel_repository_cache.outputs.repository-cache-path }}
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
@@ -41,17 +41,16 @@ runs:
|
||||
- name: Set up Bazel
|
||||
uses: bazelbuild/setup-bazelisk@v3
|
||||
|
||||
# Restore bazel repository cache so we don't have to redownload all the external dependencies
|
||||
# on every CI run.
|
||||
- name: Restore bazel repository cache
|
||||
id: cache_bazel_repository_restore
|
||||
uses: actions/cache/restore@v5
|
||||
with:
|
||||
path: |
|
||||
~/.cache/bazel-repo-cache
|
||||
key: bazel-cache-${{ inputs.target }}-${{ hashFiles('MODULE.bazel', 'codex-rs/Cargo.lock', 'codex-rs/Cargo.toml') }}
|
||||
restore-keys: |
|
||||
bazel-cache-${{ inputs.target }}
|
||||
- name: Configure Bazel repository cache
|
||||
id: configure_bazel_repository_cache
|
||||
shell: pwsh
|
||||
run: |
|
||||
# Keep the repository cache under HOME on all runners. Windows `D:\a`
|
||||
# cache paths match `.bazelrc`, but `actions/cache/restore` currently
|
||||
# returns HTTP 400 for that path in the Windows clippy job.
|
||||
$repositoryCachePath = Join-Path $HOME '.cache/bazel-repo-cache'
|
||||
"repository-cache-path=$repositoryCachePath" | Out-File -FilePath $env:GITHUB_OUTPUT -Encoding utf8 -Append
|
||||
"BAZEL_REPOSITORY_CACHE=$repositoryCachePath" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
|
||||
|
||||
- name: Configure Bazel output root (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
@@ -65,10 +64,6 @@ runs:
|
||||
$repoContentsCache = Join-Path $env:RUNNER_TEMP "bazel-repo-contents-cache-$env:GITHUB_RUN_ID-$env:GITHUB_JOB"
|
||||
"BAZEL_OUTPUT_USER_ROOT=$bazelOutputUserRoot" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
|
||||
"BAZEL_REPO_CONTENTS_CACHE=$repoContentsCache" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
|
||||
if (-not $hasDDrive) {
|
||||
$repositoryCache = Join-Path $env:USERPROFILE '.cache\bazel-repo-cache'
|
||||
"BAZEL_REPOSITORY_CACHE=$repositoryCache" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
|
||||
}
|
||||
|
||||
- name: Expose MSVC SDK environment (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
|
||||
54
.github/workflows/bazel.yml
vendored
54
.github/workflows/bazel.yml
vendored
@@ -58,6 +58,20 @@ jobs:
|
||||
target: ${{ matrix.target }}
|
||||
install-test-prereqs: "true"
|
||||
|
||||
# Restore the Bazel repository cache explicitly so external dependencies
|
||||
# do not need to be re-downloaded on every CI run. Keep restore failures
|
||||
# non-fatal so transient cache-service errors degrade to a cold build
|
||||
# instead of failing the job.
|
||||
- name: Restore bazel repository cache
|
||||
id: cache_bazel_repository_restore
|
||||
continue-on-error: true
|
||||
uses: actions/cache/restore@v5
|
||||
with:
|
||||
path: ${{ steps.setup_bazel.outputs.repository-cache-path }}
|
||||
key: bazel-cache-${{ matrix.target }}-${{ hashFiles('MODULE.bazel', 'codex-rs/Cargo.lock', 'codex-rs/Cargo.toml') }}
|
||||
restore-keys: |
|
||||
bazel-cache-${{ matrix.target }}
|
||||
|
||||
- name: Check MODULE.bazel.lock is up to date
|
||||
if: matrix.os == 'ubuntu-24.04' && matrix.target == 'x86_64-unknown-linux-gnu'
|
||||
shell: bash
|
||||
@@ -112,12 +126,11 @@ jobs:
|
||||
# Save bazel repository cache explicitly; make non-fatal so cache uploading
|
||||
# never fails the overall job. Only save when key wasn't hit.
|
||||
- name: Save bazel repository cache
|
||||
if: always() && !cancelled() && steps.setup_bazel.outputs.cache-hit != 'true'
|
||||
if: always() && !cancelled() && steps.cache_bazel_repository_restore.outputs.cache-hit != 'true'
|
||||
continue-on-error: true
|
||||
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5
|
||||
with:
|
||||
path: |
|
||||
~/.cache/bazel-repo-cache
|
||||
path: ${{ steps.setup_bazel.outputs.repository-cache-path }}
|
||||
key: bazel-cache-${{ matrix.target }}-${{ hashFiles('MODULE.bazel', 'codex-rs/Cargo.lock', 'codex-rs/Cargo.toml') }}
|
||||
|
||||
clippy:
|
||||
@@ -148,6 +161,20 @@ jobs:
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
|
||||
# Restore the Bazel repository cache explicitly so external dependencies
|
||||
# do not need to be re-downloaded on every CI run. Keep restore failures
|
||||
# non-fatal so transient cache-service errors degrade to a cold build
|
||||
# instead of failing the job.
|
||||
- name: Restore bazel repository cache
|
||||
id: cache_bazel_repository_restore
|
||||
continue-on-error: true
|
||||
uses: actions/cache/restore@v5
|
||||
with:
|
||||
path: ${{ steps.setup_bazel.outputs.repository-cache-path }}
|
||||
key: bazel-cache-${{ matrix.target }}-${{ hashFiles('MODULE.bazel', 'codex-rs/Cargo.lock', 'codex-rs/Cargo.toml') }}
|
||||
restore-keys: |
|
||||
bazel-cache-${{ matrix.target }}
|
||||
|
||||
- name: Set up Bazel execution logs
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -159,6 +186,18 @@ jobs:
|
||||
BUILDBUDDY_API_KEY: ${{ secrets.BUILDBUDDY_API_KEY }}
|
||||
shell: bash
|
||||
run: |
|
||||
bazel_clippy_args=(
|
||||
--config=clippy
|
||||
--build_metadata=COMMIT_SHA=${GITHUB_SHA}
|
||||
--build_metadata=TAG_job=clippy
|
||||
)
|
||||
if [[ "${RUNNER_OS}" == "Windows" ]]; then
|
||||
# Some explicit targets pulled in through //codex-rs/... are
|
||||
# intentionally incompatible with `//:local_windows`, but the lint
|
||||
# aspect still traverses their compatible Rust deps.
|
||||
bazel_clippy_args+=(--skip_incompatible_explicit_targets)
|
||||
fi
|
||||
|
||||
bazel_target_lines="$(./scripts/list-bazel-clippy-targets.sh)"
|
||||
bazel_targets=()
|
||||
while IFS= read -r target; do
|
||||
@@ -168,9 +207,7 @@ jobs:
|
||||
./.github/scripts/run-bazel-ci.sh \
|
||||
-- \
|
||||
build \
|
||||
--config=clippy \
|
||||
--build_metadata=COMMIT_SHA=${GITHUB_SHA} \
|
||||
--build_metadata=TAG_job=clippy \
|
||||
"${bazel_clippy_args[@]}" \
|
||||
-- \
|
||||
"${bazel_targets[@]}"
|
||||
|
||||
@@ -186,10 +223,9 @@ jobs:
|
||||
# Save bazel repository cache explicitly; make non-fatal so cache uploading
|
||||
# never fails the overall job. Only save when key wasn't hit.
|
||||
- name: Save bazel repository cache
|
||||
if: always() && !cancelled() && steps.setup_bazel.outputs.cache-hit != 'true'
|
||||
if: always() && !cancelled() && steps.cache_bazel_repository_restore.outputs.cache-hit != 'true'
|
||||
continue-on-error: true
|
||||
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5
|
||||
with:
|
||||
path: |
|
||||
~/.cache/bazel-repo-cache
|
||||
path: ${{ steps.setup_bazel.outputs.repository-cache-path }}
|
||||
key: bazel-cache-${{ matrix.target }}-${{ hashFiles('MODULE.bazel', 'codex-rs/Cargo.lock', 'codex-rs/Cargo.toml') }}
|
||||
|
||||
7
.github/workflows/rust-release.yml
vendored
7
.github/workflows/rust-release.yml
vendored
@@ -584,14 +584,11 @@ jobs:
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6
|
||||
with:
|
||||
node-version: 22
|
||||
# Node 24 bundles npm >= 11.5.1, which trusted publishing requires.
|
||||
node-version: 24
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
scope: "@openai"
|
||||
|
||||
# Trusted publishing requires npm CLI version 11.5.1 or later.
|
||||
- name: Update npm
|
||||
run: npm install -g npm@latest
|
||||
|
||||
- name: Download npm tarballs from release
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
4
codex-rs/Cargo.lock
generated
4
codex-rs/Cargo.lock
generated
@@ -1433,9 +1433,11 @@ dependencies = [
|
||||
"codex-utils-cli",
|
||||
"codex-utils-json-to-toml",
|
||||
"codex-utils-pty",
|
||||
"codex-utils-rustls-provider",
|
||||
"constant_time_eq",
|
||||
"core_test_support",
|
||||
"futures",
|
||||
"gethostname",
|
||||
"hmac",
|
||||
"jsonwebtoken",
|
||||
"opentelemetry",
|
||||
@@ -1458,6 +1460,7 @@ dependencies = [
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"uuid",
|
||||
"wiremock",
|
||||
]
|
||||
@@ -2552,6 +2555,7 @@ dependencies = [
|
||||
"codex-process-hardening",
|
||||
"ctor 0.6.3",
|
||||
"libc",
|
||||
"pretty_assertions",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
||||
@@ -57,10 +57,12 @@ codex-state = { workspace = true }
|
||||
codex-tools = { workspace = true }
|
||||
codex-utils-absolute-path = { workspace = true }
|
||||
codex-utils-json-to-toml = { workspace = true }
|
||||
codex-utils-rustls-provider = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
constant_time_eq = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
gethostname = { workspace = true }
|
||||
hmac = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
owo-colors = { workspace = true, features = ["supports-colors"] }
|
||||
@@ -81,6 +83,7 @@ tokio-util = { workspace = true }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] }
|
||||
url = { workspace = true }
|
||||
uuid = { workspace = true, features = ["serde", "v7"] }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -25,6 +25,7 @@ Supported transports:
|
||||
|
||||
- stdio (`--listen stdio://`, default): newline-delimited JSON (JSONL)
|
||||
- websocket (`--listen ws://IP:PORT`): one JSON-RPC message per websocket text frame (**experimental / unsupported**)
|
||||
- off (`--listen off`): do not expose a local transport
|
||||
|
||||
When running with `--listen ws://IP:PORT`, the same listener also serves basic HTTP health probes:
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ fn transport_name(transport: AppServerTransport) -> &'static str {
|
||||
match transport {
|
||||
AppServerTransport::Stdio => "stdio",
|
||||
AppServerTransport::WebSocket { .. } => "websocket",
|
||||
AppServerTransport::Off => "off",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
17
codex-rs/app-server/src/auth_manager.rs
Normal file
17
codex-rs/app-server/src/auth_manager.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::config::Config;
|
||||
use codex_login::AuthManager;
|
||||
|
||||
pub(crate) fn auth_manager_from_config(
|
||||
config: &Config,
|
||||
enable_codex_api_key_env: bool,
|
||||
) -> Arc<AuthManager> {
|
||||
let auth_manager = AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
enable_codex_api_key_env,
|
||||
config.cli_auth_credentials_store_mode,
|
||||
);
|
||||
auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone());
|
||||
auth_manager
|
||||
}
|
||||
@@ -50,6 +50,7 @@ use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::auth_manager::auth_manager_from_config;
|
||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::error_code::OVERLOADED_ERROR_CODE;
|
||||
@@ -378,6 +379,8 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
|
||||
}
|
||||
});
|
||||
|
||||
let auth_manager = auth_manager_from_config(&args.config, args.enable_codex_api_key_env);
|
||||
|
||||
let processor_outgoing = Arc::clone(&outgoing_message_sender);
|
||||
let (processor_tx, mut processor_rx) = mpsc::channel::<ProcessorCommand>(channel_capacity);
|
||||
let mut processor_handle = tokio::spawn(async move {
|
||||
@@ -393,7 +396,7 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
|
||||
log_db: None,
|
||||
config_warnings: args.config_warnings,
|
||||
session_source: args.session_source,
|
||||
enable_codex_api_key_env: args.enable_codex_api_key_env,
|
||||
auth_manager,
|
||||
rpc_transport: AppServerRpcTransport::InProcess,
|
||||
});
|
||||
let mut thread_created_rx = processor.thread_created_receiver();
|
||||
|
||||
@@ -7,6 +7,7 @@ use codex_core::config::ConfigBuilder;
|
||||
use codex_core::config_loader::CloudRequirementsLoader;
|
||||
use codex_core::config_loader::ConfigLayerStackOrdering;
|
||||
use codex_core::config_loader::LoaderOverrides;
|
||||
use codex_features::Feature;
|
||||
use codex_utils_cli::CliConfigOverrides;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
@@ -16,6 +17,7 @@ use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use crate::auth_manager::auth_manager_from_config;
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::message_processor::MessageProcessorArgs;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
@@ -28,6 +30,7 @@ use crate::transport::OutboundConnectionState;
|
||||
use crate::transport::TransportEvent;
|
||||
use crate::transport::auth::policy_from_settings;
|
||||
use crate::transport::route_outgoing_envelope;
|
||||
use crate::transport::start_remote_control;
|
||||
use crate::transport::start_stdio_connection;
|
||||
use crate::transport::start_websocket_acceptor;
|
||||
use codex_analytics::AppServerRpcTransport;
|
||||
@@ -42,7 +45,6 @@ use codex_core::config_loader::ConfigLoadError;
|
||||
use codex_core::config_loader::TextRange as CoreTextRange;
|
||||
use codex_exec_server::EnvironmentManager;
|
||||
use codex_feedback::CodexFeedback;
|
||||
use codex_login::AuthManager;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_state::log_db;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -61,6 +63,7 @@ use tracing_subscriber::registry::Registry;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
|
||||
mod app_server_tracing;
|
||||
mod auth_manager;
|
||||
mod bespoke_event_handling;
|
||||
mod codex_message_processor;
|
||||
mod command_exec;
|
||||
@@ -396,11 +399,8 @@ pub async fn run_main_with_transport(
|
||||
}
|
||||
}
|
||||
|
||||
let auth_manager = AuthManager::shared(
|
||||
config.codex_home.clone(),
|
||||
/*enable_codex_api_key_env*/ false,
|
||||
config.cli_auth_credentials_store_mode,
|
||||
);
|
||||
let auth_manager =
|
||||
auth_manager_from_config(&config, /*enable_codex_api_key_env*/ false);
|
||||
cloud_requirements_loader(
|
||||
auth_manager,
|
||||
config.chatgpt_base_url,
|
||||
@@ -502,13 +502,13 @@ pub async fn run_main_with_transport(
|
||||
|
||||
let feedback_layer = feedback.logger_layer();
|
||||
let feedback_metadata_layer = feedback.metadata_layer();
|
||||
let log_db = codex_state::StateRuntime::init(
|
||||
let state_db = codex_state::StateRuntime::init(
|
||||
config.sqlite_home.clone(),
|
||||
config.model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.ok()
|
||||
.map(log_db::start);
|
||||
.ok();
|
||||
let log_db = state_db.clone().map(log_db::start);
|
||||
let log_db_layer = log_db
|
||||
.clone()
|
||||
.map(|layer| layer.with_filter(Targets::new().with_default(Level::TRACE)));
|
||||
@@ -551,6 +551,27 @@ pub async fn run_main_with_transport(
|
||||
.await?;
|
||||
transport_accept_handles.push(accept_handle);
|
||||
}
|
||||
AppServerTransport::Off => {}
|
||||
}
|
||||
|
||||
let auth_manager = auth_manager_from_config(&config, /*enable_codex_api_key_env*/ false);
|
||||
|
||||
if config.features.enabled(Feature::RemoteControl) {
|
||||
let accept_handle = start_remote_control(
|
||||
config.chatgpt_base_url.clone(),
|
||||
state_db.clone(),
|
||||
auth_manager.clone(),
|
||||
transport_event_tx.clone(),
|
||||
transport_shutdown_token.clone(),
|
||||
)
|
||||
.await?;
|
||||
transport_accept_handles.push(accept_handle);
|
||||
}
|
||||
if transport_accept_handles.is_empty() {
|
||||
return Err(std::io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
"no transport configured; use --listen or enable remote control",
|
||||
));
|
||||
}
|
||||
|
||||
let outbound_handle = tokio::spawn(async move {
|
||||
@@ -625,7 +646,7 @@ pub async fn run_main_with_transport(
|
||||
log_db,
|
||||
config_warnings,
|
||||
session_source,
|
||||
enable_codex_api_key_env: false,
|
||||
auth_manager,
|
||||
rpc_transport: analytics_rpc_transport(transport),
|
||||
});
|
||||
let mut thread_created_rx = processor.thread_created_receiver();
|
||||
@@ -853,7 +874,9 @@ pub async fn run_main_with_transport(
|
||||
fn analytics_rpc_transport(transport: AppServerTransport) -> AppServerRpcTransport {
|
||||
match transport {
|
||||
AppServerTransport::Stdio => AppServerRpcTransport::Stdio,
|
||||
AppServerTransport::WebSocket { .. } => AppServerRpcTransport::Websocket,
|
||||
AppServerTransport::WebSocket { .. } | AppServerTransport::Off => {
|
||||
AppServerRpcTransport::Websocket
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ const MANAGED_CONFIG_PATH_ENV_VAR: &str = "CODEX_APP_SERVER_MANAGED_CONFIG_PATH"
|
||||
#[derive(Debug, Parser)]
|
||||
struct AppServerArgs {
|
||||
/// Transport endpoint URL. Supported values: `stdio://` (default),
|
||||
/// `ws://IP:PORT`.
|
||||
/// `ws://IP:PORT`, `off`.
|
||||
#[arg(
|
||||
long = "listen",
|
||||
value_name = "URL",
|
||||
|
||||
@@ -193,7 +193,7 @@ pub(crate) struct MessageProcessorArgs {
|
||||
pub(crate) log_db: Option<LogDbLayer>,
|
||||
pub(crate) config_warnings: Vec<ConfigWarningNotification>,
|
||||
pub(crate) session_source: SessionSource,
|
||||
pub(crate) enable_codex_api_key_env: bool,
|
||||
pub(crate) auth_manager: Arc<AuthManager>,
|
||||
pub(crate) rpc_transport: AppServerRpcTransport,
|
||||
}
|
||||
|
||||
@@ -213,17 +213,12 @@ impl MessageProcessor {
|
||||
log_db,
|
||||
config_warnings,
|
||||
session_source,
|
||||
enable_codex_api_key_env,
|
||||
auth_manager,
|
||||
rpc_transport,
|
||||
} = args;
|
||||
let auth_manager = AuthManager::shared_with_external_auth(
|
||||
config.codex_home.clone(),
|
||||
enable_codex_api_key_env,
|
||||
config.cli_auth_credentials_store_mode,
|
||||
Arc::new(ExternalAuthRefreshBridge {
|
||||
outgoing: outgoing.clone(),
|
||||
}),
|
||||
);
|
||||
auth_manager.set_external_auth(Arc::new(ExternalAuthRefreshBridge {
|
||||
outgoing: outgoing.clone(),
|
||||
}));
|
||||
let thread_manager = Arc::new(ThreadManager::new(
|
||||
config.as_ref(),
|
||||
auth_manager.clone(),
|
||||
@@ -235,7 +230,6 @@ impl MessageProcessor {
|
||||
},
|
||||
environment_manager,
|
||||
));
|
||||
auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone());
|
||||
let analytics_events_client = AnalyticsEventsClient::new(
|
||||
Arc::clone(&auth_manager),
|
||||
config.chatgpt_base_url.trim_end_matches('/').to_string(),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use super::ConnectionSessionState;
|
||||
use super::MessageProcessor;
|
||||
use super::MessageProcessorArgs;
|
||||
use crate::auth_manager::auth_manager_from_config;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::transport::AppServerTransport;
|
||||
@@ -232,6 +233,8 @@ fn build_test_processor(
|
||||
MessageProcessor,
|
||||
mpsc::Receiver<crate::outgoing_message::OutgoingEnvelope>,
|
||||
) {
|
||||
let auth_manager = auth_manager_from_config(&config, /*enable_codex_api_key_env*/ false);
|
||||
|
||||
let (outgoing_tx, outgoing_rx) = mpsc::channel(16);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(outgoing_tx));
|
||||
let processor = MessageProcessor::new(MessageProcessorArgs {
|
||||
@@ -246,7 +249,7 @@ fn build_test_processor(
|
||||
log_db: None,
|
||||
config_warnings: Vec::new(),
|
||||
session_source: SessionSource::VSCode,
|
||||
enable_codex_api_key_env: false,
|
||||
auth_manager,
|
||||
rpc_transport: AppServerRpcTransport::Stdio,
|
||||
});
|
||||
(processor, outgoing_rx)
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -28,9 +29,11 @@ use tracing::warn;
|
||||
/// plenty for an interactive CLI.
|
||||
pub(crate) const CHANNEL_CAPACITY: usize = 128;
|
||||
|
||||
mod remote_control;
|
||||
mod stdio;
|
||||
mod websocket;
|
||||
|
||||
pub(crate) use remote_control::start_remote_control;
|
||||
pub(crate) use stdio::start_stdio_connection;
|
||||
pub(crate) use websocket::start_websocket_acceptor;
|
||||
|
||||
@@ -38,6 +41,7 @@ pub(crate) use websocket::start_websocket_acceptor;
|
||||
pub enum AppServerTransport {
|
||||
Stdio,
|
||||
WebSocket { bind_address: SocketAddr },
|
||||
Off,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
@@ -51,7 +55,7 @@ impl std::fmt::Display for AppServerTransportParseError {
|
||||
match self {
|
||||
AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!(
|
||||
f,
|
||||
"unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`"
|
||||
"unsupported --listen URL `{listen_url}`; expected `stdio://`, `ws://IP:PORT`, or `off`"
|
||||
),
|
||||
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!(
|
||||
f,
|
||||
@@ -71,6 +75,10 @@ impl AppServerTransport {
|
||||
return Ok(Self::Stdio);
|
||||
}
|
||||
|
||||
if listen_url == "off" {
|
||||
return Ok(Self::Off);
|
||||
}
|
||||
|
||||
if let Some(socket_addr) = listen_url.strip_prefix("ws://") {
|
||||
let bind_address = socket_addr.parse::<SocketAddr>().map_err(|_| {
|
||||
AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string())
|
||||
@@ -166,6 +174,12 @@ impl OutboundConnectionState {
|
||||
}
|
||||
}
|
||||
|
||||
static CONNECTION_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
fn next_connection_id() -> ConnectionId {
|
||||
ConnectionId(CONNECTION_ID_COUNTER.fetch_add(1, Ordering::Relaxed))
|
||||
}
|
||||
|
||||
async fn forward_incoming_message(
|
||||
transport_event_tx: &mpsc::Sender<TransportEvent>,
|
||||
writer: &mpsc::Sender<QueuedOutgoingMessage>,
|
||||
@@ -378,8 +392,11 @@ pub(crate) async fn route_outgoing_envelope(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::error_code::OVERLOADED_ERROR_CODE;
|
||||
use codex_app_server_protocol::ConfigWarningNotification;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -393,41 +410,10 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_parses_stdio_listen_url() {
|
||||
let transport = AppServerTransport::from_listen_url(AppServerTransport::DEFAULT_LISTEN_URL)
|
||||
.expect("stdio listen URL should parse");
|
||||
assert_eq!(transport, AppServerTransport::Stdio);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_parses_websocket_listen_url() {
|
||||
let transport = AppServerTransport::from_listen_url("ws://127.0.0.1:1234")
|
||||
.expect("websocket listen URL should parse");
|
||||
fn listen_off_parses_as_off_transport() {
|
||||
assert_eq!(
|
||||
transport,
|
||||
AppServerTransport::WebSocket {
|
||||
bind_address: "127.0.0.1:1234".parse().expect("valid socket address"),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_rejects_invalid_websocket_listen_url() {
|
||||
let err = AppServerTransport::from_listen_url("ws://localhost:1234")
|
||||
.expect_err("hostname bind address should be rejected");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"invalid websocket --listen URL `ws://localhost:1234`; expected `ws://IP:PORT`"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_transport_rejects_unsupported_listen_url() {
|
||||
let err = AppServerTransport::from_listen_url("http://127.0.0.1:1234")
|
||||
.expect_err("unsupported scheme should fail");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`"
|
||||
AppServerTransport::from_listen_url("off"),
|
||||
Ok(AppServerTransport::Off)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -437,11 +423,10 @@ mod tests {
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
|
||||
let first_message =
|
||||
JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let first_message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
@@ -450,8 +435,8 @@ mod tests {
|
||||
.await
|
||||
.expect("queue should accept first message");
|
||||
|
||||
let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest {
|
||||
id: codex_app_server_protocol::RequestId::Integer(7),
|
||||
let request = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(7),
|
||||
method: "config/read".to_string(),
|
||||
params: Some(json!({ "includeLayers": false })),
|
||||
trace: None,
|
||||
@@ -499,11 +484,10 @@ mod tests {
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1);
|
||||
let (writer_tx, _writer_rx) = mpsc::channel(1);
|
||||
|
||||
let first_message =
|
||||
JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
let first_message = JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
});
|
||||
transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
@@ -512,8 +496,8 @@ mod tests {
|
||||
.await
|
||||
.expect("queue should accept first message");
|
||||
|
||||
let response = JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse {
|
||||
id: codex_app_server_protocol::RequestId::Integer(7),
|
||||
let response = JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: RequestId::Integer(7),
|
||||
result: json!({"ok": true}),
|
||||
});
|
||||
let transport_event_tx_for_enqueue = transport_event_tx.clone();
|
||||
@@ -553,11 +537,10 @@ mod tests {
|
||||
match forwarded_event {
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: queued_connection_id,
|
||||
message:
|
||||
JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { id, result }),
|
||||
message: JSONRPCMessage::Response(JSONRPCResponse { id, result }),
|
||||
} => {
|
||||
assert_eq!(queued_connection_id, connection_id);
|
||||
assert_eq!(id, codex_app_server_protocol::RequestId::Integer(7));
|
||||
assert_eq!(id, RequestId::Integer(7));
|
||||
assert_eq!(result, json!({"ok": true}));
|
||||
}
|
||||
_ => panic!("expected forwarded response message"),
|
||||
@@ -573,12 +556,10 @@ mod tests {
|
||||
transport_event_tx
|
||||
.send(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message: JSONRPCMessage::Notification(
|
||||
codex_app_server_protocol::JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
},
|
||||
),
|
||||
message: JSONRPCMessage::Notification(JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
}),
|
||||
})
|
||||
.await
|
||||
.expect("transport queue should accept first message");
|
||||
@@ -597,15 +578,15 @@ mod tests {
|
||||
.await
|
||||
.expect("writer queue should accept first message");
|
||||
|
||||
let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest {
|
||||
id: codex_app_server_protocol::RequestId::Integer(7),
|
||||
let request = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(7),
|
||||
method: "config/read".to_string(),
|
||||
params: Some(json!({ "includeLayers": false })),
|
||||
trace: None,
|
||||
});
|
||||
|
||||
let enqueue_result = tokio::time::timeout(
|
||||
std::time::Duration::from_millis(100),
|
||||
let enqueue_result = timeout(
|
||||
Duration::from_millis(100),
|
||||
enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request),
|
||||
)
|
||||
.await
|
||||
@@ -781,7 +762,7 @@ mod tests {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||
request_id: codex_app_server_protocol::RequestId::Integer(1),
|
||||
request_id: RequestId::Integer(1),
|
||||
params: codex_app_server_protocol::CommandExecutionRequestApprovalParams {
|
||||
thread_id: "thr_123".to_string(),
|
||||
turn_id: "turn_123".to_string(),
|
||||
@@ -843,7 +824,7 @@ mod tests {
|
||||
OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||
request_id: codex_app_server_protocol::RequestId::Integer(1),
|
||||
request_id: RequestId::Integer(1),
|
||||
params: codex_app_server_protocol::CommandExecutionRequestApprovalParams {
|
||||
thread_id: "thr_123".to_string(),
|
||||
turn_id: "turn_123".to_string(),
|
||||
|
||||
@@ -0,0 +1,568 @@
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
use super::next_connection_id;
|
||||
use super::protocol::ClientEnvelope;
|
||||
pub use super::protocol::ClientEvent;
|
||||
pub use super::protocol::ClientId;
|
||||
use super::protocol::PongStatus;
|
||||
use super::protocol::ServerEvent;
|
||||
use super::protocol::StreamId;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use crate::transport::remote_control::QueuedServerEnvelope;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::watch;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::Instant;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
const REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT: Duration = Duration::from_secs(10 * 60);
|
||||
pub(crate) const REMOTE_CONTROL_IDLE_SWEEP_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Stopped;
|
||||
|
||||
struct ClientState {
|
||||
connection_id: ConnectionId,
|
||||
disconnect_token: CancellationToken,
|
||||
last_activity_at: Instant,
|
||||
last_inbound_seq_id: Option<u64>,
|
||||
status_tx: watch::Sender<PongStatus>,
|
||||
}
|
||||
|
||||
pub(crate) struct ClientTracker {
|
||||
clients: HashMap<(ClientId, StreamId), ClientState>,
|
||||
legacy_stream_ids: HashMap<ClientId, StreamId>,
|
||||
join_set: JoinSet<(ClientId, StreamId)>,
|
||||
server_event_tx: mpsc::Sender<QueuedServerEnvelope>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl ClientTracker {
|
||||
pub(crate) fn new(
|
||||
server_event_tx: mpsc::Sender<QueuedServerEnvelope>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: &CancellationToken,
|
||||
) -> Self {
|
||||
Self {
|
||||
clients: HashMap::new(),
|
||||
legacy_stream_ids: HashMap::new(),
|
||||
join_set: JoinSet::new(),
|
||||
server_event_tx,
|
||||
transport_event_tx,
|
||||
shutdown_token: shutdown_token.child_token(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn bookkeep_join_set(&mut self) -> Option<(ClientId, StreamId)> {
|
||||
while let Some(join_result) = self.join_set.join_next().await {
|
||||
let Ok(client_key) = join_result else {
|
||||
continue;
|
||||
};
|
||||
return Some(client_key);
|
||||
}
|
||||
futures::future::pending().await
|
||||
}
|
||||
|
||||
pub(crate) async fn shutdown(&mut self) {
|
||||
self.shutdown_token.cancel();
|
||||
|
||||
while let Some(client_key) = self.clients.keys().next().cloned() {
|
||||
let _ = self.close_client(&client_key).await;
|
||||
}
|
||||
|
||||
self.drain_join_set().await;
|
||||
}
|
||||
|
||||
async fn drain_join_set(&mut self) {
|
||||
while self.join_set.join_next().await.is_some() {}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_message(
|
||||
&mut self,
|
||||
client_envelope: ClientEnvelope,
|
||||
) -> Result<(), Stopped> {
|
||||
let ClientEnvelope {
|
||||
client_id,
|
||||
event,
|
||||
stream_id,
|
||||
seq_id,
|
||||
cursor: _,
|
||||
} = client_envelope;
|
||||
let is_legacy_stream_id = stream_id.is_none();
|
||||
let is_initialize = matches!(&event, ClientEvent::ClientMessage { message } if remote_control_message_starts_connection(message));
|
||||
let stream_id = match stream_id {
|
||||
Some(stream_id) => stream_id,
|
||||
None if is_initialize => {
|
||||
// TODO(ruslan): delete this fallback once all clients are updated to send stream_id.
|
||||
self.legacy_stream_ids
|
||||
.remove(&client_id)
|
||||
.unwrap_or_else(StreamId::new_random)
|
||||
}
|
||||
None => self
|
||||
.legacy_stream_ids
|
||||
.get(&client_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| {
|
||||
if matches!(&event, ClientEvent::Ping) {
|
||||
StreamId::new_random()
|
||||
} else {
|
||||
StreamId(String::new())
|
||||
}
|
||||
}),
|
||||
};
|
||||
if stream_id.0.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let client_key = (client_id.clone(), stream_id.clone());
|
||||
match event {
|
||||
ClientEvent::ClientMessage { message } => {
|
||||
if let Some(seq_id) = seq_id
|
||||
&& let Some(client) = self.clients.get(&client_key)
|
||||
&& client
|
||||
.last_inbound_seq_id
|
||||
.is_some_and(|last_seq_id| last_seq_id >= seq_id)
|
||||
&& !is_initialize
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if is_initialize && self.clients.contains_key(&client_key) {
|
||||
self.close_client(&client_key).await?;
|
||||
}
|
||||
|
||||
if let Some(connection_id) = self.clients.get_mut(&client_key).map(|client| {
|
||||
client.last_activity_at = Instant::now();
|
||||
if let Some(seq_id) = seq_id {
|
||||
client.last_inbound_seq_id = Some(seq_id);
|
||||
}
|
||||
client.connection_id
|
||||
}) {
|
||||
self.send_transport_event(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
})
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !is_initialize {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let connection_id = next_connection_id();
|
||||
let (writer_tx, writer_rx) =
|
||||
mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let disconnect_token = self.shutdown_token.child_token();
|
||||
self.send_transport_event(TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
writer: writer_tx,
|
||||
disconnect_sender: Some(disconnect_token.clone()),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let (status_tx, status_rx) = watch::channel(PongStatus::Active);
|
||||
self.join_set.spawn(Self::run_client_outbound(
|
||||
client_id.clone(),
|
||||
stream_id.clone(),
|
||||
self.server_event_tx.clone(),
|
||||
writer_rx,
|
||||
status_rx,
|
||||
disconnect_token.clone(),
|
||||
));
|
||||
self.clients.insert(
|
||||
client_key,
|
||||
ClientState {
|
||||
connection_id,
|
||||
disconnect_token,
|
||||
last_activity_at: Instant::now(),
|
||||
last_inbound_seq_id: if is_legacy_stream_id { None } else { seq_id },
|
||||
status_tx,
|
||||
},
|
||||
);
|
||||
if is_legacy_stream_id {
|
||||
self.legacy_stream_ids.insert(client_id.clone(), stream_id);
|
||||
}
|
||||
self.send_transport_event(TransportEvent::IncomingMessage {
|
||||
connection_id,
|
||||
message,
|
||||
})
|
||||
.await
|
||||
}
|
||||
ClientEvent::Ack => Ok(()),
|
||||
ClientEvent::Ping => {
|
||||
if let Some(client) = self.clients.get_mut(&client_key) {
|
||||
client.last_activity_at = Instant::now();
|
||||
let _ = client.status_tx.send(PongStatus::Active);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let server_event_tx = self.server_event_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let server_envelope = QueuedServerEnvelope {
|
||||
event: ServerEvent::Pong {
|
||||
status: PongStatus::Unknown,
|
||||
},
|
||||
client_id,
|
||||
stream_id,
|
||||
write_complete_tx: None,
|
||||
};
|
||||
let _ = server_event_tx.send(server_envelope).await;
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
ClientEvent::ClientClosed => self.close_client(&client_key).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_client_outbound(
|
||||
client_id: ClientId,
|
||||
stream_id: StreamId,
|
||||
server_event_tx: mpsc::Sender<QueuedServerEnvelope>,
|
||||
mut writer_rx: mpsc::Receiver<QueuedOutgoingMessage>,
|
||||
mut status_rx: watch::Receiver<PongStatus>,
|
||||
disconnect_token: CancellationToken,
|
||||
) -> (ClientId, StreamId) {
|
||||
loop {
|
||||
let (event, write_complete_tx) = tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
queued_message = writer_rx.recv() => {
|
||||
let Some(queued_message) = queued_message else {
|
||||
break;
|
||||
};
|
||||
let event = ServerEvent::ServerMessage {
|
||||
message: Box::new(queued_message.message),
|
||||
};
|
||||
(event, queued_message.write_complete_tx)
|
||||
}
|
||||
changed = status_rx.changed() => {
|
||||
if changed.is_err() {
|
||||
break;
|
||||
}
|
||||
let event = ServerEvent::Pong { status: status_rx.borrow().clone() };
|
||||
(event, None)
|
||||
}
|
||||
};
|
||||
let send_result = tokio::select! {
|
||||
_ = disconnect_token.cancelled() => {
|
||||
break;
|
||||
}
|
||||
send_result = server_event_tx.send(QueuedServerEnvelope {
|
||||
event,
|
||||
client_id: client_id.clone(),
|
||||
stream_id: stream_id.clone(),
|
||||
write_complete_tx,
|
||||
}) => send_result,
|
||||
};
|
||||
if send_result.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
(client_id, stream_id)
|
||||
}
|
||||
|
||||
pub(crate) async fn close_expired_clients(
|
||||
&mut self,
|
||||
) -> Result<Vec<(ClientId, StreamId)>, Stopped> {
|
||||
let now = Instant::now();
|
||||
let expired_client_ids: Vec<(ClientId, StreamId)> = self
|
||||
.clients
|
||||
.iter()
|
||||
.filter_map(|(client_key, client)| {
|
||||
(!remote_control_client_is_alive(client, now)).then_some(client_key.clone())
|
||||
})
|
||||
.collect();
|
||||
for client_key in &expired_client_ids {
|
||||
self.close_client(client_key).await?;
|
||||
}
|
||||
Ok(expired_client_ids)
|
||||
}
|
||||
|
||||
pub(super) async fn close_client(
|
||||
&mut self,
|
||||
client_key: &(ClientId, StreamId),
|
||||
) -> Result<(), Stopped> {
|
||||
let Some(client) = self.clients.remove(client_key) else {
|
||||
return Ok(());
|
||||
};
|
||||
if self
|
||||
.legacy_stream_ids
|
||||
.get(&client_key.0)
|
||||
.is_some_and(|stream_id| stream_id == &client_key.1)
|
||||
{
|
||||
self.legacy_stream_ids.remove(&client_key.0);
|
||||
}
|
||||
client.disconnect_token.cancel();
|
||||
self.send_transport_event(TransportEvent::ConnectionClosed {
|
||||
connection_id: client.connection_id,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_transport_event(&self, event: TransportEvent) -> Result<(), Stopped> {
|
||||
self.transport_event_tx
|
||||
.send(event)
|
||||
.await
|
||||
.map_err(|_| Stopped)
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_control_message_starts_connection(message: &JSONRPCMessage) -> bool {
|
||||
matches!(
|
||||
message,
|
||||
JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { method, .. })
|
||||
if method == "initialize"
|
||||
)
|
||||
}
|
||||
|
||||
fn remote_control_client_is_alive(client: &ClientState, now: Instant) -> bool {
|
||||
now.duration_since(client.last_activity_at) < REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::transport::remote_control::protocol::ClientEnvelope;
|
||||
use crate::transport::remote_control::protocol::ClientEvent;
|
||||
use codex_app_server_protocol::ConfigWarningNotification;
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tokio::time::timeout;
|
||||
|
||||
fn initialize_envelope(client_id: &str) -> ClientEnvelope {
|
||||
initialize_envelope_with_stream_id(client_id, /*stream_id*/ None)
|
||||
}
|
||||
|
||||
fn initialize_envelope_with_stream_id(
|
||||
client_id: &str,
|
||||
stream_id: Option<&str>,
|
||||
) -> ClientEnvelope {
|
||||
ClientEnvelope {
|
||||
event: ClientEvent::ClientMessage {
|
||||
message: JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(1),
|
||||
method: "initialize".to_string(),
|
||||
params: Some(json!({
|
||||
"clientInfo": {
|
||||
"name": "remote-test-client",
|
||||
"version": "0.1.0"
|
||||
}
|
||||
})),
|
||||
trace: None,
|
||||
}),
|
||||
},
|
||||
client_id: ClientId(client_id.to_string()),
|
||||
stream_id: stream_id.map(|stream_id| StreamId(stream_id.to_string())),
|
||||
seq_id: Some(0),
|
||||
cursor: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancelled_outbound_task_emits_connection_closed() {
|
||||
let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let mut client_tracker =
|
||||
ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token);
|
||||
|
||||
client_tracker
|
||||
.handle_message(initialize_envelope("client-1"))
|
||||
.await
|
||||
.expect("initialize should open client");
|
||||
|
||||
let (connection_id, disconnect_sender) = match transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("connection opened should be sent")
|
||||
{
|
||||
TransportEvent::ConnectionOpened {
|
||||
connection_id,
|
||||
disconnect_sender: Some(disconnect_sender),
|
||||
..
|
||||
} => (connection_id, disconnect_sender),
|
||||
other => panic!("expected connection opened, got {other:?}"),
|
||||
};
|
||||
match transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("initialize should be forwarded")
|
||||
{
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: incoming_connection_id,
|
||||
..
|
||||
} => assert_eq!(incoming_connection_id, connection_id),
|
||||
other => panic!("expected incoming initialize, got {other:?}"),
|
||||
}
|
||||
|
||||
disconnect_sender.cancel();
|
||||
let closed_client_id = timeout(Duration::from_secs(1), client_tracker.bookkeep_join_set())
|
||||
.await
|
||||
.expect("bookkeeping should process the closed task")
|
||||
.expect("closed task should return client id");
|
||||
assert_eq!(closed_client_id.0, ClientId("client-1".to_string()));
|
||||
client_tracker
|
||||
.close_client(&closed_client_id)
|
||||
.await
|
||||
.expect("closed client should emit connection closed");
|
||||
|
||||
match transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("connection closed should be sent")
|
||||
{
|
||||
TransportEvent::ConnectionClosed {
|
||||
connection_id: closed_connection_id,
|
||||
} => assert_eq!(closed_connection_id, connection_id),
|
||||
other => panic!("expected connection closed, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_cancels_blocked_outbound_forwarding() {
|
||||
let (server_event_tx, _server_event_rx) = mpsc::channel(1);
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let mut client_tracker =
|
||||
ClientTracker::new(server_event_tx.clone(), transport_event_tx, &shutdown_token);
|
||||
|
||||
server_event_tx
|
||||
.send(QueuedServerEnvelope {
|
||||
event: ServerEvent::Pong {
|
||||
status: PongStatus::Unknown,
|
||||
},
|
||||
client_id: ClientId("queued-client".to_string()),
|
||||
stream_id: StreamId("queued-stream".to_string()),
|
||||
write_complete_tx: None,
|
||||
})
|
||||
.await
|
||||
.expect("server event queue should accept prefill");
|
||||
|
||||
client_tracker
|
||||
.handle_message(initialize_envelope("client-1"))
|
||||
.await
|
||||
.expect("initialize should open client");
|
||||
|
||||
let writer = match transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("connection opened should be sent")
|
||||
{
|
||||
TransportEvent::ConnectionOpened { writer, .. } => writer,
|
||||
other => panic!("expected connection opened, got {other:?}"),
|
||||
};
|
||||
let _ = transport_event_rx
|
||||
.recv()
|
||||
.await
|
||||
.expect("initialize should be forwarded");
|
||||
|
||||
writer
|
||||
.send(QueuedOutgoingMessage::new(
|
||||
OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning(
|
||||
ConfigWarningNotification {
|
||||
summary: "test".to_string(),
|
||||
details: None,
|
||||
path: None,
|
||||
range: None,
|
||||
},
|
||||
)),
|
||||
))
|
||||
.await
|
||||
.expect("writer should accept queued message");
|
||||
|
||||
timeout(Duration::from_secs(1), client_tracker.shutdown())
|
||||
.await
|
||||
.expect("shutdown should not hang on blocked server forwarding");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn initialize_with_new_stream_id_opens_new_connection_for_same_client() {
|
||||
let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let mut client_tracker =
|
||||
ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token);
|
||||
|
||||
client_tracker
|
||||
.handle_message(initialize_envelope_with_stream_id(
|
||||
"client-1",
|
||||
Some("stream-1"),
|
||||
))
|
||||
.await
|
||||
.expect("first initialize should open client");
|
||||
let first_connection_id = match transport_event_rx.recv().await.expect("open event") {
|
||||
TransportEvent::ConnectionOpened { connection_id, .. } => connection_id,
|
||||
other => panic!("expected connection opened, got {other:?}"),
|
||||
};
|
||||
let _ = transport_event_rx.recv().await.expect("initialize event");
|
||||
|
||||
client_tracker
|
||||
.handle_message(initialize_envelope_with_stream_id(
|
||||
"client-1",
|
||||
Some("stream-2"),
|
||||
))
|
||||
.await
|
||||
.expect("second initialize should open client");
|
||||
let second_connection_id = match transport_event_rx.recv().await.expect("open event") {
|
||||
TransportEvent::ConnectionOpened { connection_id, .. } => connection_id,
|
||||
other => panic!("expected connection opened, got {other:?}"),
|
||||
};
|
||||
|
||||
assert_ne!(first_connection_id, second_connection_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn legacy_initialize_without_stream_id_resets_inbound_seq_id() {
|
||||
let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let mut client_tracker =
|
||||
ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token);
|
||||
|
||||
client_tracker
|
||||
.handle_message(initialize_envelope("client-1"))
|
||||
.await
|
||||
.expect("initialize should open client");
|
||||
let connection_id = match transport_event_rx.recv().await.expect("open event") {
|
||||
TransportEvent::ConnectionOpened { connection_id, .. } => connection_id,
|
||||
other => panic!("expected connection opened, got {other:?}"),
|
||||
};
|
||||
let _ = transport_event_rx.recv().await.expect("initialize event");
|
||||
|
||||
client_tracker
|
||||
.handle_message(ClientEnvelope {
|
||||
event: ClientEvent::ClientMessage {
|
||||
message: JSONRPCMessage::Notification(
|
||||
codex_app_server_protocol::JSONRPCNotification {
|
||||
method: "initialized".to_string(),
|
||||
params: None,
|
||||
},
|
||||
),
|
||||
},
|
||||
client_id: ClientId("client-1".to_string()),
|
||||
stream_id: None,
|
||||
seq_id: Some(0),
|
||||
cursor: None,
|
||||
})
|
||||
.await
|
||||
.expect("legacy followup should be forwarded");
|
||||
|
||||
match transport_event_rx.recv().await.expect("followup event") {
|
||||
TransportEvent::IncomingMessage {
|
||||
connection_id: incoming_connection_id,
|
||||
..
|
||||
} => assert_eq!(incoming_connection_id, connection_id),
|
||||
other => panic!("expected incoming message, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
478
codex-rs/app-server/src/transport/remote_control/enroll.rs
Normal file
478
codex-rs/app-server/src/transport/remote_control/enroll.rs
Normal file
@@ -0,0 +1,478 @@
|
||||
use super::protocol::EnrollRemoteServerRequest;
|
||||
use super::protocol::EnrollRemoteServerResponse;
|
||||
use super::protocol::RemoteControlTarget;
|
||||
use axum::http::HeaderMap;
|
||||
use codex_login::default_client::build_reqwest_client;
|
||||
use codex_state::StateRuntime;
|
||||
use gethostname::gethostname;
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
|
||||
const REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES: usize = 4096;
|
||||
|
||||
const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
const OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id";
|
||||
const CF_RAY_HEADER: &str = "cf-ray";
|
||||
pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(super) struct RemoteControlEnrollment {
|
||||
pub(super) account_id: Option<String>,
|
||||
pub(super) environment_id: String,
|
||||
pub(super) server_id: String,
|
||||
pub(super) server_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(super) struct RemoteControlConnectionAuth {
|
||||
pub(super) bearer_token: String,
|
||||
pub(super) account_id: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn load_persisted_remote_control_enrollment(
|
||||
state_db: Option<&StateRuntime>,
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
account_id: Option<&str>,
|
||||
) -> Option<RemoteControlEnrollment> {
|
||||
let Some(state_db) = state_db else {
|
||||
info!(
|
||||
"remote control enrollment cache unavailable because sqlite state db is disabled: websocket_url={}, account_id={:?}",
|
||||
remote_control_target.websocket_url, account_id
|
||||
);
|
||||
return None;
|
||||
};
|
||||
let enrollment = match state_db
|
||||
.get_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
|
||||
.await
|
||||
{
|
||||
Ok(enrollment) => enrollment,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"failed to load persisted remote control enrollment: websocket_url={}, account_id={:?}, err={err}",
|
||||
remote_control_target.websocket_url, account_id
|
||||
);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
match enrollment {
|
||||
Some((server_id, environment_id, server_name)) => {
|
||||
info!(
|
||||
"reusing persisted remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}",
|
||||
remote_control_target.websocket_url, account_id, server_id, environment_id
|
||||
);
|
||||
Some(RemoteControlEnrollment {
|
||||
account_id: account_id.map(&str::to_string),
|
||||
environment_id,
|
||||
server_id,
|
||||
server_name,
|
||||
})
|
||||
}
|
||||
None => {
|
||||
info!(
|
||||
"no persisted remote control enrollment found: websocket_url={}, account_id={:?}",
|
||||
remote_control_target.websocket_url, account_id
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn update_persisted_remote_control_enrollment(
|
||||
state_db: Option<&StateRuntime>,
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
account_id: Option<&str>,
|
||||
enrollment: Option<&RemoteControlEnrollment>,
|
||||
) -> io::Result<()> {
|
||||
let Some(state_db) = state_db else {
|
||||
info!(
|
||||
"skipping remote control enrollment persistence because sqlite state db is disabled: websocket_url={}, account_id={:?}, has_enrollment={}",
|
||||
remote_control_target.websocket_url,
|
||||
account_id,
|
||||
enrollment.is_some()
|
||||
);
|
||||
return Ok(());
|
||||
};
|
||||
if let &Some(enrollment) = &enrollment
|
||||
&& enrollment.account_id.as_deref() != account_id
|
||||
{
|
||||
return Err(io::Error::other(format!(
|
||||
"enrollment account_id does not match expected account_id `{account_id:?}`"
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(enrollment) = enrollment {
|
||||
state_db
|
||||
.upsert_remote_control_enrollment(
|
||||
&remote_control_target.websocket_url,
|
||||
account_id,
|
||||
&enrollment.server_id,
|
||||
&enrollment.environment_id,
|
||||
&enrollment.server_name,
|
||||
)
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
info!(
|
||||
"persisted remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}",
|
||||
remote_control_target.websocket_url,
|
||||
account_id,
|
||||
enrollment.server_id,
|
||||
enrollment.environment_id
|
||||
);
|
||||
Ok(())
|
||||
} else {
|
||||
let rows_affected = state_db
|
||||
.delete_remote_control_enrollment(&remote_control_target.websocket_url, account_id)
|
||||
.await
|
||||
.map_err(io::Error::other)?;
|
||||
info!(
|
||||
"cleared persisted remote control enrollment: websocket_url={}, account_id={:?}, rows_affected={rows_affected}",
|
||||
remote_control_target.websocket_url, account_id
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn preview_remote_control_response_body(body: &[u8]) -> String {
|
||||
let body = String::from_utf8_lossy(body);
|
||||
let trimmed = body.trim();
|
||||
if trimmed.is_empty() {
|
||||
return "<empty>".to_string();
|
||||
}
|
||||
if trimmed.len() <= REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
|
||||
let mut cut = REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES;
|
||||
while !trimmed.is_char_boundary(cut) {
|
||||
cut = cut.saturating_sub(1);
|
||||
}
|
||||
let mut truncated = trimmed[..cut].to_string();
|
||||
truncated.push_str("...");
|
||||
truncated
|
||||
}
|
||||
|
||||
pub(crate) fn format_headers(headers: &HeaderMap) -> String {
|
||||
let request_id_str = headers
|
||||
.get(REQUEST_ID_HEADER)
|
||||
.or_else(|| headers.get(OAI_REQUEST_ID_HEADER))
|
||||
.map(|value| value.to_str().unwrap_or("<invalid utf-8>").to_owned())
|
||||
.unwrap_or_else(|| "<none>".to_owned());
|
||||
let cf_ray_str = headers
|
||||
.get(CF_RAY_HEADER)
|
||||
.map(|value| value.to_str().unwrap_or("<invalid utf-8>").to_owned())
|
||||
.unwrap_or_else(|| "<none>".to_owned());
|
||||
format!("request-id: {request_id_str}, cf-ray: {cf_ray_str}")
|
||||
}
|
||||
|
||||
pub(super) async fn enroll_remote_control_server(
|
||||
remote_control_target: &RemoteControlTarget,
|
||||
auth: &RemoteControlConnectionAuth,
|
||||
) -> io::Result<RemoteControlEnrollment> {
|
||||
let enroll_url = &remote_control_target.enroll_url;
|
||||
let server_name = gethostname().to_string_lossy().trim().to_string();
|
||||
let request = EnrollRemoteServerRequest {
|
||||
name: server_name.clone(),
|
||||
os: std::env::consts::OS,
|
||||
arch: std::env::consts::ARCH,
|
||||
app_server_version: env!("CARGO_PKG_VERSION"),
|
||||
};
|
||||
let client = build_reqwest_client();
|
||||
let mut http_request = client
|
||||
.post(enroll_url)
|
||||
.timeout(REMOTE_CONTROL_ENROLL_TIMEOUT)
|
||||
.bearer_auth(&auth.bearer_token)
|
||||
.json(&request);
|
||||
let account_id = auth.account_id.as_deref();
|
||||
if let Some(account_id) = account_id {
|
||||
http_request = http_request.header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id);
|
||||
}
|
||||
|
||||
let response = http_request.send().await.map_err(|err| {
|
||||
io::Error::other(format!(
|
||||
"failed to enroll remote control server at `{enroll_url}`: {err}"
|
||||
))
|
||||
})?;
|
||||
let headers = response.headers().clone();
|
||||
let status = response.status();
|
||||
let body = response.bytes().await.map_err(|err| {
|
||||
io::Error::other(format!(
|
||||
"failed to read remote control enrollment response from `{enroll_url}`: {err}"
|
||||
))
|
||||
})?;
|
||||
let body_preview = preview_remote_control_response_body(&body);
|
||||
if !status.is_success() {
|
||||
let headers_str = format_headers(&headers);
|
||||
let error_kind = if matches!(status.as_u16(), 401 | 403) {
|
||||
ErrorKind::PermissionDenied
|
||||
} else {
|
||||
ErrorKind::Other
|
||||
};
|
||||
return Err(io::Error::new(
|
||||
error_kind,
|
||||
format!(
|
||||
"remote control server enrollment failed at `{enroll_url}`: HTTP {status}, {headers_str}, body: {body_preview}"
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let enrollment = serde_json::from_slice::<EnrollRemoteServerResponse>(&body).map_err(|err| {
|
||||
let headers_str = format_headers(&headers);
|
||||
io::Error::other(format!(
|
||||
"failed to parse remote control enrollment response from `{enroll_url}`: HTTP {status}, {headers_str}, body: {body_preview}, decode error: {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(RemoteControlEnrollment {
|
||||
account_id: account_id.map(&str::to_string),
|
||||
environment_id: enrollment.environment_id,
|
||||
server_id: enrollment.server_id,
|
||||
server_name,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::transport::remote_control::protocol::normalize_remote_control_url;
|
||||
use codex_state::StateRuntime;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc<StateRuntime> {
|
||||
StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("state runtime should initialize")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persisted_remote_control_enrollment_round_trips_by_target_and_account() {
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
let state_db = remote_control_state_runtime(&codex_home).await;
|
||||
let first_target = normalize_remote_control_url("https://chatgpt.com/remote/control")
|
||||
.expect("first target should parse");
|
||||
let second_target =
|
||||
normalize_remote_control_url("https://api.chatgpt-staging.com/other/control")
|
||||
.expect("second target should parse");
|
||||
let first_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account-a".to_string()),
|
||||
environment_id: "env_first".to_string(),
|
||||
server_id: "srv_e_first".to_string(),
|
||||
server_name: "first-server".to_string(),
|
||||
};
|
||||
let second_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account-a".to_string()),
|
||||
environment_id: "env_second".to_string(),
|
||||
server_id: "srv_e_second".to_string(),
|
||||
server_name: "second-server".to_string(),
|
||||
};
|
||||
|
||||
update_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&first_target,
|
||||
Some("account-a"),
|
||||
Some(&first_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("first enrollment should persist");
|
||||
update_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&second_target,
|
||||
Some("account-a"),
|
||||
Some(&second_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("second enrollment should persist");
|
||||
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&first_target,
|
||||
Some("account-a"),
|
||||
)
|
||||
.await,
|
||||
Some(first_enrollment.clone())
|
||||
);
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&first_target,
|
||||
Some("account-b"),
|
||||
)
|
||||
.await,
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&second_target,
|
||||
Some("account-a"),
|
||||
)
|
||||
.await,
|
||||
Some(second_enrollment)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn clearing_persisted_remote_control_enrollment_removes_only_matching_entry() {
|
||||
let codex_home = TempDir::new().expect("temp dir should create");
|
||||
let state_db = remote_control_state_runtime(&codex_home).await;
|
||||
let first_target = normalize_remote_control_url("https://chatgpt.com/remote/control")
|
||||
.expect("first target should parse");
|
||||
let second_target =
|
||||
normalize_remote_control_url("https://api.chatgpt-staging.com/other/control")
|
||||
.expect("second target should parse");
|
||||
let first_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account-a".to_string()),
|
||||
environment_id: "env_first".to_string(),
|
||||
server_id: "srv_e_first".to_string(),
|
||||
server_name: "first-server".to_string(),
|
||||
};
|
||||
let second_enrollment = RemoteControlEnrollment {
|
||||
account_id: Some("account-a".to_string()),
|
||||
environment_id: "env_second".to_string(),
|
||||
server_id: "srv_e_second".to_string(),
|
||||
server_name: "second-server".to_string(),
|
||||
};
|
||||
|
||||
update_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&first_target,
|
||||
Some("account-a"),
|
||||
Some(&first_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("first enrollment should persist");
|
||||
update_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&second_target,
|
||||
Some("account-a"),
|
||||
Some(&second_enrollment),
|
||||
)
|
||||
.await
|
||||
.expect("second enrollment should persist");
|
||||
|
||||
update_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&first_target,
|
||||
Some("account-a"),
|
||||
/*enrollment*/ None,
|
||||
)
|
||||
.await
|
||||
.expect("matching enrollment should clear");
|
||||
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&first_target,
|
||||
Some("account-a"),
|
||||
)
|
||||
.await,
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
load_persisted_remote_control_enrollment(
|
||||
Some(state_db.as_ref()),
|
||||
&second_target,
|
||||
Some("account-a"),
|
||||
)
|
||||
.await,
|
||||
Some(second_enrollment)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enroll_remote_control_server_parse_failure_includes_response_body() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let remote_control_url = format!(
|
||||
"http://127.0.0.1:{}/backend-api/",
|
||||
listener
|
||||
.local_addr()
|
||||
.expect("listener should have a local addr")
|
||||
.port()
|
||||
);
|
||||
let remote_control_target =
|
||||
normalize_remote_control_url(&remote_control_url).expect("target should parse");
|
||||
let enroll_url = remote_control_target.enroll_url.clone();
|
||||
let response_body = json!({
|
||||
"error": "not enrolled",
|
||||
});
|
||||
let expected_body = response_body.to_string();
|
||||
let server_task = tokio::spawn(async move {
|
||||
let stream = accept_http_request(&listener).await;
|
||||
respond_with_json(stream, response_body).await;
|
||||
});
|
||||
|
||||
let err = enroll_remote_control_server(
|
||||
&remote_control_target,
|
||||
&RemoteControlConnectionAuth {
|
||||
bearer_token: "Access Token".to_string(),
|
||||
account_id: Some("account_id".to_string()),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect_err("invalid response should fail to parse");
|
||||
|
||||
server_task.await.expect("server task should succeed");
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!(
|
||||
"failed to parse remote control enrollment response from `{enroll_url}`: HTTP 200 OK, request-id: <none>, cf-ray: <none>, body: {expected_body}, decode error: missing field `server_id` at line 1 column {}",
|
||||
expected_body.len()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
async fn accept_http_request(listener: &TcpListener) -> TcpStream {
|
||||
let (stream, _) = timeout(Duration::from_secs(5), listener.accept())
|
||||
.await
|
||||
.expect("HTTP request should arrive in time")
|
||||
.expect("listener accept should succeed");
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
let mut request_line = String::new();
|
||||
reader
|
||||
.read_line(&mut request_line)
|
||||
.await
|
||||
.expect("request line should read");
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
reader
|
||||
.read_line(&mut line)
|
||||
.await
|
||||
.expect("header line should read");
|
||||
if line == "\r\n" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
reader.into_inner()
|
||||
}
|
||||
|
||||
async fn respond_with_json(mut stream: TcpStream, body: serde_json::Value) {
|
||||
let body = body.to_string();
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
|
||||
body.len()
|
||||
);
|
||||
stream
|
||||
.write_all(response.as_bytes())
|
||||
.await
|
||||
.expect("response should write");
|
||||
stream.flush().await.expect("response should flush");
|
||||
}
|
||||
}
|
||||
62
codex-rs/app-server/src/transport/remote_control/mod.rs
Normal file
62
codex-rs/app-server/src/transport/remote_control/mod.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
mod client_tracker;
|
||||
mod enroll;
|
||||
mod protocol;
|
||||
mod websocket;
|
||||
|
||||
use crate::transport::remote_control::websocket::RemoteControlWebsocket;
|
||||
use crate::transport::remote_control::websocket::load_remote_control_auth;
|
||||
|
||||
pub use self::protocol::ClientId;
|
||||
use self::protocol::ServerEvent;
|
||||
use self::protocol::StreamId;
|
||||
use self::protocol::normalize_remote_control_url;
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
use super::next_connection_id;
|
||||
use codex_login::AuthManager;
|
||||
use codex_state::StateRuntime;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub(super) struct QueuedServerEnvelope {
|
||||
pub(super) event: ServerEvent,
|
||||
pub(super) client_id: ClientId,
|
||||
pub(super) stream_id: StreamId,
|
||||
pub(super) write_complete_tx: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
pub(crate) async fn start_remote_control(
|
||||
remote_control_url: String,
|
||||
state_db: Option<Arc<StateRuntime>>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
shutdown_token: CancellationToken,
|
||||
) -> io::Result<JoinHandle<()>> {
|
||||
let remote_control_target = normalize_remote_control_url(&remote_control_url)?;
|
||||
validate_remote_control_auth(&auth_manager).await?;
|
||||
|
||||
Ok(tokio::spawn(async move {
|
||||
RemoteControlWebsocket::new(
|
||||
remote_control_target,
|
||||
state_db,
|
||||
auth_manager,
|
||||
transport_event_tx,
|
||||
shutdown_token,
|
||||
)
|
||||
.run()
|
||||
.await;
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) async fn validate_remote_control_auth(
|
||||
auth_manager: &Arc<AuthManager>,
|
||||
) -> io::Result<()> {
|
||||
load_remote_control_auth(auth_manager).await.map(|_| ())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
252
codex-rs/app-server/src/transport/remote_control/protocol.rs
Normal file
252
codex-rs/app-server/src/transport/remote_control/protocol.rs
Normal file
@@ -0,0 +1,252 @@
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use url::Host;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(super) struct RemoteControlTarget {
|
||||
pub(super) websocket_url: String,
|
||||
pub(super) enroll_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(super) struct EnrollRemoteServerRequest {
|
||||
pub(super) name: String,
|
||||
pub(super) os: &'static str,
|
||||
pub(super) arch: &'static str,
|
||||
pub(super) app_server_version: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(super) struct EnrollRemoteServerResponse {
|
||||
pub(super) server_id: String,
|
||||
pub(super) environment_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct ClientId(pub String);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct StreamId(pub String);
|
||||
|
||||
impl StreamId {
|
||||
pub fn new_random() -> Self {
|
||||
Self(uuid::Uuid::now_v7().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ClientEvent {
|
||||
ClientMessage {
|
||||
message: JSONRPCMessage,
|
||||
},
|
||||
/// Backend-generated acknowledgement for all server envelopes addressed to
|
||||
/// `client_id` whose envelope `seq_id` is less than or equal to this ack's
|
||||
/// `seq_id`. This cursor is client-scoped, not stream-scoped, so receivers
|
||||
/// must not use `stream_id` to partition acks.
|
||||
Ack,
|
||||
Ping,
|
||||
ClientClosed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) struct ClientEnvelope {
|
||||
#[serde(flatten)]
|
||||
pub(crate) event: ClientEvent,
|
||||
#[serde(rename = "client_id")]
|
||||
pub(crate) client_id: ClientId,
|
||||
#[serde(rename = "stream_id", skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) stream_id: Option<StreamId>,
|
||||
/// For `Ack`, this is the backend-generated per-client cursor over
|
||||
/// `ServerEnvelope.seq_id`.
|
||||
#[serde(rename = "seq_id", skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) seq_id: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PongStatus {
|
||||
Active,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ServerEvent {
|
||||
ServerMessage {
|
||||
message: Box<OutgoingMessage>,
|
||||
},
|
||||
#[allow(dead_code)]
|
||||
Ack,
|
||||
Pong {
|
||||
status: PongStatus,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) struct ServerEnvelope {
|
||||
#[serde(flatten)]
|
||||
pub(crate) event: ServerEvent,
|
||||
#[serde(rename = "client_id")]
|
||||
pub(crate) client_id: ClientId,
|
||||
#[serde(rename = "stream_id")]
|
||||
pub(crate) stream_id: StreamId,
|
||||
#[serde(rename = "seq_id")]
|
||||
pub(crate) seq_id: u64,
|
||||
}
|
||||
|
||||
fn is_allowed_chatgpt_host(host: &Option<Host<&str>>) -> bool {
|
||||
let Some(Host::Domain(host)) = *host else {
|
||||
return false;
|
||||
};
|
||||
host == "chatgpt.com"
|
||||
|| host == "chatgpt-staging.com"
|
||||
|| host.ends_with(".chatgpt.com")
|
||||
|| host.ends_with(".chatgpt-staging.com")
|
||||
}
|
||||
|
||||
fn is_localhost(host: &Option<Host<&str>>) -> bool {
|
||||
match host {
|
||||
Some(Host::Domain("localhost")) => true,
|
||||
Some(Host::Ipv4(ip)) => ip.is_loopback(),
|
||||
Some(Host::Ipv6(ip)) => ip.is_loopback(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn normalize_remote_control_url(
|
||||
remote_control_url: &str,
|
||||
) -> io::Result<RemoteControlTarget> {
|
||||
let map_url_parse_error = |err: url::ParseError| -> io::Error {
|
||||
io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("invalid remote control URL `{remote_control_url}`: {err}"),
|
||||
)
|
||||
};
|
||||
let map_scheme_error = |_: ()| -> io::Error {
|
||||
io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!(
|
||||
"invalid remote control URL `{remote_control_url}`; expected HTTPS URL for chatgpt.com or chatgpt-staging.com, or HTTP/HTTPS URL for localhost"
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
let mut remote_control_url = Url::parse(remote_control_url).map_err(map_url_parse_error)?;
|
||||
if !remote_control_url.path().ends_with('/') {
|
||||
let normalized_path = format!("{}/", remote_control_url.path());
|
||||
remote_control_url.set_path(&normalized_path);
|
||||
}
|
||||
|
||||
let enroll_url = remote_control_url
|
||||
.join("wham/remote/control/server/enroll")
|
||||
.map_err(map_url_parse_error)?;
|
||||
let mut websocket_url = remote_control_url
|
||||
.join("wham/remote/control/server")
|
||||
.map_err(map_url_parse_error)?;
|
||||
let host = enroll_url.host();
|
||||
match enroll_url.scheme() {
|
||||
"https" if is_localhost(&host) || is_allowed_chatgpt_host(&host) => {
|
||||
websocket_url.set_scheme("wss").map_err(map_scheme_error)?;
|
||||
}
|
||||
"http" if is_localhost(&host) => {
|
||||
websocket_url.set_scheme("ws").map_err(map_scheme_error)?;
|
||||
}
|
||||
_ => return Err(map_scheme_error(())),
|
||||
}
|
||||
|
||||
Ok(RemoteControlTarget {
|
||||
websocket_url: websocket_url.to_string(),
|
||||
enroll_url: enroll_url.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn normalize_remote_control_url_accepts_chatgpt_https_urls() {
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("https://chatgpt.com/backend-api")
|
||||
.expect("chatgpt.com URL should normalize"),
|
||||
RemoteControlTarget {
|
||||
websocket_url: "wss://chatgpt.com/backend-api/wham/remote/control/server"
|
||||
.to_string(),
|
||||
enroll_url: "https://chatgpt.com/backend-api/wham/remote/control/server/enroll"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("https://api.chatgpt-staging.com/backend-api")
|
||||
.expect("chatgpt-staging.com subdomain URL should normalize"),
|
||||
RemoteControlTarget {
|
||||
websocket_url:
|
||||
"wss://api.chatgpt-staging.com/backend-api/wham/remote/control/server"
|
||||
.to_string(),
|
||||
enroll_url:
|
||||
"https://api.chatgpt-staging.com/backend-api/wham/remote/control/server/enroll"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_remote_control_url_accepts_localhost_urls() {
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("http://localhost:8080/backend-api")
|
||||
.expect("localhost http URL should normalize"),
|
||||
RemoteControlTarget {
|
||||
websocket_url: "ws://localhost:8080/backend-api/wham/remote/control/server"
|
||||
.to_string(),
|
||||
enroll_url: "http://localhost:8080/backend-api/wham/remote/control/server/enroll"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_remote_control_url("https://localhost:8443/backend-api")
|
||||
.expect("localhost https URL should normalize"),
|
||||
RemoteControlTarget {
|
||||
websocket_url: "wss://localhost:8443/backend-api/wham/remote/control/server"
|
||||
.to_string(),
|
||||
enroll_url: "https://localhost:8443/backend-api/wham/remote/control/server/enroll"
|
||||
.to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_remote_control_url_rejects_unsupported_urls() {
|
||||
for remote_control_url in [
|
||||
"http://chatgpt.com/backend-api",
|
||||
"http://example.com/backend-api",
|
||||
"https://example.com/backend-api",
|
||||
"https://chatgpt.com.evil.com/backend-api",
|
||||
"https://evilchatgpt.com/backend-api",
|
||||
"https://foo.localhost/backend-api",
|
||||
] {
|
||||
let err = normalize_remote_control_url(remote_control_url)
|
||||
.expect_err("unsupported URL should be rejected");
|
||||
|
||||
assert_eq!(err.kind(), ErrorKind::InvalidInput);
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!(
|
||||
"invalid remote control URL `{remote_control_url}`; expected HTTPS URL for chatgpt.com or chatgpt-staging.com, or HTTP/HTTPS URL for localhost"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
1123
codex-rs/app-server/src/transport/remote_control/tests.rs
Normal file
1123
codex-rs/app-server/src/transport/remote_control/tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
1344
codex-rs/app-server/src/transport/remote_control/websocket.rs
Normal file
1344
codex-rs/app-server/src/transport/remote_control/websocket.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
use super::CHANNEL_CAPACITY;
|
||||
use super::TransportEvent;
|
||||
use super::forward_incoming_message;
|
||||
use super::next_connection_id;
|
||||
use super::serialize_outgoing_message;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
@@ -20,7 +20,7 @@ pub(crate) async fn start_stdio_connection(
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
stdio_handles: &mut Vec<JoinHandle<()>>,
|
||||
) -> IoResult<()> {
|
||||
let connection_id = ConnectionId(0);
|
||||
let connection_id = next_connection_id();
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(CHANNEL_CAPACITY);
|
||||
let writer_tx_for_reader = writer_tx.clone();
|
||||
transport_event_tx
|
||||
|
||||
@@ -4,6 +4,7 @@ use super::auth::WebsocketAuthPolicy;
|
||||
use super::auth::authorize_upgrade;
|
||||
use super::auth::should_warn_about_unauthenticated_non_loopback_listener;
|
||||
use super::forward_incoming_message;
|
||||
use super::next_connection_id;
|
||||
use super::serialize_outgoing_message;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
@@ -32,8 +33,6 @@ use owo_colors::Style;
|
||||
use std::io::Result as IoResult;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
@@ -75,7 +74,6 @@ fn print_websocket_startup_banner(addr: SocketAddr) {
|
||||
#[derive(Clone)]
|
||||
struct WebSocketListenerState {
|
||||
transport_event_tx: mpsc::Sender<TransportEvent>,
|
||||
connection_counter: Arc<AtomicU64>,
|
||||
auth_policy: Arc<WebsocketAuthPolicy>,
|
||||
}
|
||||
|
||||
@@ -113,7 +111,7 @@ async fn websocket_upgrade_handler(
|
||||
);
|
||||
return (err.status_code(), err.message()).into_response();
|
||||
}
|
||||
let connection_id = ConnectionId(state.connection_counter.fetch_add(1, Ordering::Relaxed));
|
||||
let connection_id = next_connection_id();
|
||||
info!(%peer_addr, "websocket client connected");
|
||||
websocket
|
||||
.on_upgrade(move |stream| async move {
|
||||
@@ -146,7 +144,6 @@ pub(crate) async fn start_websocket_acceptor(
|
||||
.layer(middleware::from_fn(reject_requests_with_origin_header))
|
||||
.with_state(WebSocketListenerState {
|
||||
transport_event_tx,
|
||||
connection_counter: Arc::new(AtomicU64::new(1)),
|
||||
auth_policy: Arc::new(auth_policy),
|
||||
});
|
||||
let server = axum::serve(
|
||||
|
||||
@@ -346,7 +346,7 @@ struct AppServerCommand {
|
||||
subcommand: Option<AppServerSubcommand>,
|
||||
|
||||
/// Transport endpoint URL. Supported values: `stdio://` (default),
|
||||
/// `ws://IP:PORT`.
|
||||
/// `ws://IP:PORT`, `off`.
|
||||
#[arg(
|
||||
long = "listen",
|
||||
value_name = "URL",
|
||||
@@ -1993,6 +1993,12 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_listen_off_parses() {
|
||||
let app_server = app_server_from_args(["codex", "app-server", "--listen", "off"].as_ref());
|
||||
assert_eq!(app_server.listen, codex_app_server::AppServerTransport::Off);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_server_listen_invalid_url_fails_to_parse() {
|
||||
let parse_result =
|
||||
|
||||
@@ -437,6 +437,9 @@
|
||||
"realtime_conversation": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"remote_control": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"remote_models": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -2144,6 +2147,9 @@
|
||||
"realtime_conversation": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"remote_control": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"remote_models": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
||||
@@ -6,11 +6,16 @@ use crate::agent::next_thread_spawn_depth;
|
||||
use crate::agent::role::DEFAULT_ROLE_NAME;
|
||||
use crate::agent::role::apply_role_to_config;
|
||||
use codex_protocol::AgentPath;
|
||||
use codex_protocol::models::DeveloperInstructions;
|
||||
use codex_protocol::protocol::InterAgentCommunication;
|
||||
use codex_protocol::protocol::Op;
|
||||
|
||||
pub(crate) struct Handler;
|
||||
|
||||
pub(crate) const SPAWN_AGENT_DEVELOPER_INSTRUCTIONS: &str = r#"<spawned_agent_context>
|
||||
You are a newly spawned agent in a team of agents collaborating to complete a task. You can spawn sub-agents to handle subtasks, and those sub-agents can spawn their own sub-agents. You are responsible for returning the response to your assigned task in the final channel. When you give your response, the contents of your response in the final channel will be immediately delivered back to your parent agent. The prior conversation history was forked from your parent agent. Treat the next user message as your assigned task, and use the forked history only as background context.
|
||||
</spawned_agent_context>"#;
|
||||
|
||||
impl ToolHandler for Handler {
|
||||
type Output = SpawnAgentResult;
|
||||
|
||||
@@ -78,6 +83,17 @@ impl ToolHandler for Handler {
|
||||
.map_err(FunctionCallError::RespondToModel)?;
|
||||
apply_spawn_agent_runtime_overrides(&mut config, turn.as_ref())?;
|
||||
apply_spawn_agent_overrides(&mut config, child_depth);
|
||||
config.developer_instructions = Some(
|
||||
if let Some(existing_instructions) = config.developer_instructions.take() {
|
||||
DeveloperInstructions::new(existing_instructions)
|
||||
.concat(DeveloperInstructions::new(
|
||||
SPAWN_AGENT_DEVELOPER_INSTRUCTIONS,
|
||||
))
|
||||
.into_text()
|
||||
} else {
|
||||
DeveloperInstructions::new(SPAWN_AGENT_DEVELOPER_INSTRUCTIONS).into_text()
|
||||
},
|
||||
);
|
||||
|
||||
let spawn_source = thread_spawn_source(
|
||||
session.conversation_id,
|
||||
|
||||
@@ -195,23 +195,28 @@ mv tokens.next tokens.txt
|
||||
|
||||
#[cfg(windows)]
|
||||
let (command, args) = {
|
||||
let script_path = tempdir.path().join("print-token.ps1");
|
||||
let script_path = tempdir.path().join("print-token.cmd");
|
||||
std::fs::write(
|
||||
&script_path,
|
||||
r#"$lines = @(Get-Content -Path tokens.txt)
|
||||
if ($lines.Count -eq 0) { exit 1 }
|
||||
Write-Output $lines[0]
|
||||
$lines | Select-Object -Skip 1 | Set-Content -Path tokens.txt
|
||||
r#"@echo off
|
||||
setlocal EnableExtensions DisableDelayedExpansion
|
||||
|
||||
set "first_line="
|
||||
<tokens.txt set /p first_line=
|
||||
if not defined first_line exit /b 1
|
||||
|
||||
echo(%first_line%
|
||||
more +1 tokens.txt > tokens.next
|
||||
move /y tokens.next tokens.txt >nul
|
||||
"#,
|
||||
)?;
|
||||
(
|
||||
"powershell.exe".to_string(),
|
||||
"cmd.exe".to_string(),
|
||||
vec![
|
||||
"-NoProfile".to_string(),
|
||||
"-ExecutionPolicy".to_string(),
|
||||
"Bypass".to_string(),
|
||||
"-File".to_string(),
|
||||
".\\print-token.ps1".to_string(),
|
||||
"/D".to_string(),
|
||||
"/Q".to_string(),
|
||||
"/C".to_string(),
|
||||
".\\print-token.cmd".to_string(),
|
||||
],
|
||||
)
|
||||
};
|
||||
@@ -227,7 +232,8 @@ $lines | Select-Object -Skip 1 | Set-Content -Path tokens.txt
|
||||
ModelProviderAuthInfo {
|
||||
command: self.command.clone(),
|
||||
args: self.args.clone(),
|
||||
timeout_ms: non_zero_u64(/*value*/ 1_000),
|
||||
// Match the provider-auth default to avoid brittle shell-startup timing in CI.
|
||||
timeout_ms: non_zero_u64(/*value*/ 5_000),
|
||||
refresh_interval_ms: 60_000,
|
||||
cwd: match codex_utils_absolute_path::AbsolutePathBuf::try_from(self.tempdir.path()) {
|
||||
Ok(cwd) => cwd,
|
||||
|
||||
@@ -35,6 +35,7 @@ const REQUESTED_MODEL: &str = "gpt-5.1";
|
||||
const REQUESTED_REASONING_EFFORT: ReasoningEffort = ReasoningEffort::Low;
|
||||
const ROLE_MODEL: &str = "gpt-5.1-codex-max";
|
||||
const ROLE_REASONING_EFFORT: ReasoningEffort = ReasoningEffort::High;
|
||||
const SPAWNED_AGENT_DEVELOPER_INSTRUCTIONS: &str = "You are a newly spawned agent in a team of agents collaborating to complete a task. You can spawn sub-agents to handle subtasks, and those sub-agents can spawn their own sub-agents. You are responsible for returning the response to your assigned task in the final channel. When you give your response, the contents of your response in the final channel will be immediately delivered back to your parent agent. The prior conversation history was forked from your parent agent. Treat the next user message as your assigned task, and use the forked history only as background context.";
|
||||
|
||||
fn body_contains(req: &wiremock::Request, text: &str) -> bool {
|
||||
let is_zstd = req
|
||||
@@ -413,6 +414,99 @@ async fn spawn_agent_requested_model_and_reasoning_override_inherited_settings_w
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn spawned_multi_agent_v2_child_receives_xml_tagged_developer_context() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let spawn_args = serde_json::to_string(&json!({
|
||||
"message": CHILD_PROMPT,
|
||||
"task_name": "worker",
|
||||
}))?;
|
||||
mount_sse_once_match(
|
||||
&server,
|
||||
|req: &wiremock::Request| body_contains(req, TURN_1_PROMPT),
|
||||
sse(vec![
|
||||
ev_response_created("resp-turn1-1"),
|
||||
ev_function_call(SPAWN_CALL_ID, "spawn_agent", &spawn_args),
|
||||
ev_completed("resp-turn1-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let _child_request_log = mount_sse_once_match(
|
||||
&server,
|
||||
|req: &wiremock::Request| {
|
||||
body_contains(req, CHILD_PROMPT) && !body_contains(req, SPAWN_CALL_ID)
|
||||
},
|
||||
sse(vec![
|
||||
ev_response_created("resp-child-1"),
|
||||
ev_completed("resp-child-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let _turn1_followup = mount_sse_once_match(
|
||||
&server,
|
||||
|req: &wiremock::Request| body_contains(req, SPAWN_CALL_ID),
|
||||
sse(vec![
|
||||
ev_response_created("resp-turn1-2"),
|
||||
ev_assistant_message("msg-turn1-2", "parent done"),
|
||||
ev_completed("resp-turn1-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config
|
||||
.features
|
||||
.enable(Feature::Collab)
|
||||
.expect("test config should allow feature update");
|
||||
config
|
||||
.features
|
||||
.enable(Feature::MultiAgentV2)
|
||||
.expect("test config should allow feature update");
|
||||
config.developer_instructions = Some("Parent developer instructions.".to_string());
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
test.submit_turn(TURN_1_PROMPT).await?;
|
||||
|
||||
let deadline = Instant::now() + Duration::from_secs(2);
|
||||
let child_request = loop {
|
||||
if let Some(request) = server
|
||||
.received_requests()
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.find(|request| {
|
||||
body_contains(request, CHILD_PROMPT)
|
||||
&& body_contains(request, "<spawned_agent_context>")
|
||||
&& body_contains(request, SPAWNED_AGENT_DEVELOPER_INSTRUCTIONS)
|
||||
&& !body_contains(request, SPAWN_CALL_ID)
|
||||
})
|
||||
{
|
||||
break request;
|
||||
}
|
||||
if Instant::now() >= deadline {
|
||||
anyhow::bail!("timed out waiting for spawned child request with developer context");
|
||||
}
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
};
|
||||
assert!(body_contains(
|
||||
&child_request,
|
||||
"Parent developer instructions."
|
||||
));
|
||||
assert!(body_contains(&child_request, "<spawned_agent_context>"));
|
||||
assert!(body_contains(
|
||||
&child_request,
|
||||
SPAWNED_AGENT_DEVELOPER_INSTRUCTIONS
|
||||
));
|
||||
assert!(body_contains(&child_request, CHILD_PROMPT));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn spawn_agent_role_overrides_requested_model_and_reasoning_settings() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
@@ -176,6 +176,8 @@ pub enum Feature {
|
||||
FastMode,
|
||||
/// Enable experimental realtime voice conversation mode in the TUI.
|
||||
RealtimeConversation,
|
||||
/// Connect app-server to the ChatGPT remote control service.
|
||||
RemoteControl,
|
||||
/// Removed compatibility flag. The TUI now always uses the app-server implementation.
|
||||
TuiAppServer,
|
||||
/// Prevent idle system sleep while a turn is actively running.
|
||||
@@ -825,6 +827,12 @@ pub const FEATURES: &[FeatureSpec] = &[
|
||||
stage: Stage::UnderDevelopment,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::RemoteControl,
|
||||
key: "remote_control",
|
||||
stage: Stage::UnderDevelopment,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::TuiAppServer,
|
||||
key: "tui_app_server",
|
||||
|
||||
@@ -165,6 +165,12 @@ fn image_detail_original_feature_is_under_development() {
|
||||
assert_eq!(Feature::ImageDetailOriginal.default_enabled(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_control_is_under_development() {
|
||||
assert_eq!(Feature::RemoteControl.stage(), Stage::UnderDevelopment);
|
||||
assert_eq!(Feature::RemoteControl.default_enabled(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collab_is_legacy_alias_for_multi_agent() {
|
||||
assert_eq!(feature_for_key("multi_agent"), Some(Feature::Collab));
|
||||
|
||||
@@ -1404,21 +1404,6 @@ impl AuthManager {
|
||||
))
|
||||
}
|
||||
|
||||
pub fn shared_with_external_auth(
|
||||
codex_home: PathBuf,
|
||||
enable_codex_api_key_env: bool,
|
||||
auth_credentials_store_mode: AuthCredentialsStoreMode,
|
||||
external_auth: Arc<dyn ExternalAuth>,
|
||||
) -> Arc<Self> {
|
||||
let manager = Self::shared(
|
||||
codex_home,
|
||||
enable_codex_api_key_env,
|
||||
auth_credentials_store_mode,
|
||||
);
|
||||
manager.set_external_auth(external_auth);
|
||||
manager
|
||||
}
|
||||
|
||||
pub fn unauthorized_recovery(self: &Arc<Self>) -> UnauthorizedRecovery {
|
||||
UnauthorizedRecovery::new(Arc::clone(self))
|
||||
}
|
||||
|
||||
@@ -26,3 +26,6 @@ serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tiny_http = { workspace = true }
|
||||
zeroize = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
|
||||
@@ -35,18 +35,20 @@ curl --fail --silent --show-error "${PROXY_BASE_URL}/shutdown"
|
||||
- Listens on the provided port or an ephemeral port if `--port` is not specified.
|
||||
- Accepts exactly `POST /v1/responses` (no query string). The request body is forwarded to `https://api.openai.com/v1/responses` with `Authorization: Bearer <key>` set. All original request headers (except any incoming `Authorization`) are forwarded upstream, with `Host` overridden to `api.openai.com`. For other requests, it responds with `403`.
|
||||
- Optionally writes a single-line JSON file with server info, currently `{ "port": <u16>, "pid": <u32> }`.
|
||||
- Optionally writes request/response JSON dumps to a directory. Each accepted request gets a pair of files that share a sequence/timestamp prefix, for example `000001-1846179912345-request.json` and `000001-1846179912345-response.json`. Header values are dumped in full except `Authorization` and any header whose name includes `cookie`, which are redacted. Bodies are written as parsed JSON when possible, otherwise as UTF-8 text.
|
||||
- Optional `--http-shutdown` enables `GET /shutdown` to terminate the process with exit code `0`. This allows one user (e.g., `root`) to start the proxy and another unprivileged user on the host to shut it down.
|
||||
|
||||
## CLI
|
||||
|
||||
```
|
||||
codex-responses-api-proxy [--port <PORT>] [--server-info <FILE>] [--http-shutdown] [--upstream-url <URL>]
|
||||
codex-responses-api-proxy [--port <PORT>] [--server-info <FILE>] [--http-shutdown] [--upstream-url <URL>] [--dump-dir <DIR>]
|
||||
```
|
||||
|
||||
- `--port <PORT>`: Port to bind on `127.0.0.1`. If omitted, an ephemeral port is chosen.
|
||||
- `--server-info <FILE>`: If set, the proxy writes a single line of JSON with `{ "port": <PORT>, "pid": <PID> }` once listening.
|
||||
- `--http-shutdown`: If set, enables `GET /shutdown` to exit the process with code `0`.
|
||||
- `--upstream-url <URL>`: Absolute URL to forward requests to. Defaults to `https://api.openai.com/v1/responses`.
|
||||
- `--dump-dir <DIR>`: If set, writes one request JSON file and one response JSON file per accepted proxy call under this directory. Filenames use a shared sequence/timestamp prefix so each pair is easy to correlate.
|
||||
- Authentication is fixed to `Authorization: Bearer <key>` to match the Codex CLI expectations.
|
||||
|
||||
For Azure, for example (ensure your deployment accepts `Authorization: Bearer <key>`):
|
||||
|
||||
360
codex-rs/responses-api-proxy/src/dump.rs
Normal file
360
codex-rs/responses-api-proxy/src/dump.rs
Normal file
@@ -0,0 +1,360 @@
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::io::Read;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::SystemTime;
|
||||
use std::time::UNIX_EPOCH;
|
||||
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use tiny_http::Header;
|
||||
use tiny_http::Method;
|
||||
|
||||
const AUTHORIZATION_HEADER_NAME: &str = "authorization";
|
||||
const REDACTED_HEADER_VALUE: &str = "[REDACTED]";
|
||||
|
||||
pub(crate) struct ExchangeDumper {
|
||||
dump_dir: PathBuf,
|
||||
next_sequence: AtomicU64,
|
||||
}
|
||||
|
||||
impl ExchangeDumper {
|
||||
pub(crate) fn new(dump_dir: PathBuf) -> io::Result<Self> {
|
||||
fs::create_dir_all(&dump_dir)?;
|
||||
|
||||
Ok(Self {
|
||||
dump_dir,
|
||||
next_sequence: AtomicU64::new(1),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn dump_request(
|
||||
&self,
|
||||
method: &Method,
|
||||
url: &str,
|
||||
headers: &[Header],
|
||||
body: &[u8],
|
||||
) -> io::Result<ExchangeDump> {
|
||||
let sequence = self.next_sequence.fetch_add(1, Ordering::Relaxed);
|
||||
let timestamp_ms = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map_or(0, |duration| duration.as_millis());
|
||||
let prefix = format!("{sequence:06}-{timestamp_ms}");
|
||||
|
||||
let request_path = self.dump_dir.join(format!("{prefix}-request.json"));
|
||||
let response_path = self.dump_dir.join(format!("{prefix}-response.json"));
|
||||
|
||||
let request_dump = RequestDump {
|
||||
method: method.as_str().to_string(),
|
||||
url: url.to_string(),
|
||||
headers: headers.iter().map(HeaderDump::from).collect(),
|
||||
body: dump_body(body),
|
||||
};
|
||||
|
||||
write_json_dump(&request_path, &request_dump)?;
|
||||
|
||||
Ok(ExchangeDump { response_path })
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ExchangeDump {
|
||||
response_path: PathBuf,
|
||||
}
|
||||
|
||||
impl ExchangeDump {
|
||||
pub(crate) fn tee_response_body<R: Read>(
|
||||
self,
|
||||
status: u16,
|
||||
headers: &HeaderMap,
|
||||
response_body: R,
|
||||
) -> ResponseBodyDump<R> {
|
||||
ResponseBodyDump {
|
||||
response_body,
|
||||
response_path: self.response_path,
|
||||
status,
|
||||
headers: headers.iter().map(HeaderDump::from).collect(),
|
||||
body: Vec::new(),
|
||||
dump_written: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ResponseBodyDump<R> {
|
||||
response_body: R,
|
||||
response_path: PathBuf,
|
||||
status: u16,
|
||||
headers: Vec<HeaderDump>,
|
||||
body: Vec<u8>,
|
||||
dump_written: bool,
|
||||
}
|
||||
|
||||
impl<R> ResponseBodyDump<R> {
|
||||
fn write_dump_if_needed(&mut self) {
|
||||
if self.dump_written {
|
||||
return;
|
||||
}
|
||||
|
||||
self.dump_written = true;
|
||||
|
||||
let response_dump = ResponseDump {
|
||||
status: self.status,
|
||||
headers: std::mem::take(&mut self.headers),
|
||||
body: dump_body(&self.body),
|
||||
};
|
||||
|
||||
if let Err(err) = write_json_dump(&self.response_path, &response_dump) {
|
||||
eprintln!(
|
||||
"responses-api-proxy failed to write {}: {err}",
|
||||
self.response_path.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Read> Read for ResponseBodyDump<R> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let bytes_read = self.response_body.read(buf)?;
|
||||
if bytes_read == 0 {
|
||||
self.write_dump_if_needed();
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
self.body.extend_from_slice(&buf[..bytes_read]);
|
||||
Ok(bytes_read)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> Drop for ResponseBodyDump<R> {
|
||||
fn drop(&mut self) {
|
||||
self.write_dump_if_needed();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RequestDump {
|
||||
method: String,
|
||||
url: String,
|
||||
headers: Vec<HeaderDump>,
|
||||
body: Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ResponseDump {
|
||||
status: u16,
|
||||
headers: Vec<HeaderDump>,
|
||||
body: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct HeaderDump {
|
||||
name: String,
|
||||
value: String,
|
||||
}
|
||||
|
||||
impl From<&Header> for HeaderDump {
|
||||
fn from(header: &Header) -> Self {
|
||||
let name = header.field.as_str().to_string();
|
||||
let value = if should_redact_header(&name) {
|
||||
REDACTED_HEADER_VALUE.to_string()
|
||||
} else {
|
||||
header.value.as_str().to_string()
|
||||
};
|
||||
|
||||
Self { name, value }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(&reqwest::header::HeaderName, &reqwest::header::HeaderValue)> for HeaderDump {
|
||||
fn from(header: (&reqwest::header::HeaderName, &reqwest::header::HeaderValue)) -> Self {
|
||||
let name = header.0.as_str();
|
||||
let value = if should_redact_header(name) {
|
||||
REDACTED_HEADER_VALUE.to_string()
|
||||
} else {
|
||||
String::from_utf8_lossy(header.1.as_bytes()).into_owned()
|
||||
};
|
||||
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn should_redact_header(name: &str) -> bool {
|
||||
name.eq_ignore_ascii_case(AUTHORIZATION_HEADER_NAME)
|
||||
|| name.to_ascii_lowercase().contains("cookie")
|
||||
}
|
||||
|
||||
fn dump_body(body: &[u8]) -> Value {
|
||||
serde_json::from_slice(body)
|
||||
.unwrap_or_else(|_| Value::String(String::from_utf8_lossy(body).into_owned()))
|
||||
}
|
||||
|
||||
fn write_json_dump(path: &PathBuf, dump: &impl Serialize) -> io::Result<()> {
|
||||
let mut bytes = serde_json::to_vec_pretty(dump)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
|
||||
bytes.push(b'\n');
|
||||
fs::write(path, bytes)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs;
|
||||
use std::io::Cursor;
|
||||
use std::io::Read;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
use reqwest::header::AUTHORIZATION;
|
||||
use reqwest::header::CONTENT_TYPE;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderValue;
|
||||
use serde_json::json;
|
||||
use tiny_http::Header;
|
||||
use tiny_http::Method;
|
||||
|
||||
use super::ExchangeDumper;
|
||||
|
||||
static NEXT_TEST_DIR: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[test]
|
||||
fn dump_request_writes_redacted_headers_and_json_body() {
|
||||
let dump_dir = test_dump_dir();
|
||||
let dumper = ExchangeDumper::new(dump_dir.clone()).expect("create dumper");
|
||||
let headers = vec![
|
||||
Header::from_bytes(&b"Authorization"[..], &b"Bearer secret"[..])
|
||||
.expect("authorization header"),
|
||||
Header::from_bytes(&b"Cookie"[..], &b"user-session=secret"[..]).expect("cookie header"),
|
||||
Header::from_bytes(&b"Content-Type"[..], &b"application/json"[..])
|
||||
.expect("content-type header"),
|
||||
];
|
||||
|
||||
let exchange_dump = dumper
|
||||
.dump_request(
|
||||
&Method::Post,
|
||||
"/v1/responses",
|
||||
&headers,
|
||||
br#"{"model":"gpt-5.4"}"#,
|
||||
)
|
||||
.expect("dump request");
|
||||
|
||||
let request_dump = fs::read_to_string(dump_file_with_suffix(&dump_dir, "-request.json"))
|
||||
.expect("read request dump");
|
||||
|
||||
assert_eq!(
|
||||
serde_json::from_str::<serde_json::Value>(&request_dump).expect("parse request dump"),
|
||||
json!({
|
||||
"method": "POST",
|
||||
"url": "/v1/responses",
|
||||
"headers": [
|
||||
{
|
||||
"name": "Authorization",
|
||||
"value": "[REDACTED]"
|
||||
},
|
||||
{
|
||||
"name": "Cookie",
|
||||
"value": "[REDACTED]"
|
||||
},
|
||||
{
|
||||
"name": "Content-Type",
|
||||
"value": "application/json"
|
||||
}
|
||||
],
|
||||
"body": {
|
||||
"model": "gpt-5.4"
|
||||
}
|
||||
})
|
||||
);
|
||||
assert!(
|
||||
exchange_dump
|
||||
.response_path
|
||||
.file_name()
|
||||
.expect("response dump file name")
|
||||
.to_string_lossy()
|
||||
.ends_with("-response.json")
|
||||
);
|
||||
|
||||
fs::remove_dir_all(dump_dir).expect("remove test dump dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_body_dump_streams_body_and_writes_response_file() {
|
||||
let dump_dir = test_dump_dir();
|
||||
let dumper = ExchangeDumper::new(dump_dir.clone()).expect("create dumper");
|
||||
let exchange_dump = dumper
|
||||
.dump_request(&Method::Post, "/v1/responses", &[], b"{}")
|
||||
.expect("dump request");
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer secret"));
|
||||
headers.insert(
|
||||
"set-cookie",
|
||||
HeaderValue::from_static("user-session=secret"),
|
||||
);
|
||||
|
||||
let mut response_body = String::new();
|
||||
exchange_dump
|
||||
.tee_response_body(
|
||||
/*status*/ 200,
|
||||
&headers,
|
||||
Cursor::new(b"data: hello\n\n".to_vec()),
|
||||
)
|
||||
.read_to_string(&mut response_body)
|
||||
.expect("read response body");
|
||||
|
||||
let response_dump = fs::read_to_string(dump_file_with_suffix(&dump_dir, "-response.json"))
|
||||
.expect("read response dump");
|
||||
|
||||
assert_eq!(response_body, "data: hello\n\n");
|
||||
assert_eq!(
|
||||
serde_json::from_str::<serde_json::Value>(&response_dump).expect("parse response dump"),
|
||||
json!({
|
||||
"status": 200,
|
||||
"headers": [
|
||||
{
|
||||
"name": "content-type",
|
||||
"value": "text/event-stream"
|
||||
},
|
||||
{
|
||||
"name": "authorization",
|
||||
"value": "[REDACTED]"
|
||||
},
|
||||
{
|
||||
"name": "set-cookie",
|
||||
"value": "[REDACTED]"
|
||||
}
|
||||
],
|
||||
"body": "data: hello\n\n"
|
||||
})
|
||||
);
|
||||
|
||||
fs::remove_dir_all(dump_dir).expect("remove test dump dir");
|
||||
}
|
||||
|
||||
fn test_dump_dir() -> std::path::PathBuf {
|
||||
let test_id = NEXT_TEST_DIR.fetch_add(1, Ordering::Relaxed);
|
||||
let dump_dir = std::env::temp_dir().join(format!(
|
||||
"codex-responses-api-proxy-dump-test-{}-{test_id}",
|
||||
std::process::id()
|
||||
));
|
||||
fs::create_dir_all(&dump_dir).expect("create test dump dir");
|
||||
dump_dir
|
||||
}
|
||||
|
||||
fn dump_file_with_suffix(dump_dir: &std::path::Path, suffix: &str) -> std::path::PathBuf {
|
||||
let mut matches = fs::read_dir(dump_dir)
|
||||
.expect("read dump dir")
|
||||
.map(|entry| entry.expect("read dump entry").path())
|
||||
.filter(|path| path.to_string_lossy().ends_with(suffix))
|
||||
.collect::<Vec<_>>();
|
||||
matches.sort();
|
||||
|
||||
assert_eq!(matches.len(), 1);
|
||||
matches.pop().expect("single dump file")
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::TcpListener;
|
||||
@@ -27,7 +28,9 @@ use tiny_http::Response;
|
||||
use tiny_http::Server;
|
||||
use tiny_http::StatusCode;
|
||||
|
||||
mod dump;
|
||||
mod read_api_key;
|
||||
use dump::ExchangeDumper;
|
||||
use read_api_key::read_auth_header_from_stdin;
|
||||
|
||||
/// CLI arguments for the proxy.
|
||||
@@ -49,6 +52,10 @@ pub struct Args {
|
||||
/// Absolute URL the proxy should forward requests to (defaults to OpenAI).
|
||||
#[arg(long, default_value = "https://api.openai.com/v1/responses")]
|
||||
pub upstream_url: String,
|
||||
|
||||
/// Directory where request/response dumps should be written as JSON.
|
||||
#[arg(long, value_name = "DIR")]
|
||||
pub dump_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -79,6 +86,12 @@ pub fn run_main(args: Args) -> Result<()> {
|
||||
upstream_url,
|
||||
host_header,
|
||||
});
|
||||
let dump_dir = args
|
||||
.dump_dir
|
||||
.map(ExchangeDumper::new)
|
||||
.transpose()
|
||||
.context("creating --dump-dir")?
|
||||
.map(Arc::new);
|
||||
|
||||
let (listener, bound_addr) = bind_listener(args.port)?;
|
||||
if let Some(path) = args.server_info.as_ref() {
|
||||
@@ -100,13 +113,20 @@ pub fn run_main(args: Args) -> Result<()> {
|
||||
for request in server.incoming_requests() {
|
||||
let client = client.clone();
|
||||
let forward_config = forward_config.clone();
|
||||
let dump_dir = dump_dir.clone();
|
||||
std::thread::spawn(move || {
|
||||
if http_shutdown && request.method() == &Method::Get && request.url() == "/shutdown" {
|
||||
let _ = request.respond(Response::new_empty(StatusCode(200)));
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
if let Err(e) = forward_request(&client, auth_header, &forward_config, request) {
|
||||
if let Err(e) = forward_request(
|
||||
&client,
|
||||
auth_header,
|
||||
&forward_config,
|
||||
dump_dir.as_deref(),
|
||||
request,
|
||||
) {
|
||||
eprintln!("forwarding error: {e}");
|
||||
}
|
||||
});
|
||||
@@ -144,6 +164,7 @@ fn forward_request(
|
||||
client: &Client,
|
||||
auth_header: &'static str,
|
||||
config: &ForwardConfig,
|
||||
dump_dir: Option<&ExchangeDumper>,
|
||||
mut req: Request,
|
||||
) -> Result<()> {
|
||||
// Only allow POST /v1/responses exactly, no query string.
|
||||
@@ -159,8 +180,18 @@ fn forward_request(
|
||||
|
||||
// Read request body
|
||||
let mut body = Vec::new();
|
||||
let mut reader = req.as_reader();
|
||||
std::io::Read::read_to_end(&mut reader, &mut body)?;
|
||||
let reader = req.as_reader();
|
||||
reader.read_to_end(&mut body)?;
|
||||
|
||||
let exchange_dump = dump_dir.and_then(|dump_dir| {
|
||||
dump_dir
|
||||
.dump_request(&method, &url_path, req.headers(), &body)
|
||||
.map_err(|err| {
|
||||
eprintln!("responses-api-proxy failed to dump request: {err}");
|
||||
err
|
||||
})
|
||||
.ok()
|
||||
});
|
||||
|
||||
// Build headers for upstream, forwarding everything from the incoming
|
||||
// request except Authorization (we replace it below).
|
||||
@@ -224,10 +255,17 @@ fn forward_request(
|
||||
}
|
||||
});
|
||||
|
||||
let response_body: Box<dyn Read + Send> = if let Some(exchange_dump) = exchange_dump {
|
||||
let headers = upstream_resp.headers().clone();
|
||||
Box::new(exchange_dump.tee_response_body(status.as_u16(), &headers, upstream_resp))
|
||||
} else {
|
||||
Box::new(upstream_resp)
|
||||
};
|
||||
|
||||
let response = Response::new(
|
||||
StatusCode(status.as_u16()),
|
||||
response_headers,
|
||||
upstream_resp,
|
||||
response_body,
|
||||
content_length,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
CREATE TABLE remote_control_enrollments (
|
||||
websocket_url TEXT NOT NULL,
|
||||
account_id TEXT NOT NULL,
|
||||
server_id TEXT NOT NULL,
|
||||
environment_id TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
PRIMARY KEY (websocket_url, account_id)
|
||||
);
|
||||
@@ -54,6 +54,7 @@ mod agent_jobs;
|
||||
mod backfill;
|
||||
mod logs;
|
||||
mod memories;
|
||||
mod remote_control;
|
||||
#[cfg(test)]
|
||||
mod test_support;
|
||||
mod threads;
|
||||
|
||||
219
codex-rs/state/src/runtime/remote_control.rs
Normal file
219
codex-rs/state/src/runtime/remote_control.rs
Normal file
@@ -0,0 +1,219 @@
|
||||
use super::*;
|
||||
|
||||
const REMOTE_CONTROL_ACCOUNT_ID_NONE: &str = "";
|
||||
|
||||
fn remote_control_account_id_key(account_id: Option<&str>) -> &str {
|
||||
account_id.unwrap_or(REMOTE_CONTROL_ACCOUNT_ID_NONE)
|
||||
}
|
||||
|
||||
impl StateRuntime {
|
||||
pub async fn get_remote_control_enrollment(
|
||||
&self,
|
||||
websocket_url: &str,
|
||||
account_id: Option<&str>,
|
||||
) -> anyhow::Result<Option<(String, String, String)>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT server_id, environment_id, server_name
|
||||
FROM remote_control_enrollments
|
||||
WHERE websocket_url = ? AND account_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(websocket_url)
|
||||
.bind(remote_control_account_id_key(account_id))
|
||||
.fetch_optional(self.pool.as_ref())
|
||||
.await?;
|
||||
|
||||
row.map(|row| {
|
||||
Ok((
|
||||
row.try_get("server_id")?,
|
||||
row.try_get("environment_id")?,
|
||||
row.try_get("server_name")?,
|
||||
))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
pub async fn upsert_remote_control_enrollment(
|
||||
&self,
|
||||
websocket_url: &str,
|
||||
account_id: Option<&str>,
|
||||
server_id: &str,
|
||||
environment_id: &str,
|
||||
server_name: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO remote_control_enrollments (
|
||||
websocket_url,
|
||||
account_id,
|
||||
server_id,
|
||||
environment_id,
|
||||
server_name,
|
||||
updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(websocket_url, account_id) DO UPDATE SET
|
||||
server_id = excluded.server_id,
|
||||
environment_id = excluded.environment_id,
|
||||
server_name = excluded.server_name,
|
||||
updated_at = excluded.updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(websocket_url)
|
||||
.bind(remote_control_account_id_key(account_id))
|
||||
.bind(server_id)
|
||||
.bind(environment_id)
|
||||
.bind(server_name)
|
||||
.bind(Utc::now().timestamp())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_remote_control_enrollment(
|
||||
&self,
|
||||
websocket_url: &str,
|
||||
account_id: Option<&str>,
|
||||
) -> anyhow::Result<u64> {
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
DELETE FROM remote_control_enrollments
|
||||
WHERE websocket_url = ? AND account_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(websocket_url)
|
||||
.bind(remote_control_account_id_key(account_id))
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::StateRuntime;
|
||||
use super::test_support::unique_temp_dir;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn remote_control_enrollment_round_trips_by_target_and_account() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
runtime
|
||||
.upsert_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-a"),
|
||||
"srv_e_first",
|
||||
"env_first",
|
||||
"first-server",
|
||||
)
|
||||
.await
|
||||
.expect("insert first enrollment");
|
||||
runtime
|
||||
.upsert_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-b"),
|
||||
"srv_e_second",
|
||||
"env_second",
|
||||
"second-server",
|
||||
)
|
||||
.await
|
||||
.expect("insert second enrollment");
|
||||
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-a"),
|
||||
)
|
||||
.await
|
||||
.expect("load first enrollment"),
|
||||
Some((
|
||||
"srv_e_first".to_string(),
|
||||
"env_first".to_string(),
|
||||
"first-server".to_string()
|
||||
))
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
/*account_id*/ None,
|
||||
)
|
||||
.await
|
||||
.expect("load missing enrollment"),
|
||||
None
|
||||
);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn delete_remote_control_enrollment_removes_only_matching_entry() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string())
|
||||
.await
|
||||
.expect("initialize runtime");
|
||||
|
||||
runtime
|
||||
.upsert_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
/*account_id*/ None,
|
||||
"srv_e_first",
|
||||
"env_first",
|
||||
"first-server",
|
||||
)
|
||||
.await
|
||||
.expect("insert first enrollment");
|
||||
runtime
|
||||
.upsert_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-a"),
|
||||
"srv_e_second",
|
||||
"env_second",
|
||||
"second-server",
|
||||
)
|
||||
.await
|
||||
.expect("insert second enrollment");
|
||||
|
||||
assert_eq!(
|
||||
runtime
|
||||
.delete_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
/*account_id*/ None,
|
||||
)
|
||||
.await
|
||||
.expect("delete first enrollment"),
|
||||
1
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
/*account_id*/ None,
|
||||
)
|
||||
.await
|
||||
.expect("load deleted enrollment"),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_remote_control_enrollment(
|
||||
"wss://example.com/backend-api/wham/remote/control/server",
|
||||
Some("account-a"),
|
||||
)
|
||||
.await
|
||||
.expect("load retained enrollment"),
|
||||
Some((
|
||||
"srv_e_second".to_string(),
|
||||
"env_second".to_string(),
|
||||
"second-server".to_string()
|
||||
))
|
||||
);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(codex_home).await;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user