# PR #1547: Add tests for OpenAI helpers and retry logic - URL: https://github.com/openai/codex/pull/1547 - Author: aibrahim-oai - Created: 2025-07-11 21:37:41 UTC - Updated: 2025-07-17 16:38:37 UTC - Changes: +501/-38, Files changed: 11, Commits: 16 ## Description ## Summary - add unit tests for tool JSON helpers - verify message assembly for chat completions - test retry and error handling paths of `ModelClient` ## Testing - `cargo clippy --workspace --all-targets -- -D warnings` - `cargo test --workspace --exclude codex-linux-sandbox` ------ https://chatgpt.com/codex/tasks/task_i_68717e8603a48321b875080ed3b70d63 ## Full Diff ```diff diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 816fc80f9b..ab60f20cd4 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -21,7 +21,6 @@ use crate::client_common::ResponseEvent; use crate::client_common::ResponseStream; use crate::error::CodexErr; use crate::error::Result; -use crate::flags::OPENAI_REQUEST_MAX_RETRIES; use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS; use crate::models::ContentItem; use crate::models::ResponseItem; @@ -34,6 +33,7 @@ pub(crate) async fn stream_chat_completions( model: &str, client: &reqwest::Client, provider: &ModelProviderInfo, + max_retries: u64, ) -> Result { // Build messages array let mut messages = Vec::::new(); @@ -146,7 +146,7 @@ pub(crate) async fn stream_chat_completions( return Err(CodexErr::UnexpectedStatus(status, body)); } - if attempt > *OPENAI_REQUEST_MAX_RETRIES { + if attempt > max_retries { return Err(CodexErr::RetryLimit(status)); } @@ -162,7 +162,7 @@ pub(crate) async fn stream_chat_completions( tokio::time::sleep(delay).await; } Err(e) => { - if attempt > *OPENAI_REQUEST_MAX_RETRIES { + if attempt > max_retries { return Err(e.into()); } let delay = backoff(attempt); @@ -462,3 +462,134 @@ pub(crate) trait AggregateStreamExt: Stream> + Size } impl AggregateStreamExt for T where T: Stream> + Sized {} +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + use crate::WireApi; + use crate::client_common::Prompt; + use crate::config::Config; + use crate::config::ConfigOverrides; + use crate::config::ConfigToml; + use crate::models::ContentItem; + use crate::models::FunctionCallOutputPayload; + use crate::models::ResponseItem; + use pretty_assertions::assert_eq; + use std::sync::Arc; + use std::sync::Mutex; + use tempfile::TempDir; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::Request; + use wiremock::Respond; + use wiremock::ResponseTemplate; + use wiremock::matchers::method; + use wiremock::matchers::path; + + struct CaptureResponder { + body: Arc>>, + } + + impl Respond for CaptureResponder { + fn respond(&self, req: &Request) -> ResponseTemplate { + let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap(); + *self.body.lock().unwrap() = Some(v); + ResponseTemplate::new(200).insert_header("content-type", "text/event-stream") + } + } + + /// Validate that `stream_chat_completions` converts our internal `Prompt` into the exact + /// Chat Completions JSON payload expected by OpenAI. We build a prompt containing user + /// assistant turns, a function call and its output, issue the request against a + /// `wiremock::MockServer`, capture the JSON body, and assert that the full `messages` array + /// matches a golden value. The test is a pure unit-test; it is skipped automatically when + /// the sandbox disables networking. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn assembles_messages_correctly() { + // Skip when sandbox networking is disabled (e.g. on CI). + if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + return; + } + let server = MockServer::start().await; + let capture = Arc::new(Mutex::new(None)); + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(CaptureResponder { + body: capture.clone(), + }) + .mount(&server) + .await; + + let provider = ModelProviderInfo { + name: "test".into(), + base_url: format!("{}/v1", server.uri()), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + + let codex_home = TempDir::new().unwrap(); + let mut config = Config::load_from_base_config_with_overrides( + ConfigToml::default(), + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + ) + .unwrap(); + config.model_provider = provider.clone(); + config.model = "gpt-4".into(); + + let client = reqwest::Client::new(); + + let prompt = Prompt { + input: vec![ + ResponseItem::Message { + role: "user".into(), + content: vec![ContentItem::InputText { text: "hi".into() }], + }, + ResponseItem::Message { + role: "assistant".into(), + content: vec![ContentItem::OutputText { text: "ok".into() }], + }, + ResponseItem::FunctionCall { + name: "foo".into(), + arguments: "{}".into(), + call_id: "c1".into(), + }, + ResponseItem::FunctionCallOutput { + call_id: "c1".into(), + output: FunctionCallOutputPayload { + content: "out".into(), + success: Some(true), + }, + }, + ], + ..Default::default() + }; + + let _ = stream_chat_completions( + &prompt, + &config.model, + &client, + &provider, + config.openai_request_max_retries, + ) + .await + .unwrap(); + + let body = capture.lock().unwrap().take().unwrap(); + let messages = body.get("messages").unwrap(); + + let expected = serde_json::json!([ + {"role":"system","content":prompt.get_full_instructions(&config.model)}, + {"role":"user","content":"hi"}, + {"role":"assistant","content":"ok"}, + {"role":"assistant", "content": null, "tool_calls":[{"id":"c1","type":"function","function":{"name":"foo","arguments":"{}"}}]}, + {"role":"tool","tool_call_id":"c1","content":"out"} + ]); + + assert_eq!(messages, &expected); + } +} diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 2fa182cf7f..34512a6dd9 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -29,7 +29,6 @@ use crate::config_types::ReasoningSummary as ReasoningSummaryConfig; use crate::error::CodexErr; use crate::error::Result; use crate::flags::CODEX_RS_SSE_FIXTURE; -use crate::flags::OPENAI_REQUEST_MAX_RETRIES; use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS; use crate::model_provider_info::ModelProviderInfo; use crate::model_provider_info::WireApi; @@ -77,6 +76,7 @@ impl ModelClient { &self.config.model, &self.client, &self.provider, + self.config.openai_request_max_retries, ) .await?; @@ -135,6 +135,7 @@ impl ModelClient { ); let mut attempt = 0; + let max_retries = self.config.openai_request_max_retries; loop { attempt += 1; @@ -171,7 +172,7 @@ impl ModelClient { return Err(CodexErr::UnexpectedStatus(status, body)); } - if attempt > *OPENAI_REQUEST_MAX_RETRIES { + if attempt > max_retries { return Err(CodexErr::RetryLimit(status)); } @@ -188,7 +189,7 @@ impl ModelClient { tokio::time::sleep(delay).await; } Err(e) => { - if attempt > *OPENAI_REQUEST_MAX_RETRIES { + if attempt > max_retries { return Err(e.into()); } let delay = backoff(attempt); @@ -315,7 +316,7 @@ where // duplicated `output` array embedded in the `response.completed` // payload. That produced two concrete issues: // 1. No real‑time streaming – the user only saw output after the - // entire turn had finished, which broke the “typing” UX and + // entire turn had finished, which broke the "typing" UX and // made long‑running turns look stalled. // 2. Duplicate `function_call_output` items – both the // individual *and* the completed array were forwarded, which @@ -394,17 +395,76 @@ async fn stream_from_fixture(path: impl AsRef) -> Result { #[cfg(test)] mod tests { - #![allow(clippy::expect_used, clippy::unwrap_used)] + #![allow(clippy::unwrap_used, clippy::print_stdout, clippy::expect_used)] use super::*; + use crate::client_common::Prompt; + use crate::config::Config; + use crate::config::ConfigOverrides; + use crate::config::ConfigToml; + use crate::config_types::ReasoningEffort as ReasoningEffortConfig; + use crate::config_types::ReasoningSummary as ReasoningSummaryConfig; + use futures::StreamExt; + use reqwest::StatusCode; use serde_json::json; + use std::sync::Arc; + use std::sync::Mutex; + use std::time::Duration; + use std::time::Instant; + use tempfile::TempDir; use tokio::sync::mpsc; use tokio_test::io::Builder as IoBuilder; use tokio_util::io::ReaderStream; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::Request; + use wiremock::Respond; + use wiremock::ResponseTemplate; + use wiremock::matchers::method; + use wiremock::matchers::path; + + // ─────────────────────────── Helpers ─────────────────────────── + + fn default_config(provider: ModelProviderInfo, max_retries: u64) -> Arc { + let codex_home = TempDir::new().unwrap(); + let mut cfg = Config::load_from_base_config_with_overrides( + ConfigToml::default(), + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + ) + .unwrap(); + cfg.model_provider = provider.clone(); + cfg.model = "gpt-test".into(); + cfg.openai_request_max_retries = max_retries; + Arc::new(cfg) + } + + fn create_test_client(server: &MockServer, max_retries: u64) -> ModelClient { + let provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + let config = default_config(provider.clone(), max_retries); + ModelClient::new( + config, + provider, + ReasoningEffortConfig::None, + ReasoningSummaryConfig::None, + ) + } - // ──────────────────────────── - // Helpers - // ──────────────────────────── + fn sse_completed(id: &str) -> String { + format!( + "event: response.completed\n\ + data: {{\"type\":\"response.completed\",\"response\":{{\"id\":\"{id}\",\"output\":[]}}}}\n\n\n" + ) + } /// Runs the SSE parser on pre-chunked byte slices and returns every event /// (including any final `Err` from a stream-closure check). @@ -453,9 +513,172 @@ mod tests { out } - // ──────────────────────────── - // Tests from `implement-test-for-responses-api-sse-parser` - // ──────────────────────────── + // ───────────── Retry / back-off behaviour tests ───────────── + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn retries_once_on_server_error() { + if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + return; + } + let server = MockServer::start().await; + + struct SeqResponder; + impl Respond for SeqResponder { + fn respond(&self, _req: &Request) -> ResponseTemplate { + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering; + static CALLS: AtomicUsize = AtomicUsize::new(0); + let n = CALLS.fetch_add(1, Ordering::SeqCst); + if n == 0 { + ResponseTemplate::new(500) + } else { + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("ok"), "text/event-stream") + } + } + } + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(SeqResponder) + .expect(2) + .mount(&server) + .await; + + let client = create_test_client(&server, 1); + let prompt = Prompt::default(); + let mut stream = client.stream(&prompt).await.unwrap(); + while let Some(ev) = stream.next().await { + if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) { + break; + } + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn retry_after_header_delay() { + if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + return; + } + let server = MockServer::start().await; + let times = Arc::new(Mutex::new(Vec::new())); + + struct SeqResponder { + times: Arc>>, + } + impl Respond for SeqResponder { + fn respond(&self, _req: &Request) -> ResponseTemplate { + let mut t = self.times.lock().unwrap(); + t.push(Instant::now()); + if t.len() == 1 { + ResponseTemplate::new(429).insert_header("retry-after", "1") + } else { + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("ok"), "text/event-stream") + } + } + } + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(SeqResponder { + times: times.clone(), + }) + .expect(2) + .mount(&server) + .await; + + let client = create_test_client(&server, 1); + let prompt = Prompt::default(); + let mut stream = client.stream(&prompt).await.unwrap(); + while let Some(ev) = stream.next().await { + if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) { + break; + } + } + let times = times.lock().unwrap(); + assert_eq!(times.len(), 2); + assert!(times[1] - times[0] >= Duration::from_secs(1)); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn retry_backoff_no_header() { + if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + return; + } + let server = MockServer::start().await; + let times = Arc::new(Mutex::new(Vec::new())); + + struct SeqResponder { + times: Arc>>, + } + impl Respond for SeqResponder { + fn respond(&self, _req: &Request) -> ResponseTemplate { + let mut t = self.times.lock().unwrap(); + t.push(Instant::now()); + if t.len() == 1 { + ResponseTemplate::new(429) + } else { + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("ok"), "text/event-stream") + } + } + } + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(SeqResponder { + times: times.clone(), + }) + .expect(2) + .mount(&server) + .await; + + let client = create_test_client(&server, 1); + let prompt = Prompt::default(); + let mut stream = client.stream(&prompt).await.unwrap(); + while let Some(ev) = stream.next().await { + if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) { + break; + } + } + let times = times.lock().unwrap(); + assert_eq!(times.len(), 2); + assert!(times[1] - times[0] >= Duration::from_millis(100)); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn permanent_error_bubbles_body() { + if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + return; + } + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(ResponseTemplate::new(400).set_body_string("bad")) + .expect(1) + .mount(&server) + .await; + + let client = create_test_client(&server, 0); + let prompt = Prompt::default(); + match client.stream(&prompt).await { + Ok(_) => panic!("expected error"), + Err(CodexErr::UnexpectedStatus(code, body)) => { + assert_eq!(code, StatusCode::BAD_REQUEST); + assert_eq!(body, "bad"); + } + Err(other) => panic!("unexpected error: {other:?}"), + } + } + + // ─────────────────────────── + // SSE-parser tests + // ─────────────────────────── #[tokio::test] async fn parses_items_and_completed() { @@ -493,17 +716,17 @@ mod tests { assert_eq!(events.len(), 3); - matches!( + assert!(matches!( &events[0], Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) if role == "assistant" - ); + )); - matches!( + assert!(matches!( &events[1], Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) if role == "assistant" - ); + )); match &events[2] { Ok(ResponseEvent::Completed { @@ -535,7 +758,7 @@ mod tests { assert_eq!(events.len(), 2); - matches!(events[0], Ok(ResponseEvent::OutputItemDone(_))); + assert!(matches!(events[0], Ok(ResponseEvent::OutputItemDone(_)))); match &events[1] { Err(CodexErr::Stream(msg)) => { @@ -545,12 +768,10 @@ mod tests { } } - // ──────────────────────────── - // Table-driven test from `main` - // ──────────────────────────── + // ─────────────────────────── + // Table-driven event-kind test + // ─────────────────────────── - /// Verifies that the adapter produces the right `ResponseEvent` for a - /// variety of incoming `type` values. #[tokio::test] async fn table_driven_event_kinds() { struct TestCase { diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index d67e692fc8..51e2c15403 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -10,6 +10,7 @@ use crate::config_types::ShellEnvironmentPolicyToml; use crate::config_types::Tui; use crate::config_types::UriBasedFileOpener; use crate::flags::OPENAI_DEFAULT_MODEL; +use crate::flags::OPENAI_REQUEST_MAX_RETRIES; use crate::model_provider_info::ModelProviderInfo; use crate::model_provider_info::built_in_model_providers; use crate::openai_model_info::get_model_info; @@ -137,6 +138,9 @@ pub struct Config { /// Base URL for requests to ChatGPT (as opposed to the OpenAI API). pub chatgpt_base_url: String, + + /// Max number of retries for a request to the model. + pub openai_request_max_retries: u64, } impl Config { @@ -321,6 +325,9 @@ pub struct ConfigToml { /// Base URL for requests to ChatGPT (as opposed to the OpenAI API). pub chatgpt_base_url: Option, + + /// Max number of retries for a request to the model. + pub openai_request_max_retries: Option, } impl ConfigToml { @@ -353,6 +360,7 @@ pub struct ConfigOverrides { pub model_provider: Option, pub config_profile: Option, pub codex_linux_sandbox_exe: Option, + pub openai_request_max_retries: Option, } impl Config { @@ -374,6 +382,7 @@ impl Config { model_provider, config_profile: config_profile_key, codex_linux_sandbox_exe, + openai_request_max_retries, } = overrides; let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) { @@ -448,6 +457,12 @@ impl Config { .as_ref() .map(|info| info.max_output_tokens) }); + + // Resolve the max-retry setting (CLI override > config.toml > env flag default). + let resolved_openai_request_max_retries = openai_request_max_retries + .or(cfg.openai_request_max_retries) + .unwrap_or_else(|| *OPENAI_REQUEST_MAX_RETRIES); + let config = Self { model, model_context_window, @@ -494,6 +509,8 @@ impl Config { .chatgpt_base_url .or(cfg.chatgpt_base_url) .unwrap_or("https://chatgpt.com/backend-api/".to_string()), + + openai_request_max_retries: resolved_openai_request_max_retries, }; Ok(config) } @@ -559,6 +576,7 @@ pub fn log_dir(cfg: &Config) -> std::io::Result { mod tests { #![allow(clippy::expect_used, clippy::unwrap_used)] use crate::config_types::HistoryPersistence; + use crate::flags::OPENAI_REQUEST_MAX_RETRIES; use super::*; use pretty_assertions::assert_eq; @@ -800,6 +818,7 @@ disable_response_storage = true model_reasoning_summary: ReasoningSummary::Detailed, model_supports_reasoning_summaries: false, chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), + openai_request_max_retries: *OPENAI_REQUEST_MAX_RETRIES, }, o3_profile_config ); @@ -846,6 +865,7 @@ disable_response_storage = true model_reasoning_summary: ReasoningSummary::default(), model_supports_reasoning_summaries: false, chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), + openai_request_max_retries: *OPENAI_REQUEST_MAX_RETRIES, }; assert_eq!(expected_gpt3_profile_config, gpt3_profile_config); @@ -907,6 +927,7 @@ disable_response_storage = true model_reasoning_summary: ReasoningSummary::default(), model_supports_reasoning_summaries: false, chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), + openai_request_max_retries: *OPENAI_REQUEST_MAX_RETRIES, }; assert_eq!(expected_zdr_profile_config, zdr_profile_config); diff --git a/codex-rs/core/src/openai_tools.rs b/codex-rs/core/src/openai_tools.rs index ef12a629b6..7e3f86fda8 100644 --- a/codex-rs/core/src/openai_tools.rs +++ b/codex-rs/core/src/openai_tools.rs @@ -155,3 +155,89 @@ fn mcp_tool_to_openai_tool( "type": "function", }) } +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + use crate::client_common::Prompt; + use mcp_types::Tool; + use mcp_types::ToolInputSchema; + + fn dummy_tool() -> (String, Tool) { + ( + "srv.dummy".to_string(), + Tool { + annotations: None, + description: Some("dummy".into()), + input_schema: ToolInputSchema { + properties: None, + required: None, + r#type: "object".to_string(), + }, + name: "dummy".into(), + }, + ) + } + + /// Ensure that the default `shell` tool plus any prompt-supplied extra tool are encoded + /// correctly for the Responses API. We compare against a golden JSON value rather than + /// asserting individual fields so that future refactors will intentionally update the test. + #[test] + fn responses_includes_default_and_extra() { + let mut prompt = Prompt::default(); + let (name, tool) = dummy_tool(); + prompt.extra_tools.insert(name.clone(), tool); + + let tools = create_tools_json_for_responses_api(&prompt, "gpt-4").unwrap(); + + // Verify presence & order: builtin `shell` first, then our extra tool. + assert_eq!( + tools[0].get("name"), + Some(&serde_json::Value::String("shell".into())) + ); + + let dummy = tools + .iter() + .find(|t| t.get("name") == Some(&serde_json::Value::String(name.clone()))) + .unwrap_or_else(|| panic!("dummy tool not found in tools list")); + + // The dummy tool should match what `mcp_tool_to_openai_tool` produces. + let expected_dummy = + mcp_tool_to_openai_tool(name, prompt.extra_tools.remove("srv.dummy").unwrap()); + assert_eq!(dummy, &expected_dummy); + } + + /// When the model name starts with `codex-`, the built-in shell tool should be encoded + /// as `local_shell` rather than `shell`. Verify that the first tool in the JSON list has + /// the adjusted type in that scenario. + #[test] + fn responses_codex_model_uses_local_shell() { + let mut prompt = Prompt::default(); + let (name, tool) = dummy_tool(); + prompt.extra_tools.insert(name, tool); + + let tools = create_tools_json_for_responses_api(&prompt, "codex-model").unwrap(); + assert_eq!(tools[0]["type"], "local_shell"); + } + + /// Chat-Completions API expects the V2 tool schema (`{"type":"function","function":{..}}`). + /// Confirm that every entry is shaped accordingly and the wrapper does not leak the internal + /// `type` field inside the nested `function` object. + #[test] + fn chat_completions_tool_format() { + let mut prompt = Prompt::default(); + let (name, tool) = dummy_tool(); + prompt.extra_tools.insert(name.clone(), tool); + + let tools = create_tools_json_for_chat_completions_api(&prompt, "gpt-4").unwrap(); + assert_eq!(tools.len(), 2); + for t in tools { + assert_eq!( + t.get("type"), + Some(&serde_json::Value::String("function".into())) + ); + let inner = t.get("function").and_then(|v| v.as_object()).unwrap(); + assert!(!inner.contains_key("type")); + } + } +} diff --git a/codex-rs/core/tests/cli_stream.rs b/codex-rs/core/tests/cli_stream.rs index df3fedfd48..105b97817c 100644 --- a/codex-rs/core/tests/cli_stream.rs +++ b/codex-rs/core/tests/cli_stream.rs @@ -66,13 +66,10 @@ async fn chat_mode_stream_cli() { .env("OPENAI_BASE_URL", format!("{}/v1", server.uri())); let output = cmd.output().unwrap(); - println!("Status: {}", output.status); - println!("Stdout:\n{}", String::from_utf8_lossy(&output.stdout)); - println!("Stderr:\n{}", String::from_utf8_lossy(&output.stderr)); assert!(output.status.success()); let stdout = String::from_utf8_lossy(&output.stdout); - assert!(stdout.contains("hi")); - assert_eq!(stdout.matches("hi").count(), 1); + let hi_lines = stdout.lines().filter(|line| line.trim() == "hi").count(); + assert_eq!(hi_lines, 1, "Expected exactly one line with 'hi'"); server.verify().await; } diff --git a/codex-rs/core/tests/live_agent.rs b/codex-rs/core/tests/live_agent.rs index c21f9d0032..25a7542b0d 100644 --- a/codex-rs/core/tests/live_agent.rs +++ b/codex-rs/core/tests/live_agent.rs @@ -55,12 +55,13 @@ async fn spawn_codex() -> Result { // beginning of the test, before we spawn any background tasks that could // observe the environment. unsafe { - std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "2"); std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "2"); } let codex_home = TempDir::new().unwrap(); - let config = load_default_config_for_test(&codex_home); + let mut config = load_default_config_for_test(&codex_home); + // Live tests keep retries low to avoid slow backoffs on flaky networks. + config.openai_request_max_retries = 2; let (agent, _init_id) = Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?; Ok(agent) @@ -79,7 +80,7 @@ async fn live_streaming_and_prev_id_reset() { let codex = spawn_codex().await.unwrap(); - // ---------- Task 1 ---------- + // ---------- Task 1 ---------- codex .submit(Op::UserInput { items: vec![InputItem::Text { @@ -113,7 +114,7 @@ async fn live_streaming_and_prev_id_reset() { "Agent did not stream any AgentMessage before TaskComplete" ); - // ---------- Task 2 (same session) ---------- + // ---------- Task 2 (same session) ---------- codex .submit(Op::UserInput { items: vec![InputItem::Text { diff --git a/codex-rs/core/tests/previous_response_id.rs b/codex-rs/core/tests/previous_response_id.rs index e64271a0ff..25e7f8fc5e 100644 --- a/codex-rs/core/tests/previous_response_id.rs +++ b/codex-rs/core/tests/previous_response_id.rs @@ -91,8 +91,8 @@ async fn keeps_previous_response_id_between_tasks() { // Environment // Update environment – `set_var` is `unsafe` starting with the 2024 // edition so we group the calls into a single `unsafe { … }` block. + // NOTE: per-request retry count is now configured directly on the Config. unsafe { - std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0"); std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0"); } let model_provider = ModelProviderInfo { @@ -113,6 +113,8 @@ async fn keeps_previous_response_id_between_tasks() { let codex_home = TempDir::new().unwrap(); let mut config = load_default_config_for_test(&codex_home); config.model_provider = model_provider; + // No per-request retries so each new user input triggers exactly one HTTP request. + config.openai_request_max_retries = 0; let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap(); diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs index da2736aa77..1d2ee4b08a 100644 --- a/codex-rs/core/tests/stream_no_completed.rs +++ b/codex-rs/core/tests/stream_no_completed.rs @@ -74,12 +74,11 @@ async fn retries_on_early_close() { // // As of Rust 2024 `std::env::set_var` has been made `unsafe` because // mutating the process environment is inherently racy when other threads - // are running. We therefore have to wrap every call in an explicit - // `unsafe` block. These are limited to the test-setup section so the - // scope is very small and clearly delineated. + // are running. We used to tweak the per-request retry counts via the + // `OPENAI_REQUEST_MAX_RETRIES` env var but that caused data races in + // multi-threaded tests. Configure the value directly on the Config instead. unsafe { - std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0"); std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "1"); std::env::set_var("OPENAI_STREAM_IDLE_TIMEOUT_MS", "2000"); } @@ -102,6 +101,8 @@ async fn retries_on_early_close() { let codex_home = TempDir::new().unwrap(); let mut config = load_default_config_for_test(&codex_home); config.model_provider = model_provider; + // Disable per-request retries (we want to exercise stream-level retries). + config.openai_request_max_retries = 0; let (codex, _init_id) = Codex::spawn(config, ctrl_c).await.unwrap(); codex diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index 44dddd4d0f..8869728487 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -104,6 +104,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)), model_provider: None, codex_linux_sandbox_exe, + openai_request_max_retries: None, }; // Parse `-c` overrides. let cli_kv_overrides = match config_overrides.parse_overrides() { diff --git a/codex-rs/mcp-server/src/codex_tool_config.rs b/codex-rs/mcp-server/src/codex_tool_config.rs index 8555524942..44a2c8970f 100644 --- a/codex-rs/mcp-server/src/codex_tool_config.rs +++ b/codex-rs/mcp-server/src/codex_tool_config.rs @@ -142,6 +142,7 @@ impl CodexToolCallParam { sandbox_mode: sandbox.map(Into::into), model_provider: None, codex_linux_sandbox_exe, + openai_request_max_retries: None, }; let cli_overrides = cli_overrides diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 4ca305b35e..ff3f2481ec 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -75,6 +75,7 @@ pub fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> std::io:: model_provider: None, config_profile: cli.config_profile.clone(), codex_linux_sandbox_exe, + openai_request_max_retries: None, }; // Parse `-c` overrides from the CLI. let cli_kv_overrides = match cli.config_overrides.parse_overrides() { ``` ## Review Comments ### codex-rs/core/src/chat_completions.rs - Created: 2025-07-12 18:44:53 UTC | Link: https://github.com/openai/codex/pull/1547#discussion_r2202867728 ```diff @@ -462,3 +462,106 @@ pub(crate) trait AggregateStreamExt: Stream> + Size } impl AggregateStreamExt for T where T: Stream> + Sized {} +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + use crate::WireApi; + use crate::client_common::Prompt; + use crate::config::{Config, ConfigOverrides, ConfigToml}; + use crate::models::{ContentItem, FunctionCallOutputPayload, ResponseItem}; + use pretty_assertions::assert_eq; + use std::sync::{Arc, Mutex}; + use tempfile::TempDir; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate}; + + struct CaptureResponder { + body: Arc>>, + } + + impl Respond for CaptureResponder { + fn respond(&self, req: &Request) -> ResponseTemplate { + let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap(); + *self.body.lock().unwrap() = Some(v); + ResponseTemplate::new(200).insert_header("content-type", "text/event-stream") + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn assembles_messages_correctly() { + let server = MockServer::start().await; ``` > Does this also need to check `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR`? - Created: 2025-07-12 18:46:04 UTC | Link: https://github.com/openai/codex/pull/1547#discussion_r2202867909 ```diff @@ -462,3 +462,106 @@ pub(crate) trait AggregateStreamExt: Stream> + Size } impl AggregateStreamExt for T where T: Stream> + Sized {} +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + use crate::WireApi; + use crate::client_common::Prompt; + use crate::config::{Config, ConfigOverrides, ConfigToml}; + use crate::models::{ContentItem, FunctionCallOutputPayload, ResponseItem}; + use pretty_assertions::assert_eq; + use std::sync::{Arc, Mutex}; + use tempfile::TempDir; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate}; + + struct CaptureResponder { + body: Arc>>, + } + + impl Respond for CaptureResponder { + fn respond(&self, req: &Request) -> ResponseTemplate { + let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap(); + *self.body.lock().unwrap() = Some(v); + ResponseTemplate::new(200).insert_header("content-type", "text/event-stream") + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn assembles_messages_correctly() { + let server = MockServer::start().await; + let capture = Arc::new(Mutex::new(None)); + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(CaptureResponder { + body: capture.clone(), + }) + .mount(&server) + .await; + + let provider = ModelProviderInfo { + name: "test".into(), + base_url: format!("{}/v1", server.uri()), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + + let codex_home = TempDir::new().unwrap(); + let mut config = Config::load_from_base_config_with_overrides( + ConfigToml::default(), + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + ) + .unwrap(); + config.model_provider = provider.clone(); + config.model = "gpt-4".into(); + + let client = reqwest::Client::new(); + + let prompt = Prompt { + input: vec![ + ResponseItem::Message { + role: "user".into(), + content: vec![ContentItem::InputText { text: "hi".into() }], + }, + ResponseItem::Message { + role: "assistant".into(), + content: vec![ContentItem::OutputText { text: "ok".into() }], + }, + ResponseItem::FunctionCall { + name: "foo".into(), + arguments: "{}".into(), + call_id: "c1".into(), + }, + ResponseItem::FunctionCallOutput { + call_id: "c1".into(), + output: FunctionCallOutputPayload { + content: "out".into(), + success: Some(true), + }, + }, + ], + ..Default::default() + }; + + let _ = stream_chat_completions(&prompt, &config.model, &client, &provider) + .await + .unwrap(); + + let body = capture.lock().unwrap().take().unwrap(); + let messages = body.get("messages").unwrap().as_array().unwrap(); + assert_eq!(messages[1]["role"], "user"); ``` > Can we just do one `assert_eq!` on `messages` in its entirety? Or maybe `&messages[1..]`? - Created: 2025-07-12 18:49:48 UTC | Link: https://github.com/openai/codex/pull/1547#discussion_r2202868898 ```diff @@ -462,3 +462,106 @@ pub(crate) trait AggregateStreamExt: Stream> + Size } impl AggregateStreamExt for T where T: Stream> + Sized {} +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + use crate::WireApi; + use crate::client_common::Prompt; + use crate::config::{Config, ConfigOverrides, ConfigToml}; + use crate::models::{ContentItem, FunctionCallOutputPayload, ResponseItem}; + use pretty_assertions::assert_eq; + use std::sync::{Arc, Mutex}; + use tempfile::TempDir; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate}; + + struct CaptureResponder { + body: Arc>>, + } + + impl Respond for CaptureResponder { + fn respond(&self, req: &Request) -> ResponseTemplate { + let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap(); + *self.body.lock().unwrap() = Some(v); + ResponseTemplate::new(200).insert_header("content-type", "text/event-stream") + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] ``` > Please add a docstring explaining what is being tested. ### codex-rs/core/src/client.rs - Created: 2025-07-12 18:51:04 UTC | Link: https://github.com/openai/codex/pull/1547#discussion_r2202869188 ```diff @@ -391,3 +391,269 @@ async fn stream_from_fixture(path: impl AsRef) -> Result { tokio::spawn(process_sse(stream, tx_event)); Ok(ResponseStream { rx_event }) } +#[cfg(test)] ``` > Looks like you need `just fmt`. - Created: 2025-07-12 18:52:45 UTC | Link: https://github.com/openai/codex/pull/1547#discussion_r2202869627 ```diff @@ -391,3 +391,269 @@ async fn stream_from_fixture(path: impl AsRef) -> Result { tokio::spawn(process_sse(stream, tx_event)); Ok(ResponseStream { rx_event }) } +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used, clippy::print_stdout)] + use super::*; + use crate::client_common::Prompt; + use crate::config::{Config, ConfigOverrides, ConfigToml}; + use futures::StreamExt; + use std::sync::{Arc, Mutex}; + use std::time::{Duration, Instant}; + use tempfile::TempDir; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate}; + + fn default_config(provider: ModelProviderInfo) -> Arc { + let codex_home = TempDir::new().unwrap(); + let mut cfg = Config::load_from_base_config_with_overrides( + ConfigToml::default(), + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + ) + .unwrap(); + cfg.model_provider = provider.clone(); + cfg.model = "gpt-test".into(); + Arc::new(cfg) ``` > Just FYI, `codex_home` will be deleted when this function exits, but that seems fine in this case. - Created: 2025-07-12 18:55:10 UTC | Link: https://github.com/openai/codex/pull/1547#discussion_r2202870169 ```diff @@ -391,3 +391,269 @@ async fn stream_from_fixture(path: impl AsRef) -> Result { tokio::spawn(process_sse(stream, tx_event)); Ok(ResponseStream { rx_event }) } +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used, clippy::print_stdout)] + use super::*; + use crate::client_common::Prompt; + use crate::config::{Config, ConfigOverrides, ConfigToml}; + use futures::StreamExt; + use std::sync::{Arc, Mutex}; + use std::time::{Duration, Instant}; + use tempfile::TempDir; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate}; + + fn default_config(provider: ModelProviderInfo) -> Arc { + let codex_home = TempDir::new().unwrap(); + let mut cfg = Config::load_from_base_config_with_overrides( + ConfigToml::default(), + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + ) + .unwrap(); + cfg.model_provider = provider.clone(); + cfg.model = "gpt-test".into(); + Arc::new(cfg) + } + + fn sse_completed(id: &str) -> String { + format!( + "event: response.completed\n\ + data: {{\"type\":\"response.completed\",\"response\":{{\"id\":\"{id}\",\"output\":[]}}}}\n\n\n" + ) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn retries_once_on_server_error() { ``` > I think all of these tests would benefit from docstrings. - Created: 2025-07-12 19:02:34 UTC | Link: https://github.com/openai/codex/pull/1547#discussion_r2202872605 ```diff @@ -391,3 +391,269 @@ async fn stream_from_fixture(path: impl AsRef) -> Result { tokio::spawn(process_sse(stream, tx_event)); Ok(ResponseStream { rx_event }) } +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used, clippy::print_stdout)] + use super::*; + use crate::client_common::Prompt; + use crate::config::{Config, ConfigOverrides, ConfigToml}; + use futures::StreamExt; + use std::sync::{Arc, Mutex}; + use std::time::{Duration, Instant}; + use tempfile::TempDir; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate}; + + fn default_config(provider: ModelProviderInfo) -> Arc { + let codex_home = TempDir::new().unwrap(); + let mut cfg = Config::load_from_base_config_with_overrides( + ConfigToml::default(), + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + ) + .unwrap(); + cfg.model_provider = provider.clone(); + cfg.model = "gpt-test".into(); + Arc::new(cfg) + } + + fn sse_completed(id: &str) -> String { + format!( + "event: response.completed\n\ + data: {{\"type\":\"response.completed\",\"response\":{{\"id\":\"{id}\",\"output\":[]}}}}\n\n\n" + ) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn retries_once_on_server_error() { + if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!("Skipping test due to sandbox network restriction"); + return; + } + let server = MockServer::start().await; + struct SeqResponder; + impl Respond for SeqResponder { + fn respond(&self, _req: &Request) -> ResponseTemplate { + use std::sync::atomic::{AtomicUsize, Ordering}; + static CALLS: AtomicUsize = AtomicUsize::new(0); + let n = CALLS.fetch_add(1, Ordering::SeqCst); + if n == 0 { + ResponseTemplate::new(500) + } else { + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("ok"), "text/event-stream") + } + } + } + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(SeqResponder) + .expect(2) + .mount(&server) + .await; + + unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1") }; + + let provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + + let config = default_config(provider.clone()); + let client = ModelClient::new( + config, + provider, + ReasoningEffortConfig::None, + ReasoningSummaryConfig::None, + ); + let prompt = Prompt::default(); + let mut stream = client.stream(&prompt).await.unwrap(); + while let Some(ev) = stream.next().await { + if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) { + break; + } + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn retry_after_header_delay() { + if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!("Skipping test due to sandbox network restriction"); + return; + } + let server = MockServer::start().await; + let times = Arc::new(Mutex::new(Vec::new())); + struct SeqResponder { + times: Arc>>, + } + impl Respond for SeqResponder { + fn respond(&self, _req: &Request) -> ResponseTemplate { + let mut t = self.times.lock().unwrap(); + t.push(Instant::now()); + if t.len() == 1 { + ResponseTemplate::new(429).insert_header("retry-after", "1") + } else { + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("ok"), "text/event-stream") + } + } + } + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(SeqResponder { + times: times.clone(), + }) + .expect(2) + .mount(&server) + .await; + + unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1") }; + + let provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + let config = default_config(provider.clone()); + let client = ModelClient::new( + config, + provider, + ReasoningEffortConfig::None, + ReasoningSummaryConfig::None, + ); + let prompt = Prompt::default(); + let mut stream = client.stream(&prompt).await.unwrap(); + while let Some(ev) = stream.next().await { + if matches!(ev.unwrap(), ResponseEvent::Completed { .. }) { + break; + } + } + let times = times.lock().unwrap(); + assert!(times.len() == 2); + assert!(times[1] - times[0] >= Duration::from_secs(1)); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn retry_backoff_no_header() { + if std::env::var(crate::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!("Skipping test due to sandbox network restriction"); + return; + } + let server = MockServer::start().await; + let times = Arc::new(Mutex::new(Vec::new())); + struct SeqResponder { + times: Arc>>, + } + impl Respond for SeqResponder { + fn respond(&self, _req: &Request) -> ResponseTemplate { + let mut t = self.times.lock().unwrap(); + t.push(Instant::now()); + if t.len() == 1 { + ResponseTemplate::new(429) + } else { + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("ok"), "text/event-stream") + } + } + } + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(SeqResponder { + times: times.clone(), + }) + .expect(2) + .mount(&server) + .await; + + unsafe { std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "1") }; + + let provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + }; + let config = default_config(provider.clone()); + let client = ModelClient::new( + config, + provider, + ReasoningEffortConfig::None, + ReasoningSummaryConfig::None, + ); ``` > Maybe use a helper function to dedupe common logic in tests? ### codex-rs/core/src/openai_tools.rs - Created: 2025-07-12 19:03:16 UTC | Link: https://github.com/openai/codex/pull/1547#discussion_r2202872735 ```diff @@ -155,3 +155,71 @@ fn mcp_tool_to_openai_tool( "type": "function", }) } +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + use crate::client_common::Prompt; + use mcp_types::{Tool, ToolInputSchema}; + + fn dummy_tool() -> (String, Tool) { + ( + "srv.dummy".to_string(), + Tool { + annotations: None, + description: Some("dummy".into()), + input_schema: ToolInputSchema { + properties: None, + required: None, + r#type: "object".to_string(), + }, + name: "dummy".into(), + }, + ) + } + + #[test] + fn responses_includes_default_and_extra() { + let mut prompt = Prompt::default(); + let (name, tool) = dummy_tool(); + prompt.extra_tools.insert(name.clone(), tool); + + let tools = create_tools_json_for_responses_api(&prompt, "gpt-4").unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0]["type"], "function"); ``` > Just one `assert_eq!` for all of `tools[0]`? - Created: 2025-07-12 19:05:25 UTC | Link: https://github.com/openai/codex/pull/1547#discussion_r2202873208 ```diff @@ -155,3 +155,71 @@ fn mcp_tool_to_openai_tool( "type": "function", }) } +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + use crate::client_common::Prompt; + use mcp_types::{Tool, ToolInputSchema}; + + fn dummy_tool() -> (String, Tool) { + ( + "srv.dummy".to_string(), + Tool { + annotations: None, + description: Some("dummy".into()), + input_schema: ToolInputSchema { + properties: None, + required: None, + r#type: "object".to_string(), + }, + name: "dummy".into(), + }, + ) + } + + #[test] + fn responses_includes_default_and_extra() { + let mut prompt = Prompt::default(); + let (name, tool) = dummy_tool(); + prompt.extra_tools.insert(name.clone(), tool); + + let tools = create_tools_json_for_responses_api(&prompt, "gpt-4").unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0]["type"], "function"); + assert_eq!(tools[0]["name"], "shell"); + assert!( + tools + .iter() + .any(|t| t.get("name") == Some(&name.clone().into())) ``` > Maybe use `find(|t| t.get("name").as_ref() == Some("srv.dummy")` on `tools.iter()` or something like that and then do an `assert_eq!()` on the value returned from `find()`? - Created: 2025-07-12 19:07:30 UTC | Link: https://github.com/openai/codex/pull/1547#discussion_r2202874282 ```diff @@ -155,3 +155,71 @@ fn mcp_tool_to_openai_tool( "type": "function", }) } +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + use crate::client_common::Prompt; + use mcp_types::{Tool, ToolInputSchema}; + + fn dummy_tool() -> (String, Tool) { + ( + "srv.dummy".to_string(), + Tool { + annotations: None, + description: Some("dummy".into()), + input_schema: ToolInputSchema { + properties: None, + required: None, + r#type: "object".to_string(), + }, + name: "dummy".into(), + }, + ) + } + + #[test] + fn responses_includes_default_and_extra() { + let mut prompt = Prompt::default(); + let (name, tool) = dummy_tool(); + prompt.extra_tools.insert(name.clone(), tool); + + let tools = create_tools_json_for_responses_api(&prompt, "gpt-4").unwrap(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0]["type"], "function"); + assert_eq!(tools[0]["name"], "shell"); + assert!( + tools + .iter() + .any(|t| t.get("name") == Some(&name.clone().into())) + ); + } + + #[test] ``` > For both of these tests, can we just assert the entire string/serde_json::Value that we get back? I realize this means that we will have to update this test if we change the default tools, but I think having a test that verifies _everything_ (and effectively documents what we send on the wire) is worth that maintenance cost.