use std::pin::Pin; use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; use futures_core::stream::Stream; use futures_util::StreamExt; use reqwest::Client; use tracing::{debug, error, info}; use super::LlmProvider; use crate::llm::error::LlmError; use crate::llm::types::{ ChatRequest, ChatResponse, OpenaiChatChunk, OpenaiChatResponse, StreamOptions, }; pub struct OpenaiProvider { http_client: Client, base_url: String, api_key: String, } impl OpenaiProvider { pub fn new(base_url: String, api_key: String, _model: String) -> Self { let http_client = Client::builder() .timeout(Duration::from_secs(120)) .build() .expect("创建 HTTP 客户端失败"); Self { http_client, base_url, api_key, } } pub fn with_client(mut self, client: Client) -> Self { self.http_client = client; self } fn map_reqwest_error(e: reqwest::Error) -> LlmError { if e.is_timeout() { LlmError::Timeout { duration: Duration::from_secs(120), } } else if e.is_connect() { LlmError::Other(format!("连接失败: {}", e)) } else { LlmError::Other(format!("请求失败: {}", e)) } } } #[async_trait] impl LlmProvider for OpenaiProvider { async fn chat(&self, request: ChatRequest) -> Result { let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/')); info!(model = %request.model, max_tokens = request.max_tokens, temperature = request.temperature, "发送 LLM 请求"); let response = self .http_client .post(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&request) .send() .await .map_err(|e| { error!(error = %e, "请求失败"); Self::map_reqwest_error(e) })?; let status = response.status(); let status_code: u16 = status.as_u16(); if !status.is_success() { let retry_after = response .headers() .get("retry-after") .and_then(|v| v.to_str().ok()) .and_then(|v| v.parse::().ok()) .map(Duration::from_secs); let body_text = response.text().await.unwrap_or_default(); error!(status = status_code, body = %body_text, "请求失败"); return match status_code { 401 => Err(LlmError::Authentication(body_text)), 429 => Err(LlmError::RateLimit { retry_after }), _ if status_code >= 500 => Err(LlmError::Request { status: status_code, body: body_text, }), _ if status_code == 400 && body_text.contains("context_length_exceeded") => { Err(LlmError::ContextLength { actual: 0, limit: 0, }) } _ => Err(LlmError::Request { status: status_code, body: body_text, }), }; } let body_text = response.text().await.unwrap_or_default(); debug!(body = %body_text, "收到响应体"); let chat_response: OpenaiChatResponse = serde_json::from_str(&body_text).map_err(|e| { error!(error = %e, body = %body_text, "响应解析失败"); LlmError::Other(format!("响应解析失败: {}", e)) })?; debug!(response = ?chat_response, "收到 LLM 响应"); Ok(ChatResponse::from(chat_response)) } async fn chat_stream( &self, mut request: ChatRequest, ) -> Result< Pin> + Send>>, LlmError, > { let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/')); request.stream = Some(true); request.stream_options = Some(StreamOptions { include_usage: Some(true), include_obfuscation: None, }); info!(model = %request.model, "发送 LLM 流式请求"); let response = self .http_client .post(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&request) .send() .await .map_err(|e| { error!(error = %e, "流式请求失败"); Self::map_reqwest_error(e) })?; let status = response.status(); let status_code: u16 = status.as_u16(); if !status.is_success() { let body_text = response.text().await.unwrap_or_default(); error!(status = status_code, body = %body_text, "流式请求失败"); return Err(LlmError::Request { status: status_code, body: body_text, }); } let byte_stream = response.bytes_stream().map(|r| { r.map_err(|e| LlmError::Other(format!("流式读取失败: {}", e))) }); Ok(Box::pin(SseChunkStream::new(byte_stream))) } } struct SseChunkStream { inner: S, buffer: String, } impl> + Unpin> SseChunkStream { fn new(stream: S) -> Self { Self { inner: stream, buffer: String::new(), } } } impl> + Unpin> Stream for SseChunkStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { loop { if let Some(pos) = self.buffer.find("\n") { let line = self.buffer.drain(..pos + 1).collect::(); let trimmed = line.trim(); if trimmed.is_empty() || trimmed == "data: [DONE]" || trimmed == "[DONE]" || trimmed == "data:" { continue; } let data = if let Some(p) = trimmed.strip_prefix("data: ") { p } else { trimmed }; match serde_json::from_str::(data) { Ok(chunk) => return std::task::Poll::Ready(Some(Ok(chunk))), Err(e) => { return std::task::Poll::Ready(Some(Err(LlmError::Other(format!( "Chunk 解析失败: {} | raw: {}", e, data ))))); } } } match Pin::new(&mut self.inner).poll_next(cx) { std::task::Poll::Ready(Some(Ok(bytes))) => { if let Ok(s) = std::str::from_utf8(&bytes) { self.buffer.push_str(s); } } std::task::Poll::Ready(Some(Err(e))) => { return std::task::Poll::Ready(Some(Err(e))); } std::task::Poll::Ready(None) => { if self.buffer.is_empty() { return std::task::Poll::Ready(None); } self.buffer.clear(); return std::task::Poll::Ready(None); } std::task::Poll::Pending => return std::task::Poll::Pending, } } } }