This commit is contained in:
Ahmed Ibrahim
2025-12-18 20:03:11 -08:00
parent ecff4d4f72
commit 359142f22f

View File

@@ -2411,19 +2411,12 @@ async fn run_turn(
let mut retries = 0;
loop {
let model_supports_parallel = turn_context
.client
.get_model_family()
.supports_parallel_tool_calls;
let prompt = Prompt {
input: input.clone(),
tools: router.specs(),
parallel_tool_calls: model_supports_parallel
&& sess.enabled(Feature::ParallelToolCalls),
base_instructions_override: turn_context.base_instructions.clone(),
output_schema: turn_context.final_output_json_schema.clone(),
};
let prompt = build_prompt(
sess.as_ref(),
turn_context.as_ref(),
router.as_ref(),
&input,
);
match try_run_turn(
Arc::clone(&router),
@@ -2460,38 +2453,15 @@ async fn run_turn(
Err(e @ CodexErr::InvalidRequest(_)) => return Err(e),
Err(e @ CodexErr::RefreshTokenFailed(_)) => return Err(e),
Err(e) => {
// Refresh models if we got an outdated models error
if matches!(e, CodexErr::OutdatedModels) {
let config = {
let state = sess.state.lock().await;
state
.session_configuration
.original_config_do_not_use
.clone()
};
if let Err(err) = sess
.services
.models_manager
.refresh_available_models(&config)
.await
{
error!("failed to refresh models after outdated models error: {err}");
}
let models_etag = sess.services.models_manager.get_models_etag().await;
let model = turn_context.client.get_model();
let model_family = sess
.services
.models_manager
.construct_model_family(&model, &config)
.await;
turn_context.client.update_model_family(model_family);
turn_context.client.update_models_etag(models_etag);
}
// Use the configured provider-specific stream retry budget.
let max_retries = turn_context.client.get_provider().stream_max_retries();
if retries < max_retries {
retries += 1;
// Refresh models if we got an outdated models error
if matches!(e, CodexErr::OutdatedModels) {
refresh_models_after_outdated_error(sess.as_ref(), turn_context.as_ref())
.await;
}
let delay = match e {
CodexErr::Stream(_, Some(delay)) => delay,
_ => backoff(retries),
@@ -2519,6 +2489,53 @@ async fn run_turn(
}
}
fn build_prompt(
sess: &Session,
turn_context: &TurnContext,
router: &ToolRouter,
input: &[ResponseItem],
) -> Prompt {
let model_supports_parallel = turn_context
.client
.get_model_family()
.supports_parallel_tool_calls;
Prompt {
input: input.to_vec(),
tools: router.specs(),
parallel_tool_calls: model_supports_parallel && sess.enabled(Feature::ParallelToolCalls),
base_instructions_override: turn_context.base_instructions.clone(),
output_schema: turn_context.final_output_json_schema.clone(),
}
}
async fn refresh_models_after_outdated_error(sess: &Session, turn_context: &TurnContext) {
let config = {
let state = sess.state.lock().await;
state
.session_configuration
.original_config_do_not_use
.clone()
};
if let Err(err) = sess
.services
.models_manager
.refresh_available_models(&config)
.await
{
error!("failed to refresh models after outdated models error: {err}");
}
let models_etag = sess.services.models_manager.get_models_etag().await;
let model = turn_context.client.get_model();
let model_family = sess
.services
.models_manager
.construct_model_family(&model, &config)
.await;
turn_context.client.update_model_family(model_family);
turn_context.client.update_models_etag(models_etag);
}
#[derive(Debug)]
struct TurnRunResult {
needs_follow_up: bool,