mirror of
https://github.com/openai/codex.git
synced 2026-04-30 11:21:34 +03:00
Move config types into a separate crate because their macros expand into a lot of new code.
374 lines
11 KiB
Rust
374 lines
11 KiB
Rust
use std::collections::HashMap;
|
|
use std::time::Duration;
|
|
|
|
use anyhow::Error;
|
|
use anyhow::Result;
|
|
use codex_protocol::protocol::McpAuthStatus;
|
|
use reqwest::Client;
|
|
use reqwest::StatusCode;
|
|
use reqwest::Url;
|
|
use reqwest::header::AUTHORIZATION;
|
|
use reqwest::header::HeaderMap;
|
|
use serde::Deserialize;
|
|
use tracing::debug;
|
|
|
|
use crate::oauth::has_oauth_tokens;
|
|
use crate::utils::apply_default_headers;
|
|
use crate::utils::build_default_headers;
|
|
use codex_config::types::OAuthCredentialsStoreMode;
|
|
|
|
const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5);
|
|
const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version";
|
|
const OAUTH_DISCOVERY_VERSION: &str = "2024-11-05";
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct StreamableHttpOAuthDiscovery {
|
|
pub scopes_supported: Option<Vec<String>>,
|
|
}
|
|
|
|
/// Determine the authentication status for a streamable HTTP MCP server.
|
|
pub async fn determine_streamable_http_auth_status(
|
|
server_name: &str,
|
|
url: &str,
|
|
bearer_token_env_var: Option<&str>,
|
|
http_headers: Option<HashMap<String, String>>,
|
|
env_http_headers: Option<HashMap<String, String>>,
|
|
store_mode: OAuthCredentialsStoreMode,
|
|
) -> Result<McpAuthStatus> {
|
|
if bearer_token_env_var.is_some() {
|
|
return Ok(McpAuthStatus::BearerToken);
|
|
}
|
|
|
|
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
|
if default_headers.contains_key(AUTHORIZATION) {
|
|
return Ok(McpAuthStatus::BearerToken);
|
|
}
|
|
|
|
if has_oauth_tokens(server_name, url, store_mode)? {
|
|
return Ok(McpAuthStatus::OAuth);
|
|
}
|
|
|
|
match discover_streamable_http_oauth_with_headers(url, &default_headers).await {
|
|
Ok(Some(_)) => Ok(McpAuthStatus::NotLoggedIn),
|
|
Ok(None) => Ok(McpAuthStatus::Unsupported),
|
|
Err(error) => {
|
|
debug!(
|
|
"failed to detect OAuth support for MCP server `{server_name}` at {url}: {error:?}"
|
|
);
|
|
Ok(McpAuthStatus::Unsupported)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Attempt to determine whether a streamable HTTP MCP server advertises OAuth login.
|
|
pub async fn supports_oauth_login(url: &str) -> Result<bool> {
|
|
Ok(discover_streamable_http_oauth(
|
|
url, /*http_headers*/ None, /*env_http_headers*/ None,
|
|
)
|
|
.await?
|
|
.is_some())
|
|
}
|
|
|
|
pub async fn discover_streamable_http_oauth(
|
|
url: &str,
|
|
http_headers: Option<HashMap<String, String>>,
|
|
env_http_headers: Option<HashMap<String, String>>,
|
|
) -> Result<Option<StreamableHttpOAuthDiscovery>> {
|
|
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
|
discover_streamable_http_oauth_with_headers(url, &default_headers).await
|
|
}
|
|
|
|
async fn discover_streamable_http_oauth_with_headers(
|
|
url: &str,
|
|
default_headers: &HeaderMap,
|
|
) -> Result<Option<StreamableHttpOAuthDiscovery>> {
|
|
let base_url = Url::parse(url)?;
|
|
|
|
// Use no_proxy to avoid a bug in the system-configuration crate that
|
|
// can result in a panic. See #8912.
|
|
let builder = Client::builder().timeout(DISCOVERY_TIMEOUT).no_proxy();
|
|
let client = apply_default_headers(builder, default_headers).build()?;
|
|
|
|
let mut last_error: Option<Error> = None;
|
|
for candidate_path in discovery_paths(base_url.path()) {
|
|
let mut discovery_url = base_url.clone();
|
|
discovery_url.set_path(&candidate_path);
|
|
|
|
let response = match client
|
|
.get(discovery_url.clone())
|
|
.header(OAUTH_DISCOVERY_HEADER, OAUTH_DISCOVERY_VERSION)
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(response) => response,
|
|
Err(err) => {
|
|
last_error = Some(err.into());
|
|
continue;
|
|
}
|
|
};
|
|
|
|
if response.status() != StatusCode::OK {
|
|
continue;
|
|
}
|
|
|
|
let metadata = match response.json::<OAuthDiscoveryMetadata>().await {
|
|
Ok(metadata) => metadata,
|
|
Err(err) => {
|
|
last_error = Some(err.into());
|
|
continue;
|
|
}
|
|
};
|
|
|
|
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
|
|
return Ok(Some(StreamableHttpOAuthDiscovery {
|
|
scopes_supported: normalize_scopes(metadata.scopes_supported),
|
|
}));
|
|
}
|
|
}
|
|
|
|
if let Some(err) = last_error {
|
|
debug!("OAuth discovery requests failed for {url}: {err:?}");
|
|
}
|
|
|
|
Ok(None)
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OAuthDiscoveryMetadata {
|
|
#[serde(default)]
|
|
authorization_endpoint: Option<String>,
|
|
#[serde(default)]
|
|
token_endpoint: Option<String>,
|
|
#[serde(default)]
|
|
scopes_supported: Option<Vec<String>>,
|
|
}
|
|
|
|
fn normalize_scopes(scopes_supported: Option<Vec<String>>) -> Option<Vec<String>> {
|
|
let scopes_supported = scopes_supported?;
|
|
|
|
let mut normalized = Vec::new();
|
|
for scope in scopes_supported {
|
|
let scope = scope.trim();
|
|
if scope.is_empty() {
|
|
continue;
|
|
}
|
|
let scope = scope.to_string();
|
|
if !normalized.contains(&scope) {
|
|
normalized.push(scope);
|
|
}
|
|
}
|
|
|
|
if normalized.is_empty() {
|
|
None
|
|
} else {
|
|
Some(normalized)
|
|
}
|
|
}
|
|
|
|
/// Implements RFC 8414 section 3.1 for discovering well-known oauth endpoints.
|
|
/// This is a requirement for MCP servers to support OAuth.
|
|
/// https://datatracker.ietf.org/doc/html/rfc8414#section-3.1
|
|
/// https://github.com/modelcontextprotocol/rust-sdk/blob/main/crates/rmcp/src/transport/auth.rs#L182
|
|
fn discovery_paths(base_path: &str) -> Vec<String> {
|
|
let trimmed = base_path.trim_start_matches('/').trim_end_matches('/');
|
|
let canonical = "/.well-known/oauth-authorization-server".to_string();
|
|
|
|
if trimmed.is_empty() {
|
|
return vec![canonical];
|
|
}
|
|
|
|
let mut candidates = Vec::new();
|
|
let mut push_unique = |candidate: String| {
|
|
if !candidates.contains(&candidate) {
|
|
candidates.push(candidate);
|
|
}
|
|
};
|
|
|
|
push_unique(format!("{canonical}/{trimmed}"));
|
|
push_unique(format!("/{trimmed}/.well-known/oauth-authorization-server"));
|
|
push_unique(canonical);
|
|
|
|
candidates
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use axum::Json;
|
|
use axum::Router;
|
|
use axum::routing::get;
|
|
use pretty_assertions::assert_eq;
|
|
use serial_test::serial;
|
|
use std::collections::HashMap;
|
|
use std::ffi::OsString;
|
|
use tokio::task::JoinHandle;
|
|
|
|
struct TestServer {
|
|
url: String,
|
|
handle: JoinHandle<()>,
|
|
}
|
|
|
|
impl Drop for TestServer {
|
|
fn drop(&mut self) {
|
|
self.handle.abort();
|
|
}
|
|
}
|
|
|
|
async fn spawn_oauth_discovery_server(metadata: serde_json::Value) -> TestServer {
|
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
|
.await
|
|
.expect("listener should bind");
|
|
let address = listener.local_addr().expect("listener should have address");
|
|
let app = Router::new().route(
|
|
"/.well-known/oauth-authorization-server/mcp",
|
|
get({
|
|
let metadata = metadata.clone();
|
|
move || {
|
|
let metadata = metadata.clone();
|
|
async move { Json(metadata) }
|
|
}
|
|
}),
|
|
);
|
|
let handle = tokio::spawn(async move {
|
|
axum::serve(listener, app).await.expect("server should run");
|
|
});
|
|
|
|
TestServer {
|
|
url: format!("http://{address}/mcp"),
|
|
handle,
|
|
}
|
|
}
|
|
|
|
struct EnvVarGuard {
|
|
key: String,
|
|
original: Option<OsString>,
|
|
}
|
|
|
|
impl EnvVarGuard {
|
|
fn set(key: &str, value: &str) -> Self {
|
|
let original = std::env::var_os(key);
|
|
unsafe {
|
|
std::env::set_var(key, value);
|
|
}
|
|
Self {
|
|
key: key.to_string(),
|
|
original,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for EnvVarGuard {
|
|
fn drop(&mut self) {
|
|
if let Some(value) = &self.original {
|
|
unsafe {
|
|
std::env::set_var(&self.key, value);
|
|
}
|
|
} else {
|
|
unsafe {
|
|
std::env::remove_var(&self.key);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn determine_auth_status_uses_bearer_token_when_authorization_header_present() {
|
|
let status = determine_streamable_http_auth_status(
|
|
"server",
|
|
"not-a-url",
|
|
/*bearer_token_env_var*/ None,
|
|
Some(HashMap::from([(
|
|
"Authorization".to_string(),
|
|
"Bearer token".to_string(),
|
|
)])),
|
|
/*env_http_headers*/ None,
|
|
OAuthCredentialsStoreMode::Keyring,
|
|
)
|
|
.await
|
|
.expect("status should compute");
|
|
|
|
assert_eq!(status, McpAuthStatus::BearerToken);
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[serial(auth_status_env)]
|
|
async fn determine_auth_status_uses_bearer_token_when_env_authorization_header_present() {
|
|
let _guard = EnvVarGuard::set("CODEX_RMCP_CLIENT_AUTH_STATUS_TEST_TOKEN", "Bearer token");
|
|
let status = determine_streamable_http_auth_status(
|
|
"server",
|
|
"not-a-url",
|
|
/*bearer_token_env_var*/ None,
|
|
/*http_headers*/ None,
|
|
Some(HashMap::from([(
|
|
"Authorization".to_string(),
|
|
"CODEX_RMCP_CLIENT_AUTH_STATUS_TEST_TOKEN".to_string(),
|
|
)])),
|
|
OAuthCredentialsStoreMode::Keyring,
|
|
)
|
|
.await
|
|
.expect("status should compute");
|
|
|
|
assert_eq!(status, McpAuthStatus::BearerToken);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn discover_streamable_http_oauth_returns_normalized_scopes() {
|
|
let server = spawn_oauth_discovery_server(serde_json::json!({
|
|
"authorization_endpoint": "https://example.com/authorize",
|
|
"token_endpoint": "https://example.com/token",
|
|
"scopes_supported": ["profile", " email ", "profile", "", " "],
|
|
}))
|
|
.await;
|
|
|
|
let discovery = discover_streamable_http_oauth(
|
|
&server.url,
|
|
/*http_headers*/ None,
|
|
/*env_http_headers*/ None,
|
|
)
|
|
.await
|
|
.expect("discovery should succeed")
|
|
.expect("oauth support should be detected");
|
|
|
|
assert_eq!(
|
|
discovery.scopes_supported,
|
|
Some(vec!["profile".to_string(), "email".to_string()])
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn discover_streamable_http_oauth_ignores_empty_scopes() {
|
|
let server = spawn_oauth_discovery_server(serde_json::json!({
|
|
"authorization_endpoint": "https://example.com/authorize",
|
|
"token_endpoint": "https://example.com/token",
|
|
"scopes_supported": ["", " "],
|
|
}))
|
|
.await;
|
|
|
|
let discovery = discover_streamable_http_oauth(
|
|
&server.url,
|
|
/*http_headers*/ None,
|
|
/*env_http_headers*/ None,
|
|
)
|
|
.await
|
|
.expect("discovery should succeed")
|
|
.expect("oauth support should be detected");
|
|
|
|
assert_eq!(discovery.scopes_supported, None);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn supports_oauth_login_does_not_require_scopes_supported() {
|
|
let server = spawn_oauth_discovery_server(serde_json::json!({
|
|
"authorization_endpoint": "https://example.com/authorize",
|
|
"token_endpoint": "https://example.com/token",
|
|
}))
|
|
.await;
|
|
|
|
let supported = supports_oauth_login(&server.url)
|
|
.await
|
|
.expect("support check should succeed");
|
|
|
|
assert!(supported);
|
|
}
|
|
}
|