Compare commits

...

2 Commits

Author SHA1 Message Date
celia-oai
09637409b0 changes 2026-04-30 20:25:16 -07:00
celia-oai
79908b64a1 changes 2026-04-30 20:08:54 -07:00
9 changed files with 113 additions and 21 deletions

View File

@@ -97,6 +97,11 @@ impl AwsAuthContext {
&self.service
}
pub async fn preload_credentials(&self) -> Result<(), AwsAuthError> {
let _ = self.credentials_provider.provide_credentials().await?;
Ok(())
}
pub async fn sign(&self, request: AwsRequestToSign) -> Result<AwsSignedRequest, AwsAuthError> {
self.sign_at(request, SystemTime::now()).await
}
@@ -202,6 +207,14 @@ mod tests {
assert!(signing::header_value(&signed.headers, "x-amz-date").is_some());
}
#[tokio::test]
async fn preload_credentials_resolves_provider() {
test_context(/*session_token*/ None)
.preload_credentials()
.await
.expect("static credentials should resolve");
}
#[test]
fn credentials_provider_failures_are_retryable() {
assert!(

View File

@@ -673,6 +673,15 @@ impl ModelClient {
true
}
/// Resolves provider credentials during session startup when the provider requests it.
pub(crate) async fn prewarm_provider_auth(&self) -> Result<()> {
if !self.state.provider.prewarms_auth_on_startup() {
return Ok(());
}
self.state.provider.prewarm_auth().await
}
/// Returns auth + provider configuration resolved from the current session auth state.
///
/// This centralizes setup used by both prewarm and normal request paths so they stay in

View File

@@ -82,6 +82,7 @@ pub(crate) mod mentions {
mod sandbox_tags;
pub mod sandboxing;
mod session_prefix;
mod session_startup_auth_prewarm;
mod session_startup_prewarm;
mod shell_detect;
pub mod skills;

View File

@@ -980,6 +980,7 @@ impl Session {
anyhow::bail!("required MCP servers failed to initialize: {details}");
}
}
sess.schedule_startup_auth_prewarm().await;
sess.schedule_startup_prewarm(session_configuration.base_instructions.clone())
.await;
let session_start_source = match &initial_history {

View File

@@ -0,0 +1,16 @@
use std::sync::Arc;
use tracing::debug;
use crate::session::session::Session;
impl Session {
pub(crate) async fn schedule_startup_auth_prewarm(self: &Arc<Self>) {
let model_client = self.services.model_client.clone();
tokio::spawn(async move {
if let Err(err) = model_client.prewarm_provider_auth().await {
debug!("startup provider auth prewarm failed: {err:#}");
}
});
}
}

View File

@@ -22,6 +22,7 @@ use super::mantle::region_from_config;
const AWS_BEARER_TOKEN_BEDROCK_ENV_VAR: &str = "AWS_BEARER_TOKEN_BEDROCK";
const LEGACY_SESSION_ID_HEADER: &str = "session_id";
#[derive(Clone, Debug)]
pub(super) enum BedrockAuthMethod {
EnvBearerToken { token: String, region: String },
AwsSdkAuth { context: AwsAuthContext },
@@ -42,17 +43,25 @@ pub(super) async fn resolve_auth_method(
Ok(BedrockAuthMethod::AwsSdkAuth { context })
}
pub(super) async fn resolve_provider_auth(
aws: &ModelProviderAwsAuthInfo,
) -> Result<SharedAuthProvider> {
match resolve_auth_method(aws).await? {
BedrockAuthMethod::EnvBearerToken { token, .. } => Ok(Arc::new(BearerAuthProvider {
pub(super) async fn prewarm_credentials(auth_method: &BedrockAuthMethod) -> Result<()> {
match auth_method {
BedrockAuthMethod::EnvBearerToken { .. } => Ok(()),
BedrockAuthMethod::AwsSdkAuth { context } => context
.preload_credentials()
.await
.map_err(aws_auth_error_to_codex_error),
}
}
pub(super) fn provider_auth_from_method(auth_method: BedrockAuthMethod) -> SharedAuthProvider {
match auth_method {
BedrockAuthMethod::EnvBearerToken { token, .. } => Arc::new(BearerAuthProvider {
token: Some(token),
account_id: None,
is_fedramp_account: false,
})),
}),
BedrockAuthMethod::AwsSdkAuth { context } => {
Ok(Arc::new(BedrockMantleSigV4AuthProvider::new(context)))
Arc::new(BedrockMantleSigV4AuthProvider::new(context))
}
}
}

View File

@@ -4,7 +4,6 @@ use codex_protocol::error::CodexErr;
use codex_protocol::error::Result;
use super::auth::BedrockAuthMethod;
use super::auth::resolve_auth_method;
const BEDROCK_MANTLE_SERVICE_NAME: &str = "bedrock-mantle";
const BEDROCK_MANTLE_SUPPORTED_REGIONS: [&str; 12] = [
@@ -48,16 +47,15 @@ pub(super) fn base_url(region: &str) -> Result<String> {
}
}
pub(super) async fn runtime_base_url(aws: &ModelProviderAwsAuthInfo) -> Result<String> {
let region = resolve_region(aws).await?;
base_url(&region)
pub(super) fn region_from_auth_method(auth_method: &BedrockAuthMethod) -> String {
match auth_method {
BedrockAuthMethod::EnvBearerToken { region, .. } => region.clone(),
BedrockAuthMethod::AwsSdkAuth { context } => context.region().to_string(),
}
}
async fn resolve_region(aws: &ModelProviderAwsAuthInfo) -> Result<String> {
match resolve_auth_method(aws).await? {
BedrockAuthMethod::EnvBearerToken { region, .. } => Ok(region),
BedrockAuthMethod::AwsSdkAuth { context } => Ok(context.region().to_string()),
}
pub(super) fn runtime_base_url_from_auth_method(auth_method: &BedrockAuthMethod) -> Result<String> {
base_url(&region_from_auth_method(auth_method))
}
#[cfg(test)]

View File

@@ -16,20 +16,26 @@ use codex_models_manager::manager::StaticModelsManager;
use codex_protocol::account::ProviderAccount;
use codex_protocol::error::Result;
use codex_protocol::openai_models::ModelsResponse;
use tokio::sync::OnceCell;
use crate::provider::ModelProvider;
use crate::provider::ProviderAccountResult;
use crate::provider::ProviderAccountState;
use crate::provider::ProviderCapabilities;
use auth::resolve_provider_auth;
use auth::BedrockAuthMethod;
use auth::prewarm_credentials;
use auth::provider_auth_from_method;
use auth::resolve_auth_method;
pub(crate) use catalog::static_model_catalog;
use mantle::runtime_base_url;
use mantle::runtime_base_url_from_auth_method;
/// Runtime provider for Amazon Bedrock's OpenAI-compatible Mantle endpoint.
#[derive(Clone, Debug)]
pub(crate) struct AmazonBedrockModelProvider {
pub(crate) info: ModelProviderInfo,
pub(crate) aws: ModelProviderAwsAuthInfo,
auth_method: Arc<OnceCell<BedrockAuthMethod>>,
credentials_prewarmed: Arc<OnceCell<()>>,
}
impl AmazonBedrockModelProvider {
@@ -44,8 +50,25 @@ impl AmazonBedrockModelProvider {
Self {
info: provider_info,
aws,
auth_method: Arc::new(OnceCell::new()),
credentials_prewarmed: Arc::new(OnceCell::new()),
}
}
async fn auth_method(&self) -> Result<BedrockAuthMethod> {
self.auth_method
.get_or_try_init(|| resolve_auth_method(&self.aws))
.await
.cloned()
}
async fn prewarm_bedrock_credentials(&self) -> Result<()> {
let auth_method = self.auth_method().await?;
self.credentials_prewarmed
.get_or_try_init(|| async move { prewarm_credentials(&auth_method).await })
.await?;
Ok(())
}
}
#[async_trait::async_trait]
@@ -70,6 +93,14 @@ impl ModelProvider for AmazonBedrockModelProvider {
None
}
fn prewarms_auth_on_startup(&self) -> bool {
true
}
async fn prewarm_auth(&self) -> Result<()> {
self.prewarm_bedrock_credentials().await
}
fn account_state(&self) -> ProviderAccountResult {
Ok(ProviderAccountState {
account: Some(ProviderAccount::AmazonBedrock),
@@ -79,16 +110,20 @@ impl ModelProvider for AmazonBedrockModelProvider {
async fn api_provider(&self) -> Result<Provider> {
let mut api_provider_info = self.info.clone();
api_provider_info.base_url = Some(runtime_base_url(&self.aws).await?);
api_provider_info.base_url = Some(runtime_base_url_from_auth_method(
&self.auth_method().await?,
)?);
api_provider_info.to_api_provider(/*auth_mode*/ None)
}
async fn runtime_base_url(&self) -> Result<Option<String>> {
Ok(Some(runtime_base_url(&self.aws).await?))
Ok(Some(runtime_base_url_from_auth_method(
&self.auth_method().await?,
)?))
}
async fn api_auth(&self) -> Result<SharedAuthProvider> {
resolve_provider_auth(&self.aws).await
Ok(provider_auth_from_method(self.auth_method().await?))
}
fn models_manager(

View File

@@ -96,6 +96,16 @@ pub trait ModelProvider: fmt::Debug + Send + Sync {
/// Returns the current provider-scoped auth value, if one is configured.
async fn auth(&self) -> Option<CodexAuth>;
/// Returns whether this provider should resolve request credentials during session startup.
fn prewarms_auth_on_startup(&self) -> bool {
false
}
/// Resolves provider credentials before the first model request when startup prewarm is enabled.
async fn prewarm_auth(&self) -> codex_protocol::error::Result<()> {
Ok(())
}
/// Returns the current app-visible account state for this provider.
fn account_state(&self) -> ProviderAccountResult;