mirror of
https://github.com/openai/codex.git
synced 2026-05-02 12:21:26 +03:00
app-server: Replay pending item requests on thread/resume (#12560)
Replay pending client requests after `thread/resume` and emit resolved notifications when those requests clear so approval/input UI state stays in sync after reconnects and across subscribed clients. Affected RPCs: - `item/commandExecution/requestApproval` - `item/fileChange/requestApproval` - `item/tool/requestUserInput` Motivation: - Resumed clients need to see pending approval/input requests that were already outstanding before the reconnect. - Clients also need an explicit signal when a pending request resolves or is cleared so stale UI can be removed on turn start, completion, or interruption. Implementation notes: - Use pending client requests from `OutgoingMessageSender` in order to replay them after `thread/resume` attaches the connection, using original request ids. - Emit `serverRequest/resolved` when pending requests are answered or cleared by lifecycle cleanup. - Update the app-server protocol schema, generated TypeScript bindings, and README docs for the replay/resolution flow. High-level test plan: - Added automated coverage for replaying pending command execution and file change approval requests on `thread/resume`. - Added automated coverage for resolved notifications in command approval, file change approval, request_user_input, turn start, and turn interrupt flows. - Verified schema/docs updates in the relevant protocol and app-server tests. Manual testing: - Tested reconnect/resume with multiple connections. - Confirmed state stayed in sync between connections.
This commit is contained in:
committed by
GitHub
parent
66b0adb34c
commit
69d7a456bb
@@ -17,6 +17,7 @@ use tokio::sync::oneshot;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
use crate::server_request_error::TURN_TRANSITION_PENDING_REQUEST_ERROR_REASON;
|
||||
|
||||
#[cfg(test)]
|
||||
use codex_protocol::account::PlanType;
|
||||
@@ -62,6 +63,7 @@ pub(crate) struct ThreadScopedOutgoingMessageSender {
|
||||
struct PendingCallbackEntry {
|
||||
callback: oneshot::Sender<ClientRequestResult>,
|
||||
thread_id: Option<ThreadId>,
|
||||
request: ServerRequest,
|
||||
}
|
||||
|
||||
impl ThreadScopedOutgoingMessageSender {
|
||||
@@ -80,12 +82,12 @@ impl ThreadScopedOutgoingMessageSender {
|
||||
pub(crate) async fn send_request(
|
||||
&self,
|
||||
payload: ServerRequestPayload,
|
||||
) -> oneshot::Receiver<ClientRequestResult> {
|
||||
) -> (RequestId, oneshot::Receiver<ClientRequestResult>) {
|
||||
self.outgoing
|
||||
.send_request_to_thread_connections(
|
||||
self.thread_id,
|
||||
self.connection_ids.as_slice(),
|
||||
.send_request_to_connections(
|
||||
Some(self.connection_ids.as_slice()),
|
||||
payload,
|
||||
Some(self.thread_id),
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -99,6 +101,20 @@ impl ThreadScopedOutgoingMessageSender {
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn abort_pending_server_requests(&self) {
|
||||
self.outgoing
|
||||
.cancel_requests_for_thread(
|
||||
self.thread_id,
|
||||
Some(JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: "client request resolved because the turn state was changed"
|
||||
.to_string(),
|
||||
data: Some(serde_json::json!({ "reason": TURN_TRANSITION_PENDING_REQUEST_ERROR_REASON })),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn send_response<T: Serialize>(
|
||||
&self,
|
||||
request_id: ConnectionRequestId,
|
||||
@@ -129,38 +145,23 @@ impl OutgoingMessageSender {
|
||||
&self,
|
||||
request: ServerRequestPayload,
|
||||
) -> (RequestId, oneshot::Receiver<ClientRequestResult>) {
|
||||
self.send_request_with_id_to_connections(&[], request, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_request_to_thread_connections(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
connection_ids: &[ConnectionId],
|
||||
request: ServerRequestPayload,
|
||||
) -> oneshot::Receiver<ClientRequestResult> {
|
||||
if connection_ids.is_empty() {
|
||||
let (_tx, rx) = oneshot::channel();
|
||||
return rx;
|
||||
}
|
||||
let (_request_id, receiver) = self
|
||||
.send_request_with_id_to_connections(connection_ids, request, Some(thread_id))
|
||||
.await;
|
||||
receiver
|
||||
self.send_request_to_connections(None, request, None).await
|
||||
}
|
||||
|
||||
fn next_request_id(&self) -> RequestId {
|
||||
RequestId::Integer(self.next_server_request_id.fetch_add(1, Ordering::Relaxed))
|
||||
}
|
||||
|
||||
async fn send_request_with_id_to_connections(
|
||||
async fn send_request_to_connections(
|
||||
&self,
|
||||
connection_ids: &[ConnectionId],
|
||||
connection_ids: Option<&[ConnectionId]>,
|
||||
request: ServerRequestPayload,
|
||||
thread_id: Option<ThreadId>,
|
||||
) -> (RequestId, oneshot::Receiver<ClientRequestResult>) {
|
||||
let id = self.next_request_id();
|
||||
let outgoing_message_id = id.clone();
|
||||
let request = request.request_with_id(outgoing_message_id.clone());
|
||||
|
||||
let (tx_approve, rx_approve) = oneshot::channel();
|
||||
{
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
@@ -169,36 +170,39 @@ impl OutgoingMessageSender {
|
||||
PendingCallbackEntry {
|
||||
callback: tx_approve,
|
||||
thread_id,
|
||||
request: request.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let outgoing_message =
|
||||
OutgoingMessage::Request(request.request_with_id(outgoing_message_id.clone()));
|
||||
let send_result = if connection_ids.is_empty() {
|
||||
self.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
let mut send_error = None;
|
||||
for connection_id in connection_ids {
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id: *connection_id,
|
||||
message: outgoing_message.clone(),
|
||||
let outgoing_message = OutgoingMessage::Request(request);
|
||||
let send_result = match connection_ids {
|
||||
None => {
|
||||
self.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
message: outgoing_message,
|
||||
})
|
||||
.await
|
||||
{
|
||||
send_error = Some(err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
match send_error {
|
||||
Some(err) => Err(err),
|
||||
None => Ok(()),
|
||||
Some(connection_ids) => {
|
||||
let mut send_error = None;
|
||||
for connection_id in connection_ids {
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id: *connection_id,
|
||||
message: outgoing_message.clone(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
send_error = Some(err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
match send_error {
|
||||
Some(err) => Err(err),
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -210,11 +214,28 @@ impl OutgoingMessageSender {
|
||||
(outgoing_message_id, rx_approve)
|
||||
}
|
||||
|
||||
pub(crate) async fn replay_requests_to_connection_for_thread(
|
||||
&self,
|
||||
connection_id: ConnectionId,
|
||||
thread_id: ThreadId,
|
||||
) {
|
||||
let requests = self.pending_requests_for_thread(thread_id).await;
|
||||
for request in requests {
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::ToConnection {
|
||||
connection_id,
|
||||
message: OutgoingMessage::Request(request),
|
||||
})
|
||||
.await
|
||||
{
|
||||
warn!("failed to resend request to client: {err:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn notify_client_response(&self, id: RequestId, result: Result) {
|
||||
let entry = {
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
request_id_to_callback.remove_entry(&id)
|
||||
};
|
||||
let entry = self.take_request_callback(&id).await;
|
||||
|
||||
match entry {
|
||||
Some((id, entry)) => {
|
||||
@@ -229,10 +250,7 @@ impl OutgoingMessageSender {
|
||||
}
|
||||
|
||||
pub(crate) async fn notify_client_error(&self, id: RequestId, error: JSONRPCErrorError) {
|
||||
let entry = {
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
request_id_to_callback.remove_entry(&id)
|
||||
};
|
||||
let entry = self.take_request_callback(&id).await;
|
||||
|
||||
match entry {
|
||||
Some((id, entry)) => {
|
||||
@@ -248,23 +266,62 @@ impl OutgoingMessageSender {
|
||||
}
|
||||
|
||||
pub(crate) async fn cancel_request(&self, id: &RequestId) -> bool {
|
||||
let entry = {
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
request_id_to_callback.remove_entry(id)
|
||||
};
|
||||
entry.is_some()
|
||||
self.take_request_callback(id).await.is_some()
|
||||
}
|
||||
|
||||
pub(crate) async fn cancel_requests_for_thread(&self, thread_id: ThreadId) {
|
||||
async fn take_request_callback(
|
||||
&self,
|
||||
id: &RequestId,
|
||||
) -> Option<(RequestId, PendingCallbackEntry)> {
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
let request_ids = request_id_to_callback
|
||||
request_id_to_callback.remove_entry(id)
|
||||
}
|
||||
|
||||
pub(crate) async fn pending_requests_for_thread(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
) -> Vec<ServerRequest> {
|
||||
let request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
let mut requests = request_id_to_callback
|
||||
.iter()
|
||||
.filter_map(|(request_id, entry)| {
|
||||
(entry.thread_id == Some(thread_id)).then_some(request_id.clone())
|
||||
.filter_map(|(_, entry)| {
|
||||
(entry.thread_id == Some(thread_id)).then_some(entry.request.clone())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
for request_id in request_ids {
|
||||
request_id_to_callback.remove(&request_id);
|
||||
requests.sort_by(|left, right| left.id().cmp(right.id()));
|
||||
requests
|
||||
}
|
||||
|
||||
pub(crate) async fn cancel_requests_for_thread(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
error: Option<JSONRPCErrorError>,
|
||||
) {
|
||||
let entries = {
|
||||
let mut request_id_to_callback = self.request_id_to_callback.lock().await;
|
||||
let request_ids = request_id_to_callback
|
||||
.iter()
|
||||
.filter_map(|(request_id, entry)| {
|
||||
(entry.thread_id == Some(thread_id)).then_some(request_id.clone())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut entries = Vec::with_capacity(request_ids.len());
|
||||
for request_id in request_ids {
|
||||
if let Some(entry) = request_id_to_callback.remove(&request_id) {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
entries
|
||||
};
|
||||
|
||||
if let Some(error) = error {
|
||||
for entry in entries {
|
||||
if let Err(err) = entry.callback.send(Err(error.clone())) {
|
||||
let request_id = entry.request.id();
|
||||
warn!("could not notify callback for {request_id:?} due to: {err:?}",);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -441,14 +498,18 @@ mod tests {
|
||||
use codex_app_server_protocol::ApplyPatchApprovalParams;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_app_server_protocol::ConfigWarningNotification;
|
||||
use codex_app_server_protocol::DynamicToolCallParams;
|
||||
use codex_app_server_protocol::FileChangeRequestApprovalParams;
|
||||
use codex_app_server_protocol::LoginChatGptCompleteNotification;
|
||||
use codex_app_server_protocol::ModelRerouteReason;
|
||||
use codex_app_server_protocol::ModelReroutedNotification;
|
||||
use codex_app_server_protocol::RateLimitSnapshot;
|
||||
use codex_app_server_protocol::RateLimitWindow;
|
||||
use codex_app_server_protocol::ToolRequestUserInputParams;
|
||||
use codex_protocol::ThreadId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::timeout;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -723,4 +784,121 @@ mod tests {
|
||||
.expect("waiter should receive a callback");
|
||||
assert_eq!(result, Err(error));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_requests_for_thread_returns_thread_requests_in_request_id_order() {
|
||||
let (tx, _rx) = mpsc::channel::<OutgoingEnvelope>(8);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let thread_id = ThreadId::new();
|
||||
let thread_outgoing = ThreadScopedOutgoingMessageSender::new(
|
||||
outgoing.clone(),
|
||||
vec![ConnectionId(1)],
|
||||
thread_id,
|
||||
);
|
||||
|
||||
let (dynamic_tool_request_id, _dynamic_tool_waiter) = thread_outgoing
|
||||
.send_request(ServerRequestPayload::DynamicToolCall(
|
||||
DynamicToolCallParams {
|
||||
thread_id: thread_id.to_string(),
|
||||
turn_id: "turn-1".to_string(),
|
||||
call_id: "call-0".to_string(),
|
||||
tool: "tool".to_string(),
|
||||
arguments: json!({}),
|
||||
},
|
||||
))
|
||||
.await;
|
||||
let (first_request_id, _first_waiter) = thread_outgoing
|
||||
.send_request(ServerRequestPayload::ToolRequestUserInput(
|
||||
ToolRequestUserInputParams {
|
||||
thread_id: thread_id.to_string(),
|
||||
turn_id: "turn-1".to_string(),
|
||||
item_id: "call-1".to_string(),
|
||||
questions: vec![],
|
||||
},
|
||||
))
|
||||
.await;
|
||||
let (second_request_id, _second_waiter) = thread_outgoing
|
||||
.send_request(ServerRequestPayload::FileChangeRequestApproval(
|
||||
FileChangeRequestApprovalParams {
|
||||
thread_id: thread_id.to_string(),
|
||||
turn_id: "turn-1".to_string(),
|
||||
item_id: "call-2".to_string(),
|
||||
reason: None,
|
||||
grant_root: None,
|
||||
},
|
||||
))
|
||||
.await;
|
||||
let pending_requests = outgoing.pending_requests_for_thread(thread_id).await;
|
||||
assert_eq!(
|
||||
pending_requests
|
||||
.iter()
|
||||
.map(ServerRequest::id)
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
&dynamic_tool_request_id,
|
||||
&first_request_id,
|
||||
&second_request_id
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancel_requests_for_thread_cancels_all_thread_requests() {
|
||||
let (tx, _rx) = mpsc::channel::<OutgoingEnvelope>(8);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(tx));
|
||||
let thread_id = ThreadId::new();
|
||||
let thread_outgoing = ThreadScopedOutgoingMessageSender::new(
|
||||
outgoing.clone(),
|
||||
vec![ConnectionId(1)],
|
||||
thread_id,
|
||||
);
|
||||
|
||||
let (_dynamic_tool_request_id, dynamic_tool_waiter) = thread_outgoing
|
||||
.send_request(ServerRequestPayload::DynamicToolCall(
|
||||
DynamicToolCallParams {
|
||||
thread_id: thread_id.to_string(),
|
||||
turn_id: "turn-1".to_string(),
|
||||
call_id: "call-0".to_string(),
|
||||
tool: "tool".to_string(),
|
||||
arguments: json!({}),
|
||||
},
|
||||
))
|
||||
.await;
|
||||
let (_request_id, user_input_waiter) = thread_outgoing
|
||||
.send_request(ServerRequestPayload::ToolRequestUserInput(
|
||||
ToolRequestUserInputParams {
|
||||
thread_id: thread_id.to_string(),
|
||||
turn_id: "turn-1".to_string(),
|
||||
item_id: "call-1".to_string(),
|
||||
questions: vec![],
|
||||
},
|
||||
))
|
||||
.await;
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: "tracked request cancelled".to_string(),
|
||||
data: None,
|
||||
};
|
||||
|
||||
outgoing
|
||||
.cancel_requests_for_thread(thread_id, Some(error.clone()))
|
||||
.await;
|
||||
|
||||
let dynamic_tool_result = timeout(Duration::from_secs(1), dynamic_tool_waiter)
|
||||
.await
|
||||
.expect("dynamic tool waiter should resolve")
|
||||
.expect("dynamic tool waiter should receive a callback");
|
||||
let user_input_result = timeout(Duration::from_secs(1), user_input_waiter)
|
||||
.await
|
||||
.expect("user input waiter should resolve")
|
||||
.expect("user input waiter should receive a callback");
|
||||
assert_eq!(dynamic_tool_result, Err(error.clone()));
|
||||
assert_eq!(user_input_result, Err(error));
|
||||
assert!(
|
||||
outgoing
|
||||
.pending_requests_for_thread(thread_id)
|
||||
.await
|
||||
.is_empty()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user