diff --git a/Cargo.toml b/Cargo.toml index a7d2155..e50301f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,10 @@ version = "0.1.0" edition = "2024" [dependencies] +tokio = { version = "1", features = ["full"] } +reqwest = { version = "0.12", features = ["json"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "2" +async-trait = "0.1" +tracing = "0.1" diff --git a/src/lib.rs b/src/lib.rs index b93cf3f..cb4039d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,6 @@ -pub fn add(left: u64, right: u64) -> u64 { - left + right -} +//! agcore —— 智能体(Agent)核心工具箱。 +//! +//! 当前提供 LLM 调用周期控制作为核心底层能力,后续将扩展至 +//! 提示词工程、记忆系统、工具调用、Agent 运行时等领域。 -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } -} +pub mod llm; diff --git a/src/llm.rs b/src/llm.rs new file mode 100644 index 0000000..9e88e23 --- /dev/null +++ b/src/llm.rs @@ -0,0 +1,8 @@ +//! LLM 调用周期 —— 大模型基础调用周期控制。 +//! +//! 包含核心数据类型、Provider 抽象、OpenAI 兼容实现以及生命周期引擎。 + +pub mod cycle; +pub mod error; +pub mod provider; +pub mod types; diff --git a/src/llm/cycle.rs b/src/llm/cycle.rs new file mode 100644 index 0000000..f710818 --- /dev/null +++ b/src/llm/cycle.rs @@ -0,0 +1,149 @@ +mod retry; +pub mod usage; + +pub use retry::RetryConfig; +pub use usage::{CostTracker, Usage}; + +use crate::llm::cycle::retry::should_retry; +use crate::llm::error::LlmError; +use crate::llm::provider::LlmProvider; +use crate::llm::types::{ChatRequest, ChatResponse, ContentBlock, Message, Role, ToolDefinition}; + +/// LLM 生命周期引擎的配置。 +pub struct CycleConfig { + /// 使用的模型名称。 + pub model: String, + /// 最大输出 token 数。 + pub max_tokens: Option, + /// 采样温度。 + pub temperature: Option, + /// 最大对话轮数(预留,暂未使用)。 + pub max_turns: Option, + /// 重试策略配置。 + pub retry: RetryConfig, +} + +impl Default for CycleConfig { + fn default() -> Self { + Self { + model: String::from("gpt-4o"), + max_tokens: None, + temperature: None, + max_turns: None, + retry: RetryConfig::default(), + } + } +} + +/// LLM 调用生命周期引擎。 +/// +/// 管理一次多轮交互的完整生命周期,包括: +/// - 消息历史维护 +/// - Token 用量追踪 +/// - 自动重试 +pub struct LlmCycle { + provider: Box, + config: CycleConfig, + usage: CostTracker, + messages: Vec, + system_prompt: Option, +} + +impl LlmCycle { + /// 创建新的生命周期引擎。 + pub fn new(provider: Box, config: CycleConfig) -> Self { + Self { + provider, + config, + usage: CostTracker::default(), + messages: Vec::new(), + system_prompt: None, + } + } + + /// 设置系统提示词(Builder 模式)。 + pub fn with_system_prompt(mut self, prompt: String) -> Self { + self.system_prompt = Some(prompt); + self + } + + /// 获取 Token 用量追踪器引用。 + pub fn usage(&self) -> &CostTracker { + &self.usage + } + + /// 获取当前消息历史。 + pub fn messages(&self) -> &[Message] { + &self.messages + } + + /// 清空消息历史。 + pub fn clear_messages(&mut self) { + self.messages.clear(); + } + + /// 重置 Token 用量统计。 + pub fn reset_usage(&mut self) { + self.usage.reset(); + } + + /// 提交一条用户消息并获取模型响应。 + /// + /// 流程: + /// 1. 将用户消息追加到消息历史 + /// 2. 构建 ChatRequest + /// 3. 使用重试循环调用 provider.chat() + /// 4. 将助手回复追加到消息历史 + /// 5. 累计 token 用量 + /// 6. 返回 ChatResponse + pub async fn submit( + &mut self, + prompt: String, + tools: Vec, + ) -> Result { + self.messages.push(Message { + role: Role::User, + content: vec![ContentBlock::Text { text: prompt }], + }); + + let mut attempts = 0; + + loop { + let request = self.build_request(&tools); + + match self.provider.chat(request).await { + Ok(response) => { + self.messages.push(Message { + role: Role::Assistant, + content: response.message.content.clone(), + }); + + self.usage.add(&response.usage); + + return Ok(response); + } + Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => { + attempts += 1; + let delay = self.config.retry.compute_delay(attempts); + tokio::time::sleep(delay).await; + } + Err(e) => { + return Err(e); + } + } + } + } + + /// 根据当前状态构建 ChatRequest。 + fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest { + ChatRequest { + model: self.config.model.clone(), + messages: self.messages.clone(), + system_prompt: self.system_prompt.clone(), + tools: tools.to_vec(), + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, + extra_body: None, + } + } +} diff --git a/src/llm/cycle/retry.rs b/src/llm/cycle/retry.rs new file mode 100644 index 0000000..83ca9fa --- /dev/null +++ b/src/llm/cycle/retry.rs @@ -0,0 +1,71 @@ +use std::time::Duration; + +use crate::llm::error::LlmError; + +/// 重试策略配置。 +/// +/// 使用指数退避 + jitter 算法计算每次重试的等待时间。 +#[derive(Debug, Clone)] +pub struct RetryConfig { + /// 最大重试次数(默认 3)。 + pub max_retries: u32, + /// 初始延迟(默认 1 秒)。 + pub base_delay: Duration, + /// 最大延迟上限(默认 30 秒)。 + pub max_delay: Duration, + /// Jitter 比例因子(默认 0.25)。 + pub jitter_factor: f64, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: 3, + base_delay: Duration::from_secs(1), + max_delay: Duration::from_secs(30), + jitter_factor: 0.25, + } + } +} + +impl RetryConfig { + /// 根据当前重试次数计算等待时间。 + /// + /// 算法: `delay = min(base * 2^(attempt-1), max_delay) + random(0, delay * jitter_factor)` + pub fn compute_delay(&self, attempt: u32) -> Duration { + let base = self.base_delay.as_secs_f64(); + let exponential = base * (2u64.pow(attempt.saturating_sub(1))) as f64; + let capped = exponential.min(self.max_delay.as_secs_f64()); + let jitter = rand_jitter(capped * self.jitter_factor); + + Duration::from_secs_f64(capped + jitter) + } +} + +/// 判断错误是否可重试。 +/// +/// 可重试条件: +/// - RateLimit(429) +/// - Timeout +/// - Request 且状态码 >= 500 或 == 429 +pub fn should_retry(err: &LlmError) -> bool { + match err { + LlmError::RateLimit { .. } => true, + LlmError::Timeout { .. } => true, + LlmError::Request { status, .. } => *status >= 500 || *status == 429, + _ => false, + } +} + +/// 基于纳秒时间戳的简单伪随机数,范围 [0, max)。 +fn rand_jitter(max: f64) -> f64 { + if max <= 0.0 { + return 0.0; + } + let t = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + let r = (t % 1000) as f64 / 1000.0; + r * max +} diff --git a/src/llm/cycle/usage.rs b/src/llm/cycle/usage.rs new file mode 100644 index 0000000..32da25e --- /dev/null +++ b/src/llm/cycle/usage.rs @@ -0,0 +1,42 @@ +/// 单次请求的 Token 用量。 +#[derive(Debug, Clone, Default)] +pub struct Usage { + /// 输入(提示词)消耗的 token 数。 + pub input_tokens: u32, + /// 输出(生成内容)消耗的 token 数。 + pub output_tokens: u32, +} + +/// Token 用量累计追踪器。 +/// +/// 在多轮对话中累计所有请求的 token 消耗。 +#[derive(Debug, Default)] +pub struct CostTracker { + accumulated: Usage, +} + +impl CostTracker { + /// 累加一次请求的用量。 + /// + /// 使用 saturating_add 防止溢出。 + pub fn add(&mut self, usage: &Usage) { + self.accumulated.input_tokens = self + .accumulated + .input_tokens + .saturating_add(usage.input_tokens); + self.accumulated.output_tokens = self + .accumulated + .output_tokens + .saturating_add(usage.output_tokens); + } + + /// 获取累计用量。 + pub fn total(&self) -> &Usage { + &self.accumulated + } + + /// 重置累计用量。 + pub fn reset(&mut self) { + self.accumulated = Usage::default(); + } +} diff --git a/src/llm/error.rs b/src/llm/error.rs new file mode 100644 index 0000000..fe17b30 --- /dev/null +++ b/src/llm/error.rs @@ -0,0 +1,37 @@ +use std::time::Duration; + +/// LLM 调用过程中可能发生的所有错误。 +/// +/// 错误按可重试性分为两类: +/// - **可重试**:`RateLimit`、`Timeout`、状态码 >= 500 +/// - **不可重试**:`Authentication`、`ContextLength`、状态码 4xx(除 429) +#[derive(thiserror::Error, Debug)] +pub enum LlmError { + /// API 认证失败(如 API key 无效)。 + #[error("认证失败: {0}")] + Authentication(String), + + /// 请求被限流,可选地附带重试等待时间。 + #[error("限流(retry_after={retry_after:?})")] + RateLimit { retry_after: Option }, + + /// HTTP 请求失败,包含状态码和响应体。 + #[error("请求失败(status={status}): {body}")] + Request { status: u16, body: String }, + + /// 请求超时。 + #[error("请求超时(duration={duration:?})")] + Timeout { duration: Duration }, + + /// 流式响应处理错误(预留)。 + #[error("流式响应错误: {0}")] + Stream(String), + + /// 上下文长度超限。 + #[error("上下文超限(actual={actual}, limit={limit})")] + ContextLength { actual: u32, limit: u32 }, + + /// 其他未分类的 LLM 调用失败。 + #[error("LLM 调用失败: {0}")] + Other(String), +} diff --git a/src/llm/provider.rs b/src/llm/provider.rs new file mode 100644 index 0000000..d7d0917 --- /dev/null +++ b/src/llm/provider.rs @@ -0,0 +1,15 @@ +pub mod openai; + +use crate::llm::error::LlmError; +use crate::llm::types::{ChatRequest, ChatResponse}; +use async_trait::async_trait; + +/// LLM Provider 抽象接口。 +/// +/// 所有具体的 LLM 后端实现(OpenAI、Anthropic、Azure 等) +/// 均需实现此 trait,以实现可插拔替换。 +#[async_trait] +pub trait LlmProvider: Send + Sync { + /// 发送聊天请求并返回完整响应。 + async fn chat(&self, request: ChatRequest) -> Result; +} diff --git a/src/llm/provider/openai.rs b/src/llm/provider/openai.rs new file mode 100644 index 0000000..e0b72b1 --- /dev/null +++ b/src/llm/provider/openai.rs @@ -0,0 +1,354 @@ +use std::time::Duration; + +use async_trait::async_trait; +use reqwest::Client; +use serde_json::{json, Value}; + +use crate::llm::cycle::usage::Usage; +use crate::llm::error::LlmError; +use crate::llm::types::{ + ChatRequest, ChatResponse, ContentBlock, Message, Role, StopReason, ToolDefinition, +}; + +use super::LlmProvider; + +/// OpenAI 兼容 API 的 Provider 实现。 +/// +/// 支持任意实现了 `POST /v1/chat/completions` 标准的 API +/// (包括 OpenAI、Azure OpenAI、DashScope、vLLM 等)。 +pub struct OpenaiProvider { + http_client: Client, + base_url: String, + api_key: String, + #[allow(dead_code)] + model: String, +} + +impl OpenaiProvider { + /// 创建新的 OpenAI Provider。 + /// + /// 默认使用 120 秒超时的 HTTP 客户端。 + pub fn new(base_url: String, api_key: String, model: String) -> Self { + let http_client = Client::builder() + .timeout(Duration::from_secs(120)) + .build() + .expect("创建 HTTP 客户端失败"); + + Self { + http_client, + base_url, + api_key, + model, + } + } + + /// 替换为自定义的 HTTP 客户端(用于测试或自定义配置)。 + pub fn with_client(mut self, client: Client) -> Self { + self.http_client = client; + self + } + + /// 将 ChatRequest 构建为 OpenAI API 请求体 JSON。 + fn build_request_body(&self, request: &ChatRequest) -> Value { + let mut body = json!({ + "model": request.model, + "messages": Self::serialize_messages(request), + }); + + if let Some(max_tokens) = request.max_tokens { + body["max_tokens"] = json!(max_tokens); + } + if let Some(temperature) = request.temperature { + body["temperature"] = json!(temperature); + } + if !request.tools.is_empty() { + body["tools"] = json!( + request + .tools + .iter() + .map(Self::serialize_tool) + .collect::>() + ); + } + + // 合并 extra_body 中的扩展参数到请求体顶层 + if let Some(ref extra) = request.extra_body + && let Some(obj) = extra.as_object() + { + for (k, v) in obj { + body[k] = v.clone(); + } + } + + body + } + + /// 将请求中的消息列表序列化为 API 消息数组。 + fn serialize_messages(request: &ChatRequest) -> Vec { + let mut messages: Vec = Vec::new(); + + // system_prompt 作为独立的 system 角色消息放在最前面 + if let Some(ref system_prompt) = request.system_prompt { + messages.push(json!({ + "role": "system", + "content": system_prompt + })); + } + + for msg in &request.messages { + messages.push(Self::serialize_message(msg)); + } + + messages + } + + /// 将单条消息序列化为 API 格式。 + /// + /// 处理逻辑: + /// - 多个 content block 或包含图片 → 使用数组格式 + /// - ToolResult → 使用 tool 角色格式 + /// - 其他 → 使用纯文本格式 + fn serialize_message(msg: &Message) -> Value { + let role_str = match msg.role { + Role::User => "user", + Role::Assistant => "assistant", + Role::System => "system", + Role::Tool => "tool", + }; + + let has_mixed = msg.content.len() > 1 + || msg + .content + .iter() + .any(|b| matches!(b, ContentBlock::ImageUrl { .. })); + + if has_mixed { + let content: Vec = msg + .content + .iter() + .map(Self::serialize_content_block) + .collect(); + json!({ "role": role_str, "content": content }) + } else if let Some(ContentBlock::ToolResult { + tool_use_id, + content, + }) = msg.content.first() + { + json!({ + "role": "tool", + "tool_call_id": tool_use_id, + "content": content + }) + } else { + let text = msg + .content + .first() + .map(|b| match b { + ContentBlock::Text { text } => text.clone(), + _ => String::new(), + }) + .unwrap_or_default(); + json!({ "role": role_str, "content": text }) + } + } + + /// 将 ContentBlock 序列化为 API content parts 数组元素。 + fn serialize_content_block(block: &ContentBlock) -> Value { + match block { + ContentBlock::Text { text } => { + json!({ "type": "text", "text": text }) + } + ContentBlock::ImageUrl { url } => { + json!({ "type": "image_url", "image_url": { "url": url } }) + } + ContentBlock::ToolUse { id, name, input } => { + json!({ "type": "tool_use", "id": id, "name": name, "input": input }) + } + ContentBlock::ToolResult { .. } => { + json!({ "type": "tool_result", "content": "" }) + } + } + } + + /// 将 ToolDefinition 序列化为 OpenAI tools 数组元素。 + fn serialize_tool(tool: &ToolDefinition) -> Value { + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema + } + }) + } + + /// 将 OpenAI API 响应 JSON 解析为 ChatResponse。 + fn parse_response(response: Value) -> Result { + let choice = response["choices"][0] + .as_object() + .ok_or_else(|| LlmError::Other("响应中缺少 choices[0]".into()))?; + + let msg = choice["message"] + .as_object() + .ok_or_else(|| LlmError::Other("响应中缺少 message".into()))?; + + let role = match msg["role"].as_str() { + Some("assistant") => Role::Assistant, + Some(_) => Role::Assistant, + None => Role::Assistant, + }; + + let mut content_blocks: Vec = Vec::new(); + + // 从 content 字段提取文本和 tool_use + if let Some(content_val) = msg.get("content") { + match content_val { + Value::String(s) if !s.is_empty() => { + content_blocks.push(ContentBlock::Text { text: s.clone() }); + } + Value::Array(arr) => { + for item in arr { + if let Some(item_type) = item["type"].as_str() { + match item_type { + "text" => { + if let Some(text) = item["text"].as_str() { + content_blocks + .push(ContentBlock::Text { text: text.into() }); + } + } + "tool_use" | "function" => { + let id = item["id"].as_str().unwrap_or("").to_string(); + let name = item["name"].as_str().unwrap_or("").to_string(); + let input = item.get("input").cloned().unwrap_or(Value::Null); + content_blocks + .push(ContentBlock::ToolUse { id, name, input }); + } + _ => {} + } + } + } + } + _ => {} + } + } + + // 从 tool_calls 字段提取工具调用(OpenAI 特有格式) + if let Some(tool_calls) = msg.get("tool_calls").and_then(|v| v.as_array()) { + for tc in tool_calls { + let id = tc["id"].as_str().unwrap_or("").to_string(); + let name = tc["function"]["name"].as_str().unwrap_or("").to_string(); + let input = tc["function"]["arguments"] + .as_str() + .and_then(|s| serde_json::from_str(s).ok()) + .unwrap_or(Value::Null); + content_blocks.push(ContentBlock::ToolUse { id, name, input }); + } + } + + if content_blocks.is_empty() { + content_blocks.push(ContentBlock::Text { + text: String::new(), + }); + } + + // 解析停止原因 + let stop_reason = choice["finish_reason"].as_str().map(|s| match s { + "stop" => StopReason::Stop, + "tool_calls" => StopReason::ToolUse, + "max_tokens" => StopReason::MaxTokens, + "length" => StopReason::Length, + "content_filter" => StopReason::ContentFilter, + other => StopReason::Other(other.into()), + }); + + // 解析 token 用量 + let usage = response["usage"] + .as_object() + .map(|u| Usage { + input_tokens: u.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32, + output_tokens: u + .get("completion_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as u32, + }) + .unwrap_or_default(); + + Ok(ChatResponse { + message: Message { + role, + content: content_blocks, + }, + usage, + stop_reason, + }) + } +} + +#[async_trait] +impl LlmProvider for OpenaiProvider { + async fn chat(&self, request: ChatRequest) -> Result { + let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/')); + let body = self.build_request_body(&request); + + let response = self + .http_client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await + .map_err(|e| { + if e.is_timeout() { + LlmError::Timeout { + duration: Duration::from_secs(120), + } + } else if e.is_connect() { + LlmError::Other(format!("连接失败: {}", e)) + } else { + LlmError::Other(format!("请求失败: {}", e)) + } + })?; + + let status = response.status(); + let status_code: u16 = status.as_u16(); + + // 处理非 2xx 响应,将 HTTP 状态码映射为对应的 LlmError 变体 + if !status.is_success() { + // 在消费 response body 之前先读取 retry-after 头部 + let retry_after = response + .headers() + .get("retry-after") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()) + .map(Duration::from_secs); + let body_text = response.text().await.unwrap_or_default(); + + return match status_code { + 401 => Err(LlmError::Authentication(body_text)), + 429 => Err(LlmError::RateLimit { retry_after }), + _ if status_code >= 500 => Err(LlmError::Request { + status: status_code, + body: body_text, + }), + _ if status_code == 400 && body_text.contains("context_length_exceeded") => { + Err(LlmError::ContextLength { + actual: 0, + limit: 0, + }) + } + _ => Err(LlmError::Request { + status: status_code, + body: body_text, + }), + }; + } + + let json_body: Value = response + .json() + .await + .map_err(|e| LlmError::Other(format!("响应解析失败: {}", e)))?; + + Self::parse_response(json_body) + } +} diff --git a/src/llm/types.rs b/src/llm/types.rs new file mode 100644 index 0000000..950ac44 --- /dev/null +++ b/src/llm/types.rs @@ -0,0 +1,100 @@ +use crate::llm::cycle::usage::Usage; +use serde_json::Value; + +/// 对话消息的角色。 +#[derive(Debug, Clone)] +pub enum Role { + User, + Assistant, + System, + Tool, +} + +/// 消息内容块,支持多模态及工具调用。 +#[derive(Debug, Clone)] +pub enum ContentBlock { + /// 纯文本内容。 + Text { + text: String, + }, + /// 图片 URL(多模态输入预留)。 + ImageUrl { + url: String, + }, + /// 模型发起的工具调用(预留,暂不实现自动执行)。 + ToolUse { + id: String, + name: String, + input: Value, + }, + /// 工具执行结果的回传(预留,暂不实现自动执行)。 + ToolResult { + tool_use_id: String, + content: String, + }, +} + +/// 一条对话消息,由角色和内容块列表组成。 +#[derive(Debug, Clone)] +pub struct Message { + pub role: Role, + pub content: Vec, +} + +/// 可供模型调用的工具定义。 +#[derive(Debug, Clone)] +pub struct ToolDefinition { + /// 工具名称。 + pub name: String, + /// 工具描述,用于模型理解何时调用。 + pub description: String, + /// JSON Schema 格式的输入参数定义。 + pub input_schema: Value, +} + +/// 对 /v1/chat/completions 的完整请求参数。 +#[derive(Debug, Clone)] +pub struct ChatRequest { + /// 模型标识(如 "gpt-4o")。 + pub model: String, + /// 对话历史 + 新消息。 + pub messages: Vec, + /// 独立的系统提示词,将在序列化时转为 system 角色消息。 + pub system_prompt: Option, + /// 可用的工具定义列表。 + pub tools: Vec, + /// 最大输出 token 数。 + pub max_tokens: Option, + /// 采样温度。 + pub temperature: Option, + /// 扩展参数(如 enable_thinking),会合并到请求体顶层。 + pub extra_body: Option, +} + +/// 模型返回的完整响应。 +#[derive(Debug, Clone)] +pub struct ChatResponse { + /// 助手的回复消息。 + pub message: Message, + /// 本次请求的 token 用量。 + pub usage: Usage, + /// 停止原因。 + pub stop_reason: Option, +} + +/// 模型停止生成的原因。 +#[derive(Debug, Clone)] +pub enum StopReason { + /// 正常结束。 + Stop, + /// 模型请求调用工具(预留)。 + ToolUse, + /// 达到 max_tokens 上限。 + MaxTokens, + /// 内容被安全过滤。 + ContentFilter, + /// 长度限制(兼容某些 API 的 finish_reason)。 + Length, + /// 其他未分类的原因。 + Other(String), +}