Compare commits

..

3 Commits

Author SHA1 Message Date
aibrahim-oai
e2343e710b Fix SSE parser test clippy issues 2025-07-11 13:19:32 -07:00
aibrahim-oai
64897d9083 Format SSE parser test 2025-07-11 12:09:42 -07:00
aibrahim-oai
e274cd04de test: ensure chat completions SSE parser merges chunks 2025-07-11 12:01:27 -07:00
4 changed files with 76 additions and 149 deletions

View File

@@ -462,3 +462,66 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use futures::stream;
use tokio::sync::mpsc;
#[tokio::test]
async fn merges_function_call_chunks_and_completes() {
let chunks = vec![
"data: {\"choices\":[{\"delta\":{\"content\":\"Hello \"}}]}\n\n",
"data: {\"choices\":[{\"delta\":{\"content\":\"world\"}}]}\n\n",
"data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call1\",\"type\":\"function\",\"function\":{\"name\":\"foo\"}}]}}]}\n\n",
"data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"a\\\": \"}}]}}]}\n\n",
"data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"1}\"}}]},\"finish_reason\":\"tool_calls\"}]}\n\n",
];
let byte_stream = stream::iter(chunks.into_iter().map(|s| Ok(Bytes::from(s))));
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent>>(8);
tokio::spawn(process_chat_sse(byte_stream, tx));
let mut events = Vec::new();
while let Some(ev) = rx.recv().await {
match ev {
Ok(event) => events.push(event),
Err(e) => panic!("stream error: {e}"),
}
}
assert_eq!(events.len(), 4);
let mut text = String::new();
for (i, event) in events.iter().take(2).enumerate() {
match event {
ResponseEvent::OutputItemDone(ResponseItem::Message { role, content })
if role == "assistant" =>
{
if let Some(ContentItem::OutputText { text: t }) = content.first() {
text.push_str(t);
}
}
other => panic!("unexpected event {i}: {other:?}"),
}
}
assert_eq!(text, "Hello world");
match &events[2] {
ResponseEvent::OutputItemDone(ResponseItem::FunctionCall {
name,
arguments,
call_id,
}) => {
assert_eq!(name, "foo");
assert_eq!(call_id, "call1");
assert_eq!(arguments, "{\"a\": 1}");
}
other => panic!("unexpected third event: {other:?}"),
}
assert!(matches!(events[3], ResponseEvent::Completed { .. }));
}
}

View File

@@ -391,85 +391,3 @@ async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
tokio::spawn(process_sse(stream, tx_event));
Ok(ResponseStream { rx_event })
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_stream(events: Vec<serde_json::Value>) -> impl Stream<Item = Result<Bytes>> {
use futures::TryStreamExt as _;
let mut buf = String::new();
for ev in &events {
let kind = ev["type"].as_str().unwrap();
buf.push_str("event: ");
buf.push_str(kind);
buf.push('\n');
buf.push_str("data: ");
buf.push_str(&ev.to_string());
buf.push_str("\n\n");
}
ReaderStream::new(std::io::Cursor::new(buf))
.map_ok(Bytes::from)
.map_err(CodexErr::Io)
}
#[tokio::test]
async fn table_driven_event_kinds() {
struct Case {
event: serde_json::Value,
expect: Option<&'static str>,
}
let cases = vec![
Case {
event: json!({ "type": "response.created", "response": {} }),
expect: Some("Created"),
},
Case {
event: json!({
"type": "response.output_item.done",
"item": { "type": "message", "role": "assistant", "content": [] }
}),
expect: Some("OutputItemDone"),
},
Case {
event: json!({ "type": "response.in_progress" }),
expect: None,
},
Case {
event: json!({ "type": "unknown.event" }),
expect: None,
},
];
for case in cases {
let completed =
json!({ "type": "response.completed", "response": { "id": "test", "output": [] } });
let stream = make_stream(vec![case.event.clone(), completed]);
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent>>(8);
tokio::spawn(process_sse(stream, tx));
let mut got = Vec::new();
while let Some(ev) = rx.recv().await {
got.push(ev.unwrap());
}
assert!(
matches!(got.last(), Some(ResponseEvent::Completed { response_id, .. }) if response_id == "test")
);
let non_completed = &got[..got.len() - 1];
match case.expect {
Some("Created") => {
assert!(matches!(non_completed.get(0), Some(ResponseEvent::Created)))
}
Some("OutputItemDone") => assert!(matches!(
non_completed.get(0),
Some(ResponseEvent::OutputItemDone(_))
)),
None => assert!(non_completed.is_empty()),
_ => unreachable!(),
}
}
}
}

View File

@@ -1,14 +0,0 @@
{
"first": [
{
"type": "response.output_item.done",
"item": { "type": "message", "role": "assistant", "content": [] }
}
],
"second": [
{
"type": "response.completed",
"response": { "id": "resp_ok", "output": [] }
}
]
}

View File

@@ -21,48 +21,16 @@ use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;
/// Load the SSE event sequences for the test from a JSON fixture.
///
/// The fixture lives in `tests/fixtures/stream_retry.json` and defines two
/// arrays of events: `first` and `second`. Each entry represents the JSON body
/// of a single SSE event. When the Responses API evolves with new fields or
/// event kinds simply update that JSON file or add a new one and reference it
/// here.
fn load_fixture() -> (String, String) {
use serde_json::Value;
use std::fs;
use std::path::PathBuf;
fn sse_incomplete() -> String {
// Only a single line; missing the completed event.
"event: response.output_item.done\n\n".to_string()
}
fn events_to_sse(events: &[Value]) -> String {
let mut out = String::new();
for ev in events {
let kind = ev
.get("type")
.and_then(Value::as_str)
.expect("event missing type");
out.push_str("event: ");
out.push_str(kind);
out.push('\n');
out.push_str("data: ");
out.push_str(&ev.to_string());
out.push_str("\n\n");
}
out
}
let path: PathBuf = [
env!("CARGO_MANIFEST_DIR"),
"tests",
"fixtures",
"stream_retry.json",
]
.iter()
.collect();
let raw = fs::read_to_string(path).expect("fixture missing");
let v: Value = serde_json::from_str(&raw).expect("invalid fixture JSON");
let first = events_to_sse(v.get("first").and_then(Value::as_array).unwrap());
let second = events_to_sse(v.get("second").and_then(Value::as_array).unwrap());
(first, second)
fn sse_completed(id: &str) -> String {
format!(
"event: response.completed\n\
data: {{\"type\":\"response.completed\",\"response\":{{\"id\":\"{id}\",\"output\":[]}}}}\n\n\n"
)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
@@ -77,13 +45,8 @@ async fn retries_on_early_close() {
}
let server = MockServer::start().await;
// Convert the JSON fixture into SSE event strings for the two mock calls.
let (sse_first, sse_second) = load_fixture();
struct SeqResponder {
first: String,
second: String,
}
struct SeqResponder;
impl Respond for SeqResponder {
fn respond(&self, _: &Request) -> ResponseTemplate {
use std::sync::atomic::AtomicUsize;
@@ -93,21 +56,18 @@ async fn retries_on_early_close() {
if n == 0 {
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(self.first.clone(), "text/event-stream")
.set_body_raw(sse_incomplete(), "text/event-stream")
} else {
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(self.second.clone(), "text/event-stream")
.set_body_raw(sse_completed("resp_ok"), "text/event-stream")
}
}
}
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(SeqResponder {
first: sse_first,
second: sse_second,
})
.respond_with(SeqResponder {})
.expect(2)
.mount(&server)
.await;