diff --git a/codex-rs/core/src/agent/control.rs b/codex-rs/core/src/agent/control.rs index 9599e2dc69..c9ac18a026 100644 --- a/codex-rs/core/src/agent/control.rs +++ b/codex-rs/core/src/agent/control.rs @@ -487,1100 +487,5 @@ impl AgentControl { } } #[cfg(test)] -mod tests { - use super::*; - use crate::CodexAuth; - use crate::CodexThread; - use crate::ThreadManager; - use crate::agent::agent_status_from_event; - use crate::config::AgentRoleConfig; - use crate::config::Config; - use crate::config::ConfigBuilder; - use crate::config_loader::LoaderOverrides; - use crate::contextual_user_message::SUBAGENT_NOTIFICATION_OPEN_TAG; - use crate::features::Feature; - use assert_matches::assert_matches; - use codex_protocol::config_types::ModeKind; - use codex_protocol::models::ContentItem; - use codex_protocol::models::ResponseItem; - use codex_protocol::protocol::ErrorEvent; - use codex_protocol::protocol::EventMsg; - use codex_protocol::protocol::SessionSource; - use codex_protocol::protocol::SubAgentSource; - use codex_protocol::protocol::TurnAbortReason; - use codex_protocol::protocol::TurnAbortedEvent; - use codex_protocol::protocol::TurnCompleteEvent; - use codex_protocol::protocol::TurnStartedEvent; - use pretty_assertions::assert_eq; - use tempfile::TempDir; - use tokio::time::Duration; - use tokio::time::sleep; - use tokio::time::timeout; - use toml::Value as TomlValue; - - async fn test_config_with_cli_overrides( - cli_overrides: Vec<(String, TomlValue)>, - ) -> (TempDir, Config) { - let home = TempDir::new().expect("create temp dir"); - let config = ConfigBuilder::default() - .codex_home(home.path().to_path_buf()) - .cli_overrides(cli_overrides) - .loader_overrides(LoaderOverrides { - #[cfg(target_os = "macos")] - managed_preferences_base64: Some(String::new()), - macos_managed_config_requirements_base64: Some(String::new()), - ..LoaderOverrides::default() - }) - .build() - .await - .expect("load default test config"); - (home, config) - } - - async fn test_config() -> (TempDir, Config) { - test_config_with_cli_overrides(Vec::new()).await - } - - fn text_input(text: &str) -> Vec { - vec![UserInput::Text { - text: text.to_string(), - text_elements: Vec::new(), - }] - } - - struct AgentControlHarness { - _home: TempDir, - config: Config, - manager: ThreadManager, - control: AgentControl, - } - - impl AgentControlHarness { - async fn new() -> Self { - let (home, config) = test_config().await; - let manager = ThreadManager::with_models_provider_and_home_for_tests( - CodexAuth::from_api_key("dummy"), - config.model_provider.clone(), - config.codex_home.clone(), - ); - let control = manager.agent_control(); - Self { - _home: home, - config, - manager, - control, - } - } - - async fn start_thread(&self) -> (ThreadId, Arc) { - let new_thread = self - .manager - .start_thread(self.config.clone()) - .await - .expect("start thread"); - (new_thread.thread_id, new_thread.thread) - } - } - - fn has_subagent_notification(history_items: &[ResponseItem]) -> bool { - history_items.iter().any(|item| { - let ResponseItem::Message { role, content, .. } = item else { - return false; - }; - if role != "user" { - return false; - } - content.iter().any(|content_item| match content_item { - ContentItem::InputText { text } | ContentItem::OutputText { text } => { - text.contains(SUBAGENT_NOTIFICATION_OPEN_TAG) - } - ContentItem::InputImage { .. } => false, - }) - }) - } - - /// Returns true when any message item contains `needle` in a text span. - fn history_contains_text(history_items: &[ResponseItem], needle: &str) -> bool { - history_items.iter().any(|item| { - let ResponseItem::Message { content, .. } = item else { - return false; - }; - content.iter().any(|content_item| match content_item { - ContentItem::InputText { text } | ContentItem::OutputText { text } => { - text.contains(needle) - } - ContentItem::InputImage { .. } => false, - }) - }) - } - - async fn wait_for_subagent_notification(parent_thread: &Arc) -> bool { - let wait = async { - loop { - let history_items = parent_thread - .codex - .session - .clone_history() - .await - .raw_items() - .to_vec(); - if has_subagent_notification(&history_items) { - return true; - } - sleep(Duration::from_millis(25)).await; - } - }; - timeout(Duration::from_secs(2), wait).await.is_ok() - } - - #[tokio::test] - async fn send_input_errors_when_manager_dropped() { - let control = AgentControl::default(); - let err = control - .send_input( - ThreadId::new(), - vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }], - ) - .await - .expect_err("send_input should fail without a manager"); - assert_eq!( - err.to_string(), - "unsupported operation: thread manager dropped" - ); - } - - #[tokio::test] - async fn get_status_returns_not_found_without_manager() { - let control = AgentControl::default(); - let got = control.get_status(ThreadId::new()).await; - assert_eq!(got, AgentStatus::NotFound); - } - - #[tokio::test] - async fn on_event_updates_status_from_task_started() { - let status = agent_status_from_event(&EventMsg::TurnStarted(TurnStartedEvent { - turn_id: "turn-1".to_string(), - model_context_window: None, - collaboration_mode_kind: ModeKind::Default, - })); - assert_eq!(status, Some(AgentStatus::Running)); - } - - #[tokio::test] - async fn on_event_updates_status_from_task_complete() { - let status = agent_status_from_event(&EventMsg::TurnComplete(TurnCompleteEvent { - turn_id: "turn-1".to_string(), - last_agent_message: Some("done".to_string()), - })); - let expected = AgentStatus::Completed(Some("done".to_string())); - assert_eq!(status, Some(expected)); - } - - #[tokio::test] - async fn on_event_updates_status_from_error() { - let status = agent_status_from_event(&EventMsg::Error(ErrorEvent { - message: "boom".to_string(), - codex_error_info: None, - })); - - let expected = AgentStatus::Errored("boom".to_string()); - assert_eq!(status, Some(expected)); - } - - #[tokio::test] - async fn on_event_updates_status_from_turn_aborted() { - let status = agent_status_from_event(&EventMsg::TurnAborted(TurnAbortedEvent { - turn_id: Some("turn-1".to_string()), - reason: TurnAbortReason::Interrupted, - })); - - let expected = AgentStatus::Errored("Interrupted".to_string()); - assert_eq!(status, Some(expected)); - } - - #[tokio::test] - async fn on_event_updates_status_from_shutdown_complete() { - let status = agent_status_from_event(&EventMsg::ShutdownComplete); - assert_eq!(status, Some(AgentStatus::Shutdown)); - } - - #[tokio::test] - async fn spawn_agent_errors_when_manager_dropped() { - let control = AgentControl::default(); - let (_home, config) = test_config().await; - let err = control - .spawn_agent(config, text_input("hello"), None) - .await - .expect_err("spawn_agent should fail without a manager"); - assert_eq!( - err.to_string(), - "unsupported operation: thread manager dropped" - ); - } - - #[tokio::test] - async fn resume_agent_errors_when_manager_dropped() { - let control = AgentControl::default(); - let (_home, config) = test_config().await; - let err = control - .resume_agent_from_rollout(config, ThreadId::new(), SessionSource::Exec) - .await - .expect_err("resume_agent should fail without a manager"); - assert_eq!( - err.to_string(), - "unsupported operation: thread manager dropped" - ); - } - - #[tokio::test] - async fn send_input_errors_when_thread_missing() { - let harness = AgentControlHarness::new().await; - let thread_id = ThreadId::new(); - let err = harness - .control - .send_input( - thread_id, - vec![UserInput::Text { - text: "hello".to_string(), - text_elements: Vec::new(), - }], - ) - .await - .expect_err("send_input should fail for missing thread"); - assert_matches!(err, CodexErr::ThreadNotFound(id) if id == thread_id); - } - - #[tokio::test] - async fn get_status_returns_not_found_for_missing_thread() { - let harness = AgentControlHarness::new().await; - let status = harness.control.get_status(ThreadId::new()).await; - assert_eq!(status, AgentStatus::NotFound); - } - - #[tokio::test] - async fn get_status_returns_pending_init_for_new_thread() { - let harness = AgentControlHarness::new().await; - let (thread_id, _) = harness.start_thread().await; - let status = harness.control.get_status(thread_id).await; - assert_eq!(status, AgentStatus::PendingInit); - } - - #[tokio::test] - async fn subscribe_status_errors_for_missing_thread() { - let harness = AgentControlHarness::new().await; - let thread_id = ThreadId::new(); - let err = harness - .control - .subscribe_status(thread_id) - .await - .expect_err("subscribe_status should fail for missing thread"); - assert_matches!(err, CodexErr::ThreadNotFound(id) if id == thread_id); - } - - #[tokio::test] - async fn subscribe_status_updates_on_shutdown() { - let harness = AgentControlHarness::new().await; - let (thread_id, thread) = harness.start_thread().await; - let mut status_rx = harness - .control - .subscribe_status(thread_id) - .await - .expect("subscribe_status should succeed"); - assert_eq!(status_rx.borrow().clone(), AgentStatus::PendingInit); - - let _ = thread - .submit(Op::Shutdown {}) - .await - .expect("shutdown should submit"); - - let _ = status_rx.changed().await; - assert_eq!(status_rx.borrow().clone(), AgentStatus::Shutdown); - } - - #[tokio::test] - async fn send_input_submits_user_message() { - let harness = AgentControlHarness::new().await; - let (thread_id, _thread) = harness.start_thread().await; - - let submission_id = harness - .control - .send_input( - thread_id, - vec![UserInput::Text { - text: "hello from tests".to_string(), - text_elements: Vec::new(), - }], - ) - .await - .expect("send_input should succeed"); - assert!(!submission_id.is_empty()); - let expected = ( - thread_id, - Op::UserInput { - items: vec![UserInput::Text { - text: "hello from tests".to_string(), - text_elements: Vec::new(), - }], - final_output_json_schema: None, - }, - ); - let captured = harness - .manager - .captured_ops() - .into_iter() - .find(|entry| *entry == expected); - assert_eq!(captured, Some(expected)); - } - - #[tokio::test] - async fn spawn_agent_creates_thread_and_sends_prompt() { - let harness = AgentControlHarness::new().await; - let thread_id = harness - .control - .spawn_agent(harness.config.clone(), text_input("spawned"), None) - .await - .expect("spawn_agent should succeed"); - let _thread = harness - .manager - .get_thread(thread_id) - .await - .expect("thread should be registered"); - let expected = ( - thread_id, - Op::UserInput { - items: vec![UserInput::Text { - text: "spawned".to_string(), - text_elements: Vec::new(), - }], - final_output_json_schema: None, - }, - ); - let captured = harness - .manager - .captured_ops() - .into_iter() - .find(|entry| *entry == expected); - assert_eq!(captured, Some(expected)); - } - - #[tokio::test] - async fn spawn_agent_can_fork_parent_thread_history() { - let harness = AgentControlHarness::new().await; - let (parent_thread_id, parent_thread) = harness.start_thread().await; - parent_thread - .inject_user_message_without_turn("parent seed context".to_string()) - .await; - let turn_context = parent_thread.codex.session.new_default_turn().await; - let parent_spawn_call_id = "spawn-call-history".to_string(); - let parent_spawn_call = ResponseItem::FunctionCall { - id: None, - name: "spawn_agent".to_string(), - namespace: None, - arguments: "{}".to_string(), - call_id: parent_spawn_call_id.clone(), - }; - parent_thread - .codex - .session - .record_conversation_items(turn_context.as_ref(), &[parent_spawn_call]) - .await; - parent_thread - .codex - .session - .ensure_rollout_materialized() - .await; - parent_thread.codex.session.flush_rollout().await; - - let child_thread_id = harness - .control - .spawn_agent_with_options( - harness.config.clone(), - text_input("child task"), - Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id, - depth: 1, - agent_nickname: None, - agent_role: None, - })), - SpawnAgentOptions { - fork_parent_spawn_call_id: Some(parent_spawn_call_id), - }, - ) - .await - .expect("forked spawn should succeed"); - - let child_thread = harness - .manager - .get_thread(child_thread_id) - .await - .expect("child thread should be registered"); - assert_ne!(child_thread_id, parent_thread_id); - let history = child_thread.codex.session.clone_history().await; - assert!(history_contains_text( - history.raw_items(), - "parent seed context" - )); - - let expected = ( - child_thread_id, - Op::UserInput { - items: vec![UserInput::Text { - text: "child task".to_string(), - text_elements: Vec::new(), - }], - final_output_json_schema: None, - }, - ); - let captured = harness - .manager - .captured_ops() - .into_iter() - .find(|entry| *entry == expected); - assert_eq!(captured, Some(expected)); - - let _ = harness - .control - .shutdown_agent(child_thread_id) - .await - .expect("child shutdown should submit"); - let _ = parent_thread - .submit(Op::Shutdown {}) - .await - .expect("parent shutdown should submit"); - } - - #[tokio::test] - async fn spawn_agent_fork_injects_output_for_parent_spawn_call() { - let harness = AgentControlHarness::new().await; - let (parent_thread_id, parent_thread) = harness.start_thread().await; - let turn_context = parent_thread.codex.session.new_default_turn().await; - let parent_spawn_call_id = "spawn-call-1".to_string(); - let parent_spawn_call = ResponseItem::FunctionCall { - id: None, - name: "spawn_agent".to_string(), - namespace: None, - arguments: "{}".to_string(), - call_id: parent_spawn_call_id.clone(), - }; - parent_thread - .codex - .session - .record_conversation_items(turn_context.as_ref(), &[parent_spawn_call]) - .await; - parent_thread - .codex - .session - .ensure_rollout_materialized() - .await; - parent_thread.codex.session.flush_rollout().await; - - let child_thread_id = harness - .control - .spawn_agent_with_options( - harness.config.clone(), - text_input("child task"), - Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id, - depth: 1, - agent_nickname: None, - agent_role: None, - })), - SpawnAgentOptions { - fork_parent_spawn_call_id: Some(parent_spawn_call_id.clone()), - }, - ) - .await - .expect("forked spawn should succeed"); - - let child_thread = harness - .manager - .get_thread(child_thread_id) - .await - .expect("child thread should be registered"); - let history = child_thread.codex.session.clone_history().await; - let injected_output = history.raw_items().iter().find_map(|item| match item { - ResponseItem::FunctionCallOutput { call_id, output } - if call_id == &parent_spawn_call_id => - { - Some(output) - } - _ => None, - }); - let injected_output = - injected_output.expect("forked child should contain synthetic tool output"); - assert_eq!( - injected_output.text_content(), - Some(FORKED_SPAWN_AGENT_OUTPUT_MESSAGE) - ); - assert_eq!(injected_output.success, Some(true)); - - let _ = harness - .control - .shutdown_agent(child_thread_id) - .await - .expect("child shutdown should submit"); - let _ = parent_thread - .submit(Op::Shutdown {}) - .await - .expect("parent shutdown should submit"); - } - - #[tokio::test] - async fn spawn_agent_fork_flushes_parent_rollout_before_loading_history() { - let harness = AgentControlHarness::new().await; - let (parent_thread_id, parent_thread) = harness.start_thread().await; - let turn_context = parent_thread.codex.session.new_default_turn().await; - let parent_spawn_call_id = "spawn-call-unflushed".to_string(); - let parent_spawn_call = ResponseItem::FunctionCall { - id: None, - name: "spawn_agent".to_string(), - namespace: None, - arguments: "{}".to_string(), - call_id: parent_spawn_call_id.clone(), - }; - parent_thread - .codex - .session - .record_conversation_items(turn_context.as_ref(), &[parent_spawn_call]) - .await; - - let child_thread_id = harness - .control - .spawn_agent_with_options( - harness.config.clone(), - text_input("child task"), - Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id, - depth: 1, - agent_nickname: None, - agent_role: None, - })), - SpawnAgentOptions { - fork_parent_spawn_call_id: Some(parent_spawn_call_id.clone()), - }, - ) - .await - .expect("forked spawn should flush parent rollout before loading history"); - - let child_thread = harness - .manager - .get_thread(child_thread_id) - .await - .expect("child thread should be registered"); - let history = child_thread.codex.session.clone_history().await; - - let mut parent_call_index = None; - let mut injected_output_index = None; - for (idx, item) in history.raw_items().iter().enumerate() { - match item { - ResponseItem::FunctionCall { call_id, .. } if call_id == &parent_spawn_call_id => { - parent_call_index = Some(idx); - } - ResponseItem::FunctionCallOutput { call_id, .. } - if call_id == &parent_spawn_call_id => - { - injected_output_index = Some(idx); - } - _ => {} - } - } - - let parent_call_index = - parent_call_index.expect("forked child should include the parent spawn_agent call"); - let injected_output_index = injected_output_index - .expect("forked child should include synthetic output for the parent spawn_agent call"); - assert!(parent_call_index < injected_output_index); - - let _ = harness - .control - .shutdown_agent(child_thread_id) - .await - .expect("child shutdown should submit"); - let _ = parent_thread - .submit(Op::Shutdown {}) - .await - .expect("parent shutdown should submit"); - } - - #[tokio::test] - async fn spawn_agent_respects_max_threads_limit() { - let max_threads = 1usize; - let (_home, config) = test_config_with_cli_overrides(vec![( - "agents.max_threads".to_string(), - TomlValue::Integer(max_threads as i64), - )]) - .await; - let manager = ThreadManager::with_models_provider_and_home_for_tests( - CodexAuth::from_api_key("dummy"), - config.model_provider.clone(), - config.codex_home.clone(), - ); - let control = manager.agent_control(); - - let _ = manager - .start_thread(config.clone()) - .await - .expect("start thread"); - - let first_agent_id = control - .spawn_agent(config.clone(), text_input("hello"), None) - .await - .expect("spawn_agent should succeed"); - - let err = control - .spawn_agent(config, text_input("hello again"), None) - .await - .expect_err("spawn_agent should respect max threads"); - let CodexErr::AgentLimitReached { - max_threads: seen_max_threads, - } = err - else { - panic!("expected CodexErr::AgentLimitReached"); - }; - assert_eq!(seen_max_threads, max_threads); - - let _ = control - .shutdown_agent(first_agent_id) - .await - .expect("shutdown agent"); - } - - #[tokio::test] - async fn spawn_agent_releases_slot_after_shutdown() { - let max_threads = 1usize; - let (_home, config) = test_config_with_cli_overrides(vec![( - "agents.max_threads".to_string(), - TomlValue::Integer(max_threads as i64), - )]) - .await; - let manager = ThreadManager::with_models_provider_and_home_for_tests( - CodexAuth::from_api_key("dummy"), - config.model_provider.clone(), - config.codex_home.clone(), - ); - let control = manager.agent_control(); - - let first_agent_id = control - .spawn_agent(config.clone(), text_input("hello"), None) - .await - .expect("spawn_agent should succeed"); - let _ = control - .shutdown_agent(first_agent_id) - .await - .expect("shutdown agent"); - - let second_agent_id = control - .spawn_agent(config.clone(), text_input("hello again"), None) - .await - .expect("spawn_agent should succeed after shutdown"); - let _ = control - .shutdown_agent(second_agent_id) - .await - .expect("shutdown agent"); - } - - #[tokio::test] - async fn spawn_agent_limit_shared_across_clones() { - let max_threads = 1usize; - let (_home, config) = test_config_with_cli_overrides(vec![( - "agents.max_threads".to_string(), - TomlValue::Integer(max_threads as i64), - )]) - .await; - let manager = ThreadManager::with_models_provider_and_home_for_tests( - CodexAuth::from_api_key("dummy"), - config.model_provider.clone(), - config.codex_home.clone(), - ); - let control = manager.agent_control(); - let cloned = control.clone(); - - let first_agent_id = cloned - .spawn_agent(config.clone(), text_input("hello"), None) - .await - .expect("spawn_agent should succeed"); - - let err = control - .spawn_agent(config, text_input("hello again"), None) - .await - .expect_err("spawn_agent should respect shared guard"); - let CodexErr::AgentLimitReached { max_threads } = err else { - panic!("expected CodexErr::AgentLimitReached"); - }; - assert_eq!(max_threads, 1); - - let _ = control - .shutdown_agent(first_agent_id) - .await - .expect("shutdown agent"); - } - - #[tokio::test] - async fn resume_agent_respects_max_threads_limit() { - let max_threads = 1usize; - let (_home, config) = test_config_with_cli_overrides(vec![( - "agents.max_threads".to_string(), - TomlValue::Integer(max_threads as i64), - )]) - .await; - let manager = ThreadManager::with_models_provider_and_home_for_tests( - CodexAuth::from_api_key("dummy"), - config.model_provider.clone(), - config.codex_home.clone(), - ); - let control = manager.agent_control(); - - let resumable_id = control - .spawn_agent(config.clone(), text_input("hello"), None) - .await - .expect("spawn_agent should succeed"); - let _ = control - .shutdown_agent(resumable_id) - .await - .expect("shutdown resumable thread"); - - let active_id = control - .spawn_agent(config.clone(), text_input("occupy"), None) - .await - .expect("spawn_agent should succeed for active slot"); - - let err = control - .resume_agent_from_rollout(config, resumable_id, SessionSource::Exec) - .await - .expect_err("resume should respect max threads"); - let CodexErr::AgentLimitReached { - max_threads: seen_max_threads, - } = err - else { - panic!("expected CodexErr::AgentLimitReached"); - }; - assert_eq!(seen_max_threads, max_threads); - - let _ = control - .shutdown_agent(active_id) - .await - .expect("shutdown active thread"); - } - - #[tokio::test] - async fn resume_agent_releases_slot_after_resume_failure() { - let max_threads = 1usize; - let (_home, config) = test_config_with_cli_overrides(vec![( - "agents.max_threads".to_string(), - TomlValue::Integer(max_threads as i64), - )]) - .await; - let manager = ThreadManager::with_models_provider_and_home_for_tests( - CodexAuth::from_api_key("dummy"), - config.model_provider.clone(), - config.codex_home.clone(), - ); - let control = manager.agent_control(); - - let _ = control - .resume_agent_from_rollout(config.clone(), ThreadId::new(), SessionSource::Exec) - .await - .expect_err("resume should fail for missing rollout path"); - - let resumed_id = control - .spawn_agent(config, text_input("hello"), None) - .await - .expect("spawn should succeed after failed resume"); - let _ = control - .shutdown_agent(resumed_id) - .await - .expect("shutdown resumed thread"); - } - - #[tokio::test] - async fn spawn_child_completion_notifies_parent_history() { - let harness = AgentControlHarness::new().await; - let (parent_thread_id, parent_thread) = harness.start_thread().await; - - let child_thread_id = harness - .control - .spawn_agent( - harness.config.clone(), - text_input("hello child"), - Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id, - depth: 1, - agent_nickname: None, - agent_role: Some("explorer".to_string()), - })), - ) - .await - .expect("child spawn should succeed"); - - let child_thread = harness - .manager - .get_thread(child_thread_id) - .await - .expect("child thread should exist"); - let _ = child_thread - .submit(Op::Shutdown {}) - .await - .expect("child shutdown should submit"); - - assert_eq!(wait_for_subagent_notification(&parent_thread).await, true); - } - - #[tokio::test] - async fn completion_watcher_notifies_parent_when_child_is_missing() { - let harness = AgentControlHarness::new().await; - let (parent_thread_id, parent_thread) = harness.start_thread().await; - let child_thread_id = ThreadId::new(); - - harness.control.maybe_start_completion_watcher( - child_thread_id, - Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id, - depth: 1, - agent_nickname: None, - agent_role: Some("explorer".to_string()), - })), - ); - - assert_eq!(wait_for_subagent_notification(&parent_thread).await, true); - - let history_items = parent_thread - .codex - .session - .clone_history() - .await - .raw_items() - .to_vec(); - assert_eq!( - history_contains_text( - &history_items, - &format!("\"agent_id\":\"{child_thread_id}\"") - ), - true - ); - assert_eq!( - history_contains_text(&history_items, "\"status\":\"not_found\""), - true - ); - } - - #[tokio::test] - async fn spawn_thread_subagent_gets_random_nickname_in_session_source() { - let harness = AgentControlHarness::new().await; - let (parent_thread_id, _parent_thread) = harness.start_thread().await; - - let child_thread_id = harness - .control - .spawn_agent( - harness.config.clone(), - text_input("hello child"), - Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id, - depth: 1, - agent_nickname: None, - agent_role: Some("explorer".to_string()), - })), - ) - .await - .expect("child spawn should succeed"); - - let child_thread = harness - .manager - .get_thread(child_thread_id) - .await - .expect("child thread should be registered"); - let snapshot = child_thread.config_snapshot().await; - - let SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id: seen_parent_thread_id, - depth, - agent_nickname, - agent_role, - }) = snapshot.session_source - else { - panic!("expected thread-spawn sub-agent source"); - }; - assert_eq!(seen_parent_thread_id, parent_thread_id); - assert_eq!(depth, 1); - assert!(agent_nickname.is_some()); - assert_eq!(agent_role, Some("explorer".to_string())); - } - - #[tokio::test] - async fn spawn_thread_subagent_uses_role_specific_nickname_candidates() { - let mut harness = AgentControlHarness::new().await; - harness.config.agent_roles.insert( - "researcher".to_string(), - AgentRoleConfig { - description: Some("Research role".to_string()), - config_file: None, - nickname_candidates: Some(vec!["Atlas".to_string()]), - }, - ); - let (parent_thread_id, _parent_thread) = harness.start_thread().await; - - let child_thread_id = harness - .control - .spawn_agent( - harness.config.clone(), - text_input("hello child"), - Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id, - depth: 1, - agent_nickname: None, - agent_role: Some("researcher".to_string()), - })), - ) - .await - .expect("child spawn should succeed"); - - let child_thread = harness - .manager - .get_thread(child_thread_id) - .await - .expect("child thread should be registered"); - let snapshot = child_thread.config_snapshot().await; - - let SessionSource::SubAgent(SubAgentSource::ThreadSpawn { agent_nickname, .. }) = - snapshot.session_source - else { - panic!("expected thread-spawn sub-agent source"); - }; - assert_eq!(agent_nickname, Some("Atlas".to_string())); - } - - #[tokio::test] - async fn resume_thread_subagent_restores_stored_nickname_and_role() { - let (home, mut config) = test_config().await; - config - .features - .enable(Feature::Sqlite) - .expect("test config should allow sqlite"); - let manager = ThreadManager::with_models_provider_and_home_for_tests( - CodexAuth::from_api_key("dummy"), - config.model_provider.clone(), - config.codex_home.clone(), - ); - let control = manager.agent_control(); - let harness = AgentControlHarness { - _home: home, - config, - manager, - control, - }; - let (parent_thread_id, _parent_thread) = harness.start_thread().await; - - let child_thread_id = harness - .control - .spawn_agent( - harness.config.clone(), - text_input("hello child"), - Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id, - depth: 1, - agent_nickname: None, - agent_role: Some("explorer".to_string()), - })), - ) - .await - .expect("child spawn should succeed"); - - let child_thread = harness - .manager - .get_thread(child_thread_id) - .await - .expect("child thread should exist"); - let mut status_rx = harness - .control - .subscribe_status(child_thread_id) - .await - .expect("status subscription should succeed"); - if matches!(status_rx.borrow().clone(), AgentStatus::PendingInit) { - timeout(Duration::from_secs(5), async { - loop { - status_rx - .changed() - .await - .expect("child status should advance past pending init"); - if !matches!(status_rx.borrow().clone(), AgentStatus::PendingInit) { - break; - } - } - }) - .await - .expect("child should initialize before shutdown"); - } - let original_snapshot = child_thread.config_snapshot().await; - let original_nickname = original_snapshot - .session_source - .get_nickname() - .expect("spawned sub-agent should have a nickname"); - let state_db = child_thread - .state_db() - .expect("sqlite state db should be available for nickname resume test"); - timeout(Duration::from_secs(5), async { - loop { - if let Ok(Some(metadata)) = state_db.get_thread(child_thread_id).await - && metadata.agent_nickname.is_some() - && metadata.agent_role.as_deref() == Some("explorer") - { - break; - } - sleep(Duration::from_millis(10)).await; - } - }) - .await - .expect("child thread metadata should be persisted to sqlite before shutdown"); - - let _ = harness - .control - .shutdown_agent(child_thread_id) - .await - .expect("child shutdown should submit"); - - let resumed_thread_id = harness - .control - .resume_agent_from_rollout( - harness.config.clone(), - child_thread_id, - SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id, - depth: 1, - agent_nickname: None, - agent_role: None, - }), - ) - .await - .expect("resume should succeed"); - assert_eq!(resumed_thread_id, child_thread_id); - - let resumed_snapshot = harness - .manager - .get_thread(resumed_thread_id) - .await - .expect("resumed child thread should exist") - .config_snapshot() - .await; - let SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id: resumed_parent_thread_id, - depth: resumed_depth, - agent_nickname: resumed_nickname, - agent_role: resumed_role, - }) = resumed_snapshot.session_source - else { - panic!("expected thread-spawn sub-agent source"); - }; - assert_eq!(resumed_parent_thread_id, parent_thread_id); - assert_eq!(resumed_depth, 1); - assert_eq!(resumed_nickname, Some(original_nickname)); - assert_eq!(resumed_role, Some("explorer".to_string())); - - let _ = harness - .control - .shutdown_agent(resumed_thread_id) - .await - .expect("resumed child shutdown should submit"); - } -} +#[path = "control_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/agent/control_tests.rs b/codex-rs/core/src/agent/control_tests.rs new file mode 100644 index 0000000000..d78c448b29 --- /dev/null +++ b/codex-rs/core/src/agent/control_tests.rs @@ -0,0 +1,1095 @@ +use super::*; +use crate::CodexAuth; +use crate::CodexThread; +use crate::ThreadManager; +use crate::agent::agent_status_from_event; +use crate::config::AgentRoleConfig; +use crate::config::Config; +use crate::config::ConfigBuilder; +use crate::config_loader::LoaderOverrides; +use crate::contextual_user_message::SUBAGENT_NOTIFICATION_OPEN_TAG; +use crate::features::Feature; +use assert_matches::assert_matches; +use codex_protocol::config_types::ModeKind; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::ErrorEvent; +use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::SubAgentSource; +use codex_protocol::protocol::TurnAbortReason; +use codex_protocol::protocol::TurnAbortedEvent; +use codex_protocol::protocol::TurnCompleteEvent; +use codex_protocol::protocol::TurnStartedEvent; +use pretty_assertions::assert_eq; +use tempfile::TempDir; +use tokio::time::Duration; +use tokio::time::sleep; +use tokio::time::timeout; +use toml::Value as TomlValue; + +async fn test_config_with_cli_overrides( + cli_overrides: Vec<(String, TomlValue)>, +) -> (TempDir, Config) { + let home = TempDir::new().expect("create temp dir"); + let config = ConfigBuilder::default() + .codex_home(home.path().to_path_buf()) + .cli_overrides(cli_overrides) + .loader_overrides(LoaderOverrides { + #[cfg(target_os = "macos")] + managed_preferences_base64: Some(String::new()), + macos_managed_config_requirements_base64: Some(String::new()), + ..LoaderOverrides::default() + }) + .build() + .await + .expect("load default test config"); + (home, config) +} + +async fn test_config() -> (TempDir, Config) { + test_config_with_cli_overrides(Vec::new()).await +} + +fn text_input(text: &str) -> Vec { + vec![UserInput::Text { + text: text.to_string(), + text_elements: Vec::new(), + }] +} + +struct AgentControlHarness { + _home: TempDir, + config: Config, + manager: ThreadManager, + control: AgentControl, +} + +impl AgentControlHarness { + async fn new() -> Self { + let (home, config) = test_config().await; + let manager = ThreadManager::with_models_provider_and_home_for_tests( + CodexAuth::from_api_key("dummy"), + config.model_provider.clone(), + config.codex_home.clone(), + ); + let control = manager.agent_control(); + Self { + _home: home, + config, + manager, + control, + } + } + + async fn start_thread(&self) -> (ThreadId, Arc) { + let new_thread = self + .manager + .start_thread(self.config.clone()) + .await + .expect("start thread"); + (new_thread.thread_id, new_thread.thread) + } +} + +fn has_subagent_notification(history_items: &[ResponseItem]) -> bool { + history_items.iter().any(|item| { + let ResponseItem::Message { role, content, .. } = item else { + return false; + }; + if role != "user" { + return false; + } + content.iter().any(|content_item| match content_item { + ContentItem::InputText { text } | ContentItem::OutputText { text } => { + text.contains(SUBAGENT_NOTIFICATION_OPEN_TAG) + } + ContentItem::InputImage { .. } => false, + }) + }) +} + +/// Returns true when any message item contains `needle` in a text span. +fn history_contains_text(history_items: &[ResponseItem], needle: &str) -> bool { + history_items.iter().any(|item| { + let ResponseItem::Message { content, .. } = item else { + return false; + }; + content.iter().any(|content_item| match content_item { + ContentItem::InputText { text } | ContentItem::OutputText { text } => { + text.contains(needle) + } + ContentItem::InputImage { .. } => false, + }) + }) +} + +async fn wait_for_subagent_notification(parent_thread: &Arc) -> bool { + let wait = async { + loop { + let history_items = parent_thread + .codex + .session + .clone_history() + .await + .raw_items() + .to_vec(); + if has_subagent_notification(&history_items) { + return true; + } + sleep(Duration::from_millis(25)).await; + } + }; + timeout(Duration::from_secs(2), wait).await.is_ok() +} + +#[tokio::test] +async fn send_input_errors_when_manager_dropped() { + let control = AgentControl::default(); + let err = control + .send_input( + ThreadId::new(), + vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }], + ) + .await + .expect_err("send_input should fail without a manager"); + assert_eq!( + err.to_string(), + "unsupported operation: thread manager dropped" + ); +} + +#[tokio::test] +async fn get_status_returns_not_found_without_manager() { + let control = AgentControl::default(); + let got = control.get_status(ThreadId::new()).await; + assert_eq!(got, AgentStatus::NotFound); +} + +#[tokio::test] +async fn on_event_updates_status_from_task_started() { + let status = agent_status_from_event(&EventMsg::TurnStarted(TurnStartedEvent { + turn_id: "turn-1".to_string(), + model_context_window: None, + collaboration_mode_kind: ModeKind::Default, + })); + assert_eq!(status, Some(AgentStatus::Running)); +} + +#[tokio::test] +async fn on_event_updates_status_from_task_complete() { + let status = agent_status_from_event(&EventMsg::TurnComplete(TurnCompleteEvent { + turn_id: "turn-1".to_string(), + last_agent_message: Some("done".to_string()), + })); + let expected = AgentStatus::Completed(Some("done".to_string())); + assert_eq!(status, Some(expected)); +} + +#[tokio::test] +async fn on_event_updates_status_from_error() { + let status = agent_status_from_event(&EventMsg::Error(ErrorEvent { + message: "boom".to_string(), + codex_error_info: None, + })); + + let expected = AgentStatus::Errored("boom".to_string()); + assert_eq!(status, Some(expected)); +} + +#[tokio::test] +async fn on_event_updates_status_from_turn_aborted() { + let status = agent_status_from_event(&EventMsg::TurnAborted(TurnAbortedEvent { + turn_id: Some("turn-1".to_string()), + reason: TurnAbortReason::Interrupted, + })); + + let expected = AgentStatus::Errored("Interrupted".to_string()); + assert_eq!(status, Some(expected)); +} + +#[tokio::test] +async fn on_event_updates_status_from_shutdown_complete() { + let status = agent_status_from_event(&EventMsg::ShutdownComplete); + assert_eq!(status, Some(AgentStatus::Shutdown)); +} + +#[tokio::test] +async fn spawn_agent_errors_when_manager_dropped() { + let control = AgentControl::default(); + let (_home, config) = test_config().await; + let err = control + .spawn_agent(config, text_input("hello"), None) + .await + .expect_err("spawn_agent should fail without a manager"); + assert_eq!( + err.to_string(), + "unsupported operation: thread manager dropped" + ); +} + +#[tokio::test] +async fn resume_agent_errors_when_manager_dropped() { + let control = AgentControl::default(); + let (_home, config) = test_config().await; + let err = control + .resume_agent_from_rollout(config, ThreadId::new(), SessionSource::Exec) + .await + .expect_err("resume_agent should fail without a manager"); + assert_eq!( + err.to_string(), + "unsupported operation: thread manager dropped" + ); +} + +#[tokio::test] +async fn send_input_errors_when_thread_missing() { + let harness = AgentControlHarness::new().await; + let thread_id = ThreadId::new(); + let err = harness + .control + .send_input( + thread_id, + vec![UserInput::Text { + text: "hello".to_string(), + text_elements: Vec::new(), + }], + ) + .await + .expect_err("send_input should fail for missing thread"); + assert_matches!(err, CodexErr::ThreadNotFound(id) if id == thread_id); +} + +#[tokio::test] +async fn get_status_returns_not_found_for_missing_thread() { + let harness = AgentControlHarness::new().await; + let status = harness.control.get_status(ThreadId::new()).await; + assert_eq!(status, AgentStatus::NotFound); +} + +#[tokio::test] +async fn get_status_returns_pending_init_for_new_thread() { + let harness = AgentControlHarness::new().await; + let (thread_id, _) = harness.start_thread().await; + let status = harness.control.get_status(thread_id).await; + assert_eq!(status, AgentStatus::PendingInit); +} + +#[tokio::test] +async fn subscribe_status_errors_for_missing_thread() { + let harness = AgentControlHarness::new().await; + let thread_id = ThreadId::new(); + let err = harness + .control + .subscribe_status(thread_id) + .await + .expect_err("subscribe_status should fail for missing thread"); + assert_matches!(err, CodexErr::ThreadNotFound(id) if id == thread_id); +} + +#[tokio::test] +async fn subscribe_status_updates_on_shutdown() { + let harness = AgentControlHarness::new().await; + let (thread_id, thread) = harness.start_thread().await; + let mut status_rx = harness + .control + .subscribe_status(thread_id) + .await + .expect("subscribe_status should succeed"); + assert_eq!(status_rx.borrow().clone(), AgentStatus::PendingInit); + + let _ = thread + .submit(Op::Shutdown {}) + .await + .expect("shutdown should submit"); + + let _ = status_rx.changed().await; + assert_eq!(status_rx.borrow().clone(), AgentStatus::Shutdown); +} + +#[tokio::test] +async fn send_input_submits_user_message() { + let harness = AgentControlHarness::new().await; + let (thread_id, _thread) = harness.start_thread().await; + + let submission_id = harness + .control + .send_input( + thread_id, + vec![UserInput::Text { + text: "hello from tests".to_string(), + text_elements: Vec::new(), + }], + ) + .await + .expect("send_input should succeed"); + assert!(!submission_id.is_empty()); + let expected = ( + thread_id, + Op::UserInput { + items: vec![UserInput::Text { + text: "hello from tests".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }, + ); + let captured = harness + .manager + .captured_ops() + .into_iter() + .find(|entry| *entry == expected); + assert_eq!(captured, Some(expected)); +} + +#[tokio::test] +async fn spawn_agent_creates_thread_and_sends_prompt() { + let harness = AgentControlHarness::new().await; + let thread_id = harness + .control + .spawn_agent(harness.config.clone(), text_input("spawned"), None) + .await + .expect("spawn_agent should succeed"); + let _thread = harness + .manager + .get_thread(thread_id) + .await + .expect("thread should be registered"); + let expected = ( + thread_id, + Op::UserInput { + items: vec![UserInput::Text { + text: "spawned".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }, + ); + let captured = harness + .manager + .captured_ops() + .into_iter() + .find(|entry| *entry == expected); + assert_eq!(captured, Some(expected)); +} + +#[tokio::test] +async fn spawn_agent_can_fork_parent_thread_history() { + let harness = AgentControlHarness::new().await; + let (parent_thread_id, parent_thread) = harness.start_thread().await; + parent_thread + .inject_user_message_without_turn("parent seed context".to_string()) + .await; + let turn_context = parent_thread.codex.session.new_default_turn().await; + let parent_spawn_call_id = "spawn-call-history".to_string(); + let parent_spawn_call = ResponseItem::FunctionCall { + id: None, + name: "spawn_agent".to_string(), + namespace: None, + arguments: "{}".to_string(), + call_id: parent_spawn_call_id.clone(), + }; + parent_thread + .codex + .session + .record_conversation_items(turn_context.as_ref(), &[parent_spawn_call]) + .await; + parent_thread + .codex + .session + .ensure_rollout_materialized() + .await; + parent_thread.codex.session.flush_rollout().await; + + let child_thread_id = harness + .control + .spawn_agent_with_options( + harness.config.clone(), + text_input("child task"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_nickname: None, + agent_role: None, + })), + SpawnAgentOptions { + fork_parent_spawn_call_id: Some(parent_spawn_call_id), + }, + ) + .await + .expect("forked spawn should succeed"); + + let child_thread = harness + .manager + .get_thread(child_thread_id) + .await + .expect("child thread should be registered"); + assert_ne!(child_thread_id, parent_thread_id); + let history = child_thread.codex.session.clone_history().await; + assert!(history_contains_text( + history.raw_items(), + "parent seed context" + )); + + let expected = ( + child_thread_id, + Op::UserInput { + items: vec![UserInput::Text { + text: "child task".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }, + ); + let captured = harness + .manager + .captured_ops() + .into_iter() + .find(|entry| *entry == expected); + assert_eq!(captured, Some(expected)); + + let _ = harness + .control + .shutdown_agent(child_thread_id) + .await + .expect("child shutdown should submit"); + let _ = parent_thread + .submit(Op::Shutdown {}) + .await + .expect("parent shutdown should submit"); +} + +#[tokio::test] +async fn spawn_agent_fork_injects_output_for_parent_spawn_call() { + let harness = AgentControlHarness::new().await; + let (parent_thread_id, parent_thread) = harness.start_thread().await; + let turn_context = parent_thread.codex.session.new_default_turn().await; + let parent_spawn_call_id = "spawn-call-1".to_string(); + let parent_spawn_call = ResponseItem::FunctionCall { + id: None, + name: "spawn_agent".to_string(), + namespace: None, + arguments: "{}".to_string(), + call_id: parent_spawn_call_id.clone(), + }; + parent_thread + .codex + .session + .record_conversation_items(turn_context.as_ref(), &[parent_spawn_call]) + .await; + parent_thread + .codex + .session + .ensure_rollout_materialized() + .await; + parent_thread.codex.session.flush_rollout().await; + + let child_thread_id = harness + .control + .spawn_agent_with_options( + harness.config.clone(), + text_input("child task"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_nickname: None, + agent_role: None, + })), + SpawnAgentOptions { + fork_parent_spawn_call_id: Some(parent_spawn_call_id.clone()), + }, + ) + .await + .expect("forked spawn should succeed"); + + let child_thread = harness + .manager + .get_thread(child_thread_id) + .await + .expect("child thread should be registered"); + let history = child_thread.codex.session.clone_history().await; + let injected_output = history.raw_items().iter().find_map(|item| match item { + ResponseItem::FunctionCallOutput { call_id, output } + if call_id == &parent_spawn_call_id => + { + Some(output) + } + _ => None, + }); + let injected_output = + injected_output.expect("forked child should contain synthetic tool output"); + assert_eq!( + injected_output.text_content(), + Some(FORKED_SPAWN_AGENT_OUTPUT_MESSAGE) + ); + assert_eq!(injected_output.success, Some(true)); + + let _ = harness + .control + .shutdown_agent(child_thread_id) + .await + .expect("child shutdown should submit"); + let _ = parent_thread + .submit(Op::Shutdown {}) + .await + .expect("parent shutdown should submit"); +} + +#[tokio::test] +async fn spawn_agent_fork_flushes_parent_rollout_before_loading_history() { + let harness = AgentControlHarness::new().await; + let (parent_thread_id, parent_thread) = harness.start_thread().await; + let turn_context = parent_thread.codex.session.new_default_turn().await; + let parent_spawn_call_id = "spawn-call-unflushed".to_string(); + let parent_spawn_call = ResponseItem::FunctionCall { + id: None, + name: "spawn_agent".to_string(), + namespace: None, + arguments: "{}".to_string(), + call_id: parent_spawn_call_id.clone(), + }; + parent_thread + .codex + .session + .record_conversation_items(turn_context.as_ref(), &[parent_spawn_call]) + .await; + + let child_thread_id = harness + .control + .spawn_agent_with_options( + harness.config.clone(), + text_input("child task"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_nickname: None, + agent_role: None, + })), + SpawnAgentOptions { + fork_parent_spawn_call_id: Some(parent_spawn_call_id.clone()), + }, + ) + .await + .expect("forked spawn should flush parent rollout before loading history"); + + let child_thread = harness + .manager + .get_thread(child_thread_id) + .await + .expect("child thread should be registered"); + let history = child_thread.codex.session.clone_history().await; + + let mut parent_call_index = None; + let mut injected_output_index = None; + for (idx, item) in history.raw_items().iter().enumerate() { + match item { + ResponseItem::FunctionCall { call_id, .. } if call_id == &parent_spawn_call_id => { + parent_call_index = Some(idx); + } + ResponseItem::FunctionCallOutput { call_id, .. } + if call_id == &parent_spawn_call_id => + { + injected_output_index = Some(idx); + } + _ => {} + } + } + + let parent_call_index = + parent_call_index.expect("forked child should include the parent spawn_agent call"); + let injected_output_index = injected_output_index + .expect("forked child should include synthetic output for the parent spawn_agent call"); + assert!(parent_call_index < injected_output_index); + + let _ = harness + .control + .shutdown_agent(child_thread_id) + .await + .expect("child shutdown should submit"); + let _ = parent_thread + .submit(Op::Shutdown {}) + .await + .expect("parent shutdown should submit"); +} + +#[tokio::test] +async fn spawn_agent_respects_max_threads_limit() { + let max_threads = 1usize; + let (_home, config) = test_config_with_cli_overrides(vec![( + "agents.max_threads".to_string(), + TomlValue::Integer(max_threads as i64), + )]) + .await; + let manager = ThreadManager::with_models_provider_and_home_for_tests( + CodexAuth::from_api_key("dummy"), + config.model_provider.clone(), + config.codex_home.clone(), + ); + let control = manager.agent_control(); + + let _ = manager + .start_thread(config.clone()) + .await + .expect("start thread"); + + let first_agent_id = control + .spawn_agent(config.clone(), text_input("hello"), None) + .await + .expect("spawn_agent should succeed"); + + let err = control + .spawn_agent(config, text_input("hello again"), None) + .await + .expect_err("spawn_agent should respect max threads"); + let CodexErr::AgentLimitReached { + max_threads: seen_max_threads, + } = err + else { + panic!("expected CodexErr::AgentLimitReached"); + }; + assert_eq!(seen_max_threads, max_threads); + + let _ = control + .shutdown_agent(first_agent_id) + .await + .expect("shutdown agent"); +} + +#[tokio::test] +async fn spawn_agent_releases_slot_after_shutdown() { + let max_threads = 1usize; + let (_home, config) = test_config_with_cli_overrides(vec![( + "agents.max_threads".to_string(), + TomlValue::Integer(max_threads as i64), + )]) + .await; + let manager = ThreadManager::with_models_provider_and_home_for_tests( + CodexAuth::from_api_key("dummy"), + config.model_provider.clone(), + config.codex_home.clone(), + ); + let control = manager.agent_control(); + + let first_agent_id = control + .spawn_agent(config.clone(), text_input("hello"), None) + .await + .expect("spawn_agent should succeed"); + let _ = control + .shutdown_agent(first_agent_id) + .await + .expect("shutdown agent"); + + let second_agent_id = control + .spawn_agent(config.clone(), text_input("hello again"), None) + .await + .expect("spawn_agent should succeed after shutdown"); + let _ = control + .shutdown_agent(second_agent_id) + .await + .expect("shutdown agent"); +} + +#[tokio::test] +async fn spawn_agent_limit_shared_across_clones() { + let max_threads = 1usize; + let (_home, config) = test_config_with_cli_overrides(vec![( + "agents.max_threads".to_string(), + TomlValue::Integer(max_threads as i64), + )]) + .await; + let manager = ThreadManager::with_models_provider_and_home_for_tests( + CodexAuth::from_api_key("dummy"), + config.model_provider.clone(), + config.codex_home.clone(), + ); + let control = manager.agent_control(); + let cloned = control.clone(); + + let first_agent_id = cloned + .spawn_agent(config.clone(), text_input("hello"), None) + .await + .expect("spawn_agent should succeed"); + + let err = control + .spawn_agent(config, text_input("hello again"), None) + .await + .expect_err("spawn_agent should respect shared guard"); + let CodexErr::AgentLimitReached { max_threads } = err else { + panic!("expected CodexErr::AgentLimitReached"); + }; + assert_eq!(max_threads, 1); + + let _ = control + .shutdown_agent(first_agent_id) + .await + .expect("shutdown agent"); +} + +#[tokio::test] +async fn resume_agent_respects_max_threads_limit() { + let max_threads = 1usize; + let (_home, config) = test_config_with_cli_overrides(vec![( + "agents.max_threads".to_string(), + TomlValue::Integer(max_threads as i64), + )]) + .await; + let manager = ThreadManager::with_models_provider_and_home_for_tests( + CodexAuth::from_api_key("dummy"), + config.model_provider.clone(), + config.codex_home.clone(), + ); + let control = manager.agent_control(); + + let resumable_id = control + .spawn_agent(config.clone(), text_input("hello"), None) + .await + .expect("spawn_agent should succeed"); + let _ = control + .shutdown_agent(resumable_id) + .await + .expect("shutdown resumable thread"); + + let active_id = control + .spawn_agent(config.clone(), text_input("occupy"), None) + .await + .expect("spawn_agent should succeed for active slot"); + + let err = control + .resume_agent_from_rollout(config, resumable_id, SessionSource::Exec) + .await + .expect_err("resume should respect max threads"); + let CodexErr::AgentLimitReached { + max_threads: seen_max_threads, + } = err + else { + panic!("expected CodexErr::AgentLimitReached"); + }; + assert_eq!(seen_max_threads, max_threads); + + let _ = control + .shutdown_agent(active_id) + .await + .expect("shutdown active thread"); +} + +#[tokio::test] +async fn resume_agent_releases_slot_after_resume_failure() { + let max_threads = 1usize; + let (_home, config) = test_config_with_cli_overrides(vec![( + "agents.max_threads".to_string(), + TomlValue::Integer(max_threads as i64), + )]) + .await; + let manager = ThreadManager::with_models_provider_and_home_for_tests( + CodexAuth::from_api_key("dummy"), + config.model_provider.clone(), + config.codex_home.clone(), + ); + let control = manager.agent_control(); + + let _ = control + .resume_agent_from_rollout(config.clone(), ThreadId::new(), SessionSource::Exec) + .await + .expect_err("resume should fail for missing rollout path"); + + let resumed_id = control + .spawn_agent(config, text_input("hello"), None) + .await + .expect("spawn should succeed after failed resume"); + let _ = control + .shutdown_agent(resumed_id) + .await + .expect("shutdown resumed thread"); +} + +#[tokio::test] +async fn spawn_child_completion_notifies_parent_history() { + let harness = AgentControlHarness::new().await; + let (parent_thread_id, parent_thread) = harness.start_thread().await; + + let child_thread_id = harness + .control + .spawn_agent( + harness.config.clone(), + text_input("hello child"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_nickname: None, + agent_role: Some("explorer".to_string()), + })), + ) + .await + .expect("child spawn should succeed"); + + let child_thread = harness + .manager + .get_thread(child_thread_id) + .await + .expect("child thread should exist"); + let _ = child_thread + .submit(Op::Shutdown {}) + .await + .expect("child shutdown should submit"); + + assert_eq!(wait_for_subagent_notification(&parent_thread).await, true); +} + +#[tokio::test] +async fn completion_watcher_notifies_parent_when_child_is_missing() { + let harness = AgentControlHarness::new().await; + let (parent_thread_id, parent_thread) = harness.start_thread().await; + let child_thread_id = ThreadId::new(); + + harness.control.maybe_start_completion_watcher( + child_thread_id, + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_nickname: None, + agent_role: Some("explorer".to_string()), + })), + ); + + assert_eq!(wait_for_subagent_notification(&parent_thread).await, true); + + let history_items = parent_thread + .codex + .session + .clone_history() + .await + .raw_items() + .to_vec(); + assert_eq!( + history_contains_text( + &history_items, + &format!("\"agent_id\":\"{child_thread_id}\"") + ), + true + ); + assert_eq!( + history_contains_text(&history_items, "\"status\":\"not_found\""), + true + ); +} + +#[tokio::test] +async fn spawn_thread_subagent_gets_random_nickname_in_session_source() { + let harness = AgentControlHarness::new().await; + let (parent_thread_id, _parent_thread) = harness.start_thread().await; + + let child_thread_id = harness + .control + .spawn_agent( + harness.config.clone(), + text_input("hello child"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_nickname: None, + agent_role: Some("explorer".to_string()), + })), + ) + .await + .expect("child spawn should succeed"); + + let child_thread = harness + .manager + .get_thread(child_thread_id) + .await + .expect("child thread should be registered"); + let snapshot = child_thread.config_snapshot().await; + + let SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id: seen_parent_thread_id, + depth, + agent_nickname, + agent_role, + }) = snapshot.session_source + else { + panic!("expected thread-spawn sub-agent source"); + }; + assert_eq!(seen_parent_thread_id, parent_thread_id); + assert_eq!(depth, 1); + assert!(agent_nickname.is_some()); + assert_eq!(agent_role, Some("explorer".to_string())); +} + +#[tokio::test] +async fn spawn_thread_subagent_uses_role_specific_nickname_candidates() { + let mut harness = AgentControlHarness::new().await; + harness.config.agent_roles.insert( + "researcher".to_string(), + AgentRoleConfig { + description: Some("Research role".to_string()), + config_file: None, + nickname_candidates: Some(vec!["Atlas".to_string()]), + }, + ); + let (parent_thread_id, _parent_thread) = harness.start_thread().await; + + let child_thread_id = harness + .control + .spawn_agent( + harness.config.clone(), + text_input("hello child"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_nickname: None, + agent_role: Some("researcher".to_string()), + })), + ) + .await + .expect("child spawn should succeed"); + + let child_thread = harness + .manager + .get_thread(child_thread_id) + .await + .expect("child thread should be registered"); + let snapshot = child_thread.config_snapshot().await; + + let SessionSource::SubAgent(SubAgentSource::ThreadSpawn { agent_nickname, .. }) = + snapshot.session_source + else { + panic!("expected thread-spawn sub-agent source"); + }; + assert_eq!(agent_nickname, Some("Atlas".to_string())); +} + +#[tokio::test] +async fn resume_thread_subagent_restores_stored_nickname_and_role() { + let (home, mut config) = test_config().await; + config + .features + .enable(Feature::Sqlite) + .expect("test config should allow sqlite"); + let manager = ThreadManager::with_models_provider_and_home_for_tests( + CodexAuth::from_api_key("dummy"), + config.model_provider.clone(), + config.codex_home.clone(), + ); + let control = manager.agent_control(); + let harness = AgentControlHarness { + _home: home, + config, + manager, + control, + }; + let (parent_thread_id, _parent_thread) = harness.start_thread().await; + + let child_thread_id = harness + .control + .spawn_agent( + harness.config.clone(), + text_input("hello child"), + Some(SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_nickname: None, + agent_role: Some("explorer".to_string()), + })), + ) + .await + .expect("child spawn should succeed"); + + let child_thread = harness + .manager + .get_thread(child_thread_id) + .await + .expect("child thread should exist"); + let mut status_rx = harness + .control + .subscribe_status(child_thread_id) + .await + .expect("status subscription should succeed"); + if matches!(status_rx.borrow().clone(), AgentStatus::PendingInit) { + timeout(Duration::from_secs(5), async { + loop { + status_rx + .changed() + .await + .expect("child status should advance past pending init"); + if !matches!(status_rx.borrow().clone(), AgentStatus::PendingInit) { + break; + } + } + }) + .await + .expect("child should initialize before shutdown"); + } + let original_snapshot = child_thread.config_snapshot().await; + let original_nickname = original_snapshot + .session_source + .get_nickname() + .expect("spawned sub-agent should have a nickname"); + let state_db = child_thread + .state_db() + .expect("sqlite state db should be available for nickname resume test"); + timeout(Duration::from_secs(5), async { + loop { + if let Ok(Some(metadata)) = state_db.get_thread(child_thread_id).await + && metadata.agent_nickname.is_some() + && metadata.agent_role.as_deref() == Some("explorer") + { + break; + } + sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("child thread metadata should be persisted to sqlite before shutdown"); + + let _ = harness + .control + .shutdown_agent(child_thread_id) + .await + .expect("child shutdown should submit"); + + let resumed_thread_id = harness + .control + .resume_agent_from_rollout( + harness.config.clone(), + child_thread_id, + SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id, + depth: 1, + agent_nickname: None, + agent_role: None, + }), + ) + .await + .expect("resume should succeed"); + assert_eq!(resumed_thread_id, child_thread_id); + + let resumed_snapshot = harness + .manager + .get_thread(resumed_thread_id) + .await + .expect("resumed child thread should exist") + .config_snapshot() + .await; + let SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id: resumed_parent_thread_id, + depth: resumed_depth, + agent_nickname: resumed_nickname, + agent_role: resumed_role, + }) = resumed_snapshot.session_source + else { + panic!("expected thread-spawn sub-agent source"); + }; + assert_eq!(resumed_parent_thread_id, parent_thread_id); + assert_eq!(resumed_depth, 1); + assert_eq!(resumed_nickname, Some(original_nickname)); + assert_eq!(resumed_role, Some("explorer".to_string())); + + let _ = harness + .control + .shutdown_agent(resumed_thread_id) + .await + .expect("resumed child shutdown should submit"); +} diff --git a/codex-rs/core/src/agent/guards.rs b/codex-rs/core/src/agent/guards.rs index 056d2b7f6a..167b993013 100644 --- a/codex-rs/core/src/agent/guards.rs +++ b/codex-rs/core/src/agent/guards.rs @@ -222,249 +222,5 @@ impl Drop for SpawnReservation { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use std::collections::HashSet; - - #[test] - fn format_agent_nickname_adds_ordinals_after_reset() { - assert_eq!(format_agent_nickname("Plato", 0), "Plato"); - assert_eq!(format_agent_nickname("Plato", 1), "Plato the 2nd"); - assert_eq!(format_agent_nickname("Plato", 2), "Plato the 3rd"); - assert_eq!(format_agent_nickname("Plato", 10), "Plato the 11th"); - assert_eq!(format_agent_nickname("Plato", 20), "Plato the 21st"); - } - - #[test] - fn session_depth_defaults_to_zero_for_root_sources() { - assert_eq!(session_depth(&SessionSource::Cli), 0); - } - - #[test] - fn thread_spawn_depth_increments_and_enforces_limit() { - let session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id: ThreadId::new(), - depth: 1, - agent_nickname: None, - agent_role: None, - }); - let child_depth = next_thread_spawn_depth(&session_source); - assert_eq!(child_depth, 2); - assert!(exceeds_thread_spawn_depth_limit(child_depth, 1)); - } - - #[test] - fn non_thread_spawn_subagents_default_to_depth_zero() { - let session_source = SessionSource::SubAgent(SubAgentSource::Review); - assert_eq!(session_depth(&session_source), 0); - assert_eq!(next_thread_spawn_depth(&session_source), 1); - assert!(!exceeds_thread_spawn_depth_limit(1, 1)); - } - - #[test] - fn reservation_drop_releases_slot() { - let guards = Arc::new(Guards::default()); - let reservation = guards.reserve_spawn_slot(Some(1)).expect("reserve slot"); - drop(reservation); - - let reservation = guards.reserve_spawn_slot(Some(1)).expect("slot released"); - drop(reservation); - } - - #[test] - fn commit_holds_slot_until_release() { - let guards = Arc::new(Guards::default()); - let reservation = guards.reserve_spawn_slot(Some(1)).expect("reserve slot"); - let thread_id = ThreadId::new(); - reservation.commit(thread_id); - - let err = match guards.reserve_spawn_slot(Some(1)) { - Ok(_) => panic!("limit should be enforced"), - Err(err) => err, - }; - let CodexErr::AgentLimitReached { max_threads } = err else { - panic!("expected CodexErr::AgentLimitReached"); - }; - assert_eq!(max_threads, 1); - - guards.release_spawned_thread(thread_id); - let reservation = guards - .reserve_spawn_slot(Some(1)) - .expect("slot released after thread removal"); - drop(reservation); - } - - #[test] - fn release_ignores_unknown_thread_id() { - let guards = Arc::new(Guards::default()); - let reservation = guards.reserve_spawn_slot(Some(1)).expect("reserve slot"); - let thread_id = ThreadId::new(); - reservation.commit(thread_id); - - guards.release_spawned_thread(ThreadId::new()); - - let err = match guards.reserve_spawn_slot(Some(1)) { - Ok(_) => panic!("limit should still be enforced"), - Err(err) => err, - }; - let CodexErr::AgentLimitReached { max_threads } = err else { - panic!("expected CodexErr::AgentLimitReached"); - }; - assert_eq!(max_threads, 1); - - guards.release_spawned_thread(thread_id); - let reservation = guards - .reserve_spawn_slot(Some(1)) - .expect("slot released after real thread removal"); - drop(reservation); - } - - #[test] - fn release_is_idempotent_for_registered_threads() { - let guards = Arc::new(Guards::default()); - let reservation = guards.reserve_spawn_slot(Some(1)).expect("reserve slot"); - let first_id = ThreadId::new(); - reservation.commit(first_id); - - guards.release_spawned_thread(first_id); - - let reservation = guards.reserve_spawn_slot(Some(1)).expect("slot reused"); - let second_id = ThreadId::new(); - reservation.commit(second_id); - - guards.release_spawned_thread(first_id); - - let err = match guards.reserve_spawn_slot(Some(1)) { - Ok(_) => panic!("limit should still be enforced"), - Err(err) => err, - }; - let CodexErr::AgentLimitReached { max_threads } = err else { - panic!("expected CodexErr::AgentLimitReached"); - }; - assert_eq!(max_threads, 1); - - guards.release_spawned_thread(second_id); - let reservation = guards - .reserve_spawn_slot(Some(1)) - .expect("slot released after second thread removal"); - drop(reservation); - } - - #[test] - fn failed_spawn_keeps_nickname_marked_used() { - let guards = Arc::new(Guards::default()); - let mut reservation = guards.reserve_spawn_slot(None).expect("reserve slot"); - let agent_nickname = reservation - .reserve_agent_nickname(&["alpha"]) - .expect("reserve agent name"); - assert_eq!(agent_nickname, "alpha"); - drop(reservation); - - let mut reservation = guards.reserve_spawn_slot(None).expect("reserve slot"); - let agent_nickname = reservation - .reserve_agent_nickname(&["alpha", "beta"]) - .expect("unused name should still be preferred"); - assert_eq!(agent_nickname, "beta"); - } - - #[test] - fn agent_nickname_resets_used_pool_when_exhausted() { - let guards = Arc::new(Guards::default()); - let mut first = guards.reserve_spawn_slot(None).expect("reserve first slot"); - let first_name = first - .reserve_agent_nickname(&["alpha"]) - .expect("reserve first agent name"); - let first_id = ThreadId::new(); - first.commit(first_id); - assert_eq!(first_name, "alpha"); - - let mut second = guards - .reserve_spawn_slot(None) - .expect("reserve second slot"); - let second_name = second - .reserve_agent_nickname(&["alpha"]) - .expect("name should be reused after pool reset"); - assert_eq!(second_name, "alpha the 2nd"); - let active_agents = guards - .active_agents - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - assert_eq!(active_agents.nickname_reset_count, 1); - } - - #[test] - fn released_nickname_stays_used_until_pool_reset() { - let guards = Arc::new(Guards::default()); - - let mut first = guards.reserve_spawn_slot(None).expect("reserve first slot"); - let first_name = first - .reserve_agent_nickname(&["alpha"]) - .expect("reserve first agent name"); - let first_id = ThreadId::new(); - first.commit(first_id); - assert_eq!(first_name, "alpha"); - - guards.release_spawned_thread(first_id); - - let mut second = guards - .reserve_spawn_slot(None) - .expect("reserve second slot"); - let second_name = second - .reserve_agent_nickname(&["alpha", "beta"]) - .expect("released name should still be marked used"); - assert_eq!(second_name, "beta"); - let second_id = ThreadId::new(); - second.commit(second_id); - guards.release_spawned_thread(second_id); - - let mut third = guards.reserve_spawn_slot(None).expect("reserve third slot"); - let third_name = third - .reserve_agent_nickname(&["alpha", "beta"]) - .expect("pool reset should permit a duplicate"); - let expected_names = - HashSet::from(["alpha the 2nd".to_string(), "beta the 2nd".to_string()]); - assert!(expected_names.contains(&third_name)); - let active_agents = guards - .active_agents - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - assert_eq!(active_agents.nickname_reset_count, 1); - } - - #[test] - fn repeated_resets_advance_the_ordinal_suffix() { - let guards = Arc::new(Guards::default()); - - let mut first = guards.reserve_spawn_slot(None).expect("reserve first slot"); - let first_name = first - .reserve_agent_nickname(&["Plato"]) - .expect("reserve first agent name"); - let first_id = ThreadId::new(); - first.commit(first_id); - assert_eq!(first_name, "Plato"); - guards.release_spawned_thread(first_id); - - let mut second = guards - .reserve_spawn_slot(None) - .expect("reserve second slot"); - let second_name = second - .reserve_agent_nickname(&["Plato"]) - .expect("reserve second agent name"); - let second_id = ThreadId::new(); - second.commit(second_id); - assert_eq!(second_name, "Plato the 2nd"); - guards.release_spawned_thread(second_id); - - let mut third = guards.reserve_spawn_slot(None).expect("reserve third slot"); - let third_name = third - .reserve_agent_nickname(&["Plato"]) - .expect("reserve third agent name"); - assert_eq!(third_name, "Plato the 3rd"); - let active_agents = guards - .active_agents - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - assert_eq!(active_agents.nickname_reset_count, 2); - } -} +#[path = "guards_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/agent/guards_tests.rs b/codex-rs/core/src/agent/guards_tests.rs new file mode 100644 index 0000000000..53bb5f3b30 --- /dev/null +++ b/codex-rs/core/src/agent/guards_tests.rs @@ -0,0 +1,243 @@ +use super::*; +use pretty_assertions::assert_eq; +use std::collections::HashSet; + +#[test] +fn format_agent_nickname_adds_ordinals_after_reset() { + assert_eq!(format_agent_nickname("Plato", 0), "Plato"); + assert_eq!(format_agent_nickname("Plato", 1), "Plato the 2nd"); + assert_eq!(format_agent_nickname("Plato", 2), "Plato the 3rd"); + assert_eq!(format_agent_nickname("Plato", 10), "Plato the 11th"); + assert_eq!(format_agent_nickname("Plato", 20), "Plato the 21st"); +} + +#[test] +fn session_depth_defaults_to_zero_for_root_sources() { + assert_eq!(session_depth(&SessionSource::Cli), 0); +} + +#[test] +fn thread_spawn_depth_increments_and_enforces_limit() { + let session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id: ThreadId::new(), + depth: 1, + agent_nickname: None, + agent_role: None, + }); + let child_depth = next_thread_spawn_depth(&session_source); + assert_eq!(child_depth, 2); + assert!(exceeds_thread_spawn_depth_limit(child_depth, 1)); +} + +#[test] +fn non_thread_spawn_subagents_default_to_depth_zero() { + let session_source = SessionSource::SubAgent(SubAgentSource::Review); + assert_eq!(session_depth(&session_source), 0); + assert_eq!(next_thread_spawn_depth(&session_source), 1); + assert!(!exceeds_thread_spawn_depth_limit(1, 1)); +} + +#[test] +fn reservation_drop_releases_slot() { + let guards = Arc::new(Guards::default()); + let reservation = guards.reserve_spawn_slot(Some(1)).expect("reserve slot"); + drop(reservation); + + let reservation = guards.reserve_spawn_slot(Some(1)).expect("slot released"); + drop(reservation); +} + +#[test] +fn commit_holds_slot_until_release() { + let guards = Arc::new(Guards::default()); + let reservation = guards.reserve_spawn_slot(Some(1)).expect("reserve slot"); + let thread_id = ThreadId::new(); + reservation.commit(thread_id); + + let err = match guards.reserve_spawn_slot(Some(1)) { + Ok(_) => panic!("limit should be enforced"), + Err(err) => err, + }; + let CodexErr::AgentLimitReached { max_threads } = err else { + panic!("expected CodexErr::AgentLimitReached"); + }; + assert_eq!(max_threads, 1); + + guards.release_spawned_thread(thread_id); + let reservation = guards + .reserve_spawn_slot(Some(1)) + .expect("slot released after thread removal"); + drop(reservation); +} + +#[test] +fn release_ignores_unknown_thread_id() { + let guards = Arc::new(Guards::default()); + let reservation = guards.reserve_spawn_slot(Some(1)).expect("reserve slot"); + let thread_id = ThreadId::new(); + reservation.commit(thread_id); + + guards.release_spawned_thread(ThreadId::new()); + + let err = match guards.reserve_spawn_slot(Some(1)) { + Ok(_) => panic!("limit should still be enforced"), + Err(err) => err, + }; + let CodexErr::AgentLimitReached { max_threads } = err else { + panic!("expected CodexErr::AgentLimitReached"); + }; + assert_eq!(max_threads, 1); + + guards.release_spawned_thread(thread_id); + let reservation = guards + .reserve_spawn_slot(Some(1)) + .expect("slot released after real thread removal"); + drop(reservation); +} + +#[test] +fn release_is_idempotent_for_registered_threads() { + let guards = Arc::new(Guards::default()); + let reservation = guards.reserve_spawn_slot(Some(1)).expect("reserve slot"); + let first_id = ThreadId::new(); + reservation.commit(first_id); + + guards.release_spawned_thread(first_id); + + let reservation = guards.reserve_spawn_slot(Some(1)).expect("slot reused"); + let second_id = ThreadId::new(); + reservation.commit(second_id); + + guards.release_spawned_thread(first_id); + + let err = match guards.reserve_spawn_slot(Some(1)) { + Ok(_) => panic!("limit should still be enforced"), + Err(err) => err, + }; + let CodexErr::AgentLimitReached { max_threads } = err else { + panic!("expected CodexErr::AgentLimitReached"); + }; + assert_eq!(max_threads, 1); + + guards.release_spawned_thread(second_id); + let reservation = guards + .reserve_spawn_slot(Some(1)) + .expect("slot released after second thread removal"); + drop(reservation); +} + +#[test] +fn failed_spawn_keeps_nickname_marked_used() { + let guards = Arc::new(Guards::default()); + let mut reservation = guards.reserve_spawn_slot(None).expect("reserve slot"); + let agent_nickname = reservation + .reserve_agent_nickname(&["alpha"]) + .expect("reserve agent name"); + assert_eq!(agent_nickname, "alpha"); + drop(reservation); + + let mut reservation = guards.reserve_spawn_slot(None).expect("reserve slot"); + let agent_nickname = reservation + .reserve_agent_nickname(&["alpha", "beta"]) + .expect("unused name should still be preferred"); + assert_eq!(agent_nickname, "beta"); +} + +#[test] +fn agent_nickname_resets_used_pool_when_exhausted() { + let guards = Arc::new(Guards::default()); + let mut first = guards.reserve_spawn_slot(None).expect("reserve first slot"); + let first_name = first + .reserve_agent_nickname(&["alpha"]) + .expect("reserve first agent name"); + let first_id = ThreadId::new(); + first.commit(first_id); + assert_eq!(first_name, "alpha"); + + let mut second = guards + .reserve_spawn_slot(None) + .expect("reserve second slot"); + let second_name = second + .reserve_agent_nickname(&["alpha"]) + .expect("name should be reused after pool reset"); + assert_eq!(second_name, "alpha the 2nd"); + let active_agents = guards + .active_agents + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + assert_eq!(active_agents.nickname_reset_count, 1); +} + +#[test] +fn released_nickname_stays_used_until_pool_reset() { + let guards = Arc::new(Guards::default()); + + let mut first = guards.reserve_spawn_slot(None).expect("reserve first slot"); + let first_name = first + .reserve_agent_nickname(&["alpha"]) + .expect("reserve first agent name"); + let first_id = ThreadId::new(); + first.commit(first_id); + assert_eq!(first_name, "alpha"); + + guards.release_spawned_thread(first_id); + + let mut second = guards + .reserve_spawn_slot(None) + .expect("reserve second slot"); + let second_name = second + .reserve_agent_nickname(&["alpha", "beta"]) + .expect("released name should still be marked used"); + assert_eq!(second_name, "beta"); + let second_id = ThreadId::new(); + second.commit(second_id); + guards.release_spawned_thread(second_id); + + let mut third = guards.reserve_spawn_slot(None).expect("reserve third slot"); + let third_name = third + .reserve_agent_nickname(&["alpha", "beta"]) + .expect("pool reset should permit a duplicate"); + let expected_names = HashSet::from(["alpha the 2nd".to_string(), "beta the 2nd".to_string()]); + assert!(expected_names.contains(&third_name)); + let active_agents = guards + .active_agents + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + assert_eq!(active_agents.nickname_reset_count, 1); +} + +#[test] +fn repeated_resets_advance_the_ordinal_suffix() { + let guards = Arc::new(Guards::default()); + + let mut first = guards.reserve_spawn_slot(None).expect("reserve first slot"); + let first_name = first + .reserve_agent_nickname(&["Plato"]) + .expect("reserve first agent name"); + let first_id = ThreadId::new(); + first.commit(first_id); + assert_eq!(first_name, "Plato"); + guards.release_spawned_thread(first_id); + + let mut second = guards + .reserve_spawn_slot(None) + .expect("reserve second slot"); + let second_name = second + .reserve_agent_nickname(&["Plato"]) + .expect("reserve second agent name"); + let second_id = ThreadId::new(); + second.commit(second_id); + assert_eq!(second_name, "Plato the 2nd"); + guards.release_spawned_thread(second_id); + + let mut third = guards.reserve_spawn_slot(None).expect("reserve third slot"); + let third_name = third + .reserve_agent_nickname(&["Plato"]) + .expect("reserve third agent name"); + assert_eq!(third_name, "Plato the 3rd"); + let active_agents = guards + .active_agents + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + assert_eq!(active_agents.nickname_reset_count, 2); +} diff --git a/codex-rs/core/src/agent/role.rs b/codex-rs/core/src/agent/role.rs index 23b60583e7..8d607c5436 100644 --- a/codex-rs/core/src/agent/role.rs +++ b/codex-rs/core/src/agent/role.rs @@ -309,685 +309,5 @@ Rules: } #[cfg(test)] -mod tests { - use super::*; - use crate::config::CONFIG_TOML_FILE; - use crate::config::ConfigBuilder; - use crate::config_loader::ConfigLayerStackOrdering; - use crate::plugins::PluginsManager; - use crate::skills::SkillsManager; - use codex_protocol::openai_models::ReasoningEffort; - use pretty_assertions::assert_eq; - use std::fs; - use std::path::PathBuf; - use std::sync::Arc; - use tempfile::TempDir; - - async fn test_config_with_cli_overrides( - cli_overrides: Vec<(String, TomlValue)>, - ) -> (TempDir, Config) { - let home = TempDir::new().expect("create temp dir"); - let home_path = home.path().to_path_buf(); - let config = ConfigBuilder::default() - .codex_home(home_path.clone()) - .cli_overrides(cli_overrides) - .fallback_cwd(Some(home_path)) - .build() - .await - .expect("load test config"); - (home, config) - } - - async fn write_role_config(home: &TempDir, name: &str, contents: &str) -> PathBuf { - let role_path = home.path().join(name); - tokio::fs::write(&role_path, contents) - .await - .expect("write role config"); - role_path - } - - fn session_flags_layer_count(config: &Config) -> usize { - config - .config_layer_stack - .get_layers(ConfigLayerStackOrdering::LowestPrecedenceFirst, true) - .into_iter() - .filter(|layer| layer.name == ConfigLayerSource::SessionFlags) - .count() - } - - #[tokio::test] - async fn apply_role_defaults_to_default_and_leaves_config_unchanged() { - let (_home, mut config) = test_config_with_cli_overrides(Vec::new()).await; - let before = config.clone(); - - apply_role_to_config(&mut config, None) - .await - .expect("default role should apply"); - - assert_eq!(before, config); - } - - #[tokio::test] - async fn apply_role_returns_error_for_unknown_role() { - let (_home, mut config) = test_config_with_cli_overrides(Vec::new()).await; - - let err = apply_role_to_config(&mut config, Some("missing-role")) - .await - .expect_err("unknown role should fail"); - - assert_eq!(err, "unknown agent_type 'missing-role'"); - } - - #[tokio::test] - #[ignore = "No role requiring it for now"] - async fn apply_explorer_role_sets_model_and_adds_session_flags_layer() { - let (_home, mut config) = test_config_with_cli_overrides(Vec::new()).await; - let before_layers = session_flags_layer_count(&config); - - apply_role_to_config(&mut config, Some("explorer")) - .await - .expect("explorer role should apply"); - - assert_eq!(config.model.as_deref(), Some("gpt-5.1-codex-mini")); - assert_eq!(config.model_reasoning_effort, Some(ReasoningEffort::Medium)); - assert_eq!(session_flags_layer_count(&config), before_layers + 1); - } - - #[tokio::test] - async fn apply_role_returns_unavailable_for_missing_user_role_file() { - let (_home, mut config) = test_config_with_cli_overrides(Vec::new()).await; - config.agent_roles.insert( - "custom".to_string(), - AgentRoleConfig { - description: None, - config_file: Some(PathBuf::from("/path/does/not/exist.toml")), - nickname_candidates: None, - }, - ); - - let err = apply_role_to_config(&mut config, Some("custom")) - .await - .expect_err("missing role file should fail"); - - assert_eq!(err, AGENT_TYPE_UNAVAILABLE_ERROR); - } - - #[tokio::test] - async fn apply_role_returns_unavailable_for_invalid_user_role_toml() { - let (home, mut config) = test_config_with_cli_overrides(Vec::new()).await; - let role_path = write_role_config(&home, "invalid-role.toml", "model = [").await; - config.agent_roles.insert( - "custom".to_string(), - AgentRoleConfig { - description: None, - config_file: Some(role_path), - nickname_candidates: None, - }, - ); - - let err = apply_role_to_config(&mut config, Some("custom")) - .await - .expect_err("invalid role file should fail"); - - assert_eq!(err, AGENT_TYPE_UNAVAILABLE_ERROR); - } - - #[tokio::test] - async fn apply_role_ignores_agent_metadata_fields_in_user_role_file() { - let (home, mut config) = test_config_with_cli_overrides(Vec::new()).await; - let role_path = write_role_config( - &home, - "metadata-role.toml", - r#" -name = "archivist" -description = "Role metadata" -nickname_candidates = ["Hypatia"] -developer_instructions = "Stay focused" -model = "role-model" -"#, - ) - .await; - config.agent_roles.insert( - "custom".to_string(), - AgentRoleConfig { - description: None, - config_file: Some(role_path), - nickname_candidates: None, - }, - ); - - apply_role_to_config(&mut config, Some("custom")) - .await - .expect("custom role should apply"); - - assert_eq!(config.model.as_deref(), Some("role-model")); - } - - #[tokio::test] - async fn apply_role_preserves_unspecified_keys() { - let (home, mut config) = test_config_with_cli_overrides(vec![( - "model".to_string(), - TomlValue::String("base-model".to_string()), - )]) - .await; - config.codex_linux_sandbox_exe = Some(PathBuf::from("/tmp/codex-linux-sandbox")); - config.main_execve_wrapper_exe = Some(PathBuf::from("/tmp/codex-execve-wrapper")); - let role_path = write_role_config( - &home, - "effort-only.toml", - "developer_instructions = \"Stay focused\"\nmodel_reasoning_effort = \"high\"", - ) - .await; - config.agent_roles.insert( - "custom".to_string(), - AgentRoleConfig { - description: None, - config_file: Some(role_path), - nickname_candidates: None, - }, - ); - - apply_role_to_config(&mut config, Some("custom")) - .await - .expect("custom role should apply"); - - assert_eq!(config.model.as_deref(), Some("base-model")); - assert_eq!(config.model_reasoning_effort, Some(ReasoningEffort::High)); - assert_eq!( - config.codex_linux_sandbox_exe, - Some(PathBuf::from("/tmp/codex-linux-sandbox")) - ); - assert_eq!( - config.main_execve_wrapper_exe, - Some(PathBuf::from("/tmp/codex-execve-wrapper")) - ); - } - - #[tokio::test] - async fn apply_role_preserves_active_profile_and_model_provider() { - let home = TempDir::new().expect("create temp dir"); - tokio::fs::write( - home.path().join(CONFIG_TOML_FILE), - r#" -[model_providers.test-provider] -name = "Test Provider" -base_url = "https://example.com/v1" -env_key = "TEST_PROVIDER_API_KEY" -wire_api = "responses" - -[profiles.test-profile] -model_provider = "test-provider" -"#, - ) - .await - .expect("write config.toml"); - let mut config = ConfigBuilder::default() - .codex_home(home.path().to_path_buf()) - .harness_overrides(ConfigOverrides { - config_profile: Some("test-profile".to_string()), - ..Default::default() - }) - .fallback_cwd(Some(home.path().to_path_buf())) - .build() - .await - .expect("load config"); - let role_path = write_role_config( - &home, - "empty-role.toml", - "developer_instructions = \"Stay focused\"", - ) - .await; - config.agent_roles.insert( - "custom".to_string(), - AgentRoleConfig { - description: None, - config_file: Some(role_path), - nickname_candidates: None, - }, - ); - - apply_role_to_config(&mut config, Some("custom")) - .await - .expect("custom role should apply"); - - assert_eq!(config.active_profile.as_deref(), Some("test-profile")); - assert_eq!(config.model_provider_id, "test-provider"); - assert_eq!(config.model_provider.name, "Test Provider"); - } - - #[tokio::test] - async fn apply_role_uses_role_profile_instead_of_current_profile() { - let home = TempDir::new().expect("create temp dir"); - tokio::fs::write( - home.path().join(CONFIG_TOML_FILE), - r#" -[model_providers.base-provider] -name = "Base Provider" -base_url = "https://base.example.com/v1" -env_key = "BASE_PROVIDER_API_KEY" -wire_api = "responses" - -[model_providers.role-provider] -name = "Role Provider" -base_url = "https://role.example.com/v1" -env_key = "ROLE_PROVIDER_API_KEY" -wire_api = "responses" - -[profiles.base-profile] -model_provider = "base-provider" - -[profiles.role-profile] -model_provider = "role-provider" -"#, - ) - .await - .expect("write config.toml"); - let mut config = ConfigBuilder::default() - .codex_home(home.path().to_path_buf()) - .harness_overrides(ConfigOverrides { - config_profile: Some("base-profile".to_string()), - ..Default::default() - }) - .fallback_cwd(Some(home.path().to_path_buf())) - .build() - .await - .expect("load config"); - let role_path = write_role_config( - &home, - "profile-role.toml", - "developer_instructions = \"Stay focused\"\nprofile = \"role-profile\"", - ) - .await; - config.agent_roles.insert( - "custom".to_string(), - AgentRoleConfig { - description: None, - config_file: Some(role_path), - nickname_candidates: None, - }, - ); - - apply_role_to_config(&mut config, Some("custom")) - .await - .expect("custom role should apply"); - - assert_eq!(config.active_profile.as_deref(), Some("role-profile")); - assert_eq!(config.model_provider_id, "role-provider"); - assert_eq!(config.model_provider.name, "Role Provider"); - } - - #[tokio::test] - async fn apply_role_uses_role_model_provider_instead_of_current_profile_provider() { - let home = TempDir::new().expect("create temp dir"); - tokio::fs::write( - home.path().join(CONFIG_TOML_FILE), - r#" -[model_providers.base-provider] -name = "Base Provider" -base_url = "https://base.example.com/v1" -env_key = "BASE_PROVIDER_API_KEY" -wire_api = "responses" - -[model_providers.role-provider] -name = "Role Provider" -base_url = "https://role.example.com/v1" -env_key = "ROLE_PROVIDER_API_KEY" -wire_api = "responses" - -[profiles.base-profile] -model_provider = "base-provider" -"#, - ) - .await - .expect("write config.toml"); - let mut config = ConfigBuilder::default() - .codex_home(home.path().to_path_buf()) - .harness_overrides(ConfigOverrides { - config_profile: Some("base-profile".to_string()), - ..Default::default() - }) - .fallback_cwd(Some(home.path().to_path_buf())) - .build() - .await - .expect("load config"); - let role_path = write_role_config( - &home, - "provider-role.toml", - "developer_instructions = \"Stay focused\"\nmodel_provider = \"role-provider\"", - ) - .await; - config.agent_roles.insert( - "custom".to_string(), - AgentRoleConfig { - description: None, - config_file: Some(role_path), - nickname_candidates: None, - }, - ); - - apply_role_to_config(&mut config, Some("custom")) - .await - .expect("custom role should apply"); - - assert_eq!(config.active_profile, None); - assert_eq!(config.model_provider_id, "role-provider"); - assert_eq!(config.model_provider.name, "Role Provider"); - } - - #[tokio::test] - async fn apply_role_uses_active_profile_model_provider_update() { - let home = TempDir::new().expect("create temp dir"); - tokio::fs::write( - home.path().join(CONFIG_TOML_FILE), - r#" -[model_providers.base-provider] -name = "Base Provider" -base_url = "https://base.example.com/v1" -env_key = "BASE_PROVIDER_API_KEY" -wire_api = "responses" - -[model_providers.role-provider] -name = "Role Provider" -base_url = "https://role.example.com/v1" -env_key = "ROLE_PROVIDER_API_KEY" -wire_api = "responses" - -[profiles.base-profile] -model_provider = "base-provider" -model_reasoning_effort = "low" -"#, - ) - .await - .expect("write config.toml"); - let mut config = ConfigBuilder::default() - .codex_home(home.path().to_path_buf()) - .harness_overrides(ConfigOverrides { - config_profile: Some("base-profile".to_string()), - ..Default::default() - }) - .fallback_cwd(Some(home.path().to_path_buf())) - .build() - .await - .expect("load config"); - let role_path = write_role_config( - &home, - "profile-edit-role.toml", - r#"developer_instructions = "Stay focused" - -[profiles.base-profile] -model_provider = "role-provider" -model_reasoning_effort = "high" -"#, - ) - .await; - config.agent_roles.insert( - "custom".to_string(), - AgentRoleConfig { - description: None, - config_file: Some(role_path), - nickname_candidates: None, - }, - ); - - apply_role_to_config(&mut config, Some("custom")) - .await - .expect("custom role should apply"); - - assert_eq!(config.active_profile.as_deref(), Some("base-profile")); - assert_eq!(config.model_provider_id, "role-provider"); - assert_eq!(config.model_provider.name, "Role Provider"); - assert_eq!(config.model_reasoning_effort, Some(ReasoningEffort::High)); - } - - #[tokio::test] - #[cfg(not(windows))] - async fn apply_role_does_not_materialize_default_sandbox_workspace_write_fields() { - use codex_protocol::protocol::SandboxPolicy; - let (home, mut config) = test_config_with_cli_overrides(vec![ - ( - "sandbox_mode".to_string(), - TomlValue::String("workspace-write".to_string()), - ), - ( - "sandbox_workspace_write.network_access".to_string(), - TomlValue::Boolean(true), - ), - ]) - .await; - let role_path = write_role_config( - &home, - "sandbox-role.toml", - r#"developer_instructions = "Stay focused" - -[sandbox_workspace_write] -writable_roots = ["./sandbox-root"] -"#, - ) - .await; - config.agent_roles.insert( - "custom".to_string(), - AgentRoleConfig { - description: None, - config_file: Some(role_path), - nickname_candidates: None, - }, - ); - - apply_role_to_config(&mut config, Some("custom")) - .await - .expect("custom role should apply"); - - let role_layer = config - .config_layer_stack - .get_layers(ConfigLayerStackOrdering::LowestPrecedenceFirst, true) - .into_iter() - .rfind(|layer| layer.name == ConfigLayerSource::SessionFlags) - .expect("expected a session flags layer"); - let sandbox_workspace_write = role_layer - .config - .get("sandbox_workspace_write") - .and_then(TomlValue::as_table) - .expect("role layer should include sandbox_workspace_write"); - assert_eq!( - sandbox_workspace_write.contains_key("network_access"), - false - ); - assert_eq!( - sandbox_workspace_write.contains_key("exclude_tmpdir_env_var"), - false - ); - assert_eq!( - sandbox_workspace_write.contains_key("exclude_slash_tmp"), - false - ); - - match &*config.permissions.sandbox_policy { - SandboxPolicy::WorkspaceWrite { network_access, .. } => { - assert_eq!(*network_access, true); - } - other => panic!("expected workspace-write sandbox policy, got {other:?}"), - } - } - - #[tokio::test] - async fn apply_role_takes_precedence_over_existing_session_flags_for_same_key() { - let (home, mut config) = test_config_with_cli_overrides(vec![( - "model".to_string(), - TomlValue::String("cli-model".to_string()), - )]) - .await; - let before_layers = session_flags_layer_count(&config); - let role_path = write_role_config( - &home, - "model-role.toml", - "developer_instructions = \"Stay focused\"\nmodel = \"role-model\"", - ) - .await; - config.agent_roles.insert( - "custom".to_string(), - AgentRoleConfig { - description: None, - config_file: Some(role_path), - nickname_candidates: None, - }, - ); - - apply_role_to_config(&mut config, Some("custom")) - .await - .expect("custom role should apply"); - - assert_eq!(config.model.as_deref(), Some("role-model")); - assert_eq!(session_flags_layer_count(&config), before_layers + 1); - } - - #[cfg_attr(windows, ignore)] - #[tokio::test] - async fn apply_role_skills_config_disables_skill_for_spawned_agent() { - let (home, mut config) = test_config_with_cli_overrides(Vec::new()).await; - let skill_dir = home.path().join("skills").join("demo"); - fs::create_dir_all(&skill_dir).expect("create skill dir"); - let skill_path = skill_dir.join("SKILL.md"); - fs::write( - &skill_path, - "---\nname: demo-skill\ndescription: demo description\n---\n\n# Body\n", - ) - .expect("write skill"); - let role_path = write_role_config( - &home, - "skills-role.toml", - &format!( - r#"developer_instructions = "Stay focused" - -[[skills.config]] -path = "{}" -enabled = false -"#, - skill_path.display() - ), - ) - .await; - config.agent_roles.insert( - "custom".to_string(), - AgentRoleConfig { - description: None, - config_file: Some(role_path), - nickname_candidates: None, - }, - ); - - apply_role_to_config(&mut config, Some("custom")) - .await - .expect("custom role should apply"); - - let plugins_manager = Arc::new(PluginsManager::new(home.path().to_path_buf())); - let skills_manager = SkillsManager::new(home.path().to_path_buf(), plugins_manager, true); - let outcome = skills_manager.skills_for_config(&config); - let skill = outcome - .skills - .iter() - .find(|skill| skill.name == "demo-skill") - .expect("demo skill should be discovered"); - - assert_eq!(outcome.is_skill_enabled(skill), false); - } - - #[test] - fn spawn_tool_spec_build_deduplicates_user_defined_built_in_roles() { - let user_defined_roles = BTreeMap::from([ - ( - "explorer".to_string(), - AgentRoleConfig { - description: Some("user override".to_string()), - config_file: None, - nickname_candidates: None, - }, - ), - ("researcher".to_string(), AgentRoleConfig::default()), - ]); - - let spec = spawn_tool_spec::build(&user_defined_roles); - - assert!(spec.contains("researcher: no description")); - assert!(spec.contains("explorer: {\nuser override\n}")); - assert!(spec.contains("default: {\nDefault agent.\n}")); - assert!(!spec.contains("Explorers are fast and authoritative.")); - } - - #[test] - fn spawn_tool_spec_lists_user_defined_roles_before_built_ins() { - let user_defined_roles = BTreeMap::from([( - "aaa".to_string(), - AgentRoleConfig { - description: Some("first".to_string()), - config_file: None, - nickname_candidates: None, - }, - )]); - - let spec = spawn_tool_spec::build(&user_defined_roles); - let user_index = spec.find("aaa: {\nfirst\n}").expect("find user role"); - let built_in_index = spec - .find("default: {\nDefault agent.\n}") - .expect("find built-in role"); - - assert!(user_index < built_in_index); - } - - #[test] - fn spawn_tool_spec_marks_role_locked_model_and_reasoning_effort() { - let tempdir = TempDir::new().expect("create temp dir"); - let role_path = tempdir.path().join("researcher.toml"); - fs::write( - &role_path, - "developer_instructions = \"Research carefully\"\nmodel = \"gpt-5\"\nmodel_reasoning_effort = \"high\"\n", - ) - .expect("write role config"); - let user_defined_roles = BTreeMap::from([( - "researcher".to_string(), - AgentRoleConfig { - description: Some("Research carefully.".to_string()), - config_file: Some(role_path), - nickname_candidates: None, - }, - )]); - - let spec = spawn_tool_spec::build(&user_defined_roles); - - assert!(spec.contains( - "Research carefully.\n- This role's model is set to `gpt-5` and its reasoning effort is set to `high`. These settings cannot be changed." - )); - } - - #[test] - fn spawn_tool_spec_marks_role_locked_reasoning_effort_only() { - let tempdir = TempDir::new().expect("create temp dir"); - let role_path = tempdir.path().join("reviewer.toml"); - fs::write( - &role_path, - "developer_instructions = \"Review carefully\"\nmodel_reasoning_effort = \"medium\"\n", - ) - .expect("write role config"); - let user_defined_roles = BTreeMap::from([( - "reviewer".to_string(), - AgentRoleConfig { - description: Some("Review carefully.".to_string()), - config_file: Some(role_path), - nickname_candidates: None, - }, - )]); - - let spec = spawn_tool_spec::build(&user_defined_roles); - - assert!(spec.contains( - "Review carefully.\n- This role's reasoning effort is set to `medium` and cannot be changed." - )); - } - - #[test] - fn built_in_config_file_contents_resolves_explorer_only() { - assert_eq!( - built_in::config_file_contents(Path::new("missing.toml")), - None - ); - } -} +#[path = "role_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/agent/role_tests.rs b/codex-rs/core/src/agent/role_tests.rs new file mode 100644 index 0000000000..cb04aa4e8f --- /dev/null +++ b/codex-rs/core/src/agent/role_tests.rs @@ -0,0 +1,680 @@ +use super::*; +use crate::config::CONFIG_TOML_FILE; +use crate::config::ConfigBuilder; +use crate::config_loader::ConfigLayerStackOrdering; +use crate::plugins::PluginsManager; +use crate::skills::SkillsManager; +use codex_protocol::openai_models::ReasoningEffort; +use pretty_assertions::assert_eq; +use std::fs; +use std::path::PathBuf; +use std::sync::Arc; +use tempfile::TempDir; + +async fn test_config_with_cli_overrides( + cli_overrides: Vec<(String, TomlValue)>, +) -> (TempDir, Config) { + let home = TempDir::new().expect("create temp dir"); + let home_path = home.path().to_path_buf(); + let config = ConfigBuilder::default() + .codex_home(home_path.clone()) + .cli_overrides(cli_overrides) + .fallback_cwd(Some(home_path)) + .build() + .await + .expect("load test config"); + (home, config) +} + +async fn write_role_config(home: &TempDir, name: &str, contents: &str) -> PathBuf { + let role_path = home.path().join(name); + tokio::fs::write(&role_path, contents) + .await + .expect("write role config"); + role_path +} + +fn session_flags_layer_count(config: &Config) -> usize { + config + .config_layer_stack + .get_layers(ConfigLayerStackOrdering::LowestPrecedenceFirst, true) + .into_iter() + .filter(|layer| layer.name == ConfigLayerSource::SessionFlags) + .count() +} + +#[tokio::test] +async fn apply_role_defaults_to_default_and_leaves_config_unchanged() { + let (_home, mut config) = test_config_with_cli_overrides(Vec::new()).await; + let before = config.clone(); + + apply_role_to_config(&mut config, None) + .await + .expect("default role should apply"); + + assert_eq!(before, config); +} + +#[tokio::test] +async fn apply_role_returns_error_for_unknown_role() { + let (_home, mut config) = test_config_with_cli_overrides(Vec::new()).await; + + let err = apply_role_to_config(&mut config, Some("missing-role")) + .await + .expect_err("unknown role should fail"); + + assert_eq!(err, "unknown agent_type 'missing-role'"); +} + +#[tokio::test] +#[ignore = "No role requiring it for now"] +async fn apply_explorer_role_sets_model_and_adds_session_flags_layer() { + let (_home, mut config) = test_config_with_cli_overrides(Vec::new()).await; + let before_layers = session_flags_layer_count(&config); + + apply_role_to_config(&mut config, Some("explorer")) + .await + .expect("explorer role should apply"); + + assert_eq!(config.model.as_deref(), Some("gpt-5.1-codex-mini")); + assert_eq!(config.model_reasoning_effort, Some(ReasoningEffort::Medium)); + assert_eq!(session_flags_layer_count(&config), before_layers + 1); +} + +#[tokio::test] +async fn apply_role_returns_unavailable_for_missing_user_role_file() { + let (_home, mut config) = test_config_with_cli_overrides(Vec::new()).await; + config.agent_roles.insert( + "custom".to_string(), + AgentRoleConfig { + description: None, + config_file: Some(PathBuf::from("/path/does/not/exist.toml")), + nickname_candidates: None, + }, + ); + + let err = apply_role_to_config(&mut config, Some("custom")) + .await + .expect_err("missing role file should fail"); + + assert_eq!(err, AGENT_TYPE_UNAVAILABLE_ERROR); +} + +#[tokio::test] +async fn apply_role_returns_unavailable_for_invalid_user_role_toml() { + let (home, mut config) = test_config_with_cli_overrides(Vec::new()).await; + let role_path = write_role_config(&home, "invalid-role.toml", "model = [").await; + config.agent_roles.insert( + "custom".to_string(), + AgentRoleConfig { + description: None, + config_file: Some(role_path), + nickname_candidates: None, + }, + ); + + let err = apply_role_to_config(&mut config, Some("custom")) + .await + .expect_err("invalid role file should fail"); + + assert_eq!(err, AGENT_TYPE_UNAVAILABLE_ERROR); +} + +#[tokio::test] +async fn apply_role_ignores_agent_metadata_fields_in_user_role_file() { + let (home, mut config) = test_config_with_cli_overrides(Vec::new()).await; + let role_path = write_role_config( + &home, + "metadata-role.toml", + r#" +name = "archivist" +description = "Role metadata" +nickname_candidates = ["Hypatia"] +developer_instructions = "Stay focused" +model = "role-model" +"#, + ) + .await; + config.agent_roles.insert( + "custom".to_string(), + AgentRoleConfig { + description: None, + config_file: Some(role_path), + nickname_candidates: None, + }, + ); + + apply_role_to_config(&mut config, Some("custom")) + .await + .expect("custom role should apply"); + + assert_eq!(config.model.as_deref(), Some("role-model")); +} + +#[tokio::test] +async fn apply_role_preserves_unspecified_keys() { + let (home, mut config) = test_config_with_cli_overrides(vec![( + "model".to_string(), + TomlValue::String("base-model".to_string()), + )]) + .await; + config.codex_linux_sandbox_exe = Some(PathBuf::from("/tmp/codex-linux-sandbox")); + config.main_execve_wrapper_exe = Some(PathBuf::from("/tmp/codex-execve-wrapper")); + let role_path = write_role_config( + &home, + "effort-only.toml", + "developer_instructions = \"Stay focused\"\nmodel_reasoning_effort = \"high\"", + ) + .await; + config.agent_roles.insert( + "custom".to_string(), + AgentRoleConfig { + description: None, + config_file: Some(role_path), + nickname_candidates: None, + }, + ); + + apply_role_to_config(&mut config, Some("custom")) + .await + .expect("custom role should apply"); + + assert_eq!(config.model.as_deref(), Some("base-model")); + assert_eq!(config.model_reasoning_effort, Some(ReasoningEffort::High)); + assert_eq!( + config.codex_linux_sandbox_exe, + Some(PathBuf::from("/tmp/codex-linux-sandbox")) + ); + assert_eq!( + config.main_execve_wrapper_exe, + Some(PathBuf::from("/tmp/codex-execve-wrapper")) + ); +} + +#[tokio::test] +async fn apply_role_preserves_active_profile_and_model_provider() { + let home = TempDir::new().expect("create temp dir"); + tokio::fs::write( + home.path().join(CONFIG_TOML_FILE), + r#" +[model_providers.test-provider] +name = "Test Provider" +base_url = "https://example.com/v1" +env_key = "TEST_PROVIDER_API_KEY" +wire_api = "responses" + +[profiles.test-profile] +model_provider = "test-provider" +"#, + ) + .await + .expect("write config.toml"); + let mut config = ConfigBuilder::default() + .codex_home(home.path().to_path_buf()) + .harness_overrides(ConfigOverrides { + config_profile: Some("test-profile".to_string()), + ..Default::default() + }) + .fallback_cwd(Some(home.path().to_path_buf())) + .build() + .await + .expect("load config"); + let role_path = write_role_config( + &home, + "empty-role.toml", + "developer_instructions = \"Stay focused\"", + ) + .await; + config.agent_roles.insert( + "custom".to_string(), + AgentRoleConfig { + description: None, + config_file: Some(role_path), + nickname_candidates: None, + }, + ); + + apply_role_to_config(&mut config, Some("custom")) + .await + .expect("custom role should apply"); + + assert_eq!(config.active_profile.as_deref(), Some("test-profile")); + assert_eq!(config.model_provider_id, "test-provider"); + assert_eq!(config.model_provider.name, "Test Provider"); +} + +#[tokio::test] +async fn apply_role_uses_role_profile_instead_of_current_profile() { + let home = TempDir::new().expect("create temp dir"); + tokio::fs::write( + home.path().join(CONFIG_TOML_FILE), + r#" +[model_providers.base-provider] +name = "Base Provider" +base_url = "https://base.example.com/v1" +env_key = "BASE_PROVIDER_API_KEY" +wire_api = "responses" + +[model_providers.role-provider] +name = "Role Provider" +base_url = "https://role.example.com/v1" +env_key = "ROLE_PROVIDER_API_KEY" +wire_api = "responses" + +[profiles.base-profile] +model_provider = "base-provider" + +[profiles.role-profile] +model_provider = "role-provider" +"#, + ) + .await + .expect("write config.toml"); + let mut config = ConfigBuilder::default() + .codex_home(home.path().to_path_buf()) + .harness_overrides(ConfigOverrides { + config_profile: Some("base-profile".to_string()), + ..Default::default() + }) + .fallback_cwd(Some(home.path().to_path_buf())) + .build() + .await + .expect("load config"); + let role_path = write_role_config( + &home, + "profile-role.toml", + "developer_instructions = \"Stay focused\"\nprofile = \"role-profile\"", + ) + .await; + config.agent_roles.insert( + "custom".to_string(), + AgentRoleConfig { + description: None, + config_file: Some(role_path), + nickname_candidates: None, + }, + ); + + apply_role_to_config(&mut config, Some("custom")) + .await + .expect("custom role should apply"); + + assert_eq!(config.active_profile.as_deref(), Some("role-profile")); + assert_eq!(config.model_provider_id, "role-provider"); + assert_eq!(config.model_provider.name, "Role Provider"); +} + +#[tokio::test] +async fn apply_role_uses_role_model_provider_instead_of_current_profile_provider() { + let home = TempDir::new().expect("create temp dir"); + tokio::fs::write( + home.path().join(CONFIG_TOML_FILE), + r#" +[model_providers.base-provider] +name = "Base Provider" +base_url = "https://base.example.com/v1" +env_key = "BASE_PROVIDER_API_KEY" +wire_api = "responses" + +[model_providers.role-provider] +name = "Role Provider" +base_url = "https://role.example.com/v1" +env_key = "ROLE_PROVIDER_API_KEY" +wire_api = "responses" + +[profiles.base-profile] +model_provider = "base-provider" +"#, + ) + .await + .expect("write config.toml"); + let mut config = ConfigBuilder::default() + .codex_home(home.path().to_path_buf()) + .harness_overrides(ConfigOverrides { + config_profile: Some("base-profile".to_string()), + ..Default::default() + }) + .fallback_cwd(Some(home.path().to_path_buf())) + .build() + .await + .expect("load config"); + let role_path = write_role_config( + &home, + "provider-role.toml", + "developer_instructions = \"Stay focused\"\nmodel_provider = \"role-provider\"", + ) + .await; + config.agent_roles.insert( + "custom".to_string(), + AgentRoleConfig { + description: None, + config_file: Some(role_path), + nickname_candidates: None, + }, + ); + + apply_role_to_config(&mut config, Some("custom")) + .await + .expect("custom role should apply"); + + assert_eq!(config.active_profile, None); + assert_eq!(config.model_provider_id, "role-provider"); + assert_eq!(config.model_provider.name, "Role Provider"); +} + +#[tokio::test] +async fn apply_role_uses_active_profile_model_provider_update() { + let home = TempDir::new().expect("create temp dir"); + tokio::fs::write( + home.path().join(CONFIG_TOML_FILE), + r#" +[model_providers.base-provider] +name = "Base Provider" +base_url = "https://base.example.com/v1" +env_key = "BASE_PROVIDER_API_KEY" +wire_api = "responses" + +[model_providers.role-provider] +name = "Role Provider" +base_url = "https://role.example.com/v1" +env_key = "ROLE_PROVIDER_API_KEY" +wire_api = "responses" + +[profiles.base-profile] +model_provider = "base-provider" +model_reasoning_effort = "low" +"#, + ) + .await + .expect("write config.toml"); + let mut config = ConfigBuilder::default() + .codex_home(home.path().to_path_buf()) + .harness_overrides(ConfigOverrides { + config_profile: Some("base-profile".to_string()), + ..Default::default() + }) + .fallback_cwd(Some(home.path().to_path_buf())) + .build() + .await + .expect("load config"); + let role_path = write_role_config( + &home, + "profile-edit-role.toml", + r#"developer_instructions = "Stay focused" + +[profiles.base-profile] +model_provider = "role-provider" +model_reasoning_effort = "high" +"#, + ) + .await; + config.agent_roles.insert( + "custom".to_string(), + AgentRoleConfig { + description: None, + config_file: Some(role_path), + nickname_candidates: None, + }, + ); + + apply_role_to_config(&mut config, Some("custom")) + .await + .expect("custom role should apply"); + + assert_eq!(config.active_profile.as_deref(), Some("base-profile")); + assert_eq!(config.model_provider_id, "role-provider"); + assert_eq!(config.model_provider.name, "Role Provider"); + assert_eq!(config.model_reasoning_effort, Some(ReasoningEffort::High)); +} + +#[tokio::test] +#[cfg(not(windows))] +async fn apply_role_does_not_materialize_default_sandbox_workspace_write_fields() { + use codex_protocol::protocol::SandboxPolicy; + let (home, mut config) = test_config_with_cli_overrides(vec![ + ( + "sandbox_mode".to_string(), + TomlValue::String("workspace-write".to_string()), + ), + ( + "sandbox_workspace_write.network_access".to_string(), + TomlValue::Boolean(true), + ), + ]) + .await; + let role_path = write_role_config( + &home, + "sandbox-role.toml", + r#"developer_instructions = "Stay focused" + +[sandbox_workspace_write] +writable_roots = ["./sandbox-root"] +"#, + ) + .await; + config.agent_roles.insert( + "custom".to_string(), + AgentRoleConfig { + description: None, + config_file: Some(role_path), + nickname_candidates: None, + }, + ); + + apply_role_to_config(&mut config, Some("custom")) + .await + .expect("custom role should apply"); + + let role_layer = config + .config_layer_stack + .get_layers(ConfigLayerStackOrdering::LowestPrecedenceFirst, true) + .into_iter() + .rfind(|layer| layer.name == ConfigLayerSource::SessionFlags) + .expect("expected a session flags layer"); + let sandbox_workspace_write = role_layer + .config + .get("sandbox_workspace_write") + .and_then(TomlValue::as_table) + .expect("role layer should include sandbox_workspace_write"); + assert_eq!( + sandbox_workspace_write.contains_key("network_access"), + false + ); + assert_eq!( + sandbox_workspace_write.contains_key("exclude_tmpdir_env_var"), + false + ); + assert_eq!( + sandbox_workspace_write.contains_key("exclude_slash_tmp"), + false + ); + + match &*config.permissions.sandbox_policy { + SandboxPolicy::WorkspaceWrite { network_access, .. } => { + assert_eq!(*network_access, true); + } + other => panic!("expected workspace-write sandbox policy, got {other:?}"), + } +} + +#[tokio::test] +async fn apply_role_takes_precedence_over_existing_session_flags_for_same_key() { + let (home, mut config) = test_config_with_cli_overrides(vec![( + "model".to_string(), + TomlValue::String("cli-model".to_string()), + )]) + .await; + let before_layers = session_flags_layer_count(&config); + let role_path = write_role_config( + &home, + "model-role.toml", + "developer_instructions = \"Stay focused\"\nmodel = \"role-model\"", + ) + .await; + config.agent_roles.insert( + "custom".to_string(), + AgentRoleConfig { + description: None, + config_file: Some(role_path), + nickname_candidates: None, + }, + ); + + apply_role_to_config(&mut config, Some("custom")) + .await + .expect("custom role should apply"); + + assert_eq!(config.model.as_deref(), Some("role-model")); + assert_eq!(session_flags_layer_count(&config), before_layers + 1); +} + +#[cfg_attr(windows, ignore)] +#[tokio::test] +async fn apply_role_skills_config_disables_skill_for_spawned_agent() { + let (home, mut config) = test_config_with_cli_overrides(Vec::new()).await; + let skill_dir = home.path().join("skills").join("demo"); + fs::create_dir_all(&skill_dir).expect("create skill dir"); + let skill_path = skill_dir.join("SKILL.md"); + fs::write( + &skill_path, + "---\nname: demo-skill\ndescription: demo description\n---\n\n# Body\n", + ) + .expect("write skill"); + let role_path = write_role_config( + &home, + "skills-role.toml", + &format!( + r#"developer_instructions = "Stay focused" + +[[skills.config]] +path = "{}" +enabled = false +"#, + skill_path.display() + ), + ) + .await; + config.agent_roles.insert( + "custom".to_string(), + AgentRoleConfig { + description: None, + config_file: Some(role_path), + nickname_candidates: None, + }, + ); + + apply_role_to_config(&mut config, Some("custom")) + .await + .expect("custom role should apply"); + + let plugins_manager = Arc::new(PluginsManager::new(home.path().to_path_buf())); + let skills_manager = SkillsManager::new(home.path().to_path_buf(), plugins_manager, true); + let outcome = skills_manager.skills_for_config(&config); + let skill = outcome + .skills + .iter() + .find(|skill| skill.name == "demo-skill") + .expect("demo skill should be discovered"); + + assert_eq!(outcome.is_skill_enabled(skill), false); +} + +#[test] +fn spawn_tool_spec_build_deduplicates_user_defined_built_in_roles() { + let user_defined_roles = BTreeMap::from([ + ( + "explorer".to_string(), + AgentRoleConfig { + description: Some("user override".to_string()), + config_file: None, + nickname_candidates: None, + }, + ), + ("researcher".to_string(), AgentRoleConfig::default()), + ]); + + let spec = spawn_tool_spec::build(&user_defined_roles); + + assert!(spec.contains("researcher: no description")); + assert!(spec.contains("explorer: {\nuser override\n}")); + assert!(spec.contains("default: {\nDefault agent.\n}")); + assert!(!spec.contains("Explorers are fast and authoritative.")); +} + +#[test] +fn spawn_tool_spec_lists_user_defined_roles_before_built_ins() { + let user_defined_roles = BTreeMap::from([( + "aaa".to_string(), + AgentRoleConfig { + description: Some("first".to_string()), + config_file: None, + nickname_candidates: None, + }, + )]); + + let spec = spawn_tool_spec::build(&user_defined_roles); + let user_index = spec.find("aaa: {\nfirst\n}").expect("find user role"); + let built_in_index = spec + .find("default: {\nDefault agent.\n}") + .expect("find built-in role"); + + assert!(user_index < built_in_index); +} + +#[test] +fn spawn_tool_spec_marks_role_locked_model_and_reasoning_effort() { + let tempdir = TempDir::new().expect("create temp dir"); + let role_path = tempdir.path().join("researcher.toml"); + fs::write( + &role_path, + "developer_instructions = \"Research carefully\"\nmodel = \"gpt-5\"\nmodel_reasoning_effort = \"high\"\n", + ) + .expect("write role config"); + let user_defined_roles = BTreeMap::from([( + "researcher".to_string(), + AgentRoleConfig { + description: Some("Research carefully.".to_string()), + config_file: Some(role_path), + nickname_candidates: None, + }, + )]); + + let spec = spawn_tool_spec::build(&user_defined_roles); + + assert!(spec.contains( + "Research carefully.\n- This role's model is set to `gpt-5` and its reasoning effort is set to `high`. These settings cannot be changed." + )); +} + +#[test] +fn spawn_tool_spec_marks_role_locked_reasoning_effort_only() { + let tempdir = TempDir::new().expect("create temp dir"); + let role_path = tempdir.path().join("reviewer.toml"); + fs::write( + &role_path, + "developer_instructions = \"Review carefully\"\nmodel_reasoning_effort = \"medium\"\n", + ) + .expect("write role config"); + let user_defined_roles = BTreeMap::from([( + "reviewer".to_string(), + AgentRoleConfig { + description: Some("Review carefully.".to_string()), + config_file: Some(role_path), + nickname_candidates: None, + }, + )]); + + let spec = spawn_tool_spec::build(&user_defined_roles); + + assert!(spec.contains( + "Review carefully.\n- This role's reasoning effort is set to `medium` and cannot be changed." + )); +} + +#[test] +fn built_in_config_file_contents_resolves_explorer_only() { + assert_eq!( + built_in::config_file_contents(Path::new("missing.toml")), + None + ); +} diff --git a/codex-rs/core/src/analytics_client.rs b/codex-rs/core/src/analytics_client.rs index c8829eda00..f2df7e010d 100644 --- a/codex-rs/core/src/analytics_client.rs +++ b/codex-rs/core/src/analytics_client.rs @@ -489,182 +489,5 @@ fn normalize_path_for_skill_id( } #[cfg(test)] -mod tests { - use super::AnalyticsEventsQueue; - use super::AppInvocation; - use super::CodexAppMentionedEventRequest; - use super::CodexAppUsedEventRequest; - use super::InvocationType; - use super::TrackEventRequest; - use super::TrackEventsContext; - use super::codex_app_metadata; - use super::normalize_path_for_skill_id; - use pretty_assertions::assert_eq; - use serde_json::json; - use std::collections::HashSet; - use std::path::PathBuf; - use std::sync::Arc; - use std::sync::Mutex; - use tokio::sync::mpsc; - - fn expected_absolute_path(path: &PathBuf) -> String { - std::fs::canonicalize(path) - .unwrap_or_else(|_| path.to_path_buf()) - .to_string_lossy() - .replace('\\', "/") - } - - #[test] - fn normalize_path_for_skill_id_repo_scoped_uses_relative_path() { - let repo_root = PathBuf::from("/repo/root"); - let skill_path = PathBuf::from("/repo/root/.codex/skills/doc/SKILL.md"); - - let path = normalize_path_for_skill_id( - Some("https://example.com/repo.git"), - Some(repo_root.as_path()), - skill_path.as_path(), - ); - - assert_eq!(path, ".codex/skills/doc/SKILL.md"); - } - - #[test] - fn normalize_path_for_skill_id_user_scoped_uses_absolute_path() { - let skill_path = PathBuf::from("/Users/abc/.codex/skills/doc/SKILL.md"); - - let path = normalize_path_for_skill_id(None, None, skill_path.as_path()); - let expected = expected_absolute_path(&skill_path); - - assert_eq!(path, expected); - } - - #[test] - fn normalize_path_for_skill_id_admin_scoped_uses_absolute_path() { - let skill_path = PathBuf::from("/etc/codex/skills/doc/SKILL.md"); - - let path = normalize_path_for_skill_id(None, None, skill_path.as_path()); - let expected = expected_absolute_path(&skill_path); - - assert_eq!(path, expected); - } - - #[test] - fn normalize_path_for_skill_id_repo_root_not_in_skill_path_uses_absolute_path() { - let repo_root = PathBuf::from("/repo/root"); - let skill_path = PathBuf::from("/other/path/.codex/skills/doc/SKILL.md"); - - let path = normalize_path_for_skill_id( - Some("https://example.com/repo.git"), - Some(repo_root.as_path()), - skill_path.as_path(), - ); - let expected = expected_absolute_path(&skill_path); - - assert_eq!(path, expected); - } - - #[test] - fn app_mentioned_event_serializes_expected_shape() { - let tracking = TrackEventsContext { - model_slug: "gpt-5".to_string(), - thread_id: "thread-1".to_string(), - turn_id: "turn-1".to_string(), - }; - let event = TrackEventRequest::AppMentioned(CodexAppMentionedEventRequest { - event_type: "codex_app_mentioned", - event_params: codex_app_metadata( - &tracking, - AppInvocation { - connector_id: Some("calendar".to_string()), - app_name: Some("Calendar".to_string()), - invocation_type: Some(InvocationType::Explicit), - }, - ), - }); - - let payload = serde_json::to_value(&event).expect("serialize app mentioned event"); - - assert_eq!( - payload, - json!({ - "event_type": "codex_app_mentioned", - "event_params": { - "connector_id": "calendar", - "thread_id": "thread-1", - "turn_id": "turn-1", - "app_name": "Calendar", - "product_client_id": crate::default_client::originator().value, - "invoke_type": "explicit", - "model_slug": "gpt-5" - } - }) - ); - } - - #[test] - fn app_used_event_serializes_expected_shape() { - let tracking = TrackEventsContext { - model_slug: "gpt-5".to_string(), - thread_id: "thread-2".to_string(), - turn_id: "turn-2".to_string(), - }; - let event = TrackEventRequest::AppUsed(CodexAppUsedEventRequest { - event_type: "codex_app_used", - event_params: codex_app_metadata( - &tracking, - AppInvocation { - connector_id: Some("drive".to_string()), - app_name: Some("Google Drive".to_string()), - invocation_type: Some(InvocationType::Implicit), - }, - ), - }); - - let payload = serde_json::to_value(&event).expect("serialize app used event"); - - assert_eq!( - payload, - json!({ - "event_type": "codex_app_used", - "event_params": { - "connector_id": "drive", - "thread_id": "thread-2", - "turn_id": "turn-2", - "app_name": "Google Drive", - "product_client_id": crate::default_client::originator().value, - "invoke_type": "implicit", - "model_slug": "gpt-5" - } - }) - ); - } - - #[test] - fn app_used_dedupe_is_keyed_by_turn_and_connector() { - let (sender, _receiver) = mpsc::channel(1); - let queue = AnalyticsEventsQueue { - sender, - app_used_emitted_keys: Arc::new(Mutex::new(HashSet::new())), - }; - let app = AppInvocation { - connector_id: Some("calendar".to_string()), - app_name: Some("Calendar".to_string()), - invocation_type: Some(InvocationType::Implicit), - }; - - let turn_1 = TrackEventsContext { - model_slug: "gpt-5".to_string(), - thread_id: "thread-1".to_string(), - turn_id: "turn-1".to_string(), - }; - let turn_2 = TrackEventsContext { - model_slug: "gpt-5".to_string(), - thread_id: "thread-1".to_string(), - turn_id: "turn-2".to_string(), - }; - - assert_eq!(queue.should_enqueue_app_used(&turn_1, &app), true); - assert_eq!(queue.should_enqueue_app_used(&turn_1, &app), false); - assert_eq!(queue.should_enqueue_app_used(&turn_2, &app), true); - } -} +#[path = "analytics_client_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/analytics_client_tests.rs b/codex-rs/core/src/analytics_client_tests.rs new file mode 100644 index 0000000000..66e9a1234b --- /dev/null +++ b/codex-rs/core/src/analytics_client_tests.rs @@ -0,0 +1,177 @@ +use super::AnalyticsEventsQueue; +use super::AppInvocation; +use super::CodexAppMentionedEventRequest; +use super::CodexAppUsedEventRequest; +use super::InvocationType; +use super::TrackEventRequest; +use super::TrackEventsContext; +use super::codex_app_metadata; +use super::normalize_path_for_skill_id; +use pretty_assertions::assert_eq; +use serde_json::json; +use std::collections::HashSet; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::Mutex; +use tokio::sync::mpsc; + +fn expected_absolute_path(path: &PathBuf) -> String { + std::fs::canonicalize(path) + .unwrap_or_else(|_| path.to_path_buf()) + .to_string_lossy() + .replace('\\', "/") +} + +#[test] +fn normalize_path_for_skill_id_repo_scoped_uses_relative_path() { + let repo_root = PathBuf::from("/repo/root"); + let skill_path = PathBuf::from("/repo/root/.codex/skills/doc/SKILL.md"); + + let path = normalize_path_for_skill_id( + Some("https://example.com/repo.git"), + Some(repo_root.as_path()), + skill_path.as_path(), + ); + + assert_eq!(path, ".codex/skills/doc/SKILL.md"); +} + +#[test] +fn normalize_path_for_skill_id_user_scoped_uses_absolute_path() { + let skill_path = PathBuf::from("/Users/abc/.codex/skills/doc/SKILL.md"); + + let path = normalize_path_for_skill_id(None, None, skill_path.as_path()); + let expected = expected_absolute_path(&skill_path); + + assert_eq!(path, expected); +} + +#[test] +fn normalize_path_for_skill_id_admin_scoped_uses_absolute_path() { + let skill_path = PathBuf::from("/etc/codex/skills/doc/SKILL.md"); + + let path = normalize_path_for_skill_id(None, None, skill_path.as_path()); + let expected = expected_absolute_path(&skill_path); + + assert_eq!(path, expected); +} + +#[test] +fn normalize_path_for_skill_id_repo_root_not_in_skill_path_uses_absolute_path() { + let repo_root = PathBuf::from("/repo/root"); + let skill_path = PathBuf::from("/other/path/.codex/skills/doc/SKILL.md"); + + let path = normalize_path_for_skill_id( + Some("https://example.com/repo.git"), + Some(repo_root.as_path()), + skill_path.as_path(), + ); + let expected = expected_absolute_path(&skill_path); + + assert_eq!(path, expected); +} + +#[test] +fn app_mentioned_event_serializes_expected_shape() { + let tracking = TrackEventsContext { + model_slug: "gpt-5".to_string(), + thread_id: "thread-1".to_string(), + turn_id: "turn-1".to_string(), + }; + let event = TrackEventRequest::AppMentioned(CodexAppMentionedEventRequest { + event_type: "codex_app_mentioned", + event_params: codex_app_metadata( + &tracking, + AppInvocation { + connector_id: Some("calendar".to_string()), + app_name: Some("Calendar".to_string()), + invocation_type: Some(InvocationType::Explicit), + }, + ), + }); + + let payload = serde_json::to_value(&event).expect("serialize app mentioned event"); + + assert_eq!( + payload, + json!({ + "event_type": "codex_app_mentioned", + "event_params": { + "connector_id": "calendar", + "thread_id": "thread-1", + "turn_id": "turn-1", + "app_name": "Calendar", + "product_client_id": crate::default_client::originator().value, + "invoke_type": "explicit", + "model_slug": "gpt-5" + } + }) + ); +} + +#[test] +fn app_used_event_serializes_expected_shape() { + let tracking = TrackEventsContext { + model_slug: "gpt-5".to_string(), + thread_id: "thread-2".to_string(), + turn_id: "turn-2".to_string(), + }; + let event = TrackEventRequest::AppUsed(CodexAppUsedEventRequest { + event_type: "codex_app_used", + event_params: codex_app_metadata( + &tracking, + AppInvocation { + connector_id: Some("drive".to_string()), + app_name: Some("Google Drive".to_string()), + invocation_type: Some(InvocationType::Implicit), + }, + ), + }); + + let payload = serde_json::to_value(&event).expect("serialize app used event"); + + assert_eq!( + payload, + json!({ + "event_type": "codex_app_used", + "event_params": { + "connector_id": "drive", + "thread_id": "thread-2", + "turn_id": "turn-2", + "app_name": "Google Drive", + "product_client_id": crate::default_client::originator().value, + "invoke_type": "implicit", + "model_slug": "gpt-5" + } + }) + ); +} + +#[test] +fn app_used_dedupe_is_keyed_by_turn_and_connector() { + let (sender, _receiver) = mpsc::channel(1); + let queue = AnalyticsEventsQueue { + sender, + app_used_emitted_keys: Arc::new(Mutex::new(HashSet::new())), + }; + let app = AppInvocation { + connector_id: Some("calendar".to_string()), + app_name: Some("Calendar".to_string()), + invocation_type: Some(InvocationType::Implicit), + }; + + let turn_1 = TrackEventsContext { + model_slug: "gpt-5".to_string(), + thread_id: "thread-1".to_string(), + turn_id: "turn-1".to_string(), + }; + let turn_2 = TrackEventsContext { + model_slug: "gpt-5".to_string(), + thread_id: "thread-1".to_string(), + turn_id: "turn-2".to_string(), + }; + + assert_eq!(queue.should_enqueue_app_used(&turn_1, &app), true); + assert_eq!(queue.should_enqueue_app_used(&turn_1, &app), false); + assert_eq!(queue.should_enqueue_app_used(&turn_2, &app), true); +} diff --git a/codex-rs/core/src/api_bridge.rs b/codex-rs/core/src/api_bridge.rs index 3b2024f58b..f363201ae9 100644 --- a/codex-rs/core/src/api_bridge.rs +++ b/codex-rs/core/src/api_bridge.rs @@ -120,104 +120,8 @@ const OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id"; const CF_RAY_HEADER: &str = "cf-ray"; #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn map_api_error_maps_server_overloaded() { - let err = map_api_error(ApiError::ServerOverloaded); - assert!(matches!(err, CodexErr::ServerOverloaded)); - } - - #[test] - fn map_api_error_maps_server_overloaded_from_503_body() { - let body = serde_json::json!({ - "error": { - "code": "server_is_overloaded" - } - }) - .to_string(); - let err = map_api_error(ApiError::Transport(TransportError::Http { - status: http::StatusCode::SERVICE_UNAVAILABLE, - url: Some("http://example.com/v1/responses".to_string()), - headers: None, - body: Some(body), - })); - - assert!(matches!(err, CodexErr::ServerOverloaded)); - } - - #[test] - fn map_api_error_maps_usage_limit_limit_name_header() { - let mut headers = HeaderMap::new(); - headers.insert( - ACTIVE_LIMIT_HEADER, - http::HeaderValue::from_static("codex_other"), - ); - headers.insert( - "x-codex-other-limit-name", - http::HeaderValue::from_static("codex_other"), - ); - let body = serde_json::json!({ - "error": { - "type": "usage_limit_reached", - "plan_type": "pro", - } - }) - .to_string(); - let err = map_api_error(ApiError::Transport(TransportError::Http { - status: http::StatusCode::TOO_MANY_REQUESTS, - url: Some("http://example.com/v1/responses".to_string()), - headers: Some(headers), - body: Some(body), - })); - - let CodexErr::UsageLimitReached(usage_limit) = err else { - panic!("expected CodexErr::UsageLimitReached, got {err:?}"); - }; - assert_eq!( - usage_limit - .rate_limits - .as_ref() - .and_then(|snapshot| snapshot.limit_name.as_deref()), - Some("codex_other") - ); - } - - #[test] - fn map_api_error_does_not_fallback_limit_name_to_limit_id() { - let mut headers = HeaderMap::new(); - headers.insert( - ACTIVE_LIMIT_HEADER, - http::HeaderValue::from_static("codex_other"), - ); - let body = serde_json::json!({ - "error": { - "type": "usage_limit_reached", - "plan_type": "pro", - } - }) - .to_string(); - let err = map_api_error(ApiError::Transport(TransportError::Http { - status: http::StatusCode::TOO_MANY_REQUESTS, - url: Some("http://example.com/v1/responses".to_string()), - headers: Some(headers), - body: Some(body), - })); - - let CodexErr::UsageLimitReached(usage_limit) = err else { - panic!("expected CodexErr::UsageLimitReached, got {err:?}"); - }; - assert_eq!( - usage_limit - .rate_limits - .as_ref() - .and_then(|snapshot| snapshot.limit_name.as_deref()), - None - ); - } -} +#[path = "api_bridge_tests.rs"] +mod tests; fn extract_request_tracking_id(headers: Option<&HeaderMap>) -> Option { extract_request_id(headers).or_else(|| extract_header(headers, CF_RAY_HEADER)) diff --git a/codex-rs/core/src/api_bridge_tests.rs b/codex-rs/core/src/api_bridge_tests.rs new file mode 100644 index 0000000000..e8391021b1 --- /dev/null +++ b/codex-rs/core/src/api_bridge_tests.rs @@ -0,0 +1,96 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn map_api_error_maps_server_overloaded() { + let err = map_api_error(ApiError::ServerOverloaded); + assert!(matches!(err, CodexErr::ServerOverloaded)); +} + +#[test] +fn map_api_error_maps_server_overloaded_from_503_body() { + let body = serde_json::json!({ + "error": { + "code": "server_is_overloaded" + } + }) + .to_string(); + let err = map_api_error(ApiError::Transport(TransportError::Http { + status: http::StatusCode::SERVICE_UNAVAILABLE, + url: Some("http://example.com/v1/responses".to_string()), + headers: None, + body: Some(body), + })); + + assert!(matches!(err, CodexErr::ServerOverloaded)); +} + +#[test] +fn map_api_error_maps_usage_limit_limit_name_header() { + let mut headers = HeaderMap::new(); + headers.insert( + ACTIVE_LIMIT_HEADER, + http::HeaderValue::from_static("codex_other"), + ); + headers.insert( + "x-codex-other-limit-name", + http::HeaderValue::from_static("codex_other"), + ); + let body = serde_json::json!({ + "error": { + "type": "usage_limit_reached", + "plan_type": "pro", + } + }) + .to_string(); + let err = map_api_error(ApiError::Transport(TransportError::Http { + status: http::StatusCode::TOO_MANY_REQUESTS, + url: Some("http://example.com/v1/responses".to_string()), + headers: Some(headers), + body: Some(body), + })); + + let CodexErr::UsageLimitReached(usage_limit) = err else { + panic!("expected CodexErr::UsageLimitReached, got {err:?}"); + }; + assert_eq!( + usage_limit + .rate_limits + .as_ref() + .and_then(|snapshot| snapshot.limit_name.as_deref()), + Some("codex_other") + ); +} + +#[test] +fn map_api_error_does_not_fallback_limit_name_to_limit_id() { + let mut headers = HeaderMap::new(); + headers.insert( + ACTIVE_LIMIT_HEADER, + http::HeaderValue::from_static("codex_other"), + ); + let body = serde_json::json!({ + "error": { + "type": "usage_limit_reached", + "plan_type": "pro", + } + }) + .to_string(); + let err = map_api_error(ApiError::Transport(TransportError::Http { + status: http::StatusCode::TOO_MANY_REQUESTS, + url: Some("http://example.com/v1/responses".to_string()), + headers: Some(headers), + body: Some(body), + })); + + let CodexErr::UsageLimitReached(usage_limit) = err else { + panic!("expected CodexErr::UsageLimitReached, got {err:?}"); + }; + assert_eq!( + usage_limit + .rate_limits + .as_ref() + .and_then(|snapshot| snapshot.limit_name.as_deref()), + None + ); +} diff --git a/codex-rs/core/src/apply_patch.rs b/codex-rs/core/src/apply_patch.rs index 0d64934cad..9a09ae9f09 100644 --- a/codex-rs/core/src/apply_patch.rs +++ b/codex-rs/core/src/apply_patch.rs @@ -104,26 +104,5 @@ pub(crate) fn convert_apply_patch_to_protocol( } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - use tempfile::tempdir; - - #[test] - fn convert_apply_patch_maps_add_variant() { - let tmp = tempdir().expect("tmp"); - let p = tmp.path().join("a.txt"); - // Create an action with a single Add change - let action = ApplyPatchAction::new_add_for_test(&p, "hello".to_string()); - - let got = convert_apply_patch_to_protocol(&action); - - assert_eq!( - got.get(&p), - Some(&FileChange::Add { - content: "hello".to_string() - }) - ); - } -} +#[path = "apply_patch_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/apply_patch_tests.rs b/codex-rs/core/src/apply_patch_tests.rs new file mode 100644 index 0000000000..1b9e722d5a --- /dev/null +++ b/codex-rs/core/src/apply_patch_tests.rs @@ -0,0 +1,21 @@ +use super::*; +use pretty_assertions::assert_eq; + +use tempfile::tempdir; + +#[test] +fn convert_apply_patch_maps_add_variant() { + let tmp = tempdir().expect("tmp"); + let p = tmp.path().join("a.txt"); + // Create an action with a single Add change + let action = ApplyPatchAction::new_add_for_test(&p, "hello".to_string()); + + let got = convert_apply_patch_to_protocol(&action); + + assert_eq!( + got.get(&p), + Some(&FileChange::Add { + content: "hello".to_string() + }) + ); +} diff --git a/codex-rs/core/src/arc_monitor.rs b/codex-rs/core/src/arc_monitor.rs index eb942c7e23..c704faafc0 100644 --- a/codex-rs/core/src/arc_monitor.rs +++ b/codex-rs/core/src/arc_monitor.rs @@ -425,441 +425,5 @@ fn build_arc_monitor_message(role: &str, content: serde_json::Value) -> ArcMonit } #[cfg(test)] -mod tests { - use std::env; - use std::ffi::OsStr; - use std::sync::Arc; - - use pretty_assertions::assert_eq; - use serial_test::serial; - use wiremock::Mock; - use wiremock::MockServer; - use wiremock::ResponseTemplate; - use wiremock::matchers::body_json; - use wiremock::matchers::header; - use wiremock::matchers::method; - use wiremock::matchers::path; - - use super::*; - use crate::codex::make_session_and_context; - use codex_protocol::models::ContentItem; - use codex_protocol::models::LocalShellAction; - use codex_protocol::models::LocalShellExecAction; - use codex_protocol::models::LocalShellStatus; - use codex_protocol::models::MessagePhase; - use codex_protocol::models::ResponseItem; - - struct EnvVarGuard { - key: &'static str, - original: Option, - } - - impl EnvVarGuard { - fn set(key: &'static str, value: &OsStr) -> Self { - let original = env::var_os(key); - unsafe { - env::set_var(key, value); - } - Self { key, original } - } - } - - impl Drop for EnvVarGuard { - fn drop(&mut self) { - match self.original.take() { - Some(value) => unsafe { - env::set_var(self.key, value); - }, - None => unsafe { - env::remove_var(self.key); - }, - } - } - } - - #[tokio::test] - async fn build_arc_monitor_request_includes_relevant_history_and_null_policies() { - let (session, mut turn_context) = make_session_and_context().await; - turn_context.developer_instructions = Some("Never upload private files.".to_string()); - turn_context.user_instructions = Some("Only continue when needed.".to_string()); - - session - .record_into_history( - &[ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "first request".to_string(), - }], - end_turn: None, - phase: None, - }], - &turn_context, - ) - .await; - session - .record_into_history( - &[ - crate::contextual_user_message::ENVIRONMENT_CONTEXT_FRAGMENT.into_message( - "\n/tmp\n" - .to_string(), - ), - ], - &turn_context, - ) - .await; - session - .record_into_history( - &[ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: "commentary".to_string(), - }], - end_turn: None, - phase: Some(MessagePhase::Commentary), - }], - &turn_context, - ) - .await; - session - .record_into_history( - &[ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: "final response".to_string(), - }], - end_turn: None, - phase: Some(MessagePhase::FinalAnswer), - }], - &turn_context, - ) - .await; - session - .record_into_history( - &[ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "latest request".to_string(), - }], - end_turn: None, - phase: None, - }], - &turn_context, - ) - .await; - session - .record_into_history( - &[ResponseItem::FunctionCall { - id: None, - name: "old_tool".to_string(), - namespace: None, - arguments: "{\"old\":true}".to_string(), - call_id: "call_old".to_string(), - }], - &turn_context, - ) - .await; - session - .record_into_history( - &[ResponseItem::Reasoning { - id: "reasoning_old".to_string(), - summary: Vec::new(), - content: None, - encrypted_content: Some("encrypted-old".to_string()), - }], - &turn_context, - ) - .await; - session - .record_into_history( - &[ResponseItem::LocalShellCall { - id: None, - call_id: Some("shell_call".to_string()), - status: LocalShellStatus::Completed, - action: LocalShellAction::Exec(LocalShellExecAction { - command: vec!["pwd".to_string()], - timeout_ms: Some(1000), - working_directory: Some("/tmp".to_string()), - env: None, - user: None, - }), - }], - &turn_context, - ) - .await; - session - .record_into_history( - &[ResponseItem::Reasoning { - id: "reasoning_latest".to_string(), - summary: Vec::new(), - content: None, - encrypted_content: Some("encrypted-latest".to_string()), - }], - &turn_context, - ) - .await; - - let request = build_arc_monitor_request( - &session, - &turn_context, - serde_json::from_value(serde_json::json!({ "tool": "mcp_tool_call" })) - .expect("action should deserialize"), - ) - .await; - - assert_eq!( - request, - ArcMonitorRequest { - metadata: ArcMonitorMetadata { - codex_thread_id: session.conversation_id.to_string(), - codex_turn_id: turn_context.sub_id.clone(), - conversation_id: Some(session.conversation_id.to_string()), - protection_client_callsite: None, - }, - messages: Some(vec![ - ArcMonitorChatMessage { - role: "user".to_string(), - content: serde_json::json!([{ - "type": "input_text", - "text": "first request", - }]), - }, - ArcMonitorChatMessage { - role: "assistant".to_string(), - content: serde_json::json!([{ - "type": "output_text", - "text": "final response", - }]), - }, - ArcMonitorChatMessage { - role: "user".to_string(), - content: serde_json::json!([{ - "type": "input_text", - "text": "latest request", - }]), - }, - ArcMonitorChatMessage { - role: "assistant".to_string(), - content: serde_json::json!([{ - "type": "tool_call", - "tool_name": "shell", - "action": { - "type": "exec", - "command": ["pwd"], - "timeout_ms": 1000, - "working_directory": "/tmp", - "env": null, - "user": null, - }, - }]), - }, - ArcMonitorChatMessage { - role: "assistant".to_string(), - content: serde_json::json!([{ - "type": "encrypted_reasoning", - "encrypted_content": "encrypted-latest", - }]), - }, - ]), - input: None, - policies: Some(ArcMonitorPolicies { - user: None, - developer: None, - }), - action: serde_json::from_value(serde_json::json!({ "tool": "mcp_tool_call" })) - .expect("action should deserialize"), - } - ); - } - - #[tokio::test] - #[serial(arc_monitor_env)] - async fn monitor_action_posts_expected_arc_request() { - let server = MockServer::start().await; - let (session, mut turn_context) = make_session_and_context().await; - turn_context.auth_manager = Some(crate::test_support::auth_manager_from_auth( - crate::CodexAuth::create_dummy_chatgpt_auth_for_testing(), - )); - turn_context.developer_instructions = Some("Developer policy".to_string()); - turn_context.user_instructions = Some("User policy".to_string()); - - let mut config = (*turn_context.config).clone(); - config.chatgpt_base_url = server.uri(); - turn_context.config = Arc::new(config); - - session - .record_into_history( - &[ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "please run the tool".to_string(), - }], - end_turn: None, - phase: None, - }], - &turn_context, - ) - .await; - - Mock::given(method("POST")) - .and(path("/codex/safety/arc")) - .and(header("authorization", "Bearer Access Token")) - .and(header("chatgpt-account-id", "account_id")) - .and(body_json(serde_json::json!({ - "metadata": { - "codex_thread_id": session.conversation_id.to_string(), - "codex_turn_id": turn_context.sub_id.clone(), - "conversation_id": session.conversation_id.to_string(), - }, - "messages": [{ - "role": "user", - "content": [{ - "type": "input_text", - "text": "please run the tool", - }], - }], - "policies": { - "developer": null, - "user": null, - }, - "action": { - "tool": "mcp_tool_call", - }, - }))) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "outcome": "ask-user", - "short_reason": "needs confirmation", - "rationale": "tool call needs additional review", - "risk_score": 42, - "risk_level": "medium", - "evidence": [{ - "message": "browser_navigate", - "why": "tool call needs additional review", - }], - }))) - .expect(1) - .mount(&server) - .await; - - let outcome = monitor_action( - &session, - &turn_context, - serde_json::json!({ "tool": "mcp_tool_call" }), - ) - .await; - - assert_eq!( - outcome, - ArcMonitorOutcome::AskUser("needs confirmation".to_string()) - ); - } - - #[tokio::test] - #[serial(arc_monitor_env)] - async fn monitor_action_uses_env_url_and_token_overrides() { - let server = MockServer::start().await; - let _url_guard = EnvVarGuard::set( - CODEX_ARC_MONITOR_ENDPOINT_OVERRIDE, - OsStr::new(&format!("{}/override/arc", server.uri())), - ); - let _token_guard = EnvVarGuard::set(CODEX_ARC_MONITOR_TOKEN, OsStr::new("override-token")); - - let (session, turn_context) = make_session_and_context().await; - session - .record_into_history( - &[ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "please run the tool".to_string(), - }], - end_turn: None, - phase: None, - }], - &turn_context, - ) - .await; - - Mock::given(method("POST")) - .and(path("/override/arc")) - .and(header("authorization", "Bearer override-token")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "outcome": "steer-model", - "short_reason": "needs approval", - "rationale": "high-risk action", - "risk_score": 96, - "risk_level": "critical", - "evidence": [{ - "message": "browser_navigate", - "why": "high-risk action", - }], - }))) - .expect(1) - .mount(&server) - .await; - - let outcome = monitor_action( - &session, - &turn_context, - serde_json::json!({ "tool": "mcp_tool_call" }), - ) - .await; - - assert_eq!( - outcome, - ArcMonitorOutcome::SteerModel("high-risk action".to_string()) - ); - } - - #[tokio::test] - #[serial(arc_monitor_env)] - async fn monitor_action_rejects_legacy_response_fields() { - let server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/codex/safety/arc")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "outcome": "steer-model", - "reason": "legacy high-risk action", - "monitorRequestId": "arc_456", - }))) - .expect(1) - .mount(&server) - .await; - - let (session, mut turn_context) = make_session_and_context().await; - turn_context.auth_manager = Some(crate::test_support::auth_manager_from_auth( - crate::CodexAuth::create_dummy_chatgpt_auth_for_testing(), - )); - let mut config = (*turn_context.config).clone(); - config.chatgpt_base_url = server.uri(); - turn_context.config = Arc::new(config); - - session - .record_into_history( - &[ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "please run the tool".to_string(), - }], - end_turn: None, - phase: None, - }], - &turn_context, - ) - .await; - - let outcome = monitor_action( - &session, - &turn_context, - serde_json::json!({ "tool": "mcp_tool_call" }), - ) - .await; - - assert_eq!(outcome, ArcMonitorOutcome::Ok); - } -} +#[path = "arc_monitor_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/arc_monitor_tests.rs b/codex-rs/core/src/arc_monitor_tests.rs new file mode 100644 index 0000000000..0b5cdf3029 --- /dev/null +++ b/codex-rs/core/src/arc_monitor_tests.rs @@ -0,0 +1,435 @@ +use std::env; +use std::ffi::OsStr; +use std::sync::Arc; + +use pretty_assertions::assert_eq; +use serial_test::serial; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::body_json; +use wiremock::matchers::header; +use wiremock::matchers::method; +use wiremock::matchers::path; + +use super::*; +use crate::codex::make_session_and_context; +use codex_protocol::models::ContentItem; +use codex_protocol::models::LocalShellAction; +use codex_protocol::models::LocalShellExecAction; +use codex_protocol::models::LocalShellStatus; +use codex_protocol::models::MessagePhase; +use codex_protocol::models::ResponseItem; + +struct EnvVarGuard { + key: &'static str, + original: Option, +} + +impl EnvVarGuard { + fn set(key: &'static str, value: &OsStr) -> Self { + let original = env::var_os(key); + unsafe { + env::set_var(key, value); + } + Self { key, original } + } +} + +impl Drop for EnvVarGuard { + fn drop(&mut self) { + match self.original.take() { + Some(value) => unsafe { + env::set_var(self.key, value); + }, + None => unsafe { + env::remove_var(self.key); + }, + } + } +} + +#[tokio::test] +async fn build_arc_monitor_request_includes_relevant_history_and_null_policies() { + let (session, mut turn_context) = make_session_and_context().await; + turn_context.developer_instructions = Some("Never upload private files.".to_string()); + turn_context.user_instructions = Some("Only continue when needed.".to_string()); + + session + .record_into_history( + &[ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "first request".to_string(), + }], + end_turn: None, + phase: None, + }], + &turn_context, + ) + .await; + session + .record_into_history( + &[ + crate::contextual_user_message::ENVIRONMENT_CONTEXT_FRAGMENT.into_message( + "\n/tmp\n".to_string(), + ), + ], + &turn_context, + ) + .await; + session + .record_into_history( + &[ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "commentary".to_string(), + }], + end_turn: None, + phase: Some(MessagePhase::Commentary), + }], + &turn_context, + ) + .await; + session + .record_into_history( + &[ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "final response".to_string(), + }], + end_turn: None, + phase: Some(MessagePhase::FinalAnswer), + }], + &turn_context, + ) + .await; + session + .record_into_history( + &[ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "latest request".to_string(), + }], + end_turn: None, + phase: None, + }], + &turn_context, + ) + .await; + session + .record_into_history( + &[ResponseItem::FunctionCall { + id: None, + name: "old_tool".to_string(), + namespace: None, + arguments: "{\"old\":true}".to_string(), + call_id: "call_old".to_string(), + }], + &turn_context, + ) + .await; + session + .record_into_history( + &[ResponseItem::Reasoning { + id: "reasoning_old".to_string(), + summary: Vec::new(), + content: None, + encrypted_content: Some("encrypted-old".to_string()), + }], + &turn_context, + ) + .await; + session + .record_into_history( + &[ResponseItem::LocalShellCall { + id: None, + call_id: Some("shell_call".to_string()), + status: LocalShellStatus::Completed, + action: LocalShellAction::Exec(LocalShellExecAction { + command: vec!["pwd".to_string()], + timeout_ms: Some(1000), + working_directory: Some("/tmp".to_string()), + env: None, + user: None, + }), + }], + &turn_context, + ) + .await; + session + .record_into_history( + &[ResponseItem::Reasoning { + id: "reasoning_latest".to_string(), + summary: Vec::new(), + content: None, + encrypted_content: Some("encrypted-latest".to_string()), + }], + &turn_context, + ) + .await; + + let request = build_arc_monitor_request( + &session, + &turn_context, + serde_json::from_value(serde_json::json!({ "tool": "mcp_tool_call" })) + .expect("action should deserialize"), + ) + .await; + + assert_eq!( + request, + ArcMonitorRequest { + metadata: ArcMonitorMetadata { + codex_thread_id: session.conversation_id.to_string(), + codex_turn_id: turn_context.sub_id.clone(), + conversation_id: Some(session.conversation_id.to_string()), + protection_client_callsite: None, + }, + messages: Some(vec![ + ArcMonitorChatMessage { + role: "user".to_string(), + content: serde_json::json!([{ + "type": "input_text", + "text": "first request", + }]), + }, + ArcMonitorChatMessage { + role: "assistant".to_string(), + content: serde_json::json!([{ + "type": "output_text", + "text": "final response", + }]), + }, + ArcMonitorChatMessage { + role: "user".to_string(), + content: serde_json::json!([{ + "type": "input_text", + "text": "latest request", + }]), + }, + ArcMonitorChatMessage { + role: "assistant".to_string(), + content: serde_json::json!([{ + "type": "tool_call", + "tool_name": "shell", + "action": { + "type": "exec", + "command": ["pwd"], + "timeout_ms": 1000, + "working_directory": "/tmp", + "env": null, + "user": null, + }, + }]), + }, + ArcMonitorChatMessage { + role: "assistant".to_string(), + content: serde_json::json!([{ + "type": "encrypted_reasoning", + "encrypted_content": "encrypted-latest", + }]), + }, + ]), + input: None, + policies: Some(ArcMonitorPolicies { + user: None, + developer: None, + }), + action: serde_json::from_value(serde_json::json!({ "tool": "mcp_tool_call" })) + .expect("action should deserialize"), + } + ); +} + +#[tokio::test] +#[serial(arc_monitor_env)] +async fn monitor_action_posts_expected_arc_request() { + let server = MockServer::start().await; + let (session, mut turn_context) = make_session_and_context().await; + turn_context.auth_manager = Some(crate::test_support::auth_manager_from_auth( + crate::CodexAuth::create_dummy_chatgpt_auth_for_testing(), + )); + turn_context.developer_instructions = Some("Developer policy".to_string()); + turn_context.user_instructions = Some("User policy".to_string()); + + let mut config = (*turn_context.config).clone(); + config.chatgpt_base_url = server.uri(); + turn_context.config = Arc::new(config); + + session + .record_into_history( + &[ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "please run the tool".to_string(), + }], + end_turn: None, + phase: None, + }], + &turn_context, + ) + .await; + + Mock::given(method("POST")) + .and(path("/codex/safety/arc")) + .and(header("authorization", "Bearer Access Token")) + .and(header("chatgpt-account-id", "account_id")) + .and(body_json(serde_json::json!({ + "metadata": { + "codex_thread_id": session.conversation_id.to_string(), + "codex_turn_id": turn_context.sub_id.clone(), + "conversation_id": session.conversation_id.to_string(), + }, + "messages": [{ + "role": "user", + "content": [{ + "type": "input_text", + "text": "please run the tool", + }], + }], + "policies": { + "developer": null, + "user": null, + }, + "action": { + "tool": "mcp_tool_call", + }, + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "outcome": "ask-user", + "short_reason": "needs confirmation", + "rationale": "tool call needs additional review", + "risk_score": 42, + "risk_level": "medium", + "evidence": [{ + "message": "browser_navigate", + "why": "tool call needs additional review", + }], + }))) + .expect(1) + .mount(&server) + .await; + + let outcome = monitor_action( + &session, + &turn_context, + serde_json::json!({ "tool": "mcp_tool_call" }), + ) + .await; + + assert_eq!( + outcome, + ArcMonitorOutcome::AskUser("needs confirmation".to_string()) + ); +} + +#[tokio::test] +#[serial(arc_monitor_env)] +async fn monitor_action_uses_env_url_and_token_overrides() { + let server = MockServer::start().await; + let _url_guard = EnvVarGuard::set( + CODEX_ARC_MONITOR_ENDPOINT_OVERRIDE, + OsStr::new(&format!("{}/override/arc", server.uri())), + ); + let _token_guard = EnvVarGuard::set(CODEX_ARC_MONITOR_TOKEN, OsStr::new("override-token")); + + let (session, turn_context) = make_session_and_context().await; + session + .record_into_history( + &[ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "please run the tool".to_string(), + }], + end_turn: None, + phase: None, + }], + &turn_context, + ) + .await; + + Mock::given(method("POST")) + .and(path("/override/arc")) + .and(header("authorization", "Bearer override-token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "outcome": "steer-model", + "short_reason": "needs approval", + "rationale": "high-risk action", + "risk_score": 96, + "risk_level": "critical", + "evidence": [{ + "message": "browser_navigate", + "why": "high-risk action", + }], + }))) + .expect(1) + .mount(&server) + .await; + + let outcome = monitor_action( + &session, + &turn_context, + serde_json::json!({ "tool": "mcp_tool_call" }), + ) + .await; + + assert_eq!( + outcome, + ArcMonitorOutcome::SteerModel("high-risk action".to_string()) + ); +} + +#[tokio::test] +#[serial(arc_monitor_env)] +async fn monitor_action_rejects_legacy_response_fields() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/codex/safety/arc")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "outcome": "steer-model", + "reason": "legacy high-risk action", + "monitorRequestId": "arc_456", + }))) + .expect(1) + .mount(&server) + .await; + + let (session, mut turn_context) = make_session_and_context().await; + turn_context.auth_manager = Some(crate::test_support::auth_manager_from_auth( + crate::CodexAuth::create_dummy_chatgpt_auth_for_testing(), + )); + let mut config = (*turn_context.config).clone(); + config.chatgpt_base_url = server.uri(); + turn_context.config = Arc::new(config); + + session + .record_into_history( + &[ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "please run the tool".to_string(), + }], + end_turn: None, + phase: None, + }], + &turn_context, + ) + .await; + + let outcome = monitor_action( + &session, + &turn_context, + serde_json::json!({ "tool": "mcp_tool_call" }), + ) + .await; + + assert_eq!(outcome, ArcMonitorOutcome::Ok); +} diff --git a/codex-rs/core/src/auth.rs b/codex-rs/core/src/auth.rs index 9f13cdf2b5..8bb2b23d87 100644 --- a/codex-rs/core/src/auth.rs +++ b/codex-rs/core/src/auth.rs @@ -1366,437 +1366,5 @@ impl AuthManager { } #[cfg(test)] -mod tests { - use super::*; - use crate::auth::storage::FileAuthStorage; - use crate::auth::storage::get_auth_file; - use crate::config::Config; - use crate::config::ConfigBuilder; - use crate::token_data::IdTokenInfo; - use crate::token_data::KnownPlan as InternalKnownPlan; - use crate::token_data::PlanType as InternalPlanType; - use codex_protocol::account::PlanType as AccountPlanType; - - use base64::Engine; - use codex_protocol::config_types::ForcedLoginMethod; - use pretty_assertions::assert_eq; - use serde::Serialize; - use serde_json::json; - use tempfile::tempdir; - - #[tokio::test] - async fn refresh_without_id_token() { - let codex_home = tempdir().unwrap(); - let fake_jwt = write_auth_file( - AuthFileParams { - openai_api_key: None, - chatgpt_plan_type: Some("pro".to_string()), - chatgpt_account_id: None, - }, - codex_home.path(), - ) - .expect("failed to write auth file"); - - let storage = create_auth_storage( - codex_home.path().to_path_buf(), - AuthCredentialsStoreMode::File, - ); - let updated = super::persist_tokens( - &storage, - None, - Some("new-access-token".to_string()), - Some("new-refresh-token".to_string()), - ) - .expect("update_tokens should succeed"); - - let tokens = updated.tokens.expect("tokens should exist"); - assert_eq!(tokens.id_token.raw_jwt, fake_jwt); - assert_eq!(tokens.access_token, "new-access-token"); - assert_eq!(tokens.refresh_token, "new-refresh-token"); - } - - #[test] - fn login_with_api_key_overwrites_existing_auth_json() { - let dir = tempdir().unwrap(); - let auth_path = dir.path().join("auth.json"); - let stale_auth = json!({ - "OPENAI_API_KEY": "sk-old", - "tokens": { - "id_token": "stale.header.payload", - "access_token": "stale-access", - "refresh_token": "stale-refresh", - "account_id": "stale-acc" - } - }); - std::fs::write( - &auth_path, - serde_json::to_string_pretty(&stale_auth).unwrap(), - ) - .unwrap(); - - super::login_with_api_key(dir.path(), "sk-new", AuthCredentialsStoreMode::File) - .expect("login_with_api_key should succeed"); - - let storage = FileAuthStorage::new(dir.path().to_path_buf()); - let auth = storage - .try_read_auth_json(&auth_path) - .expect("auth.json should parse"); - assert_eq!(auth.openai_api_key.as_deref(), Some("sk-new")); - assert!(auth.tokens.is_none(), "tokens should be cleared"); - } - - #[test] - fn missing_auth_json_returns_none() { - let dir = tempdir().unwrap(); - let auth = CodexAuth::from_auth_storage(dir.path(), AuthCredentialsStoreMode::File) - .expect("call should succeed"); - assert_eq!(auth, None); - } - - #[tokio::test] - #[serial(codex_api_key)] - async fn pro_account_with_no_api_key_uses_chatgpt_auth() { - let codex_home = tempdir().unwrap(); - let fake_jwt = write_auth_file( - AuthFileParams { - openai_api_key: None, - chatgpt_plan_type: Some("pro".to_string()), - chatgpt_account_id: None, - }, - codex_home.path(), - ) - .expect("failed to write auth file"); - - let auth = super::load_auth(codex_home.path(), false, AuthCredentialsStoreMode::File) - .unwrap() - .unwrap(); - assert_eq!(None, auth.api_key()); - assert_eq!(AuthMode::Chatgpt, auth.auth_mode()); - assert_eq!(auth.get_chatgpt_user_id().as_deref(), Some("user-12345")); - - let auth_dot_json = auth - .get_current_auth_json() - .expect("AuthDotJson should exist"); - let last_refresh = auth_dot_json - .last_refresh - .expect("last_refresh should be recorded"); - - assert_eq!( - AuthDotJson { - auth_mode: None, - openai_api_key: None, - tokens: Some(TokenData { - id_token: IdTokenInfo { - email: Some("user@example.com".to_string()), - chatgpt_plan_type: Some(InternalPlanType::Known(InternalKnownPlan::Pro)), - chatgpt_user_id: Some("user-12345".to_string()), - chatgpt_account_id: None, - raw_jwt: fake_jwt, - }, - access_token: "test-access-token".to_string(), - refresh_token: "test-refresh-token".to_string(), - account_id: None, - }), - last_refresh: Some(last_refresh), - }, - auth_dot_json - ); - } - - #[tokio::test] - #[serial(codex_api_key)] - async fn loads_api_key_from_auth_json() { - let dir = tempdir().unwrap(); - let auth_file = dir.path().join("auth.json"); - std::fs::write( - auth_file, - r#"{"OPENAI_API_KEY":"sk-test-key","tokens":null,"last_refresh":null}"#, - ) - .unwrap(); - - let auth = super::load_auth(dir.path(), false, AuthCredentialsStoreMode::File) - .unwrap() - .unwrap(); - assert_eq!(auth.auth_mode(), AuthMode::ApiKey); - assert_eq!(auth.api_key(), Some("sk-test-key")); - - assert!(auth.get_token_data().is_err()); - } - - #[test] - fn logout_removes_auth_file() -> Result<(), std::io::Error> { - let dir = tempdir()?; - let auth_dot_json = AuthDotJson { - auth_mode: Some(ApiAuthMode::ApiKey), - openai_api_key: Some("sk-test-key".to_string()), - tokens: None, - last_refresh: None, - }; - super::save_auth(dir.path(), &auth_dot_json, AuthCredentialsStoreMode::File)?; - let auth_file = get_auth_file(dir.path()); - assert!(auth_file.exists()); - assert!(logout(dir.path(), AuthCredentialsStoreMode::File)?); - assert!(!auth_file.exists()); - Ok(()) - } - - struct AuthFileParams { - openai_api_key: Option, - chatgpt_plan_type: Option, - chatgpt_account_id: Option, - } - - fn write_auth_file(params: AuthFileParams, codex_home: &Path) -> std::io::Result { - let auth_file = get_auth_file(codex_home); - // Create a minimal valid JWT for the id_token field. - #[derive(Serialize)] - struct Header { - alg: &'static str, - typ: &'static str, - } - let header = Header { - alg: "none", - typ: "JWT", - }; - let mut auth_payload = serde_json::json!({ - "chatgpt_user_id": "user-12345", - "user_id": "user-12345", - }); - - if let Some(chatgpt_plan_type) = params.chatgpt_plan_type { - auth_payload["chatgpt_plan_type"] = serde_json::Value::String(chatgpt_plan_type); - } - - if let Some(chatgpt_account_id) = params.chatgpt_account_id { - let org_value = serde_json::Value::String(chatgpt_account_id); - auth_payload["chatgpt_account_id"] = org_value; - } - - let payload = serde_json::json!({ - "email": "user@example.com", - "email_verified": true, - "https://api.openai.com/auth": auth_payload, - }); - let b64 = |b: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b); - let header_b64 = b64(&serde_json::to_vec(&header)?); - let payload_b64 = b64(&serde_json::to_vec(&payload)?); - let signature_b64 = b64(b"sig"); - let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); - - let auth_json_data = json!({ - "OPENAI_API_KEY": params.openai_api_key, - "tokens": { - "id_token": fake_jwt, - "access_token": "test-access-token", - "refresh_token": "test-refresh-token" - }, - "last_refresh": Utc::now(), - }); - let auth_json = serde_json::to_string_pretty(&auth_json_data)?; - std::fs::write(auth_file, auth_json)?; - Ok(fake_jwt) - } - - async fn build_config( - codex_home: &Path, - forced_login_method: Option, - forced_chatgpt_workspace_id: Option, - ) -> Config { - let mut config = ConfigBuilder::default() - .codex_home(codex_home.to_path_buf()) - .build() - .await - .expect("config should load"); - config.forced_login_method = forced_login_method; - config.forced_chatgpt_workspace_id = forced_chatgpt_workspace_id; - config - } - - /// Use sparingly. - /// TODO (gpeal): replace this with an injectable env var provider. - #[cfg(test)] - struct EnvVarGuard { - key: &'static str, - original: Option, - } - - #[cfg(test)] - impl EnvVarGuard { - fn set(key: &'static str, value: &str) -> Self { - let original = env::var_os(key); - unsafe { - env::set_var(key, value); - } - Self { key, original } - } - } - - #[cfg(test)] - impl Drop for EnvVarGuard { - fn drop(&mut self) { - unsafe { - match &self.original { - Some(value) => env::set_var(self.key, value), - None => env::remove_var(self.key), - } - } - } - } - - #[tokio::test] - async fn enforce_login_restrictions_logs_out_for_method_mismatch() { - let codex_home = tempdir().unwrap(); - login_with_api_key(codex_home.path(), "sk-test", AuthCredentialsStoreMode::File) - .expect("seed api key"); - - let config = build_config(codex_home.path(), Some(ForcedLoginMethod::Chatgpt), None).await; - - let err = super::enforce_login_restrictions(&config) - .expect_err("expected method mismatch to error"); - assert!(err.to_string().contains("ChatGPT login is required")); - assert!( - !codex_home.path().join("auth.json").exists(), - "auth.json should be removed on mismatch" - ); - } - - #[tokio::test] - #[serial(codex_api_key)] - async fn enforce_login_restrictions_logs_out_for_workspace_mismatch() { - let codex_home = tempdir().unwrap(); - let _jwt = write_auth_file( - AuthFileParams { - openai_api_key: None, - chatgpt_plan_type: Some("pro".to_string()), - chatgpt_account_id: Some("org_another_org".to_string()), - }, - codex_home.path(), - ) - .expect("failed to write auth file"); - - let config = build_config(codex_home.path(), None, Some("org_mine".to_string())).await; - - let err = super::enforce_login_restrictions(&config) - .expect_err("expected workspace mismatch to error"); - assert!(err.to_string().contains("workspace org_mine")); - assert!( - !codex_home.path().join("auth.json").exists(), - "auth.json should be removed on mismatch" - ); - } - - #[tokio::test] - #[serial(codex_api_key)] - async fn enforce_login_restrictions_allows_matching_workspace() { - let codex_home = tempdir().unwrap(); - let _jwt = write_auth_file( - AuthFileParams { - openai_api_key: None, - chatgpt_plan_type: Some("pro".to_string()), - chatgpt_account_id: Some("org_mine".to_string()), - }, - codex_home.path(), - ) - .expect("failed to write auth file"); - - let config = build_config(codex_home.path(), None, Some("org_mine".to_string())).await; - - super::enforce_login_restrictions(&config).expect("matching workspace should succeed"); - assert!( - codex_home.path().join("auth.json").exists(), - "auth.json should remain when restrictions pass" - ); - } - - #[tokio::test] - async fn enforce_login_restrictions_allows_api_key_if_login_method_not_set_but_forced_chatgpt_workspace_id_is_set() - { - let codex_home = tempdir().unwrap(); - login_with_api_key(codex_home.path(), "sk-test", AuthCredentialsStoreMode::File) - .expect("seed api key"); - - let config = build_config(codex_home.path(), None, Some("org_mine".to_string())).await; - - super::enforce_login_restrictions(&config).expect("matching workspace should succeed"); - assert!( - codex_home.path().join("auth.json").exists(), - "auth.json should remain when restrictions pass" - ); - } - - #[tokio::test] - #[serial(codex_api_key)] - async fn enforce_login_restrictions_blocks_env_api_key_when_chatgpt_required() { - let _guard = EnvVarGuard::set(CODEX_API_KEY_ENV_VAR, "sk-env"); - let codex_home = tempdir().unwrap(); - - let config = build_config(codex_home.path(), Some(ForcedLoginMethod::Chatgpt), None).await; - - let err = super::enforce_login_restrictions(&config) - .expect_err("environment API key should not satisfy forced ChatGPT login"); - assert!( - err.to_string() - .contains("ChatGPT login is required, but an API key is currently being used.") - ); - } - - #[test] - fn plan_type_maps_known_plan() { - let codex_home = tempdir().unwrap(); - let _jwt = write_auth_file( - AuthFileParams { - openai_api_key: None, - chatgpt_plan_type: Some("pro".to_string()), - chatgpt_account_id: None, - }, - codex_home.path(), - ) - .expect("failed to write auth file"); - - let auth = super::load_auth(codex_home.path(), false, AuthCredentialsStoreMode::File) - .expect("load auth") - .expect("auth available"); - - pretty_assertions::assert_eq!(auth.account_plan_type(), Some(AccountPlanType::Pro)); - } - - #[test] - fn plan_type_maps_unknown_to_unknown() { - let codex_home = tempdir().unwrap(); - let _jwt = write_auth_file( - AuthFileParams { - openai_api_key: None, - chatgpt_plan_type: Some("mystery-tier".to_string()), - chatgpt_account_id: None, - }, - codex_home.path(), - ) - .expect("failed to write auth file"); - - let auth = super::load_auth(codex_home.path(), false, AuthCredentialsStoreMode::File) - .expect("load auth") - .expect("auth available"); - - pretty_assertions::assert_eq!(auth.account_plan_type(), Some(AccountPlanType::Unknown)); - } - - #[test] - fn missing_plan_type_maps_to_unknown() { - let codex_home = tempdir().unwrap(); - let _jwt = write_auth_file( - AuthFileParams { - openai_api_key: None, - chatgpt_plan_type: None, - chatgpt_account_id: None, - }, - codex_home.path(), - ) - .expect("failed to write auth file"); - - let auth = super::load_auth(codex_home.path(), false, AuthCredentialsStoreMode::File) - .expect("load auth") - .expect("auth available"); - - pretty_assertions::assert_eq!(auth.account_plan_type(), Some(AccountPlanType::Unknown)); - } -} +#[path = "auth_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/auth/storage.rs b/codex-rs/core/src/auth/storage.rs index 81d17e4e1e..b1e04b8685 100644 --- a/codex-rs/core/src/auth/storage.rs +++ b/codex-rs/core/src/auth/storage.rs @@ -332,427 +332,5 @@ fn create_auth_storage_with_keyring_store( } #[cfg(test)] -mod tests { - use super::*; - use crate::token_data::IdTokenInfo; - use anyhow::Context; - use base64::Engine; - use pretty_assertions::assert_eq; - use serde_json::json; - use tempfile::tempdir; - - use codex_keyring_store::tests::MockKeyringStore; - use keyring::Error as KeyringError; - - #[tokio::test] - async fn file_storage_load_returns_auth_dot_json() -> anyhow::Result<()> { - let codex_home = tempdir()?; - let storage = FileAuthStorage::new(codex_home.path().to_path_buf()); - let auth_dot_json = AuthDotJson { - auth_mode: Some(AuthMode::ApiKey), - openai_api_key: Some("test-key".to_string()), - tokens: None, - last_refresh: Some(Utc::now()), - }; - - storage - .save(&auth_dot_json) - .context("failed to save auth file")?; - - let loaded = storage.load().context("failed to load auth file")?; - assert_eq!(Some(auth_dot_json), loaded); - Ok(()) - } - - #[tokio::test] - async fn file_storage_save_persists_auth_dot_json() -> anyhow::Result<()> { - let codex_home = tempdir()?; - let storage = FileAuthStorage::new(codex_home.path().to_path_buf()); - let auth_dot_json = AuthDotJson { - auth_mode: Some(AuthMode::ApiKey), - openai_api_key: Some("test-key".to_string()), - tokens: None, - last_refresh: Some(Utc::now()), - }; - - let file = get_auth_file(codex_home.path()); - storage - .save(&auth_dot_json) - .context("failed to save auth file")?; - - let same_auth_dot_json = storage - .try_read_auth_json(&file) - .context("failed to read auth file after save")?; - assert_eq!(auth_dot_json, same_auth_dot_json); - Ok(()) - } - - #[test] - fn file_storage_delete_removes_auth_file() -> anyhow::Result<()> { - let dir = tempdir()?; - let auth_dot_json = AuthDotJson { - auth_mode: Some(AuthMode::ApiKey), - openai_api_key: Some("sk-test-key".to_string()), - tokens: None, - last_refresh: None, - }; - let storage = create_auth_storage(dir.path().to_path_buf(), AuthCredentialsStoreMode::File); - storage.save(&auth_dot_json)?; - assert!(dir.path().join("auth.json").exists()); - let storage = FileAuthStorage::new(dir.path().to_path_buf()); - let removed = storage.delete()?; - assert!(removed); - assert!(!dir.path().join("auth.json").exists()); - Ok(()) - } - - #[test] - fn ephemeral_storage_save_load_delete_is_in_memory_only() -> anyhow::Result<()> { - let dir = tempdir()?; - let storage = create_auth_storage( - dir.path().to_path_buf(), - AuthCredentialsStoreMode::Ephemeral, - ); - let auth_dot_json = AuthDotJson { - auth_mode: Some(AuthMode::ApiKey), - openai_api_key: Some("sk-ephemeral".to_string()), - tokens: None, - last_refresh: Some(Utc::now()), - }; - - storage.save(&auth_dot_json)?; - let loaded = storage.load()?; - assert_eq!(Some(auth_dot_json), loaded); - - let removed = storage.delete()?; - assert!(removed); - let loaded = storage.load()?; - assert_eq!(None, loaded); - assert!(!get_auth_file(dir.path()).exists()); - Ok(()) - } - - fn seed_keyring_and_fallback_auth_file_for_delete( - mock_keyring: &MockKeyringStore, - codex_home: &Path, - compute_key: F, - ) -> anyhow::Result<(String, PathBuf)> - where - F: FnOnce() -> std::io::Result, - { - let key = compute_key()?; - mock_keyring.save(KEYRING_SERVICE, &key, "{}")?; - let auth_file = get_auth_file(codex_home); - std::fs::write(&auth_file, "stale")?; - Ok((key, auth_file)) - } - - fn seed_keyring_with_auth( - mock_keyring: &MockKeyringStore, - compute_key: F, - auth: &AuthDotJson, - ) -> anyhow::Result<()> - where - F: FnOnce() -> std::io::Result, - { - let key = compute_key()?; - let serialized = serde_json::to_string(auth)?; - mock_keyring.save(KEYRING_SERVICE, &key, &serialized)?; - Ok(()) - } - - fn assert_keyring_saved_auth_and_removed_fallback( - mock_keyring: &MockKeyringStore, - key: &str, - codex_home: &Path, - expected: &AuthDotJson, - ) { - let saved_value = mock_keyring - .saved_value(key) - .expect("keyring entry should exist"); - let expected_serialized = serde_json::to_string(expected).expect("serialize expected auth"); - assert_eq!(saved_value, expected_serialized); - let auth_file = get_auth_file(codex_home); - assert!( - !auth_file.exists(), - "fallback auth.json should be removed after keyring save" - ); - } - - fn id_token_with_prefix(prefix: &str) -> IdTokenInfo { - #[derive(Serialize)] - struct Header { - alg: &'static str, - typ: &'static str, - } - - let header = Header { - alg: "none", - typ: "JWT", - }; - let payload = json!({ - "email": format!("{prefix}@example.com"), - "https://api.openai.com/auth": { - "chatgpt_account_id": format!("{prefix}-account"), - }, - }); - let encode = |bytes: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes); - let header_b64 = encode(&serde_json::to_vec(&header).expect("serialize header")); - let payload_b64 = encode(&serde_json::to_vec(&payload).expect("serialize payload")); - let signature_b64 = encode(b"sig"); - let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); - - crate::token_data::parse_chatgpt_jwt_claims(&fake_jwt).expect("fake JWT should parse") - } - - fn auth_with_prefix(prefix: &str) -> AuthDotJson { - AuthDotJson { - auth_mode: Some(AuthMode::ApiKey), - openai_api_key: Some(format!("{prefix}-api-key")), - tokens: Some(TokenData { - id_token: id_token_with_prefix(prefix), - access_token: format!("{prefix}-access"), - refresh_token: format!("{prefix}-refresh"), - account_id: Some(format!("{prefix}-account-id")), - }), - last_refresh: None, - } - } - - #[test] - fn keyring_auth_storage_load_returns_deserialized_auth() -> anyhow::Result<()> { - let codex_home = tempdir()?; - let mock_keyring = MockKeyringStore::default(); - let storage = KeyringAuthStorage::new( - codex_home.path().to_path_buf(), - Arc::new(mock_keyring.clone()), - ); - let expected = AuthDotJson { - auth_mode: Some(AuthMode::ApiKey), - openai_api_key: Some("sk-test".to_string()), - tokens: None, - last_refresh: None, - }; - seed_keyring_with_auth( - &mock_keyring, - || compute_store_key(codex_home.path()), - &expected, - )?; - - let loaded = storage.load()?; - assert_eq!(Some(expected), loaded); - Ok(()) - } - - #[test] - fn keyring_auth_storage_compute_store_key_for_home_directory() -> anyhow::Result<()> { - let codex_home = PathBuf::from("~/.codex"); - - let key = compute_store_key(codex_home.as_path())?; - - assert_eq!(key, "cli|940db7b1d0e4eb40"); - Ok(()) - } - - #[test] - fn keyring_auth_storage_save_persists_and_removes_fallback_file() -> anyhow::Result<()> { - let codex_home = tempdir()?; - let mock_keyring = MockKeyringStore::default(); - let storage = KeyringAuthStorage::new( - codex_home.path().to_path_buf(), - Arc::new(mock_keyring.clone()), - ); - let auth_file = get_auth_file(codex_home.path()); - std::fs::write(&auth_file, "stale")?; - let auth = AuthDotJson { - auth_mode: Some(AuthMode::Chatgpt), - openai_api_key: None, - tokens: Some(TokenData { - id_token: Default::default(), - access_token: "access".to_string(), - refresh_token: "refresh".to_string(), - account_id: Some("account".to_string()), - }), - last_refresh: Some(Utc::now()), - }; - - storage.save(&auth)?; - - let key = compute_store_key(codex_home.path())?; - assert_keyring_saved_auth_and_removed_fallback( - &mock_keyring, - &key, - codex_home.path(), - &auth, - ); - Ok(()) - } - - #[test] - fn keyring_auth_storage_delete_removes_keyring_and_file() -> anyhow::Result<()> { - let codex_home = tempdir()?; - let mock_keyring = MockKeyringStore::default(); - let storage = KeyringAuthStorage::new( - codex_home.path().to_path_buf(), - Arc::new(mock_keyring.clone()), - ); - let (key, auth_file) = seed_keyring_and_fallback_auth_file_for_delete( - &mock_keyring, - codex_home.path(), - || compute_store_key(codex_home.path()), - )?; - - let removed = storage.delete()?; - - assert!(removed, "delete should report removal"); - assert!( - !mock_keyring.contains(&key), - "keyring entry should be removed" - ); - assert!( - !auth_file.exists(), - "fallback auth.json should be removed after keyring delete" - ); - Ok(()) - } - - #[test] - fn auto_auth_storage_load_prefers_keyring_value() -> anyhow::Result<()> { - let codex_home = tempdir()?; - let mock_keyring = MockKeyringStore::default(); - let storage = AutoAuthStorage::new( - codex_home.path().to_path_buf(), - Arc::new(mock_keyring.clone()), - ); - let keyring_auth = auth_with_prefix("keyring"); - seed_keyring_with_auth( - &mock_keyring, - || compute_store_key(codex_home.path()), - &keyring_auth, - )?; - - let file_auth = auth_with_prefix("file"); - storage.file_storage.save(&file_auth)?; - - let loaded = storage.load()?; - assert_eq!(loaded, Some(keyring_auth)); - Ok(()) - } - - #[test] - fn auto_auth_storage_load_uses_file_when_keyring_empty() -> anyhow::Result<()> { - let codex_home = tempdir()?; - let mock_keyring = MockKeyringStore::default(); - let storage = AutoAuthStorage::new(codex_home.path().to_path_buf(), Arc::new(mock_keyring)); - - let expected = auth_with_prefix("file-only"); - storage.file_storage.save(&expected)?; - - let loaded = storage.load()?; - assert_eq!(loaded, Some(expected)); - Ok(()) - } - - #[test] - fn auto_auth_storage_load_falls_back_when_keyring_errors() -> anyhow::Result<()> { - let codex_home = tempdir()?; - let mock_keyring = MockKeyringStore::default(); - let storage = AutoAuthStorage::new( - codex_home.path().to_path_buf(), - Arc::new(mock_keyring.clone()), - ); - let key = compute_store_key(codex_home.path())?; - mock_keyring.set_error(&key, KeyringError::Invalid("error".into(), "load".into())); - - let expected = auth_with_prefix("fallback"); - storage.file_storage.save(&expected)?; - - let loaded = storage.load()?; - assert_eq!(loaded, Some(expected)); - Ok(()) - } - - #[test] - fn auto_auth_storage_save_prefers_keyring() -> anyhow::Result<()> { - let codex_home = tempdir()?; - let mock_keyring = MockKeyringStore::default(); - let storage = AutoAuthStorage::new( - codex_home.path().to_path_buf(), - Arc::new(mock_keyring.clone()), - ); - let key = compute_store_key(codex_home.path())?; - - let stale = auth_with_prefix("stale"); - storage.file_storage.save(&stale)?; - - let expected = auth_with_prefix("to-save"); - storage.save(&expected)?; - - assert_keyring_saved_auth_and_removed_fallback( - &mock_keyring, - &key, - codex_home.path(), - &expected, - ); - Ok(()) - } - - #[test] - fn auto_auth_storage_save_falls_back_when_keyring_errors() -> anyhow::Result<()> { - let codex_home = tempdir()?; - let mock_keyring = MockKeyringStore::default(); - let storage = AutoAuthStorage::new( - codex_home.path().to_path_buf(), - Arc::new(mock_keyring.clone()), - ); - let key = compute_store_key(codex_home.path())?; - mock_keyring.set_error(&key, KeyringError::Invalid("error".into(), "save".into())); - - let auth = auth_with_prefix("fallback"); - storage.save(&auth)?; - - let auth_file = get_auth_file(codex_home.path()); - assert!( - auth_file.exists(), - "fallback auth.json should be created when keyring save fails" - ); - let saved = storage - .file_storage - .load()? - .context("fallback auth should exist")?; - assert_eq!(saved, auth); - assert!( - mock_keyring.saved_value(&key).is_none(), - "keyring should not contain value when save fails" - ); - Ok(()) - } - - #[test] - fn auto_auth_storage_delete_removes_keyring_and_file() -> anyhow::Result<()> { - let codex_home = tempdir()?; - let mock_keyring = MockKeyringStore::default(); - let storage = AutoAuthStorage::new( - codex_home.path().to_path_buf(), - Arc::new(mock_keyring.clone()), - ); - let (key, auth_file) = seed_keyring_and_fallback_auth_file_for_delete( - &mock_keyring, - codex_home.path(), - || compute_store_key(codex_home.path()), - )?; - - let removed = storage.delete()?; - - assert!(removed, "delete should report removal"); - assert!( - !mock_keyring.contains(&key), - "keyring entry should be removed" - ); - assert!( - !auth_file.exists(), - "fallback auth.json should be removed after delete" - ); - Ok(()) - } -} +#[path = "storage_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/auth/storage_tests.rs b/codex-rs/core/src/auth/storage_tests.rs new file mode 100644 index 0000000000..4bf72c11b9 --- /dev/null +++ b/codex-rs/core/src/auth/storage_tests.rs @@ -0,0 +1,415 @@ +use super::*; +use crate::token_data::IdTokenInfo; +use anyhow::Context; +use base64::Engine; +use pretty_assertions::assert_eq; +use serde_json::json; +use tempfile::tempdir; + +use codex_keyring_store::tests::MockKeyringStore; +use keyring::Error as KeyringError; + +#[tokio::test] +async fn file_storage_load_returns_auth_dot_json() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let storage = FileAuthStorage::new(codex_home.path().to_path_buf()); + let auth_dot_json = AuthDotJson { + auth_mode: Some(AuthMode::ApiKey), + openai_api_key: Some("test-key".to_string()), + tokens: None, + last_refresh: Some(Utc::now()), + }; + + storage + .save(&auth_dot_json) + .context("failed to save auth file")?; + + let loaded = storage.load().context("failed to load auth file")?; + assert_eq!(Some(auth_dot_json), loaded); + Ok(()) +} + +#[tokio::test] +async fn file_storage_save_persists_auth_dot_json() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let storage = FileAuthStorage::new(codex_home.path().to_path_buf()); + let auth_dot_json = AuthDotJson { + auth_mode: Some(AuthMode::ApiKey), + openai_api_key: Some("test-key".to_string()), + tokens: None, + last_refresh: Some(Utc::now()), + }; + + let file = get_auth_file(codex_home.path()); + storage + .save(&auth_dot_json) + .context("failed to save auth file")?; + + let same_auth_dot_json = storage + .try_read_auth_json(&file) + .context("failed to read auth file after save")?; + assert_eq!(auth_dot_json, same_auth_dot_json); + Ok(()) +} + +#[test] +fn file_storage_delete_removes_auth_file() -> anyhow::Result<()> { + let dir = tempdir()?; + let auth_dot_json = AuthDotJson { + auth_mode: Some(AuthMode::ApiKey), + openai_api_key: Some("sk-test-key".to_string()), + tokens: None, + last_refresh: None, + }; + let storage = create_auth_storage(dir.path().to_path_buf(), AuthCredentialsStoreMode::File); + storage.save(&auth_dot_json)?; + assert!(dir.path().join("auth.json").exists()); + let storage = FileAuthStorage::new(dir.path().to_path_buf()); + let removed = storage.delete()?; + assert!(removed); + assert!(!dir.path().join("auth.json").exists()); + Ok(()) +} + +#[test] +fn ephemeral_storage_save_load_delete_is_in_memory_only() -> anyhow::Result<()> { + let dir = tempdir()?; + let storage = create_auth_storage( + dir.path().to_path_buf(), + AuthCredentialsStoreMode::Ephemeral, + ); + let auth_dot_json = AuthDotJson { + auth_mode: Some(AuthMode::ApiKey), + openai_api_key: Some("sk-ephemeral".to_string()), + tokens: None, + last_refresh: Some(Utc::now()), + }; + + storage.save(&auth_dot_json)?; + let loaded = storage.load()?; + assert_eq!(Some(auth_dot_json), loaded); + + let removed = storage.delete()?; + assert!(removed); + let loaded = storage.load()?; + assert_eq!(None, loaded); + assert!(!get_auth_file(dir.path()).exists()); + Ok(()) +} + +fn seed_keyring_and_fallback_auth_file_for_delete( + mock_keyring: &MockKeyringStore, + codex_home: &Path, + compute_key: F, +) -> anyhow::Result<(String, PathBuf)> +where + F: FnOnce() -> std::io::Result, +{ + let key = compute_key()?; + mock_keyring.save(KEYRING_SERVICE, &key, "{}")?; + let auth_file = get_auth_file(codex_home); + std::fs::write(&auth_file, "stale")?; + Ok((key, auth_file)) +} + +fn seed_keyring_with_auth( + mock_keyring: &MockKeyringStore, + compute_key: F, + auth: &AuthDotJson, +) -> anyhow::Result<()> +where + F: FnOnce() -> std::io::Result, +{ + let key = compute_key()?; + let serialized = serde_json::to_string(auth)?; + mock_keyring.save(KEYRING_SERVICE, &key, &serialized)?; + Ok(()) +} + +fn assert_keyring_saved_auth_and_removed_fallback( + mock_keyring: &MockKeyringStore, + key: &str, + codex_home: &Path, + expected: &AuthDotJson, +) { + let saved_value = mock_keyring + .saved_value(key) + .expect("keyring entry should exist"); + let expected_serialized = serde_json::to_string(expected).expect("serialize expected auth"); + assert_eq!(saved_value, expected_serialized); + let auth_file = get_auth_file(codex_home); + assert!( + !auth_file.exists(), + "fallback auth.json should be removed after keyring save" + ); +} + +fn id_token_with_prefix(prefix: &str) -> IdTokenInfo { + #[derive(Serialize)] + struct Header { + alg: &'static str, + typ: &'static str, + } + + let header = Header { + alg: "none", + typ: "JWT", + }; + let payload = json!({ + "email": format!("{prefix}@example.com"), + "https://api.openai.com/auth": { + "chatgpt_account_id": format!("{prefix}-account"), + }, + }); + let encode = |bytes: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes); + let header_b64 = encode(&serde_json::to_vec(&header).expect("serialize header")); + let payload_b64 = encode(&serde_json::to_vec(&payload).expect("serialize payload")); + let signature_b64 = encode(b"sig"); + let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); + + crate::token_data::parse_chatgpt_jwt_claims(&fake_jwt).expect("fake JWT should parse") +} + +fn auth_with_prefix(prefix: &str) -> AuthDotJson { + AuthDotJson { + auth_mode: Some(AuthMode::ApiKey), + openai_api_key: Some(format!("{prefix}-api-key")), + tokens: Some(TokenData { + id_token: id_token_with_prefix(prefix), + access_token: format!("{prefix}-access"), + refresh_token: format!("{prefix}-refresh"), + account_id: Some(format!("{prefix}-account-id")), + }), + last_refresh: None, + } +} + +#[test] +fn keyring_auth_storage_load_returns_deserialized_auth() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = KeyringAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let expected = AuthDotJson { + auth_mode: Some(AuthMode::ApiKey), + openai_api_key: Some("sk-test".to_string()), + tokens: None, + last_refresh: None, + }; + seed_keyring_with_auth( + &mock_keyring, + || compute_store_key(codex_home.path()), + &expected, + )?; + + let loaded = storage.load()?; + assert_eq!(Some(expected), loaded); + Ok(()) +} + +#[test] +fn keyring_auth_storage_compute_store_key_for_home_directory() -> anyhow::Result<()> { + let codex_home = PathBuf::from("~/.codex"); + + let key = compute_store_key(codex_home.as_path())?; + + assert_eq!(key, "cli|940db7b1d0e4eb40"); + Ok(()) +} + +#[test] +fn keyring_auth_storage_save_persists_and_removes_fallback_file() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = KeyringAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let auth_file = get_auth_file(codex_home.path()); + std::fs::write(&auth_file, "stale")?; + let auth = AuthDotJson { + auth_mode: Some(AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(TokenData { + id_token: Default::default(), + access_token: "access".to_string(), + refresh_token: "refresh".to_string(), + account_id: Some("account".to_string()), + }), + last_refresh: Some(Utc::now()), + }; + + storage.save(&auth)?; + + let key = compute_store_key(codex_home.path())?; + assert_keyring_saved_auth_and_removed_fallback(&mock_keyring, &key, codex_home.path(), &auth); + Ok(()) +} + +#[test] +fn keyring_auth_storage_delete_removes_keyring_and_file() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = KeyringAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let (key, auth_file) = + seed_keyring_and_fallback_auth_file_for_delete(&mock_keyring, codex_home.path(), || { + compute_store_key(codex_home.path()) + })?; + + let removed = storage.delete()?; + + assert!(removed, "delete should report removal"); + assert!( + !mock_keyring.contains(&key), + "keyring entry should be removed" + ); + assert!( + !auth_file.exists(), + "fallback auth.json should be removed after keyring delete" + ); + Ok(()) +} + +#[test] +fn auto_auth_storage_load_prefers_keyring_value() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = AutoAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let keyring_auth = auth_with_prefix("keyring"); + seed_keyring_with_auth( + &mock_keyring, + || compute_store_key(codex_home.path()), + &keyring_auth, + )?; + + let file_auth = auth_with_prefix("file"); + storage.file_storage.save(&file_auth)?; + + let loaded = storage.load()?; + assert_eq!(loaded, Some(keyring_auth)); + Ok(()) +} + +#[test] +fn auto_auth_storage_load_uses_file_when_keyring_empty() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = AutoAuthStorage::new(codex_home.path().to_path_buf(), Arc::new(mock_keyring)); + + let expected = auth_with_prefix("file-only"); + storage.file_storage.save(&expected)?; + + let loaded = storage.load()?; + assert_eq!(loaded, Some(expected)); + Ok(()) +} + +#[test] +fn auto_auth_storage_load_falls_back_when_keyring_errors() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = AutoAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let key = compute_store_key(codex_home.path())?; + mock_keyring.set_error(&key, KeyringError::Invalid("error".into(), "load".into())); + + let expected = auth_with_prefix("fallback"); + storage.file_storage.save(&expected)?; + + let loaded = storage.load()?; + assert_eq!(loaded, Some(expected)); + Ok(()) +} + +#[test] +fn auto_auth_storage_save_prefers_keyring() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = AutoAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let key = compute_store_key(codex_home.path())?; + + let stale = auth_with_prefix("stale"); + storage.file_storage.save(&stale)?; + + let expected = auth_with_prefix("to-save"); + storage.save(&expected)?; + + assert_keyring_saved_auth_and_removed_fallback( + &mock_keyring, + &key, + codex_home.path(), + &expected, + ); + Ok(()) +} + +#[test] +fn auto_auth_storage_save_falls_back_when_keyring_errors() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = AutoAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let key = compute_store_key(codex_home.path())?; + mock_keyring.set_error(&key, KeyringError::Invalid("error".into(), "save".into())); + + let auth = auth_with_prefix("fallback"); + storage.save(&auth)?; + + let auth_file = get_auth_file(codex_home.path()); + assert!( + auth_file.exists(), + "fallback auth.json should be created when keyring save fails" + ); + let saved = storage + .file_storage + .load()? + .context("fallback auth should exist")?; + assert_eq!(saved, auth); + assert!( + mock_keyring.saved_value(&key).is_none(), + "keyring should not contain value when save fails" + ); + Ok(()) +} + +#[test] +fn auto_auth_storage_delete_removes_keyring_and_file() -> anyhow::Result<()> { + let codex_home = tempdir()?; + let mock_keyring = MockKeyringStore::default(); + let storage = AutoAuthStorage::new( + codex_home.path().to_path_buf(), + Arc::new(mock_keyring.clone()), + ); + let (key, auth_file) = + seed_keyring_and_fallback_auth_file_for_delete(&mock_keyring, codex_home.path(), || { + compute_store_key(codex_home.path()) + })?; + + let removed = storage.delete()?; + + assert!(removed, "delete should report removal"); + assert!( + !mock_keyring.contains(&key), + "keyring entry should be removed" + ); + assert!( + !auth_file.exists(), + "fallback auth.json should be removed after delete" + ); + Ok(()) +} diff --git a/codex-rs/core/src/auth_tests.rs b/codex-rs/core/src/auth_tests.rs new file mode 100644 index 0000000000..0c4a574f34 --- /dev/null +++ b/codex-rs/core/src/auth_tests.rs @@ -0,0 +1,432 @@ +use super::*; +use crate::auth::storage::FileAuthStorage; +use crate::auth::storage::get_auth_file; +use crate::config::Config; +use crate::config::ConfigBuilder; +use crate::token_data::IdTokenInfo; +use crate::token_data::KnownPlan as InternalKnownPlan; +use crate::token_data::PlanType as InternalPlanType; +use codex_protocol::account::PlanType as AccountPlanType; + +use base64::Engine; +use codex_protocol::config_types::ForcedLoginMethod; +use pretty_assertions::assert_eq; +use serde::Serialize; +use serde_json::json; +use tempfile::tempdir; + +#[tokio::test] +async fn refresh_without_id_token() { + let codex_home = tempdir().unwrap(); + let fake_jwt = write_auth_file( + AuthFileParams { + openai_api_key: None, + chatgpt_plan_type: Some("pro".to_string()), + chatgpt_account_id: None, + }, + codex_home.path(), + ) + .expect("failed to write auth file"); + + let storage = create_auth_storage( + codex_home.path().to_path_buf(), + AuthCredentialsStoreMode::File, + ); + let updated = super::persist_tokens( + &storage, + None, + Some("new-access-token".to_string()), + Some("new-refresh-token".to_string()), + ) + .expect("update_tokens should succeed"); + + let tokens = updated.tokens.expect("tokens should exist"); + assert_eq!(tokens.id_token.raw_jwt, fake_jwt); + assert_eq!(tokens.access_token, "new-access-token"); + assert_eq!(tokens.refresh_token, "new-refresh-token"); +} + +#[test] +fn login_with_api_key_overwrites_existing_auth_json() { + let dir = tempdir().unwrap(); + let auth_path = dir.path().join("auth.json"); + let stale_auth = json!({ + "OPENAI_API_KEY": "sk-old", + "tokens": { + "id_token": "stale.header.payload", + "access_token": "stale-access", + "refresh_token": "stale-refresh", + "account_id": "stale-acc" + } + }); + std::fs::write( + &auth_path, + serde_json::to_string_pretty(&stale_auth).unwrap(), + ) + .unwrap(); + + super::login_with_api_key(dir.path(), "sk-new", AuthCredentialsStoreMode::File) + .expect("login_with_api_key should succeed"); + + let storage = FileAuthStorage::new(dir.path().to_path_buf()); + let auth = storage + .try_read_auth_json(&auth_path) + .expect("auth.json should parse"); + assert_eq!(auth.openai_api_key.as_deref(), Some("sk-new")); + assert!(auth.tokens.is_none(), "tokens should be cleared"); +} + +#[test] +fn missing_auth_json_returns_none() { + let dir = tempdir().unwrap(); + let auth = CodexAuth::from_auth_storage(dir.path(), AuthCredentialsStoreMode::File) + .expect("call should succeed"); + assert_eq!(auth, None); +} + +#[tokio::test] +#[serial(codex_api_key)] +async fn pro_account_with_no_api_key_uses_chatgpt_auth() { + let codex_home = tempdir().unwrap(); + let fake_jwt = write_auth_file( + AuthFileParams { + openai_api_key: None, + chatgpt_plan_type: Some("pro".to_string()), + chatgpt_account_id: None, + }, + codex_home.path(), + ) + .expect("failed to write auth file"); + + let auth = super::load_auth(codex_home.path(), false, AuthCredentialsStoreMode::File) + .unwrap() + .unwrap(); + assert_eq!(None, auth.api_key()); + assert_eq!(AuthMode::Chatgpt, auth.auth_mode()); + assert_eq!(auth.get_chatgpt_user_id().as_deref(), Some("user-12345")); + + let auth_dot_json = auth + .get_current_auth_json() + .expect("AuthDotJson should exist"); + let last_refresh = auth_dot_json + .last_refresh + .expect("last_refresh should be recorded"); + + assert_eq!( + AuthDotJson { + auth_mode: None, + openai_api_key: None, + tokens: Some(TokenData { + id_token: IdTokenInfo { + email: Some("user@example.com".to_string()), + chatgpt_plan_type: Some(InternalPlanType::Known(InternalKnownPlan::Pro)), + chatgpt_user_id: Some("user-12345".to_string()), + chatgpt_account_id: None, + raw_jwt: fake_jwt, + }, + access_token: "test-access-token".to_string(), + refresh_token: "test-refresh-token".to_string(), + account_id: None, + }), + last_refresh: Some(last_refresh), + }, + auth_dot_json + ); +} + +#[tokio::test] +#[serial(codex_api_key)] +async fn loads_api_key_from_auth_json() { + let dir = tempdir().unwrap(); + let auth_file = dir.path().join("auth.json"); + std::fs::write( + auth_file, + r#"{"OPENAI_API_KEY":"sk-test-key","tokens":null,"last_refresh":null}"#, + ) + .unwrap(); + + let auth = super::load_auth(dir.path(), false, AuthCredentialsStoreMode::File) + .unwrap() + .unwrap(); + assert_eq!(auth.auth_mode(), AuthMode::ApiKey); + assert_eq!(auth.api_key(), Some("sk-test-key")); + + assert!(auth.get_token_data().is_err()); +} + +#[test] +fn logout_removes_auth_file() -> Result<(), std::io::Error> { + let dir = tempdir()?; + let auth_dot_json = AuthDotJson { + auth_mode: Some(ApiAuthMode::ApiKey), + openai_api_key: Some("sk-test-key".to_string()), + tokens: None, + last_refresh: None, + }; + super::save_auth(dir.path(), &auth_dot_json, AuthCredentialsStoreMode::File)?; + let auth_file = get_auth_file(dir.path()); + assert!(auth_file.exists()); + assert!(logout(dir.path(), AuthCredentialsStoreMode::File)?); + assert!(!auth_file.exists()); + Ok(()) +} + +struct AuthFileParams { + openai_api_key: Option, + chatgpt_plan_type: Option, + chatgpt_account_id: Option, +} + +fn write_auth_file(params: AuthFileParams, codex_home: &Path) -> std::io::Result { + let auth_file = get_auth_file(codex_home); + // Create a minimal valid JWT for the id_token field. + #[derive(Serialize)] + struct Header { + alg: &'static str, + typ: &'static str, + } + let header = Header { + alg: "none", + typ: "JWT", + }; + let mut auth_payload = serde_json::json!({ + "chatgpt_user_id": "user-12345", + "user_id": "user-12345", + }); + + if let Some(chatgpt_plan_type) = params.chatgpt_plan_type { + auth_payload["chatgpt_plan_type"] = serde_json::Value::String(chatgpt_plan_type); + } + + if let Some(chatgpt_account_id) = params.chatgpt_account_id { + let org_value = serde_json::Value::String(chatgpt_account_id); + auth_payload["chatgpt_account_id"] = org_value; + } + + let payload = serde_json::json!({ + "email": "user@example.com", + "email_verified": true, + "https://api.openai.com/auth": auth_payload, + }); + let b64 = |b: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b); + let header_b64 = b64(&serde_json::to_vec(&header)?); + let payload_b64 = b64(&serde_json::to_vec(&payload)?); + let signature_b64 = b64(b"sig"); + let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); + + let auth_json_data = json!({ + "OPENAI_API_KEY": params.openai_api_key, + "tokens": { + "id_token": fake_jwt, + "access_token": "test-access-token", + "refresh_token": "test-refresh-token" + }, + "last_refresh": Utc::now(), + }); + let auth_json = serde_json::to_string_pretty(&auth_json_data)?; + std::fs::write(auth_file, auth_json)?; + Ok(fake_jwt) +} + +async fn build_config( + codex_home: &Path, + forced_login_method: Option, + forced_chatgpt_workspace_id: Option, +) -> Config { + let mut config = ConfigBuilder::default() + .codex_home(codex_home.to_path_buf()) + .build() + .await + .expect("config should load"); + config.forced_login_method = forced_login_method; + config.forced_chatgpt_workspace_id = forced_chatgpt_workspace_id; + config +} + +/// Use sparingly. +/// TODO (gpeal): replace this with an injectable env var provider. +#[cfg(test)] +struct EnvVarGuard { + key: &'static str, + original: Option, +} + +#[cfg(test)] +impl EnvVarGuard { + fn set(key: &'static str, value: &str) -> Self { + let original = env::var_os(key); + unsafe { + env::set_var(key, value); + } + Self { key, original } + } +} + +#[cfg(test)] +impl Drop for EnvVarGuard { + fn drop(&mut self) { + unsafe { + match &self.original { + Some(value) => env::set_var(self.key, value), + None => env::remove_var(self.key), + } + } + } +} + +#[tokio::test] +async fn enforce_login_restrictions_logs_out_for_method_mismatch() { + let codex_home = tempdir().unwrap(); + login_with_api_key(codex_home.path(), "sk-test", AuthCredentialsStoreMode::File) + .expect("seed api key"); + + let config = build_config(codex_home.path(), Some(ForcedLoginMethod::Chatgpt), None).await; + + let err = + super::enforce_login_restrictions(&config).expect_err("expected method mismatch to error"); + assert!(err.to_string().contains("ChatGPT login is required")); + assert!( + !codex_home.path().join("auth.json").exists(), + "auth.json should be removed on mismatch" + ); +} + +#[tokio::test] +#[serial(codex_api_key)] +async fn enforce_login_restrictions_logs_out_for_workspace_mismatch() { + let codex_home = tempdir().unwrap(); + let _jwt = write_auth_file( + AuthFileParams { + openai_api_key: None, + chatgpt_plan_type: Some("pro".to_string()), + chatgpt_account_id: Some("org_another_org".to_string()), + }, + codex_home.path(), + ) + .expect("failed to write auth file"); + + let config = build_config(codex_home.path(), None, Some("org_mine".to_string())).await; + + let err = super::enforce_login_restrictions(&config) + .expect_err("expected workspace mismatch to error"); + assert!(err.to_string().contains("workspace org_mine")); + assert!( + !codex_home.path().join("auth.json").exists(), + "auth.json should be removed on mismatch" + ); +} + +#[tokio::test] +#[serial(codex_api_key)] +async fn enforce_login_restrictions_allows_matching_workspace() { + let codex_home = tempdir().unwrap(); + let _jwt = write_auth_file( + AuthFileParams { + openai_api_key: None, + chatgpt_plan_type: Some("pro".to_string()), + chatgpt_account_id: Some("org_mine".to_string()), + }, + codex_home.path(), + ) + .expect("failed to write auth file"); + + let config = build_config(codex_home.path(), None, Some("org_mine".to_string())).await; + + super::enforce_login_restrictions(&config).expect("matching workspace should succeed"); + assert!( + codex_home.path().join("auth.json").exists(), + "auth.json should remain when restrictions pass" + ); +} + +#[tokio::test] +async fn enforce_login_restrictions_allows_api_key_if_login_method_not_set_but_forced_chatgpt_workspace_id_is_set() + { + let codex_home = tempdir().unwrap(); + login_with_api_key(codex_home.path(), "sk-test", AuthCredentialsStoreMode::File) + .expect("seed api key"); + + let config = build_config(codex_home.path(), None, Some("org_mine".to_string())).await; + + super::enforce_login_restrictions(&config).expect("matching workspace should succeed"); + assert!( + codex_home.path().join("auth.json").exists(), + "auth.json should remain when restrictions pass" + ); +} + +#[tokio::test] +#[serial(codex_api_key)] +async fn enforce_login_restrictions_blocks_env_api_key_when_chatgpt_required() { + let _guard = EnvVarGuard::set(CODEX_API_KEY_ENV_VAR, "sk-env"); + let codex_home = tempdir().unwrap(); + + let config = build_config(codex_home.path(), Some(ForcedLoginMethod::Chatgpt), None).await; + + let err = super::enforce_login_restrictions(&config) + .expect_err("environment API key should not satisfy forced ChatGPT login"); + assert!( + err.to_string() + .contains("ChatGPT login is required, but an API key is currently being used.") + ); +} + +#[test] +fn plan_type_maps_known_plan() { + let codex_home = tempdir().unwrap(); + let _jwt = write_auth_file( + AuthFileParams { + openai_api_key: None, + chatgpt_plan_type: Some("pro".to_string()), + chatgpt_account_id: None, + }, + codex_home.path(), + ) + .expect("failed to write auth file"); + + let auth = super::load_auth(codex_home.path(), false, AuthCredentialsStoreMode::File) + .expect("load auth") + .expect("auth available"); + + pretty_assertions::assert_eq!(auth.account_plan_type(), Some(AccountPlanType::Pro)); +} + +#[test] +fn plan_type_maps_unknown_to_unknown() { + let codex_home = tempdir().unwrap(); + let _jwt = write_auth_file( + AuthFileParams { + openai_api_key: None, + chatgpt_plan_type: Some("mystery-tier".to_string()), + chatgpt_account_id: None, + }, + codex_home.path(), + ) + .expect("failed to write auth file"); + + let auth = super::load_auth(codex_home.path(), false, AuthCredentialsStoreMode::File) + .expect("load auth") + .expect("auth available"); + + pretty_assertions::assert_eq!(auth.account_plan_type(), Some(AccountPlanType::Unknown)); +} + +#[test] +fn missing_plan_type_maps_to_unknown() { + let codex_home = tempdir().unwrap(); + let _jwt = write_auth_file( + AuthFileParams { + openai_api_key: None, + chatgpt_plan_type: None, + chatgpt_account_id: None, + }, + codex_home.path(), + ) + .expect("failed to write auth file"); + + let auth = super::load_auth(codex_home.path(), false, AuthCredentialsStoreMode::File) + .expect("load auth") + .expect("auth available"); + + pretty_assertions::assert_eq!(auth.account_plan_type(), Some(AccountPlanType::Unknown)); +} diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index dcecad6b79..47d01d4a44 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -1261,101 +1261,5 @@ impl WebsocketTelemetry for ApiTelemetry { } #[cfg(test)] -mod tests { - use super::ModelClient; - use codex_otel::SessionTelemetry; - use codex_protocol::ThreadId; - use codex_protocol::openai_models::ModelInfo; - use codex_protocol::protocol::SessionSource; - use codex_protocol::protocol::SubAgentSource; - use pretty_assertions::assert_eq; - use serde_json::json; - - fn test_model_client(session_source: SessionSource) -> ModelClient { - let provider = crate::model_provider_info::create_oss_provider_with_base_url( - "https://example.com/v1", - crate::model_provider_info::WireApi::Responses, - ); - ModelClient::new( - None, - ThreadId::new(), - provider, - session_source, - None, - false, - false, - false, - None, - ) - } - - fn test_model_info() -> ModelInfo { - serde_json::from_value(json!({ - "slug": "gpt-test", - "display_name": "gpt-test", - "description": "desc", - "default_reasoning_level": "medium", - "supported_reasoning_levels": [ - {"effort": "medium", "description": "medium"} - ], - "shell_type": "shell_command", - "visibility": "list", - "supported_in_api": true, - "priority": 1, - "upgrade": null, - "base_instructions": "base instructions", - "model_messages": null, - "supports_reasoning_summaries": false, - "support_verbosity": false, - "default_verbosity": null, - "apply_patch_tool_type": null, - "truncation_policy": {"mode": "bytes", "limit": 10000}, - "supports_parallel_tool_calls": false, - "supports_image_detail_original": false, - "context_window": 272000, - "auto_compact_token_limit": null, - "experimental_supported_tools": [] - })) - .expect("deserialize test model info") - } - - fn test_session_telemetry() -> SessionTelemetry { - SessionTelemetry::new( - ThreadId::new(), - "gpt-test", - "gpt-test", - None, - None, - None, - "test-originator".to_string(), - false, - "test-terminal".to_string(), - SessionSource::Cli, - ) - } - - #[test] - fn build_subagent_headers_sets_other_subagent_label() { - let client = test_model_client(SessionSource::SubAgent(SubAgentSource::Other( - "memory_consolidation".to_string(), - ))); - let headers = client.build_subagent_headers(); - let value = headers - .get("x-openai-subagent") - .and_then(|value| value.to_str().ok()); - assert_eq!(value, Some("memory_consolidation")); - } - - #[tokio::test] - async fn summarize_memories_returns_empty_for_empty_input() { - let client = test_model_client(SessionSource::Cli); - let model_info = test_model_info(); - let session_telemetry = test_session_telemetry(); - - let output = client - .summarize_memories(Vec::new(), &model_info, None, &session_telemetry) - .await - .expect("empty summarize request should succeed"); - assert_eq!(output.len(), 0); - } -} +#[path = "client_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 4bf26b476f..e88e1af124 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -320,248 +320,5 @@ impl Stream for ResponseStream { } #[cfg(test)] -mod tests { - use codex_api::ResponsesApiRequest; - use codex_api::common::OpenAiVerbosity; - use codex_api::common::TextControls; - use codex_api::create_text_param_for_request; - use codex_protocol::config_types::ServiceTier; - use codex_protocol::models::FunctionCallOutputPayload; - use pretty_assertions::assert_eq; - - use super::*; - - #[test] - fn serializes_text_verbosity_when_set() { - let input: Vec = vec![]; - let tools: Vec = vec![]; - let req = ResponsesApiRequest { - model: "gpt-5.1".to_string(), - instructions: "i".to_string(), - input, - tools, - tool_choice: "auto".to_string(), - parallel_tool_calls: true, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - service_tier: None, - text: Some(TextControls { - verbosity: Some(OpenAiVerbosity::Low), - format: None, - }), - }; - - let v = serde_json::to_value(&req).expect("json"); - assert_eq!( - v.get("text") - .and_then(|t| t.get("verbosity")) - .and_then(|s| s.as_str()), - Some("low") - ); - } - - #[test] - fn serializes_text_schema_with_strict_format() { - let input: Vec = vec![]; - let tools: Vec = vec![]; - let schema = serde_json::json!({ - "type": "object", - "properties": { - "answer": {"type": "string"} - }, - "required": ["answer"], - }); - let text_controls = - create_text_param_for_request(None, &Some(schema.clone())).expect("text controls"); - - let req = ResponsesApiRequest { - model: "gpt-5.1".to_string(), - instructions: "i".to_string(), - input, - tools, - tool_choice: "auto".to_string(), - parallel_tool_calls: true, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - service_tier: None, - text: Some(text_controls), - }; - - let v = serde_json::to_value(&req).expect("json"); - let text = v.get("text").expect("text field"); - assert!(text.get("verbosity").is_none()); - let format = text.get("format").expect("format field"); - - assert_eq!( - format.get("name"), - Some(&serde_json::Value::String("codex_output_schema".into())) - ); - assert_eq!( - format.get("type"), - Some(&serde_json::Value::String("json_schema".into())) - ); - assert_eq!(format.get("strict"), Some(&serde_json::Value::Bool(true))); - assert_eq!(format.get("schema"), Some(&schema)); - } - - #[test] - fn omits_text_when_not_set() { - let input: Vec = vec![]; - let tools: Vec = vec![]; - let req = ResponsesApiRequest { - model: "gpt-5.1".to_string(), - instructions: "i".to_string(), - input, - tools, - tool_choice: "auto".to_string(), - parallel_tool_calls: true, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - service_tier: None, - text: None, - }; - - let v = serde_json::to_value(&req).expect("json"); - assert!(v.get("text").is_none()); - } - - #[test] - fn serializes_flex_service_tier_when_set() { - let req = ResponsesApiRequest { - model: "gpt-5.1".to_string(), - instructions: "i".to_string(), - input: vec![], - tools: vec![], - tool_choice: "auto".to_string(), - parallel_tool_calls: true, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - service_tier: Some(ServiceTier::Flex.to_string()), - text: None, - }; - - let v = serde_json::to_value(&req).expect("json"); - assert_eq!( - v.get("service_tier").and_then(|tier| tier.as_str()), - Some("flex") - ); - } - - #[test] - fn reserializes_shell_outputs_for_function_and_custom_tool_calls() { - let raw_output = r#"{"output":"hello","metadata":{"exit_code":0,"duration_seconds":0.5}}"#; - let expected_output = "Exit code: 0\nWall time: 0.5 seconds\nOutput:\nhello"; - let mut items = vec![ - ResponseItem::FunctionCall { - id: None, - name: "shell".to_string(), - namespace: None, - arguments: "{}".to_string(), - call_id: "call-1".to_string(), - }, - ResponseItem::FunctionCallOutput { - call_id: "call-1".to_string(), - output: FunctionCallOutputPayload::from_text(raw_output.to_string()), - }, - ResponseItem::CustomToolCall { - id: None, - status: None, - call_id: "call-2".to_string(), - name: "apply_patch".to_string(), - input: "*** Begin Patch".to_string(), - }, - ResponseItem::CustomToolCallOutput { - call_id: "call-2".to_string(), - output: FunctionCallOutputPayload::from_text(raw_output.to_string()), - }, - ]; - - reserialize_shell_outputs(&mut items); - - assert_eq!( - items, - vec![ - ResponseItem::FunctionCall { - id: None, - name: "shell".to_string(), - namespace: None, - arguments: "{}".to_string(), - call_id: "call-1".to_string(), - }, - ResponseItem::FunctionCallOutput { - call_id: "call-1".to_string(), - output: FunctionCallOutputPayload::from_text(expected_output.to_string()), - }, - ResponseItem::CustomToolCall { - id: None, - status: None, - call_id: "call-2".to_string(), - name: "apply_patch".to_string(), - input: "*** Begin Patch".to_string(), - }, - ResponseItem::CustomToolCallOutput { - call_id: "call-2".to_string(), - output: FunctionCallOutputPayload::from_text(expected_output.to_string()), - }, - ] - ); - } - - #[test] - fn tool_search_output_namespace_serializes_with_deferred_child_tools() { - let namespace = tools::ToolSearchOutputTool::Namespace(tools::ResponsesApiNamespace { - name: "mcp__codex_apps__calendar".to_string(), - description: "Plan events".to_string(), - tools: vec![tools::ResponsesApiNamespaceTool::Function( - tools::ResponsesApiTool { - name: "create_event".to_string(), - description: "Create a calendar event.".to_string(), - strict: false, - defer_loading: Some(true), - parameters: crate::tools::spec::JsonSchema::Object { - properties: Default::default(), - required: None, - additional_properties: None, - }, - output_schema: None, - }, - )], - }); - - let value = serde_json::to_value(namespace).expect("serialize namespace"); - - assert_eq!( - value, - serde_json::json!({ - "type": "namespace", - "name": "mcp__codex_apps__calendar", - "description": "Plan events", - "tools": [ - { - "type": "function", - "name": "create_event", - "description": "Create a calendar event.", - "strict": false, - "defer_loading": true, - "parameters": { - "type": "object", - "properties": {} - } - } - ] - }) - ); - } -} +#[path = "client_common_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/client_common_tests.rs b/codex-rs/core/src/client_common_tests.rs new file mode 100644 index 0000000000..769defabbc --- /dev/null +++ b/codex-rs/core/src/client_common_tests.rs @@ -0,0 +1,243 @@ +use codex_api::ResponsesApiRequest; +use codex_api::common::OpenAiVerbosity; +use codex_api::common::TextControls; +use codex_api::create_text_param_for_request; +use codex_protocol::config_types::ServiceTier; +use codex_protocol::models::FunctionCallOutputPayload; +use pretty_assertions::assert_eq; + +use super::*; + +#[test] +fn serializes_text_verbosity_when_set() { + let input: Vec = vec![]; + let tools: Vec = vec![]; + let req = ResponsesApiRequest { + model: "gpt-5.1".to_string(), + instructions: "i".to_string(), + input, + tools, + tool_choice: "auto".to_string(), + parallel_tool_calls: true, + reasoning: None, + store: false, + stream: true, + include: vec![], + prompt_cache_key: None, + service_tier: None, + text: Some(TextControls { + verbosity: Some(OpenAiVerbosity::Low), + format: None, + }), + }; + + let v = serde_json::to_value(&req).expect("json"); + assert_eq!( + v.get("text") + .and_then(|t| t.get("verbosity")) + .and_then(|s| s.as_str()), + Some("low") + ); +} + +#[test] +fn serializes_text_schema_with_strict_format() { + let input: Vec = vec![]; + let tools: Vec = vec![]; + let schema = serde_json::json!({ + "type": "object", + "properties": { + "answer": {"type": "string"} + }, + "required": ["answer"], + }); + let text_controls = + create_text_param_for_request(None, &Some(schema.clone())).expect("text controls"); + + let req = ResponsesApiRequest { + model: "gpt-5.1".to_string(), + instructions: "i".to_string(), + input, + tools, + tool_choice: "auto".to_string(), + parallel_tool_calls: true, + reasoning: None, + store: false, + stream: true, + include: vec![], + prompt_cache_key: None, + service_tier: None, + text: Some(text_controls), + }; + + let v = serde_json::to_value(&req).expect("json"); + let text = v.get("text").expect("text field"); + assert!(text.get("verbosity").is_none()); + let format = text.get("format").expect("format field"); + + assert_eq!( + format.get("name"), + Some(&serde_json::Value::String("codex_output_schema".into())) + ); + assert_eq!( + format.get("type"), + Some(&serde_json::Value::String("json_schema".into())) + ); + assert_eq!(format.get("strict"), Some(&serde_json::Value::Bool(true))); + assert_eq!(format.get("schema"), Some(&schema)); +} + +#[test] +fn omits_text_when_not_set() { + let input: Vec = vec![]; + let tools: Vec = vec![]; + let req = ResponsesApiRequest { + model: "gpt-5.1".to_string(), + instructions: "i".to_string(), + input, + tools, + tool_choice: "auto".to_string(), + parallel_tool_calls: true, + reasoning: None, + store: false, + stream: true, + include: vec![], + prompt_cache_key: None, + service_tier: None, + text: None, + }; + + let v = serde_json::to_value(&req).expect("json"); + assert!(v.get("text").is_none()); +} + +#[test] +fn serializes_flex_service_tier_when_set() { + let req = ResponsesApiRequest { + model: "gpt-5.1".to_string(), + instructions: "i".to_string(), + input: vec![], + tools: vec![], + tool_choice: "auto".to_string(), + parallel_tool_calls: true, + reasoning: None, + store: false, + stream: true, + include: vec![], + prompt_cache_key: None, + service_tier: Some(ServiceTier::Flex.to_string()), + text: None, + }; + + let v = serde_json::to_value(&req).expect("json"); + assert_eq!( + v.get("service_tier").and_then(|tier| tier.as_str()), + Some("flex") + ); +} + +#[test] +fn reserializes_shell_outputs_for_function_and_custom_tool_calls() { + let raw_output = r#"{"output":"hello","metadata":{"exit_code":0,"duration_seconds":0.5}}"#; + let expected_output = "Exit code: 0\nWall time: 0.5 seconds\nOutput:\nhello"; + let mut items = vec![ + ResponseItem::FunctionCall { + id: None, + name: "shell".to_string(), + namespace: None, + arguments: "{}".to_string(), + call_id: "call-1".to_string(), + }, + ResponseItem::FunctionCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload::from_text(raw_output.to_string()), + }, + ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "call-2".to_string(), + name: "apply_patch".to_string(), + input: "*** Begin Patch".to_string(), + }, + ResponseItem::CustomToolCallOutput { + call_id: "call-2".to_string(), + output: FunctionCallOutputPayload::from_text(raw_output.to_string()), + }, + ]; + + reserialize_shell_outputs(&mut items); + + assert_eq!( + items, + vec![ + ResponseItem::FunctionCall { + id: None, + name: "shell".to_string(), + namespace: None, + arguments: "{}".to_string(), + call_id: "call-1".to_string(), + }, + ResponseItem::FunctionCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload::from_text(expected_output.to_string()), + }, + ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "call-2".to_string(), + name: "apply_patch".to_string(), + input: "*** Begin Patch".to_string(), + }, + ResponseItem::CustomToolCallOutput { + call_id: "call-2".to_string(), + output: FunctionCallOutputPayload::from_text(expected_output.to_string()), + }, + ] + ); +} + +#[test] +fn tool_search_output_namespace_serializes_with_deferred_child_tools() { + let namespace = tools::ToolSearchOutputTool::Namespace(tools::ResponsesApiNamespace { + name: "mcp__codex_apps__calendar".to_string(), + description: "Plan events".to_string(), + tools: vec![tools::ResponsesApiNamespaceTool::Function( + tools::ResponsesApiTool { + name: "create_event".to_string(), + description: "Create a calendar event.".to_string(), + strict: false, + defer_loading: Some(true), + parameters: crate::tools::spec::JsonSchema::Object { + properties: Default::default(), + required: None, + additional_properties: None, + }, + output_schema: None, + }, + )], + }); + + let value = serde_json::to_value(namespace).expect("serialize namespace"); + + assert_eq!( + value, + serde_json::json!({ + "type": "namespace", + "name": "mcp__codex_apps__calendar", + "description": "Plan events", + "tools": [ + { + "type": "function", + "name": "create_event", + "description": "Create a calendar event.", + "strict": false, + "defer_loading": true, + "parameters": { + "type": "object", + "properties": {} + } + } + ] + }) + ); +} diff --git a/codex-rs/core/src/client_tests.rs b/codex-rs/core/src/client_tests.rs new file mode 100644 index 0000000000..138b61ffbb --- /dev/null +++ b/codex-rs/core/src/client_tests.rs @@ -0,0 +1,96 @@ +use super::ModelClient; +use codex_otel::SessionTelemetry; +use codex_protocol::ThreadId; +use codex_protocol::openai_models::ModelInfo; +use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::SubAgentSource; +use pretty_assertions::assert_eq; +use serde_json::json; + +fn test_model_client(session_source: SessionSource) -> ModelClient { + let provider = crate::model_provider_info::create_oss_provider_with_base_url( + "https://example.com/v1", + crate::model_provider_info::WireApi::Responses, + ); + ModelClient::new( + None, + ThreadId::new(), + provider, + session_source, + None, + false, + false, + false, + None, + ) +} + +fn test_model_info() -> ModelInfo { + serde_json::from_value(json!({ + "slug": "gpt-test", + "display_name": "gpt-test", + "description": "desc", + "default_reasoning_level": "medium", + "supported_reasoning_levels": [ + {"effort": "medium", "description": "medium"} + ], + "shell_type": "shell_command", + "visibility": "list", + "supported_in_api": true, + "priority": 1, + "upgrade": null, + "base_instructions": "base instructions", + "model_messages": null, + "supports_reasoning_summaries": false, + "support_verbosity": false, + "default_verbosity": null, + "apply_patch_tool_type": null, + "truncation_policy": {"mode": "bytes", "limit": 10000}, + "supports_parallel_tool_calls": false, + "supports_image_detail_original": false, + "context_window": 272000, + "auto_compact_token_limit": null, + "experimental_supported_tools": [] + })) + .expect("deserialize test model info") +} + +fn test_session_telemetry() -> SessionTelemetry { + SessionTelemetry::new( + ThreadId::new(), + "gpt-test", + "gpt-test", + None, + None, + None, + "test-originator".to_string(), + false, + "test-terminal".to_string(), + SessionSource::Cli, + ) +} + +#[test] +fn build_subagent_headers_sets_other_subagent_label() { + let client = test_model_client(SessionSource::SubAgent(SubAgentSource::Other( + "memory_consolidation".to_string(), + ))); + let headers = client.build_subagent_headers(); + let value = headers + .get("x-openai-subagent") + .and_then(|value| value.to_str().ok()); + assert_eq!(value, Some("memory_consolidation")); +} + +#[tokio::test] +async fn summarize_memories_returns_empty_for_empty_input() { + let client = test_model_client(SessionSource::Cli); + let model_info = test_model_info(); + let session_telemetry = test_session_telemetry(); + + let output = client + .summarize_memories(Vec::new(), &model_info, None, &session_telemetry) + .await + .expect("empty summarize request should succeed"); + assert_eq!(output.len(), 0); +} diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index 3219a10339..91a6fb2d61 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -553,225 +553,5 @@ where } #[cfg(test)] -mod tests { - use super::*; - use async_channel::bounded; - use codex_protocol::models::NetworkPermissions; - use codex_protocol::models::PermissionProfile; - use codex_protocol::models::ResponseItem; - use codex_protocol::protocol::AgentStatus; - use codex_protocol::protocol::EventMsg; - use codex_protocol::protocol::RawResponseItemEvent; - use codex_protocol::protocol::TurnAbortReason; - use codex_protocol::protocol::TurnAbortedEvent; - use codex_protocol::request_permissions::RequestPermissionsEvent; - use codex_protocol::request_permissions::RequestPermissionsResponse; - use pretty_assertions::assert_eq; - use tokio::sync::watch; - - #[tokio::test] - async fn forward_events_cancelled_while_send_blocked_shuts_down_delegate() { - let (tx_events, rx_events) = bounded(1); - let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY); - let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit); - let (session, ctx, _rx_evt) = crate::codex::make_session_and_context_with_rx().await; - let codex = Arc::new(Codex { - tx_sub, - rx_event: rx_events, - agent_status, - session: Arc::clone(&session), - session_loop_termination: completed_session_loop_termination(), - }); - - let (tx_out, rx_out) = bounded(1); - tx_out - .send(Event { - id: "full".to_string(), - msg: EventMsg::TurnAborted(TurnAbortedEvent { - turn_id: Some("turn-1".to_string()), - reason: TurnAbortReason::Interrupted, - }), - }) - .await - .unwrap(); - - let cancel = CancellationToken::new(); - let forward = tokio::spawn(forward_events( - Arc::clone(&codex), - tx_out.clone(), - session, - ctx, - cancel.clone(), - )); - - tx_events - .send(Event { - id: "evt".to_string(), - msg: EventMsg::RawResponseItem(RawResponseItemEvent { - item: ResponseItem::CustomToolCall { - id: None, - status: None, - call_id: "call-1".to_string(), - name: "tool".to_string(), - input: "{}".to_string(), - }, - }), - }) - .await - .unwrap(); - - drop(tx_events); - cancel.cancel(); - timeout(std::time::Duration::from_millis(1000), forward) - .await - .expect("forward_events hung") - .expect("forward_events join error"); - - let received = rx_out.recv().await.expect("prefilled event missing"); - assert_eq!("full", received.id); - let mut ops = Vec::new(); - while let Ok(sub) = rx_sub.try_recv() { - ops.push(sub.op); - } - assert!( - ops.iter().any(|op| matches!(op, Op::Interrupt)), - "expected Interrupt op after cancellation" - ); - assert!( - ops.iter().any(|op| matches!(op, Op::Shutdown)), - "expected Shutdown op after cancellation" - ); - } - - #[tokio::test] - async fn forward_ops_preserves_submission_trace_context() { - let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY); - let (_tx_events, rx_events) = bounded(SUBMISSION_CHANNEL_CAPACITY); - let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit); - let (session, _ctx, _rx_evt) = crate::codex::make_session_and_context_with_rx().await; - let codex = Arc::new(Codex { - tx_sub, - rx_event: rx_events, - agent_status, - session, - session_loop_termination: completed_session_loop_termination(), - }); - let (tx_ops, rx_ops) = bounded(1); - let cancel = CancellationToken::new(); - let forward = tokio::spawn(forward_ops(Arc::clone(&codex), rx_ops, cancel)); - - let submission = Submission { - id: "sub-1".to_string(), - op: Op::Interrupt, - trace: Some(codex_protocol::protocol::W3cTraceContext { - traceparent: Some( - "00-1234567890abcdef1234567890abcdef-1234567890abcdef-01".to_string(), - ), - tracestate: Some("vendor=state".to_string()), - }), - }; - tx_ops.send(submission.clone()).await.unwrap(); - drop(tx_ops); - - let forwarded = timeout(Duration::from_secs(1), rx_sub.recv()) - .await - .expect("forward_ops hung") - .expect("forwarded submission missing"); - assert_eq!(submission.id, forwarded.id); - assert_eq!(submission.op, forwarded.op); - assert_eq!(submission.trace, forwarded.trace); - - timeout(Duration::from_secs(1), forward) - .await - .expect("forward_ops did not exit") - .expect("forward_ops join error"); - } - - #[tokio::test] - async fn handle_request_permissions_uses_tool_call_id_for_round_trip() { - let (parent_session, parent_ctx, rx_events) = - crate::codex::make_session_and_context_with_rx().await; - *parent_session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY); - let (_tx_events, rx_events_child) = bounded(SUBMISSION_CHANNEL_CAPACITY); - let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit); - let codex = Arc::new(Codex { - tx_sub, - rx_event: rx_events_child, - agent_status, - session: Arc::clone(&parent_session), - session_loop_termination: completed_session_loop_termination(), - }); - - let call_id = "tool-call-1".to_string(); - let expected_response = RequestPermissionsResponse { - permissions: PermissionProfile { - network: Some(NetworkPermissions { - enabled: Some(true), - }), - ..PermissionProfile::default() - }, - scope: PermissionGrantScope::Turn, - }; - let cancel_token = CancellationToken::new(); - let request_call_id = call_id.clone(); - - let handle = tokio::spawn({ - let codex = Arc::clone(&codex); - let parent_session = Arc::clone(&parent_session); - let parent_ctx = Arc::clone(&parent_ctx); - let cancel_token = cancel_token.clone(); - async move { - handle_request_permissions( - codex.as_ref(), - parent_session.as_ref(), - parent_ctx.as_ref(), - RequestPermissionsEvent { - call_id: request_call_id, - turn_id: "child-turn-1".to_string(), - reason: Some("need access".to_string()), - permissions: PermissionProfile { - network: Some(NetworkPermissions { - enabled: Some(true), - }), - ..PermissionProfile::default() - }, - }, - &cancel_token, - ) - .await; - } - }); - - let request_event = timeout(Duration::from_secs(1), rx_events.recv()) - .await - .expect("request_permissions event timed out") - .expect("request_permissions event missing"); - let EventMsg::RequestPermissions(request) = request_event.msg else { - panic!("expected RequestPermissions event"); - }; - assert_eq!(request.call_id, call_id.clone()); - - parent_session - .notify_request_permissions_response(&call_id, expected_response.clone()) - .await; - - timeout(Duration::from_secs(1), handle) - .await - .expect("handle_request_permissions hung") - .expect("handle_request_permissions join error"); - - let submission = timeout(Duration::from_secs(1), rx_sub.recv()) - .await - .expect("request_permissions response timed out") - .expect("request_permissions response missing"); - assert_eq!( - submission.op, - Op::RequestPermissionsResponse { - id: call_id, - response: expected_response, - } - ); - } -} +#[path = "codex_delegate_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/codex_delegate_tests.rs b/codex-rs/core/src/codex_delegate_tests.rs new file mode 100644 index 0000000000..3d4c49c9df --- /dev/null +++ b/codex-rs/core/src/codex_delegate_tests.rs @@ -0,0 +1,220 @@ +use super::*; +use async_channel::bounded; +use codex_protocol::models::NetworkPermissions; +use codex_protocol::models::PermissionProfile; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::AgentStatus; +use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::RawResponseItemEvent; +use codex_protocol::protocol::TurnAbortReason; +use codex_protocol::protocol::TurnAbortedEvent; +use codex_protocol::request_permissions::RequestPermissionsEvent; +use codex_protocol::request_permissions::RequestPermissionsResponse; +use pretty_assertions::assert_eq; +use tokio::sync::watch; + +#[tokio::test] +async fn forward_events_cancelled_while_send_blocked_shuts_down_delegate() { + let (tx_events, rx_events) = bounded(1); + let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY); + let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit); + let (session, ctx, _rx_evt) = crate::codex::make_session_and_context_with_rx().await; + let codex = Arc::new(Codex { + tx_sub, + rx_event: rx_events, + agent_status, + session: Arc::clone(&session), + session_loop_termination: completed_session_loop_termination(), + }); + + let (tx_out, rx_out) = bounded(1); + tx_out + .send(Event { + id: "full".to_string(), + msg: EventMsg::TurnAborted(TurnAbortedEvent { + turn_id: Some("turn-1".to_string()), + reason: TurnAbortReason::Interrupted, + }), + }) + .await + .unwrap(); + + let cancel = CancellationToken::new(); + let forward = tokio::spawn(forward_events( + Arc::clone(&codex), + tx_out.clone(), + session, + ctx, + cancel.clone(), + )); + + tx_events + .send(Event { + id: "evt".to_string(), + msg: EventMsg::RawResponseItem(RawResponseItemEvent { + item: ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "call-1".to_string(), + name: "tool".to_string(), + input: "{}".to_string(), + }, + }), + }) + .await + .unwrap(); + + drop(tx_events); + cancel.cancel(); + timeout(std::time::Duration::from_millis(1000), forward) + .await + .expect("forward_events hung") + .expect("forward_events join error"); + + let received = rx_out.recv().await.expect("prefilled event missing"); + assert_eq!("full", received.id); + let mut ops = Vec::new(); + while let Ok(sub) = rx_sub.try_recv() { + ops.push(sub.op); + } + assert!( + ops.iter().any(|op| matches!(op, Op::Interrupt)), + "expected Interrupt op after cancellation" + ); + assert!( + ops.iter().any(|op| matches!(op, Op::Shutdown)), + "expected Shutdown op after cancellation" + ); +} + +#[tokio::test] +async fn forward_ops_preserves_submission_trace_context() { + let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY); + let (_tx_events, rx_events) = bounded(SUBMISSION_CHANNEL_CAPACITY); + let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit); + let (session, _ctx, _rx_evt) = crate::codex::make_session_and_context_with_rx().await; + let codex = Arc::new(Codex { + tx_sub, + rx_event: rx_events, + agent_status, + session, + session_loop_termination: completed_session_loop_termination(), + }); + let (tx_ops, rx_ops) = bounded(1); + let cancel = CancellationToken::new(); + let forward = tokio::spawn(forward_ops(Arc::clone(&codex), rx_ops, cancel)); + + let submission = Submission { + id: "sub-1".to_string(), + op: Op::Interrupt, + trace: Some(codex_protocol::protocol::W3cTraceContext { + traceparent: Some( + "00-1234567890abcdef1234567890abcdef-1234567890abcdef-01".to_string(), + ), + tracestate: Some("vendor=state".to_string()), + }), + }; + tx_ops.send(submission.clone()).await.unwrap(); + drop(tx_ops); + + let forwarded = timeout(Duration::from_secs(1), rx_sub.recv()) + .await + .expect("forward_ops hung") + .expect("forwarded submission missing"); + assert_eq!(submission.id, forwarded.id); + assert_eq!(submission.op, forwarded.op); + assert_eq!(submission.trace, forwarded.trace); + + timeout(Duration::from_secs(1), forward) + .await + .expect("forward_ops did not exit") + .expect("forward_ops join error"); +} + +#[tokio::test] +async fn handle_request_permissions_uses_tool_call_id_for_round_trip() { + let (parent_session, parent_ctx, rx_events) = + crate::codex::make_session_and_context_with_rx().await; + *parent_session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY); + let (_tx_events, rx_events_child) = bounded(SUBMISSION_CHANNEL_CAPACITY); + let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit); + let codex = Arc::new(Codex { + tx_sub, + rx_event: rx_events_child, + agent_status, + session: Arc::clone(&parent_session), + session_loop_termination: completed_session_loop_termination(), + }); + + let call_id = "tool-call-1".to_string(); + let expected_response = RequestPermissionsResponse { + permissions: PermissionProfile { + network: Some(NetworkPermissions { + enabled: Some(true), + }), + ..PermissionProfile::default() + }, + scope: PermissionGrantScope::Turn, + }; + let cancel_token = CancellationToken::new(); + let request_call_id = call_id.clone(); + + let handle = tokio::spawn({ + let codex = Arc::clone(&codex); + let parent_session = Arc::clone(&parent_session); + let parent_ctx = Arc::clone(&parent_ctx); + let cancel_token = cancel_token.clone(); + async move { + handle_request_permissions( + codex.as_ref(), + parent_session.as_ref(), + parent_ctx.as_ref(), + RequestPermissionsEvent { + call_id: request_call_id, + turn_id: "child-turn-1".to_string(), + reason: Some("need access".to_string()), + permissions: PermissionProfile { + network: Some(NetworkPermissions { + enabled: Some(true), + }), + ..PermissionProfile::default() + }, + }, + &cancel_token, + ) + .await; + } + }); + + let request_event = timeout(Duration::from_secs(1), rx_events.recv()) + .await + .expect("request_permissions event timed out") + .expect("request_permissions event missing"); + let EventMsg::RequestPermissions(request) = request_event.msg else { + panic!("expected RequestPermissions event"); + }; + assert_eq!(request.call_id, call_id.clone()); + + parent_session + .notify_request_permissions_response(&call_id, expected_response.clone()) + .await; + + timeout(Duration::from_secs(1), handle) + .await + .expect("handle_request_permissions hung") + .expect("handle_request_permissions join error"); + + let submission = timeout(Duration::from_secs(1), rx_sub.recv()) + .await + .expect("request_permissions response timed out") + .expect("request_permissions response missing"); + assert_eq!( + submission.op, + Op::RequestPermissionsResponse { + id: call_id, + response: expected_response, + } + ); +} diff --git a/codex-rs/core/src/command_canonicalization.rs b/codex-rs/core/src/command_canonicalization.rs index 0708e41e19..3457fa2f63 100644 --- a/codex-rs/core/src/command_canonicalization.rs +++ b/codex-rs/core/src/command_canonicalization.rs @@ -38,93 +38,5 @@ pub(crate) fn canonicalize_command_for_approval(command: &[String]) -> Vec) -> Option } #[cfg(test)] -mod tests { - use super::build_commit_message_trailer; - use super::commit_message_trailer_instruction; - use super::resolve_attribution_value; - - #[test] - fn blank_attribution_disables_trailer_prompt() { - assert_eq!(build_commit_message_trailer(Some("")), None); - assert_eq!(commit_message_trailer_instruction(Some(" ")), None); - } - - #[test] - fn default_attribution_uses_codex_trailer() { - assert_eq!( - build_commit_message_trailer(None).as_deref(), - Some("Co-authored-by: Codex ") - ); - } - - #[test] - fn resolve_value_handles_default_custom_and_blank() { - assert_eq!( - resolve_attribution_value(None), - Some("Codex ".to_string()) - ); - assert_eq!( - resolve_attribution_value(Some("MyAgent ")), - Some("MyAgent ".to_string()) - ); - assert_eq!( - resolve_attribution_value(Some("MyAgent")), - Some("MyAgent".to_string()) - ); - assert_eq!(resolve_attribution_value(Some(" ")), None); - } - - #[test] - fn instruction_mentions_trailer_and_omits_generated_with() { - let instruction = commit_message_trailer_instruction(Some("AgentX ")) - .expect("instruction expected"); - assert!(instruction.contains("Co-authored-by: AgentX ")); - assert!(instruction.contains("exactly once")); - assert!(!instruction.contains("Generated-with")); - } -} +#[path = "commit_attribution_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/commit_attribution_tests.rs b/codex-rs/core/src/commit_attribution_tests.rs new file mode 100644 index 0000000000..be7661a604 --- /dev/null +++ b/codex-rs/core/src/commit_attribution_tests.rs @@ -0,0 +1,43 @@ +use super::build_commit_message_trailer; +use super::commit_message_trailer_instruction; +use super::resolve_attribution_value; + +#[test] +fn blank_attribution_disables_trailer_prompt() { + assert_eq!(build_commit_message_trailer(Some("")), None); + assert_eq!(commit_message_trailer_instruction(Some(" ")), None); +} + +#[test] +fn default_attribution_uses_codex_trailer() { + assert_eq!( + build_commit_message_trailer(None).as_deref(), + Some("Co-authored-by: Codex ") + ); +} + +#[test] +fn resolve_value_handles_default_custom_and_blank() { + assert_eq!( + resolve_attribution_value(None), + Some("Codex ".to_string()) + ); + assert_eq!( + resolve_attribution_value(Some("MyAgent ")), + Some("MyAgent ".to_string()) + ); + assert_eq!( + resolve_attribution_value(Some("MyAgent")), + Some("MyAgent".to_string()) + ); + assert_eq!(resolve_attribution_value(Some(" ")), None); +} + +#[test] +fn instruction_mentions_trailer_and_omits_generated_with() { + let instruction = commit_message_trailer_instruction(Some("AgentX ")) + .expect("instruction expected"); + assert!(instruction.contains("Co-authored-by: AgentX ")); + assert!(instruction.contains("exactly once")); + assert!(!instruction.contains("Generated-with")); +} diff --git a/codex-rs/core/src/compact.rs b/codex-rs/core/src/compact.rs index 42e338443d..4900051938 100644 --- a/codex-rs/core/src/compact.rs +++ b/codex-rs/core/src/compact.rs @@ -438,571 +438,5 @@ async fn drain_to_completed( } #[cfg(test)] -mod tests { - - use super::*; - use pretty_assertions::assert_eq; - - async fn process_compacted_history_with_test_session( - compacted_history: Vec, - previous_turn_settings: Option<&PreviousTurnSettings>, - ) -> (Vec, Vec) { - let (session, turn_context) = crate::codex::make_session_and_context().await; - session - .set_previous_turn_settings(previous_turn_settings.cloned()) - .await; - let initial_context = session.build_initial_context(&turn_context).await; - let refreshed = crate::compact_remote::process_compacted_history( - &session, - &turn_context, - compacted_history, - InitialContextInjection::BeforeLastUserMessage, - ) - .await; - (refreshed, initial_context) - } - - #[test] - fn content_items_to_text_joins_non_empty_segments() { - let items = vec![ - ContentItem::InputText { - text: "hello".to_string(), - }, - ContentItem::OutputText { - text: String::new(), - }, - ContentItem::OutputText { - text: "world".to_string(), - }, - ]; - - let joined = content_items_to_text(&items); - - assert_eq!(Some("hello\nworld".to_string()), joined); - } - - #[test] - fn content_items_to_text_ignores_image_only_content() { - let items = vec![ContentItem::InputImage { - image_url: "file://image.png".to_string(), - }]; - - let joined = content_items_to_text(&items); - - assert_eq!(None, joined); - } - - #[test] - fn collect_user_messages_extracts_user_text_only() { - let items = vec![ - ResponseItem::Message { - id: Some("assistant".to_string()), - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: "ignored".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: Some("user".to_string()), - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "first".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Other, - ]; - - let collected = collect_user_messages(&items); - - assert_eq!(vec!["first".to_string()], collected); - } - - #[test] - fn collect_user_messages_filters_session_prefix_entries() { - let items = vec![ - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: r#"# AGENTS.md instructions for project - - -do things -"# - .to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "cwd=/tmp".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "real user message".to_string(), - }], - end_turn: None, - phase: None, - }, - ]; - - let collected = collect_user_messages(&items); - - assert_eq!(vec!["real user message".to_string()], collected); - } - - #[test] - fn build_token_limited_compacted_history_truncates_overlong_user_messages() { - // Use a small truncation limit so the test remains fast while still validating - // that oversized user content is truncated. - let max_tokens = 16; - let big = "word ".repeat(200); - let history = super::build_compacted_history_with_limit( - Vec::new(), - std::slice::from_ref(&big), - "SUMMARY", - max_tokens, - ); - assert_eq!(history.len(), 2); - - let truncated_message = &history[0]; - let summary_message = &history[1]; - - let truncated_text = match truncated_message { - ResponseItem::Message { role, content, .. } if role == "user" => { - content_items_to_text(content).unwrap_or_default() - } - other => panic!("unexpected item in history: {other:?}"), - }; - - assert!( - truncated_text.contains("tokens truncated"), - "expected truncation marker in truncated user message" - ); - assert!( - !truncated_text.contains(&big), - "truncated user message should not include the full oversized user text" - ); - - let summary_text = match summary_message { - ResponseItem::Message { role, content, .. } if role == "user" => { - content_items_to_text(content).unwrap_or_default() - } - other => panic!("unexpected item in history: {other:?}"), - }; - assert_eq!(summary_text, "SUMMARY"); - } - - #[test] - fn build_token_limited_compacted_history_appends_summary_message() { - let initial_context: Vec = Vec::new(); - let user_messages = vec!["first user message".to_string()]; - let summary_text = "summary text"; - - let history = build_compacted_history(initial_context, &user_messages, summary_text); - assert!( - !history.is_empty(), - "expected compacted history to include summary" - ); - - let last = history.last().expect("history should have a summary entry"); - let summary = match last { - ResponseItem::Message { role, content, .. } if role == "user" => { - content_items_to_text(content).unwrap_or_default() - } - other => panic!("expected summary message, found {other:?}"), - }; - assert_eq!(summary, summary_text); - } - - #[tokio::test] - async fn process_compacted_history_replaces_developer_messages() { - let compacted_history = vec![ - ResponseItem::Message { - id: None, - role: "developer".to_string(), - content: vec![ContentItem::InputText { - text: "stale permissions".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "summary".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "developer".to_string(), - content: vec![ContentItem::InputText { - text: "stale personality".to_string(), - }], - end_turn: None, - phase: None, - }, - ]; - let (refreshed, mut expected) = - process_compacted_history_with_test_session(compacted_history, None).await; - expected.push(ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "summary".to_string(), - }], - end_turn: None, - phase: None, - }); - assert_eq!(refreshed, expected); - } - - #[tokio::test] - async fn process_compacted_history_reinjects_full_initial_context() { - let compacted_history = vec![ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "summary".to_string(), - }], - end_turn: None, - phase: None, - }]; - let (refreshed, mut expected) = - process_compacted_history_with_test_session(compacted_history, None).await; - expected.push(ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "summary".to_string(), - }], - end_turn: None, - phase: None, - }); - assert_eq!(refreshed, expected); - } - - #[tokio::test] - async fn process_compacted_history_drops_non_user_content_messages() { - let compacted_history = vec![ - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: r#"# AGENTS.md instructions for /repo - - -keep me updated -"# - .to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: r#" - /repo - zsh -"# - .to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: r#" - turn-1 - interrupted -"# - .to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "summary".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "developer".to_string(), - content: vec![ContentItem::InputText { - text: "stale developer instructions".to_string(), - }], - end_turn: None, - phase: None, - }, - ]; - let (refreshed, mut expected) = - process_compacted_history_with_test_session(compacted_history, None).await; - expected.push(ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "summary".to_string(), - }], - end_turn: None, - phase: None, - }); - assert_eq!(refreshed, expected); - } - - #[tokio::test] - async fn process_compacted_history_inserts_context_before_last_real_user_message_only() { - let compacted_history = vec![ - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "older user".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: format!("{SUMMARY_PREFIX}\nsummary text"), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "latest user".to_string(), - }], - end_turn: None, - phase: None, - }, - ]; - - let (refreshed, initial_context) = - process_compacted_history_with_test_session(compacted_history, None).await; - let mut expected = vec![ - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "older user".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: format!("{SUMMARY_PREFIX}\nsummary text"), - }], - end_turn: None, - phase: None, - }, - ]; - expected.extend(initial_context); - expected.push(ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "latest user".to_string(), - }], - end_turn: None, - phase: None, - }); - assert_eq!(refreshed, expected); - } - - #[tokio::test] - async fn process_compacted_history_reinjects_model_switch_message() { - let compacted_history = vec![ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "summary".to_string(), - }], - end_turn: None, - phase: None, - }]; - let previous_turn_settings = PreviousTurnSettings { - model: "previous-regular-model".to_string(), - realtime_active: None, - }; - - let (refreshed, initial_context) = process_compacted_history_with_test_session( - compacted_history, - Some(&previous_turn_settings), - ) - .await; - - let ResponseItem::Message { role, content, .. } = &initial_context[0] else { - panic!("expected developer message"); - }; - assert_eq!(role, "developer"); - let [ContentItem::InputText { text }, ..] = content.as_slice() else { - panic!("expected developer text"); - }; - assert!(text.contains("")); - - let mut expected = initial_context; - expected.push(ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "summary".to_string(), - }], - end_turn: None, - phase: None, - }); - assert_eq!(refreshed, expected); - } - - #[test] - fn insert_initial_context_before_last_real_user_or_summary_keeps_summary_last() { - let compacted_history = vec![ - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "older user".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "latest user".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: format!("{SUMMARY_PREFIX}\nsummary text"), - }], - end_turn: None, - phase: None, - }, - ]; - let initial_context = vec![ResponseItem::Message { - id: None, - role: "developer".to_string(), - content: vec![ContentItem::InputText { - text: "fresh permissions".to_string(), - }], - end_turn: None, - phase: None, - }]; - - let refreshed = insert_initial_context_before_last_real_user_or_summary( - compacted_history, - initial_context, - ); - let expected = vec![ - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "older user".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "developer".to_string(), - content: vec![ContentItem::InputText { - text: "fresh permissions".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "latest user".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: format!("{SUMMARY_PREFIX}\nsummary text"), - }], - end_turn: None, - phase: None, - }, - ]; - assert_eq!(refreshed, expected); - } - - #[test] - fn insert_initial_context_before_last_real_user_or_summary_keeps_compaction_last() { - let compacted_history = vec![ResponseItem::Compaction { - encrypted_content: "encrypted".to_string(), - }]; - let initial_context = vec![ResponseItem::Message { - id: None, - role: "developer".to_string(), - content: vec![ContentItem::InputText { - text: "fresh permissions".to_string(), - }], - end_turn: None, - phase: None, - }]; - - let refreshed = insert_initial_context_before_last_real_user_or_summary( - compacted_history, - initial_context, - ); - let expected = vec![ - ResponseItem::Message { - id: None, - role: "developer".to_string(), - content: vec![ContentItem::InputText { - text: "fresh permissions".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Compaction { - encrypted_content: "encrypted".to_string(), - }, - ]; - assert_eq!(refreshed, expected); - } -} +#[path = "compact_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/compact_tests.rs b/codex-rs/core/src/compact_tests.rs new file mode 100644 index 0000000000..92e889d647 --- /dev/null +++ b/codex-rs/core/src/compact_tests.rs @@ -0,0 +1,561 @@ +use super::*; +use pretty_assertions::assert_eq; + +async fn process_compacted_history_with_test_session( + compacted_history: Vec, + previous_turn_settings: Option<&PreviousTurnSettings>, +) -> (Vec, Vec) { + let (session, turn_context) = crate::codex::make_session_and_context().await; + session + .set_previous_turn_settings(previous_turn_settings.cloned()) + .await; + let initial_context = session.build_initial_context(&turn_context).await; + let refreshed = crate::compact_remote::process_compacted_history( + &session, + &turn_context, + compacted_history, + InitialContextInjection::BeforeLastUserMessage, + ) + .await; + (refreshed, initial_context) +} + +#[test] +fn content_items_to_text_joins_non_empty_segments() { + let items = vec![ + ContentItem::InputText { + text: "hello".to_string(), + }, + ContentItem::OutputText { + text: String::new(), + }, + ContentItem::OutputText { + text: "world".to_string(), + }, + ]; + + let joined = content_items_to_text(&items); + + assert_eq!(Some("hello\nworld".to_string()), joined); +} + +#[test] +fn content_items_to_text_ignores_image_only_content() { + let items = vec![ContentItem::InputImage { + image_url: "file://image.png".to_string(), + }]; + + let joined = content_items_to_text(&items); + + assert_eq!(None, joined); +} + +#[test] +fn collect_user_messages_extracts_user_text_only() { + let items = vec![ + ResponseItem::Message { + id: Some("assistant".to_string()), + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "ignored".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: Some("user".to_string()), + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "first".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Other, + ]; + + let collected = collect_user_messages(&items); + + assert_eq!(vec!["first".to_string()], collected); +} + +#[test] +fn collect_user_messages_filters_session_prefix_entries() { + let items = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: r#"# AGENTS.md instructions for project + + +do things +"# + .to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "cwd=/tmp".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "real user message".to_string(), + }], + end_turn: None, + phase: None, + }, + ]; + + let collected = collect_user_messages(&items); + + assert_eq!(vec!["real user message".to_string()], collected); +} + +#[test] +fn build_token_limited_compacted_history_truncates_overlong_user_messages() { + // Use a small truncation limit so the test remains fast while still validating + // that oversized user content is truncated. + let max_tokens = 16; + let big = "word ".repeat(200); + let history = super::build_compacted_history_with_limit( + Vec::new(), + std::slice::from_ref(&big), + "SUMMARY", + max_tokens, + ); + assert_eq!(history.len(), 2); + + let truncated_message = &history[0]; + let summary_message = &history[1]; + + let truncated_text = match truncated_message { + ResponseItem::Message { role, content, .. } if role == "user" => { + content_items_to_text(content).unwrap_or_default() + } + other => panic!("unexpected item in history: {other:?}"), + }; + + assert!( + truncated_text.contains("tokens truncated"), + "expected truncation marker in truncated user message" + ); + assert!( + !truncated_text.contains(&big), + "truncated user message should not include the full oversized user text" + ); + + let summary_text = match summary_message { + ResponseItem::Message { role, content, .. } if role == "user" => { + content_items_to_text(content).unwrap_or_default() + } + other => panic!("unexpected item in history: {other:?}"), + }; + assert_eq!(summary_text, "SUMMARY"); +} + +#[test] +fn build_token_limited_compacted_history_appends_summary_message() { + let initial_context: Vec = Vec::new(); + let user_messages = vec!["first user message".to_string()]; + let summary_text = "summary text"; + + let history = build_compacted_history(initial_context, &user_messages, summary_text); + assert!( + !history.is_empty(), + "expected compacted history to include summary" + ); + + let last = history.last().expect("history should have a summary entry"); + let summary = match last { + ResponseItem::Message { role, content, .. } if role == "user" => { + content_items_to_text(content).unwrap_or_default() + } + other => panic!("expected summary message, found {other:?}"), + }; + assert_eq!(summary, summary_text); +} + +#[tokio::test] +async fn process_compacted_history_replaces_developer_messages() { + let compacted_history = vec![ + ResponseItem::Message { + id: None, + role: "developer".to_string(), + content: vec![ContentItem::InputText { + text: "stale permissions".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "summary".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "developer".to_string(), + content: vec![ContentItem::InputText { + text: "stale personality".to_string(), + }], + end_turn: None, + phase: None, + }, + ]; + let (refreshed, mut expected) = + process_compacted_history_with_test_session(compacted_history, None).await; + expected.push(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "summary".to_string(), + }], + end_turn: None, + phase: None, + }); + assert_eq!(refreshed, expected); +} + +#[tokio::test] +async fn process_compacted_history_reinjects_full_initial_context() { + let compacted_history = vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "summary".to_string(), + }], + end_turn: None, + phase: None, + }]; + let (refreshed, mut expected) = + process_compacted_history_with_test_session(compacted_history, None).await; + expected.push(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "summary".to_string(), + }], + end_turn: None, + phase: None, + }); + assert_eq!(refreshed, expected); +} + +#[tokio::test] +async fn process_compacted_history_drops_non_user_content_messages() { + let compacted_history = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: r#"# AGENTS.md instructions for /repo + + +keep me updated +"# + .to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: r#" + /repo + zsh +"# + .to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: r#" + turn-1 + interrupted +"# + .to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "summary".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "developer".to_string(), + content: vec![ContentItem::InputText { + text: "stale developer instructions".to_string(), + }], + end_turn: None, + phase: None, + }, + ]; + let (refreshed, mut expected) = + process_compacted_history_with_test_session(compacted_history, None).await; + expected.push(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "summary".to_string(), + }], + end_turn: None, + phase: None, + }); + assert_eq!(refreshed, expected); +} + +#[tokio::test] +async fn process_compacted_history_inserts_context_before_last_real_user_message_only() { + let compacted_history = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "older user".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: format!("{SUMMARY_PREFIX}\nsummary text"), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "latest user".to_string(), + }], + end_turn: None, + phase: None, + }, + ]; + + let (refreshed, initial_context) = + process_compacted_history_with_test_session(compacted_history, None).await; + let mut expected = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "older user".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: format!("{SUMMARY_PREFIX}\nsummary text"), + }], + end_turn: None, + phase: None, + }, + ]; + expected.extend(initial_context); + expected.push(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "latest user".to_string(), + }], + end_turn: None, + phase: None, + }); + assert_eq!(refreshed, expected); +} + +#[tokio::test] +async fn process_compacted_history_reinjects_model_switch_message() { + let compacted_history = vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "summary".to_string(), + }], + end_turn: None, + phase: None, + }]; + let previous_turn_settings = PreviousTurnSettings { + model: "previous-regular-model".to_string(), + realtime_active: None, + }; + + let (refreshed, initial_context) = process_compacted_history_with_test_session( + compacted_history, + Some(&previous_turn_settings), + ) + .await; + + let ResponseItem::Message { role, content, .. } = &initial_context[0] else { + panic!("expected developer message"); + }; + assert_eq!(role, "developer"); + let [ContentItem::InputText { text }, ..] = content.as_slice() else { + panic!("expected developer text"); + }; + assert!(text.contains("")); + + let mut expected = initial_context; + expected.push(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "summary".to_string(), + }], + end_turn: None, + phase: None, + }); + assert_eq!(refreshed, expected); +} + +#[test] +fn insert_initial_context_before_last_real_user_or_summary_keeps_summary_last() { + let compacted_history = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "older user".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "latest user".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: format!("{SUMMARY_PREFIX}\nsummary text"), + }], + end_turn: None, + phase: None, + }, + ]; + let initial_context = vec![ResponseItem::Message { + id: None, + role: "developer".to_string(), + content: vec![ContentItem::InputText { + text: "fresh permissions".to_string(), + }], + end_turn: None, + phase: None, + }]; + + let refreshed = + insert_initial_context_before_last_real_user_or_summary(compacted_history, initial_context); + let expected = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "older user".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "developer".to_string(), + content: vec![ContentItem::InputText { + text: "fresh permissions".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "latest user".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: format!("{SUMMARY_PREFIX}\nsummary text"), + }], + end_turn: None, + phase: None, + }, + ]; + assert_eq!(refreshed, expected); +} + +#[test] +fn insert_initial_context_before_last_real_user_or_summary_keeps_compaction_last() { + let compacted_history = vec![ResponseItem::Compaction { + encrypted_content: "encrypted".to_string(), + }]; + let initial_context = vec![ResponseItem::Message { + id: None, + role: "developer".to_string(), + content: vec![ContentItem::InputText { + text: "fresh permissions".to_string(), + }], + end_turn: None, + phase: None, + }]; + + let refreshed = + insert_initial_context_before_last_real_user_or_summary(compacted_history, initial_context); + let expected = vec![ + ResponseItem::Message { + id: None, + role: "developer".to_string(), + content: vec![ContentItem::InputText { + text: "fresh permissions".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Compaction { + encrypted_content: "encrypted".to_string(), + }, + ]; + assert_eq!(refreshed, expected); +} diff --git a/codex-rs/core/src/config/edit.rs b/codex-rs/core/src/config/edit.rs index cf139d3a5e..8d0011b5e8 100644 --- a/codex-rs/core/src/config/edit.rs +++ b/codex-rs/core/src/config/edit.rs @@ -952,1016 +952,5 @@ impl ConfigEditsBuilder { } #[cfg(test)] -mod tests { - use super::*; - use crate::config::types::McpServerTransportConfig; - use codex_protocol::openai_models::ReasoningEffort; - use pretty_assertions::assert_eq; - #[cfg(unix)] - use std::os::unix::fs::symlink; - use tempfile::tempdir; - use toml::Value as TomlValue; - - #[test] - fn blocking_set_model_top_level() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - - apply_blocking( - codex_home, - None, - &[ConfigEdit::SetModel { - model: Some("gpt-5.1-codex".to_string()), - effort: Some(ReasoningEffort::High), - }], - ) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"model = "gpt-5.1-codex" -model_reasoning_effort = "high" -"#; - assert_eq!(contents, expected); - } - - #[test] - fn builder_with_edits_applies_custom_paths() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - - ConfigEditsBuilder::new(codex_home) - .with_edits(vec![ConfigEdit::SetPath { - segments: vec!["enabled".to_string()], - value: value(true), - }]) - .apply_blocking() - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - assert_eq!(contents, "enabled = true\n"); - } - - #[test] - fn set_model_availability_nux_count_writes_shown_count() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - let shown_count = HashMap::from([("gpt-foo".to_string(), 4)]); - - ConfigEditsBuilder::new(codex_home) - .set_model_availability_nux_count(&shown_count) - .apply_blocking() - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"[tui.model_availability_nux] -gpt-foo = 4 -"#; - assert_eq!(contents, expected); - } - - #[test] - fn set_skill_config_writes_disabled_entry() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - - ConfigEditsBuilder::new(codex_home) - .with_edits([ConfigEdit::SetSkillConfig { - path: PathBuf::from("/tmp/skills/demo/SKILL.md"), - enabled: false, - }]) - .apply_blocking() - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"[[skills.config]] -path = "/tmp/skills/demo/SKILL.md" -enabled = false -"#; - assert_eq!(contents, expected); - } - - #[test] - fn set_skill_config_removes_entry_when_enabled() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"[[skills.config]] -path = "/tmp/skills/demo/SKILL.md" -enabled = false -"#, - ) - .expect("seed config"); - - ConfigEditsBuilder::new(codex_home) - .with_edits([ConfigEdit::SetSkillConfig { - path: PathBuf::from("/tmp/skills/demo/SKILL.md"), - enabled: true, - }]) - .apply_blocking() - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - assert_eq!(contents, ""); - } - - #[test] - fn blocking_set_model_preserves_inline_table_contents() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - - // Seed with inline tables for profiles to simulate common user config. - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"profile = "fast" - -profiles = { fast = { model = "gpt-4o", sandbox_mode = "strict" } } -"#, - ) - .expect("seed"); - - apply_blocking( - codex_home, - None, - &[ConfigEdit::SetModel { - model: Some("o4-mini".to_string()), - effort: None, - }], - ) - .expect("persist"); - - let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let value: TomlValue = toml::from_str(&raw).expect("parse config"); - - // Ensure sandbox_mode is preserved under profiles.fast and model updated. - let profiles_tbl = value - .get("profiles") - .and_then(|v| v.as_table()) - .expect("profiles table"); - let fast_tbl = profiles_tbl - .get("fast") - .and_then(|v| v.as_table()) - .expect("fast table"); - assert_eq!( - fast_tbl.get("sandbox_mode").and_then(|v| v.as_str()), - Some("strict") - ); - assert_eq!( - fast_tbl.get("model").and_then(|v| v.as_str()), - Some("o4-mini") - ); - } - - #[cfg(unix)] - #[test] - fn blocking_set_model_writes_through_symlink_chain() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - let target_dir = tempdir().expect("target dir"); - let target_path = target_dir.path().join(CONFIG_TOML_FILE); - let link_path = codex_home.join("config-link.toml"); - let config_path = codex_home.join(CONFIG_TOML_FILE); - - symlink(&target_path, &link_path).expect("symlink link"); - symlink("config-link.toml", &config_path).expect("symlink config"); - - apply_blocking( - codex_home, - None, - &[ConfigEdit::SetModel { - model: Some("gpt-5.1-codex".to_string()), - effort: Some(ReasoningEffort::High), - }], - ) - .expect("persist"); - - let meta = std::fs::symlink_metadata(&config_path).expect("config metadata"); - assert!(meta.file_type().is_symlink()); - - let contents = std::fs::read_to_string(&target_path).expect("read target"); - let expected = r#"model = "gpt-5.1-codex" -model_reasoning_effort = "high" -"#; - assert_eq!(contents, expected); - } - - #[cfg(unix)] - #[test] - fn blocking_set_model_replaces_symlink_on_cycle() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - let link_a = codex_home.join("a.toml"); - let link_b = codex_home.join("b.toml"); - let config_path = codex_home.join(CONFIG_TOML_FILE); - - symlink("b.toml", &link_a).expect("symlink a"); - symlink("a.toml", &link_b).expect("symlink b"); - symlink("a.toml", &config_path).expect("symlink config"); - - apply_blocking( - codex_home, - None, - &[ConfigEdit::SetModel { - model: Some("gpt-5.1-codex".to_string()), - effort: None, - }], - ) - .expect("persist"); - - let meta = std::fs::symlink_metadata(&config_path).expect("config metadata"); - assert!(!meta.file_type().is_symlink()); - - let contents = std::fs::read_to_string(&config_path).expect("read config"); - let expected = r#"model = "gpt-5.1-codex" -"#; - assert_eq!(contents, expected); - } - - #[test] - fn batch_write_table_upsert_preserves_inline_comments() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - let original = r#"approval_policy = "never" - -[mcp_servers.linear] -name = "linear" -# ok -url = "https://linear.example" - -[mcp_servers.linear.http_headers] -foo = "bar" - -[sandbox_workspace_write] -# ok 3 -network_access = false -"#; - std::fs::write(codex_home.join(CONFIG_TOML_FILE), original).expect("seed config"); - - apply_blocking( - codex_home, - None, - &[ - ConfigEdit::SetPath { - segments: vec![ - "mcp_servers".to_string(), - "linear".to_string(), - "url".to_string(), - ], - value: value("https://linear.example/v2"), - }, - ConfigEdit::SetPath { - segments: vec![ - "sandbox_workspace_write".to_string(), - "network_access".to_string(), - ], - value: value(true), - }, - ], - ) - .expect("apply"); - - let updated = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"approval_policy = "never" - -[mcp_servers.linear] -name = "linear" -# ok -url = "https://linear.example/v2" - -[mcp_servers.linear.http_headers] -foo = "bar" - -[sandbox_workspace_write] -# ok 3 -network_access = true -"#; - assert_eq!(updated, expected); - } - - #[test] - fn blocking_clear_model_removes_inline_table_entry() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"profile = "fast" - -profiles = { fast = { model = "gpt-4o", sandbox_mode = "strict" } } -"#, - ) - .expect("seed"); - - apply_blocking( - codex_home, - None, - &[ConfigEdit::SetModel { - model: None, - effort: Some(ReasoningEffort::High), - }], - ) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"profile = "fast" - -[profiles.fast] -sandbox_mode = "strict" -model_reasoning_effort = "high" -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_set_model_scopes_to_active_profile() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"profile = "team" - -[profiles.team] -model_reasoning_effort = "low" -"#, - ) - .expect("seed"); - - apply_blocking( - codex_home, - None, - &[ConfigEdit::SetModel { - model: Some("o5-preview".to_string()), - effort: Some(ReasoningEffort::Minimal), - }], - ) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"profile = "team" - -[profiles.team] -model_reasoning_effort = "minimal" -model = "o5-preview" -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_set_model_with_explicit_profile() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"[profiles."team a"] -model = "gpt-5.1-codex" -"#, - ) - .expect("seed"); - - apply_blocking( - codex_home, - Some("team a"), - &[ConfigEdit::SetModel { - model: Some("o4-mini".to_string()), - effort: None, - }], - ) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"[profiles."team a"] -model = "o4-mini" -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_set_hide_full_access_warning_preserves_table() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"# Global comment - -[notice] -# keep me -existing = "value" -"#, - ) - .expect("seed"); - - apply_blocking( - codex_home, - None, - &[ConfigEdit::SetNoticeHideFullAccessWarning(true)], - ) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"# Global comment - -[notice] -# keep me -existing = "value" -hide_full_access_warning = true -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_set_hide_rate_limit_model_nudge_preserves_table() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"[notice] -existing = "value" -"#, - ) - .expect("seed"); - - apply_blocking( - codex_home, - None, - &[ConfigEdit::SetNoticeHideRateLimitModelNudge(true)], - ) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"[notice] -existing = "value" -hide_rate_limit_model_nudge = true -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_set_hide_gpt5_1_migration_prompt_preserves_table() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"[notice] -existing = "value" -"#, - ) - .expect("seed"); - apply_blocking( - codex_home, - None, - &[ConfigEdit::SetNoticeHideModelMigrationPrompt( - "hide_gpt5_1_migration_prompt".to_string(), - true, - )], - ) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"[notice] -existing = "value" -hide_gpt5_1_migration_prompt = true -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_set_hide_gpt_5_1_codex_max_migration_prompt_preserves_table() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"[notice] -existing = "value" -"#, - ) - .expect("seed"); - apply_blocking( - codex_home, - None, - &[ConfigEdit::SetNoticeHideModelMigrationPrompt( - "hide_gpt-5.1-codex-max_migration_prompt".to_string(), - true, - )], - ) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"[notice] -existing = "value" -"hide_gpt-5.1-codex-max_migration_prompt" = true -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_record_model_migration_seen_preserves_table() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"[notice] -existing = "value" -"#, - ) - .expect("seed"); - apply_blocking( - codex_home, - None, - &[ConfigEdit::RecordModelMigrationSeen { - from: "gpt-5".to_string(), - to: "gpt-5.1".to_string(), - }], - ) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"[notice] -existing = "value" - -[notice.model_migrations] -gpt-5 = "gpt-5.1" -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_replace_mcp_servers_round_trips() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - - let mut servers = BTreeMap::new(); - servers.insert( - "stdio".to_string(), - McpServerConfig { - transport: McpServerTransportConfig::Stdio { - command: "cmd".to_string(), - args: vec!["--flag".to_string()], - env: Some( - [ - ("B".to_string(), "2".to_string()), - ("A".to_string(), "1".to_string()), - ] - .into_iter() - .collect(), - ), - env_vars: vec!["FOO".to_string()], - cwd: None, - }, - enabled: true, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: Some(vec!["one".to_string(), "two".to_string()]), - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - ); - - servers.insert( - "http".to_string(), - McpServerConfig { - transport: McpServerTransportConfig::StreamableHttp { - url: "https://example.com".to_string(), - bearer_token_env_var: Some("TOKEN".to_string()), - http_headers: Some( - [("Z-Header".to_string(), "z".to_string())] - .into_iter() - .collect(), - ), - env_http_headers: None, - }, - enabled: false, - required: false, - disabled_reason: None, - startup_timeout_sec: Some(std::time::Duration::from_secs(5)), - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: Some(vec!["forbidden".to_string()]), - scopes: None, - oauth_resource: Some("https://resource.example.com".to_string()), - }, - ); - - apply_blocking( - codex_home, - None, - &[ConfigEdit::ReplaceMcpServers(servers.clone())], - ) - .expect("persist"); - - let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = "\ -[mcp_servers.http] -url = \"https://example.com\" -bearer_token_env_var = \"TOKEN\" -enabled = false -startup_timeout_sec = 5.0 -disabled_tools = [\"forbidden\"] -oauth_resource = \"https://resource.example.com\" - -[mcp_servers.http.http_headers] -Z-Header = \"z\" - -[mcp_servers.stdio] -command = \"cmd\" -args = [\"--flag\"] -env_vars = [\"FOO\"] -enabled_tools = [\"one\", \"two\"] - -[mcp_servers.stdio.env] -A = \"1\" -B = \"2\" -"; - assert_eq!(raw, expected); - } - - #[test] - fn blocking_replace_mcp_servers_preserves_inline_comments() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"[mcp_servers] -# keep me -foo = { command = "cmd" } -"#, - ) - .expect("seed"); - - let mut servers = BTreeMap::new(); - servers.insert( - "foo".to_string(), - McpServerConfig { - transport: McpServerTransportConfig::Stdio { - command: "cmd".to_string(), - args: Vec::new(), - env: None, - env_vars: Vec::new(), - cwd: None, - }, - enabled: true, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - ); - - apply_blocking(codex_home, None, &[ConfigEdit::ReplaceMcpServers(servers)]) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"[mcp_servers] -# keep me -foo = { command = "cmd" } -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_replace_mcp_servers_preserves_inline_comment_suffix() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"[mcp_servers] -foo = { command = "cmd" } # keep me -"#, - ) - .expect("seed"); - - let mut servers = BTreeMap::new(); - servers.insert( - "foo".to_string(), - McpServerConfig { - transport: McpServerTransportConfig::Stdio { - command: "cmd".to_string(), - args: Vec::new(), - env: None, - env_vars: Vec::new(), - cwd: None, - }, - enabled: false, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - ); - - apply_blocking(codex_home, None, &[ConfigEdit::ReplaceMcpServers(servers)]) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"[mcp_servers] -foo = { command = "cmd" , enabled = false } # keep me -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_replace_mcp_servers_preserves_inline_comment_after_removing_keys() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"[mcp_servers] -foo = { command = "cmd", args = ["--flag"] } # keep me -"#, - ) - .expect("seed"); - - let mut servers = BTreeMap::new(); - servers.insert( - "foo".to_string(), - McpServerConfig { - transport: McpServerTransportConfig::Stdio { - command: "cmd".to_string(), - args: Vec::new(), - env: None, - env_vars: Vec::new(), - cwd: None, - }, - enabled: true, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - ); - - apply_blocking(codex_home, None, &[ConfigEdit::ReplaceMcpServers(servers)]) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"[mcp_servers] -foo = { command = "cmd"} # keep me -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_replace_mcp_servers_preserves_inline_comment_prefix_on_update() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - r#"[mcp_servers] -# keep me -foo = { command = "cmd" } -"#, - ) - .expect("seed"); - - let mut servers = BTreeMap::new(); - servers.insert( - "foo".to_string(), - McpServerConfig { - transport: McpServerTransportConfig::Stdio { - command: "cmd".to_string(), - args: Vec::new(), - env: None, - env_vars: Vec::new(), - cwd: None, - }, - enabled: false, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - ); - - apply_blocking(codex_home, None, &[ConfigEdit::ReplaceMcpServers(servers)]) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"[mcp_servers] -# keep me -foo = { command = "cmd" , enabled = false } -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_clear_path_noop_when_missing() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - - apply_blocking( - codex_home, - None, - &[ConfigEdit::ClearPath { - segments: vec!["missing".to_string()], - }], - ) - .expect("apply"); - - assert!( - !codex_home.join(CONFIG_TOML_FILE).exists(), - "config.toml should not be created on noop" - ); - } - - #[test] - fn blocking_set_path_updates_notifications() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - - let item = value(false); - apply_blocking( - codex_home, - None, - &[ConfigEdit::SetPath { - segments: vec!["tui".to_string(), "notifications".to_string()], - value: item, - }], - ) - .expect("apply"); - - let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let config: TomlValue = toml::from_str(&raw).expect("parse config"); - let notifications = config - .get("tui") - .and_then(|item| item.as_table()) - .and_then(|tbl| tbl.get("notifications")) - .and_then(toml::Value::as_bool); - assert_eq!(notifications, Some(false)); - } - - #[tokio::test] - async fn async_builder_set_model_persists() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path().to_path_buf(); - - ConfigEditsBuilder::new(&codex_home) - .set_model(Some("gpt-5.1-codex"), Some(ReasoningEffort::High)) - .apply() - .await - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"model = "gpt-5.1-codex" -model_reasoning_effort = "high" -"#; - assert_eq!(contents, expected); - } - - #[test] - fn blocking_builder_set_model_round_trips_back_and_forth() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - - let initial_expected = r#"model = "o4-mini" -model_reasoning_effort = "low" -"#; - ConfigEditsBuilder::new(codex_home) - .set_model(Some("o4-mini"), Some(ReasoningEffort::Low)) - .apply_blocking() - .expect("persist initial"); - let mut contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - assert_eq!(contents, initial_expected); - - let updated_expected = r#"model = "gpt-5.1-codex" -model_reasoning_effort = "high" -"#; - ConfigEditsBuilder::new(codex_home) - .set_model(Some("gpt-5.1-codex"), Some(ReasoningEffort::High)) - .apply_blocking() - .expect("persist update"); - contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - assert_eq!(contents, updated_expected); - - ConfigEditsBuilder::new(codex_home) - .set_model(Some("o4-mini"), Some(ReasoningEffort::Low)) - .apply_blocking() - .expect("persist revert"); - contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - assert_eq!(contents, initial_expected); - } - - #[tokio::test] - async fn blocking_set_asynchronous_helpers_available() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path().to_path_buf(); - - ConfigEditsBuilder::new(&codex_home) - .set_hide_full_access_warning(true) - .apply() - .await - .expect("persist"); - - let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let notice = toml::from_str::(&raw) - .expect("parse config") - .get("notice") - .and_then(|item| item.as_table()) - .and_then(|tbl| tbl.get("hide_full_access_warning")) - .and_then(toml::Value::as_bool); - assert_eq!(notice, Some(true)); - } - - #[test] - fn blocking_builder_set_realtime_audio_persists_and_clears() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - - ConfigEditsBuilder::new(codex_home) - .set_realtime_microphone(Some("USB Mic")) - .set_realtime_speaker(Some("Desk Speakers")) - .apply_blocking() - .expect("persist realtime audio"); - - let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let config: TomlValue = toml::from_str(&raw).expect("parse config"); - let realtime_audio = config - .get("audio") - .and_then(TomlValue::as_table) - .expect("audio table should exist"); - assert_eq!( - realtime_audio.get("microphone").and_then(TomlValue::as_str), - Some("USB Mic") - ); - assert_eq!( - realtime_audio.get("speaker").and_then(TomlValue::as_str), - Some("Desk Speakers") - ); - - ConfigEditsBuilder::new(codex_home) - .set_realtime_microphone(None) - .apply_blocking() - .expect("clear realtime microphone"); - - let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - let config: TomlValue = toml::from_str(&raw).expect("parse config"); - let realtime_audio = config - .get("audio") - .and_then(TomlValue::as_table) - .expect("audio table should exist"); - assert_eq!(realtime_audio.get("microphone"), None); - assert_eq!( - realtime_audio.get("speaker").and_then(TomlValue::as_str), - Some("Desk Speakers") - ); - } - - #[test] - fn replace_mcp_servers_blocking_clears_table_when_empty() { - let tmp = tempdir().expect("tmpdir"); - let codex_home = tmp.path(); - std::fs::write( - codex_home.join(CONFIG_TOML_FILE), - "[mcp_servers]\nfoo = { command = \"cmd\" }\n", - ) - .expect("seed"); - - apply_blocking( - codex_home, - None, - &[ConfigEdit::ReplaceMcpServers(BTreeMap::new())], - ) - .expect("persist"); - - let contents = - std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); - assert!(!contents.contains("mcp_servers")); - } -} +#[path = "edit_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/config/edit_tests.rs b/codex-rs/core/src/config/edit_tests.rs new file mode 100644 index 0000000000..5a31d84dd0 --- /dev/null +++ b/codex-rs/core/src/config/edit_tests.rs @@ -0,0 +1,987 @@ +use super::*; +use crate::config::types::McpServerTransportConfig; +use codex_protocol::openai_models::ReasoningEffort; +use pretty_assertions::assert_eq; +#[cfg(unix)] +use std::os::unix::fs::symlink; +use tempfile::tempdir; +use toml::Value as TomlValue; + +#[test] +fn blocking_set_model_top_level() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + + apply_blocking( + codex_home, + None, + &[ConfigEdit::SetModel { + model: Some("gpt-5.1-codex".to_string()), + effort: Some(ReasoningEffort::High), + }], + ) + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"model = "gpt-5.1-codex" +model_reasoning_effort = "high" +"#; + assert_eq!(contents, expected); +} + +#[test] +fn builder_with_edits_applies_custom_paths() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + + ConfigEditsBuilder::new(codex_home) + .with_edits(vec![ConfigEdit::SetPath { + segments: vec!["enabled".to_string()], + value: value(true), + }]) + .apply_blocking() + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + assert_eq!(contents, "enabled = true\n"); +} + +#[test] +fn set_model_availability_nux_count_writes_shown_count() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + let shown_count = HashMap::from([("gpt-foo".to_string(), 4)]); + + ConfigEditsBuilder::new(codex_home) + .set_model_availability_nux_count(&shown_count) + .apply_blocking() + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"[tui.model_availability_nux] +gpt-foo = 4 +"#; + assert_eq!(contents, expected); +} + +#[test] +fn set_skill_config_writes_disabled_entry() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + + ConfigEditsBuilder::new(codex_home) + .with_edits([ConfigEdit::SetSkillConfig { + path: PathBuf::from("/tmp/skills/demo/SKILL.md"), + enabled: false, + }]) + .apply_blocking() + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"[[skills.config]] +path = "/tmp/skills/demo/SKILL.md" +enabled = false +"#; + assert_eq!(contents, expected); +} + +#[test] +fn set_skill_config_removes_entry_when_enabled() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"[[skills.config]] +path = "/tmp/skills/demo/SKILL.md" +enabled = false +"#, + ) + .expect("seed config"); + + ConfigEditsBuilder::new(codex_home) + .with_edits([ConfigEdit::SetSkillConfig { + path: PathBuf::from("/tmp/skills/demo/SKILL.md"), + enabled: true, + }]) + .apply_blocking() + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + assert_eq!(contents, ""); +} + +#[test] +fn blocking_set_model_preserves_inline_table_contents() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + + // Seed with inline tables for profiles to simulate common user config. + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"profile = "fast" + +profiles = { fast = { model = "gpt-4o", sandbox_mode = "strict" } } +"#, + ) + .expect("seed"); + + apply_blocking( + codex_home, + None, + &[ConfigEdit::SetModel { + model: Some("o4-mini".to_string()), + effort: None, + }], + ) + .expect("persist"); + + let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let value: TomlValue = toml::from_str(&raw).expect("parse config"); + + // Ensure sandbox_mode is preserved under profiles.fast and model updated. + let profiles_tbl = value + .get("profiles") + .and_then(|v| v.as_table()) + .expect("profiles table"); + let fast_tbl = profiles_tbl + .get("fast") + .and_then(|v| v.as_table()) + .expect("fast table"); + assert_eq!( + fast_tbl.get("sandbox_mode").and_then(|v| v.as_str()), + Some("strict") + ); + assert_eq!( + fast_tbl.get("model").and_then(|v| v.as_str()), + Some("o4-mini") + ); +} + +#[cfg(unix)] +#[test] +fn blocking_set_model_writes_through_symlink_chain() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + let target_dir = tempdir().expect("target dir"); + let target_path = target_dir.path().join(CONFIG_TOML_FILE); + let link_path = codex_home.join("config-link.toml"); + let config_path = codex_home.join(CONFIG_TOML_FILE); + + symlink(&target_path, &link_path).expect("symlink link"); + symlink("config-link.toml", &config_path).expect("symlink config"); + + apply_blocking( + codex_home, + None, + &[ConfigEdit::SetModel { + model: Some("gpt-5.1-codex".to_string()), + effort: Some(ReasoningEffort::High), + }], + ) + .expect("persist"); + + let meta = std::fs::symlink_metadata(&config_path).expect("config metadata"); + assert!(meta.file_type().is_symlink()); + + let contents = std::fs::read_to_string(&target_path).expect("read target"); + let expected = r#"model = "gpt-5.1-codex" +model_reasoning_effort = "high" +"#; + assert_eq!(contents, expected); +} + +#[cfg(unix)] +#[test] +fn blocking_set_model_replaces_symlink_on_cycle() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + let link_a = codex_home.join("a.toml"); + let link_b = codex_home.join("b.toml"); + let config_path = codex_home.join(CONFIG_TOML_FILE); + + symlink("b.toml", &link_a).expect("symlink a"); + symlink("a.toml", &link_b).expect("symlink b"); + symlink("a.toml", &config_path).expect("symlink config"); + + apply_blocking( + codex_home, + None, + &[ConfigEdit::SetModel { + model: Some("gpt-5.1-codex".to_string()), + effort: None, + }], + ) + .expect("persist"); + + let meta = std::fs::symlink_metadata(&config_path).expect("config metadata"); + assert!(!meta.file_type().is_symlink()); + + let contents = std::fs::read_to_string(&config_path).expect("read config"); + let expected = r#"model = "gpt-5.1-codex" +"#; + assert_eq!(contents, expected); +} + +#[test] +fn batch_write_table_upsert_preserves_inline_comments() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + let original = r#"approval_policy = "never" + +[mcp_servers.linear] +name = "linear" +# ok +url = "https://linear.example" + +[mcp_servers.linear.http_headers] +foo = "bar" + +[sandbox_workspace_write] +# ok 3 +network_access = false +"#; + std::fs::write(codex_home.join(CONFIG_TOML_FILE), original).expect("seed config"); + + apply_blocking( + codex_home, + None, + &[ + ConfigEdit::SetPath { + segments: vec![ + "mcp_servers".to_string(), + "linear".to_string(), + "url".to_string(), + ], + value: value("https://linear.example/v2"), + }, + ConfigEdit::SetPath { + segments: vec![ + "sandbox_workspace_write".to_string(), + "network_access".to_string(), + ], + value: value(true), + }, + ], + ) + .expect("apply"); + + let updated = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"approval_policy = "never" + +[mcp_servers.linear] +name = "linear" +# ok +url = "https://linear.example/v2" + +[mcp_servers.linear.http_headers] +foo = "bar" + +[sandbox_workspace_write] +# ok 3 +network_access = true +"#; + assert_eq!(updated, expected); +} + +#[test] +fn blocking_clear_model_removes_inline_table_entry() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"profile = "fast" + +profiles = { fast = { model = "gpt-4o", sandbox_mode = "strict" } } +"#, + ) + .expect("seed"); + + apply_blocking( + codex_home, + None, + &[ConfigEdit::SetModel { + model: None, + effort: Some(ReasoningEffort::High), + }], + ) + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"profile = "fast" + +[profiles.fast] +sandbox_mode = "strict" +model_reasoning_effort = "high" +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_set_model_scopes_to_active_profile() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"profile = "team" + +[profiles.team] +model_reasoning_effort = "low" +"#, + ) + .expect("seed"); + + apply_blocking( + codex_home, + None, + &[ConfigEdit::SetModel { + model: Some("o5-preview".to_string()), + effort: Some(ReasoningEffort::Minimal), + }], + ) + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"profile = "team" + +[profiles.team] +model_reasoning_effort = "minimal" +model = "o5-preview" +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_set_model_with_explicit_profile() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"[profiles."team a"] +model = "gpt-5.1-codex" +"#, + ) + .expect("seed"); + + apply_blocking( + codex_home, + Some("team a"), + &[ConfigEdit::SetModel { + model: Some("o4-mini".to_string()), + effort: None, + }], + ) + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"[profiles."team a"] +model = "o4-mini" +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_set_hide_full_access_warning_preserves_table() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"# Global comment + +[notice] +# keep me +existing = "value" +"#, + ) + .expect("seed"); + + apply_blocking( + codex_home, + None, + &[ConfigEdit::SetNoticeHideFullAccessWarning(true)], + ) + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"# Global comment + +[notice] +# keep me +existing = "value" +hide_full_access_warning = true +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_set_hide_rate_limit_model_nudge_preserves_table() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"[notice] +existing = "value" +"#, + ) + .expect("seed"); + + apply_blocking( + codex_home, + None, + &[ConfigEdit::SetNoticeHideRateLimitModelNudge(true)], + ) + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"[notice] +existing = "value" +hide_rate_limit_model_nudge = true +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_set_hide_gpt5_1_migration_prompt_preserves_table() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"[notice] +existing = "value" +"#, + ) + .expect("seed"); + apply_blocking( + codex_home, + None, + &[ConfigEdit::SetNoticeHideModelMigrationPrompt( + "hide_gpt5_1_migration_prompt".to_string(), + true, + )], + ) + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"[notice] +existing = "value" +hide_gpt5_1_migration_prompt = true +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_set_hide_gpt_5_1_codex_max_migration_prompt_preserves_table() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"[notice] +existing = "value" +"#, + ) + .expect("seed"); + apply_blocking( + codex_home, + None, + &[ConfigEdit::SetNoticeHideModelMigrationPrompt( + "hide_gpt-5.1-codex-max_migration_prompt".to_string(), + true, + )], + ) + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"[notice] +existing = "value" +"hide_gpt-5.1-codex-max_migration_prompt" = true +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_record_model_migration_seen_preserves_table() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"[notice] +existing = "value" +"#, + ) + .expect("seed"); + apply_blocking( + codex_home, + None, + &[ConfigEdit::RecordModelMigrationSeen { + from: "gpt-5".to_string(), + to: "gpt-5.1".to_string(), + }], + ) + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"[notice] +existing = "value" + +[notice.model_migrations] +gpt-5 = "gpt-5.1" +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_replace_mcp_servers_round_trips() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + + let mut servers = BTreeMap::new(); + servers.insert( + "stdio".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "cmd".to_string(), + args: vec!["--flag".to_string()], + env: Some( + [ + ("B".to_string(), "2".to_string()), + ("A".to_string(), "1".to_string()), + ] + .into_iter() + .collect(), + ), + env_vars: vec!["FOO".to_string()], + cwd: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: Some(vec!["one".to_string(), "two".to_string()]), + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + ); + + servers.insert( + "http".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com".to_string(), + bearer_token_env_var: Some("TOKEN".to_string()), + http_headers: Some( + [("Z-Header".to_string(), "z".to_string())] + .into_iter() + .collect(), + ), + env_http_headers: None, + }, + enabled: false, + required: false, + disabled_reason: None, + startup_timeout_sec: Some(std::time::Duration::from_secs(5)), + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: Some(vec!["forbidden".to_string()]), + scopes: None, + oauth_resource: Some("https://resource.example.com".to_string()), + }, + ); + + apply_blocking( + codex_home, + None, + &[ConfigEdit::ReplaceMcpServers(servers.clone())], + ) + .expect("persist"); + + let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = "\ +[mcp_servers.http] +url = \"https://example.com\" +bearer_token_env_var = \"TOKEN\" +enabled = false +startup_timeout_sec = 5.0 +disabled_tools = [\"forbidden\"] +oauth_resource = \"https://resource.example.com\" + +[mcp_servers.http.http_headers] +Z-Header = \"z\" + +[mcp_servers.stdio] +command = \"cmd\" +args = [\"--flag\"] +env_vars = [\"FOO\"] +enabled_tools = [\"one\", \"two\"] + +[mcp_servers.stdio.env] +A = \"1\" +B = \"2\" +"; + assert_eq!(raw, expected); +} + +#[test] +fn blocking_replace_mcp_servers_preserves_inline_comments() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"[mcp_servers] +# keep me +foo = { command = "cmd" } +"#, + ) + .expect("seed"); + + let mut servers = BTreeMap::new(); + servers.insert( + "foo".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "cmd".to_string(), + args: Vec::new(), + env: None, + env_vars: Vec::new(), + cwd: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + ); + + apply_blocking(codex_home, None, &[ConfigEdit::ReplaceMcpServers(servers)]).expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"[mcp_servers] +# keep me +foo = { command = "cmd" } +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_replace_mcp_servers_preserves_inline_comment_suffix() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"[mcp_servers] +foo = { command = "cmd" } # keep me +"#, + ) + .expect("seed"); + + let mut servers = BTreeMap::new(); + servers.insert( + "foo".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "cmd".to_string(), + args: Vec::new(), + env: None, + env_vars: Vec::new(), + cwd: None, + }, + enabled: false, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + ); + + apply_blocking(codex_home, None, &[ConfigEdit::ReplaceMcpServers(servers)]).expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"[mcp_servers] +foo = { command = "cmd" , enabled = false } # keep me +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_replace_mcp_servers_preserves_inline_comment_after_removing_keys() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"[mcp_servers] +foo = { command = "cmd", args = ["--flag"] } # keep me +"#, + ) + .expect("seed"); + + let mut servers = BTreeMap::new(); + servers.insert( + "foo".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "cmd".to_string(), + args: Vec::new(), + env: None, + env_vars: Vec::new(), + cwd: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + ); + + apply_blocking(codex_home, None, &[ConfigEdit::ReplaceMcpServers(servers)]).expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"[mcp_servers] +foo = { command = "cmd"} # keep me +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_replace_mcp_servers_preserves_inline_comment_prefix_on_update() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + r#"[mcp_servers] +# keep me +foo = { command = "cmd" } +"#, + ) + .expect("seed"); + + let mut servers = BTreeMap::new(); + servers.insert( + "foo".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "cmd".to_string(), + args: Vec::new(), + env: None, + env_vars: Vec::new(), + cwd: None, + }, + enabled: false, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + ); + + apply_blocking(codex_home, None, &[ConfigEdit::ReplaceMcpServers(servers)]).expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"[mcp_servers] +# keep me +foo = { command = "cmd" , enabled = false } +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_clear_path_noop_when_missing() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + + apply_blocking( + codex_home, + None, + &[ConfigEdit::ClearPath { + segments: vec!["missing".to_string()], + }], + ) + .expect("apply"); + + assert!( + !codex_home.join(CONFIG_TOML_FILE).exists(), + "config.toml should not be created on noop" + ); +} + +#[test] +fn blocking_set_path_updates_notifications() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + + let item = value(false); + apply_blocking( + codex_home, + None, + &[ConfigEdit::SetPath { + segments: vec!["tui".to_string(), "notifications".to_string()], + value: item, + }], + ) + .expect("apply"); + + let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let config: TomlValue = toml::from_str(&raw).expect("parse config"); + let notifications = config + .get("tui") + .and_then(|item| item.as_table()) + .and_then(|tbl| tbl.get("notifications")) + .and_then(toml::Value::as_bool); + assert_eq!(notifications, Some(false)); +} + +#[tokio::test] +async fn async_builder_set_model_persists() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path().to_path_buf(); + + ConfigEditsBuilder::new(&codex_home) + .set_model(Some("gpt-5.1-codex"), Some(ReasoningEffort::High)) + .apply() + .await + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"model = "gpt-5.1-codex" +model_reasoning_effort = "high" +"#; + assert_eq!(contents, expected); +} + +#[test] +fn blocking_builder_set_model_round_trips_back_and_forth() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + + let initial_expected = r#"model = "o4-mini" +model_reasoning_effort = "low" +"#; + ConfigEditsBuilder::new(codex_home) + .set_model(Some("o4-mini"), Some(ReasoningEffort::Low)) + .apply_blocking() + .expect("persist initial"); + let mut contents = + std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + assert_eq!(contents, initial_expected); + + let updated_expected = r#"model = "gpt-5.1-codex" +model_reasoning_effort = "high" +"#; + ConfigEditsBuilder::new(codex_home) + .set_model(Some("gpt-5.1-codex"), Some(ReasoningEffort::High)) + .apply_blocking() + .expect("persist update"); + contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + assert_eq!(contents, updated_expected); + + ConfigEditsBuilder::new(codex_home) + .set_model(Some("o4-mini"), Some(ReasoningEffort::Low)) + .apply_blocking() + .expect("persist revert"); + contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + assert_eq!(contents, initial_expected); +} + +#[tokio::test] +async fn blocking_set_asynchronous_helpers_available() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path().to_path_buf(); + + ConfigEditsBuilder::new(&codex_home) + .set_hide_full_access_warning(true) + .apply() + .await + .expect("persist"); + + let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let notice = toml::from_str::(&raw) + .expect("parse config") + .get("notice") + .and_then(|item| item.as_table()) + .and_then(|tbl| tbl.get("hide_full_access_warning")) + .and_then(toml::Value::as_bool); + assert_eq!(notice, Some(true)); +} + +#[test] +fn blocking_builder_set_realtime_audio_persists_and_clears() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + + ConfigEditsBuilder::new(codex_home) + .set_realtime_microphone(Some("USB Mic")) + .set_realtime_speaker(Some("Desk Speakers")) + .apply_blocking() + .expect("persist realtime audio"); + + let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let config: TomlValue = toml::from_str(&raw).expect("parse config"); + let realtime_audio = config + .get("audio") + .and_then(TomlValue::as_table) + .expect("audio table should exist"); + assert_eq!( + realtime_audio.get("microphone").and_then(TomlValue::as_str), + Some("USB Mic") + ); + assert_eq!( + realtime_audio.get("speaker").and_then(TomlValue::as_str), + Some("Desk Speakers") + ); + + ConfigEditsBuilder::new(codex_home) + .set_realtime_microphone(None) + .apply_blocking() + .expect("clear realtime microphone"); + + let raw = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + let config: TomlValue = toml::from_str(&raw).expect("parse config"); + let realtime_audio = config + .get("audio") + .and_then(TomlValue::as_table) + .expect("audio table should exist"); + assert_eq!(realtime_audio.get("microphone"), None); + assert_eq!( + realtime_audio.get("speaker").and_then(TomlValue::as_str), + Some("Desk Speakers") + ); +} + +#[test] +fn replace_mcp_servers_blocking_clears_table_when_empty() { + let tmp = tempdir().expect("tmpdir"); + let codex_home = tmp.path(); + std::fs::write( + codex_home.join(CONFIG_TOML_FILE), + "[mcp_servers]\nfoo = { command = \"cmd\" }\n", + ) + .expect("seed"); + + apply_blocking( + codex_home, + None, + &[ConfigEdit::ReplaceMcpServers(BTreeMap::new())], + ) + .expect("persist"); + + let contents = std::fs::read_to_string(codex_home.join(CONFIG_TOML_FILE)).expect("read config"); + assert!(!contents.contains("mcp_servers")); +} diff --git a/codex-rs/core/src/config/network_proxy_spec.rs b/codex-rs/core/src/config/network_proxy_spec.rs index 9c70cf084c..de77e4426e 100644 --- a/codex-rs/core/src/config/network_proxy_spec.rs +++ b/codex-rs/core/src/config/network_proxy_spec.rs @@ -280,207 +280,5 @@ impl NetworkProxySpec { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn build_state_with_audit_metadata_threads_metadata_to_state() { - let spec = NetworkProxySpec { - config: NetworkProxyConfig::default(), - constraints: NetworkProxyConstraints::default(), - hard_deny_allowlist_misses: false, - }; - let metadata = NetworkProxyAuditMetadata { - conversation_id: Some("conversation-1".to_string()), - app_version: Some("1.2.3".to_string()), - user_account_id: Some("acct-1".to_string()), - ..NetworkProxyAuditMetadata::default() - }; - - let state = spec - .build_state_with_audit_metadata(metadata.clone()) - .expect("state should build"); - assert_eq!(state.audit_metadata(), &metadata); - } - - #[test] - fn requirements_allowed_domains_are_a_baseline_for_user_allowlist() { - let mut config = NetworkProxyConfig::default(); - config.network.allowed_domains = vec!["api.example.com".to_string()]; - let requirements = NetworkConstraints { - allowed_domains: Some(vec!["*.example.com".to_string()]), - ..Default::default() - }; - - let spec = NetworkProxySpec::from_config_and_constraints( - config, - Some(requirements), - &SandboxPolicy::new_read_only_policy(), - ) - .expect("config should stay within the managed allowlist"); - - assert_eq!( - spec.config.network.allowed_domains, - vec!["*.example.com".to_string(), "api.example.com".to_string()] - ); - assert_eq!( - spec.constraints.allowed_domains, - Some(vec!["*.example.com".to_string()]) - ); - assert_eq!(spec.constraints.allowlist_expansion_enabled, Some(true)); - } - - #[test] - fn danger_full_access_keeps_managed_allowlist_and_denylist_fixed() { - let mut config = NetworkProxyConfig::default(); - config.network.allowed_domains = vec!["evil.com".to_string()]; - config.network.denied_domains = vec!["more-blocked.example.com".to_string()]; - let requirements = NetworkConstraints { - allowed_domains: Some(vec!["*.example.com".to_string()]), - denied_domains: Some(vec!["blocked.example.com".to_string()]), - ..Default::default() - }; - - let spec = NetworkProxySpec::from_config_and_constraints( - config, - Some(requirements), - &SandboxPolicy::DangerFullAccess, - ) - .expect("yolo mode should pin the effective policy to the managed baseline"); - - assert_eq!( - spec.config.network.allowed_domains, - vec!["*.example.com".to_string()] - ); - assert_eq!( - spec.config.network.denied_domains, - vec!["blocked.example.com".to_string()] - ); - assert_eq!(spec.constraints.allowlist_expansion_enabled, Some(false)); - assert_eq!(spec.constraints.denylist_expansion_enabled, Some(false)); - } - - #[test] - fn managed_allowed_domains_only_disables_default_mode_allowlist_expansion() { - let mut config = NetworkProxyConfig::default(); - config.network.allowed_domains = vec!["api.example.com".to_string()]; - let requirements = NetworkConstraints { - allowed_domains: Some(vec!["*.example.com".to_string()]), - managed_allowed_domains_only: Some(true), - ..Default::default() - }; - - let spec = NetworkProxySpec::from_config_and_constraints( - config, - Some(requirements), - &SandboxPolicy::new_workspace_write_policy(), - ) - .expect("managed baseline should still load"); - - assert_eq!( - spec.config.network.allowed_domains, - vec!["*.example.com".to_string()] - ); - assert_eq!(spec.constraints.allowlist_expansion_enabled, Some(false)); - } - - #[test] - fn managed_allowed_domains_only_ignores_user_allowlist_and_hard_denies_misses() { - let mut config = NetworkProxyConfig::default(); - config.network.allowed_domains = vec!["api.example.com".to_string()]; - let requirements = NetworkConstraints { - allowed_domains: Some(vec!["managed.example.com".to_string()]), - managed_allowed_domains_only: Some(true), - ..Default::default() - }; - - let spec = NetworkProxySpec::from_config_and_constraints( - config, - Some(requirements), - &SandboxPolicy::new_workspace_write_policy(), - ) - .expect("managed-only allowlist should still load"); - - assert_eq!( - spec.config.network.allowed_domains, - vec!["managed.example.com".to_string()] - ); - assert_eq!( - spec.constraints.allowed_domains, - Some(vec!["managed.example.com".to_string()]) - ); - assert_eq!(spec.constraints.allowlist_expansion_enabled, Some(false)); - assert!(spec.hard_deny_allowlist_misses); - } - - #[test] - fn managed_allowed_domains_only_without_managed_allowlist_blocks_all_user_domains() { - let mut config = NetworkProxyConfig::default(); - config.network.allowed_domains = vec!["api.example.com".to_string()]; - let requirements = NetworkConstraints { - managed_allowed_domains_only: Some(true), - ..Default::default() - }; - - let spec = NetworkProxySpec::from_config_and_constraints( - config, - Some(requirements), - &SandboxPolicy::new_workspace_write_policy(), - ) - .expect("managed-only mode should treat missing managed allowlist as empty"); - - assert!(spec.config.network.allowed_domains.is_empty()); - assert_eq!(spec.constraints.allowed_domains, Some(Vec::new())); - assert_eq!(spec.constraints.allowlist_expansion_enabled, Some(false)); - assert!(spec.hard_deny_allowlist_misses); - } - - #[test] - fn managed_allowed_domains_only_blocks_all_user_domains_in_full_access_without_managed_list() { - let mut config = NetworkProxyConfig::default(); - config.network.allowed_domains = vec!["api.example.com".to_string()]; - let requirements = NetworkConstraints { - managed_allowed_domains_only: Some(true), - ..Default::default() - }; - - let spec = NetworkProxySpec::from_config_and_constraints( - config, - Some(requirements), - &SandboxPolicy::DangerFullAccess, - ) - .expect("managed-only mode should treat missing managed allowlist as empty"); - - assert!(spec.config.network.allowed_domains.is_empty()); - assert_eq!(spec.constraints.allowed_domains, Some(Vec::new())); - assert_eq!(spec.constraints.allowlist_expansion_enabled, Some(false)); - assert!(spec.hard_deny_allowlist_misses); - } - - #[test] - fn requirements_denied_domains_are_a_baseline_for_default_mode() { - let mut config = NetworkProxyConfig::default(); - config.network.denied_domains = vec!["blocked.example.com".to_string()]; - let requirements = NetworkConstraints { - denied_domains: Some(vec!["managed-blocked.example.com".to_string()]), - ..Default::default() - }; - - let spec = NetworkProxySpec::from_config_and_constraints( - config, - Some(requirements), - &SandboxPolicy::new_workspace_write_policy(), - ) - .expect("default mode should merge managed and user deny entries"); - - assert_eq!( - spec.config.network.denied_domains, - vec![ - "managed-blocked.example.com".to_string(), - "blocked.example.com".to_string() - ] - ); - assert_eq!(spec.constraints.denylist_expansion_enabled, Some(true)); - } -} +#[path = "network_proxy_spec_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/config/network_proxy_spec_tests.rs b/codex-rs/core/src/config/network_proxy_spec_tests.rs new file mode 100644 index 0000000000..4c6e82358e --- /dev/null +++ b/codex-rs/core/src/config/network_proxy_spec_tests.rs @@ -0,0 +1,202 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn build_state_with_audit_metadata_threads_metadata_to_state() { + let spec = NetworkProxySpec { + config: NetworkProxyConfig::default(), + constraints: NetworkProxyConstraints::default(), + hard_deny_allowlist_misses: false, + }; + let metadata = NetworkProxyAuditMetadata { + conversation_id: Some("conversation-1".to_string()), + app_version: Some("1.2.3".to_string()), + user_account_id: Some("acct-1".to_string()), + ..NetworkProxyAuditMetadata::default() + }; + + let state = spec + .build_state_with_audit_metadata(metadata.clone()) + .expect("state should build"); + assert_eq!(state.audit_metadata(), &metadata); +} + +#[test] +fn requirements_allowed_domains_are_a_baseline_for_user_allowlist() { + let mut config = NetworkProxyConfig::default(); + config.network.allowed_domains = vec!["api.example.com".to_string()]; + let requirements = NetworkConstraints { + allowed_domains: Some(vec!["*.example.com".to_string()]), + ..Default::default() + }; + + let spec = NetworkProxySpec::from_config_and_constraints( + config, + Some(requirements), + &SandboxPolicy::new_read_only_policy(), + ) + .expect("config should stay within the managed allowlist"); + + assert_eq!( + spec.config.network.allowed_domains, + vec!["*.example.com".to_string(), "api.example.com".to_string()] + ); + assert_eq!( + spec.constraints.allowed_domains, + Some(vec!["*.example.com".to_string()]) + ); + assert_eq!(spec.constraints.allowlist_expansion_enabled, Some(true)); +} + +#[test] +fn danger_full_access_keeps_managed_allowlist_and_denylist_fixed() { + let mut config = NetworkProxyConfig::default(); + config.network.allowed_domains = vec!["evil.com".to_string()]; + config.network.denied_domains = vec!["more-blocked.example.com".to_string()]; + let requirements = NetworkConstraints { + allowed_domains: Some(vec!["*.example.com".to_string()]), + denied_domains: Some(vec!["blocked.example.com".to_string()]), + ..Default::default() + }; + + let spec = NetworkProxySpec::from_config_and_constraints( + config, + Some(requirements), + &SandboxPolicy::DangerFullAccess, + ) + .expect("yolo mode should pin the effective policy to the managed baseline"); + + assert_eq!( + spec.config.network.allowed_domains, + vec!["*.example.com".to_string()] + ); + assert_eq!( + spec.config.network.denied_domains, + vec!["blocked.example.com".to_string()] + ); + assert_eq!(spec.constraints.allowlist_expansion_enabled, Some(false)); + assert_eq!(spec.constraints.denylist_expansion_enabled, Some(false)); +} + +#[test] +fn managed_allowed_domains_only_disables_default_mode_allowlist_expansion() { + let mut config = NetworkProxyConfig::default(); + config.network.allowed_domains = vec!["api.example.com".to_string()]; + let requirements = NetworkConstraints { + allowed_domains: Some(vec!["*.example.com".to_string()]), + managed_allowed_domains_only: Some(true), + ..Default::default() + }; + + let spec = NetworkProxySpec::from_config_and_constraints( + config, + Some(requirements), + &SandboxPolicy::new_workspace_write_policy(), + ) + .expect("managed baseline should still load"); + + assert_eq!( + spec.config.network.allowed_domains, + vec!["*.example.com".to_string()] + ); + assert_eq!(spec.constraints.allowlist_expansion_enabled, Some(false)); +} + +#[test] +fn managed_allowed_domains_only_ignores_user_allowlist_and_hard_denies_misses() { + let mut config = NetworkProxyConfig::default(); + config.network.allowed_domains = vec!["api.example.com".to_string()]; + let requirements = NetworkConstraints { + allowed_domains: Some(vec!["managed.example.com".to_string()]), + managed_allowed_domains_only: Some(true), + ..Default::default() + }; + + let spec = NetworkProxySpec::from_config_and_constraints( + config, + Some(requirements), + &SandboxPolicy::new_workspace_write_policy(), + ) + .expect("managed-only allowlist should still load"); + + assert_eq!( + spec.config.network.allowed_domains, + vec!["managed.example.com".to_string()] + ); + assert_eq!( + spec.constraints.allowed_domains, + Some(vec!["managed.example.com".to_string()]) + ); + assert_eq!(spec.constraints.allowlist_expansion_enabled, Some(false)); + assert!(spec.hard_deny_allowlist_misses); +} + +#[test] +fn managed_allowed_domains_only_without_managed_allowlist_blocks_all_user_domains() { + let mut config = NetworkProxyConfig::default(); + config.network.allowed_domains = vec!["api.example.com".to_string()]; + let requirements = NetworkConstraints { + managed_allowed_domains_only: Some(true), + ..Default::default() + }; + + let spec = NetworkProxySpec::from_config_and_constraints( + config, + Some(requirements), + &SandboxPolicy::new_workspace_write_policy(), + ) + .expect("managed-only mode should treat missing managed allowlist as empty"); + + assert!(spec.config.network.allowed_domains.is_empty()); + assert_eq!(spec.constraints.allowed_domains, Some(Vec::new())); + assert_eq!(spec.constraints.allowlist_expansion_enabled, Some(false)); + assert!(spec.hard_deny_allowlist_misses); +} + +#[test] +fn managed_allowed_domains_only_blocks_all_user_domains_in_full_access_without_managed_list() { + let mut config = NetworkProxyConfig::default(); + config.network.allowed_domains = vec!["api.example.com".to_string()]; + let requirements = NetworkConstraints { + managed_allowed_domains_only: Some(true), + ..Default::default() + }; + + let spec = NetworkProxySpec::from_config_and_constraints( + config, + Some(requirements), + &SandboxPolicy::DangerFullAccess, + ) + .expect("managed-only mode should treat missing managed allowlist as empty"); + + assert!(spec.config.network.allowed_domains.is_empty()); + assert_eq!(spec.constraints.allowed_domains, Some(Vec::new())); + assert_eq!(spec.constraints.allowlist_expansion_enabled, Some(false)); + assert!(spec.hard_deny_allowlist_misses); +} + +#[test] +fn requirements_denied_domains_are_a_baseline_for_default_mode() { + let mut config = NetworkProxyConfig::default(); + config.network.denied_domains = vec!["blocked.example.com".to_string()]; + let requirements = NetworkConstraints { + denied_domains: Some(vec!["managed-blocked.example.com".to_string()]), + ..Default::default() + }; + + let spec = NetworkProxySpec::from_config_and_constraints( + config, + Some(requirements), + &SandboxPolicy::new_workspace_write_policy(), + ) + .expect("default mode should merge managed and user deny entries"); + + assert_eq!( + spec.config.network.denied_domains, + vec![ + "managed-blocked.example.com".to_string(), + "blocked.example.com".to_string() + ] + ); + assert_eq!(spec.constraints.denylist_expansion_enabled, Some(true)); +} diff --git a/codex-rs/core/src/config/permissions.rs b/codex-rs/core/src/config/permissions.rs index b931c0f415..0ad98068fa 100644 --- a/codex-rs/core/src/config/permissions.rs +++ b/codex-rs/core/src/config/permissions.rs @@ -410,14 +410,5 @@ fn maybe_push_unknown_special_path_warning( } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn normalize_absolute_path_for_platform_simplifies_windows_verbatim_paths() { - let parsed = - normalize_absolute_path_for_platform(r"\\?\D:\c\x\worktrees\2508\swift-base", true); - assert_eq!(parsed, PathBuf::from(r"D:\c\x\worktrees\2508\swift-base")); - } -} +#[path = "permissions_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/config/permissions_tests.rs b/codex-rs/core/src/config/permissions_tests.rs new file mode 100644 index 0000000000..036c8450cd --- /dev/null +++ b/codex-rs/core/src/config/permissions_tests.rs @@ -0,0 +1,9 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn normalize_absolute_path_for_platform_simplifies_windows_verbatim_paths() { + let parsed = + normalize_absolute_path_for_platform(r"\\?\D:\c\x\worktrees\2508\swift-base", true); + assert_eq!(parsed, PathBuf::from(r"D:\c\x\worktrees\2508\swift-base")); +} diff --git a/codex-rs/core/src/config/schema.rs b/codex-rs/core/src/config/schema.rs index 95aea130e6..851f4d19ee 100644 --- a/codex-rs/core/src/config/schema.rs +++ b/codex-rs/core/src/config/schema.rs @@ -96,54 +96,5 @@ pub fn write_config_schema(out_path: &Path) -> anyhow::Result<()> { } #[cfg(test)] -mod tests { - use super::canonicalize; - use super::config_schema_json; - use super::write_config_schema; - - use pretty_assertions::assert_eq; - use similar::TextDiff; - use tempfile::TempDir; - - #[test] - fn config_schema_matches_fixture() { - let fixture_path = codex_utils_cargo_bin::find_resource!("config.schema.json") - .expect("resolve config schema fixture path"); - let fixture = std::fs::read_to_string(fixture_path).expect("read config schema fixture"); - let fixture_value: serde_json::Value = - serde_json::from_str(&fixture).expect("parse config schema fixture"); - let schema_json = config_schema_json().expect("serialize config schema"); - let schema_value: serde_json::Value = - serde_json::from_slice(&schema_json).expect("decode schema json"); - let fixture_value = canonicalize(&fixture_value); - let schema_value = canonicalize(&schema_value); - if fixture_value != schema_value { - let expected = - serde_json::to_string_pretty(&fixture_value).expect("serialize fixture json"); - let actual = - serde_json::to_string_pretty(&schema_value).expect("serialize schema json"); - let diff = TextDiff::from_lines(&expected, &actual) - .unified_diff() - .header("fixture", "generated") - .to_string(); - panic!( - "Current schema for `config.toml` doesn't match the fixture. \ -Run `just write-config-schema` to overwrite with your changes.\n\n{diff}" - ); - } - - // Make sure the version in the repo matches exactly: https://github.com/openai/codex/pull/10977. - let tmp = TempDir::new().expect("create temp dir"); - let tmp_path = tmp.path().join("config.schema.json"); - write_config_schema(&tmp_path).expect("write config schema to temp path"); - let tmp_contents = - std::fs::read_to_string(&tmp_path).expect("read back config schema from temp path"); - #[cfg(windows)] - let fixture = fixture.replace("\r\n", "\n"); - - assert_eq!( - fixture, tmp_contents, - "fixture should match exactly with generated schema" - ); - } -} +#[path = "schema_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/config/schema_tests.rs b/codex-rs/core/src/config/schema_tests.rs new file mode 100644 index 0000000000..6205d43f40 --- /dev/null +++ b/codex-rs/core/src/config/schema_tests.rs @@ -0,0 +1,48 @@ +use super::canonicalize; +use super::config_schema_json; +use super::write_config_schema; + +use pretty_assertions::assert_eq; +use similar::TextDiff; +use tempfile::TempDir; + +#[test] +fn config_schema_matches_fixture() { + let fixture_path = codex_utils_cargo_bin::find_resource!("config.schema.json") + .expect("resolve config schema fixture path"); + let fixture = std::fs::read_to_string(fixture_path).expect("read config schema fixture"); + let fixture_value: serde_json::Value = + serde_json::from_str(&fixture).expect("parse config schema fixture"); + let schema_json = config_schema_json().expect("serialize config schema"); + let schema_value: serde_json::Value = + serde_json::from_slice(&schema_json).expect("decode schema json"); + let fixture_value = canonicalize(&fixture_value); + let schema_value = canonicalize(&schema_value); + if fixture_value != schema_value { + let expected = + serde_json::to_string_pretty(&fixture_value).expect("serialize fixture json"); + let actual = serde_json::to_string_pretty(&schema_value).expect("serialize schema json"); + let diff = TextDiff::from_lines(&expected, &actual) + .unified_diff() + .header("fixture", "generated") + .to_string(); + panic!( + "Current schema for `config.toml` doesn't match the fixture. \ +Run `just write-config-schema` to overwrite with your changes.\n\n{diff}" + ); + } + + // Make sure the version in the repo matches exactly: https://github.com/openai/codex/pull/10977. + let tmp = TempDir::new().expect("create temp dir"); + let tmp_path = tmp.path().join("config.schema.json"); + write_config_schema(&tmp_path).expect("write config schema to temp path"); + let tmp_contents = + std::fs::read_to_string(&tmp_path).expect("read back config schema from temp path"); + #[cfg(windows)] + let fixture = fixture.replace("\r\n", "\n"); + + assert_eq!( + fixture, tmp_contents, + "fixture should match exactly with generated schema" + ); +} diff --git a/codex-rs/core/src/config/service.rs b/codex-rs/core/src/config/service.rs index df344afb40..7d3e2200e9 100644 --- a/codex-rs/core/src/config/service.rs +++ b/codex-rs/core/src/config/service.rs @@ -731,690 +731,5 @@ fn find_effective_layer( } #[cfg(test)] -mod tests { - use super::*; - use anyhow::Result; - use codex_app_server_protocol::AppConfig; - use codex_app_server_protocol::AppToolApproval; - use codex_app_server_protocol::AppsConfig; - use codex_app_server_protocol::AskForApproval; - use codex_utils_absolute_path::AbsolutePathBuf; - use pretty_assertions::assert_eq; - use std::collections::BTreeMap; - use tempfile::tempdir; - - #[test] - fn toml_value_to_item_handles_nested_config_tables() { - let config = r#" -[mcp_servers.docs] -command = "docs-server" - -[mcp_servers.docs.http_headers] -X-Doc = "42" -"#; - - let value: TomlValue = toml::from_str(config).expect("parse config example"); - let item = toml_value_to_item(&value).expect("convert to toml_edit item"); - - let root = item.as_table().expect("root table"); - assert!(!root.is_implicit(), "root table should be explicit"); - - let mcp_servers = root - .get("mcp_servers") - .and_then(TomlItem::as_table) - .expect("mcp_servers table"); - assert!( - !mcp_servers.is_implicit(), - "mcp_servers table should be explicit" - ); - - let docs = mcp_servers - .get("docs") - .and_then(TomlItem::as_table) - .expect("docs table"); - assert_eq!( - docs.get("command") - .and_then(TomlItem::as_value) - .and_then(toml_edit::Value::as_str), - Some("docs-server") - ); - - let http_headers = docs - .get("http_headers") - .and_then(TomlItem::as_table) - .expect("http_headers table"); - assert_eq!( - http_headers - .get("X-Doc") - .and_then(TomlItem::as_value) - .and_then(toml_edit::Value::as_str), - Some("42") - ); - } - - #[tokio::test] - async fn write_value_preserves_comments_and_order() -> Result<()> { - let tmp = tempdir().expect("tempdir"); - let original = r#"# Codex user configuration -model = "gpt-5" -approval_policy = "on-request" - -[notice] -# Preserve this comment -hide_full_access_warning = true - -[features] -unified_exec = true -"#; - std::fs::write(tmp.path().join(CONFIG_TOML_FILE), original)?; - - let service = ConfigService::new_with_defaults(tmp.path().to_path_buf()); - service - .write_value(ConfigValueWriteParams { - file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), - key_path: "features.personality".to_string(), - value: serde_json::json!(true), - merge_strategy: MergeStrategy::Replace, - expected_version: None, - }) - .await - .expect("write succeeds"); - - let updated = - std::fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).expect("read config"); - let expected = r#"# Codex user configuration -model = "gpt-5" -approval_policy = "on-request" - -[notice] -# Preserve this comment -hide_full_access_warning = true - -[features] -unified_exec = true -personality = true -"#; - assert_eq!(updated, expected); - Ok(()) - } - - #[tokio::test] - async fn write_value_supports_nested_app_paths() -> Result<()> { - let tmp = tempdir().expect("tempdir"); - std::fs::write(tmp.path().join(CONFIG_TOML_FILE), "")?; - - let service = ConfigService::new_with_defaults(tmp.path().to_path_buf()); - service - .write_value(ConfigValueWriteParams { - file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), - key_path: "apps".to_string(), - value: serde_json::json!({ - "app1": { - "enabled": false, - }, - }), - merge_strategy: MergeStrategy::Replace, - expected_version: None, - }) - .await - .expect("write apps succeeds"); - - service - .write_value(ConfigValueWriteParams { - file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), - key_path: "apps.app1.default_tools_approval_mode".to_string(), - value: serde_json::json!("prompt"), - merge_strategy: MergeStrategy::Replace, - expected_version: None, - }) - .await - .expect("write apps.app1.default_tools_approval_mode succeeds"); - - let read = service - .read(ConfigReadParams { - include_layers: false, - cwd: None, - }) - .await - .expect("config read succeeds"); - - assert_eq!( - read.config.apps, - Some(AppsConfig { - default: None, - apps: std::collections::HashMap::from([( - "app1".to_string(), - AppConfig { - enabled: false, - destructive_enabled: None, - open_world_enabled: None, - default_tools_approval_mode: Some(AppToolApproval::Prompt), - default_tools_enabled: None, - tools: None, - }, - )]), - }) - ); - - Ok(()) - } - - #[tokio::test] - async fn read_includes_origins_and_layers() { - let tmp = tempdir().expect("tempdir"); - let user_path = tmp.path().join(CONFIG_TOML_FILE); - std::fs::write(&user_path, "model = \"user\"").unwrap(); - let user_file = AbsolutePathBuf::try_from(user_path.clone()).expect("user file"); - - let managed_path = tmp.path().join("managed_config.toml"); - std::fs::write(&managed_path, "approval_policy = \"never\"").unwrap(); - let managed_file = AbsolutePathBuf::try_from(managed_path.clone()).expect("managed file"); - - let service = ConfigService::new( - tmp.path().to_path_buf(), - vec![], - LoaderOverrides { - managed_config_path: Some(managed_path.clone()), - #[cfg(target_os = "macos")] - managed_preferences_base64: None, - macos_managed_config_requirements_base64: None, - }, - CloudRequirementsLoader::default(), - ); - - let response = service - .read(ConfigReadParams { - include_layers: true, - cwd: None, - }) - .await - .expect("response"); - - assert_eq!(response.config.approval_policy, Some(AskForApproval::Never)); - - assert_eq!( - response - .origins - .get("approval_policy") - .expect("origin") - .name, - ConfigLayerSource::LegacyManagedConfigTomlFromFile { - file: managed_file.clone() - }, - ); - let layers = response.layers.expect("layers present"); - // Local macOS machines can surface an MDM-managed config layer at the - // top of the stack; ignore it so this test stays focused on file/user/system ordering. - let layers = if matches!( - layers.first().map(|layer| &layer.name), - Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm) - ) { - &layers[1..] - } else { - layers.as_slice() - }; - assert_eq!(layers.len(), 3, "expected three layers"); - assert_eq!( - layers.first().unwrap().name, - ConfigLayerSource::LegacyManagedConfigTomlFromFile { - file: managed_file.clone() - } - ); - assert_eq!( - layers.get(1).unwrap().name, - ConfigLayerSource::User { - file: user_file.clone() - } - ); - assert!(matches!( - layers.get(2).unwrap().name, - ConfigLayerSource::System { .. } - )); - } - - #[tokio::test] - async fn write_value_reports_override() { - let tmp = tempdir().expect("tempdir"); - std::fs::write( - tmp.path().join(CONFIG_TOML_FILE), - "approval_policy = \"on-request\"", - ) - .unwrap(); - - let managed_path = tmp.path().join("managed_config.toml"); - std::fs::write(&managed_path, "approval_policy = \"never\"").unwrap(); - let managed_file = AbsolutePathBuf::try_from(managed_path.clone()).expect("managed file"); - - let service = ConfigService::new( - tmp.path().to_path_buf(), - vec![], - LoaderOverrides { - managed_config_path: Some(managed_path.clone()), - #[cfg(target_os = "macos")] - managed_preferences_base64: None, - macos_managed_config_requirements_base64: None, - }, - CloudRequirementsLoader::default(), - ); - - let result = service - .write_value(ConfigValueWriteParams { - file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), - key_path: "approval_policy".to_string(), - value: serde_json::json!("never"), - merge_strategy: MergeStrategy::Replace, - expected_version: None, - }) - .await - .expect("result"); - - let read_after = service - .read(ConfigReadParams { - include_layers: true, - cwd: None, - }) - .await - .expect("read"); - assert_eq!( - read_after.config.approval_policy, - Some(AskForApproval::Never) - ); - assert_eq!( - read_after - .origins - .get("approval_policy") - .expect("origin") - .name, - ConfigLayerSource::LegacyManagedConfigTomlFromFile { - file: managed_file.clone() - } - ); - assert_eq!(result.status, WriteStatus::Ok); - assert!(result.overridden_metadata.is_none()); - } - - #[tokio::test] - async fn version_conflict_rejected() { - let tmp = tempdir().expect("tempdir"); - let user_path = tmp.path().join(CONFIG_TOML_FILE); - std::fs::write(&user_path, "model = \"user\"").unwrap(); - - let service = ConfigService::new_with_defaults(tmp.path().to_path_buf()); - let error = service - .write_value(ConfigValueWriteParams { - file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), - key_path: "model".to_string(), - value: serde_json::json!("gpt-5"), - merge_strategy: MergeStrategy::Replace, - expected_version: Some("sha256:bogus".to_string()), - }) - .await - .expect_err("should fail"); - - assert_eq!( - error.write_error_code(), - Some(ConfigWriteErrorCode::ConfigVersionConflict) - ); - } - - #[tokio::test] - async fn write_value_defaults_to_user_config_path() { - let tmp = tempdir().expect("tempdir"); - std::fs::write(tmp.path().join(CONFIG_TOML_FILE), "").unwrap(); - - let service = ConfigService::new_with_defaults(tmp.path().to_path_buf()); - service - .write_value(ConfigValueWriteParams { - file_path: None, - key_path: "model".to_string(), - value: serde_json::json!("gpt-new"), - merge_strategy: MergeStrategy::Replace, - expected_version: None, - }) - .await - .expect("write succeeds"); - - let contents = - std::fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).expect("read config"); - assert!( - contents.contains("model = \"gpt-new\""), - "config.toml should be updated even when file_path is omitted" - ); - } - - #[tokio::test] - async fn invalid_user_value_rejected_even_if_overridden_by_managed() { - let tmp = tempdir().expect("tempdir"); - std::fs::write(tmp.path().join(CONFIG_TOML_FILE), "model = \"user\"").unwrap(); - - let managed_path = tmp.path().join("managed_config.toml"); - std::fs::write(&managed_path, "approval_policy = \"never\"").unwrap(); - - let service = ConfigService::new( - tmp.path().to_path_buf(), - vec![], - LoaderOverrides { - managed_config_path: Some(managed_path.clone()), - #[cfg(target_os = "macos")] - managed_preferences_base64: None, - macos_managed_config_requirements_base64: None, - }, - CloudRequirementsLoader::default(), - ); - - let error = service - .write_value(ConfigValueWriteParams { - file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), - key_path: "approval_policy".to_string(), - value: serde_json::json!("bogus"), - merge_strategy: MergeStrategy::Replace, - expected_version: None, - }) - .await - .expect_err("should fail validation"); - - assert_eq!( - error.write_error_code(), - Some(ConfigWriteErrorCode::ConfigValidationError) - ); - - let contents = - std::fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).expect("read config"); - assert_eq!(contents.trim(), "model = \"user\""); - } - - #[tokio::test] - async fn write_value_rejects_feature_requirement_conflict() { - let tmp = tempdir().expect("tempdir"); - std::fs::write(tmp.path().join(CONFIG_TOML_FILE), "").unwrap(); - - let service = ConfigService::new( - tmp.path().to_path_buf(), - vec![], - LoaderOverrides { - managed_config_path: None, - #[cfg(target_os = "macos")] - managed_preferences_base64: None, - macos_managed_config_requirements_base64: None, - }, - CloudRequirementsLoader::new(async { - Ok(Some(ConfigRequirementsToml { - feature_requirements: Some(crate::config_loader::FeatureRequirementsToml { - entries: BTreeMap::from([("personality".to_string(), true)]), - }), - ..Default::default() - })) - }), - ); - - let error = service - .write_value(ConfigValueWriteParams { - file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), - key_path: "features.personality".to_string(), - value: serde_json::json!(false), - merge_strategy: MergeStrategy::Replace, - expected_version: None, - }) - .await - .expect_err("conflicting feature write should fail"); - - assert_eq!( - error.write_error_code(), - Some(ConfigWriteErrorCode::ConfigValidationError) - ); - assert!( - error - .to_string() - .contains("invalid value for `features`: `features.personality=false`"), - "{error}" - ); - assert_eq!( - std::fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).unwrap(), - "" - ); - } - - #[tokio::test] - async fn write_value_rejects_profile_feature_requirement_conflict() { - let tmp = tempdir().expect("tempdir"); - std::fs::write(tmp.path().join(CONFIG_TOML_FILE), "").unwrap(); - - let service = ConfigService::new( - tmp.path().to_path_buf(), - vec![], - LoaderOverrides { - managed_config_path: None, - #[cfg(target_os = "macos")] - managed_preferences_base64: None, - macos_managed_config_requirements_base64: None, - }, - CloudRequirementsLoader::new(async { - Ok(Some(ConfigRequirementsToml { - feature_requirements: Some(crate::config_loader::FeatureRequirementsToml { - entries: BTreeMap::from([("personality".to_string(), true)]), - }), - ..Default::default() - })) - }), - ); - - let error = service - .write_value(ConfigValueWriteParams { - file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), - key_path: "profiles.enterprise.features.personality".to_string(), - value: serde_json::json!(false), - merge_strategy: MergeStrategy::Replace, - expected_version: None, - }) - .await - .expect_err("conflicting profile feature write should fail"); - - assert_eq!( - error.write_error_code(), - Some(ConfigWriteErrorCode::ConfigValidationError) - ); - assert!( - error.to_string().contains( - "invalid value for `features`: `profiles.enterprise.features.personality=false`" - ), - "{error}" - ); - assert_eq!( - std::fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).unwrap(), - "" - ); - } - - #[tokio::test] - async fn read_reports_managed_overrides_user_and_session_flags() { - let tmp = tempdir().expect("tempdir"); - let user_path = tmp.path().join(CONFIG_TOML_FILE); - std::fs::write(&user_path, "model = \"user\"").unwrap(); - let user_file = AbsolutePathBuf::try_from(user_path.clone()).expect("user file"); - - let managed_path = tmp.path().join("managed_config.toml"); - std::fs::write(&managed_path, "model = \"system\"").unwrap(); - let managed_file = AbsolutePathBuf::try_from(managed_path.clone()).expect("managed file"); - - let cli_overrides = vec![( - "model".to_string(), - TomlValue::String("session".to_string()), - )]; - - let service = ConfigService::new( - tmp.path().to_path_buf(), - cli_overrides, - LoaderOverrides { - managed_config_path: Some(managed_path.clone()), - #[cfg(target_os = "macos")] - managed_preferences_base64: None, - macos_managed_config_requirements_base64: None, - }, - CloudRequirementsLoader::default(), - ); - - let response = service - .read(ConfigReadParams { - include_layers: true, - cwd: None, - }) - .await - .expect("response"); - - assert_eq!(response.config.model.as_deref(), Some("system")); - assert_eq!( - response.origins.get("model").expect("origin").name, - ConfigLayerSource::LegacyManagedConfigTomlFromFile { - file: managed_file.clone() - }, - ); - let layers = response.layers.expect("layers"); - // Local macOS machines can surface an MDM-managed config layer at the - // top of the stack; ignore it so this test stays focused on file/session/user ordering. - let layers = if matches!( - layers.first().map(|layer| &layer.name), - Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm) - ) { - &layers[1..] - } else { - layers.as_slice() - }; - assert_eq!( - layers.first().unwrap().name, - ConfigLayerSource::LegacyManagedConfigTomlFromFile { file: managed_file } - ); - assert_eq!(layers.get(1).unwrap().name, ConfigLayerSource::SessionFlags); - assert_eq!( - layers.get(2).unwrap().name, - ConfigLayerSource::User { file: user_file } - ); - } - - #[tokio::test] - async fn write_value_reports_managed_override() { - let tmp = tempdir().expect("tempdir"); - std::fs::write(tmp.path().join(CONFIG_TOML_FILE), "").unwrap(); - - let managed_path = tmp.path().join("managed_config.toml"); - std::fs::write(&managed_path, "approval_policy = \"never\"").unwrap(); - let managed_file = AbsolutePathBuf::try_from(managed_path.clone()).expect("managed file"); - - let service = ConfigService::new( - tmp.path().to_path_buf(), - vec![], - LoaderOverrides { - managed_config_path: Some(managed_path.clone()), - #[cfg(target_os = "macos")] - managed_preferences_base64: None, - macos_managed_config_requirements_base64: None, - }, - CloudRequirementsLoader::default(), - ); - - let result = service - .write_value(ConfigValueWriteParams { - file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), - key_path: "approval_policy".to_string(), - value: serde_json::json!("on-request"), - merge_strategy: MergeStrategy::Replace, - expected_version: None, - }) - .await - .expect("result"); - - assert_eq!(result.status, WriteStatus::OkOverridden); - let overridden = result.overridden_metadata.expect("overridden metadata"); - assert_eq!( - overridden.overriding_layer.name, - ConfigLayerSource::LegacyManagedConfigTomlFromFile { file: managed_file } - ); - assert_eq!(overridden.effective_value, serde_json::json!("never")); - } - - #[tokio::test] - async fn upsert_merges_tables_replace_overwrites() -> Result<()> { - let tmp = tempdir().expect("tempdir"); - let path = tmp.path().join(CONFIG_TOML_FILE); - let base = r#"[mcp_servers.linear] -bearer_token_env_var = "TOKEN" -name = "linear" -url = "https://linear.example" - -[mcp_servers.linear.env_http_headers] -existing = "keep" - -[mcp_servers.linear.http_headers] -alpha = "a" -"#; - - let overlay = serde_json::json!({ - "bearer_token_env_var": "NEW_TOKEN", - "http_headers": { - "alpha": "updated", - "beta": "b" - }, - "name": "linear", - "url": "https://linear.example" - }); - - std::fs::write(&path, base)?; - - let service = ConfigService::new_with_defaults(tmp.path().to_path_buf()); - service - .write_value(ConfigValueWriteParams { - file_path: Some(path.display().to_string()), - key_path: "mcp_servers.linear".to_string(), - value: overlay.clone(), - merge_strategy: MergeStrategy::Upsert, - expected_version: None, - }) - .await - .expect("upsert succeeds"); - - let upserted: TomlValue = toml::from_str(&std::fs::read_to_string(&path)?)?; - let expected_upsert: TomlValue = toml::from_str( - r#"[mcp_servers.linear] -bearer_token_env_var = "NEW_TOKEN" -name = "linear" -url = "https://linear.example" - -[mcp_servers.linear.env_http_headers] -existing = "keep" - -[mcp_servers.linear.http_headers] -alpha = "updated" -beta = "b" -"#, - )?; - assert_eq!(upserted, expected_upsert); - - std::fs::write(&path, base)?; - - service - .write_value(ConfigValueWriteParams { - file_path: Some(path.display().to_string()), - key_path: "mcp_servers.linear".to_string(), - value: overlay, - merge_strategy: MergeStrategy::Replace, - expected_version: None, - }) - .await - .expect("replace succeeds"); - - let replaced: TomlValue = toml::from_str(&std::fs::read_to_string(&path)?)?; - let expected_replace: TomlValue = toml::from_str( - r#"[mcp_servers.linear] -bearer_token_env_var = "NEW_TOKEN" -name = "linear" -url = "https://linear.example" - -[mcp_servers.linear.http_headers] -alpha = "updated" -beta = "b" -"#, - )?; - assert_eq!(replaced, expected_replace); - - Ok(()) - } -} +#[path = "service_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/config/service_tests.rs b/codex-rs/core/src/config/service_tests.rs new file mode 100644 index 0000000000..a23537e139 --- /dev/null +++ b/codex-rs/core/src/config/service_tests.rs @@ -0,0 +1,682 @@ +use super::*; +use anyhow::Result; +use codex_app_server_protocol::AppConfig; +use codex_app_server_protocol::AppToolApproval; +use codex_app_server_protocol::AppsConfig; +use codex_app_server_protocol::AskForApproval; +use codex_utils_absolute_path::AbsolutePathBuf; +use pretty_assertions::assert_eq; +use std::collections::BTreeMap; +use tempfile::tempdir; + +#[test] +fn toml_value_to_item_handles_nested_config_tables() { + let config = r#" +[mcp_servers.docs] +command = "docs-server" + +[mcp_servers.docs.http_headers] +X-Doc = "42" +"#; + + let value: TomlValue = toml::from_str(config).expect("parse config example"); + let item = toml_value_to_item(&value).expect("convert to toml_edit item"); + + let root = item.as_table().expect("root table"); + assert!(!root.is_implicit(), "root table should be explicit"); + + let mcp_servers = root + .get("mcp_servers") + .and_then(TomlItem::as_table) + .expect("mcp_servers table"); + assert!( + !mcp_servers.is_implicit(), + "mcp_servers table should be explicit" + ); + + let docs = mcp_servers + .get("docs") + .and_then(TomlItem::as_table) + .expect("docs table"); + assert_eq!( + docs.get("command") + .and_then(TomlItem::as_value) + .and_then(toml_edit::Value::as_str), + Some("docs-server") + ); + + let http_headers = docs + .get("http_headers") + .and_then(TomlItem::as_table) + .expect("http_headers table"); + assert_eq!( + http_headers + .get("X-Doc") + .and_then(TomlItem::as_value) + .and_then(toml_edit::Value::as_str), + Some("42") + ); +} + +#[tokio::test] +async fn write_value_preserves_comments_and_order() -> Result<()> { + let tmp = tempdir().expect("tempdir"); + let original = r#"# Codex user configuration +model = "gpt-5" +approval_policy = "on-request" + +[notice] +# Preserve this comment +hide_full_access_warning = true + +[features] +unified_exec = true +"#; + std::fs::write(tmp.path().join(CONFIG_TOML_FILE), original)?; + + let service = ConfigService::new_with_defaults(tmp.path().to_path_buf()); + service + .write_value(ConfigValueWriteParams { + file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), + key_path: "features.personality".to_string(), + value: serde_json::json!(true), + merge_strategy: MergeStrategy::Replace, + expected_version: None, + }) + .await + .expect("write succeeds"); + + let updated = std::fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).expect("read config"); + let expected = r#"# Codex user configuration +model = "gpt-5" +approval_policy = "on-request" + +[notice] +# Preserve this comment +hide_full_access_warning = true + +[features] +unified_exec = true +personality = true +"#; + assert_eq!(updated, expected); + Ok(()) +} + +#[tokio::test] +async fn write_value_supports_nested_app_paths() -> Result<()> { + let tmp = tempdir().expect("tempdir"); + std::fs::write(tmp.path().join(CONFIG_TOML_FILE), "")?; + + let service = ConfigService::new_with_defaults(tmp.path().to_path_buf()); + service + .write_value(ConfigValueWriteParams { + file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), + key_path: "apps".to_string(), + value: serde_json::json!({ + "app1": { + "enabled": false, + }, + }), + merge_strategy: MergeStrategy::Replace, + expected_version: None, + }) + .await + .expect("write apps succeeds"); + + service + .write_value(ConfigValueWriteParams { + file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), + key_path: "apps.app1.default_tools_approval_mode".to_string(), + value: serde_json::json!("prompt"), + merge_strategy: MergeStrategy::Replace, + expected_version: None, + }) + .await + .expect("write apps.app1.default_tools_approval_mode succeeds"); + + let read = service + .read(ConfigReadParams { + include_layers: false, + cwd: None, + }) + .await + .expect("config read succeeds"); + + assert_eq!( + read.config.apps, + Some(AppsConfig { + default: None, + apps: std::collections::HashMap::from([( + "app1".to_string(), + AppConfig { + enabled: false, + destructive_enabled: None, + open_world_enabled: None, + default_tools_approval_mode: Some(AppToolApproval::Prompt), + default_tools_enabled: None, + tools: None, + }, + )]), + }) + ); + + Ok(()) +} + +#[tokio::test] +async fn read_includes_origins_and_layers() { + let tmp = tempdir().expect("tempdir"); + let user_path = tmp.path().join(CONFIG_TOML_FILE); + std::fs::write(&user_path, "model = \"user\"").unwrap(); + let user_file = AbsolutePathBuf::try_from(user_path.clone()).expect("user file"); + + let managed_path = tmp.path().join("managed_config.toml"); + std::fs::write(&managed_path, "approval_policy = \"never\"").unwrap(); + let managed_file = AbsolutePathBuf::try_from(managed_path.clone()).expect("managed file"); + + let service = ConfigService::new( + tmp.path().to_path_buf(), + vec![], + LoaderOverrides { + managed_config_path: Some(managed_path.clone()), + #[cfg(target_os = "macos")] + managed_preferences_base64: None, + macos_managed_config_requirements_base64: None, + }, + CloudRequirementsLoader::default(), + ); + + let response = service + .read(ConfigReadParams { + include_layers: true, + cwd: None, + }) + .await + .expect("response"); + + assert_eq!(response.config.approval_policy, Some(AskForApproval::Never)); + + assert_eq!( + response + .origins + .get("approval_policy") + .expect("origin") + .name, + ConfigLayerSource::LegacyManagedConfigTomlFromFile { + file: managed_file.clone() + }, + ); + let layers = response.layers.expect("layers present"); + // Local macOS machines can surface an MDM-managed config layer at the + // top of the stack; ignore it so this test stays focused on file/user/system ordering. + let layers = if matches!( + layers.first().map(|layer| &layer.name), + Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm) + ) { + &layers[1..] + } else { + layers.as_slice() + }; + assert_eq!(layers.len(), 3, "expected three layers"); + assert_eq!( + layers.first().unwrap().name, + ConfigLayerSource::LegacyManagedConfigTomlFromFile { + file: managed_file.clone() + } + ); + assert_eq!( + layers.get(1).unwrap().name, + ConfigLayerSource::User { + file: user_file.clone() + } + ); + assert!(matches!( + layers.get(2).unwrap().name, + ConfigLayerSource::System { .. } + )); +} + +#[tokio::test] +async fn write_value_reports_override() { + let tmp = tempdir().expect("tempdir"); + std::fs::write( + tmp.path().join(CONFIG_TOML_FILE), + "approval_policy = \"on-request\"", + ) + .unwrap(); + + let managed_path = tmp.path().join("managed_config.toml"); + std::fs::write(&managed_path, "approval_policy = \"never\"").unwrap(); + let managed_file = AbsolutePathBuf::try_from(managed_path.clone()).expect("managed file"); + + let service = ConfigService::new( + tmp.path().to_path_buf(), + vec![], + LoaderOverrides { + managed_config_path: Some(managed_path.clone()), + #[cfg(target_os = "macos")] + managed_preferences_base64: None, + macos_managed_config_requirements_base64: None, + }, + CloudRequirementsLoader::default(), + ); + + let result = service + .write_value(ConfigValueWriteParams { + file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), + key_path: "approval_policy".to_string(), + value: serde_json::json!("never"), + merge_strategy: MergeStrategy::Replace, + expected_version: None, + }) + .await + .expect("result"); + + let read_after = service + .read(ConfigReadParams { + include_layers: true, + cwd: None, + }) + .await + .expect("read"); + assert_eq!( + read_after.config.approval_policy, + Some(AskForApproval::Never) + ); + assert_eq!( + read_after + .origins + .get("approval_policy") + .expect("origin") + .name, + ConfigLayerSource::LegacyManagedConfigTomlFromFile { + file: managed_file.clone() + } + ); + assert_eq!(result.status, WriteStatus::Ok); + assert!(result.overridden_metadata.is_none()); +} + +#[tokio::test] +async fn version_conflict_rejected() { + let tmp = tempdir().expect("tempdir"); + let user_path = tmp.path().join(CONFIG_TOML_FILE); + std::fs::write(&user_path, "model = \"user\"").unwrap(); + + let service = ConfigService::new_with_defaults(tmp.path().to_path_buf()); + let error = service + .write_value(ConfigValueWriteParams { + file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), + key_path: "model".to_string(), + value: serde_json::json!("gpt-5"), + merge_strategy: MergeStrategy::Replace, + expected_version: Some("sha256:bogus".to_string()), + }) + .await + .expect_err("should fail"); + + assert_eq!( + error.write_error_code(), + Some(ConfigWriteErrorCode::ConfigVersionConflict) + ); +} + +#[tokio::test] +async fn write_value_defaults_to_user_config_path() { + let tmp = tempdir().expect("tempdir"); + std::fs::write(tmp.path().join(CONFIG_TOML_FILE), "").unwrap(); + + let service = ConfigService::new_with_defaults(tmp.path().to_path_buf()); + service + .write_value(ConfigValueWriteParams { + file_path: None, + key_path: "model".to_string(), + value: serde_json::json!("gpt-new"), + merge_strategy: MergeStrategy::Replace, + expected_version: None, + }) + .await + .expect("write succeeds"); + + let contents = std::fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).expect("read config"); + assert!( + contents.contains("model = \"gpt-new\""), + "config.toml should be updated even when file_path is omitted" + ); +} + +#[tokio::test] +async fn invalid_user_value_rejected_even_if_overridden_by_managed() { + let tmp = tempdir().expect("tempdir"); + std::fs::write(tmp.path().join(CONFIG_TOML_FILE), "model = \"user\"").unwrap(); + + let managed_path = tmp.path().join("managed_config.toml"); + std::fs::write(&managed_path, "approval_policy = \"never\"").unwrap(); + + let service = ConfigService::new( + tmp.path().to_path_buf(), + vec![], + LoaderOverrides { + managed_config_path: Some(managed_path.clone()), + #[cfg(target_os = "macos")] + managed_preferences_base64: None, + macos_managed_config_requirements_base64: None, + }, + CloudRequirementsLoader::default(), + ); + + let error = service + .write_value(ConfigValueWriteParams { + file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), + key_path: "approval_policy".to_string(), + value: serde_json::json!("bogus"), + merge_strategy: MergeStrategy::Replace, + expected_version: None, + }) + .await + .expect_err("should fail validation"); + + assert_eq!( + error.write_error_code(), + Some(ConfigWriteErrorCode::ConfigValidationError) + ); + + let contents = std::fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).expect("read config"); + assert_eq!(contents.trim(), "model = \"user\""); +} + +#[tokio::test] +async fn write_value_rejects_feature_requirement_conflict() { + let tmp = tempdir().expect("tempdir"); + std::fs::write(tmp.path().join(CONFIG_TOML_FILE), "").unwrap(); + + let service = ConfigService::new( + tmp.path().to_path_buf(), + vec![], + LoaderOverrides { + managed_config_path: None, + #[cfg(target_os = "macos")] + managed_preferences_base64: None, + macos_managed_config_requirements_base64: None, + }, + CloudRequirementsLoader::new(async { + Ok(Some(ConfigRequirementsToml { + feature_requirements: Some(crate::config_loader::FeatureRequirementsToml { + entries: BTreeMap::from([("personality".to_string(), true)]), + }), + ..Default::default() + })) + }), + ); + + let error = service + .write_value(ConfigValueWriteParams { + file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), + key_path: "features.personality".to_string(), + value: serde_json::json!(false), + merge_strategy: MergeStrategy::Replace, + expected_version: None, + }) + .await + .expect_err("conflicting feature write should fail"); + + assert_eq!( + error.write_error_code(), + Some(ConfigWriteErrorCode::ConfigValidationError) + ); + assert!( + error + .to_string() + .contains("invalid value for `features`: `features.personality=false`"), + "{error}" + ); + assert_eq!( + std::fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).unwrap(), + "" + ); +} + +#[tokio::test] +async fn write_value_rejects_profile_feature_requirement_conflict() { + let tmp = tempdir().expect("tempdir"); + std::fs::write(tmp.path().join(CONFIG_TOML_FILE), "").unwrap(); + + let service = ConfigService::new( + tmp.path().to_path_buf(), + vec![], + LoaderOverrides { + managed_config_path: None, + #[cfg(target_os = "macos")] + managed_preferences_base64: None, + macos_managed_config_requirements_base64: None, + }, + CloudRequirementsLoader::new(async { + Ok(Some(ConfigRequirementsToml { + feature_requirements: Some(crate::config_loader::FeatureRequirementsToml { + entries: BTreeMap::from([("personality".to_string(), true)]), + }), + ..Default::default() + })) + }), + ); + + let error = service + .write_value(ConfigValueWriteParams { + file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), + key_path: "profiles.enterprise.features.personality".to_string(), + value: serde_json::json!(false), + merge_strategy: MergeStrategy::Replace, + expected_version: None, + }) + .await + .expect_err("conflicting profile feature write should fail"); + + assert_eq!( + error.write_error_code(), + Some(ConfigWriteErrorCode::ConfigValidationError) + ); + assert!( + error.to_string().contains( + "invalid value for `features`: `profiles.enterprise.features.personality=false`" + ), + "{error}" + ); + assert_eq!( + std::fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).unwrap(), + "" + ); +} + +#[tokio::test] +async fn read_reports_managed_overrides_user_and_session_flags() { + let tmp = tempdir().expect("tempdir"); + let user_path = tmp.path().join(CONFIG_TOML_FILE); + std::fs::write(&user_path, "model = \"user\"").unwrap(); + let user_file = AbsolutePathBuf::try_from(user_path.clone()).expect("user file"); + + let managed_path = tmp.path().join("managed_config.toml"); + std::fs::write(&managed_path, "model = \"system\"").unwrap(); + let managed_file = AbsolutePathBuf::try_from(managed_path.clone()).expect("managed file"); + + let cli_overrides = vec![( + "model".to_string(), + TomlValue::String("session".to_string()), + )]; + + let service = ConfigService::new( + tmp.path().to_path_buf(), + cli_overrides, + LoaderOverrides { + managed_config_path: Some(managed_path.clone()), + #[cfg(target_os = "macos")] + managed_preferences_base64: None, + macos_managed_config_requirements_base64: None, + }, + CloudRequirementsLoader::default(), + ); + + let response = service + .read(ConfigReadParams { + include_layers: true, + cwd: None, + }) + .await + .expect("response"); + + assert_eq!(response.config.model.as_deref(), Some("system")); + assert_eq!( + response.origins.get("model").expect("origin").name, + ConfigLayerSource::LegacyManagedConfigTomlFromFile { + file: managed_file.clone() + }, + ); + let layers = response.layers.expect("layers"); + // Local macOS machines can surface an MDM-managed config layer at the + // top of the stack; ignore it so this test stays focused on file/session/user ordering. + let layers = if matches!( + layers.first().map(|layer| &layer.name), + Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm) + ) { + &layers[1..] + } else { + layers.as_slice() + }; + assert_eq!( + layers.first().unwrap().name, + ConfigLayerSource::LegacyManagedConfigTomlFromFile { file: managed_file } + ); + assert_eq!(layers.get(1).unwrap().name, ConfigLayerSource::SessionFlags); + assert_eq!( + layers.get(2).unwrap().name, + ConfigLayerSource::User { file: user_file } + ); +} + +#[tokio::test] +async fn write_value_reports_managed_override() { + let tmp = tempdir().expect("tempdir"); + std::fs::write(tmp.path().join(CONFIG_TOML_FILE), "").unwrap(); + + let managed_path = tmp.path().join("managed_config.toml"); + std::fs::write(&managed_path, "approval_policy = \"never\"").unwrap(); + let managed_file = AbsolutePathBuf::try_from(managed_path.clone()).expect("managed file"); + + let service = ConfigService::new( + tmp.path().to_path_buf(), + vec![], + LoaderOverrides { + managed_config_path: Some(managed_path.clone()), + #[cfg(target_os = "macos")] + managed_preferences_base64: None, + macos_managed_config_requirements_base64: None, + }, + CloudRequirementsLoader::default(), + ); + + let result = service + .write_value(ConfigValueWriteParams { + file_path: Some(tmp.path().join(CONFIG_TOML_FILE).display().to_string()), + key_path: "approval_policy".to_string(), + value: serde_json::json!("on-request"), + merge_strategy: MergeStrategy::Replace, + expected_version: None, + }) + .await + .expect("result"); + + assert_eq!(result.status, WriteStatus::OkOverridden); + let overridden = result.overridden_metadata.expect("overridden metadata"); + assert_eq!( + overridden.overriding_layer.name, + ConfigLayerSource::LegacyManagedConfigTomlFromFile { file: managed_file } + ); + assert_eq!(overridden.effective_value, serde_json::json!("never")); +} + +#[tokio::test] +async fn upsert_merges_tables_replace_overwrites() -> Result<()> { + let tmp = tempdir().expect("tempdir"); + let path = tmp.path().join(CONFIG_TOML_FILE); + let base = r#"[mcp_servers.linear] +bearer_token_env_var = "TOKEN" +name = "linear" +url = "https://linear.example" + +[mcp_servers.linear.env_http_headers] +existing = "keep" + +[mcp_servers.linear.http_headers] +alpha = "a" +"#; + + let overlay = serde_json::json!({ + "bearer_token_env_var": "NEW_TOKEN", + "http_headers": { + "alpha": "updated", + "beta": "b" + }, + "name": "linear", + "url": "https://linear.example" + }); + + std::fs::write(&path, base)?; + + let service = ConfigService::new_with_defaults(tmp.path().to_path_buf()); + service + .write_value(ConfigValueWriteParams { + file_path: Some(path.display().to_string()), + key_path: "mcp_servers.linear".to_string(), + value: overlay.clone(), + merge_strategy: MergeStrategy::Upsert, + expected_version: None, + }) + .await + .expect("upsert succeeds"); + + let upserted: TomlValue = toml::from_str(&std::fs::read_to_string(&path)?)?; + let expected_upsert: TomlValue = toml::from_str( + r#"[mcp_servers.linear] +bearer_token_env_var = "NEW_TOKEN" +name = "linear" +url = "https://linear.example" + +[mcp_servers.linear.env_http_headers] +existing = "keep" + +[mcp_servers.linear.http_headers] +alpha = "updated" +beta = "b" +"#, + )?; + assert_eq!(upserted, expected_upsert); + + std::fs::write(&path, base)?; + + service + .write_value(ConfigValueWriteParams { + file_path: Some(path.display().to_string()), + key_path: "mcp_servers.linear".to_string(), + value: overlay, + merge_strategy: MergeStrategy::Replace, + expected_version: None, + }) + .await + .expect("replace succeeds"); + + let replaced: TomlValue = toml::from_str(&std::fs::read_to_string(&path)?)?; + let expected_replace: TomlValue = toml::from_str( + r#"[mcp_servers.linear] +bearer_token_env_var = "NEW_TOKEN" +name = "linear" +url = "https://linear.example" + +[mcp_servers.linear.http_headers] +alpha = "updated" +beta = "b" +"#, + )?; + assert_eq!(replaced, expected_replace); + + Ok(()) +} diff --git a/codex-rs/core/src/config/types.rs b/codex-rs/core/src/config/types.rs index d3e3542c57..68ef2a630e 100644 --- a/codex-rs/core/src/config/types.rs +++ b/codex-rs/core/src/config/types.rs @@ -940,320 +940,5 @@ impl Default for ShellEnvironmentPolicy { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn deserialize_stdio_command_server_config() { - let cfg: McpServerConfig = toml::from_str( - r#" - command = "echo" - "#, - ) - .expect("should deserialize command config"); - - assert_eq!( - cfg.transport, - McpServerTransportConfig::Stdio { - command: "echo".to_string(), - args: vec![], - env: None, - env_vars: Vec::new(), - cwd: None, - } - ); - assert!(cfg.enabled); - assert!(!cfg.required); - assert!(cfg.enabled_tools.is_none()); - assert!(cfg.disabled_tools.is_none()); - } - - #[test] - fn deserialize_stdio_command_server_config_with_args() { - let cfg: McpServerConfig = toml::from_str( - r#" - command = "echo" - args = ["hello", "world"] - "#, - ) - .expect("should deserialize command config"); - - assert_eq!( - cfg.transport, - McpServerTransportConfig::Stdio { - command: "echo".to_string(), - args: vec!["hello".to_string(), "world".to_string()], - env: None, - env_vars: Vec::new(), - cwd: None, - } - ); - assert!(cfg.enabled); - } - - #[test] - fn deserialize_stdio_command_server_config_with_arg_with_args_and_env() { - let cfg: McpServerConfig = toml::from_str( - r#" - command = "echo" - args = ["hello", "world"] - env = { "FOO" = "BAR" } - "#, - ) - .expect("should deserialize command config"); - - assert_eq!( - cfg.transport, - McpServerTransportConfig::Stdio { - command: "echo".to_string(), - args: vec!["hello".to_string(), "world".to_string()], - env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())])), - env_vars: Vec::new(), - cwd: None, - } - ); - assert!(cfg.enabled); - } - - #[test] - fn deserialize_stdio_command_server_config_with_env_vars() { - let cfg: McpServerConfig = toml::from_str( - r#" - command = "echo" - env_vars = ["FOO", "BAR"] - "#, - ) - .expect("should deserialize command config with env_vars"); - - assert_eq!( - cfg.transport, - McpServerTransportConfig::Stdio { - command: "echo".to_string(), - args: vec![], - env: None, - env_vars: vec!["FOO".to_string(), "BAR".to_string()], - cwd: None, - } - ); - } - - #[test] - fn deserialize_stdio_command_server_config_with_cwd() { - let cfg: McpServerConfig = toml::from_str( - r#" - command = "echo" - cwd = "/tmp" - "#, - ) - .expect("should deserialize command config with cwd"); - - assert_eq!( - cfg.transport, - McpServerTransportConfig::Stdio { - command: "echo".to_string(), - args: vec![], - env: None, - env_vars: Vec::new(), - cwd: Some(PathBuf::from("/tmp")), - } - ); - } - - #[test] - fn deserialize_disabled_server_config() { - let cfg: McpServerConfig = toml::from_str( - r#" - command = "echo" - enabled = false - "#, - ) - .expect("should deserialize disabled server config"); - - assert!(!cfg.enabled); - assert!(!cfg.required); - } - - #[test] - fn deserialize_required_server_config() { - let cfg: McpServerConfig = toml::from_str( - r#" - command = "echo" - required = true - "#, - ) - .expect("should deserialize required server config"); - - assert!(cfg.required); - } - - #[test] - fn deserialize_streamable_http_server_config() { - let cfg: McpServerConfig = toml::from_str( - r#" - url = "https://example.com/mcp" - "#, - ) - .expect("should deserialize http config"); - - assert_eq!( - cfg.transport, - McpServerTransportConfig::StreamableHttp { - url: "https://example.com/mcp".to_string(), - bearer_token_env_var: None, - http_headers: None, - env_http_headers: None, - } - ); - assert!(cfg.enabled); - } - - #[test] - fn deserialize_streamable_http_server_config_with_env_var() { - let cfg: McpServerConfig = toml::from_str( - r#" - url = "https://example.com/mcp" - bearer_token_env_var = "GITHUB_TOKEN" - "#, - ) - .expect("should deserialize http config"); - - assert_eq!( - cfg.transport, - McpServerTransportConfig::StreamableHttp { - url: "https://example.com/mcp".to_string(), - bearer_token_env_var: Some("GITHUB_TOKEN".to_string()), - http_headers: None, - env_http_headers: None, - } - ); - assert!(cfg.enabled); - } - - #[test] - fn deserialize_streamable_http_server_config_with_headers() { - let cfg: McpServerConfig = toml::from_str( - r#" - url = "https://example.com/mcp" - http_headers = { "X-Foo" = "bar" } - env_http_headers = { "X-Token" = "TOKEN_ENV" } - "#, - ) - .expect("should deserialize http config with headers"); - - assert_eq!( - cfg.transport, - McpServerTransportConfig::StreamableHttp { - url: "https://example.com/mcp".to_string(), - bearer_token_env_var: None, - http_headers: Some(HashMap::from([("X-Foo".to_string(), "bar".to_string())])), - env_http_headers: Some(HashMap::from([( - "X-Token".to_string(), - "TOKEN_ENV".to_string() - )])), - } - ); - } - - #[test] - fn deserialize_streamable_http_server_config_with_oauth_resource() { - let cfg: McpServerConfig = toml::from_str( - r#" - url = "https://example.com/mcp" - oauth_resource = "https://api.example.com" - "#, - ) - .expect("should deserialize http config with oauth_resource"); - - assert_eq!( - cfg.oauth_resource, - Some("https://api.example.com".to_string()) - ); - } - - #[test] - fn deserialize_server_config_with_tool_filters() { - let cfg: McpServerConfig = toml::from_str( - r#" - command = "echo" - enabled_tools = ["allowed"] - disabled_tools = ["blocked"] - "#, - ) - .expect("should deserialize tool filters"); - - assert_eq!(cfg.enabled_tools, Some(vec!["allowed".to_string()])); - assert_eq!(cfg.disabled_tools, Some(vec!["blocked".to_string()])); - } - - #[test] - fn deserialize_rejects_command_and_url() { - toml::from_str::( - r#" - command = "echo" - url = "https://example.com" - "#, - ) - .expect_err("should reject command+url"); - } - - #[test] - fn deserialize_rejects_env_for_http_transport() { - toml::from_str::( - r#" - url = "https://example.com" - env = { "FOO" = "BAR" } - "#, - ) - .expect_err("should reject env for http transport"); - } - - #[test] - fn deserialize_rejects_headers_for_stdio() { - toml::from_str::( - r#" - command = "echo" - http_headers = { "X-Foo" = "bar" } - "#, - ) - .expect_err("should reject http_headers for stdio transport"); - - toml::from_str::( - r#" - command = "echo" - env_http_headers = { "X-Foo" = "BAR_ENV" } - "#, - ) - .expect_err("should reject env_http_headers for stdio transport"); - - let err = toml::from_str::( - r#" - command = "echo" - oauth_resource = "https://api.example.com" - "#, - ) - .expect_err("should reject oauth_resource for stdio transport"); - - assert!( - err.to_string() - .contains("oauth_resource is not supported for stdio"), - "unexpected error: {err}" - ); - } - - #[test] - fn deserialize_rejects_inline_bearer_token_field() { - let err = toml::from_str::( - r#" - url = "https://example.com" - bearer_token = "secret" - "#, - ) - .expect_err("should reject bearer_token field"); - - assert!( - err.to_string().contains("bearer_token is not supported"), - "unexpected error: {err}" - ); - } -} +#[path = "types_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/config/types_tests.rs b/codex-rs/core/src/config/types_tests.rs new file mode 100644 index 0000000000..adb65e1673 --- /dev/null +++ b/codex-rs/core/src/config/types_tests.rs @@ -0,0 +1,315 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn deserialize_stdio_command_server_config() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + "#, + ) + .expect("should deserialize command config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec![], + env: None, + env_vars: Vec::new(), + cwd: None, + } + ); + assert!(cfg.enabled); + assert!(!cfg.required); + assert!(cfg.enabled_tools.is_none()); + assert!(cfg.disabled_tools.is_none()); +} + +#[test] +fn deserialize_stdio_command_server_config_with_args() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + args = ["hello", "world"] + "#, + ) + .expect("should deserialize command config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec!["hello".to_string(), "world".to_string()], + env: None, + env_vars: Vec::new(), + cwd: None, + } + ); + assert!(cfg.enabled); +} + +#[test] +fn deserialize_stdio_command_server_config_with_arg_with_args_and_env() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + args = ["hello", "world"] + env = { "FOO" = "BAR" } + "#, + ) + .expect("should deserialize command config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec!["hello".to_string(), "world".to_string()], + env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())])), + env_vars: Vec::new(), + cwd: None, + } + ); + assert!(cfg.enabled); +} + +#[test] +fn deserialize_stdio_command_server_config_with_env_vars() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + env_vars = ["FOO", "BAR"] + "#, + ) + .expect("should deserialize command config with env_vars"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec![], + env: None, + env_vars: vec!["FOO".to_string(), "BAR".to_string()], + cwd: None, + } + ); +} + +#[test] +fn deserialize_stdio_command_server_config_with_cwd() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + cwd = "/tmp" + "#, + ) + .expect("should deserialize command config with cwd"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec![], + env: None, + env_vars: Vec::new(), + cwd: Some(PathBuf::from("/tmp")), + } + ); +} + +#[test] +fn deserialize_disabled_server_config() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + enabled = false + "#, + ) + .expect("should deserialize disabled server config"); + + assert!(!cfg.enabled); + assert!(!cfg.required); +} + +#[test] +fn deserialize_required_server_config() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + required = true + "#, + ) + .expect("should deserialize required server config"); + + assert!(cfg.required); +} + +#[test] +fn deserialize_streamable_http_server_config() { + let cfg: McpServerConfig = toml::from_str( + r#" + url = "https://example.com/mcp" + "#, + ) + .expect("should deserialize http config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + } + ); + assert!(cfg.enabled); +} + +#[test] +fn deserialize_streamable_http_server_config_with_env_var() { + let cfg: McpServerConfig = toml::from_str( + r#" + url = "https://example.com/mcp" + bearer_token_env_var = "GITHUB_TOKEN" + "#, + ) + .expect("should deserialize http config"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token_env_var: Some("GITHUB_TOKEN".to_string()), + http_headers: None, + env_http_headers: None, + } + ); + assert!(cfg.enabled); +} + +#[test] +fn deserialize_streamable_http_server_config_with_headers() { + let cfg: McpServerConfig = toml::from_str( + r#" + url = "https://example.com/mcp" + http_headers = { "X-Foo" = "bar" } + env_http_headers = { "X-Token" = "TOKEN_ENV" } + "#, + ) + .expect("should deserialize http config with headers"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token_env_var: None, + http_headers: Some(HashMap::from([("X-Foo".to_string(), "bar".to_string())])), + env_http_headers: Some(HashMap::from([( + "X-Token".to_string(), + "TOKEN_ENV".to_string() + )])), + } + ); +} + +#[test] +fn deserialize_streamable_http_server_config_with_oauth_resource() { + let cfg: McpServerConfig = toml::from_str( + r#" + url = "https://example.com/mcp" + oauth_resource = "https://api.example.com" + "#, + ) + .expect("should deserialize http config with oauth_resource"); + + assert_eq!( + cfg.oauth_resource, + Some("https://api.example.com".to_string()) + ); +} + +#[test] +fn deserialize_server_config_with_tool_filters() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + enabled_tools = ["allowed"] + disabled_tools = ["blocked"] + "#, + ) + .expect("should deserialize tool filters"); + + assert_eq!(cfg.enabled_tools, Some(vec!["allowed".to_string()])); + assert_eq!(cfg.disabled_tools, Some(vec!["blocked".to_string()])); +} + +#[test] +fn deserialize_rejects_command_and_url() { + toml::from_str::( + r#" + command = "echo" + url = "https://example.com" + "#, + ) + .expect_err("should reject command+url"); +} + +#[test] +fn deserialize_rejects_env_for_http_transport() { + toml::from_str::( + r#" + url = "https://example.com" + env = { "FOO" = "BAR" } + "#, + ) + .expect_err("should reject env for http transport"); +} + +#[test] +fn deserialize_rejects_headers_for_stdio() { + toml::from_str::( + r#" + command = "echo" + http_headers = { "X-Foo" = "bar" } + "#, + ) + .expect_err("should reject http_headers for stdio transport"); + + toml::from_str::( + r#" + command = "echo" + env_http_headers = { "X-Foo" = "BAR_ENV" } + "#, + ) + .expect_err("should reject env_http_headers for stdio transport"); + + let err = toml::from_str::( + r#" + command = "echo" + oauth_resource = "https://api.example.com" + "#, + ) + .expect_err("should reject oauth_resource for stdio transport"); + + assert!( + err.to_string() + .contains("oauth_resource is not supported for stdio"), + "unexpected error: {err}" + ); +} + +#[test] +fn deserialize_rejects_inline_bearer_token_field() { + let err = toml::from_str::( + r#" + url = "https://example.com" + bearer_token = "secret" + "#, + ) + .expect_err("should reject bearer_token field"); + + assert!( + err.to_string().contains("bearer_token is not supported"), + "unexpected error: {err}" + ); +} diff --git a/codex-rs/core/src/connectors.rs b/codex-rs/core/src/connectors.rs index 55318bc3d5..cced54174a 100644 --- a/codex-rs/core/src/connectors.rs +++ b/codex-rs/core/src/connectors.rs @@ -891,778 +891,5 @@ fn format_connector_label(name: &str, _id: &str) -> String { } #[cfg(test)] -mod tests { - use super::*; - use crate::config::ConfigBuilder; - use crate::config::types::AppConfig; - use crate::config::types::AppToolConfig; - use crate::config::types::AppToolsConfig; - use crate::config::types::AppsDefaultConfig; - use crate::features::Feature; - use crate::mcp::CODEX_APPS_MCP_SERVER_NAME; - use crate::mcp_connection_manager::ToolInfo; - use pretty_assertions::assert_eq; - use rmcp::model::JsonObject; - use rmcp::model::Tool; - use std::collections::HashMap; - use std::sync::Arc; - use tempfile::tempdir; - - fn annotations( - destructive_hint: Option, - open_world_hint: Option, - ) -> ToolAnnotations { - ToolAnnotations { - destructive_hint, - idempotent_hint: None, - open_world_hint, - read_only_hint: None, - title: None, - } - } - - fn app(id: &str) -> AppInfo { - AppInfo { - id: id.to_string(), - name: id.to_string(), - description: None, - logo_url: None, - logo_url_dark: None, - distribution_channel: None, - install_url: None, - branding: None, - app_metadata: None, - labels: None, - is_accessible: false, - is_enabled: true, - plugin_display_names: Vec::new(), - } - } - - fn named_app(id: &str, name: &str) -> AppInfo { - AppInfo { - id: id.to_string(), - name: name.to_string(), - install_url: Some(connector_install_url(name, id)), - ..app(id) - } - } - - fn plugin_names(names: &[&str]) -> Vec { - names.iter().map(ToString::to_string).collect() - } - - fn test_tool_definition(tool_name: &str) -> Tool { - Tool { - name: tool_name.to_string().into(), - title: None, - description: None, - input_schema: Arc::new(JsonObject::default()), - output_schema: None, - annotations: None, - execution: None, - icons: None, - meta: None, - } - } - - fn google_calendar_accessible_connector(plugin_display_names: &[&str]) -> AppInfo { - AppInfo { - id: "calendar".to_string(), - name: "Google Calendar".to_string(), - description: Some("Plan events".to_string()), - logo_url: Some("https://example.com/logo.png".to_string()), - logo_url_dark: Some("https://example.com/logo-dark.png".to_string()), - distribution_channel: Some("workspace".to_string()), - branding: None, - app_metadata: None, - labels: None, - install_url: None, - is_accessible: true, - is_enabled: true, - plugin_display_names: plugin_names(plugin_display_names), - } - } - - fn codex_app_tool( - tool_name: &str, - connector_id: &str, - connector_name: Option<&str>, - plugin_display_names: &[&str], - ) -> ToolInfo { - let tool_namespace = connector_name - .map(sanitize_name) - .map(|connector_name| format!("mcp__{CODEX_APPS_MCP_SERVER_NAME}__{connector_name}")) - .unwrap_or_else(|| CODEX_APPS_MCP_SERVER_NAME.to_string()); - - ToolInfo { - server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool_name: tool_name.to_string(), - tool_namespace, - tool: test_tool_definition(tool_name), - connector_id: Some(connector_id.to_string()), - connector_name: connector_name.map(ToOwned::to_owned), - connector_description: None, - plugin_display_names: plugin_names(plugin_display_names), - } - } - - fn with_accessible_connectors_cache_cleared(f: impl FnOnce() -> R) -> R { - let previous = { - let mut cache_guard = ACCESSIBLE_CONNECTORS_CACHE - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - cache_guard.take() - }; - let result = f(); - let mut cache_guard = ACCESSIBLE_CONNECTORS_CACHE - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - *cache_guard = previous; - result - } - - #[test] - fn merge_connectors_replaces_plugin_placeholder_name_with_accessible_name() { - let plugin = plugin_app_to_app_info(AppConnectorId("calendar".to_string())); - let accessible = google_calendar_accessible_connector(&[]); - - let merged = merge_connectors(vec![plugin], vec![accessible]); - - assert_eq!( - merged, - vec![AppInfo { - id: "calendar".to_string(), - name: "Google Calendar".to_string(), - description: Some("Plan events".to_string()), - logo_url: Some("https://example.com/logo.png".to_string()), - logo_url_dark: Some("https://example.com/logo-dark.png".to_string()), - distribution_channel: Some("workspace".to_string()), - branding: None, - app_metadata: None, - labels: None, - install_url: Some(connector_install_url("calendar", "calendar")), - is_accessible: true, - is_enabled: true, - plugin_display_names: Vec::new(), - }] - ); - assert_eq!(connector_mention_slug(&merged[0]), "google-calendar"); - } - - #[test] - fn accessible_connectors_from_mcp_tools_carries_plugin_display_names() { - let tools = HashMap::from([ - ( - "mcp__codex_apps__calendar_list_events".to_string(), - codex_app_tool( - "calendar_list_events", - "calendar", - None, - &["sample", "sample"], - ), - ), - ( - "mcp__codex_apps__calendar_create_event".to_string(), - codex_app_tool( - "calendar_create_event", - "calendar", - Some("Google Calendar"), - &["beta", "sample"], - ), - ), - ( - "mcp__sample__echo".to_string(), - ToolInfo { - server_name: "sample".to_string(), - tool_name: "echo".to_string(), - tool_namespace: "sample".to_string(), - tool: test_tool_definition("echo"), - connector_id: None, - connector_name: None, - connector_description: None, - plugin_display_names: plugin_names(&["ignored"]), - }, - ), - ]); - - let connectors = accessible_connectors_from_mcp_tools(&tools); - - assert_eq!( - connectors, - vec![AppInfo { - id: "calendar".to_string(), - name: "Google Calendar".to_string(), - description: None, - logo_url: None, - logo_url_dark: None, - distribution_channel: None, - install_url: Some(connector_install_url("Google Calendar", "calendar")), - branding: None, - app_metadata: None, - labels: None, - is_accessible: true, - is_enabled: true, - plugin_display_names: plugin_names(&["beta", "sample"]), - }] - ); - } - - #[tokio::test] - async fn refresh_accessible_connectors_cache_from_mcp_tools_writes_latest_installed_apps() { - let codex_home = tempdir().expect("tempdir should succeed"); - let mut config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .build() - .await - .expect("config should load"); - let _ = config.features.set_enabled(Feature::Apps, true); - let cache_key = accessible_connectors_cache_key(&config, None); - let tools = HashMap::from([ - ( - "mcp__codex_apps__calendar_list_events".to_string(), - codex_app_tool( - "calendar_list_events", - "calendar", - Some("Google Calendar"), - &["calendar-plugin"], - ), - ), - ( - "mcp__codex_apps__openai_hidden".to_string(), - codex_app_tool( - "openai_hidden", - "connector_openai_hidden", - Some("Hidden"), - &[], - ), - ), - ]); - - let cached = with_accessible_connectors_cache_cleared(|| { - refresh_accessible_connectors_cache_from_mcp_tools(&config, None, &tools); - read_cached_accessible_connectors(&cache_key).expect("cache should be populated") - }); - - assert_eq!( - cached, - vec![AppInfo { - id: "calendar".to_string(), - name: "Google Calendar".to_string(), - description: None, - logo_url: None, - logo_url_dark: None, - distribution_channel: None, - install_url: Some(connector_install_url("Google Calendar", "calendar")), - branding: None, - app_metadata: None, - labels: None, - is_accessible: true, - is_enabled: true, - plugin_display_names: plugin_names(&["calendar-plugin"]), - }] - ); - } - - #[test] - fn merge_connectors_unions_and_dedupes_plugin_display_names() { - let mut plugin = plugin_app_to_app_info(AppConnectorId("calendar".to_string())); - plugin.plugin_display_names = plugin_names(&["sample", "alpha", "sample"]); - - let accessible = google_calendar_accessible_connector(&["beta", "alpha"]); - - let merged = merge_connectors(vec![plugin], vec![accessible]); - - assert_eq!( - merged, - vec![AppInfo { - id: "calendar".to_string(), - name: "Google Calendar".to_string(), - description: Some("Plan events".to_string()), - logo_url: Some("https://example.com/logo.png".to_string()), - logo_url_dark: Some("https://example.com/logo-dark.png".to_string()), - distribution_channel: Some("workspace".to_string()), - branding: None, - app_metadata: None, - labels: None, - install_url: Some(connector_install_url("calendar", "calendar")), - is_accessible: true, - is_enabled: true, - plugin_display_names: plugin_names(&["alpha", "beta", "sample"]), - }] - ); - } - - #[test] - fn accessible_connectors_from_mcp_tools_preserves_description() { - let mcp_tools = HashMap::from([( - "mcp__codex_apps__calendar_create_event".to_string(), - ToolInfo { - server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool_name: "calendar_create_event".to_string(), - tool_namespace: "mcp__codex_apps__calendar".to_string(), - tool: Tool { - name: "calendar_create_event".to_string().into(), - title: None, - description: Some("Create a calendar event".into()), - input_schema: Arc::new(JsonObject::default()), - output_schema: None, - annotations: None, - execution: None, - icons: None, - meta: None, - }, - connector_id: Some("calendar".to_string()), - connector_name: Some("Calendar".to_string()), - connector_description: Some("Plan events".to_string()), - plugin_display_names: Vec::new(), - }, - )]); - - assert_eq!( - accessible_connectors_from_mcp_tools(&mcp_tools), - vec![AppInfo { - id: "calendar".to_string(), - name: "Calendar".to_string(), - description: Some("Plan events".to_string()), - logo_url: None, - logo_url_dark: None, - distribution_channel: None, - branding: None, - app_metadata: None, - labels: None, - install_url: Some(connector_install_url("Calendar", "calendar")), - is_accessible: true, - is_enabled: true, - plugin_display_names: Vec::new(), - }] - ); - } - - #[test] - fn app_tool_policy_uses_global_defaults_for_destructive_hints() { - let apps_config = AppsConfigToml { - default: Some(AppsDefaultConfig { - enabled: true, - destructive_enabled: false, - open_world_enabled: true, - }), - apps: HashMap::new(), - }; - - let policy = app_tool_policy_from_apps_config( - Some(&apps_config), - Some("calendar"), - "events/create", - None, - Some(&annotations(Some(true), None)), - ); - - assert_eq!( - policy, - AppToolPolicy { - enabled: false, - approval: AppToolApproval::Auto, - } - ); - } - - #[test] - fn app_is_enabled_uses_default_for_unconfigured_apps() { - let apps_config = AppsConfigToml { - default: Some(AppsDefaultConfig { - enabled: false, - destructive_enabled: true, - open_world_enabled: true, - }), - apps: HashMap::new(), - }; - - assert!(!app_is_enabled(&apps_config, Some("calendar"))); - assert!(!app_is_enabled(&apps_config, None)); - } - - #[test] - fn app_is_enabled_prefers_per_app_override_over_default() { - let apps_config = AppsConfigToml { - default: Some(AppsDefaultConfig { - enabled: false, - destructive_enabled: true, - open_world_enabled: true, - }), - apps: HashMap::from([( - "calendar".to_string(), - AppConfig { - enabled: true, - destructive_enabled: None, - open_world_enabled: None, - default_tools_approval_mode: None, - default_tools_enabled: None, - tools: None, - }, - )]), - }; - - assert!(app_is_enabled(&apps_config, Some("calendar"))); - assert!(!app_is_enabled(&apps_config, Some("drive"))); - } - - #[test] - fn app_tool_policy_honors_default_app_enabled_false() { - let apps_config = AppsConfigToml { - default: Some(AppsDefaultConfig { - enabled: false, - destructive_enabled: true, - open_world_enabled: true, - }), - apps: HashMap::new(), - }; - - let policy = app_tool_policy_from_apps_config( - Some(&apps_config), - Some("calendar"), - "events/list", - None, - Some(&annotations(None, None)), - ); - - assert_eq!( - policy, - AppToolPolicy { - enabled: false, - approval: AppToolApproval::Auto, - } - ); - } - - #[test] - fn app_tool_policy_allows_per_app_enable_when_default_is_disabled() { - let apps_config = AppsConfigToml { - default: Some(AppsDefaultConfig { - enabled: false, - destructive_enabled: true, - open_world_enabled: true, - }), - apps: HashMap::from([( - "calendar".to_string(), - AppConfig { - enabled: true, - destructive_enabled: None, - open_world_enabled: None, - default_tools_approval_mode: None, - default_tools_enabled: None, - tools: None, - }, - )]), - }; - - let policy = app_tool_policy_from_apps_config( - Some(&apps_config), - Some("calendar"), - "events/list", - None, - Some(&annotations(None, None)), - ); - - assert_eq!( - policy, - AppToolPolicy { - enabled: true, - approval: AppToolApproval::Auto, - } - ); - } - - #[test] - fn app_tool_policy_per_tool_enabled_true_overrides_app_level_disable_flags() { - let apps_config = AppsConfigToml { - default: None, - apps: HashMap::from([( - "calendar".to_string(), - AppConfig { - enabled: true, - destructive_enabled: Some(false), - open_world_enabled: Some(false), - default_tools_approval_mode: None, - default_tools_enabled: None, - tools: Some(AppToolsConfig { - tools: HashMap::from([( - "events/create".to_string(), - AppToolConfig { - enabled: Some(true), - approval_mode: None, - }, - )]), - }), - }, - )]), - }; - - let policy = app_tool_policy_from_apps_config( - Some(&apps_config), - Some("calendar"), - "events/create", - None, - Some(&annotations(Some(true), Some(true))), - ); - - assert_eq!( - policy, - AppToolPolicy { - enabled: true, - approval: AppToolApproval::Auto, - } - ); - } - - #[test] - fn app_tool_policy_default_tools_enabled_true_overrides_app_level_tool_hints() { - let apps_config = AppsConfigToml { - default: None, - apps: HashMap::from([( - "calendar".to_string(), - AppConfig { - enabled: true, - destructive_enabled: Some(false), - open_world_enabled: Some(false), - default_tools_approval_mode: None, - default_tools_enabled: Some(true), - tools: None, - }, - )]), - }; - - let policy = app_tool_policy_from_apps_config( - Some(&apps_config), - Some("calendar"), - "events/create", - None, - Some(&annotations(Some(true), Some(true))), - ); - - assert_eq!( - policy, - AppToolPolicy { - enabled: true, - approval: AppToolApproval::Auto, - } - ); - } - - #[test] - fn app_tool_policy_default_tools_enabled_false_overrides_app_level_tool_hints() { - let apps_config = AppsConfigToml { - default: None, - apps: HashMap::from([( - "calendar".to_string(), - AppConfig { - enabled: true, - destructive_enabled: Some(true), - open_world_enabled: Some(true), - default_tools_approval_mode: Some(AppToolApproval::Approve), - default_tools_enabled: Some(false), - tools: None, - }, - )]), - }; - - let policy = app_tool_policy_from_apps_config( - Some(&apps_config), - Some("calendar"), - "events/list", - None, - Some(&annotations(None, None)), - ); - - assert_eq!( - policy, - AppToolPolicy { - enabled: false, - approval: AppToolApproval::Approve, - } - ); - } - - #[test] - fn app_tool_policy_uses_default_tools_approval_mode() { - let apps_config = AppsConfigToml { - default: None, - apps: HashMap::from([( - "calendar".to_string(), - AppConfig { - enabled: true, - destructive_enabled: None, - open_world_enabled: None, - default_tools_approval_mode: Some(AppToolApproval::Prompt), - default_tools_enabled: None, - tools: Some(AppToolsConfig { - tools: HashMap::new(), - }), - }, - )]), - }; - - let policy = app_tool_policy_from_apps_config( - Some(&apps_config), - Some("calendar"), - "events/list", - None, - Some(&annotations(None, None)), - ); - - assert_eq!( - policy, - AppToolPolicy { - enabled: true, - approval: AppToolApproval::Prompt, - } - ); - } - - #[test] - fn app_tool_policy_matches_prefix_stripped_tool_name_for_tool_config() { - let apps_config = AppsConfigToml { - default: None, - apps: HashMap::from([( - "calendar".to_string(), - AppConfig { - enabled: true, - destructive_enabled: Some(false), - open_world_enabled: Some(false), - default_tools_approval_mode: Some(AppToolApproval::Auto), - default_tools_enabled: Some(false), - tools: Some(AppToolsConfig { - tools: HashMap::from([( - "events/create".to_string(), - AppToolConfig { - enabled: Some(true), - approval_mode: Some(AppToolApproval::Approve), - }, - )]), - }), - }, - )]), - }; - - let policy = app_tool_policy_from_apps_config( - Some(&apps_config), - Some("calendar"), - "calendar_events/create", - Some("events/create"), - Some(&annotations(Some(true), Some(true))), - ); - - assert_eq!( - policy, - AppToolPolicy { - enabled: true, - approval: AppToolApproval::Approve, - } - ); - } - - #[test] - fn filter_disallowed_connectors_allows_non_disallowed_connectors() { - let filtered = filter_disallowed_connectors(vec![app("asdk_app_hidden"), app("alpha")]); - assert_eq!(filtered, vec![app("asdk_app_hidden"), app("alpha")]); - } - - #[test] - fn filter_disallowed_connectors_filters_openai_prefix() { - let filtered = filter_disallowed_connectors(vec![ - app("connector_openai_foo"), - app("connector_openai_bar"), - app("gamma"), - ]); - assert_eq!(filtered, vec![app("gamma")]); - } - - #[test] - fn filter_disallowed_connectors_filters_disallowed_connector_ids() { - let filtered = filter_disallowed_connectors(vec![ - app("asdk_app_6938a94a61d881918ef32cb999ff937c"), - app("delta"), - ]); - assert_eq!(filtered, vec![app("delta")]); - } - - #[test] - fn first_party_chat_originator_filters_target_and_openai_prefixed_connectors() { - let filtered = filter_disallowed_connectors_for_originator( - vec![ - app("connector_openai_foo"), - app("asdk_app_6938a94a61d881918ef32cb999ff937c"), - app("connector_0f9c9d4592e54d0a9a12b3f44a1e2010"), - ], - "codex_atlas", - ); - assert_eq!( - filtered, - vec![app("asdk_app_6938a94a61d881918ef32cb999ff937c")] - ); - } - - #[test] - fn filter_tool_suggest_discoverable_tools_keeps_only_allowlisted_uninstalled_apps() { - let filtered = filter_tool_suggest_discoverable_tools( - vec![ - named_app( - "connector_2128aebfecb84f64a069897515042a44", - "Google Calendar", - ), - named_app("connector_68df038e0ba48191908c8434991bbac2", "Gmail"), - named_app("connector_other", "Other"), - ], - &[AppInfo { - is_accessible: true, - ..named_app( - "connector_2128aebfecb84f64a069897515042a44", - "Google Calendar", - ) - }], - ); - - assert_eq!( - filtered, - vec![named_app( - "connector_68df038e0ba48191908c8434991bbac2", - "Gmail", - )] - ); - } - - #[test] - fn filter_tool_suggest_discoverable_tools_keeps_disabled_accessible_apps() { - let filtered = filter_tool_suggest_discoverable_tools( - vec![ - named_app( - "connector_2128aebfecb84f64a069897515042a44", - "Google Calendar", - ), - named_app("connector_68df038e0ba48191908c8434991bbac2", "Gmail"), - ], - &[ - AppInfo { - is_accessible: true, - ..named_app( - "connector_2128aebfecb84f64a069897515042a44", - "Google Calendar", - ) - }, - AppInfo { - is_accessible: true, - is_enabled: false, - ..named_app("connector_68df038e0ba48191908c8434991bbac2", "Gmail") - }, - ], - ); - - assert_eq!( - filtered, - vec![named_app( - "connector_68df038e0ba48191908c8434991bbac2", - "Gmail" - )] - ); - } -} +#[path = "connectors_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/connectors_tests.rs b/codex-rs/core/src/connectors_tests.rs new file mode 100644 index 0000000000..3743fb9f87 --- /dev/null +++ b/codex-rs/core/src/connectors_tests.rs @@ -0,0 +1,770 @@ +use super::*; +use crate::config::ConfigBuilder; +use crate::config::types::AppConfig; +use crate::config::types::AppToolConfig; +use crate::config::types::AppToolsConfig; +use crate::config::types::AppsDefaultConfig; +use crate::features::Feature; +use crate::mcp::CODEX_APPS_MCP_SERVER_NAME; +use crate::mcp_connection_manager::ToolInfo; +use pretty_assertions::assert_eq; +use rmcp::model::JsonObject; +use rmcp::model::Tool; +use std::collections::HashMap; +use std::sync::Arc; +use tempfile::tempdir; + +fn annotations(destructive_hint: Option, open_world_hint: Option) -> ToolAnnotations { + ToolAnnotations { + destructive_hint, + idempotent_hint: None, + open_world_hint, + read_only_hint: None, + title: None, + } +} + +fn app(id: &str) -> AppInfo { + AppInfo { + id: id.to_string(), + name: id.to_string(), + description: None, + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + install_url: None, + branding: None, + app_metadata: None, + labels: None, + is_accessible: false, + is_enabled: true, + plugin_display_names: Vec::new(), + } +} + +fn named_app(id: &str, name: &str) -> AppInfo { + AppInfo { + id: id.to_string(), + name: name.to_string(), + install_url: Some(connector_install_url(name, id)), + ..app(id) + } +} + +fn plugin_names(names: &[&str]) -> Vec { + names.iter().map(ToString::to_string).collect() +} + +fn test_tool_definition(tool_name: &str) -> Tool { + Tool { + name: tool_name.to_string().into(), + title: None, + description: None, + input_schema: Arc::new(JsonObject::default()), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + } +} + +fn google_calendar_accessible_connector(plugin_display_names: &[&str]) -> AppInfo { + AppInfo { + id: "calendar".to_string(), + name: "Google Calendar".to_string(), + description: Some("Plan events".to_string()), + logo_url: Some("https://example.com/logo.png".to_string()), + logo_url_dark: Some("https://example.com/logo-dark.png".to_string()), + distribution_channel: Some("workspace".to_string()), + branding: None, + app_metadata: None, + labels: None, + install_url: None, + is_accessible: true, + is_enabled: true, + plugin_display_names: plugin_names(plugin_display_names), + } +} + +fn codex_app_tool( + tool_name: &str, + connector_id: &str, + connector_name: Option<&str>, + plugin_display_names: &[&str], +) -> ToolInfo { + let tool_namespace = connector_name + .map(sanitize_name) + .map(|connector_name| format!("mcp__{CODEX_APPS_MCP_SERVER_NAME}__{connector_name}")) + .unwrap_or_else(|| CODEX_APPS_MCP_SERVER_NAME.to_string()); + + ToolInfo { + server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool_name: tool_name.to_string(), + tool_namespace, + tool: test_tool_definition(tool_name), + connector_id: Some(connector_id.to_string()), + connector_name: connector_name.map(ToOwned::to_owned), + connector_description: None, + plugin_display_names: plugin_names(plugin_display_names), + } +} + +fn with_accessible_connectors_cache_cleared(f: impl FnOnce() -> R) -> R { + let previous = { + let mut cache_guard = ACCESSIBLE_CONNECTORS_CACHE + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + cache_guard.take() + }; + let result = f(); + let mut cache_guard = ACCESSIBLE_CONNECTORS_CACHE + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *cache_guard = previous; + result +} + +#[test] +fn merge_connectors_replaces_plugin_placeholder_name_with_accessible_name() { + let plugin = plugin_app_to_app_info(AppConnectorId("calendar".to_string())); + let accessible = google_calendar_accessible_connector(&[]); + + let merged = merge_connectors(vec![plugin], vec![accessible]); + + assert_eq!( + merged, + vec![AppInfo { + id: "calendar".to_string(), + name: "Google Calendar".to_string(), + description: Some("Plan events".to_string()), + logo_url: Some("https://example.com/logo.png".to_string()), + logo_url_dark: Some("https://example.com/logo-dark.png".to_string()), + distribution_channel: Some("workspace".to_string()), + branding: None, + app_metadata: None, + labels: None, + install_url: Some(connector_install_url("calendar", "calendar")), + is_accessible: true, + is_enabled: true, + plugin_display_names: Vec::new(), + }] + ); + assert_eq!(connector_mention_slug(&merged[0]), "google-calendar"); +} + +#[test] +fn accessible_connectors_from_mcp_tools_carries_plugin_display_names() { + let tools = HashMap::from([ + ( + "mcp__codex_apps__calendar_list_events".to_string(), + codex_app_tool( + "calendar_list_events", + "calendar", + None, + &["sample", "sample"], + ), + ), + ( + "mcp__codex_apps__calendar_create_event".to_string(), + codex_app_tool( + "calendar_create_event", + "calendar", + Some("Google Calendar"), + &["beta", "sample"], + ), + ), + ( + "mcp__sample__echo".to_string(), + ToolInfo { + server_name: "sample".to_string(), + tool_name: "echo".to_string(), + tool_namespace: "sample".to_string(), + tool: test_tool_definition("echo"), + connector_id: None, + connector_name: None, + connector_description: None, + plugin_display_names: plugin_names(&["ignored"]), + }, + ), + ]); + + let connectors = accessible_connectors_from_mcp_tools(&tools); + + assert_eq!( + connectors, + vec![AppInfo { + id: "calendar".to_string(), + name: "Google Calendar".to_string(), + description: None, + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + install_url: Some(connector_install_url("Google Calendar", "calendar")), + branding: None, + app_metadata: None, + labels: None, + is_accessible: true, + is_enabled: true, + plugin_display_names: plugin_names(&["beta", "sample"]), + }] + ); +} + +#[tokio::test] +async fn refresh_accessible_connectors_cache_from_mcp_tools_writes_latest_installed_apps() { + let codex_home = tempdir().expect("tempdir should succeed"); + let mut config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .build() + .await + .expect("config should load"); + let _ = config.features.set_enabled(Feature::Apps, true); + let cache_key = accessible_connectors_cache_key(&config, None); + let tools = HashMap::from([ + ( + "mcp__codex_apps__calendar_list_events".to_string(), + codex_app_tool( + "calendar_list_events", + "calendar", + Some("Google Calendar"), + &["calendar-plugin"], + ), + ), + ( + "mcp__codex_apps__openai_hidden".to_string(), + codex_app_tool( + "openai_hidden", + "connector_openai_hidden", + Some("Hidden"), + &[], + ), + ), + ]); + + let cached = with_accessible_connectors_cache_cleared(|| { + refresh_accessible_connectors_cache_from_mcp_tools(&config, None, &tools); + read_cached_accessible_connectors(&cache_key).expect("cache should be populated") + }); + + assert_eq!( + cached, + vec![AppInfo { + id: "calendar".to_string(), + name: "Google Calendar".to_string(), + description: None, + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + install_url: Some(connector_install_url("Google Calendar", "calendar")), + branding: None, + app_metadata: None, + labels: None, + is_accessible: true, + is_enabled: true, + plugin_display_names: plugin_names(&["calendar-plugin"]), + }] + ); +} + +#[test] +fn merge_connectors_unions_and_dedupes_plugin_display_names() { + let mut plugin = plugin_app_to_app_info(AppConnectorId("calendar".to_string())); + plugin.plugin_display_names = plugin_names(&["sample", "alpha", "sample"]); + + let accessible = google_calendar_accessible_connector(&["beta", "alpha"]); + + let merged = merge_connectors(vec![plugin], vec![accessible]); + + assert_eq!( + merged, + vec![AppInfo { + id: "calendar".to_string(), + name: "Google Calendar".to_string(), + description: Some("Plan events".to_string()), + logo_url: Some("https://example.com/logo.png".to_string()), + logo_url_dark: Some("https://example.com/logo-dark.png".to_string()), + distribution_channel: Some("workspace".to_string()), + branding: None, + app_metadata: None, + labels: None, + install_url: Some(connector_install_url("calendar", "calendar")), + is_accessible: true, + is_enabled: true, + plugin_display_names: plugin_names(&["alpha", "beta", "sample"]), + }] + ); +} + +#[test] +fn accessible_connectors_from_mcp_tools_preserves_description() { + let mcp_tools = HashMap::from([( + "mcp__codex_apps__calendar_create_event".to_string(), + ToolInfo { + server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool_name: "calendar_create_event".to_string(), + tool_namespace: "mcp__codex_apps__calendar".to_string(), + tool: Tool { + name: "calendar_create_event".to_string().into(), + title: None, + description: Some("Create a calendar event".into()), + input_schema: Arc::new(JsonObject::default()), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + }, + connector_id: Some("calendar".to_string()), + connector_name: Some("Calendar".to_string()), + connector_description: Some("Plan events".to_string()), + plugin_display_names: Vec::new(), + }, + )]); + + assert_eq!( + accessible_connectors_from_mcp_tools(&mcp_tools), + vec![AppInfo { + id: "calendar".to_string(), + name: "Calendar".to_string(), + description: Some("Plan events".to_string()), + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + branding: None, + app_metadata: None, + labels: None, + install_url: Some(connector_install_url("Calendar", "calendar")), + is_accessible: true, + is_enabled: true, + plugin_display_names: Vec::new(), + }] + ); +} + +#[test] +fn app_tool_policy_uses_global_defaults_for_destructive_hints() { + let apps_config = AppsConfigToml { + default: Some(AppsDefaultConfig { + enabled: true, + destructive_enabled: false, + open_world_enabled: true, + }), + apps: HashMap::new(), + }; + + let policy = app_tool_policy_from_apps_config( + Some(&apps_config), + Some("calendar"), + "events/create", + None, + Some(&annotations(Some(true), None)), + ); + + assert_eq!( + policy, + AppToolPolicy { + enabled: false, + approval: AppToolApproval::Auto, + } + ); +} + +#[test] +fn app_is_enabled_uses_default_for_unconfigured_apps() { + let apps_config = AppsConfigToml { + default: Some(AppsDefaultConfig { + enabled: false, + destructive_enabled: true, + open_world_enabled: true, + }), + apps: HashMap::new(), + }; + + assert!(!app_is_enabled(&apps_config, Some("calendar"))); + assert!(!app_is_enabled(&apps_config, None)); +} + +#[test] +fn app_is_enabled_prefers_per_app_override_over_default() { + let apps_config = AppsConfigToml { + default: Some(AppsDefaultConfig { + enabled: false, + destructive_enabled: true, + open_world_enabled: true, + }), + apps: HashMap::from([( + "calendar".to_string(), + AppConfig { + enabled: true, + destructive_enabled: None, + open_world_enabled: None, + default_tools_approval_mode: None, + default_tools_enabled: None, + tools: None, + }, + )]), + }; + + assert!(app_is_enabled(&apps_config, Some("calendar"))); + assert!(!app_is_enabled(&apps_config, Some("drive"))); +} + +#[test] +fn app_tool_policy_honors_default_app_enabled_false() { + let apps_config = AppsConfigToml { + default: Some(AppsDefaultConfig { + enabled: false, + destructive_enabled: true, + open_world_enabled: true, + }), + apps: HashMap::new(), + }; + + let policy = app_tool_policy_from_apps_config( + Some(&apps_config), + Some("calendar"), + "events/list", + None, + Some(&annotations(None, None)), + ); + + assert_eq!( + policy, + AppToolPolicy { + enabled: false, + approval: AppToolApproval::Auto, + } + ); +} + +#[test] +fn app_tool_policy_allows_per_app_enable_when_default_is_disabled() { + let apps_config = AppsConfigToml { + default: Some(AppsDefaultConfig { + enabled: false, + destructive_enabled: true, + open_world_enabled: true, + }), + apps: HashMap::from([( + "calendar".to_string(), + AppConfig { + enabled: true, + destructive_enabled: None, + open_world_enabled: None, + default_tools_approval_mode: None, + default_tools_enabled: None, + tools: None, + }, + )]), + }; + + let policy = app_tool_policy_from_apps_config( + Some(&apps_config), + Some("calendar"), + "events/list", + None, + Some(&annotations(None, None)), + ); + + assert_eq!( + policy, + AppToolPolicy { + enabled: true, + approval: AppToolApproval::Auto, + } + ); +} + +#[test] +fn app_tool_policy_per_tool_enabled_true_overrides_app_level_disable_flags() { + let apps_config = AppsConfigToml { + default: None, + apps: HashMap::from([( + "calendar".to_string(), + AppConfig { + enabled: true, + destructive_enabled: Some(false), + open_world_enabled: Some(false), + default_tools_approval_mode: None, + default_tools_enabled: None, + tools: Some(AppToolsConfig { + tools: HashMap::from([( + "events/create".to_string(), + AppToolConfig { + enabled: Some(true), + approval_mode: None, + }, + )]), + }), + }, + )]), + }; + + let policy = app_tool_policy_from_apps_config( + Some(&apps_config), + Some("calendar"), + "events/create", + None, + Some(&annotations(Some(true), Some(true))), + ); + + assert_eq!( + policy, + AppToolPolicy { + enabled: true, + approval: AppToolApproval::Auto, + } + ); +} + +#[test] +fn app_tool_policy_default_tools_enabled_true_overrides_app_level_tool_hints() { + let apps_config = AppsConfigToml { + default: None, + apps: HashMap::from([( + "calendar".to_string(), + AppConfig { + enabled: true, + destructive_enabled: Some(false), + open_world_enabled: Some(false), + default_tools_approval_mode: None, + default_tools_enabled: Some(true), + tools: None, + }, + )]), + }; + + let policy = app_tool_policy_from_apps_config( + Some(&apps_config), + Some("calendar"), + "events/create", + None, + Some(&annotations(Some(true), Some(true))), + ); + + assert_eq!( + policy, + AppToolPolicy { + enabled: true, + approval: AppToolApproval::Auto, + } + ); +} + +#[test] +fn app_tool_policy_default_tools_enabled_false_overrides_app_level_tool_hints() { + let apps_config = AppsConfigToml { + default: None, + apps: HashMap::from([( + "calendar".to_string(), + AppConfig { + enabled: true, + destructive_enabled: Some(true), + open_world_enabled: Some(true), + default_tools_approval_mode: Some(AppToolApproval::Approve), + default_tools_enabled: Some(false), + tools: None, + }, + )]), + }; + + let policy = app_tool_policy_from_apps_config( + Some(&apps_config), + Some("calendar"), + "events/list", + None, + Some(&annotations(None, None)), + ); + + assert_eq!( + policy, + AppToolPolicy { + enabled: false, + approval: AppToolApproval::Approve, + } + ); +} + +#[test] +fn app_tool_policy_uses_default_tools_approval_mode() { + let apps_config = AppsConfigToml { + default: None, + apps: HashMap::from([( + "calendar".to_string(), + AppConfig { + enabled: true, + destructive_enabled: None, + open_world_enabled: None, + default_tools_approval_mode: Some(AppToolApproval::Prompt), + default_tools_enabled: None, + tools: Some(AppToolsConfig { + tools: HashMap::new(), + }), + }, + )]), + }; + + let policy = app_tool_policy_from_apps_config( + Some(&apps_config), + Some("calendar"), + "events/list", + None, + Some(&annotations(None, None)), + ); + + assert_eq!( + policy, + AppToolPolicy { + enabled: true, + approval: AppToolApproval::Prompt, + } + ); +} + +#[test] +fn app_tool_policy_matches_prefix_stripped_tool_name_for_tool_config() { + let apps_config = AppsConfigToml { + default: None, + apps: HashMap::from([( + "calendar".to_string(), + AppConfig { + enabled: true, + destructive_enabled: Some(false), + open_world_enabled: Some(false), + default_tools_approval_mode: Some(AppToolApproval::Auto), + default_tools_enabled: Some(false), + tools: Some(AppToolsConfig { + tools: HashMap::from([( + "events/create".to_string(), + AppToolConfig { + enabled: Some(true), + approval_mode: Some(AppToolApproval::Approve), + }, + )]), + }), + }, + )]), + }; + + let policy = app_tool_policy_from_apps_config( + Some(&apps_config), + Some("calendar"), + "calendar_events/create", + Some("events/create"), + Some(&annotations(Some(true), Some(true))), + ); + + assert_eq!( + policy, + AppToolPolicy { + enabled: true, + approval: AppToolApproval::Approve, + } + ); +} + +#[test] +fn filter_disallowed_connectors_allows_non_disallowed_connectors() { + let filtered = filter_disallowed_connectors(vec![app("asdk_app_hidden"), app("alpha")]); + assert_eq!(filtered, vec![app("asdk_app_hidden"), app("alpha")]); +} + +#[test] +fn filter_disallowed_connectors_filters_openai_prefix() { + let filtered = filter_disallowed_connectors(vec![ + app("connector_openai_foo"), + app("connector_openai_bar"), + app("gamma"), + ]); + assert_eq!(filtered, vec![app("gamma")]); +} + +#[test] +fn filter_disallowed_connectors_filters_disallowed_connector_ids() { + let filtered = filter_disallowed_connectors(vec![ + app("asdk_app_6938a94a61d881918ef32cb999ff937c"), + app("delta"), + ]); + assert_eq!(filtered, vec![app("delta")]); +} + +#[test] +fn first_party_chat_originator_filters_target_and_openai_prefixed_connectors() { + let filtered = filter_disallowed_connectors_for_originator( + vec![ + app("connector_openai_foo"), + app("asdk_app_6938a94a61d881918ef32cb999ff937c"), + app("connector_0f9c9d4592e54d0a9a12b3f44a1e2010"), + ], + "codex_atlas", + ); + assert_eq!( + filtered, + vec![app("asdk_app_6938a94a61d881918ef32cb999ff937c")] + ); +} + +#[test] +fn filter_tool_suggest_discoverable_tools_keeps_only_allowlisted_uninstalled_apps() { + let filtered = filter_tool_suggest_discoverable_tools( + vec![ + named_app( + "connector_2128aebfecb84f64a069897515042a44", + "Google Calendar", + ), + named_app("connector_68df038e0ba48191908c8434991bbac2", "Gmail"), + named_app("connector_other", "Other"), + ], + &[AppInfo { + is_accessible: true, + ..named_app( + "connector_2128aebfecb84f64a069897515042a44", + "Google Calendar", + ) + }], + ); + + assert_eq!( + filtered, + vec![named_app( + "connector_68df038e0ba48191908c8434991bbac2", + "Gmail", + )] + ); +} + +#[test] +fn filter_tool_suggest_discoverable_tools_keeps_disabled_accessible_apps() { + let filtered = filter_tool_suggest_discoverable_tools( + vec![ + named_app( + "connector_2128aebfecb84f64a069897515042a44", + "Google Calendar", + ), + named_app("connector_68df038e0ba48191908c8434991bbac2", "Gmail"), + ], + &[ + AppInfo { + is_accessible: true, + ..named_app( + "connector_2128aebfecb84f64a069897515042a44", + "Google Calendar", + ) + }, + AppInfo { + is_accessible: true, + is_enabled: false, + ..named_app("connector_68df038e0ba48191908c8434991bbac2", "Gmail") + }, + ], + ); + + assert_eq!( + filtered, + vec![named_app( + "connector_68df038e0ba48191908c8434991bbac2", + "Gmail" + )] + ); +} diff --git a/codex-rs/core/src/contextual_user_message.rs b/codex-rs/core/src/contextual_user_message.rs index 51a2d23ea9..d10f7a9fc7 100644 --- a/codex-rs/core/src/contextual_user_message.rs +++ b/codex-rs/core/src/contextual_user_message.rs @@ -104,36 +104,5 @@ pub(crate) fn is_contextual_user_fragment(content_item: &ContentItem) -> bool { } #[cfg(test)] -mod tests { - use super::*; - - #[test] - fn detects_environment_context_fragment() { - assert!(is_contextual_user_fragment(&ContentItem::InputText { - text: "\n/tmp\n".to_string(), - })); - } - - #[test] - fn detects_agents_instructions_fragment() { - assert!(is_contextual_user_fragment(&ContentItem::InputText { - text: "# AGENTS.md instructions for /tmp\n\n\nbody\n" - .to_string(), - })); - } - - #[test] - fn detects_subagent_notification_fragment_case_insensitively() { - assert!( - SUBAGENT_NOTIFICATION_FRAGMENT - .matches_text("{}") - ); - } - - #[test] - fn ignores_regular_user_text() { - assert!(!is_contextual_user_fragment(&ContentItem::InputText { - text: "hello".to_string(), - })); - } -} +#[path = "contextual_user_message_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/contextual_user_message_tests.rs b/codex-rs/core/src/contextual_user_message_tests.rs new file mode 100644 index 0000000000..df3a9daeca --- /dev/null +++ b/codex-rs/core/src/contextual_user_message_tests.rs @@ -0,0 +1,31 @@ +use super::*; + +#[test] +fn detects_environment_context_fragment() { + assert!(is_contextual_user_fragment(&ContentItem::InputText { + text: "\n/tmp\n".to_string(), + })); +} + +#[test] +fn detects_agents_instructions_fragment() { + assert!(is_contextual_user_fragment(&ContentItem::InputText { + text: "# AGENTS.md instructions for /tmp\n\n\nbody\n" + .to_string(), + })); +} + +#[test] +fn detects_subagent_notification_fragment_case_insensitively() { + assert!( + SUBAGENT_NOTIFICATION_FRAGMENT + .matches_text("{}") + ); +} + +#[test] +fn ignores_regular_user_text() { + assert!(!is_contextual_user_fragment(&ContentItem::InputText { + text: "hello".to_string(), + })); +} diff --git a/codex-rs/core/src/custom_prompts.rs b/codex-rs/core/src/custom_prompts.rs index 66b2bab32c..54ccaa62fe 100644 --- a/codex-rs/core/src/custom_prompts.rs +++ b/codex-rs/core/src/custom_prompts.rs @@ -145,100 +145,5 @@ fn parse_frontmatter(content: &str) -> (Option, Option, String) } #[cfg(test)] -mod tests { - use super::*; - use std::fs; - use tempfile::tempdir; - - #[tokio::test] - async fn empty_when_dir_missing() { - let tmp = tempdir().expect("create TempDir"); - let missing = tmp.path().join("nope"); - let found = discover_prompts_in(&missing).await; - assert!(found.is_empty()); - } - - #[tokio::test] - async fn discovers_and_sorts_files() { - let tmp = tempdir().expect("create TempDir"); - let dir = tmp.path(); - fs::write(dir.join("b.md"), b"b").unwrap(); - fs::write(dir.join("a.md"), b"a").unwrap(); - fs::create_dir(dir.join("subdir")).unwrap(); - let found = discover_prompts_in(dir).await; - let names: Vec = found.into_iter().map(|e| e.name).collect(); - assert_eq!(names, vec!["a", "b"]); - } - - #[tokio::test] - async fn excludes_builtins() { - let tmp = tempdir().expect("create TempDir"); - let dir = tmp.path(); - fs::write(dir.join("init.md"), b"ignored").unwrap(); - fs::write(dir.join("foo.md"), b"ok").unwrap(); - let mut exclude = HashSet::new(); - exclude.insert("init".to_string()); - let found = discover_prompts_in_excluding(dir, &exclude).await; - let names: Vec = found.into_iter().map(|e| e.name).collect(); - assert_eq!(names, vec!["foo"]); - } - - #[tokio::test] - async fn skips_non_utf8_files() { - let tmp = tempdir().expect("create TempDir"); - let dir = tmp.path(); - // Valid UTF-8 file - fs::write(dir.join("good.md"), b"hello").unwrap(); - // Invalid UTF-8 content in .md file (e.g., lone 0xFF byte) - fs::write(dir.join("bad.md"), vec![0xFF, 0xFE, b'\n']).unwrap(); - let found = discover_prompts_in(dir).await; - let names: Vec = found.into_iter().map(|e| e.name).collect(); - assert_eq!(names, vec!["good"]); - } - - #[tokio::test] - #[cfg(unix)] - async fn discovers_symlinked_md_files() { - let tmp = tempdir().expect("create TempDir"); - let dir = tmp.path(); - - // Create a real file - fs::write(dir.join("real.md"), b"real content").unwrap(); - - // Create a symlink to the real file - std::os::unix::fs::symlink(dir.join("real.md"), dir.join("link.md")).unwrap(); - - let found = discover_prompts_in(dir).await; - let names: Vec = found.into_iter().map(|e| e.name).collect(); - - // Both real and link should be discovered, sorted alphabetically - assert_eq!(names, vec!["link", "real"]); - } - - #[tokio::test] - async fn parses_frontmatter_and_strips_from_body() { - let tmp = tempdir().expect("create TempDir"); - let dir = tmp.path(); - let file = dir.join("withmeta.md"); - let text = "---\nname: ignored\ndescription: \"Quick review command\"\nargument-hint: \"[file] [priority]\"\n---\nActual body with $1 and $ARGUMENTS"; - fs::write(&file, text).unwrap(); - - let found = discover_prompts_in(dir).await; - assert_eq!(found.len(), 1); - let p = &found[0]; - assert_eq!(p.name, "withmeta"); - assert_eq!(p.description.as_deref(), Some("Quick review command")); - assert_eq!(p.argument_hint.as_deref(), Some("[file] [priority]")); - // Body should not include the frontmatter delimiters. - assert_eq!(p.content, "Actual body with $1 and $ARGUMENTS"); - } - - #[test] - fn parse_frontmatter_preserves_body_newlines() { - let content = "---\r\ndescription: \"Line endings\"\r\nargument_hint: \"[arg]\"\r\n---\r\nFirst line\r\nSecond line\r\n"; - let (desc, hint, body) = parse_frontmatter(content); - assert_eq!(desc.as_deref(), Some("Line endings")); - assert_eq!(hint.as_deref(), Some("[arg]")); - assert_eq!(body, "First line\r\nSecond line\r\n"); - } -} +#[path = "custom_prompts_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/custom_prompts_tests.rs b/codex-rs/core/src/custom_prompts_tests.rs new file mode 100644 index 0000000000..b1208a04e0 --- /dev/null +++ b/codex-rs/core/src/custom_prompts_tests.rs @@ -0,0 +1,95 @@ +use super::*; +use std::fs; +use tempfile::tempdir; + +#[tokio::test] +async fn empty_when_dir_missing() { + let tmp = tempdir().expect("create TempDir"); + let missing = tmp.path().join("nope"); + let found = discover_prompts_in(&missing).await; + assert!(found.is_empty()); +} + +#[tokio::test] +async fn discovers_and_sorts_files() { + let tmp = tempdir().expect("create TempDir"); + let dir = tmp.path(); + fs::write(dir.join("b.md"), b"b").unwrap(); + fs::write(dir.join("a.md"), b"a").unwrap(); + fs::create_dir(dir.join("subdir")).unwrap(); + let found = discover_prompts_in(dir).await; + let names: Vec = found.into_iter().map(|e| e.name).collect(); + assert_eq!(names, vec!["a", "b"]); +} + +#[tokio::test] +async fn excludes_builtins() { + let tmp = tempdir().expect("create TempDir"); + let dir = tmp.path(); + fs::write(dir.join("init.md"), b"ignored").unwrap(); + fs::write(dir.join("foo.md"), b"ok").unwrap(); + let mut exclude = HashSet::new(); + exclude.insert("init".to_string()); + let found = discover_prompts_in_excluding(dir, &exclude).await; + let names: Vec = found.into_iter().map(|e| e.name).collect(); + assert_eq!(names, vec!["foo"]); +} + +#[tokio::test] +async fn skips_non_utf8_files() { + let tmp = tempdir().expect("create TempDir"); + let dir = tmp.path(); + // Valid UTF-8 file + fs::write(dir.join("good.md"), b"hello").unwrap(); + // Invalid UTF-8 content in .md file (e.g., lone 0xFF byte) + fs::write(dir.join("bad.md"), vec![0xFF, 0xFE, b'\n']).unwrap(); + let found = discover_prompts_in(dir).await; + let names: Vec = found.into_iter().map(|e| e.name).collect(); + assert_eq!(names, vec!["good"]); +} + +#[tokio::test] +#[cfg(unix)] +async fn discovers_symlinked_md_files() { + let tmp = tempdir().expect("create TempDir"); + let dir = tmp.path(); + + // Create a real file + fs::write(dir.join("real.md"), b"real content").unwrap(); + + // Create a symlink to the real file + std::os::unix::fs::symlink(dir.join("real.md"), dir.join("link.md")).unwrap(); + + let found = discover_prompts_in(dir).await; + let names: Vec = found.into_iter().map(|e| e.name).collect(); + + // Both real and link should be discovered, sorted alphabetically + assert_eq!(names, vec!["link", "real"]); +} + +#[tokio::test] +async fn parses_frontmatter_and_strips_from_body() { + let tmp = tempdir().expect("create TempDir"); + let dir = tmp.path(); + let file = dir.join("withmeta.md"); + let text = "---\nname: ignored\ndescription: \"Quick review command\"\nargument-hint: \"[file] [priority]\"\n---\nActual body with $1 and $ARGUMENTS"; + fs::write(&file, text).unwrap(); + + let found = discover_prompts_in(dir).await; + assert_eq!(found.len(), 1); + let p = &found[0]; + assert_eq!(p.name, "withmeta"); + assert_eq!(p.description.as_deref(), Some("Quick review command")); + assert_eq!(p.argument_hint.as_deref(), Some("[file] [priority]")); + // Body should not include the frontmatter delimiters. + assert_eq!(p.content, "Actual body with $1 and $ARGUMENTS"); +} + +#[test] +fn parse_frontmatter_preserves_body_newlines() { + let content = "---\r\ndescription: \"Line endings\"\r\nargument_hint: \"[arg]\"\r\n---\r\nFirst line\r\nSecond line\r\n"; + let (desc, hint, body) = parse_frontmatter(content); + assert_eq!(desc.as_deref(), Some("Line endings")); + assert_eq!(hint.as_deref(), Some("[arg]")); + assert_eq!(body, "First line\r\nSecond line\r\n"); +} diff --git a/codex-rs/core/src/default_client.rs b/codex-rs/core/src/default_client.rs index aa490e8264..6d0b5496ce 100644 --- a/codex-rs/core/src/default_client.rs +++ b/codex-rs/core/src/default_client.rs @@ -216,128 +216,5 @@ fn is_sandboxed() -> bool { } #[cfg(test)] -mod tests { - use super::*; - use core_test_support::skip_if_no_network; - use pretty_assertions::assert_eq; - - #[test] - fn test_get_codex_user_agent() { - let user_agent = get_codex_user_agent(); - let originator = originator().value; - let prefix = format!("{originator}/"); - assert!(user_agent.starts_with(&prefix)); - } - - #[test] - fn is_first_party_originator_matches_known_values() { - assert_eq!(is_first_party_originator(DEFAULT_ORIGINATOR), true); - assert_eq!(is_first_party_originator("codex_vscode"), true); - assert_eq!(is_first_party_originator("Codex Something Else"), true); - assert_eq!(is_first_party_originator("codex_cli"), false); - assert_eq!(is_first_party_originator("Other"), false); - } - - #[test] - fn is_first_party_chat_originator_matches_known_values() { - assert_eq!(is_first_party_chat_originator("codex_atlas"), true); - assert_eq!( - is_first_party_chat_originator("codex_chatgpt_desktop"), - true - ); - assert_eq!(is_first_party_chat_originator(DEFAULT_ORIGINATOR), false); - assert_eq!(is_first_party_chat_originator("codex_vscode"), false); - } - - #[tokio::test] - async fn test_create_client_sets_default_headers() { - skip_if_no_network!(); - - set_default_client_residency_requirement(Some(ResidencyRequirement::Us)); - - use wiremock::Mock; - use wiremock::MockServer; - use wiremock::ResponseTemplate; - use wiremock::matchers::method; - use wiremock::matchers::path; - - let client = create_client(); - - // Spin up a local mock server and capture a request. - let server = MockServer::start().await; - Mock::given(method("GET")) - .and(path("/")) - .respond_with(ResponseTemplate::new(200)) - .mount(&server) - .await; - - let resp = client - .get(server.uri()) - .send() - .await - .expect("failed to send request"); - assert!(resp.status().is_success()); - - let requests = server - .received_requests() - .await - .expect("failed to fetch received requests"); - assert!(!requests.is_empty()); - let headers = &requests[0].headers; - - // originator header is set to the provided value - let originator_header = headers - .get("originator") - .expect("originator header missing"); - assert_eq!(originator_header.to_str().unwrap(), originator().value); - - // User-Agent matches the computed Codex UA for that originator - let expected_ua = get_codex_user_agent(); - let ua_header = headers - .get("user-agent") - .expect("user-agent header missing"); - assert_eq!(ua_header.to_str().unwrap(), expected_ua); - - let residency_header = headers - .get(RESIDENCY_HEADER_NAME) - .expect("residency header missing"); - assert_eq!(residency_header.to_str().unwrap(), "us"); - - set_default_client_residency_requirement(None); - } - - #[test] - fn test_invalid_suffix_is_sanitized() { - let prefix = "codex_cli_rs/0.0.0"; - let suffix = "bad\rsuffix"; - - assert_eq!( - sanitize_user_agent(format!("{prefix} ({suffix})"), prefix), - "codex_cli_rs/0.0.0 (bad_suffix)" - ); - } - - #[test] - fn test_invalid_suffix_is_sanitized2() { - let prefix = "codex_cli_rs/0.0.0"; - let suffix = "bad\0suffix"; - - assert_eq!( - sanitize_user_agent(format!("{prefix} ({suffix})"), prefix), - "codex_cli_rs/0.0.0 (bad_suffix)" - ); - } - - #[test] - #[cfg(target_os = "macos")] - fn test_macos() { - use regex_lite::Regex; - let user_agent = get_codex_user_agent(); - let originator = regex_lite::escape(originator().value.as_str()); - let re = Regex::new(&format!( - r"^{originator}/\d+\.\d+\.\d+ \(Mac OS \d+\.\d+\.\d+; (x86_64|arm64)\) (\S+)$" - )) - .unwrap(); - assert!(re.is_match(&user_agent)); - } -} +#[path = "default_client_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/default_client_tests.rs b/codex-rs/core/src/default_client_tests.rs new file mode 100644 index 0000000000..44d5e2c3c9 --- /dev/null +++ b/codex-rs/core/src/default_client_tests.rs @@ -0,0 +1,123 @@ +use super::*; +use core_test_support::skip_if_no_network; +use pretty_assertions::assert_eq; + +#[test] +fn test_get_codex_user_agent() { + let user_agent = get_codex_user_agent(); + let originator = originator().value; + let prefix = format!("{originator}/"); + assert!(user_agent.starts_with(&prefix)); +} + +#[test] +fn is_first_party_originator_matches_known_values() { + assert_eq!(is_first_party_originator(DEFAULT_ORIGINATOR), true); + assert_eq!(is_first_party_originator("codex_vscode"), true); + assert_eq!(is_first_party_originator("Codex Something Else"), true); + assert_eq!(is_first_party_originator("codex_cli"), false); + assert_eq!(is_first_party_originator("Other"), false); +} + +#[test] +fn is_first_party_chat_originator_matches_known_values() { + assert_eq!(is_first_party_chat_originator("codex_atlas"), true); + assert_eq!( + is_first_party_chat_originator("codex_chatgpt_desktop"), + true + ); + assert_eq!(is_first_party_chat_originator(DEFAULT_ORIGINATOR), false); + assert_eq!(is_first_party_chat_originator("codex_vscode"), false); +} + +#[tokio::test] +async fn test_create_client_sets_default_headers() { + skip_if_no_network!(); + + set_default_client_residency_requirement(Some(ResidencyRequirement::Us)); + + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::ResponseTemplate; + use wiremock::matchers::method; + use wiremock::matchers::path; + + let client = create_client(); + + // Spin up a local mock server and capture a request. + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/")) + .respond_with(ResponseTemplate::new(200)) + .mount(&server) + .await; + + let resp = client + .get(server.uri()) + .send() + .await + .expect("failed to send request"); + assert!(resp.status().is_success()); + + let requests = server + .received_requests() + .await + .expect("failed to fetch received requests"); + assert!(!requests.is_empty()); + let headers = &requests[0].headers; + + // originator header is set to the provided value + let originator_header = headers + .get("originator") + .expect("originator header missing"); + assert_eq!(originator_header.to_str().unwrap(), originator().value); + + // User-Agent matches the computed Codex UA for that originator + let expected_ua = get_codex_user_agent(); + let ua_header = headers + .get("user-agent") + .expect("user-agent header missing"); + assert_eq!(ua_header.to_str().unwrap(), expected_ua); + + let residency_header = headers + .get(RESIDENCY_HEADER_NAME) + .expect("residency header missing"); + assert_eq!(residency_header.to_str().unwrap(), "us"); + + set_default_client_residency_requirement(None); +} + +#[test] +fn test_invalid_suffix_is_sanitized() { + let prefix = "codex_cli_rs/0.0.0"; + let suffix = "bad\rsuffix"; + + assert_eq!( + sanitize_user_agent(format!("{prefix} ({suffix})"), prefix), + "codex_cli_rs/0.0.0 (bad_suffix)" + ); +} + +#[test] +fn test_invalid_suffix_is_sanitized2() { + let prefix = "codex_cli_rs/0.0.0"; + let suffix = "bad\0suffix"; + + assert_eq!( + sanitize_user_agent(format!("{prefix} ({suffix})"), prefix), + "codex_cli_rs/0.0.0 (bad_suffix)" + ); +} + +#[test] +#[cfg(target_os = "macos")] +fn test_macos() { + use regex_lite::Regex; + let user_agent = get_codex_user_agent(); + let originator = regex_lite::escape(originator().value.as_str()); + let re = Regex::new(&format!( + r"^{originator}/\d+\.\d+\.\d+ \(Mac OS \d+\.\d+\.\d+; (x86_64|arm64)\) (\S+)$" + )) + .unwrap(); + assert!(re.is_match(&user_agent)); +} diff --git a/codex-rs/core/src/environment_context.rs b/codex-rs/core/src/environment_context.rs index 3e9ed871e3..0c42cd0900 100644 --- a/codex-rs/core/src/environment_context.rs +++ b/codex-rs/core/src/environment_context.rs @@ -199,279 +199,5 @@ impl From for ResponseItem { } #[cfg(test)] -mod tests { - use crate::shell::ShellType; - - use super::*; - use core_test_support::test_path_buf; - use pretty_assertions::assert_eq; - - fn fake_shell() -> Shell { - Shell { - shell_type: ShellType::Bash, - shell_path: PathBuf::from("/bin/bash"), - shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), - } - } - - #[test] - fn serialize_workspace_write_environment_context() { - let cwd = test_path_buf("/repo"); - let context = EnvironmentContext::new( - Some(cwd.clone()), - fake_shell(), - Some("2026-02-26".to_string()), - Some("America/Los_Angeles".to_string()), - None, - None, - ); - - let expected = format!( - r#" - {cwd} - bash - 2026-02-26 - America/Los_Angeles -"#, - cwd = cwd.display(), - ); - - assert_eq!(context.serialize_to_xml(), expected); - } - - #[test] - fn serialize_environment_context_with_network() { - let network = NetworkContext { - allowed_domains: vec!["api.example.com".to_string(), "*.openai.com".to_string()], - denied_domains: vec!["blocked.example.com".to_string()], - }; - let context = EnvironmentContext::new( - Some(test_path_buf("/repo")), - fake_shell(), - Some("2026-02-26".to_string()), - Some("America/Los_Angeles".to_string()), - Some(network), - None, - ); - - let expected = format!( - r#" - {} - bash - 2026-02-26 - America/Los_Angeles - - api.example.com - *.openai.com - blocked.example.com - -"#, - test_path_buf("/repo").display() - ); - - assert_eq!(context.serialize_to_xml(), expected); - } - - #[test] - fn serialize_read_only_environment_context() { - let context = EnvironmentContext::new( - None, - fake_shell(), - Some("2026-02-26".to_string()), - Some("America/Los_Angeles".to_string()), - None, - None, - ); - - let expected = r#" - bash - 2026-02-26 - America/Los_Angeles -"#; - - assert_eq!(context.serialize_to_xml(), expected); - } - - #[test] - fn serialize_external_sandbox_environment_context() { - let context = EnvironmentContext::new( - None, - fake_shell(), - Some("2026-02-26".to_string()), - Some("America/Los_Angeles".to_string()), - None, - None, - ); - - let expected = r#" - bash - 2026-02-26 - America/Los_Angeles -"#; - - assert_eq!(context.serialize_to_xml(), expected); - } - - #[test] - fn serialize_external_sandbox_with_restricted_network_environment_context() { - let context = EnvironmentContext::new( - None, - fake_shell(), - Some("2026-02-26".to_string()), - Some("America/Los_Angeles".to_string()), - None, - None, - ); - - let expected = r#" - bash - 2026-02-26 - America/Los_Angeles -"#; - - assert_eq!(context.serialize_to_xml(), expected); - } - - #[test] - fn serialize_full_access_environment_context() { - let context = EnvironmentContext::new( - None, - fake_shell(), - Some("2026-02-26".to_string()), - Some("America/Los_Angeles".to_string()), - None, - None, - ); - - let expected = r#" - bash - 2026-02-26 - America/Los_Angeles -"#; - - assert_eq!(context.serialize_to_xml(), expected); - } - - #[test] - fn equals_except_shell_compares_cwd() { - let context1 = EnvironmentContext::new( - Some(PathBuf::from("/repo")), - fake_shell(), - None, - None, - None, - None, - ); - let context2 = EnvironmentContext::new( - Some(PathBuf::from("/repo")), - fake_shell(), - None, - None, - None, - None, - ); - assert!(context1.equals_except_shell(&context2)); - } - - #[test] - fn equals_except_shell_ignores_sandbox_policy() { - let context1 = EnvironmentContext::new( - Some(PathBuf::from("/repo")), - fake_shell(), - None, - None, - None, - None, - ); - let context2 = EnvironmentContext::new( - Some(PathBuf::from("/repo")), - fake_shell(), - None, - None, - None, - None, - ); - - assert!(context1.equals_except_shell(&context2)); - } - - #[test] - fn equals_except_shell_compares_cwd_differences() { - let context1 = EnvironmentContext::new( - Some(PathBuf::from("/repo1")), - fake_shell(), - None, - None, - None, - None, - ); - let context2 = EnvironmentContext::new( - Some(PathBuf::from("/repo2")), - fake_shell(), - None, - None, - None, - None, - ); - - assert!(!context1.equals_except_shell(&context2)); - } - - #[test] - fn equals_except_shell_ignores_shell() { - let context1 = EnvironmentContext::new( - Some(PathBuf::from("/repo")), - Shell { - shell_type: ShellType::Bash, - shell_path: "/bin/bash".into(), - shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), - }, - None, - None, - None, - None, - ); - let context2 = EnvironmentContext::new( - Some(PathBuf::from("/repo")), - Shell { - shell_type: ShellType::Zsh, - shell_path: "/bin/zsh".into(), - shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), - }, - None, - None, - None, - None, - ); - - assert!(context1.equals_except_shell(&context2)); - } - - #[test] - fn serialize_environment_context_with_subagents() { - let context = EnvironmentContext::new( - Some(test_path_buf("/repo")), - fake_shell(), - Some("2026-02-26".to_string()), - Some("America/Los_Angeles".to_string()), - None, - Some("- agent-1: atlas\n- agent-2".to_string()), - ); - - let expected = format!( - r#" - {} - bash - 2026-02-26 - America/Los_Angeles - - - agent-1: atlas - - agent-2 - -"#, - test_path_buf("/repo").display() - ); - - assert_eq!(context.serialize_to_xml(), expected); - } -} +#[path = "environment_context_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/environment_context_tests.rs b/codex-rs/core/src/environment_context_tests.rs new file mode 100644 index 0000000000..5718c09de4 --- /dev/null +++ b/codex-rs/core/src/environment_context_tests.rs @@ -0,0 +1,274 @@ +use crate::shell::ShellType; + +use super::*; +use core_test_support::test_path_buf; +use pretty_assertions::assert_eq; + +fn fake_shell() -> Shell { + Shell { + shell_type: ShellType::Bash, + shell_path: PathBuf::from("/bin/bash"), + shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), + } +} + +#[test] +fn serialize_workspace_write_environment_context() { + let cwd = test_path_buf("/repo"); + let context = EnvironmentContext::new( + Some(cwd.clone()), + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + None, + None, + ); + + let expected = format!( + r#" + {cwd} + bash + 2026-02-26 + America/Los_Angeles +"#, + cwd = cwd.display(), + ); + + assert_eq!(context.serialize_to_xml(), expected); +} + +#[test] +fn serialize_environment_context_with_network() { + let network = NetworkContext { + allowed_domains: vec!["api.example.com".to_string(), "*.openai.com".to_string()], + denied_domains: vec!["blocked.example.com".to_string()], + }; + let context = EnvironmentContext::new( + Some(test_path_buf("/repo")), + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + Some(network), + None, + ); + + let expected = format!( + r#" + {} + bash + 2026-02-26 + America/Los_Angeles + + api.example.com + *.openai.com + blocked.example.com + +"#, + test_path_buf("/repo").display() + ); + + assert_eq!(context.serialize_to_xml(), expected); +} + +#[test] +fn serialize_read_only_environment_context() { + let context = EnvironmentContext::new( + None, + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + None, + None, + ); + + let expected = r#" + bash + 2026-02-26 + America/Los_Angeles +"#; + + assert_eq!(context.serialize_to_xml(), expected); +} + +#[test] +fn serialize_external_sandbox_environment_context() { + let context = EnvironmentContext::new( + None, + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + None, + None, + ); + + let expected = r#" + bash + 2026-02-26 + America/Los_Angeles +"#; + + assert_eq!(context.serialize_to_xml(), expected); +} + +#[test] +fn serialize_external_sandbox_with_restricted_network_environment_context() { + let context = EnvironmentContext::new( + None, + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + None, + None, + ); + + let expected = r#" + bash + 2026-02-26 + America/Los_Angeles +"#; + + assert_eq!(context.serialize_to_xml(), expected); +} + +#[test] +fn serialize_full_access_environment_context() { + let context = EnvironmentContext::new( + None, + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + None, + None, + ); + + let expected = r#" + bash + 2026-02-26 + America/Los_Angeles +"#; + + assert_eq!(context.serialize_to_xml(), expected); +} + +#[test] +fn equals_except_shell_compares_cwd() { + let context1 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + fake_shell(), + None, + None, + None, + None, + ); + let context2 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + fake_shell(), + None, + None, + None, + None, + ); + assert!(context1.equals_except_shell(&context2)); +} + +#[test] +fn equals_except_shell_ignores_sandbox_policy() { + let context1 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + fake_shell(), + None, + None, + None, + None, + ); + let context2 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + fake_shell(), + None, + None, + None, + None, + ); + + assert!(context1.equals_except_shell(&context2)); +} + +#[test] +fn equals_except_shell_compares_cwd_differences() { + let context1 = EnvironmentContext::new( + Some(PathBuf::from("/repo1")), + fake_shell(), + None, + None, + None, + None, + ); + let context2 = EnvironmentContext::new( + Some(PathBuf::from("/repo2")), + fake_shell(), + None, + None, + None, + None, + ); + + assert!(!context1.equals_except_shell(&context2)); +} + +#[test] +fn equals_except_shell_ignores_shell() { + let context1 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + Shell { + shell_type: ShellType::Bash, + shell_path: "/bin/bash".into(), + shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), + }, + None, + None, + None, + None, + ); + let context2 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + Shell { + shell_type: ShellType::Zsh, + shell_path: "/bin/zsh".into(), + shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), + }, + None, + None, + None, + None, + ); + + assert!(context1.equals_except_shell(&context2)); +} + +#[test] +fn serialize_environment_context_with_subagents() { + let context = EnvironmentContext::new( + Some(test_path_buf("/repo")), + fake_shell(), + Some("2026-02-26".to_string()), + Some("America/Los_Angeles".to_string()), + None, + Some("- agent-1: atlas\n- agent-2".to_string()), + ); + + let expected = format!( + r#" + {} + bash + 2026-02-26 + America/Los_Angeles + + - agent-1: atlas + - agent-2 + +"#, + test_path_buf("/repo").display() + ); + + assert_eq!(context.serialize_to_xml(), expected); +} diff --git a/codex-rs/core/src/error.rs b/codex-rs/core/src/error.rs index ad49c611fc..f3bb4dc8e5 100644 --- a/codex-rs/core/src/error.rs +++ b/codex-rs/core/src/error.rs @@ -655,492 +655,5 @@ pub fn get_error_message_ui(e: &CodexErr) -> String { } #[cfg(test)] -mod tests { - use super::*; - use crate::exec::StreamOutput; - use chrono::DateTime; - use chrono::Duration as ChronoDuration; - use chrono::TimeZone; - use chrono::Utc; - use codex_protocol::protocol::RateLimitWindow; - use pretty_assertions::assert_eq; - use reqwest::Response; - use reqwest::ResponseBuilderExt; - use reqwest::StatusCode; - use reqwest::Url; - - fn rate_limit_snapshot() -> RateLimitSnapshot { - let primary_reset_at = Utc - .with_ymd_and_hms(2024, 1, 1, 1, 0, 0) - .unwrap() - .timestamp(); - let secondary_reset_at = Utc - .with_ymd_and_hms(2024, 1, 1, 2, 0, 0) - .unwrap() - .timestamp(); - RateLimitSnapshot { - limit_id: None, - limit_name: None, - primary: Some(RateLimitWindow { - used_percent: 50.0, - window_minutes: Some(60), - resets_at: Some(primary_reset_at), - }), - secondary: Some(RateLimitWindow { - used_percent: 30.0, - window_minutes: Some(120), - resets_at: Some(secondary_reset_at), - }), - credits: None, - plan_type: None, - } - } - - fn with_now_override(now: DateTime, f: impl FnOnce() -> T) -> T { - NOW_OVERRIDE.with(|cell| { - *cell.borrow_mut() = Some(now); - let result = f(); - *cell.borrow_mut() = None; - result - }) - } - - #[test] - fn usage_limit_reached_error_formats_plus_plan() { - let err = UsageLimitReachedError { - plan_type: Some(PlanType::Known(KnownPlan::Plus)), - resets_at: None, - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: None, - }; - assert_eq!( - err.to_string(), - "You've hit your usage limit. Upgrade to Pro (https://chatgpt.com/explore/pro), visit https://chatgpt.com/codex/settings/usage to purchase more credits or try again later." - ); - } - - #[test] - fn server_overloaded_maps_to_protocol() { - let err = CodexErr::ServerOverloaded; - assert_eq!( - err.to_codex_protocol_error(), - CodexErrorInfo::ServerOverloaded - ); - } - - #[test] - fn sandbox_denied_uses_aggregated_output_when_stderr_empty() { - let output = ExecToolCallOutput { - exit_code: 77, - stdout: StreamOutput::new(String::new()), - stderr: StreamOutput::new(String::new()), - aggregated_output: StreamOutput::new("aggregate detail".to_string()), - duration: Duration::from_millis(10), - timed_out: false, - }; - let err = CodexErr::Sandbox(SandboxErr::Denied { - output: Box::new(output), - network_policy_decision: None, - }); - assert_eq!(get_error_message_ui(&err), "aggregate detail"); - } - - #[test] - fn sandbox_denied_reports_both_streams_when_available() { - let output = ExecToolCallOutput { - exit_code: 9, - stdout: StreamOutput::new("stdout detail".to_string()), - stderr: StreamOutput::new("stderr detail".to_string()), - aggregated_output: StreamOutput::new(String::new()), - duration: Duration::from_millis(10), - timed_out: false, - }; - let err = CodexErr::Sandbox(SandboxErr::Denied { - output: Box::new(output), - network_policy_decision: None, - }); - assert_eq!(get_error_message_ui(&err), "stderr detail\nstdout detail"); - } - - #[test] - fn sandbox_denied_reports_stdout_when_no_stderr() { - let output = ExecToolCallOutput { - exit_code: 11, - stdout: StreamOutput::new("stdout only".to_string()), - stderr: StreamOutput::new(String::new()), - aggregated_output: StreamOutput::new(String::new()), - duration: Duration::from_millis(8), - timed_out: false, - }; - let err = CodexErr::Sandbox(SandboxErr::Denied { - output: Box::new(output), - network_policy_decision: None, - }); - assert_eq!(get_error_message_ui(&err), "stdout only"); - } - - #[test] - fn to_error_event_handles_response_stream_failed() { - let response = http::Response::builder() - .status(StatusCode::TOO_MANY_REQUESTS) - .url(Url::parse("http://example.com").unwrap()) - .body("") - .unwrap(); - let source = Response::from(response).error_for_status_ref().unwrap_err(); - let err = CodexErr::ResponseStreamFailed(ResponseStreamFailed { - source, - request_id: Some("req-123".to_string()), - }); - - let event = err.to_error_event(Some("prefix".to_string())); - - assert_eq!( - event.message, - "prefix: Error while reading the server response: HTTP status client error (429 Too Many Requests) for url (http://example.com/), request id: req-123" - ); - assert_eq!( - event.codex_error_info, - Some(CodexErrorInfo::ResponseStreamConnectionFailed { - http_status_code: Some(429) - }) - ); - } - - #[test] - fn sandbox_denied_reports_exit_code_when_no_output_available() { - let output = ExecToolCallOutput { - exit_code: 13, - stdout: StreamOutput::new(String::new()), - stderr: StreamOutput::new(String::new()), - aggregated_output: StreamOutput::new(String::new()), - duration: Duration::from_millis(5), - timed_out: false, - }; - let err = CodexErr::Sandbox(SandboxErr::Denied { - output: Box::new(output), - network_policy_decision: None, - }); - assert_eq!( - get_error_message_ui(&err), - "command failed inside sandbox with exit code 13" - ); - } - - #[test] - fn usage_limit_reached_error_formats_free_plan() { - let err = UsageLimitReachedError { - plan_type: Some(PlanType::Known(KnownPlan::Free)), - resets_at: None, - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: None, - }; - assert_eq!( - err.to_string(), - "You've hit your usage limit. Upgrade to Plus to continue using Codex (https://chatgpt.com/explore/plus), or try again later." - ); - } - - #[test] - fn usage_limit_reached_error_formats_go_plan() { - let err = UsageLimitReachedError { - plan_type: Some(PlanType::Known(KnownPlan::Go)), - resets_at: None, - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: None, - }; - assert_eq!( - err.to_string(), - "You've hit your usage limit. Upgrade to Plus to continue using Codex (https://chatgpt.com/explore/plus), or try again later." - ); - } - - #[test] - fn usage_limit_reached_error_formats_default_when_none() { - let err = UsageLimitReachedError { - plan_type: None, - resets_at: None, - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: None, - }; - assert_eq!( - err.to_string(), - "You've hit your usage limit. Try again later." - ); - } - - #[test] - fn usage_limit_reached_error_formats_team_plan() { - let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); - let resets_at = base + ChronoDuration::hours(1); - with_now_override(base, move || { - let expected_time = format_retry_timestamp(&resets_at); - let err = UsageLimitReachedError { - plan_type: Some(PlanType::Known(KnownPlan::Team)), - resets_at: Some(resets_at), - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: None, - }; - let expected = format!( - "You've hit your usage limit. To get more access now, send a request to your admin or try again at {expected_time}." - ); - assert_eq!(err.to_string(), expected); - }); - } - - #[test] - fn usage_limit_reached_error_formats_business_plan_without_reset() { - let err = UsageLimitReachedError { - plan_type: Some(PlanType::Known(KnownPlan::Business)), - resets_at: None, - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: None, - }; - assert_eq!( - err.to_string(), - "You've hit your usage limit. To get more access now, send a request to your admin or try again later." - ); - } - - #[test] - fn usage_limit_reached_error_formats_default_for_other_plans() { - let err = UsageLimitReachedError { - plan_type: Some(PlanType::Known(KnownPlan::Enterprise)), - resets_at: None, - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: None, - }; - assert_eq!( - err.to_string(), - "You've hit your usage limit. Try again later." - ); - } - - #[test] - fn usage_limit_reached_error_formats_pro_plan_with_reset() { - let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); - let resets_at = base + ChronoDuration::hours(1); - with_now_override(base, move || { - let expected_time = format_retry_timestamp(&resets_at); - let err = UsageLimitReachedError { - plan_type: Some(PlanType::Known(KnownPlan::Pro)), - resets_at: Some(resets_at), - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: None, - }; - let expected = format!( - "You've hit your usage limit. Visit https://chatgpt.com/codex/settings/usage to purchase more credits or try again at {expected_time}." - ); - assert_eq!(err.to_string(), expected); - }); - } - - #[test] - fn usage_limit_reached_error_hides_upsell_for_non_codex_limit_name() { - let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); - let resets_at = base + ChronoDuration::hours(1); - with_now_override(base, move || { - let expected_time = format_retry_timestamp(&resets_at); - let err = UsageLimitReachedError { - plan_type: Some(PlanType::Known(KnownPlan::Plus)), - resets_at: Some(resets_at), - rate_limits: Some(Box::new(RateLimitSnapshot { - limit_id: Some("codex_other".to_string()), - limit_name: Some("codex_other".to_string()), - ..rate_limit_snapshot() - })), - promo_message: Some( - "Visit https://chatgpt.com/codex/settings/usage to purchase more credits" - .to_string(), - ), - }; - let expected = format!( - "You've hit your usage limit for codex_other. Switch to another model now, or try again at {expected_time}." - ); - assert_eq!(err.to_string(), expected); - }); - } - - #[test] - fn usage_limit_reached_includes_minutes_when_available() { - let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); - let resets_at = base + ChronoDuration::minutes(5); - with_now_override(base, move || { - let expected_time = format_retry_timestamp(&resets_at); - let err = UsageLimitReachedError { - plan_type: None, - resets_at: Some(resets_at), - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: None, - }; - let expected = format!("You've hit your usage limit. Try again at {expected_time}."); - assert_eq!(err.to_string(), expected); - }); - } - - #[test] - fn unexpected_status_cloudflare_html_is_simplified() { - let err = UnexpectedResponseError { - status: StatusCode::FORBIDDEN, - body: "Cloudflare error: Sorry, you have been blocked" - .to_string(), - url: Some("http://example.com/blocked".to_string()), - cf_ray: Some("ray-id".to_string()), - request_id: None, - }; - let status = StatusCode::FORBIDDEN.to_string(); - let url = "http://example.com/blocked"; - assert_eq!( - err.to_string(), - format!("{CLOUDFLARE_BLOCKED_MESSAGE} (status {status}), url: {url}, cf-ray: ray-id") - ); - } - - #[test] - fn unexpected_status_non_html_is_unchanged() { - let err = UnexpectedResponseError { - status: StatusCode::FORBIDDEN, - body: "plain text error".to_string(), - url: Some("http://example.com/plain".to_string()), - cf_ray: None, - request_id: None, - }; - let status = StatusCode::FORBIDDEN.to_string(); - let url = "http://example.com/plain"; - assert_eq!( - err.to_string(), - format!("unexpected status {status}: plain text error, url: {url}") - ); - } - - #[test] - fn unexpected_status_prefers_error_message_when_present() { - let err = UnexpectedResponseError { - status: StatusCode::UNAUTHORIZED, - body: r#"{"error":{"message":"Workspace is not authorized in this region."},"status":401}"# - .to_string(), - url: Some("https://chatgpt.com/backend-api/codex/responses".to_string()), - cf_ray: None, - request_id: Some("req-123".to_string()), - }; - let status = StatusCode::UNAUTHORIZED.to_string(); - assert_eq!( - err.to_string(), - format!( - "unexpected status {status}: Workspace is not authorized in this region., url: https://chatgpt.com/backend-api/codex/responses, request id: req-123" - ) - ); - } - - #[test] - fn unexpected_status_truncates_long_body_with_ellipsis() { - let long_body = "x".repeat(UNEXPECTED_RESPONSE_BODY_MAX_BYTES + 10); - let err = UnexpectedResponseError { - status: StatusCode::BAD_GATEWAY, - body: long_body, - url: Some("http://example.com/long".to_string()), - cf_ray: None, - request_id: Some("req-long".to_string()), - }; - let status = StatusCode::BAD_GATEWAY.to_string(); - let expected_body = format!("{}...", "x".repeat(UNEXPECTED_RESPONSE_BODY_MAX_BYTES)); - assert_eq!( - err.to_string(), - format!( - "unexpected status {status}: {expected_body}, url: http://example.com/long, request id: req-long" - ) - ); - } - - #[test] - fn unexpected_status_includes_cf_ray_and_request_id() { - let err = UnexpectedResponseError { - status: StatusCode::UNAUTHORIZED, - body: "plain text error".to_string(), - url: Some("https://chatgpt.com/backend-api/codex/responses".to_string()), - cf_ray: Some("9c81f9f18f2fa49d-LHR".to_string()), - request_id: Some("req-xyz".to_string()), - }; - let status = StatusCode::UNAUTHORIZED.to_string(); - assert_eq!( - err.to_string(), - format!( - "unexpected status {status}: plain text error, url: https://chatgpt.com/backend-api/codex/responses, cf-ray: 9c81f9f18f2fa49d-LHR, request id: req-xyz" - ) - ); - } - - #[test] - fn usage_limit_reached_includes_hours_and_minutes() { - let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); - let resets_at = base + ChronoDuration::hours(3) + ChronoDuration::minutes(32); - with_now_override(base, move || { - let expected_time = format_retry_timestamp(&resets_at); - let err = UsageLimitReachedError { - plan_type: Some(PlanType::Known(KnownPlan::Plus)), - resets_at: Some(resets_at), - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: None, - }; - let expected = format!( - "You've hit your usage limit. Upgrade to Pro (https://chatgpt.com/explore/pro), visit https://chatgpt.com/codex/settings/usage to purchase more credits or try again at {expected_time}." - ); - assert_eq!(err.to_string(), expected); - }); - } - - #[test] - fn usage_limit_reached_includes_days_hours_minutes() { - let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); - let resets_at = - base + ChronoDuration::days(2) + ChronoDuration::hours(3) + ChronoDuration::minutes(5); - with_now_override(base, move || { - let expected_time = format_retry_timestamp(&resets_at); - let err = UsageLimitReachedError { - plan_type: None, - resets_at: Some(resets_at), - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: None, - }; - let expected = format!("You've hit your usage limit. Try again at {expected_time}."); - assert_eq!(err.to_string(), expected); - }); - } - - #[test] - fn usage_limit_reached_less_than_minute() { - let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); - let resets_at = base + ChronoDuration::seconds(30); - with_now_override(base, move || { - let expected_time = format_retry_timestamp(&resets_at); - let err = UsageLimitReachedError { - plan_type: None, - resets_at: Some(resets_at), - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: None, - }; - let expected = format!("You've hit your usage limit. Try again at {expected_time}."); - assert_eq!(err.to_string(), expected); - }); - } - - #[test] - fn usage_limit_reached_with_promo_message() { - let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); - let resets_at = base + ChronoDuration::seconds(30); - with_now_override(base, move || { - let expected_time = format_retry_timestamp(&resets_at); - let err = UsageLimitReachedError { - plan_type: None, - resets_at: Some(resets_at), - rate_limits: Some(Box::new(rate_limit_snapshot())), - promo_message: Some( - "To continue using Codex, start a free trial of today".to_string(), - ), - }; - let expected = format!( - "You've hit your usage limit. To continue using Codex, start a free trial of today, or try again at {expected_time}." - ); - assert_eq!(err.to_string(), expected); - }); - } -} +#[path = "error_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/error_tests.rs b/codex-rs/core/src/error_tests.rs new file mode 100644 index 0000000000..fa2bd4a6e1 --- /dev/null +++ b/codex-rs/core/src/error_tests.rs @@ -0,0 +1,487 @@ +use super::*; +use crate::exec::StreamOutput; +use chrono::DateTime; +use chrono::Duration as ChronoDuration; +use chrono::TimeZone; +use chrono::Utc; +use codex_protocol::protocol::RateLimitWindow; +use pretty_assertions::assert_eq; +use reqwest::Response; +use reqwest::ResponseBuilderExt; +use reqwest::StatusCode; +use reqwest::Url; + +fn rate_limit_snapshot() -> RateLimitSnapshot { + let primary_reset_at = Utc + .with_ymd_and_hms(2024, 1, 1, 1, 0, 0) + .unwrap() + .timestamp(); + let secondary_reset_at = Utc + .with_ymd_and_hms(2024, 1, 1, 2, 0, 0) + .unwrap() + .timestamp(); + RateLimitSnapshot { + limit_id: None, + limit_name: None, + primary: Some(RateLimitWindow { + used_percent: 50.0, + window_minutes: Some(60), + resets_at: Some(primary_reset_at), + }), + secondary: Some(RateLimitWindow { + used_percent: 30.0, + window_minutes: Some(120), + resets_at: Some(secondary_reset_at), + }), + credits: None, + plan_type: None, + } +} + +fn with_now_override(now: DateTime, f: impl FnOnce() -> T) -> T { + NOW_OVERRIDE.with(|cell| { + *cell.borrow_mut() = Some(now); + let result = f(); + *cell.borrow_mut() = None; + result + }) +} + +#[test] +fn usage_limit_reached_error_formats_plus_plan() { + let err = UsageLimitReachedError { + plan_type: Some(PlanType::Known(KnownPlan::Plus)), + resets_at: None, + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: None, + }; + assert_eq!( + err.to_string(), + "You've hit your usage limit. Upgrade to Pro (https://chatgpt.com/explore/pro), visit https://chatgpt.com/codex/settings/usage to purchase more credits or try again later." + ); +} + +#[test] +fn server_overloaded_maps_to_protocol() { + let err = CodexErr::ServerOverloaded; + assert_eq!( + err.to_codex_protocol_error(), + CodexErrorInfo::ServerOverloaded + ); +} + +#[test] +fn sandbox_denied_uses_aggregated_output_when_stderr_empty() { + let output = ExecToolCallOutput { + exit_code: 77, + stdout: StreamOutput::new(String::new()), + stderr: StreamOutput::new(String::new()), + aggregated_output: StreamOutput::new("aggregate detail".to_string()), + duration: Duration::from_millis(10), + timed_out: false, + }; + let err = CodexErr::Sandbox(SandboxErr::Denied { + output: Box::new(output), + network_policy_decision: None, + }); + assert_eq!(get_error_message_ui(&err), "aggregate detail"); +} + +#[test] +fn sandbox_denied_reports_both_streams_when_available() { + let output = ExecToolCallOutput { + exit_code: 9, + stdout: StreamOutput::new("stdout detail".to_string()), + stderr: StreamOutput::new("stderr detail".to_string()), + aggregated_output: StreamOutput::new(String::new()), + duration: Duration::from_millis(10), + timed_out: false, + }; + let err = CodexErr::Sandbox(SandboxErr::Denied { + output: Box::new(output), + network_policy_decision: None, + }); + assert_eq!(get_error_message_ui(&err), "stderr detail\nstdout detail"); +} + +#[test] +fn sandbox_denied_reports_stdout_when_no_stderr() { + let output = ExecToolCallOutput { + exit_code: 11, + stdout: StreamOutput::new("stdout only".to_string()), + stderr: StreamOutput::new(String::new()), + aggregated_output: StreamOutput::new(String::new()), + duration: Duration::from_millis(8), + timed_out: false, + }; + let err = CodexErr::Sandbox(SandboxErr::Denied { + output: Box::new(output), + network_policy_decision: None, + }); + assert_eq!(get_error_message_ui(&err), "stdout only"); +} + +#[test] +fn to_error_event_handles_response_stream_failed() { + let response = http::Response::builder() + .status(StatusCode::TOO_MANY_REQUESTS) + .url(Url::parse("http://example.com").unwrap()) + .body("") + .unwrap(); + let source = Response::from(response).error_for_status_ref().unwrap_err(); + let err = CodexErr::ResponseStreamFailed(ResponseStreamFailed { + source, + request_id: Some("req-123".to_string()), + }); + + let event = err.to_error_event(Some("prefix".to_string())); + + assert_eq!( + event.message, + "prefix: Error while reading the server response: HTTP status client error (429 Too Many Requests) for url (http://example.com/), request id: req-123" + ); + assert_eq!( + event.codex_error_info, + Some(CodexErrorInfo::ResponseStreamConnectionFailed { + http_status_code: Some(429) + }) + ); +} + +#[test] +fn sandbox_denied_reports_exit_code_when_no_output_available() { + let output = ExecToolCallOutput { + exit_code: 13, + stdout: StreamOutput::new(String::new()), + stderr: StreamOutput::new(String::new()), + aggregated_output: StreamOutput::new(String::new()), + duration: Duration::from_millis(5), + timed_out: false, + }; + let err = CodexErr::Sandbox(SandboxErr::Denied { + output: Box::new(output), + network_policy_decision: None, + }); + assert_eq!( + get_error_message_ui(&err), + "command failed inside sandbox with exit code 13" + ); +} + +#[test] +fn usage_limit_reached_error_formats_free_plan() { + let err = UsageLimitReachedError { + plan_type: Some(PlanType::Known(KnownPlan::Free)), + resets_at: None, + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: None, + }; + assert_eq!( + err.to_string(), + "You've hit your usage limit. Upgrade to Plus to continue using Codex (https://chatgpt.com/explore/plus), or try again later." + ); +} + +#[test] +fn usage_limit_reached_error_formats_go_plan() { + let err = UsageLimitReachedError { + plan_type: Some(PlanType::Known(KnownPlan::Go)), + resets_at: None, + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: None, + }; + assert_eq!( + err.to_string(), + "You've hit your usage limit. Upgrade to Plus to continue using Codex (https://chatgpt.com/explore/plus), or try again later." + ); +} + +#[test] +fn usage_limit_reached_error_formats_default_when_none() { + let err = UsageLimitReachedError { + plan_type: None, + resets_at: None, + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: None, + }; + assert_eq!( + err.to_string(), + "You've hit your usage limit. Try again later." + ); +} + +#[test] +fn usage_limit_reached_error_formats_team_plan() { + let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + let resets_at = base + ChronoDuration::hours(1); + with_now_override(base, move || { + let expected_time = format_retry_timestamp(&resets_at); + let err = UsageLimitReachedError { + plan_type: Some(PlanType::Known(KnownPlan::Team)), + resets_at: Some(resets_at), + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: None, + }; + let expected = format!( + "You've hit your usage limit. To get more access now, send a request to your admin or try again at {expected_time}." + ); + assert_eq!(err.to_string(), expected); + }); +} + +#[test] +fn usage_limit_reached_error_formats_business_plan_without_reset() { + let err = UsageLimitReachedError { + plan_type: Some(PlanType::Known(KnownPlan::Business)), + resets_at: None, + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: None, + }; + assert_eq!( + err.to_string(), + "You've hit your usage limit. To get more access now, send a request to your admin or try again later." + ); +} + +#[test] +fn usage_limit_reached_error_formats_default_for_other_plans() { + let err = UsageLimitReachedError { + plan_type: Some(PlanType::Known(KnownPlan::Enterprise)), + resets_at: None, + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: None, + }; + assert_eq!( + err.to_string(), + "You've hit your usage limit. Try again later." + ); +} + +#[test] +fn usage_limit_reached_error_formats_pro_plan_with_reset() { + let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + let resets_at = base + ChronoDuration::hours(1); + with_now_override(base, move || { + let expected_time = format_retry_timestamp(&resets_at); + let err = UsageLimitReachedError { + plan_type: Some(PlanType::Known(KnownPlan::Pro)), + resets_at: Some(resets_at), + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: None, + }; + let expected = format!( + "You've hit your usage limit. Visit https://chatgpt.com/codex/settings/usage to purchase more credits or try again at {expected_time}." + ); + assert_eq!(err.to_string(), expected); + }); +} + +#[test] +fn usage_limit_reached_error_hides_upsell_for_non_codex_limit_name() { + let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + let resets_at = base + ChronoDuration::hours(1); + with_now_override(base, move || { + let expected_time = format_retry_timestamp(&resets_at); + let err = UsageLimitReachedError { + plan_type: Some(PlanType::Known(KnownPlan::Plus)), + resets_at: Some(resets_at), + rate_limits: Some(Box::new(RateLimitSnapshot { + limit_id: Some("codex_other".to_string()), + limit_name: Some("codex_other".to_string()), + ..rate_limit_snapshot() + })), + promo_message: Some( + "Visit https://chatgpt.com/codex/settings/usage to purchase more credits" + .to_string(), + ), + }; + let expected = format!( + "You've hit your usage limit for codex_other. Switch to another model now, or try again at {expected_time}." + ); + assert_eq!(err.to_string(), expected); + }); +} + +#[test] +fn usage_limit_reached_includes_minutes_when_available() { + let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + let resets_at = base + ChronoDuration::minutes(5); + with_now_override(base, move || { + let expected_time = format_retry_timestamp(&resets_at); + let err = UsageLimitReachedError { + plan_type: None, + resets_at: Some(resets_at), + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: None, + }; + let expected = format!("You've hit your usage limit. Try again at {expected_time}."); + assert_eq!(err.to_string(), expected); + }); +} + +#[test] +fn unexpected_status_cloudflare_html_is_simplified() { + let err = UnexpectedResponseError { + status: StatusCode::FORBIDDEN, + body: "Cloudflare error: Sorry, you have been blocked" + .to_string(), + url: Some("http://example.com/blocked".to_string()), + cf_ray: Some("ray-id".to_string()), + request_id: None, + }; + let status = StatusCode::FORBIDDEN.to_string(); + let url = "http://example.com/blocked"; + assert_eq!( + err.to_string(), + format!("{CLOUDFLARE_BLOCKED_MESSAGE} (status {status}), url: {url}, cf-ray: ray-id") + ); +} + +#[test] +fn unexpected_status_non_html_is_unchanged() { + let err = UnexpectedResponseError { + status: StatusCode::FORBIDDEN, + body: "plain text error".to_string(), + url: Some("http://example.com/plain".to_string()), + cf_ray: None, + request_id: None, + }; + let status = StatusCode::FORBIDDEN.to_string(); + let url = "http://example.com/plain"; + assert_eq!( + err.to_string(), + format!("unexpected status {status}: plain text error, url: {url}") + ); +} + +#[test] +fn unexpected_status_prefers_error_message_when_present() { + let err = UnexpectedResponseError { + status: StatusCode::UNAUTHORIZED, + body: r#"{"error":{"message":"Workspace is not authorized in this region."},"status":401}"# + .to_string(), + url: Some("https://chatgpt.com/backend-api/codex/responses".to_string()), + cf_ray: None, + request_id: Some("req-123".to_string()), + }; + let status = StatusCode::UNAUTHORIZED.to_string(); + assert_eq!( + err.to_string(), + format!( + "unexpected status {status}: Workspace is not authorized in this region., url: https://chatgpt.com/backend-api/codex/responses, request id: req-123" + ) + ); +} + +#[test] +fn unexpected_status_truncates_long_body_with_ellipsis() { + let long_body = "x".repeat(UNEXPECTED_RESPONSE_BODY_MAX_BYTES + 10); + let err = UnexpectedResponseError { + status: StatusCode::BAD_GATEWAY, + body: long_body, + url: Some("http://example.com/long".to_string()), + cf_ray: None, + request_id: Some("req-long".to_string()), + }; + let status = StatusCode::BAD_GATEWAY.to_string(); + let expected_body = format!("{}...", "x".repeat(UNEXPECTED_RESPONSE_BODY_MAX_BYTES)); + assert_eq!( + err.to_string(), + format!( + "unexpected status {status}: {expected_body}, url: http://example.com/long, request id: req-long" + ) + ); +} + +#[test] +fn unexpected_status_includes_cf_ray_and_request_id() { + let err = UnexpectedResponseError { + status: StatusCode::UNAUTHORIZED, + body: "plain text error".to_string(), + url: Some("https://chatgpt.com/backend-api/codex/responses".to_string()), + cf_ray: Some("9c81f9f18f2fa49d-LHR".to_string()), + request_id: Some("req-xyz".to_string()), + }; + let status = StatusCode::UNAUTHORIZED.to_string(); + assert_eq!( + err.to_string(), + format!( + "unexpected status {status}: plain text error, url: https://chatgpt.com/backend-api/codex/responses, cf-ray: 9c81f9f18f2fa49d-LHR, request id: req-xyz" + ) + ); +} + +#[test] +fn usage_limit_reached_includes_hours_and_minutes() { + let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + let resets_at = base + ChronoDuration::hours(3) + ChronoDuration::minutes(32); + with_now_override(base, move || { + let expected_time = format_retry_timestamp(&resets_at); + let err = UsageLimitReachedError { + plan_type: Some(PlanType::Known(KnownPlan::Plus)), + resets_at: Some(resets_at), + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: None, + }; + let expected = format!( + "You've hit your usage limit. Upgrade to Pro (https://chatgpt.com/explore/pro), visit https://chatgpt.com/codex/settings/usage to purchase more credits or try again at {expected_time}." + ); + assert_eq!(err.to_string(), expected); + }); +} + +#[test] +fn usage_limit_reached_includes_days_hours_minutes() { + let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + let resets_at = + base + ChronoDuration::days(2) + ChronoDuration::hours(3) + ChronoDuration::minutes(5); + with_now_override(base, move || { + let expected_time = format_retry_timestamp(&resets_at); + let err = UsageLimitReachedError { + plan_type: None, + resets_at: Some(resets_at), + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: None, + }; + let expected = format!("You've hit your usage limit. Try again at {expected_time}."); + assert_eq!(err.to_string(), expected); + }); +} + +#[test] +fn usage_limit_reached_less_than_minute() { + let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + let resets_at = base + ChronoDuration::seconds(30); + with_now_override(base, move || { + let expected_time = format_retry_timestamp(&resets_at); + let err = UsageLimitReachedError { + plan_type: None, + resets_at: Some(resets_at), + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: None, + }; + let expected = format!("You've hit your usage limit. Try again at {expected_time}."); + assert_eq!(err.to_string(), expected); + }); +} + +#[test] +fn usage_limit_reached_with_promo_message() { + let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + let resets_at = base + ChronoDuration::seconds(30); + with_now_override(base, move || { + let expected_time = format_retry_timestamp(&resets_at); + let err = UsageLimitReachedError { + plan_type: None, + resets_at: Some(resets_at), + rate_limits: Some(Box::new(rate_limit_snapshot())), + promo_message: Some( + "To continue using Codex, start a free trial of today".to_string(), + ), + }; + let expected = format!( + "You've hit your usage limit. To continue using Codex, start a free trial of today, or try again at {expected_time}." + ); + assert_eq!(err.to_string(), expected); + }); +} diff --git a/codex-rs/core/src/event_mapping.rs b/codex-rs/core/src/event_mapping.rs index 09f1235718..72372b24cd 100644 --- a/codex-rs/core/src/event_mapping.rs +++ b/codex-rs/core/src/event_mapping.rs @@ -161,410 +161,5 @@ pub fn parse_turn_item(item: &ResponseItem) -> Option { } #[cfg(test)] -mod tests { - use super::parse_turn_item; - use codex_protocol::items::AgentMessageContent; - use codex_protocol::items::TurnItem; - use codex_protocol::items::WebSearchItem; - use codex_protocol::models::ContentItem; - use codex_protocol::models::ReasoningItemContent; - use codex_protocol::models::ReasoningItemReasoningSummary; - use codex_protocol::models::ResponseItem; - use codex_protocol::models::WebSearchAction; - use codex_protocol::user_input::UserInput; - use pretty_assertions::assert_eq; - - #[test] - fn parses_user_message_with_text_and_two_images() { - let img1 = "https://example.com/one.png".to_string(); - let img2 = "https://example.com/two.jpg".to_string(); - - let item = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ - ContentItem::InputText { - text: "Hello world".to_string(), - }, - ContentItem::InputImage { - image_url: img1.clone(), - }, - ContentItem::InputImage { - image_url: img2.clone(), - }, - ], - end_turn: None, - phase: None, - }; - - let turn_item = parse_turn_item(&item).expect("expected user message turn item"); - - match turn_item { - TurnItem::UserMessage(user) => { - let expected_content = vec![ - UserInput::Text { - text: "Hello world".to_string(), - text_elements: Vec::new(), - }, - UserInput::Image { image_url: img1 }, - UserInput::Image { image_url: img2 }, - ]; - assert_eq!(user.content, expected_content); - } - other => panic!("expected TurnItem::UserMessage, got {other:?}"), - } - } - - #[test] - fn skips_local_image_label_text() { - let image_url = "data:image/png;base64,abc".to_string(); - let label = codex_protocol::models::local_image_open_tag_text(1); - let user_text = "Please review this image.".to_string(); - - let item = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ - ContentItem::InputText { text: label }, - ContentItem::InputImage { - image_url: image_url.clone(), - }, - ContentItem::InputText { - text: "".to_string(), - }, - ContentItem::InputText { - text: user_text.clone(), - }, - ], - end_turn: None, - phase: None, - }; - - let turn_item = parse_turn_item(&item).expect("expected user message turn item"); - - match turn_item { - TurnItem::UserMessage(user) => { - let expected_content = vec![ - UserInput::Image { image_url }, - UserInput::Text { - text: user_text, - text_elements: Vec::new(), - }, - ]; - assert_eq!(user.content, expected_content); - } - other => panic!("expected TurnItem::UserMessage, got {other:?}"), - } - } - - #[test] - fn skips_unnamed_image_label_text() { - let image_url = "data:image/png;base64,abc".to_string(); - let label = codex_protocol::models::image_open_tag_text(); - let user_text = "Please review this image.".to_string(); - - let item = ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ - ContentItem::InputText { text: label }, - ContentItem::InputImage { - image_url: image_url.clone(), - }, - ContentItem::InputText { - text: codex_protocol::models::image_close_tag_text(), - }, - ContentItem::InputText { - text: user_text.clone(), - }, - ], - end_turn: None, - phase: None, - }; - - let turn_item = parse_turn_item(&item).expect("expected user message turn item"); - - match turn_item { - TurnItem::UserMessage(user) => { - let expected_content = vec![ - UserInput::Image { image_url }, - UserInput::Text { - text: user_text, - text_elements: Vec::new(), - }, - ]; - assert_eq!(user.content, expected_content); - } - other => panic!("expected TurnItem::UserMessage, got {other:?}"), - } - } - - #[test] - fn skips_user_instructions_and_env() { - let items = vec![ - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "# AGENTS.md instructions for test_directory\n\n\ntest_text\n".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "test_text".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "# AGENTS.md instructions for test_directory\n\n\ntest_text\n".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "\ndemo\nskills/demo/SKILL.md\nbody\n" - .to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "echo 42".to_string(), - }], - end_turn: None, - phase: None, - }, - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ - ContentItem::InputText { - text: "ctx".to_string(), - }, - ContentItem::InputText { - text: - "# AGENTS.md instructions for dir\n\n\nbody\n" - .to_string(), - }, - ], - end_turn: None, - phase: None, - }, - ]; - - for item in items { - let turn_item = parse_turn_item(&item); - assert!(turn_item.is_none(), "expected none, got {turn_item:?}"); - } - } - - #[test] - fn parses_agent_message() { - let item = ResponseItem::Message { - id: Some("msg-1".to_string()), - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: "Hello from Codex".to_string(), - }], - end_turn: None, - phase: None, - }; - - let turn_item = parse_turn_item(&item).expect("expected agent message turn item"); - - match turn_item { - TurnItem::AgentMessage(message) => { - let Some(AgentMessageContent::Text { text }) = message.content.first() else { - panic!("expected agent message text content"); - }; - assert_eq!(text, "Hello from Codex"); - } - other => panic!("expected TurnItem::AgentMessage, got {other:?}"), - } - } - - #[test] - fn parses_reasoning_summary_and_raw_content() { - let item = ResponseItem::Reasoning { - id: "reasoning_1".to_string(), - summary: vec![ - ReasoningItemReasoningSummary::SummaryText { - text: "Step 1".to_string(), - }, - ReasoningItemReasoningSummary::SummaryText { - text: "Step 2".to_string(), - }, - ], - content: Some(vec![ReasoningItemContent::ReasoningText { - text: "raw details".to_string(), - }]), - encrypted_content: None, - }; - - let turn_item = parse_turn_item(&item).expect("expected reasoning turn item"); - - match turn_item { - TurnItem::Reasoning(reasoning) => { - assert_eq!( - reasoning.summary_text, - vec!["Step 1".to_string(), "Step 2".to_string()] - ); - assert_eq!(reasoning.raw_content, vec!["raw details".to_string()]); - } - other => panic!("expected TurnItem::Reasoning, got {other:?}"), - } - } - - #[test] - fn parses_reasoning_including_raw_content() { - let item = ResponseItem::Reasoning { - id: "reasoning_2".to_string(), - summary: vec![ReasoningItemReasoningSummary::SummaryText { - text: "Summarized step".to_string(), - }], - content: Some(vec![ - ReasoningItemContent::ReasoningText { - text: "raw step".to_string(), - }, - ReasoningItemContent::Text { - text: "final thought".to_string(), - }, - ]), - encrypted_content: None, - }; - - let turn_item = parse_turn_item(&item).expect("expected reasoning turn item"); - - match turn_item { - TurnItem::Reasoning(reasoning) => { - assert_eq!(reasoning.summary_text, vec!["Summarized step".to_string()]); - assert_eq!( - reasoning.raw_content, - vec!["raw step".to_string(), "final thought".to_string()] - ); - } - other => panic!("expected TurnItem::Reasoning, got {other:?}"), - } - } - - #[test] - fn parses_web_search_call() { - let item = ResponseItem::WebSearchCall { - id: Some("ws_1".to_string()), - status: Some("completed".to_string()), - action: Some(WebSearchAction::Search { - query: Some("weather".to_string()), - queries: None, - }), - }; - - let turn_item = parse_turn_item(&item).expect("expected web search turn item"); - - match turn_item { - TurnItem::WebSearch(search) => assert_eq!( - search, - WebSearchItem { - id: "ws_1".to_string(), - query: "weather".to_string(), - action: WebSearchAction::Search { - query: Some("weather".to_string()), - queries: None, - }, - } - ), - other => panic!("expected TurnItem::WebSearch, got {other:?}"), - } - } - - #[test] - fn parses_web_search_open_page_call() { - let item = ResponseItem::WebSearchCall { - id: Some("ws_open".to_string()), - status: Some("completed".to_string()), - action: Some(WebSearchAction::OpenPage { - url: Some("https://example.com".to_string()), - }), - }; - - let turn_item = parse_turn_item(&item).expect("expected web search turn item"); - - match turn_item { - TurnItem::WebSearch(search) => assert_eq!( - search, - WebSearchItem { - id: "ws_open".to_string(), - query: "https://example.com".to_string(), - action: WebSearchAction::OpenPage { - url: Some("https://example.com".to_string()), - }, - } - ), - other => panic!("expected TurnItem::WebSearch, got {other:?}"), - } - } - - #[test] - fn parses_web_search_find_in_page_call() { - let item = ResponseItem::WebSearchCall { - id: Some("ws_find".to_string()), - status: Some("completed".to_string()), - action: Some(WebSearchAction::FindInPage { - url: Some("https://example.com".to_string()), - pattern: Some("needle".to_string()), - }), - }; - - let turn_item = parse_turn_item(&item).expect("expected web search turn item"); - - match turn_item { - TurnItem::WebSearch(search) => assert_eq!( - search, - WebSearchItem { - id: "ws_find".to_string(), - query: "'needle' in https://example.com".to_string(), - action: WebSearchAction::FindInPage { - url: Some("https://example.com".to_string()), - pattern: Some("needle".to_string()), - }, - } - ), - other => panic!("expected TurnItem::WebSearch, got {other:?}"), - } - } - - #[test] - fn parses_partial_web_search_call_without_action_as_other() { - let item = ResponseItem::WebSearchCall { - id: Some("ws_partial".to_string()), - status: Some("in_progress".to_string()), - action: None, - }; - - let turn_item = parse_turn_item(&item).expect("expected web search turn item"); - match turn_item { - TurnItem::WebSearch(search) => assert_eq!( - search, - WebSearchItem { - id: "ws_partial".to_string(), - query: String::new(), - action: WebSearchAction::Other, - } - ), - other => panic!("expected TurnItem::WebSearch, got {other:?}"), - } - } -} +#[path = "event_mapping_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/event_mapping_tests.rs b/codex-rs/core/src/event_mapping_tests.rs new file mode 100644 index 0000000000..7a9b7076be --- /dev/null +++ b/codex-rs/core/src/event_mapping_tests.rs @@ -0,0 +1,405 @@ +use super::parse_turn_item; +use codex_protocol::items::AgentMessageContent; +use codex_protocol::items::TurnItem; +use codex_protocol::items::WebSearchItem; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ReasoningItemContent; +use codex_protocol::models::ReasoningItemReasoningSummary; +use codex_protocol::models::ResponseItem; +use codex_protocol::models::WebSearchAction; +use codex_protocol::user_input::UserInput; +use pretty_assertions::assert_eq; + +#[test] +fn parses_user_message_with_text_and_two_images() { + let img1 = "https://example.com/one.png".to_string(); + let img2 = "https://example.com/two.jpg".to_string(); + + let item = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ + ContentItem::InputText { + text: "Hello world".to_string(), + }, + ContentItem::InputImage { + image_url: img1.clone(), + }, + ContentItem::InputImage { + image_url: img2.clone(), + }, + ], + end_turn: None, + phase: None, + }; + + let turn_item = parse_turn_item(&item).expect("expected user message turn item"); + + match turn_item { + TurnItem::UserMessage(user) => { + let expected_content = vec![ + UserInput::Text { + text: "Hello world".to_string(), + text_elements: Vec::new(), + }, + UserInput::Image { image_url: img1 }, + UserInput::Image { image_url: img2 }, + ]; + assert_eq!(user.content, expected_content); + } + other => panic!("expected TurnItem::UserMessage, got {other:?}"), + } +} + +#[test] +fn skips_local_image_label_text() { + let image_url = "data:image/png;base64,abc".to_string(); + let label = codex_protocol::models::local_image_open_tag_text(1); + let user_text = "Please review this image.".to_string(); + + let item = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ + ContentItem::InputText { text: label }, + ContentItem::InputImage { + image_url: image_url.clone(), + }, + ContentItem::InputText { + text: "".to_string(), + }, + ContentItem::InputText { + text: user_text.clone(), + }, + ], + end_turn: None, + phase: None, + }; + + let turn_item = parse_turn_item(&item).expect("expected user message turn item"); + + match turn_item { + TurnItem::UserMessage(user) => { + let expected_content = vec![ + UserInput::Image { image_url }, + UserInput::Text { + text: user_text, + text_elements: Vec::new(), + }, + ]; + assert_eq!(user.content, expected_content); + } + other => panic!("expected TurnItem::UserMessage, got {other:?}"), + } +} + +#[test] +fn skips_unnamed_image_label_text() { + let image_url = "data:image/png;base64,abc".to_string(); + let label = codex_protocol::models::image_open_tag_text(); + let user_text = "Please review this image.".to_string(); + + let item = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ + ContentItem::InputText { text: label }, + ContentItem::InputImage { + image_url: image_url.clone(), + }, + ContentItem::InputText { + text: codex_protocol::models::image_close_tag_text(), + }, + ContentItem::InputText { + text: user_text.clone(), + }, + ], + end_turn: None, + phase: None, + }; + + let turn_item = parse_turn_item(&item).expect("expected user message turn item"); + + match turn_item { + TurnItem::UserMessage(user) => { + let expected_content = vec![ + UserInput::Image { image_url }, + UserInput::Text { + text: user_text, + text_elements: Vec::new(), + }, + ]; + assert_eq!(user.content, expected_content); + } + other => panic!("expected TurnItem::UserMessage, got {other:?}"), + } +} + +#[test] +fn skips_user_instructions_and_env() { + let items = vec![ + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "# AGENTS.md instructions for test_directory\n\n\ntest_text\n".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "test_text".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "# AGENTS.md instructions for test_directory\n\n\ntest_text\n".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "\ndemo\nskills/demo/SKILL.md\nbody\n" + .to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "echo 42".to_string(), + }], + end_turn: None, + phase: None, + }, + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ + ContentItem::InputText { + text: "ctx".to_string(), + }, + ContentItem::InputText { + text: + "# AGENTS.md instructions for dir\n\n\nbody\n" + .to_string(), + }, + ], + end_turn: None, + phase: None, + }, + ]; + + for item in items { + let turn_item = parse_turn_item(&item); + assert!(turn_item.is_none(), "expected none, got {turn_item:?}"); + } +} + +#[test] +fn parses_agent_message() { + let item = ResponseItem::Message { + id: Some("msg-1".to_string()), + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "Hello from Codex".to_string(), + }], + end_turn: None, + phase: None, + }; + + let turn_item = parse_turn_item(&item).expect("expected agent message turn item"); + + match turn_item { + TurnItem::AgentMessage(message) => { + let Some(AgentMessageContent::Text { text }) = message.content.first() else { + panic!("expected agent message text content"); + }; + assert_eq!(text, "Hello from Codex"); + } + other => panic!("expected TurnItem::AgentMessage, got {other:?}"), + } +} + +#[test] +fn parses_reasoning_summary_and_raw_content() { + let item = ResponseItem::Reasoning { + id: "reasoning_1".to_string(), + summary: vec![ + ReasoningItemReasoningSummary::SummaryText { + text: "Step 1".to_string(), + }, + ReasoningItemReasoningSummary::SummaryText { + text: "Step 2".to_string(), + }, + ], + content: Some(vec![ReasoningItemContent::ReasoningText { + text: "raw details".to_string(), + }]), + encrypted_content: None, + }; + + let turn_item = parse_turn_item(&item).expect("expected reasoning turn item"); + + match turn_item { + TurnItem::Reasoning(reasoning) => { + assert_eq!( + reasoning.summary_text, + vec!["Step 1".to_string(), "Step 2".to_string()] + ); + assert_eq!(reasoning.raw_content, vec!["raw details".to_string()]); + } + other => panic!("expected TurnItem::Reasoning, got {other:?}"), + } +} + +#[test] +fn parses_reasoning_including_raw_content() { + let item = ResponseItem::Reasoning { + id: "reasoning_2".to_string(), + summary: vec![ReasoningItemReasoningSummary::SummaryText { + text: "Summarized step".to_string(), + }], + content: Some(vec![ + ReasoningItemContent::ReasoningText { + text: "raw step".to_string(), + }, + ReasoningItemContent::Text { + text: "final thought".to_string(), + }, + ]), + encrypted_content: None, + }; + + let turn_item = parse_turn_item(&item).expect("expected reasoning turn item"); + + match turn_item { + TurnItem::Reasoning(reasoning) => { + assert_eq!(reasoning.summary_text, vec!["Summarized step".to_string()]); + assert_eq!( + reasoning.raw_content, + vec!["raw step".to_string(), "final thought".to_string()] + ); + } + other => panic!("expected TurnItem::Reasoning, got {other:?}"), + } +} + +#[test] +fn parses_web_search_call() { + let item = ResponseItem::WebSearchCall { + id: Some("ws_1".to_string()), + status: Some("completed".to_string()), + action: Some(WebSearchAction::Search { + query: Some("weather".to_string()), + queries: None, + }), + }; + + let turn_item = parse_turn_item(&item).expect("expected web search turn item"); + + match turn_item { + TurnItem::WebSearch(search) => assert_eq!( + search, + WebSearchItem { + id: "ws_1".to_string(), + query: "weather".to_string(), + action: WebSearchAction::Search { + query: Some("weather".to_string()), + queries: None, + }, + } + ), + other => panic!("expected TurnItem::WebSearch, got {other:?}"), + } +} + +#[test] +fn parses_web_search_open_page_call() { + let item = ResponseItem::WebSearchCall { + id: Some("ws_open".to_string()), + status: Some("completed".to_string()), + action: Some(WebSearchAction::OpenPage { + url: Some("https://example.com".to_string()), + }), + }; + + let turn_item = parse_turn_item(&item).expect("expected web search turn item"); + + match turn_item { + TurnItem::WebSearch(search) => assert_eq!( + search, + WebSearchItem { + id: "ws_open".to_string(), + query: "https://example.com".to_string(), + action: WebSearchAction::OpenPage { + url: Some("https://example.com".to_string()), + }, + } + ), + other => panic!("expected TurnItem::WebSearch, got {other:?}"), + } +} + +#[test] +fn parses_web_search_find_in_page_call() { + let item = ResponseItem::WebSearchCall { + id: Some("ws_find".to_string()), + status: Some("completed".to_string()), + action: Some(WebSearchAction::FindInPage { + url: Some("https://example.com".to_string()), + pattern: Some("needle".to_string()), + }), + }; + + let turn_item = parse_turn_item(&item).expect("expected web search turn item"); + + match turn_item { + TurnItem::WebSearch(search) => assert_eq!( + search, + WebSearchItem { + id: "ws_find".to_string(), + query: "'needle' in https://example.com".to_string(), + action: WebSearchAction::FindInPage { + url: Some("https://example.com".to_string()), + pattern: Some("needle".to_string()), + }, + } + ), + other => panic!("expected TurnItem::WebSearch, got {other:?}"), + } +} + +#[test] +fn parses_partial_web_search_call_without_action_as_other() { + let item = ResponseItem::WebSearchCall { + id: Some("ws_partial".to_string()), + status: Some("in_progress".to_string()), + action: None, + }; + + let turn_item = parse_turn_item(&item).expect("expected web search turn item"); + match turn_item { + TurnItem::WebSearch(search) => assert_eq!( + search, + WebSearchItem { + id: "ws_partial".to_string(), + query: String::new(), + action: WebSearchAction::Other, + } + ), + other => panic!("expected TurnItem::WebSearch, got {other:?}"), + } +} diff --git a/codex-rs/core/src/exec.rs b/codex-rs/core/src/exec.rs index 1b9456fb87..867b93ab5a 100644 --- a/codex-rs/core/src/exec.rs +++ b/codex-rs/core/src/exec.rs @@ -1003,428 +1003,5 @@ fn synthetic_exit_status(code: i32) -> ExitStatus { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use std::time::Duration; - use tokio::io::AsyncWriteExt; - - fn make_exec_output( - exit_code: i32, - stdout: &str, - stderr: &str, - aggregated: &str, - ) -> ExecToolCallOutput { - ExecToolCallOutput { - exit_code, - stdout: StreamOutput::new(stdout.to_string()), - stderr: StreamOutput::new(stderr.to_string()), - aggregated_output: StreamOutput::new(aggregated.to_string()), - duration: Duration::from_millis(1), - timed_out: false, - } - } - - #[test] - fn sandbox_detection_requires_keywords() { - let output = make_exec_output(1, "", "", ""); - assert!(!is_likely_sandbox_denied( - SandboxType::LinuxSeccomp, - &output - )); - } - - #[test] - fn sandbox_detection_identifies_keyword_in_stderr() { - let output = make_exec_output(1, "", "Operation not permitted", ""); - assert!(is_likely_sandbox_denied(SandboxType::LinuxSeccomp, &output)); - } - - #[test] - fn sandbox_detection_respects_quick_reject_exit_codes() { - let output = make_exec_output(127, "", "command not found", ""); - assert!(!is_likely_sandbox_denied( - SandboxType::LinuxSeccomp, - &output - )); - } - - #[test] - fn sandbox_detection_ignores_non_sandbox_mode() { - let output = make_exec_output(1, "", "Operation not permitted", ""); - assert!(!is_likely_sandbox_denied(SandboxType::None, &output)); - } - - #[test] - fn sandbox_detection_ignores_network_policy_text_in_non_sandbox_mode() { - let output = make_exec_output( - 0, - "", - "", - r#"CODEX_NETWORK_POLICY_DECISION {"decision":"ask","reason":"not_allowed","source":"decider","protocol":"http","host":"google.com","port":80}"#, - ); - assert!(!is_likely_sandbox_denied(SandboxType::None, &output)); - } - - #[test] - fn sandbox_detection_uses_aggregated_output() { - let output = make_exec_output( - 101, - "", - "", - "cargo failed: Read-only file system when writing target", - ); - assert!(is_likely_sandbox_denied( - SandboxType::MacosSeatbelt, - &output - )); - } - - #[test] - fn sandbox_detection_ignores_network_policy_text_with_zero_exit_code() { - let output = make_exec_output( - 0, - "", - "", - r#"CODEX_NETWORK_POLICY_DECISION {"decision":"ask","source":"decider","protocol":"http","host":"google.com","port":80}"#, - ); - - assert!(!is_likely_sandbox_denied( - SandboxType::LinuxSeccomp, - &output - )); - } - - #[tokio::test] - async fn read_capped_limits_retained_bytes() { - let (mut writer, reader) = tokio::io::duplex(1024); - let bytes = vec![b'a'; EXEC_OUTPUT_MAX_BYTES.saturating_add(128 * 1024)]; - tokio::spawn(async move { - writer.write_all(&bytes).await.expect("write"); - }); - - let out = read_capped(reader, None, false).await.expect("read"); - assert_eq!(out.text.len(), EXEC_OUTPUT_MAX_BYTES); - } - - #[test] - fn aggregate_output_prefers_stderr_on_contention() { - let stdout = StreamOutput { - text: vec![b'a'; EXEC_OUTPUT_MAX_BYTES], - truncated_after_lines: None, - }; - let stderr = StreamOutput { - text: vec![b'b'; EXEC_OUTPUT_MAX_BYTES], - truncated_after_lines: None, - }; - - let aggregated = aggregate_output(&stdout, &stderr); - let stdout_cap = EXEC_OUTPUT_MAX_BYTES / 3; - let stderr_cap = EXEC_OUTPUT_MAX_BYTES.saturating_sub(stdout_cap); - - assert_eq!(aggregated.text.len(), EXEC_OUTPUT_MAX_BYTES); - assert_eq!(aggregated.text[..stdout_cap], vec![b'a'; stdout_cap]); - assert_eq!(aggregated.text[stdout_cap..], vec![b'b'; stderr_cap]); - } - - #[test] - fn aggregate_output_fills_remaining_capacity_with_stderr() { - let stdout_len = EXEC_OUTPUT_MAX_BYTES / 10; - let stdout = StreamOutput { - text: vec![b'a'; stdout_len], - truncated_after_lines: None, - }; - let stderr = StreamOutput { - text: vec![b'b'; EXEC_OUTPUT_MAX_BYTES], - truncated_after_lines: None, - }; - - let aggregated = aggregate_output(&stdout, &stderr); - let stderr_cap = EXEC_OUTPUT_MAX_BYTES.saturating_sub(stdout_len); - - assert_eq!(aggregated.text.len(), EXEC_OUTPUT_MAX_BYTES); - assert_eq!(aggregated.text[..stdout_len], vec![b'a'; stdout_len]); - assert_eq!(aggregated.text[stdout_len..], vec![b'b'; stderr_cap]); - } - - #[test] - fn aggregate_output_rebalances_when_stderr_is_small() { - let stdout = StreamOutput { - text: vec![b'a'; EXEC_OUTPUT_MAX_BYTES], - truncated_after_lines: None, - }; - let stderr = StreamOutput { - text: vec![b'b'; 1], - truncated_after_lines: None, - }; - - let aggregated = aggregate_output(&stdout, &stderr); - let stdout_len = EXEC_OUTPUT_MAX_BYTES.saturating_sub(1); - - assert_eq!(aggregated.text.len(), EXEC_OUTPUT_MAX_BYTES); - assert_eq!(aggregated.text[..stdout_len], vec![b'a'; stdout_len]); - assert_eq!(aggregated.text[stdout_len..], vec![b'b'; 1]); - } - - #[test] - fn aggregate_output_keeps_stdout_then_stderr_when_under_cap() { - let stdout = StreamOutput { - text: vec![b'a'; 4], - truncated_after_lines: None, - }; - let stderr = StreamOutput { - text: vec![b'b'; 3], - truncated_after_lines: None, - }; - - let aggregated = aggregate_output(&stdout, &stderr); - let mut expected = Vec::new(); - expected.extend_from_slice(&stdout.text); - expected.extend_from_slice(&stderr.text); - - assert_eq!(aggregated.text, expected); - assert_eq!(aggregated.truncated_after_lines, None); - } - - #[test] - fn windows_restricted_token_skips_external_sandbox_policies() { - let policy = SandboxPolicy::ExternalSandbox { - network_access: codex_protocol::protocol::NetworkAccess::Restricted, - }; - let file_system_policy = FileSystemSandboxPolicy::restricted(vec![]); - - assert_eq!( - should_use_windows_restricted_token_sandbox( - SandboxType::WindowsRestrictedToken, - &policy, - &file_system_policy, - ), - false - ); - } - - #[test] - fn windows_restricted_token_runs_for_legacy_restricted_policies() { - let policy = SandboxPolicy::new_read_only_policy(); - let file_system_policy = FileSystemSandboxPolicy::restricted(vec![]); - - assert_eq!( - should_use_windows_restricted_token_sandbox( - SandboxType::WindowsRestrictedToken, - &policy, - &file_system_policy, - ), - true - ); - } - - #[test] - fn windows_restricted_token_rejects_network_only_restrictions() { - let policy = SandboxPolicy::ExternalSandbox { - network_access: codex_protocol::protocol::NetworkAccess::Restricted, - }; - let file_system_policy = FileSystemSandboxPolicy::unrestricted(); - - assert_eq!( - unsupported_windows_restricted_token_sandbox_reason( - SandboxType::WindowsRestrictedToken, - &policy, - &file_system_policy, - NetworkSandboxPolicy::Restricted, - ), - Some( - "windows sandbox backend cannot enforce file_system=Unrestricted, network=Restricted, legacy_policy=ExternalSandbox { network_access: Restricted }; refusing to run unsandboxed".to_string() - ) - ); - } - - #[test] - fn windows_restricted_token_allows_legacy_restricted_policies() { - let policy = SandboxPolicy::new_read_only_policy(); - let file_system_policy = FileSystemSandboxPolicy::restricted(vec![]); - - assert_eq!( - unsupported_windows_restricted_token_sandbox_reason( - SandboxType::WindowsRestrictedToken, - &policy, - &file_system_policy, - NetworkSandboxPolicy::Restricted, - ), - None - ); - } - - #[test] - fn windows_restricted_token_allows_legacy_workspace_write_policies() { - let policy = SandboxPolicy::WorkspaceWrite { - writable_roots: vec![], - read_only_access: codex_protocol::protocol::ReadOnlyAccess::FullAccess, - network_access: false, - exclude_tmpdir_env_var: false, - exclude_slash_tmp: false, - }; - let file_system_policy = FileSystemSandboxPolicy::from(&policy); - - assert_eq!( - unsupported_windows_restricted_token_sandbox_reason( - SandboxType::WindowsRestrictedToken, - &policy, - &file_system_policy, - NetworkSandboxPolicy::Restricted, - ), - None - ); - } - - #[test] - fn process_exec_tool_call_uses_platform_sandbox_for_network_only_restrictions() { - let expected = crate::get_platform_sandbox(false).unwrap_or(SandboxType::None); - - assert_eq!( - select_process_exec_tool_sandbox_type( - &FileSystemSandboxPolicy::unrestricted(), - NetworkSandboxPolicy::Restricted, - codex_protocol::config_types::WindowsSandboxLevel::Disabled, - false, - ), - expected - ); - } - - #[cfg(unix)] - #[test] - fn sandbox_detection_flags_sigsys_exit_code() { - let exit_code = EXIT_CODE_SIGNAL_BASE + libc::SIGSYS; - let output = make_exec_output(exit_code, "", "", ""); - assert!(is_likely_sandbox_denied(SandboxType::LinuxSeccomp, &output)); - } - - #[cfg(unix)] - #[tokio::test] - async fn kill_child_process_group_kills_grandchildren_on_timeout() -> Result<()> { - // On Linux/macOS, /bin/bash is typically present; on FreeBSD/OpenBSD, - // prefer /bin/sh to avoid NotFound errors. - #[cfg(any(target_os = "freebsd", target_os = "openbsd"))] - let command = vec![ - "/bin/sh".to_string(), - "-c".to_string(), - "sleep 60 & echo $!; sleep 60".to_string(), - ]; - #[cfg(all(unix, not(any(target_os = "freebsd", target_os = "openbsd"))))] - let command = vec![ - "/bin/bash".to_string(), - "-c".to_string(), - "sleep 60 & echo $!; sleep 60".to_string(), - ]; - let env: HashMap = std::env::vars().collect(); - let params = ExecParams { - command, - cwd: std::env::current_dir()?, - expiration: 500.into(), - env, - network: None, - sandbox_permissions: SandboxPermissions::UseDefault, - windows_sandbox_level: codex_protocol::config_types::WindowsSandboxLevel::Disabled, - justification: None, - arg0: None, - }; - - let output = exec( - params, - SandboxType::None, - &SandboxPolicy::new_read_only_policy(), - &FileSystemSandboxPolicy::from(&SandboxPolicy::new_read_only_policy()), - NetworkSandboxPolicy::Restricted, - None, - None, - ) - .await?; - assert!(output.timed_out); - - let stdout = output.stdout.from_utf8_lossy().text; - let pid_line = stdout.lines().next().unwrap_or("").trim(); - let pid: i32 = pid_line.parse().map_err(|error| { - io::Error::new( - io::ErrorKind::InvalidData, - format!("Failed to parse pid from stdout '{pid_line}': {error}"), - ) - })?; - - let mut killed = false; - for _ in 0..20 { - // Use kill(pid, 0) to check if the process is alive. - if unsafe { libc::kill(pid, 0) } == -1 - && let Some(libc::ESRCH) = std::io::Error::last_os_error().raw_os_error() - { - killed = true; - break; - } - tokio::time::sleep(Duration::from_millis(100)).await; - } - - assert!(killed, "grandchild process with pid {pid} is still alive"); - Ok(()) - } - - #[tokio::test] - async fn process_exec_tool_call_respects_cancellation_token() -> Result<()> { - let command = long_running_command(); - let cwd = std::env::current_dir()?; - let env: HashMap = std::env::vars().collect(); - let cancel_token = CancellationToken::new(); - let cancel_tx = cancel_token.clone(); - let params = ExecParams { - command, - cwd: cwd.clone(), - expiration: ExecExpiration::Cancellation(cancel_token), - env, - network: None, - sandbox_permissions: SandboxPermissions::UseDefault, - windows_sandbox_level: codex_protocol::config_types::WindowsSandboxLevel::Disabled, - justification: None, - arg0: None, - }; - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(1_000)).await; - cancel_tx.cancel(); - }); - let result = process_exec_tool_call( - params, - &SandboxPolicy::DangerFullAccess, - &FileSystemSandboxPolicy::from(&SandboxPolicy::DangerFullAccess), - NetworkSandboxPolicy::Enabled, - cwd.as_path(), - &None, - false, - None, - ) - .await; - let output = match result { - Err(CodexErr::Sandbox(SandboxErr::Timeout { output })) => output, - other => panic!("expected timeout error, got {other:?}"), - }; - assert!(output.timed_out); - assert_eq!(output.exit_code, EXEC_TIMEOUT_EXIT_CODE); - Ok(()) - } - - #[cfg(unix)] - fn long_running_command() -> Vec { - vec![ - "/bin/sh".to_string(), - "-c".to_string(), - "sleep 30".to_string(), - ] - } - - #[cfg(windows)] - fn long_running_command() -> Vec { - vec![ - "powershell.exe".to_string(), - "-NonInteractive".to_string(), - "-NoLogo".to_string(), - "-Command".to_string(), - "Start-Sleep -Seconds 30".to_string(), - ] - } -} +#[path = "exec_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/exec_env.rs b/codex-rs/core/src/exec_env.rs index eabd35b410..83ac8ad379 100644 --- a/codex-rs/core/src/exec_env.rs +++ b/codex-rs/core/src/exec_env.rs @@ -94,220 +94,5 @@ where } #[cfg(test)] -mod tests { - use super::*; - use crate::config::types::ShellEnvironmentPolicyInherit; - use maplit::hashmap; - - fn make_vars(pairs: &[(&str, &str)]) -> Vec<(String, String)> { - pairs - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect() - } - - #[test] - fn test_core_inherit_defaults_keep_sensitive_vars() { - let vars = make_vars(&[ - ("PATH", "/usr/bin"), - ("HOME", "/home/user"), - ("API_KEY", "secret"), - ("SECRET_TOKEN", "t"), - ]); - - let policy = ShellEnvironmentPolicy::default(); // inherit All, default excludes ignored - let thread_id = ThreadId::new(); - let result = populate_env(vars, &policy, Some(thread_id)); - - let mut expected: HashMap = hashmap! { - "PATH".to_string() => "/usr/bin".to_string(), - "HOME".to_string() => "/home/user".to_string(), - "API_KEY".to_string() => "secret".to_string(), - "SECRET_TOKEN".to_string() => "t".to_string(), - }; - expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); - - assert_eq!(result, expected); - } - - #[test] - fn test_core_inherit_with_default_excludes_enabled() { - let vars = make_vars(&[ - ("PATH", "/usr/bin"), - ("HOME", "/home/user"), - ("API_KEY", "secret"), - ("SECRET_TOKEN", "t"), - ]); - - let policy = ShellEnvironmentPolicy { - ignore_default_excludes: false, // apply KEY/SECRET/TOKEN filter - ..Default::default() - }; - let thread_id = ThreadId::new(); - let result = populate_env(vars, &policy, Some(thread_id)); - - let mut expected: HashMap = hashmap! { - "PATH".to_string() => "/usr/bin".to_string(), - "HOME".to_string() => "/home/user".to_string(), - }; - expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); - - assert_eq!(result, expected); - } - - #[test] - fn test_include_only() { - let vars = make_vars(&[("PATH", "/usr/bin"), ("FOO", "bar")]); - - let policy = ShellEnvironmentPolicy { - // skip default excludes so nothing is removed prematurely - ignore_default_excludes: true, - include_only: vec![EnvironmentVariablePattern::new_case_insensitive("*PATH")], - ..Default::default() - }; - - let thread_id = ThreadId::new(); - let result = populate_env(vars, &policy, Some(thread_id)); - - let mut expected: HashMap = hashmap! { - "PATH".to_string() => "/usr/bin".to_string(), - }; - expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); - - assert_eq!(result, expected); - } - - #[test] - fn test_set_overrides() { - let vars = make_vars(&[("PATH", "/usr/bin")]); - - let mut policy = ShellEnvironmentPolicy { - ignore_default_excludes: true, - ..Default::default() - }; - policy.r#set.insert("NEW_VAR".to_string(), "42".to_string()); - - let thread_id = ThreadId::new(); - let result = populate_env(vars, &policy, Some(thread_id)); - - let mut expected: HashMap = hashmap! { - "PATH".to_string() => "/usr/bin".to_string(), - "NEW_VAR".to_string() => "42".to_string(), - }; - expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); - - assert_eq!(result, expected); - } - - #[test] - fn populate_env_inserts_thread_id() { - let vars = make_vars(&[("PATH", "/usr/bin")]); - let policy = ShellEnvironmentPolicy::default(); - let thread_id = ThreadId::new(); - let result = populate_env(vars, &policy, Some(thread_id)); - - let mut expected: HashMap = hashmap! { - "PATH".to_string() => "/usr/bin".to_string(), - }; - expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); - - assert_eq!(result, expected); - } - - #[test] - fn populate_env_omits_thread_id_when_missing() { - let vars = make_vars(&[("PATH", "/usr/bin")]); - let policy = ShellEnvironmentPolicy::default(); - let result = populate_env(vars, &policy, None); - - let expected: HashMap = hashmap! { - "PATH".to_string() => "/usr/bin".to_string(), - }; - - assert_eq!(result, expected); - } - - #[test] - fn test_inherit_all() { - let vars = make_vars(&[("PATH", "/usr/bin"), ("FOO", "bar")]); - - let policy = ShellEnvironmentPolicy { - inherit: ShellEnvironmentPolicyInherit::All, - ignore_default_excludes: true, // keep everything - ..Default::default() - }; - - let thread_id = ThreadId::new(); - let result = populate_env(vars.clone(), &policy, Some(thread_id)); - let mut expected: HashMap = vars.into_iter().collect(); - expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); - assert_eq!(result, expected); - } - - #[test] - fn test_inherit_all_with_default_excludes() { - let vars = make_vars(&[("PATH", "/usr/bin"), ("API_KEY", "secret")]); - - let policy = ShellEnvironmentPolicy { - inherit: ShellEnvironmentPolicyInherit::All, - ignore_default_excludes: false, - ..Default::default() - }; - - let thread_id = ThreadId::new(); - let result = populate_env(vars, &policy, Some(thread_id)); - let mut expected: HashMap = hashmap! { - "PATH".to_string() => "/usr/bin".to_string(), - }; - expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); - assert_eq!(result, expected); - } - - #[test] - #[cfg(target_os = "windows")] - fn test_core_inherit_respects_case_insensitive_names_on_windows() { - let vars = make_vars(&[ - ("Path", "C:\\Windows\\System32"), - ("TEMP", "C:\\Temp"), - ("FOO", "bar"), - ]); - - let policy = ShellEnvironmentPolicy { - inherit: ShellEnvironmentPolicyInherit::Core, - ignore_default_excludes: true, - ..Default::default() - }; - - let thread_id = ThreadId::new(); - let result = populate_env(vars, &policy, Some(thread_id)); - let mut expected: HashMap = hashmap! { - "Path".to_string() => "C:\\Windows\\System32".to_string(), - "TEMP".to_string() => "C:\\Temp".to_string(), - }; - expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); - - assert_eq!(result, expected); - } - - #[test] - fn test_inherit_none() { - let vars = make_vars(&[("PATH", "/usr/bin"), ("HOME", "/home")]); - - let mut policy = ShellEnvironmentPolicy { - inherit: ShellEnvironmentPolicyInherit::None, - ignore_default_excludes: true, - ..Default::default() - }; - policy - .r#set - .insert("ONLY_VAR".to_string(), "yes".to_string()); - - let thread_id = ThreadId::new(); - let result = populate_env(vars, &policy, Some(thread_id)); - let mut expected: HashMap = hashmap! { - "ONLY_VAR".to_string() => "yes".to_string(), - }; - expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); - assert_eq!(result, expected); - } -} +#[path = "exec_env_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/exec_env_tests.rs b/codex-rs/core/src/exec_env_tests.rs new file mode 100644 index 0000000000..6f001b5828 --- /dev/null +++ b/codex-rs/core/src/exec_env_tests.rs @@ -0,0 +1,215 @@ +use super::*; +use crate::config::types::ShellEnvironmentPolicyInherit; +use maplit::hashmap; + +fn make_vars(pairs: &[(&str, &str)]) -> Vec<(String, String)> { + pairs + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect() +} + +#[test] +fn test_core_inherit_defaults_keep_sensitive_vars() { + let vars = make_vars(&[ + ("PATH", "/usr/bin"), + ("HOME", "/home/user"), + ("API_KEY", "secret"), + ("SECRET_TOKEN", "t"), + ]); + + let policy = ShellEnvironmentPolicy::default(); // inherit All, default excludes ignored + let thread_id = ThreadId::new(); + let result = populate_env(vars, &policy, Some(thread_id)); + + let mut expected: HashMap = hashmap! { + "PATH".to_string() => "/usr/bin".to_string(), + "HOME".to_string() => "/home/user".to_string(), + "API_KEY".to_string() => "secret".to_string(), + "SECRET_TOKEN".to_string() => "t".to_string(), + }; + expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); + + assert_eq!(result, expected); +} + +#[test] +fn test_core_inherit_with_default_excludes_enabled() { + let vars = make_vars(&[ + ("PATH", "/usr/bin"), + ("HOME", "/home/user"), + ("API_KEY", "secret"), + ("SECRET_TOKEN", "t"), + ]); + + let policy = ShellEnvironmentPolicy { + ignore_default_excludes: false, // apply KEY/SECRET/TOKEN filter + ..Default::default() + }; + let thread_id = ThreadId::new(); + let result = populate_env(vars, &policy, Some(thread_id)); + + let mut expected: HashMap = hashmap! { + "PATH".to_string() => "/usr/bin".to_string(), + "HOME".to_string() => "/home/user".to_string(), + }; + expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); + + assert_eq!(result, expected); +} + +#[test] +fn test_include_only() { + let vars = make_vars(&[("PATH", "/usr/bin"), ("FOO", "bar")]); + + let policy = ShellEnvironmentPolicy { + // skip default excludes so nothing is removed prematurely + ignore_default_excludes: true, + include_only: vec![EnvironmentVariablePattern::new_case_insensitive("*PATH")], + ..Default::default() + }; + + let thread_id = ThreadId::new(); + let result = populate_env(vars, &policy, Some(thread_id)); + + let mut expected: HashMap = hashmap! { + "PATH".to_string() => "/usr/bin".to_string(), + }; + expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); + + assert_eq!(result, expected); +} + +#[test] +fn test_set_overrides() { + let vars = make_vars(&[("PATH", "/usr/bin")]); + + let mut policy = ShellEnvironmentPolicy { + ignore_default_excludes: true, + ..Default::default() + }; + policy.r#set.insert("NEW_VAR".to_string(), "42".to_string()); + + let thread_id = ThreadId::new(); + let result = populate_env(vars, &policy, Some(thread_id)); + + let mut expected: HashMap = hashmap! { + "PATH".to_string() => "/usr/bin".to_string(), + "NEW_VAR".to_string() => "42".to_string(), + }; + expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); + + assert_eq!(result, expected); +} + +#[test] +fn populate_env_inserts_thread_id() { + let vars = make_vars(&[("PATH", "/usr/bin")]); + let policy = ShellEnvironmentPolicy::default(); + let thread_id = ThreadId::new(); + let result = populate_env(vars, &policy, Some(thread_id)); + + let mut expected: HashMap = hashmap! { + "PATH".to_string() => "/usr/bin".to_string(), + }; + expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); + + assert_eq!(result, expected); +} + +#[test] +fn populate_env_omits_thread_id_when_missing() { + let vars = make_vars(&[("PATH", "/usr/bin")]); + let policy = ShellEnvironmentPolicy::default(); + let result = populate_env(vars, &policy, None); + + let expected: HashMap = hashmap! { + "PATH".to_string() => "/usr/bin".to_string(), + }; + + assert_eq!(result, expected); +} + +#[test] +fn test_inherit_all() { + let vars = make_vars(&[("PATH", "/usr/bin"), ("FOO", "bar")]); + + let policy = ShellEnvironmentPolicy { + inherit: ShellEnvironmentPolicyInherit::All, + ignore_default_excludes: true, // keep everything + ..Default::default() + }; + + let thread_id = ThreadId::new(); + let result = populate_env(vars.clone(), &policy, Some(thread_id)); + let mut expected: HashMap = vars.into_iter().collect(); + expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); + assert_eq!(result, expected); +} + +#[test] +fn test_inherit_all_with_default_excludes() { + let vars = make_vars(&[("PATH", "/usr/bin"), ("API_KEY", "secret")]); + + let policy = ShellEnvironmentPolicy { + inherit: ShellEnvironmentPolicyInherit::All, + ignore_default_excludes: false, + ..Default::default() + }; + + let thread_id = ThreadId::new(); + let result = populate_env(vars, &policy, Some(thread_id)); + let mut expected: HashMap = hashmap! { + "PATH".to_string() => "/usr/bin".to_string(), + }; + expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); + assert_eq!(result, expected); +} + +#[test] +#[cfg(target_os = "windows")] +fn test_core_inherit_respects_case_insensitive_names_on_windows() { + let vars = make_vars(&[ + ("Path", "C:\\Windows\\System32"), + ("TEMP", "C:\\Temp"), + ("FOO", "bar"), + ]); + + let policy = ShellEnvironmentPolicy { + inherit: ShellEnvironmentPolicyInherit::Core, + ignore_default_excludes: true, + ..Default::default() + }; + + let thread_id = ThreadId::new(); + let result = populate_env(vars, &policy, Some(thread_id)); + let mut expected: HashMap = hashmap! { + "Path".to_string() => "C:\\Windows\\System32".to_string(), + "TEMP".to_string() => "C:\\Temp".to_string(), + }; + expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); + + assert_eq!(result, expected); +} + +#[test] +fn test_inherit_none() { + let vars = make_vars(&[("PATH", "/usr/bin"), ("HOME", "/home")]); + + let mut policy = ShellEnvironmentPolicy { + inherit: ShellEnvironmentPolicyInherit::None, + ignore_default_excludes: true, + ..Default::default() + }; + policy + .r#set + .insert("ONLY_VAR".to_string(), "yes".to_string()); + + let thread_id = ThreadId::new(); + let result = populate_env(vars, &policy, Some(thread_id)); + let mut expected: HashMap = hashmap! { + "ONLY_VAR".to_string() => "yes".to_string(), + }; + expected.insert(CODEX_THREAD_ID_ENV_VAR.to_string(), thread_id.to_string()); + assert_eq!(result, expected); +} diff --git a/codex-rs/core/src/exec_policy.rs b/codex-rs/core/src/exec_policy.rs index 830fbb2ce5..2c9ba28b11 100644 --- a/codex-rs/core/src/exec_policy.rs +++ b/codex-rs/core/src/exec_policy.rs @@ -823,1606 +823,5 @@ async fn collect_policy_files(dir: impl AsRef) -> Result, Exe } #[cfg(test)] -mod tests { - use super::*; - use crate::config_loader::ConfigLayerEntry; - use crate::config_loader::ConfigLayerStack; - use crate::config_loader::ConfigRequirements; - use crate::config_loader::ConfigRequirementsToml; - use codex_app_server_protocol::ConfigLayerSource; - use codex_protocol::permissions::FileSystemAccessMode; - use codex_protocol::permissions::FileSystemPath; - use codex_protocol::permissions::FileSystemSandboxEntry; - use codex_protocol::permissions::FileSystemSpecialPath; - use codex_protocol::protocol::AskForApproval; - use codex_protocol::protocol::RejectConfig; - use codex_protocol::protocol::SandboxPolicy; - use codex_utils_absolute_path::AbsolutePathBuf; - use pretty_assertions::assert_eq; - use std::fs; - use std::path::Path; - use std::path::PathBuf; - use std::sync::Arc; - use tempfile::tempdir; - use toml::Value as TomlValue; - - fn config_stack_for_dot_codex_folder(dot_codex_folder: &Path) -> ConfigLayerStack { - let dot_codex_folder = AbsolutePathBuf::from_absolute_path(dot_codex_folder) - .expect("absolute dot_codex_folder"); - let layer = ConfigLayerEntry::new( - ConfigLayerSource::Project { dot_codex_folder }, - TomlValue::Table(Default::default()), - ); - ConfigLayerStack::new( - vec![layer], - ConfigRequirements::default(), - ConfigRequirementsToml::default(), - ) - .expect("ConfigLayerStack") - } - - fn host_absolute_path(segments: &[&str]) -> String { - let mut path = if cfg!(windows) { - PathBuf::from(r"C:\") - } else { - PathBuf::from("/") - }; - for segment in segments { - path.push(segment); - } - path.to_string_lossy().into_owned() - } - - fn host_program_path(name: &str) -> String { - let executable_name = if cfg!(windows) { - format!("{name}.exe") - } else { - name.to_string() - }; - host_absolute_path(&["usr", "bin", &executable_name]) - } - - fn starlark_string(value: &str) -> String { - value.replace('\\', "\\\\").replace('"', "\\\"") - } - - fn read_only_file_system_sandbox_policy() -> FileSystemSandboxPolicy { - FileSystemSandboxPolicy::restricted(vec![FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: FileSystemSpecialPath::Root, - }, - access: FileSystemAccessMode::Read, - }]) - } - - fn unrestricted_file_system_sandbox_policy() -> FileSystemSandboxPolicy { - FileSystemSandboxPolicy::unrestricted() - } - - #[tokio::test] - async fn returns_empty_policy_when_no_policy_files_exist() { - let temp_dir = tempdir().expect("create temp dir"); - let config_stack = config_stack_for_dot_codex_folder(temp_dir.path()); - - let manager = ExecPolicyManager::load(&config_stack) - .await - .expect("manager result"); - let policy = manager.current(); - - let commands = [vec!["rm".to_string()]]; - assert_eq!( - Evaluation { - decision: Decision::Allow, - matched_rules: vec![RuleMatch::HeuristicsRuleMatch { - command: vec!["rm".to_string()], - decision: Decision::Allow - }], - }, - policy.check_multiple(commands.iter(), &|_| Decision::Allow) - ); - assert!(!temp_dir.path().join(RULES_DIR_NAME).exists()); - } - - #[tokio::test] - async fn collect_policy_files_returns_empty_when_dir_missing() { - let temp_dir = tempdir().expect("create temp dir"); - - let policy_dir = temp_dir.path().join(RULES_DIR_NAME); - let files = collect_policy_files(&policy_dir) - .await - .expect("collect policy files"); - - assert!(files.is_empty()); - } - - #[tokio::test] - async fn format_exec_policy_error_with_source_renders_range() { - let temp_dir = tempdir().expect("create temp dir"); - let config_stack = config_stack_for_dot_codex_folder(temp_dir.path()); - let policy_dir = temp_dir.path().join(RULES_DIR_NAME); - fs::create_dir_all(&policy_dir).expect("create policy dir"); - let broken_path = policy_dir.join("broken.rules"); - fs::write( - &broken_path, - r#"prefix_rule( - pattern = ["tmux capture-pane"], - decision = "allow", - match = ["tmux capture-pane -p"], -)"#, - ) - .expect("write broken policy file"); - - let err = load_exec_policy(&config_stack) - .await - .expect_err("expected parse error"); - let rendered = format_exec_policy_error_with_source(&err); - - assert!(rendered.contains("broken.rules:1:")); - assert!(rendered.contains("on or around line 1")); - } - - #[test] - fn parse_starlark_line_from_message_extracts_path_and_line() { - let parsed = parse_starlark_line_from_message( - "/tmp/default.rules:143:1: starlark error: error: Parse error: unexpected new line", - ) - .expect("parse should succeed"); - - assert_eq!(parsed.0, PathBuf::from("/tmp/default.rules")); - assert_eq!(parsed.1, 143); - } - - #[test] - fn parse_starlark_line_from_message_rejects_zero_line() { - let parsed = parse_starlark_line_from_message( - "/tmp/default.rules:0:1: starlark error: error: Parse error: unexpected new line", - ); - assert_eq!(parsed, None); - } - - #[tokio::test] - async fn loads_policies_from_policy_subdirectory() { - let temp_dir = tempdir().expect("create temp dir"); - let config_stack = config_stack_for_dot_codex_folder(temp_dir.path()); - let policy_dir = temp_dir.path().join(RULES_DIR_NAME); - fs::create_dir_all(&policy_dir).expect("create policy dir"); - fs::write( - policy_dir.join("deny.rules"), - r#"prefix_rule(pattern=["rm"], decision="forbidden")"#, - ) - .expect("write policy file"); - - let policy = load_exec_policy(&config_stack) - .await - .expect("policy result"); - let command = [vec!["rm".to_string()]]; - assert_eq!( - Evaluation { - decision: Decision::Forbidden, - matched_rules: vec![RuleMatch::PrefixRuleMatch { - matched_prefix: vec!["rm".to_string()], - decision: Decision::Forbidden, - resolved_program: None, - justification: None, - }], - }, - policy.check_multiple(command.iter(), &|_| Decision::Allow) - ); - } - - #[tokio::test] - async fn merges_requirements_exec_policy_network_rules() -> anyhow::Result<()> { - let temp_dir = tempdir()?; - - let mut requirements_exec_policy = Policy::empty(); - requirements_exec_policy.add_network_rule( - "blocked.example.com", - codex_execpolicy::NetworkRuleProtocol::Https, - Decision::Forbidden, - None, - )?; - - let requirements = ConfigRequirements { - exec_policy: Some(codex_config::Sourced::new( - codex_config::RequirementsExecPolicy::new(requirements_exec_policy), - codex_config::RequirementSource::Unknown, - )), - ..ConfigRequirements::default() - }; - let dot_codex_folder = AbsolutePathBuf::from_absolute_path(temp_dir.path())?; - let layer = ConfigLayerEntry::new( - ConfigLayerSource::Project { dot_codex_folder }, - TomlValue::Table(Default::default()), - ); - let config_stack = - ConfigLayerStack::new(vec![layer], requirements, ConfigRequirementsToml::default())?; - - let policy = load_exec_policy(&config_stack).await?; - let (allowed, denied) = policy.compiled_network_domains(); - - assert!(allowed.is_empty()); - assert_eq!(denied, vec!["blocked.example.com".to_string()]); - Ok(()) - } - - #[tokio::test] - async fn preserves_host_executables_when_requirements_overlay_is_present() -> anyhow::Result<()> - { - let temp_dir = tempdir()?; - let policy_dir = temp_dir.path().join(RULES_DIR_NAME); - fs::create_dir_all(&policy_dir)?; - let git_path = host_absolute_path(&["usr", "bin", "git"]); - let git_path_literal = starlark_string(&git_path); - fs::write( - policy_dir.join("host.rules"), - format!( - r#" -host_executable(name = "git", paths = ["{git_path_literal}"]) -"# - ), - )?; - - let mut requirements_exec_policy = Policy::empty(); - requirements_exec_policy.add_network_rule( - "blocked.example.com", - codex_execpolicy::NetworkRuleProtocol::Https, - Decision::Forbidden, - None, - )?; - - let requirements = ConfigRequirements { - exec_policy: Some(codex_config::Sourced::new( - codex_config::RequirementsExecPolicy::new(requirements_exec_policy), - codex_config::RequirementSource::Unknown, - )), - ..ConfigRequirements::default() - }; - let dot_codex_folder = AbsolutePathBuf::from_absolute_path(temp_dir.path())?; - let layer = ConfigLayerEntry::new( - ConfigLayerSource::Project { dot_codex_folder }, - TomlValue::Table(Default::default()), - ); - let config_stack = - ConfigLayerStack::new(vec![layer], requirements, ConfigRequirementsToml::default())?; - - let policy = load_exec_policy(&config_stack).await?; - - assert_eq!( - policy - .host_executables() - .get("git") - .expect("missing git host executable") - .as_ref(), - [AbsolutePathBuf::try_from(git_path)?] - ); - Ok(()) - } - - #[tokio::test] - async fn ignores_policies_outside_policy_dir() { - let temp_dir = tempdir().expect("create temp dir"); - let config_stack = config_stack_for_dot_codex_folder(temp_dir.path()); - fs::write( - temp_dir.path().join("root.rules"), - r#"prefix_rule(pattern=["ls"], decision="prompt")"#, - ) - .expect("write policy file"); - - let policy = load_exec_policy(&config_stack) - .await - .expect("policy result"); - let command = [vec!["ls".to_string()]]; - assert_eq!( - Evaluation { - decision: Decision::Allow, - matched_rules: vec![RuleMatch::HeuristicsRuleMatch { - command: vec!["ls".to_string()], - decision: Decision::Allow - }], - }, - policy.check_multiple(command.iter(), &|_| Decision::Allow) - ); - } - - #[tokio::test] - async fn ignores_rules_from_untrusted_project_layers() -> anyhow::Result<()> { - let project_dir = tempdir()?; - let policy_dir = project_dir.path().join(RULES_DIR_NAME); - fs::create_dir_all(&policy_dir)?; - fs::write( - policy_dir.join("untrusted.rules"), - r#"prefix_rule(pattern=["ls"], decision="forbidden")"#, - )?; - - let project_dot_codex_folder = AbsolutePathBuf::from_absolute_path(project_dir.path())?; - let layers = vec![ConfigLayerEntry::new_disabled( - ConfigLayerSource::Project { - dot_codex_folder: project_dot_codex_folder, - }, - TomlValue::Table(Default::default()), - "marked untrusted", - )]; - let config_stack = ConfigLayerStack::new( - layers, - ConfigRequirements::default(), - ConfigRequirementsToml::default(), - )?; - - let policy = load_exec_policy(&config_stack).await?; - - assert_eq!( - Evaluation { - decision: Decision::Allow, - matched_rules: vec![RuleMatch::HeuristicsRuleMatch { - command: vec!["ls".to_string()], - decision: Decision::Allow, - }], - }, - policy.check_multiple([vec!["ls".to_string()]].iter(), &|_| Decision::Allow) - ); - Ok(()) - } - - #[tokio::test] - async fn loads_policies_from_multiple_config_layers() -> anyhow::Result<()> { - let user_dir = tempdir()?; - let project_dir = tempdir()?; - - let user_policy_dir = user_dir.path().join(RULES_DIR_NAME); - fs::create_dir_all(&user_policy_dir)?; - fs::write( - user_policy_dir.join("user.rules"), - r#"prefix_rule(pattern=["rm"], decision="forbidden")"#, - )?; - - let project_policy_dir = project_dir.path().join(RULES_DIR_NAME); - fs::create_dir_all(&project_policy_dir)?; - fs::write( - project_policy_dir.join("project.rules"), - r#"prefix_rule(pattern=["ls"], decision="prompt")"#, - )?; - - let user_config_toml = - AbsolutePathBuf::from_absolute_path(user_dir.path().join("config.toml"))?; - let project_dot_codex_folder = AbsolutePathBuf::from_absolute_path(project_dir.path())?; - let layers = vec![ - ConfigLayerEntry::new( - ConfigLayerSource::User { - file: user_config_toml, - }, - TomlValue::Table(Default::default()), - ), - ConfigLayerEntry::new( - ConfigLayerSource::Project { - dot_codex_folder: project_dot_codex_folder, - }, - TomlValue::Table(Default::default()), - ), - ]; - let config_stack = ConfigLayerStack::new( - layers, - ConfigRequirements::default(), - ConfigRequirementsToml::default(), - )?; - - let policy = load_exec_policy(&config_stack).await?; - - assert_eq!( - Evaluation { - decision: Decision::Forbidden, - matched_rules: vec![RuleMatch::PrefixRuleMatch { - matched_prefix: vec!["rm".to_string()], - decision: Decision::Forbidden, - resolved_program: None, - justification: None, - }], - }, - policy.check_multiple([vec!["rm".to_string()]].iter(), &|_| Decision::Allow) - ); - assert_eq!( - Evaluation { - decision: Decision::Prompt, - matched_rules: vec![RuleMatch::PrefixRuleMatch { - matched_prefix: vec!["ls".to_string()], - decision: Decision::Prompt, - resolved_program: None, - justification: None, - }], - }, - policy.check_multiple([vec!["ls".to_string()]].iter(), &|_| Decision::Allow) - ); - Ok(()) - } - - #[tokio::test] - async fn evaluates_bash_lc_inner_commands() { - let policy_src = r#" -prefix_rule(pattern=["rm"], decision="forbidden") -"#; - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", policy_src) - .expect("parse policy"); - let policy = Arc::new(parser.build()); - - let forbidden_script = vec![ - "bash".to_string(), - "-lc".to_string(), - "rm -rf /some/important/folder".to_string(), - ]; - - let manager = ExecPolicyManager::new(policy); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &forbidden_script, - approval_policy: AskForApproval::OnRequest, - sandbox_policy: &SandboxPolicy::DangerFullAccess, - file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::Forbidden { - reason: "`bash -lc 'rm -rf /some/important/folder'` rejected: policy forbids commands starting with `rm`".to_string() - } - ); - } - - #[test] - fn commands_for_exec_policy_falls_back_for_empty_shell_script() { - let command = vec!["bash".to_string(), "-lc".to_string(), "".to_string()]; - - assert_eq!(commands_for_exec_policy(&command), (vec![command], false)); - } - - #[test] - fn commands_for_exec_policy_falls_back_for_whitespace_shell_script() { - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - " \n\t ".to_string(), - ]; - - assert_eq!(commands_for_exec_policy(&command), (vec![command], false)); - } - - #[tokio::test] - async fn evaluates_heredoc_script_against_prefix_rules() { - let policy_src = r#"prefix_rule(pattern=["python3"], decision="allow")"#; - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", policy_src) - .expect("parse policy"); - let policy = Arc::new(parser.build()); - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - "python3 <<'PY'\nprint('hello')\nPY".to_string(), - ]; - - let requirement = ExecPolicyManager::new(policy) - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::OnRequest, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::Skip { - bypass_sandbox: true, - proposed_execpolicy_amendment: None, - } - ); - } - - #[tokio::test] - async fn omits_auto_amendment_for_heredoc_fallback_prompts() { - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - "python3 <<'PY'\nprint('hello')\nPY".to_string(), - ]; - - let requirement = ExecPolicyManager::default() - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::UnlessTrusted, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: None, - } - ); - } - - #[tokio::test] - async fn drops_requested_amendment_for_heredoc_fallback_prompts_when_it_wont_match() { - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - "python3 <<'PY'\nprint('hello')\nPY".to_string(), - ]; - let requested_prefix = vec!["python3".to_string(), "-m".to_string(), "pip".to_string()]; - - let requirement = ExecPolicyManager::default() - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::UnlessTrusted, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: Some(requested_prefix.clone()), - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: None, - } - ); - } - - #[tokio::test] - async fn justification_is_included_in_forbidden_exec_approval_requirement() { - let policy_src = r#" -prefix_rule( - pattern=["rm"], - decision="forbidden", - justification="destructive command", -) -"#; - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", policy_src) - .expect("parse policy"); - let policy = Arc::new(parser.build()); - - let manager = ExecPolicyManager::new(policy); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &[ - "rm".to_string(), - "-rf".to_string(), - "/some/important/folder".to_string(), - ], - approval_policy: AskForApproval::OnRequest, - sandbox_policy: &SandboxPolicy::DangerFullAccess, - file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::Forbidden { - reason: "`rm -rf /some/important/folder` rejected: destructive command".to_string() - } - ); - } - - #[tokio::test] - async fn exec_approval_requirement_prefers_execpolicy_match() { - let policy_src = r#"prefix_rule(pattern=["rm"], decision="prompt")"#; - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", policy_src) - .expect("parse policy"); - let policy = Arc::new(parser.build()); - let command = vec!["rm".to_string()]; - - let manager = ExecPolicyManager::new(policy); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::OnRequest, - sandbox_policy: &SandboxPolicy::DangerFullAccess, - file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: Some("`rm` requires approval by policy".to_string()), - proposed_execpolicy_amendment: None, - } - ); - } - - #[tokio::test] - async fn absolute_path_exec_approval_requirement_matches_host_executable_rules() { - let git_path = host_program_path("git"); - let git_path_literal = starlark_string(&git_path); - let policy_src = format!( - r#" -host_executable(name = "git", paths = ["{git_path_literal}"]) -prefix_rule(pattern=["git"], decision="allow") -"# - ); - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", &policy_src) - .expect("parse policy"); - let manager = ExecPolicyManager::new(Arc::new(parser.build())); - let command = vec![git_path, "status".to_string()]; - - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::UnlessTrusted, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::Skip { - bypass_sandbox: true, - proposed_execpolicy_amendment: None, - } - ); - } - - #[tokio::test] - async fn absolute_path_exec_approval_requirement_ignores_disallowed_host_executable_paths() { - let allowed_git_path = host_program_path("git"); - let disallowed_git_path = host_absolute_path(&[ - "opt", - "homebrew", - "bin", - if cfg!(windows) { "git.exe" } else { "git" }, - ]); - let allowed_git_path_literal = starlark_string(&allowed_git_path); - let policy_src = format!( - r#" -host_executable(name = "git", paths = ["{allowed_git_path_literal}"]) -prefix_rule(pattern=["git"], decision="prompt") -"# - ); - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", &policy_src) - .expect("parse policy"); - let manager = ExecPolicyManager::new(Arc::new(parser.build())); - let command = vec![disallowed_git_path, "status".to_string()]; - - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::UnlessTrusted, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::Skip { - bypass_sandbox: false, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)), - } - ); - } - - #[tokio::test] - async fn requested_prefix_rule_can_approve_absolute_path_commands() { - let command = vec![ - host_program_path("cargo"), - "install".to_string(), - "cargo-insta".to_string(), - ]; - let manager = ExecPolicyManager::default(); - - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::UnlessTrusted, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: Some(vec!["cargo".to_string(), "install".to_string()]), - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec![ - "cargo".to_string(), - "install".to_string(), - ])), - } - ); - } - - #[tokio::test] - async fn exec_approval_requirement_respects_approval_policy() { - let policy_src = r#"prefix_rule(pattern=["rm"], decision="prompt")"#; - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", policy_src) - .expect("parse policy"); - let policy = Arc::new(parser.build()); - let command = vec!["rm".to_string()]; - - let manager = ExecPolicyManager::new(policy); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::Never, - sandbox_policy: &SandboxPolicy::DangerFullAccess, - file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::Forbidden { - reason: PROMPT_CONFLICT_REASON.to_string() - } - ); - } - - #[test] - fn unmatched_reject_policy_still_prompts_for_restricted_sandbox_escalation() { - let command = vec!["madeup-cmd".to_string()]; - - assert_eq!( - Decision::Prompt, - render_decision_for_unmatched_command( - AskForApproval::Reject(RejectConfig { - sandbox_approval: false, - rules: false, - skill_approval: false, - request_permissions: false, - mcp_elicitations: false, - }), - &SandboxPolicy::new_read_only_policy(), - &read_only_file_system_sandbox_policy(), - &command, - SandboxPermissions::RequireEscalated, - false, - ) - ); - } - - #[test] - fn unmatched_on_request_uses_split_filesystem_policy_for_escalation_prompts() { - let command = vec!["madeup-cmd".to_string()]; - let restricted_file_system_policy = FileSystemSandboxPolicy::restricted(vec![]); - - assert_eq!( - Decision::Prompt, - render_decision_for_unmatched_command( - AskForApproval::OnRequest, - &SandboxPolicy::DangerFullAccess, - &restricted_file_system_policy, - &command, - SandboxPermissions::RequireEscalated, - false, - ) - ); - } - - #[tokio::test] - async fn exec_approval_requirement_rejects_unmatched_sandbox_escalation_when_sandbox_rejection_enabled() - { - let command = vec!["madeup-cmd".to_string()]; - - let requirement = ExecPolicyManager::default() - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::Reject(RejectConfig { - sandbox_approval: true, - rules: false, - skill_approval: false, - request_permissions: false, - mcp_elicitations: false, - }), - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::RequireEscalated, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::Forbidden { - reason: REJECT_SANDBOX_APPROVAL_REASON.to_string(), - } - ); - } - - #[tokio::test] - async fn mixed_rule_and_sandbox_prompt_prioritizes_rule_for_rejection_decision() { - let policy_src = r#"prefix_rule(pattern=["git"], decision="prompt")"#; - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", policy_src) - .expect("parse policy"); - let manager = ExecPolicyManager::new(Arc::new(parser.build())); - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - "git status && madeup-cmd".to_string(), - ]; - - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::Reject(RejectConfig { - sandbox_approval: true, - rules: false, - skill_approval: false, - request_permissions: false, - mcp_elicitations: false, - }), - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::RequireEscalated, - prefix_rule: None, - }) - .await; - - assert!(matches!( - requirement, - ExecApprovalRequirement::NeedsApproval { .. } - )); - } - - #[tokio::test] - async fn mixed_rule_and_sandbox_prompt_rejects_when_rules_rejection_enabled() { - let policy_src = r#"prefix_rule(pattern=["git"], decision="prompt")"#; - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", policy_src) - .expect("parse policy"); - let manager = ExecPolicyManager::new(Arc::new(parser.build())); - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - "git status && madeup-cmd".to_string(), - ]; - - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::Reject(RejectConfig { - sandbox_approval: false, - rules: true, - skill_approval: false, - request_permissions: false, - mcp_elicitations: false, - }), - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::RequireEscalated, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::Forbidden { - reason: REJECT_RULES_APPROVAL_REASON.to_string(), - } - ); - } - - #[tokio::test] - async fn exec_approval_requirement_falls_back_to_heuristics() { - let command = vec!["cargo".to_string(), "build".to_string()]; - - let manager = ExecPolicyManager::default(); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::UnlessTrusted, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)) - } - ); - } - - #[tokio::test] - async fn empty_bash_lc_script_falls_back_to_original_command() { - let command = vec!["bash".to_string(), "-lc".to_string(), "".to_string()]; - - let manager = ExecPolicyManager::default(); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::UnlessTrusted, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)), - } - ); - } - - #[tokio::test] - async fn whitespace_bash_lc_script_falls_back_to_original_command() { - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - " \n\t ".to_string(), - ]; - - let manager = ExecPolicyManager::default(); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::UnlessTrusted, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)), - } - ); - } - - #[tokio::test] - async fn request_rule_uses_prefix_rule() { - let command = vec![ - "cargo".to_string(), - "install".to_string(), - "cargo-insta".to_string(), - ]; - let manager = ExecPolicyManager::default(); - - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::OnRequest, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::RequireEscalated, - prefix_rule: Some(vec!["cargo".to_string(), "install".to_string()]), - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec![ - "cargo".to_string(), - "install".to_string(), - ])), - } - ); - } - - #[tokio::test] - async fn request_rule_falls_back_when_prefix_rule_does_not_approve_all_commands() { - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - "cargo install cargo-insta && rm -rf /tmp/codex".to_string(), - ]; - let manager = ExecPolicyManager::default(); - - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::OnRequest, - sandbox_policy: &SandboxPolicy::DangerFullAccess, - file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::RequireEscalated, - prefix_rule: Some(vec!["cargo".to_string(), "install".to_string()]), - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec![ - "rm".to_string(), - "-rf".to_string(), - "/tmp/codex".to_string(), - ])), - } - ); - } - - #[tokio::test] - async fn heuristics_apply_when_other_commands_match_policy() { - let policy_src = r#"prefix_rule(pattern=["apple"], decision="allow")"#; - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", policy_src) - .expect("parse policy"); - let policy = Arc::new(parser.build()); - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - "apple | orange".to_string(), - ]; - - assert_eq!( - ExecPolicyManager::new(policy) - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::UnlessTrusted, - sandbox_policy: &SandboxPolicy::DangerFullAccess, - file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec![ - "orange".to_string() - ])) - } - ); - } - - #[tokio::test] - async fn append_execpolicy_amendment_updates_policy_and_file() { - let codex_home = tempdir().expect("create temp dir"); - let prefix = vec!["echo".to_string(), "hello".to_string()]; - let manager = ExecPolicyManager::default(); - - manager - .append_amendment_and_update(codex_home.path(), &ExecPolicyAmendment::from(prefix)) - .await - .expect("update policy"); - let updated_policy = manager.current(); - - let evaluation = updated_policy.check( - &["echo".to_string(), "hello".to_string(), "world".to_string()], - &|_| Decision::Allow, - ); - assert!(matches!( - evaluation, - Evaluation { - decision: Decision::Allow, - .. - } - )); - - let contents = fs::read_to_string(default_policy_path(codex_home.path())) - .expect("policy file should have been created"); - assert_eq!( - contents, - r#"prefix_rule(pattern=["echo", "hello"], decision="allow") -"# - ); - } - - #[tokio::test] - async fn append_execpolicy_amendment_rejects_empty_prefix() { - let codex_home = tempdir().expect("create temp dir"); - let manager = ExecPolicyManager::default(); - - let result = manager - .append_amendment_and_update(codex_home.path(), &ExecPolicyAmendment::from(vec![])) - .await; - - assert!(matches!( - result, - Err(ExecPolicyUpdateError::AppendRule { - source: AmendError::EmptyPrefix, - .. - }) - )); - } - - #[tokio::test] - async fn proposed_execpolicy_amendment_is_present_for_single_command_without_policy_match() { - let command = vec!["cargo".to_string(), "build".to_string()]; - - let manager = ExecPolicyManager::default(); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::UnlessTrusted, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)) - } - ); - } - - #[tokio::test] - async fn proposed_execpolicy_amendment_is_omitted_when_policy_prompts() { - let policy_src = r#"prefix_rule(pattern=["rm"], decision="prompt")"#; - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", policy_src) - .expect("parse policy"); - let policy = Arc::new(parser.build()); - let command = vec!["rm".to_string()]; - - let manager = ExecPolicyManager::new(policy); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::OnRequest, - sandbox_policy: &SandboxPolicy::DangerFullAccess, - file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: Some("`rm` requires approval by policy".to_string()), - proposed_execpolicy_amendment: None, - } - ); - } - - #[tokio::test] - async fn proposed_execpolicy_amendment_is_present_for_multi_command_scripts() { - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - "cargo build && echo ok".to_string(), - ]; - let manager = ExecPolicyManager::default(); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::UnlessTrusted, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec![ - "cargo".to_string(), - "build".to_string() - ])), - } - ); - } - - #[tokio::test] - async fn proposed_execpolicy_amendment_uses_first_no_match_in_multi_command_scripts() { - let policy_src = r#"prefix_rule(pattern=["cat"], decision="allow")"#; - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", policy_src) - .expect("parse policy"); - let policy = Arc::new(parser.build()); - - let command = vec![ - "bash".to_string(), - "-lc".to_string(), - "cat && apple".to_string(), - ]; - - assert_eq!( - ExecPolicyManager::new(policy) - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::UnlessTrusted, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec![ - "apple".to_string() - ])), - } - ); - } - - #[tokio::test] - async fn proposed_execpolicy_amendment_is_present_when_heuristics_allow() { - let command = vec!["echo".to_string(), "safe".to_string()]; - - let manager = ExecPolicyManager::default(); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::OnRequest, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::Skip { - bypass_sandbox: false, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)), - } - ); - } - - #[tokio::test] - async fn proposed_execpolicy_amendment_is_suppressed_when_policy_matches_allow() { - let policy_src = r#"prefix_rule(pattern=["echo"], decision="allow")"#; - let mut parser = PolicyParser::new(); - parser - .parse("test.rules", policy_src) - .expect("parse policy"); - let policy = Arc::new(parser.build()); - let command = vec!["echo".to_string(), "safe".to_string()]; - - let manager = ExecPolicyManager::new(policy); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::OnRequest, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::Skip { - bypass_sandbox: true, - proposed_execpolicy_amendment: None, - } - ); - } - - fn derive_requested_execpolicy_amendment_for_test( - prefix_rule: Option<&Vec>, - matched_rules: &[RuleMatch], - ) -> Option { - let commands = prefix_rule - .cloned() - .map(|prefix_rule| vec![prefix_rule]) - .unwrap_or_else(|| vec![vec!["echo".to_string()]]); - derive_requested_execpolicy_amendment_from_prefix_rule( - prefix_rule, - matched_rules, - &Policy::empty(), - &commands, - &|_: &[String]| Decision::Allow, - &MatchOptions::default(), - ) - } - - #[test] - fn derive_requested_execpolicy_amendment_returns_none_for_missing_prefix_rule() { - assert_eq!( - None, - derive_requested_execpolicy_amendment_for_test(None, &[]) - ); - } - - #[test] - fn derive_requested_execpolicy_amendment_returns_none_for_empty_prefix_rule() { - assert_eq!( - None, - derive_requested_execpolicy_amendment_for_test(Some(&Vec::new()), &[]) - ); - } - - #[test] - fn derive_requested_execpolicy_amendment_returns_none_for_exact_banned_prefix_rule() { - assert_eq!( - None, - derive_requested_execpolicy_amendment_for_test( - Some(&vec!["python".to_string(), "-c".to_string()]), - &[], - ) - ); - } - - #[test] - fn derive_requested_execpolicy_amendment_returns_none_for_windows_and_pypy_variants() { - for prefix_rule in [ - vec!["py".to_string()], - vec!["py".to_string(), "-3".to_string()], - vec!["pythonw".to_string()], - vec!["pyw".to_string()], - vec!["pypy".to_string()], - vec!["pypy3".to_string()], - ] { - assert_eq!( - None, - derive_requested_execpolicy_amendment_for_test(Some(&prefix_rule), &[]) - ); - } - } - - #[test] - fn derive_requested_execpolicy_amendment_returns_none_for_shell_and_powershell_variants() { - for prefix_rule in [ - vec!["bash".to_string(), "-lc".to_string()], - vec!["sh".to_string(), "-c".to_string()], - vec!["sh".to_string(), "-lc".to_string()], - vec!["zsh".to_string(), "-lc".to_string()], - vec!["/bin/bash".to_string(), "-lc".to_string()], - vec!["/bin/zsh".to_string(), "-lc".to_string()], - vec!["pwsh".to_string()], - vec!["pwsh".to_string(), "-Command".to_string()], - vec!["pwsh".to_string(), "-c".to_string()], - vec!["powershell".to_string()], - vec!["powershell".to_string(), "-Command".to_string()], - vec!["powershell".to_string(), "-c".to_string()], - vec!["powershell.exe".to_string()], - vec!["powershell.exe".to_string(), "-Command".to_string()], - vec!["powershell.exe".to_string(), "-c".to_string()], - ] { - assert_eq!( - None, - derive_requested_execpolicy_amendment_for_test(Some(&prefix_rule), &[]) - ); - } - } - - #[test] - fn derive_requested_execpolicy_amendment_allows_non_exact_banned_prefix_rule_match() { - let prefix_rule = vec![ - "python".to_string(), - "-c".to_string(), - "print('hi')".to_string(), - ]; - - assert_eq!( - Some(ExecPolicyAmendment::new(prefix_rule.clone())), - derive_requested_execpolicy_amendment_for_test(Some(&prefix_rule), &[]) - ); - } - - #[test] - fn derive_requested_execpolicy_amendment_returns_none_when_policy_matches() { - let prefix_rule = vec!["cargo".to_string(), "build".to_string()]; - - let matched_rules_prompt = vec![RuleMatch::PrefixRuleMatch { - matched_prefix: vec!["cargo".to_string()], - decision: Decision::Prompt, - resolved_program: None, - justification: None, - }]; - assert_eq!( - None, - derive_requested_execpolicy_amendment_for_test( - Some(&prefix_rule), - &matched_rules_prompt - ), - "should return none when prompt policy matches" - ); - let matched_rules_allow = vec![RuleMatch::PrefixRuleMatch { - matched_prefix: vec!["cargo".to_string()], - decision: Decision::Allow, - resolved_program: None, - justification: None, - }]; - assert_eq!( - None, - derive_requested_execpolicy_amendment_for_test( - Some(&prefix_rule), - &matched_rules_allow - ), - "should return none when prompt policy matches" - ); - let matched_rules_forbidden = vec![RuleMatch::PrefixRuleMatch { - matched_prefix: vec!["cargo".to_string()], - decision: Decision::Forbidden, - resolved_program: None, - justification: None, - }]; - assert_eq!( - None, - derive_requested_execpolicy_amendment_for_test( - Some(&prefix_rule), - &matched_rules_forbidden, - ), - "should return none when prompt policy matches" - ); - } - - #[tokio::test] - async fn dangerous_rm_rf_requires_approval_in_danger_full_access() { - let command = vec_str(&["rm", "-rf", "/tmp/nonexistent"]); - let manager = ExecPolicyManager::default(); - let requirement = manager - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &command, - approval_policy: AskForApproval::OnRequest, - sandbox_policy: &SandboxPolicy::DangerFullAccess, - file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), - sandbox_permissions: SandboxPermissions::UseDefault, - prefix_rule: None, - }) - .await; - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)), - } - ); - } - - fn vec_str(items: &[&str]) -> Vec { - items.iter().map(std::string::ToString::to_string).collect() - } - - /// Note this test behaves differently on Windows because it exercises an - /// `if cfg!(windows)` code path in render_decision_for_unmatched_command(). - #[tokio::test] - async fn verify_approval_requirement_for_unsafe_powershell_command() { - // `brew install powershell` to run this test on a Mac! - // Note `pwsh` is required to parse a PowerShell command to see if it - // is safe. - if which::which("pwsh").is_err() { - return; - } - - let policy = ExecPolicyManager::new(Arc::new(Policy::empty())); - let permissions = SandboxPermissions::UseDefault; - - // This command should not be run without user approval unless there is - // a proper sandbox in place to ensure safety. - let sneaky_command = vec_str(&["pwsh", "-Command", "echo hi @(calc)"]); - let expected_amendment = Some(ExecPolicyAmendment::new(vec_str(&[ - "pwsh", - "-Command", - "echo hi @(calc)", - ]))); - let (pwsh_approval_reason, expected_req) = if cfg!(windows) { - ( - r#"On Windows, SandboxPolicy::ReadOnly should be assumed to mean - that no sandbox is present, so anything that is not "provably - safe" should require approval."#, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: expected_amendment.clone(), - }, - ) - } else { - ( - "On non-Windows, rely on the read-only sandbox to prevent harm.", - ExecApprovalRequirement::Skip { - bypass_sandbox: false, - proposed_execpolicy_amendment: expected_amendment.clone(), - }, - ) - }; - assert_eq!( - expected_req, - policy - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &sneaky_command, - approval_policy: AskForApproval::OnRequest, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: permissions, - prefix_rule: None, - }) - .await, - "{pwsh_approval_reason}" - ); - - // This is flagged as a dangerous command on all platforms. - let dangerous_command = vec_str(&["rm", "-rf", "/important/data"]); - assert_eq!( - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec_str(&[ - "rm", - "-rf", - "/important/data", - ]))), - }, - policy - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &dangerous_command, - approval_policy: AskForApproval::OnRequest, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: permissions, - prefix_rule: None, - }) - .await, - r#"On all platforms, a forbidden command should require approval - (unless AskForApproval::Never is specified)."# - ); - - // A dangerous command should be forbidden if the user has specified - // AskForApproval::Never. - assert_eq!( - ExecApprovalRequirement::Forbidden { - reason: "`rm -rf /important/data` rejected: blocked by policy".to_string(), - }, - policy - .create_exec_approval_requirement_for_command(ExecApprovalRequest { - command: &dangerous_command, - approval_policy: AskForApproval::Never, - sandbox_policy: &SandboxPolicy::new_read_only_policy(), - file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), - sandbox_permissions: permissions, - prefix_rule: None, - }) - .await, - r#"On all platforms, a forbidden command should require approval - (unless AskForApproval::Never is specified)."# - ); - } -} +#[path = "exec_policy_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/exec_policy_tests.rs b/codex-rs/core/src/exec_policy_tests.rs new file mode 100644 index 0000000000..8c7286635e --- /dev/null +++ b/codex-rs/core/src/exec_policy_tests.rs @@ -0,0 +1,1594 @@ +use super::*; +use crate::config_loader::ConfigLayerEntry; +use crate::config_loader::ConfigLayerStack; +use crate::config_loader::ConfigRequirements; +use crate::config_loader::ConfigRequirementsToml; +use codex_app_server_protocol::ConfigLayerSource; +use codex_protocol::permissions::FileSystemAccessMode; +use codex_protocol::permissions::FileSystemPath; +use codex_protocol::permissions::FileSystemSandboxEntry; +use codex_protocol::permissions::FileSystemSpecialPath; +use codex_protocol::protocol::AskForApproval; +use codex_protocol::protocol::RejectConfig; +use codex_protocol::protocol::SandboxPolicy; +use codex_utils_absolute_path::AbsolutePathBuf; +use pretty_assertions::assert_eq; +use std::fs; +use std::path::Path; +use std::path::PathBuf; +use std::sync::Arc; +use tempfile::tempdir; +use toml::Value as TomlValue; + +fn config_stack_for_dot_codex_folder(dot_codex_folder: &Path) -> ConfigLayerStack { + let dot_codex_folder = + AbsolutePathBuf::from_absolute_path(dot_codex_folder).expect("absolute dot_codex_folder"); + let layer = ConfigLayerEntry::new( + ConfigLayerSource::Project { dot_codex_folder }, + TomlValue::Table(Default::default()), + ); + ConfigLayerStack::new( + vec![layer], + ConfigRequirements::default(), + ConfigRequirementsToml::default(), + ) + .expect("ConfigLayerStack") +} + +fn host_absolute_path(segments: &[&str]) -> String { + let mut path = if cfg!(windows) { + PathBuf::from(r"C:\") + } else { + PathBuf::from("/") + }; + for segment in segments { + path.push(segment); + } + path.to_string_lossy().into_owned() +} + +fn host_program_path(name: &str) -> String { + let executable_name = if cfg!(windows) { + format!("{name}.exe") + } else { + name.to_string() + }; + host_absolute_path(&["usr", "bin", &executable_name]) +} + +fn starlark_string(value: &str) -> String { + value.replace('\\', "\\\\").replace('"', "\\\"") +} + +fn read_only_file_system_sandbox_policy() -> FileSystemSandboxPolicy { + FileSystemSandboxPolicy::restricted(vec![FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: FileSystemSpecialPath::Root, + }, + access: FileSystemAccessMode::Read, + }]) +} + +fn unrestricted_file_system_sandbox_policy() -> FileSystemSandboxPolicy { + FileSystemSandboxPolicy::unrestricted() +} + +#[tokio::test] +async fn returns_empty_policy_when_no_policy_files_exist() { + let temp_dir = tempdir().expect("create temp dir"); + let config_stack = config_stack_for_dot_codex_folder(temp_dir.path()); + + let manager = ExecPolicyManager::load(&config_stack) + .await + .expect("manager result"); + let policy = manager.current(); + + let commands = [vec!["rm".to_string()]]; + assert_eq!( + Evaluation { + decision: Decision::Allow, + matched_rules: vec![RuleMatch::HeuristicsRuleMatch { + command: vec!["rm".to_string()], + decision: Decision::Allow + }], + }, + policy.check_multiple(commands.iter(), &|_| Decision::Allow) + ); + assert!(!temp_dir.path().join(RULES_DIR_NAME).exists()); +} + +#[tokio::test] +async fn collect_policy_files_returns_empty_when_dir_missing() { + let temp_dir = tempdir().expect("create temp dir"); + + let policy_dir = temp_dir.path().join(RULES_DIR_NAME); + let files = collect_policy_files(&policy_dir) + .await + .expect("collect policy files"); + + assert!(files.is_empty()); +} + +#[tokio::test] +async fn format_exec_policy_error_with_source_renders_range() { + let temp_dir = tempdir().expect("create temp dir"); + let config_stack = config_stack_for_dot_codex_folder(temp_dir.path()); + let policy_dir = temp_dir.path().join(RULES_DIR_NAME); + fs::create_dir_all(&policy_dir).expect("create policy dir"); + let broken_path = policy_dir.join("broken.rules"); + fs::write( + &broken_path, + r#"prefix_rule( + pattern = ["tmux capture-pane"], + decision = "allow", + match = ["tmux capture-pane -p"], +)"#, + ) + .expect("write broken policy file"); + + let err = load_exec_policy(&config_stack) + .await + .expect_err("expected parse error"); + let rendered = format_exec_policy_error_with_source(&err); + + assert!(rendered.contains("broken.rules:1:")); + assert!(rendered.contains("on or around line 1")); +} + +#[test] +fn parse_starlark_line_from_message_extracts_path_and_line() { + let parsed = parse_starlark_line_from_message( + "/tmp/default.rules:143:1: starlark error: error: Parse error: unexpected new line", + ) + .expect("parse should succeed"); + + assert_eq!(parsed.0, PathBuf::from("/tmp/default.rules")); + assert_eq!(parsed.1, 143); +} + +#[test] +fn parse_starlark_line_from_message_rejects_zero_line() { + let parsed = parse_starlark_line_from_message( + "/tmp/default.rules:0:1: starlark error: error: Parse error: unexpected new line", + ); + assert_eq!(parsed, None); +} + +#[tokio::test] +async fn loads_policies_from_policy_subdirectory() { + let temp_dir = tempdir().expect("create temp dir"); + let config_stack = config_stack_for_dot_codex_folder(temp_dir.path()); + let policy_dir = temp_dir.path().join(RULES_DIR_NAME); + fs::create_dir_all(&policy_dir).expect("create policy dir"); + fs::write( + policy_dir.join("deny.rules"), + r#"prefix_rule(pattern=["rm"], decision="forbidden")"#, + ) + .expect("write policy file"); + + let policy = load_exec_policy(&config_stack) + .await + .expect("policy result"); + let command = [vec!["rm".to_string()]]; + assert_eq!( + Evaluation { + decision: Decision::Forbidden, + matched_rules: vec![RuleMatch::PrefixRuleMatch { + matched_prefix: vec!["rm".to_string()], + decision: Decision::Forbidden, + resolved_program: None, + justification: None, + }], + }, + policy.check_multiple(command.iter(), &|_| Decision::Allow) + ); +} + +#[tokio::test] +async fn merges_requirements_exec_policy_network_rules() -> anyhow::Result<()> { + let temp_dir = tempdir()?; + + let mut requirements_exec_policy = Policy::empty(); + requirements_exec_policy.add_network_rule( + "blocked.example.com", + codex_execpolicy::NetworkRuleProtocol::Https, + Decision::Forbidden, + None, + )?; + + let requirements = ConfigRequirements { + exec_policy: Some(codex_config::Sourced::new( + codex_config::RequirementsExecPolicy::new(requirements_exec_policy), + codex_config::RequirementSource::Unknown, + )), + ..ConfigRequirements::default() + }; + let dot_codex_folder = AbsolutePathBuf::from_absolute_path(temp_dir.path())?; + let layer = ConfigLayerEntry::new( + ConfigLayerSource::Project { dot_codex_folder }, + TomlValue::Table(Default::default()), + ); + let config_stack = + ConfigLayerStack::new(vec![layer], requirements, ConfigRequirementsToml::default())?; + + let policy = load_exec_policy(&config_stack).await?; + let (allowed, denied) = policy.compiled_network_domains(); + + assert!(allowed.is_empty()); + assert_eq!(denied, vec!["blocked.example.com".to_string()]); + Ok(()) +} + +#[tokio::test] +async fn preserves_host_executables_when_requirements_overlay_is_present() -> anyhow::Result<()> { + let temp_dir = tempdir()?; + let policy_dir = temp_dir.path().join(RULES_DIR_NAME); + fs::create_dir_all(&policy_dir)?; + let git_path = host_absolute_path(&["usr", "bin", "git"]); + let git_path_literal = starlark_string(&git_path); + fs::write( + policy_dir.join("host.rules"), + format!( + r#" +host_executable(name = "git", paths = ["{git_path_literal}"]) +"# + ), + )?; + + let mut requirements_exec_policy = Policy::empty(); + requirements_exec_policy.add_network_rule( + "blocked.example.com", + codex_execpolicy::NetworkRuleProtocol::Https, + Decision::Forbidden, + None, + )?; + + let requirements = ConfigRequirements { + exec_policy: Some(codex_config::Sourced::new( + codex_config::RequirementsExecPolicy::new(requirements_exec_policy), + codex_config::RequirementSource::Unknown, + )), + ..ConfigRequirements::default() + }; + let dot_codex_folder = AbsolutePathBuf::from_absolute_path(temp_dir.path())?; + let layer = ConfigLayerEntry::new( + ConfigLayerSource::Project { dot_codex_folder }, + TomlValue::Table(Default::default()), + ); + let config_stack = + ConfigLayerStack::new(vec![layer], requirements, ConfigRequirementsToml::default())?; + + let policy = load_exec_policy(&config_stack).await?; + + assert_eq!( + policy + .host_executables() + .get("git") + .expect("missing git host executable") + .as_ref(), + [AbsolutePathBuf::try_from(git_path)?] + ); + Ok(()) +} + +#[tokio::test] +async fn ignores_policies_outside_policy_dir() { + let temp_dir = tempdir().expect("create temp dir"); + let config_stack = config_stack_for_dot_codex_folder(temp_dir.path()); + fs::write( + temp_dir.path().join("root.rules"), + r#"prefix_rule(pattern=["ls"], decision="prompt")"#, + ) + .expect("write policy file"); + + let policy = load_exec_policy(&config_stack) + .await + .expect("policy result"); + let command = [vec!["ls".to_string()]]; + assert_eq!( + Evaluation { + decision: Decision::Allow, + matched_rules: vec![RuleMatch::HeuristicsRuleMatch { + command: vec!["ls".to_string()], + decision: Decision::Allow + }], + }, + policy.check_multiple(command.iter(), &|_| Decision::Allow) + ); +} + +#[tokio::test] +async fn ignores_rules_from_untrusted_project_layers() -> anyhow::Result<()> { + let project_dir = tempdir()?; + let policy_dir = project_dir.path().join(RULES_DIR_NAME); + fs::create_dir_all(&policy_dir)?; + fs::write( + policy_dir.join("untrusted.rules"), + r#"prefix_rule(pattern=["ls"], decision="forbidden")"#, + )?; + + let project_dot_codex_folder = AbsolutePathBuf::from_absolute_path(project_dir.path())?; + let layers = vec![ConfigLayerEntry::new_disabled( + ConfigLayerSource::Project { + dot_codex_folder: project_dot_codex_folder, + }, + TomlValue::Table(Default::default()), + "marked untrusted", + )]; + let config_stack = ConfigLayerStack::new( + layers, + ConfigRequirements::default(), + ConfigRequirementsToml::default(), + )?; + + let policy = load_exec_policy(&config_stack).await?; + + assert_eq!( + Evaluation { + decision: Decision::Allow, + matched_rules: vec![RuleMatch::HeuristicsRuleMatch { + command: vec!["ls".to_string()], + decision: Decision::Allow, + }], + }, + policy.check_multiple([vec!["ls".to_string()]].iter(), &|_| Decision::Allow) + ); + Ok(()) +} + +#[tokio::test] +async fn loads_policies_from_multiple_config_layers() -> anyhow::Result<()> { + let user_dir = tempdir()?; + let project_dir = tempdir()?; + + let user_policy_dir = user_dir.path().join(RULES_DIR_NAME); + fs::create_dir_all(&user_policy_dir)?; + fs::write( + user_policy_dir.join("user.rules"), + r#"prefix_rule(pattern=["rm"], decision="forbidden")"#, + )?; + + let project_policy_dir = project_dir.path().join(RULES_DIR_NAME); + fs::create_dir_all(&project_policy_dir)?; + fs::write( + project_policy_dir.join("project.rules"), + r#"prefix_rule(pattern=["ls"], decision="prompt")"#, + )?; + + let user_config_toml = + AbsolutePathBuf::from_absolute_path(user_dir.path().join("config.toml"))?; + let project_dot_codex_folder = AbsolutePathBuf::from_absolute_path(project_dir.path())?; + let layers = vec![ + ConfigLayerEntry::new( + ConfigLayerSource::User { + file: user_config_toml, + }, + TomlValue::Table(Default::default()), + ), + ConfigLayerEntry::new( + ConfigLayerSource::Project { + dot_codex_folder: project_dot_codex_folder, + }, + TomlValue::Table(Default::default()), + ), + ]; + let config_stack = ConfigLayerStack::new( + layers, + ConfigRequirements::default(), + ConfigRequirementsToml::default(), + )?; + + let policy = load_exec_policy(&config_stack).await?; + + assert_eq!( + Evaluation { + decision: Decision::Forbidden, + matched_rules: vec![RuleMatch::PrefixRuleMatch { + matched_prefix: vec!["rm".to_string()], + decision: Decision::Forbidden, + resolved_program: None, + justification: None, + }], + }, + policy.check_multiple([vec!["rm".to_string()]].iter(), &|_| Decision::Allow) + ); + assert_eq!( + Evaluation { + decision: Decision::Prompt, + matched_rules: vec![RuleMatch::PrefixRuleMatch { + matched_prefix: vec!["ls".to_string()], + decision: Decision::Prompt, + resolved_program: None, + justification: None, + }], + }, + policy.check_multiple([vec!["ls".to_string()]].iter(), &|_| Decision::Allow) + ); + Ok(()) +} + +#[tokio::test] +async fn evaluates_bash_lc_inner_commands() { + let policy_src = r#" +prefix_rule(pattern=["rm"], decision="forbidden") +"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", policy_src) + .expect("parse policy"); + let policy = Arc::new(parser.build()); + + let forbidden_script = vec![ + "bash".to_string(), + "-lc".to_string(), + "rm -rf /some/important/folder".to_string(), + ]; + + let manager = ExecPolicyManager::new(policy); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &forbidden_script, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: &SandboxPolicy::DangerFullAccess, + file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::Forbidden { + reason: "`bash -lc 'rm -rf /some/important/folder'` rejected: policy forbids commands starting with `rm`".to_string() + } + ); +} + +#[test] +fn commands_for_exec_policy_falls_back_for_empty_shell_script() { + let command = vec!["bash".to_string(), "-lc".to_string(), "".to_string()]; + + assert_eq!(commands_for_exec_policy(&command), (vec![command], false)); +} + +#[test] +fn commands_for_exec_policy_falls_back_for_whitespace_shell_script() { + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + " \n\t ".to_string(), + ]; + + assert_eq!(commands_for_exec_policy(&command), (vec![command], false)); +} + +#[tokio::test] +async fn evaluates_heredoc_script_against_prefix_rules() { + let policy_src = r#"prefix_rule(pattern=["python3"], decision="allow")"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", policy_src) + .expect("parse policy"); + let policy = Arc::new(parser.build()); + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + "python3 <<'PY'\nprint('hello')\nPY".to_string(), + ]; + + let requirement = ExecPolicyManager::new(policy) + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::Skip { + bypass_sandbox: true, + proposed_execpolicy_amendment: None, + } + ); +} + +#[tokio::test] +async fn omits_auto_amendment_for_heredoc_fallback_prompts() { + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + "python3 <<'PY'\nprint('hello')\nPY".to_string(), + ]; + + let requirement = ExecPolicyManager::default() + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::UnlessTrusted, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: None, + } + ); +} + +#[tokio::test] +async fn drops_requested_amendment_for_heredoc_fallback_prompts_when_it_wont_match() { + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + "python3 <<'PY'\nprint('hello')\nPY".to_string(), + ]; + let requested_prefix = vec!["python3".to_string(), "-m".to_string(), "pip".to_string()]; + + let requirement = ExecPolicyManager::default() + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::UnlessTrusted, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: Some(requested_prefix.clone()), + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: None, + } + ); +} + +#[tokio::test] +async fn justification_is_included_in_forbidden_exec_approval_requirement() { + let policy_src = r#" +prefix_rule( + pattern=["rm"], + decision="forbidden", + justification="destructive command", +) +"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", policy_src) + .expect("parse policy"); + let policy = Arc::new(parser.build()); + + let manager = ExecPolicyManager::new(policy); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &[ + "rm".to_string(), + "-rf".to_string(), + "/some/important/folder".to_string(), + ], + approval_policy: AskForApproval::OnRequest, + sandbox_policy: &SandboxPolicy::DangerFullAccess, + file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::Forbidden { + reason: "`rm -rf /some/important/folder` rejected: destructive command".to_string() + } + ); +} + +#[tokio::test] +async fn exec_approval_requirement_prefers_execpolicy_match() { + let policy_src = r#"prefix_rule(pattern=["rm"], decision="prompt")"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", policy_src) + .expect("parse policy"); + let policy = Arc::new(parser.build()); + let command = vec!["rm".to_string()]; + + let manager = ExecPolicyManager::new(policy); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: &SandboxPolicy::DangerFullAccess, + file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: Some("`rm` requires approval by policy".to_string()), + proposed_execpolicy_amendment: None, + } + ); +} + +#[tokio::test] +async fn absolute_path_exec_approval_requirement_matches_host_executable_rules() { + let git_path = host_program_path("git"); + let git_path_literal = starlark_string(&git_path); + let policy_src = format!( + r#" +host_executable(name = "git", paths = ["{git_path_literal}"]) +prefix_rule(pattern=["git"], decision="allow") +"# + ); + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", &policy_src) + .expect("parse policy"); + let manager = ExecPolicyManager::new(Arc::new(parser.build())); + let command = vec![git_path, "status".to_string()]; + + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::UnlessTrusted, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::Skip { + bypass_sandbox: true, + proposed_execpolicy_amendment: None, + } + ); +} + +#[tokio::test] +async fn absolute_path_exec_approval_requirement_ignores_disallowed_host_executable_paths() { + let allowed_git_path = host_program_path("git"); + let disallowed_git_path = host_absolute_path(&[ + "opt", + "homebrew", + "bin", + if cfg!(windows) { "git.exe" } else { "git" }, + ]); + let allowed_git_path_literal = starlark_string(&allowed_git_path); + let policy_src = format!( + r#" +host_executable(name = "git", paths = ["{allowed_git_path_literal}"]) +prefix_rule(pattern=["git"], decision="prompt") +"# + ); + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", &policy_src) + .expect("parse policy"); + let manager = ExecPolicyManager::new(Arc::new(parser.build())); + let command = vec![disallowed_git_path, "status".to_string()]; + + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::UnlessTrusted, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::Skip { + bypass_sandbox: false, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)), + } + ); +} + +#[tokio::test] +async fn requested_prefix_rule_can_approve_absolute_path_commands() { + let command = vec![ + host_program_path("cargo"), + "install".to_string(), + "cargo-insta".to_string(), + ]; + let manager = ExecPolicyManager::default(); + + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::UnlessTrusted, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: Some(vec!["cargo".to_string(), "install".to_string()]), + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec![ + "cargo".to_string(), + "install".to_string(), + ])), + } + ); +} + +#[tokio::test] +async fn exec_approval_requirement_respects_approval_policy() { + let policy_src = r#"prefix_rule(pattern=["rm"], decision="prompt")"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", policy_src) + .expect("parse policy"); + let policy = Arc::new(parser.build()); + let command = vec!["rm".to_string()]; + + let manager = ExecPolicyManager::new(policy); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::Never, + sandbox_policy: &SandboxPolicy::DangerFullAccess, + file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::Forbidden { + reason: PROMPT_CONFLICT_REASON.to_string() + } + ); +} + +#[test] +fn unmatched_reject_policy_still_prompts_for_restricted_sandbox_escalation() { + let command = vec!["madeup-cmd".to_string()]; + + assert_eq!( + Decision::Prompt, + render_decision_for_unmatched_command( + AskForApproval::Reject(RejectConfig { + sandbox_approval: false, + rules: false, + skill_approval: false, + request_permissions: false, + mcp_elicitations: false, + }), + &SandboxPolicy::new_read_only_policy(), + &read_only_file_system_sandbox_policy(), + &command, + SandboxPermissions::RequireEscalated, + false, + ) + ); +} + +#[test] +fn unmatched_on_request_uses_split_filesystem_policy_for_escalation_prompts() { + let command = vec!["madeup-cmd".to_string()]; + let restricted_file_system_policy = FileSystemSandboxPolicy::restricted(vec![]); + + assert_eq!( + Decision::Prompt, + render_decision_for_unmatched_command( + AskForApproval::OnRequest, + &SandboxPolicy::DangerFullAccess, + &restricted_file_system_policy, + &command, + SandboxPermissions::RequireEscalated, + false, + ) + ); +} + +#[tokio::test] +async fn exec_approval_requirement_rejects_unmatched_sandbox_escalation_when_sandbox_rejection_enabled() + { + let command = vec!["madeup-cmd".to_string()]; + + let requirement = ExecPolicyManager::default() + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::Reject(RejectConfig { + sandbox_approval: true, + rules: false, + skill_approval: false, + request_permissions: false, + mcp_elicitations: false, + }), + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::RequireEscalated, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::Forbidden { + reason: REJECT_SANDBOX_APPROVAL_REASON.to_string(), + } + ); +} + +#[tokio::test] +async fn mixed_rule_and_sandbox_prompt_prioritizes_rule_for_rejection_decision() { + let policy_src = r#"prefix_rule(pattern=["git"], decision="prompt")"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", policy_src) + .expect("parse policy"); + let manager = ExecPolicyManager::new(Arc::new(parser.build())); + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + "git status && madeup-cmd".to_string(), + ]; + + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::Reject(RejectConfig { + sandbox_approval: true, + rules: false, + skill_approval: false, + request_permissions: false, + mcp_elicitations: false, + }), + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::RequireEscalated, + prefix_rule: None, + }) + .await; + + assert!(matches!( + requirement, + ExecApprovalRequirement::NeedsApproval { .. } + )); +} + +#[tokio::test] +async fn mixed_rule_and_sandbox_prompt_rejects_when_rules_rejection_enabled() { + let policy_src = r#"prefix_rule(pattern=["git"], decision="prompt")"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", policy_src) + .expect("parse policy"); + let manager = ExecPolicyManager::new(Arc::new(parser.build())); + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + "git status && madeup-cmd".to_string(), + ]; + + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::Reject(RejectConfig { + sandbox_approval: false, + rules: true, + skill_approval: false, + request_permissions: false, + mcp_elicitations: false, + }), + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::RequireEscalated, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::Forbidden { + reason: REJECT_RULES_APPROVAL_REASON.to_string(), + } + ); +} + +#[tokio::test] +async fn exec_approval_requirement_falls_back_to_heuristics() { + let command = vec!["cargo".to_string(), "build".to_string()]; + + let manager = ExecPolicyManager::default(); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::UnlessTrusted, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)) + } + ); +} + +#[tokio::test] +async fn empty_bash_lc_script_falls_back_to_original_command() { + let command = vec!["bash".to_string(), "-lc".to_string(), "".to_string()]; + + let manager = ExecPolicyManager::default(); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::UnlessTrusted, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)), + } + ); +} + +#[tokio::test] +async fn whitespace_bash_lc_script_falls_back_to_original_command() { + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + " \n\t ".to_string(), + ]; + + let manager = ExecPolicyManager::default(); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::UnlessTrusted, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)), + } + ); +} + +#[tokio::test] +async fn request_rule_uses_prefix_rule() { + let command = vec![ + "cargo".to_string(), + "install".to_string(), + "cargo-insta".to_string(), + ]; + let manager = ExecPolicyManager::default(); + + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::RequireEscalated, + prefix_rule: Some(vec!["cargo".to_string(), "install".to_string()]), + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec![ + "cargo".to_string(), + "install".to_string(), + ])), + } + ); +} + +#[tokio::test] +async fn request_rule_falls_back_when_prefix_rule_does_not_approve_all_commands() { + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + "cargo install cargo-insta && rm -rf /tmp/codex".to_string(), + ]; + let manager = ExecPolicyManager::default(); + + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: &SandboxPolicy::DangerFullAccess, + file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::RequireEscalated, + prefix_rule: Some(vec!["cargo".to_string(), "install".to_string()]), + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec![ + "rm".to_string(), + "-rf".to_string(), + "/tmp/codex".to_string(), + ])), + } + ); +} + +#[tokio::test] +async fn heuristics_apply_when_other_commands_match_policy() { + let policy_src = r#"prefix_rule(pattern=["apple"], decision="allow")"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", policy_src) + .expect("parse policy"); + let policy = Arc::new(parser.build()); + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + "apple | orange".to_string(), + ]; + + assert_eq!( + ExecPolicyManager::new(policy) + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::UnlessTrusted, + sandbox_policy: &SandboxPolicy::DangerFullAccess, + file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec![ + "orange".to_string() + ])) + } + ); +} + +#[tokio::test] +async fn append_execpolicy_amendment_updates_policy_and_file() { + let codex_home = tempdir().expect("create temp dir"); + let prefix = vec!["echo".to_string(), "hello".to_string()]; + let manager = ExecPolicyManager::default(); + + manager + .append_amendment_and_update(codex_home.path(), &ExecPolicyAmendment::from(prefix)) + .await + .expect("update policy"); + let updated_policy = manager.current(); + + let evaluation = updated_policy.check( + &["echo".to_string(), "hello".to_string(), "world".to_string()], + &|_| Decision::Allow, + ); + assert!(matches!( + evaluation, + Evaluation { + decision: Decision::Allow, + .. + } + )); + + let contents = fs::read_to_string(default_policy_path(codex_home.path())) + .expect("policy file should have been created"); + assert_eq!( + contents, + r#"prefix_rule(pattern=["echo", "hello"], decision="allow") +"# + ); +} + +#[tokio::test] +async fn append_execpolicy_amendment_rejects_empty_prefix() { + let codex_home = tempdir().expect("create temp dir"); + let manager = ExecPolicyManager::default(); + + let result = manager + .append_amendment_and_update(codex_home.path(), &ExecPolicyAmendment::from(vec![])) + .await; + + assert!(matches!( + result, + Err(ExecPolicyUpdateError::AppendRule { + source: AmendError::EmptyPrefix, + .. + }) + )); +} + +#[tokio::test] +async fn proposed_execpolicy_amendment_is_present_for_single_command_without_policy_match() { + let command = vec!["cargo".to_string(), "build".to_string()]; + + let manager = ExecPolicyManager::default(); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::UnlessTrusted, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)) + } + ); +} + +#[tokio::test] +async fn proposed_execpolicy_amendment_is_omitted_when_policy_prompts() { + let policy_src = r#"prefix_rule(pattern=["rm"], decision="prompt")"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", policy_src) + .expect("parse policy"); + let policy = Arc::new(parser.build()); + let command = vec!["rm".to_string()]; + + let manager = ExecPolicyManager::new(policy); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: &SandboxPolicy::DangerFullAccess, + file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: Some("`rm` requires approval by policy".to_string()), + proposed_execpolicy_amendment: None, + } + ); +} + +#[tokio::test] +async fn proposed_execpolicy_amendment_is_present_for_multi_command_scripts() { + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + "cargo build && echo ok".to_string(), + ]; + let manager = ExecPolicyManager::default(); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::UnlessTrusted, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec![ + "cargo".to_string(), + "build".to_string() + ])), + } + ); +} + +#[tokio::test] +async fn proposed_execpolicy_amendment_uses_first_no_match_in_multi_command_scripts() { + let policy_src = r#"prefix_rule(pattern=["cat"], decision="allow")"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", policy_src) + .expect("parse policy"); + let policy = Arc::new(parser.build()); + + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + "cat && apple".to_string(), + ]; + + assert_eq!( + ExecPolicyManager::new(policy) + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::UnlessTrusted, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec![ + "apple".to_string() + ])), + } + ); +} + +#[tokio::test] +async fn proposed_execpolicy_amendment_is_present_when_heuristics_allow() { + let command = vec!["echo".to_string(), "safe".to_string()]; + + let manager = ExecPolicyManager::default(); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::Skip { + bypass_sandbox: false, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)), + } + ); +} + +#[tokio::test] +async fn proposed_execpolicy_amendment_is_suppressed_when_policy_matches_allow() { + let policy_src = r#"prefix_rule(pattern=["echo"], decision="allow")"#; + let mut parser = PolicyParser::new(); + parser + .parse("test.rules", policy_src) + .expect("parse policy"); + let policy = Arc::new(parser.build()); + let command = vec!["echo".to_string(), "safe".to_string()]; + + let manager = ExecPolicyManager::new(policy); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::Skip { + bypass_sandbox: true, + proposed_execpolicy_amendment: None, + } + ); +} + +fn derive_requested_execpolicy_amendment_for_test( + prefix_rule: Option<&Vec>, + matched_rules: &[RuleMatch], +) -> Option { + let commands = prefix_rule + .cloned() + .map(|prefix_rule| vec![prefix_rule]) + .unwrap_or_else(|| vec![vec!["echo".to_string()]]); + derive_requested_execpolicy_amendment_from_prefix_rule( + prefix_rule, + matched_rules, + &Policy::empty(), + &commands, + &|_: &[String]| Decision::Allow, + &MatchOptions::default(), + ) +} + +#[test] +fn derive_requested_execpolicy_amendment_returns_none_for_missing_prefix_rule() { + assert_eq!( + None, + derive_requested_execpolicy_amendment_for_test(None, &[]) + ); +} + +#[test] +fn derive_requested_execpolicy_amendment_returns_none_for_empty_prefix_rule() { + assert_eq!( + None, + derive_requested_execpolicy_amendment_for_test(Some(&Vec::new()), &[]) + ); +} + +#[test] +fn derive_requested_execpolicy_amendment_returns_none_for_exact_banned_prefix_rule() { + assert_eq!( + None, + derive_requested_execpolicy_amendment_for_test( + Some(&vec!["python".to_string(), "-c".to_string()]), + &[], + ) + ); +} + +#[test] +fn derive_requested_execpolicy_amendment_returns_none_for_windows_and_pypy_variants() { + for prefix_rule in [ + vec!["py".to_string()], + vec!["py".to_string(), "-3".to_string()], + vec!["pythonw".to_string()], + vec!["pyw".to_string()], + vec!["pypy".to_string()], + vec!["pypy3".to_string()], + ] { + assert_eq!( + None, + derive_requested_execpolicy_amendment_for_test(Some(&prefix_rule), &[]) + ); + } +} + +#[test] +fn derive_requested_execpolicy_amendment_returns_none_for_shell_and_powershell_variants() { + for prefix_rule in [ + vec!["bash".to_string(), "-lc".to_string()], + vec!["sh".to_string(), "-c".to_string()], + vec!["sh".to_string(), "-lc".to_string()], + vec!["zsh".to_string(), "-lc".to_string()], + vec!["/bin/bash".to_string(), "-lc".to_string()], + vec!["/bin/zsh".to_string(), "-lc".to_string()], + vec!["pwsh".to_string()], + vec!["pwsh".to_string(), "-Command".to_string()], + vec!["pwsh".to_string(), "-c".to_string()], + vec!["powershell".to_string()], + vec!["powershell".to_string(), "-Command".to_string()], + vec!["powershell".to_string(), "-c".to_string()], + vec!["powershell.exe".to_string()], + vec!["powershell.exe".to_string(), "-Command".to_string()], + vec!["powershell.exe".to_string(), "-c".to_string()], + ] { + assert_eq!( + None, + derive_requested_execpolicy_amendment_for_test(Some(&prefix_rule), &[]) + ); + } +} + +#[test] +fn derive_requested_execpolicy_amendment_allows_non_exact_banned_prefix_rule_match() { + let prefix_rule = vec![ + "python".to_string(), + "-c".to_string(), + "print('hi')".to_string(), + ]; + + assert_eq!( + Some(ExecPolicyAmendment::new(prefix_rule.clone())), + derive_requested_execpolicy_amendment_for_test(Some(&prefix_rule), &[]) + ); +} + +#[test] +fn derive_requested_execpolicy_amendment_returns_none_when_policy_matches() { + let prefix_rule = vec!["cargo".to_string(), "build".to_string()]; + + let matched_rules_prompt = vec![RuleMatch::PrefixRuleMatch { + matched_prefix: vec!["cargo".to_string()], + decision: Decision::Prompt, + resolved_program: None, + justification: None, + }]; + assert_eq!( + None, + derive_requested_execpolicy_amendment_for_test(Some(&prefix_rule), &matched_rules_prompt), + "should return none when prompt policy matches" + ); + let matched_rules_allow = vec![RuleMatch::PrefixRuleMatch { + matched_prefix: vec!["cargo".to_string()], + decision: Decision::Allow, + resolved_program: None, + justification: None, + }]; + assert_eq!( + None, + derive_requested_execpolicy_amendment_for_test(Some(&prefix_rule), &matched_rules_allow), + "should return none when prompt policy matches" + ); + let matched_rules_forbidden = vec![RuleMatch::PrefixRuleMatch { + matched_prefix: vec!["cargo".to_string()], + decision: Decision::Forbidden, + resolved_program: None, + justification: None, + }]; + assert_eq!( + None, + derive_requested_execpolicy_amendment_for_test( + Some(&prefix_rule), + &matched_rules_forbidden, + ), + "should return none when prompt policy matches" + ); +} + +#[tokio::test] +async fn dangerous_rm_rf_requires_approval_in_danger_full_access() { + let command = vec_str(&["rm", "-rf", "/tmp/nonexistent"]); + let manager = ExecPolicyManager::default(); + let requirement = manager + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &command, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: &SandboxPolicy::DangerFullAccess, + file_system_sandbox_policy: &unrestricted_file_system_sandbox_policy(), + sandbox_permissions: SandboxPermissions::UseDefault, + prefix_rule: None, + }) + .await; + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(command)), + } + ); +} + +fn vec_str(items: &[&str]) -> Vec { + items.iter().map(std::string::ToString::to_string).collect() +} + +/// Note this test behaves differently on Windows because it exercises an +/// `if cfg!(windows)` code path in render_decision_for_unmatched_command(). +#[tokio::test] +async fn verify_approval_requirement_for_unsafe_powershell_command() { + // `brew install powershell` to run this test on a Mac! + // Note `pwsh` is required to parse a PowerShell command to see if it + // is safe. + if which::which("pwsh").is_err() { + return; + } + + let policy = ExecPolicyManager::new(Arc::new(Policy::empty())); + let permissions = SandboxPermissions::UseDefault; + + // This command should not be run without user approval unless there is + // a proper sandbox in place to ensure safety. + let sneaky_command = vec_str(&["pwsh", "-Command", "echo hi @(calc)"]); + let expected_amendment = Some(ExecPolicyAmendment::new(vec_str(&[ + "pwsh", + "-Command", + "echo hi @(calc)", + ]))); + let (pwsh_approval_reason, expected_req) = if cfg!(windows) { + ( + r#"On Windows, SandboxPolicy::ReadOnly should be assumed to mean + that no sandbox is present, so anything that is not "provably + safe" should require approval."#, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: expected_amendment.clone(), + }, + ) + } else { + ( + "On non-Windows, rely on the read-only sandbox to prevent harm.", + ExecApprovalRequirement::Skip { + bypass_sandbox: false, + proposed_execpolicy_amendment: expected_amendment.clone(), + }, + ) + }; + assert_eq!( + expected_req, + policy + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &sneaky_command, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: permissions, + prefix_rule: None, + }) + .await, + "{pwsh_approval_reason}" + ); + + // This is flagged as a dangerous command on all platforms. + let dangerous_command = vec_str(&["rm", "-rf", "/important/data"]); + assert_eq!( + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: Some(ExecPolicyAmendment::new(vec_str(&[ + "rm", + "-rf", + "/important/data", + ]))), + }, + policy + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &dangerous_command, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: permissions, + prefix_rule: None, + }) + .await, + r#"On all platforms, a forbidden command should require approval + (unless AskForApproval::Never is specified)."# + ); + + // A dangerous command should be forbidden if the user has specified + // AskForApproval::Never. + assert_eq!( + ExecApprovalRequirement::Forbidden { + reason: "`rm -rf /important/data` rejected: blocked by policy".to_string(), + }, + policy + .create_exec_approval_requirement_for_command(ExecApprovalRequest { + command: &dangerous_command, + approval_policy: AskForApproval::Never, + sandbox_policy: &SandboxPolicy::new_read_only_policy(), + file_system_sandbox_policy: &read_only_file_system_sandbox_policy(), + sandbox_permissions: permissions, + prefix_rule: None, + }) + .await, + r#"On all platforms, a forbidden command should require approval + (unless AskForApproval::Never is specified)."# + ); +} diff --git a/codex-rs/core/src/exec_tests.rs b/codex-rs/core/src/exec_tests.rs new file mode 100644 index 0000000000..550b41af7f --- /dev/null +++ b/codex-rs/core/src/exec_tests.rs @@ -0,0 +1,423 @@ +use super::*; +use pretty_assertions::assert_eq; +use std::time::Duration; +use tokio::io::AsyncWriteExt; + +fn make_exec_output( + exit_code: i32, + stdout: &str, + stderr: &str, + aggregated: &str, +) -> ExecToolCallOutput { + ExecToolCallOutput { + exit_code, + stdout: StreamOutput::new(stdout.to_string()), + stderr: StreamOutput::new(stderr.to_string()), + aggregated_output: StreamOutput::new(aggregated.to_string()), + duration: Duration::from_millis(1), + timed_out: false, + } +} + +#[test] +fn sandbox_detection_requires_keywords() { + let output = make_exec_output(1, "", "", ""); + assert!(!is_likely_sandbox_denied( + SandboxType::LinuxSeccomp, + &output + )); +} + +#[test] +fn sandbox_detection_identifies_keyword_in_stderr() { + let output = make_exec_output(1, "", "Operation not permitted", ""); + assert!(is_likely_sandbox_denied(SandboxType::LinuxSeccomp, &output)); +} + +#[test] +fn sandbox_detection_respects_quick_reject_exit_codes() { + let output = make_exec_output(127, "", "command not found", ""); + assert!(!is_likely_sandbox_denied( + SandboxType::LinuxSeccomp, + &output + )); +} + +#[test] +fn sandbox_detection_ignores_non_sandbox_mode() { + let output = make_exec_output(1, "", "Operation not permitted", ""); + assert!(!is_likely_sandbox_denied(SandboxType::None, &output)); +} + +#[test] +fn sandbox_detection_ignores_network_policy_text_in_non_sandbox_mode() { + let output = make_exec_output( + 0, + "", + "", + r#"CODEX_NETWORK_POLICY_DECISION {"decision":"ask","reason":"not_allowed","source":"decider","protocol":"http","host":"google.com","port":80}"#, + ); + assert!(!is_likely_sandbox_denied(SandboxType::None, &output)); +} + +#[test] +fn sandbox_detection_uses_aggregated_output() { + let output = make_exec_output( + 101, + "", + "", + "cargo failed: Read-only file system when writing target", + ); + assert!(is_likely_sandbox_denied( + SandboxType::MacosSeatbelt, + &output + )); +} + +#[test] +fn sandbox_detection_ignores_network_policy_text_with_zero_exit_code() { + let output = make_exec_output( + 0, + "", + "", + r#"CODEX_NETWORK_POLICY_DECISION {"decision":"ask","source":"decider","protocol":"http","host":"google.com","port":80}"#, + ); + + assert!(!is_likely_sandbox_denied( + SandboxType::LinuxSeccomp, + &output + )); +} + +#[tokio::test] +async fn read_capped_limits_retained_bytes() { + let (mut writer, reader) = tokio::io::duplex(1024); + let bytes = vec![b'a'; EXEC_OUTPUT_MAX_BYTES.saturating_add(128 * 1024)]; + tokio::spawn(async move { + writer.write_all(&bytes).await.expect("write"); + }); + + let out = read_capped(reader, None, false).await.expect("read"); + assert_eq!(out.text.len(), EXEC_OUTPUT_MAX_BYTES); +} + +#[test] +fn aggregate_output_prefers_stderr_on_contention() { + let stdout = StreamOutput { + text: vec![b'a'; EXEC_OUTPUT_MAX_BYTES], + truncated_after_lines: None, + }; + let stderr = StreamOutput { + text: vec![b'b'; EXEC_OUTPUT_MAX_BYTES], + truncated_after_lines: None, + }; + + let aggregated = aggregate_output(&stdout, &stderr); + let stdout_cap = EXEC_OUTPUT_MAX_BYTES / 3; + let stderr_cap = EXEC_OUTPUT_MAX_BYTES.saturating_sub(stdout_cap); + + assert_eq!(aggregated.text.len(), EXEC_OUTPUT_MAX_BYTES); + assert_eq!(aggregated.text[..stdout_cap], vec![b'a'; stdout_cap]); + assert_eq!(aggregated.text[stdout_cap..], vec![b'b'; stderr_cap]); +} + +#[test] +fn aggregate_output_fills_remaining_capacity_with_stderr() { + let stdout_len = EXEC_OUTPUT_MAX_BYTES / 10; + let stdout = StreamOutput { + text: vec![b'a'; stdout_len], + truncated_after_lines: None, + }; + let stderr = StreamOutput { + text: vec![b'b'; EXEC_OUTPUT_MAX_BYTES], + truncated_after_lines: None, + }; + + let aggregated = aggregate_output(&stdout, &stderr); + let stderr_cap = EXEC_OUTPUT_MAX_BYTES.saturating_sub(stdout_len); + + assert_eq!(aggregated.text.len(), EXEC_OUTPUT_MAX_BYTES); + assert_eq!(aggregated.text[..stdout_len], vec![b'a'; stdout_len]); + assert_eq!(aggregated.text[stdout_len..], vec![b'b'; stderr_cap]); +} + +#[test] +fn aggregate_output_rebalances_when_stderr_is_small() { + let stdout = StreamOutput { + text: vec![b'a'; EXEC_OUTPUT_MAX_BYTES], + truncated_after_lines: None, + }; + let stderr = StreamOutput { + text: vec![b'b'; 1], + truncated_after_lines: None, + }; + + let aggregated = aggregate_output(&stdout, &stderr); + let stdout_len = EXEC_OUTPUT_MAX_BYTES.saturating_sub(1); + + assert_eq!(aggregated.text.len(), EXEC_OUTPUT_MAX_BYTES); + assert_eq!(aggregated.text[..stdout_len], vec![b'a'; stdout_len]); + assert_eq!(aggregated.text[stdout_len..], vec![b'b'; 1]); +} + +#[test] +fn aggregate_output_keeps_stdout_then_stderr_when_under_cap() { + let stdout = StreamOutput { + text: vec![b'a'; 4], + truncated_after_lines: None, + }; + let stderr = StreamOutput { + text: vec![b'b'; 3], + truncated_after_lines: None, + }; + + let aggregated = aggregate_output(&stdout, &stderr); + let mut expected = Vec::new(); + expected.extend_from_slice(&stdout.text); + expected.extend_from_slice(&stderr.text); + + assert_eq!(aggregated.text, expected); + assert_eq!(aggregated.truncated_after_lines, None); +} + +#[test] +fn windows_restricted_token_skips_external_sandbox_policies() { + let policy = SandboxPolicy::ExternalSandbox { + network_access: codex_protocol::protocol::NetworkAccess::Restricted, + }; + let file_system_policy = FileSystemSandboxPolicy::restricted(vec![]); + + assert_eq!( + should_use_windows_restricted_token_sandbox( + SandboxType::WindowsRestrictedToken, + &policy, + &file_system_policy, + ), + false + ); +} + +#[test] +fn windows_restricted_token_runs_for_legacy_restricted_policies() { + let policy = SandboxPolicy::new_read_only_policy(); + let file_system_policy = FileSystemSandboxPolicy::restricted(vec![]); + + assert_eq!( + should_use_windows_restricted_token_sandbox( + SandboxType::WindowsRestrictedToken, + &policy, + &file_system_policy, + ), + true + ); +} + +#[test] +fn windows_restricted_token_rejects_network_only_restrictions() { + let policy = SandboxPolicy::ExternalSandbox { + network_access: codex_protocol::protocol::NetworkAccess::Restricted, + }; + let file_system_policy = FileSystemSandboxPolicy::unrestricted(); + + assert_eq!( + unsupported_windows_restricted_token_sandbox_reason( + SandboxType::WindowsRestrictedToken, + &policy, + &file_system_policy, + NetworkSandboxPolicy::Restricted, + ), + Some( + "windows sandbox backend cannot enforce file_system=Unrestricted, network=Restricted, legacy_policy=ExternalSandbox { network_access: Restricted }; refusing to run unsandboxed".to_string() + ) + ); +} + +#[test] +fn windows_restricted_token_allows_legacy_restricted_policies() { + let policy = SandboxPolicy::new_read_only_policy(); + let file_system_policy = FileSystemSandboxPolicy::restricted(vec![]); + + assert_eq!( + unsupported_windows_restricted_token_sandbox_reason( + SandboxType::WindowsRestrictedToken, + &policy, + &file_system_policy, + NetworkSandboxPolicy::Restricted, + ), + None + ); +} + +#[test] +fn windows_restricted_token_allows_legacy_workspace_write_policies() { + let policy = SandboxPolicy::WorkspaceWrite { + writable_roots: vec![], + read_only_access: codex_protocol::protocol::ReadOnlyAccess::FullAccess, + network_access: false, + exclude_tmpdir_env_var: false, + exclude_slash_tmp: false, + }; + let file_system_policy = FileSystemSandboxPolicy::from(&policy); + + assert_eq!( + unsupported_windows_restricted_token_sandbox_reason( + SandboxType::WindowsRestrictedToken, + &policy, + &file_system_policy, + NetworkSandboxPolicy::Restricted, + ), + None + ); +} + +#[test] +fn process_exec_tool_call_uses_platform_sandbox_for_network_only_restrictions() { + let expected = crate::get_platform_sandbox(false).unwrap_or(SandboxType::None); + + assert_eq!( + select_process_exec_tool_sandbox_type( + &FileSystemSandboxPolicy::unrestricted(), + NetworkSandboxPolicy::Restricted, + codex_protocol::config_types::WindowsSandboxLevel::Disabled, + false, + ), + expected + ); +} + +#[cfg(unix)] +#[test] +fn sandbox_detection_flags_sigsys_exit_code() { + let exit_code = EXIT_CODE_SIGNAL_BASE + libc::SIGSYS; + let output = make_exec_output(exit_code, "", "", ""); + assert!(is_likely_sandbox_denied(SandboxType::LinuxSeccomp, &output)); +} + +#[cfg(unix)] +#[tokio::test] +async fn kill_child_process_group_kills_grandchildren_on_timeout() -> Result<()> { + // On Linux/macOS, /bin/bash is typically present; on FreeBSD/OpenBSD, + // prefer /bin/sh to avoid NotFound errors. + #[cfg(any(target_os = "freebsd", target_os = "openbsd"))] + let command = vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "sleep 60 & echo $!; sleep 60".to_string(), + ]; + #[cfg(all(unix, not(any(target_os = "freebsd", target_os = "openbsd"))))] + let command = vec![ + "/bin/bash".to_string(), + "-c".to_string(), + "sleep 60 & echo $!; sleep 60".to_string(), + ]; + let env: HashMap = std::env::vars().collect(); + let params = ExecParams { + command, + cwd: std::env::current_dir()?, + expiration: 500.into(), + env, + network: None, + sandbox_permissions: SandboxPermissions::UseDefault, + windows_sandbox_level: codex_protocol::config_types::WindowsSandboxLevel::Disabled, + justification: None, + arg0: None, + }; + + let output = exec( + params, + SandboxType::None, + &SandboxPolicy::new_read_only_policy(), + &FileSystemSandboxPolicy::from(&SandboxPolicy::new_read_only_policy()), + NetworkSandboxPolicy::Restricted, + None, + None, + ) + .await?; + assert!(output.timed_out); + + let stdout = output.stdout.from_utf8_lossy().text; + let pid_line = stdout.lines().next().unwrap_or("").trim(); + let pid: i32 = pid_line.parse().map_err(|error| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Failed to parse pid from stdout '{pid_line}': {error}"), + ) + })?; + + let mut killed = false; + for _ in 0..20 { + // Use kill(pid, 0) to check if the process is alive. + if unsafe { libc::kill(pid, 0) } == -1 + && let Some(libc::ESRCH) = std::io::Error::last_os_error().raw_os_error() + { + killed = true; + break; + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + + assert!(killed, "grandchild process with pid {pid} is still alive"); + Ok(()) +} + +#[tokio::test] +async fn process_exec_tool_call_respects_cancellation_token() -> Result<()> { + let command = long_running_command(); + let cwd = std::env::current_dir()?; + let env: HashMap = std::env::vars().collect(); + let cancel_token = CancellationToken::new(); + let cancel_tx = cancel_token.clone(); + let params = ExecParams { + command, + cwd: cwd.clone(), + expiration: ExecExpiration::Cancellation(cancel_token), + env, + network: None, + sandbox_permissions: SandboxPermissions::UseDefault, + windows_sandbox_level: codex_protocol::config_types::WindowsSandboxLevel::Disabled, + justification: None, + arg0: None, + }; + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(1_000)).await; + cancel_tx.cancel(); + }); + let result = process_exec_tool_call( + params, + &SandboxPolicy::DangerFullAccess, + &FileSystemSandboxPolicy::from(&SandboxPolicy::DangerFullAccess), + NetworkSandboxPolicy::Enabled, + cwd.as_path(), + &None, + false, + None, + ) + .await; + let output = match result { + Err(CodexErr::Sandbox(SandboxErr::Timeout { output })) => output, + other => panic!("expected timeout error, got {other:?}"), + }; + assert!(output.timed_out); + assert_eq!(output.exit_code, EXEC_TIMEOUT_EXIT_CODE); + Ok(()) +} + +#[cfg(unix)] +fn long_running_command() -> Vec { + vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "sleep 30".to_string(), + ] +} + +#[cfg(windows)] +fn long_running_command() -> Vec { + vec![ + "powershell.exe".to_string(), + "-NonInteractive".to_string(), + "-NoLogo".to_string(), + "-Command".to_string(), + "Start-Sleep -Seconds 30".to_string(), + ] +} diff --git a/codex-rs/core/src/external_agent_config.rs b/codex-rs/core/src/external_agent_config.rs index 24319fe064..4df4d60ebe 100644 --- a/codex-rs/core/src/external_agent_config.rs +++ b/codex-rs/core/src/external_agent_config.rs @@ -688,402 +688,5 @@ fn emit_migration_metric( } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use tempfile::TempDir; - - fn fixture_paths() -> (TempDir, PathBuf, PathBuf) { - let root = TempDir::new().expect("create tempdir"); - let claude_home = root.path().join(".claude"); - let codex_home = root.path().join(".codex"); - (root, claude_home, codex_home) - } - - fn service_for_paths(claude_home: PathBuf, codex_home: PathBuf) -> ExternalAgentConfigService { - ExternalAgentConfigService::new_for_test(codex_home, claude_home) - } - - #[test] - fn detect_home_lists_config_skills_and_agents_md() { - let (_root, claude_home, codex_home) = fixture_paths(); - let agents_skills = codex_home - .parent() - .map(|parent| parent.join(".agents").join("skills")) - .unwrap_or_else(|| PathBuf::from(".agents").join("skills")); - fs::create_dir_all(claude_home.join("skills").join("skill-a")).expect("create skills"); - fs::write(claude_home.join("CLAUDE.md"), "claude rules").expect("write claude md"); - fs::write( - claude_home.join("settings.json"), - r#"{"model":"claude","env":{"FOO":"bar"}}"#, - ) - .expect("write settings"); - - let items = service_for_paths(claude_home.clone(), codex_home.clone()) - .detect(ExternalAgentConfigDetectOptions { - include_home: true, - cwds: None, - }) - .expect("detect"); - - let expected = vec![ - ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::Config, - description: format!( - "Migrate {} into {}", - claude_home.join("settings.json").display(), - codex_home.join("config.toml").display() - ), - cwd: None, - }, - ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::Skills, - description: format!( - "Copy skill folders from {} to {}", - claude_home.join("skills").display(), - agents_skills.display() - ), - cwd: None, - }, - ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::AgentsMd, - description: format!( - "Import {} to {}", - claude_home.join("CLAUDE.md").display(), - codex_home.join("AGENTS.md").display() - ), - cwd: None, - }, - ]; - - assert_eq!(items, expected); - } - - #[test] - fn detect_repo_lists_agents_md_for_each_cwd() { - let root = TempDir::new().expect("create tempdir"); - let repo_root = root.path().join("repo"); - let nested = repo_root.join("nested").join("child"); - fs::create_dir_all(repo_root.join(".git")).expect("create git dir"); - fs::create_dir_all(&nested).expect("create nested"); - fs::write(repo_root.join("CLAUDE.md"), "Claude code guidance").expect("write source"); - - let items = service_for_paths(root.path().join(".claude"), root.path().join(".codex")) - .detect(ExternalAgentConfigDetectOptions { - include_home: false, - cwds: Some(vec![nested, repo_root.clone()]), - }) - .expect("detect"); - - let expected = vec![ - ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::AgentsMd, - description: format!( - "Import {} to {}", - repo_root.join("CLAUDE.md").display(), - repo_root.join("AGENTS.md").display(), - ), - cwd: Some(repo_root.clone()), - }, - ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::AgentsMd, - description: format!( - "Import {} to {}", - repo_root.join("CLAUDE.md").display(), - repo_root.join("AGENTS.md").display(), - ), - cwd: Some(repo_root), - }, - ]; - - assert_eq!(items, expected); - } - - #[test] - fn import_home_migrates_supported_config_fields_skills_and_agents_md() { - let (_root, claude_home, codex_home) = fixture_paths(); - let agents_skills = codex_home - .parent() - .map(|parent| parent.join(".agents").join("skills")) - .unwrap_or_else(|| PathBuf::from(".agents").join("skills")); - fs::create_dir_all(claude_home.join("skills").join("skill-a")).expect("create skills"); - fs::write( - claude_home.join("settings.json"), - r#"{"model":"claude","permissions":{"ask":["git push"]},"env":{"FOO":"bar","CI":false,"MAX_RETRIES":3,"MY_TEAM":"codex","IGNORED":null,"LIST":["a","b"],"MAP":{"x":1}},"sandbox":{"enabled":true,"network":{"allowLocalBinding":true}}}"#, - ) - .expect("write settings"); - fs::write( - claude_home.join("skills").join("skill-a").join("SKILL.md"), - "Use Claude Code and CLAUDE utilities.", - ) - .expect("write skill"); - fs::write(claude_home.join("CLAUDE.md"), "Claude code guidance").expect("write agents"); - - service_for_paths(claude_home, codex_home.clone()) - .import(vec![ - ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::AgentsMd, - description: String::new(), - cwd: None, - }, - ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::Config, - description: String::new(), - cwd: None, - }, - ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::Skills, - description: String::new(), - cwd: None, - }, - ]) - .expect("import"); - - assert_eq!( - fs::read_to_string(codex_home.join("AGENTS.md")).expect("read agents"), - "Codex guidance" - ); - - assert_eq!( - fs::read_to_string(codex_home.join("config.toml")).expect("read config"), - "sandbox_mode = \"workspace-write\"\n\n[shell_environment_policy]\ninherit = \"core\"\n\n[shell_environment_policy.set]\nCI = \"false\"\nFOO = \"bar\"\nMAX_RETRIES = \"3\"\nMY_TEAM = \"codex\"\n" - ); - assert_eq!( - fs::read_to_string(agents_skills.join("skill-a").join("SKILL.md")) - .expect("read copied skill"), - "Use Codex and Codex utilities." - ); - } - - #[test] - fn import_home_skips_empty_config_migration() { - let (_root, claude_home, codex_home) = fixture_paths(); - fs::create_dir_all(&claude_home).expect("create claude home"); - fs::write( - claude_home.join("settings.json"), - r#"{"model":"claude","sandbox":{"enabled":false}}"#, - ) - .expect("write settings"); - - service_for_paths(claude_home, codex_home.clone()) - .import(vec![ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::Config, - description: String::new(), - cwd: None, - }]) - .expect("import"); - - assert!(!codex_home.join("config.toml").exists()); - } - - #[test] - fn detect_home_skips_config_when_target_already_has_supported_fields() { - let (_root, claude_home, codex_home) = fixture_paths(); - fs::create_dir_all(&claude_home).expect("create claude home"); - fs::create_dir_all(&codex_home).expect("create codex home"); - fs::write( - claude_home.join("settings.json"), - r#"{"env":{"FOO":"bar"},"sandbox":{"enabled":true}}"#, - ) - .expect("write settings"); - fs::write( - codex_home.join("config.toml"), - r#" - sandbox_mode = "workspace-write" - - [shell_environment_policy] - inherit = "core" - - [shell_environment_policy.set] - FOO = "bar" - "#, - ) - .expect("write config"); - - let items = service_for_paths(claude_home, codex_home) - .detect(ExternalAgentConfigDetectOptions { - include_home: true, - cwds: None, - }) - .expect("detect"); - - assert_eq!(items, Vec::::new()); - } - - #[test] - fn detect_home_skips_skills_when_all_skill_directories_exist() { - let (_root, claude_home, codex_home) = fixture_paths(); - let agents_skills = codex_home - .parent() - .map(|parent| parent.join(".agents").join("skills")) - .unwrap_or_else(|| PathBuf::from(".agents").join("skills")); - fs::create_dir_all(claude_home.join("skills").join("skill-a")).expect("create source"); - fs::create_dir_all(agents_skills.join("skill-a")).expect("create target"); - - let items = service_for_paths(claude_home, codex_home) - .detect(ExternalAgentConfigDetectOptions { - include_home: true, - cwds: None, - }) - .expect("detect"); - - assert_eq!(items, Vec::::new()); - } - - #[test] - fn import_repo_agents_md_rewrites_terms_and_skips_non_empty_targets() { - let root = TempDir::new().expect("create tempdir"); - let repo_root = root.path().join("repo-a"); - let repo_with_existing_target = root.path().join("repo-b"); - fs::create_dir_all(repo_root.join(".git")).expect("create git"); - fs::create_dir_all(repo_with_existing_target.join(".git")).expect("create git"); - fs::write( - repo_root.join("CLAUDE.md"), - "Claude code\nclaude\nCLAUDE-CODE\nSee CLAUDE.md\n", - ) - .expect("write source"); - fs::write(repo_with_existing_target.join("CLAUDE.md"), "new source").expect("write source"); - fs::write( - repo_with_existing_target.join("AGENTS.md"), - "keep existing target", - ) - .expect("write target"); - - service_for_paths(root.path().join(".claude"), root.path().join(".codex")) - .import(vec![ - ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::AgentsMd, - description: String::new(), - cwd: Some(repo_root.clone()), - }, - ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::AgentsMd, - description: String::new(), - cwd: Some(repo_with_existing_target.clone()), - }, - ]) - .expect("import"); - - assert_eq!( - fs::read_to_string(repo_root.join("AGENTS.md")).expect("read target"), - "Codex\nCodex\nCodex\nSee AGENTS.md\n" - ); - assert_eq!( - fs::read_to_string(repo_with_existing_target.join("AGENTS.md")) - .expect("read existing target"), - "keep existing target" - ); - } - - #[test] - fn import_repo_agents_md_overwrites_empty_targets() { - let root = TempDir::new().expect("create tempdir"); - let repo_root = root.path().join("repo"); - fs::create_dir_all(repo_root.join(".git")).expect("create git"); - fs::write(repo_root.join("CLAUDE.md"), "Claude code guidance").expect("write source"); - fs::write(repo_root.join("AGENTS.md"), " \n\t").expect("write empty target"); - - service_for_paths(root.path().join(".claude"), root.path().join(".codex")) - .import(vec![ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::AgentsMd, - description: String::new(), - cwd: Some(repo_root.clone()), - }]) - .expect("import"); - - assert_eq!( - fs::read_to_string(repo_root.join("AGENTS.md")).expect("read target"), - "Codex guidance" - ); - } - - #[test] - fn detect_repo_prefers_non_empty_dot_claude_agents_source() { - let root = TempDir::new().expect("create tempdir"); - let repo_root = root.path().join("repo"); - fs::create_dir_all(repo_root.join(".git")).expect("create git"); - fs::create_dir_all(repo_root.join(".claude")).expect("create dot claude"); - fs::write(repo_root.join("CLAUDE.md"), " \n\t").expect("write empty root source"); - fs::write( - repo_root.join(".claude").join("CLAUDE.md"), - "Claude code guidance", - ) - .expect("write dot claude source"); - - let items = service_for_paths(root.path().join(".claude"), root.path().join(".codex")) - .detect(ExternalAgentConfigDetectOptions { - include_home: false, - cwds: Some(vec![repo_root.clone()]), - }) - .expect("detect"); - - assert_eq!( - items, - vec![ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::AgentsMd, - description: format!( - "Import {} to {}", - repo_root.join(".claude").join("CLAUDE.md").display(), - repo_root.join("AGENTS.md").display(), - ), - cwd: Some(repo_root), - }] - ); - } - - #[test] - fn import_repo_uses_non_empty_dot_claude_agents_source() { - let root = TempDir::new().expect("create tempdir"); - let repo_root = root.path().join("repo"); - fs::create_dir_all(repo_root.join(".git")).expect("create git"); - fs::create_dir_all(repo_root.join(".claude")).expect("create dot claude"); - fs::write(repo_root.join("CLAUDE.md"), "").expect("write empty root source"); - fs::write( - repo_root.join(".claude").join("CLAUDE.md"), - "Claude code guidance", - ) - .expect("write dot claude source"); - - service_for_paths(root.path().join(".claude"), root.path().join(".codex")) - .import(vec![ExternalAgentConfigMigrationItem { - item_type: ExternalAgentConfigMigrationItemType::AgentsMd, - description: String::new(), - cwd: Some(repo_root.clone()), - }]) - .expect("import"); - - assert_eq!( - fs::read_to_string(repo_root.join("AGENTS.md")).expect("read target"), - "Codex guidance" - ); - } - - #[test] - fn migration_metric_tags_for_skills_include_skills_count() { - assert_eq!( - migration_metric_tags(ExternalAgentConfigMigrationItemType::Skills, Some(3)), - vec![ - ("migration_type", "skills".to_string()), - ("skills_count", "3".to_string()), - ] - ); - } - - #[test] - fn import_skills_returns_only_new_skill_directory_count() { - let (_root, claude_home, codex_home) = fixture_paths(); - let agents_skills = codex_home - .parent() - .map(|parent| parent.join(".agents").join("skills")) - .unwrap_or_else(|| PathBuf::from(".agents").join("skills")); - fs::create_dir_all(claude_home.join("skills").join("skill-a")).expect("create source a"); - fs::create_dir_all(claude_home.join("skills").join("skill-b")).expect("create source b"); - fs::create_dir_all(agents_skills.join("skill-a")).expect("create existing target"); - - let copied_count = service_for_paths(claude_home, codex_home) - .import_skills(None) - .expect("import skills"); - - assert_eq!(copied_count, 1); - } -} +#[path = "external_agent_config_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/external_agent_config_tests.rs b/codex-rs/core/src/external_agent_config_tests.rs new file mode 100644 index 0000000000..a760f73e19 --- /dev/null +++ b/codex-rs/core/src/external_agent_config_tests.rs @@ -0,0 +1,397 @@ +use super::*; +use pretty_assertions::assert_eq; +use tempfile::TempDir; + +fn fixture_paths() -> (TempDir, PathBuf, PathBuf) { + let root = TempDir::new().expect("create tempdir"); + let claude_home = root.path().join(".claude"); + let codex_home = root.path().join(".codex"); + (root, claude_home, codex_home) +} + +fn service_for_paths(claude_home: PathBuf, codex_home: PathBuf) -> ExternalAgentConfigService { + ExternalAgentConfigService::new_for_test(codex_home, claude_home) +} + +#[test] +fn detect_home_lists_config_skills_and_agents_md() { + let (_root, claude_home, codex_home) = fixture_paths(); + let agents_skills = codex_home + .parent() + .map(|parent| parent.join(".agents").join("skills")) + .unwrap_or_else(|| PathBuf::from(".agents").join("skills")); + fs::create_dir_all(claude_home.join("skills").join("skill-a")).expect("create skills"); + fs::write(claude_home.join("CLAUDE.md"), "claude rules").expect("write claude md"); + fs::write( + claude_home.join("settings.json"), + r#"{"model":"claude","env":{"FOO":"bar"}}"#, + ) + .expect("write settings"); + + let items = service_for_paths(claude_home.clone(), codex_home.clone()) + .detect(ExternalAgentConfigDetectOptions { + include_home: true, + cwds: None, + }) + .expect("detect"); + + let expected = vec![ + ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::Config, + description: format!( + "Migrate {} into {}", + claude_home.join("settings.json").display(), + codex_home.join("config.toml").display() + ), + cwd: None, + }, + ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::Skills, + description: format!( + "Copy skill folders from {} to {}", + claude_home.join("skills").display(), + agents_skills.display() + ), + cwd: None, + }, + ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::AgentsMd, + description: format!( + "Import {} to {}", + claude_home.join("CLAUDE.md").display(), + codex_home.join("AGENTS.md").display() + ), + cwd: None, + }, + ]; + + assert_eq!(items, expected); +} + +#[test] +fn detect_repo_lists_agents_md_for_each_cwd() { + let root = TempDir::new().expect("create tempdir"); + let repo_root = root.path().join("repo"); + let nested = repo_root.join("nested").join("child"); + fs::create_dir_all(repo_root.join(".git")).expect("create git dir"); + fs::create_dir_all(&nested).expect("create nested"); + fs::write(repo_root.join("CLAUDE.md"), "Claude code guidance").expect("write source"); + + let items = service_for_paths(root.path().join(".claude"), root.path().join(".codex")) + .detect(ExternalAgentConfigDetectOptions { + include_home: false, + cwds: Some(vec![nested, repo_root.clone()]), + }) + .expect("detect"); + + let expected = vec![ + ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::AgentsMd, + description: format!( + "Import {} to {}", + repo_root.join("CLAUDE.md").display(), + repo_root.join("AGENTS.md").display(), + ), + cwd: Some(repo_root.clone()), + }, + ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::AgentsMd, + description: format!( + "Import {} to {}", + repo_root.join("CLAUDE.md").display(), + repo_root.join("AGENTS.md").display(), + ), + cwd: Some(repo_root), + }, + ]; + + assert_eq!(items, expected); +} + +#[test] +fn import_home_migrates_supported_config_fields_skills_and_agents_md() { + let (_root, claude_home, codex_home) = fixture_paths(); + let agents_skills = codex_home + .parent() + .map(|parent| parent.join(".agents").join("skills")) + .unwrap_or_else(|| PathBuf::from(".agents").join("skills")); + fs::create_dir_all(claude_home.join("skills").join("skill-a")).expect("create skills"); + fs::write( + claude_home.join("settings.json"), + r#"{"model":"claude","permissions":{"ask":["git push"]},"env":{"FOO":"bar","CI":false,"MAX_RETRIES":3,"MY_TEAM":"codex","IGNORED":null,"LIST":["a","b"],"MAP":{"x":1}},"sandbox":{"enabled":true,"network":{"allowLocalBinding":true}}}"#, + ) + .expect("write settings"); + fs::write( + claude_home.join("skills").join("skill-a").join("SKILL.md"), + "Use Claude Code and CLAUDE utilities.", + ) + .expect("write skill"); + fs::write(claude_home.join("CLAUDE.md"), "Claude code guidance").expect("write agents"); + + service_for_paths(claude_home, codex_home.clone()) + .import(vec![ + ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::AgentsMd, + description: String::new(), + cwd: None, + }, + ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::Config, + description: String::new(), + cwd: None, + }, + ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::Skills, + description: String::new(), + cwd: None, + }, + ]) + .expect("import"); + + assert_eq!( + fs::read_to_string(codex_home.join("AGENTS.md")).expect("read agents"), + "Codex guidance" + ); + + assert_eq!( + fs::read_to_string(codex_home.join("config.toml")).expect("read config"), + "sandbox_mode = \"workspace-write\"\n\n[shell_environment_policy]\ninherit = \"core\"\n\n[shell_environment_policy.set]\nCI = \"false\"\nFOO = \"bar\"\nMAX_RETRIES = \"3\"\nMY_TEAM = \"codex\"\n" + ); + assert_eq!( + fs::read_to_string(agents_skills.join("skill-a").join("SKILL.md")) + .expect("read copied skill"), + "Use Codex and Codex utilities." + ); +} + +#[test] +fn import_home_skips_empty_config_migration() { + let (_root, claude_home, codex_home) = fixture_paths(); + fs::create_dir_all(&claude_home).expect("create claude home"); + fs::write( + claude_home.join("settings.json"), + r#"{"model":"claude","sandbox":{"enabled":false}}"#, + ) + .expect("write settings"); + + service_for_paths(claude_home, codex_home.clone()) + .import(vec![ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::Config, + description: String::new(), + cwd: None, + }]) + .expect("import"); + + assert!(!codex_home.join("config.toml").exists()); +} + +#[test] +fn detect_home_skips_config_when_target_already_has_supported_fields() { + let (_root, claude_home, codex_home) = fixture_paths(); + fs::create_dir_all(&claude_home).expect("create claude home"); + fs::create_dir_all(&codex_home).expect("create codex home"); + fs::write( + claude_home.join("settings.json"), + r#"{"env":{"FOO":"bar"},"sandbox":{"enabled":true}}"#, + ) + .expect("write settings"); + fs::write( + codex_home.join("config.toml"), + r#" + sandbox_mode = "workspace-write" + + [shell_environment_policy] + inherit = "core" + + [shell_environment_policy.set] + FOO = "bar" + "#, + ) + .expect("write config"); + + let items = service_for_paths(claude_home, codex_home) + .detect(ExternalAgentConfigDetectOptions { + include_home: true, + cwds: None, + }) + .expect("detect"); + + assert_eq!(items, Vec::::new()); +} + +#[test] +fn detect_home_skips_skills_when_all_skill_directories_exist() { + let (_root, claude_home, codex_home) = fixture_paths(); + let agents_skills = codex_home + .parent() + .map(|parent| parent.join(".agents").join("skills")) + .unwrap_or_else(|| PathBuf::from(".agents").join("skills")); + fs::create_dir_all(claude_home.join("skills").join("skill-a")).expect("create source"); + fs::create_dir_all(agents_skills.join("skill-a")).expect("create target"); + + let items = service_for_paths(claude_home, codex_home) + .detect(ExternalAgentConfigDetectOptions { + include_home: true, + cwds: None, + }) + .expect("detect"); + + assert_eq!(items, Vec::::new()); +} + +#[test] +fn import_repo_agents_md_rewrites_terms_and_skips_non_empty_targets() { + let root = TempDir::new().expect("create tempdir"); + let repo_root = root.path().join("repo-a"); + let repo_with_existing_target = root.path().join("repo-b"); + fs::create_dir_all(repo_root.join(".git")).expect("create git"); + fs::create_dir_all(repo_with_existing_target.join(".git")).expect("create git"); + fs::write( + repo_root.join("CLAUDE.md"), + "Claude code\nclaude\nCLAUDE-CODE\nSee CLAUDE.md\n", + ) + .expect("write source"); + fs::write(repo_with_existing_target.join("CLAUDE.md"), "new source").expect("write source"); + fs::write( + repo_with_existing_target.join("AGENTS.md"), + "keep existing target", + ) + .expect("write target"); + + service_for_paths(root.path().join(".claude"), root.path().join(".codex")) + .import(vec![ + ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::AgentsMd, + description: String::new(), + cwd: Some(repo_root.clone()), + }, + ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::AgentsMd, + description: String::new(), + cwd: Some(repo_with_existing_target.clone()), + }, + ]) + .expect("import"); + + assert_eq!( + fs::read_to_string(repo_root.join("AGENTS.md")).expect("read target"), + "Codex\nCodex\nCodex\nSee AGENTS.md\n" + ); + assert_eq!( + fs::read_to_string(repo_with_existing_target.join("AGENTS.md")) + .expect("read existing target"), + "keep existing target" + ); +} + +#[test] +fn import_repo_agents_md_overwrites_empty_targets() { + let root = TempDir::new().expect("create tempdir"); + let repo_root = root.path().join("repo"); + fs::create_dir_all(repo_root.join(".git")).expect("create git"); + fs::write(repo_root.join("CLAUDE.md"), "Claude code guidance").expect("write source"); + fs::write(repo_root.join("AGENTS.md"), " \n\t").expect("write empty target"); + + service_for_paths(root.path().join(".claude"), root.path().join(".codex")) + .import(vec![ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::AgentsMd, + description: String::new(), + cwd: Some(repo_root.clone()), + }]) + .expect("import"); + + assert_eq!( + fs::read_to_string(repo_root.join("AGENTS.md")).expect("read target"), + "Codex guidance" + ); +} + +#[test] +fn detect_repo_prefers_non_empty_dot_claude_agents_source() { + let root = TempDir::new().expect("create tempdir"); + let repo_root = root.path().join("repo"); + fs::create_dir_all(repo_root.join(".git")).expect("create git"); + fs::create_dir_all(repo_root.join(".claude")).expect("create dot claude"); + fs::write(repo_root.join("CLAUDE.md"), " \n\t").expect("write empty root source"); + fs::write( + repo_root.join(".claude").join("CLAUDE.md"), + "Claude code guidance", + ) + .expect("write dot claude source"); + + let items = service_for_paths(root.path().join(".claude"), root.path().join(".codex")) + .detect(ExternalAgentConfigDetectOptions { + include_home: false, + cwds: Some(vec![repo_root.clone()]), + }) + .expect("detect"); + + assert_eq!( + items, + vec![ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::AgentsMd, + description: format!( + "Import {} to {}", + repo_root.join(".claude").join("CLAUDE.md").display(), + repo_root.join("AGENTS.md").display(), + ), + cwd: Some(repo_root), + }] + ); +} + +#[test] +fn import_repo_uses_non_empty_dot_claude_agents_source() { + let root = TempDir::new().expect("create tempdir"); + let repo_root = root.path().join("repo"); + fs::create_dir_all(repo_root.join(".git")).expect("create git"); + fs::create_dir_all(repo_root.join(".claude")).expect("create dot claude"); + fs::write(repo_root.join("CLAUDE.md"), "").expect("write empty root source"); + fs::write( + repo_root.join(".claude").join("CLAUDE.md"), + "Claude code guidance", + ) + .expect("write dot claude source"); + + service_for_paths(root.path().join(".claude"), root.path().join(".codex")) + .import(vec![ExternalAgentConfigMigrationItem { + item_type: ExternalAgentConfigMigrationItemType::AgentsMd, + description: String::new(), + cwd: Some(repo_root.clone()), + }]) + .expect("import"); + + assert_eq!( + fs::read_to_string(repo_root.join("AGENTS.md")).expect("read target"), + "Codex guidance" + ); +} + +#[test] +fn migration_metric_tags_for_skills_include_skills_count() { + assert_eq!( + migration_metric_tags(ExternalAgentConfigMigrationItemType::Skills, Some(3)), + vec![ + ("migration_type", "skills".to_string()), + ("skills_count", "3".to_string()), + ] + ); +} + +#[test] +fn import_skills_returns_only_new_skill_directory_count() { + let (_root, claude_home, codex_home) = fixture_paths(); + let agents_skills = codex_home + .parent() + .map(|parent| parent.join(".agents").join("skills")) + .unwrap_or_else(|| PathBuf::from(".agents").join("skills")); + fs::create_dir_all(claude_home.join("skills").join("skill-a")).expect("create source a"); + fs::create_dir_all(claude_home.join("skills").join("skill-b")).expect("create source b"); + fs::create_dir_all(agents_skills.join("skill-a")).expect("create existing target"); + + let copied_count = service_for_paths(claude_home, codex_home) + .import_skills(None) + .expect("import skills"); + + assert_eq!(copied_count, 1); +} diff --git a/codex-rs/core/src/features.rs b/codex-rs/core/src/features.rs index 27f27f55e7..d1da63cfb4 100644 --- a/codex-rs/core/src/features.rs +++ b/codex-rs/core/src/features.rs @@ -898,156 +898,5 @@ pub fn maybe_push_unstable_features_warning( } #[cfg(test)] -mod tests { - use super::*; - - use pretty_assertions::assert_eq; - - #[test] - fn under_development_features_are_disabled_by_default() { - for spec in FEATURES { - if matches!(spec.stage, Stage::UnderDevelopment) { - assert_eq!( - spec.default_enabled, false, - "feature `{}` is under development and must be disabled by default", - spec.key - ); - } - } - } - - #[test] - fn default_enabled_features_are_stable() { - for spec in FEATURES { - if spec.default_enabled { - assert!( - matches!(spec.stage, Stage::Stable | Stage::Removed), - "feature `{}` is enabled by default but is not stable/removed ({:?})", - spec.key, - spec.stage - ); - } - } - } - - #[test] - fn use_legacy_landlock_is_stable_and_disabled_by_default() { - assert_eq!(Feature::UseLegacyLandlock.stage(), Stage::Stable); - assert_eq!(Feature::UseLegacyLandlock.default_enabled(), false); - } - - #[test] - fn js_repl_is_experimental_and_user_toggleable() { - let spec = Feature::JsRepl.info(); - let stage = spec.stage; - let expected_node_version = include_str!("../../node-version.txt").trim_end(); - - assert!(matches!(stage, Stage::Experimental { .. })); - assert_eq!(stage.experimental_menu_name(), Some("JavaScript REPL")); - assert_eq!( - stage.experimental_menu_description().map(str::to_owned), - Some(format!( - "Enable a persistent Node-backed JavaScript REPL for interactive website debugging and other inline JavaScript execution capabilities. Requires Node >= v{expected_node_version} installed." - )) - ); - assert_eq!(Feature::JsRepl.default_enabled(), false); - } - - #[test] - fn guardian_approval_is_experimental_and_user_toggleable() { - let spec = Feature::GuardianApproval.info(); - let stage = spec.stage; - - assert!(matches!(stage, Stage::Experimental { .. })); - assert_eq!( - stage.experimental_menu_name(), - Some("Automatic approval review") - ); - assert_eq!( - stage.experimental_menu_description().map(str::to_owned), - Some( - "Dispatch `on-request` approval prompts (for e.g. sandbox escapes or blocked network access) to a carefully-prompted security reviewer subagent rather than blocking the agent on your input.".to_string() - ) - ); - assert_eq!(stage.experimental_announcement(), None); - assert_eq!(Feature::GuardianApproval.default_enabled(), false); - } - - #[test] - fn request_permissions_is_under_development() { - assert_eq!(Feature::RequestPermissions.stage(), Stage::UnderDevelopment); - assert_eq!(Feature::RequestPermissions.default_enabled(), false); - } - - #[test] - fn request_permissions_tool_is_under_development() { - assert_eq!( - Feature::RequestPermissionsTool.stage(), - Stage::UnderDevelopment - ); - assert_eq!(Feature::RequestPermissionsTool.default_enabled(), false); - } - - #[test] - fn tool_suggest_is_under_development() { - assert_eq!(Feature::ToolSuggest.stage(), Stage::UnderDevelopment); - assert_eq!(Feature::ToolSuggest.default_enabled(), false); - } - - #[test] - fn image_generation_is_under_development() { - assert_eq!(Feature::ImageGeneration.stage(), Stage::UnderDevelopment); - assert_eq!(Feature::ImageGeneration.default_enabled(), false); - } - - #[test] - fn image_detail_original_feature_is_under_development() { - assert_eq!( - Feature::ImageDetailOriginal.stage(), - Stage::UnderDevelopment - ); - assert_eq!(Feature::ImageDetailOriginal.default_enabled(), false); - } - - #[test] - fn collab_is_legacy_alias_for_multi_agent() { - assert_eq!(feature_for_key("multi_agent"), Some(Feature::Collab)); - assert_eq!(feature_for_key("collab"), Some(Feature::Collab)); - } - - #[test] - fn spawn_csv_is_under_development() { - assert_eq!(Feature::SpawnCsv.stage(), Stage::UnderDevelopment); - assert_eq!(Feature::SpawnCsv.default_enabled(), false); - } - - #[test] - fn spawn_csv_normalization_enables_multi_agent_one_way() { - let mut spawn_csv_features = Features::with_defaults(); - spawn_csv_features.enable(Feature::SpawnCsv); - spawn_csv_features.normalize_dependencies(); - assert_eq!(spawn_csv_features.enabled(Feature::SpawnCsv), true); - assert_eq!(spawn_csv_features.enabled(Feature::Collab), true); - - let mut collab_features = Features::with_defaults(); - collab_features.enable(Feature::Collab); - collab_features.normalize_dependencies(); - assert_eq!(collab_features.enabled(Feature::Collab), true); - assert_eq!(collab_features.enabled(Feature::SpawnCsv), false); - } - - #[test] - fn apps_require_feature_flag_and_chatgpt_auth() { - let mut features = Features::with_defaults(); - assert!(!features.apps_enabled_for_auth(None)); - - features.enable(Feature::Apps); - assert!(!features.apps_enabled_for_auth(None)); - - let api_key_auth = CodexAuth::from_api_key("test-api-key"); - assert!(!features.apps_enabled_for_auth(Some(&api_key_auth))); - - let chatgpt_auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); - assert!(features.apps_enabled_for_auth(Some(&chatgpt_auth))); - } -} +#[path = "features_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/features_tests.rs b/codex-rs/core/src/features_tests.rs new file mode 100644 index 0000000000..d8a5d1df4b --- /dev/null +++ b/codex-rs/core/src/features_tests.rs @@ -0,0 +1,151 @@ +use super::*; + +use pretty_assertions::assert_eq; + +#[test] +fn under_development_features_are_disabled_by_default() { + for spec in FEATURES { + if matches!(spec.stage, Stage::UnderDevelopment) { + assert_eq!( + spec.default_enabled, false, + "feature `{}` is under development and must be disabled by default", + spec.key + ); + } + } +} + +#[test] +fn default_enabled_features_are_stable() { + for spec in FEATURES { + if spec.default_enabled { + assert!( + matches!(spec.stage, Stage::Stable | Stage::Removed), + "feature `{}` is enabled by default but is not stable/removed ({:?})", + spec.key, + spec.stage + ); + } + } +} + +#[test] +fn use_legacy_landlock_is_stable_and_disabled_by_default() { + assert_eq!(Feature::UseLegacyLandlock.stage(), Stage::Stable); + assert_eq!(Feature::UseLegacyLandlock.default_enabled(), false); +} + +#[test] +fn js_repl_is_experimental_and_user_toggleable() { + let spec = Feature::JsRepl.info(); + let stage = spec.stage; + let expected_node_version = include_str!("../../node-version.txt").trim_end(); + + assert!(matches!(stage, Stage::Experimental { .. })); + assert_eq!(stage.experimental_menu_name(), Some("JavaScript REPL")); + assert_eq!( + stage.experimental_menu_description().map(str::to_owned), + Some(format!( + "Enable a persistent Node-backed JavaScript REPL for interactive website debugging and other inline JavaScript execution capabilities. Requires Node >= v{expected_node_version} installed." + )) + ); + assert_eq!(Feature::JsRepl.default_enabled(), false); +} + +#[test] +fn guardian_approval_is_experimental_and_user_toggleable() { + let spec = Feature::GuardianApproval.info(); + let stage = spec.stage; + + assert!(matches!(stage, Stage::Experimental { .. })); + assert_eq!( + stage.experimental_menu_name(), + Some("Automatic approval review") + ); + assert_eq!( + stage.experimental_menu_description().map(str::to_owned), + Some( + "Dispatch `on-request` approval prompts (for e.g. sandbox escapes or blocked network access) to a carefully-prompted security reviewer subagent rather than blocking the agent on your input.".to_string() + ) + ); + assert_eq!(stage.experimental_announcement(), None); + assert_eq!(Feature::GuardianApproval.default_enabled(), false); +} + +#[test] +fn request_permissions_is_under_development() { + assert_eq!(Feature::RequestPermissions.stage(), Stage::UnderDevelopment); + assert_eq!(Feature::RequestPermissions.default_enabled(), false); +} + +#[test] +fn request_permissions_tool_is_under_development() { + assert_eq!( + Feature::RequestPermissionsTool.stage(), + Stage::UnderDevelopment + ); + assert_eq!(Feature::RequestPermissionsTool.default_enabled(), false); +} + +#[test] +fn tool_suggest_is_under_development() { + assert_eq!(Feature::ToolSuggest.stage(), Stage::UnderDevelopment); + assert_eq!(Feature::ToolSuggest.default_enabled(), false); +} + +#[test] +fn image_generation_is_under_development() { + assert_eq!(Feature::ImageGeneration.stage(), Stage::UnderDevelopment); + assert_eq!(Feature::ImageGeneration.default_enabled(), false); +} + +#[test] +fn image_detail_original_feature_is_under_development() { + assert_eq!( + Feature::ImageDetailOriginal.stage(), + Stage::UnderDevelopment + ); + assert_eq!(Feature::ImageDetailOriginal.default_enabled(), false); +} + +#[test] +fn collab_is_legacy_alias_for_multi_agent() { + assert_eq!(feature_for_key("multi_agent"), Some(Feature::Collab)); + assert_eq!(feature_for_key("collab"), Some(Feature::Collab)); +} + +#[test] +fn spawn_csv_is_under_development() { + assert_eq!(Feature::SpawnCsv.stage(), Stage::UnderDevelopment); + assert_eq!(Feature::SpawnCsv.default_enabled(), false); +} + +#[test] +fn spawn_csv_normalization_enables_multi_agent_one_way() { + let mut spawn_csv_features = Features::with_defaults(); + spawn_csv_features.enable(Feature::SpawnCsv); + spawn_csv_features.normalize_dependencies(); + assert_eq!(spawn_csv_features.enabled(Feature::SpawnCsv), true); + assert_eq!(spawn_csv_features.enabled(Feature::Collab), true); + + let mut collab_features = Features::with_defaults(); + collab_features.enable(Feature::Collab); + collab_features.normalize_dependencies(); + assert_eq!(collab_features.enabled(Feature::Collab), true); + assert_eq!(collab_features.enabled(Feature::SpawnCsv), false); +} + +#[test] +fn apps_require_feature_flag_and_chatgpt_auth() { + let mut features = Features::with_defaults(); + assert!(!features.apps_enabled_for_auth(None)); + + features.enable(Feature::Apps); + assert!(!features.apps_enabled_for_auth(None)); + + let api_key_auth = CodexAuth::from_api_key("test-api-key"); + assert!(!features.apps_enabled_for_auth(Some(&api_key_auth))); + + let chatgpt_auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); + assert!(features.apps_enabled_for_auth(Some(&chatgpt_auth))); +} diff --git a/codex-rs/core/src/file_watcher.rs b/codex-rs/core/src/file_watcher.rs index f44a431568..7b2cbd76b0 100644 --- a/codex-rs/core/src/file_watcher.rs +++ b/codex-rs/core/src/file_watcher.rs @@ -350,251 +350,5 @@ fn is_skills_path(path: &Path, roots: &HashSet) -> bool { } #[cfg(test)] -mod tests { - use super::*; - use notify::EventKind; - use notify::event::AccessKind; - use notify::event::AccessMode; - use notify::event::CreateKind; - use notify::event::ModifyKind; - use notify::event::RemoveKind; - use pretty_assertions::assert_eq; - use tokio::time::timeout; - - fn path(name: &str) -> PathBuf { - PathBuf::from(name) - } - - fn notify_event(kind: EventKind, paths: Vec) -> Event { - let mut event = Event::new(kind); - for path in paths { - event = event.add_path(path); - } - event - } - - #[test] - fn throttles_and_coalesces_within_interval() { - let start = Instant::now(); - let mut throttled = ThrottledPaths::new(start); - - throttled.add(vec![path("a")]); - let first = throttled.take_ready(start).expect("first emit"); - assert_eq!(first, vec![path("a")]); - - throttled.add(vec![path("b"), path("c")]); - assert_eq!(throttled.take_ready(start), None); - - let second = throttled - .take_ready(start + WATCHER_THROTTLE_INTERVAL) - .expect("coalesced emit"); - assert_eq!(second, vec![path("b"), path("c")]); - } - - #[test] - fn flushes_pending_on_shutdown() { - let start = Instant::now(); - let mut throttled = ThrottledPaths::new(start); - - throttled.add(vec![path("a")]); - let _ = throttled.take_ready(start).expect("first emit"); - - throttled.add(vec![path("b")]); - assert_eq!(throttled.take_ready(start), None); - - let flushed = throttled - .take_pending(start) - .expect("shutdown flush emits pending paths"); - assert_eq!(flushed, vec![path("b")]); - } - - #[test] - fn classify_event_filters_to_skills_roots() { - let root = path("/tmp/skills"); - let state = RwLock::new(WatchState { - skills_root_ref_counts: HashMap::from([(root.clone(), 1)]), - }); - let event = notify_event( - EventKind::Create(CreateKind::Any), - vec![ - root.join("demo/SKILL.md"), - path("/tmp/other/not-a-skill.txt"), - ], - ); - - let classified = classify_event(&event, &state); - assert_eq!(classified, vec![root.join("demo/SKILL.md")]); - } - - #[test] - fn classify_event_supports_multiple_roots_without_prefix_false_positives() { - let root_a = path("/tmp/skills"); - let root_b = path("/tmp/workspace/.codex/skills"); - let state = RwLock::new(WatchState { - skills_root_ref_counts: HashMap::from([(root_a.clone(), 1), (root_b.clone(), 1)]), - }); - let event = notify_event( - EventKind::Modify(ModifyKind::Any), - vec![ - root_a.join("alpha/SKILL.md"), - path("/tmp/skills-extra/not-under-skills.txt"), - root_b.join("beta/SKILL.md"), - ], - ); - - let classified = classify_event(&event, &state); - assert_eq!( - classified, - vec![root_a.join("alpha/SKILL.md"), root_b.join("beta/SKILL.md")] - ); - } - - #[test] - fn classify_event_ignores_non_mutating_event_kinds() { - let root = path("/tmp/skills"); - let state = RwLock::new(WatchState { - skills_root_ref_counts: HashMap::from([(root.clone(), 1)]), - }); - let path = root.join("demo/SKILL.md"); - - let access_event = notify_event( - EventKind::Access(AccessKind::Open(AccessMode::Any)), - vec![path.clone()], - ); - assert_eq!(classify_event(&access_event, &state), Vec::::new()); - - let any_event = notify_event(EventKind::Any, vec![path.clone()]); - assert_eq!(classify_event(&any_event, &state), Vec::::new()); - - let other_event = notify_event(EventKind::Other, vec![path]); - assert_eq!(classify_event(&other_event, &state), Vec::::new()); - } - - #[test] - fn register_skills_root_dedupes_state_entries() { - let watcher = FileWatcher::noop(); - let root = path("/tmp/skills"); - watcher.register_skills_root(root.clone()); - watcher.register_skills_root(root); - watcher.register_skills_root(path("/tmp/other-skills")); - - let state = watcher.state.read().expect("state lock"); - assert_eq!(state.skills_root_ref_counts.len(), 2); - } - - #[test] - fn watch_registration_drop_unregisters_roots() { - let watcher = Arc::new(FileWatcher::noop()); - let root = path("/tmp/skills"); - watcher.register_skills_root(root.clone()); - let registration = WatchRegistration { - file_watcher: Arc::downgrade(&watcher), - roots: vec![root], - }; - - drop(registration); - - let state = watcher.state.read().expect("state lock"); - assert_eq!(state.skills_root_ref_counts.len(), 0); - } - - #[test] - fn unregister_holds_state_lock_until_unwatch_finishes() { - let temp_dir = tempfile::tempdir().expect("temp dir"); - let root = temp_dir.path().join("skills"); - std::fs::create_dir(&root).expect("create root"); - - let watcher = Arc::new(FileWatcher::new(temp_dir.path().to_path_buf()).expect("watcher")); - watcher.register_skills_root(root.clone()); - - let inner = watcher.inner.as_ref().expect("watcher inner"); - let inner_guard = inner.lock().expect("inner lock"); - - let unregister_watcher = Arc::clone(&watcher); - let unregister_root = root.clone(); - let unregister_thread = std::thread::spawn(move || { - unregister_watcher.unregister_roots(&[unregister_root]); - }); - - let state_lock_observed = (0..100).any(|_| { - let locked = watcher.state.try_write().is_err(); - if !locked { - std::thread::sleep(Duration::from_millis(10)); - } - locked - }); - assert_eq!(state_lock_observed, true); - - let register_watcher = Arc::clone(&watcher); - let register_root = root.clone(); - let register_thread = std::thread::spawn(move || { - register_watcher.register_skills_root(register_root); - }); - - drop(inner_guard); - - unregister_thread.join().expect("unregister join"); - register_thread.join().expect("register join"); - - let state = watcher.state.read().expect("state lock"); - assert_eq!(state.skills_root_ref_counts.get(&root), Some(&1)); - drop(state); - - let inner = watcher.inner.as_ref().expect("watcher inner"); - let inner = inner.lock().expect("inner lock"); - assert_eq!( - inner.watched_paths.get(&root), - Some(&RecursiveMode::Recursive) - ); - } - - #[tokio::test] - async fn spawn_event_loop_flushes_pending_changes_on_shutdown() { - let watcher = FileWatcher::noop(); - let root = path("/tmp/skills"); - { - let mut state = watcher.state.write().expect("state lock"); - state.skills_root_ref_counts.insert(root.clone(), 1); - } - - let (raw_tx, raw_rx) = mpsc::unbounded_channel(); - let (tx, mut rx) = broadcast::channel(8); - watcher.spawn_event_loop(raw_rx, Arc::clone(&watcher.state), tx); - - raw_tx - .send(Ok(notify_event( - EventKind::Create(CreateKind::File), - vec![root.join("a/SKILL.md")], - ))) - .expect("send first event"); - let first = timeout(Duration::from_secs(2), rx.recv()) - .await - .expect("first watcher event") - .expect("broadcast recv first"); - assert_eq!( - first, - FileWatcherEvent::SkillsChanged { - paths: vec![root.join("a/SKILL.md")] - } - ); - - raw_tx - .send(Ok(notify_event( - EventKind::Remove(RemoveKind::File), - vec![root.join("b/SKILL.md")], - ))) - .expect("send second event"); - drop(raw_tx); - - let second = timeout(Duration::from_secs(2), rx.recv()) - .await - .expect("second watcher event") - .expect("broadcast recv second"); - assert_eq!( - second, - FileWatcherEvent::SkillsChanged { - paths: vec![root.join("b/SKILL.md")] - } - ); - } -} +#[path = "file_watcher_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/file_watcher_tests.rs b/codex-rs/core/src/file_watcher_tests.rs new file mode 100644 index 0000000000..995e7f7cea --- /dev/null +++ b/codex-rs/core/src/file_watcher_tests.rs @@ -0,0 +1,246 @@ +use super::*; +use notify::EventKind; +use notify::event::AccessKind; +use notify::event::AccessMode; +use notify::event::CreateKind; +use notify::event::ModifyKind; +use notify::event::RemoveKind; +use pretty_assertions::assert_eq; +use tokio::time::timeout; + +fn path(name: &str) -> PathBuf { + PathBuf::from(name) +} + +fn notify_event(kind: EventKind, paths: Vec) -> Event { + let mut event = Event::new(kind); + for path in paths { + event = event.add_path(path); + } + event +} + +#[test] +fn throttles_and_coalesces_within_interval() { + let start = Instant::now(); + let mut throttled = ThrottledPaths::new(start); + + throttled.add(vec![path("a")]); + let first = throttled.take_ready(start).expect("first emit"); + assert_eq!(first, vec![path("a")]); + + throttled.add(vec![path("b"), path("c")]); + assert_eq!(throttled.take_ready(start), None); + + let second = throttled + .take_ready(start + WATCHER_THROTTLE_INTERVAL) + .expect("coalesced emit"); + assert_eq!(second, vec![path("b"), path("c")]); +} + +#[test] +fn flushes_pending_on_shutdown() { + let start = Instant::now(); + let mut throttled = ThrottledPaths::new(start); + + throttled.add(vec![path("a")]); + let _ = throttled.take_ready(start).expect("first emit"); + + throttled.add(vec![path("b")]); + assert_eq!(throttled.take_ready(start), None); + + let flushed = throttled + .take_pending(start) + .expect("shutdown flush emits pending paths"); + assert_eq!(flushed, vec![path("b")]); +} + +#[test] +fn classify_event_filters_to_skills_roots() { + let root = path("/tmp/skills"); + let state = RwLock::new(WatchState { + skills_root_ref_counts: HashMap::from([(root.clone(), 1)]), + }); + let event = notify_event( + EventKind::Create(CreateKind::Any), + vec![ + root.join("demo/SKILL.md"), + path("/tmp/other/not-a-skill.txt"), + ], + ); + + let classified = classify_event(&event, &state); + assert_eq!(classified, vec![root.join("demo/SKILL.md")]); +} + +#[test] +fn classify_event_supports_multiple_roots_without_prefix_false_positives() { + let root_a = path("/tmp/skills"); + let root_b = path("/tmp/workspace/.codex/skills"); + let state = RwLock::new(WatchState { + skills_root_ref_counts: HashMap::from([(root_a.clone(), 1), (root_b.clone(), 1)]), + }); + let event = notify_event( + EventKind::Modify(ModifyKind::Any), + vec![ + root_a.join("alpha/SKILL.md"), + path("/tmp/skills-extra/not-under-skills.txt"), + root_b.join("beta/SKILL.md"), + ], + ); + + let classified = classify_event(&event, &state); + assert_eq!( + classified, + vec![root_a.join("alpha/SKILL.md"), root_b.join("beta/SKILL.md")] + ); +} + +#[test] +fn classify_event_ignores_non_mutating_event_kinds() { + let root = path("/tmp/skills"); + let state = RwLock::new(WatchState { + skills_root_ref_counts: HashMap::from([(root.clone(), 1)]), + }); + let path = root.join("demo/SKILL.md"); + + let access_event = notify_event( + EventKind::Access(AccessKind::Open(AccessMode::Any)), + vec![path.clone()], + ); + assert_eq!(classify_event(&access_event, &state), Vec::::new()); + + let any_event = notify_event(EventKind::Any, vec![path.clone()]); + assert_eq!(classify_event(&any_event, &state), Vec::::new()); + + let other_event = notify_event(EventKind::Other, vec![path]); + assert_eq!(classify_event(&other_event, &state), Vec::::new()); +} + +#[test] +fn register_skills_root_dedupes_state_entries() { + let watcher = FileWatcher::noop(); + let root = path("/tmp/skills"); + watcher.register_skills_root(root.clone()); + watcher.register_skills_root(root); + watcher.register_skills_root(path("/tmp/other-skills")); + + let state = watcher.state.read().expect("state lock"); + assert_eq!(state.skills_root_ref_counts.len(), 2); +} + +#[test] +fn watch_registration_drop_unregisters_roots() { + let watcher = Arc::new(FileWatcher::noop()); + let root = path("/tmp/skills"); + watcher.register_skills_root(root.clone()); + let registration = WatchRegistration { + file_watcher: Arc::downgrade(&watcher), + roots: vec![root], + }; + + drop(registration); + + let state = watcher.state.read().expect("state lock"); + assert_eq!(state.skills_root_ref_counts.len(), 0); +} + +#[test] +fn unregister_holds_state_lock_until_unwatch_finishes() { + let temp_dir = tempfile::tempdir().expect("temp dir"); + let root = temp_dir.path().join("skills"); + std::fs::create_dir(&root).expect("create root"); + + let watcher = Arc::new(FileWatcher::new(temp_dir.path().to_path_buf()).expect("watcher")); + watcher.register_skills_root(root.clone()); + + let inner = watcher.inner.as_ref().expect("watcher inner"); + let inner_guard = inner.lock().expect("inner lock"); + + let unregister_watcher = Arc::clone(&watcher); + let unregister_root = root.clone(); + let unregister_thread = std::thread::spawn(move || { + unregister_watcher.unregister_roots(&[unregister_root]); + }); + + let state_lock_observed = (0..100).any(|_| { + let locked = watcher.state.try_write().is_err(); + if !locked { + std::thread::sleep(Duration::from_millis(10)); + } + locked + }); + assert_eq!(state_lock_observed, true); + + let register_watcher = Arc::clone(&watcher); + let register_root = root.clone(); + let register_thread = std::thread::spawn(move || { + register_watcher.register_skills_root(register_root); + }); + + drop(inner_guard); + + unregister_thread.join().expect("unregister join"); + register_thread.join().expect("register join"); + + let state = watcher.state.read().expect("state lock"); + assert_eq!(state.skills_root_ref_counts.get(&root), Some(&1)); + drop(state); + + let inner = watcher.inner.as_ref().expect("watcher inner"); + let inner = inner.lock().expect("inner lock"); + assert_eq!( + inner.watched_paths.get(&root), + Some(&RecursiveMode::Recursive) + ); +} + +#[tokio::test] +async fn spawn_event_loop_flushes_pending_changes_on_shutdown() { + let watcher = FileWatcher::noop(); + let root = path("/tmp/skills"); + { + let mut state = watcher.state.write().expect("state lock"); + state.skills_root_ref_counts.insert(root.clone(), 1); + } + + let (raw_tx, raw_rx) = mpsc::unbounded_channel(); + let (tx, mut rx) = broadcast::channel(8); + watcher.spawn_event_loop(raw_rx, Arc::clone(&watcher.state), tx); + + raw_tx + .send(Ok(notify_event( + EventKind::Create(CreateKind::File), + vec![root.join("a/SKILL.md")], + ))) + .expect("send first event"); + let first = timeout(Duration::from_secs(2), rx.recv()) + .await + .expect("first watcher event") + .expect("broadcast recv first"); + assert_eq!( + first, + FileWatcherEvent::SkillsChanged { + paths: vec![root.join("a/SKILL.md")] + } + ); + + raw_tx + .send(Ok(notify_event( + EventKind::Remove(RemoveKind::File), + vec![root.join("b/SKILL.md")], + ))) + .expect("send second event"); + drop(raw_tx); + + let second = timeout(Duration::from_secs(2), rx.recv()) + .await + .expect("second watcher event") + .expect("broadcast recv second"); + assert_eq!( + second, + FileWatcherEvent::SkillsChanged { + paths: vec![root.join("b/SKILL.md")] + } + ); +} diff --git a/codex-rs/core/src/git_info.rs b/codex-rs/core/src/git_info.rs index 676d230c20..052f786bfa 100644 --- a/codex-rs/core/src/git_info.rs +++ b/codex-rs/core/src/git_info.rs @@ -691,597 +691,5 @@ pub async fn current_branch_name(cwd: &Path) -> Option { } #[cfg(test)] -mod tests { - use super::*; - - use core_test_support::skip_if_sandbox; - use std::fs; - use std::path::PathBuf; - use tempfile::TempDir; - - // Helper function to create a test git repository - async fn create_test_git_repo(temp_dir: &TempDir) -> PathBuf { - let repo_path = temp_dir.path().join("repo"); - fs::create_dir(&repo_path).expect("Failed to create repo dir"); - let envs = vec![ - ("GIT_CONFIG_GLOBAL", "/dev/null"), - ("GIT_CONFIG_NOSYSTEM", "1"), - ]; - - // Initialize git repo - Command::new("git") - .envs(envs.clone()) - .args(["init"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to init git repo"); - - // Configure git user (required for commits) - Command::new("git") - .envs(envs.clone()) - .args(["config", "user.name", "Test User"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to set git user name"); - - Command::new("git") - .envs(envs.clone()) - .args(["config", "user.email", "test@example.com"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to set git user email"); - - // Create a test file and commit it - let test_file = repo_path.join("test.txt"); - fs::write(&test_file, "test content").expect("Failed to write test file"); - - Command::new("git") - .envs(envs.clone()) - .args(["add", "."]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to add files"); - - Command::new("git") - .envs(envs.clone()) - .args(["commit", "-m", "Initial commit"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to commit"); - - repo_path - } - - #[tokio::test] - async fn test_recent_commits_non_git_directory_returns_empty() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let entries = recent_commits(temp_dir.path(), 10).await; - assert!(entries.is_empty(), "expected no commits outside a git repo"); - } - - #[tokio::test] - async fn test_recent_commits_orders_and_limits() { - skip_if_sandbox!(); - use tokio::time::Duration; - use tokio::time::sleep; - - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let repo_path = create_test_git_repo(&temp_dir).await; - - // Make three distinct commits with small delays to ensure ordering by timestamp. - fs::write(repo_path.join("file.txt"), "one").unwrap(); - Command::new("git") - .args(["add", "file.txt"]) - .current_dir(&repo_path) - .output() - .await - .expect("git add"); - Command::new("git") - .args(["commit", "-m", "first change"]) - .current_dir(&repo_path) - .output() - .await - .expect("git commit 1"); - - sleep(Duration::from_millis(1100)).await; - - fs::write(repo_path.join("file.txt"), "two").unwrap(); - Command::new("git") - .args(["add", "file.txt"]) - .current_dir(&repo_path) - .output() - .await - .expect("git add 2"); - Command::new("git") - .args(["commit", "-m", "second change"]) - .current_dir(&repo_path) - .output() - .await - .expect("git commit 2"); - - sleep(Duration::from_millis(1100)).await; - - fs::write(repo_path.join("file.txt"), "three").unwrap(); - Command::new("git") - .args(["add", "file.txt"]) - .current_dir(&repo_path) - .output() - .await - .expect("git add 3"); - Command::new("git") - .args(["commit", "-m", "third change"]) - .current_dir(&repo_path) - .output() - .await - .expect("git commit 3"); - - // Request the latest 3 commits; should be our three changes in reverse time order. - let entries = recent_commits(&repo_path, 3).await; - assert_eq!(entries.len(), 3); - assert_eq!(entries[0].subject, "third change"); - assert_eq!(entries[1].subject, "second change"); - assert_eq!(entries[2].subject, "first change"); - // Basic sanity on SHA formatting - for e in entries { - assert!(e.sha.len() >= 7 && e.sha.chars().all(|c| c.is_ascii_hexdigit())); - } - } - - async fn create_test_git_repo_with_remote(temp_dir: &TempDir) -> (PathBuf, String) { - let repo_path = create_test_git_repo(temp_dir).await; - let remote_path = temp_dir.path().join("remote.git"); - - Command::new("git") - .args(["init", "--bare", remote_path.to_str().unwrap()]) - .output() - .await - .expect("Failed to init bare remote"); - - Command::new("git") - .args(["remote", "add", "origin", remote_path.to_str().unwrap()]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to add remote"); - - let output = Command::new("git") - .args(["rev-parse", "--abbrev-ref", "HEAD"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to get branch"); - let branch = String::from_utf8(output.stdout).unwrap().trim().to_string(); - - Command::new("git") - .args(["push", "-u", "origin", &branch]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to push initial commit"); - - (repo_path, branch) - } - - #[tokio::test] - async fn test_collect_git_info_non_git_directory() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let result = collect_git_info(temp_dir.path()).await; - assert!(result.is_none()); - } - - #[tokio::test] - async fn test_collect_git_info_git_repository() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let repo_path = create_test_git_repo(&temp_dir).await; - - let git_info = collect_git_info(&repo_path) - .await - .expect("Should collect git info from repo"); - - // Should have commit hash - assert!(git_info.commit_hash.is_some()); - let commit_hash = git_info.commit_hash.unwrap(); - assert_eq!(commit_hash.len(), 40); // SHA-1 hash should be 40 characters - assert!(commit_hash.chars().all(|c| c.is_ascii_hexdigit())); - - // Should have branch (likely "main" or "master") - assert!(git_info.branch.is_some()); - let branch = git_info.branch.unwrap(); - assert!(branch == "main" || branch == "master"); - - // Repository URL might be None for local repos without remote - // This is acceptable behavior - } - - #[tokio::test] - async fn test_collect_git_info_with_remote() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let repo_path = create_test_git_repo(&temp_dir).await; - - // Add a remote origin - Command::new("git") - .args([ - "remote", - "add", - "origin", - "https://github.com/example/repo.git", - ]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to add remote"); - - let git_info = collect_git_info(&repo_path) - .await - .expect("Should collect git info from repo"); - - let remote_url_output = Command::new("git") - .args(["remote", "get-url", "origin"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to read remote url"); - // Some dev environments rewrite remotes (e.g., force SSH), so compare against - // whatever URL Git reports instead of a fixed placeholder. - let expected_remote = String::from_utf8(remote_url_output.stdout) - .unwrap() - .trim() - .to_string(); - - // Should have repository URL - assert_eq!(git_info.repository_url, Some(expected_remote)); - } - - #[tokio::test] - async fn test_collect_git_info_detached_head() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let repo_path = create_test_git_repo(&temp_dir).await; - - // Get the current commit hash - let output = Command::new("git") - .args(["rev-parse", "HEAD"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to get HEAD"); - let commit_hash = String::from_utf8(output.stdout).unwrap().trim().to_string(); - - // Checkout the commit directly (detached HEAD) - Command::new("git") - .args(["checkout", &commit_hash]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to checkout commit"); - - let git_info = collect_git_info(&repo_path) - .await - .expect("Should collect git info from repo"); - - // Should have commit hash - assert!(git_info.commit_hash.is_some()); - // Branch should be None for detached HEAD (since rev-parse --abbrev-ref HEAD returns "HEAD") - assert!(git_info.branch.is_none()); - } - - #[tokio::test] - async fn test_collect_git_info_with_branch() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let repo_path = create_test_git_repo(&temp_dir).await; - - // Create and checkout a new branch - Command::new("git") - .args(["checkout", "-b", "feature-branch"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to create branch"); - - let git_info = collect_git_info(&repo_path) - .await - .expect("Should collect git info from repo"); - - // Should have the new branch name - assert_eq!(git_info.branch, Some("feature-branch".to_string())); - } - - #[tokio::test] - async fn test_get_has_changes_non_git_directory_returns_none() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - assert_eq!(get_has_changes(temp_dir.path()).await, None); - } - - #[tokio::test] - async fn test_get_has_changes_clean_repo_returns_false() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let repo_path = create_test_git_repo(&temp_dir).await; - assert_eq!(get_has_changes(&repo_path).await, Some(false)); - } - - #[tokio::test] - async fn test_get_has_changes_with_tracked_change_returns_true() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let repo_path = create_test_git_repo(&temp_dir).await; - - fs::write(repo_path.join("test.txt"), "updated tracked file").expect("write tracked file"); - assert_eq!(get_has_changes(&repo_path).await, Some(true)); - } - - #[tokio::test] - async fn test_get_has_changes_with_untracked_change_returns_true() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let repo_path = create_test_git_repo(&temp_dir).await; - - fs::write(repo_path.join("new_file.txt"), "untracked").expect("write untracked file"); - assert_eq!(get_has_changes(&repo_path).await, Some(true)); - } - - #[tokio::test] - async fn test_get_git_working_tree_state_clean_repo() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let (repo_path, branch) = create_test_git_repo_with_remote(&temp_dir).await; - - let remote_sha = Command::new("git") - .args(["rev-parse", &format!("origin/{branch}")]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to rev-parse remote"); - let remote_sha = String::from_utf8(remote_sha.stdout) - .unwrap() - .trim() - .to_string(); - - let state = git_diff_to_remote(&repo_path) - .await - .expect("Should collect working tree state"); - assert_eq!(state.sha, GitSha::new(&remote_sha)); - assert!(state.diff.is_empty()); - } - - #[tokio::test] - async fn test_get_git_working_tree_state_with_changes() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let (repo_path, branch) = create_test_git_repo_with_remote(&temp_dir).await; - - let tracked = repo_path.join("test.txt"); - fs::write(&tracked, "modified").unwrap(); - fs::write(repo_path.join("untracked.txt"), "new").unwrap(); - - let remote_sha = Command::new("git") - .args(["rev-parse", &format!("origin/{branch}")]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to rev-parse remote"); - let remote_sha = String::from_utf8(remote_sha.stdout) - .unwrap() - .trim() - .to_string(); - - let state = git_diff_to_remote(&repo_path) - .await - .expect("Should collect working tree state"); - assert_eq!(state.sha, GitSha::new(&remote_sha)); - assert!(state.diff.contains("test.txt")); - assert!(state.diff.contains("untracked.txt")); - } - - #[tokio::test] - async fn test_get_git_working_tree_state_branch_fallback() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let (repo_path, _branch) = create_test_git_repo_with_remote(&temp_dir).await; - - Command::new("git") - .args(["checkout", "-b", "feature"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to create feature branch"); - Command::new("git") - .args(["push", "-u", "origin", "feature"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to push feature branch"); - - Command::new("git") - .args(["checkout", "-b", "local-branch"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to create local branch"); - - let remote_sha = Command::new("git") - .args(["rev-parse", "origin/feature"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to rev-parse remote"); - let remote_sha = String::from_utf8(remote_sha.stdout) - .unwrap() - .trim() - .to_string(); - - let state = git_diff_to_remote(&repo_path) - .await - .expect("Should collect working tree state"); - assert_eq!(state.sha, GitSha::new(&remote_sha)); - } - - #[test] - fn resolve_root_git_project_for_trust_returns_none_outside_repo() { - let tmp = TempDir::new().expect("tempdir"); - assert!(resolve_root_git_project_for_trust(tmp.path()).is_none()); - } - - #[tokio::test] - async fn resolve_root_git_project_for_trust_regular_repo_returns_repo_root() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let repo_path = create_test_git_repo(&temp_dir).await; - let expected = std::fs::canonicalize(&repo_path).unwrap(); - - assert_eq!( - resolve_root_git_project_for_trust(&repo_path), - Some(expected.clone()) - ); - let nested = repo_path.join("sub/dir"); - std::fs::create_dir_all(&nested).unwrap(); - assert_eq!(resolve_root_git_project_for_trust(&nested), Some(expected)); - } - - #[tokio::test] - async fn resolve_root_git_project_for_trust_detects_worktree_and_returns_main_root() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let repo_path = create_test_git_repo(&temp_dir).await; - - // Create a linked worktree - let wt_root = temp_dir.path().join("wt"); - let _ = std::process::Command::new("git") - .args([ - "worktree", - "add", - wt_root.to_str().unwrap(), - "-b", - "feature/x", - ]) - .current_dir(&repo_path) - .output() - .expect("git worktree add"); - - let expected = std::fs::canonicalize(&repo_path).ok(); - let got = resolve_root_git_project_for_trust(&wt_root) - .and_then(|p| std::fs::canonicalize(p).ok()); - assert_eq!(got, expected); - let nested = wt_root.join("nested/sub"); - std::fs::create_dir_all(&nested).unwrap(); - let got_nested = - resolve_root_git_project_for_trust(&nested).and_then(|p| std::fs::canonicalize(p).ok()); - assert_eq!(got_nested, expected); - } - - #[test] - fn resolve_root_git_project_for_trust_detects_worktree_pointer_without_git_command() { - let tmp = TempDir::new().expect("tempdir"); - let repo_root = tmp.path().join("repo"); - let common_dir = repo_root.join(".git"); - let worktree_git_dir = common_dir.join("worktrees").join("feature-x"); - let worktree_root = tmp.path().join("wt"); - std::fs::create_dir_all(&worktree_git_dir).unwrap(); - std::fs::create_dir_all(&worktree_root).unwrap(); - std::fs::create_dir_all(worktree_root.join("nested")).unwrap(); - std::fs::write( - worktree_root.join(".git"), - format!("gitdir: {}\n", worktree_git_dir.display()), - ) - .unwrap(); - - let expected = std::fs::canonicalize(&repo_root).unwrap(); - assert_eq!( - resolve_root_git_project_for_trust(&worktree_root), - Some(expected.clone()) - ); - assert_eq!( - resolve_root_git_project_for_trust(&worktree_root.join("nested")), - Some(expected) - ); - } - - #[test] - fn resolve_root_git_project_for_trust_non_worktrees_gitdir_returns_none() { - let tmp = TempDir::new().expect("tempdir"); - let proj = tmp.path().join("proj"); - std::fs::create_dir_all(proj.join("nested")).unwrap(); - - // `.git` is a file but does not point to a worktrees path - std::fs::write( - proj.join(".git"), - format!( - "gitdir: {}\n", - tmp.path().join("some/other/location").display() - ), - ) - .unwrap(); - - assert!(resolve_root_git_project_for_trust(&proj).is_none()); - assert!(resolve_root_git_project_for_trust(&proj.join("nested")).is_none()); - } - - #[tokio::test] - async fn test_get_git_working_tree_state_unpushed_commit() { - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let (repo_path, branch) = create_test_git_repo_with_remote(&temp_dir).await; - - let remote_sha = Command::new("git") - .args(["rev-parse", &format!("origin/{branch}")]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to rev-parse remote"); - let remote_sha = String::from_utf8(remote_sha.stdout) - .unwrap() - .trim() - .to_string(); - - fs::write(repo_path.join("test.txt"), "updated").unwrap(); - Command::new("git") - .args(["add", "test.txt"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to add file"); - Command::new("git") - .args(["commit", "-m", "local change"]) - .current_dir(&repo_path) - .output() - .await - .expect("Failed to commit"); - - let state = git_diff_to_remote(&repo_path) - .await - .expect("Should collect working tree state"); - assert_eq!(state.sha, GitSha::new(&remote_sha)); - assert!(state.diff.contains("updated")); - } - - #[test] - fn test_git_info_serialization() { - let git_info = GitInfo { - commit_hash: Some("abc123def456".to_string()), - branch: Some("main".to_string()), - repository_url: Some("https://github.com/example/repo.git".to_string()), - }; - - let json = serde_json::to_string(&git_info).expect("Should serialize GitInfo"); - let parsed: serde_json::Value = serde_json::from_str(&json).expect("Should parse JSON"); - - assert_eq!(parsed["commit_hash"], "abc123def456"); - assert_eq!(parsed["branch"], "main"); - assert_eq!( - parsed["repository_url"], - "https://github.com/example/repo.git" - ); - } - - #[test] - fn test_git_info_serialization_with_nones() { - let git_info = GitInfo { - commit_hash: None, - branch: None, - repository_url: None, - }; - - let json = serde_json::to_string(&git_info).expect("Should serialize GitInfo"); - let parsed: serde_json::Value = serde_json::from_str(&json).expect("Should parse JSON"); - - // Fields with None values should be omitted due to skip_serializing_if - assert!(!parsed.as_object().unwrap().contains_key("commit_hash")); - assert!(!parsed.as_object().unwrap().contains_key("branch")); - assert!(!parsed.as_object().unwrap().contains_key("repository_url")); - } -} +#[path = "git_info_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/git_info_tests.rs b/codex-rs/core/src/git_info_tests.rs new file mode 100644 index 0000000000..73714ce42f --- /dev/null +++ b/codex-rs/core/src/git_info_tests.rs @@ -0,0 +1,592 @@ +use super::*; + +use core_test_support::skip_if_sandbox; +use std::fs; +use std::path::PathBuf; +use tempfile::TempDir; + +// Helper function to create a test git repository +async fn create_test_git_repo(temp_dir: &TempDir) -> PathBuf { + let repo_path = temp_dir.path().join("repo"); + fs::create_dir(&repo_path).expect("Failed to create repo dir"); + let envs = vec![ + ("GIT_CONFIG_GLOBAL", "/dev/null"), + ("GIT_CONFIG_NOSYSTEM", "1"), + ]; + + // Initialize git repo + Command::new("git") + .envs(envs.clone()) + .args(["init"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to init git repo"); + + // Configure git user (required for commits) + Command::new("git") + .envs(envs.clone()) + .args(["config", "user.name", "Test User"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to set git user name"); + + Command::new("git") + .envs(envs.clone()) + .args(["config", "user.email", "test@example.com"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to set git user email"); + + // Create a test file and commit it + let test_file = repo_path.join("test.txt"); + fs::write(&test_file, "test content").expect("Failed to write test file"); + + Command::new("git") + .envs(envs.clone()) + .args(["add", "."]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to add files"); + + Command::new("git") + .envs(envs.clone()) + .args(["commit", "-m", "Initial commit"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to commit"); + + repo_path +} + +#[tokio::test] +async fn test_recent_commits_non_git_directory_returns_empty() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let entries = recent_commits(temp_dir.path(), 10).await; + assert!(entries.is_empty(), "expected no commits outside a git repo"); +} + +#[tokio::test] +async fn test_recent_commits_orders_and_limits() { + skip_if_sandbox!(); + use tokio::time::Duration; + use tokio::time::sleep; + + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + + // Make three distinct commits with small delays to ensure ordering by timestamp. + fs::write(repo_path.join("file.txt"), "one").unwrap(); + Command::new("git") + .args(["add", "file.txt"]) + .current_dir(&repo_path) + .output() + .await + .expect("git add"); + Command::new("git") + .args(["commit", "-m", "first change"]) + .current_dir(&repo_path) + .output() + .await + .expect("git commit 1"); + + sleep(Duration::from_millis(1100)).await; + + fs::write(repo_path.join("file.txt"), "two").unwrap(); + Command::new("git") + .args(["add", "file.txt"]) + .current_dir(&repo_path) + .output() + .await + .expect("git add 2"); + Command::new("git") + .args(["commit", "-m", "second change"]) + .current_dir(&repo_path) + .output() + .await + .expect("git commit 2"); + + sleep(Duration::from_millis(1100)).await; + + fs::write(repo_path.join("file.txt"), "three").unwrap(); + Command::new("git") + .args(["add", "file.txt"]) + .current_dir(&repo_path) + .output() + .await + .expect("git add 3"); + Command::new("git") + .args(["commit", "-m", "third change"]) + .current_dir(&repo_path) + .output() + .await + .expect("git commit 3"); + + // Request the latest 3 commits; should be our three changes in reverse time order. + let entries = recent_commits(&repo_path, 3).await; + assert_eq!(entries.len(), 3); + assert_eq!(entries[0].subject, "third change"); + assert_eq!(entries[1].subject, "second change"); + assert_eq!(entries[2].subject, "first change"); + // Basic sanity on SHA formatting + for e in entries { + assert!(e.sha.len() >= 7 && e.sha.chars().all(|c| c.is_ascii_hexdigit())); + } +} + +async fn create_test_git_repo_with_remote(temp_dir: &TempDir) -> (PathBuf, String) { + let repo_path = create_test_git_repo(temp_dir).await; + let remote_path = temp_dir.path().join("remote.git"); + + Command::new("git") + .args(["init", "--bare", remote_path.to_str().unwrap()]) + .output() + .await + .expect("Failed to init bare remote"); + + Command::new("git") + .args(["remote", "add", "origin", remote_path.to_str().unwrap()]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to add remote"); + + let output = Command::new("git") + .args(["rev-parse", "--abbrev-ref", "HEAD"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to get branch"); + let branch = String::from_utf8(output.stdout).unwrap().trim().to_string(); + + Command::new("git") + .args(["push", "-u", "origin", &branch]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to push initial commit"); + + (repo_path, branch) +} + +#[tokio::test] +async fn test_collect_git_info_non_git_directory() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let result = collect_git_info(temp_dir.path()).await; + assert!(result.is_none()); +} + +#[tokio::test] +async fn test_collect_git_info_git_repository() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + + let git_info = collect_git_info(&repo_path) + .await + .expect("Should collect git info from repo"); + + // Should have commit hash + assert!(git_info.commit_hash.is_some()); + let commit_hash = git_info.commit_hash.unwrap(); + assert_eq!(commit_hash.len(), 40); // SHA-1 hash should be 40 characters + assert!(commit_hash.chars().all(|c| c.is_ascii_hexdigit())); + + // Should have branch (likely "main" or "master") + assert!(git_info.branch.is_some()); + let branch = git_info.branch.unwrap(); + assert!(branch == "main" || branch == "master"); + + // Repository URL might be None for local repos without remote + // This is acceptable behavior +} + +#[tokio::test] +async fn test_collect_git_info_with_remote() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + + // Add a remote origin + Command::new("git") + .args([ + "remote", + "add", + "origin", + "https://github.com/example/repo.git", + ]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to add remote"); + + let git_info = collect_git_info(&repo_path) + .await + .expect("Should collect git info from repo"); + + let remote_url_output = Command::new("git") + .args(["remote", "get-url", "origin"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to read remote url"); + // Some dev environments rewrite remotes (e.g., force SSH), so compare against + // whatever URL Git reports instead of a fixed placeholder. + let expected_remote = String::from_utf8(remote_url_output.stdout) + .unwrap() + .trim() + .to_string(); + + // Should have repository URL + assert_eq!(git_info.repository_url, Some(expected_remote)); +} + +#[tokio::test] +async fn test_collect_git_info_detached_head() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + + // Get the current commit hash + let output = Command::new("git") + .args(["rev-parse", "HEAD"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to get HEAD"); + let commit_hash = String::from_utf8(output.stdout).unwrap().trim().to_string(); + + // Checkout the commit directly (detached HEAD) + Command::new("git") + .args(["checkout", &commit_hash]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to checkout commit"); + + let git_info = collect_git_info(&repo_path) + .await + .expect("Should collect git info from repo"); + + // Should have commit hash + assert!(git_info.commit_hash.is_some()); + // Branch should be None for detached HEAD (since rev-parse --abbrev-ref HEAD returns "HEAD") + assert!(git_info.branch.is_none()); +} + +#[tokio::test] +async fn test_collect_git_info_with_branch() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + + // Create and checkout a new branch + Command::new("git") + .args(["checkout", "-b", "feature-branch"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to create branch"); + + let git_info = collect_git_info(&repo_path) + .await + .expect("Should collect git info from repo"); + + // Should have the new branch name + assert_eq!(git_info.branch, Some("feature-branch".to_string())); +} + +#[tokio::test] +async fn test_get_has_changes_non_git_directory_returns_none() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + assert_eq!(get_has_changes(temp_dir.path()).await, None); +} + +#[tokio::test] +async fn test_get_has_changes_clean_repo_returns_false() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + assert_eq!(get_has_changes(&repo_path).await, Some(false)); +} + +#[tokio::test] +async fn test_get_has_changes_with_tracked_change_returns_true() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + + fs::write(repo_path.join("test.txt"), "updated tracked file").expect("write tracked file"); + assert_eq!(get_has_changes(&repo_path).await, Some(true)); +} + +#[tokio::test] +async fn test_get_has_changes_with_untracked_change_returns_true() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + + fs::write(repo_path.join("new_file.txt"), "untracked").expect("write untracked file"); + assert_eq!(get_has_changes(&repo_path).await, Some(true)); +} + +#[tokio::test] +async fn test_get_git_working_tree_state_clean_repo() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let (repo_path, branch) = create_test_git_repo_with_remote(&temp_dir).await; + + let remote_sha = Command::new("git") + .args(["rev-parse", &format!("origin/{branch}")]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to rev-parse remote"); + let remote_sha = String::from_utf8(remote_sha.stdout) + .unwrap() + .trim() + .to_string(); + + let state = git_diff_to_remote(&repo_path) + .await + .expect("Should collect working tree state"); + assert_eq!(state.sha, GitSha::new(&remote_sha)); + assert!(state.diff.is_empty()); +} + +#[tokio::test] +async fn test_get_git_working_tree_state_with_changes() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let (repo_path, branch) = create_test_git_repo_with_remote(&temp_dir).await; + + let tracked = repo_path.join("test.txt"); + fs::write(&tracked, "modified").unwrap(); + fs::write(repo_path.join("untracked.txt"), "new").unwrap(); + + let remote_sha = Command::new("git") + .args(["rev-parse", &format!("origin/{branch}")]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to rev-parse remote"); + let remote_sha = String::from_utf8(remote_sha.stdout) + .unwrap() + .trim() + .to_string(); + + let state = git_diff_to_remote(&repo_path) + .await + .expect("Should collect working tree state"); + assert_eq!(state.sha, GitSha::new(&remote_sha)); + assert!(state.diff.contains("test.txt")); + assert!(state.diff.contains("untracked.txt")); +} + +#[tokio::test] +async fn test_get_git_working_tree_state_branch_fallback() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let (repo_path, _branch) = create_test_git_repo_with_remote(&temp_dir).await; + + Command::new("git") + .args(["checkout", "-b", "feature"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to create feature branch"); + Command::new("git") + .args(["push", "-u", "origin", "feature"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to push feature branch"); + + Command::new("git") + .args(["checkout", "-b", "local-branch"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to create local branch"); + + let remote_sha = Command::new("git") + .args(["rev-parse", "origin/feature"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to rev-parse remote"); + let remote_sha = String::from_utf8(remote_sha.stdout) + .unwrap() + .trim() + .to_string(); + + let state = git_diff_to_remote(&repo_path) + .await + .expect("Should collect working tree state"); + assert_eq!(state.sha, GitSha::new(&remote_sha)); +} + +#[test] +fn resolve_root_git_project_for_trust_returns_none_outside_repo() { + let tmp = TempDir::new().expect("tempdir"); + assert!(resolve_root_git_project_for_trust(tmp.path()).is_none()); +} + +#[tokio::test] +async fn resolve_root_git_project_for_trust_regular_repo_returns_repo_root() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + let expected = std::fs::canonicalize(&repo_path).unwrap(); + + assert_eq!( + resolve_root_git_project_for_trust(&repo_path), + Some(expected.clone()) + ); + let nested = repo_path.join("sub/dir"); + std::fs::create_dir_all(&nested).unwrap(); + assert_eq!(resolve_root_git_project_for_trust(&nested), Some(expected)); +} + +#[tokio::test] +async fn resolve_root_git_project_for_trust_detects_worktree_and_returns_main_root() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let repo_path = create_test_git_repo(&temp_dir).await; + + // Create a linked worktree + let wt_root = temp_dir.path().join("wt"); + let _ = std::process::Command::new("git") + .args([ + "worktree", + "add", + wt_root.to_str().unwrap(), + "-b", + "feature/x", + ]) + .current_dir(&repo_path) + .output() + .expect("git worktree add"); + + let expected = std::fs::canonicalize(&repo_path).ok(); + let got = + resolve_root_git_project_for_trust(&wt_root).and_then(|p| std::fs::canonicalize(p).ok()); + assert_eq!(got, expected); + let nested = wt_root.join("nested/sub"); + std::fs::create_dir_all(&nested).unwrap(); + let got_nested = + resolve_root_git_project_for_trust(&nested).and_then(|p| std::fs::canonicalize(p).ok()); + assert_eq!(got_nested, expected); +} + +#[test] +fn resolve_root_git_project_for_trust_detects_worktree_pointer_without_git_command() { + let tmp = TempDir::new().expect("tempdir"); + let repo_root = tmp.path().join("repo"); + let common_dir = repo_root.join(".git"); + let worktree_git_dir = common_dir.join("worktrees").join("feature-x"); + let worktree_root = tmp.path().join("wt"); + std::fs::create_dir_all(&worktree_git_dir).unwrap(); + std::fs::create_dir_all(&worktree_root).unwrap(); + std::fs::create_dir_all(worktree_root.join("nested")).unwrap(); + std::fs::write( + worktree_root.join(".git"), + format!("gitdir: {}\n", worktree_git_dir.display()), + ) + .unwrap(); + + let expected = std::fs::canonicalize(&repo_root).unwrap(); + assert_eq!( + resolve_root_git_project_for_trust(&worktree_root), + Some(expected.clone()) + ); + assert_eq!( + resolve_root_git_project_for_trust(&worktree_root.join("nested")), + Some(expected) + ); +} + +#[test] +fn resolve_root_git_project_for_trust_non_worktrees_gitdir_returns_none() { + let tmp = TempDir::new().expect("tempdir"); + let proj = tmp.path().join("proj"); + std::fs::create_dir_all(proj.join("nested")).unwrap(); + + // `.git` is a file but does not point to a worktrees path + std::fs::write( + proj.join(".git"), + format!( + "gitdir: {}\n", + tmp.path().join("some/other/location").display() + ), + ) + .unwrap(); + + assert!(resolve_root_git_project_for_trust(&proj).is_none()); + assert!(resolve_root_git_project_for_trust(&proj.join("nested")).is_none()); +} + +#[tokio::test] +async fn test_get_git_working_tree_state_unpushed_commit() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let (repo_path, branch) = create_test_git_repo_with_remote(&temp_dir).await; + + let remote_sha = Command::new("git") + .args(["rev-parse", &format!("origin/{branch}")]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to rev-parse remote"); + let remote_sha = String::from_utf8(remote_sha.stdout) + .unwrap() + .trim() + .to_string(); + + fs::write(repo_path.join("test.txt"), "updated").unwrap(); + Command::new("git") + .args(["add", "test.txt"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to add file"); + Command::new("git") + .args(["commit", "-m", "local change"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to commit"); + + let state = git_diff_to_remote(&repo_path) + .await + .expect("Should collect working tree state"); + assert_eq!(state.sha, GitSha::new(&remote_sha)); + assert!(state.diff.contains("updated")); +} + +#[test] +fn test_git_info_serialization() { + let git_info = GitInfo { + commit_hash: Some("abc123def456".to_string()), + branch: Some("main".to_string()), + repository_url: Some("https://github.com/example/repo.git".to_string()), + }; + + let json = serde_json::to_string(&git_info).expect("Should serialize GitInfo"); + let parsed: serde_json::Value = serde_json::from_str(&json).expect("Should parse JSON"); + + assert_eq!(parsed["commit_hash"], "abc123def456"); + assert_eq!(parsed["branch"], "main"); + assert_eq!( + parsed["repository_url"], + "https://github.com/example/repo.git" + ); +} + +#[test] +fn test_git_info_serialization_with_nones() { + let git_info = GitInfo { + commit_hash: None, + branch: None, + repository_url: None, + }; + + let json = serde_json::to_string(&git_info).expect("Should serialize GitInfo"); + let parsed: serde_json::Value = serde_json::from_str(&json).expect("Should parse JSON"); + + // Fields with None values should be omitted due to skip_serializing_if + assert!(!parsed.as_object().unwrap().contains_key("commit_hash")); + assert!(!parsed.as_object().unwrap().contains_key("branch")); + assert!(!parsed.as_object().unwrap().contains_key("repository_url")); +} diff --git a/codex-rs/core/src/instructions/user_instructions.rs b/codex-rs/core/src/instructions/user_instructions.rs index 09e6e4c2f7..a0389c9ff8 100644 --- a/codex-rs/core/src/instructions/user_instructions.rs +++ b/codex-rs/core/src/instructions/user_instructions.rs @@ -53,73 +53,5 @@ impl From for ResponseItem { } #[cfg(test)] -mod tests { - use super::*; - use codex_protocol::models::ContentItem; - use pretty_assertions::assert_eq; - - #[test] - fn test_user_instructions() { - let user_instructions = UserInstructions { - directory: "test_directory".to_string(), - text: "test_text".to_string(), - }; - let response_item: ResponseItem = user_instructions.into(); - - let ResponseItem::Message { role, content, .. } = response_item else { - panic!("expected ResponseItem::Message"); - }; - - assert_eq!(role, "user"); - - let [ContentItem::InputText { text }] = content.as_slice() else { - panic!("expected one InputText content item"); - }; - - assert_eq!( - text, - "# AGENTS.md instructions for test_directory\n\n\ntest_text\n", - ); - } - - #[test] - fn test_is_user_instructions() { - assert!(AGENTS_MD_FRAGMENT.matches_text( - "# AGENTS.md instructions for test_directory\n\n\ntest_text\n" - )); - assert!(!AGENTS_MD_FRAGMENT.matches_text("test_text")); - } - - #[test] - fn test_skill_instructions() { - let skill_instructions = SkillInstructions { - name: "demo-skill".to_string(), - path: "skills/demo/SKILL.md".to_string(), - contents: "body".to_string(), - }; - let response_item: ResponseItem = skill_instructions.into(); - - let ResponseItem::Message { role, content, .. } = response_item else { - panic!("expected ResponseItem::Message"); - }; - - assert_eq!(role, "user"); - - let [ContentItem::InputText { text }] = content.as_slice() else { - panic!("expected one InputText content item"); - }; - - assert_eq!( - text, - "\ndemo-skill\nskills/demo/SKILL.md\nbody\n", - ); - } - - #[test] - fn test_is_skill_instructions() { - assert!(SKILL_FRAGMENT.matches_text( - "\ndemo-skill\nskills/demo/SKILL.md\nbody\n" - )); - assert!(!SKILL_FRAGMENT.matches_text("regular text")); - } -} +#[path = "user_instructions_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/instructions/user_instructions_tests.rs b/codex-rs/core/src/instructions/user_instructions_tests.rs new file mode 100644 index 0000000000..58442600a8 --- /dev/null +++ b/codex-rs/core/src/instructions/user_instructions_tests.rs @@ -0,0 +1,68 @@ +use super::*; +use codex_protocol::models::ContentItem; +use pretty_assertions::assert_eq; + +#[test] +fn test_user_instructions() { + let user_instructions = UserInstructions { + directory: "test_directory".to_string(), + text: "test_text".to_string(), + }; + let response_item: ResponseItem = user_instructions.into(); + + let ResponseItem::Message { role, content, .. } = response_item else { + panic!("expected ResponseItem::Message"); + }; + + assert_eq!(role, "user"); + + let [ContentItem::InputText { text }] = content.as_slice() else { + panic!("expected one InputText content item"); + }; + + assert_eq!( + text, + "# AGENTS.md instructions for test_directory\n\n\ntest_text\n", + ); +} + +#[test] +fn test_is_user_instructions() { + assert!(AGENTS_MD_FRAGMENT.matches_text( + "# AGENTS.md instructions for test_directory\n\n\ntest_text\n" + )); + assert!(!AGENTS_MD_FRAGMENT.matches_text("test_text")); +} + +#[test] +fn test_skill_instructions() { + let skill_instructions = SkillInstructions { + name: "demo-skill".to_string(), + path: "skills/demo/SKILL.md".to_string(), + contents: "body".to_string(), + }; + let response_item: ResponseItem = skill_instructions.into(); + + let ResponseItem::Message { role, content, .. } = response_item else { + panic!("expected ResponseItem::Message"); + }; + + assert_eq!(role, "user"); + + let [ContentItem::InputText { text }] = content.as_slice() else { + panic!("expected one InputText content item"); + }; + + assert_eq!( + text, + "\ndemo-skill\nskills/demo/SKILL.md\nbody\n", + ); +} + +#[test] +fn test_is_skill_instructions() { + assert!(SKILL_FRAGMENT.matches_text( + "\ndemo-skill\nskills/demo/SKILL.md\nbody\n" + )); + assert!(!SKILL_FRAGMENT.matches_text("regular text")); +} diff --git a/codex-rs/core/src/landlock.rs b/codex-rs/core/src/landlock.rs index 1072854883..ea65595e11 100644 --- a/codex-rs/core/src/landlock.rs +++ b/codex-rs/core/src/landlock.rs @@ -148,75 +148,5 @@ pub(crate) fn create_linux_sandbox_command_args( } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn legacy_landlock_flag_is_included_when_requested() { - let command = vec!["/bin/true".to_string()]; - let cwd = Path::new("/tmp"); - - let default_bwrap = create_linux_sandbox_command_args(command.clone(), cwd, false, false); - assert_eq!( - default_bwrap.contains(&"--use-legacy-landlock".to_string()), - false - ); - - let legacy_landlock = create_linux_sandbox_command_args(command, cwd, true, false); - assert_eq!( - legacy_landlock.contains(&"--use-legacy-landlock".to_string()), - true - ); - } - - #[test] - fn proxy_flag_is_included_when_requested() { - let command = vec!["/bin/true".to_string()]; - let cwd = Path::new("/tmp"); - - let args = create_linux_sandbox_command_args(command, cwd, true, true); - assert_eq!( - args.contains(&"--allow-network-for-proxy".to_string()), - true - ); - } - - #[test] - fn split_policy_flags_are_included() { - let command = vec!["/bin/true".to_string()]; - let cwd = Path::new("/tmp"); - let sandbox_policy = SandboxPolicy::new_read_only_policy(); - let file_system_sandbox_policy = FileSystemSandboxPolicy::from(&sandbox_policy); - let network_sandbox_policy = NetworkSandboxPolicy::from(&sandbox_policy); - - let args = create_linux_sandbox_command_args_for_policies( - command, - &sandbox_policy, - &file_system_sandbox_policy, - network_sandbox_policy, - cwd, - true, - false, - ); - - assert_eq!( - args.windows(2).any(|window| { - window[0] == "--file-system-sandbox-policy" && !window[1].is_empty() - }), - true - ); - assert_eq!( - args.windows(2) - .any(|window| window[0] == "--network-sandbox-policy" - && window[1] == "\"restricted\""), - true - ); - } - - #[test] - fn proxy_network_requires_managed_requirements() { - assert_eq!(allow_network_for_proxy(false), false); - assert_eq!(allow_network_for_proxy(true), true); - } -} +#[path = "landlock_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/landlock_tests.rs b/codex-rs/core/src/landlock_tests.rs new file mode 100644 index 0000000000..75b887e267 --- /dev/null +++ b/codex-rs/core/src/landlock_tests.rs @@ -0,0 +1,68 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn legacy_landlock_flag_is_included_when_requested() { + let command = vec!["/bin/true".to_string()]; + let cwd = Path::new("/tmp"); + + let default_bwrap = create_linux_sandbox_command_args(command.clone(), cwd, false, false); + assert_eq!( + default_bwrap.contains(&"--use-legacy-landlock".to_string()), + false + ); + + let legacy_landlock = create_linux_sandbox_command_args(command, cwd, true, false); + assert_eq!( + legacy_landlock.contains(&"--use-legacy-landlock".to_string()), + true + ); +} + +#[test] +fn proxy_flag_is_included_when_requested() { + let command = vec!["/bin/true".to_string()]; + let cwd = Path::new("/tmp"); + + let args = create_linux_sandbox_command_args(command, cwd, true, true); + assert_eq!( + args.contains(&"--allow-network-for-proxy".to_string()), + true + ); +} + +#[test] +fn split_policy_flags_are_included() { + let command = vec!["/bin/true".to_string()]; + let cwd = Path::new("/tmp"); + let sandbox_policy = SandboxPolicy::new_read_only_policy(); + let file_system_sandbox_policy = FileSystemSandboxPolicy::from(&sandbox_policy); + let network_sandbox_policy = NetworkSandboxPolicy::from(&sandbox_policy); + + let args = create_linux_sandbox_command_args_for_policies( + command, + &sandbox_policy, + &file_system_sandbox_policy, + network_sandbox_policy, + cwd, + true, + false, + ); + + assert_eq!( + args.windows(2) + .any(|window| { window[0] == "--file-system-sandbox-policy" && !window[1].is_empty() }), + true + ); + assert_eq!( + args.windows(2) + .any(|window| window[0] == "--network-sandbox-policy" && window[1] == "\"restricted\""), + true + ); +} + +#[test] +fn proxy_network_requires_managed_requirements() { + assert_eq!(allow_network_for_proxy(false), false); + assert_eq!(allow_network_for_proxy(true), true); +} diff --git a/codex-rs/core/src/mcp/mod.rs b/codex-rs/core/src/mcp/mod.rs index ed93106fe9..3140f5bcff 100644 --- a/codex-rs/core/src/mcp/mod.rs +++ b/codex-rs/core/src/mcp/mod.rs @@ -469,344 +469,5 @@ pub(crate) async fn collect_mcp_snapshot_from_manager( } #[cfg(test)] -mod tests { - use super::*; - use crate::config::CONFIG_TOML_FILE; - use crate::config::ConfigBuilder; - use crate::plugins::AppConnectorId; - use crate::plugins::PluginCapabilitySummary; - use pretty_assertions::assert_eq; - use std::fs; - use std::path::Path; - use toml::Value; - - fn write_file(path: &Path, contents: &str) { - fs::create_dir_all(path.parent().expect("file should have a parent")).unwrap(); - fs::write(path, contents).unwrap(); - } - - fn plugin_config_toml() -> String { - let mut root = toml::map::Map::new(); - - let mut features = toml::map::Map::new(); - features.insert("plugins".to_string(), Value::Boolean(true)); - root.insert("features".to_string(), Value::Table(features)); - - let mut plugin = toml::map::Map::new(); - plugin.insert("enabled".to_string(), Value::Boolean(true)); - - let mut plugins = toml::map::Map::new(); - plugins.insert("sample@test".to_string(), Value::Table(plugin)); - root.insert("plugins".to_string(), Value::Table(plugins)); - - toml::to_string(&Value::Table(root)).expect("plugin test config should serialize") - } - - fn make_tool(name: &str) -> Tool { - Tool { - name: name.to_string(), - title: None, - description: None, - input_schema: serde_json::json!({"type": "object", "properties": {}}), - output_schema: None, - annotations: None, - icons: None, - meta: None, - } - } - - #[test] - fn split_qualified_tool_name_returns_server_and_tool() { - assert_eq!( - split_qualified_tool_name("mcp__alpha__do_thing"), - Some(("alpha".to_string(), "do_thing".to_string())) - ); - } - - #[test] - fn split_qualified_tool_name_rejects_invalid_names() { - assert_eq!(split_qualified_tool_name("other__alpha__do_thing"), None); - assert_eq!(split_qualified_tool_name("mcp__alpha__"), None); - } - - #[test] - fn group_tools_by_server_strips_prefix_and_groups() { - let mut tools = HashMap::new(); - tools.insert("mcp__alpha__do_thing".to_string(), make_tool("do_thing")); - tools.insert( - "mcp__alpha__nested__op".to_string(), - make_tool("nested__op"), - ); - tools.insert("mcp__beta__do_other".to_string(), make_tool("do_other")); - - let mut expected_alpha = HashMap::new(); - expected_alpha.insert("do_thing".to_string(), make_tool("do_thing")); - expected_alpha.insert("nested__op".to_string(), make_tool("nested__op")); - - let mut expected_beta = HashMap::new(); - expected_beta.insert("do_other".to_string(), make_tool("do_other")); - - let mut expected = HashMap::new(); - expected.insert("alpha".to_string(), expected_alpha); - expected.insert("beta".to_string(), expected_beta); - - assert_eq!(group_tools_by_server(&tools), expected); - } - - #[test] - fn tool_plugin_provenance_collects_app_and_mcp_sources() { - let provenance = ToolPluginProvenance::from_capability_summaries(&[ - PluginCapabilitySummary { - display_name: "alpha-plugin".to_string(), - app_connector_ids: vec![AppConnectorId("connector_example".to_string())], - mcp_server_names: vec!["alpha".to_string()], - ..PluginCapabilitySummary::default() - }, - PluginCapabilitySummary { - display_name: "beta-plugin".to_string(), - app_connector_ids: vec![ - AppConnectorId("connector_example".to_string()), - AppConnectorId("connector_gmail".to_string()), - ], - mcp_server_names: vec!["beta".to_string()], - ..PluginCapabilitySummary::default() - }, - ]); - - assert_eq!( - provenance, - ToolPluginProvenance { - plugin_display_names_by_connector_id: HashMap::from([ - ( - "connector_example".to_string(), - vec!["alpha-plugin".to_string(), "beta-plugin".to_string()], - ), - ( - "connector_gmail".to_string(), - vec!["beta-plugin".to_string()], - ), - ]), - plugin_display_names_by_mcp_server_name: HashMap::from([ - ("alpha".to_string(), vec!["alpha-plugin".to_string()]), - ("beta".to_string(), vec!["beta-plugin".to_string()]), - ]), - } - ); - } - - #[test] - fn codex_apps_mcp_url_for_default_gateway_keeps_existing_paths() { - assert_eq!( - codex_apps_mcp_url_for_gateway( - "https://chatgpt.com/backend-api", - CodexAppsMcpGateway::LegacyMCPGateway - ), - "https://chatgpt.com/backend-api/wham/apps" - ); - assert_eq!( - codex_apps_mcp_url_for_gateway( - "https://chat.openai.com", - CodexAppsMcpGateway::LegacyMCPGateway - ), - "https://chat.openai.com/backend-api/wham/apps" - ); - assert_eq!( - codex_apps_mcp_url_for_gateway( - "http://localhost:8080/api/codex", - CodexAppsMcpGateway::LegacyMCPGateway - ), - "http://localhost:8080/api/codex/apps" - ); - assert_eq!( - codex_apps_mcp_url_for_gateway( - "http://localhost:8080", - CodexAppsMcpGateway::LegacyMCPGateway - ), - "http://localhost:8080/api/codex/apps" - ); - } - - #[test] - fn codex_apps_mcp_url_for_gateway_uses_openai_connectors_gateway() { - let expected_url = format!("{OPENAI_CONNECTORS_MCP_BASE_URL}{OPENAI_CONNECTORS_MCP_PATH}"); - - assert_eq!( - codex_apps_mcp_url_for_gateway( - "https://chatgpt.com/backend-api", - CodexAppsMcpGateway::MCPGateway - ), - expected_url.as_str() - ); - assert_eq!( - codex_apps_mcp_url_for_gateway( - "https://chat.openai.com", - CodexAppsMcpGateway::MCPGateway - ), - expected_url.as_str() - ); - assert_eq!( - codex_apps_mcp_url_for_gateway( - "http://localhost:8080/api/codex", - CodexAppsMcpGateway::MCPGateway - ), - expected_url.as_str() - ); - assert_eq!( - codex_apps_mcp_url_for_gateway( - "http://localhost:8080", - CodexAppsMcpGateway::MCPGateway - ), - expected_url.as_str() - ); - } - - #[test] - fn codex_apps_mcp_url_uses_default_gateway_when_feature_is_disabled() { - let mut config = crate::config::test_config(); - config.chatgpt_base_url = "https://chatgpt.com".to_string(); - - assert_eq!( - codex_apps_mcp_url(&config), - "https://chatgpt.com/backend-api/wham/apps" - ); - } - - #[test] - fn codex_apps_mcp_url_uses_openai_connectors_gateway_when_feature_is_enabled() { - let mut config = crate::config::test_config(); - config.chatgpt_base_url = "https://chatgpt.com".to_string(); - config - .features - .enable(Feature::AppsMcpGateway) - .expect("test config should allow apps gateway"); - - assert_eq!( - codex_apps_mcp_url(&config), - format!("{OPENAI_CONNECTORS_MCP_BASE_URL}{OPENAI_CONNECTORS_MCP_PATH}") - ); - } - - #[test] - fn codex_apps_server_config_switches_gateway_with_flags() { - let mut config = crate::config::test_config(); - config.chatgpt_base_url = "https://chatgpt.com".to_string(); - - let mut servers = with_codex_apps_mcp(HashMap::new(), false, None, &config); - assert!(!servers.contains_key(CODEX_APPS_MCP_SERVER_NAME)); - - config - .features - .enable(Feature::Apps) - .expect("test config should allow apps"); - - servers = with_codex_apps_mcp(servers, true, None, &config); - let server = servers - .get(CODEX_APPS_MCP_SERVER_NAME) - .expect("codex apps should be present when apps is enabled"); - let url = match &server.transport { - McpServerTransportConfig::StreamableHttp { url, .. } => url, - _ => panic!("expected streamable http transport for codex apps"), - }; - - assert_eq!(url, "https://chatgpt.com/backend-api/wham/apps"); - - config - .features - .enable(Feature::AppsMcpGateway) - .expect("test config should allow apps gateway"); - servers = with_codex_apps_mcp(servers, true, None, &config); - let server = servers - .get(CODEX_APPS_MCP_SERVER_NAME) - .expect("codex apps should remain present when apps stays enabled"); - let url = match &server.transport { - McpServerTransportConfig::StreamableHttp { url, .. } => url, - _ => panic!("expected streamable http transport for codex apps"), - }; - - let expected_url = format!("{OPENAI_CONNECTORS_MCP_BASE_URL}{OPENAI_CONNECTORS_MCP_PATH}"); - assert_eq!(url, &expected_url); - } - - #[tokio::test] - async fn effective_mcp_servers_include_plugins_without_overriding_user_config() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let plugin_root = codex_home - .path() - .join("plugins/cache") - .join("test/sample/local"); - write_file( - &plugin_root.join(".codex-plugin/plugin.json"), - r#"{"name":"sample"}"#, - ); - write_file( - &plugin_root.join(".mcp.json"), - r#"{ - "mcpServers": { - "sample": { - "type": "http", - "url": "https://plugin.example/mcp" - }, - "docs": { - "type": "http", - "url": "https://docs.example/mcp" - } - } -}"#, - ); - write_file( - &codex_home.path().join(CONFIG_TOML_FILE), - &plugin_config_toml(), - ); - - let mut config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .build() - .await - .expect("config should load"); - - let mut configured_servers = config.mcp_servers.get().clone(); - configured_servers.insert( - "sample".to_string(), - McpServerConfig { - transport: McpServerTransportConfig::StreamableHttp { - url: "https://user.example/mcp".to_string(), - bearer_token_env_var: None, - http_headers: None, - env_http_headers: None, - }, - enabled: true, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - ); - config - .mcp_servers - .set(configured_servers) - .expect("test config should accept MCP servers"); - - let mcp_manager = McpManager::new(Arc::new(PluginsManager::new(config.codex_home.clone()))); - let effective = mcp_manager.effective_servers(&config, None); - - let sample = effective.get("sample").expect("user server should exist"); - let docs = effective.get("docs").expect("plugin server should exist"); - - match &sample.transport { - McpServerTransportConfig::StreamableHttp { url, .. } => { - assert_eq!(url, "https://user.example/mcp"); - } - other => panic!("expected streamable http transport, got {other:?}"), - } - match &docs.transport { - McpServerTransportConfig::StreamableHttp { url, .. } => { - assert_eq!(url, "https://docs.example/mcp"); - } - other => panic!("expected streamable http transport, got {other:?}"), - } - } -} +#[path = "mod_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/mcp/mod_tests.rs b/codex-rs/core/src/mcp/mod_tests.rs new file mode 100644 index 0000000000..cdbcda2ea0 --- /dev/null +++ b/codex-rs/core/src/mcp/mod_tests.rs @@ -0,0 +1,333 @@ +use super::*; +use crate::config::CONFIG_TOML_FILE; +use crate::config::ConfigBuilder; +use crate::plugins::AppConnectorId; +use crate::plugins::PluginCapabilitySummary; +use pretty_assertions::assert_eq; +use std::fs; +use std::path::Path; +use toml::Value; + +fn write_file(path: &Path, contents: &str) { + fs::create_dir_all(path.parent().expect("file should have a parent")).unwrap(); + fs::write(path, contents).unwrap(); +} + +fn plugin_config_toml() -> String { + let mut root = toml::map::Map::new(); + + let mut features = toml::map::Map::new(); + features.insert("plugins".to_string(), Value::Boolean(true)); + root.insert("features".to_string(), Value::Table(features)); + + let mut plugin = toml::map::Map::new(); + plugin.insert("enabled".to_string(), Value::Boolean(true)); + + let mut plugins = toml::map::Map::new(); + plugins.insert("sample@test".to_string(), Value::Table(plugin)); + root.insert("plugins".to_string(), Value::Table(plugins)); + + toml::to_string(&Value::Table(root)).expect("plugin test config should serialize") +} + +fn make_tool(name: &str) -> Tool { + Tool { + name: name.to_string(), + title: None, + description: None, + input_schema: serde_json::json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + icons: None, + meta: None, + } +} + +#[test] +fn split_qualified_tool_name_returns_server_and_tool() { + assert_eq!( + split_qualified_tool_name("mcp__alpha__do_thing"), + Some(("alpha".to_string(), "do_thing".to_string())) + ); +} + +#[test] +fn split_qualified_tool_name_rejects_invalid_names() { + assert_eq!(split_qualified_tool_name("other__alpha__do_thing"), None); + assert_eq!(split_qualified_tool_name("mcp__alpha__"), None); +} + +#[test] +fn group_tools_by_server_strips_prefix_and_groups() { + let mut tools = HashMap::new(); + tools.insert("mcp__alpha__do_thing".to_string(), make_tool("do_thing")); + tools.insert( + "mcp__alpha__nested__op".to_string(), + make_tool("nested__op"), + ); + tools.insert("mcp__beta__do_other".to_string(), make_tool("do_other")); + + let mut expected_alpha = HashMap::new(); + expected_alpha.insert("do_thing".to_string(), make_tool("do_thing")); + expected_alpha.insert("nested__op".to_string(), make_tool("nested__op")); + + let mut expected_beta = HashMap::new(); + expected_beta.insert("do_other".to_string(), make_tool("do_other")); + + let mut expected = HashMap::new(); + expected.insert("alpha".to_string(), expected_alpha); + expected.insert("beta".to_string(), expected_beta); + + assert_eq!(group_tools_by_server(&tools), expected); +} + +#[test] +fn tool_plugin_provenance_collects_app_and_mcp_sources() { + let provenance = ToolPluginProvenance::from_capability_summaries(&[ + PluginCapabilitySummary { + display_name: "alpha-plugin".to_string(), + app_connector_ids: vec![AppConnectorId("connector_example".to_string())], + mcp_server_names: vec!["alpha".to_string()], + ..PluginCapabilitySummary::default() + }, + PluginCapabilitySummary { + display_name: "beta-plugin".to_string(), + app_connector_ids: vec![ + AppConnectorId("connector_example".to_string()), + AppConnectorId("connector_gmail".to_string()), + ], + mcp_server_names: vec!["beta".to_string()], + ..PluginCapabilitySummary::default() + }, + ]); + + assert_eq!( + provenance, + ToolPluginProvenance { + plugin_display_names_by_connector_id: HashMap::from([ + ( + "connector_example".to_string(), + vec!["alpha-plugin".to_string(), "beta-plugin".to_string()], + ), + ( + "connector_gmail".to_string(), + vec!["beta-plugin".to_string()], + ), + ]), + plugin_display_names_by_mcp_server_name: HashMap::from([ + ("alpha".to_string(), vec!["alpha-plugin".to_string()]), + ("beta".to_string(), vec!["beta-plugin".to_string()]), + ]), + } + ); +} + +#[test] +fn codex_apps_mcp_url_for_default_gateway_keeps_existing_paths() { + assert_eq!( + codex_apps_mcp_url_for_gateway( + "https://chatgpt.com/backend-api", + CodexAppsMcpGateway::LegacyMCPGateway + ), + "https://chatgpt.com/backend-api/wham/apps" + ); + assert_eq!( + codex_apps_mcp_url_for_gateway( + "https://chat.openai.com", + CodexAppsMcpGateway::LegacyMCPGateway + ), + "https://chat.openai.com/backend-api/wham/apps" + ); + assert_eq!( + codex_apps_mcp_url_for_gateway( + "http://localhost:8080/api/codex", + CodexAppsMcpGateway::LegacyMCPGateway + ), + "http://localhost:8080/api/codex/apps" + ); + assert_eq!( + codex_apps_mcp_url_for_gateway( + "http://localhost:8080", + CodexAppsMcpGateway::LegacyMCPGateway + ), + "http://localhost:8080/api/codex/apps" + ); +} + +#[test] +fn codex_apps_mcp_url_for_gateway_uses_openai_connectors_gateway() { + let expected_url = format!("{OPENAI_CONNECTORS_MCP_BASE_URL}{OPENAI_CONNECTORS_MCP_PATH}"); + + assert_eq!( + codex_apps_mcp_url_for_gateway( + "https://chatgpt.com/backend-api", + CodexAppsMcpGateway::MCPGateway + ), + expected_url.as_str() + ); + assert_eq!( + codex_apps_mcp_url_for_gateway("https://chat.openai.com", CodexAppsMcpGateway::MCPGateway), + expected_url.as_str() + ); + assert_eq!( + codex_apps_mcp_url_for_gateway( + "http://localhost:8080/api/codex", + CodexAppsMcpGateway::MCPGateway + ), + expected_url.as_str() + ); + assert_eq!( + codex_apps_mcp_url_for_gateway("http://localhost:8080", CodexAppsMcpGateway::MCPGateway), + expected_url.as_str() + ); +} + +#[test] +fn codex_apps_mcp_url_uses_default_gateway_when_feature_is_disabled() { + let mut config = crate::config::test_config(); + config.chatgpt_base_url = "https://chatgpt.com".to_string(); + + assert_eq!( + codex_apps_mcp_url(&config), + "https://chatgpt.com/backend-api/wham/apps" + ); +} + +#[test] +fn codex_apps_mcp_url_uses_openai_connectors_gateway_when_feature_is_enabled() { + let mut config = crate::config::test_config(); + config.chatgpt_base_url = "https://chatgpt.com".to_string(); + config + .features + .enable(Feature::AppsMcpGateway) + .expect("test config should allow apps gateway"); + + assert_eq!( + codex_apps_mcp_url(&config), + format!("{OPENAI_CONNECTORS_MCP_BASE_URL}{OPENAI_CONNECTORS_MCP_PATH}") + ); +} + +#[test] +fn codex_apps_server_config_switches_gateway_with_flags() { + let mut config = crate::config::test_config(); + config.chatgpt_base_url = "https://chatgpt.com".to_string(); + + let mut servers = with_codex_apps_mcp(HashMap::new(), false, None, &config); + assert!(!servers.contains_key(CODEX_APPS_MCP_SERVER_NAME)); + + config + .features + .enable(Feature::Apps) + .expect("test config should allow apps"); + + servers = with_codex_apps_mcp(servers, true, None, &config); + let server = servers + .get(CODEX_APPS_MCP_SERVER_NAME) + .expect("codex apps should be present when apps is enabled"); + let url = match &server.transport { + McpServerTransportConfig::StreamableHttp { url, .. } => url, + _ => panic!("expected streamable http transport for codex apps"), + }; + + assert_eq!(url, "https://chatgpt.com/backend-api/wham/apps"); + + config + .features + .enable(Feature::AppsMcpGateway) + .expect("test config should allow apps gateway"); + servers = with_codex_apps_mcp(servers, true, None, &config); + let server = servers + .get(CODEX_APPS_MCP_SERVER_NAME) + .expect("codex apps should remain present when apps stays enabled"); + let url = match &server.transport { + McpServerTransportConfig::StreamableHttp { url, .. } => url, + _ => panic!("expected streamable http transport for codex apps"), + }; + + let expected_url = format!("{OPENAI_CONNECTORS_MCP_BASE_URL}{OPENAI_CONNECTORS_MCP_PATH}"); + assert_eq!(url, &expected_url); +} + +#[tokio::test] +async fn effective_mcp_servers_include_plugins_without_overriding_user_config() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let plugin_root = codex_home + .path() + .join("plugins/cache") + .join("test/sample/local"); + write_file( + &plugin_root.join(".codex-plugin/plugin.json"), + r#"{"name":"sample"}"#, + ); + write_file( + &plugin_root.join(".mcp.json"), + r#"{ + "mcpServers": { + "sample": { + "type": "http", + "url": "https://plugin.example/mcp" + }, + "docs": { + "type": "http", + "url": "https://docs.example/mcp" + } + } +}"#, + ); + write_file( + &codex_home.path().join(CONFIG_TOML_FILE), + &plugin_config_toml(), + ); + + let mut config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .build() + .await + .expect("config should load"); + + let mut configured_servers = config.mcp_servers.get().clone(); + configured_servers.insert( + "sample".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://user.example/mcp".to_string(), + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + ); + config + .mcp_servers + .set(configured_servers) + .expect("test config should accept MCP servers"); + + let mcp_manager = McpManager::new(Arc::new(PluginsManager::new(config.codex_home.clone()))); + let effective = mcp_manager.effective_servers(&config, None); + + let sample = effective.get("sample").expect("user server should exist"); + let docs = effective.get("docs").expect("plugin server should exist"); + + match &sample.transport { + McpServerTransportConfig::StreamableHttp { url, .. } => { + assert_eq!(url, "https://user.example/mcp"); + } + other => panic!("expected streamable http transport, got {other:?}"), + } + match &docs.transport { + McpServerTransportConfig::StreamableHttp { url, .. } => { + assert_eq!(url, "https://docs.example/mcp"); + } + other => panic!("expected streamable http transport, got {other:?}"), + } +} diff --git a/codex-rs/core/src/mcp/skill_dependencies.rs b/codex-rs/core/src/mcp/skill_dependencies.rs index f15bb6ec57..e9d77a33f4 100644 --- a/codex-rs/core/src/mcp/skill_dependencies.rs +++ b/codex-rs/core/src/mcp/skill_dependencies.rs @@ -426,111 +426,5 @@ fn mcp_dependency_to_server_config( } #[cfg(test)] -mod tests { - use super::*; - use crate::skills::model::SkillDependencies; - use codex_protocol::protocol::SkillScope; - use pretty_assertions::assert_eq; - use std::path::PathBuf; - - fn skill_with_tools(tools: Vec) -> SkillMetadata { - SkillMetadata { - name: "skill".to_string(), - description: "skill".to_string(), - short_description: None, - interface: None, - dependencies: Some(SkillDependencies { tools }), - policy: None, - permission_profile: None, - path_to_skills_md: PathBuf::from("skill"), - scope: SkillScope::User, - } - } - - #[test] - fn collect_missing_respects_canonical_installed_key() { - let url = "https://example.com/mcp".to_string(); - let skills = vec![skill_with_tools(vec![SkillToolDependency { - r#type: "mcp".to_string(), - value: "github".to_string(), - description: None, - transport: Some("streamable_http".to_string()), - command: None, - url: Some(url.clone()), - }])]; - let installed = HashMap::from([( - "alias".to_string(), - McpServerConfig { - transport: McpServerTransportConfig::StreamableHttp { - url, - bearer_token_env_var: None, - http_headers: None, - env_http_headers: None, - }, - enabled: true, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - )]); - - assert_eq!( - collect_missing_mcp_dependencies(&skills, &installed), - HashMap::new() - ); - } - - #[test] - fn collect_missing_dedupes_by_canonical_key_but_preserves_original_name() { - let url = "https://example.com/one".to_string(); - let skills = vec![skill_with_tools(vec![ - SkillToolDependency { - r#type: "mcp".to_string(), - value: "alias-one".to_string(), - description: None, - transport: Some("streamable_http".to_string()), - command: None, - url: Some(url.clone()), - }, - SkillToolDependency { - r#type: "mcp".to_string(), - value: "alias-two".to_string(), - description: None, - transport: Some("streamable_http".to_string()), - command: None, - url: Some(url.clone()), - }, - ])]; - - let expected = HashMap::from([( - "alias-one".to_string(), - McpServerConfig { - transport: McpServerTransportConfig::StreamableHttp { - url, - bearer_token_env_var: None, - http_headers: None, - env_http_headers: None, - }, - enabled: true, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - )]); - - assert_eq!( - collect_missing_mcp_dependencies(&skills, &HashMap::new()), - expected - ); - } -} +#[path = "skill_dependencies_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/mcp/skill_dependencies_tests.rs b/codex-rs/core/src/mcp/skill_dependencies_tests.rs new file mode 100644 index 0000000000..68af0df984 --- /dev/null +++ b/codex-rs/core/src/mcp/skill_dependencies_tests.rs @@ -0,0 +1,106 @@ +use super::*; +use crate::skills::model::SkillDependencies; +use codex_protocol::protocol::SkillScope; +use pretty_assertions::assert_eq; +use std::path::PathBuf; + +fn skill_with_tools(tools: Vec) -> SkillMetadata { + SkillMetadata { + name: "skill".to_string(), + description: "skill".to_string(), + short_description: None, + interface: None, + dependencies: Some(SkillDependencies { tools }), + policy: None, + permission_profile: None, + path_to_skills_md: PathBuf::from("skill"), + scope: SkillScope::User, + } +} + +#[test] +fn collect_missing_respects_canonical_installed_key() { + let url = "https://example.com/mcp".to_string(); + let skills = vec![skill_with_tools(vec![SkillToolDependency { + r#type: "mcp".to_string(), + value: "github".to_string(), + description: None, + transport: Some("streamable_http".to_string()), + command: None, + url: Some(url.clone()), + }])]; + let installed = HashMap::from([( + "alias".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url, + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + )]); + + assert_eq!( + collect_missing_mcp_dependencies(&skills, &installed), + HashMap::new() + ); +} + +#[test] +fn collect_missing_dedupes_by_canonical_key_but_preserves_original_name() { + let url = "https://example.com/one".to_string(); + let skills = vec![skill_with_tools(vec![ + SkillToolDependency { + r#type: "mcp".to_string(), + value: "alias-one".to_string(), + description: None, + transport: Some("streamable_http".to_string()), + command: None, + url: Some(url.clone()), + }, + SkillToolDependency { + r#type: "mcp".to_string(), + value: "alias-two".to_string(), + description: None, + transport: Some("streamable_http".to_string()), + command: None, + url: Some(url.clone()), + }, + ])]; + + let expected = HashMap::from([( + "alias-one".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url, + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + )]); + + assert_eq!( + collect_missing_mcp_dependencies(&skills, &HashMap::new()), + expected + ); +} diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 5eb5ebc793..c93bb13d7c 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -1703,653 +1703,5 @@ fn startup_outcome_error_message(error: StartupOutcomeError) -> String { mod mcp_init_error_display_tests {} #[cfg(test)] -mod tests { - use super::*; - use codex_protocol::protocol::McpAuthStatus; - use codex_protocol::protocol::RejectConfig; - use rmcp::model::JsonObject; - use std::collections::HashSet; - use std::sync::Arc; - use tempfile::tempdir; - - fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { - ToolInfo { - server_name: server_name.to_string(), - tool_name: tool_name.to_string(), - tool_namespace: if server_name == CODEX_APPS_MCP_SERVER_NAME { - format!("mcp__{server_name}__") - } else { - server_name.to_string() - }, - tool: Tool { - name: tool_name.to_string().into(), - title: None, - description: Some(format!("Test tool: {tool_name}").into()), - input_schema: Arc::new(JsonObject::default()), - output_schema: None, - annotations: None, - execution: None, - icons: None, - meta: None, - }, - connector_id: None, - connector_name: None, - plugin_display_names: Vec::new(), - connector_description: None, - } - } - - fn create_test_tool_with_connector( - server_name: &str, - tool_name: &str, - connector_id: &str, - connector_name: Option<&str>, - ) -> ToolInfo { - let mut tool = create_test_tool(server_name, tool_name); - tool.connector_id = Some(connector_id.to_string()); - tool.connector_name = connector_name.map(ToOwned::to_owned); - tool - } - - fn create_codex_apps_tools_cache_context( - codex_home: PathBuf, - account_id: Option<&str>, - chatgpt_user_id: Option<&str>, - ) -> CodexAppsToolsCacheContext { - CodexAppsToolsCacheContext { - codex_home, - user_key: CodexAppsToolsCacheKey { - account_id: account_id.map(ToOwned::to_owned), - chatgpt_user_id: chatgpt_user_id.map(ToOwned::to_owned), - is_workspace_account: false, - }, - } - } - - #[test] - fn elicitation_reject_policy_defaults_to_prompting() { - assert!(!elicitation_is_rejected_by_policy( - AskForApproval::OnFailure - )); - assert!(!elicitation_is_rejected_by_policy( - AskForApproval::OnRequest - )); - assert!(!elicitation_is_rejected_by_policy( - AskForApproval::UnlessTrusted - )); - assert!(!elicitation_is_rejected_by_policy(AskForApproval::Reject( - RejectConfig { - sandbox_approval: false, - rules: false, - skill_approval: false, - request_permissions: false, - mcp_elicitations: false, - } - ))); - } - - #[test] - fn elicitation_reject_policy_respects_never_and_reject_config() { - assert!(elicitation_is_rejected_by_policy(AskForApproval::Never)); - assert!(elicitation_is_rejected_by_policy(AskForApproval::Reject( - RejectConfig { - sandbox_approval: false, - rules: false, - skill_approval: false, - request_permissions: false, - mcp_elicitations: true, - } - ))); - } - - #[test] - fn test_qualify_tools_short_non_duplicated_names() { - let tools = vec![ - create_test_tool("server1", "tool1"), - create_test_tool("server1", "tool2"), - ]; - - let qualified_tools = qualify_tools(tools); - - assert_eq!(qualified_tools.len(), 2); - assert!(qualified_tools.contains_key("mcp__server1__tool1")); - assert!(qualified_tools.contains_key("mcp__server1__tool2")); - } - - #[test] - fn test_qualify_tools_duplicated_names_skipped() { - let tools = vec![ - create_test_tool("server1", "duplicate_tool"), - create_test_tool("server1", "duplicate_tool"), - ]; - - let qualified_tools = qualify_tools(tools); - - // Only the first tool should remain, the second is skipped - assert_eq!(qualified_tools.len(), 1); - assert!(qualified_tools.contains_key("mcp__server1__duplicate_tool")); - } - - #[test] - fn test_qualify_tools_long_names_same_server() { - let server_name = "my_server"; - - let tools = vec![ - create_test_tool( - server_name, - "extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits", - ), - create_test_tool( - server_name, - "yet_another_extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits", - ), - ]; - - let qualified_tools = qualify_tools(tools); - - assert_eq!(qualified_tools.len(), 2); - - let mut keys: Vec<_> = qualified_tools.keys().cloned().collect(); - keys.sort(); - - assert_eq!(keys[0].len(), 64); - assert_eq!( - keys[0], - "mcp__my_server__extremel119a2b97664e41363932dc84de21e2ff1b93b3e9" - ); - - assert_eq!(keys[1].len(), 64); - assert_eq!( - keys[1], - "mcp__my_server__yet_anot419a82a89325c1b477274a41f8c65ea5f3a7f341" - ); - } - - #[test] - fn test_qualify_tools_sanitizes_invalid_characters() { - let tools = vec![create_test_tool("server.one", "tool.two")]; - - let qualified_tools = qualify_tools(tools); - - assert_eq!(qualified_tools.len(), 1); - let (qualified_name, tool) = qualified_tools.into_iter().next().expect("one tool"); - assert_eq!(qualified_name, "mcp__server_one__tool_two"); - - // The key is sanitized for OpenAI, but we keep original parts for the actual MCP call. - assert_eq!(tool.server_name, "server.one"); - assert_eq!(tool.tool_name, "tool.two"); - - assert!( - qualified_name - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-'), - "qualified name must be Responses API compatible: {qualified_name:?}" - ); - } - - #[test] - fn tool_filter_allows_by_default() { - let filter = ToolFilter::default(); - - assert!(filter.allows("any")); - } - - #[test] - fn tool_filter_applies_enabled_list() { - let filter = ToolFilter { - enabled: Some(HashSet::from(["allowed".to_string()])), - disabled: HashSet::new(), - }; - - assert!(filter.allows("allowed")); - assert!(!filter.allows("denied")); - } - - #[test] - fn tool_filter_applies_disabled_list() { - let filter = ToolFilter { - enabled: None, - disabled: HashSet::from(["blocked".to_string()]), - }; - - assert!(!filter.allows("blocked")); - assert!(filter.allows("open")); - } - - #[test] - fn tool_filter_applies_enabled_then_disabled() { - let filter = ToolFilter { - enabled: Some(HashSet::from(["keep".to_string(), "remove".to_string()])), - disabled: HashSet::from(["remove".to_string()]), - }; - - assert!(filter.allows("keep")); - assert!(!filter.allows("remove")); - assert!(!filter.allows("unknown")); - } - - #[test] - fn filter_tools_applies_per_server_filters() { - let server1_tools = vec![ - create_test_tool("server1", "tool_a"), - create_test_tool("server1", "tool_b"), - ]; - let server2_tools = vec![create_test_tool("server2", "tool_a")]; - let server1_filter = ToolFilter { - enabled: Some(HashSet::from(["tool_a".to_string(), "tool_b".to_string()])), - disabled: HashSet::from(["tool_b".to_string()]), - }; - let server2_filter = ToolFilter { - enabled: None, - disabled: HashSet::from(["tool_a".to_string()]), - }; - - let filtered: Vec<_> = filter_tools(server1_tools, &server1_filter) - .into_iter() - .chain(filter_tools(server2_tools, &server2_filter)) - .collect(); - - assert_eq!(filtered.len(), 1); - assert_eq!(filtered[0].server_name, "server1"); - assert_eq!(filtered[0].tool_name, "tool_a"); - } - - #[test] - fn codex_apps_tools_cache_is_overwritten_by_last_write() { - let codex_home = tempdir().expect("tempdir"); - let cache_context = create_codex_apps_tools_cache_context( - codex_home.path().to_path_buf(), - Some("account-one"), - Some("user-one"), - ); - let tools_gateway_1 = vec![create_test_tool(CODEX_APPS_MCP_SERVER_NAME, "one")]; - let tools_gateway_2 = vec![create_test_tool(CODEX_APPS_MCP_SERVER_NAME, "two")]; - - write_cached_codex_apps_tools(&cache_context, &tools_gateway_1); - let cached_gateway_1 = read_cached_codex_apps_tools(&cache_context) - .expect("cache entry exists for first write"); - assert_eq!(cached_gateway_1[0].tool_name, "one"); - - write_cached_codex_apps_tools(&cache_context, &tools_gateway_2); - let cached_gateway_2 = read_cached_codex_apps_tools(&cache_context) - .expect("cache entry exists for second write"); - assert_eq!(cached_gateway_2[0].tool_name, "two"); - } - - #[test] - fn codex_apps_tools_cache_is_scoped_per_user() { - let codex_home = tempdir().expect("tempdir"); - let cache_context_user_1 = create_codex_apps_tools_cache_context( - codex_home.path().to_path_buf(), - Some("account-one"), - Some("user-one"), - ); - let cache_context_user_2 = create_codex_apps_tools_cache_context( - codex_home.path().to_path_buf(), - Some("account-two"), - Some("user-two"), - ); - let tools_user_1 = vec![create_test_tool(CODEX_APPS_MCP_SERVER_NAME, "one")]; - let tools_user_2 = vec![create_test_tool(CODEX_APPS_MCP_SERVER_NAME, "two")]; - - write_cached_codex_apps_tools(&cache_context_user_1, &tools_user_1); - write_cached_codex_apps_tools(&cache_context_user_2, &tools_user_2); - - let read_user_1 = - read_cached_codex_apps_tools(&cache_context_user_1).expect("cache entry for user one"); - let read_user_2 = - read_cached_codex_apps_tools(&cache_context_user_2).expect("cache entry for user two"); - - assert_eq!(read_user_1[0].tool_name, "one"); - assert_eq!(read_user_2[0].tool_name, "two"); - assert_ne!( - cache_context_user_1.cache_path(), - cache_context_user_2.cache_path(), - "each user should get an isolated cache file" - ); - } - - #[test] - fn codex_apps_tools_cache_filters_disallowed_connectors() { - let codex_home = tempdir().expect("tempdir"); - let cache_context = create_codex_apps_tools_cache_context( - codex_home.path().to_path_buf(), - Some("account-one"), - Some("user-one"), - ); - let tools = vec![ - create_test_tool_with_connector( - CODEX_APPS_MCP_SERVER_NAME, - "blocked_tool", - "connector_openai_hidden", - Some("Hidden"), - ), - create_test_tool_with_connector( - CODEX_APPS_MCP_SERVER_NAME, - "allowed_tool", - "calendar", - Some("Calendar"), - ), - ]; - - write_cached_codex_apps_tools(&cache_context, &tools); - let cached = - read_cached_codex_apps_tools(&cache_context).expect("cache entry exists for user"); - - assert_eq!(cached.len(), 1); - assert_eq!(cached[0].tool_name, "allowed_tool"); - assert_eq!(cached[0].connector_id.as_deref(), Some("calendar")); - } - - #[test] - fn codex_apps_tools_cache_is_ignored_when_schema_version_mismatches() { - let codex_home = tempdir().expect("tempdir"); - let cache_context = create_codex_apps_tools_cache_context( - codex_home.path().to_path_buf(), - Some("account-one"), - Some("user-one"), - ); - let cache_path = cache_context.cache_path(); - if let Some(parent) = cache_path.parent() { - std::fs::create_dir_all(parent).expect("create parent"); - } - let bytes = serde_json::to_vec_pretty(&serde_json::json!({ - "schema_version": CODEX_APPS_TOOLS_CACHE_SCHEMA_VERSION + 1, - "tools": [create_test_tool(CODEX_APPS_MCP_SERVER_NAME, "one")], - })) - .expect("serialize"); - std::fs::write(cache_path, bytes).expect("write"); - - assert!(read_cached_codex_apps_tools(&cache_context).is_none()); - } - - #[test] - fn codex_apps_tools_cache_is_ignored_when_json_is_invalid() { - let codex_home = tempdir().expect("tempdir"); - let cache_context = create_codex_apps_tools_cache_context( - codex_home.path().to_path_buf(), - Some("account-one"), - Some("user-one"), - ); - let cache_path = cache_context.cache_path(); - if let Some(parent) = cache_path.parent() { - std::fs::create_dir_all(parent).expect("create parent"); - } - std::fs::write(cache_path, b"{not json").expect("write"); - - assert!(read_cached_codex_apps_tools(&cache_context).is_none()); - } - - #[test] - fn startup_cached_codex_apps_tools_loads_from_disk_cache() { - let codex_home = tempdir().expect("tempdir"); - let cache_context = create_codex_apps_tools_cache_context( - codex_home.path().to_path_buf(), - Some("account-one"), - Some("user-one"), - ); - let cached_tools = vec![create_test_tool( - CODEX_APPS_MCP_SERVER_NAME, - "calendar_search", - )]; - write_cached_codex_apps_tools(&cache_context, &cached_tools); - - let startup_snapshot = load_startup_cached_codex_apps_tools_snapshot( - CODEX_APPS_MCP_SERVER_NAME, - Some(&cache_context), - ); - let startup_tools = startup_snapshot.expect("expected startup snapshot to load from cache"); - - assert_eq!(startup_tools.len(), 1); - assert_eq!(startup_tools[0].server_name, CODEX_APPS_MCP_SERVER_NAME); - assert_eq!(startup_tools[0].tool_name, "calendar_search"); - } - - #[tokio::test] - async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() { - let startup_tools = vec![create_test_tool( - CODEX_APPS_MCP_SERVER_NAME, - "calendar_create_event", - )]; - let pending_client = - futures::future::pending::>() - .boxed() - .shared(); - let approval_policy = Constrained::allow_any(AskForApproval::OnFailure); - let mut manager = McpConnectionManager::new_uninitialized(&approval_policy); - manager.clients.insert( - CODEX_APPS_MCP_SERVER_NAME.to_string(), - AsyncManagedClient { - client: pending_client, - startup_snapshot: Some(startup_tools), - startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), - tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), - }, - ); - - let tools = manager.list_all_tools().await; - let tool = tools - .get("mcp__codex_apps__calendar_create_event") - .expect("tool from startup cache"); - assert_eq!(tool.server_name, CODEX_APPS_MCP_SERVER_NAME); - assert_eq!(tool.tool_name, "calendar_create_event"); - } - - #[tokio::test] - async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot() { - let pending_client = - futures::future::pending::>() - .boxed() - .shared(); - let approval_policy = Constrained::allow_any(AskForApproval::OnFailure); - let mut manager = McpConnectionManager::new_uninitialized(&approval_policy); - manager.clients.insert( - CODEX_APPS_MCP_SERVER_NAME.to_string(), - AsyncManagedClient { - client: pending_client, - startup_snapshot: None, - startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), - tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), - }, - ); - - let timeout_result = - tokio::time::timeout(Duration::from_millis(10), manager.list_all_tools()).await; - assert!(timeout_result.is_err()); - } - - #[tokio::test] - async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty() { - let pending_client = - futures::future::pending::>() - .boxed() - .shared(); - let approval_policy = Constrained::allow_any(AskForApproval::OnFailure); - let mut manager = McpConnectionManager::new_uninitialized(&approval_policy); - manager.clients.insert( - CODEX_APPS_MCP_SERVER_NAME.to_string(), - AsyncManagedClient { - client: pending_client, - startup_snapshot: Some(Vec::new()), - startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), - tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), - }, - ); - - let timeout_result = - tokio::time::timeout(Duration::from_millis(10), manager.list_all_tools()).await; - let tools = timeout_result.expect("cache-hit startup snapshot should not block"); - assert!(tools.is_empty()); - } - - #[tokio::test] - async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() { - let startup_tools = vec![create_test_tool( - CODEX_APPS_MCP_SERVER_NAME, - "calendar_create_event", - )]; - let failed_client = futures::future::ready::>( - Err(StartupOutcomeError::Failed { - error: "startup failed".to_string(), - }), - ) - .boxed() - .shared(); - let approval_policy = Constrained::allow_any(AskForApproval::OnFailure); - let mut manager = McpConnectionManager::new_uninitialized(&approval_policy); - let startup_complete = Arc::new(std::sync::atomic::AtomicBool::new(true)); - manager.clients.insert( - CODEX_APPS_MCP_SERVER_NAME.to_string(), - AsyncManagedClient { - client: failed_client, - startup_snapshot: Some(startup_tools), - startup_complete, - tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), - }, - ); - - let tools = manager.list_all_tools().await; - let tool = tools - .get("mcp__codex_apps__calendar_create_event") - .expect("tool from startup cache"); - assert_eq!(tool.server_name, CODEX_APPS_MCP_SERVER_NAME); - assert_eq!(tool.tool_name, "calendar_create_event"); - } - - #[test] - fn elicitation_capability_enabled_only_for_codex_apps() { - let codex_apps_capability = elicitation_capability_for_server(CODEX_APPS_MCP_SERVER_NAME); - assert!(matches!( - codex_apps_capability, - Some(ElicitationCapability { - form: Some(FormElicitationCapability { - schema_validation: None - }), - url: None, - }) - )); - - assert!(elicitation_capability_for_server("custom_mcp").is_none()); - } - - #[test] - fn mcp_init_error_display_prompts_for_github_pat() { - let server_name = "github"; - let entry = McpAuthStatusEntry { - config: McpServerConfig { - transport: McpServerTransportConfig::StreamableHttp { - url: "https://api.githubcopilot.com/mcp/".to_string(), - bearer_token_env_var: None, - http_headers: None, - env_http_headers: None, - }, - enabled: true, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - auth_status: McpAuthStatus::Unsupported, - }; - let err: StartupOutcomeError = anyhow::anyhow!("OAuth is unsupported").into(); - - let display = mcp_init_error_display(server_name, Some(&entry), &err); - - let expected = format!( - "GitHub MCP does not support OAuth. Log in by adding a personal access token (https://github.com/settings/personal-access-tokens) to your environment and config.toml:\n[mcp_servers.{server_name}]\nbearer_token_env_var = CODEX_GITHUB_PERSONAL_ACCESS_TOKEN" - ); - - assert_eq!(expected, display); - } - - #[test] - fn mcp_init_error_display_prompts_for_login_when_auth_required() { - let server_name = "example"; - let err: StartupOutcomeError = anyhow::anyhow!("Auth required for server").into(); - - let display = mcp_init_error_display(server_name, None, &err); - - let expected = format!( - "The {server_name} MCP server is not logged in. Run `codex mcp login {server_name}`." - ); - - assert_eq!(expected, display); - } - - #[test] - fn mcp_init_error_display_reports_generic_errors() { - let server_name = "custom"; - let entry = McpAuthStatusEntry { - config: McpServerConfig { - transport: McpServerTransportConfig::StreamableHttp { - url: "https://example.com".to_string(), - bearer_token_env_var: Some("TOKEN".to_string()), - http_headers: None, - env_http_headers: None, - }, - enabled: true, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - auth_status: McpAuthStatus::Unsupported, - }; - let err: StartupOutcomeError = anyhow::anyhow!("boom").into(); - - let display = mcp_init_error_display(server_name, Some(&entry), &err); - - let expected = format!("MCP client for `{server_name}` failed to start: {err:#}"); - - assert_eq!(expected, display); - } - - #[test] - fn mcp_init_error_display_includes_startup_timeout_hint() { - let server_name = "slow"; - let err: StartupOutcomeError = anyhow::anyhow!("request timed out").into(); - - let display = mcp_init_error_display(server_name, None, &err); - - assert_eq!( - "MCP client for `slow` timed out after 10 seconds. Add or adjust `startup_timeout_sec` in your config.toml:\n[mcp_servers.slow]\nstartup_timeout_sec = XX", - display - ); - } - - #[test] - fn transport_origin_extracts_http_origin() { - let transport = McpServerTransportConfig::StreamableHttp { - url: "https://example.com:8443/path?query=1".to_string(), - bearer_token_env_var: None, - http_headers: None, - env_http_headers: None, - }; - - assert_eq!( - transport_origin(&transport), - Some("https://example.com:8443".to_string()) - ); - } - - #[test] - fn transport_origin_is_stdio_for_stdio_transport() { - let transport = McpServerTransportConfig::Stdio { - command: "server".to_string(), - args: Vec::new(), - env: None, - env_vars: Vec::new(), - cwd: None, - }; - - assert_eq!(transport_origin(&transport), Some("stdio".to_string())); - } -} +#[path = "mcp_connection_manager_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/mcp_connection_manager_tests.rs b/codex-rs/core/src/mcp_connection_manager_tests.rs new file mode 100644 index 0000000000..51eaa67b17 --- /dev/null +++ b/codex-rs/core/src/mcp_connection_manager_tests.rs @@ -0,0 +1,644 @@ +use super::*; +use codex_protocol::protocol::McpAuthStatus; +use codex_protocol::protocol::RejectConfig; +use rmcp::model::JsonObject; +use std::collections::HashSet; +use std::sync::Arc; +use tempfile::tempdir; + +fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { + ToolInfo { + server_name: server_name.to_string(), + tool_name: tool_name.to_string(), + tool_namespace: if server_name == CODEX_APPS_MCP_SERVER_NAME { + format!("mcp__{server_name}__") + } else { + server_name.to_string() + }, + tool: Tool { + name: tool_name.to_string().into(), + title: None, + description: Some(format!("Test tool: {tool_name}").into()), + input_schema: Arc::new(JsonObject::default()), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + }, + connector_id: None, + connector_name: None, + plugin_display_names: Vec::new(), + connector_description: None, + } +} + +fn create_test_tool_with_connector( + server_name: &str, + tool_name: &str, + connector_id: &str, + connector_name: Option<&str>, +) -> ToolInfo { + let mut tool = create_test_tool(server_name, tool_name); + tool.connector_id = Some(connector_id.to_string()); + tool.connector_name = connector_name.map(ToOwned::to_owned); + tool +} + +fn create_codex_apps_tools_cache_context( + codex_home: PathBuf, + account_id: Option<&str>, + chatgpt_user_id: Option<&str>, +) -> CodexAppsToolsCacheContext { + CodexAppsToolsCacheContext { + codex_home, + user_key: CodexAppsToolsCacheKey { + account_id: account_id.map(ToOwned::to_owned), + chatgpt_user_id: chatgpt_user_id.map(ToOwned::to_owned), + is_workspace_account: false, + }, + } +} + +#[test] +fn elicitation_reject_policy_defaults_to_prompting() { + assert!(!elicitation_is_rejected_by_policy( + AskForApproval::OnFailure + )); + assert!(!elicitation_is_rejected_by_policy( + AskForApproval::OnRequest + )); + assert!(!elicitation_is_rejected_by_policy( + AskForApproval::UnlessTrusted + )); + assert!(!elicitation_is_rejected_by_policy(AskForApproval::Reject( + RejectConfig { + sandbox_approval: false, + rules: false, + skill_approval: false, + request_permissions: false, + mcp_elicitations: false, + } + ))); +} + +#[test] +fn elicitation_reject_policy_respects_never_and_reject_config() { + assert!(elicitation_is_rejected_by_policy(AskForApproval::Never)); + assert!(elicitation_is_rejected_by_policy(AskForApproval::Reject( + RejectConfig { + sandbox_approval: false, + rules: false, + skill_approval: false, + request_permissions: false, + mcp_elicitations: true, + } + ))); +} + +#[test] +fn test_qualify_tools_short_non_duplicated_names() { + let tools = vec![ + create_test_tool("server1", "tool1"), + create_test_tool("server1", "tool2"), + ]; + + let qualified_tools = qualify_tools(tools); + + assert_eq!(qualified_tools.len(), 2); + assert!(qualified_tools.contains_key("mcp__server1__tool1")); + assert!(qualified_tools.contains_key("mcp__server1__tool2")); +} + +#[test] +fn test_qualify_tools_duplicated_names_skipped() { + let tools = vec![ + create_test_tool("server1", "duplicate_tool"), + create_test_tool("server1", "duplicate_tool"), + ]; + + let qualified_tools = qualify_tools(tools); + + // Only the first tool should remain, the second is skipped + assert_eq!(qualified_tools.len(), 1); + assert!(qualified_tools.contains_key("mcp__server1__duplicate_tool")); +} + +#[test] +fn test_qualify_tools_long_names_same_server() { + let server_name = "my_server"; + + let tools = vec![ + create_test_tool( + server_name, + "extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits", + ), + create_test_tool( + server_name, + "yet_another_extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits", + ), + ]; + + let qualified_tools = qualify_tools(tools); + + assert_eq!(qualified_tools.len(), 2); + + let mut keys: Vec<_> = qualified_tools.keys().cloned().collect(); + keys.sort(); + + assert_eq!(keys[0].len(), 64); + assert_eq!( + keys[0], + "mcp__my_server__extremel119a2b97664e41363932dc84de21e2ff1b93b3e9" + ); + + assert_eq!(keys[1].len(), 64); + assert_eq!( + keys[1], + "mcp__my_server__yet_anot419a82a89325c1b477274a41f8c65ea5f3a7f341" + ); +} + +#[test] +fn test_qualify_tools_sanitizes_invalid_characters() { + let tools = vec![create_test_tool("server.one", "tool.two")]; + + let qualified_tools = qualify_tools(tools); + + assert_eq!(qualified_tools.len(), 1); + let (qualified_name, tool) = qualified_tools.into_iter().next().expect("one tool"); + assert_eq!(qualified_name, "mcp__server_one__tool_two"); + + // The key is sanitized for OpenAI, but we keep original parts for the actual MCP call. + assert_eq!(tool.server_name, "server.one"); + assert_eq!(tool.tool_name, "tool.two"); + + assert!( + qualified_name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-'), + "qualified name must be Responses API compatible: {qualified_name:?}" + ); +} + +#[test] +fn tool_filter_allows_by_default() { + let filter = ToolFilter::default(); + + assert!(filter.allows("any")); +} + +#[test] +fn tool_filter_applies_enabled_list() { + let filter = ToolFilter { + enabled: Some(HashSet::from(["allowed".to_string()])), + disabled: HashSet::new(), + }; + + assert!(filter.allows("allowed")); + assert!(!filter.allows("denied")); +} + +#[test] +fn tool_filter_applies_disabled_list() { + let filter = ToolFilter { + enabled: None, + disabled: HashSet::from(["blocked".to_string()]), + }; + + assert!(!filter.allows("blocked")); + assert!(filter.allows("open")); +} + +#[test] +fn tool_filter_applies_enabled_then_disabled() { + let filter = ToolFilter { + enabled: Some(HashSet::from(["keep".to_string(), "remove".to_string()])), + disabled: HashSet::from(["remove".to_string()]), + }; + + assert!(filter.allows("keep")); + assert!(!filter.allows("remove")); + assert!(!filter.allows("unknown")); +} + +#[test] +fn filter_tools_applies_per_server_filters() { + let server1_tools = vec![ + create_test_tool("server1", "tool_a"), + create_test_tool("server1", "tool_b"), + ]; + let server2_tools = vec![create_test_tool("server2", "tool_a")]; + let server1_filter = ToolFilter { + enabled: Some(HashSet::from(["tool_a".to_string(), "tool_b".to_string()])), + disabled: HashSet::from(["tool_b".to_string()]), + }; + let server2_filter = ToolFilter { + enabled: None, + disabled: HashSet::from(["tool_a".to_string()]), + }; + + let filtered: Vec<_> = filter_tools(server1_tools, &server1_filter) + .into_iter() + .chain(filter_tools(server2_tools, &server2_filter)) + .collect(); + + assert_eq!(filtered.len(), 1); + assert_eq!(filtered[0].server_name, "server1"); + assert_eq!(filtered[0].tool_name, "tool_a"); +} + +#[test] +fn codex_apps_tools_cache_is_overwritten_by_last_write() { + let codex_home = tempdir().expect("tempdir"); + let cache_context = create_codex_apps_tools_cache_context( + codex_home.path().to_path_buf(), + Some("account-one"), + Some("user-one"), + ); + let tools_gateway_1 = vec![create_test_tool(CODEX_APPS_MCP_SERVER_NAME, "one")]; + let tools_gateway_2 = vec![create_test_tool(CODEX_APPS_MCP_SERVER_NAME, "two")]; + + write_cached_codex_apps_tools(&cache_context, &tools_gateway_1); + let cached_gateway_1 = + read_cached_codex_apps_tools(&cache_context).expect("cache entry exists for first write"); + assert_eq!(cached_gateway_1[0].tool_name, "one"); + + write_cached_codex_apps_tools(&cache_context, &tools_gateway_2); + let cached_gateway_2 = + read_cached_codex_apps_tools(&cache_context).expect("cache entry exists for second write"); + assert_eq!(cached_gateway_2[0].tool_name, "two"); +} + +#[test] +fn codex_apps_tools_cache_is_scoped_per_user() { + let codex_home = tempdir().expect("tempdir"); + let cache_context_user_1 = create_codex_apps_tools_cache_context( + codex_home.path().to_path_buf(), + Some("account-one"), + Some("user-one"), + ); + let cache_context_user_2 = create_codex_apps_tools_cache_context( + codex_home.path().to_path_buf(), + Some("account-two"), + Some("user-two"), + ); + let tools_user_1 = vec![create_test_tool(CODEX_APPS_MCP_SERVER_NAME, "one")]; + let tools_user_2 = vec![create_test_tool(CODEX_APPS_MCP_SERVER_NAME, "two")]; + + write_cached_codex_apps_tools(&cache_context_user_1, &tools_user_1); + write_cached_codex_apps_tools(&cache_context_user_2, &tools_user_2); + + let read_user_1 = + read_cached_codex_apps_tools(&cache_context_user_1).expect("cache entry for user one"); + let read_user_2 = + read_cached_codex_apps_tools(&cache_context_user_2).expect("cache entry for user two"); + + assert_eq!(read_user_1[0].tool_name, "one"); + assert_eq!(read_user_2[0].tool_name, "two"); + assert_ne!( + cache_context_user_1.cache_path(), + cache_context_user_2.cache_path(), + "each user should get an isolated cache file" + ); +} + +#[test] +fn codex_apps_tools_cache_filters_disallowed_connectors() { + let codex_home = tempdir().expect("tempdir"); + let cache_context = create_codex_apps_tools_cache_context( + codex_home.path().to_path_buf(), + Some("account-one"), + Some("user-one"), + ); + let tools = vec![ + create_test_tool_with_connector( + CODEX_APPS_MCP_SERVER_NAME, + "blocked_tool", + "connector_openai_hidden", + Some("Hidden"), + ), + create_test_tool_with_connector( + CODEX_APPS_MCP_SERVER_NAME, + "allowed_tool", + "calendar", + Some("Calendar"), + ), + ]; + + write_cached_codex_apps_tools(&cache_context, &tools); + let cached = read_cached_codex_apps_tools(&cache_context).expect("cache entry exists for user"); + + assert_eq!(cached.len(), 1); + assert_eq!(cached[0].tool_name, "allowed_tool"); + assert_eq!(cached[0].connector_id.as_deref(), Some("calendar")); +} + +#[test] +fn codex_apps_tools_cache_is_ignored_when_schema_version_mismatches() { + let codex_home = tempdir().expect("tempdir"); + let cache_context = create_codex_apps_tools_cache_context( + codex_home.path().to_path_buf(), + Some("account-one"), + Some("user-one"), + ); + let cache_path = cache_context.cache_path(); + if let Some(parent) = cache_path.parent() { + std::fs::create_dir_all(parent).expect("create parent"); + } + let bytes = serde_json::to_vec_pretty(&serde_json::json!({ + "schema_version": CODEX_APPS_TOOLS_CACHE_SCHEMA_VERSION + 1, + "tools": [create_test_tool(CODEX_APPS_MCP_SERVER_NAME, "one")], + })) + .expect("serialize"); + std::fs::write(cache_path, bytes).expect("write"); + + assert!(read_cached_codex_apps_tools(&cache_context).is_none()); +} + +#[test] +fn codex_apps_tools_cache_is_ignored_when_json_is_invalid() { + let codex_home = tempdir().expect("tempdir"); + let cache_context = create_codex_apps_tools_cache_context( + codex_home.path().to_path_buf(), + Some("account-one"), + Some("user-one"), + ); + let cache_path = cache_context.cache_path(); + if let Some(parent) = cache_path.parent() { + std::fs::create_dir_all(parent).expect("create parent"); + } + std::fs::write(cache_path, b"{not json").expect("write"); + + assert!(read_cached_codex_apps_tools(&cache_context).is_none()); +} + +#[test] +fn startup_cached_codex_apps_tools_loads_from_disk_cache() { + let codex_home = tempdir().expect("tempdir"); + let cache_context = create_codex_apps_tools_cache_context( + codex_home.path().to_path_buf(), + Some("account-one"), + Some("user-one"), + ); + let cached_tools = vec![create_test_tool( + CODEX_APPS_MCP_SERVER_NAME, + "calendar_search", + )]; + write_cached_codex_apps_tools(&cache_context, &cached_tools); + + let startup_snapshot = load_startup_cached_codex_apps_tools_snapshot( + CODEX_APPS_MCP_SERVER_NAME, + Some(&cache_context), + ); + let startup_tools = startup_snapshot.expect("expected startup snapshot to load from cache"); + + assert_eq!(startup_tools.len(), 1); + assert_eq!(startup_tools[0].server_name, CODEX_APPS_MCP_SERVER_NAME); + assert_eq!(startup_tools[0].tool_name, "calendar_search"); +} + +#[tokio::test] +async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() { + let startup_tools = vec![create_test_tool( + CODEX_APPS_MCP_SERVER_NAME, + "calendar_create_event", + )]; + let pending_client = futures::future::pending::>() + .boxed() + .shared(); + let approval_policy = Constrained::allow_any(AskForApproval::OnFailure); + let mut manager = McpConnectionManager::new_uninitialized(&approval_policy); + manager.clients.insert( + CODEX_APPS_MCP_SERVER_NAME.to_string(), + AsyncManagedClient { + client: pending_client, + startup_snapshot: Some(startup_tools), + startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), + tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), + }, + ); + + let tools = manager.list_all_tools().await; + let tool = tools + .get("mcp__codex_apps__calendar_create_event") + .expect("tool from startup cache"); + assert_eq!(tool.server_name, CODEX_APPS_MCP_SERVER_NAME); + assert_eq!(tool.tool_name, "calendar_create_event"); +} + +#[tokio::test] +async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot() { + let pending_client = futures::future::pending::>() + .boxed() + .shared(); + let approval_policy = Constrained::allow_any(AskForApproval::OnFailure); + let mut manager = McpConnectionManager::new_uninitialized(&approval_policy); + manager.clients.insert( + CODEX_APPS_MCP_SERVER_NAME.to_string(), + AsyncManagedClient { + client: pending_client, + startup_snapshot: None, + startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), + tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), + }, + ); + + let timeout_result = + tokio::time::timeout(Duration::from_millis(10), manager.list_all_tools()).await; + assert!(timeout_result.is_err()); +} + +#[tokio::test] +async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty() { + let pending_client = futures::future::pending::>() + .boxed() + .shared(); + let approval_policy = Constrained::allow_any(AskForApproval::OnFailure); + let mut manager = McpConnectionManager::new_uninitialized(&approval_policy); + manager.clients.insert( + CODEX_APPS_MCP_SERVER_NAME.to_string(), + AsyncManagedClient { + client: pending_client, + startup_snapshot: Some(Vec::new()), + startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), + tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), + }, + ); + + let timeout_result = + tokio::time::timeout(Duration::from_millis(10), manager.list_all_tools()).await; + let tools = timeout_result.expect("cache-hit startup snapshot should not block"); + assert!(tools.is_empty()); +} + +#[tokio::test] +async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() { + let startup_tools = vec![create_test_tool( + CODEX_APPS_MCP_SERVER_NAME, + "calendar_create_event", + )]; + let failed_client = futures::future::ready::>(Err( + StartupOutcomeError::Failed { + error: "startup failed".to_string(), + }, + )) + .boxed() + .shared(); + let approval_policy = Constrained::allow_any(AskForApproval::OnFailure); + let mut manager = McpConnectionManager::new_uninitialized(&approval_policy); + let startup_complete = Arc::new(std::sync::atomic::AtomicBool::new(true)); + manager.clients.insert( + CODEX_APPS_MCP_SERVER_NAME.to_string(), + AsyncManagedClient { + client: failed_client, + startup_snapshot: Some(startup_tools), + startup_complete, + tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), + }, + ); + + let tools = manager.list_all_tools().await; + let tool = tools + .get("mcp__codex_apps__calendar_create_event") + .expect("tool from startup cache"); + assert_eq!(tool.server_name, CODEX_APPS_MCP_SERVER_NAME); + assert_eq!(tool.tool_name, "calendar_create_event"); +} + +#[test] +fn elicitation_capability_enabled_only_for_codex_apps() { + let codex_apps_capability = elicitation_capability_for_server(CODEX_APPS_MCP_SERVER_NAME); + assert!(matches!( + codex_apps_capability, + Some(ElicitationCapability { + form: Some(FormElicitationCapability { + schema_validation: None + }), + url: None, + }) + )); + + assert!(elicitation_capability_for_server("custom_mcp").is_none()); +} + +#[test] +fn mcp_init_error_display_prompts_for_github_pat() { + let server_name = "github"; + let entry = McpAuthStatusEntry { + config: McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://api.githubcopilot.com/mcp/".to_string(), + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + auth_status: McpAuthStatus::Unsupported, + }; + let err: StartupOutcomeError = anyhow::anyhow!("OAuth is unsupported").into(); + + let display = mcp_init_error_display(server_name, Some(&entry), &err); + + let expected = format!( + "GitHub MCP does not support OAuth. Log in by adding a personal access token (https://github.com/settings/personal-access-tokens) to your environment and config.toml:\n[mcp_servers.{server_name}]\nbearer_token_env_var = CODEX_GITHUB_PERSONAL_ACCESS_TOKEN" + ); + + assert_eq!(expected, display); +} + +#[test] +fn mcp_init_error_display_prompts_for_login_when_auth_required() { + let server_name = "example"; + let err: StartupOutcomeError = anyhow::anyhow!("Auth required for server").into(); + + let display = mcp_init_error_display(server_name, None, &err); + + let expected = format!( + "The {server_name} MCP server is not logged in. Run `codex mcp login {server_name}`." + ); + + assert_eq!(expected, display); +} + +#[test] +fn mcp_init_error_display_reports_generic_errors() { + let server_name = "custom"; + let entry = McpAuthStatusEntry { + config: McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com".to_string(), + bearer_token_env_var: Some("TOKEN".to_string()), + http_headers: None, + env_http_headers: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + auth_status: McpAuthStatus::Unsupported, + }; + let err: StartupOutcomeError = anyhow::anyhow!("boom").into(); + + let display = mcp_init_error_display(server_name, Some(&entry), &err); + + let expected = format!("MCP client for `{server_name}` failed to start: {err:#}"); + + assert_eq!(expected, display); +} + +#[test] +fn mcp_init_error_display_includes_startup_timeout_hint() { + let server_name = "slow"; + let err: StartupOutcomeError = anyhow::anyhow!("request timed out").into(); + + let display = mcp_init_error_display(server_name, None, &err); + + assert_eq!( + "MCP client for `slow` timed out after 10 seconds. Add or adjust `startup_timeout_sec` in your config.toml:\n[mcp_servers.slow]\nstartup_timeout_sec = XX", + display + ); +} + +#[test] +fn transport_origin_extracts_http_origin() { + let transport = McpServerTransportConfig::StreamableHttp { + url: "https://example.com:8443/path?query=1".to_string(), + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }; + + assert_eq!( + transport_origin(&transport), + Some("https://example.com:8443".to_string()) + ); +} + +#[test] +fn transport_origin_is_stdio_for_stdio_transport() { + let transport = McpServerTransportConfig::Stdio { + command: "server".to_string(), + args: Vec::new(), + env: None, + env_vars: Vec::new(), + cwd: None, + }; + + assert_eq!(transport_origin(&transport), Some("stdio".to_string())); +} diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs index 888d86c5ff..7430312136 100644 --- a/codex-rs/core/src/mcp_tool_call.rs +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -1234,921 +1234,5 @@ async fn notify_mcp_tool_call_skip( } #[cfg(test)] -mod tests { - use super::*; - use crate::codex::make_session_and_context; - use crate::config::ConfigToml; - use crate::config::types::AppConfig; - use crate::config::types::AppToolConfig; - use crate::config::types::AppToolsConfig; - use crate::config::types::AppsConfigToml; - use codex_config::CONFIG_TOML_FILE; - use pretty_assertions::assert_eq; - use serde::Deserialize; - use std::collections::HashMap; - use std::sync::Arc; - use tempfile::tempdir; - - fn annotations( - read_only: Option, - destructive: Option, - open_world: Option, - ) -> ToolAnnotations { - ToolAnnotations { - destructive_hint: destructive, - idempotent_hint: None, - open_world_hint: open_world, - read_only_hint: read_only, - title: None, - } - } - - fn approval_metadata( - connector_id: Option<&str>, - connector_name: Option<&str>, - connector_description: Option<&str>, - tool_title: Option<&str>, - tool_description: Option<&str>, - ) -> McpToolApprovalMetadata { - McpToolApprovalMetadata { - annotations: None, - connector_id: connector_id.map(str::to_string), - connector_name: connector_name.map(str::to_string), - connector_description: connector_description.map(str::to_string), - tool_title: tool_title.map(str::to_string), - tool_description: tool_description.map(str::to_string), - } - } - - fn prompt_options( - allow_session_remember: bool, - allow_persistent_approval: bool, - ) -> McpToolApprovalPromptOptions { - McpToolApprovalPromptOptions { - allow_session_remember, - allow_persistent_approval, - } - } - - #[test] - fn approval_required_when_read_only_false_and_destructive() { - let annotations = annotations(Some(false), Some(true), None); - assert_eq!(requires_mcp_tool_approval(&annotations), true); - } - - #[test] - fn approval_required_when_read_only_false_and_open_world() { - let annotations = annotations(Some(false), None, Some(true)); - assert_eq!(requires_mcp_tool_approval(&annotations), true); - } - - #[test] - fn approval_required_when_destructive_even_if_read_only_true() { - let annotations = annotations(Some(true), Some(true), Some(true)); - assert_eq!(requires_mcp_tool_approval(&annotations), true); - } - - #[test] - fn prompt_mode_does_not_allow_persistent_remember() { - assert_eq!( - normalize_approval_decision_for_mode( - McpToolApprovalDecision::AcceptForSession, - AppToolApproval::Prompt, - ), - McpToolApprovalDecision::Accept - ); - assert_eq!( - normalize_approval_decision_for_mode( - McpToolApprovalDecision::AcceptAndRemember, - AppToolApproval::Prompt, - ), - McpToolApprovalDecision::Accept - ); - } - - #[test] - fn approval_question_text_prepends_safety_reason() { - assert_eq!( - mcp_tool_approval_question_text( - "Allow this action?".to_string(), - Some("This tool may contact an external system."), - ), - "Tool call needs your approval. Reason: This tool may contact an external system." - ); - } - - #[tokio::test] - async fn approval_elicitation_request_uses_message_override_and_readable_tool_params() { - let (session, turn_context) = make_session_and_context().await; - let question = build_mcp_tool_approval_question( - "q".to_string(), - CODEX_APPS_MCP_SERVER_NAME, - "create_event", - Some("Calendar"), - prompt_options(true, true), - Some("Allow Calendar to create an event?"), - ); - - let request = build_mcp_tool_approval_elicitation_request( - &session, - &turn_context, - McpToolApprovalElicitationRequest { - server: CODEX_APPS_MCP_SERVER_NAME, - metadata: Some(&approval_metadata( - Some("calendar"), - Some("Calendar"), - Some("Manage events and schedules."), - Some("Create Event"), - Some("Create a calendar event."), - )), - tool_params: Some(&serde_json::json!({ - "Calendar": "primary", - "Title": "Roadmap review", - })), - tool_params_display: None, - question, - message_override: Some("Allow Calendar to create an event?"), - prompt_options: prompt_options(true, true), - }, - ); - - assert_eq!( - request, - McpServerElicitationRequestParams { - thread_id: session.conversation_id.to_string(), - turn_id: Some(turn_context.sub_id), - server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), - request: McpServerElicitationRequest::Form { - meta: Some(serde_json::json!({ - MCP_TOOL_APPROVAL_KIND_KEY: MCP_TOOL_APPROVAL_KIND_MCP_TOOL_CALL, - MCP_TOOL_APPROVAL_PERSIST_KEY: [ - MCP_TOOL_APPROVAL_PERSIST_SESSION, - MCP_TOOL_APPROVAL_PERSIST_ALWAYS, - ], - MCP_TOOL_APPROVAL_SOURCE_KEY: MCP_TOOL_APPROVAL_SOURCE_CONNECTOR, - MCP_TOOL_APPROVAL_CONNECTOR_ID_KEY: "calendar", - MCP_TOOL_APPROVAL_CONNECTOR_NAME_KEY: "Calendar", - MCP_TOOL_APPROVAL_CONNECTOR_DESCRIPTION_KEY: "Manage events and schedules.", - MCP_TOOL_APPROVAL_TOOL_TITLE_KEY: "Create Event", - MCP_TOOL_APPROVAL_TOOL_DESCRIPTION_KEY: "Create a calendar event.", - MCP_TOOL_APPROVAL_TOOL_PARAMS_KEY: { - "Calendar": "primary", - "Title": "Roadmap review", - }, - })), - message: "Allow Calendar to create an event?".to_string(), - requested_schema: McpElicitationSchema { - schema_uri: None, - type_: McpElicitationObjectType::Object, - properties: BTreeMap::new(), - required: None, - }, - }, - } - ); - } - - #[test] - fn custom_mcp_tool_question_mentions_server_name() { - let question = build_mcp_tool_approval_question( - "q".to_string(), - "custom_server", - "run_action", - None, - prompt_options(false, false), - None, - ); - - assert_eq!(question.header, "Approve app tool call?"); - assert_eq!( - question.question, - "Allow the custom_server MCP server to run tool \"run_action\"?" - ); - assert!( - !question - .options - .expect("options") - .into_iter() - .map(|option| option.label) - .any(|label| label == MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER) - ); - } - - #[test] - fn codex_apps_tool_question_uses_fallback_app_label() { - let question = build_mcp_tool_approval_question( - "q".to_string(), - CODEX_APPS_MCP_SERVER_NAME, - "run_action", - None, - prompt_options(true, true), - None, - ); - - assert_eq!( - question.question, - "Allow this app to run tool \"run_action\"?" - ); - } - - #[test] - fn trusted_codex_apps_tool_question_offers_always_allow() { - let question = build_mcp_tool_approval_question( - "q".to_string(), - CODEX_APPS_MCP_SERVER_NAME, - "run_action", - Some("Calendar"), - prompt_options(true, true), - None, - ); - let options = question.options.expect("options"); - - assert!(options.iter().any(|option| { - option.label == MCP_TOOL_APPROVAL_ACCEPT_FOR_SESSION - && option.description == "Run the tool and remember this choice for this session." - })); - assert!(options.iter().any(|option| { - option.label == MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER - && option.description - == "Run the tool and remember this choice for future tool calls." - })); - assert_eq!( - options - .into_iter() - .map(|option| option.label) - .collect::>(), - vec![ - MCP_TOOL_APPROVAL_ACCEPT.to_string(), - MCP_TOOL_APPROVAL_ACCEPT_FOR_SESSION.to_string(), - MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER.to_string(), - MCP_TOOL_APPROVAL_CANCEL.to_string(), - ] - ); - } - - #[test] - fn codex_apps_tool_question_without_elicitation_omits_always_allow() { - let session_key = McpToolApprovalKey { - server: CODEX_APPS_MCP_SERVER_NAME.to_string(), - connector_id: Some("calendar".to_string()), - tool_name: "run_action".to_string(), - }; - let persistent_key = session_key.clone(); - let question = build_mcp_tool_approval_question( - "q".to_string(), - CODEX_APPS_MCP_SERVER_NAME, - "run_action", - Some("Calendar"), - mcp_tool_approval_prompt_options(Some(&session_key), Some(&persistent_key), false), - None, - ); - - assert_eq!( - question - .options - .expect("options") - .into_iter() - .map(|option| option.label) - .collect::>(), - vec![ - MCP_TOOL_APPROVAL_ACCEPT.to_string(), - MCP_TOOL_APPROVAL_ACCEPT_FOR_SESSION.to_string(), - MCP_TOOL_APPROVAL_CANCEL.to_string(), - ] - ); - } - - #[test] - fn custom_mcp_tool_question_offers_session_remember_without_always_allow() { - let question = build_mcp_tool_approval_question( - "q".to_string(), - "custom_server", - "run_action", - None, - prompt_options(true, false), - None, - ); - - assert_eq!( - question - .options - .expect("options") - .into_iter() - .map(|option| option.label) - .collect::>(), - vec![ - MCP_TOOL_APPROVAL_ACCEPT.to_string(), - MCP_TOOL_APPROVAL_ACCEPT_FOR_SESSION.to_string(), - MCP_TOOL_APPROVAL_CANCEL.to_string(), - ] - ); - } - - #[test] - fn custom_servers_keep_session_remember_without_persistent_approval() { - let invocation = McpInvocation { - server: "custom_server".to_string(), - tool: "run_action".to_string(), - arguments: None, - }; - let expected = McpToolApprovalKey { - server: "custom_server".to_string(), - connector_id: None, - tool_name: "run_action".to_string(), - }; - - assert_eq!( - session_mcp_tool_approval_key(&invocation, None, AppToolApproval::Auto), - Some(expected) - ); - assert_eq!( - persistent_mcp_tool_approval_key(&invocation, None, AppToolApproval::Auto), - None - ); - } - - #[test] - fn codex_apps_connectors_support_persistent_approval() { - let invocation = McpInvocation { - server: CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool: "calendar/list_events".to_string(), - arguments: None, - }; - let metadata = approval_metadata(Some("calendar"), Some("Calendar"), None, None, None); - let expected = McpToolApprovalKey { - server: CODEX_APPS_MCP_SERVER_NAME.to_string(), - connector_id: Some("calendar".to_string()), - tool_name: "calendar/list_events".to_string(), - }; - - assert_eq!( - session_mcp_tool_approval_key(&invocation, Some(&metadata), AppToolApproval::Auto), - Some(expected.clone()) - ); - assert_eq!( - persistent_mcp_tool_approval_key(&invocation, Some(&metadata), AppToolApproval::Auto), - Some(expected) - ); - } - - #[test] - fn sanitize_mcp_tool_result_for_model_rewrites_image_content() { - let result = Ok(CallToolResult { - content: vec![ - serde_json::json!({ - "type": "image", - "data": "Zm9v", - "mimeType": "image/png", - }), - serde_json::json!({ - "type": "text", - "text": "hello", - }), - ], - structured_content: None, - is_error: Some(false), - meta: None, - }); - - let got = sanitize_mcp_tool_result_for_model(false, result).expect("sanitized result"); - - assert_eq!( - got.content, - vec![ - serde_json::json!({ - "type": "text", - "text": "", - }), - serde_json::json!({ - "type": "text", - "text": "hello", - }), - ] - ); - } - - #[test] - fn sanitize_mcp_tool_result_for_model_preserves_image_when_supported() { - let original = CallToolResult { - content: vec![serde_json::json!({ - "type": "image", - "data": "Zm9v", - "mimeType": "image/png", - })], - structured_content: Some(serde_json::json!({"x": 1})), - is_error: Some(false), - meta: Some(serde_json::json!({"k": "v"})), - }; - - let got = sanitize_mcp_tool_result_for_model(true, Ok(original.clone())) - .expect("unsanitized result"); - - assert_eq!(got, original); - } - - #[test] - fn accepted_elicitation_content_converts_to_request_user_input_response() { - let response = - request_user_input_response_from_elicitation_content(Some(serde_json::json!( - { - "approval": MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER, - } - ))); - - assert_eq!( - response, - Some(RequestUserInputResponse { - answers: std::collections::HashMap::from([( - "approval".to_string(), - RequestUserInputAnswer { - answers: vec![MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER.to_string()], - }, - )]), - }) - ); - } - - #[test] - fn approval_elicitation_meta_marks_tool_approvals() { - assert_eq!( - build_mcp_tool_approval_elicitation_meta( - "custom_server", - None, - None, - None, - prompt_options(false, false), - ), - Some(serde_json::json!({ - MCP_TOOL_APPROVAL_KIND_KEY: MCP_TOOL_APPROVAL_KIND_MCP_TOOL_CALL, - })) - ); - } - - #[test] - fn approval_elicitation_meta_keeps_session_persist_behavior_for_custom_servers() { - assert_eq!( - build_mcp_tool_approval_elicitation_meta( - "custom_server", - Some(&approval_metadata( - None, - None, - None, - Some("Run Action"), - Some("Runs the selected action."), - )), - Some(&serde_json::json!({"id": 1})), - None, - prompt_options(true, false), - ), - Some(serde_json::json!({ - MCP_TOOL_APPROVAL_KIND_KEY: MCP_TOOL_APPROVAL_KIND_MCP_TOOL_CALL, - MCP_TOOL_APPROVAL_PERSIST_KEY: MCP_TOOL_APPROVAL_PERSIST_SESSION, - MCP_TOOL_APPROVAL_TOOL_TITLE_KEY: "Run Action", - MCP_TOOL_APPROVAL_TOOL_DESCRIPTION_KEY: "Runs the selected action.", - MCP_TOOL_APPROVAL_TOOL_PARAMS_KEY: { - "id": 1, - }, - })) - ); - } - - #[test] - fn guardian_mcp_review_request_includes_invocation_metadata() { - let invocation = McpInvocation { - server: CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool: "browser_navigate".to_string(), - arguments: Some(serde_json::json!({ - "url": "https://example.com", - })), - }; - - let request = build_guardian_mcp_tool_review_request( - &invocation, - Some(&approval_metadata( - Some("playwright"), - Some("Playwright"), - Some("Browser automation"), - Some("Navigate"), - Some("Open a page"), - )), - ); - - assert_eq!( - request, - GuardianApprovalRequest::McpToolCall { - server: CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool_name: "browser_navigate".to_string(), - arguments: Some(serde_json::json!({ - "url": "https://example.com", - })), - connector_id: Some("playwright".to_string()), - connector_name: Some("Playwright".to_string()), - connector_description: Some("Browser automation".to_string()), - tool_title: Some("Navigate".to_string()), - tool_description: Some("Open a page".to_string()), - annotations: None, - } - ); - } - - #[test] - fn guardian_mcp_review_request_includes_annotations_when_present() { - let invocation = McpInvocation { - server: "custom_server".to_string(), - tool: "dangerous_tool".to_string(), - arguments: None, - }; - let metadata = McpToolApprovalMetadata { - annotations: Some(annotations(Some(false), Some(true), Some(true))), - connector_id: None, - connector_name: None, - connector_description: None, - tool_title: None, - tool_description: None, - }; - - let request = build_guardian_mcp_tool_review_request(&invocation, Some(&metadata)); - - assert_eq!( - request, - GuardianApprovalRequest::McpToolCall { - server: "custom_server".to_string(), - tool_name: "dangerous_tool".to_string(), - arguments: None, - connector_id: None, - connector_name: None, - connector_description: None, - tool_title: None, - tool_description: None, - annotations: Some(GuardianMcpAnnotations { - destructive_hint: Some(true), - open_world_hint: Some(true), - read_only_hint: Some(false), - }), - } - ); - } - - #[test] - fn prepare_arc_request_action_serializes_mcp_tool_call_shape() { - let invocation = McpInvocation { - server: CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool: "browser_navigate".to_string(), - arguments: Some(serde_json::json!({ - "url": "https://example.com", - })), - }; - - let action = prepare_arc_request_action( - &invocation, - Some(&approval_metadata( - None, - Some("Playwright"), - None, - Some("Navigate"), - None, - )), - ); - - assert_eq!( - action, - serde_json::json!({ - "tool": "mcp_tool_call", - "server": CODEX_APPS_MCP_SERVER_NAME, - "tool_name": "browser_navigate", - "arguments": { - "url": "https://example.com", - }, - "connector_name": "Playwright", - "tool_title": "Navigate", - }) - ); - } - - #[test] - fn guardian_review_decision_maps_to_mcp_tool_decision() { - assert_eq!( - mcp_tool_approval_decision_from_guardian(ReviewDecision::Approved), - McpToolApprovalDecision::Accept - ); - assert_eq!( - mcp_tool_approval_decision_from_guardian(ReviewDecision::Denied), - McpToolApprovalDecision::Decline - ); - assert_eq!( - mcp_tool_approval_decision_from_guardian(ReviewDecision::Abort), - McpToolApprovalDecision::Decline - ); - } - - #[test] - fn approval_elicitation_meta_includes_connector_source_for_codex_apps() { - assert_eq!( - build_mcp_tool_approval_elicitation_meta( - CODEX_APPS_MCP_SERVER_NAME, - Some(&approval_metadata( - Some("calendar"), - Some("Calendar"), - Some("Manage events and schedules."), - Some("Run Action"), - Some("Runs the selected action."), - )), - Some(&serde_json::json!({ - "calendar_id": "primary", - })), - None, - prompt_options(false, false), - ), - Some(serde_json::json!({ - MCP_TOOL_APPROVAL_KIND_KEY: MCP_TOOL_APPROVAL_KIND_MCP_TOOL_CALL, - MCP_TOOL_APPROVAL_SOURCE_KEY: MCP_TOOL_APPROVAL_SOURCE_CONNECTOR, - MCP_TOOL_APPROVAL_CONNECTOR_ID_KEY: "calendar", - MCP_TOOL_APPROVAL_CONNECTOR_NAME_KEY: "Calendar", - MCP_TOOL_APPROVAL_CONNECTOR_DESCRIPTION_KEY: "Manage events and schedules.", - MCP_TOOL_APPROVAL_TOOL_TITLE_KEY: "Run Action", - MCP_TOOL_APPROVAL_TOOL_DESCRIPTION_KEY: "Runs the selected action.", - MCP_TOOL_APPROVAL_TOOL_PARAMS_KEY: { - "calendar_id": "primary", - }, - })) - ); - } - - #[test] - fn approval_elicitation_meta_merges_session_and_always_persist_with_connector_source() { - assert_eq!( - build_mcp_tool_approval_elicitation_meta( - CODEX_APPS_MCP_SERVER_NAME, - Some(&approval_metadata( - Some("calendar"), - Some("Calendar"), - Some("Manage events and schedules."), - Some("Run Action"), - Some("Runs the selected action."), - )), - Some(&serde_json::json!({ - "calendar_id": "primary", - })), - None, - prompt_options(true, true), - ), - Some(serde_json::json!({ - MCP_TOOL_APPROVAL_KIND_KEY: MCP_TOOL_APPROVAL_KIND_MCP_TOOL_CALL, - MCP_TOOL_APPROVAL_PERSIST_KEY: [ - MCP_TOOL_APPROVAL_PERSIST_SESSION, - MCP_TOOL_APPROVAL_PERSIST_ALWAYS, - ], - MCP_TOOL_APPROVAL_SOURCE_KEY: MCP_TOOL_APPROVAL_SOURCE_CONNECTOR, - MCP_TOOL_APPROVAL_CONNECTOR_ID_KEY: "calendar", - MCP_TOOL_APPROVAL_CONNECTOR_NAME_KEY: "Calendar", - MCP_TOOL_APPROVAL_CONNECTOR_DESCRIPTION_KEY: "Manage events and schedules.", - MCP_TOOL_APPROVAL_TOOL_TITLE_KEY: "Run Action", - MCP_TOOL_APPROVAL_TOOL_DESCRIPTION_KEY: "Runs the selected action.", - MCP_TOOL_APPROVAL_TOOL_PARAMS_KEY: { - "calendar_id": "primary", - }, - })) - ); - } - - #[test] - fn declined_elicitation_response_stays_decline() { - let response = parse_mcp_tool_approval_elicitation_response( - Some(ElicitationResponse { - action: ElicitationAction::Decline, - content: Some(serde_json::json!({ - "approval": MCP_TOOL_APPROVAL_ACCEPT, - })), - meta: None, - }), - "approval", - ); - - assert_eq!(response, McpToolApprovalDecision::Decline); - } - - #[test] - fn accepted_elicitation_response_uses_always_persist_meta() { - let response = parse_mcp_tool_approval_elicitation_response( - Some(ElicitationResponse { - action: ElicitationAction::Accept, - content: None, - meta: Some(serde_json::json!({ - MCP_TOOL_APPROVAL_PERSIST_KEY: MCP_TOOL_APPROVAL_PERSIST_ALWAYS, - })), - }), - "approval", - ); - - assert_eq!(response, McpToolApprovalDecision::AcceptAndRemember); - } - - #[test] - fn accepted_elicitation_response_uses_session_persist_meta() { - let response = parse_mcp_tool_approval_elicitation_response( - Some(ElicitationResponse { - action: ElicitationAction::Accept, - content: None, - meta: Some(serde_json::json!({ - MCP_TOOL_APPROVAL_PERSIST_KEY: MCP_TOOL_APPROVAL_PERSIST_SESSION, - })), - }), - "approval", - ); - - assert_eq!(response, McpToolApprovalDecision::AcceptForSession); - } - - #[test] - fn accepted_elicitation_without_content_defaults_to_accept() { - let response = parse_mcp_tool_approval_elicitation_response( - Some(ElicitationResponse { - action: ElicitationAction::Accept, - content: None, - meta: None, - }), - "approval", - ); - - assert_eq!(response, McpToolApprovalDecision::Accept); - } - - #[tokio::test] - async fn persist_codex_app_tool_approval_writes_tool_override() { - let tmp = tempdir().expect("tempdir"); - - persist_codex_app_tool_approval(tmp.path(), "calendar", "calendar/list_events") - .await - .expect("persist approval"); - - let contents = - std::fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).expect("read config"); - let parsed: ConfigToml = toml::from_str(&contents).expect("parse config"); - - assert_eq!( - parsed.apps, - Some(AppsConfigToml { - default: None, - apps: HashMap::from([( - "calendar".to_string(), - AppConfig { - enabled: true, - destructive_enabled: None, - open_world_enabled: None, - default_tools_approval_mode: None, - default_tools_enabled: None, - tools: Some(AppToolsConfig { - tools: HashMap::from([( - "calendar/list_events".to_string(), - AppToolConfig { - enabled: None, - approval_mode: Some(AppToolApproval::Approve), - }, - )]), - }), - }, - )]), - }) - ); - assert!(contents.contains("[apps.calendar.tools.\"calendar/list_events\"]")); - } - - #[tokio::test] - async fn maybe_persist_mcp_tool_approval_reloads_session_config() { - let (session, turn_context) = make_session_and_context().await; - let codex_home = session.codex_home().await; - std::fs::create_dir_all(&codex_home).expect("create codex home"); - let key = McpToolApprovalKey { - server: CODEX_APPS_MCP_SERVER_NAME.to_string(), - connector_id: Some("calendar".to_string()), - tool_name: "calendar/list_events".to_string(), - }; - - maybe_persist_mcp_tool_approval(&session, &turn_context, key.clone()).await; - - let config = session.get_config().await; - let apps_toml = config - .config_layer_stack - .effective_config() - .as_table() - .and_then(|table| table.get("apps")) - .cloned() - .expect("apps table"); - let apps = AppsConfigToml::deserialize(apps_toml).expect("deserialize apps config"); - let tool = apps - .apps - .get("calendar") - .and_then(|app| app.tools.as_ref()) - .and_then(|tools| tools.tools.get("calendar/list_events")) - .expect("calendar/list_events tool config exists"); - - assert_eq!( - tool, - &AppToolConfig { - enabled: None, - approval_mode: Some(AppToolApproval::Approve), - } - ); - assert_eq!(mcp_tool_approval_is_remembered(&session, &key).await, true); - } - - #[tokio::test] - async fn approve_mode_skips_when_annotations_do_not_require_approval() { - let (session, turn_context) = make_session_and_context().await; - let session = Arc::new(session); - let turn_context = Arc::new(turn_context); - let invocation = McpInvocation { - server: "custom_server".to_string(), - tool: "read_only_tool".to_string(), - arguments: None, - }; - let metadata = McpToolApprovalMetadata { - annotations: Some(annotations(Some(true), None, None)), - connector_id: None, - connector_name: None, - connector_description: None, - tool_title: Some("Read Only Tool".to_string()), - tool_description: None, - }; - - let decision = maybe_request_mcp_tool_approval( - &session, - &turn_context, - "call-1", - &invocation, - Some(&metadata), - AppToolApproval::Approve, - ) - .await; - - assert_eq!(decision, None); - } - - #[tokio::test] - async fn approve_mode_blocks_when_arc_returns_interrupt_for_model() { - use wiremock::Mock; - use wiremock::MockServer; - use wiremock::ResponseTemplate; - use wiremock::matchers::method; - use wiremock::matchers::path; - - let server = MockServer::start().await; - Mock::given(method("POST")) - .and(path("/codex/safety/arc")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "outcome": "steer-model", - "short_reason": "needs approval", - "rationale": "high-risk action", - "risk_score": 96, - "risk_level": "critical", - "evidence": [{ - "message": "dangerous_tool", - "why": "high-risk action", - }], - }))) - .expect(1) - .mount(&server) - .await; - - let (session, mut turn_context) = make_session_and_context().await; - turn_context.auth_manager = Some(crate::test_support::auth_manager_from_auth( - crate::CodexAuth::create_dummy_chatgpt_auth_for_testing(), - )); - let mut config = (*turn_context.config).clone(); - config.chatgpt_base_url = server.uri(); - turn_context.config = Arc::new(config); - - let session = Arc::new(session); - let turn_context = Arc::new(turn_context); - let invocation = McpInvocation { - server: CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool: "dangerous_tool".to_string(), - arguments: Some(serde_json::json!({ "id": 1 })), - }; - let metadata = McpToolApprovalMetadata { - annotations: Some(annotations(Some(false), Some(true), Some(true))), - connector_id: Some("calendar".to_string()), - connector_name: Some("Calendar".to_string()), - connector_description: Some("Manage events".to_string()), - tool_title: Some("Dangerous Tool".to_string()), - tool_description: Some("Performs a risky action.".to_string()), - }; - - let decision = maybe_request_mcp_tool_approval( - &session, - &turn_context, - "call-2", - &invocation, - Some(&metadata), - AppToolApproval::Approve, - ) - .await; - - assert_eq!( - decision, - Some(McpToolApprovalDecision::BlockedBySafetyMonitor( - "Tool call was cancelled because of safety risks: high-risk action".to_string(), - )) - ); - } -} +#[path = "mcp_tool_call_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/mcp_tool_call_tests.rs b/codex-rs/core/src/mcp_tool_call_tests.rs new file mode 100644 index 0000000000..5e8cb7f873 --- /dev/null +++ b/codex-rs/core/src/mcp_tool_call_tests.rs @@ -0,0 +1,913 @@ +use super::*; +use crate::codex::make_session_and_context; +use crate::config::ConfigToml; +use crate::config::types::AppConfig; +use crate::config::types::AppToolConfig; +use crate::config::types::AppToolsConfig; +use crate::config::types::AppsConfigToml; +use codex_config::CONFIG_TOML_FILE; +use pretty_assertions::assert_eq; +use serde::Deserialize; +use std::collections::HashMap; +use std::sync::Arc; +use tempfile::tempdir; + +fn annotations( + read_only: Option, + destructive: Option, + open_world: Option, +) -> ToolAnnotations { + ToolAnnotations { + destructive_hint: destructive, + idempotent_hint: None, + open_world_hint: open_world, + read_only_hint: read_only, + title: None, + } +} + +fn approval_metadata( + connector_id: Option<&str>, + connector_name: Option<&str>, + connector_description: Option<&str>, + tool_title: Option<&str>, + tool_description: Option<&str>, +) -> McpToolApprovalMetadata { + McpToolApprovalMetadata { + annotations: None, + connector_id: connector_id.map(str::to_string), + connector_name: connector_name.map(str::to_string), + connector_description: connector_description.map(str::to_string), + tool_title: tool_title.map(str::to_string), + tool_description: tool_description.map(str::to_string), + } +} + +fn prompt_options( + allow_session_remember: bool, + allow_persistent_approval: bool, +) -> McpToolApprovalPromptOptions { + McpToolApprovalPromptOptions { + allow_session_remember, + allow_persistent_approval, + } +} + +#[test] +fn approval_required_when_read_only_false_and_destructive() { + let annotations = annotations(Some(false), Some(true), None); + assert_eq!(requires_mcp_tool_approval(&annotations), true); +} + +#[test] +fn approval_required_when_read_only_false_and_open_world() { + let annotations = annotations(Some(false), None, Some(true)); + assert_eq!(requires_mcp_tool_approval(&annotations), true); +} + +#[test] +fn approval_required_when_destructive_even_if_read_only_true() { + let annotations = annotations(Some(true), Some(true), Some(true)); + assert_eq!(requires_mcp_tool_approval(&annotations), true); +} + +#[test] +fn prompt_mode_does_not_allow_persistent_remember() { + assert_eq!( + normalize_approval_decision_for_mode( + McpToolApprovalDecision::AcceptForSession, + AppToolApproval::Prompt, + ), + McpToolApprovalDecision::Accept + ); + assert_eq!( + normalize_approval_decision_for_mode( + McpToolApprovalDecision::AcceptAndRemember, + AppToolApproval::Prompt, + ), + McpToolApprovalDecision::Accept + ); +} + +#[test] +fn approval_question_text_prepends_safety_reason() { + assert_eq!( + mcp_tool_approval_question_text( + "Allow this action?".to_string(), + Some("This tool may contact an external system."), + ), + "Tool call needs your approval. Reason: This tool may contact an external system." + ); +} + +#[tokio::test] +async fn approval_elicitation_request_uses_message_override_and_readable_tool_params() { + let (session, turn_context) = make_session_and_context().await; + let question = build_mcp_tool_approval_question( + "q".to_string(), + CODEX_APPS_MCP_SERVER_NAME, + "create_event", + Some("Calendar"), + prompt_options(true, true), + Some("Allow Calendar to create an event?"), + ); + + let request = build_mcp_tool_approval_elicitation_request( + &session, + &turn_context, + McpToolApprovalElicitationRequest { + server: CODEX_APPS_MCP_SERVER_NAME, + metadata: Some(&approval_metadata( + Some("calendar"), + Some("Calendar"), + Some("Manage events and schedules."), + Some("Create Event"), + Some("Create a calendar event."), + )), + tool_params: Some(&serde_json::json!({ + "Calendar": "primary", + "Title": "Roadmap review", + })), + tool_params_display: None, + question, + message_override: Some("Allow Calendar to create an event?"), + prompt_options: prompt_options(true, true), + }, + ); + + assert_eq!( + request, + McpServerElicitationRequestParams { + thread_id: session.conversation_id.to_string(), + turn_id: Some(turn_context.sub_id), + server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), + request: McpServerElicitationRequest::Form { + meta: Some(serde_json::json!({ + MCP_TOOL_APPROVAL_KIND_KEY: MCP_TOOL_APPROVAL_KIND_MCP_TOOL_CALL, + MCP_TOOL_APPROVAL_PERSIST_KEY: [ + MCP_TOOL_APPROVAL_PERSIST_SESSION, + MCP_TOOL_APPROVAL_PERSIST_ALWAYS, + ], + MCP_TOOL_APPROVAL_SOURCE_KEY: MCP_TOOL_APPROVAL_SOURCE_CONNECTOR, + MCP_TOOL_APPROVAL_CONNECTOR_ID_KEY: "calendar", + MCP_TOOL_APPROVAL_CONNECTOR_NAME_KEY: "Calendar", + MCP_TOOL_APPROVAL_CONNECTOR_DESCRIPTION_KEY: "Manage events and schedules.", + MCP_TOOL_APPROVAL_TOOL_TITLE_KEY: "Create Event", + MCP_TOOL_APPROVAL_TOOL_DESCRIPTION_KEY: "Create a calendar event.", + MCP_TOOL_APPROVAL_TOOL_PARAMS_KEY: { + "Calendar": "primary", + "Title": "Roadmap review", + }, + })), + message: "Allow Calendar to create an event?".to_string(), + requested_schema: McpElicitationSchema { + schema_uri: None, + type_: McpElicitationObjectType::Object, + properties: BTreeMap::new(), + required: None, + }, + }, + } + ); +} + +#[test] +fn custom_mcp_tool_question_mentions_server_name() { + let question = build_mcp_tool_approval_question( + "q".to_string(), + "custom_server", + "run_action", + None, + prompt_options(false, false), + None, + ); + + assert_eq!(question.header, "Approve app tool call?"); + assert_eq!( + question.question, + "Allow the custom_server MCP server to run tool \"run_action\"?" + ); + assert!( + !question + .options + .expect("options") + .into_iter() + .map(|option| option.label) + .any(|label| label == MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER) + ); +} + +#[test] +fn codex_apps_tool_question_uses_fallback_app_label() { + let question = build_mcp_tool_approval_question( + "q".to_string(), + CODEX_APPS_MCP_SERVER_NAME, + "run_action", + None, + prompt_options(true, true), + None, + ); + + assert_eq!( + question.question, + "Allow this app to run tool \"run_action\"?" + ); +} + +#[test] +fn trusted_codex_apps_tool_question_offers_always_allow() { + let question = build_mcp_tool_approval_question( + "q".to_string(), + CODEX_APPS_MCP_SERVER_NAME, + "run_action", + Some("Calendar"), + prompt_options(true, true), + None, + ); + let options = question.options.expect("options"); + + assert!(options.iter().any(|option| { + option.label == MCP_TOOL_APPROVAL_ACCEPT_FOR_SESSION + && option.description == "Run the tool and remember this choice for this session." + })); + assert!(options.iter().any(|option| { + option.label == MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER + && option.description == "Run the tool and remember this choice for future tool calls." + })); + assert_eq!( + options + .into_iter() + .map(|option| option.label) + .collect::>(), + vec![ + MCP_TOOL_APPROVAL_ACCEPT.to_string(), + MCP_TOOL_APPROVAL_ACCEPT_FOR_SESSION.to_string(), + MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER.to_string(), + MCP_TOOL_APPROVAL_CANCEL.to_string(), + ] + ); +} + +#[test] +fn codex_apps_tool_question_without_elicitation_omits_always_allow() { + let session_key = McpToolApprovalKey { + server: CODEX_APPS_MCP_SERVER_NAME.to_string(), + connector_id: Some("calendar".to_string()), + tool_name: "run_action".to_string(), + }; + let persistent_key = session_key.clone(); + let question = build_mcp_tool_approval_question( + "q".to_string(), + CODEX_APPS_MCP_SERVER_NAME, + "run_action", + Some("Calendar"), + mcp_tool_approval_prompt_options(Some(&session_key), Some(&persistent_key), false), + None, + ); + + assert_eq!( + question + .options + .expect("options") + .into_iter() + .map(|option| option.label) + .collect::>(), + vec![ + MCP_TOOL_APPROVAL_ACCEPT.to_string(), + MCP_TOOL_APPROVAL_ACCEPT_FOR_SESSION.to_string(), + MCP_TOOL_APPROVAL_CANCEL.to_string(), + ] + ); +} + +#[test] +fn custom_mcp_tool_question_offers_session_remember_without_always_allow() { + let question = build_mcp_tool_approval_question( + "q".to_string(), + "custom_server", + "run_action", + None, + prompt_options(true, false), + None, + ); + + assert_eq!( + question + .options + .expect("options") + .into_iter() + .map(|option| option.label) + .collect::>(), + vec![ + MCP_TOOL_APPROVAL_ACCEPT.to_string(), + MCP_TOOL_APPROVAL_ACCEPT_FOR_SESSION.to_string(), + MCP_TOOL_APPROVAL_CANCEL.to_string(), + ] + ); +} + +#[test] +fn custom_servers_keep_session_remember_without_persistent_approval() { + let invocation = McpInvocation { + server: "custom_server".to_string(), + tool: "run_action".to_string(), + arguments: None, + }; + let expected = McpToolApprovalKey { + server: "custom_server".to_string(), + connector_id: None, + tool_name: "run_action".to_string(), + }; + + assert_eq!( + session_mcp_tool_approval_key(&invocation, None, AppToolApproval::Auto), + Some(expected) + ); + assert_eq!( + persistent_mcp_tool_approval_key(&invocation, None, AppToolApproval::Auto), + None + ); +} + +#[test] +fn codex_apps_connectors_support_persistent_approval() { + let invocation = McpInvocation { + server: CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool: "calendar/list_events".to_string(), + arguments: None, + }; + let metadata = approval_metadata(Some("calendar"), Some("Calendar"), None, None, None); + let expected = McpToolApprovalKey { + server: CODEX_APPS_MCP_SERVER_NAME.to_string(), + connector_id: Some("calendar".to_string()), + tool_name: "calendar/list_events".to_string(), + }; + + assert_eq!( + session_mcp_tool_approval_key(&invocation, Some(&metadata), AppToolApproval::Auto), + Some(expected.clone()) + ); + assert_eq!( + persistent_mcp_tool_approval_key(&invocation, Some(&metadata), AppToolApproval::Auto), + Some(expected) + ); +} + +#[test] +fn sanitize_mcp_tool_result_for_model_rewrites_image_content() { + let result = Ok(CallToolResult { + content: vec![ + serde_json::json!({ + "type": "image", + "data": "Zm9v", + "mimeType": "image/png", + }), + serde_json::json!({ + "type": "text", + "text": "hello", + }), + ], + structured_content: None, + is_error: Some(false), + meta: None, + }); + + let got = sanitize_mcp_tool_result_for_model(false, result).expect("sanitized result"); + + assert_eq!( + got.content, + vec![ + serde_json::json!({ + "type": "text", + "text": "", + }), + serde_json::json!({ + "type": "text", + "text": "hello", + }), + ] + ); +} + +#[test] +fn sanitize_mcp_tool_result_for_model_preserves_image_when_supported() { + let original = CallToolResult { + content: vec![serde_json::json!({ + "type": "image", + "data": "Zm9v", + "mimeType": "image/png", + })], + structured_content: Some(serde_json::json!({"x": 1})), + is_error: Some(false), + meta: Some(serde_json::json!({"k": "v"})), + }; + + let got = + sanitize_mcp_tool_result_for_model(true, Ok(original.clone())).expect("unsanitized result"); + + assert_eq!(got, original); +} + +#[test] +fn accepted_elicitation_content_converts_to_request_user_input_response() { + let response = request_user_input_response_from_elicitation_content(Some(serde_json::json!( + { + "approval": MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER, + } + ))); + + assert_eq!( + response, + Some(RequestUserInputResponse { + answers: std::collections::HashMap::from([( + "approval".to_string(), + RequestUserInputAnswer { + answers: vec![MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER.to_string()], + }, + )]), + }) + ); +} + +#[test] +fn approval_elicitation_meta_marks_tool_approvals() { + assert_eq!( + build_mcp_tool_approval_elicitation_meta( + "custom_server", + None, + None, + None, + prompt_options(false, false), + ), + Some(serde_json::json!({ + MCP_TOOL_APPROVAL_KIND_KEY: MCP_TOOL_APPROVAL_KIND_MCP_TOOL_CALL, + })) + ); +} + +#[test] +fn approval_elicitation_meta_keeps_session_persist_behavior_for_custom_servers() { + assert_eq!( + build_mcp_tool_approval_elicitation_meta( + "custom_server", + Some(&approval_metadata( + None, + None, + None, + Some("Run Action"), + Some("Runs the selected action."), + )), + Some(&serde_json::json!({"id": 1})), + None, + prompt_options(true, false), + ), + Some(serde_json::json!({ + MCP_TOOL_APPROVAL_KIND_KEY: MCP_TOOL_APPROVAL_KIND_MCP_TOOL_CALL, + MCP_TOOL_APPROVAL_PERSIST_KEY: MCP_TOOL_APPROVAL_PERSIST_SESSION, + MCP_TOOL_APPROVAL_TOOL_TITLE_KEY: "Run Action", + MCP_TOOL_APPROVAL_TOOL_DESCRIPTION_KEY: "Runs the selected action.", + MCP_TOOL_APPROVAL_TOOL_PARAMS_KEY: { + "id": 1, + }, + })) + ); +} + +#[test] +fn guardian_mcp_review_request_includes_invocation_metadata() { + let invocation = McpInvocation { + server: CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool: "browser_navigate".to_string(), + arguments: Some(serde_json::json!({ + "url": "https://example.com", + })), + }; + + let request = build_guardian_mcp_tool_review_request( + &invocation, + Some(&approval_metadata( + Some("playwright"), + Some("Playwright"), + Some("Browser automation"), + Some("Navigate"), + Some("Open a page"), + )), + ); + + assert_eq!( + request, + GuardianApprovalRequest::McpToolCall { + server: CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool_name: "browser_navigate".to_string(), + arguments: Some(serde_json::json!({ + "url": "https://example.com", + })), + connector_id: Some("playwright".to_string()), + connector_name: Some("Playwright".to_string()), + connector_description: Some("Browser automation".to_string()), + tool_title: Some("Navigate".to_string()), + tool_description: Some("Open a page".to_string()), + annotations: None, + } + ); +} + +#[test] +fn guardian_mcp_review_request_includes_annotations_when_present() { + let invocation = McpInvocation { + server: "custom_server".to_string(), + tool: "dangerous_tool".to_string(), + arguments: None, + }; + let metadata = McpToolApprovalMetadata { + annotations: Some(annotations(Some(false), Some(true), Some(true))), + connector_id: None, + connector_name: None, + connector_description: None, + tool_title: None, + tool_description: None, + }; + + let request = build_guardian_mcp_tool_review_request(&invocation, Some(&metadata)); + + assert_eq!( + request, + GuardianApprovalRequest::McpToolCall { + server: "custom_server".to_string(), + tool_name: "dangerous_tool".to_string(), + arguments: None, + connector_id: None, + connector_name: None, + connector_description: None, + tool_title: None, + tool_description: None, + annotations: Some(GuardianMcpAnnotations { + destructive_hint: Some(true), + open_world_hint: Some(true), + read_only_hint: Some(false), + }), + } + ); +} + +#[test] +fn prepare_arc_request_action_serializes_mcp_tool_call_shape() { + let invocation = McpInvocation { + server: CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool: "browser_navigate".to_string(), + arguments: Some(serde_json::json!({ + "url": "https://example.com", + })), + }; + + let action = prepare_arc_request_action( + &invocation, + Some(&approval_metadata( + None, + Some("Playwright"), + None, + Some("Navigate"), + None, + )), + ); + + assert_eq!( + action, + serde_json::json!({ + "tool": "mcp_tool_call", + "server": CODEX_APPS_MCP_SERVER_NAME, + "tool_name": "browser_navigate", + "arguments": { + "url": "https://example.com", + }, + "connector_name": "Playwright", + "tool_title": "Navigate", + }) + ); +} + +#[test] +fn guardian_review_decision_maps_to_mcp_tool_decision() { + assert_eq!( + mcp_tool_approval_decision_from_guardian(ReviewDecision::Approved), + McpToolApprovalDecision::Accept + ); + assert_eq!( + mcp_tool_approval_decision_from_guardian(ReviewDecision::Denied), + McpToolApprovalDecision::Decline + ); + assert_eq!( + mcp_tool_approval_decision_from_guardian(ReviewDecision::Abort), + McpToolApprovalDecision::Decline + ); +} + +#[test] +fn approval_elicitation_meta_includes_connector_source_for_codex_apps() { + assert_eq!( + build_mcp_tool_approval_elicitation_meta( + CODEX_APPS_MCP_SERVER_NAME, + Some(&approval_metadata( + Some("calendar"), + Some("Calendar"), + Some("Manage events and schedules."), + Some("Run Action"), + Some("Runs the selected action."), + )), + Some(&serde_json::json!({ + "calendar_id": "primary", + })), + None, + prompt_options(false, false), + ), + Some(serde_json::json!({ + MCP_TOOL_APPROVAL_KIND_KEY: MCP_TOOL_APPROVAL_KIND_MCP_TOOL_CALL, + MCP_TOOL_APPROVAL_SOURCE_KEY: MCP_TOOL_APPROVAL_SOURCE_CONNECTOR, + MCP_TOOL_APPROVAL_CONNECTOR_ID_KEY: "calendar", + MCP_TOOL_APPROVAL_CONNECTOR_NAME_KEY: "Calendar", + MCP_TOOL_APPROVAL_CONNECTOR_DESCRIPTION_KEY: "Manage events and schedules.", + MCP_TOOL_APPROVAL_TOOL_TITLE_KEY: "Run Action", + MCP_TOOL_APPROVAL_TOOL_DESCRIPTION_KEY: "Runs the selected action.", + MCP_TOOL_APPROVAL_TOOL_PARAMS_KEY: { + "calendar_id": "primary", + }, + })) + ); +} + +#[test] +fn approval_elicitation_meta_merges_session_and_always_persist_with_connector_source() { + assert_eq!( + build_mcp_tool_approval_elicitation_meta( + CODEX_APPS_MCP_SERVER_NAME, + Some(&approval_metadata( + Some("calendar"), + Some("Calendar"), + Some("Manage events and schedules."), + Some("Run Action"), + Some("Runs the selected action."), + )), + Some(&serde_json::json!({ + "calendar_id": "primary", + })), + None, + prompt_options(true, true), + ), + Some(serde_json::json!({ + MCP_TOOL_APPROVAL_KIND_KEY: MCP_TOOL_APPROVAL_KIND_MCP_TOOL_CALL, + MCP_TOOL_APPROVAL_PERSIST_KEY: [ + MCP_TOOL_APPROVAL_PERSIST_SESSION, + MCP_TOOL_APPROVAL_PERSIST_ALWAYS, + ], + MCP_TOOL_APPROVAL_SOURCE_KEY: MCP_TOOL_APPROVAL_SOURCE_CONNECTOR, + MCP_TOOL_APPROVAL_CONNECTOR_ID_KEY: "calendar", + MCP_TOOL_APPROVAL_CONNECTOR_NAME_KEY: "Calendar", + MCP_TOOL_APPROVAL_CONNECTOR_DESCRIPTION_KEY: "Manage events and schedules.", + MCP_TOOL_APPROVAL_TOOL_TITLE_KEY: "Run Action", + MCP_TOOL_APPROVAL_TOOL_DESCRIPTION_KEY: "Runs the selected action.", + MCP_TOOL_APPROVAL_TOOL_PARAMS_KEY: { + "calendar_id": "primary", + }, + })) + ); +} + +#[test] +fn declined_elicitation_response_stays_decline() { + let response = parse_mcp_tool_approval_elicitation_response( + Some(ElicitationResponse { + action: ElicitationAction::Decline, + content: Some(serde_json::json!({ + "approval": MCP_TOOL_APPROVAL_ACCEPT, + })), + meta: None, + }), + "approval", + ); + + assert_eq!(response, McpToolApprovalDecision::Decline); +} + +#[test] +fn accepted_elicitation_response_uses_always_persist_meta() { + let response = parse_mcp_tool_approval_elicitation_response( + Some(ElicitationResponse { + action: ElicitationAction::Accept, + content: None, + meta: Some(serde_json::json!({ + MCP_TOOL_APPROVAL_PERSIST_KEY: MCP_TOOL_APPROVAL_PERSIST_ALWAYS, + })), + }), + "approval", + ); + + assert_eq!(response, McpToolApprovalDecision::AcceptAndRemember); +} + +#[test] +fn accepted_elicitation_response_uses_session_persist_meta() { + let response = parse_mcp_tool_approval_elicitation_response( + Some(ElicitationResponse { + action: ElicitationAction::Accept, + content: None, + meta: Some(serde_json::json!({ + MCP_TOOL_APPROVAL_PERSIST_KEY: MCP_TOOL_APPROVAL_PERSIST_SESSION, + })), + }), + "approval", + ); + + assert_eq!(response, McpToolApprovalDecision::AcceptForSession); +} + +#[test] +fn accepted_elicitation_without_content_defaults_to_accept() { + let response = parse_mcp_tool_approval_elicitation_response( + Some(ElicitationResponse { + action: ElicitationAction::Accept, + content: None, + meta: None, + }), + "approval", + ); + + assert_eq!(response, McpToolApprovalDecision::Accept); +} + +#[tokio::test] +async fn persist_codex_app_tool_approval_writes_tool_override() { + let tmp = tempdir().expect("tempdir"); + + persist_codex_app_tool_approval(tmp.path(), "calendar", "calendar/list_events") + .await + .expect("persist approval"); + + let contents = std::fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).expect("read config"); + let parsed: ConfigToml = toml::from_str(&contents).expect("parse config"); + + assert_eq!( + parsed.apps, + Some(AppsConfigToml { + default: None, + apps: HashMap::from([( + "calendar".to_string(), + AppConfig { + enabled: true, + destructive_enabled: None, + open_world_enabled: None, + default_tools_approval_mode: None, + default_tools_enabled: None, + tools: Some(AppToolsConfig { + tools: HashMap::from([( + "calendar/list_events".to_string(), + AppToolConfig { + enabled: None, + approval_mode: Some(AppToolApproval::Approve), + }, + )]), + }), + }, + )]), + }) + ); + assert!(contents.contains("[apps.calendar.tools.\"calendar/list_events\"]")); +} + +#[tokio::test] +async fn maybe_persist_mcp_tool_approval_reloads_session_config() { + let (session, turn_context) = make_session_and_context().await; + let codex_home = session.codex_home().await; + std::fs::create_dir_all(&codex_home).expect("create codex home"); + let key = McpToolApprovalKey { + server: CODEX_APPS_MCP_SERVER_NAME.to_string(), + connector_id: Some("calendar".to_string()), + tool_name: "calendar/list_events".to_string(), + }; + + maybe_persist_mcp_tool_approval(&session, &turn_context, key.clone()).await; + + let config = session.get_config().await; + let apps_toml = config + .config_layer_stack + .effective_config() + .as_table() + .and_then(|table| table.get("apps")) + .cloned() + .expect("apps table"); + let apps = AppsConfigToml::deserialize(apps_toml).expect("deserialize apps config"); + let tool = apps + .apps + .get("calendar") + .and_then(|app| app.tools.as_ref()) + .and_then(|tools| tools.tools.get("calendar/list_events")) + .expect("calendar/list_events tool config exists"); + + assert_eq!( + tool, + &AppToolConfig { + enabled: None, + approval_mode: Some(AppToolApproval::Approve), + } + ); + assert_eq!(mcp_tool_approval_is_remembered(&session, &key).await, true); +} + +#[tokio::test] +async fn approve_mode_skips_when_annotations_do_not_require_approval() { + let (session, turn_context) = make_session_and_context().await; + let session = Arc::new(session); + let turn_context = Arc::new(turn_context); + let invocation = McpInvocation { + server: "custom_server".to_string(), + tool: "read_only_tool".to_string(), + arguments: None, + }; + let metadata = McpToolApprovalMetadata { + annotations: Some(annotations(Some(true), None, None)), + connector_id: None, + connector_name: None, + connector_description: None, + tool_title: Some("Read Only Tool".to_string()), + tool_description: None, + }; + + let decision = maybe_request_mcp_tool_approval( + &session, + &turn_context, + "call-1", + &invocation, + Some(&metadata), + AppToolApproval::Approve, + ) + .await; + + assert_eq!(decision, None); +} + +#[tokio::test] +async fn approve_mode_blocks_when_arc_returns_interrupt_for_model() { + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::ResponseTemplate; + use wiremock::matchers::method; + use wiremock::matchers::path; + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/codex/safety/arc")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "outcome": "steer-model", + "short_reason": "needs approval", + "rationale": "high-risk action", + "risk_score": 96, + "risk_level": "critical", + "evidence": [{ + "message": "dangerous_tool", + "why": "high-risk action", + }], + }))) + .expect(1) + .mount(&server) + .await; + + let (session, mut turn_context) = make_session_and_context().await; + turn_context.auth_manager = Some(crate::test_support::auth_manager_from_auth( + crate::CodexAuth::create_dummy_chatgpt_auth_for_testing(), + )); + let mut config = (*turn_context.config).clone(); + config.chatgpt_base_url = server.uri(); + turn_context.config = Arc::new(config); + + let session = Arc::new(session); + let turn_context = Arc::new(turn_context); + let invocation = McpInvocation { + server: CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool: "dangerous_tool".to_string(), + arguments: Some(serde_json::json!({ "id": 1 })), + }; + let metadata = McpToolApprovalMetadata { + annotations: Some(annotations(Some(false), Some(true), Some(true))), + connector_id: Some("calendar".to_string()), + connector_name: Some("Calendar".to_string()), + connector_description: Some("Manage events".to_string()), + tool_title: Some("Dangerous Tool".to_string()), + tool_description: Some("Performs a risky action.".to_string()), + }; + + let decision = maybe_request_mcp_tool_approval( + &session, + &turn_context, + "call-2", + &invocation, + Some(&metadata), + AppToolApproval::Approve, + ) + .await; + + assert_eq!( + decision, + Some(McpToolApprovalDecision::BlockedBySafetyMonitor( + "Tool call was cancelled because of safety risks: high-risk action".to_string(), + )) + ); +} diff --git a/codex-rs/core/src/memories/citations.rs b/codex-rs/core/src/memories/citations.rs index 91c7778266..ed620e853b 100644 --- a/codex-rs/core/src/memories/citations.rs +++ b/codex-rs/core/src/memories/citations.rs @@ -32,31 +32,5 @@ pub fn get_thread_id_from_citations(citations: Vec) -> Vec { } #[cfg(test)] -mod tests { - use super::get_thread_id_from_citations; - use codex_protocol::ThreadId; - use pretty_assertions::assert_eq; - - #[test] - fn get_thread_id_from_citations_extracts_thread_ids() { - let first = ThreadId::new(); - let second = ThreadId::new(); - - let citations = vec![format!( - "\n\nMEMORY.md:1-2|note=[x]\n\n\n{first}\nnot-a-uuid\n{second}\n\n" - )]; - - assert_eq!(get_thread_id_from_citations(citations), vec![first, second]); - } - - #[test] - fn get_thread_id_from_citations_supports_legacy_rollout_ids() { - let thread_id = ThreadId::new(); - - let citations = vec![format!( - "\n\n{thread_id}\n\n" - )]; - - assert_eq!(get_thread_id_from_citations(citations), vec![thread_id]); - } -} +#[path = "citations_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/memories/citations_tests.rs b/codex-rs/core/src/memories/citations_tests.rs new file mode 100644 index 0000000000..b6783dea7c --- /dev/null +++ b/codex-rs/core/src/memories/citations_tests.rs @@ -0,0 +1,26 @@ +use super::get_thread_id_from_citations; +use codex_protocol::ThreadId; +use pretty_assertions::assert_eq; + +#[test] +fn get_thread_id_from_citations_extracts_thread_ids() { + let first = ThreadId::new(); + let second = ThreadId::new(); + + let citations = vec![format!( + "\n\nMEMORY.md:1-2|note=[x]\n\n\n{first}\nnot-a-uuid\n{second}\n\n" + )]; + + assert_eq!(get_thread_id_from_citations(citations), vec![first, second]); +} + +#[test] +fn get_thread_id_from_citations_supports_legacy_rollout_ids() { + let thread_id = ThreadId::new(); + + let citations = vec![format!( + "\n\n{thread_id}\n\n" + )]; + + assert_eq!(get_thread_id_from_citations(citations), vec![thread_id]); +} diff --git a/codex-rs/core/src/memories/phase1.rs b/codex-rs/core/src/memories/phase1.rs index ad4f29a0d2..c7e88f07ef 100644 --- a/codex-rs/core/src/memories/phase1.rs +++ b/codex-rs/core/src/memories/phase1.rs @@ -578,72 +578,5 @@ fn emit_metrics(session: &Session, counts: &Stats) { } #[cfg(test)] -mod tests { - use super::JobOutcome; - use super::JobResult; - use super::aggregate_stats; - use codex_protocol::protocol::TokenUsage; - use pretty_assertions::assert_eq; - - #[test] - fn count_outcomes_sums_token_usage_across_all_jobs() { - let counts = aggregate_stats(vec![ - JobResult { - outcome: JobOutcome::SucceededWithOutput, - token_usage: Some(TokenUsage { - input_tokens: 10, - cached_input_tokens: 2, - output_tokens: 3, - reasoning_output_tokens: 1, - total_tokens: 13, - }), - }, - JobResult { - outcome: JobOutcome::SucceededNoOutput, - token_usage: Some(TokenUsage { - input_tokens: 7, - cached_input_tokens: 1, - output_tokens: 2, - reasoning_output_tokens: 0, - total_tokens: 9, - }), - }, - JobResult { - outcome: JobOutcome::Failed, - token_usage: None, - }, - ]); - - assert_eq!(counts.claimed, 3); - assert_eq!(counts.succeeded_with_output, 1); - assert_eq!(counts.succeeded_no_output, 1); - assert_eq!(counts.failed, 1); - assert_eq!( - counts.total_token_usage, - Some(TokenUsage { - input_tokens: 17, - cached_input_tokens: 3, - output_tokens: 5, - reasoning_output_tokens: 1, - total_tokens: 22, - }) - ); - } - - #[test] - fn count_outcomes_keeps_usage_empty_when_no_job_reports_it() { - let counts = aggregate_stats(vec![ - JobResult { - outcome: JobOutcome::SucceededWithOutput, - token_usage: None, - }, - JobResult { - outcome: JobOutcome::Failed, - token_usage: None, - }, - ]); - - assert_eq!(counts.claimed, 2); - assert_eq!(counts.total_token_usage, None); - } -} +#[path = "phase1_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/memories/phase1_tests.rs b/codex-rs/core/src/memories/phase1_tests.rs new file mode 100644 index 0000000000..c3e358187e --- /dev/null +++ b/codex-rs/core/src/memories/phase1_tests.rs @@ -0,0 +1,67 @@ +use super::JobOutcome; +use super::JobResult; +use super::aggregate_stats; +use codex_protocol::protocol::TokenUsage; +use pretty_assertions::assert_eq; + +#[test] +fn count_outcomes_sums_token_usage_across_all_jobs() { + let counts = aggregate_stats(vec![ + JobResult { + outcome: JobOutcome::SucceededWithOutput, + token_usage: Some(TokenUsage { + input_tokens: 10, + cached_input_tokens: 2, + output_tokens: 3, + reasoning_output_tokens: 1, + total_tokens: 13, + }), + }, + JobResult { + outcome: JobOutcome::SucceededNoOutput, + token_usage: Some(TokenUsage { + input_tokens: 7, + cached_input_tokens: 1, + output_tokens: 2, + reasoning_output_tokens: 0, + total_tokens: 9, + }), + }, + JobResult { + outcome: JobOutcome::Failed, + token_usage: None, + }, + ]); + + assert_eq!(counts.claimed, 3); + assert_eq!(counts.succeeded_with_output, 1); + assert_eq!(counts.succeeded_no_output, 1); + assert_eq!(counts.failed, 1); + assert_eq!( + counts.total_token_usage, + Some(TokenUsage { + input_tokens: 17, + cached_input_tokens: 3, + output_tokens: 5, + reasoning_output_tokens: 1, + total_tokens: 22, + }) + ); +} + +#[test] +fn count_outcomes_keeps_usage_empty_when_no_job_reports_it() { + let counts = aggregate_stats(vec![ + JobResult { + outcome: JobOutcome::SucceededWithOutput, + token_usage: None, + }, + JobResult { + outcome: JobOutcome::Failed, + token_usage: None, + }, + ]); + + assert_eq!(counts.claimed, 2); + assert_eq!(counts.total_token_usage, None); +} diff --git a/codex-rs/core/src/memories/prompts.rs b/codex-rs/core/src/memories/prompts.rs index 35cfe1edf0..1659e1c1c7 100644 --- a/codex-rs/core/src/memories/prompts.rs +++ b/codex-rs/core/src/memories/prompts.rs @@ -179,56 +179,5 @@ pub(crate) async fn build_memory_tool_developer_instructions(codex_home: &Path) } #[cfg(test)] -mod tests { - use super::*; - use crate::models_manager::model_info::model_info_from_slug; - - #[test] - fn build_stage_one_input_message_truncates_rollout_using_model_context_window() { - let input = format!("{}{}{}", "a".repeat(700_000), "middle", "z".repeat(700_000)); - let mut model_info = model_info_from_slug("gpt-5.2-codex"); - model_info.context_window = Some(123_000); - let expected_rollout_token_limit = usize::try_from( - ((123_000_i64 * model_info.effective_context_window_percent) / 100) - * phase_one::CONTEXT_WINDOW_PERCENT - / 100, - ) - .unwrap(); - let expected_truncated = truncate_text( - &input, - TruncationPolicy::Tokens(expected_rollout_token_limit), - ); - let message = build_stage_one_input_message( - &model_info, - Path::new("/tmp/rollout.jsonl"), - Path::new("/tmp"), - &input, - ) - .unwrap(); - - assert!(expected_truncated.contains("tokens truncated")); - assert!(expected_truncated.starts_with('a')); - assert!(expected_truncated.ends_with('z')); - assert!(message.contains(&expected_truncated)); - } - - #[test] - fn build_stage_one_input_message_uses_default_limit_when_model_context_window_missing() { - let input = format!("{}{}{}", "a".repeat(700_000), "middle", "z".repeat(700_000)); - let mut model_info = model_info_from_slug("gpt-5.2-codex"); - model_info.context_window = None; - let expected_truncated = truncate_text( - &input, - TruncationPolicy::Tokens(phase_one::DEFAULT_STAGE_ONE_ROLLOUT_TOKEN_LIMIT), - ); - let message = build_stage_one_input_message( - &model_info, - Path::new("/tmp/rollout.jsonl"), - Path::new("/tmp"), - &input, - ) - .unwrap(); - - assert!(message.contains(&expected_truncated)); - } -} +#[path = "prompts_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/memories/prompts_tests.rs b/codex-rs/core/src/memories/prompts_tests.rs new file mode 100644 index 0000000000..acbe5785a2 --- /dev/null +++ b/codex-rs/core/src/memories/prompts_tests.rs @@ -0,0 +1,51 @@ +use super::*; +use crate::models_manager::model_info::model_info_from_slug; + +#[test] +fn build_stage_one_input_message_truncates_rollout_using_model_context_window() { + let input = format!("{}{}{}", "a".repeat(700_000), "middle", "z".repeat(700_000)); + let mut model_info = model_info_from_slug("gpt-5.2-codex"); + model_info.context_window = Some(123_000); + let expected_rollout_token_limit = usize::try_from( + ((123_000_i64 * model_info.effective_context_window_percent) / 100) + * phase_one::CONTEXT_WINDOW_PERCENT + / 100, + ) + .unwrap(); + let expected_truncated = truncate_text( + &input, + TruncationPolicy::Tokens(expected_rollout_token_limit), + ); + let message = build_stage_one_input_message( + &model_info, + Path::new("/tmp/rollout.jsonl"), + Path::new("/tmp"), + &input, + ) + .unwrap(); + + assert!(expected_truncated.contains("tokens truncated")); + assert!(expected_truncated.starts_with('a')); + assert!(expected_truncated.ends_with('z')); + assert!(message.contains(&expected_truncated)); +} + +#[test] +fn build_stage_one_input_message_uses_default_limit_when_model_context_window_missing() { + let input = format!("{}{}{}", "a".repeat(700_000), "middle", "z".repeat(700_000)); + let mut model_info = model_info_from_slug("gpt-5.2-codex"); + model_info.context_window = None; + let expected_truncated = truncate_text( + &input, + TruncationPolicy::Tokens(phase_one::DEFAULT_STAGE_ONE_ROLLOUT_TOKEN_LIMIT), + ); + let message = build_stage_one_input_message( + &model_info, + Path::new("/tmp/rollout.jsonl"), + Path::new("/tmp"), + &input, + ) + .unwrap(); + + assert!(message.contains(&expected_truncated)); +} diff --git a/codex-rs/core/src/memories/storage.rs b/codex-rs/core/src/memories/storage.rs index 68f75a095f..2455ae40df 100644 --- a/codex-rs/core/src/memories/storage.rs +++ b/codex-rs/core/src/memories/storage.rs @@ -256,75 +256,5 @@ pub(super) fn rollout_summary_file_stem_from_parts( } #[cfg(test)] -mod tests { - use super::rollout_summary_file_stem; - use super::rollout_summary_file_stem_from_parts; - use chrono::TimeZone; - use chrono::Utc; - use codex_protocol::ThreadId; - use codex_state::Stage1Output; - use pretty_assertions::assert_eq; - use std::path::PathBuf; - const FIXED_PREFIX: &str = "2025-02-11T15-35-19-jqmb"; - - fn stage1_output_with_slug(thread_id: ThreadId, rollout_slug: Option<&str>) -> Stage1Output { - Stage1Output { - thread_id, - source_updated_at: Utc.timestamp_opt(123, 0).single().expect("timestamp"), - raw_memory: "raw memory".to_string(), - rollout_summary: "summary".to_string(), - rollout_slug: rollout_slug.map(ToString::to_string), - rollout_path: PathBuf::from("/tmp/rollout.jsonl"), - cwd: PathBuf::from("/tmp/workspace"), - git_branch: None, - generated_at: Utc.timestamp_opt(124, 0).single().expect("timestamp"), - } - } - - fn fixed_thread_id() -> ThreadId { - ThreadId::try_from("0194f5a6-89ab-7cde-8123-456789abcdef").expect("valid thread id") - } - - #[test] - fn rollout_summary_file_stem_uses_uuid_timestamp_and_hash_when_slug_missing() { - let thread_id = fixed_thread_id(); - let memory = stage1_output_with_slug(thread_id, None); - - assert_eq!(rollout_summary_file_stem(&memory), FIXED_PREFIX); - assert_eq!( - rollout_summary_file_stem_from_parts( - memory.thread_id, - memory.source_updated_at, - memory.rollout_slug.as_deref(), - ), - FIXED_PREFIX - ); - } - - #[test] - fn rollout_summary_file_stem_sanitizes_and_truncates_slug() { - let thread_id = fixed_thread_id(); - let memory = stage1_output_with_slug( - thread_id, - Some("Unsafe Slug/With Spaces & Symbols + EXTRA_LONG_12345_67890_ABCDE_fghij_klmno"), - ); - - let stem = rollout_summary_file_stem(&memory); - let slug = stem - .strip_prefix(&format!("{FIXED_PREFIX}-")) - .expect("slug suffix should be present"); - assert_eq!(slug.len(), 60); - assert_eq!( - slug, - "unsafe_slug_with_spaces___symbols___extra_long_12345_67890_a" - ); - } - - #[test] - fn rollout_summary_file_stem_uses_uuid_timestamp_and_hash_when_slug_is_empty() { - let thread_id = fixed_thread_id(); - let memory = stage1_output_with_slug(thread_id, Some("")); - - assert_eq!(rollout_summary_file_stem(&memory), FIXED_PREFIX); - } -} +#[path = "storage_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/memories/storage_tests.rs b/codex-rs/core/src/memories/storage_tests.rs new file mode 100644 index 0000000000..5e0f2ce89c --- /dev/null +++ b/codex-rs/core/src/memories/storage_tests.rs @@ -0,0 +1,70 @@ +use super::rollout_summary_file_stem; +use super::rollout_summary_file_stem_from_parts; +use chrono::TimeZone; +use chrono::Utc; +use codex_protocol::ThreadId; +use codex_state::Stage1Output; +use pretty_assertions::assert_eq; +use std::path::PathBuf; +const FIXED_PREFIX: &str = "2025-02-11T15-35-19-jqmb"; + +fn stage1_output_with_slug(thread_id: ThreadId, rollout_slug: Option<&str>) -> Stage1Output { + Stage1Output { + thread_id, + source_updated_at: Utc.timestamp_opt(123, 0).single().expect("timestamp"), + raw_memory: "raw memory".to_string(), + rollout_summary: "summary".to_string(), + rollout_slug: rollout_slug.map(ToString::to_string), + rollout_path: PathBuf::from("/tmp/rollout.jsonl"), + cwd: PathBuf::from("/tmp/workspace"), + git_branch: None, + generated_at: Utc.timestamp_opt(124, 0).single().expect("timestamp"), + } +} + +fn fixed_thread_id() -> ThreadId { + ThreadId::try_from("0194f5a6-89ab-7cde-8123-456789abcdef").expect("valid thread id") +} + +#[test] +fn rollout_summary_file_stem_uses_uuid_timestamp_and_hash_when_slug_missing() { + let thread_id = fixed_thread_id(); + let memory = stage1_output_with_slug(thread_id, None); + + assert_eq!(rollout_summary_file_stem(&memory), FIXED_PREFIX); + assert_eq!( + rollout_summary_file_stem_from_parts( + memory.thread_id, + memory.source_updated_at, + memory.rollout_slug.as_deref(), + ), + FIXED_PREFIX + ); +} + +#[test] +fn rollout_summary_file_stem_sanitizes_and_truncates_slug() { + let thread_id = fixed_thread_id(); + let memory = stage1_output_with_slug( + thread_id, + Some("Unsafe Slug/With Spaces & Symbols + EXTRA_LONG_12345_67890_ABCDE_fghij_klmno"), + ); + + let stem = rollout_summary_file_stem(&memory); + let slug = stem + .strip_prefix(&format!("{FIXED_PREFIX}-")) + .expect("slug suffix should be present"); + assert_eq!(slug.len(), 60); + assert_eq!( + slug, + "unsafe_slug_with_spaces___symbols___extra_long_12345_67890_a" + ); +} + +#[test] +fn rollout_summary_file_stem_uses_uuid_timestamp_and_hash_when_slug_is_empty() { + let thread_id = fixed_thread_id(); + let memory = stage1_output_with_slug(thread_id, Some("")); + + assert_eq!(rollout_summary_file_stem(&memory), FIXED_PREFIX); +} diff --git a/codex-rs/core/src/memory_trace.rs b/codex-rs/core/src/memory_trace.rs index 5cc4994427..2e613e6713 100644 --- a/codex-rs/core/src/memory_trace.rs +++ b/codex-rs/core/src/memory_trace.rs @@ -226,78 +226,5 @@ fn build_memory_id(index: usize, path: &Path) -> String { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use tempfile::tempdir; - - #[test] - fn normalize_trace_items_handles_payload_wrapper_and_message_role_filtering() { - let items = vec![ - serde_json::json!({ - "type": "response_item", - "payload": {"type": "message", "role": "assistant", "content": []} - }), - serde_json::json!({ - "type": "response_item", - "payload": [ - {"type": "message", "role": "user", "content": []}, - {"type": "message", "role": "tool", "content": []}, - {"type": "function_call", "name": "shell", "arguments": "{}", "call_id": "c1"} - ] - }), - serde_json::json!({ - "type": "not_response_item", - "payload": {"type": "message", "role": "assistant", "content": []} - }), - serde_json::json!({ - "type": "message", - "role": "developer", - "content": [] - }), - ]; - - let normalized = normalize_trace_items(items, Path::new("trace.json")).expect("normalize"); - let expected = vec![ - serde_json::json!({"type": "message", "role": "assistant", "content": []}), - serde_json::json!({"type": "message", "role": "user", "content": []}), - serde_json::json!({"type": "function_call", "name": "shell", "arguments": "{}", "call_id": "c1"}), - serde_json::json!({"type": "message", "role": "developer", "content": []}), - ]; - assert_eq!(normalized, expected); - } - - #[test] - fn load_trace_items_supports_jsonl_arrays_and_objects() { - let text = r#" -{"type":"response_item","payload":{"type":"message","role":"assistant","content":[]}} -[{"type":"message","role":"user","content":[]},{"type":"message","role":"tool","content":[]}] -"#; - let loaded = load_trace_items(Path::new("trace.jsonl"), text).expect("load"); - let expected = vec![ - serde_json::json!({"type":"message","role":"assistant","content":[]}), - serde_json::json!({"type":"message","role":"user","content":[]}), - ]; - assert_eq!(loaded, expected); - } - - #[tokio::test] - async fn load_trace_text_decodes_utf8_sig() { - let dir = tempdir().expect("tempdir"); - let path = dir.path().join("trace.json"); - tokio::fs::write( - &path, - [ - 0xEF, 0xBB, 0xBF, b'[', b'{', b'"', b't', b'y', b'p', b'e', b'"', b':', b'"', b'm', - b'e', b's', b's', b'a', b'g', b'e', b'"', b',', b'"', b'r', b'o', b'l', b'e', b'"', - b':', b'"', b'u', b's', b'e', b'r', b'"', b',', b'"', b'c', b'o', b'n', b't', b'e', - b'n', b't', b'"', b':', b'[', b']', b'}', b']', - ], - ) - .await - .expect("write"); - - let text = load_trace_text(&path).await.expect("decode"); - assert!(text.starts_with('[')); - } -} +#[path = "memory_trace_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/memory_trace_tests.rs b/codex-rs/core/src/memory_trace_tests.rs new file mode 100644 index 0000000000..e4014ef7a4 --- /dev/null +++ b/codex-rs/core/src/memory_trace_tests.rs @@ -0,0 +1,73 @@ +use super::*; +use pretty_assertions::assert_eq; +use tempfile::tempdir; + +#[test] +fn normalize_trace_items_handles_payload_wrapper_and_message_role_filtering() { + let items = vec![ + serde_json::json!({ + "type": "response_item", + "payload": {"type": "message", "role": "assistant", "content": []} + }), + serde_json::json!({ + "type": "response_item", + "payload": [ + {"type": "message", "role": "user", "content": []}, + {"type": "message", "role": "tool", "content": []}, + {"type": "function_call", "name": "shell", "arguments": "{}", "call_id": "c1"} + ] + }), + serde_json::json!({ + "type": "not_response_item", + "payload": {"type": "message", "role": "assistant", "content": []} + }), + serde_json::json!({ + "type": "message", + "role": "developer", + "content": [] + }), + ]; + + let normalized = normalize_trace_items(items, Path::new("trace.json")).expect("normalize"); + let expected = vec![ + serde_json::json!({"type": "message", "role": "assistant", "content": []}), + serde_json::json!({"type": "message", "role": "user", "content": []}), + serde_json::json!({"type": "function_call", "name": "shell", "arguments": "{}", "call_id": "c1"}), + serde_json::json!({"type": "message", "role": "developer", "content": []}), + ]; + assert_eq!(normalized, expected); +} + +#[test] +fn load_trace_items_supports_jsonl_arrays_and_objects() { + let text = r#" +{"type":"response_item","payload":{"type":"message","role":"assistant","content":[]}} +[{"type":"message","role":"user","content":[]},{"type":"message","role":"tool","content":[]}] +"#; + let loaded = load_trace_items(Path::new("trace.jsonl"), text).expect("load"); + let expected = vec![ + serde_json::json!({"type":"message","role":"assistant","content":[]}), + serde_json::json!({"type":"message","role":"user","content":[]}), + ]; + assert_eq!(loaded, expected); +} + +#[tokio::test] +async fn load_trace_text_decodes_utf8_sig() { + let dir = tempdir().expect("tempdir"); + let path = dir.path().join("trace.json"); + tokio::fs::write( + &path, + [ + 0xEF, 0xBB, 0xBF, b'[', b'{', b'"', b't', b'y', b'p', b'e', b'"', b':', b'"', b'm', + b'e', b's', b's', b'a', b'g', b'e', b'"', b',', b'"', b'r', b'o', b'l', b'e', b'"', + b':', b'"', b'u', b's', b'e', b'r', b'"', b',', b'"', b'c', b'o', b'n', b't', b'e', + b'n', b't', b'"', b':', b'[', b']', b'}', b']', + ], + ) + .await + .expect("write"); + + let text = load_trace_text(&path).await.expect("decode"); + assert!(text.starts_with('[')); +} diff --git a/codex-rs/core/src/mentions.rs b/codex-rs/core/src/mentions.rs index ceaced7faa..fec7c40e4b 100644 --- a/codex-rs/core/src/mentions.rs +++ b/codex-rs/core/src/mentions.rs @@ -132,160 +132,5 @@ pub(crate) fn build_connector_slug_counts( } #[cfg(test)] -mod tests { - use std::collections::HashSet; - - use codex_protocol::user_input::UserInput; - use pretty_assertions::assert_eq; - - use super::collect_explicit_app_ids; - use super::collect_explicit_plugin_mentions; - use crate::plugins::PluginCapabilitySummary; - - fn text_input(text: &str) -> UserInput { - UserInput::Text { - text: text.to_string(), - text_elements: Vec::new(), - } - } - - fn plugin(config_name: &str, display_name: &str) -> PluginCapabilitySummary { - PluginCapabilitySummary { - config_name: config_name.to_string(), - display_name: display_name.to_string(), - description: None, - has_skills: true, - mcp_server_names: Vec::new(), - app_connector_ids: Vec::new(), - } - } - - #[test] - fn collect_explicit_app_ids_from_linked_text_mentions() { - let input = vec![text_input("use [$calendar](app://calendar)")]; - - let app_ids = collect_explicit_app_ids(&input); - - assert_eq!(app_ids, HashSet::from(["calendar".to_string()])); - } - - #[test] - fn collect_explicit_app_ids_dedupes_structured_and_linked_mentions() { - let input = vec![ - text_input("use [$calendar](app://calendar)"), - UserInput::Mention { - name: "calendar".to_string(), - path: "app://calendar".to_string(), - }, - ]; - - let app_ids = collect_explicit_app_ids(&input); - - assert_eq!(app_ids, HashSet::from(["calendar".to_string()])); - } - - #[test] - fn collect_explicit_app_ids_ignores_non_app_paths() { - let input = vec![ - text_input( - "use [$docs](mcp://docs) and [$skill](skill://team/skill) and [$file](/tmp/file.txt)", - ), - UserInput::Mention { - name: "docs".to_string(), - path: "mcp://docs".to_string(), - }, - UserInput::Mention { - name: "skill".to_string(), - path: "skill://team/skill".to_string(), - }, - UserInput::Mention { - name: "file".to_string(), - path: "/tmp/file.txt".to_string(), - }, - ]; - - let app_ids = collect_explicit_app_ids(&input); - - assert_eq!(app_ids, HashSet::::new()); - } - - #[test] - fn collect_explicit_plugin_mentions_from_structured_paths() { - let plugins = vec![ - plugin("sample@test", "sample"), - plugin("other@test", "other"), - ]; - - let mentioned = collect_explicit_plugin_mentions( - &[UserInput::Mention { - name: "sample".to_string(), - path: "plugin://sample@test".to_string(), - }], - &plugins, - ); - - assert_eq!(mentioned, vec![plugin("sample@test", "sample")]); - } - - #[test] - fn collect_explicit_plugin_mentions_from_linked_text_mentions() { - let plugins = vec![ - plugin("sample@test", "sample"), - plugin("other@test", "other"), - ]; - - let mentioned = collect_explicit_plugin_mentions( - &[text_input("use [@sample](plugin://sample@test)")], - &plugins, - ); - - assert_eq!(mentioned, vec![plugin("sample@test", "sample")]); - } - - #[test] - fn collect_explicit_plugin_mentions_dedupes_structured_and_linked_mentions() { - let plugins = vec![ - plugin("sample@test", "sample"), - plugin("other@test", "other"), - ]; - - let mentioned = collect_explicit_plugin_mentions( - &[ - text_input("use [@sample](plugin://sample@test)"), - UserInput::Mention { - name: "sample".to_string(), - path: "plugin://sample@test".to_string(), - }, - ], - &plugins, - ); - - assert_eq!(mentioned, vec![plugin("sample@test", "sample")]); - } - - #[test] - fn collect_explicit_plugin_mentions_ignores_non_plugin_paths() { - let plugins = vec![plugin("sample@test", "sample")]; - - let mentioned = collect_explicit_plugin_mentions( - &[text_input( - "use [$app](app://calendar) and [$skill](skill://team/skill) and [$file](/tmp/file.txt)", - )], - &plugins, - ); - - assert_eq!(mentioned, Vec::::new()); - } - - #[test] - fn collect_explicit_plugin_mentions_ignores_dollar_linked_plugin_mentions() { - let plugins = vec![plugin("sample@test", "sample")]; - - let mentioned = collect_explicit_plugin_mentions( - &[text_input("use [$sample](plugin://sample@test)")], - &plugins, - ); - - assert_eq!(mentioned, Vec::::new()); - } -} +#[path = "mentions_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/mentions_tests.rs b/codex-rs/core/src/mentions_tests.rs new file mode 100644 index 0000000000..37c9adb886 --- /dev/null +++ b/codex-rs/core/src/mentions_tests.rs @@ -0,0 +1,155 @@ +use std::collections::HashSet; + +use codex_protocol::user_input::UserInput; +use pretty_assertions::assert_eq; + +use super::collect_explicit_app_ids; +use super::collect_explicit_plugin_mentions; +use crate::plugins::PluginCapabilitySummary; + +fn text_input(text: &str) -> UserInput { + UserInput::Text { + text: text.to_string(), + text_elements: Vec::new(), + } +} + +fn plugin(config_name: &str, display_name: &str) -> PluginCapabilitySummary { + PluginCapabilitySummary { + config_name: config_name.to_string(), + display_name: display_name.to_string(), + description: None, + has_skills: true, + mcp_server_names: Vec::new(), + app_connector_ids: Vec::new(), + } +} + +#[test] +fn collect_explicit_app_ids_from_linked_text_mentions() { + let input = vec![text_input("use [$calendar](app://calendar)")]; + + let app_ids = collect_explicit_app_ids(&input); + + assert_eq!(app_ids, HashSet::from(["calendar".to_string()])); +} + +#[test] +fn collect_explicit_app_ids_dedupes_structured_and_linked_mentions() { + let input = vec![ + text_input("use [$calendar](app://calendar)"), + UserInput::Mention { + name: "calendar".to_string(), + path: "app://calendar".to_string(), + }, + ]; + + let app_ids = collect_explicit_app_ids(&input); + + assert_eq!(app_ids, HashSet::from(["calendar".to_string()])); +} + +#[test] +fn collect_explicit_app_ids_ignores_non_app_paths() { + let input = vec![ + text_input( + "use [$docs](mcp://docs) and [$skill](skill://team/skill) and [$file](/tmp/file.txt)", + ), + UserInput::Mention { + name: "docs".to_string(), + path: "mcp://docs".to_string(), + }, + UserInput::Mention { + name: "skill".to_string(), + path: "skill://team/skill".to_string(), + }, + UserInput::Mention { + name: "file".to_string(), + path: "/tmp/file.txt".to_string(), + }, + ]; + + let app_ids = collect_explicit_app_ids(&input); + + assert_eq!(app_ids, HashSet::::new()); +} + +#[test] +fn collect_explicit_plugin_mentions_from_structured_paths() { + let plugins = vec![ + plugin("sample@test", "sample"), + plugin("other@test", "other"), + ]; + + let mentioned = collect_explicit_plugin_mentions( + &[UserInput::Mention { + name: "sample".to_string(), + path: "plugin://sample@test".to_string(), + }], + &plugins, + ); + + assert_eq!(mentioned, vec![plugin("sample@test", "sample")]); +} + +#[test] +fn collect_explicit_plugin_mentions_from_linked_text_mentions() { + let plugins = vec![ + plugin("sample@test", "sample"), + plugin("other@test", "other"), + ]; + + let mentioned = collect_explicit_plugin_mentions( + &[text_input("use [@sample](plugin://sample@test)")], + &plugins, + ); + + assert_eq!(mentioned, vec![plugin("sample@test", "sample")]); +} + +#[test] +fn collect_explicit_plugin_mentions_dedupes_structured_and_linked_mentions() { + let plugins = vec![ + plugin("sample@test", "sample"), + plugin("other@test", "other"), + ]; + + let mentioned = collect_explicit_plugin_mentions( + &[ + text_input("use [@sample](plugin://sample@test)"), + UserInput::Mention { + name: "sample".to_string(), + path: "plugin://sample@test".to_string(), + }, + ], + &plugins, + ); + + assert_eq!(mentioned, vec![plugin("sample@test", "sample")]); +} + +#[test] +fn collect_explicit_plugin_mentions_ignores_non_plugin_paths() { + let plugins = vec![plugin("sample@test", "sample")]; + + let mentioned = collect_explicit_plugin_mentions( + &[text_input( + "use [$app](app://calendar) and [$skill](skill://team/skill) and [$file](/tmp/file.txt)", + )], + &plugins, + ); + + assert_eq!(mentioned, Vec::::new()); +} + +#[test] +fn collect_explicit_plugin_mentions_ignores_dollar_linked_plugin_mentions() { + let plugins = vec![plugin("sample@test", "sample")]; + + let mentioned = collect_explicit_plugin_mentions( + &[text_input("use [$sample](plugin://sample@test)")], + &plugins, + ); + + assert_eq!(mentioned, Vec::::new()); +} diff --git a/codex-rs/core/src/message_history.rs b/codex-rs/core/src/message_history.rs index cb3b10098c..9a2c534890 100644 --- a/codex-rs/core/src/message_history.rs +++ b/codex-rs/core/src/message_history.rs @@ -401,216 +401,5 @@ fn history_log_id(_metadata: &std::fs::Metadata) -> Option { } #[cfg(test)] -mod tests { - use super::*; - use crate::config::ConfigBuilder; - use codex_protocol::ThreadId; - use pretty_assertions::assert_eq; - use std::fs::File; - use std::io::Write; - use tempfile::TempDir; - - #[tokio::test] - async fn lookup_reads_history_entries() { - let temp_dir = TempDir::new().expect("create temp dir"); - let history_path = temp_dir.path().join(HISTORY_FILENAME); - - let entries = vec![ - HistoryEntry { - session_id: "first-session".to_string(), - ts: 1, - text: "first".to_string(), - }, - HistoryEntry { - session_id: "second-session".to_string(), - ts: 2, - text: "second".to_string(), - }, - ]; - - let mut file = File::create(&history_path).expect("create history file"); - for entry in &entries { - writeln!( - file, - "{}", - serde_json::to_string(entry).expect("serialize history entry") - ) - .expect("write history entry"); - } - - let (log_id, count) = history_metadata_for_file(&history_path).await; - assert_eq!(count, entries.len()); - - let second_entry = - lookup_history_entry(&history_path, log_id, 1).expect("fetch second history entry"); - assert_eq!(second_entry, entries[1]); - } - - #[tokio::test] - async fn lookup_uses_stable_log_id_after_appends() { - let temp_dir = TempDir::new().expect("create temp dir"); - let history_path = temp_dir.path().join(HISTORY_FILENAME); - - let initial = HistoryEntry { - session_id: "first-session".to_string(), - ts: 1, - text: "first".to_string(), - }; - let appended = HistoryEntry { - session_id: "second-session".to_string(), - ts: 2, - text: "second".to_string(), - }; - - let mut file = File::create(&history_path).expect("create history file"); - writeln!( - file, - "{}", - serde_json::to_string(&initial).expect("serialize initial entry") - ) - .expect("write initial entry"); - - let (log_id, count) = history_metadata_for_file(&history_path).await; - assert_eq!(count, 1); - - let mut append = std::fs::OpenOptions::new() - .append(true) - .open(&history_path) - .expect("open history file for append"); - writeln!( - append, - "{}", - serde_json::to_string(&appended).expect("serialize appended entry") - ) - .expect("append history entry"); - - let fetched = - lookup_history_entry(&history_path, log_id, 1).expect("lookup appended history entry"); - assert_eq!(fetched, appended); - } - - #[tokio::test] - async fn append_entry_trims_history_when_beyond_max_bytes() { - let codex_home = TempDir::new().expect("create temp dir"); - - let mut config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .build() - .await - .expect("load config"); - - let conversation_id = ThreadId::new(); - - let entry_one = "a".repeat(200); - let entry_two = "b".repeat(200); - - let history_path = codex_home.path().join("history.jsonl"); - - append_entry(&entry_one, &conversation_id, &config) - .await - .expect("write first entry"); - - let first_len = std::fs::metadata(&history_path).expect("metadata").len(); - let limit_bytes = first_len + 10; - - config.history.max_bytes = - Some(usize::try_from(limit_bytes).expect("limit should fit into usize")); - - append_entry(&entry_two, &conversation_id, &config) - .await - .expect("write second entry"); - - let contents = std::fs::read_to_string(&history_path).expect("read history"); - - let entries = contents - .lines() - .map(|line| serde_json::from_str::(line).expect("parse entry")) - .collect::>(); - - assert_eq!( - entries.len(), - 1, - "only one entry left because entry_one should be evicted" - ); - assert_eq!(entries[0].text, entry_two); - assert!(std::fs::metadata(&history_path).expect("metadata").len() <= limit_bytes); - } - - #[tokio::test] - async fn append_entry_trims_history_to_soft_cap() { - let codex_home = TempDir::new().expect("create temp dir"); - - let mut config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .build() - .await - .expect("load config"); - - let conversation_id = ThreadId::new(); - - let short_entry = "a".repeat(200); - let long_entry = "b".repeat(400); - - let history_path = codex_home.path().join("history.jsonl"); - - append_entry(&short_entry, &conversation_id, &config) - .await - .expect("write first entry"); - - let short_entry_len = std::fs::metadata(&history_path).expect("metadata").len(); - - append_entry(&long_entry, &conversation_id, &config) - .await - .expect("write second entry"); - - let two_entry_len = std::fs::metadata(&history_path).expect("metadata").len(); - - let long_entry_len = two_entry_len - .checked_sub(short_entry_len) - .expect("second entry length should be larger than first entry length"); - - config.history.max_bytes = Some( - usize::try_from((2 * long_entry_len) + (short_entry_len / 2)) - .expect("max bytes should fit into usize"), - ); - - append_entry(&long_entry, &conversation_id, &config) - .await - .expect("write third entry"); - - let contents = std::fs::read_to_string(&history_path).expect("read history"); - - let entries = contents - .lines() - .map(|line| serde_json::from_str::(line).expect("parse entry")) - .collect::>(); - - assert_eq!(entries.len(), 1); - assert_eq!(entries[0].text, long_entry); - - let pruned_len = std::fs::metadata(&history_path).expect("metadata").len(); - let max_bytes = config - .history - .max_bytes - .expect("max bytes should be configured") as u64; - - assert!(pruned_len <= max_bytes); - - let soft_cap_bytes = ((max_bytes as f64) * HISTORY_SOFT_CAP_RATIO) - .floor() - .clamp(1.0, max_bytes as f64) as u64; - let len_without_first = 2 * long_entry_len; - - assert!( - len_without_first <= max_bytes, - "dropping only the first entry would satisfy the hard cap" - ); - assert!( - len_without_first > soft_cap_bytes, - "soft cap should require more aggressive trimming than the hard cap" - ); - - assert_eq!(pruned_len, long_entry_len); - assert!(pruned_len <= soft_cap_bytes.max(long_entry_len)); - } -} +#[path = "message_history_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/message_history_tests.rs b/codex-rs/core/src/message_history_tests.rs new file mode 100644 index 0000000000..59b9c8c7b7 --- /dev/null +++ b/codex-rs/core/src/message_history_tests.rs @@ -0,0 +1,211 @@ +use super::*; +use crate::config::ConfigBuilder; +use codex_protocol::ThreadId; +use pretty_assertions::assert_eq; +use std::fs::File; +use std::io::Write; +use tempfile::TempDir; + +#[tokio::test] +async fn lookup_reads_history_entries() { + let temp_dir = TempDir::new().expect("create temp dir"); + let history_path = temp_dir.path().join(HISTORY_FILENAME); + + let entries = vec![ + HistoryEntry { + session_id: "first-session".to_string(), + ts: 1, + text: "first".to_string(), + }, + HistoryEntry { + session_id: "second-session".to_string(), + ts: 2, + text: "second".to_string(), + }, + ]; + + let mut file = File::create(&history_path).expect("create history file"); + for entry in &entries { + writeln!( + file, + "{}", + serde_json::to_string(entry).expect("serialize history entry") + ) + .expect("write history entry"); + } + + let (log_id, count) = history_metadata_for_file(&history_path).await; + assert_eq!(count, entries.len()); + + let second_entry = + lookup_history_entry(&history_path, log_id, 1).expect("fetch second history entry"); + assert_eq!(second_entry, entries[1]); +} + +#[tokio::test] +async fn lookup_uses_stable_log_id_after_appends() { + let temp_dir = TempDir::new().expect("create temp dir"); + let history_path = temp_dir.path().join(HISTORY_FILENAME); + + let initial = HistoryEntry { + session_id: "first-session".to_string(), + ts: 1, + text: "first".to_string(), + }; + let appended = HistoryEntry { + session_id: "second-session".to_string(), + ts: 2, + text: "second".to_string(), + }; + + let mut file = File::create(&history_path).expect("create history file"); + writeln!( + file, + "{}", + serde_json::to_string(&initial).expect("serialize initial entry") + ) + .expect("write initial entry"); + + let (log_id, count) = history_metadata_for_file(&history_path).await; + assert_eq!(count, 1); + + let mut append = std::fs::OpenOptions::new() + .append(true) + .open(&history_path) + .expect("open history file for append"); + writeln!( + append, + "{}", + serde_json::to_string(&appended).expect("serialize appended entry") + ) + .expect("append history entry"); + + let fetched = + lookup_history_entry(&history_path, log_id, 1).expect("lookup appended history entry"); + assert_eq!(fetched, appended); +} + +#[tokio::test] +async fn append_entry_trims_history_when_beyond_max_bytes() { + let codex_home = TempDir::new().expect("create temp dir"); + + let mut config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .build() + .await + .expect("load config"); + + let conversation_id = ThreadId::new(); + + let entry_one = "a".repeat(200); + let entry_two = "b".repeat(200); + + let history_path = codex_home.path().join("history.jsonl"); + + append_entry(&entry_one, &conversation_id, &config) + .await + .expect("write first entry"); + + let first_len = std::fs::metadata(&history_path).expect("metadata").len(); + let limit_bytes = first_len + 10; + + config.history.max_bytes = + Some(usize::try_from(limit_bytes).expect("limit should fit into usize")); + + append_entry(&entry_two, &conversation_id, &config) + .await + .expect("write second entry"); + + let contents = std::fs::read_to_string(&history_path).expect("read history"); + + let entries = contents + .lines() + .map(|line| serde_json::from_str::(line).expect("parse entry")) + .collect::>(); + + assert_eq!( + entries.len(), + 1, + "only one entry left because entry_one should be evicted" + ); + assert_eq!(entries[0].text, entry_two); + assert!(std::fs::metadata(&history_path).expect("metadata").len() <= limit_bytes); +} + +#[tokio::test] +async fn append_entry_trims_history_to_soft_cap() { + let codex_home = TempDir::new().expect("create temp dir"); + + let mut config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .build() + .await + .expect("load config"); + + let conversation_id = ThreadId::new(); + + let short_entry = "a".repeat(200); + let long_entry = "b".repeat(400); + + let history_path = codex_home.path().join("history.jsonl"); + + append_entry(&short_entry, &conversation_id, &config) + .await + .expect("write first entry"); + + let short_entry_len = std::fs::metadata(&history_path).expect("metadata").len(); + + append_entry(&long_entry, &conversation_id, &config) + .await + .expect("write second entry"); + + let two_entry_len = std::fs::metadata(&history_path).expect("metadata").len(); + + let long_entry_len = two_entry_len + .checked_sub(short_entry_len) + .expect("second entry length should be larger than first entry length"); + + config.history.max_bytes = Some( + usize::try_from((2 * long_entry_len) + (short_entry_len / 2)) + .expect("max bytes should fit into usize"), + ); + + append_entry(&long_entry, &conversation_id, &config) + .await + .expect("write third entry"); + + let contents = std::fs::read_to_string(&history_path).expect("read history"); + + let entries = contents + .lines() + .map(|line| serde_json::from_str::(line).expect("parse entry")) + .collect::>(); + + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].text, long_entry); + + let pruned_len = std::fs::metadata(&history_path).expect("metadata").len(); + let max_bytes = config + .history + .max_bytes + .expect("max bytes should be configured") as u64; + + assert!(pruned_len <= max_bytes); + + let soft_cap_bytes = ((max_bytes as f64) * HISTORY_SOFT_CAP_RATIO) + .floor() + .clamp(1.0, max_bytes as f64) as u64; + let len_without_first = 2 * long_entry_len; + + assert!( + len_without_first <= max_bytes, + "dropping only the first entry would satisfy the hard cap" + ); + assert!( + len_without_first > soft_cap_bytes, + "soft cap should require more aggressive trimming than the hard cap" + ); + + assert_eq!(pruned_len, long_entry_len); + assert!(pruned_len <= soft_cap_bytes.max(long_entry_len)); +} diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index 5d5ee36692..d8e2ea35ed 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -330,112 +330,5 @@ pub fn create_oss_provider_with_base_url(base_url: &str, wire_api: WireApi) -> M } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn test_deserialize_ollama_model_provider_toml() { - let azure_provider_toml = r#" -name = "Ollama" -base_url = "http://localhost:11434/v1" - "#; - let expected_provider = ModelProviderInfo { - name: "Ollama".into(), - base_url: Some("http://localhost:11434/v1".into()), - env_key: None, - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_openai_auth: false, - supports_websockets: false, - }; - - let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); - assert_eq!(expected_provider, provider); - } - - #[test] - fn test_deserialize_azure_model_provider_toml() { - let azure_provider_toml = r#" -name = "Azure" -base_url = "https://xxxxx.openai.azure.com/openai" -env_key = "AZURE_OPENAI_API_KEY" -query_params = { api-version = "2025-04-01-preview" } - "#; - let expected_provider = ModelProviderInfo { - name: "Azure".into(), - base_url: Some("https://xxxxx.openai.azure.com/openai".into()), - env_key: Some("AZURE_OPENAI_API_KEY".into()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: Some(maplit::hashmap! { - "api-version".to_string() => "2025-04-01-preview".to_string(), - }), - http_headers: None, - env_http_headers: None, - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_openai_auth: false, - supports_websockets: false, - }; - - let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); - assert_eq!(expected_provider, provider); - } - - #[test] - fn test_deserialize_example_model_provider_toml() { - let azure_provider_toml = r#" -name = "Example" -base_url = "https://example.com" -env_key = "API_KEY" -http_headers = { "X-Example-Header" = "example-value" } -env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } - "#; - let expected_provider = ModelProviderInfo { - name: "Example".into(), - base_url: Some("https://example.com".into()), - env_key: Some("API_KEY".into()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: Some(maplit::hashmap! { - "X-Example-Header".to_string() => "example-value".to_string(), - }), - env_http_headers: Some(maplit::hashmap! { - "X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(), - }), - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - requires_openai_auth: false, - supports_websockets: false, - }; - - let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); - assert_eq!(expected_provider, provider); - } - - #[test] - fn test_deserialize_chat_wire_api_shows_helpful_error() { - let provider_toml = r#" -name = "OpenAI using Chat Completions" -base_url = "https://api.openai.com/v1" -env_key = "OPENAI_API_KEY" -wire_api = "chat" - "#; - - let err = toml::from_str::(provider_toml).unwrap_err(); - assert!(err.to_string().contains(CHAT_WIRE_API_REMOVED_ERROR)); - } -} +#[path = "model_provider_info_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/model_provider_info_tests.rs b/codex-rs/core/src/model_provider_info_tests.rs new file mode 100644 index 0000000000..e6d5cea36b --- /dev/null +++ b/codex-rs/core/src/model_provider_info_tests.rs @@ -0,0 +1,107 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn test_deserialize_ollama_model_provider_toml() { + let azure_provider_toml = r#" +name = "Ollama" +base_url = "http://localhost:11434/v1" + "#; + let expected_provider = ModelProviderInfo { + name: "Ollama".into(), + base_url: Some("http://localhost:11434/v1".into()), + env_key: None, + env_key_instructions: None, + experimental_bearer_token: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_openai_auth: false, + supports_websockets: false, + }; + + let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); + assert_eq!(expected_provider, provider); +} + +#[test] +fn test_deserialize_azure_model_provider_toml() { + let azure_provider_toml = r#" +name = "Azure" +base_url = "https://xxxxx.openai.azure.com/openai" +env_key = "AZURE_OPENAI_API_KEY" +query_params = { api-version = "2025-04-01-preview" } + "#; + let expected_provider = ModelProviderInfo { + name: "Azure".into(), + base_url: Some("https://xxxxx.openai.azure.com/openai".into()), + env_key: Some("AZURE_OPENAI_API_KEY".into()), + env_key_instructions: None, + experimental_bearer_token: None, + wire_api: WireApi::Responses, + query_params: Some(maplit::hashmap! { + "api-version".to_string() => "2025-04-01-preview".to_string(), + }), + http_headers: None, + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_openai_auth: false, + supports_websockets: false, + }; + + let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); + assert_eq!(expected_provider, provider); +} + +#[test] +fn test_deserialize_example_model_provider_toml() { + let azure_provider_toml = r#" +name = "Example" +base_url = "https://example.com" +env_key = "API_KEY" +http_headers = { "X-Example-Header" = "example-value" } +env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } + "#; + let expected_provider = ModelProviderInfo { + name: "Example".into(), + base_url: Some("https://example.com".into()), + env_key: Some("API_KEY".into()), + env_key_instructions: None, + experimental_bearer_token: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: Some(maplit::hashmap! { + "X-Example-Header".to_string() => "example-value".to_string(), + }), + env_http_headers: Some(maplit::hashmap! { + "X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(), + }), + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_openai_auth: false, + supports_websockets: false, + }; + + let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); + assert_eq!(expected_provider, provider); +} + +#[test] +fn test_deserialize_chat_wire_api_shows_helpful_error() { + let provider_toml = r#" +name = "OpenAI using Chat Completions" +base_url = "https://api.openai.com/v1" +env_key = "OPENAI_API_KEY" +wire_api = "chat" + "#; + + let err = toml::from_str::(provider_toml).unwrap_err(); + assert!(err.to_string().contains(CHAT_WIRE_API_REMOVED_ERROR)); +} diff --git a/codex-rs/core/src/models_manager/collaboration_mode_presets.rs b/codex-rs/core/src/models_manager/collaboration_mode_presets.rs index 5c5b212403..dceab9f3bd 100644 --- a/codex-rs/core/src/models_manager/collaboration_mode_presets.rs +++ b/codex-rs/core/src/models_manager/collaboration_mode_presets.rs @@ -103,57 +103,5 @@ fn asking_questions_guidance_message(default_mode_request_user_input: bool) -> S } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn preset_names_use_mode_display_names() { - assert_eq!(plan_preset().name, ModeKind::Plan.display_name()); - assert_eq!( - default_preset(CollaborationModesConfig::default()).name, - ModeKind::Default.display_name() - ); - assert_eq!( - plan_preset().reasoning_effort, - Some(Some(ReasoningEffort::Medium)) - ); - } - - #[test] - fn default_mode_instructions_replace_mode_names_placeholder() { - let default_instructions = default_preset(CollaborationModesConfig { - default_mode_request_user_input: true, - }) - .developer_instructions - .expect("default preset should include instructions") - .expect("default instructions should be set"); - - assert!(!default_instructions.contains(KNOWN_MODE_NAMES_PLACEHOLDER)); - assert!(!default_instructions.contains(REQUEST_USER_INPUT_AVAILABILITY_PLACEHOLDER)); - assert!(!default_instructions.contains(ASKING_QUESTIONS_GUIDANCE_PLACEHOLDER)); - - let known_mode_names = format_mode_names(&TUI_VISIBLE_COLLABORATION_MODES); - let expected_snippet = format!("Known mode names are {known_mode_names}."); - assert!(default_instructions.contains(&expected_snippet)); - - let expected_availability_message = - request_user_input_availability_message(ModeKind::Default, true); - assert!(default_instructions.contains(&expected_availability_message)); - assert!(default_instructions.contains("prefer using the `request_user_input` tool")); - } - - #[test] - fn default_mode_instructions_use_plain_text_questions_when_feature_disabled() { - let default_instructions = default_preset(CollaborationModesConfig::default()) - .developer_instructions - .expect("default preset should include instructions") - .expect("default instructions should be set"); - - assert!(!default_instructions.contains("prefer using the `request_user_input` tool")); - assert!( - default_instructions - .contains("ask the user directly with a concise plain-text question") - ); - } -} +#[path = "collaboration_mode_presets_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/models_manager/collaboration_mode_presets_tests.rs b/codex-rs/core/src/models_manager/collaboration_mode_presets_tests.rs new file mode 100644 index 0000000000..b0969f6eba --- /dev/null +++ b/codex-rs/core/src/models_manager/collaboration_mode_presets_tests.rs @@ -0,0 +1,51 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn preset_names_use_mode_display_names() { + assert_eq!(plan_preset().name, ModeKind::Plan.display_name()); + assert_eq!( + default_preset(CollaborationModesConfig::default()).name, + ModeKind::Default.display_name() + ); + assert_eq!( + plan_preset().reasoning_effort, + Some(Some(ReasoningEffort::Medium)) + ); +} + +#[test] +fn default_mode_instructions_replace_mode_names_placeholder() { + let default_instructions = default_preset(CollaborationModesConfig { + default_mode_request_user_input: true, + }) + .developer_instructions + .expect("default preset should include instructions") + .expect("default instructions should be set"); + + assert!(!default_instructions.contains(KNOWN_MODE_NAMES_PLACEHOLDER)); + assert!(!default_instructions.contains(REQUEST_USER_INPUT_AVAILABILITY_PLACEHOLDER)); + assert!(!default_instructions.contains(ASKING_QUESTIONS_GUIDANCE_PLACEHOLDER)); + + let known_mode_names = format_mode_names(&TUI_VISIBLE_COLLABORATION_MODES); + let expected_snippet = format!("Known mode names are {known_mode_names}."); + assert!(default_instructions.contains(&expected_snippet)); + + let expected_availability_message = + request_user_input_availability_message(ModeKind::Default, true); + assert!(default_instructions.contains(&expected_availability_message)); + assert!(default_instructions.contains("prefer using the `request_user_input` tool")); +} + +#[test] +fn default_mode_instructions_use_plain_text_questions_when_feature_disabled() { + let default_instructions = default_preset(CollaborationModesConfig::default()) + .developer_instructions + .expect("default preset should include instructions") + .expect("default instructions should be set"); + + assert!(!default_instructions.contains("prefer using the `request_user_input` tool")); + assert!( + default_instructions.contains("ask the user directly with a concise plain-text question") + ); +} diff --git a/codex-rs/core/src/models_manager/manager.rs b/codex-rs/core/src/models_manager/manager.rs index 35e723d253..89c6cdcb35 100644 --- a/codex-rs/core/src/models_manager/manager.rs +++ b/codex-rs/core/src/models_manager/manager.rs @@ -428,584 +428,5 @@ impl ModelsManager { } #[cfg(test)] -mod tests { - use super::*; - use crate::CodexAuth; - use crate::auth::AuthCredentialsStoreMode; - use crate::config::ConfigBuilder; - use crate::model_provider_info::WireApi; - use chrono::Utc; - use codex_protocol::openai_models::ModelsResponse; - use core_test_support::responses::mount_models_once; - use pretty_assertions::assert_eq; - use serde_json::json; - use tempfile::tempdir; - use wiremock::MockServer; - - fn remote_model(slug: &str, display: &str, priority: i32) -> ModelInfo { - remote_model_with_visibility(slug, display, priority, "list") - } - - fn remote_model_with_visibility( - slug: &str, - display: &str, - priority: i32, - visibility: &str, - ) -> ModelInfo { - serde_json::from_value(json!({ - "slug": slug, - "display_name": display, - "description": format!("{display} desc"), - "default_reasoning_level": "medium", - "supported_reasoning_levels": [{"effort": "low", "description": "low"}, {"effort": "medium", "description": "medium"}], - "shell_type": "shell_command", - "visibility": visibility, - "minimal_client_version": [0, 1, 0], - "supported_in_api": true, - "priority": priority, - "upgrade": null, - "base_instructions": "base instructions", - "supports_reasoning_summaries": false, - "support_verbosity": false, - "default_verbosity": null, - "apply_patch_tool_type": null, - "truncation_policy": {"mode": "bytes", "limit": 10_000}, - "supports_parallel_tool_calls": false, - "supports_image_detail_original": false, - "context_window": 272_000, - "experimental_supported_tools": [], - })) - .expect("valid model") - } - - fn assert_models_contain(actual: &[ModelInfo], expected: &[ModelInfo]) { - for model in expected { - assert!( - actual.iter().any(|candidate| candidate.slug == model.slug), - "expected model {} in cached list", - model.slug - ); - } - } - - fn provider_for(base_url: String) -> ModelProviderInfo { - ModelProviderInfo { - name: "mock".into(), - base_url: Some(base_url), - env_key: None, - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(5_000), - requires_openai_auth: false, - supports_websockets: false, - } - } - - #[tokio::test] - async fn get_model_info_tracks_fallback_usage() { - let codex_home = tempdir().expect("temp dir"); - let config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .build() - .await - .expect("load default test config"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let manager = ModelsManager::new( - codex_home.path().to_path_buf(), - auth_manager, - None, - CollaborationModesConfig::default(), - ); - let known_slug = manager - .get_remote_models() - .await - .first() - .expect("bundled models should include at least one model") - .slug - .clone(); - - let known = manager.get_model_info(known_slug.as_str(), &config).await; - assert!(!known.used_fallback_model_metadata); - assert_eq!(known.slug, known_slug); - - let unknown = manager - .get_model_info("model-that-does-not-exist", &config) - .await; - assert!(unknown.used_fallback_model_metadata); - assert_eq!(unknown.slug, "model-that-does-not-exist"); - } - - #[tokio::test] - async fn get_model_info_uses_custom_catalog() { - let codex_home = tempdir().expect("temp dir"); - let config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .build() - .await - .expect("load default test config"); - let mut overlay = remote_model("gpt-overlay", "Overlay", 0); - overlay.supports_image_detail_original = true; - - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let manager = ModelsManager::new( - codex_home.path().to_path_buf(), - auth_manager, - Some(ModelsResponse { - models: vec![overlay], - }), - CollaborationModesConfig::default(), - ); - - let model_info = manager - .get_model_info("gpt-overlay-experiment", &config) - .await; - - assert_eq!(model_info.slug, "gpt-overlay-experiment"); - assert_eq!(model_info.display_name, "Overlay"); - assert_eq!(model_info.context_window, Some(272_000)); - assert!(model_info.supports_image_detail_original); - assert!(!model_info.supports_parallel_tool_calls); - assert!(!model_info.used_fallback_model_metadata); - } - - #[tokio::test] - async fn get_model_info_matches_namespaced_suffix() { - let codex_home = tempdir().expect("temp dir"); - let config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .build() - .await - .expect("load default test config"); - let mut remote = remote_model("gpt-image", "Image", 0); - remote.supports_image_detail_original = true; - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let manager = ModelsManager::new( - codex_home.path().to_path_buf(), - auth_manager, - Some(ModelsResponse { - models: vec![remote], - }), - CollaborationModesConfig::default(), - ); - let namespaced_model = "custom/gpt-image".to_string(); - - let model_info = manager.get_model_info(&namespaced_model, &config).await; - - assert_eq!(model_info.slug, namespaced_model); - assert!(model_info.supports_image_detail_original); - assert!(!model_info.used_fallback_model_metadata); - } - - #[tokio::test] - async fn get_model_info_rejects_multi_segment_namespace_suffix_matching() { - let codex_home = tempdir().expect("temp dir"); - let config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .build() - .await - .expect("load default test config"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let manager = ModelsManager::new( - codex_home.path().to_path_buf(), - auth_manager, - None, - CollaborationModesConfig::default(), - ); - let known_slug = manager - .get_remote_models() - .await - .first() - .expect("bundled models should include at least one model") - .slug - .clone(); - let namespaced_model = format!("ns1/ns2/{known_slug}"); - - let model_info = manager.get_model_info(&namespaced_model, &config).await; - - assert_eq!(model_info.slug, namespaced_model); - assert!(model_info.used_fallback_model_metadata); - } - - #[tokio::test] - async fn refresh_available_models_sorts_by_priority() { - let server = MockServer::start().await; - let remote_models = vec![ - remote_model("priority-low", "Low", 1), - remote_model("priority-high", "High", 0), - ]; - let models_mock = mount_models_once( - &server, - ModelsResponse { - models: remote_models.clone(), - }, - ) - .await; - - let codex_home = tempdir().expect("temp dir"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); - let provider = provider_for(server.uri()); - let manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); - - manager - .refresh_available_models(RefreshStrategy::OnlineIfUncached) - .await - .expect("refresh succeeds"); - let cached_remote = manager.get_remote_models().await; - assert_models_contain(&cached_remote, &remote_models); - - let available = manager.list_models(RefreshStrategy::OnlineIfUncached).await; - let high_idx = available - .iter() - .position(|model| model.model == "priority-high") - .expect("priority-high should be listed"); - let low_idx = available - .iter() - .position(|model| model.model == "priority-low") - .expect("priority-low should be listed"); - assert!( - high_idx < low_idx, - "higher priority should be listed before lower priority" - ); - assert_eq!( - models_mock.requests().len(), - 1, - "expected a single /models request" - ); - } - - #[tokio::test] - async fn refresh_available_models_uses_cache_when_fresh() { - let server = MockServer::start().await; - let remote_models = vec![remote_model("cached", "Cached", 5)]; - let models_mock = mount_models_once( - &server, - ModelsResponse { - models: remote_models.clone(), - }, - ) - .await; - - let codex_home = tempdir().expect("temp dir"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); - let provider = provider_for(server.uri()); - let manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); - - manager - .refresh_available_models(RefreshStrategy::OnlineIfUncached) - .await - .expect("first refresh succeeds"); - assert_models_contain(&manager.get_remote_models().await, &remote_models); - - // Second call should read from cache and avoid the network. - manager - .refresh_available_models(RefreshStrategy::OnlineIfUncached) - .await - .expect("cached refresh succeeds"); - assert_models_contain(&manager.get_remote_models().await, &remote_models); - assert_eq!( - models_mock.requests().len(), - 1, - "cache hit should avoid a second /models request" - ); - } - - #[tokio::test] - async fn refresh_available_models_refetches_when_cache_stale() { - let server = MockServer::start().await; - let initial_models = vec![remote_model("stale", "Stale", 1)]; - let initial_mock = mount_models_once( - &server, - ModelsResponse { - models: initial_models.clone(), - }, - ) - .await; - - let codex_home = tempdir().expect("temp dir"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); - let provider = provider_for(server.uri()); - let manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); - - manager - .refresh_available_models(RefreshStrategy::OnlineIfUncached) - .await - .expect("initial refresh succeeds"); - - // Rewrite cache with an old timestamp so it is treated as stale. - manager - .cache_manager - .manipulate_cache_for_test(|fetched_at| { - *fetched_at = Utc::now() - chrono::Duration::hours(1); - }) - .await - .expect("cache manipulation succeeds"); - - let updated_models = vec![remote_model("fresh", "Fresh", 9)]; - server.reset().await; - let refreshed_mock = mount_models_once( - &server, - ModelsResponse { - models: updated_models.clone(), - }, - ) - .await; - - manager - .refresh_available_models(RefreshStrategy::OnlineIfUncached) - .await - .expect("second refresh succeeds"); - assert_models_contain(&manager.get_remote_models().await, &updated_models); - assert_eq!( - initial_mock.requests().len(), - 1, - "initial refresh should only hit /models once" - ); - assert_eq!( - refreshed_mock.requests().len(), - 1, - "stale cache refresh should fetch /models once" - ); - } - - #[tokio::test] - async fn refresh_available_models_refetches_when_version_mismatch() { - let server = MockServer::start().await; - let initial_models = vec![remote_model("old", "Old", 1)]; - let initial_mock = mount_models_once( - &server, - ModelsResponse { - models: initial_models.clone(), - }, - ) - .await; - - let codex_home = tempdir().expect("temp dir"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); - let provider = provider_for(server.uri()); - let manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); - - manager - .refresh_available_models(RefreshStrategy::OnlineIfUncached) - .await - .expect("initial refresh succeeds"); - - manager - .cache_manager - .mutate_cache_for_test(|cache| { - let client_version = crate::models_manager::client_version_to_whole(); - cache.client_version = Some(format!("{client_version}-mismatch")); - }) - .await - .expect("cache mutation succeeds"); - - let updated_models = vec![remote_model("new", "New", 2)]; - server.reset().await; - let refreshed_mock = mount_models_once( - &server, - ModelsResponse { - models: updated_models.clone(), - }, - ) - .await; - - manager - .refresh_available_models(RefreshStrategy::OnlineIfUncached) - .await - .expect("second refresh succeeds"); - assert_models_contain(&manager.get_remote_models().await, &updated_models); - assert_eq!( - initial_mock.requests().len(), - 1, - "initial refresh should only hit /models once" - ); - assert_eq!( - refreshed_mock.requests().len(), - 1, - "version mismatch should fetch /models once" - ); - } - - #[tokio::test] - async fn refresh_available_models_drops_removed_remote_models() { - let server = MockServer::start().await; - let initial_models = vec![remote_model("remote-old", "Remote Old", 1)]; - let initial_mock = mount_models_once( - &server, - ModelsResponse { - models: initial_models, - }, - ) - .await; - - let codex_home = tempdir().expect("temp dir"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); - let provider = provider_for(server.uri()); - let mut manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); - manager.cache_manager.set_ttl(Duration::ZERO); - - manager - .refresh_available_models(RefreshStrategy::OnlineIfUncached) - .await - .expect("initial refresh succeeds"); - - server.reset().await; - let refreshed_models = vec![remote_model("remote-new", "Remote New", 1)]; - let refreshed_mock = mount_models_once( - &server, - ModelsResponse { - models: refreshed_models, - }, - ) - .await; - - manager - .refresh_available_models(RefreshStrategy::OnlineIfUncached) - .await - .expect("second refresh succeeds"); - - let available = manager - .try_list_models() - .expect("models should be available"); - assert!( - available.iter().any(|preset| preset.model == "remote-new"), - "new remote model should be listed" - ); - assert!( - !available.iter().any(|preset| preset.model == "remote-old"), - "removed remote model should not be listed" - ); - assert_eq!( - initial_mock.requests().len(), - 1, - "initial refresh should only hit /models once" - ); - assert_eq!( - refreshed_mock.requests().len(), - 1, - "second refresh should only hit /models once" - ); - } - - #[tokio::test] - async fn refresh_available_models_skips_network_without_chatgpt_auth() { - let server = MockServer::start().await; - let dynamic_slug = "dynamic-model-only-for-test-noauth"; - let models_mock = mount_models_once( - &server, - ModelsResponse { - models: vec![remote_model(dynamic_slug, "No Auth", 1)], - }, - ) - .await; - - let codex_home = tempdir().expect("temp dir"); - let auth_manager = Arc::new(AuthManager::new( - codex_home.path().to_path_buf(), - false, - AuthCredentialsStoreMode::File, - )); - let provider = provider_for(server.uri()); - let manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); - - manager - .refresh_available_models(RefreshStrategy::Online) - .await - .expect("refresh should no-op without chatgpt auth"); - let cached_remote = manager.get_remote_models().await; - assert!( - !cached_remote - .iter() - .any(|candidate| candidate.slug == dynamic_slug), - "remote refresh should be skipped without chatgpt auth" - ); - assert_eq!( - models_mock.requests().len(), - 0, - "no auth should avoid /models requests" - ); - } - - #[test] - fn build_available_models_picks_default_after_hiding_hidden_models() { - let codex_home = tempdir().expect("temp dir"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let provider = provider_for("http://example.test".to_string()); - let manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); - - let hidden_model = remote_model_with_visibility("hidden", "Hidden", 0, "hide"); - let visible_model = remote_model_with_visibility("visible", "Visible", 1, "list"); - - let expected_hidden = ModelPreset::from(hidden_model.clone()); - let mut expected_visible = ModelPreset::from(visible_model.clone()); - expected_visible.is_default = true; - - let available = manager.build_available_models(vec![hidden_model, visible_model]); - - assert_eq!(available, vec![expected_hidden, expected_visible]); - } - - #[test] - fn bundled_models_json_roundtrips() { - let file_contents = include_str!("../../models.json"); - let response: ModelsResponse = - serde_json::from_str(file_contents).expect("bundled models.json should deserialize"); - - let serialized = - serde_json::to_string(&response).expect("bundled models.json should serialize"); - let roundtripped: ModelsResponse = - serde_json::from_str(&serialized).expect("serialized models.json should deserialize"); - - assert_eq!( - response, roundtripped, - "bundled models.json should round trip through serde" - ); - assert!( - !response.models.is_empty(), - "bundled models.json should contain at least one model" - ); - } -} +#[path = "manager_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/models_manager/manager_tests.rs b/codex-rs/core/src/models_manager/manager_tests.rs new file mode 100644 index 0000000000..6981d6d799 --- /dev/null +++ b/codex-rs/core/src/models_manager/manager_tests.rs @@ -0,0 +1,574 @@ +use super::*; +use crate::CodexAuth; +use crate::auth::AuthCredentialsStoreMode; +use crate::config::ConfigBuilder; +use crate::model_provider_info::WireApi; +use chrono::Utc; +use codex_protocol::openai_models::ModelsResponse; +use core_test_support::responses::mount_models_once; +use pretty_assertions::assert_eq; +use serde_json::json; +use tempfile::tempdir; +use wiremock::MockServer; + +fn remote_model(slug: &str, display: &str, priority: i32) -> ModelInfo { + remote_model_with_visibility(slug, display, priority, "list") +} + +fn remote_model_with_visibility( + slug: &str, + display: &str, + priority: i32, + visibility: &str, +) -> ModelInfo { + serde_json::from_value(json!({ + "slug": slug, + "display_name": display, + "description": format!("{display} desc"), + "default_reasoning_level": "medium", + "supported_reasoning_levels": [{"effort": "low", "description": "low"}, {"effort": "medium", "description": "medium"}], + "shell_type": "shell_command", + "visibility": visibility, + "minimal_client_version": [0, 1, 0], + "supported_in_api": true, + "priority": priority, + "upgrade": null, + "base_instructions": "base instructions", + "supports_reasoning_summaries": false, + "support_verbosity": false, + "default_verbosity": null, + "apply_patch_tool_type": null, + "truncation_policy": {"mode": "bytes", "limit": 10_000}, + "supports_parallel_tool_calls": false, + "supports_image_detail_original": false, + "context_window": 272_000, + "experimental_supported_tools": [], + })) + .expect("valid model") +} + +fn assert_models_contain(actual: &[ModelInfo], expected: &[ModelInfo]) { + for model in expected { + assert!( + actual.iter().any(|candidate| candidate.slug == model.slug), + "expected model {} in cached list", + model.slug + ); + } +} + +fn provider_for(base_url: String) -> ModelProviderInfo { + ModelProviderInfo { + name: "mock".into(), + base_url: Some(base_url), + env_key: None, + env_key_instructions: None, + experimental_bearer_token: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(5_000), + requires_openai_auth: false, + supports_websockets: false, + } +} + +#[tokio::test] +async fn get_model_info_tracks_fallback_usage() { + let codex_home = tempdir().expect("temp dir"); + let config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .build() + .await + .expect("load default test config"); + let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); + let manager = ModelsManager::new( + codex_home.path().to_path_buf(), + auth_manager, + None, + CollaborationModesConfig::default(), + ); + let known_slug = manager + .get_remote_models() + .await + .first() + .expect("bundled models should include at least one model") + .slug + .clone(); + + let known = manager.get_model_info(known_slug.as_str(), &config).await; + assert!(!known.used_fallback_model_metadata); + assert_eq!(known.slug, known_slug); + + let unknown = manager + .get_model_info("model-that-does-not-exist", &config) + .await; + assert!(unknown.used_fallback_model_metadata); + assert_eq!(unknown.slug, "model-that-does-not-exist"); +} + +#[tokio::test] +async fn get_model_info_uses_custom_catalog() { + let codex_home = tempdir().expect("temp dir"); + let config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .build() + .await + .expect("load default test config"); + let mut overlay = remote_model("gpt-overlay", "Overlay", 0); + overlay.supports_image_detail_original = true; + + let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); + let manager = ModelsManager::new( + codex_home.path().to_path_buf(), + auth_manager, + Some(ModelsResponse { + models: vec![overlay], + }), + CollaborationModesConfig::default(), + ); + + let model_info = manager + .get_model_info("gpt-overlay-experiment", &config) + .await; + + assert_eq!(model_info.slug, "gpt-overlay-experiment"); + assert_eq!(model_info.display_name, "Overlay"); + assert_eq!(model_info.context_window, Some(272_000)); + assert!(model_info.supports_image_detail_original); + assert!(!model_info.supports_parallel_tool_calls); + assert!(!model_info.used_fallback_model_metadata); +} + +#[tokio::test] +async fn get_model_info_matches_namespaced_suffix() { + let codex_home = tempdir().expect("temp dir"); + let config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .build() + .await + .expect("load default test config"); + let mut remote = remote_model("gpt-image", "Image", 0); + remote.supports_image_detail_original = true; + let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); + let manager = ModelsManager::new( + codex_home.path().to_path_buf(), + auth_manager, + Some(ModelsResponse { + models: vec![remote], + }), + CollaborationModesConfig::default(), + ); + let namespaced_model = "custom/gpt-image".to_string(); + + let model_info = manager.get_model_info(&namespaced_model, &config).await; + + assert_eq!(model_info.slug, namespaced_model); + assert!(model_info.supports_image_detail_original); + assert!(!model_info.used_fallback_model_metadata); +} + +#[tokio::test] +async fn get_model_info_rejects_multi_segment_namespace_suffix_matching() { + let codex_home = tempdir().expect("temp dir"); + let config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .build() + .await + .expect("load default test config"); + let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); + let manager = ModelsManager::new( + codex_home.path().to_path_buf(), + auth_manager, + None, + CollaborationModesConfig::default(), + ); + let known_slug = manager + .get_remote_models() + .await + .first() + .expect("bundled models should include at least one model") + .slug + .clone(); + let namespaced_model = format!("ns1/ns2/{known_slug}"); + + let model_info = manager.get_model_info(&namespaced_model, &config).await; + + assert_eq!(model_info.slug, namespaced_model); + assert!(model_info.used_fallback_model_metadata); +} + +#[tokio::test] +async fn refresh_available_models_sorts_by_priority() { + let server = MockServer::start().await; + let remote_models = vec![ + remote_model("priority-low", "Low", 1), + remote_model("priority-high", "High", 0), + ]; + let models_mock = mount_models_once( + &server, + ModelsResponse { + models: remote_models.clone(), + }, + ) + .await; + + let codex_home = tempdir().expect("temp dir"); + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let provider = provider_for(server.uri()); + let manager = ModelsManager::with_provider_for_tests( + codex_home.path().to_path_buf(), + auth_manager, + provider, + ); + + manager + .refresh_available_models(RefreshStrategy::OnlineIfUncached) + .await + .expect("refresh succeeds"); + let cached_remote = manager.get_remote_models().await; + assert_models_contain(&cached_remote, &remote_models); + + let available = manager.list_models(RefreshStrategy::OnlineIfUncached).await; + let high_idx = available + .iter() + .position(|model| model.model == "priority-high") + .expect("priority-high should be listed"); + let low_idx = available + .iter() + .position(|model| model.model == "priority-low") + .expect("priority-low should be listed"); + assert!( + high_idx < low_idx, + "higher priority should be listed before lower priority" + ); + assert_eq!( + models_mock.requests().len(), + 1, + "expected a single /models request" + ); +} + +#[tokio::test] +async fn refresh_available_models_uses_cache_when_fresh() { + let server = MockServer::start().await; + let remote_models = vec![remote_model("cached", "Cached", 5)]; + let models_mock = mount_models_once( + &server, + ModelsResponse { + models: remote_models.clone(), + }, + ) + .await; + + let codex_home = tempdir().expect("temp dir"); + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let provider = provider_for(server.uri()); + let manager = ModelsManager::with_provider_for_tests( + codex_home.path().to_path_buf(), + auth_manager, + provider, + ); + + manager + .refresh_available_models(RefreshStrategy::OnlineIfUncached) + .await + .expect("first refresh succeeds"); + assert_models_contain(&manager.get_remote_models().await, &remote_models); + + // Second call should read from cache and avoid the network. + manager + .refresh_available_models(RefreshStrategy::OnlineIfUncached) + .await + .expect("cached refresh succeeds"); + assert_models_contain(&manager.get_remote_models().await, &remote_models); + assert_eq!( + models_mock.requests().len(), + 1, + "cache hit should avoid a second /models request" + ); +} + +#[tokio::test] +async fn refresh_available_models_refetches_when_cache_stale() { + let server = MockServer::start().await; + let initial_models = vec![remote_model("stale", "Stale", 1)]; + let initial_mock = mount_models_once( + &server, + ModelsResponse { + models: initial_models.clone(), + }, + ) + .await; + + let codex_home = tempdir().expect("temp dir"); + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let provider = provider_for(server.uri()); + let manager = ModelsManager::with_provider_for_tests( + codex_home.path().to_path_buf(), + auth_manager, + provider, + ); + + manager + .refresh_available_models(RefreshStrategy::OnlineIfUncached) + .await + .expect("initial refresh succeeds"); + + // Rewrite cache with an old timestamp so it is treated as stale. + manager + .cache_manager + .manipulate_cache_for_test(|fetched_at| { + *fetched_at = Utc::now() - chrono::Duration::hours(1); + }) + .await + .expect("cache manipulation succeeds"); + + let updated_models = vec![remote_model("fresh", "Fresh", 9)]; + server.reset().await; + let refreshed_mock = mount_models_once( + &server, + ModelsResponse { + models: updated_models.clone(), + }, + ) + .await; + + manager + .refresh_available_models(RefreshStrategy::OnlineIfUncached) + .await + .expect("second refresh succeeds"); + assert_models_contain(&manager.get_remote_models().await, &updated_models); + assert_eq!( + initial_mock.requests().len(), + 1, + "initial refresh should only hit /models once" + ); + assert_eq!( + refreshed_mock.requests().len(), + 1, + "stale cache refresh should fetch /models once" + ); +} + +#[tokio::test] +async fn refresh_available_models_refetches_when_version_mismatch() { + let server = MockServer::start().await; + let initial_models = vec![remote_model("old", "Old", 1)]; + let initial_mock = mount_models_once( + &server, + ModelsResponse { + models: initial_models.clone(), + }, + ) + .await; + + let codex_home = tempdir().expect("temp dir"); + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let provider = provider_for(server.uri()); + let manager = ModelsManager::with_provider_for_tests( + codex_home.path().to_path_buf(), + auth_manager, + provider, + ); + + manager + .refresh_available_models(RefreshStrategy::OnlineIfUncached) + .await + .expect("initial refresh succeeds"); + + manager + .cache_manager + .mutate_cache_for_test(|cache| { + let client_version = crate::models_manager::client_version_to_whole(); + cache.client_version = Some(format!("{client_version}-mismatch")); + }) + .await + .expect("cache mutation succeeds"); + + let updated_models = vec![remote_model("new", "New", 2)]; + server.reset().await; + let refreshed_mock = mount_models_once( + &server, + ModelsResponse { + models: updated_models.clone(), + }, + ) + .await; + + manager + .refresh_available_models(RefreshStrategy::OnlineIfUncached) + .await + .expect("second refresh succeeds"); + assert_models_contain(&manager.get_remote_models().await, &updated_models); + assert_eq!( + initial_mock.requests().len(), + 1, + "initial refresh should only hit /models once" + ); + assert_eq!( + refreshed_mock.requests().len(), + 1, + "version mismatch should fetch /models once" + ); +} + +#[tokio::test] +async fn refresh_available_models_drops_removed_remote_models() { + let server = MockServer::start().await; + let initial_models = vec![remote_model("remote-old", "Remote Old", 1)]; + let initial_mock = mount_models_once( + &server, + ModelsResponse { + models: initial_models, + }, + ) + .await; + + let codex_home = tempdir().expect("temp dir"); + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let provider = provider_for(server.uri()); + let mut manager = ModelsManager::with_provider_for_tests( + codex_home.path().to_path_buf(), + auth_manager, + provider, + ); + manager.cache_manager.set_ttl(Duration::ZERO); + + manager + .refresh_available_models(RefreshStrategy::OnlineIfUncached) + .await + .expect("initial refresh succeeds"); + + server.reset().await; + let refreshed_models = vec![remote_model("remote-new", "Remote New", 1)]; + let refreshed_mock = mount_models_once( + &server, + ModelsResponse { + models: refreshed_models, + }, + ) + .await; + + manager + .refresh_available_models(RefreshStrategy::OnlineIfUncached) + .await + .expect("second refresh succeeds"); + + let available = manager + .try_list_models() + .expect("models should be available"); + assert!( + available.iter().any(|preset| preset.model == "remote-new"), + "new remote model should be listed" + ); + assert!( + !available.iter().any(|preset| preset.model == "remote-old"), + "removed remote model should not be listed" + ); + assert_eq!( + initial_mock.requests().len(), + 1, + "initial refresh should only hit /models once" + ); + assert_eq!( + refreshed_mock.requests().len(), + 1, + "second refresh should only hit /models once" + ); +} + +#[tokio::test] +async fn refresh_available_models_skips_network_without_chatgpt_auth() { + let server = MockServer::start().await; + let dynamic_slug = "dynamic-model-only-for-test-noauth"; + let models_mock = mount_models_once( + &server, + ModelsResponse { + models: vec![remote_model(dynamic_slug, "No Auth", 1)], + }, + ) + .await; + + let codex_home = tempdir().expect("temp dir"); + let auth_manager = Arc::new(AuthManager::new( + codex_home.path().to_path_buf(), + false, + AuthCredentialsStoreMode::File, + )); + let provider = provider_for(server.uri()); + let manager = ModelsManager::with_provider_for_tests( + codex_home.path().to_path_buf(), + auth_manager, + provider, + ); + + manager + .refresh_available_models(RefreshStrategy::Online) + .await + .expect("refresh should no-op without chatgpt auth"); + let cached_remote = manager.get_remote_models().await; + assert!( + !cached_remote + .iter() + .any(|candidate| candidate.slug == dynamic_slug), + "remote refresh should be skipped without chatgpt auth" + ); + assert_eq!( + models_mock.requests().len(), + 0, + "no auth should avoid /models requests" + ); +} + +#[test] +fn build_available_models_picks_default_after_hiding_hidden_models() { + let codex_home = tempdir().expect("temp dir"); + let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); + let provider = provider_for("http://example.test".to_string()); + let manager = ModelsManager::with_provider_for_tests( + codex_home.path().to_path_buf(), + auth_manager, + provider, + ); + + let hidden_model = remote_model_with_visibility("hidden", "Hidden", 0, "hide"); + let visible_model = remote_model_with_visibility("visible", "Visible", 1, "list"); + + let expected_hidden = ModelPreset::from(hidden_model.clone()); + let mut expected_visible = ModelPreset::from(visible_model.clone()); + expected_visible.is_default = true; + + let available = manager.build_available_models(vec![hidden_model, visible_model]); + + assert_eq!(available, vec![expected_hidden, expected_visible]); +} + +#[test] +fn bundled_models_json_roundtrips() { + let file_contents = include_str!("../../models.json"); + let response: ModelsResponse = + serde_json::from_str(file_contents).expect("bundled models.json should deserialize"); + + let serialized = + serde_json::to_string(&response).expect("bundled models.json should serialize"); + let roundtripped: ModelsResponse = + serde_json::from_str(&serialized).expect("serialized models.json should deserialize"); + + assert_eq!( + response, roundtripped, + "bundled models.json should round trip through serde" + ); + assert!( + !response.models.is_empty(), + "bundled models.json should contain at least one model" + ); +} diff --git a/codex-rs/core/src/models_manager/model_info.rs b/codex-rs/core/src/models_manager/model_info.rs index 3664a52660..d82cb92b21 100644 --- a/codex-rs/core/src/models_manager/model_info.rs +++ b/codex-rs/core/src/models_manager/model_info.rs @@ -110,44 +110,5 @@ fn local_personality_messages_for_slug(slug: &str) -> Option { } #[cfg(test)] -mod tests { - use super::*; - use crate::config::test_config; - use pretty_assertions::assert_eq; - - #[test] - fn reasoning_summaries_override_true_enables_support() { - let model = model_info_from_slug("unknown-model"); - let mut config = test_config(); - config.model_supports_reasoning_summaries = Some(true); - - let updated = with_config_overrides(model.clone(), &config); - let mut expected = model; - expected.supports_reasoning_summaries = true; - - assert_eq!(updated, expected); - } - - #[test] - fn reasoning_summaries_override_false_does_not_disable_support() { - let mut model = model_info_from_slug("unknown-model"); - model.supports_reasoning_summaries = true; - let mut config = test_config(); - config.model_supports_reasoning_summaries = Some(false); - - let updated = with_config_overrides(model.clone(), &config); - - assert_eq!(updated, model); - } - - #[test] - fn reasoning_summaries_override_false_is_noop_when_model_is_false() { - let model = model_info_from_slug("unknown-model"); - let mut config = test_config(); - config.model_supports_reasoning_summaries = Some(false); - - let updated = with_config_overrides(model.clone(), &config); - - assert_eq!(updated, model); - } -} +#[path = "model_info_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/models_manager/model_info_tests.rs b/codex-rs/core/src/models_manager/model_info_tests.rs new file mode 100644 index 0000000000..27bac8d70b --- /dev/null +++ b/codex-rs/core/src/models_manager/model_info_tests.rs @@ -0,0 +1,39 @@ +use super::*; +use crate::config::test_config; +use pretty_assertions::assert_eq; + +#[test] +fn reasoning_summaries_override_true_enables_support() { + let model = model_info_from_slug("unknown-model"); + let mut config = test_config(); + config.model_supports_reasoning_summaries = Some(true); + + let updated = with_config_overrides(model.clone(), &config); + let mut expected = model; + expected.supports_reasoning_summaries = true; + + assert_eq!(updated, expected); +} + +#[test] +fn reasoning_summaries_override_false_does_not_disable_support() { + let mut model = model_info_from_slug("unknown-model"); + model.supports_reasoning_summaries = true; + let mut config = test_config(); + config.model_supports_reasoning_summaries = Some(false); + + let updated = with_config_overrides(model.clone(), &config); + + assert_eq!(updated, model); +} + +#[test] +fn reasoning_summaries_override_false_is_noop_when_model_is_false() { + let model = model_info_from_slug("unknown-model"); + let mut config = test_config(); + config.model_supports_reasoning_summaries = Some(false); + + let updated = with_config_overrides(model.clone(), &config); + + assert_eq!(updated, model); +} diff --git a/codex-rs/core/src/network_policy_decision.rs b/codex-rs/core/src/network_policy_decision.rs index e40ae854c4..484905cfd9 100644 --- a/codex-rs/core/src/network_policy_decision.rs +++ b/codex-rs/core/src/network_policy_decision.rs @@ -121,196 +121,5 @@ pub(crate) fn execpolicy_network_rule_amendment( } #[cfg(test)] -mod tests { - use super::*; - use codex_network_proxy::BlockedRequest; - use codex_protocol::approvals::NetworkPolicyAmendment; - use codex_protocol::approvals::NetworkPolicyRuleAction; - use pretty_assertions::assert_eq; - - #[test] - fn network_approval_context_requires_ask_from_decider() { - let payload = NetworkPolicyDecisionPayload { - decision: NetworkPolicyDecision::Deny, - source: NetworkDecisionSource::Decider, - protocol: Some(NetworkApprovalProtocol::Https), - host: Some("example.com".to_string()), - reason: Some("not_allowed".to_string()), - port: Some(443), - }; - - assert_eq!(network_approval_context_from_payload(&payload), None); - } - - #[test] - fn network_approval_context_maps_http_https_and_socks_protocols() { - let http_payload = NetworkPolicyDecisionPayload { - decision: NetworkPolicyDecision::Ask, - source: NetworkDecisionSource::Decider, - protocol: Some(NetworkApprovalProtocol::Http), - host: Some("example.com".to_string()), - reason: Some("not_allowed".to_string()), - port: Some(80), - }; - assert_eq!( - network_approval_context_from_payload(&http_payload), - Some(NetworkApprovalContext { - host: "example.com".to_string(), - protocol: NetworkApprovalProtocol::Http, - }) - ); - - let https_payload = NetworkPolicyDecisionPayload { - decision: NetworkPolicyDecision::Ask, - source: NetworkDecisionSource::Decider, - protocol: Some(NetworkApprovalProtocol::Https), - host: Some("example.com".to_string()), - reason: Some("not_allowed".to_string()), - port: Some(443), - }; - assert_eq!( - network_approval_context_from_payload(&https_payload), - Some(NetworkApprovalContext { - host: "example.com".to_string(), - protocol: NetworkApprovalProtocol::Https, - }) - ); - - let http_connect_payload = NetworkPolicyDecisionPayload { - decision: NetworkPolicyDecision::Ask, - source: NetworkDecisionSource::Decider, - protocol: Some(NetworkApprovalProtocol::Https), - host: Some("example.com".to_string()), - reason: Some("not_allowed".to_string()), - port: Some(443), - }; - assert_eq!( - network_approval_context_from_payload(&http_connect_payload), - Some(NetworkApprovalContext { - host: "example.com".to_string(), - protocol: NetworkApprovalProtocol::Https, - }) - ); - - let socks5_tcp_payload = NetworkPolicyDecisionPayload { - decision: NetworkPolicyDecision::Ask, - source: NetworkDecisionSource::Decider, - protocol: Some(NetworkApprovalProtocol::Socks5Tcp), - host: Some("example.com".to_string()), - reason: Some("not_allowed".to_string()), - port: Some(443), - }; - assert_eq!( - network_approval_context_from_payload(&socks5_tcp_payload), - Some(NetworkApprovalContext { - host: "example.com".to_string(), - protocol: NetworkApprovalProtocol::Socks5Tcp, - }) - ); - - let socks5_udp_payload = NetworkPolicyDecisionPayload { - decision: NetworkPolicyDecision::Ask, - source: NetworkDecisionSource::Decider, - protocol: Some(NetworkApprovalProtocol::Socks5Udp), - host: Some("example.com".to_string()), - reason: Some("not_allowed".to_string()), - port: Some(443), - }; - assert_eq!( - network_approval_context_from_payload(&socks5_udp_payload), - Some(NetworkApprovalContext { - host: "example.com".to_string(), - protocol: NetworkApprovalProtocol::Socks5Udp, - }) - ); - } - - #[test] - fn network_policy_decision_payload_deserializes_proxy_protocol_aliases() { - let payload: NetworkPolicyDecisionPayload = serde_json::from_str( - r#"{ - "decision":"ask", - "source":"decider", - "protocol":"https_connect", - "host":"example.com", - "reason":"not_allowed", - "port":443 - }"#, - ) - .expect("payload should deserialize"); - assert_eq!(payload.protocol, Some(NetworkApprovalProtocol::Https)); - - let payload: NetworkPolicyDecisionPayload = serde_json::from_str( - r#"{ - "decision":"ask", - "source":"decider", - "protocol":"http-connect", - "host":"example.com", - "reason":"not_allowed", - "port":443 - }"#, - ) - .expect("payload should deserialize"); - assert_eq!(payload.protocol, Some(NetworkApprovalProtocol::Https)); - } - - #[test] - fn execpolicy_network_rule_amendment_maps_protocol_action_and_justification() { - let amendment = NetworkPolicyAmendment { - action: NetworkPolicyRuleAction::Deny, - host: "example.com".to_string(), - }; - let context = NetworkApprovalContext { - host: "example.com".to_string(), - protocol: NetworkApprovalProtocol::Socks5Udp, - }; - - assert_eq!( - execpolicy_network_rule_amendment(&amendment, &context, "example.com"), - ExecPolicyNetworkRuleAmendment { - protocol: ExecPolicyNetworkRuleProtocol::Socks5Udp, - decision: ExecPolicyDecision::Forbidden, - justification: "Deny socks5_udp access to example.com".to_string(), - } - ); - } - - #[test] - fn denied_network_policy_message_requires_deny_decision() { - let blocked = BlockedRequest { - host: "example.com".to_string(), - reason: "not_allowed".to_string(), - client: None, - method: Some("GET".to_string()), - mode: None, - protocol: "http".to_string(), - decision: Some("ask".to_string()), - source: Some("decider".to_string()), - port: Some(80), - timestamp: 0, - }; - assert_eq!(denied_network_policy_message(&blocked), None); - } - - #[test] - fn denied_network_policy_message_for_denylist_block_is_explicit() { - let blocked = BlockedRequest { - host: "example.com".to_string(), - reason: "denied".to_string(), - client: None, - method: Some("GET".to_string()), - mode: None, - protocol: "http".to_string(), - decision: Some("deny".to_string()), - source: Some("baseline_policy".to_string()), - port: Some(80), - timestamp: 0, - }; - assert_eq!( - denied_network_policy_message(&blocked), - Some( - "Network access to \"example.com\" was blocked: domain is explicitly denied by policy and cannot be approved from this prompt.".to_string() - ) - ); - } -} +#[path = "network_policy_decision_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/network_policy_decision_tests.rs b/codex-rs/core/src/network_policy_decision_tests.rs new file mode 100644 index 0000000000..ebb17f724f --- /dev/null +++ b/codex-rs/core/src/network_policy_decision_tests.rs @@ -0,0 +1,191 @@ +use super::*; +use codex_network_proxy::BlockedRequest; +use codex_protocol::approvals::NetworkPolicyAmendment; +use codex_protocol::approvals::NetworkPolicyRuleAction; +use pretty_assertions::assert_eq; + +#[test] +fn network_approval_context_requires_ask_from_decider() { + let payload = NetworkPolicyDecisionPayload { + decision: NetworkPolicyDecision::Deny, + source: NetworkDecisionSource::Decider, + protocol: Some(NetworkApprovalProtocol::Https), + host: Some("example.com".to_string()), + reason: Some("not_allowed".to_string()), + port: Some(443), + }; + + assert_eq!(network_approval_context_from_payload(&payload), None); +} + +#[test] +fn network_approval_context_maps_http_https_and_socks_protocols() { + let http_payload = NetworkPolicyDecisionPayload { + decision: NetworkPolicyDecision::Ask, + source: NetworkDecisionSource::Decider, + protocol: Some(NetworkApprovalProtocol::Http), + host: Some("example.com".to_string()), + reason: Some("not_allowed".to_string()), + port: Some(80), + }; + assert_eq!( + network_approval_context_from_payload(&http_payload), + Some(NetworkApprovalContext { + host: "example.com".to_string(), + protocol: NetworkApprovalProtocol::Http, + }) + ); + + let https_payload = NetworkPolicyDecisionPayload { + decision: NetworkPolicyDecision::Ask, + source: NetworkDecisionSource::Decider, + protocol: Some(NetworkApprovalProtocol::Https), + host: Some("example.com".to_string()), + reason: Some("not_allowed".to_string()), + port: Some(443), + }; + assert_eq!( + network_approval_context_from_payload(&https_payload), + Some(NetworkApprovalContext { + host: "example.com".to_string(), + protocol: NetworkApprovalProtocol::Https, + }) + ); + + let http_connect_payload = NetworkPolicyDecisionPayload { + decision: NetworkPolicyDecision::Ask, + source: NetworkDecisionSource::Decider, + protocol: Some(NetworkApprovalProtocol::Https), + host: Some("example.com".to_string()), + reason: Some("not_allowed".to_string()), + port: Some(443), + }; + assert_eq!( + network_approval_context_from_payload(&http_connect_payload), + Some(NetworkApprovalContext { + host: "example.com".to_string(), + protocol: NetworkApprovalProtocol::Https, + }) + ); + + let socks5_tcp_payload = NetworkPolicyDecisionPayload { + decision: NetworkPolicyDecision::Ask, + source: NetworkDecisionSource::Decider, + protocol: Some(NetworkApprovalProtocol::Socks5Tcp), + host: Some("example.com".to_string()), + reason: Some("not_allowed".to_string()), + port: Some(443), + }; + assert_eq!( + network_approval_context_from_payload(&socks5_tcp_payload), + Some(NetworkApprovalContext { + host: "example.com".to_string(), + protocol: NetworkApprovalProtocol::Socks5Tcp, + }) + ); + + let socks5_udp_payload = NetworkPolicyDecisionPayload { + decision: NetworkPolicyDecision::Ask, + source: NetworkDecisionSource::Decider, + protocol: Some(NetworkApprovalProtocol::Socks5Udp), + host: Some("example.com".to_string()), + reason: Some("not_allowed".to_string()), + port: Some(443), + }; + assert_eq!( + network_approval_context_from_payload(&socks5_udp_payload), + Some(NetworkApprovalContext { + host: "example.com".to_string(), + protocol: NetworkApprovalProtocol::Socks5Udp, + }) + ); +} + +#[test] +fn network_policy_decision_payload_deserializes_proxy_protocol_aliases() { + let payload: NetworkPolicyDecisionPayload = serde_json::from_str( + r#"{ + "decision":"ask", + "source":"decider", + "protocol":"https_connect", + "host":"example.com", + "reason":"not_allowed", + "port":443 + }"#, + ) + .expect("payload should deserialize"); + assert_eq!(payload.protocol, Some(NetworkApprovalProtocol::Https)); + + let payload: NetworkPolicyDecisionPayload = serde_json::from_str( + r#"{ + "decision":"ask", + "source":"decider", + "protocol":"http-connect", + "host":"example.com", + "reason":"not_allowed", + "port":443 + }"#, + ) + .expect("payload should deserialize"); + assert_eq!(payload.protocol, Some(NetworkApprovalProtocol::Https)); +} + +#[test] +fn execpolicy_network_rule_amendment_maps_protocol_action_and_justification() { + let amendment = NetworkPolicyAmendment { + action: NetworkPolicyRuleAction::Deny, + host: "example.com".to_string(), + }; + let context = NetworkApprovalContext { + host: "example.com".to_string(), + protocol: NetworkApprovalProtocol::Socks5Udp, + }; + + assert_eq!( + execpolicy_network_rule_amendment(&amendment, &context, "example.com"), + ExecPolicyNetworkRuleAmendment { + protocol: ExecPolicyNetworkRuleProtocol::Socks5Udp, + decision: ExecPolicyDecision::Forbidden, + justification: "Deny socks5_udp access to example.com".to_string(), + } + ); +} + +#[test] +fn denied_network_policy_message_requires_deny_decision() { + let blocked = BlockedRequest { + host: "example.com".to_string(), + reason: "not_allowed".to_string(), + client: None, + method: Some("GET".to_string()), + mode: None, + protocol: "http".to_string(), + decision: Some("ask".to_string()), + source: Some("decider".to_string()), + port: Some(80), + timestamp: 0, + }; + assert_eq!(denied_network_policy_message(&blocked), None); +} + +#[test] +fn denied_network_policy_message_for_denylist_block_is_explicit() { + let blocked = BlockedRequest { + host: "example.com".to_string(), + reason: "denied".to_string(), + client: None, + method: Some("GET".to_string()), + mode: None, + protocol: "http".to_string(), + decision: Some("deny".to_string()), + source: Some("baseline_policy".to_string()), + port: Some(80), + timestamp: 0, + }; + assert_eq!( + denied_network_policy_message(&blocked), + Some( + "Network access to \"example.com\" was blocked: domain is explicitly denied by policy and cannot be approved from this prompt.".to_string() + ) + ); +} diff --git a/codex-rs/core/src/network_proxy_loader.rs b/codex-rs/core/src/network_proxy_loader.rs index 1c8244f70b..0509988510 100644 --- a/codex-rs/core/src/network_proxy_loader.rs +++ b/codex-rs/core/src/network_proxy_loader.rs @@ -304,109 +304,5 @@ impl ConfigReloader for MtimeConfigReloader { } #[cfg(test)] -mod tests { - use super::*; - - use codex_execpolicy::Decision; - use codex_execpolicy::NetworkRuleProtocol; - use codex_execpolicy::Policy; - use pretty_assertions::assert_eq; - - #[test] - fn higher_precedence_profile_network_beats_lower_profile_network() { - let lower_network: toml::Value = toml::from_str( - r#" -default_permissions = "workspace" - -[permissions.workspace.network] -allowed_domains = ["lower.example.com"] -"#, - ) - .expect("lower layer should parse"); - let higher_network: toml::Value = toml::from_str( - r#" -default_permissions = "workspace" - -[permissions.workspace.network] -allowed_domains = ["higher.example.com"] -"#, - ) - .expect("higher layer should parse"); - - let mut config = NetworkProxyConfig::default(); - apply_network_tables( - &mut config, - network_tables_from_toml(&lower_network).expect("lower layer should deserialize"), - ) - .expect("lower layer should apply"); - apply_network_tables( - &mut config, - network_tables_from_toml(&higher_network).expect("higher layer should deserialize"), - ) - .expect("higher layer should apply"); - - assert_eq!(config.network.allowed_domains, vec!["higher.example.com"]); - } - - #[test] - fn execpolicy_network_rules_overlay_network_lists() { - let mut config = NetworkProxyConfig::default(); - config.network.allowed_domains = vec!["config.example.com".to_string()]; - config.network.denied_domains = vec!["blocked.example.com".to_string()]; - - let mut exec_policy = Policy::empty(); - exec_policy - .add_network_rule( - "blocked.example.com", - NetworkRuleProtocol::Https, - Decision::Allow, - None, - ) - .expect("allow rule should be valid"); - exec_policy - .add_network_rule( - "api.example.com", - NetworkRuleProtocol::Http, - Decision::Forbidden, - None, - ) - .expect("deny rule should be valid"); - - apply_exec_policy_network_rules(&mut config, &exec_policy); - - assert_eq!( - config.network.allowed_domains, - vec![ - "config.example.com".to_string(), - "blocked.example.com".to_string() - ] - ); - assert_eq!( - config.network.denied_domains, - vec!["api.example.com".to_string()] - ); - } - - #[test] - fn apply_network_constraints_includes_allow_all_unix_sockets_flag() { - let config: toml::Value = toml::from_str( - r#" -default_permissions = "workspace" - -[permissions.workspace.network] -dangerously_allow_all_unix_sockets = true -"#, - ) - .expect("permissions profile should parse"); - let network = selected_network_from_tables( - network_tables_from_toml(&config).expect("permissions profile should deserialize"), - ) - .expect("permissions profile should select a network table") - .expect("network table should be present"); - - let mut constraints = NetworkProxyConstraints::default(); - apply_network_constraints(network, &mut constraints); - - assert_eq!(constraints.dangerously_allow_all_unix_sockets, Some(true)); - } -} +#[path = "network_proxy_loader_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/network_proxy_loader_tests.rs b/codex-rs/core/src/network_proxy_loader_tests.rs new file mode 100644 index 0000000000..018061463b --- /dev/null +++ b/codex-rs/core/src/network_proxy_loader_tests.rs @@ -0,0 +1,104 @@ +use super::*; + +use codex_execpolicy::Decision; +use codex_execpolicy::NetworkRuleProtocol; +use codex_execpolicy::Policy; +use pretty_assertions::assert_eq; + +#[test] +fn higher_precedence_profile_network_beats_lower_profile_network() { + let lower_network: toml::Value = toml::from_str( + r#" +default_permissions = "workspace" + +[permissions.workspace.network] +allowed_domains = ["lower.example.com"] +"#, + ) + .expect("lower layer should parse"); + let higher_network: toml::Value = toml::from_str( + r#" +default_permissions = "workspace" + +[permissions.workspace.network] +allowed_domains = ["higher.example.com"] +"#, + ) + .expect("higher layer should parse"); + + let mut config = NetworkProxyConfig::default(); + apply_network_tables( + &mut config, + network_tables_from_toml(&lower_network).expect("lower layer should deserialize"), + ) + .expect("lower layer should apply"); + apply_network_tables( + &mut config, + network_tables_from_toml(&higher_network).expect("higher layer should deserialize"), + ) + .expect("higher layer should apply"); + + assert_eq!(config.network.allowed_domains, vec!["higher.example.com"]); +} + +#[test] +fn execpolicy_network_rules_overlay_network_lists() { + let mut config = NetworkProxyConfig::default(); + config.network.allowed_domains = vec!["config.example.com".to_string()]; + config.network.denied_domains = vec!["blocked.example.com".to_string()]; + + let mut exec_policy = Policy::empty(); + exec_policy + .add_network_rule( + "blocked.example.com", + NetworkRuleProtocol::Https, + Decision::Allow, + None, + ) + .expect("allow rule should be valid"); + exec_policy + .add_network_rule( + "api.example.com", + NetworkRuleProtocol::Http, + Decision::Forbidden, + None, + ) + .expect("deny rule should be valid"); + + apply_exec_policy_network_rules(&mut config, &exec_policy); + + assert_eq!( + config.network.allowed_domains, + vec![ + "config.example.com".to_string(), + "blocked.example.com".to_string() + ] + ); + assert_eq!( + config.network.denied_domains, + vec!["api.example.com".to_string()] + ); +} + +#[test] +fn apply_network_constraints_includes_allow_all_unix_sockets_flag() { + let config: toml::Value = toml::from_str( + r#" +default_permissions = "workspace" + +[permissions.workspace.network] +dangerously_allow_all_unix_sockets = true +"#, + ) + .expect("permissions profile should parse"); + let network = selected_network_from_tables( + network_tables_from_toml(&config).expect("permissions profile should deserialize"), + ) + .expect("permissions profile should select a network table") + .expect("network table should be present"); + + let mut constraints = NetworkProxyConstraints::default(); + apply_network_constraints(network, &mut constraints); + + assert_eq!(constraints.dangerously_allow_all_unix_sockets, Some(true)); +} diff --git a/codex-rs/core/src/original_image_detail.rs b/codex-rs/core/src/original_image_detail.rs index 06da60dff8..d5bb6d24cd 100644 --- a/codex-rs/core/src/original_image_detail.rs +++ b/codex-rs/core/src/original_image_detail.rs @@ -24,68 +24,5 @@ pub(crate) fn normalize_output_image_detail( } #[cfg(test)] -mod tests { - use super::*; - - use crate::config::test_config; - use crate::features::Features; - use crate::models_manager::manager::ModelsManager; - use pretty_assertions::assert_eq; - - #[test] - fn image_detail_original_feature_enables_explicit_original_without_force() { - let config = test_config(); - let mut model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - model_info.supports_image_detail_original = true; - let mut features = Features::with_defaults(); - features.enable(Feature::ImageDetailOriginal); - - assert!(can_request_original_image_detail(&features, &model_info)); - assert_eq!( - normalize_output_image_detail(&features, &model_info, Some(ImageDetail::Original)), - Some(ImageDetail::Original) - ); - assert_eq!( - normalize_output_image_detail(&features, &model_info, None), - None - ); - } - - #[test] - fn explicit_original_is_dropped_without_feature_or_model_support() { - let config = test_config(); - let mut model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - model_info.supports_image_detail_original = true; - let features = Features::with_defaults(); - - assert_eq!( - normalize_output_image_detail(&features, &model_info, Some(ImageDetail::Original)), - None - ); - - let mut features = Features::with_defaults(); - features.enable(Feature::ImageDetailOriginal); - model_info.supports_image_detail_original = false; - assert_eq!( - normalize_output_image_detail(&features, &model_info, Some(ImageDetail::Original)), - None - ); - } - - #[test] - fn unsupported_non_original_detail_is_dropped() { - let config = test_config(); - let mut model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - model_info.supports_image_detail_original = true; - let mut features = Features::with_defaults(); - features.enable(Feature::ImageDetailOriginal); - - assert_eq!( - normalize_output_image_detail(&features, &model_info, Some(ImageDetail::Low)), - None - ); - } -} +#[path = "original_image_detail_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/original_image_detail_tests.rs b/codex-rs/core/src/original_image_detail_tests.rs new file mode 100644 index 0000000000..b771e87bb4 --- /dev/null +++ b/codex-rs/core/src/original_image_detail_tests.rs @@ -0,0 +1,63 @@ +use super::*; + +use crate::config::test_config; +use crate::features::Features; +use crate::models_manager::manager::ModelsManager; +use pretty_assertions::assert_eq; + +#[test] +fn image_detail_original_feature_enables_explicit_original_without_force() { + let config = test_config(); + let mut model_info = + ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + model_info.supports_image_detail_original = true; + let mut features = Features::with_defaults(); + features.enable(Feature::ImageDetailOriginal); + + assert!(can_request_original_image_detail(&features, &model_info)); + assert_eq!( + normalize_output_image_detail(&features, &model_info, Some(ImageDetail::Original)), + Some(ImageDetail::Original) + ); + assert_eq!( + normalize_output_image_detail(&features, &model_info, None), + None + ); +} + +#[test] +fn explicit_original_is_dropped_without_feature_or_model_support() { + let config = test_config(); + let mut model_info = + ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + model_info.supports_image_detail_original = true; + let features = Features::with_defaults(); + + assert_eq!( + normalize_output_image_detail(&features, &model_info, Some(ImageDetail::Original)), + None + ); + + let mut features = Features::with_defaults(); + features.enable(Feature::ImageDetailOriginal); + model_info.supports_image_detail_original = false; + assert_eq!( + normalize_output_image_detail(&features, &model_info, Some(ImageDetail::Original)), + None + ); +} + +#[test] +fn unsupported_non_original_detail_is_dropped() { + let config = test_config(); + let mut model_info = + ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + model_info.supports_image_detail_original = true; + let mut features = Features::with_defaults(); + features.enable(Feature::ImageDetailOriginal); + + assert_eq!( + normalize_output_image_detail(&features, &model_info, Some(ImageDetail::Low)), + None + ); +} diff --git a/codex-rs/core/src/path_utils.rs b/codex-rs/core/src/path_utils.rs index af7d620777..eca2ce1663 100644 --- a/codex-rs/core/src/path_utils.rs +++ b/codex-rs/core/src/path_utils.rs @@ -199,83 +199,5 @@ fn lower_ascii_path(path: PathBuf) -> PathBuf { } #[cfg(test)] -mod tests { - #[cfg(unix)] - mod symlinks { - use super::super::resolve_symlink_write_paths; - use pretty_assertions::assert_eq; - use std::os::unix::fs::symlink; - - #[test] - fn symlink_cycles_fall_back_to_root_write_path() -> std::io::Result<()> { - let dir = tempfile::tempdir()?; - let a = dir.path().join("a"); - let b = dir.path().join("b"); - - symlink(&b, &a)?; - symlink(&a, &b)?; - - let resolved = resolve_symlink_write_paths(&a)?; - - assert_eq!(resolved.read_path, None); - assert_eq!(resolved.write_path, a); - Ok(()) - } - } - - #[cfg(target_os = "linux")] - mod wsl { - use super::super::normalize_for_wsl_with_flag; - use pretty_assertions::assert_eq; - use std::path::PathBuf; - - #[test] - fn wsl_mnt_drive_paths_lowercase() { - let normalized = normalize_for_wsl_with_flag(PathBuf::from("/mnt/C/Users/Dev"), true); - - assert_eq!(normalized, PathBuf::from("/mnt/c/users/dev")); - } - - #[test] - fn wsl_non_drive_paths_unchanged() { - let path = PathBuf::from("/mnt/cc/Users/Dev"); - let normalized = normalize_for_wsl_with_flag(path.clone(), true); - - assert_eq!(normalized, path); - } - - #[test] - fn wsl_non_mnt_paths_unchanged() { - let path = PathBuf::from("/home/Dev"); - let normalized = normalize_for_wsl_with_flag(path.clone(), true); - - assert_eq!(normalized, path); - } - } - - mod native_workdir { - use super::super::normalize_for_native_workdir_with_flag; - use pretty_assertions::assert_eq; - use std::path::PathBuf; - - #[cfg(target_os = "windows")] - #[test] - fn windows_verbatim_paths_are_simplified() { - let path = PathBuf::from(r"\\?\D:\c\x\worktrees\2508\swift-base"); - let normalized = normalize_for_native_workdir_with_flag(path, true); - - assert_eq!( - normalized, - PathBuf::from(r"D:\c\x\worktrees\2508\swift-base") - ); - } - - #[test] - fn non_windows_paths_are_unchanged() { - let path = PathBuf::from(r"\\?\D:\c\x\worktrees\2508\swift-base"); - let normalized = normalize_for_native_workdir_with_flag(path.clone(), false); - - assert_eq!(normalized, path); - } - } -} +#[path = "path_utils_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/path_utils_tests.rs b/codex-rs/core/src/path_utils_tests.rs new file mode 100644 index 0000000000..f028133f12 --- /dev/null +++ b/codex-rs/core/src/path_utils_tests.rs @@ -0,0 +1,78 @@ +#[cfg(unix)] +mod symlinks { + use super::super::resolve_symlink_write_paths; + use pretty_assertions::assert_eq; + use std::os::unix::fs::symlink; + + #[test] + fn symlink_cycles_fall_back_to_root_write_path() -> std::io::Result<()> { + let dir = tempfile::tempdir()?; + let a = dir.path().join("a"); + let b = dir.path().join("b"); + + symlink(&b, &a)?; + symlink(&a, &b)?; + + let resolved = resolve_symlink_write_paths(&a)?; + + assert_eq!(resolved.read_path, None); + assert_eq!(resolved.write_path, a); + Ok(()) + } +} + +#[cfg(target_os = "linux")] +mod wsl { + use super::super::normalize_for_wsl_with_flag; + use pretty_assertions::assert_eq; + use std::path::PathBuf; + + #[test] + fn wsl_mnt_drive_paths_lowercase() { + let normalized = normalize_for_wsl_with_flag(PathBuf::from("/mnt/C/Users/Dev"), true); + + assert_eq!(normalized, PathBuf::from("/mnt/c/users/dev")); + } + + #[test] + fn wsl_non_drive_paths_unchanged() { + let path = PathBuf::from("/mnt/cc/Users/Dev"); + let normalized = normalize_for_wsl_with_flag(path.clone(), true); + + assert_eq!(normalized, path); + } + + #[test] + fn wsl_non_mnt_paths_unchanged() { + let path = PathBuf::from("/home/Dev"); + let normalized = normalize_for_wsl_with_flag(path.clone(), true); + + assert_eq!(normalized, path); + } +} + +mod native_workdir { + use super::super::normalize_for_native_workdir_with_flag; + use pretty_assertions::assert_eq; + use std::path::PathBuf; + + #[cfg(target_os = "windows")] + #[test] + fn windows_verbatim_paths_are_simplified() { + let path = PathBuf::from(r"\\?\D:\c\x\worktrees\2508\swift-base"); + let normalized = normalize_for_native_workdir_with_flag(path, true); + + assert_eq!( + normalized, + PathBuf::from(r"D:\c\x\worktrees\2508\swift-base") + ); + } + + #[test] + fn non_windows_paths_are_unchanged() { + let path = PathBuf::from(r"\\?\D:\c\x\worktrees\2508\swift-base"); + let normalized = normalize_for_native_workdir_with_flag(path.clone(), false); + + assert_eq!(normalized, path); + } +} diff --git a/codex-rs/core/src/personality_migration.rs b/codex-rs/core/src/personality_migration.rs index 934ff89379..9e541c30a2 100644 --- a/codex-rs/core/src/personality_migration.rs +++ b/codex-rs/core/src/personality_migration.rs @@ -131,138 +131,5 @@ async fn create_marker(marker_path: &Path) -> io::Result<()> { } #[cfg(test)] -mod tests { - use super::*; - use codex_protocol::ThreadId; - use codex_protocol::protocol::EventMsg; - use codex_protocol::protocol::RolloutItem; - use codex_protocol::protocol::RolloutLine; - use codex_protocol::protocol::SessionMeta; - use codex_protocol::protocol::SessionMetaLine; - use codex_protocol::protocol::SessionSource; - use codex_protocol::protocol::UserMessageEvent; - use pretty_assertions::assert_eq; - use tempfile::TempDir; - use tokio::io::AsyncWriteExt; - - const TEST_TIMESTAMP: &str = "2025-01-01T00-00-00"; - - async fn read_config_toml(codex_home: &Path) -> io::Result { - let contents = tokio::fs::read_to_string(codex_home.join("config.toml")).await?; - toml::from_str(&contents).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) - } - - async fn write_session_with_user_event(codex_home: &Path) -> io::Result<()> { - let thread_id = ThreadId::new(); - let dir = codex_home - .join(SESSIONS_SUBDIR) - .join("2025") - .join("01") - .join("01"); - tokio::fs::create_dir_all(&dir).await?; - let file_path = dir.join(format!("rollout-{TEST_TIMESTAMP}-{thread_id}.jsonl")); - let mut file = tokio::fs::File::create(&file_path).await?; - - let session_meta = SessionMetaLine { - meta: SessionMeta { - id: thread_id, - forked_from_id: None, - timestamp: TEST_TIMESTAMP.to_string(), - cwd: std::path::PathBuf::from("."), - originator: "test_originator".to_string(), - cli_version: "test_version".to_string(), - source: SessionSource::Cli, - agent_nickname: None, - agent_role: None, - model_provider: None, - base_instructions: None, - dynamic_tools: None, - memory_mode: None, - }, - git: None, - }; - let meta_line = RolloutLine { - timestamp: TEST_TIMESTAMP.to_string(), - item: RolloutItem::SessionMeta(session_meta), - }; - let user_event = RolloutLine { - timestamp: TEST_TIMESTAMP.to_string(), - item: RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { - message: "hello".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - })), - }; - - file.write_all(format!("{}\n", serde_json::to_string(&meta_line)?).as_bytes()) - .await?; - file.write_all(format!("{}\n", serde_json::to_string(&user_event)?).as_bytes()) - .await?; - Ok(()) - } - - #[tokio::test] - async fn applies_when_sessions_exist_and_no_personality() -> io::Result<()> { - let temp = TempDir::new()?; - write_session_with_user_event(temp.path()).await?; - - let config_toml = ConfigToml::default(); - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; - - assert_eq!(status, PersonalityMigrationStatus::Applied); - assert!(temp.path().join(PERSONALITY_MIGRATION_FILENAME).exists()); - - let persisted = read_config_toml(temp.path()).await?; - assert_eq!(persisted.personality, Some(Personality::Pragmatic)); - Ok(()) - } - - #[tokio::test] - async fn skips_when_marker_exists() -> io::Result<()> { - let temp = TempDir::new()?; - create_marker(&temp.path().join(PERSONALITY_MIGRATION_FILENAME)).await?; - - let config_toml = ConfigToml::default(); - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; - - assert_eq!(status, PersonalityMigrationStatus::SkippedMarker); - assert!(!temp.path().join("config.toml").exists()); - Ok(()) - } - - #[tokio::test] - async fn skips_when_personality_explicit() -> io::Result<()> { - let temp = TempDir::new()?; - ConfigEditsBuilder::new(temp.path()) - .set_personality(Some(Personality::Friendly)) - .apply() - .await - .map_err(|err| io::Error::other(format!("failed to write config: {err}")))?; - - let config_toml = read_config_toml(temp.path()).await?; - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; - - assert_eq!( - status, - PersonalityMigrationStatus::SkippedExplicitPersonality - ); - assert!(temp.path().join(PERSONALITY_MIGRATION_FILENAME).exists()); - - let persisted = read_config_toml(temp.path()).await?; - assert_eq!(persisted.personality, Some(Personality::Friendly)); - Ok(()) - } - - #[tokio::test] - async fn skips_when_no_sessions() -> io::Result<()> { - let temp = TempDir::new()?; - let config_toml = ConfigToml::default(); - let status = maybe_migrate_personality(temp.path(), &config_toml).await?; - - assert_eq!(status, PersonalityMigrationStatus::SkippedNoSessions); - assert!(temp.path().join(PERSONALITY_MIGRATION_FILENAME).exists()); - assert!(!temp.path().join("config.toml").exists()); - Ok(()) - } -} +#[path = "personality_migration_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/personality_migration_tests.rs b/codex-rs/core/src/personality_migration_tests.rs new file mode 100644 index 0000000000..fef1297a97 --- /dev/null +++ b/codex-rs/core/src/personality_migration_tests.rs @@ -0,0 +1,133 @@ +use super::*; +use codex_protocol::ThreadId; +use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::RolloutItem; +use codex_protocol::protocol::RolloutLine; +use codex_protocol::protocol::SessionMeta; +use codex_protocol::protocol::SessionMetaLine; +use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::UserMessageEvent; +use pretty_assertions::assert_eq; +use tempfile::TempDir; +use tokio::io::AsyncWriteExt; + +const TEST_TIMESTAMP: &str = "2025-01-01T00-00-00"; + +async fn read_config_toml(codex_home: &Path) -> io::Result { + let contents = tokio::fs::read_to_string(codex_home.join("config.toml")).await?; + toml::from_str(&contents).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) +} + +async fn write_session_with_user_event(codex_home: &Path) -> io::Result<()> { + let thread_id = ThreadId::new(); + let dir = codex_home + .join(SESSIONS_SUBDIR) + .join("2025") + .join("01") + .join("01"); + tokio::fs::create_dir_all(&dir).await?; + let file_path = dir.join(format!("rollout-{TEST_TIMESTAMP}-{thread_id}.jsonl")); + let mut file = tokio::fs::File::create(&file_path).await?; + + let session_meta = SessionMetaLine { + meta: SessionMeta { + id: thread_id, + forked_from_id: None, + timestamp: TEST_TIMESTAMP.to_string(), + cwd: std::path::PathBuf::from("."), + originator: "test_originator".to_string(), + cli_version: "test_version".to_string(), + source: SessionSource::Cli, + agent_nickname: None, + agent_role: None, + model_provider: None, + base_instructions: None, + dynamic_tools: None, + memory_mode: None, + }, + git: None, + }; + let meta_line = RolloutLine { + timestamp: TEST_TIMESTAMP.to_string(), + item: RolloutItem::SessionMeta(session_meta), + }; + let user_event = RolloutLine { + timestamp: TEST_TIMESTAMP.to_string(), + item: RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { + message: "hello".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + })), + }; + + file.write_all(format!("{}\n", serde_json::to_string(&meta_line)?).as_bytes()) + .await?; + file.write_all(format!("{}\n", serde_json::to_string(&user_event)?).as_bytes()) + .await?; + Ok(()) +} + +#[tokio::test] +async fn applies_when_sessions_exist_and_no_personality() -> io::Result<()> { + let temp = TempDir::new()?; + write_session_with_user_event(temp.path()).await?; + + let config_toml = ConfigToml::default(); + let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + + assert_eq!(status, PersonalityMigrationStatus::Applied); + assert!(temp.path().join(PERSONALITY_MIGRATION_FILENAME).exists()); + + let persisted = read_config_toml(temp.path()).await?; + assert_eq!(persisted.personality, Some(Personality::Pragmatic)); + Ok(()) +} + +#[tokio::test] +async fn skips_when_marker_exists() -> io::Result<()> { + let temp = TempDir::new()?; + create_marker(&temp.path().join(PERSONALITY_MIGRATION_FILENAME)).await?; + + let config_toml = ConfigToml::default(); + let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + + assert_eq!(status, PersonalityMigrationStatus::SkippedMarker); + assert!(!temp.path().join("config.toml").exists()); + Ok(()) +} + +#[tokio::test] +async fn skips_when_personality_explicit() -> io::Result<()> { + let temp = TempDir::new()?; + ConfigEditsBuilder::new(temp.path()) + .set_personality(Some(Personality::Friendly)) + .apply() + .await + .map_err(|err| io::Error::other(format!("failed to write config: {err}")))?; + + let config_toml = read_config_toml(temp.path()).await?; + let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + + assert_eq!( + status, + PersonalityMigrationStatus::SkippedExplicitPersonality + ); + assert!(temp.path().join(PERSONALITY_MIGRATION_FILENAME).exists()); + + let persisted = read_config_toml(temp.path()).await?; + assert_eq!(persisted.personality, Some(Personality::Friendly)); + Ok(()) +} + +#[tokio::test] +async fn skips_when_no_sessions() -> io::Result<()> { + let temp = TempDir::new()?; + let config_toml = ConfigToml::default(); + let status = maybe_migrate_personality(temp.path(), &config_toml).await?; + + assert_eq!(status, PersonalityMigrationStatus::SkippedNoSessions); + assert!(temp.path().join(PERSONALITY_MIGRATION_FILENAME).exists()); + assert!(!temp.path().join("config.toml").exists()); + Ok(()) +} diff --git a/codex-rs/core/src/plugins/curated_repo.rs b/codex-rs/core/src/plugins/curated_repo.rs index 41f8347992..3307f28ffc 100644 --- a/codex-rs/core/src/plugins/curated_repo.rs +++ b/codex-rs/core/src/plugins/curated_repo.rs @@ -352,168 +352,5 @@ fn apply_zip_permissions( } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use std::io::Write; - use tempfile::tempdir; - use wiremock::Mock; - use wiremock::MockServer; - use wiremock::ResponseTemplate; - use wiremock::matchers::method; - use wiremock::matchers::path; - use zip::ZipWriter; - use zip::write::SimpleFileOptions; - - #[test] - fn curated_plugins_repo_path_uses_codex_home_tmp_dir() { - let tmp = tempdir().expect("tempdir"); - assert_eq!( - curated_plugins_repo_path(tmp.path()), - tmp.path().join(".tmp/plugins") - ); - } - - #[test] - fn read_curated_plugins_sha_reads_trimmed_sha_file() { - let tmp = tempdir().expect("tempdir"); - fs::create_dir_all(tmp.path().join(".tmp")).expect("create tmp"); - fs::write(tmp.path().join(".tmp/plugins.sha"), "abc123\n").expect("write sha"); - - assert_eq!( - read_curated_plugins_sha(tmp.path()).as_deref(), - Some("abc123") - ); - } - - #[tokio::test] - async fn sync_openai_plugins_repo_downloads_zipball_and_records_sha() { - let tmp = tempdir().expect("tempdir"); - let server = MockServer::start().await; - let sha = "0123456789abcdef0123456789abcdef01234567"; - - Mock::given(method("GET")) - .and(path("/repos/openai/plugins")) - .respond_with( - ResponseTemplate::new(200).set_body_string(r#"{"default_branch":"main"}"#), - ) - .mount(&server) - .await; - Mock::given(method("GET")) - .and(path("/repos/openai/plugins/git/ref/heads/main")) - .respond_with( - ResponseTemplate::new(200) - .set_body_string(format!(r#"{{"object":{{"sha":"{sha}"}}}}"#)), - ) - .mount(&server) - .await; - Mock::given(method("GET")) - .and(path(format!("/repos/openai/plugins/zipball/{sha}"))) - .respond_with( - ResponseTemplate::new(200) - .insert_header("content-type", "application/zip") - .set_body_bytes(curated_repo_zipball_bytes(sha)), - ) - .mount(&server) - .await; - - let server_uri = server.uri(); - let tmp_path = tmp.path().to_path_buf(); - tokio::task::spawn_blocking(move || { - sync_openai_plugins_repo_with_api_base_url(tmp_path.as_path(), &server_uri) - }) - .await - .expect("sync task should join") - .expect("sync should succeed"); - - let repo_path = curated_plugins_repo_path(tmp.path()); - assert!(repo_path.join(".agents/plugins/marketplace.json").is_file()); - assert!( - repo_path - .join("plugins/gmail/.codex-plugin/plugin.json") - .is_file() - ); - assert_eq!(read_curated_plugins_sha(tmp.path()).as_deref(), Some(sha)); - } - - #[tokio::test] - async fn sync_openai_plugins_repo_skips_archive_download_when_sha_matches() { - let tmp = tempdir().expect("tempdir"); - let repo_path = curated_plugins_repo_path(tmp.path()); - fs::create_dir_all(repo_path.join(".agents/plugins")).expect("create repo"); - fs::write( - repo_path.join(".agents/plugins/marketplace.json"), - r#"{"name":"openai-curated","plugins":[]}"#, - ) - .expect("write marketplace"); - fs::create_dir_all(tmp.path().join(".tmp")).expect("create tmp"); - let sha = "fedcba9876543210fedcba9876543210fedcba98"; - fs::write(tmp.path().join(".tmp/plugins.sha"), format!("{sha}\n")).expect("write sha"); - - let server = MockServer::start().await; - Mock::given(method("GET")) - .and(path("/repos/openai/plugins")) - .respond_with( - ResponseTemplate::new(200).set_body_string(r#"{"default_branch":"main"}"#), - ) - .mount(&server) - .await; - Mock::given(method("GET")) - .and(path("/repos/openai/plugins/git/ref/heads/main")) - .respond_with( - ResponseTemplate::new(200) - .set_body_string(format!(r#"{{"object":{{"sha":"{sha}"}}}}"#)), - ) - .mount(&server) - .await; - - let server_uri = server.uri(); - let tmp_path = tmp.path().to_path_buf(); - tokio::task::spawn_blocking(move || { - sync_openai_plugins_repo_with_api_base_url(tmp_path.as_path(), &server_uri) - }) - .await - .expect("sync task should join") - .expect("sync should succeed"); - - assert_eq!(read_curated_plugins_sha(tmp.path()).as_deref(), Some(sha)); - assert!(repo_path.join(".agents/plugins/marketplace.json").is_file()); - } - - fn curated_repo_zipball_bytes(sha: &str) -> Vec { - let cursor = Cursor::new(Vec::new()); - let mut writer = ZipWriter::new(cursor); - let options = SimpleFileOptions::default(); - let root = format!("openai-plugins-{sha}"); - writer - .start_file(format!("{root}/.agents/plugins/marketplace.json"), options) - .expect("start marketplace entry"); - writer - .write_all( - br#"{ - "name": "openai-curated", - "plugins": [ - { - "name": "gmail", - "source": { - "source": "local", - "path": "./plugins/gmail" - } - } - ] -}"#, - ) - .expect("write marketplace"); - writer - .start_file( - format!("{root}/plugins/gmail/.codex-plugin/plugin.json"), - options, - ) - .expect("start plugin manifest entry"); - writer - .write_all(br#"{"name":"gmail"}"#) - .expect("write plugin manifest"); - - writer.finish().expect("finish zip writer").into_inner() - } -} +#[path = "curated_repo_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/plugins/curated_repo_tests.rs b/codex-rs/core/src/plugins/curated_repo_tests.rs new file mode 100644 index 0000000000..5a14124d06 --- /dev/null +++ b/codex-rs/core/src/plugins/curated_repo_tests.rs @@ -0,0 +1,159 @@ +use super::*; +use pretty_assertions::assert_eq; +use std::io::Write; +use tempfile::tempdir; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; +use zip::ZipWriter; +use zip::write::SimpleFileOptions; + +#[test] +fn curated_plugins_repo_path_uses_codex_home_tmp_dir() { + let tmp = tempdir().expect("tempdir"); + assert_eq!( + curated_plugins_repo_path(tmp.path()), + tmp.path().join(".tmp/plugins") + ); +} + +#[test] +fn read_curated_plugins_sha_reads_trimmed_sha_file() { + let tmp = tempdir().expect("tempdir"); + fs::create_dir_all(tmp.path().join(".tmp")).expect("create tmp"); + fs::write(tmp.path().join(".tmp/plugins.sha"), "abc123\n").expect("write sha"); + + assert_eq!( + read_curated_plugins_sha(tmp.path()).as_deref(), + Some("abc123") + ); +} + +#[tokio::test] +async fn sync_openai_plugins_repo_downloads_zipball_and_records_sha() { + let tmp = tempdir().expect("tempdir"); + let server = MockServer::start().await; + let sha = "0123456789abcdef0123456789abcdef01234567"; + + Mock::given(method("GET")) + .and(path("/repos/openai/plugins")) + .respond_with(ResponseTemplate::new(200).set_body_string(r#"{"default_branch":"main"}"#)) + .mount(&server) + .await; + Mock::given(method("GET")) + .and(path("/repos/openai/plugins/git/ref/heads/main")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string(format!(r#"{{"object":{{"sha":"{sha}"}}}}"#)), + ) + .mount(&server) + .await; + Mock::given(method("GET")) + .and(path(format!("/repos/openai/plugins/zipball/{sha}"))) + .respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "application/zip") + .set_body_bytes(curated_repo_zipball_bytes(sha)), + ) + .mount(&server) + .await; + + let server_uri = server.uri(); + let tmp_path = tmp.path().to_path_buf(); + tokio::task::spawn_blocking(move || { + sync_openai_plugins_repo_with_api_base_url(tmp_path.as_path(), &server_uri) + }) + .await + .expect("sync task should join") + .expect("sync should succeed"); + + let repo_path = curated_plugins_repo_path(tmp.path()); + assert!(repo_path.join(".agents/plugins/marketplace.json").is_file()); + assert!( + repo_path + .join("plugins/gmail/.codex-plugin/plugin.json") + .is_file() + ); + assert_eq!(read_curated_plugins_sha(tmp.path()).as_deref(), Some(sha)); +} + +#[tokio::test] +async fn sync_openai_plugins_repo_skips_archive_download_when_sha_matches() { + let tmp = tempdir().expect("tempdir"); + let repo_path = curated_plugins_repo_path(tmp.path()); + fs::create_dir_all(repo_path.join(".agents/plugins")).expect("create repo"); + fs::write( + repo_path.join(".agents/plugins/marketplace.json"), + r#"{"name":"openai-curated","plugins":[]}"#, + ) + .expect("write marketplace"); + fs::create_dir_all(tmp.path().join(".tmp")).expect("create tmp"); + let sha = "fedcba9876543210fedcba9876543210fedcba98"; + fs::write(tmp.path().join(".tmp/plugins.sha"), format!("{sha}\n")).expect("write sha"); + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/repos/openai/plugins")) + .respond_with(ResponseTemplate::new(200).set_body_string(r#"{"default_branch":"main"}"#)) + .mount(&server) + .await; + Mock::given(method("GET")) + .and(path("/repos/openai/plugins/git/ref/heads/main")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string(format!(r#"{{"object":{{"sha":"{sha}"}}}}"#)), + ) + .mount(&server) + .await; + + let server_uri = server.uri(); + let tmp_path = tmp.path().to_path_buf(); + tokio::task::spawn_blocking(move || { + sync_openai_plugins_repo_with_api_base_url(tmp_path.as_path(), &server_uri) + }) + .await + .expect("sync task should join") + .expect("sync should succeed"); + + assert_eq!(read_curated_plugins_sha(tmp.path()).as_deref(), Some(sha)); + assert!(repo_path.join(".agents/plugins/marketplace.json").is_file()); +} + +fn curated_repo_zipball_bytes(sha: &str) -> Vec { + let cursor = Cursor::new(Vec::new()); + let mut writer = ZipWriter::new(cursor); + let options = SimpleFileOptions::default(); + let root = format!("openai-plugins-{sha}"); + writer + .start_file(format!("{root}/.agents/plugins/marketplace.json"), options) + .expect("start marketplace entry"); + writer + .write_all( + br#"{ + "name": "openai-curated", + "plugins": [ + { + "name": "gmail", + "source": { + "source": "local", + "path": "./plugins/gmail" + } + } + ] +}"#, + ) + .expect("write marketplace"); + writer + .start_file( + format!("{root}/plugins/gmail/.codex-plugin/plugin.json"), + options, + ) + .expect("start plugin manifest entry"); + writer + .write_all(br#"{"name":"gmail"}"#) + .expect("write plugin manifest"); + + writer.finish().expect("finish zip writer").into_inner() +} diff --git a/codex-rs/core/src/plugins/manager.rs b/codex-rs/core/src/plugins/manager.rs index a9d7cbdc79..032437c199 100644 --- a/codex-rs/core/src/plugins/manager.rs +++ b/codex-rs/core/src/plugins/manager.rs @@ -1342,1642 +1342,5 @@ struct PluginMcpDiscovery { } #[cfg(test)] -mod tests { - use super::*; - use crate::auth::CodexAuth; - use crate::config::CONFIG_TOML_FILE; - use crate::config::ConfigBuilder; - use crate::config::types::McpServerTransportConfig; - use crate::config_loader::ConfigLayerEntry; - use crate::config_loader::ConfigLayerStack; - use crate::config_loader::ConfigRequirements; - use crate::config_loader::ConfigRequirementsToml; - use codex_app_server_protocol::ConfigLayerSource; - use pretty_assertions::assert_eq; - use std::fs; - use tempfile::TempDir; - use toml::Value; - use wiremock::Mock; - use wiremock::MockServer; - use wiremock::ResponseTemplate; - use wiremock::matchers::header; - use wiremock::matchers::method; - use wiremock::matchers::path; - - const TEST_CURATED_PLUGIN_SHA: &str = "0123456789abcdef0123456789abcdef01234567"; - - fn write_file(path: &Path, contents: &str) { - fs::create_dir_all(path.parent().expect("file should have a parent")).unwrap(); - fs::write(path, contents).unwrap(); - } - - fn write_plugin(root: &Path, dir_name: &str, manifest_name: &str) { - let plugin_root = root.join(dir_name); - fs::create_dir_all(plugin_root.join(".codex-plugin")).unwrap(); - fs::create_dir_all(plugin_root.join("skills")).unwrap(); - fs::write( - plugin_root.join(".codex-plugin/plugin.json"), - format!(r#"{{"name":"{manifest_name}"}}"#), - ) - .unwrap(); - fs::write(plugin_root.join("skills/SKILL.md"), "skill").unwrap(); - fs::write(plugin_root.join(".mcp.json"), r#"{"mcpServers":{}}"#).unwrap(); - } - - fn write_openai_curated_marketplace(root: &Path, plugin_names: &[&str]) { - fs::create_dir_all(root.join(".agents/plugins")).unwrap(); - let plugins = plugin_names - .iter() - .map(|plugin_name| { - format!( - r#"{{ - "name": "{plugin_name}", - "source": {{ - "source": "local", - "path": "./plugins/{plugin_name}" - }} - }}"# - ) - }) - .collect::>() - .join(",\n"); - fs::write( - root.join(".agents/plugins/marketplace.json"), - format!( - r#"{{ - "name": "{OPENAI_CURATED_MARKETPLACE_NAME}", - "plugins": [ -{plugins} - ] -}}"# - ), - ) - .unwrap(); - for plugin_name in plugin_names { - write_plugin(root, &format!("plugins/{plugin_name}"), plugin_name); - } - } - - fn write_curated_plugin_sha(codex_home: &Path, sha: &str) { - write_file(&codex_home.join(".tmp/plugins.sha"), &format!("{sha}\n")); - } - - fn plugin_config_toml(enabled: bool, plugins_feature_enabled: bool) -> String { - let mut root = toml::map::Map::new(); - - let mut features = toml::map::Map::new(); - features.insert( - "plugins".to_string(), - Value::Boolean(plugins_feature_enabled), - ); - root.insert("features".to_string(), Value::Table(features)); - - let mut plugin = toml::map::Map::new(); - plugin.insert("enabled".to_string(), Value::Boolean(enabled)); - - let mut plugins = toml::map::Map::new(); - plugins.insert("sample@test".to_string(), Value::Table(plugin)); - root.insert("plugins".to_string(), Value::Table(plugins)); - - toml::to_string(&Value::Table(root)).expect("plugin test config should serialize") - } - - fn load_plugins_from_config(config_toml: &str, codex_home: &Path) -> PluginLoadOutcome { - write_file(&codex_home.join(CONFIG_TOML_FILE), config_toml); - let stack = ConfigLayerStack::new( - vec![ConfigLayerEntry::new( - ConfigLayerSource::User { - file: AbsolutePathBuf::try_from(codex_home.join(CONFIG_TOML_FILE)).unwrap(), - }, - toml::from_str(config_toml).expect("plugin test config should parse"), - )], - ConfigRequirements::default(), - ConfigRequirementsToml::default(), - ) - .expect("config layer stack should build"); - PluginsManager::new(codex_home.to_path_buf()) - .plugins_for_layer_stack(codex_home, &stack, false) - } - - async fn load_config(codex_home: &Path, cwd: &Path) -> crate::config::Config { - ConfigBuilder::default() - .codex_home(codex_home.to_path_buf()) - .fallback_cwd(Some(cwd.to_path_buf())) - .build() - .await - .expect("config should load") - } - - #[test] - fn load_plugins_loads_default_skills_and_mcp_servers() { - let codex_home = TempDir::new().unwrap(); - let plugin_root = codex_home - .path() - .join("plugins/cache") - .join("test/sample/local"); - - write_file( - &plugin_root.join(".codex-plugin/plugin.json"), - r#"{ - "name": "sample", - "description": "Plugin that includes the sample MCP server and Skills" -}"#, - ); - write_file( - &plugin_root.join("skills/sample-search/SKILL.md"), - "---\nname: sample-search\ndescription: search sample data\n---\n", - ); - write_file( - &plugin_root.join(".mcp.json"), - r#"{ - "mcpServers": { - "sample": { - "type": "http", - "url": "https://sample.example/mcp", - "oauth": { - "clientId": "client-id", - "callbackPort": 3118 - } - } - } -}"#, - ); - write_file( - &plugin_root.join(".app.json"), - r#"{ - "apps": { - "example": { - "id": "connector_example" - } - } -}"#, - ); - - let outcome = load_plugins_from_config(&plugin_config_toml(true, true), codex_home.path()); - - assert_eq!( - outcome.plugins, - vec![LoadedPlugin { - config_name: "sample@test".to_string(), - manifest_name: Some("sample".to_string()), - manifest_description: Some( - "Plugin that includes the sample MCP server and Skills".to_string(), - ), - root: AbsolutePathBuf::try_from(plugin_root.clone()).unwrap(), - enabled: true, - skill_roots: vec![plugin_root.join("skills")], - mcp_servers: HashMap::from([( - "sample".to_string(), - McpServerConfig { - transport: McpServerTransportConfig::StreamableHttp { - url: "https://sample.example/mcp".to_string(), - bearer_token_env_var: None, - http_headers: None, - env_http_headers: None, - }, - enabled: true, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - )]), - apps: vec![AppConnectorId("connector_example".to_string())], - error: None, - }] - ); - assert_eq!( - outcome.capability_summaries(), - &[PluginCapabilitySummary { - config_name: "sample@test".to_string(), - display_name: "sample".to_string(), - description: Some( - "Plugin that includes the sample MCP server and Skills".to_string(), - ), - has_skills: true, - mcp_server_names: vec!["sample".to_string()], - app_connector_ids: vec![AppConnectorId("connector_example".to_string())], - }] - ); - assert_eq!( - outcome.effective_skill_roots(), - vec![plugin_root.join("skills")] - ); - assert_eq!(outcome.effective_mcp_servers().len(), 1); - assert_eq!( - outcome.effective_apps(), - vec![AppConnectorId("connector_example".to_string())] - ); - } - - #[test] - fn load_plugins_uses_manifest_configured_component_paths() { - let codex_home = TempDir::new().unwrap(); - let plugin_root = codex_home - .path() - .join("plugins/cache") - .join("test/sample/local"); - - write_file( - &plugin_root.join(".codex-plugin/plugin.json"), - r#"{ - "name": "sample", - "skills": "./custom-skills/", - "mcpServers": "./config/custom.mcp.json", - "apps": "./config/custom.app.json" -}"#, - ); - write_file( - &plugin_root.join("skills/default-skill/SKILL.md"), - "---\nname: default-skill\ndescription: default skill\n---\n", - ); - write_file( - &plugin_root.join("custom-skills/custom-skill/SKILL.md"), - "---\nname: custom-skill\ndescription: custom skill\n---\n", - ); - write_file( - &plugin_root.join(".mcp.json"), - r#"{ - "mcpServers": { - "default": { - "type": "http", - "url": "https://default.example/mcp" - } - } -}"#, - ); - write_file( - &plugin_root.join("config/custom.mcp.json"), - r#"{ - "mcpServers": { - "custom": { - "type": "http", - "url": "https://custom.example/mcp" - } - } -}"#, - ); - write_file( - &plugin_root.join(".app.json"), - r#"{ - "apps": { - "default": { - "id": "connector_default" - } - } -}"#, - ); - write_file( - &plugin_root.join("config/custom.app.json"), - r#"{ - "apps": { - "custom": { - "id": "connector_custom" - } - } -}"#, - ); - - let outcome = load_plugins_from_config(&plugin_config_toml(true, true), codex_home.path()); - - assert_eq!( - outcome.plugins[0].skill_roots, - vec![ - plugin_root.join("custom-skills"), - plugin_root.join("skills") - ] - ); - assert_eq!( - outcome.plugins[0].mcp_servers, - HashMap::from([( - "custom".to_string(), - McpServerConfig { - transport: McpServerTransportConfig::StreamableHttp { - url: "https://custom.example/mcp".to_string(), - bearer_token_env_var: None, - http_headers: None, - env_http_headers: None, - }, - enabled: true, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - )]) - ); - assert_eq!( - outcome.plugins[0].apps, - vec![AppConnectorId("connector_custom".to_string())] - ); - } - - #[test] - fn load_plugins_ignores_manifest_component_paths_without_dot_slash() { - let codex_home = TempDir::new().unwrap(); - let plugin_root = codex_home - .path() - .join("plugins/cache") - .join("test/sample/local"); - - write_file( - &plugin_root.join(".codex-plugin/plugin.json"), - r#"{ - "name": "sample", - "skills": "custom-skills", - "mcpServers": "config/custom.mcp.json", - "apps": "config/custom.app.json" -}"#, - ); - write_file( - &plugin_root.join("skills/default-skill/SKILL.md"), - "---\nname: default-skill\ndescription: default skill\n---\n", - ); - write_file( - &plugin_root.join("custom-skills/custom-skill/SKILL.md"), - "---\nname: custom-skill\ndescription: custom skill\n---\n", - ); - write_file( - &plugin_root.join(".mcp.json"), - r#"{ - "mcpServers": { - "default": { - "type": "http", - "url": "https://default.example/mcp" - } - } -}"#, - ); - write_file( - &plugin_root.join("config/custom.mcp.json"), - r#"{ - "mcpServers": { - "custom": { - "type": "http", - "url": "https://custom.example/mcp" - } - } -}"#, - ); - write_file( - &plugin_root.join(".app.json"), - r#"{ - "apps": { - "default": { - "id": "connector_default" - } - } -}"#, - ); - write_file( - &plugin_root.join("config/custom.app.json"), - r#"{ - "apps": { - "custom": { - "id": "connector_custom" - } - } -}"#, - ); - - let outcome = load_plugins_from_config(&plugin_config_toml(true, true), codex_home.path()); - - assert_eq!( - outcome.plugins[0].skill_roots, - vec![plugin_root.join("skills")] - ); - assert_eq!( - outcome.plugins[0].mcp_servers, - HashMap::from([( - "default".to_string(), - McpServerConfig { - transport: McpServerTransportConfig::StreamableHttp { - url: "https://default.example/mcp".to_string(), - bearer_token_env_var: None, - http_headers: None, - env_http_headers: None, - }, - enabled: true, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }, - )]) - ); - assert_eq!( - outcome.plugins[0].apps, - vec![AppConnectorId("connector_default".to_string())] - ); - } - - #[test] - fn load_plugins_preserves_disabled_plugins_without_effective_contributions() { - let codex_home = TempDir::new().unwrap(); - let plugin_root = codex_home - .path() - .join("plugins/cache") - .join("test/sample/local"); - - write_file( - &plugin_root.join(".codex-plugin/plugin.json"), - r#"{"name":"sample"}"#, - ); - write_file( - &plugin_root.join(".mcp.json"), - r#"{ - "mcpServers": { - "sample": { - "type": "http", - "url": "https://sample.example/mcp" - } - } -}"#, - ); - - let outcome = load_plugins_from_config(&plugin_config_toml(false, true), codex_home.path()); - - assert_eq!( - outcome.plugins, - vec![LoadedPlugin { - config_name: "sample@test".to_string(), - manifest_name: None, - manifest_description: None, - root: AbsolutePathBuf::try_from(plugin_root).unwrap(), - enabled: false, - skill_roots: Vec::new(), - mcp_servers: HashMap::new(), - apps: Vec::new(), - error: None, - }] - ); - assert!(outcome.effective_skill_roots().is_empty()); - assert!(outcome.effective_mcp_servers().is_empty()); - } - - #[test] - fn effective_apps_dedupes_connector_ids_across_plugins() { - let codex_home = TempDir::new().unwrap(); - let plugin_a_root = codex_home - .path() - .join("plugins/cache") - .join("test/plugin-a/local"); - let plugin_b_root = codex_home - .path() - .join("plugins/cache") - .join("test/plugin-b/local"); - - write_file( - &plugin_a_root.join(".codex-plugin/plugin.json"), - r#"{"name":"plugin-a"}"#, - ); - write_file( - &plugin_a_root.join(".app.json"), - r#"{ - "apps": { - "example": { - "id": "connector_example" - } - } -}"#, - ); - write_file( - &plugin_b_root.join(".codex-plugin/plugin.json"), - r#"{"name":"plugin-b"}"#, - ); - write_file( - &plugin_b_root.join(".app.json"), - r#"{ - "apps": { - "chat": { - "id": "connector_example" - }, - "gmail": { - "id": "connector_gmail" - } - } -}"#, - ); - - let mut root = toml::map::Map::new(); - let mut features = toml::map::Map::new(); - features.insert("plugins".to_string(), Value::Boolean(true)); - root.insert("features".to_string(), Value::Table(features)); - - let mut plugins = toml::map::Map::new(); - - let mut plugin_a = toml::map::Map::new(); - plugin_a.insert("enabled".to_string(), Value::Boolean(true)); - plugins.insert("plugin-a@test".to_string(), Value::Table(plugin_a)); - - let mut plugin_b = toml::map::Map::new(); - plugin_b.insert("enabled".to_string(), Value::Boolean(true)); - plugins.insert("plugin-b@test".to_string(), Value::Table(plugin_b)); - - root.insert("plugins".to_string(), Value::Table(plugins)); - let config_toml = - toml::to_string(&Value::Table(root)).expect("plugin test config should serialize"); - - let outcome = load_plugins_from_config(&config_toml, codex_home.path()); - - assert_eq!( - outcome.effective_apps(), - vec![ - AppConnectorId("connector_example".to_string()), - AppConnectorId("connector_gmail".to_string()), - ] - ); - } - - #[test] - fn capability_index_filters_inactive_and_zero_capability_plugins() { - let codex_home = TempDir::new().unwrap(); - let connector = |id: &str| AppConnectorId(id.to_string()); - let http_server = |url: &str| McpServerConfig { - transport: McpServerTransportConfig::StreamableHttp { - url: url.to_string(), - bearer_token_env_var: None, - http_headers: None, - env_http_headers: None, - }, - enabled: true, - required: false, - disabled_reason: None, - startup_timeout_sec: None, - tool_timeout_sec: None, - enabled_tools: None, - disabled_tools: None, - scopes: None, - oauth_resource: None, - }; - let plugin = |config_name: &str, dir_name: &str, manifest_name: &str| LoadedPlugin { - config_name: config_name.to_string(), - manifest_name: Some(manifest_name.to_string()), - manifest_description: None, - root: AbsolutePathBuf::try_from(codex_home.path().join(dir_name)).unwrap(), - enabled: true, - skill_roots: Vec::new(), - mcp_servers: HashMap::new(), - apps: Vec::new(), - error: None, - }; - let summary = |config_name: &str, display_name: &str| PluginCapabilitySummary { - config_name: config_name.to_string(), - display_name: display_name.to_string(), - description: None, - ..PluginCapabilitySummary::default() - }; - let outcome = PluginLoadOutcome::from_plugins(vec![ - LoadedPlugin { - skill_roots: vec![codex_home.path().join("skills-plugin/skills")], - ..plugin("skills@test", "skills-plugin", "skills-plugin") - }, - LoadedPlugin { - mcp_servers: HashMap::from([("alpha".to_string(), http_server("https://alpha"))]), - apps: vec![connector("connector_example")], - ..plugin("alpha@test", "alpha-plugin", "alpha-plugin") - }, - LoadedPlugin { - mcp_servers: HashMap::from([("beta".to_string(), http_server("https://beta"))]), - apps: vec![connector("connector_example"), connector("connector_gmail")], - ..plugin("beta@test", "beta-plugin", "beta-plugin") - }, - plugin("empty@test", "empty-plugin", "empty-plugin"), - LoadedPlugin { - enabled: false, - skill_roots: vec![codex_home.path().join("disabled-plugin/skills")], - apps: vec![connector("connector_hidden")], - ..plugin("disabled@test", "disabled-plugin", "disabled-plugin") - }, - LoadedPlugin { - apps: vec![connector("connector_broken")], - error: Some("failed to load".to_string()), - ..plugin("broken@test", "broken-plugin", "broken-plugin") - }, - ]); - - assert_eq!( - outcome.capability_summaries(), - &[ - PluginCapabilitySummary { - has_skills: true, - ..summary("skills@test", "skills-plugin") - }, - PluginCapabilitySummary { - mcp_server_names: vec!["alpha".to_string()], - app_connector_ids: vec![connector("connector_example")], - ..summary("alpha@test", "alpha-plugin") - }, - PluginCapabilitySummary { - mcp_server_names: vec!["beta".to_string()], - app_connector_ids: vec![ - connector("connector_example"), - connector("connector_gmail"), - ], - ..summary("beta@test", "beta-plugin") - }, - ] - ); - } - - #[test] - fn plugin_namespace_for_skill_path_uses_manifest_name() { - let codex_home = TempDir::new().unwrap(); - let plugin_root = codex_home.path().join("plugins/sample"); - let skill_path = plugin_root.join("skills/search/SKILL.md"); - - write_file( - &plugin_root.join(".codex-plugin/plugin.json"), - r#"{"name":"sample"}"#, - ); - write_file(&skill_path, "---\ndescription: search\n---\n"); - - assert_eq!( - plugin_namespace_for_skill_path(&skill_path), - Some("sample".to_string()) - ); - } - - #[test] - fn load_plugins_returns_empty_when_feature_disabled() { - let codex_home = TempDir::new().unwrap(); - let plugin_root = codex_home - .path() - .join("plugins/cache") - .join("test/sample/local"); - - write_file( - &plugin_root.join(".codex-plugin/plugin.json"), - r#"{"name":"sample"}"#, - ); - write_file( - &plugin_root.join("skills/sample-search/SKILL.md"), - "---\nname: sample-search\ndescription: search sample data\n---\n", - ); - - let outcome = load_plugins_from_config(&plugin_config_toml(true, false), codex_home.path()); - - assert_eq!(outcome, PluginLoadOutcome::default()); - } - - #[test] - fn load_plugins_rejects_invalid_plugin_keys() { - let codex_home = TempDir::new().unwrap(); - let plugin_root = codex_home - .path() - .join("plugins/cache") - .join("test/sample/local"); - - write_file( - &plugin_root.join(".codex-plugin/plugin.json"), - r#"{"name":"sample"}"#, - ); - - let mut root = toml::map::Map::new(); - let mut features = toml::map::Map::new(); - features.insert("plugins".to_string(), Value::Boolean(true)); - root.insert("features".to_string(), Value::Table(features)); - - let mut plugin = toml::map::Map::new(); - plugin.insert("enabled".to_string(), Value::Boolean(true)); - - let mut plugins = toml::map::Map::new(); - plugins.insert("sample".to_string(), Value::Table(plugin)); - root.insert("plugins".to_string(), Value::Table(plugins)); - - let outcome = load_plugins_from_config( - &toml::to_string(&Value::Table(root)).expect("plugin test config should serialize"), - codex_home.path(), - ); - - assert_eq!(outcome.plugins.len(), 1); - assert_eq!( - outcome.plugins[0].error.as_deref(), - Some("invalid plugin key `sample`; expected @") - ); - assert!(outcome.effective_skill_roots().is_empty()); - assert!(outcome.effective_mcp_servers().is_empty()); - } - - #[tokio::test] - async fn install_plugin_updates_config_with_relative_path_and_plugin_key() { - let tmp = tempfile::tempdir().unwrap(); - let repo_root = tmp.path().join("repo"); - fs::create_dir_all(repo_root.join(".git")).unwrap(); - fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); - write_plugin(&repo_root, "sample-plugin", "sample-plugin"); - fs::write( - repo_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "debug", - "plugins": [ - { - "name": "sample-plugin", - "source": { - "source": "local", - "path": "./sample-plugin" - }, - "authPolicy": "ON_USE" - } - ] -}"#, - ) - .unwrap(); - - let result = PluginsManager::new(tmp.path().to_path_buf()) - .install_plugin(PluginInstallRequest { - plugin_name: "sample-plugin".to_string(), - marketplace_path: AbsolutePathBuf::try_from( - repo_root.join(".agents/plugins/marketplace.json"), - ) - .unwrap(), - }) - .await - .unwrap(); - - let installed_path = tmp.path().join("plugins/cache/debug/sample-plugin/local"); - assert_eq!( - result, - PluginInstallOutcome { - plugin_id: PluginId::new("sample-plugin".to_string(), "debug".to_string()).unwrap(), - plugin_version: "local".to_string(), - installed_path: AbsolutePathBuf::try_from(installed_path).unwrap(), - auth_policy: MarketplacePluginAuthPolicy::OnUse, - } - ); - - let config = fs::read_to_string(tmp.path().join("config.toml")).unwrap(); - assert!(config.contains(r#"[plugins."sample-plugin@debug"]"#)); - assert!(config.contains("enabled = true")); - } - - #[tokio::test] - async fn uninstall_plugin_removes_cache_and_config_entry() { - let tmp = tempfile::tempdir().unwrap(); - write_plugin( - &tmp.path().join("plugins/cache/debug"), - "sample-plugin/local", - "sample-plugin", - ); - write_file( - &tmp.path().join(CONFIG_TOML_FILE), - r#"[features] -plugins = true - -[plugins."sample-plugin@debug"] -enabled = true -"#, - ); - - let manager = PluginsManager::new(tmp.path().to_path_buf()); - manager - .uninstall_plugin("sample-plugin@debug".to_string()) - .await - .unwrap(); - manager - .uninstall_plugin("sample-plugin@debug".to_string()) - .await - .unwrap(); - - assert!( - !tmp.path() - .join("plugins/cache/debug/sample-plugin") - .exists() - ); - let config = fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).unwrap(); - assert!(!config.contains(r#"[plugins."sample-plugin@debug"]"#)); - } - - #[tokio::test] - async fn list_marketplaces_includes_enabled_state() { - let tmp = tempfile::tempdir().unwrap(); - let repo_root = tmp.path().join("repo"); - fs::create_dir_all(repo_root.join(".git")).unwrap(); - fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); - write_plugin( - &tmp.path().join("plugins/cache/debug"), - "enabled-plugin/local", - "enabled-plugin", - ); - write_plugin( - &tmp.path().join("plugins/cache/debug"), - "disabled-plugin/local", - "disabled-plugin", - ); - fs::write( - repo_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "debug", - "plugins": [ - { - "name": "enabled-plugin", - "source": { - "source": "local", - "path": "./enabled-plugin" - } - }, - { - "name": "disabled-plugin", - "source": { - "source": "local", - "path": "./disabled-plugin" - } - } - ] -}"#, - ) - .unwrap(); - write_file( - &tmp.path().join(CONFIG_TOML_FILE), - r#"[features] -plugins = true - -[plugins."enabled-plugin@debug"] -enabled = true - -[plugins."disabled-plugin@debug"] -enabled = false -"#, - ); - - let config = load_config(tmp.path(), &repo_root).await; - let marketplaces = PluginsManager::new(tmp.path().to_path_buf()) - .list_marketplaces_for_config(&config, &[AbsolutePathBuf::try_from(repo_root).unwrap()]) - .unwrap(); - - let marketplace = marketplaces - .into_iter() - .find(|marketplace| { - marketplace.path - == AbsolutePathBuf::try_from( - tmp.path().join("repo/.agents/plugins/marketplace.json"), - ) - .unwrap() - }) - .expect("expected repo marketplace entry"); - - assert_eq!( - marketplace, - ConfiguredMarketplaceSummary { - name: "debug".to_string(), - path: AbsolutePathBuf::try_from( - tmp.path().join("repo/.agents/plugins/marketplace.json"), - ) - .unwrap(), - plugins: vec![ - ConfiguredMarketplacePluginSummary { - id: "enabled-plugin@debug".to_string(), - name: "enabled-plugin".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from(tmp.path().join("repo/enabled-plugin")) - .unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - installed: true, - enabled: true, - }, - ConfiguredMarketplacePluginSummary { - id: "disabled-plugin@debug".to_string(), - name: "disabled-plugin".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from( - tmp.path().join("repo/disabled-plugin"), - ) - .unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - installed: true, - enabled: false, - }, - ], - } - ); - } - - #[tokio::test] - async fn list_marketplaces_includes_curated_repo_marketplace() { - let tmp = tempfile::tempdir().unwrap(); - let curated_root = curated_plugins_repo_path(tmp.path()); - let plugin_root = curated_root.join("plugins/linear"); - - fs::create_dir_all(curated_root.join(".agents/plugins")).unwrap(); - fs::create_dir_all(plugin_root.join(".codex-plugin")).unwrap(); - fs::write( - curated_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "openai-curated", - "plugins": [ - { - "name": "linear", - "source": { - "source": "local", - "path": "./plugins/linear" - } - } - ] -}"#, - ) - .unwrap(); - fs::write( - plugin_root.join(".codex-plugin/plugin.json"), - r#"{"name":"linear"}"#, - ) - .unwrap(); - - let config = load_config(tmp.path(), tmp.path()).await; - let marketplaces = PluginsManager::new(tmp.path().to_path_buf()) - .list_marketplaces_for_config(&config, &[]) - .unwrap(); - - let curated_marketplace = marketplaces - .into_iter() - .find(|marketplace| marketplace.name == "openai-curated") - .expect("curated marketplace should be listed"); - - assert_eq!( - curated_marketplace, - ConfiguredMarketplaceSummary { - name: "openai-curated".to_string(), - path: AbsolutePathBuf::try_from( - curated_root.join(".agents/plugins/marketplace.json") - ) - .unwrap(), - plugins: vec![ConfiguredMarketplacePluginSummary { - id: "linear@openai-curated".to_string(), - name: "linear".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from(curated_root.join("plugins/linear")) - .unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - installed: false, - enabled: false, - }], - } - ); - } - - #[tokio::test] - async fn list_marketplaces_uses_first_duplicate_plugin_entry() { - let tmp = tempfile::tempdir().unwrap(); - let repo_a_root = tmp.path().join("repo-a"); - let repo_b_root = tmp.path().join("repo-b"); - fs::create_dir_all(repo_a_root.join(".git")).unwrap(); - fs::create_dir_all(repo_b_root.join(".git")).unwrap(); - fs::create_dir_all(repo_a_root.join(".agents/plugins")).unwrap(); - fs::create_dir_all(repo_b_root.join(".agents/plugins")).unwrap(); - fs::write( - repo_a_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "debug", - "plugins": [ - { - "name": "dup-plugin", - "source": { - "source": "local", - "path": "./from-a" - } - } - ] -}"#, - ) - .unwrap(); - fs::write( - repo_b_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "debug", - "plugins": [ - { - "name": "dup-plugin", - "source": { - "source": "local", - "path": "./from-b" - } - }, - { - "name": "b-only-plugin", - "source": { - "source": "local", - "path": "./from-b-only" - } - } - ] -}"#, - ) - .unwrap(); - write_file( - &tmp.path().join(CONFIG_TOML_FILE), - r#"[features] -plugins = true - -[plugins."dup-plugin@debug"] -enabled = true - -[plugins."b-only-plugin@debug"] -enabled = false -"#, - ); - - let config = load_config(tmp.path(), &repo_a_root).await; - let marketplaces = PluginsManager::new(tmp.path().to_path_buf()) - .list_marketplaces_for_config( - &config, - &[ - AbsolutePathBuf::try_from(repo_a_root).unwrap(), - AbsolutePathBuf::try_from(repo_b_root).unwrap(), - ], - ) - .unwrap(); - - let repo_a_marketplace = marketplaces - .iter() - .find(|marketplace| { - marketplace.path - == AbsolutePathBuf::try_from( - tmp.path().join("repo-a/.agents/plugins/marketplace.json"), - ) - .unwrap() - }) - .expect("repo-a marketplace should be listed"); - assert_eq!( - repo_a_marketplace.plugins, - vec![ConfiguredMarketplacePluginSummary { - id: "dup-plugin@debug".to_string(), - name: "dup-plugin".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from(tmp.path().join("repo-a/from-a")).unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - installed: false, - enabled: true, - }] - ); - - let repo_b_marketplace = marketplaces - .iter() - .find(|marketplace| { - marketplace.path - == AbsolutePathBuf::try_from( - tmp.path().join("repo-b/.agents/plugins/marketplace.json"), - ) - .unwrap() - }) - .expect("repo-b marketplace should be listed"); - assert_eq!( - repo_b_marketplace.plugins, - vec![ConfiguredMarketplacePluginSummary { - id: "b-only-plugin@debug".to_string(), - name: "b-only-plugin".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from(tmp.path().join("repo-b/from-b-only")).unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - installed: false, - enabled: false, - }] - ); - - let duplicate_plugin_count = marketplaces - .iter() - .flat_map(|marketplace| marketplace.plugins.iter()) - .filter(|plugin| plugin.name == "dup-plugin") - .count(); - assert_eq!(duplicate_plugin_count, 1); - } - - #[tokio::test] - async fn list_marketplaces_marks_configured_plugin_uninstalled_when_cache_is_missing() { - let tmp = tempfile::tempdir().unwrap(); - let repo_root = tmp.path().join("repo"); - fs::create_dir_all(repo_root.join(".git")).unwrap(); - fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); - fs::write( - repo_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "debug", - "plugins": [ - { - "name": "sample-plugin", - "source": { - "source": "local", - "path": "./sample-plugin" - } - } - ] -}"#, - ) - .unwrap(); - write_file( - &tmp.path().join(CONFIG_TOML_FILE), - r#"[features] -plugins = true - -[plugins."sample-plugin@debug"] -enabled = true -"#, - ); - - let config = load_config(tmp.path(), &repo_root).await; - let marketplaces = PluginsManager::new(tmp.path().to_path_buf()) - .list_marketplaces_for_config(&config, &[AbsolutePathBuf::try_from(repo_root).unwrap()]) - .unwrap(); - - let marketplace = marketplaces - .into_iter() - .find(|marketplace| { - marketplace.path - == AbsolutePathBuf::try_from( - tmp.path().join("repo/.agents/plugins/marketplace.json"), - ) - .unwrap() - }) - .expect("expected repo marketplace entry"); - - assert_eq!( - marketplace, - ConfiguredMarketplaceSummary { - name: "debug".to_string(), - path: AbsolutePathBuf::try_from( - tmp.path().join("repo/.agents/plugins/marketplace.json"), - ) - .unwrap(), - plugins: vec![ConfiguredMarketplacePluginSummary { - id: "sample-plugin@debug".to_string(), - name: "sample-plugin".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from(tmp.path().join("repo/sample-plugin")) - .unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - installed: false, - enabled: true, - }], - } - ); - } - - #[tokio::test] - async fn sync_plugins_from_remote_reconciles_cache_and_config() { - let tmp = tempfile::tempdir().unwrap(); - let curated_root = curated_plugins_repo_path(tmp.path()); - write_openai_curated_marketplace(&curated_root, &["linear", "gmail", "calendar"]); - write_curated_plugin_sha(tmp.path(), TEST_CURATED_PLUGIN_SHA); - write_plugin( - &tmp.path().join("plugins/cache/openai-curated"), - "linear/local", - "linear", - ); - write_plugin( - &tmp.path().join("plugins/cache/openai-curated"), - "calendar/local", - "calendar", - ); - write_file( - &tmp.path().join(CONFIG_TOML_FILE), - r#"[features] -plugins = true - -[plugins."linear@openai-curated"] -enabled = false - -[plugins."calendar@openai-curated"] -enabled = true -"#, - ); - - let server = MockServer::start().await; - Mock::given(method("GET")) - .and(path("/backend-api/plugins/list")) - .and(header("authorization", "Bearer Access Token")) - .and(header("chatgpt-account-id", "account_id")) - .respond_with(ResponseTemplate::new(200).set_body_string( - r#"[ - {"id":"1","name":"linear","marketplace_name":"openai-curated","version":"1.0.0","enabled":true}, - {"id":"2","name":"gmail","marketplace_name":"openai-curated","version":"1.0.0","enabled":false} -]"#, - )) - .mount(&server) - .await; - - let mut config = load_config(tmp.path(), tmp.path()).await; - config.chatgpt_base_url = format!("{}/backend-api/", server.uri()); - let manager = PluginsManager::new(tmp.path().to_path_buf()); - let result = manager - .sync_plugins_from_remote( - &config, - Some(&CodexAuth::create_dummy_chatgpt_auth_for_testing()), - ) - .await - .unwrap(); - - assert_eq!( - result, - RemotePluginSyncResult { - installed_plugin_ids: vec!["gmail@openai-curated".to_string()], - enabled_plugin_ids: vec!["linear@openai-curated".to_string()], - disabled_plugin_ids: vec!["gmail@openai-curated".to_string()], - uninstalled_plugin_ids: vec!["calendar@openai-curated".to_string()], - } - ); - - assert!( - tmp.path() - .join("plugins/cache/openai-curated/linear/local") - .is_dir() - ); - assert!( - tmp.path() - .join(format!( - "plugins/cache/openai-curated/gmail/{TEST_CURATED_PLUGIN_SHA}" - )) - .is_dir() - ); - assert!( - !tmp.path() - .join("plugins/cache/openai-curated/calendar") - .exists() - ); - - let config = fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).unwrap(); - assert!(config.contains(r#"[plugins."linear@openai-curated"]"#)); - assert!(config.contains(r#"[plugins."gmail@openai-curated"]"#)); - assert!(config.contains("enabled = true")); - assert!(config.contains("enabled = false")); - assert!(!config.contains(r#"[plugins."calendar@openai-curated"]"#)); - - let synced_config = load_config(tmp.path(), tmp.path()).await; - let curated_marketplace = manager - .list_marketplaces_for_config(&synced_config, &[]) - .unwrap() - .into_iter() - .find(|marketplace| marketplace.name == OPENAI_CURATED_MARKETPLACE_NAME) - .unwrap(); - assert_eq!( - curated_marketplace - .plugins - .into_iter() - .map(|plugin| (plugin.id, plugin.installed, plugin.enabled)) - .collect::>(), - vec![ - ("linear@openai-curated".to_string(), true, true), - ("gmail@openai-curated".to_string(), true, false), - ("calendar@openai-curated".to_string(), false, false), - ] - ); - } - - #[tokio::test] - async fn sync_plugins_from_remote_ignores_unknown_remote_plugins() { - let tmp = tempfile::tempdir().unwrap(); - let curated_root = curated_plugins_repo_path(tmp.path()); - write_openai_curated_marketplace(&curated_root, &["linear"]); - write_curated_plugin_sha(tmp.path(), TEST_CURATED_PLUGIN_SHA); - write_file( - &tmp.path().join(CONFIG_TOML_FILE), - r#"[features] -plugins = true - -[plugins."linear@openai-curated"] -enabled = false -"#, - ); - - let server = MockServer::start().await; - Mock::given(method("GET")) - .and(path("/backend-api/plugins/list")) - .respond_with(ResponseTemplate::new(200).set_body_string( - r#"[ - {"id":"1","name":"plugin-one","marketplace_name":"openai-curated","version":"1.0.0","enabled":true} -]"#, - )) - .mount(&server) - .await; - - let mut config = load_config(tmp.path(), tmp.path()).await; - config.chatgpt_base_url = format!("{}/backend-api/", server.uri()); - let manager = PluginsManager::new(tmp.path().to_path_buf()); - let result = manager - .sync_plugins_from_remote( - &config, - Some(&CodexAuth::create_dummy_chatgpt_auth_for_testing()), - ) - .await - .unwrap(); - - assert_eq!( - result, - RemotePluginSyncResult { - installed_plugin_ids: Vec::new(), - enabled_plugin_ids: Vec::new(), - disabled_plugin_ids: Vec::new(), - uninstalled_plugin_ids: vec!["linear@openai-curated".to_string()], - } - ); - let config = fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).unwrap(); - assert!(!config.contains(r#"[plugins."linear@openai-curated"]"#)); - assert!( - !tmp.path() - .join("plugins/cache/openai-curated/linear") - .exists() - ); - } - - #[tokio::test] - async fn sync_plugins_from_remote_keeps_existing_plugins_when_install_fails() { - let tmp = tempfile::tempdir().unwrap(); - let curated_root = curated_plugins_repo_path(tmp.path()); - write_openai_curated_marketplace(&curated_root, &["linear", "gmail"]); - write_curated_plugin_sha(tmp.path(), TEST_CURATED_PLUGIN_SHA); - fs::remove_dir_all(curated_root.join("plugins/gmail")).unwrap(); - write_plugin( - &tmp.path().join("plugins/cache/openai-curated"), - "linear/local", - "linear", - ); - write_file( - &tmp.path().join(CONFIG_TOML_FILE), - r#"[features] -plugins = true - -[plugins."linear@openai-curated"] -enabled = false -"#, - ); - - let server = MockServer::start().await; - Mock::given(method("GET")) - .and(path("/backend-api/plugins/list")) - .respond_with(ResponseTemplate::new(200).set_body_string( - r#"[ - {"id":"1","name":"gmail","marketplace_name":"openai-curated","version":"1.0.0","enabled":true} -]"#, - )) - .mount(&server) - .await; - - let mut config = load_config(tmp.path(), tmp.path()).await; - config.chatgpt_base_url = format!("{}/backend-api/", server.uri()); - let manager = PluginsManager::new(tmp.path().to_path_buf()); - let err = manager - .sync_plugins_from_remote( - &config, - Some(&CodexAuth::create_dummy_chatgpt_auth_for_testing()), - ) - .await - .unwrap_err(); - - assert!(matches!( - err, - PluginRemoteSyncError::Store(PluginStoreError::Invalid(ref message)) - if message.contains("plugin source path is not a directory") - )); - assert!( - tmp.path() - .join("plugins/cache/openai-curated/linear/local") - .is_dir() - ); - assert!( - !tmp.path() - .join("plugins/cache/openai-curated/gmail") - .exists() - ); - - let config = fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).unwrap(); - assert!(config.contains(r#"[plugins."linear@openai-curated"]"#)); - assert!(!config.contains(r#"[plugins."gmail@openai-curated"]"#)); - assert!(config.contains("enabled = false")); - } - - #[tokio::test] - async fn sync_plugins_from_remote_uses_first_duplicate_local_plugin_entry() { - let tmp = tempfile::tempdir().unwrap(); - let curated_root = curated_plugins_repo_path(tmp.path()); - write_curated_plugin_sha(tmp.path(), TEST_CURATED_PLUGIN_SHA); - fs::create_dir_all(curated_root.join(".agents/plugins")).unwrap(); - fs::write( - curated_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "openai-curated", - "plugins": [ - { - "name": "gmail", - "source": { - "source": "local", - "path": "./plugins/gmail-first" - } - }, - { - "name": "gmail", - "source": { - "source": "local", - "path": "./plugins/gmail-second" - } - } - ] -}"#, - ) - .unwrap(); - write_plugin(&curated_root, "plugins/gmail-first", "gmail"); - write_plugin(&curated_root, "plugins/gmail-second", "gmail"); - fs::write(curated_root.join("plugins/gmail-first/marker.txt"), "first").unwrap(); - fs::write( - curated_root.join("plugins/gmail-second/marker.txt"), - "second", - ) - .unwrap(); - write_file( - &tmp.path().join(CONFIG_TOML_FILE), - r#"[features] -plugins = true -"#, - ); - - let server = MockServer::start().await; - Mock::given(method("GET")) - .and(path("/backend-api/plugins/list")) - .respond_with(ResponseTemplate::new(200).set_body_string( - r#"[ - {"id":"1","name":"gmail","marketplace_name":"openai-curated","version":"1.0.0","enabled":true} -]"#, - )) - .mount(&server) - .await; - - let mut config = load_config(tmp.path(), tmp.path()).await; - config.chatgpt_base_url = format!("{}/backend-api/", server.uri()); - let manager = PluginsManager::new(tmp.path().to_path_buf()); - let result = manager - .sync_plugins_from_remote( - &config, - Some(&CodexAuth::create_dummy_chatgpt_auth_for_testing()), - ) - .await - .unwrap(); - - assert_eq!( - result, - RemotePluginSyncResult { - installed_plugin_ids: vec!["gmail@openai-curated".to_string()], - enabled_plugin_ids: vec!["gmail@openai-curated".to_string()], - disabled_plugin_ids: Vec::new(), - uninstalled_plugin_ids: Vec::new(), - } - ); - assert_eq!( - fs::read_to_string(tmp.path().join(format!( - "plugins/cache/openai-curated/gmail/{TEST_CURATED_PLUGIN_SHA}/marker.txt" - ))) - .unwrap(), - "first" - ); - } - - #[test] - fn refresh_curated_plugin_cache_replaces_existing_local_version_with_sha() { - let tmp = tempfile::tempdir().unwrap(); - let curated_root = curated_plugins_repo_path(tmp.path()); - write_openai_curated_marketplace(&curated_root, &["slack"]); - write_curated_plugin_sha(tmp.path(), TEST_CURATED_PLUGIN_SHA); - let plugin_id = PluginId::new( - "slack".to_string(), - OPENAI_CURATED_MARKETPLACE_NAME.to_string(), - ) - .unwrap(); - write_plugin( - &tmp.path().join("plugins/cache/openai-curated"), - "slack/local", - "slack", - ); - - assert!( - refresh_curated_plugin_cache(tmp.path(), TEST_CURATED_PLUGIN_SHA, &[plugin_id]) - .expect("cache refresh should succeed") - ); - - assert!( - !tmp.path() - .join("plugins/cache/openai-curated/slack/local") - .exists() - ); - assert!( - tmp.path() - .join(format!( - "plugins/cache/openai-curated/slack/{TEST_CURATED_PLUGIN_SHA}" - )) - .is_dir() - ); - } - - #[test] - fn refresh_curated_plugin_cache_reinstalls_missing_configured_plugin_with_current_sha() { - let tmp = tempfile::tempdir().unwrap(); - let curated_root = curated_plugins_repo_path(tmp.path()); - write_openai_curated_marketplace(&curated_root, &["slack"]); - write_curated_plugin_sha(tmp.path(), TEST_CURATED_PLUGIN_SHA); - let plugin_id = PluginId::new( - "slack".to_string(), - OPENAI_CURATED_MARKETPLACE_NAME.to_string(), - ) - .unwrap(); - - assert!( - refresh_curated_plugin_cache(tmp.path(), TEST_CURATED_PLUGIN_SHA, &[plugin_id]) - .expect("cache refresh should recreate missing configured plugin") - ); - - assert!( - tmp.path() - .join(format!( - "plugins/cache/openai-curated/slack/{TEST_CURATED_PLUGIN_SHA}" - )) - .is_dir() - ); - } - - #[test] - fn refresh_curated_plugin_cache_returns_false_when_configured_plugins_are_current() { - let tmp = tempfile::tempdir().unwrap(); - let curated_root = curated_plugins_repo_path(tmp.path()); - write_openai_curated_marketplace(&curated_root, &["slack"]); - let plugin_id = PluginId::new( - "slack".to_string(), - OPENAI_CURATED_MARKETPLACE_NAME.to_string(), - ) - .unwrap(); - write_plugin( - &tmp.path().join("plugins/cache/openai-curated"), - &format!("slack/{TEST_CURATED_PLUGIN_SHA}"), - "slack", - ); - - assert!( - !refresh_curated_plugin_cache(tmp.path(), TEST_CURATED_PLUGIN_SHA, &[plugin_id]) - .expect("cache refresh should be a no-op when configured plugins are current") - ); - } - - #[test] - fn load_plugins_ignores_project_config_files() { - let codex_home = TempDir::new().unwrap(); - let project_root = codex_home.path().join("project"); - let plugin_root = codex_home - .path() - .join("plugins/cache") - .join("test/sample/local"); - - write_file( - &plugin_root.join(".codex-plugin/plugin.json"), - r#"{"name":"sample"}"#, - ); - write_file( - &project_root.join(".codex/config.toml"), - &plugin_config_toml(true, true), - ); - - let stack = ConfigLayerStack::new( - vec![ConfigLayerEntry::new( - ConfigLayerSource::Project { - dot_codex_folder: AbsolutePathBuf::try_from(project_root.join(".codex")) - .unwrap(), - }, - toml::from_str(&plugin_config_toml(true, true)) - .expect("project config should parse"), - )], - ConfigRequirements::default(), - ConfigRequirementsToml::default(), - ) - .expect("config layer stack should build"); - - let outcome = PluginsManager::new(codex_home.path().to_path_buf()).plugins_for_layer_stack( - &project_root, - &stack, - false, - ); - - assert_eq!(outcome, PluginLoadOutcome::default()); - } -} +#[path = "manager_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/plugins/manager_tests.rs b/codex-rs/core/src/plugins/manager_tests.rs new file mode 100644 index 0000000000..2f543aa679 --- /dev/null +++ b/codex-rs/core/src/plugins/manager_tests.rs @@ -0,0 +1,1626 @@ +use super::*; +use crate::auth::CodexAuth; +use crate::config::CONFIG_TOML_FILE; +use crate::config::ConfigBuilder; +use crate::config::types::McpServerTransportConfig; +use crate::config_loader::ConfigLayerEntry; +use crate::config_loader::ConfigLayerStack; +use crate::config_loader::ConfigRequirements; +use crate::config_loader::ConfigRequirementsToml; +use codex_app_server_protocol::ConfigLayerSource; +use pretty_assertions::assert_eq; +use std::fs; +use tempfile::TempDir; +use toml::Value; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::header; +use wiremock::matchers::method; +use wiremock::matchers::path; + +const TEST_CURATED_PLUGIN_SHA: &str = "0123456789abcdef0123456789abcdef01234567"; + +fn write_file(path: &Path, contents: &str) { + fs::create_dir_all(path.parent().expect("file should have a parent")).unwrap(); + fs::write(path, contents).unwrap(); +} + +fn write_plugin(root: &Path, dir_name: &str, manifest_name: &str) { + let plugin_root = root.join(dir_name); + fs::create_dir_all(plugin_root.join(".codex-plugin")).unwrap(); + fs::create_dir_all(plugin_root.join("skills")).unwrap(); + fs::write( + plugin_root.join(".codex-plugin/plugin.json"), + format!(r#"{{"name":"{manifest_name}"}}"#), + ) + .unwrap(); + fs::write(plugin_root.join("skills/SKILL.md"), "skill").unwrap(); + fs::write(plugin_root.join(".mcp.json"), r#"{"mcpServers":{}}"#).unwrap(); +} + +fn write_openai_curated_marketplace(root: &Path, plugin_names: &[&str]) { + fs::create_dir_all(root.join(".agents/plugins")).unwrap(); + let plugins = plugin_names + .iter() + .map(|plugin_name| { + format!( + r#"{{ + "name": "{plugin_name}", + "source": {{ + "source": "local", + "path": "./plugins/{plugin_name}" + }} + }}"# + ) + }) + .collect::>() + .join(",\n"); + fs::write( + root.join(".agents/plugins/marketplace.json"), + format!( + r#"{{ + "name": "{OPENAI_CURATED_MARKETPLACE_NAME}", + "plugins": [ +{plugins} + ] +}}"# + ), + ) + .unwrap(); + for plugin_name in plugin_names { + write_plugin(root, &format!("plugins/{plugin_name}"), plugin_name); + } +} + +fn write_curated_plugin_sha(codex_home: &Path, sha: &str) { + write_file(&codex_home.join(".tmp/plugins.sha"), &format!("{sha}\n")); +} + +fn plugin_config_toml(enabled: bool, plugins_feature_enabled: bool) -> String { + let mut root = toml::map::Map::new(); + + let mut features = toml::map::Map::new(); + features.insert( + "plugins".to_string(), + Value::Boolean(plugins_feature_enabled), + ); + root.insert("features".to_string(), Value::Table(features)); + + let mut plugin = toml::map::Map::new(); + plugin.insert("enabled".to_string(), Value::Boolean(enabled)); + + let mut plugins = toml::map::Map::new(); + plugins.insert("sample@test".to_string(), Value::Table(plugin)); + root.insert("plugins".to_string(), Value::Table(plugins)); + + toml::to_string(&Value::Table(root)).expect("plugin test config should serialize") +} + +fn load_plugins_from_config(config_toml: &str, codex_home: &Path) -> PluginLoadOutcome { + write_file(&codex_home.join(CONFIG_TOML_FILE), config_toml); + let stack = ConfigLayerStack::new( + vec![ConfigLayerEntry::new( + ConfigLayerSource::User { + file: AbsolutePathBuf::try_from(codex_home.join(CONFIG_TOML_FILE)).unwrap(), + }, + toml::from_str(config_toml).expect("plugin test config should parse"), + )], + ConfigRequirements::default(), + ConfigRequirementsToml::default(), + ) + .expect("config layer stack should build"); + PluginsManager::new(codex_home.to_path_buf()).plugins_for_layer_stack(codex_home, &stack, false) +} + +async fn load_config(codex_home: &Path, cwd: &Path) -> crate::config::Config { + ConfigBuilder::default() + .codex_home(codex_home.to_path_buf()) + .fallback_cwd(Some(cwd.to_path_buf())) + .build() + .await + .expect("config should load") +} + +#[test] +fn load_plugins_loads_default_skills_and_mcp_servers() { + let codex_home = TempDir::new().unwrap(); + let plugin_root = codex_home + .path() + .join("plugins/cache") + .join("test/sample/local"); + + write_file( + &plugin_root.join(".codex-plugin/plugin.json"), + r#"{ + "name": "sample", + "description": "Plugin that includes the sample MCP server and Skills" +}"#, + ); + write_file( + &plugin_root.join("skills/sample-search/SKILL.md"), + "---\nname: sample-search\ndescription: search sample data\n---\n", + ); + write_file( + &plugin_root.join(".mcp.json"), + r#"{ + "mcpServers": { + "sample": { + "type": "http", + "url": "https://sample.example/mcp", + "oauth": { + "clientId": "client-id", + "callbackPort": 3118 + } + } + } +}"#, + ); + write_file( + &plugin_root.join(".app.json"), + r#"{ + "apps": { + "example": { + "id": "connector_example" + } + } +}"#, + ); + + let outcome = load_plugins_from_config(&plugin_config_toml(true, true), codex_home.path()); + + assert_eq!( + outcome.plugins, + vec![LoadedPlugin { + config_name: "sample@test".to_string(), + manifest_name: Some("sample".to_string()), + manifest_description: Some( + "Plugin that includes the sample MCP server and Skills".to_string(), + ), + root: AbsolutePathBuf::try_from(plugin_root.clone()).unwrap(), + enabled: true, + skill_roots: vec![plugin_root.join("skills")], + mcp_servers: HashMap::from([( + "sample".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://sample.example/mcp".to_string(), + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + )]), + apps: vec![AppConnectorId("connector_example".to_string())], + error: None, + }] + ); + assert_eq!( + outcome.capability_summaries(), + &[PluginCapabilitySummary { + config_name: "sample@test".to_string(), + display_name: "sample".to_string(), + description: Some("Plugin that includes the sample MCP server and Skills".to_string(),), + has_skills: true, + mcp_server_names: vec!["sample".to_string()], + app_connector_ids: vec![AppConnectorId("connector_example".to_string())], + }] + ); + assert_eq!( + outcome.effective_skill_roots(), + vec![plugin_root.join("skills")] + ); + assert_eq!(outcome.effective_mcp_servers().len(), 1); + assert_eq!( + outcome.effective_apps(), + vec![AppConnectorId("connector_example".to_string())] + ); +} + +#[test] +fn load_plugins_uses_manifest_configured_component_paths() { + let codex_home = TempDir::new().unwrap(); + let plugin_root = codex_home + .path() + .join("plugins/cache") + .join("test/sample/local"); + + write_file( + &plugin_root.join(".codex-plugin/plugin.json"), + r#"{ + "name": "sample", + "skills": "./custom-skills/", + "mcpServers": "./config/custom.mcp.json", + "apps": "./config/custom.app.json" +}"#, + ); + write_file( + &plugin_root.join("skills/default-skill/SKILL.md"), + "---\nname: default-skill\ndescription: default skill\n---\n", + ); + write_file( + &plugin_root.join("custom-skills/custom-skill/SKILL.md"), + "---\nname: custom-skill\ndescription: custom skill\n---\n", + ); + write_file( + &plugin_root.join(".mcp.json"), + r#"{ + "mcpServers": { + "default": { + "type": "http", + "url": "https://default.example/mcp" + } + } +}"#, + ); + write_file( + &plugin_root.join("config/custom.mcp.json"), + r#"{ + "mcpServers": { + "custom": { + "type": "http", + "url": "https://custom.example/mcp" + } + } +}"#, + ); + write_file( + &plugin_root.join(".app.json"), + r#"{ + "apps": { + "default": { + "id": "connector_default" + } + } +}"#, + ); + write_file( + &plugin_root.join("config/custom.app.json"), + r#"{ + "apps": { + "custom": { + "id": "connector_custom" + } + } +}"#, + ); + + let outcome = load_plugins_from_config(&plugin_config_toml(true, true), codex_home.path()); + + assert_eq!( + outcome.plugins[0].skill_roots, + vec![ + plugin_root.join("custom-skills"), + plugin_root.join("skills") + ] + ); + assert_eq!( + outcome.plugins[0].mcp_servers, + HashMap::from([( + "custom".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://custom.example/mcp".to_string(), + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + )]) + ); + assert_eq!( + outcome.plugins[0].apps, + vec![AppConnectorId("connector_custom".to_string())] + ); +} + +#[test] +fn load_plugins_ignores_manifest_component_paths_without_dot_slash() { + let codex_home = TempDir::new().unwrap(); + let plugin_root = codex_home + .path() + .join("plugins/cache") + .join("test/sample/local"); + + write_file( + &plugin_root.join(".codex-plugin/plugin.json"), + r#"{ + "name": "sample", + "skills": "custom-skills", + "mcpServers": "config/custom.mcp.json", + "apps": "config/custom.app.json" +}"#, + ); + write_file( + &plugin_root.join("skills/default-skill/SKILL.md"), + "---\nname: default-skill\ndescription: default skill\n---\n", + ); + write_file( + &plugin_root.join("custom-skills/custom-skill/SKILL.md"), + "---\nname: custom-skill\ndescription: custom skill\n---\n", + ); + write_file( + &plugin_root.join(".mcp.json"), + r#"{ + "mcpServers": { + "default": { + "type": "http", + "url": "https://default.example/mcp" + } + } +}"#, + ); + write_file( + &plugin_root.join("config/custom.mcp.json"), + r#"{ + "mcpServers": { + "custom": { + "type": "http", + "url": "https://custom.example/mcp" + } + } +}"#, + ); + write_file( + &plugin_root.join(".app.json"), + r#"{ + "apps": { + "default": { + "id": "connector_default" + } + } +}"#, + ); + write_file( + &plugin_root.join("config/custom.app.json"), + r#"{ + "apps": { + "custom": { + "id": "connector_custom" + } + } +}"#, + ); + + let outcome = load_plugins_from_config(&plugin_config_toml(true, true), codex_home.path()); + + assert_eq!( + outcome.plugins[0].skill_roots, + vec![plugin_root.join("skills")] + ); + assert_eq!( + outcome.plugins[0].mcp_servers, + HashMap::from([( + "default".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://default.example/mcp".to_string(), + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }, + )]) + ); + assert_eq!( + outcome.plugins[0].apps, + vec![AppConnectorId("connector_default".to_string())] + ); +} + +#[test] +fn load_plugins_preserves_disabled_plugins_without_effective_contributions() { + let codex_home = TempDir::new().unwrap(); + let plugin_root = codex_home + .path() + .join("plugins/cache") + .join("test/sample/local"); + + write_file( + &plugin_root.join(".codex-plugin/plugin.json"), + r#"{"name":"sample"}"#, + ); + write_file( + &plugin_root.join(".mcp.json"), + r#"{ + "mcpServers": { + "sample": { + "type": "http", + "url": "https://sample.example/mcp" + } + } +}"#, + ); + + let outcome = load_plugins_from_config(&plugin_config_toml(false, true), codex_home.path()); + + assert_eq!( + outcome.plugins, + vec![LoadedPlugin { + config_name: "sample@test".to_string(), + manifest_name: None, + manifest_description: None, + root: AbsolutePathBuf::try_from(plugin_root).unwrap(), + enabled: false, + skill_roots: Vec::new(), + mcp_servers: HashMap::new(), + apps: Vec::new(), + error: None, + }] + ); + assert!(outcome.effective_skill_roots().is_empty()); + assert!(outcome.effective_mcp_servers().is_empty()); +} + +#[test] +fn effective_apps_dedupes_connector_ids_across_plugins() { + let codex_home = TempDir::new().unwrap(); + let plugin_a_root = codex_home + .path() + .join("plugins/cache") + .join("test/plugin-a/local"); + let plugin_b_root = codex_home + .path() + .join("plugins/cache") + .join("test/plugin-b/local"); + + write_file( + &plugin_a_root.join(".codex-plugin/plugin.json"), + r#"{"name":"plugin-a"}"#, + ); + write_file( + &plugin_a_root.join(".app.json"), + r#"{ + "apps": { + "example": { + "id": "connector_example" + } + } +}"#, + ); + write_file( + &plugin_b_root.join(".codex-plugin/plugin.json"), + r#"{"name":"plugin-b"}"#, + ); + write_file( + &plugin_b_root.join(".app.json"), + r#"{ + "apps": { + "chat": { + "id": "connector_example" + }, + "gmail": { + "id": "connector_gmail" + } + } +}"#, + ); + + let mut root = toml::map::Map::new(); + let mut features = toml::map::Map::new(); + features.insert("plugins".to_string(), Value::Boolean(true)); + root.insert("features".to_string(), Value::Table(features)); + + let mut plugins = toml::map::Map::new(); + + let mut plugin_a = toml::map::Map::new(); + plugin_a.insert("enabled".to_string(), Value::Boolean(true)); + plugins.insert("plugin-a@test".to_string(), Value::Table(plugin_a)); + + let mut plugin_b = toml::map::Map::new(); + plugin_b.insert("enabled".to_string(), Value::Boolean(true)); + plugins.insert("plugin-b@test".to_string(), Value::Table(plugin_b)); + + root.insert("plugins".to_string(), Value::Table(plugins)); + let config_toml = + toml::to_string(&Value::Table(root)).expect("plugin test config should serialize"); + + let outcome = load_plugins_from_config(&config_toml, codex_home.path()); + + assert_eq!( + outcome.effective_apps(), + vec![ + AppConnectorId("connector_example".to_string()), + AppConnectorId("connector_gmail".to_string()), + ] + ); +} + +#[test] +fn capability_index_filters_inactive_and_zero_capability_plugins() { + let codex_home = TempDir::new().unwrap(); + let connector = |id: &str| AppConnectorId(id.to_string()); + let http_server = |url: &str| McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: url.to_string(), + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, + }, + enabled: true, + required: false, + disabled_reason: None, + startup_timeout_sec: None, + tool_timeout_sec: None, + enabled_tools: None, + disabled_tools: None, + scopes: None, + oauth_resource: None, + }; + let plugin = |config_name: &str, dir_name: &str, manifest_name: &str| LoadedPlugin { + config_name: config_name.to_string(), + manifest_name: Some(manifest_name.to_string()), + manifest_description: None, + root: AbsolutePathBuf::try_from(codex_home.path().join(dir_name)).unwrap(), + enabled: true, + skill_roots: Vec::new(), + mcp_servers: HashMap::new(), + apps: Vec::new(), + error: None, + }; + let summary = |config_name: &str, display_name: &str| PluginCapabilitySummary { + config_name: config_name.to_string(), + display_name: display_name.to_string(), + description: None, + ..PluginCapabilitySummary::default() + }; + let outcome = PluginLoadOutcome::from_plugins(vec![ + LoadedPlugin { + skill_roots: vec![codex_home.path().join("skills-plugin/skills")], + ..plugin("skills@test", "skills-plugin", "skills-plugin") + }, + LoadedPlugin { + mcp_servers: HashMap::from([("alpha".to_string(), http_server("https://alpha"))]), + apps: vec![connector("connector_example")], + ..plugin("alpha@test", "alpha-plugin", "alpha-plugin") + }, + LoadedPlugin { + mcp_servers: HashMap::from([("beta".to_string(), http_server("https://beta"))]), + apps: vec![connector("connector_example"), connector("connector_gmail")], + ..plugin("beta@test", "beta-plugin", "beta-plugin") + }, + plugin("empty@test", "empty-plugin", "empty-plugin"), + LoadedPlugin { + enabled: false, + skill_roots: vec![codex_home.path().join("disabled-plugin/skills")], + apps: vec![connector("connector_hidden")], + ..plugin("disabled@test", "disabled-plugin", "disabled-plugin") + }, + LoadedPlugin { + apps: vec![connector("connector_broken")], + error: Some("failed to load".to_string()), + ..plugin("broken@test", "broken-plugin", "broken-plugin") + }, + ]); + + assert_eq!( + outcome.capability_summaries(), + &[ + PluginCapabilitySummary { + has_skills: true, + ..summary("skills@test", "skills-plugin") + }, + PluginCapabilitySummary { + mcp_server_names: vec!["alpha".to_string()], + app_connector_ids: vec![connector("connector_example")], + ..summary("alpha@test", "alpha-plugin") + }, + PluginCapabilitySummary { + mcp_server_names: vec!["beta".to_string()], + app_connector_ids: vec![ + connector("connector_example"), + connector("connector_gmail"), + ], + ..summary("beta@test", "beta-plugin") + }, + ] + ); +} + +#[test] +fn plugin_namespace_for_skill_path_uses_manifest_name() { + let codex_home = TempDir::new().unwrap(); + let plugin_root = codex_home.path().join("plugins/sample"); + let skill_path = plugin_root.join("skills/search/SKILL.md"); + + write_file( + &plugin_root.join(".codex-plugin/plugin.json"), + r#"{"name":"sample"}"#, + ); + write_file(&skill_path, "---\ndescription: search\n---\n"); + + assert_eq!( + plugin_namespace_for_skill_path(&skill_path), + Some("sample".to_string()) + ); +} + +#[test] +fn load_plugins_returns_empty_when_feature_disabled() { + let codex_home = TempDir::new().unwrap(); + let plugin_root = codex_home + .path() + .join("plugins/cache") + .join("test/sample/local"); + + write_file( + &plugin_root.join(".codex-plugin/plugin.json"), + r#"{"name":"sample"}"#, + ); + write_file( + &plugin_root.join("skills/sample-search/SKILL.md"), + "---\nname: sample-search\ndescription: search sample data\n---\n", + ); + + let outcome = load_plugins_from_config(&plugin_config_toml(true, false), codex_home.path()); + + assert_eq!(outcome, PluginLoadOutcome::default()); +} + +#[test] +fn load_plugins_rejects_invalid_plugin_keys() { + let codex_home = TempDir::new().unwrap(); + let plugin_root = codex_home + .path() + .join("plugins/cache") + .join("test/sample/local"); + + write_file( + &plugin_root.join(".codex-plugin/plugin.json"), + r#"{"name":"sample"}"#, + ); + + let mut root = toml::map::Map::new(); + let mut features = toml::map::Map::new(); + features.insert("plugins".to_string(), Value::Boolean(true)); + root.insert("features".to_string(), Value::Table(features)); + + let mut plugin = toml::map::Map::new(); + plugin.insert("enabled".to_string(), Value::Boolean(true)); + + let mut plugins = toml::map::Map::new(); + plugins.insert("sample".to_string(), Value::Table(plugin)); + root.insert("plugins".to_string(), Value::Table(plugins)); + + let outcome = load_plugins_from_config( + &toml::to_string(&Value::Table(root)).expect("plugin test config should serialize"), + codex_home.path(), + ); + + assert_eq!(outcome.plugins.len(), 1); + assert_eq!( + outcome.plugins[0].error.as_deref(), + Some("invalid plugin key `sample`; expected @") + ); + assert!(outcome.effective_skill_roots().is_empty()); + assert!(outcome.effective_mcp_servers().is_empty()); +} + +#[tokio::test] +async fn install_plugin_updates_config_with_relative_path_and_plugin_key() { + let tmp = tempfile::tempdir().unwrap(); + let repo_root = tmp.path().join("repo"); + fs::create_dir_all(repo_root.join(".git")).unwrap(); + fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); + write_plugin(&repo_root, "sample-plugin", "sample-plugin"); + fs::write( + repo_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "debug", + "plugins": [ + { + "name": "sample-plugin", + "source": { + "source": "local", + "path": "./sample-plugin" + }, + "authPolicy": "ON_USE" + } + ] +}"#, + ) + .unwrap(); + + let result = PluginsManager::new(tmp.path().to_path_buf()) + .install_plugin(PluginInstallRequest { + plugin_name: "sample-plugin".to_string(), + marketplace_path: AbsolutePathBuf::try_from( + repo_root.join(".agents/plugins/marketplace.json"), + ) + .unwrap(), + }) + .await + .unwrap(); + + let installed_path = tmp.path().join("plugins/cache/debug/sample-plugin/local"); + assert_eq!( + result, + PluginInstallOutcome { + plugin_id: PluginId::new("sample-plugin".to_string(), "debug".to_string()).unwrap(), + plugin_version: "local".to_string(), + installed_path: AbsolutePathBuf::try_from(installed_path).unwrap(), + auth_policy: MarketplacePluginAuthPolicy::OnUse, + } + ); + + let config = fs::read_to_string(tmp.path().join("config.toml")).unwrap(); + assert!(config.contains(r#"[plugins."sample-plugin@debug"]"#)); + assert!(config.contains("enabled = true")); +} + +#[tokio::test] +async fn uninstall_plugin_removes_cache_and_config_entry() { + let tmp = tempfile::tempdir().unwrap(); + write_plugin( + &tmp.path().join("plugins/cache/debug"), + "sample-plugin/local", + "sample-plugin", + ); + write_file( + &tmp.path().join(CONFIG_TOML_FILE), + r#"[features] +plugins = true + +[plugins."sample-plugin@debug"] +enabled = true +"#, + ); + + let manager = PluginsManager::new(tmp.path().to_path_buf()); + manager + .uninstall_plugin("sample-plugin@debug".to_string()) + .await + .unwrap(); + manager + .uninstall_plugin("sample-plugin@debug".to_string()) + .await + .unwrap(); + + assert!( + !tmp.path() + .join("plugins/cache/debug/sample-plugin") + .exists() + ); + let config = fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).unwrap(); + assert!(!config.contains(r#"[plugins."sample-plugin@debug"]"#)); +} + +#[tokio::test] +async fn list_marketplaces_includes_enabled_state() { + let tmp = tempfile::tempdir().unwrap(); + let repo_root = tmp.path().join("repo"); + fs::create_dir_all(repo_root.join(".git")).unwrap(); + fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); + write_plugin( + &tmp.path().join("plugins/cache/debug"), + "enabled-plugin/local", + "enabled-plugin", + ); + write_plugin( + &tmp.path().join("plugins/cache/debug"), + "disabled-plugin/local", + "disabled-plugin", + ); + fs::write( + repo_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "debug", + "plugins": [ + { + "name": "enabled-plugin", + "source": { + "source": "local", + "path": "./enabled-plugin" + } + }, + { + "name": "disabled-plugin", + "source": { + "source": "local", + "path": "./disabled-plugin" + } + } + ] +}"#, + ) + .unwrap(); + write_file( + &tmp.path().join(CONFIG_TOML_FILE), + r#"[features] +plugins = true + +[plugins."enabled-plugin@debug"] +enabled = true + +[plugins."disabled-plugin@debug"] +enabled = false +"#, + ); + + let config = load_config(tmp.path(), &repo_root).await; + let marketplaces = PluginsManager::new(tmp.path().to_path_buf()) + .list_marketplaces_for_config(&config, &[AbsolutePathBuf::try_from(repo_root).unwrap()]) + .unwrap(); + + let marketplace = marketplaces + .into_iter() + .find(|marketplace| { + marketplace.path + == AbsolutePathBuf::try_from( + tmp.path().join("repo/.agents/plugins/marketplace.json"), + ) + .unwrap() + }) + .expect("expected repo marketplace entry"); + + assert_eq!( + marketplace, + ConfiguredMarketplaceSummary { + name: "debug".to_string(), + path: AbsolutePathBuf::try_from( + tmp.path().join("repo/.agents/plugins/marketplace.json"), + ) + .unwrap(), + plugins: vec![ + ConfiguredMarketplacePluginSummary { + id: "enabled-plugin@debug".to_string(), + name: "enabled-plugin".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(tmp.path().join("repo/enabled-plugin")) + .unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + installed: true, + enabled: true, + }, + ConfiguredMarketplacePluginSummary { + id: "disabled-plugin@debug".to_string(), + name: "disabled-plugin".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(tmp.path().join("repo/disabled-plugin"),) + .unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + installed: true, + enabled: false, + }, + ], + } + ); +} + +#[tokio::test] +async fn list_marketplaces_includes_curated_repo_marketplace() { + let tmp = tempfile::tempdir().unwrap(); + let curated_root = curated_plugins_repo_path(tmp.path()); + let plugin_root = curated_root.join("plugins/linear"); + + fs::create_dir_all(curated_root.join(".agents/plugins")).unwrap(); + fs::create_dir_all(plugin_root.join(".codex-plugin")).unwrap(); + fs::write( + curated_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "openai-curated", + "plugins": [ + { + "name": "linear", + "source": { + "source": "local", + "path": "./plugins/linear" + } + } + ] +}"#, + ) + .unwrap(); + fs::write( + plugin_root.join(".codex-plugin/plugin.json"), + r#"{"name":"linear"}"#, + ) + .unwrap(); + + let config = load_config(tmp.path(), tmp.path()).await; + let marketplaces = PluginsManager::new(tmp.path().to_path_buf()) + .list_marketplaces_for_config(&config, &[]) + .unwrap(); + + let curated_marketplace = marketplaces + .into_iter() + .find(|marketplace| marketplace.name == "openai-curated") + .expect("curated marketplace should be listed"); + + assert_eq!( + curated_marketplace, + ConfiguredMarketplaceSummary { + name: "openai-curated".to_string(), + path: AbsolutePathBuf::try_from(curated_root.join(".agents/plugins/marketplace.json")) + .unwrap(), + plugins: vec![ConfiguredMarketplacePluginSummary { + id: "linear@openai-curated".to_string(), + name: "linear".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(curated_root.join("plugins/linear")).unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + installed: false, + enabled: false, + }], + } + ); +} + +#[tokio::test] +async fn list_marketplaces_uses_first_duplicate_plugin_entry() { + let tmp = tempfile::tempdir().unwrap(); + let repo_a_root = tmp.path().join("repo-a"); + let repo_b_root = tmp.path().join("repo-b"); + fs::create_dir_all(repo_a_root.join(".git")).unwrap(); + fs::create_dir_all(repo_b_root.join(".git")).unwrap(); + fs::create_dir_all(repo_a_root.join(".agents/plugins")).unwrap(); + fs::create_dir_all(repo_b_root.join(".agents/plugins")).unwrap(); + fs::write( + repo_a_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "debug", + "plugins": [ + { + "name": "dup-plugin", + "source": { + "source": "local", + "path": "./from-a" + } + } + ] +}"#, + ) + .unwrap(); + fs::write( + repo_b_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "debug", + "plugins": [ + { + "name": "dup-plugin", + "source": { + "source": "local", + "path": "./from-b" + } + }, + { + "name": "b-only-plugin", + "source": { + "source": "local", + "path": "./from-b-only" + } + } + ] +}"#, + ) + .unwrap(); + write_file( + &tmp.path().join(CONFIG_TOML_FILE), + r#"[features] +plugins = true + +[plugins."dup-plugin@debug"] +enabled = true + +[plugins."b-only-plugin@debug"] +enabled = false +"#, + ); + + let config = load_config(tmp.path(), &repo_a_root).await; + let marketplaces = PluginsManager::new(tmp.path().to_path_buf()) + .list_marketplaces_for_config( + &config, + &[ + AbsolutePathBuf::try_from(repo_a_root).unwrap(), + AbsolutePathBuf::try_from(repo_b_root).unwrap(), + ], + ) + .unwrap(); + + let repo_a_marketplace = marketplaces + .iter() + .find(|marketplace| { + marketplace.path + == AbsolutePathBuf::try_from( + tmp.path().join("repo-a/.agents/plugins/marketplace.json"), + ) + .unwrap() + }) + .expect("repo-a marketplace should be listed"); + assert_eq!( + repo_a_marketplace.plugins, + vec![ConfiguredMarketplacePluginSummary { + id: "dup-plugin@debug".to_string(), + name: "dup-plugin".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(tmp.path().join("repo-a/from-a")).unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + installed: false, + enabled: true, + }] + ); + + let repo_b_marketplace = marketplaces + .iter() + .find(|marketplace| { + marketplace.path + == AbsolutePathBuf::try_from( + tmp.path().join("repo-b/.agents/plugins/marketplace.json"), + ) + .unwrap() + }) + .expect("repo-b marketplace should be listed"); + assert_eq!( + repo_b_marketplace.plugins, + vec![ConfiguredMarketplacePluginSummary { + id: "b-only-plugin@debug".to_string(), + name: "b-only-plugin".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(tmp.path().join("repo-b/from-b-only")).unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + installed: false, + enabled: false, + }] + ); + + let duplicate_plugin_count = marketplaces + .iter() + .flat_map(|marketplace| marketplace.plugins.iter()) + .filter(|plugin| plugin.name == "dup-plugin") + .count(); + assert_eq!(duplicate_plugin_count, 1); +} + +#[tokio::test] +async fn list_marketplaces_marks_configured_plugin_uninstalled_when_cache_is_missing() { + let tmp = tempfile::tempdir().unwrap(); + let repo_root = tmp.path().join("repo"); + fs::create_dir_all(repo_root.join(".git")).unwrap(); + fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); + fs::write( + repo_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "debug", + "plugins": [ + { + "name": "sample-plugin", + "source": { + "source": "local", + "path": "./sample-plugin" + } + } + ] +}"#, + ) + .unwrap(); + write_file( + &tmp.path().join(CONFIG_TOML_FILE), + r#"[features] +plugins = true + +[plugins."sample-plugin@debug"] +enabled = true +"#, + ); + + let config = load_config(tmp.path(), &repo_root).await; + let marketplaces = PluginsManager::new(tmp.path().to_path_buf()) + .list_marketplaces_for_config(&config, &[AbsolutePathBuf::try_from(repo_root).unwrap()]) + .unwrap(); + + let marketplace = marketplaces + .into_iter() + .find(|marketplace| { + marketplace.path + == AbsolutePathBuf::try_from( + tmp.path().join("repo/.agents/plugins/marketplace.json"), + ) + .unwrap() + }) + .expect("expected repo marketplace entry"); + + assert_eq!( + marketplace, + ConfiguredMarketplaceSummary { + name: "debug".to_string(), + path: AbsolutePathBuf::try_from( + tmp.path().join("repo/.agents/plugins/marketplace.json"), + ) + .unwrap(), + plugins: vec![ConfiguredMarketplacePluginSummary { + id: "sample-plugin@debug".to_string(), + name: "sample-plugin".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(tmp.path().join("repo/sample-plugin")).unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + installed: false, + enabled: true, + }], + } + ); +} + +#[tokio::test] +async fn sync_plugins_from_remote_reconciles_cache_and_config() { + let tmp = tempfile::tempdir().unwrap(); + let curated_root = curated_plugins_repo_path(tmp.path()); + write_openai_curated_marketplace(&curated_root, &["linear", "gmail", "calendar"]); + write_curated_plugin_sha(tmp.path(), TEST_CURATED_PLUGIN_SHA); + write_plugin( + &tmp.path().join("plugins/cache/openai-curated"), + "linear/local", + "linear", + ); + write_plugin( + &tmp.path().join("plugins/cache/openai-curated"), + "calendar/local", + "calendar", + ); + write_file( + &tmp.path().join(CONFIG_TOML_FILE), + r#"[features] +plugins = true + +[plugins."linear@openai-curated"] +enabled = false + +[plugins."calendar@openai-curated"] +enabled = true +"#, + ); + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/backend-api/plugins/list")) + .and(header("authorization", "Bearer Access Token")) + .and(header("chatgpt-account-id", "account_id")) + .respond_with(ResponseTemplate::new(200).set_body_string( + r#"[ + {"id":"1","name":"linear","marketplace_name":"openai-curated","version":"1.0.0","enabled":true}, + {"id":"2","name":"gmail","marketplace_name":"openai-curated","version":"1.0.0","enabled":false} +]"#, + )) + .mount(&server) + .await; + + let mut config = load_config(tmp.path(), tmp.path()).await; + config.chatgpt_base_url = format!("{}/backend-api/", server.uri()); + let manager = PluginsManager::new(tmp.path().to_path_buf()); + let result = manager + .sync_plugins_from_remote( + &config, + Some(&CodexAuth::create_dummy_chatgpt_auth_for_testing()), + ) + .await + .unwrap(); + + assert_eq!( + result, + RemotePluginSyncResult { + installed_plugin_ids: vec!["gmail@openai-curated".to_string()], + enabled_plugin_ids: vec!["linear@openai-curated".to_string()], + disabled_plugin_ids: vec!["gmail@openai-curated".to_string()], + uninstalled_plugin_ids: vec!["calendar@openai-curated".to_string()], + } + ); + + assert!( + tmp.path() + .join("plugins/cache/openai-curated/linear/local") + .is_dir() + ); + assert!( + tmp.path() + .join(format!( + "plugins/cache/openai-curated/gmail/{TEST_CURATED_PLUGIN_SHA}" + )) + .is_dir() + ); + assert!( + !tmp.path() + .join("plugins/cache/openai-curated/calendar") + .exists() + ); + + let config = fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).unwrap(); + assert!(config.contains(r#"[plugins."linear@openai-curated"]"#)); + assert!(config.contains(r#"[plugins."gmail@openai-curated"]"#)); + assert!(config.contains("enabled = true")); + assert!(config.contains("enabled = false")); + assert!(!config.contains(r#"[plugins."calendar@openai-curated"]"#)); + + let synced_config = load_config(tmp.path(), tmp.path()).await; + let curated_marketplace = manager + .list_marketplaces_for_config(&synced_config, &[]) + .unwrap() + .into_iter() + .find(|marketplace| marketplace.name == OPENAI_CURATED_MARKETPLACE_NAME) + .unwrap(); + assert_eq!( + curated_marketplace + .plugins + .into_iter() + .map(|plugin| (plugin.id, plugin.installed, plugin.enabled)) + .collect::>(), + vec![ + ("linear@openai-curated".to_string(), true, true), + ("gmail@openai-curated".to_string(), true, false), + ("calendar@openai-curated".to_string(), false, false), + ] + ); +} + +#[tokio::test] +async fn sync_plugins_from_remote_ignores_unknown_remote_plugins() { + let tmp = tempfile::tempdir().unwrap(); + let curated_root = curated_plugins_repo_path(tmp.path()); + write_openai_curated_marketplace(&curated_root, &["linear"]); + write_curated_plugin_sha(tmp.path(), TEST_CURATED_PLUGIN_SHA); + write_file( + &tmp.path().join(CONFIG_TOML_FILE), + r#"[features] +plugins = true + +[plugins."linear@openai-curated"] +enabled = false +"#, + ); + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/backend-api/plugins/list")) + .respond_with(ResponseTemplate::new(200).set_body_string( + r#"[ + {"id":"1","name":"plugin-one","marketplace_name":"openai-curated","version":"1.0.0","enabled":true} +]"#, + )) + .mount(&server) + .await; + + let mut config = load_config(tmp.path(), tmp.path()).await; + config.chatgpt_base_url = format!("{}/backend-api/", server.uri()); + let manager = PluginsManager::new(tmp.path().to_path_buf()); + let result = manager + .sync_plugins_from_remote( + &config, + Some(&CodexAuth::create_dummy_chatgpt_auth_for_testing()), + ) + .await + .unwrap(); + + assert_eq!( + result, + RemotePluginSyncResult { + installed_plugin_ids: Vec::new(), + enabled_plugin_ids: Vec::new(), + disabled_plugin_ids: Vec::new(), + uninstalled_plugin_ids: vec!["linear@openai-curated".to_string()], + } + ); + let config = fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).unwrap(); + assert!(!config.contains(r#"[plugins."linear@openai-curated"]"#)); + assert!( + !tmp.path() + .join("plugins/cache/openai-curated/linear") + .exists() + ); +} + +#[tokio::test] +async fn sync_plugins_from_remote_keeps_existing_plugins_when_install_fails() { + let tmp = tempfile::tempdir().unwrap(); + let curated_root = curated_plugins_repo_path(tmp.path()); + write_openai_curated_marketplace(&curated_root, &["linear", "gmail"]); + write_curated_plugin_sha(tmp.path(), TEST_CURATED_PLUGIN_SHA); + fs::remove_dir_all(curated_root.join("plugins/gmail")).unwrap(); + write_plugin( + &tmp.path().join("plugins/cache/openai-curated"), + "linear/local", + "linear", + ); + write_file( + &tmp.path().join(CONFIG_TOML_FILE), + r#"[features] +plugins = true + +[plugins."linear@openai-curated"] +enabled = false +"#, + ); + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/backend-api/plugins/list")) + .respond_with(ResponseTemplate::new(200).set_body_string( + r#"[ + {"id":"1","name":"gmail","marketplace_name":"openai-curated","version":"1.0.0","enabled":true} +]"#, + )) + .mount(&server) + .await; + + let mut config = load_config(tmp.path(), tmp.path()).await; + config.chatgpt_base_url = format!("{}/backend-api/", server.uri()); + let manager = PluginsManager::new(tmp.path().to_path_buf()); + let err = manager + .sync_plugins_from_remote( + &config, + Some(&CodexAuth::create_dummy_chatgpt_auth_for_testing()), + ) + .await + .unwrap_err(); + + assert!(matches!( + err, + PluginRemoteSyncError::Store(PluginStoreError::Invalid(ref message)) + if message.contains("plugin source path is not a directory") + )); + assert!( + tmp.path() + .join("plugins/cache/openai-curated/linear/local") + .is_dir() + ); + assert!( + !tmp.path() + .join("plugins/cache/openai-curated/gmail") + .exists() + ); + + let config = fs::read_to_string(tmp.path().join(CONFIG_TOML_FILE)).unwrap(); + assert!(config.contains(r#"[plugins."linear@openai-curated"]"#)); + assert!(!config.contains(r#"[plugins."gmail@openai-curated"]"#)); + assert!(config.contains("enabled = false")); +} + +#[tokio::test] +async fn sync_plugins_from_remote_uses_first_duplicate_local_plugin_entry() { + let tmp = tempfile::tempdir().unwrap(); + let curated_root = curated_plugins_repo_path(tmp.path()); + write_curated_plugin_sha(tmp.path(), TEST_CURATED_PLUGIN_SHA); + fs::create_dir_all(curated_root.join(".agents/plugins")).unwrap(); + fs::write( + curated_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "openai-curated", + "plugins": [ + { + "name": "gmail", + "source": { + "source": "local", + "path": "./plugins/gmail-first" + } + }, + { + "name": "gmail", + "source": { + "source": "local", + "path": "./plugins/gmail-second" + } + } + ] +}"#, + ) + .unwrap(); + write_plugin(&curated_root, "plugins/gmail-first", "gmail"); + write_plugin(&curated_root, "plugins/gmail-second", "gmail"); + fs::write(curated_root.join("plugins/gmail-first/marker.txt"), "first").unwrap(); + fs::write( + curated_root.join("plugins/gmail-second/marker.txt"), + "second", + ) + .unwrap(); + write_file( + &tmp.path().join(CONFIG_TOML_FILE), + r#"[features] +plugins = true +"#, + ); + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/backend-api/plugins/list")) + .respond_with(ResponseTemplate::new(200).set_body_string( + r#"[ + {"id":"1","name":"gmail","marketplace_name":"openai-curated","version":"1.0.0","enabled":true} +]"#, + )) + .mount(&server) + .await; + + let mut config = load_config(tmp.path(), tmp.path()).await; + config.chatgpt_base_url = format!("{}/backend-api/", server.uri()); + let manager = PluginsManager::new(tmp.path().to_path_buf()); + let result = manager + .sync_plugins_from_remote( + &config, + Some(&CodexAuth::create_dummy_chatgpt_auth_for_testing()), + ) + .await + .unwrap(); + + assert_eq!( + result, + RemotePluginSyncResult { + installed_plugin_ids: vec!["gmail@openai-curated".to_string()], + enabled_plugin_ids: vec!["gmail@openai-curated".to_string()], + disabled_plugin_ids: Vec::new(), + uninstalled_plugin_ids: Vec::new(), + } + ); + assert_eq!( + fs::read_to_string(tmp.path().join(format!( + "plugins/cache/openai-curated/gmail/{TEST_CURATED_PLUGIN_SHA}/marker.txt" + ))) + .unwrap(), + "first" + ); +} + +#[test] +fn refresh_curated_plugin_cache_replaces_existing_local_version_with_sha() { + let tmp = tempfile::tempdir().unwrap(); + let curated_root = curated_plugins_repo_path(tmp.path()); + write_openai_curated_marketplace(&curated_root, &["slack"]); + write_curated_plugin_sha(tmp.path(), TEST_CURATED_PLUGIN_SHA); + let plugin_id = PluginId::new( + "slack".to_string(), + OPENAI_CURATED_MARKETPLACE_NAME.to_string(), + ) + .unwrap(); + write_plugin( + &tmp.path().join("plugins/cache/openai-curated"), + "slack/local", + "slack", + ); + + assert!( + refresh_curated_plugin_cache(tmp.path(), TEST_CURATED_PLUGIN_SHA, &[plugin_id]) + .expect("cache refresh should succeed") + ); + + assert!( + !tmp.path() + .join("plugins/cache/openai-curated/slack/local") + .exists() + ); + assert!( + tmp.path() + .join(format!( + "plugins/cache/openai-curated/slack/{TEST_CURATED_PLUGIN_SHA}" + )) + .is_dir() + ); +} + +#[test] +fn refresh_curated_plugin_cache_reinstalls_missing_configured_plugin_with_current_sha() { + let tmp = tempfile::tempdir().unwrap(); + let curated_root = curated_plugins_repo_path(tmp.path()); + write_openai_curated_marketplace(&curated_root, &["slack"]); + write_curated_plugin_sha(tmp.path(), TEST_CURATED_PLUGIN_SHA); + let plugin_id = PluginId::new( + "slack".to_string(), + OPENAI_CURATED_MARKETPLACE_NAME.to_string(), + ) + .unwrap(); + + assert!( + refresh_curated_plugin_cache(tmp.path(), TEST_CURATED_PLUGIN_SHA, &[plugin_id]) + .expect("cache refresh should recreate missing configured plugin") + ); + + assert!( + tmp.path() + .join(format!( + "plugins/cache/openai-curated/slack/{TEST_CURATED_PLUGIN_SHA}" + )) + .is_dir() + ); +} + +#[test] +fn refresh_curated_plugin_cache_returns_false_when_configured_plugins_are_current() { + let tmp = tempfile::tempdir().unwrap(); + let curated_root = curated_plugins_repo_path(tmp.path()); + write_openai_curated_marketplace(&curated_root, &["slack"]); + let plugin_id = PluginId::new( + "slack".to_string(), + OPENAI_CURATED_MARKETPLACE_NAME.to_string(), + ) + .unwrap(); + write_plugin( + &tmp.path().join("plugins/cache/openai-curated"), + &format!("slack/{TEST_CURATED_PLUGIN_SHA}"), + "slack", + ); + + assert!( + !refresh_curated_plugin_cache(tmp.path(), TEST_CURATED_PLUGIN_SHA, &[plugin_id]) + .expect("cache refresh should be a no-op when configured plugins are current") + ); +} + +#[test] +fn load_plugins_ignores_project_config_files() { + let codex_home = TempDir::new().unwrap(); + let project_root = codex_home.path().join("project"); + let plugin_root = codex_home + .path() + .join("plugins/cache") + .join("test/sample/local"); + + write_file( + &plugin_root.join(".codex-plugin/plugin.json"), + r#"{"name":"sample"}"#, + ); + write_file( + &project_root.join(".codex/config.toml"), + &plugin_config_toml(true, true), + ); + + let stack = ConfigLayerStack::new( + vec![ConfigLayerEntry::new( + ConfigLayerSource::Project { + dot_codex_folder: AbsolutePathBuf::try_from(project_root.join(".codex")).unwrap(), + }, + toml::from_str(&plugin_config_toml(true, true)).expect("project config should parse"), + )], + ConfigRequirements::default(), + ConfigRequirementsToml::default(), + ) + .expect("config layer stack should build"); + + let outcome = PluginsManager::new(codex_home.path().to_path_buf()).plugins_for_layer_stack( + &project_root, + &stack, + false, + ); + + assert_eq!(outcome, PluginLoadOutcome::default()); +} diff --git a/codex-rs/core/src/plugins/marketplace.rs b/codex-rs/core/src/plugins/marketplace.rs index cac6d1b027..622ed9dd17 100644 --- a/codex-rs/core/src/plugins/marketplace.rs +++ b/codex-rs/core/src/plugins/marketplace.rs @@ -390,582 +390,5 @@ enum MarketplacePluginSource { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use tempfile::tempdir; - - #[test] - fn resolve_marketplace_plugin_finds_repo_marketplace_plugin() { - let tmp = tempdir().unwrap(); - let repo_root = tmp.path().join("repo"); - fs::create_dir_all(repo_root.join(".git")).unwrap(); - fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); - fs::create_dir_all(repo_root.join("nested")).unwrap(); - fs::write( - repo_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "codex-curated", - "plugins": [ - { - "name": "local-plugin", - "source": { - "source": "local", - "path": "./plugin-1" - } - } - ] -}"#, - ) - .unwrap(); - - let resolved = resolve_marketplace_plugin( - &AbsolutePathBuf::try_from(repo_root.join(".agents/plugins/marketplace.json")).unwrap(), - "local-plugin", - ) - .unwrap(); - - assert_eq!( - resolved, - ResolvedMarketplacePlugin { - plugin_id: PluginId::new("local-plugin".to_string(), "codex-curated".to_string()) - .unwrap(), - source_path: AbsolutePathBuf::try_from(repo_root.join("plugin-1")).unwrap(), - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - } - ); - } - - #[test] - fn resolve_marketplace_plugin_reports_missing_plugin() { - let tmp = tempdir().unwrap(); - let repo_root = tmp.path().join("repo"); - fs::create_dir_all(repo_root.join(".git")).unwrap(); - fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); - fs::write( - repo_root.join(".agents/plugins/marketplace.json"), - r#"{"name":"codex-curated","plugins":[]}"#, - ) - .unwrap(); - - let err = resolve_marketplace_plugin( - &AbsolutePathBuf::try_from(repo_root.join(".agents/plugins/marketplace.json")).unwrap(), - "missing", - ) - .unwrap_err(); - - assert_eq!( - err.to_string(), - "plugin `missing` was not found in marketplace `codex-curated`" - ); - } - - #[test] - fn list_marketplaces_returns_home_and_repo_marketplaces() { - let tmp = tempdir().unwrap(); - let home_root = tmp.path().join("home"); - let repo_root = tmp.path().join("repo"); - - fs::create_dir_all(repo_root.join(".git")).unwrap(); - fs::create_dir_all(home_root.join(".agents/plugins")).unwrap(); - fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); - fs::write( - home_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "codex-curated", - "plugins": [ - { - "name": "shared-plugin", - "source": { - "source": "local", - "path": "./home-shared" - } - }, - { - "name": "home-only", - "source": { - "source": "local", - "path": "./home-only" - } - } - ] -}"#, - ) - .unwrap(); - fs::write( - repo_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "codex-curated", - "plugins": [ - { - "name": "shared-plugin", - "source": { - "source": "local", - "path": "./repo-shared" - } - }, - { - "name": "repo-only", - "source": { - "source": "local", - "path": "./repo-only" - } - } - ] -}"#, - ) - .unwrap(); - - let marketplaces = list_marketplaces_with_home( - &[AbsolutePathBuf::try_from(repo_root.clone()).unwrap()], - Some(&home_root), - ) - .unwrap(); - - assert_eq!( - marketplaces, - vec![ - MarketplaceSummary { - name: "codex-curated".to_string(), - path: AbsolutePathBuf::try_from( - home_root.join(".agents/plugins/marketplace.json"), - ) - .unwrap(), - plugins: vec![ - MarketplacePluginSummary { - name: "shared-plugin".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from(home_root.join("home-shared")) - .unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - }, - MarketplacePluginSummary { - name: "home-only".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from(home_root.join("home-only")) - .unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - }, - ], - }, - MarketplaceSummary { - name: "codex-curated".to_string(), - path: AbsolutePathBuf::try_from( - repo_root.join(".agents/plugins/marketplace.json"), - ) - .unwrap(), - plugins: vec![ - MarketplacePluginSummary { - name: "shared-plugin".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from(repo_root.join("repo-shared")) - .unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - }, - MarketplacePluginSummary { - name: "repo-only".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from(repo_root.join("repo-only")) - .unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - }, - ], - }, - ] - ); - } - - #[test] - fn list_marketplaces_keeps_distinct_entries_for_same_name() { - let tmp = tempdir().unwrap(); - let home_root = tmp.path().join("home"); - let repo_root = tmp.path().join("repo"); - let home_marketplace = home_root.join(".agents/plugins/marketplace.json"); - let repo_marketplace = repo_root.join(".agents/plugins/marketplace.json"); - - fs::create_dir_all(repo_root.join(".git")).unwrap(); - fs::create_dir_all(home_root.join(".agents/plugins")).unwrap(); - fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); - - fs::write( - home_marketplace.clone(), - r#"{ - "name": "codex-curated", - "plugins": [ - { - "name": "local-plugin", - "source": { - "source": "local", - "path": "./home-plugin" - } - } - ] -}"#, - ) - .unwrap(); - fs::write( - repo_marketplace.clone(), - r#"{ - "name": "codex-curated", - "plugins": [ - { - "name": "local-plugin", - "source": { - "source": "local", - "path": "./repo-plugin" - } - } - ] -}"#, - ) - .unwrap(); - - let marketplaces = list_marketplaces_with_home( - &[AbsolutePathBuf::try_from(repo_root.clone()).unwrap()], - Some(&home_root), - ) - .unwrap(); - - assert_eq!( - marketplaces, - vec![ - MarketplaceSummary { - name: "codex-curated".to_string(), - path: AbsolutePathBuf::try_from(home_marketplace).unwrap(), - plugins: vec![MarketplacePluginSummary { - name: "local-plugin".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from(home_root.join("home-plugin")).unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - }], - }, - MarketplaceSummary { - name: "codex-curated".to_string(), - path: AbsolutePathBuf::try_from(repo_marketplace.clone()).unwrap(), - plugins: vec![MarketplacePluginSummary { - name: "local-plugin".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from(repo_root.join("repo-plugin")).unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - }], - }, - ] - ); - - let resolved = resolve_marketplace_plugin( - &AbsolutePathBuf::try_from(repo_marketplace).unwrap(), - "local-plugin", - ) - .unwrap(); - - assert_eq!( - resolved.source_path, - AbsolutePathBuf::try_from(repo_root.join("repo-plugin")).unwrap() - ); - } - - #[test] - fn list_marketplaces_dedupes_multiple_roots_in_same_repo() { - let tmp = tempdir().unwrap(); - let repo_root = tmp.path().join("repo"); - let nested_root = repo_root.join("nested/project"); - - fs::create_dir_all(repo_root.join(".git")).unwrap(); - fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); - fs::create_dir_all(&nested_root).unwrap(); - fs::write( - repo_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "codex-curated", - "plugins": [ - { - "name": "local-plugin", - "source": { - "source": "local", - "path": "./plugin" - } - } - ] -}"#, - ) - .unwrap(); - - let marketplaces = list_marketplaces_with_home( - &[ - AbsolutePathBuf::try_from(repo_root.clone()).unwrap(), - AbsolutePathBuf::try_from(nested_root).unwrap(), - ], - None, - ) - .unwrap(); - - assert_eq!( - marketplaces, - vec![MarketplaceSummary { - name: "codex-curated".to_string(), - path: AbsolutePathBuf::try_from(repo_root.join(".agents/plugins/marketplace.json")) - .unwrap(), - plugins: vec![MarketplacePluginSummary { - name: "local-plugin".to_string(), - source: MarketplacePluginSourceSummary::Local { - path: AbsolutePathBuf::try_from(repo_root.join("plugin")).unwrap(), - }, - install_policy: MarketplacePluginInstallPolicy::Available, - auth_policy: MarketplacePluginAuthPolicy::OnInstall, - interface: None, - }], - }] - ); - } - - #[test] - fn list_marketplaces_resolves_plugin_interface_paths_to_absolute() { - let tmp = tempdir().unwrap(); - let repo_root = tmp.path().join("repo"); - let plugin_root = repo_root.join("plugins/demo-plugin"); - fs::create_dir_all(repo_root.join(".git")).unwrap(); - fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); - fs::create_dir_all(plugin_root.join(".codex-plugin")).unwrap(); - fs::write( - repo_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "codex-curated", - "plugins": [ - { - "name": "demo-plugin", - "source": { - "source": "local", - "path": "./plugins/demo-plugin" - }, - "installPolicy": "AVAILABLE", - "authPolicy": "ON_INSTALL", - "category": "Design" - } - ] -}"#, - ) - .unwrap(); - fs::write( - plugin_root.join(".codex-plugin/plugin.json"), - r#"{ - "name": "demo-plugin", - "interface": { - "displayName": "Demo", - "category": "Productivity", - "capabilities": ["Interactive", "Write"], - "composerIcon": "./assets/icon.png", - "logo": "./assets/logo.png", - "screenshots": ["./assets/shot1.png"] - } -}"#, - ) - .unwrap(); - - let marketplaces = - list_marketplaces_with_home(&[AbsolutePathBuf::try_from(repo_root).unwrap()], None) - .unwrap(); - - assert_eq!( - marketplaces[0].plugins[0].install_policy, - MarketplacePluginInstallPolicy::Available - ); - assert_eq!( - marketplaces[0].plugins[0].auth_policy, - MarketplacePluginAuthPolicy::OnInstall - ); - assert_eq!( - marketplaces[0].plugins[0].interface, - Some(PluginManifestInterfaceSummary { - display_name: Some("Demo".to_string()), - short_description: None, - long_description: None, - developer_name: None, - category: Some("Design".to_string()), - capabilities: vec!["Interactive".to_string(), "Write".to_string()], - website_url: None, - privacy_policy_url: None, - terms_of_service_url: None, - default_prompt: None, - brand_color: None, - composer_icon: Some( - AbsolutePathBuf::try_from(plugin_root.join("assets/icon.png")).unwrap(), - ), - logo: Some(AbsolutePathBuf::try_from(plugin_root.join("assets/logo.png")).unwrap()), - screenshots: vec![ - AbsolutePathBuf::try_from(plugin_root.join("assets/shot1.png")).unwrap(), - ], - }) - ); - } - - #[test] - fn list_marketplaces_ignores_plugin_interface_assets_without_dot_slash() { - let tmp = tempdir().unwrap(); - let repo_root = tmp.path().join("repo"); - let plugin_root = repo_root.join("plugins/demo-plugin"); - - fs::create_dir_all(repo_root.join(".git")).unwrap(); - fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); - fs::create_dir_all(plugin_root.join(".codex-plugin")).unwrap(); - fs::write( - repo_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "codex-curated", - "plugins": [ - { - "name": "demo-plugin", - "source": { - "source": "local", - "path": "./plugins/demo-plugin" - } - } - ] -}"#, - ) - .unwrap(); - fs::write( - plugin_root.join(".codex-plugin/plugin.json"), - r#"{ - "name": "demo-plugin", - "interface": { - "displayName": "Demo", - "capabilities": ["Interactive"], - "composerIcon": "assets/icon.png", - "logo": "/tmp/logo.png", - "screenshots": ["assets/shot1.png"] - } -}"#, - ) - .unwrap(); - - let marketplaces = - list_marketplaces_with_home(&[AbsolutePathBuf::try_from(repo_root).unwrap()], None) - .unwrap(); - - assert_eq!( - marketplaces[0].plugins[0].interface, - Some(PluginManifestInterfaceSummary { - display_name: Some("Demo".to_string()), - short_description: None, - long_description: None, - developer_name: None, - category: None, - capabilities: vec!["Interactive".to_string()], - website_url: None, - privacy_policy_url: None, - terms_of_service_url: None, - default_prompt: None, - brand_color: None, - composer_icon: None, - logo: None, - screenshots: Vec::new(), - }) - ); - assert_eq!( - marketplaces[0].plugins[0].install_policy, - MarketplacePluginInstallPolicy::Available - ); - assert_eq!( - marketplaces[0].plugins[0].auth_policy, - MarketplacePluginAuthPolicy::OnInstall - ); - } - - #[test] - fn resolve_marketplace_plugin_rejects_non_relative_local_paths() { - let tmp = tempdir().unwrap(); - let repo_root = tmp.path().join("repo"); - fs::create_dir_all(repo_root.join(".git")).unwrap(); - fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); - fs::write( - repo_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "codex-curated", - "plugins": [ - { - "name": "local-plugin", - "source": { - "source": "local", - "path": "../plugin-1" - } - } - ] -}"#, - ) - .unwrap(); - - let err = resolve_marketplace_plugin( - &AbsolutePathBuf::try_from(repo_root.join(".agents/plugins/marketplace.json")).unwrap(), - "local-plugin", - ) - .unwrap_err(); - - assert_eq!( - err.to_string(), - format!( - "invalid marketplace file `{}`: local plugin source path must start with `./`", - repo_root.join(".agents/plugins/marketplace.json").display() - ) - ); - } - - #[test] - fn resolve_marketplace_plugin_uses_first_duplicate_entry() { - let tmp = tempdir().unwrap(); - let repo_root = tmp.path().join("repo"); - fs::create_dir_all(repo_root.join(".git")).unwrap(); - fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); - fs::write( - repo_root.join(".agents/plugins/marketplace.json"), - r#"{ - "name": "codex-curated", - "plugins": [ - { - "name": "local-plugin", - "source": { - "source": "local", - "path": "./first" - } - }, - { - "name": "local-plugin", - "source": { - "source": "local", - "path": "./second" - } - } - ] -}"#, - ) - .unwrap(); - - let resolved = resolve_marketplace_plugin( - &AbsolutePathBuf::try_from(repo_root.join(".agents/plugins/marketplace.json")).unwrap(), - "local-plugin", - ) - .unwrap(); - - assert_eq!( - resolved.source_path, - AbsolutePathBuf::try_from(repo_root.join("first")).unwrap() - ); - } -} +#[path = "marketplace_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/plugins/marketplace_tests.rs b/codex-rs/core/src/plugins/marketplace_tests.rs new file mode 100644 index 0000000000..b6b5050e8f --- /dev/null +++ b/codex-rs/core/src/plugins/marketplace_tests.rs @@ -0,0 +1,571 @@ +use super::*; +use pretty_assertions::assert_eq; +use tempfile::tempdir; + +#[test] +fn resolve_marketplace_plugin_finds_repo_marketplace_plugin() { + let tmp = tempdir().unwrap(); + let repo_root = tmp.path().join("repo"); + fs::create_dir_all(repo_root.join(".git")).unwrap(); + fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); + fs::create_dir_all(repo_root.join("nested")).unwrap(); + fs::write( + repo_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "codex-curated", + "plugins": [ + { + "name": "local-plugin", + "source": { + "source": "local", + "path": "./plugin-1" + } + } + ] +}"#, + ) + .unwrap(); + + let resolved = resolve_marketplace_plugin( + &AbsolutePathBuf::try_from(repo_root.join(".agents/plugins/marketplace.json")).unwrap(), + "local-plugin", + ) + .unwrap(); + + assert_eq!( + resolved, + ResolvedMarketplacePlugin { + plugin_id: PluginId::new("local-plugin".to_string(), "codex-curated".to_string()) + .unwrap(), + source_path: AbsolutePathBuf::try_from(repo_root.join("plugin-1")).unwrap(), + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + } + ); +} + +#[test] +fn resolve_marketplace_plugin_reports_missing_plugin() { + let tmp = tempdir().unwrap(); + let repo_root = tmp.path().join("repo"); + fs::create_dir_all(repo_root.join(".git")).unwrap(); + fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); + fs::write( + repo_root.join(".agents/plugins/marketplace.json"), + r#"{"name":"codex-curated","plugins":[]}"#, + ) + .unwrap(); + + let err = resolve_marketplace_plugin( + &AbsolutePathBuf::try_from(repo_root.join(".agents/plugins/marketplace.json")).unwrap(), + "missing", + ) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "plugin `missing` was not found in marketplace `codex-curated`" + ); +} + +#[test] +fn list_marketplaces_returns_home_and_repo_marketplaces() { + let tmp = tempdir().unwrap(); + let home_root = tmp.path().join("home"); + let repo_root = tmp.path().join("repo"); + + fs::create_dir_all(repo_root.join(".git")).unwrap(); + fs::create_dir_all(home_root.join(".agents/plugins")).unwrap(); + fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); + fs::write( + home_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "codex-curated", + "plugins": [ + { + "name": "shared-plugin", + "source": { + "source": "local", + "path": "./home-shared" + } + }, + { + "name": "home-only", + "source": { + "source": "local", + "path": "./home-only" + } + } + ] +}"#, + ) + .unwrap(); + fs::write( + repo_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "codex-curated", + "plugins": [ + { + "name": "shared-plugin", + "source": { + "source": "local", + "path": "./repo-shared" + } + }, + { + "name": "repo-only", + "source": { + "source": "local", + "path": "./repo-only" + } + } + ] +}"#, + ) + .unwrap(); + + let marketplaces = list_marketplaces_with_home( + &[AbsolutePathBuf::try_from(repo_root.clone()).unwrap()], + Some(&home_root), + ) + .unwrap(); + + assert_eq!( + marketplaces, + vec![ + MarketplaceSummary { + name: "codex-curated".to_string(), + path: + AbsolutePathBuf::try_from(home_root.join(".agents/plugins/marketplace.json"),) + .unwrap(), + plugins: vec![ + MarketplacePluginSummary { + name: "shared-plugin".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(home_root.join("home-shared")).unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + }, + MarketplacePluginSummary { + name: "home-only".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(home_root.join("home-only")).unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + }, + ], + }, + MarketplaceSummary { + name: "codex-curated".to_string(), + path: + AbsolutePathBuf::try_from(repo_root.join(".agents/plugins/marketplace.json"),) + .unwrap(), + plugins: vec![ + MarketplacePluginSummary { + name: "shared-plugin".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(repo_root.join("repo-shared")).unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + }, + MarketplacePluginSummary { + name: "repo-only".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(repo_root.join("repo-only")).unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + }, + ], + }, + ] + ); +} + +#[test] +fn list_marketplaces_keeps_distinct_entries_for_same_name() { + let tmp = tempdir().unwrap(); + let home_root = tmp.path().join("home"); + let repo_root = tmp.path().join("repo"); + let home_marketplace = home_root.join(".agents/plugins/marketplace.json"); + let repo_marketplace = repo_root.join(".agents/plugins/marketplace.json"); + + fs::create_dir_all(repo_root.join(".git")).unwrap(); + fs::create_dir_all(home_root.join(".agents/plugins")).unwrap(); + fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); + + fs::write( + home_marketplace.clone(), + r#"{ + "name": "codex-curated", + "plugins": [ + { + "name": "local-plugin", + "source": { + "source": "local", + "path": "./home-plugin" + } + } + ] +}"#, + ) + .unwrap(); + fs::write( + repo_marketplace.clone(), + r#"{ + "name": "codex-curated", + "plugins": [ + { + "name": "local-plugin", + "source": { + "source": "local", + "path": "./repo-plugin" + } + } + ] +}"#, + ) + .unwrap(); + + let marketplaces = list_marketplaces_with_home( + &[AbsolutePathBuf::try_from(repo_root.clone()).unwrap()], + Some(&home_root), + ) + .unwrap(); + + assert_eq!( + marketplaces, + vec![ + MarketplaceSummary { + name: "codex-curated".to_string(), + path: AbsolutePathBuf::try_from(home_marketplace).unwrap(), + plugins: vec![MarketplacePluginSummary { + name: "local-plugin".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(home_root.join("home-plugin")).unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + }], + }, + MarketplaceSummary { + name: "codex-curated".to_string(), + path: AbsolutePathBuf::try_from(repo_marketplace.clone()).unwrap(), + plugins: vec![MarketplacePluginSummary { + name: "local-plugin".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(repo_root.join("repo-plugin")).unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + }], + }, + ] + ); + + let resolved = resolve_marketplace_plugin( + &AbsolutePathBuf::try_from(repo_marketplace).unwrap(), + "local-plugin", + ) + .unwrap(); + + assert_eq!( + resolved.source_path, + AbsolutePathBuf::try_from(repo_root.join("repo-plugin")).unwrap() + ); +} + +#[test] +fn list_marketplaces_dedupes_multiple_roots_in_same_repo() { + let tmp = tempdir().unwrap(); + let repo_root = tmp.path().join("repo"); + let nested_root = repo_root.join("nested/project"); + + fs::create_dir_all(repo_root.join(".git")).unwrap(); + fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); + fs::create_dir_all(&nested_root).unwrap(); + fs::write( + repo_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "codex-curated", + "plugins": [ + { + "name": "local-plugin", + "source": { + "source": "local", + "path": "./plugin" + } + } + ] +}"#, + ) + .unwrap(); + + let marketplaces = list_marketplaces_with_home( + &[ + AbsolutePathBuf::try_from(repo_root.clone()).unwrap(), + AbsolutePathBuf::try_from(nested_root).unwrap(), + ], + None, + ) + .unwrap(); + + assert_eq!( + marketplaces, + vec![MarketplaceSummary { + name: "codex-curated".to_string(), + path: AbsolutePathBuf::try_from(repo_root.join(".agents/plugins/marketplace.json")) + .unwrap(), + plugins: vec![MarketplacePluginSummary { + name: "local-plugin".to_string(), + source: MarketplacePluginSourceSummary::Local { + path: AbsolutePathBuf::try_from(repo_root.join("plugin")).unwrap(), + }, + install_policy: MarketplacePluginInstallPolicy::Available, + auth_policy: MarketplacePluginAuthPolicy::OnInstall, + interface: None, + }], + }] + ); +} + +#[test] +fn list_marketplaces_resolves_plugin_interface_paths_to_absolute() { + let tmp = tempdir().unwrap(); + let repo_root = tmp.path().join("repo"); + let plugin_root = repo_root.join("plugins/demo-plugin"); + fs::create_dir_all(repo_root.join(".git")).unwrap(); + fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); + fs::create_dir_all(plugin_root.join(".codex-plugin")).unwrap(); + fs::write( + repo_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "codex-curated", + "plugins": [ + { + "name": "demo-plugin", + "source": { + "source": "local", + "path": "./plugins/demo-plugin" + }, + "installPolicy": "AVAILABLE", + "authPolicy": "ON_INSTALL", + "category": "Design" + } + ] +}"#, + ) + .unwrap(); + fs::write( + plugin_root.join(".codex-plugin/plugin.json"), + r#"{ + "name": "demo-plugin", + "interface": { + "displayName": "Demo", + "category": "Productivity", + "capabilities": ["Interactive", "Write"], + "composerIcon": "./assets/icon.png", + "logo": "./assets/logo.png", + "screenshots": ["./assets/shot1.png"] + } +}"#, + ) + .unwrap(); + + let marketplaces = + list_marketplaces_with_home(&[AbsolutePathBuf::try_from(repo_root).unwrap()], None) + .unwrap(); + + assert_eq!( + marketplaces[0].plugins[0].install_policy, + MarketplacePluginInstallPolicy::Available + ); + assert_eq!( + marketplaces[0].plugins[0].auth_policy, + MarketplacePluginAuthPolicy::OnInstall + ); + assert_eq!( + marketplaces[0].plugins[0].interface, + Some(PluginManifestInterfaceSummary { + display_name: Some("Demo".to_string()), + short_description: None, + long_description: None, + developer_name: None, + category: Some("Design".to_string()), + capabilities: vec!["Interactive".to_string(), "Write".to_string()], + website_url: None, + privacy_policy_url: None, + terms_of_service_url: None, + default_prompt: None, + brand_color: None, + composer_icon: Some( + AbsolutePathBuf::try_from(plugin_root.join("assets/icon.png")).unwrap(), + ), + logo: Some(AbsolutePathBuf::try_from(plugin_root.join("assets/logo.png")).unwrap()), + screenshots: vec![ + AbsolutePathBuf::try_from(plugin_root.join("assets/shot1.png")).unwrap(), + ], + }) + ); +} + +#[test] +fn list_marketplaces_ignores_plugin_interface_assets_without_dot_slash() { + let tmp = tempdir().unwrap(); + let repo_root = tmp.path().join("repo"); + let plugin_root = repo_root.join("plugins/demo-plugin"); + + fs::create_dir_all(repo_root.join(".git")).unwrap(); + fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); + fs::create_dir_all(plugin_root.join(".codex-plugin")).unwrap(); + fs::write( + repo_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "codex-curated", + "plugins": [ + { + "name": "demo-plugin", + "source": { + "source": "local", + "path": "./plugins/demo-plugin" + } + } + ] +}"#, + ) + .unwrap(); + fs::write( + plugin_root.join(".codex-plugin/plugin.json"), + r#"{ + "name": "demo-plugin", + "interface": { + "displayName": "Demo", + "capabilities": ["Interactive"], + "composerIcon": "assets/icon.png", + "logo": "/tmp/logo.png", + "screenshots": ["assets/shot1.png"] + } +}"#, + ) + .unwrap(); + + let marketplaces = + list_marketplaces_with_home(&[AbsolutePathBuf::try_from(repo_root).unwrap()], None) + .unwrap(); + + assert_eq!( + marketplaces[0].plugins[0].interface, + Some(PluginManifestInterfaceSummary { + display_name: Some("Demo".to_string()), + short_description: None, + long_description: None, + developer_name: None, + category: None, + capabilities: vec!["Interactive".to_string()], + website_url: None, + privacy_policy_url: None, + terms_of_service_url: None, + default_prompt: None, + brand_color: None, + composer_icon: None, + logo: None, + screenshots: Vec::new(), + }) + ); + assert_eq!( + marketplaces[0].plugins[0].install_policy, + MarketplacePluginInstallPolicy::Available + ); + assert_eq!( + marketplaces[0].plugins[0].auth_policy, + MarketplacePluginAuthPolicy::OnInstall + ); +} + +#[test] +fn resolve_marketplace_plugin_rejects_non_relative_local_paths() { + let tmp = tempdir().unwrap(); + let repo_root = tmp.path().join("repo"); + fs::create_dir_all(repo_root.join(".git")).unwrap(); + fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); + fs::write( + repo_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "codex-curated", + "plugins": [ + { + "name": "local-plugin", + "source": { + "source": "local", + "path": "../plugin-1" + } + } + ] +}"#, + ) + .unwrap(); + + let err = resolve_marketplace_plugin( + &AbsolutePathBuf::try_from(repo_root.join(".agents/plugins/marketplace.json")).unwrap(), + "local-plugin", + ) + .unwrap_err(); + + assert_eq!( + err.to_string(), + format!( + "invalid marketplace file `{}`: local plugin source path must start with `./`", + repo_root.join(".agents/plugins/marketplace.json").display() + ) + ); +} + +#[test] +fn resolve_marketplace_plugin_uses_first_duplicate_entry() { + let tmp = tempdir().unwrap(); + let repo_root = tmp.path().join("repo"); + fs::create_dir_all(repo_root.join(".git")).unwrap(); + fs::create_dir_all(repo_root.join(".agents/plugins")).unwrap(); + fs::write( + repo_root.join(".agents/plugins/marketplace.json"), + r#"{ + "name": "codex-curated", + "plugins": [ + { + "name": "local-plugin", + "source": { + "source": "local", + "path": "./first" + } + }, + { + "name": "local-plugin", + "source": { + "source": "local", + "path": "./second" + } + } + ] +}"#, + ) + .unwrap(); + + let resolved = resolve_marketplace_plugin( + &AbsolutePathBuf::try_from(repo_root.join(".agents/plugins/marketplace.json")).unwrap(), + "local-plugin", + ) + .unwrap(); + + assert_eq!( + resolved.source_path, + AbsolutePathBuf::try_from(repo_root.join("first")).unwrap() + ); +} diff --git a/codex-rs/core/src/plugins/render.rs b/codex-rs/core/src/plugins/render.rs index 1111ea46be..4b7627a48a 100644 --- a/codex-rs/core/src/plugins/render.rs +++ b/codex-rs/core/src/plugins/render.rs @@ -79,12 +79,5 @@ pub(crate) fn render_explicit_plugin_instructions( } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn render_plugins_section_returns_none_for_empty_plugins() { - assert_eq!(render_plugins_section(&[]), None); - } -} +#[path = "render_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/plugins/render_tests.rs b/codex-rs/core/src/plugins/render_tests.rs new file mode 100644 index 0000000000..6ca86d0d41 --- /dev/null +++ b/codex-rs/core/src/plugins/render_tests.rs @@ -0,0 +1,7 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn render_plugins_section_returns_none_for_empty_plugins() { + assert_eq!(render_plugins_section(&[]), None); +} diff --git a/codex-rs/core/src/plugins/store.rs b/codex-rs/core/src/plugins/store.rs index 806c43f4d8..22452767ce 100644 --- a/codex-rs/core/src/plugins/store.rs +++ b/codex-rs/core/src/plugins/store.rs @@ -342,197 +342,5 @@ fn copy_dir_recursive(source: &Path, target: &Path) -> Result<(), PluginStoreErr } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use tempfile::tempdir; - - fn write_plugin(root: &Path, dir_name: &str, manifest_name: &str) { - let plugin_root = root.join(dir_name); - fs::create_dir_all(plugin_root.join(".codex-plugin")).unwrap(); - fs::create_dir_all(plugin_root.join("skills")).unwrap(); - fs::write( - plugin_root.join(".codex-plugin/plugin.json"), - format!(r#"{{"name":"{manifest_name}"}}"#), - ) - .unwrap(); - fs::write(plugin_root.join("skills/SKILL.md"), "skill").unwrap(); - fs::write(plugin_root.join(".mcp.json"), r#"{"mcpServers":{}}"#).unwrap(); - } - - #[test] - fn install_copies_plugin_into_default_marketplace() { - let tmp = tempdir().unwrap(); - write_plugin(tmp.path(), "sample-plugin", "sample-plugin"); - let plugin_id = PluginId::new("sample-plugin".to_string(), "debug".to_string()).unwrap(); - - let result = PluginStore::new(tmp.path().to_path_buf()) - .install( - AbsolutePathBuf::try_from(tmp.path().join("sample-plugin")).unwrap(), - plugin_id.clone(), - ) - .unwrap(); - - let installed_path = tmp.path().join("plugins/cache/debug/sample-plugin/local"); - assert_eq!( - result, - PluginInstallResult { - plugin_id, - plugin_version: "local".to_string(), - installed_path: AbsolutePathBuf::try_from(installed_path.clone()).unwrap(), - } - ); - assert!(installed_path.join(".codex-plugin/plugin.json").is_file()); - assert!(installed_path.join("skills/SKILL.md").is_file()); - } - - #[test] - fn install_uses_manifest_name_for_destination_and_key() { - let tmp = tempdir().unwrap(); - write_plugin(tmp.path(), "source-dir", "manifest-name"); - let plugin_id = PluginId::new("manifest-name".to_string(), "market".to_string()).unwrap(); - - let result = PluginStore::new(tmp.path().to_path_buf()) - .install( - AbsolutePathBuf::try_from(tmp.path().join("source-dir")).unwrap(), - plugin_id.clone(), - ) - .unwrap(); - - assert_eq!( - result, - PluginInstallResult { - plugin_id, - plugin_version: "local".to_string(), - installed_path: AbsolutePathBuf::try_from( - tmp.path().join("plugins/cache/market/manifest-name/local"), - ) - .unwrap(), - } - ); - } - - #[test] - fn plugin_root_derives_path_from_key_and_version() { - let tmp = tempdir().unwrap(); - let store = PluginStore::new(tmp.path().to_path_buf()); - let plugin_id = PluginId::new("sample".to_string(), "debug".to_string()).unwrap(); - - assert_eq!( - store.plugin_root(&plugin_id, "local").as_path(), - tmp.path().join("plugins/cache/debug/sample/local") - ); - } - - #[test] - fn install_with_version_uses_requested_cache_version() { - let tmp = tempdir().unwrap(); - write_plugin(tmp.path(), "sample-plugin", "sample-plugin"); - let plugin_id = - PluginId::new("sample-plugin".to_string(), "openai-curated".to_string()).unwrap(); - let plugin_version = "0123456789abcdef".to_string(); - - let result = PluginStore::new(tmp.path().to_path_buf()) - .install_with_version( - AbsolutePathBuf::try_from(tmp.path().join("sample-plugin")).unwrap(), - plugin_id.clone(), - plugin_version.clone(), - ) - .unwrap(); - - let installed_path = tmp.path().join(format!( - "plugins/cache/openai-curated/sample-plugin/{plugin_version}" - )); - assert_eq!( - result, - PluginInstallResult { - plugin_id, - plugin_version, - installed_path: AbsolutePathBuf::try_from(installed_path.clone()).unwrap(), - } - ); - assert!(installed_path.join(".codex-plugin/plugin.json").is_file()); - } - - #[test] - fn active_plugin_version_reads_version_directory_name() { - let tmp = tempdir().unwrap(); - write_plugin( - &tmp.path().join("plugins/cache/debug"), - "sample-plugin/local", - "sample-plugin", - ); - let store = PluginStore::new(tmp.path().to_path_buf()); - let plugin_id = PluginId::new("sample-plugin".to_string(), "debug".to_string()).unwrap(); - - assert_eq!( - store.active_plugin_version(&plugin_id), - Some("local".to_string()) - ); - assert_eq!( - store.active_plugin_root(&plugin_id).unwrap().as_path(), - tmp.path().join("plugins/cache/debug/sample-plugin/local") - ); - } - - #[test] - fn plugin_root_rejects_path_separators_in_key_segments() { - let err = PluginId::parse("../../etc@debug").unwrap_err(); - assert_eq!( - err.to_string(), - "invalid plugin name: only ASCII letters, digits, `_`, and `-` are allowed in `../../etc@debug`" - ); - - let err = PluginId::parse("sample@../../etc").unwrap_err(); - assert_eq!( - err.to_string(), - "invalid marketplace name: only ASCII letters, digits, `_`, and `-` are allowed in `sample@../../etc`" - ); - } - - #[test] - fn install_rejects_manifest_names_with_path_separators() { - let tmp = tempdir().unwrap(); - write_plugin(tmp.path(), "source-dir", "../../etc"); - - let err = PluginStore::new(tmp.path().to_path_buf()) - .install( - AbsolutePathBuf::try_from(tmp.path().join("source-dir")).unwrap(), - PluginId::new("source-dir".to_string(), "debug".to_string()).unwrap(), - ) - .unwrap_err(); - - assert_eq!( - err.to_string(), - "invalid plugin name: only ASCII letters, digits, `_`, and `-` are allowed" - ); - } - - #[test] - fn install_rejects_marketplace_names_with_path_separators() { - let err = PluginId::new("sample-plugin".to_string(), "../../etc".to_string()).unwrap_err(); - - assert_eq!( - err.to_string(), - "invalid marketplace name: only ASCII letters, digits, `_`, and `-` are allowed" - ); - } - - #[test] - fn install_rejects_manifest_names_that_do_not_match_marketplace_plugin_name() { - let tmp = tempdir().unwrap(); - write_plugin(tmp.path(), "source-dir", "manifest-name"); - - let err = PluginStore::new(tmp.path().to_path_buf()) - .install( - AbsolutePathBuf::try_from(tmp.path().join("source-dir")).unwrap(), - PluginId::new("different-name".to_string(), "debug".to_string()).unwrap(), - ) - .unwrap_err(); - - assert_eq!( - err.to_string(), - "plugin manifest name `manifest-name` does not match marketplace plugin name `different-name`" - ); - } -} +#[path = "store_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/plugins/store_tests.rs b/codex-rs/core/src/plugins/store_tests.rs new file mode 100644 index 0000000000..b1da11a8a6 --- /dev/null +++ b/codex-rs/core/src/plugins/store_tests.rs @@ -0,0 +1,192 @@ +use super::*; +use pretty_assertions::assert_eq; +use tempfile::tempdir; + +fn write_plugin(root: &Path, dir_name: &str, manifest_name: &str) { + let plugin_root = root.join(dir_name); + fs::create_dir_all(plugin_root.join(".codex-plugin")).unwrap(); + fs::create_dir_all(plugin_root.join("skills")).unwrap(); + fs::write( + plugin_root.join(".codex-plugin/plugin.json"), + format!(r#"{{"name":"{manifest_name}"}}"#), + ) + .unwrap(); + fs::write(plugin_root.join("skills/SKILL.md"), "skill").unwrap(); + fs::write(plugin_root.join(".mcp.json"), r#"{"mcpServers":{}}"#).unwrap(); +} + +#[test] +fn install_copies_plugin_into_default_marketplace() { + let tmp = tempdir().unwrap(); + write_plugin(tmp.path(), "sample-plugin", "sample-plugin"); + let plugin_id = PluginId::new("sample-plugin".to_string(), "debug".to_string()).unwrap(); + + let result = PluginStore::new(tmp.path().to_path_buf()) + .install( + AbsolutePathBuf::try_from(tmp.path().join("sample-plugin")).unwrap(), + plugin_id.clone(), + ) + .unwrap(); + + let installed_path = tmp.path().join("plugins/cache/debug/sample-plugin/local"); + assert_eq!( + result, + PluginInstallResult { + plugin_id, + plugin_version: "local".to_string(), + installed_path: AbsolutePathBuf::try_from(installed_path.clone()).unwrap(), + } + ); + assert!(installed_path.join(".codex-plugin/plugin.json").is_file()); + assert!(installed_path.join("skills/SKILL.md").is_file()); +} + +#[test] +fn install_uses_manifest_name_for_destination_and_key() { + let tmp = tempdir().unwrap(); + write_plugin(tmp.path(), "source-dir", "manifest-name"); + let plugin_id = PluginId::new("manifest-name".to_string(), "market".to_string()).unwrap(); + + let result = PluginStore::new(tmp.path().to_path_buf()) + .install( + AbsolutePathBuf::try_from(tmp.path().join("source-dir")).unwrap(), + plugin_id.clone(), + ) + .unwrap(); + + assert_eq!( + result, + PluginInstallResult { + plugin_id, + plugin_version: "local".to_string(), + installed_path: AbsolutePathBuf::try_from( + tmp.path().join("plugins/cache/market/manifest-name/local"), + ) + .unwrap(), + } + ); +} + +#[test] +fn plugin_root_derives_path_from_key_and_version() { + let tmp = tempdir().unwrap(); + let store = PluginStore::new(tmp.path().to_path_buf()); + let plugin_id = PluginId::new("sample".to_string(), "debug".to_string()).unwrap(); + + assert_eq!( + store.plugin_root(&plugin_id, "local").as_path(), + tmp.path().join("plugins/cache/debug/sample/local") + ); +} + +#[test] +fn install_with_version_uses_requested_cache_version() { + let tmp = tempdir().unwrap(); + write_plugin(tmp.path(), "sample-plugin", "sample-plugin"); + let plugin_id = + PluginId::new("sample-plugin".to_string(), "openai-curated".to_string()).unwrap(); + let plugin_version = "0123456789abcdef".to_string(); + + let result = PluginStore::new(tmp.path().to_path_buf()) + .install_with_version( + AbsolutePathBuf::try_from(tmp.path().join("sample-plugin")).unwrap(), + plugin_id.clone(), + plugin_version.clone(), + ) + .unwrap(); + + let installed_path = tmp.path().join(format!( + "plugins/cache/openai-curated/sample-plugin/{plugin_version}" + )); + assert_eq!( + result, + PluginInstallResult { + plugin_id, + plugin_version, + installed_path: AbsolutePathBuf::try_from(installed_path.clone()).unwrap(), + } + ); + assert!(installed_path.join(".codex-plugin/plugin.json").is_file()); +} + +#[test] +fn active_plugin_version_reads_version_directory_name() { + let tmp = tempdir().unwrap(); + write_plugin( + &tmp.path().join("plugins/cache/debug"), + "sample-plugin/local", + "sample-plugin", + ); + let store = PluginStore::new(tmp.path().to_path_buf()); + let plugin_id = PluginId::new("sample-plugin".to_string(), "debug".to_string()).unwrap(); + + assert_eq!( + store.active_plugin_version(&plugin_id), + Some("local".to_string()) + ); + assert_eq!( + store.active_plugin_root(&plugin_id).unwrap().as_path(), + tmp.path().join("plugins/cache/debug/sample-plugin/local") + ); +} + +#[test] +fn plugin_root_rejects_path_separators_in_key_segments() { + let err = PluginId::parse("../../etc@debug").unwrap_err(); + assert_eq!( + err.to_string(), + "invalid plugin name: only ASCII letters, digits, `_`, and `-` are allowed in `../../etc@debug`" + ); + + let err = PluginId::parse("sample@../../etc").unwrap_err(); + assert_eq!( + err.to_string(), + "invalid marketplace name: only ASCII letters, digits, `_`, and `-` are allowed in `sample@../../etc`" + ); +} + +#[test] +fn install_rejects_manifest_names_with_path_separators() { + let tmp = tempdir().unwrap(); + write_plugin(tmp.path(), "source-dir", "../../etc"); + + let err = PluginStore::new(tmp.path().to_path_buf()) + .install( + AbsolutePathBuf::try_from(tmp.path().join("source-dir")).unwrap(), + PluginId::new("source-dir".to_string(), "debug".to_string()).unwrap(), + ) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "invalid plugin name: only ASCII letters, digits, `_`, and `-` are allowed" + ); +} + +#[test] +fn install_rejects_marketplace_names_with_path_separators() { + let err = PluginId::new("sample-plugin".to_string(), "../../etc".to_string()).unwrap_err(); + + assert_eq!( + err.to_string(), + "invalid marketplace name: only ASCII letters, digits, `_`, and `-` are allowed" + ); +} + +#[test] +fn install_rejects_manifest_names_that_do_not_match_marketplace_plugin_name() { + let tmp = tempdir().unwrap(); + write_plugin(tmp.path(), "source-dir", "manifest-name"); + + let err = PluginStore::new(tmp.path().to_path_buf()) + .install( + AbsolutePathBuf::try_from(tmp.path().join("source-dir")).unwrap(), + PluginId::new("different-name".to_string(), "debug".to_string()).unwrap(), + ) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "plugin manifest name `manifest-name` does not match marketplace plugin name `different-name`" + ); +} diff --git a/codex-rs/core/src/project_doc.rs b/codex-rs/core/src/project_doc.rs index 958feb4db1..1dc5189821 100644 --- a/codex-rs/core/src/project_doc.rs +++ b/codex-rs/core/src/project_doc.rs @@ -318,483 +318,5 @@ fn candidate_filenames<'a>(config: &'a Config) -> Vec<&'a str> { } #[cfg(test)] -mod tests { - use super::*; - use crate::config::ConfigBuilder; - use crate::features::Feature; - use crate::skills::loader::SkillRoot; - use crate::skills::loader::load_skills_from_roots; - use codex_protocol::protocol::SkillScope; - use std::fs; - use std::path::PathBuf; - use tempfile::TempDir; - - /// Helper that returns a `Config` pointing at `root` and using `limit` as - /// the maximum number of bytes to embed from AGENTS.md. The caller can - /// optionally specify a custom `instructions` string – when `None` the - /// value is cleared to mimic a scenario where no system instructions have - /// been configured. - async fn make_config(root: &TempDir, limit: usize, instructions: Option<&str>) -> Config { - let codex_home = TempDir::new().unwrap(); - let mut config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .build() - .await - .expect("defaults for test should always succeed"); - - config.cwd = root.path().to_path_buf(); - config.project_doc_max_bytes = limit; - - config.user_instructions = instructions.map(ToOwned::to_owned); - config - } - - async fn make_config_with_fallback( - root: &TempDir, - limit: usize, - instructions: Option<&str>, - fallbacks: &[&str], - ) -> Config { - let mut config = make_config(root, limit, instructions).await; - config.project_doc_fallback_filenames = fallbacks - .iter() - .map(std::string::ToString::to_string) - .collect(); - config - } - - async fn make_config_with_project_root_markers( - root: &TempDir, - limit: usize, - instructions: Option<&str>, - markers: &[&str], - ) -> Config { - let codex_home = TempDir::new().unwrap(); - let cli_overrides = vec![( - "project_root_markers".to_string(), - TomlValue::Array( - markers - .iter() - .map(|marker| TomlValue::String((*marker).to_string())) - .collect(), - ), - )]; - let mut config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .cli_overrides(cli_overrides) - .build() - .await - .expect("defaults for test should always succeed"); - - config.cwd = root.path().to_path_buf(); - config.project_doc_max_bytes = limit; - config.user_instructions = instructions.map(ToOwned::to_owned); - config - } - - fn load_test_skills(config: &Config) -> crate::skills::SkillLoadOutcome { - load_skills_from_roots([SkillRoot { - path: config.codex_home.join("skills"), - scope: SkillScope::User, - }]) - } - - /// AGENTS.md missing – should yield `None`. - #[tokio::test] - async fn no_doc_file_returns_none() { - let tmp = tempfile::tempdir().expect("tempdir"); - - let res = get_user_instructions(&make_config(&tmp, 4096, None).await, None, None).await; - assert!( - res.is_none(), - "Expected None when AGENTS.md is absent and no system instructions provided" - ); - assert!(res.is_none(), "Expected None when AGENTS.md is absent"); - } - - /// Small file within the byte-limit is returned unmodified. - #[tokio::test] - async fn doc_smaller_than_limit_is_returned() { - let tmp = tempfile::tempdir().expect("tempdir"); - fs::write(tmp.path().join("AGENTS.md"), "hello world").unwrap(); - - let res = get_user_instructions(&make_config(&tmp, 4096, None).await, None, None) - .await - .expect("doc expected"); - - assert_eq!( - res, "hello world", - "The document should be returned verbatim when it is smaller than the limit and there are no existing instructions" - ); - } - - /// Oversize file is truncated to `project_doc_max_bytes`. - #[tokio::test] - async fn doc_larger_than_limit_is_truncated() { - const LIMIT: usize = 1024; - let tmp = tempfile::tempdir().expect("tempdir"); - - let huge = "A".repeat(LIMIT * 2); // 2 KiB - fs::write(tmp.path().join("AGENTS.md"), &huge).unwrap(); - - let res = get_user_instructions(&make_config(&tmp, LIMIT, None).await, None, None) - .await - .expect("doc expected"); - - assert_eq!(res.len(), LIMIT, "doc should be truncated to LIMIT bytes"); - assert_eq!(res, huge[..LIMIT]); - } - - /// When `cwd` is nested inside a repo, the search should locate AGENTS.md - /// placed at the repository root (identified by `.git`). - #[tokio::test] - async fn finds_doc_in_repo_root() { - let repo = tempfile::tempdir().expect("tempdir"); - - // Simulate a git repository. Note .git can be a file or a directory. - std::fs::write( - repo.path().join(".git"), - "gitdir: /path/to/actual/git/dir\n", - ) - .unwrap(); - - // Put the doc at the repo root. - fs::write(repo.path().join("AGENTS.md"), "root level doc").unwrap(); - - // Now create a nested working directory: repo/workspace/crate_a - let nested = repo.path().join("workspace/crate_a"); - std::fs::create_dir_all(&nested).unwrap(); - - // Build config pointing at the nested dir. - let mut cfg = make_config(&repo, 4096, None).await; - cfg.cwd = nested; - - let res = get_user_instructions(&cfg, None, None) - .await - .expect("doc expected"); - assert_eq!(res, "root level doc"); - } - - /// Explicitly setting the byte-limit to zero disables project docs. - #[tokio::test] - async fn zero_byte_limit_disables_docs() { - let tmp = tempfile::tempdir().expect("tempdir"); - fs::write(tmp.path().join("AGENTS.md"), "something").unwrap(); - - let res = get_user_instructions(&make_config(&tmp, 0, None).await, None, None).await; - assert!( - res.is_none(), - "With limit 0 the function should return None" - ); - } - - #[tokio::test] - async fn js_repl_instructions_are_appended_when_enabled() { - let tmp = tempfile::tempdir().expect("tempdir"); - let mut cfg = make_config(&tmp, 4096, None).await; - cfg.features - .enable(Feature::JsRepl) - .expect("test config should allow js_repl"); - - let res = get_user_instructions(&cfg, None, None) - .await - .expect("js_repl instructions expected"); - let expected = "## JavaScript REPL (Node)\n- Use `js_repl` for Node-backed JavaScript with top-level await in a persistent kernel.\n- `js_repl` is a freeform/custom tool. Direct `js_repl` calls must send raw JavaScript tool input (optionally with first-line `// codex-js-repl: timeout_ms=15000`). Do not wrap code in JSON (for example `{\"code\":\"...\"}`), quotes, or markdown code fences.\n- Helpers: `codex.cwd`, `codex.homeDir`, `codex.tmpDir`, `codex.tool(name, args?)`, and `codex.emitImage(imageLike)`.\n- `codex.tool` executes a normal tool call and resolves to the raw tool output object. Use it for shell and non-shell tools alike. Nested tool outputs stay inside JavaScript unless you emit them explicitly.\n- `codex.emitImage(...)` adds one image to the outer `js_repl` function output each time you call it, so you can call it multiple times to emit multiple images. It accepts a data URL, a single `input_image` item, an object like `{ bytes, mimeType }`, or a raw tool response object with exactly one image and no text. It rejects mixed text-and-image content.\n- Request full-resolution image processing with `detail: \"original\"` only when the `view_image` tool schema includes a `detail` argument. The same availability applies to `codex.emitImage(...)`: if `view_image.detail` is present, you may also pass `detail: \"original\"` there. Use this when high-fidelity image perception or precise localization is needed, especially for CUA agents.\n- Example of sharing an in-memory Playwright screenshot: `await codex.emitImage({ bytes: await page.screenshot({ type: \"jpeg\", quality: 85 }), mimeType: \"image/jpeg\", detail: \"original\" })`.\n- Example of sharing a local image tool result: `await codex.emitImage(codex.tool(\"view_image\", { path: \"/absolute/path\", detail: \"original\" }))`.\n- When encoding an image to send with `codex.emitImage(...)` or `view_image`, prefer JPEG at about 85 quality when lossy compression is acceptable; use PNG when transparency or lossless detail matters. Smaller uploads are faster and less likely to hit size limits.\n- Top-level bindings persist across cells. If a cell throws, prior bindings remain available and bindings that finished initializing before the throw often remain usable in later cells. For code you plan to reuse across cells, prefer declaring or assigning it in direct top-level statements before operations that might throw. If you hit `SyntaxError: Identifier 'x' has already been declared`, first reuse the existing binding, reassign a previously declared `let`, or pick a new descriptive name. Use `{ ... }` only for a short temporary block when you specifically need local scratch names; do not wrap an entire cell in block scope if you want those names reusable later. Reset the kernel with `js_repl_reset` only when you need a clean state.\n- Top-level static import declarations (for example `import x from \"./file.js\"`) are currently unsupported in `js_repl`; use dynamic imports with `await import(\"pkg\")`, `await import(\"./file.js\")`, or `await import(\"/abs/path/file.mjs\")` instead. Imported local files must be ESM `.js`/`.mjs` files and run in the same REPL VM context. Bare package imports always resolve from REPL-global search roots (`CODEX_JS_REPL_NODE_MODULE_DIRS`, then cwd), not relative to the imported file location. Local files may statically import only other local relative/absolute/`file://` `.js`/`.mjs` files; package and builtin imports from local files must stay dynamic. `import.meta.resolve()` returns importable strings such as `file://...`, bare package names, and `node:...` specifiers. Local file modules reload between execs, while top-level bindings persist until `js_repl_reset`.\n- Avoid direct access to `process.stdout` / `process.stderr` / `process.stdin`; it can corrupt the JSON line protocol. Use `console.log`, `codex.tool(...)`, and `codex.emitImage(...)`."; - assert_eq!(res, expected); - } - - #[tokio::test] - async fn js_repl_tools_only_instructions_are_feature_gated() { - let tmp = tempfile::tempdir().expect("tempdir"); - let mut cfg = make_config(&tmp, 4096, None).await; - let mut features = cfg.features.get().clone(); - features - .enable(Feature::JsRepl) - .enable(Feature::JsReplToolsOnly); - cfg.features - .set(features) - .expect("test config should allow js_repl tool restrictions"); - - let res = get_user_instructions(&cfg, None, None) - .await - .expect("js_repl instructions expected"); - let expected = "## JavaScript REPL (Node)\n- Use `js_repl` for Node-backed JavaScript with top-level await in a persistent kernel.\n- `js_repl` is a freeform/custom tool. Direct `js_repl` calls must send raw JavaScript tool input (optionally with first-line `// codex-js-repl: timeout_ms=15000`). Do not wrap code in JSON (for example `{\"code\":\"...\"}`), quotes, or markdown code fences.\n- Helpers: `codex.cwd`, `codex.homeDir`, `codex.tmpDir`, `codex.tool(name, args?)`, and `codex.emitImage(imageLike)`.\n- `codex.tool` executes a normal tool call and resolves to the raw tool output object. Use it for shell and non-shell tools alike. Nested tool outputs stay inside JavaScript unless you emit them explicitly.\n- `codex.emitImage(...)` adds one image to the outer `js_repl` function output each time you call it, so you can call it multiple times to emit multiple images. It accepts a data URL, a single `input_image` item, an object like `{ bytes, mimeType }`, or a raw tool response object with exactly one image and no text. It rejects mixed text-and-image content.\n- Request full-resolution image processing with `detail: \"original\"` only when the `view_image` tool schema includes a `detail` argument. The same availability applies to `codex.emitImage(...)`: if `view_image.detail` is present, you may also pass `detail: \"original\"` there. Use this when high-fidelity image perception or precise localization is needed, especially for CUA agents.\n- Example of sharing an in-memory Playwright screenshot: `await codex.emitImage({ bytes: await page.screenshot({ type: \"jpeg\", quality: 85 }), mimeType: \"image/jpeg\", detail: \"original\" })`.\n- Example of sharing a local image tool result: `await codex.emitImage(codex.tool(\"view_image\", { path: \"/absolute/path\", detail: \"original\" }))`.\n- When encoding an image to send with `codex.emitImage(...)` or `view_image`, prefer JPEG at about 85 quality when lossy compression is acceptable; use PNG when transparency or lossless detail matters. Smaller uploads are faster and less likely to hit size limits.\n- Top-level bindings persist across cells. If a cell throws, prior bindings remain available and bindings that finished initializing before the throw often remain usable in later cells. For code you plan to reuse across cells, prefer declaring or assigning it in direct top-level statements before operations that might throw. If you hit `SyntaxError: Identifier 'x' has already been declared`, first reuse the existing binding, reassign a previously declared `let`, or pick a new descriptive name. Use `{ ... }` only for a short temporary block when you specifically need local scratch names; do not wrap an entire cell in block scope if you want those names reusable later. Reset the kernel with `js_repl_reset` only when you need a clean state.\n- Top-level static import declarations (for example `import x from \"./file.js\"`) are currently unsupported in `js_repl`; use dynamic imports with `await import(\"pkg\")`, `await import(\"./file.js\")`, or `await import(\"/abs/path/file.mjs\")` instead. Imported local files must be ESM `.js`/`.mjs` files and run in the same REPL VM context. Bare package imports always resolve from REPL-global search roots (`CODEX_JS_REPL_NODE_MODULE_DIRS`, then cwd), not relative to the imported file location. Local files may statically import only other local relative/absolute/`file://` `.js`/`.mjs` files; package and builtin imports from local files must stay dynamic. `import.meta.resolve()` returns importable strings such as `file://...`, bare package names, and `node:...` specifiers. Local file modules reload between execs, while top-level bindings persist until `js_repl_reset`.\n- Do not call tools directly; use `js_repl` + `codex.tool(...)` for all tool calls, including shell commands.\n- MCP tools (if any) can also be called by name via `codex.tool(...)`.\n- Avoid direct access to `process.stdout` / `process.stderr` / `process.stdin`; it can corrupt the JSON line protocol. Use `console.log`, `codex.tool(...)`, and `codex.emitImage(...)`."; - assert_eq!(res, expected); - } - - #[tokio::test] - async fn js_repl_image_detail_original_does_not_change_instructions() { - let tmp = tempfile::tempdir().expect("tempdir"); - let mut cfg = make_config(&tmp, 4096, None).await; - let mut features = cfg.features.get().clone(); - features - .enable(Feature::JsRepl) - .enable(Feature::ImageDetailOriginal); - cfg.features - .set(features) - .expect("test config should allow js_repl image detail settings"); - - let res = get_user_instructions(&cfg, None, None) - .await - .expect("js_repl instructions expected"); - let expected = "## JavaScript REPL (Node)\n- Use `js_repl` for Node-backed JavaScript with top-level await in a persistent kernel.\n- `js_repl` is a freeform/custom tool. Direct `js_repl` calls must send raw JavaScript tool input (optionally with first-line `// codex-js-repl: timeout_ms=15000`). Do not wrap code in JSON (for example `{\"code\":\"...\"}`), quotes, or markdown code fences.\n- Helpers: `codex.cwd`, `codex.homeDir`, `codex.tmpDir`, `codex.tool(name, args?)`, and `codex.emitImage(imageLike)`.\n- `codex.tool` executes a normal tool call and resolves to the raw tool output object. Use it for shell and non-shell tools alike. Nested tool outputs stay inside JavaScript unless you emit them explicitly.\n- `codex.emitImage(...)` adds one image to the outer `js_repl` function output each time you call it, so you can call it multiple times to emit multiple images. It accepts a data URL, a single `input_image` item, an object like `{ bytes, mimeType }`, or a raw tool response object with exactly one image and no text. It rejects mixed text-and-image content.\n- Request full-resolution image processing with `detail: \"original\"` only when the `view_image` tool schema includes a `detail` argument. The same availability applies to `codex.emitImage(...)`: if `view_image.detail` is present, you may also pass `detail: \"original\"` there. Use this when high-fidelity image perception or precise localization is needed, especially for CUA agents.\n- Example of sharing an in-memory Playwright screenshot: `await codex.emitImage({ bytes: await page.screenshot({ type: \"jpeg\", quality: 85 }), mimeType: \"image/jpeg\", detail: \"original\" })`.\n- Example of sharing a local image tool result: `await codex.emitImage(codex.tool(\"view_image\", { path: \"/absolute/path\", detail: \"original\" }))`.\n- When encoding an image to send with `codex.emitImage(...)` or `view_image`, prefer JPEG at about 85 quality when lossy compression is acceptable; use PNG when transparency or lossless detail matters. Smaller uploads are faster and less likely to hit size limits.\n- Top-level bindings persist across cells. If a cell throws, prior bindings remain available and bindings that finished initializing before the throw often remain usable in later cells. For code you plan to reuse across cells, prefer declaring or assigning it in direct top-level statements before operations that might throw. If you hit `SyntaxError: Identifier 'x' has already been declared`, first reuse the existing binding, reassign a previously declared `let`, or pick a new descriptive name. Use `{ ... }` only for a short temporary block when you specifically need local scratch names; do not wrap an entire cell in block scope if you want those names reusable later. Reset the kernel with `js_repl_reset` only when you need a clean state.\n- Top-level static import declarations (for example `import x from \"./file.js\"`) are currently unsupported in `js_repl`; use dynamic imports with `await import(\"pkg\")`, `await import(\"./file.js\")`, or `await import(\"/abs/path/file.mjs\")` instead. Imported local files must be ESM `.js`/`.mjs` files and run in the same REPL VM context. Bare package imports always resolve from REPL-global search roots (`CODEX_JS_REPL_NODE_MODULE_DIRS`, then cwd), not relative to the imported file location. Local files may statically import only other local relative/absolute/`file://` `.js`/`.mjs` files; package and builtin imports from local files must stay dynamic. `import.meta.resolve()` returns importable strings such as `file://...`, bare package names, and `node:...` specifiers. Local file modules reload between execs, while top-level bindings persist until `js_repl_reset`.\n- Avoid direct access to `process.stdout` / `process.stderr` / `process.stdin`; it can corrupt the JSON line protocol. Use `console.log`, `codex.tool(...)`, and `codex.emitImage(...)`."; - assert_eq!(res, expected); - } - - /// When both system instructions *and* a project doc are present the two - /// should be concatenated with the separator. - #[tokio::test] - async fn merges_existing_instructions_with_project_doc() { - let tmp = tempfile::tempdir().expect("tempdir"); - fs::write(tmp.path().join("AGENTS.md"), "proj doc").unwrap(); - - const INSTRUCTIONS: &str = "base instructions"; - - let res = get_user_instructions( - &make_config(&tmp, 4096, Some(INSTRUCTIONS)).await, - None, - None, - ) - .await - .expect("should produce a combined instruction string"); - - let expected = format!("{INSTRUCTIONS}{PROJECT_DOC_SEPARATOR}{}", "proj doc"); - - assert_eq!(res, expected); - } - - /// If there are existing system instructions but the project doc is - /// missing we expect the original instructions to be returned unchanged. - #[tokio::test] - async fn keeps_existing_instructions_when_doc_missing() { - let tmp = tempfile::tempdir().expect("tempdir"); - - const INSTRUCTIONS: &str = "some instructions"; - - let res = get_user_instructions( - &make_config(&tmp, 4096, Some(INSTRUCTIONS)).await, - None, - None, - ) - .await; - - assert_eq!(res, Some(INSTRUCTIONS.to_string())); - } - - /// When both the repository root and the working directory contain - /// AGENTS.md files, their contents are concatenated from root to cwd. - #[tokio::test] - async fn concatenates_root_and_cwd_docs() { - let repo = tempfile::tempdir().expect("tempdir"); - - // Simulate a git repository. - std::fs::write( - repo.path().join(".git"), - "gitdir: /path/to/actual/git/dir\n", - ) - .unwrap(); - - // Repo root doc. - fs::write(repo.path().join("AGENTS.md"), "root doc").unwrap(); - - // Nested working directory with its own doc. - let nested = repo.path().join("workspace/crate_a"); - std::fs::create_dir_all(&nested).unwrap(); - fs::write(nested.join("AGENTS.md"), "crate doc").unwrap(); - - let mut cfg = make_config(&repo, 4096, None).await; - cfg.cwd = nested; - - let res = get_user_instructions(&cfg, None, None) - .await - .expect("doc expected"); - assert_eq!(res, "root doc\n\ncrate doc"); - } - - #[tokio::test] - async fn project_root_markers_are_honored_for_agents_discovery() { - let root = tempfile::tempdir().expect("tempdir"); - fs::write(root.path().join(".codex-root"), "").unwrap(); - fs::write(root.path().join("AGENTS.md"), "parent doc").unwrap(); - - let nested = root.path().join("dir1"); - fs::create_dir_all(nested.join(".git")).unwrap(); - fs::write(nested.join("AGENTS.md"), "child doc").unwrap(); - - let mut cfg = - make_config_with_project_root_markers(&root, 4096, None, &[".codex-root"]).await; - cfg.cwd = nested; - - let discovery = discover_project_doc_paths(&cfg).expect("discover paths"); - let expected_parent = - dunce::canonicalize(root.path().join("AGENTS.md")).expect("canonical parent doc path"); - let expected_child = - dunce::canonicalize(cfg.cwd.join("AGENTS.md")).expect("canonical child doc path"); - assert_eq!(discovery.len(), 2); - assert_eq!(discovery[0], expected_parent); - assert_eq!(discovery[1], expected_child); - - let res = get_user_instructions(&cfg, None, None) - .await - .expect("doc expected"); - assert_eq!(res, "parent doc\n\nchild doc"); - } - - /// AGENTS.override.md is preferred over AGENTS.md when both are present. - #[tokio::test] - async fn agents_local_md_preferred() { - let tmp = tempfile::tempdir().expect("tempdir"); - fs::write(tmp.path().join(DEFAULT_PROJECT_DOC_FILENAME), "versioned").unwrap(); - fs::write(tmp.path().join(LOCAL_PROJECT_DOC_FILENAME), "local").unwrap(); - - let cfg = make_config(&tmp, 4096, None).await; - - let res = get_user_instructions(&cfg, None, None) - .await - .expect("local doc expected"); - - assert_eq!(res, "local"); - - let discovery = discover_project_doc_paths(&cfg).expect("discover paths"); - assert_eq!(discovery.len(), 1); - assert_eq!( - discovery[0].file_name().unwrap().to_string_lossy(), - LOCAL_PROJECT_DOC_FILENAME - ); - } - - /// When AGENTS.md is absent but a configured fallback exists, the fallback is used. - #[tokio::test] - async fn uses_configured_fallback_when_agents_missing() { - let tmp = tempfile::tempdir().expect("tempdir"); - fs::write(tmp.path().join("EXAMPLE.md"), "example instructions").unwrap(); - - let cfg = make_config_with_fallback(&tmp, 4096, None, &["EXAMPLE.md"]).await; - - let res = get_user_instructions(&cfg, None, None) - .await - .expect("fallback doc expected"); - - assert_eq!(res, "example instructions"); - } - - /// AGENTS.md remains preferred when both AGENTS.md and fallbacks are present. - #[tokio::test] - async fn agents_md_preferred_over_fallbacks() { - let tmp = tempfile::tempdir().expect("tempdir"); - fs::write(tmp.path().join("AGENTS.md"), "primary").unwrap(); - fs::write(tmp.path().join("EXAMPLE.md"), "secondary").unwrap(); - - let cfg = make_config_with_fallback(&tmp, 4096, None, &["EXAMPLE.md", ".example.md"]).await; - - let res = get_user_instructions(&cfg, None, None) - .await - .expect("AGENTS.md should win"); - - assert_eq!(res, "primary"); - - let discovery = discover_project_doc_paths(&cfg).expect("discover paths"); - assert_eq!(discovery.len(), 1); - assert!( - discovery[0] - .file_name() - .unwrap() - .to_string_lossy() - .eq(DEFAULT_PROJECT_DOC_FILENAME) - ); - } - - #[tokio::test] - async fn skills_are_appended_to_project_doc() { - let tmp = tempfile::tempdir().expect("tempdir"); - fs::write(tmp.path().join("AGENTS.md"), "base doc").unwrap(); - - let cfg = make_config(&tmp, 4096, None).await; - create_skill( - cfg.codex_home.clone(), - "pdf-processing", - "extract from pdfs", - ); - - let skills = load_test_skills(&cfg); - let res = get_user_instructions( - &cfg, - skills.errors.is_empty().then_some(skills.skills.as_slice()), - None, - ) - .await - .expect("instructions expected"); - let expected_path = dunce::canonicalize( - cfg.codex_home - .join("skills/pdf-processing/SKILL.md") - .as_path(), - ) - .unwrap_or_else(|_| cfg.codex_home.join("skills/pdf-processing/SKILL.md")); - let expected_path_str = expected_path.to_string_lossy().replace('\\', "/"); - let usage_rules = "- Discovery: The list above is the skills available in this session (name + description + file path). Skill bodies live on disk at the listed paths.\n- Trigger rules: If the user names a skill (with `$SkillName` or plain text) OR the task clearly matches a skill's description shown above, you must use that skill for that turn. Multiple mentions mean use them all. Do not carry skills across turns unless re-mentioned.\n- Missing/blocked: If a named skill isn't in the list or the path can't be read, say so briefly and continue with the best fallback.\n- How to use a skill (progressive disclosure):\n 1) After deciding to use a skill, open its `SKILL.md`. Read only enough to follow the workflow.\n 2) When `SKILL.md` references relative paths (e.g., `scripts/foo.py`), resolve them relative to the skill directory listed above first, and only consider other paths if needed.\n 3) If `SKILL.md` points to extra folders such as `references/`, load only the specific files needed for the request; don't bulk-load everything.\n 4) If `scripts/` exist, prefer running or patching them instead of retyping large code blocks.\n 5) If `assets/` or templates exist, reuse them instead of recreating from scratch.\n- Coordination and sequencing:\n - If multiple skills apply, choose the minimal set that covers the request and state the order you'll use them.\n - Announce which skill(s) you're using and why (one short line). If you skip an obvious skill, say why.\n- Context hygiene:\n - Keep context small: summarize long sections instead of pasting them; only load extra files when needed.\n - Avoid deep reference-chasing: prefer opening only files directly linked from `SKILL.md` unless you're blocked.\n - When variants exist (frameworks, providers, domains), pick only the relevant reference file(s) and note that choice.\n- Safety and fallback: If a skill can't be applied cleanly (missing files, unclear instructions), state the issue, pick the next-best approach, and continue."; - let expected = format!( - "base doc\n\n## Skills\nA skill is a set of local instructions to follow that is stored in a `SKILL.md` file. Below is the list of skills that can be used. Each entry includes a name, description, and file path so you can open the source for full instructions when using a specific skill.\n### Available skills\n- pdf-processing: extract from pdfs (file: {expected_path_str})\n### How to use skills\n{usage_rules}" - ); - assert_eq!(res, expected); - } - - #[tokio::test] - async fn skills_render_without_project_doc() { - let tmp = tempfile::tempdir().expect("tempdir"); - let cfg = make_config(&tmp, 4096, None).await; - create_skill(cfg.codex_home.clone(), "linting", "run clippy"); - - let skills = load_test_skills(&cfg); - let res = get_user_instructions( - &cfg, - skills.errors.is_empty().then_some(skills.skills.as_slice()), - None, - ) - .await - .expect("instructions expected"); - let expected_path = - dunce::canonicalize(cfg.codex_home.join("skills/linting/SKILL.md").as_path()) - .unwrap_or_else(|_| cfg.codex_home.join("skills/linting/SKILL.md")); - let expected_path_str = expected_path.to_string_lossy().replace('\\', "/"); - let usage_rules = "- Discovery: The list above is the skills available in this session (name + description + file path). Skill bodies live on disk at the listed paths.\n- Trigger rules: If the user names a skill (with `$SkillName` or plain text) OR the task clearly matches a skill's description shown above, you must use that skill for that turn. Multiple mentions mean use them all. Do not carry skills across turns unless re-mentioned.\n- Missing/blocked: If a named skill isn't in the list or the path can't be read, say so briefly and continue with the best fallback.\n- How to use a skill (progressive disclosure):\n 1) After deciding to use a skill, open its `SKILL.md`. Read only enough to follow the workflow.\n 2) When `SKILL.md` references relative paths (e.g., `scripts/foo.py`), resolve them relative to the skill directory listed above first, and only consider other paths if needed.\n 3) If `SKILL.md` points to extra folders such as `references/`, load only the specific files needed for the request; don't bulk-load everything.\n 4) If `scripts/` exist, prefer running or patching them instead of retyping large code blocks.\n 5) If `assets/` or templates exist, reuse them instead of recreating from scratch.\n- Coordination and sequencing:\n - If multiple skills apply, choose the minimal set that covers the request and state the order you'll use them.\n - Announce which skill(s) you're using and why (one short line). If you skip an obvious skill, say why.\n- Context hygiene:\n - Keep context small: summarize long sections instead of pasting them; only load extra files when needed.\n - Avoid deep reference-chasing: prefer opening only files directly linked from `SKILL.md` unless you're blocked.\n - When variants exist (frameworks, providers, domains), pick only the relevant reference file(s) and note that choice.\n- Safety and fallback: If a skill can't be applied cleanly (missing files, unclear instructions), state the issue, pick the next-best approach, and continue."; - let expected = format!( - "## Skills\nA skill is a set of local instructions to follow that is stored in a `SKILL.md` file. Below is the list of skills that can be used. Each entry includes a name, description, and file path so you can open the source for full instructions when using a specific skill.\n### Available skills\n- linting: run clippy (file: {expected_path_str})\n### How to use skills\n{usage_rules}" - ); - assert_eq!(res, expected); - } - - #[tokio::test] - async fn apps_feature_does_not_emit_user_instructions_by_itself() { - let tmp = tempfile::tempdir().expect("tempdir"); - let mut cfg = make_config(&tmp, 4096, None).await; - cfg.features - .enable(Feature::Apps) - .expect("test config should allow apps"); - - let res = get_user_instructions(&cfg, None, None).await; - assert_eq!(res, None); - } - - #[tokio::test] - async fn apps_feature_does_not_append_to_project_doc_user_instructions() { - let tmp = tempfile::tempdir().expect("tempdir"); - fs::write(tmp.path().join("AGENTS.md"), "base doc").unwrap(); - - let mut cfg = make_config(&tmp, 4096, None).await; - cfg.features - .enable(Feature::Apps) - .expect("test config should allow apps"); - - let res = get_user_instructions(&cfg, None, None) - .await - .expect("instructions expected"); - assert_eq!(res, "base doc"); - } - - fn create_skill(codex_home: PathBuf, name: &str, description: &str) { - let skill_dir = codex_home.join(format!("skills/{name}")); - fs::create_dir_all(&skill_dir).unwrap(); - let content = format!("---\nname: {name}\ndescription: {description}\n---\n\n# Body\n"); - fs::write(skill_dir.join("SKILL.md"), content).unwrap(); - } -} +#[path = "project_doc_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/project_doc_tests.rs b/codex-rs/core/src/project_doc_tests.rs new file mode 100644 index 0000000000..f801bcf139 --- /dev/null +++ b/codex-rs/core/src/project_doc_tests.rs @@ -0,0 +1,477 @@ +use super::*; +use crate::config::ConfigBuilder; +use crate::features::Feature; +use crate::skills::loader::SkillRoot; +use crate::skills::loader::load_skills_from_roots; +use codex_protocol::protocol::SkillScope; +use std::fs; +use std::path::PathBuf; +use tempfile::TempDir; + +/// Helper that returns a `Config` pointing at `root` and using `limit` as +/// the maximum number of bytes to embed from AGENTS.md. The caller can +/// optionally specify a custom `instructions` string – when `None` the +/// value is cleared to mimic a scenario where no system instructions have +/// been configured. +async fn make_config(root: &TempDir, limit: usize, instructions: Option<&str>) -> Config { + let codex_home = TempDir::new().unwrap(); + let mut config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .build() + .await + .expect("defaults for test should always succeed"); + + config.cwd = root.path().to_path_buf(); + config.project_doc_max_bytes = limit; + + config.user_instructions = instructions.map(ToOwned::to_owned); + config +} + +async fn make_config_with_fallback( + root: &TempDir, + limit: usize, + instructions: Option<&str>, + fallbacks: &[&str], +) -> Config { + let mut config = make_config(root, limit, instructions).await; + config.project_doc_fallback_filenames = fallbacks + .iter() + .map(std::string::ToString::to_string) + .collect(); + config +} + +async fn make_config_with_project_root_markers( + root: &TempDir, + limit: usize, + instructions: Option<&str>, + markers: &[&str], +) -> Config { + let codex_home = TempDir::new().unwrap(); + let cli_overrides = vec![( + "project_root_markers".to_string(), + TomlValue::Array( + markers + .iter() + .map(|marker| TomlValue::String((*marker).to_string())) + .collect(), + ), + )]; + let mut config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .cli_overrides(cli_overrides) + .build() + .await + .expect("defaults for test should always succeed"); + + config.cwd = root.path().to_path_buf(); + config.project_doc_max_bytes = limit; + config.user_instructions = instructions.map(ToOwned::to_owned); + config +} + +fn load_test_skills(config: &Config) -> crate::skills::SkillLoadOutcome { + load_skills_from_roots([SkillRoot { + path: config.codex_home.join("skills"), + scope: SkillScope::User, + }]) +} + +/// AGENTS.md missing – should yield `None`. +#[tokio::test] +async fn no_doc_file_returns_none() { + let tmp = tempfile::tempdir().expect("tempdir"); + + let res = get_user_instructions(&make_config(&tmp, 4096, None).await, None, None).await; + assert!( + res.is_none(), + "Expected None when AGENTS.md is absent and no system instructions provided" + ); + assert!(res.is_none(), "Expected None when AGENTS.md is absent"); +} + +/// Small file within the byte-limit is returned unmodified. +#[tokio::test] +async fn doc_smaller_than_limit_is_returned() { + let tmp = tempfile::tempdir().expect("tempdir"); + fs::write(tmp.path().join("AGENTS.md"), "hello world").unwrap(); + + let res = get_user_instructions(&make_config(&tmp, 4096, None).await, None, None) + .await + .expect("doc expected"); + + assert_eq!( + res, "hello world", + "The document should be returned verbatim when it is smaller than the limit and there are no existing instructions" + ); +} + +/// Oversize file is truncated to `project_doc_max_bytes`. +#[tokio::test] +async fn doc_larger_than_limit_is_truncated() { + const LIMIT: usize = 1024; + let tmp = tempfile::tempdir().expect("tempdir"); + + let huge = "A".repeat(LIMIT * 2); // 2 KiB + fs::write(tmp.path().join("AGENTS.md"), &huge).unwrap(); + + let res = get_user_instructions(&make_config(&tmp, LIMIT, None).await, None, None) + .await + .expect("doc expected"); + + assert_eq!(res.len(), LIMIT, "doc should be truncated to LIMIT bytes"); + assert_eq!(res, huge[..LIMIT]); +} + +/// When `cwd` is nested inside a repo, the search should locate AGENTS.md +/// placed at the repository root (identified by `.git`). +#[tokio::test] +async fn finds_doc_in_repo_root() { + let repo = tempfile::tempdir().expect("tempdir"); + + // Simulate a git repository. Note .git can be a file or a directory. + std::fs::write( + repo.path().join(".git"), + "gitdir: /path/to/actual/git/dir\n", + ) + .unwrap(); + + // Put the doc at the repo root. + fs::write(repo.path().join("AGENTS.md"), "root level doc").unwrap(); + + // Now create a nested working directory: repo/workspace/crate_a + let nested = repo.path().join("workspace/crate_a"); + std::fs::create_dir_all(&nested).unwrap(); + + // Build config pointing at the nested dir. + let mut cfg = make_config(&repo, 4096, None).await; + cfg.cwd = nested; + + let res = get_user_instructions(&cfg, None, None) + .await + .expect("doc expected"); + assert_eq!(res, "root level doc"); +} + +/// Explicitly setting the byte-limit to zero disables project docs. +#[tokio::test] +async fn zero_byte_limit_disables_docs() { + let tmp = tempfile::tempdir().expect("tempdir"); + fs::write(tmp.path().join("AGENTS.md"), "something").unwrap(); + + let res = get_user_instructions(&make_config(&tmp, 0, None).await, None, None).await; + assert!( + res.is_none(), + "With limit 0 the function should return None" + ); +} + +#[tokio::test] +async fn js_repl_instructions_are_appended_when_enabled() { + let tmp = tempfile::tempdir().expect("tempdir"); + let mut cfg = make_config(&tmp, 4096, None).await; + cfg.features + .enable(Feature::JsRepl) + .expect("test config should allow js_repl"); + + let res = get_user_instructions(&cfg, None, None) + .await + .expect("js_repl instructions expected"); + let expected = "## JavaScript REPL (Node)\n- Use `js_repl` for Node-backed JavaScript with top-level await in a persistent kernel.\n- `js_repl` is a freeform/custom tool. Direct `js_repl` calls must send raw JavaScript tool input (optionally with first-line `// codex-js-repl: timeout_ms=15000`). Do not wrap code in JSON (for example `{\"code\":\"...\"}`), quotes, or markdown code fences.\n- Helpers: `codex.cwd`, `codex.homeDir`, `codex.tmpDir`, `codex.tool(name, args?)`, and `codex.emitImage(imageLike)`.\n- `codex.tool` executes a normal tool call and resolves to the raw tool output object. Use it for shell and non-shell tools alike. Nested tool outputs stay inside JavaScript unless you emit them explicitly.\n- `codex.emitImage(...)` adds one image to the outer `js_repl` function output each time you call it, so you can call it multiple times to emit multiple images. It accepts a data URL, a single `input_image` item, an object like `{ bytes, mimeType }`, or a raw tool response object with exactly one image and no text. It rejects mixed text-and-image content.\n- Request full-resolution image processing with `detail: \"original\"` only when the `view_image` tool schema includes a `detail` argument. The same availability applies to `codex.emitImage(...)`: if `view_image.detail` is present, you may also pass `detail: \"original\"` there. Use this when high-fidelity image perception or precise localization is needed, especially for CUA agents.\n- Example of sharing an in-memory Playwright screenshot: `await codex.emitImage({ bytes: await page.screenshot({ type: \"jpeg\", quality: 85 }), mimeType: \"image/jpeg\", detail: \"original\" })`.\n- Example of sharing a local image tool result: `await codex.emitImage(codex.tool(\"view_image\", { path: \"/absolute/path\", detail: \"original\" }))`.\n- When encoding an image to send with `codex.emitImage(...)` or `view_image`, prefer JPEG at about 85 quality when lossy compression is acceptable; use PNG when transparency or lossless detail matters. Smaller uploads are faster and less likely to hit size limits.\n- Top-level bindings persist across cells. If a cell throws, prior bindings remain available and bindings that finished initializing before the throw often remain usable in later cells. For code you plan to reuse across cells, prefer declaring or assigning it in direct top-level statements before operations that might throw. If you hit `SyntaxError: Identifier 'x' has already been declared`, first reuse the existing binding, reassign a previously declared `let`, or pick a new descriptive name. Use `{ ... }` only for a short temporary block when you specifically need local scratch names; do not wrap an entire cell in block scope if you want those names reusable later. Reset the kernel with `js_repl_reset` only when you need a clean state.\n- Top-level static import declarations (for example `import x from \"./file.js\"`) are currently unsupported in `js_repl`; use dynamic imports with `await import(\"pkg\")`, `await import(\"./file.js\")`, or `await import(\"/abs/path/file.mjs\")` instead. Imported local files must be ESM `.js`/`.mjs` files and run in the same REPL VM context. Bare package imports always resolve from REPL-global search roots (`CODEX_JS_REPL_NODE_MODULE_DIRS`, then cwd), not relative to the imported file location. Local files may statically import only other local relative/absolute/`file://` `.js`/`.mjs` files; package and builtin imports from local files must stay dynamic. `import.meta.resolve()` returns importable strings such as `file://...`, bare package names, and `node:...` specifiers. Local file modules reload between execs, while top-level bindings persist until `js_repl_reset`.\n- Avoid direct access to `process.stdout` / `process.stderr` / `process.stdin`; it can corrupt the JSON line protocol. Use `console.log`, `codex.tool(...)`, and `codex.emitImage(...)`."; + assert_eq!(res, expected); +} + +#[tokio::test] +async fn js_repl_tools_only_instructions_are_feature_gated() { + let tmp = tempfile::tempdir().expect("tempdir"); + let mut cfg = make_config(&tmp, 4096, None).await; + let mut features = cfg.features.get().clone(); + features + .enable(Feature::JsRepl) + .enable(Feature::JsReplToolsOnly); + cfg.features + .set(features) + .expect("test config should allow js_repl tool restrictions"); + + let res = get_user_instructions(&cfg, None, None) + .await + .expect("js_repl instructions expected"); + let expected = "## JavaScript REPL (Node)\n- Use `js_repl` for Node-backed JavaScript with top-level await in a persistent kernel.\n- `js_repl` is a freeform/custom tool. Direct `js_repl` calls must send raw JavaScript tool input (optionally with first-line `// codex-js-repl: timeout_ms=15000`). Do not wrap code in JSON (for example `{\"code\":\"...\"}`), quotes, or markdown code fences.\n- Helpers: `codex.cwd`, `codex.homeDir`, `codex.tmpDir`, `codex.tool(name, args?)`, and `codex.emitImage(imageLike)`.\n- `codex.tool` executes a normal tool call and resolves to the raw tool output object. Use it for shell and non-shell tools alike. Nested tool outputs stay inside JavaScript unless you emit them explicitly.\n- `codex.emitImage(...)` adds one image to the outer `js_repl` function output each time you call it, so you can call it multiple times to emit multiple images. It accepts a data URL, a single `input_image` item, an object like `{ bytes, mimeType }`, or a raw tool response object with exactly one image and no text. It rejects mixed text-and-image content.\n- Request full-resolution image processing with `detail: \"original\"` only when the `view_image` tool schema includes a `detail` argument. The same availability applies to `codex.emitImage(...)`: if `view_image.detail` is present, you may also pass `detail: \"original\"` there. Use this when high-fidelity image perception or precise localization is needed, especially for CUA agents.\n- Example of sharing an in-memory Playwright screenshot: `await codex.emitImage({ bytes: await page.screenshot({ type: \"jpeg\", quality: 85 }), mimeType: \"image/jpeg\", detail: \"original\" })`.\n- Example of sharing a local image tool result: `await codex.emitImage(codex.tool(\"view_image\", { path: \"/absolute/path\", detail: \"original\" }))`.\n- When encoding an image to send with `codex.emitImage(...)` or `view_image`, prefer JPEG at about 85 quality when lossy compression is acceptable; use PNG when transparency or lossless detail matters. Smaller uploads are faster and less likely to hit size limits.\n- Top-level bindings persist across cells. If a cell throws, prior bindings remain available and bindings that finished initializing before the throw often remain usable in later cells. For code you plan to reuse across cells, prefer declaring or assigning it in direct top-level statements before operations that might throw. If you hit `SyntaxError: Identifier 'x' has already been declared`, first reuse the existing binding, reassign a previously declared `let`, or pick a new descriptive name. Use `{ ... }` only for a short temporary block when you specifically need local scratch names; do not wrap an entire cell in block scope if you want those names reusable later. Reset the kernel with `js_repl_reset` only when you need a clean state.\n- Top-level static import declarations (for example `import x from \"./file.js\"`) are currently unsupported in `js_repl`; use dynamic imports with `await import(\"pkg\")`, `await import(\"./file.js\")`, or `await import(\"/abs/path/file.mjs\")` instead. Imported local files must be ESM `.js`/`.mjs` files and run in the same REPL VM context. Bare package imports always resolve from REPL-global search roots (`CODEX_JS_REPL_NODE_MODULE_DIRS`, then cwd), not relative to the imported file location. Local files may statically import only other local relative/absolute/`file://` `.js`/`.mjs` files; package and builtin imports from local files must stay dynamic. `import.meta.resolve()` returns importable strings such as `file://...`, bare package names, and `node:...` specifiers. Local file modules reload between execs, while top-level bindings persist until `js_repl_reset`.\n- Do not call tools directly; use `js_repl` + `codex.tool(...)` for all tool calls, including shell commands.\n- MCP tools (if any) can also be called by name via `codex.tool(...)`.\n- Avoid direct access to `process.stdout` / `process.stderr` / `process.stdin`; it can corrupt the JSON line protocol. Use `console.log`, `codex.tool(...)`, and `codex.emitImage(...)`."; + assert_eq!(res, expected); +} + +#[tokio::test] +async fn js_repl_image_detail_original_does_not_change_instructions() { + let tmp = tempfile::tempdir().expect("tempdir"); + let mut cfg = make_config(&tmp, 4096, None).await; + let mut features = cfg.features.get().clone(); + features + .enable(Feature::JsRepl) + .enable(Feature::ImageDetailOriginal); + cfg.features + .set(features) + .expect("test config should allow js_repl image detail settings"); + + let res = get_user_instructions(&cfg, None, None) + .await + .expect("js_repl instructions expected"); + let expected = "## JavaScript REPL (Node)\n- Use `js_repl` for Node-backed JavaScript with top-level await in a persistent kernel.\n- `js_repl` is a freeform/custom tool. Direct `js_repl` calls must send raw JavaScript tool input (optionally with first-line `// codex-js-repl: timeout_ms=15000`). Do not wrap code in JSON (for example `{\"code\":\"...\"}`), quotes, or markdown code fences.\n- Helpers: `codex.cwd`, `codex.homeDir`, `codex.tmpDir`, `codex.tool(name, args?)`, and `codex.emitImage(imageLike)`.\n- `codex.tool` executes a normal tool call and resolves to the raw tool output object. Use it for shell and non-shell tools alike. Nested tool outputs stay inside JavaScript unless you emit them explicitly.\n- `codex.emitImage(...)` adds one image to the outer `js_repl` function output each time you call it, so you can call it multiple times to emit multiple images. It accepts a data URL, a single `input_image` item, an object like `{ bytes, mimeType }`, or a raw tool response object with exactly one image and no text. It rejects mixed text-and-image content.\n- Request full-resolution image processing with `detail: \"original\"` only when the `view_image` tool schema includes a `detail` argument. The same availability applies to `codex.emitImage(...)`: if `view_image.detail` is present, you may also pass `detail: \"original\"` there. Use this when high-fidelity image perception or precise localization is needed, especially for CUA agents.\n- Example of sharing an in-memory Playwright screenshot: `await codex.emitImage({ bytes: await page.screenshot({ type: \"jpeg\", quality: 85 }), mimeType: \"image/jpeg\", detail: \"original\" })`.\n- Example of sharing a local image tool result: `await codex.emitImage(codex.tool(\"view_image\", { path: \"/absolute/path\", detail: \"original\" }))`.\n- When encoding an image to send with `codex.emitImage(...)` or `view_image`, prefer JPEG at about 85 quality when lossy compression is acceptable; use PNG when transparency or lossless detail matters. Smaller uploads are faster and less likely to hit size limits.\n- Top-level bindings persist across cells. If a cell throws, prior bindings remain available and bindings that finished initializing before the throw often remain usable in later cells. For code you plan to reuse across cells, prefer declaring or assigning it in direct top-level statements before operations that might throw. If you hit `SyntaxError: Identifier 'x' has already been declared`, first reuse the existing binding, reassign a previously declared `let`, or pick a new descriptive name. Use `{ ... }` only for a short temporary block when you specifically need local scratch names; do not wrap an entire cell in block scope if you want those names reusable later. Reset the kernel with `js_repl_reset` only when you need a clean state.\n- Top-level static import declarations (for example `import x from \"./file.js\"`) are currently unsupported in `js_repl`; use dynamic imports with `await import(\"pkg\")`, `await import(\"./file.js\")`, or `await import(\"/abs/path/file.mjs\")` instead. Imported local files must be ESM `.js`/`.mjs` files and run in the same REPL VM context. Bare package imports always resolve from REPL-global search roots (`CODEX_JS_REPL_NODE_MODULE_DIRS`, then cwd), not relative to the imported file location. Local files may statically import only other local relative/absolute/`file://` `.js`/`.mjs` files; package and builtin imports from local files must stay dynamic. `import.meta.resolve()` returns importable strings such as `file://...`, bare package names, and `node:...` specifiers. Local file modules reload between execs, while top-level bindings persist until `js_repl_reset`.\n- Avoid direct access to `process.stdout` / `process.stderr` / `process.stdin`; it can corrupt the JSON line protocol. Use `console.log`, `codex.tool(...)`, and `codex.emitImage(...)`."; + assert_eq!(res, expected); +} + +/// When both system instructions *and* a project doc are present the two +/// should be concatenated with the separator. +#[tokio::test] +async fn merges_existing_instructions_with_project_doc() { + let tmp = tempfile::tempdir().expect("tempdir"); + fs::write(tmp.path().join("AGENTS.md"), "proj doc").unwrap(); + + const INSTRUCTIONS: &str = "base instructions"; + + let res = get_user_instructions( + &make_config(&tmp, 4096, Some(INSTRUCTIONS)).await, + None, + None, + ) + .await + .expect("should produce a combined instruction string"); + + let expected = format!("{INSTRUCTIONS}{PROJECT_DOC_SEPARATOR}{}", "proj doc"); + + assert_eq!(res, expected); +} + +/// If there are existing system instructions but the project doc is +/// missing we expect the original instructions to be returned unchanged. +#[tokio::test] +async fn keeps_existing_instructions_when_doc_missing() { + let tmp = tempfile::tempdir().expect("tempdir"); + + const INSTRUCTIONS: &str = "some instructions"; + + let res = get_user_instructions( + &make_config(&tmp, 4096, Some(INSTRUCTIONS)).await, + None, + None, + ) + .await; + + assert_eq!(res, Some(INSTRUCTIONS.to_string())); +} + +/// When both the repository root and the working directory contain +/// AGENTS.md files, their contents are concatenated from root to cwd. +#[tokio::test] +async fn concatenates_root_and_cwd_docs() { + let repo = tempfile::tempdir().expect("tempdir"); + + // Simulate a git repository. + std::fs::write( + repo.path().join(".git"), + "gitdir: /path/to/actual/git/dir\n", + ) + .unwrap(); + + // Repo root doc. + fs::write(repo.path().join("AGENTS.md"), "root doc").unwrap(); + + // Nested working directory with its own doc. + let nested = repo.path().join("workspace/crate_a"); + std::fs::create_dir_all(&nested).unwrap(); + fs::write(nested.join("AGENTS.md"), "crate doc").unwrap(); + + let mut cfg = make_config(&repo, 4096, None).await; + cfg.cwd = nested; + + let res = get_user_instructions(&cfg, None, None) + .await + .expect("doc expected"); + assert_eq!(res, "root doc\n\ncrate doc"); +} + +#[tokio::test] +async fn project_root_markers_are_honored_for_agents_discovery() { + let root = tempfile::tempdir().expect("tempdir"); + fs::write(root.path().join(".codex-root"), "").unwrap(); + fs::write(root.path().join("AGENTS.md"), "parent doc").unwrap(); + + let nested = root.path().join("dir1"); + fs::create_dir_all(nested.join(".git")).unwrap(); + fs::write(nested.join("AGENTS.md"), "child doc").unwrap(); + + let mut cfg = make_config_with_project_root_markers(&root, 4096, None, &[".codex-root"]).await; + cfg.cwd = nested; + + let discovery = discover_project_doc_paths(&cfg).expect("discover paths"); + let expected_parent = + dunce::canonicalize(root.path().join("AGENTS.md")).expect("canonical parent doc path"); + let expected_child = + dunce::canonicalize(cfg.cwd.join("AGENTS.md")).expect("canonical child doc path"); + assert_eq!(discovery.len(), 2); + assert_eq!(discovery[0], expected_parent); + assert_eq!(discovery[1], expected_child); + + let res = get_user_instructions(&cfg, None, None) + .await + .expect("doc expected"); + assert_eq!(res, "parent doc\n\nchild doc"); +} + +/// AGENTS.override.md is preferred over AGENTS.md when both are present. +#[tokio::test] +async fn agents_local_md_preferred() { + let tmp = tempfile::tempdir().expect("tempdir"); + fs::write(tmp.path().join(DEFAULT_PROJECT_DOC_FILENAME), "versioned").unwrap(); + fs::write(tmp.path().join(LOCAL_PROJECT_DOC_FILENAME), "local").unwrap(); + + let cfg = make_config(&tmp, 4096, None).await; + + let res = get_user_instructions(&cfg, None, None) + .await + .expect("local doc expected"); + + assert_eq!(res, "local"); + + let discovery = discover_project_doc_paths(&cfg).expect("discover paths"); + assert_eq!(discovery.len(), 1); + assert_eq!( + discovery[0].file_name().unwrap().to_string_lossy(), + LOCAL_PROJECT_DOC_FILENAME + ); +} + +/// When AGENTS.md is absent but a configured fallback exists, the fallback is used. +#[tokio::test] +async fn uses_configured_fallback_when_agents_missing() { + let tmp = tempfile::tempdir().expect("tempdir"); + fs::write(tmp.path().join("EXAMPLE.md"), "example instructions").unwrap(); + + let cfg = make_config_with_fallback(&tmp, 4096, None, &["EXAMPLE.md"]).await; + + let res = get_user_instructions(&cfg, None, None) + .await + .expect("fallback doc expected"); + + assert_eq!(res, "example instructions"); +} + +/// AGENTS.md remains preferred when both AGENTS.md and fallbacks are present. +#[tokio::test] +async fn agents_md_preferred_over_fallbacks() { + let tmp = tempfile::tempdir().expect("tempdir"); + fs::write(tmp.path().join("AGENTS.md"), "primary").unwrap(); + fs::write(tmp.path().join("EXAMPLE.md"), "secondary").unwrap(); + + let cfg = make_config_with_fallback(&tmp, 4096, None, &["EXAMPLE.md", ".example.md"]).await; + + let res = get_user_instructions(&cfg, None, None) + .await + .expect("AGENTS.md should win"); + + assert_eq!(res, "primary"); + + let discovery = discover_project_doc_paths(&cfg).expect("discover paths"); + assert_eq!(discovery.len(), 1); + assert!( + discovery[0] + .file_name() + .unwrap() + .to_string_lossy() + .eq(DEFAULT_PROJECT_DOC_FILENAME) + ); +} + +#[tokio::test] +async fn skills_are_appended_to_project_doc() { + let tmp = tempfile::tempdir().expect("tempdir"); + fs::write(tmp.path().join("AGENTS.md"), "base doc").unwrap(); + + let cfg = make_config(&tmp, 4096, None).await; + create_skill( + cfg.codex_home.clone(), + "pdf-processing", + "extract from pdfs", + ); + + let skills = load_test_skills(&cfg); + let res = get_user_instructions( + &cfg, + skills.errors.is_empty().then_some(skills.skills.as_slice()), + None, + ) + .await + .expect("instructions expected"); + let expected_path = dunce::canonicalize( + cfg.codex_home + .join("skills/pdf-processing/SKILL.md") + .as_path(), + ) + .unwrap_or_else(|_| cfg.codex_home.join("skills/pdf-processing/SKILL.md")); + let expected_path_str = expected_path.to_string_lossy().replace('\\', "/"); + let usage_rules = "- Discovery: The list above is the skills available in this session (name + description + file path). Skill bodies live on disk at the listed paths.\n- Trigger rules: If the user names a skill (with `$SkillName` or plain text) OR the task clearly matches a skill's description shown above, you must use that skill for that turn. Multiple mentions mean use them all. Do not carry skills across turns unless re-mentioned.\n- Missing/blocked: If a named skill isn't in the list or the path can't be read, say so briefly and continue with the best fallback.\n- How to use a skill (progressive disclosure):\n 1) After deciding to use a skill, open its `SKILL.md`. Read only enough to follow the workflow.\n 2) When `SKILL.md` references relative paths (e.g., `scripts/foo.py`), resolve them relative to the skill directory listed above first, and only consider other paths if needed.\n 3) If `SKILL.md` points to extra folders such as `references/`, load only the specific files needed for the request; don't bulk-load everything.\n 4) If `scripts/` exist, prefer running or patching them instead of retyping large code blocks.\n 5) If `assets/` or templates exist, reuse them instead of recreating from scratch.\n- Coordination and sequencing:\n - If multiple skills apply, choose the minimal set that covers the request and state the order you'll use them.\n - Announce which skill(s) you're using and why (one short line). If you skip an obvious skill, say why.\n- Context hygiene:\n - Keep context small: summarize long sections instead of pasting them; only load extra files when needed.\n - Avoid deep reference-chasing: prefer opening only files directly linked from `SKILL.md` unless you're blocked.\n - When variants exist (frameworks, providers, domains), pick only the relevant reference file(s) and note that choice.\n- Safety and fallback: If a skill can't be applied cleanly (missing files, unclear instructions), state the issue, pick the next-best approach, and continue."; + let expected = format!( + "base doc\n\n## Skills\nA skill is a set of local instructions to follow that is stored in a `SKILL.md` file. Below is the list of skills that can be used. Each entry includes a name, description, and file path so you can open the source for full instructions when using a specific skill.\n### Available skills\n- pdf-processing: extract from pdfs (file: {expected_path_str})\n### How to use skills\n{usage_rules}" + ); + assert_eq!(res, expected); +} + +#[tokio::test] +async fn skills_render_without_project_doc() { + let tmp = tempfile::tempdir().expect("tempdir"); + let cfg = make_config(&tmp, 4096, None).await; + create_skill(cfg.codex_home.clone(), "linting", "run clippy"); + + let skills = load_test_skills(&cfg); + let res = get_user_instructions( + &cfg, + skills.errors.is_empty().then_some(skills.skills.as_slice()), + None, + ) + .await + .expect("instructions expected"); + let expected_path = + dunce::canonicalize(cfg.codex_home.join("skills/linting/SKILL.md").as_path()) + .unwrap_or_else(|_| cfg.codex_home.join("skills/linting/SKILL.md")); + let expected_path_str = expected_path.to_string_lossy().replace('\\', "/"); + let usage_rules = "- Discovery: The list above is the skills available in this session (name + description + file path). Skill bodies live on disk at the listed paths.\n- Trigger rules: If the user names a skill (with `$SkillName` or plain text) OR the task clearly matches a skill's description shown above, you must use that skill for that turn. Multiple mentions mean use them all. Do not carry skills across turns unless re-mentioned.\n- Missing/blocked: If a named skill isn't in the list or the path can't be read, say so briefly and continue with the best fallback.\n- How to use a skill (progressive disclosure):\n 1) After deciding to use a skill, open its `SKILL.md`. Read only enough to follow the workflow.\n 2) When `SKILL.md` references relative paths (e.g., `scripts/foo.py`), resolve them relative to the skill directory listed above first, and only consider other paths if needed.\n 3) If `SKILL.md` points to extra folders such as `references/`, load only the specific files needed for the request; don't bulk-load everything.\n 4) If `scripts/` exist, prefer running or patching them instead of retyping large code blocks.\n 5) If `assets/` or templates exist, reuse them instead of recreating from scratch.\n- Coordination and sequencing:\n - If multiple skills apply, choose the minimal set that covers the request and state the order you'll use them.\n - Announce which skill(s) you're using and why (one short line). If you skip an obvious skill, say why.\n- Context hygiene:\n - Keep context small: summarize long sections instead of pasting them; only load extra files when needed.\n - Avoid deep reference-chasing: prefer opening only files directly linked from `SKILL.md` unless you're blocked.\n - When variants exist (frameworks, providers, domains), pick only the relevant reference file(s) and note that choice.\n- Safety and fallback: If a skill can't be applied cleanly (missing files, unclear instructions), state the issue, pick the next-best approach, and continue."; + let expected = format!( + "## Skills\nA skill is a set of local instructions to follow that is stored in a `SKILL.md` file. Below is the list of skills that can be used. Each entry includes a name, description, and file path so you can open the source for full instructions when using a specific skill.\n### Available skills\n- linting: run clippy (file: {expected_path_str})\n### How to use skills\n{usage_rules}" + ); + assert_eq!(res, expected); +} + +#[tokio::test] +async fn apps_feature_does_not_emit_user_instructions_by_itself() { + let tmp = tempfile::tempdir().expect("tempdir"); + let mut cfg = make_config(&tmp, 4096, None).await; + cfg.features + .enable(Feature::Apps) + .expect("test config should allow apps"); + + let res = get_user_instructions(&cfg, None, None).await; + assert_eq!(res, None); +} + +#[tokio::test] +async fn apps_feature_does_not_append_to_project_doc_user_instructions() { + let tmp = tempfile::tempdir().expect("tempdir"); + fs::write(tmp.path().join("AGENTS.md"), "base doc").unwrap(); + + let mut cfg = make_config(&tmp, 4096, None).await; + cfg.features + .enable(Feature::Apps) + .expect("test config should allow apps"); + + let res = get_user_instructions(&cfg, None, None) + .await + .expect("instructions expected"); + assert_eq!(res, "base doc"); +} + +fn create_skill(codex_home: PathBuf, name: &str, description: &str) { + let skill_dir = codex_home.join(format!("skills/{name}")); + fs::create_dir_all(&skill_dir).unwrap(); + let content = format!("---\nname: {name}\ndescription: {description}\n---\n\n# Body\n"); + fs::write(skill_dir.join("SKILL.md"), content).unwrap(); +} diff --git a/codex-rs/core/src/realtime_context.rs b/codex-rs/core/src/realtime_context.rs index e15adabc4e..8016763514 100644 --- a/codex-rs/core/src/realtime_context.rs +++ b/codex-rs/core/src/realtime_context.rs @@ -395,138 +395,5 @@ fn approx_token_count(text: &str) -> usize { } #[cfg(test)] -mod tests { - use super::build_recent_work_section; - use super::build_workspace_section; - use super::build_workspace_section_with_user_root; - use chrono::TimeZone; - use chrono::Utc; - use codex_protocol::ThreadId; - use codex_state::ThreadMetadata; - use pretty_assertions::assert_eq; - use std::fs; - use std::path::PathBuf; - use std::process::Command; - use tempfile::TempDir; - - fn thread_metadata(cwd: &str, title: &str, first_user_message: &str) -> ThreadMetadata { - ThreadMetadata { - id: ThreadId::new(), - rollout_path: PathBuf::from("/tmp/rollout.jsonl"), - created_at: Utc - .timestamp_opt(1_709_251_100, 0) - .single() - .expect("valid timestamp"), - updated_at: Utc - .timestamp_opt(1_709_251_200, 0) - .single() - .expect("valid timestamp"), - source: "cli".to_string(), - agent_nickname: None, - agent_role: None, - model_provider: "test-provider".to_string(), - cwd: PathBuf::from(cwd), - cli_version: "test".to_string(), - title: title.to_string(), - sandbox_policy: "workspace-write".to_string(), - approval_mode: "never".to_string(), - tokens_used: 0, - first_user_message: Some(first_user_message.to_string()), - archived_at: None, - git_sha: None, - git_branch: Some("main".to_string()), - git_origin_url: None, - } - } - - #[test] - fn workspace_section_requires_meaningful_structure() { - let cwd = TempDir::new().expect("tempdir"); - assert_eq!( - build_workspace_section_with_user_root(cwd.path(), None), - None - ); - } - - #[test] - fn workspace_section_includes_tree_when_entries_exist() { - let cwd = TempDir::new().expect("tempdir"); - fs::create_dir(cwd.path().join("docs")).expect("create docs dir"); - fs::write(cwd.path().join("README.md"), "hello").expect("write readme"); - - let section = build_workspace_section(cwd.path()).expect("workspace section"); - assert!(section.contains("Working directory tree:")); - assert!(section.contains("- docs/")); - assert!(section.contains("- README.md")); - } - - #[test] - fn workspace_section_includes_user_root_tree_when_distinct() { - let root = TempDir::new().expect("tempdir"); - let cwd = root.path().join("cwd"); - let git_root = root.path().join("git"); - let user_root = root.path().join("home"); - - fs::create_dir_all(cwd.join("docs")).expect("create cwd docs dir"); - fs::write(cwd.join("README.md"), "hello").expect("write cwd readme"); - fs::create_dir_all(git_root.join(".git")).expect("create git dir"); - fs::write(git_root.join("Cargo.toml"), "[workspace]").expect("write git root marker"); - fs::create_dir_all(user_root.join("code")).expect("create user root child"); - fs::write(user_root.join(".zshrc"), "export TEST=1").expect("write home file"); - - let section = build_workspace_section_with_user_root(cwd.as_path(), Some(user_root)) - .expect("workspace section"); - assert!(section.contains("User root tree:")); - assert!(section.contains("- code/")); - assert!(!section.contains("- .zshrc")); - } - - #[test] - fn recent_work_section_groups_threads_by_cwd() { - let root = TempDir::new().expect("tempdir"); - let repo = root.path().join("repo"); - let workspace_a = repo.join("workspace-a"); - let workspace_b = repo.join("workspace-b"); - let outside = root.path().join("outside"); - - fs::create_dir(&repo).expect("create repo dir"); - Command::new("git") - .env("GIT_CONFIG_GLOBAL", "/dev/null") - .env("GIT_CONFIG_NOSYSTEM", "1") - .args(["init"]) - .current_dir(&repo) - .output() - .expect("git init"); - fs::create_dir_all(&workspace_a).expect("create workspace a"); - fs::create_dir_all(&workspace_b).expect("create workspace b"); - fs::create_dir_all(&outside).expect("create outside dir"); - - let recent_threads = vec![ - thread_metadata( - workspace_a.to_string_lossy().as_ref(), - "Investigate realtime startup context", - "Log the startup context before sending it", - ), - thread_metadata( - workspace_b.to_string_lossy().as_ref(), - "Trim websocket startup payload", - "Remove memories from the realtime startup context", - ), - thread_metadata(outside.to_string_lossy().as_ref(), "", "Inspect flaky test"), - ]; - let current_cwd = workspace_a; - let repo = fs::canonicalize(repo).expect("canonicalize repo"); - - let section = build_recent_work_section(current_cwd.as_path(), &recent_threads) - .expect("recent work section"); - assert!(section.contains(&format!("### Git repo: {}", repo.display()))); - assert!(section.contains("Recent sessions: 2")); - assert!(section.contains("User asks:")); - assert!(section.contains(&format!( - "- {}: Log the startup context before sending it", - current_cwd.display() - ))); - assert!(section.contains(&format!("### Directory: {}", outside.display()))); - assert!(section.contains(&format!("- {}: Inspect flaky test", outside.display()))); - } -} +#[path = "realtime_context_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/realtime_context_tests.rs b/codex-rs/core/src/realtime_context_tests.rs new file mode 100644 index 0000000000..b23c2743cf --- /dev/null +++ b/codex-rs/core/src/realtime_context_tests.rs @@ -0,0 +1,133 @@ +use super::build_recent_work_section; +use super::build_workspace_section; +use super::build_workspace_section_with_user_root; +use chrono::TimeZone; +use chrono::Utc; +use codex_protocol::ThreadId; +use codex_state::ThreadMetadata; +use pretty_assertions::assert_eq; +use std::fs; +use std::path::PathBuf; +use std::process::Command; +use tempfile::TempDir; + +fn thread_metadata(cwd: &str, title: &str, first_user_message: &str) -> ThreadMetadata { + ThreadMetadata { + id: ThreadId::new(), + rollout_path: PathBuf::from("/tmp/rollout.jsonl"), + created_at: Utc + .timestamp_opt(1_709_251_100, 0) + .single() + .expect("valid timestamp"), + updated_at: Utc + .timestamp_opt(1_709_251_200, 0) + .single() + .expect("valid timestamp"), + source: "cli".to_string(), + agent_nickname: None, + agent_role: None, + model_provider: "test-provider".to_string(), + cwd: PathBuf::from(cwd), + cli_version: "test".to_string(), + title: title.to_string(), + sandbox_policy: "workspace-write".to_string(), + approval_mode: "never".to_string(), + tokens_used: 0, + first_user_message: Some(first_user_message.to_string()), + archived_at: None, + git_sha: None, + git_branch: Some("main".to_string()), + git_origin_url: None, + } +} + +#[test] +fn workspace_section_requires_meaningful_structure() { + let cwd = TempDir::new().expect("tempdir"); + assert_eq!( + build_workspace_section_with_user_root(cwd.path(), None), + None + ); +} + +#[test] +fn workspace_section_includes_tree_when_entries_exist() { + let cwd = TempDir::new().expect("tempdir"); + fs::create_dir(cwd.path().join("docs")).expect("create docs dir"); + fs::write(cwd.path().join("README.md"), "hello").expect("write readme"); + + let section = build_workspace_section(cwd.path()).expect("workspace section"); + assert!(section.contains("Working directory tree:")); + assert!(section.contains("- docs/")); + assert!(section.contains("- README.md")); +} + +#[test] +fn workspace_section_includes_user_root_tree_when_distinct() { + let root = TempDir::new().expect("tempdir"); + let cwd = root.path().join("cwd"); + let git_root = root.path().join("git"); + let user_root = root.path().join("home"); + + fs::create_dir_all(cwd.join("docs")).expect("create cwd docs dir"); + fs::write(cwd.join("README.md"), "hello").expect("write cwd readme"); + fs::create_dir_all(git_root.join(".git")).expect("create git dir"); + fs::write(git_root.join("Cargo.toml"), "[workspace]").expect("write git root marker"); + fs::create_dir_all(user_root.join("code")).expect("create user root child"); + fs::write(user_root.join(".zshrc"), "export TEST=1").expect("write home file"); + + let section = build_workspace_section_with_user_root(cwd.as_path(), Some(user_root)) + .expect("workspace section"); + assert!(section.contains("User root tree:")); + assert!(section.contains("- code/")); + assert!(!section.contains("- .zshrc")); +} + +#[test] +fn recent_work_section_groups_threads_by_cwd() { + let root = TempDir::new().expect("tempdir"); + let repo = root.path().join("repo"); + let workspace_a = repo.join("workspace-a"); + let workspace_b = repo.join("workspace-b"); + let outside = root.path().join("outside"); + + fs::create_dir(&repo).expect("create repo dir"); + Command::new("git") + .env("GIT_CONFIG_GLOBAL", "/dev/null") + .env("GIT_CONFIG_NOSYSTEM", "1") + .args(["init"]) + .current_dir(&repo) + .output() + .expect("git init"); + fs::create_dir_all(&workspace_a).expect("create workspace a"); + fs::create_dir_all(&workspace_b).expect("create workspace b"); + fs::create_dir_all(&outside).expect("create outside dir"); + + let recent_threads = vec![ + thread_metadata( + workspace_a.to_string_lossy().as_ref(), + "Investigate realtime startup context", + "Log the startup context before sending it", + ), + thread_metadata( + workspace_b.to_string_lossy().as_ref(), + "Trim websocket startup payload", + "Remove memories from the realtime startup context", + ), + thread_metadata(outside.to_string_lossy().as_ref(), "", "Inspect flaky test"), + ]; + let current_cwd = workspace_a; + let repo = fs::canonicalize(repo).expect("canonicalize repo"); + + let section = build_recent_work_section(current_cwd.as_path(), &recent_threads) + .expect("recent work section"); + assert!(section.contains(&format!("### Git repo: {}", repo.display()))); + assert!(section.contains("Recent sessions: 2")); + assert!(section.contains("User asks:")); + assert!(section.contains(&format!( + "- {}: Log the startup context before sending it", + current_cwd.display() + ))); + assert!(section.contains(&format!("### Directory: {}", outside.display()))); + assert!(section.contains(&format!("- {}: Inspect flaky test", outside.display()))); +} diff --git a/codex-rs/core/src/realtime_conversation.rs b/codex-rs/core/src/realtime_conversation.rs index 3baea265ad..f1ce8398e3 100644 --- a/codex-rs/core/src/realtime_conversation.rs +++ b/codex-rs/core/src/realtime_conversation.rs @@ -600,119 +600,5 @@ async fn send_conversation_error( } #[cfg(test)] -mod tests { - use super::HandoffOutput; - use super::RealtimeHandoffState; - use super::realtime_text_from_handoff_request; - use async_channel::bounded; - use codex_protocol::protocol::RealtimeHandoffRequested; - use codex_protocol::protocol::RealtimeTranscriptEntry; - use pretty_assertions::assert_eq; - - #[test] - fn extracts_text_from_handoff_request_active_transcript() { - let handoff = RealtimeHandoffRequested { - handoff_id: "handoff_1".to_string(), - item_id: "item_1".to_string(), - input_transcript: "ignored".to_string(), - active_transcript: vec![ - RealtimeTranscriptEntry { - role: "user".to_string(), - text: "hello".to_string(), - }, - RealtimeTranscriptEntry { - role: "assistant".to_string(), - text: "hi there".to_string(), - }, - ], - }; - assert_eq!( - realtime_text_from_handoff_request(&handoff), - Some("user: hello\nassistant: hi there".to_string()) - ); - } - - #[test] - fn extracts_text_from_handoff_request_input_transcript_if_messages_missing() { - let handoff = RealtimeHandoffRequested { - handoff_id: "handoff_1".to_string(), - item_id: "item_1".to_string(), - input_transcript: "ignored".to_string(), - active_transcript: vec![], - }; - assert_eq!( - realtime_text_from_handoff_request(&handoff), - Some("ignored".to_string()) - ); - } - - #[test] - fn ignores_empty_handoff_request_input_transcript() { - let handoff = RealtimeHandoffRequested { - handoff_id: "handoff_1".to_string(), - item_id: "item_1".to_string(), - input_transcript: String::new(), - active_transcript: vec![], - }; - assert_eq!(realtime_text_from_handoff_request(&handoff), None); - } - - #[tokio::test] - async fn clears_active_handoff_explicitly() { - let (tx, _rx) = bounded(1); - let state = RealtimeHandoffState::new(tx); - - *state.active_handoff.lock().await = Some("handoff_1".to_string()); - assert_eq!( - state.active_handoff.lock().await.clone(), - Some("handoff_1".to_string()) - ); - - *state.active_handoff.lock().await = None; - assert_eq!(state.active_handoff.lock().await.clone(), None); - } - - #[tokio::test] - async fn sends_multiple_handoff_outputs_until_cleared() { - let (tx, rx) = bounded(4); - let state = RealtimeHandoffState::new(tx); - - state - .send_output("ignored".to_string()) - .await - .expect("send"); - assert!(rx.is_empty()); - - *state.active_handoff.lock().await = Some("handoff_1".to_string()); - state.send_output("result".to_string()).await.expect("send"); - state - .send_output("result 2".to_string()) - .await - .expect("send"); - - let output_1 = rx.recv().await.expect("recv"); - assert_eq!( - output_1, - HandoffOutput { - handoff_id: "handoff_1".to_string(), - output_text: "result".to_string(), - } - ); - - let output_2 = rx.recv().await.expect("recv"); - assert_eq!( - output_2, - HandoffOutput { - handoff_id: "handoff_1".to_string(), - output_text: "result 2".to_string(), - } - ); - - *state.active_handoff.lock().await = None; - state - .send_output("ignored after clear".to_string()) - .await - .expect("send"); - assert!(rx.is_empty()); - } -} +#[path = "realtime_conversation_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/realtime_conversation_tests.rs b/codex-rs/core/src/realtime_conversation_tests.rs new file mode 100644 index 0000000000..d6b85a92da --- /dev/null +++ b/codex-rs/core/src/realtime_conversation_tests.rs @@ -0,0 +1,114 @@ +use super::HandoffOutput; +use super::RealtimeHandoffState; +use super::realtime_text_from_handoff_request; +use async_channel::bounded; +use codex_protocol::protocol::RealtimeHandoffRequested; +use codex_protocol::protocol::RealtimeTranscriptEntry; +use pretty_assertions::assert_eq; + +#[test] +fn extracts_text_from_handoff_request_active_transcript() { + let handoff = RealtimeHandoffRequested { + handoff_id: "handoff_1".to_string(), + item_id: "item_1".to_string(), + input_transcript: "ignored".to_string(), + active_transcript: vec![ + RealtimeTranscriptEntry { + role: "user".to_string(), + text: "hello".to_string(), + }, + RealtimeTranscriptEntry { + role: "assistant".to_string(), + text: "hi there".to_string(), + }, + ], + }; + assert_eq!( + realtime_text_from_handoff_request(&handoff), + Some("user: hello\nassistant: hi there".to_string()) + ); +} + +#[test] +fn extracts_text_from_handoff_request_input_transcript_if_messages_missing() { + let handoff = RealtimeHandoffRequested { + handoff_id: "handoff_1".to_string(), + item_id: "item_1".to_string(), + input_transcript: "ignored".to_string(), + active_transcript: vec![], + }; + assert_eq!( + realtime_text_from_handoff_request(&handoff), + Some("ignored".to_string()) + ); +} + +#[test] +fn ignores_empty_handoff_request_input_transcript() { + let handoff = RealtimeHandoffRequested { + handoff_id: "handoff_1".to_string(), + item_id: "item_1".to_string(), + input_transcript: String::new(), + active_transcript: vec![], + }; + assert_eq!(realtime_text_from_handoff_request(&handoff), None); +} + +#[tokio::test] +async fn clears_active_handoff_explicitly() { + let (tx, _rx) = bounded(1); + let state = RealtimeHandoffState::new(tx); + + *state.active_handoff.lock().await = Some("handoff_1".to_string()); + assert_eq!( + state.active_handoff.lock().await.clone(), + Some("handoff_1".to_string()) + ); + + *state.active_handoff.lock().await = None; + assert_eq!(state.active_handoff.lock().await.clone(), None); +} + +#[tokio::test] +async fn sends_multiple_handoff_outputs_until_cleared() { + let (tx, rx) = bounded(4); + let state = RealtimeHandoffState::new(tx); + + state + .send_output("ignored".to_string()) + .await + .expect("send"); + assert!(rx.is_empty()); + + *state.active_handoff.lock().await = Some("handoff_1".to_string()); + state.send_output("result".to_string()).await.expect("send"); + state + .send_output("result 2".to_string()) + .await + .expect("send"); + + let output_1 = rx.recv().await.expect("recv"); + assert_eq!( + output_1, + HandoffOutput { + handoff_id: "handoff_1".to_string(), + output_text: "result".to_string(), + } + ); + + let output_2 = rx.recv().await.expect("recv"); + assert_eq!( + output_2, + HandoffOutput { + handoff_id: "handoff_1".to_string(), + output_text: "result 2".to_string(), + } + ); + + *state.active_handoff.lock().await = None; + state + .send_output("ignored after clear".to_string()) + .await + .expect("send"); + assert!(rx.is_empty()); +} diff --git a/codex-rs/core/src/rollout/metadata.rs b/codex-rs/core/src/rollout/metadata.rs index 3b18520ee0..d2edfbb0d8 100644 --- a/codex-rs/core/src/rollout/metadata.rs +++ b/codex-rs/core/src/rollout/metadata.rs @@ -437,387 +437,5 @@ async fn collect_rollout_paths(root: &Path) -> std::io::Result> { } #[cfg(test)] -mod tests { - use super::*; - use chrono::DateTime; - use chrono::NaiveDateTime; - use chrono::Timelike; - use chrono::Utc; - use codex_protocol::ThreadId; - use codex_protocol::protocol::CompactedItem; - use codex_protocol::protocol::GitInfo; - use codex_protocol::protocol::RolloutItem; - use codex_protocol::protocol::RolloutLine; - use codex_protocol::protocol::SessionMeta; - use codex_protocol::protocol::SessionMetaLine; - use codex_protocol::protocol::SessionSource; - use codex_state::BackfillStatus; - use codex_state::ThreadMetadataBuilder; - use pretty_assertions::assert_eq; - use std::fs::File; - use std::io::Write; - use std::path::Path; - use std::path::PathBuf; - use tempfile::tempdir; - use uuid::Uuid; - - #[tokio::test] - async fn extract_metadata_from_rollout_uses_session_meta() { - let dir = tempdir().expect("tempdir"); - let uuid = Uuid::new_v4(); - let id = ThreadId::from_string(&uuid.to_string()).expect("thread id"); - let path = dir - .path() - .join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl")); - - let session_meta = SessionMeta { - id, - forked_from_id: None, - timestamp: "2026-01-27T12:34:56Z".to_string(), - cwd: dir.path().to_path_buf(), - originator: "cli".to_string(), - cli_version: "0.0.0".to_string(), - source: SessionSource::default(), - agent_nickname: None, - agent_role: None, - model_provider: Some("openai".to_string()), - base_instructions: None, - dynamic_tools: None, - memory_mode: None, - }; - let session_meta_line = SessionMetaLine { - meta: session_meta, - git: None, - }; - let rollout_line = RolloutLine { - timestamp: "2026-01-27T12:34:56Z".to_string(), - item: RolloutItem::SessionMeta(session_meta_line.clone()), - }; - let json = serde_json::to_string(&rollout_line).expect("rollout json"); - let mut file = File::create(&path).expect("create rollout"); - writeln!(file, "{json}").expect("write rollout"); - - let outcome = extract_metadata_from_rollout(&path, "openai") - .await - .expect("extract"); - - let builder = - builder_from_session_meta(&session_meta_line, path.as_path()).expect("builder"); - let mut expected = builder.build("openai"); - apply_rollout_item(&mut expected, &rollout_line.item, "openai"); - expected.updated_at = file_modified_time_utc(&path).await.expect("mtime"); - - assert_eq!(outcome.metadata, expected); - assert_eq!(outcome.memory_mode, None); - assert_eq!(outcome.parse_errors, 0); - } - - #[tokio::test] - async fn extract_metadata_from_rollout_returns_latest_memory_mode() { - let dir = tempdir().expect("tempdir"); - let uuid = Uuid::new_v4(); - let id = ThreadId::from_string(&uuid.to_string()).expect("thread id"); - let path = dir - .path() - .join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl")); - - let session_meta = SessionMeta { - id, - forked_from_id: None, - timestamp: "2026-01-27T12:34:56Z".to_string(), - cwd: dir.path().to_path_buf(), - originator: "cli".to_string(), - cli_version: "0.0.0".to_string(), - source: SessionSource::default(), - agent_nickname: None, - agent_role: None, - model_provider: Some("openai".to_string()), - base_instructions: None, - dynamic_tools: None, - memory_mode: None, - }; - let polluted_meta = SessionMeta { - memory_mode: Some("polluted".to_string()), - ..session_meta.clone() - }; - let lines = vec![ - RolloutLine { - timestamp: "2026-01-27T12:34:56Z".to_string(), - item: RolloutItem::SessionMeta(SessionMetaLine { - meta: session_meta, - git: None, - }), - }, - RolloutLine { - timestamp: "2026-01-27T12:35:00Z".to_string(), - item: RolloutItem::SessionMeta(SessionMetaLine { - meta: polluted_meta, - git: None, - }), - }, - ]; - let mut file = File::create(&path).expect("create rollout"); - for line in lines { - writeln!( - file, - "{}", - serde_json::to_string(&line).expect("serialize rollout line") - ) - .expect("write rollout line"); - } - - let outcome = extract_metadata_from_rollout(&path, "openai") - .await - .expect("extract"); - - assert_eq!(outcome.memory_mode.as_deref(), Some("polluted")); - } - - #[test] - fn builder_from_items_falls_back_to_filename() { - let dir = tempdir().expect("tempdir"); - let uuid = Uuid::new_v4(); - let path = dir - .path() - .join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl")); - let items = vec![RolloutItem::Compacted(CompactedItem { - message: "noop".to_string(), - replacement_history: None, - })]; - - let builder = builder_from_items(items.as_slice(), path.as_path()).expect("builder"); - let naive = NaiveDateTime::parse_from_str("2026-01-27T12-34-56", "%Y-%m-%dT%H-%M-%S") - .expect("timestamp"); - let created_at = DateTime::::from_naive_utc_and_offset(naive, Utc) - .with_nanosecond(0) - .expect("nanosecond"); - let expected = ThreadMetadataBuilder::new( - ThreadId::from_string(&uuid.to_string()).expect("thread id"), - path, - created_at, - SessionSource::default(), - ); - - assert_eq!(builder, expected); - } - - #[tokio::test] - async fn backfill_sessions_resumes_from_watermark_and_marks_complete() { - let dir = tempdir().expect("tempdir"); - let codex_home = dir.path().to_path_buf(); - let first_uuid = Uuid::new_v4(); - let second_uuid = Uuid::new_v4(); - let first_path = write_rollout_in_sessions( - codex_home.as_path(), - "2026-01-27T12-34-56", - "2026-01-27T12:34:56Z", - first_uuid, - None, - ); - let second_path = write_rollout_in_sessions( - codex_home.as_path(), - "2026-01-27T12-35-56", - "2026-01-27T12:35:56Z", - second_uuid, - None, - ); - - let runtime = - codex_state::StateRuntime::init(codex_home.clone(), "test-provider".to_string()) - .await - .expect("initialize runtime"); - let first_watermark = - backfill_watermark_for_path(codex_home.as_path(), first_path.as_path()); - runtime.mark_backfill_running().await.expect("mark running"); - runtime - .checkpoint_backfill(first_watermark.as_str()) - .await - .expect("checkpoint first watermark"); - tokio::time::sleep(std::time::Duration::from_secs( - (BACKFILL_LEASE_SECONDS + 1) as u64, - )) - .await; - - let mut config = crate::config::test_config(); - config.codex_home = codex_home.clone(); - config.model_provider_id = "test-provider".to_string(); - backfill_sessions(runtime.as_ref(), &config).await; - - let first_id = ThreadId::from_string(&first_uuid.to_string()).expect("first thread id"); - let second_id = ThreadId::from_string(&second_uuid.to_string()).expect("second thread id"); - assert_eq!( - runtime - .get_thread(first_id) - .await - .expect("get first thread"), - None - ); - assert!( - runtime - .get_thread(second_id) - .await - .expect("get second thread") - .is_some() - ); - - let state = runtime - .get_backfill_state() - .await - .expect("get backfill state"); - assert_eq!(state.status, BackfillStatus::Complete); - assert_eq!( - state.last_watermark, - Some(backfill_watermark_for_path( - codex_home.as_path(), - second_path.as_path() - )) - ); - assert!(state.last_success_at.is_some()); - } - - #[tokio::test] - async fn backfill_sessions_preserves_existing_git_branch_and_fills_missing_git_fields() { - let dir = tempdir().expect("tempdir"); - let codex_home = dir.path().to_path_buf(); - let thread_uuid = Uuid::new_v4(); - let rollout_path = write_rollout_in_sessions( - codex_home.as_path(), - "2026-01-27T12-34-56", - "2026-01-27T12:34:56Z", - thread_uuid, - Some(GitInfo { - commit_hash: Some("rollout-sha".to_string()), - branch: Some("rollout-branch".to_string()), - repository_url: Some("git@example.com:openai/codex.git".to_string()), - }), - ); - - let runtime = - codex_state::StateRuntime::init(codex_home.clone(), "test-provider".to_string()) - .await - .expect("initialize runtime"); - let thread_id = ThreadId::from_string(&thread_uuid.to_string()).expect("thread id"); - let mut existing = extract_metadata_from_rollout(&rollout_path, "test-provider") - .await - .expect("extract") - .metadata; - existing.git_sha = None; - existing.git_branch = Some("sqlite-branch".to_string()); - existing.git_origin_url = None; - runtime - .upsert_thread(&existing) - .await - .expect("existing metadata upsert"); - - let mut config = crate::config::test_config(); - config.codex_home = codex_home.clone(); - config.model_provider_id = "test-provider".to_string(); - backfill_sessions(runtime.as_ref(), &config).await; - - let persisted = runtime - .get_thread(thread_id) - .await - .expect("get thread") - .expect("thread exists"); - assert_eq!(persisted.git_sha.as_deref(), Some("rollout-sha")); - assert_eq!(persisted.git_branch.as_deref(), Some("sqlite-branch")); - assert_eq!( - persisted.git_origin_url.as_deref(), - Some("git@example.com:openai/codex.git") - ); - } - - #[tokio::test] - async fn backfill_sessions_normalizes_cwd_before_upsert() { - let dir = tempdir().expect("tempdir"); - let codex_home = dir.path().to_path_buf(); - let thread_uuid = Uuid::new_v4(); - let session_cwd = codex_home.join("."); - let rollout_path = write_rollout_in_sessions_with_cwd( - codex_home.as_path(), - "2026-01-27T12-34-56", - "2026-01-27T12:34:56Z", - thread_uuid, - session_cwd.clone(), - None, - ); - - let runtime = - codex_state::StateRuntime::init(codex_home.clone(), "test-provider".to_string()) - .await - .expect("initialize runtime"); - - let mut config = crate::config::test_config(); - config.codex_home = codex_home.clone(); - config.model_provider_id = "test-provider".to_string(); - backfill_sessions(runtime.as_ref(), &config).await; - - let thread_id = ThreadId::from_string(&thread_uuid.to_string()).expect("thread id"); - let stored = runtime - .get_thread(thread_id) - .await - .expect("get thread") - .expect("thread should be backfilled"); - - assert_eq!(stored.rollout_path, rollout_path); - assert_eq!(stored.cwd, normalize_cwd_for_state_db(&session_cwd)); - } - - fn write_rollout_in_sessions( - codex_home: &Path, - filename_ts: &str, - event_ts: &str, - thread_uuid: Uuid, - git: Option, - ) -> PathBuf { - write_rollout_in_sessions_with_cwd( - codex_home, - filename_ts, - event_ts, - thread_uuid, - codex_home.to_path_buf(), - git, - ) - } - - fn write_rollout_in_sessions_with_cwd( - codex_home: &Path, - filename_ts: &str, - event_ts: &str, - thread_uuid: Uuid, - cwd: PathBuf, - git: Option, - ) -> PathBuf { - let id = ThreadId::from_string(&thread_uuid.to_string()).expect("thread id"); - let sessions_dir = codex_home.join("sessions"); - std::fs::create_dir_all(sessions_dir.as_path()).expect("create sessions dir"); - let path = sessions_dir.join(format!("rollout-{filename_ts}-{thread_uuid}.jsonl")); - let session_meta = SessionMeta { - id, - forked_from_id: None, - timestamp: event_ts.to_string(), - cwd, - originator: "cli".to_string(), - cli_version: "0.0.0".to_string(), - source: SessionSource::default(), - agent_nickname: None, - agent_role: None, - model_provider: Some("test-provider".to_string()), - base_instructions: None, - dynamic_tools: None, - memory_mode: None, - }; - let session_meta_line = SessionMetaLine { - meta: session_meta, - git, - }; - let rollout_line = RolloutLine { - timestamp: event_ts.to_string(), - item: RolloutItem::SessionMeta(session_meta_line), - }; - let json = serde_json::to_string(&rollout_line).expect("serialize rollout"); - let mut file = File::create(&path).expect("create rollout"); - writeln!(file, "{json}").expect("write rollout"); - path - } -} +#[path = "metadata_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/rollout/metadata_tests.rs b/codex-rs/core/src/rollout/metadata_tests.rs new file mode 100644 index 0000000000..5556d7002d --- /dev/null +++ b/codex-rs/core/src/rollout/metadata_tests.rs @@ -0,0 +1,377 @@ +use super::*; +use chrono::DateTime; +use chrono::NaiveDateTime; +use chrono::Timelike; +use chrono::Utc; +use codex_protocol::ThreadId; +use codex_protocol::protocol::CompactedItem; +use codex_protocol::protocol::GitInfo; +use codex_protocol::protocol::RolloutItem; +use codex_protocol::protocol::RolloutLine; +use codex_protocol::protocol::SessionMeta; +use codex_protocol::protocol::SessionMetaLine; +use codex_protocol::protocol::SessionSource; +use codex_state::BackfillStatus; +use codex_state::ThreadMetadataBuilder; +use pretty_assertions::assert_eq; +use std::fs::File; +use std::io::Write; +use std::path::Path; +use std::path::PathBuf; +use tempfile::tempdir; +use uuid::Uuid; + +#[tokio::test] +async fn extract_metadata_from_rollout_uses_session_meta() { + let dir = tempdir().expect("tempdir"); + let uuid = Uuid::new_v4(); + let id = ThreadId::from_string(&uuid.to_string()).expect("thread id"); + let path = dir + .path() + .join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl")); + + let session_meta = SessionMeta { + id, + forked_from_id: None, + timestamp: "2026-01-27T12:34:56Z".to_string(), + cwd: dir.path().to_path_buf(), + originator: "cli".to_string(), + cli_version: "0.0.0".to_string(), + source: SessionSource::default(), + agent_nickname: None, + agent_role: None, + model_provider: Some("openai".to_string()), + base_instructions: None, + dynamic_tools: None, + memory_mode: None, + }; + let session_meta_line = SessionMetaLine { + meta: session_meta, + git: None, + }; + let rollout_line = RolloutLine { + timestamp: "2026-01-27T12:34:56Z".to_string(), + item: RolloutItem::SessionMeta(session_meta_line.clone()), + }; + let json = serde_json::to_string(&rollout_line).expect("rollout json"); + let mut file = File::create(&path).expect("create rollout"); + writeln!(file, "{json}").expect("write rollout"); + + let outcome = extract_metadata_from_rollout(&path, "openai") + .await + .expect("extract"); + + let builder = builder_from_session_meta(&session_meta_line, path.as_path()).expect("builder"); + let mut expected = builder.build("openai"); + apply_rollout_item(&mut expected, &rollout_line.item, "openai"); + expected.updated_at = file_modified_time_utc(&path).await.expect("mtime"); + + assert_eq!(outcome.metadata, expected); + assert_eq!(outcome.memory_mode, None); + assert_eq!(outcome.parse_errors, 0); +} + +#[tokio::test] +async fn extract_metadata_from_rollout_returns_latest_memory_mode() { + let dir = tempdir().expect("tempdir"); + let uuid = Uuid::new_v4(); + let id = ThreadId::from_string(&uuid.to_string()).expect("thread id"); + let path = dir + .path() + .join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl")); + + let session_meta = SessionMeta { + id, + forked_from_id: None, + timestamp: "2026-01-27T12:34:56Z".to_string(), + cwd: dir.path().to_path_buf(), + originator: "cli".to_string(), + cli_version: "0.0.0".to_string(), + source: SessionSource::default(), + agent_nickname: None, + agent_role: None, + model_provider: Some("openai".to_string()), + base_instructions: None, + dynamic_tools: None, + memory_mode: None, + }; + let polluted_meta = SessionMeta { + memory_mode: Some("polluted".to_string()), + ..session_meta.clone() + }; + let lines = vec![ + RolloutLine { + timestamp: "2026-01-27T12:34:56Z".to_string(), + item: RolloutItem::SessionMeta(SessionMetaLine { + meta: session_meta, + git: None, + }), + }, + RolloutLine { + timestamp: "2026-01-27T12:35:00Z".to_string(), + item: RolloutItem::SessionMeta(SessionMetaLine { + meta: polluted_meta, + git: None, + }), + }, + ]; + let mut file = File::create(&path).expect("create rollout"); + for line in lines { + writeln!( + file, + "{}", + serde_json::to_string(&line).expect("serialize rollout line") + ) + .expect("write rollout line"); + } + + let outcome = extract_metadata_from_rollout(&path, "openai") + .await + .expect("extract"); + + assert_eq!(outcome.memory_mode.as_deref(), Some("polluted")); +} + +#[test] +fn builder_from_items_falls_back_to_filename() { + let dir = tempdir().expect("tempdir"); + let uuid = Uuid::new_v4(); + let path = dir + .path() + .join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl")); + let items = vec![RolloutItem::Compacted(CompactedItem { + message: "noop".to_string(), + replacement_history: None, + })]; + + let builder = builder_from_items(items.as_slice(), path.as_path()).expect("builder"); + let naive = NaiveDateTime::parse_from_str("2026-01-27T12-34-56", "%Y-%m-%dT%H-%M-%S") + .expect("timestamp"); + let created_at = DateTime::::from_naive_utc_and_offset(naive, Utc) + .with_nanosecond(0) + .expect("nanosecond"); + let expected = ThreadMetadataBuilder::new( + ThreadId::from_string(&uuid.to_string()).expect("thread id"), + path, + created_at, + SessionSource::default(), + ); + + assert_eq!(builder, expected); +} + +#[tokio::test] +async fn backfill_sessions_resumes_from_watermark_and_marks_complete() { + let dir = tempdir().expect("tempdir"); + let codex_home = dir.path().to_path_buf(); + let first_uuid = Uuid::new_v4(); + let second_uuid = Uuid::new_v4(); + let first_path = write_rollout_in_sessions( + codex_home.as_path(), + "2026-01-27T12-34-56", + "2026-01-27T12:34:56Z", + first_uuid, + None, + ); + let second_path = write_rollout_in_sessions( + codex_home.as_path(), + "2026-01-27T12-35-56", + "2026-01-27T12:35:56Z", + second_uuid, + None, + ); + + let runtime = codex_state::StateRuntime::init(codex_home.clone(), "test-provider".to_string()) + .await + .expect("initialize runtime"); + let first_watermark = backfill_watermark_for_path(codex_home.as_path(), first_path.as_path()); + runtime.mark_backfill_running().await.expect("mark running"); + runtime + .checkpoint_backfill(first_watermark.as_str()) + .await + .expect("checkpoint first watermark"); + tokio::time::sleep(std::time::Duration::from_secs( + (BACKFILL_LEASE_SECONDS + 1) as u64, + )) + .await; + + let mut config = crate::config::test_config(); + config.codex_home = codex_home.clone(); + config.model_provider_id = "test-provider".to_string(); + backfill_sessions(runtime.as_ref(), &config).await; + + let first_id = ThreadId::from_string(&first_uuid.to_string()).expect("first thread id"); + let second_id = ThreadId::from_string(&second_uuid.to_string()).expect("second thread id"); + assert_eq!( + runtime + .get_thread(first_id) + .await + .expect("get first thread"), + None + ); + assert!( + runtime + .get_thread(second_id) + .await + .expect("get second thread") + .is_some() + ); + + let state = runtime + .get_backfill_state() + .await + .expect("get backfill state"); + assert_eq!(state.status, BackfillStatus::Complete); + assert_eq!( + state.last_watermark, + Some(backfill_watermark_for_path( + codex_home.as_path(), + second_path.as_path() + )) + ); + assert!(state.last_success_at.is_some()); +} + +#[tokio::test] +async fn backfill_sessions_preserves_existing_git_branch_and_fills_missing_git_fields() { + let dir = tempdir().expect("tempdir"); + let codex_home = dir.path().to_path_buf(); + let thread_uuid = Uuid::new_v4(); + let rollout_path = write_rollout_in_sessions( + codex_home.as_path(), + "2026-01-27T12-34-56", + "2026-01-27T12:34:56Z", + thread_uuid, + Some(GitInfo { + commit_hash: Some("rollout-sha".to_string()), + branch: Some("rollout-branch".to_string()), + repository_url: Some("git@example.com:openai/codex.git".to_string()), + }), + ); + + let runtime = codex_state::StateRuntime::init(codex_home.clone(), "test-provider".to_string()) + .await + .expect("initialize runtime"); + let thread_id = ThreadId::from_string(&thread_uuid.to_string()).expect("thread id"); + let mut existing = extract_metadata_from_rollout(&rollout_path, "test-provider") + .await + .expect("extract") + .metadata; + existing.git_sha = None; + existing.git_branch = Some("sqlite-branch".to_string()); + existing.git_origin_url = None; + runtime + .upsert_thread(&existing) + .await + .expect("existing metadata upsert"); + + let mut config = crate::config::test_config(); + config.codex_home = codex_home.clone(); + config.model_provider_id = "test-provider".to_string(); + backfill_sessions(runtime.as_ref(), &config).await; + + let persisted = runtime + .get_thread(thread_id) + .await + .expect("get thread") + .expect("thread exists"); + assert_eq!(persisted.git_sha.as_deref(), Some("rollout-sha")); + assert_eq!(persisted.git_branch.as_deref(), Some("sqlite-branch")); + assert_eq!( + persisted.git_origin_url.as_deref(), + Some("git@example.com:openai/codex.git") + ); +} + +#[tokio::test] +async fn backfill_sessions_normalizes_cwd_before_upsert() { + let dir = tempdir().expect("tempdir"); + let codex_home = dir.path().to_path_buf(); + let thread_uuid = Uuid::new_v4(); + let session_cwd = codex_home.join("."); + let rollout_path = write_rollout_in_sessions_with_cwd( + codex_home.as_path(), + "2026-01-27T12-34-56", + "2026-01-27T12:34:56Z", + thread_uuid, + session_cwd.clone(), + None, + ); + + let runtime = codex_state::StateRuntime::init(codex_home.clone(), "test-provider".to_string()) + .await + .expect("initialize runtime"); + + let mut config = crate::config::test_config(); + config.codex_home = codex_home.clone(); + config.model_provider_id = "test-provider".to_string(); + backfill_sessions(runtime.as_ref(), &config).await; + + let thread_id = ThreadId::from_string(&thread_uuid.to_string()).expect("thread id"); + let stored = runtime + .get_thread(thread_id) + .await + .expect("get thread") + .expect("thread should be backfilled"); + + assert_eq!(stored.rollout_path, rollout_path); + assert_eq!(stored.cwd, normalize_cwd_for_state_db(&session_cwd)); +} + +fn write_rollout_in_sessions( + codex_home: &Path, + filename_ts: &str, + event_ts: &str, + thread_uuid: Uuid, + git: Option, +) -> PathBuf { + write_rollout_in_sessions_with_cwd( + codex_home, + filename_ts, + event_ts, + thread_uuid, + codex_home.to_path_buf(), + git, + ) +} + +fn write_rollout_in_sessions_with_cwd( + codex_home: &Path, + filename_ts: &str, + event_ts: &str, + thread_uuid: Uuid, + cwd: PathBuf, + git: Option, +) -> PathBuf { + let id = ThreadId::from_string(&thread_uuid.to_string()).expect("thread id"); + let sessions_dir = codex_home.join("sessions"); + std::fs::create_dir_all(sessions_dir.as_path()).expect("create sessions dir"); + let path = sessions_dir.join(format!("rollout-{filename_ts}-{thread_uuid}.jsonl")); + let session_meta = SessionMeta { + id, + forked_from_id: None, + timestamp: event_ts.to_string(), + cwd, + originator: "cli".to_string(), + cli_version: "0.0.0".to_string(), + source: SessionSource::default(), + agent_nickname: None, + agent_role: None, + model_provider: Some("test-provider".to_string()), + base_instructions: None, + dynamic_tools: None, + memory_mode: None, + }; + let session_meta_line = SessionMetaLine { + meta: session_meta, + git, + }; + let rollout_line = RolloutLine { + timestamp: event_ts.to_string(), + item: RolloutItem::SessionMeta(session_meta_line), + }; + let json = serde_json::to_string(&rollout_line).expect("serialize rollout"); + let mut file = File::create(&path).expect("create rollout"); + writeln!(file, "{json}").expect("write rollout"); + path +} diff --git a/codex-rs/core/src/rollout/recorder.rs b/codex-rs/core/src/rollout/recorder.rs index 851dcfb727..c4ec88784b 100644 --- a/codex-rs/core/src/rollout/recorder.rs +++ b/codex-rs/core/src/rollout/recorder.rs @@ -1102,517 +1102,5 @@ fn cwd_matches(session_cwd: &Path, cwd: &Path) -> bool { } #[cfg(test)] -mod tests { - use super::*; - use crate::config::ConfigBuilder; - use crate::features::Feature; - use chrono::TimeZone; - use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; - use codex_protocol::protocol::AgentMessageEvent; - use codex_protocol::protocol::AskForApproval; - use codex_protocol::protocol::EventMsg; - use codex_protocol::protocol::SandboxPolicy; - use codex_protocol::protocol::TurnContextItem; - use codex_protocol::protocol::UserMessageEvent; - use pretty_assertions::assert_eq; - use std::fs::File; - use std::fs::{self}; - use std::io::Write; - use std::path::Path; - use std::path::PathBuf; - use std::time::Duration; - use tempfile::TempDir; - use uuid::Uuid; - - fn write_session_file(root: &Path, ts: &str, uuid: Uuid) -> std::io::Result { - let day_dir = root.join("sessions/2025/01/03"); - fs::create_dir_all(&day_dir)?; - let path = day_dir.join(format!("rollout-{ts}-{uuid}.jsonl")); - let mut file = File::create(&path)?; - let meta = serde_json::json!({ - "timestamp": ts, - "type": "session_meta", - "payload": { - "id": uuid, - "timestamp": ts, - "cwd": ".", - "originator": "test_originator", - "cli_version": "test_version", - "source": "cli", - "model_provider": "test-provider", - }, - }); - writeln!(file, "{meta}")?; - let user_event = serde_json::json!({ - "timestamp": ts, - "type": "event_msg", - "payload": { - "type": "user_message", - "message": "Hello from user", - "kind": "plain", - }, - }); - writeln!(file, "{user_event}")?; - Ok(path) - } - - #[tokio::test] - async fn recorder_materializes_only_after_explicit_persist() -> std::io::Result<()> { - let home = TempDir::new().expect("temp dir"); - let config = ConfigBuilder::default() - .codex_home(home.path().to_path_buf()) - .build() - .await?; - let thread_id = ThreadId::new(); - let recorder = RolloutRecorder::new( - &config, - RolloutRecorderParams::new( - thread_id, - None, - SessionSource::Exec, - BaseInstructions::default(), - Vec::new(), - EventPersistenceMode::Limited, - ), - None, - None, - ) - .await?; - - let rollout_path = recorder.rollout_path().to_path_buf(); - assert!( - !rollout_path.exists(), - "rollout file should not exist before first user message" - ); - - recorder - .record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage( - AgentMessageEvent { - message: "buffered-event".to_string(), - phase: None, - }, - ))]) - .await?; - recorder.flush().await?; - assert!( - !rollout_path.exists(), - "rollout file should remain deferred before first user message" - ); - - recorder - .record_items(&[RolloutItem::EventMsg(EventMsg::UserMessage( - UserMessageEvent { - message: "first-user-message".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - }, - ))]) - .await?; - recorder.flush().await?; - assert!( - !rollout_path.exists(), - "user-message-like items should not materialize without explicit persist" - ); - - recorder.persist().await?; - // Second call verifies `persist()` is idempotent after materialization. - recorder.persist().await?; - assert!(rollout_path.exists(), "rollout file should be materialized"); - - let text = std::fs::read_to_string(&rollout_path)?; - assert!( - text.contains("\"type\":\"session_meta\""), - "expected session metadata in rollout" - ); - let buffered_idx = text - .find("buffered-event") - .expect("buffered event in rollout"); - let user_idx = text - .find("first-user-message") - .expect("first user message in rollout"); - assert!( - buffered_idx < user_idx, - "buffered items should preserve ordering" - ); - let text_after_second_persist = std::fs::read_to_string(&rollout_path)?; - assert_eq!(text_after_second_persist, text); - - recorder.shutdown().await?; - Ok(()) - } - - #[tokio::test] - async fn metadata_irrelevant_events_touch_state_db_updated_at() -> std::io::Result<()> { - let home = TempDir::new().expect("temp dir"); - let mut config = ConfigBuilder::default() - .codex_home(home.path().to_path_buf()) - .build() - .await?; - config - .features - .enable(Feature::Sqlite) - .expect("test config should allow sqlite"); - - let state_db = - StateRuntime::init(home.path().to_path_buf(), config.model_provider_id.clone()) - .await - .expect("state db should initialize"); - state_db - .mark_backfill_complete(None) - .await - .expect("backfill should be complete"); - - let thread_id = ThreadId::new(); - let recorder = RolloutRecorder::new( - &config, - RolloutRecorderParams::new( - thread_id, - None, - SessionSource::Cli, - BaseInstructions::default(), - Vec::new(), - EventPersistenceMode::Limited, - ), - Some(state_db.clone()), - None, - ) - .await?; - - recorder - .record_items(&[RolloutItem::EventMsg(EventMsg::UserMessage( - UserMessageEvent { - message: "first-user-message".to_string(), - images: None, - local_images: Vec::new(), - text_elements: Vec::new(), - }, - ))]) - .await?; - recorder.persist().await?; - recorder.flush().await?; - let initial_thread = state_db - .get_thread(thread_id) - .await - .expect("thread should load") - .expect("thread should exist"); - let initial_updated_at = initial_thread.updated_at; - let initial_title = initial_thread.title.clone(); - let initial_first_user_message = initial_thread.first_user_message.clone(); - - tokio::time::sleep(Duration::from_secs(1)).await; - - recorder - .record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage( - AgentMessageEvent { - message: "assistant text".to_string(), - phase: None, - }, - ))]) - .await?; - recorder.flush().await?; - - let updated_thread = state_db - .get_thread(thread_id) - .await - .expect("thread should load after agent message") - .expect("thread should still exist"); - - assert!(updated_thread.updated_at > initial_updated_at); - assert_eq!(updated_thread.title, initial_title); - assert_eq!( - updated_thread.first_user_message, - initial_first_user_message - ); - - recorder.shutdown().await?; - Ok(()) - } - - #[tokio::test] - async fn metadata_irrelevant_events_fall_back_to_upsert_when_thread_missing() - -> std::io::Result<()> { - let home = TempDir::new().expect("temp dir"); - let mut config = ConfigBuilder::default() - .codex_home(home.path().to_path_buf()) - .build() - .await?; - config - .features - .enable(Feature::Sqlite) - .expect("test config should allow sqlite"); - - let state_db = - StateRuntime::init(home.path().to_path_buf(), config.model_provider_id.clone()) - .await - .expect("state db should initialize"); - let thread_id = ThreadId::new(); - let rollout_path = home.path().join("rollout.jsonl"); - let builder = ThreadMetadataBuilder::new( - thread_id, - rollout_path.clone(), - Utc::now(), - SessionSource::Cli, - ); - let items = vec![RolloutItem::EventMsg(EventMsg::AgentMessage( - AgentMessageEvent { - message: "assistant text".to_string(), - phase: None, - }, - ))]; - - sync_thread_state_after_write( - Some(state_db.as_ref()), - rollout_path.as_path(), - Some(&builder), - items.as_slice(), - config.model_provider_id.as_str(), - None, - ) - .await; - - let thread = state_db - .get_thread(thread_id) - .await - .expect("thread should load after fallback") - .expect("thread should be inserted after fallback"); - assert_eq!(thread.id, thread_id); - - Ok(()) - } - - #[tokio::test] - async fn list_threads_db_disabled_does_not_skip_paginated_items() -> std::io::Result<()> { - let home = TempDir::new().expect("temp dir"); - let mut config = ConfigBuilder::default() - .codex_home(home.path().to_path_buf()) - .build() - .await?; - config - .features - .disable(Feature::Sqlite) - .expect("test config should allow sqlite to be disabled"); - - let newest = write_session_file(home.path(), "2025-01-03T12-00-00", Uuid::from_u128(9001))?; - let middle = write_session_file(home.path(), "2025-01-02T12-00-00", Uuid::from_u128(9002))?; - let _oldest = - write_session_file(home.path(), "2025-01-01T12-00-00", Uuid::from_u128(9003))?; - - let default_provider = config.model_provider_id.clone(); - let page1 = RolloutRecorder::list_threads( - &config, - 1, - None, - ThreadSortKey::CreatedAt, - &[], - None, - default_provider.as_str(), - None, - ) - .await?; - assert_eq!(page1.items.len(), 1); - assert_eq!(page1.items[0].path, newest); - let cursor = page1.next_cursor.clone().expect("cursor should be present"); - - let page2 = RolloutRecorder::list_threads( - &config, - 1, - Some(&cursor), - ThreadSortKey::CreatedAt, - &[], - None, - default_provider.as_str(), - None, - ) - .await?; - assert_eq!(page2.items.len(), 1); - assert_eq!(page2.items[0].path, middle); - Ok(()) - } - - #[tokio::test] - async fn list_threads_db_enabled_drops_missing_rollout_paths() -> std::io::Result<()> { - let home = TempDir::new().expect("temp dir"); - let mut config = ConfigBuilder::default() - .codex_home(home.path().to_path_buf()) - .build() - .await?; - config - .features - .enable(Feature::Sqlite) - .expect("test config should allow sqlite"); - - let uuid = Uuid::from_u128(9010); - let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); - let stale_path = home.path().join(format!( - "sessions/2099/01/01/rollout-2099-01-01T00-00-00-{uuid}.jsonl" - )); - - let runtime = codex_state::StateRuntime::init( - home.path().to_path_buf(), - config.model_provider_id.clone(), - ) - .await - .expect("state db should initialize"); - runtime - .mark_backfill_complete(None) - .await - .expect("backfill should be complete"); - let created_at = chrono::Utc - .with_ymd_and_hms(2025, 1, 3, 13, 0, 0) - .single() - .expect("valid datetime"); - let mut builder = codex_state::ThreadMetadataBuilder::new( - thread_id, - stale_path, - created_at, - SessionSource::Cli, - ); - builder.model_provider = Some(config.model_provider_id.clone()); - builder.cwd = home.path().to_path_buf(); - let mut metadata = builder.build(config.model_provider_id.as_str()); - metadata.first_user_message = Some("Hello from user".to_string()); - runtime - .upsert_thread(&metadata) - .await - .expect("state db upsert should succeed"); - - let default_provider = config.model_provider_id.clone(); - let page = RolloutRecorder::list_threads( - &config, - 10, - None, - ThreadSortKey::CreatedAt, - &[], - None, - default_provider.as_str(), - None, - ) - .await?; - assert_eq!(page.items.len(), 0); - let stored_path = runtime - .find_rollout_path_by_id(thread_id, Some(false)) - .await - .expect("state db lookup should succeed"); - assert_eq!(stored_path, None); - Ok(()) - } - - #[tokio::test] - async fn list_threads_db_enabled_repairs_stale_rollout_paths() -> std::io::Result<()> { - let home = TempDir::new().expect("temp dir"); - let mut config = ConfigBuilder::default() - .codex_home(home.path().to_path_buf()) - .build() - .await?; - config - .features - .enable(Feature::Sqlite) - .expect("test config should allow sqlite"); - - let uuid = Uuid::from_u128(9011); - let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); - let real_path = write_session_file(home.path(), "2025-01-03T13-00-00", uuid)?; - let stale_path = home.path().join(format!( - "sessions/2099/01/01/rollout-2099-01-01T00-00-00-{uuid}.jsonl" - )); - - let runtime = codex_state::StateRuntime::init( - home.path().to_path_buf(), - config.model_provider_id.clone(), - ) - .await - .expect("state db should initialize"); - runtime - .mark_backfill_complete(None) - .await - .expect("backfill should be complete"); - let created_at = chrono::Utc - .with_ymd_and_hms(2025, 1, 3, 13, 0, 0) - .single() - .expect("valid datetime"); - let mut builder = codex_state::ThreadMetadataBuilder::new( - thread_id, - stale_path, - created_at, - SessionSource::Cli, - ); - builder.model_provider = Some(config.model_provider_id.clone()); - builder.cwd = home.path().to_path_buf(); - let mut metadata = builder.build(config.model_provider_id.as_str()); - metadata.first_user_message = Some("Hello from user".to_string()); - runtime - .upsert_thread(&metadata) - .await - .expect("state db upsert should succeed"); - - let default_provider = config.model_provider_id.clone(); - let page = RolloutRecorder::list_threads( - &config, - 1, - None, - ThreadSortKey::CreatedAt, - &[], - None, - default_provider.as_str(), - None, - ) - .await?; - assert_eq!(page.items.len(), 1); - assert_eq!(page.items[0].path, real_path); - - let repaired_path = runtime - .find_rollout_path_by_id(thread_id, Some(false)) - .await - .expect("state db lookup should succeed"); - assert_eq!(repaired_path, Some(real_path)); - Ok(()) - } - - #[tokio::test] - async fn resume_candidate_matches_cwd_reads_latest_turn_context() -> std::io::Result<()> { - let home = TempDir::new().expect("temp dir"); - let stale_cwd = home.path().join("stale"); - let latest_cwd = home.path().join("latest"); - fs::create_dir_all(&stale_cwd)?; - fs::create_dir_all(&latest_cwd)?; - - let path = write_session_file(home.path(), "2025-01-03T13-00-00", Uuid::from_u128(9012))?; - let mut file = std::fs::OpenOptions::new().append(true).open(&path)?; - let turn_context = RolloutLine { - timestamp: "2025-01-03T13:00:01Z".to_string(), - item: RolloutItem::TurnContext(TurnContextItem { - turn_id: Some("turn-1".to_string()), - trace_id: None, - cwd: latest_cwd.clone(), - current_date: None, - timezone: None, - approval_policy: AskForApproval::Never, - sandbox_policy: SandboxPolicy::new_read_only_policy(), - network: None, - model: "test-model".to_string(), - personality: None, - collaboration_mode: None, - realtime_active: None, - effort: None, - summary: ReasoningSummaryConfig::Auto, - user_instructions: None, - developer_instructions: None, - final_output_json_schema: None, - truncation_policy: None, - }), - }; - writeln!(file, "{}", serde_json::to_string(&turn_context)?)?; - - assert!( - resume_candidate_matches_cwd( - path.as_path(), - Some(stale_cwd.as_path()), - latest_cwd.as_path(), - "test-provider", - ) - .await - ); - Ok(()) - } -} +#[path = "recorder_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/rollout/recorder_tests.rs b/codex-rs/core/src/rollout/recorder_tests.rs new file mode 100644 index 0000000000..f6f588574a --- /dev/null +++ b/codex-rs/core/src/rollout/recorder_tests.rs @@ -0,0 +1,509 @@ +use super::*; +use crate::config::ConfigBuilder; +use crate::features::Feature; +use chrono::TimeZone; +use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; +use codex_protocol::protocol::AgentMessageEvent; +use codex_protocol::protocol::AskForApproval; +use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::SandboxPolicy; +use codex_protocol::protocol::TurnContextItem; +use codex_protocol::protocol::UserMessageEvent; +use pretty_assertions::assert_eq; +use std::fs::File; +use std::fs::{self}; +use std::io::Write; +use std::path::Path; +use std::path::PathBuf; +use std::time::Duration; +use tempfile::TempDir; +use uuid::Uuid; + +fn write_session_file(root: &Path, ts: &str, uuid: Uuid) -> std::io::Result { + let day_dir = root.join("sessions/2025/01/03"); + fs::create_dir_all(&day_dir)?; + let path = day_dir.join(format!("rollout-{ts}-{uuid}.jsonl")); + let mut file = File::create(&path)?; + let meta = serde_json::json!({ + "timestamp": ts, + "type": "session_meta", + "payload": { + "id": uuid, + "timestamp": ts, + "cwd": ".", + "originator": "test_originator", + "cli_version": "test_version", + "source": "cli", + "model_provider": "test-provider", + }, + }); + writeln!(file, "{meta}")?; + let user_event = serde_json::json!({ + "timestamp": ts, + "type": "event_msg", + "payload": { + "type": "user_message", + "message": "Hello from user", + "kind": "plain", + }, + }); + writeln!(file, "{user_event}")?; + Ok(path) +} + +#[tokio::test] +async fn recorder_materializes_only_after_explicit_persist() -> std::io::Result<()> { + let home = TempDir::new().expect("temp dir"); + let config = ConfigBuilder::default() + .codex_home(home.path().to_path_buf()) + .build() + .await?; + let thread_id = ThreadId::new(); + let recorder = RolloutRecorder::new( + &config, + RolloutRecorderParams::new( + thread_id, + None, + SessionSource::Exec, + BaseInstructions::default(), + Vec::new(), + EventPersistenceMode::Limited, + ), + None, + None, + ) + .await?; + + let rollout_path = recorder.rollout_path().to_path_buf(); + assert!( + !rollout_path.exists(), + "rollout file should not exist before first user message" + ); + + recorder + .record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage( + AgentMessageEvent { + message: "buffered-event".to_string(), + phase: None, + }, + ))]) + .await?; + recorder.flush().await?; + assert!( + !rollout_path.exists(), + "rollout file should remain deferred before first user message" + ); + + recorder + .record_items(&[RolloutItem::EventMsg(EventMsg::UserMessage( + UserMessageEvent { + message: "first-user-message".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + }, + ))]) + .await?; + recorder.flush().await?; + assert!( + !rollout_path.exists(), + "user-message-like items should not materialize without explicit persist" + ); + + recorder.persist().await?; + // Second call verifies `persist()` is idempotent after materialization. + recorder.persist().await?; + assert!(rollout_path.exists(), "rollout file should be materialized"); + + let text = std::fs::read_to_string(&rollout_path)?; + assert!( + text.contains("\"type\":\"session_meta\""), + "expected session metadata in rollout" + ); + let buffered_idx = text + .find("buffered-event") + .expect("buffered event in rollout"); + let user_idx = text + .find("first-user-message") + .expect("first user message in rollout"); + assert!( + buffered_idx < user_idx, + "buffered items should preserve ordering" + ); + let text_after_second_persist = std::fs::read_to_string(&rollout_path)?; + assert_eq!(text_after_second_persist, text); + + recorder.shutdown().await?; + Ok(()) +} + +#[tokio::test] +async fn metadata_irrelevant_events_touch_state_db_updated_at() -> std::io::Result<()> { + let home = TempDir::new().expect("temp dir"); + let mut config = ConfigBuilder::default() + .codex_home(home.path().to_path_buf()) + .build() + .await?; + config + .features + .enable(Feature::Sqlite) + .expect("test config should allow sqlite"); + + let state_db = StateRuntime::init(home.path().to_path_buf(), config.model_provider_id.clone()) + .await + .expect("state db should initialize"); + state_db + .mark_backfill_complete(None) + .await + .expect("backfill should be complete"); + + let thread_id = ThreadId::new(); + let recorder = RolloutRecorder::new( + &config, + RolloutRecorderParams::new( + thread_id, + None, + SessionSource::Cli, + BaseInstructions::default(), + Vec::new(), + EventPersistenceMode::Limited, + ), + Some(state_db.clone()), + None, + ) + .await?; + + recorder + .record_items(&[RolloutItem::EventMsg(EventMsg::UserMessage( + UserMessageEvent { + message: "first-user-message".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + }, + ))]) + .await?; + recorder.persist().await?; + recorder.flush().await?; + let initial_thread = state_db + .get_thread(thread_id) + .await + .expect("thread should load") + .expect("thread should exist"); + let initial_updated_at = initial_thread.updated_at; + let initial_title = initial_thread.title.clone(); + let initial_first_user_message = initial_thread.first_user_message.clone(); + + tokio::time::sleep(Duration::from_secs(1)).await; + + recorder + .record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage( + AgentMessageEvent { + message: "assistant text".to_string(), + phase: None, + }, + ))]) + .await?; + recorder.flush().await?; + + let updated_thread = state_db + .get_thread(thread_id) + .await + .expect("thread should load after agent message") + .expect("thread should still exist"); + + assert!(updated_thread.updated_at > initial_updated_at); + assert_eq!(updated_thread.title, initial_title); + assert_eq!( + updated_thread.first_user_message, + initial_first_user_message + ); + + recorder.shutdown().await?; + Ok(()) +} + +#[tokio::test] +async fn metadata_irrelevant_events_fall_back_to_upsert_when_thread_missing() -> std::io::Result<()> +{ + let home = TempDir::new().expect("temp dir"); + let mut config = ConfigBuilder::default() + .codex_home(home.path().to_path_buf()) + .build() + .await?; + config + .features + .enable(Feature::Sqlite) + .expect("test config should allow sqlite"); + + let state_db = StateRuntime::init(home.path().to_path_buf(), config.model_provider_id.clone()) + .await + .expect("state db should initialize"); + let thread_id = ThreadId::new(); + let rollout_path = home.path().join("rollout.jsonl"); + let builder = ThreadMetadataBuilder::new( + thread_id, + rollout_path.clone(), + Utc::now(), + SessionSource::Cli, + ); + let items = vec![RolloutItem::EventMsg(EventMsg::AgentMessage( + AgentMessageEvent { + message: "assistant text".to_string(), + phase: None, + }, + ))]; + + sync_thread_state_after_write( + Some(state_db.as_ref()), + rollout_path.as_path(), + Some(&builder), + items.as_slice(), + config.model_provider_id.as_str(), + None, + ) + .await; + + let thread = state_db + .get_thread(thread_id) + .await + .expect("thread should load after fallback") + .expect("thread should be inserted after fallback"); + assert_eq!(thread.id, thread_id); + + Ok(()) +} + +#[tokio::test] +async fn list_threads_db_disabled_does_not_skip_paginated_items() -> std::io::Result<()> { + let home = TempDir::new().expect("temp dir"); + let mut config = ConfigBuilder::default() + .codex_home(home.path().to_path_buf()) + .build() + .await?; + config + .features + .disable(Feature::Sqlite) + .expect("test config should allow sqlite to be disabled"); + + let newest = write_session_file(home.path(), "2025-01-03T12-00-00", Uuid::from_u128(9001))?; + let middle = write_session_file(home.path(), "2025-01-02T12-00-00", Uuid::from_u128(9002))?; + let _oldest = write_session_file(home.path(), "2025-01-01T12-00-00", Uuid::from_u128(9003))?; + + let default_provider = config.model_provider_id.clone(); + let page1 = RolloutRecorder::list_threads( + &config, + 1, + None, + ThreadSortKey::CreatedAt, + &[], + None, + default_provider.as_str(), + None, + ) + .await?; + assert_eq!(page1.items.len(), 1); + assert_eq!(page1.items[0].path, newest); + let cursor = page1.next_cursor.clone().expect("cursor should be present"); + + let page2 = RolloutRecorder::list_threads( + &config, + 1, + Some(&cursor), + ThreadSortKey::CreatedAt, + &[], + None, + default_provider.as_str(), + None, + ) + .await?; + assert_eq!(page2.items.len(), 1); + assert_eq!(page2.items[0].path, middle); + Ok(()) +} + +#[tokio::test] +async fn list_threads_db_enabled_drops_missing_rollout_paths() -> std::io::Result<()> { + let home = TempDir::new().expect("temp dir"); + let mut config = ConfigBuilder::default() + .codex_home(home.path().to_path_buf()) + .build() + .await?; + config + .features + .enable(Feature::Sqlite) + .expect("test config should allow sqlite"); + + let uuid = Uuid::from_u128(9010); + let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); + let stale_path = home.path().join(format!( + "sessions/2099/01/01/rollout-2099-01-01T00-00-00-{uuid}.jsonl" + )); + + let runtime = codex_state::StateRuntime::init( + home.path().to_path_buf(), + config.model_provider_id.clone(), + ) + .await + .expect("state db should initialize"); + runtime + .mark_backfill_complete(None) + .await + .expect("backfill should be complete"); + let created_at = chrono::Utc + .with_ymd_and_hms(2025, 1, 3, 13, 0, 0) + .single() + .expect("valid datetime"); + let mut builder = codex_state::ThreadMetadataBuilder::new( + thread_id, + stale_path, + created_at, + SessionSource::Cli, + ); + builder.model_provider = Some(config.model_provider_id.clone()); + builder.cwd = home.path().to_path_buf(); + let mut metadata = builder.build(config.model_provider_id.as_str()); + metadata.first_user_message = Some("Hello from user".to_string()); + runtime + .upsert_thread(&metadata) + .await + .expect("state db upsert should succeed"); + + let default_provider = config.model_provider_id.clone(); + let page = RolloutRecorder::list_threads( + &config, + 10, + None, + ThreadSortKey::CreatedAt, + &[], + None, + default_provider.as_str(), + None, + ) + .await?; + assert_eq!(page.items.len(), 0); + let stored_path = runtime + .find_rollout_path_by_id(thread_id, Some(false)) + .await + .expect("state db lookup should succeed"); + assert_eq!(stored_path, None); + Ok(()) +} + +#[tokio::test] +async fn list_threads_db_enabled_repairs_stale_rollout_paths() -> std::io::Result<()> { + let home = TempDir::new().expect("temp dir"); + let mut config = ConfigBuilder::default() + .codex_home(home.path().to_path_buf()) + .build() + .await?; + config + .features + .enable(Feature::Sqlite) + .expect("test config should allow sqlite"); + + let uuid = Uuid::from_u128(9011); + let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id"); + let real_path = write_session_file(home.path(), "2025-01-03T13-00-00", uuid)?; + let stale_path = home.path().join(format!( + "sessions/2099/01/01/rollout-2099-01-01T00-00-00-{uuid}.jsonl" + )); + + let runtime = codex_state::StateRuntime::init( + home.path().to_path_buf(), + config.model_provider_id.clone(), + ) + .await + .expect("state db should initialize"); + runtime + .mark_backfill_complete(None) + .await + .expect("backfill should be complete"); + let created_at = chrono::Utc + .with_ymd_and_hms(2025, 1, 3, 13, 0, 0) + .single() + .expect("valid datetime"); + let mut builder = codex_state::ThreadMetadataBuilder::new( + thread_id, + stale_path, + created_at, + SessionSource::Cli, + ); + builder.model_provider = Some(config.model_provider_id.clone()); + builder.cwd = home.path().to_path_buf(); + let mut metadata = builder.build(config.model_provider_id.as_str()); + metadata.first_user_message = Some("Hello from user".to_string()); + runtime + .upsert_thread(&metadata) + .await + .expect("state db upsert should succeed"); + + let default_provider = config.model_provider_id.clone(); + let page = RolloutRecorder::list_threads( + &config, + 1, + None, + ThreadSortKey::CreatedAt, + &[], + None, + default_provider.as_str(), + None, + ) + .await?; + assert_eq!(page.items.len(), 1); + assert_eq!(page.items[0].path, real_path); + + let repaired_path = runtime + .find_rollout_path_by_id(thread_id, Some(false)) + .await + .expect("state db lookup should succeed"); + assert_eq!(repaired_path, Some(real_path)); + Ok(()) +} + +#[tokio::test] +async fn resume_candidate_matches_cwd_reads_latest_turn_context() -> std::io::Result<()> { + let home = TempDir::new().expect("temp dir"); + let stale_cwd = home.path().join("stale"); + let latest_cwd = home.path().join("latest"); + fs::create_dir_all(&stale_cwd)?; + fs::create_dir_all(&latest_cwd)?; + + let path = write_session_file(home.path(), "2025-01-03T13-00-00", Uuid::from_u128(9012))?; + let mut file = std::fs::OpenOptions::new().append(true).open(&path)?; + let turn_context = RolloutLine { + timestamp: "2025-01-03T13:00:01Z".to_string(), + item: RolloutItem::TurnContext(TurnContextItem { + turn_id: Some("turn-1".to_string()), + trace_id: None, + cwd: latest_cwd.clone(), + current_date: None, + timezone: None, + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::new_read_only_policy(), + network: None, + model: "test-model".to_string(), + personality: None, + collaboration_mode: None, + realtime_active: None, + effort: None, + summary: ReasoningSummaryConfig::Auto, + user_instructions: None, + developer_instructions: None, + final_output_json_schema: None, + truncation_policy: None, + }), + }; + writeln!(file, "{}", serde_json::to_string(&turn_context)?)?; + + assert!( + resume_candidate_matches_cwd( + path.as_path(), + Some(stale_cwd.as_path()), + latest_cwd.as_path(), + "test-provider", + ) + .await + ); + Ok(()) +} diff --git a/codex-rs/core/src/rollout/session_index.rs b/codex-rs/core/src/rollout/session_index.rs index c546dca331..8c88dd39ad 100644 --- a/codex-rs/core/src/rollout/session_index.rs +++ b/codex-rs/core/src/rollout/session_index.rs @@ -229,172 +229,5 @@ where } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use std::collections::HashMap; - use std::collections::HashSet; - use tempfile::TempDir; - fn write_index(path: &Path, lines: &[SessionIndexEntry]) -> std::io::Result<()> { - let mut out = String::new(); - for entry in lines { - out.push_str(&serde_json::to_string(entry).unwrap()); - out.push('\n'); - } - std::fs::write(path, out) - } - - #[test] - fn find_thread_id_by_name_prefers_latest_entry() -> std::io::Result<()> { - let temp = TempDir::new()?; - let path = session_index_path(temp.path()); - let id1 = ThreadId::new(); - let id2 = ThreadId::new(); - let lines = vec![ - SessionIndexEntry { - id: id1, - thread_name: "same".to_string(), - updated_at: "2024-01-01T00:00:00Z".to_string(), - }, - SessionIndexEntry { - id: id2, - thread_name: "same".to_string(), - updated_at: "2024-01-02T00:00:00Z".to_string(), - }, - ]; - write_index(&path, &lines)?; - - let found = scan_index_from_end_by_name(&path, "same")?; - assert_eq!(found.map(|entry| entry.id), Some(id2)); - Ok(()) - } - - #[test] - fn find_thread_name_by_id_prefers_latest_entry() -> std::io::Result<()> { - let temp = TempDir::new()?; - let path = session_index_path(temp.path()); - let id = ThreadId::new(); - let lines = vec![ - SessionIndexEntry { - id, - thread_name: "first".to_string(), - updated_at: "2024-01-01T00:00:00Z".to_string(), - }, - SessionIndexEntry { - id, - thread_name: "second".to_string(), - updated_at: "2024-01-02T00:00:00Z".to_string(), - }, - ]; - write_index(&path, &lines)?; - - let found = scan_index_from_end_by_id(&path, &id)?; - assert_eq!( - found.map(|entry| entry.thread_name), - Some("second".to_string()) - ); - Ok(()) - } - - #[test] - fn scan_index_returns_none_when_entry_missing() -> std::io::Result<()> { - let temp = TempDir::new()?; - let path = session_index_path(temp.path()); - let id = ThreadId::new(); - let lines = vec![SessionIndexEntry { - id, - thread_name: "present".to_string(), - updated_at: "2024-01-01T00:00:00Z".to_string(), - }]; - write_index(&path, &lines)?; - - let missing_name = scan_index_from_end_by_name(&path, "missing")?; - assert_eq!(missing_name, None); - - let missing_id = scan_index_from_end_by_id(&path, &ThreadId::new())?; - assert_eq!(missing_id, None); - Ok(()) - } - - #[tokio::test] - async fn find_thread_names_by_ids_prefers_latest_entry() -> std::io::Result<()> { - let temp = TempDir::new()?; - let path = session_index_path(temp.path()); - let id1 = ThreadId::new(); - let id2 = ThreadId::new(); - let lines = vec![ - SessionIndexEntry { - id: id1, - thread_name: "first".to_string(), - updated_at: "2024-01-01T00:00:00Z".to_string(), - }, - SessionIndexEntry { - id: id2, - thread_name: "other".to_string(), - updated_at: "2024-01-01T00:00:00Z".to_string(), - }, - SessionIndexEntry { - id: id1, - thread_name: "latest".to_string(), - updated_at: "2024-01-02T00:00:00Z".to_string(), - }, - ]; - write_index(&path, &lines)?; - - let mut ids = HashSet::new(); - ids.insert(id1); - ids.insert(id2); - - let mut expected = HashMap::new(); - expected.insert(id1, "latest".to_string()); - expected.insert(id2, "other".to_string()); - - let found = find_thread_names_by_ids(temp.path(), &ids).await?; - assert_eq!(found, expected); - Ok(()) - } - - #[test] - fn scan_index_finds_latest_match_among_mixed_entries() -> std::io::Result<()> { - let temp = TempDir::new()?; - let path = session_index_path(temp.path()); - let id_target = ThreadId::new(); - let id_other = ThreadId::new(); - let expected = SessionIndexEntry { - id: id_target, - thread_name: "target".to_string(), - updated_at: "2024-01-03T00:00:00Z".to_string(), - }; - let expected_other = SessionIndexEntry { - id: id_other, - thread_name: "target".to_string(), - updated_at: "2024-01-02T00:00:00Z".to_string(), - }; - // Resolution is based on append order (scan from end), not updated_at. - let lines = vec![ - SessionIndexEntry { - id: id_target, - thread_name: "target".to_string(), - updated_at: "2024-01-01T00:00:00Z".to_string(), - }, - expected_other.clone(), - expected.clone(), - SessionIndexEntry { - id: ThreadId::new(), - thread_name: "another".to_string(), - updated_at: "2024-01-04T00:00:00Z".to_string(), - }, - ]; - write_index(&path, &lines)?; - - let found_by_name = scan_index_from_end_by_name(&path, "target")?; - assert_eq!(found_by_name, Some(expected.clone())); - - let found_by_id = scan_index_from_end_by_id(&path, &id_target)?; - assert_eq!(found_by_id, Some(expected)); - - let found_other_by_id = scan_index_from_end_by_id(&path, &id_other)?; - assert_eq!(found_other_by_id, Some(expected_other)); - Ok(()) - } -} +#[path = "session_index_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/rollout/session_index_tests.rs b/codex-rs/core/src/rollout/session_index_tests.rs new file mode 100644 index 0000000000..864c4c5cf8 --- /dev/null +++ b/codex-rs/core/src/rollout/session_index_tests.rs @@ -0,0 +1,167 @@ +use super::*; +use pretty_assertions::assert_eq; +use std::collections::HashMap; +use std::collections::HashSet; +use tempfile::TempDir; +fn write_index(path: &Path, lines: &[SessionIndexEntry]) -> std::io::Result<()> { + let mut out = String::new(); + for entry in lines { + out.push_str(&serde_json::to_string(entry).unwrap()); + out.push('\n'); + } + std::fs::write(path, out) +} + +#[test] +fn find_thread_id_by_name_prefers_latest_entry() -> std::io::Result<()> { + let temp = TempDir::new()?; + let path = session_index_path(temp.path()); + let id1 = ThreadId::new(); + let id2 = ThreadId::new(); + let lines = vec![ + SessionIndexEntry { + id: id1, + thread_name: "same".to_string(), + updated_at: "2024-01-01T00:00:00Z".to_string(), + }, + SessionIndexEntry { + id: id2, + thread_name: "same".to_string(), + updated_at: "2024-01-02T00:00:00Z".to_string(), + }, + ]; + write_index(&path, &lines)?; + + let found = scan_index_from_end_by_name(&path, "same")?; + assert_eq!(found.map(|entry| entry.id), Some(id2)); + Ok(()) +} + +#[test] +fn find_thread_name_by_id_prefers_latest_entry() -> std::io::Result<()> { + let temp = TempDir::new()?; + let path = session_index_path(temp.path()); + let id = ThreadId::new(); + let lines = vec![ + SessionIndexEntry { + id, + thread_name: "first".to_string(), + updated_at: "2024-01-01T00:00:00Z".to_string(), + }, + SessionIndexEntry { + id, + thread_name: "second".to_string(), + updated_at: "2024-01-02T00:00:00Z".to_string(), + }, + ]; + write_index(&path, &lines)?; + + let found = scan_index_from_end_by_id(&path, &id)?; + assert_eq!( + found.map(|entry| entry.thread_name), + Some("second".to_string()) + ); + Ok(()) +} + +#[test] +fn scan_index_returns_none_when_entry_missing() -> std::io::Result<()> { + let temp = TempDir::new()?; + let path = session_index_path(temp.path()); + let id = ThreadId::new(); + let lines = vec![SessionIndexEntry { + id, + thread_name: "present".to_string(), + updated_at: "2024-01-01T00:00:00Z".to_string(), + }]; + write_index(&path, &lines)?; + + let missing_name = scan_index_from_end_by_name(&path, "missing")?; + assert_eq!(missing_name, None); + + let missing_id = scan_index_from_end_by_id(&path, &ThreadId::new())?; + assert_eq!(missing_id, None); + Ok(()) +} + +#[tokio::test] +async fn find_thread_names_by_ids_prefers_latest_entry() -> std::io::Result<()> { + let temp = TempDir::new()?; + let path = session_index_path(temp.path()); + let id1 = ThreadId::new(); + let id2 = ThreadId::new(); + let lines = vec![ + SessionIndexEntry { + id: id1, + thread_name: "first".to_string(), + updated_at: "2024-01-01T00:00:00Z".to_string(), + }, + SessionIndexEntry { + id: id2, + thread_name: "other".to_string(), + updated_at: "2024-01-01T00:00:00Z".to_string(), + }, + SessionIndexEntry { + id: id1, + thread_name: "latest".to_string(), + updated_at: "2024-01-02T00:00:00Z".to_string(), + }, + ]; + write_index(&path, &lines)?; + + let mut ids = HashSet::new(); + ids.insert(id1); + ids.insert(id2); + + let mut expected = HashMap::new(); + expected.insert(id1, "latest".to_string()); + expected.insert(id2, "other".to_string()); + + let found = find_thread_names_by_ids(temp.path(), &ids).await?; + assert_eq!(found, expected); + Ok(()) +} + +#[test] +fn scan_index_finds_latest_match_among_mixed_entries() -> std::io::Result<()> { + let temp = TempDir::new()?; + let path = session_index_path(temp.path()); + let id_target = ThreadId::new(); + let id_other = ThreadId::new(); + let expected = SessionIndexEntry { + id: id_target, + thread_name: "target".to_string(), + updated_at: "2024-01-03T00:00:00Z".to_string(), + }; + let expected_other = SessionIndexEntry { + id: id_other, + thread_name: "target".to_string(), + updated_at: "2024-01-02T00:00:00Z".to_string(), + }; + // Resolution is based on append order (scan from end), not updated_at. + let lines = vec![ + SessionIndexEntry { + id: id_target, + thread_name: "target".to_string(), + updated_at: "2024-01-01T00:00:00Z".to_string(), + }, + expected_other.clone(), + expected.clone(), + SessionIndexEntry { + id: ThreadId::new(), + thread_name: "another".to_string(), + updated_at: "2024-01-04T00:00:00Z".to_string(), + }, + ]; + write_index(&path, &lines)?; + + let found_by_name = scan_index_from_end_by_name(&path, "target")?; + assert_eq!(found_by_name, Some(expected.clone())); + + let found_by_id = scan_index_from_end_by_id(&path, &id_target)?; + assert_eq!(found_by_id, Some(expected)); + + let found_other_by_id = scan_index_from_end_by_id(&path, &id_other)?; + assert_eq!(found_other_by_id, Some(expected_other)); + Ok(()) +} diff --git a/codex-rs/core/src/rollout/truncation.rs b/codex-rs/core/src/rollout/truncation.rs index 6aacc43946..490bf42b97 100644 --- a/codex-rs/core/src/rollout/truncation.rs +++ b/codex-rs/core/src/rollout/truncation.rs @@ -69,154 +69,5 @@ pub(crate) fn truncate_rollout_before_nth_user_message_from_start( } #[cfg(test)] -mod tests { - use super::*; - use crate::codex::make_session_and_context; - use assert_matches::assert_matches; - use codex_protocol::models::ContentItem; - use codex_protocol::models::ReasoningItemReasoningSummary; - use codex_protocol::protocol::ThreadRolledBackEvent; - use pretty_assertions::assert_eq; - - fn user_msg(text: &str) -> ResponseItem { - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::OutputText { - text: text.to_string(), - }], - end_turn: None, - phase: None, - } - } - - fn assistant_msg(text: &str) -> ResponseItem { - ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: text.to_string(), - }], - end_turn: None, - phase: None, - } - } - - #[test] - fn truncates_rollout_from_start_before_nth_user_only() { - let items = [ - user_msg("u1"), - assistant_msg("a1"), - assistant_msg("a2"), - user_msg("u2"), - assistant_msg("a3"), - ResponseItem::Reasoning { - id: "r1".to_string(), - summary: vec![ReasoningItemReasoningSummary::SummaryText { - text: "s".to_string(), - }], - content: None, - encrypted_content: None, - }, - ResponseItem::FunctionCall { - id: None, - call_id: "c1".to_string(), - name: "tool".to_string(), - namespace: None, - arguments: "{}".to_string(), - }, - assistant_msg("a4"), - ]; - - let rollout: Vec = items - .iter() - .cloned() - .map(RolloutItem::ResponseItem) - .collect(); - - let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout, 1); - let expected = vec![ - RolloutItem::ResponseItem(items[0].clone()), - RolloutItem::ResponseItem(items[1].clone()), - RolloutItem::ResponseItem(items[2].clone()), - ]; - assert_eq!( - serde_json::to_value(&truncated).unwrap(), - serde_json::to_value(&expected).unwrap() - ); - - let truncated2 = truncate_rollout_before_nth_user_message_from_start(&rollout, 2); - assert_matches!(truncated2.as_slice(), []); - } - - #[test] - fn truncation_max_keeps_full_rollout() { - let rollout = vec![ - RolloutItem::ResponseItem(user_msg("u1")), - RolloutItem::ResponseItem(assistant_msg("a1")), - RolloutItem::ResponseItem(user_msg("u2")), - ]; - - let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout, usize::MAX); - - assert_eq!( - serde_json::to_value(&truncated).unwrap(), - serde_json::to_value(&rollout).unwrap() - ); - } - - #[test] - fn truncates_rollout_from_start_applies_thread_rollback_markers() { - let rollout_items = vec![ - RolloutItem::ResponseItem(user_msg("u1")), - RolloutItem::ResponseItem(assistant_msg("a1")), - RolloutItem::ResponseItem(user_msg("u2")), - RolloutItem::ResponseItem(assistant_msg("a2")), - RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent { - num_turns: 1, - })), - RolloutItem::ResponseItem(user_msg("u3")), - RolloutItem::ResponseItem(assistant_msg("a3")), - RolloutItem::ResponseItem(user_msg("u4")), - RolloutItem::ResponseItem(assistant_msg("a4")), - ]; - - // Effective user history after applying rollback(1) is: u1, u3, u4. - // So n_from_start=2 should cut before u4 (not u3). - let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 2); - let expected = rollout_items[..7].to_vec(); - assert_eq!( - serde_json::to_value(&truncated).unwrap(), - serde_json::to_value(&expected).unwrap() - ); - } - - #[tokio::test] - async fn ignores_session_prefix_messages_when_truncating_rollout_from_start() { - let (session, turn_context) = make_session_and_context().await; - let mut items = session.build_initial_context(&turn_context).await; - items.push(user_msg("feature request")); - items.push(assistant_msg("ack")); - items.push(user_msg("second question")); - items.push(assistant_msg("answer")); - - let rollout_items: Vec = items - .iter() - .cloned() - .map(RolloutItem::ResponseItem) - .collect(); - - let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 1); - let expected: Vec = vec![ - RolloutItem::ResponseItem(items[0].clone()), - RolloutItem::ResponseItem(items[1].clone()), - RolloutItem::ResponseItem(items[2].clone()), - RolloutItem::ResponseItem(items[3].clone()), - ]; - - assert_eq!( - serde_json::to_value(&truncated).unwrap(), - serde_json::to_value(&expected).unwrap() - ); - } -} +#[path = "truncation_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/rollout/truncation_tests.rs b/codex-rs/core/src/rollout/truncation_tests.rs new file mode 100644 index 0000000000..f7dd206264 --- /dev/null +++ b/codex-rs/core/src/rollout/truncation_tests.rs @@ -0,0 +1,149 @@ +use super::*; +use crate::codex::make_session_and_context; +use assert_matches::assert_matches; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ReasoningItemReasoningSummary; +use codex_protocol::protocol::ThreadRolledBackEvent; +use pretty_assertions::assert_eq; + +fn user_msg(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + } +} + +fn assistant_msg(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + } +} + +#[test] +fn truncates_rollout_from_start_before_nth_user_only() { + let items = [ + user_msg("u1"), + assistant_msg("a1"), + assistant_msg("a2"), + user_msg("u2"), + assistant_msg("a3"), + ResponseItem::Reasoning { + id: "r1".to_string(), + summary: vec![ReasoningItemReasoningSummary::SummaryText { + text: "s".to_string(), + }], + content: None, + encrypted_content: None, + }, + ResponseItem::FunctionCall { + id: None, + call_id: "c1".to_string(), + name: "tool".to_string(), + namespace: None, + arguments: "{}".to_string(), + }, + assistant_msg("a4"), + ]; + + let rollout: Vec = items + .iter() + .cloned() + .map(RolloutItem::ResponseItem) + .collect(); + + let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout, 1); + let expected = vec![ + RolloutItem::ResponseItem(items[0].clone()), + RolloutItem::ResponseItem(items[1].clone()), + RolloutItem::ResponseItem(items[2].clone()), + ]; + assert_eq!( + serde_json::to_value(&truncated).unwrap(), + serde_json::to_value(&expected).unwrap() + ); + + let truncated2 = truncate_rollout_before_nth_user_message_from_start(&rollout, 2); + assert_matches!(truncated2.as_slice(), []); +} + +#[test] +fn truncation_max_keeps_full_rollout() { + let rollout = vec![ + RolloutItem::ResponseItem(user_msg("u1")), + RolloutItem::ResponseItem(assistant_msg("a1")), + RolloutItem::ResponseItem(user_msg("u2")), + ]; + + let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout, usize::MAX); + + assert_eq!( + serde_json::to_value(&truncated).unwrap(), + serde_json::to_value(&rollout).unwrap() + ); +} + +#[test] +fn truncates_rollout_from_start_applies_thread_rollback_markers() { + let rollout_items = vec![ + RolloutItem::ResponseItem(user_msg("u1")), + RolloutItem::ResponseItem(assistant_msg("a1")), + RolloutItem::ResponseItem(user_msg("u2")), + RolloutItem::ResponseItem(assistant_msg("a2")), + RolloutItem::EventMsg(EventMsg::ThreadRolledBack(ThreadRolledBackEvent { + num_turns: 1, + })), + RolloutItem::ResponseItem(user_msg("u3")), + RolloutItem::ResponseItem(assistant_msg("a3")), + RolloutItem::ResponseItem(user_msg("u4")), + RolloutItem::ResponseItem(assistant_msg("a4")), + ]; + + // Effective user history after applying rollback(1) is: u1, u3, u4. + // So n_from_start=2 should cut before u4 (not u3). + let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 2); + let expected = rollout_items[..7].to_vec(); + assert_eq!( + serde_json::to_value(&truncated).unwrap(), + serde_json::to_value(&expected).unwrap() + ); +} + +#[tokio::test] +async fn ignores_session_prefix_messages_when_truncating_rollout_from_start() { + let (session, turn_context) = make_session_and_context().await; + let mut items = session.build_initial_context(&turn_context).await; + items.push(user_msg("feature request")); + items.push(assistant_msg("ack")); + items.push(user_msg("second question")); + items.push(assistant_msg("answer")); + + let rollout_items: Vec = items + .iter() + .cloned() + .map(RolloutItem::ResponseItem) + .collect(); + + let truncated = truncate_rollout_before_nth_user_message_from_start(&rollout_items, 1); + let expected: Vec = vec![ + RolloutItem::ResponseItem(items[0].clone()), + RolloutItem::ResponseItem(items[1].clone()), + RolloutItem::ResponseItem(items[2].clone()), + RolloutItem::ResponseItem(items[3].clone()), + ]; + + assert_eq!( + serde_json::to_value(&truncated).unwrap(), + serde_json::to_value(&expected).unwrap() + ); +} diff --git a/codex-rs/core/src/safety.rs b/codex-rs/core/src/safety.rs index d9b5368fc9..1fdd51a1b9 100644 --- a/codex-rs/core/src/safety.rs +++ b/codex-rs/core/src/safety.rs @@ -180,259 +180,5 @@ fn is_write_patch_constrained_to_writable_paths( } #[cfg(test)] -mod tests { - use super::*; - use codex_protocol::protocol::FileSystemAccessMode; - use codex_protocol::protocol::FileSystemPath; - use codex_protocol::protocol::FileSystemSandboxEntry; - use codex_protocol::protocol::FileSystemSpecialPath; - use codex_protocol::protocol::RejectConfig; - use codex_utils_absolute_path::AbsolutePathBuf; - use pretty_assertions::assert_eq; - use tempfile::TempDir; - - #[test] - fn test_writable_roots_constraint() { - // Use a temporary directory as our workspace to avoid touching - // the real current working directory. - let tmp = TempDir::new().unwrap(); - let cwd = tmp.path().to_path_buf(); - let parent = cwd.parent().unwrap().to_path_buf(); - - // Helper to build a single‑entry patch that adds a file at `p`. - let make_add_change = |p: PathBuf| ApplyPatchAction::new_add_for_test(&p, "".to_string()); - - let add_inside = make_add_change(cwd.join("inner.txt")); - let add_outside = make_add_change(parent.join("outside.txt")); - - // Policy limited to the workspace only; exclude system temp roots so - // only `cwd` is writable by default. - let policy_workspace_only = SandboxPolicy::WorkspaceWrite { - writable_roots: vec![], - read_only_access: Default::default(), - network_access: false, - exclude_tmpdir_env_var: true, - exclude_slash_tmp: true, - }; - - assert!(is_write_patch_constrained_to_writable_paths( - &add_inside, - &FileSystemSandboxPolicy::from(&policy_workspace_only), - &cwd, - )); - - assert!(!is_write_patch_constrained_to_writable_paths( - &add_outside, - &FileSystemSandboxPolicy::from(&policy_workspace_only), - &cwd, - )); - - // With the parent dir explicitly added as a writable root, the - // outside write should be permitted. - let policy_with_parent = SandboxPolicy::WorkspaceWrite { - writable_roots: vec![AbsolutePathBuf::try_from(parent).unwrap()], - read_only_access: Default::default(), - network_access: false, - exclude_tmpdir_env_var: true, - exclude_slash_tmp: true, - }; - assert!(is_write_patch_constrained_to_writable_paths( - &add_outside, - &FileSystemSandboxPolicy::from(&policy_with_parent), - &cwd, - )); - } - - #[test] - fn external_sandbox_auto_approves_in_on_request() { - let tmp = TempDir::new().unwrap(); - let cwd = tmp.path().to_path_buf(); - let add_inside = ApplyPatchAction::new_add_for_test(&cwd.join("inner.txt"), "".to_string()); - - let policy = SandboxPolicy::ExternalSandbox { - network_access: codex_protocol::protocol::NetworkAccess::Enabled, - }; - - assert_eq!( - assess_patch_safety( - &add_inside, - AskForApproval::OnRequest, - &policy, - &FileSystemSandboxPolicy::from(&policy), - &cwd, - WindowsSandboxLevel::Disabled - ), - SafetyCheck::AutoApprove { - sandbox_type: SandboxType::None, - user_explicitly_approved: false, - } - ); - } - - #[test] - fn reject_with_all_flags_false_matches_on_request_for_out_of_root_patch() { - let tmp = TempDir::new().unwrap(); - let cwd = tmp.path().to_path_buf(); - let parent = cwd.parent().unwrap().to_path_buf(); - let add_outside = - ApplyPatchAction::new_add_for_test(&parent.join("outside.txt"), "".to_string()); - let policy_workspace_only = SandboxPolicy::WorkspaceWrite { - writable_roots: vec![], - read_only_access: Default::default(), - network_access: false, - exclude_tmpdir_env_var: true, - exclude_slash_tmp: true, - }; - - assert_eq!( - assess_patch_safety( - &add_outside, - AskForApproval::OnRequest, - &policy_workspace_only, - &FileSystemSandboxPolicy::from(&policy_workspace_only), - &cwd, - WindowsSandboxLevel::Disabled, - ), - SafetyCheck::AskUser, - ); - assert_eq!( - assess_patch_safety( - &add_outside, - AskForApproval::Reject(RejectConfig { - sandbox_approval: false, - rules: false, - skill_approval: false, - request_permissions: false, - mcp_elicitations: false, - }), - &policy_workspace_only, - &FileSystemSandboxPolicy::from(&policy_workspace_only), - &cwd, - WindowsSandboxLevel::Disabled, - ), - SafetyCheck::AskUser, - ); - } - - #[test] - fn reject_sandbox_approval_rejects_out_of_root_patch() { - let tmp = TempDir::new().unwrap(); - let cwd = tmp.path().to_path_buf(); - let parent = cwd.parent().unwrap().to_path_buf(); - let add_outside = - ApplyPatchAction::new_add_for_test(&parent.join("outside.txt"), "".to_string()); - let policy_workspace_only = SandboxPolicy::WorkspaceWrite { - writable_roots: vec![], - read_only_access: Default::default(), - network_access: false, - exclude_tmpdir_env_var: true, - exclude_slash_tmp: true, - }; - - assert_eq!( - assess_patch_safety( - &add_outside, - AskForApproval::Reject(RejectConfig { - sandbox_approval: true, - rules: false, - skill_approval: false, - request_permissions: false, - mcp_elicitations: false, - }), - &policy_workspace_only, - &FileSystemSandboxPolicy::from(&policy_workspace_only), - &cwd, - WindowsSandboxLevel::Disabled, - ), - SafetyCheck::Reject { - reason: "writing outside of the project; rejected by user approval settings" - .to_string(), - }, - ); - } - #[test] - fn explicit_unreadable_paths_prevent_auto_approval_for_external_sandbox() { - let tmp = TempDir::new().unwrap(); - let cwd = tmp.path().to_path_buf(); - let blocked_path = cwd.join("blocked.txt"); - let blocked_absolute = AbsolutePathBuf::from_absolute_path(blocked_path.clone()).unwrap(); - let action = ApplyPatchAction::new_add_for_test(&blocked_path, "".to_string()); - let sandbox_policy = SandboxPolicy::ExternalSandbox { - network_access: codex_protocol::protocol::NetworkAccess::Restricted, - }; - let file_system_sandbox_policy = FileSystemSandboxPolicy::restricted(vec![ - FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: FileSystemSpecialPath::Root, - }, - access: FileSystemAccessMode::Write, - }, - FileSystemSandboxEntry { - path: FileSystemPath::Path { - path: blocked_absolute, - }, - access: FileSystemAccessMode::None, - }, - ]); - - assert!(!is_write_patch_constrained_to_writable_paths( - &action, - &file_system_sandbox_policy, - &cwd, - )); - assert_eq!( - assess_patch_safety( - &action, - AskForApproval::OnRequest, - &sandbox_policy, - &file_system_sandbox_policy, - &cwd, - WindowsSandboxLevel::Disabled, - ), - SafetyCheck::AskUser, - ); - } - - #[test] - fn explicit_read_only_subpaths_prevent_auto_approval_for_external_sandbox() { - let tmp = TempDir::new().unwrap(); - let cwd = tmp.path().to_path_buf(); - let blocked_path = cwd.join("docs").join("blocked.txt"); - let docs_absolute = AbsolutePathBuf::resolve_path_against_base("docs", &cwd).unwrap(); - let action = ApplyPatchAction::new_add_for_test(&blocked_path, "".to_string()); - let sandbox_policy = SandboxPolicy::ExternalSandbox { - network_access: codex_protocol::protocol::NetworkAccess::Restricted, - }; - let file_system_sandbox_policy = FileSystemSandboxPolicy::restricted(vec![ - FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: FileSystemSpecialPath::CurrentWorkingDirectory, - }, - access: FileSystemAccessMode::Write, - }, - FileSystemSandboxEntry { - path: FileSystemPath::Path { - path: docs_absolute, - }, - access: FileSystemAccessMode::Read, - }, - ]); - - assert!(!is_write_patch_constrained_to_writable_paths( - &action, - &file_system_sandbox_policy, - &cwd, - )); - assert_eq!( - assess_patch_safety( - &action, - AskForApproval::OnRequest, - &sandbox_policy, - &file_system_sandbox_policy, - &cwd, - WindowsSandboxLevel::Disabled, - ), - SafetyCheck::AskUser, - ); - } -} +#[path = "safety_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/safety_tests.rs b/codex-rs/core/src/safety_tests.rs new file mode 100644 index 0000000000..555d557d3c --- /dev/null +++ b/codex-rs/core/src/safety_tests.rs @@ -0,0 +1,254 @@ +use super::*; +use codex_protocol::protocol::FileSystemAccessMode; +use codex_protocol::protocol::FileSystemPath; +use codex_protocol::protocol::FileSystemSandboxEntry; +use codex_protocol::protocol::FileSystemSpecialPath; +use codex_protocol::protocol::RejectConfig; +use codex_utils_absolute_path::AbsolutePathBuf; +use pretty_assertions::assert_eq; +use tempfile::TempDir; + +#[test] +fn test_writable_roots_constraint() { + // Use a temporary directory as our workspace to avoid touching + // the real current working directory. + let tmp = TempDir::new().unwrap(); + let cwd = tmp.path().to_path_buf(); + let parent = cwd.parent().unwrap().to_path_buf(); + + // Helper to build a single‑entry patch that adds a file at `p`. + let make_add_change = |p: PathBuf| ApplyPatchAction::new_add_for_test(&p, "".to_string()); + + let add_inside = make_add_change(cwd.join("inner.txt")); + let add_outside = make_add_change(parent.join("outside.txt")); + + // Policy limited to the workspace only; exclude system temp roots so + // only `cwd` is writable by default. + let policy_workspace_only = SandboxPolicy::WorkspaceWrite { + writable_roots: vec![], + read_only_access: Default::default(), + network_access: false, + exclude_tmpdir_env_var: true, + exclude_slash_tmp: true, + }; + + assert!(is_write_patch_constrained_to_writable_paths( + &add_inside, + &FileSystemSandboxPolicy::from(&policy_workspace_only), + &cwd, + )); + + assert!(!is_write_patch_constrained_to_writable_paths( + &add_outside, + &FileSystemSandboxPolicy::from(&policy_workspace_only), + &cwd, + )); + + // With the parent dir explicitly added as a writable root, the + // outside write should be permitted. + let policy_with_parent = SandboxPolicy::WorkspaceWrite { + writable_roots: vec![AbsolutePathBuf::try_from(parent).unwrap()], + read_only_access: Default::default(), + network_access: false, + exclude_tmpdir_env_var: true, + exclude_slash_tmp: true, + }; + assert!(is_write_patch_constrained_to_writable_paths( + &add_outside, + &FileSystemSandboxPolicy::from(&policy_with_parent), + &cwd, + )); +} + +#[test] +fn external_sandbox_auto_approves_in_on_request() { + let tmp = TempDir::new().unwrap(); + let cwd = tmp.path().to_path_buf(); + let add_inside = ApplyPatchAction::new_add_for_test(&cwd.join("inner.txt"), "".to_string()); + + let policy = SandboxPolicy::ExternalSandbox { + network_access: codex_protocol::protocol::NetworkAccess::Enabled, + }; + + assert_eq!( + assess_patch_safety( + &add_inside, + AskForApproval::OnRequest, + &policy, + &FileSystemSandboxPolicy::from(&policy), + &cwd, + WindowsSandboxLevel::Disabled + ), + SafetyCheck::AutoApprove { + sandbox_type: SandboxType::None, + user_explicitly_approved: false, + } + ); +} + +#[test] +fn reject_with_all_flags_false_matches_on_request_for_out_of_root_patch() { + let tmp = TempDir::new().unwrap(); + let cwd = tmp.path().to_path_buf(); + let parent = cwd.parent().unwrap().to_path_buf(); + let add_outside = + ApplyPatchAction::new_add_for_test(&parent.join("outside.txt"), "".to_string()); + let policy_workspace_only = SandboxPolicy::WorkspaceWrite { + writable_roots: vec![], + read_only_access: Default::default(), + network_access: false, + exclude_tmpdir_env_var: true, + exclude_slash_tmp: true, + }; + + assert_eq!( + assess_patch_safety( + &add_outside, + AskForApproval::OnRequest, + &policy_workspace_only, + &FileSystemSandboxPolicy::from(&policy_workspace_only), + &cwd, + WindowsSandboxLevel::Disabled, + ), + SafetyCheck::AskUser, + ); + assert_eq!( + assess_patch_safety( + &add_outside, + AskForApproval::Reject(RejectConfig { + sandbox_approval: false, + rules: false, + skill_approval: false, + request_permissions: false, + mcp_elicitations: false, + }), + &policy_workspace_only, + &FileSystemSandboxPolicy::from(&policy_workspace_only), + &cwd, + WindowsSandboxLevel::Disabled, + ), + SafetyCheck::AskUser, + ); +} + +#[test] +fn reject_sandbox_approval_rejects_out_of_root_patch() { + let tmp = TempDir::new().unwrap(); + let cwd = tmp.path().to_path_buf(); + let parent = cwd.parent().unwrap().to_path_buf(); + let add_outside = + ApplyPatchAction::new_add_for_test(&parent.join("outside.txt"), "".to_string()); + let policy_workspace_only = SandboxPolicy::WorkspaceWrite { + writable_roots: vec![], + read_only_access: Default::default(), + network_access: false, + exclude_tmpdir_env_var: true, + exclude_slash_tmp: true, + }; + + assert_eq!( + assess_patch_safety( + &add_outside, + AskForApproval::Reject(RejectConfig { + sandbox_approval: true, + rules: false, + skill_approval: false, + request_permissions: false, + mcp_elicitations: false, + }), + &policy_workspace_only, + &FileSystemSandboxPolicy::from(&policy_workspace_only), + &cwd, + WindowsSandboxLevel::Disabled, + ), + SafetyCheck::Reject { + reason: "writing outside of the project; rejected by user approval settings" + .to_string(), + }, + ); +} +#[test] +fn explicit_unreadable_paths_prevent_auto_approval_for_external_sandbox() { + let tmp = TempDir::new().unwrap(); + let cwd = tmp.path().to_path_buf(); + let blocked_path = cwd.join("blocked.txt"); + let blocked_absolute = AbsolutePathBuf::from_absolute_path(blocked_path.clone()).unwrap(); + let action = ApplyPatchAction::new_add_for_test(&blocked_path, "".to_string()); + let sandbox_policy = SandboxPolicy::ExternalSandbox { + network_access: codex_protocol::protocol::NetworkAccess::Restricted, + }; + let file_system_sandbox_policy = FileSystemSandboxPolicy::restricted(vec![ + FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: FileSystemSpecialPath::Root, + }, + access: FileSystemAccessMode::Write, + }, + FileSystemSandboxEntry { + path: FileSystemPath::Path { + path: blocked_absolute, + }, + access: FileSystemAccessMode::None, + }, + ]); + + assert!(!is_write_patch_constrained_to_writable_paths( + &action, + &file_system_sandbox_policy, + &cwd, + )); + assert_eq!( + assess_patch_safety( + &action, + AskForApproval::OnRequest, + &sandbox_policy, + &file_system_sandbox_policy, + &cwd, + WindowsSandboxLevel::Disabled, + ), + SafetyCheck::AskUser, + ); +} + +#[test] +fn explicit_read_only_subpaths_prevent_auto_approval_for_external_sandbox() { + let tmp = TempDir::new().unwrap(); + let cwd = tmp.path().to_path_buf(); + let blocked_path = cwd.join("docs").join("blocked.txt"); + let docs_absolute = AbsolutePathBuf::resolve_path_against_base("docs", &cwd).unwrap(); + let action = ApplyPatchAction::new_add_for_test(&blocked_path, "".to_string()); + let sandbox_policy = SandboxPolicy::ExternalSandbox { + network_access: codex_protocol::protocol::NetworkAccess::Restricted, + }; + let file_system_sandbox_policy = FileSystemSandboxPolicy::restricted(vec![ + FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: FileSystemSpecialPath::CurrentWorkingDirectory, + }, + access: FileSystemAccessMode::Write, + }, + FileSystemSandboxEntry { + path: FileSystemPath::Path { + path: docs_absolute, + }, + access: FileSystemAccessMode::Read, + }, + ]); + + assert!(!is_write_patch_constrained_to_writable_paths( + &action, + &file_system_sandbox_policy, + &cwd, + )); + assert_eq!( + assess_patch_safety( + &action, + AskForApproval::OnRequest, + &sandbox_policy, + &file_system_sandbox_policy, + &cwd, + WindowsSandboxLevel::Disabled, + ), + SafetyCheck::AskUser, + ); +} diff --git a/codex-rs/core/src/sandbox_tags.rs b/codex-rs/core/src/sandbox_tags.rs index 767d2bbf8b..6a66d17dd0 100644 --- a/codex-rs/core/src/sandbox_tags.rs +++ b/codex-rs/core/src/sandbox_tags.rs @@ -24,44 +24,5 @@ pub(crate) fn sandbox_tag( } #[cfg(test)] -mod tests { - use super::sandbox_tag; - use crate::exec::SandboxType; - use crate::protocol::SandboxPolicy; - use crate::safety::get_platform_sandbox; - use codex_protocol::config_types::WindowsSandboxLevel; - use codex_protocol::protocol::NetworkAccess; - use pretty_assertions::assert_eq; - - #[test] - fn danger_full_access_is_untagged_even_when_linux_sandbox_defaults_apply() { - let actual = sandbox_tag( - &SandboxPolicy::DangerFullAccess, - WindowsSandboxLevel::Disabled, - ); - assert_eq!(actual, "none"); - } - - #[test] - fn external_sandbox_keeps_external_tag_when_linux_sandbox_defaults_apply() { - let actual = sandbox_tag( - &SandboxPolicy::ExternalSandbox { - network_access: NetworkAccess::Enabled, - }, - WindowsSandboxLevel::Disabled, - ); - assert_eq!(actual, "external"); - } - - #[test] - fn default_linux_sandbox_uses_platform_sandbox_tag() { - let actual = sandbox_tag( - &SandboxPolicy::new_read_only_policy(), - WindowsSandboxLevel::Disabled, - ); - let expected = get_platform_sandbox(false) - .map(SandboxType::as_metric_tag) - .unwrap_or("none"); - assert_eq!(actual, expected); - } -} +#[path = "sandbox_tags_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/sandbox_tags_tests.rs b/codex-rs/core/src/sandbox_tags_tests.rs new file mode 100644 index 0000000000..7084d5ff92 --- /dev/null +++ b/codex-rs/core/src/sandbox_tags_tests.rs @@ -0,0 +1,39 @@ +use super::sandbox_tag; +use crate::exec::SandboxType; +use crate::protocol::SandboxPolicy; +use crate::safety::get_platform_sandbox; +use codex_protocol::config_types::WindowsSandboxLevel; +use codex_protocol::protocol::NetworkAccess; +use pretty_assertions::assert_eq; + +#[test] +fn danger_full_access_is_untagged_even_when_linux_sandbox_defaults_apply() { + let actual = sandbox_tag( + &SandboxPolicy::DangerFullAccess, + WindowsSandboxLevel::Disabled, + ); + assert_eq!(actual, "none"); +} + +#[test] +fn external_sandbox_keeps_external_tag_when_linux_sandbox_defaults_apply() { + let actual = sandbox_tag( + &SandboxPolicy::ExternalSandbox { + network_access: NetworkAccess::Enabled, + }, + WindowsSandboxLevel::Disabled, + ); + assert_eq!(actual, "external"); +} + +#[test] +fn default_linux_sandbox_uses_platform_sandbox_tag() { + let actual = sandbox_tag( + &SandboxPolicy::new_read_only_policy(), + WindowsSandboxLevel::Disabled, + ); + let expected = get_platform_sandbox(false) + .map(SandboxType::as_metric_tag) + .unwrap_or("none"); + assert_eq!(actual, expected); +} diff --git a/codex-rs/core/src/sandboxing/macos_permissions.rs b/codex-rs/core/src/sandboxing/macos_permissions.rs index 5717a558cf..1a409d4bdf 100644 --- a/codex-rs/core/src/sandboxing/macos_permissions.rs +++ b/codex-rs/core/src/sandboxing/macos_permissions.rs @@ -150,129 +150,5 @@ fn intersect_macos_automation_permission( } #[cfg(all(test, target_os = "macos"))] -mod tests { - use super::intersect_macos_automation_permission; - use super::intersect_macos_seatbelt_profile_extensions; - use super::merge_macos_seatbelt_profile_extensions; - use super::union_macos_automation_permission; - use super::union_macos_contacts_permission; - use super::union_macos_preferences_permission; - use codex_protocol::models::MacOsAutomationPermission; - use codex_protocol::models::MacOsContactsPermission; - use codex_protocol::models::MacOsPreferencesPermission; - use codex_protocol::models::MacOsSeatbeltProfileExtensions; - use pretty_assertions::assert_eq; - - #[test] - fn merge_extensions_widens_permissions() { - let base = MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadOnly, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Calendar".to_string(), - ]), - macos_launch_services: false, - macos_accessibility: false, - macos_calendar: false, - macos_reminders: false, - macos_contacts: MacOsContactsPermission::ReadOnly, - }; - let requested = MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Notes".to_string(), - "com.apple.Calendar".to_string(), - ]), - macos_launch_services: true, - macos_accessibility: true, - macos_calendar: true, - macos_reminders: true, - macos_contacts: MacOsContactsPermission::ReadWrite, - }; - - let merged = - merge_macos_seatbelt_profile_extensions(Some(&base), Some(&requested)).expect("merge"); - - assert_eq!( - merged, - MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Calendar".to_string(), - "com.apple.Notes".to_string(), - ]), - macos_launch_services: true, - macos_accessibility: true, - macos_calendar: true, - macos_reminders: true, - macos_contacts: MacOsContactsPermission::ReadWrite, - } - ); - } - - #[test] - fn union_macos_preferences_permission_does_not_downgrade() { - let base = MacOsPreferencesPermission::ReadWrite; - let requested = MacOsPreferencesPermission::ReadOnly; - - let merged = union_macos_preferences_permission(&base, &requested); - - assert_eq!(merged, MacOsPreferencesPermission::ReadWrite); - } - - #[test] - fn union_macos_automation_permission_all_is_dominant() { - let base = MacOsAutomationPermission::BundleIds(vec!["com.apple.Notes".to_string()]); - let requested = MacOsAutomationPermission::All; - - let merged = union_macos_automation_permission(&base, &requested); - - assert_eq!(merged, MacOsAutomationPermission::All); - } - - #[test] - fn intersect_macos_automation_permission_keeps_common_bundle_ids() { - let requested = MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Notes".to_string(), - "com.apple.Calendar".to_string(), - ]); - let granted = MacOsAutomationPermission::BundleIds(vec!["com.apple.Notes".to_string()]); - - let intersected = intersect_macos_automation_permission(&requested, &granted); - - assert_eq!( - intersected, - MacOsAutomationPermission::BundleIds(vec!["com.apple.Notes".to_string()]) - ); - } - - #[test] - fn intersect_macos_seatbelt_profile_extensions_preserves_default_grant() { - let requested = MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Notes".to_string(), - ]), - macos_launch_services: false, - macos_accessibility: true, - macos_calendar: true, - macos_reminders: false, - macos_contacts: MacOsContactsPermission::None, - }; - let granted = MacOsSeatbeltProfileExtensions::default(); - - let intersected = - intersect_macos_seatbelt_profile_extensions(Some(requested), Some(granted)); - - assert_eq!(intersected, Some(MacOsSeatbeltProfileExtensions::default())); - } - - #[test] - fn union_macos_contacts_permission_does_not_downgrade() { - let base = MacOsContactsPermission::ReadWrite; - let requested = MacOsContactsPermission::ReadOnly; - - let merged = union_macos_contacts_permission(&base, &requested); - - assert_eq!(merged, MacOsContactsPermission::ReadWrite); - } -} +#[path = "macos_permissions_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/sandboxing/macos_permissions_tests.rs b/codex-rs/core/src/sandboxing/macos_permissions_tests.rs new file mode 100644 index 0000000000..97a2a2c753 --- /dev/null +++ b/codex-rs/core/src/sandboxing/macos_permissions_tests.rs @@ -0,0 +1,121 @@ +use super::intersect_macos_automation_permission; +use super::intersect_macos_seatbelt_profile_extensions; +use super::merge_macos_seatbelt_profile_extensions; +use super::union_macos_automation_permission; +use super::union_macos_contacts_permission; +use super::union_macos_preferences_permission; +use codex_protocol::models::MacOsAutomationPermission; +use codex_protocol::models::MacOsContactsPermission; +use codex_protocol::models::MacOsPreferencesPermission; +use codex_protocol::models::MacOsSeatbeltProfileExtensions; +use pretty_assertions::assert_eq; + +#[test] +fn merge_extensions_widens_permissions() { + let base = MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadOnly, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Calendar".to_string(), + ]), + macos_launch_services: false, + macos_accessibility: false, + macos_calendar: false, + macos_reminders: false, + macos_contacts: MacOsContactsPermission::ReadOnly, + }; + let requested = MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Notes".to_string(), + "com.apple.Calendar".to_string(), + ]), + macos_launch_services: true, + macos_accessibility: true, + macos_calendar: true, + macos_reminders: true, + macos_contacts: MacOsContactsPermission::ReadWrite, + }; + + let merged = + merge_macos_seatbelt_profile_extensions(Some(&base), Some(&requested)).expect("merge"); + + assert_eq!( + merged, + MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Calendar".to_string(), + "com.apple.Notes".to_string(), + ]), + macos_launch_services: true, + macos_accessibility: true, + macos_calendar: true, + macos_reminders: true, + macos_contacts: MacOsContactsPermission::ReadWrite, + } + ); +} + +#[test] +fn union_macos_preferences_permission_does_not_downgrade() { + let base = MacOsPreferencesPermission::ReadWrite; + let requested = MacOsPreferencesPermission::ReadOnly; + + let merged = union_macos_preferences_permission(&base, &requested); + + assert_eq!(merged, MacOsPreferencesPermission::ReadWrite); +} + +#[test] +fn union_macos_automation_permission_all_is_dominant() { + let base = MacOsAutomationPermission::BundleIds(vec!["com.apple.Notes".to_string()]); + let requested = MacOsAutomationPermission::All; + + let merged = union_macos_automation_permission(&base, &requested); + + assert_eq!(merged, MacOsAutomationPermission::All); +} + +#[test] +fn intersect_macos_automation_permission_keeps_common_bundle_ids() { + let requested = MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Notes".to_string(), + "com.apple.Calendar".to_string(), + ]); + let granted = MacOsAutomationPermission::BundleIds(vec!["com.apple.Notes".to_string()]); + + let intersected = intersect_macos_automation_permission(&requested, &granted); + + assert_eq!( + intersected, + MacOsAutomationPermission::BundleIds(vec!["com.apple.Notes".to_string()]) + ); +} + +#[test] +fn intersect_macos_seatbelt_profile_extensions_preserves_default_grant() { + let requested = MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + macos_automation: MacOsAutomationPermission::BundleIds(vec!["com.apple.Notes".to_string()]), + macos_launch_services: false, + macos_accessibility: true, + macos_calendar: true, + macos_reminders: false, + macos_contacts: MacOsContactsPermission::None, + }; + let granted = MacOsSeatbeltProfileExtensions::default(); + + let intersected = intersect_macos_seatbelt_profile_extensions(Some(requested), Some(granted)); + + assert_eq!(intersected, Some(MacOsSeatbeltProfileExtensions::default())); +} + +#[test] +fn union_macos_contacts_permission_does_not_downgrade() { + let base = MacOsContactsPermission::ReadWrite; + let requested = MacOsContactsPermission::ReadOnly; + + let merged = union_macos_contacts_permission(&base, &requested); + + assert_eq!(merged, MacOsContactsPermission::ReadWrite); +} diff --git a/codex-rs/core/src/sandboxing/mod.rs b/codex-rs/core/src/sandboxing/mod.rs index 9d0912faa1..fe4918a266 100644 --- a/codex-rs/core/src/sandboxing/mod.rs +++ b/codex-rs/core/src/sandboxing/mod.rs @@ -607,10 +607,18 @@ impl SandboxManager { ); let (effective_file_system_policy, effective_network_policy) = if let Some(additional_permissions) = additional_permissions { - let file_system_sandbox_policy = effective_file_system_sandbox_policy( - file_system_policy, - Some(&additional_permissions), - ); + let (extra_reads, extra_writes) = + additional_permission_roots(&additional_permissions); + let file_system_sandbox_policy = + if extra_reads.is_empty() && extra_writes.is_empty() { + file_system_policy.clone() + } else { + merge_file_system_policy_with_additional_permissions( + file_system_policy, + extra_reads, + extra_writes, + ) + }; let network_sandbox_policy = if merge_network_access(network_policy.is_enabled(), &additional_permissions) { NetworkSandboxPolicy::Enabled @@ -729,728 +737,5 @@ pub async fn execute_exec_request_with_after_spawn( } #[cfg(test)] -mod tests { - #[cfg(target_os = "macos")] - use super::EffectiveSandboxPermissions; - use super::SandboxManager; - use super::effective_file_system_sandbox_policy; - #[cfg(target_os = "macos")] - use super::intersect_permission_profiles; - use super::merge_file_system_policy_with_additional_permissions; - use super::normalize_additional_permissions; - use super::sandbox_policy_with_additional_permissions; - use super::should_require_platform_sandbox; - use crate::exec::SandboxType; - use crate::protocol::NetworkAccess; - use crate::protocol::ReadOnlyAccess; - use crate::protocol::SandboxPolicy; - use crate::tools::sandboxing::SandboxablePreference; - use codex_protocol::config_types::WindowsSandboxLevel; - use codex_protocol::models::FileSystemPermissions; - #[cfg(target_os = "macos")] - use codex_protocol::models::MacOsAutomationPermission; - #[cfg(target_os = "macos")] - use codex_protocol::models::MacOsContactsPermission; - #[cfg(target_os = "macos")] - use codex_protocol::models::MacOsPreferencesPermission; - #[cfg(target_os = "macos")] - use codex_protocol::models::MacOsSeatbeltProfileExtensions; - use codex_protocol::models::NetworkPermissions; - use codex_protocol::models::PermissionProfile; - use codex_protocol::permissions::FileSystemAccessMode; - use codex_protocol::permissions::FileSystemPath; - use codex_protocol::permissions::FileSystemSandboxEntry; - use codex_protocol::permissions::FileSystemSandboxPolicy; - use codex_protocol::permissions::FileSystemSpecialPath; - use codex_protocol::permissions::NetworkSandboxPolicy; - use codex_utils_absolute_path::AbsolutePathBuf; - use dunce::canonicalize; - use pretty_assertions::assert_eq; - use std::collections::HashMap; - use tempfile::TempDir; - - #[test] - fn danger_full_access_defaults_to_no_sandbox_without_network_requirements() { - let manager = SandboxManager::new(); - let sandbox = manager.select_initial( - &FileSystemSandboxPolicy::unrestricted(), - NetworkSandboxPolicy::Enabled, - SandboxablePreference::Auto, - WindowsSandboxLevel::Disabled, - false, - ); - assert_eq!(sandbox, SandboxType::None); - } - - #[test] - fn danger_full_access_uses_platform_sandbox_with_network_requirements() { - let manager = SandboxManager::new(); - let expected = crate::safety::get_platform_sandbox(false).unwrap_or(SandboxType::None); - let sandbox = manager.select_initial( - &FileSystemSandboxPolicy::unrestricted(), - NetworkSandboxPolicy::Enabled, - SandboxablePreference::Auto, - WindowsSandboxLevel::Disabled, - true, - ); - assert_eq!(sandbox, expected); - } - - #[test] - fn restricted_file_system_uses_platform_sandbox_without_managed_network() { - let manager = SandboxManager::new(); - let expected = crate::safety::get_platform_sandbox(false).unwrap_or(SandboxType::None); - let sandbox = manager.select_initial( - &FileSystemSandboxPolicy::restricted(vec![FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: FileSystemSpecialPath::Root, - }, - access: FileSystemAccessMode::Read, - }]), - NetworkSandboxPolicy::Enabled, - SandboxablePreference::Auto, - WindowsSandboxLevel::Disabled, - false, - ); - assert_eq!(sandbox, expected); - } - - #[test] - fn full_access_restricted_policy_skips_platform_sandbox_when_network_is_enabled() { - let policy = FileSystemSandboxPolicy::restricted(vec![FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: FileSystemSpecialPath::Root, - }, - access: FileSystemAccessMode::Write, - }]); - - assert_eq!( - should_require_platform_sandbox(&policy, NetworkSandboxPolicy::Enabled, false), - false - ); - } - - #[test] - fn root_write_policy_with_carveouts_still_uses_platform_sandbox() { - let blocked = AbsolutePathBuf::resolve_path_against_base( - "blocked", - std::env::current_dir().expect("current dir"), - ) - .expect("blocked path"); - let policy = FileSystemSandboxPolicy::restricted(vec![ - FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: FileSystemSpecialPath::Root, - }, - access: FileSystemAccessMode::Write, - }, - FileSystemSandboxEntry { - path: FileSystemPath::Path { path: blocked }, - access: FileSystemAccessMode::None, - }, - ]); - - assert_eq!( - should_require_platform_sandbox(&policy, NetworkSandboxPolicy::Enabled, false), - true - ); - } - - #[test] - fn full_access_restricted_policy_still_uses_platform_sandbox_for_restricted_network() { - let policy = FileSystemSandboxPolicy::restricted(vec![FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: FileSystemSpecialPath::Root, - }, - access: FileSystemAccessMode::Write, - }]); - - assert_eq!( - should_require_platform_sandbox(&policy, NetworkSandboxPolicy::Restricted, false), - true - ); - } - - #[test] - fn transform_preserves_unrestricted_file_system_policy_for_restricted_network() { - let manager = SandboxManager::new(); - let cwd = std::env::current_dir().expect("current dir"); - let exec_request = manager - .transform(super::SandboxTransformRequest { - spec: super::CommandSpec { - program: "true".to_string(), - args: Vec::new(), - cwd: cwd.clone(), - env: HashMap::new(), - expiration: crate::exec::ExecExpiration::DefaultTimeout, - sandbox_permissions: super::SandboxPermissions::UseDefault, - additional_permissions: None, - justification: None, - }, - policy: &SandboxPolicy::ExternalSandbox { - network_access: crate::protocol::NetworkAccess::Restricted, - }, - file_system_policy: &FileSystemSandboxPolicy::unrestricted(), - network_policy: NetworkSandboxPolicy::Restricted, - sandbox: SandboxType::None, - enforce_managed_network: false, - network: None, - sandbox_policy_cwd: cwd.as_path(), - #[cfg(target_os = "macos")] - macos_seatbelt_profile_extensions: None, - codex_linux_sandbox_exe: None, - use_legacy_landlock: false, - windows_sandbox_level: WindowsSandboxLevel::Disabled, - }) - .expect("transform"); - - assert_eq!( - exec_request.file_system_sandbox_policy, - FileSystemSandboxPolicy::unrestricted() - ); - assert_eq!( - exec_request.network_sandbox_policy, - NetworkSandboxPolicy::Restricted - ); - } - - #[test] - fn normalize_additional_permissions_preserves_network() { - let temp_dir = TempDir::new().expect("create temp dir"); - let path = AbsolutePathBuf::from_absolute_path( - canonicalize(temp_dir.path()).expect("canonicalize temp dir"), - ) - .expect("absolute temp dir"); - let permissions = normalize_additional_permissions(PermissionProfile { - network: Some(NetworkPermissions { - enabled: Some(true), - }), - file_system: Some(FileSystemPermissions { - read: Some(vec![path.clone()]), - write: Some(vec![path.clone()]), - }), - ..Default::default() - }) - .expect("permissions"); - - assert_eq!( - permissions.network, - Some(NetworkPermissions { - enabled: Some(true), - }) - ); - assert_eq!( - permissions.file_system, - Some(FileSystemPermissions { - read: Some(vec![path.clone()]), - write: Some(vec![path]), - }) - ); - } - - #[test] - fn normalize_additional_permissions_drops_empty_nested_profiles() { - let permissions = normalize_additional_permissions(PermissionProfile { - network: Some(NetworkPermissions { enabled: None }), - file_system: Some(FileSystemPermissions { - read: None, - write: None, - }), - macos: None, - }) - .expect("permissions"); - - assert_eq!(permissions, PermissionProfile::default()); - } - - #[cfg(target_os = "macos")] - #[test] - fn normalize_additional_permissions_preserves_default_macos_preferences_permission() { - let permissions = normalize_additional_permissions(PermissionProfile { - macos: Some(MacOsSeatbeltProfileExtensions::default()), - ..Default::default() - }) - .expect("permissions"); - - assert_eq!( - permissions, - PermissionProfile { - macos: Some(MacOsSeatbeltProfileExtensions::default()), - ..Default::default() - } - ); - } - - #[cfg(target_os = "macos")] - #[test] - fn intersect_permission_profiles_preserves_default_macos_grants() { - let requested = PermissionProfile { - file_system: Some(FileSystemPermissions { - read: Some(Vec::from(["/tmp/requested" - .try_into() - .expect("absolute path")])), - write: None, - }), - macos: Some(MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Notes".to_string(), - ]), - macos_launch_services: false, - macos_accessibility: true, - macos_calendar: true, - macos_reminders: false, - macos_contacts: MacOsContactsPermission::None, - }), - ..Default::default() - }; - let granted = PermissionProfile { - file_system: Some(FileSystemPermissions { - read: Some(Vec::new()), - write: None, - }), - macos: Some(MacOsSeatbeltProfileExtensions::default()), - ..Default::default() - }; - - assert_eq!( - intersect_permission_profiles(requested, granted), - PermissionProfile { - macos: Some(MacOsSeatbeltProfileExtensions::default()), - ..Default::default() - } - ); - } - - #[cfg(target_os = "macos")] - #[test] - fn normalize_additional_permissions_preserves_macos_permissions() { - let permissions = normalize_additional_permissions(PermissionProfile { - macos: Some(MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Notes".to_string(), - ]), - macos_launch_services: true, - macos_accessibility: true, - macos_calendar: true, - macos_reminders: false, - macos_contacts: MacOsContactsPermission::None, - }), - ..Default::default() - }) - .expect("permissions"); - - assert_eq!( - permissions.macos, - Some(MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Notes".to_string(), - ]), - macos_launch_services: true, - macos_accessibility: true, - macos_calendar: true, - macos_reminders: false, - macos_contacts: MacOsContactsPermission::None, - }) - ); - } - - #[test] - fn read_only_additional_permissions_can_enable_network_without_writes() { - let temp_dir = TempDir::new().expect("create temp dir"); - let path = AbsolutePathBuf::from_absolute_path( - canonicalize(temp_dir.path()).expect("canonicalize temp dir"), - ) - .expect("absolute temp dir"); - let policy = sandbox_policy_with_additional_permissions( - &SandboxPolicy::ReadOnly { - access: ReadOnlyAccess::Restricted { - include_platform_defaults: true, - readable_roots: vec![path.clone()], - }, - network_access: false, - }, - &PermissionProfile { - network: Some(NetworkPermissions { - enabled: Some(true), - }), - file_system: Some(FileSystemPermissions { - read: Some(vec![path.clone()]), - write: Some(Vec::new()), - }), - ..Default::default() - }, - ); - - assert_eq!( - policy, - SandboxPolicy::ReadOnly { - access: ReadOnlyAccess::Restricted { - include_platform_defaults: true, - readable_roots: vec![path], - }, - network_access: true, - } - ); - } - #[cfg(target_os = "macos")] - #[test] - fn effective_permissions_merge_macos_extensions_with_additional_permissions() { - let temp_dir = TempDir::new().expect("create temp dir"); - let path = AbsolutePathBuf::from_absolute_path( - canonicalize(temp_dir.path()).expect("canonicalize temp dir"), - ) - .expect("absolute temp dir"); - let effective_permissions = EffectiveSandboxPermissions::new( - &SandboxPolicy::ReadOnly { - access: ReadOnlyAccess::Restricted { - include_platform_defaults: true, - readable_roots: vec![path.clone()], - }, - network_access: false, - }, - Some(&MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadOnly, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Calendar".to_string(), - ]), - macos_launch_services: false, - macos_accessibility: false, - macos_calendar: false, - macos_reminders: false, - macos_contacts: MacOsContactsPermission::None, - }), - Some(&PermissionProfile { - file_system: Some(FileSystemPermissions { - read: Some(vec![path]), - write: Some(Vec::new()), - }), - macos: Some(MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Notes".to_string(), - ]), - macos_launch_services: true, - macos_accessibility: true, - macos_calendar: true, - macos_reminders: false, - macos_contacts: MacOsContactsPermission::None, - }), - ..Default::default() - }), - ); - - assert_eq!( - effective_permissions.macos_seatbelt_profile_extensions, - Some(MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Calendar".to_string(), - "com.apple.Notes".to_string(), - ]), - macos_launch_services: true, - macos_accessibility: true, - macos_calendar: true, - macos_reminders: false, - macos_contacts: MacOsContactsPermission::None, - }) - ); - } - - #[test] - fn external_sandbox_additional_permissions_can_enable_network() { - let temp_dir = TempDir::new().expect("create temp dir"); - let path = AbsolutePathBuf::from_absolute_path( - canonicalize(temp_dir.path()).expect("canonicalize temp dir"), - ) - .expect("absolute temp dir"); - let policy = sandbox_policy_with_additional_permissions( - &SandboxPolicy::ExternalSandbox { - network_access: NetworkAccess::Restricted, - }, - &PermissionProfile { - network: Some(NetworkPermissions { - enabled: Some(true), - }), - file_system: Some(FileSystemPermissions { - read: Some(vec![path]), - write: Some(Vec::new()), - }), - ..Default::default() - }, - ); - - assert_eq!( - policy, - SandboxPolicy::ExternalSandbox { - network_access: NetworkAccess::Enabled, - } - ); - } - - #[test] - fn transform_additional_permissions_enable_network_for_external_sandbox() { - let manager = SandboxManager::new(); - let cwd = std::env::current_dir().expect("current dir"); - let temp_dir = TempDir::new().expect("create temp dir"); - let path = AbsolutePathBuf::from_absolute_path( - canonicalize(temp_dir.path()).expect("canonicalize temp dir"), - ) - .expect("absolute temp dir"); - let exec_request = manager - .transform(super::SandboxTransformRequest { - spec: super::CommandSpec { - program: "true".to_string(), - args: Vec::new(), - cwd: cwd.clone(), - env: HashMap::new(), - expiration: crate::exec::ExecExpiration::DefaultTimeout, - sandbox_permissions: super::SandboxPermissions::WithAdditionalPermissions, - additional_permissions: Some(PermissionProfile { - network: Some(NetworkPermissions { - enabled: Some(true), - }), - file_system: Some(FileSystemPermissions { - read: Some(vec![path]), - write: Some(Vec::new()), - }), - ..Default::default() - }), - justification: None, - }, - policy: &SandboxPolicy::ExternalSandbox { - network_access: NetworkAccess::Restricted, - }, - file_system_policy: &FileSystemSandboxPolicy::unrestricted(), - network_policy: NetworkSandboxPolicy::Restricted, - sandbox: SandboxType::None, - enforce_managed_network: false, - network: None, - sandbox_policy_cwd: cwd.as_path(), - #[cfg(target_os = "macos")] - macos_seatbelt_profile_extensions: None, - codex_linux_sandbox_exe: None, - use_legacy_landlock: false, - windows_sandbox_level: WindowsSandboxLevel::Disabled, - }) - .expect("transform"); - - assert_eq!( - exec_request.sandbox_policy, - SandboxPolicy::ExternalSandbox { - network_access: NetworkAccess::Enabled, - } - ); - assert_eq!( - exec_request.network_sandbox_policy, - NetworkSandboxPolicy::Enabled - ); - } - - #[test] - fn transform_additional_permissions_preserves_denied_entries() { - let manager = SandboxManager::new(); - let cwd = std::env::current_dir().expect("current dir"); - let temp_dir = TempDir::new().expect("create temp dir"); - let workspace_root = AbsolutePathBuf::from_absolute_path( - canonicalize(temp_dir.path()).expect("canonicalize temp dir"), - ) - .expect("absolute temp dir"); - let allowed_path = workspace_root.join("allowed").expect("allowed path"); - let denied_path = workspace_root.join("denied").expect("denied path"); - let exec_request = manager - .transform(super::SandboxTransformRequest { - spec: super::CommandSpec { - program: "true".to_string(), - args: Vec::new(), - cwd: cwd.clone(), - env: HashMap::new(), - expiration: crate::exec::ExecExpiration::DefaultTimeout, - sandbox_permissions: super::SandboxPermissions::WithAdditionalPermissions, - additional_permissions: Some(PermissionProfile { - file_system: Some(FileSystemPermissions { - read: None, - write: Some(vec![allowed_path.clone()]), - }), - ..Default::default() - }), - justification: None, - }, - policy: &SandboxPolicy::ReadOnly { - access: ReadOnlyAccess::FullAccess, - network_access: false, - }, - file_system_policy: &FileSystemSandboxPolicy::restricted(vec![ - FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: FileSystemSpecialPath::Root, - }, - access: FileSystemAccessMode::Read, - }, - FileSystemSandboxEntry { - path: FileSystemPath::Path { - path: denied_path.clone(), - }, - access: FileSystemAccessMode::None, - }, - ]), - network_policy: NetworkSandboxPolicy::Restricted, - sandbox: SandboxType::None, - enforce_managed_network: false, - network: None, - sandbox_policy_cwd: cwd.as_path(), - #[cfg(target_os = "macos")] - macos_seatbelt_profile_extensions: None, - codex_linux_sandbox_exe: None, - use_legacy_landlock: false, - windows_sandbox_level: WindowsSandboxLevel::Disabled, - }) - .expect("transform"); - - assert_eq!( - exec_request.file_system_sandbox_policy, - FileSystemSandboxPolicy::restricted(vec![ - FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: FileSystemSpecialPath::Root, - }, - access: FileSystemAccessMode::Read, - }, - FileSystemSandboxEntry { - path: FileSystemPath::Path { path: denied_path }, - access: FileSystemAccessMode::None, - }, - FileSystemSandboxEntry { - path: FileSystemPath::Path { path: allowed_path }, - access: FileSystemAccessMode::Write, - }, - ]) - ); - assert_eq!( - exec_request.network_sandbox_policy, - NetworkSandboxPolicy::Restricted - ); - } - - #[test] - fn merge_file_system_policy_with_additional_permissions_preserves_unreadable_roots() { - let temp_dir = TempDir::new().expect("create temp dir"); - let cwd = AbsolutePathBuf::from_absolute_path( - canonicalize(temp_dir.path()).expect("canonicalize temp dir"), - ) - .expect("absolute temp dir"); - let allowed_path = cwd.join("allowed").expect("allowed path"); - let denied_path = cwd.join("denied").expect("denied path"); - let merged_policy = merge_file_system_policy_with_additional_permissions( - &FileSystemSandboxPolicy::restricted(vec![ - FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: FileSystemSpecialPath::Root, - }, - access: FileSystemAccessMode::Read, - }, - FileSystemSandboxEntry { - path: FileSystemPath::Path { - path: denied_path.clone(), - }, - access: FileSystemAccessMode::None, - }, - ]), - vec![allowed_path.clone()], - Vec::new(), - ); - - assert_eq!( - merged_policy.entries.contains(&FileSystemSandboxEntry { - path: FileSystemPath::Path { path: denied_path }, - access: FileSystemAccessMode::None, - }), - true - ); - assert_eq!( - merged_policy.entries.contains(&FileSystemSandboxEntry { - path: FileSystemPath::Path { path: allowed_path }, - access: FileSystemAccessMode::Read, - }), - true - ); - } - - #[test] - fn effective_file_system_sandbox_policy_returns_base_policy_without_additional_permissions() { - let temp_dir = TempDir::new().expect("create temp dir"); - let cwd = AbsolutePathBuf::from_absolute_path( - canonicalize(temp_dir.path()).expect("canonicalize temp dir"), - ) - .expect("absolute temp dir"); - let denied_path = cwd.join("denied").expect("denied path"); - let base_policy = FileSystemSandboxPolicy::restricted(vec![ - FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: FileSystemSpecialPath::Root, - }, - access: FileSystemAccessMode::Read, - }, - FileSystemSandboxEntry { - path: FileSystemPath::Path { path: denied_path }, - access: FileSystemAccessMode::None, - }, - ]); - - let effective_policy = effective_file_system_sandbox_policy(&base_policy, None); - - assert_eq!(effective_policy, base_policy); - } - - #[test] - fn effective_file_system_sandbox_policy_merges_additional_write_roots() { - let temp_dir = TempDir::new().expect("create temp dir"); - let cwd = AbsolutePathBuf::from_absolute_path( - canonicalize(temp_dir.path()).expect("canonicalize temp dir"), - ) - .expect("absolute temp dir"); - let allowed_path = cwd.join("allowed").expect("allowed path"); - let denied_path = cwd.join("denied").expect("denied path"); - let base_policy = FileSystemSandboxPolicy::restricted(vec![ - FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: FileSystemSpecialPath::Root, - }, - access: FileSystemAccessMode::Read, - }, - FileSystemSandboxEntry { - path: FileSystemPath::Path { - path: denied_path.clone(), - }, - access: FileSystemAccessMode::None, - }, - ]); - let additional_permissions = PermissionProfile { - file_system: Some(FileSystemPermissions { - read: Some(vec![]), - write: Some(vec![allowed_path.clone()]), - }), - ..Default::default() - }; - - let effective_policy = - effective_file_system_sandbox_policy(&base_policy, Some(&additional_permissions)); - - assert_eq!( - effective_policy.entries.contains(&FileSystemSandboxEntry { - path: FileSystemPath::Path { path: denied_path }, - access: FileSystemAccessMode::None, - }), - true - ); - assert_eq!( - effective_policy.entries.contains(&FileSystemSandboxEntry { - path: FileSystemPath::Path { path: allowed_path }, - access: FileSystemAccessMode::Write, - }), - true - ); - } -} +#[path = "mod_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/sandboxing/mod_tests.rs b/codex-rs/core/src/sandboxing/mod_tests.rs new file mode 100644 index 0000000000..c20c2ef413 --- /dev/null +++ b/codex-rs/core/src/sandboxing/mod_tests.rs @@ -0,0 +1,723 @@ +#[cfg(target_os = "macos")] +use super::EffectiveSandboxPermissions; +use super::SandboxManager; +use super::effective_file_system_sandbox_policy; +#[cfg(target_os = "macos")] +use super::intersect_permission_profiles; +use super::merge_file_system_policy_with_additional_permissions; +use super::normalize_additional_permissions; +use super::sandbox_policy_with_additional_permissions; +use super::should_require_platform_sandbox; +use crate::exec::SandboxType; +use crate::protocol::NetworkAccess; +use crate::protocol::ReadOnlyAccess; +use crate::protocol::SandboxPolicy; +use crate::tools::sandboxing::SandboxablePreference; +use codex_protocol::config_types::WindowsSandboxLevel; +use codex_protocol::models::FileSystemPermissions; +#[cfg(target_os = "macos")] +use codex_protocol::models::MacOsAutomationPermission; +#[cfg(target_os = "macos")] +use codex_protocol::models::MacOsContactsPermission; +#[cfg(target_os = "macos")] +use codex_protocol::models::MacOsPreferencesPermission; +#[cfg(target_os = "macos")] +use codex_protocol::models::MacOsSeatbeltProfileExtensions; +use codex_protocol::models::NetworkPermissions; +use codex_protocol::models::PermissionProfile; +use codex_protocol::permissions::FileSystemAccessMode; +use codex_protocol::permissions::FileSystemPath; +use codex_protocol::permissions::FileSystemSandboxEntry; +use codex_protocol::permissions::FileSystemSandboxPolicy; +use codex_protocol::permissions::FileSystemSpecialPath; +use codex_protocol::permissions::NetworkSandboxPolicy; +use codex_utils_absolute_path::AbsolutePathBuf; +use dunce::canonicalize; +use pretty_assertions::assert_eq; +use std::collections::HashMap; +use tempfile::TempDir; + +#[test] +fn danger_full_access_defaults_to_no_sandbox_without_network_requirements() { + let manager = SandboxManager::new(); + let sandbox = manager.select_initial( + &FileSystemSandboxPolicy::unrestricted(), + NetworkSandboxPolicy::Enabled, + SandboxablePreference::Auto, + WindowsSandboxLevel::Disabled, + false, + ); + assert_eq!(sandbox, SandboxType::None); +} + +#[test] +fn danger_full_access_uses_platform_sandbox_with_network_requirements() { + let manager = SandboxManager::new(); + let expected = crate::safety::get_platform_sandbox(false).unwrap_or(SandboxType::None); + let sandbox = manager.select_initial( + &FileSystemSandboxPolicy::unrestricted(), + NetworkSandboxPolicy::Enabled, + SandboxablePreference::Auto, + WindowsSandboxLevel::Disabled, + true, + ); + assert_eq!(sandbox, expected); +} + +#[test] +fn restricted_file_system_uses_platform_sandbox_without_managed_network() { + let manager = SandboxManager::new(); + let expected = crate::safety::get_platform_sandbox(false).unwrap_or(SandboxType::None); + let sandbox = manager.select_initial( + &FileSystemSandboxPolicy::restricted(vec![FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: FileSystemSpecialPath::Root, + }, + access: FileSystemAccessMode::Read, + }]), + NetworkSandboxPolicy::Enabled, + SandboxablePreference::Auto, + WindowsSandboxLevel::Disabled, + false, + ); + assert_eq!(sandbox, expected); +} + +#[test] +fn full_access_restricted_policy_skips_platform_sandbox_when_network_is_enabled() { + let policy = FileSystemSandboxPolicy::restricted(vec![FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: FileSystemSpecialPath::Root, + }, + access: FileSystemAccessMode::Write, + }]); + + assert_eq!( + should_require_platform_sandbox(&policy, NetworkSandboxPolicy::Enabled, false), + false + ); +} + +#[test] +fn root_write_policy_with_carveouts_still_uses_platform_sandbox() { + let blocked = AbsolutePathBuf::resolve_path_against_base( + "blocked", + std::env::current_dir().expect("current dir"), + ) + .expect("blocked path"); + let policy = FileSystemSandboxPolicy::restricted(vec![ + FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: FileSystemSpecialPath::Root, + }, + access: FileSystemAccessMode::Write, + }, + FileSystemSandboxEntry { + path: FileSystemPath::Path { path: blocked }, + access: FileSystemAccessMode::None, + }, + ]); + + assert_eq!( + should_require_platform_sandbox(&policy, NetworkSandboxPolicy::Enabled, false), + true + ); +} + +#[test] +fn full_access_restricted_policy_still_uses_platform_sandbox_for_restricted_network() { + let policy = FileSystemSandboxPolicy::restricted(vec![FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: FileSystemSpecialPath::Root, + }, + access: FileSystemAccessMode::Write, + }]); + + assert_eq!( + should_require_platform_sandbox(&policy, NetworkSandboxPolicy::Restricted, false), + true + ); +} + +#[test] +fn transform_preserves_unrestricted_file_system_policy_for_restricted_network() { + let manager = SandboxManager::new(); + let cwd = std::env::current_dir().expect("current dir"); + let exec_request = manager + .transform(super::SandboxTransformRequest { + spec: super::CommandSpec { + program: "true".to_string(), + args: Vec::new(), + cwd: cwd.clone(), + env: HashMap::new(), + expiration: crate::exec::ExecExpiration::DefaultTimeout, + sandbox_permissions: super::SandboxPermissions::UseDefault, + additional_permissions: None, + justification: None, + }, + policy: &SandboxPolicy::ExternalSandbox { + network_access: crate::protocol::NetworkAccess::Restricted, + }, + file_system_policy: &FileSystemSandboxPolicy::unrestricted(), + network_policy: NetworkSandboxPolicy::Restricted, + sandbox: SandboxType::None, + enforce_managed_network: false, + network: None, + sandbox_policy_cwd: cwd.as_path(), + #[cfg(target_os = "macos")] + macos_seatbelt_profile_extensions: None, + codex_linux_sandbox_exe: None, + use_legacy_landlock: false, + windows_sandbox_level: WindowsSandboxLevel::Disabled, + }) + .expect("transform"); + + assert_eq!( + exec_request.file_system_sandbox_policy, + FileSystemSandboxPolicy::unrestricted() + ); + assert_eq!( + exec_request.network_sandbox_policy, + NetworkSandboxPolicy::Restricted + ); +} + +#[test] +fn normalize_additional_permissions_preserves_network() { + let temp_dir = TempDir::new().expect("create temp dir"); + let path = AbsolutePathBuf::from_absolute_path( + canonicalize(temp_dir.path()).expect("canonicalize temp dir"), + ) + .expect("absolute temp dir"); + let permissions = normalize_additional_permissions(PermissionProfile { + network: Some(NetworkPermissions { + enabled: Some(true), + }), + file_system: Some(FileSystemPermissions { + read: Some(vec![path.clone()]), + write: Some(vec![path.clone()]), + }), + ..Default::default() + }) + .expect("permissions"); + + assert_eq!( + permissions.network, + Some(NetworkPermissions { + enabled: Some(true), + }) + ); + assert_eq!( + permissions.file_system, + Some(FileSystemPermissions { + read: Some(vec![path.clone()]), + write: Some(vec![path]), + }) + ); +} + +#[test] +fn normalize_additional_permissions_drops_empty_nested_profiles() { + let permissions = normalize_additional_permissions(PermissionProfile { + network: Some(NetworkPermissions { enabled: None }), + file_system: Some(FileSystemPermissions { + read: None, + write: None, + }), + macos: None, + }) + .expect("permissions"); + + assert_eq!(permissions, PermissionProfile::default()); +} + +#[cfg(target_os = "macos")] +#[test] +fn normalize_additional_permissions_preserves_default_macos_preferences_permission() { + let permissions = normalize_additional_permissions(PermissionProfile { + macos: Some(MacOsSeatbeltProfileExtensions::default()), + ..Default::default() + }) + .expect("permissions"); + + assert_eq!( + permissions, + PermissionProfile { + macos: Some(MacOsSeatbeltProfileExtensions::default()), + ..Default::default() + } + ); +} + +#[cfg(target_os = "macos")] +#[test] +fn intersect_permission_profiles_preserves_default_macos_grants() { + let requested = PermissionProfile { + file_system: Some(FileSystemPermissions { + read: Some(Vec::from(["/tmp/requested" + .try_into() + .expect("absolute path")])), + write: None, + }), + macos: Some(MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Notes".to_string(), + ]), + macos_launch_services: false, + macos_accessibility: true, + macos_calendar: true, + macos_reminders: false, + macos_contacts: MacOsContactsPermission::None, + }), + ..Default::default() + }; + let granted = PermissionProfile { + file_system: Some(FileSystemPermissions { + read: Some(Vec::new()), + write: None, + }), + macos: Some(MacOsSeatbeltProfileExtensions::default()), + ..Default::default() + }; + + assert_eq!( + intersect_permission_profiles(requested, granted), + PermissionProfile { + macos: Some(MacOsSeatbeltProfileExtensions::default()), + ..Default::default() + } + ); +} + +#[cfg(target_os = "macos")] +#[test] +fn normalize_additional_permissions_preserves_macos_permissions() { + let permissions = normalize_additional_permissions(PermissionProfile { + macos: Some(MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Notes".to_string(), + ]), + macos_launch_services: true, + macos_accessibility: true, + macos_calendar: true, + macos_reminders: false, + macos_contacts: MacOsContactsPermission::None, + }), + ..Default::default() + }) + .expect("permissions"); + + assert_eq!( + permissions.macos, + Some(MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Notes".to_string(), + ]), + macos_launch_services: true, + macos_accessibility: true, + macos_calendar: true, + macos_reminders: false, + macos_contacts: MacOsContactsPermission::None, + }) + ); +} + +#[test] +fn read_only_additional_permissions_can_enable_network_without_writes() { + let temp_dir = TempDir::new().expect("create temp dir"); + let path = AbsolutePathBuf::from_absolute_path( + canonicalize(temp_dir.path()).expect("canonicalize temp dir"), + ) + .expect("absolute temp dir"); + let policy = sandbox_policy_with_additional_permissions( + &SandboxPolicy::ReadOnly { + access: ReadOnlyAccess::Restricted { + include_platform_defaults: true, + readable_roots: vec![path.clone()], + }, + network_access: false, + }, + &PermissionProfile { + network: Some(NetworkPermissions { + enabled: Some(true), + }), + file_system: Some(FileSystemPermissions { + read: Some(vec![path.clone()]), + write: Some(Vec::new()), + }), + ..Default::default() + }, + ); + + assert_eq!( + policy, + SandboxPolicy::ReadOnly { + access: ReadOnlyAccess::Restricted { + include_platform_defaults: true, + readable_roots: vec![path], + }, + network_access: true, + } + ); +} +#[cfg(target_os = "macos")] +#[test] +fn effective_permissions_merge_macos_extensions_with_additional_permissions() { + let temp_dir = TempDir::new().expect("create temp dir"); + let path = AbsolutePathBuf::from_absolute_path( + canonicalize(temp_dir.path()).expect("canonicalize temp dir"), + ) + .expect("absolute temp dir"); + let effective_permissions = EffectiveSandboxPermissions::new( + &SandboxPolicy::ReadOnly { + access: ReadOnlyAccess::Restricted { + include_platform_defaults: true, + readable_roots: vec![path.clone()], + }, + network_access: false, + }, + Some(&MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadOnly, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Calendar".to_string(), + ]), + macos_launch_services: false, + macos_accessibility: false, + macos_calendar: false, + macos_reminders: false, + macos_contacts: MacOsContactsPermission::None, + }), + Some(&PermissionProfile { + file_system: Some(FileSystemPermissions { + read: Some(vec![path]), + write: Some(Vec::new()), + }), + macos: Some(MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Notes".to_string(), + ]), + macos_launch_services: true, + macos_accessibility: true, + macos_calendar: true, + macos_reminders: false, + macos_contacts: MacOsContactsPermission::None, + }), + ..Default::default() + }), + ); + + assert_eq!( + effective_permissions.macos_seatbelt_profile_extensions, + Some(MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Calendar".to_string(), + "com.apple.Notes".to_string(), + ]), + macos_launch_services: true, + macos_accessibility: true, + macos_calendar: true, + macos_reminders: false, + macos_contacts: MacOsContactsPermission::None, + }) + ); +} + +#[test] +fn external_sandbox_additional_permissions_can_enable_network() { + let temp_dir = TempDir::new().expect("create temp dir"); + let path = AbsolutePathBuf::from_absolute_path( + canonicalize(temp_dir.path()).expect("canonicalize temp dir"), + ) + .expect("absolute temp dir"); + let policy = sandbox_policy_with_additional_permissions( + &SandboxPolicy::ExternalSandbox { + network_access: NetworkAccess::Restricted, + }, + &PermissionProfile { + network: Some(NetworkPermissions { + enabled: Some(true), + }), + file_system: Some(FileSystemPermissions { + read: Some(vec![path]), + write: Some(Vec::new()), + }), + ..Default::default() + }, + ); + + assert_eq!( + policy, + SandboxPolicy::ExternalSandbox { + network_access: NetworkAccess::Enabled, + } + ); +} + +#[test] +fn transform_additional_permissions_enable_network_for_external_sandbox() { + let manager = SandboxManager::new(); + let cwd = std::env::current_dir().expect("current dir"); + let temp_dir = TempDir::new().expect("create temp dir"); + let path = AbsolutePathBuf::from_absolute_path( + canonicalize(temp_dir.path()).expect("canonicalize temp dir"), + ) + .expect("absolute temp dir"); + let exec_request = manager + .transform(super::SandboxTransformRequest { + spec: super::CommandSpec { + program: "true".to_string(), + args: Vec::new(), + cwd: cwd.clone(), + env: HashMap::new(), + expiration: crate::exec::ExecExpiration::DefaultTimeout, + sandbox_permissions: super::SandboxPermissions::WithAdditionalPermissions, + additional_permissions: Some(PermissionProfile { + network: Some(NetworkPermissions { + enabled: Some(true), + }), + file_system: Some(FileSystemPermissions { + read: Some(vec![path]), + write: Some(Vec::new()), + }), + ..Default::default() + }), + justification: None, + }, + policy: &SandboxPolicy::ExternalSandbox { + network_access: NetworkAccess::Restricted, + }, + file_system_policy: &FileSystemSandboxPolicy::unrestricted(), + network_policy: NetworkSandboxPolicy::Restricted, + sandbox: SandboxType::None, + enforce_managed_network: false, + network: None, + sandbox_policy_cwd: cwd.as_path(), + #[cfg(target_os = "macos")] + macos_seatbelt_profile_extensions: None, + codex_linux_sandbox_exe: None, + use_legacy_landlock: false, + windows_sandbox_level: WindowsSandboxLevel::Disabled, + }) + .expect("transform"); + + assert_eq!( + exec_request.sandbox_policy, + SandboxPolicy::ExternalSandbox { + network_access: NetworkAccess::Enabled, + } + ); + assert_eq!( + exec_request.network_sandbox_policy, + NetworkSandboxPolicy::Enabled + ); +} + +#[test] +fn transform_additional_permissions_preserves_denied_entries() { + let manager = SandboxManager::new(); + let cwd = std::env::current_dir().expect("current dir"); + let temp_dir = TempDir::new().expect("create temp dir"); + let workspace_root = AbsolutePathBuf::from_absolute_path( + canonicalize(temp_dir.path()).expect("canonicalize temp dir"), + ) + .expect("absolute temp dir"); + let allowed_path = workspace_root.join("allowed").expect("allowed path"); + let denied_path = workspace_root.join("denied").expect("denied path"); + let exec_request = manager + .transform(super::SandboxTransformRequest { + spec: super::CommandSpec { + program: "true".to_string(), + args: Vec::new(), + cwd: cwd.clone(), + env: HashMap::new(), + expiration: crate::exec::ExecExpiration::DefaultTimeout, + sandbox_permissions: super::SandboxPermissions::WithAdditionalPermissions, + additional_permissions: Some(PermissionProfile { + file_system: Some(FileSystemPermissions { + read: None, + write: Some(vec![allowed_path.clone()]), + }), + ..Default::default() + }), + justification: None, + }, + policy: &SandboxPolicy::ReadOnly { + access: ReadOnlyAccess::FullAccess, + network_access: false, + }, + file_system_policy: &FileSystemSandboxPolicy::restricted(vec![ + FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: FileSystemSpecialPath::Root, + }, + access: FileSystemAccessMode::Read, + }, + FileSystemSandboxEntry { + path: FileSystemPath::Path { + path: denied_path.clone(), + }, + access: FileSystemAccessMode::None, + }, + ]), + network_policy: NetworkSandboxPolicy::Restricted, + sandbox: SandboxType::None, + enforce_managed_network: false, + network: None, + sandbox_policy_cwd: cwd.as_path(), + #[cfg(target_os = "macos")] + macos_seatbelt_profile_extensions: None, + codex_linux_sandbox_exe: None, + use_legacy_landlock: false, + windows_sandbox_level: WindowsSandboxLevel::Disabled, + }) + .expect("transform"); + + assert_eq!( + exec_request.file_system_sandbox_policy, + FileSystemSandboxPolicy::restricted(vec![ + FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: FileSystemSpecialPath::Root, + }, + access: FileSystemAccessMode::Read, + }, + FileSystemSandboxEntry { + path: FileSystemPath::Path { path: denied_path }, + access: FileSystemAccessMode::None, + }, + FileSystemSandboxEntry { + path: FileSystemPath::Path { path: allowed_path }, + access: FileSystemAccessMode::Write, + }, + ]) + ); + assert_eq!( + exec_request.network_sandbox_policy, + NetworkSandboxPolicy::Restricted + ); +} + +#[test] +fn merge_file_system_policy_with_additional_permissions_preserves_unreadable_roots() { + let temp_dir = TempDir::new().expect("create temp dir"); + let cwd = AbsolutePathBuf::from_absolute_path( + canonicalize(temp_dir.path()).expect("canonicalize temp dir"), + ) + .expect("absolute temp dir"); + let allowed_path = cwd.join("allowed").expect("allowed path"); + let denied_path = cwd.join("denied").expect("denied path"); + let merged_policy = merge_file_system_policy_with_additional_permissions( + &FileSystemSandboxPolicy::restricted(vec![ + FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: FileSystemSpecialPath::Root, + }, + access: FileSystemAccessMode::Read, + }, + FileSystemSandboxEntry { + path: FileSystemPath::Path { + path: denied_path.clone(), + }, + access: FileSystemAccessMode::None, + }, + ]), + vec![allowed_path.clone()], + Vec::new(), + ); + + assert_eq!( + merged_policy.entries.contains(&FileSystemSandboxEntry { + path: FileSystemPath::Path { path: denied_path }, + access: FileSystemAccessMode::None, + }), + true + ); + assert_eq!( + merged_policy.entries.contains(&FileSystemSandboxEntry { + path: FileSystemPath::Path { path: allowed_path }, + access: FileSystemAccessMode::Read, + }), + true + ); +} + +#[test] +fn effective_file_system_sandbox_policy_returns_base_policy_without_additional_permissions() { + let temp_dir = TempDir::new().expect("create temp dir"); + let cwd = AbsolutePathBuf::from_absolute_path( + canonicalize(temp_dir.path()).expect("canonicalize temp dir"), + ) + .expect("absolute temp dir"); + let denied_path = cwd.join("denied").expect("denied path"); + let base_policy = FileSystemSandboxPolicy::restricted(vec![ + FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: FileSystemSpecialPath::Root, + }, + access: FileSystemAccessMode::Read, + }, + FileSystemSandboxEntry { + path: FileSystemPath::Path { path: denied_path }, + access: FileSystemAccessMode::None, + }, + ]); + + let effective_policy = effective_file_system_sandbox_policy(&base_policy, None); + + assert_eq!(effective_policy, base_policy); +} + +#[test] +fn effective_file_system_sandbox_policy_merges_additional_write_roots() { + let temp_dir = TempDir::new().expect("create temp dir"); + let cwd = AbsolutePathBuf::from_absolute_path( + canonicalize(temp_dir.path()).expect("canonicalize temp dir"), + ) + .expect("absolute temp dir"); + let allowed_path = cwd.join("allowed").expect("allowed path"); + let denied_path = cwd.join("denied").expect("denied path"); + let base_policy = FileSystemSandboxPolicy::restricted(vec![ + FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: FileSystemSpecialPath::Root, + }, + access: FileSystemAccessMode::Read, + }, + FileSystemSandboxEntry { + path: FileSystemPath::Path { + path: denied_path.clone(), + }, + access: FileSystemAccessMode::None, + }, + ]); + let additional_permissions = PermissionProfile { + file_system: Some(FileSystemPermissions { + read: Some(vec![]), + write: Some(vec![allowed_path.clone()]), + }), + ..Default::default() + }; + + let effective_policy = + effective_file_system_sandbox_policy(&base_policy, Some(&additional_permissions)); + + assert_eq!( + effective_policy.entries.contains(&FileSystemSandboxEntry { + path: FileSystemPath::Path { path: denied_path }, + access: FileSystemAccessMode::None, + }), + true + ); + assert_eq!( + effective_policy.entries.contains(&FileSystemSandboxEntry { + path: FileSystemPath::Path { path: allowed_path }, + access: FileSystemAccessMode::Write, + }), + true + ); +} diff --git a/codex-rs/core/src/seatbelt.rs b/codex-rs/core/src/seatbelt.rs index fa0538e384..672165526b 100644 --- a/codex-rs/core/src/seatbelt.rs +++ b/codex-rs/core/src/seatbelt.rs @@ -584,1064 +584,5 @@ fn macos_dir_params() -> Vec<(String, PathBuf)> { } #[cfg(test)] -mod tests { - use super::MACOS_SEATBELT_BASE_POLICY; - use super::ProxyPolicyInputs; - use super::UnixDomainSocketPolicy; - use super::create_seatbelt_command_args; - use super::create_seatbelt_command_args_for_policies_with_extensions; - use super::create_seatbelt_command_args_with_extensions; - use super::dynamic_network_policy; - use super::macos_dir_params; - use super::normalize_path_for_sandbox; - use super::unix_socket_dir_params; - use super::unix_socket_policy; - use crate::protocol::ReadOnlyAccess; - use crate::protocol::SandboxPolicy; - use crate::seatbelt::MACOS_PATH_TO_SEATBELT_EXECUTABLE; - use crate::seatbelt_permissions::MacOsAutomationPermission; - use crate::seatbelt_permissions::MacOsContactsPermission; - use crate::seatbelt_permissions::MacOsPreferencesPermission; - use crate::seatbelt_permissions::MacOsSeatbeltProfileExtensions; - use codex_protocol::permissions::FileSystemAccessMode; - use codex_protocol::permissions::FileSystemPath; - use codex_protocol::permissions::FileSystemSandboxEntry; - use codex_protocol::permissions::FileSystemSandboxPolicy; - use codex_protocol::permissions::NetworkSandboxPolicy; - use codex_utils_absolute_path::AbsolutePathBuf; - use pretty_assertions::assert_eq; - use std::fs; - use std::path::Path; - use std::path::PathBuf; - use std::process::Command; - use tempfile::TempDir; - - fn assert_seatbelt_denied(stderr: &[u8], path: &Path) { - let stderr = String::from_utf8_lossy(stderr); - let expected = format!("bash: {}: Operation not permitted\n", path.display()); - assert!( - stderr == expected - || stderr.contains("sandbox-exec: sandbox_apply: Operation not permitted"), - "unexpected stderr: {stderr}" - ); - } - - fn absolute_path(path: &str) -> AbsolutePathBuf { - AbsolutePathBuf::from_absolute_path(Path::new(path)).expect("absolute path") - } - - fn seatbelt_policy_arg(args: &[String]) -> &str { - let policy_index = args - .iter() - .position(|arg| arg == "-p") - .expect("seatbelt args should include -p"); - args.get(policy_index + 1) - .expect("seatbelt args should include policy text") - } - - #[test] - fn base_policy_allows_node_cpu_sysctls() { - assert!( - MACOS_SEATBELT_BASE_POLICY.contains("(sysctl-name \"machdep.cpu.brand_string\")"), - "base policy must allow CPU brand lookup for os.cpus()" - ); - assert!( - MACOS_SEATBELT_BASE_POLICY.contains("(sysctl-name \"hw.model\")"), - "base policy must allow hardware model lookup for os.cpus()" - ); - } - - #[test] - fn create_seatbelt_args_routes_network_through_proxy_ports() { - let policy = dynamic_network_policy( - &SandboxPolicy::new_read_only_policy(), - false, - &ProxyPolicyInputs { - ports: vec![43128, 48081], - has_proxy_config: true, - allow_local_binding: false, - ..ProxyPolicyInputs::default() - }, - ); - - assert!( - policy.contains("(allow network-outbound (remote ip \"localhost:43128\"))"), - "expected HTTP proxy port allow rule in policy:\n{policy}" - ); - assert!( - policy.contains("(allow network-outbound (remote ip \"localhost:48081\"))"), - "expected SOCKS proxy port allow rule in policy:\n{policy}" - ); - assert!( - !policy.contains("\n(allow network-outbound)\n"), - "policy should not include blanket outbound allowance when proxy ports are present:\n{policy}" - ); - assert!( - !policy.contains("(allow network-bind (local ip \"localhost:*\"))"), - "policy should not allow loopback binding unless explicitly enabled:\n{policy}" - ); - assert!( - !policy.contains("(allow network-inbound (local ip \"localhost:*\"))"), - "policy should not allow loopback inbound unless explicitly enabled:\n{policy}" - ); - } - - #[test] - fn explicit_unreadable_paths_are_excluded_from_full_disk_read_and_write_access() { - let unreadable = absolute_path("/tmp/codex-unreadable"); - let file_system_policy = FileSystemSandboxPolicy::restricted(vec![ - FileSystemSandboxEntry { - path: FileSystemPath::Special { - value: crate::protocol::FileSystemSpecialPath::Root, - }, - access: FileSystemAccessMode::Write, - }, - FileSystemSandboxEntry { - path: FileSystemPath::Path { path: unreadable }, - access: FileSystemAccessMode::None, - }, - ]); - - let args = create_seatbelt_command_args_for_policies_with_extensions( - vec!["/bin/true".to_string()], - &file_system_policy, - NetworkSandboxPolicy::Restricted, - Path::new("/"), - false, - None, - None, - ); - - let policy = seatbelt_policy_arg(&args); - assert!( - policy.contains("(require-not (subpath (param \"READABLE_ROOT_0_RO_0\")))"), - "expected read carveout in policy:\n{policy}" - ); - assert!( - policy.contains("(require-not (subpath (param \"WRITABLE_ROOT_0_RO_0\")))"), - "expected write carveout in policy:\n{policy}" - ); - assert!( - args.iter() - .any(|arg| arg == "-DREADABLE_ROOT_0_RO_0=/tmp/codex-unreadable"), - "expected read carveout parameter in args: {args:#?}" - ); - assert!( - args.iter() - .any(|arg| arg == "-DWRITABLE_ROOT_0_RO_0=/tmp/codex-unreadable"), - "expected write carveout parameter in args: {args:#?}" - ); - } - - #[test] - fn explicit_unreadable_paths_are_excluded_from_readable_roots() { - let root = absolute_path("/tmp/codex-readable"); - let unreadable = absolute_path("/tmp/codex-readable/private"); - let file_system_policy = FileSystemSandboxPolicy::restricted(vec![ - FileSystemSandboxEntry { - path: FileSystemPath::Path { path: root }, - access: FileSystemAccessMode::Read, - }, - FileSystemSandboxEntry { - path: FileSystemPath::Path { path: unreadable }, - access: FileSystemAccessMode::None, - }, - ]); - - let args = create_seatbelt_command_args_for_policies_with_extensions( - vec!["/bin/true".to_string()], - &file_system_policy, - NetworkSandboxPolicy::Restricted, - Path::new("/"), - false, - None, - None, - ); - - let policy = seatbelt_policy_arg(&args); - assert!( - policy.contains("(require-not (subpath (param \"READABLE_ROOT_0_RO_0\")))"), - "expected read carveout in policy:\n{policy}" - ); - assert!( - args.iter() - .any(|arg| arg == "-DREADABLE_ROOT_0=/tmp/codex-readable"), - "expected readable root parameter in args: {args:#?}" - ); - assert!( - args.iter() - .any(|arg| arg == "-DREADABLE_ROOT_0_RO_0=/tmp/codex-readable/private"), - "expected read carveout parameter in args: {args:#?}" - ); - } - - #[test] - fn seatbelt_args_include_macos_permission_extensions() { - let cwd = std::env::temp_dir(); - let args = create_seatbelt_command_args_with_extensions( - vec!["echo".to_string(), "ok".to_string()], - &SandboxPolicy::new_read_only_policy(), - cwd.as_path(), - false, - None, - Some(&MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Notes".to_string(), - ]), - macos_launch_services: true, - macos_accessibility: true, - macos_calendar: true, - macos_reminders: false, - macos_contacts: MacOsContactsPermission::None, - }), - ); - let policy = &args[1]; - - assert!(policy.contains("(allow user-preference-write)")); - assert!(policy.contains("(appleevent-destination \"com.apple.Notes\")")); - assert!(policy.contains("com.apple.axserver")); - assert!(policy.contains("com.apple.CalendarAgent")); - } - - #[test] - fn bundle_id_automation_keeps_lsopen_denied() { - let tmp = TempDir::new().expect("tempdir"); - let cwd = tmp.path().join("cwd"); - fs::create_dir_all(&cwd).expect("create cwd"); - - let args = create_seatbelt_command_args_with_extensions( - vec![ - "/usr/bin/python3".to_string(), - "-c".to_string(), - r#"import ctypes -import os -import sys -lib = ctypes.CDLL("/usr/lib/libsandbox.1.dylib") -lib.sandbox_check.restype = ctypes.c_int -allowed = lib.sandbox_check(os.getpid(), b"lsopen", 0) == 0 -sys.exit(0 if allowed else 13) -"# - .to_string(), - ], - &SandboxPolicy::new_read_only_policy(), - cwd.as_path(), - false, - None, - Some(&MacOsSeatbeltProfileExtensions { - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Notes".to_string(), - ]), - ..Default::default() - }), - ); - - let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE) - .args(&args) - .current_dir(&cwd) - .output() - .expect("execute seatbelt command"); - - let stderr = String::from_utf8_lossy(&output.stderr); - if stderr.contains("sandbox-exec: sandbox_apply: Operation not permitted") { - return; - } - - assert_eq!( - Some(13), - output.status.code(), - "lsopen should remain denied even with bundle-scoped automation\nstdout: {}\nstderr: {stderr}", - String::from_utf8_lossy(&output.stdout), - ); - } - - #[test] - fn seatbelt_args_without_extension_profile_keep_legacy_preferences_read_access() { - let cwd = std::env::temp_dir(); - let args = create_seatbelt_command_args( - vec!["echo".to_string(), "ok".to_string()], - &SandboxPolicy::new_read_only_policy(), - cwd.as_path(), - false, - None, - ); - let policy = &args[1]; - assert!(policy.contains("(allow user-preference-read)")); - assert!(!policy.contains("(allow user-preference-write)")); - } - - #[test] - fn seatbelt_legacy_workspace_write_nested_readable_root_stays_writable() { - let tmp = TempDir::new().expect("tempdir"); - let cwd = tmp.path().join("workspace"); - fs::create_dir_all(cwd.join("docs")).expect("create docs"); - let docs = AbsolutePathBuf::from_absolute_path(cwd.join("docs")).expect("absolute docs"); - let args = create_seatbelt_command_args( - vec!["/bin/true".to_string()], - &SandboxPolicy::WorkspaceWrite { - writable_roots: vec![], - read_only_access: ReadOnlyAccess::Restricted { - include_platform_defaults: true, - readable_roots: vec![docs.clone()], - }, - network_access: false, - exclude_tmpdir_env_var: true, - exclude_slash_tmp: true, - }, - cwd.as_path(), - false, - None, - ); - - let docs_param = format!("-DWRITABLE_ROOT_0_RO_0={}", docs.as_path().display()); - assert!( - !seatbelt_policy_arg(&args).contains("WRITABLE_ROOT_0_RO_0"), - "legacy workspace-write readable roots under cwd should not become seatbelt carveouts:\n{args:#?}" - ); - assert!( - !args.iter().any(|arg| arg == &docs_param), - "unexpected seatbelt carveout parameter for redundant legacy readable root: {args:#?}" - ); - } - - #[test] - fn seatbelt_args_default_extension_profile_keeps_preferences_read_access() { - let cwd = std::env::temp_dir(); - let args = create_seatbelt_command_args_with_extensions( - vec!["echo".to_string(), "ok".to_string()], - &SandboxPolicy::new_read_only_policy(), - cwd.as_path(), - false, - None, - Some(&MacOsSeatbeltProfileExtensions::default()), - ); - let policy = &args[1]; - assert!(!policy.contains("appleevent-send")); - assert!(!policy.contains("com.apple.axserver")); - assert!(!policy.contains("com.apple.CalendarAgent")); - assert!(policy.contains("(allow user-preference-read)")); - assert!(!policy.contains("user-preference-write")); - } - - #[test] - fn create_seatbelt_args_allows_local_binding_when_explicitly_enabled() { - let policy = dynamic_network_policy( - &SandboxPolicy::new_read_only_policy(), - false, - &ProxyPolicyInputs { - ports: vec![43128], - has_proxy_config: true, - allow_local_binding: true, - ..ProxyPolicyInputs::default() - }, - ); - - assert!( - policy.contains("(allow network-bind (local ip \"localhost:*\"))"), - "policy should allow loopback local binding when explicitly enabled:\n{policy}" - ); - assert!( - policy.contains("(allow network-inbound (local ip \"localhost:*\"))"), - "policy should allow loopback inbound when explicitly enabled:\n{policy}" - ); - assert!( - policy.contains("(allow network-outbound (remote ip \"localhost:*\"))"), - "policy should allow loopback outbound when explicitly enabled:\n{policy}" - ); - assert!( - !policy.contains("\n(allow network-outbound)\n"), - "policy should keep proxy-routed behavior without blanket outbound allowance:\n{policy}" - ); - } - - #[test] - fn dynamic_network_policy_preserves_restricted_policy_when_proxy_config_without_ports() { - let policy = dynamic_network_policy( - &SandboxPolicy::WorkspaceWrite { - writable_roots: vec![], - read_only_access: Default::default(), - network_access: true, - exclude_tmpdir_env_var: false, - exclude_slash_tmp: false, - }, - false, - &ProxyPolicyInputs { - ports: vec![], - has_proxy_config: true, - allow_local_binding: false, - ..ProxyPolicyInputs::default() - }, - ); - - assert!( - policy.contains("(socket-domain AF_SYSTEM)"), - "policy should keep the restricted network profile when proxy config is present without ports:\n{policy}" - ); - assert!( - !policy.contains("\n(allow network-outbound)\n"), - "policy should not include blanket outbound allowance when proxy config is present without ports:\n{policy}" - ); - assert!( - !policy.contains("(allow network-outbound (remote ip \"localhost:"), - "policy should not include proxy port allowance when proxy config is present without ports:\n{policy}" - ); - } - - #[test] - fn dynamic_network_policy_preserves_restricted_policy_for_managed_network_without_proxy_config() - { - let policy = dynamic_network_policy( - &SandboxPolicy::WorkspaceWrite { - writable_roots: vec![], - read_only_access: Default::default(), - network_access: true, - exclude_tmpdir_env_var: false, - exclude_slash_tmp: false, - }, - true, - &ProxyPolicyInputs { - ports: vec![], - has_proxy_config: false, - allow_local_binding: false, - ..ProxyPolicyInputs::default() - }, - ); - - assert!( - policy.contains("(socket-domain AF_SYSTEM)"), - "policy should keep the restricted network profile when managed network is active without proxy endpoints:\n{policy}" - ); - assert!( - !policy.contains("\n(allow network-outbound)\n"), - "policy should not include blanket outbound allowance when managed network is active without proxy endpoints:\n{policy}" - ); - } - - #[test] - fn create_seatbelt_args_allowlists_unix_socket_paths() { - let policy = dynamic_network_policy( - &SandboxPolicy::new_read_only_policy(), - false, - &ProxyPolicyInputs { - ports: vec![43128], - has_proxy_config: true, - allow_local_binding: false, - unix_domain_socket_policy: UnixDomainSocketPolicy::Restricted { - allowed: vec![absolute_path("/tmp/example.sock")], - }, - }, - ); - - assert!( - policy.contains("(allow system-socket (socket-domain AF_UNIX))"), - "policy should allow AF_UNIX socket creation for configured unix sockets:\n{policy}" - ); - assert!( - policy.contains( - "(allow network-bind (local unix-socket (subpath (param \"UNIX_SOCKET_PATH_0\"))))" - ), - "policy should allow binding explicitly configured unix sockets:\n{policy}" - ); - assert!( - policy.contains( - "(allow network-outbound (remote unix-socket (subpath (param \"UNIX_SOCKET_PATH_0\"))))" - ), - "policy should allow connecting to explicitly configured unix sockets:\n{policy}" - ); - assert!( - !policy.contains("(allow network* (subpath"), - "policy should no longer use the generic subpath unix-socket rules:\n{policy}" - ); - } - - #[test] - fn unix_socket_policy_non_empty_output_is_newline_terminated() { - let allowlist_policy = unix_socket_policy(&ProxyPolicyInputs { - unix_domain_socket_policy: UnixDomainSocketPolicy::Restricted { - allowed: vec![absolute_path("/tmp/example.sock")], - }, - ..ProxyPolicyInputs::default() - }); - assert!( - allowlist_policy.ends_with('\n'), - "allowlist unix socket policy should end with a newline:\n{allowlist_policy}" - ); - - let allow_all_policy = unix_socket_policy(&ProxyPolicyInputs { - unix_domain_socket_policy: UnixDomainSocketPolicy::AllowAll, - ..ProxyPolicyInputs::default() - }); - assert!( - allow_all_policy.ends_with('\n'), - "allow-all unix socket policy should end with a newline:\n{allow_all_policy}" - ); - } - - #[test] - fn unix_socket_dir_params_use_stable_param_names() { - let params = unix_socket_dir_params(&ProxyPolicyInputs { - unix_domain_socket_policy: UnixDomainSocketPolicy::Restricted { - allowed: vec![ - absolute_path("/tmp/b.sock"), - absolute_path("/tmp/a.sock"), - absolute_path("/tmp/a.sock"), - ], - }, - ..ProxyPolicyInputs::default() - }); - - assert_eq!( - params, - vec![ - ( - "UNIX_SOCKET_PATH_0".to_string(), - PathBuf::from("/tmp/a.sock") - ), - ( - "UNIX_SOCKET_PATH_1".to_string(), - PathBuf::from("/tmp/b.sock") - ), - ] - ); - } - - #[test] - fn normalize_path_for_sandbox_rejects_relative_paths() { - assert_eq!(normalize_path_for_sandbox(Path::new("relative.sock")), None); - } - - #[test] - fn create_seatbelt_args_allows_all_unix_sockets_when_enabled() { - let policy = dynamic_network_policy( - &SandboxPolicy::new_read_only_policy(), - false, - &ProxyPolicyInputs { - ports: vec![43128], - has_proxy_config: true, - allow_local_binding: false, - unix_domain_socket_policy: UnixDomainSocketPolicy::AllowAll, - }, - ); - - assert!( - policy.contains("(allow system-socket (socket-domain AF_UNIX))"), - "policy should allow AF_UNIX socket creation when unix sockets are enabled:\n{policy}" - ); - assert!( - policy.contains("(allow network-bind (local unix-socket))"), - "policy should allow binding unix sockets when enabled:\n{policy}" - ); - assert!( - policy.contains("(allow network-outbound (remote unix-socket))"), - "policy should allow connecting to unix sockets when enabled:\n{policy}" - ); - assert!( - !policy.contains("(allow network* (subpath"), - "policy should no longer use the generic subpath unix-socket rules:\n{policy}" - ); - } - - #[test] - fn create_seatbelt_args_full_network_with_proxy_is_still_proxy_only() { - let policy = dynamic_network_policy( - &SandboxPolicy::WorkspaceWrite { - writable_roots: vec![], - read_only_access: Default::default(), - network_access: true, - exclude_tmpdir_env_var: false, - exclude_slash_tmp: false, - }, - false, - &ProxyPolicyInputs { - ports: vec![43128], - has_proxy_config: true, - allow_local_binding: false, - ..ProxyPolicyInputs::default() - }, - ); - - assert!( - policy.contains("(allow network-outbound (remote ip \"localhost:43128\"))"), - "expected proxy endpoint allow rule in policy:\n{policy}" - ); - assert!( - !policy.contains("\n(allow network-outbound)\n"), - "policy should not include blanket outbound allowance when proxy is configured:\n{policy}" - ); - assert!( - !policy.contains("\n(allow network-inbound)\n"), - "policy should not include blanket inbound allowance when proxy is configured:\n{policy}" - ); - } - - #[test] - fn create_seatbelt_args_with_read_only_git_and_codex_subpaths() { - // Create a temporary workspace with two writable roots: one containing - // top-level .git and .codex directories and one without them. - let tmp = TempDir::new().expect("tempdir"); - let PopulatedTmp { - vulnerable_root, - vulnerable_root_canonical, - dot_git_canonical, - dot_codex_canonical, - empty_root, - empty_root_canonical, - } = populate_tmpdir(tmp.path()); - let cwd = tmp.path().join("cwd"); - fs::create_dir_all(&cwd).expect("create cwd"); - - // Build a policy that only includes the two test roots as writable and - // does not automatically include defaults TMPDIR or /tmp. - let policy = SandboxPolicy::WorkspaceWrite { - writable_roots: vec![vulnerable_root, empty_root] - .into_iter() - .map(|p| p.try_into().unwrap()) - .collect(), - read_only_access: Default::default(), - network_access: false, - exclude_tmpdir_env_var: true, - exclude_slash_tmp: true, - }; - - // Create the Seatbelt command to wrap a shell command that tries to - // write to .codex/config.toml in the vulnerable root. - let shell_command: Vec = [ - "bash", - "-c", - "echo 'sandbox_mode = \"danger-full-access\"' > \"$1\"", - "bash", - dot_codex_canonical - .join("config.toml") - .to_string_lossy() - .as_ref(), - ] - .iter() - .map(std::string::ToString::to_string) - .collect(); - let args = create_seatbelt_command_args(shell_command.clone(), &policy, &cwd, false, None); - - // Build the expected policy text using a raw string for readability. - // Note that the policy includes: - // - the base policy, - // - read-only access to the filesystem, - // - write access to WRITABLE_ROOT_0 (but not its .git or .codex), WRITABLE_ROOT_1, and cwd as WRITABLE_ROOT_2. - let expected_policy = format!( - r#"{MACOS_SEATBELT_BASE_POLICY} -; allow read-only file operations -(allow file-read*) -(allow file-write* -(subpath (param "WRITABLE_ROOT_0")) (require-all (subpath (param "WRITABLE_ROOT_1")) (require-not (subpath (param "WRITABLE_ROOT_1_RO_0"))) (require-not (subpath (param "WRITABLE_ROOT_1_RO_1"))) ) (subpath (param "WRITABLE_ROOT_2")) -) - -; macOS permission profile extensions -(allow ipc-posix-shm-read* (ipc-posix-name-prefix "apple.cfprefs.")) -(allow mach-lookup - (global-name "com.apple.cfprefsd.daemon") - (global-name "com.apple.cfprefsd.agent") - (local-name "com.apple.cfprefsd.agent")) -(allow user-preference-read) -"#, - ); - - assert_eq!(seatbelt_policy_arg(&args), expected_policy); - - let expected_definitions = [ - format!( - "-DWRITABLE_ROOT_0={}", - cwd.canonicalize() - .expect("canonicalize cwd") - .to_string_lossy() - ), - format!( - "-DWRITABLE_ROOT_1={}", - vulnerable_root_canonical.to_string_lossy() - ), - format!( - "-DWRITABLE_ROOT_1_RO_0={}", - dot_git_canonical.to_string_lossy() - ), - format!( - "-DWRITABLE_ROOT_1_RO_1={}", - dot_codex_canonical.to_string_lossy() - ), - format!( - "-DWRITABLE_ROOT_2={}", - empty_root_canonical.to_string_lossy() - ), - ]; - for expected_definition in expected_definitions { - assert!( - args.contains(&expected_definition), - "expected definition arg `{expected_definition}` in {args:#?}" - ); - } - for (key, value) in macos_dir_params() { - let expected_definition = format!("-D{key}={}", value.to_string_lossy()); - assert!( - args.contains(&expected_definition), - "expected definition arg `{expected_definition}` in {args:#?}" - ); - } - - let command_index = args - .iter() - .position(|arg| arg == "--") - .expect("seatbelt args should include command separator"); - assert_eq!(args[command_index + 1..], shell_command); - - // Verify that .codex/config.toml cannot be modified under the generated - // Seatbelt policy. - let config_toml = dot_codex_canonical.join("config.toml"); - let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE) - .args(&args) - .current_dir(&cwd) - .output() - .expect("execute seatbelt command"); - assert_eq!( - "sandbox_mode = \"read-only\"\n", - String::from_utf8_lossy(&fs::read(&config_toml).expect("read config.toml")), - "config.toml should contain its original contents because it should not have been modified" - ); - assert!( - !output.status.success(), - "command to write {} should fail under seatbelt", - &config_toml.display() - ); - assert_seatbelt_denied(&output.stderr, &config_toml); - - // Create a similar Seatbelt command that tries to write to a file in - // the .git folder, which should also be blocked. - let pre_commit_hook = dot_git_canonical.join("hooks").join("pre-commit"); - let shell_command_git: Vec = [ - "bash", - "-c", - "echo 'pwned!' > \"$1\"", - "bash", - pre_commit_hook.to_string_lossy().as_ref(), - ] - .iter() - .map(std::string::ToString::to_string) - .collect(); - let write_hooks_file_args = - create_seatbelt_command_args(shell_command_git, &policy, &cwd, false, None); - let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE) - .args(&write_hooks_file_args) - .current_dir(&cwd) - .output() - .expect("execute seatbelt command"); - assert!( - !fs::exists(&pre_commit_hook).expect("exists pre-commit hook"), - "{} should not exist because it should not have been created", - pre_commit_hook.display() - ); - assert!( - !output.status.success(), - "command to write {} should fail under seatbelt", - &pre_commit_hook.display() - ); - assert_seatbelt_denied(&output.stderr, &pre_commit_hook); - - // Verify that writing a file to the folder containing .git and .codex is allowed. - let allowed_file = vulnerable_root_canonical.join("allowed.txt"); - let shell_command_allowed: Vec = [ - "bash", - "-c", - "echo 'this is allowed' > \"$1\"", - "bash", - allowed_file.to_string_lossy().as_ref(), - ] - .iter() - .map(std::string::ToString::to_string) - .collect(); - let write_allowed_file_args = - create_seatbelt_command_args(shell_command_allowed, &policy, &cwd, false, None); - let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE) - .args(&write_allowed_file_args) - .current_dir(&cwd) - .output() - .expect("execute seatbelt command"); - let stderr = String::from_utf8_lossy(&output.stderr); - if !output.status.success() - && stderr.contains("sandbox-exec: sandbox_apply: Operation not permitted") - { - return; - } - assert!( - output.status.success(), - "command to write {} should succeed under seatbelt", - &allowed_file.display() - ); - assert_eq!( - "this is allowed\n", - String::from_utf8_lossy(&fs::read(&allowed_file).expect("read allowed.txt")), - "{} should contain the written text", - allowed_file.display() - ); - } - - #[test] - fn create_seatbelt_args_with_read_only_git_pointer_file() { - let tmp = TempDir::new().expect("tempdir"); - let worktree_root = tmp.path().join("worktree_root"); - fs::create_dir_all(&worktree_root).expect("create worktree_root"); - let gitdir = worktree_root.join("actual-gitdir"); - fs::create_dir_all(&gitdir).expect("create gitdir"); - let gitdir_config = gitdir.join("config"); - let gitdir_config_contents = "[core]\n"; - fs::write(&gitdir_config, gitdir_config_contents).expect("write gitdir config"); - - let dot_git = worktree_root.join(".git"); - let dot_git_contents = format!("gitdir: {}\n", gitdir.to_string_lossy()); - fs::write(&dot_git, &dot_git_contents).expect("write .git pointer"); - - let cwd = tmp.path().join("cwd"); - fs::create_dir_all(&cwd).expect("create cwd"); - - let policy = SandboxPolicy::WorkspaceWrite { - writable_roots: vec![worktree_root.try_into().expect("worktree_root is absolute")], - read_only_access: Default::default(), - network_access: false, - exclude_tmpdir_env_var: true, - exclude_slash_tmp: true, - }; - - let shell_command: Vec = [ - "bash", - "-c", - "echo 'pwned!' > \"$1\"", - "bash", - dot_git.to_string_lossy().as_ref(), - ] - .iter() - .map(std::string::ToString::to_string) - .collect(); - let args = create_seatbelt_command_args(shell_command, &policy, &cwd, false, None); - - let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE) - .args(&args) - .current_dir(&cwd) - .output() - .expect("execute seatbelt command"); - - assert_eq!( - dot_git_contents, - String::from_utf8_lossy(&fs::read(&dot_git).expect("read .git pointer")), - ".git pointer file should not be modified under seatbelt" - ); - assert!( - !output.status.success(), - "command to write {} should fail under seatbelt", - dot_git.display() - ); - assert_seatbelt_denied(&output.stderr, &dot_git); - - let shell_command_gitdir: Vec = [ - "bash", - "-c", - "echo 'pwned!' > \"$1\"", - "bash", - gitdir_config.to_string_lossy().as_ref(), - ] - .iter() - .map(std::string::ToString::to_string) - .collect(); - let gitdir_args = - create_seatbelt_command_args(shell_command_gitdir, &policy, &cwd, false, None); - let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE) - .args(&gitdir_args) - .current_dir(&cwd) - .output() - .expect("execute seatbelt command"); - - assert_eq!( - gitdir_config_contents, - String::from_utf8_lossy(&fs::read(&gitdir_config).expect("read gitdir config")), - "gitdir config should contain its original contents because it should not have been modified" - ); - assert!( - !output.status.success(), - "command to write {} should fail under seatbelt", - gitdir_config.display() - ); - assert_seatbelt_denied(&output.stderr, &gitdir_config); - } - - #[test] - fn create_seatbelt_args_for_cwd_as_git_repo() { - // Create a temporary workspace with two writable roots: one containing - // top-level .git and .codex directories and one without them. - let tmp = TempDir::new().expect("tempdir"); - let PopulatedTmp { - vulnerable_root, - vulnerable_root_canonical, - dot_git_canonical, - dot_codex_canonical, - .. - } = populate_tmpdir(tmp.path()); - - // Build a policy that does not specify any writable_roots, but does - // use the default ones (cwd and TMPDIR) and verifies the `.git` and - // `.codex` checks are done properly for cwd. - let policy = SandboxPolicy::WorkspaceWrite { - writable_roots: vec![], - read_only_access: Default::default(), - network_access: false, - exclude_tmpdir_env_var: false, - exclude_slash_tmp: false, - }; - - let shell_command: Vec = [ - "bash", - "-c", - "echo 'sandbox_mode = \"danger-full-access\"' > \"$1\"", - "bash", - dot_codex_canonical - .join("config.toml") - .to_string_lossy() - .as_ref(), - ] - .iter() - .map(std::string::ToString::to_string) - .collect(); - let args = create_seatbelt_command_args( - shell_command.clone(), - &policy, - vulnerable_root.as_path(), - false, - None, - ); - - let tmpdir_env_var = std::env::var("TMPDIR") - .ok() - .map(PathBuf::from) - .and_then(|p| p.canonicalize().ok()) - .map(|p| p.to_string_lossy().to_string()); - - let tempdir_policy_entry = if tmpdir_env_var.is_some() { - r#" (subpath (param "WRITABLE_ROOT_2"))"# - } else { - "" - }; - - // Build the expected policy text using a raw string for readability. - // Note that the policy includes: - // - the base policy, - // - read-only access to the filesystem, - // - write access to WRITABLE_ROOT_0 (but not its .git or .codex), WRITABLE_ROOT_1, and cwd as WRITABLE_ROOT_2. - let expected_policy = format!( - r#"{MACOS_SEATBELT_BASE_POLICY} -; allow read-only file operations -(allow file-read*) -(allow file-write* -(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) (require-not (subpath (param "WRITABLE_ROOT_0_RO_1"))) ) (subpath (param "WRITABLE_ROOT_1")){tempdir_policy_entry} -) - -; macOS permission profile extensions -(allow ipc-posix-shm-read* (ipc-posix-name-prefix "apple.cfprefs.")) -(allow mach-lookup - (global-name "com.apple.cfprefsd.daemon") - (global-name "com.apple.cfprefsd.agent") - (local-name "com.apple.cfprefsd.agent")) -(allow user-preference-read) -"#, - ); - - let mut expected_args = vec![ - "-p".to_string(), - expected_policy, - format!( - "-DWRITABLE_ROOT_0={}", - vulnerable_root_canonical.to_string_lossy() - ), - format!( - "-DWRITABLE_ROOT_0_RO_0={}", - dot_git_canonical.to_string_lossy() - ), - format!( - "-DWRITABLE_ROOT_0_RO_1={}", - dot_codex_canonical.to_string_lossy() - ), - format!( - "-DWRITABLE_ROOT_1={}", - PathBuf::from("/tmp") - .canonicalize() - .expect("canonicalize /tmp") - .to_string_lossy() - ), - ]; - - if let Some(p) = tmpdir_env_var { - expected_args.push(format!("-DWRITABLE_ROOT_2={p}")); - } - - expected_args.extend( - macos_dir_params() - .into_iter() - .map(|(key, value)| format!("-D{key}={value}", value = value.to_string_lossy())), - ); - - expected_args.push("--".to_string()); - expected_args.extend(shell_command); - - assert_eq!(expected_args, args); - } - - struct PopulatedTmp { - /// Path containing a .git and .codex subfolder. - /// For the purposes of this test, we consider this a "vulnerable" root - /// because a bad actor could write to .git/hooks/pre-commit so an - /// unsuspecting user would run code as privileged the next time they - /// ran `git commit` themselves, or modified .codex/config.toml to - /// contain `sandbox_mode = "danger-full-access"` so the agent would - /// have full privileges the next time it ran in that repo. - vulnerable_root: PathBuf, - vulnerable_root_canonical: PathBuf, - dot_git_canonical: PathBuf, - dot_codex_canonical: PathBuf, - - /// Path without .git or .codex subfolders. - empty_root: PathBuf, - /// Canonicalized version of `empty_root`. - empty_root_canonical: PathBuf, - } - - fn populate_tmpdir(tmp: &Path) -> PopulatedTmp { - let vulnerable_root = tmp.join("vulnerable_root"); - fs::create_dir_all(&vulnerable_root).expect("create vulnerable_root"); - - // TODO(mbolin): Should also support the case where `.git` is a file - // with a gitdir: ... line. - Command::new("git") - .arg("init") - .arg(".") - .current_dir(&vulnerable_root) - .output() - .expect("git init ."); - - fs::create_dir_all(vulnerable_root.join(".codex")).expect("create .codex"); - fs::write( - vulnerable_root.join(".codex").join("config.toml"), - "sandbox_mode = \"read-only\"\n", - ) - .expect("write .codex/config.toml"); - - let empty_root = tmp.join("empty_root"); - fs::create_dir_all(&empty_root).expect("create empty_root"); - - // Ensure we have canonical paths for -D parameter matching. - let vulnerable_root_canonical = vulnerable_root - .canonicalize() - .expect("canonicalize vulnerable_root"); - let dot_git_canonical = vulnerable_root_canonical.join(".git"); - let dot_codex_canonical = vulnerable_root_canonical.join(".codex"); - let empty_root_canonical = empty_root.canonicalize().expect("canonicalize empty_root"); - PopulatedTmp { - vulnerable_root, - vulnerable_root_canonical, - dot_git_canonical, - dot_codex_canonical, - empty_root, - empty_root_canonical, - } - } -} +#[path = "seatbelt_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/seatbelt_permissions.rs b/codex-rs/core/src/seatbelt_permissions.rs index 219ca332e7..5cbfd65094 100644 --- a/codex-rs/core/src/seatbelt_permissions.rs +++ b/codex-rs/core/src/seatbelt_permissions.rs @@ -188,157 +188,5 @@ fn is_valid_bundle_id(bundle_id: &str) -> bool { } #[cfg(test)] -mod tests { - use super::MacOsAutomationPermission; - use super::MacOsContactsPermission; - use super::MacOsPreferencesPermission; - use super::MacOsSeatbeltProfileExtensions; - use super::build_seatbelt_extensions; - - #[test] - fn preferences_read_only_emits_read_clauses_only() { - let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadOnly, - ..Default::default() - }); - assert!(policy.policy.contains("(allow user-preference-read)")); - assert!(!policy.policy.contains("(allow user-preference-write)")); - } - - #[test] - fn preferences_read_write_emits_write_clauses() { - let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - ..Default::default() - }); - assert!(policy.policy.contains("(allow user-preference-read)")); - assert!(policy.policy.contains("(allow user-preference-write)")); - assert!(policy.policy.contains( - "(allow ipc-posix-shm-write-create (ipc-posix-name-prefix \"apple.cfprefs.\"))" - )); - } - - #[test] - fn automation_all_emits_unscoped_appleevents() { - let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { - macos_automation: MacOsAutomationPermission::All, - ..Default::default() - }); - assert!(policy.policy.contains("(allow appleevent-send)")); - assert!(policy.policy.contains("com.apple.coreservices.appleevents")); - } - - #[test] - fn automation_bundle_ids_are_normalized_and_scoped() { - let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - " com.apple.Notes ".to_string(), - "com.apple.Calendar".to_string(), - "bad bundle".to_string(), - "com.apple.Notes".to_string(), - ]), - ..Default::default() - }); - assert!( - policy - .policy - .contains("(appleevent-destination \"com.apple.Calendar\")") - ); - assert!( - policy - .policy - .contains("(appleevent-destination \"com.apple.Notes\")") - ); - assert!(!policy.policy.contains("bad bundle")); - assert!(policy.policy.contains("com.apple.coreservices.appleevents")); - } - - #[test] - fn launch_services_emit_launch_clauses() { - let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { - macos_launch_services: true, - ..Default::default() - }); - assert!( - policy - .policy - .contains("com.apple.coreservices.launchservicesd") - ); - assert!(policy.policy.contains("com.apple.lsd.mapdb")); - assert!( - policy - .policy - .contains("com.apple.coreservices.quarantine-resolver") - ); - assert!(policy.policy.contains("com.apple.lsd.modifydb")); - assert!(policy.policy.contains("(allow lsopen)")); - } - - #[test] - fn accessibility_and_calendar_emit_mach_lookups() { - let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { - macos_accessibility: true, - macos_calendar: true, - ..Default::default() - }); - assert!(policy.policy.contains("com.apple.axserver")); - assert!(policy.policy.contains("com.apple.CalendarAgent")); - } - - #[test] - fn reminders_emit_calendar_agent_and_remindd_lookups() { - let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { - macos_reminders: true, - ..Default::default() - }); - assert!(policy.policy.contains("com.apple.CalendarAgent")); - assert!(policy.policy.contains("com.apple.remindd")); - } - - #[test] - fn contacts_read_only_emit_contacts_read_clauses() { - let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { - macos_contacts: MacOsContactsPermission::ReadOnly, - ..Default::default() - }); - - assert!( - policy - .policy - .contains("(subpath \"/System/Library/Address Book Plug-Ins\")") - ); - assert!( - policy - .policy - .contains("(subpath (param \"ADDRESSBOOK_DIR\"))") - ); - assert!(policy.policy.contains("com.apple.contactsd.persistence")); - assert!(policy.policy.contains("com.apple.accountsd.accountmanager")); - assert!(!policy.policy.contains("com.apple.securityd.xpc")); - assert!( - policy - .dir_params - .iter() - .any(|(key, _)| key == "ADDRESSBOOK_DIR") - ); - } - - #[test] - fn contacts_read_write_emit_write_clauses() { - let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { - macos_contacts: MacOsContactsPermission::ReadWrite, - ..Default::default() - }); - - assert!(policy.policy.contains("(subpath \"/var/folders\")")); - assert!(policy.policy.contains("(subpath \"/private/var/folders\")")); - assert!(policy.policy.contains("com.apple.securityd.xpc")); - } - - #[test] - fn default_extensions_emit_preferences_read_only_policy() { - let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions::default()); - assert!(policy.policy.contains("(allow user-preference-read)")); - assert!(!policy.policy.contains("(allow user-preference-write)")); - } -} +#[path = "seatbelt_permissions_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/seatbelt_permissions_tests.rs b/codex-rs/core/src/seatbelt_permissions_tests.rs new file mode 100644 index 0000000000..b52ccdfcb2 --- /dev/null +++ b/codex-rs/core/src/seatbelt_permissions_tests.rs @@ -0,0 +1,154 @@ +use super::MacOsAutomationPermission; +use super::MacOsContactsPermission; +use super::MacOsPreferencesPermission; +use super::MacOsSeatbeltProfileExtensions; +use super::build_seatbelt_extensions; + +#[test] +fn preferences_read_only_emits_read_clauses_only() { + let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadOnly, + ..Default::default() + }); + assert!(policy.policy.contains("(allow user-preference-read)")); + assert!(!policy.policy.contains("(allow user-preference-write)")); +} + +#[test] +fn preferences_read_write_emits_write_clauses() { + let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + ..Default::default() + }); + assert!(policy.policy.contains("(allow user-preference-read)")); + assert!(policy.policy.contains("(allow user-preference-write)")); + assert!( + policy.policy.contains( + "(allow ipc-posix-shm-write-create (ipc-posix-name-prefix \"apple.cfprefs.\"))" + ) + ); +} + +#[test] +fn automation_all_emits_unscoped_appleevents() { + let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { + macos_automation: MacOsAutomationPermission::All, + ..Default::default() + }); + assert!(policy.policy.contains("(allow appleevent-send)")); + assert!(policy.policy.contains("com.apple.coreservices.appleevents")); +} + +#[test] +fn automation_bundle_ids_are_normalized_and_scoped() { + let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + " com.apple.Notes ".to_string(), + "com.apple.Calendar".to_string(), + "bad bundle".to_string(), + "com.apple.Notes".to_string(), + ]), + ..Default::default() + }); + assert!( + policy + .policy + .contains("(appleevent-destination \"com.apple.Calendar\")") + ); + assert!( + policy + .policy + .contains("(appleevent-destination \"com.apple.Notes\")") + ); + assert!(!policy.policy.contains("bad bundle")); + assert!(policy.policy.contains("com.apple.coreservices.appleevents")); +} + +#[test] +fn launch_services_emit_launch_clauses() { + let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { + macos_launch_services: true, + ..Default::default() + }); + assert!( + policy + .policy + .contains("com.apple.coreservices.launchservicesd") + ); + assert!(policy.policy.contains("com.apple.lsd.mapdb")); + assert!( + policy + .policy + .contains("com.apple.coreservices.quarantine-resolver") + ); + assert!(policy.policy.contains("com.apple.lsd.modifydb")); + assert!(policy.policy.contains("(allow lsopen)")); +} + +#[test] +fn accessibility_and_calendar_emit_mach_lookups() { + let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { + macos_accessibility: true, + macos_calendar: true, + ..Default::default() + }); + assert!(policy.policy.contains("com.apple.axserver")); + assert!(policy.policy.contains("com.apple.CalendarAgent")); +} + +#[test] +fn reminders_emit_calendar_agent_and_remindd_lookups() { + let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { + macos_reminders: true, + ..Default::default() + }); + assert!(policy.policy.contains("com.apple.CalendarAgent")); + assert!(policy.policy.contains("com.apple.remindd")); +} + +#[test] +fn contacts_read_only_emit_contacts_read_clauses() { + let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { + macos_contacts: MacOsContactsPermission::ReadOnly, + ..Default::default() + }); + + assert!( + policy + .policy + .contains("(subpath \"/System/Library/Address Book Plug-Ins\")") + ); + assert!( + policy + .policy + .contains("(subpath (param \"ADDRESSBOOK_DIR\"))") + ); + assert!(policy.policy.contains("com.apple.contactsd.persistence")); + assert!(policy.policy.contains("com.apple.accountsd.accountmanager")); + assert!(!policy.policy.contains("com.apple.securityd.xpc")); + assert!( + policy + .dir_params + .iter() + .any(|(key, _)| key == "ADDRESSBOOK_DIR") + ); +} + +#[test] +fn contacts_read_write_emit_write_clauses() { + let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions { + macos_contacts: MacOsContactsPermission::ReadWrite, + ..Default::default() + }); + + assert!(policy.policy.contains("(subpath \"/var/folders\")")); + assert!(policy.policy.contains("(subpath \"/private/var/folders\")")); + assert!(policy.policy.contains("com.apple.securityd.xpc")); +} + +#[test] +fn default_extensions_emit_preferences_read_only_policy() { + let policy = build_seatbelt_extensions(&MacOsSeatbeltProfileExtensions::default()); + assert!(policy.policy.contains("(allow user-preference-read)")); + assert!(!policy.policy.contains("(allow user-preference-write)")); +} diff --git a/codex-rs/core/src/seatbelt_tests.rs b/codex-rs/core/src/seatbelt_tests.rs new file mode 100644 index 0000000000..9ac5eaa7b0 --- /dev/null +++ b/codex-rs/core/src/seatbelt_tests.rs @@ -0,0 +1,1058 @@ +use super::MACOS_SEATBELT_BASE_POLICY; +use super::ProxyPolicyInputs; +use super::UnixDomainSocketPolicy; +use super::create_seatbelt_command_args; +use super::create_seatbelt_command_args_for_policies_with_extensions; +use super::create_seatbelt_command_args_with_extensions; +use super::dynamic_network_policy; +use super::macos_dir_params; +use super::normalize_path_for_sandbox; +use super::unix_socket_dir_params; +use super::unix_socket_policy; +use crate::protocol::ReadOnlyAccess; +use crate::protocol::SandboxPolicy; +use crate::seatbelt::MACOS_PATH_TO_SEATBELT_EXECUTABLE; +use crate::seatbelt_permissions::MacOsAutomationPermission; +use crate::seatbelt_permissions::MacOsContactsPermission; +use crate::seatbelt_permissions::MacOsPreferencesPermission; +use crate::seatbelt_permissions::MacOsSeatbeltProfileExtensions; +use codex_protocol::permissions::FileSystemAccessMode; +use codex_protocol::permissions::FileSystemPath; +use codex_protocol::permissions::FileSystemSandboxEntry; +use codex_protocol::permissions::FileSystemSandboxPolicy; +use codex_protocol::permissions::NetworkSandboxPolicy; +use codex_utils_absolute_path::AbsolutePathBuf; +use pretty_assertions::assert_eq; +use std::fs; +use std::path::Path; +use std::path::PathBuf; +use std::process::Command; +use tempfile::TempDir; + +fn assert_seatbelt_denied(stderr: &[u8], path: &Path) { + let stderr = String::from_utf8_lossy(stderr); + let expected = format!("bash: {}: Operation not permitted\n", path.display()); + assert!( + stderr == expected + || stderr.contains("sandbox-exec: sandbox_apply: Operation not permitted"), + "unexpected stderr: {stderr}" + ); +} + +fn absolute_path(path: &str) -> AbsolutePathBuf { + AbsolutePathBuf::from_absolute_path(Path::new(path)).expect("absolute path") +} + +fn seatbelt_policy_arg(args: &[String]) -> &str { + let policy_index = args + .iter() + .position(|arg| arg == "-p") + .expect("seatbelt args should include -p"); + args.get(policy_index + 1) + .expect("seatbelt args should include policy text") +} + +#[test] +fn base_policy_allows_node_cpu_sysctls() { + assert!( + MACOS_SEATBELT_BASE_POLICY.contains("(sysctl-name \"machdep.cpu.brand_string\")"), + "base policy must allow CPU brand lookup for os.cpus()" + ); + assert!( + MACOS_SEATBELT_BASE_POLICY.contains("(sysctl-name \"hw.model\")"), + "base policy must allow hardware model lookup for os.cpus()" + ); +} + +#[test] +fn create_seatbelt_args_routes_network_through_proxy_ports() { + let policy = dynamic_network_policy( + &SandboxPolicy::new_read_only_policy(), + false, + &ProxyPolicyInputs { + ports: vec![43128, 48081], + has_proxy_config: true, + allow_local_binding: false, + ..ProxyPolicyInputs::default() + }, + ); + + assert!( + policy.contains("(allow network-outbound (remote ip \"localhost:43128\"))"), + "expected HTTP proxy port allow rule in policy:\n{policy}" + ); + assert!( + policy.contains("(allow network-outbound (remote ip \"localhost:48081\"))"), + "expected SOCKS proxy port allow rule in policy:\n{policy}" + ); + assert!( + !policy.contains("\n(allow network-outbound)\n"), + "policy should not include blanket outbound allowance when proxy ports are present:\n{policy}" + ); + assert!( + !policy.contains("(allow network-bind (local ip \"localhost:*\"))"), + "policy should not allow loopback binding unless explicitly enabled:\n{policy}" + ); + assert!( + !policy.contains("(allow network-inbound (local ip \"localhost:*\"))"), + "policy should not allow loopback inbound unless explicitly enabled:\n{policy}" + ); +} + +#[test] +fn explicit_unreadable_paths_are_excluded_from_full_disk_read_and_write_access() { + let unreadable = absolute_path("/tmp/codex-unreadable"); + let file_system_policy = FileSystemSandboxPolicy::restricted(vec![ + FileSystemSandboxEntry { + path: FileSystemPath::Special { + value: crate::protocol::FileSystemSpecialPath::Root, + }, + access: FileSystemAccessMode::Write, + }, + FileSystemSandboxEntry { + path: FileSystemPath::Path { path: unreadable }, + access: FileSystemAccessMode::None, + }, + ]); + + let args = create_seatbelt_command_args_for_policies_with_extensions( + vec!["/bin/true".to_string()], + &file_system_policy, + NetworkSandboxPolicy::Restricted, + Path::new("/"), + false, + None, + None, + ); + + let policy = seatbelt_policy_arg(&args); + assert!( + policy.contains("(require-not (subpath (param \"READABLE_ROOT_0_RO_0\")))"), + "expected read carveout in policy:\n{policy}" + ); + assert!( + policy.contains("(require-not (subpath (param \"WRITABLE_ROOT_0_RO_0\")))"), + "expected write carveout in policy:\n{policy}" + ); + assert!( + args.iter() + .any(|arg| arg == "-DREADABLE_ROOT_0_RO_0=/tmp/codex-unreadable"), + "expected read carveout parameter in args: {args:#?}" + ); + assert!( + args.iter() + .any(|arg| arg == "-DWRITABLE_ROOT_0_RO_0=/tmp/codex-unreadable"), + "expected write carveout parameter in args: {args:#?}" + ); +} + +#[test] +fn explicit_unreadable_paths_are_excluded_from_readable_roots() { + let root = absolute_path("/tmp/codex-readable"); + let unreadable = absolute_path("/tmp/codex-readable/private"); + let file_system_policy = FileSystemSandboxPolicy::restricted(vec![ + FileSystemSandboxEntry { + path: FileSystemPath::Path { path: root }, + access: FileSystemAccessMode::Read, + }, + FileSystemSandboxEntry { + path: FileSystemPath::Path { path: unreadable }, + access: FileSystemAccessMode::None, + }, + ]); + + let args = create_seatbelt_command_args_for_policies_with_extensions( + vec!["/bin/true".to_string()], + &file_system_policy, + NetworkSandboxPolicy::Restricted, + Path::new("/"), + false, + None, + None, + ); + + let policy = seatbelt_policy_arg(&args); + assert!( + policy.contains("(require-not (subpath (param \"READABLE_ROOT_0_RO_0\")))"), + "expected read carveout in policy:\n{policy}" + ); + assert!( + args.iter() + .any(|arg| arg == "-DREADABLE_ROOT_0=/tmp/codex-readable"), + "expected readable root parameter in args: {args:#?}" + ); + assert!( + args.iter() + .any(|arg| arg == "-DREADABLE_ROOT_0_RO_0=/tmp/codex-readable/private"), + "expected read carveout parameter in args: {args:#?}" + ); +} + +#[test] +fn seatbelt_args_include_macos_permission_extensions() { + let cwd = std::env::temp_dir(); + let args = create_seatbelt_command_args_with_extensions( + vec!["echo".to_string(), "ok".to_string()], + &SandboxPolicy::new_read_only_policy(), + cwd.as_path(), + false, + None, + Some(&MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Notes".to_string(), + ]), + macos_launch_services: true, + macos_accessibility: true, + macos_calendar: true, + macos_reminders: false, + macos_contacts: MacOsContactsPermission::None, + }), + ); + let policy = &args[1]; + + assert!(policy.contains("(allow user-preference-write)")); + assert!(policy.contains("(appleevent-destination \"com.apple.Notes\")")); + assert!(policy.contains("com.apple.axserver")); + assert!(policy.contains("com.apple.CalendarAgent")); +} + +#[test] +fn bundle_id_automation_keeps_lsopen_denied() { + let tmp = TempDir::new().expect("tempdir"); + let cwd = tmp.path().join("cwd"); + fs::create_dir_all(&cwd).expect("create cwd"); + + let args = create_seatbelt_command_args_with_extensions( + vec![ + "/usr/bin/python3".to_string(), + "-c".to_string(), + r#"import ctypes +import os +import sys +lib = ctypes.CDLL("/usr/lib/libsandbox.1.dylib") +lib.sandbox_check.restype = ctypes.c_int +allowed = lib.sandbox_check(os.getpid(), b"lsopen", 0) == 0 +sys.exit(0 if allowed else 13) +"# + .to_string(), + ], + &SandboxPolicy::new_read_only_policy(), + cwd.as_path(), + false, + None, + Some(&MacOsSeatbeltProfileExtensions { + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Notes".to_string(), + ]), + ..Default::default() + }), + ); + + let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE) + .args(&args) + .current_dir(&cwd) + .output() + .expect("execute seatbelt command"); + + let stderr = String::from_utf8_lossy(&output.stderr); + if stderr.contains("sandbox-exec: sandbox_apply: Operation not permitted") { + return; + } + + assert_eq!( + Some(13), + output.status.code(), + "lsopen should remain denied even with bundle-scoped automation\nstdout: {}\nstderr: {stderr}", + String::from_utf8_lossy(&output.stdout), + ); +} + +#[test] +fn seatbelt_args_without_extension_profile_keep_legacy_preferences_read_access() { + let cwd = std::env::temp_dir(); + let args = create_seatbelt_command_args( + vec!["echo".to_string(), "ok".to_string()], + &SandboxPolicy::new_read_only_policy(), + cwd.as_path(), + false, + None, + ); + let policy = &args[1]; + assert!(policy.contains("(allow user-preference-read)")); + assert!(!policy.contains("(allow user-preference-write)")); +} + +#[test] +fn seatbelt_legacy_workspace_write_nested_readable_root_stays_writable() { + let tmp = TempDir::new().expect("tempdir"); + let cwd = tmp.path().join("workspace"); + fs::create_dir_all(cwd.join("docs")).expect("create docs"); + let docs = AbsolutePathBuf::from_absolute_path(cwd.join("docs")).expect("absolute docs"); + let args = create_seatbelt_command_args( + vec!["/bin/true".to_string()], + &SandboxPolicy::WorkspaceWrite { + writable_roots: vec![], + read_only_access: ReadOnlyAccess::Restricted { + include_platform_defaults: true, + readable_roots: vec![docs.clone()], + }, + network_access: false, + exclude_tmpdir_env_var: true, + exclude_slash_tmp: true, + }, + cwd.as_path(), + false, + None, + ); + + let docs_param = format!("-DWRITABLE_ROOT_0_RO_0={}", docs.as_path().display()); + assert!( + !seatbelt_policy_arg(&args).contains("WRITABLE_ROOT_0_RO_0"), + "legacy workspace-write readable roots under cwd should not become seatbelt carveouts:\n{args:#?}" + ); + assert!( + !args.iter().any(|arg| arg == &docs_param), + "unexpected seatbelt carveout parameter for redundant legacy readable root: {args:#?}" + ); +} + +#[test] +fn seatbelt_args_default_extension_profile_keeps_preferences_read_access() { + let cwd = std::env::temp_dir(); + let args = create_seatbelt_command_args_with_extensions( + vec!["echo".to_string(), "ok".to_string()], + &SandboxPolicy::new_read_only_policy(), + cwd.as_path(), + false, + None, + Some(&MacOsSeatbeltProfileExtensions::default()), + ); + let policy = &args[1]; + assert!(!policy.contains("appleevent-send")); + assert!(!policy.contains("com.apple.axserver")); + assert!(!policy.contains("com.apple.CalendarAgent")); + assert!(policy.contains("(allow user-preference-read)")); + assert!(!policy.contains("user-preference-write")); +} + +#[test] +fn create_seatbelt_args_allows_local_binding_when_explicitly_enabled() { + let policy = dynamic_network_policy( + &SandboxPolicy::new_read_only_policy(), + false, + &ProxyPolicyInputs { + ports: vec![43128], + has_proxy_config: true, + allow_local_binding: true, + ..ProxyPolicyInputs::default() + }, + ); + + assert!( + policy.contains("(allow network-bind (local ip \"localhost:*\"))"), + "policy should allow loopback local binding when explicitly enabled:\n{policy}" + ); + assert!( + policy.contains("(allow network-inbound (local ip \"localhost:*\"))"), + "policy should allow loopback inbound when explicitly enabled:\n{policy}" + ); + assert!( + policy.contains("(allow network-outbound (remote ip \"localhost:*\"))"), + "policy should allow loopback outbound when explicitly enabled:\n{policy}" + ); + assert!( + !policy.contains("\n(allow network-outbound)\n"), + "policy should keep proxy-routed behavior without blanket outbound allowance:\n{policy}" + ); +} + +#[test] +fn dynamic_network_policy_preserves_restricted_policy_when_proxy_config_without_ports() { + let policy = dynamic_network_policy( + &SandboxPolicy::WorkspaceWrite { + writable_roots: vec![], + read_only_access: Default::default(), + network_access: true, + exclude_tmpdir_env_var: false, + exclude_slash_tmp: false, + }, + false, + &ProxyPolicyInputs { + ports: vec![], + has_proxy_config: true, + allow_local_binding: false, + ..ProxyPolicyInputs::default() + }, + ); + + assert!( + policy.contains("(socket-domain AF_SYSTEM)"), + "policy should keep the restricted network profile when proxy config is present without ports:\n{policy}" + ); + assert!( + !policy.contains("\n(allow network-outbound)\n"), + "policy should not include blanket outbound allowance when proxy config is present without ports:\n{policy}" + ); + assert!( + !policy.contains("(allow network-outbound (remote ip \"localhost:"), + "policy should not include proxy port allowance when proxy config is present without ports:\n{policy}" + ); +} + +#[test] +fn dynamic_network_policy_preserves_restricted_policy_for_managed_network_without_proxy_config() { + let policy = dynamic_network_policy( + &SandboxPolicy::WorkspaceWrite { + writable_roots: vec![], + read_only_access: Default::default(), + network_access: true, + exclude_tmpdir_env_var: false, + exclude_slash_tmp: false, + }, + true, + &ProxyPolicyInputs { + ports: vec![], + has_proxy_config: false, + allow_local_binding: false, + ..ProxyPolicyInputs::default() + }, + ); + + assert!( + policy.contains("(socket-domain AF_SYSTEM)"), + "policy should keep the restricted network profile when managed network is active without proxy endpoints:\n{policy}" + ); + assert!( + !policy.contains("\n(allow network-outbound)\n"), + "policy should not include blanket outbound allowance when managed network is active without proxy endpoints:\n{policy}" + ); +} + +#[test] +fn create_seatbelt_args_allowlists_unix_socket_paths() { + let policy = dynamic_network_policy( + &SandboxPolicy::new_read_only_policy(), + false, + &ProxyPolicyInputs { + ports: vec![43128], + has_proxy_config: true, + allow_local_binding: false, + unix_domain_socket_policy: UnixDomainSocketPolicy::Restricted { + allowed: vec![absolute_path("/tmp/example.sock")], + }, + }, + ); + + assert!( + policy.contains("(allow system-socket (socket-domain AF_UNIX))"), + "policy should allow AF_UNIX socket creation for configured unix sockets:\n{policy}" + ); + assert!( + policy.contains( + "(allow network-bind (local unix-socket (subpath (param \"UNIX_SOCKET_PATH_0\"))))" + ), + "policy should allow binding explicitly configured unix sockets:\n{policy}" + ); + assert!( + policy.contains( + "(allow network-outbound (remote unix-socket (subpath (param \"UNIX_SOCKET_PATH_0\"))))" + ), + "policy should allow connecting to explicitly configured unix sockets:\n{policy}" + ); + assert!( + !policy.contains("(allow network* (subpath"), + "policy should no longer use the generic subpath unix-socket rules:\n{policy}" + ); +} + +#[test] +fn unix_socket_policy_non_empty_output_is_newline_terminated() { + let allowlist_policy = unix_socket_policy(&ProxyPolicyInputs { + unix_domain_socket_policy: UnixDomainSocketPolicy::Restricted { + allowed: vec![absolute_path("/tmp/example.sock")], + }, + ..ProxyPolicyInputs::default() + }); + assert!( + allowlist_policy.ends_with('\n'), + "allowlist unix socket policy should end with a newline:\n{allowlist_policy}" + ); + + let allow_all_policy = unix_socket_policy(&ProxyPolicyInputs { + unix_domain_socket_policy: UnixDomainSocketPolicy::AllowAll, + ..ProxyPolicyInputs::default() + }); + assert!( + allow_all_policy.ends_with('\n'), + "allow-all unix socket policy should end with a newline:\n{allow_all_policy}" + ); +} + +#[test] +fn unix_socket_dir_params_use_stable_param_names() { + let params = unix_socket_dir_params(&ProxyPolicyInputs { + unix_domain_socket_policy: UnixDomainSocketPolicy::Restricted { + allowed: vec![ + absolute_path("/tmp/b.sock"), + absolute_path("/tmp/a.sock"), + absolute_path("/tmp/a.sock"), + ], + }, + ..ProxyPolicyInputs::default() + }); + + assert_eq!( + params, + vec![ + ( + "UNIX_SOCKET_PATH_0".to_string(), + PathBuf::from("/tmp/a.sock") + ), + ( + "UNIX_SOCKET_PATH_1".to_string(), + PathBuf::from("/tmp/b.sock") + ), + ] + ); +} + +#[test] +fn normalize_path_for_sandbox_rejects_relative_paths() { + assert_eq!(normalize_path_for_sandbox(Path::new("relative.sock")), None); +} + +#[test] +fn create_seatbelt_args_allows_all_unix_sockets_when_enabled() { + let policy = dynamic_network_policy( + &SandboxPolicy::new_read_only_policy(), + false, + &ProxyPolicyInputs { + ports: vec![43128], + has_proxy_config: true, + allow_local_binding: false, + unix_domain_socket_policy: UnixDomainSocketPolicy::AllowAll, + }, + ); + + assert!( + policy.contains("(allow system-socket (socket-domain AF_UNIX))"), + "policy should allow AF_UNIX socket creation when unix sockets are enabled:\n{policy}" + ); + assert!( + policy.contains("(allow network-bind (local unix-socket))"), + "policy should allow binding unix sockets when enabled:\n{policy}" + ); + assert!( + policy.contains("(allow network-outbound (remote unix-socket))"), + "policy should allow connecting to unix sockets when enabled:\n{policy}" + ); + assert!( + !policy.contains("(allow network* (subpath"), + "policy should no longer use the generic subpath unix-socket rules:\n{policy}" + ); +} + +#[test] +fn create_seatbelt_args_full_network_with_proxy_is_still_proxy_only() { + let policy = dynamic_network_policy( + &SandboxPolicy::WorkspaceWrite { + writable_roots: vec![], + read_only_access: Default::default(), + network_access: true, + exclude_tmpdir_env_var: false, + exclude_slash_tmp: false, + }, + false, + &ProxyPolicyInputs { + ports: vec![43128], + has_proxy_config: true, + allow_local_binding: false, + ..ProxyPolicyInputs::default() + }, + ); + + assert!( + policy.contains("(allow network-outbound (remote ip \"localhost:43128\"))"), + "expected proxy endpoint allow rule in policy:\n{policy}" + ); + assert!( + !policy.contains("\n(allow network-outbound)\n"), + "policy should not include blanket outbound allowance when proxy is configured:\n{policy}" + ); + assert!( + !policy.contains("\n(allow network-inbound)\n"), + "policy should not include blanket inbound allowance when proxy is configured:\n{policy}" + ); +} + +#[test] +fn create_seatbelt_args_with_read_only_git_and_codex_subpaths() { + // Create a temporary workspace with two writable roots: one containing + // top-level .git and .codex directories and one without them. + let tmp = TempDir::new().expect("tempdir"); + let PopulatedTmp { + vulnerable_root, + vulnerable_root_canonical, + dot_git_canonical, + dot_codex_canonical, + empty_root, + empty_root_canonical, + } = populate_tmpdir(tmp.path()); + let cwd = tmp.path().join("cwd"); + fs::create_dir_all(&cwd).expect("create cwd"); + + // Build a policy that only includes the two test roots as writable and + // does not automatically include defaults TMPDIR or /tmp. + let policy = SandboxPolicy::WorkspaceWrite { + writable_roots: vec![vulnerable_root, empty_root] + .into_iter() + .map(|p| p.try_into().unwrap()) + .collect(), + read_only_access: Default::default(), + network_access: false, + exclude_tmpdir_env_var: true, + exclude_slash_tmp: true, + }; + + // Create the Seatbelt command to wrap a shell command that tries to + // write to .codex/config.toml in the vulnerable root. + let shell_command: Vec = [ + "bash", + "-c", + "echo 'sandbox_mode = \"danger-full-access\"' > \"$1\"", + "bash", + dot_codex_canonical + .join("config.toml") + .to_string_lossy() + .as_ref(), + ] + .iter() + .map(std::string::ToString::to_string) + .collect(); + let args = create_seatbelt_command_args(shell_command.clone(), &policy, &cwd, false, None); + + // Build the expected policy text using a raw string for readability. + // Note that the policy includes: + // - the base policy, + // - read-only access to the filesystem, + // - write access to WRITABLE_ROOT_0 (but not its .git or .codex), WRITABLE_ROOT_1, and cwd as WRITABLE_ROOT_2. + let expected_policy = format!( + r#"{MACOS_SEATBELT_BASE_POLICY} +; allow read-only file operations +(allow file-read*) +(allow file-write* +(subpath (param "WRITABLE_ROOT_0")) (require-all (subpath (param "WRITABLE_ROOT_1")) (require-not (subpath (param "WRITABLE_ROOT_1_RO_0"))) (require-not (subpath (param "WRITABLE_ROOT_1_RO_1"))) ) (subpath (param "WRITABLE_ROOT_2")) +) + +; macOS permission profile extensions +(allow ipc-posix-shm-read* (ipc-posix-name-prefix "apple.cfprefs.")) +(allow mach-lookup + (global-name "com.apple.cfprefsd.daemon") + (global-name "com.apple.cfprefsd.agent") + (local-name "com.apple.cfprefsd.agent")) +(allow user-preference-read) +"#, + ); + + assert_eq!(seatbelt_policy_arg(&args), expected_policy); + + let expected_definitions = [ + format!( + "-DWRITABLE_ROOT_0={}", + cwd.canonicalize() + .expect("canonicalize cwd") + .to_string_lossy() + ), + format!( + "-DWRITABLE_ROOT_1={}", + vulnerable_root_canonical.to_string_lossy() + ), + format!( + "-DWRITABLE_ROOT_1_RO_0={}", + dot_git_canonical.to_string_lossy() + ), + format!( + "-DWRITABLE_ROOT_1_RO_1={}", + dot_codex_canonical.to_string_lossy() + ), + format!( + "-DWRITABLE_ROOT_2={}", + empty_root_canonical.to_string_lossy() + ), + ]; + for expected_definition in expected_definitions { + assert!( + args.contains(&expected_definition), + "expected definition arg `{expected_definition}` in {args:#?}" + ); + } + for (key, value) in macos_dir_params() { + let expected_definition = format!("-D{key}={}", value.to_string_lossy()); + assert!( + args.contains(&expected_definition), + "expected definition arg `{expected_definition}` in {args:#?}" + ); + } + + let command_index = args + .iter() + .position(|arg| arg == "--") + .expect("seatbelt args should include command separator"); + assert_eq!(args[command_index + 1..], shell_command); + + // Verify that .codex/config.toml cannot be modified under the generated + // Seatbelt policy. + let config_toml = dot_codex_canonical.join("config.toml"); + let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE) + .args(&args) + .current_dir(&cwd) + .output() + .expect("execute seatbelt command"); + assert_eq!( + "sandbox_mode = \"read-only\"\n", + String::from_utf8_lossy(&fs::read(&config_toml).expect("read config.toml")), + "config.toml should contain its original contents because it should not have been modified" + ); + assert!( + !output.status.success(), + "command to write {} should fail under seatbelt", + &config_toml.display() + ); + assert_seatbelt_denied(&output.stderr, &config_toml); + + // Create a similar Seatbelt command that tries to write to a file in + // the .git folder, which should also be blocked. + let pre_commit_hook = dot_git_canonical.join("hooks").join("pre-commit"); + let shell_command_git: Vec = [ + "bash", + "-c", + "echo 'pwned!' > \"$1\"", + "bash", + pre_commit_hook.to_string_lossy().as_ref(), + ] + .iter() + .map(std::string::ToString::to_string) + .collect(); + let write_hooks_file_args = + create_seatbelt_command_args(shell_command_git, &policy, &cwd, false, None); + let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE) + .args(&write_hooks_file_args) + .current_dir(&cwd) + .output() + .expect("execute seatbelt command"); + assert!( + !fs::exists(&pre_commit_hook).expect("exists pre-commit hook"), + "{} should not exist because it should not have been created", + pre_commit_hook.display() + ); + assert!( + !output.status.success(), + "command to write {} should fail under seatbelt", + &pre_commit_hook.display() + ); + assert_seatbelt_denied(&output.stderr, &pre_commit_hook); + + // Verify that writing a file to the folder containing .git and .codex is allowed. + let allowed_file = vulnerable_root_canonical.join("allowed.txt"); + let shell_command_allowed: Vec = [ + "bash", + "-c", + "echo 'this is allowed' > \"$1\"", + "bash", + allowed_file.to_string_lossy().as_ref(), + ] + .iter() + .map(std::string::ToString::to_string) + .collect(); + let write_allowed_file_args = + create_seatbelt_command_args(shell_command_allowed, &policy, &cwd, false, None); + let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE) + .args(&write_allowed_file_args) + .current_dir(&cwd) + .output() + .expect("execute seatbelt command"); + let stderr = String::from_utf8_lossy(&output.stderr); + if !output.status.success() + && stderr.contains("sandbox-exec: sandbox_apply: Operation not permitted") + { + return; + } + assert!( + output.status.success(), + "command to write {} should succeed under seatbelt", + &allowed_file.display() + ); + assert_eq!( + "this is allowed\n", + String::from_utf8_lossy(&fs::read(&allowed_file).expect("read allowed.txt")), + "{} should contain the written text", + allowed_file.display() + ); +} + +#[test] +fn create_seatbelt_args_with_read_only_git_pointer_file() { + let tmp = TempDir::new().expect("tempdir"); + let worktree_root = tmp.path().join("worktree_root"); + fs::create_dir_all(&worktree_root).expect("create worktree_root"); + let gitdir = worktree_root.join("actual-gitdir"); + fs::create_dir_all(&gitdir).expect("create gitdir"); + let gitdir_config = gitdir.join("config"); + let gitdir_config_contents = "[core]\n"; + fs::write(&gitdir_config, gitdir_config_contents).expect("write gitdir config"); + + let dot_git = worktree_root.join(".git"); + let dot_git_contents = format!("gitdir: {}\n", gitdir.to_string_lossy()); + fs::write(&dot_git, &dot_git_contents).expect("write .git pointer"); + + let cwd = tmp.path().join("cwd"); + fs::create_dir_all(&cwd).expect("create cwd"); + + let policy = SandboxPolicy::WorkspaceWrite { + writable_roots: vec![worktree_root.try_into().expect("worktree_root is absolute")], + read_only_access: Default::default(), + network_access: false, + exclude_tmpdir_env_var: true, + exclude_slash_tmp: true, + }; + + let shell_command: Vec = [ + "bash", + "-c", + "echo 'pwned!' > \"$1\"", + "bash", + dot_git.to_string_lossy().as_ref(), + ] + .iter() + .map(std::string::ToString::to_string) + .collect(); + let args = create_seatbelt_command_args(shell_command, &policy, &cwd, false, None); + + let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE) + .args(&args) + .current_dir(&cwd) + .output() + .expect("execute seatbelt command"); + + assert_eq!( + dot_git_contents, + String::from_utf8_lossy(&fs::read(&dot_git).expect("read .git pointer")), + ".git pointer file should not be modified under seatbelt" + ); + assert!( + !output.status.success(), + "command to write {} should fail under seatbelt", + dot_git.display() + ); + assert_seatbelt_denied(&output.stderr, &dot_git); + + let shell_command_gitdir: Vec = [ + "bash", + "-c", + "echo 'pwned!' > \"$1\"", + "bash", + gitdir_config.to_string_lossy().as_ref(), + ] + .iter() + .map(std::string::ToString::to_string) + .collect(); + let gitdir_args = + create_seatbelt_command_args(shell_command_gitdir, &policy, &cwd, false, None); + let output = Command::new(MACOS_PATH_TO_SEATBELT_EXECUTABLE) + .args(&gitdir_args) + .current_dir(&cwd) + .output() + .expect("execute seatbelt command"); + + assert_eq!( + gitdir_config_contents, + String::from_utf8_lossy(&fs::read(&gitdir_config).expect("read gitdir config")), + "gitdir config should contain its original contents because it should not have been modified" + ); + assert!( + !output.status.success(), + "command to write {} should fail under seatbelt", + gitdir_config.display() + ); + assert_seatbelt_denied(&output.stderr, &gitdir_config); +} + +#[test] +fn create_seatbelt_args_for_cwd_as_git_repo() { + // Create a temporary workspace with two writable roots: one containing + // top-level .git and .codex directories and one without them. + let tmp = TempDir::new().expect("tempdir"); + let PopulatedTmp { + vulnerable_root, + vulnerable_root_canonical, + dot_git_canonical, + dot_codex_canonical, + .. + } = populate_tmpdir(tmp.path()); + + // Build a policy that does not specify any writable_roots, but does + // use the default ones (cwd and TMPDIR) and verifies the `.git` and + // `.codex` checks are done properly for cwd. + let policy = SandboxPolicy::WorkspaceWrite { + writable_roots: vec![], + read_only_access: Default::default(), + network_access: false, + exclude_tmpdir_env_var: false, + exclude_slash_tmp: false, + }; + + let shell_command: Vec = [ + "bash", + "-c", + "echo 'sandbox_mode = \"danger-full-access\"' > \"$1\"", + "bash", + dot_codex_canonical + .join("config.toml") + .to_string_lossy() + .as_ref(), + ] + .iter() + .map(std::string::ToString::to_string) + .collect(); + let args = create_seatbelt_command_args( + shell_command.clone(), + &policy, + vulnerable_root.as_path(), + false, + None, + ); + + let tmpdir_env_var = std::env::var("TMPDIR") + .ok() + .map(PathBuf::from) + .and_then(|p| p.canonicalize().ok()) + .map(|p| p.to_string_lossy().to_string()); + + let tempdir_policy_entry = if tmpdir_env_var.is_some() { + r#" (subpath (param "WRITABLE_ROOT_2"))"# + } else { + "" + }; + + // Build the expected policy text using a raw string for readability. + // Note that the policy includes: + // - the base policy, + // - read-only access to the filesystem, + // - write access to WRITABLE_ROOT_0 (but not its .git or .codex), WRITABLE_ROOT_1, and cwd as WRITABLE_ROOT_2. + let expected_policy = format!( + r#"{MACOS_SEATBELT_BASE_POLICY} +; allow read-only file operations +(allow file-read*) +(allow file-write* +(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) (require-not (subpath (param "WRITABLE_ROOT_0_RO_1"))) ) (subpath (param "WRITABLE_ROOT_1")){tempdir_policy_entry} +) + +; macOS permission profile extensions +(allow ipc-posix-shm-read* (ipc-posix-name-prefix "apple.cfprefs.")) +(allow mach-lookup + (global-name "com.apple.cfprefsd.daemon") + (global-name "com.apple.cfprefsd.agent") + (local-name "com.apple.cfprefsd.agent")) +(allow user-preference-read) +"#, + ); + + let mut expected_args = vec![ + "-p".to_string(), + expected_policy, + format!( + "-DWRITABLE_ROOT_0={}", + vulnerable_root_canonical.to_string_lossy() + ), + format!( + "-DWRITABLE_ROOT_0_RO_0={}", + dot_git_canonical.to_string_lossy() + ), + format!( + "-DWRITABLE_ROOT_0_RO_1={}", + dot_codex_canonical.to_string_lossy() + ), + format!( + "-DWRITABLE_ROOT_1={}", + PathBuf::from("/tmp") + .canonicalize() + .expect("canonicalize /tmp") + .to_string_lossy() + ), + ]; + + if let Some(p) = tmpdir_env_var { + expected_args.push(format!("-DWRITABLE_ROOT_2={p}")); + } + + expected_args.extend( + macos_dir_params() + .into_iter() + .map(|(key, value)| format!("-D{key}={value}", value = value.to_string_lossy())), + ); + + expected_args.push("--".to_string()); + expected_args.extend(shell_command); + + assert_eq!(expected_args, args); +} + +struct PopulatedTmp { + /// Path containing a .git and .codex subfolder. + /// For the purposes of this test, we consider this a "vulnerable" root + /// because a bad actor could write to .git/hooks/pre-commit so an + /// unsuspecting user would run code as privileged the next time they + /// ran `git commit` themselves, or modified .codex/config.toml to + /// contain `sandbox_mode = "danger-full-access"` so the agent would + /// have full privileges the next time it ran in that repo. + vulnerable_root: PathBuf, + vulnerable_root_canonical: PathBuf, + dot_git_canonical: PathBuf, + dot_codex_canonical: PathBuf, + + /// Path without .git or .codex subfolders. + empty_root: PathBuf, + /// Canonicalized version of `empty_root`. + empty_root_canonical: PathBuf, +} + +fn populate_tmpdir(tmp: &Path) -> PopulatedTmp { + let vulnerable_root = tmp.join("vulnerable_root"); + fs::create_dir_all(&vulnerable_root).expect("create vulnerable_root"); + + // TODO(mbolin): Should also support the case where `.git` is a file + // with a gitdir: ... line. + Command::new("git") + .arg("init") + .arg(".") + .current_dir(&vulnerable_root) + .output() + .expect("git init ."); + + fs::create_dir_all(vulnerable_root.join(".codex")).expect("create .codex"); + fs::write( + vulnerable_root.join(".codex").join("config.toml"), + "sandbox_mode = \"read-only\"\n", + ) + .expect("write .codex/config.toml"); + + let empty_root = tmp.join("empty_root"); + fs::create_dir_all(&empty_root).expect("create empty_root"); + + // Ensure we have canonical paths for -D parameter matching. + let vulnerable_root_canonical = vulnerable_root + .canonicalize() + .expect("canonicalize vulnerable_root"); + let dot_git_canonical = vulnerable_root_canonical.join(".git"); + let dot_codex_canonical = vulnerable_root_canonical.join(".codex"); + let empty_root_canonical = empty_root.canonicalize().expect("canonicalize empty_root"); + PopulatedTmp { + vulnerable_root, + vulnerable_root_canonical, + dot_git_canonical, + dot_codex_canonical, + empty_root, + empty_root_canonical, + } +} diff --git a/codex-rs/core/src/shell.rs b/codex-rs/core/src/shell.rs index 4cd728992e..dba595c1ec 100644 --- a/codex-rs/core/src/shell.rs +++ b/codex-rs/core/src/shell.rs @@ -381,173 +381,5 @@ mod detect_shell_type_tests { #[cfg(test)] #[cfg(unix)] -mod tests { - use super::*; - use std::path::PathBuf; - use std::process::Command; - - #[test] - #[cfg(target_os = "macos")] - fn detects_zsh() { - let zsh_shell = get_shell(ShellType::Zsh, None).unwrap(); - - let shell_path = zsh_shell.shell_path; - - assert_eq!(shell_path, std::path::Path::new("/bin/zsh")); - } - - #[test] - #[cfg(target_os = "macos")] - fn fish_fallback_to_zsh() { - let zsh_shell = default_user_shell_from_path(Some(PathBuf::from("/bin/fish"))); - - let shell_path = zsh_shell.shell_path; - - assert_eq!(shell_path, std::path::Path::new("/bin/zsh")); - } - - #[test] - fn detects_bash() { - let bash_shell = get_shell(ShellType::Bash, None).unwrap(); - let shell_path = bash_shell.shell_path; - - assert!( - shell_path.file_name().and_then(|name| name.to_str()) == Some("bash"), - "shell path: {shell_path:?}", - ); - } - - #[test] - fn detects_sh() { - let sh_shell = get_shell(ShellType::Sh, None).unwrap(); - let shell_path = sh_shell.shell_path; - assert!( - shell_path.file_name().and_then(|name| name.to_str()) == Some("sh"), - "shell path: {shell_path:?}", - ); - } - - #[test] - fn can_run_on_shell_test() { - let cmd = "echo \"Works\""; - if cfg!(windows) { - assert!(shell_works( - get_shell(ShellType::PowerShell, None), - "Out-String 'Works'", - true, - )); - assert!(shell_works(get_shell(ShellType::Cmd, None), cmd, true,)); - assert!(shell_works(Some(ultimate_fallback_shell()), cmd, true)); - } else { - assert!(shell_works(Some(ultimate_fallback_shell()), cmd, true)); - assert!(shell_works(get_shell(ShellType::Zsh, None), cmd, false)); - assert!(shell_works(get_shell(ShellType::Bash, None), cmd, true)); - assert!(shell_works(get_shell(ShellType::Sh, None), cmd, true)); - } - } - - fn shell_works(shell: Option, command: &str, required: bool) -> bool { - if let Some(shell) = shell { - let args = shell.derive_exec_args(command, false); - let output = Command::new(args[0].clone()) - .args(&args[1..]) - .output() - .unwrap(); - assert!(output.status.success()); - assert!(String::from_utf8_lossy(&output.stdout).contains("Works")); - true - } else { - !required - } - } - - #[test] - fn derive_exec_args() { - let test_bash_shell = Shell { - shell_type: ShellType::Bash, - shell_path: PathBuf::from("/bin/bash"), - shell_snapshot: empty_shell_snapshot_receiver(), - }; - assert_eq!( - test_bash_shell.derive_exec_args("echo hello", false), - vec!["/bin/bash", "-c", "echo hello"] - ); - assert_eq!( - test_bash_shell.derive_exec_args("echo hello", true), - vec!["/bin/bash", "-lc", "echo hello"] - ); - - let test_zsh_shell = Shell { - shell_type: ShellType::Zsh, - shell_path: PathBuf::from("/bin/zsh"), - shell_snapshot: empty_shell_snapshot_receiver(), - }; - assert_eq!( - test_zsh_shell.derive_exec_args("echo hello", false), - vec!["/bin/zsh", "-c", "echo hello"] - ); - assert_eq!( - test_zsh_shell.derive_exec_args("echo hello", true), - vec!["/bin/zsh", "-lc", "echo hello"] - ); - - let test_powershell_shell = Shell { - shell_type: ShellType::PowerShell, - shell_path: PathBuf::from("pwsh.exe"), - shell_snapshot: empty_shell_snapshot_receiver(), - }; - assert_eq!( - test_powershell_shell.derive_exec_args("echo hello", false), - vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"] - ); - assert_eq!( - test_powershell_shell.derive_exec_args("echo hello", true), - vec!["pwsh.exe", "-Command", "echo hello"] - ); - } - - #[tokio::test] - async fn test_current_shell_detects_zsh() { - let shell = Command::new("sh") - .arg("-c") - .arg("echo $SHELL") - .output() - .unwrap(); - - let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string(); - if shell_path.ends_with("/zsh") { - assert_eq!( - default_user_shell(), - Shell { - shell_type: ShellType::Zsh, - shell_path: PathBuf::from(shell_path), - shell_snapshot: empty_shell_snapshot_receiver(), - } - ); - } - } - - #[tokio::test] - async fn detects_powershell_as_default() { - if !cfg!(windows) { - return; - } - - let powershell_shell = default_user_shell(); - let shell_path = powershell_shell.shell_path; - - assert!(shell_path.ends_with("pwsh.exe") || shell_path.ends_with("powershell.exe")); - } - - #[test] - fn finds_powershell() { - if !cfg!(windows) { - return; - } - - let powershell_shell = get_shell(ShellType::PowerShell, None).unwrap(); - let shell_path = powershell_shell.shell_path; - - assert!(shell_path.ends_with("pwsh.exe") || shell_path.ends_with("powershell.exe")); - } -} +#[path = "shell_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/shell_snapshot.rs b/codex-rs/core/src/shell_snapshot.rs index c10a332245..2c6c9b5295 100644 --- a/codex-rs/core/src/shell_snapshot.rs +++ b/codex-rs/core/src/shell_snapshot.rs @@ -544,425 +544,5 @@ async fn remove_snapshot_file(path: &Path) { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - #[cfg(unix)] - use std::os::unix::ffi::OsStrExt; - #[cfg(unix)] - use std::process::Command; - #[cfg(target_os = "linux")] - use std::process::Command as StdCommand; - - use tempfile::tempdir; - - #[cfg(unix)] - struct BlockingStdinPipe { - original: i32, - write_end: i32, - } - - #[cfg(unix)] - impl BlockingStdinPipe { - fn install() -> Result { - let mut fds = [0i32; 2]; - if unsafe { libc::pipe(fds.as_mut_ptr()) } == -1 { - return Err(std::io::Error::last_os_error()).context("create stdin pipe"); - } - - let original = unsafe { libc::dup(libc::STDIN_FILENO) }; - if original == -1 { - let err = std::io::Error::last_os_error(); - unsafe { - libc::close(fds[0]); - libc::close(fds[1]); - } - return Err(err).context("dup stdin"); - } - - if unsafe { libc::dup2(fds[0], libc::STDIN_FILENO) } == -1 { - let err = std::io::Error::last_os_error(); - unsafe { - libc::close(fds[0]); - libc::close(fds[1]); - libc::close(original); - } - return Err(err).context("replace stdin"); - } - - unsafe { - libc::close(fds[0]); - } - - Ok(Self { - original, - write_end: fds[1], - }) - } - } - - #[cfg(unix)] - impl Drop for BlockingStdinPipe { - fn drop(&mut self) { - unsafe { - libc::dup2(self.original, libc::STDIN_FILENO); - libc::close(self.original); - libc::close(self.write_end); - } - } - } - - #[cfg(not(target_os = "windows"))] - fn assert_posix_snapshot_sections(snapshot: &str) { - assert!(snapshot.contains("# Snapshot file")); - assert!(snapshot.contains("aliases ")); - assert!(snapshot.contains("exports ")); - assert!( - snapshot.contains("PATH"), - "snapshot should capture a PATH export" - ); - assert!(snapshot.contains("setopts ")); - } - - async fn get_snapshot(shell_type: ShellType) -> Result { - let dir = tempdir()?; - let path = dir.path().join("snapshot.sh"); - write_shell_snapshot(shell_type, &path, dir.path()).await?; - let content = fs::read_to_string(&path).await?; - Ok(content) - } - - #[test] - fn strip_snapshot_preamble_removes_leading_output() { - let snapshot = "noise\n# Snapshot file\nexport PATH=/bin\n"; - let cleaned = strip_snapshot_preamble(snapshot).expect("snapshot marker exists"); - assert_eq!(cleaned, "# Snapshot file\nexport PATH=/bin\n"); - } - - #[test] - fn strip_snapshot_preamble_requires_marker() { - let result = strip_snapshot_preamble("missing header"); - assert!(result.is_err()); - } - - #[cfg(unix)] - #[test] - fn bash_snapshot_filters_invalid_exports() -> Result<()> { - let output = Command::new("/bin/bash") - .arg("-c") - .arg(bash_snapshot_script()) - .env("BASH_ENV", "/dev/null") - .env("VALID_NAME", "ok") - .env("PWD", "/tmp/stale") - .env("NEXTEST_BIN_EXE_codex-write-config-schema", "/path/to/bin") - .env("BAD-NAME", "broken") - .output()?; - - assert!(output.status.success()); - - let stdout = String::from_utf8_lossy(&output.stdout); - assert!(stdout.contains("VALID_NAME")); - assert!(!stdout.contains("PWD=/tmp/stale")); - assert!(!stdout.contains("NEXTEST_BIN_EXE_codex-write-config-schema")); - assert!(!stdout.contains("BAD-NAME")); - - Ok(()) - } - - #[cfg(unix)] - #[test] - fn bash_snapshot_preserves_multiline_exports() -> Result<()> { - let multiline_cert = "-----BEGIN CERTIFICATE-----\nabc\n-----END CERTIFICATE-----"; - let output = Command::new("/bin/bash") - .arg("-c") - .arg(bash_snapshot_script()) - .env("BASH_ENV", "/dev/null") - .env("MULTILINE_CERT", multiline_cert) - .output()?; - - assert!(output.status.success()); - - let stdout = String::from_utf8_lossy(&output.stdout); - assert!( - stdout.contains("MULTILINE_CERT=") || stdout.contains("MULTILINE_CERT"), - "snapshot should include the multiline export name" - ); - - let dir = tempdir()?; - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write(&snapshot_path, stdout.as_bytes())?; - - let validate = Command::new("/bin/bash") - .arg("-c") - .arg("set -e; . \"$1\"") - .arg("bash") - .arg(&snapshot_path) - .env("BASH_ENV", "/dev/null") - .output()?; - - assert!( - validate.status.success(), - "snapshot validation failed: {}", - String::from_utf8_lossy(&validate.stderr) - ); - - Ok(()) - } - - #[cfg(unix)] - #[tokio::test] - async fn try_new_creates_and_deletes_snapshot_file() -> Result<()> { - let dir = tempdir()?; - let shell = Shell { - shell_type: ShellType::Bash, - shell_path: PathBuf::from("/bin/bash"), - shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), - }; - - let snapshot = ShellSnapshot::try_new(dir.path(), ThreadId::new(), dir.path(), &shell) - .await - .expect("snapshot should be created"); - let path = snapshot.path.clone(); - assert!(path.exists()); - assert_eq!(snapshot.cwd, dir.path().to_path_buf()); - - drop(snapshot); - - assert!(!path.exists()); - - Ok(()) - } - - #[cfg(unix)] - #[tokio::test] - async fn snapshot_shell_does_not_inherit_stdin() -> Result<()> { - let _stdin_guard = BlockingStdinPipe::install()?; - - let dir = tempdir()?; - let home = dir.path(); - let read_status_path = home.join("stdin-read-status"); - let read_status_display = read_status_path.display(); - // Persist the startup `read` exit status so the test can assert whether - // bash saw EOF on stdin after the snapshot process exits. - let bashrc = - format!("read -t 1 -r ignored\nprintf '%s' \"$?\" > \"{read_status_display}\"\n"); - fs::write(home.join(".bashrc"), bashrc).await?; - - let shell = Shell { - shell_type: ShellType::Bash, - shell_path: PathBuf::from("/bin/bash"), - shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), - }; - - let home_display = home.display(); - let script = format!( - "HOME=\"{home_display}\"; export HOME; {}", - bash_snapshot_script() - ); - let output = run_script_with_timeout(&shell, &script, Duration::from_secs(2), true, home) - .await - .context("run snapshot command")?; - let read_status = fs::read_to_string(&read_status_path) - .await - .context("read stdin probe status")?; - - assert_eq!( - read_status, "1", - "expected shell startup read to see EOF on stdin; status={read_status:?}" - ); - - assert!( - output.contains("# Snapshot file"), - "expected snapshot marker in output; output={output:?}" - ); - - Ok(()) - } - - #[cfg(target_os = "linux")] - #[tokio::test] - async fn timed_out_snapshot_shell_is_terminated() -> Result<()> { - use std::process::Stdio; - use tokio::time::Duration as TokioDuration; - use tokio::time::Instant; - use tokio::time::sleep; - - let dir = tempdir()?; - let pid_path = dir.path().join("pid"); - let script = format!("echo $$ > \"{}\"; sleep 30", pid_path.display()); - - let shell = Shell { - shell_type: ShellType::Sh, - shell_path: PathBuf::from("/bin/sh"), - shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), - }; - - let err = - run_script_with_timeout(&shell, &script, Duration::from_secs(1), true, dir.path()) - .await - .expect_err("snapshot shell should time out"); - assert!( - err.to_string().contains("timed out"), - "expected timeout error, got {err:?}" - ); - - let pid = fs::read_to_string(&pid_path) - .await - .expect("snapshot shell writes its pid before timing out") - .trim() - .parse::()?; - - let deadline = Instant::now() + TokioDuration::from_secs(1); - loop { - let kill_status = StdCommand::new("kill") - .arg("-0") - .arg(pid.to_string()) - .stderr(Stdio::null()) - .stdout(Stdio::null()) - .status()?; - if !kill_status.success() { - break; - } - if Instant::now() >= deadline { - panic!("timed out snapshot shell is still alive after grace period"); - } - sleep(TokioDuration::from_millis(50)).await; - } - - Ok(()) - } - - #[cfg(target_os = "macos")] - #[tokio::test] - async fn macos_zsh_snapshot_includes_sections() -> Result<()> { - let snapshot = get_snapshot(ShellType::Zsh).await?; - assert_posix_snapshot_sections(&snapshot); - Ok(()) - } - - #[cfg(target_os = "linux")] - #[tokio::test] - async fn linux_bash_snapshot_includes_sections() -> Result<()> { - let snapshot = get_snapshot(ShellType::Bash).await?; - assert_posix_snapshot_sections(&snapshot); - Ok(()) - } - - #[cfg(target_os = "linux")] - #[tokio::test] - async fn linux_sh_snapshot_includes_sections() -> Result<()> { - let snapshot = get_snapshot(ShellType::Sh).await?; - assert_posix_snapshot_sections(&snapshot); - Ok(()) - } - - #[cfg(target_os = "windows")] - #[ignore] - #[tokio::test] - async fn windows_powershell_snapshot_includes_sections() -> Result<()> { - let snapshot = get_snapshot(ShellType::PowerShell).await?; - assert!(snapshot.contains("# Snapshot file")); - assert!(snapshot.contains("aliases ")); - assert!(snapshot.contains("exports ")); - Ok(()) - } - - async fn write_rollout_stub(codex_home: &Path, session_id: ThreadId) -> Result { - let dir = codex_home - .join("sessions") - .join("2025") - .join("01") - .join("01"); - fs::create_dir_all(&dir).await?; - let path = dir.join(format!("rollout-2025-01-01T00-00-00-{session_id}.jsonl")); - fs::write(&path, "").await?; - Ok(path) - } - - #[tokio::test] - async fn cleanup_stale_snapshots_removes_orphans_and_keeps_live() -> Result<()> { - let dir = tempdir()?; - let codex_home = dir.path(); - let snapshot_dir = codex_home.join(SNAPSHOT_DIR); - fs::create_dir_all(&snapshot_dir).await?; - - let live_session = ThreadId::new(); - let orphan_session = ThreadId::new(); - let live_snapshot = snapshot_dir.join(format!("{live_session}.sh")); - let orphan_snapshot = snapshot_dir.join(format!("{orphan_session}.sh")); - let invalid_snapshot = snapshot_dir.join("not-a-snapshot.txt"); - - write_rollout_stub(codex_home, live_session).await?; - fs::write(&live_snapshot, "live").await?; - fs::write(&orphan_snapshot, "orphan").await?; - fs::write(&invalid_snapshot, "invalid").await?; - - cleanup_stale_snapshots(codex_home, ThreadId::new()).await?; - - assert_eq!(live_snapshot.exists(), true); - assert_eq!(orphan_snapshot.exists(), false); - assert_eq!(invalid_snapshot.exists(), false); - Ok(()) - } - - #[cfg(unix)] - #[tokio::test] - async fn cleanup_stale_snapshots_removes_stale_rollouts() -> Result<()> { - let dir = tempdir()?; - let codex_home = dir.path(); - let snapshot_dir = codex_home.join(SNAPSHOT_DIR); - fs::create_dir_all(&snapshot_dir).await?; - - let stale_session = ThreadId::new(); - let stale_snapshot = snapshot_dir.join(format!("{stale_session}.sh")); - let rollout_path = write_rollout_stub(codex_home, stale_session).await?; - fs::write(&stale_snapshot, "stale").await?; - - set_file_mtime(&rollout_path, SNAPSHOT_RETENTION + Duration::from_secs(60))?; - - cleanup_stale_snapshots(codex_home, ThreadId::new()).await?; - - assert_eq!(stale_snapshot.exists(), false); - Ok(()) - } - - #[cfg(unix)] - #[tokio::test] - async fn cleanup_stale_snapshots_skips_active_session() -> Result<()> { - let dir = tempdir()?; - let codex_home = dir.path(); - let snapshot_dir = codex_home.join(SNAPSHOT_DIR); - fs::create_dir_all(&snapshot_dir).await?; - - let active_session = ThreadId::new(); - let active_snapshot = snapshot_dir.join(format!("{active_session}.sh")); - let rollout_path = write_rollout_stub(codex_home, active_session).await?; - fs::write(&active_snapshot, "active").await?; - - set_file_mtime(&rollout_path, SNAPSHOT_RETENTION + Duration::from_secs(60))?; - - cleanup_stale_snapshots(codex_home, active_session).await?; - - assert_eq!(active_snapshot.exists(), true); - Ok(()) - } - - #[cfg(unix)] - fn set_file_mtime(path: &Path, age: Duration) -> Result<()> { - let now = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH)? - .as_secs() - .saturating_sub(age.as_secs()); - let tv_sec = now - .try_into() - .map_err(|_| anyhow!("Snapshot mtime is out of range for libc::timespec"))?; - let ts = libc::timespec { tv_sec, tv_nsec: 0 }; - let times = [ts, ts]; - let c_path = std::ffi::CString::new(path.as_os_str().as_bytes())?; - let result = unsafe { libc::utimensat(libc::AT_FDCWD, c_path.as_ptr(), times.as_ptr(), 0) }; - if result != 0 { - return Err(std::io::Error::last_os_error().into()); - } - Ok(()) - } -} +#[path = "shell_snapshot_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/shell_snapshot_tests.rs b/codex-rs/core/src/shell_snapshot_tests.rs new file mode 100644 index 0000000000..558a40e2e2 --- /dev/null +++ b/codex-rs/core/src/shell_snapshot_tests.rs @@ -0,0 +1,418 @@ +use super::*; +use pretty_assertions::assert_eq; +#[cfg(unix)] +use std::os::unix::ffi::OsStrExt; +#[cfg(unix)] +use std::process::Command; +#[cfg(target_os = "linux")] +use std::process::Command as StdCommand; + +use tempfile::tempdir; + +#[cfg(unix)] +struct BlockingStdinPipe { + original: i32, + write_end: i32, +} + +#[cfg(unix)] +impl BlockingStdinPipe { + fn install() -> Result { + let mut fds = [0i32; 2]; + if unsafe { libc::pipe(fds.as_mut_ptr()) } == -1 { + return Err(std::io::Error::last_os_error()).context("create stdin pipe"); + } + + let original = unsafe { libc::dup(libc::STDIN_FILENO) }; + if original == -1 { + let err = std::io::Error::last_os_error(); + unsafe { + libc::close(fds[0]); + libc::close(fds[1]); + } + return Err(err).context("dup stdin"); + } + + if unsafe { libc::dup2(fds[0], libc::STDIN_FILENO) } == -1 { + let err = std::io::Error::last_os_error(); + unsafe { + libc::close(fds[0]); + libc::close(fds[1]); + libc::close(original); + } + return Err(err).context("replace stdin"); + } + + unsafe { + libc::close(fds[0]); + } + + Ok(Self { + original, + write_end: fds[1], + }) + } +} + +#[cfg(unix)] +impl Drop for BlockingStdinPipe { + fn drop(&mut self) { + unsafe { + libc::dup2(self.original, libc::STDIN_FILENO); + libc::close(self.original); + libc::close(self.write_end); + } + } +} + +#[cfg(not(target_os = "windows"))] +fn assert_posix_snapshot_sections(snapshot: &str) { + assert!(snapshot.contains("# Snapshot file")); + assert!(snapshot.contains("aliases ")); + assert!(snapshot.contains("exports ")); + assert!( + snapshot.contains("PATH"), + "snapshot should capture a PATH export" + ); + assert!(snapshot.contains("setopts ")); +} + +async fn get_snapshot(shell_type: ShellType) -> Result { + let dir = tempdir()?; + let path = dir.path().join("snapshot.sh"); + write_shell_snapshot(shell_type, &path, dir.path()).await?; + let content = fs::read_to_string(&path).await?; + Ok(content) +} + +#[test] +fn strip_snapshot_preamble_removes_leading_output() { + let snapshot = "noise\n# Snapshot file\nexport PATH=/bin\n"; + let cleaned = strip_snapshot_preamble(snapshot).expect("snapshot marker exists"); + assert_eq!(cleaned, "# Snapshot file\nexport PATH=/bin\n"); +} + +#[test] +fn strip_snapshot_preamble_requires_marker() { + let result = strip_snapshot_preamble("missing header"); + assert!(result.is_err()); +} + +#[cfg(unix)] +#[test] +fn bash_snapshot_filters_invalid_exports() -> Result<()> { + let output = Command::new("/bin/bash") + .arg("-c") + .arg(bash_snapshot_script()) + .env("BASH_ENV", "/dev/null") + .env("VALID_NAME", "ok") + .env("PWD", "/tmp/stale") + .env("NEXTEST_BIN_EXE_codex-write-config-schema", "/path/to/bin") + .env("BAD-NAME", "broken") + .output()?; + + assert!(output.status.success()); + + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("VALID_NAME")); + assert!(!stdout.contains("PWD=/tmp/stale")); + assert!(!stdout.contains("NEXTEST_BIN_EXE_codex-write-config-schema")); + assert!(!stdout.contains("BAD-NAME")); + + Ok(()) +} + +#[cfg(unix)] +#[test] +fn bash_snapshot_preserves_multiline_exports() -> Result<()> { + let multiline_cert = "-----BEGIN CERTIFICATE-----\nabc\n-----END CERTIFICATE-----"; + let output = Command::new("/bin/bash") + .arg("-c") + .arg(bash_snapshot_script()) + .env("BASH_ENV", "/dev/null") + .env("MULTILINE_CERT", multiline_cert) + .output()?; + + assert!(output.status.success()); + + let stdout = String::from_utf8_lossy(&output.stdout); + assert!( + stdout.contains("MULTILINE_CERT=") || stdout.contains("MULTILINE_CERT"), + "snapshot should include the multiline export name" + ); + + let dir = tempdir()?; + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write(&snapshot_path, stdout.as_bytes())?; + + let validate = Command::new("/bin/bash") + .arg("-c") + .arg("set -e; . \"$1\"") + .arg("bash") + .arg(&snapshot_path) + .env("BASH_ENV", "/dev/null") + .output()?; + + assert!( + validate.status.success(), + "snapshot validation failed: {}", + String::from_utf8_lossy(&validate.stderr) + ); + + Ok(()) +} + +#[cfg(unix)] +#[tokio::test] +async fn try_new_creates_and_deletes_snapshot_file() -> Result<()> { + let dir = tempdir()?; + let shell = Shell { + shell_type: ShellType::Bash, + shell_path: PathBuf::from("/bin/bash"), + shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), + }; + + let snapshot = ShellSnapshot::try_new(dir.path(), ThreadId::new(), dir.path(), &shell) + .await + .expect("snapshot should be created"); + let path = snapshot.path.clone(); + assert!(path.exists()); + assert_eq!(snapshot.cwd, dir.path().to_path_buf()); + + drop(snapshot); + + assert!(!path.exists()); + + Ok(()) +} + +#[cfg(unix)] +#[tokio::test] +async fn snapshot_shell_does_not_inherit_stdin() -> Result<()> { + let _stdin_guard = BlockingStdinPipe::install()?; + + let dir = tempdir()?; + let home = dir.path(); + let read_status_path = home.join("stdin-read-status"); + let read_status_display = read_status_path.display(); + // Persist the startup `read` exit status so the test can assert whether + // bash saw EOF on stdin after the snapshot process exits. + let bashrc = format!("read -t 1 -r ignored\nprintf '%s' \"$?\" > \"{read_status_display}\"\n"); + fs::write(home.join(".bashrc"), bashrc).await?; + + let shell = Shell { + shell_type: ShellType::Bash, + shell_path: PathBuf::from("/bin/bash"), + shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), + }; + + let home_display = home.display(); + let script = format!( + "HOME=\"{home_display}\"; export HOME; {}", + bash_snapshot_script() + ); + let output = run_script_with_timeout(&shell, &script, Duration::from_secs(2), true, home) + .await + .context("run snapshot command")?; + let read_status = fs::read_to_string(&read_status_path) + .await + .context("read stdin probe status")?; + + assert_eq!( + read_status, "1", + "expected shell startup read to see EOF on stdin; status={read_status:?}" + ); + + assert!( + output.contains("# Snapshot file"), + "expected snapshot marker in output; output={output:?}" + ); + + Ok(()) +} + +#[cfg(target_os = "linux")] +#[tokio::test] +async fn timed_out_snapshot_shell_is_terminated() -> Result<()> { + use std::process::Stdio; + use tokio::time::Duration as TokioDuration; + use tokio::time::Instant; + use tokio::time::sleep; + + let dir = tempdir()?; + let pid_path = dir.path().join("pid"); + let script = format!("echo $$ > \"{}\"; sleep 30", pid_path.display()); + + let shell = Shell { + shell_type: ShellType::Sh, + shell_path: PathBuf::from("/bin/sh"), + shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), + }; + + let err = run_script_with_timeout(&shell, &script, Duration::from_secs(1), true, dir.path()) + .await + .expect_err("snapshot shell should time out"); + assert!( + err.to_string().contains("timed out"), + "expected timeout error, got {err:?}" + ); + + let pid = fs::read_to_string(&pid_path) + .await + .expect("snapshot shell writes its pid before timing out") + .trim() + .parse::()?; + + let deadline = Instant::now() + TokioDuration::from_secs(1); + loop { + let kill_status = StdCommand::new("kill") + .arg("-0") + .arg(pid.to_string()) + .stderr(Stdio::null()) + .stdout(Stdio::null()) + .status()?; + if !kill_status.success() { + break; + } + if Instant::now() >= deadline { + panic!("timed out snapshot shell is still alive after grace period"); + } + sleep(TokioDuration::from_millis(50)).await; + } + + Ok(()) +} + +#[cfg(target_os = "macos")] +#[tokio::test] +async fn macos_zsh_snapshot_includes_sections() -> Result<()> { + let snapshot = get_snapshot(ShellType::Zsh).await?; + assert_posix_snapshot_sections(&snapshot); + Ok(()) +} + +#[cfg(target_os = "linux")] +#[tokio::test] +async fn linux_bash_snapshot_includes_sections() -> Result<()> { + let snapshot = get_snapshot(ShellType::Bash).await?; + assert_posix_snapshot_sections(&snapshot); + Ok(()) +} + +#[cfg(target_os = "linux")] +#[tokio::test] +async fn linux_sh_snapshot_includes_sections() -> Result<()> { + let snapshot = get_snapshot(ShellType::Sh).await?; + assert_posix_snapshot_sections(&snapshot); + Ok(()) +} + +#[cfg(target_os = "windows")] +#[ignore] +#[tokio::test] +async fn windows_powershell_snapshot_includes_sections() -> Result<()> { + let snapshot = get_snapshot(ShellType::PowerShell).await?; + assert!(snapshot.contains("# Snapshot file")); + assert!(snapshot.contains("aliases ")); + assert!(snapshot.contains("exports ")); + Ok(()) +} + +async fn write_rollout_stub(codex_home: &Path, session_id: ThreadId) -> Result { + let dir = codex_home + .join("sessions") + .join("2025") + .join("01") + .join("01"); + fs::create_dir_all(&dir).await?; + let path = dir.join(format!("rollout-2025-01-01T00-00-00-{session_id}.jsonl")); + fs::write(&path, "").await?; + Ok(path) +} + +#[tokio::test] +async fn cleanup_stale_snapshots_removes_orphans_and_keeps_live() -> Result<()> { + let dir = tempdir()?; + let codex_home = dir.path(); + let snapshot_dir = codex_home.join(SNAPSHOT_DIR); + fs::create_dir_all(&snapshot_dir).await?; + + let live_session = ThreadId::new(); + let orphan_session = ThreadId::new(); + let live_snapshot = snapshot_dir.join(format!("{live_session}.sh")); + let orphan_snapshot = snapshot_dir.join(format!("{orphan_session}.sh")); + let invalid_snapshot = snapshot_dir.join("not-a-snapshot.txt"); + + write_rollout_stub(codex_home, live_session).await?; + fs::write(&live_snapshot, "live").await?; + fs::write(&orphan_snapshot, "orphan").await?; + fs::write(&invalid_snapshot, "invalid").await?; + + cleanup_stale_snapshots(codex_home, ThreadId::new()).await?; + + assert_eq!(live_snapshot.exists(), true); + assert_eq!(orphan_snapshot.exists(), false); + assert_eq!(invalid_snapshot.exists(), false); + Ok(()) +} + +#[cfg(unix)] +#[tokio::test] +async fn cleanup_stale_snapshots_removes_stale_rollouts() -> Result<()> { + let dir = tempdir()?; + let codex_home = dir.path(); + let snapshot_dir = codex_home.join(SNAPSHOT_DIR); + fs::create_dir_all(&snapshot_dir).await?; + + let stale_session = ThreadId::new(); + let stale_snapshot = snapshot_dir.join(format!("{stale_session}.sh")); + let rollout_path = write_rollout_stub(codex_home, stale_session).await?; + fs::write(&stale_snapshot, "stale").await?; + + set_file_mtime(&rollout_path, SNAPSHOT_RETENTION + Duration::from_secs(60))?; + + cleanup_stale_snapshots(codex_home, ThreadId::new()).await?; + + assert_eq!(stale_snapshot.exists(), false); + Ok(()) +} + +#[cfg(unix)] +#[tokio::test] +async fn cleanup_stale_snapshots_skips_active_session() -> Result<()> { + let dir = tempdir()?; + let codex_home = dir.path(); + let snapshot_dir = codex_home.join(SNAPSHOT_DIR); + fs::create_dir_all(&snapshot_dir).await?; + + let active_session = ThreadId::new(); + let active_snapshot = snapshot_dir.join(format!("{active_session}.sh")); + let rollout_path = write_rollout_stub(codex_home, active_session).await?; + fs::write(&active_snapshot, "active").await?; + + set_file_mtime(&rollout_path, SNAPSHOT_RETENTION + Duration::from_secs(60))?; + + cleanup_stale_snapshots(codex_home, active_session).await?; + + assert_eq!(active_snapshot.exists(), true); + Ok(()) +} + +#[cfg(unix)] +fn set_file_mtime(path: &Path, age: Duration) -> Result<()> { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH)? + .as_secs() + .saturating_sub(age.as_secs()); + let tv_sec = now + .try_into() + .map_err(|_| anyhow!("Snapshot mtime is out of range for libc::timespec"))?; + let ts = libc::timespec { tv_sec, tv_nsec: 0 }; + let times = [ts, ts]; + let c_path = std::ffi::CString::new(path.as_os_str().as_bytes())?; + let result = unsafe { libc::utimensat(libc::AT_FDCWD, c_path.as_ptr(), times.as_ptr(), 0) }; + if result != 0 { + return Err(std::io::Error::last_os_error().into()); + } + Ok(()) +} diff --git a/codex-rs/core/src/shell_tests.rs b/codex-rs/core/src/shell_tests.rs new file mode 100644 index 0000000000..88025e8b85 --- /dev/null +++ b/codex-rs/core/src/shell_tests.rs @@ -0,0 +1,168 @@ +use super::*; +use std::path::PathBuf; +use std::process::Command; + +#[test] +#[cfg(target_os = "macos")] +fn detects_zsh() { + let zsh_shell = get_shell(ShellType::Zsh, None).unwrap(); + + let shell_path = zsh_shell.shell_path; + + assert_eq!(shell_path, std::path::Path::new("/bin/zsh")); +} + +#[test] +#[cfg(target_os = "macos")] +fn fish_fallback_to_zsh() { + let zsh_shell = default_user_shell_from_path(Some(PathBuf::from("/bin/fish"))); + + let shell_path = zsh_shell.shell_path; + + assert_eq!(shell_path, std::path::Path::new("/bin/zsh")); +} + +#[test] +fn detects_bash() { + let bash_shell = get_shell(ShellType::Bash, None).unwrap(); + let shell_path = bash_shell.shell_path; + + assert!( + shell_path.file_name().and_then(|name| name.to_str()) == Some("bash"), + "shell path: {shell_path:?}", + ); +} + +#[test] +fn detects_sh() { + let sh_shell = get_shell(ShellType::Sh, None).unwrap(); + let shell_path = sh_shell.shell_path; + assert!( + shell_path.file_name().and_then(|name| name.to_str()) == Some("sh"), + "shell path: {shell_path:?}", + ); +} + +#[test] +fn can_run_on_shell_test() { + let cmd = "echo \"Works\""; + if cfg!(windows) { + assert!(shell_works( + get_shell(ShellType::PowerShell, None), + "Out-String 'Works'", + true, + )); + assert!(shell_works(get_shell(ShellType::Cmd, None), cmd, true,)); + assert!(shell_works(Some(ultimate_fallback_shell()), cmd, true)); + } else { + assert!(shell_works(Some(ultimate_fallback_shell()), cmd, true)); + assert!(shell_works(get_shell(ShellType::Zsh, None), cmd, false)); + assert!(shell_works(get_shell(ShellType::Bash, None), cmd, true)); + assert!(shell_works(get_shell(ShellType::Sh, None), cmd, true)); + } +} + +fn shell_works(shell: Option, command: &str, required: bool) -> bool { + if let Some(shell) = shell { + let args = shell.derive_exec_args(command, false); + let output = Command::new(args[0].clone()) + .args(&args[1..]) + .output() + .unwrap(); + assert!(output.status.success()); + assert!(String::from_utf8_lossy(&output.stdout).contains("Works")); + true + } else { + !required + } +} + +#[test] +fn derive_exec_args() { + let test_bash_shell = Shell { + shell_type: ShellType::Bash, + shell_path: PathBuf::from("/bin/bash"), + shell_snapshot: empty_shell_snapshot_receiver(), + }; + assert_eq!( + test_bash_shell.derive_exec_args("echo hello", false), + vec!["/bin/bash", "-c", "echo hello"] + ); + assert_eq!( + test_bash_shell.derive_exec_args("echo hello", true), + vec!["/bin/bash", "-lc", "echo hello"] + ); + + let test_zsh_shell = Shell { + shell_type: ShellType::Zsh, + shell_path: PathBuf::from("/bin/zsh"), + shell_snapshot: empty_shell_snapshot_receiver(), + }; + assert_eq!( + test_zsh_shell.derive_exec_args("echo hello", false), + vec!["/bin/zsh", "-c", "echo hello"] + ); + assert_eq!( + test_zsh_shell.derive_exec_args("echo hello", true), + vec!["/bin/zsh", "-lc", "echo hello"] + ); + + let test_powershell_shell = Shell { + shell_type: ShellType::PowerShell, + shell_path: PathBuf::from("pwsh.exe"), + shell_snapshot: empty_shell_snapshot_receiver(), + }; + assert_eq!( + test_powershell_shell.derive_exec_args("echo hello", false), + vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"] + ); + assert_eq!( + test_powershell_shell.derive_exec_args("echo hello", true), + vec!["pwsh.exe", "-Command", "echo hello"] + ); +} + +#[tokio::test] +async fn test_current_shell_detects_zsh() { + let shell = Command::new("sh") + .arg("-c") + .arg("echo $SHELL") + .output() + .unwrap(); + + let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string(); + if shell_path.ends_with("/zsh") { + assert_eq!( + default_user_shell(), + Shell { + shell_type: ShellType::Zsh, + shell_path: PathBuf::from(shell_path), + shell_snapshot: empty_shell_snapshot_receiver(), + } + ); + } +} + +#[tokio::test] +async fn detects_powershell_as_default() { + if !cfg!(windows) { + return; + } + + let powershell_shell = default_user_shell(); + let shell_path = powershell_shell.shell_path; + + assert!(shell_path.ends_with("pwsh.exe") || shell_path.ends_with("powershell.exe")); +} + +#[test] +fn finds_powershell() { + if !cfg!(windows) { + return; + } + + let powershell_shell = get_shell(ShellType::PowerShell, None).unwrap(); + let shell_path = powershell_shell.shell_path; + + assert!(shell_path.ends_with("pwsh.exe") || shell_path.ends_with("powershell.exe")); +} diff --git a/codex-rs/core/src/skills/injection.rs b/codex-rs/core/src/skills/injection.rs index d40e1bed02..549ccca87e 100644 --- a/codex-rs/core/src/skills/injection.rs +++ b/codex-rs/core/src/skills/injection.rs @@ -489,352 +489,5 @@ fn is_mention_name_char(byte: u8) -> bool { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use std::collections::HashMap; - use std::collections::HashSet; - - fn make_skill(name: &str, path: &str) -> SkillMetadata { - SkillMetadata { - name: name.to_string(), - description: format!("{name} skill"), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: PathBuf::from(path), - scope: codex_protocol::protocol::SkillScope::User, - } - } - - fn set<'a>(items: &'a [&'a str]) -> HashSet<&'a str> { - items.iter().copied().collect() - } - - fn assert_mentions(text: &str, expected_names: &[&str], expected_paths: &[&str]) { - let mentions = extract_tool_mentions(text); - assert_eq!(mentions.names, set(expected_names)); - assert_eq!(mentions.paths, set(expected_paths)); - } - - fn collect_mentions( - inputs: &[UserInput], - skills: &[SkillMetadata], - disabled_paths: &HashSet, - connector_slug_counts: &HashMap, - ) -> Vec { - collect_explicit_skill_mentions(inputs, skills, disabled_paths, connector_slug_counts) - } - - #[test] - fn text_mentions_skill_requires_exact_boundary() { - assert_eq!( - true, - text_mentions_skill("use $notion-research-doc please", "notion-research-doc") - ); - assert_eq!( - true, - text_mentions_skill("($notion-research-doc)", "notion-research-doc") - ); - assert_eq!( - true, - text_mentions_skill("$notion-research-doc.", "notion-research-doc") - ); - assert_eq!( - false, - text_mentions_skill("$notion-research-docs", "notion-research-doc") - ); - assert_eq!( - false, - text_mentions_skill("$notion-research-doc_extra", "notion-research-doc") - ); - } - - #[test] - fn text_mentions_skill_handles_end_boundary_and_near_misses() { - assert_eq!(true, text_mentions_skill("$alpha-skill", "alpha-skill")); - assert_eq!(false, text_mentions_skill("$alpha-skillx", "alpha-skill")); - assert_eq!( - true, - text_mentions_skill("$alpha-skillx and later $alpha-skill ", "alpha-skill") - ); - } - - #[test] - fn text_mentions_skill_handles_many_dollars_without_looping() { - let prefix = "$".repeat(256); - let text = format!("{prefix} not-a-mention"); - assert_eq!(false, text_mentions_skill(&text, "alpha-skill")); - } - - #[test] - fn extract_tool_mentions_handles_plain_and_linked_mentions() { - assert_mentions( - "use $alpha and [$beta](/tmp/beta)", - &["alpha", "beta"], - &["/tmp/beta"], - ); - } - - #[test] - fn extract_tool_mentions_skips_common_env_vars() { - assert_mentions("use $PATH and $alpha", &["alpha"], &[]); - assert_mentions("use [$HOME](/tmp/skill)", &[], &[]); - assert_mentions("use $XDG_CONFIG_HOME and $beta", &["beta"], &[]); - } - - #[test] - fn extract_tool_mentions_requires_link_syntax() { - assert_mentions("[beta](/tmp/beta)", &[], &[]); - assert_mentions("[$beta] /tmp/beta", &["beta"], &[]); - assert_mentions("[$beta]()", &["beta"], &[]); - } - - #[test] - fn extract_tool_mentions_trims_linked_paths_and_allows_spacing() { - assert_mentions("use [$beta] ( /tmp/beta )", &["beta"], &["/tmp/beta"]); - } - - #[test] - fn extract_tool_mentions_stops_at_non_name_chars() { - assert_mentions( - "use $alpha.skill and $beta_extra", - &["alpha", "beta_extra"], - &[], - ); - } - - #[test] - fn extract_tool_mentions_keeps_plugin_skill_namespaces() { - assert_mentions( - "use $slack:search and $alpha", - &["alpha", "slack:search"], - &[], - ); - } - - #[test] - fn collect_explicit_skill_mentions_text_respects_skill_order() { - let alpha = make_skill("alpha-skill", "/tmp/alpha"); - let beta = make_skill("beta-skill", "/tmp/beta"); - let skills = vec![beta.clone(), alpha.clone()]; - let inputs = vec![UserInput::Text { - text: "first $alpha-skill then $beta-skill".to_string(), - text_elements: Vec::new(), - }]; - let connector_counts = HashMap::new(); - - let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); - - // Text scanning should not change the previous selection ordering semantics. - assert_eq!(selected, vec![beta, alpha]); - } - - #[test] - fn collect_explicit_skill_mentions_prioritizes_structured_inputs() { - let alpha = make_skill("alpha-skill", "/tmp/alpha"); - let beta = make_skill("beta-skill", "/tmp/beta"); - let skills = vec![alpha.clone(), beta.clone()]; - let inputs = vec![ - UserInput::Text { - text: "please run $alpha-skill".to_string(), - text_elements: Vec::new(), - }, - UserInput::Skill { - name: "beta-skill".to_string(), - path: PathBuf::from("/tmp/beta"), - }, - ]; - let connector_counts = HashMap::new(); - - let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); - - assert_eq!(selected, vec![beta, alpha]); - } - - #[test] - fn collect_explicit_skill_mentions_skips_invalid_structured_and_blocks_plain_fallback() { - let alpha = make_skill("alpha-skill", "/tmp/alpha"); - let skills = vec![alpha]; - let inputs = vec![ - UserInput::Text { - text: "please run $alpha-skill".to_string(), - text_elements: Vec::new(), - }, - UserInput::Skill { - name: "alpha-skill".to_string(), - path: PathBuf::from("/tmp/missing"), - }, - ]; - let connector_counts = HashMap::new(); - - let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); - - assert_eq!(selected, Vec::new()); - } - - #[test] - fn collect_explicit_skill_mentions_skips_disabled_structured_and_blocks_plain_fallback() { - let alpha = make_skill("alpha-skill", "/tmp/alpha"); - let skills = vec![alpha]; - let inputs = vec![ - UserInput::Text { - text: "please run $alpha-skill".to_string(), - text_elements: Vec::new(), - }, - UserInput::Skill { - name: "alpha-skill".to_string(), - path: PathBuf::from("/tmp/alpha"), - }, - ]; - let disabled = HashSet::from([PathBuf::from("/tmp/alpha")]); - let connector_counts = HashMap::new(); - - let selected = collect_mentions(&inputs, &skills, &disabled, &connector_counts); - - assert_eq!(selected, Vec::new()); - } - - #[test] - fn collect_explicit_skill_mentions_dedupes_by_path() { - let alpha = make_skill("alpha-skill", "/tmp/alpha"); - let skills = vec![alpha.clone()]; - let inputs = vec![UserInput::Text { - text: "use [$alpha-skill](/tmp/alpha) and [$alpha-skill](/tmp/alpha)".to_string(), - text_elements: Vec::new(), - }]; - let connector_counts = HashMap::new(); - - let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); - - assert_eq!(selected, vec![alpha]); - } - - #[test] - fn collect_explicit_skill_mentions_skips_ambiguous_name() { - let alpha = make_skill("demo-skill", "/tmp/alpha"); - let beta = make_skill("demo-skill", "/tmp/beta"); - let skills = vec![alpha, beta]; - let inputs = vec![UserInput::Text { - text: "use $demo-skill and again $demo-skill".to_string(), - text_elements: Vec::new(), - }]; - let connector_counts = HashMap::new(); - - let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); - - assert_eq!(selected, Vec::new()); - } - - #[test] - fn collect_explicit_skill_mentions_prefers_linked_path_over_name() { - let alpha = make_skill("demo-skill", "/tmp/alpha"); - let beta = make_skill("demo-skill", "/tmp/beta"); - let skills = vec![alpha, beta.clone()]; - let inputs = vec![UserInput::Text { - text: "use $demo-skill and [$demo-skill](/tmp/beta)".to_string(), - text_elements: Vec::new(), - }]; - let connector_counts = HashMap::new(); - - let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); - - assert_eq!(selected, vec![beta]); - } - - #[test] - fn collect_explicit_skill_mentions_skips_plain_name_when_connector_matches() { - let alpha = make_skill("alpha-skill", "/tmp/alpha"); - let skills = vec![alpha]; - let inputs = vec![UserInput::Text { - text: "use $alpha-skill".to_string(), - text_elements: Vec::new(), - }]; - let connector_counts = HashMap::from([("alpha-skill".to_string(), 1)]); - - let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); - - assert_eq!(selected, Vec::new()); - } - - #[test] - fn collect_explicit_skill_mentions_allows_explicit_path_with_connector_conflict() { - let alpha = make_skill("alpha-skill", "/tmp/alpha"); - let skills = vec![alpha.clone()]; - let inputs = vec![UserInput::Text { - text: "use [$alpha-skill](/tmp/alpha)".to_string(), - text_elements: Vec::new(), - }]; - let connector_counts = HashMap::from([("alpha-skill".to_string(), 1)]); - - let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); - - assert_eq!(selected, vec![alpha]); - } - - #[test] - fn collect_explicit_skill_mentions_skips_when_linked_path_disabled() { - let alpha = make_skill("demo-skill", "/tmp/alpha"); - let beta = make_skill("demo-skill", "/tmp/beta"); - let skills = vec![alpha, beta]; - let inputs = vec![UserInput::Text { - text: "use [$demo-skill](/tmp/alpha)".to_string(), - text_elements: Vec::new(), - }]; - let disabled = HashSet::from([PathBuf::from("/tmp/alpha")]); - let connector_counts = HashMap::new(); - - let selected = collect_mentions(&inputs, &skills, &disabled, &connector_counts); - - assert_eq!(selected, Vec::new()); - } - - #[test] - fn collect_explicit_skill_mentions_prefers_resource_path() { - let alpha = make_skill("demo-skill", "/tmp/alpha"); - let beta = make_skill("demo-skill", "/tmp/beta"); - let skills = vec![alpha, beta.clone()]; - let inputs = vec![UserInput::Text { - text: "use [$demo-skill](/tmp/beta)".to_string(), - text_elements: Vec::new(), - }]; - let connector_counts = HashMap::new(); - - let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); - - assert_eq!(selected, vec![beta]); - } - - #[test] - fn collect_explicit_skill_mentions_skips_missing_path_with_no_fallback() { - let alpha = make_skill("demo-skill", "/tmp/alpha"); - let beta = make_skill("demo-skill", "/tmp/beta"); - let skills = vec![alpha, beta]; - let inputs = vec![UserInput::Text { - text: "use [$demo-skill](/tmp/missing)".to_string(), - text_elements: Vec::new(), - }]; - let connector_counts = HashMap::new(); - - let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); - - assert_eq!(selected, Vec::new()); - } - - #[test] - fn collect_explicit_skill_mentions_skips_missing_path_without_fallback() { - let alpha = make_skill("demo-skill", "/tmp/alpha"); - let skills = vec![alpha]; - let inputs = vec![UserInput::Text { - text: "use [$demo-skill](/tmp/missing)".to_string(), - text_elements: Vec::new(), - }]; - let connector_counts = HashMap::new(); - - let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); - - assert_eq!(selected, Vec::new()); - } -} +#[path = "injection_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/skills/injection_tests.rs b/codex-rs/core/src/skills/injection_tests.rs new file mode 100644 index 0000000000..74ff315bb8 --- /dev/null +++ b/codex-rs/core/src/skills/injection_tests.rs @@ -0,0 +1,347 @@ +use super::*; +use pretty_assertions::assert_eq; +use std::collections::HashMap; +use std::collections::HashSet; + +fn make_skill(name: &str, path: &str) -> SkillMetadata { + SkillMetadata { + name: name.to_string(), + description: format!("{name} skill"), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: PathBuf::from(path), + scope: codex_protocol::protocol::SkillScope::User, + } +} + +fn set<'a>(items: &'a [&'a str]) -> HashSet<&'a str> { + items.iter().copied().collect() +} + +fn assert_mentions(text: &str, expected_names: &[&str], expected_paths: &[&str]) { + let mentions = extract_tool_mentions(text); + assert_eq!(mentions.names, set(expected_names)); + assert_eq!(mentions.paths, set(expected_paths)); +} + +fn collect_mentions( + inputs: &[UserInput], + skills: &[SkillMetadata], + disabled_paths: &HashSet, + connector_slug_counts: &HashMap, +) -> Vec { + collect_explicit_skill_mentions(inputs, skills, disabled_paths, connector_slug_counts) +} + +#[test] +fn text_mentions_skill_requires_exact_boundary() { + assert_eq!( + true, + text_mentions_skill("use $notion-research-doc please", "notion-research-doc") + ); + assert_eq!( + true, + text_mentions_skill("($notion-research-doc)", "notion-research-doc") + ); + assert_eq!( + true, + text_mentions_skill("$notion-research-doc.", "notion-research-doc") + ); + assert_eq!( + false, + text_mentions_skill("$notion-research-docs", "notion-research-doc") + ); + assert_eq!( + false, + text_mentions_skill("$notion-research-doc_extra", "notion-research-doc") + ); +} + +#[test] +fn text_mentions_skill_handles_end_boundary_and_near_misses() { + assert_eq!(true, text_mentions_skill("$alpha-skill", "alpha-skill")); + assert_eq!(false, text_mentions_skill("$alpha-skillx", "alpha-skill")); + assert_eq!( + true, + text_mentions_skill("$alpha-skillx and later $alpha-skill ", "alpha-skill") + ); +} + +#[test] +fn text_mentions_skill_handles_many_dollars_without_looping() { + let prefix = "$".repeat(256); + let text = format!("{prefix} not-a-mention"); + assert_eq!(false, text_mentions_skill(&text, "alpha-skill")); +} + +#[test] +fn extract_tool_mentions_handles_plain_and_linked_mentions() { + assert_mentions( + "use $alpha and [$beta](/tmp/beta)", + &["alpha", "beta"], + &["/tmp/beta"], + ); +} + +#[test] +fn extract_tool_mentions_skips_common_env_vars() { + assert_mentions("use $PATH and $alpha", &["alpha"], &[]); + assert_mentions("use [$HOME](/tmp/skill)", &[], &[]); + assert_mentions("use $XDG_CONFIG_HOME and $beta", &["beta"], &[]); +} + +#[test] +fn extract_tool_mentions_requires_link_syntax() { + assert_mentions("[beta](/tmp/beta)", &[], &[]); + assert_mentions("[$beta] /tmp/beta", &["beta"], &[]); + assert_mentions("[$beta]()", &["beta"], &[]); +} + +#[test] +fn extract_tool_mentions_trims_linked_paths_and_allows_spacing() { + assert_mentions("use [$beta] ( /tmp/beta )", &["beta"], &["/tmp/beta"]); +} + +#[test] +fn extract_tool_mentions_stops_at_non_name_chars() { + assert_mentions( + "use $alpha.skill and $beta_extra", + &["alpha", "beta_extra"], + &[], + ); +} + +#[test] +fn extract_tool_mentions_keeps_plugin_skill_namespaces() { + assert_mentions( + "use $slack:search and $alpha", + &["alpha", "slack:search"], + &[], + ); +} + +#[test] +fn collect_explicit_skill_mentions_text_respects_skill_order() { + let alpha = make_skill("alpha-skill", "/tmp/alpha"); + let beta = make_skill("beta-skill", "/tmp/beta"); + let skills = vec![beta.clone(), alpha.clone()]; + let inputs = vec![UserInput::Text { + text: "first $alpha-skill then $beta-skill".to_string(), + text_elements: Vec::new(), + }]; + let connector_counts = HashMap::new(); + + let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); + + // Text scanning should not change the previous selection ordering semantics. + assert_eq!(selected, vec![beta, alpha]); +} + +#[test] +fn collect_explicit_skill_mentions_prioritizes_structured_inputs() { + let alpha = make_skill("alpha-skill", "/tmp/alpha"); + let beta = make_skill("beta-skill", "/tmp/beta"); + let skills = vec![alpha.clone(), beta.clone()]; + let inputs = vec![ + UserInput::Text { + text: "please run $alpha-skill".to_string(), + text_elements: Vec::new(), + }, + UserInput::Skill { + name: "beta-skill".to_string(), + path: PathBuf::from("/tmp/beta"), + }, + ]; + let connector_counts = HashMap::new(); + + let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); + + assert_eq!(selected, vec![beta, alpha]); +} + +#[test] +fn collect_explicit_skill_mentions_skips_invalid_structured_and_blocks_plain_fallback() { + let alpha = make_skill("alpha-skill", "/tmp/alpha"); + let skills = vec![alpha]; + let inputs = vec![ + UserInput::Text { + text: "please run $alpha-skill".to_string(), + text_elements: Vec::new(), + }, + UserInput::Skill { + name: "alpha-skill".to_string(), + path: PathBuf::from("/tmp/missing"), + }, + ]; + let connector_counts = HashMap::new(); + + let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); + + assert_eq!(selected, Vec::new()); +} + +#[test] +fn collect_explicit_skill_mentions_skips_disabled_structured_and_blocks_plain_fallback() { + let alpha = make_skill("alpha-skill", "/tmp/alpha"); + let skills = vec![alpha]; + let inputs = vec![ + UserInput::Text { + text: "please run $alpha-skill".to_string(), + text_elements: Vec::new(), + }, + UserInput::Skill { + name: "alpha-skill".to_string(), + path: PathBuf::from("/tmp/alpha"), + }, + ]; + let disabled = HashSet::from([PathBuf::from("/tmp/alpha")]); + let connector_counts = HashMap::new(); + + let selected = collect_mentions(&inputs, &skills, &disabled, &connector_counts); + + assert_eq!(selected, Vec::new()); +} + +#[test] +fn collect_explicit_skill_mentions_dedupes_by_path() { + let alpha = make_skill("alpha-skill", "/tmp/alpha"); + let skills = vec![alpha.clone()]; + let inputs = vec![UserInput::Text { + text: "use [$alpha-skill](/tmp/alpha) and [$alpha-skill](/tmp/alpha)".to_string(), + text_elements: Vec::new(), + }]; + let connector_counts = HashMap::new(); + + let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); + + assert_eq!(selected, vec![alpha]); +} + +#[test] +fn collect_explicit_skill_mentions_skips_ambiguous_name() { + let alpha = make_skill("demo-skill", "/tmp/alpha"); + let beta = make_skill("demo-skill", "/tmp/beta"); + let skills = vec![alpha, beta]; + let inputs = vec![UserInput::Text { + text: "use $demo-skill and again $demo-skill".to_string(), + text_elements: Vec::new(), + }]; + let connector_counts = HashMap::new(); + + let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); + + assert_eq!(selected, Vec::new()); +} + +#[test] +fn collect_explicit_skill_mentions_prefers_linked_path_over_name() { + let alpha = make_skill("demo-skill", "/tmp/alpha"); + let beta = make_skill("demo-skill", "/tmp/beta"); + let skills = vec![alpha, beta.clone()]; + let inputs = vec![UserInput::Text { + text: "use $demo-skill and [$demo-skill](/tmp/beta)".to_string(), + text_elements: Vec::new(), + }]; + let connector_counts = HashMap::new(); + + let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); + + assert_eq!(selected, vec![beta]); +} + +#[test] +fn collect_explicit_skill_mentions_skips_plain_name_when_connector_matches() { + let alpha = make_skill("alpha-skill", "/tmp/alpha"); + let skills = vec![alpha]; + let inputs = vec![UserInput::Text { + text: "use $alpha-skill".to_string(), + text_elements: Vec::new(), + }]; + let connector_counts = HashMap::from([("alpha-skill".to_string(), 1)]); + + let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); + + assert_eq!(selected, Vec::new()); +} + +#[test] +fn collect_explicit_skill_mentions_allows_explicit_path_with_connector_conflict() { + let alpha = make_skill("alpha-skill", "/tmp/alpha"); + let skills = vec![alpha.clone()]; + let inputs = vec![UserInput::Text { + text: "use [$alpha-skill](/tmp/alpha)".to_string(), + text_elements: Vec::new(), + }]; + let connector_counts = HashMap::from([("alpha-skill".to_string(), 1)]); + + let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); + + assert_eq!(selected, vec![alpha]); +} + +#[test] +fn collect_explicit_skill_mentions_skips_when_linked_path_disabled() { + let alpha = make_skill("demo-skill", "/tmp/alpha"); + let beta = make_skill("demo-skill", "/tmp/beta"); + let skills = vec![alpha, beta]; + let inputs = vec![UserInput::Text { + text: "use [$demo-skill](/tmp/alpha)".to_string(), + text_elements: Vec::new(), + }]; + let disabled = HashSet::from([PathBuf::from("/tmp/alpha")]); + let connector_counts = HashMap::new(); + + let selected = collect_mentions(&inputs, &skills, &disabled, &connector_counts); + + assert_eq!(selected, Vec::new()); +} + +#[test] +fn collect_explicit_skill_mentions_prefers_resource_path() { + let alpha = make_skill("demo-skill", "/tmp/alpha"); + let beta = make_skill("demo-skill", "/tmp/beta"); + let skills = vec![alpha, beta.clone()]; + let inputs = vec![UserInput::Text { + text: "use [$demo-skill](/tmp/beta)".to_string(), + text_elements: Vec::new(), + }]; + let connector_counts = HashMap::new(); + + let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); + + assert_eq!(selected, vec![beta]); +} + +#[test] +fn collect_explicit_skill_mentions_skips_missing_path_with_no_fallback() { + let alpha = make_skill("demo-skill", "/tmp/alpha"); + let beta = make_skill("demo-skill", "/tmp/beta"); + let skills = vec![alpha, beta]; + let inputs = vec![UserInput::Text { + text: "use [$demo-skill](/tmp/missing)".to_string(), + text_elements: Vec::new(), + }]; + let connector_counts = HashMap::new(); + + let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); + + assert_eq!(selected, Vec::new()); +} + +#[test] +fn collect_explicit_skill_mentions_skips_missing_path_without_fallback() { + let alpha = make_skill("demo-skill", "/tmp/alpha"); + let skills = vec![alpha]; + let inputs = vec![UserInput::Text { + text: "use [$demo-skill](/tmp/missing)".to_string(), + text_elements: Vec::new(), + }]; + let connector_counts = HashMap::new(); + + let selected = collect_mentions(&inputs, &skills, &HashSet::new(), &connector_counts); + + assert_eq!(selected, Vec::new()); +} diff --git a/codex-rs/core/src/skills/invocation_utils.rs b/codex-rs/core/src/skills/invocation_utils.rs index c4310baceb..d9040d3b91 100644 --- a/codex-rs/core/src/skills/invocation_utils.rs +++ b/codex-rs/core/src/skills/invocation_utils.rs @@ -231,126 +231,5 @@ fn normalize_path(path: &Path) -> PathBuf { } #[cfg(test)] -mod tests { - use super::SkillLoadOutcome; - use super::SkillMetadata; - use super::detect_skill_doc_read; - use super::detect_skill_script_run; - use super::normalize_path; - use super::script_run_token; - use pretty_assertions::assert_eq; - use std::collections::HashMap; - use std::path::Path; - use std::path::PathBuf; - use std::sync::Arc; - - fn test_skill_metadata(skill_doc_path: PathBuf) -> SkillMetadata { - SkillMetadata { - name: "test-skill".to_string(), - description: "test".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: skill_doc_path, - scope: codex_protocol::protocol::SkillScope::User, - } - } - - #[test] - fn script_run_detection_matches_runner_plus_extension() { - let tokens = vec![ - "python3".to_string(), - "-u".to_string(), - "scripts/fetch_comments.py".to_string(), - ]; - - assert_eq!(script_run_token(&tokens).is_some(), true); - } - - #[test] - fn script_run_detection_excludes_python_c() { - let tokens = vec![ - "python3".to_string(), - "-c".to_string(), - "print(1)".to_string(), - ]; - - assert_eq!(script_run_token(&tokens).is_some(), false); - } - - #[test] - fn skill_doc_read_detection_matches_absolute_path() { - let skill_doc_path = PathBuf::from("/tmp/skill-test/SKILL.md"); - let normalized_skill_doc_path = normalize_path(skill_doc_path.as_path()); - let skill = test_skill_metadata(skill_doc_path); - let outcome = SkillLoadOutcome { - implicit_skills_by_scripts_dir: Arc::new(HashMap::new()), - implicit_skills_by_doc_path: Arc::new(HashMap::from([( - normalized_skill_doc_path, - skill, - )])), - ..Default::default() - }; - - let tokens = vec![ - "cat".to_string(), - "/tmp/skill-test/SKILL.md".to_string(), - "|".to_string(), - "head".to_string(), - ]; - let found = detect_skill_doc_read(&outcome, &tokens, Path::new("/tmp")); - - assert_eq!( - found.map(|value| value.name), - Some("test-skill".to_string()) - ); - } - - #[test] - fn skill_script_run_detection_matches_relative_path_from_skill_root() { - let skill_doc_path = PathBuf::from("/tmp/skill-test/SKILL.md"); - let scripts_dir = normalize_path(Path::new("/tmp/skill-test/scripts")); - let skill = test_skill_metadata(skill_doc_path); - let outcome = SkillLoadOutcome { - implicit_skills_by_scripts_dir: Arc::new(HashMap::from([(scripts_dir, skill)])), - implicit_skills_by_doc_path: Arc::new(HashMap::new()), - ..Default::default() - }; - let tokens = vec![ - "python3".to_string(), - "scripts/fetch_comments.py".to_string(), - ]; - - let found = detect_skill_script_run(&outcome, &tokens, Path::new("/tmp/skill-test")); - - assert_eq!( - found.map(|value| value.name), - Some("test-skill".to_string()) - ); - } - - #[test] - fn skill_script_run_detection_matches_absolute_path_from_any_workdir() { - let skill_doc_path = PathBuf::from("/tmp/skill-test/SKILL.md"); - let scripts_dir = normalize_path(Path::new("/tmp/skill-test/scripts")); - let skill = test_skill_metadata(skill_doc_path); - let outcome = SkillLoadOutcome { - implicit_skills_by_scripts_dir: Arc::new(HashMap::from([(scripts_dir, skill)])), - implicit_skills_by_doc_path: Arc::new(HashMap::new()), - ..Default::default() - }; - let tokens = vec![ - "python3".to_string(), - "/tmp/skill-test/scripts/fetch_comments.py".to_string(), - ]; - - let found = detect_skill_script_run(&outcome, &tokens, Path::new("/tmp/other")); - - assert_eq!( - found.map(|value| value.name), - Some("test-skill".to_string()) - ); - } -} +#[path = "invocation_utils_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/skills/invocation_utils_tests.rs b/codex-rs/core/src/skills/invocation_utils_tests.rs new file mode 100644 index 0000000000..bf244ce767 --- /dev/null +++ b/codex-rs/core/src/skills/invocation_utils_tests.rs @@ -0,0 +1,118 @@ +use super::SkillLoadOutcome; +use super::SkillMetadata; +use super::detect_skill_doc_read; +use super::detect_skill_script_run; +use super::normalize_path; +use super::script_run_token; +use pretty_assertions::assert_eq; +use std::collections::HashMap; +use std::path::Path; +use std::path::PathBuf; +use std::sync::Arc; + +fn test_skill_metadata(skill_doc_path: PathBuf) -> SkillMetadata { + SkillMetadata { + name: "test-skill".to_string(), + description: "test".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: skill_doc_path, + scope: codex_protocol::protocol::SkillScope::User, + } +} + +#[test] +fn script_run_detection_matches_runner_plus_extension() { + let tokens = vec![ + "python3".to_string(), + "-u".to_string(), + "scripts/fetch_comments.py".to_string(), + ]; + + assert_eq!(script_run_token(&tokens).is_some(), true); +} + +#[test] +fn script_run_detection_excludes_python_c() { + let tokens = vec![ + "python3".to_string(), + "-c".to_string(), + "print(1)".to_string(), + ]; + + assert_eq!(script_run_token(&tokens).is_some(), false); +} + +#[test] +fn skill_doc_read_detection_matches_absolute_path() { + let skill_doc_path = PathBuf::from("/tmp/skill-test/SKILL.md"); + let normalized_skill_doc_path = normalize_path(skill_doc_path.as_path()); + let skill = test_skill_metadata(skill_doc_path); + let outcome = SkillLoadOutcome { + implicit_skills_by_scripts_dir: Arc::new(HashMap::new()), + implicit_skills_by_doc_path: Arc::new(HashMap::from([(normalized_skill_doc_path, skill)])), + ..Default::default() + }; + + let tokens = vec![ + "cat".to_string(), + "/tmp/skill-test/SKILL.md".to_string(), + "|".to_string(), + "head".to_string(), + ]; + let found = detect_skill_doc_read(&outcome, &tokens, Path::new("/tmp")); + + assert_eq!( + found.map(|value| value.name), + Some("test-skill".to_string()) + ); +} + +#[test] +fn skill_script_run_detection_matches_relative_path_from_skill_root() { + let skill_doc_path = PathBuf::from("/tmp/skill-test/SKILL.md"); + let scripts_dir = normalize_path(Path::new("/tmp/skill-test/scripts")); + let skill = test_skill_metadata(skill_doc_path); + let outcome = SkillLoadOutcome { + implicit_skills_by_scripts_dir: Arc::new(HashMap::from([(scripts_dir, skill)])), + implicit_skills_by_doc_path: Arc::new(HashMap::new()), + ..Default::default() + }; + let tokens = vec![ + "python3".to_string(), + "scripts/fetch_comments.py".to_string(), + ]; + + let found = detect_skill_script_run(&outcome, &tokens, Path::new("/tmp/skill-test")); + + assert_eq!( + found.map(|value| value.name), + Some("test-skill".to_string()) + ); +} + +#[test] +fn skill_script_run_detection_matches_absolute_path_from_any_workdir() { + let skill_doc_path = PathBuf::from("/tmp/skill-test/SKILL.md"); + let scripts_dir = normalize_path(Path::new("/tmp/skill-test/scripts")); + let skill = test_skill_metadata(skill_doc_path); + let outcome = SkillLoadOutcome { + implicit_skills_by_scripts_dir: Arc::new(HashMap::from([(scripts_dir, skill)])), + implicit_skills_by_doc_path: Arc::new(HashMap::new()), + ..Default::default() + }; + let tokens = vec![ + "python3".to_string(), + "/tmp/skill-test/scripts/fetch_comments.py".to_string(), + ]; + + let found = detect_skill_script_run(&outcome, &tokens, Path::new("/tmp/other")); + + assert_eq!( + found.map(|value| value.name), + Some("test-skill".to_string()) + ); +} diff --git a/codex-rs/core/src/skills/loader.rs b/codex-rs/core/src/skills/loader.rs index 84c73f9e20..b2dd48b3ae 100644 --- a/codex-rs/core/src/skills/loader.rs +++ b/codex-rs/core/src/skills/loader.rs @@ -853,1914 +853,5 @@ pub(crate) fn skill_roots_from_layer_stack( } #[cfg(test)] -mod tests { - use super::*; - use crate::config::ConfigBuilder; - use crate::config::ConfigOverrides; - use crate::config::ConfigToml; - use crate::config::ProjectConfig; - use crate::config_loader::ConfigLayerEntry; - use crate::config_loader::ConfigLayerStack; - use crate::config_loader::ConfigRequirements; - use crate::config_loader::ConfigRequirementsToml; - use codex_config::CONFIG_TOML_FILE; - use codex_protocol::config_types::TrustLevel; - use codex_protocol::models::FileSystemPermissions; - use codex_protocol::models::MacOsAutomationPermission; - use codex_protocol::models::MacOsContactsPermission; - use codex_protocol::models::MacOsPreferencesPermission; - use codex_protocol::models::MacOsSeatbeltProfileExtensions; - use codex_protocol::models::PermissionProfile; - use codex_protocol::protocol::SkillScope; - use codex_utils_absolute_path::AbsolutePathBuf; - use pretty_assertions::assert_eq; - use std::collections::HashMap; - use std::path::Path; - use tempfile::TempDir; - use toml::Value as TomlValue; - - const REPO_ROOT_CONFIG_DIR_NAME: &str = ".codex"; - - async fn make_config(codex_home: &TempDir) -> Config { - make_config_for_cwd(codex_home, codex_home.path().to_path_buf()).await - } - - async fn make_config_for_cwd(codex_home: &TempDir, cwd: PathBuf) -> Config { - let trust_root = cwd - .ancestors() - .find(|ancestor| ancestor.join(".git").exists()) - .map(Path::to_path_buf) - .unwrap_or_else(|| cwd.clone()); - - fs::write( - codex_home.path().join(CONFIG_TOML_FILE), - toml::to_string(&ConfigToml { - projects: Some(HashMap::from([( - trust_root.to_string_lossy().to_string(), - ProjectConfig { - trust_level: Some(TrustLevel::Trusted), - }, - )])), - ..Default::default() - }) - .expect("serialize config"), - ) - .unwrap(); - - let harness_overrides = ConfigOverrides { - cwd: Some(cwd), - ..Default::default() - }; - - ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .harness_overrides(harness_overrides) - .build() - .await - .expect("defaults for test should always succeed") - } - - fn load_skills_for_test(config: &Config) -> SkillLoadOutcome { - // Keep unit tests hermetic by never scanning the real `$HOME/.agents/skills`. - super::load_skills_from_roots(super::skill_roots_with_home_dir( - &config.config_layer_stack, - &config.cwd, - None, - Vec::new(), - )) - } - - fn mark_as_git_repo(dir: &Path) { - // Config/project-root discovery only checks for the presence of `.git` (file or dir), - // so we can avoid shelling out to `git init` in tests. - fs::write(dir.join(".git"), "gitdir: fake\n").unwrap(); - } - - fn normalized(path: &Path) -> PathBuf { - canonicalize_path(path).unwrap_or_else(|_| path.to_path_buf()) - } - - #[test] - fn skill_roots_from_layer_stack_maps_user_to_user_and_system_cache_and_system_to_admin() - -> anyhow::Result<()> { - let tmp = tempfile::tempdir()?; - - let system_folder = tmp.path().join("etc/codex"); - let home_folder = tmp.path().join("home"); - let user_folder = home_folder.join("codex"); - fs::create_dir_all(&system_folder)?; - fs::create_dir_all(&user_folder)?; - - // The file path doesn't need to exist; it's only used to derive the config folder. - let system_file = AbsolutePathBuf::from_absolute_path(system_folder.join("config.toml"))?; - let user_file = AbsolutePathBuf::from_absolute_path(user_folder.join("config.toml"))?; - - let layers = vec![ - ConfigLayerEntry::new( - ConfigLayerSource::System { file: system_file }, - TomlValue::Table(toml::map::Map::new()), - ), - ConfigLayerEntry::new( - ConfigLayerSource::User { file: user_file }, - TomlValue::Table(toml::map::Map::new()), - ), - ]; - let stack = ConfigLayerStack::new( - layers, - ConfigRequirements::default(), - ConfigRequirementsToml::default(), - )?; - - let got = skill_roots_from_layer_stack(&stack, Some(&home_folder)) - .into_iter() - .map(|root| (root.scope, root.path)) - .collect::>(); - - assert_eq!( - got, - vec![ - (SkillScope::User, user_folder.join("skills")), - ( - SkillScope::User, - home_folder.join(AGENTS_DIR_NAME).join(SKILLS_DIR_NAME) - ), - ( - SkillScope::System, - user_folder.join("skills").join(".system") - ), - (SkillScope::Admin, system_folder.join("skills")), - ] - ); - - Ok(()) - } - - #[test] - fn skill_roots_from_layer_stack_includes_disabled_project_layers() -> anyhow::Result<()> { - let tmp = tempfile::tempdir()?; - - let home_folder = tmp.path().join("home"); - let user_folder = home_folder.join("codex"); - fs::create_dir_all(&user_folder)?; - - let project_root = tmp.path().join("repo"); - let dot_codex = project_root.join(".codex"); - fs::create_dir_all(&dot_codex)?; - - let user_file = AbsolutePathBuf::from_absolute_path(user_folder.join("config.toml"))?; - let project_dot_codex = AbsolutePathBuf::from_absolute_path(&dot_codex)?; - - let layers = vec![ - ConfigLayerEntry::new( - ConfigLayerSource::User { file: user_file }, - TomlValue::Table(toml::map::Map::new()), - ), - ConfigLayerEntry::new_disabled( - ConfigLayerSource::Project { - dot_codex_folder: project_dot_codex, - }, - TomlValue::Table(toml::map::Map::new()), - "marked untrusted", - ), - ]; - let stack = ConfigLayerStack::new( - layers, - ConfigRequirements::default(), - ConfigRequirementsToml::default(), - )?; - - let got = skill_roots_from_layer_stack(&stack, Some(&home_folder)) - .into_iter() - .map(|root| (root.scope, root.path)) - .collect::>(); - - assert_eq!( - got, - vec![ - (SkillScope::Repo, dot_codex.join("skills")), - (SkillScope::User, user_folder.join("skills")), - ( - SkillScope::User, - home_folder.join(AGENTS_DIR_NAME).join(SKILLS_DIR_NAME) - ), - ( - SkillScope::System, - user_folder.join("skills").join(".system") - ), - ] - ); - - Ok(()) - } - - #[test] - fn loads_skills_from_home_agents_dir_for_user_scope() -> anyhow::Result<()> { - let tmp = tempfile::tempdir()?; - - let home_folder = tmp.path().join("home"); - let user_folder = home_folder.join("codex"); - fs::create_dir_all(&user_folder)?; - - let user_file = AbsolutePathBuf::from_absolute_path(user_folder.join("config.toml"))?; - let layers = vec![ConfigLayerEntry::new( - ConfigLayerSource::User { file: user_file }, - TomlValue::Table(toml::map::Map::new()), - )]; - let stack = ConfigLayerStack::new( - layers, - ConfigRequirements::default(), - ConfigRequirementsToml::default(), - )?; - - let skill_path = write_skill_at( - &home_folder.join(AGENTS_DIR_NAME).join(SKILLS_DIR_NAME), - "agents-home", - "agents-home-skill", - "from home agents", - ); - - let outcome = - load_skills_from_roots(skill_roots_from_layer_stack(&stack, Some(&home_folder))); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "agents-home-skill".to_string(), - description: "from home agents".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::User, - }] - ); - - Ok(()) - } - - fn write_skill(codex_home: &TempDir, dir: &str, name: &str, description: &str) -> PathBuf { - write_skill_at(&codex_home.path().join("skills"), dir, name, description) - } - - fn write_system_skill( - codex_home: &TempDir, - dir: &str, - name: &str, - description: &str, - ) -> PathBuf { - write_skill_at( - &codex_home.path().join("skills/.system"), - dir, - name, - description, - ) - } - - fn write_skill_at(root: &Path, dir: &str, name: &str, description: &str) -> PathBuf { - let skill_dir = root.join(dir); - fs::create_dir_all(&skill_dir).unwrap(); - let indented_description = description.replace('\n', "\n "); - let content = format!( - "---\nname: {name}\ndescription: |-\n {indented_description}\n---\n\n# Body\n" - ); - let path = skill_dir.join(SKILLS_FILENAME); - fs::write(&path, content).unwrap(); - path - } - - fn write_raw_skill_at(root: &Path, dir: &str, frontmatter: &str) -> PathBuf { - let skill_dir = root.join(dir); - fs::create_dir_all(&skill_dir).unwrap(); - let path = skill_dir.join(SKILLS_FILENAME); - let content = format!("---\n{frontmatter}\n---\n\n# Body\n"); - fs::write(&path, content).unwrap(); - path - } - - fn write_skill_metadata_at(skill_dir: &Path, contents: &str) -> PathBuf { - let path = skill_dir - .join(SKILLS_METADATA_DIR) - .join(SKILLS_METADATA_FILENAME); - if let Some(parent) = path.parent() { - fs::create_dir_all(parent).unwrap(); - } - fs::write(&path, contents).unwrap(); - path - } - - fn write_skill_interface_at(skill_dir: &Path, contents: &str) -> PathBuf { - write_skill_metadata_at(skill_dir, contents) - } - - #[tokio::test] - async fn loads_skill_dependencies_metadata_from_yaml() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "dep-skill", "from json"); - let skill_dir = skill_path.parent().expect("skill dir"); - - write_skill_metadata_at( - skill_dir, - r#" -{ - "dependencies": { - "tools": [ - { - "type": "env_var", - "value": "GITHUB_TOKEN", - "description": "GitHub API token with repo scopes" - }, - { - "type": "mcp", - "value": "github", - "description": "GitHub MCP server", - "transport": "streamable_http", - "url": "https://example.com/mcp" - }, - { - "type": "cli", - "value": "gh", - "description": "GitHub CLI" - }, - { - "type": "mcp", - "value": "local-gh", - "description": "Local GH MCP server", - "transport": "stdio", - "command": "gh-mcp" - } - ] - } -} -"#, - ); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "dep-skill".to_string(), - description: "from json".to_string(), - short_description: None, - interface: None, - dependencies: Some(SkillDependencies { - tools: vec![ - SkillToolDependency { - r#type: "env_var".to_string(), - value: "GITHUB_TOKEN".to_string(), - description: Some("GitHub API token with repo scopes".to_string()), - transport: None, - command: None, - url: None, - }, - SkillToolDependency { - r#type: "mcp".to_string(), - value: "github".to_string(), - description: Some("GitHub MCP server".to_string()), - transport: Some("streamable_http".to_string()), - command: None, - url: Some("https://example.com/mcp".to_string()), - }, - SkillToolDependency { - r#type: "cli".to_string(), - value: "gh".to_string(), - description: Some("GitHub CLI".to_string()), - transport: None, - command: None, - url: None, - }, - SkillToolDependency { - r#type: "mcp".to_string(), - value: "local-gh".to_string(), - description: Some("Local GH MCP server".to_string()), - transport: Some("stdio".to_string()), - command: Some("gh-mcp".to_string()), - url: None, - }, - ], - }), - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::User, - }] - ); - } - - #[tokio::test] - async fn loads_skill_interface_metadata_from_yaml() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "ui-skill", "from json"); - let skill_dir = skill_path.parent().expect("skill dir"); - let normalized_skill_dir = normalized(skill_dir); - - write_skill_interface_at( - skill_dir, - r##" -interface: - display_name: "UI Skill" - short_description: " short desc " - icon_small: "./assets/small-400px.png" - icon_large: "./assets/large-logo.svg" - brand_color: "#3B82F6" - default_prompt: " default prompt " -"##, - ); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - let user_skills: Vec = outcome - .skills - .into_iter() - .filter(|skill| skill.scope == SkillScope::User) - .collect(); - assert_eq!( - user_skills, - vec![SkillMetadata { - name: "ui-skill".to_string(), - description: "from json".to_string(), - short_description: None, - interface: Some(SkillInterface { - display_name: Some("UI Skill".to_string()), - short_description: Some("short desc".to_string()), - icon_small: Some(normalized_skill_dir.join("assets/small-400px.png")), - icon_large: Some(normalized_skill_dir.join("assets/large-logo.svg")), - brand_color: Some("#3B82F6".to_string()), - default_prompt: Some("default prompt".to_string()), - }), - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(skill_path.as_path()), - scope: SkillScope::User, - }] - ); - } - - #[tokio::test] - async fn loads_skill_policy_from_yaml() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "policy-skill", "from json"); - let skill_dir = skill_path.parent().expect("skill dir"); - - write_skill_metadata_at( - skill_dir, - r#" -policy: - allow_implicit_invocation: false -"#, - ); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!(outcome.skills.len(), 1); - assert_eq!( - outcome.skills[0].policy, - Some(SkillPolicy { - allow_implicit_invocation: Some(false), - }) - ); - assert!(outcome.allowed_skills_for_implicit_invocation().is_empty()); - } - - #[tokio::test] - async fn empty_skill_policy_defaults_to_allow_implicit_invocation() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "policy-empty", "from json"); - let skill_dir = skill_path.parent().expect("skill dir"); - - write_skill_metadata_at( - skill_dir, - r#" -policy: {} -"#, - ); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!(outcome.skills.len(), 1); - assert_eq!( - outcome.skills[0].policy, - Some(SkillPolicy { - allow_implicit_invocation: None, - }) - ); - assert_eq!( - outcome.allowed_skills_for_implicit_invocation(), - outcome.skills - ); - } - - #[tokio::test] - async fn loads_skill_permissions_from_yaml() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "permissions-skill", "from yaml"); - let skill_dir = skill_path.parent().expect("skill dir"); - fs::create_dir_all(skill_dir.join("data")).expect("create read path"); - fs::create_dir_all(skill_dir.join("output")).expect("create write path"); - - write_skill_metadata_at( - skill_dir, - r#" -permissions: - network: - enabled: true - file_system: - read: - - "./data" - write: - - "./output" -"#, - ); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!(outcome.skills.len(), 1); - assert_eq!( - outcome.skills[0].permission_profile, - Some(PermissionProfile { - network: Some(NetworkPermissions { - enabled: Some(true), - }), - file_system: Some(FileSystemPermissions { - read: Some(vec![ - AbsolutePathBuf::try_from(normalized(skill_dir.join("data").as_path())) - .expect("absolute data path"), - ]), - write: Some(vec![ - AbsolutePathBuf::try_from(normalized(skill_dir.join("output").as_path())) - .expect("absolute output path"), - ]), - }), - macos: None, - }) - ); - } - - #[tokio::test] - async fn empty_skill_permissions_do_not_create_profile() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "permissions-empty", "from yaml"); - let skill_dir = skill_path.parent().expect("skill dir"); - - write_skill_metadata_at( - skill_dir, - r#" -permissions: {} -"#, - ); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!(outcome.skills.len(), 1); - assert_eq!(outcome.skills[0].permission_profile, None); - } - - #[test] - fn skill_metadata_parses_macos_permissions_yaml() { - let parsed = serde_yaml::from_str::( - r#" -permissions: - macos: - macos_preferences: "read_write" - macos_automation: - - "com.apple.Notes" - macos_launch_services: true - macos_accessibility: true - macos_calendar: true -"#, - ) - .expect("parse skill metadata"); - - assert_eq!( - parsed.permissions, - Some(PermissionProfile { - macos: Some(MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Notes".to_string(), - ]), - macos_launch_services: true, - macos_accessibility: true, - macos_calendar: true, - macos_reminders: false, - macos_contacts: MacOsContactsPermission::None, - }), - ..Default::default() - }) - ); - } - - #[test] - fn skill_metadata_parses_macos_reminders_permission_yaml() { - let parsed = serde_yaml::from_str::( - r#" -permissions: - macos: - macos_reminders: true -"#, - ) - .expect("parse reminders skill metadata"); - - assert_eq!( - parsed.permissions, - Some(PermissionProfile { - macos: Some(MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadOnly, - macos_automation: MacOsAutomationPermission::None, - macos_launch_services: false, - macos_accessibility: false, - macos_calendar: false, - macos_reminders: true, - macos_contacts: MacOsContactsPermission::None, - }), - ..Default::default() - }) - ); - } - - #[cfg(target_os = "macos")] - #[tokio::test] - async fn loads_skill_macos_permissions_from_yaml() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "permissions-macos", "from yaml"); - let skill_dir = skill_path.parent().expect("skill dir"); - - write_skill_metadata_at( - skill_dir, - r#" -permissions: - macos: - macos_preferences: "read_write" - macos_automation: - - "com.apple.Notes" - macos_launch_services: true - macos_accessibility: true - macos_calendar: true -"#, - ); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!(outcome.skills.len(), 1); - assert_eq!( - outcome.skills[0].permission_profile, - Some(PermissionProfile { - macos: Some(MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Notes".to_string() - ],), - macos_launch_services: true, - macos_accessibility: true, - macos_calendar: true, - macos_reminders: false, - macos_contacts: MacOsContactsPermission::None, - }), - ..Default::default() - }) - ); - } - - #[cfg(not(target_os = "macos"))] - #[tokio::test] - async fn loads_skill_macos_permissions_from_yaml_non_macos_does_not_create_profile() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "permissions-macos", "from yaml"); - let skill_dir = skill_path.parent().expect("skill dir"); - - write_skill_metadata_at( - skill_dir, - r#" -permissions: - macos: - macos_preferences: "read_write" - macos_automation: - - "com.apple.Notes" - macos_launch_services: true - macos_accessibility: true - macos_calendar: true -"#, - ); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!(outcome.skills.len(), 1); - assert_eq!( - outcome.skills[0].permission_profile, - Some(PermissionProfile { - macos: Some(MacOsSeatbeltProfileExtensions { - macos_preferences: MacOsPreferencesPermission::ReadWrite, - macos_automation: MacOsAutomationPermission::BundleIds(vec![ - "com.apple.Notes".to_string() - ],), - macos_launch_services: true, - macos_accessibility: true, - macos_calendar: true, - macos_reminders: false, - macos_contacts: MacOsContactsPermission::None, - }), - ..Default::default() - }) - ); - } - - #[tokio::test] - async fn accepts_icon_paths_under_assets_dir() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "ui-skill", "from json"); - let skill_dir = skill_path.parent().expect("skill dir"); - let normalized_skill_dir = normalized(skill_dir); - - write_skill_interface_at( - skill_dir, - r#" -{ - "interface": { - "display_name": "UI Skill", - "icon_small": "assets/icon.png", - "icon_large": "./assets/logo.svg" - } -} -"#, - ); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "ui-skill".to_string(), - description: "from json".to_string(), - short_description: None, - interface: Some(SkillInterface { - display_name: Some("UI Skill".to_string()), - short_description: None, - icon_small: Some(normalized_skill_dir.join("assets/icon.png")), - icon_large: Some(normalized_skill_dir.join("assets/logo.svg")), - brand_color: None, - default_prompt: None, - }), - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::User, - }] - ); - } - - #[tokio::test] - async fn ignores_invalid_brand_color() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "ui-skill", "from json"); - let skill_dir = skill_path.parent().expect("skill dir"); - - write_skill_interface_at( - skill_dir, - r#" -{ - "interface": { - "brand_color": "blue" - } -} -"#, - ); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "ui-skill".to_string(), - description: "from json".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::User, - }] - ); - } - - #[tokio::test] - async fn ignores_default_prompt_over_max_length() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "ui-skill", "from json"); - let skill_dir = skill_path.parent().expect("skill dir"); - let normalized_skill_dir = normalized(skill_dir); - let too_long = "x".repeat(MAX_DEFAULT_PROMPT_LEN + 1); - - write_skill_interface_at( - skill_dir, - &format!( - r##" -{{ - "interface": {{ - "display_name": "UI Skill", - "icon_small": "./assets/small-400px.png", - "default_prompt": "{too_long}" - }} -}} -"## - ), - ); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "ui-skill".to_string(), - description: "from json".to_string(), - short_description: None, - interface: Some(SkillInterface { - display_name: Some("UI Skill".to_string()), - short_description: None, - icon_small: Some(normalized_skill_dir.join("assets/small-400px.png")), - icon_large: None, - brand_color: None, - default_prompt: None, - }), - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::User, - }] - ); - } - - #[tokio::test] - async fn drops_interface_when_icons_are_invalid() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "ui-skill", "from json"); - let skill_dir = skill_path.parent().expect("skill dir"); - - write_skill_interface_at( - skill_dir, - r#" -{ - "interface": { - "icon_small": "icon.png", - "icon_large": "./assets/../logo.svg" - } -} -"#, - ); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "ui-skill".to_string(), - description: "from json".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::User, - }] - ); - } - - #[cfg(unix)] - fn symlink_dir(target: &Path, link: &Path) { - std::os::unix::fs::symlink(target, link).unwrap(); - } - - #[cfg(unix)] - fn symlink_file(target: &Path, link: &Path) { - std::os::unix::fs::symlink(target, link).unwrap(); - } - - #[tokio::test] - #[cfg(unix)] - async fn loads_skills_via_symlinked_subdir_for_user_scope() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let shared = tempfile::tempdir().expect("tempdir"); - - let shared_skill_path = write_skill_at(shared.path(), "demo", "linked-skill", "from link"); - - fs::create_dir_all(codex_home.path().join("skills")).unwrap(); - symlink_dir(shared.path(), &codex_home.path().join("skills/shared")); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "linked-skill".to_string(), - description: "from link".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&shared_skill_path), - scope: SkillScope::User, - }] - ); - } - - #[tokio::test] - #[cfg(unix)] - async fn ignores_symlinked_skill_file_for_user_scope() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let shared = tempfile::tempdir().expect("tempdir"); - - let shared_skill_path = - write_skill_at(shared.path(), "demo", "linked-file-skill", "from link"); - - let skill_dir = codex_home.path().join("skills/demo"); - fs::create_dir_all(&skill_dir).unwrap(); - symlink_file(&shared_skill_path, &skill_dir.join(SKILLS_FILENAME)); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!(outcome.skills, Vec::new()); - } - - #[tokio::test] - #[cfg(unix)] - async fn does_not_loop_on_symlink_cycle_for_user_scope() { - let codex_home = tempfile::tempdir().expect("tempdir"); - - // Create a cycle: - // $CODEX_HOME/skills/cycle/loop -> $CODEX_HOME/skills/cycle - let cycle_dir = codex_home.path().join("skills/cycle"); - fs::create_dir_all(&cycle_dir).unwrap(); - symlink_dir(&cycle_dir, &cycle_dir.join("loop")); - - let skill_path = write_skill_at(&cycle_dir, "demo", "cycle-skill", "still loads"); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "cycle-skill".to_string(), - description: "still loads".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::User, - }] - ); - } - - #[test] - #[cfg(unix)] - fn loads_skills_via_symlinked_subdir_for_admin_scope() { - let admin_root = tempfile::tempdir().expect("tempdir"); - let shared = tempfile::tempdir().expect("tempdir"); - - let shared_skill_path = - write_skill_at(shared.path(), "demo", "admin-linked-skill", "from link"); - fs::create_dir_all(admin_root.path()).unwrap(); - symlink_dir(shared.path(), &admin_root.path().join("shared")); - - let outcome = load_skills_from_roots([SkillRoot { - path: admin_root.path().to_path_buf(), - scope: SkillScope::Admin, - }]); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "admin-linked-skill".to_string(), - description: "from link".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&shared_skill_path), - scope: SkillScope::Admin, - }] - ); - } - - #[tokio::test] - #[cfg(unix)] - async fn loads_skills_via_symlinked_subdir_for_repo_scope() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let repo_dir = tempfile::tempdir().expect("tempdir"); - mark_as_git_repo(repo_dir.path()); - let shared = tempfile::tempdir().expect("tempdir"); - - let linked_skill_path = - write_skill_at(shared.path(), "demo", "repo-linked-skill", "from link"); - let repo_skills_root = repo_dir - .path() - .join(REPO_ROOT_CONFIG_DIR_NAME) - .join(SKILLS_DIR_NAME); - fs::create_dir_all(&repo_skills_root).unwrap(); - symlink_dir(shared.path(), &repo_skills_root.join("shared")); - - let cfg = make_config_for_cwd(&codex_home, repo_dir.path().to_path_buf()).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "repo-linked-skill".to_string(), - description: "from link".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&linked_skill_path), - scope: SkillScope::Repo, - }] - ); - } - - #[tokio::test] - #[cfg(unix)] - async fn system_scope_ignores_symlinked_subdir() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let shared = tempfile::tempdir().expect("tempdir"); - - write_skill_at(shared.path(), "demo", "system-linked-skill", "from link"); - - let system_root = codex_home.path().join("skills/.system"); - fs::create_dir_all(&system_root).unwrap(); - symlink_dir(shared.path(), &system_root.join("shared")); - - let outcome = load_skills_from_roots([SkillRoot { - path: system_root, - scope: SkillScope::System, - }]); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!(outcome.skills.len(), 0); - } - - #[tokio::test] - async fn respects_max_scan_depth_for_user_scope() { - let codex_home = tempfile::tempdir().expect("tempdir"); - - let within_depth_path = write_skill( - &codex_home, - "d0/d1/d2/d3/d4/d5", - "within-depth-skill", - "loads", - ); - let _too_deep_path = write_skill( - &codex_home, - "d0/d1/d2/d3/d4/d5/d6", - "too-deep-skill", - "should not load", - ); - - let skills_root = codex_home.path().join("skills"); - let outcome = load_skills_from_roots([SkillRoot { - path: skills_root, - scope: SkillScope::User, - }]); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "within-depth-skill".to_string(), - description: "loads".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&within_depth_path), - scope: SkillScope::User, - }] - ); - } - - #[tokio::test] - async fn loads_valid_skill() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_skill(&codex_home, "demo", "demo-skill", "does things\ncarefully"); - let cfg = make_config(&codex_home).await; - - let outcome = load_skills_for_test(&cfg); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "demo-skill".to_string(), - description: "does things carefully".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::User, - }] - ); - } - - #[tokio::test] - async fn falls_back_to_directory_name_when_skill_name_is_missing() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_path = write_raw_skill_at( - &codex_home.path().join("skills"), - "directory-derived", - "description: fallback name", - ); - let cfg = make_config(&codex_home).await; - - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "directory-derived".to_string(), - description: "fallback name".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::User, - }] - ); - } - - #[tokio::test] - async fn namespaces_plugin_skills_using_plugin_name() { - let root = tempfile::tempdir().expect("tempdir"); - let plugin_root = root.path().join("plugins/sample"); - let skill_path = write_raw_skill_at( - &plugin_root.join("skills"), - "sample-search", - "description: search sample data", - ); - fs::create_dir_all(plugin_root.join(".codex-plugin")).unwrap(); - fs::write( - plugin_root.join(".codex-plugin/plugin.json"), - r#"{"name":"sample"}"#, - ) - .unwrap(); - - let outcome = load_skills_from_roots([SkillRoot { - path: plugin_root.join("skills"), - scope: SkillScope::User, - }]); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "sample:sample-search".to_string(), - description: "search sample data".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::User, - }] - ); - } - - #[tokio::test] - async fn loads_short_description_from_metadata() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_dir = codex_home.path().join("skills/demo"); - fs::create_dir_all(&skill_dir).unwrap(); - let contents = "---\nname: demo-skill\ndescription: long description\nmetadata:\n short-description: short summary\n---\n\n# Body\n"; - let skill_path = skill_dir.join(SKILLS_FILENAME); - fs::write(&skill_path, contents).unwrap(); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "demo-skill".to_string(), - description: "long description".to_string(), - short_description: Some("short summary".to_string()), - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::User, - }] - ); - } - - #[tokio::test] - async fn enforces_short_description_length_limits() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let skill_dir = codex_home.path().join("skills/demo"); - fs::create_dir_all(&skill_dir).unwrap(); - let too_long = "x".repeat(MAX_SHORT_DESCRIPTION_LEN + 1); - let contents = format!( - "---\nname: demo-skill\ndescription: long description\nmetadata:\n short-description: {too_long}\n---\n\n# Body\n" - ); - fs::write(skill_dir.join(SKILLS_FILENAME), contents).unwrap(); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - assert_eq!(outcome.skills.len(), 0); - assert_eq!(outcome.errors.len(), 1); - assert!( - outcome.errors[0] - .message - .contains("invalid metadata.short-description"), - "expected length error, got: {:?}", - outcome.errors - ); - } - - #[tokio::test] - async fn skips_hidden_and_invalid() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let hidden_dir = codex_home.path().join("skills/.hidden"); - fs::create_dir_all(&hidden_dir).unwrap(); - fs::write( - hidden_dir.join(SKILLS_FILENAME), - "---\nname: hidden\ndescription: hidden\n---\n", - ) - .unwrap(); - - // Invalid because missing closing frontmatter. - let invalid_dir = codex_home.path().join("skills/invalid"); - fs::create_dir_all(&invalid_dir).unwrap(); - fs::write(invalid_dir.join(SKILLS_FILENAME), "---\nname: bad").unwrap(); - - let cfg = make_config(&codex_home).await; - let outcome = load_skills_for_test(&cfg); - assert_eq!(outcome.skills.len(), 0); - assert_eq!(outcome.errors.len(), 1); - assert!( - outcome.errors[0] - .message - .contains("missing YAML frontmatter"), - "expected frontmatter error" - ); - } - - #[tokio::test] - async fn enforces_length_limits() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let max_desc = "\u{1F4A1}".repeat(MAX_DESCRIPTION_LEN); - write_skill(&codex_home, "max-len", "max-len", &max_desc); - let cfg = make_config(&codex_home).await; - - let outcome = load_skills_for_test(&cfg); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!(outcome.skills.len(), 1); - - let too_long_desc = "\u{1F4A1}".repeat(MAX_DESCRIPTION_LEN + 1); - write_skill(&codex_home, "too-long", "too-long", &too_long_desc); - let outcome = load_skills_for_test(&cfg); - assert_eq!(outcome.skills.len(), 1); - assert_eq!(outcome.errors.len(), 1); - assert!( - outcome.errors[0].message.contains("invalid description"), - "expected length error" - ); - } - - #[tokio::test] - async fn loads_skills_from_repo_root() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let repo_dir = tempfile::tempdir().expect("tempdir"); - mark_as_git_repo(repo_dir.path()); - - let skills_root = repo_dir - .path() - .join(REPO_ROOT_CONFIG_DIR_NAME) - .join(SKILLS_DIR_NAME); - let skill_path = write_skill_at(&skills_root, "repo", "repo-skill", "from repo"); - let cfg = make_config_for_cwd(&codex_home, repo_dir.path().to_path_buf()).await; - - let outcome = load_skills_for_test(&cfg); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "repo-skill".to_string(), - description: "from repo".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::Repo, - }] - ); - } - - #[tokio::test] - async fn loads_skills_from_agents_dir_without_codex_dir() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let repo_dir = tempfile::tempdir().expect("tempdir"); - mark_as_git_repo(repo_dir.path()); - - let skill_path = write_skill_at( - &repo_dir.path().join(AGENTS_DIR_NAME).join(SKILLS_DIR_NAME), - "agents", - "agents-skill", - "from agents", - ); - let cfg = make_config_for_cwd(&codex_home, repo_dir.path().to_path_buf()).await; - - let outcome = load_skills_for_test(&cfg); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "agents-skill".to_string(), - description: "from agents".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::Repo, - }] - ); - } - - #[tokio::test] - async fn loads_skills_from_all_codex_dirs_under_project_root() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let repo_dir = tempfile::tempdir().expect("tempdir"); - mark_as_git_repo(repo_dir.path()); - - let nested_dir = repo_dir.path().join("nested/inner"); - fs::create_dir_all(&nested_dir).unwrap(); - - let root_skill_path = write_skill_at( - &repo_dir - .path() - .join(REPO_ROOT_CONFIG_DIR_NAME) - .join(SKILLS_DIR_NAME), - "root", - "root-skill", - "from root", - ); - let nested_skill_path = write_skill_at( - &repo_dir - .path() - .join("nested") - .join(REPO_ROOT_CONFIG_DIR_NAME) - .join(SKILLS_DIR_NAME), - "nested", - "nested-skill", - "from nested", - ); - - let cfg = make_config_for_cwd(&codex_home, nested_dir).await; - - let outcome = load_skills_for_test(&cfg); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![ - SkillMetadata { - name: "nested-skill".to_string(), - description: "from nested".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&nested_skill_path), - scope: SkillScope::Repo, - }, - SkillMetadata { - name: "root-skill".to_string(), - description: "from root".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&root_skill_path), - scope: SkillScope::Repo, - }, - ] - ); - } - - #[tokio::test] - async fn loads_skills_from_codex_dir_when_not_git_repo() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let work_dir = tempfile::tempdir().expect("tempdir"); - - let skill_path = write_skill_at( - &work_dir - .path() - .join(REPO_ROOT_CONFIG_DIR_NAME) - .join(SKILLS_DIR_NAME), - "local", - "local-skill", - "from cwd", - ); - - let cfg = make_config_for_cwd(&codex_home, work_dir.path().to_path_buf()).await; - - let outcome = load_skills_for_test(&cfg); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "local-skill".to_string(), - description: "from cwd".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::Repo, - }] - ); - } - - #[tokio::test] - async fn deduplicates_by_path_preferring_first_root() { - let root = tempfile::tempdir().expect("tempdir"); - - let skill_path = write_skill_at(root.path(), "dupe", "dupe-skill", "from repo"); - - let outcome = load_skills_from_roots([ - SkillRoot { - path: root.path().to_path_buf(), - scope: SkillScope::Repo, - }, - SkillRoot { - path: root.path().to_path_buf(), - scope: SkillScope::User, - }, - ]); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "dupe-skill".to_string(), - description: "from repo".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::Repo, - }] - ); - } - - #[tokio::test] - async fn keeps_duplicate_names_from_repo_and_user() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let repo_dir = tempfile::tempdir().expect("tempdir"); - mark_as_git_repo(repo_dir.path()); - - let user_skill_path = write_skill(&codex_home, "user", "dupe-skill", "from user"); - let repo_skill_path = write_skill_at( - &repo_dir - .path() - .join(REPO_ROOT_CONFIG_DIR_NAME) - .join(SKILLS_DIR_NAME), - "repo", - "dupe-skill", - "from repo", - ); - - let cfg = make_config_for_cwd(&codex_home, repo_dir.path().to_path_buf()).await; - - let outcome = load_skills_for_test(&cfg); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![ - SkillMetadata { - name: "dupe-skill".to_string(), - description: "from repo".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&repo_skill_path), - scope: SkillScope::Repo, - }, - SkillMetadata { - name: "dupe-skill".to_string(), - description: "from user".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&user_skill_path), - scope: SkillScope::User, - }, - ] - ); - } - - #[tokio::test] - async fn keeps_duplicate_names_from_nested_codex_dirs() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let repo_dir = tempfile::tempdir().expect("tempdir"); - mark_as_git_repo(repo_dir.path()); - - let nested_dir = repo_dir.path().join("nested/inner"); - fs::create_dir_all(&nested_dir).unwrap(); - - let root_skill_path = write_skill_at( - &repo_dir - .path() - .join(REPO_ROOT_CONFIG_DIR_NAME) - .join(SKILLS_DIR_NAME), - "root", - "dupe-skill", - "from root", - ); - let nested_skill_path = write_skill_at( - &repo_dir - .path() - .join("nested") - .join(REPO_ROOT_CONFIG_DIR_NAME) - .join(SKILLS_DIR_NAME), - "nested", - "dupe-skill", - "from nested", - ); - - let cfg = make_config_for_cwd(&codex_home, nested_dir).await; - let outcome = load_skills_for_test(&cfg); - - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - let root_path = - canonicalize_path(&root_skill_path).unwrap_or_else(|_| root_skill_path.clone()); - let nested_path = - canonicalize_path(&nested_skill_path).unwrap_or_else(|_| nested_skill_path.clone()); - let (first_path, second_path, first_description, second_description) = - if root_path <= nested_path { - (root_path, nested_path, "from root", "from nested") - } else { - (nested_path, root_path, "from nested", "from root") - }; - assert_eq!( - outcome.skills, - vec![ - SkillMetadata { - name: "dupe-skill".to_string(), - description: first_description.to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: first_path, - scope: SkillScope::Repo, - }, - SkillMetadata { - name: "dupe-skill".to_string(), - description: second_description.to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: second_path, - scope: SkillScope::Repo, - }, - ] - ); - } - - #[tokio::test] - async fn repo_skills_search_does_not_escape_repo_root() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let outer_dir = tempfile::tempdir().expect("tempdir"); - let repo_dir = outer_dir.path().join("repo"); - fs::create_dir_all(&repo_dir).unwrap(); - - let _skill_path = write_skill_at( - &outer_dir - .path() - .join(REPO_ROOT_CONFIG_DIR_NAME) - .join(SKILLS_DIR_NAME), - "outer", - "outer-skill", - "from outer", - ); - mark_as_git_repo(&repo_dir); - - let cfg = make_config_for_cwd(&codex_home, repo_dir).await; - - let outcome = load_skills_for_test(&cfg); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!(outcome.skills.len(), 0); - } - - #[tokio::test] - async fn loads_skills_when_cwd_is_file_in_repo() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let repo_dir = tempfile::tempdir().expect("tempdir"); - mark_as_git_repo(repo_dir.path()); - - let skill_path = write_skill_at( - &repo_dir - .path() - .join(REPO_ROOT_CONFIG_DIR_NAME) - .join(SKILLS_DIR_NAME), - "repo", - "repo-skill", - "from repo", - ); - let file_path = repo_dir.path().join("some-file.txt"); - fs::write(&file_path, "contents").unwrap(); - - let cfg = make_config_for_cwd(&codex_home, file_path).await; - - let outcome = load_skills_for_test(&cfg); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "repo-skill".to_string(), - description: "from repo".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::Repo, - }] - ); - } - - #[tokio::test] - async fn non_git_repo_skills_search_does_not_walk_parents() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let outer_dir = tempfile::tempdir().expect("tempdir"); - let nested_dir = outer_dir.path().join("nested/inner"); - fs::create_dir_all(&nested_dir).unwrap(); - - write_skill_at( - &outer_dir - .path() - .join(REPO_ROOT_CONFIG_DIR_NAME) - .join(SKILLS_DIR_NAME), - "outer", - "outer-skill", - "from outer", - ); - - let cfg = make_config_for_cwd(&codex_home, nested_dir).await; - - let outcome = load_skills_for_test(&cfg); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!(outcome.skills.len(), 0); - } - - #[tokio::test] - async fn loads_skills_from_system_cache_when_present() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let work_dir = tempfile::tempdir().expect("tempdir"); - - let skill_path = write_system_skill(&codex_home, "system", "system-skill", "from system"); - - let cfg = make_config_for_cwd(&codex_home, work_dir.path().to_path_buf()).await; - - let outcome = load_skills_for_test(&cfg); - assert!( - outcome.errors.is_empty(), - "unexpected errors: {:?}", - outcome.errors - ); - assert_eq!( - outcome.skills, - vec![SkillMetadata { - name: "system-skill".to_string(), - description: "from system".to_string(), - short_description: None, - interface: None, - dependencies: None, - policy: None, - permission_profile: None, - path_to_skills_md: normalized(&skill_path), - scope: SkillScope::System, - }] - ); - } - - #[tokio::test] - async fn skill_roots_include_admin_with_lowest_priority() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let cfg = make_config(&codex_home).await; - - let scopes: Vec = - super::skill_roots(&cfg.config_layer_stack, &cfg.cwd, Vec::new()) - .into_iter() - .map(|root| root.scope) - .collect(); - let mut expected = vec![SkillScope::User, SkillScope::System]; - if home_dir().is_some() { - expected.insert(1, SkillScope::User); - } - expected.push(SkillScope::Admin); - assert_eq!(scopes, expected); - } -} +#[path = "loader_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/skills/loader_tests.rs b/codex-rs/core/src/skills/loader_tests.rs new file mode 100644 index 0000000000..feee09a9ab --- /dev/null +++ b/codex-rs/core/src/skills/loader_tests.rs @@ -0,0 +1,1898 @@ +use super::*; +use crate::config::ConfigBuilder; +use crate::config::ConfigOverrides; +use crate::config::ConfigToml; +use crate::config::ProjectConfig; +use crate::config_loader::ConfigLayerEntry; +use crate::config_loader::ConfigLayerStack; +use crate::config_loader::ConfigRequirements; +use crate::config_loader::ConfigRequirementsToml; +use codex_config::CONFIG_TOML_FILE; +use codex_protocol::config_types::TrustLevel; +use codex_protocol::models::FileSystemPermissions; +use codex_protocol::models::MacOsAutomationPermission; +use codex_protocol::models::MacOsContactsPermission; +use codex_protocol::models::MacOsPreferencesPermission; +use codex_protocol::models::MacOsSeatbeltProfileExtensions; +use codex_protocol::models::PermissionProfile; +use codex_protocol::protocol::SkillScope; +use codex_utils_absolute_path::AbsolutePathBuf; +use pretty_assertions::assert_eq; +use std::collections::HashMap; +use std::path::Path; +use tempfile::TempDir; +use toml::Value as TomlValue; + +const REPO_ROOT_CONFIG_DIR_NAME: &str = ".codex"; + +async fn make_config(codex_home: &TempDir) -> Config { + make_config_for_cwd(codex_home, codex_home.path().to_path_buf()).await +} + +async fn make_config_for_cwd(codex_home: &TempDir, cwd: PathBuf) -> Config { + let trust_root = cwd + .ancestors() + .find(|ancestor| ancestor.join(".git").exists()) + .map(Path::to_path_buf) + .unwrap_or_else(|| cwd.clone()); + + fs::write( + codex_home.path().join(CONFIG_TOML_FILE), + toml::to_string(&ConfigToml { + projects: Some(HashMap::from([( + trust_root.to_string_lossy().to_string(), + ProjectConfig { + trust_level: Some(TrustLevel::Trusted), + }, + )])), + ..Default::default() + }) + .expect("serialize config"), + ) + .unwrap(); + + let harness_overrides = ConfigOverrides { + cwd: Some(cwd), + ..Default::default() + }; + + ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .harness_overrides(harness_overrides) + .build() + .await + .expect("defaults for test should always succeed") +} + +fn load_skills_for_test(config: &Config) -> SkillLoadOutcome { + // Keep unit tests hermetic by never scanning the real `$HOME/.agents/skills`. + super::load_skills_from_roots(super::skill_roots_with_home_dir( + &config.config_layer_stack, + &config.cwd, + None, + Vec::new(), + )) +} + +fn mark_as_git_repo(dir: &Path) { + // Config/project-root discovery only checks for the presence of `.git` (file or dir), + // so we can avoid shelling out to `git init` in tests. + fs::write(dir.join(".git"), "gitdir: fake\n").unwrap(); +} + +fn normalized(path: &Path) -> PathBuf { + canonicalize_path(path).unwrap_or_else(|_| path.to_path_buf()) +} + +#[test] +fn skill_roots_from_layer_stack_maps_user_to_user_and_system_cache_and_system_to_admin() +-> anyhow::Result<()> { + let tmp = tempfile::tempdir()?; + + let system_folder = tmp.path().join("etc/codex"); + let home_folder = tmp.path().join("home"); + let user_folder = home_folder.join("codex"); + fs::create_dir_all(&system_folder)?; + fs::create_dir_all(&user_folder)?; + + // The file path doesn't need to exist; it's only used to derive the config folder. + let system_file = AbsolutePathBuf::from_absolute_path(system_folder.join("config.toml"))?; + let user_file = AbsolutePathBuf::from_absolute_path(user_folder.join("config.toml"))?; + + let layers = vec![ + ConfigLayerEntry::new( + ConfigLayerSource::System { file: system_file }, + TomlValue::Table(toml::map::Map::new()), + ), + ConfigLayerEntry::new( + ConfigLayerSource::User { file: user_file }, + TomlValue::Table(toml::map::Map::new()), + ), + ]; + let stack = ConfigLayerStack::new( + layers, + ConfigRequirements::default(), + ConfigRequirementsToml::default(), + )?; + + let got = skill_roots_from_layer_stack(&stack, Some(&home_folder)) + .into_iter() + .map(|root| (root.scope, root.path)) + .collect::>(); + + assert_eq!( + got, + vec![ + (SkillScope::User, user_folder.join("skills")), + ( + SkillScope::User, + home_folder.join(AGENTS_DIR_NAME).join(SKILLS_DIR_NAME) + ), + ( + SkillScope::System, + user_folder.join("skills").join(".system") + ), + (SkillScope::Admin, system_folder.join("skills")), + ] + ); + + Ok(()) +} + +#[test] +fn skill_roots_from_layer_stack_includes_disabled_project_layers() -> anyhow::Result<()> { + let tmp = tempfile::tempdir()?; + + let home_folder = tmp.path().join("home"); + let user_folder = home_folder.join("codex"); + fs::create_dir_all(&user_folder)?; + + let project_root = tmp.path().join("repo"); + let dot_codex = project_root.join(".codex"); + fs::create_dir_all(&dot_codex)?; + + let user_file = AbsolutePathBuf::from_absolute_path(user_folder.join("config.toml"))?; + let project_dot_codex = AbsolutePathBuf::from_absolute_path(&dot_codex)?; + + let layers = vec![ + ConfigLayerEntry::new( + ConfigLayerSource::User { file: user_file }, + TomlValue::Table(toml::map::Map::new()), + ), + ConfigLayerEntry::new_disabled( + ConfigLayerSource::Project { + dot_codex_folder: project_dot_codex, + }, + TomlValue::Table(toml::map::Map::new()), + "marked untrusted", + ), + ]; + let stack = ConfigLayerStack::new( + layers, + ConfigRequirements::default(), + ConfigRequirementsToml::default(), + )?; + + let got = skill_roots_from_layer_stack(&stack, Some(&home_folder)) + .into_iter() + .map(|root| (root.scope, root.path)) + .collect::>(); + + assert_eq!( + got, + vec![ + (SkillScope::Repo, dot_codex.join("skills")), + (SkillScope::User, user_folder.join("skills")), + ( + SkillScope::User, + home_folder.join(AGENTS_DIR_NAME).join(SKILLS_DIR_NAME) + ), + ( + SkillScope::System, + user_folder.join("skills").join(".system") + ), + ] + ); + + Ok(()) +} + +#[test] +fn loads_skills_from_home_agents_dir_for_user_scope() -> anyhow::Result<()> { + let tmp = tempfile::tempdir()?; + + let home_folder = tmp.path().join("home"); + let user_folder = home_folder.join("codex"); + fs::create_dir_all(&user_folder)?; + + let user_file = AbsolutePathBuf::from_absolute_path(user_folder.join("config.toml"))?; + let layers = vec![ConfigLayerEntry::new( + ConfigLayerSource::User { file: user_file }, + TomlValue::Table(toml::map::Map::new()), + )]; + let stack = ConfigLayerStack::new( + layers, + ConfigRequirements::default(), + ConfigRequirementsToml::default(), + )?; + + let skill_path = write_skill_at( + &home_folder.join(AGENTS_DIR_NAME).join(SKILLS_DIR_NAME), + "agents-home", + "agents-home-skill", + "from home agents", + ); + + let outcome = load_skills_from_roots(skill_roots_from_layer_stack(&stack, Some(&home_folder))); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "agents-home-skill".to_string(), + description: "from home agents".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::User, + }] + ); + + Ok(()) +} + +fn write_skill(codex_home: &TempDir, dir: &str, name: &str, description: &str) -> PathBuf { + write_skill_at(&codex_home.path().join("skills"), dir, name, description) +} + +fn write_system_skill(codex_home: &TempDir, dir: &str, name: &str, description: &str) -> PathBuf { + write_skill_at( + &codex_home.path().join("skills/.system"), + dir, + name, + description, + ) +} + +fn write_skill_at(root: &Path, dir: &str, name: &str, description: &str) -> PathBuf { + let skill_dir = root.join(dir); + fs::create_dir_all(&skill_dir).unwrap(); + let indented_description = description.replace('\n', "\n "); + let content = + format!("---\nname: {name}\ndescription: |-\n {indented_description}\n---\n\n# Body\n"); + let path = skill_dir.join(SKILLS_FILENAME); + fs::write(&path, content).unwrap(); + path +} + +fn write_raw_skill_at(root: &Path, dir: &str, frontmatter: &str) -> PathBuf { + let skill_dir = root.join(dir); + fs::create_dir_all(&skill_dir).unwrap(); + let path = skill_dir.join(SKILLS_FILENAME); + let content = format!("---\n{frontmatter}\n---\n\n# Body\n"); + fs::write(&path, content).unwrap(); + path +} + +fn write_skill_metadata_at(skill_dir: &Path, contents: &str) -> PathBuf { + let path = skill_dir + .join(SKILLS_METADATA_DIR) + .join(SKILLS_METADATA_FILENAME); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).unwrap(); + } + fs::write(&path, contents).unwrap(); + path +} + +fn write_skill_interface_at(skill_dir: &Path, contents: &str) -> PathBuf { + write_skill_metadata_at(skill_dir, contents) +} + +#[tokio::test] +async fn loads_skill_dependencies_metadata_from_yaml() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "dep-skill", "from json"); + let skill_dir = skill_path.parent().expect("skill dir"); + + write_skill_metadata_at( + skill_dir, + r#" +{ + "dependencies": { + "tools": [ + { + "type": "env_var", + "value": "GITHUB_TOKEN", + "description": "GitHub API token with repo scopes" + }, + { + "type": "mcp", + "value": "github", + "description": "GitHub MCP server", + "transport": "streamable_http", + "url": "https://example.com/mcp" + }, + { + "type": "cli", + "value": "gh", + "description": "GitHub CLI" + }, + { + "type": "mcp", + "value": "local-gh", + "description": "Local GH MCP server", + "transport": "stdio", + "command": "gh-mcp" + } + ] + } +} +"#, + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "dep-skill".to_string(), + description: "from json".to_string(), + short_description: None, + interface: None, + dependencies: Some(SkillDependencies { + tools: vec![ + SkillToolDependency { + r#type: "env_var".to_string(), + value: "GITHUB_TOKEN".to_string(), + description: Some("GitHub API token with repo scopes".to_string()), + transport: None, + command: None, + url: None, + }, + SkillToolDependency { + r#type: "mcp".to_string(), + value: "github".to_string(), + description: Some("GitHub MCP server".to_string()), + transport: Some("streamable_http".to_string()), + command: None, + url: Some("https://example.com/mcp".to_string()), + }, + SkillToolDependency { + r#type: "cli".to_string(), + value: "gh".to_string(), + description: Some("GitHub CLI".to_string()), + transport: None, + command: None, + url: None, + }, + SkillToolDependency { + r#type: "mcp".to_string(), + value: "local-gh".to_string(), + description: Some("Local GH MCP server".to_string()), + transport: Some("stdio".to_string()), + command: Some("gh-mcp".to_string()), + url: None, + }, + ], + }), + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::User, + }] + ); +} + +#[tokio::test] +async fn loads_skill_interface_metadata_from_yaml() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "ui-skill", "from json"); + let skill_dir = skill_path.parent().expect("skill dir"); + let normalized_skill_dir = normalized(skill_dir); + + write_skill_interface_at( + skill_dir, + r##" +interface: + display_name: "UI Skill" + short_description: " short desc " + icon_small: "./assets/small-400px.png" + icon_large: "./assets/large-logo.svg" + brand_color: "#3B82F6" + default_prompt: " default prompt " +"##, + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + let user_skills: Vec = outcome + .skills + .into_iter() + .filter(|skill| skill.scope == SkillScope::User) + .collect(); + assert_eq!( + user_skills, + vec![SkillMetadata { + name: "ui-skill".to_string(), + description: "from json".to_string(), + short_description: None, + interface: Some(SkillInterface { + display_name: Some("UI Skill".to_string()), + short_description: Some("short desc".to_string()), + icon_small: Some(normalized_skill_dir.join("assets/small-400px.png")), + icon_large: Some(normalized_skill_dir.join("assets/large-logo.svg")), + brand_color: Some("#3B82F6".to_string()), + default_prompt: Some("default prompt".to_string()), + }), + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(skill_path.as_path()), + scope: SkillScope::User, + }] + ); +} + +#[tokio::test] +async fn loads_skill_policy_from_yaml() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "policy-skill", "from json"); + let skill_dir = skill_path.parent().expect("skill dir"); + + write_skill_metadata_at( + skill_dir, + r#" +policy: + allow_implicit_invocation: false +"#, + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!(outcome.skills.len(), 1); + assert_eq!( + outcome.skills[0].policy, + Some(SkillPolicy { + allow_implicit_invocation: Some(false), + }) + ); + assert!(outcome.allowed_skills_for_implicit_invocation().is_empty()); +} + +#[tokio::test] +async fn empty_skill_policy_defaults_to_allow_implicit_invocation() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "policy-empty", "from json"); + let skill_dir = skill_path.parent().expect("skill dir"); + + write_skill_metadata_at( + skill_dir, + r#" +policy: {} +"#, + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!(outcome.skills.len(), 1); + assert_eq!( + outcome.skills[0].policy, + Some(SkillPolicy { + allow_implicit_invocation: None, + }) + ); + assert_eq!( + outcome.allowed_skills_for_implicit_invocation(), + outcome.skills + ); +} + +#[tokio::test] +async fn loads_skill_permissions_from_yaml() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "permissions-skill", "from yaml"); + let skill_dir = skill_path.parent().expect("skill dir"); + fs::create_dir_all(skill_dir.join("data")).expect("create read path"); + fs::create_dir_all(skill_dir.join("output")).expect("create write path"); + + write_skill_metadata_at( + skill_dir, + r#" +permissions: + network: + enabled: true + file_system: + read: + - "./data" + write: + - "./output" +"#, + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!(outcome.skills.len(), 1); + assert_eq!( + outcome.skills[0].permission_profile, + Some(PermissionProfile { + network: Some(NetworkPermissions { + enabled: Some(true), + }), + file_system: Some(FileSystemPermissions { + read: Some(vec![ + AbsolutePathBuf::try_from(normalized(skill_dir.join("data").as_path())) + .expect("absolute data path"), + ]), + write: Some(vec![ + AbsolutePathBuf::try_from(normalized(skill_dir.join("output").as_path())) + .expect("absolute output path"), + ]), + }), + macos: None, + }) + ); +} + +#[tokio::test] +async fn empty_skill_permissions_do_not_create_profile() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "permissions-empty", "from yaml"); + let skill_dir = skill_path.parent().expect("skill dir"); + + write_skill_metadata_at( + skill_dir, + r#" +permissions: {} +"#, + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!(outcome.skills.len(), 1); + assert_eq!(outcome.skills[0].permission_profile, None); +} + +#[test] +fn skill_metadata_parses_macos_permissions_yaml() { + let parsed = serde_yaml::from_str::( + r#" +permissions: + macos: + macos_preferences: "read_write" + macos_automation: + - "com.apple.Notes" + macos_launch_services: true + macos_accessibility: true + macos_calendar: true +"#, + ) + .expect("parse skill metadata"); + + assert_eq!( + parsed.permissions, + Some(PermissionProfile { + macos: Some(MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Notes".to_string(), + ]), + macos_launch_services: true, + macos_accessibility: true, + macos_calendar: true, + macos_reminders: false, + macos_contacts: MacOsContactsPermission::None, + }), + ..Default::default() + }) + ); +} + +#[test] +fn skill_metadata_parses_macos_reminders_permission_yaml() { + let parsed = serde_yaml::from_str::( + r#" +permissions: + macos: + macos_reminders: true +"#, + ) + .expect("parse reminders skill metadata"); + + assert_eq!( + parsed.permissions, + Some(PermissionProfile { + macos: Some(MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadOnly, + macos_automation: MacOsAutomationPermission::None, + macos_launch_services: false, + macos_accessibility: false, + macos_calendar: false, + macos_reminders: true, + macos_contacts: MacOsContactsPermission::None, + }), + ..Default::default() + }) + ); +} + +#[cfg(target_os = "macos")] +#[tokio::test] +async fn loads_skill_macos_permissions_from_yaml() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "permissions-macos", "from yaml"); + let skill_dir = skill_path.parent().expect("skill dir"); + + write_skill_metadata_at( + skill_dir, + r#" +permissions: + macos: + macos_preferences: "read_write" + macos_automation: + - "com.apple.Notes" + macos_launch_services: true + macos_accessibility: true + macos_calendar: true +"#, + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!(outcome.skills.len(), 1); + assert_eq!( + outcome.skills[0].permission_profile, + Some(PermissionProfile { + macos: Some(MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Notes".to_string() + ],), + macos_launch_services: true, + macos_accessibility: true, + macos_calendar: true, + macos_reminders: false, + macos_contacts: MacOsContactsPermission::None, + }), + ..Default::default() + }) + ); +} + +#[cfg(not(target_os = "macos"))] +#[tokio::test] +async fn loads_skill_macos_permissions_from_yaml_non_macos_does_not_create_profile() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "permissions-macos", "from yaml"); + let skill_dir = skill_path.parent().expect("skill dir"); + + write_skill_metadata_at( + skill_dir, + r#" +permissions: + macos: + macos_preferences: "read_write" + macos_automation: + - "com.apple.Notes" + macos_launch_services: true + macos_accessibility: true + macos_calendar: true +"#, + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!(outcome.skills.len(), 1); + assert_eq!( + outcome.skills[0].permission_profile, + Some(PermissionProfile { + macos: Some(MacOsSeatbeltProfileExtensions { + macos_preferences: MacOsPreferencesPermission::ReadWrite, + macos_automation: MacOsAutomationPermission::BundleIds(vec![ + "com.apple.Notes".to_string() + ],), + macos_launch_services: true, + macos_accessibility: true, + macos_calendar: true, + macos_reminders: false, + macos_contacts: MacOsContactsPermission::None, + }), + ..Default::default() + }) + ); +} + +#[tokio::test] +async fn accepts_icon_paths_under_assets_dir() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "ui-skill", "from json"); + let skill_dir = skill_path.parent().expect("skill dir"); + let normalized_skill_dir = normalized(skill_dir); + + write_skill_interface_at( + skill_dir, + r#" +{ + "interface": { + "display_name": "UI Skill", + "icon_small": "assets/icon.png", + "icon_large": "./assets/logo.svg" + } +} +"#, + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "ui-skill".to_string(), + description: "from json".to_string(), + short_description: None, + interface: Some(SkillInterface { + display_name: Some("UI Skill".to_string()), + short_description: None, + icon_small: Some(normalized_skill_dir.join("assets/icon.png")), + icon_large: Some(normalized_skill_dir.join("assets/logo.svg")), + brand_color: None, + default_prompt: None, + }), + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::User, + }] + ); +} + +#[tokio::test] +async fn ignores_invalid_brand_color() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "ui-skill", "from json"); + let skill_dir = skill_path.parent().expect("skill dir"); + + write_skill_interface_at( + skill_dir, + r#" +{ + "interface": { + "brand_color": "blue" + } +} +"#, + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "ui-skill".to_string(), + description: "from json".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::User, + }] + ); +} + +#[tokio::test] +async fn ignores_default_prompt_over_max_length() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "ui-skill", "from json"); + let skill_dir = skill_path.parent().expect("skill dir"); + let normalized_skill_dir = normalized(skill_dir); + let too_long = "x".repeat(MAX_DEFAULT_PROMPT_LEN + 1); + + write_skill_interface_at( + skill_dir, + &format!( + r##" +{{ + "interface": {{ + "display_name": "UI Skill", + "icon_small": "./assets/small-400px.png", + "default_prompt": "{too_long}" + }} +}} +"## + ), + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "ui-skill".to_string(), + description: "from json".to_string(), + short_description: None, + interface: Some(SkillInterface { + display_name: Some("UI Skill".to_string()), + short_description: None, + icon_small: Some(normalized_skill_dir.join("assets/small-400px.png")), + icon_large: None, + brand_color: None, + default_prompt: None, + }), + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::User, + }] + ); +} + +#[tokio::test] +async fn drops_interface_when_icons_are_invalid() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "ui-skill", "from json"); + let skill_dir = skill_path.parent().expect("skill dir"); + + write_skill_interface_at( + skill_dir, + r#" +{ + "interface": { + "icon_small": "icon.png", + "icon_large": "./assets/../logo.svg" + } +} +"#, + ); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "ui-skill".to_string(), + description: "from json".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::User, + }] + ); +} + +#[cfg(unix)] +fn symlink_dir(target: &Path, link: &Path) { + std::os::unix::fs::symlink(target, link).unwrap(); +} + +#[cfg(unix)] +fn symlink_file(target: &Path, link: &Path) { + std::os::unix::fs::symlink(target, link).unwrap(); +} + +#[tokio::test] +#[cfg(unix)] +async fn loads_skills_via_symlinked_subdir_for_user_scope() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let shared = tempfile::tempdir().expect("tempdir"); + + let shared_skill_path = write_skill_at(shared.path(), "demo", "linked-skill", "from link"); + + fs::create_dir_all(codex_home.path().join("skills")).unwrap(); + symlink_dir(shared.path(), &codex_home.path().join("skills/shared")); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "linked-skill".to_string(), + description: "from link".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&shared_skill_path), + scope: SkillScope::User, + }] + ); +} + +#[tokio::test] +#[cfg(unix)] +async fn ignores_symlinked_skill_file_for_user_scope() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let shared = tempfile::tempdir().expect("tempdir"); + + let shared_skill_path = write_skill_at(shared.path(), "demo", "linked-file-skill", "from link"); + + let skill_dir = codex_home.path().join("skills/demo"); + fs::create_dir_all(&skill_dir).unwrap(); + symlink_file(&shared_skill_path, &skill_dir.join(SKILLS_FILENAME)); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!(outcome.skills, Vec::new()); +} + +#[tokio::test] +#[cfg(unix)] +async fn does_not_loop_on_symlink_cycle_for_user_scope() { + let codex_home = tempfile::tempdir().expect("tempdir"); + + // Create a cycle: + // $CODEX_HOME/skills/cycle/loop -> $CODEX_HOME/skills/cycle + let cycle_dir = codex_home.path().join("skills/cycle"); + fs::create_dir_all(&cycle_dir).unwrap(); + symlink_dir(&cycle_dir, &cycle_dir.join("loop")); + + let skill_path = write_skill_at(&cycle_dir, "demo", "cycle-skill", "still loads"); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "cycle-skill".to_string(), + description: "still loads".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::User, + }] + ); +} + +#[test] +#[cfg(unix)] +fn loads_skills_via_symlinked_subdir_for_admin_scope() { + let admin_root = tempfile::tempdir().expect("tempdir"); + let shared = tempfile::tempdir().expect("tempdir"); + + let shared_skill_path = + write_skill_at(shared.path(), "demo", "admin-linked-skill", "from link"); + fs::create_dir_all(admin_root.path()).unwrap(); + symlink_dir(shared.path(), &admin_root.path().join("shared")); + + let outcome = load_skills_from_roots([SkillRoot { + path: admin_root.path().to_path_buf(), + scope: SkillScope::Admin, + }]); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "admin-linked-skill".to_string(), + description: "from link".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&shared_skill_path), + scope: SkillScope::Admin, + }] + ); +} + +#[tokio::test] +#[cfg(unix)] +async fn loads_skills_via_symlinked_subdir_for_repo_scope() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let repo_dir = tempfile::tempdir().expect("tempdir"); + mark_as_git_repo(repo_dir.path()); + let shared = tempfile::tempdir().expect("tempdir"); + + let linked_skill_path = write_skill_at(shared.path(), "demo", "repo-linked-skill", "from link"); + let repo_skills_root = repo_dir + .path() + .join(REPO_ROOT_CONFIG_DIR_NAME) + .join(SKILLS_DIR_NAME); + fs::create_dir_all(&repo_skills_root).unwrap(); + symlink_dir(shared.path(), &repo_skills_root.join("shared")); + + let cfg = make_config_for_cwd(&codex_home, repo_dir.path().to_path_buf()).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "repo-linked-skill".to_string(), + description: "from link".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&linked_skill_path), + scope: SkillScope::Repo, + }] + ); +} + +#[tokio::test] +#[cfg(unix)] +async fn system_scope_ignores_symlinked_subdir() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let shared = tempfile::tempdir().expect("tempdir"); + + write_skill_at(shared.path(), "demo", "system-linked-skill", "from link"); + + let system_root = codex_home.path().join("skills/.system"); + fs::create_dir_all(&system_root).unwrap(); + symlink_dir(shared.path(), &system_root.join("shared")); + + let outcome = load_skills_from_roots([SkillRoot { + path: system_root, + scope: SkillScope::System, + }]); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!(outcome.skills.len(), 0); +} + +#[tokio::test] +async fn respects_max_scan_depth_for_user_scope() { + let codex_home = tempfile::tempdir().expect("tempdir"); + + let within_depth_path = write_skill( + &codex_home, + "d0/d1/d2/d3/d4/d5", + "within-depth-skill", + "loads", + ); + let _too_deep_path = write_skill( + &codex_home, + "d0/d1/d2/d3/d4/d5/d6", + "too-deep-skill", + "should not load", + ); + + let skills_root = codex_home.path().join("skills"); + let outcome = load_skills_from_roots([SkillRoot { + path: skills_root, + scope: SkillScope::User, + }]); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "within-depth-skill".to_string(), + description: "loads".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&within_depth_path), + scope: SkillScope::User, + }] + ); +} + +#[tokio::test] +async fn loads_valid_skill() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_skill(&codex_home, "demo", "demo-skill", "does things\ncarefully"); + let cfg = make_config(&codex_home).await; + + let outcome = load_skills_for_test(&cfg); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "demo-skill".to_string(), + description: "does things carefully".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::User, + }] + ); +} + +#[tokio::test] +async fn falls_back_to_directory_name_when_skill_name_is_missing() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_path = write_raw_skill_at( + &codex_home.path().join("skills"), + "directory-derived", + "description: fallback name", + ); + let cfg = make_config(&codex_home).await; + + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "directory-derived".to_string(), + description: "fallback name".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::User, + }] + ); +} + +#[tokio::test] +async fn namespaces_plugin_skills_using_plugin_name() { + let root = tempfile::tempdir().expect("tempdir"); + let plugin_root = root.path().join("plugins/sample"); + let skill_path = write_raw_skill_at( + &plugin_root.join("skills"), + "sample-search", + "description: search sample data", + ); + fs::create_dir_all(plugin_root.join(".codex-plugin")).unwrap(); + fs::write( + plugin_root.join(".codex-plugin/plugin.json"), + r#"{"name":"sample"}"#, + ) + .unwrap(); + + let outcome = load_skills_from_roots([SkillRoot { + path: plugin_root.join("skills"), + scope: SkillScope::User, + }]); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "sample:sample-search".to_string(), + description: "search sample data".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::User, + }] + ); +} + +#[tokio::test] +async fn loads_short_description_from_metadata() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_dir = codex_home.path().join("skills/demo"); + fs::create_dir_all(&skill_dir).unwrap(); + let contents = "---\nname: demo-skill\ndescription: long description\nmetadata:\n short-description: short summary\n---\n\n# Body\n"; + let skill_path = skill_dir.join(SKILLS_FILENAME); + fs::write(&skill_path, contents).unwrap(); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "demo-skill".to_string(), + description: "long description".to_string(), + short_description: Some("short summary".to_string()), + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::User, + }] + ); +} + +#[tokio::test] +async fn enforces_short_description_length_limits() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let skill_dir = codex_home.path().join("skills/demo"); + fs::create_dir_all(&skill_dir).unwrap(); + let too_long = "x".repeat(MAX_SHORT_DESCRIPTION_LEN + 1); + let contents = format!( + "---\nname: demo-skill\ndescription: long description\nmetadata:\n short-description: {too_long}\n---\n\n# Body\n" + ); + fs::write(skill_dir.join(SKILLS_FILENAME), contents).unwrap(); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + assert_eq!(outcome.skills.len(), 0); + assert_eq!(outcome.errors.len(), 1); + assert!( + outcome.errors[0] + .message + .contains("invalid metadata.short-description"), + "expected length error, got: {:?}", + outcome.errors + ); +} + +#[tokio::test] +async fn skips_hidden_and_invalid() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let hidden_dir = codex_home.path().join("skills/.hidden"); + fs::create_dir_all(&hidden_dir).unwrap(); + fs::write( + hidden_dir.join(SKILLS_FILENAME), + "---\nname: hidden\ndescription: hidden\n---\n", + ) + .unwrap(); + + // Invalid because missing closing frontmatter. + let invalid_dir = codex_home.path().join("skills/invalid"); + fs::create_dir_all(&invalid_dir).unwrap(); + fs::write(invalid_dir.join(SKILLS_FILENAME), "---\nname: bad").unwrap(); + + let cfg = make_config(&codex_home).await; + let outcome = load_skills_for_test(&cfg); + assert_eq!(outcome.skills.len(), 0); + assert_eq!(outcome.errors.len(), 1); + assert!( + outcome.errors[0] + .message + .contains("missing YAML frontmatter"), + "expected frontmatter error" + ); +} + +#[tokio::test] +async fn enforces_length_limits() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let max_desc = "\u{1F4A1}".repeat(MAX_DESCRIPTION_LEN); + write_skill(&codex_home, "max-len", "max-len", &max_desc); + let cfg = make_config(&codex_home).await; + + let outcome = load_skills_for_test(&cfg); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!(outcome.skills.len(), 1); + + let too_long_desc = "\u{1F4A1}".repeat(MAX_DESCRIPTION_LEN + 1); + write_skill(&codex_home, "too-long", "too-long", &too_long_desc); + let outcome = load_skills_for_test(&cfg); + assert_eq!(outcome.skills.len(), 1); + assert_eq!(outcome.errors.len(), 1); + assert!( + outcome.errors[0].message.contains("invalid description"), + "expected length error" + ); +} + +#[tokio::test] +async fn loads_skills_from_repo_root() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let repo_dir = tempfile::tempdir().expect("tempdir"); + mark_as_git_repo(repo_dir.path()); + + let skills_root = repo_dir + .path() + .join(REPO_ROOT_CONFIG_DIR_NAME) + .join(SKILLS_DIR_NAME); + let skill_path = write_skill_at(&skills_root, "repo", "repo-skill", "from repo"); + let cfg = make_config_for_cwd(&codex_home, repo_dir.path().to_path_buf()).await; + + let outcome = load_skills_for_test(&cfg); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "repo-skill".to_string(), + description: "from repo".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::Repo, + }] + ); +} + +#[tokio::test] +async fn loads_skills_from_agents_dir_without_codex_dir() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let repo_dir = tempfile::tempdir().expect("tempdir"); + mark_as_git_repo(repo_dir.path()); + + let skill_path = write_skill_at( + &repo_dir.path().join(AGENTS_DIR_NAME).join(SKILLS_DIR_NAME), + "agents", + "agents-skill", + "from agents", + ); + let cfg = make_config_for_cwd(&codex_home, repo_dir.path().to_path_buf()).await; + + let outcome = load_skills_for_test(&cfg); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "agents-skill".to_string(), + description: "from agents".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::Repo, + }] + ); +} + +#[tokio::test] +async fn loads_skills_from_all_codex_dirs_under_project_root() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let repo_dir = tempfile::tempdir().expect("tempdir"); + mark_as_git_repo(repo_dir.path()); + + let nested_dir = repo_dir.path().join("nested/inner"); + fs::create_dir_all(&nested_dir).unwrap(); + + let root_skill_path = write_skill_at( + &repo_dir + .path() + .join(REPO_ROOT_CONFIG_DIR_NAME) + .join(SKILLS_DIR_NAME), + "root", + "root-skill", + "from root", + ); + let nested_skill_path = write_skill_at( + &repo_dir + .path() + .join("nested") + .join(REPO_ROOT_CONFIG_DIR_NAME) + .join(SKILLS_DIR_NAME), + "nested", + "nested-skill", + "from nested", + ); + + let cfg = make_config_for_cwd(&codex_home, nested_dir).await; + + let outcome = load_skills_for_test(&cfg); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![ + SkillMetadata { + name: "nested-skill".to_string(), + description: "from nested".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&nested_skill_path), + scope: SkillScope::Repo, + }, + SkillMetadata { + name: "root-skill".to_string(), + description: "from root".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&root_skill_path), + scope: SkillScope::Repo, + }, + ] + ); +} + +#[tokio::test] +async fn loads_skills_from_codex_dir_when_not_git_repo() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let work_dir = tempfile::tempdir().expect("tempdir"); + + let skill_path = write_skill_at( + &work_dir + .path() + .join(REPO_ROOT_CONFIG_DIR_NAME) + .join(SKILLS_DIR_NAME), + "local", + "local-skill", + "from cwd", + ); + + let cfg = make_config_for_cwd(&codex_home, work_dir.path().to_path_buf()).await; + + let outcome = load_skills_for_test(&cfg); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "local-skill".to_string(), + description: "from cwd".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::Repo, + }] + ); +} + +#[tokio::test] +async fn deduplicates_by_path_preferring_first_root() { + let root = tempfile::tempdir().expect("tempdir"); + + let skill_path = write_skill_at(root.path(), "dupe", "dupe-skill", "from repo"); + + let outcome = load_skills_from_roots([ + SkillRoot { + path: root.path().to_path_buf(), + scope: SkillScope::Repo, + }, + SkillRoot { + path: root.path().to_path_buf(), + scope: SkillScope::User, + }, + ]); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "dupe-skill".to_string(), + description: "from repo".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::Repo, + }] + ); +} + +#[tokio::test] +async fn keeps_duplicate_names_from_repo_and_user() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let repo_dir = tempfile::tempdir().expect("tempdir"); + mark_as_git_repo(repo_dir.path()); + + let user_skill_path = write_skill(&codex_home, "user", "dupe-skill", "from user"); + let repo_skill_path = write_skill_at( + &repo_dir + .path() + .join(REPO_ROOT_CONFIG_DIR_NAME) + .join(SKILLS_DIR_NAME), + "repo", + "dupe-skill", + "from repo", + ); + + let cfg = make_config_for_cwd(&codex_home, repo_dir.path().to_path_buf()).await; + + let outcome = load_skills_for_test(&cfg); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![ + SkillMetadata { + name: "dupe-skill".to_string(), + description: "from repo".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&repo_skill_path), + scope: SkillScope::Repo, + }, + SkillMetadata { + name: "dupe-skill".to_string(), + description: "from user".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&user_skill_path), + scope: SkillScope::User, + }, + ] + ); +} + +#[tokio::test] +async fn keeps_duplicate_names_from_nested_codex_dirs() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let repo_dir = tempfile::tempdir().expect("tempdir"); + mark_as_git_repo(repo_dir.path()); + + let nested_dir = repo_dir.path().join("nested/inner"); + fs::create_dir_all(&nested_dir).unwrap(); + + let root_skill_path = write_skill_at( + &repo_dir + .path() + .join(REPO_ROOT_CONFIG_DIR_NAME) + .join(SKILLS_DIR_NAME), + "root", + "dupe-skill", + "from root", + ); + let nested_skill_path = write_skill_at( + &repo_dir + .path() + .join("nested") + .join(REPO_ROOT_CONFIG_DIR_NAME) + .join(SKILLS_DIR_NAME), + "nested", + "dupe-skill", + "from nested", + ); + + let cfg = make_config_for_cwd(&codex_home, nested_dir).await; + let outcome = load_skills_for_test(&cfg); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + let root_path = canonicalize_path(&root_skill_path).unwrap_or_else(|_| root_skill_path.clone()); + let nested_path = + canonicalize_path(&nested_skill_path).unwrap_or_else(|_| nested_skill_path.clone()); + let (first_path, second_path, first_description, second_description) = + if root_path <= nested_path { + (root_path, nested_path, "from root", "from nested") + } else { + (nested_path, root_path, "from nested", "from root") + }; + assert_eq!( + outcome.skills, + vec![ + SkillMetadata { + name: "dupe-skill".to_string(), + description: first_description.to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: first_path, + scope: SkillScope::Repo, + }, + SkillMetadata { + name: "dupe-skill".to_string(), + description: second_description.to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: second_path, + scope: SkillScope::Repo, + }, + ] + ); +} + +#[tokio::test] +async fn repo_skills_search_does_not_escape_repo_root() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let outer_dir = tempfile::tempdir().expect("tempdir"); + let repo_dir = outer_dir.path().join("repo"); + fs::create_dir_all(&repo_dir).unwrap(); + + let _skill_path = write_skill_at( + &outer_dir + .path() + .join(REPO_ROOT_CONFIG_DIR_NAME) + .join(SKILLS_DIR_NAME), + "outer", + "outer-skill", + "from outer", + ); + mark_as_git_repo(&repo_dir); + + let cfg = make_config_for_cwd(&codex_home, repo_dir).await; + + let outcome = load_skills_for_test(&cfg); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!(outcome.skills.len(), 0); +} + +#[tokio::test] +async fn loads_skills_when_cwd_is_file_in_repo() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let repo_dir = tempfile::tempdir().expect("tempdir"); + mark_as_git_repo(repo_dir.path()); + + let skill_path = write_skill_at( + &repo_dir + .path() + .join(REPO_ROOT_CONFIG_DIR_NAME) + .join(SKILLS_DIR_NAME), + "repo", + "repo-skill", + "from repo", + ); + let file_path = repo_dir.path().join("some-file.txt"); + fs::write(&file_path, "contents").unwrap(); + + let cfg = make_config_for_cwd(&codex_home, file_path).await; + + let outcome = load_skills_for_test(&cfg); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "repo-skill".to_string(), + description: "from repo".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::Repo, + }] + ); +} + +#[tokio::test] +async fn non_git_repo_skills_search_does_not_walk_parents() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let outer_dir = tempfile::tempdir().expect("tempdir"); + let nested_dir = outer_dir.path().join("nested/inner"); + fs::create_dir_all(&nested_dir).unwrap(); + + write_skill_at( + &outer_dir + .path() + .join(REPO_ROOT_CONFIG_DIR_NAME) + .join(SKILLS_DIR_NAME), + "outer", + "outer-skill", + "from outer", + ); + + let cfg = make_config_for_cwd(&codex_home, nested_dir).await; + + let outcome = load_skills_for_test(&cfg); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!(outcome.skills.len(), 0); +} + +#[tokio::test] +async fn loads_skills_from_system_cache_when_present() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let work_dir = tempfile::tempdir().expect("tempdir"); + + let skill_path = write_system_skill(&codex_home, "system", "system-skill", "from system"); + + let cfg = make_config_for_cwd(&codex_home, work_dir.path().to_path_buf()).await; + + let outcome = load_skills_for_test(&cfg); + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + assert_eq!( + outcome.skills, + vec![SkillMetadata { + name: "system-skill".to_string(), + description: "from system".to_string(), + short_description: None, + interface: None, + dependencies: None, + policy: None, + permission_profile: None, + path_to_skills_md: normalized(&skill_path), + scope: SkillScope::System, + }] + ); +} + +#[tokio::test] +async fn skill_roots_include_admin_with_lowest_priority() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let cfg = make_config(&codex_home).await; + + let scopes: Vec = super::skill_roots(&cfg.config_layer_stack, &cfg.cwd, Vec::new()) + .into_iter() + .map(|root| root.scope) + .collect(); + let mut expected = vec![SkillScope::User, SkillScope::System]; + if home_dir().is_some() { + expected.insert(1, SkillScope::User); + } + expected.push(SkillScope::Admin); + assert_eq!(scopes, expected); +} diff --git a/codex-rs/core/src/skills/manager.rs b/codex-rs/core/src/skills/manager.rs index ed7471d653..3a824d25fc 100644 --- a/codex-rs/core/src/skills/manager.rs +++ b/codex-rs/core/src/skills/manager.rs @@ -279,364 +279,5 @@ fn normalize_extra_user_roots(extra_user_roots: &[PathBuf]) -> Vec { } #[cfg(test)] -mod tests { - use super::*; - use crate::config::ConfigBuilder; - use crate::config::ConfigOverrides; - use crate::config_loader::ConfigLayerEntry; - use crate::config_loader::ConfigLayerStack; - use crate::config_loader::ConfigRequirementsToml; - use crate::plugins::PluginsManager; - use pretty_assertions::assert_eq; - use std::fs; - use std::path::PathBuf; - use tempfile::TempDir; - - fn write_user_skill(codex_home: &TempDir, dir: &str, name: &str, description: &str) { - let skill_dir = codex_home.path().join("skills").join(dir); - fs::create_dir_all(&skill_dir).unwrap(); - let content = format!("---\nname: {name}\ndescription: {description}\n---\n\n# Body\n"); - fs::write(skill_dir.join("SKILL.md"), content).unwrap(); - } - - #[test] - fn new_with_disabled_bundled_skills_removes_stale_cached_system_skills() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let stale_system_skill_dir = codex_home.path().join("skills/.system/stale-skill"); - fs::create_dir_all(&stale_system_skill_dir).expect("create stale system skill dir"); - fs::write(stale_system_skill_dir.join("SKILL.md"), "# stale\n") - .expect("write stale system skill"); - - let plugins_manager = Arc::new(PluginsManager::new(codex_home.path().to_path_buf())); - let _skills_manager = - SkillsManager::new(codex_home.path().to_path_buf(), plugins_manager, false); - - assert!( - !codex_home.path().join("skills/.system").exists(), - "expected disabling system skills to remove stale cached bundled skills" - ); - } - - #[tokio::test] - async fn skills_for_config_seeds_cache_by_cwd() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let cwd = tempfile::tempdir().expect("tempdir"); - - let cfg = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .harness_overrides(ConfigOverrides { - cwd: Some(cwd.path().to_path_buf()), - ..Default::default() - }) - .build() - .await - .expect("defaults for test should always succeed"); - - let plugins_manager = Arc::new(PluginsManager::new(codex_home.path().to_path_buf())); - let skills_manager = - SkillsManager::new(codex_home.path().to_path_buf(), plugins_manager, true); - - write_user_skill(&codex_home, "a", "skill-a", "from a"); - let outcome1 = skills_manager.skills_for_config(&cfg); - assert!( - outcome1.skills.iter().any(|s| s.name == "skill-a"), - "expected skill-a to be discovered" - ); - - // Write a new skill after the first call; the second call should hit the cache and not - // reflect the new file. - write_user_skill(&codex_home, "b", "skill-b", "from b"); - let outcome2 = skills_manager.skills_for_config(&cfg); - assert_eq!(outcome2.errors, outcome1.errors); - assert_eq!(outcome2.skills, outcome1.skills); - } - - #[tokio::test] - async fn skills_for_cwd_reuses_cached_entry_even_when_entry_has_extra_roots() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let cwd = tempfile::tempdir().expect("tempdir"); - let extra_root = tempfile::tempdir().expect("tempdir"); - - let config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .harness_overrides(ConfigOverrides { - cwd: Some(cwd.path().to_path_buf()), - ..Default::default() - }) - .build() - .await - .expect("defaults for test should always succeed"); - - let plugins_manager = Arc::new(PluginsManager::new(codex_home.path().to_path_buf())); - let skills_manager = - SkillsManager::new(codex_home.path().to_path_buf(), plugins_manager, true); - let _ = skills_manager.skills_for_config(&config); - - write_user_skill(&extra_root, "x", "extra-skill", "from extra root"); - let extra_root_path = extra_root.path().to_path_buf(); - let outcome_with_extra = skills_manager - .skills_for_cwd_with_extra_user_roots( - cwd.path(), - true, - std::slice::from_ref(&extra_root_path), - ) - .await; - assert!( - outcome_with_extra - .skills - .iter() - .any(|skill| skill.name == "extra-skill") - ); - assert!( - outcome_with_extra - .skills - .iter() - .any(|skill| skill.scope == SkillScope::System) - ); - - // The cwd-only API returns the current cached entry for this cwd, even when that entry - // was produced with extra roots. - let outcome_without_extra = skills_manager.skills_for_cwd(cwd.path(), false).await; - assert_eq!(outcome_without_extra.skills, outcome_with_extra.skills); - assert_eq!(outcome_without_extra.errors, outcome_with_extra.errors); - } - - #[tokio::test] - async fn skills_for_config_excludes_bundled_skills_when_disabled_in_config() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let cwd = tempfile::tempdir().expect("tempdir"); - let bundled_skill_dir = codex_home.path().join("skills/.system/bundled-skill"); - fs::create_dir_all(&bundled_skill_dir).expect("create bundled skill dir"); - fs::write( - bundled_skill_dir.join("SKILL.md"), - "---\nname: bundled-skill\ndescription: from bundled root\n---\n\n# Body\n", - ) - .expect("write bundled skill"); - - fs::write( - codex_home.path().join(crate::config::CONFIG_TOML_FILE), - "[skills.bundled]\nenabled = false\n", - ) - .expect("write config"); - - let config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .harness_overrides(ConfigOverrides { - cwd: Some(cwd.path().to_path_buf()), - ..Default::default() - }) - .build() - .await - .expect("load config"); - - let plugins_manager = Arc::new(PluginsManager::new(codex_home.path().to_path_buf())); - let skills_manager = SkillsManager::new( - codex_home.path().to_path_buf(), - plugins_manager, - config.bundled_skills_enabled(), - ); - - // Recreate the cached bundled skill after startup cleanup so this assertion exercises - // root selection rather than relying on directory removal succeeding. - fs::create_dir_all(&bundled_skill_dir).expect("recreate bundled skill dir"); - fs::write( - bundled_skill_dir.join("SKILL.md"), - "---\nname: bundled-skill\ndescription: from bundled root\n---\n\n# Body\n", - ) - .expect("rewrite bundled skill"); - - let outcome = skills_manager.skills_for_config(&config); - assert!( - outcome - .skills - .iter() - .all(|skill| skill.name != "bundled-skill") - ); - assert!( - outcome - .skills - .iter() - .all(|skill| skill.scope != SkillScope::System) - ); - } - - #[tokio::test] - async fn skills_for_cwd_with_extra_roots_only_refreshes_on_force_reload() { - let codex_home = tempfile::tempdir().expect("tempdir"); - let cwd = tempfile::tempdir().expect("tempdir"); - let extra_root_a = tempfile::tempdir().expect("tempdir"); - let extra_root_b = tempfile::tempdir().expect("tempdir"); - - let config = ConfigBuilder::default() - .codex_home(codex_home.path().to_path_buf()) - .harness_overrides(ConfigOverrides { - cwd: Some(cwd.path().to_path_buf()), - ..Default::default() - }) - .build() - .await - .expect("defaults for test should always succeed"); - - let plugins_manager = Arc::new(PluginsManager::new(codex_home.path().to_path_buf())); - let skills_manager = - SkillsManager::new(codex_home.path().to_path_buf(), plugins_manager, true); - let _ = skills_manager.skills_for_config(&config); - - write_user_skill(&extra_root_a, "x", "extra-skill-a", "from extra root a"); - write_user_skill(&extra_root_b, "x", "extra-skill-b", "from extra root b"); - - let extra_root_a_path = extra_root_a.path().to_path_buf(); - let outcome_a = skills_manager - .skills_for_cwd_with_extra_user_roots( - cwd.path(), - true, - std::slice::from_ref(&extra_root_a_path), - ) - .await; - assert!( - outcome_a - .skills - .iter() - .any(|skill| skill.name == "extra-skill-a") - ); - assert!( - outcome_a - .skills - .iter() - .all(|skill| skill.name != "extra-skill-b") - ); - - let extra_root_b_path = extra_root_b.path().to_path_buf(); - let outcome_b = skills_manager - .skills_for_cwd_with_extra_user_roots( - cwd.path(), - false, - std::slice::from_ref(&extra_root_b_path), - ) - .await; - assert!( - outcome_b - .skills - .iter() - .any(|skill| skill.name == "extra-skill-a") - ); - assert!( - outcome_b - .skills - .iter() - .all(|skill| skill.name != "extra-skill-b") - ); - - let outcome_reloaded = skills_manager - .skills_for_cwd_with_extra_user_roots( - cwd.path(), - true, - std::slice::from_ref(&extra_root_b_path), - ) - .await; - assert!( - outcome_reloaded - .skills - .iter() - .any(|skill| skill.name == "extra-skill-b") - ); - assert!( - outcome_reloaded - .skills - .iter() - .all(|skill| skill.name != "extra-skill-a") - ); - } - - #[test] - fn normalize_extra_user_roots_is_stable_for_equivalent_inputs() { - let a = PathBuf::from("/tmp/a"); - let b = PathBuf::from("/tmp/b"); - - let first = normalize_extra_user_roots(&[a.clone(), b.clone(), a.clone()]); - let second = normalize_extra_user_roots(&[b, a]); - - assert_eq!(first, second); - } - - #[cfg_attr(windows, ignore)] - #[test] - fn disabled_paths_from_stack_allows_session_flags_to_override_user_layer() { - let tempdir = tempfile::tempdir().expect("tempdir"); - let skill_path = tempdir.path().join("skills").join("demo").join("SKILL.md"); - let user_file = AbsolutePathBuf::try_from(tempdir.path().join("config.toml")) - .expect("user config path should be absolute"); - let user_layer = ConfigLayerEntry::new( - ConfigLayerSource::User { file: user_file }, - toml::from_str(&format!( - r#"[[skills.config]] -path = "{}" -enabled = false -"#, - skill_path.display() - )) - .expect("user layer toml"), - ); - let session_layer = ConfigLayerEntry::new( - ConfigLayerSource::SessionFlags, - toml::from_str(&format!( - r#"[[skills.config]] -path = "{}" -enabled = true -"#, - skill_path.display() - )) - .expect("session layer toml"), - ); - let stack = ConfigLayerStack::new( - vec![user_layer, session_layer], - Default::default(), - ConfigRequirementsToml::default(), - ) - .expect("valid config layer stack"); - - assert_eq!(disabled_paths_from_stack(&stack), HashSet::new()); - } - - #[cfg_attr(windows, ignore)] - #[test] - fn disabled_paths_from_stack_allows_session_flags_to_disable_user_enabled_skill() { - let tempdir = tempfile::tempdir().expect("tempdir"); - let skill_path = tempdir.path().join("skills").join("demo").join("SKILL.md"); - let user_file = AbsolutePathBuf::try_from(tempdir.path().join("config.toml")) - .expect("user config path should be absolute"); - let user_layer = ConfigLayerEntry::new( - ConfigLayerSource::User { file: user_file }, - toml::from_str(&format!( - r#"[[skills.config]] -path = "{}" -enabled = true -"#, - skill_path.display() - )) - .expect("user layer toml"), - ); - let session_layer = ConfigLayerEntry::new( - ConfigLayerSource::SessionFlags, - toml::from_str(&format!( - r#"[[skills.config]] -path = "{}" -enabled = false -"#, - skill_path.display() - )) - .expect("session layer toml"), - ); - let stack = ConfigLayerStack::new( - vec![user_layer, session_layer], - Default::default(), - ConfigRequirementsToml::default(), - ) - .expect("valid config layer stack"); - - assert_eq!( - disabled_paths_from_stack(&stack), - HashSet::from([skill_path]) - ); - } -} +#[path = "manager_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/skills/manager_tests.rs b/codex-rs/core/src/skills/manager_tests.rs new file mode 100644 index 0000000000..f9d6dc2c5f --- /dev/null +++ b/codex-rs/core/src/skills/manager_tests.rs @@ -0,0 +1,356 @@ +use super::*; +use crate::config::ConfigBuilder; +use crate::config::ConfigOverrides; +use crate::config_loader::ConfigLayerEntry; +use crate::config_loader::ConfigLayerStack; +use crate::config_loader::ConfigRequirementsToml; +use crate::plugins::PluginsManager; +use pretty_assertions::assert_eq; +use std::fs; +use std::path::PathBuf; +use tempfile::TempDir; + +fn write_user_skill(codex_home: &TempDir, dir: &str, name: &str, description: &str) { + let skill_dir = codex_home.path().join("skills").join(dir); + fs::create_dir_all(&skill_dir).unwrap(); + let content = format!("---\nname: {name}\ndescription: {description}\n---\n\n# Body\n"); + fs::write(skill_dir.join("SKILL.md"), content).unwrap(); +} + +#[test] +fn new_with_disabled_bundled_skills_removes_stale_cached_system_skills() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let stale_system_skill_dir = codex_home.path().join("skills/.system/stale-skill"); + fs::create_dir_all(&stale_system_skill_dir).expect("create stale system skill dir"); + fs::write(stale_system_skill_dir.join("SKILL.md"), "# stale\n") + .expect("write stale system skill"); + + let plugins_manager = Arc::new(PluginsManager::new(codex_home.path().to_path_buf())); + let _skills_manager = + SkillsManager::new(codex_home.path().to_path_buf(), plugins_manager, false); + + assert!( + !codex_home.path().join("skills/.system").exists(), + "expected disabling system skills to remove stale cached bundled skills" + ); +} + +#[tokio::test] +async fn skills_for_config_seeds_cache_by_cwd() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let cwd = tempfile::tempdir().expect("tempdir"); + + let cfg = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .harness_overrides(ConfigOverrides { + cwd: Some(cwd.path().to_path_buf()), + ..Default::default() + }) + .build() + .await + .expect("defaults for test should always succeed"); + + let plugins_manager = Arc::new(PluginsManager::new(codex_home.path().to_path_buf())); + let skills_manager = SkillsManager::new(codex_home.path().to_path_buf(), plugins_manager, true); + + write_user_skill(&codex_home, "a", "skill-a", "from a"); + let outcome1 = skills_manager.skills_for_config(&cfg); + assert!( + outcome1.skills.iter().any(|s| s.name == "skill-a"), + "expected skill-a to be discovered" + ); + + // Write a new skill after the first call; the second call should hit the cache and not + // reflect the new file. + write_user_skill(&codex_home, "b", "skill-b", "from b"); + let outcome2 = skills_manager.skills_for_config(&cfg); + assert_eq!(outcome2.errors, outcome1.errors); + assert_eq!(outcome2.skills, outcome1.skills); +} + +#[tokio::test] +async fn skills_for_cwd_reuses_cached_entry_even_when_entry_has_extra_roots() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let cwd = tempfile::tempdir().expect("tempdir"); + let extra_root = tempfile::tempdir().expect("tempdir"); + + let config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .harness_overrides(ConfigOverrides { + cwd: Some(cwd.path().to_path_buf()), + ..Default::default() + }) + .build() + .await + .expect("defaults for test should always succeed"); + + let plugins_manager = Arc::new(PluginsManager::new(codex_home.path().to_path_buf())); + let skills_manager = SkillsManager::new(codex_home.path().to_path_buf(), plugins_manager, true); + let _ = skills_manager.skills_for_config(&config); + + write_user_skill(&extra_root, "x", "extra-skill", "from extra root"); + let extra_root_path = extra_root.path().to_path_buf(); + let outcome_with_extra = skills_manager + .skills_for_cwd_with_extra_user_roots( + cwd.path(), + true, + std::slice::from_ref(&extra_root_path), + ) + .await; + assert!( + outcome_with_extra + .skills + .iter() + .any(|skill| skill.name == "extra-skill") + ); + assert!( + outcome_with_extra + .skills + .iter() + .any(|skill| skill.scope == SkillScope::System) + ); + + // The cwd-only API returns the current cached entry for this cwd, even when that entry + // was produced with extra roots. + let outcome_without_extra = skills_manager.skills_for_cwd(cwd.path(), false).await; + assert_eq!(outcome_without_extra.skills, outcome_with_extra.skills); + assert_eq!(outcome_without_extra.errors, outcome_with_extra.errors); +} + +#[tokio::test] +async fn skills_for_config_excludes_bundled_skills_when_disabled_in_config() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let cwd = tempfile::tempdir().expect("tempdir"); + let bundled_skill_dir = codex_home.path().join("skills/.system/bundled-skill"); + fs::create_dir_all(&bundled_skill_dir).expect("create bundled skill dir"); + fs::write( + bundled_skill_dir.join("SKILL.md"), + "---\nname: bundled-skill\ndescription: from bundled root\n---\n\n# Body\n", + ) + .expect("write bundled skill"); + + fs::write( + codex_home.path().join(crate::config::CONFIG_TOML_FILE), + "[skills.bundled]\nenabled = false\n", + ) + .expect("write config"); + + let config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .harness_overrides(ConfigOverrides { + cwd: Some(cwd.path().to_path_buf()), + ..Default::default() + }) + .build() + .await + .expect("load config"); + + let plugins_manager = Arc::new(PluginsManager::new(codex_home.path().to_path_buf())); + let skills_manager = SkillsManager::new( + codex_home.path().to_path_buf(), + plugins_manager, + config.bundled_skills_enabled(), + ); + + // Recreate the cached bundled skill after startup cleanup so this assertion exercises + // root selection rather than relying on directory removal succeeding. + fs::create_dir_all(&bundled_skill_dir).expect("recreate bundled skill dir"); + fs::write( + bundled_skill_dir.join("SKILL.md"), + "---\nname: bundled-skill\ndescription: from bundled root\n---\n\n# Body\n", + ) + .expect("rewrite bundled skill"); + + let outcome = skills_manager.skills_for_config(&config); + assert!( + outcome + .skills + .iter() + .all(|skill| skill.name != "bundled-skill") + ); + assert!( + outcome + .skills + .iter() + .all(|skill| skill.scope != SkillScope::System) + ); +} + +#[tokio::test] +async fn skills_for_cwd_with_extra_roots_only_refreshes_on_force_reload() { + let codex_home = tempfile::tempdir().expect("tempdir"); + let cwd = tempfile::tempdir().expect("tempdir"); + let extra_root_a = tempfile::tempdir().expect("tempdir"); + let extra_root_b = tempfile::tempdir().expect("tempdir"); + + let config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .harness_overrides(ConfigOverrides { + cwd: Some(cwd.path().to_path_buf()), + ..Default::default() + }) + .build() + .await + .expect("defaults for test should always succeed"); + + let plugins_manager = Arc::new(PluginsManager::new(codex_home.path().to_path_buf())); + let skills_manager = SkillsManager::new(codex_home.path().to_path_buf(), plugins_manager, true); + let _ = skills_manager.skills_for_config(&config); + + write_user_skill(&extra_root_a, "x", "extra-skill-a", "from extra root a"); + write_user_skill(&extra_root_b, "x", "extra-skill-b", "from extra root b"); + + let extra_root_a_path = extra_root_a.path().to_path_buf(); + let outcome_a = skills_manager + .skills_for_cwd_with_extra_user_roots( + cwd.path(), + true, + std::slice::from_ref(&extra_root_a_path), + ) + .await; + assert!( + outcome_a + .skills + .iter() + .any(|skill| skill.name == "extra-skill-a") + ); + assert!( + outcome_a + .skills + .iter() + .all(|skill| skill.name != "extra-skill-b") + ); + + let extra_root_b_path = extra_root_b.path().to_path_buf(); + let outcome_b = skills_manager + .skills_for_cwd_with_extra_user_roots( + cwd.path(), + false, + std::slice::from_ref(&extra_root_b_path), + ) + .await; + assert!( + outcome_b + .skills + .iter() + .any(|skill| skill.name == "extra-skill-a") + ); + assert!( + outcome_b + .skills + .iter() + .all(|skill| skill.name != "extra-skill-b") + ); + + let outcome_reloaded = skills_manager + .skills_for_cwd_with_extra_user_roots( + cwd.path(), + true, + std::slice::from_ref(&extra_root_b_path), + ) + .await; + assert!( + outcome_reloaded + .skills + .iter() + .any(|skill| skill.name == "extra-skill-b") + ); + assert!( + outcome_reloaded + .skills + .iter() + .all(|skill| skill.name != "extra-skill-a") + ); +} + +#[test] +fn normalize_extra_user_roots_is_stable_for_equivalent_inputs() { + let a = PathBuf::from("/tmp/a"); + let b = PathBuf::from("/tmp/b"); + + let first = normalize_extra_user_roots(&[a.clone(), b.clone(), a.clone()]); + let second = normalize_extra_user_roots(&[b, a]); + + assert_eq!(first, second); +} + +#[cfg_attr(windows, ignore)] +#[test] +fn disabled_paths_from_stack_allows_session_flags_to_override_user_layer() { + let tempdir = tempfile::tempdir().expect("tempdir"); + let skill_path = tempdir.path().join("skills").join("demo").join("SKILL.md"); + let user_file = AbsolutePathBuf::try_from(tempdir.path().join("config.toml")) + .expect("user config path should be absolute"); + let user_layer = ConfigLayerEntry::new( + ConfigLayerSource::User { file: user_file }, + toml::from_str(&format!( + r#"[[skills.config]] +path = "{}" +enabled = false +"#, + skill_path.display() + )) + .expect("user layer toml"), + ); + let session_layer = ConfigLayerEntry::new( + ConfigLayerSource::SessionFlags, + toml::from_str(&format!( + r#"[[skills.config]] +path = "{}" +enabled = true +"#, + skill_path.display() + )) + .expect("session layer toml"), + ); + let stack = ConfigLayerStack::new( + vec![user_layer, session_layer], + Default::default(), + ConfigRequirementsToml::default(), + ) + .expect("valid config layer stack"); + + assert_eq!(disabled_paths_from_stack(&stack), HashSet::new()); +} + +#[cfg_attr(windows, ignore)] +#[test] +fn disabled_paths_from_stack_allows_session_flags_to_disable_user_enabled_skill() { + let tempdir = tempfile::tempdir().expect("tempdir"); + let skill_path = tempdir.path().join("skills").join("demo").join("SKILL.md"); + let user_file = AbsolutePathBuf::try_from(tempdir.path().join("config.toml")) + .expect("user config path should be absolute"); + let user_layer = ConfigLayerEntry::new( + ConfigLayerSource::User { file: user_file }, + toml::from_str(&format!( + r#"[[skills.config]] +path = "{}" +enabled = true +"#, + skill_path.display() + )) + .expect("user layer toml"), + ); + let session_layer = ConfigLayerEntry::new( + ConfigLayerSource::SessionFlags, + toml::from_str(&format!( + r#"[[skills.config]] +path = "{}" +enabled = false +"#, + skill_path.display() + )) + .expect("session layer toml"), + ); + let stack = ConfigLayerStack::new( + vec![user_layer, session_layer], + Default::default(), + ConfigRequirementsToml::default(), + ) + .expect("valid config layer stack"); + + assert_eq!( + disabled_paths_from_stack(&stack), + HashSet::from([skill_path]) + ); +} diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index a40405d1d1..40faa4b856 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -237,160 +237,5 @@ fn merge_rate_limit_fields( } #[cfg(test)] -mod tests { - use super::*; - use crate::codex::make_session_configuration_for_tests; - use crate::protocol::RateLimitWindow; - use pretty_assertions::assert_eq; - - #[tokio::test] - // Verifies connector merging deduplicates repeated IDs. - async fn merge_connector_selection_deduplicates_entries() { - let session_configuration = make_session_configuration_for_tests().await; - let mut state = SessionState::new(session_configuration); - let merged = state.merge_connector_selection([ - "calendar".to_string(), - "calendar".to_string(), - "drive".to_string(), - ]); - - assert_eq!( - merged, - HashSet::from(["calendar".to_string(), "drive".to_string()]) - ); - } - - #[tokio::test] - // Verifies clearing connector selection removes all saved IDs. - async fn clear_connector_selection_removes_entries() { - let session_configuration = make_session_configuration_for_tests().await; - let mut state = SessionState::new(session_configuration); - state.merge_connector_selection(["calendar".to_string()]); - - state.clear_connector_selection(); - - assert_eq!(state.get_connector_selection(), HashSet::new()); - } - - #[tokio::test] - async fn set_rate_limits_defaults_limit_id_to_codex_when_missing() { - let session_configuration = make_session_configuration_for_tests().await; - let mut state = SessionState::new(session_configuration); - - state.set_rate_limits(RateLimitSnapshot { - limit_id: None, - limit_name: None, - primary: Some(RateLimitWindow { - used_percent: 12.0, - window_minutes: Some(60), - resets_at: Some(100), - }), - secondary: None, - credits: None, - plan_type: None, - }); - - assert_eq!( - state - .latest_rate_limits - .as_ref() - .and_then(|v| v.limit_id.clone()), - Some("codex".to_string()) - ); - } - - #[tokio::test] - async fn set_rate_limits_defaults_to_codex_when_limit_id_missing_after_other_bucket() { - let session_configuration = make_session_configuration_for_tests().await; - let mut state = SessionState::new(session_configuration); - - state.set_rate_limits(RateLimitSnapshot { - limit_id: Some("codex_other".to_string()), - limit_name: Some("codex_other".to_string()), - primary: Some(RateLimitWindow { - used_percent: 20.0, - window_minutes: Some(60), - resets_at: Some(200), - }), - secondary: None, - credits: None, - plan_type: None, - }); - state.set_rate_limits(RateLimitSnapshot { - limit_id: None, - limit_name: None, - primary: Some(RateLimitWindow { - used_percent: 30.0, - window_minutes: Some(60), - resets_at: Some(300), - }), - secondary: None, - credits: None, - plan_type: None, - }); - - assert_eq!( - state - .latest_rate_limits - .as_ref() - .and_then(|v| v.limit_id.clone()), - Some("codex".to_string()) - ); - } - - #[tokio::test] - async fn set_rate_limits_carries_credits_and_plan_type_from_codex_to_codex_other() { - let session_configuration = make_session_configuration_for_tests().await; - let mut state = SessionState::new(session_configuration); - - state.set_rate_limits(RateLimitSnapshot { - limit_id: Some("codex".to_string()), - limit_name: Some("codex".to_string()), - primary: Some(RateLimitWindow { - used_percent: 10.0, - window_minutes: Some(60), - resets_at: Some(100), - }), - secondary: None, - credits: Some(crate::protocol::CreditsSnapshot { - has_credits: true, - unlimited: false, - balance: Some("50".to_string()), - }), - plan_type: Some(codex_protocol::account::PlanType::Plus), - }); - - state.set_rate_limits(RateLimitSnapshot { - limit_id: Some("codex_other".to_string()), - limit_name: None, - primary: Some(RateLimitWindow { - used_percent: 30.0, - window_minutes: Some(120), - resets_at: Some(200), - }), - secondary: None, - credits: None, - plan_type: None, - }); - - assert_eq!( - state.latest_rate_limits, - Some(RateLimitSnapshot { - limit_id: Some("codex_other".to_string()), - limit_name: None, - primary: Some(RateLimitWindow { - used_percent: 30.0, - window_minutes: Some(120), - resets_at: Some(200), - }), - secondary: None, - credits: Some(crate::protocol::CreditsSnapshot { - has_credits: true, - unlimited: false, - balance: Some("50".to_string()), - }), - plan_type: Some(codex_protocol::account::PlanType::Plus), - }) - ); - } -} +#[path = "session_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/state/session_tests.rs b/codex-rs/core/src/state/session_tests.rs new file mode 100644 index 0000000000..2b7c276d7e --- /dev/null +++ b/codex-rs/core/src/state/session_tests.rs @@ -0,0 +1,155 @@ +use super::*; +use crate::codex::make_session_configuration_for_tests; +use crate::protocol::RateLimitWindow; +use pretty_assertions::assert_eq; + +#[tokio::test] +// Verifies connector merging deduplicates repeated IDs. +async fn merge_connector_selection_deduplicates_entries() { + let session_configuration = make_session_configuration_for_tests().await; + let mut state = SessionState::new(session_configuration); + let merged = state.merge_connector_selection([ + "calendar".to_string(), + "calendar".to_string(), + "drive".to_string(), + ]); + + assert_eq!( + merged, + HashSet::from(["calendar".to_string(), "drive".to_string()]) + ); +} + +#[tokio::test] +// Verifies clearing connector selection removes all saved IDs. +async fn clear_connector_selection_removes_entries() { + let session_configuration = make_session_configuration_for_tests().await; + let mut state = SessionState::new(session_configuration); + state.merge_connector_selection(["calendar".to_string()]); + + state.clear_connector_selection(); + + assert_eq!(state.get_connector_selection(), HashSet::new()); +} + +#[tokio::test] +async fn set_rate_limits_defaults_limit_id_to_codex_when_missing() { + let session_configuration = make_session_configuration_for_tests().await; + let mut state = SessionState::new(session_configuration); + + state.set_rate_limits(RateLimitSnapshot { + limit_id: None, + limit_name: None, + primary: Some(RateLimitWindow { + used_percent: 12.0, + window_minutes: Some(60), + resets_at: Some(100), + }), + secondary: None, + credits: None, + plan_type: None, + }); + + assert_eq!( + state + .latest_rate_limits + .as_ref() + .and_then(|v| v.limit_id.clone()), + Some("codex".to_string()) + ); +} + +#[tokio::test] +async fn set_rate_limits_defaults_to_codex_when_limit_id_missing_after_other_bucket() { + let session_configuration = make_session_configuration_for_tests().await; + let mut state = SessionState::new(session_configuration); + + state.set_rate_limits(RateLimitSnapshot { + limit_id: Some("codex_other".to_string()), + limit_name: Some("codex_other".to_string()), + primary: Some(RateLimitWindow { + used_percent: 20.0, + window_minutes: Some(60), + resets_at: Some(200), + }), + secondary: None, + credits: None, + plan_type: None, + }); + state.set_rate_limits(RateLimitSnapshot { + limit_id: None, + limit_name: None, + primary: Some(RateLimitWindow { + used_percent: 30.0, + window_minutes: Some(60), + resets_at: Some(300), + }), + secondary: None, + credits: None, + plan_type: None, + }); + + assert_eq!( + state + .latest_rate_limits + .as_ref() + .and_then(|v| v.limit_id.clone()), + Some("codex".to_string()) + ); +} + +#[tokio::test] +async fn set_rate_limits_carries_credits_and_plan_type_from_codex_to_codex_other() { + let session_configuration = make_session_configuration_for_tests().await; + let mut state = SessionState::new(session_configuration); + + state.set_rate_limits(RateLimitSnapshot { + limit_id: Some("codex".to_string()), + limit_name: Some("codex".to_string()), + primary: Some(RateLimitWindow { + used_percent: 10.0, + window_minutes: Some(60), + resets_at: Some(100), + }), + secondary: None, + credits: Some(crate::protocol::CreditsSnapshot { + has_credits: true, + unlimited: false, + balance: Some("50".to_string()), + }), + plan_type: Some(codex_protocol::account::PlanType::Plus), + }); + + state.set_rate_limits(RateLimitSnapshot { + limit_id: Some("codex_other".to_string()), + limit_name: None, + primary: Some(RateLimitWindow { + used_percent: 30.0, + window_minutes: Some(120), + resets_at: Some(200), + }), + secondary: None, + credits: None, + plan_type: None, + }); + + assert_eq!( + state.latest_rate_limits, + Some(RateLimitSnapshot { + limit_id: Some("codex_other".to_string()), + limit_name: None, + primary: Some(RateLimitWindow { + used_percent: 30.0, + window_minutes: Some(120), + resets_at: Some(200), + }), + secondary: None, + credits: Some(crate::protocol::CreditsSnapshot { + has_credits: true, + unlimited: false, + balance: Some("50".to_string()), + }), + plan_type: Some(codex_protocol::account::PlanType::Plus), + }) + ); +} diff --git a/codex-rs/core/src/state_db.rs b/codex-rs/core/src/state_db.rs index b53b748f3f..536b335035 100644 --- a/codex-rs/core/src/state_db.rs +++ b/codex-rs/core/src/state_db.rs @@ -543,26 +543,5 @@ pub async fn touch_thread_updated_at( } #[cfg(test)] -mod tests { - use super::*; - use crate::rollout::list::parse_cursor; - use pretty_assertions::assert_eq; - - #[test] - fn cursor_to_anchor_normalizes_timestamp_format() { - let uuid = Uuid::new_v4(); - let ts_str = "2026-01-27T12-34-56"; - let token = format!("{ts_str}|{uuid}"); - let cursor = parse_cursor(token.as_str()).expect("cursor should parse"); - let anchor = cursor_to_anchor(Some(&cursor)).expect("anchor should parse"); - - let naive = - NaiveDateTime::parse_from_str(ts_str, "%Y-%m-%dT%H-%M-%S").expect("ts should parse"); - let expected_ts = DateTime::::from_naive_utc_and_offset(naive, Utc) - .with_nanosecond(0) - .expect("nanosecond"); - - assert_eq!(anchor.id, uuid); - assert_eq!(anchor.ts, expected_ts); - } -} +#[path = "state_db_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/state_db_tests.rs b/codex-rs/core/src/state_db_tests.rs new file mode 100644 index 0000000000..adf08197d6 --- /dev/null +++ b/codex-rs/core/src/state_db_tests.rs @@ -0,0 +1,21 @@ +use super::*; +use crate::rollout::list::parse_cursor; +use pretty_assertions::assert_eq; + +#[test] +fn cursor_to_anchor_normalizes_timestamp_format() { + let uuid = Uuid::new_v4(); + let ts_str = "2026-01-27T12-34-56"; + let token = format!("{ts_str}|{uuid}"); + let cursor = parse_cursor(token.as_str()).expect("cursor should parse"); + let anchor = cursor_to_anchor(Some(&cursor)).expect("anchor should parse"); + + let naive = + NaiveDateTime::parse_from_str(ts_str, "%Y-%m-%dT%H-%M-%S").expect("ts should parse"); + let expected_ts = DateTime::::from_naive_utc_and_offset(naive, Utc) + .with_nanosecond(0) + .expect("nanosecond"); + + assert_eq!(anchor.id, uuid); + assert_eq!(anchor.ts, expected_ts); +} diff --git a/codex-rs/core/src/stream_events_utils.rs b/codex-rs/core/src/stream_events_utils.rs index 22351e8138..e5231314e5 100644 --- a/codex-rs/core/src/stream_events_utils.rs +++ b/codex-rs/core/src/stream_events_utils.rs @@ -399,143 +399,5 @@ pub(crate) fn response_input_to_response_item(input: &ResponseInputItem) -> Opti } #[cfg(test)] -mod tests { - use super::default_image_generation_output_dir; - use super::handle_non_tool_response_item; - use super::last_assistant_message_from_item; - use super::save_image_generation_result; - use crate::codex::make_session_and_context; - use crate::error::CodexErr; - use codex_protocol::items::TurnItem; - use codex_protocol::models::ContentItem; - use codex_protocol::models::ResponseItem; - use pretty_assertions::assert_eq; - - fn assistant_output_text(text: &str) -> ResponseItem { - ResponseItem::Message { - id: Some("msg-1".to_string()), - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: text.to_string(), - }], - end_turn: Some(true), - phase: None, - } - } - - #[tokio::test] - async fn handle_non_tool_response_item_strips_citations_from_assistant_message() { - let (session, turn_context) = make_session_and_context().await; - let item = assistant_output_text("hellodoc1 world"); - - let turn_item = handle_non_tool_response_item(&session, &turn_context, &item, false) - .await - .expect("assistant message should parse"); - - let TurnItem::AgentMessage(agent_message) = turn_item else { - panic!("expected agent message"); - }; - let text = agent_message - .content - .iter() - .map(|entry| match entry { - codex_protocol::items::AgentMessageContent::Text { text } => text.as_str(), - }) - .collect::(); - assert_eq!(text, "hello world"); - } - - #[test] - fn last_assistant_message_from_item_strips_citations_and_plan_blocks() { - let item = assistant_output_text( - "beforedoc1\n\n- x\n\nafter", - ); - - let message = last_assistant_message_from_item(&item, true) - .expect("assistant text should remain after stripping"); - - assert_eq!(message, "before\nafter"); - } - - #[test] - fn last_assistant_message_from_item_returns_none_for_citation_only_message() { - let item = assistant_output_text("doc1"); - - assert_eq!(last_assistant_message_from_item(&item, false), None); - } - - #[test] - fn last_assistant_message_from_item_returns_none_for_plan_only_hidden_message() { - let item = assistant_output_text("\n- x\n"); - - assert_eq!(last_assistant_message_from_item(&item, true), None); - } - - #[tokio::test] - async fn save_image_generation_result_saves_base64_to_png_in_temp_dir() { - let expected_path = default_image_generation_output_dir().join("ig_save_base64.png"); - let _ = std::fs::remove_file(&expected_path); - - let saved_path = save_image_generation_result("ig_save_base64", "Zm9v") - .await - .expect("image should be saved"); - - assert_eq!(saved_path, expected_path); - assert_eq!(std::fs::read(&saved_path).expect("saved file"), b"foo"); - let _ = std::fs::remove_file(&saved_path); - } - - #[tokio::test] - async fn save_image_generation_result_rejects_data_url_payload() { - let result = "data:image/jpeg;base64,Zm9v"; - - let err = save_image_generation_result("ig_456", result) - .await - .expect_err("data url payload should error"); - assert!(matches!(err, CodexErr::InvalidRequest(_))); - } - - #[tokio::test] - async fn save_image_generation_result_overwrites_existing_file() { - let existing_path = default_image_generation_output_dir().join("ig_overwrite.png"); - std::fs::write(&existing_path, b"existing").expect("seed existing image"); - - let saved_path = save_image_generation_result("ig_overwrite", "Zm9v") - .await - .expect("image should be saved"); - - assert_eq!(saved_path, existing_path); - assert_eq!(std::fs::read(&saved_path).expect("saved file"), b"foo"); - let _ = std::fs::remove_file(&saved_path); - } - - #[tokio::test] - async fn save_image_generation_result_sanitizes_call_id_for_temp_dir_output_path() { - let expected_path = default_image_generation_output_dir().join("___ig___.png"); - let _ = std::fs::remove_file(&expected_path); - - let saved_path = save_image_generation_result("../ig/..", "Zm9v") - .await - .expect("image should be saved"); - - assert_eq!(saved_path, expected_path); - assert_eq!(std::fs::read(&saved_path).expect("saved file"), b"foo"); - let _ = std::fs::remove_file(&saved_path); - } - - #[tokio::test] - async fn save_image_generation_result_rejects_non_standard_base64() { - let err = save_image_generation_result("ig_urlsafe", "_-8") - .await - .expect_err("non-standard base64 should error"); - assert!(matches!(err, CodexErr::InvalidRequest(_))); - } - - #[tokio::test] - async fn save_image_generation_result_rejects_non_base64_data_urls() { - let err = save_image_generation_result("ig_svg", "data:image/svg+xml,") - .await - .expect_err("non-base64 data url should error"); - assert!(matches!(err, CodexErr::InvalidRequest(_))); - } -} +#[path = "stream_events_utils_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/stream_events_utils_tests.rs b/codex-rs/core/src/stream_events_utils_tests.rs new file mode 100644 index 0000000000..b7f81be735 --- /dev/null +++ b/codex-rs/core/src/stream_events_utils_tests.rs @@ -0,0 +1,138 @@ +use super::default_image_generation_output_dir; +use super::handle_non_tool_response_item; +use super::last_assistant_message_from_item; +use super::save_image_generation_result; +use crate::codex::make_session_and_context; +use crate::error::CodexErr; +use codex_protocol::items::TurnItem; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseItem; +use pretty_assertions::assert_eq; + +fn assistant_output_text(text: &str) -> ResponseItem { + ResponseItem::Message { + id: Some("msg-1".to_string()), + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + end_turn: Some(true), + phase: None, + } +} + +#[tokio::test] +async fn handle_non_tool_response_item_strips_citations_from_assistant_message() { + let (session, turn_context) = make_session_and_context().await; + let item = assistant_output_text("hellodoc1 world"); + + let turn_item = handle_non_tool_response_item(&session, &turn_context, &item, false) + .await + .expect("assistant message should parse"); + + let TurnItem::AgentMessage(agent_message) = turn_item else { + panic!("expected agent message"); + }; + let text = agent_message + .content + .iter() + .map(|entry| match entry { + codex_protocol::items::AgentMessageContent::Text { text } => text.as_str(), + }) + .collect::(); + assert_eq!(text, "hello world"); +} + +#[test] +fn last_assistant_message_from_item_strips_citations_and_plan_blocks() { + let item = assistant_output_text( + "beforedoc1\n\n- x\n\nafter", + ); + + let message = last_assistant_message_from_item(&item, true) + .expect("assistant text should remain after stripping"); + + assert_eq!(message, "before\nafter"); +} + +#[test] +fn last_assistant_message_from_item_returns_none_for_citation_only_message() { + let item = assistant_output_text("doc1"); + + assert_eq!(last_assistant_message_from_item(&item, false), None); +} + +#[test] +fn last_assistant_message_from_item_returns_none_for_plan_only_hidden_message() { + let item = assistant_output_text("\n- x\n"); + + assert_eq!(last_assistant_message_from_item(&item, true), None); +} + +#[tokio::test] +async fn save_image_generation_result_saves_base64_to_png_in_temp_dir() { + let expected_path = default_image_generation_output_dir().join("ig_save_base64.png"); + let _ = std::fs::remove_file(&expected_path); + + let saved_path = save_image_generation_result("ig_save_base64", "Zm9v") + .await + .expect("image should be saved"); + + assert_eq!(saved_path, expected_path); + assert_eq!(std::fs::read(&saved_path).expect("saved file"), b"foo"); + let _ = std::fs::remove_file(&saved_path); +} + +#[tokio::test] +async fn save_image_generation_result_rejects_data_url_payload() { + let result = "data:image/jpeg;base64,Zm9v"; + + let err = save_image_generation_result("ig_456", result) + .await + .expect_err("data url payload should error"); + assert!(matches!(err, CodexErr::InvalidRequest(_))); +} + +#[tokio::test] +async fn save_image_generation_result_overwrites_existing_file() { + let existing_path = default_image_generation_output_dir().join("ig_overwrite.png"); + std::fs::write(&existing_path, b"existing").expect("seed existing image"); + + let saved_path = save_image_generation_result("ig_overwrite", "Zm9v") + .await + .expect("image should be saved"); + + assert_eq!(saved_path, existing_path); + assert_eq!(std::fs::read(&saved_path).expect("saved file"), b"foo"); + let _ = std::fs::remove_file(&saved_path); +} + +#[tokio::test] +async fn save_image_generation_result_sanitizes_call_id_for_temp_dir_output_path() { + let expected_path = default_image_generation_output_dir().join("___ig___.png"); + let _ = std::fs::remove_file(&expected_path); + + let saved_path = save_image_generation_result("../ig/..", "Zm9v") + .await + .expect("image should be saved"); + + assert_eq!(saved_path, expected_path); + assert_eq!(std::fs::read(&saved_path).expect("saved file"), b"foo"); + let _ = std::fs::remove_file(&saved_path); +} + +#[tokio::test] +async fn save_image_generation_result_rejects_non_standard_base64() { + let err = save_image_generation_result("ig_urlsafe", "_-8") + .await + .expect_err("non-standard base64 should error"); + assert!(matches!(err, CodexErr::InvalidRequest(_))); +} + +#[tokio::test] +async fn save_image_generation_result_rejects_non_base64_data_urls() { + let err = save_image_generation_result("ig_svg", "data:image/svg+xml,") + .await + .expect_err("non-base64 data url should error"); + assert!(matches!(err, CodexErr::InvalidRequest(_))); +} diff --git a/codex-rs/core/src/tasks/ghost_snapshot.rs b/codex-rs/core/src/tasks/ghost_snapshot.rs index ded8533a23..01aa9758f7 100644 --- a/codex-rs/core/src/tasks/ghost_snapshot.rs +++ b/codex-rs/core/src/tasks/ghost_snapshot.rs @@ -250,36 +250,5 @@ fn format_bytes(bytes: i64) -> String { } #[cfg(test)] -mod tests { - use super::*; - use codex_git::LargeUntrackedDir; - use pretty_assertions::assert_eq; - use std::path::PathBuf; - - #[test] - fn large_untracked_warning_includes_threshold() { - let report = GhostSnapshotReport { - large_untracked_dirs: vec![LargeUntrackedDir { - path: PathBuf::from("models"), - file_count: 250, - }], - ignored_untracked_files: Vec::new(), - }; - - let message = format_large_untracked_warning(Some(200), &report).unwrap(); - assert!(message.contains(">= 200 files")); - } - - #[test] - fn large_untracked_warning_disabled_when_threshold_disabled() { - let report = GhostSnapshotReport { - large_untracked_dirs: vec![LargeUntrackedDir { - path: PathBuf::from("models"), - file_count: 250, - }], - ignored_untracked_files: Vec::new(), - }; - - assert_eq!(format_large_untracked_warning(None, &report), None); - } -} +#[path = "ghost_snapshot_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tasks/ghost_snapshot_tests.rs b/codex-rs/core/src/tasks/ghost_snapshot_tests.rs new file mode 100644 index 0000000000..1884a9bd05 --- /dev/null +++ b/codex-rs/core/src/tasks/ghost_snapshot_tests.rs @@ -0,0 +1,31 @@ +use super::*; +use codex_git::LargeUntrackedDir; +use pretty_assertions::assert_eq; +use std::path::PathBuf; + +#[test] +fn large_untracked_warning_includes_threshold() { + let report = GhostSnapshotReport { + large_untracked_dirs: vec![LargeUntrackedDir { + path: PathBuf::from("models"), + file_count: 250, + }], + ignored_untracked_files: Vec::new(), + }; + + let message = format_large_untracked_warning(Some(200), &report).unwrap(); + assert!(message.contains(">= 200 files")); +} + +#[test] +fn large_untracked_warning_disabled_when_threshold_disabled() { + let report = GhostSnapshotReport { + large_untracked_dirs: vec![LargeUntrackedDir { + path: PathBuf::from("models"), + file_count: 250, + }], + ignored_untracked_files: Vec::new(), + }; + + assert_eq!(format_large_untracked_warning(None, &report), None); +} diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index 638cb3febd..652f13525d 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -454,119 +454,5 @@ impl Session { } #[cfg(test)] -mod tests { - use super::emit_turn_network_proxy_metric; - use codex_otel::SessionTelemetry; - use codex_otel::metrics::MetricsClient; - use codex_otel::metrics::MetricsConfig; - use codex_otel::metrics::names::TURN_NETWORK_PROXY_METRIC; - use codex_protocol::ThreadId; - use codex_protocol::protocol::SessionSource; - use opentelemetry::KeyValue; - use opentelemetry_sdk::metrics::InMemoryMetricExporter; - use opentelemetry_sdk::metrics::data::AggregatedMetrics; - use opentelemetry_sdk::metrics::data::Metric; - use opentelemetry_sdk::metrics::data::MetricData; - use opentelemetry_sdk::metrics::data::ResourceMetrics; - use pretty_assertions::assert_eq; - use std::collections::BTreeMap; - - fn test_session_telemetry() -> SessionTelemetry { - let exporter = InMemoryMetricExporter::default(); - let metrics = MetricsClient::new( - MetricsConfig::in_memory("test", "codex-core", env!("CARGO_PKG_VERSION"), exporter) - .with_runtime_reader(), - ) - .expect("in-memory metrics client"); - SessionTelemetry::new( - ThreadId::new(), - "gpt-5.1", - "gpt-5.1", - None, - None, - None, - "test_originator".to_string(), - false, - "tty".to_string(), - SessionSource::Cli, - ) - .with_metrics_without_metadata_tags(metrics) - } - - fn find_metric<'a>(resource_metrics: &'a ResourceMetrics, name: &str) -> &'a Metric { - for scope_metrics in resource_metrics.scope_metrics() { - for metric in scope_metrics.metrics() { - if metric.name() == name { - return metric; - } - } - } - panic!("metric {name} missing"); - } - - fn attributes_to_map<'a>( - attributes: impl Iterator, - ) -> BTreeMap { - attributes - .map(|kv| (kv.key.as_str().to_string(), kv.value.as_str().to_string())) - .collect() - } - - fn metric_point(resource_metrics: &ResourceMetrics) -> (BTreeMap, u64) { - let metric = find_metric(resource_metrics, TURN_NETWORK_PROXY_METRIC); - match metric.data() { - AggregatedMetrics::U64(data) => match data { - MetricData::Sum(sum) => { - let points: Vec<_> = sum.data_points().collect(); - assert_eq!(points.len(), 1); - let point = points[0]; - (attributes_to_map(point.attributes()), point.value()) - } - _ => panic!("unexpected counter aggregation"), - }, - _ => panic!("unexpected counter data type"), - } - } - - #[test] - fn emit_turn_network_proxy_metric_records_active_turn() { - let session_telemetry = test_session_telemetry(); - - emit_turn_network_proxy_metric(&session_telemetry, true, ("tmp_mem_enabled", "true")); - - let snapshot = session_telemetry - .snapshot_metrics() - .expect("runtime metrics snapshot"); - let (attrs, value) = metric_point(&snapshot); - - assert_eq!(value, 1); - assert_eq!( - attrs, - BTreeMap::from([ - ("active".to_string(), "true".to_string()), - ("tmp_mem_enabled".to_string(), "true".to_string()), - ]) - ); - } - - #[test] - fn emit_turn_network_proxy_metric_records_inactive_turn() { - let session_telemetry = test_session_telemetry(); - - emit_turn_network_proxy_metric(&session_telemetry, false, ("tmp_mem_enabled", "false")); - - let snapshot = session_telemetry - .snapshot_metrics() - .expect("runtime metrics snapshot"); - let (attrs, value) = metric_point(&snapshot); - - assert_eq!(value, 1); - assert_eq!( - attrs, - BTreeMap::from([ - ("active".to_string(), "false".to_string()), - ("tmp_mem_enabled".to_string(), "false".to_string()), - ]) - ); - } -} +#[path = "mod_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tasks/mod_tests.rs b/codex-rs/core/src/tasks/mod_tests.rs new file mode 100644 index 0000000000..7a55d55f6f --- /dev/null +++ b/codex-rs/core/src/tasks/mod_tests.rs @@ -0,0 +1,114 @@ +use super::emit_turn_network_proxy_metric; +use codex_otel::SessionTelemetry; +use codex_otel::metrics::MetricsClient; +use codex_otel::metrics::MetricsConfig; +use codex_otel::metrics::names::TURN_NETWORK_PROXY_METRIC; +use codex_protocol::ThreadId; +use codex_protocol::protocol::SessionSource; +use opentelemetry::KeyValue; +use opentelemetry_sdk::metrics::InMemoryMetricExporter; +use opentelemetry_sdk::metrics::data::AggregatedMetrics; +use opentelemetry_sdk::metrics::data::Metric; +use opentelemetry_sdk::metrics::data::MetricData; +use opentelemetry_sdk::metrics::data::ResourceMetrics; +use pretty_assertions::assert_eq; +use std::collections::BTreeMap; + +fn test_session_telemetry() -> SessionTelemetry { + let exporter = InMemoryMetricExporter::default(); + let metrics = MetricsClient::new( + MetricsConfig::in_memory("test", "codex-core", env!("CARGO_PKG_VERSION"), exporter) + .with_runtime_reader(), + ) + .expect("in-memory metrics client"); + SessionTelemetry::new( + ThreadId::new(), + "gpt-5.1", + "gpt-5.1", + None, + None, + None, + "test_originator".to_string(), + false, + "tty".to_string(), + SessionSource::Cli, + ) + .with_metrics_without_metadata_tags(metrics) +} + +fn find_metric<'a>(resource_metrics: &'a ResourceMetrics, name: &str) -> &'a Metric { + for scope_metrics in resource_metrics.scope_metrics() { + for metric in scope_metrics.metrics() { + if metric.name() == name { + return metric; + } + } + } + panic!("metric {name} missing"); +} + +fn attributes_to_map<'a>( + attributes: impl Iterator, +) -> BTreeMap { + attributes + .map(|kv| (kv.key.as_str().to_string(), kv.value.as_str().to_string())) + .collect() +} + +fn metric_point(resource_metrics: &ResourceMetrics) -> (BTreeMap, u64) { + let metric = find_metric(resource_metrics, TURN_NETWORK_PROXY_METRIC); + match metric.data() { + AggregatedMetrics::U64(data) => match data { + MetricData::Sum(sum) => { + let points: Vec<_> = sum.data_points().collect(); + assert_eq!(points.len(), 1); + let point = points[0]; + (attributes_to_map(point.attributes()), point.value()) + } + _ => panic!("unexpected counter aggregation"), + }, + _ => panic!("unexpected counter data type"), + } +} + +#[test] +fn emit_turn_network_proxy_metric_records_active_turn() { + let session_telemetry = test_session_telemetry(); + + emit_turn_network_proxy_metric(&session_telemetry, true, ("tmp_mem_enabled", "true")); + + let snapshot = session_telemetry + .snapshot_metrics() + .expect("runtime metrics snapshot"); + let (attrs, value) = metric_point(&snapshot); + + assert_eq!(value, 1); + assert_eq!( + attrs, + BTreeMap::from([ + ("active".to_string(), "true".to_string()), + ("tmp_mem_enabled".to_string(), "true".to_string()), + ]) + ); +} + +#[test] +fn emit_turn_network_proxy_metric_records_inactive_turn() { + let session_telemetry = test_session_telemetry(); + + emit_turn_network_proxy_metric(&session_telemetry, false, ("tmp_mem_enabled", "false")); + + let snapshot = session_telemetry + .snapshot_metrics() + .expect("runtime metrics snapshot"); + let (attrs, value) = metric_point(&snapshot); + + assert_eq!(value, 1); + assert_eq!( + attrs, + BTreeMap::from([ + ("active".to_string(), "false".to_string()), + ("tmp_mem_enabled".to_string(), "false".to_string()), + ]) + ); +} diff --git a/codex-rs/core/src/terminal.rs b/codex-rs/core/src/terminal.rs index b91aef106b..f875fcd9e6 100644 --- a/codex-rs/core/src/terminal.rs +++ b/codex-rs/core/src/terminal.rs @@ -461,707 +461,5 @@ fn none_if_whitespace(value: String) -> Option { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use std::collections::HashMap; - - struct FakeEnvironment { - vars: HashMap, - tmux_client_info: TmuxClientInfo, - } - - impl FakeEnvironment { - fn new() -> Self { - Self { - vars: HashMap::new(), - tmux_client_info: TmuxClientInfo::default(), - } - } - - fn with_var(mut self, key: &str, value: &str) -> Self { - self.vars.insert(key.to_string(), value.to_string()); - self - } - - fn with_tmux_client_info(mut self, termtype: Option<&str>, termname: Option<&str>) -> Self { - self.tmux_client_info = TmuxClientInfo { - termtype: termtype.map(ToString::to_string), - termname: termname.map(ToString::to_string), - }; - self - } - } - - impl Environment for FakeEnvironment { - fn var(&self, name: &str) -> Option { - self.vars.get(name).cloned() - } - - fn tmux_client_info(&self) -> TmuxClientInfo { - self.tmux_client_info.clone() - } - } - - fn terminal_info( - name: TerminalName, - term_program: Option<&str>, - version: Option<&str>, - term: Option<&str>, - multiplexer: Option, - ) -> TerminalInfo { - TerminalInfo { - name, - term_program: term_program.map(ToString::to_string), - version: version.map(ToString::to_string), - term: term.map(ToString::to_string), - multiplexer, - } - } - - #[test] - fn detects_term_program() { - let env = FakeEnvironment::new() - .with_var("TERM_PROGRAM", "iTerm.app") - .with_var("TERM_PROGRAM_VERSION", "3.5.0") - .with_var("WEZTERM_VERSION", "2024.2"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::Iterm2, - Some("iTerm.app"), - Some("3.5.0"), - None, - None, - ), - "term_program_with_version_info" - ); - assert_eq!( - terminal.user_agent_token(), - "iTerm.app/3.5.0", - "term_program_with_version_user_agent" - ); - - let env = FakeEnvironment::new() - .with_var("TERM_PROGRAM", "iTerm.app") - .with_var("TERM_PROGRAM_VERSION", ""); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Iterm2, Some("iTerm.app"), None, None, None), - "term_program_without_version_info" - ); - assert_eq!( - terminal.user_agent_token(), - "iTerm.app", - "term_program_without_version_user_agent" - ); - - let env = FakeEnvironment::new() - .with_var("TERM_PROGRAM", "iTerm.app") - .with_var("WEZTERM_VERSION", "2024.2"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Iterm2, Some("iTerm.app"), None, None, None), - "term_program_overrides_wezterm_info" - ); - assert_eq!( - terminal.user_agent_token(), - "iTerm.app", - "term_program_overrides_wezterm_user_agent" - ); - } - - #[test] - fn detects_iterm2() { - let env = FakeEnvironment::new().with_var("ITERM_SESSION_ID", "w0t1p0"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Iterm2, None, None, None, None), - "iterm_session_id_info" - ); - assert_eq!( - terminal.user_agent_token(), - "iTerm.app", - "iterm_session_id_user_agent" - ); - } - - #[test] - fn detects_apple_terminal() { - let env = FakeEnvironment::new().with_var("TERM_PROGRAM", "Apple_Terminal"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::AppleTerminal, - Some("Apple_Terminal"), - None, - None, - None, - ), - "apple_term_program_info" - ); - assert_eq!( - terminal.user_agent_token(), - "Apple_Terminal", - "apple_term_program_user_agent" - ); - - let env = FakeEnvironment::new().with_var("TERM_SESSION_ID", "A1B2C3"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::AppleTerminal, None, None, None, None), - "apple_term_session_id_info" - ); - assert_eq!( - terminal.user_agent_token(), - "Apple_Terminal", - "apple_term_session_id_user_agent" - ); - } - - #[test] - fn detects_ghostty() { - let env = FakeEnvironment::new().with_var("TERM_PROGRAM", "Ghostty"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Ghostty, Some("Ghostty"), None, None, None), - "ghostty_term_program_info" - ); - assert_eq!( - terminal.user_agent_token(), - "Ghostty", - "ghostty_term_program_user_agent" - ); - } - - #[test] - fn detects_vscode() { - let env = FakeEnvironment::new() - .with_var("TERM_PROGRAM", "vscode") - .with_var("TERM_PROGRAM_VERSION", "1.86.0"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::VsCode, - Some("vscode"), - Some("1.86.0"), - None, - None - ), - "vscode_term_program_info" - ); - assert_eq!( - terminal.user_agent_token(), - "vscode/1.86.0", - "vscode_term_program_user_agent" - ); - } - - #[test] - fn detects_warp_terminal() { - let env = FakeEnvironment::new() - .with_var("TERM_PROGRAM", "WarpTerminal") - .with_var("TERM_PROGRAM_VERSION", "v0.2025.12.10.08.12.stable_03"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::WarpTerminal, - Some("WarpTerminal"), - Some("v0.2025.12.10.08.12.stable_03"), - None, - None, - ), - "warp_term_program_info" - ); - assert_eq!( - terminal.user_agent_token(), - "WarpTerminal/v0.2025.12.10.08.12.stable_03", - "warp_term_program_user_agent" - ); - } - - #[test] - fn detects_tmux_multiplexer() { - let env = FakeEnvironment::new() - .with_var("TMUX", "/tmp/tmux-1000/default,123,0") - .with_var("TERM_PROGRAM", "tmux") - .with_tmux_client_info(Some("xterm-256color"), Some("screen-256color")); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::Unknown, - Some("xterm-256color"), - None, - Some("screen-256color"), - Some(Multiplexer::Tmux { version: None }), - ), - "tmux_multiplexer_info" - ); - assert_eq!( - terminal.user_agent_token(), - "xterm-256color", - "tmux_multiplexer_user_agent" - ); - } - - #[test] - fn detects_zellij_multiplexer() { - let env = FakeEnvironment::new().with_var("ZELLIJ", "1"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - TerminalInfo { - name: TerminalName::Unknown, - term_program: None, - version: None, - term: None, - multiplexer: Some(Multiplexer::Zellij {}), - }, - "zellij_multiplexer" - ); - } - - #[test] - fn detects_tmux_client_termtype() { - let env = FakeEnvironment::new() - .with_var("TMUX", "/tmp/tmux-1000/default,123,0") - .with_var("TERM_PROGRAM", "tmux") - .with_tmux_client_info(Some("WezTerm"), None); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::WezTerm, - Some("WezTerm"), - None, - None, - Some(Multiplexer::Tmux { version: None }), - ), - "tmux_client_termtype_info" - ); - assert_eq!( - terminal.user_agent_token(), - "WezTerm", - "tmux_client_termtype_user_agent" - ); - } - - #[test] - fn detects_tmux_client_termname() { - let env = FakeEnvironment::new() - .with_var("TMUX", "/tmp/tmux-1000/default,123,0") - .with_var("TERM_PROGRAM", "tmux") - .with_tmux_client_info(None, Some("xterm-256color")); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::Unknown, - None, - None, - Some("xterm-256color"), - Some(Multiplexer::Tmux { version: None }) - ), - "tmux_client_termname_info" - ); - assert_eq!( - terminal.user_agent_token(), - "xterm-256color", - "tmux_client_termname_user_agent" - ); - } - - #[test] - fn detects_tmux_term_program_uses_client_termtype() { - let env = FakeEnvironment::new() - .with_var("TMUX", "/tmp/tmux-1000/default,123,0") - .with_var("TERM_PROGRAM", "tmux") - .with_var("TERM_PROGRAM_VERSION", "3.6a") - .with_tmux_client_info(Some("ghostty 1.2.3"), Some("xterm-ghostty")); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::Ghostty, - Some("ghostty"), - Some("1.2.3"), - Some("xterm-ghostty"), - Some(Multiplexer::Tmux { - version: Some("3.6a".to_string()), - }), - ), - "tmux_term_program_client_termtype_info" - ); - assert_eq!( - terminal.user_agent_token(), - "ghostty/1.2.3", - "tmux_term_program_client_termtype_user_agent" - ); - } - - #[test] - fn detects_wezterm() { - let env = FakeEnvironment::new().with_var("WEZTERM_VERSION", "2024.2"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::WezTerm, None, Some("2024.2"), None, None), - "wezterm_version_info" - ); - assert_eq!( - terminal.user_agent_token(), - "WezTerm/2024.2", - "wezterm_version_user_agent" - ); - - let env = FakeEnvironment::new() - .with_var("TERM_PROGRAM", "WezTerm") - .with_var("TERM_PROGRAM_VERSION", "2024.2"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::WezTerm, - Some("WezTerm"), - Some("2024.2"), - None, - None - ), - "wezterm_term_program_info" - ); - assert_eq!( - terminal.user_agent_token(), - "WezTerm/2024.2", - "wezterm_term_program_user_agent" - ); - - let env = FakeEnvironment::new().with_var("WEZTERM_VERSION", ""); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::WezTerm, None, None, None, None), - "wezterm_empty_info" - ); - assert_eq!( - terminal.user_agent_token(), - "WezTerm", - "wezterm_empty_user_agent" - ); - } - - #[test] - fn detects_kitty() { - let env = FakeEnvironment::new().with_var("KITTY_WINDOW_ID", "1"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Kitty, None, None, None, None), - "kitty_window_id_info" - ); - assert_eq!( - terminal.user_agent_token(), - "kitty", - "kitty_window_id_user_agent" - ); - - let env = FakeEnvironment::new() - .with_var("TERM_PROGRAM", "kitty") - .with_var("TERM_PROGRAM_VERSION", "0.30.1"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::Kitty, - Some("kitty"), - Some("0.30.1"), - None, - None - ), - "kitty_term_program_info" - ); - assert_eq!( - terminal.user_agent_token(), - "kitty/0.30.1", - "kitty_term_program_user_agent" - ); - - let env = FakeEnvironment::new() - .with_var("TERM", "xterm-kitty") - .with_var("ALACRITTY_SOCKET", "/tmp/alacritty"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Kitty, None, None, None, None), - "kitty_term_over_alacritty_info" - ); - assert_eq!( - terminal.user_agent_token(), - "kitty", - "kitty_term_over_alacritty_user_agent" - ); - } - - #[test] - fn detects_alacritty() { - let env = FakeEnvironment::new().with_var("ALACRITTY_SOCKET", "/tmp/alacritty"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Alacritty, None, None, None, None), - "alacritty_socket_info" - ); - assert_eq!( - terminal.user_agent_token(), - "Alacritty", - "alacritty_socket_user_agent" - ); - - let env = FakeEnvironment::new() - .with_var("TERM_PROGRAM", "Alacritty") - .with_var("TERM_PROGRAM_VERSION", "0.13.2"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::Alacritty, - Some("Alacritty"), - Some("0.13.2"), - None, - None, - ), - "alacritty_term_program_info" - ); - assert_eq!( - terminal.user_agent_token(), - "Alacritty/0.13.2", - "alacritty_term_program_user_agent" - ); - - let env = FakeEnvironment::new().with_var("TERM", "alacritty"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Alacritty, None, None, None, None), - "alacritty_term_info" - ); - assert_eq!( - terminal.user_agent_token(), - "Alacritty", - "alacritty_term_user_agent" - ); - } - - #[test] - fn detects_konsole() { - let env = FakeEnvironment::new().with_var("KONSOLE_VERSION", "230800"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Konsole, None, Some("230800"), None, None), - "konsole_version_info" - ); - assert_eq!( - terminal.user_agent_token(), - "Konsole/230800", - "konsole_version_user_agent" - ); - - let env = FakeEnvironment::new() - .with_var("TERM_PROGRAM", "Konsole") - .with_var("TERM_PROGRAM_VERSION", "230800"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::Konsole, - Some("Konsole"), - Some("230800"), - None, - None - ), - "konsole_term_program_info" - ); - assert_eq!( - terminal.user_agent_token(), - "Konsole/230800", - "konsole_term_program_user_agent" - ); - - let env = FakeEnvironment::new().with_var("KONSOLE_VERSION", ""); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Konsole, None, None, None, None), - "konsole_empty_info" - ); - assert_eq!( - terminal.user_agent_token(), - "Konsole", - "konsole_empty_user_agent" - ); - } - - #[test] - fn detects_gnome_terminal() { - let env = FakeEnvironment::new().with_var("GNOME_TERMINAL_SCREEN", "1"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::GnomeTerminal, None, None, None, None), - "gnome_terminal_screen_info" - ); - assert_eq!( - terminal.user_agent_token(), - "gnome-terminal", - "gnome_terminal_screen_user_agent" - ); - - let env = FakeEnvironment::new() - .with_var("TERM_PROGRAM", "gnome-terminal") - .with_var("TERM_PROGRAM_VERSION", "3.50"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::GnomeTerminal, - Some("gnome-terminal"), - Some("3.50"), - None, - None, - ), - "gnome_terminal_term_program_info" - ); - assert_eq!( - terminal.user_agent_token(), - "gnome-terminal/3.50", - "gnome_terminal_term_program_user_agent" - ); - } - - #[test] - fn detects_vte() { - let env = FakeEnvironment::new().with_var("VTE_VERSION", "7000"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Vte, None, Some("7000"), None, None), - "vte_version_info" - ); - assert_eq!( - terminal.user_agent_token(), - "VTE/7000", - "vte_version_user_agent" - ); - - let env = FakeEnvironment::new() - .with_var("TERM_PROGRAM", "VTE") - .with_var("TERM_PROGRAM_VERSION", "7000"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Vte, Some("VTE"), Some("7000"), None, None), - "vte_term_program_info" - ); - assert_eq!( - terminal.user_agent_token(), - "VTE/7000", - "vte_term_program_user_agent" - ); - - let env = FakeEnvironment::new().with_var("VTE_VERSION", ""); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Vte, None, None, None, None), - "vte_empty_info" - ); - assert_eq!(terminal.user_agent_token(), "VTE", "vte_empty_user_agent"); - } - - #[test] - fn detects_windows_terminal() { - let env = FakeEnvironment::new().with_var("WT_SESSION", "1"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::WindowsTerminal, None, None, None, None), - "wt_session_info" - ); - assert_eq!( - terminal.user_agent_token(), - "WindowsTerminal", - "wt_session_user_agent" - ); - - let env = FakeEnvironment::new() - .with_var("TERM_PROGRAM", "WindowsTerminal") - .with_var("TERM_PROGRAM_VERSION", "1.21"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::WindowsTerminal, - Some("WindowsTerminal"), - Some("1.21"), - None, - None, - ), - "windows_terminal_term_program_info" - ); - assert_eq!( - terminal.user_agent_token(), - "WindowsTerminal/1.21", - "windows_terminal_term_program_user_agent" - ); - } - - #[test] - fn detects_term_fallbacks() { - let env = FakeEnvironment::new().with_var("TERM", "xterm-256color"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info( - TerminalName::Unknown, - None, - None, - Some("xterm-256color"), - None, - ), - "term_fallback_info" - ); - assert_eq!( - terminal.user_agent_token(), - "xterm-256color", - "term_fallback_user_agent" - ); - - let env = FakeEnvironment::new().with_var("TERM", "dumb"); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Dumb, None, None, Some("dumb"), None), - "dumb_term_info" - ); - assert_eq!(terminal.user_agent_token(), "dumb", "dumb_term_user_agent"); - - let env = FakeEnvironment::new(); - let terminal = detect_terminal_info_from_env(&env); - assert_eq!( - terminal, - terminal_info(TerminalName::Unknown, None, None, None, None), - "unknown_info" - ); - assert_eq!(terminal.user_agent_token(), "unknown", "unknown_user_agent"); - } -} +#[path = "terminal_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/terminal_tests.rs b/codex-rs/core/src/terminal_tests.rs new file mode 100644 index 0000000000..d779a54ab2 --- /dev/null +++ b/codex-rs/core/src/terminal_tests.rs @@ -0,0 +1,702 @@ +use super::*; +use pretty_assertions::assert_eq; +use std::collections::HashMap; + +struct FakeEnvironment { + vars: HashMap, + tmux_client_info: TmuxClientInfo, +} + +impl FakeEnvironment { + fn new() -> Self { + Self { + vars: HashMap::new(), + tmux_client_info: TmuxClientInfo::default(), + } + } + + fn with_var(mut self, key: &str, value: &str) -> Self { + self.vars.insert(key.to_string(), value.to_string()); + self + } + + fn with_tmux_client_info(mut self, termtype: Option<&str>, termname: Option<&str>) -> Self { + self.tmux_client_info = TmuxClientInfo { + termtype: termtype.map(ToString::to_string), + termname: termname.map(ToString::to_string), + }; + self + } +} + +impl Environment for FakeEnvironment { + fn var(&self, name: &str) -> Option { + self.vars.get(name).cloned() + } + + fn tmux_client_info(&self) -> TmuxClientInfo { + self.tmux_client_info.clone() + } +} + +fn terminal_info( + name: TerminalName, + term_program: Option<&str>, + version: Option<&str>, + term: Option<&str>, + multiplexer: Option, +) -> TerminalInfo { + TerminalInfo { + name, + term_program: term_program.map(ToString::to_string), + version: version.map(ToString::to_string), + term: term.map(ToString::to_string), + multiplexer, + } +} + +#[test] +fn detects_term_program() { + let env = FakeEnvironment::new() + .with_var("TERM_PROGRAM", "iTerm.app") + .with_var("TERM_PROGRAM_VERSION", "3.5.0") + .with_var("WEZTERM_VERSION", "2024.2"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::Iterm2, + Some("iTerm.app"), + Some("3.5.0"), + None, + None, + ), + "term_program_with_version_info" + ); + assert_eq!( + terminal.user_agent_token(), + "iTerm.app/3.5.0", + "term_program_with_version_user_agent" + ); + + let env = FakeEnvironment::new() + .with_var("TERM_PROGRAM", "iTerm.app") + .with_var("TERM_PROGRAM_VERSION", ""); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Iterm2, Some("iTerm.app"), None, None, None), + "term_program_without_version_info" + ); + assert_eq!( + terminal.user_agent_token(), + "iTerm.app", + "term_program_without_version_user_agent" + ); + + let env = FakeEnvironment::new() + .with_var("TERM_PROGRAM", "iTerm.app") + .with_var("WEZTERM_VERSION", "2024.2"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Iterm2, Some("iTerm.app"), None, None, None), + "term_program_overrides_wezterm_info" + ); + assert_eq!( + terminal.user_agent_token(), + "iTerm.app", + "term_program_overrides_wezterm_user_agent" + ); +} + +#[test] +fn detects_iterm2() { + let env = FakeEnvironment::new().with_var("ITERM_SESSION_ID", "w0t1p0"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Iterm2, None, None, None, None), + "iterm_session_id_info" + ); + assert_eq!( + terminal.user_agent_token(), + "iTerm.app", + "iterm_session_id_user_agent" + ); +} + +#[test] +fn detects_apple_terminal() { + let env = FakeEnvironment::new().with_var("TERM_PROGRAM", "Apple_Terminal"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::AppleTerminal, + Some("Apple_Terminal"), + None, + None, + None, + ), + "apple_term_program_info" + ); + assert_eq!( + terminal.user_agent_token(), + "Apple_Terminal", + "apple_term_program_user_agent" + ); + + let env = FakeEnvironment::new().with_var("TERM_SESSION_ID", "A1B2C3"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::AppleTerminal, None, None, None, None), + "apple_term_session_id_info" + ); + assert_eq!( + terminal.user_agent_token(), + "Apple_Terminal", + "apple_term_session_id_user_agent" + ); +} + +#[test] +fn detects_ghostty() { + let env = FakeEnvironment::new().with_var("TERM_PROGRAM", "Ghostty"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Ghostty, Some("Ghostty"), None, None, None), + "ghostty_term_program_info" + ); + assert_eq!( + terminal.user_agent_token(), + "Ghostty", + "ghostty_term_program_user_agent" + ); +} + +#[test] +fn detects_vscode() { + let env = FakeEnvironment::new() + .with_var("TERM_PROGRAM", "vscode") + .with_var("TERM_PROGRAM_VERSION", "1.86.0"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::VsCode, + Some("vscode"), + Some("1.86.0"), + None, + None + ), + "vscode_term_program_info" + ); + assert_eq!( + terminal.user_agent_token(), + "vscode/1.86.0", + "vscode_term_program_user_agent" + ); +} + +#[test] +fn detects_warp_terminal() { + let env = FakeEnvironment::new() + .with_var("TERM_PROGRAM", "WarpTerminal") + .with_var("TERM_PROGRAM_VERSION", "v0.2025.12.10.08.12.stable_03"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::WarpTerminal, + Some("WarpTerminal"), + Some("v0.2025.12.10.08.12.stable_03"), + None, + None, + ), + "warp_term_program_info" + ); + assert_eq!( + terminal.user_agent_token(), + "WarpTerminal/v0.2025.12.10.08.12.stable_03", + "warp_term_program_user_agent" + ); +} + +#[test] +fn detects_tmux_multiplexer() { + let env = FakeEnvironment::new() + .with_var("TMUX", "/tmp/tmux-1000/default,123,0") + .with_var("TERM_PROGRAM", "tmux") + .with_tmux_client_info(Some("xterm-256color"), Some("screen-256color")); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::Unknown, + Some("xterm-256color"), + None, + Some("screen-256color"), + Some(Multiplexer::Tmux { version: None }), + ), + "tmux_multiplexer_info" + ); + assert_eq!( + terminal.user_agent_token(), + "xterm-256color", + "tmux_multiplexer_user_agent" + ); +} + +#[test] +fn detects_zellij_multiplexer() { + let env = FakeEnvironment::new().with_var("ZELLIJ", "1"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + TerminalInfo { + name: TerminalName::Unknown, + term_program: None, + version: None, + term: None, + multiplexer: Some(Multiplexer::Zellij {}), + }, + "zellij_multiplexer" + ); +} + +#[test] +fn detects_tmux_client_termtype() { + let env = FakeEnvironment::new() + .with_var("TMUX", "/tmp/tmux-1000/default,123,0") + .with_var("TERM_PROGRAM", "tmux") + .with_tmux_client_info(Some("WezTerm"), None); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::WezTerm, + Some("WezTerm"), + None, + None, + Some(Multiplexer::Tmux { version: None }), + ), + "tmux_client_termtype_info" + ); + assert_eq!( + terminal.user_agent_token(), + "WezTerm", + "tmux_client_termtype_user_agent" + ); +} + +#[test] +fn detects_tmux_client_termname() { + let env = FakeEnvironment::new() + .with_var("TMUX", "/tmp/tmux-1000/default,123,0") + .with_var("TERM_PROGRAM", "tmux") + .with_tmux_client_info(None, Some("xterm-256color")); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::Unknown, + None, + None, + Some("xterm-256color"), + Some(Multiplexer::Tmux { version: None }) + ), + "tmux_client_termname_info" + ); + assert_eq!( + terminal.user_agent_token(), + "xterm-256color", + "tmux_client_termname_user_agent" + ); +} + +#[test] +fn detects_tmux_term_program_uses_client_termtype() { + let env = FakeEnvironment::new() + .with_var("TMUX", "/tmp/tmux-1000/default,123,0") + .with_var("TERM_PROGRAM", "tmux") + .with_var("TERM_PROGRAM_VERSION", "3.6a") + .with_tmux_client_info(Some("ghostty 1.2.3"), Some("xterm-ghostty")); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::Ghostty, + Some("ghostty"), + Some("1.2.3"), + Some("xterm-ghostty"), + Some(Multiplexer::Tmux { + version: Some("3.6a".to_string()), + }), + ), + "tmux_term_program_client_termtype_info" + ); + assert_eq!( + terminal.user_agent_token(), + "ghostty/1.2.3", + "tmux_term_program_client_termtype_user_agent" + ); +} + +#[test] +fn detects_wezterm() { + let env = FakeEnvironment::new().with_var("WEZTERM_VERSION", "2024.2"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::WezTerm, None, Some("2024.2"), None, None), + "wezterm_version_info" + ); + assert_eq!( + terminal.user_agent_token(), + "WezTerm/2024.2", + "wezterm_version_user_agent" + ); + + let env = FakeEnvironment::new() + .with_var("TERM_PROGRAM", "WezTerm") + .with_var("TERM_PROGRAM_VERSION", "2024.2"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::WezTerm, + Some("WezTerm"), + Some("2024.2"), + None, + None + ), + "wezterm_term_program_info" + ); + assert_eq!( + terminal.user_agent_token(), + "WezTerm/2024.2", + "wezterm_term_program_user_agent" + ); + + let env = FakeEnvironment::new().with_var("WEZTERM_VERSION", ""); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::WezTerm, None, None, None, None), + "wezterm_empty_info" + ); + assert_eq!( + terminal.user_agent_token(), + "WezTerm", + "wezterm_empty_user_agent" + ); +} + +#[test] +fn detects_kitty() { + let env = FakeEnvironment::new().with_var("KITTY_WINDOW_ID", "1"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Kitty, None, None, None, None), + "kitty_window_id_info" + ); + assert_eq!( + terminal.user_agent_token(), + "kitty", + "kitty_window_id_user_agent" + ); + + let env = FakeEnvironment::new() + .with_var("TERM_PROGRAM", "kitty") + .with_var("TERM_PROGRAM_VERSION", "0.30.1"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::Kitty, + Some("kitty"), + Some("0.30.1"), + None, + None + ), + "kitty_term_program_info" + ); + assert_eq!( + terminal.user_agent_token(), + "kitty/0.30.1", + "kitty_term_program_user_agent" + ); + + let env = FakeEnvironment::new() + .with_var("TERM", "xterm-kitty") + .with_var("ALACRITTY_SOCKET", "/tmp/alacritty"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Kitty, None, None, None, None), + "kitty_term_over_alacritty_info" + ); + assert_eq!( + terminal.user_agent_token(), + "kitty", + "kitty_term_over_alacritty_user_agent" + ); +} + +#[test] +fn detects_alacritty() { + let env = FakeEnvironment::new().with_var("ALACRITTY_SOCKET", "/tmp/alacritty"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Alacritty, None, None, None, None), + "alacritty_socket_info" + ); + assert_eq!( + terminal.user_agent_token(), + "Alacritty", + "alacritty_socket_user_agent" + ); + + let env = FakeEnvironment::new() + .with_var("TERM_PROGRAM", "Alacritty") + .with_var("TERM_PROGRAM_VERSION", "0.13.2"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::Alacritty, + Some("Alacritty"), + Some("0.13.2"), + None, + None, + ), + "alacritty_term_program_info" + ); + assert_eq!( + terminal.user_agent_token(), + "Alacritty/0.13.2", + "alacritty_term_program_user_agent" + ); + + let env = FakeEnvironment::new().with_var("TERM", "alacritty"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Alacritty, None, None, None, None), + "alacritty_term_info" + ); + assert_eq!( + terminal.user_agent_token(), + "Alacritty", + "alacritty_term_user_agent" + ); +} + +#[test] +fn detects_konsole() { + let env = FakeEnvironment::new().with_var("KONSOLE_VERSION", "230800"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Konsole, None, Some("230800"), None, None), + "konsole_version_info" + ); + assert_eq!( + terminal.user_agent_token(), + "Konsole/230800", + "konsole_version_user_agent" + ); + + let env = FakeEnvironment::new() + .with_var("TERM_PROGRAM", "Konsole") + .with_var("TERM_PROGRAM_VERSION", "230800"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::Konsole, + Some("Konsole"), + Some("230800"), + None, + None + ), + "konsole_term_program_info" + ); + assert_eq!( + terminal.user_agent_token(), + "Konsole/230800", + "konsole_term_program_user_agent" + ); + + let env = FakeEnvironment::new().with_var("KONSOLE_VERSION", ""); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Konsole, None, None, None, None), + "konsole_empty_info" + ); + assert_eq!( + terminal.user_agent_token(), + "Konsole", + "konsole_empty_user_agent" + ); +} + +#[test] +fn detects_gnome_terminal() { + let env = FakeEnvironment::new().with_var("GNOME_TERMINAL_SCREEN", "1"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::GnomeTerminal, None, None, None, None), + "gnome_terminal_screen_info" + ); + assert_eq!( + terminal.user_agent_token(), + "gnome-terminal", + "gnome_terminal_screen_user_agent" + ); + + let env = FakeEnvironment::new() + .with_var("TERM_PROGRAM", "gnome-terminal") + .with_var("TERM_PROGRAM_VERSION", "3.50"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::GnomeTerminal, + Some("gnome-terminal"), + Some("3.50"), + None, + None, + ), + "gnome_terminal_term_program_info" + ); + assert_eq!( + terminal.user_agent_token(), + "gnome-terminal/3.50", + "gnome_terminal_term_program_user_agent" + ); +} + +#[test] +fn detects_vte() { + let env = FakeEnvironment::new().with_var("VTE_VERSION", "7000"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Vte, None, Some("7000"), None, None), + "vte_version_info" + ); + assert_eq!( + terminal.user_agent_token(), + "VTE/7000", + "vte_version_user_agent" + ); + + let env = FakeEnvironment::new() + .with_var("TERM_PROGRAM", "VTE") + .with_var("TERM_PROGRAM_VERSION", "7000"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Vte, Some("VTE"), Some("7000"), None, None), + "vte_term_program_info" + ); + assert_eq!( + terminal.user_agent_token(), + "VTE/7000", + "vte_term_program_user_agent" + ); + + let env = FakeEnvironment::new().with_var("VTE_VERSION", ""); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Vte, None, None, None, None), + "vte_empty_info" + ); + assert_eq!(terminal.user_agent_token(), "VTE", "vte_empty_user_agent"); +} + +#[test] +fn detects_windows_terminal() { + let env = FakeEnvironment::new().with_var("WT_SESSION", "1"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::WindowsTerminal, None, None, None, None), + "wt_session_info" + ); + assert_eq!( + terminal.user_agent_token(), + "WindowsTerminal", + "wt_session_user_agent" + ); + + let env = FakeEnvironment::new() + .with_var("TERM_PROGRAM", "WindowsTerminal") + .with_var("TERM_PROGRAM_VERSION", "1.21"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::WindowsTerminal, + Some("WindowsTerminal"), + Some("1.21"), + None, + None, + ), + "windows_terminal_term_program_info" + ); + assert_eq!( + terminal.user_agent_token(), + "WindowsTerminal/1.21", + "windows_terminal_term_program_user_agent" + ); +} + +#[test] +fn detects_term_fallbacks() { + let env = FakeEnvironment::new().with_var("TERM", "xterm-256color"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info( + TerminalName::Unknown, + None, + None, + Some("xterm-256color"), + None, + ), + "term_fallback_info" + ); + assert_eq!( + terminal.user_agent_token(), + "xterm-256color", + "term_fallback_user_agent" + ); + + let env = FakeEnvironment::new().with_var("TERM", "dumb"); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Dumb, None, None, Some("dumb"), None), + "dumb_term_info" + ); + assert_eq!(terminal.user_agent_token(), "dumb", "dumb_term_user_agent"); + + let env = FakeEnvironment::new(); + let terminal = detect_terminal_info_from_env(&env); + assert_eq!( + terminal, + terminal_info(TerminalName::Unknown, None, None, None, None), + "unknown_info" + ); + assert_eq!(terminal.user_agent_token(), "unknown", "unknown_user_agent"); +} diff --git a/codex-rs/core/src/text_encoding.rs b/codex-rs/core/src/text_encoding.rs index fde44c4195..b70d8af54c 100644 --- a/codex-rs/core/src/text_encoding.rs +++ b/codex-rs/core/src/text_encoding.rs @@ -117,345 +117,5 @@ fn is_windows_1252_punct(byte: u8) -> bool { } #[cfg(test)] -mod tests { - use super::*; - use encoding_rs::BIG5; - use encoding_rs::EUC_KR; - use encoding_rs::GBK; - use encoding_rs::ISO_8859_2; - use encoding_rs::ISO_8859_3; - use encoding_rs::ISO_8859_4; - use encoding_rs::ISO_8859_5; - use encoding_rs::ISO_8859_6; - use encoding_rs::ISO_8859_7; - use encoding_rs::ISO_8859_8; - use encoding_rs::ISO_8859_10; - use encoding_rs::ISO_8859_13; - use encoding_rs::SHIFT_JIS; - use encoding_rs::WINDOWS_874; - use encoding_rs::WINDOWS_1250; - use encoding_rs::WINDOWS_1251; - use encoding_rs::WINDOWS_1253; - use encoding_rs::WINDOWS_1254; - use encoding_rs::WINDOWS_1255; - use encoding_rs::WINDOWS_1256; - use encoding_rs::WINDOWS_1257; - use encoding_rs::WINDOWS_1258; - use pretty_assertions::assert_eq; - - #[test] - fn test_utf8_passthrough() { - // Fast path: when UTF-8 is valid we should avoid copies and return as-is. - let utf8_text = "Hello, мир! 世界"; - let bytes = utf8_text.as_bytes(); - assert_eq!(bytes_to_string_smart(bytes), utf8_text); - } - - #[test] - fn test_cp1251_russian_text() { - // Cyrillic text emitted by PowerShell/WSL in CP1251 should decode cleanly. - let bytes = b"\xEF\xF0\xE8\xEC\xE5\xF0"; // "пример" encoded with Windows-1251 - assert_eq!(bytes_to_string_smart(bytes), "пример"); - } - - #[test] - fn test_cp1251_privet_word() { - // Regression: CP1251 words like "Привет" must not be mis-identified as Windows-1252. - let bytes = b"\xCF\xF0\xE8\xE2\xE5\xF2"; // "Привет" encoded with Windows-1251 - assert_eq!(bytes_to_string_smart(bytes), "Привет"); - } - - #[test] - fn test_koi8_r_privet_word() { - // KOI8-R output should decode to the original Cyrillic as well. - let bytes = b"\xF0\xD2\xC9\xD7\xC5\xD4"; // "Привет" encoded with KOI8-R - assert_eq!(bytes_to_string_smart(bytes), "Привет"); - } - - #[test] - fn test_cp866_russian_text() { - // Legacy consoles (cmd.exe) commonly emit CP866 bytes for Cyrillic content. - let bytes = b"\xAF\xE0\xA8\xAC\xA5\xE0"; // "пример" encoded with CP866 - assert_eq!(bytes_to_string_smart(bytes), "пример"); - } - - #[test] - fn test_cp866_uppercase_text() { - // Ensure the IBM866 heuristic still returns IBM866 for uppercase-only words. - let bytes = b"\x8F\x90\x88"; // "ПРИ" encoded with CP866 uppercase letters - assert_eq!(bytes_to_string_smart(bytes), "ПРИ"); - } - - #[test] - fn test_cp866_uppercase_followed_by_ascii() { - // Regression test: uppercase CP866 tokens next to ASCII text should not be treated as - // CP1252. - let bytes = b"\x8F\x90\x88 test"; // "ПРИ test" encoded with CP866 uppercase letters followed by ASCII - assert_eq!(bytes_to_string_smart(bytes), "ПРИ test"); - } - - #[test] - fn test_windows_1252_quotes() { - // Smart detection should map Windows-1252 punctuation into proper Unicode. - let bytes = b"\x93\x94test"; - assert_eq!(bytes_to_string_smart(bytes), "\u{201C}\u{201D}test"); - } - - #[test] - fn test_windows_1252_multiple_quotes() { - // Longer snippets of punctuation (e.g., “foo” – “bar”) should still flip to CP1252. - let bytes = b"\x93foo\x94 \x96 \x93bar\x94"; - assert_eq!( - bytes_to_string_smart(bytes), - "\u{201C}foo\u{201D} \u{2013} \u{201C}bar\u{201D}" - ); - } - - #[test] - fn test_windows_1252_privet_gibberish_is_preserved() { - // Windows-1252 cannot encode Cyrillic; if the input literally contains "ПÑ..." we should not "fix" it. - let bytes = "Привет".as_bytes(); - assert_eq!(bytes_to_string_smart(bytes), "Привет"); - } - - #[test] - fn test_iso8859_1_latin_text() { - // ISO-8859-1 (code page 28591) is the Latin segment used by LatArCyrHeb. - // encoding_rs unifies ISO-8859-1 with Windows-1252, so reuse that constant here. - let (encoded, _, had_errors) = WINDOWS_1252.encode("Hello"); - assert!(!had_errors, "failed to encode Latin sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Hello"); - } - - #[test] - fn test_iso8859_2_central_european_text() { - // ISO-8859-2 (code page 28592) covers additional Central European glyphs. - let (encoded, _, had_errors) = ISO_8859_2.encode("Příliš žluťoučký kůň"); - assert!(!had_errors, "failed to encode ISO-8859-2 sample"); - assert_eq!( - bytes_to_string_smart(encoded.as_ref()), - "Příliš žluťoučký kůň" - ); - } - - #[test] - fn test_iso8859_3_south_europe_text() { - // ISO-8859-3 (code page 28593) adds support for Maltese/Esperanto letters. - // chardetng rarely distinguishes ISO-8859-3 from neighboring Latin code pages, so we rely on - // an ASCII-only sample to ensure round-tripping still succeeds. - let (encoded, _, had_errors) = ISO_8859_3.encode("Esperanto and Maltese"); - assert!(!had_errors, "failed to encode ISO-8859-3 sample"); - assert_eq!( - bytes_to_string_smart(encoded.as_ref()), - "Esperanto and Maltese" - ); - } - - #[test] - fn test_iso8859_4_baltic_text() { - // ISO-8859-4 (code page 28594) targets the Baltic/Nordic repertoire. - let sample = "Šis ir rakstzīmju kodēšanas tests. Dažās valodās, kurās tiek \ - izmantotas latīņu valodas burti, lēmuma pieņemšanai mums ir nepieciešams \ - vairāk ieguldījuma."; - let (encoded, _, had_errors) = ISO_8859_4.encode(sample); - assert!(!had_errors, "failed to encode ISO-8859-4 sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), sample); - } - - #[test] - fn test_iso8859_5_cyrillic_text() { - // ISO-8859-5 (code page 28595) covers the Cyrillic portion. - let (encoded, _, had_errors) = ISO_8859_5.encode("Привет"); - assert!(!had_errors, "failed to encode Cyrillic sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Привет"); - } - - #[test] - fn test_iso8859_6_arabic_text() { - // ISO-8859-6 (code page 28596) covers the Arabic glyphs. - let (encoded, _, had_errors) = ISO_8859_6.encode("مرحبا"); - assert!(!had_errors, "failed to encode Arabic sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "مرحبا"); - } - - #[test] - fn test_iso8859_7_greek_text() { - // ISO-8859-7 (code page 28597) is used for Greek locales. - let (encoded, _, had_errors) = ISO_8859_7.encode("Καλημέρα"); - assert!(!had_errors, "failed to encode ISO-8859-7 sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Καλημέρα"); - } - - #[test] - fn test_iso8859_8_hebrew_text() { - // ISO-8859-8 (code page 28598) covers the Hebrew glyphs. - let (encoded, _, had_errors) = ISO_8859_8.encode("שלום"); - assert!(!had_errors, "failed to encode Hebrew sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "שלום"); - } - - #[test] - fn test_iso8859_9_turkish_text() { - // ISO-8859-9 (code page 28599) mirrors Latin-1 but inserts Turkish letters. - // encoding_rs exposes the equivalent Windows-1254 mapping. - let (encoded, _, had_errors) = WINDOWS_1254.encode("İstanbul"); - assert!(!had_errors, "failed to encode ISO-8859-9 sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "İstanbul"); - } - - #[test] - fn test_iso8859_10_nordic_text() { - // ISO-8859-10 (code page 28600) adds additional Nordic letters. - let sample = "Þetta er prófun fyrir Ægir og Øystein."; - let (encoded, _, had_errors) = ISO_8859_10.encode(sample); - assert!(!had_errors, "failed to encode ISO-8859-10 sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), sample); - } - - #[test] - fn test_iso8859_11_thai_text() { - // ISO-8859-11 (code page 28601) mirrors TIS-620 / Windows-874 for Thai. - let sample = "ภาษาไทยสำหรับการทดสอบ ISO-8859-11"; - // encoding_rs exposes the equivalent Windows-874 encoding, so use that constant. - let (encoded, _, had_errors) = WINDOWS_874.encode(sample); - assert!(!had_errors, "failed to encode ISO-8859-11 sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), sample); - } - - // ISO-8859-12 was never standardized, and encodings 14–16 cannot be distinguished reliably - // without the heuristics we removed (chardetng generally reports neighboring Latin pages), so - // we intentionally omit coverage for those slots until the detector can identify them. - - #[test] - fn test_iso8859_13_baltic_text() { - // ISO-8859-13 (code page 28603) is common across Baltic languages. - let (encoded, _, had_errors) = ISO_8859_13.encode("Sveiki"); - assert!(!had_errors, "failed to encode ISO-8859-13 sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Sveiki"); - } - - #[test] - fn test_windows_1250_central_european_text() { - let (encoded, _, had_errors) = WINDOWS_1250.encode("Příliš žluťoučký kůň"); - assert!(!had_errors, "failed to encode Central European sample"); - assert_eq!( - bytes_to_string_smart(encoded.as_ref()), - "Příliš žluťoučký kůň" - ); - } - - #[test] - fn test_windows_1251_encoded_text() { - let (encoded, _, had_errors) = WINDOWS_1251.encode("Привет из Windows-1251"); - assert!(!had_errors, "failed to encode Windows-1251 sample"); - assert_eq!( - bytes_to_string_smart(encoded.as_ref()), - "Привет из Windows-1251" - ); - } - - #[test] - fn test_windows_1253_greek_text() { - let (encoded, _, had_errors) = WINDOWS_1253.encode("Γειά σου"); - assert!(!had_errors, "failed to encode Greek sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Γειά σου"); - } - - #[test] - fn test_windows_1254_turkish_text() { - let (encoded, _, had_errors) = WINDOWS_1254.encode("İstanbul"); - assert!(!had_errors, "failed to encode Turkish sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "İstanbul"); - } - - #[test] - fn test_windows_1255_hebrew_text() { - let (encoded, _, had_errors) = WINDOWS_1255.encode("שלום"); - assert!(!had_errors, "failed to encode Windows-1255 Hebrew sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "שלום"); - } - - #[test] - fn test_windows_1256_arabic_text() { - let (encoded, _, had_errors) = WINDOWS_1256.encode("مرحبا"); - assert!(!had_errors, "failed to encode Windows-1256 Arabic sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "مرحبا"); - } - - #[test] - fn test_windows_1257_baltic_text() { - let (encoded, _, had_errors) = WINDOWS_1257.encode("Pērkons"); - assert!(!had_errors, "failed to encode Baltic sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Pērkons"); - } - - #[test] - fn test_windows_1258_vietnamese_text() { - let (encoded, _, had_errors) = WINDOWS_1258.encode("Xin chào"); - assert!(!had_errors, "failed to encode Vietnamese sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Xin chào"); - } - - #[test] - fn test_windows_874_thai_text() { - let (encoded, _, had_errors) = WINDOWS_874.encode("สวัสดีครับ นี่คือการทดสอบภาษาไทย"); - assert!(!had_errors, "failed to encode Thai sample"); - assert_eq!( - bytes_to_string_smart(encoded.as_ref()), - "สวัสดีครับ นี่คือการทดสอบภาษาไทย" - ); - } - - #[test] - fn test_windows_932_shift_jis_text() { - let (encoded, _, had_errors) = SHIFT_JIS.encode("こんにちは"); - assert!(!had_errors, "failed to encode Shift-JIS sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "こんにちは"); - } - - #[test] - fn test_windows_936_gbk_text() { - let (encoded, _, had_errors) = GBK.encode("你好,世界,这是一个测试"); - assert!(!had_errors, "failed to encode GBK sample"); - assert_eq!( - bytes_to_string_smart(encoded.as_ref()), - "你好,世界,这是一个测试" - ); - } - - #[test] - fn test_windows_949_korean_text() { - let (encoded, _, had_errors) = EUC_KR.encode("안녕하세요"); - assert!(!had_errors, "failed to encode Korean sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "안녕하세요"); - } - - #[test] - fn test_windows_950_big5_text() { - let (encoded, _, had_errors) = BIG5.encode("繁體"); - assert!(!had_errors, "failed to encode Big5 sample"); - assert_eq!(bytes_to_string_smart(encoded.as_ref()), "繁體"); - } - - #[test] - fn test_latin1_cafe() { - // Latin-1 bytes remain common in Western-European locales; decode them directly. - let bytes = b"caf\xE9"; // codespell:ignore caf - assert_eq!(bytes_to_string_smart(bytes), "café"); - } - - #[test] - fn test_preserves_ansi_sequences() { - // ANSI escape sequences should survive regardless of the detected encoding. - let bytes = b"\x1b[31mred\x1b[0m"; - assert_eq!(bytes_to_string_smart(bytes), "\x1b[31mred\x1b[0m"); - } - - #[test] - fn test_fallback_to_lossy() { - // Completely invalid sequences fall back to the old lossy behavior. - let invalid_bytes = [0xFF, 0xFE, 0xFD]; - let result = bytes_to_string_smart(&invalid_bytes); - assert_eq!(result, String::from_utf8_lossy(&invalid_bytes)); - } -} +#[path = "text_encoding_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/text_encoding_tests.rs b/codex-rs/core/src/text_encoding_tests.rs new file mode 100644 index 0000000000..6368f38be4 --- /dev/null +++ b/codex-rs/core/src/text_encoding_tests.rs @@ -0,0 +1,340 @@ +use super::*; +use encoding_rs::BIG5; +use encoding_rs::EUC_KR; +use encoding_rs::GBK; +use encoding_rs::ISO_8859_2; +use encoding_rs::ISO_8859_3; +use encoding_rs::ISO_8859_4; +use encoding_rs::ISO_8859_5; +use encoding_rs::ISO_8859_6; +use encoding_rs::ISO_8859_7; +use encoding_rs::ISO_8859_8; +use encoding_rs::ISO_8859_10; +use encoding_rs::ISO_8859_13; +use encoding_rs::SHIFT_JIS; +use encoding_rs::WINDOWS_874; +use encoding_rs::WINDOWS_1250; +use encoding_rs::WINDOWS_1251; +use encoding_rs::WINDOWS_1253; +use encoding_rs::WINDOWS_1254; +use encoding_rs::WINDOWS_1255; +use encoding_rs::WINDOWS_1256; +use encoding_rs::WINDOWS_1257; +use encoding_rs::WINDOWS_1258; +use pretty_assertions::assert_eq; + +#[test] +fn test_utf8_passthrough() { + // Fast path: when UTF-8 is valid we should avoid copies and return as-is. + let utf8_text = "Hello, мир! 世界"; + let bytes = utf8_text.as_bytes(); + assert_eq!(bytes_to_string_smart(bytes), utf8_text); +} + +#[test] +fn test_cp1251_russian_text() { + // Cyrillic text emitted by PowerShell/WSL in CP1251 should decode cleanly. + let bytes = b"\xEF\xF0\xE8\xEC\xE5\xF0"; // "пример" encoded with Windows-1251 + assert_eq!(bytes_to_string_smart(bytes), "пример"); +} + +#[test] +fn test_cp1251_privet_word() { + // Regression: CP1251 words like "Привет" must not be mis-identified as Windows-1252. + let bytes = b"\xCF\xF0\xE8\xE2\xE5\xF2"; // "Привет" encoded with Windows-1251 + assert_eq!(bytes_to_string_smart(bytes), "Привет"); +} + +#[test] +fn test_koi8_r_privet_word() { + // KOI8-R output should decode to the original Cyrillic as well. + let bytes = b"\xF0\xD2\xC9\xD7\xC5\xD4"; // "Привет" encoded with KOI8-R + assert_eq!(bytes_to_string_smart(bytes), "Привет"); +} + +#[test] +fn test_cp866_russian_text() { + // Legacy consoles (cmd.exe) commonly emit CP866 bytes for Cyrillic content. + let bytes = b"\xAF\xE0\xA8\xAC\xA5\xE0"; // "пример" encoded with CP866 + assert_eq!(bytes_to_string_smart(bytes), "пример"); +} + +#[test] +fn test_cp866_uppercase_text() { + // Ensure the IBM866 heuristic still returns IBM866 for uppercase-only words. + let bytes = b"\x8F\x90\x88"; // "ПРИ" encoded with CP866 uppercase letters + assert_eq!(bytes_to_string_smart(bytes), "ПРИ"); +} + +#[test] +fn test_cp866_uppercase_followed_by_ascii() { + // Regression test: uppercase CP866 tokens next to ASCII text should not be treated as + // CP1252. + let bytes = b"\x8F\x90\x88 test"; // "ПРИ test" encoded with CP866 uppercase letters followed by ASCII + assert_eq!(bytes_to_string_smart(bytes), "ПРИ test"); +} + +#[test] +fn test_windows_1252_quotes() { + // Smart detection should map Windows-1252 punctuation into proper Unicode. + let bytes = b"\x93\x94test"; + assert_eq!(bytes_to_string_smart(bytes), "\u{201C}\u{201D}test"); +} + +#[test] +fn test_windows_1252_multiple_quotes() { + // Longer snippets of punctuation (e.g., “foo” – “bar”) should still flip to CP1252. + let bytes = b"\x93foo\x94 \x96 \x93bar\x94"; + assert_eq!( + bytes_to_string_smart(bytes), + "\u{201C}foo\u{201D} \u{2013} \u{201C}bar\u{201D}" + ); +} + +#[test] +fn test_windows_1252_privet_gibberish_is_preserved() { + // Windows-1252 cannot encode Cyrillic; if the input literally contains "ПÑ..." we should not "fix" it. + let bytes = "Привет".as_bytes(); + assert_eq!(bytes_to_string_smart(bytes), "Привет"); +} + +#[test] +fn test_iso8859_1_latin_text() { + // ISO-8859-1 (code page 28591) is the Latin segment used by LatArCyrHeb. + // encoding_rs unifies ISO-8859-1 with Windows-1252, so reuse that constant here. + let (encoded, _, had_errors) = WINDOWS_1252.encode("Hello"); + assert!(!had_errors, "failed to encode Latin sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Hello"); +} + +#[test] +fn test_iso8859_2_central_european_text() { + // ISO-8859-2 (code page 28592) covers additional Central European glyphs. + let (encoded, _, had_errors) = ISO_8859_2.encode("Příliš žluťoučký kůň"); + assert!(!had_errors, "failed to encode ISO-8859-2 sample"); + assert_eq!( + bytes_to_string_smart(encoded.as_ref()), + "Příliš žluťoučký kůň" + ); +} + +#[test] +fn test_iso8859_3_south_europe_text() { + // ISO-8859-3 (code page 28593) adds support for Maltese/Esperanto letters. + // chardetng rarely distinguishes ISO-8859-3 from neighboring Latin code pages, so we rely on + // an ASCII-only sample to ensure round-tripping still succeeds. + let (encoded, _, had_errors) = ISO_8859_3.encode("Esperanto and Maltese"); + assert!(!had_errors, "failed to encode ISO-8859-3 sample"); + assert_eq!( + bytes_to_string_smart(encoded.as_ref()), + "Esperanto and Maltese" + ); +} + +#[test] +fn test_iso8859_4_baltic_text() { + // ISO-8859-4 (code page 28594) targets the Baltic/Nordic repertoire. + let sample = "Šis ir rakstzīmju kodēšanas tests. Dažās valodās, kurās tiek \ + izmantotas latīņu valodas burti, lēmuma pieņemšanai mums ir nepieciešams \ + vairāk ieguldījuma."; + let (encoded, _, had_errors) = ISO_8859_4.encode(sample); + assert!(!had_errors, "failed to encode ISO-8859-4 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), sample); +} + +#[test] +fn test_iso8859_5_cyrillic_text() { + // ISO-8859-5 (code page 28595) covers the Cyrillic portion. + let (encoded, _, had_errors) = ISO_8859_5.encode("Привет"); + assert!(!had_errors, "failed to encode Cyrillic sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Привет"); +} + +#[test] +fn test_iso8859_6_arabic_text() { + // ISO-8859-6 (code page 28596) covers the Arabic glyphs. + let (encoded, _, had_errors) = ISO_8859_6.encode("مرحبا"); + assert!(!had_errors, "failed to encode Arabic sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "مرحبا"); +} + +#[test] +fn test_iso8859_7_greek_text() { + // ISO-8859-7 (code page 28597) is used for Greek locales. + let (encoded, _, had_errors) = ISO_8859_7.encode("Καλημέρα"); + assert!(!had_errors, "failed to encode ISO-8859-7 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Καλημέρα"); +} + +#[test] +fn test_iso8859_8_hebrew_text() { + // ISO-8859-8 (code page 28598) covers the Hebrew glyphs. + let (encoded, _, had_errors) = ISO_8859_8.encode("שלום"); + assert!(!had_errors, "failed to encode Hebrew sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "שלום"); +} + +#[test] +fn test_iso8859_9_turkish_text() { + // ISO-8859-9 (code page 28599) mirrors Latin-1 but inserts Turkish letters. + // encoding_rs exposes the equivalent Windows-1254 mapping. + let (encoded, _, had_errors) = WINDOWS_1254.encode("İstanbul"); + assert!(!had_errors, "failed to encode ISO-8859-9 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "İstanbul"); +} + +#[test] +fn test_iso8859_10_nordic_text() { + // ISO-8859-10 (code page 28600) adds additional Nordic letters. + let sample = "Þetta er prófun fyrir Ægir og Øystein."; + let (encoded, _, had_errors) = ISO_8859_10.encode(sample); + assert!(!had_errors, "failed to encode ISO-8859-10 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), sample); +} + +#[test] +fn test_iso8859_11_thai_text() { + // ISO-8859-11 (code page 28601) mirrors TIS-620 / Windows-874 for Thai. + let sample = "ภาษาไทยสำหรับการทดสอบ ISO-8859-11"; + // encoding_rs exposes the equivalent Windows-874 encoding, so use that constant. + let (encoded, _, had_errors) = WINDOWS_874.encode(sample); + assert!(!had_errors, "failed to encode ISO-8859-11 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), sample); +} + +// ISO-8859-12 was never standardized, and encodings 14–16 cannot be distinguished reliably +// without the heuristics we removed (chardetng generally reports neighboring Latin pages), so +// we intentionally omit coverage for those slots until the detector can identify them. + +#[test] +fn test_iso8859_13_baltic_text() { + // ISO-8859-13 (code page 28603) is common across Baltic languages. + let (encoded, _, had_errors) = ISO_8859_13.encode("Sveiki"); + assert!(!had_errors, "failed to encode ISO-8859-13 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Sveiki"); +} + +#[test] +fn test_windows_1250_central_european_text() { + let (encoded, _, had_errors) = WINDOWS_1250.encode("Příliš žluťoučký kůň"); + assert!(!had_errors, "failed to encode Central European sample"); + assert_eq!( + bytes_to_string_smart(encoded.as_ref()), + "Příliš žluťoučký kůň" + ); +} + +#[test] +fn test_windows_1251_encoded_text() { + let (encoded, _, had_errors) = WINDOWS_1251.encode("Привет из Windows-1251"); + assert!(!had_errors, "failed to encode Windows-1251 sample"); + assert_eq!( + bytes_to_string_smart(encoded.as_ref()), + "Привет из Windows-1251" + ); +} + +#[test] +fn test_windows_1253_greek_text() { + let (encoded, _, had_errors) = WINDOWS_1253.encode("Γειά σου"); + assert!(!had_errors, "failed to encode Greek sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Γειά σου"); +} + +#[test] +fn test_windows_1254_turkish_text() { + let (encoded, _, had_errors) = WINDOWS_1254.encode("İstanbul"); + assert!(!had_errors, "failed to encode Turkish sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "İstanbul"); +} + +#[test] +fn test_windows_1255_hebrew_text() { + let (encoded, _, had_errors) = WINDOWS_1255.encode("שלום"); + assert!(!had_errors, "failed to encode Windows-1255 Hebrew sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "שלום"); +} + +#[test] +fn test_windows_1256_arabic_text() { + let (encoded, _, had_errors) = WINDOWS_1256.encode("مرحبا"); + assert!(!had_errors, "failed to encode Windows-1256 Arabic sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "مرحبا"); +} + +#[test] +fn test_windows_1257_baltic_text() { + let (encoded, _, had_errors) = WINDOWS_1257.encode("Pērkons"); + assert!(!had_errors, "failed to encode Baltic sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Pērkons"); +} + +#[test] +fn test_windows_1258_vietnamese_text() { + let (encoded, _, had_errors) = WINDOWS_1258.encode("Xin chào"); + assert!(!had_errors, "failed to encode Vietnamese sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "Xin chào"); +} + +#[test] +fn test_windows_874_thai_text() { + let (encoded, _, had_errors) = WINDOWS_874.encode("สวัสดีครับ นี่คือการทดสอบภาษาไทย"); + assert!(!had_errors, "failed to encode Thai sample"); + assert_eq!( + bytes_to_string_smart(encoded.as_ref()), + "สวัสดีครับ นี่คือการทดสอบภาษาไทย" + ); +} + +#[test] +fn test_windows_932_shift_jis_text() { + let (encoded, _, had_errors) = SHIFT_JIS.encode("こんにちは"); + assert!(!had_errors, "failed to encode Shift-JIS sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "こんにちは"); +} + +#[test] +fn test_windows_936_gbk_text() { + let (encoded, _, had_errors) = GBK.encode("你好,世界,这是一个测试"); + assert!(!had_errors, "failed to encode GBK sample"); + assert_eq!( + bytes_to_string_smart(encoded.as_ref()), + "你好,世界,这是一个测试" + ); +} + +#[test] +fn test_windows_949_korean_text() { + let (encoded, _, had_errors) = EUC_KR.encode("안녕하세요"); + assert!(!had_errors, "failed to encode Korean sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "안녕하세요"); +} + +#[test] +fn test_windows_950_big5_text() { + let (encoded, _, had_errors) = BIG5.encode("繁體"); + assert!(!had_errors, "failed to encode Big5 sample"); + assert_eq!(bytes_to_string_smart(encoded.as_ref()), "繁體"); +} + +#[test] +fn test_latin1_cafe() { + // Latin-1 bytes remain common in Western-European locales; decode them directly. + let bytes = b"caf\xE9"; // codespell:ignore caf + assert_eq!(bytes_to_string_smart(bytes), "café"); +} + +#[test] +fn test_preserves_ansi_sequences() { + // ANSI escape sequences should survive regardless of the detected encoding. + let bytes = b"\x1b[31mred\x1b[0m"; + assert_eq!(bytes_to_string_smart(bytes), "\x1b[31mred\x1b[0m"); +} + +#[test] +fn test_fallback_to_lossy() { + // Completely invalid sequences fall back to the old lossy behavior. + let invalid_bytes = [0xFF, 0xFE, 0xFD]; + let result = bytes_to_string_smart(&invalid_bytes); + assert_eq!(result, String::from_utf8_lossy(&invalid_bytes)); +} diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index 4053885027..3ed8e8f0b3 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -750,157 +750,5 @@ fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> Initia } #[cfg(test)] -mod tests { - use super::*; - use crate::codex::make_session_and_context; - use crate::config::test_config; - use assert_matches::assert_matches; - use codex_protocol::models::ContentItem; - use codex_protocol::models::ReasoningItemReasoningSummary; - use codex_protocol::models::ResponseItem; - use pretty_assertions::assert_eq; - use std::time::Duration; - use tempfile::tempdir; - - fn user_msg(text: &str) -> ResponseItem { - ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::OutputText { - text: text.to_string(), - }], - end_turn: None, - phase: None, - } - } - fn assistant_msg(text: &str) -> ResponseItem { - ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: text.to_string(), - }], - end_turn: None, - phase: None, - } - } - - #[test] - fn drops_from_last_user_only() { - let items = [ - user_msg("u1"), - assistant_msg("a1"), - assistant_msg("a2"), - user_msg("u2"), - assistant_msg("a3"), - ResponseItem::Reasoning { - id: "r1".to_string(), - summary: vec![ReasoningItemReasoningSummary::SummaryText { - text: "s".to_string(), - }], - content: None, - encrypted_content: None, - }, - ResponseItem::FunctionCall { - id: None, - call_id: "c1".to_string(), - name: "tool".to_string(), - namespace: None, - arguments: "{}".to_string(), - }, - assistant_msg("a4"), - ]; - - let initial: Vec = items - .iter() - .cloned() - .map(RolloutItem::ResponseItem) - .collect(); - let truncated = truncate_before_nth_user_message(InitialHistory::Forked(initial), 1); - let got_items = truncated.get_rollout_items(); - let expected_items = vec![ - RolloutItem::ResponseItem(items[0].clone()), - RolloutItem::ResponseItem(items[1].clone()), - RolloutItem::ResponseItem(items[2].clone()), - ]; - assert_eq!( - serde_json::to_value(&got_items).unwrap(), - serde_json::to_value(&expected_items).unwrap() - ); - - let initial2: Vec = items - .iter() - .cloned() - .map(RolloutItem::ResponseItem) - .collect(); - let truncated2 = truncate_before_nth_user_message(InitialHistory::Forked(initial2), 2); - assert_matches!(truncated2, InitialHistory::New); - } - - #[tokio::test] - async fn ignores_session_prefix_messages_when_truncating() { - let (session, turn_context) = make_session_and_context().await; - let mut items = session.build_initial_context(&turn_context).await; - items.push(user_msg("feature request")); - items.push(assistant_msg("ack")); - items.push(user_msg("second question")); - items.push(assistant_msg("answer")); - - let rollout_items: Vec = items - .iter() - .cloned() - .map(RolloutItem::ResponseItem) - .collect(); - - let truncated = truncate_before_nth_user_message(InitialHistory::Forked(rollout_items), 1); - let got_items = truncated.get_rollout_items(); - - let expected: Vec = vec![ - RolloutItem::ResponseItem(items[0].clone()), - RolloutItem::ResponseItem(items[1].clone()), - RolloutItem::ResponseItem(items[2].clone()), - RolloutItem::ResponseItem(items[3].clone()), - ]; - - assert_eq!( - serde_json::to_value(&got_items).unwrap(), - serde_json::to_value(&expected).unwrap() - ); - } - - #[tokio::test] - async fn shutdown_all_threads_bounded_submits_shutdown_to_every_thread() { - let temp_dir = tempdir().expect("tempdir"); - let mut config = test_config(); - config.codex_home = temp_dir.path().join("codex-home"); - config.cwd = config.codex_home.clone(); - std::fs::create_dir_all(&config.codex_home).expect("create codex home"); - - let manager = ThreadManager::with_models_provider_and_home_for_tests( - CodexAuth::from_api_key("dummy"), - config.model_provider.clone(), - config.codex_home.clone(), - ); - let thread_1 = manager - .start_thread(config.clone()) - .await - .expect("start first thread") - .thread_id; - let thread_2 = manager - .start_thread(config) - .await - .expect("start second thread") - .thread_id; - - let report = manager - .shutdown_all_threads_bounded(Duration::from_secs(10)) - .await; - - let mut expected_completed = vec![thread_1, thread_2]; - expected_completed.sort_by_key(std::string::ToString::to_string); - assert_eq!(report.completed, expected_completed); - assert!(report.submit_failed.is_empty()); - assert!(report.timed_out.is_empty()); - assert!(manager.list_thread_ids().await.is_empty()); - } -} +#[path = "thread_manager_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/thread_manager_tests.rs b/codex-rs/core/src/thread_manager_tests.rs new file mode 100644 index 0000000000..0172f46a21 --- /dev/null +++ b/codex-rs/core/src/thread_manager_tests.rs @@ -0,0 +1,152 @@ +use super::*; +use crate::codex::make_session_and_context; +use crate::config::test_config; +use assert_matches::assert_matches; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ReasoningItemReasoningSummary; +use codex_protocol::models::ResponseItem; +use pretty_assertions::assert_eq; +use std::time::Duration; +use tempfile::tempdir; + +fn user_msg(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + } +} +fn assistant_msg(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + } +} + +#[test] +fn drops_from_last_user_only() { + let items = [ + user_msg("u1"), + assistant_msg("a1"), + assistant_msg("a2"), + user_msg("u2"), + assistant_msg("a3"), + ResponseItem::Reasoning { + id: "r1".to_string(), + summary: vec![ReasoningItemReasoningSummary::SummaryText { + text: "s".to_string(), + }], + content: None, + encrypted_content: None, + }, + ResponseItem::FunctionCall { + id: None, + call_id: "c1".to_string(), + name: "tool".to_string(), + namespace: None, + arguments: "{}".to_string(), + }, + assistant_msg("a4"), + ]; + + let initial: Vec = items + .iter() + .cloned() + .map(RolloutItem::ResponseItem) + .collect(); + let truncated = truncate_before_nth_user_message(InitialHistory::Forked(initial), 1); + let got_items = truncated.get_rollout_items(); + let expected_items = vec![ + RolloutItem::ResponseItem(items[0].clone()), + RolloutItem::ResponseItem(items[1].clone()), + RolloutItem::ResponseItem(items[2].clone()), + ]; + assert_eq!( + serde_json::to_value(&got_items).unwrap(), + serde_json::to_value(&expected_items).unwrap() + ); + + let initial2: Vec = items + .iter() + .cloned() + .map(RolloutItem::ResponseItem) + .collect(); + let truncated2 = truncate_before_nth_user_message(InitialHistory::Forked(initial2), 2); + assert_matches!(truncated2, InitialHistory::New); +} + +#[tokio::test] +async fn ignores_session_prefix_messages_when_truncating() { + let (session, turn_context) = make_session_and_context().await; + let mut items = session.build_initial_context(&turn_context).await; + items.push(user_msg("feature request")); + items.push(assistant_msg("ack")); + items.push(user_msg("second question")); + items.push(assistant_msg("answer")); + + let rollout_items: Vec = items + .iter() + .cloned() + .map(RolloutItem::ResponseItem) + .collect(); + + let truncated = truncate_before_nth_user_message(InitialHistory::Forked(rollout_items), 1); + let got_items = truncated.get_rollout_items(); + + let expected: Vec = vec![ + RolloutItem::ResponseItem(items[0].clone()), + RolloutItem::ResponseItem(items[1].clone()), + RolloutItem::ResponseItem(items[2].clone()), + RolloutItem::ResponseItem(items[3].clone()), + ]; + + assert_eq!( + serde_json::to_value(&got_items).unwrap(), + serde_json::to_value(&expected).unwrap() + ); +} + +#[tokio::test] +async fn shutdown_all_threads_bounded_submits_shutdown_to_every_thread() { + let temp_dir = tempdir().expect("tempdir"); + let mut config = test_config(); + config.codex_home = temp_dir.path().join("codex-home"); + config.cwd = config.codex_home.clone(); + std::fs::create_dir_all(&config.codex_home).expect("create codex home"); + + let manager = ThreadManager::with_models_provider_and_home_for_tests( + CodexAuth::from_api_key("dummy"), + config.model_provider.clone(), + config.codex_home.clone(), + ); + let thread_1 = manager + .start_thread(config.clone()) + .await + .expect("start first thread") + .thread_id; + let thread_2 = manager + .start_thread(config) + .await + .expect("start second thread") + .thread_id; + + let report = manager + .shutdown_all_threads_bounded(Duration::from_secs(10)) + .await; + + let mut expected_completed = vec![thread_1, thread_2]; + expected_completed.sort_by_key(std::string::ToString::to_string); + assert_eq!(report.completed, expected_completed); + assert!(report.submit_failed.is_empty()); + assert!(report.timed_out.is_empty()); + assert!(manager.list_thread_ids().await.is_empty()); +} diff --git a/codex-rs/core/src/token_data.rs b/codex-rs/core/src/token_data.rs index 85babd8511..5952d5940d 100644 --- a/codex-rs/core/src/token_data.rs +++ b/codex-rs/core/src/token_data.rs @@ -175,114 +175,5 @@ where } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use serde::Serialize; - - #[test] - fn id_token_info_parses_email_and_plan() { - #[derive(Serialize)] - struct Header { - alg: &'static str, - typ: &'static str, - } - let header = Header { - alg: "none", - typ: "JWT", - }; - let payload = serde_json::json!({ - "email": "user@example.com", - "https://api.openai.com/auth": { - "chatgpt_plan_type": "pro" - } - }); - - fn b64url_no_pad(bytes: &[u8]) -> String { - base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) - } - - let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap()); - let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap()); - let signature_b64 = b64url_no_pad(b"sig"); - let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); - - let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse"); - assert_eq!(info.email.as_deref(), Some("user@example.com")); - assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Pro")); - } - - #[test] - fn id_token_info_parses_go_plan() { - #[derive(Serialize)] - struct Header { - alg: &'static str, - typ: &'static str, - } - let header = Header { - alg: "none", - typ: "JWT", - }; - let payload = serde_json::json!({ - "email": "user@example.com", - "https://api.openai.com/auth": { - "chatgpt_plan_type": "go" - } - }); - - fn b64url_no_pad(bytes: &[u8]) -> String { - base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) - } - - let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap()); - let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap()); - let signature_b64 = b64url_no_pad(b"sig"); - let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); - - let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse"); - assert_eq!(info.email.as_deref(), Some("user@example.com")); - assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Go")); - } - - #[test] - fn id_token_info_handles_missing_fields() { - #[derive(Serialize)] - struct Header { - alg: &'static str, - typ: &'static str, - } - let header = Header { - alg: "none", - typ: "JWT", - }; - let payload = serde_json::json!({ "sub": "123" }); - - fn b64url_no_pad(bytes: &[u8]) -> String { - base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) - } - - let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap()); - let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap()); - let signature_b64 = b64url_no_pad(b"sig"); - let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); - - let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse"); - assert!(info.email.is_none()); - assert!(info.get_chatgpt_plan_type().is_none()); - } - - #[test] - fn workspace_account_detection_matches_workspace_plans() { - let workspace = IdTokenInfo { - chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Business)), - ..IdTokenInfo::default() - }; - assert_eq!(workspace.is_workspace_account(), true); - - let personal = IdTokenInfo { - chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)), - ..IdTokenInfo::default() - }; - assert_eq!(personal.is_workspace_account(), false); - } -} +#[path = "token_data_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/token_data_tests.rs b/codex-rs/core/src/token_data_tests.rs new file mode 100644 index 0000000000..e599379c18 --- /dev/null +++ b/codex-rs/core/src/token_data_tests.rs @@ -0,0 +1,109 @@ +use super::*; +use pretty_assertions::assert_eq; +use serde::Serialize; + +#[test] +fn id_token_info_parses_email_and_plan() { + #[derive(Serialize)] + struct Header { + alg: &'static str, + typ: &'static str, + } + let header = Header { + alg: "none", + typ: "JWT", + }; + let payload = serde_json::json!({ + "email": "user@example.com", + "https://api.openai.com/auth": { + "chatgpt_plan_type": "pro" + } + }); + + fn b64url_no_pad(bytes: &[u8]) -> String { + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) + } + + let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap()); + let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap()); + let signature_b64 = b64url_no_pad(b"sig"); + let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); + + let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse"); + assert_eq!(info.email.as_deref(), Some("user@example.com")); + assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Pro")); +} + +#[test] +fn id_token_info_parses_go_plan() { + #[derive(Serialize)] + struct Header { + alg: &'static str, + typ: &'static str, + } + let header = Header { + alg: "none", + typ: "JWT", + }; + let payload = serde_json::json!({ + "email": "user@example.com", + "https://api.openai.com/auth": { + "chatgpt_plan_type": "go" + } + }); + + fn b64url_no_pad(bytes: &[u8]) -> String { + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) + } + + let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap()); + let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap()); + let signature_b64 = b64url_no_pad(b"sig"); + let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); + + let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse"); + assert_eq!(info.email.as_deref(), Some("user@example.com")); + assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Go")); +} + +#[test] +fn id_token_info_handles_missing_fields() { + #[derive(Serialize)] + struct Header { + alg: &'static str, + typ: &'static str, + } + let header = Header { + alg: "none", + typ: "JWT", + }; + let payload = serde_json::json!({ "sub": "123" }); + + fn b64url_no_pad(bytes: &[u8]) -> String { + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) + } + + let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap()); + let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap()); + let signature_b64 = b64url_no_pad(b"sig"); + let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); + + let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse"); + assert!(info.email.is_none()); + assert!(info.get_chatgpt_plan_type().is_none()); +} + +#[test] +fn workspace_account_detection_matches_workspace_plans() { + let workspace = IdTokenInfo { + chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Business)), + ..IdTokenInfo::default() + }; + assert_eq!(workspace.is_workspace_account(), true); + + let personal = IdTokenInfo { + chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)), + ..IdTokenInfo::default() + }; + assert_eq!(personal.is_workspace_account(), false); +} diff --git a/codex-rs/core/src/tools/code_mode_description.rs b/codex-rs/core/src/tools/code_mode_description.rs index 2a3ba815cc..8ed9fc6f53 100644 --- a/codex-rs/core/src/tools/code_mode_description.rs +++ b/codex-rs/core/src/tools/code_mode_description.rs @@ -291,80 +291,5 @@ fn render_json_schema_literal(value: &JsonValue) -> String { } #[cfg(test)] -mod tests { - use super::render_json_schema_to_typescript; - use pretty_assertions::assert_eq; - use serde_json::json; - - #[test] - fn render_json_schema_to_typescript_renders_object_properties() { - let schema = json!({ - "type": "object", - "properties": { - "path": {"type": "string"}, - "recursive": {"type": "boolean"} - }, - "required": ["path"], - "additionalProperties": false - }); - - assert_eq!( - render_json_schema_to_typescript(&schema), - "{\n path: string;\n recursive?: boolean;\n}" - ); - } - - #[test] - fn render_json_schema_to_typescript_renders_anyof_unions() { - let schema = json!({ - "anyOf": [ - {"const": "pending"}, - {"const": "done"}, - {"type": "number"} - ] - }); - - assert_eq!( - render_json_schema_to_typescript(&schema), - "\"pending\" | \"done\" | number" - ); - } - - #[test] - fn render_json_schema_to_typescript_renders_additional_properties() { - let schema = json!({ - "type": "object", - "properties": { - "tags": { - "type": "array", - "items": {"type": "string"} - } - }, - "additionalProperties": {"type": "integer"} - }); - - assert_eq!( - render_json_schema_to_typescript(&schema), - "{\n tags?: Array;\n [key: string]: number;\n}" - ); - } - - #[test] - fn render_json_schema_to_typescript_sorts_object_properties() { - let schema = json!({ - "type": "object", - "properties": { - "structuredContent": {"type": "string"}, - "_meta": {"type": "string"}, - "isError": {"type": "boolean"}, - "content": {"type": "array", "items": {"type": "string"}} - }, - "required": ["content"] - }); - - assert_eq!( - render_json_schema_to_typescript(&schema), - "{\n _meta?: string;\n content: Array;\n isError?: boolean;\n structuredContent?: string;\n}" - ); - } -} +#[path = "code_mode_description_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/code_mode_description_tests.rs b/codex-rs/core/src/tools/code_mode_description_tests.rs new file mode 100644 index 0000000000..500d7bf670 --- /dev/null +++ b/codex-rs/core/src/tools/code_mode_description_tests.rs @@ -0,0 +1,75 @@ +use super::render_json_schema_to_typescript; +use pretty_assertions::assert_eq; +use serde_json::json; + +#[test] +fn render_json_schema_to_typescript_renders_object_properties() { + let schema = json!({ + "type": "object", + "properties": { + "path": {"type": "string"}, + "recursive": {"type": "boolean"} + }, + "required": ["path"], + "additionalProperties": false + }); + + assert_eq!( + render_json_schema_to_typescript(&schema), + "{\n path: string;\n recursive?: boolean;\n}" + ); +} + +#[test] +fn render_json_schema_to_typescript_renders_anyof_unions() { + let schema = json!({ + "anyOf": [ + {"const": "pending"}, + {"const": "done"}, + {"type": "number"} + ] + }); + + assert_eq!( + render_json_schema_to_typescript(&schema), + "\"pending\" | \"done\" | number" + ); +} + +#[test] +fn render_json_schema_to_typescript_renders_additional_properties() { + let schema = json!({ + "type": "object", + "properties": { + "tags": { + "type": "array", + "items": {"type": "string"} + } + }, + "additionalProperties": {"type": "integer"} + }); + + assert_eq!( + render_json_schema_to_typescript(&schema), + "{\n tags?: Array;\n [key: string]: number;\n}" + ); +} + +#[test] +fn render_json_schema_to_typescript_sorts_object_properties() { + let schema = json!({ + "type": "object", + "properties": { + "structuredContent": {"type": "string"}, + "_meta": {"type": "string"}, + "isError": {"type": "boolean"}, + "content": {"type": "array", "items": {"type": "string"}} + }, + "required": ["content"] + }); + + assert_eq!( + render_json_schema_to_typescript(&schema), + "{\n _meta?: string;\n content: Array;\n isError?: boolean;\n structuredContent?: string;\n}" + ); +} diff --git a/codex-rs/core/src/tools/context.rs b/codex-rs/core/src/tools/context.rs index 85127059b7..ccb38623ba 100644 --- a/codex-rs/core/src/tools/context.rs +++ b/codex-rs/core/src/tools/context.rs @@ -424,279 +424,5 @@ fn telemetry_preview(content: &str) -> String { } #[cfg(test)] -mod tests { - use super::*; - use core_test_support::assert_regex_match; - use pretty_assertions::assert_eq; - use serde_json::json; - - #[test] - fn custom_tool_calls_should_roundtrip_as_custom_outputs() { - let payload = ToolPayload::Custom { - input: "patch".to_string(), - }; - let response = FunctionToolOutput::from_text("patched".to_string(), Some(true)) - .to_response_item("call-42", &payload); - - match response { - ResponseInputItem::CustomToolCallOutput { call_id, output } => { - assert_eq!(call_id, "call-42"); - assert_eq!(output.content_items(), None); - assert_eq!(output.body.to_text().as_deref(), Some("patched")); - assert_eq!(output.success, Some(true)); - } - other => panic!("expected CustomToolCallOutput, got {other:?}"), - } - } - - #[test] - fn function_payloads_remain_function_outputs() { - let payload = ToolPayload::Function { - arguments: "{}".to_string(), - }; - let response = FunctionToolOutput::from_text("ok".to_string(), Some(true)) - .to_response_item("fn-1", &payload); - - match response { - ResponseInputItem::FunctionCallOutput { call_id, output } => { - assert_eq!(call_id, "fn-1"); - assert_eq!(output.content_items(), None); - assert_eq!(output.body.to_text().as_deref(), Some("ok")); - assert_eq!(output.success, Some(true)); - } - other => panic!("expected FunctionCallOutput, got {other:?}"), - } - } - - #[test] - fn mcp_code_mode_result_serializes_full_call_tool_result() { - let output = CallToolResult { - content: vec![serde_json::json!({ - "type": "text", - "text": "ignored", - })], - structured_content: Some(serde_json::json!({ - "threadId": "thread_123", - "content": "done", - })), - is_error: Some(false), - meta: Some(serde_json::json!({ - "source": "mcp", - })), - }; - - let result = output.code_mode_result(&ToolPayload::Mcp { - server: "server".to_string(), - tool: "tool".to_string(), - raw_arguments: "{}".to_string(), - }); - - assert_eq!( - result, - serde_json::json!({ - "content": [{ - "type": "text", - "text": "ignored", - }], - "structuredContent": { - "threadId": "thread_123", - "content": "done", - }, - "isError": false, - "_meta": { - "source": "mcp", - }, - }) - ); - } - - #[test] - fn custom_tool_calls_can_derive_text_from_content_items() { - let payload = ToolPayload::Custom { - input: "patch".to_string(), - }; - let response = FunctionToolOutput::from_content( - vec![ - FunctionCallOutputContentItem::InputText { - text: "line 1".to_string(), - }, - FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,AAA".to_string(), - detail: None, - }, - FunctionCallOutputContentItem::InputText { - text: "line 2".to_string(), - }, - ], - Some(true), - ) - .to_response_item("call-99", &payload); - - match response { - ResponseInputItem::CustomToolCallOutput { call_id, output } => { - let expected = vec![ - FunctionCallOutputContentItem::InputText { - text: "line 1".to_string(), - }, - FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,AAA".to_string(), - detail: None, - }, - FunctionCallOutputContentItem::InputText { - text: "line 2".to_string(), - }, - ]; - assert_eq!(call_id, "call-99"); - assert_eq!(output.content_items(), Some(expected.as_slice())); - assert_eq!(output.body.to_text().as_deref(), Some("line 1\nline 2")); - assert_eq!(output.success, Some(true)); - } - other => panic!("expected CustomToolCallOutput, got {other:?}"), - } - } - - #[test] - fn tool_search_payloads_roundtrip_as_tool_search_outputs() { - let payload = ToolPayload::ToolSearch { - arguments: SearchToolCallParams { - query: "calendar".to_string(), - limit: None, - }, - }; - let response = ToolSearchOutput { - tools: vec![ToolSearchOutputTool::Function( - crate::client_common::tools::ResponsesApiTool { - name: "create_event".to_string(), - description: String::new(), - strict: false, - defer_loading: Some(true), - parameters: crate::tools::spec::JsonSchema::Object { - properties: Default::default(), - required: None, - additional_properties: None, - }, - output_schema: None, - }, - )], - } - .to_response_item("search-1", &payload); - - match response { - ResponseInputItem::ToolSearchOutput { - call_id, - status, - execution, - tools, - } => { - assert_eq!(call_id, "search-1"); - assert_eq!(status, "completed"); - assert_eq!(execution, "client"); - assert_eq!( - tools, - vec![json!({ - "type": "function", - "name": "create_event", - "description": "", - "strict": false, - "defer_loading": true, - "parameters": { - "type": "object", - "properties": {} - } - })] - ); - } - other => panic!("expected ToolSearchOutput, got {other:?}"), - } - } - - #[test] - fn log_preview_uses_content_items_when_plain_text_is_missing() { - let output = FunctionToolOutput::from_content( - vec![FunctionCallOutputContentItem::InputText { - text: "preview".to_string(), - }], - Some(true), - ); - - assert_eq!(output.log_preview(), "preview"); - assert_eq!( - function_call_output_content_items_to_text(&output.body), - Some("preview".to_string()) - ); - } - - #[test] - fn telemetry_preview_returns_original_within_limits() { - let content = "short output"; - assert_eq!(telemetry_preview(content), content); - } - - #[test] - fn telemetry_preview_truncates_by_bytes() { - let content = "x".repeat(TELEMETRY_PREVIEW_MAX_BYTES + 8); - let preview = telemetry_preview(&content); - - assert!(preview.contains(TELEMETRY_PREVIEW_TRUNCATION_NOTICE)); - assert!( - preview.len() - <= TELEMETRY_PREVIEW_MAX_BYTES + TELEMETRY_PREVIEW_TRUNCATION_NOTICE.len() + 1 - ); - } - - #[test] - fn telemetry_preview_truncates_by_lines() { - let content = (0..(TELEMETRY_PREVIEW_MAX_LINES + 5)) - .map(|idx| format!("line {idx}")) - .collect::>() - .join("\n"); - - let preview = telemetry_preview(&content); - let lines: Vec<&str> = preview.lines().collect(); - - assert!(lines.len() <= TELEMETRY_PREVIEW_MAX_LINES + 1); - assert_eq!(lines.last(), Some(&TELEMETRY_PREVIEW_TRUNCATION_NOTICE)); - } - - #[test] - fn exec_command_tool_output_formats_truncated_response() { - let payload = ToolPayload::Function { - arguments: "{}".to_string(), - }; - let response = ExecCommandToolOutput { - event_call_id: "call-42".to_string(), - chunk_id: "abc123".to_string(), - wall_time: std::time::Duration::from_millis(1250), - raw_output: b"token one token two token three token four token five".to_vec(), - max_output_tokens: Some(4), - process_id: None, - exit_code: Some(0), - original_token_count: Some(10), - session_command: None, - } - .to_response_item("call-42", &payload); - - match response { - ResponseInputItem::FunctionCallOutput { call_id, output } => { - assert_eq!(call_id, "call-42"); - assert_eq!(output.success, Some(true)); - let text = output - .body - .to_text() - .expect("exec output should serialize as text"); - assert_regex_match( - r#"(?sx) - ^Chunk\ ID:\ abc123 - \nWall\ time:\ \d+\.\d{4}\ seconds - \nProcess\ exited\ with\ code\ 0 - \nOriginal\ token\ count:\ 10 - \nOutput: - \n.*tokens\ truncated.* - $"#, - &text, - ); - } - other => panic!("expected FunctionCallOutput, got {other:?}"), - } - } -} +#[path = "context_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/context_tests.rs b/codex-rs/core/src/tools/context_tests.rs new file mode 100644 index 0000000000..d6ad76c5d5 --- /dev/null +++ b/codex-rs/core/src/tools/context_tests.rs @@ -0,0 +1,274 @@ +use super::*; +use core_test_support::assert_regex_match; +use pretty_assertions::assert_eq; +use serde_json::json; + +#[test] +fn custom_tool_calls_should_roundtrip_as_custom_outputs() { + let payload = ToolPayload::Custom { + input: "patch".to_string(), + }; + let response = FunctionToolOutput::from_text("patched".to_string(), Some(true)) + .to_response_item("call-42", &payload); + + match response { + ResponseInputItem::CustomToolCallOutput { call_id, output } => { + assert_eq!(call_id, "call-42"); + assert_eq!(output.content_items(), None); + assert_eq!(output.body.to_text().as_deref(), Some("patched")); + assert_eq!(output.success, Some(true)); + } + other => panic!("expected CustomToolCallOutput, got {other:?}"), + } +} + +#[test] +fn function_payloads_remain_function_outputs() { + let payload = ToolPayload::Function { + arguments: "{}".to_string(), + }; + let response = FunctionToolOutput::from_text("ok".to_string(), Some(true)) + .to_response_item("fn-1", &payload); + + match response { + ResponseInputItem::FunctionCallOutput { call_id, output } => { + assert_eq!(call_id, "fn-1"); + assert_eq!(output.content_items(), None); + assert_eq!(output.body.to_text().as_deref(), Some("ok")); + assert_eq!(output.success, Some(true)); + } + other => panic!("expected FunctionCallOutput, got {other:?}"), + } +} + +#[test] +fn mcp_code_mode_result_serializes_full_call_tool_result() { + let output = CallToolResult { + content: vec![serde_json::json!({ + "type": "text", + "text": "ignored", + })], + structured_content: Some(serde_json::json!({ + "threadId": "thread_123", + "content": "done", + })), + is_error: Some(false), + meta: Some(serde_json::json!({ + "source": "mcp", + })), + }; + + let result = output.code_mode_result(&ToolPayload::Mcp { + server: "server".to_string(), + tool: "tool".to_string(), + raw_arguments: "{}".to_string(), + }); + + assert_eq!( + result, + serde_json::json!({ + "content": [{ + "type": "text", + "text": "ignored", + }], + "structuredContent": { + "threadId": "thread_123", + "content": "done", + }, + "isError": false, + "_meta": { + "source": "mcp", + }, + }) + ); +} + +#[test] +fn custom_tool_calls_can_derive_text_from_content_items() { + let payload = ToolPayload::Custom { + input: "patch".to_string(), + }; + let response = FunctionToolOutput::from_content( + vec![ + FunctionCallOutputContentItem::InputText { + text: "line 1".to_string(), + }, + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,AAA".to_string(), + detail: None, + }, + FunctionCallOutputContentItem::InputText { + text: "line 2".to_string(), + }, + ], + Some(true), + ) + .to_response_item("call-99", &payload); + + match response { + ResponseInputItem::CustomToolCallOutput { call_id, output } => { + let expected = vec![ + FunctionCallOutputContentItem::InputText { + text: "line 1".to_string(), + }, + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,AAA".to_string(), + detail: None, + }, + FunctionCallOutputContentItem::InputText { + text: "line 2".to_string(), + }, + ]; + assert_eq!(call_id, "call-99"); + assert_eq!(output.content_items(), Some(expected.as_slice())); + assert_eq!(output.body.to_text().as_deref(), Some("line 1\nline 2")); + assert_eq!(output.success, Some(true)); + } + other => panic!("expected CustomToolCallOutput, got {other:?}"), + } +} + +#[test] +fn tool_search_payloads_roundtrip_as_tool_search_outputs() { + let payload = ToolPayload::ToolSearch { + arguments: SearchToolCallParams { + query: "calendar".to_string(), + limit: None, + }, + }; + let response = ToolSearchOutput { + tools: vec![ToolSearchOutputTool::Function( + crate::client_common::tools::ResponsesApiTool { + name: "create_event".to_string(), + description: String::new(), + strict: false, + defer_loading: Some(true), + parameters: crate::tools::spec::JsonSchema::Object { + properties: Default::default(), + required: None, + additional_properties: None, + }, + output_schema: None, + }, + )], + } + .to_response_item("search-1", &payload); + + match response { + ResponseInputItem::ToolSearchOutput { + call_id, + status, + execution, + tools, + } => { + assert_eq!(call_id, "search-1"); + assert_eq!(status, "completed"); + assert_eq!(execution, "client"); + assert_eq!( + tools, + vec![json!({ + "type": "function", + "name": "create_event", + "description": "", + "strict": false, + "defer_loading": true, + "parameters": { + "type": "object", + "properties": {} + } + })] + ); + } + other => panic!("expected ToolSearchOutput, got {other:?}"), + } +} + +#[test] +fn log_preview_uses_content_items_when_plain_text_is_missing() { + let output = FunctionToolOutput::from_content( + vec![FunctionCallOutputContentItem::InputText { + text: "preview".to_string(), + }], + Some(true), + ); + + assert_eq!(output.log_preview(), "preview"); + assert_eq!( + function_call_output_content_items_to_text(&output.body), + Some("preview".to_string()) + ); +} + +#[test] +fn telemetry_preview_returns_original_within_limits() { + let content = "short output"; + assert_eq!(telemetry_preview(content), content); +} + +#[test] +fn telemetry_preview_truncates_by_bytes() { + let content = "x".repeat(TELEMETRY_PREVIEW_MAX_BYTES + 8); + let preview = telemetry_preview(&content); + + assert!(preview.contains(TELEMETRY_PREVIEW_TRUNCATION_NOTICE)); + assert!( + preview.len() + <= TELEMETRY_PREVIEW_MAX_BYTES + TELEMETRY_PREVIEW_TRUNCATION_NOTICE.len() + 1 + ); +} + +#[test] +fn telemetry_preview_truncates_by_lines() { + let content = (0..(TELEMETRY_PREVIEW_MAX_LINES + 5)) + .map(|idx| format!("line {idx}")) + .collect::>() + .join("\n"); + + let preview = telemetry_preview(&content); + let lines: Vec<&str> = preview.lines().collect(); + + assert!(lines.len() <= TELEMETRY_PREVIEW_MAX_LINES + 1); + assert_eq!(lines.last(), Some(&TELEMETRY_PREVIEW_TRUNCATION_NOTICE)); +} + +#[test] +fn exec_command_tool_output_formats_truncated_response() { + let payload = ToolPayload::Function { + arguments: "{}".to_string(), + }; + let response = ExecCommandToolOutput { + event_call_id: "call-42".to_string(), + chunk_id: "abc123".to_string(), + wall_time: std::time::Duration::from_millis(1250), + raw_output: b"token one token two token three token four token five".to_vec(), + max_output_tokens: Some(4), + process_id: None, + exit_code: Some(0), + original_token_count: Some(10), + session_command: None, + } + .to_response_item("call-42", &payload); + + match response { + ResponseInputItem::FunctionCallOutput { call_id, output } => { + assert_eq!(call_id, "call-42"); + assert_eq!(output.success, Some(true)); + let text = output + .body + .to_text() + .expect("exec output should serialize as text"); + assert_regex_match( + r#"(?sx) + ^Chunk\ ID:\ abc123 + \nWall\ time:\ \d+\.\d{4}\ seconds + \nProcess\ exited\ with\ code\ 0 + \nOriginal\ token\ count:\ 10 + \nOutput: + \n.*tokens\ truncated.* + $"#, + &text, + ); + } + other => panic!("expected FunctionCallOutput, got {other:?}"), + } +} diff --git a/codex-rs/core/src/tools/handlers/agent_jobs.rs b/codex-rs/core/src/tools/handlers/agent_jobs.rs index ff02a3fbd1..4e786178f8 100644 --- a/codex-rs/core/src/tools/handlers/agent_jobs.rs +++ b/codex-rs/core/src/tools/handlers/agent_jobs.rs @@ -1152,67 +1152,5 @@ fn csv_escape(value: &str) -> String { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use serde_json::json; - - #[test] - fn parse_csv_supports_quotes_and_commas() { - let input = "id,name\n1,\"alpha, beta\"\n2,gamma\n"; - let (headers, rows) = parse_csv(input).expect("csv parse"); - assert_eq!(headers, vec!["id".to_string(), "name".to_string()]); - assert_eq!( - rows, - vec![ - vec!["1".to_string(), "alpha, beta".to_string()], - vec!["2".to_string(), "gamma".to_string()] - ] - ); - } - - #[test] - fn csv_escape_quotes_when_needed() { - assert_eq!(csv_escape("simple"), "simple"); - assert_eq!(csv_escape("a,b"), "\"a,b\""); - assert_eq!(csv_escape("a\"b"), "\"a\"\"b\""); - } - - #[test] - fn render_instruction_template_expands_placeholders_and_escapes_braces() { - let row = json!({ - "path": "src/lib.rs", - "area": "test", - "file path": "docs/readme.md", - }); - let rendered = render_instruction_template( - "Review {path} in {area}. Also see {file path}. Use {{literal}}.", - &row, - ); - assert_eq!( - rendered, - "Review src/lib.rs in test. Also see docs/readme.md. Use {literal}." - ); - } - - #[test] - fn render_instruction_template_leaves_unknown_placeholders() { - let row = json!({ - "path": "src/lib.rs", - }); - let rendered = render_instruction_template("Check {path} then {missing}", &row); - assert_eq!(rendered, "Check src/lib.rs then {missing}"); - } - - #[test] - fn ensure_unique_headers_rejects_duplicates() { - let headers = vec!["path".to_string(), "path".to_string()]; - let Err(err) = ensure_unique_headers(headers.as_slice()) else { - panic!("expected duplicate header error"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel("csv header path is duplicated".to_string()) - ); - } -} +#[path = "agent_jobs_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/agent_jobs_tests.rs b/codex-rs/core/src/tools/handlers/agent_jobs_tests.rs new file mode 100644 index 0000000000..a2dbe6a480 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/agent_jobs_tests.rs @@ -0,0 +1,62 @@ +use super::*; +use pretty_assertions::assert_eq; +use serde_json::json; + +#[test] +fn parse_csv_supports_quotes_and_commas() { + let input = "id,name\n1,\"alpha, beta\"\n2,gamma\n"; + let (headers, rows) = parse_csv(input).expect("csv parse"); + assert_eq!(headers, vec!["id".to_string(), "name".to_string()]); + assert_eq!( + rows, + vec![ + vec!["1".to_string(), "alpha, beta".to_string()], + vec!["2".to_string(), "gamma".to_string()] + ] + ); +} + +#[test] +fn csv_escape_quotes_when_needed() { + assert_eq!(csv_escape("simple"), "simple"); + assert_eq!(csv_escape("a,b"), "\"a,b\""); + assert_eq!(csv_escape("a\"b"), "\"a\"\"b\""); +} + +#[test] +fn render_instruction_template_expands_placeholders_and_escapes_braces() { + let row = json!({ + "path": "src/lib.rs", + "area": "test", + "file path": "docs/readme.md", + }); + let rendered = render_instruction_template( + "Review {path} in {area}. Also see {file path}. Use {{literal}}.", + &row, + ); + assert_eq!( + rendered, + "Review src/lib.rs in test. Also see docs/readme.md. Use {literal}." + ); +} + +#[test] +fn render_instruction_template_leaves_unknown_placeholders() { + let row = json!({ + "path": "src/lib.rs", + }); + let rendered = render_instruction_template("Check {path} then {missing}", &row); + assert_eq!(rendered, "Check src/lib.rs then {missing}"); +} + +#[test] +fn ensure_unique_headers_rejects_duplicates() { + let headers = vec!["path".to_string(), "path".to_string()]; + let Err(err) = ensure_unique_headers(headers.as_slice()) else { + panic!("expected duplicate header error"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel("csv header path is duplicated".to_string()) + ); +} diff --git a/codex-rs/core/src/tools/handlers/apply_patch.rs b/codex-rs/core/src/tools/handlers/apply_patch.rs index 24119cd23b..12fb904f4d 100644 --- a/codex-rs/core/src/tools/handlers/apply_patch.rs +++ b/codex-rs/core/src/tools/handlers/apply_patch.rs @@ -461,33 +461,5 @@ It is important to remember: } #[cfg(test)] -mod tests { - use super::*; - use codex_apply_patch::MaybeApplyPatchVerified; - use pretty_assertions::assert_eq; - use tempfile::TempDir; - - #[test] - fn approval_keys_include_move_destination() { - let tmp = TempDir::new().expect("tmp"); - let cwd = tmp.path(); - std::fs::create_dir_all(cwd.join("old")).expect("create old dir"); - std::fs::create_dir_all(cwd.join("renamed/dir")).expect("create dest dir"); - std::fs::write(cwd.join("old/name.txt"), "old content\n").expect("write old file"); - let patch = r#"*** Begin Patch -*** Update File: old/name.txt -*** Move to: renamed/dir/name.txt -@@ --old content -+new content -*** End Patch"#; - let argv = vec!["apply_patch".to_string(), patch.to_string()]; - let action = match codex_apply_patch::maybe_parse_apply_patch_verified(&argv, cwd) { - MaybeApplyPatchVerified::Body(action) => action, - other => panic!("expected patch body, got: {other:?}"), - }; - - let keys = file_paths_for_action(&action); - assert_eq!(keys.len(), 2); - } -} +#[path = "apply_patch_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/apply_patch_tests.rs b/codex-rs/core/src/tools/handlers/apply_patch_tests.rs new file mode 100644 index 0000000000..7f8e8df8b5 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/apply_patch_tests.rs @@ -0,0 +1,28 @@ +use super::*; +use codex_apply_patch::MaybeApplyPatchVerified; +use pretty_assertions::assert_eq; +use tempfile::TempDir; + +#[test] +fn approval_keys_include_move_destination() { + let tmp = TempDir::new().expect("tmp"); + let cwd = tmp.path(); + std::fs::create_dir_all(cwd.join("old")).expect("create old dir"); + std::fs::create_dir_all(cwd.join("renamed/dir")).expect("create dest dir"); + std::fs::write(cwd.join("old/name.txt"), "old content\n").expect("write old file"); + let patch = r#"*** Begin Patch +*** Update File: old/name.txt +*** Move to: renamed/dir/name.txt +@@ +-old content ++new content +*** End Patch"#; + let argv = vec!["apply_patch".to_string(), patch.to_string()]; + let action = match codex_apply_patch::maybe_parse_apply_patch_verified(&argv, cwd) { + MaybeApplyPatchVerified::Body(action) => action, + other => panic!("expected patch body, got: {other:?}"), + }; + + let keys = file_paths_for_action(&action); + assert_eq!(keys.len(), 2); +} diff --git a/codex-rs/core/src/tools/handlers/artifacts.rs b/codex-rs/core/src/tools/handlers/artifacts.rs index df239495ee..bbcbcd3c80 100644 --- a/codex-rs/core/src/tools/handlers/artifacts.rs +++ b/codex-rs/core/src/tools/handlers/artifacts.rs @@ -293,130 +293,5 @@ fn error_output(error: &ArtifactsError) -> ArtifactCommandOutput { } #[cfg(test)] -mod tests { - use super::*; - use codex_artifacts::RuntimeEntrypoints; - use codex_artifacts::RuntimePathEntry; - use tempfile::TempDir; - - #[test] - fn parse_freeform_args_without_pragma() { - let args = parse_freeform_args("console.log('ok');").expect("parse args"); - assert_eq!(args.source, "console.log('ok');"); - assert_eq!(args.timeout_ms, None); - } - - #[test] - fn parse_freeform_args_with_pragma() { - let args = parse_freeform_args("// codex-artifacts: timeout_ms=45000\nconsole.log('ok');") - .expect("parse args"); - assert_eq!(args.source, "console.log('ok');"); - assert_eq!(args.timeout_ms, Some(45_000)); - } - - #[test] - fn parse_freeform_args_with_artifact_tool_pragma() { - let args = - parse_freeform_args("// codex-artifact-tool: timeout_ms=45000\nconsole.log('ok');") - .expect("parse args"); - assert_eq!(args.source, "console.log('ok');"); - assert_eq!(args.timeout_ms, Some(45_000)); - } - - #[test] - fn parse_freeform_args_rejects_json_wrapped_code() { - let err = - parse_freeform_args("{\"code\":\"console.log('ok')\"}").expect_err("expected error"); - assert!( - err.to_string() - .contains("artifacts is a freeform tool and expects raw JavaScript source") - ); - } - - #[test] - fn default_runtime_manager_uses_openai_codex_release_base() { - let codex_home = TempDir::new().expect("create temp codex home"); - let manager = default_runtime_manager(codex_home.path().to_path_buf()); - - assert_eq!( - manager.config().release().base_url().as_str(), - "https://github.com/openai/codex/releases/download/" - ); - assert_eq!( - manager.config().release().runtime_version(), - PINNED_ARTIFACT_RUNTIME_VERSION - ); - } - - #[test] - fn load_cached_runtime_reads_pinned_cache_path() { - let codex_home = TempDir::new().expect("create temp codex home"); - let platform = - codex_artifacts::ArtifactRuntimePlatform::detect_current().expect("detect platform"); - let install_dir = codex_home - .path() - .join("packages") - .join("artifacts") - .join(PINNED_ARTIFACT_RUNTIME_VERSION) - .join(platform.as_str()); - std::fs::create_dir_all(&install_dir).expect("create install dir"); - std::fs::write( - install_dir.join("manifest.json"), - serde_json::json!({ - "schema_version": 1, - "runtime_version": PINNED_ARTIFACT_RUNTIME_VERSION, - "node": { "relative_path": "node/bin/node" }, - "entrypoints": { - "build_js": { "relative_path": "artifact-tool/dist/artifact_tool.mjs" }, - "render_cli": { "relative_path": "granola-render/dist/render_cli.mjs" } - } - }) - .to_string(), - ) - .expect("write manifest"); - std::fs::create_dir_all(install_dir.join("artifact-tool/dist")) - .expect("create build entrypoint dir"); - std::fs::create_dir_all(install_dir.join("granola-render/dist")) - .expect("create render entrypoint dir"); - std::fs::write( - install_dir.join("artifact-tool/dist/artifact_tool.mjs"), - "export const ok = true;\n", - ) - .expect("write build entrypoint"); - std::fs::write( - install_dir.join("granola-render/dist/render_cli.mjs"), - "export const ok = true;\n", - ) - .expect("write render entrypoint"); - - let runtime = codex_artifacts::load_cached_runtime( - &codex_home - .path() - .join(codex_artifacts::DEFAULT_CACHE_ROOT_RELATIVE), - PINNED_ARTIFACT_RUNTIME_VERSION, - ) - .expect("resolve runtime"); - assert_eq!(runtime.runtime_version(), PINNED_ARTIFACT_RUNTIME_VERSION); - assert_eq!( - runtime.manifest().entrypoints, - RuntimeEntrypoints { - build_js: RuntimePathEntry { - relative_path: "artifact-tool/dist/artifact_tool.mjs".to_string(), - }, - render_cli: RuntimePathEntry { - relative_path: "granola-render/dist/render_cli.mjs".to_string(), - }, - } - ); - } - - #[test] - fn format_artifact_output_includes_success_message_when_silent() { - let formatted = format_artifact_output(&ArtifactCommandOutput { - exit_code: Some(0), - stdout: String::new(), - stderr: String::new(), - }); - assert!(formatted.contains("artifact JS completed successfully.")); - } -} +#[path = "artifacts_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/artifacts_tests.rs b/codex-rs/core/src/tools/handlers/artifacts_tests.rs new file mode 100644 index 0000000000..f28636acc6 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/artifacts_tests.rs @@ -0,0 +1,123 @@ +use super::*; +use codex_artifacts::RuntimeEntrypoints; +use codex_artifacts::RuntimePathEntry; +use tempfile::TempDir; + +#[test] +fn parse_freeform_args_without_pragma() { + let args = parse_freeform_args("console.log('ok');").expect("parse args"); + assert_eq!(args.source, "console.log('ok');"); + assert_eq!(args.timeout_ms, None); +} + +#[test] +fn parse_freeform_args_with_pragma() { + let args = parse_freeform_args("// codex-artifacts: timeout_ms=45000\nconsole.log('ok');") + .expect("parse args"); + assert_eq!(args.source, "console.log('ok');"); + assert_eq!(args.timeout_ms, Some(45_000)); +} + +#[test] +fn parse_freeform_args_with_artifact_tool_pragma() { + let args = parse_freeform_args("// codex-artifact-tool: timeout_ms=45000\nconsole.log('ok');") + .expect("parse args"); + assert_eq!(args.source, "console.log('ok');"); + assert_eq!(args.timeout_ms, Some(45_000)); +} + +#[test] +fn parse_freeform_args_rejects_json_wrapped_code() { + let err = parse_freeform_args("{\"code\":\"console.log('ok')\"}").expect_err("expected error"); + assert!( + err.to_string() + .contains("artifacts is a freeform tool and expects raw JavaScript source") + ); +} + +#[test] +fn default_runtime_manager_uses_openai_codex_release_base() { + let codex_home = TempDir::new().expect("create temp codex home"); + let manager = default_runtime_manager(codex_home.path().to_path_buf()); + + assert_eq!( + manager.config().release().base_url().as_str(), + "https://github.com/openai/codex/releases/download/" + ); + assert_eq!( + manager.config().release().runtime_version(), + PINNED_ARTIFACT_RUNTIME_VERSION + ); +} + +#[test] +fn load_cached_runtime_reads_pinned_cache_path() { + let codex_home = TempDir::new().expect("create temp codex home"); + let platform = + codex_artifacts::ArtifactRuntimePlatform::detect_current().expect("detect platform"); + let install_dir = codex_home + .path() + .join("packages") + .join("artifacts") + .join(PINNED_ARTIFACT_RUNTIME_VERSION) + .join(platform.as_str()); + std::fs::create_dir_all(&install_dir).expect("create install dir"); + std::fs::write( + install_dir.join("manifest.json"), + serde_json::json!({ + "schema_version": 1, + "runtime_version": PINNED_ARTIFACT_RUNTIME_VERSION, + "node": { "relative_path": "node/bin/node" }, + "entrypoints": { + "build_js": { "relative_path": "artifact-tool/dist/artifact_tool.mjs" }, + "render_cli": { "relative_path": "granola-render/dist/render_cli.mjs" } + } + }) + .to_string(), + ) + .expect("write manifest"); + std::fs::create_dir_all(install_dir.join("artifact-tool/dist")) + .expect("create build entrypoint dir"); + std::fs::create_dir_all(install_dir.join("granola-render/dist")) + .expect("create render entrypoint dir"); + std::fs::write( + install_dir.join("artifact-tool/dist/artifact_tool.mjs"), + "export const ok = true;\n", + ) + .expect("write build entrypoint"); + std::fs::write( + install_dir.join("granola-render/dist/render_cli.mjs"), + "export const ok = true;\n", + ) + .expect("write render entrypoint"); + + let runtime = codex_artifacts::load_cached_runtime( + &codex_home + .path() + .join(codex_artifacts::DEFAULT_CACHE_ROOT_RELATIVE), + PINNED_ARTIFACT_RUNTIME_VERSION, + ) + .expect("resolve runtime"); + assert_eq!(runtime.runtime_version(), PINNED_ARTIFACT_RUNTIME_VERSION); + assert_eq!( + runtime.manifest().entrypoints, + RuntimeEntrypoints { + build_js: RuntimePathEntry { + relative_path: "artifact-tool/dist/artifact_tool.mjs".to_string(), + }, + render_cli: RuntimePathEntry { + relative_path: "granola-render/dist/render_cli.mjs".to_string(), + }, + } + ); +} + +#[test] +fn format_artifact_output_includes_success_message_when_silent() { + let formatted = format_artifact_output(&ArtifactCommandOutput { + exit_code: Some(0), + stdout: String::new(), + stderr: String::new(), + }); + assert!(formatted.contains("artifact JS completed successfully.")); +} diff --git a/codex-rs/core/src/tools/handlers/grep_files.rs b/codex-rs/core/src/tools/handlers/grep_files.rs index 071ecec70c..fdb0fce7be 100644 --- a/codex-rs/core/src/tools/handlers/grep_files.rs +++ b/codex-rs/core/src/tools/handlers/grep_files.rs @@ -172,100 +172,5 @@ fn parse_results(stdout: &[u8], limit: usize) -> Vec { } #[cfg(test)] -mod tests { - use super::*; - use std::process::Command as StdCommand; - use tempfile::tempdir; - - #[test] - fn parses_basic_results() { - let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n"; - let parsed = parse_results(stdout, 10); - assert_eq!( - parsed, - vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()] - ); - } - - #[test] - fn parse_truncates_after_limit() { - let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n/tmp/file_c.rs\n"; - let parsed = parse_results(stdout, 2); - assert_eq!( - parsed, - vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()] - ); - } - - #[tokio::test] - async fn run_search_returns_results() -> anyhow::Result<()> { - if !rg_available() { - return Ok(()); - } - let temp = tempdir().expect("create temp dir"); - let dir = temp.path(); - std::fs::write(dir.join("match_one.txt"), "alpha beta gamma").unwrap(); - std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap(); - std::fs::write(dir.join("other.txt"), "omega").unwrap(); - - let results = run_rg_search("alpha", None, dir, 10, dir).await?; - assert_eq!(results.len(), 2); - assert!(results.iter().any(|path| path.ends_with("match_one.txt"))); - assert!(results.iter().any(|path| path.ends_with("match_two.txt"))); - Ok(()) - } - - #[tokio::test] - async fn run_search_with_glob_filter() -> anyhow::Result<()> { - if !rg_available() { - return Ok(()); - } - let temp = tempdir().expect("create temp dir"); - let dir = temp.path(); - std::fs::write(dir.join("match_one.rs"), "alpha beta gamma").unwrap(); - std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap(); - - let results = run_rg_search("alpha", Some("*.rs"), dir, 10, dir).await?; - assert_eq!(results.len(), 1); - assert!(results.iter().all(|path| path.ends_with("match_one.rs"))); - Ok(()) - } - - #[tokio::test] - async fn run_search_respects_limit() -> anyhow::Result<()> { - if !rg_available() { - return Ok(()); - } - let temp = tempdir().expect("create temp dir"); - let dir = temp.path(); - std::fs::write(dir.join("one.txt"), "alpha one").unwrap(); - std::fs::write(dir.join("two.txt"), "alpha two").unwrap(); - std::fs::write(dir.join("three.txt"), "alpha three").unwrap(); - - let results = run_rg_search("alpha", None, dir, 2, dir).await?; - assert_eq!(results.len(), 2); - Ok(()) - } - - #[tokio::test] - async fn run_search_handles_no_matches() -> anyhow::Result<()> { - if !rg_available() { - return Ok(()); - } - let temp = tempdir().expect("create temp dir"); - let dir = temp.path(); - std::fs::write(dir.join("one.txt"), "omega").unwrap(); - - let results = run_rg_search("alpha", None, dir, 5, dir).await?; - assert!(results.is_empty()); - Ok(()) - } - - fn rg_available() -> bool { - StdCommand::new("rg") - .arg("--version") - .output() - .map(|output| output.status.success()) - .unwrap_or(false) - } -} +#[path = "grep_files_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/grep_files_tests.rs b/codex-rs/core/src/tools/handlers/grep_files_tests.rs new file mode 100644 index 0000000000..0cc247c6f1 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/grep_files_tests.rs @@ -0,0 +1,95 @@ +use super::*; +use std::process::Command as StdCommand; +use tempfile::tempdir; + +#[test] +fn parses_basic_results() { + let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n"; + let parsed = parse_results(stdout, 10); + assert_eq!( + parsed, + vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()] + ); +} + +#[test] +fn parse_truncates_after_limit() { + let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n/tmp/file_c.rs\n"; + let parsed = parse_results(stdout, 2); + assert_eq!( + parsed, + vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()] + ); +} + +#[tokio::test] +async fn run_search_returns_results() -> anyhow::Result<()> { + if !rg_available() { + return Ok(()); + } + let temp = tempdir().expect("create temp dir"); + let dir = temp.path(); + std::fs::write(dir.join("match_one.txt"), "alpha beta gamma").unwrap(); + std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap(); + std::fs::write(dir.join("other.txt"), "omega").unwrap(); + + let results = run_rg_search("alpha", None, dir, 10, dir).await?; + assert_eq!(results.len(), 2); + assert!(results.iter().any(|path| path.ends_with("match_one.txt"))); + assert!(results.iter().any(|path| path.ends_with("match_two.txt"))); + Ok(()) +} + +#[tokio::test] +async fn run_search_with_glob_filter() -> anyhow::Result<()> { + if !rg_available() { + return Ok(()); + } + let temp = tempdir().expect("create temp dir"); + let dir = temp.path(); + std::fs::write(dir.join("match_one.rs"), "alpha beta gamma").unwrap(); + std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap(); + + let results = run_rg_search("alpha", Some("*.rs"), dir, 10, dir).await?; + assert_eq!(results.len(), 1); + assert!(results.iter().all(|path| path.ends_with("match_one.rs"))); + Ok(()) +} + +#[tokio::test] +async fn run_search_respects_limit() -> anyhow::Result<()> { + if !rg_available() { + return Ok(()); + } + let temp = tempdir().expect("create temp dir"); + let dir = temp.path(); + std::fs::write(dir.join("one.txt"), "alpha one").unwrap(); + std::fs::write(dir.join("two.txt"), "alpha two").unwrap(); + std::fs::write(dir.join("three.txt"), "alpha three").unwrap(); + + let results = run_rg_search("alpha", None, dir, 2, dir).await?; + assert_eq!(results.len(), 2); + Ok(()) +} + +#[tokio::test] +async fn run_search_handles_no_matches() -> anyhow::Result<()> { + if !rg_available() { + return Ok(()); + } + let temp = tempdir().expect("create temp dir"); + let dir = temp.path(); + std::fs::write(dir.join("one.txt"), "omega").unwrap(); + + let results = run_rg_search("alpha", None, dir, 5, dir).await?; + assert!(results.is_empty()); + Ok(()) +} + +fn rg_available() -> bool { + StdCommand::new("rg") + .arg("--version") + .output() + .map(|output| output.status.success()) + .unwrap_or(false) +} diff --git a/codex-rs/core/src/tools/handlers/js_repl.rs b/codex-rs/core/src/tools/handlers/js_repl.rs index d0404a9e23..bfb531f923 100644 --- a/codex-rs/core/src/tools/handlers/js_repl.rs +++ b/codex-rs/core/src/tools/handlers/js_repl.rs @@ -292,95 +292,5 @@ fn reject_json_or_quoted_source(code: &str) -> Result<(), FunctionCallError> { } #[cfg(test)] -mod tests { - use std::time::Duration; - - use super::parse_freeform_args; - use crate::codex::make_session_and_context_with_rx; - use crate::protocol::EventMsg; - use crate::protocol::ExecCommandSource; - use pretty_assertions::assert_eq; - - #[test] - fn parse_freeform_args_without_pragma() { - let args = parse_freeform_args("console.log('ok');").expect("parse args"); - assert_eq!(args.code, "console.log('ok');"); - assert_eq!(args.timeout_ms, None); - } - - #[test] - fn parse_freeform_args_with_pragma() { - let input = "// codex-js-repl: timeout_ms=15000\nconsole.log('ok');"; - let args = parse_freeform_args(input).expect("parse args"); - assert_eq!(args.code, "console.log('ok');"); - assert_eq!(args.timeout_ms, Some(15_000)); - } - - #[test] - fn parse_freeform_args_rejects_unknown_key() { - let err = parse_freeform_args("// codex-js-repl: nope=1\nconsole.log('ok');") - .expect_err("expected error"); - assert_eq!( - err.to_string(), - "js_repl pragma only supports timeout_ms; got `nope`" - ); - } - - #[test] - fn parse_freeform_args_rejects_reset_key() { - let err = parse_freeform_args("// codex-js-repl: reset=true\nconsole.log('ok');") - .expect_err("expected error"); - assert_eq!( - err.to_string(), - "js_repl pragma only supports timeout_ms; got `reset`" - ); - } - - #[test] - fn parse_freeform_args_rejects_json_wrapped_code() { - let err = parse_freeform_args(r#"{"code":"await doThing()"}"#).expect_err("expected error"); - assert_eq!( - err.to_string(), - "js_repl is a freeform tool and expects raw JavaScript source. Resend plain JS only (optional first line `// codex-js-repl: ...`); do not send JSON (`{\"code\":...}`), quoted code, or markdown fences." - ); - } - - #[tokio::test] - async fn emit_js_repl_exec_end_sends_event() { - let (session, turn, rx) = make_session_and_context_with_rx().await; - super::emit_js_repl_exec_end( - session.as_ref(), - turn.as_ref(), - "call-1", - "hello", - None, - Duration::from_millis(12), - ) - .await; - - let event = tokio::time::timeout(Duration::from_secs(5), async { - loop { - let event = rx.recv().await.expect("event"); - if let EventMsg::ExecCommandEnd(end) = event.msg { - break end; - } - } - }) - .await - .expect("timed out waiting for exec end"); - - assert_eq!(event.call_id, "call-1"); - assert_eq!(event.turn_id, turn.sub_id); - assert_eq!(event.command, vec!["js_repl".to_string()]); - assert_eq!(event.cwd, turn.cwd); - assert_eq!(event.source, ExecCommandSource::Agent); - assert_eq!(event.interaction_input, None); - assert_eq!(event.stdout, "hello"); - assert_eq!(event.stderr, ""); - assert!(event.aggregated_output.contains("hello")); - assert_eq!(event.exit_code, 0); - assert_eq!(event.duration, Duration::from_millis(12)); - assert!(event.formatted_output.contains("hello")); - assert!(!event.parsed_cmd.is_empty()); - } -} +#[path = "js_repl_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/js_repl_tests.rs b/codex-rs/core/src/tools/handlers/js_repl_tests.rs new file mode 100644 index 0000000000..14dc222ffe --- /dev/null +++ b/codex-rs/core/src/tools/handlers/js_repl_tests.rs @@ -0,0 +1,90 @@ +use std::time::Duration; + +use super::parse_freeform_args; +use crate::codex::make_session_and_context_with_rx; +use crate::protocol::EventMsg; +use crate::protocol::ExecCommandSource; +use pretty_assertions::assert_eq; + +#[test] +fn parse_freeform_args_without_pragma() { + let args = parse_freeform_args("console.log('ok');").expect("parse args"); + assert_eq!(args.code, "console.log('ok');"); + assert_eq!(args.timeout_ms, None); +} + +#[test] +fn parse_freeform_args_with_pragma() { + let input = "// codex-js-repl: timeout_ms=15000\nconsole.log('ok');"; + let args = parse_freeform_args(input).expect("parse args"); + assert_eq!(args.code, "console.log('ok');"); + assert_eq!(args.timeout_ms, Some(15_000)); +} + +#[test] +fn parse_freeform_args_rejects_unknown_key() { + let err = parse_freeform_args("// codex-js-repl: nope=1\nconsole.log('ok');") + .expect_err("expected error"); + assert_eq!( + err.to_string(), + "js_repl pragma only supports timeout_ms; got `nope`" + ); +} + +#[test] +fn parse_freeform_args_rejects_reset_key() { + let err = parse_freeform_args("// codex-js-repl: reset=true\nconsole.log('ok');") + .expect_err("expected error"); + assert_eq!( + err.to_string(), + "js_repl pragma only supports timeout_ms; got `reset`" + ); +} + +#[test] +fn parse_freeform_args_rejects_json_wrapped_code() { + let err = parse_freeform_args(r#"{"code":"await doThing()"}"#).expect_err("expected error"); + assert_eq!( + err.to_string(), + "js_repl is a freeform tool and expects raw JavaScript source. Resend plain JS only (optional first line `// codex-js-repl: ...`); do not send JSON (`{\"code\":...}`), quoted code, or markdown fences." + ); +} + +#[tokio::test] +async fn emit_js_repl_exec_end_sends_event() { + let (session, turn, rx) = make_session_and_context_with_rx().await; + super::emit_js_repl_exec_end( + session.as_ref(), + turn.as_ref(), + "call-1", + "hello", + None, + Duration::from_millis(12), + ) + .await; + + let event = tokio::time::timeout(Duration::from_secs(5), async { + loop { + let event = rx.recv().await.expect("event"); + if let EventMsg::ExecCommandEnd(end) = event.msg { + break end; + } + } + }) + .await + .expect("timed out waiting for exec end"); + + assert_eq!(event.call_id, "call-1"); + assert_eq!(event.turn_id, turn.sub_id); + assert_eq!(event.command, vec!["js_repl".to_string()]); + assert_eq!(event.cwd, turn.cwd); + assert_eq!(event.source, ExecCommandSource::Agent); + assert_eq!(event.interaction_input, None); + assert_eq!(event.stdout, "hello"); + assert_eq!(event.stderr, ""); + assert!(event.aggregated_output.contains("hello")); + assert_eq!(event.exit_code, 0); + assert_eq!(event.duration, Duration::from_millis(12)); + assert!(event.formatted_output.contains("hello")); + assert!(!event.parsed_cmd.is_empty()); +} diff --git a/codex-rs/core/src/tools/handlers/list_dir.rs b/codex-rs/core/src/tools/handlers/list_dir.rs index fb65b32823..fd461e82e5 100644 --- a/codex-rs/core/src/tools/handlers/list_dir.rs +++ b/codex-rs/core/src/tools/handlers/list_dir.rs @@ -267,246 +267,5 @@ impl From<&FileType> for DirEntryKind { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use tempfile::tempdir; - - #[tokio::test] - async fn lists_directory_entries() { - let temp = tempdir().expect("create tempdir"); - let dir_path = temp.path(); - - let sub_dir = dir_path.join("nested"); - tokio::fs::create_dir(&sub_dir) - .await - .expect("create sub dir"); - - let deeper_dir = sub_dir.join("deeper"); - tokio::fs::create_dir(&deeper_dir) - .await - .expect("create deeper dir"); - - tokio::fs::write(dir_path.join("entry.txt"), b"content") - .await - .expect("write file"); - tokio::fs::write(sub_dir.join("child.txt"), b"child") - .await - .expect("write child"); - tokio::fs::write(deeper_dir.join("grandchild.txt"), b"grandchild") - .await - .expect("write grandchild"); - - #[cfg(unix)] - { - use std::os::unix::fs::symlink; - let link_path = dir_path.join("link"); - symlink(dir_path.join("entry.txt"), &link_path).expect("create symlink"); - } - - let entries = list_dir_slice(dir_path, 1, 20, 3) - .await - .expect("list directory"); - - #[cfg(unix)] - let expected = vec![ - "entry.txt".to_string(), - "link@".to_string(), - "nested/".to_string(), - " child.txt".to_string(), - " deeper/".to_string(), - " grandchild.txt".to_string(), - ]; - - #[cfg(not(unix))] - let expected = vec![ - "entry.txt".to_string(), - "nested/".to_string(), - " child.txt".to_string(), - " deeper/".to_string(), - " grandchild.txt".to_string(), - ]; - - assert_eq!(entries, expected); - } - - #[tokio::test] - async fn errors_when_offset_exceeds_entries() { - let temp = tempdir().expect("create tempdir"); - let dir_path = temp.path(); - tokio::fs::create_dir(dir_path.join("nested")) - .await - .expect("create sub dir"); - - let err = list_dir_slice(dir_path, 10, 1, 2) - .await - .expect_err("offset exceeds entries"); - assert_eq!( - err, - FunctionCallError::RespondToModel("offset exceeds directory entry count".to_string()) - ); - } - - #[tokio::test] - async fn respects_depth_parameter() { - let temp = tempdir().expect("create tempdir"); - let dir_path = temp.path(); - let nested = dir_path.join("nested"); - let deeper = nested.join("deeper"); - tokio::fs::create_dir(&nested).await.expect("create nested"); - tokio::fs::create_dir(&deeper).await.expect("create deeper"); - tokio::fs::write(dir_path.join("root.txt"), b"root") - .await - .expect("write root"); - tokio::fs::write(nested.join("child.txt"), b"child") - .await - .expect("write nested"); - tokio::fs::write(deeper.join("grandchild.txt"), b"deep") - .await - .expect("write deeper"); - - let entries_depth_one = list_dir_slice(dir_path, 1, 10, 1) - .await - .expect("list depth 1"); - assert_eq!( - entries_depth_one, - vec!["nested/".to_string(), "root.txt".to_string(),] - ); - - let entries_depth_two = list_dir_slice(dir_path, 1, 20, 2) - .await - .expect("list depth 2"); - assert_eq!( - entries_depth_two, - vec![ - "nested/".to_string(), - " child.txt".to_string(), - " deeper/".to_string(), - "root.txt".to_string(), - ] - ); - - let entries_depth_three = list_dir_slice(dir_path, 1, 30, 3) - .await - .expect("list depth 3"); - assert_eq!( - entries_depth_three, - vec![ - "nested/".to_string(), - " child.txt".to_string(), - " deeper/".to_string(), - " grandchild.txt".to_string(), - "root.txt".to_string(), - ] - ); - } - - #[tokio::test] - async fn paginates_in_sorted_order() { - let temp = tempdir().expect("create tempdir"); - let dir_path = temp.path(); - - let dir_a = dir_path.join("a"); - let dir_b = dir_path.join("b"); - tokio::fs::create_dir(&dir_a).await.expect("create a"); - tokio::fs::create_dir(&dir_b).await.expect("create b"); - - tokio::fs::write(dir_a.join("a_child.txt"), b"a") - .await - .expect("write a child"); - tokio::fs::write(dir_b.join("b_child.txt"), b"b") - .await - .expect("write b child"); - - let first_page = list_dir_slice(dir_path, 1, 2, 2) - .await - .expect("list page one"); - assert_eq!( - first_page, - vec![ - "a/".to_string(), - " a_child.txt".to_string(), - "More than 2 entries found".to_string() - ] - ); - - let second_page = list_dir_slice(dir_path, 3, 2, 2) - .await - .expect("list page two"); - assert_eq!( - second_page, - vec!["b/".to_string(), " b_child.txt".to_string()] - ); - } - - #[tokio::test] - async fn handles_large_limit_without_overflow() { - let temp = tempdir().expect("create tempdir"); - let dir_path = temp.path(); - tokio::fs::write(dir_path.join("alpha.txt"), b"alpha") - .await - .expect("write alpha"); - tokio::fs::write(dir_path.join("beta.txt"), b"beta") - .await - .expect("write beta"); - tokio::fs::write(dir_path.join("gamma.txt"), b"gamma") - .await - .expect("write gamma"); - - let entries = list_dir_slice(dir_path, 2, usize::MAX, 1) - .await - .expect("list without overflow"); - assert_eq!( - entries, - vec!["beta.txt".to_string(), "gamma.txt".to_string(),] - ); - } - - #[tokio::test] - async fn indicates_truncated_results() { - let temp = tempdir().expect("create tempdir"); - let dir_path = temp.path(); - - for idx in 0..40 { - let file = dir_path.join(format!("file_{idx:02}.txt")); - tokio::fs::write(file, b"content") - .await - .expect("write file"); - } - - let entries = list_dir_slice(dir_path, 1, 25, 1) - .await - .expect("list directory"); - assert_eq!(entries.len(), 26); - assert_eq!( - entries.last(), - Some(&"More than 25 entries found".to_string()) - ); - } - - #[tokio::test] - async fn truncation_respects_sorted_order() -> anyhow::Result<()> { - let temp = tempdir()?; - let dir_path = temp.path(); - let nested = dir_path.join("nested"); - let deeper = nested.join("deeper"); - tokio::fs::create_dir(&nested).await?; - tokio::fs::create_dir(&deeper).await?; - tokio::fs::write(dir_path.join("root.txt"), b"root").await?; - tokio::fs::write(nested.join("child.txt"), b"child").await?; - tokio::fs::write(deeper.join("grandchild.txt"), b"deep").await?; - - let entries_depth_three = list_dir_slice(dir_path, 1, 3, 3).await?; - assert_eq!( - entries_depth_three, - vec![ - "nested/".to_string(), - " child.txt".to_string(), - " deeper/".to_string(), - "More than 3 entries found".to_string() - ] - ); - - Ok(()) - } -} +#[path = "list_dir_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/list_dir_tests.rs b/codex-rs/core/src/tools/handlers/list_dir_tests.rs new file mode 100644 index 0000000000..8e3991a758 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/list_dir_tests.rs @@ -0,0 +1,241 @@ +use super::*; +use pretty_assertions::assert_eq; +use tempfile::tempdir; + +#[tokio::test] +async fn lists_directory_entries() { + let temp = tempdir().expect("create tempdir"); + let dir_path = temp.path(); + + let sub_dir = dir_path.join("nested"); + tokio::fs::create_dir(&sub_dir) + .await + .expect("create sub dir"); + + let deeper_dir = sub_dir.join("deeper"); + tokio::fs::create_dir(&deeper_dir) + .await + .expect("create deeper dir"); + + tokio::fs::write(dir_path.join("entry.txt"), b"content") + .await + .expect("write file"); + tokio::fs::write(sub_dir.join("child.txt"), b"child") + .await + .expect("write child"); + tokio::fs::write(deeper_dir.join("grandchild.txt"), b"grandchild") + .await + .expect("write grandchild"); + + #[cfg(unix)] + { + use std::os::unix::fs::symlink; + let link_path = dir_path.join("link"); + symlink(dir_path.join("entry.txt"), &link_path).expect("create symlink"); + } + + let entries = list_dir_slice(dir_path, 1, 20, 3) + .await + .expect("list directory"); + + #[cfg(unix)] + let expected = vec![ + "entry.txt".to_string(), + "link@".to_string(), + "nested/".to_string(), + " child.txt".to_string(), + " deeper/".to_string(), + " grandchild.txt".to_string(), + ]; + + #[cfg(not(unix))] + let expected = vec![ + "entry.txt".to_string(), + "nested/".to_string(), + " child.txt".to_string(), + " deeper/".to_string(), + " grandchild.txt".to_string(), + ]; + + assert_eq!(entries, expected); +} + +#[tokio::test] +async fn errors_when_offset_exceeds_entries() { + let temp = tempdir().expect("create tempdir"); + let dir_path = temp.path(); + tokio::fs::create_dir(dir_path.join("nested")) + .await + .expect("create sub dir"); + + let err = list_dir_slice(dir_path, 10, 1, 2) + .await + .expect_err("offset exceeds entries"); + assert_eq!( + err, + FunctionCallError::RespondToModel("offset exceeds directory entry count".to_string()) + ); +} + +#[tokio::test] +async fn respects_depth_parameter() { + let temp = tempdir().expect("create tempdir"); + let dir_path = temp.path(); + let nested = dir_path.join("nested"); + let deeper = nested.join("deeper"); + tokio::fs::create_dir(&nested).await.expect("create nested"); + tokio::fs::create_dir(&deeper).await.expect("create deeper"); + tokio::fs::write(dir_path.join("root.txt"), b"root") + .await + .expect("write root"); + tokio::fs::write(nested.join("child.txt"), b"child") + .await + .expect("write nested"); + tokio::fs::write(deeper.join("grandchild.txt"), b"deep") + .await + .expect("write deeper"); + + let entries_depth_one = list_dir_slice(dir_path, 1, 10, 1) + .await + .expect("list depth 1"); + assert_eq!( + entries_depth_one, + vec!["nested/".to_string(), "root.txt".to_string(),] + ); + + let entries_depth_two = list_dir_slice(dir_path, 1, 20, 2) + .await + .expect("list depth 2"); + assert_eq!( + entries_depth_two, + vec![ + "nested/".to_string(), + " child.txt".to_string(), + " deeper/".to_string(), + "root.txt".to_string(), + ] + ); + + let entries_depth_three = list_dir_slice(dir_path, 1, 30, 3) + .await + .expect("list depth 3"); + assert_eq!( + entries_depth_three, + vec![ + "nested/".to_string(), + " child.txt".to_string(), + " deeper/".to_string(), + " grandchild.txt".to_string(), + "root.txt".to_string(), + ] + ); +} + +#[tokio::test] +async fn paginates_in_sorted_order() { + let temp = tempdir().expect("create tempdir"); + let dir_path = temp.path(); + + let dir_a = dir_path.join("a"); + let dir_b = dir_path.join("b"); + tokio::fs::create_dir(&dir_a).await.expect("create a"); + tokio::fs::create_dir(&dir_b).await.expect("create b"); + + tokio::fs::write(dir_a.join("a_child.txt"), b"a") + .await + .expect("write a child"); + tokio::fs::write(dir_b.join("b_child.txt"), b"b") + .await + .expect("write b child"); + + let first_page = list_dir_slice(dir_path, 1, 2, 2) + .await + .expect("list page one"); + assert_eq!( + first_page, + vec![ + "a/".to_string(), + " a_child.txt".to_string(), + "More than 2 entries found".to_string() + ] + ); + + let second_page = list_dir_slice(dir_path, 3, 2, 2) + .await + .expect("list page two"); + assert_eq!( + second_page, + vec!["b/".to_string(), " b_child.txt".to_string()] + ); +} + +#[tokio::test] +async fn handles_large_limit_without_overflow() { + let temp = tempdir().expect("create tempdir"); + let dir_path = temp.path(); + tokio::fs::write(dir_path.join("alpha.txt"), b"alpha") + .await + .expect("write alpha"); + tokio::fs::write(dir_path.join("beta.txt"), b"beta") + .await + .expect("write beta"); + tokio::fs::write(dir_path.join("gamma.txt"), b"gamma") + .await + .expect("write gamma"); + + let entries = list_dir_slice(dir_path, 2, usize::MAX, 1) + .await + .expect("list without overflow"); + assert_eq!( + entries, + vec!["beta.txt".to_string(), "gamma.txt".to_string(),] + ); +} + +#[tokio::test] +async fn indicates_truncated_results() { + let temp = tempdir().expect("create tempdir"); + let dir_path = temp.path(); + + for idx in 0..40 { + let file = dir_path.join(format!("file_{idx:02}.txt")); + tokio::fs::write(file, b"content") + .await + .expect("write file"); + } + + let entries = list_dir_slice(dir_path, 1, 25, 1) + .await + .expect("list directory"); + assert_eq!(entries.len(), 26); + assert_eq!( + entries.last(), + Some(&"More than 25 entries found".to_string()) + ); +} + +#[tokio::test] +async fn truncation_respects_sorted_order() -> anyhow::Result<()> { + let temp = tempdir()?; + let dir_path = temp.path(); + let nested = dir_path.join("nested"); + let deeper = nested.join("deeper"); + tokio::fs::create_dir(&nested).await?; + tokio::fs::create_dir(&deeper).await?; + tokio::fs::write(dir_path.join("root.txt"), b"root").await?; + tokio::fs::write(nested.join("child.txt"), b"child").await?; + tokio::fs::write(deeper.join("grandchild.txt"), b"deep").await?; + + let entries_depth_three = list_dir_slice(dir_path, 1, 3, 3).await?; + assert_eq!( + entries_depth_three, + vec![ + "nested/".to_string(), + " child.txt".to_string(), + " deeper/".to_string(), + "More than 3 entries found".to_string() + ] + ); + + Ok(()) +} diff --git a/codex-rs/core/src/tools/handlers/mcp_resource.rs b/codex-rs/core/src/tools/handlers/mcp_resource.rs index d1480063b5..02253d1080 100644 --- a/codex-rs/core/src/tools/handlers/mcp_resource.rs +++ b/codex-rs/core/src/tools/handlers/mcp_resource.rs @@ -663,131 +663,5 @@ where } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use rmcp::model::AnnotateAble; - use serde_json::json; - - fn resource(uri: &str, name: &str) -> Resource { - rmcp::model::RawResource { - uri: uri.to_string(), - name: name.to_string(), - title: None, - description: None, - mime_type: None, - size: None, - icons: None, - meta: None, - } - .no_annotation() - } - - fn template(uri_template: &str, name: &str) -> ResourceTemplate { - rmcp::model::RawResourceTemplate { - uri_template: uri_template.to_string(), - name: name.to_string(), - title: None, - description: None, - mime_type: None, - icons: None, - } - .no_annotation() - } - - #[test] - fn resource_with_server_serializes_server_field() { - let entry = ResourceWithServer::new("test".to_string(), resource("memo://id", "memo")); - let value = serde_json::to_value(&entry).expect("serialize resource"); - - assert_eq!(value["server"], json!("test")); - assert_eq!(value["uri"], json!("memo://id")); - assert_eq!(value["name"], json!("memo")); - } - - #[test] - fn list_resources_payload_from_single_server_copies_next_cursor() { - let result = ListResourcesResult { - meta: None, - next_cursor: Some("cursor-1".to_string()), - resources: vec![resource("memo://id", "memo")], - }; - let payload = ListResourcesPayload::from_single_server("srv".to_string(), result); - let value = serde_json::to_value(&payload).expect("serialize payload"); - - assert_eq!(value["server"], json!("srv")); - assert_eq!(value["nextCursor"], json!("cursor-1")); - let resources = value["resources"].as_array().expect("resources array"); - assert_eq!(resources.len(), 1); - assert_eq!(resources[0]["server"], json!("srv")); - } - - #[test] - fn list_resources_payload_from_all_servers_is_sorted() { - let mut map = HashMap::new(); - map.insert("beta".to_string(), vec![resource("memo://b-1", "b-1")]); - map.insert( - "alpha".to_string(), - vec![resource("memo://a-1", "a-1"), resource("memo://a-2", "a-2")], - ); - - let payload = ListResourcesPayload::from_all_servers(map); - let value = serde_json::to_value(&payload).expect("serialize payload"); - let uris: Vec = value["resources"] - .as_array() - .expect("resources array") - .iter() - .map(|entry| entry["uri"].as_str().unwrap().to_string()) - .collect(); - - assert_eq!( - uris, - vec![ - "memo://a-1".to_string(), - "memo://a-2".to_string(), - "memo://b-1".to_string() - ] - ); - } - - #[test] - fn call_tool_result_from_content_marks_success() { - let result = call_tool_result_from_content("{}", Some(true)); - assert_eq!(result.is_error, Some(false)); - assert_eq!(result.content.len(), 1); - } - - #[test] - fn parse_arguments_handles_empty_and_json() { - assert!( - parse_arguments(" \n\t").unwrap().is_none(), - "expected None for empty arguments" - ); - - assert!( - parse_arguments("null").unwrap().is_none(), - "expected None for null arguments" - ); - - let value = parse_arguments(r#"{"server":"figma"}"#) - .expect("parse json") - .expect("value present"); - assert_eq!(value["server"], json!("figma")); - } - - #[test] - fn template_with_server_serializes_server_field() { - let entry = - ResourceTemplateWithServer::new("srv".to_string(), template("memo://{id}", "memo")); - let value = serde_json::to_value(&entry).expect("serialize template"); - - assert_eq!( - value, - json!({ - "server": "srv", - "uriTemplate": "memo://{id}", - "name": "memo" - }) - ); - } -} +#[path = "mcp_resource_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/mcp_resource_tests.rs b/codex-rs/core/src/tools/handlers/mcp_resource_tests.rs new file mode 100644 index 0000000000..8a8410b0bd --- /dev/null +++ b/codex-rs/core/src/tools/handlers/mcp_resource_tests.rs @@ -0,0 +1,125 @@ +use super::*; +use pretty_assertions::assert_eq; +use rmcp::model::AnnotateAble; +use serde_json::json; + +fn resource(uri: &str, name: &str) -> Resource { + rmcp::model::RawResource { + uri: uri.to_string(), + name: name.to_string(), + title: None, + description: None, + mime_type: None, + size: None, + icons: None, + meta: None, + } + .no_annotation() +} + +fn template(uri_template: &str, name: &str) -> ResourceTemplate { + rmcp::model::RawResourceTemplate { + uri_template: uri_template.to_string(), + name: name.to_string(), + title: None, + description: None, + mime_type: None, + icons: None, + } + .no_annotation() +} + +#[test] +fn resource_with_server_serializes_server_field() { + let entry = ResourceWithServer::new("test".to_string(), resource("memo://id", "memo")); + let value = serde_json::to_value(&entry).expect("serialize resource"); + + assert_eq!(value["server"], json!("test")); + assert_eq!(value["uri"], json!("memo://id")); + assert_eq!(value["name"], json!("memo")); +} + +#[test] +fn list_resources_payload_from_single_server_copies_next_cursor() { + let result = ListResourcesResult { + meta: None, + next_cursor: Some("cursor-1".to_string()), + resources: vec![resource("memo://id", "memo")], + }; + let payload = ListResourcesPayload::from_single_server("srv".to_string(), result); + let value = serde_json::to_value(&payload).expect("serialize payload"); + + assert_eq!(value["server"], json!("srv")); + assert_eq!(value["nextCursor"], json!("cursor-1")); + let resources = value["resources"].as_array().expect("resources array"); + assert_eq!(resources.len(), 1); + assert_eq!(resources[0]["server"], json!("srv")); +} + +#[test] +fn list_resources_payload_from_all_servers_is_sorted() { + let mut map = HashMap::new(); + map.insert("beta".to_string(), vec![resource("memo://b-1", "b-1")]); + map.insert( + "alpha".to_string(), + vec![resource("memo://a-1", "a-1"), resource("memo://a-2", "a-2")], + ); + + let payload = ListResourcesPayload::from_all_servers(map); + let value = serde_json::to_value(&payload).expect("serialize payload"); + let uris: Vec = value["resources"] + .as_array() + .expect("resources array") + .iter() + .map(|entry| entry["uri"].as_str().unwrap().to_string()) + .collect(); + + assert_eq!( + uris, + vec![ + "memo://a-1".to_string(), + "memo://a-2".to_string(), + "memo://b-1".to_string() + ] + ); +} + +#[test] +fn call_tool_result_from_content_marks_success() { + let result = call_tool_result_from_content("{}", Some(true)); + assert_eq!(result.is_error, Some(false)); + assert_eq!(result.content.len(), 1); +} + +#[test] +fn parse_arguments_handles_empty_and_json() { + assert!( + parse_arguments(" \n\t").unwrap().is_none(), + "expected None for empty arguments" + ); + + assert!( + parse_arguments("null").unwrap().is_none(), + "expected None for null arguments" + ); + + let value = parse_arguments(r#"{"server":"figma"}"#) + .expect("parse json") + .expect("value present"); + assert_eq!(value["server"], json!("figma")); +} + +#[test] +fn template_with_server_serializes_server_field() { + let entry = ResourceTemplateWithServer::new("srv".to_string(), template("memo://{id}", "memo")); + let value = serde_json::to_value(&entry).expect("serialize template"); + + assert_eq!( + value, + json!({ + "server": "srv", + "uriTemplate": "memo://{id}", + "name": "memo" + }) + ); +} diff --git a/codex-rs/core/src/tools/handlers/multi_agents.rs b/codex-rs/core/src/tools/handlers/multi_agents.rs index 38594b5029..193a4e5e7b 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents.rs @@ -1075,1115 +1075,5 @@ fn validate_spawn_agent_reasoning_effort( } #[cfg(test)] -mod tests { - use super::*; - use crate::AuthManager; - use crate::CodexAuth; - use crate::ThreadManager; - use crate::built_in_model_providers; - use crate::codex::make_session_and_context; - use crate::config::DEFAULT_AGENT_MAX_DEPTH; - use crate::config::types::ShellEnvironmentPolicy; - use crate::function_tool::FunctionCallError; - use crate::protocol::AskForApproval; - use crate::protocol::Op; - use crate::protocol::SandboxPolicy; - use crate::protocol::SessionSource; - use crate::protocol::SubAgentSource; - use crate::tools::context::FunctionToolOutput; - use crate::turn_diff_tracker::TurnDiffTracker; - use codex_protocol::ThreadId; - use codex_protocol::models::ContentItem; - use codex_protocol::models::ResponseItem; - use codex_protocol::protocol::InitialHistory; - use codex_protocol::protocol::RolloutItem; - use pretty_assertions::assert_eq; - use serde::Deserialize; - use serde_json::json; - use std::collections::HashMap; - use std::path::PathBuf; - use std::sync::Arc; - use std::time::Duration; - use tokio::sync::Mutex; - use tokio::time::timeout; - - fn invocation( - session: Arc, - turn: Arc, - tool_name: &str, - payload: ToolPayload, - ) -> ToolInvocation { - ToolInvocation { - session, - turn, - tracker: Arc::new(Mutex::new(TurnDiffTracker::default())), - call_id: "call-1".to_string(), - tool_name: tool_name.to_string(), - tool_namespace: None, - payload, - } - } - - fn function_payload(args: serde_json::Value) -> ToolPayload { - ToolPayload::Function { - arguments: args.to_string(), - } - } - - fn thread_manager() -> ThreadManager { - ThreadManager::with_models_provider_for_tests( - CodexAuth::from_api_key("dummy"), - built_in_model_providers()["openai"].clone(), - ) - } - - fn expect_text_output(output: FunctionToolOutput) -> (String, Option) { - ( - codex_protocol::models::function_call_output_content_items_to_text(&output.body) - .unwrap_or_default(), - output.success, - ) - } - - #[tokio::test] - async fn handler_rejects_non_function_payloads() { - let (session, turn) = make_session_and_context().await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "spawn_agent", - ToolPayload::Custom { - input: "hello".to_string(), - }, - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("payload should be rejected"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel( - "collab handler received unsupported payload".to_string() - ) - ); - } - - #[tokio::test] - async fn handler_rejects_unknown_tool() { - let (session, turn) = make_session_and_context().await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "unknown_tool", - function_payload(json!({})), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("tool should be rejected"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel("unsupported collab tool unknown_tool".to_string()) - ); - } - - #[tokio::test] - async fn spawn_agent_rejects_empty_message() { - let (session, turn) = make_session_and_context().await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "spawn_agent", - function_payload(json!({"message": " "})), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("empty message should be rejected"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel( - "Empty message can't be sent to an agent".to_string() - ) - ); - } - - #[tokio::test] - async fn spawn_agent_rejects_when_message_and_items_are_both_set() { - let (session, turn) = make_session_and_context().await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "spawn_agent", - function_payload(json!({ - "message": "hello", - "items": [{"type": "mention", "name": "drive", "path": "app://drive"}] - })), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("message+items should be rejected"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel( - "Provide either message or items, but not both".to_string() - ) - ); - } - - #[tokio::test] - async fn spawn_agent_uses_explorer_role_and_preserves_approval_policy() { - #[derive(Debug, Deserialize)] - struct SpawnAgentResult { - agent_id: String, - nickname: Option, - } - - let (mut session, mut turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let mut config = (*turn.config).clone(); - let provider = built_in_model_providers()["ollama"].clone(); - config.model_provider_id = "ollama".to_string(); - config.model_provider = provider.clone(); - config - .permissions - .approval_policy - .set(AskForApproval::OnRequest) - .expect("approval policy should be set"); - turn.approval_policy - .set(AskForApproval::OnRequest) - .expect("approval policy should be set"); - turn.provider = provider; - turn.config = Arc::new(config); - - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "spawn_agent", - function_payload(json!({ - "message": "inspect this repo", - "agent_type": "explorer" - })), - ); - let output = MultiAgentHandler - .handle(invocation) - .await - .expect("spawn_agent should succeed"); - let (content, _) = expect_text_output(output); - let result: SpawnAgentResult = - serde_json::from_str(&content).expect("spawn_agent result should be json"); - let agent_id = agent_id(&result.agent_id).expect("agent_id should be valid"); - assert!( - result - .nickname - .as_deref() - .is_some_and(|nickname| !nickname.is_empty()) - ); - let snapshot = manager - .get_thread(agent_id) - .await - .expect("spawned agent thread should exist") - .config_snapshot() - .await; - assert_eq!(snapshot.approval_policy, AskForApproval::OnRequest); - assert_eq!(snapshot.model_provider_id, "ollama"); - } - - #[tokio::test] - async fn spawn_agent_errors_when_manager_dropped() { - let (session, turn) = make_session_and_context().await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "spawn_agent", - function_payload(json!({"message": "hello"})), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("spawn should fail without a manager"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel("collab manager unavailable".to_string()) - ); - } - - #[tokio::test] - async fn spawn_agent_reapplies_runtime_sandbox_after_role_config() { - fn pick_allowed_sandbox_policy( - constraint: &crate::config::Constrained, - base: SandboxPolicy, - ) -> SandboxPolicy { - let candidates = [ - SandboxPolicy::DangerFullAccess, - SandboxPolicy::new_workspace_write_policy(), - SandboxPolicy::new_read_only_policy(), - ]; - candidates - .into_iter() - .find(|candidate| *candidate != base && constraint.can_set(candidate).is_ok()) - .unwrap_or(base) - } - - #[derive(Debug, Deserialize)] - struct SpawnAgentResult { - agent_id: String, - nickname: Option, - } - - let (mut session, mut turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let expected_sandbox = pick_allowed_sandbox_policy( - &turn.config.permissions.sandbox_policy, - turn.config.permissions.sandbox_policy.get().clone(), - ); - turn.approval_policy - .set(AskForApproval::OnRequest) - .expect("approval policy should be set"); - turn.sandbox_policy - .set(expected_sandbox.clone()) - .expect("sandbox policy should be set"); - assert_ne!( - expected_sandbox, - turn.config.permissions.sandbox_policy.get().clone(), - "test requires a runtime sandbox override that differs from base config" - ); - - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "spawn_agent", - function_payload(json!({ - "message": "await this command", - "agent_type": "explorer" - })), - ); - let output = MultiAgentHandler - .handle(invocation) - .await - .expect("spawn_agent should succeed"); - let (content, _) = expect_text_output(output); - let result: SpawnAgentResult = - serde_json::from_str(&content).expect("spawn_agent result should be json"); - let agent_id = agent_id(&result.agent_id).expect("agent_id should be valid"); - assert!( - result - .nickname - .as_deref() - .is_some_and(|nickname| !nickname.is_empty()) - ); - - let snapshot = manager - .get_thread(agent_id) - .await - .expect("spawned agent thread should exist") - .config_snapshot() - .await; - assert_eq!(snapshot.sandbox_policy, expected_sandbox); - assert_eq!(snapshot.approval_policy, AskForApproval::OnRequest); - } - - #[tokio::test] - async fn spawn_agent_rejects_when_depth_limit_exceeded() { - let (mut session, mut turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - - let max_depth = turn.config.agent_max_depth; - turn.session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id: session.conversation_id, - depth: max_depth, - agent_nickname: None, - agent_role: None, - }); - - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "spawn_agent", - function_payload(json!({"message": "hello"})), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("spawn should fail when depth limit exceeded"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel( - "Agent depth limit reached. Solve the task yourself.".to_string() - ) - ); - } - - #[tokio::test] - async fn spawn_agent_allows_depth_up_to_configured_max_depth() { - #[derive(Debug, Deserialize)] - struct SpawnAgentResult { - agent_id: String, - nickname: Option, - } - - let (mut session, mut turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - - let mut config = (*turn.config).clone(); - config.agent_max_depth = DEFAULT_AGENT_MAX_DEPTH + 1; - turn.config = Arc::new(config); - turn.session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id: session.conversation_id, - depth: DEFAULT_AGENT_MAX_DEPTH, - agent_nickname: None, - agent_role: None, - }); - - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "spawn_agent", - function_payload(json!({"message": "hello"})), - ); - let output = MultiAgentHandler - .handle(invocation) - .await - .expect("spawn should succeed within configured depth"); - let (content, success) = expect_text_output(output); - let result: SpawnAgentResult = - serde_json::from_str(&content).expect("spawn_agent result should be json"); - assert!(!result.agent_id.is_empty()); - assert!( - result - .nickname - .as_deref() - .is_some_and(|nickname| !nickname.is_empty()) - ); - assert_eq!(success, Some(true)); - } - - #[tokio::test] - async fn send_input_rejects_empty_message() { - let (session, turn) = make_session_and_context().await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "send_input", - function_payload(json!({"id": ThreadId::new().to_string(), "message": ""})), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("empty message should be rejected"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel( - "Empty message can't be sent to an agent".to_string() - ) - ); - } - - #[tokio::test] - async fn send_input_rejects_when_message_and_items_are_both_set() { - let (session, turn) = make_session_and_context().await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "send_input", - function_payload(json!({ - "id": ThreadId::new().to_string(), - "message": "hello", - "items": [{"type": "mention", "name": "drive", "path": "app://drive"}] - })), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("message+items should be rejected"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel( - "Provide either message or items, but not both".to_string() - ) - ); - } - - #[tokio::test] - async fn send_input_rejects_invalid_id() { - let (session, turn) = make_session_and_context().await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "send_input", - function_payload(json!({"id": "not-a-uuid", "message": "hi"})), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("invalid id should be rejected"); - }; - let FunctionCallError::RespondToModel(msg) = err else { - panic!("expected respond-to-model error"); - }; - assert!(msg.starts_with("invalid agent id not-a-uuid:")); - } - - #[tokio::test] - async fn send_input_reports_missing_agent() { - let (mut session, turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let agent_id = ThreadId::new(); - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "send_input", - function_payload(json!({"id": agent_id.to_string(), "message": "hi"})), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("missing agent should be reported"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel(format!("agent with id {agent_id} not found")) - ); - } - - #[tokio::test] - async fn send_input_interrupts_before_prompt() { - let (mut session, turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let config = turn.config.as_ref().clone(); - let thread = manager.start_thread(config).await.expect("start thread"); - let agent_id = thread.thread_id; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "send_input", - function_payload(json!({ - "id": agent_id.to_string(), - "message": "hi", - "interrupt": true - })), - ); - MultiAgentHandler - .handle(invocation) - .await - .expect("send_input should succeed"); - - let ops = manager.captured_ops(); - let ops_for_agent: Vec<&Op> = ops - .iter() - .filter_map(|(id, op)| (*id == agent_id).then_some(op)) - .collect(); - assert_eq!(ops_for_agent.len(), 2); - assert!(matches!(ops_for_agent[0], Op::Interrupt)); - assert!(matches!(ops_for_agent[1], Op::UserInput { .. })); - - let _ = thread - .thread - .submit(Op::Shutdown {}) - .await - .expect("shutdown should submit"); - } - - #[tokio::test] - async fn send_input_accepts_structured_items() { - let (mut session, turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let config = turn.config.as_ref().clone(); - let thread = manager.start_thread(config).await.expect("start thread"); - let agent_id = thread.thread_id; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "send_input", - function_payload(json!({ - "id": agent_id.to_string(), - "items": [ - {"type": "mention", "name": "drive", "path": "app://google_drive"}, - {"type": "text", "text": "read the folder"} - ] - })), - ); - MultiAgentHandler - .handle(invocation) - .await - .expect("send_input should succeed"); - - let expected = Op::UserInput { - items: vec![ - UserInput::Mention { - name: "drive".to_string(), - path: "app://google_drive".to_string(), - }, - UserInput::Text { - text: "read the folder".to_string(), - text_elements: Vec::new(), - }, - ], - final_output_json_schema: None, - }; - let captured = manager - .captured_ops() - .into_iter() - .find(|(id, op)| *id == agent_id && *op == expected); - assert_eq!(captured, Some((agent_id, expected))); - - let _ = thread - .thread - .submit(Op::Shutdown {}) - .await - .expect("shutdown should submit"); - } - - #[tokio::test] - async fn resume_agent_rejects_invalid_id() { - let (session, turn) = make_session_and_context().await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "resume_agent", - function_payload(json!({"id": "not-a-uuid"})), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("invalid id should be rejected"); - }; - let FunctionCallError::RespondToModel(msg) = err else { - panic!("expected respond-to-model error"); - }; - assert!(msg.starts_with("invalid agent id not-a-uuid:")); - } - - #[tokio::test] - async fn resume_agent_reports_missing_agent() { - let (mut session, turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let agent_id = ThreadId::new(); - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "resume_agent", - function_payload(json!({"id": agent_id.to_string()})), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("missing agent should be reported"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel(format!("agent with id {agent_id} not found")) - ); - } - - #[tokio::test] - async fn resume_agent_noops_for_active_agent() { - let (mut session, turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let config = turn.config.as_ref().clone(); - let thread = manager.start_thread(config).await.expect("start thread"); - let agent_id = thread.thread_id; - let status_before = manager.agent_control().get_status(agent_id).await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "resume_agent", - function_payload(json!({"id": agent_id.to_string()})), - ); - - let output = MultiAgentHandler - .handle(invocation) - .await - .expect("resume_agent should succeed"); - let (content, success) = expect_text_output(output); - let result: resume_agent::ResumeAgentResult = - serde_json::from_str(&content).expect("resume_agent result should be json"); - assert_eq!(result.status, status_before); - assert_eq!(success, Some(true)); - - let thread_ids = manager.list_thread_ids().await; - assert_eq!(thread_ids, vec![agent_id]); - - let _ = thread - .thread - .submit(Op::Shutdown {}) - .await - .expect("shutdown should submit"); - } - - #[tokio::test] - async fn resume_agent_restores_closed_agent_and_accepts_send_input() { - let (mut session, turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let config = turn.config.as_ref().clone(); - let thread = manager - .resume_thread_with_history( - config, - InitialHistory::Forked(vec![RolloutItem::ResponseItem(ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "materialized".to_string(), - }], - end_turn: None, - phase: None, - })]), - AuthManager::from_auth_for_testing(CodexAuth::from_api_key("dummy")), - false, - None, - ) - .await - .expect("start thread"); - let agent_id = thread.thread_id; - let _ = manager - .agent_control() - .shutdown_agent(agent_id) - .await - .expect("shutdown agent"); - assert_eq!( - manager.agent_control().get_status(agent_id).await, - AgentStatus::NotFound - ); - let session = Arc::new(session); - let turn = Arc::new(turn); - - let resume_invocation = invocation( - session.clone(), - turn.clone(), - "resume_agent", - function_payload(json!({"id": agent_id.to_string()})), - ); - let output = MultiAgentHandler - .handle(resume_invocation) - .await - .expect("resume_agent should succeed"); - let (content, success) = expect_text_output(output); - let result: resume_agent::ResumeAgentResult = - serde_json::from_str(&content).expect("resume_agent result should be json"); - assert_ne!(result.status, AgentStatus::NotFound); - assert_eq!(success, Some(true)); - - let send_invocation = invocation( - session, - turn, - "send_input", - function_payload(json!({"id": agent_id.to_string(), "message": "hello"})), - ); - let output = MultiAgentHandler - .handle(send_invocation) - .await - .expect("send_input should succeed after resume"); - let (content, success) = expect_text_output(output); - let result: serde_json::Value = - serde_json::from_str(&content).expect("send_input result should be json"); - let submission_id = result - .get("submission_id") - .and_then(|value| value.as_str()) - .unwrap_or_default(); - assert!(!submission_id.is_empty()); - assert_eq!(success, Some(true)); - - let _ = manager - .agent_control() - .shutdown_agent(agent_id) - .await - .expect("shutdown resumed agent"); - } - - #[tokio::test] - async fn resume_agent_rejects_when_depth_limit_exceeded() { - let (mut session, mut turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - - let max_depth = turn.config.agent_max_depth; - turn.session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn { - parent_thread_id: session.conversation_id, - depth: max_depth, - agent_nickname: None, - agent_role: None, - }); - - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "resume_agent", - function_payload(json!({"id": ThreadId::new().to_string()})), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("resume should fail when depth limit exceeded"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel( - "Agent depth limit reached. Solve the task yourself.".to_string() - ) - ); - } - - #[tokio::test] - async fn wait_rejects_non_positive_timeout() { - let (session, turn) = make_session_and_context().await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "wait", - function_payload(json!({ - "ids": [ThreadId::new().to_string()], - "timeout_ms": 0 - })), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("non-positive timeout should be rejected"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel("timeout_ms must be greater than zero".to_string()) - ); - } - - #[tokio::test] - async fn wait_rejects_invalid_id() { - let (session, turn) = make_session_and_context().await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "wait", - function_payload(json!({"ids": ["invalid"]})), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("invalid id should be rejected"); - }; - let FunctionCallError::RespondToModel(msg) = err else { - panic!("expected respond-to-model error"); - }; - assert!(msg.starts_with("invalid agent id invalid:")); - } - - #[tokio::test] - async fn wait_rejects_empty_ids() { - let (session, turn) = make_session_and_context().await; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "wait", - function_payload(json!({"ids": []})), - ); - let Err(err) = MultiAgentHandler.handle(invocation).await else { - panic!("empty ids should be rejected"); - }; - assert_eq!( - err, - FunctionCallError::RespondToModel("ids must be non-empty".to_string()) - ); - } - - #[tokio::test] - async fn wait_returns_not_found_for_missing_agents() { - let (mut session, turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let id_a = ThreadId::new(); - let id_b = ThreadId::new(); - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "wait", - function_payload(json!({ - "ids": [id_a.to_string(), id_b.to_string()], - "timeout_ms": 1000 - })), - ); - let output = MultiAgentHandler - .handle(invocation) - .await - .expect("wait should succeed"); - let (content, success) = expect_text_output(output); - let result: wait::WaitResult = - serde_json::from_str(&content).expect("wait result should be json"); - assert_eq!( - result, - wait::WaitResult { - status: HashMap::from([ - (id_a, AgentStatus::NotFound), - (id_b, AgentStatus::NotFound), - ]), - timed_out: false - } - ); - assert_eq!(success, None); - } - - #[tokio::test] - async fn wait_times_out_when_status_is_not_final() { - let (mut session, turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let config = turn.config.as_ref().clone(); - let thread = manager.start_thread(config).await.expect("start thread"); - let agent_id = thread.thread_id; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "wait", - function_payload(json!({ - "ids": [agent_id.to_string()], - "timeout_ms": MIN_WAIT_TIMEOUT_MS - })), - ); - let output = MultiAgentHandler - .handle(invocation) - .await - .expect("wait should succeed"); - let (content, success) = expect_text_output(output); - let result: wait::WaitResult = - serde_json::from_str(&content).expect("wait result should be json"); - assert_eq!( - result, - wait::WaitResult { - status: HashMap::new(), - timed_out: true - } - ); - assert_eq!(success, None); - - let _ = thread - .thread - .submit(Op::Shutdown {}) - .await - .expect("shutdown should submit"); - } - - #[tokio::test] - async fn wait_clamps_short_timeouts_to_minimum() { - let (mut session, turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let config = turn.config.as_ref().clone(); - let thread = manager.start_thread(config).await.expect("start thread"); - let agent_id = thread.thread_id; - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "wait", - function_payload(json!({ - "ids": [agent_id.to_string()], - "timeout_ms": 10 - })), - ); - - let early = timeout( - Duration::from_millis(50), - MultiAgentHandler.handle(invocation), - ) - .await; - assert!( - early.is_err(), - "wait should not return before the minimum timeout clamp" - ); - - let _ = thread - .thread - .submit(Op::Shutdown {}) - .await - .expect("shutdown should submit"); - } - - #[tokio::test] - async fn wait_returns_final_status_without_timeout() { - let (mut session, turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let config = turn.config.as_ref().clone(); - let thread = manager.start_thread(config).await.expect("start thread"); - let agent_id = thread.thread_id; - let mut status_rx = manager - .agent_control() - .subscribe_status(agent_id) - .await - .expect("subscribe should succeed"); - - let _ = thread - .thread - .submit(Op::Shutdown {}) - .await - .expect("shutdown should submit"); - let _ = timeout(Duration::from_secs(1), status_rx.changed()) - .await - .expect("shutdown status should arrive"); - - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "wait", - function_payload(json!({ - "ids": [agent_id.to_string()], - "timeout_ms": 1000 - })), - ); - let output = MultiAgentHandler - .handle(invocation) - .await - .expect("wait should succeed"); - let (content, success) = expect_text_output(output); - let result: wait::WaitResult = - serde_json::from_str(&content).expect("wait result should be json"); - assert_eq!( - result, - wait::WaitResult { - status: HashMap::from([(agent_id, AgentStatus::Shutdown)]), - timed_out: false - } - ); - assert_eq!(success, None); - } - - #[tokio::test] - async fn close_agent_submits_shutdown_and_returns_status() { - let (mut session, turn) = make_session_and_context().await; - let manager = thread_manager(); - session.services.agent_control = manager.agent_control(); - let config = turn.config.as_ref().clone(); - let thread = manager.start_thread(config).await.expect("start thread"); - let agent_id = thread.thread_id; - let status_before = manager.agent_control().get_status(agent_id).await; - - let invocation = invocation( - Arc::new(session), - Arc::new(turn), - "close_agent", - function_payload(json!({"id": agent_id.to_string()})), - ); - let output = MultiAgentHandler - .handle(invocation) - .await - .expect("close_agent should succeed"); - let (content, success) = expect_text_output(output); - let result: close_agent::CloseAgentResult = - serde_json::from_str(&content).expect("close_agent result should be json"); - assert_eq!(result.status, status_before); - assert_eq!(success, Some(true)); - - let ops = manager.captured_ops(); - let submitted_shutdown = ops - .iter() - .any(|(id, op)| *id == agent_id && matches!(op, Op::Shutdown)); - assert_eq!(submitted_shutdown, true); - - let status_after = manager.agent_control().get_status(agent_id).await; - assert_eq!(status_after, AgentStatus::NotFound); - } - - #[tokio::test] - async fn build_agent_spawn_config_uses_turn_context_values() { - fn pick_allowed_sandbox_policy( - constraint: &crate::config::Constrained, - base: SandboxPolicy, - ) -> SandboxPolicy { - let candidates = [ - SandboxPolicy::new_read_only_policy(), - SandboxPolicy::new_workspace_write_policy(), - SandboxPolicy::DangerFullAccess, - ]; - candidates - .into_iter() - .find(|candidate| *candidate != base && constraint.can_set(candidate).is_ok()) - .unwrap_or(base) - } - - let (_session, mut turn) = make_session_and_context().await; - let base_instructions = BaseInstructions { - text: "base".to_string(), - }; - turn.developer_instructions = Some("dev".to_string()); - turn.compact_prompt = Some("compact".to_string()); - turn.shell_environment_policy = ShellEnvironmentPolicy { - use_profile: true, - ..ShellEnvironmentPolicy::default() - }; - let temp_dir = tempfile::tempdir().expect("temp dir"); - turn.cwd = temp_dir.path().to_path_buf(); - turn.codex_linux_sandbox_exe = Some(PathBuf::from("/bin/echo")); - let sandbox_policy = pick_allowed_sandbox_policy( - &turn.config.permissions.sandbox_policy, - turn.config.permissions.sandbox_policy.get().clone(), - ); - turn.sandbox_policy - .set(sandbox_policy) - .expect("sandbox policy set"); - turn.approval_policy - .set(AskForApproval::OnRequest) - .expect("approval policy set"); - - let config = build_agent_spawn_config(&base_instructions, &turn).expect("spawn config"); - let mut expected = (*turn.config).clone(); - expected.base_instructions = Some(base_instructions.text); - expected.model = Some(turn.model_info.slug.clone()); - expected.model_provider = turn.provider.clone(); - expected.model_reasoning_effort = turn.reasoning_effort; - expected.model_reasoning_summary = Some(turn.reasoning_summary); - expected.developer_instructions = turn.developer_instructions.clone(); - expected.compact_prompt = turn.compact_prompt.clone(); - expected.permissions.shell_environment_policy = turn.shell_environment_policy.clone(); - expected.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone(); - expected.cwd = turn.cwd.clone(); - expected - .permissions - .approval_policy - .set(AskForApproval::OnRequest) - .expect("approval policy set"); - expected - .permissions - .sandbox_policy - .set(turn.sandbox_policy.get().clone()) - .expect("sandbox policy set"); - assert_eq!(config, expected); - } - - #[tokio::test] - async fn build_agent_spawn_config_preserves_base_user_instructions() { - let (_session, mut turn) = make_session_and_context().await; - let mut base_config = (*turn.config).clone(); - base_config.user_instructions = Some("base-user".to_string()); - turn.user_instructions = Some("resolved-user".to_string()); - turn.config = Arc::new(base_config.clone()); - let base_instructions = BaseInstructions { - text: "base".to_string(), - }; - - let config = build_agent_spawn_config(&base_instructions, &turn).expect("spawn config"); - - assert_eq!(config.user_instructions, base_config.user_instructions); - } - - #[tokio::test] - async fn build_agent_resume_config_clears_base_instructions() { - let (_session, mut turn) = make_session_and_context().await; - let mut base_config = (*turn.config).clone(); - base_config.base_instructions = Some("caller-base".to_string()); - turn.config = Arc::new(base_config); - turn.approval_policy - .set(AskForApproval::OnRequest) - .expect("approval policy set"); - - let config = build_agent_resume_config(&turn, 0).expect("resume config"); - - let mut expected = (*turn.config).clone(); - expected.base_instructions = None; - expected.model = Some(turn.model_info.slug.clone()); - expected.model_provider = turn.provider.clone(); - expected.model_reasoning_effort = turn.reasoning_effort; - expected.model_reasoning_summary = Some(turn.reasoning_summary); - expected.developer_instructions = turn.developer_instructions.clone(); - expected.compact_prompt = turn.compact_prompt.clone(); - expected.permissions.shell_environment_policy = turn.shell_environment_policy.clone(); - expected.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone(); - expected.cwd = turn.cwd.clone(); - expected - .permissions - .approval_policy - .set(AskForApproval::OnRequest) - .expect("approval policy set"); - expected - .permissions - .sandbox_policy - .set(turn.sandbox_policy.get().clone()) - .expect("sandbox policy set"); - assert_eq!(config, expected); - } -} +#[path = "multi_agents_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/multi_agents_tests.rs b/codex-rs/core/src/tools/handlers/multi_agents_tests.rs new file mode 100644 index 0000000000..1aee1fbf1c --- /dev/null +++ b/codex-rs/core/src/tools/handlers/multi_agents_tests.rs @@ -0,0 +1,1103 @@ +use super::*; +use crate::AuthManager; +use crate::CodexAuth; +use crate::ThreadManager; +use crate::built_in_model_providers; +use crate::codex::make_session_and_context; +use crate::config::DEFAULT_AGENT_MAX_DEPTH; +use crate::config::types::ShellEnvironmentPolicy; +use crate::function_tool::FunctionCallError; +use crate::protocol::AskForApproval; +use crate::protocol::Op; +use crate::protocol::SandboxPolicy; +use crate::protocol::SessionSource; +use crate::protocol::SubAgentSource; +use crate::tools::context::FunctionToolOutput; +use crate::turn_diff_tracker::TurnDiffTracker; +use codex_protocol::ThreadId; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::InitialHistory; +use codex_protocol::protocol::RolloutItem; +use pretty_assertions::assert_eq; +use serde::Deserialize; +use serde_json::json; +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; +use tokio::time::timeout; + +fn invocation( + session: Arc, + turn: Arc, + tool_name: &str, + payload: ToolPayload, +) -> ToolInvocation { + ToolInvocation { + session, + turn, + tracker: Arc::new(Mutex::new(TurnDiffTracker::default())), + call_id: "call-1".to_string(), + tool_name: tool_name.to_string(), + tool_namespace: None, + payload, + } +} + +fn function_payload(args: serde_json::Value) -> ToolPayload { + ToolPayload::Function { + arguments: args.to_string(), + } +} + +fn thread_manager() -> ThreadManager { + ThreadManager::with_models_provider_for_tests( + CodexAuth::from_api_key("dummy"), + built_in_model_providers()["openai"].clone(), + ) +} + +fn expect_text_output(output: FunctionToolOutput) -> (String, Option) { + ( + codex_protocol::models::function_call_output_content_items_to_text(&output.body) + .unwrap_or_default(), + output.success, + ) +} + +#[tokio::test] +async fn handler_rejects_non_function_payloads() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + ToolPayload::Custom { + input: "hello".to_string(), + }, + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("payload should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel( + "collab handler received unsupported payload".to_string() + ) + ); +} + +#[tokio::test] +async fn handler_rejects_unknown_tool() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "unknown_tool", + function_payload(json!({})), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("tool should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel("unsupported collab tool unknown_tool".to_string()) + ); +} + +#[tokio::test] +async fn spawn_agent_rejects_empty_message() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({"message": " "})), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("empty message should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel("Empty message can't be sent to an agent".to_string()) + ); +} + +#[tokio::test] +async fn spawn_agent_rejects_when_message_and_items_are_both_set() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({ + "message": "hello", + "items": [{"type": "mention", "name": "drive", "path": "app://drive"}] + })), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("message+items should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel( + "Provide either message or items, but not both".to_string() + ) + ); +} + +#[tokio::test] +async fn spawn_agent_uses_explorer_role_and_preserves_approval_policy() { + #[derive(Debug, Deserialize)] + struct SpawnAgentResult { + agent_id: String, + nickname: Option, + } + + let (mut session, mut turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let mut config = (*turn.config).clone(); + let provider = built_in_model_providers()["ollama"].clone(); + config.model_provider_id = "ollama".to_string(); + config.model_provider = provider.clone(); + config + .permissions + .approval_policy + .set(AskForApproval::OnRequest) + .expect("approval policy should be set"); + turn.approval_policy + .set(AskForApproval::OnRequest) + .expect("approval policy should be set"); + turn.provider = provider; + turn.config = Arc::new(config); + + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({ + "message": "inspect this repo", + "agent_type": "explorer" + })), + ); + let output = MultiAgentHandler + .handle(invocation) + .await + .expect("spawn_agent should succeed"); + let (content, _) = expect_text_output(output); + let result: SpawnAgentResult = + serde_json::from_str(&content).expect("spawn_agent result should be json"); + let agent_id = agent_id(&result.agent_id).expect("agent_id should be valid"); + assert!( + result + .nickname + .as_deref() + .is_some_and(|nickname| !nickname.is_empty()) + ); + let snapshot = manager + .get_thread(agent_id) + .await + .expect("spawned agent thread should exist") + .config_snapshot() + .await; + assert_eq!(snapshot.approval_policy, AskForApproval::OnRequest); + assert_eq!(snapshot.model_provider_id, "ollama"); +} + +#[tokio::test] +async fn spawn_agent_errors_when_manager_dropped() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({"message": "hello"})), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("spawn should fail without a manager"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel("collab manager unavailable".to_string()) + ); +} + +#[tokio::test] +async fn spawn_agent_reapplies_runtime_sandbox_after_role_config() { + fn pick_allowed_sandbox_policy( + constraint: &crate::config::Constrained, + base: SandboxPolicy, + ) -> SandboxPolicy { + let candidates = [ + SandboxPolicy::DangerFullAccess, + SandboxPolicy::new_workspace_write_policy(), + SandboxPolicy::new_read_only_policy(), + ]; + candidates + .into_iter() + .find(|candidate| *candidate != base && constraint.can_set(candidate).is_ok()) + .unwrap_or(base) + } + + #[derive(Debug, Deserialize)] + struct SpawnAgentResult { + agent_id: String, + nickname: Option, + } + + let (mut session, mut turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let expected_sandbox = pick_allowed_sandbox_policy( + &turn.config.permissions.sandbox_policy, + turn.config.permissions.sandbox_policy.get().clone(), + ); + turn.approval_policy + .set(AskForApproval::OnRequest) + .expect("approval policy should be set"); + turn.sandbox_policy + .set(expected_sandbox.clone()) + .expect("sandbox policy should be set"); + assert_ne!( + expected_sandbox, + turn.config.permissions.sandbox_policy.get().clone(), + "test requires a runtime sandbox override that differs from base config" + ); + + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({ + "message": "await this command", + "agent_type": "explorer" + })), + ); + let output = MultiAgentHandler + .handle(invocation) + .await + .expect("spawn_agent should succeed"); + let (content, _) = expect_text_output(output); + let result: SpawnAgentResult = + serde_json::from_str(&content).expect("spawn_agent result should be json"); + let agent_id = agent_id(&result.agent_id).expect("agent_id should be valid"); + assert!( + result + .nickname + .as_deref() + .is_some_and(|nickname| !nickname.is_empty()) + ); + + let snapshot = manager + .get_thread(agent_id) + .await + .expect("spawned agent thread should exist") + .config_snapshot() + .await; + assert_eq!(snapshot.sandbox_policy, expected_sandbox); + assert_eq!(snapshot.approval_policy, AskForApproval::OnRequest); +} + +#[tokio::test] +async fn spawn_agent_rejects_when_depth_limit_exceeded() { + let (mut session, mut turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + + let max_depth = turn.config.agent_max_depth; + turn.session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id: session.conversation_id, + depth: max_depth, + agent_nickname: None, + agent_role: None, + }); + + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({"message": "hello"})), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("spawn should fail when depth limit exceeded"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel( + "Agent depth limit reached. Solve the task yourself.".to_string() + ) + ); +} + +#[tokio::test] +async fn spawn_agent_allows_depth_up_to_configured_max_depth() { + #[derive(Debug, Deserialize)] + struct SpawnAgentResult { + agent_id: String, + nickname: Option, + } + + let (mut session, mut turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + + let mut config = (*turn.config).clone(); + config.agent_max_depth = DEFAULT_AGENT_MAX_DEPTH + 1; + turn.config = Arc::new(config); + turn.session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id: session.conversation_id, + depth: DEFAULT_AGENT_MAX_DEPTH, + agent_nickname: None, + agent_role: None, + }); + + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "spawn_agent", + function_payload(json!({"message": "hello"})), + ); + let output = MultiAgentHandler + .handle(invocation) + .await + .expect("spawn should succeed within configured depth"); + let (content, success) = expect_text_output(output); + let result: SpawnAgentResult = + serde_json::from_str(&content).expect("spawn_agent result should be json"); + assert!(!result.agent_id.is_empty()); + assert!( + result + .nickname + .as_deref() + .is_some_and(|nickname| !nickname.is_empty()) + ); + assert_eq!(success, Some(true)); +} + +#[tokio::test] +async fn send_input_rejects_empty_message() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "send_input", + function_payload(json!({"id": ThreadId::new().to_string(), "message": ""})), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("empty message should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel("Empty message can't be sent to an agent".to_string()) + ); +} + +#[tokio::test] +async fn send_input_rejects_when_message_and_items_are_both_set() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "send_input", + function_payload(json!({ + "id": ThreadId::new().to_string(), + "message": "hello", + "items": [{"type": "mention", "name": "drive", "path": "app://drive"}] + })), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("message+items should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel( + "Provide either message or items, but not both".to_string() + ) + ); +} + +#[tokio::test] +async fn send_input_rejects_invalid_id() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "send_input", + function_payload(json!({"id": "not-a-uuid", "message": "hi"})), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("invalid id should be rejected"); + }; + let FunctionCallError::RespondToModel(msg) = err else { + panic!("expected respond-to-model error"); + }; + assert!(msg.starts_with("invalid agent id not-a-uuid:")); +} + +#[tokio::test] +async fn send_input_reports_missing_agent() { + let (mut session, turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let agent_id = ThreadId::new(); + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "send_input", + function_payload(json!({"id": agent_id.to_string(), "message": "hi"})), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("missing agent should be reported"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel(format!("agent with id {agent_id} not found")) + ); +} + +#[tokio::test] +async fn send_input_interrupts_before_prompt() { + let (mut session, turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let config = turn.config.as_ref().clone(); + let thread = manager.start_thread(config).await.expect("start thread"); + let agent_id = thread.thread_id; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "send_input", + function_payload(json!({ + "id": agent_id.to_string(), + "message": "hi", + "interrupt": true + })), + ); + MultiAgentHandler + .handle(invocation) + .await + .expect("send_input should succeed"); + + let ops = manager.captured_ops(); + let ops_for_agent: Vec<&Op> = ops + .iter() + .filter_map(|(id, op)| (*id == agent_id).then_some(op)) + .collect(); + assert_eq!(ops_for_agent.len(), 2); + assert!(matches!(ops_for_agent[0], Op::Interrupt)); + assert!(matches!(ops_for_agent[1], Op::UserInput { .. })); + + let _ = thread + .thread + .submit(Op::Shutdown {}) + .await + .expect("shutdown should submit"); +} + +#[tokio::test] +async fn send_input_accepts_structured_items() { + let (mut session, turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let config = turn.config.as_ref().clone(); + let thread = manager.start_thread(config).await.expect("start thread"); + let agent_id = thread.thread_id; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "send_input", + function_payload(json!({ + "id": agent_id.to_string(), + "items": [ + {"type": "mention", "name": "drive", "path": "app://google_drive"}, + {"type": "text", "text": "read the folder"} + ] + })), + ); + MultiAgentHandler + .handle(invocation) + .await + .expect("send_input should succeed"); + + let expected = Op::UserInput { + items: vec![ + UserInput::Mention { + name: "drive".to_string(), + path: "app://google_drive".to_string(), + }, + UserInput::Text { + text: "read the folder".to_string(), + text_elements: Vec::new(), + }, + ], + final_output_json_schema: None, + }; + let captured = manager + .captured_ops() + .into_iter() + .find(|(id, op)| *id == agent_id && *op == expected); + assert_eq!(captured, Some((agent_id, expected))); + + let _ = thread + .thread + .submit(Op::Shutdown {}) + .await + .expect("shutdown should submit"); +} + +#[tokio::test] +async fn resume_agent_rejects_invalid_id() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "resume_agent", + function_payload(json!({"id": "not-a-uuid"})), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("invalid id should be rejected"); + }; + let FunctionCallError::RespondToModel(msg) = err else { + panic!("expected respond-to-model error"); + }; + assert!(msg.starts_with("invalid agent id not-a-uuid:")); +} + +#[tokio::test] +async fn resume_agent_reports_missing_agent() { + let (mut session, turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let agent_id = ThreadId::new(); + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "resume_agent", + function_payload(json!({"id": agent_id.to_string()})), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("missing agent should be reported"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel(format!("agent with id {agent_id} not found")) + ); +} + +#[tokio::test] +async fn resume_agent_noops_for_active_agent() { + let (mut session, turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let config = turn.config.as_ref().clone(); + let thread = manager.start_thread(config).await.expect("start thread"); + let agent_id = thread.thread_id; + let status_before = manager.agent_control().get_status(agent_id).await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "resume_agent", + function_payload(json!({"id": agent_id.to_string()})), + ); + + let output = MultiAgentHandler + .handle(invocation) + .await + .expect("resume_agent should succeed"); + let (content, success) = expect_text_output(output); + let result: resume_agent::ResumeAgentResult = + serde_json::from_str(&content).expect("resume_agent result should be json"); + assert_eq!(result.status, status_before); + assert_eq!(success, Some(true)); + + let thread_ids = manager.list_thread_ids().await; + assert_eq!(thread_ids, vec![agent_id]); + + let _ = thread + .thread + .submit(Op::Shutdown {}) + .await + .expect("shutdown should submit"); +} + +#[tokio::test] +async fn resume_agent_restores_closed_agent_and_accepts_send_input() { + let (mut session, turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let config = turn.config.as_ref().clone(); + let thread = manager + .resume_thread_with_history( + config, + InitialHistory::Forked(vec![RolloutItem::ResponseItem(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "materialized".to_string(), + }], + end_turn: None, + phase: None, + })]), + AuthManager::from_auth_for_testing(CodexAuth::from_api_key("dummy")), + false, + None, + ) + .await + .expect("start thread"); + let agent_id = thread.thread_id; + let _ = manager + .agent_control() + .shutdown_agent(agent_id) + .await + .expect("shutdown agent"); + assert_eq!( + manager.agent_control().get_status(agent_id).await, + AgentStatus::NotFound + ); + let session = Arc::new(session); + let turn = Arc::new(turn); + + let resume_invocation = invocation( + session.clone(), + turn.clone(), + "resume_agent", + function_payload(json!({"id": agent_id.to_string()})), + ); + let output = MultiAgentHandler + .handle(resume_invocation) + .await + .expect("resume_agent should succeed"); + let (content, success) = expect_text_output(output); + let result: resume_agent::ResumeAgentResult = + serde_json::from_str(&content).expect("resume_agent result should be json"); + assert_ne!(result.status, AgentStatus::NotFound); + assert_eq!(success, Some(true)); + + let send_invocation = invocation( + session, + turn, + "send_input", + function_payload(json!({"id": agent_id.to_string(), "message": "hello"})), + ); + let output = MultiAgentHandler + .handle(send_invocation) + .await + .expect("send_input should succeed after resume"); + let (content, success) = expect_text_output(output); + let result: serde_json::Value = + serde_json::from_str(&content).expect("send_input result should be json"); + let submission_id = result + .get("submission_id") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + assert!(!submission_id.is_empty()); + assert_eq!(success, Some(true)); + + let _ = manager + .agent_control() + .shutdown_agent(agent_id) + .await + .expect("shutdown resumed agent"); +} + +#[tokio::test] +async fn resume_agent_rejects_when_depth_limit_exceeded() { + let (mut session, mut turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + + let max_depth = turn.config.agent_max_depth; + turn.session_source = SessionSource::SubAgent(SubAgentSource::ThreadSpawn { + parent_thread_id: session.conversation_id, + depth: max_depth, + agent_nickname: None, + agent_role: None, + }); + + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "resume_agent", + function_payload(json!({"id": ThreadId::new().to_string()})), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("resume should fail when depth limit exceeded"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel( + "Agent depth limit reached. Solve the task yourself.".to_string() + ) + ); +} + +#[tokio::test] +async fn wait_rejects_non_positive_timeout() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "wait", + function_payload(json!({ + "ids": [ThreadId::new().to_string()], + "timeout_ms": 0 + })), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("non-positive timeout should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel("timeout_ms must be greater than zero".to_string()) + ); +} + +#[tokio::test] +async fn wait_rejects_invalid_id() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "wait", + function_payload(json!({"ids": ["invalid"]})), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("invalid id should be rejected"); + }; + let FunctionCallError::RespondToModel(msg) = err else { + panic!("expected respond-to-model error"); + }; + assert!(msg.starts_with("invalid agent id invalid:")); +} + +#[tokio::test] +async fn wait_rejects_empty_ids() { + let (session, turn) = make_session_and_context().await; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "wait", + function_payload(json!({"ids": []})), + ); + let Err(err) = MultiAgentHandler.handle(invocation).await else { + panic!("empty ids should be rejected"); + }; + assert_eq!( + err, + FunctionCallError::RespondToModel("ids must be non-empty".to_string()) + ); +} + +#[tokio::test] +async fn wait_returns_not_found_for_missing_agents() { + let (mut session, turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let id_a = ThreadId::new(); + let id_b = ThreadId::new(); + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "wait", + function_payload(json!({ + "ids": [id_a.to_string(), id_b.to_string()], + "timeout_ms": 1000 + })), + ); + let output = MultiAgentHandler + .handle(invocation) + .await + .expect("wait should succeed"); + let (content, success) = expect_text_output(output); + let result: wait::WaitResult = + serde_json::from_str(&content).expect("wait result should be json"); + assert_eq!( + result, + wait::WaitResult { + status: HashMap::from([(id_a, AgentStatus::NotFound), (id_b, AgentStatus::NotFound),]), + timed_out: false + } + ); + assert_eq!(success, None); +} + +#[tokio::test] +async fn wait_times_out_when_status_is_not_final() { + let (mut session, turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let config = turn.config.as_ref().clone(); + let thread = manager.start_thread(config).await.expect("start thread"); + let agent_id = thread.thread_id; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "wait", + function_payload(json!({ + "ids": [agent_id.to_string()], + "timeout_ms": MIN_WAIT_TIMEOUT_MS + })), + ); + let output = MultiAgentHandler + .handle(invocation) + .await + .expect("wait should succeed"); + let (content, success) = expect_text_output(output); + let result: wait::WaitResult = + serde_json::from_str(&content).expect("wait result should be json"); + assert_eq!( + result, + wait::WaitResult { + status: HashMap::new(), + timed_out: true + } + ); + assert_eq!(success, None); + + let _ = thread + .thread + .submit(Op::Shutdown {}) + .await + .expect("shutdown should submit"); +} + +#[tokio::test] +async fn wait_clamps_short_timeouts_to_minimum() { + let (mut session, turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let config = turn.config.as_ref().clone(); + let thread = manager.start_thread(config).await.expect("start thread"); + let agent_id = thread.thread_id; + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "wait", + function_payload(json!({ + "ids": [agent_id.to_string()], + "timeout_ms": 10 + })), + ); + + let early = timeout( + Duration::from_millis(50), + MultiAgentHandler.handle(invocation), + ) + .await; + assert!( + early.is_err(), + "wait should not return before the minimum timeout clamp" + ); + + let _ = thread + .thread + .submit(Op::Shutdown {}) + .await + .expect("shutdown should submit"); +} + +#[tokio::test] +async fn wait_returns_final_status_without_timeout() { + let (mut session, turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let config = turn.config.as_ref().clone(); + let thread = manager.start_thread(config).await.expect("start thread"); + let agent_id = thread.thread_id; + let mut status_rx = manager + .agent_control() + .subscribe_status(agent_id) + .await + .expect("subscribe should succeed"); + + let _ = thread + .thread + .submit(Op::Shutdown {}) + .await + .expect("shutdown should submit"); + let _ = timeout(Duration::from_secs(1), status_rx.changed()) + .await + .expect("shutdown status should arrive"); + + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "wait", + function_payload(json!({ + "ids": [agent_id.to_string()], + "timeout_ms": 1000 + })), + ); + let output = MultiAgentHandler + .handle(invocation) + .await + .expect("wait should succeed"); + let (content, success) = expect_text_output(output); + let result: wait::WaitResult = + serde_json::from_str(&content).expect("wait result should be json"); + assert_eq!( + result, + wait::WaitResult { + status: HashMap::from([(agent_id, AgentStatus::Shutdown)]), + timed_out: false + } + ); + assert_eq!(success, None); +} + +#[tokio::test] +async fn close_agent_submits_shutdown_and_returns_status() { + let (mut session, turn) = make_session_and_context().await; + let manager = thread_manager(); + session.services.agent_control = manager.agent_control(); + let config = turn.config.as_ref().clone(); + let thread = manager.start_thread(config).await.expect("start thread"); + let agent_id = thread.thread_id; + let status_before = manager.agent_control().get_status(agent_id).await; + + let invocation = invocation( + Arc::new(session), + Arc::new(turn), + "close_agent", + function_payload(json!({"id": agent_id.to_string()})), + ); + let output = MultiAgentHandler + .handle(invocation) + .await + .expect("close_agent should succeed"); + let (content, success) = expect_text_output(output); + let result: close_agent::CloseAgentResult = + serde_json::from_str(&content).expect("close_agent result should be json"); + assert_eq!(result.status, status_before); + assert_eq!(success, Some(true)); + + let ops = manager.captured_ops(); + let submitted_shutdown = ops + .iter() + .any(|(id, op)| *id == agent_id && matches!(op, Op::Shutdown)); + assert_eq!(submitted_shutdown, true); + + let status_after = manager.agent_control().get_status(agent_id).await; + assert_eq!(status_after, AgentStatus::NotFound); +} + +#[tokio::test] +async fn build_agent_spawn_config_uses_turn_context_values() { + fn pick_allowed_sandbox_policy( + constraint: &crate::config::Constrained, + base: SandboxPolicy, + ) -> SandboxPolicy { + let candidates = [ + SandboxPolicy::new_read_only_policy(), + SandboxPolicy::new_workspace_write_policy(), + SandboxPolicy::DangerFullAccess, + ]; + candidates + .into_iter() + .find(|candidate| *candidate != base && constraint.can_set(candidate).is_ok()) + .unwrap_or(base) + } + + let (_session, mut turn) = make_session_and_context().await; + let base_instructions = BaseInstructions { + text: "base".to_string(), + }; + turn.developer_instructions = Some("dev".to_string()); + turn.compact_prompt = Some("compact".to_string()); + turn.shell_environment_policy = ShellEnvironmentPolicy { + use_profile: true, + ..ShellEnvironmentPolicy::default() + }; + let temp_dir = tempfile::tempdir().expect("temp dir"); + turn.cwd = temp_dir.path().to_path_buf(); + turn.codex_linux_sandbox_exe = Some(PathBuf::from("/bin/echo")); + let sandbox_policy = pick_allowed_sandbox_policy( + &turn.config.permissions.sandbox_policy, + turn.config.permissions.sandbox_policy.get().clone(), + ); + turn.sandbox_policy + .set(sandbox_policy) + .expect("sandbox policy set"); + turn.approval_policy + .set(AskForApproval::OnRequest) + .expect("approval policy set"); + + let config = build_agent_spawn_config(&base_instructions, &turn).expect("spawn config"); + let mut expected = (*turn.config).clone(); + expected.base_instructions = Some(base_instructions.text); + expected.model = Some(turn.model_info.slug.clone()); + expected.model_provider = turn.provider.clone(); + expected.model_reasoning_effort = turn.reasoning_effort; + expected.model_reasoning_summary = Some(turn.reasoning_summary); + expected.developer_instructions = turn.developer_instructions.clone(); + expected.compact_prompt = turn.compact_prompt.clone(); + expected.permissions.shell_environment_policy = turn.shell_environment_policy.clone(); + expected.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone(); + expected.cwd = turn.cwd.clone(); + expected + .permissions + .approval_policy + .set(AskForApproval::OnRequest) + .expect("approval policy set"); + expected + .permissions + .sandbox_policy + .set(turn.sandbox_policy.get().clone()) + .expect("sandbox policy set"); + assert_eq!(config, expected); +} + +#[tokio::test] +async fn build_agent_spawn_config_preserves_base_user_instructions() { + let (_session, mut turn) = make_session_and_context().await; + let mut base_config = (*turn.config).clone(); + base_config.user_instructions = Some("base-user".to_string()); + turn.user_instructions = Some("resolved-user".to_string()); + turn.config = Arc::new(base_config.clone()); + let base_instructions = BaseInstructions { + text: "base".to_string(), + }; + + let config = build_agent_spawn_config(&base_instructions, &turn).expect("spawn config"); + + assert_eq!(config.user_instructions, base_config.user_instructions); +} + +#[tokio::test] +async fn build_agent_resume_config_clears_base_instructions() { + let (_session, mut turn) = make_session_and_context().await; + let mut base_config = (*turn.config).clone(); + base_config.base_instructions = Some("caller-base".to_string()); + turn.config = Arc::new(base_config); + turn.approval_policy + .set(AskForApproval::OnRequest) + .expect("approval policy set"); + + let config = build_agent_resume_config(&turn, 0).expect("resume config"); + + let mut expected = (*turn.config).clone(); + expected.base_instructions = None; + expected.model = Some(turn.model_info.slug.clone()); + expected.model_provider = turn.provider.clone(); + expected.model_reasoning_effort = turn.reasoning_effort; + expected.model_reasoning_summary = Some(turn.reasoning_summary); + expected.developer_instructions = turn.developer_instructions.clone(); + expected.compact_prompt = turn.compact_prompt.clone(); + expected.permissions.shell_environment_policy = turn.shell_environment_policy.clone(); + expected.codex_linux_sandbox_exe = turn.codex_linux_sandbox_exe.clone(); + expected.cwd = turn.cwd.clone(); + expected + .permissions + .approval_policy + .set(AskForApproval::OnRequest) + .expect("approval policy set"); + expected + .permissions + .sandbox_policy + .set(turn.sandbox_policy.get().clone()) + .expect("sandbox policy set"); + assert_eq!(config, expected); +} diff --git a/codex-rs/core/src/tools/handlers/read_file.rs b/codex-rs/core/src/tools/handlers/read_file.rs index e88bf9baa4..b868edf5b9 100644 --- a/codex-rs/core/src/tools/handlers/read_file.rs +++ b/codex-rs/core/src/tools/handlers/read_file.rs @@ -485,508 +485,5 @@ mod defaults { } #[cfg(test)] -mod tests { - use super::indentation::read_block; - use super::slice::read; - use super::*; - use pretty_assertions::assert_eq; - use tempfile::NamedTempFile; - - #[tokio::test] - async fn reads_requested_range() -> anyhow::Result<()> { - let mut temp = NamedTempFile::new()?; - use std::io::Write as _; - write!( - temp, - "alpha -beta -gamma -" - )?; - - let lines = read(temp.path(), 2, 2).await?; - assert_eq!(lines, vec!["L2: beta".to_string(), "L3: gamma".to_string()]); - Ok(()) - } - - #[tokio::test] - async fn errors_when_offset_exceeds_length() -> anyhow::Result<()> { - let mut temp = NamedTempFile::new()?; - use std::io::Write as _; - writeln!(temp, "only")?; - - let err = read(temp.path(), 3, 1) - .await - .expect_err("offset exceeds length"); - assert_eq!( - err, - FunctionCallError::RespondToModel("offset exceeds file length".to_string()) - ); - Ok(()) - } - - #[tokio::test] - async fn reads_non_utf8_lines() -> anyhow::Result<()> { - let mut temp = NamedTempFile::new()?; - use std::io::Write as _; - temp.as_file_mut().write_all(b"\xff\xfe\nplain\n")?; - - let lines = read(temp.path(), 1, 2).await?; - let expected_first = format!("L1: {}{}", '\u{FFFD}', '\u{FFFD}'); - assert_eq!(lines, vec![expected_first, "L2: plain".to_string()]); - Ok(()) - } - - #[tokio::test] - async fn trims_crlf_endings() -> anyhow::Result<()> { - let mut temp = NamedTempFile::new()?; - use std::io::Write as _; - write!(temp, "one\r\ntwo\r\n")?; - - let lines = read(temp.path(), 1, 2).await?; - assert_eq!(lines, vec!["L1: one".to_string(), "L2: two".to_string()]); - Ok(()) - } - - #[tokio::test] - async fn respects_limit_even_with_more_lines() -> anyhow::Result<()> { - let mut temp = NamedTempFile::new()?; - use std::io::Write as _; - write!( - temp, - "first -second -third -" - )?; - - let lines = read(temp.path(), 1, 2).await?; - assert_eq!( - lines, - vec!["L1: first".to_string(), "L2: second".to_string()] - ); - Ok(()) - } - - #[tokio::test] - async fn truncates_lines_longer_than_max_length() -> anyhow::Result<()> { - let mut temp = NamedTempFile::new()?; - use std::io::Write as _; - let long_line = "x".repeat(MAX_LINE_LENGTH + 50); - writeln!(temp, "{long_line}")?; - - let lines = read(temp.path(), 1, 1).await?; - let expected = "x".repeat(MAX_LINE_LENGTH); - assert_eq!(lines, vec![format!("L1: {expected}")]); - Ok(()) - } - - #[tokio::test] - async fn indentation_mode_captures_block() -> anyhow::Result<()> { - let mut temp = NamedTempFile::new()?; - use std::io::Write as _; - write!( - temp, - "fn outer() {{ - if cond {{ - inner(); - }} - tail(); -}} -" - )?; - - let options = IndentationArgs { - anchor_line: Some(3), - include_siblings: false, - max_levels: 1, - ..Default::default() - }; - - let lines = read_block(temp.path(), 3, 10, options).await?; - - assert_eq!( - lines, - vec![ - "L2: if cond {".to_string(), - "L3: inner();".to_string(), - "L4: }".to_string() - ] - ); - Ok(()) - } - - #[tokio::test] - async fn indentation_mode_expands_parents() -> anyhow::Result<()> { - let mut temp = NamedTempFile::new()?; - use std::io::Write as _; - write!( - temp, - "mod root {{ - fn outer() {{ - if cond {{ - inner(); - }} - }} -}} -" - )?; - - let mut options = IndentationArgs { - anchor_line: Some(4), - max_levels: 2, - ..Default::default() - }; - - let lines = read_block(temp.path(), 4, 50, options.clone()).await?; - assert_eq!( - lines, - vec![ - "L2: fn outer() {".to_string(), - "L3: if cond {".to_string(), - "L4: inner();".to_string(), - "L5: }".to_string(), - "L6: }".to_string(), - ] - ); - - options.max_levels = 3; - let expanded = read_block(temp.path(), 4, 50, options).await?; - assert_eq!( - expanded, - vec![ - "L1: mod root {".to_string(), - "L2: fn outer() {".to_string(), - "L3: if cond {".to_string(), - "L4: inner();".to_string(), - "L5: }".to_string(), - "L6: }".to_string(), - "L7: }".to_string(), - ] - ); - Ok(()) - } - - #[tokio::test] - async fn indentation_mode_respects_sibling_flag() -> anyhow::Result<()> { - let mut temp = NamedTempFile::new()?; - use std::io::Write as _; - write!( - temp, - "fn wrapper() {{ - if first {{ - do_first(); - }} - if second {{ - do_second(); - }} -}} -" - )?; - - let mut options = IndentationArgs { - anchor_line: Some(3), - include_siblings: false, - max_levels: 1, - ..Default::default() - }; - - let lines = read_block(temp.path(), 3, 50, options.clone()).await?; - assert_eq!( - lines, - vec![ - "L2: if first {".to_string(), - "L3: do_first();".to_string(), - "L4: }".to_string(), - ] - ); - - options.include_siblings = true; - let with_siblings = read_block(temp.path(), 3, 50, options).await?; - assert_eq!( - with_siblings, - vec![ - "L2: if first {".to_string(), - "L3: do_first();".to_string(), - "L4: }".to_string(), - "L5: if second {".to_string(), - "L6: do_second();".to_string(), - "L7: }".to_string(), - ] - ); - Ok(()) - } - - #[tokio::test] - async fn indentation_mode_handles_python_sample() -> anyhow::Result<()> { - let mut temp = NamedTempFile::new()?; - use std::io::Write as _; - write!( - temp, - "class Foo: - def __init__(self, size): - self.size = size - def double(self, value): - if value is None: - return 0 - result = value * self.size - return result -class Bar: - def compute(self): - helper = Foo(2) - return helper.double(5) -" - )?; - - let options = IndentationArgs { - anchor_line: Some(7), - include_siblings: true, - max_levels: 1, - ..Default::default() - }; - - let lines = read_block(temp.path(), 1, 200, options).await?; - assert_eq!( - lines, - vec![ - "L2: def __init__(self, size):".to_string(), - "L3: self.size = size".to_string(), - "L4: def double(self, value):".to_string(), - "L5: if value is None:".to_string(), - "L6: return 0".to_string(), - "L7: result = value * self.size".to_string(), - "L8: return result".to_string(), - ] - ); - Ok(()) - } - - #[tokio::test] - #[ignore] - async fn indentation_mode_handles_javascript_sample() -> anyhow::Result<()> { - let mut temp = NamedTempFile::new()?; - use std::io::Write as _; - write!( - temp, - "export function makeThing() {{ - const cache = new Map(); - function ensure(key) {{ - if (!cache.has(key)) {{ - cache.set(key, []); - }} - return cache.get(key); - }} - const handlers = {{ - init() {{ - console.log(\"init\"); - }}, - run() {{ - if (Math.random() > 0.5) {{ - return \"heads\"; - }} - return \"tails\"; - }}, - }}; - return {{ cache, handlers }}; -}} -export function other() {{ - return makeThing(); -}} -" - )?; - - let options = IndentationArgs { - anchor_line: Some(15), - max_levels: 1, - ..Default::default() - }; - - let lines = read_block(temp.path(), 15, 200, options).await?; - assert_eq!( - lines, - vec![ - "L10: init() {".to_string(), - "L11: console.log(\"init\");".to_string(), - "L12: },".to_string(), - "L13: run() {".to_string(), - "L14: if (Math.random() > 0.5) {".to_string(), - "L15: return \"heads\";".to_string(), - "L16: }".to_string(), - "L17: return \"tails\";".to_string(), - "L18: },".to_string(), - ] - ); - Ok(()) - } - - fn write_cpp_sample() -> anyhow::Result { - let mut temp = NamedTempFile::new()?; - use std::io::Write as _; - write!( - temp, - "#include -#include - -namespace sample {{ -class Runner {{ -public: - void setup() {{ - if (enabled_) {{ - init(); - }} - }} - - // Run the code - int run() const {{ - switch (mode_) {{ - case Mode::Fast: - return fast(); - case Mode::Slow: - return slow(); - default: - return fallback(); - }} - }} - -private: - bool enabled_ = false; - Mode mode_ = Mode::Fast; - - int fast() const {{ - return 1; - }} -}}; -}} // namespace sample -" - )?; - Ok(temp) - } - - #[tokio::test] - async fn indentation_mode_handles_cpp_sample_shallow() -> anyhow::Result<()> { - let temp = write_cpp_sample()?; - - let options = IndentationArgs { - include_siblings: false, - anchor_line: Some(18), - max_levels: 1, - ..Default::default() - }; - - let lines = read_block(temp.path(), 18, 200, options).await?; - assert_eq!( - lines, - vec![ - "L15: switch (mode_) {".to_string(), - "L16: case Mode::Fast:".to_string(), - "L17: return fast();".to_string(), - "L18: case Mode::Slow:".to_string(), - "L19: return slow();".to_string(), - "L20: default:".to_string(), - "L21: return fallback();".to_string(), - "L22: }".to_string(), - ] - ); - Ok(()) - } - - #[tokio::test] - async fn indentation_mode_handles_cpp_sample() -> anyhow::Result<()> { - let temp = write_cpp_sample()?; - - let options = IndentationArgs { - include_siblings: false, - anchor_line: Some(18), - max_levels: 2, - ..Default::default() - }; - - let lines = read_block(temp.path(), 18, 200, options).await?; - assert_eq!( - lines, - vec![ - "L13: // Run the code".to_string(), - "L14: int run() const {".to_string(), - "L15: switch (mode_) {".to_string(), - "L16: case Mode::Fast:".to_string(), - "L17: return fast();".to_string(), - "L18: case Mode::Slow:".to_string(), - "L19: return slow();".to_string(), - "L20: default:".to_string(), - "L21: return fallback();".to_string(), - "L22: }".to_string(), - "L23: }".to_string(), - ] - ); - Ok(()) - } - - #[tokio::test] - async fn indentation_mode_handles_cpp_sample_no_headers() -> anyhow::Result<()> { - let temp = write_cpp_sample()?; - - let options = IndentationArgs { - include_siblings: false, - include_header: false, - anchor_line: Some(18), - max_levels: 2, - ..Default::default() - }; - - let lines = read_block(temp.path(), 18, 200, options).await?; - assert_eq!( - lines, - vec![ - "L14: int run() const {".to_string(), - "L15: switch (mode_) {".to_string(), - "L16: case Mode::Fast:".to_string(), - "L17: return fast();".to_string(), - "L18: case Mode::Slow:".to_string(), - "L19: return slow();".to_string(), - "L20: default:".to_string(), - "L21: return fallback();".to_string(), - "L22: }".to_string(), - "L23: }".to_string(), - ] - ); - Ok(()) - } - - #[tokio::test] - async fn indentation_mode_handles_cpp_sample_siblings() -> anyhow::Result<()> { - let temp = write_cpp_sample()?; - - let options = IndentationArgs { - include_siblings: true, - include_header: false, - anchor_line: Some(18), - max_levels: 2, - ..Default::default() - }; - - let lines = read_block(temp.path(), 18, 200, options).await?; - assert_eq!( - lines, - vec![ - "L7: void setup() {".to_string(), - "L8: if (enabled_) {".to_string(), - "L9: init();".to_string(), - "L10: }".to_string(), - "L11: }".to_string(), - "L12: ".to_string(), - "L13: // Run the code".to_string(), - "L14: int run() const {".to_string(), - "L15: switch (mode_) {".to_string(), - "L16: case Mode::Fast:".to_string(), - "L17: return fast();".to_string(), - "L18: case Mode::Slow:".to_string(), - "L19: return slow();".to_string(), - "L20: default:".to_string(), - "L21: return fallback();".to_string(), - "L22: }".to_string(), - "L23: }".to_string(), - ] - ); - Ok(()) - } -} +#[path = "read_file_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/read_file_tests.rs b/codex-rs/core/src/tools/handlers/read_file_tests.rs new file mode 100644 index 0000000000..3921a98826 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/read_file_tests.rs @@ -0,0 +1,503 @@ +use super::indentation::read_block; +use super::slice::read; +use super::*; +use pretty_assertions::assert_eq; +use tempfile::NamedTempFile; + +#[tokio::test] +async fn reads_requested_range() -> anyhow::Result<()> { + let mut temp = NamedTempFile::new()?; + use std::io::Write as _; + write!( + temp, + "alpha +beta +gamma +" + )?; + + let lines = read(temp.path(), 2, 2).await?; + assert_eq!(lines, vec!["L2: beta".to_string(), "L3: gamma".to_string()]); + Ok(()) +} + +#[tokio::test] +async fn errors_when_offset_exceeds_length() -> anyhow::Result<()> { + let mut temp = NamedTempFile::new()?; + use std::io::Write as _; + writeln!(temp, "only")?; + + let err = read(temp.path(), 3, 1) + .await + .expect_err("offset exceeds length"); + assert_eq!( + err, + FunctionCallError::RespondToModel("offset exceeds file length".to_string()) + ); + Ok(()) +} + +#[tokio::test] +async fn reads_non_utf8_lines() -> anyhow::Result<()> { + let mut temp = NamedTempFile::new()?; + use std::io::Write as _; + temp.as_file_mut().write_all(b"\xff\xfe\nplain\n")?; + + let lines = read(temp.path(), 1, 2).await?; + let expected_first = format!("L1: {}{}", '\u{FFFD}', '\u{FFFD}'); + assert_eq!(lines, vec![expected_first, "L2: plain".to_string()]); + Ok(()) +} + +#[tokio::test] +async fn trims_crlf_endings() -> anyhow::Result<()> { + let mut temp = NamedTempFile::new()?; + use std::io::Write as _; + write!(temp, "one\r\ntwo\r\n")?; + + let lines = read(temp.path(), 1, 2).await?; + assert_eq!(lines, vec!["L1: one".to_string(), "L2: two".to_string()]); + Ok(()) +} + +#[tokio::test] +async fn respects_limit_even_with_more_lines() -> anyhow::Result<()> { + let mut temp = NamedTempFile::new()?; + use std::io::Write as _; + write!( + temp, + "first +second +third +" + )?; + + let lines = read(temp.path(), 1, 2).await?; + assert_eq!( + lines, + vec!["L1: first".to_string(), "L2: second".to_string()] + ); + Ok(()) +} + +#[tokio::test] +async fn truncates_lines_longer_than_max_length() -> anyhow::Result<()> { + let mut temp = NamedTempFile::new()?; + use std::io::Write as _; + let long_line = "x".repeat(MAX_LINE_LENGTH + 50); + writeln!(temp, "{long_line}")?; + + let lines = read(temp.path(), 1, 1).await?; + let expected = "x".repeat(MAX_LINE_LENGTH); + assert_eq!(lines, vec![format!("L1: {expected}")]); + Ok(()) +} + +#[tokio::test] +async fn indentation_mode_captures_block() -> anyhow::Result<()> { + let mut temp = NamedTempFile::new()?; + use std::io::Write as _; + write!( + temp, + "fn outer() {{ + if cond {{ + inner(); + }} + tail(); +}} +" + )?; + + let options = IndentationArgs { + anchor_line: Some(3), + include_siblings: false, + max_levels: 1, + ..Default::default() + }; + + let lines = read_block(temp.path(), 3, 10, options).await?; + + assert_eq!( + lines, + vec![ + "L2: if cond {".to_string(), + "L3: inner();".to_string(), + "L4: }".to_string() + ] + ); + Ok(()) +} + +#[tokio::test] +async fn indentation_mode_expands_parents() -> anyhow::Result<()> { + let mut temp = NamedTempFile::new()?; + use std::io::Write as _; + write!( + temp, + "mod root {{ + fn outer() {{ + if cond {{ + inner(); + }} + }} +}} +" + )?; + + let mut options = IndentationArgs { + anchor_line: Some(4), + max_levels: 2, + ..Default::default() + }; + + let lines = read_block(temp.path(), 4, 50, options.clone()).await?; + assert_eq!( + lines, + vec![ + "L2: fn outer() {".to_string(), + "L3: if cond {".to_string(), + "L4: inner();".to_string(), + "L5: }".to_string(), + "L6: }".to_string(), + ] + ); + + options.max_levels = 3; + let expanded = read_block(temp.path(), 4, 50, options).await?; + assert_eq!( + expanded, + vec![ + "L1: mod root {".to_string(), + "L2: fn outer() {".to_string(), + "L3: if cond {".to_string(), + "L4: inner();".to_string(), + "L5: }".to_string(), + "L6: }".to_string(), + "L7: }".to_string(), + ] + ); + Ok(()) +} + +#[tokio::test] +async fn indentation_mode_respects_sibling_flag() -> anyhow::Result<()> { + let mut temp = NamedTempFile::new()?; + use std::io::Write as _; + write!( + temp, + "fn wrapper() {{ + if first {{ + do_first(); + }} + if second {{ + do_second(); + }} +}} +" + )?; + + let mut options = IndentationArgs { + anchor_line: Some(3), + include_siblings: false, + max_levels: 1, + ..Default::default() + }; + + let lines = read_block(temp.path(), 3, 50, options.clone()).await?; + assert_eq!( + lines, + vec![ + "L2: if first {".to_string(), + "L3: do_first();".to_string(), + "L4: }".to_string(), + ] + ); + + options.include_siblings = true; + let with_siblings = read_block(temp.path(), 3, 50, options).await?; + assert_eq!( + with_siblings, + vec![ + "L2: if first {".to_string(), + "L3: do_first();".to_string(), + "L4: }".to_string(), + "L5: if second {".to_string(), + "L6: do_second();".to_string(), + "L7: }".to_string(), + ] + ); + Ok(()) +} + +#[tokio::test] +async fn indentation_mode_handles_python_sample() -> anyhow::Result<()> { + let mut temp = NamedTempFile::new()?; + use std::io::Write as _; + write!( + temp, + "class Foo: + def __init__(self, size): + self.size = size + def double(self, value): + if value is None: + return 0 + result = value * self.size + return result +class Bar: + def compute(self): + helper = Foo(2) + return helper.double(5) +" + )?; + + let options = IndentationArgs { + anchor_line: Some(7), + include_siblings: true, + max_levels: 1, + ..Default::default() + }; + + let lines = read_block(temp.path(), 1, 200, options).await?; + assert_eq!( + lines, + vec![ + "L2: def __init__(self, size):".to_string(), + "L3: self.size = size".to_string(), + "L4: def double(self, value):".to_string(), + "L5: if value is None:".to_string(), + "L6: return 0".to_string(), + "L7: result = value * self.size".to_string(), + "L8: return result".to_string(), + ] + ); + Ok(()) +} + +#[tokio::test] +#[ignore] +async fn indentation_mode_handles_javascript_sample() -> anyhow::Result<()> { + let mut temp = NamedTempFile::new()?; + use std::io::Write as _; + write!( + temp, + "export function makeThing() {{ + const cache = new Map(); + function ensure(key) {{ + if (!cache.has(key)) {{ + cache.set(key, []); + }} + return cache.get(key); + }} + const handlers = {{ + init() {{ + console.log(\"init\"); + }}, + run() {{ + if (Math.random() > 0.5) {{ + return \"heads\"; + }} + return \"tails\"; + }}, + }}; + return {{ cache, handlers }}; +}} +export function other() {{ + return makeThing(); +}} +" + )?; + + let options = IndentationArgs { + anchor_line: Some(15), + max_levels: 1, + ..Default::default() + }; + + let lines = read_block(temp.path(), 15, 200, options).await?; + assert_eq!( + lines, + vec![ + "L10: init() {".to_string(), + "L11: console.log(\"init\");".to_string(), + "L12: },".to_string(), + "L13: run() {".to_string(), + "L14: if (Math.random() > 0.5) {".to_string(), + "L15: return \"heads\";".to_string(), + "L16: }".to_string(), + "L17: return \"tails\";".to_string(), + "L18: },".to_string(), + ] + ); + Ok(()) +} + +fn write_cpp_sample() -> anyhow::Result { + let mut temp = NamedTempFile::new()?; + use std::io::Write as _; + write!( + temp, + "#include +#include + +namespace sample {{ +class Runner {{ +public: + void setup() {{ + if (enabled_) {{ + init(); + }} + }} + + // Run the code + int run() const {{ + switch (mode_) {{ + case Mode::Fast: + return fast(); + case Mode::Slow: + return slow(); + default: + return fallback(); + }} + }} + +private: + bool enabled_ = false; + Mode mode_ = Mode::Fast; + + int fast() const {{ + return 1; + }} +}}; +}} // namespace sample +" + )?; + Ok(temp) +} + +#[tokio::test] +async fn indentation_mode_handles_cpp_sample_shallow() -> anyhow::Result<()> { + let temp = write_cpp_sample()?; + + let options = IndentationArgs { + include_siblings: false, + anchor_line: Some(18), + max_levels: 1, + ..Default::default() + }; + + let lines = read_block(temp.path(), 18, 200, options).await?; + assert_eq!( + lines, + vec![ + "L15: switch (mode_) {".to_string(), + "L16: case Mode::Fast:".to_string(), + "L17: return fast();".to_string(), + "L18: case Mode::Slow:".to_string(), + "L19: return slow();".to_string(), + "L20: default:".to_string(), + "L21: return fallback();".to_string(), + "L22: }".to_string(), + ] + ); + Ok(()) +} + +#[tokio::test] +async fn indentation_mode_handles_cpp_sample() -> anyhow::Result<()> { + let temp = write_cpp_sample()?; + + let options = IndentationArgs { + include_siblings: false, + anchor_line: Some(18), + max_levels: 2, + ..Default::default() + }; + + let lines = read_block(temp.path(), 18, 200, options).await?; + assert_eq!( + lines, + vec![ + "L13: // Run the code".to_string(), + "L14: int run() const {".to_string(), + "L15: switch (mode_) {".to_string(), + "L16: case Mode::Fast:".to_string(), + "L17: return fast();".to_string(), + "L18: case Mode::Slow:".to_string(), + "L19: return slow();".to_string(), + "L20: default:".to_string(), + "L21: return fallback();".to_string(), + "L22: }".to_string(), + "L23: }".to_string(), + ] + ); + Ok(()) +} + +#[tokio::test] +async fn indentation_mode_handles_cpp_sample_no_headers() -> anyhow::Result<()> { + let temp = write_cpp_sample()?; + + let options = IndentationArgs { + include_siblings: false, + include_header: false, + anchor_line: Some(18), + max_levels: 2, + ..Default::default() + }; + + let lines = read_block(temp.path(), 18, 200, options).await?; + assert_eq!( + lines, + vec![ + "L14: int run() const {".to_string(), + "L15: switch (mode_) {".to_string(), + "L16: case Mode::Fast:".to_string(), + "L17: return fast();".to_string(), + "L18: case Mode::Slow:".to_string(), + "L19: return slow();".to_string(), + "L20: default:".to_string(), + "L21: return fallback();".to_string(), + "L22: }".to_string(), + "L23: }".to_string(), + ] + ); + Ok(()) +} + +#[tokio::test] +async fn indentation_mode_handles_cpp_sample_siblings() -> anyhow::Result<()> { + let temp = write_cpp_sample()?; + + let options = IndentationArgs { + include_siblings: true, + include_header: false, + anchor_line: Some(18), + max_levels: 2, + ..Default::default() + }; + + let lines = read_block(temp.path(), 18, 200, options).await?; + assert_eq!( + lines, + vec![ + "L7: void setup() {".to_string(), + "L8: if (enabled_) {".to_string(), + "L9: init();".to_string(), + "L10: }".to_string(), + "L11: }".to_string(), + "L12: ".to_string(), + "L13: // Run the code".to_string(), + "L14: int run() const {".to_string(), + "L15: switch (mode_) {".to_string(), + "L16: case Mode::Fast:".to_string(), + "L17: return fast();".to_string(), + "L18: case Mode::Slow:".to_string(), + "L19: return slow();".to_string(), + "L20: default:".to_string(), + "L21: return fallback();".to_string(), + "L22: }".to_string(), + "L23: }".to_string(), + ] + ); + Ok(()) +} diff --git a/codex-rs/core/src/tools/handlers/request_user_input.rs b/codex-rs/core/src/tools/handlers/request_user_input.rs index 77def0b682..4d95a2c20d 100644 --- a/codex-rs/core/src/tools/handlers/request_user_input.rs +++ b/codex-rs/core/src/tools/handlers/request_user_input.rs @@ -121,51 +121,5 @@ impl ToolHandler for RequestUserInputHandler { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn request_user_input_mode_availability_defaults_to_plan_only() { - assert!(ModeKind::Plan.allows_request_user_input()); - assert!(!ModeKind::Default.allows_request_user_input()); - assert!(!ModeKind::Execute.allows_request_user_input()); - assert!(!ModeKind::PairProgramming.allows_request_user_input()); - } - - #[test] - fn request_user_input_unavailable_messages_respect_default_mode_feature_flag() { - assert_eq!( - request_user_input_unavailable_message(ModeKind::Plan, false), - None - ); - assert_eq!( - request_user_input_unavailable_message(ModeKind::Default, false), - Some("request_user_input is unavailable in Default mode".to_string()) - ); - assert_eq!( - request_user_input_unavailable_message(ModeKind::Default, true), - None - ); - assert_eq!( - request_user_input_unavailable_message(ModeKind::Execute, false), - Some("request_user_input is unavailable in Execute mode".to_string()) - ); - assert_eq!( - request_user_input_unavailable_message(ModeKind::PairProgramming, false), - Some("request_user_input is unavailable in Pair Programming mode".to_string()) - ); - } - - #[test] - fn request_user_input_tool_description_mentions_available_modes() { - assert_eq!( - request_user_input_tool_description(false), - "Request user input for one to three short questions and wait for the response. This tool is only available in Plan mode.".to_string() - ); - assert_eq!( - request_user_input_tool_description(true), - "Request user input for one to three short questions and wait for the response. This tool is only available in Default or Plan mode.".to_string() - ); - } -} +#[path = "request_user_input_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/request_user_input_tests.rs b/codex-rs/core/src/tools/handlers/request_user_input_tests.rs new file mode 100644 index 0000000000..f4df3c43c0 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/request_user_input_tests.rs @@ -0,0 +1,46 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn request_user_input_mode_availability_defaults_to_plan_only() { + assert!(ModeKind::Plan.allows_request_user_input()); + assert!(!ModeKind::Default.allows_request_user_input()); + assert!(!ModeKind::Execute.allows_request_user_input()); + assert!(!ModeKind::PairProgramming.allows_request_user_input()); +} + +#[test] +fn request_user_input_unavailable_messages_respect_default_mode_feature_flag() { + assert_eq!( + request_user_input_unavailable_message(ModeKind::Plan, false), + None + ); + assert_eq!( + request_user_input_unavailable_message(ModeKind::Default, false), + Some("request_user_input is unavailable in Default mode".to_string()) + ); + assert_eq!( + request_user_input_unavailable_message(ModeKind::Default, true), + None + ); + assert_eq!( + request_user_input_unavailable_message(ModeKind::Execute, false), + Some("request_user_input is unavailable in Execute mode".to_string()) + ); + assert_eq!( + request_user_input_unavailable_message(ModeKind::PairProgramming, false), + Some("request_user_input is unavailable in Pair Programming mode".to_string()) + ); +} + +#[test] +fn request_user_input_tool_description_mentions_available_modes() { + assert_eq!( + request_user_input_tool_description(false), + "Request user input for one to three short questions and wait for the response. This tool is only available in Plan mode.".to_string() + ); + assert_eq!( + request_user_input_tool_description(true), + "Request user input for one to three short questions and wait for the response. This tool is only available in Default or Plan mode.".to_string() + ); +} diff --git a/codex-rs/core/src/tools/handlers/shell.rs b/codex-rs/core/src/tools/handlers/shell.rs index a9c3aee346..01d7f1b6e9 100644 --- a/codex-rs/core/src/tools/handlers/shell.rs +++ b/codex-rs/core/src/tools/handlers/shell.rs @@ -461,185 +461,5 @@ impl ShellHandler { } #[cfg(test)] -mod tests { - use std::path::PathBuf; - use std::sync::Arc; - - use codex_protocol::models::ShellCommandToolCallParams; - use pretty_assertions::assert_eq; - - use crate::codex::make_session_and_context; - use crate::exec_env::create_env; - use crate::is_safe_command::is_known_safe_command; - use crate::powershell::try_find_powershell_executable_blocking; - use crate::powershell::try_find_pwsh_executable_blocking; - use crate::sandboxing::SandboxPermissions; - use crate::shell::Shell; - use crate::shell::ShellType; - use crate::shell_snapshot::ShellSnapshot; - use crate::tools::handlers::ShellCommandHandler; - use tokio::sync::watch; - - /// The logic for is_known_safe_command() has heuristics for known shells, - /// so we must ensure the commands generated by [ShellCommandHandler] can be - /// recognized as safe if the `command` is safe. - #[test] - fn commands_generated_by_shell_command_handler_can_be_matched_by_is_known_safe_command() { - let bash_shell = Shell { - shell_type: ShellType::Bash, - shell_path: PathBuf::from("/bin/bash"), - shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), - }; - assert_safe(&bash_shell, "ls -la"); - - let zsh_shell = Shell { - shell_type: ShellType::Zsh, - shell_path: PathBuf::from("/bin/zsh"), - shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), - }; - assert_safe(&zsh_shell, "ls -la"); - - if let Some(path) = try_find_powershell_executable_blocking() { - let powershell = Shell { - shell_type: ShellType::PowerShell, - shell_path: path.to_path_buf(), - shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), - }; - assert_safe(&powershell, "ls -Name"); - } - - if let Some(path) = try_find_pwsh_executable_blocking() { - let pwsh = Shell { - shell_type: ShellType::PowerShell, - shell_path: path.to_path_buf(), - shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), - }; - assert_safe(&pwsh, "ls -Name"); - } - } - - fn assert_safe(shell: &Shell, command: &str) { - assert!(is_known_safe_command( - &shell.derive_exec_args(command, /* use_login_shell */ true) - )); - assert!(is_known_safe_command( - &shell.derive_exec_args(command, /* use_login_shell */ false) - )); - } - - #[tokio::test] - async fn shell_command_handler_to_exec_params_uses_session_shell_and_turn_context() { - let (session, turn_context) = make_session_and_context().await; - - let command = "echo hello".to_string(); - let workdir = Some("subdir".to_string()); - let login = None; - let timeout_ms = Some(1234); - let sandbox_permissions = SandboxPermissions::RequireEscalated; - let justification = Some("because tests".to_string()); - - let expected_command = session.user_shell().derive_exec_args(&command, true); - let expected_cwd = turn_context.resolve_path(workdir.clone()); - let expected_env = create_env( - &turn_context.shell_environment_policy, - Some(session.conversation_id), - ); - - let params = ShellCommandToolCallParams { - command, - workdir, - login, - timeout_ms, - sandbox_permissions: Some(sandbox_permissions), - additional_permissions: None, - prefix_rule: None, - justification: justification.clone(), - }; - - let exec_params = ShellCommandHandler::to_exec_params( - ¶ms, - &session, - &turn_context, - session.conversation_id, - true, - ) - .expect("login shells should be allowed"); - - // ExecParams cannot derive Eq due to the CancellationToken field, so we manually compare the fields. - assert_eq!(exec_params.command, expected_command); - assert_eq!(exec_params.cwd, expected_cwd); - assert_eq!(exec_params.env, expected_env); - assert_eq!(exec_params.network, turn_context.network); - assert_eq!(exec_params.expiration.timeout_ms(), timeout_ms); - assert_eq!(exec_params.sandbox_permissions, sandbox_permissions); - assert_eq!(exec_params.justification, justification); - assert_eq!(exec_params.arg0, None); - } - - #[test] - fn shell_command_handler_respects_explicit_login_flag() { - let (_tx, shell_snapshot) = watch::channel(Some(Arc::new(ShellSnapshot { - path: PathBuf::from("/tmp/snapshot.sh"), - cwd: PathBuf::from("/tmp"), - }))); - let shell = Shell { - shell_type: ShellType::Bash, - shell_path: PathBuf::from("/bin/bash"), - shell_snapshot, - }; - - let login_command = ShellCommandHandler::base_command(&shell, "echo login shell", true); - assert_eq!( - login_command, - shell.derive_exec_args("echo login shell", true) - ); - - let non_login_command = - ShellCommandHandler::base_command(&shell, "echo non login shell", false); - assert_eq!( - non_login_command, - shell.derive_exec_args("echo non login shell", false) - ); - } - - #[tokio::test] - async fn shell_command_handler_defaults_to_non_login_when_disallowed() { - let (session, turn_context) = make_session_and_context().await; - let params = ShellCommandToolCallParams { - command: "echo hello".to_string(), - workdir: None, - login: None, - timeout_ms: None, - sandbox_permissions: None, - additional_permissions: None, - prefix_rule: None, - justification: None, - }; - - let exec_params = ShellCommandHandler::to_exec_params( - ¶ms, - &session, - &turn_context, - session.conversation_id, - false, - ) - .expect("non-login shells should still be allowed"); - - assert_eq!( - exec_params.command, - session.user_shell().derive_exec_args("echo hello", false) - ); - } - - #[test] - fn shell_command_handler_rejects_login_when_disallowed() { - let err = ShellCommandHandler::resolve_use_login_shell(Some(true), false) - .expect_err("explicit login should be rejected"); - - assert!( - err.to_string() - .contains("login shell is disabled by config"), - "unexpected error: {err}" - ); - } -} +#[path = "shell_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/shell_tests.rs b/codex-rs/core/src/tools/handlers/shell_tests.rs new file mode 100644 index 0000000000..b69f3be230 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/shell_tests.rs @@ -0,0 +1,180 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use codex_protocol::models::ShellCommandToolCallParams; +use pretty_assertions::assert_eq; + +use crate::codex::make_session_and_context; +use crate::exec_env::create_env; +use crate::is_safe_command::is_known_safe_command; +use crate::powershell::try_find_powershell_executable_blocking; +use crate::powershell::try_find_pwsh_executable_blocking; +use crate::sandboxing::SandboxPermissions; +use crate::shell::Shell; +use crate::shell::ShellType; +use crate::shell_snapshot::ShellSnapshot; +use crate::tools::handlers::ShellCommandHandler; +use tokio::sync::watch; + +/// The logic for is_known_safe_command() has heuristics for known shells, +/// so we must ensure the commands generated by [ShellCommandHandler] can be +/// recognized as safe if the `command` is safe. +#[test] +fn commands_generated_by_shell_command_handler_can_be_matched_by_is_known_safe_command() { + let bash_shell = Shell { + shell_type: ShellType::Bash, + shell_path: PathBuf::from("/bin/bash"), + shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), + }; + assert_safe(&bash_shell, "ls -la"); + + let zsh_shell = Shell { + shell_type: ShellType::Zsh, + shell_path: PathBuf::from("/bin/zsh"), + shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), + }; + assert_safe(&zsh_shell, "ls -la"); + + if let Some(path) = try_find_powershell_executable_blocking() { + let powershell = Shell { + shell_type: ShellType::PowerShell, + shell_path: path.to_path_buf(), + shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), + }; + assert_safe(&powershell, "ls -Name"); + } + + if let Some(path) = try_find_pwsh_executable_blocking() { + let pwsh = Shell { + shell_type: ShellType::PowerShell, + shell_path: path.to_path_buf(), + shell_snapshot: crate::shell::empty_shell_snapshot_receiver(), + }; + assert_safe(&pwsh, "ls -Name"); + } +} + +fn assert_safe(shell: &Shell, command: &str) { + assert!(is_known_safe_command( + &shell.derive_exec_args(command, /* use_login_shell */ true) + )); + assert!(is_known_safe_command( + &shell.derive_exec_args(command, /* use_login_shell */ false) + )); +} + +#[tokio::test] +async fn shell_command_handler_to_exec_params_uses_session_shell_and_turn_context() { + let (session, turn_context) = make_session_and_context().await; + + let command = "echo hello".to_string(); + let workdir = Some("subdir".to_string()); + let login = None; + let timeout_ms = Some(1234); + let sandbox_permissions = SandboxPermissions::RequireEscalated; + let justification = Some("because tests".to_string()); + + let expected_command = session.user_shell().derive_exec_args(&command, true); + let expected_cwd = turn_context.resolve_path(workdir.clone()); + let expected_env = create_env( + &turn_context.shell_environment_policy, + Some(session.conversation_id), + ); + + let params = ShellCommandToolCallParams { + command, + workdir, + login, + timeout_ms, + sandbox_permissions: Some(sandbox_permissions), + additional_permissions: None, + prefix_rule: None, + justification: justification.clone(), + }; + + let exec_params = ShellCommandHandler::to_exec_params( + ¶ms, + &session, + &turn_context, + session.conversation_id, + true, + ) + .expect("login shells should be allowed"); + + // ExecParams cannot derive Eq due to the CancellationToken field, so we manually compare the fields. + assert_eq!(exec_params.command, expected_command); + assert_eq!(exec_params.cwd, expected_cwd); + assert_eq!(exec_params.env, expected_env); + assert_eq!(exec_params.network, turn_context.network); + assert_eq!(exec_params.expiration.timeout_ms(), timeout_ms); + assert_eq!(exec_params.sandbox_permissions, sandbox_permissions); + assert_eq!(exec_params.justification, justification); + assert_eq!(exec_params.arg0, None); +} + +#[test] +fn shell_command_handler_respects_explicit_login_flag() { + let (_tx, shell_snapshot) = watch::channel(Some(Arc::new(ShellSnapshot { + path: PathBuf::from("/tmp/snapshot.sh"), + cwd: PathBuf::from("/tmp"), + }))); + let shell = Shell { + shell_type: ShellType::Bash, + shell_path: PathBuf::from("/bin/bash"), + shell_snapshot, + }; + + let login_command = ShellCommandHandler::base_command(&shell, "echo login shell", true); + assert_eq!( + login_command, + shell.derive_exec_args("echo login shell", true) + ); + + let non_login_command = + ShellCommandHandler::base_command(&shell, "echo non login shell", false); + assert_eq!( + non_login_command, + shell.derive_exec_args("echo non login shell", false) + ); +} + +#[tokio::test] +async fn shell_command_handler_defaults_to_non_login_when_disallowed() { + let (session, turn_context) = make_session_and_context().await; + let params = ShellCommandToolCallParams { + command: "echo hello".to_string(), + workdir: None, + login: None, + timeout_ms: None, + sandbox_permissions: None, + additional_permissions: None, + prefix_rule: None, + justification: None, + }; + + let exec_params = ShellCommandHandler::to_exec_params( + ¶ms, + &session, + &turn_context, + session.conversation_id, + false, + ) + .expect("non-login shells should still be allowed"); + + assert_eq!( + exec_params.command, + session.user_shell().derive_exec_args("echo hello", false) + ); +} + +#[test] +fn shell_command_handler_rejects_login_when_disallowed() { + let err = ShellCommandHandler::resolve_use_login_shell(Some(true), false) + .expect_err("explicit login should be rejected"); + + assert!( + err.to_string() + .contains("login shell is disabled by config"), + "unexpected error: {err}" + ); +} diff --git a/codex-rs/core/src/tools/handlers/tool_search.rs b/codex-rs/core/src/tools/handlers/tool_search.rs index 356b64ec9a..2005c7d2f1 100644 --- a/codex-rs/core/src/tools/handlers/tool_search.rs +++ b/codex-rs/core/src/tools/handlers/tool_search.rs @@ -188,203 +188,5 @@ fn build_search_text(name: &str, info: &ToolInfo) -> String { } #[cfg(test)] -mod tests { - use super::*; - use crate::mcp::CODEX_APPS_MCP_SERVER_NAME; - use pretty_assertions::assert_eq; - use rmcp::model::JsonObject; - use rmcp::model::Tool; - use serde_json::json; - use std::sync::Arc; - - #[test] - fn serialize_tool_search_output_tools_groups_results_by_namespace() { - let entries = [ - ( - "mcp__codex_apps__calendar-create-event".to_string(), - ToolInfo { - server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool_name: "-create-event".to_string(), - tool_namespace: "mcp__codex_apps__calendar".to_string(), - tool: Tool { - name: "calendar-create-event".to_string().into(), - title: None, - description: Some("Create a calendar event.".into()), - input_schema: Arc::new(JsonObject::from_iter([( - "type".to_string(), - json!("object"), - )])), - output_schema: None, - annotations: None, - execution: None, - icons: None, - meta: None, - }, - connector_id: Some("calendar".to_string()), - connector_name: Some("Calendar".to_string()), - plugin_display_names: Vec::new(), - connector_description: Some("Plan events".to_string()), - }, - ), - ( - "mcp__codex_apps__gmail-read-email".to_string(), - ToolInfo { - server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool_name: "-read-email".to_string(), - tool_namespace: "mcp__codex_apps__gmail".to_string(), - tool: Tool { - name: "gmail-read-email".to_string().into(), - title: None, - description: Some("Read an email.".into()), - input_schema: Arc::new(JsonObject::from_iter([( - "type".to_string(), - json!("object"), - )])), - output_schema: None, - annotations: None, - execution: None, - icons: None, - meta: None, - }, - connector_id: Some("gmail".to_string()), - connector_name: Some("Gmail".to_string()), - plugin_display_names: Vec::new(), - connector_description: Some("Read mail".to_string()), - }, - ), - ( - "mcp__codex_apps__calendar-list-events".to_string(), - ToolInfo { - server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool_name: "-list-events".to_string(), - tool_namespace: "mcp__codex_apps__calendar".to_string(), - tool: Tool { - name: "calendar-list-events".to_string().into(), - title: None, - description: Some("List calendar events.".into()), - input_schema: Arc::new(JsonObject::from_iter([( - "type".to_string(), - json!("object"), - )])), - output_schema: None, - annotations: None, - execution: None, - icons: None, - meta: None, - }, - connector_id: Some("calendar".to_string()), - connector_name: Some("Calendar".to_string()), - plugin_display_names: Vec::new(), - connector_description: Some("Plan events".to_string()), - }, - ), - ]; - - let tools = serialize_tool_search_output_tools(&[&entries[0], &entries[1], &entries[2]]) - .expect("serialize tool search output"); - - assert_eq!( - tools, - vec![ - ToolSearchOutputTool::Namespace(ResponsesApiNamespace { - name: "mcp__codex_apps__calendar".to_string(), - description: "Plan events".to_string(), - tools: vec![ - ResponsesApiNamespaceTool::Function(ResponsesApiTool { - name: "-create-event".to_string(), - description: "Create a calendar event.".to_string(), - strict: false, - defer_loading: Some(true), - parameters: crate::tools::spec::JsonSchema::Object { - properties: Default::default(), - required: None, - additional_properties: None, - }, - output_schema: None, - }), - ResponsesApiNamespaceTool::Function(ResponsesApiTool { - name: "-list-events".to_string(), - description: "List calendar events.".to_string(), - strict: false, - defer_loading: Some(true), - parameters: crate::tools::spec::JsonSchema::Object { - properties: Default::default(), - required: None, - additional_properties: None, - }, - output_schema: None, - }), - ], - }), - ToolSearchOutputTool::Namespace(ResponsesApiNamespace { - name: "mcp__codex_apps__gmail".to_string(), - description: "Read mail".to_string(), - tools: vec![ResponsesApiNamespaceTool::Function(ResponsesApiTool { - name: "-read-email".to_string(), - description: "Read an email.".to_string(), - strict: false, - defer_loading: Some(true), - parameters: crate::tools::spec::JsonSchema::Object { - properties: Default::default(), - required: None, - additional_properties: None, - }, - output_schema: None, - })], - }) - ] - ); - } - - #[test] - fn serialize_tool_search_output_tools_falls_back_to_connector_name_description() { - let entries = [( - "mcp__codex_apps__gmail-batch-read-email".to_string(), - ToolInfo { - server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool_name: "-batch-read-email".to_string(), - tool_namespace: "mcp__codex_apps__gmail".to_string(), - tool: Tool { - name: "gmail-batch-read-email".to_string().into(), - title: None, - description: Some("Read multiple emails.".into()), - input_schema: Arc::new(JsonObject::from_iter([( - "type".to_string(), - json!("object"), - )])), - output_schema: None, - annotations: None, - execution: None, - icons: None, - meta: None, - }, - connector_id: Some("connector_gmail_456".to_string()), - connector_name: Some("Gmail".to_string()), - plugin_display_names: Vec::new(), - connector_description: None, - }, - )]; - - let tools = serialize_tool_search_output_tools(&[&entries[0]]).expect("serialize"); - - assert_eq!( - tools, - vec![ToolSearchOutputTool::Namespace(ResponsesApiNamespace { - name: "mcp__codex_apps__gmail".to_string(), - description: "Tools for working with Gmail.".to_string(), - tools: vec![ResponsesApiNamespaceTool::Function(ResponsesApiTool { - name: "-batch-read-email".to_string(), - description: "Read multiple emails.".to_string(), - strict: false, - defer_loading: Some(true), - parameters: crate::tools::spec::JsonSchema::Object { - properties: Default::default(), - required: None, - additional_properties: None, - }, - output_schema: None, - })], - })] - ); - } -} +#[path = "tool_search_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/tool_search_tests.rs b/codex-rs/core/src/tools/handlers/tool_search_tests.rs new file mode 100644 index 0000000000..fc7ef5970e --- /dev/null +++ b/codex-rs/core/src/tools/handlers/tool_search_tests.rs @@ -0,0 +1,198 @@ +use super::*; +use crate::mcp::CODEX_APPS_MCP_SERVER_NAME; +use pretty_assertions::assert_eq; +use rmcp::model::JsonObject; +use rmcp::model::Tool; +use serde_json::json; +use std::sync::Arc; + +#[test] +fn serialize_tool_search_output_tools_groups_results_by_namespace() { + let entries = [ + ( + "mcp__codex_apps__calendar-create-event".to_string(), + ToolInfo { + server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool_name: "-create-event".to_string(), + tool_namespace: "mcp__codex_apps__calendar".to_string(), + tool: Tool { + name: "calendar-create-event".to_string().into(), + title: None, + description: Some("Create a calendar event.".into()), + input_schema: Arc::new(JsonObject::from_iter([( + "type".to_string(), + json!("object"), + )])), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + }, + connector_id: Some("calendar".to_string()), + connector_name: Some("Calendar".to_string()), + plugin_display_names: Vec::new(), + connector_description: Some("Plan events".to_string()), + }, + ), + ( + "mcp__codex_apps__gmail-read-email".to_string(), + ToolInfo { + server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool_name: "-read-email".to_string(), + tool_namespace: "mcp__codex_apps__gmail".to_string(), + tool: Tool { + name: "gmail-read-email".to_string().into(), + title: None, + description: Some("Read an email.".into()), + input_schema: Arc::new(JsonObject::from_iter([( + "type".to_string(), + json!("object"), + )])), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + }, + connector_id: Some("gmail".to_string()), + connector_name: Some("Gmail".to_string()), + plugin_display_names: Vec::new(), + connector_description: Some("Read mail".to_string()), + }, + ), + ( + "mcp__codex_apps__calendar-list-events".to_string(), + ToolInfo { + server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool_name: "-list-events".to_string(), + tool_namespace: "mcp__codex_apps__calendar".to_string(), + tool: Tool { + name: "calendar-list-events".to_string().into(), + title: None, + description: Some("List calendar events.".into()), + input_schema: Arc::new(JsonObject::from_iter([( + "type".to_string(), + json!("object"), + )])), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + }, + connector_id: Some("calendar".to_string()), + connector_name: Some("Calendar".to_string()), + plugin_display_names: Vec::new(), + connector_description: Some("Plan events".to_string()), + }, + ), + ]; + + let tools = serialize_tool_search_output_tools(&[&entries[0], &entries[1], &entries[2]]) + .expect("serialize tool search output"); + + assert_eq!( + tools, + vec![ + ToolSearchOutputTool::Namespace(ResponsesApiNamespace { + name: "mcp__codex_apps__calendar".to_string(), + description: "Plan events".to_string(), + tools: vec![ + ResponsesApiNamespaceTool::Function(ResponsesApiTool { + name: "-create-event".to_string(), + description: "Create a calendar event.".to_string(), + strict: false, + defer_loading: Some(true), + parameters: crate::tools::spec::JsonSchema::Object { + properties: Default::default(), + required: None, + additional_properties: None, + }, + output_schema: None, + }), + ResponsesApiNamespaceTool::Function(ResponsesApiTool { + name: "-list-events".to_string(), + description: "List calendar events.".to_string(), + strict: false, + defer_loading: Some(true), + parameters: crate::tools::spec::JsonSchema::Object { + properties: Default::default(), + required: None, + additional_properties: None, + }, + output_schema: None, + }), + ], + }), + ToolSearchOutputTool::Namespace(ResponsesApiNamespace { + name: "mcp__codex_apps__gmail".to_string(), + description: "Read mail".to_string(), + tools: vec![ResponsesApiNamespaceTool::Function(ResponsesApiTool { + name: "-read-email".to_string(), + description: "Read an email.".to_string(), + strict: false, + defer_loading: Some(true), + parameters: crate::tools::spec::JsonSchema::Object { + properties: Default::default(), + required: None, + additional_properties: None, + }, + output_schema: None, + })], + }) + ] + ); +} + +#[test] +fn serialize_tool_search_output_tools_falls_back_to_connector_name_description() { + let entries = [( + "mcp__codex_apps__gmail-batch-read-email".to_string(), + ToolInfo { + server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool_name: "-batch-read-email".to_string(), + tool_namespace: "mcp__codex_apps__gmail".to_string(), + tool: Tool { + name: "gmail-batch-read-email".to_string().into(), + title: None, + description: Some("Read multiple emails.".into()), + input_schema: Arc::new(JsonObject::from_iter([( + "type".to_string(), + json!("object"), + )])), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + }, + connector_id: Some("connector_gmail_456".to_string()), + connector_name: Some("Gmail".to_string()), + plugin_display_names: Vec::new(), + connector_description: None, + }, + )]; + + let tools = serialize_tool_search_output_tools(&[&entries[0]]).expect("serialize"); + + assert_eq!( + tools, + vec![ToolSearchOutputTool::Namespace(ResponsesApiNamespace { + name: "mcp__codex_apps__gmail".to_string(), + description: "Tools for working with Gmail.".to_string(), + tools: vec![ResponsesApiNamespaceTool::Function(ResponsesApiTool { + name: "-batch-read-email".to_string(), + description: "Read multiple emails.".to_string(), + strict: false, + defer_loading: Some(true), + parameters: crate::tools::spec::JsonSchema::Object { + properties: Default::default(), + required: None, + additional_properties: None, + }, + output_schema: None, + })], + })] + ); +} diff --git a/codex-rs/core/src/tools/handlers/tool_suggest.rs b/codex-rs/core/src/tools/handlers/tool_suggest.rs index 5483cac043..311f191bd0 100644 --- a/codex-rs/core/src/tools/handlers/tool_suggest.rs +++ b/codex-rs/core/src/tools/handlers/tool_suggest.rs @@ -294,172 +294,5 @@ fn verified_connector_suggestion_completed( } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - - #[test] - fn build_tool_suggestion_elicitation_request_uses_expected_shape() { - let args = ToolSuggestArgs { - tool_type: DiscoverableToolType::Connector, - action_type: DiscoverableToolAction::Install, - tool_id: "connector_2128aebfecb84f64a069897515042a44".to_string(), - suggest_reason: "Plan and reference events from your calendar".to_string(), - }; - let connector = AppInfo { - id: "connector_2128aebfecb84f64a069897515042a44".to_string(), - name: "Google Calendar".to_string(), - description: Some("Plan events and schedules.".to_string()), - logo_url: None, - logo_url_dark: None, - distribution_channel: None, - branding: None, - app_metadata: None, - labels: None, - install_url: Some( - "https://chatgpt.com/apps/google-calendar/connector_2128aebfecb84f64a069897515042a44" - .to_string(), - ), - is_accessible: false, - is_enabled: true, - plugin_display_names: Vec::new(), - }; - - let request = build_tool_suggestion_elicitation_request( - "thread-1".to_string(), - "turn-1".to_string(), - &args, - "Plan and reference events from your calendar", - &connector, - ); - - assert_eq!( - request, - McpServerElicitationRequestParams { - thread_id: "thread-1".to_string(), - turn_id: Some("turn-1".to_string()), - server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), - request: McpServerElicitationRequest::Form { - meta: Some(json!(ToolSuggestMeta { - codex_approval_kind: TOOL_SUGGEST_APPROVAL_KIND_VALUE, - tool_type: DiscoverableToolType::Connector, - suggest_type: DiscoverableToolAction::Install, - suggest_reason: "Plan and reference events from your calendar", - tool_id: "connector_2128aebfecb84f64a069897515042a44", - tool_name: "Google Calendar", - install_url: "https://chatgpt.com/apps/google-calendar/connector_2128aebfecb84f64a069897515042a44", - })), - message: "Google Calendar could help with this request.\n\nPlan and reference events from your calendar\n\nOpen ChatGPT to install it, then confirm here if you finish.".to_string(), - requested_schema: McpElicitationSchema { - schema_uri: None, - type_: McpElicitationObjectType::Object, - properties: BTreeMap::new(), - required: None, - }, - }, - } - ); - } - - #[test] - fn build_tool_suggestion_meta_uses_expected_shape() { - let meta = build_tool_suggestion_meta( - DiscoverableToolType::Connector, - DiscoverableToolAction::Install, - "Find and reference emails from your inbox", - "connector_68df038e0ba48191908c8434991bbac2", - "Gmail", - "https://chatgpt.com/apps/gmail/connector_68df038e0ba48191908c8434991bbac2", - ); - - assert_eq!( - meta, - ToolSuggestMeta { - codex_approval_kind: TOOL_SUGGEST_APPROVAL_KIND_VALUE, - tool_type: DiscoverableToolType::Connector, - suggest_type: DiscoverableToolAction::Install, - suggest_reason: "Find and reference emails from your inbox", - tool_id: "connector_68df038e0ba48191908c8434991bbac2", - tool_name: "Gmail", - install_url: "https://chatgpt.com/apps/gmail/connector_68df038e0ba48191908c8434991bbac2", - } - ); - } - - #[test] - fn verified_connector_suggestion_completed_requires_installed_connector() { - let accessible_connectors = vec![AppInfo { - id: "calendar".to_string(), - name: "Google Calendar".to_string(), - description: None, - logo_url: None, - logo_url_dark: None, - distribution_channel: None, - branding: None, - app_metadata: None, - labels: None, - install_url: None, - is_accessible: true, - is_enabled: true, - plugin_display_names: Vec::new(), - }]; - - assert!(verified_connector_suggestion_completed( - DiscoverableToolAction::Install, - "calendar", - &accessible_connectors, - )); - assert!(!verified_connector_suggestion_completed( - DiscoverableToolAction::Install, - "gmail", - &accessible_connectors, - )); - } - - #[test] - fn verified_connector_suggestion_completed_requires_enabled_connector_for_enable() { - let accessible_connectors = vec![ - AppInfo { - id: "calendar".to_string(), - name: "Google Calendar".to_string(), - description: None, - logo_url: None, - logo_url_dark: None, - distribution_channel: None, - branding: None, - app_metadata: None, - labels: None, - install_url: None, - is_accessible: true, - is_enabled: false, - plugin_display_names: Vec::new(), - }, - AppInfo { - id: "gmail".to_string(), - name: "Gmail".to_string(), - description: None, - logo_url: None, - logo_url_dark: None, - distribution_channel: None, - branding: None, - app_metadata: None, - labels: None, - install_url: None, - is_accessible: true, - is_enabled: true, - plugin_display_names: Vec::new(), - }, - ]; - - assert!(!verified_connector_suggestion_completed( - DiscoverableToolAction::Enable, - "calendar", - &accessible_connectors, - )); - assert!(verified_connector_suggestion_completed( - DiscoverableToolAction::Enable, - "gmail", - &accessible_connectors, - )); - } -} +#[path = "tool_suggest_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/tool_suggest_tests.rs b/codex-rs/core/src/tools/handlers/tool_suggest_tests.rs new file mode 100644 index 0000000000..a8c4541e91 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/tool_suggest_tests.rs @@ -0,0 +1,167 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn build_tool_suggestion_elicitation_request_uses_expected_shape() { + let args = ToolSuggestArgs { + tool_type: DiscoverableToolType::Connector, + action_type: DiscoverableToolAction::Install, + tool_id: "connector_2128aebfecb84f64a069897515042a44".to_string(), + suggest_reason: "Plan and reference events from your calendar".to_string(), + }; + let connector = AppInfo { + id: "connector_2128aebfecb84f64a069897515042a44".to_string(), + name: "Google Calendar".to_string(), + description: Some("Plan events and schedules.".to_string()), + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + branding: None, + app_metadata: None, + labels: None, + install_url: Some( + "https://chatgpt.com/apps/google-calendar/connector_2128aebfecb84f64a069897515042a44" + .to_string(), + ), + is_accessible: false, + is_enabled: true, + plugin_display_names: Vec::new(), + }; + + let request = build_tool_suggestion_elicitation_request( + "thread-1".to_string(), + "turn-1".to_string(), + &args, + "Plan and reference events from your calendar", + &connector, + ); + + assert_eq!( + request, + McpServerElicitationRequestParams { + thread_id: "thread-1".to_string(), + turn_id: Some("turn-1".to_string()), + server_name: CODEX_APPS_MCP_SERVER_NAME.to_string(), + request: McpServerElicitationRequest::Form { + meta: Some(json!(ToolSuggestMeta { + codex_approval_kind: TOOL_SUGGEST_APPROVAL_KIND_VALUE, + tool_type: DiscoverableToolType::Connector, + suggest_type: DiscoverableToolAction::Install, + suggest_reason: "Plan and reference events from your calendar", + tool_id: "connector_2128aebfecb84f64a069897515042a44", + tool_name: "Google Calendar", + install_url: "https://chatgpt.com/apps/google-calendar/connector_2128aebfecb84f64a069897515042a44", + })), + message: "Google Calendar could help with this request.\n\nPlan and reference events from your calendar\n\nOpen ChatGPT to install it, then confirm here if you finish.".to_string(), + requested_schema: McpElicitationSchema { + schema_uri: None, + type_: McpElicitationObjectType::Object, + properties: BTreeMap::new(), + required: None, + }, + }, + } + ); +} + +#[test] +fn build_tool_suggestion_meta_uses_expected_shape() { + let meta = build_tool_suggestion_meta( + DiscoverableToolType::Connector, + DiscoverableToolAction::Install, + "Find and reference emails from your inbox", + "connector_68df038e0ba48191908c8434991bbac2", + "Gmail", + "https://chatgpt.com/apps/gmail/connector_68df038e0ba48191908c8434991bbac2", + ); + + assert_eq!( + meta, + ToolSuggestMeta { + codex_approval_kind: TOOL_SUGGEST_APPROVAL_KIND_VALUE, + tool_type: DiscoverableToolType::Connector, + suggest_type: DiscoverableToolAction::Install, + suggest_reason: "Find and reference emails from your inbox", + tool_id: "connector_68df038e0ba48191908c8434991bbac2", + tool_name: "Gmail", + install_url: "https://chatgpt.com/apps/gmail/connector_68df038e0ba48191908c8434991bbac2", + } + ); +} + +#[test] +fn verified_connector_suggestion_completed_requires_installed_connector() { + let accessible_connectors = vec![AppInfo { + id: "calendar".to_string(), + name: "Google Calendar".to_string(), + description: None, + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + branding: None, + app_metadata: None, + labels: None, + install_url: None, + is_accessible: true, + is_enabled: true, + plugin_display_names: Vec::new(), + }]; + + assert!(verified_connector_suggestion_completed( + DiscoverableToolAction::Install, + "calendar", + &accessible_connectors, + )); + assert!(!verified_connector_suggestion_completed( + DiscoverableToolAction::Install, + "gmail", + &accessible_connectors, + )); +} + +#[test] +fn verified_connector_suggestion_completed_requires_enabled_connector_for_enable() { + let accessible_connectors = vec![ + AppInfo { + id: "calendar".to_string(), + name: "Google Calendar".to_string(), + description: None, + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + branding: None, + app_metadata: None, + labels: None, + install_url: None, + is_accessible: true, + is_enabled: false, + plugin_display_names: Vec::new(), + }, + AppInfo { + id: "gmail".to_string(), + name: "Gmail".to_string(), + description: None, + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + branding: None, + app_metadata: None, + labels: None, + install_url: None, + is_accessible: true, + is_enabled: true, + plugin_display_names: Vec::new(), + }, + ]; + + assert!(!verified_connector_suggestion_completed( + DiscoverableToolAction::Enable, + "calendar", + &accessible_connectors, + )); + assert!(verified_connector_suggestion_completed( + DiscoverableToolAction::Enable, + "gmail", + &accessible_connectors, + )); +} diff --git a/codex-rs/core/src/tools/handlers/unified_exec.rs b/codex-rs/core/src/tools/handlers/unified_exec.rs index edc6763ef2..02c4987dc3 100644 --- a/codex-rs/core/src/tools/handlers/unified_exec.rs +++ b/codex-rs/core/src/tools/handlers/unified_exec.rs @@ -329,131 +329,5 @@ pub(crate) fn get_command( } #[cfg(test)] -mod tests { - use super::*; - use crate::shell::default_user_shell; - use crate::tools::handlers::parse_arguments_with_base_path; - use crate::tools::handlers::resolve_workdir_base_path; - use codex_protocol::models::FileSystemPermissions; - use codex_protocol::models::PermissionProfile; - use codex_utils_absolute_path::AbsolutePathBuf; - use pretty_assertions::assert_eq; - use std::fs; - use std::sync::Arc; - use tempfile::tempdir; - - #[test] - fn test_get_command_uses_default_shell_when_unspecified() -> anyhow::Result<()> { - let json = r#"{"cmd": "echo hello"}"#; - - let args: ExecCommandArgs = parse_arguments(json)?; - - assert!(args.shell.is_none()); - - let command = - get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?; - - assert_eq!(command.len(), 3); - assert_eq!(command[2], "echo hello"); - Ok(()) - } - - #[test] - fn test_get_command_respects_explicit_bash_shell() -> anyhow::Result<()> { - let json = r#"{"cmd": "echo hello", "shell": "/bin/bash"}"#; - - let args: ExecCommandArgs = parse_arguments(json)?; - - assert_eq!(args.shell.as_deref(), Some("/bin/bash")); - - let command = - get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?; - - assert_eq!(command.last(), Some(&"echo hello".to_string())); - if command - .iter() - .any(|arg| arg.eq_ignore_ascii_case("-Command")) - { - assert!(command.contains(&"-NoProfile".to_string())); - } - Ok(()) - } - - #[test] - fn test_get_command_respects_explicit_powershell_shell() -> anyhow::Result<()> { - let json = r#"{"cmd": "echo hello", "shell": "powershell"}"#; - - let args: ExecCommandArgs = parse_arguments(json)?; - - assert_eq!(args.shell.as_deref(), Some("powershell")); - - let command = - get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?; - - assert_eq!(command[2], "echo hello"); - Ok(()) - } - - #[test] - fn test_get_command_respects_explicit_cmd_shell() -> anyhow::Result<()> { - let json = r#"{"cmd": "echo hello", "shell": "cmd"}"#; - - let args: ExecCommandArgs = parse_arguments(json)?; - - assert_eq!(args.shell.as_deref(), Some("cmd")); - - let command = - get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?; - - assert_eq!(command[2], "echo hello"); - Ok(()) - } - - #[test] - fn test_get_command_rejects_explicit_login_when_disallowed() -> anyhow::Result<()> { - let json = r#"{"cmd": "echo hello", "login": true}"#; - - let args: ExecCommandArgs = parse_arguments(json)?; - let err = get_command(&args, Arc::new(default_user_shell()), false) - .expect_err("explicit login should be rejected"); - - assert!( - err.contains("login shell is disabled by config"), - "unexpected error: {err}" - ); - Ok(()) - } - - #[test] - fn exec_command_args_resolve_relative_additional_permissions_against_workdir() - -> anyhow::Result<()> { - let cwd = tempdir()?; - let workdir = cwd.path().join("nested"); - fs::create_dir_all(&workdir)?; - let expected_write = workdir.join("relative-write.txt"); - let json = r#"{ - "cmd": "echo hello", - "workdir": "nested", - "additional_permissions": { - "file_system": { - "write": ["./relative-write.txt"] - } - } - }"#; - - let base_path = resolve_workdir_base_path(json, cwd.path())?; - let args: ExecCommandArgs = parse_arguments_with_base_path(json, base_path.as_path())?; - - assert_eq!( - args.additional_permissions, - Some(PermissionProfile { - file_system: Some(FileSystemPermissions { - read: None, - write: Some(vec![AbsolutePathBuf::try_from(expected_write)?]), - }), - ..Default::default() - }) - ); - Ok(()) - } -} +#[path = "unified_exec_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/handlers/unified_exec_tests.rs b/codex-rs/core/src/tools/handlers/unified_exec_tests.rs new file mode 100644 index 0000000000..ee31aefc14 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/unified_exec_tests.rs @@ -0,0 +1,126 @@ +use super::*; +use crate::shell::default_user_shell; +use crate::tools::handlers::parse_arguments_with_base_path; +use crate::tools::handlers::resolve_workdir_base_path; +use codex_protocol::models::FileSystemPermissions; +use codex_protocol::models::PermissionProfile; +use codex_utils_absolute_path::AbsolutePathBuf; +use pretty_assertions::assert_eq; +use std::fs; +use std::sync::Arc; +use tempfile::tempdir; + +#[test] +fn test_get_command_uses_default_shell_when_unspecified() -> anyhow::Result<()> { + let json = r#"{"cmd": "echo hello"}"#; + + let args: ExecCommandArgs = parse_arguments(json)?; + + assert!(args.shell.is_none()); + + let command = + get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?; + + assert_eq!(command.len(), 3); + assert_eq!(command[2], "echo hello"); + Ok(()) +} + +#[test] +fn test_get_command_respects_explicit_bash_shell() -> anyhow::Result<()> { + let json = r#"{"cmd": "echo hello", "shell": "/bin/bash"}"#; + + let args: ExecCommandArgs = parse_arguments(json)?; + + assert_eq!(args.shell.as_deref(), Some("/bin/bash")); + + let command = + get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?; + + assert_eq!(command.last(), Some(&"echo hello".to_string())); + if command + .iter() + .any(|arg| arg.eq_ignore_ascii_case("-Command")) + { + assert!(command.contains(&"-NoProfile".to_string())); + } + Ok(()) +} + +#[test] +fn test_get_command_respects_explicit_powershell_shell() -> anyhow::Result<()> { + let json = r#"{"cmd": "echo hello", "shell": "powershell"}"#; + + let args: ExecCommandArgs = parse_arguments(json)?; + + assert_eq!(args.shell.as_deref(), Some("powershell")); + + let command = + get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?; + + assert_eq!(command[2], "echo hello"); + Ok(()) +} + +#[test] +fn test_get_command_respects_explicit_cmd_shell() -> anyhow::Result<()> { + let json = r#"{"cmd": "echo hello", "shell": "cmd"}"#; + + let args: ExecCommandArgs = parse_arguments(json)?; + + assert_eq!(args.shell.as_deref(), Some("cmd")); + + let command = + get_command(&args, Arc::new(default_user_shell()), true).map_err(anyhow::Error::msg)?; + + assert_eq!(command[2], "echo hello"); + Ok(()) +} + +#[test] +fn test_get_command_rejects_explicit_login_when_disallowed() -> anyhow::Result<()> { + let json = r#"{"cmd": "echo hello", "login": true}"#; + + let args: ExecCommandArgs = parse_arguments(json)?; + let err = get_command(&args, Arc::new(default_user_shell()), false) + .expect_err("explicit login should be rejected"); + + assert!( + err.contains("login shell is disabled by config"), + "unexpected error: {err}" + ); + Ok(()) +} + +#[test] +fn exec_command_args_resolve_relative_additional_permissions_against_workdir() -> anyhow::Result<()> +{ + let cwd = tempdir()?; + let workdir = cwd.path().join("nested"); + fs::create_dir_all(&workdir)?; + let expected_write = workdir.join("relative-write.txt"); + let json = r#"{ + "cmd": "echo hello", + "workdir": "nested", + "additional_permissions": { + "file_system": { + "write": ["./relative-write.txt"] + } + } + }"#; + + let base_path = resolve_workdir_base_path(json, cwd.path())?; + let args: ExecCommandArgs = parse_arguments_with_base_path(json, base_path.as_path())?; + + assert_eq!( + args.additional_permissions, + Some(PermissionProfile { + file_system: Some(FileSystemPermissions { + read: None, + write: Some(vec![AbsolutePathBuf::try_from(expected_write)?]), + }), + ..Default::default() + }) + ); + Ok(()) +} diff --git a/codex-rs/core/src/tools/js_repl/mod.rs b/codex-rs/core/src/tools/js_repl/mod.rs index a6ce016ffe..7a9089a51c 100644 --- a/codex-rs/core/src/tools/js_repl/mod.rs +++ b/codex-rs/core/src/tools/js_repl/mod.rs @@ -1740,2339 +1740,5 @@ pub(crate) fn resolve_node(config_path: Option<&Path>) -> Option { } #[cfg(test)] -mod tests { - use super::*; - use crate::codex::make_session_and_context; - use crate::codex::make_session_and_context_with_dynamic_tools_and_rx; - use crate::features::Feature; - use crate::protocol::AskForApproval; - use crate::protocol::EventMsg; - use crate::protocol::SandboxPolicy; - use crate::turn_diff_tracker::TurnDiffTracker; - use codex_protocol::dynamic_tools::DynamicToolCallOutputContentItem; - use codex_protocol::dynamic_tools::DynamicToolResponse; - use codex_protocol::dynamic_tools::DynamicToolSpec; - use codex_protocol::models::FunctionCallOutputContentItem; - use codex_protocol::models::FunctionCallOutputPayload; - use codex_protocol::models::ImageDetail; - use codex_protocol::models::ResponseInputItem; - use codex_protocol::openai_models::InputModality; - use pretty_assertions::assert_eq; - use std::fs; - use std::path::Path; - use tempfile::tempdir; - - fn set_danger_full_access(turn: &mut crate::codex::TurnContext) { - turn.sandbox_policy - .set(SandboxPolicy::DangerFullAccess) - .expect("test setup should allow updating sandbox policy"); - turn.file_system_sandbox_policy = - crate::protocol::FileSystemSandboxPolicy::from(turn.sandbox_policy.get()); - turn.network_sandbox_policy = - crate::protocol::NetworkSandboxPolicy::from(turn.sandbox_policy.get()); - } - - #[test] - fn node_version_parses_v_prefix_and_suffix() { - let version = NodeVersion::parse("v25.1.0-nightly.2024").unwrap(); - assert_eq!( - version, - NodeVersion { - major: 25, - minor: 1, - patch: 0, - } - ); - } - - #[test] - fn truncate_utf8_prefix_by_bytes_preserves_character_boundaries() { - let input = "aé🙂z"; - assert_eq!(truncate_utf8_prefix_by_bytes(input, 0), ""); - assert_eq!(truncate_utf8_prefix_by_bytes(input, 1), "a"); - assert_eq!(truncate_utf8_prefix_by_bytes(input, 2), "a"); - assert_eq!(truncate_utf8_prefix_by_bytes(input, 3), "aé"); - assert_eq!(truncate_utf8_prefix_by_bytes(input, 6), "aé"); - assert_eq!(truncate_utf8_prefix_by_bytes(input, 7), "aé🙂"); - assert_eq!(truncate_utf8_prefix_by_bytes(input, 8), "aé🙂z"); - } - - #[test] - fn stderr_tail_applies_line_and_byte_limits() { - let mut lines = VecDeque::new(); - let per_line_cap = JS_REPL_STDERR_TAIL_LINE_MAX_BYTES.min(JS_REPL_STDERR_TAIL_MAX_BYTES); - let long = "x".repeat(per_line_cap + 128); - let bounded = push_stderr_tail_line(&mut lines, &long); - assert_eq!(bounded.len(), per_line_cap); - - for i in 0..50 { - let line = format!("line-{i}-{}", "y".repeat(200)); - push_stderr_tail_line(&mut lines, &line); - } - - assert!(lines.len() <= JS_REPL_STDERR_TAIL_LINE_LIMIT); - assert!(lines.iter().all(|line| line.len() <= per_line_cap)); - assert!(stderr_tail_formatted_bytes(&lines) <= JS_REPL_STDERR_TAIL_MAX_BYTES); - assert_eq!( - format_stderr_tail(&lines).len(), - stderr_tail_formatted_bytes(&lines) - ); - } - - #[test] - fn model_kernel_failure_details_are_structured_and_truncated() { - let snapshot = KernelDebugSnapshot { - pid: Some(42), - status: "exited(code=1)".to_string(), - stderr_tail: "s".repeat(JS_REPL_MODEL_DIAG_STDERR_MAX_BYTES + 400), - }; - let stream_error = "e".repeat(JS_REPL_MODEL_DIAG_ERROR_MAX_BYTES + 200); - let message = with_model_kernel_failure_message( - "js_repl kernel exited unexpectedly", - "stdout_eof", - Some(&stream_error), - &snapshot, - ); - assert!(message.starts_with("js_repl kernel exited unexpectedly\n\njs_repl diagnostics: ")); - let (_prefix, encoded) = message - .split_once("js_repl diagnostics: ") - .expect("diagnostics suffix should be present"); - let parsed: serde_json::Value = - serde_json::from_str(encoded).expect("diagnostics should be valid json"); - assert_eq!( - parsed.get("reason").and_then(|v| v.as_str()), - Some("stdout_eof") - ); - assert_eq!( - parsed.get("kernel_pid").and_then(serde_json::Value::as_u64), - Some(42) - ); - assert_eq!( - parsed.get("kernel_status").and_then(|v| v.as_str()), - Some("exited(code=1)") - ); - assert!( - parsed - .get("kernel_stderr_tail") - .and_then(|v| v.as_str()) - .expect("kernel_stderr_tail should be present") - .len() - <= JS_REPL_MODEL_DIAG_STDERR_MAX_BYTES - ); - assert!( - parsed - .get("stream_error") - .and_then(|v| v.as_str()) - .expect("stream_error should be present") - .len() - <= JS_REPL_MODEL_DIAG_ERROR_MAX_BYTES - ); - } - - #[test] - fn write_error_diagnostics_only_attach_for_likely_kernel_failures() { - let running = KernelDebugSnapshot { - pid: Some(7), - status: "running".to_string(), - stderr_tail: "".to_string(), - }; - let exited = KernelDebugSnapshot { - pid: Some(7), - status: "exited(code=1)".to_string(), - stderr_tail: "".to_string(), - }; - assert!(!should_include_model_diagnostics_for_write_error( - "failed to flush kernel message: other io error", - &running - )); - assert!(should_include_model_diagnostics_for_write_error( - "failed to write to kernel: Broken pipe (os error 32)", - &running - )); - assert!(should_include_model_diagnostics_for_write_error( - "failed to write to kernel: some other io error", - &exited - )); - } - - #[test] - fn js_repl_internal_tool_guard_matches_expected_names() { - assert!(is_js_repl_internal_tool("js_repl")); - assert!(is_js_repl_internal_tool("js_repl_reset")); - assert!(!is_js_repl_internal_tool("shell_command")); - assert!(!is_js_repl_internal_tool("list_mcp_resources")); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn wait_for_exec_tool_calls_map_drains_inflight_calls_without_hanging() { - let exec_tool_calls = Arc::new(Mutex::new(HashMap::new())); - - for _ in 0..128 { - let exec_id = Uuid::new_v4().to_string(); - exec_tool_calls - .lock() - .await - .insert(exec_id.clone(), ExecToolCalls::default()); - assert!( - JsReplManager::begin_exec_tool_call(&exec_tool_calls, &exec_id) - .await - .is_some() - ); - - let wait_map = Arc::clone(&exec_tool_calls); - let wait_exec_id = exec_id.clone(); - let waiter = tokio::spawn(async move { - JsReplManager::wait_for_exec_tool_calls_map(&wait_map, &wait_exec_id).await; - }); - - let finish_map = Arc::clone(&exec_tool_calls); - let finish_exec_id = exec_id.clone(); - let finisher = tokio::spawn(async move { - tokio::task::yield_now().await; - JsReplManager::finish_exec_tool_call(&finish_map, &finish_exec_id).await; - }); - - tokio::time::timeout(Duration::from_secs(1), waiter) - .await - .expect("wait_for_exec_tool_calls_map should not hang") - .expect("wait task should not panic"); - finisher.await.expect("finish task should not panic"); - - JsReplManager::clear_exec_tool_calls_map(&exec_tool_calls, &exec_id).await; - } - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn reset_waits_for_exec_lock_before_clearing_exec_tool_calls() { - let manager = JsReplManager::new(None, Vec::new()) - .await - .expect("manager should initialize"); - let permit = manager - .exec_lock - .clone() - .acquire_owned() - .await - .expect("lock should be acquirable"); - let exec_id = Uuid::new_v4().to_string(); - manager.register_exec_tool_calls(&exec_id).await; - - let reset_manager = Arc::clone(&manager); - let mut reset_task = tokio::spawn(async move { reset_manager.reset().await }); - tokio::time::sleep(Duration::from_millis(50)).await; - - assert!( - !reset_task.is_finished(), - "reset should wait until execute lock is released" - ); - assert!( - manager.exec_tool_calls.lock().await.contains_key(&exec_id), - "reset must not clear tool-call contexts while execute lock is held" - ); - - drop(permit); - - tokio::time::timeout(Duration::from_secs(1), &mut reset_task) - .await - .expect("reset should complete after execute lock release") - .expect("reset task should not panic") - .expect("reset should succeed"); - assert!( - !manager.exec_tool_calls.lock().await.contains_key(&exec_id), - "reset should clear tool-call contexts after lock acquisition" - ); - } - - #[test] - fn summarize_tool_call_response_for_multimodal_function_output() { - let response = ResponseInputItem::FunctionCallOutput { - call_id: "call-1".to_string(), - output: FunctionCallOutputPayload::from_content_items(vec![ - FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,abcd".to_string(), - detail: None, - }, - ]), - }; - - let actual = JsReplManager::summarize_tool_call_response(&response); - - assert_eq!( - actual, - JsReplToolCallResponseSummary { - response_type: Some("function_call_output".to_string()), - payload_kind: Some(JsReplToolCallPayloadKind::FunctionContentItems), - payload_text_preview: None, - payload_text_length: None, - payload_item_count: Some(1), - text_item_count: Some(0), - image_item_count: Some(1), - structured_content_present: None, - result_is_error: None, - } - ); - } - - #[tokio::test] - async fn emitted_image_content_item_drops_unsupported_explicit_detail() { - let (_session, turn) = make_session_and_context().await; - let content_item = emitted_image_content_item( - &turn, - "data:image/png;base64,AAA".to_string(), - Some(ImageDetail::Low), - ); - assert_eq!( - content_item, - FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,AAA".to_string(), - detail: None, - } - ); - } - - #[tokio::test] - async fn emitted_image_content_item_does_not_force_original_when_enabled() { - let (_session, mut turn) = make_session_and_context().await; - Arc::make_mut(&mut turn.config) - .features - .enable(Feature::ImageDetailOriginal) - .expect("test config should allow feature update"); - turn.features - .enable(Feature::ImageDetailOriginal) - .expect("test turn features should allow feature update"); - turn.model_info.supports_image_detail_original = true; - - let content_item = - emitted_image_content_item(&turn, "data:image/png;base64,AAA".to_string(), None); - - assert_eq!( - content_item, - FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,AAA".to_string(), - detail: None, - } - ); - } - - #[tokio::test] - async fn emitted_image_content_item_allows_explicit_original_detail_when_enabled() { - let (_session, mut turn) = make_session_and_context().await; - Arc::make_mut(&mut turn.config) - .features - .enable(Feature::ImageDetailOriginal) - .expect("test config should allow feature update"); - turn.features - .enable(Feature::ImageDetailOriginal) - .expect("test turn features should allow feature update"); - turn.model_info.supports_image_detail_original = true; - - let content_item = emitted_image_content_item( - &turn, - "data:image/png;base64,AAA".to_string(), - Some(ImageDetail::Original), - ); - - assert_eq!( - content_item, - FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,AAA".to_string(), - detail: Some(ImageDetail::Original), - } - ); - } - - #[tokio::test] - async fn emitted_image_content_item_drops_explicit_original_detail_when_disabled() { - let (_session, turn) = make_session_and_context().await; - - let content_item = emitted_image_content_item( - &turn, - "data:image/png;base64,AAA".to_string(), - Some(ImageDetail::Original), - ); - - assert_eq!( - content_item, - FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,AAA".to_string(), - detail: None, - } - ); - } - - #[test] - fn validate_emitted_image_url_accepts_case_insensitive_data_scheme() { - assert_eq!( - validate_emitted_image_url("DATA:image/png;base64,AAA"), - Ok(()) - ); - } - - #[test] - fn validate_emitted_image_url_rejects_non_data_scheme() { - assert_eq!( - validate_emitted_image_url("https://example.com/image.png"), - Err("codex.emitImage only accepts data URLs".to_string()) - ); - } - - #[test] - fn summarize_tool_call_response_for_multimodal_custom_output() { - let response = ResponseInputItem::CustomToolCallOutput { - call_id: "call-1".to_string(), - output: FunctionCallOutputPayload::from_content_items(vec![ - FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,abcd".to_string(), - detail: None, - }, - ]), - }; - - let actual = JsReplManager::summarize_tool_call_response(&response); - - assert_eq!( - actual, - JsReplToolCallResponseSummary { - response_type: Some("custom_tool_call_output".to_string()), - payload_kind: Some(JsReplToolCallPayloadKind::CustomContentItems), - payload_text_preview: None, - payload_text_length: None, - payload_item_count: Some(1), - text_item_count: Some(0), - image_item_count: Some(1), - structured_content_present: None, - result_is_error: None, - } - ); - } - - #[test] - fn summarize_tool_call_error_marks_error_payload() { - let actual = JsReplManager::summarize_tool_call_error("tool failed"); - - assert_eq!( - actual, - JsReplToolCallResponseSummary { - response_type: None, - payload_kind: Some(JsReplToolCallPayloadKind::Error), - payload_text_preview: Some("tool failed".to_string()), - payload_text_length: Some("tool failed".len()), - payload_item_count: None, - text_item_count: None, - image_item_count: None, - structured_content_present: None, - result_is_error: None, - } - ); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn reset_clears_inflight_exec_tool_calls_without_waiting() { - let manager = JsReplManager::new(None, Vec::new()) - .await - .expect("manager should initialize"); - let exec_id = Uuid::new_v4().to_string(); - manager.register_exec_tool_calls(&exec_id).await; - assert!( - JsReplManager::begin_exec_tool_call(&manager.exec_tool_calls, &exec_id) - .await - .is_some() - ); - - let wait_manager = Arc::clone(&manager); - let wait_exec_id = exec_id.clone(); - let waiter = tokio::spawn(async move { - wait_manager.wait_for_exec_tool_calls(&wait_exec_id).await; - }); - tokio::task::yield_now().await; - - tokio::time::timeout(Duration::from_secs(1), manager.reset()) - .await - .expect("reset should not hang") - .expect("reset should succeed"); - - tokio::time::timeout(Duration::from_secs(1), waiter) - .await - .expect("waiter should be released") - .expect("wait task should not panic"); - - assert!(manager.exec_tool_calls.lock().await.is_empty()); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn reset_aborts_inflight_exec_tool_tasks() { - let manager = JsReplManager::new(None, Vec::new()) - .await - .expect("manager should initialize"); - let exec_id = Uuid::new_v4().to_string(); - manager.register_exec_tool_calls(&exec_id).await; - let reset_cancel = JsReplManager::begin_exec_tool_call(&manager.exec_tool_calls, &exec_id) - .await - .expect("exec should be registered"); - - let task = tokio::spawn(async move { - tokio::select! { - _ = reset_cancel.cancelled() => "cancelled", - _ = tokio::time::sleep(Duration::from_secs(60)) => "timed_out", - } - }); - - tokio::time::timeout(Duration::from_secs(1), manager.reset()) - .await - .expect("reset should not hang") - .expect("reset should succeed"); - - let outcome = tokio::time::timeout(Duration::from_secs(1), task) - .await - .expect("cancelled task should resolve promptly") - .expect("task should not panic"); - assert_eq!(outcome, "cancelled"); - } - - async fn can_run_js_repl_runtime_tests() -> bool { - // These white-box runtime tests are required on macOS. Linux relies on - // the codex-linux-sandbox arg0 dispatch path, which is exercised in - // integration tests instead. - cfg!(target_os = "macos") - } - fn write_js_repl_test_package_source( - base: &Path, - name: &str, - source: &str, - ) -> anyhow::Result<()> { - let pkg_dir = base.join("node_modules").join(name); - fs::create_dir_all(&pkg_dir)?; - fs::write( - pkg_dir.join("package.json"), - format!( - "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"type\": \"module\",\n \"exports\": {{\n \"import\": \"./index.js\"\n }}\n}}\n" - ), - )?; - fs::write(pkg_dir.join("index.js"), source)?; - Ok(()) - } - - fn write_js_repl_test_package(base: &Path, name: &str, value: &str) -> anyhow::Result<()> { - write_js_repl_test_package_source( - base, - name, - &format!("export const value = \"{value}\";\n"), - )?; - Ok(()) - } - - fn write_js_repl_test_module( - base: &Path, - relative: &str, - contents: &str, - ) -> anyhow::Result<()> { - let module_path = base.join(relative); - if let Some(parent) = module_path.parent() { - fs::create_dir_all(parent)?; - } - fs::write(module_path, contents)?; - Ok(()) - } - - #[tokio::test] - async fn js_repl_timeout_does_not_deadlock() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let result = tokio::time::timeout( - Duration::from_secs(3), - manager.execute( - session, - turn, - tracker, - JsReplArgs { - code: "while (true) {}".to_string(), - timeout_ms: Some(50), - }, - ), - ) - .await - .expect("execute should return, not deadlock") - .expect_err("expected timeout error"); - - assert_eq!( - result.to_string(), - "js_repl execution timed out; kernel reset, rerun your request" - ); - Ok(()) - } - - #[tokio::test] - async fn js_repl_timeout_kills_kernel_process() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - manager - .execute( - Arc::clone(&session), - Arc::clone(&turn), - Arc::clone(&tracker), - JsReplArgs { - code: "console.log('warmup');".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - - let child = { - let guard = manager.kernel.lock().await; - let state = guard.as_ref().expect("kernel should exist after warmup"); - Arc::clone(&state.child) - }; - - let result = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "while (true) {}".to_string(), - timeout_ms: Some(50), - }, - ) - .await - .expect_err("expected timeout error"); - - assert_eq!( - result.to_string(), - "js_repl execution timed out; kernel reset, rerun your request" - ); - - let exit_state = { - let mut child = child.lock().await; - child.try_wait()? - }; - assert!( - exit_state.is_some(), - "timed out js_repl execution should kill previous kernel process" - ); - Ok(()) - } - - #[tokio::test] - async fn js_repl_forced_kernel_exit_recovers_on_next_exec() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - manager - .execute( - Arc::clone(&session), - Arc::clone(&turn), - Arc::clone(&tracker), - JsReplArgs { - code: "console.log('warmup');".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - - let child = { - let guard = manager.kernel.lock().await; - let state = guard.as_ref().expect("kernel should exist after warmup"); - Arc::clone(&state.child) - }; - JsReplManager::kill_kernel_child(&child, "test_crash").await; - tokio::time::timeout(Duration::from_secs(1), async { - loop { - let cleared = { - let guard = manager.kernel.lock().await; - guard - .as_ref() - .is_none_or(|state| !Arc::ptr_eq(&state.child, &child)) - }; - if cleared { - return; - } - tokio::time::sleep(Duration::from_millis(10)).await; - } - }) - .await - .expect("host should clear dead kernel state promptly"); - - let result = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "console.log('after-kill');".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(result.output.contains("after-kill")); - Ok(()) - } - - #[tokio::test] - async fn js_repl_uncaught_exception_returns_exec_error_and_recovers() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = crate::codex::make_session_and_context().await; - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - manager - .execute( - Arc::clone(&session), - Arc::clone(&turn), - Arc::clone(&tracker), - JsReplArgs { - code: "console.log('warmup');".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - - let child = { - let guard = manager.kernel.lock().await; - let state = guard.as_ref().expect("kernel should exist after warmup"); - Arc::clone(&state.child) - }; - - let err = tokio::time::timeout( - Duration::from_secs(3), - manager.execute( - Arc::clone(&session), - Arc::clone(&turn), - Arc::clone(&tracker), - JsReplArgs { - code: "setTimeout(() => { throw new Error('boom'); }, 0);\nawait new Promise(() => {});".to_string(), - timeout_ms: Some(10_000), - }, - ), - ) - .await - .expect("uncaught exception should fail promptly") - .expect_err("expected uncaught exception to fail the exec"); - - let message = err.to_string(); - assert!(message.contains("js_repl kernel uncaught exception: boom")); - assert!(message.contains("kernel reset.")); - assert!(message.contains("Catch or handle async errors")); - assert!(!message.contains("js_repl kernel exited unexpectedly")); - - tokio::time::timeout(Duration::from_secs(1), async { - loop { - let exited = { - let mut child = child.lock().await; - child.try_wait()?.is_some() - }; - if exited { - return Ok::<(), anyhow::Error>(()); - } - tokio::time::sleep(Duration::from_millis(10)).await; - } - }) - .await - .expect("uncaught exception should terminate the previous kernel process")?; - - tokio::time::timeout(Duration::from_secs(1), async { - loop { - let cleared = { - let guard = manager.kernel.lock().await; - guard - .as_ref() - .is_none_or(|state| !Arc::ptr_eq(&state.child, &child)) - }; - if cleared { - return; - } - tokio::time::sleep(Duration::from_millis(10)).await; - } - }) - .await - .expect("host should clear dead kernel state promptly"); - - let next = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "console.log('after reset');".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(next.output.contains("after reset")); - Ok(()) - } - - #[tokio::test] - async fn js_repl_waits_for_unawaited_tool_calls_before_completion() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, mut turn) = make_session_and_context().await; - turn.approval_policy - .set(AskForApproval::Never) - .expect("test setup should allow updating approval policy"); - set_danger_full_access(&mut turn); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let marker = turn - .cwd - .join(format!("js-repl-unawaited-marker-{}.txt", Uuid::new_v4())); - let marker_json = serde_json::to_string(&marker.to_string_lossy().to_string())?; - let result = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: format!( - r#" -const marker = {marker_json}; -void codex.tool("shell_command", {{ command: `sleep 0.35; printf js_repl_unawaited_done > "${{marker}}"` }}); -console.log("cell-complete"); -"# - ), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(result.output.contains("cell-complete")); - let marker_contents = tokio::fs::read_to_string(&marker).await?; - assert_eq!(marker_contents, "js_repl_unawaited_done"); - let _ = tokio::fs::remove_file(&marker).await; - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_does_not_auto_attach_image_via_view_image_tool() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, mut turn) = make_session_and_context().await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - turn.approval_policy - .set(AskForApproval::Never) - .expect("test setup should allow updating approval policy"); - set_danger_full_access(&mut turn); - - let session = Arc::new(session); - let turn = Arc::new(turn); - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -const fs = await import("node:fs/promises"); -const path = await import("node:path"); -const imagePath = path.join(codex.tmpDir, "js-repl-view-image.png"); -const png = Buffer.from( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", - "base64" -); -await fs.writeFile(imagePath, png); -const out = await codex.tool("view_image", { path: imagePath }); -console.log(out.type); -"#; - - let result = manager - .execute( - Arc::clone(&session), - turn, - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ) - .await?; - assert!(result.output.contains("function_call_output")); - assert!(result.content_items.is_empty()); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_can_emit_image_via_view_image_tool() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, mut turn) = make_session_and_context().await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - turn.approval_policy - .set(AskForApproval::Never) - .expect("test setup should allow updating approval policy"); - set_danger_full_access(&mut turn); - - let session = Arc::new(session); - let turn = Arc::new(turn); - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -const fs = await import("node:fs/promises"); -const path = await import("node:path"); -const imagePath = path.join(codex.tmpDir, "js-repl-view-image-explicit.png"); -const png = Buffer.from( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", - "base64" -); -await fs.writeFile(imagePath, png); -const out = await codex.tool("view_image", { path: imagePath }); -await codex.emitImage(out); -console.log(out.type); -"#; - - let result = manager - .execute( - Arc::clone(&session), - turn, - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ) - .await?; - assert!(result.output.contains("function_call_output")); - assert_eq!( - result.content_items.as_slice(), - [FunctionCallOutputContentItem::InputImage { - image_url: - "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" - .to_string(), - detail: None, - }] - .as_slice() - ); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_can_emit_image_from_bytes_and_mime_type() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - - let session = Arc::new(session); - let turn = Arc::new(turn); - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -const png = Buffer.from( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", - "base64" -); -await codex.emitImage({ bytes: png, mimeType: "image/png" }); -"#; - - let result = manager - .execute( - Arc::clone(&session), - turn, - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ) - .await?; - assert_eq!( - result.content_items.as_slice(), - [FunctionCallOutputContentItem::InputImage { - image_url: - "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" - .to_string(), - detail: None, - }] - .as_slice() - ); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_can_emit_multiple_images_in_one_cell() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - - let session = Arc::new(session); - let turn = Arc::new(turn); - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -await codex.emitImage( - "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" -); -await codex.emitImage( - "data:image/gif;base64,R0lGODdhAQABAIAAAP///////ywAAAAAAQABAAACAkQBADs=" -); -"#; - - let result = manager - .execute( - Arc::clone(&session), - turn, - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ) - .await?; - assert_eq!( - result.content_items.as_slice(), - [ - FunctionCallOutputContentItem::InputImage { - image_url: - "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" - .to_string(), - detail: None, - }, - FunctionCallOutputContentItem::InputImage { - image_url: - "data:image/gif;base64,R0lGODdhAQABAIAAAP///////ywAAAAAAQABAAACAkQBADs=" - .to_string(), - detail: None, - }, - ] - .as_slice() - ); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_waits_for_unawaited_emit_image_before_completion() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - - let session = Arc::new(session); - let turn = Arc::new(turn); - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -void codex.emitImage( - "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" -); -console.log("cell-complete"); -"#; - - let result = manager - .execute( - Arc::clone(&session), - turn, - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ) - .await?; - assert!(result.output.contains("cell-complete")); - assert_eq!( - result.content_items.as_slice(), - [FunctionCallOutputContentItem::InputImage { - image_url: - "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" - .to_string(), - detail: None, - }] - .as_slice() - ); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_unawaited_emit_image_errors_fail_cell() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - - let session = Arc::new(session); - let turn = Arc::new(turn); - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -void codex.emitImage({ bytes: new Uint8Array(), mimeType: "image/png" }); -console.log("cell-complete"); -"#; - - let err = manager - .execute( - Arc::clone(&session), - turn, - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ) - .await - .expect_err("unawaited invalid emitImage should fail"); - assert!(err.to_string().contains("expected non-empty bytes")); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_caught_emit_image_error_does_not_fail_cell() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - - let session = Arc::new(session); - let turn = Arc::new(turn); - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -try { - await codex.emitImage({ bytes: new Uint8Array(), mimeType: "image/png" }); -} catch (error) { - console.log(error.message); -} -console.log("cell-complete"); -"#; - - let result = manager - .execute( - Arc::clone(&session), - turn, - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ) - .await?; - assert!(result.output.contains("expected non-empty bytes")); - assert!(result.output.contains("cell-complete")); - assert!(result.content_items.is_empty()); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_emit_image_requires_explicit_mime_type_for_bytes() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - - let session = Arc::new(session); - let turn = Arc::new(turn); - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -const png = Buffer.from( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", - "base64" -); -await codex.emitImage({ bytes: png }); -"#; - - let err = manager - .execute( - Arc::clone(&session), - turn, - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ) - .await - .expect_err("missing mimeType should fail"); - assert!(err.to_string().contains("expected a non-empty mimeType")); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_emit_image_rejects_non_data_url() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - - let session = Arc::new(session); - let turn = Arc::new(turn); - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -await codex.emitImage("https://example.com/image.png"); -"#; - - let err = manager - .execute( - Arc::clone(&session), - turn, - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ) - .await - .expect_err("non-data URLs should fail"); - assert!(err.to_string().contains("only accepts data URLs")); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_emit_image_accepts_case_insensitive_data_url() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - - let session = Arc::new(session); - let turn = Arc::new(turn); - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -await codex.emitImage("DATA:image/png;base64,AAA"); -"#; - - let result = manager - .execute( - Arc::clone(&session), - turn, - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ) - .await?; - assert_eq!( - result.content_items.as_slice(), - [FunctionCallOutputContentItem::InputImage { - image_url: "DATA:image/png;base64,AAA".to_string(), - detail: None, - }] - .as_slice() - ); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_emit_image_rejects_invalid_detail() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - - let session = Arc::new(session); - let turn = Arc::new(turn); - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -const png = Buffer.from( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", - "base64" -); -await codex.emitImage({ bytes: png, mimeType: "image/png", detail: "ultra" }); -"#; - - let err = manager - .execute( - Arc::clone(&session), - turn, - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ) - .await - .expect_err("invalid detail should fail"); - assert!( - err.to_string() - .contains("only supports detail \"original\"") - ); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_emit_image_treats_null_detail_as_omitted() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - - let session = Arc::new(session); - let turn = Arc::new(turn); - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -const png = Buffer.from( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", - "base64" -); -await codex.emitImage({ bytes: png, mimeType: "image/png", detail: null }); -"#; - - let result = manager - .execute( - Arc::clone(&session), - turn, - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ) - .await?; - assert_eq!( - result.content_items.as_slice(), - [FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==".to_string(), - detail: None, - }] - .as_slice() - ); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn js_repl_emit_image_rejects_mixed_content() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn, rx_event) = - make_session_and_context_with_dynamic_tools_and_rx(vec![DynamicToolSpec { - name: "inline_image".to_string(), - description: "Returns inline text and image content.".to_string(), - input_schema: serde_json::json!({ - "type": "object", - "properties": {}, - "additionalProperties": false - }), - }]) - .await; - if !turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Ok(()); - } - - *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); - - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - let code = r#" -const out = await codex.tool("inline_image", {}); -await codex.emitImage(out); -"#; - let image_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg=="; - - let session_for_response = Arc::clone(&session); - let response_watcher = async move { - loop { - let event = tokio::time::timeout(Duration::from_secs(2), rx_event.recv()).await??; - if let EventMsg::DynamicToolCallRequest(request) = event.msg { - session_for_response - .notify_dynamic_tool_response( - &request.call_id, - DynamicToolResponse { - content_items: vec![ - DynamicToolCallOutputContentItem::InputText { - text: "inline image note".to_string(), - }, - DynamicToolCallOutputContentItem::InputImage { - image_url: image_url.to_string(), - }, - ], - success: true, - }, - ) - .await; - return Ok::<(), anyhow::Error>(()); - } - } - }; - - let (result, response_watcher_result) = tokio::join!( - manager.execute( - Arc::clone(&session), - Arc::clone(&turn), - tracker, - JsReplArgs { - code: code.to_string(), - timeout_ms: Some(15_000), - }, - ), - response_watcher, - ); - response_watcher_result?; - let err = result.expect_err("mixed content should fail"); - assert!( - err.to_string() - .contains("does not accept mixed text and image content") - ); - assert!(session.get_pending_input().await.is_empty()); - - Ok(()) - } - #[tokio::test] - async fn js_repl_prefers_env_node_module_dirs_over_config() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let env_base = tempdir()?; - write_js_repl_test_package(env_base.path(), "repl_probe", "env")?; - - let config_base = tempdir()?; - let cwd_dir = tempdir()?; - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy.r#set.insert( - "CODEX_JS_REPL_NODE_MODULE_DIRS".to_string(), - env_base.path().to_string_lossy().to_string(), - ); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - vec![config_base.path().to_path_buf()], - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let result = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "const mod = await import(\"repl_probe\"); console.log(mod.value);" - .to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(result.output.contains("env")); - Ok(()) - } - - #[tokio::test] - async fn js_repl_resolves_from_first_config_dir() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let first_base = tempdir()?; - let second_base = tempdir()?; - write_js_repl_test_package(first_base.path(), "repl_probe", "first")?; - write_js_repl_test_package(second_base.path(), "repl_probe", "second")?; - - let cwd_dir = tempdir()?; - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - vec![ - first_base.path().to_path_buf(), - second_base.path().to_path_buf(), - ], - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let result = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "const mod = await import(\"repl_probe\"); console.log(mod.value);" - .to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(result.output.contains("first")); - Ok(()) - } - - #[tokio::test] - async fn js_repl_falls_back_to_cwd_node_modules() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let config_base = tempdir()?; - let cwd_dir = tempdir()?; - write_js_repl_test_package(cwd_dir.path(), "repl_probe", "cwd")?; - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - vec![config_base.path().to_path_buf()], - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let result = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "const mod = await import(\"repl_probe\"); console.log(mod.value);" - .to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(result.output.contains("cwd")); - Ok(()) - } - - #[tokio::test] - async fn js_repl_accepts_node_modules_dir_entries() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let base_dir = tempdir()?; - let cwd_dir = tempdir()?; - write_js_repl_test_package(base_dir.path(), "repl_probe", "normalized")?; - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - vec![base_dir.path().join("node_modules")], - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let result = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "const mod = await import(\"repl_probe\"); console.log(mod.value);" - .to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(result.output.contains("normalized")); - Ok(()) - } - - #[tokio::test] - async fn js_repl_supports_relative_file_imports() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let cwd_dir = tempdir()?; - write_js_repl_test_module( - cwd_dir.path(), - "child.js", - "export const value = \"child\";\n", - )?; - write_js_repl_test_module( - cwd_dir.path(), - "parent.js", - "import { value as childValue } from \"./child.js\";\nexport const value = `${childValue}-parent`;\n", - )?; - write_js_repl_test_module( - cwd_dir.path(), - "local.mjs", - "export const value = \"mjs\";\n", - )?; - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - Vec::new(), - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let result = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "const parent = await import(\"./parent.js\"); const other = await import(\"./local.mjs\"); console.log(parent.value); console.log(other.value);".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(result.output.contains("child-parent")); - assert!(result.output.contains("mjs")); - Ok(()) - } - - #[tokio::test] - async fn js_repl_supports_absolute_file_imports() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let module_dir = tempdir()?; - let cwd_dir = tempdir()?; - write_js_repl_test_module( - module_dir.path(), - "absolute.js", - "export const value = \"absolute\";\n", - )?; - let absolute_path_json = - serde_json::to_string(&module_dir.path().join("absolute.js").display().to_string())?; - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - Vec::new(), - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let result = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: format!( - "const mod = await import({absolute_path_json}); console.log(mod.value);" - ), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(result.output.contains("absolute")); - Ok(()) - } - - #[tokio::test] - async fn js_repl_imported_local_files_can_access_repl_globals() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let cwd_dir = tempdir()?; - let expected_home_dir = serde_json::to_string("/tmp/codex-home")?; - write_js_repl_test_module( - cwd_dir.path(), - "globals.js", - &format!( - "const expectedHomeDir = {expected_home_dir};\nconsole.log(`tmp:${{codex.tmpDir === tmpDir}}`);\nconsole.log(`cwd:${{typeof codex.cwd}}:${{codex.cwd.length > 0}}`);\nconsole.log(`home:${{codex.homeDir === expectedHomeDir}}`);\nconsole.log(`tool:${{typeof codex.tool}}`);\nconsole.log(\"local-file-console-ok\");\n" - ), - )?; - - let (session, mut turn) = make_session_and_context().await; - session - .set_dependency_env(HashMap::from([( - "HOME".to_string(), - "/tmp/codex-home".to_string(), - )])) - .await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - Vec::new(), - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let result = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "await import(\"./globals.js\");".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(result.output.contains("tmp:true")); - assert!(result.output.contains("cwd:string:true")); - assert!(result.output.contains("home:true")); - assert!(result.output.contains("tool:function")); - assert!(result.output.contains("local-file-console-ok")); - Ok(()) - } - - #[tokio::test] - async fn js_repl_reimports_local_files_after_edit() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let cwd_dir = tempdir()?; - let helper_path = cwd_dir.path().join("helper.js"); - fs::write(&helper_path, "export const value = \"v1\";\n")?; - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - Vec::new(), - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let first = manager - .execute( - Arc::clone(&session), - Arc::clone(&turn), - Arc::clone(&tracker), - JsReplArgs { - code: "const { value: firstValue } = await import(\"./helper.js\");\nconsole.log(firstValue);".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(first.output.contains("v1")); - - fs::write(&helper_path, "export const value = \"v2\";\n")?; - - let second = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "console.log(firstValue);\nconst { value: secondValue } = await import(\"./helper.js\");\nconsole.log(secondValue);".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(second.output.contains("v1")); - assert!(second.output.contains("v2")); - Ok(()) - } - - #[tokio::test] - async fn js_repl_reimports_local_files_after_fixing_failure() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let cwd_dir = tempdir()?; - let helper_path = cwd_dir.path().join("broken.js"); - fs::write(&helper_path, "throw new Error(\"boom\");\n")?; - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - Vec::new(), - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let err = manager - .execute( - Arc::clone(&session), - Arc::clone(&turn), - Arc::clone(&tracker), - JsReplArgs { - code: "await import(\"./broken.js\");".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await - .expect_err("expected broken module import to fail"); - assert!(err.to_string().contains("boom")); - - fs::write(&helper_path, "export const value = \"fixed\";\n")?; - - let result = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "console.log((await import(\"./broken.js\")).value);".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - assert!(result.output.contains("fixed")); - Ok(()) - } - - #[tokio::test] - async fn js_repl_local_files_expose_node_like_import_meta() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let cwd_dir = tempdir()?; - let pkg_dir = cwd_dir.path().join("node_modules").join("repl_meta_pkg"); - fs::create_dir_all(&pkg_dir)?; - fs::write( - pkg_dir.join("package.json"), - "{\n \"name\": \"repl_meta_pkg\",\n \"version\": \"1.0.0\",\n \"type\": \"module\",\n \"exports\": {\n \"import\": \"./index.js\"\n }\n}\n", - )?; - fs::write( - pkg_dir.join("index.js"), - "import { sep } from \"node:path\";\nexport const value = `pkg:${typeof sep}`;\n", - )?; - write_js_repl_test_module( - cwd_dir.path(), - "child.js", - "export const value = \"child-export\";\n", - )?; - write_js_repl_test_module( - cwd_dir.path(), - "meta.js", - "console.log(import.meta.url);\nconsole.log(import.meta.filename);\nconsole.log(import.meta.dirname);\nconsole.log(import.meta.main);\nconsole.log(import.meta.resolve(\"./child.js\"));\nconsole.log(import.meta.resolve(\"repl_meta_pkg\"));\nconsole.log(import.meta.resolve(\"node:fs\"));\nconsole.log((await import(import.meta.resolve(\"./child.js\"))).value);\nconsole.log((await import(import.meta.resolve(\"repl_meta_pkg\"))).value);\n", - )?; - let child_path = fs::canonicalize(cwd_dir.path().join("child.js"))?; - let child_url = url::Url::from_file_path(&child_path) - .expect("child path should convert to file URL") - .to_string(); - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - Vec::new(), - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let result = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "await import(\"./meta.js\");".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await?; - let cwd_display = cwd_dir.path().display().to_string(); - let meta_path_display = cwd_dir.path().join("meta.js").display().to_string(); - assert!(result.output.contains("file://")); - assert!(result.output.contains(&meta_path_display)); - assert!(result.output.contains(&cwd_display)); - assert!(result.output.contains("false")); - assert!(result.output.contains(&child_url)); - assert!(result.output.contains("repl_meta_pkg")); - assert!(result.output.contains("node:fs")); - assert!(result.output.contains("child-export")); - assert!(result.output.contains("pkg:string")); - Ok(()) - } - - #[tokio::test] - async fn js_repl_rejects_top_level_static_imports_with_clear_error() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let (session, turn) = make_session_and_context().await; - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let err = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "import \"./local.js\";".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await - .expect_err("expected top-level static import to be rejected"); - assert!( - err.to_string() - .contains("Top-level static import \"./local.js\" is not supported in js_repl") - ); - Ok(()) - } - - #[tokio::test] - async fn js_repl_local_files_reject_static_bare_imports() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let cwd_dir = tempdir()?; - write_js_repl_test_package(cwd_dir.path(), "repl_counter", "pkg")?; - write_js_repl_test_module( - cwd_dir.path(), - "entry.js", - "import { value } from \"repl_counter\";\nconsole.log(value);\n", - )?; - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - Vec::new(), - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let err = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "await import(\"./entry.js\");".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await - .expect_err("expected static bare import to be rejected"); - assert!( - err.to_string().contains( - "Static import \"repl_counter\" is not supported from js_repl local files" - ) - ); - Ok(()) - } - - #[tokio::test] - async fn js_repl_rejects_unsupported_file_specifiers() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let cwd_dir = tempdir()?; - write_js_repl_test_module(cwd_dir.path(), "local.ts", "export const value = \"ts\";\n")?; - write_js_repl_test_module(cwd_dir.path(), "local", "export const value = \"noext\";\n")?; - fs::create_dir_all(cwd_dir.path().join("dir"))?; - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - Vec::new(), - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let unsupported_extension = manager - .execute( - Arc::clone(&session), - Arc::clone(&turn), - Arc::clone(&tracker), - JsReplArgs { - code: "await import(\"./local.ts\");".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await - .expect_err("expected unsupported extension to be rejected"); - assert!( - unsupported_extension - .to_string() - .contains("Only .js and .mjs files are supported") - ); - - let extensionless = manager - .execute( - Arc::clone(&session), - Arc::clone(&turn), - Arc::clone(&tracker), - JsReplArgs { - code: "await import(\"./local\");".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await - .expect_err("expected extensionless import to be rejected"); - assert!( - extensionless - .to_string() - .contains("Only .js and .mjs files are supported") - ); - - let directory = manager - .execute( - Arc::clone(&session), - Arc::clone(&turn), - Arc::clone(&tracker), - JsReplArgs { - code: "await import(\"./dir\");".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await - .expect_err("expected directory import to be rejected"); - assert!( - directory - .to_string() - .contains("Directory imports are not supported") - ); - - let unsupported_url = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "await import(\"https://example.com/test.js\");".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await - .expect_err("expected unsupported url import to be rejected"); - assert!( - unsupported_url - .to_string() - .contains("Unsupported import specifier") - ); - Ok(()) - } - - #[tokio::test] - async fn js_repl_blocks_sensitive_builtin_imports_from_local_files() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let cwd_dir = tempdir()?; - write_js_repl_test_module( - cwd_dir.path(), - "blocked.js", - "import process from \"node:process\";\nconsole.log(process.pid);\n", - )?; - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.path().to_path_buf(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - Vec::new(), - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let err = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "await import(\"./blocked.js\");".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await - .expect_err("expected blocked builtin import to be rejected"); - assert!( - err.to_string() - .contains("Importing module \"node:process\" is not allowed in js_repl") - ); - Ok(()) - } - - #[tokio::test] - async fn js_repl_local_files_do_not_escape_node_module_search_roots() -> anyhow::Result<()> { - if !can_run_js_repl_runtime_tests().await { - return Ok(()); - } - - let parent_dir = tempdir()?; - write_js_repl_test_package(parent_dir.path(), "repl_probe", "parent")?; - let cwd_dir = parent_dir.path().join("workspace"); - fs::create_dir_all(&cwd_dir)?; - write_js_repl_test_module( - &cwd_dir, - "entry.js", - "const { value } = await import(\"repl_probe\");\nconsole.log(value);\n", - )?; - - let (session, mut turn) = make_session_and_context().await; - turn.shell_environment_policy - .r#set - .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); - turn.cwd = cwd_dir.clone(); - turn.js_repl = Arc::new(JsReplHandle::with_node_path( - turn.config.js_repl_node_path.clone(), - Vec::new(), - )); - - let session = Arc::new(session); - let turn = Arc::new(turn); - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); - let manager = turn.js_repl.manager().await?; - - let err = manager - .execute( - session, - turn, - tracker, - JsReplArgs { - code: "await import(\"./entry.js\");".to_string(), - timeout_ms: Some(10_000), - }, - ) - .await - .expect_err("expected parent node_modules lookup to be rejected"); - assert!(err.to_string().contains("repl_probe")); - Ok(()) - } -} +#[path = "mod_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/js_repl/mod_tests.rs b/codex-rs/core/src/tools/js_repl/mod_tests.rs new file mode 100644 index 0000000000..2ea0e67f65 --- /dev/null +++ b/codex-rs/core/src/tools/js_repl/mod_tests.rs @@ -0,0 +1,2321 @@ +use super::*; +use crate::codex::make_session_and_context; +use crate::codex::make_session_and_context_with_dynamic_tools_and_rx; +use crate::features::Feature; +use crate::protocol::AskForApproval; +use crate::protocol::EventMsg; +use crate::protocol::SandboxPolicy; +use crate::turn_diff_tracker::TurnDiffTracker; +use codex_protocol::dynamic_tools::DynamicToolCallOutputContentItem; +use codex_protocol::dynamic_tools::DynamicToolResponse; +use codex_protocol::dynamic_tools::DynamicToolSpec; +use codex_protocol::models::FunctionCallOutputContentItem; +use codex_protocol::models::FunctionCallOutputPayload; +use codex_protocol::models::ImageDetail; +use codex_protocol::models::ResponseInputItem; +use codex_protocol::openai_models::InputModality; +use pretty_assertions::assert_eq; +use std::fs; +use std::path::Path; +use tempfile::tempdir; + +fn set_danger_full_access(turn: &mut crate::codex::TurnContext) { + turn.sandbox_policy + .set(SandboxPolicy::DangerFullAccess) + .expect("test setup should allow updating sandbox policy"); + turn.file_system_sandbox_policy = + crate::protocol::FileSystemSandboxPolicy::from(turn.sandbox_policy.get()); + turn.network_sandbox_policy = + crate::protocol::NetworkSandboxPolicy::from(turn.sandbox_policy.get()); +} + +#[test] +fn node_version_parses_v_prefix_and_suffix() { + let version = NodeVersion::parse("v25.1.0-nightly.2024").unwrap(); + assert_eq!( + version, + NodeVersion { + major: 25, + minor: 1, + patch: 0, + } + ); +} + +#[test] +fn truncate_utf8_prefix_by_bytes_preserves_character_boundaries() { + let input = "aé🙂z"; + assert_eq!(truncate_utf8_prefix_by_bytes(input, 0), ""); + assert_eq!(truncate_utf8_prefix_by_bytes(input, 1), "a"); + assert_eq!(truncate_utf8_prefix_by_bytes(input, 2), "a"); + assert_eq!(truncate_utf8_prefix_by_bytes(input, 3), "aé"); + assert_eq!(truncate_utf8_prefix_by_bytes(input, 6), "aé"); + assert_eq!(truncate_utf8_prefix_by_bytes(input, 7), "aé🙂"); + assert_eq!(truncate_utf8_prefix_by_bytes(input, 8), "aé🙂z"); +} + +#[test] +fn stderr_tail_applies_line_and_byte_limits() { + let mut lines = VecDeque::new(); + let per_line_cap = JS_REPL_STDERR_TAIL_LINE_MAX_BYTES.min(JS_REPL_STDERR_TAIL_MAX_BYTES); + let long = "x".repeat(per_line_cap + 128); + let bounded = push_stderr_tail_line(&mut lines, &long); + assert_eq!(bounded.len(), per_line_cap); + + for i in 0..50 { + let line = format!("line-{i}-{}", "y".repeat(200)); + push_stderr_tail_line(&mut lines, &line); + } + + assert!(lines.len() <= JS_REPL_STDERR_TAIL_LINE_LIMIT); + assert!(lines.iter().all(|line| line.len() <= per_line_cap)); + assert!(stderr_tail_formatted_bytes(&lines) <= JS_REPL_STDERR_TAIL_MAX_BYTES); + assert_eq!( + format_stderr_tail(&lines).len(), + stderr_tail_formatted_bytes(&lines) + ); +} + +#[test] +fn model_kernel_failure_details_are_structured_and_truncated() { + let snapshot = KernelDebugSnapshot { + pid: Some(42), + status: "exited(code=1)".to_string(), + stderr_tail: "s".repeat(JS_REPL_MODEL_DIAG_STDERR_MAX_BYTES + 400), + }; + let stream_error = "e".repeat(JS_REPL_MODEL_DIAG_ERROR_MAX_BYTES + 200); + let message = with_model_kernel_failure_message( + "js_repl kernel exited unexpectedly", + "stdout_eof", + Some(&stream_error), + &snapshot, + ); + assert!(message.starts_with("js_repl kernel exited unexpectedly\n\njs_repl diagnostics: ")); + let (_prefix, encoded) = message + .split_once("js_repl diagnostics: ") + .expect("diagnostics suffix should be present"); + let parsed: serde_json::Value = + serde_json::from_str(encoded).expect("diagnostics should be valid json"); + assert_eq!( + parsed.get("reason").and_then(|v| v.as_str()), + Some("stdout_eof") + ); + assert_eq!( + parsed.get("kernel_pid").and_then(serde_json::Value::as_u64), + Some(42) + ); + assert_eq!( + parsed.get("kernel_status").and_then(|v| v.as_str()), + Some("exited(code=1)") + ); + assert!( + parsed + .get("kernel_stderr_tail") + .and_then(|v| v.as_str()) + .expect("kernel_stderr_tail should be present") + .len() + <= JS_REPL_MODEL_DIAG_STDERR_MAX_BYTES + ); + assert!( + parsed + .get("stream_error") + .and_then(|v| v.as_str()) + .expect("stream_error should be present") + .len() + <= JS_REPL_MODEL_DIAG_ERROR_MAX_BYTES + ); +} + +#[test] +fn write_error_diagnostics_only_attach_for_likely_kernel_failures() { + let running = KernelDebugSnapshot { + pid: Some(7), + status: "running".to_string(), + stderr_tail: "".to_string(), + }; + let exited = KernelDebugSnapshot { + pid: Some(7), + status: "exited(code=1)".to_string(), + stderr_tail: "".to_string(), + }; + assert!(!should_include_model_diagnostics_for_write_error( + "failed to flush kernel message: other io error", + &running + )); + assert!(should_include_model_diagnostics_for_write_error( + "failed to write to kernel: Broken pipe (os error 32)", + &running + )); + assert!(should_include_model_diagnostics_for_write_error( + "failed to write to kernel: some other io error", + &exited + )); +} + +#[test] +fn js_repl_internal_tool_guard_matches_expected_names() { + assert!(is_js_repl_internal_tool("js_repl")); + assert!(is_js_repl_internal_tool("js_repl_reset")); + assert!(!is_js_repl_internal_tool("shell_command")); + assert!(!is_js_repl_internal_tool("list_mcp_resources")); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn wait_for_exec_tool_calls_map_drains_inflight_calls_without_hanging() { + let exec_tool_calls = Arc::new(Mutex::new(HashMap::new())); + + for _ in 0..128 { + let exec_id = Uuid::new_v4().to_string(); + exec_tool_calls + .lock() + .await + .insert(exec_id.clone(), ExecToolCalls::default()); + assert!( + JsReplManager::begin_exec_tool_call(&exec_tool_calls, &exec_id) + .await + .is_some() + ); + + let wait_map = Arc::clone(&exec_tool_calls); + let wait_exec_id = exec_id.clone(); + let waiter = tokio::spawn(async move { + JsReplManager::wait_for_exec_tool_calls_map(&wait_map, &wait_exec_id).await; + }); + + let finish_map = Arc::clone(&exec_tool_calls); + let finish_exec_id = exec_id.clone(); + let finisher = tokio::spawn(async move { + tokio::task::yield_now().await; + JsReplManager::finish_exec_tool_call(&finish_map, &finish_exec_id).await; + }); + + tokio::time::timeout(Duration::from_secs(1), waiter) + .await + .expect("wait_for_exec_tool_calls_map should not hang") + .expect("wait task should not panic"); + finisher.await.expect("finish task should not panic"); + + JsReplManager::clear_exec_tool_calls_map(&exec_tool_calls, &exec_id).await; + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn reset_waits_for_exec_lock_before_clearing_exec_tool_calls() { + let manager = JsReplManager::new(None, Vec::new()) + .await + .expect("manager should initialize"); + let permit = manager + .exec_lock + .clone() + .acquire_owned() + .await + .expect("lock should be acquirable"); + let exec_id = Uuid::new_v4().to_string(); + manager.register_exec_tool_calls(&exec_id).await; + + let reset_manager = Arc::clone(&manager); + let mut reset_task = tokio::spawn(async move { reset_manager.reset().await }); + tokio::time::sleep(Duration::from_millis(50)).await; + + assert!( + !reset_task.is_finished(), + "reset should wait until execute lock is released" + ); + assert!( + manager.exec_tool_calls.lock().await.contains_key(&exec_id), + "reset must not clear tool-call contexts while execute lock is held" + ); + + drop(permit); + + tokio::time::timeout(Duration::from_secs(1), &mut reset_task) + .await + .expect("reset should complete after execute lock release") + .expect("reset task should not panic") + .expect("reset should succeed"); + assert!( + !manager.exec_tool_calls.lock().await.contains_key(&exec_id), + "reset should clear tool-call contexts after lock acquisition" + ); +} + +#[test] +fn summarize_tool_call_response_for_multimodal_function_output() { + let response = ResponseInputItem::FunctionCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload::from_content_items(vec![ + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,abcd".to_string(), + detail: None, + }, + ]), + }; + + let actual = JsReplManager::summarize_tool_call_response(&response); + + assert_eq!( + actual, + JsReplToolCallResponseSummary { + response_type: Some("function_call_output".to_string()), + payload_kind: Some(JsReplToolCallPayloadKind::FunctionContentItems), + payload_text_preview: None, + payload_text_length: None, + payload_item_count: Some(1), + text_item_count: Some(0), + image_item_count: Some(1), + structured_content_present: None, + result_is_error: None, + } + ); +} + +#[tokio::test] +async fn emitted_image_content_item_drops_unsupported_explicit_detail() { + let (_session, turn) = make_session_and_context().await; + let content_item = emitted_image_content_item( + &turn, + "data:image/png;base64,AAA".to_string(), + Some(ImageDetail::Low), + ); + assert_eq!( + content_item, + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,AAA".to_string(), + detail: None, + } + ); +} + +#[tokio::test] +async fn emitted_image_content_item_does_not_force_original_when_enabled() { + let (_session, mut turn) = make_session_and_context().await; + Arc::make_mut(&mut turn.config) + .features + .enable(Feature::ImageDetailOriginal) + .expect("test config should allow feature update"); + turn.features + .enable(Feature::ImageDetailOriginal) + .expect("test turn features should allow feature update"); + turn.model_info.supports_image_detail_original = true; + + let content_item = + emitted_image_content_item(&turn, "data:image/png;base64,AAA".to_string(), None); + + assert_eq!( + content_item, + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,AAA".to_string(), + detail: None, + } + ); +} + +#[tokio::test] +async fn emitted_image_content_item_allows_explicit_original_detail_when_enabled() { + let (_session, mut turn) = make_session_and_context().await; + Arc::make_mut(&mut turn.config) + .features + .enable(Feature::ImageDetailOriginal) + .expect("test config should allow feature update"); + turn.features + .enable(Feature::ImageDetailOriginal) + .expect("test turn features should allow feature update"); + turn.model_info.supports_image_detail_original = true; + + let content_item = emitted_image_content_item( + &turn, + "data:image/png;base64,AAA".to_string(), + Some(ImageDetail::Original), + ); + + assert_eq!( + content_item, + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,AAA".to_string(), + detail: Some(ImageDetail::Original), + } + ); +} + +#[tokio::test] +async fn emitted_image_content_item_drops_explicit_original_detail_when_disabled() { + let (_session, turn) = make_session_and_context().await; + + let content_item = emitted_image_content_item( + &turn, + "data:image/png;base64,AAA".to_string(), + Some(ImageDetail::Original), + ); + + assert_eq!( + content_item, + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,AAA".to_string(), + detail: None, + } + ); +} + +#[test] +fn validate_emitted_image_url_accepts_case_insensitive_data_scheme() { + assert_eq!( + validate_emitted_image_url("DATA:image/png;base64,AAA"), + Ok(()) + ); +} + +#[test] +fn validate_emitted_image_url_rejects_non_data_scheme() { + assert_eq!( + validate_emitted_image_url("https://example.com/image.png"), + Err("codex.emitImage only accepts data URLs".to_string()) + ); +} + +#[test] +fn summarize_tool_call_response_for_multimodal_custom_output() { + let response = ResponseInputItem::CustomToolCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload::from_content_items(vec![ + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,abcd".to_string(), + detail: None, + }, + ]), + }; + + let actual = JsReplManager::summarize_tool_call_response(&response); + + assert_eq!( + actual, + JsReplToolCallResponseSummary { + response_type: Some("custom_tool_call_output".to_string()), + payload_kind: Some(JsReplToolCallPayloadKind::CustomContentItems), + payload_text_preview: None, + payload_text_length: None, + payload_item_count: Some(1), + text_item_count: Some(0), + image_item_count: Some(1), + structured_content_present: None, + result_is_error: None, + } + ); +} + +#[test] +fn summarize_tool_call_error_marks_error_payload() { + let actual = JsReplManager::summarize_tool_call_error("tool failed"); + + assert_eq!( + actual, + JsReplToolCallResponseSummary { + response_type: None, + payload_kind: Some(JsReplToolCallPayloadKind::Error), + payload_text_preview: Some("tool failed".to_string()), + payload_text_length: Some("tool failed".len()), + payload_item_count: None, + text_item_count: None, + image_item_count: None, + structured_content_present: None, + result_is_error: None, + } + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn reset_clears_inflight_exec_tool_calls_without_waiting() { + let manager = JsReplManager::new(None, Vec::new()) + .await + .expect("manager should initialize"); + let exec_id = Uuid::new_v4().to_string(); + manager.register_exec_tool_calls(&exec_id).await; + assert!( + JsReplManager::begin_exec_tool_call(&manager.exec_tool_calls, &exec_id) + .await + .is_some() + ); + + let wait_manager = Arc::clone(&manager); + let wait_exec_id = exec_id.clone(); + let waiter = tokio::spawn(async move { + wait_manager.wait_for_exec_tool_calls(&wait_exec_id).await; + }); + tokio::task::yield_now().await; + + tokio::time::timeout(Duration::from_secs(1), manager.reset()) + .await + .expect("reset should not hang") + .expect("reset should succeed"); + + tokio::time::timeout(Duration::from_secs(1), waiter) + .await + .expect("waiter should be released") + .expect("wait task should not panic"); + + assert!(manager.exec_tool_calls.lock().await.is_empty()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn reset_aborts_inflight_exec_tool_tasks() { + let manager = JsReplManager::new(None, Vec::new()) + .await + .expect("manager should initialize"); + let exec_id = Uuid::new_v4().to_string(); + manager.register_exec_tool_calls(&exec_id).await; + let reset_cancel = JsReplManager::begin_exec_tool_call(&manager.exec_tool_calls, &exec_id) + .await + .expect("exec should be registered"); + + let task = tokio::spawn(async move { + tokio::select! { + _ = reset_cancel.cancelled() => "cancelled", + _ = tokio::time::sleep(Duration::from_secs(60)) => "timed_out", + } + }); + + tokio::time::timeout(Duration::from_secs(1), manager.reset()) + .await + .expect("reset should not hang") + .expect("reset should succeed"); + + let outcome = tokio::time::timeout(Duration::from_secs(1), task) + .await + .expect("cancelled task should resolve promptly") + .expect("task should not panic"); + assert_eq!(outcome, "cancelled"); +} + +async fn can_run_js_repl_runtime_tests() -> bool { + // These white-box runtime tests are required on macOS. Linux relies on + // the codex-linux-sandbox arg0 dispatch path, which is exercised in + // integration tests instead. + cfg!(target_os = "macos") +} +fn write_js_repl_test_package_source(base: &Path, name: &str, source: &str) -> anyhow::Result<()> { + let pkg_dir = base.join("node_modules").join(name); + fs::create_dir_all(&pkg_dir)?; + fs::write( + pkg_dir.join("package.json"), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"type\": \"module\",\n \"exports\": {{\n \"import\": \"./index.js\"\n }}\n}}\n" + ), + )?; + fs::write(pkg_dir.join("index.js"), source)?; + Ok(()) +} + +fn write_js_repl_test_package(base: &Path, name: &str, value: &str) -> anyhow::Result<()> { + write_js_repl_test_package_source(base, name, &format!("export const value = \"{value}\";\n"))?; + Ok(()) +} + +fn write_js_repl_test_module(base: &Path, relative: &str, contents: &str) -> anyhow::Result<()> { + let module_path = base.join(relative); + if let Some(parent) = module_path.parent() { + fs::create_dir_all(parent)?; + } + fs::write(module_path, contents)?; + Ok(()) +} + +#[tokio::test] +async fn js_repl_timeout_does_not_deadlock() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let result = tokio::time::timeout( + Duration::from_secs(3), + manager.execute( + session, + turn, + tracker, + JsReplArgs { + code: "while (true) {}".to_string(), + timeout_ms: Some(50), + }, + ), + ) + .await + .expect("execute should return, not deadlock") + .expect_err("expected timeout error"); + + assert_eq!( + result.to_string(), + "js_repl execution timed out; kernel reset, rerun your request" + ); + Ok(()) +} + +#[tokio::test] +async fn js_repl_timeout_kills_kernel_process() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + manager + .execute( + Arc::clone(&session), + Arc::clone(&turn), + Arc::clone(&tracker), + JsReplArgs { + code: "console.log('warmup');".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + + let child = { + let guard = manager.kernel.lock().await; + let state = guard.as_ref().expect("kernel should exist after warmup"); + Arc::clone(&state.child) + }; + + let result = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "while (true) {}".to_string(), + timeout_ms: Some(50), + }, + ) + .await + .expect_err("expected timeout error"); + + assert_eq!( + result.to_string(), + "js_repl execution timed out; kernel reset, rerun your request" + ); + + let exit_state = { + let mut child = child.lock().await; + child.try_wait()? + }; + assert!( + exit_state.is_some(), + "timed out js_repl execution should kill previous kernel process" + ); + Ok(()) +} + +#[tokio::test] +async fn js_repl_forced_kernel_exit_recovers_on_next_exec() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + manager + .execute( + Arc::clone(&session), + Arc::clone(&turn), + Arc::clone(&tracker), + JsReplArgs { + code: "console.log('warmup');".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + + let child = { + let guard = manager.kernel.lock().await; + let state = guard.as_ref().expect("kernel should exist after warmup"); + Arc::clone(&state.child) + }; + JsReplManager::kill_kernel_child(&child, "test_crash").await; + tokio::time::timeout(Duration::from_secs(1), async { + loop { + let cleared = { + let guard = manager.kernel.lock().await; + guard + .as_ref() + .is_none_or(|state| !Arc::ptr_eq(&state.child, &child)) + }; + if cleared { + return; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("host should clear dead kernel state promptly"); + + let result = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "console.log('after-kill');".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(result.output.contains("after-kill")); + Ok(()) +} + +#[tokio::test] +async fn js_repl_uncaught_exception_returns_exec_error_and_recovers() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = crate::codex::make_session_and_context().await; + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + manager + .execute( + Arc::clone(&session), + Arc::clone(&turn), + Arc::clone(&tracker), + JsReplArgs { + code: "console.log('warmup');".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + + let child = { + let guard = manager.kernel.lock().await; + let state = guard.as_ref().expect("kernel should exist after warmup"); + Arc::clone(&state.child) + }; + + let err = tokio::time::timeout( + Duration::from_secs(3), + manager.execute( + Arc::clone(&session), + Arc::clone(&turn), + Arc::clone(&tracker), + JsReplArgs { + code: "setTimeout(() => { throw new Error('boom'); }, 0);\nawait new Promise(() => {});".to_string(), + timeout_ms: Some(10_000), + }, + ), + ) + .await + .expect("uncaught exception should fail promptly") + .expect_err("expected uncaught exception to fail the exec"); + + let message = err.to_string(); + assert!(message.contains("js_repl kernel uncaught exception: boom")); + assert!(message.contains("kernel reset.")); + assert!(message.contains("Catch or handle async errors")); + assert!(!message.contains("js_repl kernel exited unexpectedly")); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + let exited = { + let mut child = child.lock().await; + child.try_wait()?.is_some() + }; + if exited { + return Ok::<(), anyhow::Error>(()); + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("uncaught exception should terminate the previous kernel process")?; + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + let cleared = { + let guard = manager.kernel.lock().await; + guard + .as_ref() + .is_none_or(|state| !Arc::ptr_eq(&state.child, &child)) + }; + if cleared { + return; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("host should clear dead kernel state promptly"); + + let next = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "console.log('after reset');".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(next.output.contains("after reset")); + Ok(()) +} + +#[tokio::test] +async fn js_repl_waits_for_unawaited_tool_calls_before_completion() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, mut turn) = make_session_and_context().await; + turn.approval_policy + .set(AskForApproval::Never) + .expect("test setup should allow updating approval policy"); + set_danger_full_access(&mut turn); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let marker = turn + .cwd + .join(format!("js-repl-unawaited-marker-{}.txt", Uuid::new_v4())); + let marker_json = serde_json::to_string(&marker.to_string_lossy().to_string())?; + let result = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: format!( + r#" +const marker = {marker_json}; +void codex.tool("shell_command", {{ command: `sleep 0.35; printf js_repl_unawaited_done > "${{marker}}"` }}); +console.log("cell-complete"); +"# + ), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(result.output.contains("cell-complete")); + let marker_contents = tokio::fs::read_to_string(&marker).await?; + assert_eq!(marker_contents, "js_repl_unawaited_done"); + let _ = tokio::fs::remove_file(&marker).await; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_does_not_auto_attach_image_via_view_image_tool() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, mut turn) = make_session_and_context().await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + turn.approval_policy + .set(AskForApproval::Never) + .expect("test setup should allow updating approval policy"); + set_danger_full_access(&mut turn); + + let session = Arc::new(session); + let turn = Arc::new(turn); + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +const fs = await import("node:fs/promises"); +const path = await import("node:path"); +const imagePath = path.join(codex.tmpDir, "js-repl-view-image.png"); +const png = Buffer.from( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", + "base64" +); +await fs.writeFile(imagePath, png); +const out = await codex.tool("view_image", { path: imagePath }); +console.log(out.type); +"#; + + let result = manager + .execute( + Arc::clone(&session), + turn, + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ) + .await?; + assert!(result.output.contains("function_call_output")); + assert!(result.content_items.is_empty()); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_can_emit_image_via_view_image_tool() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, mut turn) = make_session_and_context().await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + turn.approval_policy + .set(AskForApproval::Never) + .expect("test setup should allow updating approval policy"); + set_danger_full_access(&mut turn); + + let session = Arc::new(session); + let turn = Arc::new(turn); + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +const fs = await import("node:fs/promises"); +const path = await import("node:path"); +const imagePath = path.join(codex.tmpDir, "js-repl-view-image-explicit.png"); +const png = Buffer.from( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", + "base64" +); +await fs.writeFile(imagePath, png); +const out = await codex.tool("view_image", { path: imagePath }); +await codex.emitImage(out); +console.log(out.type); +"#; + + let result = manager + .execute( + Arc::clone(&session), + turn, + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ) + .await?; + assert!(result.output.contains("function_call_output")); + assert_eq!( + result.content_items.as_slice(), + [FunctionCallOutputContentItem::InputImage { + image_url: + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" + .to_string(), + detail: None, + }] + .as_slice() + ); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_can_emit_image_from_bytes_and_mime_type() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + + let session = Arc::new(session); + let turn = Arc::new(turn); + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +const png = Buffer.from( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", + "base64" +); +await codex.emitImage({ bytes: png, mimeType: "image/png" }); +"#; + + let result = manager + .execute( + Arc::clone(&session), + turn, + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ) + .await?; + assert_eq!( + result.content_items.as_slice(), + [FunctionCallOutputContentItem::InputImage { + image_url: + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" + .to_string(), + detail: None, + }] + .as_slice() + ); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_can_emit_multiple_images_in_one_cell() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + + let session = Arc::new(session); + let turn = Arc::new(turn); + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +await codex.emitImage( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" +); +await codex.emitImage( + "data:image/gif;base64,R0lGODdhAQABAIAAAP///////ywAAAAAAQABAAACAkQBADs=" +); +"#; + + let result = manager + .execute( + Arc::clone(&session), + turn, + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ) + .await?; + assert_eq!( + result.content_items.as_slice(), + [ + FunctionCallOutputContentItem::InputImage { + image_url: + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" + .to_string(), + detail: None, + }, + FunctionCallOutputContentItem::InputImage { + image_url: + "data:image/gif;base64,R0lGODdhAQABAIAAAP///////ywAAAAAAQABAAACAkQBADs=" + .to_string(), + detail: None, + }, + ] + .as_slice() + ); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_waits_for_unawaited_emit_image_before_completion() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + + let session = Arc::new(session); + let turn = Arc::new(turn); + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +void codex.emitImage( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" +); +console.log("cell-complete"); +"#; + + let result = manager + .execute( + Arc::clone(&session), + turn, + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ) + .await?; + assert!(result.output.contains("cell-complete")); + assert_eq!( + result.content_items.as_slice(), + [FunctionCallOutputContentItem::InputImage { + image_url: + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==" + .to_string(), + detail: None, + }] + .as_slice() + ); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_unawaited_emit_image_errors_fail_cell() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + + let session = Arc::new(session); + let turn = Arc::new(turn); + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +void codex.emitImage({ bytes: new Uint8Array(), mimeType: "image/png" }); +console.log("cell-complete"); +"#; + + let err = manager + .execute( + Arc::clone(&session), + turn, + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ) + .await + .expect_err("unawaited invalid emitImage should fail"); + assert!(err.to_string().contains("expected non-empty bytes")); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_caught_emit_image_error_does_not_fail_cell() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + + let session = Arc::new(session); + let turn = Arc::new(turn); + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +try { + await codex.emitImage({ bytes: new Uint8Array(), mimeType: "image/png" }); +} catch (error) { + console.log(error.message); +} +console.log("cell-complete"); +"#; + + let result = manager + .execute( + Arc::clone(&session), + turn, + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ) + .await?; + assert!(result.output.contains("expected non-empty bytes")); + assert!(result.output.contains("cell-complete")); + assert!(result.content_items.is_empty()); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_emit_image_requires_explicit_mime_type_for_bytes() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + + let session = Arc::new(session); + let turn = Arc::new(turn); + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +const png = Buffer.from( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", + "base64" +); +await codex.emitImage({ bytes: png }); +"#; + + let err = manager + .execute( + Arc::clone(&session), + turn, + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ) + .await + .expect_err("missing mimeType should fail"); + assert!(err.to_string().contains("expected a non-empty mimeType")); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_emit_image_rejects_non_data_url() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + + let session = Arc::new(session); + let turn = Arc::new(turn); + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +await codex.emitImage("https://example.com/image.png"); +"#; + + let err = manager + .execute( + Arc::clone(&session), + turn, + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ) + .await + .expect_err("non-data URLs should fail"); + assert!(err.to_string().contains("only accepts data URLs")); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_emit_image_accepts_case_insensitive_data_url() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + + let session = Arc::new(session); + let turn = Arc::new(turn); + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +await codex.emitImage("DATA:image/png;base64,AAA"); +"#; + + let result = manager + .execute( + Arc::clone(&session), + turn, + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ) + .await?; + assert_eq!( + result.content_items.as_slice(), + [FunctionCallOutputContentItem::InputImage { + image_url: "DATA:image/png;base64,AAA".to_string(), + detail: None, + }] + .as_slice() + ); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_emit_image_rejects_invalid_detail() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + + let session = Arc::new(session); + let turn = Arc::new(turn); + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +const png = Buffer.from( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", + "base64" +); +await codex.emitImage({ bytes: png, mimeType: "image/png", detail: "ultra" }); +"#; + + let err = manager + .execute( + Arc::clone(&session), + turn, + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ) + .await + .expect_err("invalid detail should fail"); + assert!( + err.to_string() + .contains("only supports detail \"original\"") + ); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_emit_image_treats_null_detail_as_omitted() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + + let session = Arc::new(session); + let turn = Arc::new(turn); + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +const png = Buffer.from( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", + "base64" +); +await codex.emitImage({ bytes: png, mimeType: "image/png", detail: null }); +"#; + + let result = manager + .execute( + Arc::clone(&session), + turn, + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ) + .await?; + assert_eq!( + result.content_items.as_slice(), + [FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==".to_string(), + detail: None, + }] + .as_slice() + ); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn js_repl_emit_image_rejects_mixed_content() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn, rx_event) = + make_session_and_context_with_dynamic_tools_and_rx(vec![DynamicToolSpec { + name: "inline_image".to_string(), + description: "Returns inline text and image content.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + }), + }]) + .await; + if !turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { + return Ok(()); + } + + *session.active_turn.lock().await = Some(crate::state::ActiveTurn::default()); + + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + let code = r#" +const out = await codex.tool("inline_image", {}); +await codex.emitImage(out); +"#; + let image_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg=="; + + let session_for_response = Arc::clone(&session); + let response_watcher = async move { + loop { + let event = tokio::time::timeout(Duration::from_secs(2), rx_event.recv()).await??; + if let EventMsg::DynamicToolCallRequest(request) = event.msg { + session_for_response + .notify_dynamic_tool_response( + &request.call_id, + DynamicToolResponse { + content_items: vec![ + DynamicToolCallOutputContentItem::InputText { + text: "inline image note".to_string(), + }, + DynamicToolCallOutputContentItem::InputImage { + image_url: image_url.to_string(), + }, + ], + success: true, + }, + ) + .await; + return Ok::<(), anyhow::Error>(()); + } + } + }; + + let (result, response_watcher_result) = tokio::join!( + manager.execute( + Arc::clone(&session), + Arc::clone(&turn), + tracker, + JsReplArgs { + code: code.to_string(), + timeout_ms: Some(15_000), + }, + ), + response_watcher, + ); + response_watcher_result?; + let err = result.expect_err("mixed content should fail"); + assert!( + err.to_string() + .contains("does not accept mixed text and image content") + ); + assert!(session.get_pending_input().await.is_empty()); + + Ok(()) +} +#[tokio::test] +async fn js_repl_prefers_env_node_module_dirs_over_config() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let env_base = tempdir()?; + write_js_repl_test_package(env_base.path(), "repl_probe", "env")?; + + let config_base = tempdir()?; + let cwd_dir = tempdir()?; + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy.r#set.insert( + "CODEX_JS_REPL_NODE_MODULE_DIRS".to_string(), + env_base.path().to_string_lossy().to_string(), + ); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + vec![config_base.path().to_path_buf()], + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let result = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "const mod = await import(\"repl_probe\"); console.log(mod.value);" + .to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(result.output.contains("env")); + Ok(()) +} + +#[tokio::test] +async fn js_repl_resolves_from_first_config_dir() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let first_base = tempdir()?; + let second_base = tempdir()?; + write_js_repl_test_package(first_base.path(), "repl_probe", "first")?; + write_js_repl_test_package(second_base.path(), "repl_probe", "second")?; + + let cwd_dir = tempdir()?; + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + vec![ + first_base.path().to_path_buf(), + second_base.path().to_path_buf(), + ], + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let result = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "const mod = await import(\"repl_probe\"); console.log(mod.value);" + .to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(result.output.contains("first")); + Ok(()) +} + +#[tokio::test] +async fn js_repl_falls_back_to_cwd_node_modules() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let config_base = tempdir()?; + let cwd_dir = tempdir()?; + write_js_repl_test_package(cwd_dir.path(), "repl_probe", "cwd")?; + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + vec![config_base.path().to_path_buf()], + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let result = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "const mod = await import(\"repl_probe\"); console.log(mod.value);" + .to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(result.output.contains("cwd")); + Ok(()) +} + +#[tokio::test] +async fn js_repl_accepts_node_modules_dir_entries() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let base_dir = tempdir()?; + let cwd_dir = tempdir()?; + write_js_repl_test_package(base_dir.path(), "repl_probe", "normalized")?; + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + vec![base_dir.path().join("node_modules")], + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let result = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "const mod = await import(\"repl_probe\"); console.log(mod.value);" + .to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(result.output.contains("normalized")); + Ok(()) +} + +#[tokio::test] +async fn js_repl_supports_relative_file_imports() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let cwd_dir = tempdir()?; + write_js_repl_test_module( + cwd_dir.path(), + "child.js", + "export const value = \"child\";\n", + )?; + write_js_repl_test_module( + cwd_dir.path(), + "parent.js", + "import { value as childValue } from \"./child.js\";\nexport const value = `${childValue}-parent`;\n", + )?; + write_js_repl_test_module( + cwd_dir.path(), + "local.mjs", + "export const value = \"mjs\";\n", + )?; + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + Vec::new(), + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let result = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "const parent = await import(\"./parent.js\"); const other = await import(\"./local.mjs\"); console.log(parent.value); console.log(other.value);".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(result.output.contains("child-parent")); + assert!(result.output.contains("mjs")); + Ok(()) +} + +#[tokio::test] +async fn js_repl_supports_absolute_file_imports() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let module_dir = tempdir()?; + let cwd_dir = tempdir()?; + write_js_repl_test_module( + module_dir.path(), + "absolute.js", + "export const value = \"absolute\";\n", + )?; + let absolute_path_json = + serde_json::to_string(&module_dir.path().join("absolute.js").display().to_string())?; + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + Vec::new(), + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let result = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: format!( + "const mod = await import({absolute_path_json}); console.log(mod.value);" + ), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(result.output.contains("absolute")); + Ok(()) +} + +#[tokio::test] +async fn js_repl_imported_local_files_can_access_repl_globals() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let cwd_dir = tempdir()?; + let expected_home_dir = serde_json::to_string("/tmp/codex-home")?; + write_js_repl_test_module( + cwd_dir.path(), + "globals.js", + &format!( + "const expectedHomeDir = {expected_home_dir};\nconsole.log(`tmp:${{codex.tmpDir === tmpDir}}`);\nconsole.log(`cwd:${{typeof codex.cwd}}:${{codex.cwd.length > 0}}`);\nconsole.log(`home:${{codex.homeDir === expectedHomeDir}}`);\nconsole.log(`tool:${{typeof codex.tool}}`);\nconsole.log(\"local-file-console-ok\");\n" + ), + )?; + + let (session, mut turn) = make_session_and_context().await; + session + .set_dependency_env(HashMap::from([( + "HOME".to_string(), + "/tmp/codex-home".to_string(), + )])) + .await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + Vec::new(), + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let result = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "await import(\"./globals.js\");".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(result.output.contains("tmp:true")); + assert!(result.output.contains("cwd:string:true")); + assert!(result.output.contains("home:true")); + assert!(result.output.contains("tool:function")); + assert!(result.output.contains("local-file-console-ok")); + Ok(()) +} + +#[tokio::test] +async fn js_repl_reimports_local_files_after_edit() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let cwd_dir = tempdir()?; + let helper_path = cwd_dir.path().join("helper.js"); + fs::write(&helper_path, "export const value = \"v1\";\n")?; + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + Vec::new(), + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let first = manager + .execute( + Arc::clone(&session), + Arc::clone(&turn), + Arc::clone(&tracker), + JsReplArgs { + code: "const { value: firstValue } = await import(\"./helper.js\");\nconsole.log(firstValue);".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(first.output.contains("v1")); + + fs::write(&helper_path, "export const value = \"v2\";\n")?; + + let second = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "console.log(firstValue);\nconst { value: secondValue } = await import(\"./helper.js\");\nconsole.log(secondValue);".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(second.output.contains("v1")); + assert!(second.output.contains("v2")); + Ok(()) +} + +#[tokio::test] +async fn js_repl_reimports_local_files_after_fixing_failure() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let cwd_dir = tempdir()?; + let helper_path = cwd_dir.path().join("broken.js"); + fs::write(&helper_path, "throw new Error(\"boom\");\n")?; + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + Vec::new(), + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let err = manager + .execute( + Arc::clone(&session), + Arc::clone(&turn), + Arc::clone(&tracker), + JsReplArgs { + code: "await import(\"./broken.js\");".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await + .expect_err("expected broken module import to fail"); + assert!(err.to_string().contains("boom")); + + fs::write(&helper_path, "export const value = \"fixed\";\n")?; + + let result = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "console.log((await import(\"./broken.js\")).value);".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + assert!(result.output.contains("fixed")); + Ok(()) +} + +#[tokio::test] +async fn js_repl_local_files_expose_node_like_import_meta() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let cwd_dir = tempdir()?; + let pkg_dir = cwd_dir.path().join("node_modules").join("repl_meta_pkg"); + fs::create_dir_all(&pkg_dir)?; + fs::write( + pkg_dir.join("package.json"), + "{\n \"name\": \"repl_meta_pkg\",\n \"version\": \"1.0.0\",\n \"type\": \"module\",\n \"exports\": {\n \"import\": \"./index.js\"\n }\n}\n", + )?; + fs::write( + pkg_dir.join("index.js"), + "import { sep } from \"node:path\";\nexport const value = `pkg:${typeof sep}`;\n", + )?; + write_js_repl_test_module( + cwd_dir.path(), + "child.js", + "export const value = \"child-export\";\n", + )?; + write_js_repl_test_module( + cwd_dir.path(), + "meta.js", + "console.log(import.meta.url);\nconsole.log(import.meta.filename);\nconsole.log(import.meta.dirname);\nconsole.log(import.meta.main);\nconsole.log(import.meta.resolve(\"./child.js\"));\nconsole.log(import.meta.resolve(\"repl_meta_pkg\"));\nconsole.log(import.meta.resolve(\"node:fs\"));\nconsole.log((await import(import.meta.resolve(\"./child.js\"))).value);\nconsole.log((await import(import.meta.resolve(\"repl_meta_pkg\"))).value);\n", + )?; + let child_path = fs::canonicalize(cwd_dir.path().join("child.js"))?; + let child_url = url::Url::from_file_path(&child_path) + .expect("child path should convert to file URL") + .to_string(); + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + Vec::new(), + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let result = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "await import(\"./meta.js\");".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await?; + let cwd_display = cwd_dir.path().display().to_string(); + let meta_path_display = cwd_dir.path().join("meta.js").display().to_string(); + assert!(result.output.contains("file://")); + assert!(result.output.contains(&meta_path_display)); + assert!(result.output.contains(&cwd_display)); + assert!(result.output.contains("false")); + assert!(result.output.contains(&child_url)); + assert!(result.output.contains("repl_meta_pkg")); + assert!(result.output.contains("node:fs")); + assert!(result.output.contains("child-export")); + assert!(result.output.contains("pkg:string")); + Ok(()) +} + +#[tokio::test] +async fn js_repl_rejects_top_level_static_imports_with_clear_error() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let (session, turn) = make_session_and_context().await; + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let err = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "import \"./local.js\";".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await + .expect_err("expected top-level static import to be rejected"); + assert!( + err.to_string() + .contains("Top-level static import \"./local.js\" is not supported in js_repl") + ); + Ok(()) +} + +#[tokio::test] +async fn js_repl_local_files_reject_static_bare_imports() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let cwd_dir = tempdir()?; + write_js_repl_test_package(cwd_dir.path(), "repl_counter", "pkg")?; + write_js_repl_test_module( + cwd_dir.path(), + "entry.js", + "import { value } from \"repl_counter\";\nconsole.log(value);\n", + )?; + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + Vec::new(), + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let err = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "await import(\"./entry.js\");".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await + .expect_err("expected static bare import to be rejected"); + assert!( + err.to_string() + .contains("Static import \"repl_counter\" is not supported from js_repl local files") + ); + Ok(()) +} + +#[tokio::test] +async fn js_repl_rejects_unsupported_file_specifiers() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let cwd_dir = tempdir()?; + write_js_repl_test_module(cwd_dir.path(), "local.ts", "export const value = \"ts\";\n")?; + write_js_repl_test_module(cwd_dir.path(), "local", "export const value = \"noext\";\n")?; + fs::create_dir_all(cwd_dir.path().join("dir"))?; + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + Vec::new(), + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let unsupported_extension = manager + .execute( + Arc::clone(&session), + Arc::clone(&turn), + Arc::clone(&tracker), + JsReplArgs { + code: "await import(\"./local.ts\");".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await + .expect_err("expected unsupported extension to be rejected"); + assert!( + unsupported_extension + .to_string() + .contains("Only .js and .mjs files are supported") + ); + + let extensionless = manager + .execute( + Arc::clone(&session), + Arc::clone(&turn), + Arc::clone(&tracker), + JsReplArgs { + code: "await import(\"./local\");".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await + .expect_err("expected extensionless import to be rejected"); + assert!( + extensionless + .to_string() + .contains("Only .js and .mjs files are supported") + ); + + let directory = manager + .execute( + Arc::clone(&session), + Arc::clone(&turn), + Arc::clone(&tracker), + JsReplArgs { + code: "await import(\"./dir\");".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await + .expect_err("expected directory import to be rejected"); + assert!( + directory + .to_string() + .contains("Directory imports are not supported") + ); + + let unsupported_url = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "await import(\"https://example.com/test.js\");".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await + .expect_err("expected unsupported url import to be rejected"); + assert!( + unsupported_url + .to_string() + .contains("Unsupported import specifier") + ); + Ok(()) +} + +#[tokio::test] +async fn js_repl_blocks_sensitive_builtin_imports_from_local_files() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let cwd_dir = tempdir()?; + write_js_repl_test_module( + cwd_dir.path(), + "blocked.js", + "import process from \"node:process\";\nconsole.log(process.pid);\n", + )?; + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.path().to_path_buf(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + Vec::new(), + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let err = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "await import(\"./blocked.js\");".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await + .expect_err("expected blocked builtin import to be rejected"); + assert!( + err.to_string() + .contains("Importing module \"node:process\" is not allowed in js_repl") + ); + Ok(()) +} + +#[tokio::test] +async fn js_repl_local_files_do_not_escape_node_module_search_roots() -> anyhow::Result<()> { + if !can_run_js_repl_runtime_tests().await { + return Ok(()); + } + + let parent_dir = tempdir()?; + write_js_repl_test_package(parent_dir.path(), "repl_probe", "parent")?; + let cwd_dir = parent_dir.path().join("workspace"); + fs::create_dir_all(&cwd_dir)?; + write_js_repl_test_module( + &cwd_dir, + "entry.js", + "const { value } = await import(\"repl_probe\");\nconsole.log(value);\n", + )?; + + let (session, mut turn) = make_session_and_context().await; + turn.shell_environment_policy + .r#set + .remove("CODEX_JS_REPL_NODE_MODULE_DIRS"); + turn.cwd = cwd_dir.clone(); + turn.js_repl = Arc::new(JsReplHandle::with_node_path( + turn.config.js_repl_node_path.clone(), + Vec::new(), + )); + + let session = Arc::new(session); + let turn = Arc::new(turn); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::default())); + let manager = turn.js_repl.manager().await?; + + let err = manager + .execute( + session, + turn, + tracker, + JsReplArgs { + code: "await import(\"./entry.js\");".to_string(), + timeout_ms: Some(10_000), + }, + ) + .await + .expect_err("expected parent node_modules lookup to be rejected"); + assert!(err.to_string().contains("repl_probe")); + Ok(()) +} diff --git a/codex-rs/core/src/tools/network_approval.rs b/codex-rs/core/src/tools/network_approval.rs index 78f1bb0f44..756b643701 100644 --- a/codex-rs/core/src/tools/network_approval.rs +++ b/codex-rs/core/src/tools/network_approval.rs @@ -590,206 +590,5 @@ pub(crate) async fn finish_deferred_network_approval( } #[cfg(test)] -mod tests { - use super::*; - use codex_network_proxy::BlockedRequestArgs; - use codex_protocol::protocol::AskForApproval; - use pretty_assertions::assert_eq; - - #[tokio::test] - async fn pending_approvals_are_deduped_per_host_protocol_and_port() { - let service = NetworkApprovalService::default(); - let key = HostApprovalKey { - host: "example.com".to_string(), - protocol: "http", - port: 443, - }; - - let (first, first_is_owner) = service.get_or_create_pending_approval(key.clone()).await; - let (second, second_is_owner) = service.get_or_create_pending_approval(key).await; - - assert!(first_is_owner); - assert!(!second_is_owner); - assert!(Arc::ptr_eq(&first, &second)); - } - - #[tokio::test] - async fn pending_approvals_do_not_dedupe_across_ports() { - let service = NetworkApprovalService::default(); - let first_key = HostApprovalKey { - host: "example.com".to_string(), - protocol: "https", - port: 443, - }; - let second_key = HostApprovalKey { - host: "example.com".to_string(), - protocol: "https", - port: 8443, - }; - - let (first, first_is_owner) = service.get_or_create_pending_approval(first_key).await; - let (second, second_is_owner) = service.get_or_create_pending_approval(second_key).await; - - assert!(first_is_owner); - assert!(second_is_owner); - assert!(!Arc::ptr_eq(&first, &second)); - } - - #[tokio::test] - async fn session_approved_hosts_preserve_protocol_and_port_scope() { - let source = NetworkApprovalService::default(); - { - let mut approved_hosts = source.session_approved_hosts.lock().await; - approved_hosts.extend([ - HostApprovalKey { - host: "example.com".to_string(), - protocol: "https", - port: 443, - }, - HostApprovalKey { - host: "example.com".to_string(), - protocol: "https", - port: 8443, - }, - HostApprovalKey { - host: "example.com".to_string(), - protocol: "http", - port: 80, - }, - ]); - } - - let seeded = NetworkApprovalService::default(); - source.copy_session_approved_hosts_to(&seeded).await; - - let mut copied = seeded - .session_approved_hosts - .lock() - .await - .iter() - .cloned() - .collect::>(); - copied.sort_by(|a, b| (&a.host, a.protocol, a.port).cmp(&(&b.host, b.protocol, b.port))); - - assert_eq!( - copied, - vec![ - HostApprovalKey { - host: "example.com".to_string(), - protocol: "http", - port: 80, - }, - HostApprovalKey { - host: "example.com".to_string(), - protocol: "https", - port: 443, - }, - HostApprovalKey { - host: "example.com".to_string(), - protocol: "https", - port: 8443, - }, - ] - ); - } - - #[tokio::test] - async fn pending_waiters_receive_owner_decision() { - let pending = Arc::new(PendingHostApproval::new()); - - let waiter = { - let pending = Arc::clone(&pending); - tokio::spawn(async move { pending.wait_for_decision().await }) - }; - - pending - .set_decision(PendingApprovalDecision::AllowOnce) - .await; - - let decision = waiter.await.expect("waiter should complete"); - assert_eq!(decision, PendingApprovalDecision::AllowOnce); - } - - #[test] - fn allow_once_and_allow_for_session_both_allow_network() { - assert_eq!( - PendingApprovalDecision::AllowOnce.to_network_decision(), - NetworkDecision::Allow - ); - assert_eq!( - PendingApprovalDecision::AllowForSession.to_network_decision(), - NetworkDecision::Allow - ); - } - - #[test] - fn only_never_policy_disables_network_approval_flow() { - assert!(!allows_network_approval_flow(AskForApproval::Never)); - assert!(allows_network_approval_flow(AskForApproval::OnRequest)); - assert!(allows_network_approval_flow(AskForApproval::OnFailure)); - assert!(allows_network_approval_flow(AskForApproval::UnlessTrusted)); - } - - fn denied_blocked_request(host: &str) -> BlockedRequest { - BlockedRequest::new(BlockedRequestArgs { - host: host.to_string(), - reason: "not_allowed".to_string(), - client: None, - method: None, - mode: None, - protocol: "http".to_string(), - decision: Some("deny".to_string()), - source: Some("decider".to_string()), - port: Some(80), - }) - } - - #[tokio::test] - async fn record_blocked_request_sets_policy_outcome_for_owner_call() { - let service = NetworkApprovalService::default(); - service.register_call("registration-1".to_string()).await; - - service - .record_blocked_request(denied_blocked_request("example.com")) - .await; - - assert_eq!( - service.take_call_outcome("registration-1").await, - Some(NetworkApprovalOutcome::DeniedByPolicy( - "Network access to \"example.com\" was blocked: domain is not on the allowlist for the current sandbox mode.".to_string() - )) - ); - } - - #[tokio::test] - async fn blocked_request_policy_does_not_override_user_denial_outcome() { - let service = NetworkApprovalService::default(); - service.register_call("registration-1".to_string()).await; - - service - .record_call_outcome("registration-1", NetworkApprovalOutcome::DeniedByUser) - .await; - service - .record_blocked_request(denied_blocked_request("example.com")) - .await; - - assert_eq!( - service.take_call_outcome("registration-1").await, - Some(NetworkApprovalOutcome::DeniedByUser) - ); - } - - #[tokio::test] - async fn record_blocked_request_ignores_ambiguous_unattributed_blocked_requests() { - let service = NetworkApprovalService::default(); - service.register_call("registration-1".to_string()).await; - service.register_call("registration-2".to_string()).await; - - service - .record_blocked_request(denied_blocked_request("example.com")) - .await; - - assert_eq!(service.take_call_outcome("registration-1").await, None); - assert_eq!(service.take_call_outcome("registration-2").await, None); - } -} +#[path = "network_approval_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/network_approval_tests.rs b/codex-rs/core/src/tools/network_approval_tests.rs new file mode 100644 index 0000000000..991d07f042 --- /dev/null +++ b/codex-rs/core/src/tools/network_approval_tests.rs @@ -0,0 +1,201 @@ +use super::*; +use codex_network_proxy::BlockedRequestArgs; +use codex_protocol::protocol::AskForApproval; +use pretty_assertions::assert_eq; + +#[tokio::test] +async fn pending_approvals_are_deduped_per_host_protocol_and_port() { + let service = NetworkApprovalService::default(); + let key = HostApprovalKey { + host: "example.com".to_string(), + protocol: "http", + port: 443, + }; + + let (first, first_is_owner) = service.get_or_create_pending_approval(key.clone()).await; + let (second, second_is_owner) = service.get_or_create_pending_approval(key).await; + + assert!(first_is_owner); + assert!(!second_is_owner); + assert!(Arc::ptr_eq(&first, &second)); +} + +#[tokio::test] +async fn pending_approvals_do_not_dedupe_across_ports() { + let service = NetworkApprovalService::default(); + let first_key = HostApprovalKey { + host: "example.com".to_string(), + protocol: "https", + port: 443, + }; + let second_key = HostApprovalKey { + host: "example.com".to_string(), + protocol: "https", + port: 8443, + }; + + let (first, first_is_owner) = service.get_or_create_pending_approval(first_key).await; + let (second, second_is_owner) = service.get_or_create_pending_approval(second_key).await; + + assert!(first_is_owner); + assert!(second_is_owner); + assert!(!Arc::ptr_eq(&first, &second)); +} + +#[tokio::test] +async fn session_approved_hosts_preserve_protocol_and_port_scope() { + let source = NetworkApprovalService::default(); + { + let mut approved_hosts = source.session_approved_hosts.lock().await; + approved_hosts.extend([ + HostApprovalKey { + host: "example.com".to_string(), + protocol: "https", + port: 443, + }, + HostApprovalKey { + host: "example.com".to_string(), + protocol: "https", + port: 8443, + }, + HostApprovalKey { + host: "example.com".to_string(), + protocol: "http", + port: 80, + }, + ]); + } + + let seeded = NetworkApprovalService::default(); + source.copy_session_approved_hosts_to(&seeded).await; + + let mut copied = seeded + .session_approved_hosts + .lock() + .await + .iter() + .cloned() + .collect::>(); + copied.sort_by(|a, b| (&a.host, a.protocol, a.port).cmp(&(&b.host, b.protocol, b.port))); + + assert_eq!( + copied, + vec![ + HostApprovalKey { + host: "example.com".to_string(), + protocol: "http", + port: 80, + }, + HostApprovalKey { + host: "example.com".to_string(), + protocol: "https", + port: 443, + }, + HostApprovalKey { + host: "example.com".to_string(), + protocol: "https", + port: 8443, + }, + ] + ); +} + +#[tokio::test] +async fn pending_waiters_receive_owner_decision() { + let pending = Arc::new(PendingHostApproval::new()); + + let waiter = { + let pending = Arc::clone(&pending); + tokio::spawn(async move { pending.wait_for_decision().await }) + }; + + pending + .set_decision(PendingApprovalDecision::AllowOnce) + .await; + + let decision = waiter.await.expect("waiter should complete"); + assert_eq!(decision, PendingApprovalDecision::AllowOnce); +} + +#[test] +fn allow_once_and_allow_for_session_both_allow_network() { + assert_eq!( + PendingApprovalDecision::AllowOnce.to_network_decision(), + NetworkDecision::Allow + ); + assert_eq!( + PendingApprovalDecision::AllowForSession.to_network_decision(), + NetworkDecision::Allow + ); +} + +#[test] +fn only_never_policy_disables_network_approval_flow() { + assert!(!allows_network_approval_flow(AskForApproval::Never)); + assert!(allows_network_approval_flow(AskForApproval::OnRequest)); + assert!(allows_network_approval_flow(AskForApproval::OnFailure)); + assert!(allows_network_approval_flow(AskForApproval::UnlessTrusted)); +} + +fn denied_blocked_request(host: &str) -> BlockedRequest { + BlockedRequest::new(BlockedRequestArgs { + host: host.to_string(), + reason: "not_allowed".to_string(), + client: None, + method: None, + mode: None, + protocol: "http".to_string(), + decision: Some("deny".to_string()), + source: Some("decider".to_string()), + port: Some(80), + }) +} + +#[tokio::test] +async fn record_blocked_request_sets_policy_outcome_for_owner_call() { + let service = NetworkApprovalService::default(); + service.register_call("registration-1".to_string()).await; + + service + .record_blocked_request(denied_blocked_request("example.com")) + .await; + + assert_eq!( + service.take_call_outcome("registration-1").await, + Some(NetworkApprovalOutcome::DeniedByPolicy( + "Network access to \"example.com\" was blocked: domain is not on the allowlist for the current sandbox mode.".to_string() + )) + ); +} + +#[tokio::test] +async fn blocked_request_policy_does_not_override_user_denial_outcome() { + let service = NetworkApprovalService::default(); + service.register_call("registration-1".to_string()).await; + + service + .record_call_outcome("registration-1", NetworkApprovalOutcome::DeniedByUser) + .await; + service + .record_blocked_request(denied_blocked_request("example.com")) + .await; + + assert_eq!( + service.take_call_outcome("registration-1").await, + Some(NetworkApprovalOutcome::DeniedByUser) + ); +} + +#[tokio::test] +async fn record_blocked_request_ignores_ambiguous_unattributed_blocked_requests() { + let service = NetworkApprovalService::default(); + service.register_call("registration-1".to_string()).await; + service.register_call("registration-2".to_string()).await; + + service + .record_blocked_request(denied_blocked_request("example.com")) + .await; + + assert_eq!(service.take_call_outcome("registration-1").await, None); + assert_eq!(service.take_call_outcome("registration-2").await, None); +} diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index fcf4921922..e06e3442a6 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -538,58 +538,5 @@ async fn dispatch_after_tool_use_hook( } #[cfg(test)] -mod tests { - use super::*; - use crate::tools::context::ToolInvocation; - use async_trait::async_trait; - use pretty_assertions::assert_eq; - - struct TestHandler; - - #[async_trait] - impl ToolHandler for TestHandler { - type Output = crate::tools::context::FunctionToolOutput; - - fn kind(&self) -> ToolKind { - ToolKind::Function - } - - async fn handle( - &self, - _invocation: ToolInvocation, - ) -> Result { - unreachable!("test handler should not be invoked") - } - } - - #[test] - fn handler_looks_up_namespaced_aliases_explicitly() { - let plain_handler = Arc::new(TestHandler) as Arc; - let namespaced_handler = Arc::new(TestHandler) as Arc; - let namespace = "mcp__codex_apps__gmail"; - let tool_name = "gmail_get_recent_emails"; - let namespaced_name = tool_handler_key(tool_name, Some(namespace)); - let registry = ToolRegistry::new(HashMap::from([ - (tool_name.to_string(), Arc::clone(&plain_handler)), - (namespaced_name, Arc::clone(&namespaced_handler)), - ])); - - let plain = registry.handler(tool_name, None); - let namespaced = registry.handler(tool_name, Some(namespace)); - let missing_namespaced = registry.handler(tool_name, Some("mcp__codex_apps__calendar")); - - assert_eq!(plain.is_some(), true); - assert_eq!(namespaced.is_some(), true); - assert_eq!(missing_namespaced.is_none(), true); - assert!( - plain - .as_ref() - .is_some_and(|handler| Arc::ptr_eq(handler, &plain_handler)) - ); - assert!( - namespaced - .as_ref() - .is_some_and(|handler| Arc::ptr_eq(handler, &namespaced_handler)) - ); - } -} +#[path = "registry_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/registry_tests.rs b/codex-rs/core/src/tools/registry_tests.rs new file mode 100644 index 0000000000..5d9e98df35 --- /dev/null +++ b/codex-rs/core/src/tools/registry_tests.rs @@ -0,0 +1,50 @@ +use super::*; +use crate::tools::context::ToolInvocation; +use async_trait::async_trait; +use pretty_assertions::assert_eq; + +struct TestHandler; + +#[async_trait] +impl ToolHandler for TestHandler { + type Output = crate::tools::context::FunctionToolOutput; + + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle(&self, _invocation: ToolInvocation) -> Result { + unreachable!("test handler should not be invoked") + } +} + +#[test] +fn handler_looks_up_namespaced_aliases_explicitly() { + let plain_handler = Arc::new(TestHandler) as Arc; + let namespaced_handler = Arc::new(TestHandler) as Arc; + let namespace = "mcp__codex_apps__gmail"; + let tool_name = "gmail_get_recent_emails"; + let namespaced_name = tool_handler_key(tool_name, Some(namespace)); + let registry = ToolRegistry::new(HashMap::from([ + (tool_name.to_string(), Arc::clone(&plain_handler)), + (namespaced_name, Arc::clone(&namespaced_handler)), + ])); + + let plain = registry.handler(tool_name, None); + let namespaced = registry.handler(tool_name, Some(namespace)); + let missing_namespaced = registry.handler(tool_name, Some("mcp__codex_apps__calendar")); + + assert_eq!(plain.is_some(), true); + assert_eq!(namespaced.is_some(), true); + assert_eq!(missing_namespaced.is_none(), true); + assert!( + plain + .as_ref() + .is_some_and(|handler| Arc::ptr_eq(handler, &plain_handler)) + ); + assert!( + namespaced + .as_ref() + .is_some_and(|handler| Arc::ptr_eq(handler, &namespaced_handler)) + ); +} diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index d311d00702..9d8381c621 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -290,166 +290,5 @@ impl ToolRouter { } } #[cfg(test)] -mod tests { - use std::sync::Arc; - - use crate::codex::make_session_and_context; - use crate::tools::context::ToolPayload; - use crate::turn_diff_tracker::TurnDiffTracker; - use codex_protocol::models::ResponseInputItem; - use codex_protocol::models::ResponseItem; - - use super::ToolCall; - use super::ToolCallSource; - use super::ToolRouter; - use super::ToolRouterParams; - - #[tokio::test] - async fn js_repl_tools_only_blocks_direct_tool_calls() -> anyhow::Result<()> { - let (session, mut turn) = make_session_and_context().await; - turn.tools_config.js_repl_tools_only = true; - - let session = Arc::new(session); - let turn = Arc::new(turn); - let mcp_tools = session - .services - .mcp_connection_manager - .read() - .await - .list_all_tools() - .await; - let app_tools = Some(mcp_tools.clone()); - let router = ToolRouter::from_config( - &turn.tools_config, - ToolRouterParams { - mcp_tools: Some( - mcp_tools - .into_iter() - .map(|(name, tool)| (name, tool.tool)) - .collect(), - ), - app_tools, - discoverable_tools: None, - dynamic_tools: turn.dynamic_tools.as_slice(), - }, - ); - - let call = ToolCall { - tool_name: "shell".to_string(), - tool_namespace: None, - call_id: "call-1".to_string(), - payload: ToolPayload::Function { - arguments: "{}".to_string(), - }, - }; - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); - let response = router - .dispatch_tool_call(session, turn, tracker, call, ToolCallSource::Direct) - .await?; - - match response { - ResponseInputItem::FunctionCallOutput { output, .. } => { - let content = output.text_content().unwrap_or_default(); - assert!( - content.contains("direct tool calls are disabled"), - "unexpected tool call message: {content}", - ); - } - other => panic!("expected function call output, got {other:?}"), - } - - Ok(()) - } - - #[tokio::test] - async fn js_repl_tools_only_allows_js_repl_source_calls() -> anyhow::Result<()> { - let (session, mut turn) = make_session_and_context().await; - turn.tools_config.js_repl_tools_only = true; - - let session = Arc::new(session); - let turn = Arc::new(turn); - let mcp_tools = session - .services - .mcp_connection_manager - .read() - .await - .list_all_tools() - .await; - let app_tools = Some(mcp_tools.clone()); - let router = ToolRouter::from_config( - &turn.tools_config, - ToolRouterParams { - mcp_tools: Some( - mcp_tools - .into_iter() - .map(|(name, tool)| (name, tool.tool)) - .collect(), - ), - app_tools, - discoverable_tools: None, - dynamic_tools: turn.dynamic_tools.as_slice(), - }, - ); - - let call = ToolCall { - tool_name: "shell".to_string(), - tool_namespace: None, - call_id: "call-2".to_string(), - payload: ToolPayload::Function { - arguments: "{}".to_string(), - }, - }; - let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); - let response = router - .dispatch_tool_call(session, turn, tracker, call, ToolCallSource::JsRepl) - .await?; - - match response { - ResponseInputItem::FunctionCallOutput { output, .. } => { - let content = output.text_content().unwrap_or_default(); - assert!( - !content.contains("direct tool calls are disabled"), - "js_repl source should bypass direct-call policy gate" - ); - } - other => panic!("expected function call output, got {other:?}"), - } - - Ok(()) - } - - #[tokio::test] - async fn build_tool_call_uses_namespace_for_registry_name() -> anyhow::Result<()> { - let (session, _) = make_session_and_context().await; - let session = Arc::new(session); - let tool_name = "create_event".to_string(); - - let call = ToolRouter::build_tool_call( - &session, - ResponseItem::FunctionCall { - id: None, - name: tool_name.clone(), - namespace: Some("mcp__codex_apps__calendar".to_string()), - arguments: "{}".to_string(), - call_id: "call-namespace".to_string(), - }, - ) - .await? - .expect("function_call should produce a tool call"); - - assert_eq!(call.tool_name, tool_name); - assert_eq!( - call.tool_namespace, - Some("mcp__codex_apps__calendar".to_string()) - ); - assert_eq!(call.call_id, "call-namespace"); - match call.payload { - ToolPayload::Function { arguments } => { - assert_eq!(arguments, "{}"); - } - other => panic!("expected function payload, got {other:?}"), - } - - Ok(()) - } -} +#[path = "router_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/router_tests.rs b/codex-rs/core/src/tools/router_tests.rs new file mode 100644 index 0000000000..6350323d1b --- /dev/null +++ b/codex-rs/core/src/tools/router_tests.rs @@ -0,0 +1,161 @@ +use std::sync::Arc; + +use crate::codex::make_session_and_context; +use crate::tools::context::ToolPayload; +use crate::turn_diff_tracker::TurnDiffTracker; +use codex_protocol::models::ResponseInputItem; +use codex_protocol::models::ResponseItem; + +use super::ToolCall; +use super::ToolCallSource; +use super::ToolRouter; +use super::ToolRouterParams; + +#[tokio::test] +async fn js_repl_tools_only_blocks_direct_tool_calls() -> anyhow::Result<()> { + let (session, mut turn) = make_session_and_context().await; + turn.tools_config.js_repl_tools_only = true; + + let session = Arc::new(session); + let turn = Arc::new(turn); + let mcp_tools = session + .services + .mcp_connection_manager + .read() + .await + .list_all_tools() + .await; + let app_tools = Some(mcp_tools.clone()); + let router = ToolRouter::from_config( + &turn.tools_config, + ToolRouterParams { + mcp_tools: Some( + mcp_tools + .into_iter() + .map(|(name, tool)| (name, tool.tool)) + .collect(), + ), + app_tools, + discoverable_tools: None, + dynamic_tools: turn.dynamic_tools.as_slice(), + }, + ); + + let call = ToolCall { + tool_name: "shell".to_string(), + tool_namespace: None, + call_id: "call-1".to_string(), + payload: ToolPayload::Function { + arguments: "{}".to_string(), + }, + }; + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + let response = router + .dispatch_tool_call(session, turn, tracker, call, ToolCallSource::Direct) + .await?; + + match response { + ResponseInputItem::FunctionCallOutput { output, .. } => { + let content = output.text_content().unwrap_or_default(); + assert!( + content.contains("direct tool calls are disabled"), + "unexpected tool call message: {content}", + ); + } + other => panic!("expected function call output, got {other:?}"), + } + + Ok(()) +} + +#[tokio::test] +async fn js_repl_tools_only_allows_js_repl_source_calls() -> anyhow::Result<()> { + let (session, mut turn) = make_session_and_context().await; + turn.tools_config.js_repl_tools_only = true; + + let session = Arc::new(session); + let turn = Arc::new(turn); + let mcp_tools = session + .services + .mcp_connection_manager + .read() + .await + .list_all_tools() + .await; + let app_tools = Some(mcp_tools.clone()); + let router = ToolRouter::from_config( + &turn.tools_config, + ToolRouterParams { + mcp_tools: Some( + mcp_tools + .into_iter() + .map(|(name, tool)| (name, tool.tool)) + .collect(), + ), + app_tools, + discoverable_tools: None, + dynamic_tools: turn.dynamic_tools.as_slice(), + }, + ); + + let call = ToolCall { + tool_name: "shell".to_string(), + tool_namespace: None, + call_id: "call-2".to_string(), + payload: ToolPayload::Function { + arguments: "{}".to_string(), + }, + }; + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + let response = router + .dispatch_tool_call(session, turn, tracker, call, ToolCallSource::JsRepl) + .await?; + + match response { + ResponseInputItem::FunctionCallOutput { output, .. } => { + let content = output.text_content().unwrap_or_default(); + assert!( + !content.contains("direct tool calls are disabled"), + "js_repl source should bypass direct-call policy gate" + ); + } + other => panic!("expected function call output, got {other:?}"), + } + + Ok(()) +} + +#[tokio::test] +async fn build_tool_call_uses_namespace_for_registry_name() -> anyhow::Result<()> { + let (session, _) = make_session_and_context().await; + let session = Arc::new(session); + let tool_name = "create_event".to_string(); + + let call = ToolRouter::build_tool_call( + &session, + ResponseItem::FunctionCall { + id: None, + name: tool_name.clone(), + namespace: Some("mcp__codex_apps__calendar".to_string()), + arguments: "{}".to_string(), + call_id: "call-namespace".to_string(), + }, + ) + .await? + .expect("function_call should produce a tool call"); + + assert_eq!(call.tool_name, tool_name); + assert_eq!( + call.tool_namespace, + Some("mcp__codex_apps__calendar".to_string()) + ); + assert_eq!(call.call_id, "call-namespace"); + match call.payload { + ToolPayload::Function { arguments } => { + assert_eq!(arguments, "{}"); + } + other => panic!("expected function payload, got {other:?}"), + } + + Ok(()) +} diff --git a/codex-rs/core/src/tools/runtimes/apply_patch.rs b/codex-rs/core/src/tools/runtimes/apply_patch.rs index 18a82bd948..6cf1a4fa05 100644 --- a/codex-rs/core/src/tools/runtimes/apply_patch.rs +++ b/codex-rs/core/src/tools/runtimes/apply_patch.rs @@ -204,74 +204,5 @@ impl ToolRuntime for ApplyPatchRuntime { } #[cfg(test)] -mod tests { - use super::*; - use codex_protocol::protocol::RejectConfig; - use pretty_assertions::assert_eq; - use std::collections::HashMap; - - #[test] - fn wants_no_sandbox_approval_reject_respects_sandbox_flag() { - let runtime = ApplyPatchRuntime::new(); - assert!(runtime.wants_no_sandbox_approval(AskForApproval::OnRequest)); - assert!( - !runtime.wants_no_sandbox_approval(AskForApproval::Reject(RejectConfig { - sandbox_approval: true, - rules: false, - skill_approval: false, - request_permissions: false, - mcp_elicitations: false, - })) - ); - assert!( - runtime.wants_no_sandbox_approval(AskForApproval::Reject(RejectConfig { - sandbox_approval: false, - rules: false, - skill_approval: false, - request_permissions: false, - mcp_elicitations: false, - })) - ); - } - - #[test] - fn guardian_review_request_includes_full_patch_without_duplicate_changes() { - let path = std::env::temp_dir().join("guardian-apply-patch-test.txt"); - let action = ApplyPatchAction::new_add_for_test(&path, "hello".to_string()); - let expected_cwd = action.cwd.clone(); - let expected_patch = action.patch.clone(); - let request = ApplyPatchRequest { - action, - file_paths: vec![ - AbsolutePathBuf::from_absolute_path(&path).expect("temp path should be absolute"), - ], - changes: HashMap::from([( - path, - FileChange::Add { - content: "hello".to_string(), - }, - )]), - exec_approval_requirement: ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: None, - }, - sandbox_permissions: SandboxPermissions::UseDefault, - additional_permissions: None, - permissions_preapproved: false, - timeout_ms: None, - codex_exe: None, - }; - - let guardian_request = ApplyPatchRuntime::build_guardian_review_request(&request); - - assert_eq!( - guardian_request, - GuardianApprovalRequest::ApplyPatch { - cwd: expected_cwd, - files: request.file_paths, - change_count: 1usize, - patch: expected_patch, - } - ); - } -} +#[path = "apply_patch_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/runtimes/apply_patch_tests.rs b/codex-rs/core/src/tools/runtimes/apply_patch_tests.rs new file mode 100644 index 0000000000..c598162760 --- /dev/null +++ b/codex-rs/core/src/tools/runtimes/apply_patch_tests.rs @@ -0,0 +1,69 @@ +use super::*; +use codex_protocol::protocol::RejectConfig; +use pretty_assertions::assert_eq; +use std::collections::HashMap; + +#[test] +fn wants_no_sandbox_approval_reject_respects_sandbox_flag() { + let runtime = ApplyPatchRuntime::new(); + assert!(runtime.wants_no_sandbox_approval(AskForApproval::OnRequest)); + assert!( + !runtime.wants_no_sandbox_approval(AskForApproval::Reject(RejectConfig { + sandbox_approval: true, + rules: false, + skill_approval: false, + request_permissions: false, + mcp_elicitations: false, + })) + ); + assert!( + runtime.wants_no_sandbox_approval(AskForApproval::Reject(RejectConfig { + sandbox_approval: false, + rules: false, + skill_approval: false, + request_permissions: false, + mcp_elicitations: false, + })) + ); +} + +#[test] +fn guardian_review_request_includes_full_patch_without_duplicate_changes() { + let path = std::env::temp_dir().join("guardian-apply-patch-test.txt"); + let action = ApplyPatchAction::new_add_for_test(&path, "hello".to_string()); + let expected_cwd = action.cwd.clone(); + let expected_patch = action.patch.clone(); + let request = ApplyPatchRequest { + action, + file_paths: vec![ + AbsolutePathBuf::from_absolute_path(&path).expect("temp path should be absolute"), + ], + changes: HashMap::from([( + path, + FileChange::Add { + content: "hello".to_string(), + }, + )]), + exec_approval_requirement: ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: None, + }, + sandbox_permissions: SandboxPermissions::UseDefault, + additional_permissions: None, + permissions_preapproved: false, + timeout_ms: None, + codex_exe: None, + }; + + let guardian_request = ApplyPatchRuntime::build_guardian_review_request(&request); + + assert_eq!( + guardian_request, + GuardianApprovalRequest::ApplyPatch { + cwd: expected_cwd, + files: request.file_paths, + change_count: 1usize, + patch: expected_patch, + } + ); +} diff --git a/codex-rs/core/src/tools/runtimes/mod.rs b/codex-rs/core/src/tools/runtimes/mod.rs index 51002f3ce9..8003819a84 100644 --- a/codex-rs/core/src/tools/runtimes/mod.rs +++ b/codex-rs/core/src/tools/runtimes/mod.rs @@ -177,437 +177,5 @@ fn shell_single_quote(input: &str) -> String { } #[cfg(all(test, unix))] -mod tests { - use super::*; - use crate::shell::ShellType; - use crate::shell_snapshot::ShellSnapshot; - use pretty_assertions::assert_eq; - use std::path::PathBuf; - use std::process::Command; - use std::sync::Arc; - use tempfile::tempdir; - use tokio::sync::watch; - - fn shell_with_snapshot( - shell_type: ShellType, - shell_path: &str, - snapshot_path: PathBuf, - snapshot_cwd: PathBuf, - ) -> Shell { - let (_tx, shell_snapshot) = watch::channel(Some(Arc::new(ShellSnapshot { - path: snapshot_path, - cwd: snapshot_cwd, - }))); - Shell { - shell_type, - shell_path: PathBuf::from(shell_path), - shell_snapshot, - } - } - - #[test] - fn maybe_wrap_shell_lc_with_snapshot_bootstraps_in_user_shell() { - let dir = tempdir().expect("create temp dir"); - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); - let session_shell = shell_with_snapshot( - ShellType::Zsh, - "/bin/zsh", - snapshot_path, - dir.path().to_path_buf(), - ); - let command = vec![ - "/bin/bash".to_string(), - "-lc".to_string(), - "echo hello".to_string(), - ]; - - let rewritten = maybe_wrap_shell_lc_with_snapshot( - &command, - &session_shell, - dir.path(), - &HashMap::new(), - ); - - assert_eq!(rewritten[0], "/bin/zsh"); - assert_eq!(rewritten[1], "-c"); - assert!(rewritten[2].contains("if . '")); - assert!(rewritten[2].contains("exec '/bin/bash' -c 'echo hello'")); - } - - #[test] - fn maybe_wrap_shell_lc_with_snapshot_escapes_single_quotes() { - let dir = tempdir().expect("create temp dir"); - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); - let session_shell = shell_with_snapshot( - ShellType::Zsh, - "/bin/zsh", - snapshot_path, - dir.path().to_path_buf(), - ); - let command = vec![ - "/bin/bash".to_string(), - "-lc".to_string(), - "echo 'hello'".to_string(), - ]; - - let rewritten = maybe_wrap_shell_lc_with_snapshot( - &command, - &session_shell, - dir.path(), - &HashMap::new(), - ); - - assert!(rewritten[2].contains(r#"exec '/bin/bash' -c 'echo '"'"'hello'"'"''"#)); - } - - #[test] - fn maybe_wrap_shell_lc_with_snapshot_uses_bash_bootstrap_shell() { - let dir = tempdir().expect("create temp dir"); - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); - let session_shell = shell_with_snapshot( - ShellType::Bash, - "/bin/bash", - snapshot_path, - dir.path().to_path_buf(), - ); - let command = vec![ - "/bin/zsh".to_string(), - "-lc".to_string(), - "echo hello".to_string(), - ]; - - let rewritten = maybe_wrap_shell_lc_with_snapshot( - &command, - &session_shell, - dir.path(), - &HashMap::new(), - ); - - assert_eq!(rewritten[0], "/bin/bash"); - assert_eq!(rewritten[1], "-c"); - assert!(rewritten[2].contains("if . '")); - assert!(rewritten[2].contains("exec '/bin/zsh' -c 'echo hello'")); - } - - #[test] - fn maybe_wrap_shell_lc_with_snapshot_uses_sh_bootstrap_shell() { - let dir = tempdir().expect("create temp dir"); - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); - let session_shell = shell_with_snapshot( - ShellType::Sh, - "/bin/sh", - snapshot_path, - dir.path().to_path_buf(), - ); - let command = vec![ - "/bin/bash".to_string(), - "-lc".to_string(), - "echo hello".to_string(), - ]; - - let rewritten = maybe_wrap_shell_lc_with_snapshot( - &command, - &session_shell, - dir.path(), - &HashMap::new(), - ); - - assert_eq!(rewritten[0], "/bin/sh"); - assert_eq!(rewritten[1], "-c"); - assert!(rewritten[2].contains("if . '")); - assert!(rewritten[2].contains("exec '/bin/bash' -c 'echo hello'")); - } - - #[test] - fn maybe_wrap_shell_lc_with_snapshot_preserves_trailing_args() { - let dir = tempdir().expect("create temp dir"); - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); - let session_shell = shell_with_snapshot( - ShellType::Zsh, - "/bin/zsh", - snapshot_path, - dir.path().to_path_buf(), - ); - let command = vec![ - "/bin/bash".to_string(), - "-lc".to_string(), - "printf '%s %s' \"$0\" \"$1\"".to_string(), - "arg0".to_string(), - "arg1".to_string(), - ]; - - let rewritten = maybe_wrap_shell_lc_with_snapshot( - &command, - &session_shell, - dir.path(), - &HashMap::new(), - ); - - assert!( - rewritten[2].contains( - r#"exec '/bin/bash' -c 'printf '"'"'%s %s'"'"' "$0" "$1"' 'arg0' 'arg1'"# - ) - ); - } - - #[test] - fn maybe_wrap_shell_lc_with_snapshot_skips_when_cwd_mismatch() { - let dir = tempdir().expect("create temp dir"); - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); - let snapshot_cwd = dir.path().join("worktree-a"); - let command_cwd = dir.path().join("worktree-b"); - std::fs::create_dir_all(&snapshot_cwd).expect("create snapshot cwd"); - std::fs::create_dir_all(&command_cwd).expect("create command cwd"); - let session_shell = - shell_with_snapshot(ShellType::Zsh, "/bin/zsh", snapshot_path, snapshot_cwd); - let command = vec![ - "/bin/bash".to_string(), - "-lc".to_string(), - "echo hello".to_string(), - ]; - - let rewritten = maybe_wrap_shell_lc_with_snapshot( - &command, - &session_shell, - &command_cwd, - &HashMap::new(), - ); - - assert_eq!(rewritten, command); - } - - #[test] - fn maybe_wrap_shell_lc_with_snapshot_accepts_dot_alias_cwd() { - let dir = tempdir().expect("create temp dir"); - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); - let session_shell = shell_with_snapshot( - ShellType::Zsh, - "/bin/zsh", - snapshot_path, - dir.path().to_path_buf(), - ); - let command = vec![ - "/bin/bash".to_string(), - "-lc".to_string(), - "echo hello".to_string(), - ]; - let command_cwd = dir.path().join("."); - - let rewritten = maybe_wrap_shell_lc_with_snapshot( - &command, - &session_shell, - &command_cwd, - &HashMap::new(), - ); - - assert_eq!(rewritten[0], "/bin/zsh"); - assert_eq!(rewritten[1], "-c"); - assert!(rewritten[2].contains("if . '")); - assert!(rewritten[2].contains("exec '/bin/bash' -c 'echo hello'")); - } - - #[test] - fn maybe_wrap_shell_lc_with_snapshot_restores_explicit_override_precedence() { - let dir = tempdir().expect("create temp dir"); - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write( - &snapshot_path, - "# Snapshot file\nexport TEST_ENV_SNAPSHOT=global\nexport SNAPSHOT_ONLY=from_snapshot\n", - ) - .expect("write snapshot"); - let session_shell = shell_with_snapshot( - ShellType::Bash, - "/bin/bash", - snapshot_path, - dir.path().to_path_buf(), - ); - let command = vec![ - "/bin/bash".to_string(), - "-lc".to_string(), - "printf '%s|%s' \"$TEST_ENV_SNAPSHOT\" \"${SNAPSHOT_ONLY-unset}\"".to_string(), - ]; - let explicit_env_overrides = - HashMap::from([("TEST_ENV_SNAPSHOT".to_string(), "worktree".to_string())]); - let rewritten = maybe_wrap_shell_lc_with_snapshot( - &command, - &session_shell, - dir.path(), - &explicit_env_overrides, - ); - let output = Command::new(&rewritten[0]) - .args(&rewritten[1..]) - .env("TEST_ENV_SNAPSHOT", "worktree") - .output() - .expect("run rewritten command"); - - assert!(output.status.success(), "command failed: {output:?}"); - assert_eq!( - String::from_utf8_lossy(&output.stdout), - "worktree|from_snapshot" - ); - } - - #[test] - fn maybe_wrap_shell_lc_with_snapshot_keeps_snapshot_path_without_override() { - let dir = tempdir().expect("create temp dir"); - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write( - &snapshot_path, - "# Snapshot file\nexport PATH='/snapshot/bin'\n", - ) - .expect("write snapshot"); - let session_shell = shell_with_snapshot( - ShellType::Bash, - "/bin/bash", - snapshot_path, - dir.path().to_path_buf(), - ); - let command = vec![ - "/bin/bash".to_string(), - "-lc".to_string(), - "printf '%s' \"$PATH\"".to_string(), - ]; - let rewritten = maybe_wrap_shell_lc_with_snapshot( - &command, - &session_shell, - dir.path(), - &HashMap::new(), - ); - let output = Command::new(&rewritten[0]) - .args(&rewritten[1..]) - .output() - .expect("run rewritten command"); - - assert!(output.status.success(), "command failed: {output:?}"); - assert_eq!(String::from_utf8_lossy(&output.stdout), "/snapshot/bin"); - } - - #[test] - fn maybe_wrap_shell_lc_with_snapshot_applies_explicit_path_override() { - let dir = tempdir().expect("create temp dir"); - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write( - &snapshot_path, - "# Snapshot file\nexport PATH='/snapshot/bin'\n", - ) - .expect("write snapshot"); - let session_shell = shell_with_snapshot( - ShellType::Bash, - "/bin/bash", - snapshot_path, - dir.path().to_path_buf(), - ); - let command = vec![ - "/bin/bash".to_string(), - "-lc".to_string(), - "printf '%s' \"$PATH\"".to_string(), - ]; - let explicit_env_overrides = - HashMap::from([("PATH".to_string(), "/worktree/bin".to_string())]); - let rewritten = maybe_wrap_shell_lc_with_snapshot( - &command, - &session_shell, - dir.path(), - &explicit_env_overrides, - ); - let output = Command::new(&rewritten[0]) - .args(&rewritten[1..]) - .env("PATH", "/worktree/bin") - .output() - .expect("run rewritten command"); - - assert!(output.status.success(), "command failed: {output:?}"); - assert_eq!(String::from_utf8_lossy(&output.stdout), "/worktree/bin"); - } - - #[test] - fn maybe_wrap_shell_lc_with_snapshot_does_not_embed_override_values_in_argv() { - let dir = tempdir().expect("create temp dir"); - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write( - &snapshot_path, - "# Snapshot file\nexport OPENAI_API_KEY='snapshot-value'\n", - ) - .expect("write snapshot"); - let session_shell = shell_with_snapshot( - ShellType::Bash, - "/bin/bash", - snapshot_path, - dir.path().to_path_buf(), - ); - let command = vec![ - "/bin/bash".to_string(), - "-lc".to_string(), - "printf '%s' \"$OPENAI_API_KEY\"".to_string(), - ]; - let explicit_env_overrides = HashMap::from([( - "OPENAI_API_KEY".to_string(), - "super-secret-value".to_string(), - )]); - let rewritten = maybe_wrap_shell_lc_with_snapshot( - &command, - &session_shell, - dir.path(), - &explicit_env_overrides, - ); - - assert!(!rewritten[2].contains("super-secret-value")); - let output = Command::new(&rewritten[0]) - .args(&rewritten[1..]) - .env("OPENAI_API_KEY", "super-secret-value") - .output() - .expect("run rewritten command"); - assert!(output.status.success(), "command failed: {output:?}"); - assert_eq!( - String::from_utf8_lossy(&output.stdout), - "super-secret-value" - ); - } - - #[test] - fn maybe_wrap_shell_lc_with_snapshot_preserves_unset_override_variables() { - let dir = tempdir().expect("create temp dir"); - let snapshot_path = dir.path().join("snapshot.sh"); - std::fs::write( - &snapshot_path, - "# Snapshot file\nexport CODEX_TEST_UNSET_OVERRIDE='snapshot-value'\n", - ) - .expect("write snapshot"); - let session_shell = shell_with_snapshot( - ShellType::Bash, - "/bin/bash", - snapshot_path, - dir.path().to_path_buf(), - ); - let command = vec![ - "/bin/bash".to_string(), - "-lc".to_string(), - "if [ \"${CODEX_TEST_UNSET_OVERRIDE+x}\" = x ]; then printf 'set:%s' \"$CODEX_TEST_UNSET_OVERRIDE\"; else printf 'unset'; fi".to_string(), - ]; - let explicit_env_overrides = HashMap::from([( - "CODEX_TEST_UNSET_OVERRIDE".to_string(), - "worktree-value".to_string(), - )]); - let rewritten = maybe_wrap_shell_lc_with_snapshot( - &command, - &session_shell, - dir.path(), - &explicit_env_overrides, - ); - - let output = Command::new(&rewritten[0]) - .args(&rewritten[1..]) - .env_remove("CODEX_TEST_UNSET_OVERRIDE") - .output() - .expect("run rewritten command"); - assert!(output.status.success(), "command failed: {output:?}"); - assert_eq!(String::from_utf8_lossy(&output.stdout), "unset"); - } -} +#[path = "mod_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/runtimes/mod_tests.rs b/codex-rs/core/src/tools/runtimes/mod_tests.rs new file mode 100644 index 0000000000..dbc341d1de --- /dev/null +++ b/codex-rs/core/src/tools/runtimes/mod_tests.rs @@ -0,0 +1,398 @@ +use super::*; +use crate::shell::ShellType; +use crate::shell_snapshot::ShellSnapshot; +use pretty_assertions::assert_eq; +use std::path::PathBuf; +use std::process::Command; +use std::sync::Arc; +use tempfile::tempdir; +use tokio::sync::watch; + +fn shell_with_snapshot( + shell_type: ShellType, + shell_path: &str, + snapshot_path: PathBuf, + snapshot_cwd: PathBuf, +) -> Shell { + let (_tx, shell_snapshot) = watch::channel(Some(Arc::new(ShellSnapshot { + path: snapshot_path, + cwd: snapshot_cwd, + }))); + Shell { + shell_type, + shell_path: PathBuf::from(shell_path), + shell_snapshot, + } +} + +#[test] +fn maybe_wrap_shell_lc_with_snapshot_bootstraps_in_user_shell() { + let dir = tempdir().expect("create temp dir"); + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); + let session_shell = shell_with_snapshot( + ShellType::Zsh, + "/bin/zsh", + snapshot_path, + dir.path().to_path_buf(), + ); + let command = vec![ + "/bin/bash".to_string(), + "-lc".to_string(), + "echo hello".to_string(), + ]; + + let rewritten = + maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path(), &HashMap::new()); + + assert_eq!(rewritten[0], "/bin/zsh"); + assert_eq!(rewritten[1], "-c"); + assert!(rewritten[2].contains("if . '")); + assert!(rewritten[2].contains("exec '/bin/bash' -c 'echo hello'")); +} + +#[test] +fn maybe_wrap_shell_lc_with_snapshot_escapes_single_quotes() { + let dir = tempdir().expect("create temp dir"); + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); + let session_shell = shell_with_snapshot( + ShellType::Zsh, + "/bin/zsh", + snapshot_path, + dir.path().to_path_buf(), + ); + let command = vec![ + "/bin/bash".to_string(), + "-lc".to_string(), + "echo 'hello'".to_string(), + ]; + + let rewritten = + maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path(), &HashMap::new()); + + assert!(rewritten[2].contains(r#"exec '/bin/bash' -c 'echo '"'"'hello'"'"''"#)); +} + +#[test] +fn maybe_wrap_shell_lc_with_snapshot_uses_bash_bootstrap_shell() { + let dir = tempdir().expect("create temp dir"); + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); + let session_shell = shell_with_snapshot( + ShellType::Bash, + "/bin/bash", + snapshot_path, + dir.path().to_path_buf(), + ); + let command = vec![ + "/bin/zsh".to_string(), + "-lc".to_string(), + "echo hello".to_string(), + ]; + + let rewritten = + maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path(), &HashMap::new()); + + assert_eq!(rewritten[0], "/bin/bash"); + assert_eq!(rewritten[1], "-c"); + assert!(rewritten[2].contains("if . '")); + assert!(rewritten[2].contains("exec '/bin/zsh' -c 'echo hello'")); +} + +#[test] +fn maybe_wrap_shell_lc_with_snapshot_uses_sh_bootstrap_shell() { + let dir = tempdir().expect("create temp dir"); + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); + let session_shell = shell_with_snapshot( + ShellType::Sh, + "/bin/sh", + snapshot_path, + dir.path().to_path_buf(), + ); + let command = vec![ + "/bin/bash".to_string(), + "-lc".to_string(), + "echo hello".to_string(), + ]; + + let rewritten = + maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path(), &HashMap::new()); + + assert_eq!(rewritten[0], "/bin/sh"); + assert_eq!(rewritten[1], "-c"); + assert!(rewritten[2].contains("if . '")); + assert!(rewritten[2].contains("exec '/bin/bash' -c 'echo hello'")); +} + +#[test] +fn maybe_wrap_shell_lc_with_snapshot_preserves_trailing_args() { + let dir = tempdir().expect("create temp dir"); + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); + let session_shell = shell_with_snapshot( + ShellType::Zsh, + "/bin/zsh", + snapshot_path, + dir.path().to_path_buf(), + ); + let command = vec![ + "/bin/bash".to_string(), + "-lc".to_string(), + "printf '%s %s' \"$0\" \"$1\"".to_string(), + "arg0".to_string(), + "arg1".to_string(), + ]; + + let rewritten = + maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path(), &HashMap::new()); + + assert!( + rewritten[2] + .contains(r#"exec '/bin/bash' -c 'printf '"'"'%s %s'"'"' "$0" "$1"' 'arg0' 'arg1'"#) + ); +} + +#[test] +fn maybe_wrap_shell_lc_with_snapshot_skips_when_cwd_mismatch() { + let dir = tempdir().expect("create temp dir"); + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); + let snapshot_cwd = dir.path().join("worktree-a"); + let command_cwd = dir.path().join("worktree-b"); + std::fs::create_dir_all(&snapshot_cwd).expect("create snapshot cwd"); + std::fs::create_dir_all(&command_cwd).expect("create command cwd"); + let session_shell = + shell_with_snapshot(ShellType::Zsh, "/bin/zsh", snapshot_path, snapshot_cwd); + let command = vec![ + "/bin/bash".to_string(), + "-lc".to_string(), + "echo hello".to_string(), + ]; + + let rewritten = + maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, &command_cwd, &HashMap::new()); + + assert_eq!(rewritten, command); +} + +#[test] +fn maybe_wrap_shell_lc_with_snapshot_accepts_dot_alias_cwd() { + let dir = tempdir().expect("create temp dir"); + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write(&snapshot_path, "# Snapshot file\n").expect("write snapshot"); + let session_shell = shell_with_snapshot( + ShellType::Zsh, + "/bin/zsh", + snapshot_path, + dir.path().to_path_buf(), + ); + let command = vec![ + "/bin/bash".to_string(), + "-lc".to_string(), + "echo hello".to_string(), + ]; + let command_cwd = dir.path().join("."); + + let rewritten = + maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, &command_cwd, &HashMap::new()); + + assert_eq!(rewritten[0], "/bin/zsh"); + assert_eq!(rewritten[1], "-c"); + assert!(rewritten[2].contains("if . '")); + assert!(rewritten[2].contains("exec '/bin/bash' -c 'echo hello'")); +} + +#[test] +fn maybe_wrap_shell_lc_with_snapshot_restores_explicit_override_precedence() { + let dir = tempdir().expect("create temp dir"); + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write( + &snapshot_path, + "# Snapshot file\nexport TEST_ENV_SNAPSHOT=global\nexport SNAPSHOT_ONLY=from_snapshot\n", + ) + .expect("write snapshot"); + let session_shell = shell_with_snapshot( + ShellType::Bash, + "/bin/bash", + snapshot_path, + dir.path().to_path_buf(), + ); + let command = vec![ + "/bin/bash".to_string(), + "-lc".to_string(), + "printf '%s|%s' \"$TEST_ENV_SNAPSHOT\" \"${SNAPSHOT_ONLY-unset}\"".to_string(), + ]; + let explicit_env_overrides = + HashMap::from([("TEST_ENV_SNAPSHOT".to_string(), "worktree".to_string())]); + let rewritten = maybe_wrap_shell_lc_with_snapshot( + &command, + &session_shell, + dir.path(), + &explicit_env_overrides, + ); + let output = Command::new(&rewritten[0]) + .args(&rewritten[1..]) + .env("TEST_ENV_SNAPSHOT", "worktree") + .output() + .expect("run rewritten command"); + + assert!(output.status.success(), "command failed: {output:?}"); + assert_eq!( + String::from_utf8_lossy(&output.stdout), + "worktree|from_snapshot" + ); +} + +#[test] +fn maybe_wrap_shell_lc_with_snapshot_keeps_snapshot_path_without_override() { + let dir = tempdir().expect("create temp dir"); + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write( + &snapshot_path, + "# Snapshot file\nexport PATH='/snapshot/bin'\n", + ) + .expect("write snapshot"); + let session_shell = shell_with_snapshot( + ShellType::Bash, + "/bin/bash", + snapshot_path, + dir.path().to_path_buf(), + ); + let command = vec![ + "/bin/bash".to_string(), + "-lc".to_string(), + "printf '%s' \"$PATH\"".to_string(), + ]; + let rewritten = + maybe_wrap_shell_lc_with_snapshot(&command, &session_shell, dir.path(), &HashMap::new()); + let output = Command::new(&rewritten[0]) + .args(&rewritten[1..]) + .output() + .expect("run rewritten command"); + + assert!(output.status.success(), "command failed: {output:?}"); + assert_eq!(String::from_utf8_lossy(&output.stdout), "/snapshot/bin"); +} + +#[test] +fn maybe_wrap_shell_lc_with_snapshot_applies_explicit_path_override() { + let dir = tempdir().expect("create temp dir"); + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write( + &snapshot_path, + "# Snapshot file\nexport PATH='/snapshot/bin'\n", + ) + .expect("write snapshot"); + let session_shell = shell_with_snapshot( + ShellType::Bash, + "/bin/bash", + snapshot_path, + dir.path().to_path_buf(), + ); + let command = vec![ + "/bin/bash".to_string(), + "-lc".to_string(), + "printf '%s' \"$PATH\"".to_string(), + ]; + let explicit_env_overrides = HashMap::from([("PATH".to_string(), "/worktree/bin".to_string())]); + let rewritten = maybe_wrap_shell_lc_with_snapshot( + &command, + &session_shell, + dir.path(), + &explicit_env_overrides, + ); + let output = Command::new(&rewritten[0]) + .args(&rewritten[1..]) + .env("PATH", "/worktree/bin") + .output() + .expect("run rewritten command"); + + assert!(output.status.success(), "command failed: {output:?}"); + assert_eq!(String::from_utf8_lossy(&output.stdout), "/worktree/bin"); +} + +#[test] +fn maybe_wrap_shell_lc_with_snapshot_does_not_embed_override_values_in_argv() { + let dir = tempdir().expect("create temp dir"); + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write( + &snapshot_path, + "# Snapshot file\nexport OPENAI_API_KEY='snapshot-value'\n", + ) + .expect("write snapshot"); + let session_shell = shell_with_snapshot( + ShellType::Bash, + "/bin/bash", + snapshot_path, + dir.path().to_path_buf(), + ); + let command = vec![ + "/bin/bash".to_string(), + "-lc".to_string(), + "printf '%s' \"$OPENAI_API_KEY\"".to_string(), + ]; + let explicit_env_overrides = HashMap::from([( + "OPENAI_API_KEY".to_string(), + "super-secret-value".to_string(), + )]); + let rewritten = maybe_wrap_shell_lc_with_snapshot( + &command, + &session_shell, + dir.path(), + &explicit_env_overrides, + ); + + assert!(!rewritten[2].contains("super-secret-value")); + let output = Command::new(&rewritten[0]) + .args(&rewritten[1..]) + .env("OPENAI_API_KEY", "super-secret-value") + .output() + .expect("run rewritten command"); + assert!(output.status.success(), "command failed: {output:?}"); + assert_eq!( + String::from_utf8_lossy(&output.stdout), + "super-secret-value" + ); +} + +#[test] +fn maybe_wrap_shell_lc_with_snapshot_preserves_unset_override_variables() { + let dir = tempdir().expect("create temp dir"); + let snapshot_path = dir.path().join("snapshot.sh"); + std::fs::write( + &snapshot_path, + "# Snapshot file\nexport CODEX_TEST_UNSET_OVERRIDE='snapshot-value'\n", + ) + .expect("write snapshot"); + let session_shell = shell_with_snapshot( + ShellType::Bash, + "/bin/bash", + snapshot_path, + dir.path().to_path_buf(), + ); + let command = vec![ + "/bin/bash".to_string(), + "-lc".to_string(), + "if [ \"${CODEX_TEST_UNSET_OVERRIDE+x}\" = x ]; then printf 'set:%s' \"$CODEX_TEST_UNSET_OVERRIDE\"; else printf 'unset'; fi".to_string(), + ]; + let explicit_env_overrides = HashMap::from([( + "CODEX_TEST_UNSET_OVERRIDE".to_string(), + "worktree-value".to_string(), + )]); + let rewritten = maybe_wrap_shell_lc_with_snapshot( + &command, + &session_shell, + dir.path(), + &explicit_env_overrides, + ); + + let output = Command::new(&rewritten[0]) + .args(&rewritten[1..]) + .env_remove("CODEX_TEST_UNSET_OVERRIDE") + .output() + .expect("run rewritten command"); + assert!(output.status.success(), "command failed: {output:?}"); + assert_eq!(String::from_utf8_lossy(&output.stdout), "unset"); +} diff --git a/codex-rs/core/src/tools/sandboxing.rs b/codex-rs/core/src/tools/sandboxing.rs index c133dea7d8..16fd5b1a7c 100644 --- a/codex-rs/core/src/tools/sandboxing.rs +++ b/codex-rs/core/src/tools/sandboxing.rs @@ -360,119 +360,5 @@ impl<'a> SandboxAttempt<'a> { } #[cfg(test)] -mod tests { - use super::*; - use crate::sandboxing::SandboxPermissions; - use codex_protocol::protocol::NetworkAccess; - use codex_protocol::protocol::RejectConfig; - use pretty_assertions::assert_eq; - - #[test] - fn external_sandbox_skips_exec_approval_on_request() { - let sandbox_policy = SandboxPolicy::ExternalSandbox { - network_access: NetworkAccess::Restricted, - }; - assert_eq!( - default_exec_approval_requirement( - AskForApproval::OnRequest, - &FileSystemSandboxPolicy::from(&sandbox_policy), - ), - ExecApprovalRequirement::Skip { - bypass_sandbox: false, - proposed_execpolicy_amendment: None, - } - ); - } - - #[test] - fn restricted_sandbox_requires_exec_approval_on_request() { - let sandbox_policy = SandboxPolicy::new_read_only_policy(); - assert_eq!( - default_exec_approval_requirement( - AskForApproval::OnRequest, - &FileSystemSandboxPolicy::from(&sandbox_policy) - ), - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: None, - } - ); - } - - #[test] - fn default_exec_approval_requirement_rejects_sandbox_prompt_when_configured() { - let policy = AskForApproval::Reject(RejectConfig { - sandbox_approval: true, - rules: false, - skill_approval: false, - request_permissions: false, - mcp_elicitations: false, - }); - - let sandbox_policy = SandboxPolicy::new_read_only_policy(); - let requirement = default_exec_approval_requirement( - policy, - &FileSystemSandboxPolicy::from(&sandbox_policy), - ); - - assert_eq!( - requirement, - ExecApprovalRequirement::Forbidden { - reason: "approval policy rejected sandbox approval prompt".to_string(), - } - ); - } - - #[test] - fn default_exec_approval_requirement_keeps_prompt_when_sandbox_rejection_is_disabled() { - let policy = AskForApproval::Reject(RejectConfig { - sandbox_approval: false, - rules: true, - skill_approval: false, - request_permissions: false, - mcp_elicitations: true, - }); - - let sandbox_policy = SandboxPolicy::new_read_only_policy(); - let requirement = default_exec_approval_requirement( - policy, - &FileSystemSandboxPolicy::from(&sandbox_policy), - ); - - assert_eq!( - requirement, - ExecApprovalRequirement::NeedsApproval { - reason: None, - proposed_execpolicy_amendment: None, - } - ); - } - - #[test] - fn additional_permissions_allow_bypass_sandbox_first_attempt_when_execpolicy_skips() { - assert_eq!( - sandbox_override_for_first_attempt( - SandboxPermissions::WithAdditionalPermissions, - &ExecApprovalRequirement::Skip { - bypass_sandbox: true, - proposed_execpolicy_amendment: None, - }, - ), - SandboxOverride::BypassSandboxFirstAttempt - ); - } - - #[test] - fn guardian_bypasses_sandbox_for_explicit_escalation_on_first_attempt() { - assert_eq!( - sandbox_override_for_first_attempt( - SandboxPermissions::RequireEscalated, - &ExecApprovalRequirement::Skip { - bypass_sandbox: false, - proposed_execpolicy_amendment: None, - }, - ), - SandboxOverride::BypassSandboxFirstAttempt - ); - } -} +#[path = "sandboxing_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/sandboxing_tests.rs b/codex-rs/core/src/tools/sandboxing_tests.rs new file mode 100644 index 0000000000..cf68307ada --- /dev/null +++ b/codex-rs/core/src/tools/sandboxing_tests.rs @@ -0,0 +1,110 @@ +use super::*; +use crate::sandboxing::SandboxPermissions; +use codex_protocol::protocol::NetworkAccess; +use codex_protocol::protocol::RejectConfig; +use pretty_assertions::assert_eq; + +#[test] +fn external_sandbox_skips_exec_approval_on_request() { + let sandbox_policy = SandboxPolicy::ExternalSandbox { + network_access: NetworkAccess::Restricted, + }; + assert_eq!( + default_exec_approval_requirement( + AskForApproval::OnRequest, + &FileSystemSandboxPolicy::from(&sandbox_policy), + ), + ExecApprovalRequirement::Skip { + bypass_sandbox: false, + proposed_execpolicy_amendment: None, + } + ); +} + +#[test] +fn restricted_sandbox_requires_exec_approval_on_request() { + let sandbox_policy = SandboxPolicy::new_read_only_policy(); + assert_eq!( + default_exec_approval_requirement( + AskForApproval::OnRequest, + &FileSystemSandboxPolicy::from(&sandbox_policy) + ), + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: None, + } + ); +} + +#[test] +fn default_exec_approval_requirement_rejects_sandbox_prompt_when_configured() { + let policy = AskForApproval::Reject(RejectConfig { + sandbox_approval: true, + rules: false, + skill_approval: false, + request_permissions: false, + mcp_elicitations: false, + }); + + let sandbox_policy = SandboxPolicy::new_read_only_policy(); + let requirement = + default_exec_approval_requirement(policy, &FileSystemSandboxPolicy::from(&sandbox_policy)); + + assert_eq!( + requirement, + ExecApprovalRequirement::Forbidden { + reason: "approval policy rejected sandbox approval prompt".to_string(), + } + ); +} + +#[test] +fn default_exec_approval_requirement_keeps_prompt_when_sandbox_rejection_is_disabled() { + let policy = AskForApproval::Reject(RejectConfig { + sandbox_approval: false, + rules: true, + skill_approval: false, + request_permissions: false, + mcp_elicitations: true, + }); + + let sandbox_policy = SandboxPolicy::new_read_only_policy(); + let requirement = + default_exec_approval_requirement(policy, &FileSystemSandboxPolicy::from(&sandbox_policy)); + + assert_eq!( + requirement, + ExecApprovalRequirement::NeedsApproval { + reason: None, + proposed_execpolicy_amendment: None, + } + ); +} + +#[test] +fn additional_permissions_allow_bypass_sandbox_first_attempt_when_execpolicy_skips() { + assert_eq!( + sandbox_override_for_first_attempt( + SandboxPermissions::WithAdditionalPermissions, + &ExecApprovalRequirement::Skip { + bypass_sandbox: true, + proposed_execpolicy_amendment: None, + }, + ), + SandboxOverride::BypassSandboxFirstAttempt + ); +} + +#[test] +fn guardian_bypasses_sandbox_for_explicit_escalation_on_first_attempt() { + assert_eq!( + sandbox_override_for_first_attempt( + SandboxPermissions::RequireEscalated, + &ExecApprovalRequirement::Skip { + bypass_sandbox: false, + proposed_execpolicy_amendment: None, + }, + ), + SandboxOverride::BypassSandboxFirstAttempt + ); +} diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index 7d95afaebb..ccba578e40 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -2784,2442 +2784,5 @@ pub(crate) fn build_specs_with_discoverable_tools( } #[cfg(test)] -mod tests { - use crate::client_common::tools::FreeformTool; - use crate::config::test_config; - use crate::models_manager::manager::ModelsManager; - use crate::models_manager::model_info::with_config_overrides; - use crate::tools::registry::ConfiguredToolSpec; - use codex_app_server_protocol::AppInfo; - use codex_protocol::openai_models::InputModality; - use codex_protocol::openai_models::ModelInfo; - use codex_protocol::openai_models::ModelsResponse; - use pretty_assertions::assert_eq; - - use super::*; - - fn mcp_tool( - name: &str, - description: &str, - input_schema: serde_json::Value, - ) -> rmcp::model::Tool { - rmcp::model::Tool { - name: name.to_string().into(), - title: None, - description: Some(description.to_string().into()), - input_schema: std::sync::Arc::new(rmcp::model::object(input_schema)), - output_schema: None, - annotations: None, - execution: None, - icons: None, - meta: None, - } - } - - fn discoverable_connector(id: &str, name: &str, description: &str) -> DiscoverableTool { - let slug = name.replace(' ', "-").to_lowercase(); - DiscoverableTool::Connector(Box::new(AppInfo { - id: id.to_string(), - name: name.to_string(), - description: Some(description.to_string()), - logo_url: None, - logo_url_dark: None, - distribution_channel: None, - branding: None, - app_metadata: None, - labels: None, - install_url: Some(format!("https://chatgpt.com/apps/{slug}/{id}")), - is_accessible: false, - is_enabled: true, - plugin_display_names: Vec::new(), - })) - } - - #[test] - fn mcp_tool_to_openai_tool_inserts_empty_properties() { - let mut schema = rmcp::model::JsonObject::new(); - schema.insert("type".to_string(), serde_json::json!("object")); - - let tool = rmcp::model::Tool { - name: "no_props".to_string().into(), - title: None, - description: Some("No properties".to_string().into()), - input_schema: std::sync::Arc::new(schema), - output_schema: None, - annotations: None, - execution: None, - icons: None, - meta: None, - }; - - let openai_tool = - mcp_tool_to_openai_tool("server/no_props".to_string(), tool).expect("convert tool"); - let parameters = serde_json::to_value(openai_tool.parameters).expect("serialize schema"); - - assert_eq!(parameters.get("properties"), Some(&serde_json::json!({}))); - } - - #[test] - fn mcp_tool_to_openai_tool_preserves_top_level_output_schema() { - let mut input_schema = rmcp::model::JsonObject::new(); - input_schema.insert("type".to_string(), serde_json::json!("object")); - - let mut output_schema = rmcp::model::JsonObject::new(); - output_schema.insert( - "properties".to_string(), - serde_json::json!({ - "result": { - "properties": { - "nested": {} - } - } - }), - ); - output_schema.insert("required".to_string(), serde_json::json!(["result"])); - - let tool = rmcp::model::Tool { - name: "with_output".to_string().into(), - title: None, - description: Some("Has output schema".to_string().into()), - input_schema: std::sync::Arc::new(input_schema), - output_schema: Some(std::sync::Arc::new(output_schema)), - annotations: None, - execution: None, - icons: None, - meta: None, - }; - - let openai_tool = mcp_tool_to_openai_tool("mcp__server__with_output".to_string(), tool) - .expect("convert tool"); - - assert_eq!( - openai_tool.output_schema, - Some(serde_json::json!({ - "type": "object", - "properties": { - "content": { - "type": "array", - "items": {} - }, - "structuredContent": { - "properties": { - "result": { - "properties": { - "nested": {} - } - } - }, - "required": ["result"] - }, - "isError": { - "type": "boolean" - }, - "_meta": {} - }, - "required": ["content"], - "additionalProperties": false - })) - ); - } - - #[test] - fn mcp_tool_to_openai_tool_preserves_output_schema_without_inferred_type() { - let mut input_schema = rmcp::model::JsonObject::new(); - input_schema.insert("type".to_string(), serde_json::json!("object")); - - let mut output_schema = rmcp::model::JsonObject::new(); - output_schema.insert("enum".to_string(), serde_json::json!(["ok", "error"])); - - let tool = rmcp::model::Tool { - name: "with_enum_output".to_string().into(), - title: None, - description: Some("Has enum output schema".to_string().into()), - input_schema: std::sync::Arc::new(input_schema), - output_schema: Some(std::sync::Arc::new(output_schema)), - annotations: None, - execution: None, - icons: None, - meta: None, - }; - - let openai_tool = - mcp_tool_to_openai_tool("mcp__server__with_enum_output".to_string(), tool) - .expect("convert tool"); - - assert_eq!( - openai_tool.output_schema, - Some(serde_json::json!({ - "type": "object", - "properties": { - "content": { - "type": "array", - "items": {} - }, - "structuredContent": { - "enum": ["ok", "error"] - }, - "isError": { - "type": "boolean" - }, - "_meta": {} - }, - "required": ["content"], - "additionalProperties": false - })) - ); - } - - #[test] - fn search_tool_deferred_tools_always_set_defer_loading_true() { - let tool = mcp_tool( - "lookup_order", - "Look up an order", - serde_json::json!({ - "type": "object", - "properties": { - "order_id": {"type": "string"} - }, - "required": ["order_id"], - "additionalProperties": false, - }), - ); - - let openai_tool = - mcp_tool_to_deferred_openai_tool("mcp__codex_apps__lookup_order".to_string(), tool) - .expect("convert deferred tool"); - - assert_eq!(openai_tool.defer_loading, Some(true)); - } - - #[test] - fn deferred_responses_api_tool_serializes_with_defer_loading() { - let tool = mcp_tool( - "lookup_order", - "Look up an order", - serde_json::json!({ - "type": "object", - "properties": { - "order_id": {"type": "string"} - }, - "required": ["order_id"], - "additionalProperties": false, - }), - ); - - let serialized = serde_json::to_value(ToolSpec::Function( - mcp_tool_to_deferred_openai_tool("mcp__codex_apps__lookup_order".to_string(), tool) - .expect("convert deferred tool"), - )) - .expect("serialize deferred tool"); - - assert_eq!( - serialized, - serde_json::json!({ - "type": "function", - "name": "mcp__codex_apps__lookup_order", - "description": "Look up an order", - "strict": false, - "defer_loading": true, - "parameters": { - "type": "object", - "properties": { - "order_id": {"type": "string"} - }, - "required": ["order_id"], - "additionalProperties": false, - } - }) - ); - } - - fn tool_name(tool: &ToolSpec) -> &str { - match tool { - ToolSpec::Function(ResponsesApiTool { name, .. }) => name, - ToolSpec::ToolSearch { .. } => "tool_search", - ToolSpec::LocalShell {} => "local_shell", - ToolSpec::ImageGeneration { .. } => "image_generation", - ToolSpec::WebSearch { .. } => "web_search", - ToolSpec::Freeform(FreeformTool { name, .. }) => name, - } - } - - // Avoid order-based assertions; compare via set containment instead. - fn assert_contains_tool_names(tools: &[ConfiguredToolSpec], expected_subset: &[&str]) { - use std::collections::HashSet; - let mut names = HashSet::new(); - let mut duplicates = Vec::new(); - for name in tools.iter().map(|t| tool_name(&t.spec)) { - if !names.insert(name) { - duplicates.push(name); - } - } - assert!( - duplicates.is_empty(), - "duplicate tool entries detected: {duplicates:?}" - ); - for expected in expected_subset { - assert!( - names.contains(expected), - "expected tool {expected} to be present; had: {names:?}" - ); - } - } - - fn assert_lacks_tool_name(tools: &[ConfiguredToolSpec], expected_absent: &str) { - let names = tools - .iter() - .map(|tool| tool_name(&tool.spec)) - .collect::>(); - assert!( - !names.contains(&expected_absent), - "expected tool {expected_absent} to be absent; had: {names:?}" - ); - } - - fn shell_tool_name(config: &ToolsConfig) -> Option<&'static str> { - match config.shell_type { - ConfigShellToolType::Default => Some("shell"), - ConfigShellToolType::Local => Some("local_shell"), - ConfigShellToolType::UnifiedExec => None, - ConfigShellToolType::Disabled => None, - ConfigShellToolType::ShellCommand => Some("shell_command"), - } - } - - fn find_tool<'a>( - tools: &'a [ConfiguredToolSpec], - expected_name: &str, - ) -> &'a ConfiguredToolSpec { - tools - .iter() - .find(|tool| tool_name(&tool.spec) == expected_name) - .unwrap_or_else(|| panic!("expected tool {expected_name}")) - } - - fn strip_descriptions_schema(schema: &mut JsonSchema) { - match schema { - JsonSchema::Boolean { description } - | JsonSchema::String { description } - | JsonSchema::Number { description } => { - *description = None; - } - JsonSchema::Array { items, description } => { - strip_descriptions_schema(items); - *description = None; - } - JsonSchema::Object { - properties, - required: _, - additional_properties, - } => { - for v in properties.values_mut() { - strip_descriptions_schema(v); - } - if let Some(AdditionalProperties::Schema(s)) = additional_properties { - strip_descriptions_schema(s); - } - } - } - } - - fn strip_descriptions_tool(spec: &mut ToolSpec) { - match spec { - ToolSpec::ToolSearch { parameters, .. } => strip_descriptions_schema(parameters), - ToolSpec::Function(ResponsesApiTool { parameters, .. }) => { - strip_descriptions_schema(parameters); - } - ToolSpec::Freeform(_) - | ToolSpec::LocalShell {} - | ToolSpec::ImageGeneration { .. } - | ToolSpec::WebSearch { .. } => {} - } - } - - fn model_info_from_models_json(slug: &str) -> ModelInfo { - let config = test_config(); - let response: ModelsResponse = - serde_json::from_str(include_str!("../../models.json")).expect("valid models.json"); - let model = response - .models - .into_iter() - .find(|candidate| candidate.slug == slug) - .unwrap_or_else(|| panic!("model slug {slug} is missing from models.json")); - with_config_overrides(model, &config) - } - - #[test] - fn test_full_toolset_specs_for_gpt5_codex_unified_exec_web_search() { - let model_info = model_info_from_models_json("gpt-5-codex"); - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - let available_models = Vec::new(); - let config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Live), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&config, None, None, &[]).build(); - - // Build actual map name -> spec - use std::collections::BTreeMap; - use std::collections::HashSet; - let mut actual: BTreeMap = BTreeMap::from([]); - let mut duplicate_names = Vec::new(); - for t in &tools { - let name = tool_name(&t.spec).to_string(); - if actual.insert(name.clone(), t.spec.clone()).is_some() { - duplicate_names.push(name); - } - } - assert!( - duplicate_names.is_empty(), - "duplicate tool entries detected: {duplicate_names:?}" - ); - - // Build expected from the same helpers used by the builder. - let mut expected: BTreeMap = BTreeMap::from([]); - for spec in [ - create_exec_command_tool(true, false), - create_write_stdin_tool(), - PLAN_TOOL.clone(), - create_request_user_input_tool(CollaborationModesConfig::default()), - create_apply_patch_freeform_tool(), - ToolSpec::WebSearch { - external_web_access: Some(true), - filters: None, - user_location: None, - search_context_size: None, - search_content_types: None, - }, - create_view_image_tool(config.can_request_original_image_detail), - ] { - expected.insert(tool_name(&spec).to_string(), spec); - } - - if config.request_permission_enabled { - let spec = create_request_permissions_tool(); - expected.insert(tool_name(&spec).to_string(), spec); - } - - // Exact name set match — this is the only test allowed to fail when tools change. - let actual_names: HashSet<_> = actual.keys().cloned().collect(); - let expected_names: HashSet<_> = expected.keys().cloned().collect(); - assert_eq!(actual_names, expected_names, "tool name set mismatch"); - - // Compare specs ignoring human-readable descriptions. - for name in expected.keys() { - let mut a = actual.get(name).expect("present").clone(); - let mut e = expected.get(name).expect("present").clone(); - strip_descriptions_tool(&mut a); - strip_descriptions_tool(&mut e); - assert_eq!(a, e, "spec mismatch for {name}"); - } - } - - #[test] - fn test_build_specs_collab_tools_enabled() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::Collab); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - assert_contains_tool_names( - &tools, - &["spawn_agent", "send_input", "wait", "close_agent"], - ); - assert_lacks_tool_name(&tools, "spawn_agents_on_csv"); - } - - #[test] - fn test_build_specs_spawn_csv_enables_agent_jobs_and_collab_tools() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::SpawnCsv); - features.normalize_dependencies(); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - assert_contains_tool_names( - &tools, - &[ - "spawn_agent", - "send_input", - "wait", - "close_agent", - "spawn_agents_on_csv", - ], - ); - } - - #[test] - fn view_image_tool_omits_detail_without_original_detail_feature() { - let config = test_config(); - let mut model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - model_info.supports_image_detail_original = true; - let features = Features::with_defaults(); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - let view_image = find_tool(&tools, VIEW_IMAGE_TOOL_NAME); - let ToolSpec::Function(ResponsesApiTool { parameters, .. }) = &view_image.spec else { - panic!("view_image should be a function tool"); - }; - let JsonSchema::Object { properties, .. } = parameters else { - panic!("view_image should use an object schema"); - }; - assert!(!properties.contains_key("detail")); - } - - #[test] - fn view_image_tool_includes_detail_with_original_detail_feature() { - let config = test_config(); - let mut model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - model_info.supports_image_detail_original = true; - let mut features = Features::with_defaults(); - features.enable(Feature::ImageDetailOriginal); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - let view_image = find_tool(&tools, VIEW_IMAGE_TOOL_NAME); - let ToolSpec::Function(ResponsesApiTool { parameters, .. }) = &view_image.spec else { - panic!("view_image should be a function tool"); - }; - let JsonSchema::Object { properties, .. } = parameters else { - panic!("view_image should use an object schema"); - }; - assert!(properties.contains_key("detail")); - let Some(JsonSchema::String { - description: Some(description), - }) = properties.get("detail") - else { - panic!("view_image detail should include a description"); - }; - assert!(description.contains("only supported value is `original`")); - assert!(description.contains("omit this field for default resized behavior")); - } - - #[test] - fn test_build_specs_artifact_tool_enabled() { - let mut config = test_config(); - let runtime_root = tempfile::TempDir::new().expect("create temp codex home"); - config.codex_home = runtime_root.path().to_path_buf(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::Artifact); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - assert_contains_tool_names(&tools, &["artifacts"]); - } - - #[test] - fn test_build_specs_agent_job_worker_tools_enabled() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::SpawnCsv); - features.normalize_dependencies(); - features.enable(Feature::Sqlite); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::SubAgent(SubAgentSource::Other( - "agent_job:test".to_string(), - )), - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - assert_contains_tool_names( - &tools, - &[ - "spawn_agent", - "send_input", - "resume_agent", - "wait", - "close_agent", - "spawn_agents_on_csv", - "report_agent_job_result", - ], - ); - assert_lacks_tool_name(&tools, "request_user_input"); - } - - #[test] - fn request_user_input_description_reflects_default_mode_feature_flag() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - let request_user_input_tool = find_tool(&tools, "request_user_input"); - assert_eq!( - request_user_input_tool.spec, - create_request_user_input_tool(CollaborationModesConfig::default()) - ); - - features.enable(Feature::DefaultModeRequestUserInput); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - let request_user_input_tool = find_tool(&tools, "request_user_input"); - assert_eq!( - request_user_input_tool.spec, - create_request_user_input_tool(CollaborationModesConfig { - default_mode_request_user_input: true, - }) - ); - } - - #[test] - fn request_permissions_requires_feature_flag() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let features = Features::with_defaults(); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - assert_lacks_tool_name(&tools, "request_permissions"); - - let mut features = Features::with_defaults(); - features.enable(Feature::RequestPermissionsTool); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - let request_permissions_tool = find_tool(&tools, "request_permissions"); - assert_eq!( - request_permissions_tool.spec, - create_request_permissions_tool() - ); - } - - #[test] - fn request_permissions_tool_is_independent_from_additional_permissions() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::RequestPermissions); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - - assert_lacks_tool_name(&tools, "request_permissions"); - } - - #[test] - fn get_memory_requires_feature_flag() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.disable(Feature::MemoryTool); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - assert!( - !tools.iter().any(|t| t.spec.name() == "get_memory"), - "get_memory should be disabled when memory_tool feature is off" - ); - } - - #[test] - fn js_repl_requires_feature_flag() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let features = Features::with_defaults(); - - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - - assert!( - !tools.iter().any(|tool| tool.spec.name() == "js_repl"), - "js_repl should be disabled when the feature is off" - ); - assert!( - !tools.iter().any(|tool| tool.spec.name() == "js_repl_reset"), - "js_repl_reset should be disabled when the feature is off" - ); - } - - #[test] - fn js_repl_enabled_adds_tools() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::JsRepl); - - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - assert_contains_tool_names(&tools, &["js_repl", "js_repl_reset"]); - } - - #[test] - fn image_generation_tools_require_feature_and_supported_model() { - let config = test_config(); - let mut supported_model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5.2", &config); - supported_model_info.slug = "custom/gpt-5.2-variant".to_string(); - let mut unsupported_model_info = supported_model_info.clone(); - unsupported_model_info.input_modalities = vec![InputModality::Text]; - let default_features = Features::with_defaults(); - let mut image_generation_features = default_features.clone(); - image_generation_features.enable(Feature::ImageGeneration); - - let available_models = Vec::new(); - let default_tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &supported_model_info, - available_models: &available_models, - features: &default_features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (default_tools, _) = build_specs(&default_tools_config, None, None, &[]).build(); - assert!( - !default_tools - .iter() - .any(|tool| tool.spec.name() == "image_generation"), - "image_generation should be disabled by default" - ); - - let supported_tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &supported_model_info, - available_models: &available_models, - features: &image_generation_features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (supported_tools, _) = build_specs(&supported_tools_config, None, None, &[]).build(); - assert_contains_tool_names(&supported_tools, &["image_generation"]); - let image_generation_tool = find_tool(&supported_tools, "image_generation"); - assert_eq!( - serde_json::to_value(&image_generation_tool.spec).expect("serialize image tool"), - serde_json::json!({ - "type": "image_generation", - "output_format": "png" - }) - ); - - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &unsupported_model_info, - available_models: &available_models, - features: &image_generation_features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - assert!( - !tools - .iter() - .any(|tool| tool.spec.name() == "image_generation"), - "image_generation should be disabled for unsupported models" - ); - } - - #[test] - fn js_repl_freeform_grammar_blocks_common_non_js_prefixes() { - let ToolSpec::Freeform(FreeformTool { format, .. }) = create_js_repl_tool() else { - panic!("js_repl should use a freeform tool spec"); - }; - - assert_eq!(format.syntax, "lark"); - assert!(format.definition.contains("PRAGMA_LINE")); - assert!(format.definition.contains("`[^`]")); - assert!(format.definition.contains("``[^`]")); - assert!(format.definition.contains("PLAIN_JS_SOURCE")); - assert!(format.definition.contains("codex-js-repl:")); - assert!(!format.definition.contains("(?!")); - } - - fn assert_model_tools( - model_slug: &str, - features: &Features, - web_search_mode: Option, - expected_tools: &[&str], - ) { - let _config = test_config(); - let model_info = model_info_from_models_json(model_slug); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features, - web_search_mode, - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - let tool_names = tools.iter().map(|t| t.spec.name()).collect::>(); - assert_eq!(&tool_names, &expected_tools,); - } - - fn assert_default_model_tools( - model_slug: &str, - features: &Features, - web_search_mode: Option, - shell_tool: &'static str, - expected_tail: &[&str], - ) { - let mut expected = if features.enabled(Feature::UnifiedExec) { - vec!["exec_command", "write_stdin"] - } else { - vec![shell_tool] - }; - expected.extend(expected_tail); - assert_model_tools(model_slug, features, web_search_mode, &expected); - } - - #[test] - fn web_search_mode_cached_sets_external_web_access_false() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let features = Features::with_defaults(); - - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - - let tool = find_tool(&tools, "web_search"); - assert_eq!( - tool.spec, - ToolSpec::WebSearch { - external_web_access: Some(false), - filters: None, - user_location: None, - search_context_size: None, - search_content_types: None, - } - ); - } - - #[test] - fn web_search_mode_live_sets_external_web_access_true() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let features = Features::with_defaults(); - - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Live), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - - let tool = find_tool(&tools, "web_search"); - assert_eq!( - tool.spec, - ToolSpec::WebSearch { - external_web_access: Some(true), - filters: None, - user_location: None, - search_context_size: None, - search_content_types: None, - } - ); - } - - #[test] - fn web_search_config_is_forwarded_to_tool_spec() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let features = Features::with_defaults(); - let web_search_config = WebSearchConfig { - filters: Some(codex_protocol::config_types::WebSearchFilters { - allowed_domains: Some(vec!["example.com".to_string()]), - }), - user_location: Some(codex_protocol::config_types::WebSearchUserLocation { - r#type: codex_protocol::config_types::WebSearchUserLocationType::Approximate, - country: Some("US".to_string()), - region: Some("California".to_string()), - city: Some("San Francisco".to_string()), - timezone: Some("America/Los_Angeles".to_string()), - }), - search_context_size: Some(codex_protocol::config_types::WebSearchContextSize::High), - }; - - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Live), - session_source: SessionSource::Cli, - }) - .with_web_search_config(Some(web_search_config.clone())); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - - let tool = find_tool(&tools, "web_search"); - assert_eq!( - tool.spec, - ToolSpec::WebSearch { - external_web_access: Some(true), - filters: web_search_config - .filters - .map(crate::client_common::tools::ResponsesApiWebSearchFilters::from), - user_location: web_search_config - .user_location - .map(crate::client_common::tools::ResponsesApiWebSearchUserLocation::from), - search_context_size: web_search_config.search_context_size, - search_content_types: None, - } - ); - } - - #[test] - fn web_search_tool_type_text_and_image_sets_search_content_types() { - let config = test_config(); - let mut model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - model_info.web_search_tool_type = WebSearchToolType::TextAndImage; - let features = Features::with_defaults(); - - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Live), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - - let tool = find_tool(&tools, "web_search"); - assert_eq!( - tool.spec, - ToolSpec::WebSearch { - external_web_access: Some(true), - filters: None, - user_location: None, - search_context_size: None, - search_content_types: Some( - WEB_SEARCH_CONTENT_TYPES - .into_iter() - .map(str::to_string) - .collect() - ), - } - ); - } - - #[test] - fn mcp_resource_tools_are_hidden_without_mcp_servers() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let features = Features::with_defaults(); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - - assert!( - !tools.iter().any(|tool| matches!( - tool.spec.name(), - "list_mcp_resources" | "list_mcp_resource_templates" | "read_mcp_resource" - )), - "MCP resource tools should be omitted when no MCP servers are configured" - ); - } - - #[test] - fn mcp_resource_tools_are_included_when_mcp_servers_are_present() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let features = Features::with_defaults(); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, Some(HashMap::new()), None, &[]).build(); - - assert_contains_tool_names( - &tools, - &[ - "list_mcp_resources", - "list_mcp_resource_templates", - "read_mcp_resource", - ], - ); - } - - #[test] - fn test_build_specs_gpt5_codex_default() { - let features = Features::with_defaults(); - assert_default_model_tools( - "gpt-5-codex", - &features, - Some(WebSearchMode::Cached), - "shell_command", - &[ - "update_plan", - "request_user_input", - "apply_patch", - "web_search", - "view_image", - ], - ); - } - - #[test] - fn test_build_specs_gpt51_codex_default() { - let features = Features::with_defaults(); - assert_default_model_tools( - "gpt-5.1-codex", - &features, - Some(WebSearchMode::Cached), - "shell_command", - &[ - "update_plan", - "request_user_input", - "apply_patch", - "web_search", - "view_image", - ], - ); - } - - #[test] - fn test_build_specs_gpt5_codex_unified_exec_web_search() { - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - assert_model_tools( - "gpt-5-codex", - &features, - Some(WebSearchMode::Live), - &[ - "exec_command", - "write_stdin", - "update_plan", - "request_user_input", - "apply_patch", - "web_search", - "view_image", - ], - ); - } - - #[test] - fn test_build_specs_gpt51_codex_unified_exec_web_search() { - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - assert_model_tools( - "gpt-5.1-codex", - &features, - Some(WebSearchMode::Live), - &[ - "exec_command", - "write_stdin", - "update_plan", - "request_user_input", - "apply_patch", - "web_search", - "view_image", - ], - ); - } - - #[test] - fn test_gpt_5_1_codex_max_defaults() { - let features = Features::with_defaults(); - assert_default_model_tools( - "gpt-5.1-codex-max", - &features, - Some(WebSearchMode::Cached), - "shell_command", - &[ - "update_plan", - "request_user_input", - "apply_patch", - "web_search", - "view_image", - ], - ); - } - - #[test] - fn test_codex_5_1_mini_defaults() { - let features = Features::with_defaults(); - assert_default_model_tools( - "gpt-5.1-codex-mini", - &features, - Some(WebSearchMode::Cached), - "shell_command", - &[ - "update_plan", - "request_user_input", - "apply_patch", - "web_search", - "view_image", - ], - ); - } - - #[test] - fn test_gpt_5_defaults() { - let features = Features::with_defaults(); - assert_default_model_tools( - "gpt-5", - &features, - Some(WebSearchMode::Cached), - "shell", - &[ - "update_plan", - "request_user_input", - "web_search", - "view_image", - ], - ); - } - - #[test] - fn test_gpt_5_1_defaults() { - let features = Features::with_defaults(); - assert_default_model_tools( - "gpt-5.1", - &features, - Some(WebSearchMode::Cached), - "shell_command", - &[ - "update_plan", - "request_user_input", - "apply_patch", - "web_search", - "view_image", - ], - ); - } - - #[test] - fn test_gpt_5_1_codex_max_unified_exec_web_search() { - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - assert_model_tools( - "gpt-5.1-codex-max", - &features, - Some(WebSearchMode::Live), - &[ - "exec_command", - "write_stdin", - "update_plan", - "request_user_input", - "apply_patch", - "web_search", - "view_image", - ], - ); - } - - #[test] - fn test_build_specs_default_shell_present() { - let config = test_config(); - let model_info = ModelsManager::construct_model_info_offline_for_tests("o3", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Live), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, Some(HashMap::new()), None, &[]).build(); - - // Only check the shell variant and a couple of core tools. - let mut subset = vec!["exec_command", "write_stdin", "update_plan"]; - if let Some(shell_tool) = shell_tool_name(&tools_config) { - subset.push(shell_tool); - } - assert_contains_tool_names(&tools, &subset); - } - - #[test] - fn shell_zsh_fork_prefers_shell_command_over_unified_exec() { - let config = test_config(); - let model_info = ModelsManager::construct_model_info_offline_for_tests("o3", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - features.enable(Feature::ShellZshFork); - - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Live), - session_source: SessionSource::Cli, - }); - - assert_eq!(tools_config.shell_type, ConfigShellToolType::ShellCommand); - assert_eq!( - tools_config.shell_command_backend, - ShellCommandBackendConfig::ZshFork - ); - assert_eq!( - tools_config.unified_exec_backend, - UnifiedExecBackendConfig::ZshFork - ); - } - - #[test] - #[ignore] - fn test_parallel_support_flags() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - - assert!(find_tool(&tools, "exec_command").supports_parallel_tool_calls); - assert!(!find_tool(&tools, "write_stdin").supports_parallel_tool_calls); - assert!(find_tool(&tools, "grep_files").supports_parallel_tool_calls); - assert!(find_tool(&tools, "list_dir").supports_parallel_tool_calls); - assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls); - } - - #[test] - fn test_test_model_info_includes_sync_tool() { - let _config = test_config(); - let mut model_info = model_info_from_models_json("gpt-5-codex"); - model_info.experimental_supported_tools = vec![ - "test_sync_tool".to_string(), - "read_file".to_string(), - "grep_files".to_string(), - "list_dir".to_string(), - ]; - let features = Features::with_defaults(); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - - assert!( - tools - .iter() - .any(|tool| tool_name(&tool.spec) == "test_sync_tool") - ); - assert!( - tools - .iter() - .any(|tool| tool_name(&tool.spec) == "read_file") - ); - assert!( - tools - .iter() - .any(|tool| tool_name(&tool.spec) == "grep_files") - ); - assert!(tools.iter().any(|tool| tool_name(&tool.spec) == "list_dir")); - } - - #[test] - fn test_build_specs_mcp_tools_converted() { - let config = test_config(); - let model_info = ModelsManager::construct_model_info_offline_for_tests("o3", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Live), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs( - &tools_config, - Some(HashMap::from([( - "test_server/do_something_cool".to_string(), - mcp_tool( - "do_something_cool", - "Do something cool", - serde_json::json!({ - "type": "object", - "properties": { - "string_argument": { "type": "string" }, - "number_argument": { "type": "number" }, - "object_argument": { - "type": "object", - "properties": { - "string_property": { "type": "string" }, - "number_property": { "type": "number" }, - }, - "required": ["string_property", "number_property"], - "additionalProperties": false, - }, - }, - }), - ), - )])), - None, - &[], - ) - .build(); - - let tool = find_tool(&tools, "test_server/do_something_cool"); - assert_eq!( - &tool.spec, - &ToolSpec::Function(ResponsesApiTool { - name: "test_server/do_something_cool".to_string(), - parameters: JsonSchema::Object { - properties: BTreeMap::from([ - ( - "string_argument".to_string(), - JsonSchema::String { description: None } - ), - ( - "number_argument".to_string(), - JsonSchema::Number { description: None } - ), - ( - "object_argument".to_string(), - JsonSchema::Object { - properties: BTreeMap::from([ - ( - "string_property".to_string(), - JsonSchema::String { description: None } - ), - ( - "number_property".to_string(), - JsonSchema::Number { description: None } - ), - ]), - required: Some(vec![ - "string_property".to_string(), - "number_property".to_string(), - ]), - additional_properties: Some(false.into()), - }, - ), - ]), - required: None, - additional_properties: None, - }, - description: "Do something cool".to_string(), - strict: false, - output_schema: Some(mcp_call_tool_result_output_schema(serde_json::json!({}))), - defer_loading: None, - }) - ); - } - - #[test] - fn test_build_specs_mcp_tools_sorted_by_name() { - let config = test_config(); - let model_info = ModelsManager::construct_model_info_offline_for_tests("o3", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - - // Intentionally construct a map with keys that would sort alphabetically. - let tools_map: HashMap = HashMap::from([ - ( - "test_server/do".to_string(), - mcp_tool("a", "a", serde_json::json!({"type": "object"})), - ), - ( - "test_server/something".to_string(), - mcp_tool("b", "b", serde_json::json!({"type": "object"})), - ), - ( - "test_server/cool".to_string(), - mcp_tool("c", "c", serde_json::json!({"type": "object"})), - ), - ]); - - let (tools, _) = build_specs(&tools_config, Some(tools_map), None, &[]).build(); - - // Only assert that the MCP tools themselves are sorted by fully-qualified name. - let mcp_names: Vec<_> = tools - .iter() - .map(|t| tool_name(&t.spec).to_string()) - .filter(|n| n.starts_with("test_server/")) - .collect(); - let expected = vec![ - "test_server/cool".to_string(), - "test_server/do".to_string(), - "test_server/something".to_string(), - ]; - assert_eq!(mcp_names, expected); - } - - #[test] - fn search_tool_description_includes_only_codex_apps_connector_names() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::Apps); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - - let (tools, _) = build_specs( - &tools_config, - Some(HashMap::from([ - ( - "mcp__codex_apps__calendar_create_event".to_string(), - mcp_tool( - "calendar_create_event", - "Create calendar event", - serde_json::json!({"type": "object"}), - ), - ), - ( - "mcp__rmcp__echo".to_string(), - mcp_tool("echo", "Echo", serde_json::json!({"type": "object"})), - ), - ])), - Some(HashMap::from([ - ( - "mcp__codex_apps__calendar-create-event".to_string(), - ToolInfo { - server_name: crate::mcp::CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool_name: "-create-event".to_string(), - tool_namespace: "mcp__codex_apps__calendar".to_string(), - tool: mcp_tool( - "calendar-create-event", - "Create calendar event", - serde_json::json!({"type": "object"}), - ), - connector_id: Some("calendar".to_string()), - connector_name: Some("Calendar".to_string()), - plugin_display_names: Vec::new(), - connector_description: None, - }, - ), - ( - "mcp__rmcp__echo".to_string(), - ToolInfo { - server_name: "rmcp".to_string(), - tool_name: "echo".to_string(), - tool_namespace: "rmcp".to_string(), - tool: mcp_tool("echo", "Echo", serde_json::json!({"type": "object"})), - connector_id: None, - connector_name: None, - plugin_display_names: Vec::new(), - connector_description: None, - }, - ), - ])), - &[], - ) - .build(); - - let search_tool = find_tool(&tools, TOOL_SEARCH_TOOL_NAME); - let ToolSpec::ToolSearch { description, .. } = &search_tool.spec else { - panic!("expected tool_search tool"); - }; - let description = description.as_str(); - assert!(description.contains("Calendar")); - assert!(!description.contains("mcp__rmcp__echo")); - } - - #[test] - fn search_tool_requires_apps_feature_flag_only() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let app_tools = Some(HashMap::from([( - "mcp__codex_apps__calendar_create_event".to_string(), - ToolInfo { - server_name: crate::mcp::CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool_name: "calendar_create_event".to_string(), - tool_namespace: "mcp__codex_apps__calendar".to_string(), - tool: mcp_tool( - "calendar_create_event", - "Create calendar event", - serde_json::json!({"type": "object"}), - ), - connector_id: Some("calendar".to_string()), - connector_name: Some("Calendar".to_string()), - connector_description: None, - plugin_display_names: Vec::new(), - }, - )])); - - let features = Features::with_defaults(); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, app_tools.clone(), &[]).build(); - assert_lacks_tool_name(&tools, TOOL_SEARCH_TOOL_NAME); - let mut features = Features::with_defaults(); - features.enable(Feature::Apps); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs(&tools_config, None, app_tools, &[]).build(); - assert_contains_tool_names(&tools, &[TOOL_SEARCH_TOOL_NAME]); - } - - #[test] - fn tool_suggest_is_not_registered_without_feature_flag() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::Apps); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs_with_discoverable_tools( - &tools_config, - None, - None, - Some(vec![discoverable_connector( - "connector_2128aebfecb84f64a069897515042a44", - "Google Calendar", - "Plan events and schedules.", - )]), - &[], - ) - .build(); - - assert!( - !tools - .iter() - .any(|tool| tool_name(&tool.spec) == TOOL_SUGGEST_TOOL_NAME) - ); - } - - #[test] - fn search_tool_description_handles_no_enabled_apps() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::Apps); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - - let (tools, _) = build_specs(&tools_config, None, Some(HashMap::new()), &[]).build(); - let search_tool = find_tool(&tools, TOOL_SEARCH_TOOL_NAME); - let ToolSpec::ToolSearch { description, .. } = &search_tool.spec else { - panic!("expected tool_search tool"); - }; - - assert!(description.contains("(None currently enabled)")); - assert!(!description.contains("{{app_names}}")); - } - - #[test] - fn search_tool_registers_namespaced_app_tool_aliases() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::Apps); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - - let (_, registry) = build_specs( - &tools_config, - None, - Some(HashMap::from([ - ( - "mcp__codex_apps__calendar-create-event".to_string(), - ToolInfo { - server_name: crate::mcp::CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool_name: "-create-event".to_string(), - tool_namespace: "mcp__codex_apps__calendar".to_string(), - tool: mcp_tool( - "calendar-create-event", - "Create calendar event", - serde_json::json!({"type": "object"}), - ), - connector_id: Some("calendar".to_string()), - connector_name: Some("Calendar".to_string()), - connector_description: None, - plugin_display_names: Vec::new(), - }, - ), - ( - "mcp__codex_apps__calendar-list-events".to_string(), - ToolInfo { - server_name: crate::mcp::CODEX_APPS_MCP_SERVER_NAME.to_string(), - tool_name: "-list-events".to_string(), - tool_namespace: "mcp__codex_apps__calendar".to_string(), - tool: mcp_tool( - "calendar-list-events", - "List calendar events", - serde_json::json!({"type": "object"}), - ), - connector_id: Some("calendar".to_string()), - connector_name: Some("Calendar".to_string()), - connector_description: None, - plugin_display_names: Vec::new(), - }, - ), - ])), - &[], - ) - .build(); - - let alias = tool_handler_key("-create-event", Some("mcp__codex_apps__calendar")); - - assert!(registry.has_handler(TOOL_SEARCH_TOOL_NAME, None)); - assert!(registry.has_handler(alias.as_str(), None)); - } - - #[test] - fn tool_suggest_description_lists_discoverable_tools() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::Apps); - features.enable(Feature::ToolSuggest); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - - let discoverable_tools = vec![ - discoverable_connector( - "connector_2128aebfecb84f64a069897515042a44", - "Google Calendar", - "Plan events and schedules.", - ), - discoverable_connector( - "connector_68df038e0ba48191908c8434991bbac2", - "Gmail", - "Find and summarize email threads.", - ), - DiscoverableTool::Plugin(Box::new(DiscoverablePluginInfo { - id: "sample@test".to_string(), - name: "Sample Plugin".to_string(), - description: None, - has_skills: true, - mcp_server_names: vec!["sample-docs".to_string()], - app_connector_ids: vec!["connector_sample".to_string()], - })), - ]; - - let (tools, _) = build_specs_with_discoverable_tools( - &tools_config, - None, - None, - Some(discoverable_tools), - &[], - ) - .build(); - - let tool_suggest = find_tool(&tools, TOOL_SUGGEST_TOOL_NAME); - let ToolSpec::Function(ResponsesApiTool { - description, - parameters, - .. - }) = &tool_suggest.spec - else { - panic!("expected function tool"); - }; - assert!(description.contains("Google Calendar")); - assert!(description.contains("Gmail")); - assert!(description.contains("Sample Plugin")); - assert!(description.contains("Plan events and schedules.")); - assert!(description.contains("Find and summarize email threads.")); - assert!(description.contains("id: `sample@test`, type: plugin, action: enable")); - assert!( - description - .contains("skills; MCP servers: sample-docs; app connectors: connector_sample") - ); - assert!( - description.contains("DO NOT explore or recommend tools that are not on this list.") - ); - let JsonSchema::Object { required, .. } = parameters else { - panic!("expected object parameters"); - }; - assert_eq!( - required.as_ref(), - Some(&vec![ - "tool_type".to_string(), - "action_type".to_string(), - "tool_id".to_string(), - "suggest_reason".to_string(), - ]) - ); - } - - #[test] - fn test_mcp_tool_property_missing_type_defaults_to_string() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - - let (tools, _) = build_specs( - &tools_config, - Some(HashMap::from([( - "dash/search".to_string(), - mcp_tool( - "search", - "Search docs", - serde_json::json!({ - "type": "object", - "properties": { - "query": {"description": "search query"} - } - }), - ), - )])), - None, - &[], - ) - .build(); - - let tool = find_tool(&tools, "dash/search"); - assert_eq!( - tool.spec, - ToolSpec::Function(ResponsesApiTool { - name: "dash/search".to_string(), - parameters: JsonSchema::Object { - properties: BTreeMap::from([( - "query".to_string(), - JsonSchema::String { - description: Some("search query".to_string()) - } - )]), - required: None, - additional_properties: None, - }, - description: "Search docs".to_string(), - strict: false, - output_schema: Some(mcp_call_tool_result_output_schema(serde_json::json!({}))), - defer_loading: None, - }) - ); - } - - #[test] - fn test_mcp_tool_integer_normalized_to_number() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - - let (tools, _) = build_specs( - &tools_config, - Some(HashMap::from([( - "dash/paginate".to_string(), - mcp_tool( - "paginate", - "Pagination", - serde_json::json!({ - "type": "object", - "properties": {"page": {"type": "integer"}} - }), - ), - )])), - None, - &[], - ) - .build(); - - let tool = find_tool(&tools, "dash/paginate"); - assert_eq!( - tool.spec, - ToolSpec::Function(ResponsesApiTool { - name: "dash/paginate".to_string(), - parameters: JsonSchema::Object { - properties: BTreeMap::from([( - "page".to_string(), - JsonSchema::Number { description: None } - )]), - required: None, - additional_properties: None, - }, - description: "Pagination".to_string(), - strict: false, - output_schema: Some(mcp_call_tool_result_output_schema(serde_json::json!({}))), - defer_loading: None, - }) - ); - } - - #[test] - fn test_mcp_tool_array_without_items_gets_default_string_items() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - features.enable(Feature::ApplyPatchFreeform); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - - let (tools, _) = build_specs( - &tools_config, - Some(HashMap::from([( - "dash/tags".to_string(), - mcp_tool( - "tags", - "Tags", - serde_json::json!({ - "type": "object", - "properties": {"tags": {"type": "array"}} - }), - ), - )])), - None, - &[], - ) - .build(); - - let tool = find_tool(&tools, "dash/tags"); - assert_eq!( - tool.spec, - ToolSpec::Function(ResponsesApiTool { - name: "dash/tags".to_string(), - parameters: JsonSchema::Object { - properties: BTreeMap::from([( - "tags".to_string(), - JsonSchema::Array { - items: Box::new(JsonSchema::String { description: None }), - description: None - } - )]), - required: None, - additional_properties: None, - }, - description: "Tags".to_string(), - strict: false, - output_schema: Some(mcp_call_tool_result_output_schema(serde_json::json!({}))), - defer_loading: None, - }) - ); - } - - #[test] - fn test_mcp_tool_anyof_defaults_to_string() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - - let (tools, _) = build_specs( - &tools_config, - Some(HashMap::from([( - "dash/value".to_string(), - mcp_tool( - "value", - "AnyOf Value", - serde_json::json!({ - "type": "object", - "properties": { - "value": {"anyOf": [{"type": "string"}, {"type": "number"}]} - } - }), - ), - )])), - None, - &[], - ) - .build(); - - let tool = find_tool(&tools, "dash/value"); - assert_eq!( - tool.spec, - ToolSpec::Function(ResponsesApiTool { - name: "dash/value".to_string(), - parameters: JsonSchema::Object { - properties: BTreeMap::from([( - "value".to_string(), - JsonSchema::String { description: None } - )]), - required: None, - additional_properties: None, - }, - description: "AnyOf Value".to_string(), - strict: false, - output_schema: Some(mcp_call_tool_result_output_schema(serde_json::json!({}))), - defer_loading: None, - }) - ); - } - - #[test] - fn test_shell_tool() { - let tool = super::create_shell_tool(false); - let ToolSpec::Function(ResponsesApiTool { - description, name, .. - }) = &tool - else { - panic!("expected function tool"); - }; - assert_eq!(name, "shell"); - - let expected = if cfg!(windows) { - r#"Runs a Powershell command (Windows) and returns its output. Arguments to `shell` will be passed to CreateProcessW(). Most commands should be prefixed with ["powershell.exe", "-Command"]. - -Examples of valid command strings: - -- ls -a (show hidden): ["powershell.exe", "-Command", "Get-ChildItem -Force"] -- recursive find by name: ["powershell.exe", "-Command", "Get-ChildItem -Recurse -Filter *.py"] -- recursive grep: ["powershell.exe", "-Command", "Get-ChildItem -Path C:\\myrepo -Recurse | Select-String -Pattern 'TODO' -CaseSensitive"] -- ps aux | grep python: ["powershell.exe", "-Command", "Get-Process | Where-Object { $_.ProcessName -like '*python*' }"] -- setting an env var: ["powershell.exe", "-Command", "$env:FOO='bar'; echo $env:FOO"] -- running an inline Python script: ["powershell.exe", "-Command", "@'\\nprint('Hello, world!')\\n'@ | python -"]"# - } else { - r#"Runs a shell command and returns its output. -- The arguments to `shell` will be passed to execvp(). Most terminal commands should be prefixed with ["bash", "-lc"]. -- Always set the `workdir` param when using the shell function. Do not use `cd` unless absolutely necessary."# - }.to_string(); - assert_eq!(description, &expected); - } - - #[test] - fn shell_tool_with_request_permission_includes_additional_permissions() { - let tool = super::create_shell_tool(true); - let ToolSpec::Function(ResponsesApiTool { parameters, .. }) = tool else { - panic!("expected function tool"); - }; - let JsonSchema::Object { properties, .. } = parameters else { - panic!("expected object parameters"); - }; - - assert!(properties.contains_key("additional_permissions")); - - let Some(JsonSchema::String { - description: Some(description), - }) = properties.get("sandbox_permissions") - else { - panic!("expected sandbox_permissions description"); - }; - assert!(description.contains("with_additional_permissions")); - assert!(description.contains("macOS permissions")); - - let Some(JsonSchema::Object { - properties: additional_properties, - .. - }) = properties.get("additional_permissions") - else { - panic!("expected additional_permissions schema"); - }; - assert!(additional_properties.contains_key("network")); - assert!(additional_properties.contains_key("file_system")); - assert!(additional_properties.contains_key("macos")); - } - - #[test] - fn request_permissions_tool_includes_full_permission_schema() { - let tool = super::create_request_permissions_tool(); - let ToolSpec::Function(ResponsesApiTool { parameters, .. }) = tool else { - panic!("expected function tool"); - }; - let JsonSchema::Object { properties, .. } = parameters else { - panic!("expected object parameters"); - }; - let Some(JsonSchema::Object { - properties: permission_properties, - additional_properties, - .. - }) = properties.get("permissions") - else { - panic!("expected permissions object"); - }; - - assert_eq!(additional_properties, &Some(false.into())); - assert!(permission_properties.contains_key("network")); - assert!(permission_properties.contains_key("file_system")); - assert!(permission_properties.contains_key("macos")); - - let Some(JsonSchema::Object { - properties: network_properties, - additional_properties, - .. - }) = permission_properties.get("network") - else { - panic!("expected network object"); - }; - assert_eq!(additional_properties, &Some(false.into())); - assert!(network_properties.contains_key("enabled")); - - let Some(JsonSchema::Object { - properties: file_system_properties, - additional_properties, - .. - }) = permission_properties.get("file_system") - else { - panic!("expected file_system object"); - }; - assert_eq!(additional_properties, &Some(false.into())); - assert!(file_system_properties.contains_key("read")); - assert!(file_system_properties.contains_key("write")); - - let Some(JsonSchema::Object { - properties: macos_properties, - additional_properties, - .. - }) = permission_properties.get("macos") - else { - panic!("expected macos object"); - }; - assert_eq!(additional_properties, &Some(false.into())); - assert!(macos_properties.contains_key("preferences")); - assert!(macos_properties.contains_key("automations")); - assert!(macos_properties.contains_key("accessibility")); - assert!(macos_properties.contains_key("calendar")); - } - - #[test] - fn test_shell_command_tool() { - let tool = super::create_shell_command_tool(true, false); - let ToolSpec::Function(ResponsesApiTool { - description, name, .. - }) = &tool - else { - panic!("expected function tool"); - }; - assert_eq!(name, "shell_command"); - - let expected = if cfg!(windows) { - r#"Runs a Powershell command (Windows) and returns its output. - -Examples of valid command strings: - -- ls -a (show hidden): "Get-ChildItem -Force" -- recursive find by name: "Get-ChildItem -Recurse -Filter *.py" -- recursive grep: "Get-ChildItem -Path C:\\myrepo -Recurse | Select-String -Pattern 'TODO' -CaseSensitive" -- ps aux | grep python: "Get-Process | Where-Object { $_.ProcessName -like '*python*' }" -- setting an env var: "$env:FOO='bar'; echo $env:FOO" -- running an inline Python script: "@'\\nprint('Hello, world!')\\n'@ | python -"#.to_string() - } else { - r#"Runs a shell command and returns its output. -- Always set the `workdir` param when using the shell_command function. Do not use `cd` unless absolutely necessary."#.to_string() - }; - assert_eq!(description, &expected); - } - - #[test] - fn test_get_openai_tools_mcp_tools_with_additional_properties_schema() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::UnifiedExec); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - let (tools, _) = build_specs( - &tools_config, - Some(HashMap::from([( - "test_server/do_something_cool".to_string(), - mcp_tool( - "do_something_cool", - "Do something cool", - serde_json::json!({ - "type": "object", - "properties": { - "string_argument": {"type": "string"}, - "number_argument": {"type": "number"}, - "object_argument": { - "type": "object", - "properties": { - "string_property": {"type": "string"}, - "number_property": {"type": "number"} - }, - "required": ["string_property", "number_property"], - "additionalProperties": { - "type": "object", - "properties": { - "addtl_prop": {"type": "string"} - }, - "required": ["addtl_prop"], - "additionalProperties": false - } - } - } - }), - ), - )])), - None, - &[], - ) - .build(); - - let tool = find_tool(&tools, "test_server/do_something_cool"); - assert_eq!( - tool.spec, - ToolSpec::Function(ResponsesApiTool { - name: "test_server/do_something_cool".to_string(), - parameters: JsonSchema::Object { - properties: BTreeMap::from([ - ( - "string_argument".to_string(), - JsonSchema::String { description: None } - ), - ( - "number_argument".to_string(), - JsonSchema::Number { description: None } - ), - ( - "object_argument".to_string(), - JsonSchema::Object { - properties: BTreeMap::from([ - ( - "string_property".to_string(), - JsonSchema::String { description: None } - ), - ( - "number_property".to_string(), - JsonSchema::Number { description: None } - ), - ]), - required: Some(vec![ - "string_property".to_string(), - "number_property".to_string(), - ]), - additional_properties: Some( - JsonSchema::Object { - properties: BTreeMap::from([( - "addtl_prop".to_string(), - JsonSchema::String { description: None } - ),]), - required: Some(vec!["addtl_prop".to_string(),]), - additional_properties: Some(false.into()), - } - .into() - ), - }, - ), - ]), - required: None, - additional_properties: None, - }, - description: "Do something cool".to_string(), - strict: false, - output_schema: Some(mcp_call_tool_result_output_schema(serde_json::json!({}))), - defer_loading: None, - }) - ); - } - - #[test] - fn code_mode_augments_builtin_tool_descriptions_with_typed_sample() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::CodeMode); - features.enable(Feature::UnifiedExec); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - - let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); - let ToolSpec::Function(ResponsesApiTool { description, .. }) = - &find_tool(&tools, "view_image").spec - else { - panic!("expected function tool"); - }; - - assert_eq!( - description, - "View a local image from the filesystem (only use if given a full filepath by the user, and the image isn't already attached to the thread context within tags).\n\nCode mode declaration:\n```ts\nimport { view_image } from \"tools.js\";\ndeclare function view_image(args: {\n path: string;\n}): Promise;\n```" - ); - } - - #[test] - fn code_mode_augments_mcp_tool_descriptions_with_namespaced_sample() { - let config = test_config(); - let model_info = - ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); - let mut features = Features::with_defaults(); - features.enable(Feature::CodeMode); - features.enable(Feature::UnifiedExec); - let available_models = Vec::new(); - let tools_config = ToolsConfig::new(&ToolsConfigParams { - model_info: &model_info, - available_models: &available_models, - features: &features, - web_search_mode: Some(WebSearchMode::Cached), - session_source: SessionSource::Cli, - }); - - let (tools, _) = build_specs( - &tools_config, - Some(HashMap::from([( - "mcp__sample__echo".to_string(), - mcp_tool( - "echo", - "Echo text", - serde_json::json!({ - "type": "object", - "properties": { - "message": {"type": "string"} - }, - "required": ["message"], - "additionalProperties": false - }), - ), - )])), - None, - &[], - ) - .build(); - - let ToolSpec::Function(ResponsesApiTool { description, .. }) = - &find_tool(&tools, "mcp__sample__echo").spec - else { - panic!("expected function tool"); - }; - - assert_eq!( - description, - "Echo text\n\nCode mode declaration:\n```ts\nimport { echo } from \"tools/mcp/sample.js\";\ndeclare function echo(args: {\n message: string;\n}): Promise<{\n _meta?: unknown;\n content: Array;\n isError?: boolean;\n structuredContent?: unknown;\n}>;\n```" - ); - } - - #[test] - fn chat_tools_include_top_level_name() { - let properties = - BTreeMap::from([("foo".to_string(), JsonSchema::String { description: None })]); - let tools = vec![ToolSpec::Function(ResponsesApiTool { - name: "demo".to_string(), - description: "A demo tool".to_string(), - strict: false, - defer_loading: None, - parameters: JsonSchema::Object { - properties, - required: None, - additional_properties: None, - }, - output_schema: None, - })]; - - let responses_json = create_tools_json_for_responses_api(&tools).unwrap(); - assert_eq!( - responses_json, - vec![json!({ - "type": "function", - "name": "demo", - "description": "A demo tool", - "strict": false, - "parameters": { - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - }, - })] - ); - } -} +#[path = "spec_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/tools/spec_tests.rs b/codex-rs/core/src/tools/spec_tests.rs new file mode 100644 index 0000000000..a1b7857962 --- /dev/null +++ b/codex-rs/core/src/tools/spec_tests.rs @@ -0,0 +1,2397 @@ +use crate::client_common::tools::FreeformTool; +use crate::config::test_config; +use crate::models_manager::manager::ModelsManager; +use crate::models_manager::model_info::with_config_overrides; +use crate::tools::registry::ConfiguredToolSpec; +use codex_app_server_protocol::AppInfo; +use codex_protocol::openai_models::InputModality; +use codex_protocol::openai_models::ModelInfo; +use codex_protocol::openai_models::ModelsResponse; +use pretty_assertions::assert_eq; + +use super::*; + +fn mcp_tool(name: &str, description: &str, input_schema: serde_json::Value) -> rmcp::model::Tool { + rmcp::model::Tool { + name: name.to_string().into(), + title: None, + description: Some(description.to_string().into()), + input_schema: std::sync::Arc::new(rmcp::model::object(input_schema)), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + } +} + +fn discoverable_connector(id: &str, name: &str, description: &str) -> DiscoverableTool { + let slug = name.replace(' ', "-").to_lowercase(); + DiscoverableTool::Connector(Box::new(AppInfo { + id: id.to_string(), + name: name.to_string(), + description: Some(description.to_string()), + logo_url: None, + logo_url_dark: None, + distribution_channel: None, + branding: None, + app_metadata: None, + labels: None, + install_url: Some(format!("https://chatgpt.com/apps/{slug}/{id}")), + is_accessible: false, + is_enabled: true, + plugin_display_names: Vec::new(), + })) +} + +#[test] +fn mcp_tool_to_openai_tool_inserts_empty_properties() { + let mut schema = rmcp::model::JsonObject::new(); + schema.insert("type".to_string(), serde_json::json!("object")); + + let tool = rmcp::model::Tool { + name: "no_props".to_string().into(), + title: None, + description: Some("No properties".to_string().into()), + input_schema: std::sync::Arc::new(schema), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + }; + + let openai_tool = + mcp_tool_to_openai_tool("server/no_props".to_string(), tool).expect("convert tool"); + let parameters = serde_json::to_value(openai_tool.parameters).expect("serialize schema"); + + assert_eq!(parameters.get("properties"), Some(&serde_json::json!({}))); +} + +#[test] +fn mcp_tool_to_openai_tool_preserves_top_level_output_schema() { + let mut input_schema = rmcp::model::JsonObject::new(); + input_schema.insert("type".to_string(), serde_json::json!("object")); + + let mut output_schema = rmcp::model::JsonObject::new(); + output_schema.insert( + "properties".to_string(), + serde_json::json!({ + "result": { + "properties": { + "nested": {} + } + } + }), + ); + output_schema.insert("required".to_string(), serde_json::json!(["result"])); + + let tool = rmcp::model::Tool { + name: "with_output".to_string().into(), + title: None, + description: Some("Has output schema".to_string().into()), + input_schema: std::sync::Arc::new(input_schema), + output_schema: Some(std::sync::Arc::new(output_schema)), + annotations: None, + execution: None, + icons: None, + meta: None, + }; + + let openai_tool = mcp_tool_to_openai_tool("mcp__server__with_output".to_string(), tool) + .expect("convert tool"); + + assert_eq!( + openai_tool.output_schema, + Some(serde_json::json!({ + "type": "object", + "properties": { + "content": { + "type": "array", + "items": {} + }, + "structuredContent": { + "properties": { + "result": { + "properties": { + "nested": {} + } + } + }, + "required": ["result"] + }, + "isError": { + "type": "boolean" + }, + "_meta": {} + }, + "required": ["content"], + "additionalProperties": false + })) + ); +} + +#[test] +fn mcp_tool_to_openai_tool_preserves_output_schema_without_inferred_type() { + let mut input_schema = rmcp::model::JsonObject::new(); + input_schema.insert("type".to_string(), serde_json::json!("object")); + + let mut output_schema = rmcp::model::JsonObject::new(); + output_schema.insert("enum".to_string(), serde_json::json!(["ok", "error"])); + + let tool = rmcp::model::Tool { + name: "with_enum_output".to_string().into(), + title: None, + description: Some("Has enum output schema".to_string().into()), + input_schema: std::sync::Arc::new(input_schema), + output_schema: Some(std::sync::Arc::new(output_schema)), + annotations: None, + execution: None, + icons: None, + meta: None, + }; + + let openai_tool = mcp_tool_to_openai_tool("mcp__server__with_enum_output".to_string(), tool) + .expect("convert tool"); + + assert_eq!( + openai_tool.output_schema, + Some(serde_json::json!({ + "type": "object", + "properties": { + "content": { + "type": "array", + "items": {} + }, + "structuredContent": { + "enum": ["ok", "error"] + }, + "isError": { + "type": "boolean" + }, + "_meta": {} + }, + "required": ["content"], + "additionalProperties": false + })) + ); +} + +#[test] +fn search_tool_deferred_tools_always_set_defer_loading_true() { + let tool = mcp_tool( + "lookup_order", + "Look up an order", + serde_json::json!({ + "type": "object", + "properties": { + "order_id": {"type": "string"} + }, + "required": ["order_id"], + "additionalProperties": false, + }), + ); + + let openai_tool = + mcp_tool_to_deferred_openai_tool("mcp__codex_apps__lookup_order".to_string(), tool) + .expect("convert deferred tool"); + + assert_eq!(openai_tool.defer_loading, Some(true)); +} + +#[test] +fn deferred_responses_api_tool_serializes_with_defer_loading() { + let tool = mcp_tool( + "lookup_order", + "Look up an order", + serde_json::json!({ + "type": "object", + "properties": { + "order_id": {"type": "string"} + }, + "required": ["order_id"], + "additionalProperties": false, + }), + ); + + let serialized = serde_json::to_value(ToolSpec::Function( + mcp_tool_to_deferred_openai_tool("mcp__codex_apps__lookup_order".to_string(), tool) + .expect("convert deferred tool"), + )) + .expect("serialize deferred tool"); + + assert_eq!( + serialized, + serde_json::json!({ + "type": "function", + "name": "mcp__codex_apps__lookup_order", + "description": "Look up an order", + "strict": false, + "defer_loading": true, + "parameters": { + "type": "object", + "properties": { + "order_id": {"type": "string"} + }, + "required": ["order_id"], + "additionalProperties": false, + } + }) + ); +} + +fn tool_name(tool: &ToolSpec) -> &str { + match tool { + ToolSpec::Function(ResponsesApiTool { name, .. }) => name, + ToolSpec::ToolSearch { .. } => "tool_search", + ToolSpec::LocalShell {} => "local_shell", + ToolSpec::ImageGeneration { .. } => "image_generation", + ToolSpec::WebSearch { .. } => "web_search", + ToolSpec::Freeform(FreeformTool { name, .. }) => name, + } +} + +// Avoid order-based assertions; compare via set containment instead. +fn assert_contains_tool_names(tools: &[ConfiguredToolSpec], expected_subset: &[&str]) { + use std::collections::HashSet; + let mut names = HashSet::new(); + let mut duplicates = Vec::new(); + for name in tools.iter().map(|t| tool_name(&t.spec)) { + if !names.insert(name) { + duplicates.push(name); + } + } + assert!( + duplicates.is_empty(), + "duplicate tool entries detected: {duplicates:?}" + ); + for expected in expected_subset { + assert!( + names.contains(expected), + "expected tool {expected} to be present; had: {names:?}" + ); + } +} + +fn assert_lacks_tool_name(tools: &[ConfiguredToolSpec], expected_absent: &str) { + let names = tools + .iter() + .map(|tool| tool_name(&tool.spec)) + .collect::>(); + assert!( + !names.contains(&expected_absent), + "expected tool {expected_absent} to be absent; had: {names:?}" + ); +} + +fn shell_tool_name(config: &ToolsConfig) -> Option<&'static str> { + match config.shell_type { + ConfigShellToolType::Default => Some("shell"), + ConfigShellToolType::Local => Some("local_shell"), + ConfigShellToolType::UnifiedExec => None, + ConfigShellToolType::Disabled => None, + ConfigShellToolType::ShellCommand => Some("shell_command"), + } +} + +fn find_tool<'a>(tools: &'a [ConfiguredToolSpec], expected_name: &str) -> &'a ConfiguredToolSpec { + tools + .iter() + .find(|tool| tool_name(&tool.spec) == expected_name) + .unwrap_or_else(|| panic!("expected tool {expected_name}")) +} + +fn strip_descriptions_schema(schema: &mut JsonSchema) { + match schema { + JsonSchema::Boolean { description } + | JsonSchema::String { description } + | JsonSchema::Number { description } => { + *description = None; + } + JsonSchema::Array { items, description } => { + strip_descriptions_schema(items); + *description = None; + } + JsonSchema::Object { + properties, + required: _, + additional_properties, + } => { + for v in properties.values_mut() { + strip_descriptions_schema(v); + } + if let Some(AdditionalProperties::Schema(s)) = additional_properties { + strip_descriptions_schema(s); + } + } + } +} + +fn strip_descriptions_tool(spec: &mut ToolSpec) { + match spec { + ToolSpec::ToolSearch { parameters, .. } => strip_descriptions_schema(parameters), + ToolSpec::Function(ResponsesApiTool { parameters, .. }) => { + strip_descriptions_schema(parameters); + } + ToolSpec::Freeform(_) + | ToolSpec::LocalShell {} + | ToolSpec::ImageGeneration { .. } + | ToolSpec::WebSearch { .. } => {} + } +} + +fn model_info_from_models_json(slug: &str) -> ModelInfo { + let config = test_config(); + let response: ModelsResponse = + serde_json::from_str(include_str!("../../models.json")).expect("valid models.json"); + let model = response + .models + .into_iter() + .find(|candidate| candidate.slug == slug) + .unwrap_or_else(|| panic!("model slug {slug} is missing from models.json")); + with_config_overrides(model, &config) +} + +#[test] +fn test_full_toolset_specs_for_gpt5_codex_unified_exec_web_search() { + let model_info = model_info_from_models_json("gpt-5-codex"); + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + let available_models = Vec::new(); + let config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&config, None, None, &[]).build(); + + // Build actual map name -> spec + use std::collections::BTreeMap; + use std::collections::HashSet; + let mut actual: BTreeMap = BTreeMap::from([]); + let mut duplicate_names = Vec::new(); + for t in &tools { + let name = tool_name(&t.spec).to_string(); + if actual.insert(name.clone(), t.spec.clone()).is_some() { + duplicate_names.push(name); + } + } + assert!( + duplicate_names.is_empty(), + "duplicate tool entries detected: {duplicate_names:?}" + ); + + // Build expected from the same helpers used by the builder. + let mut expected: BTreeMap = BTreeMap::from([]); + for spec in [ + create_exec_command_tool(true, false), + create_write_stdin_tool(), + PLAN_TOOL.clone(), + create_request_user_input_tool(CollaborationModesConfig::default()), + create_apply_patch_freeform_tool(), + ToolSpec::WebSearch { + external_web_access: Some(true), + filters: None, + user_location: None, + search_context_size: None, + search_content_types: None, + }, + create_view_image_tool(config.can_request_original_image_detail), + ] { + expected.insert(tool_name(&spec).to_string(), spec); + } + + if config.request_permission_enabled { + let spec = create_request_permissions_tool(); + expected.insert(tool_name(&spec).to_string(), spec); + } + + // Exact name set match — this is the only test allowed to fail when tools change. + let actual_names: HashSet<_> = actual.keys().cloned().collect(); + let expected_names: HashSet<_> = expected.keys().cloned().collect(); + assert_eq!(actual_names, expected_names, "tool name set mismatch"); + + // Compare specs ignoring human-readable descriptions. + for name in expected.keys() { + let mut a = actual.get(name).expect("present").clone(); + let mut e = expected.get(name).expect("present").clone(); + strip_descriptions_tool(&mut a); + strip_descriptions_tool(&mut e); + assert_eq!(a, e, "spec mismatch for {name}"); + } +} + +#[test] +fn test_build_specs_collab_tools_enabled() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::Collab); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + assert_contains_tool_names( + &tools, + &["spawn_agent", "send_input", "wait", "close_agent"], + ); + assert_lacks_tool_name(&tools, "spawn_agents_on_csv"); +} + +#[test] +fn test_build_specs_spawn_csv_enables_agent_jobs_and_collab_tools() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::SpawnCsv); + features.normalize_dependencies(); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + assert_contains_tool_names( + &tools, + &[ + "spawn_agent", + "send_input", + "wait", + "close_agent", + "spawn_agents_on_csv", + ], + ); +} + +#[test] +fn view_image_tool_omits_detail_without_original_detail_feature() { + let config = test_config(); + let mut model_info = + ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + model_info.supports_image_detail_original = true; + let features = Features::with_defaults(); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + let view_image = find_tool(&tools, VIEW_IMAGE_TOOL_NAME); + let ToolSpec::Function(ResponsesApiTool { parameters, .. }) = &view_image.spec else { + panic!("view_image should be a function tool"); + }; + let JsonSchema::Object { properties, .. } = parameters else { + panic!("view_image should use an object schema"); + }; + assert!(!properties.contains_key("detail")); +} + +#[test] +fn view_image_tool_includes_detail_with_original_detail_feature() { + let config = test_config(); + let mut model_info = + ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + model_info.supports_image_detail_original = true; + let mut features = Features::with_defaults(); + features.enable(Feature::ImageDetailOriginal); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + let view_image = find_tool(&tools, VIEW_IMAGE_TOOL_NAME); + let ToolSpec::Function(ResponsesApiTool { parameters, .. }) = &view_image.spec else { + panic!("view_image should be a function tool"); + }; + let JsonSchema::Object { properties, .. } = parameters else { + panic!("view_image should use an object schema"); + }; + assert!(properties.contains_key("detail")); + let Some(JsonSchema::String { + description: Some(description), + }) = properties.get("detail") + else { + panic!("view_image detail should include a description"); + }; + assert!(description.contains("only supported value is `original`")); + assert!(description.contains("omit this field for default resized behavior")); +} + +#[test] +fn test_build_specs_artifact_tool_enabled() { + let mut config = test_config(); + let runtime_root = tempfile::TempDir::new().expect("create temp codex home"); + config.codex_home = runtime_root.path().to_path_buf(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::Artifact); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + assert_contains_tool_names(&tools, &["artifacts"]); +} + +#[test] +fn test_build_specs_agent_job_worker_tools_enabled() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::SpawnCsv); + features.normalize_dependencies(); + features.enable(Feature::Sqlite); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::SubAgent(SubAgentSource::Other( + "agent_job:test".to_string(), + )), + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + assert_contains_tool_names( + &tools, + &[ + "spawn_agent", + "send_input", + "resume_agent", + "wait", + "close_agent", + "spawn_agents_on_csv", + "report_agent_job_result", + ], + ); + assert_lacks_tool_name(&tools, "request_user_input"); +} + +#[test] +fn request_user_input_description_reflects_default_mode_feature_flag() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + let request_user_input_tool = find_tool(&tools, "request_user_input"); + assert_eq!( + request_user_input_tool.spec, + create_request_user_input_tool(CollaborationModesConfig::default()) + ); + + features.enable(Feature::DefaultModeRequestUserInput); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + let request_user_input_tool = find_tool(&tools, "request_user_input"); + assert_eq!( + request_user_input_tool.spec, + create_request_user_input_tool(CollaborationModesConfig { + default_mode_request_user_input: true, + }) + ); +} + +#[test] +fn request_permissions_requires_feature_flag() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let features = Features::with_defaults(); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + assert_lacks_tool_name(&tools, "request_permissions"); + + let mut features = Features::with_defaults(); + features.enable(Feature::RequestPermissionsTool); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + let request_permissions_tool = find_tool(&tools, "request_permissions"); + assert_eq!( + request_permissions_tool.spec, + create_request_permissions_tool() + ); +} + +#[test] +fn request_permissions_tool_is_independent_from_additional_permissions() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::RequestPermissions); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + + assert_lacks_tool_name(&tools, "request_permissions"); +} + +#[test] +fn get_memory_requires_feature_flag() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.disable(Feature::MemoryTool); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + assert!( + !tools.iter().any(|t| t.spec.name() == "get_memory"), + "get_memory should be disabled when memory_tool feature is off" + ); +} + +#[test] +fn js_repl_requires_feature_flag() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let features = Features::with_defaults(); + + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + + assert!( + !tools.iter().any(|tool| tool.spec.name() == "js_repl"), + "js_repl should be disabled when the feature is off" + ); + assert!( + !tools.iter().any(|tool| tool.spec.name() == "js_repl_reset"), + "js_repl_reset should be disabled when the feature is off" + ); +} + +#[test] +fn js_repl_enabled_adds_tools() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::JsRepl); + + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + assert_contains_tool_names(&tools, &["js_repl", "js_repl_reset"]); +} + +#[test] +fn image_generation_tools_require_feature_and_supported_model() { + let config = test_config(); + let mut supported_model_info = + ModelsManager::construct_model_info_offline_for_tests("gpt-5.2", &config); + supported_model_info.slug = "custom/gpt-5.2-variant".to_string(); + let mut unsupported_model_info = supported_model_info.clone(); + unsupported_model_info.input_modalities = vec![InputModality::Text]; + let default_features = Features::with_defaults(); + let mut image_generation_features = default_features.clone(); + image_generation_features.enable(Feature::ImageGeneration); + + let available_models = Vec::new(); + let default_tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &supported_model_info, + available_models: &available_models, + features: &default_features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (default_tools, _) = build_specs(&default_tools_config, None, None, &[]).build(); + assert!( + !default_tools + .iter() + .any(|tool| tool.spec.name() == "image_generation"), + "image_generation should be disabled by default" + ); + + let supported_tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &supported_model_info, + available_models: &available_models, + features: &image_generation_features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (supported_tools, _) = build_specs(&supported_tools_config, None, None, &[]).build(); + assert_contains_tool_names(&supported_tools, &["image_generation"]); + let image_generation_tool = find_tool(&supported_tools, "image_generation"); + assert_eq!( + serde_json::to_value(&image_generation_tool.spec).expect("serialize image tool"), + serde_json::json!({ + "type": "image_generation", + "output_format": "png" + }) + ); + + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &unsupported_model_info, + available_models: &available_models, + features: &image_generation_features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + assert!( + !tools + .iter() + .any(|tool| tool.spec.name() == "image_generation"), + "image_generation should be disabled for unsupported models" + ); +} + +#[test] +fn js_repl_freeform_grammar_blocks_common_non_js_prefixes() { + let ToolSpec::Freeform(FreeformTool { format, .. }) = create_js_repl_tool() else { + panic!("js_repl should use a freeform tool spec"); + }; + + assert_eq!(format.syntax, "lark"); + assert!(format.definition.contains("PRAGMA_LINE")); + assert!(format.definition.contains("`[^`]")); + assert!(format.definition.contains("``[^`]")); + assert!(format.definition.contains("PLAIN_JS_SOURCE")); + assert!(format.definition.contains("codex-js-repl:")); + assert!(!format.definition.contains("(?!")); +} + +fn assert_model_tools( + model_slug: &str, + features: &Features, + web_search_mode: Option, + expected_tools: &[&str], +) { + let _config = test_config(); + let model_info = model_info_from_models_json(model_slug); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features, + web_search_mode, + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + let tool_names = tools.iter().map(|t| t.spec.name()).collect::>(); + assert_eq!(&tool_names, &expected_tools,); +} + +fn assert_default_model_tools( + model_slug: &str, + features: &Features, + web_search_mode: Option, + shell_tool: &'static str, + expected_tail: &[&str], +) { + let mut expected = if features.enabled(Feature::UnifiedExec) { + vec!["exec_command", "write_stdin"] + } else { + vec![shell_tool] + }; + expected.extend(expected_tail); + assert_model_tools(model_slug, features, web_search_mode, &expected); +} + +#[test] +fn web_search_mode_cached_sets_external_web_access_false() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let features = Features::with_defaults(); + + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + + let tool = find_tool(&tools, "web_search"); + assert_eq!( + tool.spec, + ToolSpec::WebSearch { + external_web_access: Some(false), + filters: None, + user_location: None, + search_context_size: None, + search_content_types: None, + } + ); +} + +#[test] +fn web_search_mode_live_sets_external_web_access_true() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let features = Features::with_defaults(); + + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + + let tool = find_tool(&tools, "web_search"); + assert_eq!( + tool.spec, + ToolSpec::WebSearch { + external_web_access: Some(true), + filters: None, + user_location: None, + search_context_size: None, + search_content_types: None, + } + ); +} + +#[test] +fn web_search_config_is_forwarded_to_tool_spec() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let features = Features::with_defaults(); + let web_search_config = WebSearchConfig { + filters: Some(codex_protocol::config_types::WebSearchFilters { + allowed_domains: Some(vec!["example.com".to_string()]), + }), + user_location: Some(codex_protocol::config_types::WebSearchUserLocation { + r#type: codex_protocol::config_types::WebSearchUserLocationType::Approximate, + country: Some("US".to_string()), + region: Some("California".to_string()), + city: Some("San Francisco".to_string()), + timezone: Some("America/Los_Angeles".to_string()), + }), + search_context_size: Some(codex_protocol::config_types::WebSearchContextSize::High), + }; + + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, + }) + .with_web_search_config(Some(web_search_config.clone())); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + + let tool = find_tool(&tools, "web_search"); + assert_eq!( + tool.spec, + ToolSpec::WebSearch { + external_web_access: Some(true), + filters: web_search_config + .filters + .map(crate::client_common::tools::ResponsesApiWebSearchFilters::from), + user_location: web_search_config + .user_location + .map(crate::client_common::tools::ResponsesApiWebSearchUserLocation::from), + search_context_size: web_search_config.search_context_size, + search_content_types: None, + } + ); +} + +#[test] +fn web_search_tool_type_text_and_image_sets_search_content_types() { + let config = test_config(); + let mut model_info = + ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + model_info.web_search_tool_type = WebSearchToolType::TextAndImage; + let features = Features::with_defaults(); + + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + + let tool = find_tool(&tools, "web_search"); + assert_eq!( + tool.spec, + ToolSpec::WebSearch { + external_web_access: Some(true), + filters: None, + user_location: None, + search_context_size: None, + search_content_types: Some( + WEB_SEARCH_CONTENT_TYPES + .into_iter() + .map(str::to_string) + .collect() + ), + } + ); +} + +#[test] +fn mcp_resource_tools_are_hidden_without_mcp_servers() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let features = Features::with_defaults(); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + + assert!( + !tools.iter().any(|tool| matches!( + tool.spec.name(), + "list_mcp_resources" | "list_mcp_resource_templates" | "read_mcp_resource" + )), + "MCP resource tools should be omitted when no MCP servers are configured" + ); +} + +#[test] +fn mcp_resource_tools_are_included_when_mcp_servers_are_present() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let features = Features::with_defaults(); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, Some(HashMap::new()), None, &[]).build(); + + assert_contains_tool_names( + &tools, + &[ + "list_mcp_resources", + "list_mcp_resource_templates", + "read_mcp_resource", + ], + ); +} + +#[test] +fn test_build_specs_gpt5_codex_default() { + let features = Features::with_defaults(); + assert_default_model_tools( + "gpt-5-codex", + &features, + Some(WebSearchMode::Cached), + "shell_command", + &[ + "update_plan", + "request_user_input", + "apply_patch", + "web_search", + "view_image", + ], + ); +} + +#[test] +fn test_build_specs_gpt51_codex_default() { + let features = Features::with_defaults(); + assert_default_model_tools( + "gpt-5.1-codex", + &features, + Some(WebSearchMode::Cached), + "shell_command", + &[ + "update_plan", + "request_user_input", + "apply_patch", + "web_search", + "view_image", + ], + ); +} + +#[test] +fn test_build_specs_gpt5_codex_unified_exec_web_search() { + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + assert_model_tools( + "gpt-5-codex", + &features, + Some(WebSearchMode::Live), + &[ + "exec_command", + "write_stdin", + "update_plan", + "request_user_input", + "apply_patch", + "web_search", + "view_image", + ], + ); +} + +#[test] +fn test_build_specs_gpt51_codex_unified_exec_web_search() { + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + assert_model_tools( + "gpt-5.1-codex", + &features, + Some(WebSearchMode::Live), + &[ + "exec_command", + "write_stdin", + "update_plan", + "request_user_input", + "apply_patch", + "web_search", + "view_image", + ], + ); +} + +#[test] +fn test_gpt_5_1_codex_max_defaults() { + let features = Features::with_defaults(); + assert_default_model_tools( + "gpt-5.1-codex-max", + &features, + Some(WebSearchMode::Cached), + "shell_command", + &[ + "update_plan", + "request_user_input", + "apply_patch", + "web_search", + "view_image", + ], + ); +} + +#[test] +fn test_codex_5_1_mini_defaults() { + let features = Features::with_defaults(); + assert_default_model_tools( + "gpt-5.1-codex-mini", + &features, + Some(WebSearchMode::Cached), + "shell_command", + &[ + "update_plan", + "request_user_input", + "apply_patch", + "web_search", + "view_image", + ], + ); +} + +#[test] +fn test_gpt_5_defaults() { + let features = Features::with_defaults(); + assert_default_model_tools( + "gpt-5", + &features, + Some(WebSearchMode::Cached), + "shell", + &[ + "update_plan", + "request_user_input", + "web_search", + "view_image", + ], + ); +} + +#[test] +fn test_gpt_5_1_defaults() { + let features = Features::with_defaults(); + assert_default_model_tools( + "gpt-5.1", + &features, + Some(WebSearchMode::Cached), + "shell_command", + &[ + "update_plan", + "request_user_input", + "apply_patch", + "web_search", + "view_image", + ], + ); +} + +#[test] +fn test_gpt_5_1_codex_max_unified_exec_web_search() { + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + assert_model_tools( + "gpt-5.1-codex-max", + &features, + Some(WebSearchMode::Live), + &[ + "exec_command", + "write_stdin", + "update_plan", + "request_user_input", + "apply_patch", + "web_search", + "view_image", + ], + ); +} + +#[test] +fn test_build_specs_default_shell_present() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("o3", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, Some(HashMap::new()), None, &[]).build(); + + // Only check the shell variant and a couple of core tools. + let mut subset = vec!["exec_command", "write_stdin", "update_plan"]; + if let Some(shell_tool) = shell_tool_name(&tools_config) { + subset.push(shell_tool); + } + assert_contains_tool_names(&tools, &subset); +} + +#[test] +fn shell_zsh_fork_prefers_shell_command_over_unified_exec() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("o3", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + features.enable(Feature::ShellZshFork); + + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, + }); + + assert_eq!(tools_config.shell_type, ConfigShellToolType::ShellCommand); + assert_eq!( + tools_config.shell_command_backend, + ShellCommandBackendConfig::ZshFork + ); + assert_eq!( + tools_config.unified_exec_backend, + UnifiedExecBackendConfig::ZshFork + ); +} + +#[test] +#[ignore] +fn test_parallel_support_flags() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + + assert!(find_tool(&tools, "exec_command").supports_parallel_tool_calls); + assert!(!find_tool(&tools, "write_stdin").supports_parallel_tool_calls); + assert!(find_tool(&tools, "grep_files").supports_parallel_tool_calls); + assert!(find_tool(&tools, "list_dir").supports_parallel_tool_calls); + assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls); +} + +#[test] +fn test_test_model_info_includes_sync_tool() { + let _config = test_config(); + let mut model_info = model_info_from_models_json("gpt-5-codex"); + model_info.experimental_supported_tools = vec![ + "test_sync_tool".to_string(), + "read_file".to_string(), + "grep_files".to_string(), + "list_dir".to_string(), + ]; + let features = Features::with_defaults(); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + + assert!( + tools + .iter() + .any(|tool| tool_name(&tool.spec) == "test_sync_tool") + ); + assert!( + tools + .iter() + .any(|tool| tool_name(&tool.spec) == "read_file") + ); + assert!( + tools + .iter() + .any(|tool| tool_name(&tool.spec) == "grep_files") + ); + assert!(tools.iter().any(|tool| tool_name(&tool.spec) == "list_dir")); +} + +#[test] +fn test_build_specs_mcp_tools_converted() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("o3", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Live), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs( + &tools_config, + Some(HashMap::from([( + "test_server/do_something_cool".to_string(), + mcp_tool( + "do_something_cool", + "Do something cool", + serde_json::json!({ + "type": "object", + "properties": { + "string_argument": { "type": "string" }, + "number_argument": { "type": "number" }, + "object_argument": { + "type": "object", + "properties": { + "string_property": { "type": "string" }, + "number_property": { "type": "number" }, + }, + "required": ["string_property", "number_property"], + "additionalProperties": false, + }, + }, + }), + ), + )])), + None, + &[], + ) + .build(); + + let tool = find_tool(&tools, "test_server/do_something_cool"); + assert_eq!( + &tool.spec, + &ToolSpec::Function(ResponsesApiTool { + name: "test_server/do_something_cool".to_string(), + parameters: JsonSchema::Object { + properties: BTreeMap::from([ + ( + "string_argument".to_string(), + JsonSchema::String { description: None } + ), + ( + "number_argument".to_string(), + JsonSchema::Number { description: None } + ), + ( + "object_argument".to_string(), + JsonSchema::Object { + properties: BTreeMap::from([ + ( + "string_property".to_string(), + JsonSchema::String { description: None } + ), + ( + "number_property".to_string(), + JsonSchema::Number { description: None } + ), + ]), + required: Some(vec![ + "string_property".to_string(), + "number_property".to_string(), + ]), + additional_properties: Some(false.into()), + }, + ), + ]), + required: None, + additional_properties: None, + }, + description: "Do something cool".to_string(), + strict: false, + output_schema: Some(mcp_call_tool_result_output_schema(serde_json::json!({}))), + defer_loading: None, + }) + ); +} + +#[test] +fn test_build_specs_mcp_tools_sorted_by_name() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("o3", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + + // Intentionally construct a map with keys that would sort alphabetically. + let tools_map: HashMap = HashMap::from([ + ( + "test_server/do".to_string(), + mcp_tool("a", "a", serde_json::json!({"type": "object"})), + ), + ( + "test_server/something".to_string(), + mcp_tool("b", "b", serde_json::json!({"type": "object"})), + ), + ( + "test_server/cool".to_string(), + mcp_tool("c", "c", serde_json::json!({"type": "object"})), + ), + ]); + + let (tools, _) = build_specs(&tools_config, Some(tools_map), None, &[]).build(); + + // Only assert that the MCP tools themselves are sorted by fully-qualified name. + let mcp_names: Vec<_> = tools + .iter() + .map(|t| tool_name(&t.spec).to_string()) + .filter(|n| n.starts_with("test_server/")) + .collect(); + let expected = vec![ + "test_server/cool".to_string(), + "test_server/do".to_string(), + "test_server/something".to_string(), + ]; + assert_eq!(mcp_names, expected); +} + +#[test] +fn search_tool_description_includes_only_codex_apps_connector_names() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::Apps); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + + let (tools, _) = build_specs( + &tools_config, + Some(HashMap::from([ + ( + "mcp__codex_apps__calendar_create_event".to_string(), + mcp_tool( + "calendar_create_event", + "Create calendar event", + serde_json::json!({"type": "object"}), + ), + ), + ( + "mcp__rmcp__echo".to_string(), + mcp_tool("echo", "Echo", serde_json::json!({"type": "object"})), + ), + ])), + Some(HashMap::from([ + ( + "mcp__codex_apps__calendar-create-event".to_string(), + ToolInfo { + server_name: crate::mcp::CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool_name: "-create-event".to_string(), + tool_namespace: "mcp__codex_apps__calendar".to_string(), + tool: mcp_tool( + "calendar-create-event", + "Create calendar event", + serde_json::json!({"type": "object"}), + ), + connector_id: Some("calendar".to_string()), + connector_name: Some("Calendar".to_string()), + plugin_display_names: Vec::new(), + connector_description: None, + }, + ), + ( + "mcp__rmcp__echo".to_string(), + ToolInfo { + server_name: "rmcp".to_string(), + tool_name: "echo".to_string(), + tool_namespace: "rmcp".to_string(), + tool: mcp_tool("echo", "Echo", serde_json::json!({"type": "object"})), + connector_id: None, + connector_name: None, + plugin_display_names: Vec::new(), + connector_description: None, + }, + ), + ])), + &[], + ) + .build(); + + let search_tool = find_tool(&tools, TOOL_SEARCH_TOOL_NAME); + let ToolSpec::ToolSearch { description, .. } = &search_tool.spec else { + panic!("expected tool_search tool"); + }; + let description = description.as_str(); + assert!(description.contains("Calendar")); + assert!(!description.contains("mcp__rmcp__echo")); +} + +#[test] +fn search_tool_requires_apps_feature_flag_only() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let app_tools = Some(HashMap::from([( + "mcp__codex_apps__calendar_create_event".to_string(), + ToolInfo { + server_name: crate::mcp::CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool_name: "calendar_create_event".to_string(), + tool_namespace: "mcp__codex_apps__calendar".to_string(), + tool: mcp_tool( + "calendar_create_event", + "Create calendar event", + serde_json::json!({"type": "object"}), + ), + connector_id: Some("calendar".to_string()), + connector_name: Some("Calendar".to_string()), + connector_description: None, + plugin_display_names: Vec::new(), + }, + )])); + + let features = Features::with_defaults(); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, app_tools.clone(), &[]).build(); + assert_lacks_tool_name(&tools, TOOL_SEARCH_TOOL_NAME); + let mut features = Features::with_defaults(); + features.enable(Feature::Apps); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs(&tools_config, None, app_tools, &[]).build(); + assert_contains_tool_names(&tools, &[TOOL_SEARCH_TOOL_NAME]); +} + +#[test] +fn tool_suggest_is_not_registered_without_feature_flag() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::Apps); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs_with_discoverable_tools( + &tools_config, + None, + None, + Some(vec![discoverable_connector( + "connector_2128aebfecb84f64a069897515042a44", + "Google Calendar", + "Plan events and schedules.", + )]), + &[], + ) + .build(); + + assert!( + !tools + .iter() + .any(|tool| tool_name(&tool.spec) == TOOL_SUGGEST_TOOL_NAME) + ); +} + +#[test] +fn search_tool_description_handles_no_enabled_apps() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::Apps); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + + let (tools, _) = build_specs(&tools_config, None, Some(HashMap::new()), &[]).build(); + let search_tool = find_tool(&tools, TOOL_SEARCH_TOOL_NAME); + let ToolSpec::ToolSearch { description, .. } = &search_tool.spec else { + panic!("expected tool_search tool"); + }; + + assert!(description.contains("(None currently enabled)")); + assert!(!description.contains("{{app_names}}")); +} + +#[test] +fn search_tool_registers_namespaced_app_tool_aliases() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::Apps); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + + let (_, registry) = build_specs( + &tools_config, + None, + Some(HashMap::from([ + ( + "mcp__codex_apps__calendar-create-event".to_string(), + ToolInfo { + server_name: crate::mcp::CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool_name: "-create-event".to_string(), + tool_namespace: "mcp__codex_apps__calendar".to_string(), + tool: mcp_tool( + "calendar-create-event", + "Create calendar event", + serde_json::json!({"type": "object"}), + ), + connector_id: Some("calendar".to_string()), + connector_name: Some("Calendar".to_string()), + connector_description: None, + plugin_display_names: Vec::new(), + }, + ), + ( + "mcp__codex_apps__calendar-list-events".to_string(), + ToolInfo { + server_name: crate::mcp::CODEX_APPS_MCP_SERVER_NAME.to_string(), + tool_name: "-list-events".to_string(), + tool_namespace: "mcp__codex_apps__calendar".to_string(), + tool: mcp_tool( + "calendar-list-events", + "List calendar events", + serde_json::json!({"type": "object"}), + ), + connector_id: Some("calendar".to_string()), + connector_name: Some("Calendar".to_string()), + connector_description: None, + plugin_display_names: Vec::new(), + }, + ), + ])), + &[], + ) + .build(); + + let alias = tool_handler_key("-create-event", Some("mcp__codex_apps__calendar")); + + assert!(registry.has_handler(TOOL_SEARCH_TOOL_NAME, None)); + assert!(registry.has_handler(alias.as_str(), None)); +} + +#[test] +fn tool_suggest_description_lists_discoverable_tools() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::Apps); + features.enable(Feature::ToolSuggest); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + + let discoverable_tools = vec![ + discoverable_connector( + "connector_2128aebfecb84f64a069897515042a44", + "Google Calendar", + "Plan events and schedules.", + ), + discoverable_connector( + "connector_68df038e0ba48191908c8434991bbac2", + "Gmail", + "Find and summarize email threads.", + ), + DiscoverableTool::Plugin(Box::new(DiscoverablePluginInfo { + id: "sample@test".to_string(), + name: "Sample Plugin".to_string(), + description: None, + has_skills: true, + mcp_server_names: vec!["sample-docs".to_string()], + app_connector_ids: vec!["connector_sample".to_string()], + })), + ]; + + let (tools, _) = build_specs_with_discoverable_tools( + &tools_config, + None, + None, + Some(discoverable_tools), + &[], + ) + .build(); + + let tool_suggest = find_tool(&tools, TOOL_SUGGEST_TOOL_NAME); + let ToolSpec::Function(ResponsesApiTool { + description, + parameters, + .. + }) = &tool_suggest.spec + else { + panic!("expected function tool"); + }; + assert!(description.contains("Google Calendar")); + assert!(description.contains("Gmail")); + assert!(description.contains("Sample Plugin")); + assert!(description.contains("Plan events and schedules.")); + assert!(description.contains("Find and summarize email threads.")); + assert!(description.contains("id: `sample@test`, type: plugin, action: enable")); + assert!( + description.contains("skills; MCP servers: sample-docs; app connectors: connector_sample") + ); + assert!(description.contains("DO NOT explore or recommend tools that are not on this list.")); + let JsonSchema::Object { required, .. } = parameters else { + panic!("expected object parameters"); + }; + assert_eq!( + required.as_ref(), + Some(&vec![ + "tool_type".to_string(), + "action_type".to_string(), + "tool_id".to_string(), + "suggest_reason".to_string(), + ]) + ); +} + +#[test] +fn test_mcp_tool_property_missing_type_defaults_to_string() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + + let (tools, _) = build_specs( + &tools_config, + Some(HashMap::from([( + "dash/search".to_string(), + mcp_tool( + "search", + "Search docs", + serde_json::json!({ + "type": "object", + "properties": { + "query": {"description": "search query"} + } + }), + ), + )])), + None, + &[], + ) + .build(); + + let tool = find_tool(&tools, "dash/search"); + assert_eq!( + tool.spec, + ToolSpec::Function(ResponsesApiTool { + name: "dash/search".to_string(), + parameters: JsonSchema::Object { + properties: BTreeMap::from([( + "query".to_string(), + JsonSchema::String { + description: Some("search query".to_string()) + } + )]), + required: None, + additional_properties: None, + }, + description: "Search docs".to_string(), + strict: false, + output_schema: Some(mcp_call_tool_result_output_schema(serde_json::json!({}))), + defer_loading: None, + }) + ); +} + +#[test] +fn test_mcp_tool_integer_normalized_to_number() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + + let (tools, _) = build_specs( + &tools_config, + Some(HashMap::from([( + "dash/paginate".to_string(), + mcp_tool( + "paginate", + "Pagination", + serde_json::json!({ + "type": "object", + "properties": {"page": {"type": "integer"}} + }), + ), + )])), + None, + &[], + ) + .build(); + + let tool = find_tool(&tools, "dash/paginate"); + assert_eq!( + tool.spec, + ToolSpec::Function(ResponsesApiTool { + name: "dash/paginate".to_string(), + parameters: JsonSchema::Object { + properties: BTreeMap::from([( + "page".to_string(), + JsonSchema::Number { description: None } + )]), + required: None, + additional_properties: None, + }, + description: "Pagination".to_string(), + strict: false, + output_schema: Some(mcp_call_tool_result_output_schema(serde_json::json!({}))), + defer_loading: None, + }) + ); +} + +#[test] +fn test_mcp_tool_array_without_items_gets_default_string_items() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + features.enable(Feature::ApplyPatchFreeform); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + + let (tools, _) = build_specs( + &tools_config, + Some(HashMap::from([( + "dash/tags".to_string(), + mcp_tool( + "tags", + "Tags", + serde_json::json!({ + "type": "object", + "properties": {"tags": {"type": "array"}} + }), + ), + )])), + None, + &[], + ) + .build(); + + let tool = find_tool(&tools, "dash/tags"); + assert_eq!( + tool.spec, + ToolSpec::Function(ResponsesApiTool { + name: "dash/tags".to_string(), + parameters: JsonSchema::Object { + properties: BTreeMap::from([( + "tags".to_string(), + JsonSchema::Array { + items: Box::new(JsonSchema::String { description: None }), + description: None + } + )]), + required: None, + additional_properties: None, + }, + description: "Tags".to_string(), + strict: false, + output_schema: Some(mcp_call_tool_result_output_schema(serde_json::json!({}))), + defer_loading: None, + }) + ); +} + +#[test] +fn test_mcp_tool_anyof_defaults_to_string() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + + let (tools, _) = build_specs( + &tools_config, + Some(HashMap::from([( + "dash/value".to_string(), + mcp_tool( + "value", + "AnyOf Value", + serde_json::json!({ + "type": "object", + "properties": { + "value": {"anyOf": [{"type": "string"}, {"type": "number"}]} + } + }), + ), + )])), + None, + &[], + ) + .build(); + + let tool = find_tool(&tools, "dash/value"); + assert_eq!( + tool.spec, + ToolSpec::Function(ResponsesApiTool { + name: "dash/value".to_string(), + parameters: JsonSchema::Object { + properties: BTreeMap::from([( + "value".to_string(), + JsonSchema::String { description: None } + )]), + required: None, + additional_properties: None, + }, + description: "AnyOf Value".to_string(), + strict: false, + output_schema: Some(mcp_call_tool_result_output_schema(serde_json::json!({}))), + defer_loading: None, + }) + ); +} + +#[test] +fn test_shell_tool() { + let tool = super::create_shell_tool(false); + let ToolSpec::Function(ResponsesApiTool { + description, name, .. + }) = &tool + else { + panic!("expected function tool"); + }; + assert_eq!(name, "shell"); + + let expected = if cfg!(windows) { + r#"Runs a Powershell command (Windows) and returns its output. Arguments to `shell` will be passed to CreateProcessW(). Most commands should be prefixed with ["powershell.exe", "-Command"]. + +Examples of valid command strings: + +- ls -a (show hidden): ["powershell.exe", "-Command", "Get-ChildItem -Force"] +- recursive find by name: ["powershell.exe", "-Command", "Get-ChildItem -Recurse -Filter *.py"] +- recursive grep: ["powershell.exe", "-Command", "Get-ChildItem -Path C:\\myrepo -Recurse | Select-String -Pattern 'TODO' -CaseSensitive"] +- ps aux | grep python: ["powershell.exe", "-Command", "Get-Process | Where-Object { $_.ProcessName -like '*python*' }"] +- setting an env var: ["powershell.exe", "-Command", "$env:FOO='bar'; echo $env:FOO"] +- running an inline Python script: ["powershell.exe", "-Command", "@'\\nprint('Hello, world!')\\n'@ | python -"]"# + } else { + r#"Runs a shell command and returns its output. +- The arguments to `shell` will be passed to execvp(). Most terminal commands should be prefixed with ["bash", "-lc"]. +- Always set the `workdir` param when using the shell function. Do not use `cd` unless absolutely necessary."# + }.to_string(); + assert_eq!(description, &expected); +} + +#[test] +fn shell_tool_with_request_permission_includes_additional_permissions() { + let tool = super::create_shell_tool(true); + let ToolSpec::Function(ResponsesApiTool { parameters, .. }) = tool else { + panic!("expected function tool"); + }; + let JsonSchema::Object { properties, .. } = parameters else { + panic!("expected object parameters"); + }; + + assert!(properties.contains_key("additional_permissions")); + + let Some(JsonSchema::String { + description: Some(description), + }) = properties.get("sandbox_permissions") + else { + panic!("expected sandbox_permissions description"); + }; + assert!(description.contains("with_additional_permissions")); + assert!(description.contains("macOS permissions")); + + let Some(JsonSchema::Object { + properties: additional_properties, + .. + }) = properties.get("additional_permissions") + else { + panic!("expected additional_permissions schema"); + }; + assert!(additional_properties.contains_key("network")); + assert!(additional_properties.contains_key("file_system")); + assert!(additional_properties.contains_key("macos")); +} + +#[test] +fn request_permissions_tool_includes_full_permission_schema() { + let tool = super::create_request_permissions_tool(); + let ToolSpec::Function(ResponsesApiTool { parameters, .. }) = tool else { + panic!("expected function tool"); + }; + let JsonSchema::Object { properties, .. } = parameters else { + panic!("expected object parameters"); + }; + let Some(JsonSchema::Object { + properties: permission_properties, + additional_properties, + .. + }) = properties.get("permissions") + else { + panic!("expected permissions object"); + }; + + assert_eq!(additional_properties, &Some(false.into())); + assert!(permission_properties.contains_key("network")); + assert!(permission_properties.contains_key("file_system")); + assert!(permission_properties.contains_key("macos")); + + let Some(JsonSchema::Object { + properties: network_properties, + additional_properties, + .. + }) = permission_properties.get("network") + else { + panic!("expected network object"); + }; + assert_eq!(additional_properties, &Some(false.into())); + assert!(network_properties.contains_key("enabled")); + + let Some(JsonSchema::Object { + properties: file_system_properties, + additional_properties, + .. + }) = permission_properties.get("file_system") + else { + panic!("expected file_system object"); + }; + assert_eq!(additional_properties, &Some(false.into())); + assert!(file_system_properties.contains_key("read")); + assert!(file_system_properties.contains_key("write")); + + let Some(JsonSchema::Object { + properties: macos_properties, + additional_properties, + .. + }) = permission_properties.get("macos") + else { + panic!("expected macos object"); + }; + assert_eq!(additional_properties, &Some(false.into())); + assert!(macos_properties.contains_key("preferences")); + assert!(macos_properties.contains_key("automations")); + assert!(macos_properties.contains_key("accessibility")); + assert!(macos_properties.contains_key("calendar")); +} + +#[test] +fn test_shell_command_tool() { + let tool = super::create_shell_command_tool(true, false); + let ToolSpec::Function(ResponsesApiTool { + description, name, .. + }) = &tool + else { + panic!("expected function tool"); + }; + assert_eq!(name, "shell_command"); + + let expected = if cfg!(windows) { + r#"Runs a Powershell command (Windows) and returns its output. + +Examples of valid command strings: + +- ls -a (show hidden): "Get-ChildItem -Force" +- recursive find by name: "Get-ChildItem -Recurse -Filter *.py" +- recursive grep: "Get-ChildItem -Path C:\\myrepo -Recurse | Select-String -Pattern 'TODO' -CaseSensitive" +- ps aux | grep python: "Get-Process | Where-Object { $_.ProcessName -like '*python*' }" +- setting an env var: "$env:FOO='bar'; echo $env:FOO" +- running an inline Python script: "@'\\nprint('Hello, world!')\\n'@ | python -"#.to_string() + } else { + r#"Runs a shell command and returns its output. +- Always set the `workdir` param when using the shell_command function. Do not use `cd` unless absolutely necessary."#.to_string() + }; + assert_eq!(description, &expected); +} + +#[test] +fn test_get_openai_tools_mcp_tools_with_additional_properties_schema() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::UnifiedExec); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + let (tools, _) = build_specs( + &tools_config, + Some(HashMap::from([( + "test_server/do_something_cool".to_string(), + mcp_tool( + "do_something_cool", + "Do something cool", + serde_json::json!({ + "type": "object", + "properties": { + "string_argument": {"type": "string"}, + "number_argument": {"type": "number"}, + "object_argument": { + "type": "object", + "properties": { + "string_property": {"type": "string"}, + "number_property": {"type": "number"} + }, + "required": ["string_property", "number_property"], + "additionalProperties": { + "type": "object", + "properties": { + "addtl_prop": {"type": "string"} + }, + "required": ["addtl_prop"], + "additionalProperties": false + } + } + } + }), + ), + )])), + None, + &[], + ) + .build(); + + let tool = find_tool(&tools, "test_server/do_something_cool"); + assert_eq!( + tool.spec, + ToolSpec::Function(ResponsesApiTool { + name: "test_server/do_something_cool".to_string(), + parameters: JsonSchema::Object { + properties: BTreeMap::from([ + ( + "string_argument".to_string(), + JsonSchema::String { description: None } + ), + ( + "number_argument".to_string(), + JsonSchema::Number { description: None } + ), + ( + "object_argument".to_string(), + JsonSchema::Object { + properties: BTreeMap::from([ + ( + "string_property".to_string(), + JsonSchema::String { description: None } + ), + ( + "number_property".to_string(), + JsonSchema::Number { description: None } + ), + ]), + required: Some(vec![ + "string_property".to_string(), + "number_property".to_string(), + ]), + additional_properties: Some( + JsonSchema::Object { + properties: BTreeMap::from([( + "addtl_prop".to_string(), + JsonSchema::String { description: None } + ),]), + required: Some(vec!["addtl_prop".to_string(),]), + additional_properties: Some(false.into()), + } + .into() + ), + }, + ), + ]), + required: None, + additional_properties: None, + }, + description: "Do something cool".to_string(), + strict: false, + output_schema: Some(mcp_call_tool_result_output_schema(serde_json::json!({}))), + defer_loading: None, + }) + ); +} + +#[test] +fn code_mode_augments_builtin_tool_descriptions_with_typed_sample() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::CodeMode); + features.enable(Feature::UnifiedExec); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + + let (tools, _) = build_specs(&tools_config, None, None, &[]).build(); + let ToolSpec::Function(ResponsesApiTool { description, .. }) = + &find_tool(&tools, "view_image").spec + else { + panic!("expected function tool"); + }; + + assert_eq!( + description, + "View a local image from the filesystem (only use if given a full filepath by the user, and the image isn't already attached to the thread context within tags).\n\nCode mode declaration:\n```ts\nimport { view_image } from \"tools.js\";\ndeclare function view_image(args: {\n path: string;\n}): Promise;\n```" + ); +} + +#[test] +fn code_mode_augments_mcp_tool_descriptions_with_namespaced_sample() { + let config = test_config(); + let model_info = ModelsManager::construct_model_info_offline_for_tests("gpt-5-codex", &config); + let mut features = Features::with_defaults(); + features.enable(Feature::CodeMode); + features.enable(Feature::UnifiedExec); + let available_models = Vec::new(); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_info: &model_info, + available_models: &available_models, + features: &features, + web_search_mode: Some(WebSearchMode::Cached), + session_source: SessionSource::Cli, + }); + + let (tools, _) = build_specs( + &tools_config, + Some(HashMap::from([( + "mcp__sample__echo".to_string(), + mcp_tool( + "echo", + "Echo text", + serde_json::json!({ + "type": "object", + "properties": { + "message": {"type": "string"} + }, + "required": ["message"], + "additionalProperties": false + }), + ), + )])), + None, + &[], + ) + .build(); + + let ToolSpec::Function(ResponsesApiTool { description, .. }) = + &find_tool(&tools, "mcp__sample__echo").spec + else { + panic!("expected function tool"); + }; + + assert_eq!( + description, + "Echo text\n\nCode mode declaration:\n```ts\nimport { echo } from \"tools/mcp/sample.js\";\ndeclare function echo(args: {\n message: string;\n}): Promise<{\n _meta?: unknown;\n content: Array;\n isError?: boolean;\n structuredContent?: unknown;\n}>;\n```" + ); +} + +#[test] +fn chat_tools_include_top_level_name() { + let properties = + BTreeMap::from([("foo".to_string(), JsonSchema::String { description: None })]); + let tools = vec![ToolSpec::Function(ResponsesApiTool { + name: "demo".to_string(), + description: "A demo tool".to_string(), + strict: false, + defer_loading: None, + parameters: JsonSchema::Object { + properties, + required: None, + additional_properties: None, + }, + output_schema: None, + })]; + + let responses_json = create_tools_json_for_responses_api(&tools).unwrap(); + assert_eq!( + responses_json, + vec![json!({ + "type": "function", + "name": "demo", + "description": "A demo tool", + "strict": false, + "parameters": { + "type": "object", + "properties": { + "foo": { "type": "string" } + }, + }, + })] + ); +} diff --git a/codex-rs/core/src/truncate.rs b/codex-rs/core/src/truncate.rs index 927d7c9380..707fbe22ef 100644 --- a/codex-rs/core/src/truncate.rs +++ b/codex-rs/core/src/truncate.rs @@ -359,319 +359,5 @@ pub(crate) fn approx_tokens_from_byte_count_i64(bytes: i64) -> i64 { } #[cfg(test)] -mod tests { - - use super::TruncationPolicy; - use super::approx_token_count; - use super::formatted_truncate_text; - use super::formatted_truncate_text_content_items_with_policy; - use super::split_string; - use super::truncate_function_output_items_with_policy; - use super::truncate_text; - use super::truncate_with_token_budget; - use codex_protocol::models::FunctionCallOutputContentItem; - use pretty_assertions::assert_eq; - - #[test] - fn split_string_works() { - assert_eq!(split_string("hello world", 5, 5), (1, "hello", "world")); - assert_eq!(split_string("abc", 0, 0), (3, "", "")); - } - - #[test] - fn split_string_handles_empty_string() { - assert_eq!(split_string("", 4, 4), (0, "", "")); - } - - #[test] - fn split_string_only_keeps_prefix_when_tail_budget_is_zero() { - assert_eq!(split_string("abcdef", 3, 0), (3, "abc", "")); - } - - #[test] - fn split_string_only_keeps_suffix_when_prefix_budget_is_zero() { - assert_eq!(split_string("abcdef", 0, 3), (3, "", "def")); - } - - #[test] - fn split_string_handles_overlapping_budgets_without_removal() { - assert_eq!(split_string("abcdef", 4, 4), (0, "abcd", "ef")); - } - - #[test] - fn split_string_respects_utf8_boundaries() { - assert_eq!(split_string("😀abc😀", 5, 5), (1, "😀a", "c😀")); - - assert_eq!(split_string("😀😀😀😀😀", 1, 1), (5, "", "")); - assert_eq!(split_string("😀😀😀😀😀", 7, 7), (3, "😀", "😀")); - assert_eq!(split_string("😀😀😀😀😀", 8, 8), (1, "😀😀", "😀😀")); - } - - #[test] - fn truncate_bytes_less_than_placeholder_returns_placeholder() { - let content = "example output"; - - assert_eq!( - "Total output lines: 1\n\n…13 chars truncated…t", - formatted_truncate_text(content, TruncationPolicy::Bytes(1)), - ); - } - - #[test] - fn truncate_tokens_less_than_placeholder_returns_placeholder() { - let content = "example output"; - - assert_eq!( - "Total output lines: 1\n\nex…3 tokens truncated…ut", - formatted_truncate_text(content, TruncationPolicy::Tokens(1)), - ); - } - - #[test] - fn truncate_tokens_under_limit_returns_original() { - let content = "example output"; - - assert_eq!( - content, - formatted_truncate_text(content, TruncationPolicy::Tokens(10)), - ); - } - - #[test] - fn truncate_bytes_under_limit_returns_original() { - let content = "example output"; - - assert_eq!( - content, - formatted_truncate_text(content, TruncationPolicy::Bytes(20)), - ); - } - - #[test] - fn truncate_tokens_over_limit_returns_truncated() { - let content = "this is an example of a long output that should be truncated"; - - assert_eq!( - "Total output lines: 1\n\nthis is an…10 tokens truncated… truncated", - formatted_truncate_text(content, TruncationPolicy::Tokens(5)), - ); - } - - #[test] - fn truncate_bytes_over_limit_returns_truncated() { - let content = "this is an example of a long output that should be truncated"; - - assert_eq!( - "Total output lines: 1\n\nthis is an exam…30 chars truncated…ld be truncated", - formatted_truncate_text(content, TruncationPolicy::Bytes(30)), - ); - } - - #[test] - fn truncate_bytes_reports_original_line_count_when_truncated() { - let content = - "this is an example of a long output that should be truncated\nalso some other line"; - - assert_eq!( - "Total output lines: 2\n\nthis is an exam…51 chars truncated…some other line", - formatted_truncate_text(content, TruncationPolicy::Bytes(30)), - ); - } - - #[test] - fn truncate_tokens_reports_original_line_count_when_truncated() { - let content = - "this is an example of a long output that should be truncated\nalso some other line"; - - assert_eq!( - "Total output lines: 2\n\nthis is an example o…11 tokens truncated…also some other line", - formatted_truncate_text(content, TruncationPolicy::Tokens(10)), - ); - } - - #[test] - fn truncate_with_token_budget_returns_original_when_under_limit() { - let s = "short output"; - let limit = 100; - let (out, original) = truncate_with_token_budget(s, TruncationPolicy::Tokens(limit)); - assert_eq!(out, s); - assert_eq!(original, None); - } - - #[test] - fn truncate_with_token_budget_reports_truncation_at_zero_limit() { - let s = "abcdef"; - let (out, original) = truncate_with_token_budget(s, TruncationPolicy::Tokens(0)); - assert_eq!(out, "…2 tokens truncated…"); - assert_eq!(original, Some(2)); - } - - #[test] - fn truncate_middle_tokens_handles_utf8_content() { - let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with text\n"; - let (out, tokens) = truncate_with_token_budget(s, TruncationPolicy::Tokens(8)); - assert_eq!(out, "😀😀😀😀…8 tokens truncated… line with text\n"); - assert_eq!(tokens, Some(16)); - } - - #[test] - fn truncate_middle_bytes_handles_utf8_content() { - let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with text\n"; - let out = truncate_text(s, TruncationPolicy::Bytes(20)); - assert_eq!(out, "😀😀…21 chars truncated…with text\n"); - } - - #[test] - fn truncates_across_multiple_under_limit_texts_and_reports_omitted() { - let chunk = "alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega.\n"; - let chunk_tokens = approx_token_count(chunk); - assert!(chunk_tokens > 0, "chunk must consume tokens"); - let limit = chunk_tokens * 3; - let t1 = chunk.to_string(); - let t2 = chunk.to_string(); - let t3 = chunk.repeat(10); - let t4 = chunk.to_string(); - let t5 = chunk.to_string(); - - let items = vec![ - FunctionCallOutputContentItem::InputText { text: t1.clone() }, - FunctionCallOutputContentItem::InputText { text: t2.clone() }, - FunctionCallOutputContentItem::InputImage { - image_url: "img:mid".to_string(), - detail: None, - }, - FunctionCallOutputContentItem::InputText { text: t3 }, - FunctionCallOutputContentItem::InputText { text: t4 }, - FunctionCallOutputContentItem::InputText { text: t5 }, - ]; - - let output = - truncate_function_output_items_with_policy(&items, TruncationPolicy::Tokens(limit)); - - // Expect: t1 (full), t2 (full), image, t3 (truncated), summary mentioning 2 omitted. - assert_eq!(output.len(), 5); - - let first_text = match &output[0] { - FunctionCallOutputContentItem::InputText { text } => text, - other => panic!("unexpected first item: {other:?}"), - }; - assert_eq!(first_text, &t1); - - let second_text = match &output[1] { - FunctionCallOutputContentItem::InputText { text } => text, - other => panic!("unexpected second item: {other:?}"), - }; - assert_eq!(second_text, &t2); - - assert_eq!( - output[2], - FunctionCallOutputContentItem::InputImage { - image_url: "img:mid".to_string(), - detail: None, - } - ); - - let fourth_text = match &output[3] { - FunctionCallOutputContentItem::InputText { text } => text, - other => panic!("unexpected fourth item: {other:?}"), - }; - assert!( - fourth_text.contains("tokens truncated"), - "expected marker in truncated snippet: {fourth_text}" - ); - - let summary_text = match &output[4] { - FunctionCallOutputContentItem::InputText { text } => text, - other => panic!("unexpected summary item: {other:?}"), - }; - assert!(summary_text.contains("omitted 2 text items")); - } - - #[test] - fn formatted_truncate_text_content_items_with_policy_returns_original_under_limit() { - let items = vec![ - FunctionCallOutputContentItem::InputText { - text: "alpha".to_string(), - }, - FunctionCallOutputContentItem::InputText { - text: String::new(), - }, - FunctionCallOutputContentItem::InputText { - text: "beta".to_string(), - }, - ]; - - let (output, original_token_count) = - formatted_truncate_text_content_items_with_policy(&items, TruncationPolicy::Bytes(32)); - - assert_eq!(output, items); - assert_eq!(original_token_count, None); - } - - #[test] - fn formatted_truncate_text_content_items_with_policy_merges_text_and_appends_images() { - let items = vec![ - FunctionCallOutputContentItem::InputText { - text: "abcd".to_string(), - }, - FunctionCallOutputContentItem::InputImage { - image_url: "img:one".to_string(), - detail: None, - }, - FunctionCallOutputContentItem::InputText { - text: "efgh".to_string(), - }, - FunctionCallOutputContentItem::InputText { - text: "ijkl".to_string(), - }, - FunctionCallOutputContentItem::InputImage { - image_url: "img:two".to_string(), - detail: None, - }, - ]; - - let (output, original_token_count) = - formatted_truncate_text_content_items_with_policy(&items, TruncationPolicy::Bytes(8)); - - assert_eq!( - output, - vec![ - FunctionCallOutputContentItem::InputText { - text: "Total output lines: 3\n\nabcd…6 chars truncated…ijkl".to_string(), - }, - FunctionCallOutputContentItem::InputImage { - image_url: "img:one".to_string(), - detail: None, - }, - FunctionCallOutputContentItem::InputImage { - image_url: "img:two".to_string(), - detail: None, - }, - ] - ); - assert_eq!(original_token_count, Some(4)); - } - - #[test] - fn formatted_truncate_text_content_items_with_policy_merges_all_text_for_token_budget() { - let items = vec![ - FunctionCallOutputContentItem::InputText { - text: "abcdefgh".to_string(), - }, - FunctionCallOutputContentItem::InputText { - text: "ijklmnop".to_string(), - }, - ]; - - let (output, original_token_count) = - formatted_truncate_text_content_items_with_policy(&items, TruncationPolicy::Tokens(2)); - - assert_eq!( - output, - vec![FunctionCallOutputContentItem::InputText { - text: "Total output lines: 2\n\nabcd…3 tokens truncated…mnop".to_string(), - }] - ); - assert_eq!(original_token_count, Some(5)); - } -} +#[path = "truncate_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/truncate_tests.rs b/codex-rs/core/src/truncate_tests.rs new file mode 100644 index 0000000000..5a61a9a26d --- /dev/null +++ b/codex-rs/core/src/truncate_tests.rs @@ -0,0 +1,313 @@ +use super::TruncationPolicy; +use super::approx_token_count; +use super::formatted_truncate_text; +use super::formatted_truncate_text_content_items_with_policy; +use super::split_string; +use super::truncate_function_output_items_with_policy; +use super::truncate_text; +use super::truncate_with_token_budget; +use codex_protocol::models::FunctionCallOutputContentItem; +use pretty_assertions::assert_eq; + +#[test] +fn split_string_works() { + assert_eq!(split_string("hello world", 5, 5), (1, "hello", "world")); + assert_eq!(split_string("abc", 0, 0), (3, "", "")); +} + +#[test] +fn split_string_handles_empty_string() { + assert_eq!(split_string("", 4, 4), (0, "", "")); +} + +#[test] +fn split_string_only_keeps_prefix_when_tail_budget_is_zero() { + assert_eq!(split_string("abcdef", 3, 0), (3, "abc", "")); +} + +#[test] +fn split_string_only_keeps_suffix_when_prefix_budget_is_zero() { + assert_eq!(split_string("abcdef", 0, 3), (3, "", "def")); +} + +#[test] +fn split_string_handles_overlapping_budgets_without_removal() { + assert_eq!(split_string("abcdef", 4, 4), (0, "abcd", "ef")); +} + +#[test] +fn split_string_respects_utf8_boundaries() { + assert_eq!(split_string("😀abc😀", 5, 5), (1, "😀a", "c😀")); + + assert_eq!(split_string("😀😀😀😀😀", 1, 1), (5, "", "")); + assert_eq!(split_string("😀😀😀😀😀", 7, 7), (3, "😀", "😀")); + assert_eq!(split_string("😀😀😀😀😀", 8, 8), (1, "😀😀", "😀😀")); +} + +#[test] +fn truncate_bytes_less_than_placeholder_returns_placeholder() { + let content = "example output"; + + assert_eq!( + "Total output lines: 1\n\n…13 chars truncated…t", + formatted_truncate_text(content, TruncationPolicy::Bytes(1)), + ); +} + +#[test] +fn truncate_tokens_less_than_placeholder_returns_placeholder() { + let content = "example output"; + + assert_eq!( + "Total output lines: 1\n\nex…3 tokens truncated…ut", + formatted_truncate_text(content, TruncationPolicy::Tokens(1)), + ); +} + +#[test] +fn truncate_tokens_under_limit_returns_original() { + let content = "example output"; + + assert_eq!( + content, + formatted_truncate_text(content, TruncationPolicy::Tokens(10)), + ); +} + +#[test] +fn truncate_bytes_under_limit_returns_original() { + let content = "example output"; + + assert_eq!( + content, + formatted_truncate_text(content, TruncationPolicy::Bytes(20)), + ); +} + +#[test] +fn truncate_tokens_over_limit_returns_truncated() { + let content = "this is an example of a long output that should be truncated"; + + assert_eq!( + "Total output lines: 1\n\nthis is an…10 tokens truncated… truncated", + formatted_truncate_text(content, TruncationPolicy::Tokens(5)), + ); +} + +#[test] +fn truncate_bytes_over_limit_returns_truncated() { + let content = "this is an example of a long output that should be truncated"; + + assert_eq!( + "Total output lines: 1\n\nthis is an exam…30 chars truncated…ld be truncated", + formatted_truncate_text(content, TruncationPolicy::Bytes(30)), + ); +} + +#[test] +fn truncate_bytes_reports_original_line_count_when_truncated() { + let content = + "this is an example of a long output that should be truncated\nalso some other line"; + + assert_eq!( + "Total output lines: 2\n\nthis is an exam…51 chars truncated…some other line", + formatted_truncate_text(content, TruncationPolicy::Bytes(30)), + ); +} + +#[test] +fn truncate_tokens_reports_original_line_count_when_truncated() { + let content = + "this is an example of a long output that should be truncated\nalso some other line"; + + assert_eq!( + "Total output lines: 2\n\nthis is an example o…11 tokens truncated…also some other line", + formatted_truncate_text(content, TruncationPolicy::Tokens(10)), + ); +} + +#[test] +fn truncate_with_token_budget_returns_original_when_under_limit() { + let s = "short output"; + let limit = 100; + let (out, original) = truncate_with_token_budget(s, TruncationPolicy::Tokens(limit)); + assert_eq!(out, s); + assert_eq!(original, None); +} + +#[test] +fn truncate_with_token_budget_reports_truncation_at_zero_limit() { + let s = "abcdef"; + let (out, original) = truncate_with_token_budget(s, TruncationPolicy::Tokens(0)); + assert_eq!(out, "…2 tokens truncated…"); + assert_eq!(original, Some(2)); +} + +#[test] +fn truncate_middle_tokens_handles_utf8_content() { + let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with text\n"; + let (out, tokens) = truncate_with_token_budget(s, TruncationPolicy::Tokens(8)); + assert_eq!(out, "😀😀😀😀…8 tokens truncated… line with text\n"); + assert_eq!(tokens, Some(16)); +} + +#[test] +fn truncate_middle_bytes_handles_utf8_content() { + let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with text\n"; + let out = truncate_text(s, TruncationPolicy::Bytes(20)); + assert_eq!(out, "😀😀…21 chars truncated…with text\n"); +} + +#[test] +fn truncates_across_multiple_under_limit_texts_and_reports_omitted() { + let chunk = "alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega.\n"; + let chunk_tokens = approx_token_count(chunk); + assert!(chunk_tokens > 0, "chunk must consume tokens"); + let limit = chunk_tokens * 3; + let t1 = chunk.to_string(); + let t2 = chunk.to_string(); + let t3 = chunk.repeat(10); + let t4 = chunk.to_string(); + let t5 = chunk.to_string(); + + let items = vec![ + FunctionCallOutputContentItem::InputText { text: t1.clone() }, + FunctionCallOutputContentItem::InputText { text: t2.clone() }, + FunctionCallOutputContentItem::InputImage { + image_url: "img:mid".to_string(), + detail: None, + }, + FunctionCallOutputContentItem::InputText { text: t3 }, + FunctionCallOutputContentItem::InputText { text: t4 }, + FunctionCallOutputContentItem::InputText { text: t5 }, + ]; + + let output = + truncate_function_output_items_with_policy(&items, TruncationPolicy::Tokens(limit)); + + // Expect: t1 (full), t2 (full), image, t3 (truncated), summary mentioning 2 omitted. + assert_eq!(output.len(), 5); + + let first_text = match &output[0] { + FunctionCallOutputContentItem::InputText { text } => text, + other => panic!("unexpected first item: {other:?}"), + }; + assert_eq!(first_text, &t1); + + let second_text = match &output[1] { + FunctionCallOutputContentItem::InputText { text } => text, + other => panic!("unexpected second item: {other:?}"), + }; + assert_eq!(second_text, &t2); + + assert_eq!( + output[2], + FunctionCallOutputContentItem::InputImage { + image_url: "img:mid".to_string(), + detail: None, + } + ); + + let fourth_text = match &output[3] { + FunctionCallOutputContentItem::InputText { text } => text, + other => panic!("unexpected fourth item: {other:?}"), + }; + assert!( + fourth_text.contains("tokens truncated"), + "expected marker in truncated snippet: {fourth_text}" + ); + + let summary_text = match &output[4] { + FunctionCallOutputContentItem::InputText { text } => text, + other => panic!("unexpected summary item: {other:?}"), + }; + assert!(summary_text.contains("omitted 2 text items")); +} + +#[test] +fn formatted_truncate_text_content_items_with_policy_returns_original_under_limit() { + let items = vec![ + FunctionCallOutputContentItem::InputText { + text: "alpha".to_string(), + }, + FunctionCallOutputContentItem::InputText { + text: String::new(), + }, + FunctionCallOutputContentItem::InputText { + text: "beta".to_string(), + }, + ]; + + let (output, original_token_count) = + formatted_truncate_text_content_items_with_policy(&items, TruncationPolicy::Bytes(32)); + + assert_eq!(output, items); + assert_eq!(original_token_count, None); +} + +#[test] +fn formatted_truncate_text_content_items_with_policy_merges_text_and_appends_images() { + let items = vec![ + FunctionCallOutputContentItem::InputText { + text: "abcd".to_string(), + }, + FunctionCallOutputContentItem::InputImage { + image_url: "img:one".to_string(), + detail: None, + }, + FunctionCallOutputContentItem::InputText { + text: "efgh".to_string(), + }, + FunctionCallOutputContentItem::InputText { + text: "ijkl".to_string(), + }, + FunctionCallOutputContentItem::InputImage { + image_url: "img:two".to_string(), + detail: None, + }, + ]; + + let (output, original_token_count) = + formatted_truncate_text_content_items_with_policy(&items, TruncationPolicy::Bytes(8)); + + assert_eq!( + output, + vec![ + FunctionCallOutputContentItem::InputText { + text: "Total output lines: 3\n\nabcd…6 chars truncated…ijkl".to_string(), + }, + FunctionCallOutputContentItem::InputImage { + image_url: "img:one".to_string(), + detail: None, + }, + FunctionCallOutputContentItem::InputImage { + image_url: "img:two".to_string(), + detail: None, + }, + ] + ); + assert_eq!(original_token_count, Some(4)); +} + +#[test] +fn formatted_truncate_text_content_items_with_policy_merges_all_text_for_token_budget() { + let items = vec![ + FunctionCallOutputContentItem::InputText { + text: "abcdefgh".to_string(), + }, + FunctionCallOutputContentItem::InputText { + text: "ijklmnop".to_string(), + }, + ]; + + let (output, original_token_count) = + formatted_truncate_text_content_items_with_policy(&items, TruncationPolicy::Tokens(2)); + + assert_eq!( + output, + vec![FunctionCallOutputContentItem::InputText { + text: "Total output lines: 2\n\nabcd…3 tokens truncated…mnop".to_string(), + }] + ); + assert_eq!(original_token_count, Some(5)); +} diff --git a/codex-rs/core/src/turn_diff_tracker.rs b/codex-rs/core/src/turn_diff_tracker.rs index 06c40deb90..3568c915af 100644 --- a/codex-rs/core/src/turn_diff_tracker.rs +++ b/codex-rs/core/src/turn_diff_tracker.rs @@ -465,432 +465,5 @@ fn is_windows_drive_or_unc_root(p: &std::path::Path) -> bool { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use tempfile::tempdir; - - /// Compute the Git SHA-1 blob object ID for the given content (string). - /// This delegates to the bytes version to avoid UTF-8 lossy conversions here. - fn git_blob_sha1_hex(data: &str) -> String { - format!("{:x}", git_blob_sha1_hex_bytes(data.as_bytes())) - } - - fn normalize_diff_for_test(input: &str, root: &Path) -> String { - let root_str = root.display().to_string().replace('\\', "/"); - let replaced = input.replace(&root_str, ""); - // Split into blocks on lines starting with "diff --git ", sort blocks for determinism, and rejoin - let mut blocks: Vec = Vec::new(); - let mut current = String::new(); - for line in replaced.lines() { - if line.starts_with("diff --git ") && !current.is_empty() { - blocks.push(current); - current = String::new(); - } - if !current.is_empty() { - current.push('\n'); - } - current.push_str(line); - } - if !current.is_empty() { - blocks.push(current); - } - blocks.sort(); - let mut out = blocks.join("\n"); - if !out.ends_with('\n') { - out.push('\n'); - } - out - } - - #[test] - fn accumulates_add_and_update() { - let mut acc = TurnDiffTracker::new(); - - let dir = tempdir().unwrap(); - let file = dir.path().join("a.txt"); - - // First patch: add file (baseline should be /dev/null). - let add_changes = HashMap::from([( - file.clone(), - FileChange::Add { - content: "foo\n".to_string(), - }, - )]); - acc.on_patch_begin(&add_changes); - - // Simulate apply: create the file on disk. - fs::write(&file, "foo\n").unwrap(); - let first = acc.get_unified_diff().unwrap().unwrap(); - let first = normalize_diff_for_test(&first, dir.path()); - let expected_first = { - let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular); - let right_oid = git_blob_sha1_hex("foo\n"); - format!( - r#"diff --git a//a.txt b//a.txt -new file mode {mode} -index {ZERO_OID}..{right_oid} ---- {DEV_NULL} -+++ b//a.txt -@@ -0,0 +1 @@ -+foo -"#, - ) - }; - assert_eq!(first, expected_first); - - // Second patch: update the file on disk. - let update_changes = HashMap::from([( - file.clone(), - FileChange::Update { - unified_diff: "".to_owned(), - move_path: None, - }, - )]); - acc.on_patch_begin(&update_changes); - - // Simulate apply: append a new line. - fs::write(&file, "foo\nbar\n").unwrap(); - let combined = acc.get_unified_diff().unwrap().unwrap(); - let combined = normalize_diff_for_test(&combined, dir.path()); - let expected_combined = { - let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular); - let right_oid = git_blob_sha1_hex("foo\nbar\n"); - format!( - r#"diff --git a//a.txt b//a.txt -new file mode {mode} -index {ZERO_OID}..{right_oid} ---- {DEV_NULL} -+++ b//a.txt -@@ -0,0 +1,2 @@ -+foo -+bar -"#, - ) - }; - assert_eq!(combined, expected_combined); - } - - #[test] - fn accumulates_delete() { - let dir = tempdir().unwrap(); - let file = dir.path().join("b.txt"); - fs::write(&file, "x\n").unwrap(); - - let mut acc = TurnDiffTracker::new(); - let del_changes = HashMap::from([( - file.clone(), - FileChange::Delete { - content: "x\n".to_string(), - }, - )]); - acc.on_patch_begin(&del_changes); - - // Simulate apply: delete the file from disk. - let baseline_mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular); - fs::remove_file(&file).unwrap(); - let diff = acc.get_unified_diff().unwrap().unwrap(); - let diff = normalize_diff_for_test(&diff, dir.path()); - let expected = { - let left_oid = git_blob_sha1_hex("x\n"); - format!( - r#"diff --git a//b.txt b//b.txt -deleted file mode {baseline_mode} -index {left_oid}..{ZERO_OID} ---- a//b.txt -+++ {DEV_NULL} -@@ -1 +0,0 @@ --x -"#, - ) - }; - assert_eq!(diff, expected); - } - - #[test] - fn accumulates_move_and_update() { - let dir = tempdir().unwrap(); - let src = dir.path().join("src.txt"); - let dest = dir.path().join("dst.txt"); - fs::write(&src, "line\n").unwrap(); - - let mut acc = TurnDiffTracker::new(); - let mv_changes = HashMap::from([( - src.clone(), - FileChange::Update { - unified_diff: "".to_owned(), - move_path: Some(dest.clone()), - }, - )]); - acc.on_patch_begin(&mv_changes); - - // Simulate apply: move and update content. - fs::rename(&src, &dest).unwrap(); - fs::write(&dest, "line2\n").unwrap(); - - let out = acc.get_unified_diff().unwrap().unwrap(); - let out = normalize_diff_for_test(&out, dir.path()); - let expected = { - let left_oid = git_blob_sha1_hex("line\n"); - let right_oid = git_blob_sha1_hex("line2\n"); - format!( - r#"diff --git a//src.txt b//dst.txt -index {left_oid}..{right_oid} ---- a//src.txt -+++ b//dst.txt -@@ -1 +1 @@ --line -+line2 -"# - ) - }; - assert_eq!(out, expected); - } - - #[test] - fn move_without_1change_yields_no_diff() { - let dir = tempdir().unwrap(); - let src = dir.path().join("moved.txt"); - let dest = dir.path().join("renamed.txt"); - fs::write(&src, "same\n").unwrap(); - - let mut acc = TurnDiffTracker::new(); - let mv_changes = HashMap::from([( - src.clone(), - FileChange::Update { - unified_diff: "".to_owned(), - move_path: Some(dest.clone()), - }, - )]); - acc.on_patch_begin(&mv_changes); - - // Simulate apply: move only, no content change. - fs::rename(&src, &dest).unwrap(); - - let diff = acc.get_unified_diff().unwrap(); - assert_eq!(diff, None); - } - - #[test] - fn move_declared_but_file_only_appears_at_dest_is_add() { - let dir = tempdir().unwrap(); - let src = dir.path().join("src.txt"); - let dest = dir.path().join("dest.txt"); - let mut acc = TurnDiffTracker::new(); - let mv = HashMap::from([( - src, - FileChange::Update { - unified_diff: "".into(), - move_path: Some(dest.clone()), - }, - )]); - acc.on_patch_begin(&mv); - // No file existed initially; create only dest - fs::write(&dest, "hello\n").unwrap(); - let diff = acc.get_unified_diff().unwrap().unwrap(); - let diff = normalize_diff_for_test(&diff, dir.path()); - let expected = { - let mode = file_mode_for_path(&dest).unwrap_or(FileMode::Regular); - let right_oid = git_blob_sha1_hex("hello\n"); - format!( - r#"diff --git a//src.txt b//dest.txt -new file mode {mode} -index {ZERO_OID}..{right_oid} ---- {DEV_NULL} -+++ b//dest.txt -@@ -0,0 +1 @@ -+hello -"#, - ) - }; - assert_eq!(diff, expected); - } - - #[test] - fn update_persists_across_new_baseline_for_new_file() { - let dir = tempdir().unwrap(); - let a = dir.path().join("a.txt"); - let b = dir.path().join("b.txt"); - fs::write(&a, "foo\n").unwrap(); - fs::write(&b, "z\n").unwrap(); - - let mut acc = TurnDiffTracker::new(); - - // First: update existing a.txt (baseline snapshot is created for a). - let update_a = HashMap::from([( - a.clone(), - FileChange::Update { - unified_diff: "".to_owned(), - move_path: None, - }, - )]); - acc.on_patch_begin(&update_a); - // Simulate apply: modify a.txt on disk. - fs::write(&a, "foo\nbar\n").unwrap(); - let first = acc.get_unified_diff().unwrap().unwrap(); - let first = normalize_diff_for_test(&first, dir.path()); - let expected_first = { - let left_oid = git_blob_sha1_hex("foo\n"); - let right_oid = git_blob_sha1_hex("foo\nbar\n"); - format!( - r#"diff --git a//a.txt b//a.txt -index {left_oid}..{right_oid} ---- a//a.txt -+++ b//a.txt -@@ -1 +1,2 @@ - foo -+bar -"# - ) - }; - assert_eq!(first, expected_first); - - // Next: introduce a brand-new path b.txt into baseline snapshots via a delete change. - let del_b = HashMap::from([( - b.clone(), - FileChange::Delete { - content: "z\n".to_string(), - }, - )]); - acc.on_patch_begin(&del_b); - // Simulate apply: delete b.txt. - let baseline_mode = file_mode_for_path(&b).unwrap_or(FileMode::Regular); - fs::remove_file(&b).unwrap(); - - let combined = acc.get_unified_diff().unwrap().unwrap(); - let combined = normalize_diff_for_test(&combined, dir.path()); - let expected = { - let left_oid_a = git_blob_sha1_hex("foo\n"); - let right_oid_a = git_blob_sha1_hex("foo\nbar\n"); - let left_oid_b = git_blob_sha1_hex("z\n"); - format!( - r#"diff --git a//a.txt b//a.txt -index {left_oid_a}..{right_oid_a} ---- a//a.txt -+++ b//a.txt -@@ -1 +1,2 @@ - foo -+bar -diff --git a//b.txt b//b.txt -deleted file mode {baseline_mode} -index {left_oid_b}..{ZERO_OID} ---- a//b.txt -+++ {DEV_NULL} -@@ -1 +0,0 @@ --z -"#, - ) - }; - assert_eq!(combined, expected); - } - - #[test] - fn binary_files_differ_update() { - let dir = tempdir().unwrap(); - let file = dir.path().join("bin.dat"); - - // Initial non-UTF8 bytes - let left_bytes: Vec = vec![0xff, 0xfe, 0xfd, 0x00]; - // Updated non-UTF8 bytes - let right_bytes: Vec = vec![0x01, 0x02, 0x03, 0x00]; - - fs::write(&file, &left_bytes).unwrap(); - - let mut acc = TurnDiffTracker::new(); - let update_changes = HashMap::from([( - file.clone(), - FileChange::Update { - unified_diff: "".to_owned(), - move_path: None, - }, - )]); - acc.on_patch_begin(&update_changes); - - // Apply update on disk - fs::write(&file, &right_bytes).unwrap(); - - let diff = acc.get_unified_diff().unwrap().unwrap(); - let diff = normalize_diff_for_test(&diff, dir.path()); - let expected = { - let left_oid = format!("{:x}", git_blob_sha1_hex_bytes(&left_bytes)); - let right_oid = format!("{:x}", git_blob_sha1_hex_bytes(&right_bytes)); - format!( - r#"diff --git a//bin.dat b//bin.dat -index {left_oid}..{right_oid} ---- a//bin.dat -+++ b//bin.dat -Binary files differ -"# - ) - }; - assert_eq!(diff, expected); - } - - #[test] - fn filenames_with_spaces_add_and_update() { - let mut acc = TurnDiffTracker::new(); - - let dir = tempdir().unwrap(); - let file = dir.path().join("name with spaces.txt"); - - // First patch: add file (baseline should be /dev/null). - let add_changes = HashMap::from([( - file.clone(), - FileChange::Add { - content: "foo\n".to_string(), - }, - )]); - acc.on_patch_begin(&add_changes); - - // Simulate apply: create the file on disk. - fs::write(&file, "foo\n").unwrap(); - let first = acc.get_unified_diff().unwrap().unwrap(); - let first = normalize_diff_for_test(&first, dir.path()); - let expected_first = { - let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular); - let right_oid = git_blob_sha1_hex("foo\n"); - format!( - r#"diff --git a//name with spaces.txt b//name with spaces.txt -new file mode {mode} -index {ZERO_OID}..{right_oid} ---- {DEV_NULL} -+++ b//name with spaces.txt -@@ -0,0 +1 @@ -+foo -"#, - ) - }; - assert_eq!(first, expected_first); - - // Second patch: update the file on disk. - let update_changes = HashMap::from([( - file.clone(), - FileChange::Update { - unified_diff: "".to_owned(), - move_path: None, - }, - )]); - acc.on_patch_begin(&update_changes); - - // Simulate apply: append a new line with a space. - fs::write(&file, "foo\nbar baz\n").unwrap(); - let combined = acc.get_unified_diff().unwrap().unwrap(); - let combined = normalize_diff_for_test(&combined, dir.path()); - let expected_combined = { - let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular); - let right_oid = git_blob_sha1_hex("foo\nbar baz\n"); - format!( - r#"diff --git a//name with spaces.txt b//name with spaces.txt -new file mode {mode} -index {ZERO_OID}..{right_oid} ---- {DEV_NULL} -+++ b//name with spaces.txt -@@ -0,0 +1,2 @@ -+foo -+bar baz -"#, - ) - }; - assert_eq!(combined, expected_combined); - } -} +#[path = "turn_diff_tracker_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/turn_diff_tracker_tests.rs b/codex-rs/core/src/turn_diff_tracker_tests.rs new file mode 100644 index 0000000000..e0ab2dd667 --- /dev/null +++ b/codex-rs/core/src/turn_diff_tracker_tests.rs @@ -0,0 +1,427 @@ +use super::*; +use pretty_assertions::assert_eq; +use tempfile::tempdir; + +/// Compute the Git SHA-1 blob object ID for the given content (string). +/// This delegates to the bytes version to avoid UTF-8 lossy conversions here. +fn git_blob_sha1_hex(data: &str) -> String { + format!("{:x}", git_blob_sha1_hex_bytes(data.as_bytes())) +} + +fn normalize_diff_for_test(input: &str, root: &Path) -> String { + let root_str = root.display().to_string().replace('\\', "/"); + let replaced = input.replace(&root_str, ""); + // Split into blocks on lines starting with "diff --git ", sort blocks for determinism, and rejoin + let mut blocks: Vec = Vec::new(); + let mut current = String::new(); + for line in replaced.lines() { + if line.starts_with("diff --git ") && !current.is_empty() { + blocks.push(current); + current = String::new(); + } + if !current.is_empty() { + current.push('\n'); + } + current.push_str(line); + } + if !current.is_empty() { + blocks.push(current); + } + blocks.sort(); + let mut out = blocks.join("\n"); + if !out.ends_with('\n') { + out.push('\n'); + } + out +} + +#[test] +fn accumulates_add_and_update() { + let mut acc = TurnDiffTracker::new(); + + let dir = tempdir().unwrap(); + let file = dir.path().join("a.txt"); + + // First patch: add file (baseline should be /dev/null). + let add_changes = HashMap::from([( + file.clone(), + FileChange::Add { + content: "foo\n".to_string(), + }, + )]); + acc.on_patch_begin(&add_changes); + + // Simulate apply: create the file on disk. + fs::write(&file, "foo\n").unwrap(); + let first = acc.get_unified_diff().unwrap().unwrap(); + let first = normalize_diff_for_test(&first, dir.path()); + let expected_first = { + let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular); + let right_oid = git_blob_sha1_hex("foo\n"); + format!( + r#"diff --git a//a.txt b//a.txt +new file mode {mode} +index {ZERO_OID}..{right_oid} +--- {DEV_NULL} ++++ b//a.txt +@@ -0,0 +1 @@ ++foo +"#, + ) + }; + assert_eq!(first, expected_first); + + // Second patch: update the file on disk. + let update_changes = HashMap::from([( + file.clone(), + FileChange::Update { + unified_diff: "".to_owned(), + move_path: None, + }, + )]); + acc.on_patch_begin(&update_changes); + + // Simulate apply: append a new line. + fs::write(&file, "foo\nbar\n").unwrap(); + let combined = acc.get_unified_diff().unwrap().unwrap(); + let combined = normalize_diff_for_test(&combined, dir.path()); + let expected_combined = { + let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular); + let right_oid = git_blob_sha1_hex("foo\nbar\n"); + format!( + r#"diff --git a//a.txt b//a.txt +new file mode {mode} +index {ZERO_OID}..{right_oid} +--- {DEV_NULL} ++++ b//a.txt +@@ -0,0 +1,2 @@ ++foo ++bar +"#, + ) + }; + assert_eq!(combined, expected_combined); +} + +#[test] +fn accumulates_delete() { + let dir = tempdir().unwrap(); + let file = dir.path().join("b.txt"); + fs::write(&file, "x\n").unwrap(); + + let mut acc = TurnDiffTracker::new(); + let del_changes = HashMap::from([( + file.clone(), + FileChange::Delete { + content: "x\n".to_string(), + }, + )]); + acc.on_patch_begin(&del_changes); + + // Simulate apply: delete the file from disk. + let baseline_mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular); + fs::remove_file(&file).unwrap(); + let diff = acc.get_unified_diff().unwrap().unwrap(); + let diff = normalize_diff_for_test(&diff, dir.path()); + let expected = { + let left_oid = git_blob_sha1_hex("x\n"); + format!( + r#"diff --git a//b.txt b//b.txt +deleted file mode {baseline_mode} +index {left_oid}..{ZERO_OID} +--- a//b.txt ++++ {DEV_NULL} +@@ -1 +0,0 @@ +-x +"#, + ) + }; + assert_eq!(diff, expected); +} + +#[test] +fn accumulates_move_and_update() { + let dir = tempdir().unwrap(); + let src = dir.path().join("src.txt"); + let dest = dir.path().join("dst.txt"); + fs::write(&src, "line\n").unwrap(); + + let mut acc = TurnDiffTracker::new(); + let mv_changes = HashMap::from([( + src.clone(), + FileChange::Update { + unified_diff: "".to_owned(), + move_path: Some(dest.clone()), + }, + )]); + acc.on_patch_begin(&mv_changes); + + // Simulate apply: move and update content. + fs::rename(&src, &dest).unwrap(); + fs::write(&dest, "line2\n").unwrap(); + + let out = acc.get_unified_diff().unwrap().unwrap(); + let out = normalize_diff_for_test(&out, dir.path()); + let expected = { + let left_oid = git_blob_sha1_hex("line\n"); + let right_oid = git_blob_sha1_hex("line2\n"); + format!( + r#"diff --git a//src.txt b//dst.txt +index {left_oid}..{right_oid} +--- a//src.txt ++++ b//dst.txt +@@ -1 +1 @@ +-line ++line2 +"# + ) + }; + assert_eq!(out, expected); +} + +#[test] +fn move_without_1change_yields_no_diff() { + let dir = tempdir().unwrap(); + let src = dir.path().join("moved.txt"); + let dest = dir.path().join("renamed.txt"); + fs::write(&src, "same\n").unwrap(); + + let mut acc = TurnDiffTracker::new(); + let mv_changes = HashMap::from([( + src.clone(), + FileChange::Update { + unified_diff: "".to_owned(), + move_path: Some(dest.clone()), + }, + )]); + acc.on_patch_begin(&mv_changes); + + // Simulate apply: move only, no content change. + fs::rename(&src, &dest).unwrap(); + + let diff = acc.get_unified_diff().unwrap(); + assert_eq!(diff, None); +} + +#[test] +fn move_declared_but_file_only_appears_at_dest_is_add() { + let dir = tempdir().unwrap(); + let src = dir.path().join("src.txt"); + let dest = dir.path().join("dest.txt"); + let mut acc = TurnDiffTracker::new(); + let mv = HashMap::from([( + src, + FileChange::Update { + unified_diff: "".into(), + move_path: Some(dest.clone()), + }, + )]); + acc.on_patch_begin(&mv); + // No file existed initially; create only dest + fs::write(&dest, "hello\n").unwrap(); + let diff = acc.get_unified_diff().unwrap().unwrap(); + let diff = normalize_diff_for_test(&diff, dir.path()); + let expected = { + let mode = file_mode_for_path(&dest).unwrap_or(FileMode::Regular); + let right_oid = git_blob_sha1_hex("hello\n"); + format!( + r#"diff --git a//src.txt b//dest.txt +new file mode {mode} +index {ZERO_OID}..{right_oid} +--- {DEV_NULL} ++++ b//dest.txt +@@ -0,0 +1 @@ ++hello +"#, + ) + }; + assert_eq!(diff, expected); +} + +#[test] +fn update_persists_across_new_baseline_for_new_file() { + let dir = tempdir().unwrap(); + let a = dir.path().join("a.txt"); + let b = dir.path().join("b.txt"); + fs::write(&a, "foo\n").unwrap(); + fs::write(&b, "z\n").unwrap(); + + let mut acc = TurnDiffTracker::new(); + + // First: update existing a.txt (baseline snapshot is created for a). + let update_a = HashMap::from([( + a.clone(), + FileChange::Update { + unified_diff: "".to_owned(), + move_path: None, + }, + )]); + acc.on_patch_begin(&update_a); + // Simulate apply: modify a.txt on disk. + fs::write(&a, "foo\nbar\n").unwrap(); + let first = acc.get_unified_diff().unwrap().unwrap(); + let first = normalize_diff_for_test(&first, dir.path()); + let expected_first = { + let left_oid = git_blob_sha1_hex("foo\n"); + let right_oid = git_blob_sha1_hex("foo\nbar\n"); + format!( + r#"diff --git a//a.txt b//a.txt +index {left_oid}..{right_oid} +--- a//a.txt ++++ b//a.txt +@@ -1 +1,2 @@ + foo ++bar +"# + ) + }; + assert_eq!(first, expected_first); + + // Next: introduce a brand-new path b.txt into baseline snapshots via a delete change. + let del_b = HashMap::from([( + b.clone(), + FileChange::Delete { + content: "z\n".to_string(), + }, + )]); + acc.on_patch_begin(&del_b); + // Simulate apply: delete b.txt. + let baseline_mode = file_mode_for_path(&b).unwrap_or(FileMode::Regular); + fs::remove_file(&b).unwrap(); + + let combined = acc.get_unified_diff().unwrap().unwrap(); + let combined = normalize_diff_for_test(&combined, dir.path()); + let expected = { + let left_oid_a = git_blob_sha1_hex("foo\n"); + let right_oid_a = git_blob_sha1_hex("foo\nbar\n"); + let left_oid_b = git_blob_sha1_hex("z\n"); + format!( + r#"diff --git a//a.txt b//a.txt +index {left_oid_a}..{right_oid_a} +--- a//a.txt ++++ b//a.txt +@@ -1 +1,2 @@ + foo ++bar +diff --git a//b.txt b//b.txt +deleted file mode {baseline_mode} +index {left_oid_b}..{ZERO_OID} +--- a//b.txt ++++ {DEV_NULL} +@@ -1 +0,0 @@ +-z +"#, + ) + }; + assert_eq!(combined, expected); +} + +#[test] +fn binary_files_differ_update() { + let dir = tempdir().unwrap(); + let file = dir.path().join("bin.dat"); + + // Initial non-UTF8 bytes + let left_bytes: Vec = vec![0xff, 0xfe, 0xfd, 0x00]; + // Updated non-UTF8 bytes + let right_bytes: Vec = vec![0x01, 0x02, 0x03, 0x00]; + + fs::write(&file, &left_bytes).unwrap(); + + let mut acc = TurnDiffTracker::new(); + let update_changes = HashMap::from([( + file.clone(), + FileChange::Update { + unified_diff: "".to_owned(), + move_path: None, + }, + )]); + acc.on_patch_begin(&update_changes); + + // Apply update on disk + fs::write(&file, &right_bytes).unwrap(); + + let diff = acc.get_unified_diff().unwrap().unwrap(); + let diff = normalize_diff_for_test(&diff, dir.path()); + let expected = { + let left_oid = format!("{:x}", git_blob_sha1_hex_bytes(&left_bytes)); + let right_oid = format!("{:x}", git_blob_sha1_hex_bytes(&right_bytes)); + format!( + r#"diff --git a//bin.dat b//bin.dat +index {left_oid}..{right_oid} +--- a//bin.dat ++++ b//bin.dat +Binary files differ +"# + ) + }; + assert_eq!(diff, expected); +} + +#[test] +fn filenames_with_spaces_add_and_update() { + let mut acc = TurnDiffTracker::new(); + + let dir = tempdir().unwrap(); + let file = dir.path().join("name with spaces.txt"); + + // First patch: add file (baseline should be /dev/null). + let add_changes = HashMap::from([( + file.clone(), + FileChange::Add { + content: "foo\n".to_string(), + }, + )]); + acc.on_patch_begin(&add_changes); + + // Simulate apply: create the file on disk. + fs::write(&file, "foo\n").unwrap(); + let first = acc.get_unified_diff().unwrap().unwrap(); + let first = normalize_diff_for_test(&first, dir.path()); + let expected_first = { + let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular); + let right_oid = git_blob_sha1_hex("foo\n"); + format!( + r#"diff --git a//name with spaces.txt b//name with spaces.txt +new file mode {mode} +index {ZERO_OID}..{right_oid} +--- {DEV_NULL} ++++ b//name with spaces.txt +@@ -0,0 +1 @@ ++foo +"#, + ) + }; + assert_eq!(first, expected_first); + + // Second patch: update the file on disk. + let update_changes = HashMap::from([( + file.clone(), + FileChange::Update { + unified_diff: "".to_owned(), + move_path: None, + }, + )]); + acc.on_patch_begin(&update_changes); + + // Simulate apply: append a new line with a space. + fs::write(&file, "foo\nbar baz\n").unwrap(); + let combined = acc.get_unified_diff().unwrap().unwrap(); + let combined = normalize_diff_for_test(&combined, dir.path()); + let expected_combined = { + let mode = file_mode_for_path(&file).unwrap_or(FileMode::Regular); + let right_oid = git_blob_sha1_hex("foo\nbar baz\n"); + format!( + r#"diff --git a//name with spaces.txt b//name with spaces.txt +new file mode {mode} +index {ZERO_OID}..{right_oid} +--- {DEV_NULL} ++++ b//name with spaces.txt +@@ -0,0 +1,2 @@ ++foo ++bar baz +"#, + ) + }; + assert_eq!(combined, expected_combined); +} diff --git a/codex-rs/core/src/turn_metadata.rs b/codex-rs/core/src/turn_metadata.rs index cb09e093b8..44d36b3a14 100644 --- a/codex-rs/core/src/turn_metadata.rs +++ b/codex-rs/core/src/turn_metadata.rs @@ -228,87 +228,5 @@ impl TurnMetadataState { } #[cfg(test)] -mod tests { - use super::*; - - use serde_json::Value; - use tempfile::TempDir; - use tokio::process::Command; - - #[tokio::test] - async fn build_turn_metadata_header_includes_has_changes_for_clean_repo() { - let temp_dir = TempDir::new().expect("temp dir"); - let repo_path = temp_dir.path().join("repo"); - std::fs::create_dir_all(&repo_path).expect("create repo"); - - Command::new("git") - .args(["init"]) - .current_dir(&repo_path) - .output() - .await - .expect("git init"); - Command::new("git") - .args(["config", "user.name", "Test User"]) - .current_dir(&repo_path) - .output() - .await - .expect("git config user.name"); - Command::new("git") - .args(["config", "user.email", "test@example.com"]) - .current_dir(&repo_path) - .output() - .await - .expect("git config user.email"); - - std::fs::write(repo_path.join("README.md"), "hello").expect("write file"); - Command::new("git") - .args(["add", "."]) - .current_dir(&repo_path) - .output() - .await - .expect("git add"); - Command::new("git") - .args(["commit", "-m", "initial"]) - .current_dir(&repo_path) - .output() - .await - .expect("git commit"); - - let header = build_turn_metadata_header(&repo_path, Some("none")) - .await - .expect("header"); - let parsed: Value = serde_json::from_str(&header).expect("valid json"); - let workspace = parsed - .get("workspaces") - .and_then(Value::as_object) - .and_then(|workspaces| workspaces.values().next()) - .cloned() - .expect("workspace"); - - assert_eq!( - workspace.get("has_changes").and_then(Value::as_bool), - Some(false) - ); - } - - #[test] - fn turn_metadata_state_uses_platform_sandbox_tag() { - let temp_dir = TempDir::new().expect("temp dir"); - let cwd = temp_dir.path().to_path_buf(); - let sandbox_policy = SandboxPolicy::new_read_only_policy(); - - let state = TurnMetadataState::new( - "turn-a".to_string(), - cwd, - &sandbox_policy, - WindowsSandboxLevel::Disabled, - ); - - let header = state.current_header_value().expect("header"); - let json: Value = serde_json::from_str(&header).expect("json"); - let sandbox_name = json.get("sandbox").and_then(Value::as_str); - - let expected_sandbox = sandbox_tag(&sandbox_policy, WindowsSandboxLevel::Disabled); - assert_eq!(sandbox_name, Some(expected_sandbox)); - } -} +#[path = "turn_metadata_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/turn_metadata_tests.rs b/codex-rs/core/src/turn_metadata_tests.rs new file mode 100644 index 0000000000..5124213de3 --- /dev/null +++ b/codex-rs/core/src/turn_metadata_tests.rs @@ -0,0 +1,82 @@ +use super::*; + +use serde_json::Value; +use tempfile::TempDir; +use tokio::process::Command; + +#[tokio::test] +async fn build_turn_metadata_header_includes_has_changes_for_clean_repo() { + let temp_dir = TempDir::new().expect("temp dir"); + let repo_path = temp_dir.path().join("repo"); + std::fs::create_dir_all(&repo_path).expect("create repo"); + + Command::new("git") + .args(["init"]) + .current_dir(&repo_path) + .output() + .await + .expect("git init"); + Command::new("git") + .args(["config", "user.name", "Test User"]) + .current_dir(&repo_path) + .output() + .await + .expect("git config user.name"); + Command::new("git") + .args(["config", "user.email", "test@example.com"]) + .current_dir(&repo_path) + .output() + .await + .expect("git config user.email"); + + std::fs::write(repo_path.join("README.md"), "hello").expect("write file"); + Command::new("git") + .args(["add", "."]) + .current_dir(&repo_path) + .output() + .await + .expect("git add"); + Command::new("git") + .args(["commit", "-m", "initial"]) + .current_dir(&repo_path) + .output() + .await + .expect("git commit"); + + let header = build_turn_metadata_header(&repo_path, Some("none")) + .await + .expect("header"); + let parsed: Value = serde_json::from_str(&header).expect("valid json"); + let workspace = parsed + .get("workspaces") + .and_then(Value::as_object) + .and_then(|workspaces| workspaces.values().next()) + .cloned() + .expect("workspace"); + + assert_eq!( + workspace.get("has_changes").and_then(Value::as_bool), + Some(false) + ); +} + +#[test] +fn turn_metadata_state_uses_platform_sandbox_tag() { + let temp_dir = TempDir::new().expect("temp dir"); + let cwd = temp_dir.path().to_path_buf(); + let sandbox_policy = SandboxPolicy::new_read_only_policy(); + + let state = TurnMetadataState::new( + "turn-a".to_string(), + cwd, + &sandbox_policy, + WindowsSandboxLevel::Disabled, + ); + + let header = state.current_header_value().expect("header"); + let json: Value = serde_json::from_str(&header).expect("json"); + let sandbox_name = json.get("sandbox").and_then(Value::as_str); + + let expected_sandbox = sandbox_tag(&sandbox_policy, WindowsSandboxLevel::Disabled); + assert_eq!(sandbox_name, Some(expected_sandbox)); +} diff --git a/codex-rs/core/src/turn_timing.rs b/codex-rs/core/src/turn_timing.rs index f197242f33..c68f16e451 100644 --- a/codex-rs/core/src/turn_timing.rs +++ b/codex-rs/core/src/turn_timing.rs @@ -154,132 +154,5 @@ fn response_item_records_turn_ttft(item: &ResponseItem) -> bool { } #[cfg(test)] -mod tests { - use codex_protocol::items::AgentMessageItem; - use codex_protocol::items::TurnItem; - use codex_protocol::models::ContentItem; - use codex_protocol::models::FunctionCallOutputPayload; - use codex_protocol::models::ResponseItem; - use pretty_assertions::assert_eq; - use std::time::Instant; - - use super::TurnTimingState; - use super::response_item_records_turn_ttft; - use crate::ResponseEvent; - - #[tokio::test] - async fn turn_timing_state_records_ttft_only_once_per_turn() { - let state = TurnTimingState::default(); - assert_eq!( - state - .record_ttft_for_response_event(&ResponseEvent::OutputTextDelta("hi".to_string())) - .await, - None - ); - - state.mark_turn_started(Instant::now()).await; - assert_eq!( - state - .record_ttft_for_response_event(&ResponseEvent::Created) - .await, - None - ); - assert!( - state - .record_ttft_for_response_event(&ResponseEvent::OutputTextDelta("hi".to_string())) - .await - .is_some() - ); - assert_eq!( - state - .record_ttft_for_response_event(&ResponseEvent::OutputTextDelta( - "again".to_string() - )) - .await, - None - ); - } - - #[tokio::test] - async fn turn_timing_state_records_ttfm_independently_of_ttft() { - let state = TurnTimingState::default(); - state.mark_turn_started(Instant::now()).await; - - assert!( - state - .record_ttft_for_response_event(&ResponseEvent::OutputTextDelta("hi".to_string())) - .await - .is_some() - ); - assert!( - state - .record_ttfm_for_turn_item(&TurnItem::AgentMessage(AgentMessageItem { - id: "msg-1".to_string(), - content: Vec::new(), - phase: None, - })) - .await - .is_some() - ); - assert_eq!( - state - .record_ttfm_for_turn_item(&TurnItem::AgentMessage(AgentMessageItem { - id: "msg-2".to_string(), - content: Vec::new(), - phase: None, - })) - .await, - None - ); - } - - #[test] - fn response_item_records_turn_ttft_for_first_output_signals() { - assert!(response_item_records_turn_ttft( - &ResponseItem::FunctionCall { - id: None, - name: "shell".to_string(), - namespace: None, - arguments: "{}".to_string(), - call_id: "call-1".to_string(), - } - )); - assert!(response_item_records_turn_ttft( - &ResponseItem::CustomToolCall { - id: None, - status: None, - call_id: "call-2".to_string(), - name: "custom".to_string(), - input: "echo hi".to_string(), - } - )); - assert!(response_item_records_turn_ttft(&ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: "hello".to_string(), - }], - end_turn: None, - phase: None, - })); - } - - #[test] - fn response_item_records_turn_ttft_ignores_empty_non_output_items() { - assert!(!response_item_records_turn_ttft(&ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: String::new(), - }], - end_turn: None, - phase: None, - })); - assert!(!response_item_records_turn_ttft( - &ResponseItem::FunctionCallOutput { - call_id: "call-1".to_string(), - output: FunctionCallOutputPayload::from_text("ok".to_string()), - } - )); - } -} +#[path = "turn_timing_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/turn_timing_tests.rs b/codex-rs/core/src/turn_timing_tests.rs new file mode 100644 index 0000000000..4f292b40dc --- /dev/null +++ b/codex-rs/core/src/turn_timing_tests.rs @@ -0,0 +1,125 @@ +use codex_protocol::items::AgentMessageItem; +use codex_protocol::items::TurnItem; +use codex_protocol::models::ContentItem; +use codex_protocol::models::FunctionCallOutputPayload; +use codex_protocol::models::ResponseItem; +use pretty_assertions::assert_eq; +use std::time::Instant; + +use super::TurnTimingState; +use super::response_item_records_turn_ttft; +use crate::ResponseEvent; + +#[tokio::test] +async fn turn_timing_state_records_ttft_only_once_per_turn() { + let state = TurnTimingState::default(); + assert_eq!( + state + .record_ttft_for_response_event(&ResponseEvent::OutputTextDelta("hi".to_string())) + .await, + None + ); + + state.mark_turn_started(Instant::now()).await; + assert_eq!( + state + .record_ttft_for_response_event(&ResponseEvent::Created) + .await, + None + ); + assert!( + state + .record_ttft_for_response_event(&ResponseEvent::OutputTextDelta("hi".to_string())) + .await + .is_some() + ); + assert_eq!( + state + .record_ttft_for_response_event(&ResponseEvent::OutputTextDelta("again".to_string())) + .await, + None + ); +} + +#[tokio::test] +async fn turn_timing_state_records_ttfm_independently_of_ttft() { + let state = TurnTimingState::default(); + state.mark_turn_started(Instant::now()).await; + + assert!( + state + .record_ttft_for_response_event(&ResponseEvent::OutputTextDelta("hi".to_string())) + .await + .is_some() + ); + assert!( + state + .record_ttfm_for_turn_item(&TurnItem::AgentMessage(AgentMessageItem { + id: "msg-1".to_string(), + content: Vec::new(), + phase: None, + })) + .await + .is_some() + ); + assert_eq!( + state + .record_ttfm_for_turn_item(&TurnItem::AgentMessage(AgentMessageItem { + id: "msg-2".to_string(), + content: Vec::new(), + phase: None, + })) + .await, + None + ); +} + +#[test] +fn response_item_records_turn_ttft_for_first_output_signals() { + assert!(response_item_records_turn_ttft( + &ResponseItem::FunctionCall { + id: None, + name: "shell".to_string(), + namespace: None, + arguments: "{}".to_string(), + call_id: "call-1".to_string(), + } + )); + assert!(response_item_records_turn_ttft( + &ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "call-2".to_string(), + name: "custom".to_string(), + input: "echo hi".to_string(), + } + )); + assert!(response_item_records_turn_ttft(&ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: "hello".to_string(), + }], + end_turn: None, + phase: None, + })); +} + +#[test] +fn response_item_records_turn_ttft_ignores_empty_non_output_items() { + assert!(!response_item_records_turn_ttft(&ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: String::new(), + }], + end_turn: None, + phase: None, + })); + assert!(!response_item_records_turn_ttft( + &ResponseItem::FunctionCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload::from_text("ok".to_string()), + } + )); +} diff --git a/codex-rs/core/src/unified_exec/async_watcher.rs b/codex-rs/core/src/unified_exec/async_watcher.rs index 47543a00fc..d595342544 100644 --- a/codex-rs/core/src/unified_exec/async_watcher.rs +++ b/codex-rs/core/src/unified_exec/async_watcher.rs @@ -251,40 +251,5 @@ async fn resolve_aggregated_output( } #[cfg(test)] -mod tests { - use super::split_valid_utf8_prefix_with_max; - - use pretty_assertions::assert_eq; - - #[test] - fn split_valid_utf8_prefix_respects_max_bytes_for_ascii() { - let mut buf = b"hello word!".to_vec(); - - let first = split_valid_utf8_prefix_with_max(&mut buf, 5).expect("expected prefix"); - assert_eq!(first, b"hello".to_vec()); - assert_eq!(buf, b" word!".to_vec()); - - let second = split_valid_utf8_prefix_with_max(&mut buf, 5).expect("expected prefix"); - assert_eq!(second, b" word".to_vec()); - assert_eq!(buf, b"!".to_vec()); - } - - #[test] - fn split_valid_utf8_prefix_avoids_splitting_utf8_codepoints() { - // "é" is 2 bytes in UTF-8. With a max of 3 bytes, we should only emit 1 char (2 bytes). - let mut buf = "ééé".as_bytes().to_vec(); - - let first = split_valid_utf8_prefix_with_max(&mut buf, 3).expect("expected prefix"); - assert_eq!(std::str::from_utf8(&first).unwrap(), "é"); - assert_eq!(buf, "éé".as_bytes().to_vec()); - } - - #[test] - fn split_valid_utf8_prefix_makes_progress_on_invalid_utf8() { - let mut buf = vec![0xff, b'a', b'b']; - - let first = split_valid_utf8_prefix_with_max(&mut buf, 2).expect("expected prefix"); - assert_eq!(first, vec![0xff]); - assert_eq!(buf, b"ab".to_vec()); - } -} +#[path = "async_watcher_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/unified_exec/async_watcher_tests.rs b/codex-rs/core/src/unified_exec/async_watcher_tests.rs new file mode 100644 index 0000000000..bdf8f7534b --- /dev/null +++ b/codex-rs/core/src/unified_exec/async_watcher_tests.rs @@ -0,0 +1,35 @@ +use super::split_valid_utf8_prefix_with_max; + +use pretty_assertions::assert_eq; + +#[test] +fn split_valid_utf8_prefix_respects_max_bytes_for_ascii() { + let mut buf = b"hello word!".to_vec(); + + let first = split_valid_utf8_prefix_with_max(&mut buf, 5).expect("expected prefix"); + assert_eq!(first, b"hello".to_vec()); + assert_eq!(buf, b" word!".to_vec()); + + let second = split_valid_utf8_prefix_with_max(&mut buf, 5).expect("expected prefix"); + assert_eq!(second, b" word".to_vec()); + assert_eq!(buf, b"!".to_vec()); +} + +#[test] +fn split_valid_utf8_prefix_avoids_splitting_utf8_codepoints() { + // "é" is 2 bytes in UTF-8. With a max of 3 bytes, we should only emit 1 char (2 bytes). + let mut buf = "ééé".as_bytes().to_vec(); + + let first = split_valid_utf8_prefix_with_max(&mut buf, 3).expect("expected prefix"); + assert_eq!(std::str::from_utf8(&first).unwrap(), "é"); + assert_eq!(buf, "éé".as_bytes().to_vec()); +} + +#[test] +fn split_valid_utf8_prefix_makes_progress_on_invalid_utf8() { + let mut buf = vec![0xff, b'a', b'b']; + + let first = split_valid_utf8_prefix_with_max(&mut buf, 2).expect("expected prefix"); + assert_eq!(first, vec![0xff]); + assert_eq!(buf, b"ab".to_vec()); +} diff --git a/codex-rs/core/src/unified_exec/head_tail_buffer.rs b/codex-rs/core/src/unified_exec/head_tail_buffer.rs index 8524466048..52039e1495 100644 --- a/codex-rs/core/src/unified_exec/head_tail_buffer.rs +++ b/codex-rs/core/src/unified_exec/head_tail_buffer.rs @@ -179,94 +179,5 @@ impl HeadTailBuffer { } #[cfg(test)] -mod tests { - use super::HeadTailBuffer; - - use pretty_assertions::assert_eq; - - #[test] - fn keeps_prefix_and_suffix_when_over_budget() { - let mut buf = HeadTailBuffer::new(10); - - buf.push_chunk(b"0123456789".to_vec()); - assert_eq!(buf.omitted_bytes(), 0); - - // Exceeds max by 2; we should keep head+tail and omit the middle. - buf.push_chunk(b"ab".to_vec()); - assert!(buf.omitted_bytes() > 0); - - let rendered = String::from_utf8_lossy(&buf.to_bytes()).to_string(); - assert!(rendered.starts_with("01234")); - assert!(rendered.ends_with("89ab")); - } - - #[test] - fn max_bytes_zero_drops_everything() { - let mut buf = HeadTailBuffer::new(0); - buf.push_chunk(b"abc".to_vec()); - - assert_eq!(buf.retained_bytes(), 0); - assert_eq!(buf.omitted_bytes(), 3); - assert_eq!(buf.to_bytes(), b"".to_vec()); - assert_eq!(buf.snapshot_chunks(), Vec::>::new()); - } - - #[test] - fn head_budget_zero_keeps_only_last_byte_in_tail() { - let mut buf = HeadTailBuffer::new(1); - buf.push_chunk(b"abc".to_vec()); - - assert_eq!(buf.retained_bytes(), 1); - assert_eq!(buf.omitted_bytes(), 2); - assert_eq!(buf.to_bytes(), b"c".to_vec()); - } - - #[test] - fn draining_resets_state() { - let mut buf = HeadTailBuffer::new(10); - buf.push_chunk(b"0123456789".to_vec()); - buf.push_chunk(b"ab".to_vec()); - - let drained = buf.drain_chunks(); - assert!(!drained.is_empty()); - - assert_eq!(buf.retained_bytes(), 0); - assert_eq!(buf.omitted_bytes(), 0); - assert_eq!(buf.to_bytes(), b"".to_vec()); - } - - #[test] - fn chunk_larger_than_tail_budget_keeps_only_tail_end() { - let mut buf = HeadTailBuffer::new(10); - buf.push_chunk(b"0123456789".to_vec()); - - // Tail budget is 5 bytes. This chunk should replace the tail and keep only its last 5 bytes. - buf.push_chunk(b"ABCDEFGHIJK".to_vec()); - - let out = String::from_utf8_lossy(&buf.to_bytes()).to_string(); - assert!(out.starts_with("01234")); - assert!(out.ends_with("GHIJK")); - assert!(buf.omitted_bytes() > 0); - } - - #[test] - fn fills_head_then_tail_across_multiple_chunks() { - let mut buf = HeadTailBuffer::new(10); - - // Fill the 5-byte head budget across multiple chunks. - buf.push_chunk(b"01".to_vec()); - buf.push_chunk(b"234".to_vec()); - assert_eq!(buf.to_bytes(), b"01234".to_vec()); - - // Then fill the 5-byte tail budget. - buf.push_chunk(b"567".to_vec()); - buf.push_chunk(b"89".to_vec()); - assert_eq!(buf.to_bytes(), b"0123456789".to_vec()); - assert_eq!(buf.omitted_bytes(), 0); - - // One more byte causes the tail to drop its oldest byte. - buf.push_chunk(b"a".to_vec()); - assert_eq!(buf.to_bytes(), b"012346789a".to_vec()); - assert_eq!(buf.omitted_bytes(), 1); - } -} +#[path = "head_tail_buffer_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/unified_exec/head_tail_buffer_tests.rs b/codex-rs/core/src/unified_exec/head_tail_buffer_tests.rs new file mode 100644 index 0000000000..55493a6b84 --- /dev/null +++ b/codex-rs/core/src/unified_exec/head_tail_buffer_tests.rs @@ -0,0 +1,89 @@ +use super::HeadTailBuffer; + +use pretty_assertions::assert_eq; + +#[test] +fn keeps_prefix_and_suffix_when_over_budget() { + let mut buf = HeadTailBuffer::new(10); + + buf.push_chunk(b"0123456789".to_vec()); + assert_eq!(buf.omitted_bytes(), 0); + + // Exceeds max by 2; we should keep head+tail and omit the middle. + buf.push_chunk(b"ab".to_vec()); + assert!(buf.omitted_bytes() > 0); + + let rendered = String::from_utf8_lossy(&buf.to_bytes()).to_string(); + assert!(rendered.starts_with("01234")); + assert!(rendered.ends_with("89ab")); +} + +#[test] +fn max_bytes_zero_drops_everything() { + let mut buf = HeadTailBuffer::new(0); + buf.push_chunk(b"abc".to_vec()); + + assert_eq!(buf.retained_bytes(), 0); + assert_eq!(buf.omitted_bytes(), 3); + assert_eq!(buf.to_bytes(), b"".to_vec()); + assert_eq!(buf.snapshot_chunks(), Vec::>::new()); +} + +#[test] +fn head_budget_zero_keeps_only_last_byte_in_tail() { + let mut buf = HeadTailBuffer::new(1); + buf.push_chunk(b"abc".to_vec()); + + assert_eq!(buf.retained_bytes(), 1); + assert_eq!(buf.omitted_bytes(), 2); + assert_eq!(buf.to_bytes(), b"c".to_vec()); +} + +#[test] +fn draining_resets_state() { + let mut buf = HeadTailBuffer::new(10); + buf.push_chunk(b"0123456789".to_vec()); + buf.push_chunk(b"ab".to_vec()); + + let drained = buf.drain_chunks(); + assert!(!drained.is_empty()); + + assert_eq!(buf.retained_bytes(), 0); + assert_eq!(buf.omitted_bytes(), 0); + assert_eq!(buf.to_bytes(), b"".to_vec()); +} + +#[test] +fn chunk_larger_than_tail_budget_keeps_only_tail_end() { + let mut buf = HeadTailBuffer::new(10); + buf.push_chunk(b"0123456789".to_vec()); + + // Tail budget is 5 bytes. This chunk should replace the tail and keep only its last 5 bytes. + buf.push_chunk(b"ABCDEFGHIJK".to_vec()); + + let out = String::from_utf8_lossy(&buf.to_bytes()).to_string(); + assert!(out.starts_with("01234")); + assert!(out.ends_with("GHIJK")); + assert!(buf.omitted_bytes() > 0); +} + +#[test] +fn fills_head_then_tail_across_multiple_chunks() { + let mut buf = HeadTailBuffer::new(10); + + // Fill the 5-byte head budget across multiple chunks. + buf.push_chunk(b"01".to_vec()); + buf.push_chunk(b"234".to_vec()); + assert_eq!(buf.to_bytes(), b"01234".to_vec()); + + // Then fill the 5-byte tail budget. + buf.push_chunk(b"567".to_vec()); + buf.push_chunk(b"89".to_vec()); + assert_eq!(buf.to_bytes(), b"0123456789".to_vec()); + assert_eq!(buf.omitted_bytes(), 0); + + // One more byte causes the tail to drop its oldest byte. + buf.push_chunk(b"a".to_vec()); + assert_eq!(buf.to_bytes(), b"012346789a".to_vec()); + assert_eq!(buf.omitted_bytes(), 1); +} diff --git a/codex-rs/core/src/unified_exec/mod.rs b/codex-rs/core/src/unified_exec/mod.rs index 91af47accd..3e69a71eea 100644 --- a/codex-rs/core/src/unified_exec/mod.rs +++ b/codex-rs/core/src/unified_exec/mod.rs @@ -169,350 +169,5 @@ pub(crate) fn generate_chunk_id() -> String { #[cfg(test)] #[cfg(unix)] -mod tests { - use super::head_tail_buffer::HeadTailBuffer; - use super::*; - use crate::codex::Session; - use crate::codex::TurnContext; - use crate::codex::make_session_and_context; - use crate::protocol::AskForApproval; - use crate::protocol::SandboxPolicy; - use crate::tools::context::ExecCommandToolOutput; - use crate::unified_exec::ExecCommandRequest; - use crate::unified_exec::WriteStdinRequest; - use core_test_support::skip_if_sandbox; - use std::sync::Arc; - use tokio::time::Duration; - - async fn test_session_and_turn() -> (Arc, Arc) { - let (session, mut turn) = make_session_and_context().await; - turn.approval_policy - .set(AskForApproval::Never) - .expect("test setup should allow updating approval policy"); - turn.sandbox_policy - .set(SandboxPolicy::DangerFullAccess) - .expect("test setup should allow updating sandbox policy"); - turn.file_system_sandbox_policy = - codex_protocol::permissions::FileSystemSandboxPolicy::from(turn.sandbox_policy.get()); - turn.network_sandbox_policy = - codex_protocol::permissions::NetworkSandboxPolicy::from(turn.sandbox_policy.get()); - (Arc::new(session), Arc::new(turn)) - } - - async fn exec_command( - session: &Arc, - turn: &Arc, - cmd: &str, - yield_time_ms: u64, - ) -> Result { - let context = - UnifiedExecContext::new(Arc::clone(session), Arc::clone(turn), "call".to_string()); - let process_id = session - .services - .unified_exec_manager - .allocate_process_id() - .await; - - session - .services - .unified_exec_manager - .exec_command( - ExecCommandRequest { - command: vec!["bash".to_string(), "-lc".to_string(), cmd.to_string()], - process_id, - yield_time_ms, - max_output_tokens: None, - workdir: None, - network: None, - tty: true, - sandbox_permissions: SandboxPermissions::UseDefault, - additional_permissions: None, - additional_permissions_preapproved: false, - justification: None, - prefix_rule: None, - }, - &context, - ) - .await - } - - async fn write_stdin( - session: &Arc, - process_id: i32, - input: &str, - yield_time_ms: u64, - ) -> Result { - session - .services - .unified_exec_manager - .write_stdin(WriteStdinRequest { - process_id, - input, - yield_time_ms, - max_output_tokens: None, - }) - .await - } - - #[test] - fn push_chunk_preserves_prefix_and_suffix() { - let mut buffer = HeadTailBuffer::default(); - buffer.push_chunk(vec![b'a'; UNIFIED_EXEC_OUTPUT_MAX_BYTES]); - buffer.push_chunk(vec![b'b']); - buffer.push_chunk(vec![b'c']); - - assert_eq!(buffer.retained_bytes(), UNIFIED_EXEC_OUTPUT_MAX_BYTES); - let snapshot = buffer.snapshot_chunks(); - - let first = snapshot.first().expect("expected at least one chunk"); - assert_eq!(first.first(), Some(&b'a')); - assert!(snapshot.iter().any(|chunk| chunk.as_slice() == b"b")); - assert_eq!( - snapshot - .last() - .expect("expected at least one chunk") - .as_slice(), - b"c" - ); - } - - #[test] - fn head_tail_buffer_default_preserves_prefix_and_suffix() { - let mut buffer = HeadTailBuffer::default(); - buffer.push_chunk(vec![b'a'; UNIFIED_EXEC_OUTPUT_MAX_BYTES]); - buffer.push_chunk(b"bc".to_vec()); - - let rendered = buffer.to_bytes(); - assert_eq!(rendered.first(), Some(&b'a')); - assert!(rendered.ends_with(b"bc")); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn unified_exec_persists_across_requests() -> anyhow::Result<()> { - skip_if_sandbox!(Ok(())); - - let (session, turn) = test_session_and_turn().await; - - let open_shell = exec_command(&session, &turn, "bash -i", 2_500).await?; - let process_id = open_shell.process_id.expect("expected process_id"); - - write_stdin( - &session, - process_id, - "export CODEX_INTERACTIVE_SHELL_VAR=codex\n", - 2_500, - ) - .await?; - - let out_2 = write_stdin( - &session, - process_id, - "echo $CODEX_INTERACTIVE_SHELL_VAR\n", - 2_500, - ) - .await?; - assert!( - out_2.truncated_output().contains("codex"), - "expected environment variable output" - ); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn multi_unified_exec_sessions() -> anyhow::Result<()> { - skip_if_sandbox!(Ok(())); - - let (session, turn) = test_session_and_turn().await; - - let shell_a = exec_command(&session, &turn, "bash -i", 2_500).await?; - let session_a = shell_a.process_id.expect("expected process id"); - - write_stdin( - &session, - session_a, - "export CODEX_INTERACTIVE_SHELL_VAR=codex\n", - 2_500, - ) - .await?; - - let out_2 = - exec_command(&session, &turn, "echo $CODEX_INTERACTIVE_SHELL_VAR", 2_500).await?; - tokio::time::sleep(Duration::from_secs(2)).await; - assert!( - out_2.process_id.is_none(), - "short command should not report a process id if it exits quickly" - ); - assert!( - !out_2.truncated_output().contains("codex"), - "short command should run in a fresh shell" - ); - - let out_3 = write_stdin( - &session, - shell_a.process_id.expect("expected process id"), - "echo $CODEX_INTERACTIVE_SHELL_VAR\n", - 2_500, - ) - .await?; - assert!( - out_3.truncated_output().contains("codex"), - "session should preserve state" - ); - - Ok(()) - } - - #[tokio::test] - async fn unified_exec_timeouts() -> anyhow::Result<()> { - skip_if_sandbox!(Ok(())); - - const TEST_VAR_VALUE: &str = "unified_exec_var_123"; - - let (session, turn) = test_session_and_turn().await; - - let open_shell = exec_command(&session, &turn, "bash -i", 2_500).await?; - let process_id = open_shell.process_id.expect("expected process id"); - - write_stdin( - &session, - process_id, - format!("export CODEX_INTERACTIVE_SHELL_VAR={TEST_VAR_VALUE}\n").as_str(), - 2_500, - ) - .await?; - - let out_2 = write_stdin( - &session, - process_id, - "sleep 5 && echo $CODEX_INTERACTIVE_SHELL_VAR\n", - 10, - ) - .await?; - assert!( - !out_2.truncated_output().contains(TEST_VAR_VALUE), - "timeout too short should yield incomplete output" - ); - - tokio::time::sleep(Duration::from_secs(7)).await; - - let out_3 = write_stdin(&session, process_id, "", 100).await?; - - assert!( - out_3.truncated_output().contains(TEST_VAR_VALUE), - "subsequent poll should retrieve output" - ); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn unified_exec_pause_blocks_yield_timeout() -> anyhow::Result<()> { - skip_if_sandbox!(Ok(())); - - let (session, turn) = test_session_and_turn().await; - session.set_out_of_band_elicitation_pause_state(true); - - let paused_session = Arc::clone(&session); - tokio::spawn(async move { - tokio::time::sleep(Duration::from_secs(2)).await; - paused_session.set_out_of_band_elicitation_pause_state(false); - }); - - let started = tokio::time::Instant::now(); - let response = - exec_command(&session, &turn, "sleep 1 && echo unified-exec-done", 250).await?; - - assert!( - started.elapsed() >= Duration::from_secs(2), - "pause should block the unified exec yield timeout" - ); - assert!( - response.truncated_output().contains("unified-exec-done"), - "exec_command should wait for output after the pause lifts" - ); - assert!( - response.process_id.is_none(), - "completed command should not leave a background process" - ); - - Ok(()) - } - - #[tokio::test] - #[ignore] // Ignored while we have a better way to test this. - async fn requests_with_large_timeout_are_capped() -> anyhow::Result<()> { - let (session, turn) = test_session_and_turn().await; - - let result = exec_command(&session, &turn, "echo codex", 120_000).await?; - - assert!(result.process_id.is_some()); - assert!(result.truncated_output().contains("codex")); - - Ok(()) - } - - #[tokio::test] - #[ignore] // Ignored while we have a better way to test this. - async fn completed_commands_do_not_persist_sessions() -> anyhow::Result<()> { - let (session, turn) = test_session_and_turn().await; - let result = exec_command(&session, &turn, "echo codex", 2_500).await?; - - assert!( - result.process_id.is_some(), - "completed command should report a process id" - ); - assert!(result.truncated_output().contains("codex")); - - assert!( - session - .services - .unified_exec_manager - .process_store - .lock() - .await - .processes - .is_empty() - ); - - Ok(()) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn reusing_completed_process_returns_unknown_process() -> anyhow::Result<()> { - skip_if_sandbox!(Ok(())); - - let (session, turn) = test_session_and_turn().await; - - let open_shell = exec_command(&session, &turn, "bash -i", 2_500).await?; - let process_id = open_shell.process_id.expect("expected process id"); - - write_stdin(&session, process_id, "exit\n", 2_500).await?; - - tokio::time::sleep(Duration::from_millis(200)).await; - - let err = write_stdin(&session, process_id, "", 100) - .await - .expect_err("expected unknown process error"); - - match err { - UnifiedExecError::UnknownProcessId { process_id: err_id } => { - assert_eq!(err_id, process_id, "process id should match request"); - } - other => panic!("expected UnknownProcessId, got {other:?}"), - } - - assert!( - session - .services - .unified_exec_manager - .process_store - .lock() - .await - .processes - .is_empty() - ); - - Ok(()) - } -} +#[path = "mod_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/unified_exec/mod_tests.rs b/codex-rs/core/src/unified_exec/mod_tests.rs new file mode 100644 index 0000000000..c81d1329d5 --- /dev/null +++ b/codex-rs/core/src/unified_exec/mod_tests.rs @@ -0,0 +1,343 @@ +use super::head_tail_buffer::HeadTailBuffer; +use super::*; +use crate::codex::Session; +use crate::codex::TurnContext; +use crate::codex::make_session_and_context; +use crate::protocol::AskForApproval; +use crate::protocol::SandboxPolicy; +use crate::tools::context::ExecCommandToolOutput; +use crate::unified_exec::ExecCommandRequest; +use crate::unified_exec::WriteStdinRequest; +use core_test_support::skip_if_sandbox; +use std::sync::Arc; +use tokio::time::Duration; + +async fn test_session_and_turn() -> (Arc, Arc) { + let (session, mut turn) = make_session_and_context().await; + turn.approval_policy + .set(AskForApproval::Never) + .expect("test setup should allow updating approval policy"); + turn.sandbox_policy + .set(SandboxPolicy::DangerFullAccess) + .expect("test setup should allow updating sandbox policy"); + turn.file_system_sandbox_policy = + codex_protocol::permissions::FileSystemSandboxPolicy::from(turn.sandbox_policy.get()); + turn.network_sandbox_policy = + codex_protocol::permissions::NetworkSandboxPolicy::from(turn.sandbox_policy.get()); + (Arc::new(session), Arc::new(turn)) +} + +async fn exec_command( + session: &Arc, + turn: &Arc, + cmd: &str, + yield_time_ms: u64, +) -> Result { + let context = + UnifiedExecContext::new(Arc::clone(session), Arc::clone(turn), "call".to_string()); + let process_id = session + .services + .unified_exec_manager + .allocate_process_id() + .await; + + session + .services + .unified_exec_manager + .exec_command( + ExecCommandRequest { + command: vec!["bash".to_string(), "-lc".to_string(), cmd.to_string()], + process_id, + yield_time_ms, + max_output_tokens: None, + workdir: None, + network: None, + tty: true, + sandbox_permissions: SandboxPermissions::UseDefault, + additional_permissions: None, + additional_permissions_preapproved: false, + justification: None, + prefix_rule: None, + }, + &context, + ) + .await +} + +async fn write_stdin( + session: &Arc, + process_id: i32, + input: &str, + yield_time_ms: u64, +) -> Result { + session + .services + .unified_exec_manager + .write_stdin(WriteStdinRequest { + process_id, + input, + yield_time_ms, + max_output_tokens: None, + }) + .await +} + +#[test] +fn push_chunk_preserves_prefix_and_suffix() { + let mut buffer = HeadTailBuffer::default(); + buffer.push_chunk(vec![b'a'; UNIFIED_EXEC_OUTPUT_MAX_BYTES]); + buffer.push_chunk(vec![b'b']); + buffer.push_chunk(vec![b'c']); + + assert_eq!(buffer.retained_bytes(), UNIFIED_EXEC_OUTPUT_MAX_BYTES); + let snapshot = buffer.snapshot_chunks(); + + let first = snapshot.first().expect("expected at least one chunk"); + assert_eq!(first.first(), Some(&b'a')); + assert!(snapshot.iter().any(|chunk| chunk.as_slice() == b"b")); + assert_eq!( + snapshot + .last() + .expect("expected at least one chunk") + .as_slice(), + b"c" + ); +} + +#[test] +fn head_tail_buffer_default_preserves_prefix_and_suffix() { + let mut buffer = HeadTailBuffer::default(); + buffer.push_chunk(vec![b'a'; UNIFIED_EXEC_OUTPUT_MAX_BYTES]); + buffer.push_chunk(b"bc".to_vec()); + + let rendered = buffer.to_bytes(); + assert_eq!(rendered.first(), Some(&b'a')); + assert!(rendered.ends_with(b"bc")); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_persists_across_requests() -> anyhow::Result<()> { + skip_if_sandbox!(Ok(())); + + let (session, turn) = test_session_and_turn().await; + + let open_shell = exec_command(&session, &turn, "bash -i", 2_500).await?; + let process_id = open_shell.process_id.expect("expected process_id"); + + write_stdin( + &session, + process_id, + "export CODEX_INTERACTIVE_SHELL_VAR=codex\n", + 2_500, + ) + .await?; + + let out_2 = write_stdin( + &session, + process_id, + "echo $CODEX_INTERACTIVE_SHELL_VAR\n", + 2_500, + ) + .await?; + assert!( + out_2.truncated_output().contains("codex"), + "expected environment variable output" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn multi_unified_exec_sessions() -> anyhow::Result<()> { + skip_if_sandbox!(Ok(())); + + let (session, turn) = test_session_and_turn().await; + + let shell_a = exec_command(&session, &turn, "bash -i", 2_500).await?; + let session_a = shell_a.process_id.expect("expected process id"); + + write_stdin( + &session, + session_a, + "export CODEX_INTERACTIVE_SHELL_VAR=codex\n", + 2_500, + ) + .await?; + + let out_2 = exec_command(&session, &turn, "echo $CODEX_INTERACTIVE_SHELL_VAR", 2_500).await?; + tokio::time::sleep(Duration::from_secs(2)).await; + assert!( + out_2.process_id.is_none(), + "short command should not report a process id if it exits quickly" + ); + assert!( + !out_2.truncated_output().contains("codex"), + "short command should run in a fresh shell" + ); + + let out_3 = write_stdin( + &session, + shell_a.process_id.expect("expected process id"), + "echo $CODEX_INTERACTIVE_SHELL_VAR\n", + 2_500, + ) + .await?; + assert!( + out_3.truncated_output().contains("codex"), + "session should preserve state" + ); + + Ok(()) +} + +#[tokio::test] +async fn unified_exec_timeouts() -> anyhow::Result<()> { + skip_if_sandbox!(Ok(())); + + const TEST_VAR_VALUE: &str = "unified_exec_var_123"; + + let (session, turn) = test_session_and_turn().await; + + let open_shell = exec_command(&session, &turn, "bash -i", 2_500).await?; + let process_id = open_shell.process_id.expect("expected process id"); + + write_stdin( + &session, + process_id, + format!("export CODEX_INTERACTIVE_SHELL_VAR={TEST_VAR_VALUE}\n").as_str(), + 2_500, + ) + .await?; + + let out_2 = write_stdin( + &session, + process_id, + "sleep 5 && echo $CODEX_INTERACTIVE_SHELL_VAR\n", + 10, + ) + .await?; + assert!( + !out_2.truncated_output().contains(TEST_VAR_VALUE), + "timeout too short should yield incomplete output" + ); + + tokio::time::sleep(Duration::from_secs(7)).await; + + let out_3 = write_stdin(&session, process_id, "", 100).await?; + + assert!( + out_3.truncated_output().contains(TEST_VAR_VALUE), + "subsequent poll should retrieve output" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_pause_blocks_yield_timeout() -> anyhow::Result<()> { + skip_if_sandbox!(Ok(())); + + let (session, turn) = test_session_and_turn().await; + session.set_out_of_band_elicitation_pause_state(true); + + let paused_session = Arc::clone(&session); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(2)).await; + paused_session.set_out_of_band_elicitation_pause_state(false); + }); + + let started = tokio::time::Instant::now(); + let response = exec_command(&session, &turn, "sleep 1 && echo unified-exec-done", 250).await?; + + assert!( + started.elapsed() >= Duration::from_secs(2), + "pause should block the unified exec yield timeout" + ); + assert!( + response.truncated_output().contains("unified-exec-done"), + "exec_command should wait for output after the pause lifts" + ); + assert!( + response.process_id.is_none(), + "completed command should not leave a background process" + ); + + Ok(()) +} + +#[tokio::test] +#[ignore] // Ignored while we have a better way to test this. +async fn requests_with_large_timeout_are_capped() -> anyhow::Result<()> { + let (session, turn) = test_session_and_turn().await; + + let result = exec_command(&session, &turn, "echo codex", 120_000).await?; + + assert!(result.process_id.is_some()); + assert!(result.truncated_output().contains("codex")); + + Ok(()) +} + +#[tokio::test] +#[ignore] // Ignored while we have a better way to test this. +async fn completed_commands_do_not_persist_sessions() -> anyhow::Result<()> { + let (session, turn) = test_session_and_turn().await; + let result = exec_command(&session, &turn, "echo codex", 2_500).await?; + + assert!( + result.process_id.is_some(), + "completed command should report a process id" + ); + assert!(result.truncated_output().contains("codex")); + + assert!( + session + .services + .unified_exec_manager + .process_store + .lock() + .await + .processes + .is_empty() + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn reusing_completed_process_returns_unknown_process() -> anyhow::Result<()> { + skip_if_sandbox!(Ok(())); + + let (session, turn) = test_session_and_turn().await; + + let open_shell = exec_command(&session, &turn, "bash -i", 2_500).await?; + let process_id = open_shell.process_id.expect("expected process id"); + + write_stdin(&session, process_id, "exit\n", 2_500).await?; + + tokio::time::sleep(Duration::from_millis(200)).await; + + let err = write_stdin(&session, process_id, "", 100) + .await + .expect_err("expected unknown process error"); + + match err { + UnifiedExecError::UnknownProcessId { process_id: err_id } => { + assert_eq!(err_id, process_id, "process id should match request"); + } + other => panic!("expected UnknownProcessId, got {other:?}"), + } + + assert!( + session + .services + .unified_exec_manager + .process_store + .lock() + .await + .processes + .is_empty() + ); + + Ok(()) +} diff --git a/codex-rs/core/src/unified_exec/process_manager.rs b/codex-rs/core/src/unified_exec/process_manager.rs index dcf2dd090f..9eb269c07d 100644 --- a/codex-rs/core/src/unified_exec/process_manager.rs +++ b/codex-rs/core/src/unified_exec/process_manager.rs @@ -832,104 +832,5 @@ enum ProcessStatus { } #[cfg(test)] -mod tests { - use super::*; - use pretty_assertions::assert_eq; - use tokio::time::Duration; - use tokio::time::Instant; - - #[test] - fn unified_exec_env_injects_defaults() { - let env = apply_unified_exec_env(HashMap::new()); - let expected = HashMap::from([ - ("NO_COLOR".to_string(), "1".to_string()), - ("TERM".to_string(), "dumb".to_string()), - ("LANG".to_string(), "C.UTF-8".to_string()), - ("LC_CTYPE".to_string(), "C.UTF-8".to_string()), - ("LC_ALL".to_string(), "C.UTF-8".to_string()), - ("COLORTERM".to_string(), String::new()), - ("PAGER".to_string(), "cat".to_string()), - ("GIT_PAGER".to_string(), "cat".to_string()), - ("GH_PAGER".to_string(), "cat".to_string()), - ("CODEX_CI".to_string(), "1".to_string()), - ]); - - assert_eq!(env, expected); - } - - #[test] - fn unified_exec_env_overrides_existing_values() { - let mut base = HashMap::new(); - base.insert("NO_COLOR".to_string(), "0".to_string()); - base.insert("PATH".to_string(), "/usr/bin".to_string()); - - let env = apply_unified_exec_env(base); - - assert_eq!(env.get("NO_COLOR"), Some(&"1".to_string())); - assert_eq!(env.get("PATH"), Some(&"/usr/bin".to_string())); - } - - #[test] - fn pruning_prefers_exited_processes_outside_recently_used() { - let now = Instant::now(); - let meta = vec![ - (1, now - Duration::from_secs(40), false), - (2, now - Duration::from_secs(30), true), - (3, now - Duration::from_secs(20), false), - (4, now - Duration::from_secs(19), false), - (5, now - Duration::from_secs(18), false), - (6, now - Duration::from_secs(17), false), - (7, now - Duration::from_secs(16), false), - (8, now - Duration::from_secs(15), false), - (9, now - Duration::from_secs(14), false), - (10, now - Duration::from_secs(13), false), - ]; - - let candidate = UnifiedExecProcessManager::process_id_to_prune_from_meta(&meta); - - assert_eq!(candidate, Some(2)); - } - - #[test] - fn pruning_falls_back_to_lru_when_no_exited() { - let now = Instant::now(); - let meta = vec![ - (1, now - Duration::from_secs(40), false), - (2, now - Duration::from_secs(30), false), - (3, now - Duration::from_secs(20), false), - (4, now - Duration::from_secs(19), false), - (5, now - Duration::from_secs(18), false), - (6, now - Duration::from_secs(17), false), - (7, now - Duration::from_secs(16), false), - (8, now - Duration::from_secs(15), false), - (9, now - Duration::from_secs(14), false), - (10, now - Duration::from_secs(13), false), - ]; - - let candidate = UnifiedExecProcessManager::process_id_to_prune_from_meta(&meta); - - assert_eq!(candidate, Some(1)); - } - - #[test] - fn pruning_protects_recent_processes_even_if_exited() { - let now = Instant::now(); - let meta = vec![ - (1, now - Duration::from_secs(40), false), - (2, now - Duration::from_secs(30), false), - (3, now - Duration::from_secs(20), true), - (4, now - Duration::from_secs(19), false), - (5, now - Duration::from_secs(18), false), - (6, now - Duration::from_secs(17), false), - (7, now - Duration::from_secs(16), false), - (8, now - Duration::from_secs(15), false), - (9, now - Duration::from_secs(14), false), - (10, now - Duration::from_secs(13), true), - ]; - - let candidate = UnifiedExecProcessManager::process_id_to_prune_from_meta(&meta); - - // (10) is exited but among the last 8; we should drop the LRU outside that set. - assert_eq!(candidate, Some(1)); - } -} +#[path = "process_manager_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/unified_exec/process_manager_tests.rs b/codex-rs/core/src/unified_exec/process_manager_tests.rs new file mode 100644 index 0000000000..b145dadb09 --- /dev/null +++ b/codex-rs/core/src/unified_exec/process_manager_tests.rs @@ -0,0 +1,99 @@ +use super::*; +use pretty_assertions::assert_eq; +use tokio::time::Duration; +use tokio::time::Instant; + +#[test] +fn unified_exec_env_injects_defaults() { + let env = apply_unified_exec_env(HashMap::new()); + let expected = HashMap::from([ + ("NO_COLOR".to_string(), "1".to_string()), + ("TERM".to_string(), "dumb".to_string()), + ("LANG".to_string(), "C.UTF-8".to_string()), + ("LC_CTYPE".to_string(), "C.UTF-8".to_string()), + ("LC_ALL".to_string(), "C.UTF-8".to_string()), + ("COLORTERM".to_string(), String::new()), + ("PAGER".to_string(), "cat".to_string()), + ("GIT_PAGER".to_string(), "cat".to_string()), + ("GH_PAGER".to_string(), "cat".to_string()), + ("CODEX_CI".to_string(), "1".to_string()), + ]); + + assert_eq!(env, expected); +} + +#[test] +fn unified_exec_env_overrides_existing_values() { + let mut base = HashMap::new(); + base.insert("NO_COLOR".to_string(), "0".to_string()); + base.insert("PATH".to_string(), "/usr/bin".to_string()); + + let env = apply_unified_exec_env(base); + + assert_eq!(env.get("NO_COLOR"), Some(&"1".to_string())); + assert_eq!(env.get("PATH"), Some(&"/usr/bin".to_string())); +} + +#[test] +fn pruning_prefers_exited_processes_outside_recently_used() { + let now = Instant::now(); + let meta = vec![ + (1, now - Duration::from_secs(40), false), + (2, now - Duration::from_secs(30), true), + (3, now - Duration::from_secs(20), false), + (4, now - Duration::from_secs(19), false), + (5, now - Duration::from_secs(18), false), + (6, now - Duration::from_secs(17), false), + (7, now - Duration::from_secs(16), false), + (8, now - Duration::from_secs(15), false), + (9, now - Duration::from_secs(14), false), + (10, now - Duration::from_secs(13), false), + ]; + + let candidate = UnifiedExecProcessManager::process_id_to_prune_from_meta(&meta); + + assert_eq!(candidate, Some(2)); +} + +#[test] +fn pruning_falls_back_to_lru_when_no_exited() { + let now = Instant::now(); + let meta = vec![ + (1, now - Duration::from_secs(40), false), + (2, now - Duration::from_secs(30), false), + (3, now - Duration::from_secs(20), false), + (4, now - Duration::from_secs(19), false), + (5, now - Duration::from_secs(18), false), + (6, now - Duration::from_secs(17), false), + (7, now - Duration::from_secs(16), false), + (8, now - Duration::from_secs(15), false), + (9, now - Duration::from_secs(14), false), + (10, now - Duration::from_secs(13), false), + ]; + + let candidate = UnifiedExecProcessManager::process_id_to_prune_from_meta(&meta); + + assert_eq!(candidate, Some(1)); +} + +#[test] +fn pruning_protects_recent_processes_even_if_exited() { + let now = Instant::now(); + let meta = vec![ + (1, now - Duration::from_secs(40), false), + (2, now - Duration::from_secs(30), false), + (3, now - Duration::from_secs(20), true), + (4, now - Duration::from_secs(19), false), + (5, now - Duration::from_secs(18), false), + (6, now - Duration::from_secs(17), false), + (7, now - Duration::from_secs(16), false), + (8, now - Duration::from_secs(15), false), + (9, now - Duration::from_secs(14), false), + (10, now - Duration::from_secs(13), true), + ]; + + let candidate = UnifiedExecProcessManager::process_id_to_prune_from_meta(&meta); + + // (10) is exited but among the last 8; we should drop the LRU outside that set. + assert_eq!(candidate, Some(1)); +} diff --git a/codex-rs/core/src/user_shell_command.rs b/codex-rs/core/src/user_shell_command.rs index e7921c69f3..32cf78cf2a 100644 --- a/codex-rs/core/src/user_shell_command.rs +++ b/codex-rs/core/src/user_shell_command.rs @@ -55,61 +55,5 @@ pub fn user_shell_command_record_item( } #[cfg(test)] -mod tests { - use super::*; - use crate::codex::make_session_and_context; - use crate::exec::StreamOutput; - use codex_protocol::models::ContentItem; - use pretty_assertions::assert_eq; - - #[test] - fn detects_user_shell_command_text_variants() { - assert!( - USER_SHELL_COMMAND_FRAGMENT - .matches_text("\necho hi\n") - ); - assert!(!USER_SHELL_COMMAND_FRAGMENT.matches_text("echo hi")); - } - - #[tokio::test] - async fn formats_basic_record() { - let exec_output = ExecToolCallOutput { - exit_code: 0, - stdout: StreamOutput::new("hi".to_string()), - stderr: StreamOutput::new(String::new()), - aggregated_output: StreamOutput::new("hi".to_string()), - duration: Duration::from_secs(1), - timed_out: false, - }; - let (_, turn_context) = make_session_and_context().await; - let item = user_shell_command_record_item("echo hi", &exec_output, &turn_context); - let ResponseItem::Message { content, .. } = item else { - panic!("expected message"); - }; - let [ContentItem::InputText { text }] = content.as_slice() else { - panic!("expected input text"); - }; - assert_eq!( - text, - "\n\necho hi\n\n\nExit code: 0\nDuration: 1.0000 seconds\nOutput:\nhi\n\n" - ); - } - - #[tokio::test] - async fn uses_aggregated_output_over_streams() { - let exec_output = ExecToolCallOutput { - exit_code: 42, - stdout: StreamOutput::new("stdout-only".to_string()), - stderr: StreamOutput::new("stderr-only".to_string()), - aggregated_output: StreamOutput::new("combined output wins".to_string()), - duration: Duration::from_millis(120), - timed_out: false, - }; - let (_, turn_context) = make_session_and_context().await; - let record = format_user_shell_command_record("false", &exec_output, &turn_context); - assert_eq!( - record, - "\n\nfalse\n\n\nExit code: 42\nDuration: 0.1200 seconds\nOutput:\ncombined output wins\n\n" - ); - } -} +#[path = "user_shell_command_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/user_shell_command_tests.rs b/codex-rs/core/src/user_shell_command_tests.rs new file mode 100644 index 0000000000..a034f404e5 --- /dev/null +++ b/codex-rs/core/src/user_shell_command_tests.rs @@ -0,0 +1,56 @@ +use super::*; +use crate::codex::make_session_and_context; +use crate::exec::StreamOutput; +use codex_protocol::models::ContentItem; +use pretty_assertions::assert_eq; + +#[test] +fn detects_user_shell_command_text_variants() { + assert!( + USER_SHELL_COMMAND_FRAGMENT + .matches_text("\necho hi\n") + ); + assert!(!USER_SHELL_COMMAND_FRAGMENT.matches_text("echo hi")); +} + +#[tokio::test] +async fn formats_basic_record() { + let exec_output = ExecToolCallOutput { + exit_code: 0, + stdout: StreamOutput::new("hi".to_string()), + stderr: StreamOutput::new(String::new()), + aggregated_output: StreamOutput::new("hi".to_string()), + duration: Duration::from_secs(1), + timed_out: false, + }; + let (_, turn_context) = make_session_and_context().await; + let item = user_shell_command_record_item("echo hi", &exec_output, &turn_context); + let ResponseItem::Message { content, .. } = item else { + panic!("expected message"); + }; + let [ContentItem::InputText { text }] = content.as_slice() else { + panic!("expected input text"); + }; + assert_eq!( + text, + "\n\necho hi\n\n\nExit code: 0\nDuration: 1.0000 seconds\nOutput:\nhi\n\n" + ); +} + +#[tokio::test] +async fn uses_aggregated_output_over_streams() { + let exec_output = ExecToolCallOutput { + exit_code: 42, + stdout: StreamOutput::new("stdout-only".to_string()), + stderr: StreamOutput::new("stderr-only".to_string()), + aggregated_output: StreamOutput::new("combined output wins".to_string()), + duration: Duration::from_millis(120), + timed_out: false, + }; + let (_, turn_context) = make_session_and_context().await; + let record = format_user_shell_command_record("false", &exec_output, &turn_context); + assert_eq!( + record, + "\n\nfalse\n\n\nExit code: 42\nDuration: 0.1200 seconds\nOutput:\ncombined output wins\n\n" + ); +} diff --git a/codex-rs/core/src/util.rs b/codex-rs/core/src/util.rs index 59fecb0a9e..62e872ae6c 100644 --- a/codex-rs/core/src/util.rs +++ b/codex-rs/core/src/util.rs @@ -102,85 +102,5 @@ pub fn resume_command(thread_name: Option<&str>, thread_id: Option) -> } #[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_try_parse_error_message() { - let text = r#"{ - "error": { - "message": "Your refresh token has already been used to generate a new access token. Please try signing in again.", - "type": "invalid_request_error", - "param": null, - "code": "refresh_token_reused" - } -}"#; - let message = try_parse_error_message(text); - assert_eq!( - message, - "Your refresh token has already been used to generate a new access token. Please try signing in again." - ); - } - - #[test] - fn test_try_parse_error_message_no_error() { - let text = r#"{"message": "test"}"#; - let message = try_parse_error_message(text); - assert_eq!(message, r#"{"message": "test"}"#); - } - - #[test] - fn feedback_tags_macro_compiles() { - #[derive(Debug)] - struct OnlyDebug; - - feedback_tags!(model = "gpt-5", cached = true, debug_only = OnlyDebug); - } - - #[test] - fn normalize_thread_name_trims_and_rejects_empty() { - assert_eq!(normalize_thread_name(" "), None); - assert_eq!( - normalize_thread_name(" my thread "), - Some("my thread".to_string()) - ); - } - - #[test] - fn resume_command_prefers_name_over_id() { - let thread_id = ThreadId::from_string("123e4567-e89b-12d3-a456-426614174000").unwrap(); - let command = resume_command(Some("my-thread"), Some(thread_id)); - assert_eq!(command, Some("codex resume my-thread".to_string())); - } - - #[test] - fn resume_command_with_only_id() { - let thread_id = ThreadId::from_string("123e4567-e89b-12d3-a456-426614174000").unwrap(); - let command = resume_command(None, Some(thread_id)); - assert_eq!( - command, - Some("codex resume 123e4567-e89b-12d3-a456-426614174000".to_string()) - ); - } - - #[test] - fn resume_command_with_no_name_or_id() { - let command = resume_command(None, None); - assert_eq!(command, None); - } - - #[test] - fn resume_command_quotes_thread_name_when_needed() { - let command = resume_command(Some("-starts-with-dash"), None); - assert_eq!( - command, - Some("codex resume -- -starts-with-dash".to_string()) - ); - - let command = resume_command(Some("two words"), None); - assert_eq!(command, Some("codex resume 'two words'".to_string())); - - let command = resume_command(Some("quote'case"), None); - assert_eq!(command, Some("codex resume \"quote'case\"".to_string())); - } -} +#[path = "util_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/util_tests.rs b/codex-rs/core/src/util_tests.rs new file mode 100644 index 0000000000..dd5956bf61 --- /dev/null +++ b/codex-rs/core/src/util_tests.rs @@ -0,0 +1,80 @@ +use super::*; + +#[test] +fn test_try_parse_error_message() { + let text = r#"{ + "error": { + "message": "Your refresh token has already been used to generate a new access token. Please try signing in again.", + "type": "invalid_request_error", + "param": null, + "code": "refresh_token_reused" + } +}"#; + let message = try_parse_error_message(text); + assert_eq!( + message, + "Your refresh token has already been used to generate a new access token. Please try signing in again." + ); +} + +#[test] +fn test_try_parse_error_message_no_error() { + let text = r#"{"message": "test"}"#; + let message = try_parse_error_message(text); + assert_eq!(message, r#"{"message": "test"}"#); +} + +#[test] +fn feedback_tags_macro_compiles() { + #[derive(Debug)] + struct OnlyDebug; + + feedback_tags!(model = "gpt-5", cached = true, debug_only = OnlyDebug); +} + +#[test] +fn normalize_thread_name_trims_and_rejects_empty() { + assert_eq!(normalize_thread_name(" "), None); + assert_eq!( + normalize_thread_name(" my thread "), + Some("my thread".to_string()) + ); +} + +#[test] +fn resume_command_prefers_name_over_id() { + let thread_id = ThreadId::from_string("123e4567-e89b-12d3-a456-426614174000").unwrap(); + let command = resume_command(Some("my-thread"), Some(thread_id)); + assert_eq!(command, Some("codex resume my-thread".to_string())); +} + +#[test] +fn resume_command_with_only_id() { + let thread_id = ThreadId::from_string("123e4567-e89b-12d3-a456-426614174000").unwrap(); + let command = resume_command(None, Some(thread_id)); + assert_eq!( + command, + Some("codex resume 123e4567-e89b-12d3-a456-426614174000".to_string()) + ); +} + +#[test] +fn resume_command_with_no_name_or_id() { + let command = resume_command(None, None); + assert_eq!(command, None); +} + +#[test] +fn resume_command_quotes_thread_name_when_needed() { + let command = resume_command(Some("-starts-with-dash"), None); + assert_eq!( + command, + Some("codex resume -- -starts-with-dash".to_string()) + ); + + let command = resume_command(Some("two words"), None); + assert_eq!(command, Some("codex resume 'two words'".to_string())); + + let command = resume_command(Some("quote'case"), None); + assert_eq!(command, Some("codex resume \"quote'case\"".to_string())); +} diff --git a/codex-rs/core/src/windows_sandbox.rs b/codex-rs/core/src/windows_sandbox.rs index 6e0067aa09..25932a9622 100644 --- a/codex-rs/core/src/windows_sandbox.rs +++ b/codex-rs/core/src/windows_sandbox.rs @@ -427,137 +427,5 @@ fn windows_sandbox_setup_mode_tag(mode: WindowsSandboxSetupMode) -> &'static str } #[cfg(test)] -mod tests { - use super::*; - use crate::config::types::WindowsToml; - use crate::features::Features; - use crate::features::FeaturesToml; - use pretty_assertions::assert_eq; - use std::collections::BTreeMap; - - #[test] - fn elevated_flag_works_by_itself() { - let mut features = Features::with_defaults(); - features.enable(Feature::WindowsSandboxElevated); - - assert_eq!( - WindowsSandboxLevel::from_features(&features), - WindowsSandboxLevel::Elevated - ); - } - - #[test] - fn restricted_token_flag_works_by_itself() { - let mut features = Features::with_defaults(); - features.enable(Feature::WindowsSandbox); - - assert_eq!( - WindowsSandboxLevel::from_features(&features), - WindowsSandboxLevel::RestrictedToken - ); - } - - #[test] - fn no_flags_means_no_sandbox() { - let features = Features::with_defaults(); - - assert_eq!( - WindowsSandboxLevel::from_features(&features), - WindowsSandboxLevel::Disabled - ); - } - - #[test] - fn elevated_wins_when_both_flags_are_enabled() { - let mut features = Features::with_defaults(); - features.enable(Feature::WindowsSandbox); - features.enable(Feature::WindowsSandboxElevated); - - assert_eq!( - WindowsSandboxLevel::from_features(&features), - WindowsSandboxLevel::Elevated - ); - } - - #[test] - fn legacy_mode_prefers_elevated() { - let mut entries = BTreeMap::new(); - entries.insert("experimental_windows_sandbox".to_string(), true); - entries.insert("elevated_windows_sandbox".to_string(), true); - - assert_eq!( - legacy_windows_sandbox_mode_from_entries(&entries), - Some(WindowsSandboxModeToml::Elevated) - ); - } - - #[test] - fn legacy_mode_supports_alias_key() { - let mut entries = BTreeMap::new(); - entries.insert("enable_experimental_windows_sandbox".to_string(), true); - - assert_eq!( - legacy_windows_sandbox_mode_from_entries(&entries), - Some(WindowsSandboxModeToml::Unelevated) - ); - } - - #[test] - fn resolve_windows_sandbox_mode_prefers_profile_windows() { - let cfg = ConfigToml { - windows: Some(WindowsToml { - sandbox: Some(WindowsSandboxModeToml::Unelevated), - }), - ..Default::default() - }; - let profile = ConfigProfile { - windows: Some(WindowsToml { - sandbox: Some(WindowsSandboxModeToml::Elevated), - }), - ..Default::default() - }; - - assert_eq!( - resolve_windows_sandbox_mode(&cfg, &profile), - Some(WindowsSandboxModeToml::Elevated) - ); - } - - #[test] - fn resolve_windows_sandbox_mode_falls_back_to_legacy_keys() { - let mut entries = BTreeMap::new(); - entries.insert("experimental_windows_sandbox".to_string(), true); - let cfg = ConfigToml { - features: Some(FeaturesToml { entries }), - ..Default::default() - }; - - assert_eq!( - resolve_windows_sandbox_mode(&cfg, &ConfigProfile::default()), - Some(WindowsSandboxModeToml::Unelevated) - ); - } - - #[test] - fn resolve_windows_sandbox_mode_profile_legacy_false_blocks_top_level_legacy_true() { - let mut profile_entries = BTreeMap::new(); - profile_entries.insert("experimental_windows_sandbox".to_string(), false); - let profile = ConfigProfile { - features: Some(FeaturesToml { - entries: profile_entries, - }), - ..Default::default() - }; - - let mut cfg_entries = BTreeMap::new(); - cfg_entries.insert("experimental_windows_sandbox".to_string(), true); - let cfg = ConfigToml { - features: Some(FeaturesToml { - entries: cfg_entries, - }), - ..Default::default() - }; - - assert_eq!(resolve_windows_sandbox_mode(&cfg, &profile), None); - } -} +#[path = "windows_sandbox_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/windows_sandbox_read_grants.rs b/codex-rs/core/src/windows_sandbox_read_grants.rs index 8fa843c8b4..ec08e55cfd 100644 --- a/codex-rs/core/src/windows_sandbox_read_grants.rs +++ b/codex-rs/core/src/windows_sandbox_read_grants.rs @@ -36,62 +36,5 @@ pub fn grant_read_root_non_elevated( } #[cfg(test)] -mod tests { - use super::grant_read_root_non_elevated; - use crate::protocol::SandboxPolicy; - use std::collections::HashMap; - use std::path::Path; - use tempfile::TempDir; - - fn policy() -> SandboxPolicy { - SandboxPolicy::new_workspace_write_policy() - } - - #[test] - fn rejects_relative_path() { - let tmp = TempDir::new().expect("tempdir"); - let err = grant_read_root_non_elevated( - &policy(), - tmp.path(), - tmp.path(), - &HashMap::new(), - tmp.path(), - Path::new("relative"), - ) - .expect_err("relative path should fail"); - assert!(err.to_string().contains("path must be absolute")); - } - - #[test] - fn rejects_missing_path() { - let tmp = TempDir::new().expect("tempdir"); - let missing = tmp.path().join("does-not-exist"); - let err = grant_read_root_non_elevated( - &policy(), - tmp.path(), - tmp.path(), - &HashMap::new(), - tmp.path(), - missing.as_path(), - ) - .expect_err("missing path should fail"); - assert!(err.to_string().contains("path does not exist")); - } - - #[test] - fn rejects_file_path() { - let tmp = TempDir::new().expect("tempdir"); - let file_path = tmp.path().join("file.txt"); - std::fs::write(&file_path, "hello").expect("write file"); - let err = grant_read_root_non_elevated( - &policy(), - tmp.path(), - tmp.path(), - &HashMap::new(), - tmp.path(), - file_path.as_path(), - ) - .expect_err("file path should fail"); - assert!(err.to_string().contains("path must be a directory")); - } -} +#[path = "windows_sandbox_read_grants_tests.rs"] +mod tests; diff --git a/codex-rs/core/src/windows_sandbox_read_grants_tests.rs b/codex-rs/core/src/windows_sandbox_read_grants_tests.rs new file mode 100644 index 0000000000..c23920264b --- /dev/null +++ b/codex-rs/core/src/windows_sandbox_read_grants_tests.rs @@ -0,0 +1,57 @@ +use super::grant_read_root_non_elevated; +use crate::protocol::SandboxPolicy; +use std::collections::HashMap; +use std::path::Path; +use tempfile::TempDir; + +fn policy() -> SandboxPolicy { + SandboxPolicy::new_workspace_write_policy() +} + +#[test] +fn rejects_relative_path() { + let tmp = TempDir::new().expect("tempdir"); + let err = grant_read_root_non_elevated( + &policy(), + tmp.path(), + tmp.path(), + &HashMap::new(), + tmp.path(), + Path::new("relative"), + ) + .expect_err("relative path should fail"); + assert!(err.to_string().contains("path must be absolute")); +} + +#[test] +fn rejects_missing_path() { + let tmp = TempDir::new().expect("tempdir"); + let missing = tmp.path().join("does-not-exist"); + let err = grant_read_root_non_elevated( + &policy(), + tmp.path(), + tmp.path(), + &HashMap::new(), + tmp.path(), + missing.as_path(), + ) + .expect_err("missing path should fail"); + assert!(err.to_string().contains("path does not exist")); +} + +#[test] +fn rejects_file_path() { + let tmp = TempDir::new().expect("tempdir"); + let file_path = tmp.path().join("file.txt"); + std::fs::write(&file_path, "hello").expect("write file"); + let err = grant_read_root_non_elevated( + &policy(), + tmp.path(), + tmp.path(), + &HashMap::new(), + tmp.path(), + file_path.as_path(), + ) + .expect_err("file path should fail"); + assert!(err.to_string().contains("path must be a directory")); +} diff --git a/codex-rs/core/src/windows_sandbox_tests.rs b/codex-rs/core/src/windows_sandbox_tests.rs new file mode 100644 index 0000000000..6bcd493ad4 --- /dev/null +++ b/codex-rs/core/src/windows_sandbox_tests.rs @@ -0,0 +1,132 @@ +use super::*; +use crate::config::types::WindowsToml; +use crate::features::Features; +use crate::features::FeaturesToml; +use pretty_assertions::assert_eq; +use std::collections::BTreeMap; + +#[test] +fn elevated_flag_works_by_itself() { + let mut features = Features::with_defaults(); + features.enable(Feature::WindowsSandboxElevated); + + assert_eq!( + WindowsSandboxLevel::from_features(&features), + WindowsSandboxLevel::Elevated + ); +} + +#[test] +fn restricted_token_flag_works_by_itself() { + let mut features = Features::with_defaults(); + features.enable(Feature::WindowsSandbox); + + assert_eq!( + WindowsSandboxLevel::from_features(&features), + WindowsSandboxLevel::RestrictedToken + ); +} + +#[test] +fn no_flags_means_no_sandbox() { + let features = Features::with_defaults(); + + assert_eq!( + WindowsSandboxLevel::from_features(&features), + WindowsSandboxLevel::Disabled + ); +} + +#[test] +fn elevated_wins_when_both_flags_are_enabled() { + let mut features = Features::with_defaults(); + features.enable(Feature::WindowsSandbox); + features.enable(Feature::WindowsSandboxElevated); + + assert_eq!( + WindowsSandboxLevel::from_features(&features), + WindowsSandboxLevel::Elevated + ); +} + +#[test] +fn legacy_mode_prefers_elevated() { + let mut entries = BTreeMap::new(); + entries.insert("experimental_windows_sandbox".to_string(), true); + entries.insert("elevated_windows_sandbox".to_string(), true); + + assert_eq!( + legacy_windows_sandbox_mode_from_entries(&entries), + Some(WindowsSandboxModeToml::Elevated) + ); +} + +#[test] +fn legacy_mode_supports_alias_key() { + let mut entries = BTreeMap::new(); + entries.insert("enable_experimental_windows_sandbox".to_string(), true); + + assert_eq!( + legacy_windows_sandbox_mode_from_entries(&entries), + Some(WindowsSandboxModeToml::Unelevated) + ); +} + +#[test] +fn resolve_windows_sandbox_mode_prefers_profile_windows() { + let cfg = ConfigToml { + windows: Some(WindowsToml { + sandbox: Some(WindowsSandboxModeToml::Unelevated), + }), + ..Default::default() + }; + let profile = ConfigProfile { + windows: Some(WindowsToml { + sandbox: Some(WindowsSandboxModeToml::Elevated), + }), + ..Default::default() + }; + + assert_eq!( + resolve_windows_sandbox_mode(&cfg, &profile), + Some(WindowsSandboxModeToml::Elevated) + ); +} + +#[test] +fn resolve_windows_sandbox_mode_falls_back_to_legacy_keys() { + let mut entries = BTreeMap::new(); + entries.insert("experimental_windows_sandbox".to_string(), true); + let cfg = ConfigToml { + features: Some(FeaturesToml { entries }), + ..Default::default() + }; + + assert_eq!( + resolve_windows_sandbox_mode(&cfg, &ConfigProfile::default()), + Some(WindowsSandboxModeToml::Unelevated) + ); +} + +#[test] +fn resolve_windows_sandbox_mode_profile_legacy_false_blocks_top_level_legacy_true() { + let mut profile_entries = BTreeMap::new(); + profile_entries.insert("experimental_windows_sandbox".to_string(), false); + let profile = ConfigProfile { + features: Some(FeaturesToml { + entries: profile_entries, + }), + ..Default::default() + }; + + let mut cfg_entries = BTreeMap::new(); + cfg_entries.insert("experimental_windows_sandbox".to_string(), true); + let cfg = ConfigToml { + features: Some(FeaturesToml { + entries: cfg_entries, + }), + ..Default::default() + }; + + assert_eq!(resolve_windows_sandbox_mode(&cfg, &profile), None); +}