This commit is contained in:
jif-oai
2025-11-10 17:52:04 +00:00
parent 5b43146ba5
commit 7a5786f49f
16 changed files with 1712 additions and 1355 deletions

View File

@@ -1,27 +1,21 @@
use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputContentItem;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::SessionSource;
use eventsource_stream::Eventsource;
use futures::Stream;
use futures::StreamExt;
use futures::TryStreamExt;
use serde_json::Value;
use serde_json::json;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tracing::debug;
use tracing::trace;
use crate::aggregate::ChatAggregationMode;
use crate::api::ApiClient;
use crate::common::apply_subagent_header;
use crate::client::PayloadBuilder;
use crate::common::backoff;
use crate::error::Error;
use crate::error::Result;
@@ -64,7 +58,8 @@ impl ApiClient for ChatCompletionsApiClient {
async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
Self::validate_prompt(prompt)?;
let payload = self.build_payload(prompt)?;
let payload = crate::payload::chat::ChatPayloadBuilder::new(self.config.model.clone())
.build(prompt)?;
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let mut attempt: i64 = 0;
@@ -73,12 +68,14 @@ impl ApiClient for ChatCompletionsApiClient {
loop {
attempt += 1;
let req_builder = self
.config
.provider
.create_request_builder(&self.config.http_client, &None)
.await
.map(|builder| apply_subagent_header(builder, Some(&self.config.session_source)))?;
let req_builder = crate::client::http::build_request(
&self.config.http_client,
&self.config.provider,
&None,
Some(&self.config.session_source),
&[],
)
.await?;
let res = self
.config
@@ -103,12 +100,12 @@ impl ApiClient for ChatCompletionsApiClient {
let otel = self.config.otel_event_manager.clone();
let mode = self.config.aggregation_mode;
tokio::spawn(process_chat_sse(
tokio::spawn(crate::client::sse::process_sse(
stream,
tx_event.clone(),
idle_timeout,
otel,
mode,
crate::decode::chat::ChatSseDecoder::new(mode),
));
return Ok(ResponseStream { rx_event });
@@ -151,457 +148,6 @@ impl ChatCompletionsApiClient {
}
Ok(())
}
fn build_payload(&self, prompt: &Prompt) -> Result<serde_json::Value> {
let mut messages = Vec::<serde_json::Value>::new();
messages.push(json!({ "role": "system", "content": prompt.instructions }));
let mut reasoning_by_anchor_index: std::collections::HashMap<usize, String> =
std::collections::HashMap::new();
let mut last_emitted_role: Option<&str> = None;
for item in &prompt.input {
match item {
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
last_emitted_role = Some("assistant");
}
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
ResponseItem::Reasoning { .. }
| ResponseItem::Other
| ResponseItem::CustomToolCall { .. }
| ResponseItem::CustomToolCallOutput { .. }
| ResponseItem::WebSearchCall { .. }
| ResponseItem::GhostSnapshot { .. } => {}
}
}
let mut last_user_index: Option<usize> = None;
for (idx, item) in prompt.input.iter().enumerate() {
if let ResponseItem::Message { role, .. } = item
&& role == "user"
{
last_user_index = Some(idx);
}
}
if !matches!(last_emitted_role, Some("user")) {
for (idx, item) in prompt.input.iter().enumerate() {
if let Some(u_idx) = last_user_index
&& idx <= u_idx
{
continue;
}
if let ResponseItem::Reasoning {
content: Some(items),
..
} = item
{
let mut text = String::new();
for entry in items {
match entry {
ReasoningItemContent::ReasoningText { text: segment }
| ReasoningItemContent::Text { text: segment } => {
text.push_str(segment);
}
}
}
if text.trim().is_empty() {
continue;
}
let mut attached = false;
if idx > 0
&& let ResponseItem::Message { role, .. } = &prompt.input[idx - 1]
&& role == "assistant"
{
reasoning_by_anchor_index
.entry(idx - 1)
.and_modify(|val| val.push_str(&text))
.or_insert(text.clone());
attached = true;
}
if !attached && idx + 1 < prompt.input.len() {
match &prompt.input[idx + 1] {
ResponseItem::FunctionCall { .. }
| ResponseItem::LocalShellCall { .. } => {
reasoning_by_anchor_index
.entry(idx + 1)
.and_modify(|val| val.push_str(&text))
.or_insert(text.clone());
}
ResponseItem::Message { role, .. } if role == "assistant" => {
reasoning_by_anchor_index
.entry(idx + 1)
.and_modify(|val| val.push_str(&text))
.or_insert(text.clone());
}
_ => {}
}
}
}
}
}
let mut last_assistant_text: Option<String> = None;
for (idx, item) in prompt.input.iter().enumerate() {
match item {
ResponseItem::Message { role, content, .. } => {
let mut text = String::new();
let mut items: Vec<serde_json::Value> = Vec::new();
let mut saw_image = false;
for c in content {
match c {
ContentItem::InputText { text: t }
| ContentItem::OutputText { text: t } => {
text.push_str(t);
items.push(json!({"type":"text","text": t}));
}
ContentItem::InputImage { image_url } => {
saw_image = true;
items.push(
json!({"type":"image_url","image_url": {"url": image_url}}),
);
}
}
}
if role == "assistant" {
if let Some(prev) = &last_assistant_text
&& prev == &text
{
continue;
}
last_assistant_text = Some(text.clone());
}
let content_value = if role == "assistant" {
json!(text)
} else if saw_image {
json!(items)
} else {
json!(text)
};
let mut message = json!({
"role": role,
"content": content_value,
});
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
&& let Some(obj) = message.as_object_mut()
{
obj.insert("reasoning".to_string(), json!({"text": reasoning}));
}
messages.push(message);
}
ResponseItem::FunctionCall {
name,
arguments,
call_id,
..
} => {
messages.push(json!({
"role": "assistant",
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
},
}],
}));
}
ResponseItem::FunctionCallOutput { call_id, output } => {
let content_value = if let Some(items) = &output.content_items {
let mapped: Vec<serde_json::Value> = items
.iter()
.map(|item| match item {
FunctionCallOutputContentItem::InputText { text } => {
json!({"type":"text","text": text})
}
FunctionCallOutputContentItem::InputImage { image_url } => {
json!({"type":"image_url","image_url": {"url": image_url}})
}
})
.collect();
json!(mapped)
} else {
json!(output.content)
};
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": content_value,
}));
}
ResponseItem::LocalShellCall {
id,
call_id,
action,
..
} => {
let tool_id = call_id
.clone()
.filter(|value| !value.is_empty())
.or_else(|| id.clone())
.unwrap_or_default();
messages.push(json!({
"role": "assistant",
"tool_calls": [{
"id": tool_id,
"type": "function",
"function": {
"name": "shell",
"arguments": serde_json::to_string(action).unwrap_or_default(),
},
}],
}));
}
ResponseItem::CustomToolCall {
call_id,
name,
input,
..
} => {
messages.push(json!({
"role": "assistant",
"tool_calls": [{
"id": call_id.clone(),
"type": "function",
"function": {
"name": name,
"arguments": input,
},
}],
}));
}
ResponseItem::CustomToolCallOutput { call_id, output } => {
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": output,
}));
}
ResponseItem::WebSearchCall { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::Other
| ResponseItem::GhostSnapshot { .. } => {}
}
}
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
let payload = json!({
"model": self.config.model,
"messages": messages,
"stream": true,
"tools": tools_json,
});
trace!("chat completions payload: {}", payload);
Ok(payload)
}
}
/// Lightweight SSE processor for Chat Completions streaming, mapped to ResponseEvent.
async fn process_chat_sse<S>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
idle_timeout: Duration,
_otel_event_manager: OtelEventManager,
aggregation_mode: ChatAggregationMode,
) where
S: Stream<Item = Result<Bytes>> + Unpin,
{
let mut stream = stream.eventsource();
#[derive(Default)]
struct FunctionCallState {
name: Option<String>,
arguments: String,
call_id: Option<String>,
active: bool,
}
let mut fn_call_state = FunctionCallState::default();
let mut assistant_item: Option<ResponseItem> = None;
let mut reasoning_item: Option<ResponseItem> = None;
loop {
let response = timeout(idle_timeout, stream.next()).await;
let sse = match response {
Ok(Some(Ok(ev))) => ev,
Ok(Some(Err(err))) => {
let _ = tx_event
.send(Err(Error::Stream(err.to_string(), None)))
.await;
return;
}
Ok(None) => {
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
return;
}
Err(_) => {
let _ = tx_event
.send(Err(Error::Stream(
"idle timeout waiting for SSE".into(),
None,
)))
.await;
return;
}
};
if sse.data.trim() == "[DONE]" {
if let Some(item) = assistant_item {
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
if let Some(item) = reasoning_item {
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
return;
}
let Ok(parsed_chunk) = serde_json::from_str::<serde_json::Value>(&sse.data) else {
debug!("failed to parse SSE data into JSON: {}", sse.data);
continue;
};
let choices = parsed_chunk
.get("choices")
.and_then(|choices| choices.as_array())
.cloned()
.unwrap_or_default();
for choice in choices {
if let Some(delta) = choice.get("delta") {
if let Some(content) = delta.get("content").and_then(|c| c.as_array()) {
for piece in content {
if let Some(text) = piece.get("text").and_then(|t| t.as_str()) {
append_assistant_text(&tx_event, &mut assistant_item, text.to_string())
.await;
if matches!(aggregation_mode, ChatAggregationMode::Streaming) {
let _ = tx_event
.send(Ok(ResponseEvent::OutputTextDelta(text.to_string())))
.await;
}
}
}
}
if let Some(tool_calls) = delta.get("tool_calls").and_then(|c| c.as_array()) {
for call in tool_calls {
if let Some(id_val) = call.get("id").and_then(|id| id.as_str()) {
fn_call_state.call_id = Some(id_val.to_string());
}
if let Some(function) = call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
fn_call_state.name = Some(name.to_string());
fn_call_state.active = true;
}
if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
fn_call_state.arguments.push_str(args);
}
}
}
}
if let Some(reasoning) = delta.get("reasoning_content").and_then(|c| c.as_array()) {
for entry in reasoning {
if let Some(text) = entry.get("text").and_then(|t| t.as_str()) {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
.await;
}
}
}
}
if let Some(finish_reason) = choice.get("finish_reason").and_then(|f| f.as_str())
&& finish_reason == "tool_calls"
&& fn_call_state.active
{
let function_name = fn_call_state.name.take().unwrap_or_default();
let call_id = fn_call_state.call_id.take().unwrap_or_default();
let arguments = fn_call_state.arguments.clone();
fn_call_state = FunctionCallState::default();
let item = ResponseItem::FunctionCall {
id: Some(call_id.clone()),
call_id,
name: function_name,
arguments,
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
}
}
}
async fn append_assistant_text(
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
assistant_item: &mut Option<ResponseItem>,
text: String,
) {
if assistant_item.is_none() {
let item = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![],
};
*assistant_item = Some(item.clone());
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await;
}
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
content.push(ContentItem::OutputText { text });
}
}
async fn append_reasoning_text(
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
reasoning_item: &mut Option<ResponseItem>,
text: String,
) {
if reasoning_item.is_none() {
let item = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![]),
encrypted_content: None,
};
*reasoning_item = Some(item.clone());
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await;
}
if let Some(ResponseItem::Reasoning {
content: Some(content),
..
}) = reasoning_item
{
content.push(ReasoningItemContent::ReasoningText { text });
}
}
fn create_tools_json_for_chat_completions_api(
@@ -632,5 +178,3 @@ fn create_tools_json_for_chat_completions_api(
.collect::<Vec<serde_json::Value>>();
Ok(tools_json)
}
// aggregation types and adapters moved to crate::aggregate

View File

@@ -0,0 +1,45 @@
use std::io::BufRead;
use std::path::Path;
use codex_otel::otel_event_manager::OtelEventManager;
use futures::TryStreamExt;
use tokio::sync::mpsc;
use tokio_util::io::ReaderStream;
use crate::error::Error;
use crate::error::Result;
use crate::model_provider::ModelProviderInfo;
use crate::stream::ResponseEvent;
use crate::stream::ResponseStream;
pub async fn stream_from_fixture(
path: impl AsRef<Path>,
provider: ModelProviderInfo,
otel_event_manager: OtelEventManager,
) -> Result<ResponseStream> {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let display_path = path.as_ref().display().to_string();
let file = std::fs::File::open(path.as_ref())
.map_err(|err| Error::Other(format!("failed to open fixture {display_path}: {err}")))?;
let lines = std::io::BufReader::new(file).lines();
let mut content = String::new();
for line in lines {
let line = line
.map_err(|err| Error::Other(format!("failed to read fixture {display_path}: {err}")))?;
content.push_str(&line);
content.push('\n');
content.push('\n');
}
let rdr = std::io::Cursor::new(content);
let stream = ReaderStream::new(rdr).map_err(|err| Error::Other(err.to_string()));
tokio::spawn(crate::client::sse::process_sse(
stream,
tx_event,
provider.stream_idle_timeout(),
otel_event_manager,
crate::decode::responses::ResponsesSseDecoder,
));
Ok(ResponseStream { rx_event })
}

View File

@@ -0,0 +1,43 @@
use std::sync::Arc;
use codex_protocol::protocol::SessionSource;
use reqwest::header::HeaderMap;
use crate::auth::AuthContext;
use crate::auth::AuthProvider;
use crate::common::apply_subagent_header;
use crate::error::Result;
use crate::model_provider::ModelProviderInfo;
/// Build a request builder with provider/auth/session headers applied.
pub async fn build_request(
http_client: &reqwest::Client,
provider: &ModelProviderInfo,
auth: &Option<AuthContext>,
session_source: Option<&SessionSource>,
extra_headers: &[(&str, String)],
) -> Result<reqwest::RequestBuilder> {
let mut builder = provider.create_request_builder(http_client, auth).await?;
builder = apply_subagent_header(builder, session_source);
for (name, value) in extra_headers {
builder = builder.header(*name, value);
}
Ok(builder)
}
/// Resolve auth context from an optional provider.
pub async fn resolve_auth(auth_provider: &Option<Arc<dyn AuthProvider>>) -> Option<AuthContext> {
if let Some(p) = auth_provider {
p.auth_context().await
} else {
None
}
}
/// Extract a provider request id, when present, from headers.
pub fn request_id_from_headers(headers: &HeaderMap) -> Option<String> {
headers
.get("cf-ray")
.and_then(|v| v.to_str().ok())
.map(std::string::ToString::to_string)
}

View File

@@ -0,0 +1,39 @@
use async_trait::async_trait;
use codex_otel::otel_event_manager::OtelEventManager;
use tokio::sync::mpsc;
use crate::error::Result;
use crate::prompt::Prompt;
use crate::stream::ResponseEvent;
pub mod fixtures;
pub mod http;
pub mod rate_limits;
pub mod sse;
/// Builds provider-specific JSON payloads from a Prompt.
pub trait PayloadBuilder {
fn build(&self, prompt: &Prompt) -> Result<serde_json::Value>;
}
/// Decodes framed SSE JSON into ResponseEvent(s).
/// Implementations may keep state across frames (e.g., Chat function-call state).
#[async_trait]
pub trait ResponseDecoder {
async fn on_frame(
&mut self,
json: &str,
tx: &mpsc::Sender<Result<ResponseEvent>>,
otel: &OtelEventManager,
) -> Result<()>;
}
/// Optional trait to expose rate limit parsing where needed.
pub trait RateLimitProvider {
fn parse(
&self,
_headers: &reqwest::header::HeaderMap,
) -> Option<codex_protocol::protocol::RateLimitSnapshot> {
None
}
}

View File

@@ -0,0 +1,60 @@
use codex_protocol::protocol::RateLimitSnapshot;
use codex_protocol::protocol::RateLimitWindow;
use reqwest::header::HeaderMap;
pub fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
let primary = parse_rate_limit_window(
headers,
"x-codex-primary-used-percent",
"x-codex-primary-window-minutes",
"x-codex-primary-reset-at",
);
let secondary = parse_rate_limit_window(
headers,
"x-codex-secondary-used-percent",
"x-codex-secondary-window-minutes",
"x-codex-secondary-reset-at",
);
Some(RateLimitSnapshot { primary, secondary })
}
fn parse_rate_limit_window(
headers: &HeaderMap,
used_percent_header: &str,
window_minutes_header: &str,
resets_at_header: &str,
) -> Option<RateLimitWindow> {
let used_percent: Option<f64> = parse_header_f64(headers, used_percent_header);
used_percent.and_then(|used_percent| {
let window_minutes = parse_header_i64(headers, window_minutes_header);
let resets_at = parse_header_i64(headers, resets_at_header);
let has_data = used_percent != 0.0
|| window_minutes.is_some_and(|minutes| minutes != 0)
|| resets_at.is_some();
has_data.then_some(RateLimitWindow {
used_percent,
window_minutes,
resets_at,
})
})
}
fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option<f64> {
parse_header_str(headers, name)?
.parse::<f64>()
.ok()
.filter(|v| v.is_finite())
}
fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option<i64> {
parse_header_str(headers, name)?.parse::<i64>().ok()
}
fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
headers.get(name)?.to_str().ok()
}

View File

@@ -0,0 +1,83 @@
use std::time::Duration;
use bytes::Bytes;
use codex_otel::otel_event_manager::OtelEventManager;
use futures::Stream;
use futures::StreamExt;
use tokio::sync::mpsc;
use tokio::time::timeout;
use crate::client::ResponseDecoder;
use crate::error::Error;
use crate::error::Result;
use crate::stream::ResponseEvent;
/// Generic SSE framer: turns a Byte stream into framed JSON and delegates to a ResponseDecoder.
#[allow(clippy::too_many_arguments)]
pub async fn process_sse<S, D>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
max_idle_duration: Duration,
otel_event_manager: OtelEventManager,
mut decoder: D,
) where
S: Stream<Item = Result<Bytes>> + Send + 'static + Unpin,
D: ResponseDecoder + Send,
{
let mut stream = stream;
let mut data_buffer = String::new();
loop {
let result = timeout(max_idle_duration, stream.next()).await;
match result {
Err(_) => {
let _ = tx_event
.send(Err(Error::Stream(
"stream idle timeout fired before Completed event".to_string(),
None,
)))
.await;
return;
}
Ok(Some(Err(err))) => {
let _ = tx_event.send(Err(err)).await;
return;
}
Ok(Some(Ok(chunk))) => {
let chunk_str = match std::str::from_utf8(&chunk) {
Ok(s) => s,
Err(err) => {
let _ = tx_event
.send(Err(Error::Other(format!(
"Invalid UTF-8 in SSE chunk: {err}"
))))
.await;
return;
}
};
for line in chunk_str.lines() {
if let Some(tail) = line.strip_prefix("data:") {
data_buffer.push_str(tail.trim_start());
} else if !line.is_empty() && !data_buffer.is_empty() {
// Continuation of a long data: line split across chunks; append raw.
data_buffer.push_str(line);
}
if line.is_empty() && !data_buffer.is_empty() {
// One full JSON frame ready delegate to decoder
if let Err(err) = decoder
.on_frame(&data_buffer, &tx_event, &otel_event_manager)
.await
{
let _ = tx_event.send(Err(err)).await;
return;
}
data_buffer.clear();
}
}
}
Ok(None) => return,
}
}
}

View File

@@ -0,0 +1,172 @@
use async_trait::async_trait;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ResponseItem;
use tokio::sync::mpsc;
use tracing::debug;
use crate::aggregate::ChatAggregationMode;
use crate::error::Result;
use crate::stream::ResponseEvent;
pub struct ChatSseDecoder {
aggregation_mode: ChatAggregationMode,
fn_call_state: FunctionCallState,
assistant_item: Option<ResponseItem>,
reasoning_item: Option<ResponseItem>,
}
#[derive(Default)]
struct FunctionCallState {
name: Option<String>,
arguments: String,
call_id: Option<String>,
active: bool,
}
impl ChatSseDecoder {
pub fn new(aggregation_mode: ChatAggregationMode) -> Self {
Self {
aggregation_mode,
fn_call_state: FunctionCallState::default(),
assistant_item: None,
reasoning_item: None,
}
}
}
#[async_trait]
impl crate::client::ResponseDecoder for ChatSseDecoder {
async fn on_frame(
&mut self,
json: &str,
tx: &mpsc::Sender<Result<ResponseEvent>>,
_otel: &OtelEventManager,
) -> Result<()> {
// Chat sends a terminal "[DONE]" frame; we ignore it here. Caller should handle end-of-stream.
let Ok(parsed_chunk) = serde_json::from_str::<serde_json::Value>(json) else {
debug!("failed to parse Chat SSE JSON: {}", json);
return Ok(());
};
let choices = parsed_chunk
.get("choices")
.and_then(|choices| choices.as_array())
.cloned()
.unwrap_or_default();
for choice in choices {
if let Some(delta) = choice.get("delta") {
if let Some(content) = delta.get("content").and_then(|c| c.as_array()) {
for piece in content {
if let Some(text) = piece.get("text").and_then(|t| t.as_str()) {
append_assistant_text(tx, &mut self.assistant_item, text.to_string())
.await;
if matches!(self.aggregation_mode, ChatAggregationMode::Streaming) {
let _ = tx
.send(Ok(ResponseEvent::OutputTextDelta(text.to_string())))
.await;
}
}
}
}
if let Some(tool_calls) = delta.get("tool_calls").and_then(|c| c.as_array()) {
for call in tool_calls {
if let Some(id_val) = call.get("id").and_then(|id| id.as_str()) {
self.fn_call_state.call_id = Some(id_val.to_string());
}
if let Some(function) = call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
self.fn_call_state.name = Some(name.to_string());
self.fn_call_state.active = true;
}
if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
self.fn_call_state.arguments.push_str(args);
}
}
}
}
if let Some(reasoning) = delta.get("reasoning_content").and_then(|c| c.as_array()) {
for entry in reasoning {
if let Some(text) = entry.get("text").and_then(|t| t.as_str()) {
append_reasoning_text(tx, &mut self.reasoning_item, text.to_string())
.await;
}
}
}
}
if let Some(finish_reason) = choice.get("finish_reason").and_then(|f| f.as_str())
&& finish_reason == "tool_calls"
&& self.fn_call_state.active
{
let function_name = self.fn_call_state.name.take().unwrap_or_default();
let call_id = self.fn_call_state.call_id.take().unwrap_or_default();
let arguments = self.fn_call_state.arguments.clone();
self.fn_call_state = FunctionCallState::default();
let item = ResponseItem::FunctionCall {
id: Some(call_id.clone()),
call_id,
name: function_name,
arguments,
};
let _ = tx.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
}
Ok(())
}
}
async fn append_assistant_text(
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
assistant_item: &mut Option<ResponseItem>,
text: String,
) {
if assistant_item.is_none() {
let item = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![],
};
*assistant_item = Some(item.clone());
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await;
}
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
content.push(ContentItem::OutputText { text });
}
}
async fn append_reasoning_text(
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
reasoning_item: &mut Option<ResponseItem>,
text: String,
) {
if reasoning_item.is_none() {
let item = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![]),
encrypted_content: None,
};
*reasoning_item = Some(item.clone());
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await;
}
if let Some(ResponseItem::Reasoning {
content: Some(content),
..
}) = reasoning_item
{
content.push(ReasoningItemContent::ReasoningText { text });
}
}

View File

@@ -0,0 +1,2 @@
pub mod chat;
pub mod responses;

View File

@@ -0,0 +1,509 @@
use async_trait::async_trait;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::TokenUsage;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use std::time::Duration;
use tokio::sync::mpsc;
use tracing::debug;
use tracing::trace;
use crate::error::Error;
use crate::error::Result;
use crate::stream::ResponseEvent;
#[derive(Debug, Deserialize)]
pub struct ResponseCompleted {
pub id: String,
pub usage: Option<TokenUsage>,
}
#[derive(Debug, Deserialize)]
pub struct StreamResponseCompleted {
pub id: String,
pub usage: Option<TokenUsagePartial>,
}
#[derive(Debug, Deserialize)]
pub struct ErrorResponse {
pub error: ErrorBody,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ErrorBody {
pub r#type: Option<String>,
pub code: Option<String>,
pub message: Option<String>,
pub plan_type: Option<String>,
pub resets_at: Option<i64>,
}
pub fn is_quota_exceeded_error(error: &ErrorBody) -> bool {
error.code.as_deref() == Some("quota_exceeded")
}
#[derive(Debug, Deserialize)]
pub struct StreamEvent {
pub r#type: String,
pub response: Option<Value>,
pub item: Option<Value>,
pub error: Option<Value>,
#[serde(default)]
pub delta: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct TokenUsagePartial {
#[serde(default)]
pub input_tokens: i64,
#[serde(default)]
pub cached_input_tokens: i64,
#[serde(default)]
pub input_tokens_details: Option<TokenUsageInputDetails>,
#[serde(default)]
pub output_tokens: i64,
#[serde(default)]
pub output_tokens_details: Option<TokenUsageOutputDetails>,
#[serde(default)]
pub reasoning_output_tokens: i64,
#[serde(default)]
pub total_tokens: i64,
}
impl From<TokenUsagePartial> for TokenUsage {
fn from(value: TokenUsagePartial) -> Self {
let cached_input_tokens = if value.cached_input_tokens > 0 {
Some(value.cached_input_tokens)
} else {
value
.input_tokens_details
.and_then(|d| d.cached_tokens)
.filter(|v| *v > 0)
};
let reasoning_output_tokens = if value.reasoning_output_tokens > 0 {
Some(value.reasoning_output_tokens)
} else {
value
.output_tokens_details
.and_then(|d| d.reasoning_tokens)
.filter(|v| *v > 0)
};
Self {
input_tokens: value.input_tokens,
cached_input_tokens: cached_input_tokens.unwrap_or(0),
output_tokens: value.output_tokens,
reasoning_output_tokens: reasoning_output_tokens.unwrap_or(0),
total_tokens: value.total_tokens,
}
}
}
#[derive(Debug, Deserialize)]
pub struct TokenUsageInputDetails {
#[serde(default)]
pub cached_tokens: Option<i64>,
}
#[derive(Debug, Deserialize)]
pub struct TokenUsageOutputDetails {
#[serde(default)]
pub reasoning_tokens: Option<i64>,
}
pub async fn handle_sse_payload(
payload: sse::Payload,
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
otel_event_manager: &OtelEventManager,
) -> Result<()> {
if let Some(responses) = payload.responses {
for ev in responses {
let event = match ev {
sse::Response::Completed(complete) => {
if let Some(usage) = &complete.usage {
otel_event_manager.sse_event_completed(
usage.input_tokens,
usage.output_tokens,
Some(usage.cached_input_tokens),
Some(usage.reasoning_output_tokens),
usage.total_tokens,
);
} else {
otel_event_manager
.see_event_completed_failed(&"missing token usage".to_string());
}
ResponseEvent::Completed {
response_id: complete.id,
token_usage: complete.usage,
}
}
sse::Response::Error(err) => {
let retry_after = err
.retry_after
.map(|secs| Duration::from_secs(if secs < 0 { 0 } else { secs as u64 }));
return Err(Error::Stream(
err.message.unwrap_or_else(|| "fatal error".to_string()),
retry_after,
));
}
};
tx_event.send(Ok(event)).await.ok();
}
}
if let Some(message_delta) = payload.response_message_delta {
let ev = ResponseEvent::OutputTextDelta(message_delta.text.clone());
tx_event.send(Ok(ev)).await.ok();
}
if let Some(_response_content) = payload.response_content {
// Not used currently
}
if let Some(ev) = payload.response_event {
debug!("Unhandled response_event: {ev:?}");
}
if let Some(item) = payload.response_output_item {
match item.r#type {
sse::OutputItem::Created => {
tx_event.send(Ok(ResponseEvent::Created)).await.ok();
otel_event_manager.sse_event_kind("response.output_item.done");
}
}
}
if let Some(done) = payload.response_output_text_delta {
tx_event
.send(Ok(ResponseEvent::OutputTextDelta(done.text)))
.await
.ok();
}
if let Some(completed) = payload.response_output_item_done {
let response_item =
serde_json::from_value::<ResponseItem>(completed.item).map_err(Error::Json)?;
tx_event
.send(Ok(ResponseEvent::OutputItemDone(response_item)))
.await
.ok();
otel_event_manager.sse_event_kind("response.output_item.done");
}
if let Some(reasoning_content_delta) = payload.response_output_reasoning_delta {
tx_event
.send(Ok(ResponseEvent::ReasoningContentDelta(
reasoning_content_delta.text,
)))
.await
.ok();
}
if let Some(reasoning_summary_delta) = payload.response_output_reasoning_summary_delta {
tx_event
.send(Ok(ResponseEvent::ReasoningSummaryDelta(
reasoning_summary_delta.text,
)))
.await
.ok();
}
if let Some(ev) = payload.response_error
&& ev.code.as_deref() == Some("max_response_tokens")
{
let _ = tx_event
.send(Err(Error::Stream(
"context window exceeded".to_string(),
None,
)))
.await;
}
Ok(())
}
#[derive(Debug, Deserialize)]
pub struct TextDelta {
pub delta: String,
}
pub async fn handle_stream_event(
event: StreamEvent,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
_response_completed: &mut Option<ResponseCompleted>,
_response_error: &mut Option<Error>,
otel_event_manager: &OtelEventManager,
) {
trace!("response event: {}", event.r#type);
match event.r#type.as_str() {
"response.created" => {
let _ = tx_event.send(Ok(ResponseEvent::Created)).await;
}
"response.output_text.delta" => {
if let Some(item_val) = event.item {
let resp = serde_json::from_value::<TextDelta>(item_val);
if let Ok(delta) = resp {
let event = ResponseEvent::OutputTextDelta(delta.delta);
let _ = tx_event.send(Ok(event)).await;
}
} else if let Some(delta) = event.delta {
let _ = tx_event
.send(Ok(ResponseEvent::OutputTextDelta(delta)))
.await;
}
}
"response.reasoning_text.delta" => {
if let Some(delta) = event.delta {
let event = ResponseEvent::ReasoningContentDelta(delta);
let _ = tx_event.send(Ok(event)).await;
}
}
"response.reasoning_summary_text.delta" => {
if let Some(delta) = event.delta {
let event = ResponseEvent::ReasoningSummaryDelta(delta);
let _ = tx_event.send(Ok(event)).await;
}
}
"response.output_item.done" => {
if let Some(item_val) = event.item
&& let Ok(item) = serde_json::from_value::<ResponseItem>(item_val)
{
let event = ResponseEvent::OutputItemDone(item);
if tx_event.send(Ok(event)).await.is_err() {}
}
}
"response.failed" => {
if let Some(resp_val) = event.response {
otel_event_manager.sse_event_failed(
Some(&"response.failed".to_string()),
Duration::from_millis(0),
&resp_val,
);
if let Some(err) = resp_val
.get("error")
.cloned()
.and_then(|v| serde_json::from_value::<ErrorBody>(v).ok())
{
let msg = if err.code.as_deref() == Some("context_length_exceeded") {
"context window exceeded".to_string()
} else if err.code.as_deref() == Some("insufficient_quota") {
"quota exceeded".to_string()
} else {
err.message.unwrap_or_else(|| "fatal error".to_string())
};
let _ = tx_event.send(Err(Error::Stream(msg, None))).await;
}
}
}
"response.error" => {
if let Some(err_val) = event.error {
let err_resp = serde_json::from_value::<ErrorResponse>(err_val);
if let Ok(err) = err_resp {
let retry_after = try_parse_retry_after(&err);
let _ = tx_event
.send(Err(Error::Stream(
err.error
.message
.unwrap_or_else(|| "unknown error".to_string()),
retry_after,
)))
.await;
}
}
}
"response.completed" => {
if let Some(resp_val) = event.response
&& let Ok(resp) = serde_json::from_value::<StreamResponseCompleted>(resp_val)
{
let usage = resp.usage.map(TokenUsage::from);
let ev = ResponseEvent::Completed {
response_id: resp.id,
token_usage: usage.clone(),
};
let _ = tx_event.send(Ok(ev)).await;
if let Some(usage) = &usage {
otel_event_manager.sse_event_completed(
usage.input_tokens,
usage.output_tokens,
Some(usage.cached_input_tokens),
Some(usage.reasoning_output_tokens),
usage.total_tokens,
);
} else {
otel_event_manager
.see_event_completed_failed(&"missing token usage".to_string());
}
}
}
"response.output_item.added" => {
if let Some(item_val) = event.item
&& let Ok(item) = serde_json::from_value::<ResponseItem>(item_val)
{
let event = ResponseEvent::OutputItemAdded(item);
if tx_event.send(Ok(event)).await.is_err() {}
}
}
"response.reasoning_summary_part.added" => {
let event = ResponseEvent::ReasoningSummaryPartAdded;
let _ = tx_event.send(Ok(event)).await;
}
_ => {}
}
}
#[derive(Debug, Deserialize)]
pub struct ResponseErrorBody {
pub code: Option<String>,
}
fn try_parse_retry_after(err: &ErrorResponse) -> Option<Duration> {
if err.error.r#type.as_deref() == Some("rate_limit_exceeded") {
let retry_after = serde_json::to_value(&err.error)
.ok()
.and_then(|v| v.get("retry_after").cloned())
.and_then(|v| serde_json::from_value::<ResponseErrorBody>(v).ok())
.and_then(|v| v.code)
.and_then(parse_retry_after);
return retry_after;
}
None
}
fn parse_retry_after(s: String) -> Option<Duration> {
let minutes_pattern = regex_lite::Regex::new(r"^(\d+)m$").ok()?;
if let Some(cap) = minutes_pattern.captures(&s)
&& let Some(m) = cap.get(1).and_then(|m| m.as_str().parse::<u64>().ok())
{
return Some(Duration::from_secs(m * 60));
}
s.parse::<u64>().ok().map(Duration::from_secs)
}
pub mod sse {
use serde::Deserialize;
use serde_json::Value;
#[derive(Debug, Deserialize)]
pub struct Payload {
pub responses: Option<Vec<Response>>,
pub response_content: Option<Value>,
pub response_error: Option<ResponseError>,
pub response_event: Option<String>,
pub response_message_delta: Option<ResponseMessageDelta>,
pub response_output_item: Option<ResponseOutputItem>,
pub response_output_text_delta: Option<ResponseOutputTextDelta>,
pub response_output_item_done: Option<ResponseOutputItemDone>,
pub response_output_reasoning_delta: Option<ResponseOutputReasoningDelta>,
pub response_output_reasoning_summary_delta: Option<ResponseOutputReasoningSummaryDelta>,
}
#[derive(Debug, Deserialize)]
pub enum Response {
#[serde(rename = "response.completed")]
Completed(ResponseCompleted),
#[serde(rename = "response.error")]
Error(ResponseError),
}
#[derive(Debug, Deserialize)]
pub struct ResponseCompleted {
pub id: String,
pub usage: Option<codex_protocol::protocol::TokenUsage>,
}
#[derive(Debug, Deserialize)]
pub struct ResponseError {
pub code: Option<String>,
pub message: Option<String>,
pub retry_after: Option<i64>,
}
#[derive(Debug, Deserialize)]
pub struct ResponseMessageDelta {
pub text: String,
}
#[derive(Debug, Deserialize)]
pub enum OutputItem {
#[serde(rename = "response.output_item.created")]
Created,
}
#[derive(Debug, Deserialize)]
pub struct ResponseOutputItem {
pub r#type: OutputItem,
}
#[derive(Debug, Deserialize)]
pub struct ResponseOutputTextDelta {
pub text: String,
}
#[derive(Debug, Deserialize)]
pub struct ResponseOutputItemDone {
pub item: Value,
}
#[derive(Debug, Deserialize)]
pub struct ResponseOutputReasoningDelta {
pub text: String,
}
#[derive(Debug, Deserialize)]
pub struct ResponseOutputReasoningSummaryDelta {
pub text: String,
}
}
pub struct ResponsesSseDecoder;
impl Default for ResponsesSseDecoder {
fn default() -> Self {
Self
}
}
#[async_trait]
impl crate::client::ResponseDecoder for ResponsesSseDecoder {
async fn on_frame(
&mut self,
json: &str,
tx: &mpsc::Sender<Result<ResponseEvent>>,
otel_event_manager: &OtelEventManager,
) -> Result<()> {
if let Ok(event) = serde_json::from_str::<StreamEvent>(json) {
otel_event_manager.sse_event_kind(&event.r#type);
let mut completed: Option<ResponseCompleted> = None;
let mut error: Option<Error> = None;
handle_stream_event(
event,
tx.clone(),
&mut completed,
&mut error,
otel_event_manager,
)
.await;
return Ok(());
}
otel_event_manager.sse_event_failed(
None,
Duration::from_millis(0),
&format!("Cannot parse SSE JSON: {json}"),
);
match serde_json::from_str::<sse::Payload>(json) {
Ok(payload) => handle_sse_payload(payload, tx, otel_event_manager).await,
Err(err) => {
otel_event_manager.sse_event_failed(
None,
Duration::from_millis(0),
&format!("Cannot parse SSE JSON: {err}"),
);
Err(Error::Other(format!("Cannot parse SSE JSON: {err}")))
}
}
}
}

View File

@@ -2,9 +2,12 @@ pub mod aggregate;
pub mod api;
pub mod auth;
pub mod chat;
pub mod client;
mod common;
pub mod decode;
pub mod error;
pub mod model_provider;
pub mod payload;
pub mod prompt;
pub mod responses;
pub mod routed_client;
@@ -17,6 +20,7 @@ pub use crate::auth::AuthContext;
pub use crate::auth::AuthProvider;
pub use crate::chat::ChatCompletionsApiClient;
pub use crate::chat::ChatCompletionsApiClientConfig;
pub use crate::client::fixtures::stream_from_fixture;
pub use crate::error::Error;
pub use crate::error::Result;
pub use crate::model_provider::BUILT_IN_OSS_MODEL_PROVIDER_ID;
@@ -29,7 +33,6 @@ pub use crate::prompt::Prompt;
pub use crate::prompt::PromptBuilder;
pub use crate::responses::ResponsesApiClient;
pub use crate::responses::ResponsesApiClientConfig;
pub use crate::responses::stream_from_fixture;
pub use crate::routed_client::RoutedApiClient;
pub use crate::routed_client::RoutedApiClientConfig;
pub use crate::stream::EventStream;

View File

@@ -0,0 +1,306 @@
use serde_json::Value;
use serde_json::json;
use std::collections::HashMap;
use crate::client::PayloadBuilder;
use crate::error::Result;
use crate::prompt::Prompt;
use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputContentItem;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ResponseItem;
pub struct ChatPayloadBuilder {
model: String,
}
impl ChatPayloadBuilder {
pub fn new(model: String) -> Self {
Self { model }
}
}
impl PayloadBuilder for ChatPayloadBuilder {
fn build(&self, prompt: &Prompt) -> Result<Value> {
let mut messages = Vec::<Value>::new();
messages.push(json!({ "role": "system", "content": prompt.instructions }));
let mut reasoning_by_anchor_index: HashMap<usize, String> = HashMap::new();
let mut last_emitted_role: Option<&str> = None;
for item in &prompt.input {
match item {
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
last_emitted_role = Some("assistant");
}
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
ResponseItem::Reasoning { .. }
| ResponseItem::Other
| ResponseItem::CustomToolCall { .. }
| ResponseItem::CustomToolCallOutput { .. }
| ResponseItem::WebSearchCall { .. }
| ResponseItem::GhostSnapshot { .. } => {}
}
}
let mut last_user_index: Option<usize> = None;
for (idx, item) in prompt.input.iter().enumerate() {
if let ResponseItem::Message { role, .. } = item
&& role == "user"
{
last_user_index = Some(idx);
}
}
if !matches!(last_emitted_role, Some("user")) {
for (idx, item) in prompt.input.iter().enumerate() {
if let Some(u_idx) = last_user_index
&& idx <= u_idx
{
continue;
}
if let ResponseItem::Reasoning {
content: Some(items),
..
} = item
{
let mut text = String::new();
for entry in items {
match entry {
ReasoningItemContent::ReasoningText { text: segment }
| ReasoningItemContent::Text { text: segment } => {
text.push_str(segment);
}
}
}
if text.trim().is_empty() {
continue;
}
let mut attached = false;
if idx > 0
&& let ResponseItem::Message { role, .. } = &prompt.input[idx - 1]
&& role == "assistant"
{
reasoning_by_anchor_index
.entry(idx - 1)
.and_modify(|val| val.push_str(&text))
.or_insert(text.clone());
attached = true;
}
if !attached && idx + 1 < prompt.input.len() {
match &prompt.input[idx + 1] {
ResponseItem::FunctionCall { .. }
| ResponseItem::LocalShellCall { .. } => {
reasoning_by_anchor_index
.entry(idx + 1)
.and_modify(|val| val.push_str(&text))
.or_insert(text.clone());
}
ResponseItem::Message { role, .. } if role == "assistant" => {
reasoning_by_anchor_index
.entry(idx + 1)
.and_modify(|val| val.push_str(&text))
.or_insert(text.clone());
}
_ => {}
}
}
}
}
}
let mut last_assistant_text: Option<String> = None;
for (idx, item) in prompt.input.iter().enumerate() {
match item {
ResponseItem::Message { role, content, .. } => {
let mut text = String::new();
let mut items: Vec<Value> = Vec::new();
let mut saw_image = false;
for c in content {
match c {
ContentItem::InputText { text: t }
| ContentItem::OutputText { text: t } => {
text.push_str(t);
items.push(json!({"type":"text","text": t}));
}
ContentItem::InputImage { image_url } => {
saw_image = true;
items.push(
json!({"type":"image_url","image_url": {"url": image_url}}),
);
}
}
}
if role == "assistant" {
if let Some(prev) = &last_assistant_text
&& prev == &text
{
continue;
}
last_assistant_text = Some(text.clone());
}
let content_value = if role == "assistant" {
json!(text)
} else if saw_image {
json!(items)
} else {
json!(text)
};
let mut message = json!({
"role": role,
"content": content_value,
});
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
&& let Some(obj) = message.as_object_mut()
{
obj.insert("reasoning".to_string(), json!({"text": reasoning}));
}
messages.push(message);
}
ResponseItem::FunctionCall {
name,
arguments,
call_id,
..
} => {
messages.push(json!({
"role": "assistant",
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
},
}],
}));
}
ResponseItem::FunctionCallOutput { call_id, output } => {
let content_value = if let Some(items) = &output.content_items {
let mapped: Vec<Value> = items
.iter()
.map(|item| match item {
FunctionCallOutputContentItem::InputText { text } => {
json!({"type":"text","text": text})
}
FunctionCallOutputContentItem::InputImage { image_url } => {
json!({"type":"image_url","image_url": {"url": image_url}})
}
})
.collect();
json!(mapped)
} else {
json!(output.content)
};
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": content_value,
}));
}
ResponseItem::LocalShellCall {
id,
call_id,
action,
..
} => {
let tool_id = call_id
.clone()
.filter(|value| !value.is_empty())
.or_else(|| id.clone())
.unwrap_or_default();
messages.push(json!({
"role": "assistant",
"tool_calls": [{
"id": tool_id,
"type": "function",
"function": {
"name": "shell",
"arguments": serde_json::to_string(action).unwrap_or_default(),
},
}],
}));
}
ResponseItem::CustomToolCall {
call_id,
name,
input,
..
} => {
messages.push(json!({
"role": "assistant",
"tool_calls": [{
"id": call_id.clone(),
"type": "function",
"function": {
"name": name,
"arguments": input,
},
}],
}));
}
ResponseItem::CustomToolCallOutput { call_id, output } => {
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": output,
}));
}
ResponseItem::WebSearchCall { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::Other
| ResponseItem::GhostSnapshot { .. } => {}
}
}
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
let payload = json!({
"model": self.model,
"messages": messages,
"stream": true,
"tools": tools_json,
});
Ok(payload)
}
}
fn create_tools_json_for_chat_completions_api(
tools: &[serde_json::Value],
) -> Result<Vec<serde_json::Value>> {
let tools_json = tools
.iter()
.filter_map(|tool| {
if tool.get("type") != Some(&serde_json::Value::String("function".to_string())) {
return None;
}
let function_value = if let Some(function) = tool.get("function") {
function.clone()
} else if let Some(map) = tool.as_object() {
let mut function = map.clone();
function.remove("type");
Value::Object(function)
} else {
return None;
};
Some(json!({
"type": "function",
"function": function_value,
}))
})
.collect::<Vec<serde_json::Value>>();
Ok(tools_json)
}

View File

@@ -0,0 +1,2 @@
pub mod chat;
pub mod responses;

View File

@@ -0,0 +1,125 @@
use serde_json::Value;
use serde_json::json;
use crate::client::PayloadBuilder;
use crate::error::Result;
use crate::prompt::Prompt;
use codex_protocol::ConversationId;
use codex_protocol::models::ResponseItem;
pub struct ResponsesPayloadBuilder {
model: String,
conversation_id: ConversationId,
azure_workaround: bool,
}
impl ResponsesPayloadBuilder {
pub fn new(model: String, conversation_id: ConversationId, azure_workaround: bool) -> Self {
Self {
model,
conversation_id,
azure_workaround,
}
}
}
impl PayloadBuilder for ResponsesPayloadBuilder {
fn build(&self, prompt: &Prompt) -> Result<Value> {
let azure = self.azure_workaround;
let mut payload = json!({
"model": self.model,
"instructions": prompt.instructions,
"input": prompt.input,
"tools": prompt.tools,
"tool_choice": "auto",
"parallel_tool_calls": prompt.parallel_tool_calls,
"store": azure,
"stream": true,
"prompt_cache_key": prompt
.prompt_cache_key
.clone()
.unwrap_or_else(|| self.conversation_id.to_string()),
});
if let Some(reasoning) = prompt.reasoning.as_ref()
&& let Some(obj) = payload.as_object_mut()
{
obj.insert("reasoning".to_string(), serde_json::to_value(reasoning)?);
}
if let Some(text) = prompt.text_controls.as_ref()
&& let Some(obj) = payload.as_object_mut()
{
obj.insert("text".to_string(), serde_json::to_value(text)?);
}
let include = if prompt.reasoning.is_some() {
vec!["reasoning.encrypted_content".to_string()]
} else {
Vec::new()
};
if let Some(obj) = payload.as_object_mut() {
obj.insert(
"include".to_string(),
Value::Array(include.into_iter().map(Value::String).collect()),
);
}
// Azure Responses requires ids attached to input items
if azure
&& let Some(input_value) = payload.get_mut("input")
&& let Some(array) = input_value.as_array_mut()
{
attach_item_ids_array(array, &prompt.input);
}
Ok(payload)
}
}
fn attach_item_ids_array(json_array: &mut [Value], prompt_input: &[ResponseItem]) {
for (json_item, item) in json_array.iter_mut().zip(prompt_input.iter()) {
let Some(obj) = json_item.as_object_mut() else {
continue;
};
let mut set_id_if_absent = |id: &str| match obj.get("id") {
Some(Value::String(s)) if !s.is_empty() => {}
Some(Value::Null) | None => {
obj.insert("id".to_string(), Value::String(id.to_string()));
}
_ => {}
};
match item {
ResponseItem::Reasoning { id, .. } => set_id_if_absent(id),
ResponseItem::Message { id, .. } => {
if let Some(id) = id.as_ref() {
set_id_if_absent(id);
}
}
ResponseItem::WebSearchCall { id, .. } => {
if let Some(id) = id.as_ref() {
set_id_if_absent(id);
}
}
ResponseItem::FunctionCall { id, .. } => {
if let Some(id) = id.as_ref() {
set_id_if_absent(id);
}
}
ResponseItem::LocalShellCall { id, .. } => {
if let Some(id) = id.as_ref() {
set_id_if_absent(id);
}
}
ResponseItem::CustomToolCall { id, .. } => {
if let Some(id) = id.as_ref() {
set_id_if_absent(id);
}
}
_ => {}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -16,8 +16,8 @@ use crate::ResponsesApiClientConfig;
use crate::Result;
use crate::WireApi;
use crate::auth::AuthProvider;
use crate::client::fixtures::stream_from_fixture;
use crate::model_provider::ModelProviderInfo;
use crate::responses::stream_from_fixture;
/// Dispatches to the appropriate API client implementation based on the provider wire API.
#[derive(Clone)]