diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index b5b65011e7..b1d8c25fb0 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -678,6 +678,7 @@ name = "codex-cli" version = "0.0.0" dependencies = [ "anyhow", + "assert_cmd", "clap", "clap_complete", "codex-arg0", @@ -691,7 +692,10 @@ dependencies = [ "codex-protocol-ts", "codex-tui", "serde_json", + "tempfile", "tokio", + "toml 0.9.5", + "toml_edit 0.23.4", "tracing", "tracing-subscriber", ] diff --git a/codex-rs/cli/Cargo.toml b/codex-rs/cli/Cargo.toml index f7af3349e0..bcd35c5a0e 100644 --- a/codex-rs/cli/Cargo.toml +++ b/codex-rs/cli/Cargo.toml @@ -28,6 +28,9 @@ codex-mcp-server = { path = "../mcp-server" } codex-protocol = { path = "../protocol" } codex-tui = { path = "../tui" } serde_json = "1" +toml = "0.9.5" +toml_edit = "0.23.4" +tempfile = "3" tokio = { version = "1", features = [ "io-std", "macros", @@ -38,3 +41,6 @@ tokio = { version = "1", features = [ tracing = "0.1.41" tracing-subscriber = "0.3.19" codex-protocol-ts = { path = "../protocol-ts" } + +[dev-dependencies] +assert_cmd = "2" diff --git a/codex-rs/cli/src/lib.rs b/codex-rs/cli/src/lib.rs index c6d80c0adf..8927be585d 100644 --- a/codex-rs/cli/src/lib.rs +++ b/codex-rs/cli/src/lib.rs @@ -1,6 +1,7 @@ pub mod debug_sandbox; mod exit_status; pub mod login; +pub mod mcp_cmd; pub mod proto; use clap::Parser; diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index 2acc3d84c5..54fbaf7479 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -11,6 +11,8 @@ use codex_cli::login::run_login_status; use codex_cli::login::run_login_with_api_key; use codex_cli::login::run_login_with_chatgpt; use codex_cli::login::run_logout; +use codex_cli::mcp_cmd; +use codex_cli::mcp_cmd::McpCli; use codex_cli::proto; use codex_common::CliConfigOverrides; use codex_exec::Cli as ExecCli; @@ -56,8 +58,8 @@ enum Subcommand { /// Remove stored authentication credentials. Logout(LogoutCommand), - /// Experimental: run Codex as an MCP server. - Mcp, + /// Experimental: run Codex as an MCP server and manage MCP config. + Mcp(McpCli), /// Run the Protocol stream via stdin/stdout #[clap(visible_alias = "p")] @@ -158,8 +160,9 @@ async fn cli_main(codex_linux_sandbox_exe: Option) -> anyhow::Result<() prepend_config_flags(&mut exec_cli.config_overrides, cli.config_overrides); codex_exec::run_main(exec_cli, codex_linux_sandbox_exe).await?; } - Some(Subcommand::Mcp) => { - codex_mcp_server::run_main(codex_linux_sandbox_exe, cli.config_overrides).await?; + Some(Subcommand::Mcp(mut mcp_cli)) => { + prepend_config_flags(&mut mcp_cli.config_overrides, cli.config_overrides); + mcp_cmd::run_main(mcp_cli, codex_linux_sandbox_exe).await?; } Some(Subcommand::Login(mut login_cli)) => { prepend_config_flags(&mut login_cli.config_overrides, cli.config_overrides); diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs new file mode 100644 index 0000000000..578814ef47 --- /dev/null +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -0,0 +1,641 @@ +use std::collections::BTreeSet; +use std::collections::HashMap; +use std::path::PathBuf; + +use anyhow::Result; +use clap::Parser; +use codex_common::CliConfigOverrides; +use codex_core::config::find_codex_home; +use codex_core::config::load_config_as_toml_with_cli_overrides; +use codex_core::config_types::McpServerConfig; +use codex_core::git_info::resolve_root_git_project_for_trust; +use codex_core::mcp_toml::McpToml; +use codex_core::mcp_toml::McpTomlEntry; +use codex_core::mcp_toml::load_project_overlays; +use codex_core::mcp_toml::to_mcp_server_config; +use serde_json::json; +use tempfile as _; +use toml::Value as TomlValue; +use toml_edit as _; // ensure dependency is linked + +#[derive(Debug, Parser)] +#[command( + about = "Manage MCP servers and run Codex as an MCP server", + long_about = "Manage Model Context Protocol (MCP) servers configured for Codex.\n\nUse subcommands to add, import, list, inspect, or remove servers.\nIf no subcommand is provided, this runs the built-in MCP server (back-compat).", + after_help = "Examples:\n # Add a local stdio server (everything after -- is the server command)\n codex mcp add airtable --env AIRTABLE_API_KEY=YOUR_KEY -- npx -y airtable-mcp-server\n\n # Import multiple servers from a TOML file into project scope\n codex mcp add-toml --scope project ./mcp.toml\n\n # List configured servers (merged view with precedence local > project > user)\n codex mcp list --json\n\n # Show details for a specific server\n codex mcp get airtable --json\n\n # Remove a server from the user scope\n codex mcp remove airtable --scope user\n\n # Remove a server from all scopes\n codex mcp remove airtable --all\n\n # Windows: wrap npx with cmd /c\n codex mcp add my-svc -- cmd /c npx -y @some/package" +)] +pub struct McpCli { + #[clap(skip)] + pub config_overrides: CliConfigOverrides, + + #[command(subcommand)] + pub cmd: Option, +} + +#[derive(Debug, clap::Subcommand)] +pub enum McpSub { + /// Run Codex as an MCP server (back-compat: `codex mcp`). + Serve, + /// List configured MCP servers (merged view). + List { + #[arg(long)] + json: bool, + }, + /// Get details for a specific server name (merged view). + Get { + name: String, + #[arg(long)] + json: bool, + }, + /// Add an MCP stdio server entry to a given scope. + Add(AddArgs), + /// Remove an MCP server entry from a given scope or all scopes. + Remove(RemoveArgs), + /// Import one or more MCP servers from a TOML file with a [mcp_servers] table. + AddToml(AddTomlArgs), +} + +pub async fn run_main(mcp_cli: McpCli, codex_linux_sandbox_exe: Option) -> Result<()> { + match mcp_cli.cmd.unwrap_or(McpSub::Serve) { + McpSub::Serve => { + // Preserve the historical `codex mcp` behavior. + codex_mcp_server::run_main(codex_linux_sandbox_exe, mcp_cli.config_overrides).await? + } + McpSub::List { json } => { + list_servers(mcp_cli.config_overrides, json)?; + } + McpSub::Get { name, json } => { + get_server(mcp_cli.config_overrides, &name, json)?; + } + McpSub::Add(args) => { + add_server(mcp_cli.config_overrides, args)?; + } + McpSub::Remove(args) => { + remove_server(mcp_cli.config_overrides, args)?; + } + McpSub::AddToml(args) => { + add_toml(mcp_cli.config_overrides, args)?; + } + } + Ok(()) +} + +fn parse_cli_overrides(overrides: CliConfigOverrides) -> Vec<(String, TomlValue)> { + overrides.parse_overrides().unwrap_or_default() +} + +fn load_user_project_local_maps( + cli_overrides: CliConfigOverrides, +) -> Result<( + HashMap, + HashMap, + HashMap, +)> { + // User map via `~/.codex/config.toml` (+ -c overrides) + let codex_home = find_codex_home()?; + let user_cfg = + load_config_as_toml_with_cli_overrides(&codex_home, parse_cli_overrides(cli_overrides))?; + let mut user_map = user_cfg.mcp_servers; + + // Project/local overlays via current project root + let cwd = std::env::current_dir()?; + let project_root = resolve_root_git_project_for_trust(&cwd).unwrap_or(cwd); + let overlays = load_project_overlays(&project_root)?; + + let mut project_map = HashMap::new(); + let mut local_map = HashMap::new(); + for (scope, overlay) in overlays { + for (name, entry) in overlay.mcp_servers.into_iter() { + // Convert permissive overlay entry → strict config, expanding env vars. + if let Ok(cfg) = to_mcp_server_config(&entry, |k| std::env::var(k).ok()) { + match scope { + codex_core::mcp_toml::Scope::Project => { + project_map.insert(name, cfg); + } + codex_core::mcp_toml::Scope::Local => { + local_map.insert(name, cfg); + } + codex_core::mcp_toml::Scope::User => { + user_map.insert(name, cfg); + } + } + } + } + } + + Ok((user_map, project_map, local_map)) +} + +fn list_servers(cli_overrides: CliConfigOverrides, json_out: bool) -> Result<()> { + let (user_map, project_map, local_map) = load_user_project_local_maps(cli_overrides)?; + let mut names: BTreeSet = BTreeSet::new(); + names.extend(user_map.keys().cloned()); + names.extend(project_map.keys().cloned()); + names.extend(local_map.keys().cloned()); + + if json_out { + let mut arr = Vec::new(); + for name in names { + let (scope, cfg, shadowed_by) = + pick_with_scope(&name, &user_map, &project_map, &local_map); + arr.push(json!({ + "name": name, + "scope": scope, + "config": cfg_to_json(cfg), + "shadowed_by": shadowed_by, + })); + } + println!("{}", serde_json::to_string_pretty(&arr)?); + } else { + for name in names { + let (scope, cfg, _) = pick_with_scope(&name, &user_map, &project_map, &local_map); + let args_preview = if cfg.args.is_empty() { + String::new() + } else { + format!(" {}", cfg.args.join(" ")) + }; + println!("{} [{}] -> {}{}", name, scope, cfg.command, args_preview); + } + } + Ok(()) +} + +fn get_server(cli_overrides: CliConfigOverrides, name: &str, json_out: bool) -> Result<()> { + let (user_map, project_map, local_map) = load_user_project_local_maps(cli_overrides)?; + if !user_map.contains_key(name) + && !project_map.contains_key(name) + && !local_map.contains_key(name) + { + anyhow::bail!("MCP server `{}` not found in any scope", name); + } + let (scope, cfg, shadowed_by) = pick_with_scope(name, &user_map, &project_map, &local_map); + if json_out { + let obj = json!({ + "name": name, + "scope": scope, + "config": cfg_to_json(cfg), + "shadowed_by": shadowed_by, + }); + println!("{}", serde_json::to_string_pretty(&obj)?); + } else { + let args_preview = if cfg.args.is_empty() { + String::new() + } else { + format!(" {}", cfg.args.join(" ")) + }; + println!("{} [{}] -> {}{}", name, scope, cfg.command, args_preview); + } + Ok(()) +} + +fn pick_with_scope<'a>( + name: &str, + user_map: &'a HashMap, + project_map: &'a HashMap, + local_map: &'a HashMap, +) -> (&'static str, &'a McpServerConfig, Vec<&'static str>) { + if let Some(cfg) = local_map.get(name) { + ( + "local", + cfg, + vec![ + if project_map.contains_key(name) { + "project" + } else { + "" + }, + if user_map.contains_key(name) { + "user" + } else { + "" + }, + ] + .into_iter() + .filter(|s| !s.is_empty()) + .collect(), + ) + } else if let Some(cfg) = project_map.get(name) { + ( + "project", + cfg, + vec![if user_map.contains_key(name) { + "user" + } else { + "" + }] + .into_iter() + .filter(|s| !s.is_empty()) + .collect(), + ) + } else if let Some(cfg) = user_map.get(name) { + ("user", cfg, vec![]) + } else { + // Should not occur because callers pre-check membership. Return a + // fallback to avoid panics in release builds. + let fallback = user_map + .iter() + .next() + .or_else(|| project_map.iter().next()) + .or_else(|| local_map.iter().next()); + let (k, v) = match fallback { + Some(kv) => kv, + None => panic!("internal error: no MCP server entries found across scopes"), + }; + let _ = k; // suppress unused warning + ("user", v, vec![]) + } +} + +fn cfg_to_json(cfg: &McpServerConfig) -> serde_json::Value { + json!({ + "command": cfg.command, + "args": cfg.args, + "env": cfg.env, + }) +} + +// ------------------------------ +// Add/remove writers +// ------------------------------ + +#[derive(Copy, Clone, Debug, clap::ValueEnum)] +enum ScopeArg { + Local, + Project, + User, +} + +#[derive(Debug, Parser)] +pub struct AddArgs { + /// Unique server name (^[A-Za-z0-9_-]+$) + name: String, + /// Target scope + #[arg(long, value_enum, default_value_t = ScopeArg::Local)] + scope: ScopeArg, + /// Environment variables KEY=VALUE (repeatable) + #[arg(long = "env")] + env: Vec, + /// Command and args to launch the MCP server (after `--`) + #[arg(trailing_var_arg = true)] + cmd: Vec, +} + +#[derive(Debug, Parser)] +pub struct RemoveArgs { + /// Server name + name: String, + /// Scope to remove from; omit with --all to remove everywhere + #[arg(long, value_enum)] + scope: Option, + /// Remove from all scopes + #[arg(long)] + all: bool, +} + +fn add_server(cli_overrides: CliConfigOverrides, args: AddArgs) -> Result<()> { + validate_server_name(&args.name)?; + if args.cmd.is_empty() { + anyhow::bail!( + "missing server command; use: codex mcp add [--scope ...] [--env KEY=VALUE]... -- [args...]" + ); + } + let command = args.cmd[0].clone(); + let cmd_args: Vec = args.cmd.iter().skip(1).cloned().collect(); + let env_map = parse_env_kv(args.env.iter())?; + + let path = match args.scope { + ScopeArg::User => { + write_user_scope(&args.name, &command, &cmd_args, &env_map, cli_overrides)? + } + ScopeArg::Project => write_overlay_scope(&args.name, &command, &cmd_args, &env_map, false)?, + ScopeArg::Local => write_overlay_scope(&args.name, &command, &cmd_args, &env_map, true)?, + }; + println!( + "Added MCP server '{}' (scope: {}) → wrote {}", + args.name, + match args.scope { + ScopeArg::Local => "local", + ScopeArg::Project => "project", + ScopeArg::User => "user", + }, + path.display() + ); + Ok(()) +} + +fn remove_server(cli_overrides: CliConfigOverrides, args: RemoveArgs) -> Result<()> { + if args.all && args.scope.is_some() { + anyhow::bail!("cannot use --scope with --all"); + } + + if args.all { + let u = remove_user_scope(&args.name, cli_overrides.clone())?; + if u.wrote { + println!("Removed '{}' → wrote {}", args.name, u.path.display()); + } + let p = remove_overlay_scope(&args.name, false)?; + if p.wrote { + println!("Removed '{}' → wrote {}", args.name, p.path.display()); + } + let l = remove_overlay_scope(&args.name, true)?; + if l.wrote { + println!("Removed '{}' → wrote {}", args.name, l.path.display()); + } + return Ok(()); + } + + let outcome = match args.scope.unwrap_or(ScopeArg::Local) { + ScopeArg::User => remove_user_scope(&args.name, cli_overrides)?, + ScopeArg::Project => remove_overlay_scope(&args.name, false)?, + ScopeArg::Local => remove_overlay_scope(&args.name, true)?, + }; + if outcome.wrote { + println!("Removed '{}' → wrote {}", args.name, outcome.path.display()); + } else { + println!( + "No changes for '{}' at {}", + args.name, + outcome.path.display() + ); + } + Ok(()) +} + +#[derive(Debug, Parser)] +pub struct AddTomlArgs { + /// Path to a TOML file containing a [mcp_servers] table + path: PathBuf, + /// Target scope to import into + #[arg(long, value_enum, default_value_t = ScopeArg::Local)] + scope: ScopeArg, +} + +fn add_toml(_cli_overrides: CliConfigOverrides, args: AddTomlArgs) -> Result<()> { + let contents = std::fs::read_to_string(&args.path)?; + let parsed: McpToml = toml::from_str(&contents)?; + let mut accepted: Vec<(String, McpTomlEntry)> = Vec::new(); + let mut rejected: Vec<(String, String)> = Vec::new(); + for (name, entry) in parsed.mcp_servers.into_iter() { + if let Some(t) = entry.r#type.as_deref() + && !t.eq_ignore_ascii_case("stdio") + { + rejected.push((name, format!("unsupported transport `{}`", t))); + continue; + } + if entry.command.is_none() { + rejected.push((name, "missing command".to_string())); + continue; + } + accepted.push((name, entry)); + } + + let path = match args.scope { + ScopeArg::User => write_user_batch(&accepted)?, + ScopeArg::Project => write_overlay_batch(&accepted, false)?, + ScopeArg::Local => write_overlay_batch(&accepted, true)?, + }; + println!( + "Imported {} MCP server(s) into {}", + accepted.len(), + path.display() + ); + + if !rejected.is_empty() { + for (n, why) in rejected { + eprintln!("skipped `{}`: {}", n, why); + } + } + Ok(()) +} + +fn parse_env_kv<'a>(pairs: impl Iterator) -> Result> { + let mut map = HashMap::new(); + for p in pairs { + if let Some((k, v)) = p.split_once('=') { + if k.is_empty() { + anyhow::bail!("invalid --env '{}': empty key", p); + } + map.insert(k.to_string(), v.to_string()); + } else { + anyhow::bail!("invalid --env '{}': expected KEY=VALUE", p); + } + } + Ok(map) +} + +fn validate_server_name(name: &str) -> Result<()> { + let ok = !name.is_empty() + && name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-'); + if ok { + Ok(()) + } else { + anyhow::bail!( + "invalid server name '{}': must match ^[a-zA-Z0-9_-]+$", + name + ) + } +} + +fn resolve_codex_home_for_write() -> Result { + if let Ok(val) = std::env::var("CODEX_HOME") + && !val.is_empty() + { + let p = PathBuf::from(val); + if !p.exists() { + std::fs::create_dir_all(&p)?; + } + return Ok(p.canonicalize().unwrap_or(p)); + } + let p = find_codex_home()?; + if !p.exists() { + std::fs::create_dir_all(&p)?; + } + Ok(p) +} + +fn write_user_scope( + name: &str, + command: &str, + args: &[String], + env_map: &HashMap, + cli_overrides: CliConfigOverrides, +) -> Result { + let codex_home = resolve_codex_home_for_write()?; + let path = codex_home.join("config.toml"); + let contents = std::fs::read_to_string(&path).unwrap_or_default(); + let mut doc = contents + .parse::() + .unwrap_or_default(); + upsert_mcp_entry(&mut doc, name, command, args, env_map); + write_doc_atomic(&doc, &path)?; + let _ = cli_overrides; + Ok(path) +} + +fn write_overlay_scope( + name: &str, + command: &str, + args: &[String], + env_map: &HashMap, + local: bool, +) -> Result { + let cwd = std::env::current_dir()?; + let project_root = resolve_root_git_project_for_trust(&cwd).unwrap_or(cwd); + let fname = if local { + ".mcp.local.toml" + } else { + ".mcp.toml" + }; + let path = project_root.join(fname); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let contents = std::fs::read_to_string(&path).unwrap_or_default(); + let mut doc = contents + .parse::() + .unwrap_or_default(); + upsert_mcp_entry(&mut doc, name, command, args, env_map); + write_doc_atomic(&doc, &path)?; + Ok(path) +} + +fn write_user_batch(entries: &[(String, McpTomlEntry)]) -> Result { + let codex_home = resolve_codex_home_for_write()?; + let path = codex_home.join("config.toml"); + let contents = std::fs::read_to_string(&path).unwrap_or_default(); + let mut doc = contents + .parse::() + .unwrap_or_default(); + for (name, entry) in entries { + let args = entry.args.clone(); + let env_map = entry.env.clone(); + let command = entry.command.clone().unwrap_or_default(); + upsert_mcp_entry(&mut doc, name, &command, &args, &env_map); + } + write_doc_atomic(&doc, &path)?; + Ok(path) +} + +fn write_overlay_batch(entries: &[(String, McpTomlEntry)], local: bool) -> Result { + let cwd = std::env::current_dir()?; + let project_root = resolve_root_git_project_for_trust(&cwd).unwrap_or(cwd); + let fname = if local { + ".mcp.local.toml" + } else { + ".mcp.toml" + }; + let path = project_root.join(fname); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let contents = std::fs::read_to_string(&path).unwrap_or_default(); + let mut doc = contents + .parse::() + .unwrap_or_default(); + for (name, entry) in entries { + let args = entry.args.clone(); + let env_map = entry.env.clone(); + let command = entry.command.clone().unwrap_or_default(); + upsert_mcp_entry(&mut doc, name, &command, &args, &env_map); + } + write_doc_atomic(&doc, &path)?; + Ok(path) +} + +struct RemoveOutcome { + path: PathBuf, + wrote: bool, +} + +fn remove_user_scope(name: &str, _cli_overrides: CliConfigOverrides) -> Result { + let codex_home = resolve_codex_home_for_write()?; + let path = codex_home.join("config.toml"); + if !path.exists() { + return Ok(RemoveOutcome { path, wrote: false }); + } + let contents = std::fs::read_to_string(&path)?; + let mut doc = contents.parse::()?; + if let Some(tbl) = doc.get_mut("mcp_servers").and_then(|i| i.as_table_mut()) { + if tbl.remove(name).is_some() { + write_doc_atomic(&doc, &path)?; + return Ok(RemoveOutcome { path, wrote: true }); + } + } + Ok(RemoveOutcome { path, wrote: false }) +} + +fn remove_overlay_scope(name: &str, local: bool) -> Result { + let cwd = std::env::current_dir()?; + let project_root = resolve_root_git_project_for_trust(&cwd).unwrap_or(cwd); + let fname = if local { + ".mcp.local.toml" + } else { + ".mcp.toml" + }; + let path = project_root.join(fname); + if !path.exists() { + return Ok(RemoveOutcome { path, wrote: false }); + } + let contents = std::fs::read_to_string(&path)?; + let mut doc = contents.parse::()?; + if let Some(tbl) = doc.get_mut("mcp_servers").and_then(|i| i.as_table_mut()) { + if tbl.remove(name).is_some() { + write_doc_atomic(&doc, &path)?; + return Ok(RemoveOutcome { path, wrote: true }); + } + } + Ok(RemoveOutcome { path, wrote: false }) +} + +fn upsert_mcp_entry( + doc: &mut toml_edit::DocumentMut, + name: &str, + command: &str, + args: &[String], + env_map: &HashMap, +) { + if !doc.as_table().contains_key("mcp_servers") { + doc.insert("mcp_servers", toml_edit::table()); + } + let tbl = doc["mcp_servers"].as_table_mut().expect("table"); + tbl.set_implicit(false); + + if !tbl.contains_key(name) { + tbl.insert(name, toml_edit::table()); + } + let st = tbl[name].as_table_mut().expect("subtable"); + st.set_implicit(false); + + st["command"] = toml_edit::value(command); + let mut arr = toml_edit::Array::new(); + for a in args { + arr.push(a.as_str()); + } + st["args"] = toml_edit::Item::Value(toml_edit::Value::Array(arr)); + + if env_map.is_empty() { + if st.contains_key("env") { + st.remove("env"); + } + } else { + let mut kv = toml_edit::InlineTable::new(); + for (k, v) in env_map { + kv.get_or_insert(k, toml_edit::Value::from(v.as_str())); + } + st["env"] = toml_edit::Item::Value(toml_edit::Value::InlineTable(kv)); + } +} + +fn write_doc_atomic(doc: &toml_edit::DocumentMut, path: &PathBuf) -> Result<()> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let tmp = tempfile::NamedTempFile::new_in( + path.parent().unwrap_or_else(|| std::path::Path::new(".")), + )?; + std::fs::write(tmp.path(), doc.to_string())?; + tmp.persist(path)?; + Ok(()) +} diff --git a/codex-rs/cli/tests/mcp_add_remove.rs b/codex-rs/cli/tests/mcp_add_remove.rs new file mode 100644 index 0000000000..da9b8bcd8a --- /dev/null +++ b/codex-rs/cli/tests/mcp_add_remove.rs @@ -0,0 +1,88 @@ +use assert_cmd::prelude::*; +use std::fs; +use std::process::Command; + +fn write(path: &std::path::Path, contents: &str) { + fs::write(path, contents).unwrap(); +} + +#[test] +fn add_and_remove_user_scope() { + let codex_home = tempfile::tempdir().unwrap(); + // Pre-create CODEX_HOME for canonicalization logic + let config_path = codex_home.path().join("config.toml"); + + let project_dir = tempfile::tempdir().unwrap(); + write(&project_dir.path().join(".git"), "gitdir: nowhere"); + + // Add + Command::cargo_bin("codex") + .unwrap() + .current_dir(project_dir.path()) + .env("CODEX_HOME", codex_home.path()) + .args([ + "mcp", "add", "svc", "--scope", "user", "--", "tool", "--flag", + ]) + .assert() + .success(); + + let config = fs::read_to_string(&config_path).unwrap(); + assert!(config.contains("[mcp_servers.svc]")); + assert!(config.contains("command = \"tool\"")); + + // Remove + Command::cargo_bin("codex") + .unwrap() + .current_dir(project_dir.path()) + .env("CODEX_HOME", codex_home.path()) + .args(["mcp", "remove", "svc", "--scope", "user"]) + .assert() + .success(); + + let config_after = fs::read_to_string(&config_path).unwrap(); + assert!(!config_after.contains("[mcp_servers.svc]")); +} + +#[test] +fn add_local_and_project_scopes() { + let codex_home = tempfile::tempdir().unwrap(); + let project_dir = tempfile::tempdir().unwrap(); + write(&project_dir.path().join(".git"), "gitdir: nowhere"); + + // Add project + Command::cargo_bin("codex") + .unwrap() + .current_dir(project_dir.path()) + .env("CODEX_HOME", codex_home.path()) + .args(["mcp", "add", "svc", "--scope", "project", "--", "toolp"]) + .assert() + .success(); + let proj = fs::read_to_string(project_dir.path().join(".mcp.toml")).unwrap(); + assert!(proj.contains("[mcp_servers.svc]")); + assert!(proj.contains("toolp")); + + // Add local (override in precedence for merged view) + Command::cargo_bin("codex") + .unwrap() + .current_dir(project_dir.path()) + .env("CODEX_HOME", codex_home.path()) + .args(["mcp", "add", "svc", "--scope", "local", "--", "tooll"]) + .assert() + .success(); + let local = fs::read_to_string(project_dir.path().join(".mcp.local.toml")).unwrap(); + assert!(local.contains("tooll")); + + // Remove all + Command::cargo_bin("codex") + .unwrap() + .current_dir(project_dir.path()) + .env("CODEX_HOME", codex_home.path()) + .args(["mcp", "remove", "svc", "--all"]) + .assert() + .success(); + + let proj_after = fs::read_to_string(project_dir.path().join(".mcp.toml")).unwrap(); + assert!(!proj_after.contains("[mcp_servers.svc]")); + let local_after = fs::read_to_string(project_dir.path().join(".mcp.local.toml")).unwrap(); + assert!(!local_after.contains("[mcp_servers.svc]")); +} diff --git a/codex-rs/cli/tests/mcp_add_toml.rs b/codex-rs/cli/tests/mcp_add_toml.rs new file mode 100644 index 0000000000..8f4f4192e8 --- /dev/null +++ b/codex-rs/cli/tests/mcp_add_toml.rs @@ -0,0 +1,133 @@ +use assert_cmd::prelude::*; +use serde_json::Value; +use std::fs; +use std::process::Command; + +fn write(path: &std::path::Path, contents: &str) { + fs::write(path, contents).unwrap(); +} + +#[test] +fn add_toml_local_filters_non_stdio_and_lists() { + let codex_home = tempfile::tempdir().unwrap(); + let project_dir = tempfile::tempdir().unwrap(); + write(&project_dir.path().join(".git"), "gitdir: nowhere"); + + let import = tempfile::NamedTempFile::new().unwrap(); + write( + import.path(), + r#"[mcp_servers.ok] +type = "stdio" +command = "tool" +args = ["--x"] +env = { K = "V" } + +[mcp_servers.bad] +type = "http" +url = "https://example.invalid/mcp" + +[mcp_servers.missing] +type = "stdio" +"#, + ); + + // Import into local scope + Command::cargo_bin("codex") + .unwrap() + .current_dir(project_dir.path()) + .env("CODEX_HOME", codex_home.path()) + .args([ + "mcp", + "add-toml", + "--scope", + "local", + import.path().to_str().unwrap(), + ]) + .assert() + .success(); + + // Verify file contents + let local_contents = fs::read_to_string(project_dir.path().join(".mcp.local.toml")).unwrap(); + assert!(local_contents.contains("[mcp_servers.ok]")); + assert!(local_contents.contains("command = \"tool\"")); + assert!(!local_contents.contains("[mcp_servers.bad]")); + assert!(!local_contents.contains("[mcp_servers.missing]")); + + // And list shows only the accepted entry, with local scope + let out = Command::cargo_bin("codex") + .unwrap() + .current_dir(project_dir.path()) + .env("CODEX_HOME", codex_home.path()) + .args(["mcp", "list", "--json"]) + .assert() + .success() + .get_output() + .stdout + .clone(); + let v: Value = serde_json::from_slice(&out).unwrap(); + let arr = v.as_array().unwrap(); + let mut seen_ok = false; + for e in arr { + if e.get("name").and_then(|x| x.as_str()) == Some("ok") { + assert_eq!(e.get("scope").and_then(|x| x.as_str()), Some("local")); + seen_ok = true; + } + assert_ne!(e.get("name").and_then(|x| x.as_str()), Some("bad")); + assert_ne!(e.get("name").and_then(|x| x.as_str()), Some("missing")); + } + assert!( + seen_ok, + "expected to find imported 'ok' entry in list output" + ); +} + +#[test] +fn add_toml_user_and_get() { + let codex_home = tempfile::tempdir().unwrap(); + let project_dir = tempfile::tempdir().unwrap(); + write(&project_dir.path().join(".git"), "gitdir: nowhere"); + + let import = tempfile::NamedTempFile::new().unwrap(); + write( + import.path(), + r#"[mcp_servers.userok] +type = "stdio" +command = "utool" +"#, + ); + + // Import into user scope + Command::cargo_bin("codex") + .unwrap() + .current_dir(project_dir.path()) + .env("CODEX_HOME", codex_home.path()) + .args([ + "mcp", + "add-toml", + "--scope", + "user", + import.path().to_str().unwrap(), + ]) + .assert() + .success(); + + // Get shows the user scope + let out = Command::cargo_bin("codex") + .unwrap() + .current_dir(project_dir.path()) + .env("CODEX_HOME", codex_home.path()) + .args(["mcp", "get", "userok", "--json"]) + .assert() + .success() + .get_output() + .stdout + .clone(); + let v: Value = serde_json::from_slice(&out).unwrap(); + assert_eq!(v.get("scope").and_then(|x| x.as_str()), Some("user")); + assert_eq!( + v.get("config") + .and_then(|c| c.get("command")) + .and_then(|x| x.as_str()), + Some("utool") + ); +} diff --git a/codex-rs/cli/tests/mcp_get.rs b/codex-rs/cli/tests/mcp_get.rs new file mode 100644 index 0000000000..8bda855735 --- /dev/null +++ b/codex-rs/cli/tests/mcp_get.rs @@ -0,0 +1,52 @@ +use assert_cmd::prelude::*; +use serde_json::Value; +use std::fs; +use std::process::Command; + +fn write(path: &std::path::Path, contents: &str) { + fs::write(path, contents).unwrap(); +} + +#[test] +fn get_returns_winning_scope() { + let codex_home = tempfile::tempdir().unwrap(); + write( + &codex_home.path().join("config.toml"), + r#"[mcp_servers.svc] +command = "user-cmd" +"#, + ); + + let project_dir = tempfile::tempdir().unwrap(); + write(&project_dir.path().join(".git"), "gitdir: nowhere"); + write( + &project_dir.path().join(".mcp.toml"), + r#"[mcp_servers.svc] +command = "project-cmd" +"#, + ); + write( + &project_dir.path().join(".mcp.local.toml"), + r#"[mcp_servers.svc] +command = "local-cmd" +"#, + ); + + let assert = Command::cargo_bin("codex") + .unwrap() + .current_dir(project_dir.path()) + .env("CODEX_HOME", codex_home.path()) + .args(["mcp", "get", "svc", "--json"]) + .assert() + .success(); + let out = String::from_utf8(assert.get_output().stdout.clone()).unwrap(); + let v: Value = serde_json::from_str(&out).unwrap(); + assert_eq!(v.get("name").and_then(|x| x.as_str()), Some("svc")); + assert_eq!(v.get("scope").and_then(|x| x.as_str()), Some("local")); + assert_eq!( + v.get("config") + .and_then(|c| c.get("command")) + .and_then(|x| x.as_str()), + Some("local-cmd") + ); +} diff --git a/codex-rs/cli/tests/mcp_list.rs b/codex-rs/cli/tests/mcp_list.rs new file mode 100644 index 0000000000..92429955bb --- /dev/null +++ b/codex-rs/cli/tests/mcp_list.rs @@ -0,0 +1,71 @@ +use assert_cmd::prelude::*; +use serde_json::Value; +use std::fs; +use std::process::Command; + +fn write(path: &std::path::Path, contents: &str) { + fs::write(path, contents).unwrap(); +} + +#[test] +fn list_shows_scopes_for_user_project_local() { + let codex_home = tempfile::tempdir().unwrap(); + write( + &codex_home.path().join("config.toml"), + r#"[mcp_servers.user_svc] +command = "user-cmd" +"#, + ); + + let project_dir = tempfile::tempdir().unwrap(); + // Mark git root for nicer parity with real use + write(&project_dir.path().join(".git"), "gitdir: nowhere"); + write( + &project_dir.path().join(".mcp.toml"), + r#"[mcp_servers.proj_svc] +command = "proj-cmd" +"#, + ); + write( + &project_dir.path().join(".mcp.local.toml"), + r#"[mcp_servers.local_svc] +command = "local-cmd" +"#, + ); + + let assert = Command::cargo_bin("codex") + .unwrap() + .current_dir(project_dir.path()) + .env("CODEX_HOME", codex_home.path()) + .args(["mcp", "list", "--json"]) + .assert() + .success(); + let out = String::from_utf8(assert.get_output().stdout.clone()).unwrap(); + let v: Value = serde_json::from_str(&out).unwrap(); + let arr = v.as_array().unwrap(); + + let mut found = (false, false, false); + for e in arr { + let name = e.get("name").and_then(|x| x.as_str()).unwrap(); + let scope = e.get("scope").and_then(|x| x.as_str()).unwrap(); + match name { + "user_svc" => { + assert_eq!(scope, "user"); + found.0 = true; + } + "proj_svc" => { + assert_eq!(scope, "project"); + found.1 = true; + } + "local_svc" => { + assert_eq!(scope, "local"); + found.2 = true; + } + _ => {} + } + } + assert!( + found.0 && found.1 && found.2, + "expected three entries across scopes" + ); +} diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 98a8fde135..6d77436869 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -8,6 +8,8 @@ use crate::config_types::Tui; use crate::config_types::UriBasedFileOpener; use crate::config_types::Verbosity; use crate::git_info::resolve_root_git_project_for_trust; +use crate::mcp_toml::load_project_overlays; +use crate::mcp_toml::to_mcp_server_config; use crate::model_family::ModelFamily; use crate::model_family::find_family_for_model; use crate::model_provider_info::ModelProviderInfo; @@ -599,6 +601,7 @@ impl Config { overrides: ConfigOverrides, codex_home: PathBuf, ) -> std::io::Result { + let mut cfg = cfg; let user_instructions = Self::load_instructions(Some(&codex_home)); // Destructure ConfigOverrides fully to ensure all overrides are applied. @@ -710,6 +713,37 @@ impl Config { let experimental_resume = cfg.experimental_resume; + // Merge project overlays (.mcp.toml and .mcp.local.toml) with precedence: + // user (config.toml) < project < local. Skip invalid or non-stdio entries. + // Determine project root using the same logic as trust checks. + let project_root = + resolve_root_git_project_for_trust(&resolved_cwd).unwrap_or(resolved_cwd.clone()); + if let Ok(overlays) = load_project_overlays(&project_root) { + // Start from user-defined servers from config.toml + let mut merged = std::mem::take(&mut cfg.mcp_servers); + + // Apply in ascending precedence order: project then local. + for (scope, overlay) in overlays.iter().rev() { + for (name, entry) in overlay.mcp_servers.iter() { + match to_mcp_server_config(entry, |k| std::env::var(k).ok()) { + Ok(server_cfg) => { + merged.insert(name.clone(), server_cfg); + } + Err(e) => { + tracing::warn!( + "Skipping MCP server `{}` from {:?} overlay: {:#}", + name, + scope, + e + ); + } + } + } + } + + cfg.mcp_servers = merged; + } + // Load base instructions override from a file if specified. If the // path is relative, resolve it against the effective cwd so the // behaviour matches other path-like config values. diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index ae18332087..2bc3f9bc35 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -27,6 +27,7 @@ pub mod git_info; mod is_safe_command; pub mod landlock; mod mcp_connection_manager; +pub mod mcp_toml; mod mcp_tool_call; mod message_history; mod model_provider_info; diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index b5813a0462..078e8ff7f6 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -36,8 +36,33 @@ use crate::config_types::McpServerConfig; const MCP_TOOL_NAME_DELIMITER: &str = "__"; const MAX_TOOL_NAME_LENGTH: usize = 64; -/// Timeout for the `tools/list` request. -const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10); +/// Read MCP timeout (milliseconds) from the environment. +/// +/// Falls back to 10_000 ms (10s) if the variable is not set or cannot be +/// parsed as an integer. +pub(crate) fn mcp_timeout_from_env() -> Duration { + match std::env::var("MCP_TIMEOUT") { + Ok(val) => parse_mcp_timeout(Some(val.trim())), + Err(_) => parse_mcp_timeout(None), + } +} + +pub(crate) fn parse_mcp_timeout(val: Option<&str>) -> Duration { + const DEFAULT_MS: u64 = 10_000; + match val { + Some(s) => match s.parse::() { + Ok(ms) => Duration::from_millis(ms), + Err(_) => { + tracing::warn!( + "Invalid MCP_TIMEOUT value, using default of {} ms", + DEFAULT_MS + ); + Duration::from_millis(DEFAULT_MS) + } + }, + None => Duration::from_millis(DEFAULT_MS), + } +} /// Map that holds a startup error for every MCP server that could **not** be /// spawned successfully. @@ -154,7 +179,7 @@ impl McpConnectionManager { protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), }; let initialize_notification_params = None; - let timeout = Some(Duration::from_secs(10)); + let timeout = Some(mcp_timeout_from_env()); match client .initialize(params, initialize_notification_params, timeout) .await @@ -242,7 +267,7 @@ async fn list_all_tools( let client_clone = client.clone(); join_set.spawn(async move { let res = client_clone - .list_tools(None, Some(LIST_TOOLS_TIMEOUT)) + .list_tools(None, Some(mcp_timeout_from_env())) .await; (server_name_cloned, res) }); @@ -285,6 +310,24 @@ mod tests { use super::*; use mcp_types::ToolInputSchema; + #[test] + fn test_mcp_timeout_default_is_10s() { + let d = parse_mcp_timeout(None); + assert_eq!(d, Duration::from_millis(10_000)); + } + + #[test] + fn test_mcp_timeout_parses_ms() { + let d = parse_mcp_timeout(Some("1234")); + assert_eq!(d, Duration::from_millis(1234)); + } + + #[test] + fn test_mcp_timeout_invalid_uses_default() { + let d = parse_mcp_timeout(Some("abc")); + assert_eq!(d, Duration::from_millis(10_000)); + } + fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { ToolInfo { server_name: server_name.to_string(), diff --git a/codex-rs/core/src/mcp_toml.rs b/codex-rs/core/src/mcp_toml.rs new file mode 100644 index 0000000000..9221c19d4b --- /dev/null +++ b/codex-rs/core/src/mcp_toml.rs @@ -0,0 +1,395 @@ +use anyhow::Result; +use anyhow::anyhow; +use serde::Deserialize; +use std::collections::HashMap; +use std::path::Path; + +use crate::config_types::McpServerConfig; + +/// Expand `${VAR}` and `${VAR:-default}` sequences in `input`. +/// +/// - `${VAR}`: replaced by `lookup(VAR)` or returns an error if unset. +/// - `${VAR:-default}`: replaced by `lookup(VAR)` if set; otherwise `default`. +/// +/// No whitespace is trimmed. Defaults are treated as literal strings (no nested +/// expansions inside the default value). Variable names must match +/// `^[A-Za-z_][A-Za-z0-9_]*$`. +pub(crate) fn expand_vars( + input: &str, + mut lookup: impl FnMut(&str) -> Option, + source_label: &str, +) -> Result { + let mut out = String::with_capacity(input.len()); + let bytes = input.as_bytes(); + let mut i = 0; + while i < bytes.len() { + if bytes[i] == b'$' && i + 1 < bytes.len() && bytes[i + 1] == b'{' { + // Find closing brace + let start_inner = i + 2; + let mut end = start_inner; + let mut found = false; + while end < bytes.len() { + if bytes[end] == b'}' { + found = true; + break; + } + end += 1; + } + if !found { + return Err(anyhow!( + "unterminated variable expansion starting at byte {i} in {source_label}" + )); + } + let inner = &input[start_inner..end]; + let (name, default) = match inner.split_once(":-") { + Some((n, d)) => (n, Some(d)), + None => (inner, None), + }; + + if !is_valid_var_name(name) { + return Err(anyhow!( + "invalid variable name `{}` in {} (must match ^[A-Za-z_][A-Za-z0-9_]*$)", + name, + source_label + )); + } + + let replacement = match (lookup(name), default) { + (Some(v), _) => v, + (None, Some(d)) => d.to_string(), + (None, None) => { + return Err(anyhow!( + "environment variable `{}` not set and no default provided in {}", + name, + source_label + )); + } + }; + out.push_str(&replacement); + i = end + 1; + continue; + } + // Copy through single byte as UTF-8 is preserved by slicing boundaries here. + out.push(bytes[i] as char); + i += 1; + } + Ok(out) +} + +fn is_valid_var_name(name: &str) -> bool { + let mut chars = name.chars(); + match chars.next() { + Some(c) if is_alpha_or_underscore(c) => (), + _ => return false, + } + chars.all(|c| is_alnum_or_underscore(c)) +} + +fn is_alpha_or_underscore(c: char) -> bool { + (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_' +} + +fn is_alnum_or_underscore(c: char) -> bool { + is_alpha_or_underscore(c) || (c >= '0' && c <= '9') +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_expand_vars_simple() { + let lookup = |k: &str| match k { + "USER" => Some("alice".into()), + _ => None, + }; + let res = expand_vars("/home/${USER}/bin", lookup, "test"); + match res { + Ok(s) => assert_eq!(s, "/home/alice/bin"), + Err(e) => panic!("unexpected error: {e:#}"), + } + } + + #[test] + fn test_expand_vars_with_default() { + let lookup = |_k: &str| None; + let res = expand_vars("${REGION:-us-east}", lookup, "test"); + match res { + Ok(s) => assert_eq!(s, "us-east"), + Err(e) => panic!("unexpected error: {e:#}"), + } + } + + #[test] + fn test_expand_vars_missing_errors() { + let lookup = |_k: &str| None; + let res = expand_vars("x${REQUIRED}y", lookup, "test"); + let msg = match res { + Ok(v) => panic!("expected error, got {v}"), + Err(e) => format!("{e:#}"), + }; + assert!(msg.contains("environment variable `REQUIRED` not set")); + } + + #[test] + fn test_expand_vars_multiple() { + let lookup = |k: &str| match k { + "A" => Some("1".into()), + "B" => Some("2".into()), + _ => None, + }; + let res = expand_vars("${A}-${B}-${C:-x}", lookup, "test"); + match res { + Ok(s) => assert_eq!(s, "1-2-x"), + Err(e) => panic!("unexpected error: {e:#}"), + } + } + + #[test] + fn test_expand_vars_invalid_name() { + let lookup = |_k: &str| None; + let res = expand_vars("${1BAD}", lookup, "test"); + let msg = match res { + Ok(v) => panic!("expected error, got {v}"), + Err(e) => format!("{e:#}"), + }; + assert!(msg.contains("invalid variable name")); + } + + #[test] + fn test_expand_vars_unterminated() { + let lookup = |_k: &str| None; + let res = expand_vars("abc ${FOO", lookup, "test-file"); + let msg = match res { + Ok(v) => panic!("expected error, got {v}"), + Err(e) => format!("{e:#}"), + }; + assert!(msg.contains("unterminated variable expansion")); + assert!(msg.contains("test-file")); + } +} + +// ------------------------------- +// Serde types and converters +// ------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Scope { + User, + Project, + Local, +} + +#[derive(Debug, Deserialize, Default)] +pub struct McpToml { + #[serde(default)] + pub mcp_servers: HashMap, +} + +#[derive(Debug, Deserialize, Default)] +pub struct McpTomlEntry { + #[serde(default)] + pub r#type: Option, + pub command: Option, + #[serde(default)] + pub args: Vec, + #[serde(default)] + pub env: HashMap, +} + +/// Convert a permissive TOML entry to the strict `McpServerConfig` used by Codex. +/// +/// - Only `stdio` (or missing) transport is accepted; anything else returns an error. +/// - Expands variables in `command`, each `args[]`, and each `env` value. +/// - Returns an error if `command` is missing (after expansion) or if any +/// `${VAR}` expansion fails with no default. +pub fn to_mcp_server_config( + entry: &McpTomlEntry, + mut lookup: impl FnMut(&str) -> Option, +) -> Result { + // Transport check: only allow stdio or unspecified + if let Some(t) = entry.r#type.as_deref() { + let t_lower = t.to_ascii_lowercase(); + if t_lower != "stdio" { + return Err(anyhow!( + "unsupported MCP transport `{}` (only `stdio` supported)", + t + )); + } + } + + // Command is required + let command_raw = entry + .command + .as_ref() + .ok_or_else(|| anyhow!("missing `command` for stdio MCP server"))?; + let command = expand_vars(command_raw, &mut lookup, "overlay:command")?; + + // Args with expansion + let mut args = Vec::with_capacity(entry.args.len()); + for a in &entry.args { + args.push(expand_vars(a, &mut lookup, "overlay:args")?); + } + + // Env values with expansion; keep as None if empty + let mut env_out: HashMap = HashMap::with_capacity(entry.env.len()); + for (k, v) in &entry.env { + env_out.insert(k.clone(), expand_vars(v, &mut lookup, "overlay:env")?); + } + + Ok(McpServerConfig { + command, + args, + env: if env_out.is_empty() { + None + } else { + Some(env_out) + }, + }) +} + +#[cfg(test)] +mod convert_tests { + use super::*; + + #[test] + fn test_to_mcp_server_config_stdio_ok() { + let entry = McpTomlEntry { + r#type: None, + command: Some("${HOME}/bin/svc".to_string()), + args: vec!["--region".into(), "${REGION:-us-east}".into()], + env: HashMap::from([(String::from("API_KEY"), String::from("${KEY}"))]), + }; + let mut map = HashMap::new(); + map.insert("HOME".to_string(), "/home/alice".to_string()); + map.insert("KEY".to_string(), "secret".to_string()); + let lookup = |k: &str| map.get(k).cloned(); + let cfg = match to_mcp_server_config(&entry, lookup) { + Ok(c) => c, + Err(e) => panic!("unexpected error: {e:#}"), + }; + assert_eq!(cfg.command, "/home/alice/bin/svc"); + assert_eq!(cfg.args, vec!["--region", "us-east"]); + let api_key = cfg.env.as_ref().and_then(|m| m.get("API_KEY")).cloned(); + assert_eq!(api_key.as_deref(), Some("secret")); + } + + #[test] + fn test_to_mcp_server_config_reject_non_stdio() { + for t in ["http", "sse", "HTTP", "SSe"] { + let entry = McpTomlEntry { + r#type: Some(t.to_string()), + command: Some("tool".to_string()), + ..Default::default() + }; + let msg = match to_mcp_server_config(&entry, |_k| None) { + Ok(v) => panic!("expected error, got {v:?}"), + Err(e) => format!("{e:#}"), + }; + assert!(msg.to_lowercase().contains("unsupported mcp transport")); + } + } + + #[test] + fn test_to_mcp_server_config_missing_command_errors() { + let entry = McpTomlEntry { + command: None, + ..Default::default() + }; + let msg = match to_mcp_server_config(&entry, |_k| None) { + Ok(v) => panic!("expected error, got {v:?}"), + Err(e) => format!("{e:#}"), + }; + assert!(msg.contains("missing `command`")); + } + + #[test] + fn test_to_mcp_server_config_missing_env_var_errors() { + let entry = McpTomlEntry { + command: Some("tool".into()), + args: vec!["${REQUIRED}".into()], + ..Default::default() + }; + let msg = match to_mcp_server_config(&entry, |_k| None) { + Ok(v) => panic!("expected error, got {v:?}"), + Err(e) => format!("{e:#}"), + }; + assert!(msg.contains("environment variable `REQUIRED` not set")); + } +} + +// ------------------------------- +// Overlay loader +// ------------------------------- + +/// Load `.mcp.local.toml` and `.mcp.toml` from `project_root` if they exist. +/// +/// Returns the successfully parsed overlays in precedence order: Local then Project. +/// Invalid TOML is logged and skipped. +pub fn load_project_overlays(project_root: &Path) -> Result> { + let mut overlays = Vec::new(); + + let local_path = project_root.join(".mcp.local.toml"); + if local_path.exists() { + match std::fs::read_to_string(&local_path) { + Ok(contents) => match toml::from_str::(&contents) { + Ok(parsed) => overlays.push((Scope::Local, parsed)), + Err(e) => tracing::warn!("Failed to parse {}: {e}", local_path.display()), + }, + Err(e) => tracing::warn!("Failed to read {}: {e}", local_path.display()), + } + } + + let project_path = project_root.join(".mcp.toml"); + if project_path.exists() { + match std::fs::read_to_string(&project_path) { + Ok(contents) => match toml::from_str::(&contents) { + Ok(parsed) => overlays.push((Scope::Project, parsed)), + Err(e) => tracing::warn!("Failed to parse {}: {e}", project_path.display()), + }, + Err(e) => tracing::warn!("Failed to read {}: {e}", project_path.display()), + } + } + + Ok(overlays) +} + +#[cfg(test)] +mod overlay_tests { + use super::*; + use std::fs; + + #[test] + fn test_load_project_overlays_reads_both_files() -> Result<()> { + let dir = tempfile::tempdir()?; + let root = dir.path(); + // Pretend it's a git repo to mirror typical layout; not required by loader. + fs::write(root.join(".git"), "gitdir: nowhere")?; + + // Write project overlay + fs::write( + root.join(".mcp.toml"), + r#"[mcp_servers.alpha] +command = "alpha" +"#, + )?; + + // Write local overlay + fs::write( + root.join(".mcp.local.toml"), + r#"[mcp_servers.beta] +command = "beta" +"#, + )?; + + let overlays = load_project_overlays(root)?; + assert_eq!(overlays.len(), 2); + + // Expect Local first, then Project (our precedence order for merging later) + assert!(matches!(overlays[0].0, Scope::Local)); + assert!(overlays[0].1.mcp_servers.contains_key("beta")); + assert!(matches!(overlays[1].0, Scope::Project)); + assert!(overlays[1].1.mcp_servers.contains_key("alpha")); + Ok(()) + } +} diff --git a/codex-rs/core/tests/suite/mcp_overlays.rs b/codex-rs/core/tests/suite/mcp_overlays.rs new file mode 100644 index 0000000000..535d8063c9 --- /dev/null +++ b/codex-rs/core/tests/suite/mcp_overlays.rs @@ -0,0 +1,57 @@ +use std::fs; +use std::path::PathBuf; + +use codex_core::config::Config; +use codex_core::config::ConfigOverrides; + +fn write(path: impl Into, contents: &str) { + let p: PathBuf = path.into(); + fs::write(&p, contents).unwrap_or_else(|e| panic!("failed writing {}: {e}", p.display())); +} + +#[test] +fn test_overlay_precedence_local_over_project_over_user() -> std::io::Result<()> { + // Set up a fake CODEX_HOME with a user-level MCP server. + let codex_home = tempfile::tempdir()?; + std::env::set_var("CODEX_HOME", codex_home.path()); + // Ensure directory exists before canonicalization in find_codex_home(). + let config_toml_path = codex_home.path().join("config.toml"); + write(&config_toml_path, r#"[mcp_servers.svc] +command = "user" +"#); + + // Set up a project directory with overlays. + let project_dir = tempfile::tempdir()?; + // Mark as git repo root (enough for resolve_root_git_project_for_trust()). + write(project_dir.path().join(".git"), "gitdir: nowhere"); + + // Project overlay defines the same server name. + write( + project_dir.path().join(".mcp.toml"), + r#"[mcp_servers.svc] +command = "project" +"#, + ); + // Local overlay should take precedence. + write( + project_dir.path().join(".mcp.local.toml"), + r#"[mcp_servers.svc] +command = "local" +"#, + ); + + let overrides = ConfigOverrides { + cwd: Some(project_dir.path().to_path_buf()), + ..Default::default() + }; + + let cfg = Config::load_with_cli_overrides(vec![], overrides)?; + let svc = cfg + .mcp_servers + .get("svc") + .expect("svc should be present after merge"); + assert_eq!(svc.command, "local"); + + Ok(()) +} +