Compare commits

..

4 Commits

Author SHA1 Message Date
starr-openai
fd58e4621f Add exec-server crate and tests
Co-authored-by: Codex <noreply@openai.com>
2026-03-16 17:30:51 -07:00
iceweasel-oai
d0a693e541 windows-sandbox: add runner IPC foundation for future unified_exec (#14139)
# Summary

This PR introduces the Windows sandbox runner IPC foundation that later
unified_exec work will build on.

The key point is that this is intentionally infrastructure-only. The new
IPC transport, runner plumbing, and ConPTY helpers are added here, but
the active elevated Windows sandbox path still uses the existing
request-file bootstrap. In other words, this change prepares the
transport and module layout we need for unified_exec without switching
production behavior over yet.

Part of this PR is also a source-layout cleanup: some Windows sandbox
files are moved into more explicit `elevated/`, `conpty/`, and shared
locations so it is clearer which code is for the elevated sandbox flow,
which code is legacy/direct-spawn behavior, and which helpers are shared
between them. That reorganization is intentional in this first PR so
later behavioral changes do not also have to carry a large amount of
file-move churn.

# Why This Is Needed For unified_exec

Windows elevated sandboxed unified_exec needs a long-lived,
bidirectional control channel between the CLI and a helper process
running under the sandbox user. That channel has to support:

- starting a process and reporting structured spawn success/failure
- streaming stdout/stderr back incrementally
- forwarding stdin over time
- terminating or polling a long-lived process
- supporting both pipe-backed and PTY-backed sessions

The existing elevated one-shot path is built around a request-file
bootstrap and does not provide those primitives cleanly. Before we can
turn on Windows sandbox unified_exec, we need the underlying runner
protocol and transport layer that can carry those lifecycle events and
streams.

# Why Windows Needs More Machinery Than Linux Or macOS

Linux and macOS can generally build unified_exec on top of the existing
sandbox/process model: the parent can spawn the child directly, retain
normal ownership of stdio or PTY handles, and manage the lifetime of the
sandboxed process without introducing a second control process.

Windows elevated sandboxing is different. To run inside the sandbox
boundary, we cross into a different user/security context and then need
to manage a long-lived process from outside that boundary. That means we
need an explicit helper process plus an IPC transport to carry spawn,
stdin, output, and exit events back and forth. The extra code here is
mostly that missing Windows sandbox infrastructure, not a conceptual
difference in unified_exec itself.

# What This PR Adds

- the framed IPC message types and transport helpers for parent <->
runner communication
- the renamed Windows command runner with both the existing request-file
bootstrap and the dormant IPC bootstrap
- named-pipe helpers for the elevated runner path
- ConPTY helpers and process-thread attribute plumbing needed for
PTY-backed sessions
- shared sandbox/process helpers that later PRs will reuse when
switching live execution paths over
- early file/module moves so later PRs can focus on behavior rather than
layout churn

# What This PR Does Not Yet Do

- it does not switch the active elevated one-shot path over to IPC yet
- it does not enable Windows sandbox unified_exec yet
- it does not remove the existing request-file bootstrap yet

So while this code compiles and the new path has basic validation, it is
not yet the exercised production path. That is intentional for this
first PR: the goal here is to land the transport and runner foundation
cleanly before later PRs start routing real command execution through
it.

# Follow-Ups

Planned follow-up PRs will:

1. switch elevated one-shot Windows sandbox execution to the new runner
IPC path
2. layer Windows sandbox unified_exec sessions on top of the same
transport
3. remove the legacy request-file path once the IPC-based path is live

# Validation

- `cargo build -p codex-windows-sandbox`
2026-03-16 19:45:06 +00:00
Andi Liu
4c9dbc1f88 memories: exclude AGENTS and skills from stage1 input (#14268)
###### Why/Context/Summary
- Exclude injected AGENTS.md instructions and standalone skill payloads
from memory stage 1 inputs so memory generation focuses on conversation
content instead of prompt scaffolding.
- Strip only the AGENTS fragment from mixed contextual user messages
during stage-1 serialization, which preserves environment context in the
same message.
- Keep subagent notifications in the memory input, and add focused unit
coverage for the fragment classifier, rollout policy, and stage-1
serialization path.

###### Test plan
- `just fmt`
- `cargo test -p codex-core --lib contextual_user_message`
- `cargo test -p codex-core --lib rollout::policy`
- `cargo test -p codex-core --lib memories::phase1`
2026-03-16 19:30:38 +00:00
Anton Panasenko
663dd3f935 fix(core): fix sanitize name to use '_' everywhere (#14833) 2026-03-16 12:22:10 -07:00
41 changed files with 3257 additions and 1730 deletions

22
codex-rs/Cargo.lock generated
View File

@@ -1989,6 +1989,23 @@ dependencies = [
"wiremock",
]
[[package]]
name = "codex-exec-server"
version = "0.0.0"
dependencies = [
"anyhow",
"base64 0.22.1",
"codex-app-server-protocol",
"codex-utils-cargo-bin",
"codex-utils-pty",
"pretty_assertions",
"serde",
"serde_json",
"thiserror 2.0.18",
"tokio",
"tracing",
]
[[package]]
name = "codex-execpolicy"
version = "0.0.0"
@@ -2102,9 +2119,6 @@ name = "codex-keyring-store"
version = "0.0.0"
dependencies = [
"keyring",
"pretty_assertions",
"serde",
"serde_json",
"tracing",
]
@@ -2849,6 +2863,7 @@ dependencies = [
"chrono",
"codex-protocol",
"codex-utils-absolute-path",
"codex-utils-pty",
"codex-utils-string",
"dirs-next",
"dunce",
@@ -2857,6 +2872,7 @@ dependencies = [
"serde",
"serde_json",
"tempfile",
"tokio",
"windows 0.58.0",
"windows-sys 0.52.0",
"winres",

View File

@@ -25,6 +25,7 @@ members = [
"hooks",
"secrets",
"exec",
"exec-server",
"execpolicy",
"execpolicy-legacy",
"keyring-store",

View File

@@ -23,9 +23,6 @@ use crate::token_data::TokenData;
use codex_app_server_protocol::AuthMode;
use codex_keyring_store::DefaultKeyringStore;
use codex_keyring_store::KeyringStore;
use codex_keyring_store::delete_json_from_keyring;
use codex_keyring_store::load_json_from_keyring;
use codex_keyring_store::save_json_to_keyring;
use once_cell::sync::Lazy;
/// Determine where Codex should store CLI auth credentials.
@@ -165,39 +162,47 @@ impl KeyringAuthStorage {
}
}
fn load_auth_from_keyring(&self, base_key: &str) -> std::io::Result<Option<AuthDotJson>> {
let Some(value) =
load_json_from_keyring(self.keyring_store.as_ref(), KEYRING_SERVICE, base_key)
.map_err(|err| {
std::io::Error::other(format!("failed to load CLI auth from keyring: {err}"))
})?
else {
return Ok(None);
};
serde_json::from_value(value).map(Some).map_err(|err| {
std::io::Error::other(format!(
"failed to deserialize CLI auth from keyring: {err}"
))
})
fn load_from_keyring(&self, key: &str) -> std::io::Result<Option<AuthDotJson>> {
match self.keyring_store.load(KEYRING_SERVICE, key) {
Ok(Some(serialized)) => serde_json::from_str(&serialized).map(Some).map_err(|err| {
std::io::Error::other(format!(
"failed to deserialize CLI auth from keyring: {err}"
))
}),
Ok(None) => Ok(None),
Err(error) => Err(std::io::Error::other(format!(
"failed to load CLI auth from keyring: {}",
error.message()
))),
}
}
fn save_to_keyring(&self, key: &str, value: &str) -> std::io::Result<()> {
match self.keyring_store.save(KEYRING_SERVICE, key, value) {
Ok(()) => Ok(()),
Err(error) => {
let message = format!(
"failed to write OAuth tokens to keyring: {}",
error.message()
);
warn!("{message}");
Err(std::io::Error::other(message))
}
}
}
}
impl AuthStorageBackend for KeyringAuthStorage {
fn load(&self) -> std::io::Result<Option<AuthDotJson>> {
let key = compute_store_key(&self.codex_home)?;
self.load_auth_from_keyring(&key)
self.load_from_keyring(&key)
}
fn save(&self, auth: &AuthDotJson) -> std::io::Result<()> {
let base_key = compute_store_key(&self.codex_home)?;
let value = serde_json::to_value(auth).map_err(std::io::Error::other)?;
save_json_to_keyring(
self.keyring_store.as_ref(),
KEYRING_SERVICE,
&base_key,
&value,
)
.map_err(|err| std::io::Error::other(format!("failed to write auth to keyring: {err}")))?;
let key = compute_store_key(&self.codex_home)?;
// Simpler error mapping per style: prefer method reference over closure
let serialized = serde_json::to_string(auth).map_err(std::io::Error::other)?;
self.save_to_keyring(&key, &serialized)?;
if let Err(err) = delete_file_if_exists(&self.codex_home) {
warn!("failed to remove CLI auth fallback file: {err}");
}
@@ -205,12 +210,13 @@ impl AuthStorageBackend for KeyringAuthStorage {
}
fn delete(&self) -> std::io::Result<bool> {
let base_key = compute_store_key(&self.codex_home)?;
let keyring_removed =
delete_json_from_keyring(self.keyring_store.as_ref(), KEYRING_SERVICE, &base_key)
.map_err(|err| {
std::io::Error::other(format!("failed to delete auth from keyring: {err}"))
})?;
let key = compute_store_key(&self.codex_home)?;
let keyring_removed = self
.keyring_store
.delete(KEYRING_SERVICE, &key)
.map_err(|err| {
std::io::Error::other(format!("failed to delete auth from keyring: {err}"))
})?;
let file_removed = delete_file_if_exists(&self.codex_home)?;
Ok(keyring_removed || file_removed)
}

View File

@@ -6,88 +6,9 @@ use pretty_assertions::assert_eq;
use serde_json::json;
use tempfile::tempdir;
use codex_keyring_store::CredentialStoreError;
use codex_keyring_store::tests::MockKeyringStore;
use keyring::Error as KeyringError;
#[derive(Clone, Debug)]
struct SaveSecretErrorKeyringStore {
inner: MockKeyringStore,
}
impl KeyringStore for SaveSecretErrorKeyringStore {
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError> {
self.inner.load(service, account)
}
fn load_secret(
&self,
service: &str,
account: &str,
) -> Result<Option<Vec<u8>>, CredentialStoreError> {
self.inner.load_secret(service, account)
}
fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError> {
self.inner.save(service, account, value)
}
fn save_secret(
&self,
_service: &str,
_account: &str,
_value: &[u8],
) -> Result<(), CredentialStoreError> {
Err(CredentialStoreError::new(KeyringError::Invalid(
"error".into(),
"save".into(),
)))
}
fn delete(&self, service: &str, account: &str) -> Result<bool, CredentialStoreError> {
self.inner.delete(service, account)
}
}
#[derive(Clone, Debug)]
struct LoadSecretErrorKeyringStore {
inner: MockKeyringStore,
}
impl KeyringStore for LoadSecretErrorKeyringStore {
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError> {
self.inner.load(service, account)
}
fn load_secret(
&self,
_service: &str,
_account: &str,
) -> Result<Option<Vec<u8>>, CredentialStoreError> {
Err(CredentialStoreError::new(KeyringError::Invalid(
"error".into(),
"load".into(),
)))
}
fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError> {
self.inner.save(service, account, value)
}
fn save_secret(
&self,
service: &str,
account: &str,
value: &[u8],
) -> Result<(), CredentialStoreError> {
self.inner.save_secret(service, account, value)
}
fn delete(&self, service: &str, account: &str) -> Result<bool, CredentialStoreError> {
self.inner.delete(service, account)
}
}
#[tokio::test]
async fn file_storage_load_returns_auth_dot_json() -> anyhow::Result<()> {
let codex_home = tempdir()?;
@@ -176,16 +97,19 @@ fn ephemeral_storage_save_load_delete_is_in_memory_only() -> anyhow::Result<()>
Ok(())
}
fn seed_keyring_and_fallback_auth_file_for_delete(
storage: &KeyringAuthStorage,
fn seed_keyring_and_fallback_auth_file_for_delete<F>(
mock_keyring: &MockKeyringStore,
codex_home: &Path,
auth: &AuthDotJson,
) -> anyhow::Result<(String, PathBuf)> {
storage.save(auth)?;
let base_key = compute_store_key(codex_home)?;
compute_key: F,
) -> anyhow::Result<(String, PathBuf)>
where
F: FnOnce() -> std::io::Result<String>,
{
let key = compute_key()?;
mock_keyring.save(KEYRING_SERVICE, &key, "{}")?;
let auth_file = get_auth_file(codex_home);
std::fs::write(&auth_file, "stale")?;
Ok((base_key, auth_file))
Ok((key, auth_file))
}
fn seed_keyring_with_auth<F>(
@@ -204,26 +128,15 @@ where
fn assert_keyring_saved_auth_and_removed_fallback(
mock_keyring: &MockKeyringStore,
base_key: &str,
key: &str,
codex_home: &Path,
expected: &AuthDotJson,
) {
let expected_json = serde_json::to_value(expected).expect("auth should serialize");
let loaded = load_json_from_keyring(mock_keyring, KEYRING_SERVICE, base_key)
.expect("auth should load from keyring")
.expect("auth should exist");
assert_eq!(loaded, expected_json);
#[cfg(windows)]
assert!(
mock_keyring.saved_secret(base_key).is_none(),
"windows should store auth using split keyring entries"
);
#[cfg(not(windows))]
assert_eq!(
mock_keyring.saved_secret(base_key),
Some(serde_json::to_vec(&expected_json).expect("auth should serialize")),
"non-windows should store auth as one JSON secret"
);
let saved_value = mock_keyring
.saved_value(key)
.expect("keyring entry should exist");
let expected_serialized = serde_json::to_string(expected).expect("serialize expected auth");
assert_eq!(saved_value, expected_serialized);
let auth_file = get_auth_file(codex_home);
assert!(
!auth_file.exists(),
@@ -272,7 +185,7 @@ fn auth_with_prefix(prefix: &str) -> AuthDotJson {
}
#[test]
fn keyring_auth_storage_load_supports_legacy_single_entry() -> anyhow::Result<()> {
fn keyring_auth_storage_load_returns_deserialized_auth() -> anyhow::Result<()> {
let codex_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let storage = KeyringAuthStorage::new(
@@ -296,39 +209,6 @@ fn keyring_auth_storage_load_supports_legacy_single_entry() -> anyhow::Result<()
Ok(())
}
#[test]
fn keyring_auth_storage_load_returns_deserialized_keyring_auth() -> anyhow::Result<()> {
let codex_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let storage = KeyringAuthStorage::new(codex_home.path().to_path_buf(), Arc::new(mock_keyring));
let expected = auth_with_prefix("keyring");
storage.save(&expected)?;
let loaded = storage.load()?;
assert_eq!(Some(expected), loaded);
Ok(())
}
#[test]
fn keyring_auth_storage_load_supports_split_json_compatibility() -> anyhow::Result<()> {
let codex_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let storage = KeyringAuthStorage::new(
codex_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let expected = auth_with_prefix("split-compat");
let key = compute_store_key(codex_home.path())?;
let value = serde_json::to_value(&expected)?;
codex_keyring_store::save_split_json_to_keyring(&mock_keyring, KEYRING_SERVICE, &key, &value)?;
let loaded = storage.load()?;
assert_eq!(Some(expected), loaded);
Ok(())
}
#[test]
fn keyring_auth_storage_compute_store_key_for_home_directory() -> anyhow::Result<()> {
let codex_home = PathBuf::from("~/.codex");
@@ -376,16 +256,17 @@ fn keyring_auth_storage_delete_removes_keyring_and_file() -> anyhow::Result<()>
codex_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let auth = auth_with_prefix("delete");
let (base_key, auth_file) =
seed_keyring_and_fallback_auth_file_for_delete(&storage, codex_home.path(), &auth)?;
let (key, auth_file) =
seed_keyring_and_fallback_auth_file_for_delete(&mock_keyring, codex_home.path(), || {
compute_store_key(codex_home.path())
})?;
let removed = storage.delete()?;
assert!(removed, "delete should report removal");
assert!(
load_json_from_keyring(&mock_keyring, KEYRING_SERVICE, &base_key)?.is_none(),
"keyring auth should be removed"
!mock_keyring.contains(&key),
"keyring entry should be removed"
);
assert!(
!auth_file.exists(),
@@ -435,10 +316,12 @@ fn auto_auth_storage_load_uses_file_when_keyring_empty() -> anyhow::Result<()> {
fn auto_auth_storage_load_falls_back_when_keyring_errors() -> anyhow::Result<()> {
let codex_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let failing_keyring = LoadSecretErrorKeyringStore {
inner: mock_keyring,
};
let storage = AutoAuthStorage::new(codex_home.path().to_path_buf(), Arc::new(failing_keyring));
let storage = AutoAuthStorage::new(
codex_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let key = compute_store_key(codex_home.path())?;
mock_keyring.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
let expected = auth_with_prefix("fallback");
storage.file_storage.save(&expected)?;
@@ -477,11 +360,12 @@ fn auto_auth_storage_save_prefers_keyring() -> anyhow::Result<()> {
fn auto_auth_storage_save_falls_back_when_keyring_errors() -> anyhow::Result<()> {
let codex_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let failing_keyring = SaveSecretErrorKeyringStore {
inner: mock_keyring.clone(),
};
let storage = AutoAuthStorage::new(codex_home.path().to_path_buf(), Arc::new(failing_keyring));
let storage = AutoAuthStorage::new(
codex_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let key = compute_store_key(codex_home.path())?;
mock_keyring.set_error(&key, KeyringError::Invalid("error".into(), "save".into()));
let auth = auth_with_prefix("fallback");
storage.save(&auth)?;
@@ -497,8 +381,8 @@ fn auto_auth_storage_save_falls_back_when_keyring_errors() -> anyhow::Result<()>
.context("fallback auth should exist")?;
assert_eq!(saved, auth);
assert!(
load_json_from_keyring(&mock_keyring, KEYRING_SERVICE, &key)?.is_none(),
"keyring should not point to saved auth when save fails"
mock_keyring.saved_value(&key).is_none(),
"keyring should not contain value when save fails"
);
Ok(())
}
@@ -511,19 +395,17 @@ fn auto_auth_storage_delete_removes_keyring_and_file() -> anyhow::Result<()> {
codex_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let auth = auth_with_prefix("auto-delete");
let (base_key, auth_file) = seed_keyring_and_fallback_auth_file_for_delete(
storage.keyring_storage.as_ref(),
codex_home.path(),
&auth,
)?;
let (key, auth_file) =
seed_keyring_and_fallback_auth_file_for_delete(&mock_keyring, codex_home.path(), || {
compute_store_key(codex_home.path())
})?;
let removed = storage.delete()?;
assert!(removed, "delete should report removal");
assert!(
load_json_from_keyring(&mock_keyring, KEYRING_SERVICE, &base_key)?.is_none(),
"keyring auth should be removed"
!mock_keyring.contains(&key),
"keyring entry should be removed"
);
assert!(
!auth_file.exists(),

View File

@@ -469,7 +469,7 @@ pub fn connector_display_label(connector: &AppInfo) -> String {
}
pub fn connector_mention_slug(connector: &AppInfo) -> String {
sanitize_name(&connector_display_label(connector))
sanitize_slug(&connector_display_label(connector))
}
pub(crate) fn accessible_connectors_from_mcp_tools(
@@ -918,11 +918,15 @@ fn normalize_connector_value(value: Option<&str>) -> Option<String> {
}
pub fn connector_install_url(name: &str, connector_id: &str) -> String {
let slug = sanitize_name(name);
let slug = sanitize_slug(name);
format!("https://chatgpt.com/apps/{slug}/{connector_id}")
}
pub fn sanitize_name(name: &str) -> String {
sanitize_slug(name).replace("-", "_")
}
fn sanitize_slug(name: &str) -> String {
let mut normalized = String::with_capacity(name.len());
for character in name.chars() {
if character.is_ascii_alphanumeric() {

View File

@@ -103,6 +103,21 @@ pub(crate) fn is_contextual_user_fragment(content_item: &ContentItem) -> bool {
.any(|definition| definition.matches_text(text))
}
/// Returns whether a contextual user fragment should be omitted from memory
/// stage-1 inputs.
///
/// We exclude injected `AGENTS.md` instructions and skill payloads because
/// they are prompt scaffolding rather than conversation content, so they do
/// not improve the resulting memory. We keep environment context and
/// subagent notifications because they can carry useful execution context or
/// subtask outcomes that should remain visible to memory generation.
pub(crate) fn is_memory_excluded_contextual_user_fragment(content_item: &ContentItem) -> bool {
let ContentItem::InputText { text } = content_item else {
return false;
};
AGENTS_MD_FRAGMENT.matches_text(text) || SKILL_FRAGMENT.matches_text(text)
}
#[cfg(test)]
#[path = "contextual_user_message_tests.rs"]
mod tests;

View File

@@ -29,3 +29,35 @@ fn ignores_regular_user_text() {
text: "hello".to_string(),
}));
}
#[test]
fn classifies_memory_excluded_fragments() {
let cases = [
(
"# AGENTS.md instructions for /tmp\n\n<INSTRUCTIONS>\nbody\n</INSTRUCTIONS>",
true,
),
(
"<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>",
true,
),
(
"<environment_context>\n<cwd>/tmp</cwd>\n</environment_context>",
false,
),
(
"<subagent_notification>{\"agent_id\":\"a\",\"status\":\"completed\"}</subagent_notification>",
false,
),
];
for (text, expected) in cases {
assert_eq!(
is_memory_excluded_contextual_user_fragment(&ContentItem::InputText {
text: text.to_string(),
}),
expected,
"{text}",
);
}
}

View File

@@ -1231,12 +1231,11 @@ fn normalize_codex_apps_tool_name(
return tool_name.to_string();
}
let tool_name = sanitize_name(tool_name).replace('-', "_");
let tool_name = sanitize_name(tool_name);
if let Some(connector_name) = connector_name
.map(str::trim)
.map(sanitize_name)
.map(|name| name.replace('-', "_"))
.filter(|name| !name.is_empty())
&& let Some(stripped) = tool_name.strip_prefix(&connector_name)
&& !stripped.is_empty()
@@ -1247,7 +1246,6 @@ fn normalize_codex_apps_tool_name(
if let Some(connector_id) = connector_id
.map(str::trim)
.map(sanitize_name)
.map(|name| name.replace('-', "_"))
.filter(|name| !name.is_empty())
&& let Some(stripped) = tool_name.strip_prefix(&connector_id)
&& !stripped.is_empty()

View File

@@ -4,6 +4,7 @@ use crate::codex::Session;
use crate::codex::TurnContext;
use crate::config::Config;
use crate::config::types::MemoriesConfig;
use crate::contextual_user_message::is_memory_excluded_contextual_user_fragment;
use crate::error::CodexErr;
use crate::memories::metrics;
use crate::memories::phase_one;
@@ -463,16 +464,14 @@ mod job {
}
/// Serializes filtered stage-1 memory items for prompt inclusion.
fn serialize_filtered_rollout_response_items(
pub(super) fn serialize_filtered_rollout_response_items(
items: &[RolloutItem],
) -> crate::error::Result<String> {
let filtered = items
.iter()
.filter_map(|item| {
if let RolloutItem::ResponseItem(item) = item
&& should_persist_response_item_for_memories(item)
{
Some(item.clone())
if let RolloutItem::ResponseItem(item) = item {
sanitize_response_item_for_memories(item)
} else {
None
}
@@ -482,6 +481,44 @@ mod job {
CodexErr::InvalidRequest(format!("failed to serialize rollout memory: {err}"))
})
}
fn sanitize_response_item_for_memories(item: &ResponseItem) -> Option<ResponseItem> {
let ResponseItem::Message {
id,
role,
content,
end_turn,
phase,
} = item
else {
return should_persist_response_item_for_memories(item).then(|| item.clone());
};
if role == "developer" {
return None;
}
if role != "user" {
return Some(item.clone());
}
let content = content
.iter()
.filter(|content_item| !is_memory_excluded_contextual_user_fragment(content_item))
.cloned()
.collect::<Vec<_>>();
if content.is_empty() {
return None;
}
Some(ResponseItem::Message {
id: id.clone(),
role: role.clone(),
content,
end_turn: *end_turn,
phase: phase.clone(),
})
}
}
fn aggregate_stats(outcomes: Vec<JobResult>) -> Stats {

View File

@@ -1,9 +1,77 @@
use super::JobOutcome;
use super::JobResult;
use super::aggregate_stats;
use super::job::serialize_filtered_rollout_response_items;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::TokenUsage;
use pretty_assertions::assert_eq;
#[test]
fn serializes_memory_rollout_with_agents_removed_but_environment_kept() {
let mixed_contextual_message = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![
ContentItem::InputText {
text: "# AGENTS.md instructions for /tmp\n\n<INSTRUCTIONS>\nbody\n</INSTRUCTIONS>"
.to_string(),
},
ContentItem::InputText {
text: "<environment_context>\n<cwd>/tmp</cwd>\n</environment_context>".to_string(),
},
],
end_turn: None,
phase: None,
};
let skill_message = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "<skill>\n<name>demo</name>\n<path>skills/demo/SKILL.md</path>\nbody\n</skill>"
.to_string(),
}],
end_turn: None,
phase: None,
};
let subagent_message = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "<subagent_notification>{\"agent_id\":\"a\",\"status\":\"completed\"}</subagent_notification>"
.to_string(),
}],
end_turn: None,
phase: None,
};
let serialized = serialize_filtered_rollout_response_items(&[
RolloutItem::ResponseItem(mixed_contextual_message),
RolloutItem::ResponseItem(skill_message),
RolloutItem::ResponseItem(subagent_message.clone()),
])
.expect("serialize");
let parsed: Vec<ResponseItem> = serde_json::from_str(&serialized).expect("parse");
assert_eq!(
parsed,
vec![
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "<environment_context>\n<cwd>/tmp</cwd>\n</environment_context>"
.to_string(),
}],
end_turn: None,
phase: None,
},
subagent_message,
]
);
}
#[test]
fn count_outcomes_sums_token_usage_across_all_jobs() {
let counts = aggregate_stats(vec![

View File

@@ -0,0 +1,35 @@
[package]
name = "codex-exec-server"
version.workspace = true
edition.workspace = true
license.workspace = true
[[bin]]
name = "codex-exec-server"
path = "src/bin/codex-exec-server.rs"
[lints]
workspace = true
[dependencies]
base64 = { workspace = true }
codex-app-server-protocol = { workspace = true }
codex-utils-pty = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = [
"io-std",
"io-util",
"macros",
"process",
"rt-multi-thread",
"sync",
"time",
] }
tracing = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
codex-utils-cargo-bin = { workspace = true }
pretty_assertions = { workspace = true }

View File

@@ -0,0 +1,7 @@
#[tokio::main]
async fn main() {
if let Err(err) = codex_exec_server::run_main().await {
eprintln!("{err}");
std::process::exit(1);
}
}

View File

@@ -0,0 +1,514 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use codex_app_server_protocol::JSONRPCError;
use codex_app_server_protocol::JSONRPCErrorError;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde_json::Value;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::process::Child;
use tokio::process::Command;
use tokio::sync::Mutex;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tracing::debug;
use tracing::warn;
use crate::protocol::EXEC_EXITED_METHOD;
use crate::protocol::EXEC_METHOD;
use crate::protocol::EXEC_OUTPUT_DELTA_METHOD;
use crate::protocol::EXEC_TERMINATE_METHOD;
use crate::protocol::EXEC_WRITE_METHOD;
use crate::protocol::ExecExitedNotification;
use crate::protocol::ExecOutputDeltaNotification;
use crate::protocol::ExecParams;
use crate::protocol::ExecResponse;
use crate::protocol::INITIALIZE_METHOD;
use crate::protocol::INITIALIZED_METHOD;
use crate::protocol::InitializeParams;
use crate::protocol::InitializeResponse;
use crate::protocol::TerminateParams;
use crate::protocol::TerminateResponse;
use crate::protocol::WriteParams;
use crate::protocol::WriteResponse;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExecServerLaunchCommand {
pub program: PathBuf,
pub args: Vec<String>,
}
pub struct ExecServerProcess {
process_id: String,
pid: Option<u32>,
output_rx: Option<broadcast::Receiver<Vec<u8>>>,
writer_tx: mpsc::Sender<Vec<u8>>,
status: Arc<RemoteProcessStatus>,
client: ExecServerClient,
}
impl std::fmt::Debug for ExecServerProcess {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExecServerProcess")
.field("process_id", &self.process_id)
.field("pid", &self.pid)
.field("has_exited", &self.has_exited())
.field("exit_code", &self.exit_code())
.finish()
}
}
impl ExecServerProcess {
pub fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
self.writer_tx.clone()
}
pub fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
match self.output_rx.as_ref() {
Some(output_rx) => output_rx.resubscribe(),
None => panic!("output receiver should still be present"),
}
}
pub fn take_output_receiver(&mut self) -> broadcast::Receiver<Vec<u8>> {
match self.output_rx.take() {
Some(output_rx) => output_rx,
None => panic!("output receiver should only be taken once"),
}
}
pub fn has_exited(&self) -> bool {
self.status.has_exited()
}
pub fn exit_code(&self) -> Option<i32> {
self.status.exit_code()
}
pub fn pid(&self) -> Option<u32> {
self.pid
}
pub fn terminate(&self) {
self.status.mark_exited(None);
let client = self.client.clone();
let process_id = self.process_id.clone();
tokio::spawn(async move {
let _ = client.terminate_process(&process_id).await;
});
}
}
impl std::fmt::Debug for RemoteProcessStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RemoteProcessStatus")
.field("exited", &self.has_exited())
.field("exit_code", &self.exit_code())
.finish()
}
}
struct RemoteProcessStatus {
exited: AtomicBool,
exit_code: StdMutex<Option<i32>>,
}
impl RemoteProcessStatus {
fn new() -> Self {
Self {
exited: AtomicBool::new(false),
exit_code: StdMutex::new(None),
}
}
fn has_exited(&self) -> bool {
self.exited.load(Ordering::SeqCst)
}
fn exit_code(&self) -> Option<i32> {
self.exit_code.lock().ok().and_then(|guard| *guard)
}
fn mark_exited(&self, exit_code: Option<i32>) {
self.exited.store(true, Ordering::SeqCst);
if let Ok(mut guard) = self.exit_code.lock() {
*guard = exit_code;
}
}
}
struct RegisteredProcess {
output_tx: broadcast::Sender<Vec<u8>>,
status: Arc<RemoteProcessStatus>,
}
struct Inner {
child: StdMutex<Option<Child>>,
write_tx: mpsc::UnboundedSender<JSONRPCMessage>,
pending: Mutex<HashMap<RequestId, oneshot::Sender<Result<Value, JSONRPCErrorError>>>>,
processes: Mutex<HashMap<String, RegisteredProcess>>,
next_request_id: AtomicI64,
reader_task: JoinHandle<()>,
writer_task: JoinHandle<()>,
}
impl Drop for Inner {
fn drop(&mut self) {
self.reader_task.abort();
self.writer_task.abort();
if let Ok(mut child_guard) = self.child.lock()
&& let Some(child) = child_guard.as_mut()
{
let _ = child.start_kill();
}
}
}
#[derive(Clone)]
pub struct ExecServerClient {
inner: Arc<Inner>,
}
#[derive(Debug, thiserror::Error)]
pub enum ExecServerError {
#[error("failed to spawn exec-server: {0}")]
Spawn(#[source] std::io::Error),
#[error("exec-server transport closed")]
Closed,
#[error("failed to serialize or deserialize exec-server JSON: {0}")]
Json(#[from] serde_json::Error),
#[error("exec-server protocol error: {0}")]
Protocol(String),
#[error("exec-server rejected request ({code}): {message}")]
Server { code: i64, message: String },
}
impl ExecServerClient {
pub async fn spawn(command: ExecServerLaunchCommand) -> Result<Self, ExecServerError> {
let mut child = Command::new(&command.program);
child.args(&command.args);
child.stdin(Stdio::piped());
child.stdout(Stdio::piped());
child.stderr(Stdio::inherit());
child.kill_on_drop(true);
let mut child = child.spawn().map_err(ExecServerError::Spawn)?;
let stdin = child.stdin.take().ok_or_else(|| {
ExecServerError::Protocol("exec-server stdin was not captured".to_string())
})?;
let stdout = child.stdout.take().ok_or_else(|| {
ExecServerError::Protocol("exec-server stdout was not captured".to_string())
})?;
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<JSONRPCMessage>();
let writer_task = tokio::spawn(async move {
let mut stdin = stdin;
while let Some(message) = write_rx.recv().await {
let encoded = match serde_json::to_vec(&message) {
Ok(encoded) => encoded,
Err(err) => {
warn!("failed to encode exec-server message: {err}");
break;
}
};
if stdin.write_all(&encoded).await.is_err() {
break;
}
if stdin.write_all(b"\n").await.is_err() {
break;
}
if stdin.flush().await.is_err() {
break;
}
}
});
let pending = Mutex::new(HashMap::<
RequestId,
oneshot::Sender<Result<Value, JSONRPCErrorError>>,
>::new());
let processes = Mutex::new(HashMap::<String, RegisteredProcess>::new());
let inner = Arc::new_cyclic(move |weak| {
let weak = weak.clone();
let reader_task = tokio::spawn(async move {
let mut lines = BufReader::new(stdout).lines();
loop {
let Some(inner) = weak.upgrade() else {
break;
};
let next_line = lines.next_line().await;
match next_line {
Ok(Some(line)) => {
if line.trim().is_empty() {
continue;
}
match serde_json::from_str::<JSONRPCMessage>(&line) {
Ok(message) => {
if let Err(err) = handle_server_message(&inner, message).await {
warn!("failed to handle exec-server message: {err}");
break;
}
}
Err(err) => {
warn!("failed to parse exec-server message: {err}");
break;
}
}
}
Ok(None) => break,
Err(err) => {
warn!("failed to read exec-server stdout: {err}");
break;
}
}
}
if let Some(inner) = weak.upgrade() {
handle_transport_shutdown(&inner).await;
}
});
Inner {
child: StdMutex::new(Some(child)),
write_tx,
pending,
processes,
next_request_id: AtomicI64::new(1),
reader_task,
writer_task,
}
});
let client = Self { inner };
client.initialize().await?;
Ok(client)
}
pub async fn start_process(
&self,
params: ExecParams,
) -> Result<ExecServerProcess, ExecServerError> {
let process_id = params.process_id.clone();
let status = Arc::new(RemoteProcessStatus::new());
let (output_tx, output_rx) = broadcast::channel(256);
self.inner.processes.lock().await.insert(
process_id.clone(),
RegisteredProcess {
output_tx,
status: Arc::clone(&status),
},
);
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
let client = self.clone();
let write_process_id = process_id.clone();
tokio::spawn(async move {
while let Some(chunk) = writer_rx.recv().await {
let request = WriteParams {
process_id: write_process_id.clone(),
chunk: chunk.into(),
};
if client.write_process(request).await.is_err() {
break;
}
}
});
let response = match self.request::<_, ExecResponse>(EXEC_METHOD, &params).await {
Ok(response) => response,
Err(err) => {
self.inner.processes.lock().await.remove(&process_id);
return Err(err);
}
};
if let Some(exit_code) = response.exit_code {
status.mark_exited(Some(exit_code));
}
Ok(ExecServerProcess {
process_id,
pid: response.pid,
output_rx: Some(output_rx),
writer_tx,
status,
client: self.clone(),
})
}
async fn initialize(&self) -> Result<(), ExecServerError> {
let _: InitializeResponse = self
.request(
INITIALIZE_METHOD,
&InitializeParams {
client_name: "codex-core".to_string(),
},
)
.await?;
self.notify(INITIALIZED_METHOD, &serde_json::json!({}))
.await
}
async fn write_process(&self, params: WriteParams) -> Result<WriteResponse, ExecServerError> {
self.request(EXEC_WRITE_METHOD, &params).await
}
async fn terminate_process(
&self,
process_id: &str,
) -> Result<TerminateResponse, ExecServerError> {
self.request(
EXEC_TERMINATE_METHOD,
&TerminateParams {
process_id: process_id.to_string(),
},
)
.await
}
async fn notify<P: Serialize>(&self, method: &str, params: &P) -> Result<(), ExecServerError> {
let params = serde_json::to_value(params)?;
self.inner
.write_tx
.send(JSONRPCMessage::Notification(JSONRPCNotification {
method: method.to_string(),
params: Some(params),
}))
.map_err(|_| ExecServerError::Closed)
}
async fn request<P, R>(&self, method: &str, params: &P) -> Result<R, ExecServerError>
where
P: Serialize,
R: DeserializeOwned,
{
let request_id =
RequestId::Integer(self.inner.next_request_id.fetch_add(1, Ordering::SeqCst));
let (response_tx, response_rx) = oneshot::channel();
self.inner
.pending
.lock()
.await
.insert(request_id.clone(), response_tx);
let params = serde_json::to_value(params)?;
let message = JSONRPCMessage::Request(JSONRPCRequest {
id: request_id.clone(),
method: method.to_string(),
params: Some(params),
trace: None,
});
if self.inner.write_tx.send(message).is_err() {
self.inner.pending.lock().await.remove(&request_id);
return Err(ExecServerError::Closed);
}
let result = response_rx.await.map_err(|_| ExecServerError::Closed)?;
match result {
Ok(value) => serde_json::from_value(value).map_err(ExecServerError::from),
Err(error) => Err(ExecServerError::Server {
code: error.code,
message: error.message,
}),
}
}
}
async fn handle_server_message(
inner: &Arc<Inner>,
message: JSONRPCMessage,
) -> Result<(), ExecServerError> {
match message {
JSONRPCMessage::Response(JSONRPCResponse { id, result }) => {
if let Some(tx) = inner.pending.lock().await.remove(&id) {
let _ = tx.send(Ok(result));
}
}
JSONRPCMessage::Error(JSONRPCError { id, error }) => {
if let Some(tx) = inner.pending.lock().await.remove(&id) {
let _ = tx.send(Err(error));
}
}
JSONRPCMessage::Notification(notification) => {
handle_server_notification(inner, notification).await?;
}
JSONRPCMessage::Request(request) => {
return Err(ExecServerError::Protocol(format!(
"unexpected exec-server request from child: {}",
request.method
)));
}
}
Ok(())
}
async fn handle_server_notification(
inner: &Arc<Inner>,
notification: JSONRPCNotification,
) -> Result<(), ExecServerError> {
match notification.method.as_str() {
EXEC_OUTPUT_DELTA_METHOD => {
let params: ExecOutputDeltaNotification =
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
let chunk = params.chunk.into_inner();
let processes = inner.processes.lock().await;
if let Some(process) = processes.get(&params.process_id) {
let _ = process.output_tx.send(chunk);
}
}
EXEC_EXITED_METHOD => {
let params: ExecExitedNotification =
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
let mut processes = inner.processes.lock().await;
if let Some(process) = processes.remove(&params.process_id) {
process.status.mark_exited(Some(params.exit_code));
}
}
other => {
debug!("ignoring unknown exec-server notification: {other}");
}
}
Ok(())
}
async fn handle_transport_shutdown(inner: &Arc<Inner>) {
let pending = {
let mut pending = inner.pending.lock().await;
pending.drain().map(|(_, tx)| tx).collect::<Vec<_>>()
};
for tx in pending {
let _ = tx.send(Err(JSONRPCErrorError {
code: -32000,
data: None,
message: "exec-server transport closed".to_string(),
}));
}
let processes = {
let mut processes = inner.processes.lock().await;
processes
.drain()
.map(|(_, process)| process)
.collect::<Vec<_>>()
};
for process in processes {
process.status.mark_exited(None);
}
}

View File

@@ -0,0 +1,20 @@
mod client;
mod protocol;
mod server;
pub use client::ExecServerClient;
pub use client::ExecServerError;
pub use client::ExecServerLaunchCommand;
pub use client::ExecServerProcess;
pub use protocol::ExecExitedNotification;
pub use protocol::ExecOutputDeltaNotification;
pub use protocol::ExecOutputStream;
pub use protocol::ExecParams;
pub use protocol::ExecResponse;
pub use protocol::InitializeParams;
pub use protocol::InitializeResponse;
pub use protocol::TerminateParams;
pub use protocol::TerminateResponse;
pub use protocol::WriteParams;
pub use protocol::WriteResponse;
pub use server::run_main;

View File

@@ -0,0 +1,144 @@
use std::collections::HashMap;
use std::path::PathBuf;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use codex_utils_pty::DEFAULT_OUTPUT_BYTES_CAP;
use serde::Deserialize;
use serde::Serialize;
pub const INITIALIZE_METHOD: &str = "initialize";
pub const INITIALIZED_METHOD: &str = "initialized";
pub const EXEC_METHOD: &str = "command/exec";
pub const EXEC_WRITE_METHOD: &str = "command/exec/write";
pub const EXEC_TERMINATE_METHOD: &str = "command/exec/terminate";
pub const EXEC_OUTPUT_DELTA_METHOD: &str = "command/exec/outputDelta";
pub const EXEC_EXITED_METHOD: &str = "command/exec/exited";
pub const PROTOCOL_VERSION: &str = "exec-server.v0";
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ByteChunk(#[serde(with = "base64_bytes")] pub Vec<u8>);
impl ByteChunk {
pub fn into_inner(self) -> Vec<u8> {
self.0
}
}
impl From<Vec<u8>> for ByteChunk {
fn from(value: Vec<u8>) -> Self {
Self(value)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeParams {
pub client_name: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResponse {
pub protocol_version: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ExecParams {
pub process_id: String,
pub argv: Vec<String>,
pub cwd: PathBuf,
pub env: HashMap<String, String>,
pub tty: bool,
#[serde(default = "default_output_bytes_cap")]
pub output_bytes_cap: usize,
pub arg0: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ExecResponse {
pub process_id: String,
pub pid: Option<u32>,
pub running: bool,
pub exit_code: Option<i32>,
pub stdout: Option<ByteChunk>,
pub stderr: Option<ByteChunk>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct WriteParams {
pub process_id: String,
pub chunk: ByteChunk,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct WriteResponse {
pub accepted: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TerminateParams {
pub process_id: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TerminateResponse {
pub running: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum ExecOutputStream {
Stdout,
Stderr,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ExecOutputDeltaNotification {
pub process_id: String,
pub stream: ExecOutputStream,
pub chunk: ByteChunk,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ExecExitedNotification {
pub process_id: String,
pub exit_code: i32,
}
fn default_output_bytes_cap() -> usize {
DEFAULT_OUTPUT_BYTES_CAP
}
mod base64_bytes {
use super::BASE64_STANDARD;
use base64::Engine as _;
use serde::Deserialize;
use serde::Deserializer;
use serde::Serializer;
pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&BASE64_STANDARD.encode(bytes))
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let encoded = String::deserialize(deserializer)?;
BASE64_STANDARD
.decode(encoded)
.map_err(serde::de::Error::custom)
}
}

View File

@@ -0,0 +1,422 @@
use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use codex_app_server_protocol::JSONRPCError;
use codex_app_server_protocol::JSONRPCErrorError;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use codex_utils_pty::ExecCommandSession;
use codex_utils_pty::TerminalSize;
use serde::Serialize;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::io::BufWriter;
use tokio::sync::Mutex;
use crate::protocol::EXEC_EXITED_METHOD;
use crate::protocol::EXEC_METHOD;
use crate::protocol::EXEC_OUTPUT_DELTA_METHOD;
use crate::protocol::EXEC_TERMINATE_METHOD;
use crate::protocol::EXEC_WRITE_METHOD;
use crate::protocol::ExecExitedNotification;
use crate::protocol::ExecOutputDeltaNotification;
use crate::protocol::ExecOutputStream;
use crate::protocol::ExecParams;
use crate::protocol::ExecResponse;
use crate::protocol::INITIALIZE_METHOD;
use crate::protocol::INITIALIZED_METHOD;
use crate::protocol::InitializeResponse;
use crate::protocol::PROTOCOL_VERSION;
use crate::protocol::TerminateParams;
use crate::protocol::TerminateResponse;
use crate::protocol::WriteParams;
use crate::protocol::WriteResponse;
struct RunningProcess {
session: ExecCommandSession,
tty: bool,
stdout_buffer: Arc<StdMutex<BoundedBytesBuffer>>,
stderr_buffer: Arc<StdMutex<BoundedBytesBuffer>>,
}
#[derive(Debug)]
struct BoundedBytesBuffer {
max_bytes: usize,
bytes: VecDeque<u8>,
}
impl BoundedBytesBuffer {
fn new(max_bytes: usize) -> Self {
Self {
max_bytes,
bytes: VecDeque::with_capacity(max_bytes.min(8192)),
}
}
fn push_chunk(&mut self, chunk: &[u8]) {
if self.max_bytes == 0 {
return;
}
for byte in chunk {
self.bytes.push_back(*byte);
if self.bytes.len() > self.max_bytes {
self.bytes.pop_front();
}
}
}
fn snapshot(&self) -> Vec<u8> {
self.bytes.iter().copied().collect()
}
}
pub async fn run_main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let writer = Arc::new(Mutex::new(BufWriter::new(tokio::io::stdout())));
let processes = Arc::new(Mutex::new(HashMap::<String, RunningProcess>::new()));
let mut lines = BufReader::new(tokio::io::stdin()).lines();
while let Some(line) = lines.next_line().await? {
if line.trim().is_empty() {
continue;
}
let message = serde_json::from_str::<JSONRPCMessage>(&line)?;
if let JSONRPCMessage::Request(request) = message {
handle_request(request, &writer, &processes).await;
continue;
}
if let JSONRPCMessage::Notification(notification) = message {
if notification.method != INITIALIZED_METHOD {
send_error(
&writer,
RequestId::Integer(-1),
invalid_request(format!(
"unexpected notification method: {}",
notification.method
)),
)
.await;
}
continue;
}
}
let remaining = {
let mut processes = processes.lock().await;
processes
.drain()
.map(|(_, process)| process)
.collect::<Vec<_>>()
};
for process in remaining {
process.session.terminate();
}
Ok(())
}
async fn handle_request(
request: JSONRPCRequest,
writer: &Arc<Mutex<BufWriter<tokio::io::Stdout>>>,
processes: &Arc<Mutex<HashMap<String, RunningProcess>>>,
) {
let response = match request.method.as_str() {
INITIALIZE_METHOD => serde_json::to_value(InitializeResponse {
protocol_version: PROTOCOL_VERSION.to_string(),
})
.map_err(|err| internal_error(err.to_string())),
EXEC_METHOD => handle_exec_request(request.params, writer, processes).await,
EXEC_WRITE_METHOD => handle_write_request(request.params, processes).await,
EXEC_TERMINATE_METHOD => handle_terminate_request(request.params, processes).await,
other => Err(invalid_request(format!("unknown method: {other}"))),
};
match response {
Ok(result) => {
send_response(
writer,
JSONRPCResponse {
id: request.id,
result,
},
)
.await;
}
Err(err) => {
send_error(writer, request.id, err).await;
}
}
}
async fn handle_exec_request(
params: Option<serde_json::Value>,
writer: &Arc<Mutex<BufWriter<tokio::io::Stdout>>>,
processes: &Arc<Mutex<HashMap<String, RunningProcess>>>,
) -> Result<serde_json::Value, JSONRPCErrorError> {
let params: ExecParams = serde_json::from_value(params.unwrap_or(serde_json::Value::Null))
.map_err(|err| invalid_params(err.to_string()))?;
let (program, args) = params
.argv
.split_first()
.ok_or_else(|| invalid_params("argv must not be empty".to_string()))?;
let spawned = if params.tty {
codex_utils_pty::spawn_pty_process(
program,
args,
params.cwd.as_path(),
&params.env,
&params.arg0,
TerminalSize::default(),
)
.await
} else {
codex_utils_pty::spawn_pipe_process_no_stdin(
program,
args,
params.cwd.as_path(),
&params.env,
&params.arg0,
)
.await
}
.map_err(|err| internal_error(err.to_string()))?;
let pid = spawned.session.pid();
let stdout_buffer = Arc::new(StdMutex::new(BoundedBytesBuffer::new(
params.output_bytes_cap,
)));
let stderr_buffer = Arc::new(StdMutex::new(BoundedBytesBuffer::new(
params.output_bytes_cap,
)));
let process_id = params.process_id.clone();
{
let mut process_map = processes.lock().await;
if process_map.contains_key(&process_id) {
spawned.session.terminate();
return Err(invalid_request(format!(
"process {} already exists",
params.process_id
)));
}
process_map.insert(
process_id.clone(),
RunningProcess {
session: spawned.session,
tty: params.tty,
stdout_buffer: Arc::clone(&stdout_buffer),
stderr_buffer: Arc::clone(&stderr_buffer),
},
);
}
tokio::spawn(stream_output(
process_id.clone(),
ExecOutputStream::Stdout,
spawned.stdout_rx,
Arc::clone(writer),
Arc::clone(&stdout_buffer),
));
tokio::spawn(stream_output(
process_id.clone(),
ExecOutputStream::Stderr,
spawned.stderr_rx,
Arc::clone(writer),
Arc::clone(&stderr_buffer),
));
tokio::spawn(watch_exit(
process_id.clone(),
spawned.exit_rx,
Arc::clone(writer),
Arc::clone(processes),
));
serde_json::to_value(ExecResponse {
process_id,
pid,
running: true,
exit_code: None,
stdout: None,
stderr: None,
})
.map_err(|err| internal_error(err.to_string()))
}
async fn handle_write_request(
params: Option<serde_json::Value>,
processes: &Arc<Mutex<HashMap<String, RunningProcess>>>,
) -> Result<serde_json::Value, JSONRPCErrorError> {
let params: WriteParams = serde_json::from_value(params.unwrap_or(serde_json::Value::Null))
.map_err(|err| invalid_params(err.to_string()))?;
let writer_tx = {
let process_map = processes.lock().await;
let process = process_map
.get(&params.process_id)
.ok_or_else(|| invalid_request(format!("unknown process id {}", params.process_id)))?;
if !process.tty {
return Err(invalid_request(format!(
"stdin is closed for process {}",
params.process_id
)));
}
process.session.writer_sender()
};
writer_tx
.send(params.chunk.into_inner())
.await
.map_err(|_| internal_error("failed to write to process stdin".to_string()))?;
serde_json::to_value(WriteResponse { accepted: true })
.map_err(|err| internal_error(err.to_string()))
}
async fn handle_terminate_request(
params: Option<serde_json::Value>,
processes: &Arc<Mutex<HashMap<String, RunningProcess>>>,
) -> Result<serde_json::Value, JSONRPCErrorError> {
let params: TerminateParams = serde_json::from_value(params.unwrap_or(serde_json::Value::Null))
.map_err(|err| invalid_params(err.to_string()))?;
let process = {
let mut process_map = processes.lock().await;
process_map.remove(&params.process_id)
};
if let Some(process) = process {
process.session.terminate();
serde_json::to_value(TerminateResponse { running: true })
.map_err(|err| internal_error(err.to_string()))
} else {
serde_json::to_value(TerminateResponse { running: false })
.map_err(|err| internal_error(err.to_string()))
}
}
async fn stream_output(
process_id: String,
stream: ExecOutputStream,
mut receiver: tokio::sync::mpsc::Receiver<Vec<u8>>,
writer: Arc<Mutex<BufWriter<tokio::io::Stdout>>>,
buffer: Arc<StdMutex<BoundedBytesBuffer>>,
) {
while let Some(chunk) = receiver.recv().await {
if let Ok(mut guard) = buffer.lock() {
guard.push_chunk(&chunk);
}
let notification = ExecOutputDeltaNotification {
process_id: process_id.clone(),
stream,
chunk: chunk.into(),
};
if send_notification(&writer, EXEC_OUTPUT_DELTA_METHOD, &notification)
.await
.is_err()
{
break;
}
}
}
async fn watch_exit(
process_id: String,
exit_rx: tokio::sync::oneshot::Receiver<i32>,
writer: Arc<Mutex<BufWriter<tokio::io::Stdout>>>,
processes: Arc<Mutex<HashMap<String, RunningProcess>>>,
) {
let exit_code = exit_rx.await.unwrap_or(-1);
let removed = {
let mut processes = processes.lock().await;
processes.remove(&process_id)
};
if let Some(process) = removed {
let _ = process.stdout_buffer.lock().map(|buffer| buffer.snapshot());
let _ = process.stderr_buffer.lock().map(|buffer| buffer.snapshot());
}
let _ = send_notification(
&writer,
EXEC_EXITED_METHOD,
&ExecExitedNotification {
process_id,
exit_code,
},
)
.await;
}
async fn send_response(
writer: &Arc<Mutex<BufWriter<tokio::io::Stdout>>>,
response: JSONRPCResponse,
) {
let _ = send_message(writer, JSONRPCMessage::Response(response)).await;
}
async fn send_error(
writer: &Arc<Mutex<BufWriter<tokio::io::Stdout>>>,
id: RequestId,
error: JSONRPCErrorError,
) {
let _ = send_message(writer, JSONRPCMessage::Error(JSONRPCError { error, id })).await;
}
async fn send_notification<T: Serialize>(
writer: &Arc<Mutex<BufWriter<tokio::io::Stdout>>>,
method: &str,
params: &T,
) -> Result<(), serde_json::Error> {
send_message(
writer,
JSONRPCMessage::Notification(JSONRPCNotification {
method: method.to_string(),
params: Some(serde_json::to_value(params)?),
}),
)
.await
.map_err(serde_json::Error::io)
}
async fn send_message(
writer: &Arc<Mutex<BufWriter<tokio::io::Stdout>>>,
message: JSONRPCMessage,
) -> std::io::Result<()> {
let encoded =
serde_json::to_vec(&message).map_err(|err| std::io::Error::other(err.to_string()))?;
let mut writer = writer.lock().await;
writer.write_all(&encoded).await?;
writer.write_all(b"\n").await?;
writer.flush().await
}
fn invalid_request(message: String) -> JSONRPCErrorError {
JSONRPCErrorError {
code: -32600,
data: None,
message,
}
}
fn invalid_params(message: String) -> JSONRPCErrorError {
JSONRPCErrorError {
code: -32602,
data: None,
message,
}
}
fn internal_error(message: String) -> JSONRPCErrorError {
JSONRPCErrorError {
code: -32603,
data: None,
message,
}
}

View File

@@ -0,0 +1,141 @@
#![cfg(unix)]
use std::process::Stdio;
use std::time::Duration;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use codex_exec_server::ExecParams;
use codex_exec_server::ExecServerClient;
use codex_exec_server::ExecServerLaunchCommand;
use codex_exec_server::InitializeParams;
use codex_exec_server::InitializeResponse;
use codex_utils_cargo_bin::cargo_bin;
use pretty_assertions::assert_eq;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::process::Command;
use tokio::sync::broadcast;
use tokio::time::timeout;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_server_accepts_initialize_over_stdio() -> anyhow::Result<()> {
let binary = cargo_bin("codex-exec-server")?;
let mut child = Command::new(binary);
child.stdin(Stdio::piped());
child.stdout(Stdio::piped());
child.stderr(Stdio::inherit());
let mut child = child.spawn()?;
let mut stdin = child.stdin.take().expect("stdin");
let stdout = child.stdout.take().expect("stdout");
let mut stdout = BufReader::new(stdout).lines();
let initialize = JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(1),
method: "initialize".to_string(),
params: Some(serde_json::to_value(InitializeParams {
client_name: "exec-server-test".to_string(),
})?),
trace: None,
});
stdin
.write_all(format!("{}\n", serde_json::to_string(&initialize)?).as_bytes())
.await?;
let response_line = timeout(Duration::from_secs(5), stdout.next_line()).await??;
let response_line = response_line.expect("response line");
let response: JSONRPCMessage = serde_json::from_str(&response_line)?;
let JSONRPCMessage::Response(JSONRPCResponse { id, result }) = response else {
panic!("expected initialize response");
};
assert_eq!(id, RequestId::Integer(1));
let initialize_response: InitializeResponse = serde_json::from_value(result)?;
assert_eq!(initialize_response.protocol_version, "exec-server.v0");
let initialized = JSONRPCMessage::Notification(JSONRPCNotification {
method: "initialized".to_string(),
params: Some(serde_json::json!({})),
});
stdin
.write_all(format!("{}\n", serde_json::to_string(&initialized)?).as_bytes())
.await?;
child.start_kill()?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_server_client_streams_output_and_accepts_writes() -> anyhow::Result<()> {
let mut env = std::collections::HashMap::new();
if let Some(path) = std::env::var_os("PATH") {
env.insert("PATH".to_string(), path.to_string_lossy().into_owned());
}
let client = ExecServerClient::spawn(ExecServerLaunchCommand {
program: cargo_bin("codex-exec-server")?,
args: Vec::new(),
})
.await?;
let process = client
.start_process(ExecParams {
process_id: "2001".to_string(),
argv: vec![
"bash".to_string(),
"-lc".to_string(),
"printf 'ready\\n'; while IFS= read -r line; do printf 'echo:%s\\n' \"$line\"; done"
.to_string(),
],
cwd: std::env::current_dir()?,
env,
tty: true,
output_bytes_cap: 4096,
arg0: None,
})
.await?;
let mut output = process.output_receiver();
assert!(
recv_until_contains(&mut output, "ready")
.await?
.contains("ready"),
"expected initial ready output"
);
process
.writer_sender()
.send(b"hello\n".to_vec())
.await
.expect("write should succeed");
assert!(
recv_until_contains(&mut output, "echo:hello")
.await?
.contains("echo:hello"),
"expected echoed output"
);
process.terminate();
Ok(())
}
async fn recv_until_contains(
output: &mut broadcast::Receiver<Vec<u8>>,
needle: &str,
) -> anyhow::Result<String> {
let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
let mut collected = String::new();
loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
let chunk = timeout(remaining, output.recv()).await??;
collected.push_str(&String::from_utf8_lossy(&chunk));
if collected.contains(needle) {
return Ok(collected);
}
}
}

View File

@@ -9,13 +9,8 @@ workspace = true
[dependencies]
keyring = { workspace = true, features = ["crypto-rust"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
pretty_assertions = { workspace = true }
[target.'cfg(target_os = "linux")'.dependencies]
keyring = { workspace = true, features = ["linux-native-async-persistent"] }

View File

@@ -5,17 +5,6 @@ use std::fmt;
use std::fmt::Debug;
use tracing::trace;
mod split_json;
pub use split_json::JsonKeyringError;
pub use split_json::SplitJsonKeyringError;
pub use split_json::delete_json_from_keyring;
pub use split_json::delete_split_json_from_keyring;
pub use split_json::load_json_from_keyring;
pub use split_json::load_split_json_from_keyring;
pub use split_json::save_json_to_keyring;
pub use split_json::save_split_json_to_keyring;
#[derive(Debug)]
pub enum CredentialStoreError {
Other(KeyringError),
@@ -52,18 +41,7 @@ impl Error for CredentialStoreError {}
/// Shared credential store abstraction for keyring-backed implementations.
pub trait KeyringStore: Debug + Send + Sync {
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError>;
fn load_secret(
&self,
service: &str,
account: &str,
) -> Result<Option<Vec<u8>>, CredentialStoreError>;
fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError>;
fn save_secret(
&self,
service: &str,
account: &str,
value: &[u8],
) -> Result<(), CredentialStoreError>;
fn delete(&self, service: &str, account: &str) -> Result<bool, CredentialStoreError>;
}
@@ -90,31 +68,6 @@ impl KeyringStore for DefaultKeyringStore {
}
}
fn load_secret(
&self,
service: &str,
account: &str,
) -> Result<Option<Vec<u8>>, CredentialStoreError> {
trace!("keyring.load_secret start, service={service}, account={account}");
let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?;
match entry.get_secret() {
Ok(secret) => {
trace!("keyring.load_secret success, service={service}, account={account}");
Ok(Some(secret))
}
Err(keyring::Error::NoEntry) => {
trace!("keyring.load_secret no entry, service={service}, account={account}");
Ok(None)
}
Err(error) => {
trace!(
"keyring.load_secret error, service={service}, account={account}, error={error}"
);
Err(CredentialStoreError::new(error))
}
}
}
fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError> {
trace!(
"keyring.save start, service={service}, account={account}, value_len={}",
@@ -133,31 +86,6 @@ impl KeyringStore for DefaultKeyringStore {
}
}
fn save_secret(
&self,
service: &str,
account: &str,
value: &[u8],
) -> Result<(), CredentialStoreError> {
trace!(
"keyring.save_secret start, service={service}, account={account}, value_len={}",
value.len()
);
let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?;
match entry.set_secret(value) {
Ok(()) => {
trace!("keyring.save_secret success, service={service}, account={account}");
Ok(())
}
Err(error) => {
trace!(
"keyring.save_secret error, service={service}, account={account}, error={error}"
);
Err(CredentialStoreError::new(error))
}
}
}
fn delete(&self, service: &str, account: &str) -> Result<bool, CredentialStoreError> {
trace!("keyring.delete start, service={service}, account={account}");
let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?;
@@ -217,22 +145,6 @@ pub mod tests {
credential.get_password().ok()
}
pub fn saved_secret(&self, account: &str) -> Option<Vec<u8>> {
let credential = {
let guard = self
.credentials
.lock()
.unwrap_or_else(PoisonError::into_inner);
guard.get(account).cloned()
}?;
credential.get_secret().ok()
}
pub fn saved_secret_utf8(&self, account: &str) -> Option<String> {
let secret = self.saved_secret(account)?;
String::from_utf8(secret).ok()
}
pub fn set_error(&self, account: &str, error: KeyringError) {
let credential = self.credential(account);
credential.set_error(error);
@@ -272,30 +184,6 @@ pub mod tests {
}
}
fn load_secret(
&self,
_service: &str,
account: &str,
) -> Result<Option<Vec<u8>>, CredentialStoreError> {
let credential = {
let guard = self
.credentials
.lock()
.unwrap_or_else(PoisonError::into_inner);
guard.get(account).cloned()
};
let Some(credential) = credential else {
return Ok(None);
};
match credential.get_secret() {
Ok(secret) => Ok(Some(secret)),
Err(KeyringError::NoEntry) => Ok(None),
Err(error) => Err(CredentialStoreError::new(error)),
}
}
fn save(
&self,
_service: &str,
@@ -308,18 +196,6 @@ pub mod tests {
.map_err(CredentialStoreError::new)
}
fn save_secret(
&self,
_service: &str,
account: &str,
value: &[u8],
) -> Result<(), CredentialStoreError> {
let credential = self.credential(account);
credential
.set_secret(value)
.map_err(CredentialStoreError::new)
}
fn delete(&self, _service: &str, account: &str) -> Result<bool, CredentialStoreError> {
let credential = {
let guard = self

View File

@@ -1,856 +0,0 @@
use crate::CredentialStoreError;
use crate::KeyringStore;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Map;
use serde_json::Value;
use std::fmt;
use std::fmt::Write as _;
use tracing::warn;
const LAYOUT_VERSION: &str = "v1";
const MANIFEST_ENTRY: &str = "manifest";
const VALUE_ENTRY_PREFIX: &str = "value";
const ROOT_PATH_SENTINEL: &str = "root";
#[derive(Debug, Clone)]
pub struct SplitJsonKeyringError {
message: String,
}
pub type JsonKeyringError = SplitJsonKeyringError;
impl SplitJsonKeyringError {
fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl fmt::Display for SplitJsonKeyringError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for SplitJsonKeyringError {}
#[derive(Clone, Copy, Debug, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
enum JsonNodeKind {
Null,
Bool,
Number,
String,
Object,
Array,
}
impl JsonNodeKind {
fn from_value(value: &Value) -> Self {
match value {
Value::Null => Self::Null,
Value::Bool(_) => Self::Bool,
Value::Number(_) => Self::Number,
Value::String(_) => Self::String,
Value::Object(_) => Self::Object,
Value::Array(_) => Self::Array,
}
}
fn is_container(self) -> bool {
matches!(self, Self::Object | Self::Array)
}
fn empty_value(self) -> Option<Value> {
match self {
Self::Object => Some(Value::Object(Map::new())),
Self::Array => Some(Value::Array(Vec::new())),
Self::Null | Self::Bool | Self::Number | Self::String => None,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
struct SplitJsonNode {
path: String,
kind: JsonNodeKind,
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
struct SplitJsonManifest {
nodes: Vec<SplitJsonNode>,
}
type SplitJsonLeafValues = Vec<(String, Vec<u8>)>;
#[cfg(windows)]
pub fn load_json_from_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
) -> Result<Option<Value>, JsonKeyringError> {
if let Some(value) = load_split_json_from_keyring(keyring_store, service, base_key)? {
return Ok(Some(value));
}
load_full_json_from_keyring(keyring_store, service, base_key)
}
#[cfg(not(windows))]
pub fn load_json_from_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
) -> Result<Option<Value>, JsonKeyringError> {
if let Some(value) = load_full_json_from_keyring(keyring_store, service, base_key)? {
return Ok(Some(value));
}
load_split_json_from_keyring(keyring_store, service, base_key)
}
#[cfg(windows)]
pub fn save_json_to_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
value: &Value,
) -> Result<(), JsonKeyringError> {
save_split_json_to_keyring(keyring_store, service, base_key, value)?;
if let Err(err) = delete_full_json_from_keyring(keyring_store, service, base_key) {
warn!("failed to remove stale full JSON record from keyring: {err}");
}
Ok(())
}
#[cfg(not(windows))]
pub fn save_json_to_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
value: &Value,
) -> Result<(), JsonKeyringError> {
save_full_json_to_keyring(keyring_store, service, base_key, value)?;
if let Err(err) = delete_split_json_from_keyring(keyring_store, service, base_key) {
warn!("failed to remove stale split JSON record from keyring: {err}");
}
Ok(())
}
#[cfg(windows)]
pub fn delete_json_from_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
) -> Result<bool, JsonKeyringError> {
let split_removed = delete_split_json_from_keyring(keyring_store, service, base_key)?;
let full_removed = delete_full_json_from_keyring(keyring_store, service, base_key)?;
Ok(split_removed || full_removed)
}
#[cfg(not(windows))]
pub fn delete_json_from_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
) -> Result<bool, JsonKeyringError> {
let full_removed = delete_full_json_from_keyring(keyring_store, service, base_key)?;
let split_removed = delete_split_json_from_keyring(keyring_store, service, base_key)?;
Ok(full_removed || split_removed)
}
pub fn load_split_json_from_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
) -> Result<Option<Value>, SplitJsonKeyringError> {
let Some(manifest) = load_manifest(keyring_store, service, base_key)? else {
return Ok(None);
};
inflate_split_json(keyring_store, service, base_key, &manifest).map(Some)
}
pub fn save_split_json_to_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
value: &Value,
) -> Result<(), SplitJsonKeyringError> {
let previous_manifest = match load_manifest(keyring_store, service, base_key) {
Ok(manifest) => manifest,
Err(err) => {
warn!("failed to read previous split JSON manifest from keyring: {err}");
None
}
};
let (manifest, leaf_values) = flatten_split_json(value)?;
let current_scalar_paths = manifest
.nodes
.iter()
.filter(|node| !node.kind.is_container())
.map(|node| node.path.as_str())
.collect::<std::collections::HashSet<_>>();
for (path, bytes) in leaf_values {
let key = value_key(base_key, &path);
save_secret_to_keyring(
keyring_store,
service,
&key,
&bytes,
&format!("JSON value at {path}"),
)?;
}
let manifest_key = layout_key(base_key, MANIFEST_ENTRY);
let manifest_bytes = serde_json::to_vec(&manifest).map_err(|err| {
SplitJsonKeyringError::new(format!("failed to serialize JSON manifest: {err}"))
})?;
save_secret_to_keyring(
keyring_store,
service,
&manifest_key,
&manifest_bytes,
"JSON manifest",
)?;
if let Some(previous_manifest) = previous_manifest {
for node in previous_manifest.nodes {
if node.kind.is_container() || current_scalar_paths.contains(node.path.as_str()) {
continue;
}
let key = value_key(base_key, &node.path);
if let Err(err) = delete_keyring_entry(
keyring_store,
service,
&key,
&format!("stale JSON value at {}", node.path),
) {
warn!("failed to remove stale split JSON value from keyring: {err}");
}
}
}
Ok(())
}
pub fn delete_split_json_from_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
) -> Result<bool, SplitJsonKeyringError> {
let Some(manifest) = load_manifest(keyring_store, service, base_key)? else {
return Ok(false);
};
let mut removed = false;
for node in manifest.nodes {
if node.kind.is_container() {
continue;
}
let key = value_key(base_key, &node.path);
removed |= delete_keyring_entry(
keyring_store,
service,
&key,
&format!("JSON value at {}", node.path),
)?;
}
let manifest_key = layout_key(base_key, MANIFEST_ENTRY);
removed |= delete_keyring_entry(keyring_store, service, &manifest_key, "JSON manifest")?;
Ok(removed)
}
fn load_full_json_from_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
) -> Result<Option<Value>, SplitJsonKeyringError> {
if let Some(bytes) = load_secret_from_keyring(keyring_store, service, base_key, "JSON record")?
{
let value = serde_json::from_slice(&bytes).map_err(|err| {
SplitJsonKeyringError::new(format!(
"failed to deserialize JSON record from keyring secret: {err}"
))
})?;
return Ok(Some(value));
}
match keyring_store.load(service, base_key) {
Ok(Some(serialized)) => serde_json::from_str(&serialized).map(Some).map_err(|err| {
SplitJsonKeyringError::new(format!(
"failed to deserialize JSON record from keyring password: {err}"
))
}),
Ok(None) => Ok(None),
Err(error) => Err(credential_store_error("load", "JSON record", error)),
}
}
#[cfg(not(windows))]
fn save_full_json_to_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
value: &Value,
) -> Result<(), SplitJsonKeyringError> {
let bytes = serde_json::to_vec(value).map_err(|err| {
SplitJsonKeyringError::new(format!("failed to serialize JSON record: {err}"))
})?;
save_secret_to_keyring(keyring_store, service, base_key, &bytes, "JSON record")
}
fn delete_full_json_from_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
) -> Result<bool, SplitJsonKeyringError> {
delete_keyring_entry(keyring_store, service, base_key, "JSON record")
}
fn flatten_split_json(
value: &Value,
) -> Result<(SplitJsonManifest, SplitJsonLeafValues), SplitJsonKeyringError> {
let mut nodes = Vec::new();
let mut leaf_values = Vec::new();
collect_nodes("", value, &mut nodes, &mut leaf_values)?;
nodes.sort_by(|left, right| {
path_depth(&left.path)
.cmp(&path_depth(&right.path))
.then_with(|| left.path.cmp(&right.path))
});
leaf_values.sort_by(|left, right| left.0.cmp(&right.0));
Ok((SplitJsonManifest { nodes }, leaf_values))
}
fn collect_nodes(
path: &str,
value: &Value,
nodes: &mut Vec<SplitJsonNode>,
leaf_values: &mut SplitJsonLeafValues,
) -> Result<(), SplitJsonKeyringError> {
let kind = JsonNodeKind::from_value(value);
nodes.push(SplitJsonNode {
path: path.to_string(),
kind,
});
match value {
Value::Object(map) => {
let mut keys = map.keys().cloned().collect::<Vec<_>>();
keys.sort();
for key in keys {
let child_path = append_json_pointer_token(path, &key);
let child_value = map.get(&key).ok_or_else(|| {
SplitJsonKeyringError::new(format!(
"missing object value for path {child_path}"
))
})?;
collect_nodes(&child_path, child_value, nodes, leaf_values)?;
}
}
Value::Array(items) => {
for (index, item) in items.iter().enumerate() {
let child_path = append_json_pointer_token(path, &index.to_string());
collect_nodes(&child_path, item, nodes, leaf_values)?;
}
}
Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => {
let bytes = serde_json::to_vec(value).map_err(|err| {
SplitJsonKeyringError::new(format!(
"failed to serialize JSON value at {path}: {err}"
))
})?;
leaf_values.push((path.to_string(), bytes));
}
}
Ok(())
}
fn inflate_split_json<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
manifest: &SplitJsonManifest,
) -> Result<Value, SplitJsonKeyringError> {
let root_node = manifest
.nodes
.iter()
.find(|node| node.path.is_empty())
.ok_or_else(|| SplitJsonKeyringError::new("missing root JSON node in keyring manifest"))?;
let mut result = if let Some(value) = root_node.kind.empty_value() {
value
} else {
load_value(keyring_store, service, base_key, "")?
};
let mut nodes = manifest.nodes.clone();
nodes.sort_by(|left, right| {
path_depth(&left.path)
.cmp(&path_depth(&right.path))
.then_with(|| left.path.cmp(&right.path))
});
for node in nodes.into_iter().filter(|node| !node.path.is_empty()) {
let value = if let Some(value) = node.kind.empty_value() {
value
} else {
load_value(keyring_store, service, base_key, &node.path)?
};
insert_value_at_pointer(&mut result, &node.path, value)?;
}
Ok(result)
}
fn load_value<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
path: &str,
) -> Result<Value, SplitJsonKeyringError> {
let key = value_key(base_key, path);
let bytes = load_secret_from_keyring(
keyring_store,
service,
&key,
&format!("JSON value at {path}"),
)?
.ok_or_else(|| {
SplitJsonKeyringError::new(format!("missing JSON value at {path} in keyring"))
})?;
serde_json::from_slice(&bytes).map_err(|err| {
SplitJsonKeyringError::new(format!("failed to deserialize JSON value at {path}: {err}"))
})
}
fn insert_value_at_pointer(
root: &mut Value,
pointer: &str,
value: Value,
) -> Result<(), SplitJsonKeyringError> {
if pointer.is_empty() {
*root = value;
return Ok(());
}
let tokens = decode_json_pointer(pointer)?;
let Some((last, parents)) = tokens.split_last() else {
return Err(SplitJsonKeyringError::new(
"missing JSON pointer path tokens",
));
};
let mut current = root;
for token in parents {
current = match current {
Value::Object(map) => map.get_mut(token).ok_or_else(|| {
SplitJsonKeyringError::new(format!(
"missing parent object entry for JSON pointer {pointer}"
))
})?,
Value::Array(items) => {
let index = parse_array_index(token, pointer)?;
items.get_mut(index).ok_or_else(|| {
SplitJsonKeyringError::new(format!(
"missing parent array entry for JSON pointer {pointer}"
))
})?
}
Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => {
return Err(SplitJsonKeyringError::new(format!(
"encountered scalar while walking JSON pointer {pointer}"
)));
}
};
}
match current {
Value::Object(map) => {
map.insert(last.to_string(), value);
Ok(())
}
Value::Array(items) => {
let index = parse_array_index(last, pointer)?;
if index >= items.len() {
items.resize(index + 1, Value::Null);
}
items[index] = value;
Ok(())
}
Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => {
Err(SplitJsonKeyringError::new(format!(
"encountered scalar while assigning JSON pointer {pointer}"
)))
}
}
}
fn load_manifest<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
base_key: &str,
) -> Result<Option<SplitJsonManifest>, SplitJsonKeyringError> {
let manifest_key = layout_key(base_key, MANIFEST_ENTRY);
let Some(bytes) =
load_secret_from_keyring(keyring_store, service, &manifest_key, "JSON manifest")?
else {
return Ok(None);
};
let manifest: SplitJsonManifest = serde_json::from_slice(&bytes).map_err(|err| {
SplitJsonKeyringError::new(format!("failed to deserialize JSON manifest: {err}"))
})?;
if manifest.nodes.is_empty() {
return Err(SplitJsonKeyringError::new("JSON manifest is empty"));
}
Ok(Some(manifest))
}
fn load_secret_from_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
key: &str,
field: &str,
) -> Result<Option<Vec<u8>>, SplitJsonKeyringError> {
keyring_store
.load_secret(service, key)
.map_err(|err| credential_store_error("load", field, err))
}
fn save_secret_to_keyring<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
key: &str,
value: &[u8],
field: &str,
) -> Result<(), SplitJsonKeyringError> {
keyring_store
.save_secret(service, key, value)
.map_err(|err| credential_store_error("write", field, err))
}
fn delete_keyring_entry<K: KeyringStore + ?Sized>(
keyring_store: &K,
service: &str,
key: &str,
field: &str,
) -> Result<bool, SplitJsonKeyringError> {
keyring_store
.delete(service, key)
.map_err(|err| credential_store_error("delete", field, err))
}
fn credential_store_error(
action: &str,
field: &str,
error: CredentialStoreError,
) -> SplitJsonKeyringError {
SplitJsonKeyringError::new(format!(
"failed to {action} {field} in keyring: {}",
error.message()
))
}
fn layout_key(base_key: &str, suffix: &str) -> String {
format!("{base_key}|{LAYOUT_VERSION}|{suffix}")
}
fn value_key(base_key: &str, path: &str) -> String {
let encoded_path = encode_path(path);
layout_key(base_key, &format!("{VALUE_ENTRY_PREFIX}|{encoded_path}"))
}
fn encode_path(path: &str) -> String {
if path.is_empty() {
return ROOT_PATH_SENTINEL.to_string();
}
let mut encoded = String::with_capacity(path.len() * 2);
for byte in path.as_bytes() {
let _ = write!(&mut encoded, "{byte:02x}");
}
encoded
}
fn append_json_pointer_token(path: &str, token: &str) -> String {
let escaped = token.replace('~', "~0").replace('/', "~1");
if path.is_empty() {
format!("/{escaped}")
} else {
format!("{path}/{escaped}")
}
}
fn decode_json_pointer(pointer: &str) -> Result<Vec<String>, SplitJsonKeyringError> {
if pointer.is_empty() {
return Ok(Vec::new());
}
if !pointer.starts_with('/') {
return Err(SplitJsonKeyringError::new(format!(
"invalid JSON pointer {pointer}: expected leading slash"
)));
}
pointer[1..]
.split('/')
.map(unescape_json_pointer_token)
.collect()
}
fn unescape_json_pointer_token(token: &str) -> Result<String, SplitJsonKeyringError> {
let mut result = String::with_capacity(token.len());
let mut chars = token.chars();
while let Some(ch) = chars.next() {
if ch != '~' {
result.push(ch);
continue;
}
match chars.next() {
Some('0') => result.push('~'),
Some('1') => result.push('/'),
Some(other) => {
return Err(SplitJsonKeyringError::new(format!(
"invalid JSON pointer escape sequence ~{other}"
)));
}
None => {
return Err(SplitJsonKeyringError::new(
"invalid JSON pointer escape sequence at end of token",
));
}
}
}
Ok(result)
}
fn parse_array_index(token: &str, pointer: &str) -> Result<usize, SplitJsonKeyringError> {
token.parse::<usize>().map_err(|err| {
SplitJsonKeyringError::new(format!(
"invalid array index '{token}' in JSON pointer {pointer}: {err}"
))
})
}
fn path_depth(path: &str) -> usize {
path.chars().filter(|ch| *ch == '/').count()
}
#[cfg(test)]
mod tests {
use super::LAYOUT_VERSION;
use super::MANIFEST_ENTRY;
use super::delete_json_from_keyring;
use super::delete_split_json_from_keyring;
use super::layout_key;
use super::load_json_from_keyring;
use super::load_split_json_from_keyring;
use super::save_json_to_keyring;
use super::save_split_json_to_keyring;
use super::value_key;
use crate::KeyringStore;
use crate::tests::MockKeyringStore;
use pretty_assertions::assert_eq;
use serde_json::json;
const SERVICE: &str = "Test Service";
const BASE_KEY: &str = "base";
#[test]
fn json_storage_round_trips_using_platform_backend() {
let store = MockKeyringStore::default();
let expected = json!({
"token": "secret",
"nested": {"id": 7}
});
save_json_to_keyring(&store, SERVICE, BASE_KEY, &expected).expect("JSON should save");
let loaded = load_json_from_keyring(&store, SERVICE, BASE_KEY)
.expect("JSON should load")
.expect("JSON should exist");
assert_eq!(loaded, expected);
#[cfg(windows)]
{
assert!(
store.saved_secret(BASE_KEY).is_none(),
"windows should not store the full JSON record under the base key"
);
assert!(
store.contains(&layout_key(BASE_KEY, MANIFEST_ENTRY)),
"windows should store split JSON manifest metadata"
);
}
#[cfg(not(windows))]
{
assert_eq!(
store.saved_secret(BASE_KEY),
Some(serde_json::to_vec(&expected).expect("JSON should serialize")),
);
assert!(
!store.contains(&layout_key(BASE_KEY, MANIFEST_ENTRY)),
"non-windows should not create split JSON manifest metadata"
);
}
}
#[cfg(not(windows))]
#[test]
fn json_storage_loads_split_json_compatibility_on_non_windows() {
let store = MockKeyringStore::default();
let expected = json!({
"token": "secret",
"nested": {"id": 9}
});
save_split_json_to_keyring(&store, SERVICE, BASE_KEY, &expected)
.expect("split JSON should save");
let loaded = load_json_from_keyring(&store, SERVICE, BASE_KEY)
.expect("JSON should load")
.expect("JSON should exist");
assert_eq!(loaded, expected);
}
#[test]
fn json_storage_delete_removes_platform_and_compat_entries() {
let store = MockKeyringStore::default();
let current = json!({"current": true});
let split = json!({"split": true});
save_json_to_keyring(&store, SERVICE, BASE_KEY, &current).expect("JSON should save");
save_split_json_to_keyring(&store, SERVICE, BASE_KEY, &split)
.expect("split JSON should save");
store
.save(
SERVICE,
BASE_KEY,
&serde_json::to_string(&current).expect("JSON should serialize"),
)
.expect("legacy JSON should save");
let removed = delete_json_from_keyring(&store, SERVICE, BASE_KEY)
.expect("JSON delete should succeed");
assert!(removed);
assert!(
load_json_from_keyring(&store, SERVICE, BASE_KEY)
.expect("JSON load should succeed")
.is_none()
);
assert!(!store.contains(BASE_KEY));
assert!(!store.contains(&layout_key(BASE_KEY, MANIFEST_ENTRY)));
}
#[test]
fn split_json_round_trips_nested_values() {
let store = MockKeyringStore::default();
let expected = json!({
"name": "codex",
"enabled": true,
"count": 3,
"nested": {
"items": [null, {"hello": "world"}],
"slash/key": "~value~",
},
});
save_split_json_to_keyring(&store, SERVICE, BASE_KEY, &expected)
.expect("split JSON should save");
let loaded = load_split_json_from_keyring(&store, SERVICE, BASE_KEY)
.expect("split JSON should load")
.expect("split JSON should exist");
assert_eq!(loaded, expected);
}
#[test]
fn split_json_supports_scalar_root_values() {
let store = MockKeyringStore::default();
let expected = json!("value");
save_split_json_to_keyring(&store, SERVICE, BASE_KEY, &expected)
.expect("split JSON should save");
let root_value_key = value_key(BASE_KEY, "");
assert_eq!(
store.saved_secret_utf8(&root_value_key),
Some("\"value\"".to_string())
);
let loaded = load_split_json_from_keyring(&store, SERVICE, BASE_KEY)
.expect("split JSON should load")
.expect("split JSON should exist");
assert_eq!(loaded, expected);
}
#[test]
fn split_json_delete_removes_saved_entries() {
let store = MockKeyringStore::default();
let expected = json!({
"token": "secret",
"nested": {
"id": 123,
},
});
save_split_json_to_keyring(&store, SERVICE, BASE_KEY, &expected)
.expect("split JSON should save");
let manifest_key = layout_key(BASE_KEY, MANIFEST_ENTRY);
let token_key = value_key(BASE_KEY, "/token");
let nested_id_key = value_key(BASE_KEY, "/nested/id");
let removed = delete_split_json_from_keyring(&store, SERVICE, BASE_KEY)
.expect("split JSON delete should succeed");
assert!(removed);
assert!(!store.contains(&manifest_key));
assert!(!store.contains(&token_key));
assert!(!store.contains(&nested_id_key));
}
#[test]
fn split_json_save_replaces_previous_values() {
let store = MockKeyringStore::default();
let first = json!({"value": "first", "stale": true});
let second = json!({"value": "second", "extra": 1});
save_split_json_to_keyring(&store, SERVICE, BASE_KEY, &first)
.expect("first split JSON save should succeed");
let manifest_key = layout_key(BASE_KEY, MANIFEST_ENTRY);
let stale_value_key = value_key(BASE_KEY, "/stale");
assert!(store.contains(&manifest_key));
assert!(store.contains(&stale_value_key));
save_split_json_to_keyring(&store, SERVICE, BASE_KEY, &second)
.expect("second split JSON save should succeed");
assert!(!store.contains(&stale_value_key));
assert!(store.contains(&manifest_key));
assert_eq!(
store.saved_secret_utf8(&value_key(BASE_KEY, "/value")),
Some("\"second\"".to_string())
);
assert_eq!(
store.saved_secret_utf8(&value_key(BASE_KEY, "/extra")),
Some("1".to_string())
);
let loaded = load_split_json_from_keyring(&store, SERVICE, BASE_KEY)
.expect("split JSON should load")
.expect("split JSON should exist");
assert_eq!(loaded, second);
}
#[test]
fn split_json_uses_distinct_layout_version() {
assert_eq!(LAYOUT_VERSION, "v1");
}
}

View File

@@ -45,9 +45,6 @@ use tracing::warn;
use codex_keyring_store::DefaultKeyringStore;
use codex_keyring_store::KeyringStore;
use codex_keyring_store::delete_json_from_keyring;
use codex_keyring_store::load_json_from_keyring;
use codex_keyring_store::save_json_to_keyring;
use rmcp::transport::auth::AuthorizationManager;
use tokio::sync::Mutex;
@@ -158,15 +155,16 @@ fn load_oauth_tokens_from_keyring<K: KeyringStore>(
url: &str,
) -> Result<Option<StoredOAuthTokens>> {
let key = compute_store_key(server_name, url)?;
let Some(value) = load_json_from_keyring(keyring_store, KEYRING_SERVICE, &key)
.map_err(|err| Error::msg(err.to_string()))?
else {
return Ok(None);
};
let mut tokens: StoredOAuthTokens =
serde_json::from_value(value).context("failed to deserialize OAuth tokens from keyring")?;
refresh_expires_in_from_timestamp(&mut tokens);
Ok(Some(tokens))
match keyring_store.load(KEYRING_SERVICE, &key) {
Ok(Some(serialized)) => {
let mut tokens: StoredOAuthTokens = serde_json::from_str(&serialized)
.context("failed to deserialize OAuth tokens from keyring")?;
refresh_expires_in_from_timestamp(&mut tokens);
Ok(Some(tokens))
}
Ok(None) => Ok(None),
Err(error) => Err(Error::new(error.into_error())),
}
}
pub fn save_oauth_tokens(
@@ -193,9 +191,10 @@ fn save_oauth_tokens_with_keyring<K: KeyringStore>(
server_name: &str,
tokens: &StoredOAuthTokens,
) -> Result<()> {
let value = serde_json::to_value(tokens).context("failed to serialize OAuth tokens")?;
let serialized = serde_json::to_string(tokens).context("failed to serialize OAuth tokens")?;
let key = compute_store_key(server_name, &tokens.url)?;
match save_json_to_keyring(keyring_store, KEYRING_SERVICE, &key, &value) {
match keyring_store.save(KEYRING_SERVICE, &key, &serialized) {
Ok(()) => {
if let Err(error) = delete_oauth_tokens_from_file(&key) {
warn!("failed to remove OAuth tokens from fallback storage: {error:?}");
@@ -203,9 +202,12 @@ fn save_oauth_tokens_with_keyring<K: KeyringStore>(
Ok(())
}
Err(error) => {
let message = format!("failed to write OAuth tokens to keyring: {error}");
let message = format!(
"failed to write OAuth tokens to keyring: {}",
error.message()
);
warn!("{message}");
Err(Error::msg(message))
Err(Error::new(error.into_error()).context(message))
}
}
}
@@ -242,20 +244,22 @@ fn delete_oauth_tokens_from_keyring_and_file<K: KeyringStore>(
url: &str,
) -> Result<bool> {
let key = compute_store_key(server_name, url)?;
let keyring_removed = match delete_json_from_keyring(keyring_store, KEYRING_SERVICE, &key) {
let keyring_result = keyring_store.delete(KEYRING_SERVICE, &key);
let keyring_removed = match keyring_result {
Ok(removed) => removed,
Err(error) => {
let message = error.to_string();
let message = error.message();
warn!("failed to delete OAuth tokens from keyring: {message}");
match store_mode {
OAuthCredentialsStoreMode::Auto | OAuthCredentialsStoreMode::Keyring => {
return Err(Error::msg(message))
return Err(error.into_error())
.context("failed to delete OAuth tokens from keyring");
}
OAuthCredentialsStoreMode::File => false,
}
}
};
let file_removed = delete_oauth_tokens_from_file(&key)?;
Ok(keyring_removed || file_removed)
}
@@ -600,10 +604,6 @@ fn sha_256_prefix(value: &Value) -> Result<String> {
mod tests {
use super::*;
use anyhow::Result;
use codex_keyring_store::CredentialStoreError;
use codex_keyring_store::load_json_from_keyring;
use codex_keyring_store::save_json_to_keyring;
use codex_keyring_store::save_split_json_to_keyring;
use keyring::Error as KeyringError;
use pretty_assertions::assert_eq;
use std::sync::Mutex;
@@ -614,101 +614,6 @@ mod tests {
use codex_keyring_store::tests::MockKeyringStore;
#[derive(Clone, Debug)]
struct KeyringStoreWithError {
inner: MockKeyringStore,
fail_delete: bool,
fail_load_secret: bool,
fail_save_secret: bool,
}
impl KeyringStoreWithError {
fn fail_delete(inner: MockKeyringStore) -> Self {
Self {
inner,
fail_delete: true,
fail_load_secret: false,
fail_save_secret: false,
}
}
fn fail_load_secret(inner: MockKeyringStore) -> Self {
Self {
inner,
fail_delete: false,
fail_load_secret: true,
fail_save_secret: false,
}
}
fn fail_save_secret(inner: MockKeyringStore) -> Self {
Self {
inner,
fail_delete: false,
fail_load_secret: false,
fail_save_secret: true,
}
}
}
impl KeyringStore for KeyringStoreWithError {
fn load(
&self,
service: &str,
account: &str,
) -> Result<Option<String>, CredentialStoreError> {
self.inner.load(service, account)
}
fn load_secret(
&self,
service: &str,
account: &str,
) -> Result<Option<Vec<u8>>, CredentialStoreError> {
if self.fail_load_secret {
return Err(CredentialStoreError::new(KeyringError::Invalid(
"error".into(),
"load".into(),
)));
}
self.inner.load_secret(service, account)
}
fn save(
&self,
service: &str,
account: &str,
value: &str,
) -> Result<(), CredentialStoreError> {
self.inner.save(service, account, value)
}
fn save_secret(
&self,
service: &str,
account: &str,
value: &[u8],
) -> Result<(), CredentialStoreError> {
if self.fail_save_secret {
return Err(CredentialStoreError::new(KeyringError::Invalid(
"error".into(),
"save".into(),
)));
}
self.inner.save_secret(service, account, value)
}
fn delete(&self, service: &str, account: &str) -> Result<bool, CredentialStoreError> {
if self.fail_delete {
return Err(CredentialStoreError::new(KeyringError::Invalid(
"error".into(),
"delete".into(),
)));
}
self.inner.delete(service, account)
}
}
struct TempCodexHome {
_guard: MutexGuard<'static, ()>,
_dir: tempfile::TempDir,
@@ -746,38 +651,6 @@ mod tests {
let store = MockKeyringStore::default();
let tokens = sample_tokens();
let expected = tokens.clone();
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
let value = serde_json::to_value(&tokens)?;
save_json_to_keyring(&store, KEYRING_SERVICE, &key, &value)?;
let loaded =
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?
.expect("tokens should load from keyring");
assert_tokens_match_without_expiry(&loaded, &expected);
Ok(())
}
#[test]
fn load_oauth_tokens_supports_split_json_compatibility() -> Result<()> {
let _env = TempCodexHome::new();
let store = MockKeyringStore::default();
let tokens = sample_tokens();
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
let value = serde_json::to_value(&tokens)?;
save_split_json_to_keyring(&store, KEYRING_SERVICE, &key, &value)?;
let loaded =
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?
.expect("tokens should load from split-json compatibility format");
assert_tokens_match_without_expiry(&loaded, &tokens);
Ok(())
}
#[test]
fn load_oauth_tokens_supports_legacy_single_entry() -> Result<()> {
let _env = TempCodexHome::new();
let store = MockKeyringStore::default();
let tokens = sample_tokens();
let serialized = serde_json::to_string(&tokens)?;
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
store.save(KEYRING_SERVICE, &key, &serialized)?;
@@ -785,7 +658,7 @@ mod tests {
let loaded =
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?
.expect("tokens should load from keyring");
assert_tokens_match_without_expiry(&loaded, &tokens);
assert_tokens_match_without_expiry(&loaded, &expected);
Ok(())
}
@@ -811,9 +684,11 @@ mod tests {
#[test]
fn load_oauth_tokens_falls_back_when_keyring_errors() -> Result<()> {
let _env = TempCodexHome::new();
let store = KeyringStoreWithError::fail_load_secret(MockKeyringStore::default());
let store = MockKeyringStore::default();
let tokens = sample_tokens();
let expected = tokens.clone();
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
store.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
super::save_oauth_tokens_to_file(&tokens)?;
@@ -844,29 +719,18 @@ mod tests {
let fallback_path = super::fallback_file_path()?;
assert!(!fallback_path.exists(), "fallback file should be removed");
#[cfg(windows)]
assert!(
store.saved_secret(&key).is_none(),
"windows should not store the full JSON record under the base key"
);
#[cfg(not(windows))]
assert!(
store.saved_secret(&key).is_some(),
"non-windows should store the full JSON record as one secret"
);
let stored =
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?
.expect("value saved to keyring");
assert_tokens_match_without_expiry(&stored, &tokens);
let stored = store.saved_value(&key).expect("value saved to keyring");
assert_eq!(serde_json::from_str::<StoredOAuthTokens>(&stored)?, tokens);
Ok(())
}
#[test]
fn save_oauth_tokens_writes_fallback_when_keyring_fails() -> Result<()> {
let _env = TempCodexHome::new();
let mock_keyring = MockKeyringStore::default();
let store = KeyringStoreWithError::fail_save_secret(mock_keyring.clone());
let store = MockKeyringStore::default();
let tokens = sample_tokens();
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
store.set_error(&key, KeyringError::Invalid("error".into(), "save".into()));
super::save_oauth_tokens_with_keyring_with_fallback_to_file(
&store,
@@ -886,11 +750,7 @@ mod tests {
entry.access_token,
tokens.token_response.0.access_token().secret().as_str()
);
assert!(mock_keyring.saved_value(&key).is_none());
assert!(
load_json_from_keyring(&mock_keyring, KEYRING_SERVICE, &key)?.is_none(),
"keyring should not point at saved OAuth tokens when save fails"
);
assert!(store.saved_value(&key).is_none());
Ok(())
}
@@ -899,10 +759,8 @@ mod tests {
let _env = TempCodexHome::new();
let store = MockKeyringStore::default();
let tokens = sample_tokens();
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
let value = serde_json::to_value(&tokens)?;
save_split_json_to_keyring(&store, KEYRING_SERVICE, &key, &value)?;
let serialized = serde_json::to_string(&tokens)?;
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
store.save(KEYRING_SERVICE, &key, &serialized)?;
super::save_oauth_tokens_to_file(&tokens)?;
@@ -913,10 +771,7 @@ mod tests {
&tokens.url,
)?;
assert!(removed);
assert!(
load_json_from_keyring(&store, KEYRING_SERVICE, &key)?.is_none(),
"keyring entry should be removed"
);
assert!(!store.contains(&key));
assert!(!super::fallback_file_path()?.exists());
Ok(())
}
@@ -926,13 +781,10 @@ mod tests {
let _env = TempCodexHome::new();
let store = MockKeyringStore::default();
let tokens = sample_tokens();
let serialized = serde_json::to_string(&tokens)?;
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
let value = serde_json::to_value(&tokens)?;
save_split_json_to_keyring(&store, KEYRING_SERVICE, &key, &value)?;
assert!(
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?
.is_some()
);
store.save(KEYRING_SERVICE, &key, &serialized)?;
assert!(store.contains(&key));
let removed = super::delete_oauth_tokens_from_keyring_and_file(
&store,
@@ -941,14 +793,7 @@ mod tests {
&tokens.url,
)?;
assert!(removed);
assert!(
load_json_from_keyring(&store, KEYRING_SERVICE, &key)?.is_none(),
"keyring entry should be removed"
);
assert!(
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?
.is_none()
);
assert!(!store.contains(&key));
assert!(!super::fallback_file_path()?.exists());
Ok(())
}
@@ -956,8 +801,10 @@ mod tests {
#[test]
fn delete_oauth_tokens_propagates_keyring_errors() -> Result<()> {
let _env = TempCodexHome::new();
let store = KeyringStoreWithError::fail_delete(MockKeyringStore::default());
let store = MockKeyringStore::default();
let tokens = sample_tokens();
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
store.set_error(&key, KeyringError::Invalid("error".into(), "delete".into()));
super::save_oauth_tokens_to_file(&tokens).unwrap();
let result = super::delete_oauth_tokens_from_keyring_and_file(

View File

@@ -29,3 +29,5 @@ pub type SpawnedPty = SpawnedProcess;
pub use pty::conpty_supported;
/// Spawn a process attached to a PTY for interactive use.
pub use pty::spawn_process as spawn_pty_process;
#[cfg(windows)]
pub use win::conpty::RawConPty;

View File

@@ -238,6 +238,7 @@ async fn spawn_process_with_stdin_mode(
wait_handle,
exit_status,
exit_code,
Some(pid),
None,
);

View File

@@ -79,6 +79,7 @@ pub struct ProcessHandle {
wait_handle: StdMutex<Option<JoinHandle<()>>>,
exit_status: Arc<AtomicBool>,
exit_code: Arc<StdMutex<Option<i32>>>,
pid: Option<u32>,
// PtyHandles must be preserved because the process will receive Control+C if the
// slave is closed
_pty_handles: StdMutex<Option<PtyHandles>>,
@@ -101,6 +102,7 @@ impl ProcessHandle {
wait_handle: JoinHandle<()>,
exit_status: Arc<AtomicBool>,
exit_code: Arc<StdMutex<Option<i32>>>,
pid: Option<u32>,
pty_handles: Option<PtyHandles>,
) -> Self {
Self {
@@ -112,6 +114,7 @@ impl ProcessHandle {
wait_handle: StdMutex::new(Some(wait_handle)),
exit_status,
exit_code,
pid,
_pty_handles: StdMutex::new(pty_handles),
}
}
@@ -139,6 +142,11 @@ impl ProcessHandle {
self.exit_code.lock().ok().and_then(|guard| *guard)
}
/// Returns the OS process ID when known.
pub fn pid(&self) -> Option<u32> {
self.pid
}
/// Resize the PTY in character cells.
pub fn resize(&self, size: TerminalSize) -> anyhow::Result<()> {
let handles = self

View File

@@ -159,11 +159,12 @@ async fn spawn_process_portable(
}
let mut child = pair.slave.spawn_command(command_builder)?;
let pid = child.process_id();
#[cfg(unix)]
// portable-pty establishes the spawned PTY child as a new session leader on
// Unix, so PID == PGID and we can reuse the pipe backend's process-group
// hard-kill semantics for descendants.
let process_group_id = child.process_id();
let process_group_id = pid;
let killer = child.clone_killer();
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
@@ -241,6 +242,7 @@ async fn spawn_process_portable(
wait_handle,
exit_status,
exit_code,
pid,
Some(handles),
);
@@ -394,6 +396,7 @@ async fn spawn_process_preserving_fds(
wait_handle,
exit_status,
exit_code,
Some(process_group_id),
Some(handles),
);

View File

@@ -29,6 +29,9 @@ use portable_pty::PtyPair;
use portable_pty::PtySize;
use portable_pty::PtySystem;
use portable_pty::SlavePty;
use std::mem::ManuallyDrop;
use std::os::windows::io::AsRawHandle;
use std::os::windows::io::RawHandle;
use std::sync::Arc;
use std::sync::Mutex;
use winapi::um::wincon::COORD;
@@ -36,25 +39,68 @@ use winapi::um::wincon::COORD;
#[derive(Default)]
pub struct ConPtySystem {}
fn create_conpty_handles(
size: PtySize,
) -> anyhow::Result<(PsuedoCon, FileDescriptor, FileDescriptor)> {
let stdin = Pipe::new()?;
let stdout = Pipe::new()?;
let con = PsuedoCon::new(
COORD {
X: size.cols as i16,
Y: size.rows as i16,
},
stdin.read,
stdout.write,
)?;
Ok((con, stdin.write, stdout.read))
}
pub struct RawConPty {
con: PsuedoCon,
input_write: FileDescriptor,
output_read: FileDescriptor,
}
impl RawConPty {
pub fn new(cols: i16, rows: i16) -> anyhow::Result<Self> {
let (con, input_write, output_read) = create_conpty_handles(PtySize {
rows: rows as u16,
cols: cols as u16,
pixel_width: 0,
pixel_height: 0,
})?;
Ok(Self {
con,
input_write,
output_read,
})
}
pub fn pseudoconsole_handle(&self) -> RawHandle {
self.con.raw_handle()
}
pub fn into_raw_handles(self) -> (RawHandle, RawHandle, RawHandle) {
let me = ManuallyDrop::new(self);
(
me.con.raw_handle(),
me.input_write.as_raw_handle(),
me.output_read.as_raw_handle(),
)
}
}
impl PtySystem for ConPtySystem {
fn openpty(&self, size: PtySize) -> anyhow::Result<PtyPair> {
let stdin = Pipe::new()?;
let stdout = Pipe::new()?;
let con = PsuedoCon::new(
COORD {
X: size.cols as i16,
Y: size.rows as i16,
},
stdin.read,
stdout.write,
)?;
let (con, writable, readable) = create_conpty_handles(size)?;
let master = ConPtyMasterPty {
inner: Arc::new(Mutex::new(Inner {
con,
readable: stdout.read,
writable: Some(stdin.write),
readable,
writable: Some(writable),
size,
})),
};

View File

@@ -130,6 +130,10 @@ impl Drop for PsuedoCon {
}
impl PsuedoCon {
pub fn raw_handle(&self) -> HPCON {
self.con
}
pub fn new(size: COORD, input: FileDescriptor, output: FileDescriptor) -> Result<Self, Error> {
let mut con: HPCON = INVALID_HANDLE_VALUE;
let result = unsafe {

View File

@@ -24,12 +24,14 @@ chrono = { version = "0.4.42", default-features = false, features = [
"clock",
"std",
] }
codex-utils-pty = { workspace = true }
codex-utils-absolute-path = { workspace = true }
codex-utils-string = { workspace = true }
dunce = "1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tempfile = "3"
tokio = { workspace = true, features = ["sync", "rt"] }
windows = { version = "0.58", features = [
"Win32_Foundation",
"Win32_NetworkManagement_WindowsFirewall",
@@ -86,3 +88,6 @@ pretty_assertions = { workspace = true }
[build-dependencies]
winres = "0.1"
[package.metadata.cargo-shear]
ignored = ["codex-utils-pty", "tokio"]

View File

@@ -1,4 +1,4 @@
#[path = "../command_runner_win.rs"]
#[path = "../elevated/command_runner_win.rs"]
mod win;
#[cfg(target_os = "windows")]

View File

@@ -1,325 +0,0 @@
#![cfg(target_os = "windows")]
use anyhow::Context;
use anyhow::Result;
use codex_windows_sandbox::allow_null_device;
use codex_windows_sandbox::convert_string_sid_to_sid;
use codex_windows_sandbox::create_process_as_user;
use codex_windows_sandbox::create_readonly_token_with_caps_from;
use codex_windows_sandbox::create_workspace_write_token_with_caps_from;
use codex_windows_sandbox::get_current_token_for_restriction;
use codex_windows_sandbox::hide_current_user_profile_dir;
use codex_windows_sandbox::log_note;
use codex_windows_sandbox::parse_policy;
use codex_windows_sandbox::to_wide;
use codex_windows_sandbox::SandboxPolicy;
use serde::Deserialize;
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::path::PathBuf;
use windows_sys::Win32::Foundation::CloseHandle;
use windows_sys::Win32::Foundation::GetLastError;
use windows_sys::Win32::Foundation::LocalFree;
use windows_sys::Win32::Foundation::HANDLE;
use windows_sys::Win32::Foundation::HLOCAL;
use windows_sys::Win32::Storage::FileSystem::CreateFileW;
use windows_sys::Win32::Storage::FileSystem::FILE_GENERIC_READ;
use windows_sys::Win32::Storage::FileSystem::FILE_GENERIC_WRITE;
use windows_sys::Win32::Storage::FileSystem::OPEN_EXISTING;
use windows_sys::Win32::System::Diagnostics::Debug::SetErrorMode;
use windows_sys::Win32::System::JobObjects::AssignProcessToJobObject;
use windows_sys::Win32::System::JobObjects::CreateJobObjectW;
use windows_sys::Win32::System::JobObjects::JobObjectExtendedLimitInformation;
use windows_sys::Win32::System::JobObjects::SetInformationJobObject;
use windows_sys::Win32::System::JobObjects::JOBOBJECT_EXTENDED_LIMIT_INFORMATION;
use windows_sys::Win32::System::JobObjects::JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
use windows_sys::Win32::System::Threading::TerminateProcess;
use windows_sys::Win32::System::Threading::WaitForSingleObject;
use windows_sys::Win32::System::Threading::INFINITE;
#[path = "cwd_junction.rs"]
mod cwd_junction;
#[allow(dead_code)]
mod read_acl_mutex;
#[derive(Debug, Deserialize)]
struct RunnerRequest {
policy_json_or_preset: String,
// Writable location for logs (sandbox user's .codex).
codex_home: PathBuf,
// Real user's CODEX_HOME for shared data (caps, config).
real_codex_home: PathBuf,
cap_sids: Vec<String>,
command: Vec<String>,
cwd: PathBuf,
env_map: HashMap<String, String>,
timeout_ms: Option<u64>,
use_private_desktop: bool,
stdin_pipe: String,
stdout_pipe: String,
stderr_pipe: String,
}
const WAIT_TIMEOUT: u32 = 0x0000_0102;
unsafe fn create_job_kill_on_close() -> Result<HANDLE> {
let h = CreateJobObjectW(std::ptr::null_mut(), std::ptr::null());
if h == 0 {
return Err(anyhow::anyhow!("CreateJobObjectW failed"));
}
let mut limits: JOBOBJECT_EXTENDED_LIMIT_INFORMATION = std::mem::zeroed();
limits.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
let ok = SetInformationJobObject(
h,
JobObjectExtendedLimitInformation,
&mut limits as *mut _ as *mut _,
std::mem::size_of::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() as u32,
);
if ok == 0 {
return Err(anyhow::anyhow!("SetInformationJobObject failed"));
}
Ok(h)
}
fn read_request_file(req_path: &Path) -> Result<String> {
let content = std::fs::read_to_string(req_path)
.with_context(|| format!("read request file {}", req_path.display()));
let _ = std::fs::remove_file(req_path);
content
}
pub fn main() -> Result<()> {
let mut input = String::new();
let mut args = std::env::args().skip(1);
if let Some(first) = args.next() {
if let Some(rest) = first.strip_prefix("--request-file=") {
let req_path = PathBuf::from(rest);
input = read_request_file(&req_path)?;
}
}
if input.is_empty() {
anyhow::bail!("runner: no request-file provided");
}
let req: RunnerRequest = serde_json::from_str(&input).context("parse runner request json")?;
let log_dir = Some(req.codex_home.as_path());
hide_current_user_profile_dir(req.codex_home.as_path());
// Suppress Windows error UI from sandboxed child crashes so callers only observe exit codes.
let _ = unsafe { SetErrorMode(0x0001 | 0x0002) }; // SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX
log_note(
&format!(
"runner start cwd={} cmd={:?} real_codex_home={}",
req.cwd.display(),
req.command,
req.real_codex_home.display()
),
Some(&req.codex_home),
);
let policy = parse_policy(&req.policy_json_or_preset).context("parse policy_json_or_preset")?;
if !policy.has_full_disk_read_access() {
anyhow::bail!(
"Restricted read-only access is not yet supported by the Windows sandbox backend"
);
}
let mut cap_psids: Vec<*mut c_void> = Vec::new();
for sid in &req.cap_sids {
let Some(psid) = (unsafe { convert_string_sid_to_sid(sid) }) else {
anyhow::bail!("ConvertStringSidToSidW failed for capability SID");
};
cap_psids.push(psid);
}
if cap_psids.is_empty() {
anyhow::bail!("runner: empty capability SID list");
}
// Create restricted token from current process token.
let base = unsafe { get_current_token_for_restriction()? };
let token_res: Result<HANDLE> = unsafe {
match &policy {
SandboxPolicy::ReadOnly { .. } => {
create_readonly_token_with_caps_from(base, &cap_psids)
}
SandboxPolicy::WorkspaceWrite { .. } => {
create_workspace_write_token_with_caps_from(base, &cap_psids)
}
SandboxPolicy::DangerFullAccess | SandboxPolicy::ExternalSandbox { .. } => {
unreachable!()
}
}
};
let h_token = token_res?;
unsafe {
CloseHandle(base);
}
unsafe {
for psid in &cap_psids {
allow_null_device(*psid);
}
for psid in cap_psids {
if !psid.is_null() {
LocalFree(psid as HLOCAL);
}
}
}
// Open named pipes for stdio.
let open_pipe = |name: &str, access: u32| -> Result<HANDLE> {
let path = to_wide(name);
let handle = unsafe {
CreateFileW(
path.as_ptr(),
access,
0,
std::ptr::null_mut(),
OPEN_EXISTING,
0,
0,
)
};
if handle == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE {
let err = unsafe { GetLastError() };
log_note(
&format!("CreateFileW failed for pipe {name}: {err}"),
Some(&req.codex_home),
);
return Err(anyhow::anyhow!("CreateFileW failed for pipe {name}: {err}"));
}
Ok(handle)
};
let h_stdin = open_pipe(&req.stdin_pipe, FILE_GENERIC_READ)?;
let h_stdout = open_pipe(&req.stdout_pipe, FILE_GENERIC_WRITE)?;
let h_stderr = open_pipe(&req.stderr_pipe, FILE_GENERIC_WRITE)?;
let stdio = Some((h_stdin, h_stdout, h_stderr));
// While the read-ACL helper is running, PowerShell can fail to start in the requested CWD due
// to unreadable ancestors. Use a junction CWD for that window; once the helper finishes, go
// back to using the real requested CWD (no probing, no extra state).
let use_junction = match read_acl_mutex::read_acl_mutex_exists() {
Ok(exists) => exists,
Err(err) => {
// Fail-safe: if we can't determine the state, assume the helper might be running and
// use the junction path to avoid CWD failures on unreadable ancestors.
log_note(
&format!("junction: read_acl_mutex_exists failed: {err}; assuming read ACL helper is running"),
log_dir,
);
true
}
};
if use_junction {
log_note(
"junction: read ACL helper running; using junction CWD",
log_dir,
);
}
let effective_cwd = if use_junction {
cwd_junction::create_cwd_junction(&req.cwd, log_dir).unwrap_or_else(|| req.cwd.clone())
} else {
req.cwd.clone()
};
log_note(
&format!(
"runner: effective cwd={} (requested {})",
effective_cwd.display(),
req.cwd.display()
),
log_dir,
);
// Build command and env, spawn with CreateProcessAsUserW.
let spawn_result = unsafe {
create_process_as_user(
h_token,
&req.command,
&effective_cwd,
&req.env_map,
Some(&req.codex_home),
stdio,
req.use_private_desktop,
)
};
let created = match spawn_result {
Ok(v) => v,
Err(e) => {
log_note(&format!("runner: spawn failed: {e:?}"), log_dir);
unsafe {
CloseHandle(h_stdin);
CloseHandle(h_stdout);
CloseHandle(h_stderr);
CloseHandle(h_token);
}
return Err(e);
}
};
let proc_info = created.process_info;
let _desktop = created;
// Optional job kill on close.
let h_job = unsafe { create_job_kill_on_close().ok() };
if let Some(job) = h_job {
unsafe {
let _ = AssignProcessToJobObject(job, proc_info.hProcess);
}
}
// Wait for process.
let wait_res = unsafe {
WaitForSingleObject(
proc_info.hProcess,
req.timeout_ms.map(|ms| ms as u32).unwrap_or(INFINITE),
)
};
let timed_out = wait_res == WAIT_TIMEOUT;
let exit_code: i32;
unsafe {
if timed_out {
let _ = TerminateProcess(proc_info.hProcess, 1);
exit_code = 128 + 64;
} else {
let mut raw_exit: u32 = 1;
windows_sys::Win32::System::Threading::GetExitCodeProcess(
proc_info.hProcess,
&mut raw_exit,
);
exit_code = raw_exit as i32;
}
if proc_info.hThread != 0 {
CloseHandle(proc_info.hThread);
}
if proc_info.hProcess != 0 {
CloseHandle(proc_info.hProcess);
}
CloseHandle(h_stdin);
CloseHandle(h_stdout);
CloseHandle(h_stderr);
CloseHandle(h_token);
if let Some(job) = h_job {
CloseHandle(job);
}
}
if exit_code != 0 {
eprintln!("runner child exited with code {}", exit_code);
}
std::process::exit(exit_code);
}
#[cfg(test)]
mod tests {
use super::read_request_file;
use pretty_assertions::assert_eq;
use std::fs;
#[test]
fn removes_request_file_after_read() {
let dir = tempfile::tempdir().expect("tempdir");
let req_path = dir.path().join("request.json");
fs::write(&req_path, "{\"ok\":true}").expect("write request");
let content = read_request_file(&req_path).expect("read request");
assert_eq!(content, "{\"ok\":true}");
assert!(!req_path.exists(), "request file should be removed");
}
}

View File

@@ -0,0 +1,139 @@
//! ConPTY helpers for spawning sandboxed processes with a PTY on Windows.
//!
//! This module encapsulates ConPTY creation and process spawn with the required
//! `PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE` plumbing. It is shared by both the legacy
//! restrictedtoken path and the elevated runner path when unified_exec runs with
//! `tty=true`. The helpers are not tied to the IPC layer and can be reused by other
//! Windows sandbox flows that need a PTY.
mod proc_thread_attr;
use self::proc_thread_attr::ProcThreadAttributeList;
use crate::winutil::format_last_error;
use crate::winutil::quote_windows_arg;
use crate::winutil::to_wide;
use anyhow::Result;
use codex_utils_pty::RawConPty;
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use windows_sys::Win32::Foundation::CloseHandle;
use windows_sys::Win32::Foundation::GetLastError;
use windows_sys::Win32::Foundation::HANDLE;
use windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE;
use windows_sys::Win32::System::Console::ClosePseudoConsole;
use windows_sys::Win32::System::Threading::CreateProcessAsUserW;
use windows_sys::Win32::System::Threading::CREATE_UNICODE_ENVIRONMENT;
use windows_sys::Win32::System::Threading::EXTENDED_STARTUPINFO_PRESENT;
use windows_sys::Win32::System::Threading::PROCESS_INFORMATION;
use windows_sys::Win32::System::Threading::STARTF_USESTDHANDLES;
use windows_sys::Win32::System::Threading::STARTUPINFOEXW;
use crate::process::make_env_block;
/// Owns a ConPTY handle and its backing pipe handles.
pub struct ConptyInstance {
pub hpc: HANDLE,
pub input_write: HANDLE,
pub output_read: HANDLE,
}
impl Drop for ConptyInstance {
fn drop(&mut self) {
unsafe {
if self.input_write != 0 && self.input_write != INVALID_HANDLE_VALUE {
CloseHandle(self.input_write);
}
if self.output_read != 0 && self.output_read != INVALID_HANDLE_VALUE {
CloseHandle(self.output_read);
}
if self.hpc != 0 && self.hpc != INVALID_HANDLE_VALUE {
ClosePseudoConsole(self.hpc);
}
}
}
}
impl ConptyInstance {
/// Consume the instance and return raw handles without closing them.
pub fn into_raw(self) -> (HANDLE, HANDLE, HANDLE) {
let me = std::mem::ManuallyDrop::new(self);
(me.hpc, me.input_write, me.output_read)
}
}
/// Create a ConPTY with backing pipes.
///
/// This is public so callers that need lower-level PTY setup can build on the same
/// primitive, although the common entry point is `spawn_conpty_process_as_user`.
pub fn create_conpty(cols: i16, rows: i16) -> Result<ConptyInstance> {
let raw = RawConPty::new(cols, rows)?;
let (hpc, input_write, output_read) = raw.into_raw_handles();
Ok(ConptyInstance {
hpc: hpc as HANDLE,
input_write: input_write as HANDLE,
output_read: output_read as HANDLE,
})
}
/// Spawn a process under `h_token` with ConPTY attached.
///
/// This is the main shared ConPTY entry point and is used by both the legacy/direct path
/// and the elevated runner path whenever a PTY-backed sandboxed process is needed.
pub fn spawn_conpty_process_as_user(
h_token: HANDLE,
argv: &[String],
cwd: &Path,
env_map: &HashMap<String, String>,
) -> Result<(PROCESS_INFORMATION, ConptyInstance)> {
let cmdline_str = argv
.iter()
.map(|arg| quote_windows_arg(arg))
.collect::<Vec<_>>()
.join(" ");
let mut cmdline: Vec<u16> = to_wide(&cmdline_str);
let env_block = make_env_block(env_map);
let mut si: STARTUPINFOEXW = unsafe { std::mem::zeroed() };
si.StartupInfo.cb = std::mem::size_of::<STARTUPINFOEXW>() as u32;
si.StartupInfo.dwFlags = STARTF_USESTDHANDLES;
si.StartupInfo.hStdInput = INVALID_HANDLE_VALUE;
si.StartupInfo.hStdOutput = INVALID_HANDLE_VALUE;
si.StartupInfo.hStdError = INVALID_HANDLE_VALUE;
let desktop = to_wide("Winsta0\\Default");
si.StartupInfo.lpDesktop = desktop.as_ptr() as *mut u16;
let conpty = create_conpty(80, 24)?;
let mut attrs = ProcThreadAttributeList::new(1)?;
attrs.set_pseudoconsole(conpty.hpc)?;
si.lpAttributeList = attrs.as_mut_ptr();
let mut pi: PROCESS_INFORMATION = unsafe { std::mem::zeroed() };
let ok = unsafe {
CreateProcessAsUserW(
h_token,
std::ptr::null(),
cmdline.as_mut_ptr(),
std::ptr::null_mut(),
std::ptr::null_mut(),
0,
EXTENDED_STARTUPINFO_PRESENT | CREATE_UNICODE_ENVIRONMENT,
env_block.as_ptr() as *mut c_void,
to_wide(cwd).as_ptr(),
&si.StartupInfo,
&mut pi,
)
};
if ok == 0 {
let err = unsafe { GetLastError() } as i32;
return Err(anyhow::anyhow!(
"CreateProcessAsUserW failed: {} ({}) | cwd={} | cmd={} | env_u16_len={}",
err,
format_last_error(err),
cwd.display(),
cmdline_str,
env_block.len()
));
}
Ok((pi, conpty))
}

View File

@@ -0,0 +1,79 @@
//! Low-level Windows thread attribute helpers used by ConPTY spawn.
//!
//! This module wraps the Win32 `PROC_THREAD_ATTRIBUTE_LIST` APIs so ConPTY handles can
//! be attached to a child process. It is ConPTYspecific and used in both legacy and
//! elevated unified_exec paths when spawning a PTYbacked process.
use std::io;
use windows_sys::Win32::Foundation::GetLastError;
use windows_sys::Win32::System::Threading::DeleteProcThreadAttributeList;
use windows_sys::Win32::System::Threading::InitializeProcThreadAttributeList;
use windows_sys::Win32::System::Threading::UpdateProcThreadAttribute;
use windows_sys::Win32::System::Threading::LPPROC_THREAD_ATTRIBUTE_LIST;
const PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE: usize = 0x00020016;
/// RAII wrapper for Windows PROC_THREAD_ATTRIBUTE_LIST.
pub struct ProcThreadAttributeList {
buffer: Vec<u8>,
}
impl ProcThreadAttributeList {
/// Allocate and initialize a thread attribute list.
pub fn new(attr_count: u32) -> io::Result<Self> {
let mut size: usize = 0;
unsafe {
InitializeProcThreadAttributeList(std::ptr::null_mut(), attr_count, 0, &mut size);
}
if size == 0 {
return Err(io::Error::from_raw_os_error(unsafe {
GetLastError() as i32
}));
}
let mut buffer = vec![0u8; size];
let list = buffer.as_mut_ptr() as LPPROC_THREAD_ATTRIBUTE_LIST;
let ok = unsafe { InitializeProcThreadAttributeList(list, attr_count, 0, &mut size) };
if ok == 0 {
return Err(io::Error::from_raw_os_error(unsafe {
GetLastError() as i32
}));
}
Ok(Self { buffer })
}
/// Return a mutable pointer to the attribute list for Win32 APIs.
pub fn as_mut_ptr(&mut self) -> LPPROC_THREAD_ATTRIBUTE_LIST {
self.buffer.as_mut_ptr() as LPPROC_THREAD_ATTRIBUTE_LIST
}
/// Attach a ConPTY handle to the attribute list.
pub fn set_pseudoconsole(&mut self, hpc: isize) -> io::Result<()> {
let list = self.as_mut_ptr();
let mut hpc_value = hpc;
let ok = unsafe {
UpdateProcThreadAttribute(
list,
0,
PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE,
(&mut hpc_value as *mut isize).cast(),
std::mem::size_of::<isize>(),
std::ptr::null_mut(),
std::ptr::null_mut(),
)
};
if ok == 0 {
return Err(io::Error::from_raw_os_error(unsafe {
GetLastError() as i32
}));
}
Ok(())
}
}
impl Drop for ProcThreadAttributeList {
fn drop(&mut self) {
unsafe {
DeleteProcThreadAttributeList(self.as_mut_ptr());
}
}
}

View File

@@ -0,0 +1,742 @@
//! Windows command runner used by the **elevated** sandbox path.
//!
//! The CLI launches this binary under the sandbox user when Windows sandbox level is
//! Elevated. It connects to the IPC pipes, reads the framed `SpawnRequest`, derives a
//! restricted token from the sandbox user, and spawns the child process via ConPTY
//! (`tty=true`) or pipes (`tty=false`). It then streams output frames back to the parent,
//! accepts stdin/terminate frames, and emits a final exit frame. The legacy restrictedtoken
//! path spawns the child directly and does not use this runner.
#![cfg(target_os = "windows")]
use anyhow::Context;
use anyhow::Result;
use codex_windows_sandbox::allow_null_device;
use codex_windows_sandbox::convert_string_sid_to_sid;
use codex_windows_sandbox::create_process_as_user;
use codex_windows_sandbox::create_readonly_token_with_caps_from;
use codex_windows_sandbox::create_workspace_write_token_with_caps_from;
use codex_windows_sandbox::get_current_token_for_restriction;
use codex_windows_sandbox::hide_current_user_profile_dir;
use codex_windows_sandbox::ipc_framed::decode_bytes;
use codex_windows_sandbox::ipc_framed::encode_bytes;
use codex_windows_sandbox::ipc_framed::read_frame;
use codex_windows_sandbox::ipc_framed::write_frame;
use codex_windows_sandbox::ipc_framed::ErrorPayload;
use codex_windows_sandbox::ipc_framed::ExitPayload;
use codex_windows_sandbox::ipc_framed::FramedMessage;
use codex_windows_sandbox::ipc_framed::Message;
use codex_windows_sandbox::ipc_framed::OutputPayload;
use codex_windows_sandbox::ipc_framed::OutputStream;
use codex_windows_sandbox::log_note;
use codex_windows_sandbox::parse_policy;
use codex_windows_sandbox::read_handle_loop;
use codex_windows_sandbox::spawn_process_with_pipes;
use codex_windows_sandbox::to_wide;
use codex_windows_sandbox::PipeSpawnHandles;
use codex_windows_sandbox::SandboxPolicy;
use codex_windows_sandbox::StderrMode;
use codex_windows_sandbox::StdinMode;
use serde::Deserialize;
use std::collections::HashMap;
use std::ffi::c_void;
use std::fs::File;
use std::os::windows::io::FromRawHandle;
use std::path::Path;
use std::path::PathBuf;
use std::ptr;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use windows_sys::Win32::Foundation::CloseHandle;
use windows_sys::Win32::Foundation::GetLastError;
use windows_sys::Win32::Foundation::LocalFree;
use windows_sys::Win32::Foundation::HANDLE;
use windows_sys::Win32::Foundation::HLOCAL;
use windows_sys::Win32::Storage::FileSystem::CreateFileW;
use windows_sys::Win32::Storage::FileSystem::FILE_GENERIC_READ;
use windows_sys::Win32::Storage::FileSystem::FILE_GENERIC_WRITE;
use windows_sys::Win32::Storage::FileSystem::OPEN_EXISTING;
use windows_sys::Win32::System::Console::ClosePseudoConsole;
use windows_sys::Win32::System::JobObjects::AssignProcessToJobObject;
use windows_sys::Win32::System::JobObjects::CreateJobObjectW;
use windows_sys::Win32::System::JobObjects::JobObjectExtendedLimitInformation;
use windows_sys::Win32::System::JobObjects::SetInformationJobObject;
use windows_sys::Win32::System::JobObjects::JOBOBJECT_EXTENDED_LIMIT_INFORMATION;
use windows_sys::Win32::System::JobObjects::JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
use windows_sys::Win32::System::Threading::GetExitCodeProcess;
use windows_sys::Win32::System::Threading::GetProcessId;
use windows_sys::Win32::System::Threading::TerminateProcess;
use windows_sys::Win32::System::Threading::WaitForSingleObject;
use windows_sys::Win32::System::Threading::INFINITE;
use windows_sys::Win32::System::Threading::PROCESS_INFORMATION;
#[path = "cwd_junction.rs"]
mod cwd_junction;
#[allow(dead_code)]
#[path = "../read_acl_mutex.rs"]
mod read_acl_mutex;
#[derive(Debug, Deserialize)]
struct RunnerRequest {
policy_json_or_preset: String,
codex_home: PathBuf,
real_codex_home: PathBuf,
cap_sids: Vec<String>,
command: Vec<String>,
cwd: PathBuf,
env_map: HashMap<String, String>,
timeout_ms: Option<u64>,
use_private_desktop: bool,
stdin_pipe: String,
stdout_pipe: String,
stderr_pipe: String,
}
const WAIT_TIMEOUT: u32 = 0x0000_0102;
struct IpcSpawnedProcess {
log_dir: PathBuf,
pi: PROCESS_INFORMATION,
stdout_handle: HANDLE,
stderr_handle: HANDLE,
stdin_handle: Option<HANDLE>,
hpc_handle: Option<HANDLE>,
}
unsafe fn create_job_kill_on_close() -> Result<HANDLE> {
let h = CreateJobObjectW(std::ptr::null_mut(), std::ptr::null());
if h == 0 {
return Err(anyhow::anyhow!("CreateJobObjectW failed"));
}
let mut limits: JOBOBJECT_EXTENDED_LIMIT_INFORMATION = std::mem::zeroed();
limits.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
let ok = SetInformationJobObject(
h,
JobObjectExtendedLimitInformation,
&mut limits as *mut _ as *mut _,
std::mem::size_of::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() as u32,
);
if ok == 0 {
return Err(anyhow::anyhow!("SetInformationJobObject failed"));
}
Ok(h)
}
/// Open a named pipe created by the parent process.
fn open_pipe(name: &str, access: u32) -> Result<HANDLE> {
let path = to_wide(name);
let handle = unsafe {
CreateFileW(
path.as_ptr(),
access,
0,
std::ptr::null_mut(),
OPEN_EXISTING,
0,
0,
)
};
if handle == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE {
let err = unsafe { GetLastError() };
return Err(anyhow::anyhow!("CreateFileW failed for pipe {name}: {err}"));
}
Ok(handle)
}
fn read_request_file(req_path: &Path) -> Result<String> {
let content = std::fs::read_to_string(req_path)
.with_context(|| format!("read request file {}", req_path.display()));
let _ = std::fs::remove_file(req_path);
content
}
/// Send an error frame back to the parent process.
fn send_error(writer: &Arc<StdMutex<File>>, code: &str, message: String) -> Result<()> {
let msg = FramedMessage {
version: 1,
message: Message::Error {
payload: ErrorPayload {
message,
code: code.to_string(),
},
},
};
if let Ok(mut guard) = writer.lock() {
write_frame(&mut *guard, &msg)?;
}
Ok(())
}
/// Read and validate the initial spawn request frame.
fn read_spawn_request(
reader: &mut File,
) -> Result<codex_windows_sandbox::ipc_framed::SpawnRequest> {
let Some(msg) = read_frame(reader)? else {
anyhow::bail!("runner: pipe closed before spawn_request");
};
if msg.version != 1 {
anyhow::bail!("runner: unsupported protocol version {}", msg.version);
}
match msg.message {
Message::SpawnRequest { payload } => Ok(*payload),
other => anyhow::bail!("runner: expected spawn_request, got {other:?}"),
}
}
/// Pick an effective CWD, using a junction if the ACL helper is active.
fn effective_cwd(req_cwd: &Path, log_dir: Option<&Path>) -> PathBuf {
let use_junction = match read_acl_mutex::read_acl_mutex_exists() {
Ok(exists) => exists,
Err(err) => {
log_note(
&format!(
"junction: read_acl_mutex_exists failed: {err}; assuming read ACL helper is running"
),
log_dir,
);
true
}
};
if use_junction {
log_note(
"junction: read ACL helper running; using junction CWD",
log_dir,
);
cwd_junction::create_cwd_junction(req_cwd, log_dir).unwrap_or_else(|| req_cwd.to_path_buf())
} else {
req_cwd.to_path_buf()
}
}
fn spawn_ipc_process(
req: &codex_windows_sandbox::ipc_framed::SpawnRequest,
) -> Result<IpcSpawnedProcess> {
let log_dir = req.codex_home.clone();
hide_current_user_profile_dir(req.codex_home.as_path());
log_note(
&format!(
"runner start cwd={} cmd={:?} real_codex_home={}",
req.cwd.display(),
req.command,
req.real_codex_home.display()
),
Some(&req.codex_home),
);
let policy = parse_policy(&req.policy_json_or_preset).context("parse policy_json_or_preset")?;
if !policy.has_full_disk_read_access() {
anyhow::bail!(
"Restricted read-only access is not yet supported by the Windows sandbox backend"
);
}
let mut cap_psids: Vec<*mut c_void> = Vec::new();
for sid in &req.cap_sids {
let Some(psid) = (unsafe { convert_string_sid_to_sid(sid) }) else {
anyhow::bail!("ConvertStringSidToSidW failed for capability SID");
};
cap_psids.push(psid);
}
if cap_psids.is_empty() {
anyhow::bail!("runner: empty capability SID list");
}
let base = unsafe { get_current_token_for_restriction()? };
let token_res: Result<(HANDLE, *mut c_void)> = unsafe {
match &policy {
SandboxPolicy::ReadOnly { .. } => {
create_readonly_token_with_caps_from(base, &cap_psids)
.map(|h_token| (h_token, cap_psids[0]))
}
SandboxPolicy::WorkspaceWrite { .. } => {
create_workspace_write_token_with_caps_from(base, &cap_psids)
.map(|h_token| (h_token, cap_psids[0]))
}
SandboxPolicy::DangerFullAccess | SandboxPolicy::ExternalSandbox { .. } => {
unreachable!()
}
}
};
let (h_token, psid_to_use) = token_res?;
unsafe {
CloseHandle(base);
allow_null_device(psid_to_use);
for psid in &cap_psids {
allow_null_device(*psid);
}
for psid in cap_psids {
if !psid.is_null() {
LocalFree(psid as HLOCAL);
}
}
}
let effective_cwd = effective_cwd(&req.cwd, Some(log_dir.as_path()));
log_note(
&format!(
"runner: effective cwd={} (requested {})",
effective_cwd.display(),
req.cwd.display()
),
Some(log_dir.as_path()),
);
let mut hpc_handle: Option<HANDLE> = None;
let (pi, stdout_handle, stderr_handle, stdin_handle) = if req.tty {
let (pi, conpty) = codex_windows_sandbox::spawn_conpty_process_as_user(
h_token,
&req.command,
&effective_cwd,
&req.env,
)?;
let (hpc, input_write, output_read) = conpty.into_raw();
hpc_handle = Some(hpc);
let stdin_handle = if req.stdin_open {
Some(input_write)
} else {
unsafe {
CloseHandle(input_write);
}
None
};
(
pi,
output_read,
windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE,
stdin_handle,
)
} else {
let stdin_mode = if req.stdin_open {
StdinMode::Open
} else {
StdinMode::Closed
};
let pipe_handles: PipeSpawnHandles = spawn_process_with_pipes(
h_token,
&req.command,
&effective_cwd,
&req.env,
stdin_mode,
StderrMode::Separate,
)?;
(
pipe_handles.process,
pipe_handles.stdout_read,
pipe_handles
.stderr_read
.unwrap_or(windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE),
pipe_handles.stdin_write,
)
};
unsafe {
CloseHandle(h_token);
}
Ok(IpcSpawnedProcess {
log_dir,
pi,
stdout_handle,
stderr_handle,
stdin_handle,
hpc_handle,
})
}
/// Stream stdout/stderr from the child into Output frames.
fn spawn_output_reader(
writer: Arc<StdMutex<File>>,
handle: HANDLE,
stream: OutputStream,
log_dir: Option<PathBuf>,
) -> std::thread::JoinHandle<()> {
read_handle_loop(handle, move |chunk| {
let msg = FramedMessage {
version: 1,
message: Message::Output {
payload: OutputPayload {
data_b64: encode_bytes(chunk),
stream,
},
},
};
if let Ok(mut guard) = writer.lock() {
if let Err(err) = write_frame(&mut *guard, &msg) {
log_note(
&format!("runner output write failed: {err}"),
log_dir.as_deref(),
);
}
}
})
}
/// Read stdin/terminate frames and forward to the child process.
fn spawn_input_loop(
mut reader: File,
stdin_handle: Option<HANDLE>,
process_handle: Arc<StdMutex<Option<HANDLE>>>,
log_dir: Option<PathBuf>,
) -> std::thread::JoinHandle<()> {
std::thread::spawn(move || {
loop {
let msg = match read_frame(&mut reader) {
Ok(Some(v)) => v,
Ok(None) => break,
Err(err) => {
log_note(
&format!("runner input read failed: {err}"),
log_dir.as_deref(),
);
break;
}
};
match msg.message {
Message::Stdin { payload } => {
let Ok(bytes) = decode_bytes(&payload.data_b64) else {
continue;
};
if let Some(handle) = stdin_handle {
let mut written: u32 = 0;
unsafe {
let _ = windows_sys::Win32::Storage::FileSystem::WriteFile(
handle,
bytes.as_ptr(),
bytes.len() as u32,
&mut written,
ptr::null_mut(),
);
}
}
}
Message::Terminate { .. } => {
if let Ok(guard) = process_handle.lock() {
if let Some(handle) = guard.as_ref() {
unsafe {
let _ = TerminateProcess(*handle, 1);
}
}
}
}
Message::SpawnRequest { .. } => {}
Message::SpawnReady { .. } => {}
Message::Output { .. } => {}
Message::Exit { .. } => {}
Message::Error { .. } => {}
}
}
if let Some(handle) = stdin_handle {
unsafe {
CloseHandle(handle);
}
}
})
}
/// Entry point for the Windows command runner process.
pub fn main() -> Result<()> {
let mut request_file = None;
let mut pipe_in = None;
let mut pipe_out = None;
let mut pipe_single = None;
for arg in std::env::args().skip(1) {
if let Some(rest) = arg.strip_prefix("--request-file=") {
request_file = Some(rest.to_string());
} else if let Some(rest) = arg.strip_prefix("--pipe-in=") {
pipe_in = Some(rest.to_string());
} else if let Some(rest) = arg.strip_prefix("--pipe-out=") {
pipe_out = Some(rest.to_string());
} else if let Some(rest) = arg.strip_prefix("--pipe=") {
pipe_single = Some(rest.to_string());
}
}
if pipe_in.is_none() && pipe_out.is_none() {
if let Some(single) = pipe_single {
pipe_in = Some(single.clone());
pipe_out = Some(single);
}
}
if let Some(request_file) = request_file {
let req_path = PathBuf::from(request_file);
let input = read_request_file(&req_path)?;
let req: RunnerRequest =
serde_json::from_str(&input).context("parse runner request json")?;
let log_dir = Some(req.codex_home.as_path());
hide_current_user_profile_dir(req.codex_home.as_path());
log_note(
&format!(
"runner start cwd={} cmd={:?} real_codex_home={}",
req.cwd.display(),
req.command,
req.real_codex_home.display()
),
Some(&req.codex_home),
);
let policy =
parse_policy(&req.policy_json_or_preset).context("parse policy_json_or_preset")?;
if !policy.has_full_disk_read_access() {
anyhow::bail!(
"Restricted read-only access is not yet supported by the Windows sandbox backend"
);
}
let mut cap_psids: Vec<*mut c_void> = Vec::new();
for sid in &req.cap_sids {
let Some(psid) = (unsafe { convert_string_sid_to_sid(sid) }) else {
anyhow::bail!("ConvertStringSidToSidW failed for capability SID");
};
cap_psids.push(psid);
}
if cap_psids.is_empty() {
anyhow::bail!("runner: empty capability SID list");
}
let base = unsafe { get_current_token_for_restriction()? };
let token_res: Result<HANDLE> = unsafe {
match &policy {
SandboxPolicy::ReadOnly { .. } => {
create_readonly_token_with_caps_from(base, &cap_psids)
}
SandboxPolicy::WorkspaceWrite { .. } => {
create_workspace_write_token_with_caps_from(base, &cap_psids)
}
SandboxPolicy::DangerFullAccess | SandboxPolicy::ExternalSandbox { .. } => {
unreachable!()
}
}
};
let h_token = token_res?;
unsafe {
CloseHandle(base);
for psid in &cap_psids {
allow_null_device(*psid);
}
for psid in cap_psids {
if !psid.is_null() {
LocalFree(psid as HLOCAL);
}
}
}
let h_stdin = open_pipe(&req.stdin_pipe, FILE_GENERIC_READ)?;
let h_stdout = open_pipe(&req.stdout_pipe, FILE_GENERIC_WRITE)?;
let h_stderr = open_pipe(&req.stderr_pipe, FILE_GENERIC_WRITE)?;
let stdio = Some((h_stdin, h_stdout, h_stderr));
let effective_cwd = effective_cwd(&req.cwd, log_dir);
log_note(
&format!(
"runner: effective cwd={} (requested {})",
effective_cwd.display(),
req.cwd.display()
),
log_dir,
);
let spawn_result = unsafe {
create_process_as_user(
h_token,
&req.command,
&effective_cwd,
&req.env_map,
Some(&req.codex_home),
stdio,
req.use_private_desktop,
)
};
let created = match spawn_result {
Ok(v) => v,
Err(err) => {
log_note(&format!("runner: spawn failed: {err:?}"), log_dir);
unsafe {
CloseHandle(h_stdin);
CloseHandle(h_stdout);
CloseHandle(h_stderr);
CloseHandle(h_token);
}
return Err(err);
}
};
let proc_info = created.process_info;
let h_job = unsafe { create_job_kill_on_close().ok() };
if let Some(job) = h_job {
unsafe {
let _ = AssignProcessToJobObject(job, proc_info.hProcess);
}
}
let wait_res = unsafe {
WaitForSingleObject(
proc_info.hProcess,
req.timeout_ms.map(|ms| ms as u32).unwrap_or(INFINITE),
)
};
let timed_out = wait_res == WAIT_TIMEOUT;
let exit_code: i32;
unsafe {
if timed_out {
let _ = TerminateProcess(proc_info.hProcess, 1);
exit_code = 128 + 64;
} else {
let mut raw_exit: u32 = 1;
GetExitCodeProcess(proc_info.hProcess, &mut raw_exit);
exit_code = raw_exit as i32;
}
if proc_info.hThread != 0 {
CloseHandle(proc_info.hThread);
}
if proc_info.hProcess != 0 {
CloseHandle(proc_info.hProcess);
}
CloseHandle(h_stdin);
CloseHandle(h_stdout);
CloseHandle(h_stderr);
CloseHandle(h_token);
if let Some(job) = h_job {
CloseHandle(job);
}
}
if exit_code != 0 {
eprintln!("runner child exited with code {exit_code}");
}
std::process::exit(exit_code);
}
let Some(pipe_in) = pipe_in else {
anyhow::bail!("runner: no pipe-in provided");
};
let Some(pipe_out) = pipe_out else {
anyhow::bail!("runner: no pipe-out provided");
};
let h_pipe_in = open_pipe(&pipe_in, FILE_GENERIC_READ)?;
let h_pipe_out = open_pipe(&pipe_out, FILE_GENERIC_WRITE)?;
let mut pipe_read = unsafe { File::from_raw_handle(h_pipe_in as _) };
let pipe_write = Arc::new(StdMutex::new(unsafe {
File::from_raw_handle(h_pipe_out as _)
}));
let req = match read_spawn_request(&mut pipe_read) {
Ok(v) => v,
Err(err) => {
let _ = send_error(&pipe_write, "spawn_failed", err.to_string());
return Err(err);
}
};
let ipc_spawn = match spawn_ipc_process(&req) {
Ok(value) => value,
Err(err) => {
let _ = send_error(&pipe_write, "spawn_failed", err.to_string());
return Err(err);
}
};
let log_dir = Some(ipc_spawn.log_dir.as_path());
let pi = ipc_spawn.pi;
let stdout_handle = ipc_spawn.stdout_handle;
let stderr_handle = ipc_spawn.stderr_handle;
let stdin_handle = ipc_spawn.stdin_handle;
let hpc_handle = ipc_spawn.hpc_handle;
let h_job = unsafe { create_job_kill_on_close().ok() };
if let Some(job) = h_job {
unsafe {
let _ = AssignProcessToJobObject(job, pi.hProcess);
}
}
let process_handle = Arc::new(StdMutex::new(Some(pi.hProcess)));
let msg = FramedMessage {
version: 1,
message: Message::SpawnReady {
payload: codex_windows_sandbox::ipc_framed::SpawnReady {
process_id: unsafe { GetProcessId(pi.hProcess) },
},
},
};
if let Err(err) = if let Ok(mut guard) = pipe_write.lock() {
write_frame(&mut *guard, &msg)
} else {
anyhow::bail!("runner spawn_ready write failed: pipe_write lock poisoned");
} {
log_note(&format!("runner spawn_ready write failed: {err}"), log_dir);
let _ = send_error(&pipe_write, "spawn_failed", err.to_string());
return Err(err);
}
let log_dir_owned = log_dir.map(|p| p.to_path_buf());
let out_thread = spawn_output_reader(
Arc::clone(&pipe_write),
stdout_handle,
OutputStream::Stdout,
log_dir_owned.clone(),
);
let err_thread = if stderr_handle != windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE {
Some(spawn_output_reader(
Arc::clone(&pipe_write),
stderr_handle,
OutputStream::Stderr,
log_dir_owned.clone(),
))
} else {
None
};
let _input_thread = spawn_input_loop(
pipe_read,
stdin_handle,
Arc::clone(&process_handle),
log_dir_owned,
);
let timeout = req.timeout_ms.map(|ms| ms as u32).unwrap_or(INFINITE);
let wait_res = unsafe { WaitForSingleObject(pi.hProcess, timeout) };
let timed_out = wait_res == WAIT_TIMEOUT;
let exit_code: i32;
unsafe {
if timed_out {
let _ = TerminateProcess(pi.hProcess, 1);
exit_code = 128 + 64;
} else {
let mut raw_exit: u32 = 1;
GetExitCodeProcess(pi.hProcess, &mut raw_exit);
exit_code = raw_exit as i32;
}
if let Some(hpc) = hpc_handle {
ClosePseudoConsole(hpc);
}
if pi.hThread != 0 {
CloseHandle(pi.hThread);
}
if pi.hProcess != 0 {
CloseHandle(pi.hProcess);
}
if let Some(job) = h_job {
CloseHandle(job);
}
}
let _ = out_thread.join();
if let Some(err_thread) = err_thread {
let _ = err_thread.join();
}
let exit_msg = FramedMessage {
version: 1,
message: Message::Exit {
payload: ExitPayload {
exit_code,
timed_out,
},
},
};
if let Ok(mut guard) = pipe_write.lock() {
if let Err(err) = write_frame(&mut *guard, &exit_msg) {
log_note(&format!("runner exit write failed: {err}"), log_dir);
}
}
std::process::exit(exit_code);
}

View File

@@ -0,0 +1,181 @@
//! Framed IPC protocol used between the parent (CLI) and the elevated command runner.
//!
//! This module defines the JSON message schema (spawn request/ready, output, stdin,
//! exit, error, terminate) plus lengthprefixed framing helpers for a byte stream.
//! It is **elevated-path only**: the parent uses it to bootstrap the runner and
//! stream unified_exec I/O over named pipes. The legacy restrictedtoken path does
//! not use this protocol, and nonunified exec capture uses it only when running
//! through the elevated runner.
use anyhow::Result;
use base64::engine::general_purpose::STANDARD;
use base64::Engine as _;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::io::Read;
use std::io::Write;
use std::path::PathBuf;
/// Safety cap for a single framed message payload.
///
/// This is not a protocol requirement; it simply bounds memory use and rejects
/// obviously invalid frames.
const MAX_FRAME_LEN: usize = 8 * 1024 * 1024;
/// Length-prefixed, JSON-encoded frame.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FramedMessage {
pub version: u8,
#[serde(flatten)]
pub message: Message,
}
/// IPC message variants exchanged between parent and runner.
///
/// `SpawnRequest`, `Stdin`, and `Terminate` are parent->runner commands. `SpawnReady`,
/// `Output`, `Exit`, and `Error` are runner->parent events/results.
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Message {
SpawnRequest { payload: Box<SpawnRequest> },
SpawnReady { payload: SpawnReady },
Output { payload: OutputPayload },
Stdin { payload: StdinPayload },
Exit { payload: ExitPayload },
Error { payload: ErrorPayload },
Terminate { payload: EmptyPayload },
}
/// Spawn parameters sent from parent to runner.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SpawnRequest {
pub command: Vec<String>,
pub cwd: PathBuf,
pub env: HashMap<String, String>,
pub policy_json_or_preset: String,
pub sandbox_policy_cwd: PathBuf,
pub codex_home: PathBuf,
pub real_codex_home: PathBuf,
pub cap_sids: Vec<String>,
pub timeout_ms: Option<u64>,
pub tty: bool,
#[serde(default)]
pub stdin_open: bool,
}
/// Ack from runner after it spawns the child process.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SpawnReady {
pub process_id: u32,
}
/// Output data sent from runner to parent.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OutputPayload {
pub data_b64: String,
pub stream: OutputStream,
}
/// Output stream identifier for `OutputPayload`.
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum OutputStream {
Stdout,
Stderr,
}
/// Stdin bytes sent from parent to runner.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StdinPayload {
pub data_b64: String,
}
/// Exit status sent from runner to parent.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ExitPayload {
pub exit_code: i32,
pub timed_out: bool,
}
/// Error payload sent when the runner fails to spawn or stream.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ErrorPayload {
pub message: String,
pub code: String,
}
/// Empty payload for control messages.
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
pub struct EmptyPayload {}
/// Base64-encode raw bytes for IPC payloads.
pub fn encode_bytes(data: &[u8]) -> String {
STANDARD.encode(data)
}
/// Decode base64 payload data into raw bytes.
pub fn decode_bytes(data: &str) -> Result<Vec<u8>> {
Ok(STANDARD.decode(data.as_bytes())?)
}
/// Write a length-prefixed JSON frame.
pub fn write_frame<W: Write>(mut writer: W, msg: &FramedMessage) -> Result<()> {
let payload = serde_json::to_vec(msg)?;
if payload.len() > MAX_FRAME_LEN {
anyhow::bail!("frame too large: {}", payload.len());
}
let len = payload.len() as u32;
writer.write_all(&len.to_le_bytes())?;
writer.write_all(&payload)?;
writer.flush()?;
Ok(())
}
/// Read a length-prefixed JSON frame; returns `Ok(None)` on EOF.
pub fn read_frame<R: Read>(mut reader: R) -> Result<Option<FramedMessage>> {
let mut len_buf = [0u8; 4];
match reader.read_exact(&mut len_buf) {
Ok(()) => {}
Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(err) => return Err(err.into()),
}
let len = u32::from_le_bytes(len_buf) as usize;
if len > MAX_FRAME_LEN {
anyhow::bail!("frame too large: {}", len);
}
let mut payload = vec![0u8; len];
reader.read_exact(&mut payload)?;
let msg: FramedMessage = serde_json::from_slice(&payload)?;
Ok(Some(msg))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn framed_round_trip() {
let msg = FramedMessage {
version: 1,
message: Message::Output {
payload: OutputPayload {
data_b64: encode_bytes(b"hello"),
stream: OutputStream::Stdout,
},
},
};
let mut buf = Vec::new();
write_frame(&mut buf, &msg).expect("write");
let decoded = read_frame(buf.as_slice()).expect("read").expect("some");
assert_eq!(decoded.version, 1);
match decoded.message {
Message::Output { payload } => {
assert_eq!(payload.stream, OutputStream::Stdout);
let data = decode_bytes(&payload.data_b64).expect("decode");
assert_eq!(data, b"hello");
}
other => panic!("unexpected message: {other:?}"),
}
}
}

View File

@@ -0,0 +1,111 @@
//! Named pipe helpers for the elevated Windows sandbox runner.
//!
//! This module generates paired pipe names, creates serverside pipes with permissive
//! ACLs, and waits for the runner to connect. It is **elevated-path only** and is
//! used by the parent to establish the IPC channel for both unified_exec sessions
//! and elevated capture. The legacy restrictedtoken path spawns the child directly
//! and does not use these helpers.
use crate::helper_materialization::HelperExecutable;
use crate::helper_materialization::resolve_helper_for_launch;
use crate::winutil::resolve_sid;
use crate::winutil::string_from_sid_bytes;
use crate::winutil::to_wide;
use rand::Rng;
use rand::SeedableRng;
use rand::rngs::SmallRng;
use std::io;
use std::path::Path;
use std::path::PathBuf;
use std::ptr;
use windows_sys::Win32::Foundation::GetLastError;
use windows_sys::Win32::Foundation::HANDLE;
use windows_sys::Win32::Security::Authorization::ConvertStringSecurityDescriptorToSecurityDescriptorW;
use windows_sys::Win32::Security::PSECURITY_DESCRIPTOR;
use windows_sys::Win32::Security::SECURITY_ATTRIBUTES;
use windows_sys::Win32::System::Pipes::ConnectNamedPipe;
use windows_sys::Win32::System::Pipes::CreateNamedPipeW;
use windows_sys::Win32::System::Pipes::PIPE_READMODE_BYTE;
use windows_sys::Win32::System::Pipes::PIPE_TYPE_BYTE;
use windows_sys::Win32::System::Pipes::PIPE_WAIT;
/// PIPE_ACCESS_INBOUND (win32 constant), not exposed in windows-sys 0.52.
pub const PIPE_ACCESS_INBOUND: u32 = 0x0000_0001;
/// PIPE_ACCESS_OUTBOUND (win32 constant), not exposed in windows-sys 0.52.
pub const PIPE_ACCESS_OUTBOUND: u32 = 0x0000_0002;
/// Resolves the elevated command runner path, preferring the copied helper under
/// `.sandbox-bin` and falling back to the legacy sibling lookup when needed.
pub fn find_runner_exe(codex_home: &Path, log_dir: Option<&Path>) -> PathBuf {
resolve_helper_for_launch(HelperExecutable::CommandRunner, codex_home, log_dir)
}
/// Generates a unique named-pipe path used to communicate with the runner process.
pub fn pipe_pair() -> (String, String) {
let mut rng = SmallRng::from_entropy();
let base = format!(r"\\.\pipe\codex-runner-{:x}", rng.gen::<u128>());
(format!("{base}-in"), format!("{base}-out"))
}
/// Creates a named pipe whose DACL only allows the sandbox user to connect.
pub fn create_named_pipe(name: &str, access: u32, sandbox_username: &str) -> io::Result<HANDLE> {
let sandbox_sid = resolve_sid(sandbox_username)
.map_err(|err| io::Error::new(io::ErrorKind::PermissionDenied, err.to_string()))?;
let sandbox_sid = string_from_sid_bytes(&sandbox_sid)
.map_err(|err| io::Error::new(io::ErrorKind::PermissionDenied, err))?;
let sddl = to_wide(format!("D:(A;;GA;;;{sandbox_sid})"));
let mut sd: PSECURITY_DESCRIPTOR = ptr::null_mut();
let ok = unsafe {
ConvertStringSecurityDescriptorToSecurityDescriptorW(
sddl.as_ptr(),
1, // SDDL_REVISION_1
&mut sd,
ptr::null_mut(),
)
};
if ok == 0 {
return Err(io::Error::from_raw_os_error(unsafe {
GetLastError() as i32
}));
}
let mut sa = SECURITY_ATTRIBUTES {
nLength: std::mem::size_of::<SECURITY_ATTRIBUTES>() as u32,
lpSecurityDescriptor: sd,
bInheritHandle: 0,
};
let wide = to_wide(name);
let h = unsafe {
CreateNamedPipeW(
wide.as_ptr(),
access,
PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_WAIT,
1,
65536,
65536,
0,
&mut sa as *mut SECURITY_ATTRIBUTES,
)
};
if h == 0 || h == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE {
return Err(io::Error::from_raw_os_error(unsafe {
GetLastError() as i32
}));
}
Ok(h)
}
/// Waits for the runner to connect to a parent-created server pipe.
///
/// This is parent-side only: the runner opens the pipe with `CreateFileW`, while the
/// parent calls `ConnectNamedPipe` and tolerates the already-connected case.
pub fn connect_pipe(h: HANDLE) -> io::Result<()> {
let ok = unsafe { ConnectNamedPipe(h, ptr::null_mut()) };
if ok == 0 {
let err = unsafe { GetLastError() };
const ERROR_PIPE_CONNECTED: u32 = 535;
if err != ERROR_PIPE_CONNECTED {
return Err(io::Error::from_raw_os_error(err as i32));
}
}
Ok(())
}

View File

@@ -17,6 +17,8 @@ mod windows_impl {
use crate::policy::SandboxPolicy;
use crate::token::convert_string_sid_to_sid;
use crate::winutil::quote_windows_arg;
use crate::winutil::resolve_sid;
use crate::winutil::string_from_sid_bytes;
use crate::winutil::to_wide;
use anyhow::Result;
use rand::rngs::SmallRng;
@@ -123,10 +125,9 @@ mod windows_impl {
format!(r"\\.\pipe\codex-runner-{:x}-{}", rng.gen::<u128>(), suffix)
}
/// Creates a named pipe with permissive ACLs so the sandbox user can connect.
fn create_named_pipe(name: &str, access: u32) -> io::Result<HANDLE> {
// Allow sandbox users to connect by granting Everyone full access on the pipe.
let sddl = to_wide("D:(A;;GA;;;WD)");
/// Creates a named pipe whose DACL only allows the sandbox user to connect.
fn create_named_pipe(name: &str, access: u32, sandbox_sid: &str) -> io::Result<HANDLE> {
let sddl = to_wide(format!("D:(A;;GA;;;{sandbox_sid})"));
let mut sd: PSECURITY_DESCRIPTOR = ptr::null_mut();
let ok = unsafe {
ConvertStringSecurityDescriptorToSecurityDescriptorW(
@@ -228,6 +229,11 @@ mod windows_impl {
log_start(&command, logs_base_dir);
let sandbox_creds =
require_logon_sandbox_creds(&policy, sandbox_policy_cwd, cwd, &env_map, codex_home)?;
let sandbox_sid = resolve_sid(&sandbox_creds.username).map_err(|err: anyhow::Error| {
io::Error::new(io::ErrorKind::PermissionDenied, err.to_string())
})?;
let sandbox_sid = string_from_sid_bytes(&sandbox_sid)
.map_err(|err| io::Error::new(io::ErrorKind::PermissionDenied, err))?;
// Build capability SID for ACL grants.
if matches!(
&policy,
@@ -272,14 +278,17 @@ mod windows_impl {
let h_stdin_pipe = create_named_pipe(
&stdin_name,
PIPE_ACCESS_DUPLEX | PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_WAIT,
&sandbox_sid,
)?;
let h_stdout_pipe = create_named_pipe(
&stdout_name,
PIPE_ACCESS_DUPLEX | PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_WAIT,
&sandbox_sid,
)?;
let h_stderr_pipe = create_named_pipe(
&stderr_name,
PIPE_ACCESS_DUPLEX | PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_WAIT,
&sandbox_sid,
)?;
// Launch runner as sandbox user via CreateProcessWithLogonW.

View File

@@ -24,6 +24,14 @@ windows_modules!(
workspace_acl
);
#[cfg(target_os = "windows")]
#[path = "conpty/mod.rs"]
mod conpty;
#[cfg(target_os = "windows")]
#[path = "elevated/ipc_framed.rs"]
pub mod ipc_framed;
#[cfg(target_os = "windows")]
#[path = "setup_orchestrator.rs"]
mod setup;
@@ -36,6 +44,7 @@ mod setup_error;
#[cfg(target_os = "windows")]
pub use acl::add_deny_write_ace;
#[cfg(target_os = "windows")]
pub use acl::allow_null_device;
#[cfg(target_os = "windows")]
@@ -55,6 +64,8 @@ pub use cap::load_or_create_cap_sids;
#[cfg(target_os = "windows")]
pub use cap::workspace_cap_sid_for_cwd;
#[cfg(target_os = "windows")]
pub use conpty::spawn_conpty_process_as_user;
#[cfg(target_os = "windows")]
pub use dpapi::protect as dpapi_protect;
#[cfg(target_os = "windows")]
pub use dpapi::unprotect as dpapi_unprotect;
@@ -83,6 +94,16 @@ pub use policy::SandboxPolicy;
#[cfg(target_os = "windows")]
pub use process::create_process_as_user;
#[cfg(target_os = "windows")]
pub use process::read_handle_loop;
#[cfg(target_os = "windows")]
pub use process::spawn_process_with_pipes;
#[cfg(target_os = "windows")]
pub use process::PipeSpawnHandles;
#[cfg(target_os = "windows")]
pub use process::StderrMode;
#[cfg(target_os = "windows")]
pub use process::StdinMode;
#[cfg(target_os = "windows")]
pub use setup::run_elevated_setup;
#[cfg(target_os = "windows")]
pub use setup::run_setup_refresh;

View File

@@ -8,15 +8,19 @@ use anyhow::Result;
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::ptr;
use windows_sys::Win32::Foundation::GetLastError;
use windows_sys::Win32::Foundation::CloseHandle;
use windows_sys::Win32::Foundation::SetHandleInformation;
use windows_sys::Win32::Foundation::HANDLE;
use windows_sys::Win32::Foundation::HANDLE_FLAG_INHERIT;
use windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE;
use windows_sys::Win32::Storage::FileSystem::ReadFile;
use windows_sys::Win32::System::Console::GetStdHandle;
use windows_sys::Win32::System::Console::STD_ERROR_HANDLE;
use windows_sys::Win32::System::Console::STD_INPUT_HANDLE;
use windows_sys::Win32::System::Console::STD_OUTPUT_HANDLE;
use windows_sys::Win32::System::Pipes::CreatePipe;
use windows_sys::Win32::System::Threading::CreateProcessAsUserW;
use windows_sys::Win32::System::Threading::CREATE_UNICODE_ENVIRONMENT;
use windows_sys::Win32::System::Threading::PROCESS_INFORMATION;
@@ -153,3 +157,141 @@ pub unsafe fn create_process_as_user(
_desktop: desktop,
})
}
/// Controls whether the child's stdin handle is kept open for writing.
#[allow(dead_code)]
pub enum StdinMode {
Closed,
Open,
}
/// Controls how stderr is wired for a pipe-spawned process.
#[allow(dead_code)]
pub enum StderrMode {
MergeStdout,
Separate,
}
/// Handles returned by `spawn_process_with_pipes`.
#[allow(dead_code)]
pub struct PipeSpawnHandles {
pub process: PROCESS_INFORMATION,
pub stdin_write: Option<HANDLE>,
pub stdout_read: HANDLE,
pub stderr_read: Option<HANDLE>,
}
/// Spawns a process with anonymous pipes and returns the relevant handles.
pub fn spawn_process_with_pipes(
h_token: HANDLE,
argv: &[String],
cwd: &Path,
env_map: &HashMap<String, String>,
stdin_mode: StdinMode,
stderr_mode: StderrMode,
) -> Result<PipeSpawnHandles> {
let mut in_r: HANDLE = 0;
let mut in_w: HANDLE = 0;
let mut out_r: HANDLE = 0;
let mut out_w: HANDLE = 0;
let mut err_r: HANDLE = 0;
let mut err_w: HANDLE = 0;
unsafe {
if CreatePipe(&mut in_r, &mut in_w, ptr::null_mut(), 0) == 0 {
return Err(anyhow!("CreatePipe stdin failed: {}", GetLastError()));
}
if CreatePipe(&mut out_r, &mut out_w, ptr::null_mut(), 0) == 0 {
CloseHandle(in_r);
CloseHandle(in_w);
return Err(anyhow!("CreatePipe stdout failed: {}", GetLastError()));
}
if matches!(stderr_mode, StderrMode::Separate)
&& CreatePipe(&mut err_r, &mut err_w, ptr::null_mut(), 0) == 0
{
CloseHandle(in_r);
CloseHandle(in_w);
CloseHandle(out_r);
CloseHandle(out_w);
return Err(anyhow!("CreatePipe stderr failed: {}", GetLastError()));
}
}
let stderr_handle = match stderr_mode {
StderrMode::MergeStdout => out_w,
StderrMode::Separate => err_w,
};
let stdio = Some((in_r, out_w, stderr_handle));
let spawn_result =
unsafe { create_process_as_user(h_token, argv, cwd, env_map, None, stdio, false) };
let created = match spawn_result {
Ok(v) => v,
Err(err) => {
unsafe {
CloseHandle(in_r);
CloseHandle(in_w);
CloseHandle(out_r);
CloseHandle(out_w);
if matches!(stderr_mode, StderrMode::Separate) {
CloseHandle(err_r);
CloseHandle(err_w);
}
}
return Err(err);
}
};
let pi = created.process_info;
unsafe {
CloseHandle(in_r);
CloseHandle(out_w);
if matches!(stderr_mode, StderrMode::Separate) {
CloseHandle(err_w);
}
if matches!(stdin_mode, StdinMode::Closed) {
CloseHandle(in_w);
}
}
Ok(PipeSpawnHandles {
process: pi,
stdin_write: match stdin_mode {
StdinMode::Open => Some(in_w),
StdinMode::Closed => None,
},
stdout_read: out_r,
stderr_read: match stderr_mode {
StderrMode::Separate => Some(err_r),
StderrMode::MergeStdout => None,
},
})
}
/// Reads a HANDLE until EOF and invokes `on_chunk` for each read.
pub fn read_handle_loop<F>(handle: HANDLE, mut on_chunk: F) -> std::thread::JoinHandle<()>
where
F: FnMut(&[u8]) + Send + 'static,
{
std::thread::spawn(move || {
let mut buf = [0u8; 8192];
loop {
let mut read_bytes: u32 = 0;
let ok = unsafe {
ReadFile(
handle,
buf.as_mut_ptr(),
buf.len() as u32,
&mut read_bytes,
ptr::null_mut(),
)
};
if ok == 0 || read_bytes == 0 {
break;
}
on_chunk(&buf[..read_bytes as usize]);
}
unsafe {
CloseHandle(handle);
}
})
}

View File

@@ -0,0 +1,67 @@
//! Shared helper utilities for Windows sandbox setup.
//!
//! These helpers centralize small pieces of setup logic used across both legacy and
//! elevated paths, including unified_exec sessions and capture flows. They cover
//! codex home directory creation and git safe.directory injection so sandboxed
//! users can run git inside a repo owned by the primary user.
use anyhow::Result;
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
/// Walk upward from `start` to locate the git worktree root (supports gitfile redirects).
fn find_git_root(start: &Path) -> Option<PathBuf> {
let mut cur = dunce::canonicalize(start).ok()?;
loop {
let marker = cur.join(".git");
if marker.is_dir() {
return Some(cur);
}
if marker.is_file() {
if let Ok(txt) = std::fs::read_to_string(&marker) {
if let Some(rest) = txt.trim().strip_prefix("gitdir:") {
let gitdir = rest.trim();
let resolved = if Path::new(gitdir).is_absolute() {
PathBuf::from(gitdir)
} else {
cur.join(gitdir)
};
return resolved.parent().map(|p| p.to_path_buf()).or(Some(cur));
}
}
return Some(cur);
}
let parent = cur.parent()?;
if parent == cur {
return None;
}
cur = parent.to_path_buf();
}
}
/// Ensure the sandbox codex home directory exists.
pub fn ensure_codex_home_exists(p: &Path) -> Result<()> {
std::fs::create_dir_all(p)?;
Ok(())
}
/// Adds a git safe.directory entry to the environment when running inside a repository.
/// git will not otherwise allow the Sandbox user to run git commands on the repo directory
/// which is owned by the primary user.
pub fn inject_git_safe_directory(env_map: &mut HashMap<String, String>, cwd: &Path) {
if let Some(git_root) = find_git_root(cwd) {
let mut cfg_count: usize = env_map
.get("GIT_CONFIG_COUNT")
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(0);
let git_path = git_root.to_string_lossy().replace("\\\\", "/");
env_map.insert(
format!("GIT_CONFIG_KEY_{cfg_count}"),
"safe.directory".to_string(),
);
env_map.insert(format!("GIT_CONFIG_VALUE_{cfg_count}"), git_path);
cfg_count += 1;
env_map.insert("GIT_CONFIG_COUNT".to_string(), cfg_count.to_string());
}
}

View File

@@ -1,7 +1,15 @@
use anyhow::Result;
use std::ffi::OsStr;
use std::os::windows::ffi::OsStrExt;
use windows_sys::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER;
use windows_sys::Win32::Foundation::GetLastError;
use windows_sys::Win32::Foundation::LocalFree;
use windows_sys::Win32::Foundation::HLOCAL;
use windows_sys::Win32::Security::Authorization::ConvertStringSidToSidW;
use windows_sys::Win32::Security::CopySid;
use windows_sys::Win32::Security::GetLengthSid;
use windows_sys::Win32::Security::LookupAccountNameW;
use windows_sys::Win32::Security::SID_NAME_USE;
use windows_sys::Win32::System::Diagnostics::Debug::FormatMessageW;
use windows_sys::Win32::System::Diagnostics::Debug::FORMAT_MESSAGE_ALLOCATE_BUFFER;
use windows_sys::Win32::System::Diagnostics::Debug::FORMAT_MESSAGE_FROM_SYSTEM;
@@ -102,3 +110,83 @@ pub fn string_from_sid_bytes(sid: &[u8]) -> Result<String, String> {
Ok(out)
}
}
const SID_ADMINISTRATORS: &str = "S-1-5-32-544";
const SID_USERS: &str = "S-1-5-32-545";
const SID_AUTHENTICATED_USERS: &str = "S-1-5-11";
const SID_EVERYONE: &str = "S-1-1-0";
const SID_SYSTEM: &str = "S-1-5-18";
pub fn resolve_sid(name: &str) -> Result<Vec<u8>> {
if let Some(sid_str) = well_known_sid_str(name) {
return sid_bytes_from_string(sid_str);
}
let name_w = to_wide(OsStr::new(name));
let mut sid_buffer = vec![0u8; 68];
let mut sid_len: u32 = sid_buffer.len() as u32;
let mut domain: Vec<u16> = Vec::new();
let mut domain_len: u32 = 0;
let mut use_type: SID_NAME_USE = 0;
loop {
let ok = unsafe {
LookupAccountNameW(
std::ptr::null(),
name_w.as_ptr(),
sid_buffer.as_mut_ptr() as *mut std::ffi::c_void,
&mut sid_len,
domain.as_mut_ptr(),
&mut domain_len,
&mut use_type,
)
};
if ok != 0 {
sid_buffer.truncate(sid_len as usize);
return Ok(sid_buffer);
}
let err = unsafe { GetLastError() };
if err == ERROR_INSUFFICIENT_BUFFER {
sid_buffer.resize(sid_len as usize, 0);
domain.resize(domain_len as usize, 0);
continue;
}
return Err(anyhow::anyhow!("LookupAccountNameW failed for {name}: {err}"));
}
}
fn well_known_sid_str(name: &str) -> Option<&'static str> {
match name {
"Administrators" => Some(SID_ADMINISTRATORS),
"Users" => Some(SID_USERS),
"Authenticated Users" => Some(SID_AUTHENTICATED_USERS),
"Everyone" => Some(SID_EVERYONE),
"SYSTEM" => Some(SID_SYSTEM),
_ => None,
}
}
fn sid_bytes_from_string(sid_str: &str) -> Result<Vec<u8>> {
let sid_w = to_wide(OsStr::new(sid_str));
let mut psid: *mut std::ffi::c_void = std::ptr::null_mut();
if unsafe { ConvertStringSidToSidW(sid_w.as_ptr(), &mut psid) } == 0 {
return Err(anyhow::anyhow!(
"ConvertStringSidToSidW failed for {sid_str}: {}",
unsafe { GetLastError() }
));
}
let sid_len = unsafe { GetLengthSid(psid) };
if sid_len == 0 {
unsafe {
LocalFree(psid as _);
}
return Err(anyhow::anyhow!("GetLengthSid failed for {sid_str}"));
}
let mut out = vec![0u8; sid_len as usize];
let ok = unsafe { CopySid(sid_len, out.as_mut_ptr() as *mut std::ffi::c_void, psid) };
unsafe {
LocalFree(psid as _);
}
if ok == 0 {
return Err(anyhow::anyhow!("CopySid failed for {sid_str}"));
}
Ok(out)
}