32f3edaf19
实现 ProviderRegistry、HookExecutor、StreamEvents 和 Auto-compaction 模块,并集成到 LlmCycle 中
236 lines
7.6 KiB
Rust
236 lines
7.6 KiB
Rust
use std::pin::Pin;
|
|
use std::time::Duration;
|
|
|
|
use async_trait::async_trait;
|
|
use bytes::Bytes;
|
|
use futures_core::stream::Stream;
|
|
use futures_util::StreamExt;
|
|
use reqwest::Client;
|
|
use tracing::{debug, error, info};
|
|
|
|
use super::LlmProvider;
|
|
use crate::llm::error::LlmError;
|
|
use crate::llm::types::{
|
|
ChatRequest, ChatResponse, OpenaiChatChunk, OpenaiChatResponse, StreamOptions,
|
|
};
|
|
|
|
pub struct OpenaiProvider {
|
|
http_client: Client,
|
|
base_url: String,
|
|
api_key: String,
|
|
}
|
|
|
|
impl OpenaiProvider {
|
|
pub fn new(base_url: String, api_key: String, _model: String) -> Self {
|
|
let http_client = Client::builder()
|
|
.timeout(Duration::from_secs(120))
|
|
.build()
|
|
.expect("创建 HTTP 客户端失败");
|
|
|
|
Self {
|
|
http_client,
|
|
base_url,
|
|
api_key,
|
|
}
|
|
}
|
|
|
|
pub fn with_client(mut self, client: Client) -> Self {
|
|
self.http_client = client;
|
|
self
|
|
}
|
|
|
|
fn map_reqwest_error(e: reqwest::Error) -> LlmError {
|
|
if e.is_timeout() {
|
|
LlmError::Timeout {
|
|
duration: Duration::from_secs(120),
|
|
}
|
|
} else if e.is_connect() {
|
|
LlmError::Other(format!("连接失败: {}", e))
|
|
} else {
|
|
LlmError::Other(format!("请求失败: {}", e))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LlmProvider for OpenaiProvider {
|
|
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError> {
|
|
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
|
|
|
|
info!(model = %request.model, max_tokens = request.max_tokens, temperature = request.temperature, "发送 LLM 请求");
|
|
|
|
let response = self
|
|
.http_client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.json(&request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| {
|
|
error!(error = %e, "请求失败");
|
|
Self::map_reqwest_error(e)
|
|
})?;
|
|
|
|
let status = response.status();
|
|
let status_code: u16 = status.as_u16();
|
|
|
|
if !status.is_success() {
|
|
let retry_after = response
|
|
.headers()
|
|
.get("retry-after")
|
|
.and_then(|v| v.to_str().ok())
|
|
.and_then(|v| v.parse::<u64>().ok())
|
|
.map(Duration::from_secs);
|
|
let body_text = response.text().await.unwrap_or_default();
|
|
|
|
error!(status = status_code, body = %body_text, "请求失败");
|
|
|
|
return match status_code {
|
|
401 => Err(LlmError::Authentication(body_text)),
|
|
429 => Err(LlmError::RateLimit { retry_after }),
|
|
_ if status_code >= 500 => Err(LlmError::Request {
|
|
status: status_code,
|
|
body: body_text,
|
|
}),
|
|
_ if status_code == 400 && body_text.contains("context_length_exceeded") => {
|
|
Err(LlmError::ContextLength {
|
|
actual: 0,
|
|
limit: 0,
|
|
})
|
|
}
|
|
_ => Err(LlmError::Request {
|
|
status: status_code,
|
|
body: body_text,
|
|
}),
|
|
};
|
|
}
|
|
|
|
let body_text = response.text().await.unwrap_or_default();
|
|
debug!(body = %body_text, "收到响应体");
|
|
|
|
let chat_response: OpenaiChatResponse = serde_json::from_str(&body_text).map_err(|e| {
|
|
error!(error = %e, body = %body_text, "响应解析失败");
|
|
LlmError::Other(format!("响应解析失败: {}", e))
|
|
})?;
|
|
|
|
debug!(response = ?chat_response, "收到 LLM 响应");
|
|
|
|
Ok(ChatResponse::from(chat_response))
|
|
}
|
|
|
|
async fn chat_stream(
|
|
&self,
|
|
mut request: ChatRequest,
|
|
) -> Result<
|
|
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
|
LlmError,
|
|
> {
|
|
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
|
|
|
|
request.stream = Some(true);
|
|
request.stream_options = Some(StreamOptions {
|
|
include_usage: Some(true),
|
|
include_obfuscation: None,
|
|
});
|
|
|
|
info!(model = %request.model, "发送 LLM 流式请求");
|
|
|
|
let response = self
|
|
.http_client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.json(&request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| {
|
|
error!(error = %e, "流式请求失败");
|
|
Self::map_reqwest_error(e)
|
|
})?;
|
|
|
|
let status = response.status();
|
|
let status_code: u16 = status.as_u16();
|
|
|
|
if !status.is_success() {
|
|
let body_text = response.text().await.unwrap_or_default();
|
|
error!(status = status_code, body = %body_text, "流式请求失败");
|
|
return Err(LlmError::Request {
|
|
status: status_code,
|
|
body: body_text,
|
|
});
|
|
}
|
|
|
|
let byte_stream = response.bytes_stream().map(|r| {
|
|
r.map_err(|e| LlmError::Other(format!("流式读取失败: {}", e)))
|
|
});
|
|
|
|
Ok(Box::pin(SseChunkStream::new(byte_stream)))
|
|
}
|
|
}
|
|
|
|
struct SseChunkStream<S> {
|
|
inner: S,
|
|
buffer: String,
|
|
}
|
|
|
|
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> SseChunkStream<S> {
|
|
fn new(stream: S) -> Self {
|
|
Self {
|
|
inner: stream,
|
|
buffer: String::new(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> Stream for SseChunkStream<S> {
|
|
type Item = Result<OpenaiChatChunk, LlmError>;
|
|
|
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
|
|
loop {
|
|
if let Some(pos) = self.buffer.find("\n") {
|
|
let line = self.buffer.drain(..pos + 1).collect::<String>();
|
|
let trimmed = line.trim();
|
|
if trimmed.is_empty()
|
|
|| trimmed == "data: [DONE]"
|
|
|| trimmed == "[DONE]"
|
|
|| trimmed == "data:"
|
|
{
|
|
continue;
|
|
}
|
|
let data = if let Some(p) = trimmed.strip_prefix("data: ") {
|
|
p
|
|
} else {
|
|
trimmed
|
|
};
|
|
match serde_json::from_str::<OpenaiChatChunk>(data) {
|
|
Ok(chunk) => return std::task::Poll::Ready(Some(Ok(chunk))),
|
|
Err(e) => {
|
|
return std::task::Poll::Ready(Some(Err(LlmError::Other(format!(
|
|
"Chunk 解析失败: {} | raw: {}",
|
|
e, data
|
|
)))));
|
|
}
|
|
}
|
|
}
|
|
|
|
match Pin::new(&mut self.inner).poll_next(cx) {
|
|
std::task::Poll::Ready(Some(Ok(bytes))) => {
|
|
if let Ok(s) = std::str::from_utf8(&bytes) {
|
|
self.buffer.push_str(s);
|
|
}
|
|
}
|
|
std::task::Poll::Ready(Some(Err(e))) => {
|
|
return std::task::Poll::Ready(Some(Err(e)));
|
|
}
|
|
std::task::Poll::Ready(None) => {
|
|
if self.buffer.is_empty() {
|
|
return std::task::Poll::Ready(None);
|
|
}
|
|
self.buffer.clear();
|
|
return std::task::Poll::Ready(None);
|
|
}
|
|
std::task::Poll::Pending => return std::task::Poll::Pending,
|
|
}
|
|
}
|
|
}
|
|
}
|