feat(llm): 实现 Phase 0 剩余四个模块

实现 ProviderRegistry、HookExecutor、StreamEvents 和 Auto-compaction 模块,并集成到 LlmCycle 中
This commit is contained in:
徐涛
2026-06-02 08:51:42 +08:00
parent 69b6dd942b
commit 32f3edaf19
13 changed files with 1299 additions and 9 deletions
+7 -1
View File
@@ -5,13 +5,19 @@ edition = "2024"
[dependencies]
tokio = { version = "1", features = ["full"] }
reqwest = { version = "0.12", features = ["json"] }
reqwest = { version = "0.12", features = ["json", "stream"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "2"
async-trait = "0.1"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tokio-stream = "0.1"
futures = "0.3"
futures-util = "0.3"
futures-core = "0.3"
bytes = "1"
async-stream = "0.3"
[dev-dependencies]
dotenvy = "0.15.7"
+375
View File
@@ -0,0 +1,375 @@
# Phase 0 剩余模块 — 实施方案
> 定稿日期:2026-06-02
## 背景与目标
AG Core Phase 0Foundation)已完成核心数据类型、错误体系、Provider 接口、OpenAI 实现、生命周期引擎等基础设施。剩余 4 个子项尚未实现:**ProviderRegistry**、**HookExecutor**、**StreamEvents**、**Auto-compaction**。它们均被 Roadmap 标记为 P0(Must Have),是本阶段不可或缺的底层能力。
**目标**:完成这 4 个模块的设计与实现,使 Phase 0 全面交付。
---
## 需求分析
### 功能需求
| 模块 | 需求 | 验收条件 |
|------|------|---------|
| ProviderRegistry | 支持注册命名 Provider、按名称查找、设置默认 | 3 个公开方法 + 工厂辅助 |
| HookExecutor | 4 个事件点:PreRequest / PostRequest / OnRetry / OnError | Hook trait + HookExecutor 触发 |
| StreamEvents | 流式事件枚举 + Provider 流式方法 + Cycle 流式入口 | 6 种事件类型 + SSE 解析 |
| Auto-compaction | Token 估算 + 微压缩 + 断路器 | 触发后释放 token 且不改变语义 |
### 非功能需求
- 所有公开 API 必须带 `///` 文档注释
- 无新增 `unwrap()` 调用
- 与现有 `LlmCycle` 集成时保持向后兼容(全为可选/增量添加)
- 错误统一使用 `LlmError` 枚举
---
## 方案设计
### 1. ProviderRegistry (`src/llm/provider/registry.rs`)
**职责**:管理多个 LLM Provider 实例,支持按名称注册、发现、切换。
```rust
// src/llm/provider/registry.rs
use std::collections::HashMap;
use crate::llm::error::LlmError;
use crate::llm::provider::{LlmProvider, ProviderConfig, ProviderType, create_provider};
/// Provider 注册表——管理多个 LLM Provider 实例。
pub struct ProviderRegistry {
providers: HashMap<String, Box<dyn LlmProvider>>,
default_name: Option<String>,
}
impl ProviderRegistry {
pub fn new() -> Self;
/// 注册一个已初始化的 Provider 实例。
pub fn register(&mut self, name: impl Into<String>, provider: Box<dyn LlmProvider>);
/// 通过 ProviderType + ProviderConfig 创建并注册。
pub fn register_with_config(
&mut self,
name: impl Into<String>,
provider_type: ProviderType,
config: ProviderConfig,
) -> Result<(), LlmError>;
/// 设置默认 Provider。
pub fn set_default(&mut self, name: &str) -> Result<(), LlmError>;
/// 按名称查找 Provider。
pub fn get(&self, name: &str) -> Option<&dyn LlmProvider>;
/// 获取默认 Provider。
pub fn get_default(&self) -> Option<&dyn LlmProvider>;
}
```
**无新增依赖**
---
### 2. HookExecutor (`src/llm/hooks.rs`)
**职责**:在 LLM 调用生命周期的关键节点插入自定义逻辑。
```rust
// src/llm/hooks.rs
use async_trait::async_trait;
use crate::llm::error::LlmError;
use crate::llm::types::ChatRequest;
/// 生命周期钩子事件点。
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HookEvent {
/// LLM 请求发起之前(可阻断)。
PreRequest,
/// 成功响应之后。
PostRequest,
/// 重试之前(仅可重试错误时触发)。
OnRetry,
/// 不可恢复错误返回之前。
OnError,
}
/// 此次钩子调用的上下文。
#[derive(Debug, Clone)]
pub struct HookContext<'a> {
pub event: HookEvent,
pub request: Option<&'a ChatRequest>,
pub error: Option<&'a LlmError>,
pub attempt: u32,
}
/// 钩子执行结果。
#[derive(Debug, Clone)]
pub struct HookResult {
/// 是否阻断后续操作(仅 PreRequest 有效)。
pub should_block: bool,
/// 阻断/备注原因。
pub reason: Option<String>,
}
/// 生命周期钩子 trait。
#[async_trait]
pub trait Hook: Send + Sync {
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult;
}
/// 钩子执行器——管理注册与触发。
pub struct HookExecutor {
hooks: Vec<(HookEvent, Box<dyn Hook>)>,
}
impl HookExecutor {
pub fn new() -> Self;
pub fn register(&mut self, event: HookEvent, hook: Box<dyn Hook>);
pub async fn execute(&self, event: HookEvent, ctx: &HookContext<'_>) -> Vec<HookResult>;
}
```
**与 LlmCycle 集成**
- `LlmCycle` 新增字段 `hook_executor: Option<HookExecutor>`
- 新增 builder 方法 `with_hook_executor()`
- `submit()` 中在 4 个点触发(PreRequest 若阻断则提前返回)
**无新增依赖**async-trait 已存在)。
---
### 3. StreamEvents (`src/llm/stream.rs`)
**职责**:提供流式 LLM 调用的事件抽象,将原始 SSE chunk 解析为语义化事件。
```rust
// src/llm/stream.rs
use std::pin::Pin;
use tokio_stream::Stream;
use crate::llm::error::LlmError;
use crate::llm::types::{FinishReason, Usage, OpenaiChatChunk, ToolDefinition};
use serde_json::Value;
/// 流式事件——LLM 调用全生命周期的语义化事件。
#[derive(Debug, Clone)]
pub enum StreamEvent {
/// 助手回复文本增量。
AssistantTextDelta { text: String },
/// 工具调用开始。
ToolExecutionStarted { tool_name: String, input: Value },
/// 工具调用完成。
ToolExecutionCompleted { tool_name: String, output: Value, is_error: bool },
/// Token 用量更新。
CostUpdate { usage: Usage },
/// 一轮会话完成。
TurnComplete { reason: FinishReason },
/// 可恢复的错误事件。
Error { message: String },
}
/// 将原始 OpenaiChatChunk 流解析为 StreamEvent 流。
pub fn parse_chunk_stream(
chunks: Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
) -> Pin<Box<dyn Stream<Item = StreamEvent> + Send>>;
```
**Provider 层扩展**`src/llm/provider.rs`):
```rust
#[async_trait]
pub trait LlmProvider: Send + Sync {
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError>;
/// 流式聊天请求——返回原始 SSE chunk 流。
/// 默认实现回退到非流式调用。
async fn chat_stream(
&self,
request: ChatRequest,
) -> Result<
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
LlmError,
> {
let response = self.chat(request).await?;
let chunk = OpenaiChatChunk::from(response);
Ok(Box::pin(tokio_stream::once(Ok(chunk))))
}
}
```
**LlmCycle 扩展**
```rust
impl LlmCycle {
/// 提交请求并返回语义事件流。
pub async fn submit_stream(
&mut self,
prompt: String,
tools: Vec<ToolDefinition>,
) -> Result<
Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
LlmError,
>;
}
```
**新增依赖**
```toml
tokio-stream = "0.1"
```
---
### 4. Auto-compaction (`src/llm/compact.rs`)
**职责**:在上下文过长时自动压缩历史消息,避免 ContextLength 错误。
```rust
// src/llm/compact.rs
use crate::llm::types::{ContentField, Message, OpenaiChatMessage, OpenaiContentPart};
// === 常量 ===
const AUTOCOMPACT_BUFFER_TOKENS: u32 = 13_000;
const RESERVED_OUTPUT_TOKENS: u32 = 20_000;
const MAX_CONSECUTIVE_FAILURES: u32 = 3;
const KEEP_RECENT: usize = 6;
/// 上下文压缩配置。
#[derive(Debug, Clone)]
pub struct CompactConfig {
/// 模型上下文窗口大小(token 数)。
pub context_window: u32,
/// 为输出预留的 token 数。
pub reserved_tokens: u32,
/// 微压缩保留的最近消息数。
pub keep_recent: usize,
}
impl Default for CompactConfig {
fn default() -> Self {
Self {
context_window: 128_000,
reserved_tokens: RESERVED_OUTPUT_TOKENS,
keep_recent: KEEP_RECENT,
}
}
}
impl CompactConfig {
pub fn threshold(&self) -> u32 {
self.context_window
.saturating_sub(self.reserved_tokens)
.saturating_sub(AUTOCOMPACT_BUFFER_TOKENS)
}
}
/// 压缩状态——跟踪连续失败次数(断路器模式)。
#[derive(Debug, Clone)]
pub struct CompactState {
consecutive_failures: u32,
}
impl CompactState {
pub fn new() -> Self;
pub fn record_success(&mut self);
/// 记录失败,返回 true 表示已达断路器上限。
pub fn record_failure(&mut self) -> bool;
}
/// 粗略估计消息列表的 token 数(基于字符数,4 字符 ≈ 1 token)。
pub fn estimate_message_tokens(messages: &[Message]) -> u32;
/// 判断是否需要触发自动压缩。
pub fn should_compact(messages: &[Message], config: &CompactConfig, state: &CompactState) -> bool;
/// 执行微压缩——用占位符替换旧的 tool result 内容。
/// 返回释放的 token 数。
pub fn microcompact(messages: &mut Vec<Message>, keep_recent: usize) -> u32;
```
**与 LlmCycle 集成**
- `LlmCycle` 新增字段 `compact_config: Option<CompactConfig>`, `compact_state: CompactState`
- 新增 builder 方法 `with_compact_config()`
- `submit()` 开始时调用 `should_compact()``microcompact()`
**完整 LLM 摘要压缩**留占位,Phase 2 实现(需要循环内调用 LLM 的能力)。
**无新增依赖**
---
## 实现计划
### Step 1: 先写方案文档
创建 `docs/3-phase0-remaining.md`(即本文档)。
### Step 2: ProviderRegistry
- 创建 `src/llm/provider/registry.rs`
- `provider.rs` 添加 `pub mod registry;`
- `cargo check` 验证
### Step 3: HookExecutor
- 创建 `src/llm/hooks.rs`
- `llm.rs` 添加 `pub mod hooks;`
- `LlmCycle` 新增字段和方法
- `submit()` 中插入钩子触发点
- `cargo check` 验证
### Step 4: StreamEvents
- `Cargo.toml` 添加 `tokio-stream`
- 创建 `src/llm/stream.rs`
- `llm.rs` 添加 `pub mod stream;`
- `LlmProvider` 添加 `chat_stream()`(默认回退)
- `OpenaiProvider` 实现 SSE 解析
- `LlmCycle` 添加 `submit_stream()`
- `cargo check` 验证
### Step 5: Auto-compaction
- 创建 `src/llm/compact.rs`
- `llm.rs` 添加 `pub mod compact;`
- `LlmCycle` 新增字段和方法
- `submit()` 开头插入压缩检查
- `cargo check` 验证
### Step 6: 收尾
- `cargo clippy` — 无警告
- `cargo build --release` — 完整构建
- 检查所有新公开 API 有 `///` 注释
---
## 风险评估
| 风险 | 概率 | 影响 | 缓解措施 |
|------|------|------|---------|
| SSE 解析边界情况 | 中 | 中 | 参考 `reqwest` 的 chunked 响应处理;先用简单的 `lines()` 方式逐行读取 |
| token 估算不准 | 中 | 低 | 仅用于触发微压缩的阈值判断,保守估算即可;后续可接入 tiktoken-rs |
| 钩子阻断语义复杂 | 低 | 中 | PreRequest 阻断后返回明确错误消息;其他事件点只读不阻断 |
| 与后续 Phase 冲突 | 低 | 高 | 保持接口向后兼容,全用可选集成(Option) |
---
## 验收标准
1. `cargo check` 编译通过
2. `cargo clippy` 无警告
3. 4 个模块文件存在且路径正确
4. `ProviderRegistry` 支持注册/查找/默认 Provider
5. `HookExecutor` 支持 4 个事件点注册与触发
6. `StreamEvents` 支持从 SSE chunk 解析为语义事件
7. `Auto-compaction` 支持 token 估算与微压缩
8. `LlmCycle` 向后兼容(新增字段全为 Option,不影响现有代码)
-3
View File
@@ -1,7 +1,4 @@
//! agcore —— 智能体(Agent)核心工具箱。
//!
//! 当前提供 LLM 调用周期控制作为核心底层能力,后续将扩展至
//! 提示词工程、记忆系统、工具调用、Agent 运行时等领域。
pub mod llm;
+3 -2
View File
@@ -1,8 +1,9 @@
//! LLM 调用周期 —— 大模型基础调用周期控制。
//!
//! 包含核心数据类型、Provider 抽象、OpenAI 兼容实现以及生命周期引擎。
pub mod compact;
pub mod cycle;
pub mod error;
pub mod hooks;
pub mod provider;
pub mod stream;
pub mod types;
+159
View File
@@ -0,0 +1,159 @@
//! 上下文自动压缩 —— 当对话历史过长时自动压缩。
use crate::llm::types::{ContentField, OpenaiChatMessage, OpenaiContentPart};
const AUTOCOMPACT_BUFFER_TOKENS: u32 = 13_000;
const RESERVED_OUTPUT_TOKENS: u32 = 20_000;
const MAX_CONSECUTIVE_FAILURES: u32 = 3;
const KEEP_RECENT: usize = 6;
/// 上下文压缩配置。
#[derive(Debug, Clone)]
pub struct CompactConfig {
/// 模型上下文窗口大小(token 数)。
pub context_window: u32,
/// 为输出预留的 token 数。
pub reserved_tokens: u32,
/// 微压缩保留的最近消息数。
pub keep_recent: usize,
}
impl Default for CompactConfig {
fn default() -> Self {
Self {
context_window: 128_000,
reserved_tokens: RESERVED_OUTPUT_TOKENS,
keep_recent: KEEP_RECENT,
}
}
}
impl CompactConfig {
/// 计算自动压缩触发的阈值。
pub fn threshold(&self) -> u32 {
self.context_window
.saturating_sub(self.reserved_tokens)
.saturating_sub(AUTOCOMPACT_BUFFER_TOKENS)
}
}
/// 压缩状态 —— 跟踪连续失败次数(断路器模式)。
#[derive(Debug, Clone)]
pub struct CompactState {
consecutive_failures: u32,
}
impl Default for CompactState {
fn default() -> Self {
Self::new()
}
}
impl CompactState {
/// 创建一个新的压缩状态。
pub fn new() -> Self {
Self {
consecutive_failures: 0,
}
}
/// 记录一次成功的压缩。
pub fn record_success(&mut self) {
self.consecutive_failures = 0;
}
/// 记录一次压缩失败。
///
/// 返回 `true` 表示已达断路器上限,不再尝试。
pub fn record_failure(&mut self) -> bool {
self.consecutive_failures += 1;
self.consecutive_failures >= MAX_CONSECUTIVE_FAILURES
}
}
/// 粗略估计消息列表的 token 数(基于字符数,4 字符 ≈ 1 token)。
pub fn estimate_message_tokens(messages: &[OpenaiChatMessage]) -> u32 {
messages
.iter()
.map(estimate_single_message_tokens)
.sum()
}
fn estimate_single_message_tokens(msg: &OpenaiChatMessage) -> u32 {
let role_overhead: u32 = 4;
let content_tokens = match msg {
OpenaiChatMessage::Developer { content, .. }
| OpenaiChatMessage::System { content, .. }
| OpenaiChatMessage::User { content, .. }
| OpenaiChatMessage::Assistant { content, .. }
| OpenaiChatMessage::Function { content, .. } => estimate_content_tokens(content),
OpenaiChatMessage::Tool { content, .. } => estimate_content_tokens(content),
};
role_overhead + content_tokens
}
fn estimate_content_tokens(content: &ContentField) -> u32 {
match content {
ContentField::String(s) => estimate_text_tokens(s),
ContentField::Array(parts) => parts.iter().map(estimate_part_tokens).sum(),
}
}
fn estimate_part_tokens(part: &OpenaiContentPart) -> u32 {
match part {
OpenaiContentPart::Text { text } => estimate_text_tokens(text),
_ => 50,
}
}
fn estimate_text_tokens(text: &str) -> u32 {
if text.is_empty() {
return 0;
}
let len = text.len() as u32;
(len * 4).div_ceil(3)
}
/// 判断是否需要触发自动压缩。
pub fn should_compact(
messages: &[OpenaiChatMessage],
config: &CompactConfig,
state: &CompactState,
) -> bool {
if state.consecutive_failures >= MAX_CONSECUTIVE_FAILURES {
return false;
}
let tokens = estimate_message_tokens(messages);
tokens >= config.threshold()
}
/// 执行微压缩 —— 用 `[pruned]` 替换旧的 tool result 内容。
///
/// 这是最便宜的压缩方式,不需要 LLM 调用。
/// 保留最近的 `keep_recent` 条消息不变。
///
/// 返回释放的估算 token 数。
pub fn microcompact(messages: &mut [OpenaiChatMessage], keep_recent: usize) -> u32 {
if messages.len() <= keep_recent {
return 0;
}
let prune_start = messages.len() - keep_recent;
let mut freed_tokens: u32 = 0;
for msg in &messages[..prune_start] {
if matches!(msg, OpenaiChatMessage::Tool { .. }) {
freed_tokens += estimate_single_message_tokens(msg);
}
}
for msg in &mut messages[..prune_start] {
if let OpenaiChatMessage::Tool { content, .. } = msg {
*content = ContentField::Array(vec![OpenaiContentPart::Text {
text: "[pruned]".to_string(),
}]);
}
}
freed_tokens
}
+216 -1
View File
@@ -1,21 +1,38 @@
//! LLM 调用周期控制模块。
mod retry;
pub mod usage;
pub use retry::RetryConfig;
pub use usage::{CostTracker, Usage};
use std::pin::Pin;
use std::sync::Arc;
use futures_core::stream::Stream;
use async_stream::stream;
use crate::llm::compact::{should_compact, microcompact, CompactConfig, CompactState};
use crate::llm::cycle::retry::should_retry;
use crate::llm::error::LlmError;
use crate::llm::hooks::{HookContext, HookExecutor};
use crate::llm::provider::LlmProvider;
use crate::llm::stream::StreamEvent;
use crate::llm::types::{
ChatRequest, ChatResponse, OpenaiChatMessage, OpenaiTool, ToolChoice, ToolDefinition,
};
/// LLM 调用周期配置。
pub struct CycleConfig {
/// 模型名称。
pub model: String,
/// 最大输出 token 数。
pub max_tokens: Option<u32>,
/// 采样温度。
pub temperature: Option<f32>,
/// 最大对话轮次。
pub max_turns: Option<u32>,
/// 重试策略配置。
pub retry: RetryConfig,
}
@@ -31,15 +48,20 @@ impl Default for CycleConfig {
}
}
/// LLM 调用周期 —— 管理一次或多次 LLM 请求的生命周期。
pub struct LlmCycle {
provider: Box<dyn LlmProvider>,
config: CycleConfig,
usage: CostTracker,
messages: Vec<OpenaiChatMessage>,
system_prompt: Option<String>,
hook_executor: Option<Arc<HookExecutor>>,
compact_config: Option<CompactConfig>,
compact_state: CompactState,
}
impl LlmCycle {
/// 创建一个新的 LlmCycle。
pub fn new(provider: Box<dyn LlmProvider>, config: CycleConfig) -> Self {
Self {
provider,
@@ -47,30 +69,51 @@ impl LlmCycle {
usage: CostTracker::default(),
messages: Vec::new(),
system_prompt: None,
hook_executor: None,
compact_config: None,
compact_state: CompactState::new(),
}
}
/// 设置系统提示词。
pub fn with_system_prompt(mut self, prompt: String) -> Self {
self.system_prompt = Some(prompt);
self
}
/// 设置钩子执行器。
pub fn with_hook_executor(mut self, executor: HookExecutor) -> Self {
self.hook_executor = Some(Arc::new(executor));
self
}
/// 设置上下文压缩配置。
pub fn with_compact_config(mut self, config: CompactConfig) -> Self {
self.compact_config = Some(config);
self
}
/// 获取用量追踪器引用。
pub fn usage(&self) -> &CostTracker {
&self.usage
}
/// 获取消息历史引用。
pub fn messages(&self) -> &[OpenaiChatMessage] {
&self.messages
}
/// 清空消息历史。
pub fn clear_messages(&mut self) {
self.messages.clear();
}
/// 重置用量统计。
pub fn reset_usage(&mut self) {
self.usage.reset();
}
/// 提交用户消息并获取 LLM 响应。
pub async fn submit(
&mut self,
prompt: String,
@@ -78,31 +121,203 @@ impl LlmCycle {
) -> Result<ChatResponse, LlmError> {
self.messages.push(OpenaiChatMessage::user_text(prompt));
if let Some(ref config) = self.compact_config
&& should_compact(&self.messages, config, &self.compact_state)
{
let freed = microcompact(&mut self.messages, config.keep_recent);
if freed > 0 {
self.compact_state.record_success();
}
}
let mut attempts = 0;
loop {
let request = self.build_request(&tools);
if let Some(ref executor) = self.hook_executor {
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
.with_request(&request);
let results = executor
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
.await;
if results.iter().any(|r| r.should_block) {
let reason = results
.iter()
.find(|r| r.should_block)
.and_then(|r| r.reason.clone())
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
return Err(LlmError::Other(reason));
}
}
match self.provider.chat(request).await {
Ok(response) => {
self.messages.push(response.message.clone());
if let Some(ref executor) = self.hook_executor {
let post_request = self.build_request(&tools);
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
.with_request(&post_request);
executor
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
.await;
}
self.messages.push(response.message.clone());
self.usage.add(&response.usage);
return Ok(response);
}
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
attempts += 1;
if let Some(ref executor) = self.hook_executor {
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnRetry)
.with_error(&e)
.with_attempt(attempts);
executor
.execute(crate::llm::hooks::HookEvent::OnRetry, &ctx)
.await;
}
let delay = self.config.retry.compute_delay(attempts);
tokio::time::sleep(delay).await;
}
Err(e) => {
if let Some(ref executor) = self.hook_executor {
let ctx =
HookContext::new(crate::llm::hooks::HookEvent::OnError).with_error(&e);
executor
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
.await;
}
return Err(e);
}
}
}
}
/// 提交用户消息并返回语义事件流。
///
/// 与 `submit` 不同,该方法返回流式事件而非完整响应。
/// 适用于需要实时处理 LLM 输出的场景。
pub async fn submit_stream(
&mut self,
prompt: String,
tools: Vec<ToolDefinition>,
) -> Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>, LlmError> {
self.messages.push(OpenaiChatMessage::user_text(prompt));
if let Some(ref config) = self.compact_config
&& should_compact(&self.messages, config, &self.compact_state)
{
let freed = microcompact(&mut self.messages, config.keep_recent);
if freed > 0 {
self.compact_state.record_success();
}
}
let request = self.build_request(&tools);
if let Some(ref executor) = self.hook_executor {
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
.with_request(&request);
let results = executor
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
.await;
if results.iter().any(|r| r.should_block) {
let reason = results
.iter()
.find(|r| r.should_block)
.and_then(|r| r.reason.clone())
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
return Err(LlmError::Other(reason));
}
}
let chunk_stream = self.provider.chat_stream(request).await?;
let hook_executor = self.hook_executor.clone();
let post_request = self.build_request(&tools);
Ok(Box::pin(stream! {
use futures_util::StreamExt;
let mut chunk_stream = chunk_stream;
while let Some(result) = chunk_stream.next().await {
match result {
Ok(chunk) => {
let mut assistant_text = String::new();
let mut tool_started: Option<(String, String, String)> = None;
for choice in &chunk.choices {
let delta = &choice.delta;
if let Some(content) = &delta.content {
assistant_text.push_str(content);
}
if let Some(tool_calls) = &delta.tool_calls
&& let Some(tc) = tool_calls.first()
{
let crate::llm::types::OpenaiToolCall::Function { id, function } = tc;
tool_started = Some((id.clone(), function.name.clone(), function.arguments.clone()));
}
}
if !assistant_text.is_empty() {
yield StreamEvent::AssistantTextDelta { text: assistant_text };
}
if let Some((tool_call_id, tool_name, arguments)) = tool_started {
let args: serde_json::Value = serde_json::from_str(&arguments)
.unwrap_or(serde_json::Value::Null);
yield StreamEvent::ToolExecutionStarted {
tool_name,
input: args,
tool_call_id,
};
}
for choice in &chunk.choices {
if let Some(finish_reason) = &choice.finish_reason {
yield StreamEvent::TurnComplete {
reason: *finish_reason,
};
}
}
if let Some(usage_info) = &chunk.usage {
yield StreamEvent::CostUpdate { usage: *usage_info };
}
}
Err(e) => {
if let Some(ref executor) = hook_executor {
let ctx = crate::llm::hooks::HookContext::new(
crate::llm::hooks::HookEvent::OnError,
)
.with_error(&e);
executor
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
.await;
}
yield StreamEvent::Error { message: e.to_string() };
break;
}
}
}
if let Some(ref executor) = hook_executor {
let ctx = crate::llm::hooks::HookContext::new(
crate::llm::hooks::HookEvent::PostRequest,
)
.with_request(&post_request);
executor
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
.await;
}
}))
}
fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest {
let mut messages = self.messages.clone();
+128
View File
@@ -0,0 +1,128 @@
//! 生命周期钩子 —— 在 LLM 调用周期的关键节点插入自定义逻辑。
use async_trait::async_trait;
use crate::llm::error::LlmError;
use crate::llm::types::ChatRequest;
/// 生命周期钩子事件点。
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HookEvent {
/// LLM 请求发起之前(可阻断)。
PreRequest,
/// 成功响应之后。
PostRequest,
/// 重试之前(仅可重试错误时触发)。
OnRetry,
/// 不可恢复错误返回之前。
OnError,
}
/// 此次钩子调用的上下文。
#[derive(Debug, Clone)]
pub struct HookContext<'a> {
/// 当前事件点。
pub event: HookEvent,
/// 当前的请求(部分事件点可用)。
pub request: Option<&'a ChatRequest>,
/// 当前错误(仅 OnError 和 OnRetry 可用)。
pub error: Option<&'a LlmError>,
/// 当前重试次数(从 1 开始,仅 OnRetry 可用)。
pub attempt: u32,
}
impl<'a> HookContext<'a> {
pub(crate) fn new(event: HookEvent) -> Self {
Self {
event,
request: None,
error: None,
attempt: 0,
}
}
pub(crate) fn with_request(mut self, request: &'a ChatRequest) -> Self {
self.request = Some(request);
self
}
pub(crate) fn with_error(mut self, error: &'a LlmError) -> Self {
self.error = Some(error);
self
}
pub(crate) fn with_attempt(mut self, attempt: u32) -> Self {
self.attempt = attempt;
self
}
}
/// 钩子执行结果。
#[derive(Debug, Clone)]
pub struct HookResult {
/// 是否阻断后续操作(仅 PreRequest 有效)。
pub should_block: bool,
/// 阻断/备注原因。
pub reason: Option<String>,
}
impl HookResult {
#[allow(dead_code)]
pub(crate) fn allow() -> Self {
Self {
should_block: false,
reason: None,
}
}
#[allow(dead_code)]
pub(crate) fn block(reason: impl Into<String>) -> Self {
Self {
should_block: true,
reason: Some(reason.into()),
}
}
}
/// 生命周期钩子 trait。
#[async_trait]
pub trait Hook: Send + Sync {
/// 执行钩子逻辑。
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult;
}
/// 钩子执行器 —— 管理注册与触发。
pub struct HookExecutor {
hooks: Vec<(HookEvent, Box<dyn Hook>)>,
}
impl Default for HookExecutor {
fn default() -> Self {
Self::new()
}
}
impl HookExecutor {
/// 创建一个空的执行器。
pub fn new() -> Self {
Self {
hooks: Vec::new(),
}
}
/// 注册一个钩子到指定事件点。
pub fn register(&mut self, event: HookEvent, hook: Box<dyn Hook>) {
self.hooks.push((event, hook));
}
/// 执行指定事件点的所有钩子。
pub async fn execute(&self, event: HookEvent, ctx: &HookContext<'_>) -> Vec<HookResult> {
let mut results = Vec::new();
for (e, hook) in &self.hooks {
if *e == event {
results.push(hook.execute(ctx).await);
}
}
results
}
}
+21 -1
View File
@@ -1,7 +1,12 @@
pub mod openai;
pub mod registry;
use std::pin::Pin;
use tokio_stream::Stream;
use crate::llm::error::LlmError;
use crate::llm::types::{ChatRequest, ChatResponse};
use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatChunk};
use async_trait::async_trait;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -57,4 +62,19 @@ pub fn create_provider(
pub trait LlmProvider: Send + Sync {
/// 发送聊天请求并返回完整响应。
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError>;
/// 流式聊天请求 —— 返回原始 SSE chunk 流。
///
/// 默认实现回退到非流式调用(包装为单元素流)。
async fn chat_stream(
&self,
request: ChatRequest,
) -> Result<
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
LlmError,
> {
let response = self.chat(request).await?;
let chunk = OpenaiChatChunk::from(response);
Ok(Box::pin(tokio_stream::once(Ok(chunk))))
}
}
+122 -1
View File
@@ -1,12 +1,18 @@
use std::pin::Pin;
use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use futures_core::stream::Stream;
use futures_util::StreamExt;
use reqwest::Client;
use tracing::{debug, error, info};
use super::LlmProvider;
use crate::llm::error::LlmError;
use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatResponse};
use crate::llm::types::{
ChatRequest, ChatResponse, OpenaiChatChunk, OpenaiChatResponse, StreamOptions,
};
pub struct OpenaiProvider {
http_client: Client,
@@ -111,4 +117,119 @@ impl LlmProvider for OpenaiProvider {
Ok(ChatResponse::from(chat_response))
}
async fn chat_stream(
&self,
mut request: ChatRequest,
) -> Result<
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
LlmError,
> {
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
request.stream = Some(true);
request.stream_options = Some(StreamOptions {
include_usage: Some(true),
include_obfuscation: None,
});
info!(model = %request.model, "发送 LLM 流式请求");
let response = self
.http_client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await
.map_err(|e| {
error!(error = %e, "流式请求失败");
Self::map_reqwest_error(e)
})?;
let status = response.status();
let status_code: u16 = status.as_u16();
if !status.is_success() {
let body_text = response.text().await.unwrap_or_default();
error!(status = status_code, body = %body_text, "流式请求失败");
return Err(LlmError::Request {
status: status_code,
body: body_text,
});
}
let byte_stream = response.bytes_stream().map(|r| {
r.map_err(|e| LlmError::Other(format!("流式读取失败: {}", e)))
});
Ok(Box::pin(SseChunkStream::new(byte_stream)))
}
}
struct SseChunkStream<S> {
inner: S,
buffer: String,
}
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> SseChunkStream<S> {
fn new(stream: S) -> Self {
Self {
inner: stream,
buffer: String::new(),
}
}
}
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> Stream for SseChunkStream<S> {
type Item = Result<OpenaiChatChunk, LlmError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
loop {
if let Some(pos) = self.buffer.find("\n") {
let line = self.buffer.drain(..pos + 1).collect::<String>();
let trimmed = line.trim();
if trimmed.is_empty()
|| trimmed == "data: [DONE]"
|| trimmed == "[DONE]"
|| trimmed == "data:"
{
continue;
}
let data = if let Some(p) = trimmed.strip_prefix("data: ") {
p
} else {
trimmed
};
match serde_json::from_str::<OpenaiChatChunk>(data) {
Ok(chunk) => return std::task::Poll::Ready(Some(Ok(chunk))),
Err(e) => {
return std::task::Poll::Ready(Some(Err(LlmError::Other(format!(
"Chunk 解析失败: {} | raw: {}",
e, data
)))));
}
}
}
match Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(Ok(bytes))) => {
if let Ok(s) = std::str::from_utf8(&bytes) {
self.buffer.push_str(s);
}
}
std::task::Poll::Ready(Some(Err(e))) => {
return std::task::Poll::Ready(Some(Err(e)));
}
std::task::Poll::Ready(None) => {
if self.buffer.is_empty() {
return std::task::Poll::Ready(None);
}
self.buffer.clear();
return std::task::Poll::Ready(None);
}
std::task::Poll::Pending => return std::task::Poll::Pending,
}
}
}
}
+68
View File
@@ -0,0 +1,68 @@
//! Provider 注册表 —— 多 Provider 实例的注册与发现。
use std::collections::HashMap;
use crate::llm::error::LlmError;
use crate::llm::provider::{create_provider, LlmProvider, ProviderConfig, ProviderType};
/// Provider 注册表 —— 管理多个 LLM Provider 实例。
///
/// 支持注册命名 Provider、按名称查找、设置默认 Provider。
pub struct ProviderRegistry {
providers: HashMap<String, Box<dyn LlmProvider>>,
default_name: Option<String>,
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}
impl ProviderRegistry {
/// 创建一个空的注册表。
pub fn new() -> Self {
Self {
providers: HashMap::new(),
default_name: None,
}
}
/// 注册一个已初始化的 Provider 实例。
pub fn register(&mut self, name: impl Into<String>, provider: Box<dyn LlmProvider>) {
self.providers.insert(name.into(), provider);
}
/// 通过 ProviderType + ProviderConfig 创建并注册。
pub fn register_with_config(
&mut self,
name: impl Into<String>,
provider_type: ProviderType,
config: ProviderConfig,
) -> Result<(), LlmError> {
let provider = create_provider(provider_type, config)?;
self.register(name, provider);
Ok(())
}
/// 设置默认 Provider。
pub fn set_default(&mut self, name: &str) -> Result<(), LlmError> {
if !self.providers.contains_key(name) {
return Err(LlmError::Other(format!("Provider '{}' 不存在", name)));
}
self.default_name = Some(name.to_string());
Ok(())
}
/// 按名称查找 Provider。
pub fn get(&self, name: &str) -> Option<&dyn LlmProvider> {
self.providers.get(name).map(|p| p.as_ref())
}
/// 获取默认 Provider。
pub fn get_default(&self) -> Option<&dyn LlmProvider> {
self.default_name
.as_ref()
.and_then(|name| self.get(name))
}
}
+108
View File
@@ -0,0 +1,108 @@
//! 流式事件系统 —— 将 LLM 流式响应解析为语义化事件。
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_core::stream::Stream;
use futures_util::future::poll_fn;
use futures_util::FutureExt;
use serde_json::Value;
use crate::llm::error::LlmError;
use crate::llm::types::{FinishReason, OpenaiChatChunk, OpenaiToolCall, Usage};
/// 流式事件 —— LLM 调用全生命周期的语义化事件。
#[derive(Debug, Clone)]
pub enum StreamEvent {
/// 助手回复文本增量。
AssistantTextDelta { text: String },
/// 工具调用开始。
ToolExecutionStarted {
tool_name: String,
input: Value,
tool_call_id: String,
},
/// 工具调用完成。
ToolExecutionCompleted {
tool_name: String,
output: Value,
is_error: bool,
},
/// Token 用量更新。
CostUpdate { usage: Usage },
/// 一轮会话完成。
TurnComplete { reason: FinishReason },
/// 错误事件。
Error { message: String },
}
impl StreamEvent {
fn error(message: impl Into<String>) -> Self {
Self::Error {
message: message.into(),
}
}
}
/// 将原始 OpenaiChatChunk 流解析为 StreamEvent 流。
pub fn parse_chunk_stream(
chunks: Pin<Box<dyn futures_core::Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
) -> Pin<Box<dyn futures_core::Stream<Item = StreamEvent> + Send>> {
Box::pin(ChunkToEventStream { chunks })
}
struct ChunkToEventStream {
chunks: Pin<Box<dyn futures_core::Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
}
impl Stream for ChunkToEventStream {
type Item = StreamEvent;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
poll_fn(|cx| match Pin::new(&mut this.chunks).poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
for choice in &chunk.choices {
let delta = &choice.delta;
if let Some(content) = &delta.content {
return Poll::Ready(Some(StreamEvent::AssistantTextDelta {
text: content.clone(),
}));
}
if let Some(tool_calls) = &delta.tool_calls
&& let Some(tc) = tool_calls.first()
{
let OpenaiToolCall::Function { id, function } = tc;
let args: Value =
serde_json::from_str(&function.arguments).unwrap_or(Value::Null);
return Poll::Ready(Some(StreamEvent::ToolExecutionStarted {
tool_name: function.name.clone(),
input: args,
tool_call_id: id.clone(),
}));
}
if let Some(finish_reason) = &choice.finish_reason {
return Poll::Ready(Some(StreamEvent::TurnComplete {
reason: *finish_reason,
}));
}
}
if let Some(usage) = &chunk.usage {
return Poll::Ready(Some(StreamEvent::CostUpdate {
usage: *usage,
}));
}
Poll::Ready(None)
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(StreamEvent::error(e.to_string()))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
})
.poll_unpin(cx)
}
}
+28
View File
@@ -43,6 +43,34 @@ impl From<OpenaiChatResponse> for ChatResponse {
}
}
impl From<ChatResponse> for OpenaiChatChunk {
fn from(response: ChatResponse) -> Self {
let delta = Delta::from(response.message.clone());
let chunk_choice = ChunkChoice {
index: 0,
delta,
logprobs: None,
finish_reason: response.stop_reason,
};
OpenaiChatChunk {
id: format!("chunk-{}", std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)),
object: "chat.completion.chunk".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
model: String::new(),
choices: vec![chunk_choice],
usage: Some(response.usage),
system_fingerprint: None,
}
}
}
pub type ChatRequest = OpenaiChatRequest;
pub type Message = OpenaiChatMessage;
pub type ContentBlock = OpenaiContentPart;
+64
View File
@@ -3,6 +3,7 @@ use crate::llm::types::shared::{FinishReason, ServiceTier};
use crate::llm::types::tool::OpenaiToolCall;
use crate::llm::types::usage::Usage;
use serde::{Deserialize, Serialize};
use crate::llm::types::{ContentField, OpenaiContentPart};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenLogprob {
@@ -115,3 +116,66 @@ pub struct OpenaiChatChunk {
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
impl From<OpenaiChatMessage> for Delta {
fn from(msg: OpenaiChatMessage) -> Self {
match msg {
OpenaiChatMessage::Assistant {
content,
tool_calls,
..
} => Delta {
role: Some("assistant".to_string()),
content: match content {
ContentField::String(s) => Some(s),
ContentField::Array(parts) => {
let mut text = String::new();
for part in parts {
if let OpenaiContentPart::Text { text: t } = part {
text.push_str(&t);
}
}
if text.is_empty() {
None
} else {
Some(text)
}
}
},
refusal: None,
tool_calls,
},
_ => Delta {
role: None,
content: None,
refusal: None,
tool_calls: None,
},
}
}
}
impl From<OpenaiChatResponse> for OpenaiChatChunk {
fn from(response: OpenaiChatResponse) -> Self {
let choices = response
.choices
.into_iter()
.map(|c| ChunkChoice {
index: c.index,
delta: Delta::from(c.message),
logprobs: c.logprobs,
finish_reason: c.finish_reason,
})
.collect();
OpenaiChatChunk {
id: response.id,
object: "chat.completion.chunk".to_string(),
created: response.created,
model: response.model,
choices,
usage: Some(response.usage),
system_fingerprint: response.system_fingerprint,
}
}
}