load models from disk and set a ttl and etag (#7722)

# External (non-OpenAI) Pull Request Requirements

Before opening this Pull Request, please read the dedicated
"Contributing" markdown file or your PR may be closed:
https://github.com/openai/codex/blob/main/docs/contributing.md

If your PR conforms to our contribution guidelines, replace this text
with a detailed and high quality description of your changes.

Include a link to a bug report or enhancement request.
This commit is contained in:
Ahmed Ibrahim
2025-12-08 13:43:04 -08:00
committed by GitHub
parent 4a3e9ed88d
commit 222a491570
13 changed files with 414 additions and 70 deletions

View File

@@ -8,6 +8,7 @@ use codex_client::RequestTelemetry;
use codex_protocol::openai_models::ModelsResponse;
use http::HeaderMap;
use http::Method;
use http::header::ETAG;
use std::sync::Arc;
pub struct ModelsClient<T: HttpTransport, A: AuthProvider> {
@@ -59,12 +60,23 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
)
.await?;
serde_json::from_slice::<ModelsResponse>(&resp.body).map_err(|e| {
ApiError::Stream(format!(
"failed to decode models response: {e}; body: {}",
String::from_utf8_lossy(&resp.body)
))
})
let header_etag = resp
.headers
.get(ETAG)
.and_then(|value| value.to_str().ok())
.map(ToString::to_string);
let ModelsResponse { models, etag } = serde_json::from_slice::<ModelsResponse>(&resp.body)
.map_err(|e| {
ApiError::Stream(format!(
"failed to decode models response: {e}; body: {}",
String::from_utf8_lossy(&resp.body)
))
})?;
let etag = header_etag.unwrap_or(etag);
Ok(ModelsResponse { models, etag })
}
}
@@ -86,20 +98,36 @@ mod tests {
use std::sync::Mutex;
use std::time::Duration;
#[derive(Clone, Default)]
#[derive(Clone)]
struct CapturingTransport {
last_request: Arc<Mutex<Option<Request>>>,
body: Arc<ModelsResponse>,
}
impl Default for CapturingTransport {
fn default() -> Self {
Self {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(ModelsResponse {
models: Vec::new(),
etag: String::new(),
}),
}
}
}
#[async_trait]
impl HttpTransport for CapturingTransport {
async fn execute(&self, req: Request) -> Result<Response, TransportError> {
*self.last_request.lock().unwrap() = Some(req);
let body = serde_json::to_vec(&*self.body).unwrap();
let mut headers = HeaderMap::new();
if !self.body.etag.is_empty() {
headers.insert(ETAG, self.body.etag.parse().unwrap());
}
Ok(Response {
status: StatusCode::OK,
headers: HeaderMap::new(),
headers,
body: body.into(),
})
}
@@ -138,7 +166,10 @@ mod tests {
#[tokio::test]
async fn appends_client_version_query() {
let response = ModelsResponse { models: Vec::new() };
let response = ModelsResponse {
models: Vec::new(),
etag: String::new(),
};
let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
@@ -191,6 +222,7 @@ mod tests {
}))
.unwrap(),
],
etag: String::new(),
};
let transport = CapturingTransport {
@@ -214,4 +246,31 @@ mod tests {
assert_eq!(result.models[0].supported_in_api, true);
assert_eq!(result.models[0].priority, 1);
}
#[tokio::test]
async fn list_models_includes_etag() {
let response = ModelsResponse {
models: Vec::new(),
etag: "\"abc\"".to_string(),
};
let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
};
let client = ModelsClient::new(
transport,
provider("https://example.com/api/codex"),
DummyAuth,
);
let result = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("request should succeed");
assert_eq!(result.models.len(), 0);
assert_eq!(result.etag, "\"abc\"");
}
}