feat(llm): 添加 LLM 调用周期核心模块

新增 LLM 调用生命周期引擎,包含 Provider 抽象、OpenAI 兼容实现、
可重试机制及 Token 用量追踪。移除原有的占位测试代码。
添加所需的 Rust 依赖(tokio、reqwest、serde 等)。
This commit is contained in:
徐涛
2026-05-12 06:06:24 +08:00
parent b21e163be0
commit 91d32a6a82
10 changed files with 788 additions and 13 deletions
+354
View File
@@ -0,0 +1,354 @@
use std::time::Duration;
use async_trait::async_trait;
use reqwest::Client;
use serde_json::{json, Value};
use crate::llm::cycle::usage::Usage;
use crate::llm::error::LlmError;
use crate::llm::types::{
ChatRequest, ChatResponse, ContentBlock, Message, Role, StopReason, ToolDefinition,
};
use super::LlmProvider;
/// OpenAI 兼容 API 的 Provider 实现。
///
/// 支持任意实现了 `POST /v1/chat/completions` 标准的 API
/// (包括 OpenAI、Azure OpenAI、DashScope、vLLM 等)。
pub struct OpenaiProvider {
http_client: Client,
base_url: String,
api_key: String,
#[allow(dead_code)]
model: String,
}
impl OpenaiProvider {
/// 创建新的 OpenAI Provider。
///
/// 默认使用 120 秒超时的 HTTP 客户端。
pub fn new(base_url: String, api_key: String, model: String) -> Self {
let http_client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.expect("创建 HTTP 客户端失败");
Self {
http_client,
base_url,
api_key,
model,
}
}
/// 替换为自定义的 HTTP 客户端(用于测试或自定义配置)。
pub fn with_client(mut self, client: Client) -> Self {
self.http_client = client;
self
}
/// 将 ChatRequest 构建为 OpenAI API 请求体 JSON。
fn build_request_body(&self, request: &ChatRequest) -> Value {
let mut body = json!({
"model": request.model,
"messages": Self::serialize_messages(request),
});
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = json!(max_tokens);
}
if let Some(temperature) = request.temperature {
body["temperature"] = json!(temperature);
}
if !request.tools.is_empty() {
body["tools"] = json!(
request
.tools
.iter()
.map(Self::serialize_tool)
.collect::<Vec<_>>()
);
}
// 合并 extra_body 中的扩展参数到请求体顶层
if let Some(ref extra) = request.extra_body
&& let Some(obj) = extra.as_object()
{
for (k, v) in obj {
body[k] = v.clone();
}
}
body
}
/// 将请求中的消息列表序列化为 API 消息数组。
fn serialize_messages(request: &ChatRequest) -> Vec<Value> {
let mut messages: Vec<Value> = Vec::new();
// system_prompt 作为独立的 system 角色消息放在最前面
if let Some(ref system_prompt) = request.system_prompt {
messages.push(json!({
"role": "system",
"content": system_prompt
}));
}
for msg in &request.messages {
messages.push(Self::serialize_message(msg));
}
messages
}
/// 将单条消息序列化为 API 格式。
///
/// 处理逻辑:
/// - 多个 content block 或包含图片 → 使用数组格式
/// - ToolResult → 使用 tool 角色格式
/// - 其他 → 使用纯文本格式
fn serialize_message(msg: &Message) -> Value {
let role_str = match msg.role {
Role::User => "user",
Role::Assistant => "assistant",
Role::System => "system",
Role::Tool => "tool",
};
let has_mixed = msg.content.len() > 1
|| msg
.content
.iter()
.any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
if has_mixed {
let content: Vec<Value> = msg
.content
.iter()
.map(Self::serialize_content_block)
.collect();
json!({ "role": role_str, "content": content })
} else if let Some(ContentBlock::ToolResult {
tool_use_id,
content,
}) = msg.content.first()
{
json!({
"role": "tool",
"tool_call_id": tool_use_id,
"content": content
})
} else {
let text = msg
.content
.first()
.map(|b| match b {
ContentBlock::Text { text } => text.clone(),
_ => String::new(),
})
.unwrap_or_default();
json!({ "role": role_str, "content": text })
}
}
/// 将 ContentBlock 序列化为 API content parts 数组元素。
fn serialize_content_block(block: &ContentBlock) -> Value {
match block {
ContentBlock::Text { text } => {
json!({ "type": "text", "text": text })
}
ContentBlock::ImageUrl { url } => {
json!({ "type": "image_url", "image_url": { "url": url } })
}
ContentBlock::ToolUse { id, name, input } => {
json!({ "type": "tool_use", "id": id, "name": name, "input": input })
}
ContentBlock::ToolResult { .. } => {
json!({ "type": "tool_result", "content": "" })
}
}
}
/// 将 ToolDefinition 序列化为 OpenAI tools 数组元素。
fn serialize_tool(tool: &ToolDefinition) -> Value {
json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema
}
})
}
/// 将 OpenAI API 响应 JSON 解析为 ChatResponse。
fn parse_response(response: Value) -> Result<ChatResponse, LlmError> {
let choice = response["choices"][0]
.as_object()
.ok_or_else(|| LlmError::Other("响应中缺少 choices[0]".into()))?;
let msg = choice["message"]
.as_object()
.ok_or_else(|| LlmError::Other("响应中缺少 message".into()))?;
let role = match msg["role"].as_str() {
Some("assistant") => Role::Assistant,
Some(_) => Role::Assistant,
None => Role::Assistant,
};
let mut content_blocks: Vec<ContentBlock> = Vec::new();
// 从 content 字段提取文本和 tool_use
if let Some(content_val) = msg.get("content") {
match content_val {
Value::String(s) if !s.is_empty() => {
content_blocks.push(ContentBlock::Text { text: s.clone() });
}
Value::Array(arr) => {
for item in arr {
if let Some(item_type) = item["type"].as_str() {
match item_type {
"text" => {
if let Some(text) = item["text"].as_str() {
content_blocks
.push(ContentBlock::Text { text: text.into() });
}
}
"tool_use" | "function" => {
let id = item["id"].as_str().unwrap_or("").to_string();
let name = item["name"].as_str().unwrap_or("").to_string();
let input = item.get("input").cloned().unwrap_or(Value::Null);
content_blocks
.push(ContentBlock::ToolUse { id, name, input });
}
_ => {}
}
}
}
}
_ => {}
}
}
// 从 tool_calls 字段提取工具调用(OpenAI 特有格式)
if let Some(tool_calls) = msg.get("tool_calls").and_then(|v| v.as_array()) {
for tc in tool_calls {
let id = tc["id"].as_str().unwrap_or("").to_string();
let name = tc["function"]["name"].as_str().unwrap_or("").to_string();
let input = tc["function"]["arguments"]
.as_str()
.and_then(|s| serde_json::from_str(s).ok())
.unwrap_or(Value::Null);
content_blocks.push(ContentBlock::ToolUse { id, name, input });
}
}
if content_blocks.is_empty() {
content_blocks.push(ContentBlock::Text {
text: String::new(),
});
}
// 解析停止原因
let stop_reason = choice["finish_reason"].as_str().map(|s| match s {
"stop" => StopReason::Stop,
"tool_calls" => StopReason::ToolUse,
"max_tokens" => StopReason::MaxTokens,
"length" => StopReason::Length,
"content_filter" => StopReason::ContentFilter,
other => StopReason::Other(other.into()),
});
// 解析 token 用量
let usage = response["usage"]
.as_object()
.map(|u| Usage {
input_tokens: u.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
output_tokens: u
.get("completion_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32,
})
.unwrap_or_default();
Ok(ChatResponse {
message: Message {
role,
content: content_blocks,
},
usage,
stop_reason,
})
}
}
#[async_trait]
impl LlmProvider for OpenaiProvider {
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError> {
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
let body = self.build_request_body(&request);
let response = self
.http_client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
LlmError::Timeout {
duration: Duration::from_secs(120),
}
} else if e.is_connect() {
LlmError::Other(format!("连接失败: {}", e))
} else {
LlmError::Other(format!("请求失败: {}", e))
}
})?;
let status = response.status();
let status_code: u16 = status.as_u16();
// 处理非 2xx 响应,将 HTTP 状态码映射为对应的 LlmError 变体
if !status.is_success() {
// 在消费 response body 之前先读取 retry-after 头部
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs);
let body_text = response.text().await.unwrap_or_default();
return match status_code {
401 => Err(LlmError::Authentication(body_text)),
429 => Err(LlmError::RateLimit { retry_after }),
_ if status_code >= 500 => Err(LlmError::Request {
status: status_code,
body: body_text,
}),
_ if status_code == 400 && body_text.contains("context_length_exceeded") => {
Err(LlmError::ContextLength {
actual: 0,
limit: 0,
})
}
_ => Err(LlmError::Request {
status: status_code,
body: body_text,
}),
};
}
let json_body: Value = response
.json()
.await
.map_err(|e| LlmError::Other(format!("响应解析失败: {}", e)))?;
Self::parse_response(json_body)
}
}