Compare commits

...

2 Commits

Author SHA1 Message Date
gt-oai
761a37d5c1 blocking debug 2026-01-31 01:42:50 +00:00
gt-oai
e04cfc3601 Make cloud requirements load fail-closed 2026-01-31 00:28:03 +00:00
6 changed files with 241 additions and 111 deletions

2
codex-rs/Cargo.lock generated
View File

@@ -1297,6 +1297,7 @@ dependencies = [
name = "codex-cloud-requirements" name = "codex-cloud-requirements"
version = "0.0.0" version = "0.0.0"
dependencies = [ dependencies = [
"anyhow",
"async-trait", "async-trait",
"base64", "base64",
"codex-backend-client", "codex-backend-client",
@@ -1306,6 +1307,7 @@ dependencies = [
"pretty_assertions", "pretty_assertions",
"serde_json", "serde_json",
"tempfile", "tempfile",
"thiserror 2.0.17",
"tokio", "tokio",
"toml 0.9.5", "toml 0.9.5",
"tracing", "tracing",

View File

@@ -14,10 +14,12 @@ codex-core = { workspace = true }
codex-otel = { workspace = true } codex-otel = { workspace = true }
codex-protocol = { workspace = true } codex-protocol = { workspace = true }
tokio = { workspace = true, features = ["sync", "time"] } tokio = { workspace = true, features = ["sync", "time"] }
thiserror = { workspace = true }
toml = { workspace = true } toml = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true }
base64 = { workspace = true } base64 = { workspace = true }
pretty_assertions = { workspace = true } pretty_assertions = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }

View File

@@ -3,9 +3,7 @@
//! This crate fetches `requirements.toml` data from the backend as an alternative to loading it //! This crate fetches `requirements.toml` data from the backend as an alternative to loading it
//! from the local filesystem. It only applies to Enterprise ChatGPT customers. //! from the local filesystem. It only applies to Enterprise ChatGPT customers.
//! //!
//! Today, fetching is best-effort: on error or timeout, Codex continues without cloud requirements. //! Enterprise ChatGPT customers must successfully fetch these requirements before Codex will run.
//! We expect to tighten this so that Enterprise ChatGPT customers must successfully fetch these
//! requirements before Codex will run.
use async_trait::async_trait; use async_trait::async_trait;
use codex_backend_client::Client as BackendClient; use codex_backend_client::Client as BackendClient;
@@ -14,21 +12,73 @@ use codex_core::auth::CodexAuth;
use codex_core::config_loader::CloudRequirementsLoader; use codex_core::config_loader::CloudRequirementsLoader;
use codex_core::config_loader::ConfigRequirementsToml; use codex_core::config_loader::ConfigRequirementsToml;
use codex_protocol::account::PlanType; use codex_protocol::account::PlanType;
use std::io;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::time::Instant; use std::time::Instant;
use thiserror::Error;
use tokio::time::timeout; use tokio::time::timeout;
/// This blocks codecs startup, so must be short. /// This blocks codecs startup, so must be short.
const CLOUD_REQUIREMENTS_TIMEOUT: Duration = Duration::from_secs(5); const CLOUD_REQUIREMENTS_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug, Error, Clone, PartialEq, Eq)]
enum CloudRequirementsError {
#[error("cloud requirements user error: {0}")]
User(CloudRequirementsUserError),
#[error("cloud requirements network error: {0}")]
Network(CloudRequirementsNetworkError),
}
impl From<CloudRequirementsUserError> for CloudRequirementsError {
fn from(err: CloudRequirementsUserError) -> Self {
CloudRequirementsError::User(err)
}
}
impl From<CloudRequirementsNetworkError> for CloudRequirementsError {
fn from(err: CloudRequirementsNetworkError) -> Self {
CloudRequirementsError::Network(err)
}
}
impl From<CloudRequirementsError> for io::Error {
fn from(err: CloudRequirementsError) -> Self {
let kind = match &err {
CloudRequirementsError::User(_) => io::ErrorKind::InvalidData,
CloudRequirementsError::Network(CloudRequirementsNetworkError::Timeout { .. }) => {
io::ErrorKind::TimedOut
}
CloudRequirementsError::Network(_) => io::ErrorKind::Other,
};
io::Error::new(kind, err)
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
enum CloudRequirementsUserError {
#[error("failed to parse requirements TOML: {message}")]
InvalidToml { message: String },
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
enum CloudRequirementsNetworkError {
#[error("backend client initialization failed: {message}")]
BackendClient { message: String },
#[error("request failed: {message}")]
Request { message: String },
#[error("cloud requirements response missing contents")]
MissingContents,
#[error("timed out after {timeout_ms}ms")]
Timeout { timeout_ms: u64 },
#[error("cloud requirements task failed: {message}")]
Task { message: String },
}
#[async_trait] #[async_trait]
trait RequirementsFetcher: Send + Sync { trait RequirementsFetcher: Send + Sync {
/// Returns requirements as a TOML string. /// Returns requirements as a TOML string.
/// async fn fetch_requirements(&self, auth: &CodexAuth) -> Result<String, CloudRequirementsError>;
/// TODO(gt): For now, returns an Option. But when we want to make this fail-closed, return a
/// Result.
async fn fetch_requirements(&self, auth: &CodexAuth) -> Option<String>;
} }
struct BackendRequirementsFetcher { struct BackendRequirementsFetcher {
@@ -43,7 +93,7 @@ impl BackendRequirementsFetcher {
#[async_trait] #[async_trait]
impl RequirementsFetcher for BackendRequirementsFetcher { impl RequirementsFetcher for BackendRequirementsFetcher {
async fn fetch_requirements(&self, auth: &CodexAuth) -> Option<String> { async fn fetch_requirements(&self, auth: &CodexAuth) -> Result<String, CloudRequirementsError> {
let client = BackendClient::from_auth(self.base_url.clone(), auth) let client = BackendClient::from_auth(self.base_url.clone(), auth)
.inspect_err(|err| { .inspect_err(|err| {
tracing::warn!( tracing::warn!(
@@ -51,20 +101,28 @@ impl RequirementsFetcher for BackendRequirementsFetcher {
"Failed to construct backend client for cloud requirements" "Failed to construct backend client for cloud requirements"
); );
}) })
.ok()?; .map_err(|err| CloudRequirementsNetworkError::BackendClient {
message: err.to_string(),
})
.map_err(CloudRequirementsError::from)?;
let response = client let response = client
.get_config_requirements_file() .get_config_requirements_file()
.await .await
.inspect_err(|err| tracing::warn!(error = %err, "Failed to fetch cloud requirements")) .inspect_err(|err| tracing::warn!(error = %err, "Failed to fetch cloud requirements"))
.ok()?; .map_err(|err| CloudRequirementsNetworkError::Request {
message: err.to_string(),
})
.map_err(CloudRequirementsError::from)?;
let Some(contents) = response.contents else { let Some(contents) = response.contents else {
tracing::warn!("Cloud requirements response missing contents"); tracing::warn!("Cloud requirements response missing contents");
return None; return Err(CloudRequirementsError::from(
CloudRequirementsNetworkError::MissingContents,
));
}; };
Some(contents) Ok(contents)
} }
} }
@@ -87,29 +145,50 @@ impl CloudRequirementsService {
} }
} }
async fn fetch_with_timeout(&self) -> Option<ConfigRequirementsToml> { async fn fetch_with_timeout(
&self,
) -> Result<Option<ConfigRequirementsToml>, CloudRequirementsError> {
let _timer = let _timer =
codex_otel::start_global_timer("codex.cloud_requirements.fetch.duration_ms", &[]); codex_otel::start_global_timer("codex.cloud_requirements.fetch.duration_ms", &[]);
let started_at = Instant::now(); let started_at = Instant::now();
let result = timeout(self.timeout, self.fetch()) let result = timeout(self.timeout, self.fetch()).await.map_err(|_| {
.await CloudRequirementsNetworkError::Timeout {
.inspect_err(|_| { timeout_ms: self.timeout.as_millis() as u64,
tracing::warn!("Timed out waiting for cloud requirements; continuing without them"); }
}) })?;
.ok()?;
let elapsed_ms = started_at.elapsed().as_millis();
match result.as_ref() { match result.as_ref() {
Some(requirements) => { Ok(Some(requirements)) => {
tracing::info!( tracing::info!(
elapsed_ms = started_at.elapsed().as_millis(), elapsed_ms,
status = "success",
requirements = ?requirements, requirements = ?requirements,
"Cloud requirements load completed" "Cloud requirements load completed"
); );
println!(
"cloud_requirements status=success elapsed_ms={elapsed_ms} value={requirements:?}"
);
} }
None => { Ok(None) => {
tracing::info!( tracing::info!(
elapsed_ms = started_at.elapsed().as_millis(), elapsed_ms,
"Cloud requirements load completed (none)" status = "none",
requirements = %"none",
"Cloud requirements load completed"
);
println!("cloud_requirements status=none elapsed_ms={elapsed_ms} value=none");
}
Err(err) => {
tracing::warn!(
elapsed_ms,
status = "error",
requirements = %"none",
error = %err,
"Cloud requirements load failed"
);
println!(
"cloud_requirements status=error elapsed_ms={elapsed_ms} value=none error={err}"
); );
} }
} }
@@ -117,17 +196,19 @@ impl CloudRequirementsService {
result result
} }
async fn fetch(&self) -> Option<ConfigRequirementsToml> { async fn fetch(&self) -> Result<Option<ConfigRequirementsToml>, CloudRequirementsError> {
let auth = self.auth_manager.auth().await?; let auth = match self.auth_manager.auth().await {
Some(auth) => auth,
None => return Ok(None),
};
if !(auth.is_chatgpt_auth() && auth.account_plan_type() == Some(PlanType::Enterprise)) { if !(auth.is_chatgpt_auth() && auth.account_plan_type() == Some(PlanType::Enterprise)) {
return None; return Ok(None);
} }
let contents = self.fetcher.fetch_requirements(&auth).await?; let contents = self.fetcher.fetch_requirements(&auth).await?;
parse_cloud_requirements(&contents) parse_cloud_requirements(&contents)
.inspect_err(|err| tracing::warn!(error = %err, "Failed to parse cloud requirements")) .inspect_err(|err| tracing::warn!(error = %err, "Failed to parse cloud requirements"))
.ok() .map_err(CloudRequirementsError::from)
.flatten()
} }
} }
@@ -143,20 +224,28 @@ pub fn cloud_requirements_loader(
let task = tokio::spawn(async move { service.fetch_with_timeout().await }); let task = tokio::spawn(async move { service.fetch_with_timeout().await });
CloudRequirementsLoader::new(async move { CloudRequirementsLoader::new(async move {
task.await task.await
.map_err(|err| {
CloudRequirementsError::from(CloudRequirementsNetworkError::Task {
message: err.to_string(),
})
})
.and_then(std::convert::identity)
.map_err(io::Error::from)
.inspect_err(|err| tracing::warn!(error = %err, "Cloud requirements task failed")) .inspect_err(|err| tracing::warn!(error = %err, "Cloud requirements task failed"))
.ok()
.flatten()
}) })
} }
fn parse_cloud_requirements( fn parse_cloud_requirements(
contents: &str, contents: &str,
) -> Result<Option<ConfigRequirementsToml>, toml::de::Error> { ) -> Result<Option<ConfigRequirementsToml>, CloudRequirementsUserError> {
if contents.trim().is_empty() { if contents.trim().is_empty() {
return Ok(None); return Ok(None);
} }
let requirements: ConfigRequirementsToml = toml::from_str(contents)?; let requirements: ConfigRequirementsToml =
toml::from_str(contents).map_err(|err| CloudRequirementsUserError::InvalidToml {
message: err.to_string(),
})?;
if requirements.is_empty() { if requirements.is_empty() {
Ok(None) Ok(None)
} else { } else {
@@ -167,6 +256,7 @@ fn parse_cloud_requirements(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use anyhow::Result;
use base64::Engine; use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use codex_core::auth::AuthCredentialsStoreMode; use codex_core::auth::AuthCredentialsStoreMode;
@@ -177,28 +267,28 @@ mod tests {
use std::path::Path; use std::path::Path;
use tempfile::tempdir; use tempfile::tempdir;
fn write_auth_json(codex_home: &Path, value: serde_json::Value) -> std::io::Result<()> { fn write_auth_json(codex_home: &Path, value: serde_json::Value) -> Result<()> {
std::fs::write(codex_home.join("auth.json"), serde_json::to_string(&value)?)?; std::fs::write(codex_home.join("auth.json"), serde_json::to_string(&value)?)?;
Ok(()) Ok(())
} }
fn auth_manager_with_api_key() -> Arc<AuthManager> { fn auth_manager_with_api_key() -> Result<Arc<AuthManager>> {
let tmp = tempdir().expect("tempdir"); let tmp = tempdir()?;
let auth_json = json!({ let auth_json = json!({
"OPENAI_API_KEY": "sk-test-key", "OPENAI_API_KEY": "sk-test-key",
"tokens": null, "tokens": null,
"last_refresh": null, "last_refresh": null,
}); });
write_auth_json(tmp.path(), auth_json).expect("write auth"); write_auth_json(tmp.path(), auth_json)?;
Arc::new(AuthManager::new( Ok(Arc::new(AuthManager::new(
tmp.path().to_path_buf(), tmp.path().to_path_buf(),
false, false,
AuthCredentialsStoreMode::File, AuthCredentialsStoreMode::File,
)) )))
} }
fn auth_manager_with_plan(plan_type: &str) -> Arc<AuthManager> { fn auth_manager_with_plan(plan_type: &str) -> Result<Arc<AuthManager>> {
let tmp = tempdir().expect("tempdir"); let tmp = tempdir()?;
let header = json!({ "alg": "none", "typ": "JWT" }); let header = json!({ "alg": "none", "typ": "JWT" });
let auth_payload = json!({ let auth_payload = json!({
"chatgpt_plan_type": plan_type, "chatgpt_plan_type": plan_type,
@@ -209,8 +299,8 @@ mod tests {
"email": "user@example.com", "email": "user@example.com",
"https://api.openai.com/auth": auth_payload, "https://api.openai.com/auth": auth_payload,
}); });
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header).expect("header")); let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header)?);
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload).expect("payload")); let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload)?);
let signature_b64 = URL_SAFE_NO_PAD.encode(b"sig"); let signature_b64 = URL_SAFE_NO_PAD.encode(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
@@ -223,26 +313,31 @@ mod tests {
}, },
"last_refresh": null, "last_refresh": null,
}); });
write_auth_json(tmp.path(), auth_json).expect("write auth"); write_auth_json(tmp.path(), auth_json)?;
Arc::new(AuthManager::new( Ok(Arc::new(AuthManager::new(
tmp.path().to_path_buf(), tmp.path().to_path_buf(),
false, false,
AuthCredentialsStoreMode::File, AuthCredentialsStoreMode::File,
)) )))
} }
fn parse_for_fetch(contents: Option<&str>) -> Option<ConfigRequirementsToml> { fn parse_for_fetch(
contents.and_then(|contents| parse_cloud_requirements(contents).ok().flatten()) contents: Option<&str>,
) -> Result<Option<ConfigRequirementsToml>, CloudRequirementsUserError> {
contents.map(parse_cloud_requirements).unwrap_or(Ok(None))
} }
struct StaticFetcher { struct StaticFetcher {
contents: Option<String>, result: Result<String, CloudRequirementsError>,
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl RequirementsFetcher for StaticFetcher { impl RequirementsFetcher for StaticFetcher {
async fn fetch_requirements(&self, _auth: &CodexAuth) -> Option<String> { async fn fetch_requirements(
self.contents.clone() &self,
_auth: &CodexAuth,
) -> Result<String, CloudRequirementsError> {
self.result.clone()
} }
} }
@@ -250,88 +345,115 @@ mod tests {
#[async_trait::async_trait] #[async_trait::async_trait]
impl RequirementsFetcher for PendingFetcher { impl RequirementsFetcher for PendingFetcher {
async fn fetch_requirements(&self, _auth: &CodexAuth) -> Option<String> { async fn fetch_requirements(
&self,
_auth: &CodexAuth,
) -> Result<String, CloudRequirementsError> {
pending::<()>().await; pending::<()>().await;
None Ok(String::new())
} }
} }
#[tokio::test] #[tokio::test]
async fn fetch_cloud_requirements_skips_non_chatgpt_auth() { async fn fetch_cloud_requirements_skips_non_chatgpt_auth() -> Result<()> {
let auth_manager = auth_manager_with_api_key();
let service = CloudRequirementsService::new( let service = CloudRequirementsService::new(
auth_manager, auth_manager_with_api_key()?,
Arc::new(StaticFetcher { contents: None }), Arc::new(StaticFetcher {
result: Ok(String::new()),
}),
CLOUD_REQUIREMENTS_TIMEOUT, CLOUD_REQUIREMENTS_TIMEOUT,
); );
let result = service.fetch().await; assert_eq!(service.fetch().await, Ok(None));
assert!(result.is_none()); Ok(())
} }
#[tokio::test] #[tokio::test]
async fn fetch_cloud_requirements_skips_non_enterprise_plan() { async fn fetch_cloud_requirements_skips_non_enterprise_plan() -> Result<()> {
let auth_manager = auth_manager_with_plan("pro");
let service = CloudRequirementsService::new( let service = CloudRequirementsService::new(
auth_manager, auth_manager_with_plan("pro")?,
Arc::new(StaticFetcher { contents: None }), Arc::new(StaticFetcher {
result: Ok(String::new()),
}),
CLOUD_REQUIREMENTS_TIMEOUT, CLOUD_REQUIREMENTS_TIMEOUT,
); );
let result = service.fetch().await; assert_eq!(service.fetch().await, Ok(None));
assert!(result.is_none()); Ok(())
} }
#[tokio::test] #[tokio::test]
async fn fetch_cloud_requirements_handles_missing_contents() { async fn fetch_cloud_requirements_returns_missing_contents_error() -> Result<()> {
let result = parse_for_fetch(None); let service = CloudRequirementsService::new(
assert!(result.is_none()); auth_manager_with_plan("enterprise")?,
} Arc::new(StaticFetcher {
result: Err(CloudRequirementsError::Network(
#[tokio::test] CloudRequirementsNetworkError::MissingContents,
async fn fetch_cloud_requirements_handles_empty_contents() { )),
let result = parse_for_fetch(Some(" ")); }),
assert!(result.is_none()); CLOUD_REQUIREMENTS_TIMEOUT,
} );
#[tokio::test]
async fn fetch_cloud_requirements_handles_invalid_toml() {
let result = parse_for_fetch(Some("not = ["));
assert!(result.is_none());
}
#[tokio::test]
async fn fetch_cloud_requirements_ignores_empty_requirements() {
let result = parse_for_fetch(Some("# comment"));
assert!(result.is_none());
}
#[tokio::test]
async fn fetch_cloud_requirements_parses_valid_toml() {
let result = parse_for_fetch(Some("allowed_approval_policies = [\"never\"]"));
assert_eq!( assert_eq!(
result, service.fetch().await,
Some(ConfigRequirementsToml { Err(CloudRequirementsError::Network(
CloudRequirementsNetworkError::MissingContents
))
);
Ok(())
}
#[tokio::test]
async fn fetch_cloud_requirements_handles_empty_contents() -> Result<()> {
assert_eq!(parse_for_fetch(Some(" ")), Ok(None));
Ok(())
}
#[tokio::test]
async fn fetch_cloud_requirements_handles_invalid_toml() -> Result<()> {
assert!(matches!(
parse_for_fetch(Some("not = [")),
Err(CloudRequirementsUserError::InvalidToml { .. })
));
Ok(())
}
#[tokio::test]
async fn fetch_cloud_requirements_ignores_empty_requirements() -> Result<()> {
assert_eq!(parse_for_fetch(Some("# comment")), Ok(None));
Ok(())
}
#[tokio::test]
async fn fetch_cloud_requirements_parses_valid_toml() -> Result<()> {
assert_eq!(
parse_for_fetch(Some("allowed_approval_policies = [\"never\"]")),
Ok(Some(ConfigRequirementsToml {
allowed_approval_policies: Some(vec![AskForApproval::Never]), allowed_approval_policies: Some(vec![AskForApproval::Never]),
allowed_sandbox_modes: None, allowed_sandbox_modes: None,
mcp_servers: None, mcp_servers: None,
rules: None, rules: None,
enforce_residency: None, enforce_residency: None,
}) }))
); );
Ok(())
} }
#[tokio::test(start_paused = true)] #[tokio::test(start_paused = true)]
async fn fetch_cloud_requirements_times_out() { async fn fetch_cloud_requirements_times_out() -> Result<()> {
let auth_manager = auth_manager_with_plan("enterprise");
let service = CloudRequirementsService::new( let service = CloudRequirementsService::new(
auth_manager, auth_manager_with_plan("enterprise")?,
Arc::new(PendingFetcher), Arc::new(PendingFetcher),
CLOUD_REQUIREMENTS_TIMEOUT, CLOUD_REQUIREMENTS_TIMEOUT,
); );
let handle = tokio::spawn(async move { service.fetch_with_timeout().await }); let handle = tokio::spawn(async move { service.fetch_with_timeout().await });
tokio::time::advance(CLOUD_REQUIREMENTS_TIMEOUT + Duration::from_millis(1)).await; tokio::time::advance(CLOUD_REQUIREMENTS_TIMEOUT + Duration::from_millis(1)).await;
let result = handle.await.expect("cloud requirements task"); assert_eq!(
assert!(result.is_none()); handle.await?,
Err(CloudRequirementsError::Network(
CloudRequirementsNetworkError::Timeout {
timeout_ms: CLOUD_REQUIREMENTS_TIMEOUT.as_millis() as u64,
}
))
);
Ok(())
} }
} }

View File

@@ -4,25 +4,29 @@ use futures::future::FutureExt;
use futures::future::Shared; use futures::future::Shared;
use std::fmt; use std::fmt;
use std::future::Future; use std::future::Future;
use std::io;
use std::sync::Arc;
#[derive(Clone)] #[derive(Clone)]
pub struct CloudRequirementsLoader { pub struct CloudRequirementsLoader {
// TODO(gt): This should return a Result once we can fail-closed. fut: Shared<BoxFuture<'static, Arc<io::Result<Option<ConfigRequirementsToml>>>>>,
fut: Shared<BoxFuture<'static, Option<ConfigRequirementsToml>>>,
} }
impl CloudRequirementsLoader { impl CloudRequirementsLoader {
pub fn new<F>(fut: F) -> Self pub fn new<F>(fut: F) -> Self
where where
F: Future<Output = Option<ConfigRequirementsToml>> + Send + 'static, F: Future<Output = io::Result<Option<ConfigRequirementsToml>>> + Send + 'static,
{ {
Self { Self {
fut: fut.boxed().shared(), fut: fut.map(Arc::new).boxed().shared(),
} }
} }
pub async fn get(&self) -> Option<ConfigRequirementsToml> { pub async fn get(&self) -> io::Result<Option<ConfigRequirementsToml>> {
self.fut.clone().await match self.fut.clone().await.as_ref() {
Ok(requirements) => Ok(requirements.clone()),
Err(err) => Err(io::Error::new(err.kind(), err.to_string())),
}
} }
} }
@@ -34,7 +38,7 @@ impl fmt::Debug for CloudRequirementsLoader {
impl Default for CloudRequirementsLoader { impl Default for CloudRequirementsLoader {
fn default() -> Self { fn default() -> Self {
Self::new(async { None }) Self::new(async { Ok(None) })
} }
} }
@@ -52,11 +56,11 @@ mod tests {
let counter_clone = Arc::clone(&counter); let counter_clone = Arc::clone(&counter);
let loader = CloudRequirementsLoader::new(async move { let loader = CloudRequirementsLoader::new(async move {
counter_clone.fetch_add(1, Ordering::SeqCst); counter_clone.fetch_add(1, Ordering::SeqCst);
Some(ConfigRequirementsToml::default()) Ok(Some(ConfigRequirementsToml::default()))
}); });
let (first, second) = tokio::join!(loader.get(), loader.get()); let (first, second) = tokio::join!(loader.get(), loader.get());
assert_eq!(first, second); assert_eq!(first.as_ref().ok(), second.as_ref().ok());
assert_eq!(counter.load(Ordering::SeqCst), 1); assert_eq!(counter.load(Ordering::SeqCst), 1);
} }
} }

View File

@@ -115,7 +115,7 @@ pub async fn load_config_layers_state(
) )
.await?; .await?;
if let Some(requirements) = cloud_requirements.get().await { if let Some(requirements) = cloud_requirements.get().await? {
config_requirements_toml config_requirements_toml
.merge_unset_fields(RequirementSource::CloudRequirements, requirements); .merge_unset_fields(RequirementSource::CloudRequirements, requirements);
} }

View File

@@ -545,7 +545,7 @@ async fn load_config_layers_includes_cloud_requirements() -> anyhow::Result<()>
enforce_residency: None, enforce_residency: None,
}; };
let expected = requirements.clone(); let expected = requirements.clone();
let cloud_requirements = CloudRequirementsLoader::new(async move { Some(requirements) }); let cloud_requirements = CloudRequirementsLoader::new(async move { Ok(Some(requirements)) });
let layers = load_config_layers_state( let layers = load_config_layers_state(
&codex_home, &codex_home,