diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index f276868484..1bed79c3ca 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -21,7 +21,7 @@ "settings": { "terminal.integrated.defaultProfile.linux": "bash" }, - "extensions": ["rust-lang.rust-analyzer"] + "extensions": ["rust-lang.rust-analyzer", "tamasfe.even-better-toml"] } } } diff --git a/.github/actions/codex/bun.lock b/.github/actions/codex/bun.lock index e7382ff7a2..8b546a5ac6 100644 --- a/.github/actions/codex/bun.lock +++ b/.github/actions/codex/bun.lock @@ -8,8 +8,8 @@ "@actions/github": "^6.0.1", }, "devDependencies": { - "@types/bun": "^1.2.18", - "@types/node": "^24.0.13", + "@types/bun": "^1.2.19", + "@types/node": "^24.1.0", "prettier": "^3.6.2", "typescript": "^5.8.3", }, @@ -48,15 +48,15 @@ "@octokit/types": ["@octokit/types@13.10.0", "", { "dependencies": { "@octokit/openapi-types": "^24.2.0" } }, "sha512-ifLaO34EbbPj0Xgro4G5lP5asESjwHracYJvVaPIyXMuiuXLlhic3S47cBdTb+jfODkTE5YtGCLt3Ay3+J97sA=="], - "@types/bun": ["@types/bun@1.2.18", "", { "dependencies": { "bun-types": "1.2.18" } }, "sha512-Xf6RaWVheyemaThV0kUfaAUvCNokFr+bH8Jxp+tTZfx7dAPA8z9ePnP9S9+Vspzuxxx9JRAXhnyccRj3GyCMdQ=="], + "@types/bun": ["@types/bun@1.2.19", "", { "dependencies": { "bun-types": "1.2.19" } }, "sha512-d9ZCmrH3CJ2uYKXQIUuZ/pUnTqIvLDS0SK7pFmbx8ma+ziH/FRMoAq5bYpRG7y+w1gl+HgyNZbtqgMq4W4e2Lg=="], - "@types/node": ["@types/node@24.0.13", "", { "dependencies": { "undici-types": "~7.8.0" } }, "sha512-Qm9OYVOFHFYg3wJoTSrz80hoec5Lia/dPp84do3X7dZvLikQvM1YpmvTBEdIr/e+U8HTkFjLHLnl78K/qjf+jQ=="], + "@types/node": ["@types/node@24.1.0", "", { "dependencies": { "undici-types": "~7.8.0" } }, "sha512-ut5FthK5moxFKH2T1CUOC6ctR67rQRvvHdFLCD2Ql6KXmMuCrjsSsRI9UsLCm9M18BMwClv4pn327UvB7eeO1w=="], "@types/react": ["@types/react@19.1.8", "", { "dependencies": { "csstype": "^3.0.2" } }, "sha512-AwAfQ2Wa5bCx9WP8nZL2uMZWod7J7/JSplxbTmBQ5ms6QpqNYm672H0Vu9ZVKVngQ+ii4R/byguVEUZQyeg44g=="], "before-after-hook": ["before-after-hook@2.2.3", "", {}, "sha512-NzUnlZexiaH/46WDhANlyR2bXRopNg4F/zuSA3OpZnllCUgRaOF2znDioDWrmbNVsuZk6l9pMquQB38cfBZwkQ=="], - "bun-types": ["bun-types@1.2.18", "", { "dependencies": { "@types/node": "*" }, "peerDependencies": { "@types/react": "^19" } }, "sha512-04+Eha5NP7Z0A9YgDAzMk5PHR16ZuLVa83b26kH5+cp1qZW4F6FmAURngE7INf4tKOvCE69vYvDEwoNl1tGiWw=="], + "bun-types": ["bun-types@1.2.19", "", { "dependencies": { "@types/node": "*" }, "peerDependencies": { "@types/react": "^19" } }, "sha512-uAOTaZSPuYsWIXRpj7o56Let0g/wjihKCkeRqUBhlLVM/Bt+Fj9xTo+LhC1OV1XDaGkz4hNC80et5xgy+9KTHQ=="], "csstype": ["csstype@3.1.3", "", {}, "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw=="], @@ -82,6 +82,8 @@ "@octokit/plugin-rest-endpoint-methods/@octokit/types": ["@octokit/types@12.6.0", "", { "dependencies": { "@octokit/openapi-types": "^20.0.0" } }, "sha512-1rhSOfRa6H9w4YwK0yrf5faDaDTb+yLyBUKOCV4xtCDB5VmIPqd/v9yr9o6SAzOAlRxMiRiCic6JVM1/kunVkw=="], + "bun-types/@types/node": ["@types/node@24.0.13", "", { "dependencies": { "undici-types": "~7.8.0" } }, "sha512-Qm9OYVOFHFYg3wJoTSrz80hoec5Lia/dPp84do3X7dZvLikQvM1YpmvTBEdIr/e+U8HTkFjLHLnl78K/qjf+jQ=="], + "@octokit/plugin-paginate-rest/@octokit/types/@octokit/openapi-types": ["@octokit/openapi-types@20.0.0", "", {}, "sha512-EtqRBEjp1dL/15V7WiX5LJMIxxkdiGJnabzYx5Apx4FkQIFgAfKumXeYAqqJCj1s+BMX4cPFIFC4OLCR6stlnA=="], "@octokit/plugin-rest-endpoint-methods/@octokit/types/@octokit/openapi-types": ["@octokit/openapi-types@20.0.0", "", {}, "sha512-EtqRBEjp1dL/15V7WiX5LJMIxxkdiGJnabzYx5Apx4FkQIFgAfKumXeYAqqJCj1s+BMX4cPFIFC4OLCR6stlnA=="], diff --git a/.github/actions/codex/package.json b/.github/actions/codex/package.json index 53260f4d58..21817b8a59 100644 --- a/.github/actions/codex/package.json +++ b/.github/actions/codex/package.json @@ -13,8 +13,8 @@ "@actions/github": "^6.0.1" }, "devDependencies": { - "@types/bun": "^1.2.18", - "@types/node": "^24.0.13", + "@types/bun": "^1.2.19", + "@types/node": "^24.1.0", "prettier": "^3.6.2", "typescript": "^5.8.3" } diff --git a/.github/codex/labels/codex-rust-review.md b/.github/codex/labels/codex-rust-review.md new file mode 100644 index 0000000000..2c2893a1fe --- /dev/null +++ b/.github/codex/labels/codex-rust-review.md @@ -0,0 +1,23 @@ +Review this PR and respond with a very concise final message, formatted in Markdown. + +There should be a summary of the changes (1-2 sentences) and a few bullet points if necessary. + +Then provide the **review** (1-2 sentences plus bullet points, friendly tone). + +Things to look out for when doing the review: + +- **Make sure the pull request body explains the motivation behind the change.** If the author has failed to do this, call it out, and if you think you can deduce the motivation behind the change, propose copy. +- Ideally, the PR body also contains a small summary of the change. For small changes, the PR title may be sufficient. +- Each PR should ideally do one conceptual thing. For example, if a PR does a refactoring as well as introducing a new feature, push back and suggest the refactoring be done in a separate PR. This makes things easier for the reviewer, as refactoring changes can often be far-reaching, yet quick to review. +- If the nature of the change seems to have a visual component (which is often the case for changes to `codex-rs/tui`), recommend including a screenshot or video to demonstrate the change, if appropriate. +- Rust files should generally be organized such that the public parts of the API appear near the top of the file and helper functions go below. This is analagous to the "inverted pyramid" structure that is favored in journalism. +- Encourage the use of small enums or the newtype pattern in Rust if it helps readability without adding significant cognitive load or lines of code. +- Be wary of large files and offer suggestions for how to break things into more reasonably-sized files. +- When modifying a `Cargo.toml` file, make sure that dependency lists stay alphabetically sorted. Also consider whether a new dependency is added to the appropriate place (e.g., `[dependencies]` versus `[dev-dependencies]`) +- If you see opportunities for the changes in a diff to use more idiomatic Rust, please make specific recommendations. For example, favor the use of expressions over `return`. +- When introducing new code, be on the lookout for code that duplicates existing code. When found, propose a way to refactor the existing code such that it should be reused. +- Each create in the Cargo workspace in `codex-rs` has a specific purpose: make a note if you believe new code is not introduced in the correct crate. +- When possible, try to keep the `core` crate as small as possible. Non-core but shared logic is often a good candidate for `codex-rs/common`. +- References to existing GitHub issues and PRs are encouraged, where appropriate, though you likely do not have network access, so may not be able to help here. + +{CODEX_ACTION_GITHUB_EVENT_PATH} contains the JSON that triggered this GitHub workflow. It contains the `base` and `head` refs that define this PR. Both refs are available locally. diff --git a/.github/workflows/codex.yml b/.github/workflows/codex.yml index a0ac5b9740..18fe74cc85 100644 --- a/.github/workflows/codex.yml +++ b/.github/workflows/codex.yml @@ -20,7 +20,7 @@ jobs: (github.event_name == 'issues' && ( (github.event.action == 'labeled' && (github.event.label.name == 'codex-attempt' || github.event.label.name == 'codex-triage')) )) || - (github.event_name == 'pull_request' && github.event.action == 'labeled' && github.event.label.name == 'codex-review') + (github.event_name == 'pull_request' && github.event.action == 'labeled' && (github.event.label.name == 'codex-review' || github.event.label.name == 'codex-rust-review')) runs-on: ubuntu-latest permissions: contents: write # can push or create branches diff --git a/.github/workflows/rust-release.yml b/.github/workflows/rust-release.yml index 7b765bed17..3f1c084d91 100644 --- a/.github/workflows/rust-release.yml +++ b/.github/workflows/rust-release.yml @@ -93,7 +93,7 @@ jobs: sudo apt install -y musl-tools pkg-config - name: Cargo build - run: cargo build --target ${{ matrix.target }} --release --all-targets --all-features + run: cargo build --target ${{ matrix.target }} --release --bin codex --bin codex-exec --bin codex-linux-sandbox - name: Stage artifacts shell: bash diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000000..dd5dac527d --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,5 @@ +{ + "recommendations": [ + "tamasfe.even-better-toml", + ] +} diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000..618207f301 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,18 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Cargo launch", + "cargo": { + "cwd": "${workspaceFolder}/codex-rs", + "args": [ + "build", + "--bin=codex-tui" + ] + }, + "args": [] + } + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000000..1712f5989b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,16 @@ +{ + "rust-analyzer.checkOnSave": true, + "rust-analyzer.check.command": "clippy", + "rust-analyzer.check.extraArgs": ["--all-features", "--tests"], + "rust-analyzer.rustfmt.extraArgs": ["--config", "imports_granularity=Item"], + "[rust]": { + "editor.defaultFormatter": "rust-lang.rust-analyzer", + "editor.formatOnSave": true, + }, + "[toml]": { + "editor.defaultFormatter": "tamasfe.even-better-toml", + "editor.formatOnSave": true, + }, + "evenBetterToml.formatter.reorderArrays": true, + "evenBetterToml.formatter.reorderKeys": true, +} diff --git a/AGENTS.md b/AGENTS.md index 1348e57824..27af48ae60 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -3,3 +3,7 @@ In the codex-rs folder where the rust code lives: - Never add or modify any code related to `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR`. You operate in a sandbox where `CODEX_SANDBOX_NETWORK_DISABLED=1` will be set whenever you use the `shell` tool. Any existing code that uses `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` was authored with this fact in mind. It is often used to early exit out of tests that the author knew you would not be able to run given your sandbox limitations. + +Before creating a pull request with changes to `codex-rs`, run `just fmt` (in `codex-rs` directory) to format the code and `just fix` (in `codex-rs` directory) to fix any linter issues in the code, ensure the test suite passes by running `cargo test --all-features` in the `codex-rs` directory. + +When making individual changes prefer running tests on individual files or projects first. diff --git a/NOTICE b/NOTICE index ad09ca421e..2805899d56 100644 --- a/NOTICE +++ b/NOTICE @@ -1,2 +1,6 @@ OpenAI Codex Copyright 2025 OpenAI + +This project includes code derived from [Ratatui](https://github.com/ratatui/ratatui), licensed under the MIT license. +Copyright (c) 2016-2022 Florian Dehau +Copyright (c) 2023-2025 The Ratatui Developers diff --git a/README.md b/README.md index 60e44298a3..c7f6a1d595 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,12 @@ codex login If you complete the process successfully, you should have a `~/.codex/auth.json` file that contains the credentials that Codex will use. +To verify whether you are currently logged in, run: + +``` +codex login status +``` + If you encounter problems with the login flow, please comment on .
diff --git a/codex-cli/bin/codex.js b/codex-cli/bin/codex.js index 54b99078e4..ae1fb9593c 100755 --- a/codex-cli/bin/codex.js +++ b/codex-cli/bin/codex.js @@ -15,7 +15,6 @@ * current platform / architecture, an error is thrown. */ -import { spawnSync } from "child_process"; import fs from "fs"; import path from "path"; import { fileURLToPath, pathToFileURL } from "url"; @@ -35,7 +34,7 @@ const wantsNative = fs.existsSync(path.join(__dirname, "use-native")) || : false); // Try native binary if requested. -if (wantsNative) { +if (wantsNative && process.platform !== 'win32') { const { platform, arch } = process; let targetTriple = null; @@ -74,22 +73,76 @@ if (wantsNative) { } const binaryPath = path.join(__dirname, "..", "bin", `codex-${targetTriple}`); - const result = spawnSync(binaryPath, process.argv.slice(2), { + + // Use an asynchronous spawn instead of spawnSync so that Node is able to + // respond to signals (e.g. Ctrl-C / SIGINT) while the native binary is + // executing. This allows us to forward those signals to the child process + // and guarantees that when either the child terminates or the parent + // receives a fatal signal, both processes exit in a predictable manner. + const { spawn } = await import("child_process"); + + const child = spawn(binaryPath, process.argv.slice(2), { stdio: "inherit", }); - const exitCode = typeof result.status === "number" ? result.status : 1; - process.exit(exitCode); -} + child.on("error", (err) => { + // Typically triggered when the binary is missing or not executable. + // Re-throwing here will terminate the parent with a non-zero exit code + // while still printing a helpful stack trace. + // eslint-disable-next-line no-console + console.error(err); + process.exit(1); + }); -// Fallback: execute the original JavaScript CLI. + // Forward common termination signals to the child so that it shuts down + // gracefully. In the handler we temporarily disable the default behavior of + // exiting immediately; once the child has been signaled we simply wait for + // its exit event which will in turn terminate the parent (see below). + const forwardSignal = (signal) => { + if (child.killed) { + return; + } + try { + child.kill(signal); + } catch { + /* ignore */ + } + }; -// Resolve the path to the compiled CLI bundle -const cliPath = path.resolve(__dirname, "../dist/cli.js"); -const cliUrl = pathToFileURL(cliPath).href; + ["SIGINT", "SIGTERM", "SIGHUP"].forEach((sig) => { + process.on(sig, () => forwardSignal(sig)); + }); -// Load and execute the CLI -(async () => { + // When the child exits, mirror its termination reason in the parent so that + // shell scripts and other tooling observe the correct exit status. + // Wrap the lifetime of the child process in a Promise so that we can await + // its termination in a structured way. The Promise resolves with an object + // describing how the child exited: either via exit code or due to a signal. + const childResult = await new Promise((resolve) => { + child.on("exit", (code, signal) => { + if (signal) { + resolve({ type: "signal", signal }); + } else { + resolve({ type: "code", exitCode: code ?? 1 }); + } + }); + }); + + if (childResult.type === "signal") { + // Re-emit the same signal so that the parent terminates with the expected + // semantics (this also sets the correct exit code of 128 + n). + process.kill(process.pid, childResult.signal); + } else { + process.exit(childResult.exitCode); + } +} else { + // Fallback: execute the original JavaScript CLI. + + // Resolve the path to the compiled CLI bundle + const cliPath = path.resolve(__dirname, "../dist/cli.js"); + const cliUrl = pathToFileURL(cliPath).href; + + // Load and execute the CLI try { await import(cliUrl); } catch (err) { @@ -97,4 +150,4 @@ const cliUrl = pathToFileURL(cliPath).href; console.error(err); process.exit(1); } -})(); +} diff --git a/codex-cli/scripts/README.md b/codex-cli/scripts/README.md new file mode 100644 index 0000000000..21e4f3e883 --- /dev/null +++ b/codex-cli/scripts/README.md @@ -0,0 +1,9 @@ +# npm releases + +Run the following: + +To build the 0.2.x or later version of the npm module, which runs the Rust version of the CLI, build it as follows: + +```bash +./codex-cli/scripts/stage_rust_release.py --release-version 0.6.0 +``` diff --git a/codex-cli/scripts/stage_release.sh b/codex-cli/scripts/stage_release.sh index 29b9f76783..cd32ade6f9 100755 --- a/codex-cli/scripts/stage_release.sh +++ b/codex-cli/scripts/stage_release.sh @@ -4,10 +4,7 @@ # ----------------------------------------------------------------------------- # Stages an npm release for @openai/codex. # -# The script used to accept a single optional positional argument that indicated -# the temporary directory in which to stage the package. We now support a -# flag-based interface so that we can extend the command with further options -# without breaking the call-site contract. +# Usage: # # --tmp : Use instead of a freshly created temp directory. # --native : Bundle the pre-built Rust CLI binaries for Linux alongside @@ -141,7 +138,8 @@ popd >/dev/null echo "Staged version $VERSION for release in $TMPDIR" if [[ "$INCLUDE_NATIVE" -eq 1 ]]; then - echo "Test Rust:" + echo "Verify the CLI:" + echo " node ${TMPDIR}/bin/codex.js --version" echo " node ${TMPDIR}/bin/codex.js --help" else echo "Test Node:" diff --git a/codex-cli/src/approvals.ts b/codex-cli/src/approvals.ts index e626da7fa5..35b8c0ae16 100644 --- a/codex-cli/src/approvals.ts +++ b/codex-cli/src/approvals.ts @@ -370,11 +370,26 @@ export function isSafeCommand( reason: "View file with line numbers", group: "Reading files", }; - case "rg": + case "rg": { + // Certain ripgrep options execute external commands or invoke other + // processes, so we must reject them. + const isUnsafe = command.some( + (arg: string) => + UNSAFE_OPTIONS_FOR_RIPGREP_WITHOUT_ARGS.has(arg) || + [...UNSAFE_OPTIONS_FOR_RIPGREP_WITH_ARGS].some( + (opt) => arg === opt || arg.startsWith(`${opt}=`), + ), + ); + + if (isUnsafe) { + break; + } + return { reason: "Ripgrep search", group: "Searching", }; + } case "find": { // Certain options to `find` allow executing arbitrary processes, so we // cannot auto-approve them. @@ -495,6 +510,22 @@ const UNSAFE_OPTIONS_FOR_FIND_COMMAND: ReadonlySet = new Set([ "-fprintf", ]); +// Ripgrep options that are considered unsafe because they may execute +// arbitrary commands or spawn auxiliary processes. +const UNSAFE_OPTIONS_FOR_RIPGREP_WITH_ARGS: ReadonlySet = new Set([ + // Executes an arbitrary command for each matching file. + "--pre", + // Allows custom hostname command which could leak environment details. + "--hostname-bin", +]); + +const UNSAFE_OPTIONS_FOR_RIPGREP_WITHOUT_ARGS: ReadonlySet = new Set([ + // Enables searching inside archives which triggers external decompression + // utilities – reject out of an abundance of caution. + "--search-zip", + "-z", +]); + // ---------------- Helper utilities for complex shell expressions ----------------- // A conservative allow-list of bash operators that do not, on their own, cause diff --git a/codex-cli/tests/approvals.test.ts b/codex-cli/tests/approvals.test.ts index c592c39525..645ab44ce9 100644 --- a/codex-cli/tests/approvals.test.ts +++ b/codex-cli/tests/approvals.test.ts @@ -44,6 +44,14 @@ describe("canAutoApprove()", () => { group: "Navigating", runInSandbox: false, }); + + // Ripgrep safe invocation. + expect(check(["rg", "TODO"])).toEqual({ + type: "auto-approve", + reason: "Ripgrep search", + group: "Searching", + runInSandbox: false, + }); }); test("simple safe commands within a `bash -lc` call", () => { @@ -67,6 +75,24 @@ describe("canAutoApprove()", () => { }); }); + test("ripgrep unsafe flags", () => { + // Flags that do not take arguments + expect(check(["rg", "--search-zip", "TODO"])).toEqual({ type: "ask-user" }); + expect(check(["rg", "-z", "TODO"])).toEqual({ type: "ask-user" }); + + // Flags that take arguments (provided separately) + expect(check(["rg", "--pre", "cat", "TODO"])).toEqual({ type: "ask-user" }); + expect(check(["rg", "--hostname-bin", "hostname", "TODO"])).toEqual({ + type: "ask-user", + }); + + // Flags that take arguments in = form + expect(check(["rg", "--pre=cat", "TODO"])).toEqual({ type: "ask-user" }); + expect(check(["rg", "--hostname-bin=hostname", "TODO"])).toEqual({ + type: "ask-user", + }); + }); + test("bash -lc commands with unsafe redirects", () => { expect(check(["bash", "-lc", "echo hello > file.txt"])).toEqual({ type: "ask-user", diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 3de3e78198..120050c227 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -250,6 +250,28 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "async-trait" version = "0.1.88" @@ -377,6 +399,15 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bstr" version = "1.12.0" @@ -432,18 +463,18 @@ checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" [[package]] name = "castaway" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0abae9be0aaf9ea96a3b1b8b1b55c602ca751eba1b1500220cea4ecbafe7c0d5" +checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" dependencies = [ "rustversion", ] [[package]] name = "cc" -version = "1.2.29" +version = "1.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362" +checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" dependencies = [ "jobserver", "libc", @@ -539,9 +570,9 @@ checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" [[package]] name = "clipboard-win" -version = "5.4.0" +version = "5.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15efe7a882b08f34e38556b14f2fb3daa98769d06c7f0c1b076dfd0d983bc892" +checksum = "bde03770d3df201d4fb868f2c9c59e66a3e4e2bd06692a0fe701e7103c7e84d4" dependencies = [ "error-code", ] @@ -574,6 +605,18 @@ dependencies = [ "tree-sitter-bash", ] +[[package]] +name = "codex-arg0" +version = "0.0.0" +dependencies = [ + "anyhow", + "codex-apply-patch", + "codex-core", + "codex-linux-sandbox", + "dotenvy", + "tokio", +] + [[package]] name = "codex-chatgpt" version = "0.0.0" @@ -597,11 +640,11 @@ dependencies = [ "anyhow", "clap", "clap_complete", + "codex-arg0", "codex-chatgpt", "codex-common", "codex-core", "codex-exec", - "codex-linux-sandbox", "codex-login", "codex-mcp-server", "codex-tui", @@ -618,7 +661,7 @@ dependencies = [ "clap", "codex-core", "serde", - "toml 0.9.1", + "toml 0.9.2", ] [[package]] @@ -630,36 +673,45 @@ dependencies = [ "async-channel", "base64 0.22.1", "bytes", + "chrono", "codex-apply-patch", + "codex-login", "codex-mcp-client", + "core_test_support", "dirs", "env-flags", "eventsource-stream", "fs2", "futures", "landlock", + "libc", "maplit", "mcp-types", "mime_guess", "openssl-sys", "predicates", "pretty_assertions", - "rand 0.9.1", + "rand 0.9.2", "reqwest", "seccompiler", "serde", "serde_json", - "strum_macros 0.27.1", + "sha1", + "shlex", + "strum_macros 0.27.2", "tempfile", "thiserror 2.0.12", "time", "tokio", + "tokio-test", "tokio-util", - "toml 0.9.1", + "toml 0.9.2", "tracing", "tree-sitter", "tree-sitter-bash", "uuid", + "walkdir", + "whoami", "wildmatch", "wiremock", ] @@ -669,14 +721,17 @@ name = "codex-exec" version = "0.0.0" dependencies = [ "anyhow", + "assert_cmd", "chrono", "clap", + "codex-arg0", "codex-common", "codex-core", - "codex-linux-sandbox", "owo-colors", + "predicates", "serde_json", "shlex", + "tempfile", "tokio", "tracing", "tracing-subscriber", @@ -721,6 +776,7 @@ version = "0.0.0" dependencies = [ "anyhow", "clap", + "codex-common", "codex-core", "landlock", "libc", @@ -758,17 +814,25 @@ name = "codex-mcp-server" version = "0.0.0" dependencies = [ "anyhow", + "assert_cmd", + "codex-arg0", "codex-core", - "codex-linux-sandbox", "mcp-types", + "mcp_test_support", "pretty_assertions", "schemars 0.8.22", "serde", "serde_json", + "shlex", + "strum_macros 0.27.2", + "tempfile", "tokio", - "toml 0.9.1", + "tokio-test", + "toml 0.9.2", "tracing", "tracing-subscriber", + "uuid", + "wiremock", ] [[package]] @@ -779,10 +843,10 @@ dependencies = [ "base64 0.22.1", "clap", "codex-ansi-escape", + "codex-arg0", "codex-common", "codex-core", "codex-file-search", - "codex-linux-sandbox", "codex-login", "color-eyre", "crossterm", @@ -797,8 +861,8 @@ dependencies = [ "regex-lite", "serde_json", "shlex", - "strum 0.27.1", - "strum_macros 0.27.1", + "strum 0.27.2", + "strum_macros 0.27.2", "tokio", "tracing", "tracing-appender", @@ -807,6 +871,7 @@ dependencies = [ "tui-markdown", "tui-textarea", "unicode-segmentation", + "unicode-width 0.1.14", "uuid", ] @@ -910,10 +975,29 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] -name = "crc32fast" -version = "1.4.2" +name = "core_test_support" +version = "0.0.0" +dependencies = [ + "codex-core", + "serde_json", + "tempfile", + "tokio", +] + +[[package]] +name = "cpufeatures" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" dependencies = [ "cfg-if", ] @@ -983,6 +1067,16 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "ctor" version = "0.1.26" @@ -1133,6 +1227,16 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "dirs" version = "6.0.0" @@ -1202,6 +1306,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "dupe" version = "0.9.1" @@ -1434,7 +1544,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" dependencies = [ "cfg-if", - "rustix 1.0.7", + "rustix 1.0.8", "windows-sys 0.59.0", ] @@ -1622,6 +1732,16 @@ dependencies = [ "byteorder", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getopts" version = "0.2.23" @@ -1873,9 +1993,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.15" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f66d5bd4c6f02bf0542fad85d626775bab9258cf795a4256dcaf3161114d1df" +checksum = "8d9b05277c7e8da2c93a568989bb6207bef0112e8d17df7a6eda4a3cf143bc5e" dependencies = [ "base64 0.22.1", "bytes", @@ -1889,7 +2009,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2", + "socket2 0.6.0", "system-configuration", "tokio", "tower-service", @@ -2142,9 +2262,9 @@ dependencies = [ [[package]] name = "instability" -version = "0.3.7" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf9fed6d91cfb734e7476a06bde8300a1b94e217e1b523b6f0cd1a01998c71d" +checksum = "435d80800b936787d62688c927b6490e887c7ef5ff9ce922c6c6050fca75eb9a" dependencies = [ "darling", "indoc", @@ -2175,9 +2295,9 @@ dependencies = [ [[package]] name = "io-uring" -version = "0.7.8" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b86e202f00093dcba4275d4636b93ef9dd75d025ae560d2521b45ea28ab49013" +checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" dependencies = [ "bitflags 2.9.1", "cfg-if", @@ -2381,9 +2501,9 @@ dependencies = [ [[package]] name = "libredox" -version = "0.1.4" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1580801010e535496706ba011c15f8532df6b42297d2e471fec38ceadd8c0638" +checksum = "4488594b9328dee448adb906d8b126d9b7deb7cf5c22161ee591610bb1be83c0" dependencies = [ "bitflags 2.9.1", "libc", @@ -2516,6 +2636,22 @@ dependencies = [ "serde_json", ] +[[package]] +name = "mcp_test_support" +version = "0.0.0" +dependencies = [ + "anyhow", + "assert_cmd", + "codex-mcp-server", + "mcp-types", + "pretty_assertions", + "serde_json", + "shlex", + "tempfile", + "tokio", + "wiremock", +] + [[package]] name = "memchr" version = "2.7.5" @@ -3191,9 +3327,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", @@ -3240,8 +3376,7 @@ dependencies = [ [[package]] name = "ratatui" version = "0.29.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eabd94c2f37801c20583fc49dd5cd6b0ba68c716787c2dd6ed18571e1e63117b" +source = "git+https://github.com/nornagon/ratatui?branch=nornagon-v0.29.0-patch#bca287ddc5d38fe088c79e2eda22422b96226f2e" dependencies = [ "bitflags 2.9.1", "cassowary", @@ -3346,9 +3481,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.13" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d04b7d0ee6b4a0207a0a7adb104d23ecb0b47d6beae7152d0fa34b692b29fd6" +checksum = "7e8af0dde094006011e6a740d4879319439489813bd0bcdc7d821beaeeff48ec" dependencies = [ "bitflags 2.9.1", ] @@ -3496,9 +3631,9 @@ dependencies = [ [[package]] name = "rgb" -version = "0.8.51" +version = "0.8.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a457e416a0f90d246a4c3288bd7a25b2304ca727f253f95be383dd17af56be8f" +checksum = "0c6a884d2998352bb4daf0183589aec883f16a6da1f4dde84d8e2e9a5409a1ce" [[package]] name = "ring" @@ -3574,22 +3709,22 @@ dependencies = [ [[package]] name = "rustix" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8" dependencies = [ "bitflags 2.9.1", "errno", "libc", "linux-raw-sys 0.9.4", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] name = "rustls" -version = "0.23.28" +version = "0.23.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7160e3e10bf4535308537f3c4e1641468cd0e485175d6163087c0393c7d46643" +checksum = "2491382039b29b9b11ff08b76ff6c97cf287671dbb74f0be44bda389fffe9bd1" dependencies = [ "once_cell", "rustls-pki-types", @@ -3609,9 +3744,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.3" +version = "0.103.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" +checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc" dependencies = [ "ring", "rustls-pki-types", @@ -3837,9 +3972,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3" dependencies = [ "indexmap 2.10.0", "itoa", @@ -3921,6 +4056,17 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -4021,6 +4167,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -4164,9 +4320,9 @@ dependencies = [ [[package]] name = "strum" -version = "0.27.1" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" [[package]] name = "strum_macros" @@ -4183,14 +4339,13 @@ dependencies = [ [[package]] name = "strum_macros" -version = "0.27.1" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" dependencies = [ "heck", "proc-macro2", "quote", - "rustversion", "syn 2.0.104", ] @@ -4313,7 +4468,7 @@ dependencies = [ "fastrand", "getrandom 0.3.3", "once_cell", - "rustix 1.0.7", + "rustix 1.0.8", "windows-sys 0.59.0", ] @@ -4334,7 +4489,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45c6481c4829e4cc63825e62c49186a34538b7b2750b73b266581ffb612fb5ed" dependencies = [ - "rustix 1.0.7", + "rustix 1.0.8", "windows-sys 0.59.0", ] @@ -4480,7 +4635,7 @@ dependencies = [ "pin-project-lite", "signal-hook-registry", "slab", - "socket2", + "socket2 0.5.10", "tokio-macros", "windows-sys 0.52.0", ] @@ -4516,6 +4671,30 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.15" @@ -4543,9 +4722,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0207d6ed1852c2a124c1fbec61621acb8330d2bf969a5d0643131e9affd985a5" +checksum = "ed0aee96c12fa71097902e0bb061a5e1ebd766a6636bb605ba401c45c1650eac" dependencies = [ "indexmap 2.10.0", "serde", @@ -4589,18 +4768,18 @@ dependencies = [ [[package]] name = "toml_parser" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5c1c469eda89749d2230d8156a5969a69ffe0d6d01200581cdc6110674d293e" +checksum = "97200572db069e74c512a14117b296ba0a80a30123fbbb5aa1f4a348f639ca30" dependencies = [ "winnow", ] [[package]] name = "toml_writer" -version = "1.0.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b679217f2848de74cabd3e8fc5e6d66f40b7da40f8e1954d92054d9010690fd5" +checksum = "fcc842091f2def52017664b53082ecbbeb5c7731092bad69d2c63050401dfd64" [[package]] name = "tower" @@ -4733,9 +4912,9 @@ dependencies = [ [[package]] name = "tree-sitter" -version = "0.25.6" +version = "0.25.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7cf18d43cbf0bfca51f657132cc616a5097edc4424d538bae6fa60142eaf9f0" +checksum = "6d7b8994f367f16e6fa14b5aebbcb350de5d7cbea82dc5b00ae997dd71680dd2" dependencies = [ "cc", "regex", @@ -4804,6 +4983,12 @@ dependencies = [ "unicode-width 0.2.0", ] +[[package]] +name = "typenum" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" + [[package]] name = "unicase" version = "2.8.1" @@ -4971,6 +5156,12 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -5071,6 +5262,17 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3" +[[package]] +name = "whoami" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6994d13118ab492c3c80c1f81928718159254c53c472bf9ce36f8dae4add02a7" +dependencies = [ + "redox_syscall", + "wasite", + "web-sys", +] + [[package]] name = "wildmatch" version = "2.4.0" @@ -5399,9 +5601,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" [[package]] name = "winnow" -version = "0.7.11" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74c7b26e3480b707944fc872477815d29a8e429d2f93a1ce000f5fa84a15cbcd" +checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95" dependencies = [ "memchr", ] diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index eba43e548b..0f8085c7e5 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -1,8 +1,8 @@ [workspace] -resolver = "2" members = [ "ansi-escape", "apply-patch", + "arg0", "cli", "common", "core", @@ -16,6 +16,7 @@ members = [ "mcp-types", "tui", ] +resolver = "2" [workspace.package] version = "0.0.0" @@ -40,3 +41,7 @@ strip = "symbols" # See https://github.com/openai/codex/issues/1411 for details. codegen-units = 1 + +[patch.crates-io] +# ratatui = { path = "../../ratatui" } +ratatui = { git = "https://github.com/nornagon/ratatui", branch = "nornagon-v0.29.0-patch" } diff --git a/codex-rs/ansi-escape/Cargo.toml b/codex-rs/ansi-escape/Cargo.toml index 9092c77c9c..ada675380d 100644 --- a/codex-rs/ansi-escape/Cargo.toml +++ b/codex-rs/ansi-escape/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-ansi-escape" version = { workspace = true } -edition = "2024" [lib] name = "codex_ansi_escape" @@ -10,7 +10,7 @@ path = "src/lib.rs" [dependencies] ansi-to-tui = "7.0.0" ratatui = { version = "0.29.0", features = [ - "unstable-widget-ref", "unstable-rendered-line-info", + "unstable-widget-ref", ] } tracing = { version = "0.1.41", features = ["log"] } diff --git a/codex-rs/apply-patch/Cargo.toml b/codex-rs/apply-patch/Cargo.toml index 7848e6e47f..622f53ce71 100644 --- a/codex-rs/apply-patch/Cargo.toml +++ b/codex-rs/apply-patch/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-apply-patch" version = { workspace = true } -edition = "2024" [lib] name = "codex_apply_patch" @@ -14,7 +14,7 @@ workspace = true anyhow = "1" similar = "2.7.0" thiserror = "2.0.12" -tree-sitter = "0.25.3" +tree-sitter = "0.25.8" tree-sitter-bash = "0.25.0" [dev-dependencies] diff --git a/codex-rs/apply-patch/src/lib.rs b/codex-rs/apply-patch/src/lib.rs index c81241d0da..8d42be9a92 100644 --- a/codex-rs/apply-patch/src/lib.rs +++ b/codex-rs/apply-patch/src/lib.rs @@ -58,16 +58,24 @@ impl PartialEq for IoError { #[derive(Debug, PartialEq)] pub enum MaybeApplyPatch { - Body(Vec), + Body(ApplyPatchArgs), ShellParseError(ExtractHeredocError), PatchParseError(ParseError), NotApplyPatch, } +/// Both the raw PATCH argument to `apply_patch` as well as the PATCH argument +/// parsed into hunks. +#[derive(Debug, PartialEq)] +pub struct ApplyPatchArgs { + pub patch: String, + pub hunks: Vec, +} + pub fn maybe_parse_apply_patch(argv: &[String]) -> MaybeApplyPatch { match argv { [cmd, body] if cmd == "apply_patch" => match parse_patch(body) { - Ok(hunks) => MaybeApplyPatch::Body(hunks), + Ok(source) => MaybeApplyPatch::Body(source), Err(e) => MaybeApplyPatch::PatchParseError(e), }, [bash, flag, script] @@ -77,7 +85,7 @@ pub fn maybe_parse_apply_patch(argv: &[String]) -> MaybeApplyPatch { { match extract_heredoc_body_from_apply_patch_command(script) { Ok(body) => match parse_patch(&body) { - Ok(hunks) => MaybeApplyPatch::Body(hunks), + Ok(source) => MaybeApplyPatch::Body(source), Err(e) => MaybeApplyPatch::PatchParseError(e), }, Err(e) => MaybeApplyPatch::ShellParseError(e), @@ -116,11 +124,19 @@ pub enum MaybeApplyPatchVerified { NotApplyPatch, } -#[derive(Debug, PartialEq)] /// ApplyPatchAction is the result of parsing an `apply_patch` command. By /// construction, all paths should be absolute paths. +#[derive(Debug, PartialEq)] pub struct ApplyPatchAction { changes: HashMap, + + /// The raw patch argument that can be used with `apply_patch` as an exec + /// call. i.e., if the original arg was parsed in "lenient" mode with a + /// heredoc, this should be the value without the heredoc wrapper. + pub patch: String, + + /// The working directory that was used to resolve relative paths in the patch. + pub cwd: PathBuf, } impl ApplyPatchAction { @@ -140,8 +156,28 @@ impl ApplyPatchAction { panic!("path must be absolute"); } + #[allow(clippy::expect_used)] + let filename = path + .file_name() + .expect("path should not be empty") + .to_string_lossy(); + let patch = format!( + r#"*** Begin Patch +*** Update File: {filename} +@@ ++ {content} +*** End Patch"#, + ); let changes = HashMap::from([(path.to_path_buf(), ApplyPatchFileChange::Add { content })]); - Self { changes } + #[allow(clippy::expect_used)] + Self { + changes, + cwd: path + .parent() + .expect("path should have parent") + .to_path_buf(), + patch, + } } } @@ -149,7 +185,7 @@ impl ApplyPatchAction { /// patch. pub fn maybe_parse_apply_patch_verified(argv: &[String], cwd: &Path) -> MaybeApplyPatchVerified { match maybe_parse_apply_patch(argv) { - MaybeApplyPatch::Body(hunks) => { + MaybeApplyPatch::Body(ApplyPatchArgs { patch, hunks }) => { let mut changes = HashMap::new(); for hunk in hunks { let path = hunk.resolve_path(cwd); @@ -183,7 +219,11 @@ pub fn maybe_parse_apply_patch_verified(argv: &[String], cwd: &Path) -> MaybeApp } } } - MaybeApplyPatchVerified::Body(ApplyPatchAction { changes }) + MaybeApplyPatchVerified::Body(ApplyPatchAction { + changes, + patch, + cwd: cwd.to_path_buf(), + }) } MaybeApplyPatch::ShellParseError(e) => MaybeApplyPatchVerified::ShellParseError(e), MaybeApplyPatch::PatchParseError(e) => MaybeApplyPatchVerified::CorrectnessError(e.into()), @@ -264,7 +304,7 @@ pub fn apply_patch( stderr: &mut impl std::io::Write, ) -> Result<(), ApplyPatchError> { let hunks = match parse_patch(patch) { - Ok(hunks) => hunks, + Ok(source) => source.hunks, Err(e) => { match &e { InvalidPatchError(message) => { @@ -652,7 +692,7 @@ mod tests { ]); match maybe_parse_apply_patch(&args) { - MaybeApplyPatch::Body(hunks) => { + MaybeApplyPatch::Body(ApplyPatchArgs { hunks, patch: _ }) => { assert_eq!( hunks, vec![Hunk::AddFile { @@ -679,7 +719,7 @@ PATCH"#, ]); match maybe_parse_apply_patch(&args) { - MaybeApplyPatch::Body(hunks) => { + MaybeApplyPatch::Body(ApplyPatchArgs { hunks, patch: _ }) => { assert_eq!( hunks, vec![Hunk::AddFile { @@ -954,7 +994,7 @@ PATCH"#, )); let patch = parse_patch(&patch).unwrap(); - let update_file_chunks = match patch.as_slice() { + let update_file_chunks = match patch.hunks.as_slice() { [Hunk::UpdateFile { chunks, .. }] => chunks, _ => panic!("Expected a single UpdateFile hunk"), }; @@ -992,7 +1032,7 @@ PATCH"#, )); let patch = parse_patch(&patch).unwrap(); - let chunks = match patch.as_slice() { + let chunks = match patch.hunks.as_slice() { [Hunk::UpdateFile { chunks, .. }] => chunks, _ => panic!("Expected a single UpdateFile hunk"), }; @@ -1029,7 +1069,7 @@ PATCH"#, )); let patch = parse_patch(&patch).unwrap(); - let chunks = match patch.as_slice() { + let chunks = match patch.hunks.as_slice() { [Hunk::UpdateFile { chunks, .. }] => chunks, _ => panic!("Expected a single UpdateFile hunk"), }; @@ -1064,7 +1104,7 @@ PATCH"#, )); let patch = parse_patch(&patch).unwrap(); - let chunks = match patch.as_slice() { + let chunks = match patch.hunks.as_slice() { [Hunk::UpdateFile { chunks, .. }] => chunks, _ => panic!("Expected a single UpdateFile hunk"), }; @@ -1110,7 +1150,7 @@ PATCH"#, // Extract chunks then build the unified diff. let parsed = parse_patch(&patch).unwrap(); - let chunks = match parsed.as_slice() { + let chunks = match parsed.hunks.as_slice() { [Hunk::UpdateFile { chunks, .. }] => chunks, _ => panic!("Expected a single UpdateFile hunk"), }; @@ -1193,6 +1233,8 @@ g new_content: "updated session directory content\n".to_string(), }, )]), + patch: argv[1].clone(), + cwd: session_dir.path().to_path_buf(), }) ); } diff --git a/codex-rs/apply-patch/src/parser.rs b/codex-rs/apply-patch/src/parser.rs index d07691a49d..44c5b14619 100644 --- a/codex-rs/apply-patch/src/parser.rs +++ b/codex-rs/apply-patch/src/parser.rs @@ -22,6 +22,7 @@ //! //! The parser below is a little more lenient than the explicit spec and allows for //! leading/trailing whitespace around patch markers. +use crate::ApplyPatchArgs; use std::path::Path; use std::path::PathBuf; @@ -102,7 +103,7 @@ pub struct UpdateFileChunk { pub is_end_of_file: bool, } -pub fn parse_patch(patch: &str) -> Result, ParseError> { +pub fn parse_patch(patch: &str) -> Result { let mode = if PARSE_IN_STRICT_MODE { ParseMode::Strict } else { @@ -150,7 +151,7 @@ enum ParseMode { Lenient, } -fn parse_patch_text(patch: &str, mode: ParseMode) -> Result, ParseError> { +fn parse_patch_text(patch: &str, mode: ParseMode) -> Result { let lines: Vec<&str> = patch.trim().lines().collect(); let lines: &[&str] = match check_patch_boundaries_strict(&lines) { Ok(()) => &lines, @@ -173,7 +174,8 @@ fn parse_patch_text(patch: &str, mode: ParseMode) -> Result, ParseErro line_number += hunk_lines; remaining_lines = &remaining_lines[hunk_lines..] } - Ok(hunks) + let patch = lines.join("\n"); + Ok(ApplyPatchArgs { hunks, patch }) } /// Checks the start and end lines of the patch text for `apply_patch`, @@ -425,6 +427,7 @@ fn parse_update_file_chunk( } #[test] +#[allow(clippy::unwrap_used)] fn test_parse_patch() { assert_eq!( parse_patch_text("bad", ParseMode::Strict), @@ -455,8 +458,10 @@ fn test_parse_patch() { "*** Begin Patch\n\ *** End Patch", ParseMode::Strict - ), - Ok(Vec::new()) + ) + .unwrap() + .hunks, + Vec::new() ); assert_eq!( parse_patch_text( @@ -472,8 +477,10 @@ fn test_parse_patch() { + return 123\n\ *** End Patch", ParseMode::Strict - ), - Ok(vec![ + ) + .unwrap() + .hunks, + vec![ AddFile { path: PathBuf::from("path/add.py"), contents: "abc\ndef\n".to_string() @@ -491,7 +498,7 @@ fn test_parse_patch() { is_end_of_file: false }] } - ]) + ] ); // Update hunk followed by another hunk (Add File). assert_eq!( @@ -504,8 +511,10 @@ fn test_parse_patch() { +content\n\ *** End Patch", ParseMode::Strict - ), - Ok(vec![ + ) + .unwrap() + .hunks, + vec![ UpdateFile { path: PathBuf::from("file.py"), move_path: None, @@ -520,7 +529,7 @@ fn test_parse_patch() { path: PathBuf::from("other.py"), contents: "content\n".to_string() } - ]) + ] ); // Update hunk without an explicit @@ header for the first chunk should parse. @@ -533,8 +542,10 @@ fn test_parse_patch() { +bar *** End Patch"#, ParseMode::Strict - ), - Ok(vec![UpdateFile { + ) + .unwrap() + .hunks, + vec![UpdateFile { path: PathBuf::from("file2.py"), move_path: None, chunks: vec![UpdateFileChunk { @@ -543,7 +554,7 @@ fn test_parse_patch() { new_lines: vec!["import foo".to_string(), "bar".to_string()], is_end_of_file: false, }], - }]) + }] ); } @@ -574,7 +585,10 @@ fn test_parse_patch_lenient() { ); assert_eq!( parse_patch_text(&patch_text_in_heredoc, ParseMode::Lenient), - Ok(expected_patch.clone()) + Ok(ApplyPatchArgs { + hunks: expected_patch.clone(), + patch: patch_text.to_string() + }) ); let patch_text_in_single_quoted_heredoc = format!("<<'EOF'\n{patch_text}\nEOF\n"); @@ -584,7 +598,10 @@ fn test_parse_patch_lenient() { ); assert_eq!( parse_patch_text(&patch_text_in_single_quoted_heredoc, ParseMode::Lenient), - Ok(expected_patch.clone()) + Ok(ApplyPatchArgs { + hunks: expected_patch.clone(), + patch: patch_text.to_string() + }) ); let patch_text_in_double_quoted_heredoc = format!("<<\"EOF\"\n{patch_text}\nEOF\n"); @@ -594,7 +611,10 @@ fn test_parse_patch_lenient() { ); assert_eq!( parse_patch_text(&patch_text_in_double_quoted_heredoc, ParseMode::Lenient), - Ok(expected_patch.clone()) + Ok(ApplyPatchArgs { + hunks: expected_patch.clone(), + patch: patch_text.to_string() + }) ); let patch_text_in_mismatched_quotes_heredoc = format!("<<\"EOF'\n{patch_text}\nEOF\n"); diff --git a/codex-rs/arg0/Cargo.toml b/codex-rs/arg0/Cargo.toml new file mode 100644 index 0000000000..d668ffeff9 --- /dev/null +++ b/codex-rs/arg0/Cargo.toml @@ -0,0 +1,19 @@ +[package] +edition = "2024" +name = "codex-arg0" +version = { workspace = true } + +[lib] +name = "codex_arg0" +path = "src/lib.rs" + +[lints] +workspace = true + +[dependencies] +anyhow = "1" +codex-apply-patch = { path = "../apply-patch" } +codex-core = { path = "../core" } +codex-linux-sandbox = { path = "../linux-sandbox" } +dotenvy = "0.15.7" +tokio = { version = "1", features = ["rt-multi-thread"] } diff --git a/codex-rs/arg0/src/lib.rs b/codex-rs/arg0/src/lib.rs new file mode 100644 index 0000000000..c097ebc11c --- /dev/null +++ b/codex-rs/arg0/src/lib.rs @@ -0,0 +1,91 @@ +use std::future::Future; +use std::path::Path; +use std::path::PathBuf; + +use codex_core::CODEX_APPLY_PATCH_ARG1; + +/// While we want to deploy the Codex CLI as a single executable for simplicity, +/// we also want to expose some of its functionality as distinct CLIs, so we use +/// the "arg0 trick" to determine which CLI to dispatch. This effectively allows +/// us to simulate deploying multiple executables as a single binary on Mac and +/// Linux (but not Windows). +/// +/// When the current executable is invoked through the hard-link or alias named +/// `codex-linux-sandbox` we *directly* execute +/// [`codex_linux_sandbox::run_main`] (which never returns). Otherwise we: +/// +/// 1. Use [`dotenvy::from_path`] and [`dotenvy::dotenv`] to modify the +/// environment before creating any threads. +/// 2. Construct a Tokio multi-thread runtime. +/// 3. Derive the path to the current executable (so children can re-invoke the +/// sandbox) when running on Linux. +/// 4. Execute the provided async `main_fn` inside that runtime, forwarding any +/// error. Note that `main_fn` receives `codex_linux_sandbox_exe: +/// Option`, as an argument, which is generally needed as part of +/// constructing [`codex_core::config::Config`]. +/// +/// This function should be used to wrap any `main()` function in binary crates +/// in this workspace that depends on these helper CLIs. +pub fn arg0_dispatch_or_else(main_fn: F) -> anyhow::Result<()> +where + F: FnOnce(Option) -> Fut, + Fut: Future>, +{ + // Determine if we were invoked via the special alias. + let mut args = std::env::args_os(); + let argv0 = args.next().unwrap_or_default(); + let exe_name = Path::new(&argv0) + .file_name() + .and_then(|s| s.to_str()) + .unwrap_or(""); + + if exe_name == "codex-linux-sandbox" { + // Safety: [`run_main`] never returns. + codex_linux_sandbox::run_main(); + } + + let argv1 = args.next().unwrap_or_default(); + if argv1 == CODEX_APPLY_PATCH_ARG1 { + let patch_arg = args.next().and_then(|s| s.to_str().map(|s| s.to_owned())); + let exit_code = match patch_arg { + Some(patch_arg) => { + let mut stdout = std::io::stdout(); + let mut stderr = std::io::stderr(); + match codex_apply_patch::apply_patch(&patch_arg, &mut stdout, &mut stderr) { + Ok(()) => 0, + Err(_) => 1, + } + } + None => { + eprintln!("Error: {CODEX_APPLY_PATCH_ARG1} requires a UTF-8 PATCH argument."); + 1 + } + }; + std::process::exit(exit_code); + } + + // This modifies the environment, which is not thread-safe, so do this + // before creating any threads/the Tokio runtime. + load_dotenv(); + + // Regular invocation – create a Tokio runtime and execute the provided + // async entry-point. + let runtime = tokio::runtime::Runtime::new()?; + runtime.block_on(async move { + let codex_linux_sandbox_exe: Option = if cfg!(target_os = "linux") { + std::env::current_exe().ok() + } else { + None + }; + + main_fn(codex_linux_sandbox_exe).await + }) +} + +/// Load env vars from ~/.codex/.env and `$(pwd)/.env`. +fn load_dotenv() { + if let Ok(codex_home) = codex_core::config::find_codex_home() { + dotenvy::from_path(codex_home.join(".env")).ok(); + } + dotenvy::dotenv().ok(); +} diff --git a/codex-rs/chatgpt/Cargo.toml b/codex-rs/chatgpt/Cargo.toml index e07543f4e8..903dc14b51 100644 --- a/codex-rs/chatgpt/Cargo.toml +++ b/codex-rs/chatgpt/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-chatgpt" version = { workspace = true } -edition = "2024" [lints] workspace = true @@ -9,12 +9,12 @@ workspace = true [dependencies] anyhow = "1" clap = { version = "4", features = ["derive"] } -serde = { version = "1", features = ["derive"] } -serde_json = "1" codex-common = { path = "../common", features = ["cli"] } codex-core = { path = "../core" } codex-login = { path = "../login" } reqwest = { version = "0.12", features = ["json", "stream"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" tokio = { version = "1", features = ["full"] } [dev-dependencies] diff --git a/codex-rs/chatgpt/src/apply_command.rs b/codex-rs/chatgpt/src/apply_command.rs index 4209d958e1..52ab205a0c 100644 --- a/codex-rs/chatgpt/src/apply_command.rs +++ b/codex-rs/chatgpt/src/apply_command.rs @@ -1,3 +1,5 @@ +use std::path::PathBuf; + use clap::Parser; use codex_common::CliConfigOverrides; use codex_core::config::Config; @@ -17,7 +19,10 @@ pub struct ApplyCommand { #[clap(flatten)] pub config_overrides: CliConfigOverrides, } -pub async fn run_apply_command(apply_cli: ApplyCommand) -> anyhow::Result<()> { +pub async fn run_apply_command( + apply_cli: ApplyCommand, + cwd: Option, +) -> anyhow::Result<()> { let config = Config::load_with_cli_overrides( apply_cli .config_overrides @@ -29,10 +34,13 @@ pub async fn run_apply_command(apply_cli: ApplyCommand) -> anyhow::Result<()> { init_chatgpt_token_from_auth(&config.codex_home).await?; let task_response = get_task(&config, apply_cli.task_id).await?; - apply_diff_from_task(task_response).await + apply_diff_from_task(task_response, cwd).await } -pub async fn apply_diff_from_task(task_response: GetTaskResponse) -> anyhow::Result<()> { +pub async fn apply_diff_from_task( + task_response: GetTaskResponse, + cwd: Option, +) -> anyhow::Result<()> { let diff_turn = match task_response.current_diff_task_turn { Some(turn) => turn, None => anyhow::bail!("No diff turn found"), @@ -42,13 +50,17 @@ pub async fn apply_diff_from_task(task_response: GetTaskResponse) -> anyhow::Res _ => None, }); match output_diff { - Some(output_diff) => apply_diff(&output_diff.diff).await, + Some(output_diff) => apply_diff(&output_diff.diff, cwd).await, None => anyhow::bail!("No PR output item found"), } } -async fn apply_diff(diff: &str) -> anyhow::Result<()> { - let toplevel_output = tokio::process::Command::new("git") +async fn apply_diff(diff: &str, cwd: Option) -> anyhow::Result<()> { + let mut cmd = tokio::process::Command::new("git"); + if let Some(cwd) = cwd { + cmd.current_dir(cwd); + } + let toplevel_output = cmd .args(vec!["rev-parse", "--show-toplevel"]) .output() .await?; diff --git a/codex-rs/chatgpt/src/chatgpt_client.rs b/codex-rs/chatgpt/src/chatgpt_client.rs index 4c4cb4c4c3..907783bb81 100644 --- a/codex-rs/chatgpt/src/chatgpt_client.rs +++ b/codex-rs/chatgpt/src/chatgpt_client.rs @@ -21,10 +21,14 @@ pub(crate) async fn chatgpt_get_request( let token = get_chatgpt_token_data().ok_or_else(|| anyhow::anyhow!("ChatGPT token not available"))?; + let account_id = token.account_id.ok_or_else(|| { + anyhow::anyhow!("ChatGPT account ID not available, please re-run `codex login`") + }); + let response = client .get(&url) .bearer_auth(&token.access_token) - .header("chatgpt-account-id", &token.account_id) + .header("chatgpt-account-id", account_id?) .header("Content-Type", "application/json") .header("User-Agent", "codex-cli") .send() diff --git a/codex-rs/chatgpt/src/chatgpt_token.rs b/codex-rs/chatgpt/src/chatgpt_token.rs index adf9a6ba96..55ebc22a08 100644 --- a/codex-rs/chatgpt/src/chatgpt_token.rs +++ b/codex-rs/chatgpt/src/chatgpt_token.rs @@ -18,7 +18,10 @@ pub fn set_chatgpt_token_data(value: TokenData) { /// Initialize the ChatGPT token from auth.json file pub async fn init_chatgpt_token_from_auth(codex_home: &Path) -> std::io::Result<()> { - let auth_json = codex_login::try_read_auth_json(codex_home).await?; - set_chatgpt_token_data(auth_json.tokens.clone()); + let auth = codex_login::load_auth(codex_home)?; + if let Some(auth) = auth { + let token_data = auth.get_token_data().await?; + set_chatgpt_token_data(token_data); + } Ok(()) } diff --git a/codex-rs/chatgpt/tests/apply_command_e2e.rs b/codex-rs/chatgpt/tests/apply_command_e2e.rs index e395e4f155..45c33bedb4 100644 --- a/codex-rs/chatgpt/tests/apply_command_e2e.rs +++ b/codex-rs/chatgpt/tests/apply_command_e2e.rs @@ -78,17 +78,7 @@ async fn test_apply_command_creates_fibonacci_file() { .await .expect("Failed to load fixture"); - let original_dir = std::env::current_dir().expect("Failed to get current dir"); - std::env::set_current_dir(repo_path).expect("Failed to change directory"); - struct DirGuard(std::path::PathBuf); - impl Drop for DirGuard { - fn drop(&mut self) { - let _ = std::env::set_current_dir(&self.0); - } - } - let _guard = DirGuard(original_dir); - - apply_diff_from_task(task_response) + apply_diff_from_task(task_response, Some(repo_path.to_path_buf())) .await .expect("Failed to apply diff from task"); @@ -173,7 +163,7 @@ console.log(fib(10)); .await .expect("Failed to load fixture"); - let apply_result = apply_diff_from_task(task_response).await; + let apply_result = apply_diff_from_task(task_response, Some(repo_path.to_path_buf())).await; assert!( apply_result.is_err(), diff --git a/codex-rs/cli/Cargo.toml b/codex-rs/cli/Cargo.toml index 943788157b..0f370691cf 100644 --- a/codex-rs/cli/Cargo.toml +++ b/codex-rs/cli/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-cli" version = { workspace = true } -edition = "2024" [[bin]] name = "codex" @@ -18,12 +18,12 @@ workspace = true anyhow = "1" clap = { version = "4", features = ["derive"] } clap_complete = "4" +codex-arg0 = { path = "../arg0" } codex-chatgpt = { path = "../chatgpt" } -codex-core = { path = "../core" } codex-common = { path = "../common", features = ["cli"] } +codex-core = { path = "../core" } codex-exec = { path = "../exec" } codex-login = { path = "../login" } -codex-linux-sandbox = { path = "../linux-sandbox" } codex-mcp-server = { path = "../mcp-server" } codex-tui = { path = "../tui" } serde_json = "1" diff --git a/codex-rs/cli/src/login.rs b/codex-rs/cli/src/login.rs index af3fb667f6..390c310030 100644 --- a/codex-rs/cli/src/login.rs +++ b/codex-rs/cli/src/login.rs @@ -1,25 +1,12 @@ use codex_common::CliConfigOverrides; use codex_core::config::Config; use codex_core::config::ConfigOverrides; +use codex_login::AuthMode; +use codex_login::load_auth; use codex_login::login_with_chatgpt; pub async fn run_login_with_chatgpt(cli_config_overrides: CliConfigOverrides) -> ! { - let cli_overrides = match cli_config_overrides.parse_overrides() { - Ok(v) => v, - Err(e) => { - eprintln!("Error parsing -c overrides: {e}"); - std::process::exit(1); - } - }; - - let config_overrides = ConfigOverrides::default(); - let config = match Config::load_with_cli_overrides(cli_overrides, config_overrides) { - Ok(config) => config, - Err(e) => { - eprintln!("Error loading configuration: {e}"); - std::process::exit(1); - } - }; + let config = load_config_or_exit(cli_config_overrides); let capture_output = false; match login_with_chatgpt(&config.codex_home, capture_output).await { @@ -33,3 +20,77 @@ pub async fn run_login_with_chatgpt(cli_config_overrides: CliConfigOverrides) -> } } } + +pub async fn run_login_status(cli_config_overrides: CliConfigOverrides) -> ! { + let config = load_config_or_exit(cli_config_overrides); + + match load_auth(&config.codex_home) { + Ok(Some(auth)) => match auth.mode { + AuthMode::ApiKey => { + if let Some(api_key) = auth.api_key.as_deref() { + eprintln!("Logged in using an API key - {}", safe_format_key(api_key)); + } else { + eprintln!("Logged in using an API key"); + } + std::process::exit(0); + } + AuthMode::ChatGPT => { + eprintln!("Logged in using ChatGPT"); + std::process::exit(0); + } + }, + Ok(None) => { + eprintln!("Not logged in"); + std::process::exit(1); + } + Err(e) => { + eprintln!("Error checking login status: {e}"); + std::process::exit(1); + } + } +} + +fn load_config_or_exit(cli_config_overrides: CliConfigOverrides) -> Config { + let cli_overrides = match cli_config_overrides.parse_overrides() { + Ok(v) => v, + Err(e) => { + eprintln!("Error parsing -c overrides: {e}"); + std::process::exit(1); + } + }; + + let config_overrides = ConfigOverrides::default(); + match Config::load_with_cli_overrides(cli_overrides, config_overrides) { + Ok(config) => config, + Err(e) => { + eprintln!("Error loading configuration: {e}"); + std::process::exit(1); + } + } +} + +fn safe_format_key(key: &str) -> String { + if key.len() <= 13 { + return "***".to_string(); + } + let prefix = &key[..8]; + let suffix = &key[key.len() - 5..]; + format!("{prefix}***{suffix}") +} + +#[cfg(test)] +mod tests { + use super::safe_format_key; + + #[test] + fn formats_long_key() { + let key = "sk-proj-1234567890ABCDE"; + assert_eq!(safe_format_key(key), "sk-proj-***ABCDE"); + } + + #[test] + fn short_key_returns_stars() { + let key = "sk-proj-12345"; + assert_eq!(safe_format_key(key), "***"); + } +} diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index 7e23782d75..c5fd69f9cd 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -2,10 +2,12 @@ use clap::CommandFactory; use clap::Parser; use clap_complete::Shell; use clap_complete::generate; +use codex_arg0::arg0_dispatch_or_else; use codex_chatgpt::apply_command::ApplyCommand; use codex_chatgpt::apply_command::run_apply_command; use codex_cli::LandlockCommand; use codex_cli::SeatbeltCommand; +use codex_cli::login::run_login_status; use codex_cli::login::run_login_with_chatgpt; use codex_cli::proto; use codex_common::CliConfigOverrides; @@ -42,7 +44,7 @@ enum Subcommand { #[clap(visible_alias = "e")] Exec(ExecCli), - /// Login with ChatGPT. + /// Manage login. Login(LoginCommand), /// Experimental: run Codex as an MCP server. @@ -89,10 +91,19 @@ enum DebugCommand { struct LoginCommand { #[clap(skip)] config_overrides: CliConfigOverrides, + + #[command(subcommand)] + action: Option, +} + +#[derive(Debug, clap::Subcommand)] +enum LoginSubcommand { + /// Show login status. + Status, } fn main() -> anyhow::Result<()> { - codex_linux_sandbox::run_with_sandbox(|codex_linux_sandbox_exe| async move { + arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move { cli_main(codex_linux_sandbox_exe).await?; Ok(()) }) @@ -105,7 +116,8 @@ async fn cli_main(codex_linux_sandbox_exe: Option) -> anyhow::Result<() None => { let mut tui_cli = cli.interactive; prepend_config_flags(&mut tui_cli.config_overrides, cli.config_overrides); - codex_tui::run_main(tui_cli, codex_linux_sandbox_exe)?; + let usage = codex_tui::run_main(tui_cli, codex_linux_sandbox_exe).await?; + println!("{}", codex_core::protocol::FinalOutput::from(usage)); } Some(Subcommand::Exec(mut exec_cli)) => { prepend_config_flags(&mut exec_cli.config_overrides, cli.config_overrides); @@ -116,7 +128,14 @@ async fn cli_main(codex_linux_sandbox_exe: Option) -> anyhow::Result<() } Some(Subcommand::Login(mut login_cli)) => { prepend_config_flags(&mut login_cli.config_overrides, cli.config_overrides); - run_login_with_chatgpt(login_cli.config_overrides).await; + match login_cli.action { + Some(LoginSubcommand::Status) => { + run_login_status(login_cli.config_overrides).await; + } + None => { + run_login_with_chatgpt(login_cli.config_overrides).await; + } + } } Some(Subcommand::Proto(mut proto_cli)) => { prepend_config_flags(&mut proto_cli.config_overrides, cli.config_overrides); @@ -145,7 +164,7 @@ async fn cli_main(codex_linux_sandbox_exe: Option) -> anyhow::Result<() }, Some(Subcommand::Apply(mut apply_cli)) => { prepend_config_flags(&mut apply_cli.config_overrides, cli.config_overrides); - run_apply_command(apply_cli).await?; + run_apply_command(apply_cli, None).await?; } } diff --git a/codex-rs/cli/src/proto.rs b/codex-rs/cli/src/proto.rs index 148699552a..291e1680f1 100644 --- a/codex-rs/cli/src/proto.rs +++ b/codex-rs/cli/src/proto.rs @@ -4,10 +4,12 @@ use std::sync::Arc; use clap::Parser; use codex_common::CliConfigOverrides; use codex_core::Codex; +use codex_core::CodexSpawnOk; use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::protocol::Submission; use codex_core::util::notify_on_sigint; +use codex_login::load_auth; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tracing::error; @@ -34,8 +36,9 @@ pub async fn run_main(opts: ProtoCli) -> anyhow::Result<()> { .map_err(anyhow::Error::msg)?; let config = Config::load_with_cli_overrides(overrides_vec, ConfigOverrides::default())?; + let auth = load_auth(&config.codex_home)?; let ctrl_c = notify_on_sigint(); - let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await?; + let CodexSpawnOk { codex, .. } = Codex::spawn(config, auth, ctrl_c.clone()).await?; let codex = Arc::new(codex); // Task that reads JSON lines from stdin and forwards to Submission Queue diff --git a/codex-rs/common/Cargo.toml b/codex-rs/common/Cargo.toml index 3b843181cf..1723098b8a 100644 --- a/codex-rs/common/Cargo.toml +++ b/codex-rs/common/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-common" version = { workspace = true } -edition = "2024" [lints] workspace = true @@ -9,11 +9,11 @@ workspace = true [dependencies] clap = { version = "4", features = ["derive", "wrap_help"], optional = true } codex-core = { path = "../core" } -toml = { version = "0.9", optional = true } serde = { version = "1", optional = true } +toml = { version = "0.9", optional = true } [features] # Separate feature so that `clap` is not a mandatory dependency. -cli = ["clap", "toml", "serde"] +cli = ["clap", "serde", "toml"] elapsed = [] sandbox_summary = [] diff --git a/codex-rs/common/src/config_override.rs b/codex-rs/common/src/config_override.rs index 610195d6d1..c9b18edc7c 100644 --- a/codex-rs/common/src/config_override.rs +++ b/codex-rs/common/src/config_override.rs @@ -64,7 +64,11 @@ impl CliConfigOverrides { // `-c model=o3` without the quotes. let value: Value = match parse_toml_value(value_str) { Ok(v) => v, - Err(_) => Value::String(value_str.to_string()), + Err(_) => { + // Strip leading/trailing quotes if present + let trimmed = value_str.trim().trim_matches(|c| c == '"' || c == '\''); + Value::String(trimmed.to_string()) + } }; Ok((key.to_string(), value)) diff --git a/codex-rs/config.md b/codex-rs/config.md index 5399513c85..7f184d6269 100644 --- a/codex-rs/config.md +++ b/codex-rs/config.md @@ -92,6 +92,35 @@ http_headers = { "X-Example-Header" = "example-value" } env_http_headers = { "X-Example-Features": "EXAMPLE_FEATURES" } ``` +### Per-provider network tuning + +The following optional settings control retry behaviour and streaming idle timeouts **per model provider**. They must be specified inside the corresponding `[model_providers.]` block in `config.toml`. (Older releases accepted top‑level keys; those are now ignored.) + +Example: + +```toml +[model_providers.openai] +name = "OpenAI" +base_url = "https://api.openai.com/v1" +env_key = "OPENAI_API_KEY" +# network tuning overrides (all optional; falls back to built‑in defaults) +request_max_retries = 4 # retry failed HTTP requests +stream_max_retries = 10 # retry dropped SSE streams +stream_idle_timeout_ms = 300000 # 5m idle timeout +``` + +#### request_max_retries + +How many times Codex will retry a failed HTTP request to the model provider. Defaults to `4`. + +#### stream_max_retries + +Number of times Codex will attempt to reconnect when a streaming response is interrupted. Defaults to `10`. + +#### stream_idle_timeout_ms + +How long Codex will wait for activity on a streaming response before treating the connection as lost. Defaults to `300_000` (5 minutes). + ## model_provider Identifies which provider to use from the `model_providers` map. Defaults to `"openai"`. You can override the `base_url` for the built-in `openai` provider via the `OPENAI_BASE_URL` environment variable. @@ -444,7 +473,7 @@ Currently, `"vscode"` is the default, though Codex does not verify VS Code is in ## hide_agent_reasoning -Codex intermittently emits "reasoning" events that show the model’s internal "thinking" before it produces a final answer. Some users may find these events distracting, especially in CI logs or minimal terminal output. +Codex intermittently emits "reasoning" events that show the model's internal "thinking" before it produces a final answer. Some users may find these events distracting, especially in CI logs or minimal terminal output. Setting `hide_agent_reasoning` to `true` suppresses these events in **both** the TUI as well as the headless `exec` sub-command: @@ -482,14 +511,5 @@ Options that are specific to the TUI. ```toml [tui] -# This will make it so that Codex does not try to process mouse events, which -# means your Terminal's native drag-to-text to text selection and copy/paste -# should work. The tradeoff is that Codex will not receive any mouse events, so -# it will not be possible to use the mouse to scroll conversation history. -# -# Note that most terminals support holding down a modifier key when using the -# mouse to support text selection. For example, even if Codex mouse capture is -# enabled (i.e., this is set to `false`), you can still hold down alt while -# dragging the mouse to select text. -disable_mouse_capture = true # defaults to `false` +# More to come here ``` diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 22636102c9..ecc904cd5e 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-core" version = { workspace = true } -edition = "2024" [lib] name = "codex_core" @@ -15,20 +15,25 @@ anyhow = "1" async-channel = "2.3.1" base64 = "0.22" bytes = "1.10.1" +chrono = { version = "0.4", features = ["serde"] } codex-apply-patch = { path = "../apply-patch" } +codex-login = { path = "../login" } codex-mcp-client = { path = "../mcp-client" } dirs = "6" env-flags = "0.1.1" eventsource-stream = "0.2.3" fs2 = "0.4.3" futures = "0.3" +libc = "0.2.174" mcp-types = { path = "../mcp-types" } mime_guess = "2.0" rand = "0.9" reqwest = { version = "0.12", features = ["json", "stream"] } serde = { version = "1", features = ["derive"] } serde_json = "1" -strum_macros = "0.27.1" +sha1 = "0.10.6" +shlex = "1.3.0" +strum_macros = "0.27.2" thiserror = "2.0.12" time = { version = "0.3", features = ["formatting", "local-offset", "macros"] } tokio = { version = "1", features = [ @@ -39,13 +44,15 @@ tokio = { version = "1", features = [ "signal", ] } tokio-util = "0.7.14" -toml = "0.9.1" +toml = "0.9.2" tracing = { version = "0.1.41", features = ["log"] } -tree-sitter = "0.25.3" +tree-sitter = "0.25.8" tree-sitter-bash = "0.25.0" uuid = { version = "1", features = ["serde", "v4"] } +whoami = "1.6.0" wildmatch = "2.4.0" + [target.'cfg(target_os = "linux")'.dependencies] landlock = "0.4.1" seccompiler = "0.5.0" @@ -60,8 +67,11 @@ openssl-sys = { version = "*", features = ["vendored"] } [dev-dependencies] assert_cmd = "2" +core_test_support = { path = "tests/common" } maplit = "1.0.2" predicates = "3" pretty_assertions = "1.4.1" tempfile = "3" +tokio-test = "0.4" +walkdir = "2.5.0" wiremock = "0.6" diff --git a/codex-rs/core/README.md b/codex-rs/core/README.md index 9b3e59c8af..9a4c255abe 100644 --- a/codex-rs/core/README.md +++ b/codex-rs/core/README.md @@ -2,9 +2,18 @@ This crate implements the business logic for Codex. It is designed to be used by the various Codex UIs written in Rust. -Though for non-Rust UIs, we are also working to define a _protocol_ for talking to Codex. See: +## Dependencies -- [Specification](../docs/protocol_v1.md) -- [Rust types](./src/protocol.rs) +Note that `codex-core` makes some assumptions about certain helper utilities being available in the environment. Currently, this -You can use the `proto` subcommand using the executable in the [`cli` crate](../cli) to speak the protocol using newline-delimited-JSON over stdin/stdout. +### macOS + +Expects `/usr/bin/sandbox-exec` to be present. + +### Linux + +Expects the binary containing `codex-core` to run the equivalent of `codex debug landlock` when `arg0` is `codex-linux-sandbox`. See the `codex-arg0` crate for details. + +### All Platforms + +Expects the binary containing `codex-core` to simulate the virtual `apply_patch` CLI when `arg1` is `--codex-run-as-apply-patch`. See the `codex-arg0` crate for details. diff --git a/codex-rs/core/src/apply_patch.rs b/codex-rs/core/src/apply_patch.rs new file mode 100644 index 0000000000..f116c790ab --- /dev/null +++ b/codex-rs/core/src/apply_patch.rs @@ -0,0 +1,434 @@ +use crate::codex::Session; +use crate::models::FunctionCallOutputPayload; +use crate::models::ResponseInputItem; +use crate::protocol::Event; +use crate::protocol::EventMsg; +use crate::protocol::FileChange; +use crate::protocol::PatchApplyBeginEvent; +use crate::protocol::PatchApplyEndEvent; +use crate::protocol::ReviewDecision; +use crate::safety::SafetyCheck; +use crate::safety::assess_patch_safety; +use anyhow::Context; +use codex_apply_patch::AffectedPaths; +use codex_apply_patch::ApplyPatchAction; +use codex_apply_patch::ApplyPatchFileChange; +use codex_apply_patch::print_summary; +use std::collections::HashMap; +use std::path::Path; +use std::path::PathBuf; + +pub const CODEX_APPLY_PATCH_ARG1: &str = "--codex-run-as-apply-patch"; + +pub(crate) enum InternalApplyPatchInvocation { + /// The `apply_patch` call was handled programmatically, without any sort + /// of sandbox, because the user explicitly approved it. This is the + /// result to use with the `shell` function call that contained `apply_patch`. + Output(ResponseInputItem), + + /// The `apply_patch` call was auto-approved, which means that, on the + /// surface, it appears to be safe, but it should be run in a sandbox if the + /// user has configured one because a path being written could be a hard + /// link to a file outside the writable folders, so only the sandbox can + /// faithfully prevent the write in that case. + DelegateToExec(ApplyPatchAction), +} + +impl From for InternalApplyPatchInvocation { + fn from(item: ResponseInputItem) -> Self { + InternalApplyPatchInvocation::Output(item) + } +} + +pub(crate) async fn apply_patch( + sess: &Session, + sub_id: &str, + call_id: &str, + action: ApplyPatchAction, +) -> InternalApplyPatchInvocation { + let writable_roots_snapshot = { + #[allow(clippy::unwrap_used)] + let guard = sess.writable_roots.lock().unwrap(); + guard.clone() + }; + + let auto_approved = match assess_patch_safety( + &action, + sess.approval_policy, + &writable_roots_snapshot, + &sess.cwd, + ) { + SafetyCheck::AutoApprove { .. } => { + return InternalApplyPatchInvocation::DelegateToExec(action); + } + SafetyCheck::AskUser => { + // Compute a readable summary of path changes to include in the + // approval request so the user can make an informed decision. + let rx_approve = sess + .request_patch_approval(sub_id.to_owned(), call_id.to_owned(), &action, None, None) + .await; + match rx_approve.await.unwrap_or_default() { + ReviewDecision::Approved | ReviewDecision::ApprovedForSession => false, + ReviewDecision::Denied | ReviewDecision::Abort => { + return ResponseInputItem::FunctionCallOutput { + call_id: call_id.to_owned(), + output: FunctionCallOutputPayload { + content: "patch rejected by user".to_string(), + success: Some(false), + }, + } + .into(); + } + } + } + SafetyCheck::Reject { reason } => { + return ResponseInputItem::FunctionCallOutput { + call_id: call_id.to_owned(), + output: FunctionCallOutputPayload { + content: format!("patch rejected: {reason}"), + success: Some(false), + }, + } + .into(); + } + }; + + // Verify write permissions before touching the filesystem. + let writable_snapshot = { + #[allow(clippy::unwrap_used)] + sess.writable_roots.lock().unwrap().clone() + }; + + if let Some(offending) = first_offending_path(&action, &writable_snapshot, &sess.cwd) { + let root = offending.parent().unwrap_or(&offending).to_path_buf(); + + let reason = Some(format!( + "grant write access to {} for this session", + root.display() + )); + + let rx = sess + .request_patch_approval( + sub_id.to_owned(), + call_id.to_owned(), + &action, + reason.clone(), + Some(root.clone()), + ) + .await; + + if !matches!( + rx.await.unwrap_or_default(), + ReviewDecision::Approved | ReviewDecision::ApprovedForSession + ) { + return ResponseInputItem::FunctionCallOutput { + call_id: call_id.to_owned(), + output: FunctionCallOutputPayload { + content: "patch rejected by user".to_string(), + success: Some(false), + }, + } + .into(); + } + + // user approved, extend writable roots for this session + #[allow(clippy::unwrap_used)] + sess.writable_roots.lock().unwrap().push(root); + } + + let _ = sess + .tx_event + .send(Event { + id: sub_id.to_owned(), + msg: EventMsg::PatchApplyBegin(PatchApplyBeginEvent { + call_id: call_id.to_owned(), + auto_approved, + changes: convert_apply_patch_to_protocol(&action), + }), + }) + .await; + + let mut stdout = Vec::new(); + let mut stderr = Vec::new(); + // Enforce writable roots. If a write is blocked, collect offending root + // and prompt the user to extend permissions. + let mut result = apply_changes_from_apply_patch_and_report(&action, &mut stdout, &mut stderr); + + if let Err(err) = &result { + if err.kind() == std::io::ErrorKind::PermissionDenied { + // Determine first offending path. + let offending_opt = action + .changes() + .iter() + .flat_map(|(path, change)| match change { + ApplyPatchFileChange::Add { .. } => vec![path.as_ref()], + ApplyPatchFileChange::Delete => vec![path.as_ref()], + ApplyPatchFileChange::Update { + move_path: Some(move_path), + .. + } => { + vec![path.as_ref(), move_path.as_ref()] + } + ApplyPatchFileChange::Update { + move_path: None, .. + } => vec![path.as_ref()], + }) + .find_map(|path: &Path| { + // ApplyPatchAction promises to guarantee absolute paths. + if !path.is_absolute() { + panic!("apply_patch invariant failed: path is not absolute: {path:?}"); + } + + let writable = { + #[allow(clippy::unwrap_used)] + let roots = sess.writable_roots.lock().unwrap(); + roots.iter().any(|root| path.starts_with(root)) + }; + if writable { + None + } else { + Some(path.to_path_buf()) + } + }); + + if let Some(offending) = offending_opt { + let root = offending.parent().unwrap_or(&offending).to_path_buf(); + + let reason = Some(format!( + "grant write access to {} for this session", + root.display() + )); + let rx = sess + .request_patch_approval( + sub_id.to_owned(), + call_id.to_owned(), + &action, + reason.clone(), + Some(root.clone()), + ) + .await; + if matches!( + rx.await.unwrap_or_default(), + ReviewDecision::Approved | ReviewDecision::ApprovedForSession + ) { + // Extend writable roots. + #[allow(clippy::unwrap_used)] + sess.writable_roots.lock().unwrap().push(root); + stdout.clear(); + stderr.clear(); + result = apply_changes_from_apply_patch_and_report( + &action, + &mut stdout, + &mut stderr, + ); + } + } + } + } + + // Emit PatchApplyEnd event. + let success_flag = result.is_ok(); + let _ = sess + .tx_event + .send(Event { + id: sub_id.to_owned(), + msg: EventMsg::PatchApplyEnd(PatchApplyEndEvent { + call_id: call_id.to_owned(), + stdout: String::from_utf8_lossy(&stdout).to_string(), + stderr: String::from_utf8_lossy(&stderr).to_string(), + success: success_flag, + }), + }) + .await; + + let item = match result { + Ok(_) => ResponseInputItem::FunctionCallOutput { + call_id: call_id.to_owned(), + output: FunctionCallOutputPayload { + content: String::from_utf8_lossy(&stdout).to_string(), + success: None, + }, + }, + Err(e) => ResponseInputItem::FunctionCallOutput { + call_id: call_id.to_owned(), + output: FunctionCallOutputPayload { + content: format!("error: {e:#}, stderr: {}", String::from_utf8_lossy(&stderr)), + success: Some(false), + }, + }, + }; + InternalApplyPatchInvocation::Output(item) +} + +/// Return the first path in `hunks` that is NOT under any of the +/// `writable_roots` (after normalising). If all paths are acceptable, +/// returns None. +fn first_offending_path( + action: &ApplyPatchAction, + writable_roots: &[PathBuf], + cwd: &Path, +) -> Option { + let changes = action.changes(); + for (path, change) in changes { + let candidate = match change { + ApplyPatchFileChange::Add { .. } => path, + ApplyPatchFileChange::Delete => path, + ApplyPatchFileChange::Update { move_path, .. } => move_path.as_ref().unwrap_or(path), + }; + + let abs = if candidate.is_absolute() { + candidate.clone() + } else { + cwd.join(candidate) + }; + + let mut allowed = false; + for root in writable_roots { + let root_abs = if root.is_absolute() { + root.clone() + } else { + cwd.join(root) + }; + if abs.starts_with(&root_abs) { + allowed = true; + break; + } + } + + if !allowed { + return Some(candidate.clone()); + } + } + None +} + +pub(crate) fn convert_apply_patch_to_protocol( + action: &ApplyPatchAction, +) -> HashMap { + let changes = action.changes(); + let mut result = HashMap::with_capacity(changes.len()); + for (path, change) in changes { + let protocol_change = match change { + ApplyPatchFileChange::Add { content } => FileChange::Add { + content: content.clone(), + }, + ApplyPatchFileChange::Delete => FileChange::Delete, + ApplyPatchFileChange::Update { + unified_diff, + move_path, + new_content: _new_content, + } => FileChange::Update { + unified_diff: unified_diff.clone(), + move_path: move_path.clone(), + }, + }; + result.insert(path.clone(), protocol_change); + } + result +} + +fn apply_changes_from_apply_patch_and_report( + action: &ApplyPatchAction, + stdout: &mut impl std::io::Write, + stderr: &mut impl std::io::Write, +) -> std::io::Result<()> { + match apply_changes_from_apply_patch(action) { + Ok(affected_paths) => { + print_summary(&affected_paths, stdout)?; + } + Err(err) => { + writeln!(stderr, "{err:?}")?; + } + } + + Ok(()) +} + +fn apply_changes_from_apply_patch(action: &ApplyPatchAction) -> anyhow::Result { + let mut added: Vec = Vec::new(); + let mut modified: Vec = Vec::new(); + let mut deleted: Vec = Vec::new(); + + let changes = action.changes(); + for (path, change) in changes { + match change { + ApplyPatchFileChange::Add { content } => { + if let Some(parent) = path.parent() { + if !parent.as_os_str().is_empty() { + std::fs::create_dir_all(parent).with_context(|| { + format!("Failed to create parent directories for {}", path.display()) + })?; + } + } + std::fs::write(path, content) + .with_context(|| format!("Failed to write file {}", path.display()))?; + added.push(path.clone()); + } + ApplyPatchFileChange::Delete => { + std::fs::remove_file(path) + .with_context(|| format!("Failed to delete file {}", path.display()))?; + deleted.push(path.clone()); + } + ApplyPatchFileChange::Update { + unified_diff: _unified_diff, + move_path, + new_content, + } => { + if let Some(move_path) = move_path { + if let Some(parent) = move_path.parent() { + if !parent.as_os_str().is_empty() { + std::fs::create_dir_all(parent).with_context(|| { + format!( + "Failed to create parent directories for {}", + move_path.display() + ) + })?; + } + } + + std::fs::rename(path, move_path) + .with_context(|| format!("Failed to rename file {}", path.display()))?; + std::fs::write(move_path, new_content)?; + modified.push(move_path.clone()); + deleted.push(path.clone()); + } else { + std::fs::write(path, new_content)?; + modified.push(path.clone()); + } + } + } + } + + Ok(AffectedPaths { + added, + modified, + deleted, + }) +} + +pub(crate) fn get_writable_roots(cwd: &Path) -> Vec { + let mut writable_roots = Vec::new(); + if cfg!(target_os = "macos") { + // On macOS, $TMPDIR is private to the user. + writable_roots.push(std::env::temp_dir()); + + // Allow pyenv to update its shims directory. Without this, any tool + // that happens to be managed by `pyenv` will fail with an error like: + // + // pyenv: cannot rehash: $HOME/.pyenv/shims isn't writable + // + // which is emitted every time `pyenv` tries to run `rehash` (for + // example, after installing a new Python package that drops an entry + // point). Although the sandbox is intentionally read‑only by default, + // writing to the user's local `pyenv` directory is safe because it + // is already user‑writable and scoped to the current user account. + if let Ok(home_dir) = std::env::var("HOME") { + let pyenv_dir = PathBuf::from(home_dir).join(".pyenv"); + writable_roots.push(pyenv_dir); + } + } + + writable_roots.push(cwd.to_path_buf()); + + writable_roots +} diff --git a/codex-rs/core/src/bash.rs b/codex-rs/core/src/bash.rs new file mode 100644 index 0000000000..b9cd444356 --- /dev/null +++ b/codex-rs/core/src/bash.rs @@ -0,0 +1,219 @@ +use tree_sitter::Parser; +use tree_sitter::Tree; +use tree_sitter_bash::LANGUAGE as BASH; + +/// Parse the provided bash source using tree-sitter-bash, returning a Tree on +/// success or None if parsing failed. +pub fn try_parse_bash(bash_lc_arg: &str) -> Option { + let lang = BASH.into(); + let mut parser = Parser::new(); + #[expect(clippy::expect_used)] + parser.set_language(&lang).expect("load bash grammar"); + let old_tree: Option<&Tree> = None; + parser.parse(bash_lc_arg, old_tree) +} + +/// Parse a script which may contain multiple simple commands joined only by +/// the safe logical/pipe/sequencing operators: `&&`, `||`, `;`, `|`. +/// +/// Returns `Some(Vec)` if every command is a plain word‑only +/// command and the parse tree does not contain disallowed constructs +/// (parentheses, redirections, substitutions, control flow, etc.). Otherwise +/// returns `None`. +pub fn try_parse_word_only_commands_sequence(tree: &Tree, src: &str) -> Option>> { + if tree.root_node().has_error() { + return None; + } + + // List of allowed (named) node kinds for a "word only commands sequence". + // If we encounter a named node that is not in this list we reject. + const ALLOWED_KINDS: &[&str] = &[ + // top level containers + "program", + "list", + "pipeline", + // commands & words + "command", + "command_name", + "word", + "string", + "string_content", + "raw_string", + "number", + ]; + // Allow only safe punctuation / operator tokens; anything else causes reject. + const ALLOWED_PUNCT_TOKENS: &[&str] = &["&&", "||", ";", "|", "\"", "'"]; + + let root = tree.root_node(); + let mut cursor = root.walk(); + let mut stack = vec![root]; + let mut command_nodes = Vec::new(); + while let Some(node) = stack.pop() { + let kind = node.kind(); + if node.is_named() { + if !ALLOWED_KINDS.contains(&kind) { + return None; + } + if kind == "command" { + command_nodes.push(node); + } + } else { + // Reject any punctuation / operator tokens that are not explicitly allowed. + if kind.chars().any(|c| "&;|".contains(c)) && !ALLOWED_PUNCT_TOKENS.contains(&kind) { + return None; + } + if !(ALLOWED_PUNCT_TOKENS.contains(&kind) || kind.trim().is_empty()) { + // If it's a quote token or operator it's allowed above; we also allow whitespace tokens. + // Any other punctuation like parentheses, braces, redirects, backticks, etc are rejected. + return None; + } + } + for child in node.children(&mut cursor) { + stack.push(child); + } + } + + let mut commands = Vec::new(); + for node in command_nodes { + if let Some(words) = parse_plain_command_from_node(node, src) { + commands.push(words); + } else { + return None; + } + } + Some(commands) +} + +fn parse_plain_command_from_node(cmd: tree_sitter::Node, src: &str) -> Option> { + if cmd.kind() != "command" { + return None; + } + let mut words = Vec::new(); + let mut cursor = cmd.walk(); + for child in cmd.named_children(&mut cursor) { + match child.kind() { + "command_name" => { + let word_node = child.named_child(0)?; + if word_node.kind() != "word" { + return None; + } + words.push(word_node.utf8_text(src.as_bytes()).ok()?.to_owned()); + } + "word" | "number" => { + words.push(child.utf8_text(src.as_bytes()).ok()?.to_owned()); + } + "string" => { + if child.child_count() == 3 + && child.child(0)?.kind() == "\"" + && child.child(1)?.kind() == "string_content" + && child.child(2)?.kind() == "\"" + { + words.push(child.child(1)?.utf8_text(src.as_bytes()).ok()?.to_owned()); + } else { + return None; + } + } + "raw_string" => { + let raw_string = child.utf8_text(src.as_bytes()).ok()?; + let stripped = raw_string + .strip_prefix('\'') + .and_then(|s| s.strip_suffix('\'')); + if let Some(s) = stripped { + words.push(s.to_owned()); + } else { + return None; + } + } + _ => return None, + } + } + Some(words) +} + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + + fn parse_seq(src: &str) -> Option>> { + let tree = try_parse_bash(src)?; + try_parse_word_only_commands_sequence(&tree, src) + } + + #[test] + fn accepts_single_simple_command() { + let cmds = parse_seq("ls -1").unwrap(); + assert_eq!(cmds, vec![vec!["ls".to_string(), "-1".to_string()]]); + } + + #[test] + fn accepts_multiple_commands_with_allowed_operators() { + let src = "ls && pwd; echo 'hi there' | wc -l"; + let cmds = parse_seq(src).unwrap(); + let expected: Vec> = vec![ + vec!["wc".to_string(), "-l".to_string()], + vec!["echo".to_string(), "hi there".to_string()], + vec!["pwd".to_string()], + vec!["ls".to_string()], + ]; + assert_eq!(cmds, expected); + } + + #[test] + fn extracts_double_and_single_quoted_strings() { + let cmds = parse_seq("echo \"hello world\"").unwrap(); + assert_eq!( + cmds, + vec![vec!["echo".to_string(), "hello world".to_string()]] + ); + + let cmds2 = parse_seq("echo 'hi there'").unwrap(); + assert_eq!( + cmds2, + vec![vec!["echo".to_string(), "hi there".to_string()]] + ); + } + + #[test] + fn accepts_numbers_as_words() { + let cmds = parse_seq("echo 123 456").unwrap(); + assert_eq!( + cmds, + vec![vec![ + "echo".to_string(), + "123".to_string(), + "456".to_string() + ]] + ); + } + + #[test] + fn rejects_parentheses_and_subshells() { + assert!(parse_seq("(ls)").is_none()); + assert!(parse_seq("ls || (pwd && echo hi)").is_none()); + } + + #[test] + fn rejects_redirections_and_unsupported_operators() { + assert!(parse_seq("ls > out.txt").is_none()); + assert!(parse_seq("echo hi & echo bye").is_none()); + } + + #[test] + fn rejects_command_and_process_substitutions_and_expansions() { + assert!(parse_seq("echo $(pwd)").is_none()); + assert!(parse_seq("echo `pwd`").is_none()); + assert!(parse_seq("echo $HOME").is_none()); + assert!(parse_seq("echo \"hi $USER\"").is_none()); + } + + #[test] + fn rejects_variable_assignment_prefix() { + assert!(parse_seq("FOO=bar ls").is_none()); + } + + #[test] + fn rejects_trailing_operator_parse_error() { + assert!(parse_seq("ls &&").is_none()); + } +} diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 816fc80f9b..5ede774b1c 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -21,8 +21,6 @@ use crate::client_common::ResponseEvent; use crate::client_common::ResponseStream; use crate::error::CodexErr; use crate::error::Result; -use crate::flags::OPENAI_REQUEST_MAX_RETRIES; -use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS; use crate::models::ContentItem; use crate::models::ResponseItem; use crate::openai_tools::create_tools_json_for_chat_completions_api; @@ -32,6 +30,7 @@ use crate::util::backoff; pub(crate) async fn stream_chat_completions( prompt: &Prompt, model: &str, + include_plan_tool: bool, client: &reqwest::Client, provider: &ModelProviderInfo, ) -> Result { @@ -41,9 +40,13 @@ pub(crate) async fn stream_chat_completions( let full_instructions = prompt.get_full_instructions(model); messages.push(json!({"role": "system", "content": full_instructions})); + if let Some(instr) = &prompt.user_instructions { + messages.push(json!({"role": "user", "content": instr})); + } + for item in &prompt.input { match item { - ResponseItem::Message { role, content } => { + ResponseItem::Message { role, content, .. } => { let mut text = String::new(); for c in content { match c { @@ -60,6 +63,7 @@ pub(crate) async fn stream_chat_completions( name, arguments, call_id, + .. } => { messages.push(json!({ "role": "assistant", @@ -106,7 +110,7 @@ pub(crate) async fn stream_chat_completions( } } - let tools_json = create_tools_json_for_chat_completions_api(prompt, model)?; + let tools_json = create_tools_json_for_chat_completions_api(prompt, model, include_plan_tool)?; let payload = json!({ "model": model, "messages": messages, @@ -121,6 +125,7 @@ pub(crate) async fn stream_chat_completions( ); let mut attempt = 0; + let max_retries = provider.request_max_retries(); loop { attempt += 1; @@ -134,9 +139,13 @@ pub(crate) async fn stream_chat_completions( match res { Ok(resp) if resp.status().is_success() => { - let (tx_event, rx_event) = mpsc::channel::>(16); + let (tx_event, rx_event) = mpsc::channel::>(1600); let stream = resp.bytes_stream().map_err(CodexErr::Reqwest); - tokio::spawn(process_chat_sse(stream, tx_event)); + tokio::spawn(process_chat_sse( + stream, + tx_event, + provider.stream_idle_timeout(), + )); return Ok(ResponseStream { rx_event }); } Ok(res) => { @@ -146,7 +155,7 @@ pub(crate) async fn stream_chat_completions( return Err(CodexErr::UnexpectedStatus(status, body)); } - if attempt > *OPENAI_REQUEST_MAX_RETRIES { + if attempt > max_retries { return Err(CodexErr::RetryLimit(status)); } @@ -162,7 +171,7 @@ pub(crate) async fn stream_chat_completions( tokio::time::sleep(delay).await; } Err(e) => { - if attempt > *OPENAI_REQUEST_MAX_RETRIES { + if attempt > max_retries { return Err(e.into()); } let delay = backoff(attempt); @@ -175,14 +184,15 @@ pub(crate) async fn stream_chat_completions( /// Lightweight SSE processor for the Chat Completions streaming format. The /// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest /// of the pipeline can stay agnostic of the underlying wire format. -async fn process_chat_sse(stream: S, tx_event: mpsc::Sender>) -where +async fn process_chat_sse( + stream: S, + tx_event: mpsc::Sender>, + idle_timeout: Duration, +) where S: Stream> + Unpin, { let mut stream = stream.eventsource(); - let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS; - // State to accumulate a function call across streaming chunks. // OpenAI may split the `arguments` string over multiple `delta` events // until the chunk whose `finish_reason` is `tool_calls` is emitted. We @@ -255,6 +265,7 @@ where content: vec![ContentItem::OutputText { text: content.to_string(), }], + id: None, }; let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; @@ -296,6 +307,7 @@ where "tool_calls" if fn_call_state.active => { // Build the FunctionCall response item. let item = ResponseItem::FunctionCall { + id: None, name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()), arguments: fn_call_state.arguments.clone(), call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new), @@ -398,6 +410,7 @@ where }))) => { if !this.cumulative.is_empty() { let aggregated_item = crate::models::ResponseItem::Message { + id: None, role: "assistant".to_string(), content: vec![crate::models::ContentItem::OutputText { text: std::mem::take(&mut this.cumulative), @@ -426,6 +439,12 @@ where // will never appear in a Chat Completions stream. continue; } + Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(_)))) + | Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => { + // Deltas are ignored here since aggregation waits for the + // final OutputItemDone. + continue; + } } } } diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 1b8e4c959d..4e0e62c0f7 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -3,6 +3,8 @@ use std::path::Path; use std::time::Duration; use bytes::Bytes; +use codex_login::AuthMode; +use codex_login::CodexAuth; use eventsource_stream::Eventsource; use futures::prelude::*; use reqwest::StatusCode; @@ -15,6 +17,7 @@ use tokio_util::io::ReaderStream; use tracing::debug; use tracing::trace; use tracing::warn; +use uuid::Uuid; use crate::chat_completions::AggregateStreamExt; use crate::chat_completions::stream_chat_completions; @@ -27,12 +30,12 @@ use crate::config::Config; use crate::config_types::ReasoningEffort as ReasoningEffortConfig; use crate::config_types::ReasoningSummary as ReasoningSummaryConfig; use crate::error::CodexErr; +use crate::error::EnvVarError; use crate::error::Result; use crate::flags::CODEX_RS_SSE_FIXTURE; -use crate::flags::OPENAI_REQUEST_MAX_RETRIES; -use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS; use crate::model_provider_info::ModelProviderInfo; use crate::model_provider_info::WireApi; +use crate::models::ContentItem; use crate::models::ResponseItem; use crate::openai_tools::create_tools_json_for_responses_api; use crate::protocol::TokenUsage; @@ -42,8 +45,10 @@ use std::sync::Arc; #[derive(Clone)] pub struct ModelClient { config: Arc, + auth: Option, client: reqwest::Client, provider: ModelProviderInfo, + session_id: Uuid, effort: ReasoningEffortConfig, summary: ReasoningSummaryConfig, } @@ -51,14 +56,18 @@ pub struct ModelClient { impl ModelClient { pub fn new( config: Arc, + auth: Option, provider: ModelProviderInfo, effort: ReasoningEffortConfig, summary: ReasoningSummaryConfig, + session_id: Uuid, ) -> Self { Self { config, + auth, client: reqwest::Client::new(), provider, + session_id, effort, summary, } @@ -75,6 +84,7 @@ impl ModelClient { let response_stream = stream_chat_completions( prompt, &self.config.model, + self.config.include_plan_tool, &self.client, &self.provider, ) @@ -109,23 +119,65 @@ impl ModelClient { if let Some(path) = &*CODEX_RS_SSE_FIXTURE { // short circuit for tests warn!(path, "Streaming from fixture"); - return stream_from_fixture(path).await; + return stream_from_fixture(path, self.provider.clone()).await; } + let auth = self.auth.as_ref().ok_or_else(|| { + CodexErr::EnvVar(EnvVarError { + var: "OPENAI_API_KEY".to_string(), + instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".to_string()), + }) + })?; + + let store = prompt.store && auth.mode != AuthMode::ChatGPT; + + let base_url = match self.provider.base_url.clone() { + Some(url) => url, + None => match auth.mode { + AuthMode::ChatGPT => "https://chatgpt.com/backend-api/codex".to_string(), + AuthMode::ApiKey => "https://api.openai.com/v1".to_string(), + }, + }; + + let token = auth.get_token().await?; + let full_instructions = prompt.get_full_instructions(&self.config.model); - let tools_json = create_tools_json_for_responses_api(prompt, &self.config.model)?; + let tools_json = create_tools_json_for_responses_api( + prompt, + &self.config.model, + self.config.include_plan_tool, + )?; let reasoning = create_reasoning_param_for_request(&self.config, self.effort, self.summary); + + // Request encrypted COT if we are not storing responses, + // otherwise reasoning items will be referenced by ID + let include: Vec = if !store && reasoning.is_some() { + vec!["reasoning.encrypted_content".to_string()] + } else { + vec![] + }; + + let mut input_with_instructions = Vec::with_capacity(prompt.input.len() + 1); + if let Some(ui) = &prompt.user_instructions { + input_with_instructions.push(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { text: ui.clone() }], + }); + } + input_with_instructions.extend(prompt.input.clone()); + let payload = ResponsesApiRequest { model: &self.config.model, instructions: &full_instructions, - input: &prompt.input, + input: &input_with_instructions, tools: &tools_json, tool_choice: "auto", parallel_tool_calls: false, reasoning, - previous_response_id: prompt.prev_id.clone(), - store: prompt.store, + store, stream: true, + include, }; trace!( @@ -135,24 +187,45 @@ impl ModelClient { ); let mut attempt = 0; + let max_retries = self.provider.request_max_retries(); + loop { attempt += 1; let req_builder = self - .provider - .create_request_builder(&self.client)? + .client + .post(format!("{base_url}/responses")) .header("OpenAI-Beta", "responses=experimental") + .header("session_id", self.session_id.to_string()) + .bearer_auth(&token) .header(reqwest::header::ACCEPT, "text/event-stream") .json(&payload); + let req_builder = self.provider.apply_http_headers(req_builder); + let res = req_builder.send().await; + if let Ok(resp) = &res { + trace!( + "Response status: {}, request-id: {}", + resp.status(), + resp.headers() + .get("x-request-id") + .map(|v| v.to_str().unwrap_or_default()) + .unwrap_or_default() + ); + } + match res { Ok(resp) if resp.status().is_success() => { - let (tx_event, rx_event) = mpsc::channel::>(16); + let (tx_event, rx_event) = mpsc::channel::>(1600); // spawn task to process SSE let stream = resp.bytes_stream().map_err(CodexErr::Reqwest); - tokio::spawn(process_sse(stream, tx_event)); + tokio::spawn(process_sse( + stream, + tx_event, + self.provider.stream_idle_timeout(), + )); return Ok(ResponseStream { rx_event }); } @@ -171,7 +244,7 @@ impl ModelClient { return Err(CodexErr::UnexpectedStatus(status, body)); } - if attempt > *OPENAI_REQUEST_MAX_RETRIES { + if attempt > max_retries { return Err(CodexErr::RetryLimit(status)); } @@ -188,7 +261,7 @@ impl ModelClient { tokio::time::sleep(delay).await; } Err(e) => { - if attempt > *OPENAI_REQUEST_MAX_RETRIES { + if attempt > max_retries { return Err(e.into()); } let delay = backoff(attempt); @@ -197,6 +270,10 @@ impl ModelClient { } } } + + pub fn get_provider(&self) -> ModelProviderInfo { + self.provider.clone() + } } #[derive(Debug, Deserialize, Serialize)] @@ -205,6 +282,7 @@ struct SseEvent { kind: String, response: Option, item: Option, + delta: Option, } #[derive(Debug, Deserialize)] @@ -247,14 +325,16 @@ struct ResponseCompletedOutputTokensDetails { reasoning_tokens: u64, } -async fn process_sse(stream: S, tx_event: mpsc::Sender>) -where +async fn process_sse( + stream: S, + tx_event: mpsc::Sender>, + idle_timeout: Duration, +) where S: Stream> + Unpin, { let mut stream = stream.eventsource(); // If the stream stays completely silent for an extended period treat it as disconnected. - let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS; // The response id returned from the "complete" message. let mut response_completed: Option = None; @@ -315,7 +395,7 @@ where // duplicated `output` array embedded in the `response.completed` // payload. That produced two concrete issues: // 1. No real‑time streaming – the user only saw output after the - // entire turn had finished, which broke the “typing” UX and + // entire turn had finished, which broke the "typing" UX and // made long‑running turns look stalled. // 2. Duplicate `function_call_output` items – both the // individual *and* the completed array were forwarded, which @@ -337,11 +417,40 @@ where return; } } + "response.output_text.delta" => { + if let Some(delta) = event.delta { + let event = ResponseEvent::OutputTextDelta(delta); + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + } + "response.reasoning_summary_text.delta" => { + if let Some(delta) = event.delta { + let event = ResponseEvent::ReasoningSummaryDelta(delta); + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + } "response.created" => { if event.response.is_some() { let _ = tx_event.send(Ok(ResponseEvent::Created {})).await; } } + "response.failed" => { + if let Some(resp_val) = event.response { + let error = resp_val + .get("error") + .and_then(|v| v.get("message")) + .and_then(|v| v.as_str()) + .unwrap_or("response.failed event received"); + + let _ = tx_event + .send(Err(CodexErr::Stream(error.to_string()))) + .await; + } + } // Final response completed – includes array of output items & id "response.completed" => { if let Some(resp_val) = event.response { @@ -360,10 +469,8 @@ where | "response.function_call_arguments.delta" | "response.in_progress" | "response.output_item.added" - | "response.output_text.delta" | "response.output_text.done" | "response.reasoning_summary_part.added" - | "response.reasoning_summary_text.delta" | "response.reasoning_summary_text.done" => { // Currently, we ignore these events, but we handle them // separately to skip the logging message in the `other` case. @@ -374,8 +481,11 @@ where } /// used in tests to stream from a text SSE file -async fn stream_from_fixture(path: impl AsRef) -> Result { - let (tx_event, rx_event) = mpsc::channel::>(16); +async fn stream_from_fixture( + path: impl AsRef, + provider: ModelProviderInfo, +) -> Result { + let (tx_event, rx_event) = mpsc::channel::>(1600); let f = std::fs::File::open(path.as_ref())?; let lines = std::io::BufReader::new(f).lines(); @@ -388,17 +498,57 @@ async fn stream_from_fixture(path: impl AsRef) -> Result { let rdr = std::io::Cursor::new(content); let stream = ReaderStream::new(rdr).map_err(CodexErr::Io); - tokio::spawn(process_sse(stream, tx_event)); + tokio::spawn(process_sse( + stream, + tx_event, + provider.stream_idle_timeout(), + )); Ok(ResponseStream { rx_event }) } #[cfg(test)] mod tests { #![allow(clippy::expect_used, clippy::unwrap_used)] + use super::*; use serde_json::json; + use tokio::sync::mpsc; + use tokio_test::io::Builder as IoBuilder; + use tokio_util::io::ReaderStream; - async fn run_sse(events: Vec) -> Vec { + // ──────────────────────────── + // Helpers + // ──────────────────────────── + + /// Runs the SSE parser on pre-chunked byte slices and returns every event + /// (including any final `Err` from a stream-closure check). + async fn collect_events( + chunks: &[&[u8]], + provider: ModelProviderInfo, + ) -> Vec> { + let mut builder = IoBuilder::new(); + for chunk in chunks { + builder.read(chunk); + } + + let reader = builder.build(); + let stream = ReaderStream::new(reader).map_err(CodexErr::Io); + let (tx, mut rx) = mpsc::channel::>(16); + tokio::spawn(process_sse(stream, tx, provider.stream_idle_timeout())); + + let mut events = Vec::new(); + while let Some(ev) = rx.recv().await { + events.push(ev); + } + events + } + + /// Builds an in-memory SSE stream from JSON fixtures and returns only the + /// successfully parsed events (panics on internal channel errors). + async fn run_sse( + events: Vec, + provider: ModelProviderInfo, + ) -> Vec { let mut body = String::new(); for e in events { let kind = e @@ -411,9 +561,11 @@ mod tests { body.push_str(&format!("event: {kind}\ndata: {e}\n\n")); } } + let (tx, mut rx) = mpsc::channel::>(8); let stream = ReaderStream::new(std::io::Cursor::new(body)).map_err(CodexErr::Io); - tokio::spawn(process_sse(stream, tx)); + tokio::spawn(process_sse(stream, tx, provider.stream_idle_timeout())); + let mut out = Vec::new(); while let Some(ev) = rx.recv().await { out.push(ev.expect("channel closed")); @@ -421,14 +573,137 @@ mod tests { out } - /// Verifies that the SSE adapter emits the expected [`ResponseEvent`] for - /// a variety of `type` values from the Responses API. The test is written - /// table-driven style to keep additions for new event kinds trivial. - /// - /// Each `Case` supplies an input event, a predicate that must match the - /// *first* `ResponseEvent` produced by the adapter, and the total number - /// of events expected after appending a synthetic `response.completed` - /// marker that terminates the stream. + // ──────────────────────────── + // Tests from `implement-test-for-responses-api-sse-parser` + // ──────────────────────────── + + #[tokio::test] + async fn parses_items_and_completed() { + let item1 = json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello"}] + } + }) + .to_string(); + + let item2 = json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "World"}] + } + }) + .to_string(); + + let completed = json!({ + "type": "response.completed", + "response": { "id": "resp1" } + }) + .to_string(); + + let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); + let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n"); + let sse3 = format!("event: response.completed\ndata: {completed}\n\n"); + + let provider = ModelProviderInfo { + name: "test".to_string(), + base_url: Some("https://test.com".to_string()), + env_key: Some("TEST_API_KEY".to_string()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(1000), + requires_auth: false, + }; + + let events = collect_events( + &[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()], + provider, + ) + .await; + + assert_eq!(events.len(), 3); + + matches!( + &events[0], + Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) + if role == "assistant" + ); + + matches!( + &events[1], + Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) + if role == "assistant" + ); + + match &events[2] { + Ok(ResponseEvent::Completed { + response_id, + token_usage, + }) => { + assert_eq!(response_id, "resp1"); + assert!(token_usage.is_none()); + } + other => panic!("unexpected third event: {other:?}"), + } + } + + #[tokio::test] + async fn error_when_missing_completed() { + let item1 = json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello"}] + } + }) + .to_string(); + + let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); + let provider = ModelProviderInfo { + name: "test".to_string(), + base_url: Some("https://test.com".to_string()), + env_key: Some("TEST_API_KEY".to_string()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(1000), + requires_auth: false, + }; + + let events = collect_events(&[sse1.as_bytes()], provider).await; + + assert_eq!(events.len(), 2); + + matches!(events[0], Ok(ResponseEvent::OutputItemDone(_))); + + match &events[1] { + Err(CodexErr::Stream(msg)) => { + assert_eq!(msg, "stream closed before response.completed") + } + other => panic!("unexpected second event: {other:?}"), + } + } + + // ──────────────────────────── + // Table-driven test from `main` + // ──────────────────────────── + + /// Verifies that the adapter produces the right `ResponseEvent` for a + /// variety of incoming `type` values. #[tokio::test] async fn table_driven_event_kinds() { struct TestCase { @@ -441,11 +716,9 @@ mod tests { fn is_created(ev: &ResponseEvent) -> bool { matches!(ev, ResponseEvent::Created) } - fn is_output(ev: &ResponseEvent) -> bool { matches!(ev, ResponseEvent::OutputItemDone(_)) } - fn is_completed(ev: &ResponseEvent) -> bool { matches!(ev, ResponseEvent::Completed { .. }) } @@ -498,9 +771,29 @@ mod tests { for case in cases { let mut evs = vec![case.event]; evs.push(completed.clone()); - let out = run_sse(evs).await; + + let provider = ModelProviderInfo { + name: "test".to_string(), + base_url: Some("https://test.com".to_string()), + env_key: Some("TEST_API_KEY".to_string()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(1000), + requires_auth: false, + }; + + let out = run_sse(evs, provider).await; assert_eq!(out.len(), case.expected_len, "case {}", case.name); - assert!((case.expect_first)(&out[0]), "case {}", case.name); + assert!( + (case.expect_first)(&out[0]), + "first event mismatch in case {}", + case.name + ); } } } diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index f9a816a7a9..157f35872a 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -22,8 +22,6 @@ const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md"); pub struct Prompt { /// Conversation context input items. pub input: Vec, - /// Optional previous response ID (when storage is enabled). - pub prev_id: Option, /// Optional instructions from the user to amend to the built-in agent /// instructions. pub user_instructions: Option, @@ -34,14 +32,18 @@ pub struct Prompt { /// the "fully qualified" tool name (i.e., prefixed with the server name), /// which should be reported to the model in place of Tool::name. pub extra_tools: HashMap, + + /// Optional override for the built-in BASE_INSTRUCTIONS. + pub base_instructions_override: Option, } impl Prompt { pub(crate) fn get_full_instructions(&self, model: &str) -> Cow<'_, str> { - let mut sections: Vec<&str> = vec![BASE_INSTRUCTIONS]; - if let Some(ref user) = self.user_instructions { - sections.push(user); - } + let base = self + .base_instructions_override + .as_deref() + .unwrap_or(BASE_INSTRUCTIONS); + let mut sections: Vec<&str> = vec![base]; if model.starts_with("gpt-4.1") { sections.push(APPLY_PATCH_TOOL_INSTRUCTIONS); } @@ -57,6 +59,8 @@ pub enum ResponseEvent { response_id: String, token_usage: Option, }, + OutputTextDelta(String), + ReasoningSummaryDelta(String), } #[derive(Debug, Serialize)] @@ -124,11 +128,10 @@ pub(crate) struct ResponsesApiRequest<'a> { pub(crate) tool_choice: &'static str, pub(crate) parallel_tool_calls: bool, pub(crate) reasoning: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) previous_response_id: Option, /// true when using the Responses API. pub(crate) store: bool, pub(crate) stream: bool, + pub(crate) include: Vec, } use crate::config::Config; @@ -182,3 +185,19 @@ impl Stream for ResponseStream { self.rx_event.poll_recv(cx) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn get_full_instructions_no_user_content() { + let prompt = Prompt { + user_instructions: Some("custom instruction".to_string()), + ..Default::default() + }; + let expected = format!("{BASE_INSTRUCTIONS}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}"); + let full = prompt.get_full_instructions("gpt-4.1"); + assert_eq!(full, expected); + } +} diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index d39f239b88..710cb832e9 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -11,15 +11,12 @@ use std::sync::Mutex; use std::sync::atomic::AtomicU64; use std::time::Duration; -use anyhow::Context; use async_channel::Receiver; use async_channel::Sender; -use codex_apply_patch::AffectedPaths; use codex_apply_patch::ApplyPatchAction; -use codex_apply_patch::ApplyPatchFileChange; use codex_apply_patch::MaybeApplyPatchVerified; use codex_apply_patch::maybe_parse_apply_patch_verified; -use codex_apply_patch::print_summary; +use codex_login::CodexAuth; use futures::prelude::*; use mcp_types::CallToolResult; use serde::Serialize; @@ -34,7 +31,11 @@ use tracing::trace; use tracing::warn; use uuid::Uuid; -use crate::WireApi; +use crate::apply_patch::CODEX_APPLY_PATCH_ARG1; +use crate::apply_patch::InternalApplyPatchInvocation; +use crate::apply_patch::convert_apply_patch_to_protocol; +use crate::apply_patch::get_writable_roots; +use crate::apply_patch::{self}; use crate::client::ModelClient; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; @@ -49,9 +50,7 @@ use crate::exec::ExecToolCallOutput; use crate::exec::SandboxType; use crate::exec::process_exec_tool_call; use crate::exec_env::create_env; -use crate::flags::OPENAI_STREAM_MAX_RETRIES; use crate::mcp_connection_manager::McpConnectionManager; -use crate::mcp_connection_manager::try_parse_fully_qualified_tool_name; use crate::mcp_tool_call::handle_mcp_tool_call; use crate::models::ContentItem; use crate::models::FunctionCallOutputPayload; @@ -61,9 +60,12 @@ use crate::models::ReasoningItemReasoningSummary; use crate::models::ResponseInputItem; use crate::models::ResponseItem; use crate::models::ShellToolCallParams; +use crate::plan_tool::handle_update_plan; use crate::project_doc::get_user_instructions; +use crate::protocol::AgentMessageDeltaEvent; use crate::protocol::AgentMessageEvent; use crate::protocol::AgentReasoningContentEvent; +use crate::protocol::AgentReasoningDeltaEvent; use crate::protocol::AgentReasoningEvent; use crate::protocol::ApplyPatchApprovalRequestEvent; use crate::protocol::AskForApproval; @@ -74,11 +76,8 @@ use crate::protocol::EventMsg; use crate::protocol::ExecApprovalRequestEvent; use crate::protocol::ExecCommandBeginEvent; use crate::protocol::ExecCommandEndEvent; -use crate::protocol::FileChange; use crate::protocol::InputItem; use crate::protocol::Op; -use crate::protocol::PatchApplyBeginEvent; -use crate::protocol::PatchApplyEndEvent; use crate::protocol::ReviewDecision; use crate::protocol::SandboxPolicy; use crate::protocol::SessionConfiguredEvent; @@ -87,7 +86,8 @@ use crate::protocol::TaskCompleteEvent; use crate::rollout::RolloutRecorder; use crate::safety::SafetyCheck; use crate::safety::assess_command_safety; -use crate::safety::assess_patch_safety; +use crate::safety::assess_safety_for_untrusted_command; +use crate::shell; use crate::user_notification::UserNotification; use crate::util::backoff; @@ -99,30 +99,52 @@ pub struct Codex { rx_event: Receiver, } -impl Codex { - /// Spawn a new [`Codex`] and initialize the session. Returns the instance - /// of `Codex` and the ID of the `SessionInitialized` event that was - /// submitted to start the session. - pub async fn spawn(config: Config, ctrl_c: Arc) -> CodexResult<(Codex, String)> { - let (tx_sub, rx_sub) = async_channel::bounded(64); - let (tx_event, rx_event) = async_channel::bounded(64); +/// Wrapper returned by [`Codex::spawn`] containing the spawned [`Codex`], +/// the submission id for the initial `ConfigureSession` request and the +/// unique session id. +pub struct CodexSpawnOk { + pub codex: Codex, + pub init_id: String, + pub session_id: Uuid, +} + +impl Codex { + /// Spawn a new [`Codex`] and initialize the session. + pub async fn spawn( + config: Config, + auth: Option, + ctrl_c: Arc, + ) -> CodexResult { + // experimental resume path (undocumented) + let resume_path = config.experimental_resume.clone(); + info!("resume_path: {resume_path:?}"); + let (tx_sub, rx_sub) = async_channel::bounded(64); + let (tx_event, rx_event) = async_channel::bounded(1600); + + let user_instructions = get_user_instructions(&config).await; - let instructions = get_user_instructions(&config).await; let configure_session = Op::ConfigureSession { provider: config.model_provider.clone(), model: config.model.clone(), model_reasoning_effort: config.model_reasoning_effort, model_reasoning_summary: config.model_reasoning_summary, - instructions, + user_instructions, + base_instructions: config.base_instructions.clone(), approval_policy: config.approval_policy, sandbox_policy: config.sandbox_policy.clone(), disable_response_storage: config.disable_response_storage, notify: config.notify.clone(), cwd: config.cwd.clone(), + resume_path: resume_path.clone(), }; let config = Arc::new(config); - tokio::spawn(submission_loop(config, rx_sub, tx_event, ctrl_c)); + + // Generate a unique ID for the lifetime of this Codex session. + let session_id = Uuid::new_v4(); + tokio::spawn(submission_loop( + session_id, config, auth, rx_sub, tx_event, ctrl_c, + )); let codex = Codex { next_id: AtomicU64::new(0), tx_sub, @@ -130,7 +152,11 @@ impl Codex { }; let init_id = codex.submit(configure_session).await?; - Ok((codex, init_id)) + Ok(CodexSpawnOk { + codex, + init_id, + session_id, + }) } /// Submit the `op` wrapped in a `Submission` with a unique ID. @@ -169,18 +195,20 @@ impl Codex { /// A session has at most 1 running task at a time, and can be interrupted by user input. pub(crate) struct Session { client: ModelClient, - tx_event: Sender, + pub(crate) tx_event: Sender, ctrl_c: Arc, /// The session's current working directory. All relative paths provided by /// the model as well as sandbox policies are resolved against this path /// instead of `std::env::current_dir()`. - cwd: PathBuf, - instructions: Option, - approval_policy: AskForApproval, + pub(crate) cwd: PathBuf, + base_instructions: Option, + user_instructions: Option, + pub(crate) approval_policy: AskForApproval, sandbox_policy: SandboxPolicy, shell_environment_policy: ShellEnvironmentPolicy, - writable_roots: Mutex>, + pub(crate) writable_roots: Mutex>, + disable_response_storage: bool, /// Manager for external MCP servers/tools. mcp_connection_manager: McpConnectionManager, @@ -194,6 +222,7 @@ pub(crate) struct Session { rollout: Mutex>, state: Mutex, codex_linux_sandbox_exe: Option, + user_shell: shell::Shell, } impl Session { @@ -209,13 +238,9 @@ impl Session { struct State { approved_commands: HashSet>, current_task: Option, - /// Call IDs that have been sent from the Responses API but have not been sent back yet. - /// You CANNOT send a Responses API follow-up message unless you have sent back the output for all pending calls or else it will 400. - pending_call_ids: HashSet, - previous_response_id: Option, pending_approvals: HashMap>, pending_input: Vec, - zdr_transcript: Option, + history: ConversationHistory, } impl Session { @@ -247,6 +272,7 @@ impl Session { pub async fn request_command_approval( &self, sub_id: String, + call_id: String, command: Vec, cwd: PathBuf, reason: Option, @@ -255,6 +281,7 @@ impl Session { let event = Event { id: sub_id.clone(), msg: EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { + call_id, command, cwd, reason, @@ -271,6 +298,7 @@ impl Session { pub async fn request_patch_approval( &self, sub_id: String, + call_id: String, action: &ApplyPatchAction, reason: Option, grant_root: Option, @@ -279,6 +307,7 @@ impl Session { let event = Event { id: sub_id.clone(), msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { + call_id, changes: convert_apply_patch_to_protocol(action), reason, grant_root, @@ -308,37 +337,42 @@ impl Session { /// transcript, if enabled. async fn record_conversation_items(&self, items: &[ResponseItem]) { debug!("Recording items for conversation: {items:?}"); - self.record_rollout_items(items).await; + self.record_state_snapshot(items).await; - if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() { - transcript.record_items(items); - } + self.state.lock().unwrap().history.record_items(items); } - /// Append the given items to the session's rollout transcript (if enabled) - /// and persist them to disk. - async fn record_rollout_items(&self, items: &[ResponseItem]) { - // Clone the recorder outside of the mutex so we don't hold the lock - // across an await point (MutexGuard is not Send). + async fn record_state_snapshot(&self, items: &[ResponseItem]) { + let snapshot = { crate::rollout::SessionStateSnapshot {} }; + let recorder = { let guard = self.rollout.lock().unwrap(); guard.as_ref().cloned() }; if let Some(rec) = recorder { + if let Err(e) = rec.record_state(snapshot).await { + error!("failed to record rollout state: {e:#}"); + } if let Err(e) = rec.record_items(items).await { error!("failed to record rollout items: {e:#}"); } } } - async fn notify_exec_command_begin(&self, sub_id: &str, call_id: &str, params: &ExecParams) { + async fn notify_exec_command_begin( + &self, + sub_id: &str, + call_id: &str, + command_for_display: Vec, + command_cwd: &Path, + ) { let event = Event { id: sub_id.to_string(), msg: EventMsg::ExecCommandBegin(ExecCommandBeginEvent { call_id: call_id.to_string(), - command: params.command.clone(), - cwd: params.cwd.clone(), + command: command_for_display, + cwd: command_cwd.to_path_buf(), }), }; let _ = self.tx_event.send(event).await; @@ -417,8 +451,6 @@ impl Session { pub fn abort(&self) { info!("Aborting existing session"); let mut state = self.state.lock().unwrap(); - // Don't clear pending_call_ids because we need to keep track of them to ensure we don't 400 on the next turn. - // We will generate a synthetic aborted response for each pending call id. state.pending_approvals.clear(); state.pending_input.clear(); if let Some(task) = state.current_task.take() { @@ -463,15 +495,10 @@ impl Drop for Session { } impl State { - pub fn partial_clone(&self, retain_zdr_transcript: bool) -> Self { + pub fn partial_clone(&self) -> Self { Self { approved_commands: self.approved_commands.clone(), - previous_response_id: self.previous_response_id.clone(), - zdr_transcript: if retain_zdr_transcript { - self.zdr_transcript.clone() - } else { - None - }, + history: self.history.clone(), ..Default::default() } } @@ -513,14 +540,13 @@ impl AgentTask { } async fn submission_loop( + mut session_id: Uuid, config: Arc, + auth: Option, rx_sub: Receiver, tx_event: Sender, ctrl_c: Arc, ) { - // Generate a unique ID for the lifetime of this Codex session. - let session_id = Uuid::new_v4(); - let mut sess: Option> = None; // shorthand - send an event when there is no active session let send_no_session_event = |sub_id: String| async { @@ -566,14 +592,18 @@ async fn submission_loop( model, model_reasoning_effort, model_reasoning_summary, - instructions, + user_instructions, + base_instructions, approval_policy, sandbox_policy, disable_response_storage, notify, cwd, + resume_path, } => { - info!("Configuring session: model={model}; provider={provider:?}"); + info!( + "Configuring session: model={model}; provider={provider:?}; resume={resume_path:?}" + ); if !cwd.is_absolute() { let message = format!("cwd is not absolute: {cwd:?}"); error!(message); @@ -586,31 +616,59 @@ async fn submission_loop( } return; } + // Optionally resume an existing rollout. + let mut restored_items: Option> = None; + let rollout_recorder: Option = + if let Some(path) = resume_path.as_ref() { + match RolloutRecorder::resume(path, cwd.clone()).await { + Ok((rec, saved)) => { + session_id = saved.session_id; + if !saved.items.is_empty() { + restored_items = Some(saved.items); + } + Some(rec) + } + Err(e) => { + warn!("failed to resume rollout from {path:?}: {e}"); + None + } + } + } else { + None + }; + + let rollout_recorder = match rollout_recorder { + Some(rec) => Some(rec), + None => { + match RolloutRecorder::new(&config, session_id, user_instructions.clone()) + .await + { + Ok(r) => Some(r), + Err(e) => { + warn!("failed to initialise rollout recorder: {e}"); + None + } + } + } + }; let client = ModelClient::new( config.clone(), + auth.clone(), provider.clone(), model_reasoning_effort, model_reasoning_summary, + session_id, ); // abort any current running session and clone its state - let retain_zdr_transcript = - record_conversation_history(disable_response_storage, provider.wire_api); let state = match sess.take() { Some(sess) => { sess.abort(); - sess.state - .lock() - .unwrap() - .partial_clone(retain_zdr_transcript) + sess.state.lock().unwrap().partial_clone() } None => State { - zdr_transcript: if retain_zdr_transcript { - Some(ConversationHistory::new()) - } else { - None - }, + history: ConversationHistory::new(), ..Default::default() }, }; @@ -645,26 +703,13 @@ async fn submission_loop( }); } } - - // Attempt to create a RolloutRecorder *before* moving the - // `instructions` value into the Session struct. - // TODO: if ConfigureSession is sent twice, we will create an - // overlapping rollout file. Consider passing RolloutRecorder - // from above. - let rollout_recorder = - match RolloutRecorder::new(&config, session_id, instructions.clone()).await { - Ok(r) => Some(r), - Err(e) => { - warn!("failed to initialise rollout recorder: {e}"); - None - } - }; - + let default_shell = shell::default_user_shell().await; sess = Some(Arc::new(Session { client, tx_event: tx_event.clone(), ctrl_c: Arc::clone(&ctrl_c), - instructions, + user_instructions, + base_instructions, approval_policy, sandbox_policy, shell_environment_policy: config.shell_environment_policy.clone(), @@ -675,8 +720,18 @@ async fn submission_loop( state: Mutex::new(state), rollout: Mutex::new(rollout_recorder), codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(), + disable_response_storage, + user_shell: default_shell, })); + // Patch restored state into the newly created session. + if let Some(sess_arc) = &sess { + if restored_items.is_some() { + let mut st = sess_arc.state.lock().unwrap(); + st.history.record_items(restored_items.unwrap().iter()); + } + } + // Gather history metadata for SessionConfiguredEvent. let (history_log_id, history_entry_count) = crate::message_history::history_metadata(&config).await; @@ -745,6 +800,8 @@ async fn submission_loop( } } Op::AddToHistory { text } => { + // TODO: What should we do if we got AddToHistory before ConfigureSession? + // currently, if ConfigureSession has resume path, this history will be ignored let id = session_id; let config = config.clone(); tokio::spawn(async move { @@ -784,6 +841,37 @@ async fn submission_loop( } }); } + Op::Shutdown => { + info!("Shutting down Codex instance"); + + // Gracefully flush and shutdown rollout recorder on session end so tests + // that inspect the rollout file do not race with the background writer. + if let Some(sess_arc) = sess { + let recorder_opt = sess_arc.rollout.lock().unwrap().take(); + if let Some(rec) = recorder_opt { + if let Err(e) = rec.shutdown().await { + warn!("failed to shutdown rollout recorder: {e}"); + let event = Event { + id: sub.id.clone(), + msg: EventMsg::Error(ErrorEvent { + message: "Failed to shutdown rollout recorder".to_string(), + }), + }; + if let Err(e) = tx_event.send(event).await { + warn!("failed to send error message: {e:?}"); + } + } + } + } + let event = Event { + id: sub.id.clone(), + msg: EventMsg::ShutdownComplete, + }; + if let Err(e) = tx_event.send(event).await { + warn!("failed to send Shutdown event: {e}"); + } + break; + } } } debug!("Agent loop exited"); @@ -818,14 +906,8 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { sess.record_conversation_items(&[initial_input_for_turn.clone().into()]) .await; - let mut input_for_next_turn: Vec = vec![initial_input_for_turn]; let last_agent_message: Option; loop { - let mut net_new_turn_input = input_for_next_turn - .drain(..) - .map(ResponseItem::from) - .collect::>(); - // Note that pending_input would be something like a message the user // submitted through the UI while the model was running. Though the UI // may support this, the model might not. @@ -842,29 +924,7 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { // only record the new items that originated in this turn so that it // represents an append-only log without duplicates. let turn_input: Vec = - if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() { - // If we are using Chat/ZDR, we need to send the transcript with - // every turn. By induction, `transcript` already contains: - // - The `input` that kicked off this task. - // - Each `ResponseItem` that was recorded in the previous turn. - // - Each response to a `ResponseItem` (in practice, the only - // response type we seem to have is `FunctionCallOutput`). - // - // The only thing the `transcript` does not contain is the - // `pending_input` that was injected while the model was - // running. We need to add that to the conversation history - // so that the model can see it in the next turn. - [transcript.contents(), pending_input].concat() - } else { - // In practice, net_new_turn_input should contain only: - // - User messages - // - Outputs for function calls requested by the model - net_new_turn_input.extend(pending_input); - - // Responses API path – we can just send the new items and - // record the same. - net_new_turn_input - }; + [sess.state.lock().unwrap().history.contents(), pending_input].concat(); let turn_input_messages: Vec = turn_input .iter() @@ -920,15 +980,17 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { ) => { items_to_record_in_conversation_history.push(item); let (content, success): (String, Option) = match result { - Ok(CallToolResult { content, is_error }) => { - match serde_json::to_string(content) { - Ok(content) => (content, *is_error), - Err(e) => { - warn!("Failed to serialize MCP tool call output: {e}"); - (e.to_string(), Some(true)) - } + Ok(CallToolResult { + content, + is_error, + structured_content: _, + }) => match serde_json::to_string(content) { + Ok(content) => (content, *is_error), + Err(e) => { + warn!("Failed to serialize MCP tool call output: {e}"); + (e.to_string(), Some(true)) } - } + }, Err(e) => (e.clone(), Some(true)), }; items_to_record_in_conversation_history.push( @@ -938,8 +1000,19 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { }, ); } - (ResponseItem::Reasoning { .. }, None) => { - // Omit from conversation history. + ( + ResponseItem::Reasoning { + id, + summary, + encrypted_content, + }, + None, + ) => { + items_to_record_in_conversation_history.push(ResponseItem::Reasoning { + id: id.clone(), + summary: summary.clone(), + encrypted_content: encrypted_content.clone(), + }); } _ => { warn!("Unexpected response item: {item:?} with response: {response:?}"); @@ -968,8 +1041,6 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { }); break; } - - input_for_next_turn = responses; } Err(e) => { info!("Turn error: {e:#}"); @@ -997,27 +1068,13 @@ async fn run_turn( sub_id: String, input: Vec, ) -> CodexResult> { - // Decide whether to use server-side storage (previous_response_id) or disable it - let (prev_id, store) = { - let state = sess.state.lock().unwrap(); - let store = state.zdr_transcript.is_none(); - let prev_id = if store { - state.previous_response_id.clone() - } else { - // When using ZDR, the Responses API may send previous_response_id - // back, but trying to use it results in a 400. - None - }; - (prev_id, store) - }; - let extra_tools = sess.mcp_connection_manager.list_all_tools(); let prompt = Prompt { input, - prev_id, - user_instructions: sess.instructions.clone(), - store, + user_instructions: sess.user_instructions.clone(), + store: !sess.disable_response_storage, extra_tools, + base_instructions_override: sess.base_instructions.clone(), }; let mut retries = 0; @@ -1027,12 +1084,13 @@ async fn run_turn( Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted), Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(e) => { - if retries < *OPENAI_STREAM_MAX_RETRIES { + // Use the configured provider-specific stream retry budget. + let max_retries = sess.client.get_provider().stream_max_retries(); + if retries < max_retries { retries += 1; let delay = backoff(retries); warn!( - "stream disconnected - retrying turn ({retries}/{} in {delay:?})...", - *OPENAI_STREAM_MAX_RETRIES + "stream disconnected - retrying turn ({retries}/{max_retries} in {delay:?})...", ); // Surface retry information to any UI/front‑end so the @@ -1041,8 +1099,7 @@ async fn run_turn( sess.notify_background_event( &sub_id, format!( - "stream error: {e}; retrying {retries}/{} in {:?}…", - *OPENAI_STREAM_MAX_RETRIES, delay + "stream error: {e}; retrying {retries}/{max_retries} in {delay:?}…" ), ) .await; @@ -1089,11 +1146,17 @@ async fn try_run_turn( // This usually happens because the user interrupted the model before we responded to one of its tool calls // and then the user sent a follow-up message. let missing_calls = { - sess.state - .lock() - .unwrap() - .pending_call_ids + prompt + .input .iter() + .filter_map(|ri| match ri { + ResponseItem::FunctionCall { call_id, .. } => Some(call_id), + ResponseItem::LocalShellCall { + call_id: Some(call_id), + .. + } => Some(call_id), + _ => None, + }) .filter_map(|call_id| { if completed_call_ids.contains(&call_id) { None @@ -1122,45 +1185,38 @@ async fn try_run_turn( }; let mut stream = sess.client.clone().stream(&prompt).await?; - - // Buffer all the incoming messages from the stream first, then execute them. - // If we execute a function call in the middle of handling the stream, it can time out. - let mut input = Vec::new(); - while let Some(event) = stream.next().await { - let event = event?; - info!("Received stream event: {event:?}"); - input.push(event); - } - let mut output = Vec::new(); - for event in input { - info!("Processing event: {event:?}"); - match event { - ResponseEvent::Created => { - let mut state = sess.state.lock().unwrap(); - // We successfully created a new response and ensured that all pending calls were included so we can clear the pending call ids. - state.pending_call_ids.clear(); + loop { + // Poll the next item from the model stream. We must inspect *both* Ok and Err + // cases so that transient stream failures (e.g., dropped SSE connection before + // `response.completed`) bubble up and trigger the caller's retry logic. + let event = stream.next().await; + let Some(event) = event else { + // Channel closed without yielding a final Completed event or explicit error. + // Treat as a disconnected stream so the caller can retry. + return Err(CodexErr::Stream( + "stream closed before response.completed".into(), + )); + }; + + let event = match event { + Ok(ev) => ev, + Err(e) => { + // Propagate the underlying stream error to the caller (run_turn), which + // will apply the configured `stream_max_retries` policy. + return Err(e); } + }; + + match event { + ResponseEvent::Created => {} ResponseEvent::OutputItemDone(item) => { - let call_id = match &item { - ResponseItem::LocalShellCall { - call_id: Some(call_id), - .. - } => Some(call_id), - ResponseItem::FunctionCall { call_id, .. } => Some(call_id), - _ => None, - }; - if let Some(call_id) = call_id { - // We just got a new call id so we need to make sure to respond to it in the next turn. - let mut state = sess.state.lock().unwrap(); - state.pending_call_ids.insert(call_id.clone()); - } let response = handle_response_item(sess, sub_id, item.clone()).await?; output.push(ProcessedResponseItem { item, response }); } ResponseEvent::Completed { - response_id, + response_id: _, token_usage, } => { if let Some(token_usage) = token_usage { @@ -1173,13 +1229,24 @@ async fn try_run_turn( .ok(); } - let mut state = sess.state.lock().unwrap(); - state.previous_response_id = Some(response_id); - break; + return Ok(output); + } + ResponseEvent::OutputTextDelta(delta) => { + let event = Event { + id: sub_id.to_string(), + msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }), + }; + sess.tx_event.send(event).await.ok(); + } + ResponseEvent::ReasoningSummaryDelta(delta) => { + let event = Event { + id: sub_id.to_string(), + msg: EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }), + }; + sess.tx_event.send(event).await.ok(); } } } - Ok(output) } async fn handle_response_item( @@ -1234,6 +1301,7 @@ async fn handle_response_item( name, arguments, call_id, + .. } => { info!("FunctionCall: {arguments}"); Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await) @@ -1298,13 +1366,14 @@ async fn handle_function_call( let params = match parse_container_exec_arguments(arguments, sess, &call_id) { Ok(params) => params, Err(output) => { - return output; + return *output; } }; handle_container_exec_with_params(params, sess, sub_id, call_id).await } + "update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await, _ => { - match try_parse_fully_qualified_tool_name(&name) { + match sess.mcp_connection_manager.parse_tool_name(&name) { Some((server, tool_name)) => { // TODO(mbolin): Determine appropriate timeout for tool call. let timeout = None; @@ -1341,7 +1410,7 @@ fn parse_container_exec_arguments( arguments: String, sess: &Session, call_id: &str, -) -> Result { +) -> Result> { // parse command match serde_json::from_str::(&arguments) { Ok(shell_tool_call_params) => Ok(to_exec_params(shell_tool_call_params, sess)), @@ -1354,11 +1423,23 @@ fn parse_container_exec_arguments( success: None, }, }; - Err(output) + Err(Box::new(output)) } } } +fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams { + if sess.shell_environment_policy.use_profile { + let command = sess + .user_shell + .format_default_shell_invocation(params.command.clone()); + if let Some(command) = command { + return ExecParams { command, ..params }; + } + } + params +} + async fn handle_container_exec_with_params( params: ExecParams, sess: &Session, @@ -1366,44 +1447,84 @@ async fn handle_container_exec_with_params( call_id: String, ) -> ResponseInputItem { // check if this was a patch, and apply it if so - match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) { - MaybeApplyPatchVerified::Body(changes) => { - return apply_patch(sess, sub_id, call_id, changes).await; - } - MaybeApplyPatchVerified::CorrectnessError(parse_error) => { - // It looks like an invocation of `apply_patch`, but we - // could not resolve it into a patch that would apply - // cleanly. Return to model for resample. - return ResponseInputItem::FunctionCallOutput { - call_id, - output: FunctionCallOutputPayload { - content: format!("error: {parse_error:#}"), - success: None, - }, - }; - } - MaybeApplyPatchVerified::ShellParseError(error) => { - trace!("Failed to parse shell command, {error:?}"); - } - MaybeApplyPatchVerified::NotApplyPatch => (), - } + let apply_patch_action_for_exec = + match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) { + MaybeApplyPatchVerified::Body(changes) => { + match apply_patch::apply_patch(sess, &sub_id, &call_id, changes).await { + InternalApplyPatchInvocation::Output(item) => return item, + InternalApplyPatchInvocation::DelegateToExec(action) => Some(action), + } + } + MaybeApplyPatchVerified::CorrectnessError(parse_error) => { + // It looks like an invocation of `apply_patch`, but we + // could not resolve it into a patch that would apply + // cleanly. Return to model for resample. + return ResponseInputItem::FunctionCallOutput { + call_id, + output: FunctionCallOutputPayload { + content: format!("error: {parse_error:#}"), + success: None, + }, + }; + } + MaybeApplyPatchVerified::ShellParseError(error) => { + trace!("Failed to parse shell command, {error:?}"); + None + } + MaybeApplyPatchVerified::NotApplyPatch => None, + }; - // safety checks - let safety = { - let state = sess.state.lock().unwrap(); - assess_command_safety( - ¶ms.command, - sess.approval_policy, - &sess.sandbox_policy, - &state.approved_commands, - ) + let (params, safety, command_for_display) = match apply_patch_action_for_exec { + Some(ApplyPatchAction { patch, cwd, .. }) => { + let path_to_codex = std::env::current_exe() + .ok() + .map(|p| p.to_string_lossy().to_string()); + let Some(path_to_codex) = path_to_codex else { + return ResponseInputItem::FunctionCallOutput { + call_id, + output: FunctionCallOutputPayload { + content: "failed to determine path to codex executable".to_string(), + success: None, + }, + }; + }; + + let params = ExecParams { + command: vec![ + path_to_codex, + CODEX_APPLY_PATCH_ARG1.to_string(), + patch.clone(), + ], + cwd, + timeout_ms: params.timeout_ms, + env: HashMap::new(), + }; + let safety = + assess_safety_for_untrusted_command(sess.approval_policy, &sess.sandbox_policy); + (params, safety, vec!["apply_patch".to_string(), patch]) + } + None => { + let safety = { + let state = sess.state.lock().unwrap(); + assess_command_safety( + ¶ms.command, + sess.approval_policy, + &sess.sandbox_policy, + &state.approved_commands, + ) + }; + let command_for_display = params.command.clone(); + (params, safety, command_for_display) + } }; + let sandbox_type = match safety { SafetyCheck::AutoApprove { sandbox_type } => sandbox_type, SafetyCheck::AskUser => { let rx_approve = sess .request_command_approval( sub_id.clone(), + call_id.clone(), params.command.clone(), params.cwd.clone(), None, @@ -1441,9 +1562,10 @@ async fn handle_container_exec_with_params( } }; - sess.notify_exec_command_begin(&sub_id, &call_id, ¶ms) + sess.notify_exec_command_begin(&sub_id, &call_id, command_for_display.clone(), ¶ms.cwd) .await; + let params = maybe_run_with_user_profile(params, sess); let output_result = process_exec_tool_call( params.clone(), sandbox_type, @@ -1481,7 +1603,16 @@ async fn handle_container_exec_with_params( } } Err(CodexErr::Sandbox(error)) => { - handle_sandbox_error(error, sandbox_type, params, sess, sub_id, call_id).await + handle_sandbox_error( + error, + sandbox_type, + params, + command_for_display, + sess, + sub_id, + call_id, + ) + .await } Err(e) => { // Handle non-sandbox errors @@ -1500,6 +1631,7 @@ async fn handle_sandbox_error( error: SandboxErr, sandbox_type: SandboxType, params: ExecParams, + command_for_display: Vec, sess: &Session, sub_id: String, call_id: String, @@ -1531,6 +1663,7 @@ async fn handle_sandbox_error( let rx_approve = sess .request_command_approval( sub_id.clone(), + call_id.clone(), params.command.clone(), params.cwd.clone(), Some("command failed; retry without sandbox?".to_string()), @@ -1548,9 +1681,7 @@ async fn handle_sandbox_error( sess.notify_background_event(&sub_id, "retrying command without sandbox") .await; - // Emit a fresh Begin event so progress bars reset. - let retry_call_id = format!("{call_id}-retry"); - sess.notify_exec_command_begin(&sub_id, &retry_call_id, ¶ms) + sess.notify_exec_command_begin(&sub_id, &call_id, command_for_display, ¶ms.cwd) .await; // This is an escalated retry; the policy will not be @@ -1573,14 +1704,8 @@ async fn handle_sandbox_error( duration, } = retry_output; - sess.notify_exec_command_end( - &sub_id, - &retry_call_id, - &stdout, - &stderr, - exit_code, - ) - .await; + sess.notify_exec_command_end(&sub_id, &call_id, &stdout, &stderr, exit_code) + .await; let is_success = exit_code == 0; let content = format_exec_output( @@ -1622,377 +1747,6 @@ async fn handle_sandbox_error( } } -async fn apply_patch( - sess: &Session, - sub_id: String, - call_id: String, - action: ApplyPatchAction, -) -> ResponseInputItem { - let writable_roots_snapshot = { - let guard = sess.writable_roots.lock().unwrap(); - guard.clone() - }; - - let auto_approved = match assess_patch_safety( - &action, - sess.approval_policy, - &writable_roots_snapshot, - &sess.cwd, - ) { - SafetyCheck::AutoApprove { .. } => true, - SafetyCheck::AskUser => { - // Compute a readable summary of path changes to include in the - // approval request so the user can make an informed decision. - let rx_approve = sess - .request_patch_approval(sub_id.clone(), &action, None, None) - .await; - match rx_approve.await.unwrap_or_default() { - ReviewDecision::Approved | ReviewDecision::ApprovedForSession => false, - ReviewDecision::Denied | ReviewDecision::Abort => { - return ResponseInputItem::FunctionCallOutput { - call_id, - output: FunctionCallOutputPayload { - content: "patch rejected by user".to_string(), - success: Some(false), - }, - }; - } - } - } - SafetyCheck::Reject { reason } => { - return ResponseInputItem::FunctionCallOutput { - call_id, - output: FunctionCallOutputPayload { - content: format!("patch rejected: {reason}"), - success: Some(false), - }, - }; - } - }; - - // Verify write permissions before touching the filesystem. - let writable_snapshot = { sess.writable_roots.lock().unwrap().clone() }; - - if let Some(offending) = first_offending_path(&action, &writable_snapshot, &sess.cwd) { - let root = offending.parent().unwrap_or(&offending).to_path_buf(); - - let reason = Some(format!( - "grant write access to {} for this session", - root.display() - )); - - let rx = sess - .request_patch_approval(sub_id.clone(), &action, reason.clone(), Some(root.clone())) - .await; - - if !matches!( - rx.await.unwrap_or_default(), - ReviewDecision::Approved | ReviewDecision::ApprovedForSession - ) { - return ResponseInputItem::FunctionCallOutput { - call_id, - output: FunctionCallOutputPayload { - content: "patch rejected by user".to_string(), - success: Some(false), - }, - }; - } - - // user approved, extend writable roots for this session - sess.writable_roots.lock().unwrap().push(root); - } - - let _ = sess - .tx_event - .send(Event { - id: sub_id.clone(), - msg: EventMsg::PatchApplyBegin(PatchApplyBeginEvent { - call_id: call_id.clone(), - auto_approved, - changes: convert_apply_patch_to_protocol(&action), - }), - }) - .await; - - let mut stdout = Vec::new(); - let mut stderr = Vec::new(); - // Enforce writable roots. If a write is blocked, collect offending root - // and prompt the user to extend permissions. - let mut result = apply_changes_from_apply_patch_and_report(&action, &mut stdout, &mut stderr); - - if let Err(err) = &result { - if err.kind() == std::io::ErrorKind::PermissionDenied { - // Determine first offending path. - let offending_opt = action - .changes() - .iter() - .flat_map(|(path, change)| match change { - ApplyPatchFileChange::Add { .. } => vec![path.as_ref()], - ApplyPatchFileChange::Delete => vec![path.as_ref()], - ApplyPatchFileChange::Update { - move_path: Some(move_path), - .. - } => { - vec![path.as_ref(), move_path.as_ref()] - } - ApplyPatchFileChange::Update { - move_path: None, .. - } => vec![path.as_ref()], - }) - .find_map(|path: &Path| { - // ApplyPatchAction promises to guarantee absolute paths. - if !path.is_absolute() { - panic!("apply_patch invariant failed: path is not absolute: {path:?}"); - } - - let writable = { - let roots = sess.writable_roots.lock().unwrap(); - roots.iter().any(|root| path.starts_with(root)) - }; - if writable { - None - } else { - Some(path.to_path_buf()) - } - }); - - if let Some(offending) = offending_opt { - let root = offending.parent().unwrap_or(&offending).to_path_buf(); - - let reason = Some(format!( - "grant write access to {} for this session", - root.display() - )); - let rx = sess - .request_patch_approval( - sub_id.clone(), - &action, - reason.clone(), - Some(root.clone()), - ) - .await; - if matches!( - rx.await.unwrap_or_default(), - ReviewDecision::Approved | ReviewDecision::ApprovedForSession - ) { - // Extend writable roots. - sess.writable_roots.lock().unwrap().push(root); - stdout.clear(); - stderr.clear(); - result = apply_changes_from_apply_patch_and_report( - &action, - &mut stdout, - &mut stderr, - ); - } - } - } - } - - // Emit PatchApplyEnd event. - let success_flag = result.is_ok(); - let _ = sess - .tx_event - .send(Event { - id: sub_id.clone(), - msg: EventMsg::PatchApplyEnd(PatchApplyEndEvent { - call_id: call_id.clone(), - stdout: String::from_utf8_lossy(&stdout).to_string(), - stderr: String::from_utf8_lossy(&stderr).to_string(), - success: success_flag, - }), - }) - .await; - - match result { - Ok(_) => ResponseInputItem::FunctionCallOutput { - call_id, - output: FunctionCallOutputPayload { - content: String::from_utf8_lossy(&stdout).to_string(), - success: None, - }, - }, - Err(e) => ResponseInputItem::FunctionCallOutput { - call_id, - output: FunctionCallOutputPayload { - content: format!("error: {e:#}, stderr: {}", String::from_utf8_lossy(&stderr)), - success: Some(false), - }, - }, - } -} - -/// Return the first path in `hunks` that is NOT under any of the -/// `writable_roots` (after normalising). If all paths are acceptable, -/// returns None. -fn first_offending_path( - action: &ApplyPatchAction, - writable_roots: &[PathBuf], - cwd: &Path, -) -> Option { - let changes = action.changes(); - for (path, change) in changes { - let candidate = match change { - ApplyPatchFileChange::Add { .. } => path, - ApplyPatchFileChange::Delete => path, - ApplyPatchFileChange::Update { move_path, .. } => move_path.as_ref().unwrap_or(path), - }; - - let abs = if candidate.is_absolute() { - candidate.clone() - } else { - cwd.join(candidate) - }; - - let mut allowed = false; - for root in writable_roots { - let root_abs = if root.is_absolute() { - root.clone() - } else { - cwd.join(root) - }; - if abs.starts_with(&root_abs) { - allowed = true; - break; - } - } - - if !allowed { - return Some(candidate.clone()); - } - } - None -} - -fn convert_apply_patch_to_protocol(action: &ApplyPatchAction) -> HashMap { - let changes = action.changes(); - let mut result = HashMap::with_capacity(changes.len()); - for (path, change) in changes { - let protocol_change = match change { - ApplyPatchFileChange::Add { content } => FileChange::Add { - content: content.clone(), - }, - ApplyPatchFileChange::Delete => FileChange::Delete, - ApplyPatchFileChange::Update { - unified_diff, - move_path, - new_content: _new_content, - } => FileChange::Update { - unified_diff: unified_diff.clone(), - move_path: move_path.clone(), - }, - }; - result.insert(path.clone(), protocol_change); - } - result -} - -fn apply_changes_from_apply_patch_and_report( - action: &ApplyPatchAction, - stdout: &mut impl std::io::Write, - stderr: &mut impl std::io::Write, -) -> std::io::Result<()> { - match apply_changes_from_apply_patch(action) { - Ok(affected_paths) => { - print_summary(&affected_paths, stdout)?; - } - Err(err) => { - writeln!(stderr, "{err:?}")?; - } - } - - Ok(()) -} - -fn apply_changes_from_apply_patch(action: &ApplyPatchAction) -> anyhow::Result { - let mut added: Vec = Vec::new(); - let mut modified: Vec = Vec::new(); - let mut deleted: Vec = Vec::new(); - - let changes = action.changes(); - for (path, change) in changes { - match change { - ApplyPatchFileChange::Add { content } => { - if let Some(parent) = path.parent() { - if !parent.as_os_str().is_empty() { - std::fs::create_dir_all(parent).with_context(|| { - format!("Failed to create parent directories for {}", path.display()) - })?; - } - } - std::fs::write(path, content) - .with_context(|| format!("Failed to write file {}", path.display()))?; - added.push(path.clone()); - } - ApplyPatchFileChange::Delete => { - std::fs::remove_file(path) - .with_context(|| format!("Failed to delete file {}", path.display()))?; - deleted.push(path.clone()); - } - ApplyPatchFileChange::Update { - unified_diff: _unified_diff, - move_path, - new_content, - } => { - if let Some(move_path) = move_path { - if let Some(parent) = move_path.parent() { - if !parent.as_os_str().is_empty() { - std::fs::create_dir_all(parent).with_context(|| { - format!( - "Failed to create parent directories for {}", - move_path.display() - ) - })?; - } - } - - std::fs::rename(path, move_path) - .with_context(|| format!("Failed to rename file {}", path.display()))?; - std::fs::write(move_path, new_content)?; - modified.push(move_path.clone()); - deleted.push(path.clone()); - } else { - std::fs::write(path, new_content)?; - modified.push(path.clone()); - } - } - } - } - - Ok(AffectedPaths { - added, - modified, - deleted, - }) -} - -fn get_writable_roots(cwd: &Path) -> Vec { - let mut writable_roots = Vec::new(); - if cfg!(target_os = "macos") { - // On macOS, $TMPDIR is private to the user. - writable_roots.push(std::env::temp_dir()); - - // Allow pyenv to update its shims directory. Without this, any tool - // that happens to be managed by `pyenv` will fail with an error like: - // - // pyenv: cannot rehash: $HOME/.pyenv/shims isn't writable - // - // which is emitted every time `pyenv` tries to run `rehash` (for - // example, after installing a new Python package that drops an entry - // point). Although the sandbox is intentionally read‑only by default, - // writing to the user's local `pyenv` directory is safe because it - // is already user‑writable and scoped to the current user account. - if let Ok(home_dir) = std::env::var("HOME") { - let pyenv_dir = PathBuf::from(home_dir).join(".pyenv"); - writable_roots.push(pyenv_dir); - } - } - - writable_roots.push(cwd.to_path_buf()); - - writable_roots -} - /// Exec output is a pre-serialized JSON payload fn format_exec_output(output: &str, exit_code: i32, duration: Duration) -> String { #[derive(Serialize)] @@ -2024,7 +1778,7 @@ fn format_exec_output(output: &str, exit_code: i32, duration: Duration) -> Strin fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option { responses.iter().rev().find_map(|item| { - if let ResponseItem::Message { role, content } = item { + if let ResponseItem::Message { role, content, .. } = item { if role == "assistant" { content.iter().rev().find_map(|ci| { if let ContentItem::OutputText { text } = ci { @@ -2041,15 +1795,3 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option bool { - if disable_response_storage { - return true; - } - - match wire_api { - WireApi::Responses => false, - WireApi::Chat => true, - } -} diff --git a/codex-rs/core/src/codex_wrapper.rs b/codex-rs/core/src/codex_wrapper.rs index f2ece22da7..1e26a9ebed 100644 --- a/codex-rs/core/src/codex_wrapper.rs +++ b/codex-rs/core/src/codex_wrapper.rs @@ -1,20 +1,37 @@ use std::sync::Arc; use crate::Codex; +use crate::CodexSpawnOk; use crate::config::Config; use crate::protocol::Event; use crate::protocol::EventMsg; use crate::util::notify_on_sigint; +use codex_login::load_auth; use tokio::sync::Notify; +use uuid::Uuid; + +/// Represents an active Codex conversation, including the first event +/// (which is [`EventMsg::SessionConfigured`]). +pub struct CodexConversation { + pub codex: Codex, + pub session_id: Uuid, + pub session_configured: Event, + pub ctrl_c: Arc, +} /// Spawn a new [`Codex`] and initialize the session. /// /// Returns the wrapped [`Codex`] **and** the `SessionInitialized` event that /// is received as a response to the initial `ConfigureSession` submission so /// that callers can surface the information to the UI. -pub async fn init_codex(config: Config) -> anyhow::Result<(Codex, Event, Arc)> { +pub async fn init_codex(config: Config) -> anyhow::Result { let ctrl_c = notify_on_sigint(); - let (codex, init_id) = Codex::spawn(config, ctrl_c.clone()).await?; + let auth = load_auth(&config.codex_home)?; + let CodexSpawnOk { + codex, + init_id, + session_id, + } = Codex::spawn(config, auth, ctrl_c.clone()).await?; // The first event must be `SessionInitialized`. Validate and forward it to // the caller so that they can display it in the conversation history. @@ -33,5 +50,10 @@ pub async fn init_codex(config: Config) -> anyhow::Result<(Codex, Event, Arc, + pub user_instructions: Option, + + /// Base instructions override. + pub base_instructions: Option, /// Optional external notifier command. When set, Codex will spawn this /// program after each completed *turn* (i.e. when the agent finishes @@ -141,6 +144,12 @@ pub struct Config { /// Base URL for requests to ChatGPT (as opposed to the OpenAI API). pub chatgpt_base_url: String, + + /// Experimental rollout resume path (absolute path to .jsonl; undocumented). + pub experimental_resume: Option, + + /// Include an experimental plan tool that the model can use to update its current plan and status of each step. + pub include_plan_tool: bool, } impl Config { @@ -329,6 +338,12 @@ pub struct ConfigToml { /// Base URL for requests to ChatGPT (as opposed to the OpenAI API). pub chatgpt_base_url: Option, + + /// Experimental rollout resume path (absolute path to .jsonl; undocumented). + pub experimental_resume: Option, + + /// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS. + pub experimental_instructions_file: Option, } impl ConfigToml { @@ -361,6 +376,8 @@ pub struct ConfigOverrides { pub model_provider: Option, pub config_profile: Option, pub codex_linux_sandbox_exe: Option, + pub base_instructions: Option, + pub include_plan_tool: Option, } impl Config { @@ -371,7 +388,7 @@ impl Config { overrides: ConfigOverrides, codex_home: PathBuf, ) -> std::io::Result { - let instructions = Self::load_instructions(Some(&codex_home)); + let user_instructions = Self::load_instructions(Some(&codex_home)); // Destructure ConfigOverrides fully to ensure all overrides are applied. let ConfigOverrides { @@ -382,6 +399,8 @@ impl Config { model_provider, config_profile: config_profile_key, codex_linux_sandbox_exe, + base_instructions, + include_plan_tool, } = overrides; let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) { @@ -456,6 +475,18 @@ impl Config { .as_ref() .map(|info| info.max_output_tokens) }); + + let experimental_resume = cfg.experimental_resume; + + // Load base instructions override from a file if specified. If the + // path is relative, resolve it against the effective cwd so the + // behaviour matches other path-like config values. + let file_base_instructions = Self::get_base_instructions( + cfg.experimental_instructions_file.as_ref(), + &resolved_cwd, + )?; + let base_instructions = base_instructions.or(file_base_instructions); + let config = Self { model, model_context_window, @@ -474,7 +505,8 @@ impl Config { .or(cfg.disable_response_storage) .unwrap_or(false), notify: cfg.notify, - instructions, + user_instructions, + base_instructions, mcp_servers: cfg.mcp_servers, model_providers, project_doc_max_bytes: cfg.project_doc_max_bytes.unwrap_or(PROJECT_DOC_MAX_BYTES), @@ -503,6 +535,9 @@ impl Config { .chatgpt_base_url .or(cfg.chatgpt_base_url) .unwrap_or("https://chatgpt.com/backend-api/".to_string()), + + experimental_resume, + include_plan_tool: include_plan_tool.unwrap_or(false), }; Ok(config) } @@ -523,6 +558,48 @@ impl Config { } }) } + + fn get_base_instructions( + path: Option<&PathBuf>, + cwd: &Path, + ) -> std::io::Result> { + let p = match path.as_ref() { + None => return Ok(None), + Some(p) => p, + }; + + // Resolve relative paths against the provided cwd to make CLI + // overrides consistent regardless of where the process was launched + // from. + let full_path = if p.is_relative() { + cwd.join(p) + } else { + p.to_path_buf() + }; + + let contents = std::fs::read_to_string(&full_path).map_err(|e| { + std::io::Error::new( + e.kind(), + format!( + "failed to read experimental instructions file {}: {e}", + full_path.display() + ), + ) + })?; + + let s = contents.trim().to_string(); + if s.is_empty() { + Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "experimental instructions file is empty: {}", + full_path.display() + ), + )) + } else { + Ok(Some(s)) + } + } } fn default_model() -> String { @@ -537,7 +614,7 @@ fn default_model() -> String { /// function will Err if the path does not exist. /// - If `CODEX_HOME` is not set, this function does not verify that the /// directory exists. -fn find_codex_home() -> std::io::Result { +pub fn find_codex_home() -> std::io::Result { // Honor the `CODEX_HOME` environment variable when it is set to allow users // (and tests) to override the default location. if let Ok(val) = std::env::var("CODEX_HOME") { @@ -691,6 +768,9 @@ name = "OpenAI using Chat Completions" base_url = "https://api.openai.com/v1" env_key = "OPENAI_API_KEY" wire_api = "chat" +request_max_retries = 4 # retry failed HTTP requests +stream_max_retries = 10 # retry dropped SSE streams +stream_idle_timeout_ms = 300000 # 5m idle timeout [profiles.o3] model = "o3" @@ -724,13 +804,17 @@ disable_response_storage = true let openai_chat_completions_provider = ModelProviderInfo { name: "OpenAI using Chat Completions".to_string(), - base_url: "https://api.openai.com/v1".to_string(), + base_url: Some("https://api.openai.com/v1".to_string()), env_key: Some("OPENAI_API_KEY".to_string()), wire_api: crate::WireApi::Chat, env_key_instructions: None, query_params: None, http_headers: None, env_http_headers: None, + request_max_retries: Some(4), + stream_max_retries: Some(10), + stream_idle_timeout_ms: Some(300_000), + requires_auth: false, }; let model_provider_map = { let mut model_provider_map = built_in_model_providers(); @@ -761,7 +845,7 @@ disable_response_storage = true /// /// 1. custom command-line argument, e.g. `--model o3` /// 2. as part of a profile, where the `--profile` is specified via a CLI - /// (or in the config file itelf) + /// (or in the config file itself) /// 3. as an entry in `config.toml`, e.g. `model = "o3"` /// 4. the default value for a required field defined in code, e.g., /// `crate::flags::OPENAI_DEFAULT_MODEL` @@ -793,7 +877,7 @@ disable_response_storage = true sandbox_policy: SandboxPolicy::new_read_only_policy(), shell_environment_policy: ShellEnvironmentPolicy::default(), disable_response_storage: false, - instructions: None, + user_instructions: None, notify: None, cwd: fixture.cwd(), mcp_servers: HashMap::new(), @@ -810,6 +894,9 @@ disable_response_storage = true model_reasoning_summary: ReasoningSummary::Detailed, model_supports_reasoning_summaries: false, chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), + experimental_resume: None, + base_instructions: None, + include_plan_tool: false, }, o3_profile_config ); @@ -840,7 +927,7 @@ disable_response_storage = true sandbox_policy: SandboxPolicy::new_read_only_policy(), shell_environment_policy: ShellEnvironmentPolicy::default(), disable_response_storage: false, - instructions: None, + user_instructions: None, notify: None, cwd: fixture.cwd(), mcp_servers: HashMap::new(), @@ -857,6 +944,9 @@ disable_response_storage = true model_reasoning_summary: ReasoningSummary::default(), model_supports_reasoning_summaries: false, chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), + experimental_resume: None, + base_instructions: None, + include_plan_tool: false, }; assert_eq!(expected_gpt3_profile_config, gpt3_profile_config); @@ -902,7 +992,7 @@ disable_response_storage = true sandbox_policy: SandboxPolicy::new_read_only_policy(), shell_environment_policy: ShellEnvironmentPolicy::default(), disable_response_storage: true, - instructions: None, + user_instructions: None, notify: None, cwd: fixture.cwd(), mcp_servers: HashMap::new(), @@ -919,6 +1009,9 @@ disable_response_storage = true model_reasoning_summary: ReasoningSummary::default(), model_supports_reasoning_summaries: false, chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), + experimental_resume: None, + base_instructions: None, + include_plan_tool: false, }; assert_eq!(expected_zdr_profile_config, zdr_profile_config); diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index 83fe613c86..9bf0d483e1 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -76,22 +76,9 @@ pub enum HistoryPersistence { /// Collection of settings that are specific to the TUI. #[derive(Deserialize, Debug, Clone, PartialEq, Default)] -pub struct Tui { - /// By default, mouse capture is enabled in the TUI so that it is possible - /// to scroll the conversation history with a mouse. This comes at the cost - /// of not being able to use the mouse to select text in the TUI. - /// (Most terminals support a modifier key to allow this. For example, - /// text selection works in iTerm if you hold down the `Option` key while - /// clicking and dragging.) - /// - /// Setting this option to `true` disables mouse capture, so scrolling with - /// the mouse is not possible, though the keyboard shortcuts e.g. `b` and - /// `space` still work. This allows the user to select text in the TUI - /// using the mouse without needing to hold down a modifier key. - pub disable_mouse_capture: bool, -} +pub struct Tui {} -#[derive(Deserialize, Debug, Clone, Copy, PartialEq, Default)] +#[derive(Deserialize, Debug, Clone, Copy, PartialEq, Default, Serialize)] #[serde(rename_all = "kebab-case")] pub enum SandboxMode { #[serde(rename = "read-only")] @@ -143,6 +130,8 @@ pub struct ShellEnvironmentPolicyToml { /// List of regular expressions. pub include_only: Option>, + + pub experimental_use_profile: Option, } pub type EnvironmentVariablePattern = WildMatchPattern<'*', '?'>; @@ -171,6 +160,9 @@ pub struct ShellEnvironmentPolicy { /// Environment variable names to retain in the environment. pub include_only: Vec, + + /// If true, the shell profile will be used to run the command. + pub use_profile: bool, } impl From for ShellEnvironmentPolicy { @@ -190,6 +182,7 @@ impl From for ShellEnvironmentPolicy { .into_iter() .map(|s| EnvironmentVariablePattern::new_case_insensitive(&s)) .collect(); + let use_profile = toml.experimental_use_profile.unwrap_or(false); Self { inherit, @@ -197,6 +190,7 @@ impl From for ShellEnvironmentPolicy { exclude, r#set, include_only, + use_profile, } } } diff --git a/codex-rs/core/src/conversation_history.rs b/codex-rs/core/src/conversation_history.rs index 52fb1ec4f4..4cd989cbd9 100644 --- a/codex-rs/core/src/conversation_history.rs +++ b/codex-rs/core/src/conversation_history.rs @@ -1,12 +1,7 @@ use crate::models::ResponseItem; -/// Transcript of conversation history that is needed: -/// - for ZDR clients for which previous_response_id is not available, so we -/// must include the transcript with every API call. This must include each -/// `function_call` and its corresponding `function_call_output`. -/// - for clients using the "chat completions" API as opposed to the -/// "responses" API. -#[derive(Debug, Clone)] +/// Transcript of conversation history +#[derive(Debug, Clone, Default)] pub(crate) struct ConversationHistory { /// The oldest items are at the beginning of the vector. items: Vec, @@ -44,7 +39,8 @@ fn is_api_message(message: &ResponseItem) -> bool { ResponseItem::Message { role, .. } => role.as_str() != "system", ResponseItem::FunctionCallOutput { .. } | ResponseItem::FunctionCall { .. } - | ResponseItem::LocalShellCall { .. } => true, - ResponseItem::Reasoning { .. } | ResponseItem::Other => false, + | ResponseItem::LocalShellCall { .. } + | ResponseItem::Reasoning { .. } => true, + ResponseItem::Other => false, } } diff --git a/codex-rs/core/src/exec.rs b/codex-rs/core/src/exec.rs index 3b37cb538d..230c4ec134 100644 --- a/codex-rs/core/src/exec.rs +++ b/codex-rs/core/src/exec.rs @@ -17,6 +17,7 @@ use tokio::io::BufReader; use tokio::process::Child; use tokio::process::Command; use tokio::sync::Notify; +use tracing::trace; use crate::error::CodexErr; use crate::error::Result; @@ -82,7 +83,8 @@ pub async fn process_exec_tool_call( ) -> Result { let start = Instant::now(); - let raw_output_result = match sandbox_type { + let raw_output_result: std::result::Result = match sandbox_type + { SandboxType::None => exec(params, sandbox_policy, ctrl_c).await, SandboxType::MacosSeatbelt => { let ExecParams { @@ -372,6 +374,10 @@ async fn spawn_child_async( stdio_policy: StdioPolicy, env: HashMap, ) -> std::io::Result { + trace!( + "spawn_child_async: {program:?} {args:?} {arg0:?} {cwd:?} {sandbox_policy:?} {stdio_policy:?} {env:?}" + ); + let mut cmd = Command::new(&program); #[cfg(unix)] cmd.arg0(arg0.map_or_else(|| program.to_string_lossy().to_string(), String::from)); @@ -384,6 +390,31 @@ async fn spawn_child_async( cmd.env(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR, "1"); } + // If this Codex process dies (including being killed via SIGKILL), we want + // any child processes that were spawned as part of a `"shell"` tool call + // to also be terminated. + + // This relies on prctl(2), so it only works on Linux. + #[cfg(target_os = "linux")] + unsafe { + cmd.pre_exec(|| { + // This prctl call effectively requests, "deliver SIGTERM when my + // current parent dies." + if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM) == -1 { + return Err(io::Error::last_os_error()); + } + + // Though if there was a race condition and this pre_exec() block is + // run _after_ the parent (i.e., the Codex process) has already + // exited, then the parent is the _init_ process (which will never + // die), so we should just terminate the child process now. + if libc::getppid() == 1 { + libc::raise(libc::SIGTERM); + } + Ok(()) + }); + } + match stdio_policy { StdioPolicy::RedirectForShellTool => { // Do not create a file descriptor for stdin because otherwise some diff --git a/codex-rs/core/src/flags.rs b/codex-rs/core/src/flags.rs index c21ef67026..c150405491 100644 --- a/codex-rs/core/src/flags.rs +++ b/codex-rs/core/src/flags.rs @@ -11,14 +11,6 @@ env_flags! { pub OPENAI_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| { value.parse().map(Duration::from_millis) }; - pub OPENAI_REQUEST_MAX_RETRIES: u64 = 4; - pub OPENAI_STREAM_MAX_RETRIES: u64 = 10; - - // We generally don't want to disconnect; this updates the timeout to be five minutes - // which matches the upstream typescript codex impl. - pub OPENAI_STREAM_IDLE_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| { - value.parse().map(Duration::from_millis) - }; /// Fixture path for offline tests (see client.rs). pub CODEX_RS_SSE_FIXTURE: Option<&str> = None; diff --git a/codex-rs/core/src/git_info.rs b/codex-rs/core/src/git_info.rs new file mode 100644 index 0000000000..cf959d32d1 --- /dev/null +++ b/codex-rs/core/src/git_info.rs @@ -0,0 +1,307 @@ +use std::path::Path; + +use serde::Deserialize; +use serde::Serialize; +use tokio::process::Command; +use tokio::time::Duration as TokioDuration; +use tokio::time::timeout; + +/// Timeout for git commands to prevent freezing on large repositories +const GIT_COMMAND_TIMEOUT: TokioDuration = TokioDuration::from_secs(5); + +#[derive(Serialize, Deserialize, Clone)] +pub struct GitInfo { + /// Current commit hash (SHA) + #[serde(skip_serializing_if = "Option::is_none")] + pub commit_hash: Option, + /// Current branch name + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + /// Repository URL (if available from remote) + #[serde(skip_serializing_if = "Option::is_none")] + pub repository_url: Option, +} + +/// Collect git repository information from the given working directory using command-line git. +/// Returns None if no git repository is found or if git operations fail. +/// Uses timeouts to prevent freezing on large repositories. +/// All git commands (except the initial repo check) run in parallel for better performance. +pub async fn collect_git_info(cwd: &Path) -> Option { + // Check if we're in a git repository first + let is_git_repo = run_git_command_with_timeout(&["rev-parse", "--git-dir"], cwd) + .await? + .status + .success(); + + if !is_git_repo { + return None; + } + + // Run all git info collection commands in parallel + let (commit_result, branch_result, url_result) = tokio::join!( + run_git_command_with_timeout(&["rev-parse", "HEAD"], cwd), + run_git_command_with_timeout(&["rev-parse", "--abbrev-ref", "HEAD"], cwd), + run_git_command_with_timeout(&["remote", "get-url", "origin"], cwd) + ); + + let mut git_info = GitInfo { + commit_hash: None, + branch: None, + repository_url: None, + }; + + // Process commit hash + if let Some(output) = commit_result { + if output.status.success() { + if let Ok(hash) = String::from_utf8(output.stdout) { + git_info.commit_hash = Some(hash.trim().to_string()); + } + } + } + + // Process branch name + if let Some(output) = branch_result { + if output.status.success() { + if let Ok(branch) = String::from_utf8(output.stdout) { + let branch = branch.trim(); + if branch != "HEAD" { + git_info.branch = Some(branch.to_string()); + } + } + } + } + + // Process repository URL + if let Some(output) = url_result { + if output.status.success() { + if let Ok(url) = String::from_utf8(output.stdout) { + git_info.repository_url = Some(url.trim().to_string()); + } + } + } + + Some(git_info) +} + +/// Run a git command with a timeout to prevent blocking on large repositories +async fn run_git_command_with_timeout(args: &[&str], cwd: &Path) -> Option { + let result = timeout( + GIT_COMMAND_TIMEOUT, + Command::new("git").args(args).current_dir(cwd).output(), + ) + .await; + + match result { + Ok(Ok(output)) => Some(output), + _ => None, // Timeout or error + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::expect_used)] + #![allow(clippy::unwrap_used)] + + use super::*; + + use std::fs; + use std::path::PathBuf; + use tempfile::TempDir; + + // Helper function to create a test git repository + async fn create_test_git_repo(temp_dir: &TempDir) -> PathBuf { + let repo_path = temp_dir.path().to_path_buf(); + + // Initialize git repo + Command::new("git") + .args(["init"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to init git repo"); + + // Configure git user (required for commits) + Command::new("git") + .args(["config", "user.name", "Test User"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to set git user name"); + + Command::new("git") + .args(["config", "user.email", "test@example.com"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to set git user email"); + + // Create a test file and commit it + let test_file = repo_path.join("test.txt"); + fs::write(&test_file, "test content").expect("Failed to write test file"); + + Command::new("git") + .args(["add", "."]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to add files"); + + Command::new("git") + .args(["commit", "-m", "Initial commit"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to commit"); + + repo_path + } + + #[tokio::test] + async fn test_collect_git_info_non_git_directory() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let result = collect_git_info(temp_dir.path()).await; + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_collect_git_info_git_repository() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + + let git_info = collect_git_info(&repo_path) + .await + .expect("Should collect git info from repo"); + + // Should have commit hash + assert!(git_info.commit_hash.is_some()); + let commit_hash = git_info.commit_hash.unwrap(); + assert_eq!(commit_hash.len(), 40); // SHA-1 hash should be 40 characters + assert!(commit_hash.chars().all(|c| c.is_ascii_hexdigit())); + + // Should have branch (likely "main" or "master") + assert!(git_info.branch.is_some()); + let branch = git_info.branch.unwrap(); + assert!(branch == "main" || branch == "master"); + + // Repository URL might be None for local repos without remote + // This is acceptable behavior + } + + #[tokio::test] + async fn test_collect_git_info_with_remote() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + + // Add a remote origin + Command::new("git") + .args([ + "remote", + "add", + "origin", + "https://github.com/example/repo.git", + ]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to add remote"); + + let git_info = collect_git_info(&repo_path) + .await + .expect("Should collect git info from repo"); + + // Should have repository URL + assert_eq!( + git_info.repository_url, + Some("https://github.com/example/repo.git".to_string()) + ); + } + + #[tokio::test] + async fn test_collect_git_info_detached_head() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + + // Get the current commit hash + let output = Command::new("git") + .args(["rev-parse", "HEAD"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to get HEAD"); + let commit_hash = String::from_utf8(output.stdout).unwrap().trim().to_string(); + + // Checkout the commit directly (detached HEAD) + Command::new("git") + .args(["checkout", &commit_hash]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to checkout commit"); + + let git_info = collect_git_info(&repo_path) + .await + .expect("Should collect git info from repo"); + + // Should have commit hash + assert!(git_info.commit_hash.is_some()); + // Branch should be None for detached HEAD (since rev-parse --abbrev-ref HEAD returns "HEAD") + assert!(git_info.branch.is_none()); + } + + #[tokio::test] + async fn test_collect_git_info_with_branch() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + + // Create and checkout a new branch + Command::new("git") + .args(["checkout", "-b", "feature-branch"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to create branch"); + + let git_info = collect_git_info(&repo_path) + .await + .expect("Should collect git info from repo"); + + // Should have the new branch name + assert_eq!(git_info.branch, Some("feature-branch".to_string())); + } + + #[test] + fn test_git_info_serialization() { + let git_info = GitInfo { + commit_hash: Some("abc123def456".to_string()), + branch: Some("main".to_string()), + repository_url: Some("https://github.com/example/repo.git".to_string()), + }; + + let json = serde_json::to_string(&git_info).expect("Should serialize GitInfo"); + let parsed: serde_json::Value = serde_json::from_str(&json).expect("Should parse JSON"); + + assert_eq!(parsed["commit_hash"], "abc123def456"); + assert_eq!(parsed["branch"], "main"); + assert_eq!( + parsed["repository_url"], + "https://github.com/example/repo.git" + ); + } + + #[test] + fn test_git_info_serialization_with_nones() { + let git_info = GitInfo { + commit_hash: None, + branch: None, + repository_url: None, + }; + + let json = serde_json::to_string(&git_info).expect("Should serialize GitInfo"); + let parsed: serde_json::Value = serde_json::from_str(&json).expect("Should parse JSON"); + + // Fields with None values should be omitted due to skip_serializing_if + assert!(!parsed.as_object().unwrap().contains_key("commit_hash")); + assert!(!parsed.as_object().unwrap().contains_key("branch")); + assert!(!parsed.as_object().unwrap().contains_key("repository_url")); + } +} diff --git a/codex-rs/core/src/is_safe_command.rs b/codex-rs/core/src/is_safe_command.rs index 98c41dbdc2..f5f453f8d8 100644 --- a/codex-rs/core/src/is_safe_command.rs +++ b/codex-rs/core/src/is_safe_command.rs @@ -1,31 +1,57 @@ -use tree_sitter::Parser; -use tree_sitter::Tree; -use tree_sitter_bash::LANGUAGE as BASH; +use crate::bash::try_parse_bash; +use crate::bash::try_parse_word_only_commands_sequence; pub fn is_known_safe_command(command: &[String]) -> bool { if is_safe_to_call_with_exec(command) { return true; } - // TODO(mbolin): Also support safe commands that are piped together such - // as `cat foo | wc -l`. - matches!( - command, - [bash, flag, script] - if bash == "bash" - && flag == "-lc" - && try_parse_bash(script).and_then(|tree| - try_parse_single_word_only_command(&tree, script)).is_some_and(|parsed_bash_command| is_safe_to_call_with_exec(&parsed_bash_command)) - ) + // Support `bash -lc "..."` where the script consists solely of one or + // more "plain" commands (only bare words / quoted strings) combined with + // a conservative allow‑list of shell operators that themselves do not + // introduce side effects ( "&&", "||", ";", and "|" ). If every + // individual command in the script is itself a known‑safe command, then + // the composite expression is considered safe. + if let [bash, flag, script] = command { + if bash == "bash" && flag == "-lc" { + if let Some(tree) = try_parse_bash(script) { + if let Some(all_commands) = try_parse_word_only_commands_sequence(&tree, script) { + if !all_commands.is_empty() + && all_commands + .iter() + .all(|cmd| is_safe_to_call_with_exec(cmd)) + { + return true; + } + } + } + } + } + + false } fn is_safe_to_call_with_exec(command: &[String]) -> bool { let cmd0 = command.first().map(String::as_str); match cmd0 { + #[rustfmt::skip] Some( - "cat" | "cd" | "echo" | "grep" | "head" | "ls" | "pwd" | "rg" | "tail" | "wc" | "which", - ) => true, + "cat" | + "cd" | + "echo" | + "false" | + "grep" | + "head" | + "ls" | + "nl" | + "pwd" | + "tail" | + "true" | + "wc" | + "which") => { + true + }, Some("find") => { // Certain options to `find` can delete files, write to files, or @@ -46,6 +72,29 @@ fn is_safe_to_call_with_exec(command: &[String]) -> bool { .any(|arg| UNSAFE_FIND_OPTIONS.contains(&arg.as_str())) } + // Ripgrep + Some("rg") => { + const UNSAFE_RIPGREP_OPTIONS_WITH_ARGS: &[&str] = &[ + // Takes an arbitrary command that is executed for each match. + "--pre", + // Takes a command that can be used to obtain the local hostname. + "--hostname-bin", + ]; + const UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS: &[&str] = &[ + // Calls out to other decompression tools, so do not auto-approve + // out of an abundance of caution. + "--search-zip", + "-z", + ]; + + !command.iter().any(|arg| { + UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS.contains(&arg.as_str()) + || UNSAFE_RIPGREP_OPTIONS_WITH_ARGS + .iter() + .any(|&opt| arg == opt || arg.starts_with(&format!("{opt}="))) + }) + } + // Git Some("git") => matches!( command.get(1).map(String::as_str), @@ -72,90 +121,7 @@ fn is_safe_to_call_with_exec(command: &[String]) -> bool { } } -fn try_parse_bash(bash_lc_arg: &str) -> Option { - let lang = BASH.into(); - let mut parser = Parser::new(); - #[expect(clippy::expect_used)] - parser.set_language(&lang).expect("load bash grammar"); - - let old_tree: Option<&Tree> = None; - parser.parse(bash_lc_arg, old_tree) -} - -/// If `tree` represents a single Bash command whose name and every argument is -/// an ordinary `word`, return those words in order; otherwise, return `None`. -/// -/// `src` must be the exact source string that was parsed into `tree`, so we can -/// extract the text for every node. -pub fn try_parse_single_word_only_command(tree: &Tree, src: &str) -> Option> { - // Any parse error is an immediate rejection. - if tree.root_node().has_error() { - return None; - } - - // (program …) with exactly one statement - let root = tree.root_node(); - if root.kind() != "program" || root.named_child_count() != 1 { - return None; - } - - let cmd = root.named_child(0)?; // (command …) - if cmd.kind() != "command" { - return None; - } - - let mut words = Vec::new(); - let mut cursor = cmd.walk(); - - for child in cmd.named_children(&mut cursor) { - match child.kind() { - // The command name node wraps one `word` child. - "command_name" => { - let word_node = child.named_child(0)?; // make sure it's only a word - if word_node.kind() != "word" { - return None; - } - words.push(word_node.utf8_text(src.as_bytes()).ok()?.to_owned()); - } - // Positional‑argument word (allowed). - "word" | "number" => { - words.push(child.utf8_text(src.as_bytes()).ok()?.to_owned()); - } - "string" => { - if child.child_count() == 3 - && child.child(0)?.kind() == "\"" - && child.child(1)?.kind() == "string_content" - && child.child(2)?.kind() == "\"" - { - words.push(child.child(1)?.utf8_text(src.as_bytes()).ok()?.to_owned()); - } else { - // Anything else means the command is *not* plain words. - return None; - } - } - "concatenation" => { - // TODO: Consider things like `'ab\'a'`. - return None; - } - "raw_string" => { - // Raw string is a single word, but we need to strip the quotes. - let raw_string = child.utf8_text(src.as_bytes()).ok()?; - let stripped = raw_string - .strip_prefix('\'') - .and_then(|s| s.strip_suffix('\'')); - if let Some(stripped) = stripped { - words.push(stripped.to_owned()); - } else { - return None; - } - } - // Anything else means the command is *not* plain words. - _ => return None, - } - } - - Some(words) -} +// (bash parsing helpers implemented in crate::bash) /* ---------------------------------------------------------- Example @@ -193,6 +159,7 @@ fn is_valid_sed_n_arg(arg: Option<&str>) -> bool { _ => false, } } + #[cfg(test)] mod tests { #![allow(clippy::unwrap_used)] @@ -209,6 +176,11 @@ mod tests { assert!(is_safe_to_call_with_exec(&vec_str(&[ "sed", "-n", "1,5p", "file.txt" ]))); + assert!(is_safe_to_call_with_exec(&vec_str(&[ + "nl", + "-nrz", + "Cargo.toml" + ]))); // Safe `find` command (no unsafe options). assert!(is_safe_to_call_with_exec(&vec_str(&[ @@ -245,6 +217,40 @@ mod tests { } } + #[test] + fn ripgrep_rules() { + // Safe ripgrep invocations – none of the unsafe flags are present. + assert!(is_safe_to_call_with_exec(&vec_str(&[ + "rg", + "Cargo.toml", + "-n" + ]))); + + // Unsafe flags that do not take an argument (present verbatim). + for args in [ + vec_str(&["rg", "--search-zip", "files"]), + vec_str(&["rg", "-z", "files"]), + ] { + assert!( + !is_safe_to_call_with_exec(&args), + "expected {args:?} to be considered unsafe due to zip-search flag", + ); + } + + // Unsafe flags that expect a value, provided in both split and = forms. + for args in [ + vec_str(&["rg", "--pre", "pwned", "files"]), + vec_str(&["rg", "--pre=pwned", "files"]), + vec_str(&["rg", "--hostname-bin", "pwned", "files"]), + vec_str(&["rg", "--hostname-bin=pwned", "files"]), + ] { + assert!( + !is_safe_to_call_with_exec(&args), + "expected {args:?} to be considered unsafe due to external-command flag", + ); + } + } + #[test] fn bash_lc_safe_examples() { assert!(is_known_safe_command(&vec_str(&["bash", "-lc", "ls"]))); @@ -277,6 +283,30 @@ mod tests { ]))); } + #[test] + fn bash_lc_safe_examples_with_operators() { + assert!(is_known_safe_command(&vec_str(&[ + "bash", + "-lc", + "grep -R \"Cargo.toml\" -n || true" + ]))); + assert!(is_known_safe_command(&vec_str(&[ + "bash", + "-lc", + "ls && pwd" + ]))); + assert!(is_known_safe_command(&vec_str(&[ + "bash", + "-lc", + "echo 'hi' ; ls" + ]))); + assert!(is_known_safe_command(&vec_str(&[ + "bash", + "-lc", + "ls | wc -l" + ]))); + } + #[test] fn bash_lc_unsafe_examples() { assert!( @@ -290,44 +320,29 @@ mod tests { assert!( !is_known_safe_command(&vec_str(&["bash", "-lc", "find . -name file.txt -delete"])), - "Unsafe find option should not be auto‑approved." - ); - } - - #[test] - fn test_try_parse_single_word_only_command() { - let script_with_single_quoted_string = "sed -n '1,5p' file.txt"; - let parsed_words = try_parse_bash(script_with_single_quoted_string) - .and_then(|tree| { - try_parse_single_word_only_command(&tree, script_with_single_quoted_string) - }) - .unwrap(); - assert_eq!( - vec![ - "sed".to_string(), - "-n".to_string(), - // Ensure the single quotes are properly removed. - "1,5p".to_string(), - "file.txt".to_string() - ], - parsed_words, + "Unsafe find option should not be auto-approved." ); - let script_with_number_arg = "ls -1"; - let parsed_words = try_parse_bash(script_with_number_arg) - .and_then(|tree| try_parse_single_word_only_command(&tree, script_with_number_arg)) - .unwrap(); - assert_eq!(vec!["ls", "-1"], parsed_words,); + // Disallowed because of unsafe command in sequence. + assert!( + !is_known_safe_command(&vec_str(&["bash", "-lc", "ls && rm -rf /"])), + "Sequence containing unsafe command must be rejected" + ); - let script_with_double_quoted_string_with_no_funny_stuff_arg = "grep -R \"Cargo.toml\" -n"; - let parsed_words = try_parse_bash(script_with_double_quoted_string_with_no_funny_stuff_arg) - .and_then(|tree| { - try_parse_single_word_only_command( - &tree, - script_with_double_quoted_string_with_no_funny_stuff_arg, - ) - }) - .unwrap(); - assert_eq!(vec!["grep", "-R", "Cargo.toml", "-n"], parsed_words); + // Disallowed because of parentheses / subshell. + assert!( + !is_known_safe_command(&vec_str(&["bash", "-lc", "(ls)"])), + "Parentheses (subshell) are not provably safe with the current parser" + ); + assert!( + !is_known_safe_command(&vec_str(&["bash", "-lc", "ls || (pwd && echo hi)"])), + "Nested parentheses are not provably safe with the current parser" + ); + + // Disallowed redirection. + assert!( + !is_known_safe_command(&vec_str(&["bash", "-lc", "ls > out.txt"])), + "> redirection should be rejected" + ); } } diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 6812260c97..054abd742a 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -5,11 +5,14 @@ // the TUI or the tracing stack). #![deny(clippy::print_stdout, clippy::print_stderr)] +mod apply_patch; +mod bash; mod chat_completions; mod client; mod client_common; pub mod codex; pub use codex::Codex; +pub use codex::CodexSpawnOk; pub mod codex_wrapper; pub mod config; pub mod config_profile; @@ -19,6 +22,7 @@ pub mod error; pub mod exec; pub mod exec_env; mod flags; +pub mod git_info; mod is_safe_command; mod mcp_connection_manager; mod mcp_tool_call; @@ -26,15 +30,18 @@ mod message_history; mod model_provider_info; pub use model_provider_info::ModelProviderInfo; pub use model_provider_info::WireApi; +pub use model_provider_info::built_in_model_providers; mod models; -pub mod openai_api_key; mod openai_model_info; mod openai_tools; +pub mod plan_tool; mod project_doc; pub mod protocol; mod rollout; mod safety; +pub mod shell; mod user_notification; pub mod util; +pub use apply_patch::CODEX_APPLY_PATCH_ARG1; pub use client_common::model_supports_reasoning_summaries; diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 6ae1865f16..2e33c8754b 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -7,6 +7,8 @@ //! `""` as the key. use std::collections::HashMap; +use std::collections::HashSet; +use std::ffi::OsString; use std::time::Duration; use anyhow::Context; @@ -16,8 +18,13 @@ use codex_mcp_client::McpClient; use mcp_types::ClientCapabilities; use mcp_types::Implementation; use mcp_types::Tool; + +use serde_json::json; +use sha1::Digest; +use sha1::Sha1; use tokio::task::JoinSet; use tracing::info; +use tracing::warn; use crate::config_types::McpServerConfig; @@ -26,7 +33,8 @@ use crate::config_types::McpServerConfig; /// /// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must /// choose a delimiter from this character set. -const MCP_TOOL_NAME_DELIMITER: &str = "__OAI_CODEX_MCP__"; +const MCP_TOOL_NAME_DELIMITER: &str = "__"; +const MAX_TOOL_NAME_LENGTH: usize = 64; /// Timeout for the `tools/list` request. const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10); @@ -35,16 +43,42 @@ const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10); /// spawned successfully. pub type ClientStartErrors = HashMap; -fn fully_qualified_tool_name(server: &str, tool: &str) -> String { - format!("{server}{MCP_TOOL_NAME_DELIMITER}{tool}") +fn qualify_tools(tools: Vec) -> HashMap { + let mut used_names = HashSet::new(); + let mut qualified_tools = HashMap::new(); + for tool in tools { + let mut qualified_name = format!( + "{}{}{}", + tool.server_name, MCP_TOOL_NAME_DELIMITER, tool.tool_name + ); + if qualified_name.len() > MAX_TOOL_NAME_LENGTH { + let mut hasher = Sha1::new(); + hasher.update(qualified_name.as_bytes()); + let sha1 = hasher.finalize(); + let sha1_str = format!("{sha1:x}"); + + // Truncate to make room for the hash suffix + let prefix_len = MAX_TOOL_NAME_LENGTH - sha1_str.len(); + + qualified_name = format!("{}{}", &qualified_name[..prefix_len], sha1_str); + } + + if used_names.contains(&qualified_name) { + warn!("skipping duplicated tool {}", qualified_name); + continue; + } + + used_names.insert(qualified_name.clone()); + qualified_tools.insert(qualified_name, tool); + } + + qualified_tools } -pub(crate) fn try_parse_fully_qualified_tool_name(fq_name: &str) -> Option<(String, String)> { - let (server, tool) = fq_name.split_once(MCP_TOOL_NAME_DELIMITER)?; - if server.is_empty() || tool.is_empty() { - return None; - } - Some((server.to_string(), tool.to_string())) +struct ToolInfo { + server_name: String, + tool_name: String, + tool: Tool, } /// A thin wrapper around a set of running [`McpClient`] instances. @@ -57,7 +91,7 @@ pub(crate) struct McpConnectionManager { clients: HashMap>, /// Fully qualified tool name -> tool instance. - tools: HashMap, + tools: HashMap, } impl McpConnectionManager { @@ -79,12 +113,27 @@ impl McpConnectionManager { // Launch all configured servers concurrently. let mut join_set = JoinSet::new(); + let mut errors = ClientStartErrors::new(); for (server_name, cfg) in mcp_servers { - // TODO: Verify server name: require `^[a-zA-Z0-9_-]+$`? + // Validate server name before spawning + if !is_valid_mcp_server_name(&server_name) { + let error = anyhow::anyhow!( + "invalid server name '{}': must match pattern ^[a-zA-Z0-9_-]+$", + server_name + ); + errors.insert(server_name, error); + continue; + } + join_set.spawn(async move { let McpServerConfig { command, args, env } = cfg; - let client_res = McpClient::new_stdio_client(command, args, env).await; + let client_res = McpClient::new_stdio_client( + command.into(), + args.into_iter().map(OsString::from).collect(), + env, + ) + .await; match client_res { Ok(client) => { // Initialize the client. @@ -93,10 +142,14 @@ impl McpConnectionManager { experimental: None, roots: None, sampling: None, + // https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities + // indicates this should be an empty object. + elicitation: Some(json!({})), }, client_info: Implementation { name: "codex-mcp-client".to_owned(), version: env!("CARGO_PKG_VERSION").to_owned(), + title: Some("Codex".into()), }, protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), }; @@ -117,7 +170,6 @@ impl McpConnectionManager { let mut clients: HashMap> = HashMap::with_capacity(join_set.len()); - let mut errors = ClientStartErrors::new(); while let Some(res) = join_set.join_next().await { let (server_name, client_res) = res?; // JoinError propagation @@ -132,7 +184,9 @@ impl McpConnectionManager { } } - let tools = list_all_tools(&clients).await?; + let all_tools = list_all_tools(&clients).await?; + + let tools = qualify_tools(all_tools); Ok((Self { clients, tools }, errors)) } @@ -140,7 +194,10 @@ impl McpConnectionManager { /// Returns a single map that contains **all** tools. Each key is the /// fully-qualified name for the tool. pub fn list_all_tools(&self) -> HashMap { - self.tools.clone() + self.tools + .iter() + .map(|(name, tool)| (name.clone(), tool.tool.clone())) + .collect() } /// Invoke the tool indicated by the (server, tool) pair. @@ -162,13 +219,19 @@ impl McpConnectionManager { .await .with_context(|| format!("tool call failed for `{server}/{tool}`")) } + + pub fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> { + self.tools + .get(tool_name) + .map(|tool| (tool.server_name.clone(), tool.tool_name.clone())) + } } /// Query every server for its available tools and return a single map that /// contains **all** tools. Each key is the fully-qualified name for the tool. -pub async fn list_all_tools( +async fn list_all_tools( clients: &HashMap>, -) -> Result> { +) -> Result> { let mut join_set = JoinSet::new(); // Spawn one task per server so we can query them concurrently. This @@ -185,18 +248,19 @@ pub async fn list_all_tools( }); } - let mut aggregated: HashMap = HashMap::with_capacity(join_set.len()); + let mut aggregated: Vec = Vec::with_capacity(join_set.len()); while let Some(join_res) = join_set.join_next().await { let (server_name, list_result) = join_res?; let list_result = list_result?; for tool in list_result.tools { - // TODO(mbolin): escape tool names that contain invalid characters. - let fq_name = fully_qualified_tool_name(&server_name, &tool.name); - if aggregated.insert(fq_name.clone(), tool).is_some() { - panic!("tool name collision for '{fq_name}': suspicious"); - } + let tool_info = ToolInfo { + server_name: server_name.clone(), + tool_name: tool.name.clone(), + tool, + }; + aggregated.push(tool_info); } } @@ -208,3 +272,99 @@ pub async fn list_all_tools( Ok(aggregated) } + +fn is_valid_mcp_server_name(server_name: &str) -> bool { + !server_name.is_empty() + && server_name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use mcp_types::ToolInputSchema; + + fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { + ToolInfo { + server_name: server_name.to_string(), + tool_name: tool_name.to_string(), + tool: Tool { + annotations: None, + description: Some(format!("Test tool: {tool_name}")), + input_schema: ToolInputSchema { + properties: None, + required: None, + r#type: "object".to_string(), + }, + name: tool_name.to_string(), + output_schema: None, + title: None, + }, + } + } + + #[test] + fn test_qualify_tools_short_non_duplicated_names() { + let tools = vec![ + create_test_tool("server1", "tool1"), + create_test_tool("server1", "tool2"), + ]; + + let qualified_tools = qualify_tools(tools); + + assert_eq!(qualified_tools.len(), 2); + assert!(qualified_tools.contains_key("server1__tool1")); + assert!(qualified_tools.contains_key("server1__tool2")); + } + + #[test] + fn test_qualify_tools_duplicated_names_skipped() { + let tools = vec![ + create_test_tool("server1", "duplicate_tool"), + create_test_tool("server1", "duplicate_tool"), + ]; + + let qualified_tools = qualify_tools(tools); + + // Only the first tool should remain, the second is skipped + assert_eq!(qualified_tools.len(), 1); + assert!(qualified_tools.contains_key("server1__duplicate_tool")); + } + + #[test] + fn test_qualify_tools_long_names_same_server() { + let server_name = "my_server"; + + let tools = vec![ + create_test_tool( + server_name, + "extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits", + ), + create_test_tool( + server_name, + "yet_another_extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits", + ), + ]; + + let qualified_tools = qualify_tools(tools); + + assert_eq!(qualified_tools.len(), 2); + + let mut keys: Vec<_> = qualified_tools.keys().cloned().collect(); + keys.sort(); + + assert_eq!(keys[0].len(), 64); + assert_eq!( + keys[0], + "my_server__extremely_lena02e507efc5a9de88637e436690364fd4219e4ef" + ); + + assert_eq!(keys[1].len(), 64); + assert_eq!( + keys[1], + "my_server__yet_another_e1c3987bd9c50b826cbe1687966f79f0c602d19ca" + ); + } +} diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs index 61a51a0e7a..e92d7e8481 100644 --- a/codex-rs/core/src/mcp_tool_call.rs +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -1,4 +1,5 @@ use std::time::Duration; +use std::time::Instant; use tracing::error; @@ -7,6 +8,7 @@ use crate::models::FunctionCallOutputPayload; use crate::models::ResponseInputItem; use crate::protocol::Event; use crate::protocol::EventMsg; +use crate::protocol::McpInvocation; use crate::protocol::McpToolCallBeginEvent; use crate::protocol::McpToolCallEndEvent; @@ -41,21 +43,28 @@ pub(crate) async fn handle_mcp_tool_call( } }; - let tool_call_begin_event = EventMsg::McpToolCallBegin(McpToolCallBeginEvent { - call_id: call_id.clone(), + let invocation = McpInvocation { server: server.clone(), tool: tool_name.clone(), arguments: arguments_value.clone(), + }; + + let tool_call_begin_event = EventMsg::McpToolCallBegin(McpToolCallBeginEvent { + call_id: call_id.clone(), + invocation: invocation.clone(), }); notify_mcp_tool_call_event(sess, sub_id, tool_call_begin_event).await; + let start = Instant::now(); // Perform the tool call. let result = sess - .call_tool(&server, &tool_name, arguments_value, timeout) + .call_tool(&server, &tool_name, arguments_value.clone(), timeout) .await .map_err(|e| format!("tool call error: {e}")); let tool_call_end_event = EventMsg::McpToolCallEnd(McpToolCallEndEvent { call_id: call_id.clone(), + invocation, + duration: start.elapsed(), result: result.clone(), }); diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index b38c912d34..4640f53ad7 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -9,13 +9,16 @@ use serde::Deserialize; use serde::Serialize; use std::collections::HashMap; use std::env::VarError; +use std::time::Duration; use crate::error::EnvVarError; -use crate::openai_api_key::get_openai_api_key; /// Value for the `OpenAI-Originator` header that is sent with requests to /// OpenAI. const OPENAI_ORIGINATOR_HEADER: &str = "codex_cli_rs"; +const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000; +const DEFAULT_STREAM_MAX_RETRIES: u64 = 10; +const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4; /// Wire protocol that the provider speaks. Most third-party services only /// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI @@ -26,7 +29,7 @@ const OPENAI_ORIGINATOR_HEADER: &str = "codex_cli_rs"; #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum WireApi { - /// The experimental “Responses” API exposed by OpenAI at `/v1/responses`. + /// The Responses API exposed by OpenAI at `/v1/responses`. Responses, /// Regular Chat Completions compatible with `/v1/chat/completions`. @@ -40,7 +43,7 @@ pub struct ModelProviderInfo { /// Friendly display name. pub name: String, /// Base URL for the provider's OpenAI-compatible API. - pub base_url: String, + pub base_url: Option, /// Environment variable that stores the user's API key for this provider. pub env_key: Option, @@ -64,6 +67,20 @@ pub struct ModelProviderInfo { /// value should be used. If the environment variable is not set, or the /// value is empty, the header will not be included in the request. pub env_http_headers: Option>, + + /// Maximum number of times to retry a failed HTTP request to this provider. + pub request_max_retries: Option, + + /// Number of times to retry reconnecting a dropped streaming response before failing. + pub stream_max_retries: Option, + + /// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating + /// the connection as lost. + pub stream_idle_timeout_ms: Option, + + /// Whether this provider requires some form of standard authentication (API key, ChatGPT token). + #[serde(default)] + pub requires_auth: bool, } impl ModelProviderInfo { @@ -79,11 +96,11 @@ impl ModelProviderInfo { &'a self, client: &'a reqwest::Client, ) -> crate::error::Result { - let api_key = self.api_key()?; - let url = self.get_full_url(); let mut builder = client.post(url); + + let api_key = self.api_key()?; if let Some(key) = api_key { builder = builder.bearer_auth(key); } @@ -103,9 +120,15 @@ impl ModelProviderInfo { .join("&"); format!("?{full_params}") }); - let base_url = &self.base_url; + let base_url = self + .base_url + .clone() + .unwrap_or("https://api.openai.com/v1".to_string()); + match self.wire_api { - WireApi::Responses => format!("{base_url}/responses{query_string}"), + WireApi::Responses => { + format!("{base_url}/responses{query_string}") + } WireApi::Chat => format!("{base_url}/chat/completions{query_string}"), } } @@ -113,7 +136,10 @@ impl ModelProviderInfo { /// Apply provider-specific HTTP headers (both static and environment-based) /// onto an existing `reqwest::RequestBuilder` and return the updated /// builder. - fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + pub fn apply_http_headers( + &self, + mut builder: reqwest::RequestBuilder, + ) -> reqwest::RequestBuilder { if let Some(extra) = &self.http_headers { for (k, v) in extra { builder = builder.header(k, v); @@ -138,11 +164,7 @@ impl ModelProviderInfo { fn api_key(&self) -> crate::error::Result> { match &self.env_key { Some(env_key) => { - let env_value = if env_key == crate::openai_api_key::OPENAI_API_KEY_ENV_VAR { - get_openai_api_key().map_or_else(|| Err(VarError::NotPresent), Ok) - } else { - std::env::var(env_key) - }; + let env_value = std::env::var(env_key); env_value .and_then(|v| { if v.trim().is_empty() { @@ -161,6 +183,25 @@ impl ModelProviderInfo { None => Ok(None), } } + + /// Effective maximum number of request retries for this provider. + pub fn request_max_retries(&self) -> u64 { + self.request_max_retries + .unwrap_or(DEFAULT_REQUEST_MAX_RETRIES) + } + + /// Effective maximum number of stream reconnection attempts for this provider. + pub fn stream_max_retries(&self) -> u64 { + self.stream_max_retries + .unwrap_or(DEFAULT_STREAM_MAX_RETRIES) + } + + /// Effective idle timeout for streaming responses. + pub fn stream_idle_timeout(&self) -> Duration { + self.stream_idle_timeout_ms + .map(Duration::from_millis) + .unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS)) + } } /// Built-in default provider list. @@ -171,43 +212,51 @@ pub fn built_in_model_providers() -> HashMap { // providers are bundled with Codex CLI, so we only include the OpenAI // provider by default. Users are encouraged to add to `model_providers` // in config.toml to add their own providers. - [ - ( - "openai", - P { - name: "OpenAI".into(), - // Allow users to override the default OpenAI endpoint by - // exporting `OPENAI_BASE_URL`. This is useful when pointing - // Codex at a proxy, mock server, or Azure-style deployment - // without requiring a full TOML override for the built-in - // OpenAI provider. - base_url: std::env::var("OPENAI_BASE_URL") - .ok() - .filter(|v| !v.trim().is_empty()) - .unwrap_or_else(|| "https://api.openai.com/v1".to_string()), - env_key: Some("OPENAI_API_KEY".into()), - env_key_instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".into()), - wire_api: WireApi::Responses, - query_params: None, - http_headers: Some( - [ - ("originator".to_string(), OPENAI_ORIGINATOR_HEADER.to_string()), - ("version".to_string(), env!("CARGO_PKG_VERSION").to_string()), - ] - .into_iter() - .collect(), - ), - env_http_headers: Some( - [ - ("OpenAI-Organization".to_string(), "OPENAI_ORGANIZATION".to_string()), - ("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()), - ] - .into_iter() - .collect(), - ), - }, - ), - ] + [( + "openai", + P { + name: "OpenAI".into(), + // Allow users to override the default OpenAI endpoint by + // exporting `OPENAI_BASE_URL`. This is useful when pointing + // Codex at a proxy, mock server, or Azure-style deployment + // without requiring a full TOML override for the built-in + // OpenAI provider. + base_url: std::env::var("OPENAI_BASE_URL") + .ok() + .filter(|v| !v.trim().is_empty()), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: Some( + [ + ( + "originator".to_string(), + OPENAI_ORIGINATOR_HEADER.to_string(), + ), + ("version".to_string(), env!("CARGO_PKG_VERSION").to_string()), + ] + .into_iter() + .collect(), + ), + env_http_headers: Some( + [ + ( + "OpenAI-Organization".to_string(), + "OPENAI_ORGANIZATION".to_string(), + ), + ("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()), + ] + .into_iter() + .collect(), + ), + // Use global defaults for retry/timeout unless overridden in config.toml. + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_auth: true, + }, + )] .into_iter() .map(|(k, v)| (k.to_string(), v)) .collect() @@ -227,13 +276,17 @@ base_url = "http://localhost:11434/v1" "#; let expected_provider = ModelProviderInfo { name: "Ollama".into(), - base_url: "http://localhost:11434/v1".into(), + base_url: Some("http://localhost:11434/v1".into()), env_key: None, env_key_instructions: None, wire_api: WireApi::Chat, query_params: None, http_headers: None, env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_auth: false, }; let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); @@ -250,7 +303,7 @@ query_params = { api-version = "2025-04-01-preview" } "#; let expected_provider = ModelProviderInfo { name: "Azure".into(), - base_url: "https://xxxxx.openai.azure.com/openai".into(), + base_url: Some("https://xxxxx.openai.azure.com/openai".into()), env_key: Some("AZURE_OPENAI_API_KEY".into()), env_key_instructions: None, wire_api: WireApi::Chat, @@ -259,6 +312,10 @@ query_params = { api-version = "2025-04-01-preview" } }), http_headers: None, env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_auth: false, }; let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); @@ -276,7 +333,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } "#; let expected_provider = ModelProviderInfo { name: "Example".into(), - base_url: "https://example.com".into(), + base_url: Some("https://example.com".into()), env_key: Some("API_KEY".into()), env_key_instructions: None, wire_api: WireApi::Chat, @@ -287,6 +344,10 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } env_http_headers: Some(maplit::hashmap! { "X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(), }), + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_auth: false, }; let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); diff --git a/codex-rs/core/src/models.rs b/codex-rs/core/src/models.rs index 864ab1d799..2a5e8c6f22 100644 --- a/codex-rs/core/src/models.rs +++ b/codex-rs/core/src/models.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use base64::Engine; use mcp_types::CallToolResult; use serde::Deserialize; +use serde::Deserializer; use serde::Serialize; use serde::ser::Serializer; @@ -37,6 +38,7 @@ pub enum ContentItem { #[serde(tag = "type", rename_all = "snake_case")] pub enum ResponseItem { Message { + id: Option, role: String, content: Vec, }, @@ -45,6 +47,7 @@ pub enum ResponseItem { summary: Vec, #[serde(default)] content: Vec, + encrypted_content: Option, }, LocalShellCall { /// Set when using the chat completions API. @@ -55,6 +58,7 @@ pub enum ResponseItem { action: LocalShellAction, }, FunctionCall { + id: Option, name: String, // The Responses API returns the function call arguments as a *string* that contains // JSON, not as an already‑parsed object. We keep it as a raw string here and let @@ -80,7 +84,11 @@ pub enum ResponseItem { impl From for ResponseItem { fn from(item: ResponseInputItem) -> Self { match item { - ResponseInputItem::Message { role, content } => Self::Message { role, content }, + ResponseInputItem::Message { role, content } => Self::Message { + role, + content, + id: None, + }, ResponseInputItem::FunctionCallOutput { call_id, output } => { Self::FunctionCallOutput { call_id, output } } @@ -185,7 +193,7 @@ pub struct ShellToolCallParams { pub timeout_ms: Option, } -#[derive(Deserialize, Debug, Clone)] +#[derive(Debug, Clone)] pub struct FunctionCallOutputPayload { pub content: String, #[expect(dead_code)] @@ -213,6 +221,19 @@ impl Serialize for FunctionCallOutputPayload { } } +impl<'de> Deserialize<'de> for FunctionCallOutputPayload { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Ok(FunctionCallOutputPayload { + content: s, + success: None, + }) + } +} + // Implement Display so callers can treat the payload like a plain string when logging or doing // trivial substring checks in tests (existing tests call `.contains()` on the output). Display // returns the raw `content` field. diff --git a/codex-rs/core/src/openai_api_key.rs b/codex-rs/core/src/openai_api_key.rs deleted file mode 100644 index 728914c0f2..0000000000 --- a/codex-rs/core/src/openai_api_key.rs +++ /dev/null @@ -1,24 +0,0 @@ -use std::env; -use std::sync::LazyLock; -use std::sync::RwLock; - -pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY"; - -static OPENAI_API_KEY: LazyLock>> = LazyLock::new(|| { - let val = env::var(OPENAI_API_KEY_ENV_VAR) - .ok() - .and_then(|s| if s.is_empty() { None } else { Some(s) }); - RwLock::new(val) -}); - -pub fn get_openai_api_key() -> Option { - #![allow(clippy::unwrap_used)] - OPENAI_API_KEY.read().unwrap().clone() -} - -pub fn set_openai_api_key(value: String) { - #![allow(clippy::unwrap_used)] - if !value.is_empty() { - *OPENAI_API_KEY.write().unwrap() = Some(value); - } -} diff --git a/codex-rs/core/src/openai_tools.rs b/codex-rs/core/src/openai_tools.rs index ef12a629b6..0f1e7d9ca7 100644 --- a/codex-rs/core/src/openai_tools.rs +++ b/codex-rs/core/src/openai_tools.rs @@ -4,13 +4,14 @@ use std::collections::BTreeMap; use std::sync::LazyLock; use crate::client_common::Prompt; +use crate::plan_tool::PLAN_TOOL; #[derive(Debug, Clone, Serialize)] pub(crate) struct ResponsesApiTool { - name: &'static str, - description: &'static str, - strict: bool, - parameters: JsonSchema, + pub(crate) name: &'static str, + pub(crate) description: &'static str, + pub(crate) strict: bool, + pub(crate) parameters: JsonSchema, } /// When serialized as JSON, this produces a valid "Tool" in the OpenAI @@ -74,6 +75,7 @@ static DEFAULT_CODEX_MODEL_TOOLS: LazyLock> = pub(crate) fn create_tools_json_for_responses_api( prompt: &Prompt, model: &str, + include_plan_tool: bool, ) -> crate::error::Result> { // Assemble tool list: built-in tools + any extra tools from the prompt. let default_tools = if model.starts_with("codex") { @@ -93,6 +95,10 @@ pub(crate) fn create_tools_json_for_responses_api( .map(|(name, tool)| mcp_tool_to_openai_tool(name, tool)), ); + if include_plan_tool { + tools_json.push(serde_json::to_value(PLAN_TOOL.clone())?); + } + Ok(tools_json) } @@ -102,10 +108,12 @@ pub(crate) fn create_tools_json_for_responses_api( pub(crate) fn create_tools_json_for_chat_completions_api( prompt: &Prompt, model: &str, + include_plan_tool: bool, ) -> crate::error::Result> { // We start with the JSON for the Responses API and than rewrite it to match // the chat completions tool call format. - let responses_api_tools_json = create_tools_json_for_responses_api(prompt, model)?; + let responses_api_tools_json = + create_tools_json_for_responses_api(prompt, model, include_plan_tool)?; let tools_json = responses_api_tools_json .into_iter() .filter_map(|mut tool| { diff --git a/codex-rs/core/src/plan_tool.rs b/codex-rs/core/src/plan_tool.rs new file mode 100644 index 0000000000..dbddb8b5eb --- /dev/null +++ b/codex-rs/core/src/plan_tool.rs @@ -0,0 +1,126 @@ +use std::collections::BTreeMap; +use std::sync::LazyLock; + +use serde::Deserialize; +use serde::Serialize; + +use crate::codex::Session; +use crate::models::FunctionCallOutputPayload; +use crate::models::ResponseInputItem; +use crate::openai_tools::JsonSchema; +use crate::openai_tools::OpenAiTool; +use crate::openai_tools::ResponsesApiTool; +use crate::protocol::Event; +use crate::protocol::EventMsg; + +// Types for the TODO tool arguments matching codex-vscode/todo-mcp/src/main.rs +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum StepStatus { + Pending, + InProgress, + Completed, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct PlanItemArg { + pub step: String, + pub status: StepStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct UpdatePlanArgs { + #[serde(default)] + pub explanation: Option, + pub plan: Vec, +} + +pub(crate) static PLAN_TOOL: LazyLock = LazyLock::new(|| { + let mut plan_item_props = BTreeMap::new(); + plan_item_props.insert("step".to_string(), JsonSchema::String); + plan_item_props.insert("status".to_string(), JsonSchema::String); + + let plan_items_schema = JsonSchema::Array { + items: Box::new(JsonSchema::Object { + properties: plan_item_props, + required: &["step", "status"], + additional_properties: false, + }), + }; + + let mut properties = BTreeMap::new(); + properties.insert("explanation".to_string(), JsonSchema::String); + properties.insert("plan".to_string(), plan_items_schema); + + OpenAiTool::Function(ResponsesApiTool { + name: "update_plan", + description: r#"Use the update_plan tool to keep the user updated on the current plan for the task. +After understanding the user's task, call the update_plan tool with an initial plan. An example of a plan: +1. Explore the codebase to find relevant files (status: in_progress) +2. Implement the feature in the XYZ component (status: pending) +3. Commit changes and make a pull request (status: pending) +Each step should be a short, 1-sentence description. +Until all the steps are finished, there should always be exactly one in_progress step in the plan. +Call the update_plan tool whenever you finish a step, marking the completed step as `completed` and marking the next step as `in_progress`. +Before running a command, consider whether or not you have completed the previous step, and make sure to mark it as completed before moving on to the next step. +Sometimes, you may need to change plans in the middle of a task: call `update_plan` with the updated plan and make sure to provide an `explanation` of the rationale when doing so. +When all steps are completed, call update_plan one last time with all steps marked as `completed`."#, + strict: false, + parameters: JsonSchema::Object { + properties, + required: &["plan"], + additional_properties: false, + }, + }) +}); + +/// This function doesn't do anything useful. However, it gives the model a structured way to record its plan that clients can read and render. +/// So it's the _inputs_ to this function that are useful to clients, not the outputs and neither are actually useful for the model other +/// than forcing it to come up and document a plan (TBD how that affects performance). +pub(crate) async fn handle_update_plan( + session: &Session, + arguments: String, + sub_id: String, + call_id: String, +) -> ResponseInputItem { + match parse_update_plan_arguments(arguments, &call_id) { + Ok(args) => { + let output = ResponseInputItem::FunctionCallOutput { + call_id, + output: FunctionCallOutputPayload { + content: "Plan updated".to_string(), + success: Some(true), + }, + }; + session + .send_event(Event { + id: sub_id.to_string(), + msg: EventMsg::PlanUpdate(args), + }) + .await; + output + } + Err(output) => *output, + } +} + +fn parse_update_plan_arguments( + arguments: String, + call_id: &str, +) -> Result> { + match serde_json::from_str::(&arguments) { + Ok(args) => Ok(args), + Err(e) => { + let output = ResponseInputItem::FunctionCallOutput { + call_id: call_id.to_string(), + output: FunctionCallOutputPayload { + content: format!("failed to parse function arguments: {e}"), + success: None, + }, + }; + Err(Box::new(output)) + } + } +} diff --git a/codex-rs/core/src/project_doc.rs b/codex-rs/core/src/project_doc.rs index ab9d46186f..9f46159d1d 100644 --- a/codex-rs/core/src/project_doc.rs +++ b/codex-rs/core/src/project_doc.rs @@ -27,16 +27,16 @@ const PROJECT_DOC_SEPARATOR: &str = "\n\n--- project-doc ---\n\n"; /// string of instructions. pub(crate) async fn get_user_instructions(config: &Config) -> Option { match find_project_doc(config).await { - Ok(Some(project_doc)) => match &config.instructions { + Ok(Some(project_doc)) => match &config.user_instructions { Some(original_instructions) => Some(format!( "{original_instructions}{PROJECT_DOC_SEPARATOR}{project_doc}" )), None => Some(project_doc), }, - Ok(None) => config.instructions.clone(), + Ok(None) => config.user_instructions.clone(), Err(e) => { error!("error trying to find project doc: {e:#}"); - config.instructions.clone() + config.user_instructions.clone() } } } @@ -159,7 +159,7 @@ mod tests { config.cwd = root.path().to_path_buf(); config.project_doc_max_bytes = limit; - config.instructions = instructions.map(ToOwned::to_owned); + config.user_instructions = instructions.map(ToOwned::to_owned); config } diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs index dd64b4c8a8..2111aff74f 100644 --- a/codex-rs/core/src/protocol.rs +++ b/codex-rs/core/src/protocol.rs @@ -4,19 +4,23 @@ //! between user and agent. use std::collections::HashMap; +use std::fmt; use std::path::Path; use std::path::PathBuf; use std::str::FromStr; +use std::time::Duration; use mcp_types::CallToolResult; use serde::Deserialize; use serde::Serialize; +use strum_macros::Display; use uuid::Uuid; use crate::config_types::ReasoningEffort as ReasoningEffortConfig; use crate::config_types::ReasoningSummary as ReasoningSummaryConfig; use crate::message_history::HistoryEntry; use crate::model_provider_info::ModelProviderInfo; +use crate::plan_tool::UpdatePlanArgs; /// Submission Queue Entry - requests from user #[derive(Debug, Clone, Deserialize, Serialize)] @@ -44,8 +48,12 @@ pub enum Op { model_reasoning_effort: ReasoningEffortConfig, model_reasoning_summary: ReasoningSummaryConfig, - /// Model instructions - instructions: Option, + /// Model instructions that are appended to the base instructions. + user_instructions: Option, + + /// Base instructions override. + base_instructions: Option, + /// When to escalate for approval for execution approval_policy: AskForApproval, /// How to sandbox commands executed in the system @@ -69,6 +77,10 @@ pub enum Op { /// `ConfigureSession` operation so that the business-logic layer can /// operate deterministically. cwd: std::path::PathBuf, + + /// Path to a rollout file to resume from. + #[serde(skip_serializing_if = "Option::is_none")] + resume_path: Option, }, /// Abort current task. @@ -108,18 +120,23 @@ pub enum Op { /// Request a single history entry identified by `log_id` + `offset`. GetHistoryEntryRequest { offset: usize, log_id: u64 }, + + /// Request to shut down codex instance. + Shutdown, } /// Determines the conditions under which the user is consulted to approve /// running the command proposed by Codex. -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize, Display)] #[serde(rename_all = "kebab-case")] +#[strum(serialize_all = "kebab-case")] pub enum AskForApproval { /// Under this policy, only "known safe" commands—as determined by /// `is_safe_command()`—that **only read files** are auto‑approved. /// Everything else will ask the user to approve. #[default] #[serde(rename = "untrusted")] + #[strum(serialize = "untrusted")] UnlessTrusted, /// *All* commands are auto‑approved, but they are expected to run inside a @@ -263,8 +280,9 @@ pub struct Event { } /// Response event from the agent -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize, Display)] #[serde(tag = "type", rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] pub enum EventMsg { /// Error while executing a submission Error(ErrorEvent), @@ -282,9 +300,15 @@ pub enum EventMsg { /// Agent text output message AgentMessage(AgentMessageEvent), - /// Reasoning summary from agent. + /// Agent text output delta message + AgentMessageDelta(AgentMessageDeltaEvent), + + /// Reasoning event from agent. AgentReasoning(AgentReasoningEvent), + /// Agent reasoning delta event from agent. + AgentReasoningDelta(AgentReasoningDeltaEvent), + /// Raw chain-of-thought from agent. AgentReasoningContent(AgentReasoningContentEvent), @@ -315,6 +339,11 @@ pub enum EventMsg { /// Response to GetHistoryEntryRequest. GetHistoryEntryResponse(GetHistoryEntryResponseEvent), + + PlanUpdate(UpdatePlanArgs), + + /// Notification that the agent is shutting down. + ShutdownComplete, } // Individual event payload types matching each `EventMsg` variant. @@ -338,11 +367,46 @@ pub struct TokenUsage { pub total_tokens: u64, } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FinalOutput { + pub token_usage: TokenUsage, +} + +impl From for FinalOutput { + fn from(token_usage: TokenUsage) -> Self { + Self { token_usage } + } +} + +impl fmt::Display for FinalOutput { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let u = &self.token_usage; + write!( + f, + "Token usage: total={} input={}{} output={}{}", + u.total_tokens, + u.input_tokens, + u.cached_input_tokens + .map(|c| format!(" (cached {c})")) + .unwrap_or_default(), + u.output_tokens, + u.reasoning_output_tokens + .map(|r| format!(" (reasoning {r})")) + .unwrap_or_default() + ) + } +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct AgentMessageEvent { pub message: String, } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct AgentMessageDeltaEvent { + pub delta: String, +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct AgentReasoningEvent { pub text: String, @@ -353,10 +417,12 @@ pub struct AgentReasoningContentEvent { pub text: String, } +pub struct AgentReasoningDeltaEvent { + pub delta: String, +} + #[derive(Debug, Clone, Deserialize, Serialize)] -pub struct McpToolCallBeginEvent { - /// Identifier so this can be paired with the McpToolCallEnd event. - pub call_id: String, +pub struct McpInvocation { /// Name of the MCP server as defined in the config. pub server: String, /// Name of the tool as given by the MCP server. @@ -365,10 +431,19 @@ pub struct McpToolCallBeginEvent { pub arguments: Option, } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct McpToolCallBeginEvent { + /// Identifier so this can be paired with the McpToolCallEnd event. + pub call_id: String, + pub invocation: McpInvocation, +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct McpToolCallEndEvent { /// Identifier for the corresponding McpToolCallBegin that finished. pub call_id: String, + pub invocation: McpInvocation, + pub duration: Duration, /// Result of the tool call. Note this could be an error. pub result: Result, } @@ -406,6 +481,8 @@ pub struct ExecCommandEndEvent { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ExecApprovalRequestEvent { + /// Identifier for the associated exec call, if available. + pub call_id: String, /// The command to be executed. pub command: Vec, /// The command's working directory. @@ -417,6 +494,8 @@ pub struct ExecApprovalRequestEvent { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ApplyPatchApprovalRequestEvent { + /// Responses API call id for the associated patch apply call, if available. + pub call_id: String, pub changes: HashMap, /// Optional explanatory reason (e.g. request for extra write access). #[serde(skip_serializing_if = "Option::is_none")] diff --git a/codex-rs/core/src/rollout.rs b/codex-rs/core/src/rollout.rs index c18a58df06..0ccd8e891b 100644 --- a/codex-rs/core/src/rollout.rs +++ b/codex-rs/core/src/rollout.rs @@ -1,33 +1,57 @@ -//! Functionality to persist a Codex conversation *rollout* – a linear list of -//! [`ResponseItem`] objects exchanged during a session – to disk so that -//! sessions can be replayed or inspected later (mirrors the behaviour of the -//! upstream TypeScript implementation). +//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later. use std::fs::File; use std::fs::{self}; use std::io::Error as IoError; +use std::path::Path; +use serde::Deserialize; use serde::Serialize; +use serde_json::Value; use time::OffsetDateTime; use time::format_description::FormatItem; use time::macros::format_description; use tokio::io::AsyncWriteExt; use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::{self}; +use tokio::sync::oneshot; +use tracing::info; +use tracing::warn; use uuid::Uuid; use crate::config::Config; +use crate::git_info::GitInfo; +use crate::git_info::collect_git_info; use crate::models::ResponseItem; -/// Folder inside `~/.codex` that holds saved rollouts. const SESSIONS_SUBDIR: &str = "sessions"; +#[derive(Serialize, Deserialize, Clone, Default)] +pub struct SessionMeta { + pub id: Uuid, + pub timestamp: String, + pub instructions: Option, +} + #[derive(Serialize)] -struct SessionMeta { - id: String, - timestamp: String, +struct SessionMetaWithGit { + #[serde(flatten)] + meta: SessionMeta, #[serde(skip_serializing_if = "Option::is_none")] - instructions: Option, + git: Option, +} + +#[derive(Serialize, Deserialize, Default, Clone)] +pub struct SessionStateSnapshot {} + +#[derive(Serialize, Deserialize, Default, Clone)] +pub struct SavedSession { + pub session: SessionMeta, + #[serde(default)] + pub items: Vec, + #[serde(default)] + pub state: SessionStateSnapshot, + pub session_id: Uuid, } /// Records all [`ResponseItem`]s for a session and flushes them to disk after @@ -41,7 +65,13 @@ struct SessionMeta { /// ``` #[derive(Clone)] pub(crate) struct RolloutRecorder { - tx: Sender, + tx: Sender, +} + +enum RolloutCmd { + AddItems(Vec), + UpdateState(SessionStateSnapshot), + Shutdown { ack: oneshot::Sender<()> }, } impl RolloutRecorder { @@ -59,7 +89,6 @@ impl RolloutRecorder { timestamp, } = create_log_file(config, uuid)?; - // Build the static session metadata JSON first. let timestamp_format: &[FormatItem] = format_description!( "[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z" ); @@ -67,48 +96,33 @@ impl RolloutRecorder { .format(timestamp_format) .map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?; - let meta = SessionMeta { - timestamp, - id: session_id.to_string(), - instructions, - }; + // Clone the cwd for the spawned task to collect git info asynchronously + let cwd = config.cwd.clone(); // A reasonably-sized bounded channel. If the buffer fills up the send // future will yield, which is fine – we only need to ensure we do not - // perform *blocking* I/O on the caller’s thread. - let (tx, mut rx) = mpsc::channel::(256); + // perform *blocking* I/O on the caller's thread. + let (tx, rx) = mpsc::channel::(256); // Spawn a Tokio task that owns the file handle and performs async // writes. Using `tokio::fs::File` keeps everything on the async I/O // driver instead of blocking the runtime. - tokio::task::spawn(async move { - let mut file = tokio::fs::File::from_std(file); + tokio::task::spawn(rollout_writer( + tokio::fs::File::from_std(file), + rx, + Some(SessionMeta { + timestamp, + id: session_id, + instructions, + }), + cwd, + )); - while let Some(line) = rx.recv().await { - // Write line + newline, then flush to disk. - if let Err(e) = file.write_all(line.as_bytes()).await { - tracing::warn!("rollout writer: failed to write line: {e}"); - break; - } - if let Err(e) = file.write_all(b"\n").await { - tracing::warn!("rollout writer: failed to write newline: {e}"); - break; - } - if let Err(e) = file.flush().await { - tracing::warn!("rollout writer: failed to flush: {e}"); - break; - } - } - }); - - let recorder = Self { tx }; - // Ensure SessionMeta is the first item in the file. - recorder.record_item(&meta).await?; - Ok(recorder) + Ok(Self { tx }) } - /// Append `items` to the rollout file. pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> { + let mut filtered = Vec::new(); for item in items { match item { // Note that function calls may look a bit strange if they are @@ -117,27 +131,114 @@ impl RolloutRecorder { ResponseItem::Message { .. } | ResponseItem::LocalShellCall { .. } | ResponseItem::FunctionCall { .. } - | ResponseItem::FunctionCallOutput { .. } => {} - ResponseItem::Reasoning { .. } | ResponseItem::Other => { + | ResponseItem::FunctionCallOutput { .. } + | ResponseItem::Reasoning { .. } => filtered.push(item.clone()), + ResponseItem::Other => { // These should never be serialized. continue; } } - self.record_item(item).await?; } - Ok(()) + if filtered.is_empty() { + return Ok(()); + } + self.tx + .send(RolloutCmd::AddItems(filtered)) + .await + .map_err(|e| IoError::other(format!("failed to queue rollout items: {e}"))) } - async fn record_item(&self, item: &impl Serialize) -> std::io::Result<()> { - // Serialize the item to JSON first so that the writer thread only has - // to perform the actual write. - let json = serde_json::to_string(item) - .map_err(|e| IoError::other(format!("failed to serialize response items: {e}")))?; - + pub(crate) async fn record_state(&self, state: SessionStateSnapshot) -> std::io::Result<()> { self.tx - .send(json) + .send(RolloutCmd::UpdateState(state)) .await - .map_err(|e| IoError::other(format!("failed to queue rollout item: {e}"))) + .map_err(|e| IoError::other(format!("failed to queue rollout state: {e}"))) + } + + pub async fn resume( + path: &Path, + cwd: std::path::PathBuf, + ) -> std::io::Result<(Self, SavedSession)> { + info!("Resuming rollout from {path:?}"); + let text = tokio::fs::read_to_string(path).await?; + let mut lines = text.lines(); + let meta_line = lines + .next() + .ok_or_else(|| IoError::other("empty session file"))?; + let session: SessionMeta = serde_json::from_str(meta_line) + .map_err(|e| IoError::other(format!("failed to parse session meta: {e}")))?; + let mut items = Vec::new(); + let mut state = SessionStateSnapshot::default(); + + for line in lines { + if line.trim().is_empty() { + continue; + } + let v: Value = match serde_json::from_str(line) { + Ok(v) => v, + Err(_) => continue, + }; + if v.get("record_type") + .and_then(|rt| rt.as_str()) + .map(|s| s == "state") + .unwrap_or(false) + { + if let Ok(s) = serde_json::from_value::(v.clone()) { + state = s + } + continue; + } + match serde_json::from_value::(v.clone()) { + Ok(item) => match item { + ResponseItem::Message { .. } + | ResponseItem::LocalShellCall { .. } + | ResponseItem::FunctionCall { .. } + | ResponseItem::FunctionCallOutput { .. } + | ResponseItem::Reasoning { .. } => items.push(item), + ResponseItem::Other => {} + }, + Err(e) => { + warn!("failed to parse item: {v:?}, error: {e}"); + } + } + } + + let saved = SavedSession { + session: session.clone(), + items: items.clone(), + state: state.clone(), + session_id: session.id, + }; + + let file = std::fs::OpenOptions::new() + .append(true) + .read(true) + .open(path)?; + + let (tx, rx) = mpsc::channel::(256); + tokio::task::spawn(rollout_writer( + tokio::fs::File::from_std(file), + rx, + None, + cwd, + )); + info!("Resumed rollout successfully from {path:?}"); + Ok((Self { tx }, saved)) + } + + pub async fn shutdown(&self) -> std::io::Result<()> { + let (tx_done, rx_done) = oneshot::channel(); + match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await { + Ok(_) => rx_done + .await + .map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))), + Err(e) => { + warn!("failed to send rollout shutdown command: {e}"); + Err(IoError::other(format!( + "failed to send rollout shutdown command: {e}" + ))) + } + } } } @@ -153,13 +254,15 @@ struct LogFileInfo { } fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result { - // Resolve ~/.codex/sessions and create it if missing. - let mut dir = config.codex_home.clone(); - dir.push(SESSIONS_SUBDIR); - fs::create_dir_all(&dir)?; - + // Resolve ~/.codex/sessions/YYYY/MM/DD and create it if missing. let timestamp = OffsetDateTime::now_local() .map_err(|e| IoError::other(format!("failed to get local time: {e}")))?; + let mut dir = config.codex_home.clone(); + dir.push(SESSIONS_SUBDIR); + dir.push(timestamp.year().to_string()); + dir.push(format!("{:02}", u8::from(timestamp.month()))); + dir.push(format!("{:02}", timestamp.day())); + fs::create_dir_all(&dir)?; // Custom format for YYYY-MM-DDThh-mm-ss. Use `-` instead of `:` for // compatibility with filesystems that do not allow colons in filenames. @@ -183,3 +286,77 @@ fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result, + mut meta: Option, + cwd: std::path::PathBuf, +) -> std::io::Result<()> { + let mut writer = JsonlWriter { file }; + + // If we have a meta, collect git info asynchronously and write meta first + if let Some(session_meta) = meta.take() { + let git_info = collect_git_info(&cwd).await; + let session_meta_with_git = SessionMetaWithGit { + meta: session_meta, + git: git_info, + }; + + // Write the SessionMeta as the first item in the file + writer.write_line(&session_meta_with_git).await?; + } + + // Process rollout commands + while let Some(cmd) = rx.recv().await { + match cmd { + RolloutCmd::AddItems(items) => { + for item in items { + match item { + ResponseItem::Message { .. } + | ResponseItem::LocalShellCall { .. } + | ResponseItem::FunctionCall { .. } + | ResponseItem::FunctionCallOutput { .. } + | ResponseItem::Reasoning { .. } => { + writer.write_line(&item).await?; + } + ResponseItem::Other => {} + } + } + } + RolloutCmd::UpdateState(state) => { + #[derive(Serialize)] + struct StateLine<'a> { + record_type: &'static str, + #[serde(flatten)] + state: &'a SessionStateSnapshot, + } + writer + .write_line(&StateLine { + record_type: "state", + state: &state, + }) + .await?; + } + RolloutCmd::Shutdown { ack } => { + let _ = ack.send(()); + } + } + } + + Ok(()) +} + +struct JsonlWriter { + file: tokio::fs::File, +} + +impl JsonlWriter { + async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> { + let mut json = serde_json::to_string(item)?; + json.push('\n'); + let _ = self.file.write_all(json.as_bytes()).await; + self.file.flush().await?; + Ok(()) + } +} diff --git a/codex-rs/core/src/safety.rs b/codex-rs/core/src/safety.rs index 6a3ff29901..f9bc27e058 100644 --- a/codex-rs/core/src/safety.rs +++ b/codex-rs/core/src/safety.rs @@ -75,9 +75,6 @@ pub fn assess_command_safety( sandbox_policy: &SandboxPolicy, approved: &HashSet>, ) -> SafetyCheck { - use AskForApproval::*; - use SandboxPolicy::*; - // A command is "trusted" because either: // - it belongs to a set of commands we consider "safe" by default, or // - the user has explicitly approved the command for this session @@ -97,6 +94,16 @@ pub fn assess_command_safety( }; } + assess_safety_for_untrusted_command(approval_policy, sandbox_policy) +} + +pub(crate) fn assess_safety_for_untrusted_command( + approval_policy: AskForApproval, + sandbox_policy: &SandboxPolicy, +) -> SafetyCheck { + use AskForApproval::*; + use SandboxPolicy::*; + match (approval_policy, sandbox_policy) { (UnlessTrusted, _) => { // Even though the user may have opted into DangerFullAccess, diff --git a/codex-rs/core/src/shell.rs b/codex-rs/core/src/shell.rs new file mode 100644 index 0000000000..98addffce2 --- /dev/null +++ b/codex-rs/core/src/shell.rs @@ -0,0 +1,236 @@ +use shlex; + +#[derive(Debug, PartialEq, Eq)] +pub struct ZshShell { + shell_path: String, + zshrc_path: String, +} + +#[derive(Debug, PartialEq, Eq)] +pub enum Shell { + Zsh(ZshShell), + Unknown, +} + +impl Shell { + pub fn format_default_shell_invocation(&self, command: Vec) -> Option> { + match self { + Shell::Zsh(zsh) => { + if !std::path::Path::new(&zsh.zshrc_path).exists() { + return None; + } + + let mut result = vec![zsh.shell_path.clone()]; + result.push("-lc".to_string()); + + let joined = strip_bash_lc(&command) + .or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok()); + + if let Some(joined) = joined { + result.push(format!("source {} && ({joined})", zsh.zshrc_path)); + } else { + return None; + } + Some(result) + } + Shell::Unknown => None, + } + } +} + +fn strip_bash_lc(command: &Vec) -> Option { + match command.as_slice() { + // exactly three items + [first, second, third] + // first two must be "bash", "-lc" + if first == "bash" && second == "-lc" => + { + Some(third.clone()) + } + _ => None, + } +} + +#[cfg(target_os = "macos")] +pub async fn default_user_shell() -> Shell { + use tokio::process::Command; + use whoami; + + let user = whoami::username(); + let home = format!("/Users/{user}"); + let output = Command::new("dscl") + .args([".", "-read", &home, "UserShell"]) + .output() + .await + .ok(); + match output { + Some(o) => { + if !o.status.success() { + return Shell::Unknown; + } + let stdout = String::from_utf8_lossy(&o.stdout); + for line in stdout.lines() { + if let Some(shell_path) = line.strip_prefix("UserShell: ") { + if shell_path.ends_with("/zsh") { + return Shell::Zsh(ZshShell { + shell_path: shell_path.to_string(), + zshrc_path: format!("{home}/.zshrc"), + }); + } + } + } + + Shell::Unknown + } + _ => Shell::Unknown, + } +} + +#[cfg(not(target_os = "macos"))] +pub async fn default_user_shell() -> Shell { + Shell::Unknown +} + +#[cfg(test)] +#[cfg(target_os = "macos")] +mod tests { + use super::*; + use std::process::Command; + + #[tokio::test] + #[expect(clippy::unwrap_used)] + async fn test_current_shell_detects_zsh() { + let shell = Command::new("sh") + .arg("-c") + .arg("echo $SHELL") + .output() + .unwrap(); + + let home = std::env::var("HOME").unwrap(); + let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string(); + if shell_path.ends_with("/zsh") { + assert_eq!( + default_user_shell().await, + Shell::Zsh(ZshShell { + shell_path: shell_path.to_string(), + zshrc_path: format!("{home}/.zshrc",), + }) + ); + } + } + + #[tokio::test] + async fn test_run_with_profile_zshrc_not_exists() { + let shell = Shell::Zsh(ZshShell { + shell_path: "/bin/zsh".to_string(), + zshrc_path: "/does/not/exist/.zshrc".to_string(), + }); + let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]); + assert_eq!(actual_cmd, None); + } + + #[expect(clippy::unwrap_used)] + #[tokio::test] + async fn test_run_with_profile_escaping_and_execution() { + let shell_path = "/bin/zsh"; + + let cases = vec![ + ( + vec!["myecho"], + vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"], + Some("It works!\n"), + ), + ( + vec!["myecho"], + vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"], + Some("It works!\n"), + ), + ( + vec!["bash", "-c", "echo 'single' \"double\""], + vec![ + shell_path, + "-lc", + "source ZSHRC_PATH && (bash -c \"echo 'single' \\\"double\\\"\")", + ], + Some("single double\n"), + ), + ( + vec!["bash", "-lc", "echo 'single' \"double\""], + vec![ + shell_path, + "-lc", + "source ZSHRC_PATH && (echo 'single' \"double\")", + ], + Some("single double\n"), + ), + ]; + for (input, expected_cmd, expected_output) in cases { + use std::collections::HashMap; + use std::path::PathBuf; + use std::sync::Arc; + + use tokio::sync::Notify; + + use crate::exec::ExecParams; + use crate::exec::SandboxType; + use crate::exec::process_exec_tool_call; + use crate::protocol::SandboxPolicy; + + // create a temp directory with a zshrc file in it + let temp_home = tempfile::tempdir().unwrap(); + let zshrc_path = temp_home.path().join(".zshrc"); + std::fs::write( + &zshrc_path, + r#" + set -x + function myecho { + echo 'It works!' + } + "#, + ) + .unwrap(); + let shell = Shell::Zsh(ZshShell { + shell_path: shell_path.to_string(), + zshrc_path: zshrc_path.to_str().unwrap().to_string(), + }); + + let actual_cmd = shell + .format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect()); + let expected_cmd = expected_cmd + .iter() + .map(|s| { + s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap()) + .to_string() + }) + .collect(); + + assert_eq!(actual_cmd, Some(expected_cmd)); + // Actually run the command and check output/exit code + let output = process_exec_tool_call( + ExecParams { + command: actual_cmd.unwrap(), + cwd: PathBuf::from(temp_home.path()), + timeout_ms: None, + env: HashMap::from([( + "HOME".to_string(), + temp_home.path().to_str().unwrap().to_string(), + )]), + }, + SandboxType::None, + Arc::new(Notify::new()), + &SandboxPolicy::DangerFullAccess, + &None, + ) + .await + .unwrap(); + + assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}"); + if let Some(expected) = expected_output { + assert_eq!( + output.stdout, expected, + "input: {input:?} output: {output:?}" + ); + } + } + } +} diff --git a/codex-rs/core/tests/cli_responses_fixture.sse b/codex-rs/core/tests/cli_responses_fixture.sse new file mode 100644 index 0000000000..d297ebafb2 --- /dev/null +++ b/codex-rs/core/tests/cli_responses_fixture.sse @@ -0,0 +1,8 @@ +event: response.created +data: {"type":"response.created","response":{"id":"resp1"}} + +event: response.output_item.done +data: {"type":"response.output_item.done","item":{"type":"message","role":"assistant","content":[{"type":"output_text","text":"fixture hello"}]}} + +event: response.completed +data: {"type":"response.completed","response":{"id":"resp1","output":[]}} diff --git a/codex-rs/core/tests/cli_stream.rs b/codex-rs/core/tests/cli_stream.rs new file mode 100644 index 0000000000..0ab7bd0bb2 --- /dev/null +++ b/codex-rs/core/tests/cli_stream.rs @@ -0,0 +1,574 @@ +#![expect(clippy::unwrap_used)] + +use assert_cmd::Command as AssertCommand; +use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; +use std::time::Duration; +use std::time::Instant; +use tempfile::TempDir; +use uuid::Uuid; +use walkdir::WalkDir; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +/// Tests streaming chat completions through the CLI using a mock server. +/// This test: +/// 1. Sets up a mock server that simulates OpenAI's chat completions API +/// 2. Configures codex to use this mock server via a custom provider +/// 3. Sends a simple "hello?" prompt and verifies the streamed response +/// 4. Ensures the response is received exactly once and contains "hi" +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_mode_stream_cli() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let server = MockServer::start().await; + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n", + "data: {\"choices\":[{\"delta\":{}}]}\n\n", + "data: [DONE]\n\n" + ); + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse, "text/event-stream"), + ) + .expect(1) + .mount(&server) + .await; + + let home = TempDir::new().unwrap(); + let provider_override = format!( + "model_providers.mock={{ name = \"mock\", base_url = \"{}/v1\", env_key = \"PATH\", wire_api = \"chat\" }}", + server.uri() + ); + let mut cmd = AssertCommand::new("cargo"); + cmd.arg("run") + .arg("-p") + .arg("codex-cli") + .arg("--quiet") + .arg("--") + .arg("exec") + .arg("--skip-git-repo-check") + .arg("-c") + .arg(&provider_override) + .arg("-c") + .arg("model_provider=\"mock\"") + .arg("-C") + .arg(env!("CARGO_MANIFEST_DIR")) + .arg("hello?"); + cmd.env("CODEX_HOME", home.path()) + .env("OPENAI_API_KEY", "dummy") + .env("OPENAI_BASE_URL", format!("{}/v1", server.uri())); + + let output = cmd.output().unwrap(); + println!("Status: {}", output.status); + println!("Stdout:\n{}", String::from_utf8_lossy(&output.stdout)); + println!("Stderr:\n{}", String::from_utf8_lossy(&output.stderr)); + assert!(output.status.success()); + let stdout = String::from_utf8_lossy(&output.stdout); + let hi_lines = stdout.lines().filter(|line| line.trim() == "hi").count(); + assert_eq!(hi_lines, 1, "Expected exactly one line with 'hi'"); + + server.verify().await; +} + +/// Verify that passing `-c experimental_instructions_file=...` to the CLI +/// overrides the built-in base instructions by inspecting the request body +/// received by a mock OpenAI Responses endpoint. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exec_cli_applies_experimental_instructions_file() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + // Start mock server which will capture the request and return a minimal + // SSE stream for a single turn. + let server = MockServer::start().await; + let sse = concat!( + "data: {\"type\":\"response.created\",\"response\":{}}\n\n", + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"r1\"}}\n\n" + ); + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse, "text/event-stream"), + ) + .expect(1) + .mount(&server) + .await; + + // Create a temporary instructions file with a unique marker we can assert + // appears in the outbound request payload. + let custom = TempDir::new().unwrap(); + let marker = "cli-experimental-instructions-marker"; + let custom_path = custom.path().join("instr.md"); + std::fs::write(&custom_path, marker).unwrap(); + let custom_path_str = custom_path.to_string_lossy().replace('\\', "/"); + + // Build a provider override that points at the mock server and instructs + // Codex to use the Responses API with the dummy env var. + let provider_override = format!( + "model_providers.mock={{ name = \"mock\", base_url = \"{}/v1\", env_key = \"PATH\", wire_api = \"responses\" }}", + server.uri() + ); + + let home = TempDir::new().unwrap(); + let mut cmd = AssertCommand::new("cargo"); + cmd.arg("run") + .arg("-p") + .arg("codex-cli") + .arg("--quiet") + .arg("--") + .arg("exec") + .arg("--skip-git-repo-check") + .arg("-c") + .arg(&provider_override) + .arg("-c") + .arg("model_provider=\"mock\"") + .arg("-c") + .arg(format!( + "experimental_instructions_file=\"{custom_path_str}\"" + )) + .arg("-C") + .arg(env!("CARGO_MANIFEST_DIR")) + .arg("hello?\n"); + cmd.env("CODEX_HOME", home.path()) + .env("OPENAI_API_KEY", "dummy") + .env("OPENAI_BASE_URL", format!("{}/v1", server.uri())); + + let output = cmd.output().unwrap(); + println!("Status: {}", output.status); + println!("Stdout:\n{}", String::from_utf8_lossy(&output.stdout)); + println!("Stderr:\n{}", String::from_utf8_lossy(&output.stderr)); + assert!(output.status.success()); + + // Inspect the captured request and verify our custom base instructions were + // included in the `instructions` field. + let request = &server.received_requests().await.unwrap()[0]; + let body = request.body_json::().unwrap(); + let instructions = body + .get("instructions") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + assert!( + instructions.contains(marker), + "instructions did not contain custom marker; got: {instructions}" + ); +} + +/// Tests streaming responses through the CLI using a local SSE fixture file. +/// This test: +/// 1. Uses a pre-recorded SSE response fixture instead of a live server +/// 2. Configures codex to read from this fixture via CODEX_RS_SSE_FIXTURE env var +/// 3. Sends a "hello?" prompt and verifies the response +/// 4. Ensures the fixture content is correctly streamed through the CLI +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_api_stream_cli() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let fixture = + std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/cli_responses_fixture.sse"); + + let home = TempDir::new().unwrap(); + let mut cmd = AssertCommand::new("cargo"); + cmd.arg("run") + .arg("-p") + .arg("codex-cli") + .arg("--quiet") + .arg("--") + .arg("exec") + .arg("--skip-git-repo-check") + .arg("-C") + .arg(env!("CARGO_MANIFEST_DIR")) + .arg("hello?"); + cmd.env("CODEX_HOME", home.path()) + .env("OPENAI_API_KEY", "dummy") + .env("CODEX_RS_SSE_FIXTURE", fixture) + .env("OPENAI_BASE_URL", "http://unused.local"); + + let output = cmd.output().unwrap(); + assert!(output.status.success()); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("fixture hello")); +} + +/// End-to-end: create a session (writes rollout), verify the file, then resume and confirm append. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn integration_creates_and_checks_session_file() { + // Honor sandbox network restrictions for CI parity with the other tests. + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + // 1. Temp home so we read/write isolated session files. + let home = TempDir::new().unwrap(); + + // 2. Unique marker we'll look for in the session log. + let marker = format!("integration-test-{}", Uuid::new_v4()); + let prompt = format!("echo {marker}"); + + // 3. Use the same offline SSE fixture as responses_api_stream_cli so the test is hermetic. + let fixture = + std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/cli_responses_fixture.sse"); + + // 4. Run the codex CLI through cargo (ensures the right bin is built) and invoke `exec`, + // which is what records a session. + let mut cmd = AssertCommand::new("cargo"); + cmd.arg("run") + .arg("-p") + .arg("codex-cli") + .arg("--quiet") + .arg("--") + .arg("exec") + .arg("--skip-git-repo-check") + .arg("-C") + .arg(env!("CARGO_MANIFEST_DIR")) + .arg(&prompt); + cmd.env("CODEX_HOME", home.path()) + .env("OPENAI_API_KEY", "dummy") + .env("CODEX_RS_SSE_FIXTURE", &fixture) + // Required for CLI arg parsing even though fixture short-circuits network usage. + .env("OPENAI_BASE_URL", "http://unused.local"); + + let output = cmd.output().unwrap(); + assert!( + output.status.success(), + "codex-cli exec failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + + // Wait for sessions dir to appear. + let sessions_dir = home.path().join("sessions"); + let dir_deadline = Instant::now() + Duration::from_secs(5); + while !sessions_dir.exists() && Instant::now() < dir_deadline { + std::thread::sleep(Duration::from_millis(50)); + } + assert!(sessions_dir.exists(), "sessions directory never appeared"); + + // Find the session file that contains `marker`. + let deadline = Instant::now() + Duration::from_secs(10); + let mut matching_path: Option = None; + while Instant::now() < deadline && matching_path.is_none() { + for entry in WalkDir::new(&sessions_dir) { + let entry = match entry { + Ok(e) => e, + Err(_) => continue, + }; + if !entry.file_type().is_file() { + continue; + } + if !entry.file_name().to_string_lossy().ends_with(".jsonl") { + continue; + } + let path = entry.path(); + let Ok(content) = std::fs::read_to_string(path) else { + continue; + }; + let mut lines = content.lines(); + if lines.next().is_none() { + continue; + } + for line in lines { + if line.trim().is_empty() { + continue; + } + let item: serde_json::Value = match serde_json::from_str(line) { + Ok(v) => v, + Err(_) => continue, + }; + if item.get("type").and_then(|t| t.as_str()) == Some("message") { + if let Some(c) = item.get("content") { + if c.to_string().contains(&marker) { + matching_path = Some(path.to_path_buf()); + break; + } + } + } + } + } + if matching_path.is_none() { + std::thread::sleep(Duration::from_millis(50)); + } + } + + let path = match matching_path { + Some(p) => p, + None => panic!("No session file containing the marker was found"), + }; + + // Basic sanity checks on location and metadata. + let rel = match path.strip_prefix(&sessions_dir) { + Ok(r) => r, + Err(_) => panic!("session file should live under sessions/"), + }; + let comps: Vec = rel + .components() + .map(|c| c.as_os_str().to_string_lossy().into_owned()) + .collect(); + assert_eq!( + comps.len(), + 4, + "Expected sessions/YYYY/MM/DD/, got {rel:?}" + ); + let year = &comps[0]; + let month = &comps[1]; + let day = &comps[2]; + assert!( + year.len() == 4 && year.chars().all(|c| c.is_ascii_digit()), + "Year dir not 4-digit numeric: {year}" + ); + assert!( + month.len() == 2 && month.chars().all(|c| c.is_ascii_digit()), + "Month dir not zero-padded 2-digit numeric: {month}" + ); + assert!( + day.len() == 2 && day.chars().all(|c| c.is_ascii_digit()), + "Day dir not zero-padded 2-digit numeric: {day}" + ); + if let Ok(m) = month.parse::() { + assert!((1..=12).contains(&m), "Month out of range: {m}"); + } + if let Ok(d) = day.parse::() { + assert!((1..=31).contains(&d), "Day out of range: {d}"); + } + + let content = + std::fs::read_to_string(&path).unwrap_or_else(|_| panic!("Failed to read session file")); + let mut lines = content.lines(); + let meta_line = lines + .next() + .ok_or("missing session meta line") + .unwrap_or_else(|_| panic!("missing session meta line")); + let meta: serde_json::Value = serde_json::from_str(meta_line) + .unwrap_or_else(|_| panic!("Failed to parse session meta line as JSON")); + assert!(meta.get("id").is_some(), "SessionMeta missing id"); + assert!( + meta.get("timestamp").is_some(), + "SessionMeta missing timestamp" + ); + + let mut found_message = false; + for line in lines { + if line.trim().is_empty() { + continue; + } + let Ok(item) = serde_json::from_str::(line) else { + continue; + }; + if item.get("type").and_then(|t| t.as_str()) == Some("message") { + if let Some(c) = item.get("content") { + if c.to_string().contains(&marker) { + found_message = true; + break; + } + } + } + } + assert!( + found_message, + "No message found in session file containing the marker" + ); + + // Second run: resume and append. + let orig_len = content.lines().count(); + let marker2 = format!("integration-resume-{}", Uuid::new_v4()); + let prompt2 = format!("echo {marker2}"); + // Cross‑platform safe resume override. On Windows, backslashes in a TOML string must be escaped + // or the parse will fail and the raw literal (including quotes) may be preserved all the way down + // to Config, which in turn breaks resume because the path is invalid. Normalize to forward slashes + // to sidestep the issue. + let resume_path_str = path.to_string_lossy().replace('\\', "/"); + let resume_override = format!("experimental_resume=\"{resume_path_str}\""); + let mut cmd2 = AssertCommand::new("cargo"); + cmd2.arg("run") + .arg("-p") + .arg("codex-cli") + .arg("--quiet") + .arg("--") + .arg("exec") + .arg("--skip-git-repo-check") + .arg("-c") + .arg(&resume_override) + .arg("-C") + .arg(env!("CARGO_MANIFEST_DIR")) + .arg(&prompt2); + cmd2.env("CODEX_HOME", home.path()) + .env("OPENAI_API_KEY", "dummy") + .env("CODEX_RS_SSE_FIXTURE", &fixture) + .env("OPENAI_BASE_URL", "http://unused.local"); + + let output2 = cmd2.output().unwrap(); + assert!(output2.status.success(), "resume codex-cli run failed"); + + // The rollout writer runs on a background async task; give it a moment to flush. + let mut new_len = orig_len; + let deadline = Instant::now() + Duration::from_secs(5); + let mut content2 = String::new(); + while Instant::now() < deadline { + if let Ok(c) = std::fs::read_to_string(&path) { + let count = c.lines().count(); + if count > orig_len { + content2 = c; + new_len = count; + break; + } + } + std::thread::sleep(Duration::from_millis(50)); + } + if content2.is_empty() { + // last attempt + content2 = std::fs::read_to_string(&path).unwrap(); + new_len = content2.lines().count(); + } + assert!(new_len > orig_len, "rollout file did not grow after resume"); + assert!(content2.contains(&marker), "rollout lost original marker"); + assert!( + content2.contains(&marker2), + "rollout missing resumed marker" + ); +} + +/// Integration test to verify git info is collected and recorded in session files. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn integration_git_info_unit_test() { + // This test verifies git info collection works independently + // without depending on the full CLI integration + + // 1. Create temp directory for git repo + let temp_dir = TempDir::new().unwrap(); + let git_repo = temp_dir.path().to_path_buf(); + + // 2. Initialize a git repository with some content + let init_output = std::process::Command::new("git") + .args(["init"]) + .current_dir(&git_repo) + .output() + .unwrap(); + assert!(init_output.status.success(), "git init failed"); + + // Configure git user (required for commits) + std::process::Command::new("git") + .args(["config", "user.name", "Integration Test"]) + .current_dir(&git_repo) + .output() + .unwrap(); + + std::process::Command::new("git") + .args(["config", "user.email", "test@example.com"]) + .current_dir(&git_repo) + .output() + .unwrap(); + + // Create a test file and commit it + let test_file = git_repo.join("test.txt"); + std::fs::write(&test_file, "integration test content").unwrap(); + + std::process::Command::new("git") + .args(["add", "."]) + .current_dir(&git_repo) + .output() + .unwrap(); + + let commit_output = std::process::Command::new("git") + .args(["commit", "-m", "Integration test commit"]) + .current_dir(&git_repo) + .output() + .unwrap(); + assert!(commit_output.status.success(), "git commit failed"); + + // Create a branch to test branch detection + std::process::Command::new("git") + .args(["checkout", "-b", "integration-test-branch"]) + .current_dir(&git_repo) + .output() + .unwrap(); + + // Add a remote to test repository URL detection + std::process::Command::new("git") + .args([ + "remote", + "add", + "origin", + "https://github.com/example/integration-test.git", + ]) + .current_dir(&git_repo) + .output() + .unwrap(); + + // 3. Test git info collection directly + let git_info = codex_core::git_info::collect_git_info(&git_repo).await; + + // 4. Verify git info is present and contains expected data + assert!(git_info.is_some(), "Git info should be collected"); + + let git_info = git_info.unwrap(); + + // Check that we have a commit hash + assert!( + git_info.commit_hash.is_some(), + "Git info should contain commit_hash" + ); + let commit_hash = git_info.commit_hash.as_ref().unwrap(); + assert_eq!(commit_hash.len(), 40, "Commit hash should be 40 characters"); + assert!( + commit_hash.chars().all(|c| c.is_ascii_hexdigit()), + "Commit hash should be hexadecimal" + ); + + // Check that we have the correct branch + assert!(git_info.branch.is_some(), "Git info should contain branch"); + let branch = git_info.branch.as_ref().unwrap(); + assert_eq!( + branch, "integration-test-branch", + "Branch should match what we created" + ); + + // Check that we have the repository URL + assert!( + git_info.repository_url.is_some(), + "Git info should contain repository_url" + ); + let repo_url = git_info.repository_url.as_ref().unwrap(); + assert_eq!( + repo_url, "https://github.com/example/integration-test.git", + "Repository URL should match what we configured" + ); + + println!("✅ Git info collection test passed!"); + println!(" Commit: {commit_hash}"); + println!(" Branch: {branch}"); + println!(" Repo: {repo_url}"); + + // 5. Test serialization to ensure it works in SessionMeta + let serialized = serde_json::to_string(&git_info).unwrap(); + let deserialized: codex_core::git_info::GitInfo = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(git_info.commit_hash, deserialized.commit_hash); + assert_eq!(git_info.branch, deserialized.branch); + assert_eq!(git_info.repository_url, deserialized.repository_url); + + println!("✅ Git info serialization test passed!"); +} diff --git a/codex-rs/core/tests/client.rs b/codex-rs/core/tests/client.rs new file mode 100644 index 0000000000..67d95cb8f6 --- /dev/null +++ b/codex-rs/core/tests/client.rs @@ -0,0 +1,340 @@ +use std::path::PathBuf; + +use chrono::Utc; +use codex_core::Codex; +use codex_core::CodexSpawnOk; +use codex_core::ModelProviderInfo; +use codex_core::built_in_model_providers; +use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use codex_core::protocol::SessionConfiguredEvent; +use codex_login::AuthDotJson; +use codex_login::AuthMode; +use codex_login::CodexAuth; +use codex_login::TokenData; +use core_test_support::load_default_config_for_test; +use core_test_support::load_sse_fixture_with_id; +use core_test_support::wait_for_event; +use tempfile::TempDir; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +/// Build minimal SSE stream with completed marker using the JSON fixture. +fn sse_completed(id: &str) -> String { + load_sse_fixture_with_id("tests/fixtures/completed_template.json", id) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn includes_session_id_and_model_headers_in_request() { + #![allow(clippy::unwrap_used)] + + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + // Mock server + let server = MockServer::start().await; + + // First request – must NOT include `previous_response_id`. + let first = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("resp1"), "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(first) + .expect(1) + .mount(&server) + .await; + + let model_provider = ModelProviderInfo { + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + + // Init session + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.model_provider = model_provider; + + let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); + let CodexSpawnOk { codex, .. } = Codex::spawn( + config, + Some(CodexAuth::from_api_key("Test API Key".to_string())), + ctrl_c.clone(), + ) + .await + .unwrap(); + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "hello".into(), + }], + }) + .await + .unwrap(); + + let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) = + wait_for_event(&codex, |ev| matches!(ev, EventMsg::SessionConfigured(_))).await + else { + unreachable!() + }; + + let current_session_id = Some(session_id.to_string()); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + // get request from the server + let request = &server.received_requests().await.unwrap()[0]; + let request_session_id = request.headers.get("session_id").unwrap(); + let request_originator = request.headers.get("originator").unwrap(); + let request_authorization = request.headers.get("authorization").unwrap(); + + assert!(current_session_id.is_some()); + assert_eq!( + request_session_id.to_str().unwrap(), + current_session_id.as_ref().unwrap() + ); + assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs"); + assert_eq!( + request_authorization.to_str().unwrap(), + "Bearer Test API Key" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn includes_base_instructions_override_in_request() { + #![allow(clippy::unwrap_used)] + + // Mock server + let server = MockServer::start().await; + + // First request – must NOT include `previous_response_id`. + let first = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("resp1"), "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(first) + .expect(1) + .mount(&server) + .await; + + let model_provider = ModelProviderInfo { + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + + config.base_instructions = Some("test instructions".to_string()); + config.model_provider = model_provider; + + let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); + let CodexSpawnOk { codex, .. } = Codex::spawn( + config, + Some(CodexAuth::from_api_key("Test API Key".to_string())), + ctrl_c.clone(), + ) + .await + .unwrap(); + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "hello".into(), + }], + }) + .await + .unwrap(); + + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + let request = &server.received_requests().await.unwrap()[0]; + let request_body = request.body_json::().unwrap(); + + assert!( + request_body["instructions"] + .as_str() + .unwrap() + .contains("test instructions") + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chatgpt_auth_sends_correct_request() { + #![allow(clippy::unwrap_used)] + + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + // Mock server + let server = MockServer::start().await; + + // First request – must NOT include `previous_response_id`. + let first = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("resp1"), "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/api/codex/responses")) + .respond_with(first) + .expect(1) + .mount(&server) + .await; + + let model_provider = ModelProviderInfo { + base_url: Some(format!("{}/api/codex", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + + // Init session + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.model_provider = model_provider; + let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); + let CodexSpawnOk { codex, .. } = Codex::spawn( + config, + Some(auth_from_token("Access Token".to_string())), + ctrl_c.clone(), + ) + .await + .unwrap(); + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "hello".into(), + }], + }) + .await + .unwrap(); + + let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) = + wait_for_event(&codex, |ev| matches!(ev, EventMsg::SessionConfigured(_))).await + else { + unreachable!() + }; + + let current_session_id = Some(session_id.to_string()); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + // get request from the server + let request = &server.received_requests().await.unwrap()[0]; + let request_session_id = request.headers.get("session_id").unwrap(); + let request_originator = request.headers.get("originator").unwrap(); + let request_authorization = request.headers.get("authorization").unwrap(); + let request_body = request.body_json::().unwrap(); + + assert!(current_session_id.is_some()); + assert_eq!( + request_session_id.to_str().unwrap(), + current_session_id.as_ref().unwrap() + ); + assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs"); + assert_eq!( + request_authorization.to_str().unwrap(), + "Bearer Access Token" + ); + assert!(!request_body["store"].as_bool().unwrap()); + assert!(request_body["stream"].as_bool().unwrap()); + assert_eq!( + request_body["include"][0].as_str().unwrap(), + "reasoning.encrypted_content" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn includes_user_instructions_message_in_request() { + #![allow(clippy::unwrap_used)] + + let server = MockServer::start().await; + + let first = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("resp1"), "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(first) + .expect(1) + .mount(&server) + .await; + + let model_provider = ModelProviderInfo { + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.model_provider = model_provider; + config.user_instructions = Some("be nice".to_string()); + + let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); + let CodexSpawnOk { codex, .. } = Codex::spawn( + config, + Some(CodexAuth::from_api_key("Test API Key".to_string())), + ctrl_c.clone(), + ) + .await + .unwrap(); + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "hello".into(), + }], + }) + .await + .unwrap(); + + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + let request = &server.received_requests().await.unwrap()[0]; + let request_body = request.body_json::().unwrap(); + + assert!( + !request_body["instructions"] + .as_str() + .unwrap() + .contains("be nice") + ); + assert_eq!(request_body["input"][0]["role"], "user"); + assert!( + request_body["input"][0]["content"][0]["text"] + .as_str() + .unwrap() + .starts_with("be nice") + ); +} +fn auth_from_token(id_token: String) -> CodexAuth { + CodexAuth::new( + None, + AuthMode::ChatGPT, + PathBuf::new(), + Some(AuthDotJson { + tokens: TokenData { + id_token, + access_token: "Access Token".to_string(), + refresh_token: "test".to_string(), + account_id: None, + }, + last_refresh: Utc::now(), + openai_api_key: None, + }), + ) +} diff --git a/codex-rs/core/tests/common/Cargo.toml b/codex-rs/core/tests/common/Cargo.toml new file mode 100644 index 0000000000..9cfd20cdb4 --- /dev/null +++ b/codex-rs/core/tests/common/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "core_test_support" +version = { workspace = true } +edition = "2024" + +[lib] +path = "lib.rs" + +[dependencies] +codex-core = { path = "../.." } +serde_json = "1" +tempfile = "3" +tokio = { version = "1", features = ["time"] } diff --git a/codex-rs/core/tests/test_support.rs b/codex-rs/core/tests/common/lib.rs similarity index 84% rename from codex-rs/core/tests/test_support.rs rename to codex-rs/core/tests/common/lib.rs index 5dbe637101..2577679f65 100644 --- a/codex-rs/core/tests/test_support.rs +++ b/codex-rs/core/tests/common/lib.rs @@ -1,9 +1,5 @@ #![allow(clippy::expect_used)] -// Helpers shared by the integration tests. These are located inside the -// `tests/` tree on purpose so they never become part of the public API surface -// of the `codex-core` crate. - use tempfile::TempDir; use codex_core::config::Config; @@ -74,3 +70,23 @@ pub fn load_sse_fixture_with_id(path: impl AsRef, id: &str) -> }) .collect() } + +pub async fn wait_for_event( + codex: &codex_core::Codex, + mut predicate: F, +) -> codex_core::protocol::EventMsg +where + F: FnMut(&codex_core::protocol::EventMsg) -> bool, +{ + use tokio::time::Duration; + use tokio::time::timeout; + loop { + let ev = timeout(Duration::from_secs(1), codex.next_event()) + .await + .expect("timeout waiting for event") + .expect("stream ended unexpectedly"); + if predicate(&ev.msg) { + return ev.msg; + } + } +} diff --git a/codex-rs/core/tests/live_agent.rs b/codex-rs/core/tests/live_agent.rs index c21f9d0032..95408e20e5 100644 --- a/codex-rs/core/tests/live_agent.rs +++ b/codex-rs/core/tests/live_agent.rs @@ -20,15 +20,15 @@ use std::time::Duration; use codex_core::Codex; +use codex_core::CodexSpawnOk; use codex_core::error::CodexErr; use codex_core::protocol::AgentMessageEvent; use codex_core::protocol::ErrorEvent; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; use codex_core::protocol::Op; -mod test_support; +use core_test_support::load_default_config_for_test; use tempfile::TempDir; -use test_support::load_default_config_for_test; use tokio::sync::Notify; use tokio::time::timeout; @@ -45,23 +45,12 @@ async fn spawn_codex() -> Result { "OPENAI_API_KEY must be set for live tests" ); - // Environment tweaks to keep the tests snappy and inexpensive while still - // exercising retry/robustness logic. - // - // NOTE: Starting with the 2024 edition `std::env::set_var` is `unsafe` - // because changing the process environment races with any other threads - // that might be performing environment look-ups at the same time. - // Restrict the unsafety to this tiny block that happens at the very - // beginning of the test, before we spawn any background tasks that could - // observe the environment. - unsafe { - std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "2"); - std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "2"); - } - let codex_home = TempDir::new().unwrap(); - let config = load_default_config_for_test(&codex_home); - let (agent, _init_id) = Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?; + let mut config = load_default_config_for_test(&codex_home); + config.model_provider.request_max_retries = Some(2); + config.model_provider.stream_max_retries = Some(2); + let CodexSpawnOk { codex: agent, .. } = + Codex::spawn(config, None, std::sync::Arc::new(Notify::new())).await?; Ok(agent) } @@ -79,7 +68,7 @@ async fn live_streaming_and_prev_id_reset() { let codex = spawn_codex().await.unwrap(); - // ---------- Task 1 ---------- + // ---------- Task 1 ---------- codex .submit(Op::UserInput { items: vec![InputItem::Text { @@ -113,7 +102,7 @@ async fn live_streaming_and_prev_id_reset() { "Agent did not stream any AgentMessage before TaskComplete" ); - // ---------- Task 2 (same session) ---------- + // ---------- Task 2 (same session) ---------- codex .submit(Op::UserInput { items: vec![InputItem::Text { diff --git a/codex-rs/core/tests/previous_response_id.rs b/codex-rs/core/tests/previous_response_id.rs deleted file mode 100644 index e64271a0ff..0000000000 --- a/codex-rs/core/tests/previous_response_id.rs +++ /dev/null @@ -1,166 +0,0 @@ -use std::time::Duration; - -use codex_core::Codex; -use codex_core::ModelProviderInfo; -use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; -use codex_core::protocol::ErrorEvent; -use codex_core::protocol::EventMsg; -use codex_core::protocol::InputItem; -use codex_core::protocol::Op; -mod test_support; -use serde_json::Value; -use tempfile::TempDir; -use test_support::load_default_config_for_test; -use test_support::load_sse_fixture_with_id; -use tokio::time::timeout; -use wiremock::Match; -use wiremock::Mock; -use wiremock::MockServer; -use wiremock::Request; -use wiremock::ResponseTemplate; -use wiremock::matchers::method; -use wiremock::matchers::path; - -/// Matcher asserting that JSON body has NO `previous_response_id` field. -struct NoPrevId; - -impl Match for NoPrevId { - fn matches(&self, req: &Request) -> bool { - serde_json::from_slice::(&req.body) - .map(|v| v.get("previous_response_id").is_none()) - .unwrap_or(false) - } -} - -/// Matcher asserting that JSON body HAS a `previous_response_id` field. -struct HasPrevId; - -impl Match for HasPrevId { - fn matches(&self, req: &Request) -> bool { - serde_json::from_slice::(&req.body) - .map(|v| v.get("previous_response_id").is_some()) - .unwrap_or(false) - } -} - -/// Build minimal SSE stream with completed marker using the JSON fixture. -fn sse_completed(id: &str) -> String { - load_sse_fixture_with_id("tests/fixtures/completed_template.json", id) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn keeps_previous_response_id_between_tasks() { - #![allow(clippy::unwrap_used)] - - if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { - println!( - "Skipping test because it cannot execute when network is disabled in a Codex sandbox." - ); - return; - } - - // Mock server - let server = MockServer::start().await; - - // First request – must NOT include `previous_response_id`. - let first = ResponseTemplate::new(200) - .insert_header("content-type", "text/event-stream") - .set_body_raw(sse_completed("resp1"), "text/event-stream"); - - Mock::given(method("POST")) - .and(path("/v1/responses")) - .and(NoPrevId) - .respond_with(first) - .expect(1) - .mount(&server) - .await; - - // Second request – MUST include `previous_response_id`. - let second = ResponseTemplate::new(200) - .insert_header("content-type", "text/event-stream") - .set_body_raw(sse_completed("resp2"), "text/event-stream"); - - Mock::given(method("POST")) - .and(path("/v1/responses")) - .and(HasPrevId) - .respond_with(second) - .expect(1) - .mount(&server) - .await; - - // Environment - // Update environment – `set_var` is `unsafe` starting with the 2024 - // edition so we group the calls into a single `unsafe { … }` block. - unsafe { - std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0"); - std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0"); - } - let model_provider = ModelProviderInfo { - name: "openai".into(), - base_url: format!("{}/v1", server.uri()), - // Environment variable that should exist in the test environment. - // ModelClient will return an error if the environment variable for the - // provider is not set. - env_key: Some("PATH".into()), - env_key_instructions: None, - wire_api: codex_core::WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - }; - - // Init session - let codex_home = TempDir::new().unwrap(); - let mut config = load_default_config_for_test(&codex_home); - config.model_provider = model_provider; - let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); - let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap(); - - // Task 1 – triggers first request (no previous_response_id) - codex - .submit(Op::UserInput { - items: vec![InputItem::Text { - text: "hello".into(), - }], - }) - .await - .unwrap(); - - // Wait for TaskComplete - loop { - let ev = timeout(Duration::from_secs(1), codex.next_event()) - .await - .unwrap() - .unwrap(); - if matches!(ev.msg, EventMsg::TaskComplete(_)) { - break; - } - } - - // Task 2 – should include `previous_response_id` (triggers second request) - codex - .submit(Op::UserInput { - items: vec![InputItem::Text { - text: "again".into(), - }], - }) - .await - .unwrap(); - - // Wait for TaskComplete or error - loop { - let ev = timeout(Duration::from_secs(1), codex.next_event()) - .await - .unwrap() - .unwrap(); - match ev.msg { - EventMsg::TaskComplete(_) => break, - EventMsg::Error(ErrorEvent { message }) => { - panic!("unexpected error: {message}") - } - _ => { - // Ignore other events. - } - } - } -} diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs index da2736aa77..d2fc035569 100644 --- a/codex-rs/core/tests/stream_no_completed.rs +++ b/codex-rs/core/tests/stream_no_completed.rs @@ -4,16 +4,17 @@ use std::time::Duration; use codex_core::Codex; +use codex_core::CodexSpawnOk; use codex_core::ModelProviderInfo; use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; use codex_core::protocol::Op; -mod test_support; +use codex_login::CodexAuth; +use core_test_support::load_default_config_for_test; +use core_test_support::load_sse_fixture; +use core_test_support::load_sse_fixture_with_id; use tempfile::TempDir; -use test_support::load_default_config_for_test; -use test_support::load_sse_fixture; -use test_support::load_sse_fixture_with_id; use tokio::time::timeout; use wiremock::Mock; use wiremock::MockServer; @@ -70,23 +71,12 @@ async fn retries_on_early_close() { .mount(&server) .await; - // Environment - // - // As of Rust 2024 `std::env::set_var` has been made `unsafe` because - // mutating the process environment is inherently racy when other threads - // are running. We therefore have to wrap every call in an explicit - // `unsafe` block. These are limited to the test-setup section so the - // scope is very small and clearly delineated. - - unsafe { - std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0"); - std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "1"); - std::env::set_var("OPENAI_STREAM_IDLE_TIMEOUT_MS", "2000"); - } + // Configure retry behavior explicitly to avoid mutating process-wide + // environment variables. let model_provider = ModelProviderInfo { name: "openai".into(), - base_url: format!("{}/v1", server.uri()), + base_url: Some(format!("{}/v1", server.uri())), // Environment variable that should exist in the test environment. // ModelClient will return an error if the environment variable for the // provider is not set. @@ -96,13 +86,24 @@ async fn retries_on_early_close() { query_params: None, http_headers: None, env_http_headers: None, + // exercise retry path: first attempt yields incomplete stream, so allow 1 retry + request_max_retries: Some(0), + stream_max_retries: Some(1), + stream_idle_timeout_ms: Some(2000), + requires_auth: false, }; let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); let codex_home = TempDir::new().unwrap(); let mut config = load_default_config_for_test(&codex_home); config.model_provider = model_provider; - let (codex, _init_id) = Codex::spawn(config, ctrl_c).await.unwrap(); + let CodexSpawnOk { codex, .. } = Codex::spawn( + config, + Some(CodexAuth::from_api_key("Test API Key".to_string())), + ctrl_c, + ) + .await + .unwrap(); codex .submit(Op::UserInput { diff --git a/codex-rs/exec/Cargo.toml b/codex-rs/exec/Cargo.toml index ed01b78ec8..cd521410b1 100644 --- a/codex-rs/exec/Cargo.toml +++ b/codex-rs/exec/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-exec" version = { workspace = true } -edition = "2024" [[bin]] name = "codex-exec" @@ -18,13 +18,13 @@ workspace = true anyhow = "1" chrono = "0.4.40" clap = { version = "4", features = ["derive"] } -codex-core = { path = "../core" } +codex-arg0 = { path = "../arg0" } codex-common = { path = "../common", features = [ "cli", "elapsed", "sandbox_summary", ] } -codex-linux-sandbox = { path = "../linux-sandbox" } +codex-core = { path = "../core" } owo-colors = "4.2.0" serde_json = "1" shlex = "1.3.0" @@ -37,3 +37,8 @@ tokio = { version = "1", features = [ ] } tracing = { version = "0.1.41", features = ["log"] } tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } + +[dev-dependencies] +assert_cmd = "2" +predicates = "3" +tempfile = "3.13.0" diff --git a/codex-rs/exec/src/cli.rs b/codex-rs/exec/src/cli.rs index 613fedf0a1..53af25c7e9 100644 --- a/codex-rs/exec/src/cli.rs +++ b/codex-rs/exec/src/cli.rs @@ -51,6 +51,10 @@ pub struct Cli { #[arg(long = "color", value_enum, default_value_t = Color::Auto)] pub color: Color, + /// Print events to stdout as JSONL. + #[arg(long = "json", default_value_t = false)] + pub json: bool, + /// Specifies file where the last message from the agent should be written. #[arg(long = "output-last-message")] pub last_message_file: Option, diff --git a/codex-rs/exec/src/event_processor.rs b/codex-rs/exec/src/event_processor.rs index 57a364c3f1..dbd6bd40ec 100644 --- a/codex-rs/exec/src/event_processor.rs +++ b/codex-rs/exec/src/event_processor.rs @@ -1,509 +1,69 @@ -use codex_common::elapsed::format_elapsed; +use std::path::Path; + use codex_common::summarize_sandbox_policy; use codex_core::WireApi; use codex_core::config::Config; use codex_core::model_supports_reasoning_summaries; -use codex_core::protocol::AgentMessageEvent; -use codex_core::protocol::BackgroundEventEvent; -use codex_core::protocol::ErrorEvent; use codex_core::protocol::Event; -use codex_core::protocol::EventMsg; -use codex_core::protocol::ExecCommandBeginEvent; -use codex_core::protocol::ExecCommandEndEvent; -use codex_core::protocol::FileChange; -use codex_core::protocol::McpToolCallBeginEvent; -use codex_core::protocol::McpToolCallEndEvent; -use codex_core::protocol::PatchApplyBeginEvent; -use codex_core::protocol::PatchApplyEndEvent; -use codex_core::protocol::SessionConfiguredEvent; -use codex_core::protocol::TokenUsage; -use owo_colors::OwoColorize; -use owo_colors::Style; -use shlex::try_join; -use std::collections::HashMap; -use std::time::Instant; -/// This should be configurable. When used in CI, users may not want to impose -/// a limit so they can see the full transcript. -const MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL: usize = 20; - -pub(crate) struct EventProcessor { - call_id_to_command: HashMap, - call_id_to_patch: HashMap, - - /// Tracks in-flight MCP tool calls so we can calculate duration and print - /// a concise summary when the corresponding `McpToolCallEnd` event is - /// received. - call_id_to_tool_call: HashMap, - - // To ensure that --color=never is respected, ANSI escapes _must_ be added - // using .style() with one of these fields. If you need a new style, add a - // new field here. - bold: Style, - italic: Style, - dimmed: Style, - - magenta: Style, - red: Style, - green: Style, - cyan: Style, - - /// Whether to include `AgentReasoning` events in the output. - show_agent_reasoning: bool, - show_agent_reasoning_content: bool, +pub(crate) enum CodexStatus { + Running, + InitiateShutdown, + Shutdown, } -impl EventProcessor { - pub(crate) fn create_with_ansi( - with_ansi: bool, - show_agent_reasoning: bool, - show_agent_reasoning_content: bool, - ) -> Self { - let call_id_to_command = HashMap::new(); - let call_id_to_patch = HashMap::new(); - let call_id_to_tool_call = HashMap::new(); +pub(crate) trait EventProcessor { + /// Print summary of effective configuration and user prompt. + fn print_config_summary(&mut self, config: &Config, prompt: &str); - if with_ansi { - Self { - call_id_to_command, - call_id_to_patch, - bold: Style::new().bold(), - italic: Style::new().italic(), - dimmed: Style::new().dimmed(), - magenta: Style::new().magenta(), - red: Style::new().red(), - green: Style::new().green(), - cyan: Style::new().cyan(), - call_id_to_tool_call, - show_agent_reasoning, - show_agent_reasoning_content, - } - } else { - Self { - call_id_to_command, - call_id_to_patch, - bold: Style::new(), - italic: Style::new(), - dimmed: Style::new(), - magenta: Style::new(), - red: Style::new(), - green: Style::new(), - cyan: Style::new(), - call_id_to_tool_call, - show_agent_reasoning, - show_agent_reasoning_content, - } + /// Handle a single event emitted by the agent. + fn process_event(&mut self, event: Event) -> CodexStatus; +} + +pub(crate) fn create_config_summary_entries(config: &Config) -> Vec<(&'static str, String)> { + let mut entries = vec![ + ("workdir", config.cwd.display().to_string()), + ("model", config.model.clone()), + ("provider", config.model_provider_id.clone()), + ("approval", config.approval_policy.to_string()), + ("sandbox", summarize_sandbox_policy(&config.sandbox_policy)), + ]; + if config.model_provider.wire_api == WireApi::Responses + && model_supports_reasoning_summaries(config) + { + entries.push(( + "reasoning effort", + config.model_reasoning_effort.to_string(), + )); + entries.push(( + "reasoning summaries", + config.model_reasoning_summary.to_string(), + )); + } + entries +} + +pub(crate) fn handle_last_message( + last_agent_message: Option<&str>, + last_message_path: Option<&Path>, +) { + match (last_message_path, last_agent_message) { + (Some(path), Some(msg)) => write_last_message_file(msg, Some(path)), + (Some(path), None) => { + write_last_message_file("", Some(path)); + eprintln!( + "Warning: no last agent message; wrote empty content to {}", + path.display() + ); } + (None, _) => eprintln!("Warning: no file to write last message to."), } } -struct ExecCommandBegin { - command: Vec, - start_time: Instant, -} - -/// Metadata captured when an `McpToolCallBegin` event is received. -struct McpToolCallBegin { - /// Formatted invocation string, e.g. `server.tool({"city":"sf"})`. - invocation: String, - /// Timestamp when the call started so we can compute duration later. - start_time: Instant, -} - -struct PatchApplyBegin { - start_time: Instant, - auto_approved: bool, -} - -// Timestamped println helper. The timestamp is styled with self.dimmed. -#[macro_export] -macro_rules! ts_println { - ($self:ident, $($arg:tt)*) => {{ - let now = chrono::Utc::now(); - let formatted = now.format("[%Y-%m-%dT%H:%M:%S]"); - print!("{} ", formatted.style($self.dimmed)); - println!($($arg)*); - }}; -} - -impl EventProcessor { - /// Print a concise summary of the effective configuration that will be used - /// for the session. This mirrors the information shown in the TUI welcome - /// screen. - pub(crate) fn print_config_summary(&mut self, config: &Config, prompt: &str) { - const VERSION: &str = env!("CARGO_PKG_VERSION"); - ts_println!( - self, - "OpenAI Codex v{} (research preview)\n--------", - VERSION - ); - - let mut entries = vec![ - ("workdir", config.cwd.display().to_string()), - ("model", config.model.clone()), - ("provider", config.model_provider_id.clone()), - ("approval", format!("{:?}", config.approval_policy)), - ("sandbox", summarize_sandbox_policy(&config.sandbox_policy)), - ]; - if config.model_provider.wire_api == WireApi::Responses - && model_supports_reasoning_summaries(config) - { - entries.push(( - "reasoning effort", - config.model_reasoning_effort.to_string(), - )); - entries.push(( - "reasoning summaries", - config.model_reasoning_summary.to_string(), - )); - } - - for (key, value) in entries { - println!("{} {}", format!("{key}:").style(self.bold), value); - } - - println!("--------"); - - // Echo the prompt that will be sent to the agent so it is visible in the - // transcript/logs before any events come in. Note the prompt may have been - // read from stdin, so it may not be visible in the terminal otherwise. - ts_println!( - self, - "{}\n{}", - "User instructions:".style(self.bold).style(self.cyan), - prompt - ); - } - - pub(crate) fn process_event(&mut self, event: Event) { - let Event { id: _, msg } = event; - match msg { - EventMsg::Error(ErrorEvent { message }) => { - let prefix = "ERROR:".style(self.red); - ts_println!(self, "{prefix} {message}"); - } - EventMsg::BackgroundEvent(BackgroundEventEvent { message }) => { - ts_println!(self, "{}", message.style(self.dimmed)); - } - EventMsg::TaskStarted | EventMsg::TaskComplete(_) => { - // Ignore. - } - EventMsg::TokenCount(TokenUsage { total_tokens, .. }) => { - ts_println!(self, "tokens used: {total_tokens}"); - } - EventMsg::AgentMessage(AgentMessageEvent { message }) => { - ts_println!( - self, - "{}\n{message}", - "codex".style(self.bold).style(self.magenta) - ); - } - EventMsg::ExecCommandBegin(ExecCommandBeginEvent { - call_id, - command, - cwd, - }) => { - self.call_id_to_command.insert( - call_id.clone(), - ExecCommandBegin { - command: command.clone(), - start_time: Instant::now(), - }, - ); - ts_println!( - self, - "{} {} in {}", - "exec".style(self.magenta), - escape_command(&command).style(self.bold), - cwd.to_string_lossy(), - ); - } - EventMsg::ExecCommandEnd(ExecCommandEndEvent { - call_id, - stdout, - stderr, - exit_code, - }) => { - let exec_command = self.call_id_to_command.remove(&call_id); - let (duration, call) = if let Some(ExecCommandBegin { - command, - start_time, - }) = exec_command - { - ( - format!(" in {}", format_elapsed(start_time)), - format!("{}", escape_command(&command).style(self.bold)), - ) - } else { - ("".to_string(), format!("exec('{call_id}')")) - }; - - let output = if exit_code == 0 { stdout } else { stderr }; - let truncated_output = output - .lines() - .take(MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL) - .collect::>() - .join("\n"); - match exit_code { - 0 => { - let title = format!("{call} succeeded{duration}:"); - ts_println!(self, "{}", title.style(self.green)); - } - _ => { - let title = format!("{call} exited {exit_code}{duration}:"); - ts_println!(self, "{}", title.style(self.red)); - } - } - println!("{}", truncated_output.style(self.dimmed)); - } - EventMsg::McpToolCallBegin(McpToolCallBeginEvent { - call_id, - server, - tool, - arguments, - }) => { - // Build fully-qualified tool name: server.tool - let fq_tool_name = format!("{server}.{tool}"); - - // Format arguments as compact JSON so they fit on one line. - let args_str = arguments - .as_ref() - .map(|v: &serde_json::Value| { - serde_json::to_string(v).unwrap_or_else(|_| v.to_string()) - }) - .unwrap_or_default(); - - let invocation = if args_str.is_empty() { - format!("{fq_tool_name}()") - } else { - format!("{fq_tool_name}({args_str})") - }; - - self.call_id_to_tool_call.insert( - call_id.clone(), - McpToolCallBegin { - invocation: invocation.clone(), - start_time: Instant::now(), - }, - ); - - ts_println!( - self, - "{} {}", - "tool".style(self.magenta), - invocation.style(self.bold), - ); - } - EventMsg::McpToolCallEnd(tool_call_end_event) => { - let is_success = tool_call_end_event.is_success(); - let McpToolCallEndEvent { call_id, result } = tool_call_end_event; - // Retrieve start time and invocation for duration calculation and labeling. - let info = self.call_id_to_tool_call.remove(&call_id); - - let (duration, invocation) = if let Some(McpToolCallBegin { - invocation, - start_time, - .. - }) = info - { - (format!(" in {}", format_elapsed(start_time)), invocation) - } else { - (String::new(), format!("tool('{call_id}')")) - }; - - let status_str = if is_success { "success" } else { "failed" }; - let title_style = if is_success { self.green } else { self.red }; - let title = format!("{invocation} {status_str}{duration}:"); - - ts_println!(self, "{}", title.style(title_style)); - - if let Ok(res) = result { - let val: serde_json::Value = res.into(); - let pretty = - serde_json::to_string_pretty(&val).unwrap_or_else(|_| val.to_string()); - - for line in pretty.lines().take(MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL) { - println!("{}", line.style(self.dimmed)); - } - } - } - EventMsg::PatchApplyBegin(PatchApplyBeginEvent { - call_id, - auto_approved, - changes, - }) => { - // Store metadata so we can calculate duration later when we - // receive the corresponding PatchApplyEnd event. - self.call_id_to_patch.insert( - call_id.clone(), - PatchApplyBegin { - start_time: Instant::now(), - auto_approved, - }, - ); - - ts_println!( - self, - "{} auto_approved={}:", - "apply_patch".style(self.magenta), - auto_approved, - ); - - // Pretty-print the patch summary with colored diff markers so - // it’s easy to scan in the terminal output. - for (path, change) in changes.iter() { - match change { - FileChange::Add { content } => { - let header = format!( - "{} {}", - format_file_change(change), - path.to_string_lossy() - ); - println!("{}", header.style(self.magenta)); - for line in content.lines() { - println!("{}", line.style(self.green)); - } - } - FileChange::Delete => { - let header = format!( - "{} {}", - format_file_change(change), - path.to_string_lossy() - ); - println!("{}", header.style(self.magenta)); - } - FileChange::Update { - unified_diff, - move_path, - } => { - let header = if let Some(dest) = move_path { - format!( - "{} {} -> {}", - format_file_change(change), - path.to_string_lossy(), - dest.to_string_lossy() - ) - } else { - format!("{} {}", format_file_change(change), path.to_string_lossy()) - }; - println!("{}", header.style(self.magenta)); - - // Colorize diff lines. We keep file header lines - // (--- / +++) without extra coloring so they are - // still readable. - for diff_line in unified_diff.lines() { - if diff_line.starts_with('+') && !diff_line.starts_with("+++") { - println!("{}", diff_line.style(self.green)); - } else if diff_line.starts_with('-') - && !diff_line.starts_with("---") - { - println!("{}", diff_line.style(self.red)); - } else { - println!("{diff_line}"); - } - } - } - } - } - } - EventMsg::PatchApplyEnd(PatchApplyEndEvent { - call_id, - stdout, - stderr, - success, - }) => { - let patch_begin = self.call_id_to_patch.remove(&call_id); - - // Compute duration and summary label similar to exec commands. - let (duration, label) = if let Some(PatchApplyBegin { - start_time, - auto_approved, - }) = patch_begin - { - ( - format!(" in {}", format_elapsed(start_time)), - format!("apply_patch(auto_approved={auto_approved})"), - ) - } else { - (String::new(), format!("apply_patch('{call_id}')")) - }; - - let (exit_code, output, title_style) = if success { - (0, stdout, self.green) - } else { - (1, stderr, self.red) - }; - - let title = format!("{label} exited {exit_code}{duration}:"); - ts_println!(self, "{}", title.style(title_style)); - for line in output.lines() { - println!("{}", line.style(self.dimmed)); - } - } - EventMsg::ExecApprovalRequest(_) => { - // Should we exit? - } - EventMsg::ApplyPatchApprovalRequest(_) => { - // Should we exit? - } - EventMsg::AgentReasoning(agent_reasoning_event) => { - if self.show_agent_reasoning { - ts_println!( - self, - "{}\n{}", - "thinking".style(self.italic).style(self.magenta), - agent_reasoning_event.text - ); - } - } - EventMsg::AgentReasoningContent(agent_reasoning_event) => { - if self.show_agent_reasoning && self.show_agent_reasoning_content { - ts_println!( - self, - "{}\n{}", - "thinking".style(self.italic).style(self.magenta), - agent_reasoning_event.text - ); - } - } - EventMsg::SessionConfigured(session_configured_event) => { - let SessionConfiguredEvent { - session_id, - model, - history_log_id: _, - history_entry_count: _, - } = session_configured_event; - - ts_println!( - self, - "{} {}", - "codex session".style(self.magenta).style(self.bold), - session_id.to_string().style(self.dimmed) - ); - - ts_println!(self, "model: {}", model); - println!(); - } - EventMsg::GetHistoryEntryResponse(_) => { - // Currently ignored in exec output. - } +fn write_last_message_file(contents: &str, last_message_path: Option<&Path>) { + if let Some(path) = last_message_path { + if let Err(e) = std::fs::write(path, contents) { + eprintln!("Failed to write last message file {path:?}: {e}"); } } } - -fn escape_command(command: &[String]) -> String { - try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" ")) -} - -fn format_file_change(change: &FileChange) -> &'static str { - match change { - FileChange::Add { .. } => "A", - FileChange::Delete => "D", - FileChange::Update { - move_path: Some(_), .. - } => "R", - FileChange::Update { - move_path: None, .. - } => "M", - } -} diff --git a/codex-rs/exec/src/event_processor_with_human_output.rs b/codex-rs/exec/src/event_processor_with_human_output.rs new file mode 100644 index 0000000000..54604d538b --- /dev/null +++ b/codex-rs/exec/src/event_processor_with_human_output.rs @@ -0,0 +1,520 @@ +use codex_common::elapsed::format_duration; +use codex_common::elapsed::format_elapsed; +use codex_core::config::Config; +use codex_core::plan_tool::UpdatePlanArgs; +use codex_core::protocol::AgentMessageDeltaEvent; +use codex_core::protocol::AgentMessageEvent; +use codex_core::protocol::AgentReasoningDeltaEvent; +use codex_core::protocol::BackgroundEventEvent; +use codex_core::protocol::ErrorEvent; +use codex_core::protocol::Event; +use codex_core::protocol::EventMsg; +use codex_core::protocol::ExecCommandBeginEvent; +use codex_core::protocol::ExecCommandEndEvent; +use codex_core::protocol::FileChange; +use codex_core::protocol::McpInvocation; +use codex_core::protocol::McpToolCallBeginEvent; +use codex_core::protocol::McpToolCallEndEvent; +use codex_core::protocol::PatchApplyBeginEvent; +use codex_core::protocol::PatchApplyEndEvent; +use codex_core::protocol::SessionConfiguredEvent; +use codex_core::protocol::TaskCompleteEvent; +use codex_core::protocol::TokenUsage; +use owo_colors::OwoColorize; +use owo_colors::Style; +use shlex::try_join; +use std::collections::HashMap; +use std::io::Write; +use std::path::PathBuf; +use std::time::Instant; + +use crate::event_processor::CodexStatus; +use crate::event_processor::EventProcessor; +use crate::event_processor::create_config_summary_entries; +use crate::event_processor::handle_last_message; + +/// This should be configurable. When used in CI, users may not want to impose +/// a limit so they can see the full transcript. +const MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL: usize = 20; +pub(crate) struct EventProcessorWithHumanOutput { + call_id_to_command: HashMap, + call_id_to_patch: HashMap, + + // To ensure that --color=never is respected, ANSI escapes _must_ be added + // using .style() with one of these fields. If you need a new style, add a + // new field here. + bold: Style, + italic: Style, + dimmed: Style, + + magenta: Style, + red: Style, + green: Style, + cyan: Style, + + /// Whether to include `AgentReasoning` events in the output. + show_agent_reasoning: bool, + answer_started: bool, + reasoning_started: bool, + last_message_path: Option, +} + +impl EventProcessorWithHumanOutput { + pub(crate) fn create_with_ansi( + with_ansi: bool, + config: &Config, + last_message_path: Option, + ) -> Self { + let call_id_to_command = HashMap::new(); + let call_id_to_patch = HashMap::new(); + + if with_ansi { + Self { + call_id_to_command, + call_id_to_patch, + bold: Style::new().bold(), + italic: Style::new().italic(), + dimmed: Style::new().dimmed(), + magenta: Style::new().magenta(), + red: Style::new().red(), + green: Style::new().green(), + cyan: Style::new().cyan(), + show_agent_reasoning: !config.hide_agent_reasoning, + answer_started: false, + reasoning_started: false, + last_message_path, + } + } else { + Self { + call_id_to_command, + call_id_to_patch, + bold: Style::new(), + italic: Style::new(), + dimmed: Style::new(), + magenta: Style::new(), + red: Style::new(), + green: Style::new(), + cyan: Style::new(), + show_agent_reasoning: !config.hide_agent_reasoning, + answer_started: false, + reasoning_started: false, + last_message_path, + } + } + } +} + +struct ExecCommandBegin { + command: Vec, + start_time: Instant, +} + +struct PatchApplyBegin { + start_time: Instant, + auto_approved: bool, +} + +// Timestamped println helper. The timestamp is styled with self.dimmed. +#[macro_export] +macro_rules! ts_println { + ($self:ident, $($arg:tt)*) => {{ + let now = chrono::Utc::now(); + let formatted = now.format("[%Y-%m-%dT%H:%M:%S]"); + print!("{} ", formatted.style($self.dimmed)); + println!($($arg)*); + }}; +} + +impl EventProcessor for EventProcessorWithHumanOutput { + /// Print a concise summary of the effective configuration that will be used + /// for the session. This mirrors the information shown in the TUI welcome + /// screen. + fn print_config_summary(&mut self, config: &Config, prompt: &str) { + const VERSION: &str = env!("CARGO_PKG_VERSION"); + ts_println!( + self, + "OpenAI Codex v{} (research preview)\n--------", + VERSION + ); + + let entries = create_config_summary_entries(config); + + for (key, value) in entries { + println!("{} {}", format!("{key}:").style(self.bold), value); + } + + println!("--------"); + + // Echo the prompt that will be sent to the agent so it is visible in the + // transcript/logs before any events come in. Note the prompt may have been + // read from stdin, so it may not be visible in the terminal otherwise. + ts_println!( + self, + "{}\n{}", + "User instructions:".style(self.bold).style(self.cyan), + prompt + ); + } + + fn process_event(&mut self, event: Event) -> CodexStatus { + let Event { id: _, msg } = event; + match msg { + EventMsg::Error(ErrorEvent { message }) => { + let prefix = "ERROR:".style(self.red); + ts_println!(self, "{prefix} {message}"); + } + EventMsg::BackgroundEvent(BackgroundEventEvent { message }) => { + ts_println!(self, "{}", message.style(self.dimmed)); + } + EventMsg::TaskStarted => { + // Ignore. + } + EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => { + handle_last_message( + last_agent_message.as_deref(), + self.last_message_path.as_deref(), + ); + return CodexStatus::InitiateShutdown; + } + EventMsg::TokenCount(TokenUsage { total_tokens, .. }) => { + ts_println!(self, "tokens used: {total_tokens}"); + } + EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => { + if !self.answer_started { + ts_println!(self, "{}\n", "codex".style(self.italic).style(self.magenta)); + self.answer_started = true; + } + print!("{delta}"); + #[allow(clippy::expect_used)] + std::io::stdout().flush().expect("could not flush stdout"); + } + EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }) => { + if !self.show_agent_reasoning { + return CodexStatus::Running; + } + if !self.reasoning_started { + ts_println!( + self, + "{}\n", + "thinking".style(self.italic).style(self.magenta), + ); + self.reasoning_started = true; + } + print!("{delta}"); + #[allow(clippy::expect_used)] + std::io::stdout().flush().expect("could not flush stdout"); + } + EventMsg::AgentMessage(AgentMessageEvent { message }) => { + // if answer_started is false, this means we haven't received any + // delta. Thus, we need to print the message as a new answer. + if !self.answer_started { + ts_println!( + self, + "{}\n{}", + "codex".style(self.italic).style(self.magenta), + message, + ); + } else { + println!(); + self.answer_started = false; + } + } + EventMsg::ExecCommandBegin(ExecCommandBeginEvent { + call_id, + command, + cwd, + }) => { + self.call_id_to_command.insert( + call_id.clone(), + ExecCommandBegin { + command: command.clone(), + start_time: Instant::now(), + }, + ); + ts_println!( + self, + "{} {} in {}", + "exec".style(self.magenta), + escape_command(&command).style(self.bold), + cwd.to_string_lossy(), + ); + } + EventMsg::ExecCommandEnd(ExecCommandEndEvent { + call_id, + stdout, + stderr, + exit_code, + }) => { + let exec_command = self.call_id_to_command.remove(&call_id); + let (duration, call) = if let Some(ExecCommandBegin { + command, + start_time, + }) = exec_command + { + ( + format!(" in {}", format_elapsed(start_time)), + format!("{}", escape_command(&command).style(self.bold)), + ) + } else { + ("".to_string(), format!("exec('{call_id}')")) + }; + + let output = if exit_code == 0 { stdout } else { stderr }; + let truncated_output = output + .lines() + .take(MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL) + .collect::>() + .join("\n"); + match exit_code { + 0 => { + let title = format!("{call} succeeded{duration}:"); + ts_println!(self, "{}", title.style(self.green)); + } + _ => { + let title = format!("{call} exited {exit_code}{duration}:"); + ts_println!(self, "{}", title.style(self.red)); + } + } + println!("{}", truncated_output.style(self.dimmed)); + } + EventMsg::McpToolCallBegin(McpToolCallBeginEvent { + call_id: _, + invocation, + }) => { + ts_println!( + self, + "{} {}", + "tool".style(self.magenta), + format_mcp_invocation(&invocation).style(self.bold), + ); + } + EventMsg::McpToolCallEnd(tool_call_end_event) => { + let is_success = tool_call_end_event.is_success(); + let McpToolCallEndEvent { + call_id: _, + result, + invocation, + duration, + } = tool_call_end_event; + + let duration = format!(" in {}", format_duration(duration)); + + let status_str = if is_success { "success" } else { "failed" }; + let title_style = if is_success { self.green } else { self.red }; + let title = format!( + "{} {status_str}{duration}:", + format_mcp_invocation(&invocation) + ); + + ts_println!(self, "{}", title.style(title_style)); + + if let Ok(res) = result { + let val: serde_json::Value = res.into(); + let pretty = + serde_json::to_string_pretty(&val).unwrap_or_else(|_| val.to_string()); + + for line in pretty.lines().take(MAX_OUTPUT_LINES_FOR_EXEC_TOOL_CALL) { + println!("{}", line.style(self.dimmed)); + } + } + } + EventMsg::PatchApplyBegin(PatchApplyBeginEvent { + call_id, + auto_approved, + changes, + }) => { + // Store metadata so we can calculate duration later when we + // receive the corresponding PatchApplyEnd event. + self.call_id_to_patch.insert( + call_id.clone(), + PatchApplyBegin { + start_time: Instant::now(), + auto_approved, + }, + ); + + ts_println!( + self, + "{} auto_approved={}:", + "apply_patch".style(self.magenta), + auto_approved, + ); + + // Pretty-print the patch summary with colored diff markers so + // it's easy to scan in the terminal output. + for (path, change) in changes.iter() { + match change { + FileChange::Add { content } => { + let header = format!( + "{} {}", + format_file_change(change), + path.to_string_lossy() + ); + println!("{}", header.style(self.magenta)); + for line in content.lines() { + println!("{}", line.style(self.green)); + } + } + FileChange::Delete => { + let header = format!( + "{} {}", + format_file_change(change), + path.to_string_lossy() + ); + println!("{}", header.style(self.magenta)); + } + FileChange::Update { + unified_diff, + move_path, + } => { + let header = if let Some(dest) = move_path { + format!( + "{} {} -> {}", + format_file_change(change), + path.to_string_lossy(), + dest.to_string_lossy() + ) + } else { + format!("{} {}", format_file_change(change), path.to_string_lossy()) + }; + println!("{}", header.style(self.magenta)); + + // Colorize diff lines. We keep file header lines + // (--- / +++) without extra coloring so they are + // still readable. + for diff_line in unified_diff.lines() { + if diff_line.starts_with('+') && !diff_line.starts_with("+++") { + println!("{}", diff_line.style(self.green)); + } else if diff_line.starts_with('-') + && !diff_line.starts_with("---") + { + println!("{}", diff_line.style(self.red)); + } else { + println!("{diff_line}"); + } + } + } + } + } + } + EventMsg::PatchApplyEnd(PatchApplyEndEvent { + call_id, + stdout, + stderr, + success, + }) => { + let patch_begin = self.call_id_to_patch.remove(&call_id); + + // Compute duration and summary label similar to exec commands. + let (duration, label) = if let Some(PatchApplyBegin { + start_time, + auto_approved, + }) = patch_begin + { + ( + format!(" in {}", format_elapsed(start_time)), + format!("apply_patch(auto_approved={auto_approved})"), + ) + } else { + (String::new(), format!("apply_patch('{call_id}')")) + }; + + let (exit_code, output, title_style) = if success { + (0, stdout, self.green) + } else { + (1, stderr, self.red) + }; + + let title = format!("{label} exited {exit_code}{duration}:"); + ts_println!(self, "{}", title.style(title_style)); + for line in output.lines() { + println!("{}", line.style(self.dimmed)); + } + } + EventMsg::ExecApprovalRequest(_) => { + // Should we exit? + } + EventMsg::ApplyPatchApprovalRequest(_) => { + // Should we exit? + } + EventMsg::AgentReasoning(agent_reasoning_event) => { + if self.show_agent_reasoning { + if !self.reasoning_started { + ts_println!( + self, + "{}\n{}", + "codex".style(self.italic).style(self.magenta), + agent_reasoning_event.text, + ); + } else { + println!(); + self.reasoning_started = false; + } + } + } + EventMsg::SessionConfigured(session_configured_event) => { + let SessionConfiguredEvent { + session_id, + model, + history_log_id: _, + history_entry_count: _, + } = session_configured_event; + + ts_println!( + self, + "{} {}", + "codex session".style(self.magenta).style(self.bold), + session_id.to_string().style(self.dimmed) + ); + + ts_println!(self, "model: {}", model); + println!(); + } + EventMsg::PlanUpdate(plan_update_event) => { + let UpdatePlanArgs { explanation, plan } = plan_update_event; + ts_println!(self, "explanation: {explanation:?}"); + ts_println!(self, "plan: {plan:?}"); + } + EventMsg::GetHistoryEntryResponse(_) => { + // Currently ignored in exec output. + } + EventMsg::ShutdownComplete => return CodexStatus::Shutdown, + } + CodexStatus::Running + } +} + +fn escape_command(command: &[String]) -> String { + try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" ")) +} + +fn format_file_change(change: &FileChange) -> &'static str { + match change { + FileChange::Add { .. } => "A", + FileChange::Delete => "D", + FileChange::Update { + move_path: Some(_), .. + } => "R", + FileChange::Update { + move_path: None, .. + } => "M", + } +} + +fn format_mcp_invocation(invocation: &McpInvocation) -> String { + // Build fully-qualified tool name: server.tool + let fq_tool_name = format!("{}.{}", invocation.server, invocation.tool); + + // Format arguments as compact JSON so they fit on one line. + let args_str = invocation + .arguments + .as_ref() + .map(|v: &serde_json::Value| serde_json::to_string(v).unwrap_or_else(|_| v.to_string())) + .unwrap_or_default(); + + if args_str.is_empty() { + format!("{fq_tool_name}()") + } else { + format!("{fq_tool_name}({args_str})") + } +} diff --git a/codex-rs/exec/src/event_processor_with_json_output.rs b/codex-rs/exec/src/event_processor_with_json_output.rs new file mode 100644 index 0000000000..e7a658b76f --- /dev/null +++ b/codex-rs/exec/src/event_processor_with_json_output.rs @@ -0,0 +1,64 @@ +use std::collections::HashMap; +use std::path::PathBuf; + +use codex_core::config::Config; +use codex_core::protocol::Event; +use codex_core::protocol::EventMsg; +use codex_core::protocol::TaskCompleteEvent; +use serde_json::json; + +use crate::event_processor::CodexStatus; +use crate::event_processor::EventProcessor; +use crate::event_processor::create_config_summary_entries; +use crate::event_processor::handle_last_message; + +pub(crate) struct EventProcessorWithJsonOutput { + last_message_path: Option, +} + +impl EventProcessorWithJsonOutput { + pub fn new(last_message_path: Option) -> Self { + Self { last_message_path } + } +} + +impl EventProcessor for EventProcessorWithJsonOutput { + fn print_config_summary(&mut self, config: &Config, prompt: &str) { + let entries = create_config_summary_entries(config) + .into_iter() + .map(|(key, value)| (key.to_string(), value)) + .collect::>(); + #[allow(clippy::expect_used)] + let config_json = + serde_json::to_string(&entries).expect("Failed to serialize config summary to JSON"); + println!("{config_json}"); + + let prompt_json = json!({ + "prompt": prompt, + }); + println!("{prompt_json}"); + } + + fn process_event(&mut self, event: Event) -> CodexStatus { + match event.msg { + EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_) => { + // Suppress streaming events in JSON mode. + CodexStatus::Running + } + EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => { + handle_last_message( + last_agent_message.as_deref(), + self.last_message_path.as_deref(), + ); + CodexStatus::InitiateShutdown + } + EventMsg::ShutdownComplete => CodexStatus::Shutdown, + _ => { + if let Ok(line) = serde_json::to_string(&event) { + println!("{line}"); + } + CodexStatus::Running + } + } + } +} diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index 85f154c5aa..ce4d7f65cc 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -1,14 +1,16 @@ mod cli; mod event_processor; +mod event_processor_with_human_output; +mod event_processor_with_json_output; use std::io::IsTerminal; use std::io::Read; -use std::path::Path; use std::path::PathBuf; use std::sync::Arc; pub use cli::Cli; -use codex_core::codex_wrapper; +use codex_core::codex_wrapper::CodexConversation; +use codex_core::codex_wrapper::{self}; use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::config_types::SandboxMode; @@ -19,12 +21,16 @@ use codex_core::protocol::InputItem; use codex_core::protocol::Op; use codex_core::protocol::TaskCompleteEvent; use codex_core::util::is_inside_git_repo; -use event_processor::EventProcessor; +use event_processor_with_human_output::EventProcessorWithHumanOutput; +use event_processor_with_json_output::EventProcessorWithJsonOutput; use tracing::debug; use tracing::error; use tracing::info; use tracing_subscriber::EnvFilter; +use crate::event_processor::CodexStatus; +use crate::event_processor::EventProcessor; + pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> anyhow::Result<()> { let Cli { images, @@ -36,6 +42,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any skip_git_repo_check, color, last_message_file, + json: json_mode, sandbox_mode: sandbox_mode_cli_arg, prompt, config_overrides, @@ -85,6 +92,20 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any ), }; + // TODO(mbolin): Take a more thoughtful approach to logging. + let default_level = "error"; + let _ = tracing_subscriber::fmt() + // Fallback to the `default_level` log filter if the environment + // variable is not set _or_ contains an invalid value + .with_env_filter( + EnvFilter::try_from_default_env() + .or_else(|_| EnvFilter::try_new(default_level)) + .unwrap_or_else(|_| EnvFilter::new(default_level)), + ) + .with_ansi(stderr_with_ansi) + .with_writer(std::io::stderr) + .try_init(); + let sandbox_mode = if full_auto { Some(SandboxMode::WorkspaceWrite) } else if dangerously_bypass_approvals_and_sandbox { @@ -104,6 +125,8 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)), model_provider: None, codex_linux_sandbox_exe, + base_instructions: None, + include_plan_tool: None, }; // Parse `-c` overrides. let cli_kv_overrides = match config_overrides.parse_overrides() { @@ -115,11 +138,16 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any }; let config = Config::load_with_cli_overrides(cli_kv_overrides, overrides)?; - let mut event_processor = EventProcessor::create_with_ansi( - stdout_with_ansi, - !config.hide_agent_reasoning, - config.show_reasoning_content, - ); + let mut event_processor: Box = if json_mode { + Box::new(EventProcessorWithJsonOutput::new(last_message_file.clone())) + } else { + Box::new(EventProcessorWithHumanOutput::create_with_ansi( + stdout_with_ansi, + &config, + last_message_file.clone(), + )) + }; + // Print the effective configuration and prompt so users can see what Codex // is using. event_processor.print_config_summary(&config, &prompt); @@ -129,23 +157,14 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any std::process::exit(1); } - // TODO(mbolin): Take a more thoughtful approach to logging. - let default_level = "error"; - let _ = tracing_subscriber::fmt() - // Fallback to the `default_level` log filter if the environment - // variable is not set _or_ contains an invalid value - .with_env_filter( - EnvFilter::try_from_default_env() - .or_else(|_| EnvFilter::try_new(default_level)) - .unwrap_or_else(|_| EnvFilter::new(default_level)), - ) - .with_ansi(stderr_with_ansi) - .with_writer(std::io::stderr) - .try_init(); - - let (codex_wrapper, event, ctrl_c) = codex_wrapper::init_codex(config).await?; + let CodexConversation { + codex: codex_wrapper, + session_configured, + ctrl_c, + .. + } = codex_wrapper::init_codex(config).await?; let codex = Arc::new(codex_wrapper); - info!("Codex initialized with event: {event:?}"); + info!("Codex initialized with event: {session_configured:?}"); let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); { @@ -213,40 +232,17 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any // Run the loop until the task is complete. while let Some(event) = rx.recv().await { - let (is_last_event, last_assistant_message) = match &event.msg { - EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => { - (true, last_agent_message.clone()) + let shutdown: CodexStatus = event_processor.process_event(event); + match shutdown { + CodexStatus::Running => continue, + CodexStatus::InitiateShutdown => { + codex.submit(Op::Shutdown).await?; + } + CodexStatus::Shutdown => { + break; } - _ => (false, None), - }; - event_processor.process_event(event); - if is_last_event { - handle_last_message(last_assistant_message, last_message_file.as_deref())?; - break; } } Ok(()) } - -fn handle_last_message( - last_agent_message: Option, - last_message_file: Option<&Path>, -) -> std::io::Result<()> { - match (last_agent_message, last_message_file) { - (Some(last_agent_message), Some(last_message_file)) => { - // Last message and a file to write to. - std::fs::write(last_message_file, last_agent_message)?; - } - (None, Some(last_message_file)) => { - eprintln!( - "Warning: No last message to write to file: {}", - last_message_file.to_string_lossy() - ); - } - (_, None) => { - // No last message and no file to write to. - } - } - Ok(()) -} diff --git a/codex-rs/exec/src/main.rs b/codex-rs/exec/src/main.rs index 3a8e1f9411..03ee533ea9 100644 --- a/codex-rs/exec/src/main.rs +++ b/codex-rs/exec/src/main.rs @@ -10,6 +10,7 @@ //! This allows us to ship a completely separate set of functionality as part //! of the `codex-exec` binary. use clap::Parser; +use codex_arg0::arg0_dispatch_or_else; use codex_common::CliConfigOverrides; use codex_exec::Cli; use codex_exec::run_main; @@ -24,7 +25,7 @@ struct TopCli { } fn main() -> anyhow::Result<()> { - codex_linux_sandbox::run_with_sandbox(|codex_linux_sandbox_exe| async move { + arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move { let top_cli = TopCli::parse(); // Merge root-level overrides into inner CLI struct so downstream logic remains unchanged. let mut inner = top_cli.inner; diff --git a/codex-rs/exec/tests/apply_patch.rs b/codex-rs/exec/tests/apply_patch.rs new file mode 100644 index 0000000000..f65d32e1c8 --- /dev/null +++ b/codex-rs/exec/tests/apply_patch.rs @@ -0,0 +1,39 @@ +use anyhow::Context; +use assert_cmd::prelude::*; +use codex_core::CODEX_APPLY_PATCH_ARG1; +use std::fs; +use std::process::Command; +use tempfile::tempdir; + +/// While we may add an `apply-patch` subcommand to the `codex` CLI multitool +/// at some point, we must ensure that the smaller `codex-exec` CLI can still +/// emulate the `apply_patch` CLI. +#[test] +fn test_standalone_exec_cli_can_use_apply_patch() -> anyhow::Result<()> { + let tmp = tempdir()?; + let relative_path = "source.txt"; + let absolute_path = tmp.path().join(relative_path); + fs::write(&absolute_path, "original content\n")?; + + Command::cargo_bin("codex-exec") + .context("should find binary for codex-exec")? + .arg(CODEX_APPLY_PATCH_ARG1) + .arg( + r#"*** Begin Patch +*** Update File: source.txt +@@ +-original content ++modified by apply_patch +*** End Patch"#, + ) + .current_dir(tmp.path()) + .assert() + .success() + .stdout("Success. Updated the following files:\nM source.txt\n") + .stderr(predicates::str::is_empty()); + assert_eq!( + fs::read_to_string(absolute_path)?, + "modified by apply_patch\n" + ); + Ok(()) +} diff --git a/codex-rs/file-search/Cargo.toml b/codex-rs/file-search/Cargo.toml index bb5b80b2cf..3f70377183 100644 --- a/codex-rs/file-search/Cargo.toml +++ b/codex-rs/file-search/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-file-search" version = { workspace = true } -edition = "2024" [[bin]] name = "codex-file-search" diff --git a/codex-rs/justfile b/codex-rs/justfile index 83a390ec56..3e1336be43 100644 --- a/codex-rs/justfile +++ b/codex-rs/justfile @@ -23,3 +23,10 @@ file-search *args: # format code fmt: cargo fmt -- --config imports_granularity=Item + +fix: + cargo clippy --fix --all-features --tests --allow-dirty + +install: + rustup show active-toolchain + cargo fetch diff --git a/codex-rs/linux-sandbox/Cargo.toml b/codex-rs/linux-sandbox/Cargo.toml index c8cd1078c0..ea7052c409 100644 --- a/codex-rs/linux-sandbox/Cargo.toml +++ b/codex-rs/linux-sandbox/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-linux-sandbox" version = { workspace = true } -edition = "2024" [[bin]] name = "codex-linux-sandbox" @@ -14,13 +14,16 @@ path = "src/lib.rs" [lints] workspace = true -[dependencies] +[target.'cfg(target_os = "linux")'.dependencies] anyhow = "1" clap = { version = "4", features = ["derive"] } +codex-common = { path = "../common", features = ["cli"] } codex-core = { path = "../core" } -tokio = { version = "1", features = ["rt-multi-thread"] } +landlock = "0.4.1" +libc = "0.2.172" +seccompiler = "0.5.0" -[dev-dependencies] +[target.'cfg(target_os = "linux")'.dev-dependencies] tempfile = "3" tokio = { version = "1", features = [ "io-std", @@ -29,8 +32,3 @@ tokio = { version = "1", features = [ "rt-multi-thread", "signal", ] } - -[target.'cfg(target_os = "linux")'.dependencies] -libc = "0.2.172" -landlock = "0.4.1" -seccompiler = "0.5.0" diff --git a/codex-rs/linux-sandbox/src/lib.rs b/codex-rs/linux-sandbox/src/lib.rs index 568f015822..80453c7f96 100644 --- a/codex-rs/linux-sandbox/src/lib.rs +++ b/codex-rs/linux-sandbox/src/lib.rs @@ -4,57 +4,8 @@ mod landlock; mod linux_run_main; #[cfg(target_os = "linux")] -pub use linux_run_main::run_main; - -use std::future::Future; -use std::path::PathBuf; - -/// Helper that consolidates the common boilerplate found in several Codex -/// binaries (`codex`, `codex-exec`, `codex-tui`) around dispatching to the -/// `codex-linux-sandbox` sub-command. -/// -/// When the current executable is invoked through the hard-link or alias -/// named `codex-linux-sandbox` we *directly* execute [`run_main`](crate::run_main) -/// (which never returns). Otherwise we: -/// 1. Construct a Tokio multi-thread runtime. -/// 2. Derive the path to the current executable (so children can re-invoke -/// the sandbox) when running on Linux. -/// 3. Execute the provided async `main_fn` inside that runtime, forwarding -/// any error. -/// -/// This function eliminates duplicated code across the various `main.rs` -/// entry-points. -pub fn run_with_sandbox(main_fn: F) -> anyhow::Result<()> -where - F: FnOnce(Option) -> Fut, - Fut: Future>, -{ - use std::path::Path; - - // Determine if we were invoked via the special alias. - let argv0 = std::env::args().next().unwrap_or_default(); - let exe_name = Path::new(&argv0) - .file_name() - .and_then(|s| s.to_str()) - .unwrap_or(""); - - if exe_name == "codex-linux-sandbox" { - // Safety: [`run_main`] never returns. - crate::run_main(); - } - - // Regular invocation – create a Tokio runtime and execute the provided - // async entry-point. - let runtime = tokio::runtime::Runtime::new()?; - runtime.block_on(async move { - let codex_linux_sandbox_exe: Option = if cfg!(target_os = "linux") { - std::env::current_exe().ok() - } else { - None - }; - - main_fn(codex_linux_sandbox_exe).await - }) +pub fn run_main() -> ! { + linux_run_main::run_main(); } #[cfg(not(target_os = "linux"))] diff --git a/codex-rs/login/Cargo.toml b/codex-rs/login/Cargo.toml index e6eba6fd4f..e10666b092 100644 --- a/codex-rs/login/Cargo.toml +++ b/codex-rs/login/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-login" version = { workspace = true } -edition = "2024" [lints] workspace = true diff --git a/codex-rs/login/src/lib.rs b/codex-rs/login/src/lib.rs index 99d2f7f983..47dbbca9fb 100644 --- a/codex-rs/login/src/lib.rs +++ b/codex-rs/login/src/lib.rs @@ -1,19 +1,152 @@ use chrono::DateTime; + use chrono::Utc; use serde::Deserialize; use serde::Serialize; +use std::env; use std::fs::OpenOptions; use std::io::Read; use std::io::Write; #[cfg(unix)] use std::os::unix::fs::OpenOptionsExt; use std::path::Path; +use std::path::PathBuf; use std::process::Stdio; +use std::sync::Arc; +use std::sync::Mutex; +use std::time::Duration; use tokio::process::Command; const SOURCE_FOR_PYTHON_SERVER: &str = include_str!("./login_with_chatgpt.py"); const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; +const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY"; + +#[derive(Clone, Debug, PartialEq)] +pub enum AuthMode { + ApiKey, + ChatGPT, +} + +#[derive(Debug, Clone)] +pub struct CodexAuth { + pub api_key: Option, + pub mode: AuthMode, + auth_dot_json: Arc>>, + auth_file: PathBuf, +} + +impl PartialEq for CodexAuth { + fn eq(&self, other: &Self) -> bool { + self.mode == other.mode + } +} + +impl CodexAuth { + pub fn new( + api_key: Option, + mode: AuthMode, + auth_file: PathBuf, + auth_dot_json: Option, + ) -> Self { + let auth_dot_json = Arc::new(Mutex::new(auth_dot_json)); + Self { + api_key, + mode, + auth_file, + auth_dot_json, + } + } + + pub fn from_api_key(api_key: String) -> Self { + Self { + api_key: Some(api_key), + mode: AuthMode::ApiKey, + auth_file: PathBuf::new(), + auth_dot_json: Arc::new(Mutex::new(None)), + } + } + + pub async fn get_token_data(&self) -> Result { + #[expect(clippy::unwrap_used)] + let auth_dot_json = self.auth_dot_json.lock().unwrap().clone(); + + match auth_dot_json { + Some(auth_dot_json) => { + if auth_dot_json.last_refresh < Utc::now() - chrono::Duration::days(28) { + let refresh_response = tokio::time::timeout( + Duration::from_secs(60), + try_refresh_token(auth_dot_json.tokens.refresh_token.clone()), + ) + .await + .map_err(|_| { + std::io::Error::other("timed out while refreshing OpenAI API key") + })? + .map_err(std::io::Error::other)?; + + let updated_auth_dot_json = update_tokens( + &self.auth_file, + refresh_response.id_token, + refresh_response.access_token, + refresh_response.refresh_token, + ) + .await?; + + #[expect(clippy::unwrap_used)] + let mut auth_dot_json = self.auth_dot_json.lock().unwrap(); + *auth_dot_json = Some(updated_auth_dot_json); + } + Ok(auth_dot_json.tokens.clone()) + } + None => Err(std::io::Error::other("Token data is not available.")), + } + } + + pub async fn get_token(&self) -> Result { + match self.mode { + AuthMode::ApiKey => Ok(self.api_key.clone().unwrap_or_default()), + AuthMode::ChatGPT => { + let id_token = self.get_token_data().await?.access_token; + + Ok(id_token) + } + } + } +} + +// Loads the available auth information from the auth.json or OPENAI_API_KEY environment variable. +pub fn load_auth(codex_home: &Path) -> std::io::Result> { + let auth_file = codex_home.join("auth.json"); + + let auth_dot_json = try_read_auth_json(&auth_file).ok(); + + let auth_json_api_key = auth_dot_json + .as_ref() + .and_then(|a| a.openai_api_key.clone()) + .filter(|s| !s.is_empty()); + + let openai_api_key = env::var(OPENAI_API_KEY_ENV_VAR) + .ok() + .filter(|s| !s.is_empty()) + .or(auth_json_api_key); + + if openai_api_key.is_none() && auth_dot_json.is_none() { + return Ok(None); + } + + let mode = if openai_api_key.is_some() { + AuthMode::ApiKey + } else { + AuthMode::ChatGPT + }; + + Ok(Some(CodexAuth { + api_key: openai_api_key, + mode, + auth_file, + auth_dot_json: Arc::new(Mutex::new(auth_dot_json)), + })) +} /// Run `python3 -c {{SOURCE_FOR_PYTHON_SERVER}}` with the CODEX_HOME /// environment variable set to the provided `codex_home` path. If the @@ -24,14 +157,12 @@ const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; /// If `capture_output` is true, the subprocess's output will be captured and /// recorded in memory. Otherwise, the subprocess's output will be sent to the /// current process's stdout/stderr. -pub async fn login_with_chatgpt( - codex_home: &Path, - capture_output: bool, -) -> std::io::Result { +pub async fn login_with_chatgpt(codex_home: &Path, capture_output: bool) -> std::io::Result<()> { let child = Command::new("python3") .arg("-c") .arg(SOURCE_FOR_PYTHON_SERVER) .env("CODEX_HOME", codex_home) + .env("CODEX_CLIENT_ID", CLIENT_ID) .stdin(Stdio::null()) .stdout(if capture_output { Stdio::piped() @@ -47,7 +178,7 @@ pub async fn login_with_chatgpt( let output = child.wait_with_output().await?; if output.status.success() { - try_read_openai_api_key(codex_home).await + Ok(()) } else { let stderr = String::from_utf8_lossy(&output.stderr); Err(std::io::Error::other(format!( @@ -56,61 +187,54 @@ pub async fn login_with_chatgpt( } } -/// Attempt to read the `OPENAI_API_KEY` from the `auth.json` file in the given -/// `CODEX_HOME` directory, refreshing it, if necessary. -pub async fn try_read_openai_api_key(codex_home: &Path) -> std::io::Result { - let auth_dot_json = try_read_auth_json(codex_home).await?; - Ok(auth_dot_json.openai_api_key) -} - /// Attempt to read and refresh the `auth.json` file in the given `CODEX_HOME` directory. /// Returns the full AuthDotJson structure after refreshing if necessary. -pub async fn try_read_auth_json(codex_home: &Path) -> std::io::Result { - let auth_path = codex_home.join("auth.json"); - let mut file = std::fs::File::open(&auth_path)?; +pub fn try_read_auth_json(auth_file: &Path) -> std::io::Result { + let mut file = std::fs::File::open(auth_file)?; let mut contents = String::new(); file.read_to_string(&mut contents)?; let auth_dot_json: AuthDotJson = serde_json::from_str(&contents)?; - if is_expired(&auth_dot_json) { - let refresh_response = try_refresh_token(&auth_dot_json).await?; - let mut auth_dot_json = auth_dot_json; - auth_dot_json.tokens.id_token = refresh_response.id_token; - if let Some(refresh_token) = refresh_response.refresh_token { - auth_dot_json.tokens.refresh_token = refresh_token; - } - auth_dot_json.last_refresh = Utc::now(); + Ok(auth_dot_json) +} - let mut options = OpenOptions::new(); - options.truncate(true).write(true).create(true); - #[cfg(unix)] - { - options.mode(0o600); - } - - let json_data = serde_json::to_string(&auth_dot_json)?; - { - let mut file = options.open(&auth_path)?; - file.write_all(json_data.as_bytes())?; - file.flush()?; - } - - Ok(auth_dot_json) - } else { - Ok(auth_dot_json) +async fn update_tokens( + auth_file: &Path, + id_token: String, + access_token: Option, + refresh_token: Option, +) -> std::io::Result { + let mut options = OpenOptions::new(); + options.truncate(true).write(true).create(true); + #[cfg(unix)] + { + options.mode(0o600); } + let mut auth_dot_json = try_read_auth_json(auth_file)?; + + auth_dot_json.tokens.id_token = id_token.to_string(); + if let Some(access_token) = access_token { + auth_dot_json.tokens.access_token = access_token.to_string(); + } + if let Some(refresh_token) = refresh_token { + auth_dot_json.tokens.refresh_token = refresh_token.to_string(); + } + auth_dot_json.last_refresh = Utc::now(); + + let json_data = serde_json::to_string_pretty(&auth_dot_json)?; + { + let mut file = options.open(auth_file)?; + file.write_all(json_data.as_bytes())?; + file.flush()?; + } + Ok(auth_dot_json) } -fn is_expired(auth_dot_json: &AuthDotJson) -> bool { - let last_refresh = auth_dot_json.last_refresh; - last_refresh < Utc::now() - chrono::Duration::days(28) -} - -async fn try_refresh_token(auth_dot_json: &AuthDotJson) -> std::io::Result { +async fn try_refresh_token(refresh_token: String) -> std::io::Result { let refresh_request = RefreshRequest { client_id: CLIENT_ID, grant_type: "refresh_token", - refresh_token: auth_dot_json.tokens.refresh_token.clone(), + refresh_token, scope: "openid profile email", }; @@ -145,24 +269,25 @@ struct RefreshRequest { scope: &'static str, } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] struct RefreshResponse { id_token: String, + access_token: Option, refresh_token: Option, } /// Expected structure for $CODEX_HOME/auth.json. -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] pub struct AuthDotJson { #[serde(rename = "OPENAI_API_KEY")] - pub openai_api_key: String, + pub openai_api_key: Option, pub tokens: TokenData, pub last_refresh: DateTime, } -#[derive(Deserialize, Serialize, Clone)] +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] pub struct TokenData { /// This is a JWT. pub id_token: String, @@ -172,5 +297,5 @@ pub struct TokenData { pub refresh_token: String, - pub account_id: String, + pub account_id: Option, } diff --git a/codex-rs/login/src/login_with_chatgpt.py b/codex-rs/login/src/login_with_chatgpt.py index ccb051c0af..2dbf5be58a 100644 --- a/codex-rs/login/src/login_with_chatgpt.py +++ b/codex-rs/login/src/login_with_chatgpt.py @@ -41,7 +41,6 @@ from typing import Any, Dict # for type hints REQUIRED_PORT = 1455 URL_BASE = f"http://localhost:{REQUIRED_PORT}" DEFAULT_ISSUER = "https://auth.openai.com" -DEFAULT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" EXIT_CODE_WHEN_ADDRESS_ALREADY_IN_USE = 13 @@ -58,7 +57,7 @@ class TokenData: class AuthBundle: """Aggregates authentication data produced after successful OAuth flow.""" - api_key: str + api_key: str | None token_data: TokenData last_refresh: str @@ -78,12 +77,18 @@ def main() -> None: eprint("ERROR: CODEX_HOME environment variable is not set") sys.exit(1) + client_id = os.getenv("CODEX_CLIENT_ID") + if not client_id: + eprint("ERROR: CODEX_CLIENT_ID environment variable is not set") + sys.exit(1) + # Spawn server. try: httpd = _ApiKeyHTTPServer( ("127.0.0.1", REQUIRED_PORT), _ApiKeyHTTPHandler, codex_home=codex_home, + client_id=client_id, verbose=args.verbose, ) except OSError as e: @@ -157,7 +162,7 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler): return try: - auth_bundle, success_url = self._exchange_code_for_api_key(code) + auth_bundle, success_url = self._exchange_code(code) except Exception as exc: # noqa: BLE001 – propagate to client self.send_error(500, f"Token exchange failed: {exc}") return @@ -211,68 +216,22 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler): if getattr(self.server, "verbose", False): # type: ignore[attr-defined] super().log_message(fmt, *args) - def _exchange_code_for_api_key(self, code: str) -> tuple[AuthBundle, str]: - """Perform token + token-exchange to obtain an OpenAI API key. + def _obtain_api_key( + self, + token_claims: Dict[str, Any], + access_claims: Dict[str, Any], + token_data: TokenData, + ) -> tuple[str | None, str | None]: + """Obtain an API key from the auth service. - Returns (AuthBundle, success_url). + Returns (api_key, success_url) if successful, None otherwise. """ - token_endpoint = f"{self.server.issuer}/oauth/token" - - # 1. Authorization-code -> (id_token, access_token, refresh_token) - data = urllib.parse.urlencode( - { - "grant_type": "authorization_code", - "code": code, - "redirect_uri": self.server.redirect_uri, - "client_id": self.server.client_id, - "code_verifier": self.server.pkce.code_verifier, - } - ).encode() - - token_data: TokenData - - with urllib.request.urlopen( - urllib.request.Request( - token_endpoint, - data=data, - method="POST", - headers={"Content-Type": "application/x-www-form-urlencoded"}, - ) - ) as resp: - payload = json.loads(resp.read().decode()) - - # Extract chatgpt_account_id from id_token - id_token_parts = payload["id_token"].split(".") - if len(id_token_parts) != 3: - raise ValueError("Invalid ID token") - id_token_claims = _decode_jwt_segment(id_token_parts[1]) - auth_claims = id_token_claims.get("https://api.openai.com/auth", {}) - chatgpt_account_id = auth_claims.get("chatgpt_account_id", "") - - token_data = TokenData( - id_token=payload["id_token"], - access_token=payload["access_token"], - refresh_token=payload["refresh_token"], - account_id=chatgpt_account_id, - ) - - access_token_parts = token_data.access_token.split(".") - if len(access_token_parts) != 3: - raise ValueError("Invalid access token") - - access_token_claims = _decode_jwt_segment(access_token_parts[1]) - - token_claims = id_token_claims.get("https://api.openai.com/auth", {}) - access_claims = access_token_claims.get("https://api.openai.com/auth", {}) - org_id = token_claims.get("organization_id") - if not org_id: - raise ValueError("Missing organization in id_token claims") - project_id = token_claims.get("project_id") - if not project_id: - raise ValueError("Missing project in id_token claims") + + if not org_id or not project_id: + return (None, None) random_id = secrets.token_hex(6) @@ -292,7 +251,7 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler): exchanged_access_token: str with urllib.request.urlopen( urllib.request.Request( - token_endpoint, + self.server.token_endpoint, data=exchange_data, method="POST", headers={"Content-Type": "application/x-www-form-urlencoded"}, @@ -340,6 +299,65 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler): except Exception as exc: # pragma: no cover – best-effort only eprint(f"Unable to redeem ChatGPT subscriber API credits: {exc}") + return (exchanged_access_token, success_url) + + def _exchange_code(self, code: str) -> tuple[AuthBundle, str]: + """Perform token + token-exchange to obtain an OpenAI API key. + + Returns (AuthBundle, success_url). + """ + + # 1. Authorization-code -> (id_token, access_token, refresh_token) + data = urllib.parse.urlencode( + { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self.server.redirect_uri, + "client_id": self.server.client_id, + "code_verifier": self.server.pkce.code_verifier, + } + ).encode() + + token_data: TokenData + + with urllib.request.urlopen( + urllib.request.Request( + self.server.token_endpoint, + data=data, + method="POST", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + ) as resp: + payload = json.loads(resp.read().decode()) + + # Extract chatgpt_account_id from id_token + id_token_parts = payload["id_token"].split(".") + if len(id_token_parts) != 3: + raise ValueError("Invalid ID token") + id_token_claims = _decode_jwt_segment(id_token_parts[1]) + auth_claims = id_token_claims.get("https://api.openai.com/auth", {}) + chatgpt_account_id = auth_claims.get("chatgpt_account_id", "") + + token_data = TokenData( + id_token=payload["id_token"], + access_token=payload["access_token"], + refresh_token=payload["refresh_token"], + account_id=chatgpt_account_id, + ) + + access_token_parts = token_data.access_token.split(".") + if len(access_token_parts) != 3: + raise ValueError("Invalid access token") + + access_token_claims = _decode_jwt_segment(access_token_parts[1]) + + token_claims = id_token_claims.get("https://api.openai.com/auth", {}) + access_claims = access_token_claims.get("https://api.openai.com/auth", {}) + + exchanged_access_token, success_url = self._obtain_api_key( + token_claims, access_claims, token_data + ) + # Persist refresh_token/id_token for future use (redeem credits etc.) last_refresh_str = ( datetime.datetime.now(datetime.timezone.utc) @@ -353,7 +371,7 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler): last_refresh=last_refresh_str, ) - return (auth_bundle, success_url) + return (auth_bundle, success_url or f"{URL_BASE}/success") def request_shutdown(self) -> None: # shutdown() must be invoked from another thread to avoid @@ -413,6 +431,7 @@ class _ApiKeyHTTPServer(http.server.HTTPServer): request_handler_class: type[http.server.BaseHTTPRequestHandler], *, codex_home: str, + client_id: str, verbose: bool = False, ) -> None: super().__init__(server_address, request_handler_class, bind_and_activate=True) @@ -422,7 +441,8 @@ class _ApiKeyHTTPServer(http.server.HTTPServer): self.verbose: bool = verbose self.issuer: str = DEFAULT_ISSUER - self.client_id: str = DEFAULT_CLIENT_ID + self.token_endpoint: str = f"{self.issuer}/oauth/token" + self.client_id: str = client_id port = server_address[1] self.redirect_uri: str = f"http://localhost:{port}/auth/callback" self.pkce: PkceCodes = _generate_pkce() @@ -581,8 +601,8 @@ def maybe_redeem_credits( granted = redeem_data.get("granted_chatgpt_subscriber_api_credits", 0) if granted and granted > 0: eprint( - f"""Thanks for being a ChatGPT {'Plus' if plan_type=='plus' else 'Pro'} subscriber! -If you haven't already redeemed, you should receive {'$5' if plan_type=='plus' else '$50'} in API credits. + f"""Thanks for being a ChatGPT {"Plus" if plan_type == "plus" else "Pro"} subscriber! +If you haven't already redeemed, you should receive {"$5" if plan_type == "plus" else "$50"} in API credits. Credits: https://platform.openai.com/settings/organization/billing/credit-grants More info: https://help.openai.com/en/articles/11381614""", diff --git a/codex-rs/mcp-client/src/main.rs b/codex-rs/mcp-client/src/main.rs index 518383d1ea..10cfe389bf 100644 --- a/codex-rs/mcp-client/src/main.rs +++ b/codex-rs/mcp-client/src/main.rs @@ -10,6 +10,7 @@ //! program. The utility connects, issues a `tools/list` request and prints the //! server's response as pretty JSON. +use std::ffi::OsString; use std::time::Duration; use anyhow::Context; @@ -37,7 +38,7 @@ async fn main() -> Result<()> { .try_init(); // Collect command-line arguments excluding the program name itself. - let mut args: Vec = std::env::args().skip(1).collect(); + let mut args: Vec = std::env::args_os().skip(1).collect(); if args.is_empty() || args[0] == "--help" || args[0] == "-h" { eprintln!("Usage: mcp-client [args..]\n\nExample: mcp-client codex-mcp-server"); @@ -57,10 +58,12 @@ async fn main() -> Result<()> { experimental: None, roots: None, sampling: None, + elicitation: None, }, client_info: Implementation { name: "codex-mcp-client".to_owned(), version: env!("CARGO_PKG_VERSION").to_owned(), + title: Some("Codex".to_string()), }, protocol_version: MCP_SCHEMA_VERSION.to_owned(), }; diff --git a/codex-rs/mcp-client/src/mcp_client.rs b/codex-rs/mcp-client/src/mcp_client.rs index 6a9111e69f..084d0bf4ba 100644 --- a/codex-rs/mcp-client/src/mcp_client.rs +++ b/codex-rs/mcp-client/src/mcp_client.rs @@ -12,6 +12,7 @@ //! issue requests and receive strongly-typed results. use std::collections::HashMap; +use std::ffi::OsString; use std::sync::Arc; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; @@ -82,8 +83,8 @@ impl McpClient { /// Caller is responsible for sending the `initialize` request. See /// [`initialize`](Self::initialize) for details. pub async fn new_stdio_client( - program: String, - args: Vec, + program: OsString, + args: Vec, env: Option>, ) -> std::io::Result { let mut child = Command::new(program) diff --git a/codex-rs/mcp-server/Cargo.toml b/codex-rs/mcp-server/Cargo.toml index f91a3dc8f8..2f618808c1 100644 --- a/codex-rs/mcp-server/Cargo.toml +++ b/codex-rs/mcp-server/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-mcp-server" version = { workspace = true } -edition = "2024" [[bin]] name = "codex-mcp-server" @@ -16,15 +16,14 @@ workspace = true [dependencies] anyhow = "1" +codex-arg0 = { path = "../arg0" } codex-core = { path = "../core" } -codex-linux-sandbox = { path = "../linux-sandbox" } mcp-types = { path = "../mcp-types" } schemars = "0.8.22" serde = { version = "1", features = ["derive"] } serde_json = "1" -toml = "0.9" -tracing = { version = "0.1.41", features = ["log"] } -tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } +shlex = "1.3.0" +strum_macros = "0.27.2" tokio = { version = "1", features = [ "io-std", "macros", @@ -32,6 +31,15 @@ tokio = { version = "1", features = [ "rt-multi-thread", "signal", ] } +toml = "0.9" +tracing = { version = "0.1.41", features = ["log"] } +tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } +uuid = { version = "1", features = ["serde", "v4"] } [dev-dependencies] +assert_cmd = "2" +mcp_test_support = { path = "tests/common" } pretty_assertions = "1.4.1" +tempfile = "3" +tokio-test = "0.4" +wiremock = "0.6" diff --git a/codex-rs/mcp-server/src/codex_tool_config.rs b/codex-rs/mcp-server/src/codex_tool_config.rs index 8555524942..877d0e05f7 100644 --- a/codex-rs/mcp-server/src/codex_tool_config.rs +++ b/codex-rs/mcp-server/src/codex_tool_config.rs @@ -7,15 +7,16 @@ use mcp_types::ToolInputSchema; use schemars::JsonSchema; use schemars::r#gen::SchemaSettings; use serde::Deserialize; +use serde::Serialize; use std::collections::HashMap; use std::path::PathBuf; use crate::json_to_toml::json_to_toml; /// Client-supplied configuration for a `codex` tool-call. -#[derive(Debug, Clone, Deserialize, JsonSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)] #[serde(rename_all = "kebab-case")] -pub(crate) struct CodexToolCallParam { +pub struct CodexToolCallParam { /// The *initial user prompt* to start the Codex conversation. pub prompt: String, @@ -45,13 +46,21 @@ pub(crate) struct CodexToolCallParam { /// CODEX_HOME/config.toml. #[serde(default, skip_serializing_if = "Option::is_none")] pub config: Option>, + + /// The set of instructions to use instead of the default ones. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub base_instructions: Option, + + /// Whether to include the plan tool in the conversation. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub include_plan_tool: Option, } /// Custom enum mirroring [`AskForApproval`], but has an extra dependency on /// [`JsonSchema`]. -#[derive(Debug, Clone, Deserialize, JsonSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "kebab-case")] -pub(crate) enum CodexToolCallApprovalPolicy { +pub enum CodexToolCallApprovalPolicy { Untrusted, OnFailure, Never, @@ -69,9 +78,9 @@ impl From for AskForApproval { /// Custom enum mirroring [`SandboxMode`] from config_types.rs, but with /// `JsonSchema` support. -#[derive(Debug, Clone, Deserialize, JsonSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "kebab-case")] -pub(crate) enum CodexToolCallSandboxMode { +pub enum CodexToolCallSandboxMode { ReadOnly, WorkspaceWrite, DangerFullAccess, @@ -108,7 +117,10 @@ pub(crate) fn create_tool_for_codex_tool_call_param() -> Tool { Tool { name: "codex".to_string(), + title: Some("Codex".to_string()), input_schema: tool_input_schema, + // TODO(mbolin): This should be defined. + output_schema: None, description: Some( "Run a Codex session. Accepts configuration parameters matching the Codex Config struct.".to_string(), ), @@ -131,9 +143,11 @@ impl CodexToolCallParam { approval_policy, sandbox, config: cli_overrides, + base_instructions, + include_plan_tool, } = self; - // Build the `ConfigOverrides` recognised by codex-core. + // Build the `ConfigOverrides` recognized by codex-core. let overrides = codex_core::config::ConfigOverrides { model, config_profile: profile, @@ -142,6 +156,8 @@ impl CodexToolCallParam { sandbox_mode: sandbox.map(Into::into), model_provider: None, codex_linux_sandbox_exe, + base_instructions, + include_plan_tool, }; let cli_overrides = cli_overrides @@ -156,6 +172,47 @@ impl CodexToolCallParam { } } +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub struct CodexToolCallReplyParam { + /// The *session id* for this conversation. + pub session_id: String, + + /// The *next user prompt* to continue the Codex conversation. + pub prompt: String, +} + +/// Builds a `Tool` definition for the `codex-reply` tool-call. +pub(crate) fn create_tool_for_codex_tool_call_reply_param() -> Tool { + let schema = SchemaSettings::draft2019_09() + .with(|s| { + s.inline_subschemas = true; + s.option_add_null_type = false; + }) + .into_generator() + .into_root_schema_for::(); + + #[expect(clippy::expect_used)] + let schema_value = + serde_json::to_value(&schema).expect("Codex reply tool schema should serialise to JSON"); + + let tool_input_schema = + serde_json::from_value::(schema_value).unwrap_or_else(|e| { + panic!("failed to create Tool from schema: {e}"); + }); + + Tool { + name: "codex-reply".to_string(), + title: Some("Codex Reply".to_string()), + input_schema: tool_input_schema, + output_schema: None, + description: Some( + "Continue a Codex session by providing the session id and prompt.".to_string(), + ), + annotations: None, + } +} + #[cfg(test)] mod tests { use super::*; @@ -179,6 +236,7 @@ mod tests { let tool_json = serde_json::to_value(&tool).expect("tool serializes"); let expected_tool_json = serde_json::json!({ "name": "codex", + "title": "Codex", "description": "Run a Codex session. Accepts configuration parameters matching the Codex Config struct.", "inputSchema": { "type": "object", @@ -210,6 +268,10 @@ mod tests { "description": "Working directory for the session. If relative, it is resolved against the server process's current working directory.", "type": "string" }, + "include-plan-tool": { + "description": "Whether to include the plan tool in the conversation.", + "type": "boolean" + }, "model": { "description": "Optional override for the model name (e.g. \"o3\", \"o4-mini\").", "type": "string" @@ -222,6 +284,10 @@ mod tests { "description": "The *initial user prompt* to start the Codex conversation.", "type": "string" }, + "base-instructions": { + "description": "The set of instructions to use instead of the default ones.", + "type": "string" + }, }, "required": [ "prompt" @@ -230,4 +296,34 @@ mod tests { }); assert_eq!(expected_tool_json, tool_json); } + + #[test] + fn verify_codex_tool_reply_json_schema() { + let tool = create_tool_for_codex_tool_call_reply_param(); + #[expect(clippy::expect_used)] + let tool_json = serde_json::to_value(&tool).expect("tool serializes"); + let expected_tool_json = serde_json::json!({ + "description": "Continue a Codex session by providing the session id and prompt.", + "inputSchema": { + "properties": { + "prompt": { + "description": "The *next user prompt* to continue the Codex conversation.", + "type": "string" + }, + "sessionId": { + "description": "The *session id* for this conversation.", + "type": "string" + }, + }, + "required": [ + "prompt", + "sessionId", + ], + "type": "object", + }, + "name": "codex-reply", + "title": "Codex Reply", + }); + assert_eq!(expected_tool_json, tool_json); + } } diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index 26f0636ae6..8bbf17c2ea 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -2,33 +2,35 @@ //! Tokio task. Separated from `message_processor.rs` to keep that file small //! and to make future feature-growth easier to manage. +use std::collections::HashMap; +use std::sync::Arc; + +use codex_core::Codex; +use codex_core::codex_wrapper::CodexConversation; use codex_core::codex_wrapper::init_codex; use codex_core::config::Config as CodexConfig; use codex_core::protocol::AgentMessageEvent; -use codex_core::protocol::Event; +use codex_core::protocol::ApplyPatchApprovalRequestEvent; use codex_core::protocol::EventMsg; +use codex_core::protocol::ExecApprovalRequestEvent; use codex_core::protocol::InputItem; use codex_core::protocol::Op; use codex_core::protocol::Submission; use codex_core::protocol::TaskCompleteEvent; use mcp_types::CallToolResult; -use mcp_types::CallToolResultContent; -use mcp_types::JSONRPC_VERSION; -use mcp_types::JSONRPCMessage; -use mcp_types::JSONRPCResponse; +use mcp_types::ContentBlock; use mcp_types::RequestId; use mcp_types::TextContent; -use tokio::sync::mpsc::Sender; +use serde_json::json; +use tokio::sync::Mutex; +use uuid::Uuid; -/// Convert a Codex [`Event`] to an MCP notification. -fn codex_event_to_notification(event: &Event) -> JSONRPCMessage { - #[expect(clippy::expect_used)] - JSONRPCMessage::Notification(mcp_types::JSONRPCNotification { - jsonrpc: JSONRPC_VERSION.into(), - method: "codex/event".into(), - params: Some(serde_json::to_value(event).expect("Event must serialize")), - }) -} +use crate::exec_approval::handle_exec_approval_request; +use crate::outgoing_message::OutgoingMessageSender; +use crate::outgoing_message::OutgoingNotificationMeta; +use crate::patch_approval::handle_patch_approval_request; + +pub(crate) const INVALID_PARAMS_ERROR_CODE: i64 = -32602; /// Run a complete Codex session and stream events back to the client. /// @@ -38,33 +40,43 @@ pub async fn run_codex_tool_session( id: RequestId, initial_prompt: String, config: CodexConfig, - outgoing: Sender, + outgoing: Arc, + session_map: Arc>>>, + running_requests_id_to_codex_uuid: Arc>>, ) { - let (codex, first_event, _ctrl_c) = match init_codex(config).await { + let CodexConversation { + codex, + session_configured, + session_id, + .. + } = match init_codex(config).await { Ok(res) => res, Err(e) => { let result = CallToolResult { - content: vec![CallToolResultContent::TextContent(TextContent { + content: vec![ContentBlock::TextContent(TextContent { r#type: "text".to_string(), text: format!("Failed to start Codex session: {e}"), annotations: None, })], is_error: Some(true), + structured_content: None, }; - let _ = outgoing - .send(JSONRPCMessage::Response(JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id, - result: result.into(), - })) - .await; + outgoing.send_response(id.clone(), result.into()).await; return; } }; + let codex = Arc::new(codex); - // Send initial SessionConfigured event. - let _ = outgoing - .send(codex_event_to_notification(&first_event)) + // update the session map so we can retrieve the session in a reply, and then drop it, since + // we no longer need it for this function + session_map.lock().await.insert(session_id, codex.clone()); + drop(session_map); + + outgoing + .send_event_as_notification( + &session_configured, + Some(OutgoingNotificationMeta::new(Some(id.clone()))), + ) .await; // Use the original MCP request ID as the `sub_id` for the Codex submission so that @@ -74,9 +86,12 @@ pub async fn run_codex_tool_session( RequestId::String(s) => s.clone(), RequestId::Integer(n) => n.to_string(), }; - + running_requests_id_to_codex_uuid + .lock() + .await + .insert(id.clone(), session_id); let submission = Submission { - id: sub_id, + id: sub_id.clone(), op: Op::UserInput { items: vec![InputItem::Text { text: initial_prompt.clone(), @@ -86,93 +101,158 @@ pub async fn run_codex_tool_session( if let Err(e) = codex.submit_with_id(submission).await { tracing::error!("Failed to submit initial prompt: {e}"); + // unregister the id so we don't keep it in the map + running_requests_id_to_codex_uuid.lock().await.remove(&id); + return; } - let mut last_agent_message: Option = None; + run_codex_tool_session_inner(codex, outgoing, id, running_requests_id_to_codex_uuid).await; +} + +pub async fn run_codex_tool_session_reply( + codex: Arc, + outgoing: Arc, + request_id: RequestId, + prompt: String, + running_requests_id_to_codex_uuid: Arc>>, + session_id: Uuid, +) { + running_requests_id_to_codex_uuid + .lock() + .await + .insert(request_id.clone(), session_id); + if let Err(e) = codex + .submit(Op::UserInput { + items: vec![InputItem::Text { text: prompt }], + }) + .await + { + tracing::error!("Failed to submit user input: {e}"); + // unregister the id so we don't keep it in the map + running_requests_id_to_codex_uuid + .lock() + .await + .remove(&request_id); + return; + } + + run_codex_tool_session_inner( + codex, + outgoing, + request_id, + running_requests_id_to_codex_uuid, + ) + .await; +} + +async fn run_codex_tool_session_inner( + codex: Arc, + outgoing: Arc, + request_id: RequestId, + running_requests_id_to_codex_uuid: Arc>>, +) { + let request_id_str = match &request_id { + RequestId::String(s) => s.clone(), + RequestId::Integer(n) => n.to_string(), + }; // Stream events until the task needs to pause for user interaction or // completes. loop { match codex.next_event().await { Ok(event) => { - let _ = outgoing.send(codex_event_to_notification(&event)).await; + outgoing + .send_event_as_notification( + &event, + Some(OutgoingNotificationMeta::new(Some(request_id.clone()))), + ) + .await; - match &event.msg { - EventMsg::AgentMessage(AgentMessageEvent { message }) => { - last_agent_message = Some(message.clone()); - } - EventMsg::ExecApprovalRequest(_) => { - let result = CallToolResult { - content: vec![CallToolResultContent::TextContent(TextContent { - r#type: "text".to_string(), - text: "EXEC_APPROVAL_REQUIRED".to_string(), - annotations: None, - })], - is_error: None, - }; - let _ = outgoing - .send(JSONRPCMessage::Response(JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id: id.clone(), - result: result.into(), - })) - .await; - break; - } - EventMsg::ApplyPatchApprovalRequest(_) => { - let result = CallToolResult { - content: vec![CallToolResultContent::TextContent(TextContent { - r#type: "text".to_string(), - text: "PATCH_APPROVAL_REQUIRED".to_string(), - annotations: None, - })], - is_error: None, - }; - let _ = outgoing - .send(JSONRPCMessage::Response(JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id: id.clone(), - result: result.into(), - })) - .await; - break; - } - EventMsg::TaskComplete(TaskCompleteEvent { - last_agent_message: _, + match event.msg { + EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { + command, + cwd, + call_id, + reason: _, }) => { - let result = if let Some(msg) = last_agent_message { - CallToolResult { - content: vec![CallToolResultContent::TextContent(TextContent { - r#type: "text".to_string(), - text: msg, - annotations: None, - })], - is_error: None, - } - } else { - CallToolResult { - content: vec![CallToolResultContent::TextContent(TextContent { - r#type: "text".to_string(), - text: String::new(), - annotations: None, - })], - is_error: None, - } + handle_exec_approval_request( + command, + cwd, + outgoing.clone(), + codex.clone(), + request_id.clone(), + request_id_str.clone(), + event.id.clone(), + call_id, + ) + .await; + continue; + } + EventMsg::Error(err_event) => { + // Return a response to conclude the tool call when the Codex session reports an error (e.g., interruption). + let result = json!({ + "error": err_event.message, + }); + outgoing.send_response(request_id.clone(), result).await; + break; + } + EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { + call_id, + reason, + grant_root, + changes, + }) => { + handle_patch_approval_request( + call_id, + reason, + grant_root, + changes, + outgoing.clone(), + codex.clone(), + request_id.clone(), + request_id_str.clone(), + event.id.clone(), + ) + .await; + continue; + } + EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => { + let text = match last_agent_message { + Some(msg) => msg.clone(), + None => "".to_string(), }; - let _ = outgoing - .send(JSONRPCMessage::Response(JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id: id.clone(), - result: result.into(), - })) + let result = CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_string(), + text, + annotations: None, + })], + is_error: None, + structured_content: None, + }; + outgoing + .send_response(request_id.clone(), result.into()) .await; + // unregister the id so we don't keep it in the map + running_requests_id_to_codex_uuid + .lock() + .await + .remove(&request_id); break; } EventMsg::SessionConfigured(_) => { tracing::error!("unexpected SessionConfigured event"); } - EventMsg::Error(_) - | EventMsg::TaskStarted + EventMsg::AgentMessageDelta(_) => { + // TODO: think how we want to support this in the MCP + } + EventMsg::AgentReasoningDelta(_) => { + // TODO: think how we want to support this in the MCP + } + EventMsg::AgentMessage(AgentMessageEvent { .. }) => { + // TODO: think how we want to support this in the MCP + } + EventMsg::TaskStarted | EventMsg::TokenCount(_) | EventMsg::AgentReasoning(_) | EventMsg::AgentReasoningContent(_) @@ -183,7 +263,9 @@ pub async fn run_codex_tool_session( | EventMsg::BackgroundEvent(_) | EventMsg::PatchApplyBegin(_) | EventMsg::PatchApplyEnd(_) - | EventMsg::GetHistoryEntryResponse(_) => { + | EventMsg::GetHistoryEntryResponse(_) + | EventMsg::PlanUpdate(_) + | EventMsg::ShutdownComplete => { // For now, we do not do anything extra for these // events. Note that // send(codex_event_to_notification(&event)) above has @@ -195,19 +277,18 @@ pub async fn run_codex_tool_session( } Err(e) => { let result = CallToolResult { - content: vec![CallToolResultContent::TextContent(TextContent { + content: vec![ContentBlock::TextContent(TextContent { r#type: "text".to_string(), text: format!("Codex runtime error: {e}"), annotations: None, })], is_error: Some(true), + // TODO(mbolin): Could present the error in a more + // structured way. + structured_content: None, }; - let _ = outgoing - .send(JSONRPCMessage::Response(JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id: id.clone(), - result: result.into(), - })) + outgoing + .send_response(request_id.clone(), result.into()) .await; break; } diff --git a/codex-rs/mcp-server/src/exec_approval.rs b/codex-rs/mcp-server/src/exec_approval.rs new file mode 100644 index 0000000000..f073214bf5 --- /dev/null +++ b/codex-rs/mcp-server/src/exec_approval.rs @@ -0,0 +1,149 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use codex_core::Codex; +use codex_core::protocol::Op; +use codex_core::protocol::ReviewDecision; +use mcp_types::ElicitRequest; +use mcp_types::ElicitRequestParamsRequestedSchema; +use mcp_types::JSONRPCErrorError; +use mcp_types::ModelContextProtocolRequest; +use mcp_types::RequestId; +use serde::Deserialize; +use serde::Serialize; +use serde_json::json; +use tracing::error; + +use crate::codex_tool_runner::INVALID_PARAMS_ERROR_CODE; + +/// Conforms to [`mcp_types::ElicitRequestParams`] so that it can be used as the +/// `params` field of an [`ElicitRequest`]. +#[derive(Debug, Serialize)] +pub struct ExecApprovalElicitRequestParams { + // These fields are required so that `params` + // conforms to ElicitRequestParams. + pub message: String, + + #[serde(rename = "requestedSchema")] + pub requested_schema: ElicitRequestParamsRequestedSchema, + + // These are additional fields the client can use to + // correlate the request with the codex tool call. + pub codex_elicitation: String, + pub codex_mcp_tool_call_id: String, + pub codex_event_id: String, + pub codex_call_id: String, + pub codex_command: Vec, + pub codex_cwd: PathBuf, +} + +// TODO(mbolin): ExecApprovalResponse does not conform to ElicitResult. See: +// - https://github.com/modelcontextprotocol/modelcontextprotocol/blob/f962dc1780fa5eed7fb7c8a0232f1fc83ef220cd/schema/2025-06-18/schema.json#L617-L636 +// - https://modelcontextprotocol.io/specification/draft/client/elicitation#protocol-messages +// It should have "action" and "content" fields. +#[derive(Debug, Serialize, Deserialize)] +pub struct ExecApprovalResponse { + pub decision: ReviewDecision, +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn handle_exec_approval_request( + command: Vec, + cwd: PathBuf, + outgoing: Arc, + codex: Arc, + request_id: RequestId, + tool_call_id: String, + event_id: String, + call_id: String, +) { + let escaped_command = + shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" ")); + let message = format!( + "Allow Codex to run `{escaped_command}` in `{cwd}`?", + cwd = cwd.to_string_lossy() + ); + + let params = ExecApprovalElicitRequestParams { + message, + requested_schema: ElicitRequestParamsRequestedSchema { + r#type: "object".to_string(), + properties: json!({}), + required: None, + }, + codex_elicitation: "exec-approval".to_string(), + codex_mcp_tool_call_id: tool_call_id.clone(), + codex_event_id: event_id.clone(), + codex_call_id: call_id, + codex_command: command, + codex_cwd: cwd, + }; + let params_json = match serde_json::to_value(¶ms) { + Ok(value) => value, + Err(err) => { + let message = format!("Failed to serialize ExecApprovalElicitRequestParams: {err}"); + error!("{message}"); + + outgoing + .send_error( + request_id.clone(), + JSONRPCErrorError { + code: INVALID_PARAMS_ERROR_CODE, + message, + data: None, + }, + ) + .await; + + return; + } + }; + + let on_response = outgoing + .send_request(ElicitRequest::METHOD, Some(params_json)) + .await; + + // Listen for the response on a separate task so we don't block the main agent loop. + { + let codex = codex.clone(); + let event_id = event_id.clone(); + tokio::spawn(async move { + on_exec_approval_response(event_id, on_response, codex).await; + }); + } +} + +async fn on_exec_approval_response( + event_id: String, + receiver: tokio::sync::oneshot::Receiver, + codex: Arc, +) { + let response = receiver.await; + let value = match response { + Ok(value) => value, + Err(err) => { + error!("request failed: {err:?}"); + return; + } + }; + + // Try to deserialize `value` and then make the appropriate call to `codex`. + let response = serde_json::from_value::(value).unwrap_or_else(|err| { + error!("failed to deserialize ExecApprovalResponse: {err}"); + // If we cannot deserialize the response, we deny the request to be + // conservative. + ExecApprovalResponse { + decision: ReviewDecision::Denied, + } + }); + + if let Err(err) = codex + .submit(Op::ExecApproval { + id: event_id, + decision: response.decision, + }) + .await + { + error!("failed to submit ExecApproval: {err}"); + } +} diff --git a/codex-rs/mcp-server/src/lib.rs b/codex-rs/mcp-server/src/lib.rs index b2a7797fe6..0912fed118 100644 --- a/codex-rs/mcp-server/src/lib.rs +++ b/codex-rs/mcp-server/src/lib.rs @@ -13,13 +13,27 @@ use tokio::sync::mpsc; use tracing::debug; use tracing::error; use tracing::info; +use tracing_subscriber::EnvFilter; mod codex_tool_config; mod codex_tool_runner; +mod exec_approval; mod json_to_toml; +mod mcp_protocol; mod message_processor; +mod outgoing_message; +mod patch_approval; use crate::message_processor::MessageProcessor; +use crate::outgoing_message::OutgoingMessage; +use crate::outgoing_message::OutgoingMessageSender; + +pub use crate::codex_tool_config::CodexToolCallParam; +pub use crate::codex_tool_config::CodexToolCallReplyParam; +pub use crate::exec_approval::ExecApprovalElicitRequestParams; +pub use crate::exec_approval::ExecApprovalResponse; +pub use crate::patch_approval::PatchApprovalElicitRequestParams; +pub use crate::patch_approval::PatchApprovalResponse; /// Size of the bounded channels used to communicate between tasks. The value /// is a balance between throughput and memory usage – 128 messages should be @@ -31,11 +45,12 @@ pub async fn run_main(codex_linux_sandbox_exe: Option) -> IoResult<()> // control the log level with `RUST_LOG`. tracing_subscriber::fmt() .with_writer(std::io::stderr) + .with_env_filter(EnvFilter::from_default_env()) .init(); // Set up channels. let (incoming_tx, mut incoming_rx) = mpsc::channel::(CHANNEL_CAPACITY); - let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); + let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); // Task: read from stdin, push to `incoming_tx`. let stdin_reader_handle = tokio::spawn({ @@ -63,16 +78,15 @@ pub async fn run_main(codex_linux_sandbox_exe: Option) -> IoResult<()> // Task: process incoming messages. let processor_handle = tokio::spawn({ - let mut processor = MessageProcessor::new(outgoing_tx.clone(), codex_linux_sandbox_exe); + let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx); + let mut processor = MessageProcessor::new(outgoing_message_sender, codex_linux_sandbox_exe); async move { while let Some(msg) = incoming_rx.recv().await { match msg { - JSONRPCMessage::Request(r) => processor.process_request(r), - JSONRPCMessage::Response(r) => processor.process_response(r), - JSONRPCMessage::Notification(n) => processor.process_notification(n), - JSONRPCMessage::BatchRequest(b) => processor.process_batch_request(b), + JSONRPCMessage::Request(r) => processor.process_request(r).await, + JSONRPCMessage::Response(r) => processor.process_response(r).await, + JSONRPCMessage::Notification(n) => processor.process_notification(n).await, JSONRPCMessage::Error(e) => processor.process_error(e), - JSONRPCMessage::BatchResponse(b) => processor.process_batch_response(b), } } @@ -83,7 +97,8 @@ pub async fn run_main(codex_linux_sandbox_exe: Option) -> IoResult<()> // Task: write outgoing messages to stdout. let stdout_writer_handle = tokio::spawn(async move { let mut stdout = io::stdout(); - while let Some(msg) = outgoing_rx.recv().await { + while let Some(outgoing_message) = outgoing_rx.recv().await { + let msg: JSONRPCMessage = outgoing_message.into(); match serde_json::to_string(&msg) { Ok(json) => { if let Err(e) = stdout.write_all(json.as_bytes()).await { diff --git a/codex-rs/mcp-server/src/main.rs b/codex-rs/mcp-server/src/main.rs index 51c46c44d2..60ddeeab41 100644 --- a/codex-rs/mcp-server/src/main.rs +++ b/codex-rs/mcp-server/src/main.rs @@ -1,7 +1,8 @@ +use codex_arg0::arg0_dispatch_or_else; use codex_mcp_server::run_main; fn main() -> anyhow::Result<()> { - codex_linux_sandbox::run_with_sandbox(|codex_linux_sandbox_exe| async move { + arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move { run_main(codex_linux_sandbox_exe).await?; Ok(()) }) diff --git a/codex-rs/mcp-server/src/mcp_protocol.rs b/codex-rs/mcp-server/src/mcp_protocol.rs new file mode 100644 index 0000000000..e507376c16 --- /dev/null +++ b/codex-rs/mcp-server/src/mcp_protocol.rs @@ -0,0 +1,962 @@ +use codex_core::config_types::SandboxMode; +use codex_core::protocol::AskForApproval; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use serde::Deserialize; +use serde::Serialize; +use strum_macros::Display; +use uuid::Uuid; + +use mcp_types::RequestId; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct ConversationId(pub Uuid); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct MessageId(pub Uuid); + +// Requests +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallRequest { + #[serde(rename = "jsonrpc")] + pub jsonrpc: &'static str, + pub id: RequestId, + pub method: &'static str, + pub params: ToolCallRequestParams, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "name", content = "arguments", rename_all = "camelCase")] +pub enum ToolCallRequestParams { + ConversationCreate(ConversationCreateArgs), + ConversationStream(ConversationStreamArgs), + ConversationSendMessage(ConversationSendMessageArgs), + ConversationsList(ConversationsListArgs), +} + +impl ToolCallRequestParams { + /// Wrap this request in a JSON-RPC request. + #[allow(dead_code)] + pub fn into_request(self, id: RequestId) -> ToolCallRequest { + ToolCallRequest { + jsonrpc: "2.0", + id, + method: "tools/call", + params: self, + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ConversationCreateArgs { + pub prompt: String, + pub model: String, + pub cwd: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub approval_policy: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub sandbox: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub config: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub profile: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub base_instructions: Option, +} + +/// Optional overrides for an existing conversation's execution context when sending a message. +/// Fields left as `None` inherit the current conversation/session settings. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ConversationOverrides { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cwd: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub approval_policy: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub sandbox: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub config: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub profile: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub base_instructions: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ConversationStreamArgs { + pub conversation_id: ConversationId, +} + +/// If omitted, the message continues from the latest turn. +/// Set to resume/edit from an earlier parent message in the thread. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ConversationSendMessageArgs { + pub conversation_id: ConversationId, + pub content: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub parent_message_id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[serde(flatten)] + pub conversation_overrides: Option, +} +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ConversationsListArgs { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub limit: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cursor: Option, +} + +// Responses +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolCallResponse { + pub request_id: RequestId, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub is_error: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub result: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolCallResponseResult { + ConversationCreate(ConversationCreateResult), + ConversationStream(ConversationStreamResult), + ConversationSendMessage(ConversationSendMessageResult), + ConversationsList(ConversationsListResult), +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ConversationCreateResult { + pub conversation_id: ConversationId, + pub model: String, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ConversationStreamResult {} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ConversationSendMessageResult { + pub success: bool, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ConversationsListResult { + pub conversations: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ConversationSummary { + pub conversation_id: ConversationId, + pub title: String, +} + +// Notifications +#[derive(Debug, Clone, Deserialize, Display)] +pub enum ServerNotification { + InitialState(InitialStateNotificationParams), + StreamDisconnected(StreamDisconnectedNotificationParams), + CodexEvent(Box), +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NotificationMeta { + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub request_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InitialStateNotificationParams { + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, + pub initial_state: InitialStatePayload, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InitialStatePayload { + #[serde(default)] + pub events: Vec, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct StreamDisconnectedNotificationParams { + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, + pub reason: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CodexEventNotificationParams { + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, + pub msg: EventMsg, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CancelNotificationParams { + pub request_id: RequestId, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reason: Option, +} + +impl Serialize for ServerNotification { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeMap; + + let mut map = serializer.serialize_map(Some(2))?; + match self { + ServerNotification::CodexEvent(p) => { + map.serialize_entry("method", &format!("notifications/{}", p.msg))?; + map.serialize_entry("params", p)?; + } + ServerNotification::InitialState(p) => { + map.serialize_entry("method", "notifications/initial_state")?; + map.serialize_entry("params", p)?; + } + ServerNotification::StreamDisconnected(p) => { + map.serialize_entry("method", "notifications/stream_disconnected")?; + map.serialize_entry("params", p)?; + } + } + map.end() + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "method", content = "params", rename_all = "camelCase")] +pub enum ClientNotification { + #[serde(rename = "notifications/cancelled")] + Cancelled(CancelNotificationParams), +} + +#[cfg(test)] +#[allow(clippy::expect_used)] +#[allow(clippy::unwrap_used)] +mod tests { + use std::path::PathBuf; + + use super::*; + use codex_core::protocol::McpInvocation; + use codex_core::protocol::McpToolCallBeginEvent; + use pretty_assertions::assert_eq; + use serde::Serialize; + use serde_json::Value; + use serde_json::json; + use uuid::uuid; + + fn to_val(v: &T) -> Value { + serde_json::to_value(v).expect("serialize to Value") + } + + // ----- Requests ----- + + #[test] + fn serialize_tool_call_request_params_conversation_create_minimal() { + let req = ToolCallRequestParams::ConversationCreate(ConversationCreateArgs { + prompt: "".into(), + model: "o3".into(), + cwd: "/repo".into(), + approval_policy: None, + sandbox: None, + config: None, + profile: None, + base_instructions: None, + }); + + let observed = to_val(&req.into_request(mcp_types::RequestId::Integer(2))); + let expected = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "conversationCreate", + "arguments": { + "prompt": "", + "model": "o3", + "cwd": "/repo" + } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_tool_call_request_params_conversation_send_message_with_overrides_and_parent_message_id() + { + let req = ToolCallRequestParams::ConversationSendMessage(ConversationSendMessageArgs { + conversation_id: ConversationId(uuid!("d0f6ecbe-84a2-41c1-b23d-b20473b25eab")), + content: vec![ + InputItem::Text { text: "Hi".into() }, + InputItem::Image { + image_url: "https://example.com/cat.jpg".into(), + }, + InputItem::LocalImage { + path: "notes.txt".into(), + }, + ], + parent_message_id: Some(MessageId(uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"))), + conversation_overrides: Some(ConversationOverrides { + model: Some("o4-mini".into()), + cwd: Some("/workdir".into()), + approval_policy: None, + sandbox: Some(SandboxMode::DangerFullAccess), + config: Some(json!({"temp": 0.2})), + profile: Some("eng".into()), + base_instructions: Some("Be terse".into()), + }), + }); + + let observed = to_val(&req.into_request(mcp_types::RequestId::Integer(2))); + let expected = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "conversationSendMessage", + "arguments": { + "conversation_id": "d0f6ecbe-84a2-41c1-b23d-b20473b25eab", + "content": [ + { "type": "text", "text": "Hi" }, + { "type": "image", "image_url": "https://example.com/cat.jpg" }, + { "type": "local_image", "path": "notes.txt" } + ], + "parent_message_id": "67e55044-10b1-426f-9247-bb680e5fe0c8", + "model": "o4-mini", + "cwd": "/workdir", + "sandbox": "danger-full-access", + "config": { "temp": 0.2 }, + "profile": "eng", + "base_instructions": "Be terse" + } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_tool_call_request_params_conversations_list_with_opts() { + let req = ToolCallRequestParams::ConversationsList(ConversationsListArgs { + limit: Some(50), + cursor: Some("abc".into()), + }); + + let observed = to_val(&req.into_request(RequestId::Integer(2))); + let expected = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "conversationsList", + "arguments": { + "limit": 50, + "cursor": "abc" + } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_tool_call_request_params_conversation_stream() { + let req = ToolCallRequestParams::ConversationStream(ConversationStreamArgs { + conversation_id: ConversationId(uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8")), + }); + + let observed = to_val(&req.into_request(mcp_types::RequestId::Integer(2))); + let expected = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "conversationStream", + "arguments": { + "conversation_id": "67e55044-10b1-426f-9247-bb680e5fe0c8" + } + } + }); + assert_eq!(observed, expected); + } + + // ----- Message inputs / sources ----- + + #[test] + fn serialize_message_input_image_url() { + let item = InputItem::Image { + image_url: "https://example.com/x.png".into(), + }; + let observed = to_val(&item); + let expected = json!({ + "type": "image", + "image_url": "https://example.com/x.png" + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_message_input_local_image_path() { + let url = InputItem::LocalImage { + path: PathBuf::from("https://example.com/a.pdf"), + }; + let id = InputItem::LocalImage { + path: PathBuf::from("file_456"), + }; + let observed_url = to_val(&url); + let expected_url = json!({"type":"local_image","path":"https://example.com/a.pdf"}); + assert_eq!( + observed_url, expected_url, + "LocalImage with URL path should serialize as image_url" + ); + let observed_id = to_val(&id); + let expected_id = json!({"type":"local_image","path":"file_456"}); + assert_eq!( + observed_id, expected_id, + "LocalImage with file id should serialize as image_url" + ); + } + + #[test] + fn serialize_message_input_image_url_without_detail() { + let item = InputItem::Image { + image_url: "https://example.com/x.png".into(), + }; + let observed = to_val(&item); + let expected = json!({ + "type": "image", + "image_url": "https://example.com/x.png" + }); + assert_eq!(observed, expected); + } + + // ----- Responses ----- + + #[test] + fn response_success_conversation_create_full_schema() { + let env = ToolCallResponse { + request_id: RequestId::Integer(1), + is_error: None, + result: Some(ToolCallResponseResult::ConversationCreate( + ConversationCreateResult { + conversation_id: ConversationId(uuid!("d0f6ecbe-84a2-41c1-b23d-b20473b25eab")), + model: "o3".into(), + }, + )), + }; + let observed = to_val(&env); + let expected = json!({ + "requestId": 1, + "result": { + "conversation_id": "d0f6ecbe-84a2-41c1-b23d-b20473b25eab", + "model": "o3" + } + }); + assert_eq!( + observed, expected, + "response (ConversationCreate) must match" + ); + } + + #[test] + fn response_success_conversation_stream_empty_result_object() { + let env = ToolCallResponse { + request_id: RequestId::Integer(2), + is_error: None, + result: Some(ToolCallResponseResult::ConversationStream( + ConversationStreamResult {}, + )), + }; + let observed = to_val(&env); + let expected = json!({ + "requestId": 2, + "result": {} + }); + assert_eq!( + observed, expected, + "response (ConversationStream) must have empty object result" + ); + } + + #[test] + fn response_success_send_message_accepted_full_schema() { + let env = ToolCallResponse { + request_id: RequestId::Integer(3), + is_error: None, + result: Some(ToolCallResponseResult::ConversationSendMessage( + ConversationSendMessageResult { success: true }, + )), + }; + let observed = to_val(&env); + let expected = json!({ + "requestId": 3, + "result": { "success": true } + }); + assert_eq!( + observed, expected, + "response (ConversationSendMessageAccepted) must match" + ); + } + + #[test] + fn response_success_conversations_list_with_next_cursor_full_schema() { + let env = ToolCallResponse { + request_id: RequestId::Integer(4), + is_error: None, + result: Some(ToolCallResponseResult::ConversationsList( + ConversationsListResult { + conversations: vec![ConversationSummary { + conversation_id: ConversationId(uuid!( + "67e55044-10b1-426f-9247-bb680e5fe0c8" + )), + title: "Refactor config loader".into(), + }], + next_cursor: Some("next123".into()), + }, + )), + }; + let observed = to_val(&env); + let expected = json!({ + "requestId": 4, + "result": { + "conversations": [ + { + "conversation_id": "67e55044-10b1-426f-9247-bb680e5fe0c8", + "title": "Refactor config loader" + } + ], + "next_cursor": "next123" + } + }); + assert_eq!( + observed, expected, + "response (ConversationsList with cursor) must match" + ); + } + + #[test] + fn response_error_only_is_error_and_request_id_string() { + let env = ToolCallResponse { + request_id: RequestId::Integer(4), + is_error: Some(true), + result: None, + }; + let observed = to_val(&env); + let expected = json!({ + "requestId": 4, + "isError": true + }); + assert_eq!( + observed, expected, + "error response must omit `result` and include `isError`" + ); + } + + // ----- Notifications ----- + + #[test] + fn serialize_notification_initial_state_minimal() { + let params = InitialStateNotificationParams { + meta: Some(NotificationMeta { + conversation_id: Some(ConversationId(uuid!( + "67e55044-10b1-426f-9247-bb680e5fe0c8" + ))), + request_id: Some(RequestId::Integer(44)), + }), + initial_state: InitialStatePayload { + events: vec![ + CodexEventNotificationParams { + meta: None, + msg: EventMsg::TaskStarted, + }, + CodexEventNotificationParams { + meta: None, + msg: EventMsg::AgentMessageDelta( + codex_core::protocol::AgentMessageDeltaEvent { + delta: "Loading...".into(), + }, + ), + }, + ], + }, + }; + + let observed = to_val(&ServerNotification::InitialState(params.clone())); + let expected = json!({ + "method": "notifications/initial_state", + "params": { + "_meta": { + "conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8", + "requestId": 44 + }, + "initial_state": { + "events": [ + { "msg": { "type": "task_started" } }, + { "msg": { "type": "agent_message_delta", "delta": "Loading..." } } + ] + } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_initial_state_omits_empty_events_full_json() { + let params = InitialStateNotificationParams { + meta: None, + initial_state: InitialStatePayload { events: vec![] }, + }; + + let observed = to_val(&ServerNotification::InitialState(params)); + let expected = json!({ + "method": "notifications/initial_state", + "params": { + "initial_state": { "events": [] } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_stream_disconnected() { + let params = StreamDisconnectedNotificationParams { + meta: Some(NotificationMeta { + conversation_id: Some(ConversationId(uuid!( + "67e55044-10b1-426f-9247-bb680e5fe0c8" + ))), + request_id: None, + }), + reason: "New stream() took over".into(), + }; + + let observed = to_val(&ServerNotification::StreamDisconnected(params)); + let expected = json!({ + "method": "notifications/stream_disconnected", + "params": { + "_meta": { "conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8" }, + "reason": "New stream() took over" + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_codex_event_uses_eventmsg_type_in_method() { + let params = CodexEventNotificationParams { + meta: Some(NotificationMeta { + conversation_id: Some(ConversationId(uuid!( + "67e55044-10b1-426f-9247-bb680e5fe0c8" + ))), + request_id: Some(RequestId::Integer(44)), + }), + msg: EventMsg::AgentMessage(codex_core::protocol::AgentMessageEvent { + message: "hi".into(), + }), + }; + + let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); + let expected = json!({ + "method": "notifications/agent_message", + "params": { + "_meta": { + "conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8", + "requestId": 44 + }, + "msg": { "type": "agent_message", "message": "hi" } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_codex_event_task_started_full_json() { + let params = CodexEventNotificationParams { + meta: Some(NotificationMeta { + conversation_id: Some(ConversationId(uuid!( + "67e55044-10b1-426f-9247-bb680e5fe0c8" + ))), + request_id: Some(RequestId::Integer(7)), + }), + msg: EventMsg::TaskStarted, + }; + + let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); + let expected = json!({ + "method": "notifications/task_started", + "params": { + "_meta": { + "conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8", + "requestId": 7 + }, + "msg": { "type": "task_started" } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_codex_event_agent_message_delta_full_json() { + let params = CodexEventNotificationParams { + meta: None, + msg: EventMsg::AgentMessageDelta(codex_core::protocol::AgentMessageDeltaEvent { + delta: "stream...".into(), + }), + }; + + let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); + let expected = json!({ + "method": "notifications/agent_message_delta", + "params": { + "msg": { "type": "agent_message_delta", "delta": "stream..." } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_codex_event_agent_message_full_json() { + let params = CodexEventNotificationParams { + meta: Some(NotificationMeta { + conversation_id: Some(ConversationId(uuid!( + "67e55044-10b1-426f-9247-bb680e5fe0c8" + ))), + request_id: Some(RequestId::Integer(44)), + }), + msg: EventMsg::AgentMessage(codex_core::protocol::AgentMessageEvent { + message: "hi".into(), + }), + }; + + let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); + let expected = json!({ + "method": "notifications/agent_message", + "params": { + "_meta": { + "conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8", + "requestId": 44 + }, + "msg": { "type": "agent_message", "message": "hi" } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_codex_event_agent_reasoning_full_json() { + let params = CodexEventNotificationParams { + meta: None, + msg: EventMsg::AgentReasoning(codex_core::protocol::AgentReasoningEvent { + text: "thinking…".into(), + }), + }; + + let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); + let expected = json!({ + "method": "notifications/agent_reasoning", + "params": { + "msg": { "type": "agent_reasoning", "text": "thinking…" } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_codex_event_token_count_full_json() { + let usage = codex_core::protocol::TokenUsage { + input_tokens: 10, + cached_input_tokens: Some(2), + output_tokens: 5, + reasoning_output_tokens: Some(1), + total_tokens: 16, + }; + let params = CodexEventNotificationParams { + meta: None, + msg: EventMsg::TokenCount(usage), + }; + + let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); + let expected = json!({ + "method": "notifications/token_count", + "params": { + "msg": { + "type": "token_count", + "input_tokens": 10, + "cached_input_tokens": 2, + "output_tokens": 5, + "reasoning_output_tokens": 1, + "total_tokens": 16 + } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_codex_event_session_configured_full_json() { + let params = CodexEventNotificationParams { + meta: Some(NotificationMeta { + conversation_id: Some(ConversationId(uuid!( + "67e55044-10b1-426f-9247-bb680e5fe0c8" + ))), + request_id: None, + }), + msg: EventMsg::SessionConfigured(codex_core::protocol::SessionConfiguredEvent { + session_id: uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"), + model: "codex-mini-latest".into(), + history_log_id: 42, + history_entry_count: 3, + }), + }; + + let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); + let expected = json!({ + "method": "notifications/session_configured", + "params": { + "_meta": { "conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8" }, + "msg": { + "type": "session_configured", + "session_id": "67e55044-10b1-426f-9247-bb680e5fe0c8", + "model": "codex-mini-latest", + "history_log_id": 42, + "history_entry_count": 3 + } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_codex_event_exec_command_begin_full_json() { + let params = CodexEventNotificationParams { + meta: None, + msg: EventMsg::ExecCommandBegin(codex_core::protocol::ExecCommandBeginEvent { + call_id: "c1".into(), + command: vec!["bash".into(), "-lc".into(), "echo hi".into()], + cwd: std::path::PathBuf::from("/work"), + }), + }; + + let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); + let expected = json!({ + "method": "notifications/exec_command_begin", + "params": { + "msg": { + "type": "exec_command_begin", + "call_id": "c1", + "command": ["bash", "-lc", "echo hi"], + "cwd": "/work" + } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_codex_event_mcp_tool_call_begin_full_json() { + let params = CodexEventNotificationParams { + meta: None, + msg: EventMsg::McpToolCallBegin(McpToolCallBeginEvent { + call_id: "m1".into(), + invocation: McpInvocation { + server: "calc".into(), + tool: "add".into(), + arguments: Some(json!({"a":1,"b":2})), + }, + }), + }; + + let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); + let expected = json!({ + "method": "notifications/mcp_tool_call_begin", + "params": { + "msg": { + "type": "mcp_tool_call_begin", + "call_id": "m1", + "invocation": { + "server": "calc", + "tool": "add", + "arguments": { "a": 1, "b": 2 } + } + } + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_codex_event_patch_apply_end_full_json() { + let params = CodexEventNotificationParams { + meta: None, + msg: EventMsg::PatchApplyEnd(codex_core::protocol::PatchApplyEndEvent { + call_id: "p1".into(), + stdout: "ok".into(), + stderr: "".into(), + success: true, + }), + }; + + let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); + let expected = json!({ + "method": "notifications/patch_apply_end", + "params": { + "msg": { + "type": "patch_apply_end", + "call_id": "p1", + "stdout": "ok", + "stderr": "", + "success": true + } + } + }); + assert_eq!(observed, expected); + } + + // ----- Cancelled notifications ----- + + #[test] + fn serialize_notification_cancelled_with_reason_full_json() { + let params = CancelNotificationParams { + request_id: RequestId::String("r-123".into()), + reason: Some("user_cancelled".into()), + }; + + let observed = to_val(&ClientNotification::Cancelled(params)); + let expected = json!({ + "method": "notifications/cancelled", + "params": { + "requestId": "r-123", + "reason": "user_cancelled" + } + }); + assert_eq!(observed, expected); + } + + #[test] + fn serialize_notification_cancelled_without_reason_full_json() { + let params = CancelNotificationParams { + request_id: RequestId::Integer(77), + reason: None, + }; + + let observed = to_val(&ClientNotification::Cancelled(params)); + + // Check exact structure: reason must be omitted. + assert_eq!(observed["method"], "notifications/cancelled"); + assert_eq!(observed["params"]["requestId"], 77); + assert!( + observed["params"].get("reason").is_none(), + "reason must be omitted when None" + ); + } +} diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs index bf6f42e569..7ba827d60b 100644 --- a/codex-rs/mcp-server/src/message_processor.rs +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -1,19 +1,22 @@ +use std::collections::HashMap; use std::path::PathBuf; +use std::sync::Arc; use crate::codex_tool_config::CodexToolCallParam; +use crate::codex_tool_config::CodexToolCallReplyParam; use crate::codex_tool_config::create_tool_for_codex_tool_call_param; +use crate::codex_tool_config::create_tool_for_codex_tool_call_reply_param; +use crate::outgoing_message::OutgoingMessageSender; +use codex_core::Codex; use codex_core::config::Config as CodexConfig; +use codex_core::protocol::Submission; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; -use mcp_types::CallToolResultContent; use mcp_types::ClientRequest; -use mcp_types::JSONRPC_VERSION; -use mcp_types::JSONRPCBatchRequest; -use mcp_types::JSONRPCBatchResponse; +use mcp_types::ContentBlock; use mcp_types::JSONRPCError; use mcp_types::JSONRPCErrorError; -use mcp_types::JSONRPCMessage; use mcp_types::JSONRPCNotification; use mcp_types::JSONRPCRequest; use mcp_types::JSONRPCResponse; @@ -24,30 +27,35 @@ use mcp_types::ServerCapabilitiesTools; use mcp_types::ServerNotification; use mcp_types::TextContent; use serde_json::json; -use tokio::sync::mpsc; +use tokio::sync::Mutex; use tokio::task; +use uuid::Uuid; pub(crate) struct MessageProcessor { - outgoing: mpsc::Sender, + outgoing: Arc, initialized: bool, codex_linux_sandbox_exe: Option, + session_map: Arc>>>, + running_requests_id_to_codex_uuid: Arc>>, } impl MessageProcessor { /// Create a new `MessageProcessor`, retaining a handle to the outgoing /// `Sender` so handlers can enqueue messages to be written to stdout. pub(crate) fn new( - outgoing: mpsc::Sender, + outgoing: OutgoingMessageSender, codex_linux_sandbox_exe: Option, ) -> Self { Self { - outgoing, + outgoing: Arc::new(outgoing), initialized: false, codex_linux_sandbox_exe, + session_map: Arc::new(Mutex::new(HashMap::new())), + running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())), } } - pub(crate) fn process_request(&mut self, request: JSONRPCRequest) { + pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) { // Hold on to the ID so we can respond. let request_id = request.id.clone(); @@ -62,10 +70,10 @@ impl MessageProcessor { // Dispatch to a dedicated handler for each request type. match client_request { ClientRequest::InitializeRequest(params) => { - self.handle_initialize(request_id, params); + self.handle_initialize(request_id, params).await; } ClientRequest::PingRequest(params) => { - self.handle_ping(request_id, params); + self.handle_ping(request_id, params).await; } ClientRequest::ListResourcesRequest(params) => { self.handle_list_resources(params); @@ -89,10 +97,10 @@ impl MessageProcessor { self.handle_get_prompt(params); } ClientRequest::ListToolsRequest(params) => { - self.handle_list_tools(request_id, params); + self.handle_list_tools(request_id, params).await; } ClientRequest::CallToolRequest(params) => { - self.handle_call_tool(request_id, params); + self.handle_call_tool(request_id, params).await; } ClientRequest::SetLevelRequest(params) => { self.handle_set_level(params); @@ -104,12 +112,14 @@ impl MessageProcessor { } /// Handle a standalone JSON-RPC response originating from the peer. - pub(crate) fn process_response(&mut self, response: JSONRPCResponse) { + pub(crate) async fn process_response(&mut self, response: JSONRPCResponse) { tracing::info!("<- response: {:?}", response); + let JSONRPCResponse { id, result, .. } = response; + self.outgoing.notify_client_response(id, result).await } /// Handle a fire-and-forget JSON-RPC notification. - pub(crate) fn process_notification(&mut self, notification: JSONRPCNotification) { + pub(crate) async fn process_notification(&mut self, notification: JSONRPCNotification) { let server_notification = match ServerNotification::try_from(notification) { Ok(n) => n, Err(e) => { @@ -122,7 +132,7 @@ impl MessageProcessor { // handler so additional logic can be implemented incrementally. match server_notification { ServerNotification::CancelledNotification(params) => { - self.handle_cancelled_notification(params); + self.handle_cancelled_notification(params).await; } ServerNotification::ProgressNotification(params) => { self.handle_progress_notification(params); @@ -145,42 +155,12 @@ impl MessageProcessor { } } - /// Handle a batch of requests and/or notifications. - pub(crate) fn process_batch_request(&mut self, batch: JSONRPCBatchRequest) { - tracing::info!("<- batch request containing {} item(s)", batch.len()); - for item in batch { - match item { - mcp_types::JSONRPCBatchRequestItem::JSONRPCRequest(req) => { - self.process_request(req); - } - mcp_types::JSONRPCBatchRequestItem::JSONRPCNotification(note) => { - self.process_notification(note); - } - } - } - } - /// Handle an error object received from the peer. pub(crate) fn process_error(&mut self, err: JSONRPCError) { tracing::error!("<- error: {:?}", err); } - /// Handle a batch of responses/errors. - pub(crate) fn process_batch_response(&mut self, batch: JSONRPCBatchResponse) { - tracing::info!("<- batch response containing {} item(s)", batch.len()); - for item in batch { - match item { - mcp_types::JSONRPCBatchResponseItem::JSONRPCResponse(resp) => { - self.process_response(resp); - } - mcp_types::JSONRPCBatchResponseItem::JSONRPCError(err) => { - self.process_error(err); - } - } - } - } - - fn handle_initialize( + async fn handle_initialize( &mut self, id: RequestId, params: ::Params, @@ -189,19 +169,12 @@ impl MessageProcessor { if self.initialized { // Already initialised: send JSON-RPC error response. - let error_msg = JSONRPCMessage::Error(JSONRPCError { - jsonrpc: JSONRPC_VERSION.into(), - id, - error: JSONRPCErrorError { - code: -32600, // Invalid Request - message: "initialize called more than once".to_string(), - data: None, - }, - }); - - if let Err(e) = self.outgoing.try_send(error_msg) { - tracing::error!("Failed to send initialization error: {e}"); - } + let error = JSONRPCErrorError { + code: -32600, // Invalid Request + message: "initialize called more than once".to_string(), + data: None, + }; + self.outgoing.send_error(id, error).await; return; } @@ -223,38 +196,34 @@ impl MessageProcessor { protocol_version: params.protocol_version.clone(), server_info: mcp_types::Implementation { name: "codex-mcp-server".to_string(), - version: mcp_types::MCP_SCHEMA_VERSION.to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + title: Some("Codex".to_string()), }, }; - self.send_response::(id, result); + self.send_response::(id, result) + .await; } - fn send_response(&self, id: RequestId, result: T::Result) + async fn send_response(&self, id: RequestId, result: T::Result) where T: ModelContextProtocolRequest, { // result has `Serialized` instance so should never fail #[expect(clippy::unwrap_used)] - let response = JSONRPCMessage::Response(JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id, - result: serde_json::to_value(result).unwrap(), - }); - - if let Err(e) = self.outgoing.try_send(response) { - tracing::error!("Failed to send response: {e}"); - } + let result = serde_json::to_value(result).unwrap(); + self.outgoing.send_response(id, result).await; } - fn handle_ping( + async fn handle_ping( &self, id: RequestId, params: ::Params, ) { tracing::info!("ping -> params: {:?}", params); let result = json!({}); - self.send_response::(id, result); + self.send_response::(id, result) + .await; } fn handle_list_resources( @@ -307,21 +276,25 @@ impl MessageProcessor { tracing::info!("prompts/get -> params: {:?}", params); } - fn handle_list_tools( + async fn handle_list_tools( &self, id: RequestId, params: ::Params, ) { tracing::trace!("tools/list -> {params:?}"); let result = ListToolsResult { - tools: vec![create_tool_for_codex_tool_call_param()], + tools: vec![ + create_tool_for_codex_tool_call_param(), + create_tool_for_codex_tool_call_reply_param(), + ], next_cursor: None, }; - self.send_response::(id, result); + self.send_response::(id, result) + .await; } - fn handle_call_tool( + async fn handle_call_tool( &self, id: RequestId, params: ::Params, @@ -329,28 +302,36 @@ impl MessageProcessor { tracing::info!("tools/call -> params: {:?}", params); let CallToolRequestParams { name, arguments } = params; - // We only support the "codex" tool for now. - if name != "codex" { - // Tool not found – return error result so the LLM can react. - let result = CallToolResult { - content: vec![CallToolResultContent::TextContent(TextContent { - r#type: "text".to_string(), - text: format!("Unknown tool '{name}'"), - annotations: None, - })], - is_error: Some(true), - }; - self.send_response::(id, result); - return; + match name.as_str() { + "codex" => self.handle_tool_call_codex(id, arguments).await, + "codex-reply" => { + self.handle_tool_call_codex_session_reply(id, arguments) + .await + } + _ => { + let result = CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_string(), + text: format!("Unknown tool '{name}'"), + annotations: None, + })], + is_error: Some(true), + structured_content: None, + }; + self.send_response::(id, result) + .await; + } } + } + async fn handle_tool_call_codex(&self, id: RequestId, arguments: Option) { let (initial_prompt, config): (String, CodexConfig) = match arguments { Some(json_val) => match serde_json::from_value::(json_val) { Ok(tool_cfg) => match tool_cfg.into_config(self.codex_linux_sandbox_exe.clone()) { Ok(cfg) => cfg, Err(e) => { let result = CallToolResult { - content: vec![CallToolResultContent::TextContent(TextContent { + content: vec![ContentBlock::TextContent(TextContent { r#type: "text".to_owned(), text: format!( "Failed to load Codex configuration from overrides: {e}" @@ -358,27 +339,31 @@ impl MessageProcessor { annotations: None, })], is_error: Some(true), + structured_content: None, }; - self.send_response::(id, result); + self.send_response::(id, result) + .await; return; } }, Err(e) => { let result = CallToolResult { - content: vec![CallToolResultContent::TextContent(TextContent { + content: vec![ContentBlock::TextContent(TextContent { r#type: "text".to_owned(), text: format!("Failed to parse configuration for Codex tool: {e}"), annotations: None, })], is_error: Some(true), + structured_content: None, }; - self.send_response::(id, result); + self.send_response::(id, result) + .await; return; } }, None => { let result = CallToolResult { - content: vec![CallToolResultContent::TextContent(TextContent { + content: vec![ContentBlock::TextContent(TextContent { r#type: "text".to_string(), text: "Missing arguments for codex tool-call; the `prompt` field is required." @@ -386,21 +371,147 @@ impl MessageProcessor { annotations: None, })], is_error: Some(true), + structured_content: None, }; - self.send_response::(id, result); + self.send_response::(id, result) + .await; return; } }; - // Clone outgoing sender to move into async task. + // Clone outgoing and session map to move into async task. let outgoing = self.outgoing.clone(); + let session_map = self.session_map.clone(); + let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone(); // Spawn an async task to handle the Codex session so that we do not // block the synchronous message-processing loop. task::spawn(async move { // Run the Codex session and stream events back to the client. - crate::codex_tool_runner::run_codex_tool_session(id, initial_prompt, config, outgoing) + crate::codex_tool_runner::run_codex_tool_session( + id, + initial_prompt, + config, + outgoing, + session_map, + running_requests_id_to_codex_uuid, + ) + .await; + }); + } + + async fn handle_tool_call_codex_session_reply( + &self, + request_id: RequestId, + arguments: Option, + ) { + tracing::info!("tools/call -> params: {:?}", arguments); + + // parse arguments + let CodexToolCallReplyParam { session_id, prompt } = match arguments { + Some(json_val) => match serde_json::from_value::(json_val) { + Ok(params) => params, + Err(e) => { + tracing::error!("Failed to parse Codex tool call reply parameters: {e}"); + let result = CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_owned(), + text: format!("Failed to parse configuration for Codex tool: {e}"), + annotations: None, + })], + is_error: Some(true), + structured_content: None, + }; + self.send_response::(request_id, result) + .await; + return; + } + }, + None => { + tracing::error!( + "Missing arguments for codex-reply tool-call; the `session_id` and `prompt` fields are required." + ); + let result = CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_owned(), + text: "Missing arguments for codex-reply tool-call; the `session_id` and `prompt` fields are required.".to_owned(), + annotations: None, + })], + is_error: Some(true), + structured_content: None, + }; + self.send_response::(request_id, result) + .await; + return; + } + }; + let session_id = match Uuid::parse_str(&session_id) { + Ok(id) => id, + Err(e) => { + tracing::error!("Failed to parse session_id: {e}"); + let result = CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_owned(), + text: format!("Failed to parse session_id: {e}"), + annotations: None, + })], + is_error: Some(true), + structured_content: None, + }; + self.send_response::(request_id, result) + .await; + return; + } + }; + + // load codex from session map + let session_map_mutex = Arc::clone(&self.session_map); + + // Clone outgoing and session map to move into async task. + let outgoing = self.outgoing.clone(); + let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone(); + + let codex = { + let session_map = session_map_mutex.lock().await; + match session_map.get(&session_id).cloned() { + Some(c) => c, + None => { + tracing::warn!("Session not found for session_id: {session_id}"); + let result = CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_owned(), + text: format!("Session not found for session_id: {session_id}"), + annotations: None, + })], + is_error: Some(true), + structured_content: None, + }; + outgoing + .send_response(request_id, serde_json::to_value(result).unwrap_or_default()) + .await; + return; + } + } + }; + + // Spawn the long-running reply handler. + tokio::spawn({ + let codex = codex.clone(); + let outgoing = outgoing.clone(); + let prompt = prompt.clone(); + let running_requests_id_to_codex_uuid = running_requests_id_to_codex_uuid.clone(); + + async move { + crate::codex_tool_runner::run_codex_tool_session_reply( + codex, + outgoing, + request_id, + prompt, + running_requests_id_to_codex_uuid, + session_id, + ) .await; + } }); } @@ -422,11 +533,58 @@ impl MessageProcessor { // Notification handlers // --------------------------------------------------------------------- - fn handle_cancelled_notification( + async fn handle_cancelled_notification( &self, params: ::Params, ) { - tracing::info!("notifications/cancelled -> params: {:?}", params); + let request_id = params.request_id; + // Create a stable string form early for logging and submission id. + let request_id_string = match &request_id { + RequestId::String(s) => s.clone(), + RequestId::Integer(i) => i.to_string(), + }; + + // Obtain the session_id while holding the first lock, then release. + let session_id = { + let map_guard = self.running_requests_id_to_codex_uuid.lock().await; + match map_guard.get(&request_id) { + Some(id) => *id, // Uuid is Copy + None => { + tracing::warn!("Session not found for request_id: {}", request_id_string); + return; + } + } + }; + tracing::info!("session_id: {session_id}"); + + // Obtain the Codex Arc while holding the session_map lock, then release. + let codex_arc = { + let sessions_guard = self.session_map.lock().await; + match sessions_guard.get(&session_id) { + Some(codex) => Arc::clone(codex), + None => { + tracing::warn!("Session not found for session_id: {session_id}"); + return; + } + } + }; + + // Submit interrupt to Codex. + let err = codex_arc + .submit_with_id(Submission { + id: request_id_string, + op: codex_core::protocol::Op::Interrupt, + }) + .await; + if let Err(e) = err { + tracing::error!("Failed to submit interrupt to Codex: {e}"); + return; + } + // unregister the id so we don't keep it in the map + self.running_requests_id_to_codex_uuid + .lock() + .await + .remove(&request_id); } fn handle_progress_notification( diff --git a/codex-rs/mcp-server/src/outgoing_message.rs b/codex-rs/mcp-server/src/outgoing_message.rs new file mode 100644 index 0000000000..e7b0b9b63c --- /dev/null +++ b/codex-rs/mcp-server/src/outgoing_message.rs @@ -0,0 +1,331 @@ +use std::collections::HashMap; +use std::sync::atomic::AtomicI64; +use std::sync::atomic::Ordering; + +use codex_core::protocol::Event; +use mcp_types::JSONRPC_VERSION; +use mcp_types::JSONRPCError; +use mcp_types::JSONRPCErrorError; +use mcp_types::JSONRPCMessage; +use mcp_types::JSONRPCNotification; +use mcp_types::JSONRPCRequest; +use mcp_types::JSONRPCResponse; +use mcp_types::RequestId; +use mcp_types::Result; +use serde::Serialize; +use tokio::sync::Mutex; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tracing::warn; + +/// Sends messages to the client and manages request callbacks. +pub(crate) struct OutgoingMessageSender { + next_request_id: AtomicI64, + sender: mpsc::Sender, + request_id_to_callback: Mutex>>, +} + +impl OutgoingMessageSender { + pub(crate) fn new(sender: mpsc::Sender) -> Self { + Self { + next_request_id: AtomicI64::new(0), + sender, + request_id_to_callback: Mutex::new(HashMap::new()), + } + } + + pub(crate) async fn send_request( + &self, + method: &str, + params: Option, + ) -> oneshot::Receiver { + let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed)); + let outgoing_message_id = id.clone(); + let (tx_approve, rx_approve) = oneshot::channel(); + { + let mut request_id_to_callback = self.request_id_to_callback.lock().await; + request_id_to_callback.insert(id, tx_approve); + } + + let outgoing_message = OutgoingMessage::Request(OutgoingRequest { + id: outgoing_message_id, + method: method.to_string(), + params, + }); + let _ = self.sender.send(outgoing_message).await; + rx_approve + } + + pub(crate) async fn notify_client_response(&self, id: RequestId, result: Result) { + let entry = { + let mut request_id_to_callback = self.request_id_to_callback.lock().await; + request_id_to_callback.remove_entry(&id) + }; + + match entry { + Some((id, sender)) => { + if let Err(err) = sender.send(result) { + warn!("could not notify callback for {id:?} due to: {err:?}"); + } + } + None => { + warn!("could not find callback for {id:?}"); + } + } + } + + pub(crate) async fn send_response(&self, id: RequestId, result: Result) { + let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result }); + let _ = self.sender.send(outgoing_message).await; + } + + pub(crate) async fn send_event_as_notification( + &self, + event: &Event, + meta: Option, + ) { + #[allow(clippy::expect_used)] + let event_json = serde_json::to_value(event).expect("Event must serialize"); + + let params = if let Ok(params) = serde_json::to_value(OutgoingNotificationParams { + meta, + event: event_json.clone(), + }) { + params + } else { + warn!("Failed to serialize event as OutgoingNotificationParams"); + event_json + }; + + let outgoing_message = OutgoingMessage::Notification(OutgoingNotification { + method: "codex/event".to_string(), + params: Some(params.clone()), + }); + let _ = self.sender.send(outgoing_message).await; + + self.send_event_as_notification_new_schema(event, Some(params.clone())) + .await; + } + + // should be backwards compatible. + // it will replace send_event_as_notification eventually. + async fn send_event_as_notification_new_schema( + &self, + event: &Event, + params: Option, + ) { + let outgoing_message = OutgoingMessage::Notification(OutgoingNotification { + method: event.msg.to_string(), + params, + }); + let _ = self.sender.send(outgoing_message).await; + } + pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) { + let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error }); + let _ = self.sender.send(outgoing_message).await; + } +} + +/// Outgoing message from the server to the client. +pub(crate) enum OutgoingMessage { + Request(OutgoingRequest), + Notification(OutgoingNotification), + Response(OutgoingResponse), + Error(OutgoingError), +} + +impl From for JSONRPCMessage { + fn from(val: OutgoingMessage) -> Self { + use OutgoingMessage::*; + match val { + Request(OutgoingRequest { id, method, params }) => { + JSONRPCMessage::Request(JSONRPCRequest { + jsonrpc: JSONRPC_VERSION.into(), + id, + method, + params, + }) + } + Notification(OutgoingNotification { method, params }) => { + JSONRPCMessage::Notification(JSONRPCNotification { + jsonrpc: JSONRPC_VERSION.into(), + method, + params, + }) + } + Response(OutgoingResponse { id, result }) => { + JSONRPCMessage::Response(JSONRPCResponse { + jsonrpc: JSONRPC_VERSION.into(), + id, + result, + }) + } + Error(OutgoingError { id, error }) => JSONRPCMessage::Error(JSONRPCError { + jsonrpc: JSONRPC_VERSION.into(), + id, + error, + }), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub(crate) struct OutgoingRequest { + pub id: RequestId, + pub method: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub(crate) struct OutgoingNotification { + pub method: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub(crate) struct OutgoingNotificationParams { + #[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")] + pub meta: Option, + + #[serde(flatten)] + pub event: serde_json::Value, +} + +// Additional mcp-specific data to be added to a [`codex_core::protocol::Event`] as notification.params._meta +// MCP Spec: https://modelcontextprotocol.io/specification/2025-06-18/basic#meta +// Typescript Schema: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/0695a497eb50a804fc0e88c18a93a21a675d6b3e/schema/2025-06-18/schema.ts +#[derive(Debug, Clone, PartialEq, Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct OutgoingNotificationMeta { + pub request_id: Option, +} + +impl OutgoingNotificationMeta { + pub(crate) fn new(request_id: Option) -> Self { + Self { request_id } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub(crate) struct OutgoingResponse { + pub id: RequestId, + pub result: Result, +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub(crate) struct OutgoingError { + pub error: JSONRPCErrorError, + pub id: RequestId, +} + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + + use codex_core::protocol::EventMsg; + use codex_core::protocol::SessionConfiguredEvent; + use pretty_assertions::assert_eq; + use serde_json::json; + use uuid::Uuid; + + use super::*; + + #[tokio::test] + async fn test_send_event_as_notification() { + let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(2); + let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx); + + let event = Event { + id: "1".to_string(), + msg: EventMsg::SessionConfigured(SessionConfiguredEvent { + session_id: Uuid::new_v4(), + model: "gpt-4o".to_string(), + history_log_id: 1, + history_entry_count: 1000, + }), + }; + + outgoing_message_sender + .send_event_as_notification(&event, None) + .await; + + let result = outgoing_rx.recv().await.unwrap(); + let OutgoingMessage::Notification(OutgoingNotification { method, params }) = result else { + panic!("expected Notification for first message"); + }; + assert_eq!(method, "codex/event"); + + let Ok(expected_params) = serde_json::to_value(&event) else { + panic!("Event must serialize"); + }; + assert_eq!(params, Some(expected_params.clone())); + + let result2 = outgoing_rx.recv().await.unwrap(); + let OutgoingMessage::Notification(OutgoingNotification { + method: method2, + params: params2, + }) = result2 + else { + panic!("expected Notification for second message"); + }; + assert_eq!(method2, event.msg.to_string()); + assert_eq!(params2, Some(expected_params)); + } + + #[tokio::test] + async fn test_send_event_as_notification_with_meta() { + let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(2); + let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx); + + let session_configured_event = SessionConfiguredEvent { + session_id: Uuid::new_v4(), + model: "gpt-4o".to_string(), + history_log_id: 1, + history_entry_count: 1000, + }; + let event = Event { + id: "1".to_string(), + msg: EventMsg::SessionConfigured(session_configured_event.clone()), + }; + let meta = OutgoingNotificationMeta { + request_id: Some(RequestId::String("123".to_string())), + }; + + outgoing_message_sender + .send_event_as_notification(&event, Some(meta)) + .await; + + let result = outgoing_rx.recv().await.unwrap(); + let OutgoingMessage::Notification(OutgoingNotification { method, params }) = result else { + panic!("expected Notification for first message"); + }; + assert_eq!(method, "codex/event"); + let expected_params = json!({ + "_meta": { + "requestId": "123", + }, + "id": "1", + "msg": { + "session_id": session_configured_event.session_id, + "model": session_configured_event.model, + "history_log_id": session_configured_event.history_log_id, + "history_entry_count": session_configured_event.history_entry_count, + "type": "session_configured", + } + }); + assert_eq!(params.unwrap(), expected_params); + + let result2 = outgoing_rx.recv().await.unwrap(); + let OutgoingMessage::Notification(OutgoingNotification { + method: method2, + params: params2, + }) = result2 + else { + panic!("expected Notification for second message"); + }; + assert_eq!(method2, event.msg.to_string()); + assert_eq!(params2.unwrap(), expected_params); + } +} diff --git a/codex-rs/mcp-server/src/patch_approval.rs b/codex-rs/mcp-server/src/patch_approval.rs new file mode 100644 index 0000000000..db99ee5f27 --- /dev/null +++ b/codex-rs/mcp-server/src/patch_approval.rs @@ -0,0 +1,150 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; + +use codex_core::Codex; +use codex_core::protocol::FileChange; +use codex_core::protocol::Op; +use codex_core::protocol::ReviewDecision; +use mcp_types::ElicitRequest; +use mcp_types::ElicitRequestParamsRequestedSchema; +use mcp_types::JSONRPCErrorError; +use mcp_types::ModelContextProtocolRequest; +use mcp_types::RequestId; +use serde::Deserialize; +use serde::Serialize; +use serde_json::json; +use tracing::error; + +use crate::codex_tool_runner::INVALID_PARAMS_ERROR_CODE; +use crate::outgoing_message::OutgoingMessageSender; + +#[derive(Debug, Serialize)] +pub struct PatchApprovalElicitRequestParams { + pub message: String, + #[serde(rename = "requestedSchema")] + pub requested_schema: ElicitRequestParamsRequestedSchema, + pub codex_elicitation: String, + pub codex_mcp_tool_call_id: String, + pub codex_event_id: String, + pub codex_call_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub codex_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub codex_grant_root: Option, + pub codex_changes: HashMap, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct PatchApprovalResponse { + pub decision: ReviewDecision, +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn handle_patch_approval_request( + call_id: String, + reason: Option, + grant_root: Option, + changes: HashMap, + outgoing: Arc, + codex: Arc, + request_id: RequestId, + tool_call_id: String, + event_id: String, +) { + let mut message_lines = Vec::new(); + if let Some(r) = &reason { + message_lines.push(r.clone()); + } + message_lines.push("Allow Codex to apply proposed code changes?".to_string()); + + let params = PatchApprovalElicitRequestParams { + message: message_lines.join("\n"), + requested_schema: ElicitRequestParamsRequestedSchema { + r#type: "object".to_string(), + properties: json!({}), + required: None, + }, + codex_elicitation: "patch-approval".to_string(), + codex_mcp_tool_call_id: tool_call_id.clone(), + codex_event_id: event_id.clone(), + codex_call_id: call_id, + codex_reason: reason, + codex_grant_root: grant_root, + codex_changes: changes, + }; + let params_json = match serde_json::to_value(¶ms) { + Ok(value) => value, + Err(err) => { + let message = format!("Failed to serialize PatchApprovalElicitRequestParams: {err}"); + error!("{message}"); + + outgoing + .send_error( + request_id.clone(), + JSONRPCErrorError { + code: INVALID_PARAMS_ERROR_CODE, + message, + data: None, + }, + ) + .await; + + return; + } + }; + + let on_response = outgoing + .send_request(ElicitRequest::METHOD, Some(params_json)) + .await; + + // Listen for the response on a separate task so we don't block the main agent loop. + { + let codex = codex.clone(); + let event_id = event_id.clone(); + tokio::spawn(async move { + on_patch_approval_response(event_id, on_response, codex).await; + }); + } +} + +pub(crate) async fn on_patch_approval_response( + event_id: String, + receiver: tokio::sync::oneshot::Receiver, + codex: Arc, +) { + let response = receiver.await; + let value = match response { + Ok(value) => value, + Err(err) => { + error!("request failed: {err:?}"); + if let Err(submit_err) = codex + .submit(Op::PatchApproval { + id: event_id.clone(), + decision: ReviewDecision::Denied, + }) + .await + { + error!("failed to submit denied PatchApproval after request failure: {submit_err}"); + } + return; + } + }; + + let response = serde_json::from_value::(value).unwrap_or_else(|err| { + error!("failed to deserialize PatchApprovalResponse: {err}"); + PatchApprovalResponse { + decision: ReviewDecision::Denied, + } + }); + + if let Err(err) = codex + .submit(Op::PatchApproval { + id: event_id, + decision: response.decision, + }) + .await + { + error!("failed to submit PatchApproval: {err}"); + } +} diff --git a/codex-rs/mcp-server/tests/codex_tool.rs b/codex-rs/mcp-server/tests/codex_tool.rs new file mode 100644 index 0000000000..0f06483f24 --- /dev/null +++ b/codex-rs/mcp-server/tests/codex_tool.rs @@ -0,0 +1,440 @@ +use std::collections::HashMap; +use std::env; +use std::path::Path; +use std::path::PathBuf; + +use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; +use codex_core::protocol::FileChange; +use codex_core::protocol::ReviewDecision; +use codex_mcp_server::CodexToolCallParam; +use codex_mcp_server::ExecApprovalElicitRequestParams; +use codex_mcp_server::ExecApprovalResponse; +use codex_mcp_server::PatchApprovalElicitRequestParams; +use codex_mcp_server::PatchApprovalResponse; +use mcp_types::ElicitRequest; +use mcp_types::ElicitRequestParamsRequestedSchema; +use mcp_types::JSONRPC_VERSION; +use mcp_types::JSONRPCRequest; +use mcp_types::JSONRPCResponse; +use mcp_types::ModelContextProtocolRequest; +use mcp_types::RequestId; +use pretty_assertions::assert_eq; +use serde_json::json; +use tempfile::TempDir; +use tokio::time::timeout; +use wiremock::MockServer; + +use mcp_test_support::McpProcess; +use mcp_test_support::create_apply_patch_sse_response; +use mcp_test_support::create_final_assistant_message_sse_response; +use mcp_test_support::create_mock_chat_completions_server; +use mcp_test_support::create_shell_sse_response; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + +/// Test that a shell command that is not on the "trusted" list triggers an +/// elicitation request to the MCP and that sending the approval runs the +/// command, as expected. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_shell_command_approval_triggers_elicitation() { + if env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + // Apparently `#[tokio::test]` must return `()`, so we create a helper + // function that returns `Result` so we can use `?` in favor of `unwrap`. + if let Err(err) = shell_command_approval_triggers_elicitation().await { + panic!("failure: {err}"); + } +} + +async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> { + // We use `git init` because it will not be on the "trusted" list. + let shell_command = vec!["git".to_string(), "init".to_string()]; + let workdir_for_shell_function_call = TempDir::new()?; + + let McpHandle { + process: mut mcp_process, + server: _server, + dir: _dir, + } = create_mcp_process(vec![ + create_shell_sse_response( + shell_command.clone(), + Some(workdir_for_shell_function_call.path()), + Some(5_000), + "call1234", + )?, + create_final_assistant_message_sse_response("Enjoy your new git repo!")?, + ]) + .await?; + + // Send a "codex" tool request, which should hit the completions endpoint. + // In turn, it should reply with a tool call, which the MCP should forward + // as an elicitation. + let codex_request_id = mcp_process + .send_codex_tool_call(CodexToolCallParam { + prompt: "run `git init`".to_string(), + ..Default::default() + }) + .await?; + let elicitation_request = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_request_message(), + ) + .await??; + + // This is the first request from the server, so the id should be 0 given + // how things are currently implemented. + let elicitation_request_id = RequestId::Integer(0); + let expected_elicitation_request = create_expected_elicitation_request( + elicitation_request_id.clone(), + shell_command.clone(), + workdir_for_shell_function_call.path(), + codex_request_id.to_string(), + // Internal Codex id: empirically it is 1, but this is + // admittedly an internal detail that could change. + "1".to_string(), + )?; + assert_eq!(expected_elicitation_request, elicitation_request); + + // Accept the `git init` request by responding to the elicitation. + mcp_process + .send_response( + elicitation_request_id, + serde_json::to_value(ExecApprovalResponse { + decision: ReviewDecision::Approved, + })?, + ) + .await?; + + // Verify the original `codex` tool call completes and that `git init` ran + // successfully. + let codex_response = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)), + ) + .await??; + assert_eq!( + JSONRPCResponse { + jsonrpc: JSONRPC_VERSION.into(), + id: RequestId::Integer(codex_request_id), + result: json!({ + "content": [ + { + "text": "Enjoy your new git repo!", + "type": "text" + } + ] + }), + }, + codex_response + ); + + assert!( + workdir_for_shell_function_call.path().join(".git").is_dir(), + ".git folder should have been created" + ); + + Ok(()) +} + +fn create_expected_elicitation_request( + elicitation_request_id: RequestId, + command: Vec, + workdir: &Path, + codex_mcp_tool_call_id: String, + codex_event_id: String, +) -> anyhow::Result { + let expected_message = format!( + "Allow Codex to run `{}` in `{}`?", + shlex::try_join(command.iter().map(|s| s.as_ref()))?, + workdir.to_string_lossy() + ); + Ok(JSONRPCRequest { + jsonrpc: JSONRPC_VERSION.into(), + id: elicitation_request_id, + method: ElicitRequest::METHOD.to_string(), + params: Some(serde_json::to_value(&ExecApprovalElicitRequestParams { + message: expected_message, + requested_schema: ElicitRequestParamsRequestedSchema { + r#type: "object".to_string(), + properties: json!({}), + required: None, + }, + codex_elicitation: "exec-approval".to_string(), + codex_mcp_tool_call_id, + codex_event_id, + codex_command: command, + codex_cwd: workdir.to_path_buf(), + codex_call_id: "call1234".to_string(), + })?), + }) +} + +/// Test that patch approval triggers an elicitation request to the MCP and that +/// sending the approval applies the patch, as expected. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_patch_approval_triggers_elicitation() { + if env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + if let Err(err) = patch_approval_triggers_elicitation().await { + panic!("failure: {err}"); + } +} + +async fn patch_approval_triggers_elicitation() -> anyhow::Result<()> { + let cwd = TempDir::new()?; + let test_file = cwd.path().join("destination_file.txt"); + std::fs::write(&test_file, "original content\n")?; + + let patch_content = format!( + "*** Begin Patch\n*** Update File: {}\n-original content\n+modified content\n*** End Patch", + test_file.as_path().to_string_lossy() + ); + + let McpHandle { + process: mut mcp_process, + server: _server, + dir: _dir, + } = create_mcp_process(vec![ + create_apply_patch_sse_response(&patch_content, "call1234")?, + create_final_assistant_message_sse_response("Patch has been applied successfully!")?, + ]) + .await?; + + // Send a "codex" tool request that will trigger the apply_patch command + let codex_request_id = mcp_process + .send_codex_tool_call(CodexToolCallParam { + cwd: Some(cwd.path().to_string_lossy().to_string()), + prompt: "please modify the test file".to_string(), + ..Default::default() + }) + .await?; + let elicitation_request = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_request_message(), + ) + .await??; + + let elicitation_request_id = RequestId::Integer(0); + + let mut expected_changes = HashMap::new(); + expected_changes.insert( + test_file.as_path().to_path_buf(), + FileChange::Update { + unified_diff: "@@ -1 +1 @@\n-original content\n+modified content\n".to_string(), + move_path: None, + }, + ); + + let expected_elicitation_request = create_expected_patch_approval_elicitation_request( + elicitation_request_id.clone(), + expected_changes, + None, // No grant_root expected + None, // No reason expected + codex_request_id.to_string(), + "1".to_string(), + )?; + assert_eq!(expected_elicitation_request, elicitation_request); + + // Accept the patch approval request by responding to the elicitation + mcp_process + .send_response( + elicitation_request_id, + serde_json::to_value(PatchApprovalResponse { + decision: ReviewDecision::Approved, + })?, + ) + .await?; + + // Verify the original `codex` tool call completes + let codex_response = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)), + ) + .await??; + assert_eq!( + JSONRPCResponse { + jsonrpc: JSONRPC_VERSION.into(), + id: RequestId::Integer(codex_request_id), + result: json!({ + "content": [ + { + "text": "Patch has been applied successfully!", + "type": "text" + } + ] + }), + }, + codex_response + ); + + let file_contents = std::fs::read_to_string(test_file.as_path())?; + assert_eq!(file_contents, "modified content\n"); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_codex_tool_passes_base_instructions() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + // Apparently `#[tokio::test]` must return `()`, so we create a helper + // function that returns `Result` so we can use `?` in favor of `unwrap`. + if let Err(err) = codex_tool_passes_base_instructions().await { + panic!("failure: {err}"); + } +} + +async fn codex_tool_passes_base_instructions() -> anyhow::Result<()> { + #![allow(clippy::unwrap_used)] + + let server = + create_mock_chat_completions_server(vec![create_final_assistant_message_sse_response( + "Enjoy!", + )?]) + .await; + + // Run `codex mcp` with a specific config.toml. + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + let mut mcp_process = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??; + + // Send a "codex" tool request, which should hit the completions endpoint. + let codex_request_id = mcp_process + .send_codex_tool_call(CodexToolCallParam { + prompt: "How are you?".to_string(), + base_instructions: Some("You are a helpful assistant.".to_string()), + ..Default::default() + }) + .await?; + + let codex_response = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)), + ) + .await??; + assert_eq!( + JSONRPCResponse { + jsonrpc: JSONRPC_VERSION.into(), + id: RequestId::Integer(codex_request_id), + result: json!({ + "content": [ + { + "text": "Enjoy!", + "type": "text" + } + ] + }), + }, + codex_response + ); + + let requests = server.received_requests().await.unwrap(); + let request = requests[0].body_json::().unwrap(); + let instructions = request["messages"][0]["content"].as_str().unwrap(); + assert!(instructions.starts_with("You are a helpful assistant.")); + + Ok(()) +} + +fn create_expected_patch_approval_elicitation_request( + elicitation_request_id: RequestId, + changes: HashMap, + grant_root: Option, + reason: Option, + codex_mcp_tool_call_id: String, + codex_event_id: String, +) -> anyhow::Result { + let mut message_lines = Vec::new(); + if let Some(r) = &reason { + message_lines.push(r.clone()); + } + message_lines.push("Allow Codex to apply proposed code changes?".to_string()); + + Ok(JSONRPCRequest { + jsonrpc: JSONRPC_VERSION.into(), + id: elicitation_request_id, + method: ElicitRequest::METHOD.to_string(), + params: Some(serde_json::to_value(&PatchApprovalElicitRequestParams { + message: message_lines.join("\n"), + requested_schema: ElicitRequestParamsRequestedSchema { + r#type: "object".to_string(), + properties: json!({}), + required: None, + }, + codex_elicitation: "patch-approval".to_string(), + codex_mcp_tool_call_id, + codex_event_id, + codex_reason: reason, + codex_grant_root: grant_root, + codex_changes: changes, + codex_call_id: "call1234".to_string(), + })?), + }) +} + +/// This handle is used to ensure that the MockServer and TempDir are not dropped while +/// the McpProcess is still running. +pub struct McpHandle { + pub process: McpProcess, + /// Retain the server for the lifetime of the McpProcess. + #[allow(dead_code)] + server: MockServer, + /// Retain the temporary directory for the lifetime of the McpProcess. + #[allow(dead_code)] + dir: TempDir, +} + +async fn create_mcp_process(responses: Vec) -> anyhow::Result { + let server = create_mock_chat_completions_server(responses).await; + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + let mut mcp_process = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??; + Ok(McpHandle { + process: mcp_process, + server, + dir: codex_home, + }) +} + +/// Create a Codex config that uses the mock server as the model provider. +/// It also uses `approval_policy = "untrusted"` so that we exercise the +/// elicitation code path for shell commands. +fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + format!( + r#" +model = "mock-model" +approval_policy = "untrusted" +sandbox_policy = "read-only" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "chat" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +} diff --git a/codex-rs/mcp-server/tests/common/Cargo.toml b/codex-rs/mcp-server/tests/common/Cargo.toml new file mode 100644 index 0000000000..3aa246f154 --- /dev/null +++ b/codex-rs/mcp-server/tests/common/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "mcp_test_support" +version = { workspace = true } +edition = "2024" + +[lib] +path = "lib.rs" + +[dependencies] +anyhow = "1" +assert_cmd = "2" +codex-mcp-server = { path = "../.." } +mcp-types = { path = "../../../mcp-types" } +pretty_assertions = "1.4.1" +serde_json = "1" +shlex = "1.3.0" +tempfile = "3" +tokio = { version = "1", features = [ + "io-std", + "macros", + "process", + "rt-multi-thread", +] } +wiremock = "0.6" diff --git a/codex-rs/mcp-server/tests/common/lib.rs b/codex-rs/mcp-server/tests/common/lib.rs new file mode 100644 index 0000000000..b338e2e8ce --- /dev/null +++ b/codex-rs/mcp-server/tests/common/lib.rs @@ -0,0 +1,9 @@ +mod mcp_process; +mod mock_model_server; +mod responses; + +pub use mcp_process::McpProcess; +pub use mock_model_server::create_mock_chat_completions_server; +pub use responses::create_apply_patch_sse_response; +pub use responses::create_final_assistant_message_sse_response; +pub use responses::create_shell_sse_response; diff --git a/codex-rs/mcp-server/tests/common/mcp_process.rs b/codex-rs/mcp-server/tests/common/mcp_process.rs new file mode 100644 index 0000000000..8138749c40 --- /dev/null +++ b/codex-rs/mcp-server/tests/common/mcp_process.rs @@ -0,0 +1,343 @@ +use std::path::Path; +use std::process::Stdio; +use std::sync::atomic::AtomicI64; +use std::sync::atomic::Ordering; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::process::Child; +use tokio::process::ChildStdin; +use tokio::process::ChildStdout; + +use anyhow::Context; +use assert_cmd::prelude::*; +use codex_mcp_server::CodexToolCallParam; +use codex_mcp_server::CodexToolCallReplyParam; +use mcp_types::CallToolRequestParams; +use mcp_types::ClientCapabilities; +use mcp_types::Implementation; +use mcp_types::InitializeRequestParams; +use mcp_types::JSONRPC_VERSION; +use mcp_types::JSONRPCMessage; +use mcp_types::JSONRPCNotification; +use mcp_types::JSONRPCRequest; +use mcp_types::JSONRPCResponse; +use mcp_types::ModelContextProtocolNotification; +use mcp_types::ModelContextProtocolRequest; +use mcp_types::RequestId; +use pretty_assertions::assert_eq; +use serde_json::json; +use std::process::Command as StdCommand; +use tokio::process::Command; + +pub struct McpProcess { + next_request_id: AtomicI64, + /// Retain this child process until the client is dropped. The Tokio runtime + /// will make a "best effort" to reap the process after it exits, but it is + /// not a guarantee. See the `kill_on_drop` documentation for details. + #[allow(dead_code)] + process: Child, + stdin: ChildStdin, + stdout: BufReader, +} + +impl McpProcess { + pub async fn new(codex_home: &Path) -> anyhow::Result { + // Use assert_cmd to locate the binary path and then switch to tokio::process::Command + let std_cmd = StdCommand::cargo_bin("codex-mcp-server") + .context("should find binary for codex-mcp-server")?; + + let program = std_cmd.get_program().to_owned(); + + let mut cmd = Command::new(program); + + cmd.stdin(Stdio::piped()); + cmd.stdout(Stdio::piped()); + cmd.env("CODEX_HOME", codex_home); + cmd.env("RUST_LOG", "debug"); + + let mut process = cmd + .kill_on_drop(true) + .spawn() + .context("codex-mcp-server proc should start")?; + let stdin = process + .stdin + .take() + .ok_or_else(|| anyhow::format_err!("mcp should have stdin fd"))?; + let stdout = process + .stdout + .take() + .ok_or_else(|| anyhow::format_err!("mcp should have stdout fd"))?; + let stdout = BufReader::new(stdout); + Ok(Self { + next_request_id: AtomicI64::new(0), + process, + stdin, + stdout, + }) + } + + /// Performs the initialization handshake with the MCP server. + pub async fn initialize(&mut self) -> anyhow::Result<()> { + let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); + + let params = InitializeRequestParams { + capabilities: ClientCapabilities { + elicitation: Some(json!({})), + experimental: None, + roots: None, + sampling: None, + }, + client_info: Implementation { + name: "elicitation test".into(), + title: Some("Elicitation Test".into()), + version: "0.0.0".into(), + }, + protocol_version: mcp_types::MCP_SCHEMA_VERSION.into(), + }; + let params_value = serde_json::to_value(params)?; + + self.send_jsonrpc_message(JSONRPCMessage::Request(JSONRPCRequest { + jsonrpc: JSONRPC_VERSION.into(), + id: RequestId::Integer(request_id), + method: mcp_types::InitializeRequest::METHOD.into(), + params: Some(params_value), + })) + .await?; + + let initialized = self.read_jsonrpc_message().await?; + assert_eq!( + JSONRPCMessage::Response(JSONRPCResponse { + jsonrpc: JSONRPC_VERSION.into(), + id: RequestId::Integer(request_id), + result: json!({ + "capabilities": { + "tools": { + "listChanged": true + }, + }, + "serverInfo": { + "name": "codex-mcp-server", + "title": "Codex", + "version": "0.0.0" + }, + "protocolVersion": mcp_types::MCP_SCHEMA_VERSION + }) + }), + initialized + ); + + // Send notifications/initialized to ack the response. + self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification { + jsonrpc: JSONRPC_VERSION.into(), + method: mcp_types::InitializedNotification::METHOD.into(), + params: None, + })) + .await?; + + Ok(()) + } + + /// Returns the id used to make the request so it can be used when + /// correlating notifications. + pub async fn send_codex_tool_call( + &mut self, + params: CodexToolCallParam, + ) -> anyhow::Result { + let codex_tool_call_params = CallToolRequestParams { + name: "codex".to_string(), + arguments: Some(serde_json::to_value(params)?), + }; + self.send_request( + mcp_types::CallToolRequest::METHOD, + Some(serde_json::to_value(codex_tool_call_params)?), + ) + .await + } + + pub async fn send_codex_reply_tool_call( + &mut self, + session_id: &str, + prompt: &str, + ) -> anyhow::Result { + let codex_tool_call_params = CallToolRequestParams { + name: "codex-reply".to_string(), + arguments: Some(serde_json::to_value(CodexToolCallReplyParam { + prompt: prompt.to_string(), + session_id: session_id.to_string(), + })?), + }; + self.send_request( + mcp_types::CallToolRequest::METHOD, + Some(serde_json::to_value(codex_tool_call_params)?), + ) + .await + } + + async fn send_request( + &mut self, + method: &str, + params: Option, + ) -> anyhow::Result { + let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); + + let message = JSONRPCMessage::Request(JSONRPCRequest { + jsonrpc: JSONRPC_VERSION.into(), + id: RequestId::Integer(request_id), + method: method.to_string(), + params, + }); + self.send_jsonrpc_message(message).await?; + Ok(request_id) + } + + pub async fn send_response( + &mut self, + id: RequestId, + result: serde_json::Value, + ) -> anyhow::Result<()> { + self.send_jsonrpc_message(JSONRPCMessage::Response(JSONRPCResponse { + jsonrpc: JSONRPC_VERSION.into(), + id, + result, + })) + .await + } + + async fn send_jsonrpc_message(&mut self, message: JSONRPCMessage) -> anyhow::Result<()> { + let payload = serde_json::to_string(&message)?; + self.stdin.write_all(payload.as_bytes()).await?; + self.stdin.write_all(b"\n").await?; + self.stdin.flush().await?; + Ok(()) + } + + async fn read_jsonrpc_message(&mut self) -> anyhow::Result { + let mut line = String::new(); + self.stdout.read_line(&mut line).await?; + let message = serde_json::from_str::(&line)?; + Ok(message) + } + pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result { + loop { + let message = self.read_jsonrpc_message().await?; + eprint!("message: {message:?}"); + + match message { + JSONRPCMessage::Notification(_) => { + eprintln!("notification: {message:?}"); + } + JSONRPCMessage::Request(jsonrpc_request) => { + return Ok(jsonrpc_request); + } + JSONRPCMessage::Error(_) => { + anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}"); + } + JSONRPCMessage::Response(_) => { + anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}"); + } + } + } + } + + pub async fn read_stream_until_response_message( + &mut self, + request_id: RequestId, + ) -> anyhow::Result { + loop { + let message = self.read_jsonrpc_message().await?; + eprint!("message: {message:?}"); + + match message { + JSONRPCMessage::Notification(_) => { + eprintln!("notification: {message:?}"); + } + JSONRPCMessage::Request(_) => { + anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}"); + } + JSONRPCMessage::Error(_) => { + anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}"); + } + JSONRPCMessage::Response(jsonrpc_response) => { + if jsonrpc_response.id == request_id { + return Ok(jsonrpc_response); + } + } + } + } + } + + pub async fn read_stream_until_configured_response_message( + &mut self, + ) -> anyhow::Result { + let mut sid_old: Option = None; + let mut sid_new: Option = None; + loop { + let message = self.read_jsonrpc_message().await?; + eprint!("message: {message:?}"); + + match message { + JSONRPCMessage::Notification(notification) => { + if let Some(params) = notification.params { + // Back-compat schema: method == "codex/event" and msg.type == "session_configured" + if notification.method == "codex/event" { + if let Some(msg) = params.get("msg") { + if msg.get("type").and_then(|v| v.as_str()) + == Some("session_configured") + { + if let Some(session_id) = + msg.get("session_id").and_then(|v| v.as_str()) + { + sid_old = Some(session_id.to_string()); + } + } + } + } + // New schema: method is the Display of EventMsg::SessionConfigured => "SessionConfigured" + if notification.method == "session_configured" { + if let Some(msg) = params.get("msg") { + if let Some(session_id) = + msg.get("session_id").and_then(|v| v.as_str()) + { + sid_new = Some(session_id.to_string()); + } + } + } + } + + if sid_old.is_some() && sid_new.is_some() { + // Both seen, they must match + assert_eq!( + sid_old.as_ref().unwrap(), + sid_new.as_ref().unwrap(), + "session_id mismatch between old and new schema" + ); + return Ok(sid_old.unwrap()); + } + } + JSONRPCMessage::Request(_) => { + anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}"); + } + JSONRPCMessage::Error(_) => { + anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}"); + } + JSONRPCMessage::Response(_) => { + anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}"); + } + } + } + } + + pub async fn send_notification( + &mut self, + method: &str, + params: Option, + ) -> anyhow::Result<()> { + self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification { + jsonrpc: JSONRPC_VERSION.into(), + method: method.to_string(), + params, + })) + .await + } +} diff --git a/codex-rs/mcp-server/tests/common/mock_model_server.rs b/codex-rs/mcp-server/tests/common/mock_model_server.rs new file mode 100644 index 0000000000..be7f3eb5b3 --- /dev/null +++ b/codex-rs/mcp-server/tests/common/mock_model_server.rs @@ -0,0 +1,47 @@ +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; + +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::Respond; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +/// Create a mock server that will provide the responses, in order, for +/// requests to the `/v1/chat/completions` endpoint. +pub async fn create_mock_chat_completions_server(responses: Vec) -> MockServer { + let server = MockServer::start().await; + + let num_calls = responses.len(); + let seq_responder = SeqResponder { + num_calls: AtomicUsize::new(0), + responses, + }; + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(seq_responder) + .expect(num_calls as u64) + .mount(&server) + .await; + + server +} + +struct SeqResponder { + num_calls: AtomicUsize, + responses: Vec, +} + +impl Respond for SeqResponder { + fn respond(&self, _: &wiremock::Request) -> ResponseTemplate { + let call_num = self.num_calls.fetch_add(1, Ordering::SeqCst); + match self.responses.get(call_num) { + Some(response) => ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(response.clone(), "text/event-stream"), + None => panic!("no response for {call_num}"), + } + } +} diff --git a/codex-rs/mcp-server/tests/common/responses.rs b/codex-rs/mcp-server/tests/common/responses.rs new file mode 100644 index 0000000000..9a827fb986 --- /dev/null +++ b/codex-rs/mcp-server/tests/common/responses.rs @@ -0,0 +1,95 @@ +use serde_json::json; +use std::path::Path; + +pub fn create_shell_sse_response( + command: Vec, + workdir: Option<&Path>, + timeout_ms: Option, + call_id: &str, +) -> anyhow::Result { + // The `arguments`` for the `shell` tool is a serialized JSON object. + let tool_call_arguments = serde_json::to_string(&json!({ + "command": command, + "workdir": workdir.map(|w| w.to_string_lossy()), + "timeout": timeout_ms + }))?; + let tool_call = json!({ + "choices": [ + { + "delta": { + "tool_calls": [ + { + "id": call_id, + "function": { + "name": "shell", + "arguments": tool_call_arguments + } + } + ] + }, + "finish_reason": "tool_calls" + } + ] + }); + + let sse = format!( + "data: {}\n\ndata: DONE\n\n", + serde_json::to_string(&tool_call)? + ); + Ok(sse) +} + +pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Result { + let assistant_message = json!({ + "choices": [ + { + "delta": { + "content": message + }, + "finish_reason": "stop" + } + ] + }); + + let sse = format!( + "data: {}\n\ndata: DONE\n\n", + serde_json::to_string(&assistant_message)? + ); + Ok(sse) +} + +pub fn create_apply_patch_sse_response( + patch_content: &str, + call_id: &str, +) -> anyhow::Result { + // Use shell command to call apply_patch with heredoc format + let shell_command = format!("apply_patch <<'EOF'\n{patch_content}\nEOF"); + let tool_call_arguments = serde_json::to_string(&json!({ + "command": ["bash", "-lc", shell_command] + }))?; + + let tool_call = json!({ + "choices": [ + { + "delta": { + "tool_calls": [ + { + "id": call_id, + "function": { + "name": "shell", + "arguments": tool_call_arguments + } + } + ] + }, + "finish_reason": "tool_calls" + } + ] + }); + + let sse = format!( + "data: {}\n\ndata: DONE\n\n", + serde_json::to_string(&tool_call)? + ); + Ok(sse) +} diff --git a/codex-rs/mcp-server/tests/interrupt.rs b/codex-rs/mcp-server/tests/interrupt.rs new file mode 100644 index 0000000000..313bc7afab --- /dev/null +++ b/codex-rs/mcp-server/tests/interrupt.rs @@ -0,0 +1,177 @@ +#![cfg(unix)] +// Support code lives in the `mcp_test_support` crate under tests/common. + +use std::path::Path; + +use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; +use codex_mcp_server::CodexToolCallParam; +use mcp_types::JSONRPCResponse; +use mcp_types::RequestId; +use serde_json::json; +use tempfile::TempDir; +use tokio::time::timeout; + +use mcp_test_support::McpProcess; +use mcp_test_support::create_mock_chat_completions_server; +use mcp_test_support::create_shell_sse_response; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_shell_command_interruption() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + if let Err(err) = shell_command_interruption().await { + panic!("failure: {err}"); + } +} + +async fn shell_command_interruption() -> anyhow::Result<()> { + // Use a cross-platform blocking command. On Windows plain `sleep` is not guaranteed to exist + // (MSYS/GNU coreutils may be absent) and the failure causes the tool call to finish immediately, + // which triggers a second model request before the test sends the explicit follow-up. That + // prematurely consumes the second mocked SSE response and leads to a third POST (panic: no response for 2). + // Powershell Start-Sleep is always available on Windows runners. On Unix we keep using `sleep`. + #[cfg(target_os = "windows")] + let shell_command = vec![ + "powershell".to_string(), + "-Command".to_string(), + "Start-Sleep -Seconds 60".to_string(), + ]; + #[cfg(not(target_os = "windows"))] + let shell_command = vec!["sleep".to_string(), "60".to_string()]; + let workdir_for_shell_function_call = TempDir::new()?; + + // Create mock server with a single SSE response: the long sleep command + let server = create_mock_chat_completions_server(vec![ + create_shell_sse_response( + shell_command.clone(), + Some(workdir_for_shell_function_call.path()), + Some(60_000), // 60 seconds timeout in ms + "call_sleep", + )?, + create_shell_sse_response( + shell_command.clone(), + Some(workdir_for_shell_function_call.path()), + Some(60_000), // 60 seconds timeout in ms + "call_sleep", + )?, + ]) + .await; + + // Create Codex configuration + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), server.uri())?; + let mut mcp_process = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??; + + // Send codex tool call that triggers "sleep 60" + let codex_request_id = mcp_process + .send_codex_tool_call(CodexToolCallParam { + cwd: None, + prompt: "First Run: run `sleep 60`".to_string(), + model: None, + profile: None, + approval_policy: None, + sandbox: None, + config: None, + base_instructions: None, + include_plan_tool: None, + }) + .await?; + + let session_id = mcp_process + .read_stream_until_configured_response_message() + .await?; + + // Give the command a moment to start + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + // Send interrupt notification + mcp_process + .send_notification( + "notifications/cancelled", + Some(json!({ "requestId": codex_request_id })), + ) + .await?; + + // Expect Codex to return an error or interruption response + let codex_response: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)), + ) + .await??; + + assert!( + codex_response + .result + .as_object() + .map(|o| o.contains_key("error")) + .unwrap_or(false), + "Expected an interruption or error result, got: {codex_response:?}" + ); + + let codex_reply_request_id = mcp_process + .send_codex_reply_tool_call(&session_id, "Second Run: run `sleep 60`") + .await?; + + // Give the command a moment to start + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + // Send interrupt notification + mcp_process + .send_notification( + "notifications/cancelled", + Some(json!({ "requestId": codex_reply_request_id })), + ) + .await?; + + // Expect Codex to return an error or interruption response + let codex_response: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_response_message(RequestId::Integer(codex_reply_request_id)), + ) + .await??; + + assert!( + codex_response + .result + .as_object() + .map(|o| o.contains_key("error")) + .unwrap_or(false), + "Expected an interruption or error result, got: {codex_response:?}" + ); + Ok(()) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn create_config_toml(codex_home: &Path, server_uri: String) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + format!( + r#" +model = "mock-model" +approval_policy = "never" +sandbox_mode = "danger-full-access" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "chat" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +} diff --git a/codex-rs/mcp-types/Cargo.toml b/codex-rs/mcp-types/Cargo.toml index 81ac2d9761..db849d5f0e 100644 --- a/codex-rs/mcp-types/Cargo.toml +++ b/codex-rs/mcp-types/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "mcp-types" version = { workspace = true } -edition = "2024" [lints] workspace = true diff --git a/codex-rs/mcp-types/README.md b/codex-rs/mcp-types/README.md index 2ac613ea96..66ea540cc4 100644 --- a/codex-rs/mcp-types/README.md +++ b/codex-rs/mcp-types/README.md @@ -2,7 +2,7 @@ Types for Model Context Protocol. Inspired by https://crates.io/crates/lsp-types. -As documented on https://modelcontextprotocol.io/specification/2025-03-26/basic: +As documented on https://modelcontextprotocol.io/specification/2025-06-18/basic: -- TypeScript schema is the source of truth: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.ts -- JSON schema is amenable to automated tooling: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json +- TypeScript schema is the source of truth: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-06-18/schema.ts +- JSON schema is amenable to automated tooling: https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-06-18/schema.json diff --git a/codex-rs/mcp-types/generate_mcp_types.py b/codex-rs/mcp-types/generate_mcp_types.py index ff11dbf0dc..38f57e9a1b 100755 --- a/codex-rs/mcp-types/generate_mcp_types.py +++ b/codex-rs/mcp-types/generate_mcp_types.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # flake8: noqa: E501 +import argparse import json import subprocess import sys @@ -13,10 +14,13 @@ from pathlib import Path # Helper first so it is defined when other functions call it. from typing import Any, Literal -SCHEMA_VERSION = "2025-03-26" +SCHEMA_VERSION = "2025-06-18" JSONRPC_VERSION = "2.0" STANDARD_DERIVE = "#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]\n" +STANDARD_HASHABLE_DERIVE = ( + "#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]\n" +) # Will be populated with the schema's `definitions` map in `main()` so that # helper functions (for example `define_any_of`) can perform look-ups while @@ -26,19 +30,27 @@ DEFINITIONS: dict[str, Any] = {} CLIENT_REQUEST_TYPE_NAMES: list[str] = [] # Concrete *Notification types that make up the ServerNotification enum. SERVER_NOTIFICATION_TYPE_NAMES: list[str] = [] +# Enum types that will need a `allow(clippy::large_enum_variant)` annotation in +# order to compile without warnings. +LARGE_ENUMS = {"ServerResult"} def main() -> int: - num_args = len(sys.argv) - if num_args == 1: - schema_file = ( - Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json" - ) - elif num_args == 2: - schema_file = Path(sys.argv[1]) - else: - print("Usage: python3 codegen.py ") - return 1 + parser = argparse.ArgumentParser( + description="Embed, cluster and analyse text prompts via the OpenAI API.", + ) + + default_schema_file = ( + Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json" + ) + parser.add_argument( + "schema_file", + nargs="?", + default=default_schema_file, + help="schema.json file to process", + ) + args = parser.parse_args() + schema_file = args.schema_file lib_rs = Path(__file__).resolve().parent / "src/lib.rs" @@ -197,6 +209,8 @@ def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> Non if name.endswith("Result"): out.extend(f"impl From<{name}> for serde_json::Value {{\n") out.append(f" fn from(value: {name}) -> Self {{\n") + out.append(" // Leave this as it should never fail\n") + out.append(" #[expect(clippy::unwrap_used)]\n") out.append(" serde_json::to_value(value).unwrap()\n") out.append(" }\n") out.append("}\n\n") @@ -211,20 +225,7 @@ def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> Non any_of = definition.get("anyOf", []) if any_of: assert isinstance(any_of, list) - if name == "JSONRPCMessage": - # Special case for JSONRPCMessage because its definition in the - # JSON schema does not quite match how we think about this type - # definition in Rust. - deep_copied_any_of = json.loads(json.dumps(any_of)) - deep_copied_any_of[2] = { - "$ref": "#/definitions/JSONRPCBatchRequest", - } - deep_copied_any_of[5] = { - "$ref": "#/definitions/JSONRPCBatchResponse", - } - out.extend(define_any_of(name, deep_copied_any_of, description)) - else: - out.extend(define_any_of(name, any_of, description)) + out.extend(define_any_of(name, any_of, description)) return type_prop = definition.get("type", None) @@ -393,7 +394,7 @@ def define_string_enum( def define_untagged_enum(name: str, type_list: list[str], out: list[str]) -> None: - out.append(STANDARD_DERIVE) + out.append(STANDARD_HASHABLE_DERIVE) out.append("#[serde(untagged)]\n") out.append(f"pub enum {name} {{\n") for simple_type in type_list: @@ -439,6 +440,8 @@ def define_any_of( if serde := get_serde_annotation_for_anyof_type(name): out.append(serde + "\n") + if name in LARGE_ENUMS: + out.append("#[allow(clippy::large_enum_variant)]\n") out.append(f"pub enum {name} {{\n") if name == "ClientRequest": @@ -596,6 +599,8 @@ def rust_prop_name(name: str, is_optional: bool) -> RustProp: prop_name = "r#type" elif name == "ref": prop_name = "r#ref" + elif name == "enum": + prop_name = "r#enum" elif snake_case := to_snake_case(name): prop_name = snake_case is_rename = True diff --git a/codex-rs/mcp-types/schema/2025-06-18/schema.json b/codex-rs/mcp-types/schema/2025-06-18/schema.json new file mode 100644 index 0000000000..24ba4f6309 --- /dev/null +++ b/codex-rs/mcp-types/schema/2025-06-18/schema.json @@ -0,0 +1,2517 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "definitions": { + "Annotations": { + "description": "Optional annotations for the client. The client can use annotations to inform how objects are used or displayed", + "properties": { + "audience": { + "description": "Describes who the intended customer of this object or data is.\n\nIt can include multiple entries to indicate content useful for multiple audiences (e.g., `[\"user\", \"assistant\"]`).", + "items": { + "$ref": "#/definitions/Role" + }, + "type": "array" + }, + "lastModified": { + "description": "The moment the resource was last modified, as an ISO 8601 formatted string.\n\nShould be an ISO 8601 formatted string (e.g., \"2025-01-12T15:00:58Z\").\n\nExamples: last activity timestamp in an open file, timestamp when the resource\nwas attached, etc.", + "type": "string" + }, + "priority": { + "description": "Describes how important this data is for operating the server.\n\nA value of 1 means \"most important,\" and indicates that the data is\neffectively required, while 0 means \"least important,\" and indicates that\nthe data is entirely optional.", + "maximum": 1, + "minimum": 0, + "type": "number" + } + }, + "type": "object" + }, + "AudioContent": { + "description": "Audio provided to or from an LLM.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "data": { + "description": "The base64-encoded audio data.", + "format": "byte", + "type": "string" + }, + "mimeType": { + "description": "The MIME type of the audio. Different providers may support different audio types.", + "type": "string" + }, + "type": { + "const": "audio", + "type": "string" + } + }, + "required": [ + "data", + "mimeType", + "type" + ], + "type": "object" + }, + "BaseMetadata": { + "description": "Base interface for metadata with name (identifier) and title (display name) properties.", + "properties": { + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + }, + "BlobResourceContents": { + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "blob": { + "description": "A base64-encoded string representing the binary data of the item.", + "format": "byte", + "type": "string" + }, + "mimeType": { + "description": "The MIME type of this resource, if known.", + "type": "string" + }, + "uri": { + "description": "The URI of this resource.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "blob", + "uri" + ], + "type": "object" + }, + "BooleanSchema": { + "properties": { + "default": { + "type": "boolean" + }, + "description": { + "type": "string" + }, + "title": { + "type": "string" + }, + "type": { + "const": "boolean", + "type": "string" + } + }, + "required": [ + "type" + ], + "type": "object" + }, + "CallToolRequest": { + "description": "Used by the client to invoke a tool provided by the server.", + "properties": { + "method": { + "const": "tools/call", + "type": "string" + }, + "params": { + "properties": { + "arguments": { + "additionalProperties": {}, + "type": "object" + }, + "name": { + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "CallToolResult": { + "description": "The server's response to a tool call.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "content": { + "description": "A list of content objects that represent the unstructured result of the tool call.", + "items": { + "$ref": "#/definitions/ContentBlock" + }, + "type": "array" + }, + "isError": { + "description": "Whether the tool call ended in an error.\n\nIf not set, this is assumed to be false (the call was successful).\n\nAny errors that originate from the tool SHOULD be reported inside the result\nobject, with `isError` set to true, _not_ as an MCP protocol-level error\nresponse. Otherwise, the LLM would not be able to see that an error occurred\nand self-correct.\n\nHowever, any errors in _finding_ the tool, an error indicating that the\nserver does not support tool calls, or any other exceptional conditions,\nshould be reported as an MCP error response.", + "type": "boolean" + }, + "structuredContent": { + "additionalProperties": {}, + "description": "An optional JSON object that represents the structured result of the tool call.", + "type": "object" + } + }, + "required": [ + "content" + ], + "type": "object" + }, + "CancelledNotification": { + "description": "This notification can be sent by either side to indicate that it is cancelling a previously-issued request.\n\nThe request SHOULD still be in-flight, but due to communication latency, it is always possible that this notification MAY arrive after the request has already finished.\n\nThis notification indicates that the result will be unused, so any associated processing SHOULD cease.\n\nA client MUST NOT attempt to cancel its `initialize` request.", + "properties": { + "method": { + "const": "notifications/cancelled", + "type": "string" + }, + "params": { + "properties": { + "reason": { + "description": "An optional string describing the reason for the cancellation. This MAY be logged or presented to the user.", + "type": "string" + }, + "requestId": { + "$ref": "#/definitions/RequestId", + "description": "The ID of the request to cancel.\n\nThis MUST correspond to the ID of a request previously issued in the same direction." + } + }, + "required": [ + "requestId" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "ClientCapabilities": { + "description": "Capabilities a client may support. Known capabilities are defined here, in this schema, but this is not a closed set: any client can define its own, additional capabilities.", + "properties": { + "elicitation": { + "additionalProperties": true, + "description": "Present if the client supports elicitation from the server.", + "properties": {}, + "type": "object" + }, + "experimental": { + "additionalProperties": { + "additionalProperties": true, + "properties": {}, + "type": "object" + }, + "description": "Experimental, non-standard capabilities that the client supports.", + "type": "object" + }, + "roots": { + "description": "Present if the client supports listing roots.", + "properties": { + "listChanged": { + "description": "Whether the client supports notifications for changes to the roots list.", + "type": "boolean" + } + }, + "type": "object" + }, + "sampling": { + "additionalProperties": true, + "description": "Present if the client supports sampling from an LLM.", + "properties": {}, + "type": "object" + } + }, + "type": "object" + }, + "ClientNotification": { + "anyOf": [ + { + "$ref": "#/definitions/CancelledNotification" + }, + { + "$ref": "#/definitions/InitializedNotification" + }, + { + "$ref": "#/definitions/ProgressNotification" + }, + { + "$ref": "#/definitions/RootsListChangedNotification" + } + ] + }, + "ClientRequest": { + "anyOf": [ + { + "$ref": "#/definitions/InitializeRequest" + }, + { + "$ref": "#/definitions/PingRequest" + }, + { + "$ref": "#/definitions/ListResourcesRequest" + }, + { + "$ref": "#/definitions/ListResourceTemplatesRequest" + }, + { + "$ref": "#/definitions/ReadResourceRequest" + }, + { + "$ref": "#/definitions/SubscribeRequest" + }, + { + "$ref": "#/definitions/UnsubscribeRequest" + }, + { + "$ref": "#/definitions/ListPromptsRequest" + }, + { + "$ref": "#/definitions/GetPromptRequest" + }, + { + "$ref": "#/definitions/ListToolsRequest" + }, + { + "$ref": "#/definitions/CallToolRequest" + }, + { + "$ref": "#/definitions/SetLevelRequest" + }, + { + "$ref": "#/definitions/CompleteRequest" + } + ] + }, + "ClientResult": { + "anyOf": [ + { + "$ref": "#/definitions/Result" + }, + { + "$ref": "#/definitions/CreateMessageResult" + }, + { + "$ref": "#/definitions/ListRootsResult" + }, + { + "$ref": "#/definitions/ElicitResult" + } + ] + }, + "CompleteRequest": { + "description": "A request from the client to the server, to ask for completion options.", + "properties": { + "method": { + "const": "completion/complete", + "type": "string" + }, + "params": { + "properties": { + "argument": { + "description": "The argument's information", + "properties": { + "name": { + "description": "The name of the argument", + "type": "string" + }, + "value": { + "description": "The value of the argument to use for completion matching.", + "type": "string" + } + }, + "required": [ + "name", + "value" + ], + "type": "object" + }, + "context": { + "description": "Additional, optional context for completions", + "properties": { + "arguments": { + "additionalProperties": { + "type": "string" + }, + "description": "Previously-resolved variables in a URI template or prompt.", + "type": "object" + } + }, + "type": "object" + }, + "ref": { + "anyOf": [ + { + "$ref": "#/definitions/PromptReference" + }, + { + "$ref": "#/definitions/ResourceTemplateReference" + } + ] + } + }, + "required": [ + "argument", + "ref" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "CompleteResult": { + "description": "The server's response to a completion/complete request", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "completion": { + "properties": { + "hasMore": { + "description": "Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown.", + "type": "boolean" + }, + "total": { + "description": "The total number of completion options available. This can exceed the number of values actually sent in the response.", + "type": "integer" + }, + "values": { + "description": "An array of completion values. Must not exceed 100 items.", + "items": { + "type": "string" + }, + "type": "array" + } + }, + "required": [ + "values" + ], + "type": "object" + } + }, + "required": [ + "completion" + ], + "type": "object" + }, + "ContentBlock": { + "anyOf": [ + { + "$ref": "#/definitions/TextContent" + }, + { + "$ref": "#/definitions/ImageContent" + }, + { + "$ref": "#/definitions/AudioContent" + }, + { + "$ref": "#/definitions/ResourceLink" + }, + { + "$ref": "#/definitions/EmbeddedResource" + } + ] + }, + "CreateMessageRequest": { + "description": "A request from the server to sample an LLM via the client. The client has full discretion over which model to select. The client should also inform the user before beginning sampling, to allow them to inspect the request (human in the loop) and decide whether to approve it.", + "properties": { + "method": { + "const": "sampling/createMessage", + "type": "string" + }, + "params": { + "properties": { + "includeContext": { + "description": "A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. The client MAY ignore this request.", + "enum": [ + "allServers", + "none", + "thisServer" + ], + "type": "string" + }, + "maxTokens": { + "description": "The maximum number of tokens to sample, as requested by the server. The client MAY choose to sample fewer tokens than requested.", + "type": "integer" + }, + "messages": { + "items": { + "$ref": "#/definitions/SamplingMessage" + }, + "type": "array" + }, + "metadata": { + "additionalProperties": true, + "description": "Optional metadata to pass through to the LLM provider. The format of this metadata is provider-specific.", + "properties": {}, + "type": "object" + }, + "modelPreferences": { + "$ref": "#/definitions/ModelPreferences", + "description": "The server's preferences for which model to select. The client MAY ignore these preferences." + }, + "stopSequences": { + "items": { + "type": "string" + }, + "type": "array" + }, + "systemPrompt": { + "description": "An optional system prompt the server wants to use for sampling. The client MAY modify or omit this prompt.", + "type": "string" + }, + "temperature": { + "type": "number" + } + }, + "required": [ + "maxTokens", + "messages" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "CreateMessageResult": { + "description": "The client's response to a sampling/create_message request from the server. The client should inform the user before returning the sampled message, to allow them to inspect the response (human in the loop) and decide whether to allow the server to see it.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "content": { + "anyOf": [ + { + "$ref": "#/definitions/TextContent" + }, + { + "$ref": "#/definitions/ImageContent" + }, + { + "$ref": "#/definitions/AudioContent" + } + ] + }, + "model": { + "description": "The name of the model that generated the message.", + "type": "string" + }, + "role": { + "$ref": "#/definitions/Role" + }, + "stopReason": { + "description": "The reason why sampling stopped, if known.", + "type": "string" + } + }, + "required": [ + "content", + "model", + "role" + ], + "type": "object" + }, + "Cursor": { + "description": "An opaque token used to represent a cursor for pagination.", + "type": "string" + }, + "ElicitRequest": { + "description": "A request from the server to elicit additional information from the user via the client.", + "properties": { + "method": { + "const": "elicitation/create", + "type": "string" + }, + "params": { + "properties": { + "message": { + "description": "The message to present to the user.", + "type": "string" + }, + "requestedSchema": { + "description": "A restricted subset of JSON Schema.\nOnly top-level properties are allowed, without nesting.", + "properties": { + "properties": { + "additionalProperties": { + "$ref": "#/definitions/PrimitiveSchemaDefinition" + }, + "type": "object" + }, + "required": { + "items": { + "type": "string" + }, + "type": "array" + }, + "type": { + "const": "object", + "type": "string" + } + }, + "required": [ + "properties", + "type" + ], + "type": "object" + } + }, + "required": [ + "message", + "requestedSchema" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "ElicitResult": { + "description": "The client's response to an elicitation request.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "action": { + "description": "The user action in response to the elicitation.\n- \"accept\": User submitted the form/confirmed the action\n- \"decline\": User explicitly declined the action\n- \"cancel\": User dismissed without making an explicit choice", + "enum": [ + "accept", + "cancel", + "decline" + ], + "type": "string" + }, + "content": { + "additionalProperties": { + "type": [ + "string", + "integer", + "boolean" + ] + }, + "description": "The submitted form data, only present when action is \"accept\".\nContains values matching the requested schema.", + "type": "object" + } + }, + "required": [ + "action" + ], + "type": "object" + }, + "EmbeddedResource": { + "description": "The contents of a resource, embedded into a prompt or tool call result.\n\nIt is up to the client how best to render embedded resources for the benefit\nof the LLM and/or the user.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "resource": { + "anyOf": [ + { + "$ref": "#/definitions/TextResourceContents" + }, + { + "$ref": "#/definitions/BlobResourceContents" + } + ] + }, + "type": { + "const": "resource", + "type": "string" + } + }, + "required": [ + "resource", + "type" + ], + "type": "object" + }, + "EmptyResult": { + "$ref": "#/definitions/Result" + }, + "EnumSchema": { + "properties": { + "description": { + "type": "string" + }, + "enum": { + "items": { + "type": "string" + }, + "type": "array" + }, + "enumNames": { + "items": { + "type": "string" + }, + "type": "array" + }, + "title": { + "type": "string" + }, + "type": { + "const": "string", + "type": "string" + } + }, + "required": [ + "enum", + "type" + ], + "type": "object" + }, + "GetPromptRequest": { + "description": "Used by the client to get a prompt provided by the server.", + "properties": { + "method": { + "const": "prompts/get", + "type": "string" + }, + "params": { + "properties": { + "arguments": { + "additionalProperties": { + "type": "string" + }, + "description": "Arguments to use for templating the prompt.", + "type": "object" + }, + "name": { + "description": "The name of the prompt or prompt template.", + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "GetPromptResult": { + "description": "The server's response to a prompts/get request from the client.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "description": { + "description": "An optional description for the prompt.", + "type": "string" + }, + "messages": { + "items": { + "$ref": "#/definitions/PromptMessage" + }, + "type": "array" + } + }, + "required": [ + "messages" + ], + "type": "object" + }, + "ImageContent": { + "description": "An image provided to or from an LLM.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "data": { + "description": "The base64-encoded image data.", + "format": "byte", + "type": "string" + }, + "mimeType": { + "description": "The MIME type of the image. Different providers may support different image types.", + "type": "string" + }, + "type": { + "const": "image", + "type": "string" + } + }, + "required": [ + "data", + "mimeType", + "type" + ], + "type": "object" + }, + "Implementation": { + "description": "Describes the name and version of an MCP implementation, with an optional title for UI representation.", + "properties": { + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + }, + "version": { + "type": "string" + } + }, + "required": [ + "name", + "version" + ], + "type": "object" + }, + "InitializeRequest": { + "description": "This request is sent from the client to the server when it first connects, asking it to begin initialization.", + "properties": { + "method": { + "const": "initialize", + "type": "string" + }, + "params": { + "properties": { + "capabilities": { + "$ref": "#/definitions/ClientCapabilities" + }, + "clientInfo": { + "$ref": "#/definitions/Implementation" + }, + "protocolVersion": { + "description": "The latest version of the Model Context Protocol that the client supports. The client MAY decide to support older versions as well.", + "type": "string" + } + }, + "required": [ + "capabilities", + "clientInfo", + "protocolVersion" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "InitializeResult": { + "description": "After receiving an initialize request from the client, the server sends this response.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "capabilities": { + "$ref": "#/definitions/ServerCapabilities" + }, + "instructions": { + "description": "Instructions describing how to use the server and its features.\n\nThis can be used by clients to improve the LLM's understanding of available tools, resources, etc. It can be thought of like a \"hint\" to the model. For example, this information MAY be added to the system prompt.", + "type": "string" + }, + "protocolVersion": { + "description": "The version of the Model Context Protocol that the server wants to use. This may not match the version that the client requested. If the client cannot support this version, it MUST disconnect.", + "type": "string" + }, + "serverInfo": { + "$ref": "#/definitions/Implementation" + } + }, + "required": [ + "capabilities", + "protocolVersion", + "serverInfo" + ], + "type": "object" + }, + "InitializedNotification": { + "description": "This notification is sent from the client to the server after initialization has finished.", + "properties": { + "method": { + "const": "notifications/initialized", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "JSONRPCError": { + "description": "A response to a request that indicates an error occurred.", + "properties": { + "error": { + "properties": { + "code": { + "description": "The error type that occurred.", + "type": "integer" + }, + "data": { + "description": "Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.)." + }, + "message": { + "description": "A short description of the error. The message SHOULD be limited to a concise single sentence.", + "type": "string" + } + }, + "required": [ + "code", + "message" + ], + "type": "object" + }, + "id": { + "$ref": "#/definitions/RequestId" + }, + "jsonrpc": { + "const": "2.0", + "type": "string" + } + }, + "required": [ + "error", + "id", + "jsonrpc" + ], + "type": "object" + }, + "JSONRPCMessage": { + "anyOf": [ + { + "$ref": "#/definitions/JSONRPCRequest" + }, + { + "$ref": "#/definitions/JSONRPCNotification" + }, + { + "$ref": "#/definitions/JSONRPCResponse" + }, + { + "$ref": "#/definitions/JSONRPCError" + } + ], + "description": "Refers to any valid JSON-RPC object that can be decoded off the wire, or encoded to be sent." + }, + "JSONRPCNotification": { + "description": "A notification which does not expect a response.", + "properties": { + "jsonrpc": { + "const": "2.0", + "type": "string" + }, + "method": { + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "jsonrpc", + "method" + ], + "type": "object" + }, + "JSONRPCRequest": { + "description": "A request that expects a response.", + "properties": { + "id": { + "$ref": "#/definitions/RequestId" + }, + "jsonrpc": { + "const": "2.0", + "type": "string" + }, + "method": { + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "properties": { + "progressToken": { + "$ref": "#/definitions/ProgressToken", + "description": "If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications." + } + }, + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "id", + "jsonrpc", + "method" + ], + "type": "object" + }, + "JSONRPCResponse": { + "description": "A successful (non-error) response to a request.", + "properties": { + "id": { + "$ref": "#/definitions/RequestId" + }, + "jsonrpc": { + "const": "2.0", + "type": "string" + }, + "result": { + "$ref": "#/definitions/Result" + } + }, + "required": [ + "id", + "jsonrpc", + "result" + ], + "type": "object" + }, + "ListPromptsRequest": { + "description": "Sent from the client to request a list of prompts and prompt templates the server has.", + "properties": { + "method": { + "const": "prompts/list", + "type": "string" + }, + "params": { + "properties": { + "cursor": { + "description": "An opaque token representing the current pagination position.\nIf provided, the server should return results starting after this cursor.", + "type": "string" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "ListPromptsResult": { + "description": "The server's response to a prompts/list request from the client.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "nextCursor": { + "description": "An opaque token representing the pagination position after the last returned result.\nIf present, there may be more results available.", + "type": "string" + }, + "prompts": { + "items": { + "$ref": "#/definitions/Prompt" + }, + "type": "array" + } + }, + "required": [ + "prompts" + ], + "type": "object" + }, + "ListResourceTemplatesRequest": { + "description": "Sent from the client to request a list of resource templates the server has.", + "properties": { + "method": { + "const": "resources/templates/list", + "type": "string" + }, + "params": { + "properties": { + "cursor": { + "description": "An opaque token representing the current pagination position.\nIf provided, the server should return results starting after this cursor.", + "type": "string" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "ListResourceTemplatesResult": { + "description": "The server's response to a resources/templates/list request from the client.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "nextCursor": { + "description": "An opaque token representing the pagination position after the last returned result.\nIf present, there may be more results available.", + "type": "string" + }, + "resourceTemplates": { + "items": { + "$ref": "#/definitions/ResourceTemplate" + }, + "type": "array" + } + }, + "required": [ + "resourceTemplates" + ], + "type": "object" + }, + "ListResourcesRequest": { + "description": "Sent from the client to request a list of resources the server has.", + "properties": { + "method": { + "const": "resources/list", + "type": "string" + }, + "params": { + "properties": { + "cursor": { + "description": "An opaque token representing the current pagination position.\nIf provided, the server should return results starting after this cursor.", + "type": "string" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "ListResourcesResult": { + "description": "The server's response to a resources/list request from the client.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "nextCursor": { + "description": "An opaque token representing the pagination position after the last returned result.\nIf present, there may be more results available.", + "type": "string" + }, + "resources": { + "items": { + "$ref": "#/definitions/Resource" + }, + "type": "array" + } + }, + "required": [ + "resources" + ], + "type": "object" + }, + "ListRootsRequest": { + "description": "Sent from the server to request a list of root URIs from the client. Roots allow\nservers to ask for specific directories or files to operate on. A common example\nfor roots is providing a set of repositories or directories a server should operate\non.\n\nThis request is typically used when the server needs to understand the file system\nstructure or access specific locations that the client has permission to read from.", + "properties": { + "method": { + "const": "roots/list", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "properties": { + "progressToken": { + "$ref": "#/definitions/ProgressToken", + "description": "If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications." + } + }, + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "ListRootsResult": { + "description": "The client's response to a roots/list request from the server.\nThis result contains an array of Root objects, each representing a root directory\nor file that the server can operate on.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "roots": { + "items": { + "$ref": "#/definitions/Root" + }, + "type": "array" + } + }, + "required": [ + "roots" + ], + "type": "object" + }, + "ListToolsRequest": { + "description": "Sent from the client to request a list of tools the server has.", + "properties": { + "method": { + "const": "tools/list", + "type": "string" + }, + "params": { + "properties": { + "cursor": { + "description": "An opaque token representing the current pagination position.\nIf provided, the server should return results starting after this cursor.", + "type": "string" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "ListToolsResult": { + "description": "The server's response to a tools/list request from the client.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "nextCursor": { + "description": "An opaque token representing the pagination position after the last returned result.\nIf present, there may be more results available.", + "type": "string" + }, + "tools": { + "items": { + "$ref": "#/definitions/Tool" + }, + "type": "array" + } + }, + "required": [ + "tools" + ], + "type": "object" + }, + "LoggingLevel": { + "description": "The severity of a log message.\n\nThese map to syslog message severities, as specified in RFC-5424:\nhttps://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1", + "enum": [ + "alert", + "critical", + "debug", + "emergency", + "error", + "info", + "notice", + "warning" + ], + "type": "string" + }, + "LoggingMessageNotification": { + "description": "Notification of a log message passed from server to client. If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically.", + "properties": { + "method": { + "const": "notifications/message", + "type": "string" + }, + "params": { + "properties": { + "data": { + "description": "The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here." + }, + "level": { + "$ref": "#/definitions/LoggingLevel", + "description": "The severity of this log message." + }, + "logger": { + "description": "An optional name of the logger issuing this message.", + "type": "string" + } + }, + "required": [ + "data", + "level" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "ModelHint": { + "description": "Hints to use for model selection.\n\nKeys not declared here are currently left unspecified by the spec and are up\nto the client to interpret.", + "properties": { + "name": { + "description": "A hint for a model name.\n\nThe client SHOULD treat this as a substring of a model name; for example:\n - `claude-3-5-sonnet` should match `claude-3-5-sonnet-20241022`\n - `sonnet` should match `claude-3-5-sonnet-20241022`, `claude-3-sonnet-20240229`, etc.\n - `claude` should match any Claude model\n\nThe client MAY also map the string to a different provider's model name or a different model family, as long as it fills a similar niche; for example:\n - `gemini-1.5-flash` could match `claude-3-haiku-20240307`", + "type": "string" + } + }, + "type": "object" + }, + "ModelPreferences": { + "description": "The server's preferences for model selection, requested of the client during sampling.\n\nBecause LLMs can vary along multiple dimensions, choosing the \"best\" model is\nrarely straightforward. Different models excel in different areas—some are\nfaster but less capable, others are more capable but more expensive, and so\non. This interface allows servers to express their priorities across multiple\ndimensions to help clients make an appropriate selection for their use case.\n\nThese preferences are always advisory. The client MAY ignore them. It is also\nup to the client to decide how to interpret these preferences and how to\nbalance them against other considerations.", + "properties": { + "costPriority": { + "description": "How much to prioritize cost when selecting a model. A value of 0 means cost\nis not important, while a value of 1 means cost is the most important\nfactor.", + "maximum": 1, + "minimum": 0, + "type": "number" + }, + "hints": { + "description": "Optional hints to use for model selection.\n\nIf multiple hints are specified, the client MUST evaluate them in order\n(such that the first match is taken).\n\nThe client SHOULD prioritize these hints over the numeric priorities, but\nMAY still use the priorities to select from ambiguous matches.", + "items": { + "$ref": "#/definitions/ModelHint" + }, + "type": "array" + }, + "intelligencePriority": { + "description": "How much to prioritize intelligence and capabilities when selecting a\nmodel. A value of 0 means intelligence is not important, while a value of 1\nmeans intelligence is the most important factor.", + "maximum": 1, + "minimum": 0, + "type": "number" + }, + "speedPriority": { + "description": "How much to prioritize sampling speed (latency) when selecting a model. A\nvalue of 0 means speed is not important, while a value of 1 means speed is\nthe most important factor.", + "maximum": 1, + "minimum": 0, + "type": "number" + } + }, + "type": "object" + }, + "Notification": { + "properties": { + "method": { + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "NumberSchema": { + "properties": { + "description": { + "type": "string" + }, + "maximum": { + "type": "integer" + }, + "minimum": { + "type": "integer" + }, + "title": { + "type": "string" + }, + "type": { + "enum": [ + "integer", + "number" + ], + "type": "string" + } + }, + "required": [ + "type" + ], + "type": "object" + }, + "PaginatedRequest": { + "properties": { + "method": { + "type": "string" + }, + "params": { + "properties": { + "cursor": { + "description": "An opaque token representing the current pagination position.\nIf provided, the server should return results starting after this cursor.", + "type": "string" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "PaginatedResult": { + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "nextCursor": { + "description": "An opaque token representing the pagination position after the last returned result.\nIf present, there may be more results available.", + "type": "string" + } + }, + "type": "object" + }, + "PingRequest": { + "description": "A ping, issued by either the server or the client, to check that the other party is still alive. The receiver must promptly respond, or else may be disconnected.", + "properties": { + "method": { + "const": "ping", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "properties": { + "progressToken": { + "$ref": "#/definitions/ProgressToken", + "description": "If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications." + } + }, + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "PrimitiveSchemaDefinition": { + "anyOf": [ + { + "$ref": "#/definitions/StringSchema" + }, + { + "$ref": "#/definitions/NumberSchema" + }, + { + "$ref": "#/definitions/BooleanSchema" + }, + { + "$ref": "#/definitions/EnumSchema" + } + ], + "description": "Restricted schema definitions that only allow primitive types\nwithout nested objects or arrays." + }, + "ProgressNotification": { + "description": "An out-of-band notification used to inform the receiver of a progress update for a long-running request.", + "properties": { + "method": { + "const": "notifications/progress", + "type": "string" + }, + "params": { + "properties": { + "message": { + "description": "An optional message describing the current progress.", + "type": "string" + }, + "progress": { + "description": "The progress thus far. This should increase every time progress is made, even if the total is unknown.", + "type": "number" + }, + "progressToken": { + "$ref": "#/definitions/ProgressToken", + "description": "The progress token which was given in the initial request, used to associate this notification with the request that is proceeding." + }, + "total": { + "description": "Total number of items to process (or total progress required), if known.", + "type": "number" + } + }, + "required": [ + "progress", + "progressToken" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "ProgressToken": { + "description": "A progress token, used to associate progress notifications with the original request.", + "type": [ + "string", + "integer" + ] + }, + "Prompt": { + "description": "A prompt or prompt template that the server offers.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "arguments": { + "description": "A list of arguments to use for templating the prompt.", + "items": { + "$ref": "#/definitions/PromptArgument" + }, + "type": "array" + }, + "description": { + "description": "An optional description of what this prompt provides", + "type": "string" + }, + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + }, + "PromptArgument": { + "description": "Describes an argument that a prompt can accept.", + "properties": { + "description": { + "description": "A human-readable description of the argument.", + "type": "string" + }, + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "required": { + "description": "Whether this argument must be provided.", + "type": "boolean" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + } + }, + "required": [ + "name" + ], + "type": "object" + }, + "PromptListChangedNotification": { + "description": "An optional notification from the server to the client, informing it that the list of prompts it offers has changed. This may be issued by servers without any previous subscription from the client.", + "properties": { + "method": { + "const": "notifications/prompts/list_changed", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "PromptMessage": { + "description": "Describes a message returned as part of a prompt.\n\nThis is similar to `SamplingMessage`, but also supports the embedding of\nresources from the MCP server.", + "properties": { + "content": { + "$ref": "#/definitions/ContentBlock" + }, + "role": { + "$ref": "#/definitions/Role" + } + }, + "required": [ + "content", + "role" + ], + "type": "object" + }, + "PromptReference": { + "description": "Identifies a prompt.", + "properties": { + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + }, + "type": { + "const": "ref/prompt", + "type": "string" + } + }, + "required": [ + "name", + "type" + ], + "type": "object" + }, + "ReadResourceRequest": { + "description": "Sent from the client to the server, to read a specific resource URI.", + "properties": { + "method": { + "const": "resources/read", + "type": "string" + }, + "params": { + "properties": { + "uri": { + "description": "The URI of the resource to read. The URI can use any protocol; it is up to the server how to interpret it.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "uri" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "ReadResourceResult": { + "description": "The server's response to a resources/read request from the client.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "contents": { + "items": { + "anyOf": [ + { + "$ref": "#/definitions/TextResourceContents" + }, + { + "$ref": "#/definitions/BlobResourceContents" + } + ] + }, + "type": "array" + } + }, + "required": [ + "contents" + ], + "type": "object" + }, + "Request": { + "properties": { + "method": { + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "properties": { + "progressToken": { + "$ref": "#/definitions/ProgressToken", + "description": "If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications." + } + }, + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "RequestId": { + "description": "A uniquely identifying ID for a request in JSON-RPC.", + "type": [ + "string", + "integer" + ] + }, + "Resource": { + "description": "A known resource that the server is capable of reading.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "description": { + "description": "A description of what this resource represents.\n\nThis can be used by clients to improve the LLM's understanding of available resources. It can be thought of like a \"hint\" to the model.", + "type": "string" + }, + "mimeType": { + "description": "The MIME type of this resource, if known.", + "type": "string" + }, + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "size": { + "description": "The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known.\n\nThis can be used by Hosts to display file sizes and estimate context window usage.", + "type": "integer" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + }, + "uri": { + "description": "The URI of this resource.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "name", + "uri" + ], + "type": "object" + }, + "ResourceContents": { + "description": "The contents of a specific resource or sub-resource.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "mimeType": { + "description": "The MIME type of this resource, if known.", + "type": "string" + }, + "uri": { + "description": "The URI of this resource.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "uri" + ], + "type": "object" + }, + "ResourceLink": { + "description": "A resource that the server is capable of reading, included in a prompt or tool call result.\n\nNote: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "description": { + "description": "A description of what this resource represents.\n\nThis can be used by clients to improve the LLM's understanding of available resources. It can be thought of like a \"hint\" to the model.", + "type": "string" + }, + "mimeType": { + "description": "The MIME type of this resource, if known.", + "type": "string" + }, + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "size": { + "description": "The size of the raw resource content, in bytes (i.e., before base64 encoding or any tokenization), if known.\n\nThis can be used by Hosts to display file sizes and estimate context window usage.", + "type": "integer" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + }, + "type": { + "const": "resource_link", + "type": "string" + }, + "uri": { + "description": "The URI of this resource.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "name", + "type", + "uri" + ], + "type": "object" + }, + "ResourceListChangedNotification": { + "description": "An optional notification from the server to the client, informing it that the list of resources it can read from has changed. This may be issued by servers without any previous subscription from the client.", + "properties": { + "method": { + "const": "notifications/resources/list_changed", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "ResourceTemplate": { + "description": "A template description for resources available on the server.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "description": { + "description": "A description of what this template is for.\n\nThis can be used by clients to improve the LLM's understanding of available resources. It can be thought of like a \"hint\" to the model.", + "type": "string" + }, + "mimeType": { + "description": "The MIME type for all resources that match this template. This should only be included if all resources matching this template have the same type.", + "type": "string" + }, + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + }, + "uriTemplate": { + "description": "A URI template (according to RFC 6570) that can be used to construct resource URIs.", + "format": "uri-template", + "type": "string" + } + }, + "required": [ + "name", + "uriTemplate" + ], + "type": "object" + }, + "ResourceTemplateReference": { + "description": "A reference to a resource or resource template definition.", + "properties": { + "type": { + "const": "ref/resource", + "type": "string" + }, + "uri": { + "description": "The URI or URI template of the resource.", + "format": "uri-template", + "type": "string" + } + }, + "required": [ + "type", + "uri" + ], + "type": "object" + }, + "ResourceUpdatedNotification": { + "description": "A notification from the server to the client, informing it that a resource has changed and may need to be read again. This should only be sent if the client previously sent a resources/subscribe request.", + "properties": { + "method": { + "const": "notifications/resources/updated", + "type": "string" + }, + "params": { + "properties": { + "uri": { + "description": "The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "uri" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "Result": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + }, + "Role": { + "description": "The sender or recipient of messages and data in a conversation.", + "enum": [ + "assistant", + "user" + ], + "type": "string" + }, + "Root": { + "description": "Represents a root directory or file that the server can operate on.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "name": { + "description": "An optional name for the root. This can be used to provide a human-readable\nidentifier for the root, which may be useful for display purposes or for\nreferencing the root in other parts of the application.", + "type": "string" + }, + "uri": { + "description": "The URI identifying the root. This *must* start with file:// for now.\nThis restriction may be relaxed in future versions of the protocol to allow\nother URI schemes.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "uri" + ], + "type": "object" + }, + "RootsListChangedNotification": { + "description": "A notification from the client to the server, informing it that the list of roots has changed.\nThis notification should be sent whenever the client adds, removes, or modifies any root.\nThe server should then request an updated list of roots using the ListRootsRequest.", + "properties": { + "method": { + "const": "notifications/roots/list_changed", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "SamplingMessage": { + "description": "Describes a message issued to or received from an LLM API.", + "properties": { + "content": { + "anyOf": [ + { + "$ref": "#/definitions/TextContent" + }, + { + "$ref": "#/definitions/ImageContent" + }, + { + "$ref": "#/definitions/AudioContent" + } + ] + }, + "role": { + "$ref": "#/definitions/Role" + } + }, + "required": [ + "content", + "role" + ], + "type": "object" + }, + "ServerCapabilities": { + "description": "Capabilities that a server may support. Known capabilities are defined here, in this schema, but this is not a closed set: any server can define its own, additional capabilities.", + "properties": { + "completions": { + "additionalProperties": true, + "description": "Present if the server supports argument autocompletion suggestions.", + "properties": {}, + "type": "object" + }, + "experimental": { + "additionalProperties": { + "additionalProperties": true, + "properties": {}, + "type": "object" + }, + "description": "Experimental, non-standard capabilities that the server supports.", + "type": "object" + }, + "logging": { + "additionalProperties": true, + "description": "Present if the server supports sending log messages to the client.", + "properties": {}, + "type": "object" + }, + "prompts": { + "description": "Present if the server offers any prompt templates.", + "properties": { + "listChanged": { + "description": "Whether this server supports notifications for changes to the prompt list.", + "type": "boolean" + } + }, + "type": "object" + }, + "resources": { + "description": "Present if the server offers any resources to read.", + "properties": { + "listChanged": { + "description": "Whether this server supports notifications for changes to the resource list.", + "type": "boolean" + }, + "subscribe": { + "description": "Whether this server supports subscribing to resource updates.", + "type": "boolean" + } + }, + "type": "object" + }, + "tools": { + "description": "Present if the server offers any tools to call.", + "properties": { + "listChanged": { + "description": "Whether this server supports notifications for changes to the tool list.", + "type": "boolean" + } + }, + "type": "object" + } + }, + "type": "object" + }, + "ServerNotification": { + "anyOf": [ + { + "$ref": "#/definitions/CancelledNotification" + }, + { + "$ref": "#/definitions/ProgressNotification" + }, + { + "$ref": "#/definitions/ResourceListChangedNotification" + }, + { + "$ref": "#/definitions/ResourceUpdatedNotification" + }, + { + "$ref": "#/definitions/PromptListChangedNotification" + }, + { + "$ref": "#/definitions/ToolListChangedNotification" + }, + { + "$ref": "#/definitions/LoggingMessageNotification" + } + ] + }, + "ServerRequest": { + "anyOf": [ + { + "$ref": "#/definitions/PingRequest" + }, + { + "$ref": "#/definitions/CreateMessageRequest" + }, + { + "$ref": "#/definitions/ListRootsRequest" + }, + { + "$ref": "#/definitions/ElicitRequest" + } + ] + }, + "ServerResult": { + "anyOf": [ + { + "$ref": "#/definitions/Result" + }, + { + "$ref": "#/definitions/InitializeResult" + }, + { + "$ref": "#/definitions/ListResourcesResult" + }, + { + "$ref": "#/definitions/ListResourceTemplatesResult" + }, + { + "$ref": "#/definitions/ReadResourceResult" + }, + { + "$ref": "#/definitions/ListPromptsResult" + }, + { + "$ref": "#/definitions/GetPromptResult" + }, + { + "$ref": "#/definitions/ListToolsResult" + }, + { + "$ref": "#/definitions/CallToolResult" + }, + { + "$ref": "#/definitions/CompleteResult" + } + ] + }, + "SetLevelRequest": { + "description": "A request from the client to the server, to enable or adjust logging.", + "properties": { + "method": { + "const": "logging/setLevel", + "type": "string" + }, + "params": { + "properties": { + "level": { + "$ref": "#/definitions/LoggingLevel", + "description": "The level of logging that the client wants to receive from the server. The server should send all logs at this level and higher (i.e., more severe) to the client as notifications/message." + } + }, + "required": [ + "level" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "StringSchema": { + "properties": { + "description": { + "type": "string" + }, + "format": { + "enum": [ + "date", + "date-time", + "email", + "uri" + ], + "type": "string" + }, + "maxLength": { + "type": "integer" + }, + "minLength": { + "type": "integer" + }, + "title": { + "type": "string" + }, + "type": { + "const": "string", + "type": "string" + } + }, + "required": [ + "type" + ], + "type": "object" + }, + "SubscribeRequest": { + "description": "Sent from the client to request resources/updated notifications from the server whenever a particular resource changes.", + "properties": { + "method": { + "const": "resources/subscribe", + "type": "string" + }, + "params": { + "properties": { + "uri": { + "description": "The URI of the resource to subscribe to. The URI can use any protocol; it is up to the server how to interpret it.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "uri" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + }, + "TextContent": { + "description": "Text provided to or from an LLM.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/Annotations", + "description": "Optional annotations for the client." + }, + "text": { + "description": "The text content of the message.", + "type": "string" + }, + "type": { + "const": "text", + "type": "string" + } + }, + "required": [ + "text", + "type" + ], + "type": "object" + }, + "TextResourceContents": { + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "mimeType": { + "description": "The MIME type of this resource, if known.", + "type": "string" + }, + "text": { + "description": "The text of the item. This must only be set if the item can actually be represented as text (not binary data).", + "type": "string" + }, + "uri": { + "description": "The URI of this resource.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "text", + "uri" + ], + "type": "object" + }, + "Tool": { + "description": "Definition for a tool the client can call.", + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + }, + "annotations": { + "$ref": "#/definitions/ToolAnnotations", + "description": "Optional additional tool information.\n\nDisplay name precedence order is: title, annotations.title, then name." + }, + "description": { + "description": "A human-readable description of the tool.\n\nThis can be used by clients to improve the LLM's understanding of available tools. It can be thought of like a \"hint\" to the model.", + "type": "string" + }, + "inputSchema": { + "description": "A JSON Schema object defining the expected parameters for the tool.", + "properties": { + "properties": { + "additionalProperties": { + "additionalProperties": true, + "properties": {}, + "type": "object" + }, + "type": "object" + }, + "required": { + "items": { + "type": "string" + }, + "type": "array" + }, + "type": { + "const": "object", + "type": "string" + } + }, + "required": [ + "type" + ], + "type": "object" + }, + "name": { + "description": "Intended for programmatic or logical use, but used as a display name in past specs or fallback (if title isn't present).", + "type": "string" + }, + "outputSchema": { + "description": "An optional JSON Schema object defining the structure of the tool's output returned in\nthe structuredContent field of a CallToolResult.", + "properties": { + "properties": { + "additionalProperties": { + "additionalProperties": true, + "properties": {}, + "type": "object" + }, + "type": "object" + }, + "required": { + "items": { + "type": "string" + }, + "type": "array" + }, + "type": { + "const": "object", + "type": "string" + } + }, + "required": [ + "type" + ], + "type": "object" + }, + "title": { + "description": "Intended for UI and end-user contexts — optimized to be human-readable and easily understood,\neven by those unfamiliar with domain-specific terminology.\n\nIf not provided, the name should be used for display (except for Tool,\nwhere `annotations.title` should be given precedence over using `name`,\nif present).", + "type": "string" + } + }, + "required": [ + "inputSchema", + "name" + ], + "type": "object" + }, + "ToolAnnotations": { + "description": "Additional properties describing a Tool to clients.\n\nNOTE: all properties in ToolAnnotations are **hints**.\nThey are not guaranteed to provide a faithful description of\ntool behavior (including descriptive properties like `title`).\n\nClients should never make tool use decisions based on ToolAnnotations\nreceived from untrusted servers.", + "properties": { + "destructiveHint": { + "description": "If true, the tool may perform destructive updates to its environment.\nIf false, the tool performs only additive updates.\n\n(This property is meaningful only when `readOnlyHint == false`)\n\nDefault: true", + "type": "boolean" + }, + "idempotentHint": { + "description": "If true, calling the tool repeatedly with the same arguments\nwill have no additional effect on the its environment.\n\n(This property is meaningful only when `readOnlyHint == false`)\n\nDefault: false", + "type": "boolean" + }, + "openWorldHint": { + "description": "If true, this tool may interact with an \"open world\" of external\nentities. If false, the tool's domain of interaction is closed.\nFor example, the world of a web search tool is open, whereas that\nof a memory tool is not.\n\nDefault: true", + "type": "boolean" + }, + "readOnlyHint": { + "description": "If true, the tool does not modify its environment.\n\nDefault: false", + "type": "boolean" + }, + "title": { + "description": "A human-readable title for the tool.", + "type": "string" + } + }, + "type": "object" + }, + "ToolListChangedNotification": { + "description": "An optional notification from the server to the client, informing it that the list of tools it offers has changed. This may be issued by servers without any previous subscription from the client.", + "properties": { + "method": { + "const": "notifications/tools/list_changed", + "type": "string" + }, + "params": { + "additionalProperties": {}, + "properties": { + "_meta": { + "additionalProperties": {}, + "description": "See [specification/2025-06-18/basic/index#general-fields] for notes on _meta usage.", + "type": "object" + } + }, + "type": "object" + } + }, + "required": [ + "method" + ], + "type": "object" + }, + "UnsubscribeRequest": { + "description": "Sent from the client to request cancellation of resources/updated notifications from the server. This should follow a previous resources/subscribe request.", + "properties": { + "method": { + "const": "resources/unsubscribe", + "type": "string" + }, + "params": { + "properties": { + "uri": { + "description": "The URI of the resource to unsubscribe from.", + "format": "uri", + "type": "string" + } + }, + "required": [ + "uri" + ], + "type": "object" + } + }, + "required": [ + "method", + "params" + ], + "type": "object" + } + } +} + diff --git a/codex-rs/mcp-types/src/lib.rs b/codex-rs/mcp-types/src/lib.rs index 0ed518535f..cf09d67e35 100644 --- a/codex-rs/mcp-types/src/lib.rs +++ b/codex-rs/mcp-types/src/lib.rs @@ -10,7 +10,7 @@ use serde::Serialize; use serde::de::DeserializeOwned; use std::convert::TryFrom; -pub const MCP_SCHEMA_VERSION: &str = "2025-03-26"; +pub const MCP_SCHEMA_VERSION: &str = "2025-06-18"; pub const JSONRPC_VERSION: &str = "2.0"; /// Paired request/response types for the Model Context Protocol (MCP). @@ -35,6 +35,12 @@ fn default_jsonrpc() -> String { pub struct Annotations { #[serde(default, skip_serializing_if = "Option::is_none")] pub audience: Option>, + #[serde( + rename = "lastModified", + default, + skip_serializing_if = "Option::is_none" + )] + pub last_modified: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub priority: Option, } @@ -50,6 +56,14 @@ pub struct AudioContent { pub r#type: String, // &'static str = "audio" } +/// Base interface for metadata with name (identifier) and title (display name) properties. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct BaseMetadata { + pub name: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, +} + #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct BlobResourceContents { pub blob: String, @@ -58,6 +72,17 @@ pub struct BlobResourceContents { pub uri: String, } +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct BooleanSchema { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: String, // &'static str = "boolean" +} + #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum CallToolRequest {} @@ -75,29 +100,17 @@ pub struct CallToolRequestParams { } /// The server's response to a tool call. -/// -/// Any errors that originate from the tool SHOULD be reported inside the result -/// object, with `isError` set to true, _not_ as an MCP protocol-level error -/// response. Otherwise, the LLM would not be able to see that an error occurred -/// and self-correct. -/// -/// However, any errors in _finding_ the tool, an error indicating that the -/// server does not support tool calls, or any other exceptional conditions, -/// should be reported as an MCP error response. #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct CallToolResult { - pub content: Vec, + pub content: Vec, #[serde(rename = "isError", default, skip_serializing_if = "Option::is_none")] pub is_error: Option, -} - -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] -#[serde(untagged)] -pub enum CallToolResultContent { - TextContent(TextContent), - ImageContent(ImageContent), - AudioContent(AudioContent), - EmbeddedResource(EmbeddedResource), + #[serde( + rename = "structuredContent", + default, + skip_serializing_if = "Option::is_none" + )] + pub structured_content: Option, } impl From for serde_json::Value { @@ -127,6 +140,8 @@ pub struct CancelledNotificationParams { /// Capabilities a client may support. Known capabilities are defined here, in this schema, but this is not a closed set: any client can define its own, additional capabilities. #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct ClientCapabilities { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub elicitation: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub experimental: Option, #[serde(default, skip_serializing_if = "Option::is_none")] @@ -194,6 +209,7 @@ pub enum ClientResult { Result(Result), CreateMessageResult(CreateMessageResult), ListRootsResult(ListRootsResult), + ElicitResult(ElicitResult), } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] @@ -208,9 +224,18 @@ impl ModelContextProtocolRequest for CompleteRequest { #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct CompleteRequestParams { pub argument: CompleteRequestParamsArgument, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub context: Option, pub r#ref: CompleteRequestParamsRef, } +/// Additional, optional context for completions +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct CompleteRequestParamsContext { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub arguments: Option, +} + /// The argument's information #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct CompleteRequestParamsArgument { @@ -222,7 +247,7 @@ pub struct CompleteRequestParamsArgument { #[serde(untagged)] pub enum CompleteRequestParamsRef { PromptReference(PromptReference), - ResourceReference(ResourceReference), + ResourceTemplateReference(ResourceTemplateReference), } /// The server's response to a completion/complete request @@ -248,6 +273,16 @@ impl From for serde_json::Value { } } +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ContentBlock { + TextContent(TextContent), + ImageContent(ImageContent), + AudioContent(AudioContent), + ResourceLink(ResourceLink), + EmbeddedResource(EmbeddedResource), +} + #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum CreateMessageRequest {} @@ -325,6 +360,48 @@ impl From for serde_json::Value { #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct Cursor(String); +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub enum ElicitRequest {} + +impl ModelContextProtocolRequest for ElicitRequest { + const METHOD: &'static str = "elicitation/create"; + type Params = ElicitRequestParams; + type Result = ElicitResult; +} + +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct ElicitRequestParams { + pub message: String, + #[serde(rename = "requestedSchema")] + pub requested_schema: ElicitRequestParamsRequestedSchema, +} + +/// A restricted subset of JSON Schema. +/// Only top-level properties are allowed, without nesting. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct ElicitRequestParamsRequestedSchema { + pub properties: serde_json::Value, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub required: Option>, + pub r#type: String, // &'static str = "object" +} + +/// The client's response to an elicitation request. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct ElicitResult { + pub action: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub content: Option, +} + +impl From for serde_json::Value { + fn from(value: ElicitResult) -> Self { + // Leave this as it should never fail + #[expect(clippy::unwrap_used)] + serde_json::to_value(value).unwrap() + } +} + /// The contents of a resource, embedded into a prompt or tool call result. /// /// It is up to the client how best to render embedded resources for the benefit @@ -346,6 +423,18 @@ pub enum EmbeddedResourceResource { pub type EmptyResult = Result; +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct EnumSchema { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub description: Option, + pub r#enum: Vec, + #[serde(rename = "enumNames", default, skip_serializing_if = "Option::is_none")] + pub enum_names: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: String, // &'static str = "string" +} + #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum GetPromptRequest {} @@ -389,10 +478,12 @@ pub struct ImageContent { pub r#type: String, // &'static str = "image" } -/// Describes the name and version of an MCP implementation. +/// Describes the name and version of an MCP implementation, with an optional title for UI representation. #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct Implementation { pub name: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, pub version: String, } @@ -442,24 +533,6 @@ impl ModelContextProtocolNotification for InitializedNotification { type Params = Option; } -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] -#[serde(untagged)] -pub enum JSONRPCBatchRequestItem { - JSONRPCRequest(JSONRPCRequest), - JSONRPCNotification(JSONRPCNotification), -} - -pub type JSONRPCBatchRequest = Vec; - -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] -#[serde(untagged)] -pub enum JSONRPCBatchResponseItem { - JSONRPCResponse(JSONRPCResponse), - JSONRPCError(JSONRPCError), -} - -pub type JSONRPCBatchResponse = Vec; - /// A response to a request that indicates an error occurred. #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct JSONRPCError { @@ -483,10 +556,8 @@ pub struct JSONRPCErrorError { pub enum JSONRPCMessage { Request(JSONRPCRequest), Notification(JSONRPCNotification), - BatchRequest(JSONRPCBatchRequest), Response(JSONRPCResponse), Error(JSONRPCError), - BatchResponse(JSONRPCBatchResponse), } /// A notification which does not expect a response. @@ -777,6 +848,19 @@ pub struct Notification { pub params: Option, } +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct NumberSchema { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub maximum: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub minimum: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: String, +} + #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct PaginatedRequest { pub method: String, @@ -817,6 +901,17 @@ impl ModelContextProtocolRequest for PingRequest { type Result = Result; } +/// Restricted schema definitions that only allow primitive types +/// without nested objects or arrays. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] +pub enum PrimitiveSchemaDefinition { + StringSchema(StringSchema), + NumberSchema(NumberSchema), + BooleanSchema(BooleanSchema), + EnumSchema(EnumSchema), +} + #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum ProgressNotification {} @@ -836,7 +931,7 @@ pub struct ProgressNotificationParams { pub total: Option, } -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)] #[serde(untagged)] pub enum ProgressToken { String(String), @@ -851,6 +946,8 @@ pub struct Prompt { #[serde(default, skip_serializing_if = "Option::is_none")] pub description: Option, pub name: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, } /// Describes an argument that a prompt can accept. @@ -861,6 +958,8 @@ pub struct PromptArgument { pub name: String, #[serde(default, skip_serializing_if = "Option::is_none")] pub required: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] @@ -877,23 +976,16 @@ impl ModelContextProtocolNotification for PromptListChangedNotification { /// resources from the MCP server. #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct PromptMessage { - pub content: PromptMessageContent, + pub content: ContentBlock, pub role: Role, } -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] -#[serde(untagged)] -pub enum PromptMessageContent { - TextContent(TextContent), - ImageContent(ImageContent), - AudioContent(AudioContent), - EmbeddedResource(EmbeddedResource), -} - /// Identifies a prompt. #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct PromptReference { pub name: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, pub r#type: String, // &'static str = "ref/prompt" } @@ -939,7 +1031,7 @@ pub struct Request { pub params: Option, } -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)] #[serde(untagged)] pub enum RequestId { String(String), @@ -958,6 +1050,8 @@ pub struct Resource { pub name: String, #[serde(default, skip_serializing_if = "Option::is_none")] pub size: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, pub uri: String, } @@ -969,6 +1063,26 @@ pub struct ResourceContents { pub uri: String, } +/// A resource that the server is capable of reading, included in a prompt or tool call result. +/// +/// Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct ResourceLink { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub annotations: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(rename = "mimeType", default, skip_serializing_if = "Option::is_none")] + pub mime_type: Option, + pub name: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub size: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: String, // &'static str = "resource_link" + pub uri: String, +} + #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum ResourceListChangedNotification {} @@ -977,13 +1091,6 @@ impl ModelContextProtocolNotification for ResourceListChangedNotification { type Params = Option; } -/// A reference to a resource or resource template definition. -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] -pub struct ResourceReference { - pub r#type: String, // &'static str = "ref/resource" - pub uri: String, -} - /// A template description for resources available on the server. #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct ResourceTemplate { @@ -994,10 +1101,19 @@ pub struct ResourceTemplate { #[serde(rename = "mimeType", default, skip_serializing_if = "Option::is_none")] pub mime_type: Option, pub name: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, #[serde(rename = "uriTemplate")] pub uri_template: String, } +/// A reference to a resource or resource template definition. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct ResourceTemplateReference { + pub r#type: String, // &'static str = "ref/resource" + pub uri: String, +} + #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum ResourceUpdatedNotification {} @@ -1140,6 +1256,7 @@ pub enum ServerRequest { PingRequest(PingRequest), CreateMessageRequest(CreateMessageRequest), ListRootsRequest(ListRootsRequest), + ElicitRequest(ElicitRequest), } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] @@ -1172,6 +1289,21 @@ pub struct SetLevelRequestParams { pub level: LoggingLevel, } +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct StringSchema { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub format: Option, + #[serde(rename = "maxLength", default, skip_serializing_if = "Option::is_none")] + pub max_length: Option, + #[serde(rename = "minLength", default, skip_serializing_if = "Option::is_none")] + pub min_length: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: String, // &'static str = "string" +} + #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum SubscribeRequest {} @@ -1213,6 +1345,25 @@ pub struct Tool { #[serde(rename = "inputSchema")] pub input_schema: ToolInputSchema, pub name: String, + #[serde( + rename = "outputSchema", + default, + skip_serializing_if = "Option::is_none" + )] + pub output_schema: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub title: Option, +} + +/// An optional JSON Schema object defining the structure of the tool's output returned in +/// the structuredContent field of a CallToolResult. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct ToolOutputSchema { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub properties: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub required: Option>, + pub r#type: String, // &'static str = "object" } /// A JSON Schema object defining the expected parameters for the tool. diff --git a/codex-rs/mcp-types/tests/initialize.rs b/codex-rs/mcp-types/tests/initialize.rs index 27902dce50..c69586f030 100644 --- a/codex-rs/mcp-types/tests/initialize.rs +++ b/codex-rs/mcp-types/tests/initialize.rs @@ -17,8 +17,8 @@ fn deserialize_initialize_request() { "method": "initialize", "params": { "capabilities": {}, - "clientInfo": { "name": "acme-client", "version": "1.2.3" }, - "protocolVersion": "2025-03-26" + "clientInfo": { "name": "acme-client", "title": "Acme", "version": "1.2.3" }, + "protocolVersion": "2025-06-18" } }"#; @@ -37,8 +37,8 @@ fn deserialize_initialize_request() { method: "initialize".into(), params: Some(json!({ "capabilities": {}, - "clientInfo": { "name": "acme-client", "version": "1.2.3" }, - "protocolVersion": "2025-03-26" + "clientInfo": { "name": "acme-client", "title": "Acme", "version": "1.2.3" }, + "protocolVersion": "2025-06-18" })), }; @@ -57,12 +57,14 @@ fn deserialize_initialize_request() { experimental: None, roots: None, sampling: None, + elicitation: None, }, client_info: Implementation { name: "acme-client".into(), + title: Some("Acme".to_string()), version: "1.2.3".into(), }, - protocol_version: "2025-03-26".into(), + protocol_version: "2025-06-18".into(), } ); } diff --git a/codex-rs/rust-toolchain.toml b/codex-rs/rust-toolchain.toml new file mode 100644 index 0000000000..72bafdf4b6 --- /dev/null +++ b/codex-rs/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "1.88.0" +components = [ "clippy", "rustfmt", "rust-src"] diff --git a/codex-rs/tui/Cargo.toml b/codex-rs/tui/Cargo.toml index 74aedfa353..63d287ca11 100644 --- a/codex-rs/tui/Cargo.toml +++ b/codex-rs/tui/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "codex-tui" version = { workspace = true } -edition = "2024" [[bin]] name = "codex-tui" @@ -19,14 +19,14 @@ anyhow = "1" base64 = "0.22.1" clap = { version = "4", features = ["derive"] } codex-ansi-escape = { path = "../ansi-escape" } -codex-core = { path = "../core" } +codex-arg0 = { path = "../arg0" } codex-common = { path = "../common", features = [ "cli", "elapsed", "sandbox_summary", ] } +codex-core = { path = "../core" } codex-file-search = { path = "../file-search" } -codex-linux-sandbox = { path = "../linux-sandbox" } codex-login = { path = "../login" } color-eyre = "0.6.3" crossterm = { version = "0.28.1", features = ["bracketed-paste"] } @@ -35,15 +35,16 @@ lazy_static = "1" mcp-types = { path = "../mcp-types" } path-clean = "1.0.1" ratatui = { version = "0.29.0", features = [ - "unstable-widget-ref", + "scrolling-regions", "unstable-rendered-line-info", + "unstable-widget-ref", ] } ratatui-image = "8.0.0" regex-lite = "0.1" serde_json = { version = "1", features = ["preserve_order"] } shlex = "1.3.0" -strum = "0.27.1" -strum_macros = "0.27.1" +strum = "0.27.2" +strum_macros = "0.27.2" tokio = { version = "1", features = [ "io-std", "macros", @@ -58,6 +59,7 @@ tui-input = "0.14.0" tui-markdown = "0.3.3" tui-textarea = "0.7.0" unicode-segmentation = "1.12.0" +unicode-width = "0.1" uuid = "1" [dev-dependencies] diff --git a/codex-rs/tui/src/app.rs b/codex-rs/tui/src/app.rs index e1dde8332d..ad4bc24fd4 100644 --- a/codex-rs/tui/src/app.rs +++ b/codex-rs/tui/src/app.rs @@ -5,21 +5,29 @@ use crate::file_search::FileSearchManager; use crate::get_git_diff::get_git_diff; use crate::git_warning_screen::GitWarningOutcome; use crate::git_warning_screen::GitWarningScreen; -use crate::login_screen::LoginScreen; -use crate::mouse_capture::MouseCapture; -use crate::scroll_event_helper::ScrollEventHelper; use crate::slash_command::SlashCommand; use crate::tui; use codex_core::config::Config; use codex_core::protocol::Event; +use codex_core::protocol::EventMsg; +use codex_core::protocol::ExecApprovalRequestEvent; use color_eyre::eyre::Result; +use crossterm::SynchronizedUpdate; use crossterm::event::KeyCode; use crossterm::event::KeyEvent; -use crossterm::event::MouseEvent; -use crossterm::event::MouseEventKind; +use ratatui::layout::Offset; +use ratatui::prelude::Backend; use std::path::PathBuf; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; use std::sync::mpsc::Receiver; use std::sync::mpsc::channel; +use std::thread; +use std::time::Duration; + +/// Time window for debouncing redraw requests. +const REDRAW_DEBOUNCE: Duration = Duration::from_millis(10); /// Top-level application state: which full-screen view is currently active. #[allow(clippy::large_enum_variant)] @@ -30,8 +38,6 @@ enum AppState<'a> { /// `AppState`. widget: Box>, }, - /// The login screen for the OpenAI provider. - Login { screen: LoginScreen }, /// The start-up warning that recommends running codex inside a Git repo. GitWarning { screen: GitWarningScreen }, } @@ -46,6 +52,9 @@ pub(crate) struct App<'a> { file_search: FileSearchManager, + /// True when a redraw has been scheduled but not yet executed. + pending_redraw: Arc, + /// Stored parameters needed to instantiate the ChatWidget later, e.g., /// after dismissing the Git-repo warning. chat_args: Option, @@ -60,66 +69,61 @@ struct ChatWidgetArgs { initial_images: Vec, } -impl<'a> App<'a> { +impl App<'_> { pub(crate) fn new( config: Config, initial_prompt: Option, - show_login_screen: bool, show_git_warning: bool, initial_images: Vec, ) -> Self { let (app_event_tx, app_event_rx) = channel(); let app_event_tx = AppEventSender::new(app_event_tx); - let scroll_event_helper = ScrollEventHelper::new(app_event_tx.clone()); + let pending_redraw = Arc::new(AtomicBool::new(false)); // Spawn a dedicated thread for reading the crossterm event loop and // re-publishing the events as AppEvents, as appropriate. { let app_event_tx = app_event_tx.clone(); std::thread::spawn(move || { - while let Ok(event) = crossterm::event::read() { - match event { - crossterm::event::Event::Key(key_event) => { - app_event_tx.send(AppEvent::KeyEvent(key_event)); - } - crossterm::event::Event::Resize(_, _) => { - app_event_tx.send(AppEvent::Redraw); - } - crossterm::event::Event::Mouse(MouseEvent { - kind: MouseEventKind::ScrollUp, - .. - }) => { - scroll_event_helper.scroll_up(); - } - crossterm::event::Event::Mouse(MouseEvent { - kind: MouseEventKind::ScrollDown, - .. - }) => { - scroll_event_helper.scroll_down(); - } - crossterm::event::Event::Paste(pasted) => { - app_event_tx.send(AppEvent::Paste(pasted)); - } - _ => { - // Ignore any other events. + loop { + // This timeout is necessary to avoid holding the event lock + // that crossterm::event::read() acquires. In particular, + // reading the cursor position (crossterm::cursor::position()) + // needs to acquire the event lock, and so will fail if it + // can't acquire it within 2 sec. Resizing the terminal + // crashes the app if the cursor position can't be read. + if let Ok(true) = crossterm::event::poll(Duration::from_millis(100)) { + if let Ok(event) = crossterm::event::read() { + match event { + crossterm::event::Event::Key(key_event) => { + app_event_tx.send(AppEvent::KeyEvent(key_event)); + } + crossterm::event::Event::Resize(_, _) => { + app_event_tx.send(AppEvent::RequestRedraw); + } + crossterm::event::Event::Paste(pasted) => { + // Many terminals convert newlines to \r when + // pasting, e.g. [iTerm2][]. But [tui-textarea + // expects \n][tui-textarea]. This seems like a bug + // in tui-textarea IMO, but work around it for now. + // [tui-textarea]: https://github.com/rhysd/tui-textarea/blob/4d18622eeac13b309e0ff6a55a46ac6706da68cf/src/textarea.rs#L782-L783 + // [iTerm2]: https://github.com/gnachman/iTerm2/blob/5d0c0d9f68523cbd0494dad5422998964a2ecd8d/sources/iTermPasteHelper.m#L206-L216 + let pasted = pasted.replace("\r", "\n"); + app_event_tx.send(AppEvent::Paste(pasted)); + } + _ => { + // Ignore any other events. + } + } } + } else { + // Timeout expired, no `Event` is available } } }); } - let (app_state, chat_args) = if show_login_screen { - ( - AppState::Login { - screen: LoginScreen::new(app_event_tx.clone(), config.codex_home.clone()), - }, - Some(ChatWidgetArgs { - config: config.clone(), - initial_prompt, - initial_images, - }), - ) - } else if show_git_warning { + let (app_state, chat_args) = if show_git_warning { ( AppState::GitWarning { screen: GitWarningScreen::new(), @@ -152,6 +156,7 @@ impl<'a> App<'a> { app_state, config, file_search, + pending_redraw, chat_args, } } @@ -162,19 +167,44 @@ impl<'a> App<'a> { self.app_event_tx.clone() } - pub(crate) fn run( - &mut self, - terminal: &mut tui::Tui, - mouse_capture: &mut MouseCapture, - ) -> Result<()> { + /// Schedule a redraw if one is not already pending. + #[allow(clippy::unwrap_used)] + fn schedule_redraw(&self) { + // Attempt to set the flag to `true`. If it was already `true`, another + // redraw is already pending so we can return early. + if self + .pending_redraw + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_err() + { + return; + } + + let tx = self.app_event_tx.clone(); + let pending_redraw = self.pending_redraw.clone(); + thread::spawn(move || { + thread::sleep(REDRAW_DEBOUNCE); + tx.send(AppEvent::Redraw); + pending_redraw.store(false, Ordering::SeqCst); + }); + } + + pub(crate) fn run(&mut self, terminal: &mut tui::Tui) -> Result<()> { // Insert an event to trigger the first render. let app_event_tx = self.app_event_tx.clone(); - app_event_tx.send(AppEvent::Redraw); + app_event_tx.send(AppEvent::RequestRedraw); while let Ok(event) = self.app_event_rx.recv() { match event { + AppEvent::InsertHistory(lines) => { + crate::insert_history::insert_history_lines(terminal, lines); + self.app_event_tx.send(AppEvent::RequestRedraw); + } + AppEvent::RequestRedraw => { + self.schedule_redraw(); + } AppEvent::Redraw => { - self.draw_next_frame(terminal)?; + std::io::stdout().sync_update(|_| self.draw_next_frame(terminal))??; } AppEvent::KeyEvent(key_event) => { match key_event { @@ -185,11 +215,9 @@ impl<'a> App<'a> { } => { match &mut self.app_state { AppState::Chat { widget } => { - if widget.on_ctrl_c() { - self.app_event_tx.send(AppEvent::ExitRequest); - } + widget.on_ctrl_c(); } - AppState::Login { .. } | AppState::GitWarning { .. } => { + AppState::GitWarning { .. } => { // No-op. } } @@ -199,16 +227,27 @@ impl<'a> App<'a> { modifiers: crossterm::event::KeyModifiers::CONTROL, .. } => { - self.app_event_tx.send(AppEvent::ExitRequest); + match &mut self.app_state { + AppState::Chat { widget } => { + if widget.composer_is_empty() { + self.app_event_tx.send(AppEvent::ExitRequest); + } else { + // Treat Ctrl+D as a normal key event when the composer + // is not empty so that it doesn't quit the application + // prematurely. + self.dispatch_key_event(key_event); + } + } + AppState::GitWarning { .. } => { + self.app_event_tx.send(AppEvent::ExitRequest); + } + } } _ => { self.dispatch_key_event(key_event); } }; } - AppEvent::Scroll(scroll_delta) => { - self.dispatch_scroll_event(scroll_delta); - } AppEvent::Paste(text) => { self.dispatch_paste_event(text); } @@ -220,11 +259,11 @@ impl<'a> App<'a> { } AppEvent::CodexOp(op) => match &mut self.app_state { AppState::Chat { widget } => widget.submit_op(op), - AppState::Login { .. } | AppState::GitWarning { .. } => {} + AppState::GitWarning { .. } => {} }, AppEvent::LatestLog(line) => match &mut self.app_state { AppState::Chat { widget } => widget.update_latest_log(line), - AppState::Login { .. } | AppState::GitWarning { .. } => {} + AppState::GitWarning { .. } => {} }, AppEvent::DispatchCommand(command) => match command { SlashCommand::New => { @@ -235,12 +274,7 @@ impl<'a> App<'a> { Vec::new(), )); self.app_state = AppState::Chat { widget: new_widget }; - self.app_event_tx.send(AppEvent::Redraw); - } - SlashCommand::ToggleMouseMode => { - if let Err(e) = mouse_capture.toggle() { - tracing::error!("Failed to toggle mouse mode: {e}"); - } + self.app_event_tx.send(AppEvent::RequestRedraw); } SlashCommand::Quit => { break; @@ -266,6 +300,18 @@ impl<'a> App<'a> { widget.add_diff_output(text); } } + #[cfg(debug_assertions)] + SlashCommand::TestApproval => { + self.app_event_tx.send(AppEvent::CodexEvent(Event { + id: "1".to_string(), + msg: EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { + call_id: "1".to_string(), + command: vec!["git".into(), "apply".into()], + cwd: self.config.cwd.clone(), + reason: Some("test".to_string()), + }), + })); + } }, AppEvent::StartFileSearch(query) => { self.file_search.on_user_query(query); @@ -282,14 +328,56 @@ impl<'a> App<'a> { Ok(()) } + pub(crate) fn token_usage(&self) -> codex_core::protocol::TokenUsage { + match &self.app_state { + AppState::Chat { widget } => widget.token_usage().clone(), + AppState::GitWarning { .. } => codex_core::protocol::TokenUsage::default(), + } + } + fn draw_next_frame(&mut self, terminal: &mut tui::Tui) -> Result<()> { + let screen_size = terminal.size()?; + let last_known_screen_size = terminal.last_known_screen_size; + if screen_size != last_known_screen_size { + let cursor_pos = terminal.get_cursor_position()?; + let last_known_cursor_pos = terminal.last_known_cursor_pos; + if cursor_pos.y != last_known_cursor_pos.y { + // The terminal was resized. The only point of reference we have for where our viewport + // was moved is the cursor position. + // NB this assumes that the cursor was not wrapped as part of the resize. + let cursor_delta = cursor_pos.y as i32 - last_known_cursor_pos.y as i32; + + let new_viewport_area = terminal.viewport_area.offset(Offset { + x: 0, + y: cursor_delta, + }); + terminal.set_viewport_area(new_viewport_area); + terminal.clear()?; + } + } + + let size = terminal.size()?; + let desired_height = match &self.app_state { + AppState::Chat { widget } => widget.desired_height(size.width), + AppState::GitWarning { .. } => 10, + }; + let mut area = terminal.viewport_area; + area.height = desired_height; + area.width = size.width; + if area.bottom() > size.height { + terminal + .backend_mut() + .scroll_region_up(0..area.top(), area.bottom() - size.height)?; + area.y = size.height - area.height; + } + if area != terminal.viewport_area { + terminal.clear()?; + terminal.set_viewport_area(area); + } match &mut self.app_state { AppState::Chat { widget } => { terminal.draw(|frame| frame.render_widget_ref(&**widget, frame.area()))?; } - AppState::Login { screen } => { - terminal.draw(|frame| frame.render_widget_ref(&*screen, frame.area()))?; - } AppState::GitWarning { screen } => { terminal.draw(|frame| frame.render_widget_ref(&*screen, frame.area()))?; } @@ -304,7 +392,6 @@ impl<'a> App<'a> { AppState::Chat { widget } => { widget.handle_key_event(key_event); } - AppState::Login { screen } => screen.handle_key_event(key_event), AppState::GitWarning { screen } => match screen.handle_key_event(key_event) { GitWarningOutcome::Continue => { // User accepted – switch to chat view. @@ -320,7 +407,7 @@ impl<'a> App<'a> { args.initial_images, )); self.app_state = AppState::Chat { widget }; - self.app_event_tx.send(AppEvent::Redraw); + self.app_event_tx.send(AppEvent::RequestRedraw); } GitWarningOutcome::Quit => { self.app_event_tx.send(AppEvent::ExitRequest); @@ -335,21 +422,14 @@ impl<'a> App<'a> { fn dispatch_paste_event(&mut self, pasted: String) { match &mut self.app_state { AppState::Chat { widget } => widget.handle_paste(pasted), - AppState::Login { .. } | AppState::GitWarning { .. } => {} - } - } - - fn dispatch_scroll_event(&mut self, scroll_delta: i32) { - match &mut self.app_state { - AppState::Chat { widget } => widget.handle_scroll_delta(scroll_delta), - AppState::Login { .. } | AppState::GitWarning { .. } => {} + AppState::GitWarning { .. } => {} } } fn dispatch_codex_event(&mut self, event: Event) { match &mut self.app_state { AppState::Chat { widget } => widget.handle_codex_event(event), - AppState::Login { .. } | AppState::GitWarning { .. } => {} + AppState::GitWarning { .. } => {} } } } diff --git a/codex-rs/tui/src/app_event.rs b/codex-rs/tui/src/app_event.rs index fd6b2479ee..77a600d304 100644 --- a/codex-rs/tui/src/app_event.rs +++ b/codex-rs/tui/src/app_event.rs @@ -1,6 +1,7 @@ use codex_core::protocol::Event; use codex_file_search::FileMatch; use crossterm::event::KeyEvent; +use ratatui::text::Line; use crate::slash_command::SlashCommand; @@ -8,6 +9,10 @@ use crate::slash_command::SlashCommand; pub(crate) enum AppEvent { CodexEvent(Event), + /// Request a redraw which will be debounced by the [`App`]. + RequestRedraw, + + /// Actually draw the next frame. Redraw, KeyEvent(KeyEvent), @@ -15,10 +20,6 @@ pub(crate) enum AppEvent { /// Text pasted from the terminal clipboard. Paste(String), - /// Scroll event with a value representing the "scroll delta" as the net - /// scroll up/down events within a short time window. - Scroll(i32), - /// Request to exit the application gracefully. ExitRequest, @@ -45,4 +46,6 @@ pub(crate) enum AppEvent { query: String, matches: Vec, }, + + InsertHistory(Vec>), } diff --git a/codex-rs/tui/src/bottom_pane/approval_modal_view.rs b/codex-rs/tui/src/bottom_pane/approval_modal_view.rs index ca33047b1f..4cd952f9eb 100644 --- a/codex-rs/tui/src/bottom_pane/approval_modal_view.rs +++ b/codex-rs/tui/src/bottom_pane/approval_modal_view.rs @@ -9,6 +9,7 @@ use crate::user_approval_widget::UserApprovalWidget; use super::BottomPane; use super::BottomPaneView; +use super::CancellationEvent; /// Modal overlay asking the user to approve/deny a sequence of requests. pub(crate) struct ApprovalModalView<'a> { @@ -46,12 +47,18 @@ impl<'a> BottomPaneView<'a> for ApprovalModalView<'a> { self.maybe_advance(); } + fn on_ctrl_c(&mut self, _pane: &mut BottomPane<'a>) -> CancellationEvent { + self.current.on_ctrl_c(); + self.queue.clear(); + CancellationEvent::Handled + } + fn is_complete(&self) -> bool { self.current.is_complete() && self.queue.is_empty() } - fn calculate_required_height(&self, area: &Rect) -> u16 { - self.current.get_height(area) + fn desired_height(&self, width: u16) -> u16 { + self.current.desired_height(width) } fn render(&self, area: Rect, buf: &mut Buffer) { @@ -63,3 +70,39 @@ impl<'a> BottomPaneView<'a> for ApprovalModalView<'a> { None } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::app_event::AppEvent; + use std::path::PathBuf; + use std::sync::mpsc::channel; + + fn make_exec_request() -> ApprovalRequest { + ApprovalRequest::Exec { + id: "test".to_string(), + command: vec!["echo".to_string(), "hi".to_string()], + cwd: PathBuf::from("/tmp"), + reason: None, + } + } + + #[test] + fn ctrl_c_aborts_and_clears_queue() { + let (tx_raw, _rx) = channel::(); + let tx = AppEventSender::new(tx_raw); + let first = make_exec_request(); + let mut view = ApprovalModalView::new(first, tx); + view.enqueue_request(make_exec_request()); + + let (tx_raw2, _rx2) = channel::(); + let mut pane = BottomPane::new(super::super::BottomPaneParams { + app_event_tx: AppEventSender::new(tx_raw2), + has_input_focus: true, + }); + assert_eq!(CancellationEvent::Handled, view.on_ctrl_c(&mut pane)); + assert!(view.queue.is_empty()); + assert!(view.current.is_complete()); + assert!(view.is_complete()); + } +} diff --git a/codex-rs/tui/src/bottom_pane/bottom_pane_view.rs b/codex-rs/tui/src/bottom_pane/bottom_pane_view.rs index 6abf5399f5..a5616371d2 100644 --- a/codex-rs/tui/src/bottom_pane/bottom_pane_view.rs +++ b/codex-rs/tui/src/bottom_pane/bottom_pane_view.rs @@ -4,6 +4,7 @@ use ratatui::buffer::Buffer; use ratatui::layout::Rect; use super::BottomPane; +use super::CancellationEvent; /// Type to use for a method that may require a redraw of the UI. pub(crate) enum ConditionalUpdate { @@ -22,8 +23,13 @@ pub(crate) trait BottomPaneView<'a> { false } - /// Height required to render the view. - fn calculate_required_height(&self, area: &Rect) -> u16; + /// Handle Ctrl-C while this view is active. + fn on_ctrl_c(&mut self, _pane: &mut BottomPane<'a>) -> CancellationEvent { + CancellationEvent::Ignored + } + + /// Return the desired height of the view. + fn desired_height(&self, width: u16) -> u16; /// Render the view: this will be displayed in place of the composer. fn render(&self, area: Rect, buf: &mut Buffer); diff --git a/codex-rs/tui/src/bottom_pane/chat_composer.rs b/codex-rs/tui/src/bottom_pane/chat_composer.rs index e89187d165..3bc573a003 100644 --- a/codex-rs/tui/src/bottom_pane/chat_composer.rs +++ b/codex-rs/tui/src/bottom_pane/chat_composer.rs @@ -1,11 +1,13 @@ use codex_core::protocol::TokenUsage; use crossterm::event::KeyEvent; use ratatui::buffer::Buffer; -use ratatui::layout::Alignment; use ratatui::layout::Rect; +use ratatui::style::Color; use ratatui::style::Style; +use ratatui::style::Styled; use ratatui::style::Stylize; use ratatui::text::Line; +use ratatui::text::Span; use ratatui::widgets::BorderType; use ratatui::widgets::Borders; use ratatui::widgets::Widget; @@ -22,12 +24,7 @@ use crate::app_event::AppEvent; use crate::app_event_sender::AppEventSender; use codex_file_search::FileMatch; -/// Minimum number of visible text rows inside the textarea. -const MIN_TEXTAREA_ROWS: usize = 1; -/// Rows consumed by the border. -const BORDER_LINES: u16 = 2; - -const BASE_PLACEHOLDER_TEXT: &str = "send a message"; +const BASE_PLACEHOLDER_TEXT: &str = "..."; /// If the pasted content exceeds this number of characters, replace it with a /// placeholder in the UI. const LARGE_PASTE_CHAR_THRESHOLD: usize = 1000; @@ -76,6 +73,20 @@ impl ChatComposer<'_> { this } + pub fn desired_height(&self) -> u16 { + self.textarea.lines().len().max(1) as u16 + + match &self.active_popup { + ActivePopup::None => 1u16, + ActivePopup::Command(c) => c.calculate_required_height(), + ActivePopup::File(c) => c.calculate_required_height(), + } + } + + /// Returns true if the composer currently contains no user input. + pub(crate) fn is_empty(&self) -> bool { + self.textarea.is_empty() + } + /// Update the cached *context-left* percentage and refresh the placeholder /// text. The UI relies on the placeholder to convey the remaining /// context when the composer is empty. @@ -127,10 +138,6 @@ impl ChatComposer<'_> { .on_entry_response(log_id, offset, entry, &mut self.textarea) } - pub fn set_input_focus(&mut self, has_focus: bool) { - self.update_border(has_focus); - } - pub fn handle_paste(&mut self, pasted: String) -> bool { let char_count = pasted.chars().count(); if char_count > LARGE_PASTE_CHAR_THRESHOLD { @@ -464,6 +471,20 @@ impl ChatComposer<'_> { self.textarea.insert_newline(); (InputResult::None, true) } + Input { + key: Key::Char('d'), + ctrl: true, + alt: false, + shift: false, + } => { + self.textarea.input(Input { + key: Key::Delete, + ctrl: false, + alt: false, + shift: false, + }); + (InputResult::None, true) + } input => self.handle_input_basic(input), } } @@ -481,6 +502,17 @@ impl ChatComposer<'_> { } } + if let Input { + key: Key::Char('u'), + ctrl: true, + alt: false, + .. + } = input + { + self.textarea.delete_line_by_head(); + return (InputResult::None, true); + } + // Normal input handling self.textarea.input(input); let text_after = self.textarea.lines().join("\n"); @@ -604,107 +636,95 @@ impl ChatComposer<'_> { self.dismissed_file_popup_token = None; } - pub fn calculate_required_height(&self, area: &Rect) -> u16 { - let rows = self.textarea.lines().len().max(MIN_TEXTAREA_ROWS); - let num_popup_rows = match &self.active_popup { - ActivePopup::Command(popup) => popup.calculate_required_height(area), - ActivePopup::File(popup) => popup.calculate_required_height(area), - ActivePopup::None => 0, - }; - - rows as u16 + BORDER_LINES + num_popup_rows - } - fn update_border(&mut self, has_focus: bool) { - struct BlockState { - right_title: Line<'static>, - border_style: Style, - } - - let bs = if has_focus { - if self.ctrl_c_quit_hint { - BlockState { - right_title: Line::from("Ctrl+C to quit").alignment(Alignment::Right), - border_style: Style::default(), - } - } else { - BlockState { - right_title: Line::from("Enter to send | Ctrl+D to quit | Ctrl+J for newline") - .alignment(Alignment::Right), - border_style: Style::default(), - } - } + let border_style = if has_focus { + Style::default().fg(Color::Cyan) } else { - BlockState { - right_title: Line::from(""), - border_style: Style::default().dim(), - } + Style::default().dim() }; self.textarea.set_block( ratatui::widgets::Block::default() - .title_bottom(bs.right_title) - .borders(Borders::ALL) - .border_type(BorderType::Rounded) - .border_style(bs.border_style), + .borders(Borders::LEFT) + .border_type(BorderType::QuadrantOutside) + .border_style(border_style), ); } - - pub(crate) fn is_popup_visible(&self) -> bool { - match self.active_popup { - ActivePopup::Command(_) | ActivePopup::File(_) => true, - ActivePopup::None => false, - } - } } impl WidgetRef for &ChatComposer<'_> { fn render_ref(&self, area: Rect, buf: &mut Buffer) { match &self.active_popup { ActivePopup::Command(popup) => { - let popup_height = popup.calculate_required_height(&area); + let popup_height = popup.calculate_required_height(); // Split the provided rect so that the popup is rendered at the - // *top* and the textarea occupies the remaining space below. - let popup_rect = Rect { + // **bottom** and the textarea occupies the remaining space above. + let popup_height = popup_height.min(area.height); + let textarea_rect = Rect { x: area.x, y: area.y, width: area.width, - height: popup_height.min(area.height), + height: area.height.saturating_sub(popup_height), }; - - let textarea_rect = Rect { + let popup_rect = Rect { x: area.x, - y: area.y + popup_rect.height, + y: area.y + textarea_rect.height, width: area.width, - height: area.height.saturating_sub(popup_rect.height), + height: popup_height, }; popup.render(popup_rect, buf); self.textarea.render(textarea_rect, buf); } ActivePopup::File(popup) => { - let popup_height = popup.calculate_required_height(&area); + let popup_height = popup.calculate_required_height(); - let popup_rect = Rect { + let popup_height = popup_height.min(area.height); + let textarea_rect = Rect { x: area.x, y: area.y, width: area.width, - height: popup_height.min(area.height), - }; - - let textarea_rect = Rect { - x: area.x, - y: area.y + popup_rect.height, - width: area.width, height: area.height.saturating_sub(popup_height), }; + let popup_rect = Rect { + x: area.x, + y: area.y + textarea_rect.height, + width: area.width, + height: popup_height, + }; popup.render(popup_rect, buf); self.textarea.render(textarea_rect, buf); } ActivePopup::None => { - self.textarea.render(area, buf); + let mut textarea_rect = area; + textarea_rect.height = textarea_rect.height.saturating_sub(1); + self.textarea.render(textarea_rect, buf); + let mut bottom_line_rect = area; + bottom_line_rect.y += textarea_rect.height; + bottom_line_rect.height = 1; + let key_hint_style = Style::default().fg(Color::Cyan); + let hint = if self.ctrl_c_quit_hint { + vec![ + Span::from(" "), + "Ctrl+C again".set_style(key_hint_style), + Span::from(" to quit"), + ] + } else { + vec![ + Span::from(" "), + "⏎".set_style(key_hint_style), + Span::from(" send "), + "Shift+⏎".set_style(key_hint_style), + Span::from(" newline "), + "Ctrl+C".set_style(key_hint_style), + Span::from(" quit"), + ] + }; + Line::from(hint) + .style(Style::default().dim()) + .render_ref(bottom_line_rect, buf); } } } diff --git a/codex-rs/tui/src/bottom_pane/chat_composer_history.rs b/codex-rs/tui/src/bottom_pane/chat_composer_history.rs index fc85c28262..5715c99492 100644 --- a/codex-rs/tui/src/bottom_pane/chat_composer_history.rs +++ b/codex-rs/tui/src/bottom_pane/chat_composer_history.rs @@ -72,8 +72,7 @@ impl ChatComposerHistory { return false; } - let lines = textarea.lines(); - if lines.len() == 1 && lines[0].is_empty() { + if textarea.is_empty() { return true; } @@ -85,6 +84,7 @@ impl ChatComposerHistory { return false; } + let lines = textarea.lines(); matches!(&self.last_history_text, Some(prev) if prev == &lines.join("\n")) } diff --git a/codex-rs/tui/src/bottom_pane/command_popup.rs b/codex-rs/tui/src/bottom_pane/command_popup.rs index fd865047ef..364a8472dc 100644 --- a/codex-rs/tui/src/bottom_pane/command_popup.rs +++ b/codex-rs/tui/src/bottom_pane/command_popup.rs @@ -3,9 +3,9 @@ use ratatui::layout::Rect; use ratatui::style::Color; use ratatui::style::Style; use ratatui::style::Stylize; -use ratatui::widgets::Block; -use ratatui::widgets::BorderType; -use ratatui::widgets::Borders; +use ratatui::symbols::border::QUADRANT_LEFT_HALF; +use ratatui::text::Line; +use ratatui::text::Span; use ratatui::widgets::Cell; use ratatui::widgets::Row; use ratatui::widgets::Table; @@ -71,12 +71,8 @@ impl CommandPopup { /// Determine the preferred height of the popup. This is the number of /// rows required to show **at most** `MAX_POPUP_ROWS` commands plus the /// table/border overhead (one line at the top and one at the bottom). - pub(crate) fn calculate_required_height(&self, _area: &Rect) -> u16 { - let matches = self.filtered_commands(); - let row_count = matches.len().clamp(1, MAX_POPUP_ROWS) as u16; - // Account for the border added by the Block that wraps the table. - // 2 = one line at the top, one at the bottom. - row_count + 2 + pub(crate) fn calculate_required_height(&self) -> u16 { + self.filtered_commands().len().clamp(1, MAX_POPUP_ROWS) as u16 } /// Return the list of commands that match the current filter. Matching is @@ -158,18 +154,19 @@ impl WidgetRef for CommandPopup { let default_style = Style::default(); let command_style = Style::default().fg(Color::LightBlue); for (idx, cmd) in visible_matches.iter().enumerate() { - let (cmd_style, desc_style) = if Some(idx) == self.selected_idx { - ( - command_style.bg(Color::DarkGray), - default_style.bg(Color::DarkGray), - ) - } else { - (command_style, default_style) - }; - rows.push(Row::new(vec![ - Cell::from(format!("/{}", cmd.command())).style(cmd_style), - Cell::from(cmd.description().to_string()).style(desc_style), + Cell::from(Line::from(vec![ + if Some(idx) == self.selected_idx { + Span::styled( + "›", + Style::default().bg(Color::DarkGray).fg(Color::LightCyan), + ) + } else { + Span::styled(QUADRANT_LEFT_HALF, Style::default().fg(Color::DarkGray)) + }, + Span::styled(format!("/{}", cmd.command()), command_style), + ])), + Cell::from(cmd.description().to_string()).style(default_style), ])); } } @@ -180,12 +177,13 @@ impl WidgetRef for CommandPopup { rows, [Constraint::Length(FIRST_COLUMN_WIDTH), Constraint::Min(10)], ) - .column_spacing(0) - .block( - Block::default() - .borders(Borders::ALL) - .border_type(BorderType::Rounded), - ); + .column_spacing(0); + // .block( + // Block::default() + // .borders(Borders::LEFT) + // .border_type(BorderType::QuadrantOutside) + // .border_style(Style::default().fg(Color::DarkGray)), + // ); table.render(area, buf); } diff --git a/codex-rs/tui/src/bottom_pane/file_search_popup.rs b/codex-rs/tui/src/bottom_pane/file_search_popup.rs index 34eb59e4b2..ac6c91cf47 100644 --- a/codex-rs/tui/src/bottom_pane/file_search_popup.rs +++ b/codex-rs/tui/src/bottom_pane/file_search_popup.rs @@ -109,18 +109,14 @@ impl FileSearchPopup { } /// Preferred height (rows) including border. - pub(crate) fn calculate_required_height(&self, _area: &Rect) -> u16 { + pub(crate) fn calculate_required_height(&self) -> u16 { // Row count depends on whether we already have matches. If no matches // yet (e.g. initial search or query with no results) reserve a single // row so the popup is still visible. When matches are present we show // up to MAX_RESULTS regardless of the waiting flag so the list // remains stable while a newer search is in-flight. - let rows = if self.matches.is_empty() { - 1 - } else { - self.matches.len().clamp(1, MAX_RESULTS) - } as u16; - rows + 2 // border + + self.matches.len().clamp(1, MAX_RESULTS) as u16 } } @@ -128,7 +124,14 @@ impl WidgetRef for &FileSearchPopup { fn render_ref(&self, area: Rect, buf: &mut Buffer) { // Prepare rows. let rows: Vec = if self.matches.is_empty() { - vec![Row::new(vec![Cell::from(" no matches ")])] + vec![Row::new(vec![ + Cell::from(if self.waiting { + "(searching …)" + } else { + "no matches" + }) + .style(Style::new().add_modifier(Modifier::ITALIC | Modifier::DIM)), + ])] } else { self.matches .iter() @@ -169,17 +172,12 @@ impl WidgetRef for &FileSearchPopup { .collect() }; - let mut title = format!(" @{} ", self.pending_query); - if self.waiting { - title.push_str(" (searching …)"); - } - let table = Table::new(rows, vec![Constraint::Percentage(100)]) .block( Block::default() - .borders(Borders::ALL) - .border_type(BorderType::Rounded) - .title(title), + .borders(Borders::LEFT) + .border_type(BorderType::QuadrantOutside) + .border_style(Style::default().fg(Color::DarkGray)), ) .widths([Constraint::Percentage(100)]); diff --git a/codex-rs/tui/src/bottom_pane/mod.rs b/codex-rs/tui/src/bottom_pane/mod.rs index 350492b3e9..2710a3e997 100644 --- a/codex-rs/tui/src/bottom_pane/mod.rs +++ b/codex-rs/tui/src/bottom_pane/mod.rs @@ -20,6 +20,12 @@ mod command_popup; mod file_search_popup; mod status_indicator_view; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum CancellationEvent { + Ignored, + Handled, +} + pub(crate) use chat_composer::ChatComposer; pub(crate) use chat_composer::InputResult; @@ -58,6 +64,13 @@ impl BottomPane<'_> { } } + pub fn desired_height(&self, width: u16) -> u16 { + self.active_view + .as_ref() + .map(|v| v.desired_height(width)) + .unwrap_or(self.composer.desired_height()) + } + /// Forward a key event to the active view or the composer. pub fn handle_key_event(&mut self, key_event: KeyEvent) -> InputResult { if let Some(mut view) = self.active_view.take() { @@ -65,10 +78,8 @@ impl BottomPane<'_> { if !view.is_complete() { self.active_view = Some(view); } else if self.is_task_running { - let height = self.composer.calculate_required_height(&Rect::default()); self.active_view = Some(Box::new(StatusIndicatorView::new( self.app_event_tx.clone(), - height, ))); } self.request_redraw(); @@ -82,6 +93,33 @@ impl BottomPane<'_> { } } + /// Handle Ctrl-C in the bottom pane. If a modal view is active it gets a + /// chance to consume the event (e.g. to dismiss itself). + pub(crate) fn on_ctrl_c(&mut self) -> CancellationEvent { + let mut view = match self.active_view.take() { + Some(view) => view, + None => return CancellationEvent::Ignored, + }; + + let event = view.on_ctrl_c(self); + match event { + CancellationEvent::Handled => { + if !view.is_complete() { + self.active_view = Some(view); + } else if self.is_task_running { + self.active_view = Some(Box::new(StatusIndicatorView::new( + self.app_event_tx.clone(), + ))); + } + self.show_ctrl_c_quit_hint(); + } + CancellationEvent::Ignored => { + self.active_view = Some(view); + } + } + event + } + pub fn handle_paste(&mut self, pasted: String) { if self.active_view.is_none() { let needs_redraw = self.composer.handle_paste(pasted); @@ -106,12 +144,6 @@ impl BottomPane<'_> { } } - /// Update the UI to reflect whether this `BottomPane` has input focus. - pub(crate) fn set_input_focus(&mut self, has_focus: bool) { - self.has_input_focus = has_focus; - self.composer.set_input_focus(has_focus); - } - pub(crate) fn show_ctrl_c_quit_hint(&mut self) { self.ctrl_c_quit_hint = true; self.composer @@ -138,10 +170,8 @@ impl BottomPane<'_> { match (running, self.active_view.is_some()) { (true, false) => { // Show status indicator overlay. - let height = self.composer.calculate_required_height(&Rect::default()); self.active_view = Some(Box::new(StatusIndicatorView::new( self.app_event_tx.clone(), - height, ))); self.request_redraw(); } @@ -162,6 +192,10 @@ impl BottomPane<'_> { } } + pub(crate) fn composer_is_empty(&self) -> bool { + self.composer.is_empty() + } + pub(crate) fn is_task_running(&self) -> bool { self.is_task_running } @@ -199,21 +233,8 @@ impl BottomPane<'_> { } /// Height (terminal rows) required by the current bottom pane. - pub fn calculate_required_height(&self, area: &Rect) -> u16 { - if let Some(view) = &self.active_view { - view.calculate_required_height(area) - } else { - self.composer.calculate_required_height(area) - } - } - pub(crate) fn request_redraw(&self) { - self.app_event_tx.send(AppEvent::Redraw) - } - - /// Returns true when a popup inside the composer is visible. - pub(crate) fn is_popup_visible(&self) -> bool { - self.active_view.is_none() && self.composer.is_popup_visible() + self.app_event_tx.send(AppEvent::RequestRedraw) } // --- History helpers --- @@ -253,3 +274,34 @@ impl WidgetRef for &BottomPane<'_> { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::app_event::AppEvent; + use std::path::PathBuf; + use std::sync::mpsc::channel; + + fn exec_request() -> ApprovalRequest { + ApprovalRequest::Exec { + id: "1".to_string(), + command: vec!["echo".into(), "ok".into()], + cwd: PathBuf::from("."), + reason: None, + } + } + + #[test] + fn ctrl_c_on_modal_consumes_and_shows_quit_hint() { + let (tx_raw, _rx) = channel::(); + let tx = AppEventSender::new(tx_raw); + let mut pane = BottomPane::new(BottomPaneParams { + app_event_tx: tx, + has_input_focus: true, + }); + pane.push_approval_request(exec_request()); + assert_eq!(CancellationEvent::Handled, pane.on_ctrl_c()); + assert!(pane.ctrl_c_quit_hint_visible()); + assert_eq!(CancellationEvent::Ignored, pane.on_ctrl_c()); + } +} diff --git a/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__backspace_after_pastes.snap b/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__backspace_after_pastes.snap index fa604c862b..4f155dab30 100644 --- a/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__backspace_after_pastes.snap +++ b/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__backspace_after_pastes.snap @@ -2,13 +2,13 @@ source: tui/src/bottom_pane/chat_composer.rs expression: terminal.backend() --- -"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮" -"│[Pasted Content 1002 chars][Pasted Content 1004 chars] │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"╰───────────────────────────────────────────────Enter to send | Ctrl+D to quit | Ctrl+J for newline╯" +"▌[Pasted Content 1002 chars][Pasted Content 1004 chars] " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +" ⏎ send Shift+⏎ newline Ctrl+C quit " diff --git a/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__empty.snap b/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__empty.snap index a89076d8aa..4e8371f177 100644 --- a/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__empty.snap +++ b/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__empty.snap @@ -2,13 +2,13 @@ source: tui/src/bottom_pane/chat_composer.rs expression: terminal.backend() --- -"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮" -"│ send a message │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"╰───────────────────────────────────────────────Enter to send | Ctrl+D to quit | Ctrl+J for newline╯" +"▌ ... " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +" ⏎ send Shift+⏎ newline Ctrl+C quit " diff --git a/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__large.snap b/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__large.snap index 39a62da400..80fea40d5f 100644 --- a/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__large.snap +++ b/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__large.snap @@ -2,13 +2,13 @@ source: tui/src/bottom_pane/chat_composer.rs expression: terminal.backend() --- -"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮" -"│[Pasted Content 1005 chars] │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"╰───────────────────────────────────────────────Enter to send | Ctrl+D to quit | Ctrl+J for newline╯" +"▌[Pasted Content 1005 chars] " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +" ⏎ send Shift+⏎ newline Ctrl+C quit " diff --git a/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__multiple_pastes.snap b/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__multiple_pastes.snap index cd94095431..26e8d26733 100644 --- a/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__multiple_pastes.snap +++ b/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__multiple_pastes.snap @@ -2,13 +2,13 @@ source: tui/src/bottom_pane/chat_composer.rs expression: terminal.backend() --- -"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮" -"│[Pasted Content 1003 chars][Pasted Content 1007 chars] another short paste │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"╰───────────────────────────────────────────────Enter to send | Ctrl+D to quit | Ctrl+J for newline╯" +"▌[Pasted Content 1003 chars][Pasted Content 1007 chars] another short paste " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +" ⏎ send Shift+⏎ newline Ctrl+C quit " diff --git a/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__small.snap b/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__small.snap index e6b55e36d8..0f1b9e6426 100644 --- a/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__small.snap +++ b/codex-rs/tui/src/bottom_pane/snapshots/codex_tui__bottom_pane__chat_composer__tests__small.snap @@ -2,13 +2,13 @@ source: tui/src/bottom_pane/chat_composer.rs expression: terminal.backend() --- -"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮" -"│short │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"│ │" -"╰───────────────────────────────────────────────Enter to send | Ctrl+D to quit | Ctrl+J for newline╯" +"▌short " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +"▌ " +" ⏎ send Shift+⏎ newline Ctrl+C quit " diff --git a/codex-rs/tui/src/bottom_pane/status_indicator_view.rs b/codex-rs/tui/src/bottom_pane/status_indicator_view.rs index d9ac57d7b9..a944271e45 100644 --- a/codex-rs/tui/src/bottom_pane/status_indicator_view.rs +++ b/codex-rs/tui/src/bottom_pane/status_indicator_view.rs @@ -1,5 +1,4 @@ use ratatui::buffer::Buffer; -use ratatui::layout::Rect; use ratatui::widgets::WidgetRef; use crate::app_event_sender::AppEventSender; @@ -13,9 +12,9 @@ pub(crate) struct StatusIndicatorView { } impl StatusIndicatorView { - pub fn new(app_event_tx: AppEventSender, height: u16) -> Self { + pub fn new(app_event_tx: AppEventSender) -> Self { Self { - view: StatusIndicatorWidget::new(app_event_tx, height), + view: StatusIndicatorWidget::new(app_event_tx), } } @@ -24,7 +23,7 @@ impl StatusIndicatorView { } } -impl<'a> BottomPaneView<'a> for StatusIndicatorView { +impl BottomPaneView<'_> for StatusIndicatorView { fn update_status_text(&mut self, text: String) -> ConditionalUpdate { self.update_text(text); ConditionalUpdate::NeedsRedraw @@ -34,11 +33,11 @@ impl<'a> BottomPaneView<'a> for StatusIndicatorView { true } - fn calculate_required_height(&self, _area: &Rect) -> u16 { - self.view.get_height() + fn desired_height(&self, width: u16) -> u16 { + self.view.desired_height(width) } - fn render(&self, area: Rect, buf: &mut Buffer) { + fn render(&self, area: ratatui::layout::Rect, buf: &mut Buffer) { self.view.render_ref(area, buf); } } diff --git a/codex-rs/tui/src/cell_widget.rs b/codex-rs/tui/src/cell_widget.rs deleted file mode 100644 index 8acdc0553a..0000000000 --- a/codex-rs/tui/src/cell_widget.rs +++ /dev/null @@ -1,20 +0,0 @@ -use ratatui::prelude::*; - -/// Trait implemented by every type that can live inside the conversation -/// history list. It provides two primitives that the parent scroll-view -/// needs: how *tall* the widget is at a given width and how to render an -/// arbitrary contiguous *window* of that widget. -/// -/// The `first_visible_line` argument to [`render_window`] allows partial -/// rendering when the top of the widget is scrolled off-screen. The caller -/// guarantees that `first_visible_line + area.height as usize` never exceeds -/// the total height previously returned by [`height`]. -pub(crate) trait CellWidget { - /// Total height measured in wrapped terminal lines when drawn with the - /// given *content* width (no scrollbar column included). - fn height(&self, width: u16) -> usize; - - /// Render a *window* that starts `first_visible_line` lines below the top - /// of the widget. The window’s size is given by `area`. - fn render_window(&self, first_visible_line: usize, area: Rect, buf: &mut Buffer); -} diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index 3841db1dd9..f81e35d375 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -1,10 +1,15 @@ +use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; +use std::time::Duration; +use codex_core::codex_wrapper::CodexConversation; use codex_core::codex_wrapper::init_codex; use codex_core::config::Config; +use codex_core::protocol::AgentMessageDeltaEvent; use codex_core::protocol::AgentMessageEvent; use codex_core::protocol::AgentReasoningContentEvent; +use codex_core::protocol::AgentReasoningDeltaEvent; use codex_core::protocol::AgentReasoningEvent; use codex_core::protocol::ApplyPatchApprovalRequestEvent; use codex_core::protocol::ErrorEvent; @@ -22,9 +27,6 @@ use codex_core::protocol::TaskCompleteEvent; use codex_core::protocol::TokenUsage; use crossterm::event::KeyEvent; use ratatui::buffer::Buffer; -use ratatui::layout::Constraint; -use ratatui::layout::Direction; -use ratatui::layout::Layout; use ratatui::layout::Rect; use ratatui::widgets::Widget; use ratatui::widgets::WidgetRef; @@ -35,27 +37,34 @@ use crate::app_event::AppEvent; use crate::app_event_sender::AppEventSender; use crate::bottom_pane::BottomPane; use crate::bottom_pane::BottomPaneParams; +use crate::bottom_pane::CancellationEvent; use crate::bottom_pane::InputResult; -use crate::conversation_history_widget::ConversationHistoryWidget; +use crate::exec_command::strip_bash_lc_and_escape; +use crate::history_cell::CommandOutput; +use crate::history_cell::HistoryCell; use crate::history_cell::PatchEventType; use crate::user_approval_widget::ApprovalRequest; use codex_file_search::FileMatch; +struct RunningCommand { + command: Vec, + #[allow(dead_code)] + cwd: PathBuf, +} + pub(crate) struct ChatWidget<'a> { app_event_tx: AppEventSender, codex_op_tx: UnboundedSender, - conversation_history: ConversationHistoryWidget, bottom_pane: BottomPane<'a>, - input_focus: InputFocus, config: Config, initial_user_message: Option, token_usage: TokenUsage, -} - -#[derive(Clone, Copy, Eq, PartialEq)] -enum InputFocus { - HistoryPane, - BottomPane, + reasoning_buffer: String, + // Buffer for streaming assistant answer text; we do not surface partial + // We wait for the final AgentMessage event and then emit the full text + // at once into scrollback so the history contains a single message. + answer_buffer: String, + running_commands: HashMap, } struct UserMessage { @@ -93,7 +102,11 @@ impl ChatWidget<'_> { // Create the Codex asynchronously so the UI loads as quickly as possible. let config_for_agent_loop = config.clone(); tokio::spawn(async move { - let (codex, session_event, _ctrl_c) = match init_codex(config_for_agent_loop).await { + let CodexConversation { + codex, + session_configured, + .. + } = match init_codex(config_for_agent_loop).await { Ok(vals) => vals, Err(e) => { // TODO: surface this error to the user. @@ -104,7 +117,7 @@ impl ChatWidget<'_> { // Forward the captured `SessionInitialized` event that was consumed // inside `init_codex()` so it can be rendered in the UI. - app_event_tx_clone.send(AppEvent::CodexEvent(session_event.clone())); + app_event_tx_clone.send(AppEvent::CodexEvent(session_configured.clone())); let codex = Arc::new(codex); let codex_clone = codex.clone(); tokio::spawn(async move { @@ -124,61 +137,44 @@ impl ChatWidget<'_> { Self { app_event_tx: app_event_tx.clone(), codex_op_tx, - conversation_history: ConversationHistoryWidget::new(), bottom_pane: BottomPane::new(BottomPaneParams { app_event_tx, has_input_focus: true, }), - input_focus: InputFocus::BottomPane, config, initial_user_message: create_initial_user_message( initial_prompt.unwrap_or_default(), initial_images, ), token_usage: TokenUsage::default(), + reasoning_buffer: String::new(), + answer_buffer: String::new(), + running_commands: HashMap::new(), } } + pub fn desired_height(&self, width: u16) -> u16 { + self.bottom_pane.desired_height(width) + } + pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) { self.bottom_pane.clear_ctrl_c_quit_hint(); - // Special-case : normally toggles focus between history and bottom panes. - // However, when the slash-command popup is visible we forward the key - // to the bottom pane so it can handle auto-completion. - if matches!(key_event.code, crossterm::event::KeyCode::Tab) - && !self.bottom_pane.is_popup_visible() - { - self.input_focus = match self.input_focus { - InputFocus::HistoryPane => InputFocus::BottomPane, - InputFocus::BottomPane => InputFocus::HistoryPane, - }; - self.conversation_history - .set_input_focus(self.input_focus == InputFocus::HistoryPane); - self.bottom_pane - .set_input_focus(self.input_focus == InputFocus::BottomPane); - self.request_redraw(); - return; - } - match self.input_focus { - InputFocus::HistoryPane => { - let needs_redraw = self.conversation_history.handle_key_event(key_event); - if needs_redraw { - self.request_redraw(); - } + match self.bottom_pane.handle_key_event(key_event) { + InputResult::Submitted(text) => { + self.submit_user_message(text.into()); } - InputFocus::BottomPane => match self.bottom_pane.handle_key_event(key_event) { - InputResult::Submitted(text) => { - self.submit_user_message(text.into()); - } - InputResult::None => {} - }, + InputResult::None => {} } } pub(crate) fn handle_paste(&mut self, text: String) { - if matches!(self.input_focus, InputFocus::BottomPane) { - self.bottom_pane.handle_paste(text); - } + self.bottom_pane.handle_paste(text); + } + + fn add_to_history(&mut self, cell: HistoryCell) { + self.app_event_tx + .send(AppEvent::InsertHistory(cell.plain_lines())); } fn submit_user_message(&mut self, user_message: UserMessage) { @@ -214,23 +210,18 @@ impl ChatWidget<'_> { // Only show text portion in conversation history for now. if !text.is_empty() { - self.conversation_history.add_user_message(text); + self.add_to_history(HistoryCell::new_user_prompt(text.clone())); } - self.conversation_history.scroll_to_bottom(); } pub(crate) fn handle_codex_event(&mut self, event: Event) { let Event { id, msg } = event; match msg { EventMsg::SessionConfigured(event) => { - // Record session information at the top of the conversation. - self.conversation_history - .add_session_info(&self.config, event.clone()); - - // Forward history metadata to the bottom pane so the chat - // composer can navigate through past messages. self.bottom_pane .set_history_metadata(event.history_log_id, event.history_entry_count); + // Record session information at the top of the conversation. + self.add_to_history(HistoryCell::new_session_info(&self.config, event, true)); if let Some(user_message) = self.initial_user_message.take() { // If the user provided an initial message, add it to the @@ -241,16 +232,47 @@ impl ChatWidget<'_> { self.request_redraw(); } EventMsg::AgentMessage(AgentMessageEvent { message }) => { - self.conversation_history - .add_agent_message(&self.config, message); + // Final assistant answer. Prefer the fully provided message + // from the event; if it is empty fall back to any accumulated + // delta buffer (some providers may only stream deltas and send + // an empty final message). + let full = if message.is_empty() { + std::mem::take(&mut self.answer_buffer) + } else { + self.answer_buffer.clear(); + message + }; + if !full.is_empty() { + self.add_to_history(HistoryCell::new_agent_message(&self.config, full)); + } self.request_redraw(); } + EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => { + // Buffer only – do not emit partial lines. This avoids cases + // where long responses appear truncated if the terminal + // wrapped early. The full message is emitted on + // AgentMessage. + self.answer_buffer.push_str(&delta); + } + EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }) => { + // Buffer only – disable incremental reasoning streaming so we + // avoid truncated intermediate lines. Full text emitted on + // AgentReasoning. + self.reasoning_buffer.push_str(&delta); + } EventMsg::AgentReasoning(AgentReasoningEvent { text }) => { - if !self.config.hide_agent_reasoning { - self.conversation_history - .add_agent_reasoning(&self.config, text); - self.request_redraw(); + // Emit full reasoning text once. Some providers might send + // final event with empty text if only deltas were used. + let full = if text.is_empty() { + std::mem::take(&mut self.reasoning_buffer) + } else { + self.reasoning_buffer.clear(); + text + }; + if !full.is_empty() { + self.add_to_history(HistoryCell::new_agent_reasoning(&self.config, full)); } + self.request_redraw(); } EventMsg::AgentReasoningContent(AgentReasoningContentEvent { text }) => { if !self.config.hide_agent_reasoning && self.config.show_reasoning_content { @@ -277,14 +299,27 @@ impl ChatWidget<'_> { .set_token_usage(self.token_usage.clone(), self.config.model_context_window); } EventMsg::Error(ErrorEvent { message }) => { - self.conversation_history.add_error(message); + self.add_to_history(HistoryCell::new_error_event(message.clone())); self.bottom_pane.set_task_running(false); } EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { + call_id: _, command, cwd, reason, }) => { + // Print the command to the history so it is visible in the + // transcript *before* the modal asks for approval. + let cmdline = strip_bash_lc_and_escape(&command); + let text = format!( + "command requires approval:\n$ {cmdline}{reason}", + reason = reason + .as_ref() + .map(|r| format!("\n{r}")) + .unwrap_or_default() + ); + self.add_to_history(HistoryCell::new_background_event(text)); + let request = ApprovalRequest::Exec { id, command, @@ -292,8 +327,10 @@ impl ChatWidget<'_> { reason, }; self.bottom_pane.push_approval_request(request); + self.request_redraw(); } EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { + call_id: _, changes, reason, grant_root, @@ -309,10 +346,10 @@ impl ChatWidget<'_> { // prompt before they have seen *what* is being requested. // ------------------------------------------------------------------ - self.conversation_history - .add_patch_event(PatchEventType::ApprovalRequest, changes); - - self.conversation_history.scroll_to_bottom(); + self.add_to_history(HistoryCell::new_patch_event( + PatchEventType::ApprovalRequest, + changes, + )); // Now surface the approval request in the BottomPane as before. let request = ApprovalRequest::ApplyPatch { @@ -326,11 +363,16 @@ impl ChatWidget<'_> { EventMsg::ExecCommandBegin(ExecCommandBeginEvent { call_id, command, - cwd: _, + cwd, }) => { - self.conversation_history - .add_active_exec_command(call_id, command); - self.request_redraw(); + self.running_commands.insert( + call_id, + RunningCommand { + command: command.clone(), + cwd: cwd.clone(), + }, + ); + self.add_to_history(HistoryCell::new_active_exec_command(command)); } EventMsg::PatchApplyBegin(PatchApplyBeginEvent { call_id: _, @@ -339,12 +381,10 @@ impl ChatWidget<'_> { }) => { // Even when a patch is auto‑approved we still display the // summary so the user can follow along. - self.conversation_history - .add_patch_event(PatchEventType::ApplyBegin { auto_approved }, changes); - if !auto_approved { - self.conversation_history.scroll_to_bottom(); - } - self.request_redraw(); + self.add_to_history(HistoryCell::new_patch_event( + PatchEventType::ApplyBegin { auto_approved }, + changes, + )); } EventMsg::ExecCommandEnd(ExecCommandEndEvent { call_id, @@ -352,26 +392,39 @@ impl ChatWidget<'_> { stdout, stderr, }) => { - self.conversation_history - .record_completed_exec_command(call_id, stdout, stderr, exit_code); - self.request_redraw(); + let cmd = self.running_commands.remove(&call_id); + self.add_to_history(HistoryCell::new_completed_exec_command( + cmd.map(|cmd| cmd.command).unwrap_or_else(|| vec![call_id]), + CommandOutput { + exit_code, + stdout, + stderr, + duration: Duration::from_secs(0), + }, + )); } EventMsg::McpToolCallBegin(McpToolCallBeginEvent { - call_id, - server, - tool, - arguments, + call_id: _, + invocation, }) => { - self.conversation_history - .add_active_mcp_tool_call(call_id, server, tool, arguments); - self.request_redraw(); + self.add_to_history(HistoryCell::new_active_mcp_tool_call(invocation)); } - EventMsg::McpToolCallEnd(mcp_tool_call_end_event) => { - let success = mcp_tool_call_end_event.is_success(); - let McpToolCallEndEvent { call_id, result } = mcp_tool_call_end_event; - self.conversation_history - .record_completed_mcp_tool_call(call_id, success, result); - self.request_redraw(); + EventMsg::McpToolCallEnd(McpToolCallEndEvent { + call_id: _, + duration, + invocation, + result, + }) => { + self.add_to_history(HistoryCell::new_completed_mcp_tool_call( + 80, + invocation, + duration, + result + .as_ref() + .map(|r| r.is_error.unwrap_or(false)) + .unwrap_or(false), + result, + )); } EventMsg::GetHistoryEntryResponse(event) => { let codex_core::protocol::GetHistoryEntryResponseEvent { @@ -384,10 +437,11 @@ impl ChatWidget<'_> { self.bottom_pane .on_history_entry_response(log_id, offset, entry.map(|e| e.text)); } + EventMsg::ShutdownComplete => { + self.app_event_tx.send(AppEvent::ExitRequest); + } event => { - self.conversation_history - .add_background_event(format!("{event:?}")); - self.request_redraw(); + self.add_to_history(HistoryCell::new_background_event(format!("{event:?}"))); } } } @@ -399,25 +453,11 @@ impl ChatWidget<'_> { } fn request_redraw(&mut self) { - self.app_event_tx.send(AppEvent::Redraw); + self.app_event_tx.send(AppEvent::RequestRedraw); } pub(crate) fn add_diff_output(&mut self, diff_output: String) { - self.conversation_history.add_diff_output(diff_output); - self.request_redraw(); - } - - pub(crate) fn handle_scroll_delta(&mut self, scroll_delta: i32) { - // If the user is trying to scroll exactly one line, we let them, but - // otherwise we assume they are trying to scroll in larger increments. - let magnified_scroll_delta = if scroll_delta == 1 { - 1 - } else { - // Play with this: perhaps it should be non-linear? - scroll_delta * 2 - }; - self.conversation_history.scroll(magnified_scroll_delta); - self.request_redraw(); + self.add_to_history(HistoryCell::new_diff_output(diff_output.clone())); } /// Forward file-search results to the bottom pane. @@ -426,40 +466,50 @@ impl ChatWidget<'_> { } /// Handle Ctrl-C key press. - /// Returns true if the key press was handled, false if it was not. - /// If the key press was not handled, the caller should handle it (likely by exiting the process). - pub(crate) fn on_ctrl_c(&mut self) -> bool { + /// Returns CancellationEvent::Handled if the event was consumed by the UI, or + /// CancellationEvent::Ignored if the caller should handle it (e.g. exit). + pub(crate) fn on_ctrl_c(&mut self) -> CancellationEvent { + match self.bottom_pane.on_ctrl_c() { + CancellationEvent::Handled => return CancellationEvent::Handled, + CancellationEvent::Ignored => {} + } if self.bottom_pane.is_task_running() { self.bottom_pane.clear_ctrl_c_quit_hint(); self.submit_op(Op::Interrupt); - false + self.answer_buffer.clear(); + self.reasoning_buffer.clear(); + CancellationEvent::Ignored } else if self.bottom_pane.ctrl_c_quit_hint_visible() { - true + self.submit_op(Op::Shutdown); + CancellationEvent::Handled } else { self.bottom_pane.show_ctrl_c_quit_hint(); - false + CancellationEvent::Ignored } } + pub(crate) fn composer_is_empty(&self) -> bool { + self.bottom_pane.composer_is_empty() + } + /// Forward an `Op` directly to codex. pub(crate) fn submit_op(&self, op: Op) { if let Err(e) = self.codex_op_tx.send(op) { tracing::error!("failed to submit op: {e}"); } } + + pub(crate) fn token_usage(&self) -> &TokenUsage { + &self.token_usage + } } impl WidgetRef for &ChatWidget<'_> { fn render_ref(&self, area: Rect, buf: &mut Buffer) { - let bottom_height = self.bottom_pane.calculate_required_height(&area); - - let chunks = Layout::default() - .direction(Direction::Vertical) - .constraints([Constraint::Min(0), Constraint::Length(bottom_height)]) - .split(area); - - self.conversation_history.render(chunks[0], buf); - (&self.bottom_pane).render(chunks[1], buf); + // In the hybrid inline viewport mode we only draw the interactive + // bottom pane; history entries are injected directly into scrollback + // via `Terminal::insert_before`. + (&self.bottom_pane).render(area, buf); } } diff --git a/codex-rs/tui/src/conversation_history_widget.rs b/codex-rs/tui/src/conversation_history_widget.rs deleted file mode 100644 index c0e5031d70..0000000000 --- a/codex-rs/tui/src/conversation_history_widget.rs +++ /dev/null @@ -1,499 +0,0 @@ -use crate::cell_widget::CellWidget; -use crate::history_cell::CommandOutput; -use crate::history_cell::HistoryCell; -use crate::history_cell::PatchEventType; -use codex_core::config::Config; -use codex_core::protocol::FileChange; -use codex_core::protocol::SessionConfiguredEvent; -use crossterm::event::KeyCode; -use crossterm::event::KeyEvent; -use ratatui::prelude::*; -use ratatui::style::Style; -use ratatui::widgets::*; -use serde_json::Value as JsonValue; -use std::cell::Cell as StdCell; -use std::cell::Cell; -use std::collections::HashMap; -use std::path::PathBuf; - -/// A single history entry plus its cached wrapped-line count. -struct Entry { - cell: HistoryCell, - line_count: Cell, -} - -pub struct ConversationHistoryWidget { - entries: Vec, - /// The width (in terminal cells/columns) that [`Entry::line_count`] was - /// computed for. When the available width changes we recompute counts. - cached_width: StdCell, - scroll_position: usize, - /// Number of lines the last time render_ref() was called - num_rendered_lines: StdCell, - /// The height of the viewport last time render_ref() was called - last_viewport_height: StdCell, - has_input_focus: bool, -} - -impl ConversationHistoryWidget { - pub fn new() -> Self { - Self { - entries: Vec::new(), - cached_width: StdCell::new(0), - scroll_position: usize::MAX, - num_rendered_lines: StdCell::new(0), - last_viewport_height: StdCell::new(0), - has_input_focus: false, - } - } - - pub(crate) fn set_input_focus(&mut self, has_input_focus: bool) { - self.has_input_focus = has_input_focus; - } - - /// Returns true if it needs a redraw. - pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) -> bool { - match key_event.code { - KeyCode::Up | KeyCode::Char('k') => { - self.scroll_up(1); - true - } - KeyCode::Down | KeyCode::Char('j') => { - self.scroll_down(1); - true - } - KeyCode::PageUp | KeyCode::Char('b') => { - self.scroll_page_up(); - true - } - KeyCode::PageDown | KeyCode::Char(' ') => { - self.scroll_page_down(); - true - } - _ => false, - } - } - - /// Negative delta scrolls up; positive delta scrolls down. - pub(crate) fn scroll(&mut self, delta: i32) { - match delta.cmp(&0) { - std::cmp::Ordering::Less => self.scroll_up(-delta as u32), - std::cmp::Ordering::Greater => self.scroll_down(delta as u32), - std::cmp::Ordering::Equal => {} - } - } - - fn scroll_up(&mut self, num_lines: u32) { - // If a user is scrolling up from the "stick to bottom" mode, we need to - // map this to a specific scroll position so we can calculate the delta. - // This requires us to care about how tall the screen is. - if self.scroll_position == usize::MAX { - self.scroll_position = self - .num_rendered_lines - .get() - .saturating_sub(self.last_viewport_height.get()); - } - - self.scroll_position = self.scroll_position.saturating_sub(num_lines as usize); - } - - fn scroll_down(&mut self, num_lines: u32) { - // If we're already pinned to the bottom there's nothing to do. - if self.scroll_position == usize::MAX { - return; - } - - let viewport_height = self.last_viewport_height.get().max(1); - let num_rendered_lines = self.num_rendered_lines.get(); - - // Compute the maximum explicit scroll offset that still shows a full - // viewport. This mirrors the calculation in `scroll_page_down()` and - // in the render path. - let max_scroll = num_rendered_lines.saturating_sub(viewport_height); - - let new_pos = self.scroll_position.saturating_add(num_lines as usize); - - if new_pos >= max_scroll { - // Reached (or passed) the bottom – switch to stick‑to‑bottom mode - // so that additional output keeps the view pinned automatically. - self.scroll_position = usize::MAX; - } else { - self.scroll_position = new_pos; - } - } - - /// Scroll up by one full viewport height (Page Up). - fn scroll_page_up(&mut self) { - let viewport_height = self.last_viewport_height.get().max(1); - - // If we are currently in the "stick to bottom" mode, first convert the - // implicit scroll position (`usize::MAX`) into an explicit offset that - // represents the very bottom of the scroll region. This mirrors the - // logic from `scroll_up()`. - if self.scroll_position == usize::MAX { - self.scroll_position = self - .num_rendered_lines - .get() - .saturating_sub(viewport_height); - } - - // Move up by a full page. - self.scroll_position = self.scroll_position.saturating_sub(viewport_height); - } - - /// Scroll down by one full viewport height (Page Down). - fn scroll_page_down(&mut self) { - // Nothing to do if we're already stuck to the bottom. - if self.scroll_position == usize::MAX { - return; - } - - let viewport_height = self.last_viewport_height.get().max(1); - let num_lines = self.num_rendered_lines.get(); - - // Calculate the maximum explicit scroll offset that is still within - // range. This matches the logic in `scroll_down()` and the render - // method. - let max_scroll = num_lines.saturating_sub(viewport_height); - - // Attempt to move down by a full page. - let new_pos = self.scroll_position.saturating_add(viewport_height); - - if new_pos >= max_scroll { - // We have reached (or passed) the bottom – switch back to - // automatic stick‑to‑bottom mode so that subsequent output keeps - // the viewport pinned. - self.scroll_position = usize::MAX; - } else { - self.scroll_position = new_pos; - } - } - - pub fn scroll_to_bottom(&mut self) { - self.scroll_position = usize::MAX; - } - - /// Note `model` could differ from `config.model` if the agent decided to - /// use a different model than the one requested by the user. - pub fn add_session_info(&mut self, config: &Config, event: SessionConfiguredEvent) { - // In practice, SessionConfiguredEvent should always be the first entry - // in the history, but it is possible that an error could be sent - // before the session info. - let has_welcome_message = self - .entries - .iter() - .any(|entry| matches!(entry.cell, HistoryCell::WelcomeMessage { .. })); - self.add_to_history(HistoryCell::new_session_info( - config, - event, - !has_welcome_message, - )); - } - - pub fn add_user_message(&mut self, message: String) { - self.add_to_history(HistoryCell::new_user_prompt(message)); - } - - pub fn add_agent_message(&mut self, config: &Config, message: String) { - self.add_to_history(HistoryCell::new_agent_message(config, message)); - } - - pub fn add_agent_reasoning(&mut self, config: &Config, text: String) { - self.add_to_history(HistoryCell::new_agent_reasoning(config, text)); - } - - pub fn add_background_event(&mut self, message: String) { - self.add_to_history(HistoryCell::new_background_event(message)); - } - - pub fn add_diff_output(&mut self, diff_output: String) { - self.add_to_history(HistoryCell::new_diff_output(diff_output)); - } - - pub fn add_error(&mut self, message: String) { - self.add_to_history(HistoryCell::new_error_event(message)); - } - - /// Add a pending patch entry (before user approval). - pub fn add_patch_event( - &mut self, - event_type: PatchEventType, - changes: HashMap, - ) { - self.add_to_history(HistoryCell::new_patch_event(event_type, changes)); - } - - pub fn add_active_exec_command(&mut self, call_id: String, command: Vec) { - self.add_to_history(HistoryCell::new_active_exec_command(call_id, command)); - } - - pub fn add_active_mcp_tool_call( - &mut self, - call_id: String, - server: String, - tool: String, - arguments: Option, - ) { - self.add_to_history(HistoryCell::new_active_mcp_tool_call( - call_id, server, tool, arguments, - )); - } - - fn add_to_history(&mut self, cell: HistoryCell) { - let width = self.cached_width.get(); - let count = if width > 0 { cell.height(width) } else { 0 }; - - self.entries.push(Entry { - cell, - line_count: Cell::new(count), - }); - } - - pub fn record_completed_exec_command( - &mut self, - call_id: String, - stdout: String, - stderr: String, - exit_code: i32, - ) { - let width = self.cached_width.get(); - for entry in self.entries.iter_mut() { - let cell = &mut entry.cell; - if let HistoryCell::ActiveExecCommand { - call_id: history_id, - command, - start, - .. - } = cell - { - if &call_id == history_id { - *cell = HistoryCell::new_completed_exec_command( - command.clone(), - CommandOutput { - exit_code, - stdout, - stderr, - duration: start.elapsed(), - }, - ); - - // Update cached line count. - if width > 0 { - entry.line_count.set(cell.height(width)); - } - break; - } - } - } - } - - pub fn record_completed_mcp_tool_call( - &mut self, - call_id: String, - success: bool, - result: Result, - ) { - let width = self.cached_width.get(); - for entry in self.entries.iter_mut() { - if let HistoryCell::ActiveMcpToolCall { - call_id: history_id, - invocation, - start, - .. - } = &entry.cell - { - if &call_id == history_id { - let completed = HistoryCell::new_completed_mcp_tool_call( - width, - invocation.clone(), - *start, - success, - result, - ); - entry.cell = completed; - - if width > 0 { - entry.line_count.set(entry.cell.height(width)); - } - - break; - } - } - } - } -} - -impl WidgetRef for ConversationHistoryWidget { - fn render_ref(&self, area: Rect, buf: &mut Buffer) { - let (title, border_style) = if self.has_input_focus { - ( - "Messages (↑/↓ or j/k = line, b/space = page)", - Style::default().fg(Color::LightYellow), - ) - } else { - ("Messages (tab to focus)", Style::default().dim()) - }; - - let block = Block::default() - .title(title) - .borders(Borders::ALL) - .border_type(BorderType::Rounded) - .border_style(border_style); - - // Compute the inner area that will be available for the list after - // the surrounding `Block` is drawn. - let inner = block.inner(area); - let viewport_height = inner.height as usize; - - // Cache (and if necessary recalculate) the wrapped line counts for every - // [`HistoryCell`] so that our scrolling math accounts for text - // wrapping. We always reserve one column on the right-hand side for the - // scrollbar so that the content never renders "under" the scrollbar. - let effective_width = inner.width.saturating_sub(1); - - if effective_width == 0 { - return; // Nothing to draw – avoid division by zero. - } - - // Recompute cache if the effective width changed. - let num_lines: usize = if self.cached_width.get() != effective_width { - self.cached_width.set(effective_width); - - let mut num_lines: usize = 0; - for entry in &self.entries { - let count = entry.cell.height(effective_width); - num_lines += count; - entry.line_count.set(count); - } - num_lines - } else { - self.entries.iter().map(|e| e.line_count.get()).sum() - }; - - // Determine the scroll position. Note the existing value of - // `self.scroll_position` could exceed the maximum scroll offset if the - // user made the window wider since the last render. - let max_scroll = num_lines.saturating_sub(viewport_height); - let scroll_pos = if self.scroll_position == usize::MAX { - max_scroll - } else { - self.scroll_position.min(max_scroll) - }; - - // ------------------------------------------------------------------ - // Render order: - // 1. Clear full widget area (avoid artifacts from prior frame). - // 2. Draw the surrounding Block (border and title). - // 3. Render *each* visible HistoryCell into its own sub-Rect while - // respecting partial visibility at the top and bottom. - // 4. Draw the scrollbar track / thumb in the reserved column. - // ------------------------------------------------------------------ - - // Clear entire widget area first. - Clear.render(area, buf); - - // Draw border + title. - block.render(area, buf); - - // ------------------------------------------------------------------ - // Calculate which cells are visible for the current scroll position - // and paint them one by one. - // ------------------------------------------------------------------ - - let mut y_cursor = inner.y; // first line inside viewport - let mut remaining_height = inner.height as usize; - let mut lines_to_skip = scroll_pos; // number of wrapped lines to skip (above viewport) - - for entry in &self.entries { - let cell_height = entry.line_count.get(); - - // Completely above viewport? Skip whole cell. - if lines_to_skip >= cell_height { - lines_to_skip -= cell_height; - continue; - } - - // Determine how much of this cell is visible. - let visible_height = (cell_height - lines_to_skip).min(remaining_height); - - if visible_height == 0 { - break; // no space left - } - - let cell_rect = Rect { - x: inner.x, - y: y_cursor, - width: effective_width, - height: visible_height as u16, - }; - - entry.cell.render_window(lines_to_skip, cell_rect, buf); - - // Advance cursor inside viewport. - y_cursor += visible_height as u16; - remaining_height -= visible_height; - - // After the first (possibly partially skipped) cell, we no longer - // need to skip lines at the top. - lines_to_skip = 0; - - if remaining_height == 0 { - break; // viewport filled - } - } - - // Always render a scrollbar *track* so the reserved column is filled. - let overflow = num_lines.saturating_sub(viewport_height); - - let mut scroll_state = ScrollbarState::default() - // The Scrollbar widget expects the *content* height minus the - // viewport height. When there is no overflow we still provide 0 - // so that the widget renders only the track without a thumb. - .content_length(overflow) - .position(scroll_pos); - - { - // Choose a thumb color that stands out only when this pane has focus so that the - // user’s attention is naturally drawn to the active viewport. When unfocused we show - // a low-contrast thumb so the scrollbar fades into the background without becoming - // invisible. - let thumb_style = if self.has_input_focus { - Style::reset().fg(Color::LightYellow) - } else { - Style::reset().fg(Color::Gray) - }; - - // By default the Scrollbar widget inherits any style that was - // present in the underlying buffer cells. That means if a colored - // line happens to be underneath the scrollbar, the track (and - // potentially the thumb) adopt that color. Explicitly setting the - // track/thumb styles ensures we always draw the scrollbar with a - // consistent palette regardless of what content is behind it. - StatefulWidget::render( - Scrollbar::new(ScrollbarOrientation::VerticalRight) - .begin_symbol(Some("↑")) - .end_symbol(Some("↓")) - .begin_style(Style::reset().fg(Color::DarkGray)) - .end_style(Style::reset().fg(Color::DarkGray)) - .thumb_symbol("█") - .thumb_style(thumb_style) - .track_symbol(Some("│")) - .track_style(Style::reset().fg(Color::DarkGray)), - inner, - buf, - &mut scroll_state, - ); - } - - // Update auxiliary stats that the scroll handlers rely on. - self.num_rendered_lines.set(num_lines); - self.last_viewport_height.set(viewport_height); - } -} - -/// Common [`Wrap`] configuration used for both measurement and rendering so -/// they stay in sync. -#[inline] -pub(crate) const fn wrap_cfg() -> ratatui::widgets::Wrap { - ratatui::widgets::Wrap { trim: false } -} diff --git a/codex-rs/tui/src/custom_terminal.rs b/codex-rs/tui/src/custom_terminal.rs new file mode 100644 index 0000000000..1ada679fc1 --- /dev/null +++ b/codex-rs/tui/src/custom_terminal.rs @@ -0,0 +1,588 @@ +// This is derived from `ratatui::Terminal`, which is licensed under the following terms: +// +// The MIT License (MIT) +// Copyright (c) 2016-2022 Florian Dehau +// Copyright (c) 2023-2025 The Ratatui Developers +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +use std::io; + +use ratatui::backend::Backend; +use ratatui::backend::ClearType; +use ratatui::buffer::Buffer; +use ratatui::layout::Position; +use ratatui::layout::Rect; +use ratatui::layout::Size; +use ratatui::widgets::StatefulWidget; +use ratatui::widgets::StatefulWidgetRef; +use ratatui::widgets::Widget; +use ratatui::widgets::WidgetRef; + +#[derive(Debug, Hash)] +pub struct Frame<'a> { + /// Where should the cursor be after drawing this frame? + /// + /// If `None`, the cursor is hidden and its position is controlled by the backend. If `Some((x, + /// y))`, the cursor is shown and placed at `(x, y)` after the call to `Terminal::draw()`. + pub(crate) cursor_position: Option, + + /// The area of the viewport + pub(crate) viewport_area: Rect, + + /// The buffer that is used to draw the current frame + pub(crate) buffer: &'a mut Buffer, + + /// The frame count indicating the sequence number of this frame. + pub(crate) count: usize, +} + +#[allow(dead_code)] +impl Frame<'_> { + /// The area of the current frame + /// + /// This is guaranteed not to change during rendering, so may be called multiple times. + /// + /// If your app listens for a resize event from the backend, it should ignore the values from + /// the event for any calculations that are used to render the current frame and use this value + /// instead as this is the area of the buffer that is used to render the current frame. + pub const fn area(&self) -> Rect { + self.viewport_area + } + + /// Render a [`Widget`] to the current buffer using [`Widget::render`]. + /// + /// Usually the area argument is the size of the current frame or a sub-area of the current + /// frame (which can be obtained using [`Layout`] to split the total area). + /// + /// # Example + /// + /// ```rust + /// # use ratatui::{backend::TestBackend, Terminal}; + /// # let backend = TestBackend::new(5, 5); + /// # let mut terminal = Terminal::new(backend).unwrap(); + /// # let mut frame = terminal.get_frame(); + /// use ratatui::{layout::Rect, widgets::Block}; + /// + /// let block = Block::new(); + /// let area = Rect::new(0, 0, 5, 5); + /// frame.render_widget(block, area); + /// ``` + /// + /// [`Layout`]: crate::layout::Layout + pub fn render_widget(&mut self, widget: W, area: Rect) { + widget.render(area, self.buffer); + } + + /// Render a [`WidgetRef`] to the current buffer using [`WidgetRef::render_ref`]. + /// + /// Usually the area argument is the size of the current frame or a sub-area of the current + /// frame (which can be obtained using [`Layout`] to split the total area). + /// + /// # Example + /// + /// ```rust + /// # #[cfg(feature = "unstable-widget-ref")] { + /// # use ratatui::{backend::TestBackend, Terminal}; + /// # let backend = TestBackend::new(5, 5); + /// # let mut terminal = Terminal::new(backend).unwrap(); + /// # let mut frame = terminal.get_frame(); + /// use ratatui::{layout::Rect, widgets::Block}; + /// + /// let block = Block::new(); + /// let area = Rect::new(0, 0, 5, 5); + /// frame.render_widget_ref(block, area); + /// # } + /// ``` + #[allow(clippy::needless_pass_by_value)] + pub fn render_widget_ref(&mut self, widget: W, area: Rect) { + widget.render_ref(area, self.buffer); + } + + /// Render a [`StatefulWidget`] to the current buffer using [`StatefulWidget::render`]. + /// + /// Usually the area argument is the size of the current frame or a sub-area of the current + /// frame (which can be obtained using [`Layout`] to split the total area). + /// + /// The last argument should be an instance of the [`StatefulWidget::State`] associated to the + /// given [`StatefulWidget`]. + /// + /// # Example + /// + /// ```rust + /// # use ratatui::{backend::TestBackend, Terminal}; + /// # let backend = TestBackend::new(5, 5); + /// # let mut terminal = Terminal::new(backend).unwrap(); + /// # let mut frame = terminal.get_frame(); + /// use ratatui::{ + /// layout::Rect, + /// widgets::{List, ListItem, ListState}, + /// }; + /// + /// let mut state = ListState::default().with_selected(Some(1)); + /// let list = List::new(vec![ListItem::new("Item 1"), ListItem::new("Item 2")]); + /// let area = Rect::new(0, 0, 5, 5); + /// frame.render_stateful_widget(list, area, &mut state); + /// ``` + /// + /// [`Layout`]: crate::layout::Layout + pub fn render_stateful_widget(&mut self, widget: W, area: Rect, state: &mut W::State) + where + W: StatefulWidget, + { + widget.render(area, self.buffer, state); + } + + /// Render a [`StatefulWidgetRef`] to the current buffer using + /// [`StatefulWidgetRef::render_ref`]. + /// + /// Usually the area argument is the size of the current frame or a sub-area of the current + /// frame (which can be obtained using [`Layout`] to split the total area). + /// + /// The last argument should be an instance of the [`StatefulWidgetRef::State`] associated to + /// the given [`StatefulWidgetRef`]. + /// + /// # Example + /// + /// ```rust + /// # #[cfg(feature = "unstable-widget-ref")] { + /// # use ratatui::{backend::TestBackend, Terminal}; + /// # let backend = TestBackend::new(5, 5); + /// # let mut terminal = Terminal::new(backend).unwrap(); + /// # let mut frame = terminal.get_frame(); + /// use ratatui::{ + /// layout::Rect, + /// widgets::{List, ListItem, ListState}, + /// }; + /// + /// let mut state = ListState::default().with_selected(Some(1)); + /// let list = List::new(vec![ListItem::new("Item 1"), ListItem::new("Item 2")]); + /// let area = Rect::new(0, 0, 5, 5); + /// frame.render_stateful_widget_ref(list, area, &mut state); + /// # } + /// ``` + #[allow(clippy::needless_pass_by_value)] + pub fn render_stateful_widget_ref(&mut self, widget: W, area: Rect, state: &mut W::State) + where + W: StatefulWidgetRef, + { + widget.render_ref(area, self.buffer, state); + } + + /// After drawing this frame, make the cursor visible and put it at the specified (x, y) + /// coordinates. If this method is not called, the cursor will be hidden. + /// + /// Note that this will interfere with calls to [`Terminal::hide_cursor`], + /// [`Terminal::show_cursor`], and [`Terminal::set_cursor_position`]. Pick one of the APIs and + /// stick with it. + /// + /// [`Terminal::hide_cursor`]: crate::Terminal::hide_cursor + /// [`Terminal::show_cursor`]: crate::Terminal::show_cursor + /// [`Terminal::set_cursor_position`]: crate::Terminal::set_cursor_position + pub fn set_cursor_position>(&mut self, position: P) { + self.cursor_position = Some(position.into()); + } + + /// Gets the buffer that this `Frame` draws into as a mutable reference. + pub fn buffer_mut(&mut self) -> &mut Buffer { + self.buffer + } + + /// Returns the current frame count. + /// + /// This method provides access to the frame count, which is a sequence number indicating + /// how many frames have been rendered up to (but not including) this one. It can be used + /// for purposes such as animation, performance tracking, or debugging. + /// + /// Each time a frame has been rendered, this count is incremented, + /// providing a consistent way to reference the order and number of frames processed by the + /// terminal. When count reaches its maximum value (`usize::MAX`), it wraps around to zero. + /// + /// This count is particularly useful when dealing with dynamic content or animations where the + /// state of the display changes over time. By tracking the frame count, developers can + /// synchronize updates or changes to the content with the rendering process. + /// + /// # Examples + /// + /// ```rust + /// # use ratatui::{backend::TestBackend, Terminal}; + /// # let backend = TestBackend::new(5, 5); + /// # let mut terminal = Terminal::new(backend).unwrap(); + /// # let mut frame = terminal.get_frame(); + /// let current_count = frame.count(); + /// println!("Current frame count: {}", current_count); + /// ``` + pub const fn count(&self) -> usize { + self.count + } +} + +#[derive(Debug, Default, Clone, Eq, PartialEq, Hash)] +pub struct Terminal +where + B: Backend, +{ + /// The backend used to interface with the terminal + backend: B, + /// Holds the results of the current and previous draw calls. The two are compared at the end + /// of each draw pass to output the necessary updates to the terminal + buffers: [Buffer; 2], + /// Index of the current buffer in the previous array + current: usize, + /// Whether the cursor is currently hidden + hidden_cursor: bool, + /// Area of the viewport + pub viewport_area: Rect, + /// Last known size of the terminal. Used to detect if the internal buffers have to be resized. + pub last_known_screen_size: Size, + /// Last known position of the cursor. Used to find the new area when the viewport is inlined + /// and the terminal resized. + pub last_known_cursor_pos: Position, + /// Number of frames rendered up until current time. + frame_count: usize, +} + +impl Drop for Terminal +where + B: Backend, +{ + #[allow(clippy::print_stderr)] + fn drop(&mut self) { + // Attempt to restore the cursor state + if self.hidden_cursor { + if let Err(err) = self.show_cursor() { + eprintln!("Failed to show the cursor: {err}"); + } + } + } +} + +impl Terminal +where + B: Backend, +{ + /// Creates a new [`Terminal`] with the given [`Backend`] and [`TerminalOptions`]. + /// + /// # Example + /// + /// ```rust + /// use std::io::stdout; + /// + /// use ratatui::{backend::CrosstermBackend, layout::Rect, Terminal, TerminalOptions, Viewport}; + /// + /// let backend = CrosstermBackend::new(stdout()); + /// let viewport = Viewport::Fixed(Rect::new(0, 0, 10, 10)); + /// let terminal = Terminal::with_options(backend, TerminalOptions { viewport })?; + /// # std::io::Result::Ok(()) + /// ``` + pub fn with_options(mut backend: B) -> io::Result { + let screen_size = backend.size()?; + let cursor_pos = backend.get_cursor_position()?; + Ok(Self { + backend, + buffers: [ + Buffer::empty(Rect::new(0, 0, 0, 0)), + Buffer::empty(Rect::new(0, 0, 0, 0)), + ], + current: 0, + hidden_cursor: false, + viewport_area: Rect::new(0, cursor_pos.y, 0, 0), + last_known_screen_size: screen_size, + last_known_cursor_pos: cursor_pos, + frame_count: 0, + }) + } + + /// Get a Frame object which provides a consistent view into the terminal state for rendering. + pub fn get_frame(&mut self) -> Frame { + let count = self.frame_count; + Frame { + cursor_position: None, + viewport_area: self.viewport_area, + buffer: self.current_buffer_mut(), + count, + } + } + + /// Gets the current buffer as a mutable reference. + pub fn current_buffer_mut(&mut self) -> &mut Buffer { + &mut self.buffers[self.current] + } + + /// Gets the backend + pub const fn backend(&self) -> &B { + &self.backend + } + + /// Gets the backend as a mutable reference + pub fn backend_mut(&mut self) -> &mut B { + &mut self.backend + } + + /// Obtains a difference between the previous and the current buffer and passes it to the + /// current backend for drawing. + pub fn flush(&mut self) -> io::Result<()> { + let previous_buffer = &self.buffers[1 - self.current]; + let current_buffer = &self.buffers[self.current]; + let updates = previous_buffer.diff(current_buffer); + if let Some((col, row, _)) = updates.last() { + self.last_known_cursor_pos = Position { x: *col, y: *row }; + } + self.backend.draw(updates.into_iter()) + } + + /// Updates the Terminal so that internal buffers match the requested area. + /// + /// Requested area will be saved to remain consistent when rendering. This leads to a full clear + /// of the screen. + pub fn resize(&mut self, screen_size: Size) -> io::Result<()> { + self.last_known_screen_size = screen_size; + Ok(()) + } + + /// Sets the viewport area. + pub fn set_viewport_area(&mut self, area: Rect) { + self.buffers[self.current].resize(area); + self.buffers[1 - self.current].resize(area); + self.viewport_area = area; + } + + /// Queries the backend for size and resizes if it doesn't match the previous size. + pub fn autoresize(&mut self) -> io::Result<()> { + let screen_size = self.size()?; + if screen_size != self.last_known_screen_size { + self.resize(screen_size)?; + } + Ok(()) + } + + /// Draws a single frame to the terminal. + /// + /// Returns a [`CompletedFrame`] if successful, otherwise a [`std::io::Error`]. + /// + /// If the render callback passed to this method can fail, use [`try_draw`] instead. + /// + /// Applications should call `draw` or [`try_draw`] in a loop to continuously render the + /// terminal. These methods are the main entry points for drawing to the terminal. + /// + /// [`try_draw`]: Terminal::try_draw + /// + /// This method will: + /// + /// - autoresize the terminal if necessary + /// - call the render callback, passing it a [`Frame`] reference to render to + /// - flush the current internal state by copying the current buffer to the backend + /// - move the cursor to the last known position if it was set during the rendering closure + /// + /// The render callback should fully render the entire frame when called, including areas that + /// are unchanged from the previous frame. This is because each frame is compared to the + /// previous frame to determine what has changed, and only the changes are written to the + /// terminal. If the render callback does not fully render the frame, the terminal will not be + /// in a consistent state. + /// + /// # Examples + /// + /// ``` + /// # let backend = ratatui::backend::TestBackend::new(10, 10); + /// # let mut terminal = ratatui::Terminal::new(backend)?; + /// use ratatui::{layout::Position, widgets::Paragraph}; + /// + /// // with a closure + /// terminal.draw(|frame| { + /// let area = frame.area(); + /// frame.render_widget(Paragraph::new("Hello World!"), area); + /// frame.set_cursor_position(Position { x: 0, y: 0 }); + /// })?; + /// + /// // or with a function + /// terminal.draw(render)?; + /// + /// fn render(frame: &mut ratatui::Frame) { + /// frame.render_widget(Paragraph::new("Hello World!"), frame.area()); + /// } + /// # std::io::Result::Ok(()) + /// ``` + pub fn draw(&mut self, render_callback: F) -> io::Result<()> + where + F: FnOnce(&mut Frame), + { + self.try_draw(|frame| { + render_callback(frame); + io::Result::Ok(()) + }) + } + + /// Tries to draw a single frame to the terminal. + /// + /// Returns [`Result::Ok`] containing a [`CompletedFrame`] if successful, otherwise + /// [`Result::Err`] containing the [`std::io::Error`] that caused the failure. + /// + /// This is the equivalent of [`Terminal::draw`] but the render callback is a function or + /// closure that returns a `Result` instead of nothing. + /// + /// Applications should call `try_draw` or [`draw`] in a loop to continuously render the + /// terminal. These methods are the main entry points for drawing to the terminal. + /// + /// [`draw`]: Terminal::draw + /// + /// This method will: + /// + /// - autoresize the terminal if necessary + /// - call the render callback, passing it a [`Frame`] reference to render to + /// - flush the current internal state by copying the current buffer to the backend + /// - move the cursor to the last known position if it was set during the rendering closure + /// - return a [`CompletedFrame`] with the current buffer and the area of the terminal + /// + /// The render callback passed to `try_draw` can return any [`Result`] with an error type that + /// can be converted into an [`std::io::Error`] using the [`Into`] trait. This makes it possible + /// to use the `?` operator to propagate errors that occur during rendering. If the render + /// callback returns an error, the error will be returned from `try_draw` as an + /// [`std::io::Error`] and the terminal will not be updated. + /// + /// The [`CompletedFrame`] returned by this method can be useful for debugging or testing + /// purposes, but it is often not used in regular applicationss. + /// + /// The render callback should fully render the entire frame when called, including areas that + /// are unchanged from the previous frame. This is because each frame is compared to the + /// previous frame to determine what has changed, and only the changes are written to the + /// terminal. If the render function does not fully render the frame, the terminal will not be + /// in a consistent state. + /// + /// # Examples + /// + /// ```should_panic + /// # use ratatui::layout::Position;; + /// # let backend = ratatui::backend::TestBackend::new(10, 10); + /// # let mut terminal = ratatui::Terminal::new(backend)?; + /// use std::io; + /// + /// use ratatui::widgets::Paragraph; + /// + /// // with a closure + /// terminal.try_draw(|frame| { + /// let value: u8 = "not a number".parse().map_err(io::Error::other)?; + /// let area = frame.area(); + /// frame.render_widget(Paragraph::new("Hello World!"), area); + /// frame.set_cursor_position(Position { x: 0, y: 0 }); + /// io::Result::Ok(()) + /// })?; + /// + /// // or with a function + /// terminal.try_draw(render)?; + /// + /// fn render(frame: &mut ratatui::Frame) -> io::Result<()> { + /// let value: u8 = "not a number".parse().map_err(io::Error::other)?; + /// frame.render_widget(Paragraph::new("Hello World!"), frame.area()); + /// Ok(()) + /// } + /// # io::Result::Ok(()) + /// ``` + pub fn try_draw(&mut self, render_callback: F) -> io::Result<()> + where + F: FnOnce(&mut Frame) -> Result<(), E>, + E: Into, + { + // Autoresize - otherwise we get glitches if shrinking or potential desync between widgets + // and the terminal (if growing), which may OOB. + self.autoresize()?; + + let mut frame = self.get_frame(); + + render_callback(&mut frame).map_err(Into::into)?; + + // We can't change the cursor position right away because we have to flush the frame to + // stdout first. But we also can't keep the frame around, since it holds a &mut to + // Buffer. Thus, we're taking the important data out of the Frame and dropping it. + let cursor_position = frame.cursor_position; + + // Draw to stdout + self.flush()?; + + match cursor_position { + None => self.hide_cursor()?, + Some(position) => { + self.show_cursor()?; + self.set_cursor_position(position)?; + } + } + + self.swap_buffers(); + + // Flush + self.backend.flush()?; + + // increment frame count before returning from draw + self.frame_count = self.frame_count.wrapping_add(1); + + Ok(()) + } + + /// Hides the cursor. + pub fn hide_cursor(&mut self) -> io::Result<()> { + self.backend.hide_cursor()?; + self.hidden_cursor = true; + Ok(()) + } + + /// Shows the cursor. + pub fn show_cursor(&mut self) -> io::Result<()> { + self.backend.show_cursor()?; + self.hidden_cursor = false; + Ok(()) + } + + /// Gets the current cursor position. + /// + /// This is the position of the cursor after the last draw call. + #[allow(dead_code)] + pub fn get_cursor_position(&mut self) -> io::Result { + self.backend.get_cursor_position() + } + + /// Sets the cursor position. + pub fn set_cursor_position>(&mut self, position: P) -> io::Result<()> { + let position = position.into(); + self.backend.set_cursor_position(position)?; + self.last_known_cursor_pos = position; + Ok(()) + } + + /// Clear the terminal and force a full redraw on the next draw call. + pub fn clear(&mut self) -> io::Result<()> { + if self.viewport_area.is_empty() { + return Ok(()); + } + self.backend + .set_cursor_position(self.viewport_area.as_position())?; + self.backend.clear_region(ClearType::AfterCursor)?; + // Reset the back buffer to make sure the next update will redraw everything. + self.buffers[1 - self.current].reset(); + Ok(()) + } + + /// Clears the inactive buffer and swaps it with the current buffer + pub fn swap_buffers(&mut self) { + self.buffers[1 - self.current].reset(); + self.current = 1 - self.current; + } + + /// Queries the real size of the backend. + pub fn size(&self) -> io::Result { + self.backend.size() + } +} diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index 0bfbc414b9..956a0cc7ef 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -1,5 +1,4 @@ -use crate::cell_widget::CellWidget; -use crate::exec_command::escape_command; +use crate::exec_command::strip_bash_lc_and_escape; use crate::markdown::append_markdown; use crate::text_block::TextBlock; use crate::text_formatting::format_and_truncate_tool_result; @@ -11,26 +10,22 @@ use codex_core::WireApi; use codex_core::config::Config; use codex_core::model_supports_reasoning_summaries; use codex_core::protocol::FileChange; +use codex_core::protocol::McpInvocation; use codex_core::protocol::SessionConfiguredEvent; use image::DynamicImage; -use image::GenericImageView; use image::ImageReader; -use lazy_static::lazy_static; use mcp_types::EmbeddedResourceResource; +use mcp_types::ResourceLink; use ratatui::prelude::*; use ratatui::style::Color; use ratatui::style::Modifier; use ratatui::style::Style; use ratatui::text::Line as RtLine; use ratatui::text::Span as RtSpan; -use ratatui_image::Image as TuiImage; -use ratatui_image::Resize as ImgResize; -use ratatui_image::picker::ProtocolType; use std::collections::HashMap; use std::io::Cursor; use std::path::PathBuf; use std::time::Duration; -use std::time::Instant; use tracing::error; pub(crate) struct CommandOutput { @@ -45,6 +40,21 @@ pub(crate) enum PatchEventType { ApplyBegin { auto_approved: bool }, } +fn span_to_static(span: &Span) -> Span<'static> { + Span { + style: span.style, + content: std::borrow::Cow::Owned(span.content.clone().into_owned()), + } +} + +fn line_to_static(line: &Line) -> Line<'static> { + Line { + style: line.style, + alignment: line.alignment, + spans: line.spans.iter().map(span_to_static).collect(), + } +} + /// Represents an event to display in the conversation history. Returns its /// `Vec>` representation to make it easier to display in a /// scrollable list. @@ -62,25 +72,13 @@ pub(crate) enum HistoryCell { AgentReasoning { view: TextBlock }, /// An exec tool call that has not finished yet. - ActiveExecCommand { - call_id: String, - /// The shell command, escaped and formatted. - command: String, - start: Instant, - view: TextBlock, - }, + ActiveExecCommand { view: TextBlock }, /// Completed exec tool call. CompletedExecCommand { view: TextBlock }, /// An MCP tool call that has not finished yet. - ActiveMcpToolCall { - call_id: String, - /// Formatted line that shows the command name and arguments - invocation: Line<'static>, - start: Instant, - view: TextBlock, - }, + ActiveMcpToolCall { view: TextBlock }, /// Completed MCP tool call where we show the result serialized as JSON. CompletedMcpToolCall { view: TextBlock }, @@ -93,13 +91,7 @@ pub(crate) enum HistoryCell { // resized version avoids doing the potentially expensive rescale twice // because the scroll-view first calls `height()` for layouting and then // `render_window()` for painting. - CompletedMcpToolCallWithImageOutput { - image: DynamicImage, - /// Cached data derived from the current terminal width. The cache is - /// invalidated whenever the width changes (e.g. when the user - /// resizes the window). - render_cache: std::cell::RefCell>, - }, + CompletedMcpToolCallWithImageOutput { _image: DynamicImage }, /// Background event. BackgroundEvent { view: TextBlock }, @@ -122,6 +114,32 @@ pub(crate) enum HistoryCell { const TOOL_CALL_MAX_LINES: usize = 5; impl HistoryCell { + /// Return a cloned, plain representation of the cell's lines suitable for + /// one‑shot insertion into the terminal scrollback. Image cells are + /// represented with a simple placeholder for now. + pub(crate) fn plain_lines(&self) -> Vec> { + match self { + HistoryCell::WelcomeMessage { view } + | HistoryCell::UserPrompt { view } + | HistoryCell::AgentMessage { view } + | HistoryCell::AgentReasoning { view } + | HistoryCell::BackgroundEvent { view } + | HistoryCell::GitDiffOutput { view } + | HistoryCell::ErrorEvent { view } + | HistoryCell::SessionInfo { view } + | HistoryCell::CompletedExecCommand { view } + | HistoryCell::CompletedMcpToolCall { view } + | HistoryCell::PendingPatch { view } + | HistoryCell::ActiveExecCommand { view, .. } + | HistoryCell::ActiveMcpToolCall { view, .. } => { + view.lines.iter().map(line_to_static).collect() + } + HistoryCell::CompletedMcpToolCallWithImageOutput { .. } => vec![ + Line::from("tool result (image output omitted)"), + Line::from(""), + ], + } + } pub(crate) fn new_session_info( config: &Config, event: SessionConfiguredEvent, @@ -155,7 +173,7 @@ impl HistoryCell { ("workdir", config.cwd.display().to_string()), ("model", config.model.clone()), ("provider", config.model_provider_id.clone()), - ("approval", format!("{:?}", config.approval_policy)), + ("approval", config.approval_policy.to_string()), ("sandbox", summarize_sandbox_policy(&config.sandbox_policy)), ]; if config.model_provider.wire_api == WireApi::Responses @@ -227,9 +245,8 @@ impl HistoryCell { } } - pub(crate) fn new_active_exec_command(call_id: String, command: Vec) -> Self { - let command_escaped = escape_command(&command); - let start = Instant::now(); + pub(crate) fn new_active_exec_command(command: Vec) -> Self { + let command_escaped = strip_bash_lc_and_escape(&command); let lines: Vec> = vec![ Line::from(vec!["command".magenta(), " running...".dim()]), @@ -238,14 +255,11 @@ impl HistoryCell { ]; HistoryCell::ActiveExecCommand { - call_id, - command: command_escaped, - start, view: TextBlock::new(lines), } } - pub(crate) fn new_completed_exec_command(command: String, output: CommandOutput) -> Self { + pub(crate) fn new_completed_exec_command(command: Vec, output: CommandOutput) -> Self { let CommandOutput { exit_code, stdout, @@ -269,7 +283,8 @@ impl HistoryCell { let src = if exit_code == 0 { stdout } else { stderr }; - lines.push(Line::from(format!("$ {command}"))); + let cmdline = strip_bash_lc_and_escape(&command); + lines.push(Line::from(format!("$ {cmdline}"))); let mut lines_iter = src.lines(); for raw in lines_iter.by_ref().take(TOOL_CALL_MAX_LINES) { lines.push(ansi_escape_line(raw).dim()); @@ -285,41 +300,15 @@ impl HistoryCell { } } - pub(crate) fn new_active_mcp_tool_call( - call_id: String, - server: String, - tool: String, - arguments: Option, - ) -> Self { - // Format the arguments as compact JSON so they roughly fit on one - // line. If there are no arguments we keep it empty so the invocation - // mirrors a function-style call. - let args_str = arguments - .as_ref() - .map(|v| { - // Use compact form to keep things short but readable. - serde_json::to_string(v).unwrap_or_else(|_| v.to_string()) - }) - .unwrap_or_default(); - - let invocation_spans = vec![ - Span::styled(server, Style::default().fg(Color::Blue)), - Span::raw("."), - Span::styled(tool, Style::default().fg(Color::Blue)), - Span::raw("("), - Span::styled(args_str, Style::default().fg(Color::Gray)), - Span::raw(")"), - ]; - let invocation = Line::from(invocation_spans); - - let start = Instant::now(); + pub(crate) fn new_active_mcp_tool_call(invocation: McpInvocation) -> Self { let title_line = Line::from(vec!["tool".magenta(), " running...".dim()]); - let lines: Vec> = vec![title_line, invocation.clone(), Line::from("")]; + let lines: Vec = vec![ + title_line, + format_mcp_invocation(invocation.clone()), + Line::from(""), + ]; HistoryCell::ActiveMcpToolCall { - call_id, - invocation, - start, view: TextBlock::new(lines), } } @@ -331,8 +320,7 @@ impl HistoryCell { ) -> Option { match result { Ok(mcp_types::CallToolResult { content, .. }) => { - if let Some(mcp_types::CallToolResultContent::ImageContent(image)) = content.first() - { + if let Some(mcp_types::ContentBlock::ImageContent(image)) = content.first() { let raw_data = match base64::engine::general_purpose::STANDARD.decode(&image.data) { Ok(data) => data, @@ -358,10 +346,7 @@ impl HistoryCell { } }; - Some(HistoryCell::CompletedMcpToolCallWithImageOutput { - image, - render_cache: std::cell::RefCell::new(None), - }) + Some(HistoryCell::CompletedMcpToolCallWithImageOutput { _image: image }) } else { None } @@ -372,8 +357,8 @@ impl HistoryCell { pub(crate) fn new_completed_mcp_tool_call( num_cols: u16, - invocation: Line<'static>, - start: Instant, + invocation: McpInvocation, + duration: Duration, success: bool, result: Result, ) -> Self { @@ -381,7 +366,7 @@ impl HistoryCell { return cell; } - let duration = format_duration(start.elapsed()); + let duration = format_duration(duration); let status_str = if success { "success" } else { "failed" }; let title_line = Line::from(vec![ "tool".magenta(), @@ -396,7 +381,7 @@ impl HistoryCell { let mut lines: Vec> = Vec::new(); lines.push(title_line); - lines.push(invocation); + lines.push(format_mcp_invocation(invocation)); match result { Ok(mcp_types::CallToolResult { content, .. }) => { @@ -405,21 +390,21 @@ impl HistoryCell { for tool_call_result in content { let line_text = match tool_call_result { - mcp_types::CallToolResultContent::TextContent(text) => { + mcp_types::ContentBlock::TextContent(text) => { format_and_truncate_tool_result( &text.text, TOOL_CALL_MAX_LINES, num_cols as usize, ) } - mcp_types::CallToolResultContent::ImageContent(_) => { + mcp_types::ContentBlock::ImageContent(_) => { // TODO show images even if they're not the first result, will require a refactor of `CompletedMcpToolCall` "".to_string() } - mcp_types::CallToolResultContent::AudioContent(_) => { + mcp_types::ContentBlock::AudioContent(_) => { "