Compare commits

..

1 Commits

Author SHA1 Message Date
Shaqayeq
bd664bb752 Add Python app-server SDK
Co-authored-by: Codex <noreply@openai.com>
2026-03-11 23:00:33 -07:00
26 changed files with 11318 additions and 120 deletions

View File

@@ -1,115 +0,0 @@
use codex_protocol::protocol::RolloutLine;
use schemars::r#gen::SchemaSettings;
use serde_json::Map;
use serde_json::Value;
use std::any::TypeId;
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::io;
use std::path::PathBuf;
use ts_rs::TS;
use ts_rs::TypeVisitor;
const GENERATED_TS_HEADER: &str = "// GENERATED CODE! DO NOT MODIFY BY HAND!\n\n";
const TS_RS_NOTE: &str = "// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.\n";
const JSON_SCHEMA_FILENAME: &str = "rollout-line.schema.json";
const TYPESCRIPT_FILENAME: &str = "rollout-line.schema.ts";
fn main() -> io::Result<()> {
let out_dir = std::env::args_os()
.nth(1)
.map(PathBuf::from)
.unwrap_or_else(|| {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let codex_rs_dir = manifest_dir.parent().unwrap_or(&manifest_dir);
codex_rs_dir.join("out/rollout-line-schema")
});
std::fs::create_dir_all(&out_dir)?;
std::fs::write(
out_dir.join(JSON_SCHEMA_FILENAME),
rollout_line_schema_json()?,
)?;
std::fs::write(
out_dir.join(TYPESCRIPT_FILENAME),
generate_typescript_bundle::<RolloutLine>(),
)?;
for filename in [JSON_SCHEMA_FILENAME, TYPESCRIPT_FILENAME] {
println!("Wrote {}", out_dir.join(filename).display());
}
Ok(())
}
fn rollout_line_schema_json() -> io::Result<Vec<u8>> {
let schema = SchemaSettings::draft07()
.into_generator()
.into_root_schema_for::<RolloutLine>();
let value = serde_json::to_value(schema).map_err(io::Error::other)?;
let value = canonicalize_json(&value);
serde_json::to_vec_pretty(&value).map_err(io::Error::other)
}
fn canonicalize_json(value: &Value) -> Value {
match value {
Value::Array(items) => Value::Array(items.iter().map(canonicalize_json).collect()),
Value::Object(map) => {
let mut entries: Vec<_> = map.iter().collect();
entries.sort_by(|(left, _), (right, _)| left.cmp(right));
let mut sorted = Map::with_capacity(map.len());
for (key, child) in entries {
sorted.insert(key.clone(), canonicalize_json(child));
}
Value::Object(sorted)
}
_ => value.clone(),
}
}
fn generate_typescript_bundle<T: TS + 'static + ?Sized>() -> String {
let mut declarations = BTreeMap::new();
let mut seen = HashSet::new();
collect_typescript_declarations::<T>(&mut declarations, &mut seen);
let body = declarations
.into_values()
.collect::<Vec<_>>()
.join("\n\n")
.replace("\r\n", "\n")
.replace('\r', "\n");
format!("{GENERATED_TS_HEADER}{TS_RS_NOTE}\n{body}\n")
}
fn collect_typescript_declarations<T: TS + 'static + ?Sized>(
declarations: &mut BTreeMap<PathBuf, String>,
seen: &mut HashSet<TypeId>,
) {
let Some(output_path) = T::output_path() else {
return;
};
if !seen.insert(TypeId::of::<T>()) {
return;
}
let mut declaration = String::new();
if let Some(docs) = T::docs() {
declaration.push_str(&docs.replace("\r\n", "\n").replace('\r', "\n"));
}
declaration.push_str("export ");
declaration.push_str(&T::decl().replace("\r\n", "\n").replace('\r', "\n"));
declarations.insert(output_path.components().collect(), declaration);
let mut visitor = TypeScriptDeclarationCollector { declarations, seen };
T::visit_dependencies(&mut visitor);
}
struct TypeScriptDeclarationCollector<'a> {
declarations: &'a mut BTreeMap<PathBuf, String>,
seen: &'a mut HashSet<TypeId>,
}
impl TypeVisitor for TypeScriptDeclarationCollector<'_> {
fn visit<T: TS + 'static + ?Sized>(&mut self) {
collect_typescript_declarations::<T>(self.declarations, self.seen);
}
}

View File

@@ -2423,7 +2423,7 @@ pub enum TruncationPolicy {
Tokens(usize),
}
#[derive(Serialize, Deserialize, Clone, JsonSchema, TS)]
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct RolloutLine {
pub timestamp: String,
#[serde(flatten)]

View File

@@ -78,10 +78,6 @@ mcp-server-run *args:
write-config-schema:
cargo run -p codex-core --bin codex-write-config-schema
# Generate RolloutLine schema artifacts.
write-rollout-line-schema *args:
cargo run -p codex-protocol --bin codex-write-rollout-line-schema -- "$@"
# Regenerate vendored app-server protocol schema artifacts.
write-app-server-schema *args:
cargo run -p codex-app-server-protocol --bin write_schema_fixtures -- "$@"

View File

@@ -0,0 +1,9 @@
# Codex CLI Runtime for Python SDK
Platform-specific runtime package consumed by the published `codex-app-server-sdk`.
This package is staged during release so the SDK can pin an exact Codex CLI
version without checking platform binaries into the repo.
`codex-cli-bin` is intentionally wheel-only. Do not build or publish an sdist
for this package.

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from hatchling.builders.hooks.plugin.interface import BuildHookInterface
class RuntimeBuildHook(BuildHookInterface):
def initialize(self, version: str, build_data: dict[str, object]) -> None:
del version
if self.target_name == "sdist":
raise RuntimeError(
"codex-cli-bin is wheel-only; build and publish platform wheels only."
)
build_data["pure_python"] = False
build_data["infer_tag"] = True

View File

@@ -0,0 +1,45 @@
[build-system]
requires = ["hatchling>=1.24.0"]
build-backend = "hatchling.build"
[project]
name = "codex-cli-bin"
version = "0.0.0-dev"
description = "Pinned Codex CLI runtime for the Python SDK"
readme = "README.md"
requires-python = ">=3.10"
license = { text = "Apache-2.0" }
authors = [{ name = "OpenAI" }]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
[project.urls]
Homepage = "https://github.com/openai/codex"
Repository = "https://github.com/openai/codex"
Issues = "https://github.com/openai/codex/issues"
[tool.hatch.build]
exclude = [
".venv/**",
".pytest_cache/**",
"dist/**",
"build/**",
]
[tool.hatch.build.targets.wheel]
packages = ["src/codex_cli_bin"]
include = ["src/codex_cli_bin/bin/**"]
[tool.hatch.build.targets.wheel.hooks.custom]
[tool.hatch.build.targets.sdist]
[tool.hatch.build.targets.sdist.hooks.custom]

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
import os
from pathlib import Path
PACKAGE_NAME = "codex-cli-bin"
def bundled_codex_path() -> Path:
exe = "codex.exe" if os.name == "nt" else "codex"
path = Path(__file__).resolve().parent / "bin" / exe
if not path.is_file():
raise FileNotFoundError(
f"{PACKAGE_NAME} is installed but missing its packaged codex binary at {path}"
)
return path
__all__ = ["PACKAGE_NAME", "bundled_codex_path"]

95
sdk/python/README.md Normal file
View File

@@ -0,0 +1,95 @@
# Codex App Server Python SDK (Experimental)
Experimental Python SDK for `codex app-server` JSON-RPC v2 over stdio, with a small default surface optimized for real scripts and apps.
The generated wire-model layer is currently sourced from the bundled v2 schema and exposed as Pydantic models with snake_case Python fields that serialize back to the app-servers camelCase wire format.
## Install
```bash
cd sdk/python
python -m pip install -e .
```
Published SDK builds pin an exact `codex-cli-bin` runtime dependency. For local
repo development, pass `AppServerConfig(codex_bin=...)` to point at a local
build explicitly.
## Quickstart
```python
from codex_app_server import Codex, TextInput
with Codex() as codex:
thread = codex.thread_start(model="gpt-5")
result = thread.turn(TextInput("Say hello in one sentence.")).run()
print(result.text)
```
## Docs map
- Golden path tutorial: `docs/getting-started.md`
- API reference (signatures + behavior): `docs/api-reference.md`
- Common decisions and pitfalls: `docs/faq.md`
- Runnable examples index: `examples/README.md`
- Jupyter walkthrough notebook: `notebooks/sdk_walkthrough.ipynb`
## Examples
Start here:
```bash
cd sdk/python
python examples/01_quickstart_constructor/sync.py
python examples/01_quickstart_constructor/async.py
```
## Runtime packaging
The repo no longer checks `codex` binaries into `sdk/python`.
Published SDK builds are pinned to an exact `codex-cli-bin` package version,
and that runtime package carries the platform-specific binary for the target
wheel.
For local repo development, the checked-in `sdk/python-runtime` package is only
a template for staged release artifacts. Editable installs should use an
explicit `codex_bin` override instead.
## Maintainer workflow
```bash
cd sdk/python
python scripts/update_sdk_artifacts.py generate-types
python scripts/update_sdk_artifacts.py \
stage-sdk \
/tmp/codex-python-release/codex-app-server-sdk \
--runtime-version 1.2.3
python scripts/update_sdk_artifacts.py \
stage-runtime \
/tmp/codex-python-release/codex-cli-bin \
/path/to/codex \
--runtime-version 1.2.3
```
This supports the CI release flow:
- run `generate-types` before packaging
- stage `codex-app-server-sdk` once with an exact `codex-cli-bin==...` dependency
- stage `codex-cli-bin` on each supported platform runner with the same pinned runtime version
- build and publish `codex-cli-bin` as platform wheels only; do not publish an sdist
## Compatibility and versioning
- Package: `codex-app-server-sdk`
- Runtime package: `codex-cli-bin`
- Current SDK version in this repo: `0.2.0`
- Python: `>=3.10`
- Target protocol: Codex `app-server` JSON-RPC v2
- Recommendation: keep SDK and `codex` CLI reasonably up to date together
## Notes
- `Codex()` is eager and performs startup + `initialize` in the constructor.
- Use context managers (`with Codex() as codex:`) to ensure shutdown.
- For transient overload, use `codex_app_server.retry.retry_on_overload`.

77
sdk/python/docs/faq.md Normal file
View File

@@ -0,0 +1,77 @@
# FAQ
## Thread vs turn
- A `Thread` is conversation state.
- A `Turn` is one model execution inside that thread.
- Multi-turn chat means multiple turns on the same `Thread`.
## `run()` vs `stream()`
- `Turn.run()` is the easiest path. It consumes events until completion and returns `TurnResult`.
- `Turn.stream()` yields raw notifications (`Notification`) so you can react event-by-event.
Choose `run()` for most apps. Choose `stream()` for progress UIs, custom timeout logic, or custom parsing.
## Sync vs async clients
- `Codex` is the minimal sync SDK and best default.
- `AsyncAppServerClient` wraps the sync transport with `asyncio.to_thread(...)` for async-friendly call sites.
If your app is not already async, stay with `Codex`.
## `thread(...)` vs `thread_resume(...)`
- `codex.thread(thread_id)` only binds a local helper to an existing thread ID.
- `codex.thread_resume(thread_id, ...)` performs a `thread/resume` RPC and can apply overrides (model, instructions, sandbox, etc.).
Use `thread(...)` for simple continuation. Use `thread_resume(...)` when you need explicit resume semantics or override fields.
## Why does constructor fail?
`Codex()` is eager: it starts transport and calls `initialize` in `__init__`.
Common causes:
- published runtime package (`codex-cli-bin`) is not installed
- local `codex_bin` override points to a missing file
- local auth/session is missing
- incompatible/old app-server
Maintainers stage releases by building the SDK once and the runtime once per
platform with the same pinned runtime version. Publish `codex-cli-bin` as
platform wheels only; do not publish an sdist:
```bash
cd sdk/python
python scripts/update_sdk_artifacts.py generate-types
python scripts/update_sdk_artifacts.py \
stage-sdk \
/tmp/codex-python-release/codex-app-server-sdk \
--runtime-version 1.2.3
python scripts/update_sdk_artifacts.py \
stage-runtime \
/tmp/codex-python-release/codex-cli-bin \
/path/to/codex \
--runtime-version 1.2.3
```
## Why does a turn "hang"?
A turn is complete only when `turn/completed` arrives for that turn ID.
- `run()` waits for this automatically.
- With `stream()`, make sure you keep consuming notifications until completion.
## How do I retry safely?
Use `retry_on_overload(...)` for transient overload failures (`ServerBusyError`).
Do not blindly retry all errors. For `InvalidParamsError` or `MethodNotFoundError`, fix inputs/version compatibility instead.
## Common pitfalls
- Starting a new thread for every prompt when you wanted continuity.
- Forgetting to `close()` (or not using `with Codex() as codex:`).
- Ignoring `TurnResult.status` and `TurnResult.error`.
- Mixing SDK input classes with raw dicts incorrectly in minimal API paths.

View File

@@ -0,0 +1,75 @@
# Getting Started
This is the fastest path from install to a multi-turn thread using the minimal SDK surface.
## 1) Install
From repo root:
```bash
cd sdk/python
python -m pip install -e .
```
Requirements:
- Python `>=3.10`
- installed `codex-cli-bin` runtime package, or an explicit `codex_bin` override
- Local Codex auth/session configured
## 2) Run your first turn
```python
from codex_app_server import Codex, TextInput
with Codex() as codex:
print("Server:", codex.metadata.server_name, codex.metadata.server_version)
thread = codex.thread_start(model="gpt-5")
result = thread.turn(TextInput("Say hello in one sentence.")).run()
print("Thread:", result.thread_id)
print("Turn:", result.turn_id)
print("Status:", result.status)
print("Text:", result.text)
```
What happened:
- `Codex()` started and initialized `codex app-server`.
- `thread_start(...)` created a thread.
- `turn(...).run()` consumed events until `turn/completed` and returned a `TurnResult`.
## 3) Continue the same thread (multi-turn)
```python
from codex_app_server import Codex, TextInput
with Codex() as codex:
thread = codex.thread_start(model="gpt-5")
first = thread.turn(TextInput("Summarize Rust ownership in 2 bullets.")).run()
second = thread.turn(TextInput("Now explain it to a Python developer.")).run()
print("first:", first.text)
print("second:", second.text)
```
## 4) Resume an existing thread
```python
from codex_app_server import Codex, TextInput
THREAD_ID = "thr_123" # replace with a real id
with Codex() as codex:
thread = codex.thread(THREAD_ID)
result = thread.turn(TextInput("Continue where we left off.")).run()
print(result.text)
```
## 5) Next stops
- API surface and signatures: `docs/api-reference.md`
- Common decisions/pitfalls: `docs/faq.md`
- End-to-end runnable examples: `examples/README.md`

62
sdk/python/pyproject.toml Normal file
View File

@@ -0,0 +1,62 @@
[build-system]
requires = ["hatchling>=1.24.0"]
build-backend = "hatchling.build"
[project]
name = "codex-app-server-sdk"
version = "0.2.0"
description = "Python SDK for Codex app-server v2"
readme = "README.md"
requires-python = ">=3.10"
license = { text = "Apache-2.0" }
authors = [{ name = "OpenClaw Assistant" }]
keywords = ["codex", "json-rpc", "sdk", "llm", "app-server"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Libraries :: Python Modules",
]
dependencies = ["pydantic>=2.12"]
[project.urls]
Homepage = "https://github.com/openai/codex"
Repository = "https://github.com/openai/codex"
Issues = "https://github.com/openai/codex/issues"
[project.optional-dependencies]
dev = ["pytest>=8.0", "datamodel-code-generator==0.31.2", "ruff>=0.11"]
[tool.hatch.build]
exclude = [
".venv/**",
".venv2/**",
".pytest_cache/**",
"dist/**",
"build/**",
]
[tool.hatch.build.targets.wheel]
packages = ["src/codex_app_server"]
include = [
"src/codex_app_server/py.typed",
]
[tool.hatch.build.targets.sdist]
include = [
"src/codex_app_server/**",
"README.md",
"CHANGELOG.md",
"CONTRIBUTING.md",
"RELEASE_CHECKLIST.md",
"pyproject.toml",
]
[tool.pytest.ini_options]
addopts = "-q"
testpaths = ["tests"]

View File

@@ -0,0 +1,998 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import importlib
import json
import platform
import re
import shutil
import stat
import subprocess
import sys
import tempfile
import types
import typing
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Sequence, get_args, get_origin
def repo_root() -> Path:
return Path(__file__).resolve().parents[3]
def sdk_root() -> Path:
return repo_root() / "sdk" / "python"
def python_runtime_root() -> Path:
return repo_root() / "sdk" / "python-runtime"
def schema_bundle_path() -> Path:
return (
repo_root()
/ "codex-rs"
/ "app-server-protocol"
/ "schema"
/ "json"
/ "codex_app_server_protocol.v2.schemas.json"
)
def schema_root_dir() -> Path:
return repo_root() / "codex-rs" / "app-server-protocol" / "schema" / "json"
def _is_windows() -> bool:
return platform.system().lower().startswith("win")
def runtime_binary_name() -> str:
return "codex.exe" if _is_windows() else "codex"
def staged_runtime_bin_path(root: Path) -> Path:
return root / "src" / "codex_cli_bin" / "bin" / runtime_binary_name()
def run(cmd: list[str], cwd: Path) -> None:
subprocess.run(cmd, cwd=str(cwd), check=True)
def run_python_module(module: str, args: list[str], cwd: Path) -> None:
run([sys.executable, "-m", module, *args], cwd)
def current_sdk_version() -> str:
match = re.search(
r'^version = "([^"]+)"$',
(sdk_root() / "pyproject.toml").read_text(),
flags=re.MULTILINE,
)
if match is None:
raise RuntimeError("Could not determine Python SDK version from pyproject.toml")
return match.group(1)
def _copy_package_tree(src: Path, dst: Path) -> None:
if dst.exists():
if dst.is_dir():
shutil.rmtree(dst)
else:
dst.unlink()
shutil.copytree(
src,
dst,
ignore=shutil.ignore_patterns(
".venv",
".venv2",
".pytest_cache",
"__pycache__",
"build",
"dist",
"*.pyc",
),
)
def _rewrite_project_version(pyproject_text: str, version: str) -> str:
updated, count = re.subn(
r'^version = "[^"]+"$',
f'version = "{version}"',
pyproject_text,
count=1,
flags=re.MULTILINE,
)
if count != 1:
raise RuntimeError("Could not rewrite project version in pyproject.toml")
return updated
def _rewrite_sdk_runtime_dependency(pyproject_text: str, runtime_version: str) -> str:
match = re.search(r"^dependencies = \[(.*?)\]$", pyproject_text, flags=re.MULTILINE)
if match is None:
raise RuntimeError(
"Could not find dependencies array in sdk/python/pyproject.toml"
)
raw_items = [item.strip() for item in match.group(1).split(",") if item.strip()]
raw_items = [item for item in raw_items if "codex-cli-bin" not in item]
raw_items.append(f'"codex-cli-bin=={runtime_version}"')
replacement = "dependencies = [\n " + ",\n ".join(raw_items) + ",\n]"
return pyproject_text[: match.start()] + replacement + pyproject_text[match.end() :]
def stage_python_sdk_package(
staging_dir: Path, sdk_version: str, runtime_version: str
) -> Path:
_copy_package_tree(sdk_root(), staging_dir)
sdk_bin_dir = staging_dir / "src" / "codex_app_server" / "bin"
if sdk_bin_dir.exists():
shutil.rmtree(sdk_bin_dir)
pyproject_path = staging_dir / "pyproject.toml"
pyproject_text = pyproject_path.read_text()
pyproject_text = _rewrite_project_version(pyproject_text, sdk_version)
pyproject_text = _rewrite_sdk_runtime_dependency(pyproject_text, runtime_version)
pyproject_path.write_text(pyproject_text)
return staging_dir
def stage_python_runtime_package(
staging_dir: Path, runtime_version: str, binary_path: Path
) -> Path:
_copy_package_tree(python_runtime_root(), staging_dir)
pyproject_path = staging_dir / "pyproject.toml"
pyproject_path.write_text(
_rewrite_project_version(pyproject_path.read_text(), runtime_version)
)
out_bin = staged_runtime_bin_path(staging_dir)
out_bin.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(binary_path, out_bin)
if not _is_windows():
out_bin.chmod(
out_bin.stat().st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH
)
return staging_dir
def _flatten_string_enum_one_of(definition: dict[str, Any]) -> bool:
branches = definition.get("oneOf")
if not isinstance(branches, list) or not branches:
return False
enum_values: list[str] = []
for branch in branches:
if not isinstance(branch, dict):
return False
if branch.get("type") != "string":
return False
enum = branch.get("enum")
if not isinstance(enum, list) or len(enum) != 1 or not isinstance(enum[0], str):
return False
extra_keys = set(branch) - {"type", "enum", "description", "title"}
if extra_keys:
return False
enum_values.append(enum[0])
description = definition.get("description")
title = definition.get("title")
definition.clear()
definition["type"] = "string"
definition["enum"] = enum_values
if isinstance(description, str):
definition["description"] = description
if isinstance(title, str):
definition["title"] = title
return True
DISCRIMINATOR_KEYS = ("type", "method", "mode", "state", "status", "role", "reason")
def _to_pascal_case(value: str) -> str:
parts = re.split(r"[^0-9A-Za-z]+", value)
compact = "".join(part[:1].upper() + part[1:] for part in parts if part)
return compact or "Value"
def _string_literal(value: Any) -> str | None:
if not isinstance(value, dict):
return None
const = value.get("const")
if isinstance(const, str):
return const
enum = value.get("enum")
if isinstance(enum, list) and enum and len(enum) == 1 and isinstance(enum[0], str):
return enum[0]
return None
def _enum_literals(value: Any) -> list[str] | None:
if not isinstance(value, dict):
return None
enum = value.get("enum")
if (
not isinstance(enum, list)
or not enum
or not all(isinstance(item, str) for item in enum)
):
return None
return list(enum)
def _literal_from_property(props: dict[str, Any], key: str) -> str | None:
return _string_literal(props.get(key))
def _variant_definition_name(base: str, variant: dict[str, Any]) -> str | None:
# datamodel-code-generator invents numbered helper names for inline union
# branches unless they carry a stable, unique title up front. We derive
# those titles from the branch discriminator or other identifying shape.
props = variant.get("properties")
if isinstance(props, dict):
for key in DISCRIMINATOR_KEYS:
literal = _literal_from_property(props, key)
if literal is None:
continue
pascal = _to_pascal_case(literal)
if base == "ClientRequest":
return f"{pascal}Request"
if base == "ServerRequest":
return f"{pascal}ServerRequest"
if base == "ClientNotification":
return f"{pascal}ClientNotification"
if base == "ServerNotification":
return f"{pascal}ServerNotification"
if base == "EventMsg":
return f"{pascal}EventMsg"
return f"{pascal}{base}"
if len(props) == 1:
key = next(iter(props))
pascal = _string_literal(props[key])
return f"{_to_pascal_case(pascal or key)}{base}"
required = variant.get("required")
if (
isinstance(required, list)
and len(required) == 1
and isinstance(required[0], str)
):
return f"{_to_pascal_case(required[0])}{base}"
enum_literals = _enum_literals(variant)
if enum_literals is not None:
if len(enum_literals) == 1:
return f"{_to_pascal_case(enum_literals[0])}{base}"
return f"{base}Value"
return None
def _variant_collision_key(
base: str, variant: dict[str, Any], generated_name: str
) -> str:
parts = [f"base={base}", f"generated={generated_name}"]
props = variant.get("properties")
if isinstance(props, dict):
for key in DISCRIMINATOR_KEYS:
literal = _literal_from_property(props, key)
if literal is not None:
parts.append(f"{key}={literal}")
if len(props) == 1:
parts.append(f"only_property={next(iter(props))}")
required = variant.get("required")
if (
isinstance(required, list)
and len(required) == 1
and isinstance(required[0], str)
):
parts.append(f"required_only={required[0]}")
enum_literals = _enum_literals(variant)
if enum_literals is not None:
parts.append(f"enum={'|'.join(enum_literals)}")
return "|".join(parts)
def _set_discriminator_titles(props: dict[str, Any], owner: str) -> None:
for key in DISCRIMINATOR_KEYS:
prop = props.get(key)
if not isinstance(prop, dict):
continue
if _string_literal(prop) is None or "title" in prop:
continue
prop["title"] = f"{owner}{_to_pascal_case(key)}"
def _annotate_variant_list(variants: list[Any], base: str | None) -> None:
seen = {
variant["title"]
for variant in variants
if isinstance(variant, dict) and isinstance(variant.get("title"), str)
}
for variant in variants:
if not isinstance(variant, dict):
continue
variant_name = variant.get("title")
generated_name = _variant_definition_name(base, variant) if base else None
if generated_name is not None and (
not isinstance(variant_name, str)
or "/" in variant_name
or variant_name != generated_name
):
# Titles like `Thread/startedNotification` sanitize poorly in
# Python, and envelope titles like `ErrorNotification` collide
# with their payload model names. Rewrite them before codegen so
# we get `ThreadStartedServerNotification` instead of `...1`.
if generated_name in seen and variant_name != generated_name:
raise RuntimeError(
"Variant title naming collision detected: "
f"{_variant_collision_key(base or '<root>', variant, generated_name)}"
)
variant["title"] = generated_name
seen.add(generated_name)
variant_name = generated_name
if isinstance(variant_name, str):
props = variant.get("properties")
if isinstance(props, dict):
_set_discriminator_titles(props, variant_name)
_annotate_schema(variant, base)
def _annotate_schema(value: Any, base: str | None = None) -> None:
if isinstance(value, list):
for item in value:
_annotate_schema(item, base)
return
if not isinstance(value, dict):
return
owner = value.get("title")
props = value.get("properties")
if isinstance(owner, str) and isinstance(props, dict):
_set_discriminator_titles(props, owner)
one_of = value.get("oneOf")
if isinstance(one_of, list):
# Walk nested unions recursively so every inline branch gets the same
# title normalization treatment before we hand the bundle to Python
# codegen.
_annotate_variant_list(one_of, base)
any_of = value.get("anyOf")
if isinstance(any_of, list):
_annotate_variant_list(any_of, base)
definitions = value.get("definitions")
if isinstance(definitions, dict):
for name, schema in definitions.items():
_annotate_schema(schema, name if isinstance(name, str) else base)
defs = value.get("$defs")
if isinstance(defs, dict):
for name, schema in defs.items():
_annotate_schema(schema, name if isinstance(name, str) else base)
for key, child in value.items():
if key in {"oneOf", "anyOf", "definitions", "$defs"}:
continue
_annotate_schema(child, base)
def _normalized_schema_bundle_text() -> str:
schema = json.loads(schema_bundle_path().read_text())
definitions = schema.get("definitions", {})
if isinstance(definitions, dict):
for definition in definitions.values():
if isinstance(definition, dict):
_flatten_string_enum_one_of(definition)
# Normalize the schema into something datamodel-code-generator can map to
# stable class names instead of anonymous numbered helpers.
_annotate_schema(schema)
return json.dumps(schema, indent=2, sort_keys=True) + "\n"
def generate_v2_all() -> None:
out_path = sdk_root() / "src" / "codex_app_server" / "generated" / "v2_all.py"
out_dir = out_path.parent
old_package_dir = out_dir / "v2_all"
if old_package_dir.exists():
shutil.rmtree(old_package_dir)
out_dir.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as td:
normalized_bundle = Path(td) / schema_bundle_path().name
normalized_bundle.write_text(_normalized_schema_bundle_text())
run_python_module(
"datamodel_code_generator",
[
"--input",
str(normalized_bundle),
"--input-file-type",
"jsonschema",
"--output",
str(out_path),
"--output-model-type",
"pydantic_v2.BaseModel",
"--target-python-version",
"3.11",
"--use-standard-collections",
"--enum-field-as-literal",
"one",
"--field-constraints",
"--use-default-kwarg",
"--snake-case-field",
"--allow-population-by-field-name",
# Once the schema prepass has assigned stable titles, tell the
# generator to prefer those titles as the emitted class names.
"--use-title-as-name",
"--use-annotated",
"--use-union-operator",
"--disable-timestamp",
# Keep the generated file formatted deterministically so the
# checked-in artifact only changes when the schema does.
"--formatters",
"ruff-format",
],
cwd=sdk_root(),
)
_normalize_generated_timestamps(out_path)
def _notification_specs() -> list[tuple[str, str]]:
server_notifications = json.loads(
(schema_root_dir() / "ServerNotification.json").read_text()
)
one_of = server_notifications.get("oneOf", [])
generated_source = (
sdk_root() / "src" / "codex_app_server" / "generated" / "v2_all.py"
).read_text()
specs: list[tuple[str, str]] = []
for variant in one_of:
props = variant.get("properties", {})
method_meta = props.get("method", {})
params_meta = props.get("params", {})
methods = method_meta.get("enum", [])
if len(methods) != 1:
continue
method = methods[0]
if not isinstance(method, str):
continue
ref = params_meta.get("$ref")
if not isinstance(ref, str) or not ref.startswith("#/definitions/"):
continue
class_name = ref.split("/")[-1]
if (
f"class {class_name}(" not in generated_source
and f"{class_name} =" not in generated_source
):
# Skip schema variants that are not emitted into the generated v2 surface.
continue
specs.append((method, class_name))
specs.sort()
return specs
def generate_notification_registry() -> None:
out = (
sdk_root()
/ "src"
/ "codex_app_server"
/ "generated"
/ "notification_registry.py"
)
specs = _notification_specs()
class_names = sorted({class_name for _, class_name in specs})
lines = [
"# Auto-generated by scripts/update_sdk_artifacts.py",
"# DO NOT EDIT MANUALLY.",
"",
"from __future__ import annotations",
"",
"from pydantic import BaseModel",
"",
]
for class_name in class_names:
lines.append(f"from .v2_all import {class_name}")
lines.extend(
[
"",
"NOTIFICATION_MODELS: dict[str, type[BaseModel]] = {",
]
)
for method, class_name in specs:
lines.append(f' "{method}": {class_name},')
lines.extend(["}", ""])
out.write_text("\n".join(lines))
def _normalize_generated_timestamps(root: Path) -> None:
timestamp_re = re.compile(r"^#\s+timestamp:\s+.+$", flags=re.MULTILINE)
py_files = [root] if root.is_file() else sorted(root.rglob("*.py"))
for py_file in py_files:
content = py_file.read_text()
normalized = timestamp_re.sub("# timestamp: <normalized>", content)
if normalized != content:
py_file.write_text(normalized)
FIELD_ANNOTATION_OVERRIDES: dict[str, str] = {
# Keep public API typed without falling back to `Any`.
"config": "JsonObject",
"output_schema": "JsonObject",
}
@dataclass(slots=True)
class PublicFieldSpec:
wire_name: str
py_name: str
annotation: str
required: bool
@dataclass(frozen=True)
class CliOps:
generate_types: Callable[[], None]
stage_python_sdk_package: Callable[[Path, str, str], Path]
stage_python_runtime_package: Callable[[Path, str, Path], Path]
current_sdk_version: Callable[[], str]
def _annotation_to_source(annotation: Any) -> str:
origin = get_origin(annotation)
if origin is typing.Annotated:
return _annotation_to_source(get_args(annotation)[0])
if origin in (typing.Union, types.UnionType):
parts: list[str] = []
for arg in get_args(annotation):
rendered = _annotation_to_source(arg)
if rendered not in parts:
parts.append(rendered)
return " | ".join(parts)
if origin is list:
args = get_args(annotation)
item = _annotation_to_source(args[0]) if args else "Any"
return f"list[{item}]"
if origin is dict:
args = get_args(annotation)
key = _annotation_to_source(args[0]) if args else "str"
val = _annotation_to_source(args[1]) if len(args) > 1 else "Any"
return f"dict[{key}, {val}]"
if annotation is Any or annotation is typing.Any:
return "Any"
if annotation is None or annotation is type(None):
return "None"
if isinstance(annotation, type):
if annotation.__module__ == "builtins":
return annotation.__name__
return annotation.__name__
return repr(annotation)
def _camel_to_snake(name: str) -> str:
head = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", head).lower()
def _load_public_fields(
module_name: str, class_name: str, *, exclude: set[str] | None = None
) -> list[PublicFieldSpec]:
exclude = exclude or set()
module = importlib.import_module(module_name)
model = getattr(module, class_name)
fields: list[PublicFieldSpec] = []
for name, field in model.model_fields.items():
if name in exclude:
continue
required = field.is_required()
annotation = _annotation_to_source(field.annotation)
override = FIELD_ANNOTATION_OVERRIDES.get(name)
if override is not None:
annotation = override if required else f"{override} | None"
fields.append(
PublicFieldSpec(
wire_name=name,
py_name=name,
annotation=annotation,
required=required,
)
)
return fields
def _kw_signature_lines(fields: list[PublicFieldSpec]) -> list[str]:
lines: list[str] = []
for field in fields:
default = "" if field.required else " = None"
lines.append(f" {field.py_name}: {field.annotation}{default},")
return lines
def _model_arg_lines(
fields: list[PublicFieldSpec], *, indent: str = " "
) -> list[str]:
return [f"{indent}{field.wire_name}={field.py_name}," for field in fields]
def _replace_generated_block(source: str, block_name: str, body: str) -> str:
start_tag = f" # BEGIN GENERATED: {block_name}"
end_tag = f" # END GENERATED: {block_name}"
pattern = re.compile(rf"(?s){re.escape(start_tag)}\n.*?\n{re.escape(end_tag)}")
replacement = f"{start_tag}\n{body.rstrip()}\n{end_tag}"
updated, count = pattern.subn(replacement, source, count=1)
if count != 1:
raise RuntimeError(f"Could not update generated block: {block_name}")
return updated
def _render_codex_block(
thread_start_fields: list[PublicFieldSpec],
thread_list_fields: list[PublicFieldSpec],
resume_fields: list[PublicFieldSpec],
fork_fields: list[PublicFieldSpec],
) -> str:
lines = [
" def thread_start(",
" self,",
" *,",
*_kw_signature_lines(thread_start_fields),
" ) -> Thread:",
" params = ThreadStartParams(",
*_model_arg_lines(thread_start_fields),
" )",
" started = self._client.thread_start(params)",
" return Thread(self._client, started.thread.id)",
"",
" def thread_list(",
" self,",
" *,",
*_kw_signature_lines(thread_list_fields),
" ) -> ThreadListResponse:",
" params = ThreadListParams(",
*_model_arg_lines(thread_list_fields),
" )",
" return self._client.thread_list(params)",
"",
" def thread_resume(",
" self,",
" thread_id: str,",
" *,",
*_kw_signature_lines(resume_fields),
" ) -> Thread:",
" params = ThreadResumeParams(",
" thread_id=thread_id,",
*_model_arg_lines(resume_fields),
" )",
" resumed = self._client.thread_resume(thread_id, params)",
" return Thread(self._client, resumed.thread.id)",
"",
" def thread_fork(",
" self,",
" thread_id: str,",
" *,",
*_kw_signature_lines(fork_fields),
" ) -> Thread:",
" params = ThreadForkParams(",
" thread_id=thread_id,",
*_model_arg_lines(fork_fields),
" )",
" forked = self._client.thread_fork(thread_id, params)",
" return Thread(self._client, forked.thread.id)",
"",
" def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:",
" return self._client.thread_archive(thread_id)",
"",
" def thread_unarchive(self, thread_id: str) -> Thread:",
" unarchived = self._client.thread_unarchive(thread_id)",
" return Thread(self._client, unarchived.thread.id)",
]
return "\n".join(lines)
def _render_async_codex_block(
thread_start_fields: list[PublicFieldSpec],
thread_list_fields: list[PublicFieldSpec],
resume_fields: list[PublicFieldSpec],
fork_fields: list[PublicFieldSpec],
) -> str:
lines = [
" async def thread_start(",
" self,",
" *,",
*_kw_signature_lines(thread_start_fields),
" ) -> AsyncThread:",
" await self._ensure_initialized()",
" params = ThreadStartParams(",
*_model_arg_lines(thread_start_fields),
" )",
" started = await self._client.thread_start(params)",
" return AsyncThread(self, started.thread.id)",
"",
" async def thread_list(",
" self,",
" *,",
*_kw_signature_lines(thread_list_fields),
" ) -> ThreadListResponse:",
" await self._ensure_initialized()",
" params = ThreadListParams(",
*_model_arg_lines(thread_list_fields),
" )",
" return await self._client.thread_list(params)",
"",
" async def thread_resume(",
" self,",
" thread_id: str,",
" *,",
*_kw_signature_lines(resume_fields),
" ) -> AsyncThread:",
" await self._ensure_initialized()",
" params = ThreadResumeParams(",
" thread_id=thread_id,",
*_model_arg_lines(resume_fields),
" )",
" resumed = await self._client.thread_resume(thread_id, params)",
" return AsyncThread(self, resumed.thread.id)",
"",
" async def thread_fork(",
" self,",
" thread_id: str,",
" *,",
*_kw_signature_lines(fork_fields),
" ) -> AsyncThread:",
" await self._ensure_initialized()",
" params = ThreadForkParams(",
" thread_id=thread_id,",
*_model_arg_lines(fork_fields),
" )",
" forked = await self._client.thread_fork(thread_id, params)",
" return AsyncThread(self, forked.thread.id)",
"",
" async def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:",
" await self._ensure_initialized()",
" return await self._client.thread_archive(thread_id)",
"",
" async def thread_unarchive(self, thread_id: str) -> AsyncThread:",
" await self._ensure_initialized()",
" unarchived = await self._client.thread_unarchive(thread_id)",
" return AsyncThread(self, unarchived.thread.id)",
]
return "\n".join(lines)
def _render_thread_block(
turn_fields: list[PublicFieldSpec],
) -> str:
lines = [
" def turn(",
" self,",
" input: Input,",
" *,",
*_kw_signature_lines(turn_fields),
" ) -> Turn:",
" wire_input = _to_wire_input(input)",
" params = TurnStartParams(",
" thread_id=self.id,",
" input=wire_input,",
*_model_arg_lines(turn_fields),
" )",
" turn = self._client.turn_start(self.id, wire_input, params=params)",
" return Turn(self._client, self.id, turn.turn.id)",
]
return "\n".join(lines)
def _render_async_thread_block(
turn_fields: list[PublicFieldSpec],
) -> str:
lines = [
" async def turn(",
" self,",
" input: Input,",
" *,",
*_kw_signature_lines(turn_fields),
" ) -> AsyncTurn:",
" await self._codex._ensure_initialized()",
" wire_input = _to_wire_input(input)",
" params = TurnStartParams(",
" thread_id=self.id,",
" input=wire_input,",
*_model_arg_lines(turn_fields),
" )",
" turn = await self._codex._client.turn_start(",
" self.id,",
" wire_input,",
" params=params,",
" )",
" return AsyncTurn(self._codex, self.id, turn.turn.id)",
]
return "\n".join(lines)
def generate_public_api_flat_methods() -> None:
src_dir = sdk_root() / "src"
public_api_path = src_dir / "codex_app_server" / "public_api.py"
if not public_api_path.exists():
# PR2 can run codegen before the ergonomic public API layer is added.
return
src_dir_str = str(src_dir)
if src_dir_str not in sys.path:
sys.path.insert(0, src_dir_str)
thread_start_fields = _load_public_fields(
"codex_app_server.generated.v2_all",
"ThreadStartParams",
)
thread_list_fields = _load_public_fields(
"codex_app_server.generated.v2_all",
"ThreadListParams",
)
thread_resume_fields = _load_public_fields(
"codex_app_server.generated.v2_all",
"ThreadResumeParams",
exclude={"thread_id"},
)
thread_fork_fields = _load_public_fields(
"codex_app_server.generated.v2_all",
"ThreadForkParams",
exclude={"thread_id"},
)
turn_start_fields = _load_public_fields(
"codex_app_server.generated.v2_all",
"TurnStartParams",
exclude={"thread_id", "input"},
)
source = public_api_path.read_text()
source = _replace_generated_block(
source,
"Codex.flat_methods",
_render_codex_block(
thread_start_fields,
thread_list_fields,
thread_resume_fields,
thread_fork_fields,
),
)
source = _replace_generated_block(
source,
"AsyncCodex.flat_methods",
_render_async_codex_block(
thread_start_fields,
thread_list_fields,
thread_resume_fields,
thread_fork_fields,
),
)
source = _replace_generated_block(
source,
"Thread.flat_methods",
_render_thread_block(turn_start_fields),
)
source = _replace_generated_block(
source,
"AsyncThread.flat_methods",
_render_async_thread_block(turn_start_fields),
)
public_api_path.write_text(source)
def generate_types() -> None:
# v2_all is the authoritative generated surface.
generate_v2_all()
generate_notification_registry()
generate_public_api_flat_methods()
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Single SDK maintenance entrypoint")
subparsers = parser.add_subparsers(dest="command", required=True)
subparsers.add_parser(
"generate-types", help="Regenerate Python protocol-derived types"
)
stage_sdk_parser = subparsers.add_parser(
"stage-sdk",
help="Stage a releasable SDK package pinned to a runtime version",
)
stage_sdk_parser.add_argument(
"staging_dir",
type=Path,
help="Output directory for the staged SDK package",
)
stage_sdk_parser.add_argument(
"--runtime-version",
required=True,
help="Pinned codex-cli-bin version for the staged SDK package",
)
stage_sdk_parser.add_argument(
"--sdk-version",
help="Version to write into the staged SDK package (defaults to sdk/python current version)",
)
stage_runtime_parser = subparsers.add_parser(
"stage-runtime",
help="Stage a releasable runtime package for the current platform",
)
stage_runtime_parser.add_argument(
"staging_dir",
type=Path,
help="Output directory for the staged runtime package",
)
stage_runtime_parser.add_argument(
"runtime_binary",
type=Path,
help="Path to the codex binary to package for this platform",
)
stage_runtime_parser.add_argument(
"--runtime-version",
required=True,
help="Version to write into the staged runtime package",
)
return parser
def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
return build_parser().parse_args(list(argv) if argv is not None else None)
def default_cli_ops() -> CliOps:
return CliOps(
generate_types=generate_types,
stage_python_sdk_package=stage_python_sdk_package,
stage_python_runtime_package=stage_python_runtime_package,
current_sdk_version=current_sdk_version,
)
def run_command(args: argparse.Namespace, ops: CliOps) -> None:
if args.command == "generate-types":
ops.generate_types()
elif args.command == "stage-sdk":
ops.generate_types()
ops.stage_python_sdk_package(
args.staging_dir,
args.sdk_version or ops.current_sdk_version(),
args.runtime_version,
)
elif args.command == "stage-runtime":
ops.stage_python_runtime_package(
args.staging_dir,
args.runtime_version,
args.runtime_binary.resolve(),
)
def main(argv: Sequence[str] | None = None, ops: CliOps | None = None) -> None:
args = parse_args(argv)
run_command(args, ops or default_cli_ops())
print("Done.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,10 @@
from .client import AppServerClient, AppServerConfig
from .errors import AppServerError, JsonRpcError, TransportClosedError
__all__ = [
"AppServerClient",
"AppServerConfig",
"AppServerError",
"JsonRpcError",
"TransportClosedError",
]

View File

@@ -0,0 +1,540 @@
from __future__ import annotations
import json
import os
import subprocess
import threading
import uuid
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Iterable, Iterator, TypeVar
from pydantic import BaseModel
from .errors import AppServerError, TransportClosedError, map_jsonrpc_error
from .generated.notification_registry import NOTIFICATION_MODELS
from .generated.v2_all import (
AgentMessageDeltaNotification,
ModelListResponse,
ThreadArchiveResponse,
ThreadCompactStartResponse,
ThreadForkParams as V2ThreadForkParams,
ThreadForkResponse,
ThreadListParams as V2ThreadListParams,
ThreadListResponse,
ThreadReadResponse,
ThreadResumeParams as V2ThreadResumeParams,
ThreadResumeResponse,
ThreadSetNameResponse,
ThreadStartParams as V2ThreadStartParams,
ThreadStartResponse,
ThreadUnarchiveResponse,
TurnCompletedNotification,
TurnInterruptResponse,
TurnStartParams as V2TurnStartParams,
TurnStartResponse,
TurnSteerResponse,
)
from .models import (
InitializeResponse,
JsonObject,
JsonValue,
Notification,
UnknownNotification,
)
from .retry import retry_on_overload
ModelT = TypeVar("ModelT", bound=BaseModel)
ApprovalHandler = Callable[[str, JsonObject | None], JsonObject]
RUNTIME_PKG_NAME = "codex-cli-bin"
def _params_dict(
params: (
V2ThreadStartParams
| V2ThreadResumeParams
| V2ThreadListParams
| V2ThreadForkParams
| V2TurnStartParams
| JsonObject
| None
),
) -> JsonObject:
if params is None:
return {}
if hasattr(params, "model_dump"):
dumped = params.model_dump(
by_alias=True,
exclude_none=True,
mode="json",
)
if not isinstance(dumped, dict):
raise TypeError("Expected model_dump() to return dict")
return dumped
if isinstance(params, dict):
return params
raise TypeError(f"Expected generated params model or dict, got {type(params).__name__}")
def _installed_codex_path() -> Path:
try:
from codex_cli_bin import bundled_codex_path
except ImportError as exc:
raise FileNotFoundError(
"Unable to locate the pinned Codex runtime. Install the published SDK build "
f"with its {RUNTIME_PKG_NAME} dependency, or set AppServerConfig.codex_bin "
"explicitly."
) from exc
return bundled_codex_path()
@dataclass(frozen=True)
class CodexBinResolverOps:
installed_codex_path: Callable[[], Path]
path_exists: Callable[[Path], bool]
def _default_codex_bin_resolver_ops() -> CodexBinResolverOps:
return CodexBinResolverOps(
installed_codex_path=_installed_codex_path,
path_exists=lambda path: path.exists(),
)
def resolve_codex_bin(config: "AppServerConfig", ops: CodexBinResolverOps) -> Path:
if config.codex_bin is not None:
codex_bin = Path(config.codex_bin)
if not ops.path_exists(codex_bin):
raise FileNotFoundError(
f"Codex binary not found at {codex_bin}. Set AppServerConfig.codex_bin "
"to a valid binary path."
)
return codex_bin
return ops.installed_codex_path()
def _resolve_codex_bin(config: "AppServerConfig") -> Path:
return resolve_codex_bin(config, _default_codex_bin_resolver_ops())
@dataclass(slots=True)
class AppServerConfig:
codex_bin: str | None = None
launch_args_override: tuple[str, ...] | None = None
config_overrides: tuple[str, ...] = ()
cwd: str | None = None
env: dict[str, str] | None = None
client_name: str = "codex_python_sdk"
client_title: str = "Codex Python SDK"
client_version: str = "0.2.0"
experimental_api: bool = True
class AppServerClient:
"""Synchronous typed JSON-RPC client for `codex app-server` over stdio."""
def __init__(
self,
config: AppServerConfig | None = None,
approval_handler: ApprovalHandler | None = None,
) -> None:
self.config = config or AppServerConfig()
self._approval_handler = approval_handler or self._default_approval_handler
self._proc: subprocess.Popen[str] | None = None
self._lock = threading.Lock()
self._turn_consumer_lock = threading.Lock()
self._active_turn_consumer: str | None = None
self._pending_notifications: deque[Notification] = deque()
self._stderr_lines: deque[str] = deque(maxlen=400)
self._stderr_thread: threading.Thread | None = None
def __enter__(self) -> "AppServerClient":
self.start()
return self
def __exit__(self, _exc_type, _exc, _tb) -> None:
self.close()
def start(self) -> None:
if self._proc is not None:
return
if self.config.launch_args_override is not None:
args = list(self.config.launch_args_override)
else:
codex_bin = _resolve_codex_bin(self.config)
args = [str(codex_bin)]
for kv in self.config.config_overrides:
args.extend(["--config", kv])
args.extend(["app-server", "--listen", "stdio://"])
env = os.environ.copy()
if self.config.env:
env.update(self.config.env)
self._proc = subprocess.Popen(
args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
cwd=self.config.cwd,
env=env,
bufsize=1,
)
self._start_stderr_drain_thread()
def close(self) -> None:
if self._proc is None:
return
proc = self._proc
self._proc = None
self._active_turn_consumer = None
if proc.stdin:
proc.stdin.close()
try:
proc.terminate()
proc.wait(timeout=2)
except Exception:
proc.kill()
if self._stderr_thread and self._stderr_thread.is_alive():
self._stderr_thread.join(timeout=0.5)
def initialize(self) -> InitializeResponse:
result = self.request(
"initialize",
{
"clientInfo": {
"name": self.config.client_name,
"title": self.config.client_title,
"version": self.config.client_version,
},
"capabilities": {
"experimentalApi": self.config.experimental_api,
},
},
response_model=InitializeResponse,
)
self.notify("initialized", None)
return result
def request(
self,
method: str,
params: JsonObject | None,
*,
response_model: type[ModelT],
) -> ModelT:
result = self._request_raw(method, params)
if not isinstance(result, dict):
raise AppServerError(f"{method} response must be a JSON object")
return response_model.model_validate(result)
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
request_id = str(uuid.uuid4())
self._write_message({"id": request_id, "method": method, "params": params or {}})
while True:
msg = self._read_message()
if "method" in msg and "id" in msg:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
continue
if "method" in msg and "id" not in msg:
self._pending_notifications.append(
self._coerce_notification(msg["method"], msg.get("params"))
)
continue
if msg.get("id") != request_id:
continue
if "error" in msg:
err = msg["error"]
if isinstance(err, dict):
raise map_jsonrpc_error(
int(err.get("code", -32000)),
str(err.get("message", "unknown")),
err.get("data"),
)
raise AppServerError("Malformed JSON-RPC error response")
return msg.get("result")
def notify(self, method: str, params: JsonObject | None = None) -> None:
self._write_message({"method": method, "params": params or {}})
def next_notification(self) -> Notification:
if self._pending_notifications:
return self._pending_notifications.popleft()
while True:
msg = self._read_message()
if "method" in msg and "id" in msg:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
continue
if "method" in msg and "id" not in msg:
return self._coerce_notification(msg["method"], msg.get("params"))
def acquire_turn_consumer(self, turn_id: str) -> None:
with self._turn_consumer_lock:
if self._active_turn_consumer is not None:
raise RuntimeError(
"Concurrent turn consumers are not yet supported in the experimental SDK. "
f"Client is already streaming turn {self._active_turn_consumer!r}; "
f"cannot start turn {turn_id!r} until the active consumer finishes."
)
self._active_turn_consumer = turn_id
def release_turn_consumer(self, turn_id: str) -> None:
with self._turn_consumer_lock:
if self._active_turn_consumer == turn_id:
self._active_turn_consumer = None
def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse)
def thread_resume(
self,
thread_id: str,
params: V2ThreadResumeParams | JsonObject | None = None,
) -> ThreadResumeResponse:
payload = {"threadId": thread_id, **_params_dict(params)}
return self.request("thread/resume", payload, response_model=ThreadResumeResponse)
def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse:
return self.request("thread/list", _params_dict(params), response_model=ThreadListResponse)
def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse:
return self.request(
"thread/read",
{"threadId": thread_id, "includeTurns": include_turns},
response_model=ThreadReadResponse,
)
def thread_fork(
self,
thread_id: str,
params: V2ThreadForkParams | JsonObject | None = None,
) -> ThreadForkResponse:
payload = {"threadId": thread_id, **_params_dict(params)}
return self.request("thread/fork", payload, response_model=ThreadForkResponse)
def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:
return self.request("thread/archive", {"threadId": thread_id}, response_model=ThreadArchiveResponse)
def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse:
return self.request("thread/unarchive", {"threadId": thread_id}, response_model=ThreadUnarchiveResponse)
def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse:
return self.request(
"thread/name/set",
{"threadId": thread_id, "name": name},
response_model=ThreadSetNameResponse,
)
def thread_compact(self, thread_id: str) -> ThreadCompactStartResponse:
return self.request(
"thread/compact/start",
{"threadId": thread_id},
response_model=ThreadCompactStartResponse,
)
def turn_start(
self,
thread_id: str,
input_items: list[JsonObject] | JsonObject | str,
params: V2TurnStartParams | JsonObject | None = None,
) -> TurnStartResponse:
payload = {
**_params_dict(params),
"threadId": thread_id,
"input": self._normalize_input_items(input_items),
}
return self.request("turn/start", payload, response_model=TurnStartResponse)
def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
return self.request(
"turn/interrupt",
{"threadId": thread_id, "turnId": turn_id},
response_model=TurnInterruptResponse,
)
def turn_steer(
self,
thread_id: str,
expected_turn_id: str,
input_items: list[JsonObject] | JsonObject | str,
) -> TurnSteerResponse:
return self.request(
"turn/steer",
{
"threadId": thread_id,
"expectedTurnId": expected_turn_id,
"input": self._normalize_input_items(input_items),
},
response_model=TurnSteerResponse,
)
def model_list(self, include_hidden: bool = False) -> ModelListResponse:
return self.request(
"model/list",
{"includeHidden": include_hidden},
response_model=ModelListResponse,
)
def request_with_retry_on_overload(
self,
method: str,
params: JsonObject | None,
*,
response_model: type[ModelT],
max_attempts: int = 3,
initial_delay_s: float = 0.25,
max_delay_s: float = 2.0,
) -> ModelT:
return retry_on_overload(
lambda: self.request(method, params, response_model=response_model),
max_attempts=max_attempts,
initial_delay_s=initial_delay_s,
max_delay_s=max_delay_s,
)
def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
while True:
notification = self.next_notification()
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
return notification.payload
def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]:
target_methods = {methods} if isinstance(methods, str) else set(methods)
out: list[Notification] = []
while True:
notification = self.next_notification()
out.append(notification)
if notification.method in target_methods:
return out
def stream_text(
self,
thread_id: str,
text: str,
params: V2TurnStartParams | JsonObject | None = None,
) -> Iterator[AgentMessageDeltaNotification]:
started = self.turn_start(thread_id, text, params=params)
turn_id = started.turn.id
while True:
notification = self.next_notification()
if (
notification.method == "item/agentMessage/delta"
and isinstance(notification.payload, AgentMessageDeltaNotification)
and notification.payload.turn_id == turn_id
):
yield notification.payload
continue
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
break
def _coerce_notification(self, method: str, params: object) -> Notification:
params_dict = params if isinstance(params, dict) else {}
model = NOTIFICATION_MODELS.get(method)
if model is None:
return Notification(method=method, payload=UnknownNotification(params=params_dict))
try:
payload = model.model_validate(params_dict)
except Exception: # noqa: BLE001
return Notification(method=method, payload=UnknownNotification(params=params_dict))
return Notification(method=method, payload=payload)
def _normalize_input_items(
self,
input_items: list[JsonObject] | JsonObject | str,
) -> list[JsonObject]:
if isinstance(input_items, str):
return [{"type": "text", "text": input_items}]
if isinstance(input_items, dict):
return [input_items]
return input_items
def _default_approval_handler(self, method: str, params: JsonObject | None) -> JsonObject:
if method == "item/commandExecution/requestApproval":
return {"decision": "accept"}
if method == "item/fileChange/requestApproval":
return {"decision": "accept"}
return {}
def _start_stderr_drain_thread(self) -> None:
if self._proc is None or self._proc.stderr is None:
return
def _drain() -> None:
stderr = self._proc.stderr
if stderr is None:
return
for line in stderr:
self._stderr_lines.append(line.rstrip("\n"))
self._stderr_thread = threading.Thread(target=_drain, daemon=True)
self._stderr_thread.start()
def _stderr_tail(self, limit: int = 40) -> str:
return "\n".join(list(self._stderr_lines)[-limit:])
def _handle_server_request(self, msg: dict[str, JsonValue]) -> JsonObject:
method = msg["method"]
params = msg.get("params")
if not isinstance(method, str):
return {}
return self._approval_handler(
method,
params if isinstance(params, dict) else None,
)
def _write_message(self, payload: JsonObject) -> None:
if self._proc is None or self._proc.stdin is None:
raise TransportClosedError("app-server is not running")
with self._lock:
self._proc.stdin.write(json.dumps(payload) + "\n")
self._proc.stdin.flush()
def _read_message(self) -> dict[str, JsonValue]:
if self._proc is None or self._proc.stdout is None:
raise TransportClosedError("app-server is not running")
line = self._proc.stdout.readline()
if not line:
raise TransportClosedError(
f"app-server closed stdout. stderr_tail={self._stderr_tail()[:2000]}"
)
try:
message = json.loads(line)
except json.JSONDecodeError as exc:
raise AppServerError(f"Invalid JSON-RPC line: {line!r}") from exc
if not isinstance(message, dict):
raise AppServerError(f"Invalid JSON-RPC payload: {message!r}")
return message
def default_codex_home() -> str:
return str(Path.home() / ".codex")

View File

@@ -0,0 +1,125 @@
from __future__ import annotations
from typing import Any
class AppServerError(Exception):
"""Base exception for SDK errors."""
class JsonRpcError(AppServerError):
"""Raw JSON-RPC error wrapper from the server."""
def __init__(self, code: int, message: str, data: Any = None):
super().__init__(f"JSON-RPC error {code}: {message}")
self.code = code
self.message = message
self.data = data
class TransportClosedError(AppServerError):
"""Raised when the app-server transport closes unexpectedly."""
class AppServerRpcError(JsonRpcError):
"""Base typed error for JSON-RPC failures."""
class ParseError(AppServerRpcError):
pass
class InvalidRequestError(AppServerRpcError):
pass
class MethodNotFoundError(AppServerRpcError):
pass
class InvalidParamsError(AppServerRpcError):
pass
class InternalRpcError(AppServerRpcError):
pass
class ServerBusyError(AppServerRpcError):
"""Server is overloaded / unavailable and caller should retry."""
class RetryLimitExceededError(ServerBusyError):
"""Server exhausted internal retry budget for a retryable operation."""
def _contains_retry_limit_text(message: str) -> bool:
lowered = message.lower()
return "retry limit" in lowered or "too many failed attempts" in lowered
def _is_server_overloaded(data: Any) -> bool:
if data is None:
return False
if isinstance(data, str):
return data.lower() == "server_overloaded"
if isinstance(data, dict):
direct = (
data.get("codex_error_info")
or data.get("codexErrorInfo")
or data.get("errorInfo")
)
if isinstance(direct, str) and direct.lower() == "server_overloaded":
return True
if isinstance(direct, dict):
for value in direct.values():
if isinstance(value, str) and value.lower() == "server_overloaded":
return True
for value in data.values():
if _is_server_overloaded(value):
return True
if isinstance(data, list):
return any(_is_server_overloaded(value) for value in data)
return False
def map_jsonrpc_error(code: int, message: str, data: Any = None) -> JsonRpcError:
"""Map a raw JSON-RPC error into a richer SDK exception class."""
if code == -32700:
return ParseError(code, message, data)
if code == -32600:
return InvalidRequestError(code, message, data)
if code == -32601:
return MethodNotFoundError(code, message, data)
if code == -32602:
return InvalidParamsError(code, message, data)
if code == -32603:
return InternalRpcError(code, message, data)
if -32099 <= code <= -32000:
if _is_server_overloaded(data):
if _contains_retry_limit_text(message):
return RetryLimitExceededError(code, message, data)
return ServerBusyError(code, message, data)
if _contains_retry_limit_text(message):
return RetryLimitExceededError(code, message, data)
return AppServerRpcError(code, message, data)
return JsonRpcError(code, message, data)
def is_retryable_error(exc: BaseException) -> bool:
"""True if the exception is a transient overload-style error."""
if isinstance(exc, ServerBusyError):
return True
if isinstance(exc, JsonRpcError):
return _is_server_overloaded(exc.data)
return False

View File

@@ -0,0 +1 @@
"""Auto-generated Python types derived from the app-server schemas."""

View File

@@ -0,0 +1,102 @@
# Auto-generated by scripts/update_sdk_artifacts.py
# DO NOT EDIT MANUALLY.
from __future__ import annotations
from pydantic import BaseModel
from .v2_all import AccountLoginCompletedNotification
from .v2_all import AccountRateLimitsUpdatedNotification
from .v2_all import AccountUpdatedNotification
from .v2_all import AgentMessageDeltaNotification
from .v2_all import AppListUpdatedNotification
from .v2_all import CommandExecOutputDeltaNotification
from .v2_all import CommandExecutionOutputDeltaNotification
from .v2_all import ConfigWarningNotification
from .v2_all import ContextCompactedNotification
from .v2_all import DeprecationNoticeNotification
from .v2_all import ErrorNotification
from .v2_all import FileChangeOutputDeltaNotification
from .v2_all import FuzzyFileSearchSessionCompletedNotification
from .v2_all import FuzzyFileSearchSessionUpdatedNotification
from .v2_all import HookCompletedNotification
from .v2_all import HookStartedNotification
from .v2_all import ItemCompletedNotification
from .v2_all import ItemStartedNotification
from .v2_all import McpServerOauthLoginCompletedNotification
from .v2_all import McpToolCallProgressNotification
from .v2_all import ModelReroutedNotification
from .v2_all import PlanDeltaNotification
from .v2_all import ReasoningSummaryPartAddedNotification
from .v2_all import ReasoningSummaryTextDeltaNotification
from .v2_all import ReasoningTextDeltaNotification
from .v2_all import ServerRequestResolvedNotification
from .v2_all import SkillsChangedNotification
from .v2_all import TerminalInteractionNotification
from .v2_all import ThreadArchivedNotification
from .v2_all import ThreadClosedNotification
from .v2_all import ThreadNameUpdatedNotification
from .v2_all import ThreadRealtimeClosedNotification
from .v2_all import ThreadRealtimeErrorNotification
from .v2_all import ThreadRealtimeItemAddedNotification
from .v2_all import ThreadRealtimeOutputAudioDeltaNotification
from .v2_all import ThreadRealtimeStartedNotification
from .v2_all import ThreadStartedNotification
from .v2_all import ThreadStatusChangedNotification
from .v2_all import ThreadTokenUsageUpdatedNotification
from .v2_all import ThreadUnarchivedNotification
from .v2_all import TurnCompletedNotification
from .v2_all import TurnDiffUpdatedNotification
from .v2_all import TurnPlanUpdatedNotification
from .v2_all import TurnStartedNotification
from .v2_all import WindowsSandboxSetupCompletedNotification
from .v2_all import WindowsWorldWritableWarningNotification
NOTIFICATION_MODELS: dict[str, type[BaseModel]] = {
"account/login/completed": AccountLoginCompletedNotification,
"account/rateLimits/updated": AccountRateLimitsUpdatedNotification,
"account/updated": AccountUpdatedNotification,
"app/list/updated": AppListUpdatedNotification,
"command/exec/outputDelta": CommandExecOutputDeltaNotification,
"configWarning": ConfigWarningNotification,
"deprecationNotice": DeprecationNoticeNotification,
"error": ErrorNotification,
"fuzzyFileSearch/sessionCompleted": FuzzyFileSearchSessionCompletedNotification,
"fuzzyFileSearch/sessionUpdated": FuzzyFileSearchSessionUpdatedNotification,
"hook/completed": HookCompletedNotification,
"hook/started": HookStartedNotification,
"item/agentMessage/delta": AgentMessageDeltaNotification,
"item/commandExecution/outputDelta": CommandExecutionOutputDeltaNotification,
"item/commandExecution/terminalInteraction": TerminalInteractionNotification,
"item/completed": ItemCompletedNotification,
"item/fileChange/outputDelta": FileChangeOutputDeltaNotification,
"item/mcpToolCall/progress": McpToolCallProgressNotification,
"item/plan/delta": PlanDeltaNotification,
"item/reasoning/summaryPartAdded": ReasoningSummaryPartAddedNotification,
"item/reasoning/summaryTextDelta": ReasoningSummaryTextDeltaNotification,
"item/reasoning/textDelta": ReasoningTextDeltaNotification,
"item/started": ItemStartedNotification,
"mcpServer/oauthLogin/completed": McpServerOauthLoginCompletedNotification,
"model/rerouted": ModelReroutedNotification,
"serverRequest/resolved": ServerRequestResolvedNotification,
"skills/changed": SkillsChangedNotification,
"thread/archived": ThreadArchivedNotification,
"thread/closed": ThreadClosedNotification,
"thread/compacted": ContextCompactedNotification,
"thread/name/updated": ThreadNameUpdatedNotification,
"thread/realtime/closed": ThreadRealtimeClosedNotification,
"thread/realtime/error": ThreadRealtimeErrorNotification,
"thread/realtime/itemAdded": ThreadRealtimeItemAddedNotification,
"thread/realtime/outputAudio/delta": ThreadRealtimeOutputAudioDeltaNotification,
"thread/realtime/started": ThreadRealtimeStartedNotification,
"thread/started": ThreadStartedNotification,
"thread/status/changed": ThreadStatusChangedNotification,
"thread/tokenUsage/updated": ThreadTokenUsageUpdatedNotification,
"thread/unarchived": ThreadUnarchivedNotification,
"turn/completed": TurnCompletedNotification,
"turn/diff/updated": TurnDiffUpdatedNotification,
"turn/plan/updated": TurnPlanUpdatedNotification,
"turn/started": TurnStartedNotification,
"windows/worldWritableWarning": WindowsWorldWritableWarningNotification,
"windowsSandbox/setupCompleted": WindowsSandboxSetupCompletedNotification,
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,25 @@
"""Stable aliases over full v2 autogenerated models (datamodel-code-generator)."""
from .v2_all.ModelListResponse import ModelListResponse
from .v2_all.ThreadCompactStartResponse import ThreadCompactStartResponse
from .v2_all.ThreadListResponse import ThreadListResponse
from .v2_all.ThreadReadResponse import ThreadReadResponse
from .v2_all.ThreadTokenUsageUpdatedNotification import (
ThreadTokenUsageUpdatedNotification,
)
from .v2_all.TurnCompletedNotification import ThreadItem153 as ThreadItem
from .v2_all.TurnCompletedNotification import (
TurnCompletedNotification as TurnCompletedNotificationPayload,
)
from .v2_all.TurnSteerResponse import TurnSteerResponse
__all__ = [
"ModelListResponse",
"ThreadCompactStartResponse",
"ThreadListResponse",
"ThreadReadResponse",
"ThreadTokenUsageUpdatedNotification",
"TurnCompletedNotificationPayload",
"TurnSteerResponse",
"ThreadItem",
]

View File

@@ -0,0 +1,97 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TypeAlias
from pydantic import BaseModel
from .generated.v2_all import (
AccountLoginCompletedNotification,
AccountRateLimitsUpdatedNotification,
AccountUpdatedNotification,
AgentMessageDeltaNotification,
AppListUpdatedNotification,
CommandExecutionOutputDeltaNotification,
ConfigWarningNotification,
ContextCompactedNotification,
DeprecationNoticeNotification,
ErrorNotification,
FileChangeOutputDeltaNotification,
ItemCompletedNotification,
ItemStartedNotification,
McpServerOauthLoginCompletedNotification,
McpToolCallProgressNotification,
PlanDeltaNotification,
RawResponseItemCompletedNotification,
ReasoningSummaryPartAddedNotification,
ReasoningSummaryTextDeltaNotification,
ReasoningTextDeltaNotification,
TerminalInteractionNotification,
ThreadNameUpdatedNotification,
ThreadStartedNotification,
ThreadTokenUsageUpdatedNotification,
TurnCompletedNotification,
TurnDiffUpdatedNotification,
TurnPlanUpdatedNotification,
TurnStartedNotification,
WindowsWorldWritableWarningNotification,
)
JsonScalar: TypeAlias = str | int | float | bool | None
JsonValue: TypeAlias = JsonScalar | dict[str, "JsonValue"] | list["JsonValue"]
JsonObject: TypeAlias = dict[str, JsonValue]
@dataclass(slots=True)
class UnknownNotification:
params: JsonObject
NotificationPayload: TypeAlias = (
AccountLoginCompletedNotification
| AccountRateLimitsUpdatedNotification
| AccountUpdatedNotification
| AgentMessageDeltaNotification
| AppListUpdatedNotification
| CommandExecutionOutputDeltaNotification
| ConfigWarningNotification
| ContextCompactedNotification
| DeprecationNoticeNotification
| ErrorNotification
| FileChangeOutputDeltaNotification
| ItemCompletedNotification
| ItemStartedNotification
| McpServerOauthLoginCompletedNotification
| McpToolCallProgressNotification
| PlanDeltaNotification
| RawResponseItemCompletedNotification
| ReasoningSummaryPartAddedNotification
| ReasoningSummaryTextDeltaNotification
| ReasoningTextDeltaNotification
| TerminalInteractionNotification
| ThreadNameUpdatedNotification
| ThreadStartedNotification
| ThreadTokenUsageUpdatedNotification
| TurnCompletedNotification
| TurnDiffUpdatedNotification
| TurnPlanUpdatedNotification
| TurnStartedNotification
| WindowsWorldWritableWarningNotification
| UnknownNotification
)
@dataclass(slots=True)
class Notification:
method: str
payload: NotificationPayload
class ServerInfo(BaseModel):
name: str | None = None
version: str | None = None
class InitializeResponse(BaseModel):
serverInfo: ServerInfo | None = None
userAgent: str | None = None

View File

View File

@@ -0,0 +1,41 @@
from __future__ import annotations
import random
import time
from typing import Callable, TypeVar
from .errors import is_retryable_error
T = TypeVar("T")
def retry_on_overload(
op: Callable[[], T],
*,
max_attempts: int = 3,
initial_delay_s: float = 0.25,
max_delay_s: float = 2.0,
jitter_ratio: float = 0.2,
) -> T:
"""Retry helper for transient server-overload errors."""
if max_attempts < 1:
raise ValueError("max_attempts must be >= 1")
delay = initial_delay_s
attempt = 0
while True:
attempt += 1
try:
return op()
except Exception as exc:
if attempt >= max_attempts:
raise
if not is_retryable_error(exc):
raise
jitter = delay * jitter_ratio
sleep_for = min(max_delay_s, delay) + random.uniform(-jitter, jitter)
if sleep_for > 0:
time.sleep(sleep_for)
delay = min(max_delay_s, delay * 2)

View File

@@ -0,0 +1,16 @@
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
src_str = str(SRC)
if src_str in sys.path:
sys.path.remove(src_str)
sys.path.insert(0, src_str)
for module_name in list(sys.modules):
if module_name == "codex_app_server" or module_name.startswith("codex_app_server."):
sys.modules.pop(module_name)

View File

@@ -0,0 +1,411 @@
from __future__ import annotations
import ast
import importlib.util
import json
import sys
import tomllib
from pathlib import Path
import pytest
ROOT = Path(__file__).resolve().parents[1]
def _load_update_script_module():
script_path = ROOT / "scripts" / "update_sdk_artifacts.py"
spec = importlib.util.spec_from_file_location("update_sdk_artifacts", script_path)
if spec is None or spec.loader is None:
raise AssertionError(f"Failed to load script module: {script_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module
def test_generation_has_single_maintenance_entrypoint_script() -> None:
scripts = sorted(p.name for p in (ROOT / "scripts").glob("*.py"))
assert scripts == ["update_sdk_artifacts.py"]
def test_generate_types_wires_all_generation_steps() -> None:
source = (ROOT / "scripts" / "update_sdk_artifacts.py").read_text()
tree = ast.parse(source)
generate_types_fn = next(
(
node
for node in tree.body
if isinstance(node, ast.FunctionDef) and node.name == "generate_types"
),
None,
)
assert generate_types_fn is not None
calls: list[str] = []
for node in generate_types_fn.body:
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
fn = node.value.func
if isinstance(fn, ast.Name):
calls.append(fn.id)
assert calls == [
"generate_v2_all",
"generate_notification_registry",
"generate_public_api_flat_methods",
]
def test_schema_normalization_only_flattens_string_literal_oneofs() -> None:
script = _load_update_script_module()
schema = json.loads(
(
ROOT.parent.parent
/ "codex-rs"
/ "app-server-protocol"
/ "schema"
/ "json"
/ "codex_app_server_protocol.v2.schemas.json"
).read_text()
)
definitions = schema["definitions"]
flattened = [
name
for name, definition in definitions.items()
if isinstance(definition, dict)
and script._flatten_string_enum_one_of(definition.copy())
]
assert flattened == [
"AuthMode",
"CommandExecOutputStream",
"ExperimentalFeatureStage",
"InputModality",
"MessagePhase",
]
def test_python_codegen_schema_annotation_adds_stable_variant_titles() -> None:
script = _load_update_script_module()
schema = json.loads(
(
ROOT.parent.parent
/ "codex-rs"
/ "app-server-protocol"
/ "schema"
/ "json"
/ "codex_app_server_protocol.v2.schemas.json"
).read_text()
)
script._annotate_schema(schema)
definitions = schema["definitions"]
server_notification_titles = {
variant.get("title")
for variant in definitions["ServerNotification"]["oneOf"]
if isinstance(variant, dict)
}
assert "ErrorServerNotification" in server_notification_titles
assert "ThreadStartedServerNotification" in server_notification_titles
assert "ErrorNotification" not in server_notification_titles
assert "Thread/startedNotification" not in server_notification_titles
ask_for_approval_titles = [
variant.get("title") for variant in definitions["AskForApproval"]["oneOf"]
]
assert ask_for_approval_titles == [
"AskForApprovalValue",
"RejectAskForApproval",
]
reasoning_summary_titles = [
variant.get("title") for variant in definitions["ReasoningSummary"]["oneOf"]
]
assert reasoning_summary_titles == [
"ReasoningSummaryValue",
"NoneReasoningSummary",
]
def test_generate_v2_all_uses_titles_for_generated_names() -> None:
source = (ROOT / "scripts" / "update_sdk_artifacts.py").read_text()
assert "--use-title-as-name" in source
assert "--use-annotated" in source
assert "--formatters" in source
assert "ruff-format" in source
def test_runtime_package_template_has_no_checked_in_binaries() -> None:
runtime_root = ROOT.parent / "python-runtime" / "src" / "codex_cli_bin"
assert sorted(
path.name
for path in runtime_root.rglob("*")
if path.is_file() and "__pycache__" not in path.parts
) == ["__init__.py"]
def test_runtime_package_is_wheel_only_and_builds_platform_specific_wheels() -> None:
pyproject = tomllib.loads(
(ROOT.parent / "python-runtime" / "pyproject.toml").read_text()
)
hook_source = (ROOT.parent / "python-runtime" / "hatch_build.py").read_text()
hook_tree = ast.parse(hook_source)
initialize_fn = next(
node
for node in ast.walk(hook_tree)
if isinstance(node, ast.FunctionDef) and node.name == "initialize"
)
sdist_guard = next(
(
node
for node in initialize_fn.body
if isinstance(node, ast.If)
and isinstance(node.test, ast.Compare)
and isinstance(node.test.left, ast.Attribute)
and isinstance(node.test.left.value, ast.Name)
and node.test.left.value.id == "self"
and node.test.left.attr == "target_name"
and len(node.test.ops) == 1
and isinstance(node.test.ops[0], ast.Eq)
and len(node.test.comparators) == 1
and isinstance(node.test.comparators[0], ast.Constant)
and node.test.comparators[0].value == "sdist"
),
None,
)
build_data_assignments = {
node.targets[0].slice.value: node.value.value
for node in initialize_fn.body
if isinstance(node, ast.Assign)
and len(node.targets) == 1
and isinstance(node.targets[0], ast.Subscript)
and isinstance(node.targets[0].value, ast.Name)
and node.targets[0].value.id == "build_data"
and isinstance(node.targets[0].slice, ast.Constant)
and isinstance(node.targets[0].slice.value, str)
and isinstance(node.value, ast.Constant)
}
assert pyproject["tool"]["hatch"]["build"]["targets"]["wheel"] == {
"packages": ["src/codex_cli_bin"],
"include": ["src/codex_cli_bin/bin/**"],
"hooks": {"custom": {}},
}
assert pyproject["tool"]["hatch"]["build"]["targets"]["sdist"] == {
"hooks": {"custom": {}},
}
assert sdist_guard is not None
assert build_data_assignments == {"pure_python": False, "infer_tag": True}
def test_stage_runtime_release_copies_binary_and_sets_version(tmp_path: Path) -> None:
script = _load_update_script_module()
fake_binary = tmp_path / script.runtime_binary_name()
fake_binary.write_text("fake codex\n")
staged = script.stage_python_runtime_package(
tmp_path / "runtime-stage",
"1.2.3",
fake_binary,
)
assert staged == tmp_path / "runtime-stage"
assert script.staged_runtime_bin_path(staged).read_text() == "fake codex\n"
assert 'version = "1.2.3"' in (staged / "pyproject.toml").read_text()
def test_stage_runtime_release_replaces_existing_staging_dir(tmp_path: Path) -> None:
script = _load_update_script_module()
staging_dir = tmp_path / "runtime-stage"
old_file = staging_dir / "stale.txt"
old_file.parent.mkdir(parents=True)
old_file.write_text("stale")
fake_binary = tmp_path / script.runtime_binary_name()
fake_binary.write_text("fake codex\n")
staged = script.stage_python_runtime_package(
staging_dir,
"1.2.3",
fake_binary,
)
assert staged == staging_dir
assert not old_file.exists()
assert script.staged_runtime_bin_path(staged).read_text() == "fake codex\n"
def test_stage_sdk_release_injects_exact_runtime_pin(tmp_path: Path) -> None:
script = _load_update_script_module()
staged = script.stage_python_sdk_package(tmp_path / "sdk-stage", "0.2.1", "1.2.3")
pyproject = (staged / "pyproject.toml").read_text()
assert 'version = "0.2.1"' in pyproject
assert '"codex-cli-bin==1.2.3"' in pyproject
assert not any((staged / "src" / "codex_app_server").glob("bin/**"))
def test_stage_sdk_release_replaces_existing_staging_dir(tmp_path: Path) -> None:
script = _load_update_script_module()
staging_dir = tmp_path / "sdk-stage"
old_file = staging_dir / "stale.txt"
old_file.parent.mkdir(parents=True)
old_file.write_text("stale")
staged = script.stage_python_sdk_package(staging_dir, "0.2.1", "1.2.3")
assert staged == staging_dir
assert not old_file.exists()
def test_stage_sdk_runs_type_generation_before_staging(tmp_path: Path) -> None:
script = _load_update_script_module()
calls: list[str] = []
args = script.parse_args(
[
"stage-sdk",
str(tmp_path / "sdk-stage"),
"--runtime-version",
"1.2.3",
]
)
def fake_generate_types() -> None:
calls.append("generate_types")
def fake_stage_sdk_package(
_staging_dir: Path, _sdk_version: str, _runtime_version: str
) -> Path:
calls.append("stage_sdk")
return tmp_path / "sdk-stage"
def fake_stage_runtime_package(
_staging_dir: Path, _runtime_version: str, _runtime_binary: Path
) -> Path:
raise AssertionError("runtime staging should not run for stage-sdk")
def fake_current_sdk_version() -> str:
return "0.2.0"
ops = script.CliOps(
generate_types=fake_generate_types,
stage_python_sdk_package=fake_stage_sdk_package,
stage_python_runtime_package=fake_stage_runtime_package,
current_sdk_version=fake_current_sdk_version,
)
script.run_command(args, ops)
assert calls == ["generate_types", "stage_sdk"]
def test_stage_runtime_stages_binary_without_type_generation(tmp_path: Path) -> None:
script = _load_update_script_module()
fake_binary = tmp_path / script.runtime_binary_name()
fake_binary.write_text("fake codex\n")
calls: list[str] = []
args = script.parse_args(
[
"stage-runtime",
str(tmp_path / "runtime-stage"),
str(fake_binary),
"--runtime-version",
"1.2.3",
]
)
def fake_generate_types() -> None:
calls.append("generate_types")
def fake_stage_sdk_package(
_staging_dir: Path, _sdk_version: str, _runtime_version: str
) -> Path:
raise AssertionError("sdk staging should not run for stage-runtime")
def fake_stage_runtime_package(
_staging_dir: Path, _runtime_version: str, _runtime_binary: Path
) -> Path:
calls.append("stage_runtime")
return tmp_path / "runtime-stage"
def fake_current_sdk_version() -> str:
return "0.2.0"
ops = script.CliOps(
generate_types=fake_generate_types,
stage_python_sdk_package=fake_stage_sdk_package,
stage_python_runtime_package=fake_stage_runtime_package,
current_sdk_version=fake_current_sdk_version,
)
script.run_command(args, ops)
assert calls == ["stage_runtime"]
def test_default_runtime_is_resolved_from_installed_runtime_package(
tmp_path: Path,
) -> None:
from codex_app_server import client as client_module
fake_binary = tmp_path / ("codex.exe" if client_module.os.name == "nt" else "codex")
fake_binary.write_text("")
ops = client_module.CodexBinResolverOps(
installed_codex_path=lambda: fake_binary,
path_exists=lambda path: path == fake_binary,
)
config = client_module.AppServerConfig()
assert config.codex_bin is None
assert client_module.resolve_codex_bin(config, ops) == fake_binary
def test_explicit_codex_bin_override_takes_priority(tmp_path: Path) -> None:
from codex_app_server import client as client_module
explicit_binary = tmp_path / (
"custom-codex.exe" if client_module.os.name == "nt" else "custom-codex"
)
explicit_binary.write_text("")
ops = client_module.CodexBinResolverOps(
installed_codex_path=lambda: (_ for _ in ()).throw(
AssertionError("packaged runtime should not be used")
),
path_exists=lambda path: path == explicit_binary,
)
config = client_module.AppServerConfig(codex_bin=str(explicit_binary))
assert client_module.resolve_codex_bin(config, ops) == explicit_binary
def test_missing_runtime_package_requires_explicit_codex_bin() -> None:
from codex_app_server import client as client_module
ops = client_module.CodexBinResolverOps(
installed_codex_path=lambda: (_ for _ in ()).throw(
FileNotFoundError("missing packaged runtime")
),
path_exists=lambda _path: False,
)
with pytest.raises(FileNotFoundError, match="missing packaged runtime"):
client_module.resolve_codex_bin(client_module.AppServerConfig(), ops)
def test_broken_runtime_package_does_not_fall_back() -> None:
from codex_app_server import client as client_module
ops = client_module.CodexBinResolverOps(
installed_codex_path=lambda: (_ for _ in ()).throw(
FileNotFoundError("missing packaged binary")
),
path_exists=lambda _path: False,
)
with pytest.raises(FileNotFoundError) as exc_info:
client_module.resolve_codex_bin(client_module.AppServerConfig(), ops)
assert str(exc_info.value) == ("missing packaged binary")

View File

@@ -0,0 +1,95 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
from codex_app_server.client import AppServerClient, _params_dict
from codex_app_server.generated.v2_all import ThreadListParams, ThreadTokenUsageUpdatedNotification
from codex_app_server.models import UnknownNotification
ROOT = Path(__file__).resolve().parents[1]
def test_thread_set_name_and_compact_use_current_rpc_methods() -> None:
client = AppServerClient()
calls: list[tuple[str, dict[str, Any] | None]] = []
def fake_request(method: str, params, *, response_model): # type: ignore[no-untyped-def]
calls.append((method, params))
return response_model.model_validate({})
client.request = fake_request # type: ignore[method-assign]
client.thread_set_name("thread-1", "sdk-name")
client.thread_compact("thread-1")
assert calls[0][0] == "thread/name/set"
assert calls[1][0] == "thread/compact/start"
def test_generated_params_models_are_snake_case_and_dump_by_alias() -> None:
params = ThreadListParams(search_term="needle", limit=5)
assert "search_term" in ThreadListParams.model_fields
dumped = _params_dict(params)
assert dumped == {"searchTerm": "needle", "limit": 5}
def test_generated_v2_bundle_has_single_shared_plan_type_definition() -> None:
source = (ROOT / "src" / "codex_app_server" / "generated" / "v2_all.py").read_text()
assert source.count("class PlanType(") == 1
def test_notifications_are_typed_with_canonical_v2_methods() -> None:
client = AppServerClient()
event = client._coerce_notification(
"thread/tokenUsage/updated",
{
"threadId": "thread-1",
"turnId": "turn-1",
"tokenUsage": {
"last": {
"cachedInputTokens": 0,
"inputTokens": 1,
"outputTokens": 2,
"reasoningOutputTokens": 0,
"totalTokens": 3,
},
"total": {
"cachedInputTokens": 0,
"inputTokens": 1,
"outputTokens": 2,
"reasoningOutputTokens": 0,
"totalTokens": 3,
},
},
},
)
assert event.method == "thread/tokenUsage/updated"
assert isinstance(event.payload, ThreadTokenUsageUpdatedNotification)
assert event.payload.turn_id == "turn-1"
def test_unknown_notifications_fall_back_to_unknown_payloads() -> None:
client = AppServerClient()
event = client._coerce_notification(
"unknown/notification",
{
"id": "evt-1",
"conversationId": "thread-1",
"msg": {"type": "turn_aborted"},
},
)
assert event.method == "unknown/notification"
assert isinstance(event.payload, UnknownNotification)
assert event.payload.params["msg"] == {"type": "turn_aborted"}
def test_invalid_notification_payload_falls_back_to_unknown() -> None:
client = AppServerClient()
event = client._coerce_notification("thread/tokenUsage/updated", {"threadId": "missing"})
assert event.method == "thread/tokenUsage/updated"
assert isinstance(event.payload, UnknownNotification)

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import os
import subprocess
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
GENERATED_TARGETS = [
Path("src/codex_app_server/generated/notification_registry.py"),
Path("src/codex_app_server/generated/v2_all.py"),
Path("src/codex_app_server/public_api.py"),
]
def _snapshot_target(root: Path, rel_path: Path) -> dict[str, bytes] | bytes | None:
target = root / rel_path
if not target.exists():
return None
if target.is_file():
return target.read_bytes()
snapshot: dict[str, bytes] = {}
for path in sorted(target.rglob("*")):
if path.is_file() and "__pycache__" not in path.parts:
snapshot[str(path.relative_to(target))] = path.read_bytes()
return snapshot
def _snapshot_targets(root: Path) -> dict[str, dict[str, bytes] | bytes | None]:
return {
str(rel_path): _snapshot_target(root, rel_path) for rel_path in GENERATED_TARGETS
}
def test_generated_files_are_up_to_date():
before = _snapshot_targets(ROOT)
# Regenerate contract artifacts via single maintenance entrypoint.
env = os.environ.copy()
python_bin = str(Path(sys.executable).parent)
env["PATH"] = f"{python_bin}{os.pathsep}{env.get('PATH', '')}"
subprocess.run(
[sys.executable, "scripts/update_sdk_artifacts.py", "generate-types"],
cwd=ROOT,
check=True,
env=env,
)
after = _snapshot_targets(ROOT)
assert before == after, "Generated files drifted after regeneration"