Compare commits

...

1 Commits

Author SHA1 Message Date
viyatb-oai
b21c3253e0 feat(core): add turn-scoped taint tracking 2026-03-03 13:10:12 -08:00
29 changed files with 641 additions and 9 deletions

5
codex-rs/Cargo.lock generated
View File

@@ -1717,6 +1717,7 @@ dependencies = [
"codex-shell-escalation",
"codex-skills",
"codex-state",
"codex-taint",
"codex-utils-absolute-path",
"codex-utils-cargo-bin",
"codex-utils-home-dir",
@@ -2297,6 +2298,10 @@ dependencies = [
"uds_windows",
]
[[package]]
name = "codex-taint"
version = "0.0.0"
[[package]]
name = "codex-tui"
version = "0.0.0"

View File

@@ -38,6 +38,7 @@ members = [
"rmcp-client",
"responses-api-proxy",
"stdio-to-uds",
"taint",
"otel",
"tui",
"utils/absolute-path",
@@ -116,6 +117,7 @@ codex-shell-command = { path = "shell-command" }
codex-shell-escalation = { path = "shell-escalation" }
codex-skills = { path = "skills" }
codex-state = { path = "state" }
codex-taint = { path = "taint" }
codex-stdio-to-uds = { path = "stdio-to-uds" }
codex-tui = { path = "tui" }
codex-utils-absolute-path = { path = "utils/absolute-path" }

View File

@@ -34,6 +34,7 @@ codex-client = { workspace = true }
codex-config = { workspace = true }
codex-shell-command = { workspace = true }
codex-skills = { workspace = true }
codex-taint = { workspace = true }
codex-execpolicy = { workspace = true }
codex-file-search = { workspace = true }
codex-git = { workspace = true }

View File

@@ -29,6 +29,7 @@ use crate::features::FEATURES;
use crate::features::Feature;
use crate::features::Features;
use crate::features::maybe_push_unstable_features_warning;
use crate::function_tool::FunctionCallError;
use crate::models_manager::manager::ModelsManager;
use crate::parse_command::parse_command;
use crate::parse_turn_item;
@@ -92,6 +93,9 @@ use codex_protocol::request_user_input::RequestUserInputResponse;
use codex_protocol::skill_approval::SkillApprovalResponse;
use codex_rmcp_client::ElicitationResponse;
use codex_rmcp_client::OAuthCredentialsStoreMode;
use codex_taint::TaintEffect;
use codex_taint::TaintSink;
use codex_taint::TaintState;
use futures::future::BoxFuture;
use futures::prelude::*;
use futures::stream::FuturesOrdered;
@@ -3370,10 +3374,65 @@ impl Session {
}
let mut turn_state = active_turn.turn_state.lock().await;
turn_state.reset_taint();
turn_state.push_pending_input(input.into());
Ok(active_turn_id.clone())
}
pub(crate) async fn apply_taint_effect(&self, turn_id: &str, effect: TaintEffect) {
if matches!(effect, TaintEffect::None) {
return;
}
let active = self.active_turn.lock().await;
let Some(active_turn) = active.as_ref() else {
return;
};
let Some((active_turn_id, _)) = active_turn.tasks.first() else {
return;
};
if active_turn_id != turn_id {
return;
}
let mut turn_state = active_turn.turn_state.lock().await;
tracing::debug!(%turn_id, ?effect, "applying taint effect");
turn_state.apply_taint_effect(effect);
}
pub(crate) async fn current_taint(&self, turn_id: &str) -> TaintState {
let active = self.active_turn.lock().await;
let Some(active_turn) = active.as_ref() else {
return TaintState::default();
};
let Some((active_turn_id, _)) = active_turn.tasks.first() else {
return TaintState::default();
};
if active_turn_id != turn_id {
return TaintState::default();
}
let turn_state = active_turn.turn_state.lock().await;
turn_state.current_taint()
}
pub(crate) async fn ensure_taint_sink_allowed(
&self,
turn_id: &str,
sink: TaintSink,
) -> Result<(), FunctionCallError> {
let taint = self.current_taint(turn_id).await;
if let Err(err) = taint.check_sink(sink) {
tracing::warn!(%turn_id, ?sink, recent_sources = ?err.recent_sources, "blocked taint sink");
return Err(FunctionCallError::RespondToModel(err.to_string()));
}
Ok(())
}
/// Returns the input if there was no task running to inject into
pub async fn inject_response_items(
&self,

View File

@@ -27,6 +27,7 @@ use codex_protocol::request_user_input::RequestUserInputArgs;
use codex_protocol::request_user_input::RequestUserInputQuestion;
use codex_protocol::request_user_input::RequestUserInputQuestionOption;
use codex_protocol::request_user_input::RequestUserInputResponse;
use codex_taint::TaintSink;
use rmcp::model::ToolAnnotations;
use serde::Serialize;
use std::sync::Arc;
@@ -102,6 +103,25 @@ pub(crate) async fn handle_mcp_tool_call(
return ResponseInputItem::McpToolCallOutput { call_id, result };
}
let is_read_only_tool = metadata
.as_ref()
.and_then(|metadata| metadata.annotations.as_ref())
.and_then(|annotations| annotations.read_only_hint)
== Some(true);
if !is_read_only_tool
&& let Err(err) = sess
.ensure_taint_sink_allowed(&turn_context.sub_id, TaintSink::external_dispatch())
.await
{
return ResponseInputItem::FunctionCallOutput {
call_id: call_id.clone(),
output: FunctionCallOutputPayload {
body: FunctionCallOutputBody::Text(err.to_string()),
success: Some(false),
},
};
}
if let Some(decision) = maybe_request_mcp_tool_approval(
sess.as_ref(),
turn_context,

View File

@@ -12,6 +12,8 @@ use codex_protocol::dynamic_tools::DynamicToolResponse;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::request_user_input::RequestUserInputResponse;
use codex_protocol::skill_approval::SkillApprovalResponse;
use codex_taint::TaintEffect;
use codex_taint::TaintState;
use tokio::sync::oneshot;
use crate::codex::TurnContext;
@@ -75,6 +77,7 @@ pub(crate) struct TurnState {
pending_skill_approvals: HashMap<String, oneshot::Sender<SkillApprovalResponse>>,
pending_dynamic_tools: HashMap<String, oneshot::Sender<DynamicToolResponse>>,
pending_input: Vec<ResponseInputItem>,
control_taint: TaintState,
}
impl TurnState {
@@ -150,6 +153,18 @@ impl TurnState {
self.pending_input.push(input);
}
pub(crate) fn apply_taint_effect(&mut self, effect: TaintEffect) {
self.control_taint.apply(effect);
}
pub(crate) fn current_taint(&self) -> TaintState {
self.control_taint.clone()
}
pub(crate) fn reset_taint(&mut self) {
self.control_taint.reset();
}
pub(crate) fn take_pending_input(&mut self) -> Vec<ResponseInputItem> {
if self.pending_input.is_empty() {
Vec::with_capacity(0)
@@ -172,3 +187,31 @@ impl ActiveTurn {
ts.clear_pending();
}
}
#[cfg(test)]
mod tests {
use super::*;
use codex_taint::TaintLabel;
use codex_taint::TaintSource;
use pretty_assertions::assert_eq;
#[test]
fn turn_state_applies_and_resets_taint() {
let mut state = TurnState::default();
state.apply_taint_effect(TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::ReadFile,
});
let taint = state.current_taint();
assert_eq!(
taint.labels().iter().copied().collect::<Vec<_>>(),
vec![TaintLabel::WorkspaceContent]
);
assert_eq!(taint.recent_sources(), &[TaintSource::ReadFile]);
state.apply_taint_effect(TaintEffect::Reset);
assert!(state.current_taint().is_clean());
}
}

View File

@@ -18,6 +18,9 @@ use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use codex_taint::TaintEffect;
use codex_taint::TaintLabel;
use codex_taint::TaintSource;
use futures::Future;
use tracing::debug;
use tracing::instrument;
@@ -78,6 +81,17 @@ pub(crate) async fn handle_output_item_done(
}
// No tool call: convert messages/reasoning into turn items and mark them as complete.
Ok(None) => {
if matches!(item, ResponseItem::WebSearchCall { .. }) {
ctx.sess
.apply_taint_effect(
&ctx.turn_context.sub_id,
TaintEffect::Mark {
label: TaintLabel::ExternalContent,
source: TaintSource::WebSearch,
},
)
.await;
}
if let Some(turn_item) = handle_non_tool_response_item(&item, plan_mode).await {
if previously_active_item.is_none() {
ctx.sess

View File

@@ -5,6 +5,9 @@ use async_trait::async_trait;
use codex_async_utils::CancelErr;
use codex_async_utils::OrCancelExt;
use codex_protocol::user_input::UserInput;
use codex_taint::TaintEffect;
use codex_taint::TaintLabel;
use codex_taint::TaintSource;
use tokio_util::sync::CancellationToken;
use tracing::error;
use uuid::Uuid;
@@ -322,6 +325,16 @@ async fn persist_user_shell_output(
_ => unreachable!("user shell command output record should always be a message"),
};
session
.apply_taint_effect(
&turn_context.sub_id,
TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::UserShellOutput,
},
)
.await;
if let Err(items) = session
.inject_response_items(vec![response_input_item])
.await

View File

@@ -9,6 +9,7 @@ use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ShellToolCallParams;
use codex_taint::TaintEffect;
use codex_utils_string::take_bytes_at_char_boundary;
use std::borrow::Cow;
use std::sync::Arc;
@@ -68,9 +69,11 @@ pub enum ToolOutput {
// or structured content items.
body: FunctionCallOutputBody,
success: Option<bool>,
taint_effect: TaintEffect,
},
Mcp {
result: Result<CallToolResult, String>,
taint_effect: TaintEffect,
},
}
@@ -80,20 +83,32 @@ impl ToolOutput {
ToolOutput::Function { body, .. } => {
telemetry_preview(&body.to_text().unwrap_or_default())
}
ToolOutput::Mcp { result } => format!("{result:?}"),
ToolOutput::Mcp { result, .. } => format!("{result:?}"),
}
}
pub fn success_for_logging(&self) -> bool {
match self {
ToolOutput::Function { success, .. } => success.unwrap_or(true),
ToolOutput::Mcp { result } => result.is_ok(),
ToolOutput::Mcp { result, .. } => result.is_ok(),
}
}
pub fn taint_effect(&self) -> &TaintEffect {
match self {
ToolOutput::Function { taint_effect, .. } | ToolOutput::Mcp { taint_effect, .. } => {
taint_effect
}
}
}
pub fn into_response(self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
match self {
ToolOutput::Function { body, success } => {
ToolOutput::Function {
body,
success,
taint_effect: _,
} => {
// `custom_tool_call` is the Responses API item type for freeform
// tools (`ToolSpec::Freeform`, e.g. freeform `apply_patch`).
// Those payloads must round-trip as `custom_tool_call_output`
@@ -115,7 +130,10 @@ impl ToolOutput {
}
}
// Direct MCP response path for MCP tool result envelopes.
ToolOutput::Mcp { result } => ResponseInputItem::McpToolCallOutput {
ToolOutput::Mcp {
result,
taint_effect: _,
} => ResponseInputItem::McpToolCallOutput {
call_id: call_id.to_string(),
result,
},
@@ -177,6 +195,7 @@ mod tests {
let response = ToolOutput::Function {
body: FunctionCallOutputBody::Text("patched".to_string()),
success: Some(true),
taint_effect: TaintEffect::None,
}
.into_response("call-42", &payload);
@@ -197,6 +216,7 @@ mod tests {
let response = ToolOutput::Function {
body: FunctionCallOutputBody::Text("ok".to_string()),
success: Some(true),
taint_effect: TaintEffect::None,
}
.into_response("fn-1", &payload);
@@ -229,6 +249,7 @@ mod tests {
},
]),
success: Some(true),
taint_effect: TaintEffect::None,
}
.into_response("call-99", &payload);
@@ -250,6 +271,7 @@ mod tests {
},
]),
success: Some(true),
taint_effect: TaintEffect::None,
};
assert_eq!(output.log_preview(), "preview");
@@ -286,4 +308,15 @@ mod tests {
assert!(lines.len() <= TELEMETRY_PREVIEW_MAX_LINES + 1);
assert_eq!(lines.last(), Some(&TELEMETRY_PREVIEW_TRUNCATION_NOTICE));
}
#[test]
fn taint_effect_defaults_to_none_when_unmarked() {
let output = ToolOutput::Function {
body: FunctionCallOutputBody::Text("ok".to_string()),
success: Some(true),
taint_effect: TaintEffect::None,
};
assert_eq!(output.taint_effect(), &TaintEffect::None);
}
}

View File

@@ -460,6 +460,7 @@ mod spawn_agents_on_csv {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: codex_taint::TaintEffect::None,
})
}
}
@@ -509,6 +510,7 @@ mod report_agent_job_result {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: codex_taint::TaintEffect::None,
})
}
}

View File

@@ -114,6 +114,7 @@ impl ToolHandler for ApplyPatchHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: codex_taint::TaintEffect::None,
})
}
InternalApplyPatchInvocation::DelegateToExec(apply) => {
@@ -166,6 +167,7 @@ impl ToolHandler for ApplyPatchHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: codex_taint::TaintEffect::None,
})
}
}
@@ -217,6 +219,7 @@ pub(crate) async fn intercept_apply_patch(
Ok(Some(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: codex_taint::TaintEffect::None,
}))
}
InternalApplyPatchInvocation::DelegateToExec(apply) => {
@@ -268,6 +271,7 @@ pub(crate) async fn intercept_apply_patch(
Ok(Some(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: codex_taint::TaintEffect::None,
}))
}
}

View File

@@ -13,6 +13,10 @@ use codex_protocol::dynamic_tools::DynamicToolResponse;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputContentItem;
use codex_protocol::protocol::EventMsg;
use codex_taint::TaintEffect;
use codex_taint::TaintLabel;
use codex_taint::TaintSink;
use codex_taint::TaintSource;
use serde_json::Value;
use tokio::sync::oneshot;
use tracing::warn;
@@ -49,6 +53,9 @@ impl ToolHandler for DynamicToolHandler {
};
let args: Value = parse_arguments(&arguments)?;
session
.ensure_taint_sink_allowed(&turn.sub_id, TaintSink::external_dispatch())
.await?;
let response = request_dynamic_tool(&session, turn.as_ref(), call_id, tool_name, args)
.await
.ok_or_else(|| {
@@ -70,6 +77,10 @@ impl ToolHandler for DynamicToolHandler {
Ok(ToolOutput::Function {
body,
success: Some(success),
taint_effect: TaintEffect::Mark {
label: TaintLabel::ExternalContent,
source: TaintSource::DynamicTool,
},
})
}
}

View File

@@ -1,4 +1,7 @@
use codex_protocol::models::FunctionCallOutputBody;
use codex_taint::TaintEffect;
use codex_taint::TaintLabel;
use codex_taint::TaintSource;
use std::path::Path;
use std::time::Duration;
@@ -89,11 +92,19 @@ impl ToolHandler for GrepFilesHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text("No matches found.".to_string()),
success: Some(false),
taint_effect: TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::GrepFiles,
},
})
} else {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(search_results.join("\n")),
success: Some(true),
taint_effect: TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::GrepFiles,
},
})
}
}

View File

@@ -172,6 +172,7 @@ impl ToolHandler for JsReplHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::ContentItems(items),
success: Some(true),
taint_effect: codex_taint::TaintEffect::None,
})
}
}
@@ -193,6 +194,7 @@ impl ToolHandler for JsReplResetHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text("js_repl kernel reset".to_string()),
success: Some(true),
taint_effect: codex_taint::TaintEffect::None,
})
}
}

View File

@@ -1,4 +1,7 @@
use codex_protocol::models::FunctionCallOutputBody;
use codex_taint::TaintEffect;
use codex_taint::TaintLabel;
use codex_taint::TaintSource;
use std::collections::VecDeque;
use std::ffi::OsStr;
use std::fs::FileType;
@@ -105,6 +108,10 @@ impl ToolHandler for ListDirHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(output.join("\n")),
success: Some(true),
taint_effect: TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::ListDir,
},
})
}
}

View File

@@ -9,6 +9,9 @@ use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use codex_protocol::models::ResponseInputItem;
use codex_taint::TaintEffect;
use codex_taint::TaintLabel;
use codex_taint::TaintSource;
pub struct McpHandler;
@@ -54,11 +57,28 @@ impl ToolHandler for McpHandler {
.await;
match response {
ResponseInputItem::McpToolCallOutput { result, .. } => Ok(ToolOutput::Mcp { result }),
ResponseInputItem::McpToolCallOutput { result, .. } => {
let taint_effect = if result.is_ok() {
TaintEffect::Mark {
label: TaintLabel::ExternalContent,
source: TaintSource::McpTool,
}
} else {
TaintEffect::None
};
Ok(ToolOutput::Mcp {
result,
taint_effect,
})
}
ResponseInputItem::FunctionCallOutput { output, .. } => {
let success = output.success;
let body = output.body;
Ok(ToolOutput::Function { body, success })
Ok(ToolOutput::Function {
body,
success,
taint_effect: TaintEffect::None,
})
}
_ => Err(FunctionCallError::RespondToModel(
"mcp handler received unexpected response variant".to_string(),

View File

@@ -6,6 +6,9 @@ use std::time::Instant;
use async_trait::async_trait;
use codex_protocol::mcp::CallToolResult;
use codex_taint::TaintEffect;
use codex_taint::TaintLabel;
use codex_taint::TaintSource;
use rmcp::model::ListResourceTemplatesResult;
use rmcp::model::ListResourcesResult;
use rmcp::model::PaginatedRequestParams;
@@ -298,7 +301,7 @@ async fn handle_list_resources(
match payload_result {
Ok(payload) => match serialize_function_output(payload) {
Ok(output) => {
let ToolOutput::Function { body, success } = &output else {
let ToolOutput::Function { body, success, .. } = &output else {
unreachable!("MCP resource handler should return function output");
};
let content = body.to_text().unwrap_or_default();
@@ -406,7 +409,7 @@ async fn handle_list_resource_templates(
match payload_result {
Ok(payload) => match serialize_function_output(payload) {
Ok(output) => {
let ToolOutput::Function { body, success } = &output else {
let ToolOutput::Function { body, success, .. } = &output else {
unreachable!("MCP resource handler should return function output");
};
let content = body.to_text().unwrap_or_default();
@@ -499,7 +502,7 @@ async fn handle_read_resource(
match payload_result {
Ok(payload) => match serialize_function_output(payload) {
Ok(output) => {
let ToolOutput::Function { body, success } = &output else {
let ToolOutput::Function { body, success, .. } = &output else {
unreachable!("MCP resource handler should return function output");
};
let content = body.to_text().unwrap_or_default();
@@ -627,6 +630,10 @@ where
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: TaintEffect::Mark {
label: TaintLabel::ExternalContent,
source: TaintSource::McpResource,
},
})
}

View File

@@ -34,6 +34,10 @@ use codex_protocol::protocol::CollabWaitingEndEvent;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_protocol::user_input::UserInput;
use codex_taint::TaintEffect;
use codex_taint::TaintLabel;
use codex_taint::TaintSink;
use codex_taint::TaintSource;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
@@ -221,6 +225,7 @@ mod spawn {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: TaintEffect::None,
})
}
}
@@ -253,6 +258,9 @@ mod send_input {
let receiver_thread_id = agent_id(&args.id)?;
let input_items = parse_collab_input(args.message, args.items)?;
let prompt = input_preview(&input_items);
session
.ensure_taint_sink_allowed(&turn.sub_id, TaintSink::agent_forward())
.await?;
let (receiver_agent_nickname, receiver_agent_role) = session
.services
.agent_control
@@ -314,6 +322,7 @@ mod send_input {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: TaintEffect::None,
})
}
}
@@ -428,6 +437,7 @@ mod resume_agent {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: TaintEffect::None,
})
}
@@ -641,6 +651,10 @@ pub(crate) mod wait {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: None,
taint_effect: TaintEffect::Mark {
label: TaintLabel::AgentContent,
source: TaintSource::AgentResult,
},
})
}
@@ -761,6 +775,10 @@ pub mod close_agent {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: TaintEffect::Mark {
label: TaintLabel::AgentContent,
source: TaintSource::AgentResult,
},
})
}
}

View File

@@ -91,6 +91,7 @@ impl ToolHandler for PlanHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: codex_taint::TaintEffect::None,
})
}
}

View File

@@ -1,4 +1,7 @@
use codex_protocol::models::FunctionCallOutputBody;
use codex_taint::TaintEffect;
use codex_taint::TaintLabel;
use codex_taint::TaintSource;
use std::collections::VecDeque;
use std::path::PathBuf;
@@ -149,6 +152,10 @@ impl ToolHandler for ReadFileHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(collected.join("\n")),
success: Some(true),
taint_effect: TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::ReadFile,
},
})
}
}

View File

@@ -11,6 +11,7 @@ use crate::tools::registry::ToolKind;
use codex_protocol::config_types::ModeKind;
use codex_protocol::config_types::TUI_VISIBLE_COLLABORATION_MODES;
use codex_protocol::request_user_input::RequestUserInputArgs;
use codex_taint::TaintEffect;
fn format_allowed_modes() -> String {
let mode_names: Vec<&str> = TUI_VISIBLE_COLLABORATION_MODES
@@ -107,6 +108,7 @@ impl ToolHandler for RequestUserInputHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: TaintEffect::Reset,
})
}
}

View File

@@ -144,6 +144,7 @@ impl ToolHandler for SearchToolBm25Handler {
return Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: codex_taint::TaintEffect::None,
});
}
@@ -187,6 +188,7 @@ impl ToolHandler for SearchToolBm25Handler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: codex_taint::TaintEffect::None,
})
}
}

View File

@@ -3,6 +3,10 @@ use codex_protocol::ThreadId;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::ShellCommandToolCallParams;
use codex_protocol::models::ShellToolCallParams;
use codex_taint::TaintEffect;
use codex_taint::TaintLabel;
use codex_taint::TaintSink;
use codex_taint::TaintSource;
use std::sync::Arc;
use crate::codex::TurnContext;
@@ -387,6 +391,10 @@ impl ShellHandler {
return Ok(output);
}
session
.ensure_taint_sink_allowed(&turn.sub_id, TaintSink::shell_exec())
.await?;
let source = ExecCommandSource::Agent;
let emitter = ToolEmitter::shell(
exec_params.command.clone(),
@@ -452,6 +460,10 @@ impl ShellHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::ShellOutput,
},
})
}
}

View File

@@ -94,6 +94,7 @@ impl ToolHandler for TestSyncHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text("ok".to_string()),
success: Some(true),
taint_effect: codex_taint::TaintEffect::None,
})
}
}

View File

@@ -23,6 +23,10 @@ use crate::unified_exec::WriteStdinRequest;
use async_trait::async_trait;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::PermissionProfile;
use codex_taint::TaintEffect;
use codex_taint::TaintLabel;
use codex_taint::TaintSink;
use codex_taint::TaintSource;
use serde::Deserialize;
use std::path::PathBuf;
use std::sync::Arc;
@@ -215,6 +219,10 @@ impl ToolHandler for UnifiedExecHandler {
return Ok(output);
}
session
.ensure_taint_sink_allowed(&turn.sub_id, TaintSink::shell_exec())
.await?;
manager
.exec_command(
ExecCommandRequest {
@@ -274,6 +282,10 @@ impl ToolHandler for UnifiedExecHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::Text(content),
success: Some(true),
taint_effect: TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::ShellOutput,
},
})
}
}

View File

@@ -2,6 +2,9 @@ use async_trait::async_trait;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputContentItem;
use codex_protocol::openai_models::InputModality;
use codex_taint::TaintEffect;
use codex_taint::TaintLabel;
use codex_taint::TaintSource;
use serde::Deserialize;
use tokio::fs;
@@ -110,6 +113,10 @@ impl ToolHandler for ViewImageHandler {
Ok(ToolOutput::Function {
body: FunctionCallOutputBody::ContentItems(content),
success: Some(true),
taint_effect: TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::ViewImage,
},
})
}
}

View File

@@ -216,6 +216,11 @@ impl ToolRegistry {
let output = guard.take().ok_or_else(|| {
FunctionCallError::Fatal("tool produced no output".to_string())
})?;
let taint_effect = output.taint_effect().clone();
invocation
.session
.apply_taint_effect(&invocation.turn.sub_id, taint_effect)
.await;
Ok(output.into_response(&call_id_owned, &payload_for_response))
}
Err(err) => Err(err),

12
codex-rs/taint/Cargo.toml Normal file
View File

@@ -0,0 +1,12 @@
[package]
edition.workspace = true
license.workspace = true
name = "codex-taint"
version.workspace = true
[lib]
name = "codex_taint"
path = "src/lib.rs"
[lints]
workspace = true

299
codex-rs/taint/src/lib.rs Normal file
View File

@@ -0,0 +1,299 @@
use std::collections::BTreeSet;
use std::fmt;
const MAX_RECENT_SOURCES: usize = 8;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum TaintLabel {
WorkspaceContent,
ExternalContent,
AgentContent,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TaintSource {
ReadFile,
GrepFiles,
ListDir,
ViewImage,
ShellOutput,
UserShellOutput,
McpTool,
McpResource,
DynamicTool,
WebSearch,
AgentResult,
}
impl TaintSource {
fn as_str(self) -> &'static str {
match self {
Self::ReadFile => "read_file",
Self::GrepFiles => "grep_files",
Self::ListDir => "list_dir",
Self::ViewImage => "view_image",
Self::ShellOutput => "shell_output",
Self::UserShellOutput => "user_shell_output",
Self::McpTool => "mcp_tool",
Self::McpResource => "mcp_resource",
Self::DynamicTool => "dynamic_tool",
Self::WebSearch => "web_search",
Self::AgentResult => "agent_result",
}
}
}
impl fmt::Display for TaintSource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaintSink {
ShellExec,
ExternalDispatch,
AgentForward,
}
impl TaintSink {
pub const fn shell_exec() -> Self {
Self::ShellExec
}
pub const fn external_dispatch() -> Self {
Self::ExternalDispatch
}
pub const fn agent_forward() -> Self {
Self::AgentForward
}
fn as_str(self) -> &'static str {
match self {
Self::ShellExec => "shell execution",
Self::ExternalDispatch => "external tool dispatch",
Self::AgentForward => "agent forwarding",
}
}
fn blocked_labels(self) -> &'static [TaintLabel] {
match self {
Self::ShellExec => &[
TaintLabel::WorkspaceContent,
TaintLabel::ExternalContent,
TaintLabel::AgentContent,
],
Self::ExternalDispatch => &[TaintLabel::ExternalContent, TaintLabel::AgentContent],
Self::AgentForward => &[TaintLabel::ExternalContent, TaintLabel::AgentContent],
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TaintEffect {
None,
Mark {
label: TaintLabel,
source: TaintSource,
},
Reset,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct TaintState {
labels: BTreeSet<TaintLabel>,
recent_sources: Vec<TaintSource>,
}
impl TaintState {
pub fn apply(&mut self, effect: TaintEffect) {
match effect {
TaintEffect::None => {}
TaintEffect::Reset => self.reset(),
TaintEffect::Mark { label, source } => {
self.labels.insert(label);
if self.recent_sources.last().copied() != Some(source) {
self.recent_sources.push(source);
if self.recent_sources.len() > MAX_RECENT_SOURCES {
let overflow = self.recent_sources.len() - MAX_RECENT_SOURCES;
self.recent_sources.drain(0..overflow);
}
}
}
}
}
pub fn reset(&mut self) {
self.labels.clear();
self.recent_sources.clear();
}
pub fn labels(&self) -> &BTreeSet<TaintLabel> {
&self.labels
}
pub fn recent_sources(&self) -> &[TaintSource] {
&self.recent_sources
}
pub fn check_sink(&self, sink: TaintSink) -> Result<(), TaintViolation> {
let blocked_labels = self
.labels
.iter()
.copied()
.filter(|label| sink.blocked_labels().contains(label))
.collect::<Vec<_>>();
if blocked_labels.is_empty() {
Ok(())
} else {
Err(TaintViolation {
sink,
labels: blocked_labels,
recent_sources: self.recent_sources.clone(),
})
}
}
pub fn is_clean(&self) -> bool {
self.labels.is_empty()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TaintViolation {
pub sink: TaintSink,
pub labels: Vec<TaintLabel>,
pub recent_sources: Vec<TaintSource>,
}
impl fmt::Display for TaintViolation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sources = if self.recent_sources.is_empty() {
"an unknown source".to_string()
} else {
self.recent_sources
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(", ")
};
write!(
f,
"Refusing to use {} because this turn includes recent untrusted content from {}. Ask the user to confirm the next action in a new message if they want to proceed.",
self.sink.as_str(),
sources
)
}
}
impl std::error::Error for TaintViolation {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mark_adds_labels_and_sources_deterministically() {
let mut state = TaintState::default();
state.apply(TaintEffect::Mark {
label: TaintLabel::ExternalContent,
source: TaintSource::McpTool,
});
state.apply(TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::ReadFile,
});
assert_eq!(
state.labels().iter().copied().collect::<Vec<_>>(),
vec![TaintLabel::WorkspaceContent, TaintLabel::ExternalContent]
);
assert_eq!(
state.recent_sources(),
&[TaintSource::McpTool, TaintSource::ReadFile]
);
}
#[test]
fn repeated_mark_does_not_repeat_same_source_back_to_back() {
let mut state = TaintState::default();
state.apply(TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::ReadFile,
});
state.apply(TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::ReadFile,
});
assert_eq!(state.recent_sources(), &[TaintSource::ReadFile]);
}
#[test]
fn reset_clears_labels_and_sources() {
let mut state = TaintState::default();
state.apply(TaintEffect::Mark {
label: TaintLabel::AgentContent,
source: TaintSource::AgentResult,
});
state.apply(TaintEffect::Reset);
assert!(state.is_clean());
assert!(state.recent_sources().is_empty());
}
#[test]
fn shell_exec_blocks_any_active_label() {
let mut state = TaintState::default();
state.apply(TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::ReadFile,
});
let violation = state
.check_sink(TaintSink::shell_exec())
.expect_err("shell should be blocked");
assert_eq!(violation.labels, vec![TaintLabel::WorkspaceContent]);
assert_eq!(violation.recent_sources, vec![TaintSource::ReadFile]);
}
#[test]
fn external_dispatch_blocks_external_and_agent_content_only() {
let mut workspace_only = TaintState::default();
workspace_only.apply(TaintEffect::Mark {
label: TaintLabel::WorkspaceContent,
source: TaintSource::ReadFile,
});
assert!(
workspace_only
.check_sink(TaintSink::external_dispatch())
.is_ok()
);
let mut external = TaintState::default();
external.apply(TaintEffect::Mark {
label: TaintLabel::ExternalContent,
source: TaintSource::McpTool,
});
assert!(external.check_sink(TaintSink::external_dispatch()).is_err());
}
#[test]
fn violation_messages_include_sink_and_sources() {
let mut state = TaintState::default();
state.apply(TaintEffect::Mark {
label: TaintLabel::ExternalContent,
source: TaintSource::WebSearch,
});
let message = state
.check_sink(TaintSink::agent_forward())
.expect_err("agent forwarding should be blocked")
.to_string();
assert!(message.contains("agent forwarding"));
assert!(message.contains("web_search"));
}
}