mirror of
https://github.com/openai/codex.git
synced 2026-05-06 14:21:08 +03:00
Compare commits
4 Commits
pr21028
...
kmeelu/rea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8b00ac21a | ||
|
|
df4a90740f | ||
|
|
73d6392f4f | ||
|
|
9e905528bb |
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -2220,6 +2220,7 @@ dependencies = [
|
||||
"opentelemetry_sdk",
|
||||
"pretty_assertions",
|
||||
"rand 0.9.3",
|
||||
"rcgen",
|
||||
"reqwest",
|
||||
"rustls",
|
||||
"rustls-native-certs",
|
||||
|
||||
@@ -322,6 +322,10 @@ quick-xml = "0.38.4"
|
||||
rand = "0.9"
|
||||
ratatui = "0.29.0"
|
||||
ratatui-macros = "0.6.0"
|
||||
rcgen = { version = "0.14.7", default-features = false, features = [
|
||||
"aws_lc_rs",
|
||||
"pem",
|
||||
] }
|
||||
regex = "1.12.3"
|
||||
regex-lite = "0.1.8"
|
||||
reqwest = { version = "0.12", features = ["cookies"] }
|
||||
|
||||
@@ -12,7 +12,7 @@ futures = { workspace = true }
|
||||
http = { workspace = true }
|
||||
opentelemetry = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json", "stream"] }
|
||||
reqwest = { workspace = true, features = ["json", "rustls-tls-native-roots", "stream"] }
|
||||
rustls = { workspace = true }
|
||||
rustls-native-certs = { workspace = true }
|
||||
rustls-pki-types = { workspace = true }
|
||||
@@ -32,5 +32,6 @@ workspace = true
|
||||
codex-utils-cargo-bin = { workspace = true }
|
||||
opentelemetry_sdk = { workspace = true }
|
||||
pretty_assertions = { workspace = true }
|
||||
rcgen = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
@@ -8,22 +8,93 @@
|
||||
//! - env precedence is respected,
|
||||
//! - multi-cert PEM bundles load,
|
||||
//! - error messages guide users when CA files are invalid.
|
||||
//! - optional HTTPS probes can complete a request through the constructed client.
|
||||
//!
|
||||
//! The detailed explanation of what "hermetic" means here lives in `codex_client::custom_ca`.
|
||||
//! This binary exists so the tests can exercise
|
||||
//! [`codex_client::build_reqwest_client_for_subprocess_tests`] in a separate process without
|
||||
//! duplicating client-construction logic.
|
||||
|
||||
use std::env;
|
||||
use std::process;
|
||||
use std::time::Duration;
|
||||
|
||||
const PROBE_TLS13_ENV: &str = "CODEX_CUSTOM_CA_PROBE_TLS13";
|
||||
const PROBE_PROXY_ENV: &str = "CODEX_CUSTOM_CA_PROBE_PROXY";
|
||||
const PROBE_URL_ENV: &str = "CODEX_CUSTOM_CA_PROBE_URL";
|
||||
|
||||
fn main() {
|
||||
match codex_client::build_reqwest_client_for_subprocess_tests(reqwest::Client::builder()) {
|
||||
Ok(_) => {
|
||||
println!("ok");
|
||||
let runtime = match tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
{
|
||||
Ok(runtime) => runtime,
|
||||
Err(error) => {
|
||||
eprintln!("failed to create probe runtime: {error}");
|
||||
process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
match runtime.block_on(run_probe()) {
|
||||
Ok(()) => println!("ok"),
|
||||
Err(error) => {
|
||||
eprintln!("{error}");
|
||||
process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_probe() -> Result<(), String> {
|
||||
let proxy_url = env::var(PROBE_PROXY_ENV).ok();
|
||||
let target_url = env::var(PROBE_URL_ENV).ok();
|
||||
let mut builder = reqwest::Client::builder();
|
||||
if target_url.is_some() {
|
||||
builder = builder.timeout(Duration::from_secs(5));
|
||||
}
|
||||
if env::var_os(PROBE_TLS13_ENV).is_some() {
|
||||
builder = builder.min_tls_version(reqwest::tls::Version::TLS_1_3);
|
||||
}
|
||||
|
||||
let client = build_probe_client(builder, proxy_url.as_deref())?;
|
||||
if let Some(url) = target_url {
|
||||
post_probe_request(&client, &url).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_probe_client(
|
||||
builder: reqwest::ClientBuilder,
|
||||
proxy_url: Option<&str>,
|
||||
) -> Result<reqwest::Client, String> {
|
||||
if let Some(proxy_url) = proxy_url {
|
||||
let proxy = reqwest::Proxy::https(proxy_url)
|
||||
.map_err(|error| format!("failed to configure probe proxy {proxy_url}: {error}"))?;
|
||||
return codex_client::build_reqwest_client_with_custom_ca(builder.proxy(proxy))
|
||||
.map_err(|error| error.to_string());
|
||||
}
|
||||
|
||||
codex_client::build_reqwest_client_for_subprocess_tests(builder)
|
||||
.map_err(|error| error.to_string())
|
||||
}
|
||||
|
||||
async fn post_probe_request(client: &reqwest::Client, url: &str) -> Result<(), String> {
|
||||
let response = client
|
||||
.post(url)
|
||||
.header("Content-Type", "application/x-www-form-urlencoded")
|
||||
.body("grant_type=authorization_code&code=test")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|error| format!("probe request failed: {error:?}"))?;
|
||||
let status = response.status();
|
||||
let body = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|error| format!("failed to read probe response body: {error}"))?;
|
||||
if !status.is_success() {
|
||||
return Err(format!("probe request returned {status}: {body}"));
|
||||
}
|
||||
if body != "ok" {
|
||||
return Err(format!("probe response body mismatch: {body}"));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -14,10 +14,9 @@
|
||||
//! `TRUSTED CERTIFICATE` labels and bundles that also contain CRLs
|
||||
//! - return user-facing errors that explain how to fix misconfigured CA files
|
||||
//!
|
||||
//! It does not validate certificate chains or perform a handshake in tests. Its contract is
|
||||
//! narrower: produce a transport configuration whose root store contains every parseable
|
||||
//! certificate block from the configured PEM bundle, or fail early with a precise error before
|
||||
//! the caller starts network traffic.
|
||||
//! Its production contract is narrow: produce a transport configuration whose root store contains
|
||||
//! every parseable certificate block from the configured PEM bundle, or fail early with a precise
|
||||
//! error before the caller starts network traffic.
|
||||
//!
|
||||
//! In this module's test setup, a hermetic test is one whose result depends only on the CA file
|
||||
//! and environment variables that the test chose for itself. That matters here because the normal
|
||||
@@ -36,7 +35,8 @@
|
||||
//! - unit tests in this module cover env-selection logic without constructing a real client
|
||||
//! - subprocess integration tests under `tests/` cover real client construction through
|
||||
//! [`build_reqwest_client_for_subprocess_tests`], which disables reqwest proxy autodetection so
|
||||
//! the tests can observe custom-CA success and failure directly
|
||||
//! the tests can observe custom-CA success and failure directly, including one TLS handshake
|
||||
//! through a local HTTPS server
|
||||
//! - those subprocess tests also scrub inherited CA environment variables before launch so their
|
||||
//! result depends only on the test fixtures and env vars set by the test itself
|
||||
|
||||
@@ -266,12 +266,21 @@ fn maybe_build_rustls_client_config_with_env(
|
||||
/// This exists so tests can exercise precedence behavior deterministically without mutating the
|
||||
/// real process environment. It selects the CA bundle, delegates file parsing to
|
||||
/// [`ConfiguredCaBundle::load_certificates`], preserves the caller's chosen `reqwest` builder
|
||||
/// configuration, and finally registers each parsed certificate with that builder.
|
||||
/// configuration, forces rustls when a custom CA is configured, and finally registers each parsed
|
||||
/// certificate with that builder.
|
||||
fn build_reqwest_client_with_env(
|
||||
env_source: &dyn EnvSource,
|
||||
mut builder: reqwest::ClientBuilder,
|
||||
) -> Result<reqwest::Client, BuildCustomCaTransportError> {
|
||||
if let Some(bundle) = env_source.configured_ca_bundle() {
|
||||
ensure_rustls_crypto_provider();
|
||||
info!(
|
||||
source_env = bundle.source_env,
|
||||
ca_path = %bundle.path.display(),
|
||||
"building HTTP client with rustls backend for custom CA bundle"
|
||||
);
|
||||
builder = builder.use_rustls_tls();
|
||||
|
||||
let certificates = bundle.load_certificates()?;
|
||||
|
||||
for (idx, cert) in certificates.iter().enumerate() {
|
||||
|
||||
@@ -4,24 +4,83 @@
|
||||
//! `build_reqwest_client_for_subprocess_tests` instead of calling the helper in-process. The
|
||||
//! detailed explanation of what "hermetic" means here lives in `codex_client::custom_ca`; these
|
||||
//! tests add the process-level half of that contract by scrubbing inherited CA environment
|
||||
//! variables before each subprocess launch. They still stop at client construction: the
|
||||
//! assertions here cover CA file selection, PEM parsing, and user-facing errors, not a full TLS
|
||||
//! handshake.
|
||||
//! variables before each subprocess launch. Most assertions here cover CA file selection, PEM
|
||||
//! parsing, and user-facing errors. The HTTPS probes go further and perform real POSTs against
|
||||
//! locally generated certificates, including through a TLS-intercepting CONNECT proxy.
|
||||
|
||||
use codex_utils_cargo_bin::cargo_bin;
|
||||
use rcgen::BasicConstraints;
|
||||
use rcgen::CertificateParams;
|
||||
use rcgen::CertifiedIssuer;
|
||||
use rcgen::DistinguishedName;
|
||||
use rcgen::DnType;
|
||||
use rcgen::ExtendedKeyUsagePurpose;
|
||||
use rcgen::IsCa;
|
||||
use rcgen::KeyPair;
|
||||
use rcgen::KeyUsagePurpose;
|
||||
use rcgen::PKCS_ECDSA_P256_SHA256;
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use rustls_pki_types::PrivateKeyDer;
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
use std::net::TcpListener;
|
||||
use std::net::TcpStream;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
use std::sync::Arc;
|
||||
use std::sync::mpsc;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tempfile::TempDir;
|
||||
|
||||
const CODEX_CA_CERT_ENV: &str = "CODEX_CA_CERTIFICATE";
|
||||
const PROBE_PROXY_ENV: &str = "CODEX_CUSTOM_CA_PROBE_PROXY";
|
||||
const PROBE_TLS13_ENV: &str = "CODEX_CUSTOM_CA_PROBE_TLS13";
|
||||
const PROBE_URL_ENV: &str = "CODEX_CUSTOM_CA_PROBE_URL";
|
||||
const SSL_CERT_FILE_ENV: &str = "SSL_CERT_FILE";
|
||||
const PROXY_ENV_VARS: &[&str] = &[
|
||||
"HTTP_PROXY",
|
||||
"http_proxy",
|
||||
"HTTPS_PROXY",
|
||||
"https_proxy",
|
||||
"ALL_PROXY",
|
||||
"all_proxy",
|
||||
"NO_PROXY",
|
||||
"no_proxy",
|
||||
];
|
||||
|
||||
const TEST_CERT_1: &str = include_str!("fixtures/test-ca.pem");
|
||||
const TEST_CERT_2: &str = include_str!("fixtures/test-intermediate.pem");
|
||||
const TRUSTED_TEST_CERT: &str = include_str!("fixtures/test-ca-trusted.pem");
|
||||
|
||||
fn write_cert_file(temp_dir: &TempDir, name: &str, contents: &str) -> std::path::PathBuf {
|
||||
struct Tls13Material {
|
||||
ca_cert_pem: String,
|
||||
server_cert: CertificateDer<'static>,
|
||||
server_key: PrivateKeyDer<'static>,
|
||||
}
|
||||
|
||||
struct Tls13TestServer {
|
||||
ca_cert_pem: String,
|
||||
request_rx: mpsc::Receiver<Result<String, String>>,
|
||||
url: String,
|
||||
}
|
||||
|
||||
struct PlainHttpOrigin {
|
||||
request_rx: mpsc::Receiver<Result<String, String>>,
|
||||
url: String,
|
||||
}
|
||||
|
||||
struct TlsInterceptingProxy {
|
||||
ca_cert_pem: String,
|
||||
request_rx: mpsc::Receiver<Result<String, String>>,
|
||||
url: String,
|
||||
}
|
||||
|
||||
fn write_cert_file(temp_dir: &TempDir, name: &str, contents: &str) -> PathBuf {
|
||||
let path = temp_dir.path().join(name);
|
||||
fs::write(&path, contents).unwrap_or_else(|error| {
|
||||
panic!("write cert fixture failed for {}: {error}", path.display())
|
||||
@@ -29,7 +88,7 @@ fn write_cert_file(temp_dir: &TempDir, name: &str, contents: &str) -> std::path:
|
||||
path
|
||||
}
|
||||
|
||||
fn run_probe(envs: &[(&str, &Path)]) -> std::process::Output {
|
||||
fn probe_command() -> Command {
|
||||
let mut cmd = Command::new(
|
||||
cargo_bin("custom_ca_probe")
|
||||
.unwrap_or_else(|error| panic!("failed to locate custom_ca_probe: {error}")),
|
||||
@@ -37,7 +96,18 @@ fn run_probe(envs: &[(&str, &Path)]) -> std::process::Output {
|
||||
// `Command` inherits the parent environment by default, so scrub CA-related variables first or
|
||||
// these tests can accidentally pass/fail based on the developer shell or CI runner.
|
||||
cmd.env_remove(CODEX_CA_CERT_ENV);
|
||||
cmd.env_remove(PROBE_PROXY_ENV);
|
||||
cmd.env_remove(PROBE_TLS13_ENV);
|
||||
cmd.env_remove(PROBE_URL_ENV);
|
||||
cmd.env_remove(SSL_CERT_FILE_ENV);
|
||||
for env_var in PROXY_ENV_VARS {
|
||||
cmd.env_remove(env_var);
|
||||
}
|
||||
cmd
|
||||
}
|
||||
|
||||
fn run_probe(envs: &[(&str, &Path)]) -> std::process::Output {
|
||||
let mut cmd = probe_command();
|
||||
for (key, value) in envs {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
@@ -45,6 +115,286 @@ fn run_probe(envs: &[(&str, &Path)]) -> std::process::Output {
|
||||
.unwrap_or_else(|error| panic!("failed to run custom_ca_probe: {error}"))
|
||||
}
|
||||
|
||||
fn run_probe_posting_to_tls13_server(envs: &[(&str, &Path)], url: &str) -> std::process::Output {
|
||||
let mut cmd = probe_command();
|
||||
for (key, value) in envs {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
cmd.env(PROBE_TLS13_ENV, "1");
|
||||
cmd.env(PROBE_URL_ENV, url);
|
||||
cmd.output()
|
||||
.unwrap_or_else(|error| panic!("failed to run custom_ca_probe: {error}"))
|
||||
}
|
||||
|
||||
fn run_probe_posting_through_tls_intercepting_proxy(
|
||||
envs: &[(&str, &Path)],
|
||||
url: &str,
|
||||
proxy_url: &str,
|
||||
) -> std::process::Output {
|
||||
let mut cmd = probe_command();
|
||||
for (key, value) in envs {
|
||||
cmd.env(key, value);
|
||||
}
|
||||
cmd.env(PROBE_PROXY_ENV, proxy_url);
|
||||
cmd.env(PROBE_TLS13_ENV, "1");
|
||||
cmd.env(PROBE_URL_ENV, url);
|
||||
cmd.output()
|
||||
.unwrap_or_else(|error| panic!("failed to run custom_ca_probe: {error}"))
|
||||
}
|
||||
|
||||
fn spawn_tls13_test_server() -> Tls13TestServer {
|
||||
codex_utils_rustls_provider::ensure_rustls_crypto_provider();
|
||||
let material = generate_tls13_material();
|
||||
let listener = TcpListener::bind(("127.0.0.1", 0))
|
||||
.unwrap_or_else(|error| panic!("bind TLS test server: {error}"));
|
||||
listener
|
||||
.set_nonblocking(true)
|
||||
.unwrap_or_else(|error| panic!("set TLS test server nonblocking: {error}"));
|
||||
let port = listener
|
||||
.local_addr()
|
||||
.unwrap_or_else(|error| panic!("TLS test server addr: {error}"))
|
||||
.port();
|
||||
let config = Arc::new(
|
||||
rustls::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![material.server_cert], material.server_key)
|
||||
.unwrap_or_else(|error| panic!("TLS 1.3 server config: {error}")),
|
||||
);
|
||||
let (request_tx, request_rx) = mpsc::channel();
|
||||
|
||||
thread::spawn(move || {
|
||||
let result = accept_tls13_request(listener, config);
|
||||
let _ = request_tx.send(result.map_err(|error| error.to_string()));
|
||||
});
|
||||
|
||||
Tls13TestServer {
|
||||
ca_cert_pem: material.ca_cert_pem,
|
||||
request_rx,
|
||||
url: format!("https://127.0.0.1:{port}/oauth/token"),
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_plain_http_origin() -> PlainHttpOrigin {
|
||||
let listener = TcpListener::bind(("127.0.0.1", 0))
|
||||
.unwrap_or_else(|error| panic!("bind plain HTTP origin: {error}"));
|
||||
listener
|
||||
.set_nonblocking(true)
|
||||
.unwrap_or_else(|error| panic!("set plain HTTP origin nonblocking: {error}"));
|
||||
let port = listener
|
||||
.local_addr()
|
||||
.unwrap_or_else(|error| panic!("plain HTTP origin addr: {error}"))
|
||||
.port();
|
||||
let (request_tx, request_rx) = mpsc::channel();
|
||||
|
||||
thread::spawn(move || {
|
||||
let result = accept_plain_http_origin_request(listener);
|
||||
let _ = request_tx.send(result.map_err(|error| error.to_string()));
|
||||
});
|
||||
|
||||
PlainHttpOrigin {
|
||||
request_rx,
|
||||
url: format!("https://127.0.0.1:{port}/oauth/token"),
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_tls_intercepting_proxy() -> TlsInterceptingProxy {
|
||||
codex_utils_rustls_provider::ensure_rustls_crypto_provider();
|
||||
let material = generate_tls13_material();
|
||||
let listener = TcpListener::bind(("127.0.0.1", 0))
|
||||
.unwrap_or_else(|error| panic!("bind TLS intercepting proxy: {error}"));
|
||||
listener
|
||||
.set_nonblocking(true)
|
||||
.unwrap_or_else(|error| panic!("set TLS intercepting proxy nonblocking: {error}"));
|
||||
let port = listener
|
||||
.local_addr()
|
||||
.unwrap_or_else(|error| panic!("TLS intercepting proxy addr: {error}"))
|
||||
.port();
|
||||
let config = Arc::new(
|
||||
rustls::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![material.server_cert], material.server_key)
|
||||
.unwrap_or_else(|error| panic!("TLS intercepting proxy config: {error}")),
|
||||
);
|
||||
let (request_tx, request_rx) = mpsc::channel();
|
||||
|
||||
thread::spawn(move || {
|
||||
let result = accept_tls_intercepting_proxy_request(listener, config);
|
||||
let _ = request_tx.send(result.map_err(|error| error.to_string()));
|
||||
});
|
||||
|
||||
TlsInterceptingProxy {
|
||||
ca_cert_pem: material.ca_cert_pem,
|
||||
request_rx,
|
||||
url: format!("http://127.0.0.1:{port}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_tls13_material() -> Tls13Material {
|
||||
let mut ca_params = CertificateParams::default();
|
||||
ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
|
||||
ca_params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
|
||||
let mut ca_distinguished_name = DistinguishedName::new();
|
||||
ca_distinguished_name.push(DnType::CommonName, "codex test CA");
|
||||
ca_params.distinguished_name = ca_distinguished_name;
|
||||
let ca_key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)
|
||||
.unwrap_or_else(|error| panic!("generate test CA key pair: {error}"));
|
||||
let ca = CertifiedIssuer::self_signed(ca_params, ca_key_pair)
|
||||
.unwrap_or_else(|error| panic!("generate test CA certificate: {error}"));
|
||||
|
||||
let mut server_params =
|
||||
CertificateParams::new(vec!["localhost".to_string(), "127.0.0.1".to_string()])
|
||||
.unwrap_or_else(|error| panic!("create test server certificate params: {error}"));
|
||||
server_params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
||||
server_params.key_usages = vec![
|
||||
KeyUsagePurpose::DigitalSignature,
|
||||
KeyUsagePurpose::KeyEncipherment,
|
||||
];
|
||||
let server_key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)
|
||||
.unwrap_or_else(|error| panic!("generate test server key pair: {error}"));
|
||||
let server_cert = server_params
|
||||
.signed_by(&server_key_pair, &ca)
|
||||
.unwrap_or_else(|error| panic!("generate test server certificate: {error}"));
|
||||
|
||||
Tls13Material {
|
||||
ca_cert_pem: ca.pem(),
|
||||
server_cert: server_cert.der().clone(),
|
||||
server_key: PrivateKeyDer::from(server_key_pair),
|
||||
}
|
||||
}
|
||||
|
||||
fn accept_plain_http_origin_request(listener: TcpListener) -> io::Result<String> {
|
||||
let mut stream = accept_with_timeout(listener, Duration::from_secs(5))?;
|
||||
stream.set_nonblocking(false)?;
|
||||
stream.set_read_timeout(Some(Duration::from_secs(5)))?;
|
||||
stream.set_write_timeout(Some(Duration::from_secs(5)))?;
|
||||
|
||||
let request = read_http_message(&mut stream)?;
|
||||
stream.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok")?;
|
||||
stream.flush()?;
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
fn accept_tls13_request(
|
||||
listener: TcpListener,
|
||||
config: Arc<rustls::ServerConfig>,
|
||||
) -> io::Result<String> {
|
||||
let stream = accept_with_timeout(listener, Duration::from_secs(5))?;
|
||||
stream.set_nonblocking(false)?;
|
||||
stream.set_read_timeout(Some(Duration::from_secs(5)))?;
|
||||
stream.set_write_timeout(Some(Duration::from_secs(5)))?;
|
||||
|
||||
let connection = rustls::ServerConnection::new(config).map_err(io::Error::other)?;
|
||||
let mut tls = rustls::StreamOwned::new(connection, stream);
|
||||
let request = read_http_message(&mut tls)?;
|
||||
tls.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok")?;
|
||||
tls.flush()?;
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
fn accept_tls_intercepting_proxy_request(
|
||||
listener: TcpListener,
|
||||
config: Arc<rustls::ServerConfig>,
|
||||
) -> io::Result<String> {
|
||||
let mut stream = accept_with_timeout(listener, Duration::from_secs(5))?;
|
||||
stream.set_nonblocking(false)?;
|
||||
stream.set_read_timeout(Some(Duration::from_secs(5)))?;
|
||||
stream.set_write_timeout(Some(Duration::from_secs(5)))?;
|
||||
|
||||
let connect_request = read_http_message(&mut stream)?;
|
||||
let origin_authority = connect_authority_from_request(&connect_request)?;
|
||||
stream.write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")?;
|
||||
stream.flush()?;
|
||||
|
||||
let connection = rustls::ServerConnection::new(config).map_err(io::Error::other)?;
|
||||
let mut tls = rustls::StreamOwned::new(connection, stream);
|
||||
let request = read_http_message(&mut tls)?;
|
||||
|
||||
let mut origin = TcpStream::connect(origin_authority.as_str())?;
|
||||
origin.set_read_timeout(Some(Duration::from_secs(5)))?;
|
||||
origin.set_write_timeout(Some(Duration::from_secs(5)))?;
|
||||
origin.write_all(request.as_bytes())?;
|
||||
origin.flush()?;
|
||||
let response = read_http_message(&mut origin)?;
|
||||
|
||||
tls.write_all(response.as_bytes())?;
|
||||
tls.flush()?;
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
fn connect_authority_from_request(request: &str) -> io::Result<String> {
|
||||
let request_line = request
|
||||
.lines()
|
||||
.next()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "empty CONNECT request"))?;
|
||||
let mut parts = request_line.split_whitespace();
|
||||
match (parts.next(), parts.next(), parts.next()) {
|
||||
(Some("CONNECT"), Some(authority), Some(_version)) => Ok(authority.to_string()),
|
||||
_ => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("invalid CONNECT request line: {request_line}"),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn accept_with_timeout(listener: TcpListener, timeout: Duration) -> io::Result<TcpStream> {
|
||||
let deadline = Instant::now() + timeout;
|
||||
loop {
|
||||
match listener.accept() {
|
||||
Ok((stream, _)) => return Ok(stream),
|
||||
Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
|
||||
if Instant::now() >= deadline {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"timed out waiting for TLS test client",
|
||||
));
|
||||
}
|
||||
thread::sleep(Duration::from_millis(10));
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_http_message(stream: &mut impl Read) -> io::Result<String> {
|
||||
let mut buffer = Vec::new();
|
||||
let mut chunk = [0; 1024];
|
||||
loop {
|
||||
let bytes_read = stream.read(&mut chunk)?;
|
||||
if bytes_read == 0 {
|
||||
break;
|
||||
}
|
||||
buffer.extend_from_slice(&chunk[..bytes_read]);
|
||||
if let Some(header_end) = buffer.windows(4).position(|window| window == b"\r\n\r\n") {
|
||||
let body_start = header_end + 4;
|
||||
let headers = String::from_utf8_lossy(&buffer[..body_start]);
|
||||
let content_length = headers
|
||||
.lines()
|
||||
.filter_map(|line| line.split_once(':'))
|
||||
.find_map(|(name, value)| {
|
||||
name.eq_ignore_ascii_case("content-length")
|
||||
.then(|| value.trim().parse::<usize>().ok())
|
||||
.flatten()
|
||||
})
|
||||
.unwrap_or(0);
|
||||
if buffer.len() >= body_start + content_length {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(String::from_utf8_lossy(&buffer).into_owned())
|
||||
}
|
||||
|
||||
fn assert_token_exchange_request(request: &str) {
|
||||
assert!(
|
||||
request.starts_with("POST /oauth/token HTTP/1.1"),
|
||||
"unexpected request:\n{request}"
|
||||
);
|
||||
assert!(
|
||||
request.contains("grant_type=authorization_code&code=test"),
|
||||
"unexpected request body:\n{request}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uses_codex_ca_cert_env() {
|
||||
let temp_dir = TempDir::new().expect("tempdir");
|
||||
@@ -90,6 +440,59 @@ fn handles_multi_certificate_bundle() {
|
||||
assert!(output.status.success());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn posts_to_tls13_server_using_custom_ca_bundle() {
|
||||
let temp_dir = TempDir::new().expect("tempdir");
|
||||
let server = spawn_tls13_test_server();
|
||||
let cert_path = write_cert_file(&temp_dir, "tls-ca.pem", &server.ca_cert_pem);
|
||||
|
||||
let output =
|
||||
run_probe_posting_to_tls13_server(&[(CODEX_CA_CERT_ENV, cert_path.as_path())], &server.url);
|
||||
let server_result = server.request_rx.recv_timeout(Duration::from_secs(5));
|
||||
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"custom_ca_probe failed\nstdout:\n{}\nstderr:\n{}\nserver:\n{server_result:?}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
let request = server_result
|
||||
.expect("TLS test server should report a request")
|
||||
.expect("TLS test server should accept the probe request");
|
||||
assert_token_exchange_request(&request);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn posts_to_token_origin_through_tls_intercepting_proxy_with_custom_ca_bundle() {
|
||||
let temp_dir = TempDir::new().expect("tempdir");
|
||||
let origin = spawn_plain_http_origin();
|
||||
let proxy = spawn_tls_intercepting_proxy();
|
||||
let cert_path = write_cert_file(&temp_dir, "proxy-ca.pem", &proxy.ca_cert_pem);
|
||||
|
||||
let output = run_probe_posting_through_tls_intercepting_proxy(
|
||||
&[(CODEX_CA_CERT_ENV, cert_path.as_path())],
|
||||
&origin.url,
|
||||
&proxy.url,
|
||||
);
|
||||
let proxy_result = proxy.request_rx.recv_timeout(Duration::from_secs(5));
|
||||
let origin_result = origin.request_rx.recv_timeout(Duration::from_secs(5));
|
||||
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"custom_ca_probe failed\nstdout:\n{}\nstderr:\n{}\nproxy:\n{proxy_result:?}\norigin:\n{origin_result:?}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
let proxy_request = proxy_result
|
||||
.expect("TLS intercepting proxy should report a request")
|
||||
.expect("TLS intercepting proxy should accept the probe request");
|
||||
let origin_request = origin_result
|
||||
.expect("plain HTTP origin should report a request")
|
||||
.expect("plain HTTP origin should accept the forwarded request");
|
||||
assert_token_exchange_request(&proxy_request);
|
||||
assert_token_exchange_request(&origin_request);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_empty_pem_file_with_hint() {
|
||||
let temp_dir = TempDir::new().expect("tempdir");
|
||||
|
||||
@@ -52,6 +52,7 @@ use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::debug;
|
||||
@@ -196,6 +197,12 @@ struct RealtimeInputTask {
|
||||
event_parser: RealtimeEventParser,
|
||||
}
|
||||
|
||||
struct RealtimeInputChannels {
|
||||
user_text_rx: Receiver<String>,
|
||||
handoff_output_rx: Receiver<HandoffOutput>,
|
||||
audio_rx: Receiver<RealtimeAudioFrame>,
|
||||
}
|
||||
|
||||
impl RealtimeHandoffState {
|
||||
fn new(output_tx: Sender<HandoffOutput>, session_kind: RealtimeSessionKind) -> Self {
|
||||
Self {
|
||||
@@ -212,7 +219,6 @@ struct ConversationState {
|
||||
audio_tx: Sender<RealtimeAudioFrame>,
|
||||
user_text_tx: Sender<String>,
|
||||
session_kind: RealtimeSessionKind,
|
||||
writer: RealtimeWebsocketWriter,
|
||||
handoff: RealtimeHandoffState,
|
||||
input_task: JoinHandle<()>,
|
||||
fanout_task: Option<JoinHandle<()>>,
|
||||
@@ -271,6 +277,7 @@ impl RealtimeConversationManager {
|
||||
}
|
||||
|
||||
async fn start_inner(&self, start: RealtimeStart) -> CodexResult<RealtimeStartOutput> {
|
||||
let startup_started_at = Instant::now();
|
||||
let RealtimeStart {
|
||||
api_provider,
|
||||
extra_headers,
|
||||
@@ -284,39 +291,6 @@ impl RealtimeConversationManager {
|
||||
RealtimeEventParser::RealtimeV2 => RealtimeSessionKind::V2,
|
||||
};
|
||||
|
||||
let client = RealtimeWebsocketClient::new(api_provider);
|
||||
let (connection, sdp) = if let Some(sdp) = sdp {
|
||||
let call = model_client
|
||||
.create_realtime_call_with_headers(
|
||||
sdp,
|
||||
session_config.clone(),
|
||||
extra_headers.unwrap_or_default(),
|
||||
)
|
||||
.await?;
|
||||
let connection = client
|
||||
.connect_webrtc_sideband(
|
||||
session_config,
|
||||
&call.call_id,
|
||||
call.sideband_headers,
|
||||
default_headers(),
|
||||
)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
(connection, Some(call.sdp))
|
||||
} else {
|
||||
let connection = client
|
||||
.connect(
|
||||
session_config,
|
||||
extra_headers.unwrap_or_default(),
|
||||
default_headers(),
|
||||
)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
(connection, None)
|
||||
};
|
||||
|
||||
let writer = connection.writer();
|
||||
let events = connection.events();
|
||||
let (audio_tx, audio_rx) =
|
||||
async_channel::bounded::<RealtimeAudioFrame>(AUDIO_IN_QUEUE_CAPACITY);
|
||||
let (user_text_tx, user_text_rx) =
|
||||
@@ -328,24 +302,85 @@ impl RealtimeConversationManager {
|
||||
|
||||
let realtime_active = Arc::new(AtomicBool::new(true));
|
||||
let handoff = RealtimeHandoffState::new(handoff_output_tx, session_kind);
|
||||
let task = spawn_realtime_input_task(RealtimeInputTask {
|
||||
writer: writer.clone(),
|
||||
events,
|
||||
let input_channels = RealtimeInputChannels {
|
||||
user_text_rx,
|
||||
handoff_output_rx,
|
||||
audio_rx,
|
||||
events_tx,
|
||||
handoff_state: handoff.clone(),
|
||||
session_kind,
|
||||
event_parser,
|
||||
});
|
||||
};
|
||||
|
||||
let client = RealtimeWebsocketClient::new(api_provider);
|
||||
let (task, sdp) = if let Some(sdp) = sdp {
|
||||
info!(transport = "webrtc", "creating realtime call");
|
||||
let call_started_at = Instant::now();
|
||||
let call = model_client
|
||||
.create_realtime_call_with_headers(
|
||||
sdp,
|
||||
session_config.clone(),
|
||||
extra_headers.unwrap_or_default(),
|
||||
)
|
||||
.await?;
|
||||
info!(
|
||||
transport = "webrtc",
|
||||
call_id = %call.call_id,
|
||||
elapsed_ms = call_started_at.elapsed().as_millis() as u64,
|
||||
total_elapsed_ms = startup_started_at.elapsed().as_millis() as u64,
|
||||
"realtime call created; sdp answer ready"
|
||||
);
|
||||
let task = spawn_webrtc_sideband_input_task(RealtimeWebrtcSidebandInputTask {
|
||||
client,
|
||||
session_config,
|
||||
call_id: call.call_id,
|
||||
sideband_headers: call.sideband_headers,
|
||||
input_channels,
|
||||
events_tx,
|
||||
handoff_state: handoff.clone(),
|
||||
session_kind,
|
||||
event_parser,
|
||||
realtime_active: Arc::clone(&realtime_active),
|
||||
startup_started_at,
|
||||
});
|
||||
info!(
|
||||
transport = "webrtc",
|
||||
total_elapsed_ms = startup_started_at.elapsed().as_millis() as u64,
|
||||
"spawned realtime sideband connection task"
|
||||
);
|
||||
(task, Some(call.sdp))
|
||||
} else {
|
||||
info!(transport = "websocket", "connecting realtime websocket");
|
||||
let connect_started_at = Instant::now();
|
||||
let connection = client
|
||||
.connect(
|
||||
session_config,
|
||||
extra_headers.unwrap_or_default(),
|
||||
default_headers(),
|
||||
)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
info!(
|
||||
transport = "websocket",
|
||||
elapsed_ms = connect_started_at.elapsed().as_millis() as u64,
|
||||
total_elapsed_ms = startup_started_at.elapsed().as_millis() as u64,
|
||||
"connected realtime websocket"
|
||||
);
|
||||
let task = spawn_realtime_input_task(RealtimeInputTask {
|
||||
writer: connection.writer(),
|
||||
events: connection.events(),
|
||||
user_text_rx: input_channels.user_text_rx,
|
||||
handoff_output_rx: input_channels.handoff_output_rx,
|
||||
audio_rx: input_channels.audio_rx,
|
||||
events_tx,
|
||||
handoff_state: handoff.clone(),
|
||||
session_kind,
|
||||
event_parser,
|
||||
});
|
||||
(task, None)
|
||||
};
|
||||
|
||||
let mut guard = self.state.lock().await;
|
||||
*guard = Some(ConversationState {
|
||||
audio_tx,
|
||||
user_text_tx,
|
||||
session_kind,
|
||||
writer,
|
||||
handoff,
|
||||
input_task: task,
|
||||
fanout_task: None,
|
||||
@@ -805,6 +840,7 @@ async fn handle_start_inner(
|
||||
msg: EventMsg::RealtimeConversationSdp(RealtimeConversationSdpEvent { sdp }),
|
||||
})
|
||||
.await;
|
||||
info!("sent realtime sdp answer to client");
|
||||
}
|
||||
|
||||
let sess_clone = Arc::clone(sess);
|
||||
@@ -1004,6 +1040,100 @@ pub(crate) async fn handle_close(sess: &Arc<Session>, sub_id: String) {
|
||||
}
|
||||
|
||||
fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
|
||||
tokio::spawn(run_realtime_input_task(input))
|
||||
}
|
||||
|
||||
struct RealtimeWebrtcSidebandInputTask {
|
||||
client: RealtimeWebsocketClient,
|
||||
session_config: RealtimeSessionConfig,
|
||||
call_id: String,
|
||||
sideband_headers: HeaderMap,
|
||||
input_channels: RealtimeInputChannels,
|
||||
events_tx: Sender<RealtimeEvent>,
|
||||
handoff_state: RealtimeHandoffState,
|
||||
session_kind: RealtimeSessionKind,
|
||||
event_parser: RealtimeEventParser,
|
||||
realtime_active: Arc<AtomicBool>,
|
||||
startup_started_at: Instant,
|
||||
}
|
||||
|
||||
fn spawn_webrtc_sideband_input_task(input: RealtimeWebrtcSidebandInputTask) -> JoinHandle<()> {
|
||||
let RealtimeWebrtcSidebandInputTask {
|
||||
client,
|
||||
session_config,
|
||||
call_id,
|
||||
sideband_headers,
|
||||
input_channels,
|
||||
events_tx,
|
||||
handoff_state,
|
||||
session_kind,
|
||||
event_parser,
|
||||
realtime_active,
|
||||
startup_started_at,
|
||||
} = input;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if !realtime_active.load(Ordering::Relaxed) {
|
||||
return;
|
||||
}
|
||||
|
||||
info!(%call_id, "connecting realtime sideband websocket");
|
||||
let sideband_started_at = Instant::now();
|
||||
let connection = match client
|
||||
.connect_webrtc_sideband(
|
||||
session_config,
|
||||
&call_id,
|
||||
sideband_headers,
|
||||
default_headers(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(connection) => {
|
||||
info!(
|
||||
%call_id,
|
||||
elapsed_ms = sideband_started_at.elapsed().as_millis() as u64,
|
||||
total_elapsed_ms = startup_started_at.elapsed().as_millis() as u64,
|
||||
"connected realtime sideband websocket"
|
||||
);
|
||||
connection
|
||||
}
|
||||
Err(err) => {
|
||||
if realtime_active.load(Ordering::Relaxed) {
|
||||
let mapped_error = map_api_error(err);
|
||||
warn!(
|
||||
%call_id,
|
||||
elapsed_ms = sideband_started_at.elapsed().as_millis() as u64,
|
||||
total_elapsed_ms = startup_started_at.elapsed().as_millis() as u64,
|
||||
"failed to connect realtime sideband: {mapped_error}"
|
||||
);
|
||||
let _ = events_tx
|
||||
.send(RealtimeEvent::Error(mapped_error.to_string()))
|
||||
.await;
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !realtime_active.load(Ordering::Relaxed) {
|
||||
return;
|
||||
}
|
||||
|
||||
run_realtime_input_task(RealtimeInputTask {
|
||||
writer: connection.writer(),
|
||||
events: connection.events(),
|
||||
user_text_rx: input_channels.user_text_rx,
|
||||
handoff_output_rx: input_channels.handoff_output_rx,
|
||||
audio_rx: input_channels.audio_rx,
|
||||
events_tx,
|
||||
handoff_state,
|
||||
session_kind,
|
||||
event_parser,
|
||||
})
|
||||
.await;
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_realtime_input_task(input: RealtimeInputTask) {
|
||||
let RealtimeInputTask {
|
||||
writer,
|
||||
events,
|
||||
@@ -1016,57 +1146,55 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
|
||||
event_parser,
|
||||
} = input;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut output_audio_state: Option<OutputAudioState> = None;
|
||||
let mut response_create_queue = RealtimeResponseCreateQueue::default();
|
||||
let mut output_audio_state: Option<OutputAudioState> = None;
|
||||
let mut response_create_queue = RealtimeResponseCreateQueue::default();
|
||||
|
||||
loop {
|
||||
let result = tokio::select! {
|
||||
// Text typed by the user that should be sent into realtime.
|
||||
user_text = user_text_rx.recv() => {
|
||||
handle_user_text_input(
|
||||
user_text,
|
||||
&writer,
|
||||
&events_tx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
// Background agent progress or final output that should be sent back to realtime.
|
||||
background_agent_output = handoff_output_rx.recv() => {
|
||||
handle_handoff_output(
|
||||
background_agent_output,
|
||||
&writer,
|
||||
&events_tx,
|
||||
&handoff_state,
|
||||
event_parser,
|
||||
&mut response_create_queue,
|
||||
)
|
||||
.await
|
||||
}
|
||||
// Events received from the realtime server.
|
||||
realtime_event = events.next_event() => {
|
||||
handle_realtime_server_event(
|
||||
realtime_event,
|
||||
&writer,
|
||||
&events_tx,
|
||||
&handoff_state,
|
||||
session_kind,
|
||||
&mut output_audio_state,
|
||||
&mut response_create_queue,
|
||||
)
|
||||
loop {
|
||||
let result = tokio::select! {
|
||||
// Text typed by the user that should be sent into realtime.
|
||||
user_text = user_text_rx.recv() => {
|
||||
handle_user_text_input(
|
||||
user_text,
|
||||
&writer,
|
||||
&events_tx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
// Audio frames captured from the user microphone.
|
||||
user_audio_frame = audio_rx.recv() => {
|
||||
handle_user_audio_input(user_audio_frame, &writer, &events_tx)
|
||||
.await
|
||||
}
|
||||
};
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
// Background agent progress or final output that should be sent back to realtime.
|
||||
background_agent_output = handoff_output_rx.recv() => {
|
||||
handle_handoff_output(
|
||||
background_agent_output,
|
||||
&writer,
|
||||
&events_tx,
|
||||
&handoff_state,
|
||||
event_parser,
|
||||
&mut response_create_queue,
|
||||
)
|
||||
.await
|
||||
}
|
||||
// Events received from the realtime server.
|
||||
realtime_event = events.next_event() => {
|
||||
handle_realtime_server_event(
|
||||
realtime_event,
|
||||
&writer,
|
||||
&events_tx,
|
||||
&handoff_state,
|
||||
session_kind,
|
||||
&mut output_audio_state,
|
||||
&mut response_create_queue,
|
||||
)
|
||||
.await
|
||||
}
|
||||
// Audio frames captured from the user microphone.
|
||||
user_audio_frame = audio_rx.recv() => {
|
||||
handle_user_audio_input(user_audio_frame, &writer, &events_tx)
|
||||
.await
|
||||
}
|
||||
};
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_user_text_input(
|
||||
|
||||
@@ -48,6 +48,7 @@ use std::process::Command;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::timeout;
|
||||
use wiremock::Match;
|
||||
@@ -456,6 +457,7 @@ async fn conversation_webrtc_start_posts_generated_session() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let sideband_accept_delay = Duration::from_millis(1000);
|
||||
let capture = RealtimeCallRequestCapture::new();
|
||||
Mock::given(method("POST"))
|
||||
.and(path_regex(".*/realtime/calls$"))
|
||||
@@ -468,12 +470,15 @@ async fn conversation_webrtc_start_posts_generated_session() -> Result<()> {
|
||||
.mount(&server)
|
||||
.await;
|
||||
let realtime_server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
|
||||
requests: vec![vec![json!({
|
||||
"type": "session.updated",
|
||||
"session": { "id": "sess_webrtc", "instructions": "backend prompt" }
|
||||
})]],
|
||||
requests: vec![
|
||||
vec![json!({
|
||||
"type": "session.updated",
|
||||
"session": { "id": "sess_webrtc", "instructions": "backend prompt" }
|
||||
})],
|
||||
vec![],
|
||||
],
|
||||
response_headers: Vec::new(),
|
||||
accept_delay: None,
|
||||
accept_delay: Some(sideband_accept_delay),
|
||||
close_after_requests: false,
|
||||
}])
|
||||
.await;
|
||||
@@ -488,6 +493,7 @@ async fn conversation_webrtc_start_posts_generated_session() -> Result<()> {
|
||||
});
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let start = Instant::now();
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationStart(ConversationStartParams {
|
||||
output_modality: RealtimeOutputModality::Audio,
|
||||
@@ -509,7 +515,19 @@ async fn conversation_webrtc_start_posts_generated_session() -> Result<()> {
|
||||
})
|
||||
.await
|
||||
.unwrap_or_else(|err: ErrorEvent| panic!("conversation call create failed: {err:?}"));
|
||||
let sdp_elapsed = start.elapsed();
|
||||
assert_eq!(created.sdp, "v=answer\r\n");
|
||||
assert!(
|
||||
sdp_elapsed < sideband_accept_delay,
|
||||
"SDP answer should arrive before sideband accept delay; elapsed={sdp_elapsed:?}, delay={sideband_accept_delay:?}"
|
||||
);
|
||||
assert!(realtime_server.handshakes().is_empty());
|
||||
|
||||
test.codex
|
||||
.submit(Op::RealtimeConversationText(ConversationTextParams {
|
||||
text: "queued before sideband".to_string(),
|
||||
}))
|
||||
.await?;
|
||||
|
||||
let session_updated = wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
|
||||
@@ -578,6 +596,13 @@ async fn conversation_webrtc_start_posts_generated_session() -> Result<()> {
|
||||
.context("session.update should include instructions")?
|
||||
.contains("startup context")
|
||||
);
|
||||
let queued_text = realtime_server
|
||||
.wait_for_request(/*connection_index*/ 0, /*request_index*/ 1)
|
||||
.await;
|
||||
assert_eq!(
|
||||
websocket_request_text(&queued_text).as_deref(),
|
||||
Some("queued before sideband")
|
||||
);
|
||||
let handshake = realtime_server.single_handshake();
|
||||
assert_eq!(
|
||||
handshake.uri(),
|
||||
|
||||
Reference in New Issue
Block a user