Compare commits

...

1 Commits

Author SHA1 Message Date
Shaqayeq
083243dca1 Support concurrent Python SDK turns across threads 2026-03-19 16:16:02 -07:00
8 changed files with 584 additions and 189 deletions

View File

@@ -2,7 +2,7 @@
Public surface of `codex_app_server` for app-server v2.
This SDK surface is experimental. The current implementation intentionally allows only one active turn consumer (`Thread.run()`, `TurnHandle.stream()`, or `TurnHandle.run()`) per client instance at a time.
This SDK surface is experimental. The current implementation allows concurrent turn consumers on one client only when they belong to different thread IDs. Each client still supports only one active turn per thread ID at a time.
## Package Entry
@@ -137,8 +137,9 @@ Use `turn(...)` when you need low-level turn control (`stream()`, `steer()`,
Behavior notes:
- `stream()` and `run()` are exclusive per client instance in the current experimental build
- starting a second turn consumer on the same `Codex` instance raises `RuntimeError`
- `stream()` and `run()` may run concurrently on one client when the turns belong to different thread IDs
- starting a second turn on the same thread raises `RuntimeError`; use `steer()` or `interrupt()` on the existing handle instead
- low-level global notification APIs such as `next_notification()` are incompatible with active turn streaming on the same client
### AsyncTurnHandle
@@ -149,8 +150,9 @@ Behavior notes:
Behavior notes:
- `stream()` and `run()` are exclusive per client instance in the current experimental build
- starting a second turn consumer on the same `AsyncCodex` instance raises `RuntimeError`
- `stream()` and `run()` may run concurrently on one client when the turns belong to different thread IDs
- starting a second turn on the same thread raises `RuntimeError`; use `steer()` or `interrupt()` on the existing handle instead
- low-level global notification APIs such as `next_notification()` are incompatible with active turn streaming on the same client
## Inputs

View File

@@ -43,7 +43,8 @@ What happened:
- `thread.run("...")` started a turn, consumed events until completion, and returned the final assistant response plus collected items and usage.
- `result.final_response` is `None` when no final-answer or phase-less assistant message item completes for the turn.
- use `thread.turn(...)` when you need a `TurnHandle` for streaming, steering, interrupting, or turn IDs/status
- one client can have only one active turn consumer (`thread.run(...)`, `TurnHandle.stream()`, or `TurnHandle.run()`) at a time in the current experimental build
- one client can run turns concurrently across different thread IDs in the current experimental build
- one thread can have only one active turn at a time on a given client; start a second same-thread turn only after the first completes, or use `steer()` on the existing `TurnHandle`
## 3) Continue the same thread (multi-turn)

View File

@@ -653,11 +653,10 @@ class TurnHandle:
return self._client.turn_interrupt(self.thread_id, self.id)
def stream(self) -> Iterator[Notification]:
# TODO: replace this client-wide experimental guard with per-turn event demux.
self._client.acquire_turn_consumer(self.id)
self._client.acquire_turn_consumer(self.thread_id, self.id)
try:
while True:
event = self._client.next_notification()
event = self._client.next_turn_notification(self.thread_id, self.id)
yield event
if (
event.method == "turn/completed"
@@ -666,7 +665,7 @@ class TurnHandle:
):
break
finally:
self._client.release_turn_consumer(self.id)
self._client.release_turn_consumer(self.thread_id, self.id)
def run(self) -> AppServerTurn:
completed: TurnCompletedNotification | None = None
@@ -704,11 +703,10 @@ class AsyncTurnHandle:
async def stream(self) -> AsyncIterator[Notification]:
await self._codex._ensure_initialized()
# TODO: replace this client-wide experimental guard with per-turn event demux.
self._codex._client.acquire_turn_consumer(self.id)
self._codex._client.acquire_turn_consumer(self.thread_id, self.id)
try:
while True:
event = await self._codex._client.next_notification()
event = await self._codex._client.next_turn_notification(self.thread_id, self.id)
yield event
if (
event.method == "turn/completed"
@@ -717,7 +715,7 @@ class AsyncTurnHandle:
):
break
finally:
self._codex._client.release_turn_consumer(self.id)
self._codex._client.release_turn_consumer(self.thread_id, self.id)
async def run(self) -> AppServerTurn:
completed: TurnCompletedNotification | None = None

View File

@@ -41,8 +41,6 @@ class AsyncAppServerClient:
def __init__(self, config: AppServerConfig | None = None) -> None:
self._sync = AppServerClient(config=config)
# Single stdio transport cannot be read safely from multiple threads.
self._transport_lock = asyncio.Lock()
async def __aenter__(self) -> "AsyncAppServerClient":
await self.start()
@@ -58,8 +56,7 @@ class AsyncAppServerClient:
*args: ParamsT.args,
**kwargs: ParamsT.kwargs,
) -> ReturnT:
async with self._transport_lock:
return await asyncio.to_thread(fn, *args, **kwargs)
return await asyncio.to_thread(fn, *args, **kwargs)
@staticmethod
def _next_from_iterator(
@@ -79,11 +76,11 @@ class AsyncAppServerClient:
async def initialize(self) -> InitializeResponse:
return await self._call_sync(self._sync.initialize)
def acquire_turn_consumer(self, turn_id: str) -> None:
self._sync.acquire_turn_consumer(turn_id)
def acquire_turn_consumer(self, thread_id: str, turn_id: str) -> None:
self._sync.acquire_turn_consumer(thread_id, turn_id)
def release_turn_consumer(self, turn_id: str) -> None:
self._sync.release_turn_consumer(turn_id)
def release_turn_consumer(self, thread_id: str, turn_id: str) -> None:
self._sync.release_turn_consumer(thread_id, turn_id)
async def request(
self,
@@ -184,6 +181,9 @@ class AsyncAppServerClient:
async def next_notification(self) -> Notification:
return await self._call_sync(self._sync.next_notification)
async def next_turn_notification(self, thread_id: str, turn_id: str) -> Notification:
return await self._call_sync(self._sync.next_turn_notification, thread_id, turn_id)
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
@@ -196,13 +196,12 @@ class AsyncAppServerClient:
text: str,
params: V2TurnStartParams | JsonObject | None = None,
) -> AsyncIterator[AgentMessageDeltaNotification]:
async with self._transport_lock:
iterator = self._sync.stream_text(thread_id, text, params)
while True:
has_value, chunk = await asyncio.to_thread(
self._next_from_iterator,
iterator,
)
if not has_value:
break
yield chunk
iterator = self._sync.stream_text(thread_id, text, params)
while True:
has_value, chunk = await asyncio.to_thread(
self._next_from_iterator,
iterator,
)
if not has_value:
break
yield chunk

View File

@@ -6,7 +6,7 @@ import subprocess
import threading
import uuid
from collections import deque
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Iterable, Iterator, TypeVar
@@ -48,6 +48,58 @@ from .retry import retry_on_overload
ModelT = TypeVar("ModelT", bound=BaseModel)
ApprovalHandler = Callable[[str, JsonObject | None], JsonObject]
RUNTIME_PKG_NAME = "codex-cli-bin"
GLOBAL_NOTIFICATION_BACKLOG_LIMIT = 512
@dataclass(slots=True)
class _PendingRequest:
event: threading.Event = field(default_factory=threading.Event)
result: JsonValue | None = None
error: BaseException | None = None
class _BufferedNotificationStream:
def __init__(self, *, maxlen: int | None = None) -> None:
self._condition = threading.Condition()
self._items: deque[Notification] = (
deque(maxlen=maxlen) if maxlen is not None else deque()
)
self._closed = False
self._error: BaseException | None = None
def push(self, notification: Notification) -> None:
with self._condition:
if self._closed:
return
self._items.append(notification)
self._condition.notify_all()
def pop(self) -> Notification:
with self._condition:
while not self._items and not self._closed:
self._condition.wait()
if self._items:
return self._items.popleft()
if self._error is not None:
raise self._error
raise TransportClosedError("notification stream is closed")
def close(self, error: BaseException | None = None) -> None:
with self._condition:
self._closed = True
self._error = error
self._condition.notify_all()
def is_closed(self) -> bool:
with self._condition:
return self._closed
def is_drained(self) -> bool:
with self._condition:
return self._closed and not self._items
def _params_dict(
@@ -144,12 +196,21 @@ class AppServerClient:
self.config = config or AppServerConfig()
self._approval_handler = approval_handler or self._default_approval_handler
self._proc: subprocess.Popen[str] | None = None
self._lock = threading.Lock()
self._turn_consumer_lock = threading.Lock()
self._active_turn_consumer: str | None = None
self._pending_notifications: deque[Notification] = deque()
self._write_lock = threading.Lock()
self._state_lock = threading.Lock()
self._pending_notifications = _BufferedNotificationStream(
maxlen=GLOBAL_NOTIFICATION_BACKLOG_LIMIT
)
self._pending_requests: dict[str, _PendingRequest] = {}
self._turn_streams: dict[tuple[str, str], _BufferedNotificationStream] = {}
self._turn_starting_by_thread_id: set[str] = set()
self._active_turn_by_thread_id: dict[str, str] = {}
self._active_turn_consumers: set[tuple[str, str]] = set()
self._active_turn_stream_count = 0
self._transport_error: BaseException | None = None
self._stderr_lines: deque[str] = deque(maxlen=400)
self._stderr_thread: threading.Thread | None = None
self._reader_thread: threading.Thread | None = None
def __enter__(self) -> "AppServerClient":
self.start()
@@ -161,6 +222,7 @@ class AppServerClient:
def start(self) -> None:
if self._proc is not None:
return
self._reset_transport_state()
if self.config.launch_args_override is not None:
args = list(self.config.launch_args_override)
@@ -187,13 +249,14 @@ class AppServerClient:
)
self._start_stderr_drain_thread()
self._start_reader_thread()
def close(self) -> None:
if self._proc is None:
return
proc = self._proc
self._proc = None
self._active_turn_consumer = None
self._finish_transport(TransportClosedError("app-server closed"))
if proc.stdin:
proc.stdin.close()
@@ -205,6 +268,8 @@ class AppServerClient:
if self._stderr_thread and self._stderr_thread.is_alive():
self._stderr_thread.join(timeout=0.5)
if self._reader_thread and self._reader_thread.is_alive():
self._reader_thread.join(timeout=0.5)
def initialize(self) -> InitializeResponse:
result = self.request(
@@ -238,67 +303,76 @@ class AppServerClient:
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
request_id = str(uuid.uuid4())
self._write_message({"id": request_id, "method": method, "params": params or {}})
waiter = _PendingRequest()
with self._state_lock:
if self._transport_error is not None:
raise self._transport_error
self._pending_requests[request_id] = waiter
while True:
msg = self._read_message()
try:
self._write_message({"id": request_id, "method": method, "params": params or {}})
except BaseException:
with self._state_lock:
self._pending_requests.pop(request_id, None)
raise
if "method" in msg and "id" in msg:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
continue
if "method" in msg and "id" not in msg:
self._pending_notifications.append(
self._coerce_notification(msg["method"], msg.get("params"))
)
continue
if msg.get("id") != request_id:
continue
if "error" in msg:
err = msg["error"]
if isinstance(err, dict):
raise map_jsonrpc_error(
int(err.get("code", -32000)),
str(err.get("message", "unknown")),
err.get("data"),
)
raise AppServerError("Malformed JSON-RPC error response")
return msg.get("result")
waiter.event.wait()
if waiter.error is not None:
raise waiter.error
return waiter.result
def notify(self, method: str, params: JsonObject | None = None) -> None:
self._write_message({"method": method, "params": params or {}})
def next_notification(self) -> Notification:
if self._pending_notifications:
return self._pending_notifications.popleft()
while True:
msg = self._read_message()
if "method" in msg and "id" in msg:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
continue
if "method" in msg and "id" not in msg:
return self._coerce_notification(msg["method"], msg.get("params"))
def acquire_turn_consumer(self, turn_id: str) -> None:
with self._turn_consumer_lock:
if self._active_turn_consumer is not None:
with self._state_lock:
if self._active_turn_stream_count > 0:
raise RuntimeError(
"Concurrent turn consumers are not yet supported in the experimental SDK. "
f"Client is already streaming turn {self._active_turn_consumer!r}; "
f"cannot start turn {turn_id!r} until the active consumer finishes."
"next_notification() is incompatible with active turn streaming on the same "
"client. Consume notifications from TurnHandle.stream()/run() instead."
)
self._active_turn_consumer = turn_id
return self._pending_notifications.pop()
def release_turn_consumer(self, turn_id: str) -> None:
with self._turn_consumer_lock:
if self._active_turn_consumer == turn_id:
self._active_turn_consumer = None
def acquire_turn_consumer(self, thread_id: str, turn_id: str) -> None:
turn_key = (thread_id, turn_id)
with self._state_lock:
if turn_key in self._active_turn_consumers:
raise RuntimeError(
f"Turn {turn_id!r} is already being streamed on thread {thread_id!r}."
)
self._active_turn_consumers.add(turn_key)
self._active_turn_stream_count += 1
self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
def release_turn_consumer(self, thread_id: str, turn_id: str) -> None:
turn_key = (thread_id, turn_id)
with self._state_lock:
if turn_key in self._active_turn_consumers:
self._active_turn_consumers.remove(turn_key)
self._active_turn_stream_count -= 1
stream = self._turn_streams.get(turn_key)
if stream is not None and stream.is_drained():
self._turn_streams.pop(turn_key, None)
def next_turn_notification(self, thread_id: str, turn_id: str) -> Notification:
turn_key = (thread_id, turn_id)
with self._state_lock:
stream = self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
return stream.pop()
def assert_can_start_turn(self, thread_id: str) -> None:
with self._state_lock:
if thread_id in self._turn_starting_by_thread_id:
raise RuntimeError(
f"Thread {thread_id!r} is already starting a turn on this client."
)
active_turn_id = self._active_turn_by_thread_id.get(thread_id)
if active_turn_id is not None:
raise RuntimeError(
f"Thread {thread_id!r} already has active turn {active_turn_id!r}. "
"Use TurnHandle.steer() or TurnHandle.interrupt() instead of starting "
"another turn on the same thread."
)
def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse)
@@ -355,12 +429,19 @@ class AppServerClient:
input_items: list[JsonObject] | JsonObject | str,
params: V2TurnStartParams | JsonObject | None = None,
) -> TurnStartResponse:
self._begin_turn_start(thread_id)
payload = {
**_params_dict(params),
"threadId": thread_id,
"input": self._normalize_input_items(input_items),
}
return self.request("turn/start", payload, response_model=TurnStartResponse)
try:
started = self.request("turn/start", payload, response_model=TurnStartResponse)
except BaseException:
self._cancel_turn_start(thread_id)
raise
self._finish_turn_start(thread_id, started.turn.id)
return started
def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
return self.request(
@@ -436,21 +517,25 @@ class AppServerClient:
) -> Iterator[AgentMessageDeltaNotification]:
started = self.turn_start(thread_id, text, params=params)
turn_id = started.turn.id
while True:
notification = self.next_notification()
if (
notification.method == "item/agentMessage/delta"
and isinstance(notification.payload, AgentMessageDeltaNotification)
and notification.payload.turn_id == turn_id
):
yield notification.payload
continue
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
break
self.acquire_turn_consumer(thread_id, turn_id)
try:
while True:
notification = self.next_turn_notification(thread_id, turn_id)
if (
notification.method == "item/agentMessage/delta"
and isinstance(notification.payload, AgentMessageDeltaNotification)
and notification.payload.turn_id == turn_id
):
yield notification.payload
continue
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
break
finally:
self.release_turn_consumer(thread_id, turn_id)
def _coerce_notification(self, method: str, params: object) -> Notification:
params_dict = params if isinstance(params, dict) else {}
@@ -512,7 +597,7 @@ class AppServerClient:
def _write_message(self, payload: JsonObject) -> None:
if self._proc is None or self._proc.stdin is None:
raise TransportClosedError("app-server is not running")
with self._lock:
with self._write_lock:
self._proc.stdin.write(json.dumps(payload) + "\n")
self._proc.stdin.flush()
@@ -535,6 +620,162 @@ class AppServerClient:
raise AppServerError(f"Invalid JSON-RPC payload: {message!r}")
return message
def _reset_transport_state(self) -> None:
self._pending_notifications = _BufferedNotificationStream(
maxlen=GLOBAL_NOTIFICATION_BACKLOG_LIMIT
)
self._pending_requests = {}
self._turn_streams = {}
self._turn_starting_by_thread_id = set()
self._active_turn_by_thread_id = {}
self._active_turn_consumers = set()
self._active_turn_stream_count = 0
self._transport_error = None
def _start_reader_thread(self) -> None:
def _reader() -> None:
try:
while True:
msg = self._read_message()
if "method" in msg and "id" in msg:
self._start_server_request_worker(msg)
continue
if "method" in msg and "id" not in msg:
method = msg["method"]
if isinstance(method, str):
self._dispatch_notification(
self._coerce_notification(method, msg.get("params"))
)
continue
self._handle_response_message(msg)
except BaseException as exc: # noqa: BLE001
self._finish_transport(exc)
self._reader_thread = threading.Thread(target=_reader, daemon=True)
self._reader_thread.start()
def _start_server_request_worker(self, msg: dict[str, JsonValue]) -> None:
def _resolve() -> None:
try:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
except BaseException:
return
threading.Thread(target=_resolve, daemon=True).start()
def _handle_response_message(self, msg: dict[str, JsonValue]) -> None:
request_id = msg.get("id")
if not isinstance(request_id, str):
return
with self._state_lock:
waiter = self._pending_requests.pop(request_id, None)
if waiter is None:
return
if "error" in msg:
err = msg["error"]
if isinstance(err, dict):
waiter.error = map_jsonrpc_error(
int(err.get("code", -32000)),
str(err.get("message", "unknown")),
err.get("data"),
)
else:
waiter.error = AppServerError("Malformed JSON-RPC error response")
else:
waiter.result = msg.get("result")
waiter.event.set()
def _dispatch_notification(self, notification: Notification) -> None:
self._pending_notifications.push(notification)
turn_key = self._turn_key_for_notification(notification)
if turn_key is None:
return
thread_id, turn_id = turn_key
close_stream = False
with self._state_lock:
stream = self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
if notification.method == "turn/started":
self._turn_starting_by_thread_id.discard(thread_id)
self._active_turn_by_thread_id[thread_id] = turn_id
elif notification.method == "turn/completed":
self._turn_starting_by_thread_id.discard(thread_id)
if self._active_turn_by_thread_id.get(thread_id) == turn_id:
self._active_turn_by_thread_id.pop(thread_id, None)
close_stream = True
stream.push(notification)
if close_stream:
stream.close()
def _turn_key_for_notification(self, notification: Notification) -> tuple[str, str] | None:
payload = notification.payload
thread_id = getattr(payload, "thread_id", None)
turn_id = getattr(payload, "turn_id", None)
if isinstance(thread_id, str) and isinstance(turn_id, str):
return thread_id, turn_id
turn = getattr(payload, "turn", None)
nested_turn_id = getattr(turn, "id", None)
if isinstance(thread_id, str) and isinstance(nested_turn_id, str):
return thread_id, nested_turn_id
return None
def _begin_turn_start(self, thread_id: str) -> None:
with self._state_lock:
active_turn_id = self._active_turn_by_thread_id.get(thread_id)
if active_turn_id is not None:
raise RuntimeError(
f"Thread {thread_id!r} already has active turn {active_turn_id!r}. "
"Use TurnHandle.steer() or TurnHandle.interrupt() instead of starting "
"another turn on the same thread."
)
if thread_id in self._turn_starting_by_thread_id:
raise RuntimeError(
f"Thread {thread_id!r} is already starting a turn on this client."
)
self._turn_starting_by_thread_id.add(thread_id)
def _cancel_turn_start(self, thread_id: str) -> None:
with self._state_lock:
self._turn_starting_by_thread_id.discard(thread_id)
def _finish_turn_start(self, thread_id: str, turn_id: str) -> None:
turn_key = (thread_id, turn_id)
with self._state_lock:
self._turn_starting_by_thread_id.discard(thread_id)
stream = self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
if not stream.is_closed():
self._active_turn_by_thread_id[thread_id] = turn_id
def _finish_transport(self, error: BaseException) -> None:
with self._state_lock:
if self._transport_error is not None:
return
self._transport_error = error
pending_requests = list(self._pending_requests.values())
self._pending_requests.clear()
turn_streams = list(self._turn_streams.values())
self._turn_streams.clear()
self._turn_starting_by_thread_id.clear()
self._active_turn_by_thread_id.clear()
self._active_turn_consumers.clear()
self._active_turn_stream_count = 0
for waiter in pending_requests:
waiter.error = error
waiter.event.set()
self._pending_notifications.close(error)
for stream in turn_streams:
stream.close(error)
def default_codex_home() -> str:
return str(Path.home() / ".codex")

View File

@@ -6,7 +6,7 @@ import time
from codex_app_server.async_client import AsyncAppServerClient
def test_async_client_serializes_transport_calls() -> None:
def test_async_client_allows_parallel_transport_calls() -> None:
async def scenario() -> int:
client = AsyncAppServerClient()
active = 0
@@ -24,10 +24,10 @@ def test_async_client_serializes_transport_calls() -> None:
await asyncio.gather(client.model_list(), client.model_list())
return max_active
assert asyncio.run(scenario()) == 1
assert asyncio.run(scenario()) == 2
def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
def test_async_stream_text_is_incremental_and_allows_parallel_calls() -> None:
async def scenario() -> tuple[str, list[str], bool]:
client = AsyncAppServerClient()
@@ -46,19 +46,19 @@ def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
stream = client.stream_text("thread-1", "hello")
first = await anext(stream)
blocked_before_stream_done = False
completed_before_stream_done = False
competing_call = asyncio.create_task(client.model_list())
await asyncio.sleep(0.01)
blocked_before_stream_done = not competing_call.done()
completed_before_stream_done = competing_call.done()
remaining: list[str] = []
async for item in stream:
remaining.append(item)
await competing_call
return first, remaining, blocked_before_stream_done
return first, remaining, completed_before_stream_done
first, remaining, blocked = asyncio.run(scenario())
first, remaining, completed = asyncio.run(scenario())
assert first == "first"
assert remaining == ["second", "third"]
assert blocked
assert completed

View File

@@ -4,7 +4,12 @@ from pathlib import Path
from typing import Any
from codex_app_server.client import AppServerClient, _params_dict
from codex_app_server.generated.v2_all import ThreadListParams, ThreadTokenUsageUpdatedNotification
from codex_app_server.generated.v2_all import (
AgentMessageDeltaNotification,
ThreadListParams,
ThreadTokenUsageUpdatedNotification,
TurnCompletedNotification,
)
from codex_app_server.models import UnknownNotification
ROOT = Path(__file__).resolve().parents[1]
@@ -93,3 +98,58 @@ def test_invalid_notification_payload_falls_back_to_unknown() -> None:
assert event.method == "thread/tokenUsage/updated"
assert isinstance(event.payload, UnknownNotification)
def test_client_routes_interleaved_turn_notifications_to_matching_turn_queues() -> None:
client = AppServerClient()
first = client._coerce_notification(
"item/agentMessage/delta",
{
"delta": "first",
"itemId": "item-1",
"threadId": "thread-1",
"turnId": "turn-1",
},
)
second = client._coerce_notification(
"item/agentMessage/delta",
{
"delta": "second",
"itemId": "item-2",
"threadId": "thread-2",
"turnId": "turn-2",
},
)
client._dispatch_notification(first) # type: ignore[attr-defined]
client._dispatch_notification(second) # type: ignore[attr-defined]
first_turn = client.next_turn_notification("thread-1", "turn-1")
second_turn = client.next_turn_notification("thread-2", "turn-2")
assert isinstance(first_turn.payload, AgentMessageDeltaNotification)
assert first_turn.payload.delta == "first"
assert isinstance(second_turn.payload, AgentMessageDeltaNotification)
assert second_turn.payload.delta == "second"
def test_next_notification_still_returns_turn_notifications_without_active_streams() -> None:
client = AppServerClient()
completed = client._coerce_notification(
"turn/completed",
{
"threadId": "thread-1",
"turn": {
"id": "turn-1",
"items": [],
"status": "completed",
},
},
)
client._dispatch_notification(completed) # type: ignore[attr-defined]
event = client.next_notification()
assert event.method == "turn/completed"
assert isinstance(event.payload, TurnCompletedNotification)
assert event.payload.turn.id == "turn-1"

View File

@@ -133,6 +133,15 @@ def _token_usage_notification(
)
def _turn_notification_source(
notifications_by_turn: dict[tuple[str, str], deque[Notification]],
):
def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
return notifications_by_turn[(thread_id, turn_id)].popleft()
return fake_next_turn_notification
def test_codex_init_failure_closes_client(monkeypatch: pytest.MonkeyPatch) -> None:
closed: list[bool] = []
@@ -226,66 +235,132 @@ def test_async_codex_initializes_only_once_under_concurrency() -> None:
asyncio.run(scenario())
def test_turn_stream_rejects_second_active_consumer() -> None:
def test_turn_stream_allows_different_active_threads() -> None:
client = AppServerClient()
notifications: deque[Notification] = deque(
[
_delta_notification(turn_id="turn-1"),
_completed_notification(turn_id="turn-1"),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
]
),
("thread-2", "turn-2"): deque(
[
_delta_notification(thread_id="thread-2", turn_id="turn-2"),
_completed_notification(thread_id="thread-2", turn_id="turn-2"),
]
),
}
)
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
second_stream = TurnHandle(client, "thread-2", "turn-2").stream()
assert next(first_stream).method == "item/agentMessage/delta"
assert next(second_stream).method == "item/agentMessage/delta"
first_stream.close()
second_stream.close()
def test_turn_stream_blocks_next_notification_while_active() -> None:
client = AppServerClient()
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
assert next(first_stream).method == "item/agentMessage/delta"
second_stream = TurnHandle(client, "thread-1", "turn-2").stream()
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
next(second_stream)
with pytest.raises(RuntimeError, match="next_notification\\(\\) is incompatible"):
client.next_notification()
first_stream.close()
def test_async_turn_stream_rejects_second_active_consumer() -> None:
def test_turn_start_rejects_same_thread_overlap_and_allows_after_completion() -> None:
client = AppServerClient()
turn_ids = iter(["turn-1", "turn-2"])
def fake_request(method: str, params, *, response_model): # type: ignore[no-untyped-def]
assert method == "turn/start"
return response_model.model_validate(
{
"turn": {
"id": next(turn_ids),
"items": [],
"status": TurnStatus.in_progress.value,
}
}
)
client.request = fake_request # type: ignore[method-assign]
first = client.turn_start("thread-1", "first turn")
assert first.turn.id == "turn-1"
with pytest.raises(RuntimeError, match="already has active turn"):
client.turn_start("thread-1", "second turn")
client._dispatch_notification( # type: ignore[attr-defined]
_completed_notification(thread_id="thread-1", turn_id="turn-1")
)
second = client.turn_start("thread-1", "second turn")
assert second.turn.id == "turn-2"
def test_async_turn_stream_allows_different_active_threads() -> None:
async def scenario() -> None:
codex = AsyncCodex()
async def fake_ensure_initialized() -> None:
return None
notifications: deque[Notification] = deque(
[
_delta_notification(turn_id="turn-1"),
_completed_notification(turn_id="turn-1"),
]
)
notifications_by_turn = {
("thread-1", "turn-1"): deque(
[
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
]
),
("thread-2", "turn-2"): deque(
[
_delta_notification(thread_id="thread-2", turn_id="turn-2"),
_completed_notification(thread_id="thread-2", turn_id="turn-2"),
]
),
}
async def fake_next_notification() -> Notification:
return notifications.popleft()
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
return notifications_by_turn[(thread_id, turn_id)].popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
first_stream = AsyncTurnHandle(codex, "thread-1", "turn-1").stream()
second_stream = AsyncTurnHandle(codex, "thread-2", "turn-2").stream()
assert (await anext(first_stream)).method == "item/agentMessage/delta"
second_stream = AsyncTurnHandle(codex, "thread-1", "turn-2").stream()
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
await anext(second_stream)
assert (await anext(second_stream)).method == "item/agentMessage/delta"
await first_stream.aclose()
await second_stream.aclose()
asyncio.run(scenario())
def test_turn_run_returns_completed_turn_payload() -> None:
client = AppServerClient()
notifications: deque[Notification] = deque(
[
_completed_notification(),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{("thread-1", "turn-1"): deque([_completed_notification()])}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
result = TurnHandle(client, "thread-1", "turn-1").run()
@@ -298,14 +373,17 @@ def test_thread_run_accepts_string_input_and_returns_run_result() -> None:
client = AppServerClient()
item_notification = _item_completed_notification(text="Hello.")
usage_notification = _token_usage_notification()
notifications: deque[Notification] = deque(
[
item_notification,
usage_notification,
_completed_notification(),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
item_notification,
usage_notification,
_completed_notification(),
]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
seen: dict[str, object] = {}
def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202
@@ -331,14 +409,17 @@ def test_thread_run_uses_last_completed_assistant_message_as_final_response() ->
client = AppServerClient()
first_item_notification = _item_completed_notification(text="First message")
second_item_notification = _item_completed_notification(text="Second message")
notifications: deque[Notification] = deque(
[
first_item_notification,
second_item_notification,
_completed_notification(),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
first_item_notification,
second_item_notification,
_completed_notification(),
]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -356,14 +437,17 @@ def test_thread_run_preserves_empty_last_assistant_message() -> None:
client = AppServerClient()
first_item_notification = _item_completed_notification(text="First message")
second_item_notification = _item_completed_notification(text="")
notifications: deque[Notification] = deque(
[
first_item_notification,
second_item_notification,
_completed_notification(),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
first_item_notification,
second_item_notification,
_completed_notification(),
]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -387,14 +471,17 @@ def test_thread_run_prefers_explicit_final_answer_over_later_commentary() -> Non
text="Commentary",
phase=MessagePhase.commentary,
)
notifications: deque[Notification] = deque(
[
final_answer_notification,
commentary_notification,
_completed_notification(),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
final_answer_notification,
commentary_notification,
_completed_notification(),
]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -414,13 +501,16 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
text="Commentary",
phase=MessagePhase.commentary,
)
notifications: deque[Notification] = deque(
[
commentary_notification,
_completed_notification(),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[
commentary_notification,
_completed_notification(),
]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -433,12 +523,13 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
def test_thread_run_raises_on_failed_turn() -> None:
client = AppServerClient()
notifications: deque[Notification] = deque(
[
_completed_notification(status="failed", error_message="boom"),
]
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
{
("thread-1", "turn-1"): deque(
[_completed_notification(status="failed", error_message="boom")]
),
}
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -471,12 +562,13 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None:
seen["params"] = params
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
async def fake_next_notification() -> Notification:
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
assert (thread_id, turn_id) == ("thread-1", "turn-1")
return notifications.popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
result = await AsyncThread(codex, "thread-1").run("hello")
@@ -511,12 +603,13 @@ def test_async_thread_run_uses_last_completed_assistant_message_as_final_respons
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
async def fake_next_notification() -> Notification:
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
assert (thread_id, turn_id) == ("thread-1", "turn-1")
return notifications.popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
result = await AsyncThread(codex, "thread-1").run("hello")
@@ -550,12 +643,13 @@ def test_async_thread_run_returns_none_when_only_commentary_messages_complete()
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
async def fake_next_notification() -> Notification:
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
assert (thread_id, turn_id) == ("thread-1", "turn-1")
return notifications.popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
result = await AsyncThread(codex, "thread-1").run("hello")