Simplification 2

This commit is contained in:
jif-oai
2025-11-12 15:53:18 +00:00
parent f5918d7e1b
commit 166ca2fce7
9 changed files with 395 additions and 241 deletions

View File

@@ -12,6 +12,14 @@ use crate::error::Error;
use crate::error::Result;
use crate::stream::WireEvent;
async fn send_wire_event(tx: &mpsc::Sender<crate::error::Result<WireEvent>>, event: WireEvent) {
let _ = tx.send(Ok(event)).await;
}
fn serialize_response_item(item: ResponseItem) -> Value {
serde_json::to_value(item).unwrap_or_else(|_| Value::String(String::new()))
}
#[derive(Default)]
struct FunctionCallState {
active: bool,
@@ -34,6 +42,171 @@ impl WireChatSseDecoder {
pub fn new() -> Self {
Self::default()
}
async fn emit_created_once(&mut self, tx: &mpsc::Sender<crate::error::Result<WireEvent>>) {
if self.created_emitted {
return;
}
send_wire_event(tx, WireEvent::Created).await;
self.created_emitted = true;
}
async fn handle_content_delta(
&mut self,
delta: &Value,
tx: &mpsc::Sender<crate::error::Result<WireEvent>>,
) {
if let Some(content) = delta.get("content").and_then(|c| c.as_array()) {
for piece in content {
if let Some(text) = piece.get("text").and_then(|t| t.as_str()) {
self.push_assistant_text(text, tx).await;
}
}
}
if let Some(reasoning) = delta.get("reasoning_content").and_then(|c| c.as_array()) {
for entry in reasoning {
if let Some(text) = entry.get("text").and_then(|t| t.as_str()) {
self.push_reasoning_text(text, tx).await;
}
}
}
}
async fn push_assistant_text(
&mut self,
text: &str,
tx: &mpsc::Sender<crate::error::Result<WireEvent>>,
) {
self.start_assistant(tx).await;
self.assistant_text.push_str(text);
send_wire_event(tx, WireEvent::OutputTextDelta(text.to_string())).await;
}
async fn push_reasoning_text(
&mut self,
text: &str,
tx: &mpsc::Sender<crate::error::Result<WireEvent>>,
) {
self.start_reasoning(tx).await;
self.reasoning_text.push_str(text);
send_wire_event(tx, WireEvent::ReasoningContentDelta(text.to_string())).await;
}
async fn start_assistant(&mut self, tx: &mpsc::Sender<crate::error::Result<WireEvent>>) {
if self.assistant_started {
return;
}
self.assistant_started = true;
let message = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: String::new(),
}],
};
send_wire_event(
tx,
WireEvent::OutputItemAdded(serialize_response_item(message)),
)
.await;
}
async fn start_reasoning(&mut self, tx: &mpsc::Sender<crate::error::Result<WireEvent>>) {
if self.reasoning_started {
return;
}
self.reasoning_started = true;
let reasoning_item = ResponseItem::Reasoning {
id: String::new(),
summary: vec![],
content: None,
encrypted_content: None,
};
send_wire_event(
tx,
WireEvent::OutputItemAdded(serialize_response_item(reasoning_item)),
)
.await;
}
fn record_tool_calls(&mut self, delta: &Value) {
if let Some(tool_calls) = delta.get("tool_calls").and_then(|c| c.as_array()) {
for call in tool_calls {
if let Some(id_val) = call.get("id").and_then(|id| id.as_str()) {
self.fn_call_state.call_id = Some(id_val.to_string());
}
if let Some(function) = call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
self.fn_call_state.name = Some(name.to_string());
self.fn_call_state.active = true;
}
if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
self.fn_call_state.arguments.push_str(args);
}
}
}
}
}
fn finish_function_call(&mut self) -> Option<Value> {
if !self.fn_call_state.active {
return None;
}
let function_name = self.fn_call_state.name.take().unwrap_or_default();
let call_id = self.fn_call_state.call_id.take().unwrap_or_default();
let arguments = std::mem::take(&mut self.fn_call_state.arguments);
self.fn_call_state = FunctionCallState::default();
Some(serde_json::json!({
"type": "function_call",
"id": call_id,
"call_id": call_id,
"name": function_name,
"arguments": arguments,
}))
}
fn finish_reasoning(&mut self) -> Option<Value> {
if !self.reasoning_started {
return None;
}
let mut content = Vec::new();
let text = std::mem::take(&mut self.reasoning_text);
if !text.is_empty() {
content.push(ReasoningItemContent::ReasoningText { text });
}
self.reasoning_started = false;
Some(serialize_response_item(ResponseItem::Reasoning {
id: String::new(),
summary: vec![],
content: Some(content),
encrypted_content: None,
}))
}
fn finish_assistant(&mut self) -> Option<Value> {
if !self.assistant_started {
return None;
}
let text = std::mem::take(&mut self.assistant_text);
self.assistant_started = false;
Some(serialize_response_item(ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText { text }],
}))
}
fn reset_reasoning_and_assistant(&mut self) {
self.assistant_started = false;
self.assistant_text.clear();
self.reasoning_started = false;
self.reasoning_text.clear();
}
}
#[async_trait]
@@ -57,137 +230,39 @@ impl WireResponseDecoder for WireChatSseDecoder {
.unwrap_or_default();
for choice in choices {
if !self.created_emitted {
let _ = tx.send(Ok(WireEvent::Created)).await;
self.created_emitted = true;
}
self.emit_created_once(tx).await;
if let Some(delta) = choice.get("delta") {
if let Some(content) = delta.get("content").and_then(|c| c.as_array()) {
for piece in content {
if let Some(text) = piece.get("text").and_then(|t| t.as_str()) {
if !self.assistant_started {
self.assistant_started = true;
let message = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: String::new(),
}],
};
let value = serde_json::to_value(message)
.unwrap_or_else(|_| Value::String(String::new()));
let _ = tx.send(Ok(WireEvent::OutputItemAdded(value))).await;
}
self.assistant_text.push_str(text);
let _ = tx
.send(Ok(WireEvent::OutputTextDelta(text.to_string())))
.await;
}
}
}
if let Some(tool_calls) = delta.get("tool_calls").and_then(|c| c.as_array()) {
for call in tool_calls {
if let Some(id_val) = call.get("id").and_then(|id| id.as_str()) {
self.fn_call_state.call_id = Some(id_val.to_string());
}
if let Some(function) = call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
self.fn_call_state.name = Some(name.to_string());
self.fn_call_state.active = true;
}
if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
self.fn_call_state.arguments.push_str(args);
}
}
}
}
if let Some(reasoning) = delta.get("reasoning_content").and_then(|c| c.as_array()) {
for entry in reasoning {
if let Some(text) = entry.get("text").and_then(|t| t.as_str()) {
if !self.reasoning_started {
self.reasoning_started = true;
let reasoning_item = ResponseItem::Reasoning {
id: String::new(),
summary: vec![],
content: None,
encrypted_content: None,
};
let value = serde_json::to_value(reasoning_item)
.unwrap_or_else(|_| Value::String(String::new()));
let _ = tx.send(Ok(WireEvent::OutputItemAdded(value))).await;
}
self.reasoning_text.push_str(text);
let _ = tx
.send(Ok(WireEvent::ReasoningContentDelta(text.to_string())))
.await;
}
}
}
self.handle_content_delta(delta, tx).await;
self.record_tool_calls(delta);
}
if let Some(finish_reason) = choice.get("finish_reason").and_then(|f| f.as_str()) {
match finish_reason {
"tool_calls" if self.fn_call_state.active => {
let function_name = self.fn_call_state.name.take().unwrap_or_default();
let call_id = self.fn_call_state.call_id.take().unwrap_or_default();
let arguments = self.fn_call_state.arguments.clone();
self.fn_call_state = FunctionCallState::default();
let item = serde_json::json!({
"type": "function_call",
"id": call_id,
"call_id": call_id,
"name": function_name,
"arguments": arguments,
});
let _ = tx.send(Ok(WireEvent::OutputItemDone(item))).await;
"tool_calls" => {
if let Some(item) = self.finish_function_call() {
send_wire_event(tx, WireEvent::OutputItemDone(item)).await;
}
}
"stop" | "length" => {
if self.reasoning_started {
let mut content = Vec::new();
if !self.reasoning_text.is_empty() {
content.push(ReasoningItemContent::ReasoningText {
text: self.reasoning_text.clone(),
});
}
let reasoning_item = ResponseItem::Reasoning {
id: String::new(),
summary: vec![],
content: Some(content),
encrypted_content: None,
};
let value = serde_json::to_value(reasoning_item)
.unwrap_or_else(|_| Value::String(String::new()));
let _ = tx.send(Ok(WireEvent::OutputItemDone(value))).await;
if let Some(reasoning_item) = self.finish_reasoning() {
send_wire_event(tx, WireEvent::OutputItemDone(reasoning_item)).await;
}
if self.assistant_started {
let message = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: self.assistant_text.clone(),
}],
};
let value = serde_json::to_value(message)
.unwrap_or_else(|_| Value::String(String::new()));
let _ = tx.send(Ok(WireEvent::OutputItemDone(value))).await;
if let Some(message) = self.finish_assistant() {
send_wire_event(tx, WireEvent::OutputItemDone(message)).await;
}
let _ = tx
.send(Ok(WireEvent::Completed {
send_wire_event(
tx,
WireEvent::Completed {
response_id: String::new(),
token_usage: None,
}))
.await;
},
)
.await;
self.assistant_started = false;
self.assistant_text.clear();
self.reasoning_started = false;
self.reasoning_text.clear();
self.reset_reasoning_and_assistant();
}
_ => {}
}