Compare commits

...

5 Commits

Author SHA1 Message Date
Ahmed Ibrahim
374bff8cc0 codex: fix realtime bazel build 2026-03-04 21:17:04 -08:00
Ahmed Ibrahim
ef5a45e585 codex: format realtime websocket and audio code 2026-03-04 20:07:50 -08:00
Ahmed Ibrahim
a5420779c4 realtime: disable output interruption from turn detection 2026-03-04 17:14:23 -08:00
Ahmed Ibrahim
159bda93c6 realtime: fix output format rate and smooth playback buffer 2026-03-04 17:09:40 -08:00
Ahmed Ibrahim
f13917d50e realtime: use codex tool handoff and fix audio playback 2026-03-04 16:55:47 -08:00
9 changed files with 785 additions and 215 deletions

View File

@@ -50,7 +50,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
vec![],
vec![
json!({
"type": "conversation.output_audio.delta",
"type": "response.output_audio.delta",
"delta": "AQID",
"sample_rate": 24_000,
"channels": 1,

View File

@@ -8,6 +8,12 @@ use crate::endpoint::realtime_websocket::protocol::SessionAudio;
use crate::endpoint::realtime_websocket::protocol::SessionAudioFormat;
use crate::endpoint::realtime_websocket::protocol::SessionAudioInput;
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutput;
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutputFormat;
use crate::endpoint::realtime_websocket::protocol::SessionTool;
use crate::endpoint::realtime_websocket::protocol::SessionToolParameters;
use crate::endpoint::realtime_websocket::protocol::SessionToolProperties;
use crate::endpoint::realtime_websocket::protocol::SessionToolProperty;
use crate::endpoint::realtime_websocket::protocol::SessionTurnDetection;
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession;
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
use crate::error::ApiError;
@@ -17,6 +23,7 @@ use futures::SinkExt;
use futures::StreamExt;
use http::HeaderMap;
use http::HeaderValue;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
@@ -220,6 +227,20 @@ impl RealtimeWebsocketConnection {
.await
}
pub async fn send_function_call_output(
&self,
call_id: String,
output_text: String,
) -> Result<(), ApiError> {
self.writer
.send_function_call_output(call_id, output_text)
.await
}
pub async fn send_response_create(&self) -> Result<(), ApiError> {
self.writer.send_response_create().await
}
pub async fn close(&self) -> Result<(), ApiError> {
self.writer.close().await
}
@@ -263,11 +284,10 @@ impl RealtimeWebsocketWriter {
pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> {
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
item: ConversationItem {
kind: "message".to_string(),
item: ConversationItem::Message {
role: "user".to_string(),
content: vec![ConversationItemContent {
kind: "text".to_string(),
kind: "input_text".to_string(),
text,
}],
},
@@ -277,32 +297,85 @@ impl RealtimeWebsocketWriter {
pub async fn send_conversation_handoff_append(
&self,
handoff_id: String,
_handoff_id: String,
output_text: String,
) -> Result<(), ApiError> {
self.send_json(RealtimeOutboundMessage::ConversationHandoffAppend {
handoff_id,
output_text,
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
item: ConversationItem::Message {
role: "assistant".to_string(),
content: vec![ConversationItemContent {
kind: "output_text".to_string(),
text: output_text,
}],
},
})
.await
}
pub async fn send_function_call_output(
&self,
call_id: String,
output_text: String,
) -> Result<(), ApiError> {
let output = json!({
"content": output_text,
})
.to_string();
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
item: ConversationItem::FunctionCallOutput { call_id, output },
})
.await
}
pub async fn send_response_create(&self) -> Result<(), ApiError> {
self.send_json(RealtimeOutboundMessage::ResponseCreate)
.await
}
pub async fn send_session_update(&self, instructions: String) -> Result<(), ApiError> {
self.send_json(RealtimeOutboundMessage::SessionUpdate {
session: SessionUpdateSession {
kind: "quicksilver".to_string(),
kind: "realtime".to_string(),
instructions,
output_modalities: vec!["audio".to_string()],
audio: SessionAudio {
input: SessionAudioInput {
format: SessionAudioFormat {
kind: "audio/pcm".to_string(),
rate: 24_000,
},
turn_detection: SessionTurnDetection {
kind: "semantic_vad".to_string(),
interrupt_response: false,
create_response: true,
},
},
output: SessionAudioOutput {
voice: "mundo".to_string(),
format: SessionAudioOutputFormat {
kind: "audio/pcm".to_string(),
rate: 24_000,
},
voice: "marin".to_string(),
},
},
tools: vec![SessionTool {
kind: "function".to_string(),
name: "codex".to_string(),
description:
"Delegate a request to Codex and return the final result to the user."
.to_string(),
parameters: SessionToolParameters {
kind: "object".to_string(),
properties: SessionToolProperties {
prompt: SessionToolProperty {
kind: "string".to_string(),
description: "The user request to delegate to Codex.".to_string(),
},
},
required: vec!["prompt".to_string()],
},
}],
tool_choice: "auto".to_string(),
},
})
.await
@@ -510,15 +583,16 @@ fn websocket_url_from_api_url(
}
}
{
let has_additional_query_params = query_params
.is_some_and(|params| params.keys().any(|key| key != "model" || model.is_none()));
if model.is_some() || has_additional_query_params {
let mut query = url.query_pairs_mut();
query.append_pair("intent", "quicksilver");
if let Some(model) = model {
query.append_pair("model", model);
}
if let Some(query_params) = query_params {
for (key, value) in query_params {
if key == "intent" || (key == "model" && model.is_some()) {
if key == "model" && model.is_some() {
continue;
}
query.append_pair(key, value);
@@ -590,7 +664,7 @@ mod tests {
#[test]
fn parse_audio_delta_event() {
let payload = json!({
"type": "conversation.output_audio.delta",
"type": "response.output_audio.delta",
"delta": "AAA=",
"sample_rate": 48000,
"channels": 1,
@@ -608,6 +682,25 @@ mod tests {
);
}
#[test]
fn parse_audio_delta_event_defaults_audio_shape() {
let payload = json!({
"type": "response.output_audio.delta",
"delta": "AAA="
})
.to_string();
assert_eq!(
parse_realtime_event(payload.as_str()),
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
data: "AAA=".to_string(),
sample_rate: 24_000,
num_channels: 1,
samples_per_channel: None,
}))
);
}
#[test]
fn parse_conversation_item_added_event() {
let payload = json!({
@@ -641,13 +734,18 @@ mod tests {
#[test]
fn parse_handoff_requested_event() {
let payload = json!({
"type": "conversation.handoff.requested",
"handoff_id": "handoff_123",
"item_id": "item_123",
"input_transcript": "delegate this",
"messages": [
{"role": "user", "text": "delegate this"}
]
"type": "response.done",
"response": {
"output": [
{
"id": "item_123",
"type": "function_call",
"name": "codex",
"call_id": "handoff_123",
"arguments": "{\"prompt\":\"delegate this\"}"
}
]
}
})
.to_string();
@@ -665,6 +763,24 @@ mod tests {
);
}
#[test]
fn parse_unknown_event_as_conversation_item_added() {
let payload = json!({
"type": "response.output_text.delta",
"delta": "hello",
"response_id": "resp_1",
})
.to_string();
assert_eq!(
parse_realtime_event(payload.as_str()),
Some(RealtimeEvent::ConversationItemAdded(json!({
"type": "response.output_text.delta",
"delta": "hello",
"response_id": "resp_1",
})))
);
}
#[test]
fn merge_request_headers_matches_http_precedence() {
let mut provider_headers = HeaderMap::new();
@@ -702,10 +818,7 @@ mod tests {
fn websocket_url_from_http_base_defaults_to_ws_path() {
let url =
websocket_url_from_api_url("http://127.0.0.1:8011", None, None).expect("build ws url");
assert_eq!(
url.as_str(),
"ws://127.0.0.1:8011/v1/realtime?intent=quicksilver"
);
assert_eq!(url.as_str(), "ws://127.0.0.1:8011/v1/realtime");
}
#[test]
@@ -715,7 +828,7 @@ mod tests {
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/v1/realtime?intent=quicksilver&model=realtime-test-model"
"wss://example.com/v1/realtime?model=realtime-test-model"
);
}
@@ -725,7 +838,7 @@ mod tests {
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://api.openai.com/v1/realtime?intent=quicksilver&model=snapshot"
"wss://api.openai.com/v1/realtime?model=snapshot"
);
}
@@ -736,7 +849,7 @@ mod tests {
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/openai/v1/realtime?intent=quicksilver&model=snapshot"
"wss://example.com/openai/v1/realtime?model=snapshot"
);
}
@@ -744,16 +857,13 @@ mod tests {
fn websocket_url_preserves_existing_realtime_path_and_extra_query_params() {
let url = websocket_url_from_api_url(
"https://example.com/v1/realtime?foo=bar",
Some(&HashMap::from([
("trace".to_string(), "1".to_string()),
("intent".to_string(), "ignored".to_string()),
])),
Some(&HashMap::from([("trace".to_string(), "1".to_string())])),
Some("snapshot"),
)
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/v1/realtime?foo=bar&intent=quicksilver&model=snapshot&trace=1"
"wss://example.com/v1/realtime?foo=bar&model=snapshot&trace=1"
);
}
@@ -777,12 +887,16 @@ mod tests {
assert_eq!(first_json["type"], "session.update");
assert_eq!(
first_json["session"]["type"],
Value::String("quicksilver".to_string())
Value::String("realtime".to_string())
);
assert_eq!(
first_json["session"]["instructions"],
Value::String("backend prompt".to_string())
);
assert_eq!(
first_json["session"]["output_modalities"][0],
Value::String("audio".to_string())
);
assert_eq!(
first_json["session"]["audio"]["input"]["format"]["type"],
Value::String("audio/pcm".to_string())
@@ -791,9 +905,45 @@ mod tests {
first_json["session"]["audio"]["input"]["format"]["rate"],
Value::from(24_000)
);
assert_eq!(
first_json["session"]["audio"]["input"]["turn_detection"]["type"],
Value::String("semantic_vad".to_string())
);
assert_eq!(
first_json["session"]["audio"]["input"]["turn_detection"]["interrupt_response"],
Value::Bool(false)
);
assert_eq!(
first_json["session"]["audio"]["input"]["turn_detection"]["create_response"],
Value::Bool(true)
);
assert_eq!(
first_json["session"]["audio"]["output"]["format"]["type"],
Value::String("audio/pcm".to_string())
);
assert_eq!(
first_json["session"]["audio"]["output"]["format"]["rate"],
Value::from(24_000)
);
assert_eq!(
first_json["session"]["audio"]["output"]["voice"],
Value::String("mundo".to_string())
Value::String("marin".to_string())
);
assert_eq!(
first_json["session"]["tool_choice"],
Value::String("auto".to_string())
);
assert_eq!(
first_json["session"]["tools"][0]["type"],
Value::String("function".to_string())
);
assert_eq!(
first_json["session"]["tools"][0]["name"],
Value::String("codex".to_string())
);
assert_eq!(
first_json["session"]["tools"][0]["parameters"]["required"][0],
Value::String("prompt".to_string())
);
ws.send(Message::Text(
@@ -836,13 +986,43 @@ mod tests {
.into_text()
.expect("text");
let fourth_json: Value = serde_json::from_str(&fourth).expect("json");
assert_eq!(fourth_json["type"], "conversation.handoff.append");
assert_eq!(fourth_json["handoff_id"], "handoff_1");
assert_eq!(fourth_json["output_text"], "hello from codex");
assert_eq!(fourth_json["type"], "conversation.item.create");
assert_eq!(fourth_json["item"]["type"], "message");
assert_eq!(fourth_json["item"]["role"], "assistant");
assert_eq!(
fourth_json["item"]["content"][0]["type"],
Value::String("output_text".to_string())
);
assert_eq!(
fourth_json["item"]["content"][0]["text"],
Value::String("hello from codex".to_string())
);
let fifth = ws
.next()
.await
.expect("fifth msg")
.expect("fifth msg ok")
.into_text()
.expect("text");
let fifth_json: Value = serde_json::from_str(&fifth).expect("json");
assert_eq!(fifth_json["type"], "conversation.item.create");
assert_eq!(fifth_json["item"]["type"], "function_call_output");
assert_eq!(fifth_json["item"]["call_id"], "handoff_1");
let sixth = ws
.next()
.await
.expect("sixth msg")
.expect("sixth msg ok")
.into_text()
.expect("text");
let sixth_json: Value = serde_json::from_str(&sixth).expect("json");
assert_eq!(sixth_json["type"], "response.create");
ws.send(Message::Text(
json!({
"type": "conversation.output_audio.delta",
"type": "response.output_audio.delta",
"delta": "AQID",
"sample_rate": 48000,
"channels": 1
@@ -855,11 +1035,18 @@ mod tests {
ws.send(Message::Text(
json!({
"type": "conversation.handoff.requested",
"handoff_id": "handoff_1",
"item_id": "item_2",
"input_transcript": "delegate now",
"messages": [{"role": "user", "text": "delegate now"}]
"type": "response.done",
"response": {
"output": [
{
"id": "item_2",
"type": "function_call",
"name": "codex",
"call_id": "handoff_1",
"arguments": "{\"prompt\":\"delegate now\"}"
}
]
}
})
.to_string()
.into(),
@@ -929,6 +1116,14 @@ mod tests {
)
.await
.expect("send handoff");
connection
.send_function_call_output("handoff_1".to_string(), "final from codex".to_string())
.await
.expect("send function output");
connection
.send_response_create()
.await
.expect("send response.create");
let audio_event = connection
.next_event()

View File

@@ -2,8 +2,10 @@ pub use codex_protocol::protocol::RealtimeAudioFrame;
pub use codex_protocol::protocol::RealtimeEvent;
pub use codex_protocol::protocol::RealtimeHandoffMessage;
pub use codex_protocol::protocol::RealtimeHandoffRequested;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use std::string::ToString;
use tracing::debug;
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -18,11 +20,8 @@ pub struct RealtimeSessionConfig {
pub(super) enum RealtimeOutboundMessage {
#[serde(rename = "input_audio_buffer.append")]
InputAudioBufferAppend { audio: String },
#[serde(rename = "conversation.handoff.append")]
ConversationHandoffAppend {
handoff_id: String,
output_text: String,
},
#[serde(rename = "response.create")]
ResponseCreate,
#[serde(rename = "session.update")]
SessionUpdate { session: SessionUpdateSession },
#[serde(rename = "conversation.item.create")]
@@ -34,7 +33,10 @@ pub(super) struct SessionUpdateSession {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) instructions: String,
pub(super) output_modalities: Vec<String>,
pub(super) audio: SessionAudio,
pub(super) tools: Vec<SessionTool>,
pub(super) tool_choice: String,
}
#[derive(Debug, Clone, Serialize)]
@@ -46,6 +48,7 @@ pub(super) struct SessionAudio {
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionAudioInput {
pub(super) format: SessionAudioFormat,
pub(super) turn_detection: SessionTurnDetection,
}
#[derive(Debug, Clone, Serialize)]
@@ -55,17 +58,66 @@ pub(super) struct SessionAudioFormat {
pub(super) rate: u32,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionTurnDetection {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) interrupt_response: bool,
pub(super) create_response: bool,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionAudioOutput {
pub(super) format: SessionAudioOutputFormat,
pub(super) voice: String,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct ConversationItem {
pub(super) struct SessionAudioOutputFormat {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) role: String,
pub(super) content: Vec<ConversationItemContent>,
pub(super) rate: u32,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionTool {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) name: String,
pub(super) description: String,
pub(super) parameters: SessionToolParameters,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionToolParameters {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) properties: SessionToolProperties,
pub(super) required: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionToolProperties {
pub(super) prompt: SessionToolProperty,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionToolProperty {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) description: String,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type")]
pub(super) enum ConversationItem {
#[serde(rename = "message")]
Message {
role: String,
content: Vec<ConversationItemContent>,
},
#[serde(rename = "function_call_output")]
FunctionCallOutput { call_id: String, output: String },
}
#[derive(Debug, Clone, Serialize)]
@@ -92,7 +144,7 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
}
};
match message_type {
"session.updated" => {
"session.created" | "session.updated" => {
let session_id = parsed
.get("session")
.and_then(Value::as_object)
@@ -110,7 +162,7 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
instructions,
})
}
"conversation.output_audio.delta" => {
"response.output_audio.delta" => {
let data = parsed
.get("delta")
.and_then(Value::as_str)
@@ -119,12 +171,14 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
let sample_rate = parsed
.get("sample_rate")
.and_then(Value::as_u64)
.and_then(|v| u32::try_from(v).ok())?;
.and_then(|v| u32::try_from(v).ok())
.unwrap_or(24_000);
let num_channels = parsed
.get("channels")
.or_else(|| parsed.get("num_channels"))
.and_then(Value::as_u64)
.and_then(|v| u16::try_from(v).ok())?;
.and_then(|v| u16::try_from(v).ok())
.unwrap_or(1);
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
data,
sample_rate,
@@ -146,35 +200,11 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
.and_then(Value::as_str)
.map(str::to_string)
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
"conversation.handoff.requested" => {
let handoff_id = parsed
.get("handoff_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let item_id = parsed
.get("item_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let input_transcript = parsed
.get("input_transcript")
.and_then(Value::as_str)
.map(str::to_string)?;
let messages = parsed
.get("messages")
.and_then(Value::as_array)?
.iter()
.filter_map(|message| {
let role = message.get("role").and_then(Value::as_str)?.to_string();
let text = message.get("text").and_then(Value::as_str)?.to_string();
Some(RealtimeHandoffMessage { role, text })
})
.collect();
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id,
item_id,
input_transcript,
messages,
}))
"response.done" => {
if let Some(handoff) = parse_handoff_requested(&parsed) {
return Some(RealtimeEvent::HandoffRequested(handoff));
}
Some(RealtimeEvent::ConversationItemAdded(parsed))
}
"error" => parsed
.get("message")
@@ -188,11 +218,100 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
.and_then(Value::as_str)
.map(str::to_string)
})
.or_else(|| parsed.get("error").map(std::string::ToString::to_string))
.or_else(|| parsed.get("error").map(ToString::to_string))
.map(RealtimeEvent::Error),
_ => {
debug!("received unsupported realtime event type: {message_type}, data: {payload}");
None
}
_ => Some(RealtimeEvent::ConversationItemAdded(parsed)),
}
}
fn parse_handoff_requested(parsed: &Value) -> Option<RealtimeHandoffRequested> {
let outputs = parsed
.get("response")
.and_then(Value::as_object)
.and_then(|response| response.get("output"))
.and_then(Value::as_array)?;
let function_call = outputs.iter().find(|item| {
item.get("type").and_then(Value::as_str) == Some("function_call")
&& item.get("name").and_then(Value::as_str) == Some("codex")
})?;
let handoff_id = function_call
.get("call_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let item_id = function_call
.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.unwrap_or_else(|| handoff_id.clone());
let arguments = function_call
.get("arguments")
.and_then(Value::as_str)
.unwrap_or_default();
let (input_transcript, messages) = parse_handoff_arguments(arguments);
Some(RealtimeHandoffRequested {
handoff_id,
item_id,
input_transcript,
messages,
})
}
fn parse_handoff_arguments(arguments: &str) -> (String, Vec<RealtimeHandoffMessage>) {
#[derive(Debug, Deserialize)]
struct HandoffArguments {
#[serde(default)]
prompt: Option<String>,
#[serde(default)]
text: Option<String>,
#[serde(default)]
input: Option<String>,
#[serde(default)]
message: Option<String>,
#[serde(default)]
input_transcript: Option<String>,
#[serde(default)]
messages: Vec<RealtimeHandoffMessage>,
}
let Some(parsed) = serde_json::from_str::<HandoffArguments>(arguments).ok() else {
return (
arguments.to_string(),
vec![RealtimeHandoffMessage {
role: "user".to_string(),
text: arguments.to_string(),
}],
);
};
let messages = parsed
.messages
.into_iter()
.filter(|message| !message.text.is_empty())
.collect::<Vec<_>>();
for value in [
parsed.prompt,
parsed.text,
parsed.input,
parsed.message,
parsed.input_transcript,
]
.into_iter()
.flatten()
{
if !value.is_empty() {
if messages.is_empty() {
return (
value.clone(),
vec![RealtimeHandoffMessage {
role: "user".to_string(),
text: value,
}],
);
}
return (value, messages);
}
}
if let Some(first_message) = messages.first() {
return (first_message.text.clone(), messages);
}
(String::new(), messages)
}

View File

@@ -81,7 +81,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
assert_eq!(first_json["type"], "session.update");
assert_eq!(
first_json["session"]["type"],
Value::String("quicksilver".to_string())
Value::String("realtime".to_string())
);
assert_eq!(
first_json["session"]["instructions"],
@@ -95,6 +95,30 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
first_json["session"]["audio"]["input"]["format"]["rate"],
Value::from(24_000)
);
assert_eq!(
first_json["session"]["audio"]["input"]["turn_detection"]["type"],
Value::String("semantic_vad".to_string())
);
assert_eq!(
first_json["session"]["audio"]["input"]["turn_detection"]["interrupt_response"],
Value::Bool(false)
);
assert_eq!(
first_json["session"]["audio"]["input"]["turn_detection"]["create_response"],
Value::Bool(true)
);
assert_eq!(
first_json["session"]["tool_choice"],
Value::String("auto".to_string())
);
assert_eq!(
first_json["session"]["tools"][0]["type"],
Value::String("function".to_string())
);
assert_eq!(
first_json["session"]["tools"][0]["name"],
Value::String("codex".to_string())
);
ws.send(Message::Text(
json!({
@@ -119,7 +143,7 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
ws.send(Message::Text(
json!({
"type": "conversation.output_audio.delta",
"type": "response.output_audio.delta",
"delta": "AQID",
"sample_rate": 48000,
"channels": 1
@@ -311,7 +335,7 @@ async fn realtime_ws_e2e_disconnected_emitted_once() {
}
#[tokio::test]
async fn realtime_ws_e2e_ignores_unknown_text_events() {
async fn realtime_ws_e2e_forwards_unknown_text_events() {
let (addr, server) = spawn_realtime_ws_server(|mut ws: RealtimeWsStream| async move {
let first = ws
.next()
@@ -361,13 +385,26 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
.await
.expect("connect");
let event = connection
let first_event = connection
.next_event()
.await
.expect("next event")
.expect("event");
assert_eq!(
event,
first_event,
RealtimeEvent::ConversationItemAdded(json!({
"type": "response.created",
"response": {"id": "resp_unknown"}
}))
);
let second_event = connection
.next_event()
.await
.expect("next event")
.expect("event");
assert_eq!(
second_event,
RealtimeEvent::SessionUpdated {
session_id: "sess_after_unknown".to_string(),
instructions: Some("backend prompt".to_string()),

View File

@@ -2447,6 +2447,9 @@ impl Session {
if !matches!(msg, EventMsg::TurnComplete(_)) {
return;
}
if let Err(err) = self.conversation.handoff_complete().await {
debug!("failed to send final realtime handoff tool output: {err}");
}
self.conversation.clear_active_handoff().await;
}

View File

@@ -43,6 +43,7 @@ const AUDIO_IN_QUEUE_CAPACITY: usize = 256;
const USER_TEXT_IN_QUEUE_CAPACITY: usize = 64;
const HANDOFF_OUT_QUEUE_CAPACITY: usize = 64;
const OUTPUT_EVENTS_QUEUE_CAPACITY: usize = 256;
const DEFAULT_REALTIME_MODEL: &str = "gpt-realtime-1.5";
pub(crate) struct RealtimeConversationManager {
state: Mutex<Option<ConversationState>>,
@@ -52,12 +53,19 @@ pub(crate) struct RealtimeConversationManager {
struct RealtimeHandoffState {
output_tx: Sender<HandoffOutput>,
active_handoff: Arc<Mutex<Option<String>>>,
last_output_text: Arc<Mutex<Option<String>>>,
}
#[derive(Debug, PartialEq, Eq)]
struct HandoffOutput {
handoff_id: String,
output_text: String,
enum HandoffOutput {
TextUpdate {
handoff_id: String,
output_text: String,
},
FinalToolCall {
call_id: String,
output_text: String,
},
}
impl RealtimeHandoffState {
@@ -65,6 +73,7 @@ impl RealtimeHandoffState {
Self {
output_tx,
active_handoff: Arc::new(Mutex::new(None)),
last_output_text: Arc::new(Mutex::new(None)),
}
}
@@ -72,9 +81,10 @@ impl RealtimeHandoffState {
let Some(handoff_id) = self.active_handoff.lock().await.clone() else {
return Ok(());
};
*self.last_output_text.lock().await = Some(output_text.clone());
self.output_tx
.send(HandoffOutput {
.send(HandoffOutput::TextUpdate {
handoff_id,
output_text,
})
@@ -82,6 +92,23 @@ impl RealtimeHandoffState {
.map_err(|_| CodexErr::InvalidRequest("conversation is not running".to_string()))?;
Ok(())
}
async fn send_final_output(&self) -> CodexResult<()> {
let Some(call_id) = self.active_handoff.lock().await.clone() else {
return Ok(());
};
let Some(output_text) = self.last_output_text.lock().await.clone() else {
return Ok(());
};
self.output_tx
.send(HandoffOutput::FinalToolCall {
call_id,
output_text,
})
.await
.map_err(|_| CodexErr::InvalidRequest("conversation is not running".to_string()))?;
Ok(())
}
}
#[allow(dead_code)]
@@ -232,6 +259,17 @@ impl RealtimeConversationManager {
handoff.send_output(output_text).await
}
pub(crate) async fn handoff_complete(&self) -> CodexResult<()> {
let handoff = {
let guard = self.state.lock().await;
guard.as_ref().map(|state| state.handoff.clone())
};
let Some(handoff) = handoff else {
return Ok(());
};
handoff.send_final_output().await
}
pub(crate) async fn active_handoff_id(&self) -> Option<String> {
let handoff = {
let guard = self.state.lock().await;
@@ -247,6 +285,7 @@ impl RealtimeConversationManager {
};
if let Some(handoff) = handoff {
*handoff.active_handoff.lock().await = None;
*handoff.last_output_text.lock().await = None;
}
}
@@ -282,7 +321,7 @@ pub(crate) async fn handle_start(
.experimental_realtime_ws_backend_prompt
.clone()
.unwrap_or(params.prompt);
let model = config.experimental_realtime_ws_model.clone();
let model = Some(DEFAULT_REALTIME_MODEL.to_string());
let requested_session_id = params
.session_id
@@ -489,17 +528,39 @@ fn spawn_realtime_input_task(
}
handoff_output = handoff_output_rx.recv() => {
match handoff_output {
Ok(HandoffOutput {
handoff_id,
output_text,
}) => {
if let Err(err) = writer
.send_conversation_handoff_append(handoff_id, output_text)
.await
{
let mapped_error = map_api_error(err);
warn!("failed to send handoff output: {mapped_error}");
break;
Ok(handoff_output) => {
match handoff_output {
HandoffOutput::TextUpdate {
handoff_id,
output_text,
} => {
if let Err(err) = writer
.send_conversation_handoff_append(handoff_id, output_text)
.await
{
let mapped_error = map_api_error(err);
warn!("failed to send handoff output: {mapped_error}");
break;
}
}
HandoffOutput::FinalToolCall {
call_id,
output_text,
} => {
if let Err(err) = writer
.send_function_call_output(call_id, output_text)
.await
{
let mapped_error = map_api_error(err);
warn!("failed to send handoff tool output: {mapped_error}");
break;
}
if let Err(err) = writer.send_response_create().await {
let mapped_error = map_api_error(err);
warn!("failed to send handoff response.create: {mapped_error}");
break;
}
}
}
}
Err(_) => break,
@@ -511,6 +572,7 @@ fn spawn_realtime_input_task(
if let RealtimeEvent::HandoffRequested(handoff) = &event {
*handoff_state.active_handoff.lock().await =
Some(handoff.handoff_id.clone());
*handoff_state.last_output_text.lock().await = None;
}
let should_stop = matches!(&event, RealtimeEvent::Error(_));
if events_tx.send(event).await.is_err() {
@@ -670,7 +732,7 @@ mod tests {
let output_1 = rx.recv().await.expect("recv");
assert_eq!(
output_1,
HandoffOutput {
HandoffOutput::TextUpdate {
handoff_id: "handoff_1".to_string(),
output_text: "result".to_string(),
}
@@ -679,7 +741,7 @@ mod tests {
let output_2 = rx.recv().await.expect("recv");
assert_eq!(
output_2,
HandoffOutput {
HandoffOutput::TextUpdate {
handoff_id: "handoff_1".to_string(),
output_text: "result 2".to_string(),
}
@@ -692,4 +754,27 @@ mod tests {
.expect("send");
assert!(rx.is_empty());
}
#[tokio::test]
async fn sends_final_tool_call_output_for_active_handoff() {
let (tx, rx) = bounded(4);
let state = RealtimeHandoffState::new(tx);
*state.active_handoff.lock().await = Some("handoff_2".to_string());
state
.send_output("final text".to_string())
.await
.expect("send");
let _ = rx.recv().await.expect("recv text update");
state.send_final_output().await.expect("send final output");
let final_output = rx.recv().await.expect("recv final output");
assert_eq!(
final_output,
HandoffOutput::FinalToolCall {
call_id: "handoff_2".to_string(),
output_text: "final text".to_string(),
}
);
}
}

View File

@@ -42,7 +42,7 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
vec![],
vec![
json!({
"type": "conversation.output_audio.delta",
"type": "response.output_audio.delta",
"delta": "AQID",
"sample_rate": 24000,
"channels": 1
@@ -141,7 +141,7 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
);
assert_eq!(
server.handshakes()[1].uri(),
"/v1/realtime?intent=quicksilver&model=realtime-test-model"
"/v1/realtime?model=gpt-realtime-1.5"
);
let mut request_types = [
connection[1].body_json()["type"]
@@ -387,7 +387,7 @@ async fn conversation_second_start_replaces_runtime() -> Result<()> {
"session": { "id": "sess_new", "instructions": "new" }
})],
vec![json!({
"type": "conversation.output_audio.delta",
"type": "response.output_audio.delta",
"delta": "AQID",
"sample_rate": 24000,
"channels": 1
@@ -600,15 +600,11 @@ async fn conversation_mirrors_assistant_message_text_to_realtime_handoff() -> Re
"type": "session.updated",
"session": { "id": "sess_1", "instructions": "backend prompt" }
}),
json!({
"type": "conversation.handoff.requested",
"handoff_id": "handoff_1",
"item_id": "item_1",
"input_transcript": "delegate hello",
"messages": [{ "role": "user", "text": "delegate hello" }]
}),
realtime_handoff_requested_event("handoff_1", "item_1", "delegate hello"),
],
vec![],
vec![],
vec![],
]])
.await;
@@ -652,7 +648,7 @@ async fn conversation_mirrors_assistant_message_text_to_realtime_handoff() -> Re
let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
while tokio::time::Instant::now() < deadline {
let connections = realtime_server.connections();
if connections.len() == 1 && connections[0].len() >= 2 {
if connections.len() == 1 && connections[0].len() >= 4 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
@@ -660,22 +656,46 @@ async fn conversation_mirrors_assistant_message_text_to_realtime_handoff() -> Re
let realtime_connections = realtime_server.connections();
assert_eq!(realtime_connections.len(), 1);
assert_eq!(realtime_connections[0].len(), 2);
assert_eq!(realtime_connections[0].len(), 4);
assert_eq!(
realtime_connections[0][0].body_json()["type"].as_str(),
Some("session.update")
);
assert_eq!(
realtime_connections[0][1].body_json()["type"].as_str(),
Some("conversation.handoff.append")
Some("conversation.item.create")
);
assert_eq!(
realtime_connections[0][1].body_json()["handoff_id"].as_str(),
realtime_connections[0][1].body_json()["item"]["type"].as_str(),
Some("message")
);
assert_eq!(
realtime_connections[0][1].body_json()["item"]["role"].as_str(),
Some("assistant")
);
assert_eq!(
realtime_connections[0][1].body_json()["item"]["content"][0]["type"].as_str(),
Some("output_text")
);
assert_eq!(
realtime_connections[0][1].body_json()["item"]["content"][0]["text"].as_str(),
Some("assistant says hi")
);
assert_eq!(
realtime_connections[0][2].body_json()["type"].as_str(),
Some("conversation.item.create")
);
assert_eq!(
realtime_connections[0][2].body_json()["item"]["type"].as_str(),
Some("function_call_output")
);
assert_eq!(
realtime_connections[0][2].body_json()["item"]["call_id"].as_str(),
Some("handoff_1")
);
assert_eq!(
realtime_connections[0][1].body_json()["output_text"].as_str(),
Some("assistant says hi")
realtime_connections[0][3].body_json()["type"].as_str(),
Some("response.create")
);
realtime_server.shutdown().await;
@@ -719,19 +739,15 @@ async fn conversation_handoff_persists_across_item_done_until_turn_complete() ->
"type": "session.updated",
"session": { "id": "sess_item_done", "instructions": "backend prompt" }
}),
json!({
"type": "conversation.handoff.requested",
"handoff_id": "handoff_item_done",
"item_id": "item_item_done",
"input_transcript": "delegate now",
"messages": [{ "role": "user", "text": "delegate now" }]
}),
realtime_handoff_requested_event("handoff_item_done", "item_item_done", "delegate now"),
],
vec![json!({
"type": "conversation.item.done",
"item": { "id": "item_item_done" }
})],
vec![],
vec![],
vec![],
]])
.await;
@@ -769,14 +785,22 @@ async fn conversation_handoff_persists_across_item_done_until_turn_complete() ->
let first_append = realtime_server.wait_for_request(0, 1).await;
assert_eq!(
first_append.body_json()["type"].as_str(),
Some("conversation.handoff.append")
Some("conversation.item.create")
);
assert_eq!(
first_append.body_json()["handoff_id"].as_str(),
Some("handoff_item_done")
first_append.body_json()["item"]["type"].as_str(),
Some("message")
);
assert_eq!(
first_append.body_json()["output_text"].as_str(),
first_append.body_json()["item"]["role"].as_str(),
Some("assistant")
);
assert_eq!(
first_append.body_json()["item"]["content"][0]["type"].as_str(),
Some("output_text")
);
assert_eq!(
first_append.body_json()["item"]["content"][0]["text"].as_str(),
Some("assistant message 1")
);
@@ -793,14 +817,22 @@ async fn conversation_handoff_persists_across_item_done_until_turn_complete() ->
let second_append = realtime_server.wait_for_request(0, 2).await;
assert_eq!(
second_append.body_json()["type"].as_str(),
Some("conversation.handoff.append")
Some("conversation.item.create")
);
assert_eq!(
second_append.body_json()["handoff_id"].as_str(),
Some("handoff_item_done")
second_append.body_json()["item"]["type"].as_str(),
Some("message")
);
assert_eq!(
second_append.body_json()["output_text"].as_str(),
second_append.body_json()["item"]["role"].as_str(),
Some("assistant")
);
assert_eq!(
second_append.body_json()["item"]["content"][0]["type"].as_str(),
Some("output_text")
);
assert_eq!(
second_append.body_json()["item"]["content"][0]["text"].as_str(),
Some("assistant message 2")
);
@@ -816,6 +848,30 @@ async fn conversation_handoff_persists_across_item_done_until_turn_complete() ->
})
.await;
let final_tool_call = realtime_server.wait_for_request(0, 3).await;
assert_eq!(
final_tool_call.body_json()["type"].as_str(),
Some("conversation.item.create")
);
assert_eq!(
final_tool_call.body_json()["item"]["type"].as_str(),
Some("function_call_output")
);
assert_eq!(
final_tool_call.body_json()["item"]["call_id"].as_str(),
Some("handoff_item_done")
);
assert_eq!(
final_tool_call.body_json()["item"]["output"].as_str(),
Some("{\"content\":\"assistant message 2\"}")
);
let response_create = realtime_server.wait_for_request(0, 4).await;
assert_eq!(
response_create.body_json()["type"].as_str(),
Some("response.create")
);
realtime_server.shutdown().await;
api_server.shutdown().await;
Ok(())
@@ -825,6 +881,23 @@ fn sse_event(event: Value) -> String {
responses::sse(vec![event])
}
fn realtime_handoff_requested_event(handoff_id: &str, item_id: &str, prompt: &str) -> Value {
json!({
"type": "response.done",
"response": {
"output": [
{
"id": item_id,
"type": "function_call",
"name": "codex",
"call_id": handoff_id,
"arguments": json!({ "prompt": prompt }).to_string(),
}
]
}
})
}
fn message_input_texts(body: &Value, role: &str) -> Vec<String> {
body.get("input")
.and_then(Value::as_array)
@@ -859,13 +932,7 @@ async fn inbound_handoff_request_starts_turn() -> Result<()> {
"type": "session.updated",
"session": { "id": "sess_inbound", "instructions": "backend prompt" }
}),
json!({
"type": "conversation.handoff.requested",
"handoff_id": "handoff_inbound",
"item_id": "item_inbound",
"input_transcript": "text from realtime",
"messages": [{ "role": "user", "text": "text from realtime" }]
}),
realtime_handoff_requested_event("handoff_inbound", "item_inbound", "text from realtime"),
]]])
.await;
@@ -939,15 +1006,25 @@ async fn inbound_handoff_request_uses_all_messages() -> Result<()> {
"session": { "id": "sess_inbound_multi", "instructions": "backend prompt" }
}),
json!({
"type": "conversation.handoff.requested",
"handoff_id": "handoff_inbound_multi",
"item_id": "item_inbound_multi",
"input_transcript": "ignored",
"messages": [
{ "role": "assistant", "text": "assistant context" },
{ "role": "user", "text": "delegated query" },
{ "role": "assistant", "text": "assist confirm" },
]
"type": "response.done",
"response": {
"output": [
{
"id": "item_inbound_multi",
"type": "function_call",
"name": "codex",
"call_id": "handoff_inbound_multi",
"arguments": json!({
"input_transcript": "ignored",
"messages": [
{ "role": "assistant", "text": "assistant context" },
{ "role": "user", "text": "delegated query" },
{ "role": "assistant", "text": "assist confirm" },
]
}).to_string(),
}
]
}
}),
]]])
.await;
@@ -1012,7 +1089,7 @@ async fn inbound_conversation_item_does_not_start_turn_and_still_forwards_audio(
}
}),
json!({
"type": "conversation.output_audio.delta",
"type": "response.output_audio.delta",
"delta": "AQID",
"sample_rate": 24000,
"channels": 1
@@ -1102,13 +1179,11 @@ async fn delegated_turn_user_role_echo_does_not_redelegate_and_still_forwards_au
"type": "session.updated",
"session": { "id": "sess_echo_guard", "instructions": "backend prompt" }
}),
json!({
"type": "conversation.handoff.requested",
"handoff_id": "handoff_echo_guard",
"item_id": "item_echo_guard",
"input_transcript": "delegate now",
"messages": [{"role": "user", "text": "delegate now"}]
}),
realtime_handoff_requested_event(
"handoff_echo_guard",
"item_echo_guard",
"delegate now",
),
],
vec![
json!({
@@ -1120,7 +1195,7 @@ async fn delegated_turn_user_role_echo_does_not_redelegate_and_still_forwards_au
}
}),
json!({
"type": "conversation.output_audio.delta",
"type": "response.output_audio.delta",
"delta": "AQID",
"sample_rate": 24000,
"channels": 1
@@ -1168,22 +1243,30 @@ async fn delegated_turn_user_role_echo_does_not_redelegate_and_still_forwards_au
let mirrored_request = realtime_server.wait_for_request(0, 1).await;
let mirrored_request_body = mirrored_request.body_json();
eprintln!(
"[realtime test +{}ms] saw mirrored request type={:?} handoff_id={:?} text={:?}",
"[realtime test +{}ms] saw mirrored request type={:?} role={:?} text={:?}",
start.elapsed().as_millis(),
mirrored_request_body["type"].as_str(),
mirrored_request_body["handoff_id"].as_str(),
mirrored_request_body["output_text"].as_str(),
mirrored_request_body["item"]["role"].as_str(),
mirrored_request_body["item"]["content"][0]["text"].as_str(),
);
assert_eq!(
mirrored_request_body["type"].as_str(),
Some("conversation.handoff.append")
Some("conversation.item.create")
);
assert_eq!(
mirrored_request_body["handoff_id"].as_str(),
Some("handoff_echo_guard")
mirrored_request_body["item"]["type"].as_str(),
Some("message")
);
assert_eq!(
mirrored_request_body["output_text"].as_str(),
mirrored_request_body["item"]["role"].as_str(),
Some("assistant")
);
assert_eq!(
mirrored_request_body["item"]["content"][0]["type"].as_str(),
Some("output_text")
);
assert_eq!(
mirrored_request_body["item"]["content"][0]["text"].as_str(),
Some("assistant says hi")
);
@@ -1250,15 +1333,13 @@ async fn inbound_handoff_request_does_not_block_realtime_event_forwarding() -> R
"type": "session.updated",
"session": { "id": "sess_non_blocking", "instructions": "backend prompt" }
}),
realtime_handoff_requested_event(
"handoff_non_blocking",
"item_non_blocking",
"delegate now",
),
json!({
"type": "conversation.handoff.requested",
"handoff_id": "handoff_non_blocking",
"item_id": "item_non_blocking",
"input_transcript": "delegate now",
"messages": [{"role": "user", "text": "delegate now"}]
}),
json!({
"type": "conversation.output_audio.delta",
"type": "response.output_audio.delta",
"delta": "AQID",
"sample_rate": 24000,
"channels": 1
@@ -1377,13 +1458,11 @@ async fn inbound_handoff_request_steers_active_turn() -> Result<()> {
"type": "session.updated",
"session": { "id": "sess_steer", "instructions": "backend prompt" }
})],
vec![json!({
"type": "conversation.handoff.requested",
"handoff_id": "handoff_steer",
"item_id": "item_steer",
"input_transcript": "steer via realtime",
"messages": [{ "role": "user", "text": "steer via realtime" }]
})],
vec![realtime_handoff_requested_event(
"handoff_steer",
"item_steer",
"steer via realtime",
)],
]])
.await;
@@ -1500,15 +1579,9 @@ async fn inbound_handoff_request_starts_turn_and_does_not_block_realtime_audio()
"type": "session.updated",
"session": { "id": "sess_handoff_request", "instructions": "backend prompt" }
}),
realtime_handoff_requested_event("handoff_audio", "item_audio", delegated_text),
json!({
"type": "conversation.handoff.requested",
"handoff_id": "handoff_audio",
"item_id": "item_audio",
"input_transcript": delegated_text,
"messages": [{ "role": "user", "text": delegated_text }]
}),
json!({
"type": "conversation.output_audio.delta",
"type": "response.output_audio.delta",
"delta": "AQID",
"sample_rate": 24000,
"channels": 1

View File

@@ -304,7 +304,7 @@ impl ChatWidget {
}
}
#[cfg(not(target_os = "linux"))]
#[cfg(all(not(target_os = "linux"), feature = "voice-input"))]
fn start_realtime_local_audio(&mut self) {
if self.realtime_conversation.capture_stop_flag.is_some() {
return;
@@ -361,6 +361,9 @@ impl ChatWidget {
#[cfg(target_os = "linux")]
fn start_realtime_local_audio(&mut self) {}
#[cfg(all(not(target_os = "linux"), not(feature = "voice-input")))]
fn start_realtime_local_audio(&mut self) {}
#[cfg(all(not(target_os = "linux"), feature = "voice-input"))]
pub(crate) fn restart_realtime_audio_device(&mut self, kind: RealtimeAudioDeviceKind) {
if !self.realtime_conversation.is_active() {

View File

@@ -484,7 +484,7 @@ fn convert_u16_to_i16_and_peak(input: &[u16], out: &mut Vec<i16>) -> u16 {
pub(crate) struct RealtimeAudioPlayer {
_stream: cpal::Stream,
queue: Arc<Mutex<VecDeque<i16>>>,
queue: Arc<Mutex<OutputAudioQueue>>,
output_sample_rate: u32,
output_channels: u16,
}
@@ -495,8 +495,9 @@ impl RealtimeAudioPlayer {
crate::audio_device::select_configured_output_device_and_config(config)?;
let output_sample_rate = config.sample_rate().0;
let output_channels = config.channels();
let queue = Arc::new(Mutex::new(VecDeque::new()));
let stream = build_output_stream(&device, &config, Arc::clone(&queue))?;
let prebuffer_samples = output_prebuffer_samples(output_sample_rate, output_channels);
let queue = Arc::new(Mutex::new(OutputAudioQueue::default()));
let stream = build_output_stream(&device, &config, Arc::clone(&queue), prebuffer_samples)?;
stream
.play()
.map_err(|e| format!("failed to start output stream: {e}"))?;
@@ -537,13 +538,14 @@ impl RealtimeAudioPlayer {
.lock()
.map_err(|_| "failed to lock output audio queue".to_string())?;
// TODO(aibrahim): Cap or trim this queue if we observe producer bursts outrunning playback.
guard.extend(converted);
guard.samples.extend(converted);
Ok(())
}
pub(crate) fn clear(&self) {
if let Ok(mut guard) = self.queue.lock() {
guard.clear();
guard.samples.clear();
guard.primed = false;
}
}
}
@@ -551,14 +553,15 @@ impl RealtimeAudioPlayer {
fn build_output_stream(
device: &cpal::Device,
config: &cpal::SupportedStreamConfig,
queue: Arc<Mutex<VecDeque<i16>>>,
queue: Arc<Mutex<OutputAudioQueue>>,
prebuffer_samples: usize,
) -> Result<cpal::Stream, String> {
let config_any: cpal::StreamConfig = config.clone().into();
match config.sample_format() {
cpal::SampleFormat::F32 => device
.build_output_stream(
&config_any,
move |output: &mut [f32], _| fill_output_f32(output, &queue),
move |output: &mut [f32], _| fill_output_f32(output, &queue, prebuffer_samples),
move |err| error!("audio output error: {err}"),
None,
)
@@ -566,7 +569,7 @@ fn build_output_stream(
cpal::SampleFormat::I16 => device
.build_output_stream(
&config_any,
move |output: &mut [i16], _| fill_output_i16(output, &queue),
move |output: &mut [i16], _| fill_output_i16(output, &queue, prebuffer_samples),
move |err| error!("audio output error: {err}"),
None,
)
@@ -574,7 +577,7 @@ fn build_output_stream(
cpal::SampleFormat::U16 => device
.build_output_stream(
&config_any,
move |output: &mut [u16], _| fill_output_u16(output, &queue),
move |output: &mut [u16], _| fill_output_u16(output, &queue, prebuffer_samples),
move |err| error!("audio output error: {err}"),
None,
)
@@ -583,20 +586,64 @@ fn build_output_stream(
}
}
fn fill_output_i16(output: &mut [i16], queue: &Arc<Mutex<VecDeque<i16>>>) {
#[derive(Default)]
struct OutputAudioQueue {
samples: VecDeque<i16>,
primed: bool,
}
fn output_prebuffer_samples(sample_rate: u32, channels: u16) -> usize {
let samples_per_second = (sample_rate as usize).saturating_mul(channels as usize);
// 120ms jitter buffer smooths websocket burstiness without adding too much latency.
((samples_per_second as u64) * 120 / 1_000) as usize
}
fn should_output_silence(queue: &mut OutputAudioQueue, min_buffer_samples: usize) -> bool {
if !queue.primed {
if queue.samples.len() < min_buffer_samples {
return true;
}
queue.primed = true;
}
if queue.samples.is_empty() {
queue.primed = false;
return true;
}
false
}
fn fill_output_i16(
output: &mut [i16],
queue: &Arc<Mutex<OutputAudioQueue>>,
prebuffer_samples: usize,
) {
if let Ok(mut guard) = queue.lock() {
if should_output_silence(&mut guard, prebuffer_samples) {
output.fill(0);
return;
}
for sample in output {
*sample = guard.pop_front().unwrap_or(0);
*sample = guard.samples.pop_front().unwrap_or(0);
}
return;
}
output.fill(0);
}
fn fill_output_f32(output: &mut [f32], queue: &Arc<Mutex<VecDeque<i16>>>) {
fn fill_output_f32(
output: &mut [f32],
queue: &Arc<Mutex<OutputAudioQueue>>,
prebuffer_samples: usize,
) {
if let Ok(mut guard) = queue.lock() {
if should_output_silence(&mut guard, prebuffer_samples) {
output.fill(0.0);
return;
}
for sample in output {
let v = guard.pop_front().unwrap_or(0);
let v = guard.samples.pop_front().unwrap_or(0);
*sample = (v as f32) / (i16::MAX as f32);
}
return;
@@ -604,10 +651,18 @@ fn fill_output_f32(output: &mut [f32], queue: &Arc<Mutex<VecDeque<i16>>>) {
output.fill(0.0);
}
fn fill_output_u16(output: &mut [u16], queue: &Arc<Mutex<VecDeque<i16>>>) {
fn fill_output_u16(
output: &mut [u16],
queue: &Arc<Mutex<OutputAudioQueue>>,
prebuffer_samples: usize,
) {
if let Ok(mut guard) = queue.lock() {
if should_output_silence(&mut guard, prebuffer_samples) {
output.fill(32768);
return;
}
for sample in output {
let v = guard.pop_front().unwrap_or(0);
let v = guard.samples.pop_front().unwrap_or(0);
*sample = (v as i32 + 32768).clamp(0, u16::MAX as i32) as u16;
}
return;