mirror of
https://github.com/openai/codex.git
synced 2026-04-30 03:12:20 +03:00
This PR is part of the effort to move the TUI on top of the app server. In a previous PR, we introduced an in-process app server and moved `exec` on top of it. For the TUI, we want to do the migration in stages. The app server doesn't currently expose all of the functionality required by the TUI, so we're going to need to support a hybrid approach as we make the transition. This PR changes the TUI initialization to instantiate an in-process app server and access its `AuthManager` and `ThreadManager` rather than constructing its own copies. It also adds a placeholder TUI event handler that will eventually translate app server events into TUI events. App server notifications are accepted but ignored for now. It also adds proper shutdown of the app server when the TUI terminates.
559 lines
19 KiB
Rust
559 lines
19 KiB
Rust
use super::ConnectionSessionState;
|
|
use super::MessageProcessor;
|
|
use super::MessageProcessorArgs;
|
|
use crate::outgoing_message::ConnectionId;
|
|
use crate::outgoing_message::OutgoingMessageSender;
|
|
use crate::transport::AppServerTransport;
|
|
use anyhow::Result;
|
|
use app_test_support::create_mock_responses_server_repeating_assistant;
|
|
use app_test_support::write_mock_responses_config_toml;
|
|
use codex_app_server_protocol::ClientInfo;
|
|
use codex_app_server_protocol::ClientRequest;
|
|
use codex_app_server_protocol::InitializeCapabilities;
|
|
use codex_app_server_protocol::InitializeParams;
|
|
use codex_app_server_protocol::InitializeResponse;
|
|
use codex_app_server_protocol::JSONRPCRequest;
|
|
use codex_app_server_protocol::RequestId;
|
|
use codex_app_server_protocol::ThreadStartParams;
|
|
use codex_app_server_protocol::ThreadStartResponse;
|
|
use codex_app_server_protocol::TurnStartParams;
|
|
use codex_app_server_protocol::TurnStartResponse;
|
|
use codex_app_server_protocol::UserInput;
|
|
use codex_arg0::Arg0DispatchPaths;
|
|
use codex_core::config::Config;
|
|
use codex_core::config::ConfigBuilder;
|
|
use codex_core::config_loader::CloudRequirementsLoader;
|
|
use codex_core::config_loader::LoaderOverrides;
|
|
use codex_feedback::CodexFeedback;
|
|
use codex_protocol::protocol::SessionSource;
|
|
use codex_protocol::protocol::W3cTraceContext;
|
|
use opentelemetry::global;
|
|
use opentelemetry::trace::SpanId;
|
|
use opentelemetry::trace::SpanKind;
|
|
use opentelemetry::trace::TraceId;
|
|
use opentelemetry::trace::TracerProvider as _;
|
|
use opentelemetry_sdk::propagation::TraceContextPropagator;
|
|
use opentelemetry_sdk::trace::InMemorySpanExporter;
|
|
use opentelemetry_sdk::trace::SdkTracerProvider;
|
|
use opentelemetry_sdk::trace::SpanData;
|
|
use pretty_assertions::assert_eq;
|
|
use std::collections::BTreeMap;
|
|
use std::path::Path;
|
|
use std::sync::Arc;
|
|
use std::sync::OnceLock;
|
|
use tempfile::TempDir;
|
|
use tokio::sync::mpsc;
|
|
use tracing_subscriber::layer::SubscriberExt;
|
|
use wiremock::MockServer;
|
|
|
|
const TEST_CONNECTION_ID: ConnectionId = ConnectionId(7);
|
|
const CORE_TURN_SANITY_SPAN_NAMES: &[&str] =
|
|
&["submission_dispatch", "session_task.turn", "run_turn"];
|
|
|
|
struct TestTracing {
|
|
exporter: InMemorySpanExporter,
|
|
provider: SdkTracerProvider,
|
|
}
|
|
|
|
struct RemoteTrace {
|
|
trace_id: TraceId,
|
|
parent_span_id: SpanId,
|
|
context: W3cTraceContext,
|
|
}
|
|
|
|
impl RemoteTrace {
|
|
fn new(trace_id: &str, parent_span_id: &str) -> Self {
|
|
let trace_id = TraceId::from_hex(trace_id).expect("trace id");
|
|
let parent_span_id = SpanId::from_hex(parent_span_id).expect("parent span id");
|
|
let context = W3cTraceContext {
|
|
traceparent: Some(format!("00-{trace_id}-{parent_span_id}-01")),
|
|
tracestate: Some("vendor=value".to_string()),
|
|
};
|
|
|
|
Self {
|
|
trace_id,
|
|
parent_span_id,
|
|
context,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn init_test_tracing() -> &'static TestTracing {
|
|
static TEST_TRACING: OnceLock<TestTracing> = OnceLock::new();
|
|
TEST_TRACING.get_or_init(|| {
|
|
let exporter = InMemorySpanExporter::default();
|
|
let provider = SdkTracerProvider::builder()
|
|
.with_simple_exporter(exporter.clone())
|
|
.build();
|
|
let tracer = provider.tracer("codex-app-server-message-processor-tests");
|
|
global::set_text_map_propagator(TraceContextPropagator::new());
|
|
let subscriber =
|
|
tracing_subscriber::registry().with(tracing_opentelemetry::layer().with_tracer(tracer));
|
|
tracing::subscriber::set_global_default(subscriber)
|
|
.expect("global tracing subscriber should only be installed once");
|
|
TestTracing { exporter, provider }
|
|
})
|
|
}
|
|
|
|
fn request_from_client_request(request: ClientRequest) -> JSONRPCRequest {
|
|
serde_json::from_value(serde_json::to_value(request).expect("serialize client request"))
|
|
.expect("client request should convert to JSON-RPC")
|
|
}
|
|
|
|
fn tracing_test_guard() -> &'static tokio::sync::Mutex<()> {
|
|
static GUARD: OnceLock<tokio::sync::Mutex<()>> = OnceLock::new();
|
|
GUARD.get_or_init(|| tokio::sync::Mutex::new(()))
|
|
}
|
|
|
|
struct TracingHarness {
|
|
_server: MockServer,
|
|
_codex_home: TempDir,
|
|
processor: MessageProcessor,
|
|
outgoing_rx: mpsc::Receiver<crate::outgoing_message::OutgoingEnvelope>,
|
|
session: ConnectionSessionState,
|
|
tracing: &'static TestTracing,
|
|
}
|
|
|
|
impl TracingHarness {
|
|
async fn new() -> Result<Self> {
|
|
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
|
let codex_home = TempDir::new()?;
|
|
let config = Arc::new(build_test_config(codex_home.path(), &server.uri()).await?);
|
|
let (processor, outgoing_rx) = build_test_processor(config);
|
|
let tracing = init_test_tracing();
|
|
tracing.exporter.reset();
|
|
tracing::callsite::rebuild_interest_cache();
|
|
let mut harness = Self {
|
|
_server: server,
|
|
_codex_home: codex_home,
|
|
processor,
|
|
outgoing_rx,
|
|
session: ConnectionSessionState::default(),
|
|
tracing,
|
|
};
|
|
|
|
let _: InitializeResponse = harness
|
|
.request(
|
|
ClientRequest::Initialize {
|
|
request_id: RequestId::Integer(1),
|
|
params: InitializeParams {
|
|
client_info: ClientInfo {
|
|
name: "codex-app-server-tests".to_string(),
|
|
title: None,
|
|
version: "0.1.0".to_string(),
|
|
},
|
|
capabilities: Some(InitializeCapabilities {
|
|
experimental_api: true,
|
|
..Default::default()
|
|
}),
|
|
},
|
|
},
|
|
None,
|
|
)
|
|
.await;
|
|
assert!(harness.session.initialized);
|
|
|
|
Ok(harness)
|
|
}
|
|
|
|
fn reset_tracing(&self) {
|
|
self.tracing.exporter.reset();
|
|
}
|
|
|
|
async fn shutdown(self) {
|
|
self.processor.shutdown_threads().await;
|
|
self.processor.drain_background_tasks().await;
|
|
}
|
|
|
|
async fn request<T>(&mut self, request: ClientRequest, trace: Option<W3cTraceContext>) -> T
|
|
where
|
|
T: serde::de::DeserializeOwned,
|
|
{
|
|
let request_id = match request.id() {
|
|
RequestId::Integer(request_id) => *request_id,
|
|
request_id => panic!("expected integer request id in test harness, got {request_id:?}"),
|
|
};
|
|
let mut request = request_from_client_request(request);
|
|
request.trace = trace;
|
|
|
|
self.processor
|
|
.process_request(
|
|
TEST_CONNECTION_ID,
|
|
request,
|
|
AppServerTransport::Stdio,
|
|
&mut self.session,
|
|
)
|
|
.await;
|
|
read_response(&mut self.outgoing_rx, request_id).await
|
|
}
|
|
|
|
async fn start_thread(
|
|
&mut self,
|
|
request_id: i64,
|
|
trace: Option<W3cTraceContext>,
|
|
) -> ThreadStartResponse {
|
|
let response = self
|
|
.request(
|
|
ClientRequest::ThreadStart {
|
|
request_id: RequestId::Integer(request_id),
|
|
params: ThreadStartParams {
|
|
ephemeral: Some(true),
|
|
..ThreadStartParams::default()
|
|
},
|
|
},
|
|
trace,
|
|
)
|
|
.await;
|
|
read_thread_started_notification(&mut self.outgoing_rx).await;
|
|
response
|
|
}
|
|
}
|
|
|
|
async fn build_test_config(codex_home: &Path, server_uri: &str) -> Result<Config> {
|
|
write_mock_responses_config_toml(
|
|
codex_home,
|
|
server_uri,
|
|
&BTreeMap::new(),
|
|
8_192,
|
|
Some(false),
|
|
"mock_provider",
|
|
"compact",
|
|
)?;
|
|
|
|
Ok(ConfigBuilder::default()
|
|
.codex_home(codex_home.to_path_buf())
|
|
.build()
|
|
.await?)
|
|
}
|
|
|
|
fn build_test_processor(
|
|
config: Arc<Config>,
|
|
) -> (
|
|
MessageProcessor,
|
|
mpsc::Receiver<crate::outgoing_message::OutgoingEnvelope>,
|
|
) {
|
|
let (outgoing_tx, outgoing_rx) = mpsc::channel(16);
|
|
let outgoing = Arc::new(OutgoingMessageSender::new(outgoing_tx));
|
|
let processor = MessageProcessor::new(MessageProcessorArgs {
|
|
outgoing,
|
|
arg0_paths: Arg0DispatchPaths::default(),
|
|
config,
|
|
cli_overrides: Vec::new(),
|
|
loader_overrides: LoaderOverrides::default(),
|
|
cloud_requirements: CloudRequirementsLoader::default(),
|
|
auth_manager: None,
|
|
thread_manager: None,
|
|
feedback: CodexFeedback::new(),
|
|
log_db: None,
|
|
config_warnings: Vec::new(),
|
|
session_source: SessionSource::VSCode,
|
|
enable_codex_api_key_env: false,
|
|
});
|
|
(processor, outgoing_rx)
|
|
}
|
|
|
|
fn span_attr<'a>(span: &'a SpanData, key: &str) -> Option<&'a str> {
|
|
span.attributes
|
|
.iter()
|
|
.find(|kv| kv.key.as_str() == key)
|
|
.and_then(|kv| match &kv.value {
|
|
opentelemetry::Value::String(value) => Some(value.as_str()),
|
|
_ => None,
|
|
})
|
|
}
|
|
|
|
fn find_rpc_span_with_trace<'a>(
|
|
spans: &'a [SpanData],
|
|
kind: SpanKind,
|
|
method: &str,
|
|
trace_id: TraceId,
|
|
) -> &'a SpanData {
|
|
spans
|
|
.iter()
|
|
.find(|span| {
|
|
span.span_kind == kind
|
|
&& span_attr(span, "rpc.system") == Some("jsonrpc")
|
|
&& span_attr(span, "rpc.method") == Some(method)
|
|
&& span.span_context.trace_id() == trace_id
|
|
})
|
|
.unwrap_or_else(|| {
|
|
panic!(
|
|
"missing {kind:?} span for rpc.method={method} trace={trace_id}; exported spans:\n{}",
|
|
format_spans(spans)
|
|
)
|
|
})
|
|
}
|
|
|
|
fn find_span_by_name_with_trace<'a>(
|
|
spans: &'a [SpanData],
|
|
name: &str,
|
|
trace_id: TraceId,
|
|
) -> &'a SpanData {
|
|
spans
|
|
.iter()
|
|
.find(|span| span.name.as_ref() == name && span.span_context.trace_id() == trace_id)
|
|
.unwrap_or_else(|| {
|
|
panic!(
|
|
"missing span named {name} for trace={trace_id}; exported spans:\n{}",
|
|
format_spans(spans)
|
|
)
|
|
})
|
|
}
|
|
|
|
fn format_spans(spans: &[SpanData]) -> String {
|
|
spans
|
|
.iter()
|
|
.map(|span| {
|
|
let rpc_method = span_attr(span, "rpc.method").unwrap_or("-");
|
|
format!(
|
|
"name={} span_id={} kind={:?} parent={} trace={} rpc.method={}",
|
|
span.name,
|
|
span.span_context.span_id(),
|
|
span.span_kind,
|
|
span.parent_span_id,
|
|
span.span_context.trace_id(),
|
|
rpc_method
|
|
)
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join("\n")
|
|
}
|
|
|
|
fn assert_span_descends_from(spans: &[SpanData], child: &SpanData, ancestor: &SpanData) {
|
|
let ancestor_span_id = ancestor.span_context.span_id();
|
|
let mut parent_span_id = child.parent_span_id;
|
|
while parent_span_id != SpanId::INVALID {
|
|
if parent_span_id == ancestor_span_id {
|
|
return;
|
|
}
|
|
let Some(parent_span) = spans
|
|
.iter()
|
|
.find(|span| span.span_context.span_id() == parent_span_id)
|
|
else {
|
|
break;
|
|
};
|
|
parent_span_id = parent_span.parent_span_id;
|
|
}
|
|
|
|
panic!(
|
|
"span {} does not descend from {}; exported spans:\n{}",
|
|
child.name,
|
|
ancestor.name,
|
|
format_spans(spans)
|
|
);
|
|
}
|
|
|
|
async fn read_response<T: serde::de::DeserializeOwned>(
|
|
outgoing_rx: &mut mpsc::Receiver<crate::outgoing_message::OutgoingEnvelope>,
|
|
request_id: i64,
|
|
) -> T {
|
|
loop {
|
|
let envelope = tokio::time::timeout(std::time::Duration::from_secs(5), outgoing_rx.recv())
|
|
.await
|
|
.expect("timed out waiting for response")
|
|
.expect("outgoing channel closed");
|
|
let crate::outgoing_message::OutgoingEnvelope::ToConnection {
|
|
connection_id,
|
|
message,
|
|
} = envelope
|
|
else {
|
|
continue;
|
|
};
|
|
if connection_id != TEST_CONNECTION_ID {
|
|
continue;
|
|
}
|
|
let crate::outgoing_message::OutgoingMessage::Response(response) = message else {
|
|
continue;
|
|
};
|
|
if response.id != RequestId::Integer(request_id) {
|
|
continue;
|
|
}
|
|
return serde_json::from_value(response.result)
|
|
.expect("response payload should deserialize");
|
|
}
|
|
}
|
|
|
|
async fn read_thread_started_notification(
|
|
outgoing_rx: &mut mpsc::Receiver<crate::outgoing_message::OutgoingEnvelope>,
|
|
) {
|
|
loop {
|
|
let envelope = tokio::time::timeout(std::time::Duration::from_secs(5), outgoing_rx.recv())
|
|
.await
|
|
.expect("timed out waiting for thread/started notification")
|
|
.expect("outgoing channel closed");
|
|
match envelope {
|
|
crate::outgoing_message::OutgoingEnvelope::ToConnection {
|
|
connection_id,
|
|
message,
|
|
} => {
|
|
if connection_id != TEST_CONNECTION_ID {
|
|
continue;
|
|
}
|
|
let crate::outgoing_message::OutgoingMessage::AppServerNotification(notification) =
|
|
message
|
|
else {
|
|
continue;
|
|
};
|
|
if matches!(
|
|
notification,
|
|
codex_app_server_protocol::ServerNotification::ThreadStarted(_)
|
|
) {
|
|
return;
|
|
}
|
|
}
|
|
crate::outgoing_message::OutgoingEnvelope::Broadcast { message } => {
|
|
let crate::outgoing_message::OutgoingMessage::AppServerNotification(notification) =
|
|
message
|
|
else {
|
|
continue;
|
|
};
|
|
if matches!(
|
|
notification,
|
|
codex_app_server_protocol::ServerNotification::ThreadStarted(_)
|
|
) {
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn wait_for_exported_spans<F>(tracing: &TestTracing, predicate: F) -> Vec<SpanData>
|
|
where
|
|
F: Fn(&[SpanData]) -> bool,
|
|
{
|
|
let mut last_spans = Vec::new();
|
|
for _ in 0..200 {
|
|
tokio::task::yield_now().await;
|
|
tracing
|
|
.provider
|
|
.force_flush()
|
|
.expect("force flush should succeed");
|
|
let spans = tracing.exporter.get_finished_spans().expect("span export");
|
|
last_spans = spans.clone();
|
|
if predicate(&spans) {
|
|
return spans;
|
|
}
|
|
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
|
}
|
|
|
|
panic!(
|
|
"timed out waiting for expected exported spans:\n{}",
|
|
format_spans(&last_spans)
|
|
);
|
|
}
|
|
|
|
#[tokio::test(flavor = "current_thread")]
|
|
async fn thread_start_jsonrpc_span_exports_server_span_and_parents_children() -> Result<()> {
|
|
let _guard = tracing_test_guard().lock().await;
|
|
let mut harness = TracingHarness::new().await?;
|
|
|
|
let RemoteTrace {
|
|
trace_id: remote_trace_id,
|
|
context: remote_trace,
|
|
..
|
|
} = RemoteTrace::new("00000000000000000000000000000011", "0000000000000022");
|
|
|
|
let _: ThreadStartResponse = harness.start_thread(2, Some(remote_trace)).await;
|
|
let spans = wait_for_exported_spans(harness.tracing, |spans| {
|
|
spans.iter().any(|span| {
|
|
span.span_kind == SpanKind::Server
|
|
&& span_attr(span, "rpc.method") == Some("thread/start")
|
|
&& span.span_context.trace_id() == remote_trace_id
|
|
}) && spans.iter().any(|span| {
|
|
span.name.as_ref() == "thread_spawn" && span.span_context.trace_id() == remote_trace_id
|
|
}) && spans.iter().any(|span| {
|
|
span.name.as_ref() == "session_init" && span.span_context.trace_id() == remote_trace_id
|
|
})
|
|
})
|
|
.await;
|
|
|
|
let server_request_span =
|
|
find_rpc_span_with_trace(&spans, SpanKind::Server, "thread/start", remote_trace_id);
|
|
let thread_spawn_span = find_span_by_name_with_trace(&spans, "thread_spawn", remote_trace_id);
|
|
let session_init_span = find_span_by_name_with_trace(&spans, "session_init", remote_trace_id);
|
|
assert_eq!(server_request_span.name.as_ref(), "thread/start");
|
|
assert_eq!(server_request_span.span_context.trace_id(), remote_trace_id);
|
|
assert_ne!(server_request_span.span_context.span_id(), SpanId::INVALID);
|
|
assert_span_descends_from(&spans, thread_spawn_span, server_request_span);
|
|
assert_span_descends_from(&spans, session_init_span, server_request_span);
|
|
harness.shutdown().await;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test(flavor = "current_thread")]
|
|
async fn turn_start_jsonrpc_span_parents_core_turn_spans() -> Result<()> {
|
|
let _guard = tracing_test_guard().lock().await;
|
|
let mut harness = TracingHarness::new().await?;
|
|
let thread_start_response = harness.start_thread(2, None).await;
|
|
let thread_id = thread_start_response.thread.id.clone();
|
|
|
|
harness.reset_tracing();
|
|
|
|
let RemoteTrace {
|
|
trace_id: remote_trace_id,
|
|
parent_span_id: remote_parent_span_id,
|
|
context: remote_trace,
|
|
} = RemoteTrace::new("00000000000000000000000000000077", "0000000000000088");
|
|
let _: TurnStartResponse = harness
|
|
.request(
|
|
ClientRequest::TurnStart {
|
|
request_id: RequestId::Integer(3),
|
|
params: TurnStartParams {
|
|
thread_id,
|
|
input: vec![UserInput::Text {
|
|
text: "hello".to_string(),
|
|
text_elements: Vec::new(),
|
|
}],
|
|
cwd: None,
|
|
approval_policy: None,
|
|
sandbox_policy: None,
|
|
model: None,
|
|
service_tier: None,
|
|
effort: None,
|
|
summary: None,
|
|
personality: None,
|
|
output_schema: None,
|
|
collaboration_mode: None,
|
|
},
|
|
},
|
|
Some(remote_trace),
|
|
)
|
|
.await;
|
|
let spans = wait_for_exported_spans(harness.tracing, |spans| {
|
|
spans.iter().any(|span| {
|
|
span.span_kind == SpanKind::Server
|
|
&& span_attr(span, "rpc.method") == Some("turn/start")
|
|
&& span.span_context.trace_id() == remote_trace_id
|
|
}) && spans.iter().any(|span| {
|
|
CORE_TURN_SANITY_SPAN_NAMES.contains(&span.name.as_ref())
|
|
&& span.span_context.trace_id() == remote_trace_id
|
|
})
|
|
})
|
|
.await;
|
|
|
|
let server_request_span =
|
|
find_rpc_span_with_trace(&spans, SpanKind::Server, "turn/start", remote_trace_id);
|
|
let core_turn_span = spans
|
|
.iter()
|
|
.find(|span| {
|
|
CORE_TURN_SANITY_SPAN_NAMES.contains(&span.name.as_ref())
|
|
&& span.span_context.trace_id() == remote_trace_id
|
|
})
|
|
.unwrap_or_else(|| {
|
|
panic!(
|
|
"missing representative core turn span for trace={remote_trace_id}; exported spans:\n{}",
|
|
format_spans(&spans)
|
|
)
|
|
});
|
|
|
|
assert_eq!(server_request_span.parent_span_id, remote_parent_span_id);
|
|
assert!(server_request_span.parent_span_is_remote);
|
|
assert_eq!(server_request_span.span_context.trace_id(), remote_trace_id);
|
|
assert_span_descends_from(&spans, core_turn_span, server_request_span);
|
|
harness.shutdown().await;
|
|
|
|
Ok(())
|
|
}
|