Files
agcore/src/llm/provider/openai.rs
T
徐涛 32f3edaf19 feat(llm): 实现 Phase 0 剩余四个模块
实现 ProviderRegistry、HookExecutor、StreamEvents 和 Auto-compaction 模块,并集成到 LlmCycle 中
2026-06-02 08:51:42 +08:00

236 lines
7.6 KiB
Rust

use std::pin::Pin;
use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use futures_core::stream::Stream;
use futures_util::StreamExt;
use reqwest::Client;
use tracing::{debug, error, info};
use super::LlmProvider;
use crate::llm::error::LlmError;
use crate::llm::types::{
ChatRequest, ChatResponse, OpenaiChatChunk, OpenaiChatResponse, StreamOptions,
};
pub struct OpenaiProvider {
http_client: Client,
base_url: String,
api_key: String,
}
impl OpenaiProvider {
pub fn new(base_url: String, api_key: String, _model: String) -> Self {
let http_client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.expect("创建 HTTP 客户端失败");
Self {
http_client,
base_url,
api_key,
}
}
pub fn with_client(mut self, client: Client) -> Self {
self.http_client = client;
self
}
fn map_reqwest_error(e: reqwest::Error) -> LlmError {
if e.is_timeout() {
LlmError::Timeout {
duration: Duration::from_secs(120),
}
} else if e.is_connect() {
LlmError::Other(format!("连接失败: {}", e))
} else {
LlmError::Other(format!("请求失败: {}", e))
}
}
}
#[async_trait]
impl LlmProvider for OpenaiProvider {
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError> {
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
info!(model = %request.model, max_tokens = request.max_tokens, temperature = request.temperature, "发送 LLM 请求");
let response = self
.http_client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await
.map_err(|e| {
error!(error = %e, "请求失败");
Self::map_reqwest_error(e)
})?;
let status = response.status();
let status_code: u16 = status.as_u16();
if !status.is_success() {
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs);
let body_text = response.text().await.unwrap_or_default();
error!(status = status_code, body = %body_text, "请求失败");
return match status_code {
401 => Err(LlmError::Authentication(body_text)),
429 => Err(LlmError::RateLimit { retry_after }),
_ if status_code >= 500 => Err(LlmError::Request {
status: status_code,
body: body_text,
}),
_ if status_code == 400 && body_text.contains("context_length_exceeded") => {
Err(LlmError::ContextLength {
actual: 0,
limit: 0,
})
}
_ => Err(LlmError::Request {
status: status_code,
body: body_text,
}),
};
}
let body_text = response.text().await.unwrap_or_default();
debug!(body = %body_text, "收到响应体");
let chat_response: OpenaiChatResponse = serde_json::from_str(&body_text).map_err(|e| {
error!(error = %e, body = %body_text, "响应解析失败");
LlmError::Other(format!("响应解析失败: {}", e))
})?;
debug!(response = ?chat_response, "收到 LLM 响应");
Ok(ChatResponse::from(chat_response))
}
async fn chat_stream(
&self,
mut request: ChatRequest,
) -> Result<
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
LlmError,
> {
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
request.stream = Some(true);
request.stream_options = Some(StreamOptions {
include_usage: Some(true),
include_obfuscation: None,
});
info!(model = %request.model, "发送 LLM 流式请求");
let response = self
.http_client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await
.map_err(|e| {
error!(error = %e, "流式请求失败");
Self::map_reqwest_error(e)
})?;
let status = response.status();
let status_code: u16 = status.as_u16();
if !status.is_success() {
let body_text = response.text().await.unwrap_or_default();
error!(status = status_code, body = %body_text, "流式请求失败");
return Err(LlmError::Request {
status: status_code,
body: body_text,
});
}
let byte_stream = response.bytes_stream().map(|r| {
r.map_err(|e| LlmError::Other(format!("流式读取失败: {}", e)))
});
Ok(Box::pin(SseChunkStream::new(byte_stream)))
}
}
struct SseChunkStream<S> {
inner: S,
buffer: String,
}
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> SseChunkStream<S> {
fn new(stream: S) -> Self {
Self {
inner: stream,
buffer: String::new(),
}
}
}
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> Stream for SseChunkStream<S> {
type Item = Result<OpenaiChatChunk, LlmError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
loop {
if let Some(pos) = self.buffer.find("\n") {
let line = self.buffer.drain(..pos + 1).collect::<String>();
let trimmed = line.trim();
if trimmed.is_empty()
|| trimmed == "data: [DONE]"
|| trimmed == "[DONE]"
|| trimmed == "data:"
{
continue;
}
let data = if let Some(p) = trimmed.strip_prefix("data: ") {
p
} else {
trimmed
};
match serde_json::from_str::<OpenaiChatChunk>(data) {
Ok(chunk) => return std::task::Poll::Ready(Some(Ok(chunk))),
Err(e) => {
return std::task::Poll::Ready(Some(Err(LlmError::Other(format!(
"Chunk 解析失败: {} | raw: {}",
e, data
)))));
}
}
}
match Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(Ok(bytes))) => {
if let Ok(s) = std::str::from_utf8(&bytes) {
self.buffer.push_str(s);
}
}
std::task::Poll::Ready(Some(Err(e))) => {
return std::task::Poll::Ready(Some(Err(e)));
}
std::task::Poll::Ready(None) => {
if self.buffer.is_empty() {
return std::task::Poll::Ready(None);
}
self.buffer.clear();
return std::task::Poll::Ready(None);
}
std::task::Poll::Pending => return std::task::Poll::Pending,
}
}
}
}