realtime: use codex tool handoff and fix audio playback

This commit is contained in:
Ahmed Ibrahim
2026-03-04 16:55:47 -08:00
parent 8a59386273
commit f13917d50e
8 changed files with 667 additions and 200 deletions

View File

@@ -2,8 +2,10 @@ pub use codex_protocol::protocol::RealtimeAudioFrame;
pub use codex_protocol::protocol::RealtimeEvent;
pub use codex_protocol::protocol::RealtimeHandoffMessage;
pub use codex_protocol::protocol::RealtimeHandoffRequested;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use std::string::ToString;
use tracing::debug;
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -18,11 +20,8 @@ pub struct RealtimeSessionConfig {
pub(super) enum RealtimeOutboundMessage {
#[serde(rename = "input_audio_buffer.append")]
InputAudioBufferAppend { audio: String },
#[serde(rename = "conversation.handoff.append")]
ConversationHandoffAppend {
handoff_id: String,
output_text: String,
},
#[serde(rename = "response.create")]
ResponseCreate,
#[serde(rename = "session.update")]
SessionUpdate { session: SessionUpdateSession },
#[serde(rename = "conversation.item.create")]
@@ -34,7 +33,10 @@ pub(super) struct SessionUpdateSession {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) instructions: String,
pub(super) output_modalities: Vec<String>,
pub(super) audio: SessionAudio,
pub(super) tools: Vec<SessionTool>,
pub(super) tool_choice: String,
}
#[derive(Debug, Clone, Serialize)]
@@ -57,15 +59,55 @@ pub(super) struct SessionAudioFormat {
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionAudioOutput {
pub(super) format: SessionAudioOutputFormat,
pub(super) voice: String,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct ConversationItem {
pub(super) struct SessionAudioOutputFormat {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) role: String,
pub(super) content: Vec<ConversationItemContent>,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionTool {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) name: String,
pub(super) description: String,
pub(super) parameters: SessionToolParameters,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionToolParameters {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) properties: SessionToolProperties,
pub(super) required: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionToolProperties {
pub(super) prompt: SessionToolProperty,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionToolProperty {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) description: String,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type")]
pub(super) enum ConversationItem {
#[serde(rename = "message")]
Message {
role: String,
content: Vec<ConversationItemContent>,
},
#[serde(rename = "function_call_output")]
FunctionCallOutput { call_id: String, output: String },
}
#[derive(Debug, Clone, Serialize)]
@@ -92,7 +134,7 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
}
};
match message_type {
"session.updated" => {
"session.created" | "session.updated" => {
let session_id = parsed
.get("session")
.and_then(Value::as_object)
@@ -110,7 +152,7 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
instructions,
})
}
"conversation.output_audio.delta" => {
"response.output_audio.delta" => {
let data = parsed
.get("delta")
.and_then(Value::as_str)
@@ -119,12 +161,14 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
let sample_rate = parsed
.get("sample_rate")
.and_then(Value::as_u64)
.and_then(|v| u32::try_from(v).ok())?;
.and_then(|v| u32::try_from(v).ok())
.unwrap_or(24_000);
let num_channels = parsed
.get("channels")
.or_else(|| parsed.get("num_channels"))
.and_then(Value::as_u64)
.and_then(|v| u16::try_from(v).ok())?;
.and_then(|v| u16::try_from(v).ok())
.unwrap_or(1);
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
data,
sample_rate,
@@ -146,35 +190,11 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
.and_then(Value::as_str)
.map(str::to_string)
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
"conversation.handoff.requested" => {
let handoff_id = parsed
.get("handoff_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let item_id = parsed
.get("item_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let input_transcript = parsed
.get("input_transcript")
.and_then(Value::as_str)
.map(str::to_string)?;
let messages = parsed
.get("messages")
.and_then(Value::as_array)?
.iter()
.filter_map(|message| {
let role = message.get("role").and_then(Value::as_str)?.to_string();
let text = message.get("text").and_then(Value::as_str)?.to_string();
Some(RealtimeHandoffMessage { role, text })
})
.collect();
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id,
item_id,
input_transcript,
messages,
}))
"response.done" => {
if let Some(handoff) = parse_handoff_requested(&parsed) {
return Some(RealtimeEvent::HandoffRequested(handoff));
}
Some(RealtimeEvent::ConversationItemAdded(parsed))
}
"error" => parsed
.get("message")
@@ -188,11 +208,100 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
.and_then(Value::as_str)
.map(str::to_string)
})
.or_else(|| parsed.get("error").map(std::string::ToString::to_string))
.or_else(|| parsed.get("error").map(ToString::to_string))
.map(RealtimeEvent::Error),
_ => {
debug!("received unsupported realtime event type: {message_type}, data: {payload}");
None
}
_ => Some(RealtimeEvent::ConversationItemAdded(parsed)),
}
}
fn parse_handoff_requested(parsed: &Value) -> Option<RealtimeHandoffRequested> {
let outputs = parsed
.get("response")
.and_then(Value::as_object)
.and_then(|response| response.get("output"))
.and_then(Value::as_array)?;
let function_call = outputs.iter().find(|item| {
item.get("type").and_then(Value::as_str) == Some("function_call")
&& item.get("name").and_then(Value::as_str) == Some("codex")
})?;
let handoff_id = function_call
.get("call_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let item_id = function_call
.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.unwrap_or_else(|| handoff_id.clone());
let arguments = function_call
.get("arguments")
.and_then(Value::as_str)
.unwrap_or_default();
let (input_transcript, messages) = parse_handoff_arguments(arguments);
Some(RealtimeHandoffRequested {
handoff_id,
item_id,
input_transcript,
messages,
})
}
fn parse_handoff_arguments(arguments: &str) -> (String, Vec<RealtimeHandoffMessage>) {
#[derive(Debug, Deserialize)]
struct HandoffArguments {
#[serde(default)]
prompt: Option<String>,
#[serde(default)]
text: Option<String>,
#[serde(default)]
input: Option<String>,
#[serde(default)]
message: Option<String>,
#[serde(default)]
input_transcript: Option<String>,
#[serde(default)]
messages: Vec<RealtimeHandoffMessage>,
}
let Some(parsed) = serde_json::from_str::<HandoffArguments>(arguments).ok() else {
return (
arguments.to_string(),
vec![RealtimeHandoffMessage {
role: "user".to_string(),
text: arguments.to_string(),
}],
);
};
let messages = parsed
.messages
.into_iter()
.filter(|message| !message.text.is_empty())
.collect::<Vec<_>>();
for value in [
parsed.prompt,
parsed.text,
parsed.input,
parsed.message,
parsed.input_transcript,
]
.into_iter()
.flatten()
{
if !value.is_empty() {
if messages.is_empty() {
return (
value.clone(),
vec![RealtimeHandoffMessage {
role: "user".to_string(),
text: value,
}],
);
}
return (value, messages);
}
}
if let Some(first_message) = messages.first() {
return (first_message.text.clone(), messages);
}
(String::new(), messages)
}