feat(llm): 实现 Phase 0 剩余四个模块
实现 ProviderRegistry、HookExecutor、StreamEvents 和 Auto-compaction 模块,并集成到 LlmCycle 中
This commit is contained in:
+7
-1
@@ -5,13 +5,19 @@ edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "2"
|
||||
async-trait = "0.1"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
tokio-stream = "0.1"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
futures-core = "0.3"
|
||||
bytes = "1"
|
||||
async-stream = "0.3"
|
||||
|
||||
[dev-dependencies]
|
||||
dotenvy = "0.15.7"
|
||||
|
||||
@@ -0,0 +1,375 @@
|
||||
# Phase 0 剩余模块 — 实施方案
|
||||
|
||||
> 定稿日期:2026-06-02
|
||||
|
||||
## 背景与目标
|
||||
|
||||
AG Core Phase 0(Foundation)已完成核心数据类型、错误体系、Provider 接口、OpenAI 实现、生命周期引擎等基础设施。剩余 4 个子项尚未实现:**ProviderRegistry**、**HookExecutor**、**StreamEvents**、**Auto-compaction**。它们均被 Roadmap 标记为 P0(Must Have),是本阶段不可或缺的底层能力。
|
||||
|
||||
**目标**:完成这 4 个模块的设计与实现,使 Phase 0 全面交付。
|
||||
|
||||
---
|
||||
|
||||
## 需求分析
|
||||
|
||||
### 功能需求
|
||||
|
||||
| 模块 | 需求 | 验收条件 |
|
||||
|------|------|---------|
|
||||
| ProviderRegistry | 支持注册命名 Provider、按名称查找、设置默认 | 3 个公开方法 + 工厂辅助 |
|
||||
| HookExecutor | 4 个事件点:PreRequest / PostRequest / OnRetry / OnError | Hook trait + HookExecutor 触发 |
|
||||
| StreamEvents | 流式事件枚举 + Provider 流式方法 + Cycle 流式入口 | 6 种事件类型 + SSE 解析 |
|
||||
| Auto-compaction | Token 估算 + 微压缩 + 断路器 | 触发后释放 token 且不改变语义 |
|
||||
|
||||
### 非功能需求
|
||||
|
||||
- 所有公开 API 必须带 `///` 文档注释
|
||||
- 无新增 `unwrap()` 调用
|
||||
- 与现有 `LlmCycle` 集成时保持向后兼容(全为可选/增量添加)
|
||||
- 错误统一使用 `LlmError` 枚举
|
||||
|
||||
---
|
||||
|
||||
## 方案设计
|
||||
|
||||
### 1. ProviderRegistry (`src/llm/provider/registry.rs`)
|
||||
|
||||
**职责**:管理多个 LLM Provider 实例,支持按名称注册、发现、切换。
|
||||
|
||||
```rust
|
||||
// src/llm/provider/registry.rs
|
||||
|
||||
use std::collections::HashMap;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::provider::{LlmProvider, ProviderConfig, ProviderType, create_provider};
|
||||
|
||||
/// Provider 注册表——管理多个 LLM Provider 实例。
|
||||
pub struct ProviderRegistry {
|
||||
providers: HashMap<String, Box<dyn LlmProvider>>,
|
||||
default_name: Option<String>,
|
||||
}
|
||||
|
||||
impl ProviderRegistry {
|
||||
pub fn new() -> Self;
|
||||
|
||||
/// 注册一个已初始化的 Provider 实例。
|
||||
pub fn register(&mut self, name: impl Into<String>, provider: Box<dyn LlmProvider>);
|
||||
|
||||
/// 通过 ProviderType + ProviderConfig 创建并注册。
|
||||
pub fn register_with_config(
|
||||
&mut self,
|
||||
name: impl Into<String>,
|
||||
provider_type: ProviderType,
|
||||
config: ProviderConfig,
|
||||
) -> Result<(), LlmError>;
|
||||
|
||||
/// 设置默认 Provider。
|
||||
pub fn set_default(&mut self, name: &str) -> Result<(), LlmError>;
|
||||
|
||||
/// 按名称查找 Provider。
|
||||
pub fn get(&self, name: &str) -> Option<&dyn LlmProvider>;
|
||||
|
||||
/// 获取默认 Provider。
|
||||
pub fn get_default(&self) -> Option<&dyn LlmProvider>;
|
||||
}
|
||||
```
|
||||
|
||||
**无新增依赖**。
|
||||
|
||||
---
|
||||
|
||||
### 2. HookExecutor (`src/llm/hooks.rs`)
|
||||
|
||||
**职责**:在 LLM 调用生命周期的关键节点插入自定义逻辑。
|
||||
|
||||
```rust
|
||||
// src/llm/hooks.rs
|
||||
|
||||
use async_trait::async_trait;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::ChatRequest;
|
||||
|
||||
/// 生命周期钩子事件点。
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum HookEvent {
|
||||
/// LLM 请求发起之前(可阻断)。
|
||||
PreRequest,
|
||||
/// 成功响应之后。
|
||||
PostRequest,
|
||||
/// 重试之前(仅可重试错误时触发)。
|
||||
OnRetry,
|
||||
/// 不可恢复错误返回之前。
|
||||
OnError,
|
||||
}
|
||||
|
||||
/// 此次钩子调用的上下文。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookContext<'a> {
|
||||
pub event: HookEvent,
|
||||
pub request: Option<&'a ChatRequest>,
|
||||
pub error: Option<&'a LlmError>,
|
||||
pub attempt: u32,
|
||||
}
|
||||
|
||||
/// 钩子执行结果。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookResult {
|
||||
/// 是否阻断后续操作(仅 PreRequest 有效)。
|
||||
pub should_block: bool,
|
||||
/// 阻断/备注原因。
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
/// 生命周期钩子 trait。
|
||||
#[async_trait]
|
||||
pub trait Hook: Send + Sync {
|
||||
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult;
|
||||
}
|
||||
|
||||
/// 钩子执行器——管理注册与触发。
|
||||
pub struct HookExecutor {
|
||||
hooks: Vec<(HookEvent, Box<dyn Hook>)>,
|
||||
}
|
||||
|
||||
impl HookExecutor {
|
||||
pub fn new() -> Self;
|
||||
pub fn register(&mut self, event: HookEvent, hook: Box<dyn Hook>);
|
||||
pub async fn execute(&self, event: HookEvent, ctx: &HookContext<'_>) -> Vec<HookResult>;
|
||||
}
|
||||
```
|
||||
|
||||
**与 LlmCycle 集成**:
|
||||
- `LlmCycle` 新增字段 `hook_executor: Option<HookExecutor>`
|
||||
- 新增 builder 方法 `with_hook_executor()`
|
||||
- `submit()` 中在 4 个点触发(PreRequest 若阻断则提前返回)
|
||||
|
||||
**无新增依赖**(async-trait 已存在)。
|
||||
|
||||
---
|
||||
|
||||
### 3. StreamEvents (`src/llm/stream.rs`)
|
||||
|
||||
**职责**:提供流式 LLM 调用的事件抽象,将原始 SSE chunk 解析为语义化事件。
|
||||
|
||||
```rust
|
||||
// src/llm/stream.rs
|
||||
|
||||
use std::pin::Pin;
|
||||
use tokio_stream::Stream;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{FinishReason, Usage, OpenaiChatChunk, ToolDefinition};
|
||||
use serde_json::Value;
|
||||
|
||||
/// 流式事件——LLM 调用全生命周期的语义化事件。
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamEvent {
|
||||
/// 助手回复文本增量。
|
||||
AssistantTextDelta { text: String },
|
||||
/// 工具调用开始。
|
||||
ToolExecutionStarted { tool_name: String, input: Value },
|
||||
/// 工具调用完成。
|
||||
ToolExecutionCompleted { tool_name: String, output: Value, is_error: bool },
|
||||
/// Token 用量更新。
|
||||
CostUpdate { usage: Usage },
|
||||
/// 一轮会话完成。
|
||||
TurnComplete { reason: FinishReason },
|
||||
/// 可恢复的错误事件。
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
/// 将原始 OpenaiChatChunk 流解析为 StreamEvent 流。
|
||||
pub fn parse_chunk_stream(
|
||||
chunks: Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
) -> Pin<Box<dyn Stream<Item = StreamEvent> + Send>>;
|
||||
```
|
||||
|
||||
**Provider 层扩展**(`src/llm/provider.rs`):
|
||||
```rust
|
||||
#[async_trait]
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError>;
|
||||
|
||||
/// 流式聊天请求——返回原始 SSE chunk 流。
|
||||
/// 默认实现回退到非流式调用。
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
request: ChatRequest,
|
||||
) -> Result<
|
||||
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
LlmError,
|
||||
> {
|
||||
let response = self.chat(request).await?;
|
||||
let chunk = OpenaiChatChunk::from(response);
|
||||
Ok(Box::pin(tokio_stream::once(Ok(chunk))))
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**LlmCycle 扩展**:
|
||||
```rust
|
||||
impl LlmCycle {
|
||||
/// 提交请求并返回语义事件流。
|
||||
pub async fn submit_stream(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
tools: Vec<ToolDefinition>,
|
||||
) -> Result<
|
||||
Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
|
||||
LlmError,
|
||||
>;
|
||||
}
|
||||
```
|
||||
|
||||
**新增依赖**:
|
||||
```toml
|
||||
tokio-stream = "0.1"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. Auto-compaction (`src/llm/compact.rs`)
|
||||
|
||||
**职责**:在上下文过长时自动压缩历史消息,避免 ContextLength 错误。
|
||||
|
||||
```rust
|
||||
// src/llm/compact.rs
|
||||
|
||||
use crate::llm::types::{ContentField, Message, OpenaiChatMessage, OpenaiContentPart};
|
||||
|
||||
// === 常量 ===
|
||||
const AUTOCOMPACT_BUFFER_TOKENS: u32 = 13_000;
|
||||
const RESERVED_OUTPUT_TOKENS: u32 = 20_000;
|
||||
const MAX_CONSECUTIVE_FAILURES: u32 = 3;
|
||||
const KEEP_RECENT: usize = 6;
|
||||
|
||||
/// 上下文压缩配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactConfig {
|
||||
/// 模型上下文窗口大小(token 数)。
|
||||
pub context_window: u32,
|
||||
/// 为输出预留的 token 数。
|
||||
pub reserved_tokens: u32,
|
||||
/// 微压缩保留的最近消息数。
|
||||
pub keep_recent: usize,
|
||||
}
|
||||
|
||||
impl Default for CompactConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
context_window: 128_000,
|
||||
reserved_tokens: RESERVED_OUTPUT_TOKENS,
|
||||
keep_recent: KEEP_RECENT,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompactConfig {
|
||||
pub fn threshold(&self) -> u32 {
|
||||
self.context_window
|
||||
.saturating_sub(self.reserved_tokens)
|
||||
.saturating_sub(AUTOCOMPACT_BUFFER_TOKENS)
|
||||
}
|
||||
}
|
||||
|
||||
/// 压缩状态——跟踪连续失败次数(断路器模式)。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactState {
|
||||
consecutive_failures: u32,
|
||||
}
|
||||
|
||||
impl CompactState {
|
||||
pub fn new() -> Self;
|
||||
pub fn record_success(&mut self);
|
||||
/// 记录失败,返回 true 表示已达断路器上限。
|
||||
pub fn record_failure(&mut self) -> bool;
|
||||
}
|
||||
|
||||
/// 粗略估计消息列表的 token 数(基于字符数,4 字符 ≈ 1 token)。
|
||||
pub fn estimate_message_tokens(messages: &[Message]) -> u32;
|
||||
|
||||
/// 判断是否需要触发自动压缩。
|
||||
pub fn should_compact(messages: &[Message], config: &CompactConfig, state: &CompactState) -> bool;
|
||||
|
||||
/// 执行微压缩——用占位符替换旧的 tool result 内容。
|
||||
/// 返回释放的 token 数。
|
||||
pub fn microcompact(messages: &mut Vec<Message>, keep_recent: usize) -> u32;
|
||||
```
|
||||
|
||||
**与 LlmCycle 集成**:
|
||||
- `LlmCycle` 新增字段 `compact_config: Option<CompactConfig>`, `compact_state: CompactState`
|
||||
- 新增 builder 方法 `with_compact_config()`
|
||||
- `submit()` 开始时调用 `should_compact()` → `microcompact()`
|
||||
|
||||
**完整 LLM 摘要压缩**留占位,Phase 2 实现(需要循环内调用 LLM 的能力)。
|
||||
|
||||
**无新增依赖**。
|
||||
|
||||
---
|
||||
|
||||
## 实现计划
|
||||
|
||||
### Step 1: 先写方案文档
|
||||
|
||||
创建 `docs/3-phase0-remaining.md`(即本文档)。
|
||||
|
||||
### Step 2: ProviderRegistry
|
||||
|
||||
- 创建 `src/llm/provider/registry.rs`
|
||||
- `provider.rs` 添加 `pub mod registry;`
|
||||
- `cargo check` 验证
|
||||
|
||||
### Step 3: HookExecutor
|
||||
|
||||
- 创建 `src/llm/hooks.rs`
|
||||
- `llm.rs` 添加 `pub mod hooks;`
|
||||
- `LlmCycle` 新增字段和方法
|
||||
- `submit()` 中插入钩子触发点
|
||||
- `cargo check` 验证
|
||||
|
||||
### Step 4: StreamEvents
|
||||
|
||||
- `Cargo.toml` 添加 `tokio-stream`
|
||||
- 创建 `src/llm/stream.rs`
|
||||
- `llm.rs` 添加 `pub mod stream;`
|
||||
- `LlmProvider` 添加 `chat_stream()`(默认回退)
|
||||
- `OpenaiProvider` 实现 SSE 解析
|
||||
- `LlmCycle` 添加 `submit_stream()`
|
||||
- `cargo check` 验证
|
||||
|
||||
### Step 5: Auto-compaction
|
||||
|
||||
- 创建 `src/llm/compact.rs`
|
||||
- `llm.rs` 添加 `pub mod compact;`
|
||||
- `LlmCycle` 新增字段和方法
|
||||
- `submit()` 开头插入压缩检查
|
||||
- `cargo check` 验证
|
||||
|
||||
### Step 6: 收尾
|
||||
|
||||
- `cargo clippy` — 无警告
|
||||
- `cargo build --release` — 完整构建
|
||||
- 检查所有新公开 API 有 `///` 注释
|
||||
|
||||
---
|
||||
|
||||
## 风险评估
|
||||
|
||||
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||
|------|------|------|---------|
|
||||
| SSE 解析边界情况 | 中 | 中 | 参考 `reqwest` 的 chunked 响应处理;先用简单的 `lines()` 方式逐行读取 |
|
||||
| token 估算不准 | 中 | 低 | 仅用于触发微压缩的阈值判断,保守估算即可;后续可接入 tiktoken-rs |
|
||||
| 钩子阻断语义复杂 | 低 | 中 | PreRequest 阻断后返回明确错误消息;其他事件点只读不阻断 |
|
||||
| 与后续 Phase 冲突 | 低 | 高 | 保持接口向后兼容,全用可选集成(Option) |
|
||||
|
||||
---
|
||||
|
||||
## 验收标准
|
||||
|
||||
1. `cargo check` 编译通过
|
||||
2. `cargo clippy` 无警告
|
||||
3. 4 个模块文件存在且路径正确
|
||||
4. `ProviderRegistry` 支持注册/查找/默认 Provider
|
||||
5. `HookExecutor` 支持 4 个事件点注册与触发
|
||||
6. `StreamEvents` 支持从 SSE chunk 解析为语义事件
|
||||
7. `Auto-compaction` 支持 token 估算与微压缩
|
||||
8. `LlmCycle` 向后兼容(新增字段全为 Option,不影响现有代码)
|
||||
@@ -1,7 +1,4 @@
|
||||
//! agcore —— 智能体(Agent)核心工具箱。
|
||||
//!
|
||||
//! 当前提供 LLM 调用周期控制作为核心底层能力,后续将扩展至
|
||||
//! 提示词工程、记忆系统、工具调用、Agent 运行时等领域。
|
||||
|
||||
pub mod llm;
|
||||
|
||||
|
||||
+3
-2
@@ -1,8 +1,9 @@
|
||||
//! LLM 调用周期 —— 大模型基础调用周期控制。
|
||||
//!
|
||||
//! 包含核心数据类型、Provider 抽象、OpenAI 兼容实现以及生命周期引擎。
|
||||
|
||||
pub mod compact;
|
||||
pub mod cycle;
|
||||
pub mod error;
|
||||
pub mod hooks;
|
||||
pub mod provider;
|
||||
pub mod stream;
|
||||
pub mod types;
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
//! 上下文自动压缩 —— 当对话历史过长时自动压缩。
|
||||
|
||||
use crate::llm::types::{ContentField, OpenaiChatMessage, OpenaiContentPart};
|
||||
|
||||
const AUTOCOMPACT_BUFFER_TOKENS: u32 = 13_000;
|
||||
const RESERVED_OUTPUT_TOKENS: u32 = 20_000;
|
||||
const MAX_CONSECUTIVE_FAILURES: u32 = 3;
|
||||
const KEEP_RECENT: usize = 6;
|
||||
|
||||
/// 上下文压缩配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactConfig {
|
||||
/// 模型上下文窗口大小(token 数)。
|
||||
pub context_window: u32,
|
||||
/// 为输出预留的 token 数。
|
||||
pub reserved_tokens: u32,
|
||||
/// 微压缩保留的最近消息数。
|
||||
pub keep_recent: usize,
|
||||
}
|
||||
|
||||
impl Default for CompactConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
context_window: 128_000,
|
||||
reserved_tokens: RESERVED_OUTPUT_TOKENS,
|
||||
keep_recent: KEEP_RECENT,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompactConfig {
|
||||
/// 计算自动压缩触发的阈值。
|
||||
pub fn threshold(&self) -> u32 {
|
||||
self.context_window
|
||||
.saturating_sub(self.reserved_tokens)
|
||||
.saturating_sub(AUTOCOMPACT_BUFFER_TOKENS)
|
||||
}
|
||||
}
|
||||
|
||||
/// 压缩状态 —— 跟踪连续失败次数(断路器模式)。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactState {
|
||||
consecutive_failures: u32,
|
||||
}
|
||||
|
||||
impl Default for CompactState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CompactState {
|
||||
/// 创建一个新的压缩状态。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
consecutive_failures: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// 记录一次成功的压缩。
|
||||
pub fn record_success(&mut self) {
|
||||
self.consecutive_failures = 0;
|
||||
}
|
||||
|
||||
/// 记录一次压缩失败。
|
||||
///
|
||||
/// 返回 `true` 表示已达断路器上限,不再尝试。
|
||||
pub fn record_failure(&mut self) -> bool {
|
||||
self.consecutive_failures += 1;
|
||||
self.consecutive_failures >= MAX_CONSECUTIVE_FAILURES
|
||||
}
|
||||
}
|
||||
|
||||
/// 粗略估计消息列表的 token 数(基于字符数,4 字符 ≈ 1 token)。
|
||||
pub fn estimate_message_tokens(messages: &[OpenaiChatMessage]) -> u32 {
|
||||
messages
|
||||
.iter()
|
||||
.map(estimate_single_message_tokens)
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn estimate_single_message_tokens(msg: &OpenaiChatMessage) -> u32 {
|
||||
let role_overhead: u32 = 4;
|
||||
let content_tokens = match msg {
|
||||
OpenaiChatMessage::Developer { content, .. }
|
||||
| OpenaiChatMessage::System { content, .. }
|
||||
| OpenaiChatMessage::User { content, .. }
|
||||
| OpenaiChatMessage::Assistant { content, .. }
|
||||
| OpenaiChatMessage::Function { content, .. } => estimate_content_tokens(content),
|
||||
OpenaiChatMessage::Tool { content, .. } => estimate_content_tokens(content),
|
||||
};
|
||||
role_overhead + content_tokens
|
||||
}
|
||||
|
||||
fn estimate_content_tokens(content: &ContentField) -> u32 {
|
||||
match content {
|
||||
ContentField::String(s) => estimate_text_tokens(s),
|
||||
ContentField::Array(parts) => parts.iter().map(estimate_part_tokens).sum(),
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_part_tokens(part: &OpenaiContentPart) -> u32 {
|
||||
match part {
|
||||
OpenaiContentPart::Text { text } => estimate_text_tokens(text),
|
||||
_ => 50,
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_text_tokens(text: &str) -> u32 {
|
||||
if text.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let len = text.len() as u32;
|
||||
(len * 4).div_ceil(3)
|
||||
}
|
||||
|
||||
/// 判断是否需要触发自动压缩。
|
||||
pub fn should_compact(
|
||||
messages: &[OpenaiChatMessage],
|
||||
config: &CompactConfig,
|
||||
state: &CompactState,
|
||||
) -> bool {
|
||||
if state.consecutive_failures >= MAX_CONSECUTIVE_FAILURES {
|
||||
return false;
|
||||
}
|
||||
let tokens = estimate_message_tokens(messages);
|
||||
tokens >= config.threshold()
|
||||
}
|
||||
|
||||
/// 执行微压缩 —— 用 `[pruned]` 替换旧的 tool result 内容。
|
||||
///
|
||||
/// 这是最便宜的压缩方式,不需要 LLM 调用。
|
||||
/// 保留最近的 `keep_recent` 条消息不变。
|
||||
///
|
||||
/// 返回释放的估算 token 数。
|
||||
pub fn microcompact(messages: &mut [OpenaiChatMessage], keep_recent: usize) -> u32 {
|
||||
if messages.len() <= keep_recent {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let prune_start = messages.len() - keep_recent;
|
||||
let mut freed_tokens: u32 = 0;
|
||||
|
||||
for msg in &messages[..prune_start] {
|
||||
if matches!(msg, OpenaiChatMessage::Tool { .. }) {
|
||||
freed_tokens += estimate_single_message_tokens(msg);
|
||||
}
|
||||
}
|
||||
|
||||
for msg in &mut messages[..prune_start] {
|
||||
if let OpenaiChatMessage::Tool { content, .. } = msg {
|
||||
*content = ContentField::Array(vec![OpenaiContentPart::Text {
|
||||
text: "[pruned]".to_string(),
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
freed_tokens
|
||||
}
|
||||
+216
-1
@@ -1,21 +1,38 @@
|
||||
//! LLM 调用周期控制模块。
|
||||
|
||||
mod retry;
|
||||
pub mod usage;
|
||||
|
||||
pub use retry::RetryConfig;
|
||||
pub use usage::{CostTracker, Usage};
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_core::stream::Stream;
|
||||
use async_stream::stream;
|
||||
|
||||
use crate::llm::compact::{should_compact, microcompact, CompactConfig, CompactState};
|
||||
use crate::llm::cycle::retry::should_retry;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::hooks::{HookContext, HookExecutor};
|
||||
use crate::llm::provider::LlmProvider;
|
||||
use crate::llm::stream::StreamEvent;
|
||||
use crate::llm::types::{
|
||||
ChatRequest, ChatResponse, OpenaiChatMessage, OpenaiTool, ToolChoice, ToolDefinition,
|
||||
};
|
||||
|
||||
/// LLM 调用周期配置。
|
||||
pub struct CycleConfig {
|
||||
/// 模型名称。
|
||||
pub model: String,
|
||||
/// 最大输出 token 数。
|
||||
pub max_tokens: Option<u32>,
|
||||
/// 采样温度。
|
||||
pub temperature: Option<f32>,
|
||||
/// 最大对话轮次。
|
||||
pub max_turns: Option<u32>,
|
||||
/// 重试策略配置。
|
||||
pub retry: RetryConfig,
|
||||
}
|
||||
|
||||
@@ -31,15 +48,20 @@ impl Default for CycleConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// LLM 调用周期 —— 管理一次或多次 LLM 请求的生命周期。
|
||||
pub struct LlmCycle {
|
||||
provider: Box<dyn LlmProvider>,
|
||||
config: CycleConfig,
|
||||
usage: CostTracker,
|
||||
messages: Vec<OpenaiChatMessage>,
|
||||
system_prompt: Option<String>,
|
||||
hook_executor: Option<Arc<HookExecutor>>,
|
||||
compact_config: Option<CompactConfig>,
|
||||
compact_state: CompactState,
|
||||
}
|
||||
|
||||
impl LlmCycle {
|
||||
/// 创建一个新的 LlmCycle。
|
||||
pub fn new(provider: Box<dyn LlmProvider>, config: CycleConfig) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
@@ -47,30 +69,51 @@ impl LlmCycle {
|
||||
usage: CostTracker::default(),
|
||||
messages: Vec::new(),
|
||||
system_prompt: None,
|
||||
hook_executor: None,
|
||||
compact_config: None,
|
||||
compact_state: CompactState::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 设置系统提示词。
|
||||
pub fn with_system_prompt(mut self, prompt: String) -> Self {
|
||||
self.system_prompt = Some(prompt);
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置钩子执行器。
|
||||
pub fn with_hook_executor(mut self, executor: HookExecutor) -> Self {
|
||||
self.hook_executor = Some(Arc::new(executor));
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置上下文压缩配置。
|
||||
pub fn with_compact_config(mut self, config: CompactConfig) -> Self {
|
||||
self.compact_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
/// 获取用量追踪器引用。
|
||||
pub fn usage(&self) -> &CostTracker {
|
||||
&self.usage
|
||||
}
|
||||
|
||||
/// 获取消息历史引用。
|
||||
pub fn messages(&self) -> &[OpenaiChatMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
/// 清空消息历史。
|
||||
pub fn clear_messages(&mut self) {
|
||||
self.messages.clear();
|
||||
}
|
||||
|
||||
/// 重置用量统计。
|
||||
pub fn reset_usage(&mut self) {
|
||||
self.usage.reset();
|
||||
}
|
||||
|
||||
/// 提交用户消息并获取 LLM 响应。
|
||||
pub async fn submit(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
@@ -78,31 +121,203 @@ impl LlmCycle {
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||
|
||||
if let Some(ref config) = self.compact_config
|
||||
&& should_compact(&self.messages, config, &self.compact_state)
|
||||
{
|
||||
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||
if freed > 0 {
|
||||
self.compact_state.record_success();
|
||||
}
|
||||
}
|
||||
|
||||
let mut attempts = 0;
|
||||
|
||||
loop {
|
||||
let request = self.build_request(&tools);
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||
.with_request(&request);
|
||||
let results = executor
|
||||
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||
.await;
|
||||
if results.iter().any(|r| r.should_block) {
|
||||
let reason = results
|
||||
.iter()
|
||||
.find(|r| r.should_block)
|
||||
.and_then(|r| r.reason.clone())
|
||||
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||
return Err(LlmError::Other(reason));
|
||||
}
|
||||
}
|
||||
|
||||
match self.provider.chat(request).await {
|
||||
Ok(response) => {
|
||||
self.messages.push(response.message.clone());
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let post_request = self.build_request(&tools);
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
|
||||
.with_request(&post_request);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||
.await;
|
||||
}
|
||||
|
||||
self.messages.push(response.message.clone());
|
||||
self.usage.add(&response.usage);
|
||||
|
||||
return Ok(response);
|
||||
}
|
||||
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
|
||||
attempts += 1;
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnRetry)
|
||||
.with_error(&e)
|
||||
.with_attempt(attempts);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnRetry, &ctx)
|
||||
.await;
|
||||
}
|
||||
|
||||
let delay = self.config.retry.compute_delay(attempts);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx =
|
||||
HookContext::new(crate::llm::hooks::HookEvent::OnError).with_error(&e);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||
.await;
|
||||
}
|
||||
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 提交用户消息并返回语义事件流。
|
||||
///
|
||||
/// 与 `submit` 不同,该方法返回流式事件而非完整响应。
|
||||
/// 适用于需要实时处理 LLM 输出的场景。
|
||||
pub async fn submit_stream(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
tools: Vec<ToolDefinition>,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>, LlmError> {
|
||||
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||
|
||||
if let Some(ref config) = self.compact_config
|
||||
&& should_compact(&self.messages, config, &self.compact_state)
|
||||
{
|
||||
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||
if freed > 0 {
|
||||
self.compact_state.record_success();
|
||||
}
|
||||
}
|
||||
|
||||
let request = self.build_request(&tools);
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||
.with_request(&request);
|
||||
let results = executor
|
||||
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||
.await;
|
||||
if results.iter().any(|r| r.should_block) {
|
||||
let reason = results
|
||||
.iter()
|
||||
.find(|r| r.should_block)
|
||||
.and_then(|r| r.reason.clone())
|
||||
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||
return Err(LlmError::Other(reason));
|
||||
}
|
||||
}
|
||||
|
||||
let chunk_stream = self.provider.chat_stream(request).await?;
|
||||
let hook_executor = self.hook_executor.clone();
|
||||
let post_request = self.build_request(&tools);
|
||||
|
||||
Ok(Box::pin(stream! {
|
||||
use futures_util::StreamExt;
|
||||
let mut chunk_stream = chunk_stream;
|
||||
|
||||
while let Some(result) = chunk_stream.next().await {
|
||||
match result {
|
||||
Ok(chunk) => {
|
||||
let mut assistant_text = String::new();
|
||||
let mut tool_started: Option<(String, String, String)> = None;
|
||||
|
||||
for choice in &chunk.choices {
|
||||
let delta = &choice.delta;
|
||||
|
||||
if let Some(content) = &delta.content {
|
||||
assistant_text.push_str(content);
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = &delta.tool_calls
|
||||
&& let Some(tc) = tool_calls.first()
|
||||
{
|
||||
let crate::llm::types::OpenaiToolCall::Function { id, function } = tc;
|
||||
tool_started = Some((id.clone(), function.name.clone(), function.arguments.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if !assistant_text.is_empty() {
|
||||
yield StreamEvent::AssistantTextDelta { text: assistant_text };
|
||||
}
|
||||
|
||||
if let Some((tool_call_id, tool_name, arguments)) = tool_started {
|
||||
let args: serde_json::Value = serde_json::from_str(&arguments)
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
yield StreamEvent::ToolExecutionStarted {
|
||||
tool_name,
|
||||
input: args,
|
||||
tool_call_id,
|
||||
};
|
||||
}
|
||||
|
||||
for choice in &chunk.choices {
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
yield StreamEvent::TurnComplete {
|
||||
reason: *finish_reason,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(usage_info) = &chunk.usage {
|
||||
yield StreamEvent::CostUpdate { usage: *usage_info };
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(ref executor) = hook_executor {
|
||||
let ctx = crate::llm::hooks::HookContext::new(
|
||||
crate::llm::hooks::HookEvent::OnError,
|
||||
)
|
||||
.with_error(&e);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||
.await;
|
||||
}
|
||||
yield StreamEvent::Error { message: e.to_string() };
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref executor) = hook_executor {
|
||||
let ctx = crate::llm::hooks::HookContext::new(
|
||||
crate::llm::hooks::HookEvent::PostRequest,
|
||||
)
|
||||
.with_request(&post_request);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||
.await;
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest {
|
||||
let mut messages = self.messages.clone();
|
||||
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
//! 生命周期钩子 —— 在 LLM 调用周期的关键节点插入自定义逻辑。
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::ChatRequest;
|
||||
|
||||
/// 生命周期钩子事件点。
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum HookEvent {
|
||||
/// LLM 请求发起之前(可阻断)。
|
||||
PreRequest,
|
||||
/// 成功响应之后。
|
||||
PostRequest,
|
||||
/// 重试之前(仅可重试错误时触发)。
|
||||
OnRetry,
|
||||
/// 不可恢复错误返回之前。
|
||||
OnError,
|
||||
}
|
||||
|
||||
/// 此次钩子调用的上下文。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookContext<'a> {
|
||||
/// 当前事件点。
|
||||
pub event: HookEvent,
|
||||
/// 当前的请求(部分事件点可用)。
|
||||
pub request: Option<&'a ChatRequest>,
|
||||
/// 当前错误(仅 OnError 和 OnRetry 可用)。
|
||||
pub error: Option<&'a LlmError>,
|
||||
/// 当前重试次数(从 1 开始,仅 OnRetry 可用)。
|
||||
pub attempt: u32,
|
||||
}
|
||||
|
||||
impl<'a> HookContext<'a> {
|
||||
pub(crate) fn new(event: HookEvent) -> Self {
|
||||
Self {
|
||||
event,
|
||||
request: None,
|
||||
error: None,
|
||||
attempt: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_request(mut self, request: &'a ChatRequest) -> Self {
|
||||
self.request = Some(request);
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn with_error(mut self, error: &'a LlmError) -> Self {
|
||||
self.error = Some(error);
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn with_attempt(mut self, attempt: u32) -> Self {
|
||||
self.attempt = attempt;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// 钩子执行结果。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookResult {
|
||||
/// 是否阻断后续操作(仅 PreRequest 有效)。
|
||||
pub should_block: bool,
|
||||
/// 阻断/备注原因。
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
impl HookResult {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn allow() -> Self {
|
||||
Self {
|
||||
should_block: false,
|
||||
reason: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn block(reason: impl Into<String>) -> Self {
|
||||
Self {
|
||||
should_block: true,
|
||||
reason: Some(reason.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 生命周期钩子 trait。
|
||||
#[async_trait]
|
||||
pub trait Hook: Send + Sync {
|
||||
/// 执行钩子逻辑。
|
||||
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult;
|
||||
}
|
||||
|
||||
/// 钩子执行器 —— 管理注册与触发。
|
||||
pub struct HookExecutor {
|
||||
hooks: Vec<(HookEvent, Box<dyn Hook>)>,
|
||||
}
|
||||
|
||||
impl Default for HookExecutor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl HookExecutor {
|
||||
/// 创建一个空的执行器。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
hooks: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 注册一个钩子到指定事件点。
|
||||
pub fn register(&mut self, event: HookEvent, hook: Box<dyn Hook>) {
|
||||
self.hooks.push((event, hook));
|
||||
}
|
||||
|
||||
/// 执行指定事件点的所有钩子。
|
||||
pub async fn execute(&self, event: HookEvent, ctx: &HookContext<'_>) -> Vec<HookResult> {
|
||||
let mut results = Vec::new();
|
||||
for (e, hook) in &self.hooks {
|
||||
if *e == event {
|
||||
results.push(hook.execute(ctx).await);
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
}
|
||||
+21
-1
@@ -1,7 +1,12 @@
|
||||
pub mod openai;
|
||||
pub mod registry;
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use tokio_stream::Stream;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{ChatRequest, ChatResponse};
|
||||
use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatChunk};
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -57,4 +62,19 @@ pub fn create_provider(
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
/// 发送聊天请求并返回完整响应。
|
||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError>;
|
||||
|
||||
/// 流式聊天请求 —— 返回原始 SSE chunk 流。
|
||||
///
|
||||
/// 默认实现回退到非流式调用(包装为单元素流)。
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
request: ChatRequest,
|
||||
) -> Result<
|
||||
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
LlmError,
|
||||
> {
|
||||
let response = self.chat(request).await?;
|
||||
let chunk = OpenaiChatChunk::from(response);
|
||||
Ok(Box::pin(tokio_stream::once(Ok(chunk))))
|
||||
}
|
||||
}
|
||||
|
||||
+122
-1
@@ -1,12 +1,18 @@
|
||||
use std::pin::Pin;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use futures_core::stream::Stream;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use super::LlmProvider;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatResponse};
|
||||
use crate::llm::types::{
|
||||
ChatRequest, ChatResponse, OpenaiChatChunk, OpenaiChatResponse, StreamOptions,
|
||||
};
|
||||
|
||||
pub struct OpenaiProvider {
|
||||
http_client: Client,
|
||||
@@ -111,4 +117,119 @@ impl LlmProvider for OpenaiProvider {
|
||||
|
||||
Ok(ChatResponse::from(chat_response))
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
mut request: ChatRequest,
|
||||
) -> Result<
|
||||
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
LlmError,
|
||||
> {
|
||||
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
|
||||
|
||||
request.stream = Some(true);
|
||||
request.stream_options = Some(StreamOptions {
|
||||
include_usage: Some(true),
|
||||
include_obfuscation: None,
|
||||
});
|
||||
|
||||
info!(model = %request.model, "发送 LLM 流式请求");
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(error = %e, "流式请求失败");
|
||||
Self::map_reqwest_error(e)
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
let status_code: u16 = status.as_u16();
|
||||
|
||||
if !status.is_success() {
|
||||
let body_text = response.text().await.unwrap_or_default();
|
||||
error!(status = status_code, body = %body_text, "流式请求失败");
|
||||
return Err(LlmError::Request {
|
||||
status: status_code,
|
||||
body: body_text,
|
||||
});
|
||||
}
|
||||
|
||||
let byte_stream = response.bytes_stream().map(|r| {
|
||||
r.map_err(|e| LlmError::Other(format!("流式读取失败: {}", e)))
|
||||
});
|
||||
|
||||
Ok(Box::pin(SseChunkStream::new(byte_stream)))
|
||||
}
|
||||
}
|
||||
|
||||
struct SseChunkStream<S> {
|
||||
inner: S,
|
||||
buffer: String,
|
||||
}
|
||||
|
||||
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> SseChunkStream<S> {
|
||||
fn new(stream: S) -> Self {
|
||||
Self {
|
||||
inner: stream,
|
||||
buffer: String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> Stream for SseChunkStream<S> {
|
||||
type Item = Result<OpenaiChatChunk, LlmError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
if let Some(pos) = self.buffer.find("\n") {
|
||||
let line = self.buffer.drain(..pos + 1).collect::<String>();
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty()
|
||||
|| trimmed == "data: [DONE]"
|
||||
|| trimmed == "[DONE]"
|
||||
|| trimmed == "data:"
|
||||
{
|
||||
continue;
|
||||
}
|
||||
let data = if let Some(p) = trimmed.strip_prefix("data: ") {
|
||||
p
|
||||
} else {
|
||||
trimmed
|
||||
};
|
||||
match serde_json::from_str::<OpenaiChatChunk>(data) {
|
||||
Ok(chunk) => return std::task::Poll::Ready(Some(Ok(chunk))),
|
||||
Err(e) => {
|
||||
return std::task::Poll::Ready(Some(Err(LlmError::Other(format!(
|
||||
"Chunk 解析失败: {} | raw: {}",
|
||||
e, data
|
||||
)))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match Pin::new(&mut self.inner).poll_next(cx) {
|
||||
std::task::Poll::Ready(Some(Ok(bytes))) => {
|
||||
if let Ok(s) = std::str::from_utf8(&bytes) {
|
||||
self.buffer.push_str(s);
|
||||
}
|
||||
}
|
||||
std::task::Poll::Ready(Some(Err(e))) => {
|
||||
return std::task::Poll::Ready(Some(Err(e)));
|
||||
}
|
||||
std::task::Poll::Ready(None) => {
|
||||
if self.buffer.is_empty() {
|
||||
return std::task::Poll::Ready(None);
|
||||
}
|
||||
self.buffer.clear();
|
||||
return std::task::Poll::Ready(None);
|
||||
}
|
||||
std::task::Poll::Pending => return std::task::Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
//! Provider 注册表 —— 多 Provider 实例的注册与发现。
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::provider::{create_provider, LlmProvider, ProviderConfig, ProviderType};
|
||||
|
||||
/// Provider 注册表 —— 管理多个 LLM Provider 实例。
|
||||
///
|
||||
/// 支持注册命名 Provider、按名称查找、设置默认 Provider。
|
||||
pub struct ProviderRegistry {
|
||||
providers: HashMap<String, Box<dyn LlmProvider>>,
|
||||
default_name: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for ProviderRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderRegistry {
|
||||
/// 创建一个空的注册表。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
providers: HashMap::new(),
|
||||
default_name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 注册一个已初始化的 Provider 实例。
|
||||
pub fn register(&mut self, name: impl Into<String>, provider: Box<dyn LlmProvider>) {
|
||||
self.providers.insert(name.into(), provider);
|
||||
}
|
||||
|
||||
/// 通过 ProviderType + ProviderConfig 创建并注册。
|
||||
pub fn register_with_config(
|
||||
&mut self,
|
||||
name: impl Into<String>,
|
||||
provider_type: ProviderType,
|
||||
config: ProviderConfig,
|
||||
) -> Result<(), LlmError> {
|
||||
let provider = create_provider(provider_type, config)?;
|
||||
self.register(name, provider);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 设置默认 Provider。
|
||||
pub fn set_default(&mut self, name: &str) -> Result<(), LlmError> {
|
||||
if !self.providers.contains_key(name) {
|
||||
return Err(LlmError::Other(format!("Provider '{}' 不存在", name)));
|
||||
}
|
||||
self.default_name = Some(name.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 按名称查找 Provider。
|
||||
pub fn get(&self, name: &str) -> Option<&dyn LlmProvider> {
|
||||
self.providers.get(name).map(|p| p.as_ref())
|
||||
}
|
||||
|
||||
/// 获取默认 Provider。
|
||||
pub fn get_default(&self) -> Option<&dyn LlmProvider> {
|
||||
self.default_name
|
||||
.as_ref()
|
||||
.and_then(|name| self.get(name))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
//! 流式事件系统 —— 将 LLM 流式响应解析为语义化事件。
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use futures_core::stream::Stream;
|
||||
use futures_util::future::poll_fn;
|
||||
use futures_util::FutureExt;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{FinishReason, OpenaiChatChunk, OpenaiToolCall, Usage};
|
||||
|
||||
/// 流式事件 —— LLM 调用全生命周期的语义化事件。
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamEvent {
|
||||
/// 助手回复文本增量。
|
||||
AssistantTextDelta { text: String },
|
||||
/// 工具调用开始。
|
||||
ToolExecutionStarted {
|
||||
tool_name: String,
|
||||
input: Value,
|
||||
tool_call_id: String,
|
||||
},
|
||||
/// 工具调用完成。
|
||||
ToolExecutionCompleted {
|
||||
tool_name: String,
|
||||
output: Value,
|
||||
is_error: bool,
|
||||
},
|
||||
/// Token 用量更新。
|
||||
CostUpdate { usage: Usage },
|
||||
/// 一轮会话完成。
|
||||
TurnComplete { reason: FinishReason },
|
||||
/// 错误事件。
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
impl StreamEvent {
|
||||
fn error(message: impl Into<String>) -> Self {
|
||||
Self::Error {
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 将原始 OpenaiChatChunk 流解析为 StreamEvent 流。
|
||||
pub fn parse_chunk_stream(
|
||||
chunks: Pin<Box<dyn futures_core::Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
) -> Pin<Box<dyn futures_core::Stream<Item = StreamEvent> + Send>> {
|
||||
Box::pin(ChunkToEventStream { chunks })
|
||||
}
|
||||
|
||||
struct ChunkToEventStream {
|
||||
chunks: Pin<Box<dyn futures_core::Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
}
|
||||
|
||||
impl Stream for ChunkToEventStream {
|
||||
type Item = StreamEvent;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = &mut *self;
|
||||
poll_fn(|cx| match Pin::new(&mut this.chunks).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(chunk))) => {
|
||||
for choice in &chunk.choices {
|
||||
let delta = &choice.delta;
|
||||
|
||||
if let Some(content) = &delta.content {
|
||||
return Poll::Ready(Some(StreamEvent::AssistantTextDelta {
|
||||
text: content.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = &delta.tool_calls
|
||||
&& let Some(tc) = tool_calls.first()
|
||||
{
|
||||
let OpenaiToolCall::Function { id, function } = tc;
|
||||
let args: Value =
|
||||
serde_json::from_str(&function.arguments).unwrap_or(Value::Null);
|
||||
return Poll::Ready(Some(StreamEvent::ToolExecutionStarted {
|
||||
tool_name: function.name.clone(),
|
||||
input: args,
|
||||
tool_call_id: id.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
return Poll::Ready(Some(StreamEvent::TurnComplete {
|
||||
reason: *finish_reason,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = &chunk.usage {
|
||||
return Poll::Ready(Some(StreamEvent::CostUpdate {
|
||||
usage: *usage,
|
||||
}));
|
||||
}
|
||||
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(StreamEvent::error(e.to_string()))),
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
})
|
||||
.poll_unpin(cx)
|
||||
}
|
||||
}
|
||||
@@ -43,6 +43,34 @@ impl From<OpenaiChatResponse> for ChatResponse {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ChatResponse> for OpenaiChatChunk {
|
||||
fn from(response: ChatResponse) -> Self {
|
||||
let delta = Delta::from(response.message.clone());
|
||||
let chunk_choice = ChunkChoice {
|
||||
index: 0,
|
||||
delta,
|
||||
logprobs: None,
|
||||
finish_reason: response.stop_reason,
|
||||
};
|
||||
|
||||
OpenaiChatChunk {
|
||||
id: format!("chunk-{}", std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_nanos())
|
||||
.unwrap_or(0)),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0),
|
||||
model: String::new(),
|
||||
choices: vec![chunk_choice],
|
||||
usage: Some(response.usage),
|
||||
system_fingerprint: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type ChatRequest = OpenaiChatRequest;
|
||||
pub type Message = OpenaiChatMessage;
|
||||
pub type ContentBlock = OpenaiContentPart;
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::llm::types::shared::{FinishReason, ServiceTier};
|
||||
use crate::llm::types::tool::OpenaiToolCall;
|
||||
use crate::llm::types::usage::Usage;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::llm::types::{ContentField, OpenaiContentPart};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TokenLogprob {
|
||||
@@ -115,3 +116,66 @@ pub struct OpenaiChatChunk {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
}
|
||||
|
||||
impl From<OpenaiChatMessage> for Delta {
|
||||
fn from(msg: OpenaiChatMessage) -> Self {
|
||||
match msg {
|
||||
OpenaiChatMessage::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
..
|
||||
} => Delta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: match content {
|
||||
ContentField::String(s) => Some(s),
|
||||
ContentField::Array(parts) => {
|
||||
let mut text = String::new();
|
||||
for part in parts {
|
||||
if let OpenaiContentPart::Text { text: t } = part {
|
||||
text.push_str(&t);
|
||||
}
|
||||
}
|
||||
if text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text)
|
||||
}
|
||||
}
|
||||
},
|
||||
refusal: None,
|
||||
tool_calls,
|
||||
},
|
||||
_ => Delta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<OpenaiChatResponse> for OpenaiChatChunk {
|
||||
fn from(response: OpenaiChatResponse) -> Self {
|
||||
let choices = response
|
||||
.choices
|
||||
.into_iter()
|
||||
.map(|c| ChunkChoice {
|
||||
index: c.index,
|
||||
delta: Delta::from(c.message),
|
||||
logprobs: c.logprobs,
|
||||
finish_reason: c.finish_reason,
|
||||
})
|
||||
.collect();
|
||||
|
||||
OpenaiChatChunk {
|
||||
id: response.id,
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
choices,
|
||||
usage: Some(response.usage),
|
||||
system_fingerprint: response.system_fingerprint,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user