mirror of
https://github.com/openai/codex.git
synced 2026-04-28 02:11:08 +03:00
chore: nuke chat/completions API (#10157)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
Typed clients for Codex/OpenAI APIs built on top of the generic transport in `codex-client`.
|
||||
|
||||
- Hosts the request/response models and prompt helpers for Responses, Chat Completions, and Compact APIs.
|
||||
- Hosts the request/response models and prompt helpers for Responses and Compact APIs.
|
||||
- Owns provider configuration (base URLs, headers, query params), auth header injection, retry tuning, and stream idle settings.
|
||||
- Parses SSE streams into `ResponseEvent`/`ResponseStream`, including rate-limit snapshots and API-specific error mapping.
|
||||
- Serves as the wire-level layer consumed by `codex-core`; higher layers handle auth refresh and business logic.
|
||||
@@ -11,7 +11,7 @@ Typed clients for Codex/OpenAI APIs built on top of the generic transport in `co
|
||||
|
||||
The public interface of this crate is intentionally small and uniform:
|
||||
|
||||
- **Prompted endpoints (Chat + Responses)**
|
||||
- **Prompted endpoints (Responses)**
|
||||
- Input: a single `Prompt` plus endpoint-specific options.
|
||||
- `Prompt` (re-exported as `codex_api::Prompt`) carries:
|
||||
- `instructions: String` – the fully-resolved system prompt for this turn.
|
||||
|
||||
@@ -13,7 +13,7 @@ use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Canonical prompt input for Chat and Responses endpoints.
|
||||
/// Canonical prompt input for Responses endpoints.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
/// Fully-resolved system instructions for this turn.
|
||||
|
||||
@@ -1,111 +1,21 @@
|
||||
use crate::ChatRequest;
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::common::Prompt as ApiPrompt;
|
||||
use crate::common::ResponseEvent;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::endpoint::streaming::StreamingClient;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::provider::WireApi;
|
||||
use crate::sse::chat::spawn_chat_stream;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::RequestCompression;
|
||||
use codex_client::RequestTelemetry;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use futures::Stream;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
pub struct ChatClient<T: HttpTransport, A: AuthProvider> {
|
||||
streaming: StreamingClient<T, A>,
|
||||
}
|
||||
|
||||
impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
|
||||
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
|
||||
Self {
|
||||
streaming: StreamingClient::new(transport, provider, auth),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_telemetry(
|
||||
self,
|
||||
request: Option<Arc<dyn RequestTelemetry>>,
|
||||
sse: Option<Arc<dyn SseTelemetry>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
streaming: self.streaming.with_telemetry(request, sse),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_request(&self, request: ChatRequest) -> Result<ResponseStream, ApiError> {
|
||||
self.stream(request.body, request.headers).await
|
||||
}
|
||||
|
||||
pub async fn stream_prompt(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt: &ApiPrompt,
|
||||
conversation_id: Option<String>,
|
||||
session_source: Option<SessionSource>,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
use crate::requests::ChatRequestBuilder;
|
||||
|
||||
let request =
|
||||
ChatRequestBuilder::new(model, &prompt.instructions, &prompt.input, &prompt.tools)
|
||||
.conversation_id(conversation_id)
|
||||
.session_source(session_source)
|
||||
.build(self.streaming.provider())?;
|
||||
|
||||
self.stream_request(request).await
|
||||
}
|
||||
|
||||
fn path(&self) -> &'static str {
|
||||
match self.streaming.provider().wire {
|
||||
WireApi::Chat => "chat/completions",
|
||||
_ => "responses",
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream(
|
||||
&self,
|
||||
body: Value,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
self.streaming
|
||||
.stream(
|
||||
self.path(),
|
||||
body,
|
||||
extra_headers,
|
||||
RequestCompression::None,
|
||||
spawn_chat_stream,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
|
||||
/// Stream adapter that merges token deltas into a single assistant message per turn.
|
||||
pub struct AggregatedStream {
|
||||
inner: ResponseStream,
|
||||
cumulative: String,
|
||||
cumulative_reasoning: String,
|
||||
pending: VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl Stream for AggregatedStream {
|
||||
@@ -122,7 +32,7 @@ impl Stream for AggregatedStream {
|
||||
match Pin::new(&mut this.inner).poll_next(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
@@ -130,29 +40,16 @@ impl Stream for AggregatedStream {
|
||||
);
|
||||
|
||||
if is_assistant_message {
|
||||
match this.mode {
|
||||
AggregateMode::AggregatedOnly => {
|
||||
if this.cumulative.is_empty()
|
||||
&& let ResponseItem::Message { content, .. } = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
ContentItem::OutputText { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
AggregateMode::Streaming => {
|
||||
if this.cumulative.is_empty() {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if this.cumulative.is_empty()
|
||||
&& let ResponseItem::Message { content, .. } = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
ContentItem::OutputText { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
@@ -216,35 +113,20 @@ impl Stream for AggregatedStream {
|
||||
token_usage,
|
||||
})));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Created))) => continue,
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
|
||||
this.cumulative.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta,
|
||||
content_index,
|
||||
content_index: _,
|
||||
}))) => {
|
||||
this.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta,
|
||||
content_index,
|
||||
})));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta { .. }))) => continue,
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded { .. }))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta { .. }))) => continue,
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded { .. }))) => continue,
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
|
||||
}
|
||||
@@ -255,28 +137,21 @@ impl Stream for AggregatedStream {
|
||||
|
||||
pub trait AggregateStreamExt {
|
||||
fn aggregate(self) -> AggregatedStream;
|
||||
|
||||
fn streaming_mode(self) -> ResponseStream;
|
||||
}
|
||||
|
||||
impl AggregateStreamExt for ResponseStream {
|
||||
fn aggregate(self) -> AggregatedStream {
|
||||
AggregatedStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
|
||||
fn streaming_mode(self) -> ResponseStream {
|
||||
self
|
||||
AggregatedStream::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl AggregatedStream {
|
||||
fn new(inner: ResponseStream, mode: AggregateMode) -> Self {
|
||||
fn new(inner: ResponseStream) -> Self {
|
||||
AggregatedStream {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -36,12 +36,9 @@ impl<T: HttpTransport, A: AuthProvider> CompactClient<T, A> {
|
||||
self
|
||||
}
|
||||
|
||||
fn path(&self) -> Result<&'static str, ApiError> {
|
||||
fn path(&self) -> &'static str {
|
||||
match self.provider.wire {
|
||||
WireApi::Compact | WireApi::Responses => Ok("responses/compact"),
|
||||
WireApi::Chat => Err(ApiError::Stream(
|
||||
"compact endpoint requires responses wire api".to_string(),
|
||||
)),
|
||||
WireApi::Compact | WireApi::Responses => "responses/compact",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +47,7 @@ impl<T: HttpTransport, A: AuthProvider> CompactClient<T, A> {
|
||||
body: serde_json::Value,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<Vec<ResponseItem>, ApiError> {
|
||||
let path = self.path()?;
|
||||
let path = self.path();
|
||||
let builder = || {
|
||||
let mut req = self.provider.build_request(Method::POST, path);
|
||||
req.headers.extend(extra_headers.clone());
|
||||
@@ -139,24 +136,14 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn errors_when_wire_is_chat() {
|
||||
let client = CompactClient::new(DummyTransport, provider(WireApi::Chat), DummyAuth);
|
||||
let input = CompactionInput {
|
||||
model: "gpt-test",
|
||||
input: &[],
|
||||
instructions: "inst",
|
||||
};
|
||||
let err = client
|
||||
.compact_input(&input, HeaderMap::new())
|
||||
.await
|
||||
.expect_err("expected wire mismatch to fail");
|
||||
#[test]
|
||||
fn path_is_responses_compact_for_supported_wire_apis() {
|
||||
let responses_client =
|
||||
CompactClient::new(DummyTransport, provider(WireApi::Responses), DummyAuth);
|
||||
assert_eq!(responses_client.path(), "responses/compact");
|
||||
|
||||
match err {
|
||||
ApiError::Stream(msg) => {
|
||||
assert_eq!(msg, "compact endpoint requires responses wire api");
|
||||
}
|
||||
other => panic!("unexpected error: {other:?}"),
|
||||
}
|
||||
let compact_client =
|
||||
CompactClient::new(DummyTransport, provider(WireApi::Compact), DummyAuth);
|
||||
assert_eq!(compact_client.path(), "responses/compact");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
pub mod chat;
|
||||
pub mod aggregate;
|
||||
pub mod compact;
|
||||
pub mod models;
|
||||
pub mod responses;
|
||||
|
||||
@@ -111,7 +111,6 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
|
||||
fn path(&self) -> &'static str {
|
||||
match self.streaming.provider().wire {
|
||||
WireApi::Responses | WireApi::Compact => "responses",
|
||||
WireApi::Chat => "chat/completions",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,8 +22,7 @@ pub use crate::common::ResponseEvent;
|
||||
pub use crate::common::ResponseStream;
|
||||
pub use crate::common::ResponsesApiRequest;
|
||||
pub use crate::common::create_text_param_for_request;
|
||||
pub use crate::endpoint::chat::AggregateStreamExt;
|
||||
pub use crate::endpoint::chat::ChatClient;
|
||||
pub use crate::endpoint::aggregate::AggregateStreamExt;
|
||||
pub use crate::endpoint::compact::CompactClient;
|
||||
pub use crate::endpoint::models::ModelsClient;
|
||||
pub use crate::endpoint::responses::ResponsesClient;
|
||||
@@ -34,8 +33,6 @@ pub use crate::error::ApiError;
|
||||
pub use crate::provider::Provider;
|
||||
pub use crate::provider::WireApi;
|
||||
pub use crate::provider::is_azure_responses_wire_base_url;
|
||||
pub use crate::requests::ChatRequest;
|
||||
pub use crate::requests::ChatRequestBuilder;
|
||||
pub use crate::requests::ResponsesRequest;
|
||||
pub use crate::requests::ResponsesRequestBuilder;
|
||||
pub use crate::sse::stream_from_fixture;
|
||||
|
||||
@@ -12,7 +12,6 @@ use url::Url;
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum WireApi {
|
||||
Responses,
|
||||
Chat,
|
||||
Compact,
|
||||
}
|
||||
|
||||
@@ -182,7 +181,7 @@ mod tests {
|
||||
}
|
||||
|
||||
assert!(!is_azure_responses_wire_base_url(
|
||||
WireApi::Chat,
|
||||
WireApi::Compact,
|
||||
"Azure",
|
||||
Some("https://foo.openai.azure.com/openai")
|
||||
));
|
||||
|
||||
@@ -1,494 +0,0 @@
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::requests::headers::build_conversation_headers;
|
||||
use crate::requests::headers::insert_header;
|
||||
use crate::requests::headers::subagent_header;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use http::HeaderMap;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Assembled request body plus headers for Chat Completions streaming calls.
|
||||
pub struct ChatRequest {
|
||||
pub body: Value,
|
||||
pub headers: HeaderMap,
|
||||
}
|
||||
|
||||
pub struct ChatRequestBuilder<'a> {
|
||||
model: &'a str,
|
||||
instructions: &'a str,
|
||||
input: &'a [ResponseItem],
|
||||
tools: &'a [Value],
|
||||
conversation_id: Option<String>,
|
||||
session_source: Option<SessionSource>,
|
||||
}
|
||||
|
||||
impl<'a> ChatRequestBuilder<'a> {
|
||||
pub fn new(
|
||||
model: &'a str,
|
||||
instructions: &'a str,
|
||||
input: &'a [ResponseItem],
|
||||
tools: &'a [Value],
|
||||
) -> Self {
|
||||
Self {
|
||||
model,
|
||||
instructions,
|
||||
input,
|
||||
tools,
|
||||
conversation_id: None,
|
||||
session_source: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conversation_id(mut self, id: Option<String>) -> Self {
|
||||
self.conversation_id = id;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn session_source(mut self, source: Option<SessionSource>) -> Self {
|
||||
self.session_source = source;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self, _provider: &Provider) -> Result<ChatRequest, ApiError> {
|
||||
let mut messages = Vec::<Value>::new();
|
||||
messages.push(json!({"role": "system", "content": self.instructions}));
|
||||
|
||||
let input = self.input;
|
||||
let mut reasoning_by_anchor_index: HashMap<usize, String> = HashMap::new();
|
||||
let mut last_emitted_role: Option<&str> = None;
|
||||
for item in input {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
last_emitted_role = Some("assistant")
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
ResponseItem::CustomToolCall { .. } => {}
|
||||
ResponseItem::CustomToolCallOutput { .. } => {}
|
||||
ResponseItem::WebSearchCall { .. } => {}
|
||||
ResponseItem::GhostSnapshot { .. } => {}
|
||||
ResponseItem::Compaction { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_user_index: Option<usize> = None;
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
last_user_index = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
if !matches!(last_emitted_role, Some("user")) {
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
if let Some(u_idx) = last_user_index
|
||||
&& idx <= u_idx
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let ResponseItem::Reasoning {
|
||||
content: Some(items),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
let mut text = String::new();
|
||||
for entry in items {
|
||||
match entry {
|
||||
ReasoningItemContent::ReasoningText { text: segment }
|
||||
| ReasoningItemContent::Text { text: segment } => {
|
||||
text.push_str(segment)
|
||||
}
|
||||
}
|
||||
}
|
||||
if text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut attached = false;
|
||||
if idx > 0
|
||||
&& let ResponseItem::Message { role, .. } = &input[idx - 1]
|
||||
&& role == "assistant"
|
||||
{
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx - 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
attached = true;
|
||||
}
|
||||
|
||||
if !attached && idx + 1 < input.len() {
|
||||
match &input[idx + 1] {
|
||||
ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::LocalShellCall { .. } => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
ResponseItem::Message { role, .. } if role == "assistant" => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_assistant_text: Option<String> = None;
|
||||
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
let mut text = String::new();
|
||||
let mut items: Vec<Value> = Vec::new();
|
||||
let mut saw_image = false;
|
||||
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t }
|
||||
| ContentItem::OutputText { text: t } => {
|
||||
text.push_str(t);
|
||||
items.push(json!({"type":"text","text": t}));
|
||||
}
|
||||
ContentItem::InputImage { image_url } => {
|
||||
saw_image = true;
|
||||
items.push(
|
||||
json!({"type":"image_url","image_url": {"url": image_url}}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if role == "assistant" {
|
||||
if let Some(prev) = &last_assistant_text
|
||||
&& prev == &text
|
||||
{
|
||||
continue;
|
||||
}
|
||||
last_assistant_text = Some(text.clone());
|
||||
}
|
||||
|
||||
let content_value = if role == "assistant" {
|
||||
json!(text)
|
||||
} else if saw_image {
|
||||
json!(items)
|
||||
} else {
|
||||
json!(text)
|
||||
};
|
||||
|
||||
let mut msg = json!({"role": role, "content": content_value});
|
||||
if role == "assistant"
|
||||
&& let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
let reasoning = reasoning_by_anchor_index.get(&idx).map(String::as_str);
|
||||
let tool_call = json!({
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
}
|
||||
});
|
||||
push_tool_call_message(&mut messages, tool_call, reasoning);
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id: _,
|
||||
status,
|
||||
action,
|
||||
} => {
|
||||
let reasoning = reasoning_by_anchor_index.get(&idx).map(String::as_str);
|
||||
let tool_call = json!({
|
||||
"id": id.clone().unwrap_or_default(),
|
||||
"type": "local_shell_call",
|
||||
"status": status,
|
||||
"action": action,
|
||||
});
|
||||
push_tool_call_message(&mut messages, tool_call, reasoning);
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
let content_value = if let Some(items) = &output.content_items {
|
||||
let mapped: Vec<Value> = items
|
||||
.iter()
|
||||
.map(|it| match it {
|
||||
FunctionCallOutputContentItem::InputText { text } => {
|
||||
json!({"type":"text","text": text})
|
||||
}
|
||||
FunctionCallOutputContentItem::InputImage { image_url } => {
|
||||
json!({"type":"image_url","image_url": {"url": image_url}})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
json!(mapped)
|
||||
} else {
|
||||
json!(output.content)
|
||||
};
|
||||
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": content_value,
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id,
|
||||
call_id: _,
|
||||
name,
|
||||
input,
|
||||
status: _,
|
||||
} => {
|
||||
let tool_call = json!({
|
||||
"id": id,
|
||||
"type": "custom",
|
||||
"custom": {
|
||||
"name": name,
|
||||
"input": input,
|
||||
}
|
||||
});
|
||||
let reasoning = reasoning_by_anchor_index.get(&idx).map(String::as_str);
|
||||
push_tool_call_message(&mut messages, tool_call, reasoning);
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": output,
|
||||
}));
|
||||
}
|
||||
ResponseItem::GhostSnapshot { .. } => {
|
||||
continue;
|
||||
}
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::Compaction { .. } => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let payload = json!({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
"tools": self.tools,
|
||||
});
|
||||
|
||||
let mut headers = build_conversation_headers(self.conversation_id);
|
||||
if let Some(subagent) = subagent_header(&self.session_source) {
|
||||
insert_header(&mut headers, "x-openai-subagent", &subagent);
|
||||
}
|
||||
|
||||
Ok(ChatRequest {
|
||||
body: payload,
|
||||
headers,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn push_tool_call_message(messages: &mut Vec<Value>, tool_call: Value, reasoning: Option<&str>) {
|
||||
// Chat Completions requires that tool calls are grouped into a single assistant message
|
||||
// (with `tool_calls: [...]`) followed by tool role responses.
|
||||
if let Some(Value::Object(obj)) = messages.last_mut()
|
||||
&& obj.get("role").and_then(Value::as_str) == Some("assistant")
|
||||
&& obj.get("content").is_some_and(Value::is_null)
|
||||
&& let Some(tool_calls) = obj.get_mut("tool_calls").and_then(Value::as_array_mut)
|
||||
{
|
||||
tool_calls.push(tool_call);
|
||||
if let Some(reasoning) = reasoning {
|
||||
if let Some(Value::String(existing)) = obj.get_mut("reasoning") {
|
||||
if !existing.is_empty() {
|
||||
existing.push('\n');
|
||||
}
|
||||
existing.push_str(reasoning);
|
||||
} else {
|
||||
obj.insert(
|
||||
"reasoning".to_string(),
|
||||
Value::String(reasoning.to_string()),
|
||||
);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [tool_call],
|
||||
});
|
||||
if let Some(reasoning) = reasoning
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::RetryConfig;
|
||||
use crate::provider::WireApi;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use http::HeaderValue;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::time::Duration;
|
||||
|
||||
fn provider() -> Provider {
|
||||
Provider {
|
||||
name: "openai".to_string(),
|
||||
base_url: "https://api.openai.com/v1".to_string(),
|
||||
query_params: None,
|
||||
wire: WireApi::Chat,
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(10),
|
||||
retry_429: false,
|
||||
retry_5xx: true,
|
||||
retry_transport: true,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(1),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn attaches_conversation_and_subagent_headers() {
|
||||
let prompt_input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hi".to_string(),
|
||||
}],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
}];
|
||||
let req = ChatRequestBuilder::new("gpt-test", "inst", &prompt_input, &[])
|
||||
.conversation_id(Some("conv-1".into()))
|
||||
.session_source(Some(SessionSource::SubAgent(SubAgentSource::Review)))
|
||||
.build(&provider())
|
||||
.expect("request");
|
||||
|
||||
assert_eq!(
|
||||
req.headers.get("session_id"),
|
||||
Some(&HeaderValue::from_static("conv-1"))
|
||||
);
|
||||
assert_eq!(
|
||||
req.headers.get("x-openai-subagent"),
|
||||
Some(&HeaderValue::from_static("review"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn groups_consecutive_tool_calls_into_a_single_assistant_message() {
|
||||
let prompt_input = vec![
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "read these".to_string(),
|
||||
}],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
},
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "read_file".to_string(),
|
||||
arguments: r#"{"path":"a.txt"}"#.to_string(),
|
||||
call_id: "call-a".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "read_file".to_string(),
|
||||
arguments: r#"{"path":"b.txt"}"#.to_string(),
|
||||
call_id: "call-b".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "read_file".to_string(),
|
||||
arguments: r#"{"path":"c.txt"}"#.to_string(),
|
||||
call_id: "call-c".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-a".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "A".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-b".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "B".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-c".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "C".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
let req = ChatRequestBuilder::new("gpt-test", "inst", &prompt_input, &[])
|
||||
.build(&provider())
|
||||
.expect("request");
|
||||
|
||||
let messages = req
|
||||
.body
|
||||
.get("messages")
|
||||
.and_then(|v| v.as_array())
|
||||
.expect("messages array");
|
||||
// system + user + assistant(tool_calls=[...]) + 3 tool outputs
|
||||
assert_eq!(messages.len(), 6);
|
||||
|
||||
assert_eq!(messages[0]["role"], "system");
|
||||
assert_eq!(messages[1]["role"], "user");
|
||||
|
||||
let tool_calls_msg = &messages[2];
|
||||
assert_eq!(tool_calls_msg["role"], "assistant");
|
||||
assert_eq!(tool_calls_msg["content"], serde_json::Value::Null);
|
||||
let tool_calls = tool_calls_msg["tool_calls"]
|
||||
.as_array()
|
||||
.expect("tool_calls array");
|
||||
assert_eq!(tool_calls.len(), 3);
|
||||
assert_eq!(tool_calls[0]["id"], "call-a");
|
||||
assert_eq!(tool_calls[1]["id"], "call-b");
|
||||
assert_eq!(tool_calls[2]["id"], "call-c");
|
||||
|
||||
assert_eq!(messages[3]["role"], "tool");
|
||||
assert_eq!(messages[3]["tool_call_id"], "call-a");
|
||||
assert_eq!(messages[4]["role"], "tool");
|
||||
assert_eq!(messages[4]["tool_call_id"], "call-b");
|
||||
assert_eq!(messages[5]["role"], "tool");
|
||||
assert_eq!(messages[5]["tool_call_id"], "call-c");
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,5 @@
|
||||
pub mod chat;
|
||||
pub(crate) mod headers;
|
||||
pub mod responses;
|
||||
|
||||
pub use chat::ChatRequest;
|
||||
pub use chat::ChatRequestBuilder;
|
||||
pub use responses::ResponsesRequest;
|
||||
pub use responses::ResponsesRequestBuilder;
|
||||
|
||||
@@ -1,717 +0,0 @@
|
||||
use crate::common::ResponseEvent;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::error::ApiError;
|
||||
use crate::telemetry::SseTelemetry;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
pub(crate) fn spawn_chat_stream(
|
||||
stream_response: StreamResponse,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<Arc<dyn SseTelemetry>>,
|
||||
_turn_state: Option<Arc<OnceLock<String>>>,
|
||||
) -> ResponseStream {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
|
||||
tokio::spawn(async move {
|
||||
process_chat_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
|
||||
});
|
||||
ResponseStream { rx_event }
|
||||
}
|
||||
|
||||
/// Processes Server-Sent Events from the legacy Chat Completions streaming API.
|
||||
///
|
||||
/// The upstream protocol terminates a streaming response with a final sentinel event
|
||||
/// (`data: [DONE]`). Historically, some of our test stubs have emitted `data: DONE`
|
||||
/// (without brackets) instead.
|
||||
///
|
||||
/// `eventsource_stream` delivers these sentinels as regular events rather than signaling
|
||||
/// end-of-stream. If we try to parse them as JSON, we log and skip them, then keep
|
||||
/// polling for more events.
|
||||
///
|
||||
/// On servers that keep the HTTP connection open after emitting the sentinel (notably
|
||||
/// wiremock on Windows), skipping the sentinel means we never emit `ResponseEvent::Completed`.
|
||||
/// Higher-level workflows/tests that wait for completion before issuing subsequent model
|
||||
/// calls will then stall, which shows up as "expected N requests, got 1" verification
|
||||
/// failures in the mock server.
|
||||
pub async fn process_chat_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent, ApiError>>,
|
||||
idle_timeout: Duration,
|
||||
telemetry: Option<std::sync::Arc<dyn SseTelemetry>>,
|
||||
) where
|
||||
S: Stream<Item = Result<bytes::Bytes, codex_client::TransportError>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct ToolCallState {
|
||||
id: Option<String>,
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
let mut tool_calls: HashMap<usize, ToolCallState> = HashMap::new();
|
||||
let mut tool_call_order: Vec<usize> = Vec::new();
|
||||
let mut tool_call_order_seen: HashSet<usize> = HashSet::new();
|
||||
let mut tool_call_index_by_id: HashMap<String, usize> = HashMap::new();
|
||||
let mut next_tool_call_index = 0usize;
|
||||
let mut last_tool_call_index: Option<usize> = None;
|
||||
let mut assistant_item: Option<ResponseItem> = None;
|
||||
let mut reasoning_item: Option<ResponseItem> = None;
|
||||
let mut completed_sent = false;
|
||||
|
||||
async fn flush_and_complete(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent, ApiError>>,
|
||||
reasoning_item: &mut Option<ResponseItem>,
|
||||
assistant_item: &mut Option<ResponseItem>,
|
||||
) {
|
||||
if let Some(reasoning) = reasoning_item.take() {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(assistant) = assistant_item.take() {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(assistant)))
|
||||
.await;
|
||||
}
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
|
||||
loop {
|
||||
let start = Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
if let Some(t) = telemetry.as_ref() {
|
||||
t.on_sse_poll(&response, start.elapsed());
|
||||
}
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(sse))) => sse,
|
||||
Ok(Some(Err(e))) => {
|
||||
let _ = tx_event.send(Err(ApiError::Stream(e.to_string()))).await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
if !completed_sent {
|
||||
flush_and_complete(&tx_event, &mut reasoning_item, &mut assistant_item).await;
|
||||
}
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx_event
|
||||
.send(Err(ApiError::Stream("idle timeout waiting for SSE".into())))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
trace!("SSE event: {}", sse.data);
|
||||
|
||||
let data = sse.data.trim();
|
||||
|
||||
if data.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if data == "[DONE]" || data == "DONE" {
|
||||
if !completed_sent {
|
||||
flush_and_complete(&tx_event, &mut reasoning_item, &mut assistant_item).await;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let value: serde_json::Value = match serde_json::from_str(data) {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
debug!(
|
||||
"Failed to parse ChatCompletions SSE event: {err}, data: {}",
|
||||
data
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let Some(choices) = value.get("choices").and_then(|c| c.as_array()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for choice in choices {
|
||||
if let Some(delta) = choice.get("delta") {
|
||||
if let Some(reasoning) = delta.get("reasoning") {
|
||||
if let Some(text) = reasoning.as_str() {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
|
||||
.await;
|
||||
} else if let Some(text) = reasoning.get("text").and_then(|v| v.as_str()) {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
|
||||
.await;
|
||||
} else if let Some(text) = reasoning.get("content").and_then(|v| v.as_str()) {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(content) = delta.get("content") {
|
||||
if content.is_array() {
|
||||
for item in content.as_array().unwrap_or(&vec![]) {
|
||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||
append_assistant_text(
|
||||
&tx_event,
|
||||
&mut assistant_item,
|
||||
text.to_string(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
} else if let Some(text) = content.as_str() {
|
||||
append_assistant_text(&tx_event, &mut assistant_item, text.to_string())
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tool_call_values) = delta.get("tool_calls").and_then(|c| c.as_array()) {
|
||||
for tool_call in tool_call_values {
|
||||
let mut index = tool_call
|
||||
.get("index")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map(|i| i as usize);
|
||||
|
||||
let mut call_id_for_lookup = None;
|
||||
if let Some(call_id) = tool_call.get("id").and_then(|i| i.as_str()) {
|
||||
call_id_for_lookup = Some(call_id.to_string());
|
||||
if let Some(existing) = tool_call_index_by_id.get(call_id) {
|
||||
index = Some(*existing);
|
||||
}
|
||||
}
|
||||
|
||||
if index.is_none() && call_id_for_lookup.is_none() {
|
||||
index = last_tool_call_index;
|
||||
}
|
||||
|
||||
let index = index.unwrap_or_else(|| {
|
||||
while tool_calls.contains_key(&next_tool_call_index) {
|
||||
next_tool_call_index += 1;
|
||||
}
|
||||
let idx = next_tool_call_index;
|
||||
next_tool_call_index += 1;
|
||||
idx
|
||||
});
|
||||
|
||||
let call_state = tool_calls.entry(index).or_default();
|
||||
if tool_call_order_seen.insert(index) {
|
||||
tool_call_order.push(index);
|
||||
}
|
||||
|
||||
if let Some(id) = tool_call.get("id").and_then(|i| i.as_str()) {
|
||||
call_state.id.get_or_insert_with(|| id.to_string());
|
||||
tool_call_index_by_id.entry(id.to_string()).or_insert(index);
|
||||
}
|
||||
|
||||
if let Some(func) = tool_call.get("function") {
|
||||
if let Some(fname) = func.get("name").and_then(|n| n.as_str())
|
||||
&& !fname.is_empty()
|
||||
{
|
||||
call_state.name.get_or_insert_with(|| fname.to_string());
|
||||
}
|
||||
if let Some(arguments) = func.get("arguments").and_then(|a| a.as_str())
|
||||
{
|
||||
call_state.arguments.push_str(arguments);
|
||||
}
|
||||
}
|
||||
|
||||
last_tool_call_index = Some(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(message) = choice.get("message")
|
||||
&& let Some(reasoning) = message.get("reasoning")
|
||||
{
|
||||
if let Some(text) = reasoning.as_str() {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
|
||||
} else if let Some(text) = reasoning.get("text").and_then(|v| v.as_str()) {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
|
||||
} else if let Some(text) = reasoning.get("content").and_then(|v| v.as_str()) {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
|
||||
}
|
||||
}
|
||||
|
||||
let finish_reason = choice.get("finish_reason").and_then(|r| r.as_str());
|
||||
if finish_reason == Some("stop") {
|
||||
if let Some(reasoning) = reasoning_item.take() {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(assistant) = assistant_item.take() {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(assistant)))
|
||||
.await;
|
||||
}
|
||||
if !completed_sent {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
completed_sent = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if finish_reason == Some("length") {
|
||||
let _ = tx_event.send(Err(ApiError::ContextWindowExceeded)).await;
|
||||
return;
|
||||
}
|
||||
|
||||
if finish_reason == Some("tool_calls") {
|
||||
if let Some(reasoning) = reasoning_item.take() {
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
|
||||
.await;
|
||||
}
|
||||
|
||||
for index in tool_call_order.drain(..) {
|
||||
let Some(state) = tool_calls.remove(&index) else {
|
||||
continue;
|
||||
};
|
||||
tool_call_order_seen.remove(&index);
|
||||
let ToolCallState {
|
||||
id,
|
||||
name,
|
||||
arguments,
|
||||
} = state;
|
||||
let Some(name) = name else {
|
||||
debug!("Skipping tool call at index {index} because name is missing");
|
||||
continue;
|
||||
};
|
||||
let item = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name,
|
||||
arguments,
|
||||
call_id: id.unwrap_or_else(|| format!("tool-call-{index}")),
|
||||
};
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_assistant_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent, ApiError>>,
|
||||
assistant_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if assistant_item.is_none() {
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
};
|
||||
*assistant_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
|
||||
content.push(ContentItem::OutputText { text: text.clone() });
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputTextDelta(text.clone())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_reasoning_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent, ApiError>>,
|
||||
reasoning_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if reasoning_item.is_none() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
*reasoning_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Reasoning {
|
||||
content: Some(content),
|
||||
..
|
||||
}) = reasoning_item
|
||||
{
|
||||
let content_index = content.len() as i64;
|
||||
content.push(ReasoningItemContent::ReasoningText { text: text.clone() });
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta: text.clone(),
|
||||
content_index,
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::TryStreamExt;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::io::ReaderStream;
|
||||
|
||||
fn build_body(events: &[serde_json::Value]) -> String {
|
||||
let mut body = String::new();
|
||||
for e in events {
|
||||
body.push_str(&format!("event: message\ndata: {e}\n\n"));
|
||||
}
|
||||
body
|
||||
}
|
||||
|
||||
/// Regression test: the stream should complete when we see a `[DONE]` sentinel.
|
||||
///
|
||||
/// This is important for tests/mocks that don't immediately close the underlying
|
||||
/// connection after emitting the sentinel.
|
||||
#[tokio::test]
|
||||
async fn completes_on_done_sentinel_without_json() {
|
||||
let events = collect_events("event: message\ndata: [DONE]\n\n").await;
|
||||
assert_matches!(&events[..], [ResponseEvent::Completed { .. }]);
|
||||
}
|
||||
|
||||
async fn collect_events(body: &str) -> Vec<ResponseEvent> {
|
||||
let reader = ReaderStream::new(std::io::Cursor::new(body.to_string()))
|
||||
.map_err(|err| codex_client::TransportError::Network(err.to_string()));
|
||||
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(16);
|
||||
tokio::spawn(process_chat_sse(
|
||||
reader,
|
||||
tx,
|
||||
Duration::from_millis(1000),
|
||||
None,
|
||||
));
|
||||
|
||||
let mut out = Vec::new();
|
||||
while let Some(ev) = rx.recv().await {
|
||||
out.push(ev.expect("stream error"));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concatenates_tool_call_arguments_across_deltas() {
|
||||
let delta_name = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"index": 0,
|
||||
"function": { "name": "do_a" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let delta_args_1 = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"index": 0,
|
||||
"function": { "arguments": "{ \"foo\":" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let delta_args_2 = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"index": 0,
|
||||
"function": { "arguments": "1}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_name, delta_args_1, delta_args_2, finish]);
|
||||
let events = collect_events(&body).await;
|
||||
assert_matches!(
|
||||
&events[..],
|
||||
[
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }),
|
||||
ResponseEvent::Completed { .. }
|
||||
] if call_id == "call_a" && name == "do_a" && arguments == "{ \"foo\":1}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn emits_multiple_tool_calls() {
|
||||
let delta_a = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a", "arguments": "{\"foo\":1}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let delta_b = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_b",
|
||||
"function": { "name": "do_b", "arguments": "{\"bar\":2}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_a, delta_b, finish]);
|
||||
let events = collect_events(&body).await;
|
||||
assert_matches!(
|
||||
&events[..],
|
||||
[
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id: call_a, name: name_a, arguments: args_a, .. }),
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id: call_b, name: name_b, arguments: args_b, .. }),
|
||||
ResponseEvent::Completed { .. }
|
||||
] if call_a == "call_a" && name_a == "do_a" && args_a == "{\"foo\":1}" && call_b == "call_b" && name_b == "do_b" && args_b == "{\"bar\":2}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn emits_tool_calls_for_multiple_choices() {
|
||||
let payload = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"index": 0,
|
||||
"function": { "name": "do_a", "arguments": "{}" }
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
},
|
||||
{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_b",
|
||||
"index": 0,
|
||||
"function": { "name": "do_b", "arguments": "{}" }
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let body = build_body(&[payload]);
|
||||
let events = collect_events(&body).await;
|
||||
assert_matches!(
|
||||
&events[..],
|
||||
[
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id: call_a, name: name_a, arguments: args_a, .. }),
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id: call_b, name: name_b, arguments: args_b, .. }),
|
||||
ResponseEvent::Completed { .. }
|
||||
] if call_a == "call_a" && name_a == "do_a" && args_a == "{}" && call_b == "call_b" && name_b == "do_b" && args_b == "{}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn merges_tool_calls_by_index_when_id_missing_on_subsequent_deltas() {
|
||||
let delta_with_id = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"index": 0,
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a", "arguments": "{ \"foo\":" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let delta_without_id = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"index": 0,
|
||||
"function": { "arguments": "1}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_with_id, delta_without_id, finish]);
|
||||
let events = collect_events(&body).await;
|
||||
assert_matches!(
|
||||
&events[..],
|
||||
[
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }),
|
||||
ResponseEvent::Completed { .. }
|
||||
] if call_id == "call_a" && name == "do_a" && arguments == "{ \"foo\":1}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn preserves_tool_call_name_when_empty_deltas_arrive() {
|
||||
let delta_with_name = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let delta_with_empty_name = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "", "arguments": "{}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_with_name, delta_with_empty_name, finish]);
|
||||
let events = collect_events(&body).await;
|
||||
assert_matches!(
|
||||
&events[..],
|
||||
[
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { name, arguments, .. }),
|
||||
ResponseEvent::Completed { .. }
|
||||
] if name == "do_a" && arguments == "{}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn emits_tool_calls_even_when_content_and_reasoning_present() {
|
||||
let delta_content_and_tools = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"content": [{"text": "hi"}],
|
||||
"reasoning": "because",
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a", "arguments": "{}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_content_and_tools, finish]);
|
||||
let events = collect_events(&body).await;
|
||||
|
||||
assert_matches!(
|
||||
&events[..],
|
||||
[
|
||||
ResponseEvent::OutputItemAdded(ResponseItem::Reasoning { .. }),
|
||||
ResponseEvent::ReasoningContentDelta { .. },
|
||||
ResponseEvent::OutputItemAdded(ResponseItem::Message { .. }),
|
||||
ResponseEvent::OutputTextDelta(delta),
|
||||
ResponseEvent::OutputItemDone(ResponseItem::Reasoning { .. }),
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, .. }),
|
||||
ResponseEvent::OutputItemDone(ResponseItem::Message { .. }),
|
||||
ResponseEvent::Completed { .. }
|
||||
] if delta == "hi" && call_id == "call_a" && name == "do_a"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn drops_partial_tool_calls_on_stop_finish_reason() {
|
||||
let delta_tool = json!({
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"id": "call_a",
|
||||
"function": { "name": "do_a", "arguments": "{}" }
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
let finish_stop = json!({
|
||||
"choices": [{
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
});
|
||||
|
||||
let body = build_body(&[delta_tool, finish_stop]);
|
||||
let events = collect_events(&body).await;
|
||||
|
||||
assert!(!events.iter().any(|ev| {
|
||||
matches!(
|
||||
ev,
|
||||
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
|
||||
)
|
||||
}));
|
||||
assert_matches!(events.last(), Some(ResponseEvent::Completed { .. }));
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
pub mod chat;
|
||||
pub mod responses;
|
||||
|
||||
pub use responses::process_sse;
|
||||
|
||||
@@ -6,7 +6,6 @@ use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use codex_api::AuthProvider;
|
||||
use codex_api::ChatClient;
|
||||
use codex_api::Provider;
|
||||
use codex_api::ResponsesClient;
|
||||
use codex_api::ResponsesOptions;
|
||||
@@ -195,34 +194,6 @@ data: {"id":"resp-1","output":[{"type":"message","role":"assistant","content":[{
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_client_uses_chat_completions_path_for_chat_wire() -> Result<()> {
|
||||
let state = RecordingState::default();
|
||||
let transport = RecordingTransport::new(state.clone());
|
||||
let client = ChatClient::new(transport, provider("openai", WireApi::Chat), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/chat/completions");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_client_uses_responses_path_for_responses_wire() -> Result<()> {
|
||||
let state = RecordingState::default();
|
||||
let transport = RecordingTransport::new(state.clone());
|
||||
let client = ChatClient::new(transport, provider("openai", WireApi::Responses), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client.stream(body, HeaderMap::new()).await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/responses");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_client_uses_responses_path_for_responses_wire() -> Result<()> {
|
||||
let state = RecordingState::default();
|
||||
@@ -240,10 +211,10 @@ async fn responses_client_uses_responses_path_for_responses_wire() -> Result<()>
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_client_uses_chat_path_for_chat_wire() -> Result<()> {
|
||||
async fn responses_client_uses_responses_path_for_compact_wire() -> Result<()> {
|
||||
let state = RecordingState::default();
|
||||
let transport = RecordingTransport::new(state.clone());
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Chat), NoAuth);
|
||||
let client = ResponsesClient::new(transport, provider("openai", WireApi::Compact), NoAuth);
|
||||
|
||||
let body = serde_json::json!({ "echo": true });
|
||||
let _stream = client
|
||||
@@ -251,7 +222,7 @@ async fn responses_client_uses_chat_path_for_chat_wire() -> Result<()> {
|
||||
.await?;
|
||||
|
||||
let requests = state.take_stream_requests();
|
||||
assert_path_ends_with(&requests, "/chat/completions");
|
||||
assert_path_ends_with(&requests, "/responses");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user