mirror of
https://github.com/openai/codex.git
synced 2026-05-10 00:01:16 +03:00
Compare commits
16 Commits
latest-alp
...
codex/pyth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b8e868140 | ||
|
|
2654cc299e | ||
|
|
242ca6d8fd | ||
|
|
b7635f4d77 | ||
|
|
c24694bdb0 | ||
|
|
6e10973c78 | ||
|
|
becbd2a127 | ||
|
|
11e31d7d38 | ||
|
|
1d0023776f | ||
|
|
9b54951688 | ||
|
|
3a3e1b477c | ||
|
|
356c6797b8 | ||
|
|
bd14ac4758 | ||
|
|
d764740e6f | ||
|
|
29e1c96f72 | ||
|
|
ebe75bb683 |
33
.github/workflows/sdk.yml
vendored
33
.github/workflows/sdk.yml
vendored
@@ -6,6 +6,39 @@ on:
|
||||
pull_request: {}
|
||||
|
||||
jobs:
|
||||
python-sdk:
|
||||
runs-on:
|
||||
group: codex-runners
|
||||
labels: codex-linux-x64
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
persist-credentials: false
|
||||
|
||||
- name: Test Python SDK
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Run inside Alpine so dependency resolution exercises the pinned
|
||||
# runtime wheel on the same Linux wheel family that CI installs.
|
||||
docker run --rm \
|
||||
--user "$(id -u):$(id -g)" \
|
||||
-e HOME=/tmp/codex-python-sdk-home \
|
||||
-e UV_LINK_MODE=copy \
|
||||
-v "${GITHUB_WORKSPACE}:${GITHUB_WORKSPACE}" \
|
||||
-w "${GITHUB_WORKSPACE}/sdk/python" \
|
||||
python:3.12-alpine \
|
||||
sh -euxc '
|
||||
python -m venv /tmp/uv
|
||||
/tmp/uv/bin/python -m pip install uv==0.11.3
|
||||
/tmp/uv/bin/uv sync --extra dev --frozen
|
||||
/tmp/uv/bin/uv run --extra dev pytest
|
||||
'
|
||||
|
||||
sdks:
|
||||
runs-on:
|
||||
group: codex-runners
|
||||
|
||||
@@ -114,7 +114,7 @@ members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.131.0-alpha.4"
|
||||
version = "0.0.0"
|
||||
# Track the edition for all workspace crates in one place. Individual
|
||||
# crates can still override this value, but keeping it here means new
|
||||
# crates created with `cargo new -w ...` automatically inherit the 2024
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
Experimental Python SDK for `codex app-server` JSON-RPC v2 over stdio, with a small default surface optimized for real scripts and apps.
|
||||
|
||||
The generated wire-model layer is currently sourced from the bundled v2 schema and exposed as Pydantic models with snake_case Python fields that serialize back to the app-server’s camelCase wire format.
|
||||
The generated wire-model layer is sourced from the pinned `openai-codex-cli-bin`
|
||||
runtime package and exposed as Pydantic models with snake_case Python fields
|
||||
that serialize back to the app-server’s camelCase wire format.
|
||||
|
||||
## Install
|
||||
|
||||
@@ -68,6 +70,7 @@ notebook bootstrap the pinned runtime package automatically.
|
||||
|
||||
```bash
|
||||
cd sdk/python
|
||||
uv sync
|
||||
python scripts/update_sdk_artifacts.py generate-types
|
||||
python scripts/update_sdk_artifacts.py \
|
||||
stage-sdk \
|
||||
|
||||
@@ -27,16 +27,22 @@ class RuntimeSetupError(RuntimeError):
|
||||
|
||||
|
||||
def pinned_runtime_version() -> str:
|
||||
source_version = _source_tree_project_version()
|
||||
if source_version is not None:
|
||||
return _normalized_package_version(source_version)
|
||||
"""Return the exact runtime version pinned by the SDK package dependency."""
|
||||
source_pin = _source_tree_runtime_dependency_version()
|
||||
if source_pin is not None:
|
||||
return _normalized_package_version(source_pin)
|
||||
|
||||
try:
|
||||
return _normalized_package_version(importlib.metadata.version(SDK_PACKAGE_NAME))
|
||||
installed_pin = _installed_sdk_runtime_dependency_version()
|
||||
except importlib.metadata.PackageNotFoundError as exc:
|
||||
raise RuntimeSetupError(
|
||||
f"Unable to resolve {SDK_PACKAGE_NAME} version for runtime pinning."
|
||||
f"Unable to resolve {SDK_PACKAGE_NAME} metadata for runtime pinning."
|
||||
) from exc
|
||||
if installed_pin is None:
|
||||
raise RuntimeSetupError(
|
||||
f"Unable to resolve {PACKAGE_NAME} dependency pin from {SDK_PACKAGE_NAME}."
|
||||
)
|
||||
return _normalized_package_version(installed_pin)
|
||||
|
||||
|
||||
def ensure_runtime_package_installed(
|
||||
@@ -399,20 +405,33 @@ def _release_tag(version: str) -> str:
|
||||
return f"rust-v{_codex_release_version(version)}"
|
||||
|
||||
|
||||
def _source_tree_project_version() -> str | None:
|
||||
def _source_tree_runtime_dependency_version() -> str | None:
|
||||
"""Read the runtime dependency pin when the SDK is running from a checkout."""
|
||||
pyproject_path = Path(__file__).resolve().parent / "pyproject.toml"
|
||||
if not pyproject_path.exists():
|
||||
return None
|
||||
|
||||
match = re.search(
|
||||
r'(?m)^version = "([^"]+)"$',
|
||||
pyproject_path.read_text(encoding="utf-8"),
|
||||
)
|
||||
match = re.search(_runtime_dependency_pin_pattern(), pyproject_path.read_text())
|
||||
if match is None:
|
||||
return None
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def _installed_sdk_runtime_dependency_version() -> str | None:
|
||||
"""Read the runtime dependency pin from installed package metadata."""
|
||||
requirements = importlib.metadata.requires(SDK_PACKAGE_NAME) or []
|
||||
for requirement in requirements:
|
||||
match = re.search(_runtime_dependency_pin_pattern(), requirement)
|
||||
if match is not None:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
|
||||
def _runtime_dependency_pin_pattern() -> str:
|
||||
"""Match the exact runtime dependency pin in TOML and wheel metadata."""
|
||||
return rf'{re.escape(PACKAGE_NAME)}\s*==\s*"?([^",;\s]+)"?'
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PACKAGE_NAME",
|
||||
"SDK_PACKAGE_NAME",
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Public surface of `codex_app_server` for app-server v2.
|
||||
|
||||
This SDK surface is experimental. The current implementation intentionally allows only one active turn consumer (`Thread.run()`, `TurnHandle.stream()`, or `TurnHandle.run()`) per client instance at a time.
|
||||
This SDK surface is experimental. Turn streams are routed by turn ID so one client can consume multiple active turns concurrently.
|
||||
|
||||
## Package Entry
|
||||
|
||||
@@ -137,8 +137,8 @@ Use `turn(...)` when you need low-level turn control (`stream()`, `steer()`,
|
||||
|
||||
Behavior notes:
|
||||
|
||||
- `stream()` and `run()` are exclusive per client instance in the current experimental build
|
||||
- starting a second turn consumer on the same `Codex` instance raises `RuntimeError`
|
||||
- `stream()` and `run()` consume only notifications for their own turn ID
|
||||
- one `Codex` instance can stream multiple active turns concurrently
|
||||
|
||||
### AsyncTurnHandle
|
||||
|
||||
@@ -149,8 +149,8 @@ Behavior notes:
|
||||
|
||||
Behavior notes:
|
||||
|
||||
- `stream()` and `run()` are exclusive per client instance in the current experimental build
|
||||
- starting a second turn consumer on the same `AsyncCodex` instance raises `RuntimeError`
|
||||
- `stream()` and `run()` consume only notifications for their own turn ID
|
||||
- one `AsyncCodex` instance can stream multiple active turns concurrently
|
||||
|
||||
## Inputs
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ What happened:
|
||||
- `thread.run("...")` started a turn, consumed events until completion, and returned the final assistant response plus collected items and usage.
|
||||
- `result.final_response` is `None` when no final-answer or phase-less assistant message item completes for the turn.
|
||||
- use `thread.turn(...)` when you need a `TurnHandle` for streaming, steering, interrupting, or turn IDs/status
|
||||
- one client can have only one active turn consumer (`thread.run(...)`, `TurnHandle.stream()`, or `TurnHandle.run()`) at a time in the current experimental build
|
||||
- one client can consume multiple active turns concurrently; turn streams are routed by turn ID
|
||||
|
||||
## 3) Continue the same thread (multi-turn)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ will download the matching GitHub release artifact, stage a temporary local
|
||||
`openai-codex-cli-bin` package, install it into your active interpreter, and clean up
|
||||
the temporary files afterward.
|
||||
|
||||
The pinned runtime version comes from the SDK package version.
|
||||
The pinned runtime version comes from the SDK package dependency.
|
||||
|
||||
## Run examples
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "openai-codex-app-server-sdk"
|
||||
version = "0.116.0a1"
|
||||
version = "0.131.0a4"
|
||||
description = "Python SDK for Codex app-server v2"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
@@ -22,7 +22,7 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
]
|
||||
dependencies = ["pydantic>=2.12"]
|
||||
dependencies = ["pydantic>=2.12", "openai-codex-cli-bin==0.131.0a4"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/openai/codex"
|
||||
@@ -63,8 +63,10 @@ testpaths = ["tests"]
|
||||
|
||||
[tool.uv]
|
||||
exclude-newer = "7 days"
|
||||
exclude-newer-package = { openai-codex-cli-bin = "2026-05-10T00:00:00Z" }
|
||||
index-strategy = "first-index"
|
||||
|
||||
[tool.uv.pip]
|
||||
exclude-newer = "7 days"
|
||||
exclude-newer-package = { openai-codex-cli-bin = "2026-05-10T00:00:00Z" }
|
||||
index-strategy = "first-index"
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import json
|
||||
import platform
|
||||
import re
|
||||
@@ -33,19 +34,14 @@ def python_runtime_root() -> Path:
|
||||
return repo_root() / "sdk" / "python-runtime"
|
||||
|
||||
|
||||
def schema_bundle_path() -> Path:
|
||||
return (
|
||||
repo_root()
|
||||
/ "codex-rs"
|
||||
/ "app-server-protocol"
|
||||
/ "schema"
|
||||
/ "json"
|
||||
/ "codex_app_server_protocol.v2.schemas.json"
|
||||
)
|
||||
def sdk_pyproject_path() -> Path:
|
||||
"""Return the SDK pyproject file that owns package pins and versions."""
|
||||
return sdk_root() / "pyproject.toml"
|
||||
|
||||
|
||||
def schema_root_dir() -> Path:
|
||||
return repo_root() / "codex-rs" / "app-server-protocol" / "schema" / "json"
|
||||
def schema_bundle_path(schema_dir: Path) -> Path:
|
||||
"""Return the aggregate v2 schema bundle emitted by the runtime binary."""
|
||||
return schema_dir / "codex_app_server_protocol.v2.schemas.json"
|
||||
|
||||
|
||||
def _is_windows() -> bool:
|
||||
@@ -61,6 +57,7 @@ def staged_runtime_bin_path(root: Path) -> Path:
|
||||
|
||||
|
||||
def staged_runtime_resource_path(root: Path, resource: Path) -> Path:
|
||||
"""Stage runtime helper binaries beside the main bundled Codex binary."""
|
||||
# Runtime wheels include the whole bin/ directory, so helper executables
|
||||
# should be staged beside the main Codex binary instead of changing the
|
||||
# package template for each platform.
|
||||
@@ -78,7 +75,7 @@ def run_python_module(module: str, args: list[str], cwd: Path) -> None:
|
||||
def current_sdk_version() -> str:
|
||||
match = re.search(
|
||||
r'^version = "([^"]+)"$',
|
||||
(sdk_root() / "pyproject.toml").read_text(),
|
||||
sdk_pyproject_path().read_text(),
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
if match is None:
|
||||
@@ -86,6 +83,59 @@ def current_sdk_version() -> str:
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def pinned_runtime_version() -> str:
|
||||
"""Read the exact runtime package pin used for schema generation."""
|
||||
pyproject_text = sdk_pyproject_path().read_text()
|
||||
match = re.search(r"(?ms)^dependencies = \[(.*?)\]$", pyproject_text)
|
||||
if match is None:
|
||||
raise RuntimeError(
|
||||
"Could not find dependencies array in sdk/python/pyproject.toml"
|
||||
)
|
||||
|
||||
pins = re.findall(
|
||||
rf'"{re.escape(RUNTIME_DISTRIBUTION_NAME)}==([^"]+)"',
|
||||
match.group(1),
|
||||
)
|
||||
if len(pins) != 1:
|
||||
raise RuntimeError(
|
||||
f"Expected exactly one {RUNTIME_DISTRIBUTION_NAME} dependency pin "
|
||||
"in sdk/python/pyproject.toml"
|
||||
)
|
||||
return normalize_codex_version(pins[0])
|
||||
|
||||
|
||||
def pinned_runtime_codex_path() -> Path:
|
||||
"""Return the bundled Codex binary from the installed pinned runtime wheel."""
|
||||
expected_version = pinned_runtime_version()
|
||||
try:
|
||||
installed_version = importlib.metadata.version(RUNTIME_DISTRIBUTION_NAME)
|
||||
except importlib.metadata.PackageNotFoundError as exc:
|
||||
raise RuntimeError(
|
||||
f"Install {RUNTIME_DISTRIBUTION_NAME}=={expected_version} before "
|
||||
"generating Python SDK types."
|
||||
) from exc
|
||||
|
||||
normalized_installed_version = normalize_codex_version(installed_version)
|
||||
if normalized_installed_version != expected_version:
|
||||
raise RuntimeError(
|
||||
f"Expected {RUNTIME_DISTRIBUTION_NAME}=={expected_version}, "
|
||||
f"but found {installed_version}."
|
||||
)
|
||||
|
||||
try:
|
||||
from codex_cli_bin import bundled_codex_path
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
f"Installed {RUNTIME_DISTRIBUTION_NAME} package does not expose "
|
||||
"bundled_codex_path."
|
||||
) from exc
|
||||
|
||||
codex_path = bundled_codex_path()
|
||||
if not codex_path.exists():
|
||||
raise RuntimeError(f"Pinned Codex runtime binary not found at {codex_path}.")
|
||||
return codex_path
|
||||
|
||||
|
||||
def normalize_codex_version(version: str) -> str:
|
||||
normalized = version.strip()
|
||||
if normalized.startswith("rust-v"):
|
||||
@@ -487,8 +537,28 @@ def _annotate_schema(value: Any, base: str | None = None) -> None:
|
||||
_annotate_schema(child, base)
|
||||
|
||||
|
||||
def _normalized_schema_bundle_text() -> str:
|
||||
schema = json.loads(schema_bundle_path().read_text())
|
||||
def generate_schema_from_pinned_runtime(schema_dir: Path) -> Path:
|
||||
"""Generate app-server schemas by invoking the installed pinned runtime binary."""
|
||||
codex_path = pinned_runtime_codex_path()
|
||||
if schema_dir.exists():
|
||||
shutil.rmtree(schema_dir)
|
||||
schema_dir.mkdir(parents=True)
|
||||
run(
|
||||
[
|
||||
str(codex_path),
|
||||
"app-server",
|
||||
"generate-json-schema",
|
||||
"--out",
|
||||
str(schema_dir),
|
||||
],
|
||||
cwd=sdk_root(),
|
||||
)
|
||||
return schema_dir
|
||||
|
||||
|
||||
def _normalized_schema_bundle_text(schema_dir: Path) -> str:
|
||||
"""Normalize the schema bundle before feeding it to the Python type generator."""
|
||||
schema = json.loads(schema_bundle_path(schema_dir).read_text())
|
||||
definitions = schema.get("definitions", {})
|
||||
if isinstance(definitions, dict):
|
||||
for definition in definitions.values():
|
||||
@@ -500,7 +570,8 @@ def _normalized_schema_bundle_text() -> str:
|
||||
return json.dumps(schema, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
|
||||
def generate_v2_all() -> None:
|
||||
def generate_v2_all(schema_dir: Path) -> None:
|
||||
"""Regenerate the Pydantic v2 protocol model module from runtime schemas."""
|
||||
out_path = sdk_root() / "src" / "codex_app_server" / "generated" / "v2_all.py"
|
||||
out_dir = out_path.parent
|
||||
old_package_dir = out_dir / "v2_all"
|
||||
@@ -508,8 +579,8 @@ def generate_v2_all() -> None:
|
||||
shutil.rmtree(old_package_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
normalized_bundle = Path(td) / schema_bundle_path().name
|
||||
normalized_bundle.write_text(_normalized_schema_bundle_text())
|
||||
normalized_bundle = Path(td) / schema_bundle_path(schema_dir).name
|
||||
normalized_bundle.write_text(_normalized_schema_bundle_text(schema_dir))
|
||||
run_python_module(
|
||||
"datamodel_code_generator",
|
||||
[
|
||||
@@ -546,9 +617,10 @@ def generate_v2_all() -> None:
|
||||
_normalize_generated_timestamps(out_path)
|
||||
|
||||
|
||||
def _notification_specs() -> list[tuple[str, str]]:
|
||||
def _notification_specs(schema_dir: Path) -> list[tuple[str, str]]:
|
||||
"""Map each server notification method to its generated payload model class."""
|
||||
server_notifications = json.loads(
|
||||
(schema_root_dir() / "ServerNotification.json").read_text()
|
||||
(schema_dir / "ServerNotification.json").read_text()
|
||||
)
|
||||
one_of = server_notifications.get("oneOf", [])
|
||||
generated_source = (
|
||||
@@ -585,7 +657,48 @@ def _notification_specs() -> list[tuple[str, str]]:
|
||||
return specs
|
||||
|
||||
|
||||
def generate_notification_registry() -> None:
|
||||
def _notification_turn_id_specs(
|
||||
schema_dir: Path,
|
||||
specs: list[tuple[str, str]],
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Classify notification payloads by where their turn id is carried."""
|
||||
server_notifications = json.loads(
|
||||
(schema_dir / "ServerNotification.json").read_text()
|
||||
)
|
||||
definitions = server_notifications.get("definitions", {})
|
||||
if not isinstance(definitions, dict):
|
||||
return ([], [])
|
||||
|
||||
direct: list[str] = []
|
||||
nested: list[str] = []
|
||||
for _, class_name in specs:
|
||||
definition = definitions.get(class_name)
|
||||
if not isinstance(definition, dict):
|
||||
continue
|
||||
props = definition.get("properties", {})
|
||||
if not isinstance(props, dict):
|
||||
continue
|
||||
if "turnId" in props:
|
||||
direct.append(class_name)
|
||||
continue
|
||||
turn = props.get("turn")
|
||||
if isinstance(turn, dict) and turn.get("$ref") == "#/definitions/Turn":
|
||||
nested.append(class_name)
|
||||
|
||||
return (sorted(set(direct)), sorted(set(nested)))
|
||||
|
||||
|
||||
def _type_tuple_source(class_names: list[str]) -> str:
|
||||
"""Render a generated tuple literal for notification payload classes."""
|
||||
if not class_names:
|
||||
return "()"
|
||||
if len(class_names) == 1:
|
||||
return f"({class_names[0]},)"
|
||||
return "(\n" + "".join(f" {class_name},\n" for class_name in class_names) + ")"
|
||||
|
||||
|
||||
def generate_notification_registry(schema_dir: Path) -> None:
|
||||
"""Regenerate notification dispatch metadata from the runtime notification schema."""
|
||||
out = (
|
||||
sdk_root()
|
||||
/ "src"
|
||||
@@ -593,8 +706,12 @@ def generate_notification_registry() -> None:
|
||||
/ "generated"
|
||||
/ "notification_registry.py"
|
||||
)
|
||||
specs = _notification_specs()
|
||||
specs = _notification_specs(schema_dir)
|
||||
class_names = sorted({class_name for _, class_name in specs})
|
||||
direct_turn_id_types, nested_turn_types = _notification_turn_id_specs(
|
||||
schema_dir,
|
||||
specs,
|
||||
)
|
||||
|
||||
lines = [
|
||||
"# Auto-generated by scripts/update_sdk_artifacts.py",
|
||||
@@ -616,7 +733,27 @@ def generate_notification_registry() -> None:
|
||||
)
|
||||
for method, class_name in specs:
|
||||
lines.append(f' "{method}": {class_name},')
|
||||
lines.extend(["}", ""])
|
||||
lines.extend(
|
||||
[
|
||||
"}",
|
||||
"",
|
||||
"DIRECT_TURN_ID_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = "
|
||||
f"{_type_tuple_source(direct_turn_id_types)}",
|
||||
"",
|
||||
"NESTED_TURN_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = "
|
||||
f"{_type_tuple_source(nested_turn_types)}",
|
||||
"",
|
||||
"",
|
||||
"def notification_turn_id(payload: BaseModel) -> str | None:",
|
||||
' """Return the turn id carried by generated notification payload metadata."""',
|
||||
" if isinstance(payload, DIRECT_TURN_ID_NOTIFICATION_TYPES):",
|
||||
" return payload.turn_id if isinstance(payload.turn_id, str) else None",
|
||||
" if isinstance(payload, NESTED_TURN_NOTIFICATION_TYPES):",
|
||||
" return payload.turn.id",
|
||||
" return None",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
out.write_text("\n".join(lines))
|
||||
|
||||
@@ -695,8 +832,12 @@ def _camel_to_snake(name: str) -> str:
|
||||
def _load_public_fields(
|
||||
module_name: str, class_name: str, *, exclude: set[str] | None = None
|
||||
) -> list[PublicFieldSpec]:
|
||||
"""Load generated model fields used to render the ergonomic public methods."""
|
||||
exclude = exclude or set()
|
||||
module = importlib.import_module(module_name)
|
||||
if module_name == "codex_app_server.generated.v2_all":
|
||||
module = _load_generated_v2_all_module()
|
||||
else:
|
||||
module = importlib.import_module(module_name)
|
||||
model = getattr(module, class_name)
|
||||
fields: list[PublicFieldSpec] = []
|
||||
for name, field in model.model_fields.items():
|
||||
@@ -718,6 +859,20 @@ def _load_public_fields(
|
||||
return fields
|
||||
|
||||
|
||||
def _load_generated_v2_all_module() -> types.ModuleType:
|
||||
"""Import the freshly generated v2_all module without importing package init."""
|
||||
module_name = "_codex_app_server_generated_v2_all_for_artifacts"
|
||||
sys.modules.pop(module_name, None)
|
||||
module_path = sdk_root() / "src" / "codex_app_server" / "generated" / "v2_all.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise RuntimeError(f"Failed to load generated module from {module_path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _kw_signature_lines(fields: list[PublicFieldSpec]) -> list[str]:
|
||||
lines: list[str] = []
|
||||
for field in fields:
|
||||
@@ -927,6 +1082,7 @@ def _render_async_thread_block(
|
||||
|
||||
|
||||
def generate_public_api_flat_methods() -> None:
|
||||
"""Regenerate the public convenience methods from generated protocol models."""
|
||||
src_dir = sdk_root() / "src"
|
||||
public_api_path = src_dir / "codex_app_server" / "api.py"
|
||||
if not public_api_path.exists():
|
||||
@@ -992,13 +1148,22 @@ def generate_public_api_flat_methods() -> None:
|
||||
_render_async_thread_block(turn_start_fields),
|
||||
)
|
||||
public_api_path.write_text(source)
|
||||
run_python_module("ruff", ["format", str(public_api_path)], cwd=sdk_root())
|
||||
|
||||
|
||||
def generate_types_from_schema_dir(schema_dir: Path) -> None:
|
||||
"""Regenerate every SDK artifact derived from an existing schema directory."""
|
||||
# v2_all is the authoritative generated surface.
|
||||
generate_v2_all(schema_dir)
|
||||
generate_notification_registry(schema_dir)
|
||||
generate_public_api_flat_methods()
|
||||
|
||||
|
||||
def generate_types() -> None:
|
||||
# v2_all is the authoritative generated surface.
|
||||
generate_v2_all()
|
||||
generate_notification_registry()
|
||||
generate_public_api_flat_methods()
|
||||
"""Generate schemas from the pinned runtime and then refresh SDK artifacts."""
|
||||
with tempfile.TemporaryDirectory(prefix="codex-python-schema-") as td:
|
||||
schema_dir = generate_schema_from_pinned_runtime(Path(td) / "schema")
|
||||
generate_types_from_schema_dir(schema_dir)
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -22,12 +22,12 @@ from .generated.v2_all import (
|
||||
ReasoningSummary,
|
||||
SandboxMode,
|
||||
SandboxPolicy,
|
||||
ServiceTier,
|
||||
ThreadItem,
|
||||
ThreadForkParams,
|
||||
ThreadListParams,
|
||||
ThreadResumeParams,
|
||||
ThreadSortKey,
|
||||
ThreadSource,
|
||||
ThreadSourceKind,
|
||||
ThreadStartParams,
|
||||
ThreadTokenUsageUpdatedNotification,
|
||||
@@ -86,11 +86,11 @@ __all__ = [
|
||||
"ReasoningSummary",
|
||||
"SandboxMode",
|
||||
"SandboxPolicy",
|
||||
"ServiceTier",
|
||||
"ThreadStartParams",
|
||||
"ThreadResumeParams",
|
||||
"ThreadListParams",
|
||||
"ThreadSortKey",
|
||||
"ThreadSource",
|
||||
"ThreadSourceKind",
|
||||
"ThreadForkParams",
|
||||
"TurnStatus",
|
||||
|
||||
160
sdk/python/src/codex_app_server/_message_router.py
Normal file
160
sdk/python/src/codex_app_server/_message_router.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from collections import deque
|
||||
|
||||
from .errors import AppServerError, map_jsonrpc_error
|
||||
from .generated.notification_registry import notification_turn_id
|
||||
from .models import JsonValue, Notification, UnknownNotification
|
||||
|
||||
ResponseQueueItem = JsonValue | BaseException
|
||||
NotificationQueueItem = Notification | BaseException
|
||||
|
||||
|
||||
class MessageRouter:
|
||||
"""Route reader-thread messages to the SDK operation waiting for them.
|
||||
|
||||
The app-server stdio transport is a single ordered stream, so only the
|
||||
reader thread should consume stdout. This router keeps the rest of the SDK
|
||||
from competing for that stream by giving each in-flight JSON-RPC request
|
||||
and active turn stream its own queue.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Create empty response, turn, and global notification queues."""
|
||||
self._lock = threading.Lock()
|
||||
self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {}
|
||||
self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {}
|
||||
self._pending_turn_notifications: dict[str, deque[Notification]] = {}
|
||||
self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue()
|
||||
|
||||
def create_response_waiter(self, request_id: str) -> queue.Queue[ResponseQueueItem]:
|
||||
"""Register a one-shot queue for a JSON-RPC response id."""
|
||||
|
||||
waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1)
|
||||
with self._lock:
|
||||
self._response_waiters[request_id] = waiter
|
||||
return waiter
|
||||
|
||||
def discard_response_waiter(self, request_id: str) -> None:
|
||||
"""Remove a response waiter when the request could not be written."""
|
||||
|
||||
with self._lock:
|
||||
self._response_waiters.pop(request_id, None)
|
||||
|
||||
def next_global_notification(self) -> Notification:
|
||||
"""Block until the next notification that is not scoped to a turn."""
|
||||
|
||||
item = self._global_notifications.get()
|
||||
if isinstance(item, BaseException):
|
||||
raise item
|
||||
return item
|
||||
|
||||
def register_turn(self, turn_id: str) -> None:
|
||||
"""Register a queue for a turn stream and replay early events."""
|
||||
|
||||
turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue()
|
||||
with self._lock:
|
||||
if turn_id in self._turn_notifications:
|
||||
return
|
||||
# A turn can emit events immediately after turn/start, before the
|
||||
# caller receives the TurnHandle and starts streaming.
|
||||
pending = self._pending_turn_notifications.pop(turn_id, deque())
|
||||
self._turn_notifications[turn_id] = turn_queue
|
||||
for notification in pending:
|
||||
turn_queue.put(notification)
|
||||
|
||||
def unregister_turn(self, turn_id: str) -> None:
|
||||
"""Stop routing future turn events to the stream queue."""
|
||||
|
||||
with self._lock:
|
||||
self._turn_notifications.pop(turn_id, None)
|
||||
|
||||
def next_turn_notification(self, turn_id: str) -> Notification:
|
||||
"""Block until the next notification for a registered turn."""
|
||||
|
||||
with self._lock:
|
||||
turn_queue = self._turn_notifications.get(turn_id)
|
||||
if turn_queue is None:
|
||||
raise RuntimeError(f"turn {turn_id!r} is not registered for streaming")
|
||||
item = turn_queue.get()
|
||||
if isinstance(item, BaseException):
|
||||
raise item
|
||||
return item
|
||||
|
||||
def route_response(self, msg: dict[str, JsonValue]) -> None:
|
||||
"""Deliver a JSON-RPC response or error to its request waiter."""
|
||||
|
||||
request_id = msg.get("id")
|
||||
with self._lock:
|
||||
waiter = self._response_waiters.pop(str(request_id), None)
|
||||
if waiter is None:
|
||||
return
|
||||
|
||||
if "error" in msg:
|
||||
err = msg["error"]
|
||||
if isinstance(err, dict):
|
||||
waiter.put(
|
||||
map_jsonrpc_error(
|
||||
int(err.get("code", -32000)),
|
||||
str(err.get("message", "unknown")),
|
||||
err.get("data"),
|
||||
)
|
||||
)
|
||||
else:
|
||||
waiter.put(AppServerError("Malformed JSON-RPC error response"))
|
||||
return
|
||||
|
||||
waiter.put(msg.get("result"))
|
||||
|
||||
def route_notification(self, notification: Notification) -> None:
|
||||
"""Deliver a notification to a turn queue or the global queue."""
|
||||
|
||||
turn_id = self._notification_turn_id(notification)
|
||||
if turn_id is None:
|
||||
self._global_notifications.put(notification)
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
turn_queue = self._turn_notifications.get(turn_id)
|
||||
if turn_queue is None:
|
||||
if notification.method == "turn/completed":
|
||||
self._pending_turn_notifications.pop(turn_id, None)
|
||||
return
|
||||
self._pending_turn_notifications.setdefault(turn_id, deque()).append(
|
||||
notification
|
||||
)
|
||||
return
|
||||
turn_queue.put(notification)
|
||||
|
||||
def fail_all(self, exc: BaseException) -> None:
|
||||
"""Wake every blocked waiter when the reader thread exits."""
|
||||
|
||||
with self._lock:
|
||||
response_waiters = list(self._response_waiters.values())
|
||||
self._response_waiters.clear()
|
||||
turn_queues = list(self._turn_notifications.values())
|
||||
self._pending_turn_notifications.clear()
|
||||
# Put the same transport failure into every queue so no SDK call blocks
|
||||
# forever waiting for a response that cannot arrive.
|
||||
for waiter in response_waiters:
|
||||
waiter.put(exc)
|
||||
for turn_queue in turn_queues:
|
||||
turn_queue.put(exc)
|
||||
self._global_notifications.put(exc)
|
||||
|
||||
def _notification_turn_id(self, notification: Notification) -> str | None:
|
||||
"""Extract routing ids from known generated payloads or raw unknown payloads."""
|
||||
payload = notification.payload
|
||||
if isinstance(payload, UnknownNotification):
|
||||
raw_turn_id = payload.params.get("turnId")
|
||||
if isinstance(raw_turn_id, str):
|
||||
return raw_turn_id
|
||||
raw_turn = payload.params.get("turn")
|
||||
if isinstance(raw_turn, dict):
|
||||
raw_nested_turn_id = raw_turn.get("id")
|
||||
if isinstance(raw_nested_turn_id, str):
|
||||
return raw_nested_turn_id
|
||||
return None
|
||||
return notification_turn_id(payload)
|
||||
@@ -15,7 +15,6 @@ from .generated.v2_all import (
|
||||
ReasoningSummary,
|
||||
SandboxMode,
|
||||
SandboxPolicy,
|
||||
ServiceTier,
|
||||
SortDirection,
|
||||
ThreadArchiveResponse,
|
||||
ThreadCompactStartResponse,
|
||||
@@ -27,6 +26,7 @@ from .generated.v2_all import (
|
||||
ThreadResumeParams,
|
||||
ThreadSetNameResponse,
|
||||
ThreadSortKey,
|
||||
ThreadSource,
|
||||
ThreadSourceKind,
|
||||
ThreadStartSource,
|
||||
ThreadStartParams,
|
||||
@@ -38,14 +38,14 @@ from .generated.v2_all import (
|
||||
)
|
||||
from .models import InitializeResponse, JsonObject, Notification, ServerInfo
|
||||
from ._inputs import (
|
||||
ImageInput,
|
||||
ImageInput as ImageInput,
|
||||
Input,
|
||||
InputItem,
|
||||
LocalImageInput,
|
||||
MentionInput,
|
||||
InputItem as InputItem,
|
||||
LocalImageInput as LocalImageInput,
|
||||
MentionInput as MentionInput,
|
||||
RunInput,
|
||||
SkillInput,
|
||||
TextInput,
|
||||
SkillInput as SkillInput,
|
||||
TextInput as TextInput,
|
||||
_normalize_run_input,
|
||||
_to_wire_input,
|
||||
)
|
||||
@@ -152,8 +152,9 @@ class Codex:
|
||||
personality: Personality | None = None,
|
||||
sandbox: SandboxMode | None = None,
|
||||
service_name: str | None = None,
|
||||
service_tier: ServiceTier | None = None,
|
||||
service_tier: str | None = None,
|
||||
session_start_source: ThreadStartSource | None = None,
|
||||
thread_source: ThreadSource | None = None,
|
||||
) -> Thread:
|
||||
params = ThreadStartParams(
|
||||
approval_policy=approval_policy,
|
||||
@@ -170,6 +171,7 @@ class Codex:
|
||||
service_name=service_name,
|
||||
service_tier=service_tier,
|
||||
session_start_source=session_start_source,
|
||||
thread_source=thread_source,
|
||||
)
|
||||
started = self._client.thread_start(params)
|
||||
return Thread(self._client, started.thread.id)
|
||||
@@ -216,7 +218,7 @@ class Codex:
|
||||
model_provider: str | None = None,
|
||||
personality: Personality | None = None,
|
||||
sandbox: SandboxMode | None = None,
|
||||
service_tier: ServiceTier | None = None,
|
||||
service_tier: str | None = None,
|
||||
) -> Thread:
|
||||
params = ThreadResumeParams(
|
||||
thread_id=thread_id,
|
||||
@@ -249,7 +251,8 @@ class Codex:
|
||||
model: str | None = None,
|
||||
model_provider: str | None = None,
|
||||
sandbox: SandboxMode | None = None,
|
||||
service_tier: ServiceTier | None = None,
|
||||
service_tier: str | None = None,
|
||||
thread_source: ThreadSource | None = None,
|
||||
) -> Thread:
|
||||
params = ThreadForkParams(
|
||||
thread_id=thread_id,
|
||||
@@ -264,6 +267,7 @@ class Codex:
|
||||
model_provider=model_provider,
|
||||
sandbox=sandbox,
|
||||
service_tier=service_tier,
|
||||
thread_source=thread_source,
|
||||
)
|
||||
forked = self._client.thread_fork(thread_id, params)
|
||||
return Thread(self._client, forked.thread.id)
|
||||
@@ -274,6 +278,7 @@ class Codex:
|
||||
def thread_unarchive(self, thread_id: str) -> Thread:
|
||||
unarchived = self._client.thread_unarchive(thread_id)
|
||||
return Thread(self._client, unarchived.thread.id)
|
||||
|
||||
# END GENERATED: Codex.flat_methods
|
||||
|
||||
def models(self, *, include_hidden: bool = False) -> ModelListResponse:
|
||||
@@ -348,8 +353,9 @@ class AsyncCodex:
|
||||
personality: Personality | None = None,
|
||||
sandbox: SandboxMode | None = None,
|
||||
service_name: str | None = None,
|
||||
service_tier: ServiceTier | None = None,
|
||||
service_tier: str | None = None,
|
||||
session_start_source: ThreadStartSource | None = None,
|
||||
thread_source: ThreadSource | None = None,
|
||||
) -> AsyncThread:
|
||||
await self._ensure_initialized()
|
||||
params = ThreadStartParams(
|
||||
@@ -367,6 +373,7 @@ class AsyncCodex:
|
||||
service_name=service_name,
|
||||
service_tier=service_tier,
|
||||
session_start_source=session_start_source,
|
||||
thread_source=thread_source,
|
||||
)
|
||||
started = await self._client.thread_start(params)
|
||||
return AsyncThread(self, started.thread.id)
|
||||
@@ -414,7 +421,7 @@ class AsyncCodex:
|
||||
model_provider: str | None = None,
|
||||
personality: Personality | None = None,
|
||||
sandbox: SandboxMode | None = None,
|
||||
service_tier: ServiceTier | None = None,
|
||||
service_tier: str | None = None,
|
||||
) -> AsyncThread:
|
||||
await self._ensure_initialized()
|
||||
params = ThreadResumeParams(
|
||||
@@ -448,7 +455,8 @@ class AsyncCodex:
|
||||
model: str | None = None,
|
||||
model_provider: str | None = None,
|
||||
sandbox: SandboxMode | None = None,
|
||||
service_tier: ServiceTier | None = None,
|
||||
service_tier: str | None = None,
|
||||
thread_source: ThreadSource | None = None,
|
||||
) -> AsyncThread:
|
||||
await self._ensure_initialized()
|
||||
params = ThreadForkParams(
|
||||
@@ -464,6 +472,7 @@ class AsyncCodex:
|
||||
model_provider=model_provider,
|
||||
sandbox=sandbox,
|
||||
service_tier=service_tier,
|
||||
thread_source=thread_source,
|
||||
)
|
||||
forked = await self._client.thread_fork(thread_id, params)
|
||||
return AsyncThread(self, forked.thread.id)
|
||||
@@ -476,6 +485,7 @@ class AsyncCodex:
|
||||
await self._ensure_initialized()
|
||||
unarchived = await self._client.thread_unarchive(thread_id)
|
||||
return AsyncThread(self, unarchived.thread.id)
|
||||
|
||||
# END GENERATED: AsyncCodex.flat_methods
|
||||
|
||||
async def models(self, *, include_hidden: bool = False) -> ModelListResponse:
|
||||
@@ -500,7 +510,7 @@ class Thread:
|
||||
output_schema: JsonObject | None = None,
|
||||
personality: Personality | None = None,
|
||||
sandbox_policy: SandboxPolicy | None = None,
|
||||
service_tier: ServiceTier | None = None,
|
||||
service_tier: str | None = None,
|
||||
summary: ReasoningSummary | None = None,
|
||||
) -> RunResult:
|
||||
turn = self.turn(
|
||||
@@ -535,7 +545,7 @@ class Thread:
|
||||
output_schema: JsonObject | None = None,
|
||||
personality: Personality | None = None,
|
||||
sandbox_policy: SandboxPolicy | None = None,
|
||||
service_tier: ServiceTier | None = None,
|
||||
service_tier: str | None = None,
|
||||
summary: ReasoningSummary | None = None,
|
||||
) -> TurnHandle:
|
||||
wire_input = _to_wire_input(input)
|
||||
@@ -555,6 +565,7 @@ class Thread:
|
||||
)
|
||||
turn = self._client.turn_start(self.id, wire_input, params=params)
|
||||
return TurnHandle(self._client, self.id, turn.turn.id)
|
||||
|
||||
# END GENERATED: Thread.flat_methods
|
||||
|
||||
def read(self, *, include_turns: bool = False) -> ThreadReadResponse:
|
||||
@@ -584,7 +595,7 @@ class AsyncThread:
|
||||
output_schema: JsonObject | None = None,
|
||||
personality: Personality | None = None,
|
||||
sandbox_policy: SandboxPolicy | None = None,
|
||||
service_tier: ServiceTier | None = None,
|
||||
service_tier: str | None = None,
|
||||
summary: ReasoningSummary | None = None,
|
||||
) -> RunResult:
|
||||
turn = await self.turn(
|
||||
@@ -619,7 +630,7 @@ class AsyncThread:
|
||||
output_schema: JsonObject | None = None,
|
||||
personality: Personality | None = None,
|
||||
sandbox_policy: SandboxPolicy | None = None,
|
||||
service_tier: ServiceTier | None = None,
|
||||
service_tier: str | None = None,
|
||||
summary: ReasoningSummary | None = None,
|
||||
) -> AsyncTurnHandle:
|
||||
await self._codex._ensure_initialized()
|
||||
@@ -644,6 +655,7 @@ class AsyncThread:
|
||||
params=params,
|
||||
)
|
||||
return AsyncTurnHandle(self._codex, self.id, turn.turn.id)
|
||||
|
||||
# END GENERATED: AsyncThread.flat_methods
|
||||
|
||||
async def read(self, *, include_turns: bool = False) -> ThreadReadResponse:
|
||||
@@ -674,11 +686,11 @@ class TurnHandle:
|
||||
return self._client.turn_interrupt(self.thread_id, self.id)
|
||||
|
||||
def stream(self) -> Iterator[Notification]:
|
||||
# TODO: replace this client-wide experimental guard with per-turn event demux.
|
||||
self._client.acquire_turn_consumer(self.id)
|
||||
"""Yield only notifications routed to this turn handle."""
|
||||
self._client.register_turn_notifications(self.id)
|
||||
try:
|
||||
while True:
|
||||
event = self._client.next_notification()
|
||||
event = self._client.next_turn_notification(self.id)
|
||||
yield event
|
||||
if (
|
||||
event.method == "turn/completed"
|
||||
@@ -687,7 +699,7 @@ class TurnHandle:
|
||||
):
|
||||
break
|
||||
finally:
|
||||
self._client.release_turn_consumer(self.id)
|
||||
self._client.unregister_turn_notifications(self.id)
|
||||
|
||||
def run(self) -> AppServerTurn:
|
||||
completed: TurnCompletedNotification | None = None
|
||||
@@ -727,12 +739,12 @@ class AsyncTurnHandle:
|
||||
return await self._codex._client.turn_interrupt(self.thread_id, self.id)
|
||||
|
||||
async def stream(self) -> AsyncIterator[Notification]:
|
||||
"""Yield only notifications routed to this async turn handle."""
|
||||
await self._codex._ensure_initialized()
|
||||
# TODO: replace this client-wide experimental guard with per-turn event demux.
|
||||
self._codex._client.acquire_turn_consumer(self.id)
|
||||
self._codex._client.register_turn_notifications(self.id)
|
||||
try:
|
||||
while True:
|
||||
event = await self._codex._client.next_notification()
|
||||
event = await self._codex._client.next_turn_notification(self.id)
|
||||
yield event
|
||||
if (
|
||||
event.method == "turn/completed"
|
||||
@@ -741,7 +753,7 @@ class AsyncTurnHandle:
|
||||
):
|
||||
break
|
||||
finally:
|
||||
self._codex._client.release_turn_consumer(self.id)
|
||||
self._codex._client.unregister_turn_notifications(self.id)
|
||||
|
||||
async def run(self) -> AppServerTurn:
|
||||
completed: TurnCompletedNotification | None = None
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Iterator
|
||||
from typing import AsyncIterator, Callable, Iterable, ParamSpec, TypeVar
|
||||
from typing import AsyncIterator, Callable, ParamSpec, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -40,15 +40,16 @@ class AsyncAppServerClient:
|
||||
"""Async wrapper around AppServerClient using thread offloading."""
|
||||
|
||||
def __init__(self, config: AppServerConfig | None = None) -> None:
|
||||
"""Create the wrapped sync client that owns the transport process."""
|
||||
self._sync = AppServerClient(config=config)
|
||||
# Single stdio transport cannot be read safely from multiple threads.
|
||||
self._transport_lock = asyncio.Lock()
|
||||
|
||||
async def __aenter__(self) -> "AsyncAppServerClient":
|
||||
"""Start the app-server process when entering an async context."""
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, _exc_type, _exc, _tb) -> None:
|
||||
"""Close the app-server process when leaving an async context."""
|
||||
await self.close()
|
||||
|
||||
async def _call_sync(
|
||||
@@ -58,32 +59,38 @@ class AsyncAppServerClient:
|
||||
*args: ParamsT.args,
|
||||
**kwargs: ParamsT.kwargs,
|
||||
) -> ReturnT:
|
||||
async with self._transport_lock:
|
||||
return await asyncio.to_thread(fn, *args, **kwargs)
|
||||
"""Run a blocking sync-client operation without blocking the event loop."""
|
||||
return await asyncio.to_thread(fn, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _next_from_iterator(
|
||||
iterator: Iterator[AgentMessageDeltaNotification],
|
||||
) -> tuple[bool, AgentMessageDeltaNotification | None]:
|
||||
"""Convert StopIteration into a value that can cross asyncio.to_thread."""
|
||||
try:
|
||||
return True, next(iterator)
|
||||
except StopIteration:
|
||||
return False, None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the wrapped sync client in a worker thread."""
|
||||
await self._call_sync(self._sync.start)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the wrapped sync client in a worker thread."""
|
||||
await self._call_sync(self._sync.close)
|
||||
|
||||
async def initialize(self) -> InitializeResponse:
|
||||
"""Initialize the app-server session."""
|
||||
return await self._call_sync(self._sync.initialize)
|
||||
|
||||
def acquire_turn_consumer(self, turn_id: str) -> None:
|
||||
self._sync.acquire_turn_consumer(turn_id)
|
||||
def register_turn_notifications(self, turn_id: str) -> None:
|
||||
"""Register a turn notification queue on the wrapped sync client."""
|
||||
self._sync.register_turn_notifications(turn_id)
|
||||
|
||||
def release_turn_consumer(self, turn_id: str) -> None:
|
||||
self._sync.release_turn_consumer(turn_id)
|
||||
def unregister_turn_notifications(self, turn_id: str) -> None:
|
||||
"""Unregister a turn notification queue on the wrapped sync client."""
|
||||
self._sync.unregister_turn_notifications(turn_id)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
@@ -92,6 +99,7 @@ class AsyncAppServerClient:
|
||||
*,
|
||||
response_model: type[ModelT],
|
||||
) -> ModelT:
|
||||
"""Send a typed JSON-RPC request through the wrapped sync client."""
|
||||
return await self._call_sync(
|
||||
self._sync.request,
|
||||
method,
|
||||
@@ -99,7 +107,10 @@ class AsyncAppServerClient:
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
async def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
|
||||
async def thread_start(
|
||||
self, params: V2ThreadStartParams | JsonObject | None = None
|
||||
) -> ThreadStartResponse:
|
||||
"""Start a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_start, params)
|
||||
|
||||
async def thread_resume(
|
||||
@@ -107,12 +118,19 @@ class AsyncAppServerClient:
|
||||
thread_id: str,
|
||||
params: V2ThreadResumeParams | JsonObject | None = None,
|
||||
) -> ThreadResumeResponse:
|
||||
"""Resume a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_resume, thread_id, params)
|
||||
|
||||
async def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse:
|
||||
async def thread_list(
|
||||
self, params: V2ThreadListParams | JsonObject | None = None
|
||||
) -> ThreadListResponse:
|
||||
"""List threads using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_list, params)
|
||||
|
||||
async def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse:
|
||||
async def thread_read(
|
||||
self, thread_id: str, include_turns: bool = False
|
||||
) -> ThreadReadResponse:
|
||||
"""Read a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_read, thread_id, include_turns)
|
||||
|
||||
async def thread_fork(
|
||||
@@ -120,18 +138,23 @@ class AsyncAppServerClient:
|
||||
thread_id: str,
|
||||
params: V2ThreadForkParams | JsonObject | None = None,
|
||||
) -> ThreadForkResponse:
|
||||
"""Fork a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_fork, thread_id, params)
|
||||
|
||||
async def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:
|
||||
"""Archive a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_archive, thread_id)
|
||||
|
||||
async def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse:
|
||||
"""Unarchive a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_unarchive, thread_id)
|
||||
|
||||
async def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse:
|
||||
"""Rename a thread using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_set_name, thread_id, name)
|
||||
|
||||
async def thread_compact(self, thread_id: str) -> ThreadCompactStartResponse:
|
||||
"""Start thread compaction using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_compact, thread_id)
|
||||
|
||||
async def turn_start(
|
||||
@@ -140,9 +163,15 @@ class AsyncAppServerClient:
|
||||
input_items: list[JsonObject] | JsonObject | str,
|
||||
params: V2TurnStartParams | JsonObject | None = None,
|
||||
) -> TurnStartResponse:
|
||||
return await self._call_sync(self._sync.turn_start, thread_id, input_items, params)
|
||||
"""Start a turn using the wrapped sync client."""
|
||||
return await self._call_sync(
|
||||
self._sync.turn_start, thread_id, input_items, params
|
||||
)
|
||||
|
||||
async def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
|
||||
async def turn_interrupt(
|
||||
self, thread_id: str, turn_id: str
|
||||
) -> TurnInterruptResponse:
|
||||
"""Interrupt a turn using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.turn_interrupt, thread_id, turn_id)
|
||||
|
||||
async def turn_steer(
|
||||
@@ -151,6 +180,7 @@ class AsyncAppServerClient:
|
||||
expected_turn_id: str,
|
||||
input_items: list[JsonObject] | JsonObject | str,
|
||||
) -> TurnSteerResponse:
|
||||
"""Send steering input to a turn using the wrapped sync client."""
|
||||
return await self._call_sync(
|
||||
self._sync.turn_steer,
|
||||
thread_id,
|
||||
@@ -159,6 +189,7 @@ class AsyncAppServerClient:
|
||||
)
|
||||
|
||||
async def model_list(self, include_hidden: bool = False) -> ModelListResponse:
|
||||
"""List models using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.model_list, include_hidden)
|
||||
|
||||
async def request_with_retry_on_overload(
|
||||
@@ -171,6 +202,7 @@ class AsyncAppServerClient:
|
||||
initial_delay_s: float = 0.25,
|
||||
max_delay_s: float = 2.0,
|
||||
) -> ModelT:
|
||||
"""Send a typed request with the sync client's overload retry policy."""
|
||||
return await self._call_sync(
|
||||
self._sync.request_with_retry_on_overload,
|
||||
method,
|
||||
@@ -182,13 +214,16 @@ class AsyncAppServerClient:
|
||||
)
|
||||
|
||||
async def next_notification(self) -> Notification:
|
||||
"""Wait for the next global notification without blocking the event loop."""
|
||||
return await self._call_sync(self._sync.next_notification)
|
||||
|
||||
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
|
||||
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
|
||||
async def next_turn_notification(self, turn_id: str) -> Notification:
|
||||
"""Wait for the next notification routed to one turn."""
|
||||
return await self._call_sync(self._sync.next_turn_notification, turn_id)
|
||||
|
||||
async def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]:
|
||||
return await self._call_sync(self._sync.stream_until_methods, methods)
|
||||
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
|
||||
"""Wait for the completion notification routed to one turn."""
|
||||
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
|
||||
|
||||
async def stream_text(
|
||||
self,
|
||||
@@ -196,13 +231,13 @@ class AsyncAppServerClient:
|
||||
text: str,
|
||||
params: V2TurnStartParams | JsonObject | None = None,
|
||||
) -> AsyncIterator[AgentMessageDeltaNotification]:
|
||||
async with self._transport_lock:
|
||||
iterator = self._sync.stream_text(thread_id, text, params)
|
||||
while True:
|
||||
has_value, chunk = await asyncio.to_thread(
|
||||
self._next_from_iterator,
|
||||
iterator,
|
||||
)
|
||||
if not has_value:
|
||||
break
|
||||
yield chunk
|
||||
"""Stream text deltas from one turn without monopolizing the event loop."""
|
||||
iterator = self._sync.stream_text(thread_id, text, params)
|
||||
while True:
|
||||
has_value, chunk = await asyncio.to_thread(
|
||||
self._next_from_iterator,
|
||||
iterator,
|
||||
)
|
||||
if not has_value:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
@@ -8,11 +8,11 @@ import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, TypeVar
|
||||
from typing import Callable, Iterator, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .errors import AppServerError, TransportClosedError, map_jsonrpc_error
|
||||
from .errors import AppServerError, TransportClosedError
|
||||
from .generated.notification_registry import NOTIFICATION_MODELS
|
||||
from .generated.v2_all import (
|
||||
AgentMessageDeltaNotification,
|
||||
@@ -43,6 +43,7 @@ from .models import (
|
||||
Notification,
|
||||
UnknownNotification,
|
||||
)
|
||||
from ._message_router import MessageRouter
|
||||
from .retry import retry_on_overload
|
||||
from ._version import __version__ as SDK_VERSION
|
||||
|
||||
@@ -75,7 +76,9 @@ def _params_dict(
|
||||
return dumped
|
||||
if isinstance(params, dict):
|
||||
return params
|
||||
raise TypeError(f"Expected generated params model or dict, got {type(params).__name__}")
|
||||
raise TypeError(
|
||||
f"Expected generated params model or dict, got {type(params).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def _installed_codex_path() -> Path:
|
||||
@@ -146,11 +149,10 @@ class AppServerClient:
|
||||
self._approval_handler = approval_handler or self._default_approval_handler
|
||||
self._proc: subprocess.Popen[str] | None = None
|
||||
self._lock = threading.Lock()
|
||||
self._turn_consumer_lock = threading.Lock()
|
||||
self._active_turn_consumer: str | None = None
|
||||
self._pending_notifications: deque[Notification] = deque()
|
||||
self._router = MessageRouter()
|
||||
self._stderr_lines: deque[str] = deque(maxlen=400)
|
||||
self._stderr_thread: threading.Thread | None = None
|
||||
self._reader_thread: threading.Thread | None = None
|
||||
|
||||
def __enter__(self) -> "AppServerClient":
|
||||
self.start()
|
||||
@@ -189,13 +191,13 @@ class AppServerClient:
|
||||
)
|
||||
|
||||
self._start_stderr_drain_thread()
|
||||
self._start_reader_thread()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._proc is None:
|
||||
return
|
||||
proc = self._proc
|
||||
self._proc = None
|
||||
self._active_turn_consumer = None
|
||||
|
||||
if proc.stdin:
|
||||
proc.stdin.close()
|
||||
@@ -207,6 +209,8 @@ class AppServerClient:
|
||||
|
||||
if self._stderr_thread and self._stderr_thread.is_alive():
|
||||
self._stderr_thread.join(timeout=0.5)
|
||||
if self._reader_thread and self._reader_thread.is_alive():
|
||||
self._reader_thread.join(timeout=0.5)
|
||||
|
||||
def initialize(self) -> InitializeResponse:
|
||||
result = self.request(
|
||||
@@ -239,71 +243,49 @@ class AppServerClient:
|
||||
return response_model.model_validate(result)
|
||||
|
||||
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
|
||||
"""Send a JSON-RPC request and wait for the reader thread to route its response."""
|
||||
request_id = str(uuid.uuid4())
|
||||
self._write_message({"id": request_id, "method": method, "params": params or {}})
|
||||
waiter = self._router.create_response_waiter(request_id)
|
||||
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
try:
|
||||
self._write_message(
|
||||
{"id": request_id, "method": method, "params": params or {}}
|
||||
)
|
||||
except BaseException:
|
||||
self._router.discard_response_waiter(request_id)
|
||||
raise
|
||||
|
||||
if "method" in msg and "id" in msg:
|
||||
response = self._handle_server_request(msg)
|
||||
self._write_message({"id": msg["id"], "result": response})
|
||||
continue
|
||||
|
||||
if "method" in msg and "id" not in msg:
|
||||
self._pending_notifications.append(
|
||||
self._coerce_notification(msg["method"], msg.get("params"))
|
||||
)
|
||||
continue
|
||||
|
||||
if msg.get("id") != request_id:
|
||||
continue
|
||||
|
||||
if "error" in msg:
|
||||
err = msg["error"]
|
||||
if isinstance(err, dict):
|
||||
raise map_jsonrpc_error(
|
||||
int(err.get("code", -32000)),
|
||||
str(err.get("message", "unknown")),
|
||||
err.get("data"),
|
||||
)
|
||||
raise AppServerError("Malformed JSON-RPC error response")
|
||||
|
||||
return msg.get("result")
|
||||
item = waiter.get()
|
||||
if isinstance(item, BaseException):
|
||||
raise item
|
||||
return item
|
||||
|
||||
def notify(self, method: str, params: JsonObject | None = None) -> None:
|
||||
"""Send a JSON-RPC notification without waiting for a response."""
|
||||
self._write_message({"method": method, "params": params or {}})
|
||||
|
||||
def next_notification(self) -> Notification:
|
||||
if self._pending_notifications:
|
||||
return self._pending_notifications.popleft()
|
||||
"""Return the next notification that is not scoped to an active turn."""
|
||||
return self._router.next_global_notification()
|
||||
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
if "method" in msg and "id" in msg:
|
||||
response = self._handle_server_request(msg)
|
||||
self._write_message({"id": msg["id"], "result": response})
|
||||
continue
|
||||
if "method" in msg and "id" not in msg:
|
||||
return self._coerce_notification(msg["method"], msg.get("params"))
|
||||
def register_turn_notifications(self, turn_id: str) -> None:
|
||||
"""Start routing notifications for one turn into its dedicated queue."""
|
||||
self._router.register_turn(turn_id)
|
||||
|
||||
def acquire_turn_consumer(self, turn_id: str) -> None:
|
||||
with self._turn_consumer_lock:
|
||||
if self._active_turn_consumer is not None:
|
||||
raise RuntimeError(
|
||||
"Concurrent turn consumers are not yet supported in the experimental SDK. "
|
||||
f"Client is already streaming turn {self._active_turn_consumer!r}; "
|
||||
f"cannot start turn {turn_id!r} until the active consumer finishes."
|
||||
)
|
||||
self._active_turn_consumer = turn_id
|
||||
def unregister_turn_notifications(self, turn_id: str) -> None:
|
||||
"""Stop routing notifications for one turn into its dedicated queue."""
|
||||
self._router.unregister_turn(turn_id)
|
||||
|
||||
def release_turn_consumer(self, turn_id: str) -> None:
|
||||
with self._turn_consumer_lock:
|
||||
if self._active_turn_consumer == turn_id:
|
||||
self._active_turn_consumer = None
|
||||
def next_turn_notification(self, turn_id: str) -> Notification:
|
||||
"""Return the next routed notification for the requested turn id."""
|
||||
return self._router.next_turn_notification(turn_id)
|
||||
|
||||
def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
|
||||
return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse)
|
||||
def thread_start(
|
||||
self, params: V2ThreadStartParams | JsonObject | None = None
|
||||
) -> ThreadStartResponse:
|
||||
return self.request(
|
||||
"thread/start", _params_dict(params), response_model=ThreadStartResponse
|
||||
)
|
||||
|
||||
def thread_resume(
|
||||
self,
|
||||
@@ -311,12 +293,20 @@ class AppServerClient:
|
||||
params: V2ThreadResumeParams | JsonObject | None = None,
|
||||
) -> ThreadResumeResponse:
|
||||
payload = {"threadId": thread_id, **_params_dict(params)}
|
||||
return self.request("thread/resume", payload, response_model=ThreadResumeResponse)
|
||||
return self.request(
|
||||
"thread/resume", payload, response_model=ThreadResumeResponse
|
||||
)
|
||||
|
||||
def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse:
|
||||
return self.request("thread/list", _params_dict(params), response_model=ThreadListResponse)
|
||||
def thread_list(
|
||||
self, params: V2ThreadListParams | JsonObject | None = None
|
||||
) -> ThreadListResponse:
|
||||
return self.request(
|
||||
"thread/list", _params_dict(params), response_model=ThreadListResponse
|
||||
)
|
||||
|
||||
def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse:
|
||||
def thread_read(
|
||||
self, thread_id: str, include_turns: bool = False
|
||||
) -> ThreadReadResponse:
|
||||
return self.request(
|
||||
"thread/read",
|
||||
{"threadId": thread_id, "includeTurns": include_turns},
|
||||
@@ -332,10 +322,18 @@ class AppServerClient:
|
||||
return self.request("thread/fork", payload, response_model=ThreadForkResponse)
|
||||
|
||||
def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:
|
||||
return self.request("thread/archive", {"threadId": thread_id}, response_model=ThreadArchiveResponse)
|
||||
return self.request(
|
||||
"thread/archive",
|
||||
{"threadId": thread_id},
|
||||
response_model=ThreadArchiveResponse,
|
||||
)
|
||||
|
||||
def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse:
|
||||
return self.request("thread/unarchive", {"threadId": thread_id}, response_model=ThreadUnarchiveResponse)
|
||||
return self.request(
|
||||
"thread/unarchive",
|
||||
{"threadId": thread_id},
|
||||
response_model=ThreadUnarchiveResponse,
|
||||
)
|
||||
|
||||
def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse:
|
||||
return self.request(
|
||||
@@ -357,12 +355,15 @@ class AppServerClient:
|
||||
input_items: list[JsonObject] | JsonObject | str,
|
||||
params: V2TurnStartParams | JsonObject | None = None,
|
||||
) -> TurnStartResponse:
|
||||
"""Start a turn and register its notification queue as early as possible."""
|
||||
payload = {
|
||||
**_params_dict(params),
|
||||
"threadId": thread_id,
|
||||
"input": self._normalize_input_items(input_items),
|
||||
}
|
||||
return self.request("turn/start", payload, response_model=TurnStartResponse)
|
||||
started = self.request("turn/start", payload, response_model=TurnStartResponse)
|
||||
self.register_turn_notifications(started.turn.id)
|
||||
return started
|
||||
|
||||
def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
|
||||
return self.request(
|
||||
@@ -412,23 +413,19 @@ class AppServerClient:
|
||||
)
|
||||
|
||||
def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
|
||||
while True:
|
||||
notification = self.next_notification()
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
return notification.payload
|
||||
|
||||
def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]:
|
||||
target_methods = {methods} if isinstance(methods, str) else set(methods)
|
||||
out: list[Notification] = []
|
||||
while True:
|
||||
notification = self.next_notification()
|
||||
out.append(notification)
|
||||
if notification.method in target_methods:
|
||||
return out
|
||||
"""Block on the routed turn stream until the matching completion arrives."""
|
||||
self.register_turn_notifications(turn_id)
|
||||
try:
|
||||
while True:
|
||||
notification = self.next_turn_notification(turn_id)
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
return notification.payload
|
||||
finally:
|
||||
self.unregister_turn_notifications(turn_id)
|
||||
|
||||
def stream_text(
|
||||
self,
|
||||
@@ -436,35 +433,44 @@ class AppServerClient:
|
||||
text: str,
|
||||
params: V2TurnStartParams | JsonObject | None = None,
|
||||
) -> Iterator[AgentMessageDeltaNotification]:
|
||||
"""Start a text turn and yield only its agent-message delta payloads."""
|
||||
started = self.turn_start(thread_id, text, params=params)
|
||||
turn_id = started.turn.id
|
||||
while True:
|
||||
notification = self.next_notification()
|
||||
if (
|
||||
notification.method == "item/agentMessage/delta"
|
||||
and isinstance(notification.payload, AgentMessageDeltaNotification)
|
||||
and notification.payload.turn_id == turn_id
|
||||
):
|
||||
yield notification.payload
|
||||
continue
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
break
|
||||
self.register_turn_notifications(turn_id)
|
||||
try:
|
||||
while True:
|
||||
notification = self.next_turn_notification(turn_id)
|
||||
if (
|
||||
notification.method == "item/agentMessage/delta"
|
||||
and isinstance(notification.payload, AgentMessageDeltaNotification)
|
||||
and notification.payload.turn_id == turn_id
|
||||
):
|
||||
yield notification.payload
|
||||
continue
|
||||
if (
|
||||
notification.method == "turn/completed"
|
||||
and isinstance(notification.payload, TurnCompletedNotification)
|
||||
and notification.payload.turn.id == turn_id
|
||||
):
|
||||
break
|
||||
finally:
|
||||
self.unregister_turn_notifications(turn_id)
|
||||
|
||||
def _coerce_notification(self, method: str, params: object) -> Notification:
|
||||
params_dict = params if isinstance(params, dict) else {}
|
||||
|
||||
model = NOTIFICATION_MODELS.get(method)
|
||||
if model is None:
|
||||
return Notification(method=method, payload=UnknownNotification(params=params_dict))
|
||||
return Notification(
|
||||
method=method, payload=UnknownNotification(params=params_dict)
|
||||
)
|
||||
|
||||
try:
|
||||
payload = model.model_validate(params_dict)
|
||||
except Exception: # noqa: BLE001
|
||||
return Notification(method=method, payload=UnknownNotification(params=params_dict))
|
||||
return Notification(
|
||||
method=method, payload=UnknownNotification(params=params_dict)
|
||||
)
|
||||
return Notification(method=method, payload=payload)
|
||||
|
||||
def _normalize_input_items(
|
||||
@@ -477,7 +483,10 @@ class AppServerClient:
|
||||
return [input_items]
|
||||
return input_items
|
||||
|
||||
def _default_approval_handler(self, method: str, params: JsonObject | None) -> JsonObject:
|
||||
def _default_approval_handler(
|
||||
self, method: str, params: JsonObject | None
|
||||
) -> JsonObject:
|
||||
"""Accept approval requests when the caller did not provide a handler."""
|
||||
if method == "item/commandExecution/requestApproval":
|
||||
return {"decision": "accept"}
|
||||
if method == "item/fileChange/requestApproval":
|
||||
@@ -498,6 +507,34 @@ class AppServerClient:
|
||||
self._stderr_thread = threading.Thread(target=_drain, daemon=True)
|
||||
self._stderr_thread.start()
|
||||
|
||||
def _start_reader_thread(self) -> None:
|
||||
"""Start the sole stdout reader that fans messages into router queues."""
|
||||
if self._proc is None or self._proc.stdout is None:
|
||||
return
|
||||
|
||||
self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True)
|
||||
self._reader_thread.start()
|
||||
|
||||
def _reader_loop(self) -> None:
|
||||
"""Continuously classify transport messages into requests, responses, and events."""
|
||||
try:
|
||||
while True:
|
||||
msg = self._read_message()
|
||||
if "method" in msg and "id" in msg:
|
||||
response = self._handle_server_request(msg)
|
||||
self._write_message({"id": msg["id"], "result": response})
|
||||
continue
|
||||
if "method" in msg and "id" not in msg:
|
||||
method = msg["method"]
|
||||
if isinstance(method, str):
|
||||
self._router.route_notification(
|
||||
self._coerce_notification(method, msg.get("params"))
|
||||
)
|
||||
continue
|
||||
self._router.route_response(msg)
|
||||
except BaseException as exc:
|
||||
self._router.fail_all(exc)
|
||||
|
||||
def _stderr_tail(self, limit: int = 40) -> str:
|
||||
return "\n".join(list(self._stderr_lines)[-limit:])
|
||||
|
||||
|
||||
@@ -35,6 +35,8 @@ from .v2_all import McpToolCallProgressNotification
|
||||
from .v2_all import ModelReroutedNotification
|
||||
from .v2_all import ModelVerificationNotification
|
||||
from .v2_all import PlanDeltaNotification
|
||||
from .v2_all import ProcessExitedNotification
|
||||
from .v2_all import ProcessOutputDeltaNotification
|
||||
from .v2_all import ReasoningSummaryPartAddedNotification
|
||||
from .v2_all import ReasoningSummaryTextDeltaNotification
|
||||
from .v2_all import ReasoningTextDeltaNotification
|
||||
@@ -101,6 +103,8 @@ NOTIFICATION_MODELS: dict[str, type[BaseModel]] = {
|
||||
"mcpServer/startupStatus/updated": McpServerStatusUpdatedNotification,
|
||||
"model/rerouted": ModelReroutedNotification,
|
||||
"model/verification": ModelVerificationNotification,
|
||||
"process/exited": ProcessExitedNotification,
|
||||
"process/outputDelta": ProcessOutputDeltaNotification,
|
||||
"remoteControl/status/changed": RemoteControlStatusChangedNotification,
|
||||
"serverRequest/resolved": ServerRequestResolvedNotification,
|
||||
"skills/changed": SkillsChangedNotification,
|
||||
@@ -130,3 +134,44 @@ NOTIFICATION_MODELS: dict[str, type[BaseModel]] = {
|
||||
"windows/worldWritableWarning": WindowsWorldWritableWarningNotification,
|
||||
"windowsSandbox/setupCompleted": WindowsSandboxSetupCompletedNotification,
|
||||
}
|
||||
|
||||
DIRECT_TURN_ID_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = (
|
||||
AgentMessageDeltaNotification,
|
||||
CommandExecutionOutputDeltaNotification,
|
||||
ContextCompactedNotification,
|
||||
ErrorNotification,
|
||||
FileChangeOutputDeltaNotification,
|
||||
FileChangePatchUpdatedNotification,
|
||||
HookCompletedNotification,
|
||||
HookStartedNotification,
|
||||
ItemCompletedNotification,
|
||||
ItemGuardianApprovalReviewCompletedNotification,
|
||||
ItemGuardianApprovalReviewStartedNotification,
|
||||
ItemStartedNotification,
|
||||
McpToolCallProgressNotification,
|
||||
ModelReroutedNotification,
|
||||
ModelVerificationNotification,
|
||||
PlanDeltaNotification,
|
||||
ReasoningSummaryPartAddedNotification,
|
||||
ReasoningSummaryTextDeltaNotification,
|
||||
ReasoningTextDeltaNotification,
|
||||
TerminalInteractionNotification,
|
||||
ThreadGoalUpdatedNotification,
|
||||
ThreadTokenUsageUpdatedNotification,
|
||||
TurnDiffUpdatedNotification,
|
||||
TurnPlanUpdatedNotification,
|
||||
)
|
||||
|
||||
NESTED_TURN_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = (
|
||||
TurnCompletedNotification,
|
||||
TurnStartedNotification,
|
||||
)
|
||||
|
||||
|
||||
def notification_turn_id(payload: BaseModel) -> str | None:
|
||||
"""Return the turn id carried by generated notification payload metadata."""
|
||||
if isinstance(payload, DIRECT_TURN_ID_NOTIFICATION_TYPES):
|
||||
return payload.turn_id if isinstance(payload.turn_id, str) else None
|
||||
if isinstance(payload, NESTED_TURN_NOTIFICATION_TYPES):
|
||||
return payload.turn.id
|
||||
return None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,6 +16,7 @@ ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def _load_update_script_module():
|
||||
"""Load the maintenance script as a module so tests exercise real helpers."""
|
||||
script_path = ROOT / "scripts" / "update_sdk_artifacts.py"
|
||||
spec = importlib.util.spec_from_file_location("update_sdk_artifacts", script_path)
|
||||
if spec is None or spec.loader is None:
|
||||
@@ -27,6 +28,7 @@ def _load_update_script_module():
|
||||
|
||||
|
||||
def _load_runtime_setup_module():
|
||||
"""Load runtime setup without importing the SDK package under test."""
|
||||
runtime_setup_path = ROOT / "_runtime_setup.py"
|
||||
spec = importlib.util.spec_from_file_location("_runtime_setup", runtime_setup_path)
|
||||
if spec is None or spec.loader is None:
|
||||
@@ -40,11 +42,13 @@ def _load_runtime_setup_module():
|
||||
|
||||
|
||||
def test_generation_has_single_maintenance_entrypoint_script() -> None:
|
||||
"""Keep artifact workflows routed through one script instead of side entrypoints."""
|
||||
scripts = sorted(p.name for p in (ROOT / "scripts").glob("*.py"))
|
||||
assert scripts == ["update_sdk_artifacts.py"]
|
||||
|
||||
|
||||
def test_generate_types_wires_all_generation_steps() -> None:
|
||||
"""The type generation command should refresh every schema-derived artifact."""
|
||||
source = (ROOT / "scripts" / "update_sdk_artifacts.py").read_text()
|
||||
tree = ast.parse(source)
|
||||
|
||||
@@ -52,7 +56,8 @@ def test_generate_types_wires_all_generation_steps() -> None:
|
||||
(
|
||||
node
|
||||
for node in tree.body
|
||||
if isinstance(node, ast.FunctionDef) and node.name == "generate_types"
|
||||
if isinstance(node, ast.FunctionDef)
|
||||
and node.name == "generate_types_from_schema_dir"
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -72,19 +77,19 @@ def test_generate_types_wires_all_generation_steps() -> None:
|
||||
]
|
||||
|
||||
|
||||
def test_schema_normalization_only_flattens_string_literal_oneofs() -> None:
|
||||
def _load_runtime_schema_bundle(tmp_path: Path) -> dict:
|
||||
"""Ask the pinned runtime package for a real schema bundle used by tests."""
|
||||
script = _load_update_script_module()
|
||||
schema = json.loads(
|
||||
(
|
||||
ROOT.parent.parent
|
||||
/ "codex-rs"
|
||||
/ "app-server-protocol"
|
||||
/ "schema"
|
||||
/ "json"
|
||||
/ "codex_app_server_protocol.v2.schemas.json"
|
||||
).read_text()
|
||||
)
|
||||
schema_dir = script.generate_schema_from_pinned_runtime(tmp_path / "schema")
|
||||
return json.loads(script.schema_bundle_path(schema_dir).read_text())
|
||||
|
||||
|
||||
def test_schema_normalization_only_flattens_string_literal_oneofs(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Schema normalization should only flatten the enum-shaped oneOf variants."""
|
||||
script = _load_update_script_module()
|
||||
schema = _load_runtime_schema_bundle(tmp_path)
|
||||
definitions = schema["definitions"]
|
||||
flattened = [
|
||||
name
|
||||
@@ -94,27 +99,23 @@ def test_schema_normalization_only_flattens_string_literal_oneofs() -> None:
|
||||
]
|
||||
|
||||
assert flattened == [
|
||||
"AuthMode",
|
||||
"CommandExecOutputStream",
|
||||
"ExperimentalFeatureStage",
|
||||
"InputModality",
|
||||
"MessagePhase",
|
||||
"TurnItemsView",
|
||||
"PluginAvailability",
|
||||
"AuthMode",
|
||||
"InputModality",
|
||||
"ExperimentalFeatureStage",
|
||||
"CommandExecOutputStream",
|
||||
"ProcessOutputStream",
|
||||
]
|
||||
|
||||
|
||||
def test_python_codegen_schema_annotation_adds_stable_variant_titles() -> None:
|
||||
def test_python_codegen_schema_annotation_adds_stable_variant_titles(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Schema annotations should give generated protocol classes stable names."""
|
||||
script = _load_update_script_module()
|
||||
schema = json.loads(
|
||||
(
|
||||
ROOT.parent.parent
|
||||
/ "codex-rs"
|
||||
/ "app-server-protocol"
|
||||
/ "schema"
|
||||
/ "json"
|
||||
/ "codex_app_server_protocol.v2.schemas.json"
|
||||
).read_text()
|
||||
)
|
||||
|
||||
schema = _load_runtime_schema_bundle(tmp_path)
|
||||
script._annotate_schema(schema)
|
||||
definitions = schema["definitions"]
|
||||
|
||||
@@ -163,8 +164,9 @@ def test_runtime_package_template_has_no_checked_in_binaries() -> None:
|
||||
|
||||
|
||||
def test_examples_readme_points_to_runtime_version_source_of_truth() -> None:
|
||||
"""Document that examples should point at the dependency pin, not release lore."""
|
||||
readme = (ROOT / "examples" / "README.md").read_text()
|
||||
assert "The pinned runtime version comes from the SDK package version." in readme
|
||||
assert "The pinned runtime version comes from the SDK package dependency." in readme
|
||||
|
||||
|
||||
def test_runtime_distribution_name_is_consistent() -> None:
|
||||
@@ -185,6 +187,25 @@ def test_runtime_distribution_name_is_consistent() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_source_sdk_package_pins_published_runtime() -> None:
|
||||
"""The source package metadata should pin the runtime wheel that ships schemas."""
|
||||
script = _load_update_script_module()
|
||||
pyproject = tomllib.loads((ROOT / "pyproject.toml").read_text())
|
||||
|
||||
assert {
|
||||
"sdk_version": pyproject["project"]["version"],
|
||||
"runtime_pin": script.pinned_runtime_version(),
|
||||
"dependencies": pyproject["project"]["dependencies"],
|
||||
} == {
|
||||
"sdk_version": "0.131.0a4",
|
||||
"runtime_pin": "0.131.0a4",
|
||||
"dependencies": [
|
||||
"pydantic>=2.12",
|
||||
"openai-codex-cli-bin==0.131.0a4",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_release_metadata_retries_without_invalid_auth(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -212,11 +233,16 @@ def test_release_metadata_retries_without_invalid_auth(
|
||||
|
||||
|
||||
def test_runtime_setup_uses_pep440_package_version_and_codex_release_tags() -> None:
|
||||
"""The SDK uses PEP 440 package pins and converts only when fetching releases."""
|
||||
runtime_setup = _load_runtime_setup_module()
|
||||
pyproject = tomllib.loads((ROOT / "pyproject.toml").read_text())
|
||||
|
||||
assert runtime_setup.PACKAGE_NAME == "openai-codex-cli-bin"
|
||||
assert runtime_setup.pinned_runtime_version() == pyproject["project"]["version"]
|
||||
assert (
|
||||
f"{runtime_setup.PACKAGE_NAME}=={pyproject['project']['version']}"
|
||||
in pyproject["project"]["dependencies"]
|
||||
)
|
||||
assert (
|
||||
runtime_setup._normalized_package_version("rust-v0.116.0-alpha.1")
|
||||
== "0.116.0a1"
|
||||
@@ -352,6 +378,7 @@ def test_stage_runtime_release_can_pin_wheel_platform_tag(tmp_path: Path) -> Non
|
||||
|
||||
|
||||
def test_stage_runtime_release_copies_resource_binaries(tmp_path: Path) -> None:
|
||||
"""Runtime staging should copy every helper binary into the wheel bin dir."""
|
||||
script = _load_update_script_module()
|
||||
fake_binary = tmp_path / script.runtime_binary_name()
|
||||
helper = tmp_path / "helper"
|
||||
@@ -382,6 +409,7 @@ def test_stage_runtime_release_copies_resource_binaries(tmp_path: Path) -> None:
|
||||
def test_runtime_resource_binaries_are_included_by_wheel_config(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""The runtime wheel config should include helper binaries beside Codex."""
|
||||
script = _load_update_script_module()
|
||||
fake_binary = tmp_path / script.runtime_binary_name()
|
||||
helper = tmp_path / "helper"
|
||||
@@ -398,9 +426,7 @@ def test_runtime_resource_binaries_are_included_by_wheel_config(
|
||||
pyproject = tomllib.loads((staged / "pyproject.toml").read_text())
|
||||
assert {
|
||||
"include": pyproject["tool"]["hatch"]["build"]["targets"]["wheel"]["include"],
|
||||
"helper": (
|
||||
staged / "src" / "codex_cli_bin" / "bin" / "helper"
|
||||
).read_text(),
|
||||
"helper": (staged / "src" / "codex_cli_bin" / "bin" / "helper").read_text(),
|
||||
} == {
|
||||
"include": ["src/codex_cli_bin/bin/**"],
|
||||
"helper": "fake helper\n",
|
||||
|
||||
@@ -2,17 +2,26 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
from codex_app_server.async_client import AsyncAppServerClient
|
||||
from codex_app_server.generated.v2_all import (
|
||||
AgentMessageDeltaNotification,
|
||||
TurnCompletedNotification,
|
||||
)
|
||||
from codex_app_server.models import Notification, UnknownNotification
|
||||
|
||||
|
||||
def test_async_client_serializes_transport_calls() -> None:
|
||||
def test_async_client_allows_concurrent_transport_calls() -> None:
|
||||
"""Async wrappers should offload sync calls so concurrent awaits can overlap."""
|
||||
async def scenario() -> int:
|
||||
"""Run two blocking sync calls and report peak overlap."""
|
||||
client = AsyncAppServerClient()
|
||||
active = 0
|
||||
max_active = 0
|
||||
|
||||
def fake_model_list(include_hidden: bool = False) -> bool:
|
||||
"""Simulate a blocking sync transport call."""
|
||||
nonlocal active, max_active
|
||||
active += 1
|
||||
max_active = max(max_active, active)
|
||||
@@ -24,20 +33,24 @@ def test_async_client_serializes_transport_calls() -> None:
|
||||
await asyncio.gather(client.model_list(), client.model_list())
|
||||
return max_active
|
||||
|
||||
assert asyncio.run(scenario()) == 1
|
||||
assert asyncio.run(scenario()) == 2
|
||||
|
||||
|
||||
def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
|
||||
def test_async_stream_text_is_incremental_without_blocking_parallel_calls() -> None:
|
||||
"""Async text streaming should yield incrementally without blocking other calls."""
|
||||
async def scenario() -> tuple[str, list[str], bool]:
|
||||
"""Start a stream, then prove another async client call can finish."""
|
||||
client = AsyncAppServerClient()
|
||||
|
||||
def fake_stream_text(thread_id: str, text: str, params=None): # type: ignore[no-untyped-def]
|
||||
"""Yield one item before sleeping so the async wrapper can interleave."""
|
||||
yield "first"
|
||||
time.sleep(0.03)
|
||||
yield "second"
|
||||
yield "third"
|
||||
|
||||
def fake_model_list(include_hidden: bool = False) -> str:
|
||||
"""Return immediately to prove the event loop was not monopolized."""
|
||||
return "done"
|
||||
|
||||
client._sync.stream_text = fake_stream_text # type: ignore[method-assign]
|
||||
@@ -46,19 +59,167 @@ def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
|
||||
stream = client.stream_text("thread-1", "hello")
|
||||
first = await anext(stream)
|
||||
|
||||
blocked_before_stream_done = False
|
||||
competing_call = asyncio.create_task(client.model_list())
|
||||
await asyncio.sleep(0.01)
|
||||
blocked_before_stream_done = not competing_call.done()
|
||||
competing_call_done_before_stream_done = competing_call.done()
|
||||
|
||||
remaining: list[str] = []
|
||||
async for item in stream:
|
||||
remaining.append(item)
|
||||
|
||||
await competing_call
|
||||
return first, remaining, blocked_before_stream_done
|
||||
return first, remaining, competing_call_done_before_stream_done
|
||||
|
||||
first, remaining, blocked = asyncio.run(scenario())
|
||||
first, remaining, was_unblocked = asyncio.run(scenario())
|
||||
assert first == "first"
|
||||
assert remaining == ["second", "third"]
|
||||
assert blocked
|
||||
assert was_unblocked
|
||||
|
||||
|
||||
def test_async_client_turn_notification_methods_delegate_to_sync_client() -> None:
|
||||
"""Async turn routing methods should preserve sync-client registration semantics."""
|
||||
async def scenario() -> tuple[list[tuple[str, str]], Notification, str]:
|
||||
"""Record the sync-client calls made by async turn notification wrappers."""
|
||||
client = AsyncAppServerClient()
|
||||
event = Notification(
|
||||
method="unknown/direct",
|
||||
payload=UnknownNotification(params={"turnId": "turn-1"}),
|
||||
)
|
||||
completed = TurnCompletedNotification.model_validate(
|
||||
{
|
||||
"threadId": "thread-1",
|
||||
"turn": {"id": "turn-1", "items": [], "status": "completed"},
|
||||
}
|
||||
)
|
||||
calls: list[tuple[str, str]] = []
|
||||
|
||||
def fake_register(turn_id: str) -> None:
|
||||
"""Record turn registration through the wrapped sync client."""
|
||||
calls.append(("register", turn_id))
|
||||
|
||||
def fake_unregister(turn_id: str) -> None:
|
||||
"""Record turn unregistration through the wrapped sync client."""
|
||||
calls.append(("unregister", turn_id))
|
||||
|
||||
def fake_next(turn_id: str) -> Notification:
|
||||
"""Return one routed notification through the wrapped sync client."""
|
||||
calls.append(("next", turn_id))
|
||||
return event
|
||||
|
||||
def fake_wait(turn_id: str) -> TurnCompletedNotification:
|
||||
"""Return one completion through the wrapped sync client."""
|
||||
calls.append(("wait", turn_id))
|
||||
return completed
|
||||
|
||||
client._sync.register_turn_notifications = fake_register # type: ignore[method-assign]
|
||||
client._sync.unregister_turn_notifications = fake_unregister # type: ignore[method-assign]
|
||||
client._sync.next_turn_notification = fake_next # type: ignore[method-assign]
|
||||
client._sync.wait_for_turn_completed = fake_wait # type: ignore[method-assign]
|
||||
|
||||
client.register_turn_notifications("turn-1")
|
||||
next_event = await client.next_turn_notification("turn-1")
|
||||
completed_event = await client.wait_for_turn_completed("turn-1")
|
||||
client.unregister_turn_notifications("turn-1")
|
||||
|
||||
return calls, next_event, completed_event.turn.id
|
||||
|
||||
calls, next_event, completed_turn_id = asyncio.run(scenario())
|
||||
|
||||
assert (
|
||||
calls,
|
||||
next_event,
|
||||
completed_turn_id,
|
||||
) == (
|
||||
[
|
||||
("register", "turn-1"),
|
||||
("next", "turn-1"),
|
||||
("wait", "turn-1"),
|
||||
("unregister", "turn-1"),
|
||||
],
|
||||
Notification(
|
||||
method="unknown/direct",
|
||||
payload=UnknownNotification(params={"turnId": "turn-1"}),
|
||||
),
|
||||
"turn-1",
|
||||
)
|
||||
|
||||
|
||||
def test_async_stream_text_uses_sync_turn_routing() -> None:
|
||||
"""Async text streaming should consume the same per-turn routing path as sync."""
|
||||
async def scenario() -> tuple[list[tuple[str, str]], list[str]]:
|
||||
"""Record routing calls while streaming two deltas and one completion."""
|
||||
client = AsyncAppServerClient()
|
||||
notifications = [
|
||||
Notification(
|
||||
method="item/agentMessage/delta",
|
||||
payload=AgentMessageDeltaNotification.model_validate(
|
||||
{
|
||||
"delta": "first",
|
||||
"itemId": "item-1",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-1",
|
||||
}
|
||||
),
|
||||
),
|
||||
Notification(
|
||||
method="item/agentMessage/delta",
|
||||
payload=AgentMessageDeltaNotification.model_validate(
|
||||
{
|
||||
"delta": "second",
|
||||
"itemId": "item-2",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-1",
|
||||
}
|
||||
),
|
||||
),
|
||||
Notification(
|
||||
method="turn/completed",
|
||||
payload=TurnCompletedNotification.model_validate(
|
||||
{
|
||||
"threadId": "thread-1",
|
||||
"turn": {"id": "turn-1", "items": [], "status": "completed"},
|
||||
}
|
||||
),
|
||||
),
|
||||
]
|
||||
calls: list[tuple[str, str]] = []
|
||||
|
||||
def fake_turn_start(thread_id: str, text: str, *, params=None): # type: ignore[no-untyped-def]
|
||||
"""Return a started turn id while recording the request thread."""
|
||||
calls.append(("turn_start", thread_id))
|
||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||
|
||||
def fake_register(turn_id: str) -> None:
|
||||
"""Record stream registration for the started turn."""
|
||||
calls.append(("register", turn_id))
|
||||
|
||||
def fake_next(turn_id: str) -> Notification:
|
||||
"""Return the next queued turn notification."""
|
||||
calls.append(("next", turn_id))
|
||||
return notifications.pop(0)
|
||||
|
||||
def fake_unregister(turn_id: str) -> None:
|
||||
"""Record stream cleanup for the started turn."""
|
||||
calls.append(("unregister", turn_id))
|
||||
|
||||
client._sync.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||
client._sync.register_turn_notifications = fake_register # type: ignore[method-assign]
|
||||
client._sync.next_turn_notification = fake_next # type: ignore[method-assign]
|
||||
client._sync.unregister_turn_notifications = fake_unregister # type: ignore[method-assign]
|
||||
|
||||
chunks = [chunk async for chunk in client.stream_text("thread-1", "hello")]
|
||||
return calls, [chunk.delta for chunk in chunks]
|
||||
|
||||
calls, deltas = asyncio.run(scenario())
|
||||
|
||||
assert (calls, deltas) == (
|
||||
[
|
||||
("turn_start", "thread-1"),
|
||||
("register", "turn-1"),
|
||||
("next", "turn-1"),
|
||||
("next", "turn-1"),
|
||||
("next", "turn-1"),
|
||||
("unregister", "turn-1"),
|
||||
],
|
||||
["first", "second"],
|
||||
)
|
||||
|
||||
@@ -4,13 +4,17 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from codex_app_server.client import AppServerClient, _params_dict
|
||||
from codex_app_server.generated.notification_registry import notification_turn_id
|
||||
from codex_app_server.generated.v2_all import (
|
||||
AgentMessageDeltaNotification,
|
||||
ApprovalsReviewer,
|
||||
ThreadListParams,
|
||||
ThreadResumeResponse,
|
||||
ThreadTokenUsageUpdatedNotification,
|
||||
TurnCompletedNotification,
|
||||
WarningNotification,
|
||||
)
|
||||
from codex_app_server.models import UnknownNotification
|
||||
from codex_app_server.models import Notification, UnknownNotification
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
@@ -46,6 +50,7 @@ def test_generated_v2_bundle_has_single_shared_plan_type_definition() -> None:
|
||||
|
||||
|
||||
def test_thread_resume_response_accepts_auto_review_reviewer() -> None:
|
||||
"""Generated response models should keep accepting the auto review enum value."""
|
||||
response = ThreadResumeResponse.model_validate(
|
||||
{
|
||||
"approvalPolicy": "on-request",
|
||||
@@ -62,6 +67,8 @@ def test_thread_resume_response_accepts_auto_review_reviewer() -> None:
|
||||
"id": "thread-1",
|
||||
"modelProvider": "openai",
|
||||
"preview": "",
|
||||
# The pinned runtime schema requires the session id on threads.
|
||||
"sessionId": "session-1",
|
||||
"source": "cli",
|
||||
"status": {"type": "idle"},
|
||||
"turns": [],
|
||||
@@ -128,3 +135,227 @@ def test_invalid_notification_payload_falls_back_to_unknown() -> None:
|
||||
|
||||
assert event.method == "thread/tokenUsage/updated"
|
||||
assert isinstance(event.payload, UnknownNotification)
|
||||
|
||||
|
||||
def test_generated_notification_turn_id_handles_known_payload_shapes() -> None:
|
||||
"""Generated routing metadata should cover direct, nested, and unscoped payloads."""
|
||||
direct = AgentMessageDeltaNotification.model_validate(
|
||||
{
|
||||
"delta": "hello",
|
||||
"itemId": "item-1",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-1",
|
||||
}
|
||||
)
|
||||
nested = TurnCompletedNotification.model_validate(
|
||||
{
|
||||
"threadId": "thread-1",
|
||||
"turn": {"id": "turn-2", "items": [], "status": "completed"},
|
||||
}
|
||||
)
|
||||
unscoped = WarningNotification(message="heads up")
|
||||
|
||||
assert [
|
||||
notification_turn_id(direct),
|
||||
notification_turn_id(nested),
|
||||
notification_turn_id(unscoped),
|
||||
] == ["turn-1", "turn-2", None]
|
||||
|
||||
|
||||
def test_turn_notification_router_demuxes_registered_turns() -> None:
|
||||
"""The router should deliver out-of-order turn events to the matching queues."""
|
||||
client = AppServerClient()
|
||||
client.register_turn_notifications("turn-1")
|
||||
client.register_turn_notifications("turn-2")
|
||||
|
||||
client._router.route_notification(
|
||||
client._coerce_notification(
|
||||
"item/agentMessage/delta",
|
||||
{
|
||||
"delta": "two",
|
||||
"itemId": "item-2",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-2",
|
||||
},
|
||||
)
|
||||
)
|
||||
client._router.route_notification(
|
||||
client._coerce_notification(
|
||||
"item/agentMessage/delta",
|
||||
{
|
||||
"delta": "one",
|
||||
"itemId": "item-1",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-1",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
first = client.next_turn_notification("turn-1")
|
||||
second = client.next_turn_notification("turn-2")
|
||||
|
||||
assert isinstance(first.payload, AgentMessageDeltaNotification)
|
||||
assert isinstance(second.payload, AgentMessageDeltaNotification)
|
||||
assert [
|
||||
(first.method, first.payload.delta),
|
||||
(second.method, second.payload.delta),
|
||||
] == [
|
||||
("item/agentMessage/delta", "one"),
|
||||
("item/agentMessage/delta", "two"),
|
||||
]
|
||||
|
||||
|
||||
def test_client_reader_routes_interleaved_turn_notifications_by_turn_id() -> None:
|
||||
"""Reader-loop routing should preserve order within each interleaved turn stream."""
|
||||
client = AppServerClient()
|
||||
client.register_turn_notifications("turn-1")
|
||||
client.register_turn_notifications("turn-2")
|
||||
|
||||
messages: list[dict[str, object]] = [
|
||||
{
|
||||
"method": "item/agentMessage/delta",
|
||||
"params": {
|
||||
"delta": "one-a",
|
||||
"itemId": "item-1",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-1",
|
||||
},
|
||||
},
|
||||
{
|
||||
"method": "item/agentMessage/delta",
|
||||
"params": {
|
||||
"delta": "two-a",
|
||||
"itemId": "item-2",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-2",
|
||||
},
|
||||
},
|
||||
{
|
||||
"method": "item/agentMessage/delta",
|
||||
"params": {
|
||||
"delta": "one-b",
|
||||
"itemId": "item-3",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-1",
|
||||
},
|
||||
},
|
||||
{
|
||||
"method": "item/agentMessage/delta",
|
||||
"params": {
|
||||
"delta": "two-b",
|
||||
"itemId": "item-4",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-2",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def fake_read_message() -> dict[str, object]:
|
||||
"""Feed the reader loop a realistic interleaved stdout sequence."""
|
||||
if messages:
|
||||
return messages.pop(0)
|
||||
raise EOFError
|
||||
|
||||
client._read_message = fake_read_message # type: ignore[method-assign]
|
||||
client._reader_loop()
|
||||
|
||||
first_turn_events = [
|
||||
client.next_turn_notification("turn-1"),
|
||||
client.next_turn_notification("turn-1"),
|
||||
]
|
||||
second_turn_events = [
|
||||
client.next_turn_notification("turn-2"),
|
||||
client.next_turn_notification("turn-2"),
|
||||
]
|
||||
|
||||
first_turn_deltas = [
|
||||
event.payload.delta
|
||||
for event in first_turn_events
|
||||
if isinstance(event.payload, AgentMessageDeltaNotification)
|
||||
]
|
||||
second_turn_deltas = [
|
||||
event.payload.delta
|
||||
for event in second_turn_events
|
||||
if isinstance(event.payload, AgentMessageDeltaNotification)
|
||||
]
|
||||
assert (first_turn_deltas, second_turn_deltas) == (
|
||||
["one-a", "one-b"],
|
||||
["two-a", "two-b"],
|
||||
)
|
||||
|
||||
|
||||
def test_turn_notification_router_buffers_events_before_registration() -> None:
|
||||
"""Early turn events should be replayed once their TurnHandle registers."""
|
||||
client = AppServerClient()
|
||||
client._router.route_notification(
|
||||
client._coerce_notification(
|
||||
"item/agentMessage/delta",
|
||||
{
|
||||
"delta": "early",
|
||||
"itemId": "item-1",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-1",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
client.register_turn_notifications("turn-1")
|
||||
event = client.next_turn_notification("turn-1")
|
||||
|
||||
assert isinstance(event.payload, AgentMessageDeltaNotification)
|
||||
assert (event.method, event.payload.delta) == (
|
||||
"item/agentMessage/delta",
|
||||
"early",
|
||||
)
|
||||
|
||||
|
||||
def test_turn_notification_router_clears_unregistered_turn_when_completed() -> None:
|
||||
"""A completed unregistered turn should not leave a pending queue behind."""
|
||||
client = AppServerClient()
|
||||
client._router.route_notification(
|
||||
client._coerce_notification(
|
||||
"item/agentMessage/delta",
|
||||
{
|
||||
"delta": "early",
|
||||
"itemId": "item-1",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-1",
|
||||
},
|
||||
)
|
||||
)
|
||||
client._router.route_notification(
|
||||
client._coerce_notification(
|
||||
"turn/completed",
|
||||
{
|
||||
"threadId": "thread-1",
|
||||
"turn": {"id": "turn-1", "items": [], "status": "completed"},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert client._router._pending_turn_notifications == {}
|
||||
|
||||
|
||||
def test_turn_notification_router_routes_unknown_turn_notifications() -> None:
|
||||
"""Unknown notifications should still route when their raw params carry a turn id."""
|
||||
client = AppServerClient()
|
||||
client.register_turn_notifications("turn-1")
|
||||
client.register_turn_notifications("turn-2")
|
||||
|
||||
client._router.route_notification(
|
||||
Notification(
|
||||
method="unknown/direct",
|
||||
payload=UnknownNotification(params={"turnId": "turn-1"}),
|
||||
)
|
||||
)
|
||||
client._router.route_notification(
|
||||
Notification(
|
||||
method="unknown/nested",
|
||||
payload=UnknownNotification(params={"turn": {"id": "turn-2"}}),
|
||||
)
|
||||
)
|
||||
|
||||
first = client.next_turn_notification("turn-1")
|
||||
second = client.next_turn_notification("turn-2")
|
||||
|
||||
assert [first.method, second.method] == ["unknown/direct", "unknown/nested"]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.metadata
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -14,6 +15,7 @@ GENERATED_TARGETS = [
|
||||
|
||||
|
||||
def _snapshot_target(root: Path, rel_path: Path) -> dict[str, bytes] | bytes | None:
|
||||
"""Capture one generated artifact so regeneration drift is easy to compare."""
|
||||
target = root / rel_path
|
||||
if not target.exists():
|
||||
return None
|
||||
@@ -28,16 +30,22 @@ def _snapshot_target(root: Path, rel_path: Path) -> dict[str, bytes] | bytes | N
|
||||
|
||||
|
||||
def _snapshot_targets(root: Path) -> dict[str, dict[str, bytes] | bytes | None]:
|
||||
"""Capture all checked-in generated artifacts before and after regeneration."""
|
||||
return {
|
||||
str(rel_path): _snapshot_target(root, rel_path) for rel_path in GENERATED_TARGETS
|
||||
str(rel_path): _snapshot_target(root, rel_path)
|
||||
for rel_path in GENERATED_TARGETS
|
||||
}
|
||||
|
||||
|
||||
def test_generated_files_are_up_to_date():
|
||||
"""Regenerating from the pinned runtime package should leave artifacts unchanged."""
|
||||
before = _snapshot_targets(ROOT)
|
||||
|
||||
# Regenerate contract artifacts via single maintenance entrypoint.
|
||||
# Regenerate contract artifacts via the pinned runtime package, not a local
|
||||
# app-server binary from the checkout or CI environment.
|
||||
assert importlib.metadata.version("openai-codex-cli-bin") == "0.131.0a4"
|
||||
env = os.environ.copy()
|
||||
env.pop("CODEX_EXEC_PATH", None)
|
||||
python_bin = str(Path(sys.executable).parent)
|
||||
env["PATH"] = f"{python_bin}{os.pathsep}{env.get('PATH', '')}"
|
||||
|
||||
|
||||
@@ -82,6 +82,7 @@ def _item_completed_notification(
|
||||
text: str = "final text",
|
||||
phase: MessagePhase | None = None,
|
||||
) -> Notification:
|
||||
"""Build a realistic completed-item notification accepted by generated models."""
|
||||
item: dict[str, object] = {
|
||||
"id": "item-1",
|
||||
"text": text,
|
||||
@@ -93,6 +94,8 @@ def _item_completed_notification(
|
||||
method="item/completed",
|
||||
payload=ItemCompletedNotification.model_validate(
|
||||
{
|
||||
# The pinned runtime schema requires completion timestamps.
|
||||
"completedAtMs": 1,
|
||||
"item": item,
|
||||
"threadId": thread_id,
|
||||
"turnId": turn_id,
|
||||
@@ -226,54 +229,79 @@ def test_async_codex_initializes_only_once_under_concurrency() -> None:
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_turn_stream_rejects_second_active_consumer() -> None:
|
||||
def test_turn_streams_can_consume_multiple_turns_on_one_client() -> None:
|
||||
"""Two sync TurnHandle streams should advance independently on one client."""
|
||||
client = AppServerClient()
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
_delta_notification(turn_id="turn-1"),
|
||||
_completed_notification(turn_id="turn-1"),
|
||||
]
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
notifications: dict[str, deque[Notification]] = {
|
||||
"turn-1": deque(
|
||||
[
|
||||
_delta_notification(turn_id="turn-1", text="one"),
|
||||
_completed_notification(turn_id="turn-1"),
|
||||
]
|
||||
),
|
||||
"turn-2": deque(
|
||||
[
|
||||
_delta_notification(turn_id="turn-2", text="two"),
|
||||
_completed_notification(turn_id="turn-2"),
|
||||
]
|
||||
),
|
||||
}
|
||||
client.next_turn_notification = lambda turn_id: notifications[turn_id].popleft() # type: ignore[method-assign]
|
||||
|
||||
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
|
||||
assert next(first_stream).method == "item/agentMessage/delta"
|
||||
|
||||
second_stream = TurnHandle(client, "thread-1", "turn-2").stream()
|
||||
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
|
||||
next(second_stream)
|
||||
assert next(second_stream).method == "item/agentMessage/delta"
|
||||
assert next(first_stream).method == "turn/completed"
|
||||
assert next(second_stream).method == "turn/completed"
|
||||
|
||||
first_stream.close()
|
||||
second_stream.close()
|
||||
|
||||
|
||||
def test_async_turn_stream_rejects_second_active_consumer() -> None:
|
||||
def test_async_turn_streams_can_consume_multiple_turns_on_one_client() -> None:
|
||||
"""Two async TurnHandle streams should advance independently on one client."""
|
||||
async def scenario() -> None:
|
||||
"""Interleave two async streams backed by separate per-turn queues."""
|
||||
codex = AsyncCodex()
|
||||
|
||||
async def fake_ensure_initialized() -> None:
|
||||
"""Avoid starting a real app-server process for this stream test."""
|
||||
return None
|
||||
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
_delta_notification(turn_id="turn-1"),
|
||||
_completed_notification(turn_id="turn-1"),
|
||||
]
|
||||
)
|
||||
notifications: dict[str, deque[Notification]] = {
|
||||
"turn-1": deque(
|
||||
[
|
||||
_delta_notification(turn_id="turn-1", text="one"),
|
||||
_completed_notification(turn_id="turn-1"),
|
||||
]
|
||||
),
|
||||
"turn-2": deque(
|
||||
[
|
||||
_delta_notification(turn_id="turn-2", text="two"),
|
||||
_completed_notification(turn_id="turn-2"),
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
async def fake_next_notification() -> Notification:
|
||||
return notifications.popleft()
|
||||
async def fake_next_notification(turn_id: str) -> Notification:
|
||||
"""Return the next notification from the requested per-turn queue."""
|
||||
return notifications[turn_id].popleft()
|
||||
|
||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
||||
codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign]
|
||||
|
||||
first_stream = AsyncTurnHandle(codex, "thread-1", "turn-1").stream()
|
||||
assert (await anext(first_stream)).method == "item/agentMessage/delta"
|
||||
|
||||
second_stream = AsyncTurnHandle(codex, "thread-1", "turn-2").stream()
|
||||
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
|
||||
await anext(second_stream)
|
||||
assert (await anext(second_stream)).method == "item/agentMessage/delta"
|
||||
assert (await anext(first_stream)).method == "turn/completed"
|
||||
assert (await anext(second_stream)).method == "turn/completed"
|
||||
|
||||
await first_stream.aclose()
|
||||
await second_stream.aclose()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
@@ -285,7 +313,7 @@ def test_turn_run_returns_completed_turn_payload() -> None:
|
||||
_completed_notification(),
|
||||
]
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
|
||||
|
||||
result = TurnHandle(client, "thread-1", "turn-1").run()
|
||||
|
||||
@@ -305,7 +333,7 @@ def test_thread_run_accepts_string_input_and_returns_run_result() -> None:
|
||||
_completed_notification(),
|
||||
]
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202
|
||||
@@ -338,7 +366,7 @@ def test_thread_run_uses_last_completed_assistant_message_as_final_response() ->
|
||||
_completed_notification(),
|
||||
]
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -363,7 +391,7 @@ def test_thread_run_preserves_empty_last_assistant_message() -> None:
|
||||
_completed_notification(),
|
||||
]
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -394,7 +422,7 @@ def test_thread_run_prefers_explicit_final_answer_over_later_commentary() -> Non
|
||||
_completed_notification(),
|
||||
]
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -420,7 +448,7 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
|
||||
_completed_notification(),
|
||||
]
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -438,7 +466,7 @@ def test_thread_run_raises_on_failed_turn() -> None:
|
||||
_completed_notification(status="failed", error_message="boom"),
|
||||
]
|
||||
)
|
||||
client.next_notification = notifications.popleft # type: ignore[method-assign]
|
||||
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
|
||||
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
@@ -447,11 +475,60 @@ def test_thread_run_raises_on_failed_turn() -> None:
|
||||
Thread(client, "thread-1").run("hello")
|
||||
|
||||
|
||||
def test_stream_text_registers_and_consumes_turn_notifications() -> None:
|
||||
"""stream_text should register, consume, and unregister one turn queue."""
|
||||
client = AppServerClient()
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
_delta_notification(text="first"),
|
||||
_delta_notification(text="second"),
|
||||
_completed_notification(),
|
||||
]
|
||||
)
|
||||
calls: list[tuple[str, str]] = []
|
||||
client.turn_start = lambda thread_id, input_items, *, params=None: SimpleNamespace( # noqa: ARG005,E731
|
||||
turn=SimpleNamespace(id="turn-1")
|
||||
)
|
||||
|
||||
def fake_register(turn_id: str) -> None:
|
||||
"""Record registration for the turn created by stream_text."""
|
||||
calls.append(("register", turn_id))
|
||||
|
||||
def fake_next(turn_id: str) -> Notification:
|
||||
"""Return the next queued notification for stream_text."""
|
||||
calls.append(("next", turn_id))
|
||||
return notifications.popleft()
|
||||
|
||||
def fake_unregister(turn_id: str) -> None:
|
||||
"""Record cleanup for the turn created by stream_text."""
|
||||
calls.append(("unregister", turn_id))
|
||||
|
||||
client.register_turn_notifications = fake_register # type: ignore[method-assign]
|
||||
client.next_turn_notification = fake_next # type: ignore[method-assign]
|
||||
client.unregister_turn_notifications = fake_unregister # type: ignore[method-assign]
|
||||
|
||||
chunks = list(client.stream_text("thread-1", "hello"))
|
||||
|
||||
assert ([chunk.delta for chunk in chunks], calls) == (
|
||||
["first", "second"],
|
||||
[
|
||||
("register", "turn-1"),
|
||||
("next", "turn-1"),
|
||||
("next", "turn-1"),
|
||||
("next", "turn-1"),
|
||||
("unregister", "turn-1"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None:
|
||||
"""Async Thread.run should normalize string input and collect routed results."""
|
||||
async def scenario() -> None:
|
||||
"""Feed item, usage, and completion events through the async turn stream."""
|
||||
codex = AsyncCodex()
|
||||
|
||||
async def fake_ensure_initialized() -> None:
|
||||
"""Avoid starting a real app-server process for this run test."""
|
||||
return None
|
||||
|
||||
item_notification = _item_completed_notification(text="Hello async.")
|
||||
@@ -466,17 +543,19 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None:
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202
|
||||
"""Capture normalized input and return a synthetic turn id."""
|
||||
seen["thread_id"] = thread_id
|
||||
seen["wire_input"] = wire_input
|
||||
seen["params"] = params
|
||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||
|
||||
async def fake_next_notification() -> Notification:
|
||||
async def fake_next_notification(_turn_id: str) -> Notification:
|
||||
"""Return the next queued notification for the synthetic turn."""
|
||||
return notifications.popleft()
|
||||
|
||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
||||
codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign]
|
||||
|
||||
result = await AsyncThread(codex, "thread-1").run("hello")
|
||||
|
||||
@@ -491,15 +570,24 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None:
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_async_thread_run_uses_last_completed_assistant_message_as_final_response() -> None:
|
||||
def test_async_thread_run_uses_last_completed_assistant_message_as_final_response() -> (
|
||||
None
|
||||
):
|
||||
"""Async run should use the last final assistant message as the response text."""
|
||||
async def scenario() -> None:
|
||||
"""Feed two completed agent messages through the async per-turn stream."""
|
||||
codex = AsyncCodex()
|
||||
|
||||
async def fake_ensure_initialized() -> None:
|
||||
"""Avoid starting a real app-server process for this run test."""
|
||||
return None
|
||||
|
||||
first_item_notification = _item_completed_notification(text="First async message")
|
||||
second_item_notification = _item_completed_notification(text="Second async message")
|
||||
first_item_notification = _item_completed_notification(
|
||||
text="First async message"
|
||||
)
|
||||
second_item_notification = _item_completed_notification(
|
||||
text="Second async message"
|
||||
)
|
||||
notifications: deque[Notification] = deque(
|
||||
[
|
||||
first_item_notification,
|
||||
@@ -509,14 +597,16 @@ def test_async_thread_run_uses_last_completed_assistant_message_as_final_respons
|
||||
)
|
||||
|
||||
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
|
||||
"""Return a synthetic turn id after AsyncThread.run builds input."""
|
||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||
|
||||
async def fake_next_notification() -> Notification:
|
||||
async def fake_next_notification(_turn_id: str) -> Notification:
|
||||
"""Return the next queued notification for that synthetic turn."""
|
||||
return notifications.popleft()
|
||||
|
||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
||||
codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign]
|
||||
|
||||
result = await AsyncThread(codex, "thread-1").run("hello")
|
||||
|
||||
@@ -530,10 +620,13 @@ def test_async_thread_run_uses_last_completed_assistant_message_as_final_respons
|
||||
|
||||
|
||||
def test_async_thread_run_returns_none_when_only_commentary_messages_complete() -> None:
|
||||
"""Async Thread.run should ignore commentary-only messages for final text."""
|
||||
async def scenario() -> None:
|
||||
"""Feed a commentary item and completion through the async turn stream."""
|
||||
codex = AsyncCodex()
|
||||
|
||||
async def fake_ensure_initialized() -> None:
|
||||
"""Avoid starting a real app-server process for this run test."""
|
||||
return None
|
||||
|
||||
commentary_notification = _item_completed_notification(
|
||||
@@ -548,14 +641,16 @@ def test_async_thread_run_returns_none_when_only_commentary_messages_complete()
|
||||
)
|
||||
|
||||
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
|
||||
"""Return a synthetic turn id for commentary-only output."""
|
||||
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
|
||||
|
||||
async def fake_next_notification() -> Notification:
|
||||
async def fake_next_notification(_turn_id: str) -> Notification:
|
||||
"""Return the next queued commentary/completion notification."""
|
||||
return notifications.popleft()
|
||||
|
||||
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
|
||||
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
|
||||
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
|
||||
codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign]
|
||||
|
||||
result = await AsyncThread(codex, "thread-1").run("hello")
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ def test_package_includes_py_typed_marker() -> None:
|
||||
|
||||
|
||||
def test_generated_public_signatures_are_snake_case_and_typed() -> None:
|
||||
"""Generated convenience methods should expose typed Pythonic keyword names."""
|
||||
expected = {
|
||||
Codex.thread_start: [
|
||||
"approval_policy",
|
||||
@@ -70,6 +71,7 @@ def test_generated_public_signatures_are_snake_case_and_typed() -> None:
|
||||
"service_name",
|
||||
"service_tier",
|
||||
"session_start_source",
|
||||
"thread_source",
|
||||
],
|
||||
Codex.thread_list: [
|
||||
"archived",
|
||||
@@ -108,6 +110,7 @@ def test_generated_public_signatures_are_snake_case_and_typed() -> None:
|
||||
"model_provider",
|
||||
"sandbox",
|
||||
"service_tier",
|
||||
"thread_source",
|
||||
],
|
||||
Thread.turn: [
|
||||
"approval_policy",
|
||||
@@ -148,6 +151,7 @@ def test_generated_public_signatures_are_snake_case_and_typed() -> None:
|
||||
"service_name",
|
||||
"service_tier",
|
||||
"session_start_source",
|
||||
"thread_source",
|
||||
],
|
||||
AsyncCodex.thread_list: [
|
||||
"archived",
|
||||
@@ -186,6 +190,7 @@ def test_generated_public_signatures_are_snake_case_and_typed() -> None:
|
||||
"model_provider",
|
||||
"sandbox",
|
||||
"service_tier",
|
||||
"thread_source",
|
||||
],
|
||||
AsyncThread.turn: [
|
||||
"approval_policy",
|
||||
|
||||
22
sdk/python/uv.lock
generated
22
sdk/python/uv.lock
generated
@@ -3,9 +3,12 @@ revision = 3
|
||||
requires-python = ">=3.10"
|
||||
|
||||
[options]
|
||||
exclude-newer = "2026-04-20T18:19:27.620299Z"
|
||||
exclude-newer = "2026-05-02T06:28:46.47929Z"
|
||||
exclude-newer-span = "P7D"
|
||||
|
||||
[options.exclude-newer-package]
|
||||
openai-codex-cli-bin = "2026-05-10T00:00:00Z"
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
version = "0.7.0"
|
||||
@@ -279,9 +282,10 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "openai-codex-app-server-sdk"
|
||||
version = "0.116.0a1"
|
||||
version = "0.131.0a4"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "openai-codex-cli-bin" },
|
||||
{ name = "pydantic" },
|
||||
]
|
||||
|
||||
@@ -295,12 +299,26 @@ dev = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "datamodel-code-generator", marker = "extra == 'dev'", specifier = "==0.31.2" },
|
||||
{ name = "openai-codex-cli-bin", specifier = "==0.131.0a4" },
|
||||
{ name = "pydantic", specifier = ">=2.12" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" },
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.11" },
|
||||
]
|
||||
provides-extras = ["dev"]
|
||||
|
||||
[[package]]
|
||||
name = "openai-codex-cli-bin"
|
||||
version = "0.131.0a4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6b/9f/f9fc4bb1b2b7a20d4d65143ebb4c4dcd2301a718183b539ecb5b1c0ac3ec/openai_codex_cli_bin-0.131.0a4-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:db0f3cb7dda310641ac04fbaf3f128693a3817ab83ae59b67a3c9c74bd53f8b8", size = 88367585, upload-time = "2026-05-09T06:14:09.453Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/39/eb95ed0e8156669e895a192dec760be07dabe891c3c6340f7c6487b9a976/openai_codex_cli_bin-0.131.0a4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6cae5af6edca7f6d3f0bcbbd93cfc8a6dc3e33fb5955af21ae492b6d5d0dcb72", size = 79245567, upload-time = "2026-05-09T06:14:13.581Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/92/ade176fa78d746d5ff7a6e371d64740c0d95ab299b0dd58a5404b89b3915/openai_codex_cli_bin-0.131.0a4-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:5728f9887baf62d7e72f4f242093b3ff81e26c81d80d346fe1eef7eda6838aa8", size = 77758628, upload-time = "2026-05-09T06:14:18.374Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/e6/bfe6c65f8e3e5499f71b24c3b6e8d07e4d426543d25e429b9b141b544e5f/openai_codex_cli_bin-0.131.0a4-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:d7a47fd3667fbcc216593839c202deffa056e9b3d46c6933e72594d461f4fea0", size = 84535509, upload-time = "2026-05-09T06:14:22.851Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bd/b7/53dc094a691ab6f2ca079e8e865b122843809ac4fad51cac4d59021e599d/openai_codex_cli_bin-0.131.0a4-py3-none-win_amd64.whl", hash = "sha256:c61bcf029672494c4c7fdc8567dbaa659a48bb75641d91c2ade27c1e46803434", size = 88185543, upload-time = "2026-05-09T06:14:27.282Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/99/e0852ffcf9b4d2794fef83e0c3a267b3c773a776f136e9f7ce19f0c8df42/openai_codex_cli_bin-0.131.0a4-py3-none-win_arm64.whl", hash = "sha256:bbde750186861f102e346ac066f4e9608f515f7b71b16a6e8b7ef1ddc02a97a5", size = 81196380, upload-time = "2026-05-09T06:14:32.103Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "26.1"
|
||||
|
||||
Reference in New Issue
Block a user