From 5a9f82a572ff8d00a79d8a13be5c96b3922e7627 Mon Sep 17 00:00:00 2001 From: jimmyfraiture Date: Wed, 17 Sep 2025 15:01:23 +0100 Subject: [PATCH] V1 --- codex-rs/Cargo.lock | 2 + codex-rs/core/Cargo.toml | 2 + codex-rs/core/src/bin/prompt_harness.rs | 144 ++++++++ codex-rs/core/src/lib.rs | 3 + codex-rs/core/src/prompt_harness/driver.py | 64 ++++ codex-rs/core/src/prompt_harness/mod.rs | 337 ++++++++++++++++++ .../src/prompt_harness/prompt_override.rs | 109 ++++++ codex-rs/core/tests/prompt_harness_bin.rs | 48 +++ codex-rs/new_prompt.md | 2 + 9 files changed, 711 insertions(+) create mode 100644 codex-rs/core/src/bin/prompt_harness.rs create mode 100644 codex-rs/core/src/prompt_harness/driver.py create mode 100644 codex-rs/core/src/prompt_harness/mod.rs create mode 100644 codex-rs/core/src/prompt_harness/prompt_override.rs create mode 100644 codex-rs/core/tests/prompt_harness_bin.rs create mode 100644 codex-rs/new_prompt.md diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 81fc8683cb..3e64f580f3 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -669,6 +669,7 @@ dependencies = [ "base64", "bytes", "chrono", + "clap", "codex-apply-patch", "codex-file-search", "codex-mcp-client", @@ -706,6 +707,7 @@ dependencies = [ "toml", "toml_edit", "tracing", + "tracing-subscriber", "tree-sitter", "tree-sitter-bash", "uuid", diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index b4ed4a937a..a13429c389 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -14,6 +14,7 @@ workspace = true [dependencies] anyhow = "1" askama = "0.12" +clap = { version = "4", features = ["derive"] } async-channel = "2.3.1" base64 = "0.22" bytes = "1.10.1" @@ -53,6 +54,7 @@ tokio-util = "0.7.16" toml = "0.9.5" toml_edit = "0.23.4" tracing = { version = "0.1.41", features = ["log"] } +tracing-subscriber = { version = "0.3", features = ["fmt"] } tree-sitter = "0.25.9" tree-sitter-bash = "0.25.0" uuid = { version = "1", features = ["serde", "v4"] } diff --git a/codex-rs/core/src/bin/prompt_harness.rs b/codex-rs/core/src/bin/prompt_harness.rs new file mode 100644 index 0000000000..797a989fca --- /dev/null +++ b/codex-rs/core/src/bin/prompt_harness.rs @@ -0,0 +1,144 @@ +use std::path::PathBuf; + +use anyhow::Context; +use anyhow::Result; +use anyhow::anyhow; +use clap::ArgAction; +use clap::Parser; +use codex_core::prompt_harness::PromptHarnessCommand; +use codex_core::prompt_harness::PromptHarnessOptions; +use codex_core::prompt_harness::run_prompt_harness; + +#[derive(Debug, Parser)] +#[command( + author, + version, + about = "Run Codex with a system prompt override and attach a JSON protocol script." +)] +struct PromptHarnessCli { + /// Override configuration values (`toml`-parsed). Repeatable. + #[arg( + short = 'c', + long = "config", + value_name = "key=value", + action = ArgAction::Append + )] + raw_overrides: Vec, + + /// Path to the file containing replacement system instructions for Codex. + #[arg(long = "system-prompt-file", value_name = "FILE")] + system_prompt_file: PathBuf, + + /// Command to execute. Receives Codex protocol events on stdin and must + /// emit submissions as JSON on stdout. + #[arg( + value_name = "COMMAND", + trailing_var_arg = true, + default_values = ["python3", "core/src/prompt_harness/driver.py"] + )] + command: Vec, +} + +#[tokio::main(flavor = "multi_thread")] +async fn main() -> Result<()> { + let cli = PromptHarnessCli::parse(); + let _ = tracing_subscriber::fmt() + .with_writer(std::io::stderr) + .try_init(); + + let overrides = parse_overrides(&cli.raw_overrides)?; + let command = build_command(cli.command).context("command was missing program name")?; + + let options = PromptHarnessOptions { + cli_overrides: overrides, + prompt_file: cli.system_prompt_file, + command, + }; + + run_prompt_harness(options).await +} + +fn build_command(mut parts: Vec) -> Option { + if parts.is_empty() { + return None; + } + let program = PathBuf::from(parts.remove(0)); + Some(PromptHarnessCommand { + program, + args: parts, + }) +} + +fn parse_overrides(raw: &[String]) -> Result> { + raw.iter() + .map(|entry| parse_single_override(entry)) + .collect() +} + +fn parse_single_override(raw: &str) -> Result<(String, toml::Value)> { + let mut split = raw.splitn(2, '='); + let key = split + .next() + .map(str::trim) + .filter(|key| !key.is_empty()) + .ok_or_else(|| anyhow!("override missing key: {raw}"))?; + let value = split + .next() + .map(str::trim) + .ok_or_else(|| anyhow!("override missing '=' delimiter: {raw}"))?; + + let parsed = + parse_toml_value(value).unwrap_or_else(|| toml::Value::String(trim_override_string(value))); + + Ok((key.to_string(), parsed)) +} + +fn trim_override_string(raw: &str) -> String { + let trimmed = raw.trim(); + trimmed.trim_matches(|c| c == '\'' || c == '"').to_string() +} + +fn parse_toml_value(raw: &str) -> Option { + let wrapped = format!("_value_ = {raw}"); + let mut table: toml::Table = toml::from_str(&wrapped).ok()?; + table.remove("_value_") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parses_string_literal() { + let (k, v) = parse_single_override("model='o4'").expect("override"); + assert_eq!(k, "model"); + assert_eq!(v, toml::Value::String("o4".to_string())); + } + + #[test] + fn parses_json_literal() { + let (k, v) = parse_single_override("numbers=[1,2]").expect("override"); + assert_eq!(k, "numbers"); + assert_eq!( + v, + toml::Value::Array(vec![toml::Value::Integer(1), toml::Value::Integer(2)]) + ); + } + + #[test] + fn rejects_missing_key() { + assert!(parse_single_override("=oops").is_err()); + } + + #[test] + fn rejects_missing_value() { + assert!(parse_single_override("model").is_err()); + } + + #[test] + fn build_command_splits_program_and_args() { + let cmd = build_command(vec!["python".to_string(), "-V".to_string()]).expect("command"); + assert_eq!(cmd.program, PathBuf::from("python")); + assert_eq!(cmd.args, vec!["-V".to_string()]); + } +} diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index e024effbe2..a1b6651bff 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -59,6 +59,7 @@ mod openai_model_info; mod openai_tools; pub mod plan_tool; pub mod project_doc; +pub mod prompt_harness; mod rollout; pub(crate) mod safety; pub mod seatbelt; @@ -98,3 +99,5 @@ pub use codex_protocol::models::LocalShellExecAction; pub use codex_protocol::models::LocalShellStatus; pub use codex_protocol::models::ReasoningItemContent; pub use codex_protocol::models::ResponseItem; +pub use prompt_harness::SAMPLE_DRIVER; +pub use prompt_harness::load_system_prompt_override; diff --git a/codex-rs/core/src/prompt_harness/driver.py b/codex-rs/core/src/prompt_harness/driver.py new file mode 100644 index 0000000000..97283e47e4 --- /dev/null +++ b/codex-rs/core/src/prompt_harness/driver.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +import json +import sys +from typing import Iterator + + +def render_agent_message(message: object) -> str: + if isinstance(message, str): + return message + if isinstance(message, dict) and "content" in message: + return json.dumps(message) + return str(message) + + +def send_question(questions: Iterator[str], turn: int) -> bool: + try: + text = next(questions) + except StopIteration: + return False + + payload = { + "id": f"turn-{turn}", + "op": { + "type": "user_input", + "items": [ + { + "type": "text", + "text": text, + } + ], + }, + } + print(json.dumps(payload), flush=True) + print(f"[user] {text}", file=sys.stderr) + return True + + +def main() -> None: + questions = iter(["What is your name?", "1+1=?"]) + turn = 1 + + for raw in sys.stdin: + event = json.loads(raw) + kind = event.get("msg", {}).get("type") + + if kind != "agent_message_delta" and kind != "agent_reasoning_delta": + print(f"[harness] event {kind}", file=sys.stderr) + + if kind == "session_configured": + if send_question(questions, turn): + continue + elif kind == "user_message": + print(f"[user_message raw] {json.dumps(event)}", file=sys.stderr) + elif kind == "agent_message": + message = event.get("msg", {}).get("message") + print(f"[agent] {render_agent_message(message)}", file=sys.stderr) + elif kind == "task_complete": + turn += 1 + if not send_question(questions, turn): + break + + +if __name__ == "__main__": + main() diff --git a/codex-rs/core/src/prompt_harness/mod.rs b/codex-rs/core/src/prompt_harness/mod.rs new file mode 100644 index 0000000000..56eccf2039 --- /dev/null +++ b/codex-rs/core/src/prompt_harness/mod.rs @@ -0,0 +1,337 @@ +mod prompt_override; + +use std::path::PathBuf; +use std::sync::Arc; + +use anyhow::Context; +use anyhow::Result; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::io::BufWriter; +use tokio::io::{self}; +use tokio::process::ChildStdin; +use tokio::process::ChildStdout; +use tokio::sync::watch; + +use crate::auth::AuthManager; +use crate::codex::INITIAL_SUBMIT_ID; +use crate::codex_conversation::CodexConversation; +use crate::config::Config; +use crate::config::ConfigOverrides; +use crate::conversation_manager::ConversationManager; +use codex_protocol::protocol::Event; +use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::SessionConfiguredEvent; +use codex_protocol::protocol::Submission; +use tracing::debug; +use tracing::error; +use tracing::info; + +pub use prompt_override::load_system_prompt_override; + +/// Sample Python harness script that can be used to drive the prompt harness binary. +pub const SAMPLE_DRIVER: &str = include_str!("driver.py"); + +#[derive(Debug, Clone)] +pub struct PromptHarnessCommand { + pub program: PathBuf, + pub args: Vec, +} + +#[derive(Debug, Clone)] +pub struct PromptHarnessOptions { + pub cli_overrides: Vec<(String, toml::Value)>, + pub prompt_file: PathBuf, + pub command: PromptHarnessCommand, +} + +/// Load configuration, override system prompt, and execute the harness. +pub async fn run_prompt_harness(opts: PromptHarnessOptions) -> Result<()> { + let PromptHarnessOptions { + cli_overrides, + prompt_file, + command, + } = opts; + + let base_instructions = load_system_prompt_override(&prompt_file).with_context(|| { + format!( + "failed to load system prompt override from {}", + prompt_file.display() + ) + })?; + + let config = load_config(cli_overrides, base_instructions.clone())?; + let auth_manager = AuthManager::shared(config.codex_home.clone()); + let conversation_manager = ConversationManager::new(auth_manager); + + let session = conversation_manager + .new_conversation(config) + .await + .context("failed to start Codex conversation")?; + + info!( + ?command.program, + args = ?command.args, + "starting prompt harness child process" + ); + + run_conversation(command, session.conversation, session.session_configured).await +} + +fn load_config( + cli_overrides: Vec<(String, toml::Value)>, + base_instructions: String, +) -> Result { + let overrides = ConfigOverrides { + base_instructions: Some(base_instructions.clone()), + ..ConfigOverrides::default() + }; + let mut config = Config::load_with_cli_overrides(cli_overrides, overrides)?; + let effective_instructions = config + .base_instructions + .clone() + .unwrap_or(base_instructions); + config.model_family.base_instructions = effective_instructions.clone(); + // Force the override to be the only set of instructions that the model sees. + config.user_instructions = Some(effective_instructions); + + Ok(config) +} + +async fn run_conversation( + command: PromptHarnessCommand, + conversation: Arc, + session_configured: SessionConfiguredEvent, +) -> Result<()> { + use std::process::Stdio; + use tokio::process::Command; + + let mut child = Command::new(&command.program) + .args(&command.args) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()) + .spawn() + .with_context(|| { + format!( + "failed to spawn child process `{}`", + command.program.display() + ) + })?; + + let child_stdin = child + .stdin + .take() + .context("child process lacks stdin pipe")?; + let child_stdout = child + .stdout + .take() + .context("child process lacks stdout pipe")?; + + let (child_exit_tx, child_exit_rx) = watch::channel(false); + + let events_task = tokio::spawn(pump_events( + conversation.clone(), + session_configured, + child_stdin, + child_exit_rx.clone(), + )); + + let submissions_task = tokio::spawn(pump_submissions( + conversation, + child_stdout, + child_exit_rx.clone(), + )); + + let status = child + .wait() + .await + .with_context(|| format!("failed to wait for child `{}`", command.program.display()))?; + let _ = child_exit_tx.send(true); + + info!(?status, "prompt harness child exited"); + + match events_task.await { + Ok(res) => res?, + Err(err) => return Err(err).context("event pump task panicked"), + } + + match submissions_task.await { + Ok(res) => res?, + Err(err) => return Err(err).context("submission pump task panicked"), + } + + Ok(()) +} + +async fn pump_events( + conversation: Arc, + session_configured: SessionConfiguredEvent, + child_stdin: ChildStdin, + mut child_exit_rx: watch::Receiver, +) -> Result<()> { + let mut writer = BufWriter::new(child_stdin); + + let initial_event = Event { + id: INITIAL_SUBMIT_ID.to_string(), + msg: EventMsg::SessionConfigured(session_configured), + }; + + if !write_event(&mut writer, &initial_event).await? { + return Ok(()); + } + + loop { + tokio::select! { + changed = child_exit_rx.changed() => { + if changed.is_err() || *child_exit_rx.borrow() { + break; + } + } + event = conversation.next_event() => { + let event = event?; + if !write_event(&mut writer, &event).await? { + break; + } + } + } + } + + Ok(()) +} + +async fn pump_submissions( + conversation: Arc, + child_stdout: ChildStdout, + mut child_exit_rx: watch::Receiver, +) -> Result<()> { + let mut reader = BufReader::new(child_stdout).lines(); + + loop { + tokio::select! { + changed = child_exit_rx.changed() => { + if changed.is_err() || *child_exit_rx.borrow() { + break; + } + } + line = reader.next_line() => { + match line? { + Some(line) => { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + match serde_json::from_str::(trimmed) { + Ok(submission) => { + if let Err(err) = conversation.submit_with_id(submission).await { + return Err(err.into()); + } + } + Err(err) => { + if trimmed.starts_with('{') || trimmed.starts_with('[') { + error!("invalid submission from child: {err}"); + } else { + debug!("ignoring non-JSON child output line: {trimmed}"); + } + } + } + } + None => break, + } + } + } + } + + Ok(()) +} + +async fn write_event(writer: &mut BufWriter, event: &Event) -> Result { + let json = serde_json::to_string(event).context("failed to serialize event")?; + let write_res = writer.write_all(json.as_bytes()).await; + if let Err(err) = write_res { + return handle_broken_pipe(err); + } + let newline_res = writer.write_all(b"\n").await; + if let Err(err) = newline_res { + return handle_broken_pipe(err); + } + let flush_res = writer.flush().await; + if let Err(err) = flush_res { + return handle_broken_pipe(err); + } + Ok(true) +} + +fn handle_broken_pipe(err: io::Error) -> Result { + match err.kind() { + io::ErrorKind::BrokenPipe + | io::ErrorKind::ConnectionReset + | io::ErrorKind::NotConnected => { + info!("child process closed stdin"); + Ok(false) + } + _ => Err(err.into()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + use tempfile::NamedTempFile; + use tempfile::TempDir; + + static ENV_GUARD: Mutex<()> = Mutex::new(()); + + fn set_env_var(key: &str, value: impl AsRef) { + unsafe { + std::env::set_var(key, value); + } + } + + fn remove_env_var(key: &str) { + unsafe { + std::env::remove_var(key); + } + } + + struct EnvVarReset<'a> { + key: &'a str, + prev: Option, + } + + impl<'a> EnvVarReset<'a> { + fn new(key: &'a str) -> Self { + let prev = std::env::var(key).ok(); + Self { key, prev } + } + } + + impl Drop for EnvVarReset<'_> { + fn drop(&mut self) { + if let Some(prev) = &self.prev { + set_env_var(self.key, prev); + } else { + remove_env_var(self.key); + } + } + } + + #[test] + fn load_config_applies_base_instructions() { + let _guard = ENV_GUARD.lock().expect("lock env guard"); + let codex_home = TempDir::new().expect("create codex home"); + let _reset = EnvVarReset::new("CODEX_HOME"); + set_env_var("CODEX_HOME", codex_home.path()); + + let file = NamedTempFile::new().expect("create temp"); + std::fs::write(file.path(), "prompt override").expect("write prompt"); + + let base = load_system_prompt_override(file.path()).expect("load prompt"); + let config = load_config(Vec::new(), base.clone()).expect("load config"); + assert_eq!(config.base_instructions.as_deref(), Some(base.as_str())); + assert_eq!(config.user_instructions.as_deref(), Some(base.as_str())); + assert_eq!(config.model_family.base_instructions, base); + } +} diff --git a/codex-rs/core/src/prompt_harness/prompt_override.rs b/codex-rs/core/src/prompt_harness/prompt_override.rs new file mode 100644 index 0000000000..49880e48cd --- /dev/null +++ b/codex-rs/core/src/prompt_harness/prompt_override.rs @@ -0,0 +1,109 @@ +use std::fs::File; +use std::io::Read; +use std::io::{self}; +use std::path::Path; + +const PROMPT_OVERRIDE_MAX_BYTES: u64 = 8 * 1024; + +pub fn load_system_prompt_override(path: &Path) -> io::Result { + let metadata = path.metadata().map_err(|err| { + io::Error::new( + err.kind(), + format!( + "failed to read system prompt override metadata {}: {err}", + path.display() + ), + ) + })?; + + if metadata.len() == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("system prompt override file is empty: {}", path.display()), + )); + } + + if metadata.len() > PROMPT_OVERRIDE_MAX_BYTES { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "system prompt override exceeds limit ({} bytes): {}", + PROMPT_OVERRIDE_MAX_BYTES, + path.display() + ), + )); + } + + let mut file = File::open(path).map_err(|err| { + io::Error::new( + err.kind(), + format!( + "failed to open system prompt override {}: {err}", + path.display() + ), + ) + })?; + let mut buf = String::new(); + file.read_to_string(&mut buf).map_err(|err| { + io::Error::new( + err.kind(), + format!( + "failed to read system prompt override {}: {err}", + path.display() + ), + ) + })?; + + let trimmed = buf.trim().to_string(); + if trimmed.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "system prompt override only contained whitespace: {}", + path.display() + ), + )); + } + + Ok(trimmed) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + + #[test] + fn rejects_missing_file() { + let path = std::path::PathBuf::from("/no/such/file"); + let err = load_system_prompt_override(&path).expect_err("expected error"); + assert_eq!(err.kind(), io::ErrorKind::NotFound); + } + + #[test] + fn rejects_empty_file() { + let file = NamedTempFile::new().expect("create temp"); + std::fs::write(file.path(), " \n \n").expect("write temp"); + let err = load_system_prompt_override(file.path()).expect_err("expected error"); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("whitespace")); + } + + #[test] + fn rejects_large_file() { + let file = NamedTempFile::new().expect("create temp"); + let large = vec![b'x'; (PROMPT_OVERRIDE_MAX_BYTES + 1) as usize]; + std::fs::write(file.path(), large).expect("write temp"); + let err = load_system_prompt_override(file.path()).expect_err("expected error"); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("exceeds limit")); + } + + #[test] + fn trims_and_returns_contents() { + let file = NamedTempFile::new().expect("create temp"); + std::fs::write(file.path(), "\n hello world \n").expect("write temp"); + let prompt = load_system_prompt_override(file.path()).expect("load prompt"); + assert_eq!(prompt, "hello world"); + } +} diff --git a/codex-rs/core/tests/prompt_harness_bin.rs b/codex-rs/core/tests/prompt_harness_bin.rs new file mode 100644 index 0000000000..3aff89ad7b --- /dev/null +++ b/codex-rs/core/tests/prompt_harness_bin.rs @@ -0,0 +1,48 @@ +use std::error::Error; +use std::fs; + +use assert_cmd::Command; +use tempfile::TempDir; + +#[test] +fn prompt_harness_streams_session_event() -> Result<(), Box> { + let workspace = TempDir::new()?; + let codex_home = workspace.path().join("codex_home"); + fs::create_dir(&codex_home)?; + + let prompt_path = workspace.path().join("override.md"); + fs::write(&prompt_path, "system override contents")?; + + let script_path = workspace.path().join("driver.py"); + fs::write(&script_path, driver_script())?; + + let mut cmd = Command::cargo_bin("prompt_harness")?; + cmd.env("CODEX_HOME", &codex_home) + .arg("--system-prompt-file") + .arg(&prompt_path) + .arg("python3") + .arg(&script_path) + .assert() + .success(); + + Ok(()) +} + +fn driver_script() -> String { + r#"#!/usr/bin/env python3 +import json +import sys + +first = sys.stdin.readline() +if not first: + sys.exit("missing session_configured event") + +message = json.loads(first) +if message.get("msg", {}).get("type") != "session_configured": + sys.exit("unexpected initial event type") + +submission = {"id": "interrupt", "op": {"type": "interrupt"}} +print(json.dumps(submission), flush=True) +"# + .to_string() +} diff --git a/codex-rs/new_prompt.md b/codex-rs/new_prompt.md new file mode 100644 index 0000000000..c53fc9163c --- /dev/null +++ b/codex-rs/new_prompt.md @@ -0,0 +1,2 @@ +Your name is batman +Be very talkative (i.e. use a lot of words to answer any questions) \ No newline at end of file