feat: support mcp in-session login (#7751)

### Summary
* Added `mcpServer/oauthLogin` in app server for supporting in session
MCP server login
* Added `McpServerOauthLoginParams` and `McpServerOauthLoginResponse` to
support above method with response returning the auth URL for consumer
to open browser or display accordingly.
* Added `McpServerOauthLoginCompletedNotification` which the app server
would emit on MCP server login success or failure (i.e. timeout).
* Refactored rmcp-client oath_login to have the ability on starting a
auth server which the codex_message_processor uses for in-session auth.
This commit is contained in:
Shijie Rao
2025-12-09 17:43:53 -08:00
committed by GitHub
parent fa4cac1e6b
commit 893f5261eb
8 changed files with 392 additions and 59 deletions

View File

@@ -16,7 +16,9 @@ pub use oauth::WrappedOAuthTokenResponse;
pub use oauth::delete_oauth_tokens;
pub(crate) use oauth::load_oauth_tokens;
pub use oauth::save_oauth_tokens;
pub use perform_oauth_login::OauthLoginHandle;
pub use perform_oauth_login::perform_oauth_login;
pub use perform_oauth_login::perform_oauth_login_return_url;
pub use rmcp::model::ElicitationAction;
pub use rmcp_client::Elicitation;
pub use rmcp_client::ElicitationResponse;

View File

@@ -22,6 +22,11 @@ use crate::save_oauth_tokens;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
struct OauthHeaders {
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
}
struct CallbackServerGuard {
server: Arc<Server>,
}
@@ -40,70 +45,52 @@ pub async fn perform_oauth_login(
env_http_headers: Option<HashMap<String, String>>,
scopes: &[String],
) -> Result<()> {
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
let guard = CallbackServerGuard {
server: Arc::clone(&server),
let headers = OauthHeaders {
http_headers,
env_http_headers,
};
OauthLoginFlow::new(
server_name,
server_url,
store_mode,
headers,
scopes,
true,
None,
)
.await?
.finish()
.await
}
let redirect_uri = match server.server_addr() {
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
format!("http://{}:{}/callback", addr.ip(), addr.port())
}
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
format!("http://[{}]:{}/callback", addr.ip(), addr.port())
}
#[cfg(not(target_os = "windows"))]
_ => return Err(anyhow!("unable to determine callback address")),
pub async fn perform_oauth_login_return_url(
server_name: &str,
server_url: &str,
store_mode: OAuthCredentialsStoreMode,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
scopes: &[String],
timeout_secs: Option<i64>,
) -> Result<OauthLoginHandle> {
let headers = OauthHeaders {
http_headers,
env_http_headers,
};
let flow = OauthLoginFlow::new(
server_name,
server_url,
store_mode,
headers,
scopes,
false,
timeout_secs,
)
.await?;
let (tx, rx) = oneshot::channel();
spawn_callback_server(server, tx);
let authorization_url = flow.authorization_url();
let completion = flow.spawn();
let default_headers = build_default_headers(http_headers, env_http_headers)?;
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
oauth_state
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
.await?;
let auth_url = oauth_state.get_authorization_url().await?;
println!("Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n");
if webbrowser::open(&auth_url).is_err() {
println!("(Browser launch failed; please copy the URL above manually.)");
}
let (code, csrf_state) = timeout(Duration::from_secs(300), rx)
.await
.context("timed out waiting for OAuth callback")?
.context("OAuth callback was cancelled")?;
oauth_state
.handle_callback(&code, &csrf_state)
.await
.context("failed to handle OAuth callback")?;
let (client_id, credentials_opt) = oauth_state
.get_credentials()
.await
.context("failed to retrieve OAuth credentials")?;
let credentials =
credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
let expires_at = compute_expires_at_millis(&credentials);
let stored = StoredOAuthTokens {
server_name: server_name.to_string(),
url: server_url.to_string(),
client_id,
token_response: WrappedOAuthTokenResponse(credentials),
expires_at,
};
save_oauth_tokens(server_name, &stored, store_mode)?;
drop(guard);
Ok(())
Ok(OauthLoginHandle::new(authorization_url, completion))
}
fn spawn_callback_server(server: Arc<Server>, tx: oneshot::Sender<(String, String)>) {
@@ -160,3 +147,181 @@ fn parse_oauth_callback(path: &str) -> Option<OauthCallbackResult> {
state: state?,
})
}
pub struct OauthLoginHandle {
authorization_url: String,
completion: oneshot::Receiver<Result<()>>,
}
impl OauthLoginHandle {
fn new(authorization_url: String, completion: oneshot::Receiver<Result<()>>) -> Self {
Self {
authorization_url,
completion,
}
}
pub fn authorization_url(&self) -> &str {
&self.authorization_url
}
pub fn into_parts(self) -> (String, oneshot::Receiver<Result<()>>) {
(self.authorization_url, self.completion)
}
pub async fn wait(self) -> Result<()> {
self.completion
.await
.map_err(|err| anyhow!("OAuth login task was cancelled: {err}"))?
}
}
struct OauthLoginFlow {
auth_url: String,
oauth_state: OAuthState,
rx: oneshot::Receiver<(String, String)>,
guard: CallbackServerGuard,
server_name: String,
server_url: String,
store_mode: OAuthCredentialsStoreMode,
launch_browser: bool,
timeout: Duration,
}
impl OauthLoginFlow {
async fn new(
server_name: &str,
server_url: &str,
store_mode: OAuthCredentialsStoreMode,
headers: OauthHeaders,
scopes: &[String],
launch_browser: bool,
timeout_secs: Option<i64>,
) -> Result<Self> {
const DEFAULT_OAUTH_TIMEOUT_SECS: i64 = 300;
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
let guard = CallbackServerGuard {
server: Arc::clone(&server),
};
let redirect_uri = match server.server_addr() {
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
let ip = addr.ip();
let port = addr.port();
format!("http://{ip}:{port}/callback")
}
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
let ip = addr.ip();
let port = addr.port();
format!("http://[{ip}]:{port}/callback")
}
#[cfg(not(target_os = "windows"))]
_ => return Err(anyhow!("unable to determine callback address")),
};
let (tx, rx) = oneshot::channel();
spawn_callback_server(server, tx);
let OauthHeaders {
http_headers,
env_http_headers,
} = headers;
let default_headers = build_default_headers(http_headers, env_http_headers)?;
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
oauth_state
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
.await?;
let auth_url = oauth_state.get_authorization_url().await?;
let timeout_secs = timeout_secs.unwrap_or(DEFAULT_OAUTH_TIMEOUT_SECS).max(1);
let timeout = Duration::from_secs(timeout_secs as u64);
Ok(Self {
auth_url,
oauth_state,
rx,
guard,
server_name: server_name.to_string(),
server_url: server_url.to_string(),
store_mode,
launch_browser,
timeout,
})
}
fn authorization_url(&self) -> String {
self.auth_url.clone()
}
async fn finish(mut self) -> Result<()> {
if self.launch_browser {
let server_name = &self.server_name;
let auth_url = &self.auth_url;
println!(
"Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n"
);
if webbrowser::open(auth_url).is_err() {
println!("(Browser launch failed; please copy the URL above manually.)");
}
}
let result = async {
let (code, csrf_state) = timeout(self.timeout, &mut self.rx)
.await
.context("timed out waiting for OAuth callback")?
.context("OAuth callback was cancelled")?;
self.oauth_state
.handle_callback(&code, &csrf_state)
.await
.context("failed to handle OAuth callback")?;
let (client_id, credentials_opt) = self
.oauth_state
.get_credentials()
.await
.context("failed to retrieve OAuth credentials")?;
let credentials = credentials_opt
.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
let expires_at = compute_expires_at_millis(&credentials);
let stored = StoredOAuthTokens {
server_name: self.server_name.clone(),
url: self.server_url.clone(),
client_id,
token_response: WrappedOAuthTokenResponse(credentials),
expires_at,
};
save_oauth_tokens(&self.server_name, &stored, self.store_mode)?;
Ok(())
}
.await;
drop(self.guard);
result
}
fn spawn(self) -> oneshot::Receiver<Result<()>> {
let server_name_for_logging = self.server_name.clone();
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
let result = self.finish().await;
if let Err(err) = &result {
eprintln!(
"Failed to complete OAuth login for '{server_name_for_logging}': {err:#}"
);
}
let _ = tx.send(result);
});
rx
}
}