mirror of
https://github.com/openai/codex.git
synced 2026-03-22 22:06:29 +03:00
Compare commits
1 Commits
main
...
dev/shaqay
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
083243dca1 |
@@ -2,7 +2,7 @@
|
||||
|
||||
Public surface of `codex_app_server` for app-server v2.
|
||||
|
||||
This SDK surface is experimental. The current implementation intentionally allows only one active turn consumer (`Thread.run()`, `TurnHandle.stream()`, or `TurnHandle.run()`) per client instance at a time.
|
||||
This SDK surface is experimental. The current implementation allows concurrent turn consumers on one client only when they belong to different thread IDs. Each client still supports only one active turn per thread ID at a time.
|
||||
|
||||
## Package Entry
|
||||
|
||||
@@ -137,8 +137,9 @@ Use `turn(...)` when you need low-level turn control (`stream()`, `steer()`,
|
||||
|
||||
Behavior notes:
|
||||
|
||||
- `stream()` and `run()` are exclusive per client instance in the current experimental build
|
||||
- starting a second turn consumer on the same `Codex` instance raises `RuntimeError`
|
||||
- `stream()` and `run()` may run concurrently on one client when the turns belong to different thread IDs
|
||||
- starting a second turn on the same thread raises `RuntimeError`; use `steer()` or `interrupt()` on the existing handle instead
|
||||
- low-level global notification APIs such as `next_notification()` are incompatible with active turn streaming on the same client
|
||||
|
||||
### AsyncTurnHandle
|
||||
|
||||
@@ -149,8 +150,9 @@ Behavior notes:
|
||||
|
||||
Behavior notes:
|
||||
|
||||
- `stream()` and `run()` are exclusive per client instance in the current experimental build
|
||||
- starting a second turn consumer on the same `AsyncCodex` instance raises `RuntimeError`
|
||||
- `stream()` and `run()` may run concurrently on one client when the turns belong to different thread IDs
|
||||
- starting a second turn on the same thread raises `RuntimeError`; use `steer()` or `interrupt()` on the existing handle instead
|
||||
- low-level global notification APIs such as `next_notification()` are incompatible with active turn streaming on the same client
|
||||
|
||||
## Inputs
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ What happened:
|
||||
- `thread.run("...")` started a turn, consumed events until completion, and returned the final assistant response plus collected items and usage.
|
||||
- `result.final_response` is `None` when no final-answer or phase-less assistant message item completes for the turn.
|
||||
- use `thread.turn(...)` when you need a `TurnHandle` for streaming, steering, interrupting, or turn IDs/status
|
||||
- one client can have only one active turn consumer (`thread.run(...)`, `TurnHandle.stream()`, or `TurnHandle.run()`) at a time in the current experimental build
|
||||
- one client can run turns concurrently across different thread IDs in the current experimental build
|
||||
- one thread can have only one active turn at a time on a given client; start a second same-thread turn only after the first completes, or use `steer()` on the existing `TurnHandle`
|
||||
|
||||
## 3) Continue the same thread (multi-turn)
|
||||
|
||||
|
||||
@@ -653,11 +653,10 @@ class TurnHandle:
|
||||
return self._client.turn_interrupt(self.thread_id, self.id)
|
||||
|
||||
def stream(self) -> Iterator[Notification]:
|
||||
# TODO: replace this client-wide experimental guard with per-turn event demux.
|
||||
self._client.acquire_turn_consumer(self.id)
|
||||
self._client.acquire_turn_consumer(self.thread_id, self.id)
|
||||
try:
|
||||
while True:
|
||||
event = self._client.next_notification()
|
||||
event = self._client.next_turn_notification(self.thread_id, self.id)
|
||||
yield event
|
||||
if (
|
||||
event.method == "turn/completed"
|
||||
@@ -666,7 +665,7 @@ class TurnHandle:
|
||||
):
|
||||
break
|
||||
finally:
|
||||
self._client.release_turn_consumer(self.id)
|
||||
self._client.release_turn_consumer(self.thread_id, self.id)
|
||||
|
||||
def run(self) -> AppServerTurn:
|
||||
completed: TurnCompletedNotification | None = None
|
||||
@@ -704,11 +703,10 @@ class AsyncTurnHandle:
|
||||
|
||||
async def stream(self) -> AsyncIterator[Notification]:
|
||||
await self._codex._ensure_initialized()
|
||||
# TODO: replace this client-wide experimental guard with per-turn event demux.
|
||||
self._codex._client.acquire_turn_consumer(self.id)
|
||||
self._codex._client.acquire_turn_consumer(self.thread_id, self.id)
|
||||
try:
|
||||
while True:
|
||||
event = await self._codex._client.next_notification()
|
||||
event = await self._codex._client.next_turn_notification(self.thread_id, self.id)
|
||||
yield event
|
||||
if (
|
||||
event.method == "turn/completed"
|
||||
@@ -717,7 +715,7 @@ class AsyncTurnHandle:
|
||||
):
|
||||
break
|
||||
finally:
|
||||
self._codex._client.release_turn_consumer(self.id)
|
||||
self._codex._client.release_turn_consumer(self.thread_id, self.id)
|
||||
|
||||
async def run(self) -> AppServerTurn:
|
||||
completed: TurnCompletedNotification | None = None
|
||||
|
||||
@@ -41,8 +41,6 @@ class AsyncAppServerClient:
|
||||
|
||||
def __init__(self, config: AppServerConfig | None = None) -> None:
|
||||
self._sync = AppServerClient(config=config)
|
||||
# Single stdio transport cannot be read safely from multiple threads.
|
||||
self._transport_lock = asyncio.Lock()
|
||||
|
||||
async def __aenter__(self) -> "AsyncAppServerClient":
|
||||
await self.start()
|
||||
@@ -58,8 +56,7 @@ class AsyncAppServerClient:
|
||||
*args: ParamsT.args,
|
||||
**kwargs: ParamsT.kwargs,
|
||||
) -> ReturnT:
|
||||
async with self._transport_lock:
|
||||
return await asyncio.to_thread(fn, *args, **kwargs)
|
||||
return await asyncio.to_thread(fn, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _next_from_iterator(
|
||||
@@ -79,11 +76,11 @@ class AsyncAppServerClient:
|
||||
async def initialize(self) -> InitializeResponse:
|
||||
return await self._call_sync(self._sync.initialize)
|
||||
|
||||
def acquire_turn_consumer(self, turn_id: str) -> None:
|
||||
self._sync.acquire_turn_consumer(turn_id)
|
||||
def acquire_turn_consumer(self, thread_id: str, turn_id: str) -> None:
|
||||
self._sync.acquire_turn_consumer(thread_id, turn_id)
|
||||
|
||||
def release_turn_consumer(self, turn_id: str) -> None:
|
||||
self._sync.release_turn_consumer(turn_id)
|
||||
def release_turn_consumer(self, thread_id: str, turn_id: str) -> None:
|
||||
self._sync.release_turn_consumer(thread_id, turn_id)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
@@ -184,6 +181,9 @@ class AsyncAppServerClient:
|
||||
async def next_notification(self) -> Notification:
|
||||
return await self._call_sync(self._sync.next_notification)
|
||||
|
||||
async def next_turn_notification(self, thread_id: str, turn_id: str) -> Notification:
|
||||
return await self._call_sync(self._sync.next_turn_notification, thread_id, turn_id)
|
||||
|
||||
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
|
||||
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
|
||||
|
||||
@@ -196,13 +196,12 @@ class AsyncAppServerClient:
|
||||
text: str,
|
||||
params: V2TurnStartParams | JsonObject | None = None,
|
||||
) -> AsyncIterator[AgentMessageDeltaNotification]:
|
||||
async with self._transport_lock:
|
||||
iterator = self._sync.stream_text(thread_id, text, params)
|
||||
while True:
|
||||
has_value, chunk = await asyncio.to_thread(
|
||||
self._next_from_iterator,
|
||||
iterator,
|
||||
)
|
||||
if not has_value:
|
||||
break
|
||||
yield chunk
|
||||
iterator = self._sync.stream_text(thread_id, text, params)
|
||||
while True:
|
||||
has_value, chunk = await asyncio.to_thread(
|
||||
self._next_from_iterator,
|
||||
iterator,
|
||||
)
|
||||
if not has_value:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
@@ -6,7 +6,7 @@ import subprocess
|
||||
import threading
|
||||
import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, TypeVar
|
||||
|
||||
@@ -48,6 +48,58 @@ from .retry import retry_on_overload
|
||||
ModelT = TypeVar("ModelT", bound=BaseModel)
|
||||
ApprovalHandler = Callable[[str, JsonObject | None], JsonObject]
|
||||
RUNTIME_PKG_NAME = "codex-cli-bin"
|
||||
GLOBAL_NOTIFICATION_BACKLOG_LIMIT = 512
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _PendingRequest:
|
||||
event: threading.Event = field(default_factory=threading.Event)
|
||||
result: JsonValue | None = None
|
||||
error: BaseException | None = None
|
||||
|
||||
|
||||
class _BufferedNotificationStream:
|
||||
def __init__(self, *, maxlen: int | None = None) -> None:
|
||||
self._condition = threading.Condition()
|
||||
self._items: deque[Notification] = (
|
||||
deque(maxlen=maxlen) if maxlen is not None else deque()
|
||||
)
|
||||
self._closed = False
|
||||
self._error: BaseException | None = None
|
||||
|
||||
def push(self, notification: Notification) -> None:
|
||||
with self._condition:
|
||||
if self._closed:
|
||||
return
|
||||
self._items.append(notification)
|
||||
self._condition.notify_all()
|
||||
|
||||
def pop(self) -> Notification:
|
||||
with self._condition:
|
||||
while not self._items and not self._closed:
|
||||
self._condition.wait()
|
||||
|
||||
if self._items:
|
||||
return self._items.popleft()
|
||||
|
||||
if self._error is not None:
|
||||
raise self._error
|
||||
|
||||
raise TransportClosedError("notification stream is closed")
|
||||
|
||||
def close(self, error: BaseException | None = None) -> None:
|
||||
with self._condition:
|
||||
self._closed = True
|
||||
self._error = error
|
||||
self._condition.notify_all()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
with self._condition:
|
||||
return self._closed
|
||||
|
||||
def is_drained(self) -> bool:
|
||||
with self._condition:
|
||||
return self._closed and not self._items
|
||||
|
||||
|
||||
def _params_dict(
|
||||
@@ -144,12 +196,21 @@ class AppServerClient:
|
||||
self.config = config or AppServerConfig()
|
||||
self._approval_handler = approval_handler or self._default_approval_handler
|
||||
self._proc: subprocess.Popen[str] | None = None
|
||||
self._lock = threading.Lock()
|
||||
self._turn_consumer_lock = threading.Lock()
|
||||
self._active_turn_consumer: str | None = None
|
||||
self._pending_notifications: deque[Notification] = deque()
|
||||
self._write_lock = threading.Lock()
|
||||
self._state_lock = threading.Lock()
|
||||
self._pending_notifications = _BufferedNotificationStream(
|
||||
maxlen=GLOBAL_NOTIFICATION_BACKLOG_LIMIT
|
||||
)
|
||||
self._pending_requests: dict[str, _PendingRequest] = {}
|
||||
self._turn_streams: dict[tuple[str, str], _BufferedNotificationStream] = {}
|
||||
self._turn_starting_by_thread_id: set[str] = set()
|
||||
self._active_turn_by_thread_id: dict[str, str] = {}
|
||||
self._active_turn_consumers: set[tuple[str, str]] = set()
|
||||
self._active_turn_stream_count = 0
|
||||
self._transport_error: BaseException | None = None
|
||||
self._stderr_lines: deque[str] = deque(maxlen=400)
|
||||
self._stderr_thread: threading.Thread | None = None
|
||||
self._reader_thread: threading.Thread | None = None
|
||||
|
||||
def __enter__(self) -> "AppServerClient":
|
||||
self.start()
|
||||
@@ -161,6 +222,7 @@ class AppServerClient:
|
||||
def start(self) -> None:
|
||||
if self._proc is not None:
|
||||
return
|
||||
self._reset_transport_state()
|
||||
|
||||
if self.config.launch_args_override is not None:
|
||||
args = list(self.config.launch_args_override)
|
||||
@@ -187,13 +249,14 @@ class AppServerClient:
|
||||
)
|
||||
|
||||
self._start_stderr_drain_thread()
|
||||
self._start_reader_thread()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._proc is None:
|
||||
return
|
||||
proc = self._proc
|
||||
self._proc = None
|
||||
self._active_turn_consumer = None
|
||||
self._finish_transport(TransportClosedError("app-server closed"))
|
||||
|
||||
if proc.stdin:
|
||||
proc.stdin.close()
|
||||
@@ -205,6 +268,8 @@ class AppServerClient:
|
||||
|
||||
if self._stderr_thread and self._stderr_thread.is_alive():
|
||||
self._stderr_thread.join(timeout=0.5)
|
||||
if self._reader_thread and self._reader_thread.is_alive():
|
||||
self._reader_thread.join(timeout=0.5)
|
||||
|
||||
def initialize(self) -> InitializeResponse:
|
||||
result = self.request(
|
||||
@@ -238,67 +303,76 @@ class AppServerClient:
|
||||
|
||||
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
|
||||
request_id = str(uuid.uuid4())
|
||||
self._write_message({"id": request_id, "method": method, "params": params or {}})
|
||||
waiter = _PendingRequest()
|
||||
with self._state_lock:
|
||||
if self._transport_error is not None:
|
||||
raise self._transport_error
|
||||
self._pending_requests[request_id] = waiter
|
||||
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
try:
|
||||
self._write_message({"id": request_id, "method": method, "params": params or {}})
|
||||
except BaseException:
|
||||
with self._state_lock:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
|
||||
if "method" in msg and "id" in msg:
|
||||
response = self._handle_server_request(msg)
|
||||
self._write_message({"id": msg["id"], "result": response})
|
||||
continue
|
||||
|
||||
if "method" in msg and "id" not in msg:
|
||||
self._pending_notifications.append(
|
||||
self._coerce_notification(msg["method"], msg.get("params"))
|
||||
)
|
||||
continue
|
||||
|
||||
if msg.get("id") != request_id:
|
||||
continue
|
||||
|
||||
if "error" in msg:
|
||||
err = msg["error"]
|
||||
if isinstance(err, dict):
|
||||
raise map_jsonrpc_error(
|
||||
int(err.get("code", -32000)),
|
||||
str(err.get("message", "unknown")),
|
||||
err.get("data"),
|
||||
)
|
||||
raise AppServerError("Malformed JSON-RPC error response")
|
||||
|
||||
return msg.get("result")
|
||||
waiter.event.wait()
|
||||
if waiter.error is not None:
|
||||
raise waiter.error
|
||||
return waiter.result
|
||||
|
||||
def notify(self, method: str, params: JsonObject | None = None) -> None:
|
||||
self._write_message({"method": method, "params": params or {}})
|
||||
|
||||
def next_notification(self) -> Notification:
|
||||
if self._pending_notifications:
|
||||
return self._pending_notifications.popleft()
|
||||
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
if "method" in msg and "id" in msg:
|
||||
response = self._handle_server_request(msg)
|
||||
self._write_message({"id": msg["id"], "result": response})
|
||||
continue
|
||||
if "method" in msg and "id" not in msg:
|
||||
return self._coerce_notification(msg["method"], msg.get("params"))
|
||||
|
||||
def acquire_turn_consumer(self, turn_id: str) -> None:
|
||||
with self._turn_consumer_lock:
|
||||
if self._active_turn_consumer is not None:
|
||||
with self._state_lock:
|
||||
if self._active_turn_stream_count > 0:
|
||||
raise RuntimeError(
|
||||
"Concurrent turn consumers are not yet supported in the experimental SDK. "
|
||||
f"Client is already streaming turn {self._active_turn_consumer!r}; "
|
||||
f"cannot start turn {turn_id!r} until the active consumer finishes."
|
||||
"next_notification() is incompatible with active turn streaming on the same "
|
||||
"client. Consume notifications from TurnHandle.stream()/run() instead."
|
||||
)
|
||||
self._active_turn_consumer = turn_id
|
||||
return self._pending_notifications.pop()
|
||||
|
||||
def release_turn_consumer(self, turn_id: str) -> None:
|
||||
with self._turn_consumer_lock:
|
||||
if self._active_turn_consumer == turn_id:
|
||||
self._active_turn_consumer = None
|
||||
def acquire_turn_consumer(self, thread_id: str, turn_id: str) -> None:
|
||||
turn_key = (thread_id, turn_id)
|
||||
with self._state_lock:
|
||||
if turn_key in self._active_turn_consumers:
|
||||
raise RuntimeError(
|
||||
f"Turn {turn_id!r} is already being streamed on thread {thread_id!r}."
|
||||
)
|
||||
self._active_turn_consumers.add(turn_key)
|
||||
self._active_turn_stream_count += 1
|
||||
self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
|
||||
|
||||
def release_turn_consumer(self, thread_id: str, turn_id: str) -> None:
|
||||
turn_key = (thread_id, turn_id)
|
||||
with self._state_lock:
|
||||
if turn_key in self._active_turn_consumers:
|
||||
self._active_turn_consumers.remove(turn_key)
|
||||
self._active_turn_stream_count -= 1
|
||||
stream = self._turn_streams.get(turn_key)
|
||||
if stream is not None and stream.is_drained():
|
||||
self._turn_streams.pop(turn_key, None)
|
||||
|
||||
def next_turn_notification(self, thread_id: str, turn_id: str) -> Notification:
|
||||
turn_key = (thread_id, turn_id)
|
||||
with self._state_lock:
|
||||
stream = self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
|
||||
return stream.pop()
|
||||
|
||||
def assert_can_start_turn(self, thread_id: str) -> None:
|
||||
with self._state_lock:
|
||||
if thread_id in self._turn_starting_by_thread_id:
|
||||
raise RuntimeError(
|
||||
f"Thread {thread_id!r} is already starting a turn on this client."
|
||||
)
|
||||
active_turn_id = self._active_turn_by_thread_id.get(thread_id)
|
||||
if active_turn_id is not None:
|
||||
raise RuntimeError(
|
||||
f"Thread {thread_id!r} already has active turn {active_turn_id!r}. "
|
||||
"Use TurnHandle.steer() or TurnHandle.interrupt() instead of starting "
|
||||
"another turn on the same thread."
|
||||
)
|
||||
|
||||
def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
|
||||
return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse)
|
||||
@@ -355,12 +429,19 @@ class AppServerClient:
|
||||
input_items: list[JsonObject] | JsonObject | str,
|
||||
params: V2TurnStartParams | JsonObject | None = None,
|
||||
) -> TurnStartResponse:
|
||||
self._begin_turn_start(thread_id)
|
||||
payload = {
|
||||
**_params_dict(params),
|
||||
"threadId": thread_id,
|
||||
"input": self._normalize_input_items(input_items),
|
||||
}
|
||||
return self.request("turn/start", payload, response_model=TurnStartResponse)
|
||||
try:
|
||||
started = self.request("turn/start", payload, response_model=TurnStartResponse)
|
||||
except BaseException:
|
||||
self._cancel_turn_start(thread_id)
|
||||
raise
|
||||
self._finish_turn_start(thread_id, started.turn.id)
|
||||
return started
|
||||
|
||||
def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
|
||||
return self.request(
|
||||
@@ -436,21 +517,25 @@ class AppServerClient:
|
||||
) -> Iterator[AgentMessageDeltaNotification]:
|
||||
started = self.turn_start(thread_id, text, params=params)
|
||||
turn_id = started.turn.id
|
||||
while True:
|
||||
notification = self.next_notification()
|
||||
if (
|
||||
notification.method == "item/agentMessage/delta"
|
||||
and isinstance(notification.payload, AgentMessageDeltaNotification)
|
||||
and notification.payload.turn_id == turn_id
|
||||
):
|
||||
yield notification.payload
|
||||
continue
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
break
|
||||
self.acquire_turn_consumer(thread_id, turn_id)
|
||||
try:
|
||||
while True:
|
||||
notification = self.next_turn_notification(thread_id, turn_id)
|
||||
if (
|
||||
notification.method == "item/agentMessage/delta"
|
||||
and isinstance(notification.payload, AgentMessageDeltaNotification)
|
||||
and notification.payload.turn_id == turn_id
|
||||
):
|
||||
yield notification.payload
|
||||
continue
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
break
|
||||
finally:
|
||||
self.release_turn_consumer(thread_id, turn_id)
|
||||
|
||||
def _coerce_notification(self, method: str, params: object) -> Notification:
|
||||
params_dict = params if isinstance(params, dict) else {}
|
||||
@@ -512,7 +597,7 @@ class AppServerClient:
|
||||
def _write_message(self, payload: JsonObject) -> None:
|
||||
if self._proc is None or self._proc.stdin is None:
|
||||
raise TransportClosedError("app-server is not running")
|
||||
with self._lock:
|
||||
with self._write_lock:
|
||||
self._proc.stdin.write(json.dumps(payload) + "\n")
|
||||
self._proc.stdin.flush()
|
||||
|
||||
@@ -535,6 +620,162 @@ class AppServerClient:
|
||||
raise AppServerError(f"Invalid JSON-RPC payload: {message!r}")
|
||||
return message
|
||||
|
||||
def _reset_transport_state(self) -> None:
|
||||
self._pending_notifications = _BufferedNotificationStream(
|
||||
maxlen=GLOBAL_NOTIFICATION_BACKLOG_LIMIT
|
||||
)
|
||||
self._pending_requests = {}
|
||||
self._turn_streams = {}
|
||||
self._turn_starting_by_thread_id = set()
|
||||
self._active_turn_by_thread_id = {}
|
||||
self._active_turn_consumers = set()
|
||||
self._active_turn_stream_count = 0
|
||||
self._transport_error = None
|
||||
|
||||
def _start_reader_thread(self) -> None:
|
||||
def _reader() -> None:
|
||||
try:
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
if "method" in msg and "id" in msg:
|
||||
self._start_server_request_worker(msg)
|
||||
continue
|
||||
if "method" in msg and "id" not in msg:
|
||||
method = msg["method"]
|
||||
if isinstance(method, str):
|
||||
self._dispatch_notification(
|
||||
self._coerce_notification(method, msg.get("params"))
|
||||
)
|
||||
continue
|
||||
self._handle_response_message(msg)
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
self._finish_transport(exc)
|
||||
|
||||
self._reader_thread = threading.Thread(target=_reader, daemon=True)
|
||||
self._reader_thread.start()
|
||||
|
||||
def _start_server_request_worker(self, msg: dict[str, JsonValue]) -> None:
|
||||
def _resolve() -> None:
|
||||
try:
|
||||
response = self._handle_server_request(msg)
|
||||
self._write_message({"id": msg["id"], "result": response})
|
||||
except BaseException:
|
||||
return
|
||||
|
||||
threading.Thread(target=_resolve, daemon=True).start()
|
||||
|
||||
def _handle_response_message(self, msg: dict[str, JsonValue]) -> None:
|
||||
request_id = msg.get("id")
|
||||
if not isinstance(request_id, str):
|
||||
return
|
||||
|
||||
with self._state_lock:
|
||||
waiter = self._pending_requests.pop(request_id, None)
|
||||
|
||||
if waiter is None:
|
||||
return
|
||||
|
||||
if "error" in msg:
|
||||
err = msg["error"]
|
||||
if isinstance(err, dict):
|
||||
waiter.error = map_jsonrpc_error(
|
||||
int(err.get("code", -32000)),
|
||||
str(err.get("message", "unknown")),
|
||||
err.get("data"),
|
||||
)
|
||||
else:
|
||||
waiter.error = AppServerError("Malformed JSON-RPC error response")
|
||||
else:
|
||||
waiter.result = msg.get("result")
|
||||
waiter.event.set()
|
||||
|
||||
def _dispatch_notification(self, notification: Notification) -> None:
|
||||
self._pending_notifications.push(notification)
|
||||
|
||||
turn_key = self._turn_key_for_notification(notification)
|
||||
if turn_key is None:
|
||||
return
|
||||
|
||||
thread_id, turn_id = turn_key
|
||||
close_stream = False
|
||||
with self._state_lock:
|
||||
stream = self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
|
||||
if notification.method == "turn/started":
|
||||
self._turn_starting_by_thread_id.discard(thread_id)
|
||||
self._active_turn_by_thread_id[thread_id] = turn_id
|
||||
elif notification.method == "turn/completed":
|
||||
self._turn_starting_by_thread_id.discard(thread_id)
|
||||
if self._active_turn_by_thread_id.get(thread_id) == turn_id:
|
||||
self._active_turn_by_thread_id.pop(thread_id, None)
|
||||
close_stream = True
|
||||
|
||||
stream.push(notification)
|
||||
if close_stream:
|
||||
stream.close()
|
||||
|
||||
def _turn_key_for_notification(self, notification: Notification) -> tuple[str, str] | None:
|
||||
payload = notification.payload
|
||||
thread_id = getattr(payload, "thread_id", None)
|
||||
turn_id = getattr(payload, "turn_id", None)
|
||||
if isinstance(thread_id, str) and isinstance(turn_id, str):
|
||||
return thread_id, turn_id
|
||||
|
||||
turn = getattr(payload, "turn", None)
|
||||
nested_turn_id = getattr(turn, "id", None)
|
||||
if isinstance(thread_id, str) and isinstance(nested_turn_id, str):
|
||||
return thread_id, nested_turn_id
|
||||
|
||||
return None
|
||||
|
||||
def _begin_turn_start(self, thread_id: str) -> None:
|
||||
with self._state_lock:
|
||||
active_turn_id = self._active_turn_by_thread_id.get(thread_id)
|
||||
if active_turn_id is not None:
|
||||
raise RuntimeError(
|
||||
f"Thread {thread_id!r} already has active turn {active_turn_id!r}. "
|
||||
"Use TurnHandle.steer() or TurnHandle.interrupt() instead of starting "
|
||||
"another turn on the same thread."
|
||||
)
|
||||
if thread_id in self._turn_starting_by_thread_id:
|
||||
raise RuntimeError(
|
||||
f"Thread {thread_id!r} is already starting a turn on this client."
|
||||
)
|
||||
self._turn_starting_by_thread_id.add(thread_id)
|
||||
|
||||
def _cancel_turn_start(self, thread_id: str) -> None:
|
||||
with self._state_lock:
|
||||
self._turn_starting_by_thread_id.discard(thread_id)
|
||||
|
||||
def _finish_turn_start(self, thread_id: str, turn_id: str) -> None:
|
||||
turn_key = (thread_id, turn_id)
|
||||
with self._state_lock:
|
||||
self._turn_starting_by_thread_id.discard(thread_id)
|
||||
stream = self._turn_streams.setdefault(turn_key, _BufferedNotificationStream())
|
||||
if not stream.is_closed():
|
||||
self._active_turn_by_thread_id[thread_id] = turn_id
|
||||
|
||||
def _finish_transport(self, error: BaseException) -> None:
|
||||
with self._state_lock:
|
||||
if self._transport_error is not None:
|
||||
return
|
||||
self._transport_error = error
|
||||
pending_requests = list(self._pending_requests.values())
|
||||
self._pending_requests.clear()
|
||||
turn_streams = list(self._turn_streams.values())
|
||||
self._turn_streams.clear()
|
||||
self._turn_starting_by_thread_id.clear()
|
||||
self._active_turn_by_thread_id.clear()
|
||||
self._active_turn_consumers.clear()
|
||||
self._active_turn_stream_count = 0
|
||||
|
||||
for waiter in pending_requests:
|
||||
waiter.error = error
|
||||
waiter.event.set()
|
||||
|
||||
self._pending_notifications.close(error)
|
||||
for stream in turn_streams:
|
||||
stream.close(error)
|
||||
|
||||
|
||||
def default_codex_home() -> str:
|
||||
return str(Path.home() / ".codex")
|
||||
|
||||
@@ -6,7 +6,7 @@ import time
|
||||
from codex_app_server.async_client import AsyncAppServerClient
|
||||
|
||||
|
||||
def test_async_client_serializes_transport_calls() -> None:
|
||||
def test_async_client_allows_parallel_transport_calls() -> None:
|
||||
async def scenario() -> int:
|
||||
client = AsyncAppServerClient()
|
||||
active = 0
|
||||
@@ -24,10 +24,10 @@ def test_async_client_serializes_transport_calls() -> None:
|
||||
await asyncio.gather(client.model_list(), client.model_list())
|
||||
return max_active
|
||||
|
||||
assert asyncio.run(scenario()) == 1
|
||||
assert asyncio.run(scenario()) == 2
|
||||
|
||||
|
||||
def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
|
||||
def test_async_stream_text_is_incremental_and_allows_parallel_calls() -> None:
|
||||
async def scenario() -> tuple[str, list[str], bool]:
|
||||
client = AsyncAppServerClient()
|
||||
|
||||
@@ -46,19 +46,19 @@ def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
|
||||
stream = client.stream_text("thread-1", "hello")
|
||||
first = await anext(stream)
|
||||
|
||||
blocked_before_stream_done = False
|
||||
completed_before_stream_done = False
|
||||
competing_call = asyncio.create_task(client.model_list())
|
||||
await asyncio.sleep(0.01)
|
||||
blocked_before_stream_done = not competing_call.done()
|
||||
completed_before_stream_done = competing_call.done()
|
||||
|
||||
remaining: list[str] = []
|
||||
async for item in stream:
|
||||
remaining.append(item)
|
||||
|
||||
await competing_call
|
||||
return first, remaining, blocked_before_stream_done
|
||||
return first, remaining, completed_before_stream_done
|
||||
|
||||
first, remaining, blocked = asyncio.run(scenario())
|
||||
first, remaining, completed = asyncio.run(scenario())
|
||||
assert first == "first"
|
||||
assert remaining == ["second", "third"]
|
||||
assert blocked
|
||||
assert completed
|
||||
|
||||
@@ -4,7 +4,12 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from codex_app_server.client import AppServerClient, _params_dict
|
||||
from codex_app_server.generated.v2_all import ThreadListParams, ThreadTokenUsageUpdatedNotification
|
||||
from codex_app_server.generated.v2_all import (
|
||||
AgentMessageDeltaNotification,
|
||||
ThreadListParams,
|
||||
ThreadTokenUsageUpdatedNotification,
|
||||
TurnCompletedNotification,
|
||||
)
|
||||
from codex_app_server.models import UnknownNotification
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
@@ -93,3 +98,58 @@ def test_invalid_notification_payload_falls_back_to_unknown() -> None:
|
||||
|
||||
assert event.method == "thread/tokenUsage/updated"
|
||||
assert isinstance(event.payload, UnknownNotification)
|
||||
|
||||
|
||||
def test_client_routes_interleaved_turn_notifications_to_matching_turn_queues() -> None:
|
||||
client = AppServerClient()
|
||||
first = client._coerce_notification(
|
||||
"item/agentMessage/delta",
|
||||
{
|
||||
"delta": "first",
|
||||
"itemId": "item-1",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-1",
|
||||
},
|
||||
)
|
||||
second = client._coerce_notification(
|
||||
"item/agentMessage/delta",
|
||||
{
|
||||
"delta": "second",
|
||||
"itemId": "item-2",
|
||||
"threadId": "thread-2",
|
||||
"turnId": "turn-2",
|
||||
},
|
||||
)
|
||||
|
||||
client._dispatch_notification(first) # type: ignore[attr-defined]
|
||||
client._dispatch_notification(second) # type: ignore[attr-defined]
|
||||
|
||||
first_turn = client.next_turn_notification("thread-1", "turn-1")
|
||||
second_turn = client.next_turn_notification("thread-2", "turn-2")
|
||||
|
||||
assert isinstance(first_turn.payload, AgentMessageDeltaNotification)
|
||||
assert first_turn.payload.delta == "first"
|
||||
assert isinstance(second_turn.payload, AgentMessageDeltaNotification)
|
||||
assert second_turn.payload.delta == "second"
|
||||
|
||||
|
||||
def test_next_notification_still_returns_turn_notifications_without_active_streams() -> None:
|
||||
client = AppServerClient()
|
||||
completed = client._coerce_notification(
|
||||
"turn/completed",
|
||||
{
|
||||
"threadId": "thread-1",
|
||||
"turn": {
|
||||
"id": "turn-1",
|
||||
"items": [],
|
||||
"status": "completed",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
client._dispatch_notification(completed) # type: ignore[attr-defined]
|
||||
event = client.next_notification()
|
||||
|
||||
assert event.method == "turn/completed"
|
||||
assert isinstance(event.payload, TurnCompletedNotification)
|
||||
assert event.payload.turn.id == "turn-1"
|
||||
|
||||
@@ -133,6 +133,15 @@ def _token_usage_notification(
|
||||
)
|
||||
|
||||
|
||||
def _turn_notification_source(
|
||||
notifications_by_turn: dict[tuple[str, str], deque[Notification]],
|
||||
):
|
||||
def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||
return notifications_by_turn[(thread_id, turn_id)].popleft()
|
||||
|
||||
return fake_next_turn_notification
|
||||
|
||||
|
||||
def test_codex_init_failure_closes_client(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
closed: list[bool] = []
|
||||
|
||||
@@ -226,66 +235,132 @@ def test_async_codex_initializes_only_once_under_concurrency() -> None:
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_turn_stream_rejects_second_active_consumer() -> None:
|
||||
def test_turn_stream_allows_different_active_threads() -> None:
|
||||
client = AppServerClient()
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
_delta_notification(turn_id="turn-1"),
|
||||
_completed_notification(turn_id="turn-1"),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||
]
|
||||
),
|
||||
("thread-2", "turn-2"): deque(
|
||||
[
|
||||
_delta_notification(thread_id="thread-2", turn_id="turn-2"),
|
||||
_completed_notification(thread_id="thread-2", turn_id="turn-2"),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
|
||||
second_stream = TurnHandle(client, "thread-2", "turn-2").stream()
|
||||
assert next(first_stream).method == "item/agentMessage/delta"
|
||||
assert next(second_stream).method == "item/agentMessage/delta"
|
||||
|
||||
first_stream.close()
|
||||
second_stream.close()
|
||||
|
||||
|
||||
def test_turn_stream_blocks_next_notification_while_active() -> None:
|
||||
client = AppServerClient()
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
|
||||
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
|
||||
assert next(first_stream).method == "item/agentMessage/delta"
|
||||
|
||||
second_stream = TurnHandle(client, "thread-1", "turn-2").stream()
|
||||
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
|
||||
next(second_stream)
|
||||
with pytest.raises(RuntimeError, match="next_notification\\(\\) is incompatible"):
|
||||
client.next_notification()
|
||||
|
||||
first_stream.close()
|
||||
|
||||
|
||||
def test_async_turn_stream_rejects_second_active_consumer() -> None:
|
||||
def test_turn_start_rejects_same_thread_overlap_and_allows_after_completion() -> None:
|
||||
client = AppServerClient()
|
||||
turn_ids = iter(["turn-1", "turn-2"])
|
||||
|
||||
def fake_request(method: str, params, *, response_model): # type: ignore[no-untyped-def]
|
||||
assert method == "turn/start"
|
||||
return response_model.model_validate(
|
||||
{
|
||||
"turn": {
|
||||
"id": next(turn_ids),
|
||||
"items": [],
|
||||
"status": TurnStatus.in_progress.value,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
client.request = fake_request # type: ignore[method-assign]
|
||||
|
||||
first = client.turn_start("thread-1", "first turn")
|
||||
assert first.turn.id == "turn-1"
|
||||
|
||||
with pytest.raises(RuntimeError, match="already has active turn"):
|
||||
client.turn_start("thread-1", "second turn")
|
||||
|
||||
client._dispatch_notification( # type: ignore[attr-defined]
|
||||
_completed_notification(thread_id="thread-1", turn_id="turn-1")
|
||||
)
|
||||
|
||||
second = client.turn_start("thread-1", "second turn")
|
||||
assert second.turn.id == "turn-2"
|
||||
|
||||
|
||||
def test_async_turn_stream_allows_different_active_threads() -> None:
|
||||
async def scenario() -> None:
|
||||
codex = AsyncCodex()
|
||||
|
||||
async def fake_ensure_initialized() -> None:
|
||||
return None
|
||||
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
_delta_notification(turn_id="turn-1"),
|
||||
_completed_notification(turn_id="turn-1"),
|
||||
]
|
||||
)
|
||||
notifications_by_turn = {
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||
]
|
||||
),
|
||||
("thread-2", "turn-2"): deque(
|
||||
[
|
||||
_delta_notification(thread_id="thread-2", turn_id="turn-2"),
|
||||
_completed_notification(thread_id="thread-2", turn_id="turn-2"),
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
async def fake_next_notification() -> Notification:
|
||||
return notifications.popleft()
|
||||
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||
return notifications_by_turn[(thread_id, turn_id)].popleft()
|
||||
|
||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
||||
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
|
||||
|
||||
first_stream = AsyncTurnHandle(codex, "thread-1", "turn-1").stream()
|
||||
second_stream = AsyncTurnHandle(codex, "thread-2", "turn-2").stream()
|
||||
assert (await anext(first_stream)).method == "item/agentMessage/delta"
|
||||
|
||||
second_stream = AsyncTurnHandle(codex, "thread-1", "turn-2").stream()
|
||||
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
|
||||
await anext(second_stream)
|
||||
assert (await anext(second_stream)).method == "item/agentMessage/delta"
|
||||
|
||||
await first_stream.aclose()
|
||||
await second_stream.aclose()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_turn_run_returns_completed_turn_payload() -> None:
|
||||
client = AppServerClient()
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
_completed_notification(),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{("thread-1", "turn-1"): deque([_completed_notification()])}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
|
||||
result = TurnHandle(client, "thread-1", "turn-1").run()
|
||||
|
||||
@@ -298,14 +373,17 @@ def test_thread_run_accepts_string_input_and_returns_run_result() -> None:
|
||||
client = AppServerClient()
|
||||
item_notification = _item_completed_notification(text="Hello.")
|
||||
usage_notification = _token_usage_notification()
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
item_notification,
|
||||
usage_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
item_notification,
|
||||
usage_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202
|
||||
@@ -331,14 +409,17 @@ def test_thread_run_uses_last_completed_assistant_message_as_final_response() ->
|
||||
client = AppServerClient()
|
||||
first_item_notification = _item_completed_notification(text="First message")
|
||||
second_item_notification = _item_completed_notification(text="Second message")
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
first_item_notification,
|
||||
second_item_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
first_item_notification,
|
||||
second_item_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -356,14 +437,17 @@ def test_thread_run_preserves_empty_last_assistant_message() -> None:
|
||||
client = AppServerClient()
|
||||
first_item_notification = _item_completed_notification(text="First message")
|
||||
second_item_notification = _item_completed_notification(text="")
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
first_item_notification,
|
||||
second_item_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
first_item_notification,
|
||||
second_item_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -387,14 +471,17 @@ def test_thread_run_prefers_explicit_final_answer_over_later_commentary() -> Non
|
||||
text="Commentary",
|
||||
phase=MessagePhase.commentary,
|
||||
)
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
final_answer_notification,
|
||||
commentary_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
final_answer_notification,
|
||||
commentary_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -414,13 +501,16 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
|
||||
text="Commentary",
|
||||
phase=MessagePhase.commentary,
|
||||
)
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
commentary_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
commentary_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -433,12 +523,13 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
|
||||
|
||||
def test_thread_run_raises_on_failed_turn() -> None:
|
||||
client = AppServerClient()
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
_completed_notification(status="failed", error_message="boom"),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[_completed_notification(status="failed", error_message="boom")]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -471,12 +562,13 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None:
|
||||
seen["params"] = params
|
||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||
|
||||
async def fake_next_notification() -> Notification:
|
||||
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||
assert (thread_id, turn_id) == ("thread-1", "turn-1")
|
||||
return notifications.popleft()
|
||||
|
||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
||||
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
|
||||
|
||||
result = await AsyncThread(codex, "thread-1").run("hello")
|
||||
|
||||
@@ -511,12 +603,13 @@ def test_async_thread_run_uses_last_completed_assistant_message_as_final_respons
|
||||
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
|
||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||
|
||||
async def fake_next_notification() -> Notification:
|
||||
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||
assert (thread_id, turn_id) == ("thread-1", "turn-1")
|
||||
return notifications.popleft()
|
||||
|
||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
||||
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
|
||||
|
||||
result = await AsyncThread(codex, "thread-1").run("hello")
|
||||
|
||||
@@ -550,12 +643,13 @@ def test_async_thread_run_returns_none_when_only_commentary_messages_complete()
|
||||
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
|
||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||
|
||||
async def fake_next_notification() -> Notification:
|
||||
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||
assert (thread_id, turn_id) == ("thread-1", "turn-1")
|
||||
return notifications.popleft()
|
||||
|
||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
||||
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
|
||||
|
||||
result = await AsyncThread(codex, "thread-1").run("hello")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user