stream init

This commit is contained in:
Ahmed Ibrahim
2025-08-01 17:25:50 -07:00
parent 97ab8fb610
commit 3a456c1fbb
8 changed files with 585 additions and 39 deletions

View File

@@ -17,6 +17,7 @@ use codex_mcp_server::CodexToolCallReplyParam;
use codex_mcp_server::mcp_protocol::ConversationCreateArgs;
use codex_mcp_server::mcp_protocol::ConversationId;
use codex_mcp_server::mcp_protocol::ConversationSendMessageArgs;
use codex_mcp_server::mcp_protocol::ConversationStreamArgs;
use codex_mcp_server::mcp_protocol::ToolCallRequestParams;
use mcp_types::CallToolRequestParams;
@@ -201,6 +202,20 @@ impl McpProcess {
.await
}
pub async fn send_conversation_stream_tool_call(
&mut self,
session_id: &str,
) -> anyhow::Result<i64> {
let params = ToolCallRequestParams::ConversationStream(ConversationStreamArgs {
conversation_id: ConversationId(Uuid::parse_str(session_id)?),
});
self.send_request(
mcp_types::CallToolRequest::METHOD,
Some(serde_json::to_value(params)?),
)
.await
}
pub async fn send_conversation_create_tool_call(
&mut self,
prompt: &str,
@@ -236,6 +251,99 @@ impl McpProcess {
.await
}
/// Create a conversation and return its conversation_id as a string.
pub async fn create_conversation_and_get_id(
&mut self,
prompt: &str,
model: &str,
cwd: &str,
) -> anyhow::Result<String> {
let req_id = self
.send_conversation_create_tool_call(prompt, model, cwd)
.await?;
let resp = self
.read_stream_until_response_message(RequestId::Integer(req_id))
.await?;
let conv_id = resp.result["structuredContent"]["conversation_id"]
.as_str()
.ok_or_else(|| anyhow::format_err!("missing conversation_id"))?
.to_string();
Ok(conv_id)
}
/// Connect stream for a conversation and wait for the initial_state notification.
/// Returns the params of the initial_state notification for further inspection.
pub async fn connect_stream_and_expect_initial_state(
&mut self,
session_id: &str,
) -> anyhow::Result<serde_json::Value> {
let req_id = self.send_conversation_stream_tool_call(session_id).await?;
// Wait for stream() tool-call response first
let _ = self
.read_stream_until_response_message(RequestId::Integer(req_id))
.await?;
// Then the initial_state notification
let note = self
.read_stream_until_notification_method("notifications/initial_state")
.await?;
note.params
.ok_or_else(|| anyhow::format_err!("initial_state must have params"))
}
/// Connect stream and also return the request id for later cancellation.
pub async fn connect_stream_get_req_and_initial_state(
&mut self,
session_id: &str,
) -> anyhow::Result<(i64, serde_json::Value)> {
let req_id = self.send_conversation_stream_tool_call(session_id).await?;
let _ = self
.read_stream_until_response_message(RequestId::Integer(req_id))
.await?;
let note = self
.read_stream_until_notification_method("notifications/initial_state")
.await?;
let params = note
.params
.ok_or_else(|| anyhow::format_err!("initial_state must have params"))?;
Ok((req_id, params))
}
/// Wait for an agent_message with a bounded timeout. Returns Some(params) if received, None on timeout.
pub async fn maybe_wait_for_agent_message(
&mut self,
dur: std::time::Duration,
) -> anyhow::Result<Option<serde_json::Value>> {
match tokio::time::timeout(dur, self.wait_for_agent_message()).await {
Ok(Ok(v)) => Ok(Some(v)),
Ok(Err(e)) => Err(e),
Err(_elapsed) => Ok(None),
}
}
/// Send a user message to a conversation and wait for the OK tool-call response.
pub async fn send_user_message_and_wait_ok(
&mut self,
message: &str,
session_id: &str,
) -> anyhow::Result<()> {
let req_id = self
.send_user_message_tool_call(message, session_id)
.await?;
let _ = self
.read_stream_until_response_message(RequestId::Integer(req_id))
.await?;
Ok(())
}
/// Wait until an agent_message notification arrives; returns its params.
pub async fn wait_for_agent_message(&mut self) -> anyhow::Result<serde_json::Value> {
let note = self
.read_stream_until_notification_method("agent_message")
.await?;
note.params
.ok_or_else(|| anyhow::format_err!("agent_message missing params"))
}
async fn send_request(
&mut self,
method: &str,
@@ -329,6 +437,31 @@ impl McpProcess {
}
}
pub async fn read_stream_until_notification_method(
&mut self,
method: &str,
) -> anyhow::Result<JSONRPCNotification> {
loop {
let message = self.read_jsonrpc_message().await?;
match message {
JSONRPCMessage::Notification(n) => {
if n.method == method {
return Ok(n);
}
}
JSONRPCMessage::Request(_) => {
// ignore
}
JSONRPCMessage::Error(_) => {
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
}
JSONRPCMessage::Response(_) => {
// ignore
}
}
}
}
pub async fn read_stream_until_configured_response_message(
&mut self,
) -> anyhow::Result<String> {

View File

@@ -0,0 +1,228 @@
#![allow(clippy::expect_used, clippy::unwrap_used)]
use std::path::Path;
use mcp_test_support::McpProcess;
use mcp_test_support::create_final_assistant_message_sse_response;
use mcp_test_support::create_mock_chat_completions_server;
use mcp_types::JSONRPCNotification;
use pretty_assertions::assert_eq;
use serde_json::json;
use tempfile::TempDir;
use tokio::time::timeout;
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_connect_then_send_receives_initial_state_and_notifications() {
let responses = vec![
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
];
let server = create_mock_chat_completions_server(responses).await;
let codex_home = TempDir::new().expect("create temp dir");
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
let mut mcp = McpProcess::new(codex_home.path())
.await
.expect("spawn mcp process");
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
.await
.expect("init timeout")
.expect("init failed");
// Create conversation
let conv_id = mcp
.create_conversation_and_get_id("", "o3", "/repo")
.await
.expect("create conversation");
// Connect the stream
let params = mcp
.connect_stream_and_expect_initial_state(&conv_id)
.await
.expect("initial_state params");
assert_eq!(
params["_meta"]["conversationId"].as_str(),
Some(conv_id.as_str())
);
assert!(params["initial_state"]["events"].is_array());
assert!(
params["initial_state"]["events"]
.as_array()
.unwrap()
.is_empty()
);
// Send a message and expect a subsequent notification (non-initial_state)
mcp.send_user_message_and_wait_ok("Hello there", &conv_id)
.await
.expect("send message ok");
// Read until we see an event notification (new schema example: agent_message)
let params = mcp.wait_for_agent_message().await.expect("agent message");
assert_eq!(params["msg"]["type"].as_str(), Some("agent_message"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_send_then_connect_receives_initial_state_with_message() {
let responses = vec![
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
];
let server = create_mock_chat_completions_server(responses).await;
let codex_home = TempDir::new().expect("create temp dir");
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
let mut mcp = McpProcess::new(codex_home.path())
.await
.expect("spawn mcp process");
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
.await
.expect("init timeout")
.expect("init failed");
// Create conversation
let conv_id = mcp
.create_conversation_and_get_id("", "o3", "/repo")
.await
.expect("create conversation");
// Send a message BEFORE connecting stream
mcp.send_user_message_and_wait_ok("Hello world", &conv_id)
.await
.expect("send message ok");
// Now connect stream and expect InitialState with the prior message included
let params = mcp
.connect_stream_and_expect_initial_state(&conv_id)
.await
.expect("initial_state params");
let events = params["initial_state"]["events"]
.as_array()
.expect("events array");
if !events.iter().any(|ev| {
ev.get("msg")
.and_then(|m| m.get("type"))
.and_then(|t| t.as_str())
== Some("agent_message")
&& ev
.get("msg")
.and_then(|m| m.get("message"))
.and_then(|t| t.as_str())
== Some("Done")
}) {
// Fallback to live notification if not present in initial state
let note: JSONRPCNotification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_method("agent_message"),
)
.await
.expect("event note timeout")
.expect("event note err");
let params = note.params.expect("params");
assert_eq!(params["msg"]["type"].as_str(), Some("agent_message"));
assert_eq!(params["msg"]["message"].as_str(), Some("Done"));
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_cancel_stream_then_reconnect_catches_up_initial_state() {
// One response is sufficient for the assertions in this test
let responses = vec![
create_final_assistant_message_sse_response("Done").expect("build mock assistant message"),
];
let server = create_mock_chat_completions_server(responses).await;
let codex_home = TempDir::new().expect("create temp dir");
create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml");
let mut mcp = McpProcess::new(codex_home.path())
.await
.expect("spawn mcp process");
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
.await
.expect("init timeout")
.expect("init failed");
// Create and connect stream A
let conv_id = mcp
.create_conversation_and_get_id("", "o3", "/repo")
.await
.expect("create");
let (stream_a_id, _params) = mcp
.connect_stream_get_req_and_initial_state(&conv_id)
.await
.expect("stream A initial_state");
// Send M1 and ensure we get live agent_message
mcp.send_user_message_and_wait_ok("Hello M1", &conv_id)
.await
.expect("send M1");
let _params = mcp.wait_for_agent_message().await.expect("agent M1");
// Cancel stream A
mcp.send_notification(
"notifications/cancelled",
Some(json!({ "requestId": stream_a_id })),
)
.await
.expect("send cancelled");
// Send M2 while stream is cancelled; we should NOT get agent_message live
mcp.send_user_message_and_wait_ok("Hello M2", &conv_id)
.await
.expect("send M2");
let maybe = mcp
.maybe_wait_for_agent_message(std::time::Duration::from_millis(300))
.await
.expect("maybe wait");
assert!(
maybe.is_none(),
"should not get live agent_message after cancel"
);
// Connect stream B and expect initial_state that includes the response
let params = mcp
.connect_stream_and_expect_initial_state(&conv_id)
.await
.expect("stream B initial_state");
let events = params["initial_state"]["events"]
.as_array()
.expect("events array");
assert!(events.iter().any(|ev| {
ev.get("msg")
.and_then(|m| m.get("type"))
.and_then(|t| t.as_str())
== Some("agent_message")
&& ev
.get("msg")
.and_then(|m| m.get("message"))
.and_then(|t| t.as_str())
== Some("Done")
}));
}
// Helper to create a config.toml pointing at the mock model server.
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
let config_toml = codex_home.join("config.toml");
std::fs::write(
config_toml,
format!(
r#"
model = "mock-model"
approval_policy = "never"
sandbox_mode = "danger-full-access"
model_provider = "mock_provider"
[model_providers.mock_provider]
name = "Mock provider for test"
base_url = "{server_uri}/v1"
wire_api = "chat"
request_max_retries = 0
stream_max_retries = 0
"#
),
)
}