Compare commits

...

1 Commits

Author SHA1 Message Date
Michael Bolin
0aca760ba8 PoC: codex-kernel 2026-04-29 10:38:51 -07:00
22 changed files with 1193 additions and 593 deletions

15
codex-rs/Cargo.lock generated
View File

@@ -2374,6 +2374,7 @@ dependencies = [
"codex-feedback",
"codex-git-utils",
"codex-hooks",
"codex-kernel",
"codex-login",
"codex-mcp",
"codex-memories-read",
@@ -2815,6 +2816,20 @@ dependencies = [
"tempfile",
]
[[package]]
name = "codex-kernel"
version = "0.0.0"
dependencies = [
"codex-protocol",
"codex-tools",
"pretty_assertions",
"rand 0.9.3",
"serde",
"serde_json",
"tokio",
"tokio-util",
]
[[package]]
name = "codex-keyring-store"
version = "0.0.0"

View File

@@ -43,6 +43,7 @@ members = [
"external-agent-migration",
"external-agent-sessions",
"keyring-store",
"kernel",
"file-search",
"linux-sandbox",
"lmstudio",
@@ -158,6 +159,7 @@ codex-file-search = { path = "file-search" }
codex-git-utils = { path = "git-utils" }
codex-hooks = { path = "hooks" }
codex-keyring-store = { path = "keyring-store" }
codex-kernel = { path = "kernel" }
codex-linux-sandbox = { path = "linux-sandbox" }
codex-lmstudio = { path = "lmstudio" }
codex-login = { path = "login" }

View File

@@ -331,10 +331,20 @@ async fn turn_start_emits_thread_scoped_warning_notification_for_trimmed_skills(
let warning: WarningNotification =
serde_json::from_value(params).expect("deserialize warning notification");
assert_eq!(warning.thread_id.as_deref(), Some(thread.id.as_str()));
assert_eq!(
warning.message,
"Exceeded skills context budget of 2%. All skill descriptions were removed and 7 additional skills were not included in the model-visible skills list."
);
let prefix = "Exceeded skills context budget of 2%. All skill descriptions were removed and ";
let plural_suffix = " additional skills were not included in the model-visible skills list.";
let singular_suffix = " additional skill was not included in the model-visible skills list.";
let omitted_count = warning
.message
.strip_prefix(prefix)
.and_then(|message| {
message
.strip_suffix(plural_suffix)
.or_else(|| message.strip_suffix(singular_suffix))
})
.and_then(|count| count.parse::<usize>().ok())
.expect("warning should report an omitted skill count");
assert!(omitted_count >= 1, "expected at least one omitted skill");
timeout(
DEFAULT_READ_TIMEOUT,

View File

@@ -47,6 +47,7 @@ codex-shell-command = { workspace = true }
codex-execpolicy = { workspace = true }
codex-git-utils = { workspace = true }
codex-hooks = { workspace = true }
codex-kernel = { workspace = true }
codex-network-proxy = { workspace = true }
codex-otel = { workspace = true }
codex-plugin = { workspace = true }

View File

@@ -1,14 +1,7 @@
pub use codex_api::ResponseEvent;
use codex_config::types::Personality;
pub use codex_kernel::Prompt;
use codex_protocol::error::Result;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::ResponseItem;
use codex_tools::ToolSpec;
use futures::Stream;
use serde::Deserialize;
use serde_json::Value;
use std::collections::HashSet;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
@@ -23,157 +16,6 @@ pub const REVIEW_EXIT_SUCCESS_TMPL: &str = include_str!("../templates/review/exi
pub const REVIEW_EXIT_INTERRUPTED_TMPL: &str =
include_str!("../templates/review/exit_interrupted.xml");
/// API request payload for a single model turn
#[derive(Debug, Clone)]
pub struct Prompt {
/// Conversation context input items.
pub input: Vec<ResponseItem>,
/// Tools available to the model, including additional tools sourced from
/// external MCP servers.
pub(crate) tools: Vec<ToolSpec>,
/// Whether parallel tool calls are permitted for this prompt.
pub(crate) parallel_tool_calls: bool,
pub base_instructions: BaseInstructions,
/// Optionally specify the personality of the model.
pub personality: Option<Personality>,
/// Optional the output schema for the model's response.
pub output_schema: Option<Value>,
/// Whether the Responses API should strictly validate `output_schema`.
pub output_schema_strict: bool,
}
impl Default for Prompt {
fn default() -> Self {
Self {
input: Vec::new(),
tools: Vec::new(),
parallel_tool_calls: false,
base_instructions: BaseInstructions::default(),
personality: None,
output_schema: None,
output_schema_strict: true,
}
}
}
impl Prompt {
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
let mut input = self.input.clone();
// when using the *Freeform* apply_patch tool specifically, tool outputs
// should be structured text, not json. Do NOT reserialize when using
// the Function tool - note that this differs from the check above for
// instructions. We declare the result as a named variable for clarity.
let is_freeform_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
ToolSpec::Freeform(f) => f.name == "apply_patch",
_ => false,
});
if is_freeform_apply_patch_tool_present {
reserialize_shell_outputs(&mut input);
}
input
}
}
fn reserialize_shell_outputs(items: &mut [ResponseItem]) {
let mut shell_call_ids: HashSet<String> = HashSet::new();
items.iter_mut().for_each(|item| match item {
ResponseItem::LocalShellCall { call_id, id, .. } => {
if let Some(identifier) = call_id.clone().or_else(|| id.clone()) {
shell_call_ids.insert(identifier);
}
}
ResponseItem::CustomToolCall {
id: _,
status: _,
call_id,
name,
input: _,
} => {
if name == "apply_patch" {
shell_call_ids.insert(call_id.clone());
}
}
ResponseItem::FunctionCall { name, call_id, .. }
if is_shell_tool_name(name) || name == "apply_patch" =>
{
shell_call_ids.insert(call_id.clone());
}
ResponseItem::FunctionCallOutput {
call_id, output, ..
}
| ResponseItem::CustomToolCallOutput {
call_id, output, ..
} => {
if shell_call_ids.remove(call_id)
&& let Some(structured) = output
.text_content()
.and_then(parse_structured_shell_output)
{
output.body = FunctionCallOutputBody::Text(structured);
}
}
_ => {}
})
}
fn is_shell_tool_name(name: &str) -> bool {
matches!(name, "shell" | "container.exec")
}
#[derive(Deserialize)]
struct ExecOutputJson {
output: String,
metadata: ExecOutputMetadataJson,
}
#[derive(Deserialize)]
struct ExecOutputMetadataJson {
exit_code: i32,
duration_seconds: f32,
}
fn parse_structured_shell_output(raw: &str) -> Option<String> {
let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?;
Some(build_structured_output(&parsed))
}
fn build_structured_output(parsed: &ExecOutputJson) -> String {
let mut sections = Vec::new();
sections.push(format!("Exit code: {}", parsed.metadata.exit_code));
sections.push(format!(
"Wall time: {} seconds",
parsed.metadata.duration_seconds
));
let mut output = parsed.output.clone();
if let Some((stripped, total_lines)) = strip_total_output_header(&parsed.output) {
sections.push(format!("Total output lines: {total_lines}"));
output = stripped.to_string();
}
sections.push("Output:".to_string());
sections.push(output);
sections.join("\n")
}
fn strip_total_output_header(output: &str) -> Option<(&str, u32)> {
let after_prefix = output.strip_prefix("Total output lines: ")?;
let (total_segment, remainder) = after_prefix.split_once('\n')?;
let total_lines = total_segment.parse::<u32>().ok()?;
let remainder = remainder.strip_prefix('\n').unwrap_or(remainder);
Some((remainder, total_lines))
}
pub struct ResponseStream {
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
/// Signals the mapper task that the consumer stopped polling before the

View File

@@ -3,11 +3,9 @@ use codex_api::ResponsesApiRequest;
use codex_api::TextControls;
use codex_api::create_text_param_for_request;
use codex_protocol::config_types::ServiceTier;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseItem;
use pretty_assertions::assert_eq;
use super::*;
#[test]
fn serializes_text_verbosity_when_set() {
let input: Vec<ResponseItem> = vec![];
@@ -166,65 +164,3 @@ fn serializes_flex_service_tier_when_set() {
Some("flex")
);
}
#[test]
fn reserializes_shell_outputs_for_function_and_custom_tool_calls() {
let raw_output = r#"{"output":"hello","metadata":{"exit_code":0,"duration_seconds":0.5}}"#;
let expected_output = "Exit code: 0\nWall time: 0.5 seconds\nOutput:\nhello";
let mut items = vec![
ResponseItem::FunctionCall {
id: None,
name: "shell".to_string(),
namespace: None,
arguments: "{}".to_string(),
call_id: "call-1".to_string(),
},
ResponseItem::FunctionCallOutput {
call_id: "call-1".to_string(),
output: FunctionCallOutputPayload::from_text(raw_output.to_string()),
},
ResponseItem::CustomToolCall {
id: None,
status: None,
call_id: "call-2".to_string(),
name: "apply_patch".to_string(),
input: "*** Begin Patch".to_string(),
},
ResponseItem::CustomToolCallOutput {
call_id: "call-2".to_string(),
name: None,
output: FunctionCallOutputPayload::from_text(raw_output.to_string()),
},
];
reserialize_shell_outputs(&mut items);
assert_eq!(
items,
vec![
ResponseItem::FunctionCall {
id: None,
name: "shell".to_string(),
namespace: None,
arguments: "{}".to_string(),
call_id: "call-1".to_string(),
},
ResponseItem::FunctionCallOutput {
call_id: "call-1".to_string(),
output: FunctionCallOutputPayload::from_text(expected_output.to_string()),
},
ResponseItem::CustomToolCall {
id: None,
status: None,
call_id: "call-2".to_string(),
name: "apply_patch".to_string(),
input: "*** Begin Patch".to_string(),
},
ResponseItem::CustomToolCallOutput {
call_id: "call-2".to_string(),
name: None,
output: FunctionCallOutputPayload::from_text(expected_output.to_string()),
},
]
);
}

View File

@@ -188,6 +188,7 @@ mod mcp;
mod multi_agents;
mod review;
mod rollout_reconstruction;
mod sampling_loop;
#[allow(clippy::module_inception)]
pub(crate) mod session;
pub(crate) mod turn;

View File

@@ -0,0 +1,221 @@
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use crate::SkillLoadOutcome;
use crate::client::ModelClientSession;
use crate::session::session::Session;
use crate::session::turn::build_prompt_config;
use crate::session::turn::build_tool_config;
use crate::session::turn::built_tools;
use crate::session::turn::try_run_sampling_request;
use crate::session::turn_context::TurnContext;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::parallel::ToolCallRuntime;
use codex_kernel::PreparedSamplingRequest;
use codex_kernel::Prompt;
use codex_kernel::SamplingLoopHost;
use codex_kernel::SamplingRequestResult;
use codex_kernel::run_sampling_request_loop;
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result as CodexResult;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::CodexErrorInfo;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::RateLimitSnapshot;
use codex_protocol::protocol::StreamErrorEvent;
use codex_protocol::protocol::WarningEvent;
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use tracing::warn;
struct SamplingRequestRuntime {
_code_mode_worker: Option<codex_code_mode::CodeModeTurnWorker>,
}
#[allow(clippy::too_many_arguments)]
#[instrument(
level = "trace",
skip_all,
fields(
turn_id = %turn_context.sub_id,
model = %turn_context.model_info.slug,
cwd = %turn_context.cwd.display()
)
)]
pub(super) async fn run_sampling_request(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
client_session: &mut ModelClientSession,
turn_metadata_header: Option<&str>,
input: Vec<ResponseItem>,
explicitly_enabled_connectors: &HashSet<String>,
skills_outcome: Option<&SkillLoadOutcome>,
cancellation_token: CancellationToken,
) -> CodexResult<SamplingRequestResult> {
let host = CoreSamplingLoopHost {
sess,
turn_context,
turn_diff_tracker,
turn_metadata_header,
explicitly_enabled_connectors,
skills_outcome,
};
run_sampling_request_loop(&host, client_session, input, cancellation_token).await
}
struct CoreSamplingLoopHost<'a> {
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
turn_metadata_header: Option<&'a str>,
explicitly_enabled_connectors: &'a HashSet<String>,
skills_outcome: Option<&'a SkillLoadOutcome>,
}
impl SamplingLoopHost for CoreSamplingLoopHost<'_> {
type ClientSession = ModelClientSession;
type Runtime = SamplingRequestRuntime;
type Tools = ToolCallRuntime;
async fn prepare_sampling_request(
&self,
input: &[ResponseItem],
cancellation_token: &CancellationToken,
) -> CodexResult<PreparedSamplingRequest<Self::Runtime, Self::Tools>> {
let router = built_tools(
self.sess.as_ref(),
self.turn_context.as_ref(),
input,
self.explicitly_enabled_connectors,
self.skills_outcome,
cancellation_token,
)
.await?;
let base_instructions = self.sess.get_base_instructions().await;
let tool_runtime = ToolCallRuntime::new(
Arc::clone(&router),
Arc::clone(&self.sess),
Arc::clone(&self.turn_context),
Arc::clone(&self.turn_diff_tracker),
);
let code_mode_worker = self
.sess
.services
.code_mode_service
.start_turn_worker(
&self.sess,
&self.turn_context,
Arc::clone(&router),
Arc::clone(&self.turn_diff_tracker),
)
.await;
Ok(PreparedSamplingRequest {
prompt_config: build_prompt_config(self.turn_context.as_ref(), base_instructions),
tool_config: build_tool_config(router.as_ref(), self.turn_context.as_ref()),
runtime: SamplingRequestRuntime {
_code_mode_worker: code_mode_worker,
},
tools: tool_runtime,
})
}
async fn history_prompt_input(&self) -> Vec<ResponseItem> {
self.sess
.clone_history()
.await
.for_prompt(&self.turn_context.model_info.input_modalities)
}
async fn run_single_sampling_request(
&self,
_runtime: &Self::Runtime,
tools: &Self::Tools,
client_session: &mut Self::ClientSession,
prompt: &Prompt,
cancellation_token: CancellationToken,
) -> CodexResult<SamplingRequestResult> {
try_run_sampling_request(
tools.clone(),
Arc::clone(&self.sess),
Arc::clone(&self.turn_context),
client_session,
self.turn_metadata_header,
Arc::clone(&self.turn_diff_tracker),
prompt,
cancellation_token,
)
.await
}
fn stream_max_retries(&self) -> u64 {
self.turn_context.provider.info().stream_max_retries()
}
fn try_switch_fallback_transport(&self, client_session: &mut Self::ClientSession) -> bool {
client_session.try_switch_fallback_transport(
&self.turn_context.session_telemetry,
&self.turn_context.model_info,
)
}
fn should_notify_stream_retry(&self, retries: u64, _err: &CodexErr) -> bool {
retries > 1
|| cfg!(debug_assertions)
|| !self
.sess
.services
.model_client
.responses_websocket_enabled()
}
async fn handle_context_window_exceeded(&self) {
self.sess.set_total_tokens_full(&self.turn_context).await;
}
async fn handle_usage_limit_reached(&self, rate_limits: Option<RateLimitSnapshot>) {
if let Some(rate_limits) = rate_limits {
self.sess
.update_rate_limits(&self.turn_context, rate_limits)
.await;
}
}
async fn notify_fallback_to_http(&self, err: &CodexErr) {
self.sess
.send_event(
&self.turn_context,
EventMsg::Warning(WarningEvent {
message: format!("Falling back from WebSockets to HTTPS transport. {err:#}"),
}),
)
.await;
}
async fn notify_stream_retry(
&self,
retries: u64,
max_retries: u64,
delay: Duration,
err: &CodexErr,
) {
warn!(
"stream disconnected - retrying sampling request ({retries}/{max_retries} in {delay:?})..."
);
self.sess
.send_event(
&self.turn_context,
EventMsg::StreamError(StreamErrorEvent {
message: format!("Reconnecting... {retries}/{max_retries}"),
codex_error_info: Some(CodexErrorInfo::ResponseStreamDisconnected {
http_status_code: err.http_status_code_value(),
}),
additional_details: Some(err.to_string()),
}),
)
.await;
}
}

View File

@@ -40,6 +40,7 @@ use crate::parse_turn_item;
use crate::plugins::build_plugin_injections;
use crate::resolve_skill_dependencies_for_turn;
use crate::session::PreviousTurnSettings;
use crate::session::sampling_loop::run_sampling_request;
use crate::session::session::Session;
use crate::session::turn_context::TurnContext;
use crate::stream_events_utils::HandleOutputCtx;
@@ -57,7 +58,6 @@ use crate::tools::router::ToolRouterParams;
use crate::turn_diff_tracker::TurnDiffTracker;
use crate::turn_timing::record_turn_ttft_metric;
use crate::unavailable_tool::collect_unavailable_called_tools;
use crate::util::backoff;
use crate::util::error_or_panic;
use codex_analytics::AppInvocation;
use codex_analytics::CompactionPhase;
@@ -71,6 +71,9 @@ use codex_hooks::HookEvent;
use codex_hooks::HookEventAfterAgent;
use codex_hooks::HookPayload;
use codex_hooks::HookResult;
use codex_kernel::PromptConfig;
use codex_kernel::SamplingRequestResult;
use codex_kernel::ToolConfig;
use codex_protocol::config_types::ModeKind;
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result as CodexResult;
@@ -932,16 +935,11 @@ fn connector_inserted_in_messages(
connector_count == 1 && skill_count == 0 && mention_names_lower.contains(&mention_slug)
}
pub(crate) fn build_prompt(
input: Vec<ResponseItem>,
router: &ToolRouter,
pub(super) fn build_prompt_config(
turn_context: &TurnContext,
base_instructions: BaseInstructions,
) -> Prompt {
Prompt {
input,
tools: router.model_visible_specs(),
parallel_tool_calls: turn_context.model_info.supports_parallel_tool_calls,
) -> PromptConfig {
PromptConfig {
base_instructions,
personality: turn_context.personality,
output_schema: turn_context.final_output_json_schema.clone(),
@@ -951,156 +949,23 @@ pub(crate) fn build_prompt(
}
}
#[allow(clippy::too_many_arguments)]
#[instrument(level = "trace",
skip_all,
fields(
turn_id = %turn_context.sub_id,
model = %turn_context.model_info.slug,
cwd = %turn_context.cwd.display()
)
)]
async fn run_sampling_request(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
client_session: &mut ModelClientSession,
turn_metadata_header: Option<&str>,
input: Vec<ResponseItem>,
explicitly_enabled_connectors: &HashSet<String>,
skills_outcome: Option<&SkillLoadOutcome>,
cancellation_token: CancellationToken,
) -> CodexResult<SamplingRequestResult> {
let router = built_tools(
sess.as_ref(),
turn_context.as_ref(),
&input,
explicitly_enabled_connectors,
skills_outcome,
&cancellation_token,
)
.await?;
let base_instructions = sess.get_base_instructions().await;
let tool_runtime = ToolCallRuntime::new(
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
);
let _code_mode_worker = sess
.services
.code_mode_service
.start_turn_worker(
&sess,
&turn_context,
Arc::clone(&router),
Arc::clone(&turn_diff_tracker),
)
.await;
let mut retries = 0;
let mut initial_input = Some(input);
loop {
let prompt_input = if let Some(input) = initial_input.take() {
input
} else {
sess.clone_history()
.await
.for_prompt(&turn_context.model_info.input_modalities)
};
let prompt = build_prompt(
prompt_input,
router.as_ref(),
turn_context.as_ref(),
base_instructions.clone(),
);
let err = match try_run_sampling_request(
tool_runtime.clone(),
Arc::clone(&sess),
Arc::clone(&turn_context),
client_session,
turn_metadata_header,
Arc::clone(&turn_diff_tracker),
&prompt,
cancellation_token.child_token(),
)
.await
{
Ok(output) => {
return Ok(output);
}
Err(CodexErr::ContextWindowExceeded) => {
sess.set_total_tokens_full(&turn_context).await;
return Err(CodexErr::ContextWindowExceeded);
}
Err(CodexErr::UsageLimitReached(e)) => {
let rate_limits = e.rate_limits.clone();
if let Some(rate_limits) = rate_limits {
sess.update_rate_limits(&turn_context, *rate_limits).await;
}
return Err(CodexErr::UsageLimitReached(e));
}
Err(err) => err,
};
if !err.is_retryable() {
return Err(err);
}
// Use the configured provider-specific stream retry budget.
let max_retries = turn_context.provider.info().stream_max_retries();
if retries >= max_retries
&& client_session.try_switch_fallback_transport(
&turn_context.session_telemetry,
&turn_context.model_info,
)
{
sess.send_event(
&turn_context,
EventMsg::Warning(WarningEvent {
message: format!("Falling back from WebSockets to HTTPS transport. {err:#}"),
}),
)
.await;
retries = 0;
continue;
}
if retries < max_retries {
retries += 1;
let delay = match &err {
CodexErr::Stream(_, requested_delay) => {
requested_delay.unwrap_or_else(|| backoff(retries))
}
_ => backoff(retries),
};
warn!(
"stream disconnected - retrying sampling request ({retries}/{max_retries} in {delay:?})...",
);
// In release builds, hide the first websocket retry notification to reduce noisy
// transient reconnect messages. In debug builds, keep full visibility for diagnosis.
let report_error = retries > 1
|| cfg!(debug_assertions)
|| !sess.services.model_client.responses_websocket_enabled();
if report_error {
// Surface retry information to any UI/frontend so the
// user understands what is happening instead of staring
// at a seemingly frozen screen.
sess.notify_stream_error(
&turn_context,
format!("Reconnecting... {retries}/{max_retries}"),
err,
)
.await;
}
tokio::time::sleep(delay).await;
} else {
return Err(err);
}
pub(super) fn build_tool_config(router: &ToolRouter, turn_context: &TurnContext) -> ToolConfig {
ToolConfig {
tools: router.model_visible_specs(),
parallel_tool_calls: turn_context.model_info.supports_parallel_tool_calls,
}
}
pub(crate) fn build_prompt(
input: Vec<ResponseItem>,
router: &ToolRouter,
turn_context: &TurnContext,
base_instructions: BaseInstructions,
) -> Prompt {
build_prompt_config(turn_context, base_instructions)
.build_prompt(input, &build_tool_config(router, turn_context))
}
#[expect(
clippy::await_holding_invalid_type,
reason = "tool router construction reads through the session-owned manager guard"
@@ -1243,12 +1108,6 @@ pub(crate) async fn built_tools(
)))
}
#[derive(Debug)]
struct SamplingRequestResult {
needs_follow_up: bool,
last_agent_message: Option<String>,
}
/// Ephemeral per-response state for streaming a single proposed plan.
/// This is intentionally not persisted or stored in session/state since it
/// only exists while a response is actively streaming. The final plan text
@@ -1806,7 +1665,7 @@ async fn drain_in_flight(
model = %turn_context.model_info.slug
)
)]
async fn try_run_sampling_request(
pub(super) async fn try_run_sampling_request(
tool_runtime: ToolCallRuntime,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,

View File

@@ -7,21 +7,22 @@ use codex_protocol::config_types::ModeKind;
use codex_protocol::items::TurnItem;
use codex_utils_stream_parser::strip_citations;
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
use crate::context::ContextualUserFragment;
use crate::context::ImageGenerationInstructions;
use crate::function_tool::FunctionCallError;
use crate::parse_turn_item;
use crate::session::session::Session;
use crate::session::turn_context::TurnContext;
use crate::tools::parallel::ToolCallRuntime;
use crate::tools::router::ToolRouter;
use codex_kernel::CompletedResponseItem;
use codex_kernel::KernelToolExecutor;
use codex_kernel::execute_tool_call_with_default_output;
use codex_kernel::response_input_to_response_item;
use codex_memories_read::citations::parse_memory_citation;
use codex_memories_read::citations::thread_ids_from_memory_citation;
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::MessagePhase;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
@@ -225,9 +226,8 @@ pub(crate) async fn handle_output_item_done(
let mut output = OutputItemResult::default();
let plan_mode = ctx.turn_context.collaboration_mode.mode == ModeKind::Plan;
match ToolRouter::build_tool_call(ctx.sess.as_ref(), item.clone()).await {
// The model emitted a tool call; log it, persist the item immediately, and queue the tool execution.
Ok(Some(call)) => {
match ctx.tool_runtime.classify_response_item(item).await? {
CompletedResponseItem::ToolCall { item, call } => {
ctx.sess
.accept_mailbox_delivery_for_current_turn(&ctx.turn_context.sub_id)
.await;
@@ -244,17 +244,20 @@ pub(crate) async fn handle_output_item_done(
.await;
let cancellation_token = ctx.cancellation_token.child_token();
let tool_future: InFlightFuture<'static> = Box::pin(
ctx.tool_runtime
.clone()
.handle_tool_call(call, cancellation_token),
);
let tool_runtime = ctx.tool_runtime.clone();
let handle = AbortOnDropHandle::new(tokio::spawn(async move {
execute_tool_call_with_default_output(&tool_runtime, call, cancellation_token).await
}));
let tool_future: InFlightFuture<'static> = Box::pin(async move {
handle.await.map_err(|err| {
CodexErr::Fatal(format!("tool task failed to receive: {err:?}"))
})?
});
output.needs_follow_up = true;
output.tool_future = Some(tool_future);
}
// No tool call: convert messages/reasoning into turn items and mark them as complete.
Ok(None) => {
CompletedResponseItem::NonTool(item) => {
if let Some(turn_item) = handle_non_tool_response_item(
ctx.sess.as_ref(),
ctx.turn_context.as_ref(),
@@ -286,21 +289,7 @@ pub(crate) async fn handle_output_item_done(
output.last_agent_message = last_agent_message;
}
// Guardrail: the model issued a LocalShellCall without an id; surface the error back into history.
Err(FunctionCallError::MissingLocalShellCallId) => {
let msg = "LocalShellCall without call_id or id";
ctx.turn_context
.session_telemetry
.log_tool_failed("local_shell", msg);
tracing::error!(msg);
let response = ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
body: FunctionCallOutputBody::Text(msg.to_string()),
..Default::default()
},
};
CompletedResponseItem::ImmediateResponse { item, response } => {
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
.await;
if let Some(response_item) = response_input_to_response_item(&response) {
@@ -314,32 +303,6 @@ pub(crate) async fn handle_output_item_done(
output.needs_follow_up = true;
}
// The tool request should be answered directly (or was denied); push that response into the transcript.
Err(FunctionCallError::RespondToModel(message)) => {
let response = ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
body: FunctionCallOutputBody::Text(message),
..Default::default()
},
};
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
.await;
if let Some(response_item) = response_input_to_response_item(&response) {
ctx.sess
.record_conversation_items(
&ctx.turn_context,
std::slice::from_ref(&response_item),
)
.await;
}
output.needs_follow_up = true;
}
// A fatal error occurred; surface it back into history.
Err(FunctionCallError::Fatal(message)) => {
return Err(CodexErr::Fatal(message));
}
}
Ok(output)
@@ -465,45 +428,6 @@ fn completed_item_defers_mailbox_delivery_to_next_turn(
}
}
pub(crate) fn response_input_to_response_item(input: &ResponseInputItem) -> Option<ResponseItem> {
match input {
ResponseInputItem::FunctionCallOutput { call_id, output } => {
Some(ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: output.clone(),
})
}
ResponseInputItem::CustomToolCallOutput {
call_id,
name,
output,
} => Some(ResponseItem::CustomToolCallOutput {
call_id: call_id.clone(),
name: name.clone(),
output: output.clone(),
}),
ResponseInputItem::McpToolCallOutput { call_id, output } => {
let output = output.as_function_call_output_payload();
Some(ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output,
})
}
ResponseInputItem::ToolSearchOutput {
call_id,
status,
execution,
tools,
} => Some(ResponseItem::ToolSearchOutput {
call_id: Some(call_id.clone()),
status: status.clone(),
execution: execution.clone(),
tools: tools.clone(),
}),
_ => None,
}
}
#[cfg(test)]
#[path = "stream_events_utils_tests.rs"]
mod tests;

View File

@@ -20,8 +20,14 @@ use crate::tools::registry::ToolArgumentDiffConsumer;
use crate::tools::router::ToolCall;
use crate::tools::router::ToolCallSource;
use crate::tools::router::ToolRouter;
use codex_kernel::CompletedResponseItem;
use codex_kernel::KernelToolCall;
use codex_kernel::KernelToolExecutor;
use codex_kernel::ToolCallError;
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result as CodexResult;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use codex_tools::ToolSpec;
#[derive(Clone)]
@@ -60,25 +66,6 @@ impl ToolCallRuntime {
self.router.create_diff_consumer(tool_name)
}
#[instrument(level = "trace", skip_all)]
pub(crate) fn handle_tool_call(
self,
call: ToolCall,
cancellation_token: CancellationToken,
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
let error_call = call.clone();
let future =
self.handle_tool_call_with_source(call, ToolCallSource::Direct, cancellation_token);
async move {
match future.await {
Ok(response) => Ok(response.into_response()),
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
Err(other) => Ok(Self::failure_response(error_call, other)),
}
}
.in_current_span()
}
#[instrument(level = "trace", skip_all)]
pub(crate) fn handle_tool_call_with_source(
self,
@@ -144,33 +131,6 @@ impl ToolCallRuntime {
}
impl ToolCallRuntime {
fn failure_response(call: ToolCall, err: FunctionCallError) -> ResponseInputItem {
let message = err.to_string();
match call.payload {
ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput {
call_id: call.call_id,
status: "completed".to_string(),
execution: "client".to_string(),
tools: Vec::new(),
},
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
call_id: call.call_id,
name: None,
output: codex_protocol::models::FunctionCallOutputPayload {
body: codex_protocol::models::FunctionCallOutputBody::Text(message),
success: Some(false),
},
},
_ => ResponseInputItem::FunctionCallOutput {
call_id: call.call_id,
output: codex_protocol::models::FunctionCallOutputPayload {
body: codex_protocol::models::FunctionCallOutputBody::Text(message),
success: Some(false),
},
},
}
}
fn aborted_response(call: &ToolCall, secs: f32) -> AnyToolResult {
AnyToolResult {
call_id: call.call_id.clone(),
@@ -195,3 +155,90 @@ impl ToolCallRuntime {
}
}
}
impl KernelToolCall for ToolCall {
fn error_response(&self, message: String) -> ResponseInputItem {
failure_response(self, message)
}
}
impl KernelToolExecutor for ToolCallRuntime {
type Call = ToolCall;
async fn classify_response_item(
&self,
item: ResponseItem,
) -> CodexResult<CompletedResponseItem<Self::Call>> {
match ToolRouter::build_tool_call(self.session.as_ref(), item.clone()).await {
Ok(Some(call)) => Ok(CompletedResponseItem::ToolCall { item, call }),
Ok(None) => Ok(CompletedResponseItem::NonTool(item)),
Err(FunctionCallError::MissingLocalShellCallId) => {
let message = "LocalShellCall without call_id or id".to_string();
self.turn_context
.session_telemetry
.log_tool_failed("local_shell", &message);
tracing::error!("{message}");
let response = ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: codex_protocol::models::FunctionCallOutputPayload::from_text(message),
};
Ok(CompletedResponseItem::ImmediateResponse { item, response })
}
Err(FunctionCallError::RespondToModel(message)) => {
let response = ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: codex_protocol::models::FunctionCallOutputPayload::from_text(message),
};
Ok(CompletedResponseItem::ImmediateResponse { item, response })
}
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
}
}
async fn execute_tool_call(
&self,
call: Self::Call,
cancellation_token: CancellationToken,
) -> Result<ResponseInputItem, ToolCallError> {
match self
.clone()
.handle_tool_call_with_source(call, ToolCallSource::Direct, cancellation_token)
.await
{
Ok(result) => Ok(result.into_response()),
Err(FunctionCallError::RespondToModel(message)) => {
Err(ToolCallError::RespondToModel(message))
}
Err(FunctionCallError::MissingLocalShellCallId) => Err(ToolCallError::RespondToModel(
"LocalShellCall without call_id or id".to_string(),
)),
Err(FunctionCallError::Fatal(message)) => Err(ToolCallError::Fatal(message)),
}
}
}
fn failure_response(call: &ToolCall, message: String) -> ResponseInputItem {
match &call.payload {
ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput {
call_id: call.call_id.clone(),
status: "completed".to_string(),
execution: "client".to_string(),
tools: Vec::new(),
},
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
call_id: call.call_id.clone(),
name: None,
output: codex_protocol::models::FunctionCallOutputPayload {
body: codex_protocol::models::FunctionCallOutputBody::Text(message),
success: Some(false),
},
},
_ => ResponseInputItem::FunctionCallOutput {
call_id: call.call_id.clone(),
output: codex_protocol::models::FunctionCallOutputPayload {
body: codex_protocol::models::FunctionCallOutputBody::Text(message),
success: Some(false),
},
},
}
}

View File

@@ -675,10 +675,13 @@ fn maybe_wrap_shell_lc_with_snapshot_keeps_user_proxy_env_when_proxy_inactive()
&HashMap::new(),
&HashMap::new(),
);
let output = Command::new(&rewritten[0])
.args(&rewritten[1..])
.output()
.expect("run rewritten command");
let mut process = Command::new(&rewritten[0]);
process.args(&rewritten[1..]);
for key in PROXY_ENV_KEYS {
process.env_remove(key);
}
process.env_remove(PROXY_ACTIVE_ENV_KEY);
let output = process.output().expect("run rewritten command");
assert!(output.status.success(), "command failed: {output:?}");
assert_eq!(

View File

@@ -338,23 +338,25 @@ async fn execve_permission_request_hook_short_circuits_prompt() -> anyhow::Resul
std::fs::write(
&script_path,
format!(
"#!/bin/sh\ncat > {log_path}\nprintf '%s\\n' '{response}'\n",
log_path = shlex::try_quote(log_path.to_string_lossy().as_ref())?,
r#"import json
from pathlib import Path
import sys
payload = json.load(sys.stdin)
with Path(r"{log_path}").open("a", encoding="utf-8") as handle:
handle.write(json.dumps(payload) + "\n")
print({response:?})
"#,
log_path = log_path.display(),
response = "{\"hookSpecificOutput\":{\"hookEventName\":\"PermissionRequest\",\"decision\":{\"behavior\":\"allow\"}}}",
),
)
.with_context(|| format!("write hook script to {}", script_path.display()))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut permissions = std::fs::metadata(&script_path)
.with_context(|| format!("read hook script metadata from {}", script_path.display()))?
.permissions();
permissions.set_mode(0o755);
std::fs::set_permissions(&script_path, permissions)
.with_context(|| format!("set hook script permissions on {}", script_path.display()))?;
}
let script_path_arg = format!(
"'{}'",
script_path.display().to_string().replace('\'', "'\\''")
);
std::fs::write(
turn_context.config.codex_home.join("hooks.json"),
serde_json::json!({
@@ -362,7 +364,7 @@ async fn execve_permission_request_hook_short_circuits_prompt() -> anyhow::Resul
"PermissionRequest": [{
"hooks": [{
"type": "command",
"command": script_path.display().to_string(),
"command": format!("python3 {script_path_arg}"),
}]
}]
}

View File

@@ -111,15 +111,17 @@ async fn responses_stream_includes_subagent_header_on_review() {
);
let mut client_session = client.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
let prompt = Prompt {
input: vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
phase: None,
}],
phase: None,
}];
..Prompt::default()
};
let mut stream = client_session
.stream(
@@ -237,15 +239,17 @@ async fn responses_stream_includes_subagent_header_on_other() {
);
let mut client_session = client.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
let prompt = Prompt {
input: vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
phase: None,
}],
phase: None,
}];
..Prompt::default()
};
let mut stream = client_session
.stream(
@@ -352,15 +356,17 @@ async fn responses_respects_model_info_overrides_from_config() {
);
let mut client_session = client.new_session();
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
let prompt = Prompt {
input: vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
phase: None,
}],
phase: None,
}];
..Prompt::default()
};
let mut stream = client_session
.stream(

View File

@@ -1744,9 +1744,10 @@ fn assistant_message_item(id: &str, text: &str) -> ResponseItem {
}
fn prompt_with_input(input: Vec<ResponseItem>) -> Prompt {
let mut prompt = Prompt::default();
prompt.input = input;
prompt
Prompt {
input,
..Prompt::default()
}
}
fn prompt_with_input_and_instructions(input: Vec<ResponseItem>, instructions: &str) -> Prompt {

View File

@@ -0,0 +1,6 @@
load("//:defs.bzl", "codex_rust_crate")
codex_rust_crate(
name = "kernel",
crate_name = "codex_kernel",
)

View File

@@ -0,0 +1,25 @@
[package]
edition.workspace = true
license.workspace = true
name = "codex-kernel"
version.workspace = true
[lib]
doctest = false
name = "codex_kernel"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
codex-protocol = { workspace = true }
codex-tools = { workspace = true }
rand = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tokio = { workspace = true, features = ["time"] }
tokio-util = { workspace = true, features = ["rt"] }
[dev-dependencies]
pretty_assertions = { workspace = true }

View File

@@ -0,0 +1,17 @@
mod prompt;
mod sampling;
mod tools;
pub use prompt::Prompt;
pub use prompt::PromptConfig;
pub use sampling::PreparedSamplingRequest;
pub use sampling::SamplingLoopHost;
pub use sampling::SamplingRequestResult;
pub use sampling::run_sampling_request_loop;
pub use tools::CompletedResponseItem;
pub use tools::KernelToolCall;
pub use tools::KernelToolExecutor;
pub use tools::ToolCallError;
pub use tools::ToolConfig;
pub use tools::execute_tool_call_with_default_output;
pub use tools::response_input_to_response_item;

View File

@@ -0,0 +1,257 @@
use crate::ToolConfig;
use codex_protocol::config_types::Personality;
use codex_protocol::models::BaseInstructions;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::ResponseItem;
use codex_tools::ToolSpec;
use serde::Deserialize;
use serde_json::Value;
use std::collections::HashSet;
/// API request payload for a single model turn.
#[derive(Debug, Clone)]
pub struct Prompt {
pub input: Vec<ResponseItem>,
pub tools: Vec<ToolSpec>,
pub parallel_tool_calls: bool,
pub base_instructions: BaseInstructions,
pub personality: Option<Personality>,
pub output_schema: Option<Value>,
pub output_schema_strict: bool,
}
/// Retry-stable prompt settings used to rebuild a [`Prompt`] after tool calls or stream retries.
#[derive(Debug, Clone)]
pub struct PromptConfig {
pub base_instructions: BaseInstructions,
pub personality: Option<Personality>,
pub output_schema: Option<Value>,
pub output_schema_strict: bool,
}
impl Prompt {
pub fn get_formatted_input(&self) -> Vec<ResponseItem> {
let mut input = self.input.clone();
let is_freeform_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
ToolSpec::Freeform(f) => f.name == "apply_patch",
_ => false,
});
if is_freeform_apply_patch_tool_present {
reserialize_shell_outputs(&mut input);
}
input
}
}
impl PromptConfig {
pub fn build_prompt(&self, input: Vec<ResponseItem>, tool_config: &ToolConfig) -> Prompt {
Prompt {
input,
tools: tool_config.tools.clone(),
parallel_tool_calls: tool_config.parallel_tool_calls,
base_instructions: self.base_instructions.clone(),
personality: self.personality,
output_schema: self.output_schema.clone(),
output_schema_strict: self.output_schema_strict,
}
}
}
impl Default for Prompt {
fn default() -> Self {
Self {
input: Vec::new(),
tools: Vec::new(),
parallel_tool_calls: false,
base_instructions: BaseInstructions::default(),
personality: None,
output_schema: None,
output_schema_strict: true,
}
}
}
impl Default for PromptConfig {
fn default() -> Self {
Self {
base_instructions: BaseInstructions::default(),
personality: None,
output_schema: None,
output_schema_strict: true,
}
}
}
fn reserialize_shell_outputs(items: &mut [ResponseItem]) {
let mut shell_call_ids: HashSet<String> = HashSet::new();
items.iter_mut().for_each(|item| match item {
ResponseItem::LocalShellCall { call_id, id, .. } => {
if let Some(identifier) = call_id.clone().or_else(|| id.clone()) {
shell_call_ids.insert(identifier);
}
}
ResponseItem::CustomToolCall {
id: _,
status: _,
call_id,
name,
input: _,
} => {
if name == "apply_patch" {
shell_call_ids.insert(call_id.clone());
}
}
ResponseItem::FunctionCall { name, call_id, .. }
if is_shell_tool_name(name) || name == "apply_patch" =>
{
shell_call_ids.insert(call_id.clone());
}
ResponseItem::FunctionCallOutput {
call_id, output, ..
}
| ResponseItem::CustomToolCallOutput {
call_id, output, ..
} => {
if shell_call_ids.remove(call_id)
&& let Some(structured) = output
.text_content()
.and_then(parse_structured_shell_output)
{
output.body = FunctionCallOutputBody::Text(structured);
}
}
_ => {}
});
}
fn is_shell_tool_name(name: &str) -> bool {
matches!(name, "shell" | "container.exec")
}
#[derive(Deserialize)]
struct ExecOutputJson {
output: String,
metadata: ExecOutputMetadataJson,
}
#[derive(Deserialize)]
struct ExecOutputMetadataJson {
exit_code: i32,
duration_seconds: f32,
}
fn parse_structured_shell_output(raw: &str) -> Option<String> {
let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?;
Some(build_structured_output(&parsed))
}
fn build_structured_output(parsed: &ExecOutputJson) -> String {
let mut sections = Vec::new();
let exit_code = parsed.metadata.exit_code;
sections.push(format!("Exit code: {exit_code}"));
let duration_seconds = parsed.metadata.duration_seconds;
sections.push(format!("Wall time: {duration_seconds} seconds"));
let mut output = parsed.output.clone();
if let Some((stripped, total_lines)) = strip_total_output_header(&parsed.output) {
sections.push(format!("Total output lines: {total_lines}"));
output = stripped.to_string();
}
sections.push("Output:".to_string());
sections.push(output);
sections.join("\n")
}
fn strip_total_output_header(output: &str) -> Option<(&str, u32)> {
let after_prefix = output.strip_prefix("Total output lines: ")?;
let (total_segment, remainder) = after_prefix.split_once('\n')?;
let total_lines = total_segment.parse::<u32>().ok()?;
let remainder = remainder.strip_prefix('\n').unwrap_or(remainder);
Some((remainder, total_lines))
}
#[cfg(test)]
mod tests {
use super::Prompt;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseItem;
use codex_tools::FreeformTool;
use codex_tools::FreeformToolFormat;
use codex_tools::ToolSpec;
use pretty_assertions::assert_eq;
#[test]
fn reserializes_shell_outputs_for_function_and_custom_tool_calls() {
let raw_output = r#"{"output":"hello","metadata":{"exit_code":0,"duration_seconds":0.5}}"#;
let expected_output = "Exit code: 0\nWall time: 0.5 seconds\nOutput:\nhello";
let prompt = Prompt {
input: vec![
ResponseItem::FunctionCall {
id: None,
name: "shell".to_string(),
namespace: None,
arguments: "{}".to_string(),
call_id: "call-1".to_string(),
},
ResponseItem::FunctionCallOutput {
call_id: "call-1".to_string(),
output: FunctionCallOutputPayload::from_text(raw_output.to_string()),
},
ResponseItem::CustomToolCall {
id: None,
status: None,
call_id: "call-2".to_string(),
name: "apply_patch".to_string(),
input: "*** Begin Patch".to_string(),
},
ResponseItem::CustomToolCallOutput {
call_id: "call-2".to_string(),
name: None,
output: FunctionCallOutputPayload::from_text(raw_output.to_string()),
},
],
tools: vec![ToolSpec::Freeform(FreeformTool {
name: "apply_patch".to_string(),
description: "patch".to_string(),
format: FreeformToolFormat {
r#type: "grammar".to_string(),
syntax: "lark".to_string(),
definition: "patch".to_string(),
},
})],
..Prompt::default()
};
assert_eq!(
prompt.get_formatted_input(),
vec![
ResponseItem::FunctionCall {
id: None,
name: "shell".to_string(),
namespace: None,
arguments: "{}".to_string(),
call_id: "call-1".to_string(),
},
ResponseItem::FunctionCallOutput {
call_id: "call-1".to_string(),
output: FunctionCallOutputPayload::from_text(expected_output.to_string()),
},
ResponseItem::CustomToolCall {
id: None,
status: None,
call_id: "call-2".to_string(),
name: "apply_patch".to_string(),
input: "*** Begin Patch".to_string(),
},
ResponseItem::CustomToolCallOutput {
call_id: "call-2".to_string(),
name: None,
output: FunctionCallOutputPayload::from_text(expected_output.to_string()),
},
]
);
}
}

View File

@@ -0,0 +1,170 @@
use crate::KernelToolExecutor;
use crate::Prompt;
use crate::PromptConfig;
use crate::ToolConfig;
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::RateLimitSnapshot;
use rand::Rng;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
const INITIAL_DELAY_MS: u64 = 200;
const BACKOFF_FACTOR: f64 = 2.0;
/// Host-facing bundle created once before the kernel begins retrying a logical sampling request.
///
/// Implementations should place retry-stable state here so retries can rebuild prompts without
/// repeating expensive setup work such as tool routing or per-turn worker startup.
pub struct PreparedSamplingRequest<Runtime, Tools> {
pub prompt_config: PromptConfig,
pub tool_config: ToolConfig,
pub runtime: Runtime,
pub tools: Tools,
}
/// Final outcome of a logical sampling request after the model/tool loop reaches a stable state.
#[derive(Debug)]
pub struct SamplingRequestResult {
pub needs_follow_up: bool,
pub last_agent_message: Option<String>,
}
/// Host adapter for the kernel's retryable sampling loop.
///
/// Implementations provide the concrete model transport, tool runtime, history access, and
/// user-facing retry notifications while allowing `codex-kernel` to stay agnostic of local
/// filesystem, network, or tool-execution details.
#[allow(async_fn_in_trait)]
pub trait SamplingLoopHost {
type ClientSession;
type Runtime;
type Tools: KernelToolExecutor;
async fn prepare_sampling_request(
&self,
input: &[ResponseItem],
cancellation_token: &CancellationToken,
) -> Result<PreparedSamplingRequest<Self::Runtime, Self::Tools>>;
async fn history_prompt_input(&self) -> Vec<ResponseItem>;
async fn run_single_sampling_request(
&self,
runtime: &Self::Runtime,
tools: &Self::Tools,
client_session: &mut Self::ClientSession,
prompt: &Prompt,
cancellation_token: CancellationToken,
) -> Result<SamplingRequestResult>;
fn stream_max_retries(&self) -> u64;
fn try_switch_fallback_transport(&self, client_session: &mut Self::ClientSession) -> bool;
fn should_notify_stream_retry(&self, retries: u64, err: &CodexErr) -> bool;
async fn handle_context_window_exceeded(&self);
async fn handle_usage_limit_reached(&self, rate_limits: Option<RateLimitSnapshot>);
async fn notify_fallback_to_http(&self, err: &CodexErr);
async fn notify_stream_retry(
&self,
retries: u64,
max_retries: u64,
delay: Duration,
err: &CodexErr,
);
}
pub async fn run_sampling_request_loop<Host>(
host: &Host,
client_session: &mut Host::ClientSession,
initial_input: Vec<ResponseItem>,
cancellation_token: CancellationToken,
) -> Result<SamplingRequestResult>
where
Host: SamplingLoopHost,
{
let prepared = host
.prepare_sampling_request(initial_input.as_slice(), &cancellation_token)
.await?;
let mut retries = 0;
let mut initial_input = Some(initial_input);
loop {
let prompt_input = if let Some(input) = initial_input.take() {
input
} else {
host.history_prompt_input().await
};
let prompt = prepared
.prompt_config
.build_prompt(prompt_input, &prepared.tool_config);
let err = match host
.run_single_sampling_request(
&prepared.runtime,
&prepared.tools,
client_session,
&prompt,
cancellation_token.child_token(),
)
.await
{
Ok(output) => {
return Ok(output);
}
Err(CodexErr::ContextWindowExceeded) => {
host.handle_context_window_exceeded().await;
return Err(CodexErr::ContextWindowExceeded);
}
Err(CodexErr::UsageLimitReached(error)) => {
host.handle_usage_limit_reached(
error
.rate_limits
.as_ref()
.map(|snapshot| (**snapshot).clone()),
)
.await;
return Err(CodexErr::UsageLimitReached(error));
}
Err(err) => err,
};
if !err.is_retryable() {
return Err(err);
}
let max_retries = host.stream_max_retries();
if retries >= max_retries && host.try_switch_fallback_transport(client_session) {
host.notify_fallback_to_http(&err).await;
retries = 0;
continue;
}
if retries < max_retries {
retries += 1;
let delay = match &err {
CodexErr::Stream(_, requested_delay) => {
requested_delay.unwrap_or_else(|| backoff(retries))
}
_ => backoff(retries),
};
if host.should_notify_stream_retry(retries, &err) {
host.notify_stream_retry(retries, max_retries, delay, &err)
.await;
}
tokio::time::sleep(delay).await;
} else {
return Err(err);
}
}
}
fn backoff(attempt: u64) -> Duration {
let exp = BACKOFF_FACTOR.powi(attempt.saturating_sub(1) as i32);
let base = (INITIAL_DELAY_MS as f64 * exp) as u64;
let jitter = rand::rng().random_range(0.9..1.1);
Duration::from_millis((base as f64 * jitter) as u64)
}

View File

@@ -0,0 +1,253 @@
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use codex_tools::ToolSpec;
use std::fmt;
use tokio_util::sync::CancellationToken;
/// Retry-stable tool registration for a logical model request.
///
/// The host chooses which tools are exposed to the model and whether the model
/// may issue them in parallel. `codex-kernel` treats this as opaque model-facing
/// configuration and uses it to rebuild prompts across retries.
#[derive(Debug, Clone, Default)]
pub struct ToolConfig {
pub tools: Vec<ToolSpec>,
pub parallel_tool_calls: bool,
}
/// Metadata a host-specific tool call must expose so `codex-kernel` can
/// synthesize a model-visible failure response when execution cannot complete.
pub trait KernelToolCall: Clone + Send + Sync {
fn error_response(&self, message: String) -> ResponseInputItem;
}
/// Non-fatal tool execution failures are surfaced back to the model as a
/// complementary output item, while fatal failures abort the turn.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ToolCallError {
RespondToModel(String),
Fatal(String),
}
impl fmt::Display for ToolCallError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::RespondToModel(message) => write!(f, "{message}"),
Self::Fatal(message) => write!(f, "Fatal error: {message}"),
}
}
}
/// Classification of a completed model output item relative to the host's tool
/// catalog.
pub enum CompletedResponseItem<Call> {
NonTool(ResponseItem),
ToolCall {
item: ResponseItem,
call: Call,
},
ImmediateResponse {
item: ResponseItem,
response: ResponseInputItem,
},
}
/// Host-provided tool adapter used by `codex-kernel`'s default model/tool loop.
///
/// Implementations own tool-call parsing and execution, while
/// `codex-kernel` owns the generic behavior of attempting a complementary
/// tool output item for every tool call the model emits.
#[allow(async_fn_in_trait)]
pub trait KernelToolExecutor {
type Call: KernelToolCall;
async fn classify_response_item(
&self,
item: ResponseItem,
) -> Result<CompletedResponseItem<Self::Call>>;
async fn execute_tool_call(
&self,
call: Self::Call,
cancellation_token: CancellationToken,
) -> std::result::Result<ResponseInputItem, ToolCallError>;
}
pub async fn execute_tool_call_with_default_output<Executor>(
executor: &Executor,
call: Executor::Call,
cancellation_token: CancellationToken,
) -> Result<ResponseInputItem>
where
Executor: KernelToolExecutor,
{
match executor
.execute_tool_call(call.clone(), cancellation_token)
.await
{
Ok(response) => Ok(response),
Err(ToolCallError::RespondToModel(message)) => Ok(call.error_response(message)),
Err(ToolCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
}
}
pub fn response_input_to_response_item(input: &ResponseInputItem) -> Option<ResponseItem> {
match input {
ResponseInputItem::FunctionCallOutput { call_id, output } => {
if call_id.is_empty() {
return None;
}
Some(ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: output.clone(),
})
}
ResponseInputItem::CustomToolCallOutput {
call_id,
name,
output,
} => {
if call_id.is_empty() {
return None;
}
Some(ResponseItem::CustomToolCallOutput {
call_id: call_id.clone(),
name: name.clone(),
output: output.clone(),
})
}
ResponseInputItem::McpToolCallOutput { call_id, output } => {
if call_id.is_empty() {
return None;
}
let output = output.as_function_call_output_payload();
Some(ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output,
})
}
ResponseInputItem::ToolSearchOutput {
call_id,
status,
execution,
tools,
} => {
if call_id.is_empty() {
return None;
}
Some(ResponseItem::ToolSearchOutput {
call_id: Some(call_id.clone()),
status: status.clone(),
execution: execution.clone(),
tools: tools.clone(),
})
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::CompletedResponseItem;
use super::KernelToolCall;
use super::KernelToolExecutor;
use super::ToolCallError;
use super::execute_tool_call_with_default_output;
use super::response_input_to_response_item;
use codex_protocol::error::Result;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use pretty_assertions::assert_eq;
use tokio_util::sync::CancellationToken;
#[derive(Clone, Debug, PartialEq, Eq)]
struct TestCall {
call_id: String,
}
impl KernelToolCall for TestCall {
fn error_response(&self, message: String) -> ResponseInputItem {
ResponseInputItem::FunctionCallOutput {
call_id: self.call_id.clone(),
output: FunctionCallOutputPayload::from_text(message),
}
}
}
struct TestExecutor;
impl KernelToolExecutor for TestExecutor {
type Call = TestCall;
async fn classify_response_item(
&self,
item: ResponseItem,
) -> Result<CompletedResponseItem<Self::Call>> {
Ok(CompletedResponseItem::NonTool(item))
}
async fn execute_tool_call(
&self,
call: Self::Call,
_cancellation_token: CancellationToken,
) -> std::result::Result<ResponseInputItem, ToolCallError> {
match call.call_id.as_str() {
"ok" => Ok(ResponseInputItem::FunctionCallOutput {
call_id: call.call_id,
output: FunctionCallOutputPayload::from_text("done".to_string()),
}),
"retry" => Err(ToolCallError::RespondToModel("try again".to_string())),
_ => Err(ToolCallError::Fatal("boom".to_string())),
}
}
}
#[tokio::test]
async fn execute_tool_call_with_default_output_surfaces_non_fatal_errors() {
let response = execute_tool_call_with_default_output(
&TestExecutor,
TestCall {
call_id: "retry".to_string(),
},
CancellationToken::new(),
)
.await
.expect("tool response");
assert_eq!(
response,
ResponseInputItem::FunctionCallOutput {
call_id: "retry".to_string(),
output: FunctionCallOutputPayload::from_text("try again".to_string()),
}
);
}
#[tokio::test]
async fn execute_tool_call_with_default_output_propagates_fatal_errors() {
let err = execute_tool_call_with_default_output(
&TestExecutor,
TestCall {
call_id: "fatal".to_string(),
},
CancellationToken::new(),
)
.await
.expect_err("fatal tool error");
assert_eq!(err.to_string(), "Fatal error: boom");
}
#[test]
fn response_input_to_response_item_skips_empty_call_ids() {
let item = response_input_to_response_item(&ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload::from_text("boom".to_string()),
});
assert_eq!(item, None);
}
}

View File

@@ -285,25 +285,27 @@ mod job {
let (rollout_items, _, _) = RolloutRecorder::load_rollout_items(rollout_path).await?;
let rollout_contents = serialize_filtered_rollout_response_items(&rollout_items)?;
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: build_stage_one_input_message(
&stage_one_context.model_info,
rollout_path,
rollout_cwd,
&rollout_contents,
)?,
let prompt = Prompt {
input: vec![ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: build_stage_one_input_message(
&stage_one_context.model_info,
rollout_path,
rollout_cwd,
&rollout_contents,
)?,
}],
phase: None,
}],
phase: None,
}];
prompt.base_instructions = BaseInstructions {
text: crate::stage_one::PROMPT.to_string(),
base_instructions: BaseInstructions {
text: crate::stage_one::PROMPT.to_string(),
},
output_schema: Some(output_schema()),
output_schema_strict: true,
..Prompt::default()
};
prompt.output_schema = Some(output_schema());
prompt.output_schema_strict = true;
let (result, token_usage) = context
.stream_stage_one_prompt(config, &prompt, stage_one_context)