From 1bac24f8276b5bbea94ac045ae9b583564b62c80 Mon Sep 17 00:00:00 2001 From: jif-oai Date: Mon, 10 Nov 2025 12:01:02 +0000 Subject: [PATCH] V1 --- codex-rs/Cargo.lock | 24 + codex-rs/Cargo.toml | 2 + codex-rs/api-client/Cargo.toml | 29 + codex-rs/api-client/src/api.rs | 13 + codex-rs/api-client/src/auth.rs | 17 + codex-rs/api-client/src/chat.rs | 629 ++++++++ codex-rs/api-client/src/error.rs | 38 + codex-rs/api-client/src/lib.rs | 37 + codex-rs/api-client/src/model_provider.rs | 343 +++++ codex-rs/api-client/src/prompt.rs | 46 + codex-rs/api-client/src/responses.rs | 819 ++++++++++ codex-rs/api-client/src/stream.rs | 83 + codex-rs/core/Cargo.toml | 1 + codex-rs/core/src/chat_completions.rs | 967 ------------ codex-rs/core/src/client.rs | 1692 +++++---------------- codex-rs/core/src/client_common.rs | 212 +-- codex-rs/core/src/codex.rs | 22 +- codex-rs/core/src/codex/compact.rs | 4 +- codex-rs/core/src/config/mod.rs | 4 +- codex-rs/core/src/default_client.rs | 8 + codex-rs/core/src/lib.rs | 12 +- codex-rs/core/src/model_provider_info.rs | 532 ------- 22 files changed, 2476 insertions(+), 3058 deletions(-) create mode 100644 codex-rs/api-client/Cargo.toml create mode 100644 codex-rs/api-client/src/api.rs create mode 100644 codex-rs/api-client/src/auth.rs create mode 100644 codex-rs/api-client/src/chat.rs create mode 100644 codex-rs/api-client/src/error.rs create mode 100644 codex-rs/api-client/src/lib.rs create mode 100644 codex-rs/api-client/src/model_provider.rs create mode 100644 codex-rs/api-client/src/prompt.rs create mode 100644 codex-rs/api-client/src/responses.rs create mode 100644 codex-rs/api-client/src/stream.rs delete mode 100644 codex-rs/core/src/chat_completions.rs delete mode 100644 codex-rs/core/src/model_provider_info.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index a72c885868..ea74bcd3bb 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -829,6 +829,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "codex-api-client" +version = "0.0.0" +dependencies = [ + "async-trait", + "bytes", + "codex-app-server-protocol", + "codex-otel", + "codex-protocol", + "eventsource-stream", + "futures", + "maplit", + "regex-lite", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.17", + "tokio", + "tokio-util", + "toml", + "tracing", +] + [[package]] name = "codex-app-server" version = "0.0.0" @@ -1062,6 +1085,7 @@ dependencies = [ "base64", "bytes", "chrono", + "codex-api-client", "codex-app-server-protocol", "codex-apply-patch", "codex-async-utils", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index d732151c85..7a24019e18 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = [ + "api-client", "backend-client", "ansi-escape", "async-utils", @@ -54,6 +55,7 @@ edition = "2024" # Internal app_test_support = { path = "app-server/tests/common" } codex-ansi-escape = { path = "ansi-escape" } +codex-api-client = { path = "api-client" } codex-app-server = { path = "app-server" } codex-app-server-protocol = { path = "app-server-protocol" } codex-apply-patch = { path = "apply-patch" } diff --git a/codex-rs/api-client/Cargo.toml b/codex-rs/api-client/Cargo.toml new file mode 100644 index 0000000000..d0ba8d5cf2 --- /dev/null +++ b/codex-rs/api-client/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "codex-api-client" +version.workspace = true +edition.workspace = true + +[dependencies] +async-trait = { workspace = true } +bytes = { workspace = true } +codex-app-server-protocol = { workspace = true } +codex-otel = { workspace = true } +codex-protocol = { path = "../protocol" } +eventsource-stream = { workspace = true } +futures = { workspace = true, default-features = false, features = ["std"] } +maplit = "1.0.2" +regex-lite = { workspace = true } +reqwest = { workspace = true, features = ["json", "stream"] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["sync", "time", "rt", "rt-multi-thread", "macros", "io-util"] } +tokio-util = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +maplit = "1.0.2" +toml = { workspace = true } + +[lints] +workspace = true diff --git a/codex-rs/api-client/src/api.rs b/codex-rs/api-client/src/api.rs new file mode 100644 index 0000000000..517f15f3ea --- /dev/null +++ b/codex-rs/api-client/src/api.rs @@ -0,0 +1,13 @@ +use async_trait::async_trait; + +use crate::error::Result; +use crate::prompt::Prompt; +use crate::stream::ResponseStream; + +#[async_trait] +pub trait ApiClient: Sized { + type Config; + + async fn new(config: Self::Config) -> Result; + async fn stream(&self, prompt: Prompt) -> Result; +} diff --git a/codex-rs/api-client/src/auth.rs b/codex-rs/api-client/src/auth.rs new file mode 100644 index 0000000000..8f35c3a7a6 --- /dev/null +++ b/codex-rs/api-client/src/auth.rs @@ -0,0 +1,17 @@ +use async_trait::async_trait; +use codex_app_server_protocol::AuthMode; +use serde::Deserialize; +use serde::Serialize; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AuthContext { + pub mode: AuthMode, + pub bearer_token: Option, + pub account_id: Option, +} + +#[async_trait] +pub trait AuthProvider: Send + Sync { + async fn auth_context(&self) -> Option; + async fn refresh_token(&self) -> std::result::Result, String>; +} diff --git a/codex-rs/api-client/src/chat.rs b/codex-rs/api-client/src/chat.rs new file mode 100644 index 0000000000..fe648903aa --- /dev/null +++ b/codex-rs/api-client/src/chat.rs @@ -0,0 +1,629 @@ +use std::time::Duration; + +use async_trait::async_trait; +use bytes::Bytes; +use codex_otel::otel_event_manager::OtelEventManager; +use codex_protocol::models::ContentItem; +use codex_protocol::models::FunctionCallOutputContentItem; +use codex_protocol::models::ReasoningItemContent; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::SessionSource; +use eventsource_stream::Eventsource; +use futures::Stream; +use futures::StreamExt; +use futures::TryStreamExt; +use serde_json::Value; +use serde_json::json; +use tokio::sync::mpsc; +use tokio::time::timeout; +use tracing::debug; +use tracing::trace; + +use crate::aggregate::ChatAggregationMode; +use crate::api::ApiClient; +use crate::common::apply_subagent_header; +use crate::common::backoff; +use crate::error::Error; +use crate::model_provider::ModelProviderInfo; +use crate::prompt::Prompt; +use crate::stream::ResponseEvent; +use crate::stream::ResponseStream; + +pub type Result = std::result::Result; + +#[derive(Clone)] +pub struct ChatCompletionsApiClientConfig { + pub http_client: reqwest::Client, + pub provider: ModelProviderInfo, + pub model: String, + pub otel_event_manager: OtelEventManager, + pub session_source: SessionSource, + pub aggregation_mode: ChatAggregationMode, +} + +#[derive(Clone)] +pub struct ChatCompletionsApiClient { + config: ChatCompletionsApiClientConfig, +} + +#[async_trait] +impl ApiClient for ChatCompletionsApiClient { + type Config = ChatCompletionsApiClientConfig; + + async fn new(config: Self::Config) -> Result { + Ok(Self { config }) + } + + async fn stream(&self, prompt: Prompt) -> Result { + Self::validate_prompt(&prompt)?; + + let payload = self.build_payload(&prompt)?; + let (tx_event, rx_event) = mpsc::channel::>(1600); + + let mut attempt: i64 = 0; + let max_retries = self.config.provider.request_max_retries(); + + loop { + attempt += 1; + + let req_builder = self + .config + .provider + .create_request_builder(&self.config.http_client, &None) + .await + .map(|builder| apply_subagent_header(builder, Some(&self.config.session_source)))?; + + let res = self + .config + .otel_event_manager + .log_request(attempt as u64, || { + req_builder + .header(reqwest::header::ACCEPT, "text/event-stream") + .json(&payload) + .send() + }) + .await; + + match res { + Ok(resp) if resp.status().is_success() => { + let stream = resp + .bytes_stream() + .map_err(|err| Error::ResponseStreamFailed { + source: err, + request_id: None, + }); + let idle_timeout = self.config.provider.stream_idle_timeout(); + let otel = self.config.otel_event_manager.clone(); + let mode = self.config.aggregation_mode; + + tokio::spawn(process_chat_sse( + stream, + tx_event.clone(), + idle_timeout, + otel, + mode, + )); + + return Ok(ResponseStream { rx_event }); + } + Ok(resp) => { + if attempt >= max_retries { + let status = resp.status(); + let body = resp + .text() + .await + .unwrap_or_else(|_| "".to_string()); + return Err(Error::UnexpectedStatus { status, body }); + } + + let retry_after = resp + .headers() + .get(reqwest::header::RETRY_AFTER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + .map(|secs| Duration::from_secs(if secs < 0 { 0 } else { secs as u64 })); + tokio::time::sleep(retry_after.unwrap_or_else(|| backoff(attempt))).await; + } + Err(error) => { + if attempt >= max_retries { + return Err(Error::Http(error)); + } + tokio::time::sleep(backoff(attempt)).await; + } + } + } + } +} + +impl ChatCompletionsApiClient { + fn validate_prompt(prompt: &Prompt) -> Result<()> { + if prompt.output_schema.is_some() { + return Err(Error::UnsupportedOperation( + "output_schema is not supported for Chat Completions API".to_string(), + )); + } + Ok(()) + } + + fn build_payload(&self, prompt: &Prompt) -> Result { + let mut messages = Vec::::new(); + messages.push(json!({ "role": "system", "content": prompt.instructions })); + + let mut reasoning_by_anchor_index: std::collections::HashMap = + std::collections::HashMap::new(); + + let mut last_emitted_role: Option<&str> = None; + for item in &prompt.input { + match item { + ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()), + ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { + last_emitted_role = Some("assistant"); + } + ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"), + ResponseItem::Reasoning { .. } + | ResponseItem::Other + | ResponseItem::CustomToolCall { .. } + | ResponseItem::CustomToolCallOutput { .. } + | ResponseItem::WebSearchCall { .. } + | ResponseItem::GhostSnapshot { .. } => {} + } + } + + let mut last_user_index: Option = None; + for (idx, item) in prompt.input.iter().enumerate() { + if let ResponseItem::Message { role, .. } = item + && role == "user" + { + last_user_index = Some(idx); + } + } + + if !matches!(last_emitted_role, Some("user")) { + for (idx, item) in prompt.input.iter().enumerate() { + if let Some(u_idx) = last_user_index + && idx <= u_idx + { + continue; + } + + if let ResponseItem::Reasoning { + content: Some(items), + .. + } = item + { + let mut text = String::new(); + for entry in items { + match entry { + ReasoningItemContent::ReasoningText { text: segment } + | ReasoningItemContent::Text { text: segment } => { + text.push_str(segment); + } + } + } + if text.trim().is_empty() { + continue; + } + + let mut attached = false; + if idx > 0 + && let ResponseItem::Message { role, .. } = &prompt.input[idx - 1] + && role == "assistant" + { + reasoning_by_anchor_index + .entry(idx - 1) + .and_modify(|val| val.push_str(&text)) + .or_insert(text.clone()); + attached = true; + } + + if !attached && idx + 1 < prompt.input.len() { + match &prompt.input[idx + 1] { + ResponseItem::FunctionCall { .. } + | ResponseItem::LocalShellCall { .. } => { + reasoning_by_anchor_index + .entry(idx + 1) + .and_modify(|val| val.push_str(&text)) + .or_insert(text.clone()); + } + ResponseItem::Message { role, .. } if role == "assistant" => { + reasoning_by_anchor_index + .entry(idx + 1) + .and_modify(|val| val.push_str(&text)) + .or_insert(text.clone()); + } + _ => {} + } + } + } + } + } + + let mut last_assistant_text: Option = None; + + for (idx, item) in prompt.input.iter().enumerate() { + match item { + ResponseItem::Message { role, content, .. } => { + let mut text = String::new(); + let mut items: Vec = Vec::new(); + let mut saw_image = false; + + for c in content { + match c { + ContentItem::InputText { text: t } + | ContentItem::OutputText { text: t } => { + text.push_str(t); + items.push(json!({"type":"text","text": t})); + } + ContentItem::InputImage { image_url } => { + saw_image = true; + items.push( + json!({"type":"image_url","image_url": {"url": image_url}}), + ); + } + } + } + + if role == "assistant" { + if let Some(prev) = &last_assistant_text + && prev == &text + { + continue; + } + last_assistant_text = Some(text.clone()); + } + + let content_value = if role == "assistant" { + json!(text) + } else if saw_image { + json!(items) + } else { + json!(text) + }; + + let mut message = json!({ + "role": role, + "content": content_value, + }); + + if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) + && let Some(obj) = message.as_object_mut() + { + obj.insert("reasoning".to_string(), json!({"text": reasoning})); + } + + messages.push(message); + } + ResponseItem::FunctionCall { + name, + arguments, + call_id, + .. + } => { + messages.push(json!({ + "role": "assistant", + "tool_calls": [{ + "id": call_id, + "type": "function", + "function": { + "name": name, + "arguments": arguments, + }, + }], + })); + } + ResponseItem::FunctionCallOutput { call_id, output } => { + let content_value = if let Some(items) = &output.content_items { + let mapped: Vec = items + .iter() + .map(|item| match item { + FunctionCallOutputContentItem::InputText { text } => { + json!({"type":"text","text": text}) + } + FunctionCallOutputContentItem::InputImage { image_url } => { + json!({"type":"image_url","image_url": {"url": image_url}}) + } + }) + .collect(); + json!(mapped) + } else { + json!(output.content) + }; + + messages.push(json!({ + "role": "tool", + "tool_call_id": call_id, + "content": content_value, + })); + } + ResponseItem::LocalShellCall { + id, + call_id, + action, + .. + } => { + let tool_id = call_id + .clone() + .filter(|value| !value.is_empty()) + .or_else(|| id.clone()) + .unwrap_or_default(); + messages.push(json!({ + "role": "assistant", + "tool_calls": [{ + "id": tool_id, + "type": "function", + "function": { + "name": "shell", + "arguments": serde_json::to_string(action).unwrap_or_default(), + }, + }], + })); + } + ResponseItem::CustomToolCall { + call_id, + name, + input, + .. + } => { + messages.push(json!({ + "role": "assistant", + "tool_calls": [{ + "id": call_id.clone(), + "type": "function", + "function": { + "name": name, + "arguments": input, + }, + }], + })); + } + ResponseItem::CustomToolCallOutput { call_id, output } => { + messages.push(json!({ + "role": "tool", + "tool_call_id": call_id, + "content": output, + })); + } + ResponseItem::WebSearchCall { .. } + | ResponseItem::Reasoning { .. } + | ResponseItem::Other + | ResponseItem::GhostSnapshot { .. } => {} + } + } + + let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?; + let payload = json!({ + "model": self.config.model, + "messages": messages, + "stream": true, + "tools": tools_json, + }); + + trace!("chat completions payload: {}", payload); + Ok(payload) + } +} + +/// Lightweight SSE processor for Chat Completions streaming, mapped to ResponseEvent. +async fn process_chat_sse( + stream: S, + tx_event: mpsc::Sender>, + idle_timeout: Duration, + _otel_event_manager: OtelEventManager, + aggregation_mode: ChatAggregationMode, +) where + S: Stream> + Unpin, +{ + let mut stream = stream.eventsource(); + + #[derive(Default)] + struct FunctionCallState { + name: Option, + arguments: String, + call_id: Option, + active: bool, + } + + let mut fn_call_state = FunctionCallState::default(); + let mut assistant_item: Option = None; + let mut reasoning_item: Option = None; + + loop { + let response = timeout(idle_timeout, stream.next()).await; + let sse = match response { + Ok(Some(Ok(ev))) => ev, + Ok(Some(Err(err))) => { + let _ = tx_event + .send(Err(Error::Stream(err.to_string(), None))) + .await; + return; + } + Ok(None) => { + let _ = tx_event + .send(Ok(ResponseEvent::Completed { + response_id: String::new(), + token_usage: None, + })) + .await; + return; + } + Err(_) => { + let _ = tx_event + .send(Err(Error::Stream( + "idle timeout waiting for SSE".into(), + None, + ))) + .await; + return; + } + }; + + if sse.data.trim() == "[DONE]" { + if let Some(item) = assistant_item { + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + if let Some(item) = reasoning_item { + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + let _ = tx_event + .send(Ok(ResponseEvent::Completed { + response_id: String::new(), + token_usage: None, + })) + .await; + return; + } + + let Ok(parsed_chunk) = serde_json::from_str::(&sse.data) else { + debug!("failed to parse SSE data into JSON: {}", sse.data); + continue; + }; + + let choices = parsed_chunk + .get("choices") + .and_then(|choices| choices.as_array()) + .cloned() + .unwrap_or_default(); + + for choice in choices { + if let Some(delta) = choice.get("delta") { + if let Some(content) = delta.get("content").and_then(|c| c.as_array()) { + for piece in content { + if let Some(text) = piece.get("text").and_then(|t| t.as_str()) { + append_assistant_text(&tx_event, &mut assistant_item, text.to_string()) + .await; + if matches!(aggregation_mode, ChatAggregationMode::Streaming) { + let _ = tx_event + .send(Ok(ResponseEvent::OutputTextDelta(text.to_string()))) + .await; + } + } + } + } + + if let Some(tool_calls) = delta.get("tool_calls").and_then(|c| c.as_array()) { + for call in tool_calls { + if let Some(id_val) = call.get("id").and_then(|id| id.as_str()) { + fn_call_state.call_id = Some(id_val.to_string()); + } + if let Some(function) = call.get("function") { + if let Some(name) = function.get("name").and_then(|n| n.as_str()) { + fn_call_state.name = Some(name.to_string()); + fn_call_state.active = true; + } + if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) { + fn_call_state.arguments.push_str(args); + } + } + } + } + + if let Some(reasoning) = delta.get("reasoning_content").and_then(|c| c.as_array()) { + for entry in reasoning { + if let Some(text) = entry.get("text").and_then(|t| t.as_str()) { + append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()) + .await; + } + } + } + } + + if let Some(finish_reason) = choice.get("finish_reason").and_then(|f| f.as_str()) + && finish_reason == "tool_calls" + && fn_call_state.active + { + let function_name = fn_call_state.name.take().unwrap_or_default(); + let call_id = fn_call_state.call_id.take().unwrap_or_default(); + let arguments = fn_call_state.arguments.clone(); + fn_call_state = FunctionCallState::default(); + + let item = ResponseItem::FunctionCall { + id: Some(call_id.clone()), + call_id, + name: function_name, + arguments, + }; + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + } + } +} + +async fn append_assistant_text( + tx_event: &mpsc::Sender>, + assistant_item: &mut Option, + text: String, +) { + if assistant_item.is_none() { + let item = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![], + }; + *assistant_item = Some(item.clone()); + let _ = tx_event + .send(Ok(ResponseEvent::OutputItemAdded(item))) + .await; + } + + if let Some(ResponseItem::Message { content, .. }) = assistant_item { + content.push(ContentItem::OutputText { text }); + } +} + +async fn append_reasoning_text( + tx_event: &mpsc::Sender>, + reasoning_item: &mut Option, + text: String, +) { + if reasoning_item.is_none() { + let item = ResponseItem::Reasoning { + id: String::new(), + summary: Vec::new(), + content: Some(vec![]), + encrypted_content: None, + }; + *reasoning_item = Some(item.clone()); + let _ = tx_event + .send(Ok(ResponseEvent::OutputItemAdded(item))) + .await; + } + + if let Some(ResponseItem::Reasoning { + content: Some(content), + .. + }) = reasoning_item + { + content.push(ReasoningItemContent::ReasoningText { text }); + } +} + +fn create_tools_json_for_chat_completions_api( + tools: &[serde_json::Value], +) -> Result> { + let tools_json = tools + .iter() + .filter_map(|tool| { + if tool.get("type") != Some(&serde_json::Value::String("function".to_string())) { + return None; + } + + let function_value = if let Some(function) = tool.get("function") { + function.clone() + } else if let Some(map) = tool.as_object() { + let mut function = map.clone(); + function.remove("type"); + Value::Object(function) + } else { + return None; + }; + + Some(json!({ + "type": "function", + "function": function_value, + })) + }) + .collect::>(); + Ok(tools_json) +} + +// aggregation types and adapters moved to crate::aggregate diff --git a/codex-rs/api-client/src/error.rs b/codex-rs/api-client/src/error.rs new file mode 100644 index 0000000000..b449e7b3e5 --- /dev/null +++ b/codex-rs/api-client/src/error.rs @@ -0,0 +1,38 @@ +use reqwest::StatusCode; +use thiserror::Error; + +pub type Result = std::result::Result; + +#[derive(Error, Debug)] +pub enum Error { + #[error("{0}")] + UnsupportedOperation(String), + #[error(transparent)] + Http(#[from] reqwest::Error), + #[error("response stream failed: {source}")] + ResponseStreamFailed { + #[source] + source: reqwest::Error, + request_id: Option, + }, + #[error("stream error: {0}")] + Stream(String, Option), + #[error("unexpected status {status}: {body}")] + UnexpectedStatus { status: StatusCode, body: String }, + #[error("retry limit reached {status:?} request_id={request_id:?}")] + RetryLimit { + status: Option, + request_id: Option, + }, + #[error("missing env var {var}: {instructions:?}")] + MissingEnvVar { + var: String, + instructions: Option, + }, + #[error("auth error: {0}")] + Auth(String), + #[error(transparent)] + Json(#[from] serde_json::Error), + #[error("{0}")] + Other(String), +} diff --git a/codex-rs/api-client/src/lib.rs b/codex-rs/api-client/src/lib.rs new file mode 100644 index 0000000000..f65ef15467 --- /dev/null +++ b/codex-rs/api-client/src/lib.rs @@ -0,0 +1,37 @@ +pub mod aggregate; +pub mod api; +pub mod auth; +pub mod chat; +mod common; +pub mod error; +pub mod model_provider; +pub mod prompt; +pub mod responses; +pub mod stream; + +pub use crate::aggregate::AggregateStreamExt; +pub use crate::aggregate::ChatAggregationMode; +pub use crate::api::ApiClient; +pub use crate::auth::AuthContext; +pub use crate::auth::AuthProvider; +pub use crate::chat::ChatCompletionsApiClient; +pub use crate::chat::ChatCompletionsApiClientConfig; +pub use crate::error::Error; +pub use crate::error::Result; +pub use crate::model_provider::BUILT_IN_OSS_MODEL_PROVIDER_ID; +pub use crate::model_provider::ModelProviderInfo; +pub use crate::model_provider::WireApi; +pub use crate::model_provider::built_in_model_providers; +pub use crate::model_provider::create_oss_provider; +pub use crate::model_provider::create_oss_provider_with_base_url; +pub use crate::prompt::Prompt; +pub use crate::responses::ResponsesApiClient; +pub use crate::responses::ResponsesApiClientConfig; +pub use crate::responses::stream_from_fixture; +pub use crate::stream::EventStream; +pub use crate::stream::Reasoning; +pub use crate::stream::ResponseEvent; +pub use crate::stream::ResponseStream; +pub use crate::stream::TextControls; +pub use crate::stream::TextFormat; +pub use crate::stream::TextFormatType; diff --git a/codex-rs/api-client/src/model_provider.rs b/codex-rs/api-client/src/model_provider.rs new file mode 100644 index 0000000000..5b1a9138ba --- /dev/null +++ b/codex-rs/api-client/src/model_provider.rs @@ -0,0 +1,343 @@ +//! Registry of model providers supported by Codex. +//! +//! Providers can be defined in two places: +//! 1. Built-in defaults compiled into the binary so Codex works out-of-the-box. +//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers` +//! key. These override or extend the defaults at runtime. + +use std::collections::HashMap; +use std::env::VarError; +use std::time::Duration; + +use codex_app_server_protocol::AuthMode; +use serde::Deserialize; +use serde::Serialize; + +use crate::auth::AuthContext; +use crate::error::Error; +use crate::error::Result; + +const DEFAULT_STREAM_IDLE_TIMEOUT_MS: i64 = 300_000; +const DEFAULT_STREAM_MAX_RETRIES: i64 = 5; +const DEFAULT_REQUEST_MAX_RETRIES: i64 = 4; +/// Hard cap for user-configured `stream_max_retries`. +const MAX_STREAM_MAX_RETRIES: i64 = 100; +/// Hard cap for user-configured `request_max_retries`. +const MAX_REQUEST_MAX_RETRIES: i64 = 100; +const DEFAULT_OLLAMA_PORT: i32 = 11434; + +/// Wire protocol that the provider speaks. Most third-party services only +/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI +/// itself (and a handful of others) additionally expose the more modern +/// Responses API. The two protocols use different request/response shapes +/// and cannot be auto-detected at runtime, therefore each provider entry +/// must declare which one it expects. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum WireApi { + /// The Responses API exposed by OpenAI at `/v1/responses`. + Responses, + /// Regular Chat Completions compatible with `/v1/chat/completions`. + #[default] + Chat, +} + +/// Serializable representation of a provider definition. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct ModelProviderInfo { + /// Friendly display name. + pub name: String, + /// Base URL for the provider's OpenAI-compatible API. + pub base_url: Option, + /// Environment variable that stores the user's API key for this provider. + pub env_key: Option, + /// Optional instructions to help the user get a valid value for the + /// variable and set it. + pub env_key_instructions: Option, + /// Value to use with `Authorization: Bearer ` header. Use of this + /// config is discouraged in favor of `env_key` for security reasons, but + /// this may be necessary when using this programmatically. + pub experimental_bearer_token: Option, + /// Which wire protocol this provider expects. + #[serde(default)] + pub wire_api: WireApi, + /// Optional query parameters to append to the base URL. + pub query_params: Option>, + /// Additional HTTP headers to include in requests to this provider where + /// the (key, value) pairs are the header name and value. + pub http_headers: Option>, + /// Optional HTTP headers to include in requests to this provider where the + /// (key, value) pairs are the header name and environment variable whose + /// value should be used. If the environment variable is not set, or the + /// value is empty, the header will not be included in the request. + pub env_http_headers: Option>, + /// Maximum number of times to retry a failed HTTP request to this provider. + pub request_max_retries: Option, + /// Number of times to retry reconnecting a dropped streaming response before failing. + pub stream_max_retries: Option, + /// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating + /// the connection as lost. + pub stream_idle_timeout_ms: Option, + /// Does this provider require an OpenAI API Key or ChatGPT login token? If true, + /// the user is presented with a login screen on first run, and login preference and token/key + /// are stored in auth.json. If false (which is the default), the login screen is skipped, + /// and the API key (if needed) comes from the `env_key` environment variable. + #[serde(default)] + pub requires_openai_auth: bool, +} + +impl ModelProviderInfo { + /// Construct a `POST` request builder for the given URL using the provided + /// [`reqwest::Client`] applying: + /// - provider-specific headers (static and environment based) + /// - Bearer auth header when an API key is available + /// - Auth token for OAuth + /// + /// If the provider declares an `env_key` but the variable is missing or empty, this returns an + /// error identical to the one produced by [`ModelProviderInfo::api_key`]. + pub async fn create_request_builder( + &self, + client: &reqwest::Client, + auth: &Option, + ) -> Result { + let effective_auth = if let Some(secret_key) = &self.experimental_bearer_token { + Some(AuthContext { + mode: AuthMode::ApiKey, + bearer_token: Some(secret_key.clone()), + account_id: None, + }) + } else { + match self.api_key()? { + Some(key) => Some(AuthContext { + mode: AuthMode::ApiKey, + bearer_token: Some(key), + account_id: None, + }), + None => auth.clone(), + } + }; + + let url = self.get_full_url(effective_auth.as_ref()); + let mut builder = client.post(url); + + if let Some(context) = effective_auth.as_ref() + && let Some(token) = context.bearer_token.as_ref() + { + builder = builder.bearer_auth(token); + } + + Ok(self.apply_http_headers(builder)) + } + + fn get_query_string(&self) -> String { + self.query_params + .as_ref() + .map_or_else(String::new, |params| { + let full_params = params + .iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join("&"); + format!("?{full_params}") + }) + } + + pub fn get_full_url(&self, auth: Option<&AuthContext>) -> String { + let default_base_url = if matches!( + auth, + Some(AuthContext { + mode: AuthMode::ChatGPT, + .. + }) + ) { + "https://chatgpt.com/backend-api/codex" + } else { + "https://api.openai.com/v1" + }; + let query_string = self.get_query_string(); + let base_url = self + .base_url + .clone() + .unwrap_or_else(|| default_base_url.to_string()); + + match self.wire_api { + WireApi::Responses => format!("{base_url}/responses{query_string}"), + WireApi::Chat => format!("{base_url}/chat/completions{query_string}"), + } + } + + pub fn is_azure_responses_endpoint(&self) -> bool { + if self.wire_api != WireApi::Responses { + return false; + } + + if self.name.eq_ignore_ascii_case("azure") { + return true; + } + + self.base_url + .as_ref() + .map(|base| matches_azure_responses_base_url(base)) + .unwrap_or(false) + } + + /// Apply provider-specific HTTP headers (both static and environment-based) onto an existing + /// [`reqwest::RequestBuilder`] and return the updated builder. + fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(extra) = &self.http_headers { + for (k, v) in extra { + builder = builder.header(k, v); + } + } + + if let Some(env_headers) = &self.env_http_headers { + for (header, env_var) in env_headers { + if let Ok(val) = std::env::var(env_var) + && !val.trim().is_empty() + { + builder = builder.header(header, val); + } + } + } + + builder + } + + pub fn api_key(&self) -> Result> { + Ok(match self.env_key.as_ref() { + Some(env_key) => match std::env::var(env_key) { + Ok(value) if !value.trim().is_empty() => Some(value), + Ok(_missing) => None, + Err(VarError::NotPresent) => { + let instructions = self.env_key_instructions.clone(); + return Err(Error::MissingEnvVar { + var: env_key.to_string(), + instructions, + }); + } + Err(VarError::NotUnicode(_)) => { + return Err(Error::MissingEnvVar { + var: env_key.to_string(), + instructions: None, + }); + } + }, + None => None, + }) + } + + pub fn stream_max_retries(&self) -> i64 { + let value = self + .stream_max_retries + .unwrap_or(DEFAULT_STREAM_MAX_RETRIES) + .min(MAX_STREAM_MAX_RETRIES); + value.max(0) + } + + pub fn request_max_retries(&self) -> i64 { + let value = self + .request_max_retries + .unwrap_or(DEFAULT_REQUEST_MAX_RETRIES) + .min(MAX_REQUEST_MAX_RETRIES); + value.max(0) + } + + pub fn stream_idle_timeout(&self) -> Duration { + let ms = self + .stream_idle_timeout_ms + .unwrap_or(DEFAULT_STREAM_IDLE_TIMEOUT_MS); + let clamped = if ms < 0 { 0 } else { ms as u64 }; + Duration::from_millis(clamped) + } +} + +fn matches_azure_responses_base_url(base: &str) -> bool { + base.starts_with("https://") && base.ends_with(".openai.azure.com/openai/responses") +} + +pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "openai/compatible"; +pub const OPENAI_MODEL_PROVIDER_ID: &str = "openai"; +pub const ANTHROPIC_MODEL_PROVIDER_ID: &str = "anthropic"; + +/// Returns the baked-in list of providers. These can be overridden by a `[model_providers]` +/// entry inside `~/.codex/config.toml`. +pub fn built_in_model_providers() -> HashMap { + let mut providers = HashMap::new(); + + providers.insert( + OPENAI_MODEL_PROVIDER_ID.to_string(), + ModelProviderInfo { + name: "OpenAI".to_string(), + base_url: None, + env_key: Some("OPENAI_API_KEY".to_string()), + env_key_instructions: Some("Log in to OpenAI and create a new API key at https://platform.openai.com/api-keys. Then paste it here.".to_string()), + experimental_bearer_token: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_openai_auth: true, + }, + ); + + providers.insert( + ANTHROPIC_MODEL_PROVIDER_ID.to_string(), + ModelProviderInfo { + name: "Anthropic".to_string(), + base_url: Some("https://api.anthropic.com/v1/messages".to_string()), + env_key: Some("ANTHROPIC_API_KEY".to_string()), + env_key_instructions: Some("Create a new API key at https://console.anthropic.com/settings/keys and paste it here.".to_string()), + experimental_bearer_token: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: Some( + maplit::hashmap! { + "anthropic-version".to_string() => "2023-06-01".to_string(), + } + ), + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_openai_auth: false, + }, + ); + + providers.insert( + BUILT_IN_OSS_MODEL_PROVIDER_ID.to_string(), + create_oss_provider_with_base_url("http://localhost:11434"), + ); + + providers +} + +pub fn create_oss_provider_with_base_url(url: &str) -> ModelProviderInfo { + let http_headers = maplit::hashmap! { + "x-oss-provider".to_string() => "ollama".to_string(), + }; + ModelProviderInfo { + name: "Self-hosted OpenAI-compatible (OSS)".to_string(), + base_url: Some(url.to_string()), + env_key: Some("CODEX_OSS_PROVIDER_API_KEY".to_string()), + env_key_instructions: Some( + "Set CODEx_OSS_PROVIDER_API_KEY to authenticate with this provider.".to_string(), + ), + experimental_bearer_token: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: Some(http_headers), + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_openai_auth: false, + } +} + +/// Convenience helper to construct a default `openai/compatible` provider pointing at localhost. +pub fn create_oss_provider() -> ModelProviderInfo { + create_oss_provider_with_base_url(&format!("http://localhost:{DEFAULT_OLLAMA_PORT}")) +} diff --git a/codex-rs/api-client/src/prompt.rs b/codex-rs/api-client/src/prompt.rs new file mode 100644 index 0000000000..e5c0840257 --- /dev/null +++ b/codex-rs/api-client/src/prompt.rs @@ -0,0 +1,46 @@ +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::SessionSource; +use serde_json::Value; + +use crate::Reasoning; +use crate::TextControls; + +#[derive(Debug, Clone, Default)] +pub struct Prompt { + pub instructions: String, + pub input: Vec, + pub tools: Vec, + pub parallel_tool_calls: bool, + pub output_schema: Option, + pub reasoning: Option, + pub text_controls: Option, + pub prompt_cache_key: Option, + pub session_source: Option, +} + +impl Prompt { + #[allow(clippy::too_many_arguments)] + pub fn new( + instructions: String, + input: Vec, + tools: Vec, + parallel_tool_calls: bool, + output_schema: Option, + reasoning: Option, + text_controls: Option, + prompt_cache_key: Option, + session_source: Option, + ) -> Self { + Self { + instructions, + input, + tools, + parallel_tool_calls, + output_schema, + reasoning, + text_controls, + prompt_cache_key, + session_source, + } + } +} diff --git a/codex-rs/api-client/src/responses.rs b/codex-rs/api-client/src/responses.rs new file mode 100644 index 0000000000..12185b2858 --- /dev/null +++ b/codex-rs/api-client/src/responses.rs @@ -0,0 +1,819 @@ +use std::io::BufRead; +use std::path::Path; +use std::sync::Arc; +use std::sync::OnceLock; +use std::time::Duration; + +use async_trait::async_trait; +use bytes::Bytes; +use codex_app_server_protocol::AuthMode; +use codex_otel::otel_event_manager::OtelEventManager; +use codex_protocol::ConversationId; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::RateLimitSnapshot; +use codex_protocol::protocol::TokenUsage; +use futures::Stream; +use futures::StreamExt; +use futures::TryStreamExt; +use regex_lite::Regex; +use reqwest::StatusCode; +use reqwest::header::HeaderMap; +use serde::Deserialize; +use serde_json::Value; +use serde_json::json; +use tokio::sync::mpsc; +use tokio::time::timeout; +use tokio_util::io::ReaderStream; +use tracing::debug; +use tracing::trace; + +use crate::api::ApiClient; +use crate::auth::AuthProvider; +use crate::common::apply_subagent_header; +use crate::common::backoff; +use crate::error::Error; +use crate::model_provider::ModelProviderInfo; +use crate::prompt::Prompt; +use crate::stream::ResponseEvent; +use crate::stream::ResponseStream; + +type Result = std::result::Result; + +#[derive(Clone)] +pub struct ResponsesApiClientConfig { + pub http_client: reqwest::Client, + pub provider: ModelProviderInfo, + pub model: String, + pub conversation_id: ConversationId, + pub auth_provider: Option>, + pub otel_event_manager: OtelEventManager, +} + +#[derive(Clone)] +pub struct ResponsesApiClient { + config: ResponsesApiClientConfig, +} + +#[async_trait] +impl ApiClient for ResponsesApiClient { + type Config = ResponsesApiClientConfig; + + async fn new(config: Self::Config) -> Result { + Ok(Self { config }) + } + + async fn stream(&self, prompt: Prompt) -> Result { + if self.config.provider.wire_api != crate::model_provider::WireApi::Responses { + return Err(Error::UnsupportedOperation( + "ResponsesApiClient requires a Responses provider".to_string(), + )); + } + + let mut payload_json = self.build_payload(&prompt)?; + + if self.config.provider.is_azure_responses_endpoint() + && let Some(input_value) = payload_json.get_mut("input") + && let Some(array) = input_value.as_array_mut() + { + attach_item_ids_array(array, &prompt.input); + } + + let max_attempts = self.config.provider.request_max_retries(); + for attempt in 0..=max_attempts { + match self + .attempt_stream_responses(attempt, &prompt, &payload_json) + .await + { + Ok(stream) => return Ok(stream), + Err(StreamAttemptError::Fatal(err)) => return Err(err), + Err(retryable) => { + if attempt == max_attempts { + return Err(retryable.into_error()); + } + + tokio::time::sleep(retryable.delay(attempt)).await; + } + } + } + + unreachable!("attempt_stream_responses should always return"); + } +} + +impl ResponsesApiClient { + fn build_payload(&self, prompt: &Prompt) -> Result { + let azure_workaround = self.config.provider.is_azure_responses_endpoint(); + + let mut payload = json!({ + "model": self.config.model, + "instructions": prompt.instructions, + "input": prompt.input, + "tools": prompt.tools, + "tool_choice": "auto", + "parallel_tool_calls": prompt.parallel_tool_calls, + "store": azure_workaround, + "stream": true, + "prompt_cache_key": prompt + .prompt_cache_key + .clone() + .unwrap_or_else(|| self.config.conversation_id.to_string()), + }); + + if let Some(reasoning) = prompt.reasoning.as_ref() + && let Some(obj) = payload.as_object_mut() + { + obj.insert("reasoning".to_string(), serde_json::to_value(reasoning)?); + } + + if let Some(text) = prompt.text_controls.as_ref() + && let Some(obj) = payload.as_object_mut() + { + obj.insert("text".to_string(), serde_json::to_value(text)?); + } + + let include = if prompt.reasoning.is_some() { + vec!["reasoning.encrypted_content".to_string()] + } else { + Vec::new() + }; + if let Some(obj) = payload.as_object_mut() { + obj.insert( + "include".to_string(), + Value::Array(include.into_iter().map(Value::String).collect()), + ); + } + + Ok(payload) + } + + async fn attempt_stream_responses( + &self, + attempt: i64, + prompt: &Prompt, + payload_json: &Value, + ) -> std::result::Result { + let auth = if let Some(provider) = &self.config.auth_provider { + provider.auth_context().await + } else { + None + }; + + trace!( + "POST to {}: {:?}", + self.config.provider.get_full_url(auth.as_ref()), + serde_json::to_string(payload_json) + .unwrap_or_else(|_| "".to_string()) + ); + + let mut req_builder = self + .config + .provider + .create_request_builder(&self.config.http_client, &auth) + .await + .map_err(StreamAttemptError::Fatal)?; + req_builder = apply_subagent_header(req_builder, prompt.session_source.as_ref()); + + req_builder = req_builder + .header("conversation_id", self.config.conversation_id.to_string()) + .header("session_id", self.config.conversation_id.to_string()) + .header(reqwest::header::ACCEPT, "text/event-stream") + .json(payload_json); + + if let Some(auth_ctx) = auth.as_ref() + && auth_ctx.mode == AuthMode::ChatGPT + && let Some(account_id) = auth_ctx.account_id.clone() + { + req_builder = req_builder.header("chatgpt-account-id", account_id); + } + + let res = self + .config + .otel_event_manager + .log_request(attempt as u64, || req_builder.send()) + .await; + + let mut request_id = None; + if let Ok(resp) = &res { + request_id = resp + .headers() + .get("cf-ray") + .and_then(|v| v.to_str().ok()) + .map(std::string::ToString::to_string); + } + + match res { + Ok(resp) if resp.status().is_success() => { + let (tx_event, rx_event) = mpsc::channel::>(1600); + + if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers()) + && tx_event + .send(Ok(ResponseEvent::RateLimits(snapshot))) + .await + .is_err() + { + debug!("receiver dropped rate limit snapshot event"); + } + + let stream = resp + .bytes_stream() + .map_err(move |err| Error::ResponseStreamFailed { + source: err, + request_id: request_id.clone(), + }); + let idle_timeout = self.config.provider.stream_idle_timeout(); + let otel = self.config.otel_event_manager.clone(); + + tokio::spawn(process_sse(stream, tx_event, idle_timeout, otel)); + + Ok(ResponseStream { rx_event }) + } + Ok(resp) => Err(handle_error_response(resp, request_id, &self.config).await), + Err(err) => Err(StreamAttemptError::RetryableTransportError(Error::Http( + err, + ))), + } + } +} + +async fn handle_error_response( + resp: reqwest::Response, + request_id: Option, + _config: &ResponsesApiClientConfig, +) -> StreamAttemptError { + let status = resp.status(); + let retry_after_secs = resp + .headers() + .get(reqwest::header::RETRY_AFTER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + let retry_after = retry_after_secs.map(|secs| { + let clamped = if secs < 0 { 0 } else { secs as u64 }; + Duration::from_secs(clamped) + }); + + if !(status == StatusCode::TOO_MANY_REQUESTS + || status == StatusCode::UNAUTHORIZED + || status.is_server_error()) + { + let body = resp.text().await.unwrap_or_default(); + return StreamAttemptError::Fatal(Error::UnexpectedStatus { status, body }); + } + + if status == StatusCode::TOO_MANY_REQUESTS { + let body = resp.json::().await.ok(); + if let Some(ErrorResponse { error }) = body { + if error.r#type.as_deref() == Some("usage_limit_reached") { + return StreamAttemptError::Fatal(Error::Stream( + "usage limit reached".to_string(), + None, + )); + } else if error.r#type.as_deref() == Some("usage_not_included") { + return StreamAttemptError::Fatal(Error::Stream( + "usage not included".to_string(), + None, + )); + } else if is_quota_exceeded_error(&error) { + return StreamAttemptError::Fatal(Error::Stream( + "quota exceeded".to_string(), + None, + )); + } + } + } + + StreamAttemptError::RetryableHttpError { + status, + retry_after, + request_id, + } +} + +#[allow(clippy::too_many_arguments)] +async fn process_sse( + stream: S, + tx_event: mpsc::Sender>, + max_idle_duration: Duration, + otel_event_manager: OtelEventManager, +) where + S: Stream> + Send + 'static + Unpin, +{ + let mut stream = stream; + let mut response_completed: Option = None; + let mut response_error: Option = None; + + loop { + let result = timeout(max_idle_duration, stream.next()).await; + match result { + Err(_) => { + if let Some(completed) = response_completed.take() { + let _ = emit_response_completed( + tx_event.clone(), + completed, + response_error.take(), + &otel_event_manager, + ) + .await; + return; + } + + let _ = tx_event + .send(Err(Error::Stream( + "stream idle timeout fired before Completed event".to_string(), + None, + ))) + .await; + return; + } + Ok(Some(Err(err))) => { + let _ = tx_event.send(Err(err)).await; + return; + } + Ok(Some(Ok(chunk))) => { + if let Err(err) = process_sse_chunk(chunk, &tx_event).await { + let _ = tx_event.send(Err(err)).await; + return; + } + } + Ok(None) => { + if let Some(completed) = response_completed.take() { + let _ = emit_response_completed( + tx_event.clone(), + completed, + response_error.take(), + &otel_event_manager, + ) + .await; + } + return; + } + } + } +} + +async fn emit_response_completed( + tx_event: mpsc::Sender>, + completed: ResponseCompleted, + response_error: Option, + _otel_event_manager: &OtelEventManager, +) -> Result<()> { + if let Some(err) = response_error { + tx_event.send(Err(err)).await.ok(); + return Ok(()); + } + + let event = ResponseEvent::Completed { + response_id: completed.id, + token_usage: completed.usage, + }; + tx_event.send(Ok(event)).await.ok(); + + Ok(()) +} + +fn parse_rate_limit_snapshot(_headers: &HeaderMap) -> Option { + None +} + +async fn process_sse_chunk( + chunk: Bytes, + tx_event: &mpsc::Sender>, +) -> Result<()> { + let chunk_str = std::str::from_utf8(&chunk) + .map_err(|err| Error::Other(format!("Invalid UTF-8 in SSE chunk: {err}")))?; + trace!("responses api chunk ({chunk_str:?})"); + + let mut data_buffer = String::new(); + for line in chunk_str.lines() { + if let Some(tail) = line.strip_prefix("data:") { + data_buffer.push_str(tail.trim_start()); + } + + if line.is_empty() { + let payload: sse::Payload = serde_json::from_str(&data_buffer) + .map_err(|err| Error::Other(format!("Cannot parse SSE JSON: {err}")))?; + handle_sse_payload(payload, tx_event).await?; + data_buffer.clear(); + } + } + + Ok(()) +} + +async fn handle_sse_payload( + payload: sse::Payload, + tx_event: &mpsc::Sender>, +) -> Result<()> { + if let Some(responses) = payload.responses { + for ev in responses { + let event = match ev { + sse::Response::Completed(complete) => ResponseEvent::Completed { + response_id: complete.id, + token_usage: complete.usage, + }, + sse::Response::Error(err) => { + let retry_after = err + .retry_after + .map(|secs| Duration::from_secs(if secs < 0 { 0 } else { secs as u64 })); + return Err(Error::Stream( + err.message.unwrap_or_else(|| "fatal error".to_string()), + retry_after, + )); + } + }; + tx_event.send(Ok(event)).await.ok(); + } + } + + if let Some(message_delta) = payload.response_message_delta { + let ev = ResponseEvent::OutputTextDelta(message_delta.text.clone()); + tx_event.send(Ok(ev)).await.ok(); + } + + if let Some(_rate_limits) = payload.rate_limits { + // Rate limit snapshots are not emitted for this protocol shape in this build. + } + + if let Some(_response_content) = payload.response_content { + // Not used currently + } + + if let Some(ev) = payload.response_event { + debug!("Unhandled response_event: {ev:?}"); + } + + if let Some(item) = payload.response_output_item { + match item.r#type { + sse::OutputItem::Created => { + tx_event.send(Ok(ResponseEvent::Created)).await.ok(); + } + } + } + + if let Some(done) = payload.response_output_text_delta { + tx_event + .send(Ok(ResponseEvent::OutputTextDelta(done.text))) + .await + .ok(); + } + + if let Some(completed) = payload.response_output_item_done { + let response_item = + serde_json::from_value::(completed.item).map_err(Error::Json)?; + tx_event + .send(Ok(ResponseEvent::OutputItemDone(response_item))) + .await + .ok(); + } + + if let Some(reasoning_content_delta) = payload.response_output_reasoning_delta { + tx_event + .send(Ok(ResponseEvent::ReasoningContentDelta( + reasoning_content_delta.text, + ))) + .await + .ok(); + } + + if let Some(reasoning_summary_delta) = payload.response_output_reasoning_summary_delta { + tx_event + .send(Ok(ResponseEvent::ReasoningSummaryDelta( + reasoning_summary_delta.text, + ))) + .await + .ok(); + } + + if let Some(ev) = payload.response_error + && ev.code.as_deref() == Some("max_response_tokens") + { + let _ = tx_event + .send(Err(Error::Stream( + "context window exceeded".to_string(), + None, + ))) + .await; + } + + Ok(()) +} + +#[derive(Debug, Deserialize)] +struct ResponseCompleted { + id: String, + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ErrorResponse { + error: ErrorBody, +} +#[derive(Debug, Deserialize)] +struct ErrorBody { + r#type: Option, + code: Option, + message: Option, + plan_type: Option, + resets_at: Option, +} + +fn is_quota_exceeded_error(error: &ErrorBody) -> bool { + error.code.as_deref() == Some("quota_exceeded") +} + +enum StreamAttemptError { + RetryableHttpError { + status: StatusCode, + retry_after: Option, + request_id: Option, + }, + RetryableTransportError(Error), + Fatal(Error), +} + +impl StreamAttemptError { + fn delay(&self, attempt: i64) -> Duration { + match self { + StreamAttemptError::RetryableHttpError { + retry_after: Some(retry_after), + .. + } => *retry_after, + StreamAttemptError::RetryableHttpError { + retry_after: None, .. + } + | StreamAttemptError::RetryableTransportError(..) => backoff(attempt), + StreamAttemptError::Fatal(..) => Duration::from_millis(0), + } + } + + fn into_error(self) -> Error { + match self { + StreamAttemptError::RetryableHttpError { + status, request_id, .. + } => Error::RetryLimit { + status: Some(status), + request_id, + }, + StreamAttemptError::RetryableTransportError(err) | StreamAttemptError::Fatal(err) => { + err + } + } + } +} + +// backoff moved to crate::common + +fn rate_limit_regex() -> Option<&'static Regex> { + static RE: OnceLock> = OnceLock::new(); + + RE.get_or_init(|| Regex::new(r"Please try again in (\d+(?:\.\d+)?)(s|ms)").ok()) + .as_ref() +} + +fn try_parse_retry_after(err: &ErrorResponse) -> Option { + if err.error.code.as_deref() != Some("rate_limit_exceeded") { + return None; + } + + if let Some(re) = rate_limit_regex() + && let Some(message) = &err.error.message + && let Some(captures) = re.captures(message) + { + let seconds = captures.get(1); + let unit = captures.get(2); + + if let (Some(value), Some(unit)) = (seconds, unit) { + let value = value.as_str().parse::().ok()?; + let unit = unit.as_str(); + + if unit == "s" { + return Some(Duration::from_secs_f64(value)); + } else if unit == "ms" { + return Some(Duration::from_millis(value as u64)); + } + } + } + None +} + +fn is_context_window_error(error: &ErrorResponse) -> bool { + error.error.code.as_deref() == Some("context_length_exceeded") +} + +/// used in tests to stream from a text SSE file +pub async fn stream_from_fixture( + path: impl AsRef, + provider: ModelProviderInfo, + otel_event_manager: OtelEventManager, +) -> Result { + let (tx_event, rx_event) = mpsc::channel::>(1600); + let display_path = path.as_ref().display().to_string(); + let file = std::fs::File::open(path.as_ref()) + .map_err(|err| Error::Other(format!("failed to open fixture {display_path}: {err}")))?; + let lines = std::io::BufReader::new(file).lines(); + + let mut content = String::new(); + for line in lines { + let line = line + .map_err(|err| Error::Other(format!("failed to read fixture {display_path}: {err}")))?; + content.push_str(&line); + content.push('\n'); + content.push('\n'); + } + + let rdr = std::io::Cursor::new(content); + let stream = ReaderStream::new(rdr).map_err(|err| Error::Other(err.to_string())); + tokio::spawn(process_sse( + stream, + tx_event, + provider.stream_idle_timeout(), + otel_event_manager, + )); + Ok(ResponseStream { rx_event }) +} + +fn attach_item_ids_array(_json_array: &mut Vec, _prompt_input: &[ResponseItem]) { + // no-op for current protocol version +} + +#[derive(Debug, Deserialize)] +struct StreamEvent { + r#type: String, + response: Option, + item: Option, + error: Option, +} + +#[derive(Debug, Deserialize)] +struct StreamResponsePayload { + event: StreamEvent, +} + +async fn handle_stream_event( + event: StreamEvent, + tx_event: mpsc::Sender>, + response_completed: &mut Option, + response_error: &mut Option, +) { + trace!("response event: {}", event.r#type); + match event.r#type.as_str() { + "response.output_text.delta" => { + if let Some(item_val) = event.item { + let resp = serde_json::from_value::(item_val); + if let Ok(delta) = resp { + let event = ResponseEvent::OutputTextDelta(delta.delta); + if tx_event.send(Ok(event)).await.is_err() {} + } + } + } + "response.error" => { + if let Some(err_val) = event.error { + let err_resp = serde_json::from_value::(err_val); + match err_resp { + Ok(err) => { + let retry_after = try_parse_retry_after(&err); + *response_error = Some(Error::Stream( + err.error + .message + .unwrap_or_else(|| "unknown error".to_string()), + retry_after, + )); + } + Err(err) => { + let _ = tx_event + .send(Err(Error::Stream( + format!("failed to parse ErrorResponse: {err}"), + None, + ))) + .await; + } + } + } + } + "response.completed" => { + if let Some(resp_val) = event.response { + match serde_json::from_value::(resp_val) { + Ok(resp) => { + *response_completed = Some(resp); + } + Err(err) => { + let _ = tx_event + .send(Err(Error::Stream( + format!("failed to parse ResponseCompleted: {err}"), + None, + ))) + .await; + } + }; + }; + } + "response.output_item.added" => { + if let Some(item_val) = event.item + && let Ok(item) = serde_json::from_value::(item_val) + { + let event = ResponseEvent::OutputItemAdded(item); + if tx_event.send(Ok(event)).await.is_err() {} + } + } + "response.reasoning_summary_part.added" => { + let event = ResponseEvent::ReasoningSummaryPartAdded; + let _ = tx_event.send(Ok(event)).await; + } + _ => {} + } +} + +#[derive(Debug, Deserialize)] +struct TextDelta { + role: String, + delta: String, +} + +mod sse { + use serde::Deserialize; + use serde_json::Value; + + #[derive(Debug, Deserialize)] + pub struct Payload { + pub responses: Option>, + pub response_content: Option, + pub response_error: Option, + pub response_event: Option, + pub response_message_delta: Option, + pub response_output_item: Option, + pub response_output_text_delta: Option, + pub response_output_item_done: Option, + pub response_output_reasoning_delta: Option, + pub response_output_reasoning_summary_delta: Option, + pub rate_limits: Option>, + } + + #[derive(Debug, Deserialize)] + pub enum Response { + #[serde(rename = "response.completed")] + Completed(ResponseCompleted), + #[serde(rename = "response.error")] + Error(ResponseError), + } + + #[derive(Debug, Deserialize)] + pub struct ResponseCompleted { + pub id: String, + pub usage: Option, + } + + #[derive(Debug, Deserialize)] + pub struct ResponseError { + pub code: Option, + pub message: Option, + pub retry_after: Option, + } + + #[derive(Debug, Deserialize)] + pub struct ResponseMessageDelta { + pub text: String, + pub role: String, + pub appended_content: Vec, + } + + #[derive(Debug, Deserialize)] + pub enum OutputItem { + #[serde(rename = "response.output_item.created")] + Created, + } + + #[derive(Debug, Deserialize)] + pub struct ResponseOutputItem { + pub r#type: OutputItem, + pub item: Value, + } + + #[derive(Debug, Deserialize)] + pub struct ResponseOutputTextDelta { + pub text: String, + } + + #[derive(Debug, Deserialize)] + pub struct ResponseOutputItemDone { + pub item: Value, + } + + #[derive(Debug, Deserialize)] + pub struct ResponseOutputReasoningDelta { + pub content: Vec, + pub text: String, + } + + #[derive(Debug, Deserialize)] + pub struct ResponseOutputReasoningSummaryDelta { + pub summary: Vec, + pub text: String, + } + + #[derive(Debug, Deserialize)] + pub struct RateLimit { + pub window: String, + pub remaining_tokens: i64, + pub limit: i64, + pub reset_seconds: i64, + } +} diff --git a/codex-rs/api-client/src/stream.rs b/codex-rs/api-client/src/stream.rs new file mode 100644 index 0000000000..ac76b282fa --- /dev/null +++ b/codex-rs/api-client/src/stream.rs @@ -0,0 +1,83 @@ +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; +use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::RateLimitSnapshot; +use codex_protocol::protocol::TokenUsage; +use futures::Stream; +use serde::Serialize; +use serde_json::Value; +use tokio::sync::mpsc; + +use crate::error::Result; + +#[derive(Debug, Serialize, Clone)] +pub struct Reasoning { + #[serde(skip_serializing_if = "Option::is_none")] + pub effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +#[derive(Debug, Serialize, Default, Clone)] +#[serde(rename_all = "snake_case")] +pub enum TextFormatType { + #[default] + JsonSchema, +} + +#[derive(Debug, Serialize, Default, Clone)] +pub struct TextFormat { + pub r#type: TextFormatType, + pub strict: bool, + pub schema: Value, + pub name: String, +} + +#[derive(Debug, Serialize, Default, Clone)] +pub struct TextControls { + #[serde(skip_serializing_if = "Option::is_none")] + pub verbosity: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, +} + +#[derive(Debug)] +pub enum ResponseEvent { + Created, + OutputItemDone(ResponseItem), + OutputItemAdded(ResponseItem), + Completed { + response_id: String, + token_usage: Option, + }, + OutputTextDelta(String), + ReasoningSummaryDelta(String), + ReasoningContentDelta(String), + ReasoningSummaryPartAdded, + RateLimits(RateLimitSnapshot), +} + +#[derive(Debug)] +pub struct EventStream { + pub(crate) rx_event: mpsc::Receiver, +} + +impl EventStream { + pub fn from_receiver(rx_event: mpsc::Receiver) -> Self { + Self { rx_event } + } +} + +impl Stream for EventStream { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx_event.poll_recv(cx) + } +} + +pub type ResponseStream = EventStream>; diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 17a7a16609..b34c59a710 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -22,6 +22,7 @@ chrono = { workspace = true, features = ["serde"] } codex-app-server-protocol = { workspace = true } codex-apply-patch = { workspace = true } codex-async-utils = { workspace = true } +codex-api-client = { workspace = true } codex-file-search = { workspace = true } codex-git = { workspace = true } codex-keyring-store = { workspace = true } diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs deleted file mode 100644 index abb27d9b55..0000000000 --- a/codex-rs/core/src/chat_completions.rs +++ /dev/null @@ -1,967 +0,0 @@ -use std::time::Duration; - -use crate::ModelProviderInfo; -use crate::client_common::Prompt; -use crate::client_common::ResponseEvent; -use crate::client_common::ResponseStream; -use crate::default_client::CodexHttpClient; -use crate::error::CodexErr; -use crate::error::ConnectionFailedError; -use crate::error::ResponseStreamFailed; -use crate::error::Result; -use crate::error::RetryLimitReachedError; -use crate::error::UnexpectedResponseError; -use crate::model_family::ModelFamily; -use crate::tools::spec::create_tools_json_for_chat_completions_api; -use crate::util::backoff; -use bytes::Bytes; -use codex_otel::otel_event_manager::OtelEventManager; -use codex_protocol::models::ContentItem; -use codex_protocol::models::FunctionCallOutputContentItem; -use codex_protocol::models::ReasoningItemContent; -use codex_protocol::models::ResponseItem; -use codex_protocol::protocol::SessionSource; -use codex_protocol::protocol::SubAgentSource; -use eventsource_stream::Eventsource; -use futures::Stream; -use futures::StreamExt; -use futures::TryStreamExt; -use reqwest::StatusCode; -use serde_json::json; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; -use tokio::sync::mpsc; -use tokio::time::timeout; -use tracing::debug; -use tracing::trace; - -/// Implementation for the classic Chat Completions API. -pub(crate) async fn stream_chat_completions( - prompt: &Prompt, - model_family: &ModelFamily, - client: &CodexHttpClient, - provider: &ModelProviderInfo, - otel_event_manager: &OtelEventManager, - session_source: &SessionSource, -) -> Result { - if prompt.output_schema.is_some() { - return Err(CodexErr::UnsupportedOperation( - "output_schema is not supported for Chat Completions API".to_string(), - )); - } - - // Build messages array - let mut messages = Vec::::new(); - - let full_instructions = prompt.get_full_instructions(model_family); - messages.push(json!({"role": "system", "content": full_instructions})); - - let input = prompt.get_formatted_input(); - - // Pre-scan: map Reasoning blocks to the adjacent assistant anchor after the last user. - // - If the last emitted message is a user message, drop all reasoning. - // - Otherwise, for each Reasoning item after the last user message, attach it - // to the immediate previous assistant message (stop turns) or the immediate - // next assistant anchor (tool-call turns: function/local shell call, or assistant message). - let mut reasoning_by_anchor_index: std::collections::HashMap = - std::collections::HashMap::new(); - - // Determine the last role that would be emitted to Chat Completions. - let mut last_emitted_role: Option<&str> = None; - for item in &input { - match item { - ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()), - ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { - last_emitted_role = Some("assistant") - } - ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"), - ResponseItem::Reasoning { .. } | ResponseItem::Other => {} - ResponseItem::CustomToolCall { .. } => {} - ResponseItem::CustomToolCallOutput { .. } => {} - ResponseItem::WebSearchCall { .. } => {} - ResponseItem::GhostSnapshot { .. } => {} - } - } - - // Find the last user message index in the input. - let mut last_user_index: Option = None; - for (idx, item) in input.iter().enumerate() { - if let ResponseItem::Message { role, .. } = item - && role == "user" - { - last_user_index = Some(idx); - } - } - - // Attach reasoning only if the conversation does not end with a user message. - if !matches!(last_emitted_role, Some("user")) { - for (idx, item) in input.iter().enumerate() { - // Only consider reasoning that appears after the last user message. - if let Some(u_idx) = last_user_index - && idx <= u_idx - { - continue; - } - - if let ResponseItem::Reasoning { - content: Some(items), - .. - } = item - { - let mut text = String::new(); - for entry in items { - match entry { - ReasoningItemContent::ReasoningText { text: segment } - | ReasoningItemContent::Text { text: segment } => text.push_str(segment), - } - } - if text.trim().is_empty() { - continue; - } - - // Prefer immediate previous assistant message (stop turns) - let mut attached = false; - if idx > 0 - && let ResponseItem::Message { role, .. } = &input[idx - 1] - && role == "assistant" - { - reasoning_by_anchor_index - .entry(idx - 1) - .and_modify(|v| v.push_str(&text)) - .or_insert(text.clone()); - attached = true; - } - - // Otherwise, attach to immediate next assistant anchor (tool-calls or assistant message) - if !attached && idx + 1 < input.len() { - match &input[idx + 1] { - ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { - reasoning_by_anchor_index - .entry(idx + 1) - .and_modify(|v| v.push_str(&text)) - .or_insert(text.clone()); - } - ResponseItem::Message { role, .. } if role == "assistant" => { - reasoning_by_anchor_index - .entry(idx + 1) - .and_modify(|v| v.push_str(&text)) - .or_insert(text.clone()); - } - _ => {} - } - } - } - } - } - - // Track last assistant text we emitted to avoid duplicate assistant messages - // in the outbound Chat Completions payload (can happen if a final - // aggregated assistant message was recorded alongside an earlier partial). - let mut last_assistant_text: Option = None; - - for (idx, item) in input.iter().enumerate() { - match item { - ResponseItem::Message { role, content, .. } => { - // Build content either as a plain string (typical for assistant text) - // or as an array of content items when images are present (user/tool multimodal). - let mut text = String::new(); - let mut items: Vec = Vec::new(); - let mut saw_image = false; - - for c in content { - match c { - ContentItem::InputText { text: t } - | ContentItem::OutputText { text: t } => { - text.push_str(t); - items.push(json!({"type":"text","text": t})); - } - ContentItem::InputImage { image_url } => { - saw_image = true; - items.push(json!({"type":"image_url","image_url": {"url": image_url}})); - } - } - } - - // Skip exact-duplicate assistant messages. - if role == "assistant" { - if let Some(prev) = &last_assistant_text - && prev == &text - { - continue; - } - last_assistant_text = Some(text.clone()); - } - - // For assistant messages, always send a plain string for compatibility. - // For user messages, if an image is present, send an array of content items. - let content_value = if role == "assistant" { - json!(text) - } else if saw_image { - json!(items) - } else { - json!(text) - }; - - let mut msg = json!({"role": role, "content": content_value}); - if role == "assistant" - && let Some(reasoning) = reasoning_by_anchor_index.get(&idx) - && let Some(obj) = msg.as_object_mut() - { - obj.insert("reasoning".to_string(), json!(reasoning)); - } - messages.push(msg); - } - ResponseItem::FunctionCall { - name, - arguments, - call_id, - .. - } => { - let mut msg = json!({ - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": call_id, - "type": "function", - "function": { - "name": name, - "arguments": arguments, - } - }] - }); - if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) - && let Some(obj) = msg.as_object_mut() - { - obj.insert("reasoning".to_string(), json!(reasoning)); - } - messages.push(msg); - } - ResponseItem::LocalShellCall { - id, - call_id: _, - status, - action, - } => { - // Confirm with API team. - let mut msg = json!({ - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": id.clone().unwrap_or_else(|| "".to_string()), - "type": "local_shell_call", - "status": status, - "action": action, - }] - }); - if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) - && let Some(obj) = msg.as_object_mut() - { - obj.insert("reasoning".to_string(), json!(reasoning)); - } - messages.push(msg); - } - ResponseItem::FunctionCallOutput { call_id, output } => { - // Prefer structured content items when available (e.g., images) - // otherwise fall back to the legacy plain-string content. - let content_value = if let Some(items) = &output.content_items { - let mapped: Vec = items - .iter() - .map(|it| match it { - FunctionCallOutputContentItem::InputText { text } => { - json!({"type":"text","text": text}) - } - FunctionCallOutputContentItem::InputImage { image_url } => { - json!({"type":"image_url","image_url": {"url": image_url}}) - } - }) - .collect(); - json!(mapped) - } else { - json!(output.content) - }; - - messages.push(json!({ - "role": "tool", - "tool_call_id": call_id, - "content": content_value, - })); - } - ResponseItem::CustomToolCall { - id, - call_id: _, - name, - input, - status: _, - } => { - messages.push(json!({ - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": id, - "type": "custom", - "custom": { - "name": name, - "input": input, - } - }] - })); - } - ResponseItem::CustomToolCallOutput { call_id, output } => { - messages.push(json!({ - "role": "tool", - "tool_call_id": call_id, - "content": output, - })); - } - ResponseItem::GhostSnapshot { .. } => { - // Ghost snapshots annotate history but are not sent to the model. - continue; - } - ResponseItem::Reasoning { .. } - | ResponseItem::WebSearchCall { .. } - | ResponseItem::Other => { - // Omit these items from the conversation history. - continue; - } - } - } - - let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?; - let payload = json!({ - "model": model_family.slug, - "messages": messages, - "stream": true, - "tools": tools_json, - }); - - debug!( - "POST to {}: {}", - provider.get_full_url(&None), - serde_json::to_string_pretty(&payload).unwrap_or_default() - ); - - let mut attempt = 0; - let max_retries = provider.request_max_retries(); - loop { - attempt += 1; - - let mut req_builder = provider.create_request_builder(client, &None).await?; - - // Include subagent header only for subagent sessions. - if let SessionSource::SubAgent(sub) = session_source.clone() { - let subagent = if let SubAgentSource::Other(label) = sub { - label - } else { - serde_json::to_value(&sub) - .ok() - .and_then(|v| v.as_str().map(std::string::ToString::to_string)) - .unwrap_or_else(|| "other".to_string()) - }; - req_builder = req_builder.header("x-openai-subagent", subagent); - } - - let res = otel_event_manager - .log_request(attempt, || { - req_builder - .header(reqwest::header::ACCEPT, "text/event-stream") - .json(&payload) - .send() - }) - .await; - - match res { - Ok(resp) if resp.status().is_success() => { - let (tx_event, rx_event) = mpsc::channel::>(1600); - let stream = resp.bytes_stream().map_err(|e| { - CodexErr::ResponseStreamFailed(ResponseStreamFailed { - source: e, - request_id: None, - }) - }); - tokio::spawn(process_chat_sse( - stream, - tx_event, - provider.stream_idle_timeout(), - otel_event_manager.clone(), - )); - return Ok(ResponseStream { rx_event }); - } - Ok(res) => { - let status = res.status(); - if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) { - let body = (res.text().await).unwrap_or_default(); - return Err(CodexErr::UnexpectedStatus(UnexpectedResponseError { - status, - body, - request_id: None, - })); - } - - if attempt > max_retries { - return Err(CodexErr::RetryLimit(RetryLimitReachedError { - status, - request_id: None, - })); - } - - let retry_after_secs = res - .headers() - .get(reqwest::header::RETRY_AFTER) - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse::().ok()); - - let delay = retry_after_secs - .map(|s| Duration::from_millis(s * 1_000)) - .unwrap_or_else(|| backoff(attempt)); - tokio::time::sleep(delay).await; - } - Err(e) => { - if attempt > max_retries { - return Err(CodexErr::ConnectionFailed(ConnectionFailedError { - source: e, - })); - } - let delay = backoff(attempt); - tokio::time::sleep(delay).await; - } - } - } -} - -async fn append_assistant_text( - tx_event: &mpsc::Sender>, - assistant_item: &mut Option, - text: String, -) { - if assistant_item.is_none() { - let item = ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![], - }; - *assistant_item = Some(item.clone()); - let _ = tx_event - .send(Ok(ResponseEvent::OutputItemAdded(item))) - .await; - } - - if let Some(ResponseItem::Message { content, .. }) = assistant_item { - content.push(ContentItem::OutputText { text: text.clone() }); - let _ = tx_event - .send(Ok(ResponseEvent::OutputTextDelta(text.clone()))) - .await; - } -} - -async fn append_reasoning_text( - tx_event: &mpsc::Sender>, - reasoning_item: &mut Option, - text: String, -) { - if reasoning_item.is_none() { - let item = ResponseItem::Reasoning { - id: String::new(), - summary: Vec::new(), - content: Some(vec![]), - encrypted_content: None, - }; - *reasoning_item = Some(item.clone()); - let _ = tx_event - .send(Ok(ResponseEvent::OutputItemAdded(item))) - .await; - } - - if let Some(ResponseItem::Reasoning { - content: Some(content), - .. - }) = reasoning_item - { - content.push(ReasoningItemContent::ReasoningText { text: text.clone() }); - - let _ = tx_event - .send(Ok(ResponseEvent::ReasoningContentDelta(text.clone()))) - .await; - } -} -/// Lightweight SSE processor for the Chat Completions streaming format. The -/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest -/// of the pipeline can stay agnostic of the underlying wire format. -async fn process_chat_sse( - stream: S, - tx_event: mpsc::Sender>, - idle_timeout: Duration, - otel_event_manager: OtelEventManager, -) where - S: Stream> + Unpin, -{ - let mut stream = stream.eventsource(); - - // State to accumulate a function call across streaming chunks. - // OpenAI may split the `arguments` string over multiple `delta` events - // until the chunk whose `finish_reason` is `tool_calls` is emitted. We - // keep collecting the pieces here and forward a single - // `ResponseItem::FunctionCall` once the call is complete. - #[derive(Default)] - struct FunctionCallState { - name: Option, - arguments: String, - call_id: Option, - active: bool, - } - - let mut fn_call_state = FunctionCallState::default(); - let mut assistant_item: Option = None; - let mut reasoning_item: Option = None; - - loop { - let start = std::time::Instant::now(); - let response = timeout(idle_timeout, stream.next()).await; - let duration = start.elapsed(); - otel_event_manager.log_sse_event(&response, duration); - - let sse = match response { - Ok(Some(Ok(ev))) => ev, - Ok(Some(Err(e))) => { - let _ = tx_event - .send(Err(CodexErr::Stream(e.to_string(), None))) - .await; - return; - } - Ok(None) => { - // Stream closed gracefully – emit Completed with dummy id. - let _ = tx_event - .send(Ok(ResponseEvent::Completed { - response_id: String::new(), - token_usage: None, - })) - .await; - return; - } - Err(_) => { - let _ = tx_event - .send(Err(CodexErr::Stream( - "idle timeout waiting for SSE".into(), - None, - ))) - .await; - return; - } - }; - - // OpenAI Chat streaming sends a literal string "[DONE]" when finished. - if sse.data.trim() == "[DONE]" { - // Emit any finalized items before closing so downstream consumers receive - // terminal events for both assistant content and raw reasoning. - if let Some(item) = assistant_item { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - - if let Some(item) = reasoning_item { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - - let _ = tx_event - .send(Ok(ResponseEvent::Completed { - response_id: String::new(), - token_usage: None, - })) - .await; - return; - } - - // Parse JSON chunk - let chunk: serde_json::Value = match serde_json::from_str(&sse.data) { - Ok(v) => v, - Err(_) => continue, - }; - trace!("chat_completions received SSE chunk: {chunk:?}"); - - let choice_opt = chunk.get("choices").and_then(|c| c.get(0)); - - if let Some(choice) = choice_opt { - // Handle assistant content tokens as streaming deltas. - if let Some(content) = choice - .get("delta") - .and_then(|d| d.get("content")) - .and_then(|c| c.as_str()) - && !content.is_empty() - { - append_assistant_text(&tx_event, &mut assistant_item, content.to_string()).await; - } - - // Forward any reasoning/thinking deltas if present. - // Some providers stream `reasoning` as a plain string while others - // nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`). - if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) { - let mut maybe_text = reasoning_val - .as_str() - .map(str::to_string) - .filter(|s| !s.is_empty()); - - if maybe_text.is_none() && reasoning_val.is_object() { - if let Some(s) = reasoning_val - .get("text") - .and_then(|t| t.as_str()) - .filter(|s| !s.is_empty()) - { - maybe_text = Some(s.to_string()); - } else if let Some(s) = reasoning_val - .get("content") - .and_then(|t| t.as_str()) - .filter(|s| !s.is_empty()) - { - maybe_text = Some(s.to_string()); - } - } - - if let Some(reasoning) = maybe_text { - // Accumulate so we can emit a terminal Reasoning item at the end. - append_reasoning_text(&tx_event, &mut reasoning_item, reasoning).await; - } - } - - // Some providers only include reasoning on the final message object. - if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning")) - { - // Accept either a plain string or an object with { text | content } - if let Some(s) = message_reasoning.as_str() { - if !s.is_empty() { - append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await; - } - } else if let Some(obj) = message_reasoning.as_object() - && let Some(s) = obj - .get("text") - .and_then(|v| v.as_str()) - .or_else(|| obj.get("content").and_then(|v| v.as_str())) - && !s.is_empty() - { - append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await; - } - } - - // Handle streaming function / tool calls. - if let Some(tool_calls) = choice - .get("delta") - .and_then(|d| d.get("tool_calls")) - .and_then(|tc| tc.as_array()) - && let Some(tool_call) = tool_calls.first() - { - // Mark that we have an active function call in progress. - fn_call_state.active = true; - - // Extract call_id if present. - if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) { - fn_call_state.call_id.get_or_insert_with(|| id.to_string()); - } - - // Extract function details if present. - if let Some(function) = tool_call.get("function") { - if let Some(name) = function.get("name").and_then(|n| n.as_str()) { - fn_call_state.name.get_or_insert_with(|| name.to_string()); - } - - if let Some(args_fragment) = function.get("arguments").and_then(|a| a.as_str()) - { - fn_call_state.arguments.push_str(args_fragment); - } - } - } - - // Emit end-of-turn when finish_reason signals completion. - if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) { - match finish_reason { - "tool_calls" if fn_call_state.active => { - // First, flush the terminal raw reasoning so UIs can finalize - // the reasoning stream before any exec/tool events begin. - if let Some(item) = reasoning_item.take() { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - - // Then emit the FunctionCall response item. - let item = ResponseItem::FunctionCall { - id: None, - name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()), - arguments: fn_call_state.arguments.clone(), - call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new), - }; - - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - "stop" => { - // Regular turn without tool-call. Emit the final assistant message - // as a single OutputItemDone so non-delta consumers see the result. - if let Some(item) = assistant_item.take() { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - // Also emit a terminal Reasoning item so UIs can finalize raw reasoning. - if let Some(item) = reasoning_item.take() { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - } - _ => {} - } - - // Emit Completed regardless of reason so the agent can advance. - let _ = tx_event - .send(Ok(ResponseEvent::Completed { - response_id: String::new(), - token_usage: None, - })) - .await; - - // Prepare for potential next turn (should not happen in same stream). - // fn_call_state = FunctionCallState::default(); - - return; // End processing for this SSE stream. - } - } - } -} - -/// Optional client-side aggregation helper -/// -/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from -/// [`process_chat_sse`] into a *running* assistant message, **suppressing the -/// per-token deltas**. The stream stays silent while the model is thinking -/// and only emits two events per turn: -/// -/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message -/// (fully concatenated). -/// 2. The original `ResponseEvent::Completed` right after it. -/// -/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers. -/// -/// The adapter is intentionally *lossless*: callers who do **not** opt in via -/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified -/// events. -#[derive(Copy, Clone, Eq, PartialEq)] -enum AggregateMode { - AggregatedOnly, - Streaming, -} -pub(crate) struct AggregatedChatStream { - inner: S, - cumulative: String, - cumulative_reasoning: String, - pending: std::collections::VecDeque, - mode: AggregateMode, -} - -impl Stream for AggregatedChatStream -where - S: Stream> + Unpin, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - // First, flush any buffered events from the previous call. - if let Some(ev) = this.pending.pop_front() { - return Poll::Ready(Some(Ok(ev))); - } - - loop { - match Pin::new(&mut this.inner).poll_next(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(None) => return Poll::Ready(None), - Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), - Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { - // If this is an incremental assistant message chunk, accumulate but - // do NOT emit yet. Forward any other item (e.g. FunctionCall) right - // away so downstream consumers see it. - - let is_assistant_message = matches!( - &item, - codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant" - ); - - if is_assistant_message { - match this.mode { - AggregateMode::AggregatedOnly => { - // Only use the final assistant message if we have not - // seen any deltas; otherwise, deltas already built the - // cumulative text and this would duplicate it. - if this.cumulative.is_empty() - && let codex_protocol::models::ResponseItem::Message { - content, - .. - } = &item - && let Some(text) = content.iter().find_map(|c| match c { - codex_protocol::models::ContentItem::OutputText { - text, - } => Some(text), - _ => None, - }) - { - this.cumulative.push_str(text); - } - // Swallow assistant message here; emit on Completed. - continue; - } - AggregateMode::Streaming => { - // In streaming mode, if we have not seen any deltas, forward - // the final assistant message directly. If deltas were seen, - // suppress the final message to avoid duplication. - if this.cumulative.is_empty() { - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( - item, - )))); - } else { - continue; - } - } - } - } - - // Not an assistant message – forward immediately. - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))); - } - Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))); - } - Poll::Ready(Some(Ok(ResponseEvent::Completed { - response_id, - token_usage, - }))) => { - // Build any aggregated items in the correct order: Reasoning first, then Message. - let mut emitted_any = false; - - if !this.cumulative_reasoning.is_empty() - && matches!(this.mode, AggregateMode::AggregatedOnly) - { - let aggregated_reasoning = - codex_protocol::models::ResponseItem::Reasoning { - id: String::new(), - summary: Vec::new(), - content: Some(vec![ - codex_protocol::models::ReasoningItemContent::ReasoningText { - text: std::mem::take(&mut this.cumulative_reasoning), - }, - ]), - encrypted_content: None, - }; - this.pending - .push_back(ResponseEvent::OutputItemDone(aggregated_reasoning)); - emitted_any = true; - } - - // Always emit the final aggregated assistant message when any - // content deltas have been observed. In AggregatedOnly mode this - // is the sole assistant output; in Streaming mode this finalizes - // the streamed deltas into a terminal OutputItemDone so callers - // can persist/render the message once per turn. - if !this.cumulative.is_empty() { - let aggregated_message = codex_protocol::models::ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![codex_protocol::models::ContentItem::OutputText { - text: std::mem::take(&mut this.cumulative), - }], - }; - this.pending - .push_back(ResponseEvent::OutputItemDone(aggregated_message)); - emitted_any = true; - } - - // Always emit Completed last when anything was aggregated. - if emitted_any { - this.pending.push_back(ResponseEvent::Completed { - response_id: response_id.clone(), - token_usage: token_usage.clone(), - }); - // Return the first pending event now. - if let Some(ev) = this.pending.pop_front() { - return Poll::Ready(Some(Ok(ev))); - } - } - - // Nothing aggregated – forward Completed directly. - return Poll::Ready(Some(Ok(ResponseEvent::Completed { - response_id, - token_usage, - }))); - } - Poll::Ready(Some(Ok(ResponseEvent::Created))) => { - // These events are exclusive to the Responses API and - // will never appear in a Chat Completions stream. - continue; - } - Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => { - // Always accumulate deltas so we can emit a final OutputItemDone at Completed. - this.cumulative.push_str(&delta); - if matches!(this.mode, AggregateMode::Streaming) { - // In streaming mode, also forward the delta immediately. - return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))); - } else { - continue; - } - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => { - // Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed. - this.cumulative_reasoning.push_str(&delta); - if matches!(this.mode, AggregateMode::Streaming) { - // In streaming mode, also forward the delta immediately. - return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))); - } else { - continue; - } - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => { - continue; - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => { - continue; - } - Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))); - } - } - } - } -} - -/// Extension trait that activates aggregation on any stream of [`ResponseEvent`]. -pub(crate) trait AggregateStreamExt: Stream> + Sized { - /// Returns a new stream that emits **only** the final assistant message - /// per turn instead of every incremental delta. The produced - /// `ResponseEvent` sequence for a typical text turn looks like: - /// - /// ```ignore - /// OutputItemDone() - /// Completed - /// ``` - /// - /// No other `OutputItemDone` events will be seen by the caller. - /// - /// Usage: - /// - /// ```ignore - /// let agg_stream = client.stream(&prompt).await?.aggregate(); - /// while let Some(event) = agg_stream.next().await { - /// // event now contains cumulative text - /// } - /// ``` - fn aggregate(self) -> AggregatedChatStream { - AggregatedChatStream::new(self, AggregateMode::AggregatedOnly) - } -} - -impl AggregateStreamExt for T where T: Stream> + Sized {} - -impl AggregatedChatStream { - fn new(inner: S, mode: AggregateMode) -> Self { - AggregatedChatStream { - inner, - cumulative: String::new(), - cumulative_reasoning: String::new(), - pending: std::collections::VecDeque::new(), - mode, - } - } - - pub(crate) fn streaming_mode(inner: S) -> Self { - Self::new(inner, AggregateMode::Streaming) - } -} diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 3a0bcb9b54..cd73f4a903 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -1,96 +1,123 @@ -use std::io::BufRead; -use std::path::Path; +use std::fmt; use std::sync::Arc; -use std::sync::OnceLock; -use std::time::Duration; -use bytes::Bytes; -use chrono::DateTime; -use chrono::Utc; -use codex_app_server_protocol::AuthMode; +use async_trait::async_trait; +use codex_api_client::AggregateStreamExt; +use codex_api_client::ApiClient; +use codex_api_client::AuthContext; +use codex_api_client::AuthProvider; +use codex_api_client::ChatAggregationMode; +use codex_api_client::ChatCompletionsApiClient; +use codex_api_client::ChatCompletionsApiClientConfig; +use codex_api_client::ModelProviderInfo; +use codex_api_client::ResponsesApiClient; +use codex_api_client::ResponsesApiClientConfig; +use codex_api_client::Result as ApiClientResult; +use codex_api_client::WireApi; +use codex_api_client::stream_from_fixture; use codex_otel::otel_event_manager::OtelEventManager; use codex_protocol::ConversationId; use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; -use codex_protocol::models::ResponseItem; use codex_protocol::protocol::SessionSource; -use eventsource_stream::Eventsource; -use futures::prelude::*; -use regex_lite::Regex; +use futures::StreamExt; +use futures::stream::BoxStream; use reqwest::StatusCode; -use reqwest::header::HeaderMap; -use serde::Deserialize; -use serde::Serialize; -use serde_json::Value; +use tokio::sync::OnceCell; use tokio::sync::mpsc; -use tokio::time::timeout; -use tokio_util::io::ReaderStream; -use tracing::debug; -use tracing::trace; use tracing::warn; use crate::AuthManager; -use crate::auth::CodexAuth; -use crate::auth::RefreshTokenError; -use crate::chat_completions::AggregateStreamExt; -use crate::chat_completions::stream_chat_completions; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; use crate::client_common::ResponseStream; -use crate::client_common::ResponsesApiRequest; use crate::client_common::create_reasoning_param_for_request; use crate::client_common::create_text_param_for_request; use crate::config::Config; -use crate::default_client::CodexHttpClient; use crate::default_client::create_client; use crate::error::CodexErr; use crate::error::ConnectionFailedError; +use crate::error::EnvVarError; use crate::error::ResponseStreamFailed; use crate::error::Result; use crate::error::RetryLimitReachedError; use crate::error::UnexpectedResponseError; -use crate::error::UsageLimitReachedError; use crate::flags::CODEX_RS_SSE_FIXTURE; use crate::model_family::ModelFamily; -use crate::model_provider_info::ModelProviderInfo; -use crate::model_provider_info::WireApi; use crate::openai_model_info::get_model_info; -use crate::protocol::RateLimitSnapshot; -use crate::protocol::RateLimitWindow; -use crate::protocol::TokenUsage; -use crate::token_data::PlanType; +use crate::tools::spec::create_tools_json_for_chat_completions_api; use crate::tools::spec::create_tools_json_for_responses_api; -use crate::util::backoff; -#[derive(Debug, Deserialize)] -struct ErrorResponse { - error: Error, -} - -#[derive(Debug, Deserialize)] -struct Error { - r#type: Option, - code: Option, - message: Option, - - // Optional fields available on "usage_limit_reached" and "usage_not_included" errors - plan_type: Option, - resets_at: Option, -} - -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ModelClient { config: Arc, auth_manager: Option>, otel_event_manager: OtelEventManager, - client: CodexHttpClient, provider: ModelProviderInfo, + backend: Arc>, conversation_id: ConversationId, effort: Option, summary: ReasoningSummaryConfig, session_source: SessionSource, } +impl fmt::Debug for ModelClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ModelClient") + .field("provider", &self.provider.name) + .field("model", &self.config.model) + .field("conversation_id", &self.conversation_id) + .field("backend_initialized", &self.backend.get().is_some()) + .finish() + } +} + +type ApiClientStream = BoxStream<'static, ApiClientResult>; + +enum ModelBackend { + Responses(ResponsesBackend), + Chat(ChatBackend), +} + +impl ModelBackend { + async fn stream(&self, prompt: codex_api_client::Prompt) -> ApiClientResult { + match self { + ModelBackend::Responses(backend) => backend.stream(prompt).await, + ModelBackend::Chat(backend) => backend.stream(prompt).await, + } + } +} + +struct ResponsesBackend { + client: ResponsesApiClient, +} + +impl ResponsesBackend { + async fn stream(&self, prompt: codex_api_client::Prompt) -> ApiClientResult { + self.client + .stream(prompt) + .await + .map(futures::StreamExt::boxed) + } +} + +struct ChatBackend { + client: ChatCompletionsApiClient, + show_reasoning: bool, +} + +impl ChatBackend { + async fn stream(&self, prompt: codex_api_client::Prompt) -> ApiClientResult { + let stream = self.client.stream(prompt).await?; + let stream = if self.show_reasoning { + stream.streaming_mode().boxed() + } else { + stream.aggregate().boxed() + }; + Ok(stream) + } +} + #[allow(clippy::too_many_arguments)] impl ModelClient { pub fn new( @@ -103,14 +130,14 @@ impl ModelClient { conversation_id: ConversationId, session_source: SessionSource, ) -> Self { - let client = create_client(); + let backend = Arc::new(OnceCell::new()); Self { config, auth_manager, otel_event_manager, - client, provider, + backend, conversation_id, effort, summary, @@ -123,7 +150,7 @@ impl ModelClient { self.config .model_context_window .or_else(|| get_model_info(&self.config.model_family).map(|info| info.context_window)) - .map(|w| w.saturating_mul(pct) / 100) + .map(|wid| wid.saturating_mul(pct) / 100) } pub fn get_auto_compact_token_limit(&self) -> Option { @@ -141,79 +168,47 @@ impl ModelClient { } pub async fn stream(&self, prompt: &Prompt) -> Result { - match self.provider.wire_api { - WireApi::Responses => self.stream_responses(prompt).await, - WireApi::Chat => { - // Create the raw streaming connection first. - let response_stream = stream_chat_completions( - prompt, - &self.config.model_family, - &self.client, - &self.provider, - &self.otel_event_manager, - &self.session_source, - ) - .await?; - - // Wrap it with the aggregation adapter so callers see *only* - // the final assistant message per turn (matching the - // behaviour of the Responses API). - let mut aggregated = if self.config.show_raw_agent_reasoning { - crate::chat_completions::AggregatedChatStream::streaming_mode(response_stream) - } else { - response_stream.aggregate() - }; - - // Bridge the aggregated stream back into a standard - // `ResponseStream` by forwarding events through a channel. - let (tx, rx) = mpsc::channel::>(16); - - tokio::spawn(async move { - use futures::StreamExt; - while let Some(ev) = aggregated.next().await { - // Exit early if receiver hung up. - if tx.send(ev).await.is_err() { - break; - } - } - }); - - Ok(ResponseStream { rx_event: rx }) - } + let api_prompt = self.build_api_prompt(prompt)?; + if self.provider.wire_api == WireApi::Responses + && let Some(path) = &*CODEX_RS_SSE_FIXTURE + { + warn!(path, "Streaming from fixture"); + let stream = + stream_from_fixture(path, self.provider.clone(), self.otel_event_manager.clone()) + .await + .map_err(map_api_error)? + .boxed(); + return Ok(wrap_stream(stream)); } + + let backend = self + .backend + .get_or_try_init(|| async { self.build_backend().await }) + .await + .map_err(map_api_error)?; + + let api_stream = backend.stream(api_prompt).await.map_err(map_api_error)?; + + Ok(wrap_stream(api_stream)) } - /// Implementation for the OpenAI *Responses* experimental API. - async fn stream_responses(&self, prompt: &Prompt) -> Result { - if let Some(path) = &*CODEX_RS_SSE_FIXTURE { - // short circuit for tests - warn!(path, "Streaming from fixture"); - return stream_from_fixture( - path, - self.provider.clone(), - self.otel_event_manager.clone(), - ) - .await; - } + fn build_api_prompt(&self, prompt: &Prompt) -> Result { + let instructions = prompt + .get_full_instructions(&self.config.model_family) + .into_owned(); + let input = prompt.get_formatted_input(); - let auth_manager = self.auth_manager.clone(); + let tools = match self.provider.wire_api { + WireApi::Responses => create_tools_json_for_responses_api(&prompt.tools)?, + WireApi::Chat => create_tools_json_for_chat_completions_api(&prompt.tools)?, + }; - let full_instructions = prompt.get_full_instructions(&self.config.model_family); - let tools_json = create_tools_json_for_responses_api(&prompt.tools)?; let reasoning = create_reasoning_param_for_request( &self.config.model_family, self.effort, self.summary, ); - let include: Vec = if reasoning.is_some() { - vec!["reasoning.encrypted_content".to_string()] - } else { - vec![] - }; - - let input_with_instructions = prompt.get_formatted_input(); - let verbosity = if self.config.model_family.support_verbosity { self.config.model_verbosity } else { @@ -226,243 +221,68 @@ impl ModelClient { None }; - // Only include `text.verbosity` for GPT-5 family models - let text = create_text_param_for_request(verbosity, &prompt.output_schema); + let text_controls = create_text_param_for_request(verbosity, &prompt.output_schema); - // In general, we want to explicitly send `store: false` when using the Responses API, - // but in practice, the Azure Responses API rejects `store: false`: - // - // - If store = false and id is sent an error is thrown that ID is not found - // - If store = false and id is not sent an error is thrown that ID is required - // - // For Azure, we send `store: true` and preserve reasoning item IDs. - let azure_workaround = self.provider.is_azure_responses_endpoint(); - - let payload = ResponsesApiRequest { - model: &self.config.model, - instructions: &full_instructions, - input: &input_with_instructions, - tools: &tools_json, - tool_choice: "auto", + Ok(codex_api_client::Prompt { + instructions, + input, + tools, parallel_tool_calls: prompt.parallel_tool_calls, + output_schema: prompt.output_schema.clone(), reasoning, - store: azure_workaround, - stream: true, - include, + text_controls, prompt_cache_key: Some(self.conversation_id.to_string()), - text, - }; - - let mut payload_json = serde_json::to_value(&payload)?; - if azure_workaround { - attach_item_ids(&mut payload_json, &input_with_instructions); - } - - let max_attempts = self.provider.request_max_retries(); - for attempt in 0..=max_attempts { - match self - .attempt_stream_responses(attempt, &payload_json, &auth_manager) - .await - { - Ok(stream) => { - return Ok(stream); - } - Err(StreamAttemptError::Fatal(e)) => { - return Err(e); - } - Err(retryable_attempt_error) => { - if attempt == max_attempts { - return Err(retryable_attempt_error.into_error()); - } - - tokio::time::sleep(retryable_attempt_error.delay(attempt)).await; - } - } - } - - unreachable!("stream_responses_attempt should always return"); + session_source: Some(self.session_source.clone()), + }) } - /// Single attempt to start a streaming Responses API call. - async fn attempt_stream_responses( - &self, - attempt: u64, - payload_json: &Value, - auth_manager: &Option>, - ) -> std::result::Result { - // Always fetch the latest auth in case a prior attempt refreshed the token. - let auth = auth_manager.as_ref().and_then(|m| m.auth()); + async fn build_backend(&self) -> ApiClientResult { + match self.provider.wire_api { + WireApi::Responses => self.build_responses_backend().await, + WireApi::Chat => self.build_chat_backend().await, + } + } - trace!( - "POST to {}: {:?}", - self.provider.get_full_url(&auth), - serde_json::to_string(payload_json) - .unwrap_or("".to_string()) - ); + async fn build_responses_backend(&self) -> ApiClientResult { + let auth_provider = self.auth_manager.as_ref().map(|manager| { + Arc::new(AuthManagerProvider::new(Arc::clone(manager))) as Arc + }); - let mut req_builder = self - .provider - .create_request_builder(&self.client, &auth) - .await - .map_err(StreamAttemptError::Fatal)?; + let http_client = create_client().clone_inner(); + let config = ResponsesApiClientConfig { + http_client, + provider: self.provider.clone(), + model: self.config.model.clone(), + conversation_id: self.conversation_id, + auth_provider, + otel_event_manager: self.otel_event_manager.clone(), + }; - // Include subagent header only for subagent sessions. - if let SessionSource::SubAgent(sub) = &self.session_source { - let subagent = if let crate::protocol::SubAgentSource::Other(label) = sub { - label.clone() + let client = ResponsesApiClient::new(config).await?; + Ok(ModelBackend::Responses(ResponsesBackend { client })) + } + + async fn build_chat_backend(&self) -> ApiClientResult { + let show_reasoning = self.config.show_raw_agent_reasoning; + let http_client = create_client().clone_inner(); + let config = ChatCompletionsApiClientConfig { + http_client, + provider: self.provider.clone(), + model: self.config.model.clone(), + otel_event_manager: self.otel_event_manager.clone(), + session_source: self.session_source.clone(), + aggregation_mode: if show_reasoning { + ChatAggregationMode::Streaming } else { - serde_json::to_value(sub) - .ok() - .and_then(|v| v.as_str().map(std::string::ToString::to_string)) - .unwrap_or_else(|| "other".to_string()) - }; - req_builder = req_builder.header("x-openai-subagent", subagent); - } + ChatAggregationMode::AggregatedOnly + }, + }; - req_builder = req_builder - // Send session_id for compatibility. - .header("conversation_id", self.conversation_id.to_string()) - .header("session_id", self.conversation_id.to_string()) - .header(reqwest::header::ACCEPT, "text/event-stream") - .json(payload_json); - - if let Some(auth) = auth.as_ref() - && auth.mode == AuthMode::ChatGPT - && let Some(account_id) = auth.get_account_id() - { - req_builder = req_builder.header("chatgpt-account-id", account_id); - } - - let res = self - .otel_event_manager - .log_request(attempt, || req_builder.send()) - .await; - - let mut request_id = None; - if let Ok(resp) = &res { - request_id = resp - .headers() - .get("cf-ray") - .map(|v| v.to_str().unwrap_or_default().to_string()); - } - - match res { - Ok(resp) if resp.status().is_success() => { - let (tx_event, rx_event) = mpsc::channel::>(1600); - - if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers()) - && tx_event - .send(Ok(ResponseEvent::RateLimits(snapshot))) - .await - .is_err() - { - debug!("receiver dropped rate limit snapshot event"); - } - - // spawn task to process SSE - let stream = resp.bytes_stream().map_err(move |e| { - CodexErr::ResponseStreamFailed(ResponseStreamFailed { - source: e, - request_id: request_id.clone(), - }) - }); - tokio::spawn(process_sse( - stream, - tx_event, - self.provider.stream_idle_timeout(), - self.otel_event_manager.clone(), - )); - - Ok(ResponseStream { rx_event }) - } - Ok(res) => { - let status = res.status(); - - // Pull out Retry‑After header if present. - let retry_after_secs = res - .headers() - .get(reqwest::header::RETRY_AFTER) - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse::().ok()); - let retry_after = retry_after_secs.map(|s| Duration::from_millis(s * 1_000)); - - if status == StatusCode::UNAUTHORIZED - && let Some(manager) = auth_manager.as_ref() - && let Some(auth) = auth.as_ref() - && auth.mode == AuthMode::ChatGPT - && let Err(err) = manager.refresh_token().await - { - let stream_error = match err { - RefreshTokenError::Permanent(failed) => { - StreamAttemptError::Fatal(CodexErr::RefreshTokenFailed(failed)) - } - RefreshTokenError::Transient(other) => { - StreamAttemptError::RetryableTransportError(CodexErr::Io(other)) - } - }; - return Err(stream_error); - } - - // The OpenAI Responses endpoint returns structured JSON bodies even for 4xx/5xx - // errors. When we bubble early with only the HTTP status the caller sees an opaque - // "unexpected status 400 Bad Request" which makes debugging nearly impossible. - // Instead, read (and include) the response text so higher layers and users see the - // exact error message (e.g. "Unknown parameter: 'input[0].metadata'"). The body is - // small and this branch only runs on error paths so the extra allocation is - // negligible. - if !(status == StatusCode::TOO_MANY_REQUESTS - || status == StatusCode::UNAUTHORIZED - || status.is_server_error()) - { - // Surface the error body to callers. Use `unwrap_or_default` per Clippy. - let body = res.text().await.unwrap_or_default(); - return Err(StreamAttemptError::Fatal(CodexErr::UnexpectedStatus( - UnexpectedResponseError { - status, - body, - request_id: None, - }, - ))); - } - - if status == StatusCode::TOO_MANY_REQUESTS { - let rate_limit_snapshot = parse_rate_limit_snapshot(res.headers()); - let body = res.json::().await.ok(); - if let Some(ErrorResponse { error }) = body { - if error.r#type.as_deref() == Some("usage_limit_reached") { - // Prefer the plan_type provided in the error message if present - // because it's more up to date than the one encoded in the auth - // token. - let plan_type = error - .plan_type - .or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type)); - let resets_at = error - .resets_at - .and_then(|seconds| DateTime::::from_timestamp(seconds, 0)); - let codex_err = CodexErr::UsageLimitReached(UsageLimitReachedError { - plan_type, - resets_at, - rate_limits: rate_limit_snapshot, - }); - return Err(StreamAttemptError::Fatal(codex_err)); - } else if error.r#type.as_deref() == Some("usage_not_included") { - return Err(StreamAttemptError::Fatal(CodexErr::UsageNotIncluded)); - } else if is_quota_exceeded_error(&error) { - return Err(StreamAttemptError::Fatal(CodexErr::QuotaExceeded)); - } - } - } - - Err(StreamAttemptError::RetryableHttpError { - status, - retry_after, - request_id, - }) - } - Err(e) => Err(StreamAttemptError::RetryableTransportError( - CodexErr::ConnectionFailed(ConnectionFailedError { source: e }), - )), - } + let client = ChatCompletionsApiClient::new(config).await?; + Ok(ModelBackend::Chat(ChatBackend { + client, + show_reasoning, + })) } pub fn get_provider(&self) -> ModelProviderInfo { @@ -477,22 +297,18 @@ impl ModelClient { self.session_source.clone() } - /// Returns the currently configured model slug. pub fn get_model(&self) -> String { self.config.model.clone() } - /// Returns the currently configured model family. pub fn get_model_family(&self) -> ModelFamily { self.config.model_family.clone() } - /// Returns the current reasoning effort setting. pub fn get_reasoning_effort(&self) -> Option { self.effort } - /// Returns the current reasoning summary setting. pub fn get_reasoning_summary(&self) -> ReasoningSummaryConfig { self.summary } @@ -502,1022 +318,204 @@ impl ModelClient { } } -enum StreamAttemptError { - RetryableHttpError { - status: StatusCode, - retry_after: Option, - request_id: Option, - }, - RetryableTransportError(CodexErr), - Fatal(CodexErr), +struct AuthManagerProvider { + manager: Arc, } -impl StreamAttemptError { - /// attempt is 0-based. - fn delay(&self, attempt: u64) -> Duration { - // backoff() uses 1-based attempts. - let backoff_attempt = attempt + 1; - match self { - Self::RetryableHttpError { retry_after, .. } => { - retry_after.unwrap_or_else(|| backoff(backoff_attempt)) - } - Self::RetryableTransportError { .. } => backoff(backoff_attempt), - Self::Fatal(_) => { - // Should not be called on Fatal errors. - Duration::from_secs(0) - } - } - } - - fn into_error(self) -> CodexErr { - match self { - Self::RetryableHttpError { - status, request_id, .. - } => { - if status == StatusCode::INTERNAL_SERVER_ERROR { - CodexErr::InternalServerError - } else { - CodexErr::RetryLimit(RetryLimitReachedError { status, request_id }) - } - } - Self::RetryableTransportError(error) => error, - Self::Fatal(error) => error, - } +impl AuthManagerProvider { + fn new(manager: Arc) -> Self { + Self { manager } } } -#[derive(Debug, Deserialize, Serialize)] -struct SseEvent { - #[serde(rename = "type")] - kind: String, - response: Option, - item: Option, - delta: Option, -} - -#[derive(Debug, Deserialize)] -struct ResponseCompleted { - id: String, - usage: Option, -} - -#[derive(Debug, Deserialize)] -struct ResponseCompletedUsage { - input_tokens: i64, - input_tokens_details: Option, - output_tokens: i64, - output_tokens_details: Option, - total_tokens: i64, -} - -impl From for TokenUsage { - fn from(val: ResponseCompletedUsage) -> Self { - TokenUsage { - input_tokens: val.input_tokens, - cached_input_tokens: val - .input_tokens_details - .map(|d| d.cached_tokens) - .unwrap_or(0), - output_tokens: val.output_tokens, - reasoning_output_tokens: val - .output_tokens_details - .map(|d| d.reasoning_tokens) - .unwrap_or(0), - total_tokens: val.total_tokens, - } - } -} - -#[derive(Debug, Deserialize)] -struct ResponseCompletedInputTokensDetails { - cached_tokens: i64, -} - -#[derive(Debug, Deserialize)] -struct ResponseCompletedOutputTokensDetails { - reasoning_tokens: i64, -} - -fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) { - let Some(input_value) = payload_json.get_mut("input") else { - return; - }; - let serde_json::Value::Array(items) = input_value else { - return; - }; - - for (value, item) in items.iter_mut().zip(original_items.iter()) { - if let ResponseItem::Reasoning { id, .. } - | ResponseItem::Message { id: Some(id), .. } - | ResponseItem::WebSearchCall { id: Some(id), .. } - | ResponseItem::FunctionCall { id: Some(id), .. } - | ResponseItem::LocalShellCall { id: Some(id), .. } - | ResponseItem::CustomToolCall { id: Some(id), .. } = item - { - if id.is_empty() { - continue; +#[async_trait] +impl AuthProvider for AuthManagerProvider { + async fn auth_context(&self) -> Option { + let auth = self.manager.auth()?; + let mode = auth.mode; + let account_id = auth.get_account_id(); + let bearer_token = match auth.get_token().await { + Ok(token) if !token.is_empty() => Some(token), + Ok(_) => None, + Err(err) => { + warn!("failed to resolve auth token: {err}"); + None } + }; - if let Some(obj) = value.as_object_mut() { - obj.insert("id".to_string(), Value::String(id.clone())); - } - } - } -} - -fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option { - let primary = parse_rate_limit_window( - headers, - "x-codex-primary-used-percent", - "x-codex-primary-window-minutes", - "x-codex-primary-reset-at", - ); - - let secondary = parse_rate_limit_window( - headers, - "x-codex-secondary-used-percent", - "x-codex-secondary-window-minutes", - "x-codex-secondary-reset-at", - ); - - Some(RateLimitSnapshot { primary, secondary }) -} - -fn parse_rate_limit_window( - headers: &HeaderMap, - used_percent_header: &str, - window_minutes_header: &str, - resets_at_header: &str, -) -> Option { - let used_percent: Option = parse_header_f64(headers, used_percent_header); - - used_percent.and_then(|used_percent| { - let window_minutes = parse_header_i64(headers, window_minutes_header); - let resets_at = parse_header_i64(headers, resets_at_header); - - let has_data = used_percent != 0.0 - || window_minutes.is_some_and(|minutes| minutes != 0) - || resets_at.is_some(); - - has_data.then_some(RateLimitWindow { - used_percent, - window_minutes, - resets_at, + Some(AuthContext { + mode, + bearer_token, + account_id, }) - }) -} + } -fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option { - parse_header_str(headers, name)? - .parse::() - .ok() - .filter(|v| v.is_finite()) -} - -fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option { - parse_header_str(headers, name)?.parse::().ok() -} - -fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> { - headers.get(name)?.to_str().ok() -} - -async fn process_sse( - stream: S, - tx_event: mpsc::Sender>, - idle_timeout: Duration, - otel_event_manager: OtelEventManager, -) where - S: Stream> + Unpin, -{ - let mut stream = stream.eventsource(); - - // If the stream stays completely silent for an extended period treat it as disconnected. - // The response id returned from the "complete" message. - let mut response_completed: Option = None; - let mut response_error: Option = None; - - loop { - let start = std::time::Instant::now(); - let response = timeout(idle_timeout, stream.next()).await; - let duration = start.elapsed(); - otel_event_manager.log_sse_event(&response, duration); - - let sse = match response { - Ok(Some(Ok(sse))) => sse, - Ok(Some(Err(e))) => { - debug!("SSE Error: {e:#}"); - let event = CodexErr::Stream(e.to_string(), None); - let _ = tx_event.send(Err(event)).await; - return; - } - Ok(None) => { - match response_completed { - Some(ResponseCompleted { - id: response_id, - usage, - }) => { - if let Some(token_usage) = &usage { - otel_event_manager.sse_event_completed( - token_usage.input_tokens, - token_usage.output_tokens, - token_usage - .input_tokens_details - .as_ref() - .map(|d| d.cached_tokens), - token_usage - .output_tokens_details - .as_ref() - .map(|d| d.reasoning_tokens), - token_usage.total_tokens, - ); - } - let event = ResponseEvent::Completed { - response_id, - token_usage: usage.map(Into::into), - }; - let _ = tx_event.send(Ok(event)).await; - } - None => { - let error = response_error.unwrap_or(CodexErr::Stream( - "stream closed before response.completed".into(), - None, - )); - otel_event_manager.see_event_completed_failed(&error); - - let _ = tx_event.send(Err(error)).await; - } - } - return; - } - Err(_) => { - let _ = tx_event - .send(Err(CodexErr::Stream( - "idle timeout waiting for SSE".into(), - None, - ))) - .await; - return; - } - }; - - let raw = sse.data.clone(); - trace!("SSE event: {}", raw); - - let event: SseEvent = match serde_json::from_str(&sse.data) { - Ok(event) => event, - Err(e) => { - debug!("Failed to parse SSE event: {e}, data: {}", &sse.data); - continue; - } - }; - - match event.kind.as_str() { - // Individual output item finalised. Forward immediately so the - // rest of the agent can stream assistant text/functions *live* - // instead of waiting for the final `response.completed` envelope. - // - // IMPORTANT: We used to ignore these events and forward the - // duplicated `output` array embedded in the `response.completed` - // payload. That produced two concrete issues: - // 1. No real‑time streaming – the user only saw output after the - // entire turn had finished, which broke the "typing" UX and - // made long‑running turns look stalled. - // 2. Duplicate `function_call_output` items – both the - // individual *and* the completed array were forwarded, which - // confused the backend and triggered 400 - // "previous_response_not_found" errors because the duplicated - // IDs did not match the incremental turn chain. - // - // The fix is to forward the incremental events *as they come* and - // drop the duplicated list inside `response.completed`. - "response.output_item.done" => { - let Some(item_val) = event.item else { continue }; - let Ok(item) = serde_json::from_value::(item_val) else { - debug!("failed to parse ResponseItem from output_item.done"); - continue; - }; - - let event = ResponseEvent::OutputItemDone(item); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - "response.output_text.delta" => { - if let Some(delta) = event.delta { - let event = ResponseEvent::OutputTextDelta(delta); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - "response.reasoning_summary_text.delta" => { - if let Some(delta) = event.delta { - let event = ResponseEvent::ReasoningSummaryDelta(delta); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - "response.reasoning_text.delta" => { - if let Some(delta) = event.delta { - let event = ResponseEvent::ReasoningContentDelta(delta); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - "response.created" => { - if event.response.is_some() { - let _ = tx_event.send(Ok(ResponseEvent::Created {})).await; - } - } - "response.failed" => { - if let Some(resp_val) = event.response { - response_error = Some(CodexErr::Stream( - "response.failed event received".to_string(), - None, - )); - - let error = resp_val.get("error"); - - if let Some(error) = error { - match serde_json::from_value::(error.clone()) { - Ok(error) => { - if is_context_window_error(&error) { - response_error = Some(CodexErr::ContextWindowExceeded); - } else if is_quota_exceeded_error(&error) { - response_error = Some(CodexErr::QuotaExceeded); - } else { - let delay = try_parse_retry_after(&error); - let message = error.message.clone().unwrap_or_default(); - response_error = Some(CodexErr::Stream(message, delay)); - } - } - Err(e) => { - let error = format!("failed to parse ErrorResponse: {e}"); - debug!(error); - response_error = Some(CodexErr::Stream(error, None)) - } - } - } - } - } - // Final response completed – includes array of output items & id - "response.completed" => { - if let Some(resp_val) = event.response { - match serde_json::from_value::(resp_val) { - Ok(r) => { - response_completed = Some(r); - } - Err(e) => { - let error = format!("failed to parse ResponseCompleted: {e}"); - debug!(error); - response_error = Some(CodexErr::Stream(error, None)); - continue; - } - }; - }; - } - "response.content_part.done" - | "response.function_call_arguments.delta" - | "response.custom_tool_call_input.delta" - | "response.custom_tool_call_input.done" // also emitted as response.output_item.done - | "response.in_progress" - | "response.output_text.done" => {} - "response.output_item.added" => { - let Some(item_val) = event.item else { continue }; - let Ok(item) = serde_json::from_value::(item_val) else { - debug!("failed to parse ResponseItem from output_item.done"); - continue; - }; - - let event = ResponseEvent::OutputItemAdded(item); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - "response.reasoning_summary_part.added" => { - // Boundary between reasoning summary sections (e.g., titles). - let event = ResponseEvent::ReasoningSummaryPartAdded; - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - "response.reasoning_summary_text.done" => {} - _ => {} - } + async fn refresh_token(&self) -> std::result::Result, String> { + self.manager + .refresh_token() + .await + .map_err(|err| err.to_string()) } } -/// used in tests to stream from a text SSE file -async fn stream_from_fixture( - path: impl AsRef, - provider: ModelProviderInfo, - otel_event_manager: OtelEventManager, -) -> Result { - let (tx_event, rx_event) = mpsc::channel::>(1600); - let f = std::fs::File::open(path.as_ref())?; - let lines = std::io::BufReader::new(f).lines(); - - // insert \n\n after each line for proper SSE parsing - let mut content = String::new(); - for line in lines { - content.push_str(&line?); - content.push_str("\n\n"); - } - - let rdr = std::io::Cursor::new(content); - let stream = ReaderStream::new(rdr).map_err(CodexErr::Io); - tokio::spawn(process_sse( - stream, - tx_event, - provider.stream_idle_timeout(), - otel_event_manager, - )); - Ok(ResponseStream { rx_event }) -} - -fn rate_limit_regex() -> &'static Regex { - static RE: OnceLock = OnceLock::new(); - - // Match both OpenAI-style messages like "Please try again in 1.898s" - // and Azure OpenAI-style messages like "Try again in 35 seconds". - #[expect(clippy::unwrap_used)] - RE.get_or_init(|| Regex::new(r"(?i)try again in\s*(\d+(?:\.\d+)?)\s*(s|ms|seconds?)").unwrap()) -} - -fn try_parse_retry_after(err: &Error) -> Option { - if err.code != Some("rate_limit_exceeded".to_string()) { - return None; - } - - // parse retry hints like "try again in 1.898s" or - // "Try again in 35 seconds" using regex - let re = rate_limit_regex(); - if let Some(message) = &err.message - && let Some(captures) = re.captures(message) - { - let seconds = captures.get(1); - let unit = captures.get(2); - - if let (Some(value), Some(unit)) = (seconds, unit) { - let value = value.as_str().parse::().ok()?; - let unit = unit.as_str().to_ascii_lowercase(); - - if unit == "s" || unit.starts_with("second") { - return Some(Duration::from_secs_f64(value)); - } else if unit == "ms" { - return Some(Duration::from_millis(value as u64)); - } - } - } - None -} - -fn is_context_window_error(error: &Error) -> bool { - error.code.as_deref() == Some("context_length_exceeded") -} - -fn is_quota_exceeded_error(error: &Error) -> bool { - error.code.as_deref() == Some("insufficient_quota") -} - -#[cfg(test)] -mod tests { - use super::*; - use assert_matches::assert_matches; - use serde_json::json; - use tokio::sync::mpsc; - use tokio_test::io::Builder as IoBuilder; - use tokio_util::io::ReaderStream; - - // ──────────────────────────── - // Helpers - // ──────────────────────────── - - /// Runs the SSE parser on pre-chunked byte slices and returns every event - /// (including any final `Err` from a stream-closure check). - async fn collect_events( - chunks: &[&[u8]], - provider: ModelProviderInfo, - otel_event_manager: OtelEventManager, - ) -> Vec> { - let mut builder = IoBuilder::new(); - for chunk in chunks { - builder.read(chunk); - } - - let reader = builder.build(); - let stream = ReaderStream::new(reader).map_err(CodexErr::Io); - let (tx, mut rx) = mpsc::channel::>(16); - tokio::spawn(process_sse( - stream, - tx, - provider.stream_idle_timeout(), - otel_event_manager, - )); - - let mut events = Vec::new(); - while let Some(ev) = rx.recv().await { - events.push(ev); - } - events - } - - /// Builds an in-memory SSE stream from JSON fixtures and returns only the - /// successfully parsed events (panics on internal channel errors). - async fn run_sse( - events: Vec, - provider: ModelProviderInfo, - otel_event_manager: OtelEventManager, - ) -> Vec { - let mut body = String::new(); - for e in events { - let kind = e - .get("type") - .and_then(|v| v.as_str()) - .expect("fixture event missing type"); - if e.as_object().map(|o| o.len() == 1).unwrap_or(false) { - body.push_str(&format!("event: {kind}\n\n")); - } else { - body.push_str(&format!("event: {kind}\ndata: {e}\n\n")); - } - } - - let (tx, mut rx) = mpsc::channel::>(8); - let stream = ReaderStream::new(std::io::Cursor::new(body)).map_err(CodexErr::Io); - tokio::spawn(process_sse( - stream, - tx, - provider.stream_idle_timeout(), - otel_event_manager, - )); - - let mut out = Vec::new(); - while let Some(ev) = rx.recv().await { - out.push(ev.expect("channel closed")); - } - out - } - - fn otel_event_manager() -> OtelEventManager { - OtelEventManager::new( - ConversationId::new(), - "test", - "test", - None, - Some("test@test.com".to_string()), - Some(AuthMode::ChatGPT), - false, - "test".to_string(), - ) - } - - // ──────────────────────────── - // Tests from `implement-test-for-responses-api-sse-parser` - // ──────────────────────────── - - #[tokio::test] - async fn parses_items_and_completed() { - let item1 = json!({ - "type": "response.output_item.done", - "item": { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": "Hello"}] - } - }) - .to_string(); - - let item2 = json!({ - "type": "response.output_item.done", - "item": { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": "World"}] - } - }) - .to_string(); - - let completed = json!({ - "type": "response.completed", - "response": { "id": "resp1" } - }) - .to_string(); - - let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); - let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n"); - let sse3 = format!("event: response.completed\ndata: {completed}\n\n"); - - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events( - &[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()], - provider, - otel_event_manager, - ) - .await; - - assert_eq!(events.len(), 3); - - matches!( - &events[0], - Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) - if role == "assistant" - ); - - matches!( - &events[1], - Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) - if role == "assistant" - ); - - match &events[2] { - Ok(ResponseEvent::Completed { - response_id, - token_usage, - }) => { - assert_eq!(response_id, "resp1"); - assert!(token_usage.is_none()); - } - other => panic!("unexpected third event: {other:?}"), - } - } - - #[tokio::test] - async fn error_when_missing_completed() { - let item1 = json!({ - "type": "response.output_item.done", - "item": { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": "Hello"}] - } - }) - .to_string(); - - let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 2); - - matches!(events[0], Ok(ResponseEvent::OutputItemDone(_))); - - match &events[1] { - Err(CodexErr::Stream(msg, _)) => { - assert_eq!(msg, "stream closed before response.completed") - } - other => panic!("unexpected second event: {other:?}"), - } - } - - #[tokio::test] - async fn error_when_error_event() { - let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_689bcf18d7f08194bf3440ba62fe05d803fee0cdac429894","object":"response","created_at":1755041560,"status":"failed","background":false,"error":{"code":"rate_limit_exceeded","message":"Rate limit reached for gpt-5 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."}, "usage":null,"user":null,"metadata":{}}}"#; - - let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 1); - - match &events[0] { - Err(CodexErr::Stream(msg, delay)) => { - assert_eq!( - msg, - "Rate limit reached for gpt-5 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more." - ); - assert_eq!(*delay, Some(Duration::from_secs_f64(11.054))); - } - other => panic!("unexpected second event: {other:?}"), - } - } - - #[tokio::test] - async fn context_window_error_is_fatal() { - let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_5c66275b97b9baef1ed95550adb3b7ec13b17aafd1d2f11b","object":"response","created_at":1759510079,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."},"usage":null,"user":null,"metadata":{}}}"#; - - let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 1); - - match &events[0] { - Err(err @ CodexErr::ContextWindowExceeded) => { - assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string()); - } - other => panic!("unexpected context window event: {other:?}"), - } - } - - #[tokio::test] - async fn context_window_error_with_newline_is_fatal() { - let raw_error = r#"{"type":"response.failed","sequence_number":4,"response":{"id":"resp_fatal_newline","object":"response","created_at":1759510080,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try\nagain."},"usage":null,"user":null,"metadata":{}}}"#; - - let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 1); - - match &events[0] { - Err(err @ CodexErr::ContextWindowExceeded) => { - assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string()); - } - other => panic!("unexpected context window event: {other:?}"), - } - } - - #[tokio::test] - async fn quota_exceeded_error_is_fatal() { - let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_quota","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"insufficient_quota","message":"You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors."},"incomplete_details":null}}"#; - - let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 1); - - match &events[0] { - Err(err @ CodexErr::QuotaExceeded) => { - assert_eq!(err.to_string(), CodexErr::QuotaExceeded.to_string()); - } - other => panic!("unexpected quota exceeded event: {other:?}"), - } - } - - // ──────────────────────────── - // Table-driven test from `main` - // ──────────────────────────── - - /// Verifies that the adapter produces the right `ResponseEvent` for a - /// variety of incoming `type` values. - #[tokio::test] - async fn table_driven_event_kinds() { - struct TestCase { - name: &'static str, - event: serde_json::Value, - expect_first: fn(&ResponseEvent) -> bool, - expected_len: usize, - } - - fn is_created(ev: &ResponseEvent) -> bool { - matches!(ev, ResponseEvent::Created) - } - fn is_output(ev: &ResponseEvent) -> bool { - matches!(ev, ResponseEvent::OutputItemDone(_)) - } - fn is_completed(ev: &ResponseEvent) -> bool { - matches!(ev, ResponseEvent::Completed { .. }) - } - - let completed = json!({ - "type": "response.completed", - "response": { - "id": "c", - "usage": { - "input_tokens": 0, - "input_tokens_details": null, - "output_tokens": 0, - "output_tokens_details": null, - "total_tokens": 0 - }, - "output": [] - } - }); - - let cases = vec![ - TestCase { - name: "created", - event: json!({"type": "response.created", "response": {}}), - expect_first: is_created, - expected_len: 2, - }, - TestCase { - name: "output_item.done", - event: json!({ - "type": "response.output_item.done", - "item": { - "type": "message", - "role": "assistant", - "content": [ - {"type": "output_text", "text": "hi"} - ] - } - }), - expect_first: is_output, - expected_len: 2, - }, - TestCase { - name: "unknown", - event: json!({"type": "response.new_tool_event"}), - expect_first: is_completed, - expected_len: 1, - }, - ]; - - for case in cases { - let mut evs = vec![case.event]; - evs.push(completed.clone()); - - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, +fn wrap_stream(stream: ApiClientStream) -> ResponseStream { + let (tx, rx) = mpsc::channel::>(1600); + + tokio::spawn(async move { + let mut stream = stream; + while let Some(item) = stream.next().await { + let mapped = match item { + Ok(event) => Ok(event), + Err(err) => Err(map_api_error(err)), }; - let otel_event_manager = otel_event_manager(); - - let out = run_sse(evs, provider, otel_event_manager).await; - assert_eq!(out.len(), case.expected_len, "case {}", case.name); - assert!( - (case.expect_first)(&out[0]), - "first event mismatch in case {}", - case.name - ); + if tx.send(mapped).await.is_err() { + break; + } } - } + }); - #[test] - fn test_try_parse_retry_after() { - let err = Error { - r#type: None, - message: Some("Rate limit reached for gpt-5 in organization org- on tokens per min (TPM): Limit 1, Used 1, Requested 19304. Please try again in 28ms. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()), - code: Some("rate_limit_exceeded".to_string()), - plan_type: None, - resets_at: None - }; + codex_api_client::EventStream::from_receiver(rx) +} - let delay = try_parse_retry_after(&err); - assert_eq!(delay, Some(Duration::from_millis(28))); - } - - #[test] - fn test_try_parse_retry_after_no_delay() { - let err = Error { - r#type: None, - message: Some("Rate limit reached for gpt-5 in organization on tokens per min (TPM): Limit 30000, Used 6899, Requested 24050. Please try again in 1.898s. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()), - code: Some("rate_limit_exceeded".to_string()), - plan_type: None, - resets_at: None - }; - let delay = try_parse_retry_after(&err); - assert_eq!(delay, Some(Duration::from_secs_f64(1.898))); - } - - #[test] - fn test_try_parse_retry_after_azure() { - let err = Error { - r#type: None, - message: Some("Rate limit exceeded. Try again in 35 seconds.".to_string()), - code: Some("rate_limit_exceeded".to_string()), - plan_type: None, - resets_at: None, - }; - let delay = try_parse_retry_after(&err); - assert_eq!(delay, Some(Duration::from_secs(35))); - } - - #[test] - fn error_response_deserializes_schema_known_plan_type_and_serializes_back() { - use crate::token_data::KnownPlan; - use crate::token_data::PlanType; - - let json = - r#"{"error":{"type":"usage_limit_reached","plan_type":"pro","resets_at":1704067200}}"#; - let resp: ErrorResponse = serde_json::from_str(json).expect("should deserialize schema"); - - assert_matches!(resp.error.plan_type, Some(PlanType::Known(KnownPlan::Pro))); - - let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type"); - assert_eq!(plan_json, "\"pro\""); - } - - #[test] - fn error_response_deserializes_schema_unknown_plan_type_and_serializes_back() { - use crate::token_data::PlanType; - - let json = - r#"{"error":{"type":"usage_limit_reached","plan_type":"vip","resets_at":1704067260}}"#; - let resp: ErrorResponse = serde_json::from_str(json).expect("should deserialize schema"); - - assert_matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip"); - - let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type"); - assert_eq!(plan_json, "\"vip\""); +fn map_api_error(err: codex_api_client::Error) -> CodexErr { + match err { + codex_api_client::Error::UnsupportedOperation(msg) => CodexErr::UnsupportedOperation(msg), + codex_api_client::Error::Http(source) => { + CodexErr::ConnectionFailed(ConnectionFailedError { source }) + } + codex_api_client::Error::ResponseStreamFailed { source, request_id } => { + CodexErr::ResponseStreamFailed(ResponseStreamFailed { source, request_id }) + } + codex_api_client::Error::Stream(message, delay) => CodexErr::Stream(message, delay), + codex_api_client::Error::UnexpectedStatus { status, body } => { + CodexErr::UnexpectedStatus(UnexpectedResponseError { + status, + body, + request_id: None, + }) + } + codex_api_client::Error::RetryLimit { status, request_id } => { + CodexErr::RetryLimit(RetryLimitReachedError { + status: status.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + request_id, + }) + } + codex_api_client::Error::MissingEnvVar { var, instructions } => { + CodexErr::EnvVar(EnvVarError { var, instructions }) + } + codex_api_client::Error::Auth(message) => CodexErr::Fatal(message), + codex_api_client::Error::Json(err) => CodexErr::Json(err), + codex_api_client::Error::Other(message) => CodexErr::Fatal(message), } } + +/// Stream using the codex-api-client directly from a `TurnContext` without `ModelClient` indirection. +pub async fn stream_for_turn( + ctx: &crate::codex::TurnContext, + prompt: &Prompt, +) -> Result { + let instructions = prompt + .get_full_instructions(&ctx.client.get_model_family()) + .into_owned(); + let input = prompt.get_formatted_input(); + + let tools = match ctx.client.get_provider().wire_api { + WireApi::Responses => create_tools_json_for_responses_api(&prompt.tools)?, + WireApi::Chat => create_tools_json_for_chat_completions_api(&prompt.tools)?, + }; + + let reasoning = create_reasoning_param_for_request( + &ctx.client.get_model_family(), + ctx.client.get_reasoning_effort(), + ctx.client.get_reasoning_summary(), + ); + + let verbosity = if ctx.client.get_model_family().support_verbosity { + ctx.client.config().model_verbosity + } else { + if ctx.client.config().model_verbosity.is_some() { + warn!( + "model_verbosity is set but ignored as the model does not support verbosity: {}", + ctx.client.get_model_family().family + ); + } + None + }; + + let text_controls = create_text_param_for_request(verbosity, &prompt.output_schema); + + let api_prompt = codex_api_client::Prompt { + instructions, + input, + tools, + parallel_tool_calls: prompt.parallel_tool_calls, + output_schema: prompt.output_schema.clone(), + reasoning, + text_controls, + prompt_cache_key: Some(ctx.client.conversation_id.to_string()), + session_source: Some(ctx.client.get_session_source()), + }; + + if ctx.client.get_provider().wire_api == WireApi::Responses + && let Some(path) = &*CODEX_RS_SSE_FIXTURE + { + warn!(path, "Streaming from fixture"); + let stream = stream_from_fixture( + path, + ctx.client.get_provider(), + ctx.client.get_otel_event_manager(), + ) + .await + .map_err(map_api_error)? + .boxed(); + return Ok(wrap_stream(stream)); + } + + let http_client = create_client().clone_inner(); + let api_stream = match ctx.client.get_provider().wire_api { + WireApi::Responses => { + let auth_provider = ctx.client.get_auth_manager().as_ref().map(|m| { + Arc::new(AuthManagerProvider::new(Arc::clone(m))) as Arc + }); + let cfg = ResponsesApiClientConfig { + http_client, + provider: ctx.client.get_provider(), + model: ctx.client.get_model(), + conversation_id: ctx.client.conversation_id, + auth_provider, + otel_event_manager: ctx.client.get_otel_event_manager(), + }; + let client = ResponsesApiClient::new(cfg).await.map_err(map_api_error)?; + client + .stream(api_prompt) + .await + .map_err(map_api_error)? + .boxed() + } + WireApi::Chat => { + let cfg = ChatCompletionsApiClientConfig { + http_client, + provider: ctx.client.get_provider(), + model: ctx.client.get_model(), + otel_event_manager: ctx.client.get_otel_event_manager(), + session_source: ctx.client.get_session_source(), + aggregation_mode: if ctx.client.config().show_raw_agent_reasoning { + ChatAggregationMode::Streaming + } else { + ChatAggregationMode::AggregatedOnly + }, + }; + let client = ChatCompletionsApiClient::new(cfg) + .await + .map_err(map_api_error)?; + client + .stream(api_prompt) + .await + .map_err(map_api_error)? + .boxed() + } + }; + + Ok(wrap_stream(api_stream)) +} diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 2ac02f5f66..8eb94026b4 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -1,24 +1,21 @@ use crate::client_common::tools::ToolSpec; use crate::error::Result; use crate::model_family::ModelFamily; -use crate::protocol::RateLimitSnapshot; -use crate::protocol::TokenUsage; +use codex_api_client::Reasoning; +pub use codex_api_client::ResponseEvent; +use codex_api_client::TextControls; +use codex_api_client::TextFormat; +use codex_api_client::TextFormatType; use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS; use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; use codex_protocol::config_types::Verbosity as VerbosityConfig; use codex_protocol::models::ResponseItem; -use futures::Stream; use serde::Deserialize; -use serde::Serialize; use serde_json::Value; use std::borrow::Cow; use std::collections::HashSet; use std::ops::Deref; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; -use tokio::sync::mpsc; /// Review thread system prompt. Edit `core/src/review_prompt.md` to customize. pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md"); @@ -193,95 +190,7 @@ fn strip_total_output_header(output: &str) -> Option<&str> { Some(remainder) } -#[derive(Debug)] -pub enum ResponseEvent { - Created, - OutputItemDone(ResponseItem), - OutputItemAdded(ResponseItem), - Completed { - response_id: String, - token_usage: Option, - }, - OutputTextDelta(String), - ReasoningSummaryDelta(String), - ReasoningContentDelta(String), - ReasoningSummaryPartAdded, - RateLimits(RateLimitSnapshot), -} - -#[derive(Debug, Serialize)] -pub(crate) struct Reasoning { - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) effort: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) summary: Option, -} - -#[derive(Debug, Serialize, Default, Clone)] -#[serde(rename_all = "snake_case")] -pub(crate) enum TextFormatType { - #[default] - JsonSchema, -} - -#[derive(Debug, Serialize, Default, Clone)] -pub(crate) struct TextFormat { - pub(crate) r#type: TextFormatType, - pub(crate) strict: bool, - pub(crate) schema: Value, - pub(crate) name: String, -} - -/// Controls under the `text` field in the Responses API for GPT-5. -#[derive(Debug, Serialize, Default, Clone)] -pub(crate) struct TextControls { - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) verbosity: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) format: Option, -} - -#[derive(Debug, Serialize, Default, Clone)] -#[serde(rename_all = "lowercase")] -pub(crate) enum OpenAiVerbosity { - Low, - #[default] - Medium, - High, -} - -impl From for OpenAiVerbosity { - fn from(v: VerbosityConfig) -> Self { - match v { - VerbosityConfig::Low => OpenAiVerbosity::Low, - VerbosityConfig::Medium => OpenAiVerbosity::Medium, - VerbosityConfig::High => OpenAiVerbosity::High, - } - } -} - -/// Request object that is serialized as JSON and POST'ed when using the -/// Responses API. -#[derive(Debug, Serialize)] -pub(crate) struct ResponsesApiRequest<'a> { - pub(crate) model: &'a str, - pub(crate) instructions: &'a str, - // TODO(mbolin): ResponseItem::Other should not be serialized. Currently, - // we code defensively to avoid this case, but perhaps we should use a - // separate enum for serialization. - pub(crate) input: &'a Vec, - pub(crate) tools: &'a [serde_json::Value], - pub(crate) tool_choice: &'static str, - pub(crate) parallel_tool_calls: bool, - pub(crate) reasoning: Option, - pub(crate) store: bool, - pub(crate) stream: bool, - pub(crate) include: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) prompt_cache_key: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) text: Option, -} +pub type ResponseStream = codex_api_client::EventStream>; pub(crate) mod tools { use crate::tools::spec::JsonSchema; @@ -366,7 +275,11 @@ pub(crate) fn create_text_param_for_request( } Some(TextControls { - verbosity: verbosity.map(std::convert::Into::into), + verbosity: verbosity.map(|v| match v { + VerbosityConfig::Low => "low".to_string(), + VerbosityConfig::Medium => "medium".to_string(), + VerbosityConfig::High => "high".to_string(), + }), format: output_schema.as_ref().map(|schema| TextFormat { r#type: TextFormatType::JsonSchema, strict: true, @@ -376,18 +289,6 @@ pub(crate) fn create_text_param_for_request( }) } -pub struct ResponseStream { - pub(crate) rx_event: mpsc::Receiver>, -} - -impl Stream for ResponseStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.rx_event.poll_recv(cx) - } -} - #[cfg(test)] mod tests { use crate::model_family::find_family_for_model; @@ -453,39 +354,14 @@ mod tests { #[test] fn serializes_text_verbosity_when_set() { - let input: Vec = vec![]; - let tools: Vec = vec![]; - let req = ResponsesApiRequest { - model: "gpt-5", - instructions: "i", - input: &input, - tools: &tools, - tool_choice: "auto", - parallel_tool_calls: true, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - text: Some(TextControls { - verbosity: Some(OpenAiVerbosity::Low), - format: None, - }), - }; - - let v = serde_json::to_value(&req).expect("json"); - assert_eq!( - v.get("text") - .and_then(|t| t.get("verbosity")) - .and_then(|s| s.as_str()), - Some("low") - ); + let controls = + create_text_param_for_request(Some(VerbosityConfig::Low), &None).expect("controls"); + assert_eq!(controls.verbosity.as_deref(), Some("low")); + assert!(controls.format.is_none()); } #[test] fn serializes_text_schema_with_strict_format() { - let input: Vec = vec![]; - let tools: Vec = vec![]; let schema = serde_json::json!({ "type": "object", "properties": { @@ -493,61 +369,17 @@ mod tests { }, "required": ["answer"], }); - let text_controls = + let controls = create_text_param_for_request(None, &Some(schema.clone())).expect("text controls"); - - let req = ResponsesApiRequest { - model: "gpt-5", - instructions: "i", - input: &input, - tools: &tools, - tool_choice: "auto", - parallel_tool_calls: true, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - text: Some(text_controls), - }; - - let v = serde_json::to_value(&req).expect("json"); - let text = v.get("text").expect("text field"); - assert!(text.get("verbosity").is_none()); - let format = text.get("format").expect("format field"); - - assert_eq!( - format.get("name"), - Some(&serde_json::Value::String("codex_output_schema".into())) - ); - assert_eq!( - format.get("type"), - Some(&serde_json::Value::String("json_schema".into())) - ); - assert_eq!(format.get("strict"), Some(&serde_json::Value::Bool(true))); - assert_eq!(format.get("schema"), Some(&schema)); + assert!(controls.verbosity.is_none()); + let format = controls.format.expect("format"); + assert_eq!(format.name, "codex_output_schema"); + assert!(format.strict); + assert_eq!(format.schema, schema); } #[test] fn omits_text_when_not_set() { - let input: Vec = vec![]; - let tools: Vec = vec![]; - let req = ResponsesApiRequest { - model: "gpt-5", - instructions: "i", - input: &input, - tools: &tools, - tool_choice: "auto", - parallel_tool_calls: true, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - text: None, - }; - - let v = serde_json::to_value(&req).expect("json"); - assert!(v.get("text").is_none()); + assert!(create_text_param_for_request(None, &None).is_none()); } } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index fc6b44c621..00b13637a4 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -52,7 +52,6 @@ use tracing::info; use tracing::warn; use crate::ModelProviderInfo; -use crate::client::ModelClient; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; use crate::config::Config; @@ -294,6 +293,8 @@ impl TurnContext { } } +// Model-specific helpers live on ModelClient; TurnContext remains lean. + #[allow(dead_code)] #[derive(Clone)] pub(crate) struct SessionConfiguration { @@ -403,6 +404,11 @@ impl Session { session_configuration.model.as_str(), ); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + features: &config.features, + }); + let client = ModelClient::new( Arc::new(per_turn_config), auth_manager, @@ -414,11 +420,6 @@ impl Session { session_configuration.session_source.clone(), ); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_family: &model_family, - features: &config.features, - }); - TurnContext { sub_id, client, @@ -1674,6 +1675,7 @@ async fn spawn_review_thread( ); let per_turn_config = Arc::new(per_turn_config); + let client = ModelClient::new( per_turn_config.clone(), auth_manager, @@ -1936,7 +1938,7 @@ async fn run_turn( retries += 1; let delay = match e { CodexErr::Stream(_, Some(delay)) => delay, - _ => backoff(retries), + _ => backoff(retries.max(0) as u64), }; warn!( "stream disconnected - retrying turn ({retries}/{max_retries} in {delay:?})...", @@ -1995,10 +1997,7 @@ async fn try_run_turn( }); sess.persist_rollout_items(&[rollout_item]).await; - let mut stream = turn_context - .client - .clone() - .stream(prompt) + let mut stream = crate::client::stream_for_turn(&turn_context, prompt) .or_cancel(&cancellation_token) .await??; @@ -3144,3 +3143,4 @@ mod tests { ); } } +use crate::ModelClient; diff --git a/codex-rs/core/src/codex/compact.rs b/codex-rs/core/src/codex/compact.rs index a3ffe6a3f0..5e84b0baa3 100644 --- a/codex-rs/core/src/codex/compact.rs +++ b/codex-rs/core/src/codex/compact.rs @@ -120,7 +120,7 @@ async fn run_compact_task_inner( Err(e) => { if retries < max_retries { retries += 1; - let delay = backoff(retries); + let delay = backoff(retries.max(0) as u64); sess.notify_stream_error( turn_context.as_ref(), format!("Reconnecting... {retries}/{max_retries}"), @@ -266,7 +266,7 @@ async fn drain_to_completed( turn_context: &TurnContext, prompt: &Prompt, ) -> CodexResult<()> { - let mut stream = turn_context.client.clone().stream(prompt).await?; + let mut stream = crate::client::stream_for_turn(turn_context, prompt).await?; loop { let maybe_event = stream.next().await; let Some(event) = maybe_event else { diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index 0dc9d12667..7cd07f6c3e 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -1,4 +1,6 @@ +use crate::ModelProviderInfo; use crate::auth::AuthCredentialsStoreMode; +use crate::built_in_model_providers; use crate::config::types::DEFAULT_OTEL_ENVIRONMENT; use crate::config::types::History; use crate::config::types::McpServerConfig; @@ -25,8 +27,6 @@ use crate::git_info::resolve_root_git_project_for_trust; use crate::model_family::ModelFamily; use crate::model_family::derive_default_model_family; use crate::model_family::find_family_for_model; -use crate::model_provider_info::ModelProviderInfo; -use crate::model_provider_info::built_in_model_providers; use crate::openai_model_info::get_model_info; use crate::project_doc::DEFAULT_PROJECT_DOC_FILENAME; use crate::project_doc::LOCAL_PROJECT_DOC_FILENAME; diff --git a/codex-rs/core/src/default_client.rs b/codex-rs/core/src/default_client.rs index 8e4635460c..b29f06e29c 100644 --- a/codex-rs/core/src/default_client.rs +++ b/codex-rs/core/src/default_client.rs @@ -41,6 +41,14 @@ impl CodexHttpClient { Self { inner } } + pub fn inner(&self) -> &reqwest::Client { + &self.inner + } + + pub fn clone_inner(&self) -> reqwest::Client { + self.inner.clone() + } + pub fn get(&self, url: U) -> CodexRequestBuilder where U: IntoUrl, diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 93a317700f..c96d2bc1ac 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -8,7 +8,6 @@ mod apply_patch; pub mod auth; pub mod bash; -mod chat_completions; mod client; mod client_common; pub mod codex; @@ -32,7 +31,6 @@ pub mod mcp; mod mcp_connection_manager; mod mcp_tool_call; mod message_history; -mod model_provider_info; pub mod parse_command; mod response_processing; pub mod sandboxing; @@ -40,11 +38,11 @@ pub mod token_data; mod truncate; mod unified_exec; mod user_instructions; -pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID; -pub use model_provider_info::ModelProviderInfo; -pub use model_provider_info::WireApi; -pub use model_provider_info::built_in_model_providers; -pub use model_provider_info::create_oss_provider_with_base_url; +pub use codex_api_client::BUILT_IN_OSS_MODEL_PROVIDER_ID; +pub use codex_api_client::ModelProviderInfo; +pub use codex_api_client::WireApi; +pub use codex_api_client::built_in_model_providers; +pub use codex_api_client::create_oss_provider_with_base_url; mod conversation_manager; mod event_mapping; pub mod review_format; diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs deleted file mode 100644 index 8dc252aa7c..0000000000 --- a/codex-rs/core/src/model_provider_info.rs +++ /dev/null @@ -1,532 +0,0 @@ -//! Registry of model providers supported by Codex. -//! -//! Providers can be defined in two places: -//! 1. Built-in defaults compiled into the binary so Codex works out-of-the-box. -//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers` -//! key. These override or extend the defaults at runtime. - -use crate::CodexAuth; -use crate::default_client::CodexHttpClient; -use crate::default_client::CodexRequestBuilder; -use codex_app_server_protocol::AuthMode; -use serde::Deserialize; -use serde::Serialize; -use std::collections::HashMap; -use std::env::VarError; -use std::time::Duration; - -use crate::error::EnvVarError; -const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000; -const DEFAULT_STREAM_MAX_RETRIES: u64 = 5; -const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4; -/// Hard cap for user-configured `stream_max_retries`. -const MAX_STREAM_MAX_RETRIES: u64 = 100; -/// Hard cap for user-configured `request_max_retries`. -const MAX_REQUEST_MAX_RETRIES: u64 = 100; - -/// Wire protocol that the provider speaks. Most third-party services only -/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI -/// itself (and a handful of others) additionally expose the more modern -/// *Responses* API. The two protocols use different request/response shapes -/// and *cannot* be auto-detected at runtime, therefore each provider entry -/// must declare which one it expects. -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum WireApi { - /// The Responses API exposed by OpenAI at `/v1/responses`. - Responses, - - /// Regular Chat Completions compatible with `/v1/chat/completions`. - #[default] - Chat, -} - -/// Serializable representation of a provider definition. -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] -pub struct ModelProviderInfo { - /// Friendly display name. - pub name: String, - /// Base URL for the provider's OpenAI-compatible API. - pub base_url: Option, - /// Environment variable that stores the user's API key for this provider. - pub env_key: Option, - - /// Optional instructions to help the user get a valid value for the - /// variable and set it. - pub env_key_instructions: Option, - - /// Value to use with `Authorization: Bearer ` header. Use of this - /// config is discouraged in favor of `env_key` for security reasons, but - /// this may be necessary when using this programmatically. - pub experimental_bearer_token: Option, - - /// Which wire protocol this provider expects. - #[serde(default)] - pub wire_api: WireApi, - - /// Optional query parameters to append to the base URL. - pub query_params: Option>, - - /// Additional HTTP headers to include in requests to this provider where - /// the (key, value) pairs are the header name and value. - pub http_headers: Option>, - - /// Optional HTTP headers to include in requests to this provider where the - /// (key, value) pairs are the header name and _environment variable_ whose - /// value should be used. If the environment variable is not set, or the - /// value is empty, the header will not be included in the request. - pub env_http_headers: Option>, - - /// Maximum number of times to retry a failed HTTP request to this provider. - pub request_max_retries: Option, - - /// Number of times to retry reconnecting a dropped streaming response before failing. - pub stream_max_retries: Option, - - /// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating - /// the connection as lost. - pub stream_idle_timeout_ms: Option, - - /// Does this provider require an OpenAI API Key or ChatGPT login token? If true, - /// user is presented with login screen on first run, and login preference and token/key - /// are stored in auth.json. If false (which is the default), login screen is skipped, - /// and API key (if needed) comes from the "env_key" environment variable. - #[serde(default)] - pub requires_openai_auth: bool, -} - -impl ModelProviderInfo { - /// Construct a `POST` RequestBuilder for the given URL using the provided - /// [`CodexHttpClient`] applying: - /// • provider-specific headers (static + env based) - /// • Bearer auth header when an API key is available. - /// • Auth token for OAuth. - /// - /// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the - /// one produced by [`ModelProviderInfo::api_key`]. - pub async fn create_request_builder<'a>( - &'a self, - client: &'a CodexHttpClient, - auth: &Option, - ) -> crate::error::Result { - let effective_auth = if let Some(secret_key) = &self.experimental_bearer_token { - Some(CodexAuth::from_api_key(secret_key)) - } else { - match self.api_key() { - Ok(Some(key)) => Some(CodexAuth::from_api_key(&key)), - Ok(None) => auth.clone(), - Err(err) => { - if auth.is_some() { - auth.clone() - } else { - return Err(err); - } - } - } - }; - - let url = self.get_full_url(&effective_auth); - - let mut builder = client.post(url); - - if let Some(auth) = effective_auth.as_ref() { - builder = builder.bearer_auth(auth.get_token().await?); - } - - Ok(self.apply_http_headers(builder)) - } - - fn get_query_string(&self) -> String { - self.query_params - .as_ref() - .map_or_else(String::new, |params| { - let full_params = params - .iter() - .map(|(k, v)| format!("{k}={v}")) - .collect::>() - .join("&"); - format!("?{full_params}") - }) - } - - pub(crate) fn get_full_url(&self, auth: &Option) -> String { - let default_base_url = if matches!( - auth, - Some(CodexAuth { - mode: AuthMode::ChatGPT, - .. - }) - ) { - "https://chatgpt.com/backend-api/codex" - } else { - "https://api.openai.com/v1" - }; - let query_string = self.get_query_string(); - let base_url = self - .base_url - .clone() - .unwrap_or(default_base_url.to_string()); - - match self.wire_api { - WireApi::Responses => format!("{base_url}/responses{query_string}"), - WireApi::Chat => format!("{base_url}/chat/completions{query_string}"), - } - } - - pub(crate) fn is_azure_responses_endpoint(&self) -> bool { - if self.wire_api != WireApi::Responses { - return false; - } - - if self.name.eq_ignore_ascii_case("azure") { - return true; - } - - self.base_url - .as_ref() - .map(|base| matches_azure_responses_base_url(base)) - .unwrap_or(false) - } - - /// Apply provider-specific HTTP headers (both static and environment-based) - /// onto an existing [`CodexRequestBuilder`] and return the updated - /// builder. - fn apply_http_headers(&self, mut builder: CodexRequestBuilder) -> CodexRequestBuilder { - if let Some(extra) = &self.http_headers { - for (k, v) in extra { - builder = builder.header(k, v); - } - } - - if let Some(env_headers) = &self.env_http_headers { - for (header, env_var) in env_headers { - if let Ok(val) = std::env::var(env_var) - && !val.trim().is_empty() - { - builder = builder.header(header, val); - } - } - } - builder - } - - /// If `env_key` is Some, returns the API key for this provider if present - /// (and non-empty) in the environment. If `env_key` is required but - /// cannot be found, returns an error. - pub fn api_key(&self) -> crate::error::Result> { - match &self.env_key { - Some(env_key) => { - let env_value = std::env::var(env_key); - env_value - .and_then(|v| { - if v.trim().is_empty() { - Err(VarError::NotPresent) - } else { - Ok(Some(v)) - } - }) - .map_err(|_| { - crate::error::CodexErr::EnvVar(EnvVarError { - var: env_key.clone(), - instructions: self.env_key_instructions.clone(), - }) - }) - } - None => Ok(None), - } - } - - /// Effective maximum number of request retries for this provider. - pub fn request_max_retries(&self) -> u64 { - self.request_max_retries - .unwrap_or(DEFAULT_REQUEST_MAX_RETRIES) - .min(MAX_REQUEST_MAX_RETRIES) - } - - /// Effective maximum number of stream reconnection attempts for this provider. - pub fn stream_max_retries(&self) -> u64 { - self.stream_max_retries - .unwrap_or(DEFAULT_STREAM_MAX_RETRIES) - .min(MAX_STREAM_MAX_RETRIES) - } - - /// Effective idle timeout for streaming responses. - pub fn stream_idle_timeout(&self) -> Duration { - self.stream_idle_timeout_ms - .map(Duration::from_millis) - .unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS)) - } -} - -const DEFAULT_OLLAMA_PORT: u32 = 11434; - -pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss"; - -/// Built-in default provider list. -pub fn built_in_model_providers() -> HashMap { - use ModelProviderInfo as P; - - // We do not want to be in the business of adjucating which third-party - // providers are bundled with Codex CLI, so we only include the OpenAI and - // open source ("oss") providers by default. Users are encouraged to add to - // `model_providers` in config.toml to add their own providers. - [ - ( - "openai", - P { - name: "OpenAI".into(), - // Allow users to override the default OpenAI endpoint by - // exporting `OPENAI_BASE_URL`. This is useful when pointing - // Codex at a proxy, mock server, or Azure-style deployment - // without requiring a full TOML override for the built-in - // OpenAI provider. - base_url: std::env::var("OPENAI_BASE_URL") - .ok() - .filter(|v| !v.trim().is_empty()), - env_key: None, - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: Some( - [("version".to_string(), env!("CARGO_PKG_VERSION").to_string())] - .into_iter() - .collect(), - ), - env_http_headers: Some( - [ - ( - "OpenAI-Organization".to_string(), - "OPENAI_ORGANIZATION".to_string(), - ), - ("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()), - ] - .into_iter() - .collect(), - ), - // Use global defaults for retry/timeout unless overridden in config.toml. - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_openai_auth: true, - }, - ), - (BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()), - ] - .into_iter() - .map(|(k, v)| (k.to_string(), v)) - .collect() -} - -pub fn create_oss_provider() -> ModelProviderInfo { - // These CODEX_OSS_ environment variables are experimental: we may - // switch to reading values from config.toml instead. - let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL") - .ok() - .filter(|v| !v.trim().is_empty()) - { - Some(url) => url, - None => format!( - "http://localhost:{port}/v1", - port = std::env::var("CODEX_OSS_PORT") - .ok() - .filter(|v| !v.trim().is_empty()) - .and_then(|v| v.parse::().ok()) - .unwrap_or(DEFAULT_OLLAMA_PORT) - ), - }; - - create_oss_provider_with_base_url(&codex_oss_base_url) -} - -pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo { - ModelProviderInfo { - name: "gpt-oss".into(), - base_url: Some(base_url.into()), - env_key: None, - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Chat, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_openai_auth: false, - } -} - -fn matches_azure_responses_base_url(base_url: &str) -> bool { - let base = base_url.to_ascii_lowercase(); - const AZURE_MARKERS: [&str; 5] = [ - "openai.azure.", - "cognitiveservices.azure.", - "aoai.azure.", - "azure-api.", - "azurefd.", - ]; - AZURE_MARKERS.iter().any(|marker| base.contains(marker)) -} - -#[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn test_deserialize_ollama_model_provider_toml() { - let azure_provider_toml = r#" -name = "Ollama" -base_url = "http://localhost:11434/v1" - "#; - let expected_provider = ModelProviderInfo { - name: "Ollama".into(), - base_url: Some("http://localhost:11434/v1".into()), - env_key: None, - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Chat, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_openai_auth: false, - }; - - let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); - assert_eq!(expected_provider, provider); - } - - #[test] - fn test_deserialize_azure_model_provider_toml() { - let azure_provider_toml = r#" -name = "Azure" -base_url = "https://xxxxx.openai.azure.com/openai" -env_key = "AZURE_OPENAI_API_KEY" -query_params = { api-version = "2025-04-01-preview" } - "#; - let expected_provider = ModelProviderInfo { - name: "Azure".into(), - base_url: Some("https://xxxxx.openai.azure.com/openai".into()), - env_key: Some("AZURE_OPENAI_API_KEY".into()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Chat, - query_params: Some(maplit::hashmap! { - "api-version".to_string() => "2025-04-01-preview".to_string(), - }), - http_headers: None, - env_http_headers: None, - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_openai_auth: false, - }; - - let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); - assert_eq!(expected_provider, provider); - } - - #[test] - fn test_deserialize_example_model_provider_toml() { - let azure_provider_toml = r#" -name = "Example" -base_url = "https://example.com" -env_key = "API_KEY" -http_headers = { "X-Example-Header" = "example-value" } -env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } - "#; - let expected_provider = ModelProviderInfo { - name: "Example".into(), - base_url: Some("https://example.com".into()), - env_key: Some("API_KEY".into()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Chat, - query_params: None, - http_headers: Some(maplit::hashmap! { - "X-Example-Header".to_string() => "example-value".to_string(), - }), - env_http_headers: Some(maplit::hashmap! { - "X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(), - }), - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_openai_auth: false, - }; - - let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); - assert_eq!(expected_provider, provider); - } - - #[test] - fn detects_azure_responses_base_urls() { - fn provider_for(base_url: &str) -> ModelProviderInfo { - ModelProviderInfo { - name: "test".into(), - base_url: Some(base_url.into()), - env_key: None, - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_openai_auth: false, - } - } - - let positive_cases = [ - "https://foo.openai.azure.com/openai", - "https://foo.openai.azure.us/openai/deployments/bar", - "https://foo.cognitiveservices.azure.cn/openai", - "https://foo.aoai.azure.com/openai", - "https://foo.openai.azure-api.net/openai", - "https://foo.z01.azurefd.net/", - ]; - for base_url in positive_cases { - let provider = provider_for(base_url); - assert!( - provider.is_azure_responses_endpoint(), - "expected {base_url} to be detected as Azure" - ); - } - - let named_provider = ModelProviderInfo { - name: "Azure".into(), - base_url: Some("https://example.com".into()), - env_key: None, - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_openai_auth: false, - }; - assert!(named_provider.is_azure_responses_endpoint()); - - let negative_cases = [ - "https://api.openai.com/v1", - "https://example.com/openai", - "https://myproxy.azurewebsites.net/openai", - ]; - for base_url in negative_cases { - let provider = provider_for(base_url); - assert!( - !provider.is_azure_responses_endpoint(), - "expected {base_url} not to be detected as Azure" - ); - } - } -}