mirror of
https://github.com/openai/codex.git
synced 2026-04-28 02:11:08 +03:00
feat: support mcp in-session login (#7751)
### Summary * Added `mcpServer/oauthLogin` in app server for supporting in session MCP server login * Added `McpServerOauthLoginParams` and `McpServerOauthLoginResponse` to support above method with response returning the auth URL for consumer to open browser or display accordingly. * Added `McpServerOauthLoginCompletedNotification` which the app server would emit on MCP server login success or failure (i.e. timeout). * Refactored rmcp-client oath_login to have the ability on starting a auth server which the codex_message_processor uses for in-session auth.
This commit is contained in:
@@ -16,7 +16,9 @@ pub use oauth::WrappedOAuthTokenResponse;
|
||||
pub use oauth::delete_oauth_tokens;
|
||||
pub(crate) use oauth::load_oauth_tokens;
|
||||
pub use oauth::save_oauth_tokens;
|
||||
pub use perform_oauth_login::OauthLoginHandle;
|
||||
pub use perform_oauth_login::perform_oauth_login;
|
||||
pub use perform_oauth_login::perform_oauth_login_return_url;
|
||||
pub use rmcp::model::ElicitationAction;
|
||||
pub use rmcp_client::Elicitation;
|
||||
pub use rmcp_client::ElicitationResponse;
|
||||
|
||||
@@ -22,6 +22,11 @@ use crate::save_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
|
||||
struct OauthHeaders {
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
struct CallbackServerGuard {
|
||||
server: Arc<Server>,
|
||||
}
|
||||
@@ -40,70 +45,52 @@ pub async fn perform_oauth_login(
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
scopes: &[String],
|
||||
) -> Result<()> {
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
server: Arc::clone(&server),
|
||||
let headers = OauthHeaders {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
};
|
||||
OauthLoginFlow::new(
|
||||
server_name,
|
||||
server_url,
|
||||
store_mode,
|
||||
headers,
|
||||
scopes,
|
||||
true,
|
||||
None,
|
||||
)
|
||||
.await?
|
||||
.finish()
|
||||
.await
|
||||
}
|
||||
|
||||
let redirect_uri = match server.server_addr() {
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
|
||||
format!("http://{}:{}/callback", addr.ip(), addr.port())
|
||||
}
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
|
||||
format!("http://[{}]:{}/callback", addr.ip(), addr.port())
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
_ => return Err(anyhow!("unable to determine callback address")),
|
||||
pub async fn perform_oauth_login_return_url(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
scopes: &[String],
|
||||
timeout_secs: Option<i64>,
|
||||
) -> Result<OauthLoginHandle> {
|
||||
let headers = OauthHeaders {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
};
|
||||
let flow = OauthLoginFlow::new(
|
||||
server_name,
|
||||
server_url,
|
||||
store_mode,
|
||||
headers,
|
||||
scopes,
|
||||
false,
|
||||
timeout_secs,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
spawn_callback_server(server, tx);
|
||||
let authorization_url = flow.authorization_url();
|
||||
let completion = flow.spawn();
|
||||
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
|
||||
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
|
||||
oauth_state
|
||||
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
|
||||
.await?;
|
||||
let auth_url = oauth_state.get_authorization_url().await?;
|
||||
|
||||
println!("Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n");
|
||||
|
||||
if webbrowser::open(&auth_url).is_err() {
|
||||
println!("(Browser launch failed; please copy the URL above manually.)");
|
||||
}
|
||||
|
||||
let (code, csrf_state) = timeout(Duration::from_secs(300), rx)
|
||||
.await
|
||||
.context("timed out waiting for OAuth callback")?
|
||||
.context("OAuth callback was cancelled")?;
|
||||
|
||||
oauth_state
|
||||
.handle_callback(&code, &csrf_state)
|
||||
.await
|
||||
.context("failed to handle OAuth callback")?;
|
||||
|
||||
let (client_id, credentials_opt) = oauth_state
|
||||
.get_credentials()
|
||||
.await
|
||||
.context("failed to retrieve OAuth credentials")?;
|
||||
let credentials =
|
||||
credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
|
||||
|
||||
let expires_at = compute_expires_at_millis(&credentials);
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: server_name.to_string(),
|
||||
url: server_url.to_string(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
expires_at,
|
||||
};
|
||||
save_oauth_tokens(server_name, &stored, store_mode)?;
|
||||
|
||||
drop(guard);
|
||||
Ok(())
|
||||
Ok(OauthLoginHandle::new(authorization_url, completion))
|
||||
}
|
||||
|
||||
fn spawn_callback_server(server: Arc<Server>, tx: oneshot::Sender<(String, String)>) {
|
||||
@@ -160,3 +147,181 @@ fn parse_oauth_callback(path: &str) -> Option<OauthCallbackResult> {
|
||||
state: state?,
|
||||
})
|
||||
}
|
||||
|
||||
pub struct OauthLoginHandle {
|
||||
authorization_url: String,
|
||||
completion: oneshot::Receiver<Result<()>>,
|
||||
}
|
||||
|
||||
impl OauthLoginHandle {
|
||||
fn new(authorization_url: String, completion: oneshot::Receiver<Result<()>>) -> Self {
|
||||
Self {
|
||||
authorization_url,
|
||||
completion,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn authorization_url(&self) -> &str {
|
||||
&self.authorization_url
|
||||
}
|
||||
|
||||
pub fn into_parts(self) -> (String, oneshot::Receiver<Result<()>>) {
|
||||
(self.authorization_url, self.completion)
|
||||
}
|
||||
|
||||
pub async fn wait(self) -> Result<()> {
|
||||
self.completion
|
||||
.await
|
||||
.map_err(|err| anyhow!("OAuth login task was cancelled: {err}"))?
|
||||
}
|
||||
}
|
||||
|
||||
struct OauthLoginFlow {
|
||||
auth_url: String,
|
||||
oauth_state: OAuthState,
|
||||
rx: oneshot::Receiver<(String, String)>,
|
||||
guard: CallbackServerGuard,
|
||||
server_name: String,
|
||||
server_url: String,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
launch_browser: bool,
|
||||
timeout: Duration,
|
||||
}
|
||||
|
||||
impl OauthLoginFlow {
|
||||
async fn new(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
headers: OauthHeaders,
|
||||
scopes: &[String],
|
||||
launch_browser: bool,
|
||||
timeout_secs: Option<i64>,
|
||||
) -> Result<Self> {
|
||||
const DEFAULT_OAUTH_TIMEOUT_SECS: i64 = 300;
|
||||
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
server: Arc::clone(&server),
|
||||
};
|
||||
|
||||
let redirect_uri = match server.server_addr() {
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
|
||||
let ip = addr.ip();
|
||||
let port = addr.port();
|
||||
format!("http://{ip}:{port}/callback")
|
||||
}
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
|
||||
let ip = addr.ip();
|
||||
let port = addr.port();
|
||||
format!("http://[{ip}]:{port}/callback")
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
_ => return Err(anyhow!("unable to determine callback address")),
|
||||
};
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
spawn_callback_server(server, tx);
|
||||
|
||||
let OauthHeaders {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} = headers;
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
|
||||
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
|
||||
oauth_state
|
||||
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
|
||||
.await?;
|
||||
let auth_url = oauth_state.get_authorization_url().await?;
|
||||
let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1);
|
||||
let timeout = Duration::from_secs(timeout_secs as u64);
|
||||
|
||||
Ok(Self {
|
||||
auth_url,
|
||||
oauth_state,
|
||||
rx,
|
||||
guard,
|
||||
server_name: server_name.to_string(),
|
||||
server_url: server_url.to_string(),
|
||||
store_mode,
|
||||
launch_browser,
|
||||
timeout,
|
||||
})
|
||||
}
|
||||
|
||||
fn authorization_url(&self) -> String {
|
||||
self.auth_url.clone()
|
||||
}
|
||||
|
||||
async fn finish(mut self) -> Result<()> {
|
||||
if self.launch_browser {
|
||||
let server_name = &self.server_name;
|
||||
let auth_url = &self.auth_url;
|
||||
println!(
|
||||
"Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n"
|
||||
);
|
||||
|
||||
if webbrowser::open(auth_url).is_err() {
|
||||
println!("(Browser launch failed; please copy the URL above manually.)");
|
||||
}
|
||||
}
|
||||
|
||||
let result = async {
|
||||
let (code, csrf_state) = timeout(self.timeout, &mut self.rx)
|
||||
.await
|
||||
.context("timed out waiting for OAuth callback")?
|
||||
.context("OAuth callback was cancelled")?;
|
||||
|
||||
self.oauth_state
|
||||
.handle_callback(&code, &csrf_state)
|
||||
.await
|
||||
.context("failed to handle OAuth callback")?;
|
||||
|
||||
let (client_id, credentials_opt) = self
|
||||
.oauth_state
|
||||
.get_credentials()
|
||||
.await
|
||||
.context("failed to retrieve OAuth credentials")?;
|
||||
let credentials = credentials_opt
|
||||
.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
|
||||
|
||||
let expires_at = compute_expires_at_millis(&credentials);
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: self.server_name.clone(),
|
||||
url: self.server_url.clone(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
expires_at,
|
||||
};
|
||||
save_oauth_tokens(&self.server_name, &stored, self.store_mode)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
drop(self.guard);
|
||||
result
|
||||
}
|
||||
|
||||
fn spawn(self) -> oneshot::Receiver<Result<()>> {
|
||||
let server_name_for_logging = self.server_name.clone();
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let result = self.finish().await;
|
||||
|
||||
if let Err(err) = &result {
|
||||
eprintln!(
|
||||
"Failed to complete OAuth login for '{server_name_for_logging}': {err:#}"
|
||||
);
|
||||
}
|
||||
|
||||
let _ = tx.send(result);
|
||||
});
|
||||
|
||||
rx
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user