chore: proper client extraction (#6996)

This commit is contained in:
jif-oai
2025-11-25 18:06:12 +00:00
committed by GitHub
parent 2845e2c006
commit 4502b1b263
45 changed files with 4893 additions and 2748 deletions

View File

@@ -0,0 +1,504 @@
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::error::ApiError;
use crate::telemetry::SseTelemetry;
use codex_client::StreamResponse;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ResponseItem;
use eventsource_stream::Eventsource;
use futures::Stream;
use futures::StreamExt;
use std::collections::HashMap;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::Instant;
use tokio::time::timeout;
use tracing::debug;
use tracing::trace;
pub(crate) fn spawn_chat_stream(
stream_response: StreamResponse,
idle_timeout: Duration,
telemetry: Option<std::sync::Arc<dyn SseTelemetry>>,
) -> ResponseStream {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
tokio::spawn(async move {
process_chat_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
});
ResponseStream { rx_event }
}
pub async fn process_chat_sse<S>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent, ApiError>>,
idle_timeout: Duration,
telemetry: Option<std::sync::Arc<dyn SseTelemetry>>,
) where
S: Stream<Item = Result<bytes::Bytes, codex_client::TransportError>> + Unpin,
{
let mut stream = stream.eventsource();
#[derive(Default, Debug)]
struct ToolCallState {
name: Option<String>,
arguments: String,
}
let mut tool_calls: HashMap<String, ToolCallState> = HashMap::new();
let mut tool_call_order: Vec<String> = Vec::new();
let mut assistant_item: Option<ResponseItem> = None;
let mut reasoning_item: Option<ResponseItem> = None;
let mut completed_sent = false;
loop {
let start = Instant::now();
let response = timeout(idle_timeout, stream.next()).await;
if let Some(t) = telemetry.as_ref() {
t.on_sse_poll(&response, start.elapsed());
}
let sse = match response {
Ok(Some(Ok(sse))) => sse,
Ok(Some(Err(e))) => {
let _ = tx_event.send(Err(ApiError::Stream(e.to_string()))).await;
return;
}
Ok(None) => {
if let Some(reasoning) = reasoning_item {
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
.await;
}
if let Some(assistant) = assistant_item {
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemDone(assistant)))
.await;
}
if !completed_sent {
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
}
return;
}
Err(_) => {
let _ = tx_event
.send(Err(ApiError::Stream("idle timeout waiting for SSE".into())))
.await;
return;
}
};
trace!("SSE event: {}", sse.data);
if sse.data.trim().is_empty() {
continue;
}
let value: serde_json::Value = match serde_json::from_str(&sse.data) {
Ok(val) => val,
Err(err) => {
debug!(
"Failed to parse ChatCompletions SSE event: {err}, data: {}",
&sse.data
);
continue;
}
};
let Some(choices) = value.get("choices").and_then(|c| c.as_array()) else {
continue;
};
for choice in choices {
if let Some(delta) = choice.get("delta") {
if let Some(reasoning) = delta.get("reasoning") {
if let Some(text) = reasoning.as_str() {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
.await;
} else if let Some(text) = reasoning.get("text").and_then(|v| v.as_str()) {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
.await;
} else if let Some(text) = reasoning.get("content").and_then(|v| v.as_str()) {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string())
.await;
}
}
if let Some(content) = delta.get("content") {
if content.is_array() {
for item in content.as_array().unwrap_or(&vec![]) {
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
append_assistant_text(
&tx_event,
&mut assistant_item,
text.to_string(),
)
.await;
}
}
} else if let Some(text) = content.as_str() {
append_assistant_text(&tx_event, &mut assistant_item, text.to_string())
.await;
}
}
if let Some(tool_call_values) = delta.get("tool_calls").and_then(|c| c.as_array()) {
for tool_call in tool_call_values {
let id = tool_call
.get("id")
.and_then(|i| i.as_str())
.map(str::to_string)
.unwrap_or_else(|| format!("tool-call-{}", tool_call_order.len()));
let call_state = tool_calls.entry(id.clone()).or_default();
if !tool_call_order.contains(&id) {
tool_call_order.push(id.clone());
}
if let Some(func) = tool_call.get("function") {
if let Some(fname) = func.get("name").and_then(|n| n.as_str()) {
call_state.name = Some(fname.to_string());
}
if let Some(arguments) = func.get("arguments").and_then(|a| a.as_str())
{
call_state.arguments.push_str(arguments);
}
}
}
}
}
if let Some(message) = choice.get("message")
&& let Some(reasoning) = message.get("reasoning")
{
if let Some(text) = reasoning.as_str() {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
} else if let Some(text) = reasoning.get("text").and_then(|v| v.as_str()) {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
} else if let Some(text) = reasoning.get("content").and_then(|v| v.as_str()) {
append_reasoning_text(&tx_event, &mut reasoning_item, text.to_string()).await;
}
}
let finish_reason = choice.get("finish_reason").and_then(|r| r.as_str());
if finish_reason == Some("stop") {
if let Some(reasoning) = reasoning_item.take() {
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
.await;
}
if let Some(assistant) = assistant_item.take() {
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemDone(assistant)))
.await;
}
if !completed_sent {
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
completed_sent = true;
}
continue;
}
if finish_reason == Some("length") {
let _ = tx_event.send(Err(ApiError::ContextWindowExceeded)).await;
return;
}
if finish_reason == Some("tool_calls") {
if let Some(reasoning) = reasoning_item.take() {
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemDone(reasoning)))
.await;
}
for call_id in tool_call_order.drain(..) {
let state = tool_calls.remove(&call_id).unwrap_or_default();
let item = ResponseItem::FunctionCall {
id: None,
name: state.name.unwrap_or_default(),
arguments: state.arguments,
call_id: call_id.clone(),
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
}
}
}
}
async fn append_assistant_text(
tx_event: &mpsc::Sender<Result<ResponseEvent, ApiError>>,
assistant_item: &mut Option<ResponseItem>,
text: String,
) {
if assistant_item.is_none() {
let item = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![],
};
*assistant_item = Some(item.clone());
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await;
}
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
content.push(ContentItem::OutputText { text: text.clone() });
let _ = tx_event
.send(Ok(ResponseEvent::OutputTextDelta(text.clone())))
.await;
}
}
async fn append_reasoning_text(
tx_event: &mpsc::Sender<Result<ResponseEvent, ApiError>>,
reasoning_item: &mut Option<ResponseItem>,
text: String,
) {
if reasoning_item.is_none() {
let item = ResponseItem::Reasoning {
id: String::new(),
summary: Vec::new(),
content: Some(vec![]),
encrypted_content: None,
};
*reasoning_item = Some(item.clone());
let _ = tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await;
}
if let Some(ResponseItem::Reasoning {
content: Some(content),
..
}) = reasoning_item
{
let content_index = content.len() as i64;
content.push(ReasoningItemContent::ReasoningText { text: text.clone() });
let _ = tx_event
.send(Ok(ResponseEvent::ReasoningContentDelta {
delta: text.clone(),
content_index,
}))
.await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use codex_protocol::models::ResponseItem;
use futures::TryStreamExt;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_util::io::ReaderStream;
fn build_body(events: &[serde_json::Value]) -> String {
let mut body = String::new();
for e in events {
body.push_str(&format!("event: message\ndata: {e}\n\n"));
}
body
}
async fn collect_events(body: &str) -> Vec<ResponseEvent> {
let reader = ReaderStream::new(std::io::Cursor::new(body.to_string()))
.map_err(|err| codex_client::TransportError::Network(err.to_string()));
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(16);
tokio::spawn(process_chat_sse(
reader,
tx,
Duration::from_millis(1000),
None,
));
let mut out = Vec::new();
while let Some(ev) = rx.recv().await {
out.push(ev.expect("stream error"));
}
out
}
#[tokio::test]
async fn emits_multiple_tool_calls() {
let delta_a = json!({
"choices": [{
"delta": {
"tool_calls": [{
"id": "call_a",
"function": { "name": "do_a", "arguments": "{\"foo\":1}" }
}]
}
}]
});
let delta_b = json!({
"choices": [{
"delta": {
"tool_calls": [{
"id": "call_b",
"function": { "name": "do_b", "arguments": "{\"bar\":2}" }
}]
}
}]
});
let finish = json!({
"choices": [{
"finish_reason": "tool_calls"
}]
});
let body = build_body(&[delta_a, delta_b, finish]);
let events = collect_events(&body).await;
assert_eq!(events.len(), 3);
assert_matches!(
&events[0],
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. })
if call_id == "call_a" && name == "do_a" && arguments == "{\"foo\":1}"
);
assert_matches!(
&events[1],
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. })
if call_id == "call_b" && name == "do_b" && arguments == "{\"bar\":2}"
);
assert_matches!(events[2], ResponseEvent::Completed { .. });
}
#[tokio::test]
async fn concatenates_tool_call_arguments_across_deltas() {
let delta_name = json!({
"choices": [{
"delta": {
"tool_calls": [{
"id": "call_a",
"function": { "name": "do_a" }
}]
}
}]
});
let delta_args_1 = json!({
"choices": [{
"delta": {
"tool_calls": [{
"id": "call_a",
"function": { "arguments": "{ \"foo\":" }
}]
}
}]
});
let delta_args_2 = json!({
"choices": [{
"delta": {
"tool_calls": [{
"id": "call_a",
"function": { "arguments": "1}" }
}]
}
}]
});
let finish = json!({
"choices": [{
"finish_reason": "tool_calls"
}]
});
let body = build_body(&[delta_name, delta_args_1, delta_args_2, finish]);
let events = collect_events(&body).await;
assert_matches!(
&events[..],
[
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }),
ResponseEvent::Completed { .. }
] if call_id == "call_a" && name == "do_a" && arguments == "{ \"foo\":1}"
);
}
#[tokio::test]
async fn emits_tool_calls_even_when_content_and_reasoning_present() {
let delta_content_and_tools = json!({
"choices": [{
"delta": {
"content": [{"text": "hi"}],
"reasoning": "because",
"tool_calls": [{
"id": "call_a",
"function": { "name": "do_a", "arguments": "{}" }
}]
}
}]
});
let finish = json!({
"choices": [{
"finish_reason": "tool_calls"
}]
});
let body = build_body(&[delta_content_and_tools, finish]);
let events = collect_events(&body).await;
assert_matches!(
&events[..],
[
ResponseEvent::OutputItemAdded(ResponseItem::Reasoning { .. }),
ResponseEvent::ReasoningContentDelta { .. },
ResponseEvent::OutputItemAdded(ResponseItem::Message { .. }),
ResponseEvent::OutputTextDelta(delta),
ResponseEvent::OutputItemDone(ResponseItem::Reasoning { .. }),
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, .. }),
ResponseEvent::OutputItemDone(ResponseItem::Message { .. }),
ResponseEvent::Completed { .. }
] if delta == "hi" && call_id == "call_a" && name == "do_a"
);
}
#[tokio::test]
async fn drops_partial_tool_calls_on_stop_finish_reason() {
let delta_tool = json!({
"choices": [{
"delta": {
"tool_calls": [{
"id": "call_a",
"function": { "name": "do_a", "arguments": "{}" }
}]
}
}]
});
let finish_stop = json!({
"choices": [{
"finish_reason": "stop"
}]
});
let body = build_body(&[delta_tool, finish_stop]);
let events = collect_events(&body).await;
assert!(!events.iter().any(|ev| {
matches!(
ev,
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
)
}));
assert_matches!(events.last(), Some(ResponseEvent::Completed { .. }));
}
}

View File

@@ -0,0 +1,6 @@
pub mod chat;
pub mod responses;
pub use responses::process_sse;
pub use responses::spawn_response_stream;
pub use responses::stream_from_fixture;

View File

@@ -0,0 +1,672 @@
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::error::ApiError;
use crate::rate_limits::parse_rate_limit;
use crate::telemetry::SseTelemetry;
use codex_client::ByteStream;
use codex_client::StreamResponse;
use codex_client::TransportError;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::TokenUsage;
use eventsource_stream::Eventsource;
use futures::StreamExt;
use futures::TryStreamExt;
use serde::Deserialize;
use serde_json::Value;
use std::io::BufRead;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::Instant;
use tokio::time::timeout;
use tokio_util::io::ReaderStream;
use tracing::debug;
use tracing::trace;
/// Streams SSE events from an on-disk fixture for tests.
pub fn stream_from_fixture(
path: impl AsRef<Path>,
idle_timeout: Duration,
) -> Result<ResponseStream, ApiError> {
let file =
std::fs::File::open(path.as_ref()).map_err(|err| ApiError::Stream(err.to_string()))?;
let mut content = String::new();
for line in std::io::BufReader::new(file).lines() {
let line = line.map_err(|err| ApiError::Stream(err.to_string()))?;
content.push_str(&line);
content.push_str("\n\n");
}
let reader = std::io::Cursor::new(content);
let stream = ReaderStream::new(reader).map_err(|err| TransportError::Network(err.to_string()));
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
tokio::spawn(process_sse(Box::pin(stream), tx_event, idle_timeout, None));
Ok(ResponseStream { rx_event })
}
pub fn spawn_response_stream(
stream_response: StreamResponse,
idle_timeout: Duration,
telemetry: Option<Arc<dyn SseTelemetry>>,
) -> ResponseStream {
let rate_limits = parse_rate_limit(&stream_response.headers);
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
tokio::spawn(async move {
if let Some(snapshot) = rate_limits {
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
}
process_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
});
ResponseStream { rx_event }
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Error {
r#type: Option<String>,
code: Option<String>,
message: Option<String>,
plan_type: Option<String>,
resets_at: Option<i64>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ResponseCompleted {
id: String,
#[serde(default)]
usage: Option<ResponseCompletedUsage>,
}
#[derive(Debug, Deserialize)]
struct ResponseCompletedUsage {
input_tokens: i64,
input_tokens_details: Option<ResponseCompletedInputTokensDetails>,
output_tokens: i64,
output_tokens_details: Option<ResponseCompletedOutputTokensDetails>,
total_tokens: i64,
}
impl From<ResponseCompletedUsage> for TokenUsage {
fn from(val: ResponseCompletedUsage) -> Self {
TokenUsage {
input_tokens: val.input_tokens,
cached_input_tokens: val
.input_tokens_details
.map(|d| d.cached_tokens)
.unwrap_or(0),
output_tokens: val.output_tokens,
reasoning_output_tokens: val
.output_tokens_details
.map(|d| d.reasoning_tokens)
.unwrap_or(0),
total_tokens: val.total_tokens,
}
}
}
#[derive(Debug, Deserialize)]
struct ResponseCompletedInputTokensDetails {
cached_tokens: i64,
}
#[derive(Debug, Deserialize)]
struct ResponseCompletedOutputTokensDetails {
reasoning_tokens: i64,
}
#[derive(Deserialize, Debug)]
struct SseEvent {
#[serde(rename = "type")]
kind: String,
response: Option<Value>,
item: Option<Value>,
delta: Option<String>,
summary_index: Option<i64>,
content_index: Option<i64>,
}
pub async fn process_sse(
stream: ByteStream,
tx_event: mpsc::Sender<Result<ResponseEvent, ApiError>>,
idle_timeout: Duration,
telemetry: Option<Arc<dyn SseTelemetry>>,
) {
let mut stream = stream.eventsource();
let mut response_completed: Option<ResponseCompleted> = None;
let mut response_error: Option<ApiError> = None;
loop {
let start = Instant::now();
let response = timeout(idle_timeout, stream.next()).await;
if let Some(t) = telemetry.as_ref() {
t.on_sse_poll(&response, start.elapsed());
}
let sse = match response {
Ok(Some(Ok(sse))) => sse,
Ok(Some(Err(e))) => {
debug!("SSE Error: {e:#}");
let _ = tx_event.send(Err(ApiError::Stream(e.to_string()))).await;
return;
}
Ok(None) => {
match response_completed.take() {
Some(ResponseCompleted { id, usage }) => {
let event = ResponseEvent::Completed {
response_id: id,
token_usage: usage.map(Into::into),
};
let _ = tx_event.send(Ok(event)).await;
}
None => {
let error = response_error.unwrap_or(ApiError::Stream(
"stream closed before response.completed".into(),
));
let _ = tx_event.send(Err(error)).await;
}
}
return;
}
Err(_) => {
let _ = tx_event
.send(Err(ApiError::Stream("idle timeout waiting for SSE".into())))
.await;
return;
}
};
let raw = sse.data.clone();
trace!("SSE event: {raw}");
let event: SseEvent = match serde_json::from_str(&sse.data) {
Ok(event) => event,
Err(e) => {
debug!("Failed to parse SSE event: {e}, data: {}", &sse.data);
continue;
}
};
match event.kind.as_str() {
"response.output_item.done" => {
let Some(item_val) = event.item else { continue };
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
debug!("failed to parse ResponseItem from output_item.done");
continue;
};
let event = ResponseEvent::OutputItemDone(item);
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
"response.output_text.delta" => {
if let Some(delta) = event.delta {
let event = ResponseEvent::OutputTextDelta(delta);
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
}
"response.reasoning_summary_text.delta" => {
if let (Some(delta), Some(summary_index)) = (event.delta, event.summary_index) {
let event = ResponseEvent::ReasoningSummaryDelta {
delta,
summary_index,
};
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
}
"response.reasoning_text.delta" => {
if let (Some(delta), Some(content_index)) = (event.delta, event.content_index) {
let event = ResponseEvent::ReasoningContentDelta {
delta,
content_index,
};
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
}
"response.created" => {
if event.response.is_some() {
let _ = tx_event.send(Ok(ResponseEvent::Created {})).await;
}
}
"response.failed" => {
if let Some(resp_val) = event.response {
response_error =
Some(ApiError::Stream("response.failed event received".into()));
if let Some(error) = resp_val.get("error")
&& let Ok(error) = serde_json::from_value::<Error>(error.clone())
{
if is_context_window_error(&error) {
response_error = Some(ApiError::ContextWindowExceeded);
} else if is_quota_exceeded_error(&error) {
response_error = Some(ApiError::QuotaExceeded);
} else if is_usage_not_included(&error) {
response_error = Some(ApiError::UsageNotIncluded);
} else {
let delay = try_parse_retry_after(&error);
let message = error.message.clone().unwrap_or_default();
response_error = Some(ApiError::Retryable { message, delay });
}
}
}
}
"response.completed" => {
if let Some(resp_val) = event.response {
match serde_json::from_value::<ResponseCompleted>(resp_val) {
Ok(r) => {
response_completed = Some(r);
}
Err(e) => {
let error = format!("failed to parse ResponseCompleted: {e}");
debug!(error);
response_error = Some(ApiError::Stream(error));
continue;
}
};
};
}
"response.output_item.added" => {
let Some(item_val) = event.item else { continue };
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
debug!("failed to parse ResponseItem from output_item.done");
continue;
};
let event = ResponseEvent::OutputItemAdded(item);
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
"response.reasoning_summary_part.added" => {
if let Some(summary_index) = event.summary_index {
let event = ResponseEvent::ReasoningSummaryPartAdded { summary_index };
if tx_event.send(Ok(event)).await.is_err() {
return;
}
}
}
_ => {}
}
}
}
fn try_parse_retry_after(err: &Error) -> Option<Duration> {
if err.code.as_deref() != Some("rate_limit_exceeded") {
return None;
}
let re = rate_limit_regex();
if let Some(message) = &err.message
&& let Some(captures) = re.captures(message)
{
let seconds = captures.get(1);
let unit = captures.get(2);
if let (Some(value), Some(unit)) = (seconds, unit) {
let value = value.as_str().parse::<f64>().ok()?;
let unit = unit.as_str().to_ascii_lowercase();
if unit == "s" || unit.starts_with("second") {
return Some(Duration::from_secs_f64(value));
} else if unit == "ms" {
return Some(Duration::from_millis(value as u64));
}
}
}
None
}
fn is_context_window_error(error: &Error) -> bool {
error.code.as_deref() == Some("context_length_exceeded")
}
fn is_quota_exceeded_error(error: &Error) -> bool {
error.code.as_deref() == Some("insufficient_quota")
}
fn is_usage_not_included(error: &Error) -> bool {
error.code.as_deref() == Some("usage_not_included")
}
fn rate_limit_regex() -> &'static regex_lite::Regex {
static RE: std::sync::OnceLock<regex_lite::Regex> = std::sync::OnceLock::new();
#[expect(clippy::unwrap_used)]
RE.get_or_init(|| {
regex_lite::Regex::new(r"(?i)try again in\s*(\d+(?:\.\d+)?)\s*(s|ms|seconds?)").unwrap()
})
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use codex_protocol::models::ResponseItem;
use pretty_assertions::assert_eq;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_test::io::Builder as IoBuilder;
async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent, ApiError>> {
let mut builder = IoBuilder::new();
for chunk in chunks {
builder.read(chunk);
}
let reader = builder.build();
let stream =
ReaderStream::new(reader).map_err(|err| TransportError::Network(err.to_string()));
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(16);
tokio::spawn(process_sse(Box::pin(stream), tx, idle_timeout(), None));
let mut events = Vec::new();
while let Some(ev) = rx.recv().await {
events.push(ev);
}
events
}
async fn run_sse(events: Vec<serde_json::Value>) -> Vec<ResponseEvent> {
let mut body = String::new();
for e in events {
let kind = e
.get("type")
.and_then(|v| v.as_str())
.expect("fixture event missing type");
if e.as_object().map(|o| o.len() == 1).unwrap_or(false) {
body.push_str(&format!("event: {kind}\n\n"));
} else {
body.push_str(&format!("event: {kind}\ndata: {e}\n\n"));
}
}
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent, ApiError>>(8);
let stream = ReaderStream::new(std::io::Cursor::new(body))
.map_err(|err| TransportError::Network(err.to_string()));
tokio::spawn(process_sse(Box::pin(stream), tx, idle_timeout(), None));
let mut out = Vec::new();
while let Some(ev) = rx.recv().await {
out.push(ev.expect("channel closed"));
}
out
}
fn idle_timeout() -> Duration {
Duration::from_millis(1000)
}
#[tokio::test]
async fn parses_items_and_completed() {
let item1 = json!({
"type": "response.output_item.done",
"item": {
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "Hello"}]
}
})
.to_string();
let item2 = json!({
"type": "response.output_item.done",
"item": {
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "World"}]
}
})
.to_string();
let completed = json!({
"type": "response.completed",
"response": { "id": "resp1" }
})
.to_string();
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n");
let sse3 = format!("event: response.completed\ndata: {completed}\n\n");
let events = collect_events(&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()]).await;
assert_eq!(events.len(), 3);
assert_matches!(
&events[0],
Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }))
if role == "assistant"
);
assert_matches!(
&events[1],
Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. }))
if role == "assistant"
);
match &events[2] {
Ok(ResponseEvent::Completed {
response_id,
token_usage,
}) => {
assert_eq!(response_id, "resp1");
assert!(token_usage.is_none());
}
other => panic!("unexpected third event: {other:?}"),
}
}
#[tokio::test]
async fn error_when_missing_completed() {
let item1 = json!({
"type": "response.output_item.done",
"item": {
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "Hello"}]
}
})
.to_string();
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
let events = collect_events(&[sse1.as_bytes()]).await;
assert_eq!(events.len(), 2);
assert_matches!(events[0], Ok(ResponseEvent::OutputItemDone(_)));
match &events[1] {
Err(ApiError::Stream(msg)) => {
assert_eq!(msg, "stream closed before response.completed")
}
other => panic!("unexpected second event: {other:?}"),
}
}
#[tokio::test]
async fn error_when_error_event() {
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_689bcf18d7f08194bf3440ba62fe05d803fee0cdac429894","object":"response","created_at":1755041560,"status":"failed","background":false,"error":{"code":"rate_limit_exceeded","message":"Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."}, "usage":null,"user":null,"metadata":{}}}"#;
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
let events = collect_events(&[sse1.as_bytes()]).await;
assert_eq!(events.len(), 1);
match &events[0] {
Err(ApiError::Retryable { message, delay }) => {
assert_eq!(
message,
"Rate limit reached for gpt-5.1 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."
);
assert_eq!(*delay, Some(Duration::from_secs_f64(11.054)));
}
other => panic!("unexpected second event: {other:?}"),
}
}
#[tokio::test]
async fn context_window_error_is_fatal() {
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_5c66275b97b9baef1ed95550adb3b7ec13b17aafd1d2f11b","object":"response","created_at":1759510079,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."},"usage":null,"user":null,"metadata":{}}}"#;
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
let events = collect_events(&[sse1.as_bytes()]).await;
assert_eq!(events.len(), 1);
assert_matches!(events[0], Err(ApiError::ContextWindowExceeded));
}
#[tokio::test]
async fn context_window_error_with_newline_is_fatal() {
let raw_error = r#"{"type":"response.failed","sequence_number":4,"response":{"id":"resp_fatal_newline","object":"response","created_at":1759510080,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try\nagain."},"usage":null,"user":null,"metadata":{}}}"#;
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
let events = collect_events(&[sse1.as_bytes()]).await;
assert_eq!(events.len(), 1);
assert_matches!(events[0], Err(ApiError::ContextWindowExceeded));
}
#[tokio::test]
async fn quota_exceeded_error_is_fatal() {
let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_fatal_quota","object":"response","created_at":1759771626,"status":"failed","background":false,"error":{"code":"insufficient_quota","message":"You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors."},"incomplete_details":null}}"#;
let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n");
let events = collect_events(&[sse1.as_bytes()]).await;
assert_eq!(events.len(), 1);
assert_matches!(events[0], Err(ApiError::QuotaExceeded));
}
#[tokio::test]
async fn table_driven_event_kinds() {
struct TestCase {
name: &'static str,
event: serde_json::Value,
expect_first: fn(&ResponseEvent) -> bool,
expected_len: usize,
}
fn is_created(ev: &ResponseEvent) -> bool {
matches!(ev, ResponseEvent::Created)
}
fn is_output(ev: &ResponseEvent) -> bool {
matches!(ev, ResponseEvent::OutputItemDone(_))
}
fn is_completed(ev: &ResponseEvent) -> bool {
matches!(ev, ResponseEvent::Completed { .. })
}
let completed = json!({
"type": "response.completed",
"response": {
"id": "c",
"usage": {
"input_tokens": 0,
"input_tokens_details": null,
"output_tokens": 0,
"output_tokens_details": null,
"total_tokens": 0
},
"output": []
}
});
let cases = vec![
TestCase {
name: "created",
event: json!({"type": "response.created", "response": {}}),
expect_first: is_created,
expected_len: 2,
},
TestCase {
name: "output_item.done",
event: json!({
"type": "response.output_item.done",
"item": {
"type": "message",
"role": "assistant",
"content": [
{"type": "output_text", "text": "hi"}
]
}
}),
expect_first: is_output,
expected_len: 2,
},
TestCase {
name: "unknown",
event: json!({"type": "response.new_tool_event"}),
expect_first: is_completed,
expected_len: 1,
},
];
for case in cases {
let mut evs = vec![case.event];
evs.push(completed.clone());
let out = run_sse(evs).await;
assert_eq!(out.len(), case.expected_len, "case {}", case.name);
assert!(
(case.expect_first)(&out[0]),
"first event mismatch in case {}",
case.name
);
}
}
#[test]
fn test_try_parse_retry_after() {
let err = Error {
r#type: None,
message: Some("Rate limit reached for gpt-5.1 in organization org- on tokens per min (TPM): Limit 1, Used 1, Requested 19304. Please try again in 28ms. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()),
code: Some("rate_limit_exceeded".to_string()),
plan_type: None,
resets_at: None,
};
let delay = try_parse_retry_after(&err);
assert_eq!(delay, Some(Duration::from_millis(28)));
}
#[test]
fn test_try_parse_retry_after_no_delay() {
let err = Error {
r#type: None,
message: Some("Rate limit reached for gpt-5.1 in organization <ORG> on tokens per min (TPM): Limit 30000, Used 6899, Requested 24050. Please try again in 1.898s. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()),
code: Some("rate_limit_exceeded".to_string()),
plan_type: None,
resets_at: None,
};
let delay = try_parse_retry_after(&err);
assert_eq!(delay, Some(Duration::from_secs_f64(1.898)));
}
#[test]
fn test_try_parse_retry_after_azure() {
let err = Error {
r#type: None,
message: Some("Rate limit exceeded. Try again in 35 seconds.".to_string()),
code: Some("rate_limit_exceeded".to_string()),
plan_type: None,
resets_at: None,
};
let delay = try_parse_retry_after(&err);
assert_eq!(delay, Some(Duration::from_secs(35)));
}
}