fix(app-server): make thread/start non-blocking (#13033)

Stop `thread/start` from blocking other app-server requests.

Before this change, `thread/start ran` inline on the request loop, so
slow startup paths like MCP auth checks could hold up unrelated requests
on the same connection, including `thread/loaded/list`. This moves
`thread/start` into a background task.

While doing so, it revealed an issue where we were doing nested locking
(and there were some race conditions possible that could introduce a
"phantom listener"). This PR also refactors the listener/subscription
bookkeeping - listener/subscription state is now centralized in
`ThreadStateManager` instead of being split across multiple lock
domains. That makes late auto-attach on `thread/start` race-safe and
avoids reintroducing disconnected clients as phantom subscribers.
This commit is contained in:
Owen Lin
2026-02-27 17:40:08 -08:00
committed by GitHub
parent 6604608bad
commit 8fa792868c
3 changed files with 592 additions and 268 deletions

View File

@@ -60,7 +60,6 @@ pub(crate) struct ThreadState {
listener_command_tx: Option<mpsc::UnboundedSender<ThreadListenerCommand>>,
current_turn_history: ThreadHistoryBuilder,
listener_thread: Option<Weak<CodexThread>>,
subscribed_connections: HashSet<ConnectionId>,
}
impl ThreadState {
@@ -95,18 +94,6 @@ impl ThreadState {
self.listener_thread = None;
}
pub(crate) fn add_connection(&mut self, connection_id: ConnectionId) {
self.subscribed_connections.insert(connection_id);
}
pub(crate) fn remove_connection(&mut self, connection_id: ConnectionId) {
self.subscribed_connections.remove(&connection_id);
}
pub(crate) fn subscribed_connection_ids(&self) -> Vec<ConnectionId> {
self.subscribed_connections.iter().copied().collect()
}
pub(crate) fn set_experimental_raw_events(&mut self, enabled: bool) {
self.experimental_raw_events = enabled;
}
@@ -135,55 +122,112 @@ struct SubscriptionState {
connection_id: ConnectionId,
}
struct ThreadEntry {
state: Arc<Mutex<ThreadState>>,
connection_ids: HashSet<ConnectionId>,
}
impl Default for ThreadEntry {
fn default() -> Self {
Self {
state: Arc::new(Mutex::new(ThreadState::default())),
connection_ids: HashSet::new(),
}
}
}
#[derive(Default)]
pub(crate) struct ThreadStateManager {
thread_states: HashMap<ThreadId, Arc<Mutex<ThreadState>>>,
struct ThreadStateManagerInner {
live_connections: HashSet<ConnectionId>,
threads: HashMap<ThreadId, ThreadEntry>,
subscription_state_by_id: HashMap<Uuid, SubscriptionState>,
thread_ids_by_connection: HashMap<ConnectionId, HashSet<ThreadId>>,
}
#[derive(Clone, Default)]
pub(crate) struct ThreadStateManager {
state: Arc<Mutex<ThreadStateManagerInner>>,
}
impl ThreadStateManager {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn thread_state(&mut self, thread_id: ThreadId) -> Arc<Mutex<ThreadState>> {
self.thread_states
.entry(thread_id)
.or_insert_with(|| Arc::new(Mutex::new(ThreadState::default())))
.clone()
pub(crate) async fn connection_initialized(&self, connection_id: ConnectionId) {
self.state
.lock()
.await
.live_connections
.insert(connection_id);
}
pub(crate) async fn remove_listener(&mut self, subscription_id: Uuid) -> Option<ThreadId> {
let subscription_state = self.subscription_state_by_id.remove(&subscription_id)?;
pub(crate) async fn subscribed_connection_ids(&self, thread_id: ThreadId) -> Vec<ConnectionId> {
let state = self.state.lock().await;
state
.threads
.get(&thread_id)
.map(|thread_entry| thread_entry.connection_ids.iter().copied().collect())
.unwrap_or_default()
}
pub(crate) async fn thread_state(&self, thread_id: ThreadId) -> Arc<Mutex<ThreadState>> {
let mut state = self.state.lock().await;
state.threads.entry(thread_id).or_default().state.clone()
}
pub(crate) async fn remove_listener(&self, subscription_id: Uuid) -> Option<ThreadId> {
let (subscription_state, connection_still_subscribed_to_thread, thread_state) = {
let mut state = self.state.lock().await;
let subscription_state = state.subscription_state_by_id.remove(&subscription_id)?;
let thread_id = subscription_state.thread_id;
let connection_still_subscribed_to_thread = state
.subscription_state_by_id
.values()
.any(|subscription_state_entry| {
subscription_state_entry.thread_id == thread_id
&& subscription_state_entry.connection_id
== subscription_state.connection_id
});
if !connection_still_subscribed_to_thread {
let mut remove_connection_entry = false;
if let Some(thread_ids) = state
.thread_ids_by_connection
.get_mut(&subscription_state.connection_id)
{
thread_ids.remove(&thread_id);
remove_connection_entry = thread_ids.is_empty();
}
if remove_connection_entry {
state
.thread_ids_by_connection
.remove(&subscription_state.connection_id);
}
if let Some(thread_entry) = state.threads.get_mut(&thread_id) {
thread_entry
.connection_ids
.remove(&subscription_state.connection_id);
}
}
let thread_state = state.threads.get(&thread_id).map(|thread_entry| {
(
thread_entry.connection_ids.is_empty(),
thread_entry.state.clone(),
)
});
(
subscription_state,
connection_still_subscribed_to_thread,
thread_state,
)
};
let thread_id = subscription_state.thread_id;
let connection_still_subscribed_to_thread =
self.subscription_state_by_id.values().any(|state| {
state.thread_id == thread_id
&& state.connection_id == subscription_state.connection_id
});
if !connection_still_subscribed_to_thread {
let mut remove_connection_entry = false;
if let Some(thread_ids) = self
.thread_ids_by_connection
.get_mut(&subscription_state.connection_id)
{
thread_ids.remove(&thread_id);
remove_connection_entry = thread_ids.is_empty();
}
if remove_connection_entry {
self.thread_ids_by_connection
.remove(&subscription_state.connection_id);
}
}
if let Some(thread_state) = self.thread_states.get(&thread_id) {
let mut thread_state = thread_state.lock().await;
if !connection_still_subscribed_to_thread {
thread_state.remove_connection(subscription_state.connection_id);
}
if thread_state.subscribed_connection_ids().is_empty() {
if let Some((no_subscribers, thread_state)) = thread_state {
let thread_state = thread_state.lock().await;
if !connection_still_subscribed_to_thread && no_subscribers {
tracing::debug!(
thread_id = %thread_id,
subscription_id = %subscription_id,
@@ -196,8 +240,24 @@ impl ThreadStateManager {
Some(thread_id)
}
pub(crate) async fn remove_thread_state(&mut self, thread_id: ThreadId) {
if let Some(thread_state) = self.thread_states.remove(&thread_id) {
pub(crate) async fn remove_thread_state(&self, thread_id: ThreadId) {
let thread_state = {
let mut state = self.state.lock().await;
let thread_state = state
.threads
.remove(&thread_id)
.map(|thread_entry| thread_entry.state);
state
.subscription_state_by_id
.retain(|_, state| state.thread_id != thread_id);
state.thread_ids_by_connection.retain(|_, thread_ids| {
thread_ids.remove(&thread_id);
!thread_ids.is_empty()
});
thread_state
};
if let Some(thread_state) = thread_state {
let mut thread_state = thread_state.lock().await;
tracing::debug!(
thread_id = %thread_id,
@@ -208,142 +268,189 @@ impl ThreadStateManager {
);
thread_state.clear_listener();
}
self.subscription_state_by_id
.retain(|_, state| state.thread_id != thread_id);
self.thread_ids_by_connection.retain(|_, thread_ids| {
thread_ids.remove(&thread_id);
!thread_ids.is_empty()
});
}
pub(crate) async fn unsubscribe_connection_from_thread(
&mut self,
&self,
thread_id: ThreadId,
connection_id: ConnectionId,
) -> bool {
let Some(thread_state) = self.thread_states.get(&thread_id) else {
return false;
{
let mut state = self.state.lock().await;
if !state.threads.contains_key(&thread_id) {
return false;
}
if !state
.thread_ids_by_connection
.get(&connection_id)
.is_some_and(|thread_ids| thread_ids.contains(&thread_id))
{
return false;
}
if let Some(thread_ids) = state.thread_ids_by_connection.get_mut(&connection_id) {
thread_ids.remove(&thread_id);
if thread_ids.is_empty() {
state.thread_ids_by_connection.remove(&connection_id);
}
}
if let Some(thread_entry) = state.threads.get_mut(&thread_id) {
thread_entry.connection_ids.remove(&connection_id);
}
state
.subscription_state_by_id
.retain(|_, subscription_state| {
!(subscription_state.thread_id == thread_id
&& subscription_state.connection_id == connection_id)
});
};
if !self
.thread_ids_by_connection
.get(&connection_id)
.is_some_and(|thread_ids| thread_ids.contains(&thread_id))
{
return false;
}
if let Some(thread_ids) = self.thread_ids_by_connection.get_mut(&connection_id) {
thread_ids.remove(&thread_id);
if thread_ids.is_empty() {
self.thread_ids_by_connection.remove(&connection_id);
}
}
self.subscription_state_by_id.retain(|_, state| {
!(state.thread_id == thread_id && state.connection_id == connection_id)
});
let mut thread_state = thread_state.lock().await;
thread_state.remove_connection(connection_id);
true
}
pub(crate) async fn has_subscribers(&self, thread_id: ThreadId) -> bool {
let Some(thread_state) = self.thread_states.get(&thread_id) else {
return false;
};
!thread_state
self.state
.lock()
.await
.subscribed_connection_ids()
.is_empty()
.threads
.get(&thread_id)
.is_some_and(|thread_entry| !thread_entry.connection_ids.is_empty())
}
pub(crate) async fn set_listener(
&mut self,
&self,
subscription_id: Uuid,
thread_id: ThreadId,
connection_id: ConnectionId,
experimental_raw_events: bool,
) -> Arc<Mutex<ThreadState>> {
self.subscription_state_by_id.insert(
subscription_id,
SubscriptionState {
thread_id,
connection_id,
},
);
self.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
let thread_state = self.thread_state(thread_id);
let thread_state = {
let mut state = self.state.lock().await;
state.subscription_state_by_id.insert(
subscription_id,
SubscriptionState {
thread_id,
connection_id,
},
);
state
.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
let thread_entry = state.threads.entry(thread_id).or_default();
thread_entry.connection_ids.insert(connection_id);
thread_entry.state.clone()
};
{
let mut thread_state_guard = thread_state.lock().await;
thread_state_guard.add_connection(connection_id);
thread_state_guard.set_experimental_raw_events(experimental_raw_events);
}
thread_state
}
pub(crate) async fn ensure_connection_subscribed(
&mut self,
pub(crate) async fn try_ensure_connection_subscribed(
&self,
thread_id: ThreadId,
connection_id: ConnectionId,
experimental_raw_events: bool,
) -> Arc<Mutex<ThreadState>> {
self.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
let thread_state = self.thread_state(thread_id);
) -> Option<Arc<Mutex<ThreadState>>> {
let thread_state = {
let mut state = self.state.lock().await;
if !state.live_connections.contains(&connection_id) {
return None;
}
state
.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
let thread_entry = state.threads.entry(thread_id).or_default();
thread_entry.connection_ids.insert(connection_id);
thread_entry.state.clone()
};
{
let mut thread_state_guard = thread_state.lock().await;
thread_state_guard.add_connection(connection_id);
if experimental_raw_events {
thread_state_guard.set_experimental_raw_events(true);
}
}
thread_state
Some(thread_state)
}
pub(crate) async fn remove_connection(&mut self, connection_id: ConnectionId) {
let thread_ids = self
.thread_ids_by_connection
.remove(&connection_id)
.unwrap_or_default();
self.subscription_state_by_id
.retain(|_, state| state.connection_id != connection_id);
if thread_ids.is_empty() {
for thread_state in self.thread_states.values() {
let mut thread_state = thread_state.lock().await;
thread_state.remove_connection(connection_id);
if thread_state.subscribed_connection_ids().is_empty() {
tracing::debug!(
connection_id = ?connection_id,
listener_generation = thread_state.listener_generation,
"retaining thread listener after connection disconnect left zero subscribers"
);
}
}
return;
pub(crate) async fn try_add_connection_to_thread(
&self,
thread_id: ThreadId,
connection_id: ConnectionId,
) -> bool {
let mut state = self.state.lock().await;
if !state.live_connections.contains(&connection_id) {
return false;
}
state
.thread_ids_by_connection
.entry(connection_id)
.or_default()
.insert(thread_id);
state
.threads
.entry(thread_id)
.or_default()
.connection_ids
.insert(connection_id);
true
}
for thread_id in thread_ids {
if let Some(thread_state) = self.thread_states.get(&thread_id) {
let mut thread_state = thread_state.lock().await;
thread_state.remove_connection(connection_id);
if thread_state.subscribed_connection_ids().is_empty() {
tracing::debug!(
thread_id = %thread_id,
connection_id = ?connection_id,
listener_generation = thread_state.listener_generation,
"retaining thread listener after connection disconnect left zero subscribers"
);
pub(crate) async fn remove_connection(&self, connection_id: ConnectionId) {
let thread_states = {
let mut state = self.state.lock().await;
state.live_connections.remove(&connection_id);
let thread_ids = state
.thread_ids_by_connection
.remove(&connection_id)
.unwrap_or_default();
state
.subscription_state_by_id
.retain(|_, state| state.connection_id != connection_id);
for thread_id in &thread_ids {
if let Some(thread_entry) = state.threads.get_mut(thread_id) {
thread_entry.connection_ids.remove(&connection_id);
}
}
thread_ids
.into_iter()
.map(|thread_id| {
(
thread_id,
state
.threads
.get(&thread_id)
.is_none_or(|thread_entry| thread_entry.connection_ids.is_empty()),
state
.threads
.get(&thread_id)
.map(|thread_entry| thread_entry.state.clone()),
)
})
.collect::<Vec<_>>()
};
for (thread_id, no_subscribers, thread_state) in thread_states {
if !no_subscribers {
continue;
}
let Some(thread_state) = thread_state else {
continue;
};
let listener_generation = thread_state.lock().await.listener_generation;
tracing::debug!(
thread_id = %thread_id,
connection_id = ?connection_id,
listener_generation,
"retaining thread listener after connection disconnect left zero subscribers"
);
}
}
}