mirror of
https://github.com/openai/codex.git
synced 2026-05-06 14:21:08 +03:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0aca760ba8 |
15
codex-rs/Cargo.lock
generated
15
codex-rs/Cargo.lock
generated
@@ -2374,6 +2374,7 @@ dependencies = [
|
||||
"codex-feedback",
|
||||
"codex-git-utils",
|
||||
"codex-hooks",
|
||||
"codex-kernel",
|
||||
"codex-login",
|
||||
"codex-mcp",
|
||||
"codex-memories-read",
|
||||
@@ -2815,6 +2816,20 @@ dependencies = [
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-kernel"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"codex-protocol",
|
||||
"codex-tools",
|
||||
"pretty_assertions",
|
||||
"rand 0.9.3",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-keyring-store"
|
||||
version = "0.0.0"
|
||||
|
||||
@@ -43,6 +43,7 @@ members = [
|
||||
"external-agent-migration",
|
||||
"external-agent-sessions",
|
||||
"keyring-store",
|
||||
"kernel",
|
||||
"file-search",
|
||||
"linux-sandbox",
|
||||
"lmstudio",
|
||||
@@ -158,6 +159,7 @@ codex-file-search = { path = "file-search" }
|
||||
codex-git-utils = { path = "git-utils" }
|
||||
codex-hooks = { path = "hooks" }
|
||||
codex-keyring-store = { path = "keyring-store" }
|
||||
codex-kernel = { path = "kernel" }
|
||||
codex-linux-sandbox = { path = "linux-sandbox" }
|
||||
codex-lmstudio = { path = "lmstudio" }
|
||||
codex-login = { path = "login" }
|
||||
|
||||
@@ -331,10 +331,20 @@ async fn turn_start_emits_thread_scoped_warning_notification_for_trimmed_skills(
|
||||
let warning: WarningNotification =
|
||||
serde_json::from_value(params).expect("deserialize warning notification");
|
||||
assert_eq!(warning.thread_id.as_deref(), Some(thread.id.as_str()));
|
||||
assert_eq!(
|
||||
warning.message,
|
||||
"Exceeded skills context budget of 2%. All skill descriptions were removed and 7 additional skills were not included in the model-visible skills list."
|
||||
);
|
||||
let prefix = "Exceeded skills context budget of 2%. All skill descriptions were removed and ";
|
||||
let plural_suffix = " additional skills were not included in the model-visible skills list.";
|
||||
let singular_suffix = " additional skill was not included in the model-visible skills list.";
|
||||
let omitted_count = warning
|
||||
.message
|
||||
.strip_prefix(prefix)
|
||||
.and_then(|message| {
|
||||
message
|
||||
.strip_suffix(plural_suffix)
|
||||
.or_else(|| message.strip_suffix(singular_suffix))
|
||||
})
|
||||
.and_then(|count| count.parse::<usize>().ok())
|
||||
.expect("warning should report an omitted skill count");
|
||||
assert!(omitted_count >= 1, "expected at least one omitted skill");
|
||||
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
|
||||
@@ -47,6 +47,7 @@ codex-shell-command = { workspace = true }
|
||||
codex-execpolicy = { workspace = true }
|
||||
codex-git-utils = { workspace = true }
|
||||
codex-hooks = { workspace = true }
|
||||
codex-kernel = { workspace = true }
|
||||
codex-network-proxy = { workspace = true }
|
||||
codex-otel = { workspace = true }
|
||||
codex-plugin = { workspace = true }
|
||||
|
||||
@@ -1,14 +1,7 @@
|
||||
pub use codex_api::ResponseEvent;
|
||||
use codex_config::types::Personality;
|
||||
pub use codex_kernel::Prompt;
|
||||
use codex_protocol::error::Result;
|
||||
use codex_protocol::models::BaseInstructions;
|
||||
use codex_protocol::models::FunctionCallOutputBody;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_tools::ToolSpec;
|
||||
use futures::Stream;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashSet;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
@@ -23,157 +16,6 @@ pub const REVIEW_EXIT_SUCCESS_TMPL: &str = include_str!("../templates/review/exi
|
||||
pub const REVIEW_EXIT_INTERRUPTED_TMPL: &str =
|
||||
include_str!("../templates/review/exit_interrupted.xml");
|
||||
|
||||
/// API request payload for a single model turn
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
/// Conversation context input items.
|
||||
pub input: Vec<ResponseItem>,
|
||||
|
||||
/// Tools available to the model, including additional tools sourced from
|
||||
/// external MCP servers.
|
||||
pub(crate) tools: Vec<ToolSpec>,
|
||||
|
||||
/// Whether parallel tool calls are permitted for this prompt.
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
|
||||
pub base_instructions: BaseInstructions,
|
||||
|
||||
/// Optionally specify the personality of the model.
|
||||
pub personality: Option<Personality>,
|
||||
|
||||
/// Optional the output schema for the model's response.
|
||||
pub output_schema: Option<Value>,
|
||||
|
||||
/// Whether the Responses API should strictly validate `output_schema`.
|
||||
pub output_schema_strict: bool,
|
||||
}
|
||||
|
||||
impl Default for Prompt {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
input: Vec::new(),
|
||||
tools: Vec::new(),
|
||||
parallel_tool_calls: false,
|
||||
base_instructions: BaseInstructions::default(),
|
||||
personality: None,
|
||||
output_schema: None,
|
||||
output_schema_strict: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
let mut input = self.input.clone();
|
||||
|
||||
// when using the *Freeform* apply_patch tool specifically, tool outputs
|
||||
// should be structured text, not json. Do NOT reserialize when using
|
||||
// the Function tool - note that this differs from the check above for
|
||||
// instructions. We declare the result as a named variable for clarity.
|
||||
let is_freeform_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
||||
ToolSpec::Freeform(f) => f.name == "apply_patch",
|
||||
_ => false,
|
||||
});
|
||||
if is_freeform_apply_patch_tool_present {
|
||||
reserialize_shell_outputs(&mut input);
|
||||
}
|
||||
|
||||
input
|
||||
}
|
||||
}
|
||||
|
||||
fn reserialize_shell_outputs(items: &mut [ResponseItem]) {
|
||||
let mut shell_call_ids: HashSet<String> = HashSet::new();
|
||||
|
||||
items.iter_mut().for_each(|item| match item {
|
||||
ResponseItem::LocalShellCall { call_id, id, .. } => {
|
||||
if let Some(identifier) = call_id.clone().or_else(|| id.clone()) {
|
||||
shell_call_ids.insert(identifier);
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id: _,
|
||||
status: _,
|
||||
call_id,
|
||||
name,
|
||||
input: _,
|
||||
} => {
|
||||
if name == "apply_patch" {
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCall { name, call_id, .. }
|
||||
if is_shell_tool_name(name) || name == "apply_patch" =>
|
||||
{
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id, output, ..
|
||||
}
|
||||
| ResponseItem::CustomToolCallOutput {
|
||||
call_id, output, ..
|
||||
} => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = output
|
||||
.text_content()
|
||||
.and_then(parse_structured_shell_output)
|
||||
{
|
||||
output.body = FunctionCallOutputBody::Text(structured);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
}
|
||||
|
||||
fn is_shell_tool_name(name: &str) -> bool {
|
||||
matches!(name, "shell" | "container.exec")
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputJson {
|
||||
output: String,
|
||||
metadata: ExecOutputMetadataJson,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputMetadataJson {
|
||||
exit_code: i32,
|
||||
duration_seconds: f32,
|
||||
}
|
||||
|
||||
fn parse_structured_shell_output(raw: &str) -> Option<String> {
|
||||
let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?;
|
||||
Some(build_structured_output(&parsed))
|
||||
}
|
||||
|
||||
fn build_structured_output(parsed: &ExecOutputJson) -> String {
|
||||
let mut sections = Vec::new();
|
||||
sections.push(format!("Exit code: {}", parsed.metadata.exit_code));
|
||||
sections.push(format!(
|
||||
"Wall time: {} seconds",
|
||||
parsed.metadata.duration_seconds
|
||||
));
|
||||
|
||||
let mut output = parsed.output.clone();
|
||||
if let Some((stripped, total_lines)) = strip_total_output_header(&parsed.output) {
|
||||
sections.push(format!("Total output lines: {total_lines}"));
|
||||
output = stripped.to_string();
|
||||
}
|
||||
|
||||
sections.push("Output:".to_string());
|
||||
sections.push(output);
|
||||
|
||||
sections.join("\n")
|
||||
}
|
||||
|
||||
fn strip_total_output_header(output: &str) -> Option<(&str, u32)> {
|
||||
let after_prefix = output.strip_prefix("Total output lines: ")?;
|
||||
let (total_segment, remainder) = after_prefix.split_once('\n')?;
|
||||
let total_lines = total_segment.parse::<u32>().ok()?;
|
||||
let remainder = remainder.strip_prefix('\n').unwrap_or(remainder);
|
||||
Some((remainder, total_lines))
|
||||
}
|
||||
|
||||
pub struct ResponseStream {
|
||||
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
|
||||
/// Signals the mapper task that the consumer stopped polling before the
|
||||
|
||||
@@ -3,11 +3,9 @@ use codex_api::ResponsesApiRequest;
|
||||
use codex_api::TextControls;
|
||||
use codex_api::create_text_param_for_request;
|
||||
use codex_protocol::config_types::ServiceTier;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn serializes_text_verbosity_when_set() {
|
||||
let input: Vec<ResponseItem> = vec![];
|
||||
@@ -166,65 +164,3 @@ fn serializes_flex_service_tier_when_set() {
|
||||
Some("flex")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reserializes_shell_outputs_for_function_and_custom_tool_calls() {
|
||||
let raw_output = r#"{"output":"hello","metadata":{"exit_code":0,"duration_seconds":0.5}}"#;
|
||||
let expected_output = "Exit code: 0\nWall time: 0.5 seconds\nOutput:\nhello";
|
||||
let mut items = vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "shell".to_string(),
|
||||
namespace: None,
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-1".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-1".to_string(),
|
||||
output: FunctionCallOutputPayload::from_text(raw_output.to_string()),
|
||||
},
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "call-2".to_string(),
|
||||
name: "apply_patch".to_string(),
|
||||
input: "*** Begin Patch".to_string(),
|
||||
},
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: "call-2".to_string(),
|
||||
name: None,
|
||||
output: FunctionCallOutputPayload::from_text(raw_output.to_string()),
|
||||
},
|
||||
];
|
||||
|
||||
reserialize_shell_outputs(&mut items);
|
||||
|
||||
assert_eq!(
|
||||
items,
|
||||
vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "shell".to_string(),
|
||||
namespace: None,
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-1".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-1".to_string(),
|
||||
output: FunctionCallOutputPayload::from_text(expected_output.to_string()),
|
||||
},
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "call-2".to_string(),
|
||||
name: "apply_patch".to_string(),
|
||||
input: "*** Begin Patch".to_string(),
|
||||
},
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: "call-2".to_string(),
|
||||
name: None,
|
||||
output: FunctionCallOutputPayload::from_text(expected_output.to_string()),
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
@@ -188,6 +188,7 @@ mod mcp;
|
||||
mod multi_agents;
|
||||
mod review;
|
||||
mod rollout_reconstruction;
|
||||
mod sampling_loop;
|
||||
#[allow(clippy::module_inception)]
|
||||
pub(crate) mod session;
|
||||
pub(crate) mod turn;
|
||||
|
||||
221
codex-rs/core/src/session/sampling_loop.rs
Normal file
221
codex-rs/core/src/session/sampling_loop.rs
Normal file
@@ -0,0 +1,221 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::SkillLoadOutcome;
|
||||
use crate::client::ModelClientSession;
|
||||
use crate::session::session::Session;
|
||||
use crate::session::turn::build_prompt_config;
|
||||
use crate::session::turn::build_tool_config;
|
||||
use crate::session::turn::built_tools;
|
||||
use crate::session::turn::try_run_sampling_request;
|
||||
use crate::session::turn_context::TurnContext;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
use codex_kernel::PreparedSamplingRequest;
|
||||
use codex_kernel::Prompt;
|
||||
use codex_kernel::SamplingLoopHost;
|
||||
use codex_kernel::SamplingRequestResult;
|
||||
use codex_kernel::run_sampling_request_loop;
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::Result as CodexResult;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::CodexErrorInfo;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use codex_protocol::protocol::StreamErrorEvent;
|
||||
use codex_protocol::protocol::WarningEvent;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::instrument;
|
||||
use tracing::warn;
|
||||
|
||||
struct SamplingRequestRuntime {
|
||||
_code_mode_worker: Option<codex_code_mode::CodeModeTurnWorker>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[instrument(
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
turn_id = %turn_context.sub_id,
|
||||
model = %turn_context.model_info.slug,
|
||||
cwd = %turn_context.cwd.display()
|
||||
)
|
||||
)]
|
||||
pub(super) async fn run_sampling_request(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
client_session: &mut ModelClientSession,
|
||||
turn_metadata_header: Option<&str>,
|
||||
input: Vec<ResponseItem>,
|
||||
explicitly_enabled_connectors: &HashSet<String>,
|
||||
skills_outcome: Option<&SkillLoadOutcome>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<SamplingRequestResult> {
|
||||
let host = CoreSamplingLoopHost {
|
||||
sess,
|
||||
turn_context,
|
||||
turn_diff_tracker,
|
||||
turn_metadata_header,
|
||||
explicitly_enabled_connectors,
|
||||
skills_outcome,
|
||||
};
|
||||
|
||||
run_sampling_request_loop(&host, client_session, input, cancellation_token).await
|
||||
}
|
||||
|
||||
struct CoreSamplingLoopHost<'a> {
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
turn_metadata_header: Option<&'a str>,
|
||||
explicitly_enabled_connectors: &'a HashSet<String>,
|
||||
skills_outcome: Option<&'a SkillLoadOutcome>,
|
||||
}
|
||||
|
||||
impl SamplingLoopHost for CoreSamplingLoopHost<'_> {
|
||||
type ClientSession = ModelClientSession;
|
||||
type Runtime = SamplingRequestRuntime;
|
||||
type Tools = ToolCallRuntime;
|
||||
|
||||
async fn prepare_sampling_request(
|
||||
&self,
|
||||
input: &[ResponseItem],
|
||||
cancellation_token: &CancellationToken,
|
||||
) -> CodexResult<PreparedSamplingRequest<Self::Runtime, Self::Tools>> {
|
||||
let router = built_tools(
|
||||
self.sess.as_ref(),
|
||||
self.turn_context.as_ref(),
|
||||
input,
|
||||
self.explicitly_enabled_connectors,
|
||||
self.skills_outcome,
|
||||
cancellation_token,
|
||||
)
|
||||
.await?;
|
||||
let base_instructions = self.sess.get_base_instructions().await;
|
||||
let tool_runtime = ToolCallRuntime::new(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&self.sess),
|
||||
Arc::clone(&self.turn_context),
|
||||
Arc::clone(&self.turn_diff_tracker),
|
||||
);
|
||||
let code_mode_worker = self
|
||||
.sess
|
||||
.services
|
||||
.code_mode_service
|
||||
.start_turn_worker(
|
||||
&self.sess,
|
||||
&self.turn_context,
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&self.turn_diff_tracker),
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(PreparedSamplingRequest {
|
||||
prompt_config: build_prompt_config(self.turn_context.as_ref(), base_instructions),
|
||||
tool_config: build_tool_config(router.as_ref(), self.turn_context.as_ref()),
|
||||
runtime: SamplingRequestRuntime {
|
||||
_code_mode_worker: code_mode_worker,
|
||||
},
|
||||
tools: tool_runtime,
|
||||
})
|
||||
}
|
||||
|
||||
async fn history_prompt_input(&self) -> Vec<ResponseItem> {
|
||||
self.sess
|
||||
.clone_history()
|
||||
.await
|
||||
.for_prompt(&self.turn_context.model_info.input_modalities)
|
||||
}
|
||||
|
||||
async fn run_single_sampling_request(
|
||||
&self,
|
||||
_runtime: &Self::Runtime,
|
||||
tools: &Self::Tools,
|
||||
client_session: &mut Self::ClientSession,
|
||||
prompt: &Prompt,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<SamplingRequestResult> {
|
||||
try_run_sampling_request(
|
||||
tools.clone(),
|
||||
Arc::clone(&self.sess),
|
||||
Arc::clone(&self.turn_context),
|
||||
client_session,
|
||||
self.turn_metadata_header,
|
||||
Arc::clone(&self.turn_diff_tracker),
|
||||
prompt,
|
||||
cancellation_token,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn stream_max_retries(&self) -> u64 {
|
||||
self.turn_context.provider.info().stream_max_retries()
|
||||
}
|
||||
|
||||
fn try_switch_fallback_transport(&self, client_session: &mut Self::ClientSession) -> bool {
|
||||
client_session.try_switch_fallback_transport(
|
||||
&self.turn_context.session_telemetry,
|
||||
&self.turn_context.model_info,
|
||||
)
|
||||
}
|
||||
|
||||
fn should_notify_stream_retry(&self, retries: u64, _err: &CodexErr) -> bool {
|
||||
retries > 1
|
||||
|| cfg!(debug_assertions)
|
||||
|| !self
|
||||
.sess
|
||||
.services
|
||||
.model_client
|
||||
.responses_websocket_enabled()
|
||||
}
|
||||
|
||||
async fn handle_context_window_exceeded(&self) {
|
||||
self.sess.set_total_tokens_full(&self.turn_context).await;
|
||||
}
|
||||
|
||||
async fn handle_usage_limit_reached(&self, rate_limits: Option<RateLimitSnapshot>) {
|
||||
if let Some(rate_limits) = rate_limits {
|
||||
self.sess
|
||||
.update_rate_limits(&self.turn_context, rate_limits)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn notify_fallback_to_http(&self, err: &CodexErr) {
|
||||
self.sess
|
||||
.send_event(
|
||||
&self.turn_context,
|
||||
EventMsg::Warning(WarningEvent {
|
||||
message: format!("Falling back from WebSockets to HTTPS transport. {err:#}"),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn notify_stream_retry(
|
||||
&self,
|
||||
retries: u64,
|
||||
max_retries: u64,
|
||||
delay: Duration,
|
||||
err: &CodexErr,
|
||||
) {
|
||||
warn!(
|
||||
"stream disconnected - retrying sampling request ({retries}/{max_retries} in {delay:?})..."
|
||||
);
|
||||
self.sess
|
||||
.send_event(
|
||||
&self.turn_context,
|
||||
EventMsg::StreamError(StreamErrorEvent {
|
||||
message: format!("Reconnecting... {retries}/{max_retries}"),
|
||||
codex_error_info: Some(CodexErrorInfo::ResponseStreamDisconnected {
|
||||
http_status_code: err.http_status_code_value(),
|
||||
}),
|
||||
additional_details: Some(err.to_string()),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
@@ -40,6 +40,7 @@ use crate::parse_turn_item;
|
||||
use crate::plugins::build_plugin_injections;
|
||||
use crate::resolve_skill_dependencies_for_turn;
|
||||
use crate::session::PreviousTurnSettings;
|
||||
use crate::session::sampling_loop::run_sampling_request;
|
||||
use crate::session::session::Session;
|
||||
use crate::session::turn_context::TurnContext;
|
||||
use crate::stream_events_utils::HandleOutputCtx;
|
||||
@@ -57,7 +58,6 @@ use crate::tools::router::ToolRouterParams;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::turn_timing::record_turn_ttft_metric;
|
||||
use crate::unavailable_tool::collect_unavailable_called_tools;
|
||||
use crate::util::backoff;
|
||||
use crate::util::error_or_panic;
|
||||
use codex_analytics::AppInvocation;
|
||||
use codex_analytics::CompactionPhase;
|
||||
@@ -71,6 +71,9 @@ use codex_hooks::HookEvent;
|
||||
use codex_hooks::HookEventAfterAgent;
|
||||
use codex_hooks::HookPayload;
|
||||
use codex_hooks::HookResult;
|
||||
use codex_kernel::PromptConfig;
|
||||
use codex_kernel::SamplingRequestResult;
|
||||
use codex_kernel::ToolConfig;
|
||||
use codex_protocol::config_types::ModeKind;
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::Result as CodexResult;
|
||||
@@ -932,16 +935,11 @@ fn connector_inserted_in_messages(
|
||||
connector_count == 1 && skill_count == 0 && mention_names_lower.contains(&mention_slug)
|
||||
}
|
||||
|
||||
pub(crate) fn build_prompt(
|
||||
input: Vec<ResponseItem>,
|
||||
router: &ToolRouter,
|
||||
pub(super) fn build_prompt_config(
|
||||
turn_context: &TurnContext,
|
||||
base_instructions: BaseInstructions,
|
||||
) -> Prompt {
|
||||
Prompt {
|
||||
input,
|
||||
tools: router.model_visible_specs(),
|
||||
parallel_tool_calls: turn_context.model_info.supports_parallel_tool_calls,
|
||||
) -> PromptConfig {
|
||||
PromptConfig {
|
||||
base_instructions,
|
||||
personality: turn_context.personality,
|
||||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
@@ -951,156 +949,23 @@ pub(crate) fn build_prompt(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[instrument(level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
turn_id = %turn_context.sub_id,
|
||||
model = %turn_context.model_info.slug,
|
||||
cwd = %turn_context.cwd.display()
|
||||
)
|
||||
)]
|
||||
async fn run_sampling_request(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
client_session: &mut ModelClientSession,
|
||||
turn_metadata_header: Option<&str>,
|
||||
input: Vec<ResponseItem>,
|
||||
explicitly_enabled_connectors: &HashSet<String>,
|
||||
skills_outcome: Option<&SkillLoadOutcome>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<SamplingRequestResult> {
|
||||
let router = built_tools(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
&input,
|
||||
explicitly_enabled_connectors,
|
||||
skills_outcome,
|
||||
&cancellation_token,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let base_instructions = sess.get_base_instructions().await;
|
||||
|
||||
let tool_runtime = ToolCallRuntime::new(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
);
|
||||
let _code_mode_worker = sess
|
||||
.services
|
||||
.code_mode_service
|
||||
.start_turn_worker(
|
||||
&sess,
|
||||
&turn_context,
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
)
|
||||
.await;
|
||||
let mut retries = 0;
|
||||
let mut initial_input = Some(input);
|
||||
loop {
|
||||
let prompt_input = if let Some(input) = initial_input.take() {
|
||||
input
|
||||
} else {
|
||||
sess.clone_history()
|
||||
.await
|
||||
.for_prompt(&turn_context.model_info.input_modalities)
|
||||
};
|
||||
let prompt = build_prompt(
|
||||
prompt_input,
|
||||
router.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
base_instructions.clone(),
|
||||
);
|
||||
let err = match try_run_sampling_request(
|
||||
tool_runtime.clone(),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
client_session,
|
||||
turn_metadata_header,
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&prompt,
|
||||
cancellation_token.child_token(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(output) => {
|
||||
return Ok(output);
|
||||
}
|
||||
Err(CodexErr::ContextWindowExceeded) => {
|
||||
sess.set_total_tokens_full(&turn_context).await;
|
||||
return Err(CodexErr::ContextWindowExceeded);
|
||||
}
|
||||
Err(CodexErr::UsageLimitReached(e)) => {
|
||||
let rate_limits = e.rate_limits.clone();
|
||||
if let Some(rate_limits) = rate_limits {
|
||||
sess.update_rate_limits(&turn_context, *rate_limits).await;
|
||||
}
|
||||
return Err(CodexErr::UsageLimitReached(e));
|
||||
}
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
if !err.is_retryable() {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// Use the configured provider-specific stream retry budget.
|
||||
let max_retries = turn_context.provider.info().stream_max_retries();
|
||||
if retries >= max_retries
|
||||
&& client_session.try_switch_fallback_transport(
|
||||
&turn_context.session_telemetry,
|
||||
&turn_context.model_info,
|
||||
)
|
||||
{
|
||||
sess.send_event(
|
||||
&turn_context,
|
||||
EventMsg::Warning(WarningEvent {
|
||||
message: format!("Falling back from WebSockets to HTTPS transport. {err:#}"),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
retries = 0;
|
||||
continue;
|
||||
}
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
let delay = match &err {
|
||||
CodexErr::Stream(_, requested_delay) => {
|
||||
requested_delay.unwrap_or_else(|| backoff(retries))
|
||||
}
|
||||
_ => backoff(retries),
|
||||
};
|
||||
warn!(
|
||||
"stream disconnected - retrying sampling request ({retries}/{max_retries} in {delay:?})...",
|
||||
);
|
||||
|
||||
// In release builds, hide the first websocket retry notification to reduce noisy
|
||||
// transient reconnect messages. In debug builds, keep full visibility for diagnosis.
|
||||
let report_error = retries > 1
|
||||
|| cfg!(debug_assertions)
|
||||
|| !sess.services.model_client.responses_websocket_enabled();
|
||||
if report_error {
|
||||
// Surface retry information to any UI/front‑end so the
|
||||
// user understands what is happening instead of staring
|
||||
// at a seemingly frozen screen.
|
||||
sess.notify_stream_error(
|
||||
&turn_context,
|
||||
format!("Reconnecting... {retries}/{max_retries}"),
|
||||
err,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
tokio::time::sleep(delay).await;
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
pub(super) fn build_tool_config(router: &ToolRouter, turn_context: &TurnContext) -> ToolConfig {
|
||||
ToolConfig {
|
||||
tools: router.model_visible_specs(),
|
||||
parallel_tool_calls: turn_context.model_info.supports_parallel_tool_calls,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_prompt(
|
||||
input: Vec<ResponseItem>,
|
||||
router: &ToolRouter,
|
||||
turn_context: &TurnContext,
|
||||
base_instructions: BaseInstructions,
|
||||
) -> Prompt {
|
||||
build_prompt_config(turn_context, base_instructions)
|
||||
.build_prompt(input, &build_tool_config(router, turn_context))
|
||||
}
|
||||
|
||||
#[expect(
|
||||
clippy::await_holding_invalid_type,
|
||||
reason = "tool router construction reads through the session-owned manager guard"
|
||||
@@ -1243,12 +1108,6 @@ pub(crate) async fn built_tools(
|
||||
)))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SamplingRequestResult {
|
||||
needs_follow_up: bool,
|
||||
last_agent_message: Option<String>,
|
||||
}
|
||||
|
||||
/// Ephemeral per-response state for streaming a single proposed plan.
|
||||
/// This is intentionally not persisted or stored in session/state since it
|
||||
/// only exists while a response is actively streaming. The final plan text
|
||||
@@ -1806,7 +1665,7 @@ async fn drain_in_flight(
|
||||
model = %turn_context.model_info.slug
|
||||
)
|
||||
)]
|
||||
async fn try_run_sampling_request(
|
||||
pub(super) async fn try_run_sampling_request(
|
||||
tool_runtime: ToolCallRuntime,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
|
||||
@@ -7,21 +7,22 @@ use codex_protocol::config_types::ModeKind;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_utils_stream_parser::strip_citations;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
|
||||
use crate::context::ContextualUserFragment;
|
||||
use crate::context::ImageGenerationInstructions;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::parse_turn_item;
|
||||
use crate::session::session::Session;
|
||||
use crate::session::turn_context::TurnContext;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_kernel::CompletedResponseItem;
|
||||
use codex_kernel::KernelToolExecutor;
|
||||
use codex_kernel::execute_tool_call_with_default_output;
|
||||
use codex_kernel::response_input_to_response_item;
|
||||
use codex_memories_read::citations::parse_memory_citation;
|
||||
use codex_memories_read::citations::thread_ids_from_memory_citation;
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::Result;
|
||||
use codex_protocol::models::FunctionCallOutputBody;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::MessagePhase;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -225,9 +226,8 @@ pub(crate) async fn handle_output_item_done(
|
||||
let mut output = OutputItemResult::default();
|
||||
let plan_mode = ctx.turn_context.collaboration_mode.mode == ModeKind::Plan;
|
||||
|
||||
match ToolRouter::build_tool_call(ctx.sess.as_ref(), item.clone()).await {
|
||||
// The model emitted a tool call; log it, persist the item immediately, and queue the tool execution.
|
||||
Ok(Some(call)) => {
|
||||
match ctx.tool_runtime.classify_response_item(item).await? {
|
||||
CompletedResponseItem::ToolCall { item, call } => {
|
||||
ctx.sess
|
||||
.accept_mailbox_delivery_for_current_turn(&ctx.turn_context.sub_id)
|
||||
.await;
|
||||
@@ -244,17 +244,20 @@ pub(crate) async fn handle_output_item_done(
|
||||
.await;
|
||||
|
||||
let cancellation_token = ctx.cancellation_token.child_token();
|
||||
let tool_future: InFlightFuture<'static> = Box::pin(
|
||||
ctx.tool_runtime
|
||||
.clone()
|
||||
.handle_tool_call(call, cancellation_token),
|
||||
);
|
||||
let tool_runtime = ctx.tool_runtime.clone();
|
||||
let handle = AbortOnDropHandle::new(tokio::spawn(async move {
|
||||
execute_tool_call_with_default_output(&tool_runtime, call, cancellation_token).await
|
||||
}));
|
||||
let tool_future: InFlightFuture<'static> = Box::pin(async move {
|
||||
handle.await.map_err(|err| {
|
||||
CodexErr::Fatal(format!("tool task failed to receive: {err:?}"))
|
||||
})?
|
||||
});
|
||||
|
||||
output.needs_follow_up = true;
|
||||
output.tool_future = Some(tool_future);
|
||||
}
|
||||
// No tool call: convert messages/reasoning into turn items and mark them as complete.
|
||||
Ok(None) => {
|
||||
CompletedResponseItem::NonTool(item) => {
|
||||
if let Some(turn_item) = handle_non_tool_response_item(
|
||||
ctx.sess.as_ref(),
|
||||
ctx.turn_context.as_ref(),
|
||||
@@ -286,21 +289,7 @@ pub(crate) async fn handle_output_item_done(
|
||||
|
||||
output.last_agent_message = last_agent_message;
|
||||
}
|
||||
// Guardrail: the model issued a LocalShellCall without an id; surface the error back into history.
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||
let msg = "LocalShellCall without call_id or id";
|
||||
ctx.turn_context
|
||||
.session_telemetry
|
||||
.log_tool_failed("local_shell", msg);
|
||||
tracing::error!(msg);
|
||||
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text(msg.to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
CompletedResponseItem::ImmediateResponse { item, response } => {
|
||||
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
|
||||
.await;
|
||||
if let Some(response_item) = response_input_to_response_item(&response) {
|
||||
@@ -314,32 +303,6 @@ pub(crate) async fn handle_output_item_done(
|
||||
|
||||
output.needs_follow_up = true;
|
||||
}
|
||||
// The tool request should be answered directly (or was denied); push that response into the transcript.
|
||||
Err(FunctionCallError::RespondToModel(message)) => {
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text(message),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
|
||||
.await;
|
||||
if let Some(response_item) = response_input_to_response_item(&response) {
|
||||
ctx.sess
|
||||
.record_conversation_items(
|
||||
&ctx.turn_context,
|
||||
std::slice::from_ref(&response_item),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
output.needs_follow_up = true;
|
||||
}
|
||||
// A fatal error occurred; surface it back into history.
|
||||
Err(FunctionCallError::Fatal(message)) => {
|
||||
return Err(CodexErr::Fatal(message));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
@@ -465,45 +428,6 @@ fn completed_item_defers_mailbox_delivery_to_next_turn(
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn response_input_to_response_item(input: &ResponseInputItem) -> Option<ResponseItem> {
|
||||
match input {
|
||||
ResponseInputItem::FunctionCallOutput { call_id, output } => {
|
||||
Some(ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
})
|
||||
}
|
||||
ResponseInputItem::CustomToolCallOutput {
|
||||
call_id,
|
||||
name,
|
||||
output,
|
||||
} => Some(ResponseItem::CustomToolCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
name: name.clone(),
|
||||
output: output.clone(),
|
||||
}),
|
||||
ResponseInputItem::McpToolCallOutput { call_id, output } => {
|
||||
let output = output.as_function_call_output_payload();
|
||||
Some(ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output,
|
||||
})
|
||||
}
|
||||
ResponseInputItem::ToolSearchOutput {
|
||||
call_id,
|
||||
status,
|
||||
execution,
|
||||
tools,
|
||||
} => Some(ResponseItem::ToolSearchOutput {
|
||||
call_id: Some(call_id.clone()),
|
||||
status: status.clone(),
|
||||
execution: execution.clone(),
|
||||
tools: tools.clone(),
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "stream_events_utils_tests.rs"]
|
||||
mod tests;
|
||||
|
||||
@@ -20,8 +20,14 @@ use crate::tools::registry::ToolArgumentDiffConsumer;
|
||||
use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolCallSource;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_kernel::CompletedResponseItem;
|
||||
use codex_kernel::KernelToolCall;
|
||||
use codex_kernel::KernelToolExecutor;
|
||||
use codex_kernel::ToolCallError;
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::Result as CodexResult;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_tools::ToolSpec;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -60,25 +66,6 @@ impl ToolCallRuntime {
|
||||
self.router.create_diff_consumer(tool_name)
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub(crate) fn handle_tool_call(
|
||||
self,
|
||||
call: ToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
|
||||
let error_call = call.clone();
|
||||
let future =
|
||||
self.handle_tool_call_with_source(call, ToolCallSource::Direct, cancellation_token);
|
||||
async move {
|
||||
match future.await {
|
||||
Ok(response) => Ok(response.into_response()),
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
Err(other) => Ok(Self::failure_response(error_call, other)),
|
||||
}
|
||||
}
|
||||
.in_current_span()
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all)]
|
||||
pub(crate) fn handle_tool_call_with_source(
|
||||
self,
|
||||
@@ -144,33 +131,6 @@ impl ToolCallRuntime {
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
fn failure_response(call: ToolCall, err: FunctionCallError) -> ResponseInputItem {
|
||||
let message = err.to_string();
|
||||
match call.payload {
|
||||
ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput {
|
||||
call_id: call.call_id,
|
||||
status: "completed".to_string(),
|
||||
execution: "client".to_string(),
|
||||
tools: Vec::new(),
|
||||
},
|
||||
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
|
||||
call_id: call.call_id,
|
||||
name: None,
|
||||
output: codex_protocol::models::FunctionCallOutputPayload {
|
||||
body: codex_protocol::models::FunctionCallOutputBody::Text(message),
|
||||
success: Some(false),
|
||||
},
|
||||
},
|
||||
_ => ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call.call_id,
|
||||
output: codex_protocol::models::FunctionCallOutputPayload {
|
||||
body: codex_protocol::models::FunctionCallOutputBody::Text(message),
|
||||
success: Some(false),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn aborted_response(call: &ToolCall, secs: f32) -> AnyToolResult {
|
||||
AnyToolResult {
|
||||
call_id: call.call_id.clone(),
|
||||
@@ -195,3 +155,90 @@ impl ToolCallRuntime {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelToolCall for ToolCall {
|
||||
fn error_response(&self, message: String) -> ResponseInputItem {
|
||||
failure_response(self, message)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelToolExecutor for ToolCallRuntime {
|
||||
type Call = ToolCall;
|
||||
|
||||
async fn classify_response_item(
|
||||
&self,
|
||||
item: ResponseItem,
|
||||
) -> CodexResult<CompletedResponseItem<Self::Call>> {
|
||||
match ToolRouter::build_tool_call(self.session.as_ref(), item.clone()).await {
|
||||
Ok(Some(call)) => Ok(CompletedResponseItem::ToolCall { item, call }),
|
||||
Ok(None) => Ok(CompletedResponseItem::NonTool(item)),
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||
let message = "LocalShellCall without call_id or id".to_string();
|
||||
self.turn_context
|
||||
.session_telemetry
|
||||
.log_tool_failed("local_shell", &message);
|
||||
tracing::error!("{message}");
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: codex_protocol::models::FunctionCallOutputPayload::from_text(message),
|
||||
};
|
||||
Ok(CompletedResponseItem::ImmediateResponse { item, response })
|
||||
}
|
||||
Err(FunctionCallError::RespondToModel(message)) => {
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: codex_protocol::models::FunctionCallOutputPayload::from_text(message),
|
||||
};
|
||||
Ok(CompletedResponseItem::ImmediateResponse { item, response })
|
||||
}
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_tool_call(
|
||||
&self,
|
||||
call: Self::Call,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Result<ResponseInputItem, ToolCallError> {
|
||||
match self
|
||||
.clone()
|
||||
.handle_tool_call_with_source(call, ToolCallSource::Direct, cancellation_token)
|
||||
.await
|
||||
{
|
||||
Ok(result) => Ok(result.into_response()),
|
||||
Err(FunctionCallError::RespondToModel(message)) => {
|
||||
Err(ToolCallError::RespondToModel(message))
|
||||
}
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => Err(ToolCallError::RespondToModel(
|
||||
"LocalShellCall without call_id or id".to_string(),
|
||||
)),
|
||||
Err(FunctionCallError::Fatal(message)) => Err(ToolCallError::Fatal(message)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn failure_response(call: &ToolCall, message: String) -> ResponseInputItem {
|
||||
match &call.payload {
|
||||
ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
status: "completed".to_string(),
|
||||
execution: "client".to_string(),
|
||||
tools: Vec::new(),
|
||||
},
|
||||
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
name: None,
|
||||
output: codex_protocol::models::FunctionCallOutputPayload {
|
||||
body: codex_protocol::models::FunctionCallOutputBody::Text(message),
|
||||
success: Some(false),
|
||||
},
|
||||
},
|
||||
_ => ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
output: codex_protocol::models::FunctionCallOutputPayload {
|
||||
body: codex_protocol::models::FunctionCallOutputBody::Text(message),
|
||||
success: Some(false),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -675,10 +675,13 @@ fn maybe_wrap_shell_lc_with_snapshot_keeps_user_proxy_env_when_proxy_inactive()
|
||||
&HashMap::new(),
|
||||
&HashMap::new(),
|
||||
);
|
||||
let output = Command::new(&rewritten[0])
|
||||
.args(&rewritten[1..])
|
||||
.output()
|
||||
.expect("run rewritten command");
|
||||
let mut process = Command::new(&rewritten[0]);
|
||||
process.args(&rewritten[1..]);
|
||||
for key in PROXY_ENV_KEYS {
|
||||
process.env_remove(key);
|
||||
}
|
||||
process.env_remove(PROXY_ACTIVE_ENV_KEY);
|
||||
let output = process.output().expect("run rewritten command");
|
||||
|
||||
assert!(output.status.success(), "command failed: {output:?}");
|
||||
assert_eq!(
|
||||
|
||||
@@ -338,23 +338,25 @@ async fn execve_permission_request_hook_short_circuits_prompt() -> anyhow::Resul
|
||||
std::fs::write(
|
||||
&script_path,
|
||||
format!(
|
||||
"#!/bin/sh\ncat > {log_path}\nprintf '%s\\n' '{response}'\n",
|
||||
log_path = shlex::try_quote(log_path.to_string_lossy().as_ref())?,
|
||||
r#"import json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
payload = json.load(sys.stdin)
|
||||
with Path(r"{log_path}").open("a", encoding="utf-8") as handle:
|
||||
handle.write(json.dumps(payload) + "\n")
|
||||
|
||||
print({response:?})
|
||||
"#,
|
||||
log_path = log_path.display(),
|
||||
response = "{\"hookSpecificOutput\":{\"hookEventName\":\"PermissionRequest\",\"decision\":{\"behavior\":\"allow\"}}}",
|
||||
),
|
||||
)
|
||||
.with_context(|| format!("write hook script to {}", script_path.display()))?;
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let mut permissions = std::fs::metadata(&script_path)
|
||||
.with_context(|| format!("read hook script metadata from {}", script_path.display()))?
|
||||
.permissions();
|
||||
permissions.set_mode(0o755);
|
||||
std::fs::set_permissions(&script_path, permissions)
|
||||
.with_context(|| format!("set hook script permissions on {}", script_path.display()))?;
|
||||
}
|
||||
let script_path_arg = format!(
|
||||
"'{}'",
|
||||
script_path.display().to_string().replace('\'', "'\\''")
|
||||
);
|
||||
std::fs::write(
|
||||
turn_context.config.codex_home.join("hooks.json"),
|
||||
serde_json::json!({
|
||||
@@ -362,7 +364,7 @@ async fn execve_permission_request_hook_short_circuits_prompt() -> anyhow::Resul
|
||||
"PermissionRequest": [{
|
||||
"hooks": [{
|
||||
"type": "command",
|
||||
"command": script_path.display().to_string(),
|
||||
"command": format!("python3 {script_path_arg}"),
|
||||
}]
|
||||
}]
|
||||
}
|
||||
|
||||
@@ -111,15 +111,17 @@ async fn responses_stream_includes_subagent_header_on_review() {
|
||||
);
|
||||
let mut client_session = client.new_session();
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
let prompt = Prompt {
|
||||
input: vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
phase: None,
|
||||
}],
|
||||
phase: None,
|
||||
}];
|
||||
..Prompt::default()
|
||||
};
|
||||
|
||||
let mut stream = client_session
|
||||
.stream(
|
||||
@@ -237,15 +239,17 @@ async fn responses_stream_includes_subagent_header_on_other() {
|
||||
);
|
||||
let mut client_session = client.new_session();
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
let prompt = Prompt {
|
||||
input: vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
phase: None,
|
||||
}],
|
||||
phase: None,
|
||||
}];
|
||||
..Prompt::default()
|
||||
};
|
||||
|
||||
let mut stream = client_session
|
||||
.stream(
|
||||
@@ -352,15 +356,17 @@ async fn responses_respects_model_info_overrides_from_config() {
|
||||
);
|
||||
let mut client_session = client.new_session();
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
let prompt = Prompt {
|
||||
input: vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
phase: None,
|
||||
}],
|
||||
phase: None,
|
||||
}];
|
||||
..Prompt::default()
|
||||
};
|
||||
|
||||
let mut stream = client_session
|
||||
.stream(
|
||||
|
||||
@@ -1744,9 +1744,10 @@ fn assistant_message_item(id: &str, text: &str) -> ResponseItem {
|
||||
}
|
||||
|
||||
fn prompt_with_input(input: Vec<ResponseItem>) -> Prompt {
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = input;
|
||||
prompt
|
||||
Prompt {
|
||||
input,
|
||||
..Prompt::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn prompt_with_input_and_instructions(input: Vec<ResponseItem>, instructions: &str) -> Prompt {
|
||||
|
||||
6
codex-rs/kernel/BUILD.bazel
Normal file
6
codex-rs/kernel/BUILD.bazel
Normal file
@@ -0,0 +1,6 @@
|
||||
load("//:defs.bzl", "codex_rust_crate")
|
||||
|
||||
codex_rust_crate(
|
||||
name = "kernel",
|
||||
crate_name = "codex_kernel",
|
||||
)
|
||||
25
codex-rs/kernel/Cargo.toml
Normal file
25
codex-rs/kernel/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "codex-kernel"
|
||||
version.workspace = true
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "codex_kernel"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
codex-protocol = { workspace = true }
|
||||
codex-tools = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true, features = ["time"] }
|
||||
tokio-util = { workspace = true, features = ["rt"] }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
17
codex-rs/kernel/src/lib.rs
Normal file
17
codex-rs/kernel/src/lib.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
mod prompt;
|
||||
mod sampling;
|
||||
mod tools;
|
||||
|
||||
pub use prompt::Prompt;
|
||||
pub use prompt::PromptConfig;
|
||||
pub use sampling::PreparedSamplingRequest;
|
||||
pub use sampling::SamplingLoopHost;
|
||||
pub use sampling::SamplingRequestResult;
|
||||
pub use sampling::run_sampling_request_loop;
|
||||
pub use tools::CompletedResponseItem;
|
||||
pub use tools::KernelToolCall;
|
||||
pub use tools::KernelToolExecutor;
|
||||
pub use tools::ToolCallError;
|
||||
pub use tools::ToolConfig;
|
||||
pub use tools::execute_tool_call_with_default_output;
|
||||
pub use tools::response_input_to_response_item;
|
||||
257
codex-rs/kernel/src/prompt.rs
Normal file
257
codex-rs/kernel/src/prompt.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
use crate::ToolConfig;
|
||||
use codex_protocol::config_types::Personality;
|
||||
use codex_protocol::models::BaseInstructions;
|
||||
use codex_protocol::models::FunctionCallOutputBody;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_tools::ToolSpec;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// API request payload for a single model turn.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
pub input: Vec<ResponseItem>,
|
||||
pub tools: Vec<ToolSpec>,
|
||||
pub parallel_tool_calls: bool,
|
||||
pub base_instructions: BaseInstructions,
|
||||
pub personality: Option<Personality>,
|
||||
pub output_schema: Option<Value>,
|
||||
pub output_schema_strict: bool,
|
||||
}
|
||||
|
||||
/// Retry-stable prompt settings used to rebuild a [`Prompt`] after tool calls or stream retries.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PromptConfig {
|
||||
pub base_instructions: BaseInstructions,
|
||||
pub personality: Option<Personality>,
|
||||
pub output_schema: Option<Value>,
|
||||
pub output_schema_strict: bool,
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
let mut input = self.input.clone();
|
||||
let is_freeform_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
||||
ToolSpec::Freeform(f) => f.name == "apply_patch",
|
||||
_ => false,
|
||||
});
|
||||
if is_freeform_apply_patch_tool_present {
|
||||
reserialize_shell_outputs(&mut input);
|
||||
}
|
||||
|
||||
input
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptConfig {
|
||||
pub fn build_prompt(&self, input: Vec<ResponseItem>, tool_config: &ToolConfig) -> Prompt {
|
||||
Prompt {
|
||||
input,
|
||||
tools: tool_config.tools.clone(),
|
||||
parallel_tool_calls: tool_config.parallel_tool_calls,
|
||||
base_instructions: self.base_instructions.clone(),
|
||||
personality: self.personality,
|
||||
output_schema: self.output_schema.clone(),
|
||||
output_schema_strict: self.output_schema_strict,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Prompt {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
input: Vec::new(),
|
||||
tools: Vec::new(),
|
||||
parallel_tool_calls: false,
|
||||
base_instructions: BaseInstructions::default(),
|
||||
personality: None,
|
||||
output_schema: None,
|
||||
output_schema_strict: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PromptConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
base_instructions: BaseInstructions::default(),
|
||||
personality: None,
|
||||
output_schema: None,
|
||||
output_schema_strict: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reserialize_shell_outputs(items: &mut [ResponseItem]) {
|
||||
let mut shell_call_ids: HashSet<String> = HashSet::new();
|
||||
|
||||
items.iter_mut().for_each(|item| match item {
|
||||
ResponseItem::LocalShellCall { call_id, id, .. } => {
|
||||
if let Some(identifier) = call_id.clone().or_else(|| id.clone()) {
|
||||
shell_call_ids.insert(identifier);
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id: _,
|
||||
status: _,
|
||||
call_id,
|
||||
name,
|
||||
input: _,
|
||||
} => {
|
||||
if name == "apply_patch" {
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCall { name, call_id, .. }
|
||||
if is_shell_tool_name(name) || name == "apply_patch" =>
|
||||
{
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id, output, ..
|
||||
}
|
||||
| ResponseItem::CustomToolCallOutput {
|
||||
call_id, output, ..
|
||||
} => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = output
|
||||
.text_content()
|
||||
.and_then(parse_structured_shell_output)
|
||||
{
|
||||
output.body = FunctionCallOutputBody::Text(structured);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
});
|
||||
}
|
||||
|
||||
fn is_shell_tool_name(name: &str) -> bool {
|
||||
matches!(name, "shell" | "container.exec")
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputJson {
|
||||
output: String,
|
||||
metadata: ExecOutputMetadataJson,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputMetadataJson {
|
||||
exit_code: i32,
|
||||
duration_seconds: f32,
|
||||
}
|
||||
|
||||
fn parse_structured_shell_output(raw: &str) -> Option<String> {
|
||||
let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?;
|
||||
Some(build_structured_output(&parsed))
|
||||
}
|
||||
|
||||
fn build_structured_output(parsed: &ExecOutputJson) -> String {
|
||||
let mut sections = Vec::new();
|
||||
let exit_code = parsed.metadata.exit_code;
|
||||
sections.push(format!("Exit code: {exit_code}"));
|
||||
let duration_seconds = parsed.metadata.duration_seconds;
|
||||
sections.push(format!("Wall time: {duration_seconds} seconds"));
|
||||
|
||||
let mut output = parsed.output.clone();
|
||||
if let Some((stripped, total_lines)) = strip_total_output_header(&parsed.output) {
|
||||
sections.push(format!("Total output lines: {total_lines}"));
|
||||
output = stripped.to_string();
|
||||
}
|
||||
|
||||
sections.push("Output:".to_string());
|
||||
sections.push(output);
|
||||
sections.join("\n")
|
||||
}
|
||||
|
||||
fn strip_total_output_header(output: &str) -> Option<(&str, u32)> {
|
||||
let after_prefix = output.strip_prefix("Total output lines: ")?;
|
||||
let (total_segment, remainder) = after_prefix.split_once('\n')?;
|
||||
let total_lines = total_segment.parse::<u32>().ok()?;
|
||||
let remainder = remainder.strip_prefix('\n').unwrap_or(remainder);
|
||||
Some((remainder, total_lines))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Prompt;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_tools::FreeformTool;
|
||||
use codex_tools::FreeformToolFormat;
|
||||
use codex_tools::ToolSpec;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn reserializes_shell_outputs_for_function_and_custom_tool_calls() {
|
||||
let raw_output = r#"{"output":"hello","metadata":{"exit_code":0,"duration_seconds":0.5}}"#;
|
||||
let expected_output = "Exit code: 0\nWall time: 0.5 seconds\nOutput:\nhello";
|
||||
let prompt = Prompt {
|
||||
input: vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "shell".to_string(),
|
||||
namespace: None,
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-1".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-1".to_string(),
|
||||
output: FunctionCallOutputPayload::from_text(raw_output.to_string()),
|
||||
},
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "call-2".to_string(),
|
||||
name: "apply_patch".to_string(),
|
||||
input: "*** Begin Patch".to_string(),
|
||||
},
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: "call-2".to_string(),
|
||||
name: None,
|
||||
output: FunctionCallOutputPayload::from_text(raw_output.to_string()),
|
||||
},
|
||||
],
|
||||
tools: vec![ToolSpec::Freeform(FreeformTool {
|
||||
name: "apply_patch".to_string(),
|
||||
description: "patch".to_string(),
|
||||
format: FreeformToolFormat {
|
||||
r#type: "grammar".to_string(),
|
||||
syntax: "lark".to_string(),
|
||||
definition: "patch".to_string(),
|
||||
},
|
||||
})],
|
||||
..Prompt::default()
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
prompt.get_formatted_input(),
|
||||
vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "shell".to_string(),
|
||||
namespace: None,
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-1".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-1".to_string(),
|
||||
output: FunctionCallOutputPayload::from_text(expected_output.to_string()),
|
||||
},
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "call-2".to_string(),
|
||||
name: "apply_patch".to_string(),
|
||||
input: "*** Begin Patch".to_string(),
|
||||
},
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: "call-2".to_string(),
|
||||
name: None,
|
||||
output: FunctionCallOutputPayload::from_text(expected_output.to_string()),
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
170
codex-rs/kernel/src/sampling.rs
Normal file
170
codex-rs/kernel/src/sampling.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use crate::KernelToolExecutor;
|
||||
use crate::Prompt;
|
||||
use crate::PromptConfig;
|
||||
use crate::ToolConfig;
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::Result;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use rand::Rng;
|
||||
use std::time::Duration;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
const INITIAL_DELAY_MS: u64 = 200;
|
||||
const BACKOFF_FACTOR: f64 = 2.0;
|
||||
|
||||
/// Host-facing bundle created once before the kernel begins retrying a logical sampling request.
|
||||
///
|
||||
/// Implementations should place retry-stable state here so retries can rebuild prompts without
|
||||
/// repeating expensive setup work such as tool routing or per-turn worker startup.
|
||||
pub struct PreparedSamplingRequest<Runtime, Tools> {
|
||||
pub prompt_config: PromptConfig,
|
||||
pub tool_config: ToolConfig,
|
||||
pub runtime: Runtime,
|
||||
pub tools: Tools,
|
||||
}
|
||||
|
||||
/// Final outcome of a logical sampling request after the model/tool loop reaches a stable state.
|
||||
#[derive(Debug)]
|
||||
pub struct SamplingRequestResult {
|
||||
pub needs_follow_up: bool,
|
||||
pub last_agent_message: Option<String>,
|
||||
}
|
||||
|
||||
/// Host adapter for the kernel's retryable sampling loop.
|
||||
///
|
||||
/// Implementations provide the concrete model transport, tool runtime, history access, and
|
||||
/// user-facing retry notifications while allowing `codex-kernel` to stay agnostic of local
|
||||
/// filesystem, network, or tool-execution details.
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait SamplingLoopHost {
|
||||
type ClientSession;
|
||||
type Runtime;
|
||||
type Tools: KernelToolExecutor;
|
||||
|
||||
async fn prepare_sampling_request(
|
||||
&self,
|
||||
input: &[ResponseItem],
|
||||
cancellation_token: &CancellationToken,
|
||||
) -> Result<PreparedSamplingRequest<Self::Runtime, Self::Tools>>;
|
||||
|
||||
async fn history_prompt_input(&self) -> Vec<ResponseItem>;
|
||||
|
||||
async fn run_single_sampling_request(
|
||||
&self,
|
||||
runtime: &Self::Runtime,
|
||||
tools: &Self::Tools,
|
||||
client_session: &mut Self::ClientSession,
|
||||
prompt: &Prompt,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Result<SamplingRequestResult>;
|
||||
|
||||
fn stream_max_retries(&self) -> u64;
|
||||
|
||||
fn try_switch_fallback_transport(&self, client_session: &mut Self::ClientSession) -> bool;
|
||||
|
||||
fn should_notify_stream_retry(&self, retries: u64, err: &CodexErr) -> bool;
|
||||
|
||||
async fn handle_context_window_exceeded(&self);
|
||||
|
||||
async fn handle_usage_limit_reached(&self, rate_limits: Option<RateLimitSnapshot>);
|
||||
|
||||
async fn notify_fallback_to_http(&self, err: &CodexErr);
|
||||
|
||||
async fn notify_stream_retry(
|
||||
&self,
|
||||
retries: u64,
|
||||
max_retries: u64,
|
||||
delay: Duration,
|
||||
err: &CodexErr,
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn run_sampling_request_loop<Host>(
|
||||
host: &Host,
|
||||
client_session: &mut Host::ClientSession,
|
||||
initial_input: Vec<ResponseItem>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Result<SamplingRequestResult>
|
||||
where
|
||||
Host: SamplingLoopHost,
|
||||
{
|
||||
let prepared = host
|
||||
.prepare_sampling_request(initial_input.as_slice(), &cancellation_token)
|
||||
.await?;
|
||||
let mut retries = 0;
|
||||
let mut initial_input = Some(initial_input);
|
||||
loop {
|
||||
let prompt_input = if let Some(input) = initial_input.take() {
|
||||
input
|
||||
} else {
|
||||
host.history_prompt_input().await
|
||||
};
|
||||
let prompt = prepared
|
||||
.prompt_config
|
||||
.build_prompt(prompt_input, &prepared.tool_config);
|
||||
let err = match host
|
||||
.run_single_sampling_request(
|
||||
&prepared.runtime,
|
||||
&prepared.tools,
|
||||
client_session,
|
||||
&prompt,
|
||||
cancellation_token.child_token(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(output) => {
|
||||
return Ok(output);
|
||||
}
|
||||
Err(CodexErr::ContextWindowExceeded) => {
|
||||
host.handle_context_window_exceeded().await;
|
||||
return Err(CodexErr::ContextWindowExceeded);
|
||||
}
|
||||
Err(CodexErr::UsageLimitReached(error)) => {
|
||||
host.handle_usage_limit_reached(
|
||||
error
|
||||
.rate_limits
|
||||
.as_ref()
|
||||
.map(|snapshot| (**snapshot).clone()),
|
||||
)
|
||||
.await;
|
||||
return Err(CodexErr::UsageLimitReached(error));
|
||||
}
|
||||
Err(err) => err,
|
||||
};
|
||||
|
||||
if !err.is_retryable() {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
let max_retries = host.stream_max_retries();
|
||||
if retries >= max_retries && host.try_switch_fallback_transport(client_session) {
|
||||
host.notify_fallback_to_http(&err).await;
|
||||
retries = 0;
|
||||
continue;
|
||||
}
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
let delay = match &err {
|
||||
CodexErr::Stream(_, requested_delay) => {
|
||||
requested_delay.unwrap_or_else(|| backoff(retries))
|
||||
}
|
||||
_ => backoff(retries),
|
||||
};
|
||||
if host.should_notify_stream_retry(retries, &err) {
|
||||
host.notify_stream_retry(retries, max_retries, delay, &err)
|
||||
.await;
|
||||
}
|
||||
tokio::time::sleep(delay).await;
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn backoff(attempt: u64) -> Duration {
|
||||
let exp = BACKOFF_FACTOR.powi(attempt.saturating_sub(1) as i32);
|
||||
let base = (INITIAL_DELAY_MS as f64 * exp) as u64;
|
||||
let jitter = rand::rng().random_range(0.9..1.1);
|
||||
Duration::from_millis((base as f64 * jitter) as u64)
|
||||
}
|
||||
253
codex-rs/kernel/src/tools.rs
Normal file
253
codex-rs/kernel/src/tools.rs
Normal file
@@ -0,0 +1,253 @@
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::Result;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_tools::ToolSpec;
|
||||
use std::fmt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
/// Retry-stable tool registration for a logical model request.
|
||||
///
|
||||
/// The host chooses which tools are exposed to the model and whether the model
|
||||
/// may issue them in parallel. `codex-kernel` treats this as opaque model-facing
|
||||
/// configuration and uses it to rebuild prompts across retries.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ToolConfig {
|
||||
pub tools: Vec<ToolSpec>,
|
||||
pub parallel_tool_calls: bool,
|
||||
}
|
||||
|
||||
/// Metadata a host-specific tool call must expose so `codex-kernel` can
|
||||
/// synthesize a model-visible failure response when execution cannot complete.
|
||||
pub trait KernelToolCall: Clone + Send + Sync {
|
||||
fn error_response(&self, message: String) -> ResponseInputItem;
|
||||
}
|
||||
|
||||
/// Non-fatal tool execution failures are surfaced back to the model as a
|
||||
/// complementary output item, while fatal failures abort the turn.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ToolCallError {
|
||||
RespondToModel(String),
|
||||
Fatal(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for ToolCallError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::RespondToModel(message) => write!(f, "{message}"),
|
||||
Self::Fatal(message) => write!(f, "Fatal error: {message}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Classification of a completed model output item relative to the host's tool
|
||||
/// catalog.
|
||||
pub enum CompletedResponseItem<Call> {
|
||||
NonTool(ResponseItem),
|
||||
ToolCall {
|
||||
item: ResponseItem,
|
||||
call: Call,
|
||||
},
|
||||
ImmediateResponse {
|
||||
item: ResponseItem,
|
||||
response: ResponseInputItem,
|
||||
},
|
||||
}
|
||||
|
||||
/// Host-provided tool adapter used by `codex-kernel`'s default model/tool loop.
|
||||
///
|
||||
/// Implementations own tool-call parsing and execution, while
|
||||
/// `codex-kernel` owns the generic behavior of attempting a complementary
|
||||
/// tool output item for every tool call the model emits.
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait KernelToolExecutor {
|
||||
type Call: KernelToolCall;
|
||||
|
||||
async fn classify_response_item(
|
||||
&self,
|
||||
item: ResponseItem,
|
||||
) -> Result<CompletedResponseItem<Self::Call>>;
|
||||
|
||||
async fn execute_tool_call(
|
||||
&self,
|
||||
call: Self::Call,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> std::result::Result<ResponseInputItem, ToolCallError>;
|
||||
}
|
||||
|
||||
pub async fn execute_tool_call_with_default_output<Executor>(
|
||||
executor: &Executor,
|
||||
call: Executor::Call,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Result<ResponseInputItem>
|
||||
where
|
||||
Executor: KernelToolExecutor,
|
||||
{
|
||||
match executor
|
||||
.execute_tool_call(call.clone(), cancellation_token)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(response),
|
||||
Err(ToolCallError::RespondToModel(message)) => Ok(call.error_response(message)),
|
||||
Err(ToolCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn response_input_to_response_item(input: &ResponseInputItem) -> Option<ResponseItem> {
|
||||
match input {
|
||||
ResponseInputItem::FunctionCallOutput { call_id, output } => {
|
||||
if call_id.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
})
|
||||
}
|
||||
ResponseInputItem::CustomToolCallOutput {
|
||||
call_id,
|
||||
name,
|
||||
output,
|
||||
} => {
|
||||
if call_id.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(ResponseItem::CustomToolCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
name: name.clone(),
|
||||
output: output.clone(),
|
||||
})
|
||||
}
|
||||
ResponseInputItem::McpToolCallOutput { call_id, output } => {
|
||||
if call_id.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let output = output.as_function_call_output_payload();
|
||||
Some(ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output,
|
||||
})
|
||||
}
|
||||
ResponseInputItem::ToolSearchOutput {
|
||||
call_id,
|
||||
status,
|
||||
execution,
|
||||
tools,
|
||||
} => {
|
||||
if call_id.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(ResponseItem::ToolSearchOutput {
|
||||
call_id: Some(call_id.clone()),
|
||||
status: status.clone(),
|
||||
execution: execution.clone(),
|
||||
tools: tools.clone(),
|
||||
})
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::CompletedResponseItem;
|
||||
use super::KernelToolCall;
|
||||
use super::KernelToolExecutor;
|
||||
use super::ToolCallError;
|
||||
use super::execute_tool_call_with_default_output;
|
||||
use super::response_input_to_response_item;
|
||||
use codex_protocol::error::Result;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
struct TestCall {
|
||||
call_id: String,
|
||||
}
|
||||
|
||||
impl KernelToolCall for TestCall {
|
||||
fn error_response(&self, message: String) -> ResponseInputItem {
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
call_id: self.call_id.clone(),
|
||||
output: FunctionCallOutputPayload::from_text(message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TestExecutor;
|
||||
|
||||
impl KernelToolExecutor for TestExecutor {
|
||||
type Call = TestCall;
|
||||
|
||||
async fn classify_response_item(
|
||||
&self,
|
||||
item: ResponseItem,
|
||||
) -> Result<CompletedResponseItem<Self::Call>> {
|
||||
Ok(CompletedResponseItem::NonTool(item))
|
||||
}
|
||||
|
||||
async fn execute_tool_call(
|
||||
&self,
|
||||
call: Self::Call,
|
||||
_cancellation_token: CancellationToken,
|
||||
) -> std::result::Result<ResponseInputItem, ToolCallError> {
|
||||
match call.call_id.as_str() {
|
||||
"ok" => Ok(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call.call_id,
|
||||
output: FunctionCallOutputPayload::from_text("done".to_string()),
|
||||
}),
|
||||
"retry" => Err(ToolCallError::RespondToModel("try again".to_string())),
|
||||
_ => Err(ToolCallError::Fatal("boom".to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_tool_call_with_default_output_surfaces_non_fatal_errors() {
|
||||
let response = execute_tool_call_with_default_output(
|
||||
&TestExecutor,
|
||||
TestCall {
|
||||
call_id: "retry".to_string(),
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await
|
||||
.expect("tool response");
|
||||
|
||||
assert_eq!(
|
||||
response,
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
call_id: "retry".to_string(),
|
||||
output: FunctionCallOutputPayload::from_text("try again".to_string()),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_tool_call_with_default_output_propagates_fatal_errors() {
|
||||
let err = execute_tool_call_with_default_output(
|
||||
&TestExecutor,
|
||||
TestCall {
|
||||
call_id: "fatal".to_string(),
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await
|
||||
.expect_err("fatal tool error");
|
||||
|
||||
assert_eq!(err.to_string(), "Fatal error: boom");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_input_to_response_item_skips_empty_call_ids() {
|
||||
let item = response_input_to_response_item(&ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload::from_text("boom".to_string()),
|
||||
});
|
||||
|
||||
assert_eq!(item, None);
|
||||
}
|
||||
}
|
||||
@@ -285,25 +285,27 @@ mod job {
|
||||
let (rollout_items, _, _) = RolloutRecorder::load_rollout_items(rollout_path).await?;
|
||||
let rollout_contents = serialize_filtered_rollout_response_items(&rollout_items)?;
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: build_stage_one_input_message(
|
||||
&stage_one_context.model_info,
|
||||
rollout_path,
|
||||
rollout_cwd,
|
||||
&rollout_contents,
|
||||
)?,
|
||||
let prompt = Prompt {
|
||||
input: vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: build_stage_one_input_message(
|
||||
&stage_one_context.model_info,
|
||||
rollout_path,
|
||||
rollout_cwd,
|
||||
&rollout_contents,
|
||||
)?,
|
||||
}],
|
||||
phase: None,
|
||||
}],
|
||||
phase: None,
|
||||
}];
|
||||
prompt.base_instructions = BaseInstructions {
|
||||
text: crate::stage_one::PROMPT.to_string(),
|
||||
base_instructions: BaseInstructions {
|
||||
text: crate::stage_one::PROMPT.to_string(),
|
||||
},
|
||||
output_schema: Some(output_schema()),
|
||||
output_schema_strict: true,
|
||||
..Prompt::default()
|
||||
};
|
||||
prompt.output_schema = Some(output_schema());
|
||||
prompt.output_schema_strict = true;
|
||||
|
||||
let (result, token_usage) = context
|
||||
.stream_stage_one_prompt(config, &prompt, stage_one_context)
|
||||
|
||||
Reference in New Issue
Block a user