This commit is contained in:
jif-oai
2025-11-11 14:45:06 +00:00
parent 5fc0c39386
commit f6494aa85c
22 changed files with 703 additions and 1102 deletions

View File

@@ -12,17 +12,15 @@ use tokio::sync::mpsc;
use tracing::debug;
use tracing::trace;
use crate::api::ApiClient;
use crate::api::PayloadClient;
use crate::auth::AuthProvider;
use crate::client::PayloadBuilder;
use crate::common::backoff;
use crate::decode::responses::ErrorResponse;
use crate::error::Error;
use crate::error::Result;
use crate::model_provider::ModelProviderInfo;
use crate::prompt::Prompt;
use crate::stream::ResponseEvent;
use crate::stream::ResponseStream;
use codex_provider_config::ModelProviderInfo;
#[derive(Clone)]
/// Configuration for the OpenAI Responses API client (`/v1/responses`).
@@ -47,62 +45,39 @@ pub struct ResponsesApiClient {
config: ResponsesApiClientConfig,
}
#[async_trait]
impl ApiClient for ResponsesApiClient {
impl PayloadClient for ResponsesApiClient {
type Config = ResponsesApiClientConfig;
fn new(config: Self::Config) -> Result<Self> {
Ok(Self { config })
}
async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
if self.config.provider.wire_api != crate::model_provider::WireApi::Responses {
async fn stream_payload(
&self,
payload_json: &Value,
session_source: Option<&codex_protocol::protocol::SessionSource>,
) -> Result<ResponseStream> {
if self.config.provider.wire_api != codex_provider_config::WireApi::Responses {
return Err(Error::UnsupportedOperation(
"ResponsesApiClient requires a Responses provider".to_string(),
));
}
let payload_json = crate::payload::responses::ResponsesPayloadBuilder::new(
self.config.model.clone(),
self.config.conversation_id,
self.config.provider.is_azure_responses_endpoint(),
)
.build(prompt)?;
let max_attempts = self.config.provider.request_max_retries();
for attempt in 0..=max_attempts {
match self
.attempt_stream_responses(attempt, prompt, &payload_json)
.await
{
Ok(stream) => return Ok(stream),
Err(StreamAttemptError::Fatal(err)) => return Err(err),
Err(retryable) => {
if attempt == max_attempts {
return Err(retryable.into_error());
}
tokio::time::sleep(retryable.delay(attempt)).await;
}
}
}
unreachable!("attempt_stream_responses should always return");
}
}
impl ResponsesApiClient {
async fn attempt_stream_responses(
&self,
attempt: i64,
prompt: &Prompt,
payload_json: &Value,
) -> std::result::Result<ResponseStream, StreamAttemptError> {
let auth = crate::client::http::resolve_auth(&self.config.auth_provider).await;
trace!(
"POST to {}: {:?}",
self.config.provider.get_full_url(auth.as_ref()),
self.config.provider.get_full_url(
auth.as_ref()
.map(|a| codex_provider_config::AuthContext {
mode: a.mode,
bearer_token: a.bearer_token.clone(),
account_id: a.account_id.clone(),
})
.as_ref()
),
serde_json::to_string(payload_json)
.unwrap_or_else(|_| "<unable to serialize payload>".to_string())
);
@@ -115,11 +90,10 @@ impl ResponsesApiClient {
&self.config.http_client,
&self.config.provider,
&auth,
prompt.session_source.as_ref(),
session_source,
&extra_headers,
)
.await
.map_err(StreamAttemptError::Fatal)?;
.await?;
req_builder = req_builder
.header(reqwest::header::ACCEPT, "text/event-stream")
@@ -135,152 +109,40 @@ impl ResponsesApiClient {
let res = self
.config
.otel_event_manager
.log_request(attempt as u64, || req_builder.send())
.await;
.log_request(0, || req_builder.send())
.await
.map_err(|source| Error::ResponseStreamFailed {
source,
request_id: None,
})?;
let mut request_id = None;
if let Ok(resp) = &res {
request_id = resp
.headers()
.get("cf-ray")
.and_then(|v| v.to_str().ok())
.map(std::string::ToString::to_string);
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
if let Some(snapshot) = crate::client::rate_limits::parse_rate_limit_snapshot(res.headers())
&& tx_event
.send(Ok(ResponseEvent::RateLimits(snapshot)))
.await
.is_err()
{
debug!("receiver dropped rate limit snapshot event");
}
match res {
Ok(resp) if resp.status().is_success() => {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let stream = res
.bytes_stream()
.map_err(|err| Error::ResponseStreamFailed {
source: err,
request_id: None,
});
let idle_timeout = self.config.provider.stream_idle_timeout();
let otel = self.config.otel_event_manager.clone();
tokio::spawn(crate::client::sse::process_sse(
stream,
tx_event,
idle_timeout,
otel,
crate::decode::responses::ResponsesSseDecoder,
));
if let Some(snapshot) =
crate::client::rate_limits::parse_rate_limit_snapshot(resp.headers())
&& tx_event
.send(Ok(ResponseEvent::RateLimits(snapshot)))
.await
.is_err()
{
debug!("receiver dropped rate limit snapshot event");
}
let stream = resp
.bytes_stream()
.map_err(move |err| Error::ResponseStreamFailed {
source: err,
request_id: request_id.clone(),
});
let idle_timeout = self.config.provider.stream_idle_timeout();
let otel = self.config.otel_event_manager.clone();
tokio::spawn(crate::client::sse::process_sse(
stream,
tx_event,
idle_timeout,
otel,
crate::decode::responses::ResponsesSseDecoder,
));
Ok(ResponseStream { rx_event })
}
Ok(resp) => Err(handle_error_response(resp, request_id, &self.config).await),
Err(err) => Err(StreamAttemptError::RetryableTransportError(Error::Http(
err,
))),
}
Ok(crate::stream::EventStream::from_receiver(rx_event))
}
}
// payload building is provided by crate::payload::responses
async fn handle_error_response(
resp: reqwest::Response,
request_id: Option<String>,
_config: &ResponsesApiClientConfig,
) -> StreamAttemptError {
let status = resp.status();
let retry_after_secs = resp
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<i64>().ok());
let retry_after = retry_after_secs.map(|secs| {
let clamped = if secs < 0 { 0 } else { secs as u64 };
Duration::from_secs(clamped)
});
if !(status == StatusCode::TOO_MANY_REQUESTS
|| status == StatusCode::UNAUTHORIZED
|| status.is_server_error())
{
let body = resp.text().await.unwrap_or_default();
return StreamAttemptError::Fatal(Error::UnexpectedStatus { status, body });
}
if status == StatusCode::TOO_MANY_REQUESTS {
let rate_limits = crate::client::rate_limits::parse_rate_limit_snapshot(resp.headers());
let body = resp.json::<ErrorResponse>().await.ok();
if let Some(ErrorResponse { error }) = body {
if error.r#type.as_deref() == Some("usage_limit_reached") {
return StreamAttemptError::Fatal(Error::UsageLimitReached {
plan_type: error.plan_type,
resets_at: error.resets_at,
rate_limits,
});
} else if error.r#type.as_deref() == Some("usage_not_included") {
return StreamAttemptError::Fatal(Error::Stream(
"usage not included".to_string(),
None,
));
} else if crate::decode::responses::is_quota_exceeded_error(&error) {
return StreamAttemptError::Fatal(Error::Stream(
"quota exceeded".to_string(),
None,
));
}
}
}
StreamAttemptError::RetryableHttpError {
status,
retry_after,
request_id,
}
}
enum StreamAttemptError {
RetryableHttpError {
status: StatusCode,
retry_after: Option<Duration>,
request_id: Option<String>,
},
RetryableTransportError(Error),
Fatal(Error),
}
impl StreamAttemptError {
fn delay(&self, attempt: i64) -> Duration {
match self {
StreamAttemptError::RetryableHttpError {
retry_after: Some(retry_after),
..
} => *retry_after,
StreamAttemptError::RetryableHttpError {
retry_after: None, ..
}
| StreamAttemptError::RetryableTransportError(..) => backoff(attempt),
StreamAttemptError::Fatal(..) => Duration::from_millis(0),
}
}
fn into_error(self) -> Error {
match self {
StreamAttemptError::RetryableHttpError {
status, request_id, ..
} => Error::RetryLimit {
status: Some(status),
request_id,
},
StreamAttemptError::RetryableTransportError(err) | StreamAttemptError::Fatal(err) => {
err
}
}
}
}