mirror of
https://github.com/openai/codex.git
synced 2026-05-05 05:42:33 +03:00
Separate realtime v2 and v1 runtime behavior
This commit is contained in:
@@ -1,20 +1,26 @@
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationItem;
|
||||
use crate::endpoint::realtime_websocket::protocol::ConversationItemContent;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeApiMode;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeAudioFrame;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeEvent;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage;
|
||||
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionConfig;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudio;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioFormat;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioInput;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutput;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioInputV1;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioInputV2;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutputFormat;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutputV1;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutputV2;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioV1;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionAudioV2;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionTool;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionToolParameters;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionToolProperties;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionToolProperty;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionTurnDetection;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSessionV1;
|
||||
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSessionV2;
|
||||
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
@@ -200,12 +206,14 @@ pub struct RealtimeWebsocketConnection {
|
||||
pub struct RealtimeWebsocketWriter {
|
||||
stream: Arc<WsStream>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
mode: RealtimeApiMode,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RealtimeWebsocketEvents {
|
||||
rx_message: Arc<Mutex<mpsc::UnboundedReceiver<Result<Message, WsError>>>>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
mode: RealtimeApiMode,
|
||||
}
|
||||
|
||||
impl RealtimeWebsocketConnection {
|
||||
@@ -260,6 +268,7 @@ impl RealtimeWebsocketConnection {
|
||||
fn new(
|
||||
stream: WsStream,
|
||||
rx_message: mpsc::UnboundedReceiver<Result<Message, WsError>>,
|
||||
mode: RealtimeApiMode,
|
||||
) -> Self {
|
||||
let stream = Arc::new(stream);
|
||||
let is_closed = Arc::new(AtomicBool::new(false));
|
||||
@@ -267,10 +276,12 @@ impl RealtimeWebsocketConnection {
|
||||
writer: RealtimeWebsocketWriter {
|
||||
stream: Arc::clone(&stream),
|
||||
is_closed: Arc::clone(&is_closed),
|
||||
mode,
|
||||
},
|
||||
events: RealtimeWebsocketEvents {
|
||||
rx_message: Arc::new(Mutex::new(rx_message)),
|
||||
is_closed,
|
||||
mode,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -283,13 +294,14 @@ impl RealtimeWebsocketWriter {
|
||||
}
|
||||
|
||||
pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> {
|
||||
let kind = match self.mode {
|
||||
RealtimeApiMode::V1 => "text".to_string(),
|
||||
RealtimeApiMode::V2 => "input_text".to_string(),
|
||||
};
|
||||
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
|
||||
item: ConversationItem::Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![ConversationItemContent {
|
||||
kind: "input_text".to_string(),
|
||||
text,
|
||||
}],
|
||||
content: vec![ConversationItemContent { kind, text }],
|
||||
},
|
||||
})
|
||||
.await
|
||||
@@ -297,19 +309,30 @@ impl RealtimeWebsocketWriter {
|
||||
|
||||
pub async fn send_conversation_handoff_append(
|
||||
&self,
|
||||
_handoff_id: String,
|
||||
handoff_id: String,
|
||||
output_text: String,
|
||||
) -> Result<(), ApiError> {
|
||||
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
|
||||
item: ConversationItem::Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ConversationItemContent {
|
||||
kind: "output_text".to_string(),
|
||||
text: output_text,
|
||||
}],
|
||||
},
|
||||
})
|
||||
.await
|
||||
match self.mode {
|
||||
RealtimeApiMode::V1 => {
|
||||
self.send_json(RealtimeOutboundMessage::ConversationHandoffAppend {
|
||||
handoff_id,
|
||||
output_text,
|
||||
})
|
||||
.await
|
||||
}
|
||||
RealtimeApiMode::V2 => {
|
||||
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
|
||||
item: ConversationItem::Message {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ConversationItemContent {
|
||||
kind: "output_text".to_string(),
|
||||
text: output_text,
|
||||
}],
|
||||
},
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_function_call_output(
|
||||
@@ -317,29 +340,54 @@ impl RealtimeWebsocketWriter {
|
||||
call_id: String,
|
||||
output_text: String,
|
||||
) -> Result<(), ApiError> {
|
||||
let output = json!({
|
||||
"content": output_text,
|
||||
})
|
||||
.to_string();
|
||||
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
|
||||
item: ConversationItem::FunctionCallOutput { call_id, output },
|
||||
})
|
||||
.await
|
||||
match self.mode {
|
||||
RealtimeApiMode::V1 => Ok(()),
|
||||
RealtimeApiMode::V2 => {
|
||||
let output = json!({
|
||||
"content": output_text,
|
||||
})
|
||||
.to_string();
|
||||
self.send_json(RealtimeOutboundMessage::ConversationItemCreate {
|
||||
item: ConversationItem::FunctionCallOutput { call_id, output },
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_response_create(&self) -> Result<(), ApiError> {
|
||||
self.send_json(RealtimeOutboundMessage::ResponseCreate)
|
||||
.await
|
||||
match self.mode {
|
||||
RealtimeApiMode::V1 => Ok(()),
|
||||
RealtimeApiMode::V2 => {
|
||||
self.send_json(RealtimeOutboundMessage::ResponseCreate)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_session_update(&self, instructions: String) -> Result<(), ApiError> {
|
||||
self.send_json(RealtimeOutboundMessage::SessionUpdate {
|
||||
session: SessionUpdateSession {
|
||||
let session = match self.mode {
|
||||
RealtimeApiMode::V1 => SessionUpdateSession::V1(SessionUpdateSessionV1 {
|
||||
kind: "quicksilver".to_string(),
|
||||
instructions,
|
||||
audio: SessionAudioV1 {
|
||||
input: SessionAudioInputV1 {
|
||||
format: SessionAudioFormat {
|
||||
kind: "audio/pcm".to_string(),
|
||||
rate: 24_000,
|
||||
},
|
||||
},
|
||||
output: SessionAudioOutputV1 {
|
||||
voice: "mundo".to_string(),
|
||||
},
|
||||
},
|
||||
}),
|
||||
RealtimeApiMode::V2 => SessionUpdateSession::V2(SessionUpdateSessionV2 {
|
||||
kind: "realtime".to_string(),
|
||||
instructions,
|
||||
output_modalities: vec!["audio".to_string()],
|
||||
audio: SessionAudio {
|
||||
input: SessionAudioInput {
|
||||
audio: SessionAudioV2 {
|
||||
input: SessionAudioInputV2 {
|
||||
format: SessionAudioFormat {
|
||||
kind: "audio/pcm".to_string(),
|
||||
rate: 24_000,
|
||||
@@ -350,7 +398,7 @@ impl RealtimeWebsocketWriter {
|
||||
create_response: true,
|
||||
},
|
||||
},
|
||||
output: SessionAudioOutput {
|
||||
output: SessionAudioOutputV2 {
|
||||
format: SessionAudioOutputFormat {
|
||||
kind: "audio/pcm".to_string(),
|
||||
rate: 24_000,
|
||||
@@ -376,9 +424,11 @@ impl RealtimeWebsocketWriter {
|
||||
},
|
||||
}],
|
||||
tool_choice: "auto".to_string(),
|
||||
},
|
||||
})
|
||||
.await
|
||||
}),
|
||||
};
|
||||
|
||||
self.send_json(RealtimeOutboundMessage::SessionUpdate { session })
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), ApiError> {
|
||||
@@ -415,6 +465,10 @@ impl RealtimeWebsocketWriter {
|
||||
}
|
||||
|
||||
impl RealtimeWebsocketEvents {
|
||||
fn mode(&self) -> RealtimeApiMode {
|
||||
self.mode
|
||||
}
|
||||
|
||||
pub async fn next_event(&self) -> Result<Option<RealtimeEvent>, ApiError> {
|
||||
if self.is_closed.load(Ordering::SeqCst) {
|
||||
return Ok(None);
|
||||
@@ -439,7 +493,7 @@ impl RealtimeWebsocketEvents {
|
||||
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
if let Some(event) = parse_realtime_event(&text) {
|
||||
if let Some(event) = parse_realtime_event(&text, self.mode()) {
|
||||
debug!(?event, "realtime websocket parsed event");
|
||||
return Ok(Some(event));
|
||||
}
|
||||
@@ -485,6 +539,7 @@ impl RealtimeWebsocketClient {
|
||||
self.provider.base_url.as_str(),
|
||||
self.provider.query_params.as_ref(),
|
||||
config.model.as_deref(),
|
||||
config.mode,
|
||||
)?;
|
||||
|
||||
let mut request = ws_url
|
||||
@@ -512,7 +567,7 @@ impl RealtimeWebsocketClient {
|
||||
);
|
||||
|
||||
let (stream, rx_message) = WsStream::new(stream);
|
||||
let connection = RealtimeWebsocketConnection::new(stream, rx_message);
|
||||
let connection = RealtimeWebsocketConnection::new(stream, rx_message, config.mode);
|
||||
debug!(
|
||||
session_id = config.session_id.as_deref().unwrap_or("<none>"),
|
||||
"realtime websocket sending session.update"
|
||||
@@ -564,6 +619,7 @@ fn websocket_url_from_api_url(
|
||||
api_url: &str,
|
||||
query_params: Option<&HashMap<String, String>>,
|
||||
model: Option<&str>,
|
||||
mode: RealtimeApiMode,
|
||||
) -> Result<Url, ApiError> {
|
||||
let mut url = Url::parse(api_url)
|
||||
.map_err(|err| ApiError::Stream(format!("failed to parse realtime api_url: {err}")))?;
|
||||
@@ -583,16 +639,19 @@ fn websocket_url_from_api_url(
|
||||
}
|
||||
}
|
||||
|
||||
let has_additional_query_params = query_params
|
||||
.is_some_and(|params| params.keys().any(|key| key != "model" || model.is_none()));
|
||||
if model.is_some() || has_additional_query_params {
|
||||
{
|
||||
let mut query = url.query_pairs_mut();
|
||||
if mode == RealtimeApiMode::V1 {
|
||||
query.append_pair("intent", "quicksilver");
|
||||
}
|
||||
if let Some(model) = model {
|
||||
query.append_pair("model", model);
|
||||
}
|
||||
if let Some(query_params) = query_params {
|
||||
for (key, value) in query_params {
|
||||
if key == "model" && model.is_some() {
|
||||
if (key == "model" && model.is_some())
|
||||
|| (key == "intent" && mode == RealtimeApiMode::V1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
query.append_pair(key, value);
|
||||
@@ -653,7 +712,7 @@ mod tests {
|
||||
.to_string();
|
||||
|
||||
assert_eq!(
|
||||
parse_realtime_event(payload.as_str()),
|
||||
parse_realtime_event(payload.as_str(), RealtimeApiMode::V2),
|
||||
Some(RealtimeEvent::SessionUpdated {
|
||||
session_id: "sess_123".to_string(),
|
||||
instructions: Some("backend prompt".to_string()),
|
||||
@@ -672,7 +731,7 @@ mod tests {
|
||||
})
|
||||
.to_string();
|
||||
assert_eq!(
|
||||
parse_realtime_event(payload.as_str()),
|
||||
parse_realtime_event(payload.as_str(), RealtimeApiMode::V2),
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data: "AAA=".to_string(),
|
||||
sample_rate: 48000,
|
||||
@@ -691,7 +750,7 @@ mod tests {
|
||||
.to_string();
|
||||
|
||||
assert_eq!(
|
||||
parse_realtime_event(payload.as_str()),
|
||||
parse_realtime_event(payload.as_str(), RealtimeApiMode::V2),
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data: "AAA=".to_string(),
|
||||
sample_rate: 24_000,
|
||||
@@ -709,7 +768,7 @@ mod tests {
|
||||
})
|
||||
.to_string();
|
||||
assert_eq!(
|
||||
parse_realtime_event(payload.as_str()),
|
||||
parse_realtime_event(payload.as_str(), RealtimeApiMode::V2),
|
||||
Some(RealtimeEvent::ConversationItemAdded(
|
||||
json!({"type": "message", "seq": 7})
|
||||
))
|
||||
@@ -724,7 +783,7 @@ mod tests {
|
||||
})
|
||||
.to_string();
|
||||
assert_eq!(
|
||||
parse_realtime_event(payload.as_str()),
|
||||
parse_realtime_event(payload.as_str(), RealtimeApiMode::V2),
|
||||
Some(RealtimeEvent::ConversationItemDone {
|
||||
item_id: "item_123".to_string(),
|
||||
})
|
||||
@@ -750,7 +809,7 @@ mod tests {
|
||||
.to_string();
|
||||
|
||||
assert_eq!(
|
||||
parse_realtime_event(payload.as_str()),
|
||||
parse_realtime_event(payload.as_str(), RealtimeApiMode::V2),
|
||||
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
|
||||
handoff_id: "handoff_123".to_string(),
|
||||
item_id: "item_123".to_string(),
|
||||
@@ -763,6 +822,54 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_v1_handoff_requested_event() {
|
||||
let payload = json!({
|
||||
"type": "conversation.handoff.requested",
|
||||
"handoff_id": "handoff_legacy",
|
||||
"item_id": "item_legacy",
|
||||
"input_transcript": "delegate legacy",
|
||||
"messages": [
|
||||
{"role": "user", "text": "delegate legacy"}
|
||||
]
|
||||
})
|
||||
.to_string();
|
||||
|
||||
assert_eq!(
|
||||
parse_realtime_event(payload.as_str(), RealtimeApiMode::V1),
|
||||
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
|
||||
handoff_id: "handoff_legacy".to_string(),
|
||||
item_id: "item_legacy".to_string(),
|
||||
input_transcript: "delegate legacy".to_string(),
|
||||
messages: vec![RealtimeHandoffMessage {
|
||||
role: "user".to_string(),
|
||||
text: "delegate legacy".to_string(),
|
||||
}],
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_v1_audio_delta_event() {
|
||||
let payload = json!({
|
||||
"type": "conversation.output_audio.delta",
|
||||
"delta": "AAA=",
|
||||
"sample_rate": 48000,
|
||||
"channels": 1,
|
||||
"samples_per_channel": 960
|
||||
})
|
||||
.to_string();
|
||||
assert_eq!(
|
||||
parse_realtime_event(payload.as_str(), RealtimeApiMode::V1),
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data: "AAA=".to_string(),
|
||||
sample_rate: 48000,
|
||||
num_channels: 1,
|
||||
samples_per_channel: Some(960),
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unknown_event_as_conversation_item_added() {
|
||||
let payload = json!({
|
||||
@@ -772,7 +879,7 @@ mod tests {
|
||||
})
|
||||
.to_string();
|
||||
assert_eq!(
|
||||
parse_realtime_event(payload.as_str()),
|
||||
parse_realtime_event(payload.as_str(), RealtimeApiMode::V2),
|
||||
Some(RealtimeEvent::ConversationItemAdded(json!({
|
||||
"type": "response.output_text.delta",
|
||||
"delta": "hello",
|
||||
@@ -817,15 +924,20 @@ mod tests {
|
||||
#[test]
|
||||
fn websocket_url_from_http_base_defaults_to_ws_path() {
|
||||
let url =
|
||||
websocket_url_from_api_url("http://127.0.0.1:8011", None, None).expect("build ws url");
|
||||
websocket_url_from_api_url("http://127.0.0.1:8011", None, None, RealtimeApiMode::V2)
|
||||
.expect("build ws url");
|
||||
assert_eq!(url.as_str(), "ws://127.0.0.1:8011/v1/realtime");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_ws_base_defaults_to_ws_path() {
|
||||
let url =
|
||||
websocket_url_from_api_url("wss://example.com", None, Some("realtime-test-model"))
|
||||
.expect("build ws url");
|
||||
let url = websocket_url_from_api_url(
|
||||
"wss://example.com",
|
||||
None,
|
||||
Some("realtime-test-model"),
|
||||
RealtimeApiMode::V2,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/v1/realtime?model=realtime-test-model"
|
||||
@@ -834,8 +946,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_v1_base_appends_realtime_path() {
|
||||
let url = websocket_url_from_api_url("https://api.openai.com/v1", None, Some("snapshot"))
|
||||
.expect("build ws url");
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://api.openai.com/v1",
|
||||
None,
|
||||
Some("snapshot"),
|
||||
RealtimeApiMode::V2,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://api.openai.com/v1/realtime?model=snapshot"
|
||||
@@ -844,9 +961,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn websocket_url_from_nested_v1_base_appends_realtime_path() {
|
||||
let url =
|
||||
websocket_url_from_api_url("https://example.com/openai/v1", None, Some("snapshot"))
|
||||
.expect("build ws url");
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://example.com/openai/v1",
|
||||
None,
|
||||
Some("snapshot"),
|
||||
RealtimeApiMode::V2,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/openai/v1/realtime?model=snapshot"
|
||||
@@ -859,6 +980,7 @@ mod tests {
|
||||
"https://example.com/v1/realtime?foo=bar",
|
||||
Some(&HashMap::from([("trace".to_string(), "1".to_string())])),
|
||||
Some("snapshot"),
|
||||
RealtimeApiMode::V2,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
@@ -867,6 +989,21 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn websocket_url_v1_mode_adds_quicksilver_intent() {
|
||||
let url = websocket_url_from_api_url(
|
||||
"https://example.com/v1/realtime?foo=bar",
|
||||
Some(&HashMap::from([("trace".to_string(), "1".to_string())])),
|
||||
Some("snapshot"),
|
||||
RealtimeApiMode::V1,
|
||||
)
|
||||
.expect("build ws url");
|
||||
assert_eq!(
|
||||
url.as_str(),
|
||||
"wss://example.com/v1/realtime?foo=bar&intent=quicksilver&model=snapshot&trace=1"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn e2e_connect_and_exchange_events_against_mock_ws_server() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
|
||||
@@ -1076,6 +1213,7 @@ mod tests {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
mode: RealtimeApiMode::V2,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
@@ -1224,6 +1362,7 @@ mod tests {
|
||||
instructions: "backend prompt".to_string(),
|
||||
model: Some("realtime-test-model".to_string()),
|
||||
session_id: Some("conv_1".to_string()),
|
||||
mode: RealtimeApiMode::V2,
|
||||
},
|
||||
HeaderMap::new(),
|
||||
HeaderMap::new(),
|
||||
|
||||
@@ -7,4 +7,5 @@ pub use methods::RealtimeWebsocketClient;
|
||||
pub use methods::RealtimeWebsocketConnection;
|
||||
pub use methods::RealtimeWebsocketEvents;
|
||||
pub use methods::RealtimeWebsocketWriter;
|
||||
pub use protocol::RealtimeApiMode;
|
||||
pub use protocol::RealtimeSessionConfig;
|
||||
|
||||
@@ -8,11 +8,18 @@ use serde_json::Value;
|
||||
use std::string::ToString;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RealtimeApiMode {
|
||||
V1,
|
||||
V2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RealtimeSessionConfig {
|
||||
pub instructions: String,
|
||||
pub model: Option<String>,
|
||||
pub session_id: Option<String>,
|
||||
pub mode: RealtimeApiMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
@@ -20,6 +27,11 @@ pub struct RealtimeSessionConfig {
|
||||
pub(super) enum RealtimeOutboundMessage {
|
||||
#[serde(rename = "input_audio_buffer.append")]
|
||||
InputAudioBufferAppend { audio: String },
|
||||
#[serde(rename = "conversation.handoff.append")]
|
||||
ConversationHandoffAppend {
|
||||
handoff_id: String,
|
||||
output_text: String,
|
||||
},
|
||||
#[serde(rename = "response.create")]
|
||||
ResponseCreate,
|
||||
#[serde(rename = "session.update")]
|
||||
@@ -29,24 +41,50 @@ pub(super) enum RealtimeOutboundMessage {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionUpdateSession {
|
||||
#[serde(untagged)]
|
||||
pub(super) enum SessionUpdateSession {
|
||||
V1(SessionUpdateSessionV1),
|
||||
V2(SessionUpdateSessionV2),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionUpdateSessionV1 {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) instructions: String,
|
||||
pub(super) audio: SessionAudioV1,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionUpdateSessionV2 {
|
||||
#[serde(rename = "type")]
|
||||
pub(super) kind: String,
|
||||
pub(super) instructions: String,
|
||||
pub(super) output_modalities: Vec<String>,
|
||||
pub(super) audio: SessionAudio,
|
||||
pub(super) audio: SessionAudioV2,
|
||||
pub(super) tools: Vec<SessionTool>,
|
||||
pub(super) tool_choice: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionAudio {
|
||||
pub(super) input: SessionAudioInput,
|
||||
pub(super) output: SessionAudioOutput,
|
||||
pub(super) struct SessionAudioV1 {
|
||||
pub(super) input: SessionAudioInputV1,
|
||||
pub(super) output: SessionAudioOutputV1,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionAudioInput {
|
||||
pub(super) struct SessionAudioV2 {
|
||||
pub(super) input: SessionAudioInputV2,
|
||||
pub(super) output: SessionAudioOutputV2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionAudioInputV1 {
|
||||
pub(super) format: SessionAudioFormat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionAudioInputV2 {
|
||||
pub(super) format: SessionAudioFormat,
|
||||
pub(super) turn_detection: SessionTurnDetection,
|
||||
}
|
||||
@@ -67,7 +105,12 @@ pub(super) struct SessionTurnDetection {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionAudioOutput {
|
||||
pub(super) struct SessionAudioOutputV1 {
|
||||
pub(super) voice: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(super) struct SessionAudioOutputV2 {
|
||||
pub(super) format: SessionAudioOutputFormat,
|
||||
pub(super) voice: String,
|
||||
}
|
||||
@@ -127,7 +170,7 @@ pub(super) struct ConversationItemContent {
|
||||
pub(super) text: String,
|
||||
}
|
||||
|
||||
pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
|
||||
pub(super) fn parse_realtime_event(payload: &str, mode: RealtimeApiMode) -> Option<RealtimeEvent> {
|
||||
let parsed: Value = match serde_json::from_str(payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(err) => {
|
||||
@@ -143,52 +186,44 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
|
||||
return None;
|
||||
}
|
||||
};
|
||||
match mode {
|
||||
RealtimeApiMode::V1 => parse_realtime_event_v1(&parsed, message_type, payload),
|
||||
RealtimeApiMode::V2 => parse_realtime_event_v2(parsed, message_type),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_realtime_event_v1(
|
||||
parsed: &Value,
|
||||
message_type: &str,
|
||||
payload: &str,
|
||||
) -> Option<RealtimeEvent> {
|
||||
match message_type {
|
||||
"session.created" | "session.updated" => {
|
||||
let session_id = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
let instructions = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("instructions"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
session_id.map(|session_id| RealtimeEvent::SessionUpdated {
|
||||
session_id,
|
||||
instructions,
|
||||
})
|
||||
}
|
||||
"response.output_audio.delta" => {
|
||||
let data = parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| parsed.get("data").and_then(Value::as_str))
|
||||
.map(str::to_string)?;
|
||||
let sample_rate = parsed
|
||||
.get("sample_rate")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u32::try_from(v).ok())
|
||||
.unwrap_or(24_000);
|
||||
let num_channels = parsed
|
||||
.get("channels")
|
||||
.or_else(|| parsed.get("num_channels"))
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u16::try_from(v).ok())
|
||||
.unwrap_or(1);
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate,
|
||||
num_channels,
|
||||
samples_per_channel: parsed
|
||||
.get("samples_per_channel")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u32::try_from(v).ok()),
|
||||
}))
|
||||
"session.updated" => parse_session_updated(parsed),
|
||||
"conversation.output_audio.delta" => parse_audio_delta(parsed, false),
|
||||
"conversation.item.added" => parsed
|
||||
.get("item")
|
||||
.cloned()
|
||||
.map(RealtimeEvent::ConversationItemAdded),
|
||||
"conversation.item.done" => parsed
|
||||
.get("item")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|item| item.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
|
||||
"conversation.handoff.requested" => parse_handoff_requested_v1(parsed),
|
||||
"error" => parse_realtime_error(parsed),
|
||||
_ => {
|
||||
debug!("received unsupported realtime event type: {message_type}, data: {payload}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_realtime_event_v2(parsed: Value, message_type: &str) -> Option<RealtimeEvent> {
|
||||
match message_type {
|
||||
"session.created" | "session.updated" => parse_session_updated(&parsed),
|
||||
"response.output_audio.delta" => parse_audio_delta(&parsed, true),
|
||||
"conversation.item.added" => parsed
|
||||
.get("item")
|
||||
.cloned()
|
||||
@@ -201,30 +236,110 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
|
||||
.map(str::to_string)
|
||||
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
|
||||
"response.done" => {
|
||||
if let Some(handoff) = parse_handoff_requested(&parsed) {
|
||||
if let Some(handoff) = parse_handoff_requested_v2(&parsed) {
|
||||
return Some(RealtimeEvent::HandoffRequested(handoff));
|
||||
}
|
||||
Some(RealtimeEvent::ConversationItemAdded(parsed))
|
||||
}
|
||||
"error" => parsed
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.or_else(|| {
|
||||
parsed
|
||||
.get("error")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|error| error.get("message"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.or_else(|| parsed.get("error").map(ToString::to_string))
|
||||
.map(RealtimeEvent::Error),
|
||||
"error" => parse_realtime_error(&parsed),
|
||||
_ => Some(RealtimeEvent::ConversationItemAdded(parsed)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_handoff_requested(parsed: &Value) -> Option<RealtimeHandoffRequested> {
|
||||
fn parse_session_updated(parsed: &Value) -> Option<RealtimeEvent> {
|
||||
let session_id = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("id"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
let instructions = parsed
|
||||
.get("session")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|session| session.get("instructions"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string);
|
||||
session_id.map(|session_id| RealtimeEvent::SessionUpdated {
|
||||
session_id,
|
||||
instructions,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_audio_delta(parsed: &Value, default_shape: bool) -> Option<RealtimeEvent> {
|
||||
let data = parsed
|
||||
.get("delta")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| parsed.get("data").and_then(Value::as_str))
|
||||
.map(str::to_string)?;
|
||||
let sample_rate = parsed
|
||||
.get("sample_rate")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u32::try_from(v).ok());
|
||||
let num_channels = parsed
|
||||
.get("channels")
|
||||
.or_else(|| parsed.get("num_channels"))
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u16::try_from(v).ok());
|
||||
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
|
||||
data,
|
||||
sample_rate: sample_rate.or_else(|| default_shape.then_some(24_000))?,
|
||||
num_channels: num_channels.or_else(|| default_shape.then_some(1))?,
|
||||
samples_per_channel: parsed
|
||||
.get("samples_per_channel")
|
||||
.and_then(Value::as_u64)
|
||||
.and_then(|v| u32::try_from(v).ok()),
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_realtime_error(parsed: &Value) -> Option<RealtimeEvent> {
|
||||
parsed
|
||||
.get("message")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
.or_else(|| {
|
||||
parsed
|
||||
.get("error")
|
||||
.and_then(Value::as_object)
|
||||
.and_then(|error| error.get("message"))
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.or_else(|| parsed.get("error").map(ToString::to_string))
|
||||
.map(RealtimeEvent::Error)
|
||||
}
|
||||
|
||||
fn parse_handoff_requested_v1(parsed: &Value) -> Option<RealtimeHandoffRequested> {
|
||||
let handoff_id = parsed
|
||||
.get("handoff_id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let item_id = parsed
|
||||
.get("item_id")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let input_transcript = parsed
|
||||
.get("input_transcript")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)?;
|
||||
let messages = parsed
|
||||
.get("messages")
|
||||
.and_then(Value::as_array)?
|
||||
.iter()
|
||||
.filter_map(|message| {
|
||||
let role = message.get("role").and_then(Value::as_str)?.to_string();
|
||||
let text = message.get("text").and_then(Value::as_str)?.to_string();
|
||||
Some(RealtimeHandoffMessage { role, text })
|
||||
})
|
||||
.collect();
|
||||
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
|
||||
handoff_id,
|
||||
item_id,
|
||||
input_transcript,
|
||||
messages,
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_handoff_requested_v2(parsed: &Value) -> Option<RealtimeHandoffRequested> {
|
||||
let outputs = parsed
|
||||
.get("response")
|
||||
.and_then(Value::as_object)
|
||||
|
||||
@@ -27,6 +27,7 @@ pub use crate::common::create_text_param_for_request;
|
||||
pub use crate::endpoint::compact::CompactClient;
|
||||
pub use crate::endpoint::memories::MemoriesClient;
|
||||
pub use crate::endpoint::models::ModelsClient;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeApiMode;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeSessionConfig;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketClient;
|
||||
pub use crate::endpoint::realtime_websocket::RealtimeWebsocketConnection;
|
||||
|
||||
Reference in New Issue
Block a user