Compare commits

...

1 Commits

Author SHA1 Message Date
pakrym-oai
9525434812 Scope function_call_output assertions to specific mocks 2025-10-09 11:11:06 -07:00

View File

@@ -1,7 +1,5 @@
#![cfg(not(target_os = "windows"))]
use std::collections::HashMap;
use anyhow::Result;
use codex_core::protocol::AskForApproval;
use codex_core::protocol::EventMsg;
@@ -9,11 +7,12 @@ use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
use codex_core::protocol::SandboxPolicy;
use codex_protocol::config_types::ReasoningSummary;
use core_test_support::responses::ResponseMock;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::skip_if_no_network;
@@ -31,24 +30,21 @@ fn extract_output_text(item: &Value) -> Option<&str> {
})
}
fn collect_tool_outputs(bodies: &[Value]) -> Result<HashMap<String, Value>> {
let mut outputs = HashMap::new();
for body in bodies {
if let Some(items) = body.get("input").and_then(Value::as_array) {
for item in items {
if item.get("type").and_then(Value::as_str) != Some("function_call_output") {
continue;
}
if let Some(call_id) = item.get("call_id").and_then(Value::as_str) {
let content = extract_output_text(item)
.ok_or_else(|| anyhow::anyhow!("missing tool output content"))?;
let parsed: Value = serde_json::from_str(content)?;
outputs.insert(call_id.to_string(), parsed);
}
}
}
}
Ok(outputs)
fn function_call_output_json(mock: &ResponseMock, call_id: &str) -> Result<Value> {
let request = mock
.requests()
.into_iter()
.find(|request| {
request.input().iter().any(|item| {
item.get("type").and_then(Value::as_str) == Some("function_call_output")
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
})
})
.ok_or_else(|| anyhow::anyhow!("missing {call_id} function_call_output"))?;
let item = request.function_call_output(call_id);
let content = extract_output_text(&item)
.ok_or_else(|| anyhow::anyhow!("missing tool output content for {call_id}"))?;
Ok(serde_json::from_str(content)?)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
@@ -81,7 +77,8 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
"timeout_ms": 500,
});
let responses = vec![
let _first_mock = mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-1"),
ev_function_call(
@@ -91,6 +88,10 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
),
ev_completed("resp-1"),
]),
)
.await;
let _second_mock = mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-2"),
ev_function_call(
@@ -100,12 +101,16 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
),
ev_completed("resp-2"),
]),
)
.await;
let final_mock = mount_sse_once(
&server,
sse(vec![
ev_assistant_message("msg-1", "all done"),
ev_completed("resp-3"),
]),
];
mount_sse_sequence(&server, responses).await;
)
.await;
let session_model = session_configured.model.clone();
@@ -126,19 +131,7 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
let requests = server.received_requests().await.expect("recorded requests");
assert!(!requests.is_empty(), "expected at least one POST request");
let bodies = requests
.iter()
.map(|req| req.body_json::<Value>().expect("request json"))
.collect::<Vec<_>>();
let outputs = collect_tool_outputs(&bodies)?;
let start_output = outputs
.get(first_call_id)
.expect("missing first unified_exec output");
let start_output = function_call_output_json(&final_mock, first_call_id)?;
let session_id = start_output["session_id"].as_str().unwrap_or_default();
assert!(
!session_id.is_empty(),
@@ -151,9 +144,7 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
.is_empty()
);
let reuse_output = outputs
.get(second_call_id)
.expect("missing reused unified_exec output");
let reuse_output = function_call_output_json(&final_mock, second_call_id)?;
assert_eq!(
reuse_output["session_id"].as_str().unwrap_or_default(),
session_id
@@ -216,7 +207,8 @@ PY
"timeout_ms": 800,
});
let responses = vec![
let _first_mock = mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-1"),
ev_function_call(
@@ -226,6 +218,10 @@ PY
),
ev_completed("resp-1"),
]),
)
.await;
let _second_mock = mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-2"),
ev_function_call(
@@ -235,12 +231,16 @@ PY
),
ev_completed("resp-2"),
]),
)
.await;
let final_mock = mount_sse_once(
&server,
sse(vec![
ev_assistant_message("msg-1", "lag handled"),
ev_completed("resp-3"),
]),
];
mount_sse_sequence(&server, responses).await;
)
.await;
let session_model = session_configured.model.clone();
@@ -261,28 +261,14 @@ PY
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
let requests = server.received_requests().await.expect("recorded requests");
assert!(!requests.is_empty(), "expected at least one POST request");
let bodies = requests
.iter()
.map(|req| req.body_json::<Value>().expect("request json"))
.collect::<Vec<_>>();
let outputs = collect_tool_outputs(&bodies)?;
let start_output = outputs
.get(first_call_id)
.expect("missing initial unified_exec output");
let start_output = function_call_output_json(&final_mock, first_call_id)?;
let session_id = start_output["session_id"].as_str().unwrap_or_default();
assert!(
!session_id.is_empty(),
"expected session id from initial unified_exec response"
);
let poll_output = outputs
.get(second_call_id)
.expect("missing poll unified_exec output");
let poll_output = function_call_output_json(&final_mock, second_call_id)?;
let poll_text = poll_output["output"].as_str().unwrap_or_default();
assert!(
poll_text.contains("TAIL-MARKER"),
@@ -322,7 +308,8 @@ async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
"timeout_ms": 800,
});
let responses = vec![
let _first_mock = mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-1"),
ev_function_call(
@@ -332,6 +319,10 @@ async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
),
ev_completed("resp-1"),
]),
)
.await;
let _second_mock = mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-2"),
ev_function_call(
@@ -341,12 +332,16 @@ async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
),
ev_completed("resp-2"),
]),
)
.await;
let final_mock = mount_sse_once(
&server,
sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-3"),
]),
];
mount_sse_sequence(&server, responses).await;
)
.await;
let session_model = session_configured.model.clone();
@@ -372,17 +367,7 @@ async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
}
}
let requests = server.received_requests().await.expect("recorded requests");
assert!(!requests.is_empty(), "expected at least one POST request");
let bodies = requests
.iter()
.map(|req| req.body_json::<Value>().expect("request json"))
.collect::<Vec<_>>();
let outputs = collect_tool_outputs(&bodies)?;
let first_output = outputs.get(first_call_id).expect("missing timeout output");
let first_output = function_call_output_json(&final_mock, first_call_id)?;
assert_eq!(first_output["session_id"], "0");
assert!(
first_output["output"]
@@ -391,7 +376,7 @@ async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
.is_empty()
);
let poll_output = outputs.get(second_call_id).expect("missing poll output");
let poll_output = function_call_output_json(&final_mock, second_call_id)?;
let output_text = poll_output["output"].as_str().unwrap_or_default();
assert!(
output_text.contains("ready"),