Compare commits

...

3 Commits

Author SHA1 Message Date
nicholasclark-openai
1a1ab025e3 codex: scope MCP request headers to turn lifecycle
Set Codex Apps MCP request headers once per active turn and clear them on turn end,
instead of threading request-scoped headers through every tool call. Keep RMCP
header injection limited to streamable HTTP tools/call requests so list/init
paths stay unchanged and concurrent tool calls on the same client are not
serialized.

Co-authored-by: Codex <noreply@openai.com>
2026-03-18 16:48:57 -07:00
nicholasclark-openai
33f5447387 Add outbound HTTP header tracing logs
Co-authored-by: Codex <noreply@openai.com>
2026-03-17 19:37:38 -07:00
nicholasclark-openai
bccce0f2d8 Forward tool call task headers to MCP HTTP requests
Co-authored-by: Codex <noreply@openai.com>
2026-03-17 18:14:32 -07:00
8 changed files with 371 additions and 23 deletions

View File

@@ -5,6 +5,7 @@ use crate::types::RateLimitStatusPayload;
use crate::types::TurnAttemptsSiblingTurnsResponse;
use anyhow::Result;
use codex_client::build_reqwest_client_with_custom_ca;
use codex_client::log_http_request;
use codex_core::auth::CodexAuth;
use codex_core::default_client::get_codex_user_agent;
use codex_protocol::account::PlanType as AccountPlanType;
@@ -259,7 +260,9 @@ impl Client {
PathStyle::CodexApi => format!("{}/api/codex/usage", self.base_url),
PathStyle::ChatGptApi => format!("{}/wham/usage", self.base_url),
};
let req = self.http.get(&url).headers(self.headers());
let headers = self.headers();
log_http_request("GET", &url, &headers);
let req = self.http.get(&url).headers(headers);
let (body, ct) = self.exec_request(req, "GET", &url).await?;
let payload: RateLimitStatusPayload = self.decode_json(&url, &ct, &body)?;
Ok(Self::rate_limit_snapshots_from_payload(payload))
@@ -276,7 +279,9 @@ impl Client {
PathStyle::CodexApi => format!("{}/api/codex/tasks/list", self.base_url),
PathStyle::ChatGptApi => format!("{}/wham/tasks/list", self.base_url),
};
let req = self.http.get(&url).headers(self.headers());
let headers = self.headers();
log_http_request("GET", &url, &headers);
let req = self.http.get(&url).headers(headers);
let req = if let Some(lim) = limit {
req.query(&[("limit", lim)])
} else {
@@ -314,7 +319,9 @@ impl Client {
PathStyle::CodexApi => format!("{}/api/codex/tasks/{}", self.base_url, task_id),
PathStyle::ChatGptApi => format!("{}/wham/tasks/{}", self.base_url, task_id),
};
let req = self.http.get(&url).headers(self.headers());
let headers = self.headers();
log_http_request("GET", &url, &headers);
let req = self.http.get(&url).headers(headers);
let (body, ct) = self.exec_request(req, "GET", &url).await?;
let parsed: CodeTaskDetailsResponse = self.decode_json(&url, &ct, &body)?;
Ok((parsed, body, ct))
@@ -335,7 +342,9 @@ impl Client {
self.base_url, task_id, turn_id
),
};
let req = self.http.get(&url).headers(self.headers());
let headers = self.headers();
log_http_request("GET", &url, &headers);
let req = self.http.get(&url).headers(headers);
let (body, ct) = self.exec_request(req, "GET", &url).await?;
self.decode_json::<TurnAttemptsSiblingTurnsResponse>(&url, &ct, &body)
}
@@ -351,7 +360,9 @@ impl Client {
PathStyle::CodexApi => format!("{}/api/codex/config/requirements", self.base_url),
PathStyle::ChatGptApi => format!("{}/wham/config/requirements", self.base_url),
};
let req = self.http.get(&url).headers(self.headers());
let headers = self.headers();
log_http_request("GET", &url, &headers);
let req = self.http.get(&url).headers(headers);
let (body, ct) = self.exec_request_detailed(req, "GET", &url).await?;
self.decode_json::<ConfigFileResponse>(&url, &ct, &body)
.map_err(RequestError::from)
@@ -364,12 +375,10 @@ impl Client {
PathStyle::CodexApi => format!("{}/api/codex/tasks", self.base_url),
PathStyle::ChatGptApi => format!("{}/wham/tasks", self.base_url),
};
let req = self
.http
.post(&url)
.headers(self.headers())
.header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
.json(&request_body);
let mut headers = self.headers();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
log_http_request("POST", &url, &headers);
let req = self.http.post(&url).headers(headers).json(&request_body);
let (body, ct) = self.exec_request(req, "POST", &url).await?;
// Extract id from JSON: prefer `task.id`; fallback to top-level `id` when present.
match serde_json::from_str::<serde_json::Value>(&body) {

View File

@@ -8,6 +8,7 @@ use reqwest::IntoUrl;
use reqwest::Method;
use reqwest::Response;
use serde::Serialize;
use std::collections::BTreeMap;
use std::fmt::Display;
use std::time::Duration;
use tracing::Span;
@@ -18,6 +19,11 @@ pub struct CodexHttpClient {
inner: reqwest::Client,
}
const CODEX_TRACE_HTTP_HEADERS_ENV: &str = "CODEX_TRACE_HTTP_HEADERS";
const CODEX_TRACE_HTTP_HEADERS_INCLUDE_SENSITIVE_ENV: &str =
"CODEX_TRACE_HTTP_HEADERS_INCLUDE_SENSITIVE";
const REDACTED_HEADER_VALUE: &str = "<redacted>";
impl CodexHttpClient {
pub fn new(inner: reqwest::Client) -> Self {
Self { inner }
@@ -111,9 +117,18 @@ impl CodexRequestBuilder {
}
pub async fn send(self) -> Result<Response, reqwest::Error> {
let headers = trace_headers();
let builder = self.builder.headers(trace_headers());
if let Some(request_builder) = builder.try_clone()
&& let Ok(request) = request_builder.build()
{
log_http_request(
self.method.as_str(),
request.url().as_str(),
request.headers(),
);
}
match self.builder.headers(headers).send().await {
match builder.send().await {
Ok(response) => {
tracing::debug!(
method = %self.method,
@@ -165,6 +180,63 @@ fn trace_headers() -> HeaderMap {
headers
}
pub fn log_http_request(method: &str, url: &str, headers: &HeaderMap) {
if !http_trace_headers_enabled() {
return;
}
tracing::info!(
method,
url,
headers = ?format_headers_for_log(headers),
"Outbound HTTP request"
);
}
pub fn format_headers_for_log(headers: &HeaderMap) -> BTreeMap<String, String> {
let include_sensitive = http_trace_include_sensitive_headers();
headers
.iter()
.map(|(name, value)| {
let name_str = name.as_str().to_ascii_lowercase();
let value_str = if include_sensitive || !is_sensitive_header(&name_str) {
value.to_str().unwrap_or("<binary>").to_string()
} else {
REDACTED_HEADER_VALUE.to_string()
};
(name_str, value_str)
})
.collect()
}
fn http_trace_headers_enabled() -> bool {
env_flag_enabled(CODEX_TRACE_HTTP_HEADERS_ENV)
}
fn http_trace_include_sensitive_headers() -> bool {
env_flag_enabled(CODEX_TRACE_HTTP_HEADERS_INCLUDE_SENSITIVE_ENV)
}
fn env_flag_enabled(name: &str) -> bool {
std::env::var(name)
.map(|value| {
let normalized = value.trim().to_ascii_lowercase();
!normalized.is_empty()
&& normalized != "0"
&& normalized != "false"
&& normalized != "no"
&& normalized != "off"
})
.unwrap_or(false)
}
fn is_sensitive_header(name: &str) -> bool {
matches!(
name,
"authorization" | "proxy-authorization" | "cookie" | "set-cookie" | "x-api-key" | "api-key"
)
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -19,6 +19,8 @@ pub use crate::custom_ca::build_reqwest_client_with_custom_ca;
pub use crate::custom_ca::maybe_build_rustls_client_config_with_custom_ca;
pub use crate::default_client::CodexHttpClient;
pub use crate::default_client::CodexRequestBuilder;
pub use crate::default_client::format_headers_for_log;
pub use crate::default_client::log_http_request;
pub use crate::error::StreamError;
pub use crate::error::TransportError;
pub use crate::request::Request;

View File

@@ -125,6 +125,8 @@ use futures::future::BoxFuture;
use futures::future::Shared;
use futures::prelude::*;
use futures::stream::FuturesOrdered;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use rmcp::model::ListResourceTemplatesResult;
use rmcp::model::ListResourcesResult;
use rmcp::model::PaginatedRequestParams;
@@ -3921,6 +3923,42 @@ impl Session {
.await
}
pub(crate) async fn sync_mcp_request_headers_for_turn(&self, turn_context: &TurnContext) {
let mut request_headers = HeaderMap::new();
let session_id = self.conversation_id.to_string();
if let Ok(value) = HeaderValue::from_str(&session_id) {
request_headers.insert("session_id", value.clone());
request_headers.insert("x-client-request-id", value);
}
if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value()
&& let Ok(value) = HeaderValue::from_str(&turn_metadata)
{
request_headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value);
}
let request_headers = if request_headers.is_empty() {
None
} else {
Some(request_headers)
};
self.services
.mcp_connection_manager
.read()
.await
.set_request_headers_for_server(
crate::mcp::CODEX_APPS_MCP_SERVER_NAME,
request_headers,
);
}
pub(crate) async fn clear_mcp_request_headers(&self) {
self.services
.mcp_connection_manager
.read()
.await
.set_request_headers_for_server(crate::mcp::CODEX_APPS_MCP_SERVER_NAME, None);
}
pub(crate) async fn parse_mcp_tool_name(
&self,
name: &str,

View File

@@ -89,6 +89,7 @@ pub(crate) async fn run_codex_thread_interactive(
metrics_service_name: None,
inherited_shell_snapshot: None,
user_shell_override: None,
inherited_exec_policy: Some(Arc::clone(&parent_session.services.exec_policy)),
parent_trace: None,
})
.await?;

View File

@@ -423,6 +423,7 @@ impl ManagedClient {
#[derive(Clone)]
struct AsyncManagedClient {
client: Shared<BoxFuture<'static, Result<ManagedClient, StartupOutcomeError>>>,
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
startup_snapshot: Option<Vec<ToolInfo>>,
startup_complete: Arc<AtomicBool>,
tool_plugin_provenance: Arc<ToolPluginProvenance>,
@@ -448,17 +449,26 @@ impl AsyncManagedClient {
codex_apps_tools_cache_context.as_ref(),
)
.map(|tools| filter_tools(tools, &tool_filter));
let request_headers = Arc::new(StdMutex::new(None));
let startup_tool_filter = tool_filter;
let startup_complete = Arc::new(AtomicBool::new(false));
let startup_complete_for_fut = Arc::clone(&startup_complete);
let request_headers_for_client = Arc::clone(&request_headers);
let fut = async move {
let outcome = async {
if let Err(error) = validate_mcp_server_name(&server_name) {
return Err(error.into());
}
let client =
Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?);
let client = Arc::new(
make_rmcp_client(
&server_name,
config.transport,
store_mode,
request_headers_for_client,
)
.await?,
);
match start_server_task(
server_name,
client,
@@ -495,6 +505,7 @@ impl AsyncManagedClient {
Self {
client,
request_headers,
startup_snapshot,
startup_complete,
tool_plugin_provenance,
@@ -576,6 +587,14 @@ impl AsyncManagedClient {
let managed = self.client().await?;
managed.notify_sandbox_state_change(sandbox_state).await
}
fn set_request_headers(&self, request_headers: Option<reqwest::header::HeaderMap>) {
let mut guard = self
.request_headers
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*guard = request_headers;
}
}
pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state";
@@ -1046,6 +1065,16 @@ impl McpConnectionManager {
})
}
pub(crate) fn set_request_headers_for_server(
&self,
server_name: &str,
request_headers: Option<reqwest::header::HeaderMap>,
) {
if let Some(client) = self.clients.get(server_name) {
client.set_request_headers(request_headers);
}
}
/// List resources from the specified server.
pub async fn list_resources(
&self,
@@ -1429,6 +1458,7 @@ async fn make_rmcp_client(
server_name: &str,
transport: McpServerTransportConfig,
store_mode: OAuthCredentialsStoreMode,
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
) -> Result<RmcpClient, StartupOutcomeError> {
match transport {
McpServerTransportConfig::Stdio {
@@ -1462,6 +1492,7 @@ async fn make_rmcp_client(
http_headers,
env_http_headers,
store_mode,
request_headers,
)
.await
.map_err(StartupOutcomeError::from)

View File

@@ -151,6 +151,8 @@ impl Session {
) {
self.abort_all_tasks(TurnAbortReason::Replaced).await;
self.clear_connector_selection().await;
self.sync_mcp_request_headers_for_turn(turn_context.as_ref())
.await;
let task: Arc<dyn SessionTask> = Arc::new(task);
let task_kind = task.kind();
@@ -231,6 +233,7 @@ impl Session {
// in-flight approval wait can surface as a model-visible rejection before TurnAborted.
active_turn.clear_pending().await;
}
self.clear_mcp_request_headers().await;
}
pub async fn on_task_finished(
@@ -260,6 +263,9 @@ impl Session {
*active = None;
}
drop(active);
if should_clear_active_turn {
self.clear_mcp_request_headers().await;
}
if !pending_input.is_empty() {
let pending_response_items = pending_input
.into_iter()

View File

@@ -5,11 +5,14 @@ use std::io;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::time::Duration;
use anyhow::Result;
use anyhow::anyhow;
use codex_client::build_reqwest_client_with_custom_ca;
use codex_client::format_headers_for_log;
use codex_client::log_http_request;
use futures::FutureExt;
use futures::StreamExt;
use futures::future::BoxFuture;
@@ -19,9 +22,11 @@ use reqwest::header::ACCEPT;
use reqwest::header::AUTHORIZATION;
use reqwest::header::CONTENT_TYPE;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use reqwest::header::WWW_AUTHENTICATE;
use rmcp::model::CallToolRequestParams;
use rmcp::model::CallToolResult;
use rmcp::model::ClientJsonRpcMessage;
use rmcp::model::ClientNotification;
use rmcp::model::ClientRequest;
use rmcp::model::CreateElicitationRequestParams;
@@ -82,15 +87,123 @@ const JSON_MIME_TYPE: &str = "application/json";
const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192;
const CODEX_TRACE_HTTP_HEADERS_ENV: &str = "CODEX_TRACE_HTTP_HEADERS";
const CODEX_TRACE_HTTP_BODIES_ENV: &str = "CODEX_TRACE_HTTP_BODIES";
fn http_trace_headers_enabled() -> bool {
std::env::var_os(CODEX_TRACE_HTTP_HEADERS_ENV).is_some_and(|value| value != "0")
}
fn http_trace_bodies_enabled() -> bool {
std::env::var_os(CODEX_TRACE_HTTP_BODIES_ENV).is_some_and(|value| value != "0")
}
fn log_mcp_http_request(
method: &str,
url: &str,
headers: &HeaderMap,
message: &rmcp::model::ClientJsonRpcMessage,
) {
log_http_request(method, url, headers);
if !http_trace_headers_enabled() && !http_trace_bodies_enabled() {
return;
}
let body = if http_trace_bodies_enabled() {
serde_json::to_string(message).ok()
} else {
None
};
match message {
rmcp::model::JsonRpcMessage::Request(request) => {
tracing::info!(
method,
url,
rpc_method = request.request.method(),
rpc_id = ?request.id,
headers = ?format_headers_for_log(headers),
body,
"Outbound MCP HTTP request"
);
}
rmcp::model::JsonRpcMessage::Notification(_) => {
tracing::info!(
method,
url,
rpc_kind = "notification",
headers = ?format_headers_for_log(headers),
body,
"Outbound MCP HTTP request"
);
}
rmcp::model::JsonRpcMessage::Response(response) => {
tracing::info!(
method,
url,
rpc_kind = "response",
rpc_id = ?response.id,
headers = ?format_headers_for_log(headers),
body,
"Outbound MCP HTTP request"
);
}
rmcp::model::JsonRpcMessage::Error(error) => {
tracing::info!(
method,
url,
rpc_kind = "error",
rpc_id = ?error.id,
headers = ?format_headers_for_log(headers),
body,
"Outbound MCP HTTP request"
);
}
}
}
fn message_uses_request_scoped_headers(message: &ClientJsonRpcMessage) -> bool {
matches!(
message,
ClientJsonRpcMessage::Request(request)
if request.request.method() == "tools/call"
)
}
fn apply_request_scoped_headers(
mut request: reqwest::RequestBuilder,
request_headers_state: &Arc<StdMutex<Option<HeaderMap>>>,
) -> reqwest::RequestBuilder {
let extra_headers = request_headers_state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone();
if let Some(extra_headers) = extra_headers {
for (name, value) in &extra_headers {
request = request.header(name, value.clone());
}
}
request
}
#[derive(Clone)]
struct StreamableHttpResponseClient {
inner: reqwest::Client,
default_headers: HeaderMap,
request_headers_state: Arc<StdMutex<Option<HeaderMap>>>,
}
impl StreamableHttpResponseClient {
fn new(inner: reqwest::Client) -> Self {
Self { inner }
fn new(
inner: reqwest::Client,
default_headers: HeaderMap,
request_headers_state: Arc<StdMutex<Option<HeaderMap>>>,
) -> Self {
Self {
inner,
default_headers,
request_headers_state,
}
}
fn reqwest_error(
@@ -127,12 +240,37 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
.inner
.post(uri.as_ref())
.header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "));
let mut request_headers = self.default_headers.clone();
request_headers.insert(
ACCEPT,
HeaderValue::from_static("text/event-stream, application/json"),
);
if let Some(auth_header) = auth_token {
if let Ok(value) = HeaderValue::from_str(&format!("Bearer {auth_header}")) {
request_headers.insert(AUTHORIZATION, value);
}
request = request.bearer_auth(auth_header);
}
if let Some(session_id_value) = session_id.as_ref() {
if let Ok(value) = HeaderValue::from_str(session_id_value.as_ref()) {
request_headers.insert(HEADER_SESSION_ID, value);
}
request = request.header(HEADER_SESSION_ID, session_id_value.as_ref());
}
if message_uses_request_scoped_headers(&message) {
let extra_headers = self
.request_headers_state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone();
if let Some(extra_headers) = extra_headers {
for (name, value) in &extra_headers {
request_headers.insert(name.clone(), value.clone());
}
}
request = apply_request_scoped_headers(request, &self.request_headers_state);
}
log_mcp_http_request("POST", uri.as_ref(), &request_headers, &message);
let response = request
.json(&message)
@@ -225,9 +363,17 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
auth_token: Option<String>,
) -> std::result::Result<(), StreamableHttpError<Self::Error>> {
let mut request_builder = self.inner.delete(uri.as_ref());
let mut request_headers = self.default_headers.clone();
if let Some(auth_header) = auth_token {
if let Ok(value) = HeaderValue::from_str(&format!("Bearer {auth_header}")) {
request_headers.insert(AUTHORIZATION, value);
}
request_builder = request_builder.bearer_auth(auth_header);
}
if let Ok(value) = HeaderValue::from_str(session.as_ref()) {
request_headers.insert(HEADER_SESSION_ID, value);
}
log_http_request("DELETE", uri.as_ref(), &request_headers);
let response = request_builder
.header(HEADER_SESSION_ID, session.as_ref())
.send()
@@ -259,12 +405,27 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
.get(uri.as_ref())
.header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "))
.header(HEADER_SESSION_ID, session_id.as_ref());
let mut request_headers = self.default_headers.clone();
request_headers.insert(
ACCEPT,
HeaderValue::from_static("text/event-stream, application/json"),
);
if let Ok(value) = HeaderValue::from_str(session_id.as_ref()) {
request_headers.insert(HEADER_SESSION_ID, value);
}
if let Some(last_event_id) = last_event_id {
if let Ok(value) = HeaderValue::from_str(&last_event_id) {
request_headers.insert(HEADER_LAST_EVENT_ID, value);
}
request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id);
}
if let Some(auth_header) = auth_token {
if let Ok(value) = HeaderValue::from_str(&format!("Bearer {auth_header}")) {
request_headers.insert(AUTHORIZATION, value);
}
request_builder = request_builder.bearer_auth(auth_header);
}
log_http_request("GET", uri.as_ref(), &request_headers);
let response = request_builder
.send()
@@ -472,6 +633,7 @@ pub struct RmcpClient {
transport_recipe: TransportRecipe,
initialize_context: Mutex<Option<InitializeContext>>,
session_recovery_lock: Mutex<()>,
request_headers: Option<Arc<StdMutex<Option<HeaderMap>>>>,
}
impl RmcpClient {
@@ -489,7 +651,7 @@ impl RmcpClient {
env_vars: env_vars.to_vec(),
cwd,
};
let transport = Self::create_pending_transport(&transport_recipe)
let transport = Self::create_pending_transport(&transport_recipe, None)
.await
.map_err(io::Error::other)?;
@@ -500,6 +662,7 @@ impl RmcpClient {
transport_recipe,
initialize_context: Mutex::new(None),
session_recovery_lock: Mutex::new(()),
request_headers: None,
})
}
@@ -511,6 +674,7 @@ impl RmcpClient {
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
store_mode: OAuthCredentialsStoreMode,
request_headers: Arc<StdMutex<Option<HeaderMap>>>,
) -> Result<Self> {
let transport_recipe = TransportRecipe::StreamableHttp {
server_name: server_name.to_string(),
@@ -520,7 +684,9 @@ impl RmcpClient {
env_http_headers,
store_mode,
};
let transport = Self::create_pending_transport(&transport_recipe).await?;
let transport =
Self::create_pending_transport(&transport_recipe, Some(Arc::clone(&request_headers)))
.await?;
Ok(Self {
state: Mutex::new(ClientState::Connecting {
transport: Some(transport),
@@ -528,6 +694,7 @@ impl RmcpClient {
transport_recipe,
initialize_context: Mutex::new(None),
session_recovery_lock: Mutex::new(()),
request_headers: Some(request_headers),
})
}
@@ -830,6 +997,7 @@ impl RmcpClient {
async fn create_pending_transport(
transport_recipe: &TransportRecipe,
request_headers: Option<Arc<StdMutex<Option<HeaderMap>>>>,
) -> Result<PendingTransport> {
match transport_recipe {
TransportRecipe::Stdio {
@@ -946,7 +1114,13 @@ impl RmcpClient {
.auth_header(access_token);
let http_client = build_http_client(&default_headers)?;
let transport = StreamableHttpClientTransport::with_client(
StreamableHttpResponseClient::new(http_client),
StreamableHttpResponseClient::new(
http_client,
default_headers.clone(),
request_headers
.clone()
.unwrap_or_else(|| Arc::new(StdMutex::new(None))),
),
http_config,
);
Ok(PendingTransport::StreamableHttp { transport })
@@ -963,7 +1137,13 @@ impl RmcpClient {
let http_client = build_http_client(&default_headers)?;
let transport = StreamableHttpClientTransport::with_client(
StreamableHttpResponseClient::new(http_client),
StreamableHttpResponseClient::new(
http_client,
default_headers.clone(),
request_headers
.clone()
.unwrap_or_else(|| Arc::new(StdMutex::new(None))),
),
http_config,
);
Ok(PendingTransport::StreamableHttp { transport })
@@ -1111,7 +1291,9 @@ impl RmcpClient {
.await
.clone()
.ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?;
let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?;
let pending_transport =
Self::create_pending_transport(&self.transport_recipe, self.request_headers.clone())
.await?;
let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport(
pending_transport,
initialize_context.handler,
@@ -1166,7 +1348,14 @@ async fn create_oauth_transport_and_runtime(
}
};
let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager);
let auth_client = AuthClient::new(
StreamableHttpResponseClient::new(
http_client,
default_headers.clone(),
Arc::new(StdMutex::new(None)),
),
manager,
);
let auth_manager = auth_client.auth_manager.clone();
let transport = StreamableHttpClientTransport::with_client(