mirror of
https://github.com/openai/codex.git
synced 2026-05-04 05:11:37 +03:00
Support concurrent Python SDK turns across threads
This commit is contained in:
@@ -6,7 +6,7 @@ import time
|
||||
from codex_app_server.async_client import AsyncAppServerClient
|
||||
|
||||
|
||||
def test_async_client_serializes_transport_calls() -> None:
|
||||
def test_async_client_allows_parallel_transport_calls() -> None:
|
||||
async def scenario() -> int:
|
||||
client = AsyncAppServerClient()
|
||||
active = 0
|
||||
@@ -24,10 +24,10 @@ def test_async_client_serializes_transport_calls() -> None:
|
||||
await asyncio.gather(client.model_list(), client.model_list())
|
||||
return max_active
|
||||
|
||||
assert asyncio.run(scenario()) == 1
|
||||
assert asyncio.run(scenario()) == 2
|
||||
|
||||
|
||||
def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
|
||||
def test_async_stream_text_is_incremental_and_allows_parallel_calls() -> None:
|
||||
async def scenario() -> tuple[str, list[str], bool]:
|
||||
client = AsyncAppServerClient()
|
||||
|
||||
@@ -46,19 +46,19 @@ def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
|
||||
stream = client.stream_text("thread-1", "hello")
|
||||
first = await anext(stream)
|
||||
|
||||
blocked_before_stream_done = False
|
||||
completed_before_stream_done = False
|
||||
competing_call = asyncio.create_task(client.model_list())
|
||||
await asyncio.sleep(0.01)
|
||||
blocked_before_stream_done = not competing_call.done()
|
||||
completed_before_stream_done = competing_call.done()
|
||||
|
||||
remaining: list[str] = []
|
||||
async for item in stream:
|
||||
remaining.append(item)
|
||||
|
||||
await competing_call
|
||||
return first, remaining, blocked_before_stream_done
|
||||
return first, remaining, completed_before_stream_done
|
||||
|
||||
first, remaining, blocked = asyncio.run(scenario())
|
||||
first, remaining, completed = asyncio.run(scenario())
|
||||
assert first == "first"
|
||||
assert remaining == ["second", "third"]
|
||||
assert blocked
|
||||
assert completed
|
||||
|
||||
@@ -4,7 +4,12 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from codex_app_server.client import AppServerClient, _params_dict
|
||||
from codex_app_server.generated.v2_all import ThreadListParams, ThreadTokenUsageUpdatedNotification
|
||||
from codex_app_server.generated.v2_all import (
|
||||
AgentMessageDeltaNotification,
|
||||
ThreadListParams,
|
||||
ThreadTokenUsageUpdatedNotification,
|
||||
TurnCompletedNotification,
|
||||
)
|
||||
from codex_app_server.models import UnknownNotification
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
@@ -93,3 +98,58 @@ def test_invalid_notification_payload_falls_back_to_unknown() -> None:
|
||||
|
||||
assert event.method == "thread/tokenUsage/updated"
|
||||
assert isinstance(event.payload, UnknownNotification)
|
||||
|
||||
|
||||
def test_client_routes_interleaved_turn_notifications_to_matching_turn_queues() -> None:
|
||||
client = AppServerClient()
|
||||
first = client._coerce_notification(
|
||||
"item/agentMessage/delta",
|
||||
{
|
||||
"delta": "first",
|
||||
"itemId": "item-1",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-1",
|
||||
},
|
||||
)
|
||||
second = client._coerce_notification(
|
||||
"item/agentMessage/delta",
|
||||
{
|
||||
"delta": "second",
|
||||
"itemId": "item-2",
|
||||
"threadId": "thread-2",
|
||||
"turnId": "turn-2",
|
||||
},
|
||||
)
|
||||
|
||||
client._dispatch_notification(first) # type: ignore[attr-defined]
|
||||
client._dispatch_notification(second) # type: ignore[attr-defined]
|
||||
|
||||
first_turn = client.next_turn_notification("thread-1", "turn-1")
|
||||
second_turn = client.next_turn_notification("thread-2", "turn-2")
|
||||
|
||||
assert isinstance(first_turn.payload, AgentMessageDeltaNotification)
|
||||
assert first_turn.payload.delta == "first"
|
||||
assert isinstance(second_turn.payload, AgentMessageDeltaNotification)
|
||||
assert second_turn.payload.delta == "second"
|
||||
|
||||
|
||||
def test_next_notification_still_returns_turn_notifications_without_active_streams() -> None:
|
||||
client = AppServerClient()
|
||||
completed = client._coerce_notification(
|
||||
"turn/completed",
|
||||
{
|
||||
"threadId": "thread-1",
|
||||
"turn": {
|
||||
"id": "turn-1",
|
||||
"items": [],
|
||||
"status": "completed",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
client._dispatch_notification(completed) # type: ignore[attr-defined]
|
||||
event = client.next_notification()
|
||||
|
||||
assert event.method == "turn/completed"
|
||||
assert isinstance(event.payload, TurnCompletedNotification)
|
||||
assert event.payload.turn.id == "turn-1"
|
||||
|
||||
@@ -133,6 +133,15 @@ def _token_usage_notification(
|
||||
)
|
||||
|
||||
|
||||
def _turn_notification_source(
|
||||
notifications_by_turn: dict[tuple[str, str], deque[Notification]],
|
||||
):
|
||||
def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||
return notifications_by_turn[(thread_id, turn_id)].popleft()
|
||||
|
||||
return fake_next_turn_notification
|
||||
|
||||
|
||||
def test_codex_init_failure_closes_client(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
closed: list[bool] = []
|
||||
|
||||
@@ -226,66 +235,132 @@ def test_async_codex_initializes_only_once_under_concurrency() -> None:
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_turn_stream_rejects_second_active_consumer() -> None:
|
||||
def test_turn_stream_allows_different_active_threads() -> None:
|
||||
client = AppServerClient()
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
_delta_notification(turn_id="turn-1"),
|
||||
_completed_notification(turn_id="turn-1"),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||
]
|
||||
),
|
||||
("thread-2", "turn-2"): deque(
|
||||
[
|
||||
_delta_notification(thread_id="thread-2", turn_id="turn-2"),
|
||||
_completed_notification(thread_id="thread-2", turn_id="turn-2"),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
|
||||
second_stream = TurnHandle(client, "thread-2", "turn-2").stream()
|
||||
assert next(first_stream).method == "item/agentMessage/delta"
|
||||
assert next(second_stream).method == "item/agentMessage/delta"
|
||||
|
||||
first_stream.close()
|
||||
second_stream.close()
|
||||
|
||||
|
||||
def test_turn_stream_blocks_next_notification_while_active() -> None:
|
||||
client = AppServerClient()
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
|
||||
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
|
||||
assert next(first_stream).method == "item/agentMessage/delta"
|
||||
|
||||
second_stream = TurnHandle(client, "thread-1", "turn-2").stream()
|
||||
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
|
||||
next(second_stream)
|
||||
with pytest.raises(RuntimeError, match="next_notification\\(\\) is incompatible"):
|
||||
client.next_notification()
|
||||
|
||||
first_stream.close()
|
||||
|
||||
|
||||
def test_async_turn_stream_rejects_second_active_consumer() -> None:
|
||||
def test_turn_start_rejects_same_thread_overlap_and_allows_after_completion() -> None:
|
||||
client = AppServerClient()
|
||||
turn_ids = iter(["turn-1", "turn-2"])
|
||||
|
||||
def fake_request(method: str, params, *, response_model): # type: ignore[no-untyped-def]
|
||||
assert method == "turn/start"
|
||||
return response_model.model_validate(
|
||||
{
|
||||
"turn": {
|
||||
"id": next(turn_ids),
|
||||
"items": [],
|
||||
"status": TurnStatus.in_progress.value,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
client.request = fake_request # type: ignore[method-assign]
|
||||
|
||||
first = client.turn_start("thread-1", "first turn")
|
||||
assert first.turn.id == "turn-1"
|
||||
|
||||
with pytest.raises(RuntimeError, match="already has active turn"):
|
||||
client.turn_start("thread-1", "second turn")
|
||||
|
||||
client._dispatch_notification( # type: ignore[attr-defined]
|
||||
_completed_notification(thread_id="thread-1", turn_id="turn-1")
|
||||
)
|
||||
|
||||
second = client.turn_start("thread-1", "second turn")
|
||||
assert second.turn.id == "turn-2"
|
||||
|
||||
|
||||
def test_async_turn_stream_allows_different_active_threads() -> None:
|
||||
async def scenario() -> None:
|
||||
codex = AsyncCodex()
|
||||
|
||||
async def fake_ensure_initialized() -> None:
|
||||
return None
|
||||
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
_delta_notification(turn_id="turn-1"),
|
||||
_completed_notification(turn_id="turn-1"),
|
||||
]
|
||||
)
|
||||
notifications_by_turn = {
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
_delta_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||
_completed_notification(thread_id="thread-1", turn_id="turn-1"),
|
||||
]
|
||||
),
|
||||
("thread-2", "turn-2"): deque(
|
||||
[
|
||||
_delta_notification(thread_id="thread-2", turn_id="turn-2"),
|
||||
_completed_notification(thread_id="thread-2", turn_id="turn-2"),
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
async def fake_next_notification() -> Notification:
|
||||
return notifications.popleft()
|
||||
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||
return notifications_by_turn[(thread_id, turn_id)].popleft()
|
||||
|
||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
||||
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
|
||||
|
||||
first_stream = AsyncTurnHandle(codex, "thread-1", "turn-1").stream()
|
||||
second_stream = AsyncTurnHandle(codex, "thread-2", "turn-2").stream()
|
||||
assert (await anext(first_stream)).method == "item/agentMessage/delta"
|
||||
|
||||
second_stream = AsyncTurnHandle(codex, "thread-1", "turn-2").stream()
|
||||
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
|
||||
await anext(second_stream)
|
||||
assert (await anext(second_stream)).method == "item/agentMessage/delta"
|
||||
|
||||
await first_stream.aclose()
|
||||
await second_stream.aclose()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_turn_run_returns_completed_turn_payload() -> None:
|
||||
client = AppServerClient()
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
_completed_notification(),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{("thread-1", "turn-1"): deque([_completed_notification()])}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
|
||||
result = TurnHandle(client, "thread-1", "turn-1").run()
|
||||
|
||||
@@ -298,14 +373,17 @@ def test_thread_run_accepts_string_input_and_returns_run_result() -> None:
|
||||
client = AppServerClient()
|
||||
item_notification = _item_completed_notification(text="Hello.")
|
||||
usage_notification = _token_usage_notification()
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
item_notification,
|
||||
usage_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
item_notification,
|
||||
usage_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202
|
||||
@@ -331,14 +409,17 @@ def test_thread_run_uses_last_completed_assistant_message_as_final_response() ->
|
||||
client = AppServerClient()
|
||||
first_item_notification = _item_completed_notification(text="First message")
|
||||
second_item_notification = _item_completed_notification(text="Second message")
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
first_item_notification,
|
||||
second_item_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
first_item_notification,
|
||||
second_item_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -356,14 +437,17 @@ def test_thread_run_preserves_empty_last_assistant_message() -> None:
|
||||
client = AppServerClient()
|
||||
first_item_notification = _item_completed_notification(text="First message")
|
||||
second_item_notification = _item_completed_notification(text="")
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
first_item_notification,
|
||||
second_item_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
first_item_notification,
|
||||
second_item_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -387,14 +471,17 @@ def test_thread_run_prefers_explicit_final_answer_over_later_commentary() -> Non
|
||||
text="Commentary",
|
||||
phase=MessagePhase.commentary,
|
||||
)
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
final_answer_notification,
|
||||
commentary_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
final_answer_notification,
|
||||
commentary_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -414,13 +501,16 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
|
||||
text="Commentary",
|
||||
phase=MessagePhase.commentary,
|
||||
)
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
commentary_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[
|
||||
commentary_notification,
|
||||
_completed_notification(),
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -433,12 +523,13 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
|
||||
|
||||
def test_thread_run_raises_on_failed_turn() -> None:
|
||||
client = AppServerClient()
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
_completed_notification(status="failed", error_message="boom"),
|
||||
]
|
||||
client.next_turn_notification = _turn_notification_source( # type: ignore[method-assign]
|
||||
{
|
||||
("thread-1", "turn-1"): deque(
|
||||
[_completed_notification(status="failed", error_message="boom")]
|
||||
),
|
||||
}
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -471,12 +562,13 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None:
|
||||
seen["params"] = params
|
||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||
|
||||
async def fake_next_notification() -> Notification:
|
||||
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||
assert (thread_id, turn_id) == ("thread-1", "turn-1")
|
||||
return notifications.popleft()
|
||||
|
||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
||||
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
|
||||
|
||||
result = await AsyncThread(codex, "thread-1").run("hello")
|
||||
|
||||
@@ -511,12 +603,13 @@ def test_async_thread_run_uses_last_completed_assistant_message_as_final_respons
|
||||
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
|
||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||
|
||||
async def fake_next_notification() -> Notification:
|
||||
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||
assert (thread_id, turn_id) == ("thread-1", "turn-1")
|
||||
return notifications.popleft()
|
||||
|
||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
||||
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
|
||||
|
||||
result = await AsyncThread(codex, "thread-1").run("hello")
|
||||
|
||||
@@ -550,12 +643,13 @@ def test_async_thread_run_returns_none_when_only_commentary_messages_complete()
|
||||
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
|
||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||
|
||||
async def fake_next_notification() -> Notification:
|
||||
async def fake_next_turn_notification(thread_id: str, turn_id: str) -> Notification:
|
||||
assert (thread_id, turn_id) == ("thread-1", "turn-1")
|
||||
return notifications.popleft()
|
||||
|
||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
||||
codex._client.next_turn_notification = fake_next_turn_notification # type: ignore[method-assign]
|
||||
|
||||
result = await AsyncThread(codex, "thread-1").run("hello")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user