feat(llm): 添加 LLM 调用周期核心模块
新增 LLM 调用生命周期引擎,包含 Provider 抽象、OpenAI 兼容实现、 可重试机制及 Token 用量追踪。移除原有的占位测试代码。 添加所需的 Rust 依赖(tokio、reqwest、serde 等)。
This commit is contained in:
@@ -4,3 +4,10 @@ version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "2"
|
||||
async-trait = "0.1"
|
||||
tracing = "0.1"
|
||||
|
||||
+5
-13
@@ -1,14 +1,6 @@
|
||||
pub fn add(left: u64, right: u64) -> u64 {
|
||||
left + right
|
||||
}
|
||||
//! agcore —— 智能体(Agent)核心工具箱。
|
||||
//!
|
||||
//! 当前提供 LLM 调用周期控制作为核心底层能力,后续将扩展至
|
||||
//! 提示词工程、记忆系统、工具调用、Agent 运行时等领域。
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn it_works() {
|
||||
let result = add(2, 2);
|
||||
assert_eq!(result, 4);
|
||||
}
|
||||
}
|
||||
pub mod llm;
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
//! LLM 调用周期 —— 大模型基础调用周期控制。
|
||||
//!
|
||||
//! 包含核心数据类型、Provider 抽象、OpenAI 兼容实现以及生命周期引擎。
|
||||
|
||||
pub mod cycle;
|
||||
pub mod error;
|
||||
pub mod provider;
|
||||
pub mod types;
|
||||
@@ -0,0 +1,149 @@
|
||||
mod retry;
|
||||
pub mod usage;
|
||||
|
||||
pub use retry::RetryConfig;
|
||||
pub use usage::{CostTracker, Usage};
|
||||
|
||||
use crate::llm::cycle::retry::should_retry;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::provider::LlmProvider;
|
||||
use crate::llm::types::{ChatRequest, ChatResponse, ContentBlock, Message, Role, ToolDefinition};
|
||||
|
||||
/// LLM 生命周期引擎的配置。
|
||||
pub struct CycleConfig {
|
||||
/// 使用的模型名称。
|
||||
pub model: String,
|
||||
/// 最大输出 token 数。
|
||||
pub max_tokens: Option<u32>,
|
||||
/// 采样温度。
|
||||
pub temperature: Option<f32>,
|
||||
/// 最大对话轮数(预留,暂未使用)。
|
||||
pub max_turns: Option<u32>,
|
||||
/// 重试策略配置。
|
||||
pub retry: RetryConfig,
|
||||
}
|
||||
|
||||
impl Default for CycleConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model: String::from("gpt-4o"),
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
max_turns: None,
|
||||
retry: RetryConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// LLM 调用生命周期引擎。
|
||||
///
|
||||
/// 管理一次多轮交互的完整生命周期,包括:
|
||||
/// - 消息历史维护
|
||||
/// - Token 用量追踪
|
||||
/// - 自动重试
|
||||
pub struct LlmCycle {
|
||||
provider: Box<dyn LlmProvider>,
|
||||
config: CycleConfig,
|
||||
usage: CostTracker,
|
||||
messages: Vec<Message>,
|
||||
system_prompt: Option<String>,
|
||||
}
|
||||
|
||||
impl LlmCycle {
|
||||
/// 创建新的生命周期引擎。
|
||||
pub fn new(provider: Box<dyn LlmProvider>, config: CycleConfig) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
config,
|
||||
usage: CostTracker::default(),
|
||||
messages: Vec::new(),
|
||||
system_prompt: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 设置系统提示词(Builder 模式)。
|
||||
pub fn with_system_prompt(mut self, prompt: String) -> Self {
|
||||
self.system_prompt = Some(prompt);
|
||||
self
|
||||
}
|
||||
|
||||
/// 获取 Token 用量追踪器引用。
|
||||
pub fn usage(&self) -> &CostTracker {
|
||||
&self.usage
|
||||
}
|
||||
|
||||
/// 获取当前消息历史。
|
||||
pub fn messages(&self) -> &[Message] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
/// 清空消息历史。
|
||||
pub fn clear_messages(&mut self) {
|
||||
self.messages.clear();
|
||||
}
|
||||
|
||||
/// 重置 Token 用量统计。
|
||||
pub fn reset_usage(&mut self) {
|
||||
self.usage.reset();
|
||||
}
|
||||
|
||||
/// 提交一条用户消息并获取模型响应。
|
||||
///
|
||||
/// 流程:
|
||||
/// 1. 将用户消息追加到消息历史
|
||||
/// 2. 构建 ChatRequest
|
||||
/// 3. 使用重试循环调用 provider.chat()
|
||||
/// 4. 将助手回复追加到消息历史
|
||||
/// 5. 累计 token 用量
|
||||
/// 6. 返回 ChatResponse
|
||||
pub async fn submit(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
tools: Vec<ToolDefinition>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
self.messages.push(Message {
|
||||
role: Role::User,
|
||||
content: vec![ContentBlock::Text { text: prompt }],
|
||||
});
|
||||
|
||||
let mut attempts = 0;
|
||||
|
||||
loop {
|
||||
let request = self.build_request(&tools);
|
||||
|
||||
match self.provider.chat(request).await {
|
||||
Ok(response) => {
|
||||
self.messages.push(Message {
|
||||
role: Role::Assistant,
|
||||
content: response.message.content.clone(),
|
||||
});
|
||||
|
||||
self.usage.add(&response.usage);
|
||||
|
||||
return Ok(response);
|
||||
}
|
||||
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
|
||||
attempts += 1;
|
||||
let delay = self.config.retry.compute_delay(attempts);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 根据当前状态构建 ChatRequest。
|
||||
fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest {
|
||||
ChatRequest {
|
||||
model: self.config.model.clone(),
|
||||
messages: self.messages.clone(),
|
||||
system_prompt: self.system_prompt.clone(),
|
||||
tools: tools.to_vec(),
|
||||
max_tokens: self.config.max_tokens,
|
||||
temperature: self.config.temperature,
|
||||
extra_body: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
|
||||
/// 重试策略配置。
|
||||
///
|
||||
/// 使用指数退避 + jitter 算法计算每次重试的等待时间。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryConfig {
|
||||
/// 最大重试次数(默认 3)。
|
||||
pub max_retries: u32,
|
||||
/// 初始延迟(默认 1 秒)。
|
||||
pub base_delay: Duration,
|
||||
/// 最大延迟上限(默认 30 秒)。
|
||||
pub max_delay: Duration,
|
||||
/// Jitter 比例因子(默认 0.25)。
|
||||
pub jitter_factor: f64,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_retries: 3,
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_delay: Duration::from_secs(30),
|
||||
jitter_factor: 0.25,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RetryConfig {
|
||||
/// 根据当前重试次数计算等待时间。
|
||||
///
|
||||
/// 算法: `delay = min(base * 2^(attempt-1), max_delay) + random(0, delay * jitter_factor)`
|
||||
pub fn compute_delay(&self, attempt: u32) -> Duration {
|
||||
let base = self.base_delay.as_secs_f64();
|
||||
let exponential = base * (2u64.pow(attempt.saturating_sub(1))) as f64;
|
||||
let capped = exponential.min(self.max_delay.as_secs_f64());
|
||||
let jitter = rand_jitter(capped * self.jitter_factor);
|
||||
|
||||
Duration::from_secs_f64(capped + jitter)
|
||||
}
|
||||
}
|
||||
|
||||
/// 判断错误是否可重试。
|
||||
///
|
||||
/// 可重试条件:
|
||||
/// - RateLimit(429)
|
||||
/// - Timeout
|
||||
/// - Request 且状态码 >= 500 或 == 429
|
||||
pub fn should_retry(err: &LlmError) -> bool {
|
||||
match err {
|
||||
LlmError::RateLimit { .. } => true,
|
||||
LlmError::Timeout { .. } => true,
|
||||
LlmError::Request { status, .. } => *status >= 500 || *status == 429,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// 基于纳秒时间戳的简单伪随机数,范围 [0, max)。
|
||||
fn rand_jitter(max: f64) -> f64 {
|
||||
if max <= 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
let t = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos();
|
||||
let r = (t % 1000) as f64 / 1000.0;
|
||||
r * max
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
/// 单次请求的 Token 用量。
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Usage {
|
||||
/// 输入(提示词)消耗的 token 数。
|
||||
pub input_tokens: u32,
|
||||
/// 输出(生成内容)消耗的 token 数。
|
||||
pub output_tokens: u32,
|
||||
}
|
||||
|
||||
/// Token 用量累计追踪器。
|
||||
///
|
||||
/// 在多轮对话中累计所有请求的 token 消耗。
|
||||
#[derive(Debug, Default)]
|
||||
pub struct CostTracker {
|
||||
accumulated: Usage,
|
||||
}
|
||||
|
||||
impl CostTracker {
|
||||
/// 累加一次请求的用量。
|
||||
///
|
||||
/// 使用 saturating_add 防止溢出。
|
||||
pub fn add(&mut self, usage: &Usage) {
|
||||
self.accumulated.input_tokens = self
|
||||
.accumulated
|
||||
.input_tokens
|
||||
.saturating_add(usage.input_tokens);
|
||||
self.accumulated.output_tokens = self
|
||||
.accumulated
|
||||
.output_tokens
|
||||
.saturating_add(usage.output_tokens);
|
||||
}
|
||||
|
||||
/// 获取累计用量。
|
||||
pub fn total(&self) -> &Usage {
|
||||
&self.accumulated
|
||||
}
|
||||
|
||||
/// 重置累计用量。
|
||||
pub fn reset(&mut self) {
|
||||
self.accumulated = Usage::default();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
use std::time::Duration;
|
||||
|
||||
/// LLM 调用过程中可能发生的所有错误。
|
||||
///
|
||||
/// 错误按可重试性分为两类:
|
||||
/// - **可重试**:`RateLimit`、`Timeout`、状态码 >= 500
|
||||
/// - **不可重试**:`Authentication`、`ContextLength`、状态码 4xx(除 429)
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum LlmError {
|
||||
/// API 认证失败(如 API key 无效)。
|
||||
#[error("认证失败: {0}")]
|
||||
Authentication(String),
|
||||
|
||||
/// 请求被限流,可选地附带重试等待时间。
|
||||
#[error("限流(retry_after={retry_after:?})")]
|
||||
RateLimit { retry_after: Option<Duration> },
|
||||
|
||||
/// HTTP 请求失败,包含状态码和响应体。
|
||||
#[error("请求失败(status={status}): {body}")]
|
||||
Request { status: u16, body: String },
|
||||
|
||||
/// 请求超时。
|
||||
#[error("请求超时(duration={duration:?})")]
|
||||
Timeout { duration: Duration },
|
||||
|
||||
/// 流式响应处理错误(预留)。
|
||||
#[error("流式响应错误: {0}")]
|
||||
Stream(String),
|
||||
|
||||
/// 上下文长度超限。
|
||||
#[error("上下文超限(actual={actual}, limit={limit})")]
|
||||
ContextLength { actual: u32, limit: u32 },
|
||||
|
||||
/// 其他未分类的 LLM 调用失败。
|
||||
#[error("LLM 调用失败: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
pub mod openai;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{ChatRequest, ChatResponse};
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// LLM Provider 抽象接口。
|
||||
///
|
||||
/// 所有具体的 LLM 后端实现(OpenAI、Anthropic、Azure 等)
|
||||
/// 均需实现此 trait,以实现可插拔替换。
|
||||
#[async_trait]
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
/// 发送聊天请求并返回完整响应。
|
||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError>;
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::llm::cycle::usage::Usage;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{
|
||||
ChatRequest, ChatResponse, ContentBlock, Message, Role, StopReason, ToolDefinition,
|
||||
};
|
||||
|
||||
use super::LlmProvider;
|
||||
|
||||
/// OpenAI 兼容 API 的 Provider 实现。
|
||||
///
|
||||
/// 支持任意实现了 `POST /v1/chat/completions` 标准的 API
|
||||
/// (包括 OpenAI、Azure OpenAI、DashScope、vLLM 等)。
|
||||
pub struct OpenaiProvider {
|
||||
http_client: Client,
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
#[allow(dead_code)]
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl OpenaiProvider {
|
||||
/// 创建新的 OpenAI Provider。
|
||||
///
|
||||
/// 默认使用 120 秒超时的 HTTP 客户端。
|
||||
pub fn new(base_url: String, api_key: String, model: String) -> Self {
|
||||
let http_client = Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.expect("创建 HTTP 客户端失败");
|
||||
|
||||
Self {
|
||||
http_client,
|
||||
base_url,
|
||||
api_key,
|
||||
model,
|
||||
}
|
||||
}
|
||||
|
||||
/// 替换为自定义的 HTTP 客户端(用于测试或自定义配置)。
|
||||
pub fn with_client(mut self, client: Client) -> Self {
|
||||
self.http_client = client;
|
||||
self
|
||||
}
|
||||
|
||||
/// 将 ChatRequest 构建为 OpenAI API 请求体 JSON。
|
||||
fn build_request_body(&self, request: &ChatRequest) -> Value {
|
||||
let mut body = json!({
|
||||
"model": request.model,
|
||||
"messages": Self::serialize_messages(request),
|
||||
});
|
||||
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = json!(max_tokens);
|
||||
}
|
||||
if let Some(temperature) = request.temperature {
|
||||
body["temperature"] = json!(temperature);
|
||||
}
|
||||
if !request.tools.is_empty() {
|
||||
body["tools"] = json!(
|
||||
request
|
||||
.tools
|
||||
.iter()
|
||||
.map(Self::serialize_tool)
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
// 合并 extra_body 中的扩展参数到请求体顶层
|
||||
if let Some(ref extra) = request.extra_body
|
||||
&& let Some(obj) = extra.as_object()
|
||||
{
|
||||
for (k, v) in obj {
|
||||
body[k] = v.clone();
|
||||
}
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
|
||||
/// 将请求中的消息列表序列化为 API 消息数组。
|
||||
fn serialize_messages(request: &ChatRequest) -> Vec<Value> {
|
||||
let mut messages: Vec<Value> = Vec::new();
|
||||
|
||||
// system_prompt 作为独立的 system 角色消息放在最前面
|
||||
if let Some(ref system_prompt) = request.system_prompt {
|
||||
messages.push(json!({
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
}));
|
||||
}
|
||||
|
||||
for msg in &request.messages {
|
||||
messages.push(Self::serialize_message(msg));
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
/// 将单条消息序列化为 API 格式。
|
||||
///
|
||||
/// 处理逻辑:
|
||||
/// - 多个 content block 或包含图片 → 使用数组格式
|
||||
/// - ToolResult → 使用 tool 角色格式
|
||||
/// - 其他 → 使用纯文本格式
|
||||
fn serialize_message(msg: &Message) -> Value {
|
||||
let role_str = match msg.role {
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::System => "system",
|
||||
Role::Tool => "tool",
|
||||
};
|
||||
|
||||
let has_mixed = msg.content.len() > 1
|
||||
|| msg
|
||||
.content
|
||||
.iter()
|
||||
.any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
|
||||
|
||||
if has_mixed {
|
||||
let content: Vec<Value> = msg
|
||||
.content
|
||||
.iter()
|
||||
.map(Self::serialize_content_block)
|
||||
.collect();
|
||||
json!({ "role": role_str, "content": content })
|
||||
} else if let Some(ContentBlock::ToolResult {
|
||||
tool_use_id,
|
||||
content,
|
||||
}) = msg.content.first()
|
||||
{
|
||||
json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_use_id,
|
||||
"content": content
|
||||
})
|
||||
} else {
|
||||
let text = msg
|
||||
.content
|
||||
.first()
|
||||
.map(|b| match b {
|
||||
ContentBlock::Text { text } => text.clone(),
|
||||
_ => String::new(),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
json!({ "role": role_str, "content": text })
|
||||
}
|
||||
}
|
||||
|
||||
/// 将 ContentBlock 序列化为 API content parts 数组元素。
|
||||
fn serialize_content_block(block: &ContentBlock) -> Value {
|
||||
match block {
|
||||
ContentBlock::Text { text } => {
|
||||
json!({ "type": "text", "text": text })
|
||||
}
|
||||
ContentBlock::ImageUrl { url } => {
|
||||
json!({ "type": "image_url", "image_url": { "url": url } })
|
||||
}
|
||||
ContentBlock::ToolUse { id, name, input } => {
|
||||
json!({ "type": "tool_use", "id": id, "name": name, "input": input })
|
||||
}
|
||||
ContentBlock::ToolResult { .. } => {
|
||||
json!({ "type": "tool_result", "content": "" })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 将 ToolDefinition 序列化为 OpenAI tools 数组元素。
|
||||
fn serialize_tool(tool: &ToolDefinition) -> Value {
|
||||
json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// 将 OpenAI API 响应 JSON 解析为 ChatResponse。
|
||||
fn parse_response(response: Value) -> Result<ChatResponse, LlmError> {
|
||||
let choice = response["choices"][0]
|
||||
.as_object()
|
||||
.ok_or_else(|| LlmError::Other("响应中缺少 choices[0]".into()))?;
|
||||
|
||||
let msg = choice["message"]
|
||||
.as_object()
|
||||
.ok_or_else(|| LlmError::Other("响应中缺少 message".into()))?;
|
||||
|
||||
let role = match msg["role"].as_str() {
|
||||
Some("assistant") => Role::Assistant,
|
||||
Some(_) => Role::Assistant,
|
||||
None => Role::Assistant,
|
||||
};
|
||||
|
||||
let mut content_blocks: Vec<ContentBlock> = Vec::new();
|
||||
|
||||
// 从 content 字段提取文本和 tool_use
|
||||
if let Some(content_val) = msg.get("content") {
|
||||
match content_val {
|
||||
Value::String(s) if !s.is_empty() => {
|
||||
content_blocks.push(ContentBlock::Text { text: s.clone() });
|
||||
}
|
||||
Value::Array(arr) => {
|
||||
for item in arr {
|
||||
if let Some(item_type) = item["type"].as_str() {
|
||||
match item_type {
|
||||
"text" => {
|
||||
if let Some(text) = item["text"].as_str() {
|
||||
content_blocks
|
||||
.push(ContentBlock::Text { text: text.into() });
|
||||
}
|
||||
}
|
||||
"tool_use" | "function" => {
|
||||
let id = item["id"].as_str().unwrap_or("").to_string();
|
||||
let name = item["name"].as_str().unwrap_or("").to_string();
|
||||
let input = item.get("input").cloned().unwrap_or(Value::Null);
|
||||
content_blocks
|
||||
.push(ContentBlock::ToolUse { id, name, input });
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// 从 tool_calls 字段提取工具调用(OpenAI 特有格式)
|
||||
if let Some(tool_calls) = msg.get("tool_calls").and_then(|v| v.as_array()) {
|
||||
for tc in tool_calls {
|
||||
let id = tc["id"].as_str().unwrap_or("").to_string();
|
||||
let name = tc["function"]["name"].as_str().unwrap_or("").to_string();
|
||||
let input = tc["function"]["arguments"]
|
||||
.as_str()
|
||||
.and_then(|s| serde_json::from_str(s).ok())
|
||||
.unwrap_or(Value::Null);
|
||||
content_blocks.push(ContentBlock::ToolUse { id, name, input });
|
||||
}
|
||||
}
|
||||
|
||||
if content_blocks.is_empty() {
|
||||
content_blocks.push(ContentBlock::Text {
|
||||
text: String::new(),
|
||||
});
|
||||
}
|
||||
|
||||
// 解析停止原因
|
||||
let stop_reason = choice["finish_reason"].as_str().map(|s| match s {
|
||||
"stop" => StopReason::Stop,
|
||||
"tool_calls" => StopReason::ToolUse,
|
||||
"max_tokens" => StopReason::MaxTokens,
|
||||
"length" => StopReason::Length,
|
||||
"content_filter" => StopReason::ContentFilter,
|
||||
other => StopReason::Other(other.into()),
|
||||
});
|
||||
|
||||
// 解析 token 用量
|
||||
let usage = response["usage"]
|
||||
.as_object()
|
||||
.map(|u| Usage {
|
||||
input_tokens: u.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
|
||||
output_tokens: u
|
||||
.get("completion_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as u32,
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(ChatResponse {
|
||||
message: Message {
|
||||
role,
|
||||
content: content_blocks,
|
||||
},
|
||||
usage,
|
||||
stop_reason,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenaiProvider {
|
||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError> {
|
||||
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
|
||||
let body = self.build_request_body(&request);
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.is_timeout() {
|
||||
LlmError::Timeout {
|
||||
duration: Duration::from_secs(120),
|
||||
}
|
||||
} else if e.is_connect() {
|
||||
LlmError::Other(format!("连接失败: {}", e))
|
||||
} else {
|
||||
LlmError::Other(format!("请求失败: {}", e))
|
||||
}
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
let status_code: u16 = status.as_u16();
|
||||
|
||||
// 处理非 2xx 响应,将 HTTP 状态码映射为对应的 LlmError 变体
|
||||
if !status.is_success() {
|
||||
// 在消费 response body 之前先读取 retry-after 头部
|
||||
let retry_after = response
|
||||
.headers()
|
||||
.get("retry-after")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.parse::<u64>().ok())
|
||||
.map(Duration::from_secs);
|
||||
let body_text = response.text().await.unwrap_or_default();
|
||||
|
||||
return match status_code {
|
||||
401 => Err(LlmError::Authentication(body_text)),
|
||||
429 => Err(LlmError::RateLimit { retry_after }),
|
||||
_ if status_code >= 500 => Err(LlmError::Request {
|
||||
status: status_code,
|
||||
body: body_text,
|
||||
}),
|
||||
_ if status_code == 400 && body_text.contains("context_length_exceeded") => {
|
||||
Err(LlmError::ContextLength {
|
||||
actual: 0,
|
||||
limit: 0,
|
||||
})
|
||||
}
|
||||
_ => Err(LlmError::Request {
|
||||
status: status_code,
|
||||
body: body_text,
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
let json_body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Other(format!("响应解析失败: {}", e)))?;
|
||||
|
||||
Self::parse_response(json_body)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
use crate::llm::cycle::usage::Usage;
|
||||
use serde_json::Value;
|
||||
|
||||
/// 对话消息的角色。
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
Tool,
|
||||
}
|
||||
|
||||
/// 消息内容块,支持多模态及工具调用。
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ContentBlock {
|
||||
/// 纯文本内容。
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
/// 图片 URL(多模态输入预留)。
|
||||
ImageUrl {
|
||||
url: String,
|
||||
},
|
||||
/// 模型发起的工具调用(预留,暂不实现自动执行)。
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
/// 工具执行结果的回传(预留,暂不实现自动执行)。
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// 一条对话消息,由角色和内容块列表组成。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub content: Vec<ContentBlock>,
|
||||
}
|
||||
|
||||
/// 可供模型调用的工具定义。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolDefinition {
|
||||
/// 工具名称。
|
||||
pub name: String,
|
||||
/// 工具描述,用于模型理解何时调用。
|
||||
pub description: String,
|
||||
/// JSON Schema 格式的输入参数定义。
|
||||
pub input_schema: Value,
|
||||
}
|
||||
|
||||
/// 对 /v1/chat/completions 的完整请求参数。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatRequest {
|
||||
/// 模型标识(如 "gpt-4o")。
|
||||
pub model: String,
|
||||
/// 对话历史 + 新消息。
|
||||
pub messages: Vec<Message>,
|
||||
/// 独立的系统提示词,将在序列化时转为 system 角色消息。
|
||||
pub system_prompt: Option<String>,
|
||||
/// 可用的工具定义列表。
|
||||
pub tools: Vec<ToolDefinition>,
|
||||
/// 最大输出 token 数。
|
||||
pub max_tokens: Option<u32>,
|
||||
/// 采样温度。
|
||||
pub temperature: Option<f32>,
|
||||
/// 扩展参数(如 enable_thinking),会合并到请求体顶层。
|
||||
pub extra_body: Option<Value>,
|
||||
}
|
||||
|
||||
/// 模型返回的完整响应。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatResponse {
|
||||
/// 助手的回复消息。
|
||||
pub message: Message,
|
||||
/// 本次请求的 token 用量。
|
||||
pub usage: Usage,
|
||||
/// 停止原因。
|
||||
pub stop_reason: Option<StopReason>,
|
||||
}
|
||||
|
||||
/// 模型停止生成的原因。
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StopReason {
|
||||
/// 正常结束。
|
||||
Stop,
|
||||
/// 模型请求调用工具(预留)。
|
||||
ToolUse,
|
||||
/// 达到 max_tokens 上限。
|
||||
MaxTokens,
|
||||
/// 内容被安全过滤。
|
||||
ContentFilter,
|
||||
/// 长度限制(兼容某些 API 的 finish_reason)。
|
||||
Length,
|
||||
/// 其他未分类的原因。
|
||||
Other(String),
|
||||
}
|
||||
Reference in New Issue
Block a user