From 32f3edaf19e54d0bbf9e13df220a000120a4edee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E6=B6=9B?= Date: Tue, 2 Jun 2026 08:51:42 +0800 Subject: [PATCH] =?UTF-8?q?feat(llm):=20=E5=AE=9E=E7=8E=B0=20Phase=200=20?= =?UTF-8?q?=E5=89=A9=E4=BD=99=E5=9B=9B=E4=B8=AA=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现 ProviderRegistry、HookExecutor、StreamEvents 和 Auto-compaction 模块,并集成到 LlmCycle 中 --- Cargo.toml | 8 +- docs/3-phase0-remaining.md | 375 +++++++++++++++++++++++++++++++++++ src/lib.rs | 3 - src/llm.rs | 5 +- src/llm/compact.rs | 159 +++++++++++++++ src/llm/cycle.rs | 217 +++++++++++++++++++- src/llm/hooks.rs | 128 ++++++++++++ src/llm/provider.rs | 22 +- src/llm/provider/openai.rs | 123 +++++++++++- src/llm/provider/registry.rs | 68 +++++++ src/llm/stream.rs | 108 ++++++++++ src/llm/types/mod.rs | 28 +++ src/llm/types/response.rs | 64 ++++++ 13 files changed, 1299 insertions(+), 9 deletions(-) create mode 100644 docs/3-phase0-remaining.md create mode 100644 src/llm/compact.rs create mode 100644 src/llm/hooks.rs create mode 100644 src/llm/provider/registry.rs create mode 100644 src/llm/stream.rs diff --git a/Cargo.toml b/Cargo.toml index 6fa76b8..002a5c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,13 +5,19 @@ edition = "2024" [dependencies] tokio = { version = "1", features = ["full"] } -reqwest = { version = "0.12", features = ["json"] } +reqwest = { version = "0.12", features = ["json", "stream"] } serde = { version = "1", features = ["derive"] } serde_json = "1" thiserror = "2" async-trait = "0.1" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tokio-stream = "0.1" +futures = "0.3" +futures-util = "0.3" +futures-core = "0.3" +bytes = "1" +async-stream = "0.3" [dev-dependencies] dotenvy = "0.15.7" diff --git a/docs/3-phase0-remaining.md b/docs/3-phase0-remaining.md new file mode 100644 index 0000000..43f5a6f --- /dev/null +++ b/docs/3-phase0-remaining.md @@ -0,0 +1,375 @@ +# Phase 0 剩余模块 — 实施方案 + +> 定稿日期:2026-06-02 + +## 背景与目标 + +AG Core Phase 0(Foundation)已完成核心数据类型、错误体系、Provider 接口、OpenAI 实现、生命周期引擎等基础设施。剩余 4 个子项尚未实现:**ProviderRegistry**、**HookExecutor**、**StreamEvents**、**Auto-compaction**。它们均被 Roadmap 标记为 P0(Must Have),是本阶段不可或缺的底层能力。 + +**目标**:完成这 4 个模块的设计与实现,使 Phase 0 全面交付。 + +--- + +## 需求分析 + +### 功能需求 + +| 模块 | 需求 | 验收条件 | +|------|------|---------| +| ProviderRegistry | 支持注册命名 Provider、按名称查找、设置默认 | 3 个公开方法 + 工厂辅助 | +| HookExecutor | 4 个事件点:PreRequest / PostRequest / OnRetry / OnError | Hook trait + HookExecutor 触发 | +| StreamEvents | 流式事件枚举 + Provider 流式方法 + Cycle 流式入口 | 6 种事件类型 + SSE 解析 | +| Auto-compaction | Token 估算 + 微压缩 + 断路器 | 触发后释放 token 且不改变语义 | + +### 非功能需求 + +- 所有公开 API 必须带 `///` 文档注释 +- 无新增 `unwrap()` 调用 +- 与现有 `LlmCycle` 集成时保持向后兼容(全为可选/增量添加) +- 错误统一使用 `LlmError` 枚举 + +--- + +## 方案设计 + +### 1. ProviderRegistry (`src/llm/provider/registry.rs`) + +**职责**:管理多个 LLM Provider 实例,支持按名称注册、发现、切换。 + +```rust +// src/llm/provider/registry.rs + +use std::collections::HashMap; +use crate::llm::error::LlmError; +use crate::llm::provider::{LlmProvider, ProviderConfig, ProviderType, create_provider}; + +/// Provider 注册表——管理多个 LLM Provider 实例。 +pub struct ProviderRegistry { + providers: HashMap>, + default_name: Option, +} + +impl ProviderRegistry { + pub fn new() -> Self; + + /// 注册一个已初始化的 Provider 实例。 + pub fn register(&mut self, name: impl Into, provider: Box); + + /// 通过 ProviderType + ProviderConfig 创建并注册。 + pub fn register_with_config( + &mut self, + name: impl Into, + provider_type: ProviderType, + config: ProviderConfig, + ) -> Result<(), LlmError>; + + /// 设置默认 Provider。 + pub fn set_default(&mut self, name: &str) -> Result<(), LlmError>; + + /// 按名称查找 Provider。 + pub fn get(&self, name: &str) -> Option<&dyn LlmProvider>; + + /// 获取默认 Provider。 + pub fn get_default(&self) -> Option<&dyn LlmProvider>; +} +``` + +**无新增依赖**。 + +--- + +### 2. HookExecutor (`src/llm/hooks.rs`) + +**职责**:在 LLM 调用生命周期的关键节点插入自定义逻辑。 + +```rust +// src/llm/hooks.rs + +use async_trait::async_trait; +use crate::llm::error::LlmError; +use crate::llm::types::ChatRequest; + +/// 生命周期钩子事件点。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HookEvent { + /// LLM 请求发起之前(可阻断)。 + PreRequest, + /// 成功响应之后。 + PostRequest, + /// 重试之前(仅可重试错误时触发)。 + OnRetry, + /// 不可恢复错误返回之前。 + OnError, +} + +/// 此次钩子调用的上下文。 +#[derive(Debug, Clone)] +pub struct HookContext<'a> { + pub event: HookEvent, + pub request: Option<&'a ChatRequest>, + pub error: Option<&'a LlmError>, + pub attempt: u32, +} + +/// 钩子执行结果。 +#[derive(Debug, Clone)] +pub struct HookResult { + /// 是否阻断后续操作(仅 PreRequest 有效)。 + pub should_block: bool, + /// 阻断/备注原因。 + pub reason: Option, +} + +/// 生命周期钩子 trait。 +#[async_trait] +pub trait Hook: Send + Sync { + async fn execute(&self, ctx: &HookContext<'_>) -> HookResult; +} + +/// 钩子执行器——管理注册与触发。 +pub struct HookExecutor { + hooks: Vec<(HookEvent, Box)>, +} + +impl HookExecutor { + pub fn new() -> Self; + pub fn register(&mut self, event: HookEvent, hook: Box); + pub async fn execute(&self, event: HookEvent, ctx: &HookContext<'_>) -> Vec; +} +``` + +**与 LlmCycle 集成**: +- `LlmCycle` 新增字段 `hook_executor: Option` +- 新增 builder 方法 `with_hook_executor()` +- `submit()` 中在 4 个点触发(PreRequest 若阻断则提前返回) + +**无新增依赖**(async-trait 已存在)。 + +--- + +### 3. StreamEvents (`src/llm/stream.rs`) + +**职责**:提供流式 LLM 调用的事件抽象,将原始 SSE chunk 解析为语义化事件。 + +```rust +// src/llm/stream.rs + +use std::pin::Pin; +use tokio_stream::Stream; +use crate::llm::error::LlmError; +use crate::llm::types::{FinishReason, Usage, OpenaiChatChunk, ToolDefinition}; +use serde_json::Value; + +/// 流式事件——LLM 调用全生命周期的语义化事件。 +#[derive(Debug, Clone)] +pub enum StreamEvent { + /// 助手回复文本增量。 + AssistantTextDelta { text: String }, + /// 工具调用开始。 + ToolExecutionStarted { tool_name: String, input: Value }, + /// 工具调用完成。 + ToolExecutionCompleted { tool_name: String, output: Value, is_error: bool }, + /// Token 用量更新。 + CostUpdate { usage: Usage }, + /// 一轮会话完成。 + TurnComplete { reason: FinishReason }, + /// 可恢复的错误事件。 + Error { message: String }, +} + +/// 将原始 OpenaiChatChunk 流解析为 StreamEvent 流。 +pub fn parse_chunk_stream( + chunks: Pin> + Send>>, +) -> Pin + Send>>; +``` + +**Provider 层扩展**(`src/llm/provider.rs`): +```rust +#[async_trait] +pub trait LlmProvider: Send + Sync { + async fn chat(&self, request: ChatRequest) -> Result; + + /// 流式聊天请求——返回原始 SSE chunk 流。 + /// 默认实现回退到非流式调用。 + async fn chat_stream( + &self, + request: ChatRequest, + ) -> Result< + Pin> + Send>>, + LlmError, + > { + let response = self.chat(request).await?; + let chunk = OpenaiChatChunk::from(response); + Ok(Box::pin(tokio_stream::once(Ok(chunk)))) + } +} +``` + +**LlmCycle 扩展**: +```rust +impl LlmCycle { + /// 提交请求并返回语义事件流。 + pub async fn submit_stream( + &mut self, + prompt: String, + tools: Vec, + ) -> Result< + Pin + Send>>, + LlmError, + >; +} +``` + +**新增依赖**: +```toml +tokio-stream = "0.1" +``` + +--- + +### 4. Auto-compaction (`src/llm/compact.rs`) + +**职责**:在上下文过长时自动压缩历史消息,避免 ContextLength 错误。 + +```rust +// src/llm/compact.rs + +use crate::llm::types::{ContentField, Message, OpenaiChatMessage, OpenaiContentPart}; + +// === 常量 === +const AUTOCOMPACT_BUFFER_TOKENS: u32 = 13_000; +const RESERVED_OUTPUT_TOKENS: u32 = 20_000; +const MAX_CONSECUTIVE_FAILURES: u32 = 3; +const KEEP_RECENT: usize = 6; + +/// 上下文压缩配置。 +#[derive(Debug, Clone)] +pub struct CompactConfig { + /// 模型上下文窗口大小(token 数)。 + pub context_window: u32, + /// 为输出预留的 token 数。 + pub reserved_tokens: u32, + /// 微压缩保留的最近消息数。 + pub keep_recent: usize, +} + +impl Default for CompactConfig { + fn default() -> Self { + Self { + context_window: 128_000, + reserved_tokens: RESERVED_OUTPUT_TOKENS, + keep_recent: KEEP_RECENT, + } + } +} + +impl CompactConfig { + pub fn threshold(&self) -> u32 { + self.context_window + .saturating_sub(self.reserved_tokens) + .saturating_sub(AUTOCOMPACT_BUFFER_TOKENS) + } +} + +/// 压缩状态——跟踪连续失败次数(断路器模式)。 +#[derive(Debug, Clone)] +pub struct CompactState { + consecutive_failures: u32, +} + +impl CompactState { + pub fn new() -> Self; + pub fn record_success(&mut self); + /// 记录失败,返回 true 表示已达断路器上限。 + pub fn record_failure(&mut self) -> bool; +} + +/// 粗略估计消息列表的 token 数(基于字符数,4 字符 ≈ 1 token)。 +pub fn estimate_message_tokens(messages: &[Message]) -> u32; + +/// 判断是否需要触发自动压缩。 +pub fn should_compact(messages: &[Message], config: &CompactConfig, state: &CompactState) -> bool; + +/// 执行微压缩——用占位符替换旧的 tool result 内容。 +/// 返回释放的 token 数。 +pub fn microcompact(messages: &mut Vec, keep_recent: usize) -> u32; +``` + +**与 LlmCycle 集成**: +- `LlmCycle` 新增字段 `compact_config: Option`, `compact_state: CompactState` +- 新增 builder 方法 `with_compact_config()` +- `submit()` 开始时调用 `should_compact()` → `microcompact()` + +**完整 LLM 摘要压缩**留占位,Phase 2 实现(需要循环内调用 LLM 的能力)。 + +**无新增依赖**。 + +--- + +## 实现计划 + +### Step 1: 先写方案文档 + +创建 `docs/3-phase0-remaining.md`(即本文档)。 + +### Step 2: ProviderRegistry + +- 创建 `src/llm/provider/registry.rs` +- `provider.rs` 添加 `pub mod registry;` +- `cargo check` 验证 + +### Step 3: HookExecutor + +- 创建 `src/llm/hooks.rs` +- `llm.rs` 添加 `pub mod hooks;` +- `LlmCycle` 新增字段和方法 +- `submit()` 中插入钩子触发点 +- `cargo check` 验证 + +### Step 4: StreamEvents + +- `Cargo.toml` 添加 `tokio-stream` +- 创建 `src/llm/stream.rs` +- `llm.rs` 添加 `pub mod stream;` +- `LlmProvider` 添加 `chat_stream()`(默认回退) +- `OpenaiProvider` 实现 SSE 解析 +- `LlmCycle` 添加 `submit_stream()` +- `cargo check` 验证 + +### Step 5: Auto-compaction + +- 创建 `src/llm/compact.rs` +- `llm.rs` 添加 `pub mod compact;` +- `LlmCycle` 新增字段和方法 +- `submit()` 开头插入压缩检查 +- `cargo check` 验证 + +### Step 6: 收尾 + +- `cargo clippy` — 无警告 +- `cargo build --release` — 完整构建 +- 检查所有新公开 API 有 `///` 注释 + +--- + +## 风险评估 + +| 风险 | 概率 | 影响 | 缓解措施 | +|------|------|------|---------| +| SSE 解析边界情况 | 中 | 中 | 参考 `reqwest` 的 chunked 响应处理;先用简单的 `lines()` 方式逐行读取 | +| token 估算不准 | 中 | 低 | 仅用于触发微压缩的阈值判断,保守估算即可;后续可接入 tiktoken-rs | +| 钩子阻断语义复杂 | 低 | 中 | PreRequest 阻断后返回明确错误消息;其他事件点只读不阻断 | +| 与后续 Phase 冲突 | 低 | 高 | 保持接口向后兼容,全用可选集成(Option) | + +--- + +## 验收标准 + +1. `cargo check` 编译通过 +2. `cargo clippy` 无警告 +3. 4 个模块文件存在且路径正确 +4. `ProviderRegistry` 支持注册/查找/默认 Provider +5. `HookExecutor` 支持 4 个事件点注册与触发 +6. `StreamEvents` 支持从 SSE chunk 解析为语义事件 +7. `Auto-compaction` 支持 token 估算与微压缩 +8. `LlmCycle` 向后兼容(新增字段全为 Option,不影响现有代码) diff --git a/src/lib.rs b/src/lib.rs index 7b85201..b7061ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,4 @@ //! agcore —— 智能体(Agent)核心工具箱。 -//! -//! 当前提供 LLM 调用周期控制作为核心底层能力,后续将扩展至 -//! 提示词工程、记忆系统、工具调用、Agent 运行时等领域。 pub mod llm; diff --git a/src/llm.rs b/src/llm.rs index 9e88e23..d306bfe 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -1,8 +1,9 @@ //! LLM 调用周期 —— 大模型基础调用周期控制。 -//! -//! 包含核心数据类型、Provider 抽象、OpenAI 兼容实现以及生命周期引擎。 +pub mod compact; pub mod cycle; pub mod error; +pub mod hooks; pub mod provider; +pub mod stream; pub mod types; diff --git a/src/llm/compact.rs b/src/llm/compact.rs new file mode 100644 index 0000000..53a07e4 --- /dev/null +++ b/src/llm/compact.rs @@ -0,0 +1,159 @@ +//! 上下文自动压缩 —— 当对话历史过长时自动压缩。 + +use crate::llm::types::{ContentField, OpenaiChatMessage, OpenaiContentPart}; + +const AUTOCOMPACT_BUFFER_TOKENS: u32 = 13_000; +const RESERVED_OUTPUT_TOKENS: u32 = 20_000; +const MAX_CONSECUTIVE_FAILURES: u32 = 3; +const KEEP_RECENT: usize = 6; + +/// 上下文压缩配置。 +#[derive(Debug, Clone)] +pub struct CompactConfig { + /// 模型上下文窗口大小(token 数)。 + pub context_window: u32, + /// 为输出预留的 token 数。 + pub reserved_tokens: u32, + /// 微压缩保留的最近消息数。 + pub keep_recent: usize, +} + +impl Default for CompactConfig { + fn default() -> Self { + Self { + context_window: 128_000, + reserved_tokens: RESERVED_OUTPUT_TOKENS, + keep_recent: KEEP_RECENT, + } + } +} + +impl CompactConfig { + /// 计算自动压缩触发的阈值。 + pub fn threshold(&self) -> u32 { + self.context_window + .saturating_sub(self.reserved_tokens) + .saturating_sub(AUTOCOMPACT_BUFFER_TOKENS) + } +} + +/// 压缩状态 —— 跟踪连续失败次数(断路器模式)。 +#[derive(Debug, Clone)] +pub struct CompactState { + consecutive_failures: u32, +} + +impl Default for CompactState { + fn default() -> Self { + Self::new() + } +} + +impl CompactState { + /// 创建一个新的压缩状态。 + pub fn new() -> Self { + Self { + consecutive_failures: 0, + } + } + + /// 记录一次成功的压缩。 + pub fn record_success(&mut self) { + self.consecutive_failures = 0; + } + + /// 记录一次压缩失败。 + /// + /// 返回 `true` 表示已达断路器上限,不再尝试。 + pub fn record_failure(&mut self) -> bool { + self.consecutive_failures += 1; + self.consecutive_failures >= MAX_CONSECUTIVE_FAILURES + } +} + +/// 粗略估计消息列表的 token 数(基于字符数,4 字符 ≈ 1 token)。 +pub fn estimate_message_tokens(messages: &[OpenaiChatMessage]) -> u32 { + messages + .iter() + .map(estimate_single_message_tokens) + .sum() +} + +fn estimate_single_message_tokens(msg: &OpenaiChatMessage) -> u32 { + let role_overhead: u32 = 4; + let content_tokens = match msg { + OpenaiChatMessage::Developer { content, .. } + | OpenaiChatMessage::System { content, .. } + | OpenaiChatMessage::User { content, .. } + | OpenaiChatMessage::Assistant { content, .. } + | OpenaiChatMessage::Function { content, .. } => estimate_content_tokens(content), + OpenaiChatMessage::Tool { content, .. } => estimate_content_tokens(content), + }; + role_overhead + content_tokens +} + +fn estimate_content_tokens(content: &ContentField) -> u32 { + match content { + ContentField::String(s) => estimate_text_tokens(s), + ContentField::Array(parts) => parts.iter().map(estimate_part_tokens).sum(), + } +} + +fn estimate_part_tokens(part: &OpenaiContentPart) -> u32 { + match part { + OpenaiContentPart::Text { text } => estimate_text_tokens(text), + _ => 50, + } +} + +fn estimate_text_tokens(text: &str) -> u32 { + if text.is_empty() { + return 0; + } + let len = text.len() as u32; + (len * 4).div_ceil(3) +} + +/// 判断是否需要触发自动压缩。 +pub fn should_compact( + messages: &[OpenaiChatMessage], + config: &CompactConfig, + state: &CompactState, +) -> bool { + if state.consecutive_failures >= MAX_CONSECUTIVE_FAILURES { + return false; + } + let tokens = estimate_message_tokens(messages); + tokens >= config.threshold() +} + +/// 执行微压缩 —— 用 `[pruned]` 替换旧的 tool result 内容。 +/// +/// 这是最便宜的压缩方式,不需要 LLM 调用。 +/// 保留最近的 `keep_recent` 条消息不变。 +/// +/// 返回释放的估算 token 数。 +pub fn microcompact(messages: &mut [OpenaiChatMessage], keep_recent: usize) -> u32 { + if messages.len() <= keep_recent { + return 0; + } + + let prune_start = messages.len() - keep_recent; + let mut freed_tokens: u32 = 0; + + for msg in &messages[..prune_start] { + if matches!(msg, OpenaiChatMessage::Tool { .. }) { + freed_tokens += estimate_single_message_tokens(msg); + } + } + + for msg in &mut messages[..prune_start] { + if let OpenaiChatMessage::Tool { content, .. } = msg { + *content = ContentField::Array(vec![OpenaiContentPart::Text { + text: "[pruned]".to_string(), + }]); + } + } + + freed_tokens +} diff --git a/src/llm/cycle.rs b/src/llm/cycle.rs index 0b52c9e..529a2fd 100644 --- a/src/llm/cycle.rs +++ b/src/llm/cycle.rs @@ -1,21 +1,38 @@ +//! LLM 调用周期控制模块。 + mod retry; pub mod usage; pub use retry::RetryConfig; pub use usage::{CostTracker, Usage}; +use std::pin::Pin; +use std::sync::Arc; + +use futures_core::stream::Stream; +use async_stream::stream; + +use crate::llm::compact::{should_compact, microcompact, CompactConfig, CompactState}; use crate::llm::cycle::retry::should_retry; use crate::llm::error::LlmError; +use crate::llm::hooks::{HookContext, HookExecutor}; use crate::llm::provider::LlmProvider; +use crate::llm::stream::StreamEvent; use crate::llm::types::{ ChatRequest, ChatResponse, OpenaiChatMessage, OpenaiTool, ToolChoice, ToolDefinition, }; +/// LLM 调用周期配置。 pub struct CycleConfig { + /// 模型名称。 pub model: String, + /// 最大输出 token 数。 pub max_tokens: Option, + /// 采样温度。 pub temperature: Option, + /// 最大对话轮次。 pub max_turns: Option, + /// 重试策略配置。 pub retry: RetryConfig, } @@ -31,15 +48,20 @@ impl Default for CycleConfig { } } +/// LLM 调用周期 —— 管理一次或多次 LLM 请求的生命周期。 pub struct LlmCycle { provider: Box, config: CycleConfig, usage: CostTracker, messages: Vec, system_prompt: Option, + hook_executor: Option>, + compact_config: Option, + compact_state: CompactState, } impl LlmCycle { + /// 创建一个新的 LlmCycle。 pub fn new(provider: Box, config: CycleConfig) -> Self { Self { provider, @@ -47,30 +69,51 @@ impl LlmCycle { usage: CostTracker::default(), messages: Vec::new(), system_prompt: None, + hook_executor: None, + compact_config: None, + compact_state: CompactState::new(), } } + /// 设置系统提示词。 pub fn with_system_prompt(mut self, prompt: String) -> Self { self.system_prompt = Some(prompt); self } + /// 设置钩子执行器。 + pub fn with_hook_executor(mut self, executor: HookExecutor) -> Self { + self.hook_executor = Some(Arc::new(executor)); + self + } + + /// 设置上下文压缩配置。 + pub fn with_compact_config(mut self, config: CompactConfig) -> Self { + self.compact_config = Some(config); + self + } + + /// 获取用量追踪器引用。 pub fn usage(&self) -> &CostTracker { &self.usage } + /// 获取消息历史引用。 pub fn messages(&self) -> &[OpenaiChatMessage] { &self.messages } + /// 清空消息历史。 pub fn clear_messages(&mut self) { self.messages.clear(); } + /// 重置用量统计。 pub fn reset_usage(&mut self) { self.usage.reset(); } + /// 提交用户消息并获取 LLM 响应。 pub async fn submit( &mut self, prompt: String, @@ -78,31 +121,203 @@ impl LlmCycle { ) -> Result { self.messages.push(OpenaiChatMessage::user_text(prompt)); + if let Some(ref config) = self.compact_config + && should_compact(&self.messages, config, &self.compact_state) + { + let freed = microcompact(&mut self.messages, config.keep_recent); + if freed > 0 { + self.compact_state.record_success(); + } + } + let mut attempts = 0; loop { let request = self.build_request(&tools); + if let Some(ref executor) = self.hook_executor { + let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest) + .with_request(&request); + let results = executor + .execute(crate::llm::hooks::HookEvent::PreRequest, &ctx) + .await; + if results.iter().any(|r| r.should_block) { + let reason = results + .iter() + .find(|r| r.should_block) + .and_then(|r| r.reason.clone()) + .unwrap_or_else(|| "Blocked by pre-request hook".to_string()); + return Err(LlmError::Other(reason)); + } + } + match self.provider.chat(request).await { Ok(response) => { - self.messages.push(response.message.clone()); + if let Some(ref executor) = self.hook_executor { + let post_request = self.build_request(&tools); + let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest) + .with_request(&post_request); + executor + .execute(crate::llm::hooks::HookEvent::PostRequest, &ctx) + .await; + } + self.messages.push(response.message.clone()); self.usage.add(&response.usage); return Ok(response); } Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => { attempts += 1; + + if let Some(ref executor) = self.hook_executor { + let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnRetry) + .with_error(&e) + .with_attempt(attempts); + executor + .execute(crate::llm::hooks::HookEvent::OnRetry, &ctx) + .await; + } + let delay = self.config.retry.compute_delay(attempts); tokio::time::sleep(delay).await; } Err(e) => { + if let Some(ref executor) = self.hook_executor { + let ctx = + HookContext::new(crate::llm::hooks::HookEvent::OnError).with_error(&e); + executor + .execute(crate::llm::hooks::HookEvent::OnError, &ctx) + .await; + } + return Err(e); } } } } + /// 提交用户消息并返回语义事件流。 + /// + /// 与 `submit` 不同,该方法返回流式事件而非完整响应。 + /// 适用于需要实时处理 LLM 输出的场景。 + pub async fn submit_stream( + &mut self, + prompt: String, + tools: Vec, + ) -> Result + Send>>, LlmError> { + self.messages.push(OpenaiChatMessage::user_text(prompt)); + + if let Some(ref config) = self.compact_config + && should_compact(&self.messages, config, &self.compact_state) + { + let freed = microcompact(&mut self.messages, config.keep_recent); + if freed > 0 { + self.compact_state.record_success(); + } + } + + let request = self.build_request(&tools); + + if let Some(ref executor) = self.hook_executor { + let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest) + .with_request(&request); + let results = executor + .execute(crate::llm::hooks::HookEvent::PreRequest, &ctx) + .await; + if results.iter().any(|r| r.should_block) { + let reason = results + .iter() + .find(|r| r.should_block) + .and_then(|r| r.reason.clone()) + .unwrap_or_else(|| "Blocked by pre-request hook".to_string()); + return Err(LlmError::Other(reason)); + } + } + + let chunk_stream = self.provider.chat_stream(request).await?; + let hook_executor = self.hook_executor.clone(); + let post_request = self.build_request(&tools); + + Ok(Box::pin(stream! { + use futures_util::StreamExt; + let mut chunk_stream = chunk_stream; + + while let Some(result) = chunk_stream.next().await { + match result { + Ok(chunk) => { + let mut assistant_text = String::new(); + let mut tool_started: Option<(String, String, String)> = None; + + for choice in &chunk.choices { + let delta = &choice.delta; + + if let Some(content) = &delta.content { + assistant_text.push_str(content); + } + + if let Some(tool_calls) = &delta.tool_calls + && let Some(tc) = tool_calls.first() + { + let crate::llm::types::OpenaiToolCall::Function { id, function } = tc; + tool_started = Some((id.clone(), function.name.clone(), function.arguments.clone())); + } + } + + if !assistant_text.is_empty() { + yield StreamEvent::AssistantTextDelta { text: assistant_text }; + } + + if let Some((tool_call_id, tool_name, arguments)) = tool_started { + let args: serde_json::Value = serde_json::from_str(&arguments) + .unwrap_or(serde_json::Value::Null); + yield StreamEvent::ToolExecutionStarted { + tool_name, + input: args, + tool_call_id, + }; + } + + for choice in &chunk.choices { + if let Some(finish_reason) = &choice.finish_reason { + yield StreamEvent::TurnComplete { + reason: *finish_reason, + }; + } + } + + if let Some(usage_info) = &chunk.usage { + yield StreamEvent::CostUpdate { usage: *usage_info }; + } + } + Err(e) => { + if let Some(ref executor) = hook_executor { + let ctx = crate::llm::hooks::HookContext::new( + crate::llm::hooks::HookEvent::OnError, + ) + .with_error(&e); + executor + .execute(crate::llm::hooks::HookEvent::OnError, &ctx) + .await; + } + yield StreamEvent::Error { message: e.to_string() }; + break; + } + } + } + + if let Some(ref executor) = hook_executor { + let ctx = crate::llm::hooks::HookContext::new( + crate::llm::hooks::HookEvent::PostRequest, + ) + .with_request(&post_request); + executor + .execute(crate::llm::hooks::HookEvent::PostRequest, &ctx) + .await; + } + })) + } + fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest { let mut messages = self.messages.clone(); diff --git a/src/llm/hooks.rs b/src/llm/hooks.rs new file mode 100644 index 0000000..96ee840 --- /dev/null +++ b/src/llm/hooks.rs @@ -0,0 +1,128 @@ +//! 生命周期钩子 —— 在 LLM 调用周期的关键节点插入自定义逻辑。 + +use async_trait::async_trait; + +use crate::llm::error::LlmError; +use crate::llm::types::ChatRequest; + +/// 生命周期钩子事件点。 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HookEvent { + /// LLM 请求发起之前(可阻断)。 + PreRequest, + /// 成功响应之后。 + PostRequest, + /// 重试之前(仅可重试错误时触发)。 + OnRetry, + /// 不可恢复错误返回之前。 + OnError, +} + +/// 此次钩子调用的上下文。 +#[derive(Debug, Clone)] +pub struct HookContext<'a> { + /// 当前事件点。 + pub event: HookEvent, + /// 当前的请求(部分事件点可用)。 + pub request: Option<&'a ChatRequest>, + /// 当前错误(仅 OnError 和 OnRetry 可用)。 + pub error: Option<&'a LlmError>, + /// 当前重试次数(从 1 开始,仅 OnRetry 可用)。 + pub attempt: u32, +} + +impl<'a> HookContext<'a> { + pub(crate) fn new(event: HookEvent) -> Self { + Self { + event, + request: None, + error: None, + attempt: 0, + } + } + + pub(crate) fn with_request(mut self, request: &'a ChatRequest) -> Self { + self.request = Some(request); + self + } + + pub(crate) fn with_error(mut self, error: &'a LlmError) -> Self { + self.error = Some(error); + self + } + + pub(crate) fn with_attempt(mut self, attempt: u32) -> Self { + self.attempt = attempt; + self + } +} + +/// 钩子执行结果。 +#[derive(Debug, Clone)] +pub struct HookResult { + /// 是否阻断后续操作(仅 PreRequest 有效)。 + pub should_block: bool, + /// 阻断/备注原因。 + pub reason: Option, +} + +impl HookResult { + #[allow(dead_code)] + pub(crate) fn allow() -> Self { + Self { + should_block: false, + reason: None, + } + } + + #[allow(dead_code)] + pub(crate) fn block(reason: impl Into) -> Self { + Self { + should_block: true, + reason: Some(reason.into()), + } + } +} + +/// 生命周期钩子 trait。 +#[async_trait] +pub trait Hook: Send + Sync { + /// 执行钩子逻辑。 + async fn execute(&self, ctx: &HookContext<'_>) -> HookResult; +} + +/// 钩子执行器 —— 管理注册与触发。 +pub struct HookExecutor { + hooks: Vec<(HookEvent, Box)>, +} + +impl Default for HookExecutor { + fn default() -> Self { + Self::new() + } +} + +impl HookExecutor { + /// 创建一个空的执行器。 + pub fn new() -> Self { + Self { + hooks: Vec::new(), + } + } + + /// 注册一个钩子到指定事件点。 + pub fn register(&mut self, event: HookEvent, hook: Box) { + self.hooks.push((event, hook)); + } + + /// 执行指定事件点的所有钩子。 + pub async fn execute(&self, event: HookEvent, ctx: &HookContext<'_>) -> Vec { + let mut results = Vec::new(); + for (e, hook) in &self.hooks { + if *e == event { + results.push(hook.execute(ctx).await); + } + } + results + } +} diff --git a/src/llm/provider.rs b/src/llm/provider.rs index 28445c8..cb61b86 100644 --- a/src/llm/provider.rs +++ b/src/llm/provider.rs @@ -1,7 +1,12 @@ pub mod openai; +pub mod registry; + +use std::pin::Pin; + +use tokio_stream::Stream; use crate::llm::error::LlmError; -use crate::llm::types::{ChatRequest, ChatResponse}; +use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatChunk}; use async_trait::async_trait; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -57,4 +62,19 @@ pub fn create_provider( pub trait LlmProvider: Send + Sync { /// 发送聊天请求并返回完整响应。 async fn chat(&self, request: ChatRequest) -> Result; + + /// 流式聊天请求 —— 返回原始 SSE chunk 流。 + /// + /// 默认实现回退到非流式调用(包装为单元素流)。 + async fn chat_stream( + &self, + request: ChatRequest, + ) -> Result< + Pin> + Send>>, + LlmError, + > { + let response = self.chat(request).await?; + let chunk = OpenaiChatChunk::from(response); + Ok(Box::pin(tokio_stream::once(Ok(chunk)))) + } } diff --git a/src/llm/provider/openai.rs b/src/llm/provider/openai.rs index 6cd59ea..fa72968 100644 --- a/src/llm/provider/openai.rs +++ b/src/llm/provider/openai.rs @@ -1,12 +1,18 @@ +use std::pin::Pin; use std::time::Duration; use async_trait::async_trait; +use bytes::Bytes; +use futures_core::stream::Stream; +use futures_util::StreamExt; use reqwest::Client; use tracing::{debug, error, info}; use super::LlmProvider; use crate::llm::error::LlmError; -use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatResponse}; +use crate::llm::types::{ + ChatRequest, ChatResponse, OpenaiChatChunk, OpenaiChatResponse, StreamOptions, +}; pub struct OpenaiProvider { http_client: Client, @@ -111,4 +117,119 @@ impl LlmProvider for OpenaiProvider { Ok(ChatResponse::from(chat_response)) } + + async fn chat_stream( + &self, + mut request: ChatRequest, + ) -> Result< + Pin> + Send>>, + LlmError, + > { + let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/')); + + request.stream = Some(true); + request.stream_options = Some(StreamOptions { + include_usage: Some(true), + include_obfuscation: None, + }); + + info!(model = %request.model, "发送 LLM 流式请求"); + + let response = self + .http_client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&request) + .send() + .await + .map_err(|e| { + error!(error = %e, "流式请求失败"); + Self::map_reqwest_error(e) + })?; + + let status = response.status(); + let status_code: u16 = status.as_u16(); + + if !status.is_success() { + let body_text = response.text().await.unwrap_or_default(); + error!(status = status_code, body = %body_text, "流式请求失败"); + return Err(LlmError::Request { + status: status_code, + body: body_text, + }); + } + + let byte_stream = response.bytes_stream().map(|r| { + r.map_err(|e| LlmError::Other(format!("流式读取失败: {}", e))) + }); + + Ok(Box::pin(SseChunkStream::new(byte_stream))) + } +} + +struct SseChunkStream { + inner: S, + buffer: String, +} + +impl> + Unpin> SseChunkStream { + fn new(stream: S) -> Self { + Self { + inner: stream, + buffer: String::new(), + } + } +} + +impl> + Unpin> Stream for SseChunkStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + loop { + if let Some(pos) = self.buffer.find("\n") { + let line = self.buffer.drain(..pos + 1).collect::(); + let trimmed = line.trim(); + if trimmed.is_empty() + || trimmed == "data: [DONE]" + || trimmed == "[DONE]" + || trimmed == "data:" + { + continue; + } + let data = if let Some(p) = trimmed.strip_prefix("data: ") { + p + } else { + trimmed + }; + match serde_json::from_str::(data) { + Ok(chunk) => return std::task::Poll::Ready(Some(Ok(chunk))), + Err(e) => { + return std::task::Poll::Ready(Some(Err(LlmError::Other(format!( + "Chunk 解析失败: {} | raw: {}", + e, data + ))))); + } + } + } + + match Pin::new(&mut self.inner).poll_next(cx) { + std::task::Poll::Ready(Some(Ok(bytes))) => { + if let Ok(s) = std::str::from_utf8(&bytes) { + self.buffer.push_str(s); + } + } + std::task::Poll::Ready(Some(Err(e))) => { + return std::task::Poll::Ready(Some(Err(e))); + } + std::task::Poll::Ready(None) => { + if self.buffer.is_empty() { + return std::task::Poll::Ready(None); + } + self.buffer.clear(); + return std::task::Poll::Ready(None); + } + std::task::Poll::Pending => return std::task::Poll::Pending, + } + } + } } diff --git a/src/llm/provider/registry.rs b/src/llm/provider/registry.rs new file mode 100644 index 0000000..82890a9 --- /dev/null +++ b/src/llm/provider/registry.rs @@ -0,0 +1,68 @@ +//! Provider 注册表 —— 多 Provider 实例的注册与发现。 + +use std::collections::HashMap; + +use crate::llm::error::LlmError; +use crate::llm::provider::{create_provider, LlmProvider, ProviderConfig, ProviderType}; + +/// Provider 注册表 —— 管理多个 LLM Provider 实例。 +/// +/// 支持注册命名 Provider、按名称查找、设置默认 Provider。 +pub struct ProviderRegistry { + providers: HashMap>, + default_name: Option, +} + +impl Default for ProviderRegistry { + fn default() -> Self { + Self::new() + } +} + +impl ProviderRegistry { + /// 创建一个空的注册表。 + pub fn new() -> Self { + Self { + providers: HashMap::new(), + default_name: None, + } + } + + /// 注册一个已初始化的 Provider 实例。 + pub fn register(&mut self, name: impl Into, provider: Box) { + self.providers.insert(name.into(), provider); + } + + /// 通过 ProviderType + ProviderConfig 创建并注册。 + pub fn register_with_config( + &mut self, + name: impl Into, + provider_type: ProviderType, + config: ProviderConfig, + ) -> Result<(), LlmError> { + let provider = create_provider(provider_type, config)?; + self.register(name, provider); + Ok(()) + } + + /// 设置默认 Provider。 + pub fn set_default(&mut self, name: &str) -> Result<(), LlmError> { + if !self.providers.contains_key(name) { + return Err(LlmError::Other(format!("Provider '{}' 不存在", name))); + } + self.default_name = Some(name.to_string()); + Ok(()) + } + + /// 按名称查找 Provider。 + pub fn get(&self, name: &str) -> Option<&dyn LlmProvider> { + self.providers.get(name).map(|p| p.as_ref()) + } + + /// 获取默认 Provider。 + pub fn get_default(&self) -> Option<&dyn LlmProvider> { + self.default_name + .as_ref() + .and_then(|name| self.get(name)) + } +} diff --git a/src/llm/stream.rs b/src/llm/stream.rs new file mode 100644 index 0000000..d15b0cf --- /dev/null +++ b/src/llm/stream.rs @@ -0,0 +1,108 @@ +//! 流式事件系统 —— 将 LLM 流式响应解析为语义化事件。 + +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_core::stream::Stream; +use futures_util::future::poll_fn; +use futures_util::FutureExt; +use serde_json::Value; + +use crate::llm::error::LlmError; +use crate::llm::types::{FinishReason, OpenaiChatChunk, OpenaiToolCall, Usage}; + +/// 流式事件 —— LLM 调用全生命周期的语义化事件。 +#[derive(Debug, Clone)] +pub enum StreamEvent { + /// 助手回复文本增量。 + AssistantTextDelta { text: String }, + /// 工具调用开始。 + ToolExecutionStarted { + tool_name: String, + input: Value, + tool_call_id: String, + }, + /// 工具调用完成。 + ToolExecutionCompleted { + tool_name: String, + output: Value, + is_error: bool, + }, + /// Token 用量更新。 + CostUpdate { usage: Usage }, + /// 一轮会话完成。 + TurnComplete { reason: FinishReason }, + /// 错误事件。 + Error { message: String }, +} + +impl StreamEvent { + fn error(message: impl Into) -> Self { + Self::Error { + message: message.into(), + } + } +} + +/// 将原始 OpenaiChatChunk 流解析为 StreamEvent 流。 +pub fn parse_chunk_stream( + chunks: Pin> + Send>>, +) -> Pin + Send>> { + Box::pin(ChunkToEventStream { chunks }) +} + +struct ChunkToEventStream { + chunks: Pin> + Send>>, +} + +impl Stream for ChunkToEventStream { + type Item = StreamEvent; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = &mut *self; + poll_fn(|cx| match Pin::new(&mut this.chunks).poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + for choice in &chunk.choices { + let delta = &choice.delta; + + if let Some(content) = &delta.content { + return Poll::Ready(Some(StreamEvent::AssistantTextDelta { + text: content.clone(), + })); + } + + if let Some(tool_calls) = &delta.tool_calls + && let Some(tc) = tool_calls.first() + { + let OpenaiToolCall::Function { id, function } = tc; + let args: Value = + serde_json::from_str(&function.arguments).unwrap_or(Value::Null); + return Poll::Ready(Some(StreamEvent::ToolExecutionStarted { + tool_name: function.name.clone(), + input: args, + tool_call_id: id.clone(), + })); + } + + if let Some(finish_reason) = &choice.finish_reason { + return Poll::Ready(Some(StreamEvent::TurnComplete { + reason: *finish_reason, + })); + } + } + + if let Some(usage) = &chunk.usage { + return Poll::Ready(Some(StreamEvent::CostUpdate { + usage: *usage, + })); + } + + Poll::Ready(None) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(StreamEvent::error(e.to_string()))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + }) + .poll_unpin(cx) + } +} diff --git a/src/llm/types/mod.rs b/src/llm/types/mod.rs index 342f36d..e448dbc 100644 --- a/src/llm/types/mod.rs +++ b/src/llm/types/mod.rs @@ -43,6 +43,34 @@ impl From for ChatResponse { } } +impl From for OpenaiChatChunk { + fn from(response: ChatResponse) -> Self { + let delta = Delta::from(response.message.clone()); + let chunk_choice = ChunkChoice { + index: 0, + delta, + logprobs: None, + finish_reason: response.stop_reason, + }; + + OpenaiChatChunk { + id: format!("chunk-{}", std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or(0)), + object: "chat.completion.chunk".to_string(), + created: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0), + model: String::new(), + choices: vec![chunk_choice], + usage: Some(response.usage), + system_fingerprint: None, + } + } +} + pub type ChatRequest = OpenaiChatRequest; pub type Message = OpenaiChatMessage; pub type ContentBlock = OpenaiContentPart; diff --git a/src/llm/types/response.rs b/src/llm/types/response.rs index d3eba1c..ece17fe 100644 --- a/src/llm/types/response.rs +++ b/src/llm/types/response.rs @@ -3,6 +3,7 @@ use crate::llm::types::shared::{FinishReason, ServiceTier}; use crate::llm::types::tool::OpenaiToolCall; use crate::llm::types::usage::Usage; use serde::{Deserialize, Serialize}; +use crate::llm::types::{ContentField, OpenaiContentPart}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TokenLogprob { @@ -115,3 +116,66 @@ pub struct OpenaiChatChunk { #[serde(skip_serializing_if = "Option::is_none")] pub system_fingerprint: Option, } + +impl From for Delta { + fn from(msg: OpenaiChatMessage) -> Self { + match msg { + OpenaiChatMessage::Assistant { + content, + tool_calls, + .. + } => Delta { + role: Some("assistant".to_string()), + content: match content { + ContentField::String(s) => Some(s), + ContentField::Array(parts) => { + let mut text = String::new(); + for part in parts { + if let OpenaiContentPart::Text { text: t } = part { + text.push_str(&t); + } + } + if text.is_empty() { + None + } else { + Some(text) + } + } + }, + refusal: None, + tool_calls, + }, + _ => Delta { + role: None, + content: None, + refusal: None, + tool_calls: None, + }, + } + } +} + +impl From for OpenaiChatChunk { + fn from(response: OpenaiChatResponse) -> Self { + let choices = response + .choices + .into_iter() + .map(|c| ChunkChoice { + index: c.index, + delta: Delta::from(c.message), + logprobs: c.logprobs, + finish_reason: c.finish_reason, + }) + .collect(); + + OpenaiChatChunk { + id: response.id, + object: "chat.completion.chunk".to_string(), + created: response.created, + model: response.model, + choices, + usage: Some(response.usage), + system_fingerprint: response.system_fingerprint, + } + } +}