stream init

This commit is contained in:
Ahmed Ibrahim
2025-08-01 17:25:50 -07:00
parent 97ab8fb610
commit 3a456c1fbb
8 changed files with 585 additions and 39 deletions

View File

@@ -17,6 +17,7 @@ use codex_mcp_server::CodexToolCallReplyParam;
use codex_mcp_server::mcp_protocol::ConversationCreateArgs;
use codex_mcp_server::mcp_protocol::ConversationId;
use codex_mcp_server::mcp_protocol::ConversationSendMessageArgs;
use codex_mcp_server::mcp_protocol::ConversationStreamArgs;
use codex_mcp_server::mcp_protocol::ToolCallRequestParams;
use mcp_types::CallToolRequestParams;
@@ -201,6 +202,20 @@ impl McpProcess {
.await
}
pub async fn send_conversation_stream_tool_call(
&mut self,
session_id: &str,
) -> anyhow::Result<i64> {
let params = ToolCallRequestParams::ConversationStream(ConversationStreamArgs {
conversation_id: ConversationId(Uuid::parse_str(session_id)?),
});
self.send_request(
mcp_types::CallToolRequest::METHOD,
Some(serde_json::to_value(params)?),
)
.await
}
pub async fn send_conversation_create_tool_call(
&mut self,
prompt: &str,
@@ -236,6 +251,99 @@ impl McpProcess {
.await
}
/// Create a conversation and return its conversation_id as a string.
pub async fn create_conversation_and_get_id(
&mut self,
prompt: &str,
model: &str,
cwd: &str,
) -> anyhow::Result<String> {
let req_id = self
.send_conversation_create_tool_call(prompt, model, cwd)
.await?;
let resp = self
.read_stream_until_response_message(RequestId::Integer(req_id))
.await?;
let conv_id = resp.result["structuredContent"]["conversation_id"]
.as_str()
.ok_or_else(|| anyhow::format_err!("missing conversation_id"))?
.to_string();
Ok(conv_id)
}
/// Connect stream for a conversation and wait for the initial_state notification.
/// Returns the params of the initial_state notification for further inspection.
pub async fn connect_stream_and_expect_initial_state(
&mut self,
session_id: &str,
) -> anyhow::Result<serde_json::Value> {
let req_id = self.send_conversation_stream_tool_call(session_id).await?;
// Wait for stream() tool-call response first
let _ = self
.read_stream_until_response_message(RequestId::Integer(req_id))
.await?;
// Then the initial_state notification
let note = self
.read_stream_until_notification_method("notifications/initial_state")
.await?;
note.params
.ok_or_else(|| anyhow::format_err!("initial_state must have params"))
}
/// Connect stream and also return the request id for later cancellation.
pub async fn connect_stream_get_req_and_initial_state(
&mut self,
session_id: &str,
) -> anyhow::Result<(i64, serde_json::Value)> {
let req_id = self.send_conversation_stream_tool_call(session_id).await?;
let _ = self
.read_stream_until_response_message(RequestId::Integer(req_id))
.await?;
let note = self
.read_stream_until_notification_method("notifications/initial_state")
.await?;
let params = note
.params
.ok_or_else(|| anyhow::format_err!("initial_state must have params"))?;
Ok((req_id, params))
}
/// Wait for an agent_message with a bounded timeout. Returns Some(params) if received, None on timeout.
pub async fn maybe_wait_for_agent_message(
&mut self,
dur: std::time::Duration,
) -> anyhow::Result<Option<serde_json::Value>> {
match tokio::time::timeout(dur, self.wait_for_agent_message()).await {
Ok(Ok(v)) => Ok(Some(v)),
Ok(Err(e)) => Err(e),
Err(_elapsed) => Ok(None),
}
}
/// Send a user message to a conversation and wait for the OK tool-call response.
pub async fn send_user_message_and_wait_ok(
&mut self,
message: &str,
session_id: &str,
) -> anyhow::Result<()> {
let req_id = self
.send_user_message_tool_call(message, session_id)
.await?;
let _ = self
.read_stream_until_response_message(RequestId::Integer(req_id))
.await?;
Ok(())
}
/// Wait until an agent_message notification arrives; returns its params.
pub async fn wait_for_agent_message(&mut self) -> anyhow::Result<serde_json::Value> {
let note = self
.read_stream_until_notification_method("agent_message")
.await?;
note.params
.ok_or_else(|| anyhow::format_err!("agent_message missing params"))
}
async fn send_request(
&mut self,
method: &str,
@@ -329,6 +437,31 @@ impl McpProcess {
}
}
pub async fn read_stream_until_notification_method(
&mut self,
method: &str,
) -> anyhow::Result<JSONRPCNotification> {
loop {
let message = self.read_jsonrpc_message().await?;
match message {
JSONRPCMessage::Notification(n) => {
if n.method == method {
return Ok(n);
}
}
JSONRPCMessage::Request(_) => {
// ignore
}
JSONRPCMessage::Error(_) => {
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
}
JSONRPCMessage::Response(_) => {
// ignore
}
}
}
}
pub async fn read_stream_until_configured_response_message(
&mut self,
) -> anyhow::Result<String> {