Files
codex/prs/bolinfest/PR-1538.md
2025-09-02 15:17:45 -07:00

998 lines
35 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# PR #1538: Add tests for chat stream aggregation and tool events
- URL: https://github.com/openai/codex/pull/1538
- Author: aibrahim-oai
- Created: 2025-07-11 19:03:03 UTC
- Updated: 2025-07-21 20:58:18 UTC
- Changes: +323/-12, Files changed: 5, Commits: 12
## Description
## Summary
- unit test AggregatedChatStream to ensure it merges assistant message deltas and forwards other items
- verify parsing of function_call_output and local_shell_call SSE events
- ensure chat request payload encodes tool calls correctly
## Testing
- `cargo test -p codex-core --manifest-path codex-rs/Cargo.toml`
- `cargo test --manifest-path codex-rs/Cargo.toml --all --tests` *(fails: Sandbox(LandlockRestrict))*
------
https://chatgpt.com/codex/tasks/task_i_687158d61e748321ba5f1631199bd8a4
## Full Diff
```diff
diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs
index ad7b55952a..8eabcaf342 100644
--- a/codex-rs/core/src/chat_completions.rs
+++ b/codex-rs/core/src/chat_completions.rs
@@ -458,6 +458,9 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
/// // event now contains cumulative text
/// }
/// ```
+ ///
+ /// See [`tests::aggregates_consecutive_message_chunks`] for an example.
+ /// ```
fn aggregate(self) -> AggregatedChatStream<Self> {
AggregatedChatStream {
inner: self,
@@ -468,3 +471,237 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+ #![allow(clippy::unwrap_used)]
+
+ use super::*;
+ use crate::models::FunctionCallOutputPayload;
+ use crate::models::LocalShellAction;
+ use crate::models::LocalShellExecAction;
+ use crate::models::LocalShellStatus;
+ use crate::openai_tools::create_tools_json_for_chat_completions_api;
+ use futures::StreamExt;
+ use futures::stream;
+ use serde_json::json;
+
+ /// Helper constructing a minimal assistant text chunk.
+ fn text_chunk(txt: &str) -> ResponseEvent {
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".to_string(),
+ content: vec![ContentItem::OutputText { text: txt.into() }],
+ })
+ }
+
+ #[tokio::test]
+ async fn aggregates_consecutive_message_chunks() {
+ let events = vec![
+ Ok(text_chunk("Hello")),
+ Ok(text_chunk(", world")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r1".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ let expected = vec![
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".into(),
+ content: vec![ContentItem::OutputText {
+ text: "Hello, world".into(),
+ }],
+ }),
+ ResponseEvent::Completed {
+ response_id: "r1".into(),
+ token_usage: None,
+ },
+ ];
+
+ assert_eq!(
+ collected, expected,
+ "aggregated assistant message + Completed"
+ );
+ }
+
+ #[tokio::test]
+ async fn forwards_non_text_items_without_merging() {
+ let func_call = ResponseItem::FunctionCall {
+ name: "shell".to_string(),
+ arguments: "{}".to_string(),
+ call_id: "call1".to_string(),
+ };
+
+ let events = vec![
+ Ok(text_chunk("foo")),
+ Ok(ResponseEvent::OutputItemDone(func_call.clone())),
+ Ok(text_chunk("bar")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r2".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ let expected = vec![
+ ResponseEvent::OutputItemDone(func_call.clone()),
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".into(),
+ content: vec![ContentItem::OutputText {
+ text: "foobar".into(),
+ }],
+ }),
+ ResponseEvent::Completed {
+ response_id: "r2".into(),
+ token_usage: None,
+ },
+ ];
+
+ assert_eq!(
+ collected, expected,
+ "non-text items forwarded intact; text merged"
+ );
+ }
+
+ #[tokio::test]
+ async fn formats_tool_calls_in_chat_payload() {
+ use std::sync::Arc;
+ use std::sync::Mutex;
+ use wiremock::Mock;
+ use wiremock::MockServer;
+ use wiremock::Request;
+ use wiremock::Respond;
+ use wiremock::ResponseTemplate;
+ use wiremock::matchers::method;
+ use wiremock::matchers::path;
+
+ struct CaptureResponder(Arc<Mutex<Option<serde_json::Value>>>);
+ impl Respond for CaptureResponder {
+ fn respond(&self, req: &Request) -> ResponseTemplate {
+ let v: serde_json::Value = serde_json::from_slice(&req.body).unwrap();
+ *self.0.lock().unwrap() = Some(v);
+ ResponseTemplate::new(200)
+ .insert_header("content-type", "text/event-stream")
+ .set_body_raw(
+ "event: response.completed\n\
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
+ "text/event-stream",
+ )
+ }
+ }
+
+ let server = MockServer::start().await;
+ let captured = Arc::new(Mutex::new(None));
+
+ Mock::given(method("POST"))
+ .and(path("/v1/chat/completions"))
+ .respond_with(CaptureResponder(captured.clone()))
+ .expect(1)
+ .mount(&server)
+ .await;
+
+ // Build provider pointing at mock server; no need to mutate global env vars.
+ let provider = ModelProviderInfo {
+ name: "openai".into(),
+ base_url: format!("{}/v1", server.uri()),
+ env_key: Some("PATH".into()),
+ env_key_instructions: None,
+ wire_api: crate::WireApi::Chat,
+ query_params: None,
+ http_headers: None,
+ env_http_headers: None,
+ };
+
+ let mut prompt = Prompt::default();
+ prompt.input.push(ResponseItem::Message {
+ role: "user".into(),
+ content: vec![ContentItem::InputText { text: "hi".into() }],
+ });
+ prompt.input.push(ResponseItem::FunctionCall {
+ name: "shell".into(),
+ arguments: "[]".into(),
+ call_id: "call123".into(),
+ });
+ prompt.input.push(ResponseItem::FunctionCallOutput {
+ call_id: "call123".into(),
+ output: FunctionCallOutputPayload {
+ content: "ok".into(),
+ success: Some(true),
+ },
+ });
+ prompt.input.push(ResponseItem::LocalShellCall {
+ id: Some("ls1".into()),
+ call_id: Some("call456".into()),
+ status: LocalShellStatus::Completed,
+ action: LocalShellAction::Exec(LocalShellExecAction {
+ command: vec!["echo".into(), "hi".into()],
+ timeout_ms: Some(1),
+ working_directory: None,
+ env: None,
+ user: None,
+ }),
+ });
+
+ let client = reqwest::Client::new();
+ let _ = stream_chat_completions(&prompt, "model", &client, &provider)
+ .await
+ .unwrap();
+
+ let body = captured.lock().unwrap().take().unwrap();
+
+ // Build the expected payload exactly as stream_chat_completions() should.
+ let full_instructions = prompt.get_full_instructions("model");
+ let expected_messages = vec![
+ json!({"role":"system","content":full_instructions}),
+ json!({"role":"user","content":"hi"}),
+ json!({
+ "role":"assistant",
+ "content":null,
+ "tool_calls":[{
+ "id":"call123",
+ "type":"function",
+ "function":{
+ "name":"shell",
+ "arguments":"[]"
+ }
+ }]
+ }),
+ json!({
+ "role":"tool",
+ "tool_call_id":"call123",
+ "content":"ok"
+ }),
+ json!({
+ "role":"assistant",
+ "content":null,
+ "tool_calls":[{
+ "id":"ls1",
+ "type":"local_shell_call",
+ "status":"completed",
+ "action":{
+ "type":"exec",
+ "command":["echo","hi"],
+ "timeout_ms":1,
+ "working_directory":null,
+ "env":null,
+ "user":null
+ }
+ }]
+ }),
+ ];
+ let tools_json = create_tools_json_for_chat_completions_api(&prompt, "model").unwrap();
+ let expected_body = json!({
+ "model":"model",
+ "messages": expected_messages,
+ "stream": true,
+ "tools": tools_json,
+ });
+
+ assert_eq!(body, expected_body, "chat payload encoded incorrectly");
+ }
+}
diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs
index 8ec68d02e8..b2ff284fd0 100644
--- a/codex-rs/core/src/client.rs
+++ b/codex-rs/core/src/client.rs
@@ -317,7 +317,7 @@ where
// duplicated `output` array embedded in the `response.completed`
// payload. That produced two concrete issues:
// 1. No realtime streaming the user only saw output after the
- // entire turn had finished, which broke the “typing” UX and
+ // entire turn had finished, which broke the "typing" UX and
// made longrunning turns look stalled.
// 2. Duplicate `function_call_output` items both the
// individual *and* the completed array were forwarded, which
@@ -390,6 +390,7 @@ where
}
/// used in tests to stream from a text SSE file
+#[allow(dead_code)]
async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let f = std::fs::File::open(path.as_ref())?;
@@ -413,6 +414,8 @@ mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
+ use crate::models::LocalShellAction;
+ use crate::models::LocalShellStatus;
use serde_json::json;
use tokio::sync::mpsc;
use tokio_test::io::Builder as IoBuilder;
@@ -422,6 +425,17 @@ mod tests {
// Helpers
// ────────────────────────────
+ /// Build a tiny SSE string with the provided *raw* event chunks (already formatted as
+ /// `"event: ...\ndata: ..."` lines). Each chunk is separated by a blank line.
+ fn build_sse(chunks: &[&str]) -> String {
+ let mut out = String::new();
+ for c in chunks {
+ out.push_str(c);
+ out.push_str("\n\n");
+ }
+ out
+ }
+
/// Runs the SSE parser on pre-chunked byte slices and returns every event
/// (including any final `Err` from a stream-closure check).
async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent>> {
@@ -469,6 +483,65 @@ mod tests {
out
}
+ // ────────────────────────────
+ // Tests from `implement-unit-tests-for-event-aggregation-and-tool-calls`
+ // ────────────────────────────
+
+ #[tokio::test]
+ async fn parses_function_and_local_shell_items() {
+ let func = "event: response.output_item.done\n\
+data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call_output\",\"call_id\":\"call1\",\"output\":{\"content\":\"ok\",\"success\":true}}}";
+ let shell = "event: response.output_item.done\n\
+data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"local_shell_call\",\"id\":\"ls1\",\"call_id\":\"call2\",\"status\":\"in_progress\",\"action\":{\"type\":\"exec\",\"command\":[\"echo\",\"hi\"],\"timeout_ms\":123,\"working_directory\":null,\"env\":null,\"user\":null}}}";
+ let done = "event: response.completed\n\
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}";
+
+ let content = build_sse(&[func, shell, done]);
+
+ let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<ResponseEvent>>(8);
+ let stream = ReaderStream::new(std::io::Cursor::new(content)).map_err(CodexErr::Io);
+ tokio::spawn(super::process_sse(stream, tx));
+
+ // function_call_output
+ match rx.recv().await.unwrap().unwrap() {
+ ResponseEvent::OutputItemDone(ResponseItem::FunctionCallOutput { call_id, output }) => {
+ assert_eq!(call_id, "call1");
+ assert_eq!(output.content, "ok");
+ assert_eq!(output.success, Some(true));
+ }
+ other => panic!("unexpected first event: {other:?}"),
+ }
+
+ // local_shell_call
+ match rx.recv().await.unwrap().unwrap() {
+ ResponseEvent::OutputItemDone(ResponseItem::LocalShellCall {
+ id,
+ call_id,
+ status,
+ action,
+ }) => {
+ assert_eq!(id.as_deref(), Some("ls1"));
+ assert_eq!(call_id.as_deref(), Some("call2"));
+ if !matches!(status, LocalShellStatus::InProgress) {
+ panic!("unexpected status: {status:?}");
+ }
+ match action {
+ LocalShellAction::Exec(act) => {
+ assert_eq!(act.command, vec!["echo".to_string(), "hi".to_string()]);
+ assert_eq!(act.timeout_ms, Some(123));
+ }
+ }
+ }
+ other => panic!("unexpected second event: {other:?}"),
+ }
+
+ // completed
+ assert!(matches!(
+ rx.recv().await.unwrap().unwrap(),
+ ResponseEvent::Completed { response_id, .. } if response_id == "resp"
+ ));
+ }
+
// ────────────────────────────
// Tests from `implement-test-for-responses-api-sse-parser`
// ────────────────────────────
@@ -549,6 +622,7 @@ mod tests {
let events = collect_events(&[sse1.as_bytes()]).await;
+ // We expect the item + a final Err complaining about the missing completed event.
assert_eq!(events.len(), 2);
matches!(events[0], Ok(ResponseEvent::OutputItemDone(_)));
diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs
index 3e3c2e7efa..6b220d4fff 100644
--- a/codex-rs/core/src/client_common.rs
+++ b/codex-rs/core/src/client_common.rs
@@ -49,7 +49,7 @@ impl Prompt {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone, PartialEq)]
pub enum ResponseEvent {
Created,
OutputItemDone(ResponseItem),
diff --git a/codex-rs/core/src/models.rs b/codex-rs/core/src/models.rs
index 6b392fb19d..26babba715 100644
--- a/codex-rs/core/src/models.rs
+++ b/codex-rs/core/src/models.rs
@@ -8,7 +8,7 @@ use serde::ser::Serializer;
use crate::protocol::InputItem;
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseInputItem {
Message {
@@ -25,7 +25,7 @@ pub enum ResponseInputItem {
},
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentItem {
InputText { text: String },
@@ -33,7 +33,7 @@ pub enum ContentItem {
OutputText { text: String },
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseItem {
Message {
@@ -99,7 +99,7 @@ impl From<ResponseInputItem> for ResponseItem {
}
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum LocalShellStatus {
Completed,
@@ -107,13 +107,13 @@ pub enum LocalShellStatus {
Incomplete,
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum LocalShellAction {
Exec(LocalShellExecAction),
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LocalShellExecAction {
pub command: Vec<String>,
pub timeout_ms: Option<u64>,
@@ -122,7 +122,7 @@ pub struct LocalShellExecAction {
pub user: Option<String>,
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ReasoningItemReasoningSummary {
SummaryText { text: String },
@@ -177,10 +177,10 @@ pub struct ShellToolCallParams {
pub timeout_ms: Option<u64>,
}
-#[derive(Deserialize, Debug, Clone)]
+#[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionCallOutputPayload {
pub content: String,
- #[expect(dead_code)]
+ #[allow(dead_code)]
pub success: Option<bool>,
}
diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs
index b233d4f27b..c14b2e190a 100644
--- a/codex-rs/core/src/protocol.rs
+++ b/codex-rs/core/src/protocol.rs
@@ -332,7 +332,7 @@ pub struct TaskCompleteEvent {
pub last_agent_message: Option<String>,
}
-#[derive(Debug, Clone, Deserialize, Serialize, Default)]
+#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq)]
pub struct TokenUsage {
pub input_tokens: u64,
pub cached_input_tokens: Option<u64>,
```
## Review Comments
### codex-rs/core/src/chat_completions.rs
- Created: 2025-07-12 19:43:31 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202890339
```diff
@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+ #![allow(clippy::unwrap_used)]
+
+ use super::*;
+ use crate::models::FunctionCallOutputPayload;
+ use crate::models::LocalShellAction;
+ use crate::models::LocalShellExecAction;
+ use crate::models::LocalShellStatus;
+ use futures::StreamExt;
+ use futures::stream;
+
+ /// Helper constructing a minimal assistant text chunk.
+ fn text_chunk(txt: &str) -> ResponseEvent {
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".to_string(),
+ content: vec![ContentItem::OutputText { text: txt.into() }],
+ })
+ }
+
+ #[tokio::test]
+ async fn aggregates_consecutive_message_chunks() {
+ let events = vec![
+ Ok(text_chunk("Hello")),
+ Ok(text_chunk(", world")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r1".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 2, "only final message and Completed");
```
> just `assert_eq!()` on all of `collected`?
- Created: 2025-07-12 19:44:08 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202890464
```diff
@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+ #![allow(clippy::unwrap_used)]
+
+ use super::*;
+ use crate::models::FunctionCallOutputPayload;
+ use crate::models::LocalShellAction;
+ use crate::models::LocalShellExecAction;
+ use crate::models::LocalShellStatus;
+ use futures::StreamExt;
+ use futures::stream;
+
+ /// Helper constructing a minimal assistant text chunk.
+ fn text_chunk(txt: &str) -> ResponseEvent {
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".to_string(),
+ content: vec![ContentItem::OutputText { text: txt.into() }],
+ })
+ }
+
+ #[tokio::test]
+ async fn aggregates_consecutive_message_chunks() {
+ let events = vec![
+ Ok(text_chunk("Hello")),
+ Ok(text_chunk(", world")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r1".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 2, "only final message and Completed");
+
+ match &collected[0] {
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+ let text = match &content[0] {
+ ContentItem::OutputText { text } => text,
+ _ => panic!("unexpected content item"),
+ };
+ assert_eq!(text, "Hello, world");
+ }
+ other => panic!("unexpected first event: {other:?}"),
+ }
+
+ assert!(matches!(
+ collected[1],
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1"
+ ));
+ }
+
+ #[tokio::test]
+ async fn forwards_non_text_items_without_merging() {
+ let func_call = ResponseItem::FunctionCall {
+ name: "shell".to_string(),
+ arguments: "{}".to_string(),
+ call_id: "call1".to_string(),
+ };
+
+ let events = vec![
+ Ok(text_chunk("foo")),
+ Ok(ResponseEvent::OutputItemDone(func_call.clone())),
+ Ok(text_chunk("bar")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r2".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 3);
```
> same here
- Created: 2025-07-12 19:45:40 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202891601
```diff
@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+ #![allow(clippy::unwrap_used)]
+
+ use super::*;
+ use crate::models::FunctionCallOutputPayload;
+ use crate::models::LocalShellAction;
+ use crate::models::LocalShellExecAction;
+ use crate::models::LocalShellStatus;
+ use futures::StreamExt;
+ use futures::stream;
+
+ /// Helper constructing a minimal assistant text chunk.
+ fn text_chunk(txt: &str) -> ResponseEvent {
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".to_string(),
+ content: vec![ContentItem::OutputText { text: txt.into() }],
+ })
+ }
+
+ #[tokio::test]
+ async fn aggregates_consecutive_message_chunks() {
+ let events = vec![
+ Ok(text_chunk("Hello")),
+ Ok(text_chunk(", world")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r1".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 2, "only final message and Completed");
+
+ match &collected[0] {
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+ let text = match &content[0] {
+ ContentItem::OutputText { text } => text,
+ _ => panic!("unexpected content item"),
+ };
+ assert_eq!(text, "Hello, world");
+ }
+ other => panic!("unexpected first event: {other:?}"),
+ }
+
+ assert!(matches!(
+ collected[1],
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1"
+ ));
+ }
+
+ #[tokio::test]
+ async fn forwards_non_text_items_without_merging() {
+ let func_call = ResponseItem::FunctionCall {
+ name: "shell".to_string(),
+ arguments: "{}".to_string(),
+ call_id: "call1".to_string(),
+ };
+
+ let events = vec![
+ Ok(text_chunk("foo")),
+ Ok(ResponseEvent::OutputItemDone(func_call.clone())),
+ Ok(text_chunk("bar")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r2".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 3);
+
+ // First event should be the function call forwarded directly.
+ assert!(matches!(
+ collected[0],
+ ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
+ ));
+
+ // Second is the combined assistant message.
+ match &collected[1] {
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+ let text = match &content[0] {
+ ContentItem::OutputText { text } => text,
+ _ => panic!("unexpected content item"),
+ };
+ assert_eq!(text, "foobar");
+ }
+ other => panic!("unexpected second event: {other:?}"),
+ }
+
+ // Final Completed event.
+ assert!(matches!(
+ collected[2],
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r2"
+ ));
+ }
+
+ #[tokio::test]
+ async fn formats_tool_calls_in_chat_payload() {
+ use serde_json::Value;
+ use std::sync::Arc;
+ use std::sync::Mutex;
+ use wiremock::Mock;
+ use wiremock::MockServer;
+ use wiremock::Request;
+ use wiremock::Respond;
+ use wiremock::ResponseTemplate;
+ use wiremock::matchers::method;
+ use wiremock::matchers::path;
+
+ struct CaptureResponder(Arc<Mutex<Option<Value>>>);
+ impl Respond for CaptureResponder {
+ fn respond(&self, req: &Request) -> ResponseTemplate {
+ let v: Value = serde_json::from_slice(&req.body).unwrap();
+ *self.0.lock().unwrap() = Some(v);
+ ResponseTemplate::new(200)
+ .insert_header("content-type", "text/event-stream")
+ .set_body_raw(
+ "event: response.completed\n\
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
+ "text/event-stream",
+ )
+ }
+ }
+
+ let server = MockServer::start().await;
+ let captured = Arc::new(Mutex::new(None));
+
+ Mock::given(method("POST"))
+ .and(path("/v1/chat/completions"))
+ .respond_with(CaptureResponder(captured.clone()))
+ .expect(1)
+ .mount(&server)
+ .await;
+
+ unsafe {
+ std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
```
> At some point, we should really find another way to thread this through so we can eliminate all these `unsafe` blocks.
- Created: 2025-07-12 19:46:18 UTC | Link: https://github.com/openai/codex/pull/1538#discussion_r2202892586
```diff
@@ -462,3 +465,228 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
}
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
+
+#[cfg(test)]
+mod tests {
+ #![allow(clippy::unwrap_used)]
+
+ use super::*;
+ use crate::models::FunctionCallOutputPayload;
+ use crate::models::LocalShellAction;
+ use crate::models::LocalShellExecAction;
+ use crate::models::LocalShellStatus;
+ use futures::StreamExt;
+ use futures::stream;
+
+ /// Helper constructing a minimal assistant text chunk.
+ fn text_chunk(txt: &str) -> ResponseEvent {
+ ResponseEvent::OutputItemDone(ResponseItem::Message {
+ role: "assistant".to_string(),
+ content: vec![ContentItem::OutputText { text: txt.into() }],
+ })
+ }
+
+ #[tokio::test]
+ async fn aggregates_consecutive_message_chunks() {
+ let events = vec![
+ Ok(text_chunk("Hello")),
+ Ok(text_chunk(", world")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r1".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 2, "only final message and Completed");
+
+ match &collected[0] {
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+ let text = match &content[0] {
+ ContentItem::OutputText { text } => text,
+ _ => panic!("unexpected content item"),
+ };
+ assert_eq!(text, "Hello, world");
+ }
+ other => panic!("unexpected first event: {other:?}"),
+ }
+
+ assert!(matches!(
+ collected[1],
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r1"
+ ));
+ }
+
+ #[tokio::test]
+ async fn forwards_non_text_items_without_merging() {
+ let func_call = ResponseItem::FunctionCall {
+ name: "shell".to_string(),
+ arguments: "{}".to_string(),
+ call_id: "call1".to_string(),
+ };
+
+ let events = vec![
+ Ok(text_chunk("foo")),
+ Ok(ResponseEvent::OutputItemDone(func_call.clone())),
+ Ok(text_chunk("bar")),
+ Ok(ResponseEvent::Completed {
+ response_id: "r2".to_string(),
+ token_usage: None,
+ }),
+ ];
+
+ let stream = stream::iter(events).aggregate();
+ let collected: Vec<_> = stream.map(Result::unwrap).collect().await;
+
+ assert_eq!(collected.len(), 3);
+
+ // First event should be the function call forwarded directly.
+ assert!(matches!(
+ collected[0],
+ ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { .. })
+ ));
+
+ // Second is the combined assistant message.
+ match &collected[1] {
+ ResponseEvent::OutputItemDone(ResponseItem::Message { content, .. }) => {
+ let text = match &content[0] {
+ ContentItem::OutputText { text } => text,
+ _ => panic!("unexpected content item"),
+ };
+ assert_eq!(text, "foobar");
+ }
+ other => panic!("unexpected second event: {other:?}"),
+ }
+
+ // Final Completed event.
+ assert!(matches!(
+ collected[2],
+ ResponseEvent::Completed { response_id: ref id, token_usage: None } if id == "r2"
+ ));
+ }
+
+ #[tokio::test]
+ async fn formats_tool_calls_in_chat_payload() {
+ use serde_json::Value;
+ use std::sync::Arc;
+ use std::sync::Mutex;
+ use wiremock::Mock;
+ use wiremock::MockServer;
+ use wiremock::Request;
+ use wiremock::Respond;
+ use wiremock::ResponseTemplate;
+ use wiremock::matchers::method;
+ use wiremock::matchers::path;
+
+ struct CaptureResponder(Arc<Mutex<Option<Value>>>);
+ impl Respond for CaptureResponder {
+ fn respond(&self, req: &Request) -> ResponseTemplate {
+ let v: Value = serde_json::from_slice(&req.body).unwrap();
+ *self.0.lock().unwrap() = Some(v);
+ ResponseTemplate::new(200)
+ .insert_header("content-type", "text/event-stream")
+ .set_body_raw(
+ "event: response.completed\n\
+data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp\",\"output\":[]}}\n\n",
+ "text/event-stream",
+ )
+ }
+ }
+
+ let server = MockServer::start().await;
+ let captured = Arc::new(Mutex::new(None));
+
+ Mock::given(method("POST"))
+ .and(path("/v1/chat/completions"))
+ .respond_with(CaptureResponder(captured.clone()))
+ .expect(1)
+ .mount(&server)
+ .await;
+
+ unsafe {
+ std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
+ }
+
+ let provider = ModelProviderInfo {
+ name: "openai".into(),
+ base_url: format!("{}/v1", server.uri()),
+ env_key: Some("PATH".into()),
+ env_key_instructions: None,
+ wire_api: crate::WireApi::Chat,
+ query_params: None,
+ http_headers: None,
+ env_http_headers: None,
+ };
+
+ let mut prompt = Prompt::default();
+ prompt.input.push(ResponseItem::Message {
+ role: "user".into(),
+ content: vec![ContentItem::InputText { text: "hi".into() }],
+ });
+ prompt.input.push(ResponseItem::FunctionCall {
+ name: "shell".into(),
+ arguments: "[]".into(),
+ call_id: "call123".into(),
+ });
+ prompt.input.push(ResponseItem::FunctionCallOutput {
+ call_id: "call123".into(),
+ output: FunctionCallOutputPayload {
+ content: "ok".into(),
+ success: Some(true),
+ },
+ });
+ prompt.input.push(ResponseItem::LocalShellCall {
+ id: Some("ls1".into()),
+ call_id: Some("call456".into()),
+ status: LocalShellStatus::Completed,
+ action: LocalShellAction::Exec(LocalShellExecAction {
+ command: vec!["echo".into(), "hi".into()],
+ timeout_ms: Some(1),
+ working_directory: None,
+ env: None,
+ user: None,
+ }),
+ });
+
+ let client = reqwest::Client::new();
+ let _ = stream_chat_completions(&prompt, "model", &client, &provider)
+ .await
+ .unwrap();
+
+ let body = captured.lock().unwrap().take().unwrap();
+ let messages = body.get("messages").unwrap().as_array().unwrap();
```
> `assert_eq!()` for `body`