mirror of
https://github.com/openai/codex.git
synced 2026-03-06 14:05:29 +03:00
Compare commits
28 Commits
queue/stee
...
rust-v0.11
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16edc4b1d2 | ||
|
|
751b2b8f8c | ||
|
|
e389a5d876 | ||
|
|
feeb266968 | ||
|
|
6bd7ef3d3a | ||
|
|
e3f8b121db | ||
|
|
9b646b563f | ||
|
|
c07207a4cd | ||
|
|
cb49cef114 | ||
|
|
c76cb1dfc1 | ||
|
|
c66f6be85d | ||
|
|
ceb3e91d5d | ||
|
|
c55b946a08 | ||
|
|
b0f5b65021 | ||
|
|
ecfafcfd3b | ||
|
|
8c37168f62 | ||
|
|
e05d7dc2a3 | ||
|
|
9792f94eef | ||
|
|
ce92794bed | ||
|
|
c8e37631d4 | ||
|
|
50c408b9fd | ||
|
|
fa861e544f | ||
|
|
eeaf2ca0b6 | ||
|
|
b249e5098f | ||
|
|
756708a671 | ||
|
|
38e0461af5 | ||
|
|
3793670da8 | ||
|
|
ea4224b95c |
1
codex-rs/Cargo.lock
generated
1
codex-rs/Cargo.lock
generated
@@ -698,6 +698,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"sha1",
|
||||
"shlex",
|
||||
"similar",
|
||||
"strum_macros 0.27.2",
|
||||
"tempfile",
|
||||
"thiserror 2.0.12",
|
||||
|
||||
@@ -19,7 +19,7 @@ members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.0.0"
|
||||
version = "0.11.0-alpha.4"
|
||||
# Track the edition for all workspace crates in one place. Individual
|
||||
# crates can still override this value, but keeping it here means new
|
||||
# crates created with `cargo new -w ...` automatically inherit the 2024
|
||||
|
||||
@@ -18,6 +18,9 @@ pub enum ApprovalModeCliArg {
|
||||
/// will escalate to the user to ask for un-sandboxed execution.
|
||||
OnFailure,
|
||||
|
||||
/// Only ask for approval if the model requests it.
|
||||
OnRequest,
|
||||
|
||||
/// Never ask for user approval
|
||||
/// Execution failures are immediately returned to the model.
|
||||
Never,
|
||||
@@ -28,6 +31,7 @@ impl From<ApprovalModeCliArg> for AskForApproval {
|
||||
match value {
|
||||
ApprovalModeCliArg::Untrusted => AskForApproval::UnlessTrusted,
|
||||
ApprovalModeCliArg::OnFailure => AskForApproval::OnFailure,
|
||||
ApprovalModeCliArg::OnRequest => AskForApproval::OnRequest,
|
||||
ApprovalModeCliArg::Never => AskForApproval::Never,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
sha1 = "0.10.6"
|
||||
shlex = "1.3.0"
|
||||
similar = "2"
|
||||
strum_macros = "0.27.2"
|
||||
thiserror = "2.0.12"
|
||||
time = { version = "0.3", features = ["formatting", "local-offset", "macros"] }
|
||||
|
||||
@@ -24,6 +24,7 @@ use crate::error::Result;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::openai_tools::create_tools_json_for_chat_completions_api;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::util::backoff;
|
||||
|
||||
/// Implementation for the classic Chat Completions API.
|
||||
@@ -33,6 +34,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
include_plan_tool: bool,
|
||||
client: &reqwest::Client,
|
||||
provider: &ModelProviderInfo,
|
||||
sandbox_policy: Option<SandboxPolicy>,
|
||||
) -> Result<ResponseStream> {
|
||||
// Build messages array
|
||||
let mut messages = Vec::<serde_json::Value>::new();
|
||||
@@ -110,7 +112,12 @@ pub(crate) async fn stream_chat_completions(
|
||||
}
|
||||
}
|
||||
|
||||
let tools_json = create_tools_json_for_chat_completions_api(prompt, model, include_plan_tool)?;
|
||||
let tools_json = create_tools_json_for_chat_completions_api(
|
||||
prompt,
|
||||
model,
|
||||
include_plan_tool,
|
||||
sandbox_policy,
|
||||
)?;
|
||||
let payload = json!({
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
|
||||
@@ -38,6 +38,7 @@ use crate::model_provider_info::WireApi;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::openai_tools::create_tools_json_for_responses_api;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::util::backoff;
|
||||
use std::sync::Arc;
|
||||
@@ -76,9 +77,13 @@ impl ModelClient {
|
||||
/// Dispatches to either the Responses or Chat implementation depending on
|
||||
/// the provider config. Public callers always invoke `stream()` – the
|
||||
/// specialised helpers are private to avoid accidental misuse.
|
||||
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
pub async fn stream(
|
||||
&self,
|
||||
prompt: &Prompt,
|
||||
sandbox_policy: Option<SandboxPolicy>,
|
||||
) -> Result<ResponseStream> {
|
||||
match self.provider.wire_api {
|
||||
WireApi::Responses => self.stream_responses(prompt).await,
|
||||
WireApi::Responses => self.stream_responses(prompt, sandbox_policy).await,
|
||||
WireApi::Chat => {
|
||||
// Create the raw streaming connection first.
|
||||
let response_stream = stream_chat_completions(
|
||||
@@ -87,6 +92,7 @@ impl ModelClient {
|
||||
self.config.include_plan_tool,
|
||||
&self.client,
|
||||
&self.provider,
|
||||
sandbox_policy,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -115,7 +121,11 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
/// Implementation for the OpenAI *Responses* experimental API.
|
||||
async fn stream_responses(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
async fn stream_responses(
|
||||
&self,
|
||||
prompt: &Prompt,
|
||||
sandbox_policy: Option<SandboxPolicy>,
|
||||
) -> Result<ResponseStream> {
|
||||
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
|
||||
// short circuit for tests
|
||||
warn!(path, "Streaming from fixture");
|
||||
@@ -146,6 +156,7 @@ impl ModelClient {
|
||||
prompt,
|
||||
&self.config.model,
|
||||
self.config.include_plan_tool,
|
||||
sandbox_policy,
|
||||
)?;
|
||||
let reasoning = create_reasoning_param_for_request(&self.config, self.effort, self.summary);
|
||||
|
||||
|
||||
@@ -84,11 +84,13 @@ use crate::protocol::SandboxPolicy;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::protocol::Submission;
|
||||
use crate::protocol::TaskCompleteEvent;
|
||||
use crate::protocol::TurnDiffEvent;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use crate::safety::SafetyCheck;
|
||||
use crate::safety::assess_command_safety;
|
||||
use crate::safety::assess_safety_for_untrusted_command;
|
||||
use crate::shell;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::user_notification::UserNotification;
|
||||
use crate::util::backoff;
|
||||
|
||||
@@ -206,7 +208,7 @@ pub(crate) struct Session {
|
||||
base_instructions: Option<String>,
|
||||
user_instructions: Option<String>,
|
||||
pub(crate) approval_policy: AskForApproval,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
pub(crate) sandbox_policy: SandboxPolicy,
|
||||
shell_environment_policy: ShellEnvironmentPolicy,
|
||||
pub(crate) writable_roots: Mutex<Vec<PathBuf>>,
|
||||
disable_response_storage: bool,
|
||||
@@ -361,7 +363,11 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
async fn notify_exec_command_begin(&self, exec_command_context: ExecCommandContext) {
|
||||
async fn on_exec_command_begin(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
exec_command_context: ExecCommandContext,
|
||||
) {
|
||||
let ExecCommandContext {
|
||||
sub_id,
|
||||
call_id,
|
||||
@@ -373,11 +379,15 @@ impl Session {
|
||||
Some(ApplyPatchCommandContext {
|
||||
user_explicitly_approved_this_action,
|
||||
changes,
|
||||
}) => EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
auto_approved: !user_explicitly_approved_this_action,
|
||||
changes,
|
||||
}),
|
||||
}) => {
|
||||
let _ = turn_diff_tracker.on_patch_begin(&changes);
|
||||
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
auto_approved: !user_explicitly_approved_this_action,
|
||||
changes,
|
||||
})
|
||||
}
|
||||
None => EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
call_id,
|
||||
command: command_for_display.clone(),
|
||||
@@ -391,8 +401,10 @@ impl Session {
|
||||
let _ = self.tx_event.send(event).await;
|
||||
}
|
||||
|
||||
async fn notify_exec_command_end(
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn on_exec_command_end(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
stdout: &str,
|
||||
@@ -427,6 +439,20 @@ impl Session {
|
||||
msg,
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
|
||||
// If this is an apply_patch, after we emit the end patch, emit a second event
|
||||
// with the full turn diff if there is one.
|
||||
if is_apply_patch {
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
id: sub_id.into(),
|
||||
msg,
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper that emits a BackgroundEvent with the given message. This keeps
|
||||
@@ -1000,6 +1026,10 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
.await;
|
||||
|
||||
let last_agent_message: Option<String>;
|
||||
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
let mut turn_diff_tracker = TurnDiffTracker::new();
|
||||
|
||||
loop {
|
||||
// Note that pending_input would be something like a message the user
|
||||
// submitted through the UI while the model was running. Though the UI
|
||||
@@ -1031,7 +1061,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
match run_turn(&sess, sub_id.clone(), turn_input).await {
|
||||
match run_turn(&sess, &mut turn_diff_tracker, sub_id.clone(), turn_input).await {
|
||||
Ok(turn_output) => {
|
||||
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
|
||||
let mut responses = Vec::<ResponseInputItem>::new();
|
||||
@@ -1157,6 +1187,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
|
||||
async fn run_turn(
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: String,
|
||||
input: Vec<ResponseItem>,
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
@@ -1171,7 +1202,7 @@ async fn run_turn(
|
||||
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match try_run_turn(sess, &sub_id, &prompt).await {
|
||||
match try_run_turn(sess, turn_diff_tracker, &sub_id, &prompt).await {
|
||||
Ok(output) => return Ok(output),
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
@@ -1217,6 +1248,7 @@ struct ProcessedResponseItem {
|
||||
|
||||
async fn try_run_turn(
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
@@ -1276,7 +1308,13 @@ async fn try_run_turn(
|
||||
})
|
||||
};
|
||||
|
||||
let mut stream = sess.client.clone().stream(&prompt).await?;
|
||||
// only provide the sandbox policy if the approval policy is OnRequest
|
||||
let sandbox_policy = if sess.approval_policy == AskForApproval::OnRequest {
|
||||
Some(sess.sandbox_policy.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut stream = sess.client.clone().stream(&prompt, sandbox_policy).await?;
|
||||
|
||||
let mut output = Vec::new();
|
||||
loop {
|
||||
@@ -1304,7 +1342,8 @@ async fn try_run_turn(
|
||||
match event {
|
||||
ResponseEvent::Created => {}
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
let response = handle_response_item(sess, sub_id, item.clone()).await?;
|
||||
let response =
|
||||
handle_response_item(sess, turn_diff_tracker, sub_id, item.clone()).await?;
|
||||
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
}
|
||||
@@ -1322,6 +1361,16 @@ async fn try_run_turn(
|
||||
.ok();
|
||||
}
|
||||
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
let _ = sess.tx_event.send(event).await;
|
||||
}
|
||||
|
||||
return Ok(output);
|
||||
}
|
||||
ResponseEvent::OutputTextDelta(delta) => {
|
||||
@@ -1426,6 +1475,7 @@ async fn run_compact_task(
|
||||
|
||||
async fn handle_response_item(
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
item: ResponseItem,
|
||||
) -> CodexResult<Option<ResponseInputItem>> {
|
||||
@@ -1463,7 +1513,17 @@ async fn handle_response_item(
|
||||
..
|
||||
} => {
|
||||
info!("FunctionCall: {arguments}");
|
||||
Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
|
||||
Some(
|
||||
handle_function_call(
|
||||
sess,
|
||||
turn_diff_tracker,
|
||||
sub_id.to_string(),
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
)
|
||||
.await,
|
||||
)
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
@@ -1477,6 +1537,8 @@ async fn handle_response_item(
|
||||
command: action.command,
|
||||
workdir: action.working_directory,
|
||||
timeout_ms: action.timeout_ms,
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
let effective_call_id = match (call_id, id) {
|
||||
(Some(call_id), _) => call_id,
|
||||
@@ -1498,6 +1560,7 @@ async fn handle_response_item(
|
||||
handle_container_exec_with_params(
|
||||
exec_params,
|
||||
sess,
|
||||
turn_diff_tracker,
|
||||
sub_id.to_string(),
|
||||
effective_call_id,
|
||||
)
|
||||
@@ -1515,6 +1578,7 @@ async fn handle_response_item(
|
||||
|
||||
async fn handle_function_call(
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
@@ -1528,7 +1592,8 @@ async fn handle_function_call(
|
||||
return *output;
|
||||
}
|
||||
};
|
||||
handle_container_exec_with_params(params, sess, sub_id, call_id).await
|
||||
handle_container_exec_with_params(params, sess, turn_diff_tracker, sub_id, call_id)
|
||||
.await
|
||||
}
|
||||
"update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await,
|
||||
_ => {
|
||||
@@ -1562,6 +1627,8 @@ fn to_exec_params(params: ShellToolCallParams, sess: &Session) -> ExecParams {
|
||||
cwd: sess.resolve_path(params.workdir.clone()),
|
||||
timeout_ms: params.timeout_ms,
|
||||
env: create_env(&sess.shell_environment_policy),
|
||||
with_escalated_permissions: params.with_escalated_permissions,
|
||||
justification: params.justification,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1602,6 +1669,7 @@ fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams
|
||||
async fn handle_container_exec_with_params(
|
||||
params: ExecParams,
|
||||
sess: &Session,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
) -> ResponseInputItem {
|
||||
@@ -1661,13 +1729,19 @@ async fn handle_container_exec_with_params(
|
||||
cwd: cwd.clone(),
|
||||
timeout_ms: params.timeout_ms,
|
||||
env: HashMap::new(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
let safety = if *user_explicitly_approved_this_action {
|
||||
SafetyCheck::AutoApprove {
|
||||
sandbox_type: SandboxType::None,
|
||||
}
|
||||
} else {
|
||||
assess_safety_for_untrusted_command(sess.approval_policy, &sess.sandbox_policy)
|
||||
assess_safety_for_untrusted_command(
|
||||
sess.approval_policy,
|
||||
&sess.sandbox_policy,
|
||||
params.with_escalated_permissions.unwrap_or(false),
|
||||
)
|
||||
};
|
||||
(
|
||||
params,
|
||||
@@ -1683,6 +1757,7 @@ async fn handle_container_exec_with_params(
|
||||
sess.approval_policy,
|
||||
&sess.sandbox_policy,
|
||||
&state.approved_commands,
|
||||
params.with_escalated_permissions.unwrap_or(false),
|
||||
)
|
||||
};
|
||||
let command_for_display = params.command.clone();
|
||||
@@ -1699,7 +1774,7 @@ async fn handle_container_exec_with_params(
|
||||
call_id.clone(),
|
||||
params.command.clone(),
|
||||
params.cwd.clone(),
|
||||
None,
|
||||
params.justification.clone(),
|
||||
)
|
||||
.await;
|
||||
match rx_approve.await.unwrap_or_default() {
|
||||
@@ -1749,7 +1824,7 @@ async fn handle_container_exec_with_params(
|
||||
},
|
||||
),
|
||||
};
|
||||
sess.notify_exec_command_begin(exec_command_context.clone())
|
||||
sess.on_exec_command_begin(turn_diff_tracker, exec_command_context.clone())
|
||||
.await;
|
||||
|
||||
let params = maybe_run_with_user_profile(params, sess);
|
||||
@@ -1771,7 +1846,8 @@ async fn handle_container_exec_with_params(
|
||||
duration,
|
||||
} = output;
|
||||
|
||||
sess.notify_exec_command_end(
|
||||
sess.on_exec_command_end(
|
||||
turn_diff_tracker,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
&stdout,
|
||||
@@ -1797,7 +1873,15 @@ async fn handle_container_exec_with_params(
|
||||
}
|
||||
}
|
||||
Err(CodexErr::Sandbox(error)) => {
|
||||
handle_sandbox_error(params, exec_command_context, error, sandbox_type, sess).await
|
||||
handle_sandbox_error(
|
||||
turn_diff_tracker,
|
||||
params,
|
||||
exec_command_context,
|
||||
error,
|
||||
sandbox_type,
|
||||
sess,
|
||||
)
|
||||
.await
|
||||
}
|
||||
Err(e) => {
|
||||
// Handle non-sandbox errors
|
||||
@@ -1813,6 +1897,7 @@ async fn handle_container_exec_with_params(
|
||||
}
|
||||
|
||||
async fn handle_sandbox_error(
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
params: ExecParams,
|
||||
exec_command_context: ExecCommandContext,
|
||||
error: SandboxErr,
|
||||
@@ -1824,8 +1909,11 @@ async fn handle_sandbox_error(
|
||||
let cwd = exec_command_context.cwd.clone();
|
||||
let is_apply_patch = exec_command_context.apply_patch.is_some();
|
||||
|
||||
// Early out if the user never wants to be asked for approval; just return to the model immediately
|
||||
if sess.approval_policy == AskForApproval::Never {
|
||||
// Early out if either the user never wants to be asked for approval, or
|
||||
// we're letting the model manage escalation requests.
|
||||
if sess.approval_policy == AskForApproval::Never
|
||||
|| sess.approval_policy == AskForApproval::OnRequest
|
||||
{
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
@@ -1869,7 +1957,8 @@ async fn handle_sandbox_error(
|
||||
sess.notify_background_event(&sub_id, "retrying command without sandbox")
|
||||
.await;
|
||||
|
||||
sess.notify_exec_command_begin(exec_command_context).await;
|
||||
sess.on_exec_command_begin(turn_diff_tracker, exec_command_context)
|
||||
.await;
|
||||
|
||||
// This is an escalated retry; the policy will not be
|
||||
// examined and the sandbox has been set to `None`.
|
||||
@@ -1891,7 +1980,8 @@ async fn handle_sandbox_error(
|
||||
duration,
|
||||
} = retry_output;
|
||||
|
||||
sess.notify_exec_command_end(
|
||||
sess.on_exec_command_end(
|
||||
turn_diff_tracker,
|
||||
&sub_id,
|
||||
&call_id,
|
||||
&stdout,
|
||||
@@ -1991,7 +2081,7 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
|
||||
}
|
||||
|
||||
async fn drain_to_completed(sess: &Session, prompt: &Prompt) -> CodexResult<()> {
|
||||
let mut stream = sess.client.clone().stream(prompt).await?;
|
||||
let mut stream = sess.client.clone().stream(prompt, None).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
let Some(event) = maybe_event else {
|
||||
|
||||
@@ -43,6 +43,8 @@ pub struct ExecParams {
|
||||
pub cwd: PathBuf,
|
||||
pub timeout_ms: Option<u64>,
|
||||
pub env: HashMap<String, String>,
|
||||
pub with_escalated_permissions: Option<bool>,
|
||||
pub justification: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
@@ -74,6 +76,8 @@ pub async fn process_exec_tool_call(
|
||||
cwd,
|
||||
timeout_ms,
|
||||
env,
|
||||
with_escalated_permissions: _,
|
||||
justification: _,
|
||||
} = params;
|
||||
let child = spawn_command_under_seatbelt(
|
||||
command,
|
||||
@@ -91,6 +95,8 @@ pub async fn process_exec_tool_call(
|
||||
cwd,
|
||||
timeout_ms,
|
||||
env,
|
||||
with_escalated_permissions: _,
|
||||
justification: _,
|
||||
} = params;
|
||||
|
||||
let codex_linux_sandbox_exe = codex_linux_sandbox_exe
|
||||
@@ -230,6 +236,8 @@ async fn exec(
|
||||
cwd,
|
||||
timeout_ms,
|
||||
env,
|
||||
with_escalated_permissions: _,
|
||||
justification: _,
|
||||
}: ExecParams,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
ctrl_c: Arc<Notify>,
|
||||
|
||||
@@ -42,6 +42,7 @@ mod safety;
|
||||
pub mod seatbelt;
|
||||
pub mod shell;
|
||||
pub mod spawn;
|
||||
pub mod turn_diff_tracker;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
|
||||
|
||||
@@ -183,6 +183,8 @@ pub struct ShellToolCallParams {
|
||||
// The wire format uses `timeout`, which has ambiguous units, so we use
|
||||
// `timeout_ms` as the field name so it is clear in code.
|
||||
pub timeout_ms: Option<u64>,
|
||||
pub with_escalated_permissions: Option<bool>,
|
||||
pub justification: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -295,6 +297,30 @@ mod tests {
|
||||
command: vec!["ls".to_string(), "-l".to_string()],
|
||||
workdir: Some("/tmp".to_string()),
|
||||
timeout_ms: Some(1000),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
},
|
||||
params
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn deserialize_shell_tool_call_params_with_escalated_permissions() {
|
||||
let json = r#"{
|
||||
"command": ["ls", "-l"],
|
||||
"workdir": "/tmp",
|
||||
"timeout": 1000,
|
||||
"with_escalated_permissions": true,
|
||||
"justification": "I need internet access to run npm install"
|
||||
}"#;
|
||||
|
||||
let params: ShellToolCallParams = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
ShellToolCallParams {
|
||||
command: vec!["ls".to_string(), "-l".to_string()],
|
||||
workdir: Some("/tmp".to_string()),
|
||||
timeout_ms: Some(1000),
|
||||
with_escalated_permissions: Some(true),
|
||||
justification: Some("I need internet access to run npm install".to_string()),
|
||||
},
|
||||
params
|
||||
);
|
||||
|
||||
@@ -5,11 +5,12 @@ use std::sync::LazyLock;
|
||||
|
||||
use crate::client_common::Prompt;
|
||||
use crate::plan_tool::PLAN_TOOL;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(crate) struct ResponsesApiTool {
|
||||
pub(crate) name: &'static str,
|
||||
pub(crate) description: &'static str,
|
||||
pub(crate) description: String,
|
||||
pub(crate) strict: bool,
|
||||
pub(crate) parameters: JsonSchema,
|
||||
}
|
||||
@@ -29,8 +30,18 @@ pub(crate) enum OpenAiTool {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub(crate) enum JsonSchema {
|
||||
String,
|
||||
Number,
|
||||
Boolean {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
String {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
Number {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
},
|
||||
Array {
|
||||
items: Box<JsonSchema>,
|
||||
},
|
||||
@@ -48,15 +59,21 @@ static DEFAULT_TOOLS: LazyLock<Vec<OpenAiTool>> = LazyLock::new(|| {
|
||||
properties.insert(
|
||||
"command".to_string(),
|
||||
JsonSchema::Array {
|
||||
items: Box::new(JsonSchema::String),
|
||||
items: Box::new(JsonSchema::String { description: None }),
|
||||
},
|
||||
);
|
||||
properties.insert("workdir".to_string(), JsonSchema::String);
|
||||
properties.insert("timeout".to_string(), JsonSchema::Number);
|
||||
properties.insert(
|
||||
"workdir".to_string(),
|
||||
JsonSchema::String { description: None },
|
||||
);
|
||||
properties.insert(
|
||||
"timeout".to_string(),
|
||||
JsonSchema::Number { description: None },
|
||||
);
|
||||
|
||||
vec![OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "shell",
|
||||
description: "Runs a shell command, and returns its output.",
|
||||
description: "Runs a shell command, and returns its output.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
@@ -72,13 +89,18 @@ static DEFAULT_CODEX_MODEL_TOOLS: LazyLock<Vec<OpenAiTool>> =
|
||||
/// Returns JSON values that are compatible with Function Calling in the
|
||||
/// Responses API:
|
||||
/// https://platform.openai.com/docs/guides/function-calling?api-mode=responses
|
||||
/// TODO: we're starting to get more complex with our tool config, let's
|
||||
pub(crate) fn create_tools_json_for_responses_api(
|
||||
prompt: &Prompt,
|
||||
model: &str,
|
||||
include_plan_tool: bool,
|
||||
sandbox_policy: Option<SandboxPolicy>,
|
||||
) -> crate::error::Result<Vec<serde_json::Value>> {
|
||||
// Assemble tool list: built-in tools + any extra tools from the prompt.
|
||||
let default_tools = if model.starts_with("codex") {
|
||||
let default_tools = if let Some(sandbox_policy) = sandbox_policy {
|
||||
// if sandbox_policy is provided, create the shell tool
|
||||
&vec![create_shell_tool_for_sandbox(sandbox_policy)]
|
||||
} else if model.contains("codex") {
|
||||
&DEFAULT_CODEX_MODEL_TOOLS
|
||||
} else {
|
||||
&DEFAULT_TOOLS
|
||||
@@ -102,6 +124,109 @@ pub(crate) fn create_tools_json_for_responses_api(
|
||||
Ok(tools_json)
|
||||
}
|
||||
|
||||
fn create_shell_tool_for_sandbox(sandbox_policy: SandboxPolicy) -> OpenAiTool {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"command".to_string(),
|
||||
JsonSchema::Array {
|
||||
items: Box::new(JsonSchema::String {
|
||||
description: Some("The command to execute".to_string()),
|
||||
}),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"workdir".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("The working directory to execute the command in".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"timeout".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("The timeout for the command in milliseconds".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
if sandbox_policy != SandboxPolicy::DangerFullAccess {
|
||||
properties.insert(
|
||||
"with_escalated_permissions".to_string(),
|
||||
JsonSchema::Boolean {
|
||||
description: Some("Whether to request escalated permissions. Set to true if command needs to be run without sandbox restrictions".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"justification".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("Only set if ask_for_escalated_permissions is true. 1-sentence explanation of why we want to run this command.".to_string()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let description = match sandbox_policy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: _,
|
||||
network_access,
|
||||
} => {
|
||||
format!(
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
|
||||
- When invoking the shell tool, your call will be running in a landlock sandbox, and some shell commands will require escalated privileges:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Reading files outside the current directory
|
||||
- Writing files outside the current directory, and protected folders like .git or .env{}
|
||||
- Examples of commands that require escalated privileges:
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter."#,
|
||||
if !network_access {
|
||||
"\n - Commands that require network access\n"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
)
|
||||
}
|
||||
SandboxPolicy::DangerFullAccess => {
|
||||
"Runs a shell command, and returns its output.".to_string()
|
||||
}
|
||||
SandboxPolicy::ReadOnly => {
|
||||
r#"
|
||||
The shell tool is used to execute shell commands.
|
||||
IMPORTANT: If you are running the apply_patch command, you will need to provide the with_escalated_permissions parameter with the boolean value true.
|
||||
|
||||
- When invoking the shell tool, your call will be running in a landlock sandbox, and some shell commands (including apply_patch) will require escalated permissions:
|
||||
- Types of actions that require escalated privileges:
|
||||
- Reading files outside the current directory
|
||||
- Writing files
|
||||
- Applying patches
|
||||
- Examples of commands that require escalated privileges:
|
||||
- apply_patch
|
||||
- git commit
|
||||
- npm install or pnpm install
|
||||
- cargo build
|
||||
- cargo test
|
||||
- When invoking a command that will require escalated privileges:
|
||||
- Provide the with_escalated_permissions parameter with the boolean value true
|
||||
- Include a short, 1 sentence explanation for why we need to run with_escalated_permissions in the justification parameter"#.to_string()
|
||||
}
|
||||
};
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
name: "shell",
|
||||
description,
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: &["command"],
|
||||
additional_properties: false,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns JSON values that are compatible with Function Calling in the
|
||||
/// Chat Completions API:
|
||||
/// https://platform.openai.com/docs/guides/function-calling?api-mode=chat
|
||||
@@ -109,11 +234,12 @@ pub(crate) fn create_tools_json_for_chat_completions_api(
|
||||
prompt: &Prompt,
|
||||
model: &str,
|
||||
include_plan_tool: bool,
|
||||
sandbox_policy: Option<SandboxPolicy>,
|
||||
) -> crate::error::Result<Vec<serde_json::Value>> {
|
||||
// We start with the JSON for the Responses API and than rewrite it to match
|
||||
// the chat completions tool call format.
|
||||
let responses_api_tools_json =
|
||||
create_tools_json_for_responses_api(prompt, model, include_plan_tool)?;
|
||||
create_tools_json_for_responses_api(prompt, model, include_plan_tool, sandbox_policy)?;
|
||||
let tools_json = responses_api_tools_json
|
||||
.into_iter()
|
||||
.filter_map(|mut tool| {
|
||||
@@ -163,3 +289,50 @@ fn mcp_tool_to_openai_tool(
|
||||
"type": "function",
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_tools_json_for_responses_api() {
|
||||
let prompt = Prompt {
|
||||
..Default::default()
|
||||
};
|
||||
let model = "gpt-4o-mini";
|
||||
let include_plan_tool = true;
|
||||
let sandbox_policy = None;
|
||||
|
||||
let tools_json =
|
||||
create_tools_json_for_responses_api(&prompt, model, include_plan_tool, sandbox_policy)
|
||||
.unwrap();
|
||||
assert_eq!(tools_json[0]["name"], "shell");
|
||||
let properties = tools_json[0]["parameters"]["properties"]
|
||||
.as_object()
|
||||
.unwrap();
|
||||
assert!(!properties.contains_key("with_escalated_permissions"));
|
||||
assert!(!properties.contains_key("justification"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_tools_json_for_responses_api_with_sandbox() {
|
||||
let prompt = Prompt {
|
||||
..Default::default()
|
||||
};
|
||||
let model = "codex-mini-latest";
|
||||
let include_plan_tool = true;
|
||||
let sandbox_policy = Some(SandboxPolicy::ReadOnly);
|
||||
|
||||
let tools_json =
|
||||
create_tools_json_for_responses_api(&prompt, model, include_plan_tool, sandbox_policy)
|
||||
.unwrap();
|
||||
assert_eq!(tools_json[0]["name"], "shell");
|
||||
let properties = tools_json[0]["parameters"]["properties"]
|
||||
.as_object()
|
||||
.unwrap();
|
||||
assert!(properties.contains_key("with_escalated_permissions"));
|
||||
assert!(properties.contains_key("justification"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,8 +39,11 @@ pub struct UpdatePlanArgs {
|
||||
|
||||
pub(crate) static PLAN_TOOL: LazyLock<OpenAiTool> = LazyLock::new(|| {
|
||||
let mut plan_item_props = BTreeMap::new();
|
||||
plan_item_props.insert("step".to_string(), JsonSchema::String);
|
||||
plan_item_props.insert("status".to_string(), JsonSchema::String);
|
||||
plan_item_props.insert("step".to_string(), JsonSchema::String { description: None });
|
||||
plan_item_props.insert(
|
||||
"status".to_string(),
|
||||
JsonSchema::String { description: None },
|
||||
);
|
||||
|
||||
let plan_items_schema = JsonSchema::Array {
|
||||
items: Box::new(JsonSchema::Object {
|
||||
@@ -51,7 +54,10 @@ pub(crate) static PLAN_TOOL: LazyLock<OpenAiTool> = LazyLock::new(|| {
|
||||
};
|
||||
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert("explanation".to_string(), JsonSchema::String);
|
||||
properties.insert(
|
||||
"explanation".to_string(),
|
||||
JsonSchema::String { description: None },
|
||||
);
|
||||
properties.insert("plan".to_string(), plan_items_schema);
|
||||
|
||||
OpenAiTool::Function(ResponsesApiTool {
|
||||
@@ -66,7 +72,7 @@ Until all the steps are finished, there should always be exactly one in_progress
|
||||
Call the update_plan tool whenever you finish a step, marking the completed step as `completed` and marking the next step as `in_progress`.
|
||||
Before running a command, consider whether or not you have completed the previous step, and make sure to mark it as completed before moving on to the next step.
|
||||
Sometimes, you may need to change plans in the middle of a task: call `update_plan` with the updated plan and make sure to provide an `explanation` of the rationale when doing so.
|
||||
When all steps are completed, call update_plan one last time with all steps marked as `completed`."#,
|
||||
When all steps are completed, call update_plan one last time with all steps marked as `completed`."#.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
|
||||
@@ -149,6 +149,11 @@ pub enum AskForApproval {
|
||||
/// the user to approve execution without a sandbox.
|
||||
OnFailure,
|
||||
|
||||
// Experimental: Commands are run inside a sandbox, and the model can
|
||||
// proactively request escalation of privileges. Failures are handled by
|
||||
// the model.
|
||||
OnRequest,
|
||||
|
||||
/// Never ask the user to approve commands. Failures are immediately returned
|
||||
/// to the model, and never escalated to the user for approval.
|
||||
Never,
|
||||
@@ -338,6 +343,8 @@ pub enum EventMsg {
|
||||
/// Notification that a patch application has finished.
|
||||
PatchApplyEnd(PatchApplyEndEvent),
|
||||
|
||||
TurnDiff(TurnDiffEvent),
|
||||
|
||||
/// Response to GetHistoryEntryRequest.
|
||||
GetHistoryEntryResponse(GetHistoryEntryResponseEvent),
|
||||
|
||||
@@ -529,6 +536,11 @@ pub struct PatchApplyEndEvent {
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct TurnDiffEvent {
|
||||
pub unified_diff: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct GetHistoryEntryResponseEvent {
|
||||
pub offset: usize,
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::is_safe_command::is_known_safe_command;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum SafetyCheck {
|
||||
AutoApprove { sandbox_type: SandboxType },
|
||||
AskUser,
|
||||
@@ -31,7 +31,7 @@ pub fn assess_patch_safety(
|
||||
}
|
||||
|
||||
match policy {
|
||||
AskForApproval::OnFailure | AskForApproval::Never => {
|
||||
AskForApproval::OnFailure | AskForApproval::Never | AskForApproval::OnRequest => {
|
||||
// Continue to see if this can be auto-approved.
|
||||
}
|
||||
// TODO(ragona): I'm not sure this is actually correct? I believe in this case
|
||||
@@ -76,6 +76,7 @@ pub fn assess_command_safety(
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
approved: &HashSet<Vec<String>>,
|
||||
request_escalated_privileges: bool,
|
||||
) -> SafetyCheck {
|
||||
// A command is "trusted" because either:
|
||||
// - it belongs to a set of commands we consider "safe" by default, or
|
||||
@@ -96,12 +97,17 @@ pub fn assess_command_safety(
|
||||
};
|
||||
}
|
||||
|
||||
assess_safety_for_untrusted_command(approval_policy, sandbox_policy)
|
||||
assess_safety_for_untrusted_command(
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
request_escalated_privileges,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn assess_safety_for_untrusted_command(
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
with_escalated_permissions: bool,
|
||||
) -> SafetyCheck {
|
||||
use AskForApproval::*;
|
||||
use SandboxPolicy::*;
|
||||
@@ -113,9 +119,23 @@ pub(crate) fn assess_safety_for_untrusted_command(
|
||||
// commands.
|
||||
SafetyCheck::AskUser
|
||||
}
|
||||
(OnFailure, DangerFullAccess) | (Never, DangerFullAccess) => SafetyCheck::AutoApprove {
|
||||
(OnFailure, DangerFullAccess)
|
||||
| (Never, DangerFullAccess)
|
||||
| (OnRequest, DangerFullAccess) => SafetyCheck::AutoApprove {
|
||||
sandbox_type: SandboxType::None,
|
||||
},
|
||||
(OnRequest, ReadOnly) | (OnRequest, WorkspaceWrite { .. }) => {
|
||||
if with_escalated_permissions {
|
||||
SafetyCheck::AskUser
|
||||
} else {
|
||||
match get_platform_sandbox() {
|
||||
Some(sandbox_type) => SafetyCheck::AutoApprove { sandbox_type },
|
||||
// Fall back to asking since the command is untrusted and
|
||||
// we do not have a sandbox available
|
||||
None => SafetyCheck::AskUser,
|
||||
}
|
||||
}
|
||||
}
|
||||
(Never, ReadOnly)
|
||||
| (Never, WorkspaceWrite { .. })
|
||||
| (OnFailure, ReadOnly)
|
||||
@@ -264,4 +284,47 @@ mod tests {
|
||||
&cwd,
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_escalated_privileges() {
|
||||
// Should not be a trusted command
|
||||
let command = vec!["git commit".to_string()];
|
||||
let approval_policy = AskForApproval::OnRequest;
|
||||
let sandbox_policy = SandboxPolicy::ReadOnly;
|
||||
let approved: HashSet<Vec<String>> = HashSet::new();
|
||||
let request_escalated_privileges = true;
|
||||
|
||||
let safety_check = assess_command_safety(
|
||||
&command,
|
||||
approval_policy,
|
||||
&sandbox_policy,
|
||||
&approved,
|
||||
request_escalated_privileges,
|
||||
);
|
||||
|
||||
assert_eq!(safety_check, SafetyCheck::AskUser);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_escalated_privileges_no_sandbox_fallback() {
|
||||
let command = vec!["git commit".to_string()];
|
||||
let approval_policy = AskForApproval::OnRequest;
|
||||
let sandbox_policy = SandboxPolicy::ReadOnly;
|
||||
let approved: HashSet<Vec<String>> = HashSet::new();
|
||||
let request_escalated_privileges = false;
|
||||
|
||||
let safety_check = assess_command_safety(
|
||||
&command,
|
||||
approval_policy,
|
||||
&sandbox_policy,
|
||||
&approved,
|
||||
request_escalated_privileges,
|
||||
);
|
||||
|
||||
let expected = match get_platform_sandbox() {
|
||||
Some(sandbox_type) => SafetyCheck::AutoApprove { sandbox_type },
|
||||
None => SafetyCheck::AskUser,
|
||||
};
|
||||
assert_eq!(safety_check, expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,6 +215,8 @@ mod tests {
|
||||
"HOME".to_string(),
|
||||
temp_home.path().to_str().unwrap().to_string(),
|
||||
)]),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
},
|
||||
SandboxType::None,
|
||||
Arc::new(Notify::new()),
|
||||
|
||||
458
codex-rs/core/src/turn_diff_tracker.rs
Normal file
458
codex-rs/core/src/turn_diff_tracker.rs
Normal file
@@ -0,0 +1,458 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::protocol::FileChange;
|
||||
|
||||
/// Tracks sets of changes to files and exposes the overall unified diff.
|
||||
/// Internally, the way this works is now:
|
||||
/// 1. Maintain an in-memory baseline snapshot of files when they are first seen.
|
||||
/// For new additions, do not create a baseline so that diffs are shown as proper additions (using /dev/null).
|
||||
/// 2. Keep a stable internal filename (uuid + same extension) per external path for rename tracking.
|
||||
/// 3. To compute the aggregated unified diff, compare each baseline snapshot to the current file on disk entirely in-memory
|
||||
/// using the `similar` crate and emit unified diffs with rewritten external paths.
|
||||
#[derive(Default)]
|
||||
pub struct TurnDiffTracker {
|
||||
/// Map external path -> internal filename (uuid + same extension).
|
||||
external_to_temp_name: HashMap<PathBuf, String>,
|
||||
/// Internal filename -> external path as of baseline snapshot.
|
||||
temp_name_to_baseline_external: HashMap<String, PathBuf>,
|
||||
/// Internal filename -> external path as of current accumulated state (after applying all changes).
|
||||
/// This is where renames are tracked.
|
||||
temp_name_to_current_external: HashMap<String, PathBuf>,
|
||||
/// Internal filename -> baseline file contents (None means the file did not exist, i.e. /dev/null).
|
||||
baseline_contents: HashMap<String, Option<String>>,
|
||||
/// Internal filename -> baseline file mode (100644 or 100755). Only set when baseline file existed.
|
||||
baseline_mode: HashMap<String, String>,
|
||||
/// Aggregated unified diff for all accumulated changes across files.
|
||||
pub unified_diff: Option<String>,
|
||||
}
|
||||
|
||||
impl TurnDiffTracker {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Front-run apply patch calls to track the starting contents of any modified files.
|
||||
/// - Creates an in-memory baseline snapshot for files that already exist on disk when first seen.
|
||||
/// - For additions, we intentionally do not create a baseline snapshot so that diffs are proper additions.
|
||||
/// - Also updates internal mappings for move/rename events.
|
||||
pub fn on_patch_begin(&mut self, changes: &HashMap<PathBuf, FileChange>) -> Result<()> {
|
||||
for (path, change) in changes.iter() {
|
||||
// Ensure a stable internal filename exists for this external path.
|
||||
if !self.external_to_temp_name.contains_key(path) {
|
||||
let internal = uuid_filename_for(path);
|
||||
self.external_to_temp_name
|
||||
.insert(path.clone(), internal.clone());
|
||||
self.temp_name_to_baseline_external
|
||||
.insert(internal.clone(), path.clone());
|
||||
self.temp_name_to_current_external
|
||||
.insert(internal.clone(), path.clone());
|
||||
|
||||
// If the file exists on disk now, snapshot as baseline; else leave missing to represent /dev/null.
|
||||
let baseline = if path.exists() {
|
||||
let contents = fs::read(path)
|
||||
.with_context(|| format!("failed to read original {}", path.display()))?;
|
||||
// Capture baseline mode for later file mode lines.
|
||||
if let Some(mode) = file_mode_for_path(path) {
|
||||
self.baseline_mode.insert(internal.clone(), mode);
|
||||
}
|
||||
Some(String::from_utf8_lossy(&contents).into_owned())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.baseline_contents.insert(internal.clone(), baseline);
|
||||
}
|
||||
|
||||
// Track rename/move in current mapping if provided in an Update.
|
||||
if let FileChange::Update {
|
||||
move_path: Some(dest),
|
||||
..
|
||||
} = change
|
||||
{
|
||||
let uuid_filename = match self.external_to_temp_name.get(path) {
|
||||
Some(i) => i.clone(),
|
||||
None => {
|
||||
// This should be rare, but if we haven't mapped the source, create it with no baseline.
|
||||
let i = uuid_filename_for(path);
|
||||
self.external_to_temp_name.insert(path.clone(), i.clone());
|
||||
self.temp_name_to_baseline_external
|
||||
.insert(i.clone(), path.clone());
|
||||
// No on-disk file read here; treat as addition.
|
||||
self.baseline_contents.insert(i.clone(), None);
|
||||
i
|
||||
}
|
||||
};
|
||||
// Update current external mapping for temp file name.
|
||||
self.temp_name_to_current_external
|
||||
.insert(uuid_filename.clone(), dest.clone());
|
||||
// Update forward file_mapping: external current -> internal name.
|
||||
self.external_to_temp_name.remove(path);
|
||||
self.external_to_temp_name
|
||||
.insert(dest.clone(), uuid_filename);
|
||||
};
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_path_for_internal(&self, internal: &str) -> Option<PathBuf> {
|
||||
self.temp_name_to_current_external
|
||||
.get(internal)
|
||||
.or_else(|| self.temp_name_to_baseline_external.get(internal))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Recompute the aggregated unified diff by comparing all of the in-memory snapshots that were
|
||||
/// collected before the first time they were touched by apply_patch during this turn with
|
||||
/// the current repo state.
|
||||
pub fn get_unified_diff(&mut self) -> Result<Option<String>> {
|
||||
let mut aggregated = String::new();
|
||||
|
||||
// Compute diffs per tracked internal file in a stable order by external path.
|
||||
let mut internals: Vec<String> = self
|
||||
.temp_name_to_baseline_external
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect();
|
||||
// Sort lexicographically by external path to match git behavior.
|
||||
internals.sort_by_key(|a| {
|
||||
let path = self.get_path_for_internal(a);
|
||||
match path {
|
||||
Some(p) => p
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.to_owned())
|
||||
.unwrap_or_default(),
|
||||
None => String::new(),
|
||||
}
|
||||
});
|
||||
|
||||
for internal in internals {
|
||||
// Baseline external must exist for any tracked internal.
|
||||
let baseline_external = match self.temp_name_to_baseline_external.get(&internal) {
|
||||
Some(p) => p.clone(),
|
||||
None => continue,
|
||||
};
|
||||
let current_external = match self.get_path_for_internal(&internal) {
|
||||
Some(p) => p,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let left_content = self
|
||||
.baseline_contents
|
||||
.get(&internal)
|
||||
.cloned()
|
||||
.unwrap_or(None);
|
||||
|
||||
let right_content = if current_external.exists() {
|
||||
let contents = fs::read(¤t_external).with_context(|| {
|
||||
format!(
|
||||
"failed to read current file for diff {}",
|
||||
current_external.display()
|
||||
)
|
||||
})?;
|
||||
Some(String::from_utf8_lossy(&contents).into_owned())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let left_text = left_content.as_deref().unwrap_or("");
|
||||
let right_text = right_content.as_deref().unwrap_or("");
|
||||
|
||||
if left_text == right_text {
|
||||
continue;
|
||||
}
|
||||
|
||||
let left_display = baseline_external.display().to_string();
|
||||
let right_display = current_external.display().to_string();
|
||||
|
||||
// Diff the contents.
|
||||
let diff = similar::TextDiff::from_lines(left_text, right_text);
|
||||
|
||||
// Emit a git-style header for better readability and parity with previous behavior.
|
||||
aggregated.push_str(&format!("diff --git a/{left_display} b/{right_display}\n"));
|
||||
|
||||
// Emit file mode lines and index line similar to git.
|
||||
let is_add = left_content.is_none() && right_content.is_some();
|
||||
let is_delete = left_content.is_some() && right_content.is_none();
|
||||
|
||||
// Determine modes.
|
||||
let baseline_mode = self
|
||||
.baseline_mode
|
||||
.get(&internal)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "100644".to_string());
|
||||
let current_mode =
|
||||
file_mode_for_path(¤t_external).unwrap_or_else(|| "100644".to_string());
|
||||
|
||||
if is_add {
|
||||
aggregated.push_str(&format!("new file mode {current_mode}\n"));
|
||||
} else if is_delete {
|
||||
aggregated.push_str(&format!("deleted file mode {baseline_mode}\n"));
|
||||
} else if baseline_mode != current_mode {
|
||||
aggregated.push_str(&format!("old mode {baseline_mode}\n"));
|
||||
aggregated.push_str(&format!("new mode {current_mode}\n"));
|
||||
}
|
||||
|
||||
aggregated.push_str(&format!("index {ZERO_OID}..{ZERO_OID}\n"));
|
||||
|
||||
let old_header = if left_content.is_some() {
|
||||
format!("a/{left_display}")
|
||||
} else {
|
||||
"/dev/null".to_string()
|
||||
};
|
||||
let new_header = if right_content.is_some() {
|
||||
format!("b/{right_display}")
|
||||
} else {
|
||||
"/dev/null".to_string()
|
||||
};
|
||||
|
||||
let unified = diff
|
||||
.unified_diff()
|
||||
.context_radius(3)
|
||||
.header(&old_header, &new_header)
|
||||
.to_string();
|
||||
|
||||
aggregated.push_str(&unified);
|
||||
if !aggregated.ends_with('\n') {
|
||||
aggregated.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
self.unified_diff = if aggregated.trim().is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(aggregated)
|
||||
};
|
||||
|
||||
Ok(self.unified_diff.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn uuid_filename_for(path: &Path) -> String {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
match path.extension().and_then(|e| e.to_str()) {
|
||||
Some(ext) if !ext.is_empty() => format!("{id}.{ext}"),
|
||||
_ => id,
|
||||
}
|
||||
}
|
||||
|
||||
const ZERO_OID: &str = "0000000000000000000000000000000000000000";
|
||||
|
||||
fn file_mode_for_path(path: &Path) -> Option<String> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let meta = fs::metadata(path).ok()?;
|
||||
let mode = meta.permissions().mode();
|
||||
let is_exec = (mode & 0o111) != 0;
|
||||
Some(if is_exec {
|
||||
"100755".to_string()
|
||||
} else {
|
||||
"100644".to_string()
|
||||
})
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
// Default to non-executable on non-unix.
|
||||
Some("100644".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn normalize_diff_for_test(input: &str, root: &Path) -> String {
|
||||
let root_str = root.display().to_string();
|
||||
let replaced = input.replace(&root_str, "<TMP>");
|
||||
// Split into blocks on lines starting with "diff --git ", sort blocks for determinism, and rejoin
|
||||
let mut blocks: Vec<String> = Vec::new();
|
||||
let mut current = String::new();
|
||||
for line in replaced.lines() {
|
||||
if line.starts_with("diff --git ") && !current.is_empty() {
|
||||
blocks.push(current);
|
||||
current = String::new();
|
||||
}
|
||||
if !current.is_empty() {
|
||||
current.push('\n');
|
||||
}
|
||||
current.push_str(line);
|
||||
}
|
||||
if !current.is_empty() {
|
||||
blocks.push(current);
|
||||
}
|
||||
blocks.sort();
|
||||
let mut out = blocks.join("\n");
|
||||
if !out.ends_with('\n') {
|
||||
out.push('\n');
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulates_add_and_update() {
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("a.txt");
|
||||
|
||||
// First patch: add file (baseline should be /dev/null).
|
||||
let add_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Add {
|
||||
content: "foo\n".to_string(),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&add_changes).unwrap();
|
||||
|
||||
// Simulate apply: create the file on disk.
|
||||
fs::write(&file, "foo\n").unwrap();
|
||||
acc.get_unified_diff().unwrap();
|
||||
let first = acc.unified_diff.clone().unwrap();
|
||||
let first = normalize_diff_for_test(&first, dir.path());
|
||||
let expected_first = {
|
||||
let mode = file_mode_for_path(&file).unwrap_or_else(|| "100644".to_string());
|
||||
format!(
|
||||
"diff --git a/<TMP>/a.txt b/<TMP>/a.txt\nnew file mode {mode}\nindex {ZERO_OID}..{ZERO_OID}\n--- /dev/null\n+++ b/<TMP>/a.txt\n@@ -0,0 +1 @@\n+foo\n",
|
||||
)
|
||||
};
|
||||
assert_eq!(first, expected_first);
|
||||
|
||||
// Second patch: update the file on disk.
|
||||
let update_changes = HashMap::from([(
|
||||
file.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: None,
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&update_changes).unwrap();
|
||||
|
||||
// Simulate apply: append a new line.
|
||||
fs::write(&file, "foo\nbar\n").unwrap();
|
||||
acc.get_unified_diff().unwrap();
|
||||
let combined = acc.unified_diff.clone().unwrap();
|
||||
let combined = normalize_diff_for_test(&combined, dir.path());
|
||||
let expected_combined = {
|
||||
let mode = file_mode_for_path(&file).unwrap_or_else(|| "100644".to_string());
|
||||
format!(
|
||||
"diff --git a/<TMP>/a.txt b/<TMP>/a.txt\nnew file mode {mode}\nindex {ZERO_OID}..{ZERO_OID}\n--- /dev/null\n+++ b/<TMP>/a.txt\n@@ -0,0 +1,2 @@\n+foo\n+bar\n",
|
||||
)
|
||||
};
|
||||
assert_eq!(combined, expected_combined);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulates_delete() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("b.txt");
|
||||
fs::write(&file, "x\n").unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let del_changes = HashMap::from([(file.clone(), FileChange::Delete)]);
|
||||
acc.on_patch_begin(&del_changes).unwrap();
|
||||
|
||||
// Simulate apply: delete the file from disk.
|
||||
let baseline_mode = file_mode_for_path(&file).unwrap_or_else(|| "100644".to_string());
|
||||
fs::remove_file(&file).unwrap();
|
||||
acc.get_unified_diff().unwrap();
|
||||
let diff = acc.unified_diff.clone().unwrap();
|
||||
let diff = normalize_diff_for_test(&diff, dir.path());
|
||||
let expected = format!(
|
||||
"diff --git a/<TMP>/b.txt b/<TMP>/b.txt\ndeleted file mode {baseline_mode}\nindex {ZERO_OID}..{ZERO_OID}\n--- a/<TMP>/b.txt\n+++ /dev/null\n@@ -1 +0,0 @@\n-x\n",
|
||||
);
|
||||
assert_eq!(diff, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accumulates_move_and_update() {
|
||||
let dir = tempdir().unwrap();
|
||||
let src = dir.path().join("src.txt");
|
||||
let dest = dir.path().join("dst.txt");
|
||||
fs::write(&src, "line\n").unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
let mv_changes = HashMap::from([(
|
||||
src.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: Some(dest.clone()),
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&mv_changes).unwrap();
|
||||
|
||||
// Simulate apply: move and update content.
|
||||
fs::rename(&src, &dest).unwrap();
|
||||
fs::write(&dest, "line2\n").unwrap();
|
||||
|
||||
acc.get_unified_diff().unwrap();
|
||||
let out = acc.unified_diff.clone().unwrap();
|
||||
let out = normalize_diff_for_test(&out, dir.path());
|
||||
let expected = {
|
||||
format!(
|
||||
"diff --git a/<TMP>/src.txt b/<TMP>/dst.txt\nindex {ZERO_OID}..{ZERO_OID}\n--- a/<TMP>/src.txt\n+++ b/<TMP>/dst.txt\n@@ -1 +1 @@\n-line\n+line2\n"
|
||||
)
|
||||
};
|
||||
assert_eq!(out, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_persists_across_new_baseline_for_new_file() {
|
||||
let dir = tempdir().unwrap();
|
||||
let a = dir.path().join("a.txt");
|
||||
let b = dir.path().join("b.txt");
|
||||
fs::write(&a, "foo\n").unwrap();
|
||||
fs::write(&b, "z\n").unwrap();
|
||||
|
||||
let mut acc = TurnDiffTracker::new();
|
||||
|
||||
// First: update existing a.txt (baseline snapshot is created for a).
|
||||
let update_a = HashMap::from([(
|
||||
a.clone(),
|
||||
FileChange::Update {
|
||||
unified_diff: "".to_owned(),
|
||||
move_path: None,
|
||||
},
|
||||
)]);
|
||||
acc.on_patch_begin(&update_a).unwrap();
|
||||
// Simulate apply: modify a.txt on disk.
|
||||
fs::write(&a, "foo\nbar\n").unwrap();
|
||||
acc.get_unified_diff().unwrap();
|
||||
let first = acc.unified_diff.clone().unwrap();
|
||||
let first = normalize_diff_for_test(&first, dir.path());
|
||||
let expected_first = {
|
||||
format!(
|
||||
"diff --git a/<TMP>/a.txt b/<TMP>/a.txt\nindex {ZERO_OID}..{ZERO_OID}\n--- a/<TMP>/a.txt\n+++ b/<TMP>/a.txt\n@@ -1 +1,2 @@\n foo\n+bar\n"
|
||||
)
|
||||
};
|
||||
assert_eq!(first, expected_first);
|
||||
|
||||
// Next: introduce a brand-new path b.txt into baseline snapshots via a delete change.
|
||||
let del_b = HashMap::from([(b.clone(), FileChange::Delete)]);
|
||||
acc.on_patch_begin(&del_b).unwrap();
|
||||
// Simulate apply: delete b.txt.
|
||||
let baseline_mode = file_mode_for_path(&b).unwrap_or_else(|| "100644".to_string());
|
||||
fs::remove_file(&b).unwrap();
|
||||
acc.get_unified_diff().unwrap();
|
||||
|
||||
let combined = acc.unified_diff.clone().unwrap();
|
||||
let combined = normalize_diff_for_test(&combined, dir.path());
|
||||
let expected = {
|
||||
format!(
|
||||
"diff --git a/<TMP>/a.txt b/<TMP>/a.txt\nindex {ZERO_OID}..{ZERO_OID}\n--- a/<TMP>/a.txt\n+++ b/<TMP>/a.txt\n@@ -1 +1,2 @@\n foo\n+bar\n\
|
||||
diff --git a/<TMP>/b.txt b/<TMP>/b.txt\ndeleted file mode {baseline_mode}\nindex {ZERO_OID}..{ZERO_OID}\n--- a/<TMP>/b.txt\n+++ /dev/null\n@@ -1 +0,0 @@\n-z\n",
|
||||
)
|
||||
};
|
||||
assert_eq!(combined, expected);
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,7 @@ use codex_core::protocol::PatchApplyEndEvent;
|
||||
use codex_core::protocol::SessionConfiguredEvent;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use codex_core::protocol::TokenUsage;
|
||||
use codex_core::protocol::TurnDiffEvent;
|
||||
use owo_colors::OwoColorize;
|
||||
use owo_colors::Style;
|
||||
use shlex::try_join;
|
||||
@@ -402,6 +403,7 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
stdout,
|
||||
stderr,
|
||||
success,
|
||||
..
|
||||
}) => {
|
||||
let patch_begin = self.call_id_to_patch.remove(&call_id);
|
||||
|
||||
@@ -431,6 +433,10 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
||||
println!("{}", line.style(self.dimmed));
|
||||
}
|
||||
}
|
||||
EventMsg::TurnDiff(TurnDiffEvent { unified_diff }) => {
|
||||
ts_println!(self, "{}", "turn diff:".style(self.magenta));
|
||||
println!("{unified_diff}");
|
||||
}
|
||||
EventMsg::ExecApprovalRequest(_) => {
|
||||
// Should we exit?
|
||||
}
|
||||
|
||||
@@ -44,6 +44,8 @@ async fn run_cmd(cmd: &[&str], writable_roots: &[PathBuf], timeout_ms: u64) {
|
||||
cwd: std::env::current_dir().expect("cwd should exist"),
|
||||
timeout_ms: Some(timeout_ms),
|
||||
env: create_env_from_core_vars(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
|
||||
let sandbox_policy = SandboxPolicy::WorkspaceWrite {
|
||||
@@ -137,6 +139,8 @@ async fn assert_network_blocked(cmd: &[&str]) {
|
||||
// do not stall the suite.
|
||||
timeout_ms: Some(NETWORK_TIMEOUT_MS),
|
||||
env: create_env_from_core_vars(),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
|
||||
let sandbox_policy = SandboxPolicy::new_read_only_policy();
|
||||
|
||||
@@ -262,6 +262,7 @@ async fn run_codex_tool_session_inner(
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::ShutdownComplete => {
|
||||
|
||||
Reference in New Issue
Block a user