diff --git a/codex-rs/mcp-server/src/conversation_loop.rs b/codex-rs/mcp-server/src/conversation_loop.rs index 7af304c89e..3ff5359c19 100644 --- a/codex-rs/mcp-server/src/conversation_loop.rs +++ b/codex-rs/mcp-server/src/conversation_loop.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::collections::HashSet; use std::path::PathBuf; use std::sync::Arc; @@ -12,7 +11,7 @@ use codex_core::protocol::ExecApprovalRequestEvent; use codex_core::protocol::FileChange; use mcp_types::RequestId; use tokio::sync::Mutex; -use tokio::sync::watch::Receiver as WatchReceiver; +// no streaming watch channel; streaming is toggled via set_streaming on the struct use tracing::error; use uuid::Uuid; @@ -25,6 +24,19 @@ use crate::mcp_protocol::NotificationMeta; use crate::outgoing_message::OutgoingMessageSender; use crate::patch_approval::handle_patch_approval_request; +/// A single source of truth for an active conversation. +/// Owns the Codex session and all per-conversation state. +pub(crate) struct Conversation { + codex: Arc, + session_id: Uuid, + outgoing: Arc, + request_id: RequestId, + running: bool, + streaming_enabled: bool, + buffered_events: Vec, + pending_elicitations: Vec, +} + /// Deferred elicitation requests to be sent after InitialState when /// streaming is enabled. Preserves original event order (FIFO). enum PendingElicitation { @@ -43,15 +55,6 @@ enum PendingElicitation { }, } -/// Immutable context shared across helper functions to avoid long -/// argument lists. -struct LoopCtx { - outgoing: Arc, - codex: Arc, - request_id: RequestId, - request_id_str: String, -} - /// Snapshot of a patch approval request used to defer elicitation. struct PatchReq { call_id: String, @@ -61,384 +64,324 @@ struct PatchReq { event_id: String, } -/// Conversation event loop bridging Codex events to MCP notifications. -/// -/// Semantics: -/// - Always buffers all Codex events to include in an InitialState snapshot when -/// streaming turns on. -/// - Streams notifications live when `streaming_enabled` is true. -/// - Defers exec/patch approval elicitations until streaming turns on so -/// the client first receives InitialState, then the corresponding requests. -pub async fn run_conversation_loop( - codex: Arc, - outgoing: Arc, - request_id: RequestId, - mut stream_rx: WatchReceiver, - session_id: Uuid, - running_session_ids: Arc>>, -) { - let request_id_str = match &request_id { - RequestId::String(s) => s.clone(), - RequestId::Integer(n) => n.to_string(), - }; +impl Conversation { + pub(crate) fn new( + codex: Arc, + outgoing: Arc, + request_id: RequestId, + session_id: Uuid, + ) -> Arc> { + let conv = Arc::new(Mutex::new(Self { + codex, + session_id, + outgoing, + request_id, + running: false, + streaming_enabled: false, + buffered_events: Vec::new(), + pending_elicitations: Vec::new(), + })); + // Detach a background loop tied to this Conversation + Conversation::spawn_loop(conv.clone()); + conv + } - // Buffer all events to include in InitialState when streaming is enabled - // TODO: this should be expanded to load sessions from the disk. - let mut buffered_events: Vec = Vec::new(); - let mut streaming_enabled = *stream_rx.borrow(); - - let mut pending_elicitations: Vec = Vec::new(); - - let ctx = LoopCtx { - outgoing: outgoing.clone(), - codex: codex.clone(), - request_id: request_id.clone(), - request_id_str: request_id_str.clone(), - }; - - loop { - tokio::select! { - res = codex.next_event() => { - handle_next_event_arm( - res, - streaming_enabled, - &mut buffered_events, - &mut pending_elicitations, - &ctx, - &running_session_ids, - &session_id, - ).await; - }, - changed = stream_rx.changed() => { - handle_stream_rx_arm( - changed, - &mut stream_rx, - &mut streaming_enabled, - &session_id, - &buffered_events, - &mut pending_elicitations, - &ctx, - ).await; - } + pub(crate) async fn set_streaming(&mut self, enabled: bool) { + if enabled && !self.streaming_enabled { + self.streaming_enabled = true; + self.emit_initial_state().await; + self.drain_pending_elicitations().await; + } else if !enabled && self.streaming_enabled { + self.streaming_enabled = false; } } -} -/// Handles the `codex.next_event()` select arm. -async fn handle_next_event_arm( - res: Result, - streaming_enabled: bool, - buffered_events: &mut Vec, - pending_elicitations: &mut Vec, - ctx: &LoopCtx, - running_session_ids: &Arc>>, - session_id: &Uuid, -) where - E: std::fmt::Display, -{ - match res { - Ok(event) => { - buffered_events.push(CodexEventNotificationParams { - meta: None, - msg: event.msg.clone(), - }); - stream_event_if_enabled(streaming_enabled, ctx, &event.msg).await; + fn spawn_loop(this: Arc>) { + tokio::spawn(async move { + loop { + // We clone codex to avoid holding the lock while awaiting next_event + let codex = { this.lock().await.codex.clone() }; + let res = codex.next_event().await; + let mut guard = this.lock().await; + guard.handle_next_event(res).await; + } + }); + } - match event.msg { - EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { - command, - cwd, - call_id, - reason: _, - }) => { - process_exec_request( - streaming_enabled, - pending_elicitations, + pub(crate) fn codex(&self) -> Arc { + self.codex.clone() + } + + pub(crate) async fn try_submit_user_input( + &mut self, + request_id: RequestId, + items: Vec, + ) -> Result<(), String> { + if self.running { + return Err("Session is already running".to_string()); + } + // Optimistically mark running to avoid races between quick successive submits + self.running = true; + let request_id_string = match &request_id { + RequestId::String(s) => s.clone(), + RequestId::Integer(i) => i.to_string(), + }; + let submit_res = self + .codex + .submit_with_id(codex_core::protocol::Submission { + id: request_id_string, + op: codex_core::protocol::Op::UserInput { items }, + }) + .await; + if let Err(e) = submit_res { + // Revert running on error + self.running = false; + return Err(format!("Failed to submit user input: {e}")); + } + Ok(()) + } + + async fn handle_next_event(&mut self, res: Result) + where + E: std::fmt::Display, + { + match res { + Ok(event) => { + self.buffered_events.push(CodexEventNotificationParams { + meta: None, + msg: event.msg.clone(), + }); + self.stream_event_if_enabled(&event.msg).await; + + match event.msg { + EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { command, cwd, call_id, - event.id.clone(), - ctx, - ) - .await; - } - EventMsg::Error(_) => { - error!("Codex runtime error"); - handle_task_clear(running_session_ids, session_id).await; - } - EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { - call_id, - reason, - grant_root, - changes, - }) => { - process_patch_request( - streaming_enabled, - pending_elicitations, - PatchReq { + reason: _, + }) => { + self.process_exec_request(command, cwd, call_id, event.id.clone()) + .await; + } + EventMsg::Error(_) => { + error!("Codex runtime error"); + self.handle_task_clear().await; + } + EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { + call_id, + reason, + grant_root, + changes, + }) => { + self.process_patch_request(PatchReq { call_id, reason, grant_root, changes, event_id: event.id.clone(), + }) + .await; + } + EventMsg::TaskComplete(_) => { + self.handle_task_clear().await; + } + EventMsg::TaskStarted => { + self.handle_task_started().await; + } + EventMsg::SessionConfigured(_) => { + tracing::error!("unexpected SessionConfigured event"); + } + EventMsg::AgentMessageDelta(_) => {} + EventMsg::AgentReasoningDelta(_) => {} + EventMsg::AgentMessage(AgentMessageEvent { .. }) => {} + EventMsg::TokenCount(_) + | EventMsg::AgentReasoning(_) + | EventMsg::McpToolCallBegin(_) + | EventMsg::McpToolCallEnd(_) + | EventMsg::ExecCommandBegin(_) + | EventMsg::ExecCommandEnd(_) + | EventMsg::BackgroundEvent(_) + | EventMsg::ExecCommandOutputDelta(_) + | EventMsg::PatchApplyBegin(_) + | EventMsg::PatchApplyEnd(_) + | EventMsg::GetHistoryEntryResponse(_) + | EventMsg::PlanUpdate(_) + | EventMsg::TurnDiff(_) + | EventMsg::ShutdownComplete => { + self.handle_task_clear().await; + } + } + } + Err(e) => { + error!("Codex runtime error: {e}"); + self.handle_task_clear().await; + } + } + } + + // streaming toggling handled by set_streaming() + + async fn emit_initial_state(&self) { + let params = InitialStateNotificationParams { + meta: Some(NotificationMeta { + conversation_id: Some(ConversationId(self.session_id)), + request_id: None, + }), + initial_state: InitialStatePayload { + events: self.buffered_events.clone(), + }, + }; + if let Ok(params_val) = serde_json::to_value(¶ms) { + self.outgoing + .send_custom_notification("notifications/initial_state", params_val) + .await; + } else { + error!("Failed to serialize InitialState params"); + } + } + + async fn drain_pending_elicitations(&mut self) { + for item in self.pending_elicitations.drain(..) { + match item { + PendingElicitation::Exec { + command, + cwd, + event_id, + call_id, + } => { + handle_exec_approval_request( + command, + cwd, + self.outgoing.clone(), + self.codex.clone(), + self.request_id.clone(), + match &self.request_id { + RequestId::String(s) => s.clone(), + RequestId::Integer(n) => n.to_string(), }, - ctx, + event_id, + call_id, ) .await; } - EventMsg::TaskComplete(_) => { - handle_task_clear(running_session_ids, session_id).await; - } - EventMsg::TaskStarted => { - handle_task_started(running_session_ids, session_id).await; - } - EventMsg::SessionConfigured(_) => { - tracing::error!("unexpected SessionConfigured event"); - } - EventMsg::AgentMessageDelta(_) => {} - EventMsg::AgentReasoningDelta(_) => {} - EventMsg::AgentMessage(AgentMessageEvent { .. }) => {} - EventMsg::TokenCount(_) - | EventMsg::AgentReasoning(_) - | EventMsg::McpToolCallBegin(_) - | EventMsg::McpToolCallEnd(_) - | EventMsg::ExecCommandBegin(_) - | EventMsg::ExecCommandEnd(_) - | EventMsg::BackgroundEvent(_) - | EventMsg::ExecCommandOutputDelta(_) - | EventMsg::PatchApplyBegin(_) - | EventMsg::PatchApplyEnd(_) - | EventMsg::GetHistoryEntryResponse(_) - | EventMsg::PlanUpdate(_) - | EventMsg::ShutdownComplete => { - handle_task_clear(running_session_ids, session_id).await; - } - } - } - Err(e) => { - error!("Codex runtime error: {e}"); - handle_task_clear(running_session_ids, session_id).await; - } - } -} - -/// Handles the `stream_rx.changed()` select arm. -async fn handle_stream_rx_arm( - changed: Result<(), tokio::sync::watch::error::RecvError>, - stream_rx: &mut WatchReceiver, - streaming_enabled: &mut bool, - session_id: &Uuid, - buffered_events: &[CodexEventNotificationParams], - pending_elicitations: &mut Vec, - ctx: &LoopCtx, -) { - if changed.is_ok() { - let now = *stream_rx.borrow(); - handle_stream_change( - now, - streaming_enabled, - *session_id, - buffered_events, - pending_elicitations, - ctx, - ) - .await; - } else { - error!("stream_rx change error; streaming control channel closed"); - } -} - -/// Handles a streaming state change. -/// -/// When enabling streaming: -/// 1) emits InitialState with all buffered events -/// 2) drains and sends any deferred elicitations -async fn handle_stream_change( - now: bool, - streaming_enabled: &mut bool, - session_id: Uuid, - buffered_events: &[CodexEventNotificationParams], - pending: &mut Vec, - ctx: &LoopCtx, -) { - if now && !*streaming_enabled { - *streaming_enabled = true; - emit_initial_state(ctx, session_id, buffered_events).await; - drain_pending_elicitations(pending, ctx).await; - } else if !now && *streaming_enabled { - *streaming_enabled = false; - } -} - -/// Emits the InitialState snapshot to the client. -async fn emit_initial_state( - ctx: &LoopCtx, - session_id: Uuid, - buffered_events: &[CodexEventNotificationParams], -) { - let params = InitialStateNotificationParams { - meta: Some(NotificationMeta { - conversation_id: Some(ConversationId(session_id)), - request_id: None, - }), - initial_state: InitialStatePayload { - events: buffered_events.to_vec(), - }, - }; - if let Ok(params_val) = serde_json::to_value(¶ms) { - ctx.outgoing - .send_custom_notification("notifications/initial_state", params_val) - .await; - } else { - error!("Failed to serialize InitialState params"); - } -} - -/// Sends any deferred exec/patch elicitations in FIFO order. -async fn drain_pending_elicitations(pending: &mut Vec, ctx: &LoopCtx) { - for item in pending.drain(..) { - match item { - PendingElicitation::Exec { - command, - cwd, - event_id, - call_id, - } => { - handle_exec_approval_request( - command, - cwd, - ctx.outgoing.clone(), - ctx.codex.clone(), - ctx.request_id.clone(), - ctx.request_id_str.clone(), - event_id, - call_id, - ) - .await; - } - PendingElicitation::PatchReq { - call_id, - reason, - grant_root, - changes, - event_id, - } => { - handle_patch_approval_request( + PendingElicitation::PatchReq { call_id, reason, grant_root, changes, - ctx.outgoing.clone(), - ctx.codex.clone(), - ctx.request_id.clone(), - ctx.request_id_str.clone(), event_id, - ) - .await; + } => { + handle_patch_approval_request( + call_id, + reason, + grant_root, + changes, + self.outgoing.clone(), + self.codex.clone(), + self.request_id.clone(), + match &self.request_id { + RequestId::String(s) => s.clone(), + RequestId::Integer(n) => n.to_string(), + }, + event_id, + ) + .await; + } } } } -} -/// Handles an exec approval request. If streaming is disabled, defers the -/// elicitation until after InitialState; otherwise elicits immediately. -async fn process_exec_request( - streaming_enabled: bool, - pending: &mut Vec, - command: Vec, - cwd: PathBuf, - call_id: String, - event_id: String, - ctx: &LoopCtx, -) { - if streaming_enabled { - handle_exec_approval_request( - command, - cwd, - ctx.outgoing.clone(), - ctx.codex.clone(), - ctx.request_id.clone(), - ctx.request_id_str.clone(), - event_id, - call_id, - ) - .await; - } else { - pending.push(PendingElicitation::Exec { - command, - cwd, - event_id, - call_id, - }); - } -} - -/// Handles a patch approval request. If streaming is disabled, defers the -/// elicitation until after InitialState; otherwise elicits immediately. -async fn process_patch_request( - streaming_enabled: bool, - pending: &mut Vec, - req: PatchReq, - ctx: &LoopCtx, -) { - let PatchReq { - call_id, - reason, - grant_root, - changes, - event_id, - } = req; - if streaming_enabled { - handle_patch_approval_request( - call_id, - reason, - grant_root, - changes, - ctx.outgoing.clone(), - ctx.codex.clone(), - ctx.request_id.clone(), - ctx.request_id_str.clone(), - event_id, - ) - .await; - } else { - pending.push(PendingElicitation::PatchReq { - call_id, - reason, - grant_root, - changes, - event_id, - }); - } -} - -/// Streams a single Codex event as an MCP notification if streaming is enabled. -async fn stream_event_if_enabled(streaming_enabled: bool, ctx: &LoopCtx, msg: &EventMsg) { - if !streaming_enabled { - return; - } - let method = msg.to_string(); - let params = CodexEventNotificationParams { - meta: None, - msg: msg.clone(), - }; - if let Ok(params_val) = serde_json::to_value(¶ms) { - ctx.outgoing - .send_custom_notification(&method, params_val) + async fn process_exec_request( + &mut self, + command: Vec, + cwd: PathBuf, + call_id: String, + event_id: String, + ) { + if self.streaming_enabled { + handle_exec_approval_request( + command, + cwd, + self.outgoing.clone(), + self.codex.clone(), + self.request_id.clone(), + match &self.request_id { + RequestId::String(s) => s.clone(), + RequestId::Integer(n) => n.to_string(), + }, + event_id, + call_id, + ) .await; - } else { - error!("Failed to serialize event params"); + } else { + self.pending_elicitations.push(PendingElicitation::Exec { + command, + cwd, + event_id, + call_id, + }); + } + } + + async fn process_patch_request(&mut self, req: PatchReq) { + let PatchReq { + call_id, + reason, + grant_root, + changes, + event_id, + } = req; + if self.streaming_enabled { + handle_patch_approval_request( + call_id, + reason, + grant_root, + changes, + self.outgoing.clone(), + self.codex.clone(), + self.request_id.clone(), + match &self.request_id { + RequestId::String(s) => s.clone(), + RequestId::Integer(n) => n.to_string(), + }, + event_id, + ) + .await; + } else { + self.pending_elicitations + .push(PendingElicitation::PatchReq { + call_id, + reason, + grant_root, + changes, + event_id, + }); + } + } + + async fn stream_event_if_enabled(&self, msg: &EventMsg) { + if !self.streaming_enabled { + return; + } + let method = msg.to_string(); + let params = CodexEventNotificationParams { + meta: None, + msg: msg.clone(), + }; + if let Ok(params_val) = serde_json::to_value(¶ms) { + self.outgoing + .send_custom_notification(&method, params_val) + .await; + } else { + error!("Failed to serialize event params"); + } + } + + async fn handle_task_started(&mut self) { + self.running = true; + } + + async fn handle_task_clear(&mut self) { + self.running = false; } } - -/// Inserts the session id into the shared running set when a task starts. -async fn handle_task_started(running_session_ids: &Arc>>, session_id: &Uuid) { - let mut running_session_ids = running_session_ids.lock().await; - running_session_ids.insert(*session_id); -} - -/// Removes the session id from the shared running set for any terminal condition. -async fn handle_task_clear(running_session_ids: &Arc>>, session_id: &Uuid) { - let mut running_session_ids = running_session_ids.lock().await; - running_session_ids.remove(session_id); -} diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs index 306ddb4245..dd47651341 100644 --- a/codex-rs/mcp-server/src/message_processor.rs +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::collections::HashSet; use std::path::PathBuf; use std::sync::Arc; @@ -37,7 +36,6 @@ use mcp_types::ServerNotification; use mcp_types::TextContent; use serde_json::json; use tokio::sync::Mutex; -use tokio::sync::watch; use tokio::task; use uuid::Uuid; @@ -46,10 +44,8 @@ pub(crate) struct MessageProcessor { initialized: bool, codex_linux_sandbox_exe: Option, session_map: Arc>>>, + conversation_map: Arc>>>>, running_requests_id_to_codex_uuid: Arc>>, - running_session_ids: Arc>>, - /// Per-session streaming state signal (true when client connected via ConversationStream) - streaming_session_senders: Arc>>>, /// Track request IDs to the original ToolCallRequestParams for cancellation handling tool_request_map: Arc>>, } @@ -66,31 +62,22 @@ impl MessageProcessor { initialized: false, codex_linux_sandbox_exe, session_map: Arc::new(Mutex::new(HashMap::new())), + conversation_map: Arc::new(Mutex::new(HashMap::new())), running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())), - running_session_ids: Arc::new(Mutex::new(HashSet::new())), - streaming_session_senders: Arc::new(Mutex::new(HashMap::new())), tool_request_map: Arc::new(Mutex::new(HashMap::new())), } } - pub(crate) fn session_map(&self) -> Arc>>> { - self.session_map.clone() + pub(crate) fn conversation_map( + &self, + ) -> Arc>>>> { + self.conversation_map.clone() } pub(crate) fn outgoing(&self) -> Arc { self.outgoing.clone() } - pub(crate) fn running_session_ids(&self) -> Arc>> { - self.running_session_ids.clone() - } - - pub(crate) fn streaming_session_senders( - &self, - ) -> Arc>>> { - self.streaming_session_senders.clone() - } - pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) { // Hold on to the ID so we can respond. let request_id = request.id.clone(); @@ -644,9 +631,9 @@ impl MessageProcessor { let session_id = args.conversation_id.0; let codex_arc = { - let sessions_guard = self.session_map.lock().await; + let sessions_guard = self.conversation_map.lock().await; match sessions_guard.get(&session_id) { - Some(codex) => Arc::clone(codex), + Some(conv) => conv.lock().await.codex().clone(), None => { tracing::warn!( "Cancel send_message: session not found for session_id: {session_id}" diff --git a/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs b/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs index d7d35a8760..8cbbd20f07 100644 --- a/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs +++ b/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs @@ -1,19 +1,14 @@ -use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; -use codex_core::Codex; use codex_core::codex_wrapper::init_codex; use codex_core::config::Config as CodexConfig; use codex_core::config::ConfigOverrides; use codex_core::protocol::EventMsg; use codex_core::protocol::SessionConfiguredEvent; use mcp_types::RequestId; -use tokio::sync::Mutex; -use tokio::sync::watch; -use uuid::Uuid; -use crate::conversation_loop::run_conversation_loop; +use crate::conversation_loop::Conversation; use crate::json_to_toml::json_to_toml; use crate::mcp_protocol::ConversationCreateArgs; use crate::mcp_protocol::ConversationCreateResult; @@ -122,41 +117,17 @@ pub(crate) async fn handle_create_conversation( let session_id = codex_conversation.session_id; let codex_arc = Arc::new(codex_conversation.codex); - // Store session for future calls - insert_session( - session_id, - codex_arc.clone(), - message_processor.session_map(), - ) - .await; - - // Create per-session streaming control channel (initially disabled) - let (stream_tx, stream_rx) = watch::channel(false); - { - let senders = message_processor.streaming_session_senders(); - let mut guard = senders.lock().await; - guard.insert(session_id, stream_tx); - } - // Run the conversation loop in the background so this request can return immediately. + // Construct conversation and start its loop, store it, then reply with id and model let outgoing = message_processor.outgoing(); - let spawn_id = id.clone(); - let running_session_ids = message_processor.running_session_ids(); - tokio::spawn(async move { - run_conversation_loop( - codex_arc.clone(), - outgoing, - spawn_id, - stream_rx, - session_id, - running_session_ids, - ) - .await; - }); - - // Reply with the new conversation id and effective model + let conversation = Conversation::new(codex_arc.clone(), outgoing, id.clone(), session_id); + let conv_map = message_processor.conversation_map(); + { + let mut guard = conv_map.lock().await; + guard.insert(session_id, conversation); + } message_processor .send_response_with_optional_error( - id, + id.clone(), Some(ToolCallResponseResult::ConversationCreate( ConversationCreateResult::Ok { conversation_id: ConversationId(session_id), @@ -167,12 +138,3 @@ pub(crate) async fn handle_create_conversation( ) .await; } - -async fn insert_session( - session_id: Uuid, - codex: Arc, - session_map: Arc>>>, -) { - let mut guard = session_map.lock().await; - guard.insert(session_id, codex); -} diff --git a/codex-rs/mcp-server/src/tool_handlers/send_message.rs b/codex-rs/mcp-server/src/tool_handlers/send_message.rs index 0fc76d02e7..7d3c838e42 100644 --- a/codex-rs/mcp-server/src/tool_handlers/send_message.rs +++ b/codex-rs/mcp-server/src/tool_handlers/send_message.rs @@ -1,13 +1,11 @@ use std::collections::HashMap; use std::sync::Arc; -use codex_core::Codex; -use codex_core::protocol::Op; -use codex_core::protocol::Submission; use mcp_types::RequestId; use tokio::sync::Mutex; use uuid::Uuid; +use crate::conversation_loop::Conversation; use crate::mcp_protocol::ConversationSendMessageArgs; use crate::mcp_protocol::ConversationSendMessageResult; use crate::mcp_protocol::ToolCallResponseResult; @@ -41,7 +39,8 @@ pub(crate) async fn handle_send_message( } let session_id = conversation_id.0; - let Some(codex) = get_session(session_id, message_processor.session_map()).await else { + let Some(conversation) = get_session(session_id, message_processor.conversation_map()).await + else { message_processor .send_response_with_optional_error( id, @@ -56,47 +55,17 @@ pub(crate) async fn handle_send_message( return; }; - let running = { - let running_session_ids = message_processor.running_session_ids(); - let running_session_ids = running_session_ids.lock().await; - running_session_ids.contains(&session_id) + let res = { + let mut guard = conversation.lock().await; + guard.try_submit_user_input(id.clone(), items).await }; - if running { + if let Err(e) = res { message_processor .send_response_with_optional_error( id, Some(ToolCallResponseResult::ConversationSendMessage( - ConversationSendMessageResult::Error { - message: "Session is already running".to_string(), - }, - )), - Some(true), - ) - .await; - return; - } - - let request_id_string = match &id { - RequestId::String(s) => s.clone(), - RequestId::Integer(i) => i.to_string(), - }; - - let submit_res = codex - .submit_with_id(Submission { - id: request_id_string, - op: Op::UserInput { items }, - }) - .await; - - if let Err(e) = submit_res { - message_processor - .send_response_with_optional_error( - id, - Some(ToolCallResponseResult::ConversationSendMessage( - ConversationSendMessageResult::Error { - message: format!("Failed to submit user input: {e}"), - }, + ConversationSendMessageResult::Error { message: e }, )), Some(true), ) @@ -117,8 +86,8 @@ pub(crate) async fn handle_send_message( pub(crate) async fn get_session( session_id: Uuid, - session_map: Arc>>>, -) -> Option> { - let guard = session_map.lock().await; + conversation_map: Arc>>>>, +) -> Option>> { + let guard = conversation_map.lock().await; guard.get(&session_id).cloned() } diff --git a/codex-rs/mcp-server/src/tool_handlers/stream_conversation.rs b/codex-rs/mcp-server/src/tool_handlers/stream_conversation.rs index d890ad9f07..860a90ddd7 100644 --- a/codex-rs/mcp-server/src/tool_handlers/stream_conversation.rs +++ b/codex-rs/mcp-server/src/tool_handlers/stream_conversation.rs @@ -17,12 +17,10 @@ pub(crate) async fn handle_stream_conversation( let session_id = conversation_id.0; - // Ensure the session exists - let session_exists = get_session(session_id, message_processor.session_map()) - .await - .is_some(); + // Ensure the session exists and enable streaming + let conv = get_session(session_id, message_processor.conversation_map()).await; - if !session_exists { + if conv.is_none() { // Return an error with no result payload per MCP error pattern message_processor .send_response_with_optional_error(id, None, Some(true)) @@ -30,20 +28,8 @@ pub(crate) async fn handle_stream_conversation( return; } - // Toggle streaming to enabled via the per-session watch channel - let senders_map = message_processor.streaming_session_senders(); - let tx = { - let guard = senders_map.lock().await; - guard.get(&session_id).cloned() - }; - if let Some(tx) = tx { - let _ = tx.send(true); - } else { - // No channel found for the session; treat as error - message_processor - .send_response_with_optional_error(id, None, Some(true)) - .await; - return; + if let Some(conv) = conv { + conv.lock().await.set_streaming(true).await; } // Acknowledge the stream request @@ -64,12 +50,7 @@ pub(crate) async fn handle_cancel( args: &ConversationStreamArgs, ) { let session_id = args.conversation_id.0; - let sender_opt: Option> = { - let senders = message_processor.streaming_session_senders(); - let guard = senders.lock().await; - guard.get(&session_id).cloned() - }; - if let Some(tx) = sender_opt { - let _ = tx.send(false); + if let Some(conv) = get_session(session_id, message_processor.conversation_map()).await { + conv.lock().await.set_streaming(false).await; } } diff --git a/codex-rs/mcp-server/tests/send_message.rs b/codex-rs/mcp-server/tests/send_message.rs index 0115bf7b2a..c309a02062 100644 --- a/codex-rs/mcp-server/tests/send_message.rs +++ b/codex-rs/mcp-server/tests/send_message.rs @@ -3,7 +3,6 @@ use std::thread::sleep; use std::time::Duration; -use codex_mcp_server::CodexToolCallParam; use mcp_test_support::McpProcess; use mcp_test_support::create_config_toml; use mcp_test_support::create_final_assistant_message_sse_response; @@ -20,11 +19,9 @@ const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_send_message_success() { - // Spin up a mock completions server that immediately ends the Codex turn. - // Two Codex turns hit the mock model (session start + send-user-message). Provide two SSE responses. + // Spin up a mock completions server that ends the Codex turn for the send-user-message call. let responses = vec![ create_final_assistant_message_sse_response("Done").expect("build mock assistant message"), - create_final_assistant_message_sse_response("Done").expect("build mock assistant message"), ]; let server = create_mock_chat_completions_server(responses).await; @@ -41,29 +38,11 @@ async fn test_send_message_success() { .expect("init timed out") .expect("init failed"); - // Kick off a Codex session so we have a valid session_id. - let codex_request_id = mcp_process - .send_codex_tool_call(CodexToolCallParam { - prompt: "Start a session".to_string(), - ..Default::default() - }) - .await - .expect("send codex tool call"); - - // Wait for the session_configured event to get the session_id. + // Create a conversation using the tool and get its conversation_id let session_id = mcp_process - .read_stream_until_configured_response_message() + .create_conversation_and_get_id("", "mock-model", "/repo") .await - .expect("read session_configured"); - - // The original codex call will finish quickly given our mock; consume its response. - timeout( - DEFAULT_READ_TIMEOUT, - mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)), - ) - .await - .expect("codex response timeout") - .expect("codex response error"); + .expect("create conversation"); // Now exercise the send-user-message tool. let send_msg_request_id = mcp_process