Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 336920554a | |||
| 63c50e1fc7 | |||
| 0c51bb78a6 | |||
| 2ecc0b4001 | |||
| 1fe7f02281 | |||
| 6dc7ee492f | |||
| b571f530f8 | |||
| 59994bf55e | |||
| fb1c530358 | |||
| f818bd59f5 | |||
| 8573c6eb92 | |||
| 692bff5751 | |||
| b6e7acfb0f | |||
| e598f6d3ee | |||
| 5d6bb5e983 | |||
| 0d58d07ab1 | |||
| dd9c5be1fe | |||
| 993ae0eb4b | |||
| 7f5513adf3 | |||
| ea1e5c7f7e | |||
| 32f3edaf19 | |||
| 69b6dd942b | |||
| 99b304e120 | |||
| 0267da93f1 |
+1
-1
@@ -8,7 +8,7 @@ end_of_line = lf
|
|||||||
trim_trailing_whitespace = true
|
trim_trailing_whitespace = true
|
||||||
insert_final_newline = true
|
insert_final_newline = true
|
||||||
charset = utf-8
|
charset = utf-8
|
||||||
max_line_length = 100
|
max_line_length = 120
|
||||||
|
|
||||||
[*.java]
|
[*.java]
|
||||||
indent_size = 4
|
indent_size = 4
|
||||||
|
|||||||
@@ -1032,3 +1032,4 @@ Gemfile.lock
|
|||||||
|
|
||||||
# Specific Project files
|
# Specific Project files
|
||||||
.opencode/**
|
.opencode/**
|
||||||
|
.codegraph/**
|
||||||
|
|||||||
@@ -31,6 +31,8 @@
|
|||||||
### 3. 精准变更 (Surgical Changes)
|
### 3. 精准变更 (Surgical Changes)
|
||||||
**只改动必须改的。只清理你自己造成的混乱。**
|
**只改动必须改的。只清理你自己造成的混乱。**
|
||||||
|
|
||||||
|
- 未明确要求时不修改已有文件
|
||||||
|
- 先确认意图再动手
|
||||||
- 不要"优化"相邻的代码、注释或格式。
|
- 不要"优化"相邻的代码、注释或格式。
|
||||||
- 不要重构没有问题的代码。
|
- 不要重构没有问题的代码。
|
||||||
- 遵循已有风格,即使你自己的写法不同。
|
- 遵循已有风格,即使你自己的写法不同。
|
||||||
@@ -64,9 +66,9 @@
|
|||||||
- 变量/函数命名:`snake_case`
|
- 变量/函数命名:`snake_case`
|
||||||
|
|
||||||
**测试要求**
|
**测试要求**
|
||||||
- 新功能建议附测试;修复 bug 建议附回归测试(不主动编写测试)
|
- 核心业务逻辑需测试(关键算法、边界条件、错误处理)
|
||||||
- 简单明确的逻辑不需要创建测试(如枚举字面值、Getter、无分支的简单转换)
|
- 简单逻辑不需要测试(枚举字面值、Getter、无分支的简单转换)
|
||||||
- 测试结构:AAA 模式 (Arrange-Act-Assert),优先测试边界条件和错误场景
|
- 不主动补测试(除非用户明确要求)
|
||||||
|
|
||||||
**错误处理**
|
**错误处理**
|
||||||
- 优先使用 `Result` 处理错误,避免 `unwrap()`
|
- 优先使用 `Result` 处理错误,避免 `unwrap()`
|
||||||
@@ -76,7 +78,11 @@
|
|||||||
**安全规范**
|
**安全规范**
|
||||||
- 不硬编码密钥,使用环境变量
|
- 不硬编码密钥,使用环境变量
|
||||||
- 用户输入必须验证
|
- 用户输入必须验证
|
||||||
- 依赖包保持更新
|
- 依赖升级策略:
|
||||||
|
- **安全补丁**:立即升级(修复已知漏洞)
|
||||||
|
- **次要版本**:评估后升级(新功能、向后兼容)
|
||||||
|
- **主要版本**:谨慎升级(可能破坏兼容性,需全面测试)
|
||||||
|
- **验证**:升级后运行完整测试套件确保无回归
|
||||||
|
|
||||||
**Git Commit 规范**
|
**Git Commit 规范**
|
||||||
- 使用 Conventional Commits 格式:`<type>(<scope>): <description>`
|
- 使用 Conventional Commits 格式:`<type>(<scope>): <description>`
|
||||||
@@ -91,6 +97,17 @@
|
|||||||
- `test` - 测试相关
|
- `test` - 测试相关
|
||||||
- `chore` - 构建/工具/配置
|
- `chore` - 构建/工具/配置
|
||||||
- 描述使用祈使句、现在时态、句尾无句号
|
- 描述使用祈使句、现在时态、句尾无句号
|
||||||
|
- **Scope 推导规则**:
|
||||||
|
- 从被提交的文件路径中推导出一个最相关的 scope
|
||||||
|
- 一个 commit 只写一个主要 scope,不要罗列多个
|
||||||
|
- 如果改动涉及多个模块,选择影响范围最大的那个或使用更上层抽象名称
|
||||||
|
- 如果改动是全局的,可以省略 scope 或使用 `core`/`global`
|
||||||
|
- **示例**:
|
||||||
|
- `feat(llm): 添加流式响应支持`(修改 `src/llm/stream.rs`)
|
||||||
|
- `fix(memory): 修复向量检索边界条件`(修改 `src/memory/vector_store.rs`)
|
||||||
|
- `refactor(core): 统一错误类型定义`(修改多个模块的错误处理)
|
||||||
|
- `docs: 更新 API 文档`(全局文档更新)
|
||||||
|
- `chore(deps): 升级 tokio 到 1.40`(依赖更新)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -106,7 +123,47 @@
|
|||||||
|
|
||||||
**模块组织**
|
**模块组织**
|
||||||
- 按功能领域组织模块(一个模块一个职责)
|
- 按功能领域组织模块(一个模块一个职责)
|
||||||
- 使用 Rust 新风格模块结构:`foo.rs` 作为模块根,子模块在 `foo/bar.rs`
|
- **2018+ 版风格**:使用 `foo.rs` 作为模块根,子模块放在 `foo/` 目录下
|
||||||
|
- **扁平优先**:如果模块没有子模块,直接用单个 `foo.rs` 文件
|
||||||
|
- **需要子模块时才用目录**:只有当模块需要组织多个子模块时才创建 `foo/` 目录
|
||||||
|
- **避免 `mod.rs`**:不使用 `foo/mod.rs` 风格,统一使用 `foo.rs` 作为模块根
|
||||||
|
- **公共 API 重导出**:在模块根中使用 `pub use` 重新导出子模块的公共类型,提供清晰的 API 边界
|
||||||
|
|
||||||
|
**示例结构图**:
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
├── lib.rs # crate 根,声明顶层模块
|
||||||
|
├── llm.rs # LLM 模块(无子模块,扁平文件)
|
||||||
|
├── memory.rs # 记忆模块根(有子模块)
|
||||||
|
├── memory/
|
||||||
|
│ ├── conversation.rs # ✓ 子模块
|
||||||
|
│ └── vector_store.rs # ✓ 子模块
|
||||||
|
├── prompt.rs # 提示词模块根(有子模块)
|
||||||
|
├── prompt/
|
||||||
|
│ ├── template.rs # ✓ 子模块
|
||||||
|
│ └── optimizer.rs # ✓ 子模块
|
||||||
|
├── tool.rs # 工具模块根(有子模块)
|
||||||
|
├── tool/
|
||||||
|
│ ├── mcp.rs # ✓ 子模块
|
||||||
|
│ └── registry.rs # ✓ 子模块
|
||||||
|
└── agent.rs # Agent 模块根(有子模块)
|
||||||
|
└── agent/
|
||||||
|
├── runtime.rs # ✓ 子模块
|
||||||
|
└── planner.rs # ✓ 子模块
|
||||||
|
|
||||||
|
# ❌ 避免的风格
|
||||||
|
src/
|
||||||
|
└── memory/
|
||||||
|
└── mod.rs # ❌ 不使用 mod.rs 风格
|
||||||
|
|
||||||
|
# ✅ memory.rs 示例:声明子模块并重导出公共 API
|
||||||
|
pub mod conversation;
|
||||||
|
pub mod vector_store;
|
||||||
|
|
||||||
|
// 重导出常用类型,提供简洁的公共 API
|
||||||
|
pub use conversation::ConversationMemory;
|
||||||
|
pub use vector_store::VectorStore;
|
||||||
|
```
|
||||||
|
|
||||||
**测试文件**: 内联测试(`#[cfg(test)] mod tests {}`)或 `tests/` 目录
|
**测试文件**: 内联测试(`#[cfg(test)] mod tests {}`)或 `tests/` 目录
|
||||||
|
|
||||||
@@ -114,6 +171,22 @@
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## 文档规范
|
||||||
|
|
||||||
|
### 方案规范 (docs/)
|
||||||
|
|
||||||
|
**编号规则**:创建新方案前必须先通过 shell 命令确认当前实际最大编号(Unix: `ls docs/` / Windows: `dir docs\`),禁止使用上下文中缓存的编号,如遇冲突自动递增
|
||||||
|
|
||||||
|
**方案文档结构**(6 项):
|
||||||
|
1. **背景与目标** - 问题描述、预期目标
|
||||||
|
2. **需求分析** - 功能需求、非功能需求
|
||||||
|
3. **方案设计** - 架构设计、模块划分、接口定义
|
||||||
|
4. **实现计划** - 任务拆解、优先级、时间估算
|
||||||
|
5. **风险评估** - 潜在风险、缓解措施
|
||||||
|
6. **验收标准** - 可验证的完成条件
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 项目特定规则
|
## 项目特定规则
|
||||||
|
|
||||||
### 项目结构
|
### 项目结构
|
||||||
|
|||||||
+9
-1
@@ -5,13 +5,21 @@ edition = "2024"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
reqwest = { version = "0.12", features = ["json"] }
|
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
thiserror = "2"
|
thiserror = "2"
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
tokio-stream = "0.1"
|
||||||
|
futures = "0.3"
|
||||||
|
futures-util = "0.3"
|
||||||
|
futures-core = "0.3"
|
||||||
|
bytes = "1"
|
||||||
|
async-stream = "0.3"
|
||||||
|
tokio-util = { version = "0.7", features = ["rt"] }
|
||||||
|
time = { version = "0.3", features = ["serde"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
dotenvy = "0.15.7"
|
dotenvy = "0.15.7"
|
||||||
|
|||||||
@@ -0,0 +1,375 @@
|
|||||||
|
# Phase 0 剩余模块 — 实施方案
|
||||||
|
|
||||||
|
> 定稿日期:2026-06-02
|
||||||
|
|
||||||
|
## 背景与目标
|
||||||
|
|
||||||
|
AG Core Phase 0(Foundation)已完成核心数据类型、错误体系、Provider 接口、OpenAI 实现、生命周期引擎等基础设施。剩余 4 个子项尚未实现:**ProviderRegistry**、**HookExecutor**、**StreamEvents**、**Auto-compaction**。它们均被 Roadmap 标记为 P0(Must Have),是本阶段不可或缺的底层能力。
|
||||||
|
|
||||||
|
**目标**:完成这 4 个模块的设计与实现,使 Phase 0 全面交付。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 需求分析
|
||||||
|
|
||||||
|
### 功能需求
|
||||||
|
|
||||||
|
| 模块 | 需求 | 验收条件 |
|
||||||
|
|------|------|---------|
|
||||||
|
| ProviderRegistry | 支持注册命名 Provider、按名称查找、设置默认 | 3 个公开方法 + 工厂辅助 |
|
||||||
|
| HookExecutor | 4 个事件点:PreRequest / PostRequest / OnRetry / OnError | Hook trait + HookExecutor 触发 |
|
||||||
|
| StreamEvents | 流式事件枚举 + Provider 流式方法 + Cycle 流式入口 | 6 种事件类型 + SSE 解析 |
|
||||||
|
| Auto-compaction | Token 估算 + 微压缩 + 断路器 | 触发后释放 token 且不改变语义 |
|
||||||
|
|
||||||
|
### 非功能需求
|
||||||
|
|
||||||
|
- 所有公开 API 必须带 `///` 文档注释
|
||||||
|
- 无新增 `unwrap()` 调用
|
||||||
|
- 与现有 `LlmCycle` 集成时保持向后兼容(全为可选/增量添加)
|
||||||
|
- 错误统一使用 `LlmError` 枚举
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 方案设计
|
||||||
|
|
||||||
|
### 1. ProviderRegistry (`src/llm/provider/registry.rs`)
|
||||||
|
|
||||||
|
**职责**:管理多个 LLM Provider 实例,支持按名称注册、发现、切换。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// src/llm/provider/registry.rs
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use crate::llm::error::LlmError;
|
||||||
|
use crate::llm::provider::{LlmProvider, ProviderConfig, ProviderType, create_provider};
|
||||||
|
|
||||||
|
/// Provider 注册表——管理多个 LLM Provider 实例。
|
||||||
|
pub struct ProviderRegistry {
|
||||||
|
providers: HashMap<String, Box<dyn LlmProvider>>,
|
||||||
|
default_name: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderRegistry {
|
||||||
|
pub fn new() -> Self;
|
||||||
|
|
||||||
|
/// 注册一个已初始化的 Provider 实例。
|
||||||
|
pub fn register(&mut self, name: impl Into<String>, provider: Box<dyn LlmProvider>);
|
||||||
|
|
||||||
|
/// 通过 ProviderType + ProviderConfig 创建并注册。
|
||||||
|
pub fn register_with_config(
|
||||||
|
&mut self,
|
||||||
|
name: impl Into<String>,
|
||||||
|
provider_type: ProviderType,
|
||||||
|
config: ProviderConfig,
|
||||||
|
) -> Result<(), LlmError>;
|
||||||
|
|
||||||
|
/// 设置默认 Provider。
|
||||||
|
pub fn set_default(&mut self, name: &str) -> Result<(), LlmError>;
|
||||||
|
|
||||||
|
/// 按名称查找 Provider。
|
||||||
|
pub fn get(&self, name: &str) -> Option<&dyn LlmProvider>;
|
||||||
|
|
||||||
|
/// 获取默认 Provider。
|
||||||
|
pub fn get_default(&self) -> Option<&dyn LlmProvider>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**无新增依赖**。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. HookExecutor (`src/llm/hooks.rs`)
|
||||||
|
|
||||||
|
**职责**:在 LLM 调用生命周期的关键节点插入自定义逻辑。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// src/llm/hooks.rs
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use crate::llm::error::LlmError;
|
||||||
|
use crate::llm::types::ChatRequest;
|
||||||
|
|
||||||
|
/// 生命周期钩子事件点。
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum HookEvent {
|
||||||
|
/// LLM 请求发起之前(可阻断)。
|
||||||
|
PreRequest,
|
||||||
|
/// 成功响应之后。
|
||||||
|
PostRequest,
|
||||||
|
/// 重试之前(仅可重试错误时触发)。
|
||||||
|
OnRetry,
|
||||||
|
/// 不可恢复错误返回之前。
|
||||||
|
OnError,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 此次钩子调用的上下文。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct HookContext<'a> {
|
||||||
|
pub event: HookEvent,
|
||||||
|
pub request: Option<&'a ChatRequest>,
|
||||||
|
pub error: Option<&'a LlmError>,
|
||||||
|
pub attempt: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 钩子执行结果。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct HookResult {
|
||||||
|
/// 是否阻断后续操作(仅 PreRequest 有效)。
|
||||||
|
pub should_block: bool,
|
||||||
|
/// 阻断/备注原因。
|
||||||
|
pub reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 生命周期钩子 trait。
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Hook: Send + Sync {
|
||||||
|
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 钩子执行器——管理注册与触发。
|
||||||
|
pub struct HookExecutor {
|
||||||
|
hooks: Vec<(HookEvent, Box<dyn Hook>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HookExecutor {
|
||||||
|
pub fn new() -> Self;
|
||||||
|
pub fn register(&mut self, event: HookEvent, hook: Box<dyn Hook>);
|
||||||
|
pub async fn execute(&self, event: HookEvent, ctx: &HookContext<'_>) -> Vec<HookResult>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**与 LlmCycle 集成**:
|
||||||
|
- `LlmCycle` 新增字段 `hook_executor: Option<HookExecutor>`
|
||||||
|
- 新增 builder 方法 `with_hook_executor()`
|
||||||
|
- `submit()` 中在 4 个点触发(PreRequest 若阻断则提前返回)
|
||||||
|
|
||||||
|
**无新增依赖**(async-trait 已存在)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. StreamEvents (`src/llm/stream.rs`)
|
||||||
|
|
||||||
|
**职责**:提供流式 LLM 调用的事件抽象,将原始 SSE chunk 解析为语义化事件。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// src/llm/stream.rs
|
||||||
|
|
||||||
|
use std::pin::Pin;
|
||||||
|
use tokio_stream::Stream;
|
||||||
|
use crate::llm::error::LlmError;
|
||||||
|
use crate::llm::types::{FinishReason, Usage, OpenaiChatChunk, ToolDefinition};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
/// 流式事件——LLM 调用全生命周期的语义化事件。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum StreamEvent {
|
||||||
|
/// 助手回复文本增量。
|
||||||
|
AssistantTextDelta { text: String },
|
||||||
|
/// 工具调用开始。
|
||||||
|
ToolExecutionStarted { tool_name: String, input: Value },
|
||||||
|
/// 工具调用完成。
|
||||||
|
ToolExecutionCompleted { tool_name: String, output: Value, is_error: bool },
|
||||||
|
/// Token 用量更新。
|
||||||
|
CostUpdate { usage: Usage },
|
||||||
|
/// 一轮会话完成。
|
||||||
|
TurnComplete { reason: FinishReason },
|
||||||
|
/// 可恢复的错误事件。
|
||||||
|
Error { message: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 将原始 OpenaiChatChunk 流解析为 StreamEvent 流。
|
||||||
|
pub fn parse_chunk_stream(
|
||||||
|
chunks: Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||||
|
) -> Pin<Box<dyn Stream<Item = StreamEvent> + Send>>;
|
||||||
|
```
|
||||||
|
|
||||||
|
**Provider 层扩展**(`src/llm/provider.rs`):
|
||||||
|
```rust
|
||||||
|
#[async_trait]
|
||||||
|
pub trait LlmProvider: Send + Sync {
|
||||||
|
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError>;
|
||||||
|
|
||||||
|
/// 流式聊天请求——返回原始 SSE chunk 流。
|
||||||
|
/// 默认实现回退到非流式调用。
|
||||||
|
async fn chat_stream(
|
||||||
|
&self,
|
||||||
|
request: ChatRequest,
|
||||||
|
) -> Result<
|
||||||
|
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||||
|
LlmError,
|
||||||
|
> {
|
||||||
|
let response = self.chat(request).await?;
|
||||||
|
let chunk = OpenaiChatChunk::from(response);
|
||||||
|
Ok(Box::pin(tokio_stream::once(Ok(chunk))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**LlmCycle 扩展**:
|
||||||
|
```rust
|
||||||
|
impl LlmCycle {
|
||||||
|
/// 提交请求并返回语义事件流。
|
||||||
|
pub async fn submit_stream(
|
||||||
|
&mut self,
|
||||||
|
prompt: String,
|
||||||
|
tools: Vec<ToolDefinition>,
|
||||||
|
) -> Result<
|
||||||
|
Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
|
||||||
|
LlmError,
|
||||||
|
>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**新增依赖**:
|
||||||
|
```toml
|
||||||
|
tokio-stream = "0.1"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4. Auto-compaction (`src/llm/compact.rs`)
|
||||||
|
|
||||||
|
**职责**:在上下文过长时自动压缩历史消息,避免 ContextLength 错误。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// src/llm/compact.rs
|
||||||
|
|
||||||
|
use crate::llm::types::{ContentField, Message, OpenaiChatMessage, OpenaiContentPart};
|
||||||
|
|
||||||
|
// === 常量 ===
|
||||||
|
const AUTOCOMPACT_BUFFER_TOKENS: u32 = 13_000;
|
||||||
|
const RESERVED_OUTPUT_TOKENS: u32 = 20_000;
|
||||||
|
const MAX_CONSECUTIVE_FAILURES: u32 = 3;
|
||||||
|
const KEEP_RECENT: usize = 6;
|
||||||
|
|
||||||
|
/// 上下文压缩配置。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CompactConfig {
|
||||||
|
/// 模型上下文窗口大小(token 数)。
|
||||||
|
pub context_window: u32,
|
||||||
|
/// 为输出预留的 token 数。
|
||||||
|
pub reserved_tokens: u32,
|
||||||
|
/// 微压缩保留的最近消息数。
|
||||||
|
pub keep_recent: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CompactConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
context_window: 128_000,
|
||||||
|
reserved_tokens: RESERVED_OUTPUT_TOKENS,
|
||||||
|
keep_recent: KEEP_RECENT,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompactConfig {
|
||||||
|
pub fn threshold(&self) -> u32 {
|
||||||
|
self.context_window
|
||||||
|
.saturating_sub(self.reserved_tokens)
|
||||||
|
.saturating_sub(AUTOCOMPACT_BUFFER_TOKENS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 压缩状态——跟踪连续失败次数(断路器模式)。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CompactState {
|
||||||
|
consecutive_failures: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompactState {
|
||||||
|
pub fn new() -> Self;
|
||||||
|
pub fn record_success(&mut self);
|
||||||
|
/// 记录失败,返回 true 表示已达断路器上限。
|
||||||
|
pub fn record_failure(&mut self) -> bool;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 粗略估计消息列表的 token 数(基于字符数,4 字符 ≈ 1 token)。
|
||||||
|
pub fn estimate_message_tokens(messages: &[Message]) -> u32;
|
||||||
|
|
||||||
|
/// 判断是否需要触发自动压缩。
|
||||||
|
pub fn should_compact(messages: &[Message], config: &CompactConfig, state: &CompactState) -> bool;
|
||||||
|
|
||||||
|
/// 执行微压缩——用占位符替换旧的 tool result 内容。
|
||||||
|
/// 返回释放的 token 数。
|
||||||
|
pub fn microcompact(messages: &mut Vec<Message>, keep_recent: usize) -> u32;
|
||||||
|
```
|
||||||
|
|
||||||
|
**与 LlmCycle 集成**:
|
||||||
|
- `LlmCycle` 新增字段 `compact_config: Option<CompactConfig>`, `compact_state: CompactState`
|
||||||
|
- 新增 builder 方法 `with_compact_config()`
|
||||||
|
- `submit()` 开始时调用 `should_compact()` → `microcompact()`
|
||||||
|
|
||||||
|
**完整 LLM 摘要压缩**留占位,Phase 2 实现(需要循环内调用 LLM 的能力)。
|
||||||
|
|
||||||
|
**无新增依赖**。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 实现计划
|
||||||
|
|
||||||
|
### Step 1: 先写方案文档
|
||||||
|
|
||||||
|
创建 `docs/3-phase0-remaining.md`(即本文档)。
|
||||||
|
|
||||||
|
### Step 2: ProviderRegistry
|
||||||
|
|
||||||
|
- 创建 `src/llm/provider/registry.rs`
|
||||||
|
- `provider.rs` 添加 `pub mod registry;`
|
||||||
|
- `cargo check` 验证
|
||||||
|
|
||||||
|
### Step 3: HookExecutor
|
||||||
|
|
||||||
|
- 创建 `src/llm/hooks.rs`
|
||||||
|
- `llm.rs` 添加 `pub mod hooks;`
|
||||||
|
- `LlmCycle` 新增字段和方法
|
||||||
|
- `submit()` 中插入钩子触发点
|
||||||
|
- `cargo check` 验证
|
||||||
|
|
||||||
|
### Step 4: StreamEvents
|
||||||
|
|
||||||
|
- `Cargo.toml` 添加 `tokio-stream`
|
||||||
|
- 创建 `src/llm/stream.rs`
|
||||||
|
- `llm.rs` 添加 `pub mod stream;`
|
||||||
|
- `LlmProvider` 添加 `chat_stream()`(默认回退)
|
||||||
|
- `OpenaiProvider` 实现 SSE 解析
|
||||||
|
- `LlmCycle` 添加 `submit_stream()`
|
||||||
|
- `cargo check` 验证
|
||||||
|
|
||||||
|
### Step 5: Auto-compaction
|
||||||
|
|
||||||
|
- 创建 `src/llm/compact.rs`
|
||||||
|
- `llm.rs` 添加 `pub mod compact;`
|
||||||
|
- `LlmCycle` 新增字段和方法
|
||||||
|
- `submit()` 开头插入压缩检查
|
||||||
|
- `cargo check` 验证
|
||||||
|
|
||||||
|
### Step 6: 收尾
|
||||||
|
|
||||||
|
- `cargo clippy` — 无警告
|
||||||
|
- `cargo build --release` — 完整构建
|
||||||
|
- 检查所有新公开 API 有 `///` 注释
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 风险评估
|
||||||
|
|
||||||
|
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||||
|
|------|------|------|---------|
|
||||||
|
| SSE 解析边界情况 | 中 | 中 | 参考 `reqwest` 的 chunked 响应处理;先用简单的 `lines()` 方式逐行读取 |
|
||||||
|
| token 估算不准 | 中 | 低 | 仅用于触发微压缩的阈值判断,保守估算即可;后续可接入 tiktoken-rs |
|
||||||
|
| 钩子阻断语义复杂 | 低 | 中 | PreRequest 阻断后返回明确错误消息;其他事件点只读不阻断 |
|
||||||
|
| 与后续 Phase 冲突 | 低 | 高 | 保持接口向后兼容,全用可选集成(Option) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 验收标准
|
||||||
|
|
||||||
|
1. `cargo check` 编译通过
|
||||||
|
2. `cargo clippy` 无警告
|
||||||
|
3. 4 个模块文件存在且路径正确
|
||||||
|
4. `ProviderRegistry` 支持注册/查找/默认 Provider
|
||||||
|
5. `HookExecutor` 支持 4 个事件点注册与触发
|
||||||
|
6. `StreamEvents` 支持从 SSE chunk 解析为语义事件
|
||||||
|
7. `Auto-compaction` 支持 token 估算与微压缩
|
||||||
|
8. `LlmCycle` 向后兼容(新增字段全为 Option,不影响现有代码)
|
||||||
@@ -0,0 +1,611 @@
|
|||||||
|
# Phase 1: Prompt Engineering — 方案设计
|
||||||
|
|
||||||
|
> 定稿日期:2026-06-02
|
||||||
|
|
||||||
|
## 背景与目标
|
||||||
|
|
||||||
|
AG Core Phase 0(Foundation)已完成 LLM 调用周期的全部基础设施。Phase 1 的目标是补齐**提示词工程**能力,提供提示词的组合、模板化与优化能力,使其能直接服务于 Phase 2(工具系统)和 Phase 4(Agent 运行时)。
|
||||||
|
|
||||||
|
**目标**:实现 `PromptTemplate`(模板引擎)和 `PromptComposer`(提示词组合器)两个核心组件,覆盖变量插值、条件渲染、多消息序列拼接等场景。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 需求分析
|
||||||
|
|
||||||
|
### 功能需求
|
||||||
|
|
||||||
|
| 模块 | 需求 | 验收条件 |
|
||||||
|
|------|------|---------|
|
||||||
|
| `PromptTemplate` | 支持变量插值 `{{ var }}` | 渲染后正确替换所有变量 |
|
||||||
|
| `PromptTemplate` | 支持条件渲染 `{{#if var}}...{{/if}}` | 变量非空/非 false 时渲染块 |
|
||||||
|
| `PromptTemplate` | 支持列表循环 `{{#each list}}...{{/each}}` | 遍历渲染集合元素 |
|
||||||
|
| `PromptTemplate` | 支持嵌套模板引用 `{{> template_name }}` | 引用已注册的模板片段 |
|
||||||
|
| `PromptTemplate` | 支持部分转义(原样输出 `{{ literal }}`) | 提供 raw block 语法 |
|
||||||
|
| `PromptComposer` | 按角色拼接 system/user/assistant 消息序列 | 输出 `Vec<OpenaiChatMessage>` |
|
||||||
|
| `PromptComposer` | 支持插入预编译的 PromptTemplate | 组合器中混合使用静态文本和模板 |
|
||||||
|
| `PromptComposer` | 支持从已有消息列表扩展 | 接收 `Vec<OpenaiChatMessage>` 作为初始状态 |
|
||||||
|
| `PromptComposer` | 支持多模态 ContentPart 构建(图片/音频/文件) | `user_content()`/`system_content()` 等接受 `OpenaiContentPart` |
|
||||||
|
| `PromptComposer` | 支持 Developer 角色(o1 系列模型) | `developer()` / `developer_template()` / `developer_content()` |
|
||||||
|
| `PromptComposer` | 支持 Tool 角色(工具执行结果回传) | `tool()` / `tool_content()` 接受 `tool_call_id` |
|
||||||
|
| `PromptComposer` | 支持 `name` 字段设置(多角色区分) | `with_name()` 链式方法为上一消息设置名称 |
|
||||||
|
| `PromptComposer` | 支持消息序列合法性验证 | `validate_messages()` 检查顺序约束与角色交替 |
|
||||||
|
| `PromptTemplateRegistry` | 模板注册表:按名称注册、查找、文件加载 | 从字符串/文件注册模板,按名称渲染 |
|
||||||
|
| `PromptTemplateRegistry` | 支持延迟编译模式 | `register_lazy()` 存储原始字符串,首次渲染时编译 |
|
||||||
|
| `PromptTemplate` | 实现 `Display` trait | 输出原始模板字符串(用于日志和调试) |
|
||||||
|
| `TemplateContext` | 支持从 JSON 构造 | `from_json()` 递归转换 `serde_json::Value` 为 `TemplateValue` |
|
||||||
|
|
||||||
|
### 非功能需求
|
||||||
|
|
||||||
|
- 所有公开 API 必须带 `///` 文档注释
|
||||||
|
- 无新增 `unwrap()` 调用
|
||||||
|
- **零运行时依赖**(不使用 tera、askama 等模板引擎 crate)
|
||||||
|
- 模板引擎失败时返回结构化错误(`PromptError`)
|
||||||
|
- 与现有 `OpenaiChatMessage` / `ChatRequest` 类型自然集成
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 方案设计
|
||||||
|
|
||||||
|
### 模块结构
|
||||||
|
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
prompt.rs # prompt 模块根:声明子模块 + 重导出公共 API
|
||||||
|
prompt/
|
||||||
|
template.rs # PromptTemplate — 模板引擎
|
||||||
|
template/
|
||||||
|
registry.rs # PromptTemplateRegistry — 模板注册表
|
||||||
|
composer.rs # PromptComposer — 提示词组合器
|
||||||
|
error.rs # PromptError — 错误类型
|
||||||
|
```
|
||||||
|
|
||||||
|
`prompt.rs` 根模块声明:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// prompt.rs
|
||||||
|
pub mod composer;
|
||||||
|
pub mod error;
|
||||||
|
pub mod template;
|
||||||
|
|
||||||
|
pub use composer::PromptComposer;
|
||||||
|
pub use error::PromptError;
|
||||||
|
pub use template::{PromptTemplate, PromptTemplateRegistry};
|
||||||
|
```
|
||||||
|
|
||||||
|
`lib.rs` 添加:
|
||||||
|
|
||||||
|
```diff
|
||||||
|
pub mod llm;
|
||||||
|
+pub mod prompt;
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1. 模板引擎选择:自建轻量
|
||||||
|
|
||||||
|
**决策**:不使用 `tera` / `askama` / `maud` / `minijinja` 等第三方模板 crate。
|
||||||
|
|
||||||
|
**理由**:
|
||||||
|
- Phase 1 模板需求极其简单(变量插值 + 条件 + 列表循环),不需要 Jinja2/Handlebars 全能力
|
||||||
|
- 无依赖 = 编译更快、无安全面、版本冲突为 0
|
||||||
|
- 自建 50-80 行核心逻辑即可覆盖所有需求
|
||||||
|
- Roadmap 估算 400 行总代码,60 行模板引擎足够
|
||||||
|
|
||||||
|
**语法设计**(参考 Handlebars 子集):
|
||||||
|
|
||||||
|
```
|
||||||
|
{{ variable_name }} → 变量插值
|
||||||
|
{{#if var}}...{{/if}} → 条件渲染(var 存在且非空)
|
||||||
|
{{#if var}}...{{else}}...{{/if}} → 条件 + 否则
|
||||||
|
{{#each list}} {{item}} {{/each}} → 列表循环
|
||||||
|
{{#raw}} {{literal}} {{/raw}} → 原始块(不解析内部模板语法)
|
||||||
|
{{> template_name}} → 引用已注册的嵌套模板
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. PromptTemplate — 模板引擎
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// prompt/template.rs
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use crate::prompt::error::PromptError;
|
||||||
|
|
||||||
|
/// 渲染上下文中使用的值类型。
|
||||||
|
pub enum TemplateValue {
|
||||||
|
String(String),
|
||||||
|
Bool(bool),
|
||||||
|
Array(Vec<TemplateValue>),
|
||||||
|
Object(HashMap<String, TemplateValue>),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `TemplateValue` 自动转换(提升 `ctx.insert("name", "Alice")` 的易用性)。
|
||||||
|
impl From<String> for TemplateValue { ... }
|
||||||
|
impl From<&str> for TemplateValue { ... }
|
||||||
|
impl From<bool> for TemplateValue { ... }
|
||||||
|
|
||||||
|
/// 模板变量上下文。
|
||||||
|
pub struct TemplateContext {
|
||||||
|
vars: HashMap<String, TemplateValue>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TemplateContext {
|
||||||
|
pub fn new() -> Self;
|
||||||
|
|
||||||
|
/// 插入变量(支持 `&str` / `String` / `bool` 自动转换)。
|
||||||
|
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<TemplateValue>);
|
||||||
|
|
||||||
|
pub fn get(&self, key: &str) -> Option<&TemplateValue>;
|
||||||
|
|
||||||
|
/// 从 `serde_json::Value` 递归构造(支持嵌套 Object/Array)。
|
||||||
|
pub fn from_json(value: &serde_json::Value) -> Result<Self, PromptError>;
|
||||||
|
|
||||||
|
/// 从 `HashMap` 构造(适用于配置加载场景)。
|
||||||
|
pub fn from_map(map: HashMap<String, TemplateValue>) -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 预编译的模板。
|
||||||
|
pub struct PromptTemplate {
|
||||||
|
/// 原始模板字符串(用于 debug)。
|
||||||
|
raw: String,
|
||||||
|
/// 编译后的 AST 片段。
|
||||||
|
fragments: Vec<Fragment>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 编译后的 AST 节点(内部枚举)。
|
||||||
|
enum Fragment {
|
||||||
|
Literal(String),
|
||||||
|
Variable { name: String },
|
||||||
|
If { condition: String, body: Vec<Fragment>, else_body: Vec<Fragment> },
|
||||||
|
Each { variable: String, body: Vec<Fragment> },
|
||||||
|
Raw(String),
|
||||||
|
Include(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptTemplate {
|
||||||
|
/// 从模板字符串编译。
|
||||||
|
pub fn compile(template: &str) -> Result<Self, PromptError>;
|
||||||
|
|
||||||
|
/// 使用上下文渲染。
|
||||||
|
pub fn render(&self, ctx: &TemplateContext) -> Result<String, PromptError>;
|
||||||
|
|
||||||
|
/// 注册可引用的子模板。
|
||||||
|
pub fn register_partial(&mut self, name: &str, template: PromptTemplate);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `Display` 输出原始模板字符串,便于日志和调试。
|
||||||
|
impl fmt::Display for PromptTemplate {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "{}", self.raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**编译流程**:
|
||||||
|
1. 逐字符扫描模板字符串
|
||||||
|
2. 遇到 `{{` 时解析指令类型(variable/if/each/raw/include)
|
||||||
|
3. 生成 `Fragment` AST 列表
|
||||||
|
4. 返回 `PromptTemplate { raw, fragments }`
|
||||||
|
|
||||||
|
**渲染流程**:
|
||||||
|
1. 遍历 `fragments`
|
||||||
|
2. `Literal` → 直接追加
|
||||||
|
3. `Variable { name }` → `ctx.get(name)` → 追加字符串值
|
||||||
|
4. `If { condition, body, else_body }` → 判断 `ctx.get(condition)` 是否存在且真 → 递归渲染 body/else_body
|
||||||
|
5. `Each { variable, body }` → `ctx.get(variable)` 转为数组 → 为每个元素设置 `item` 变量 → 递归渲染 body
|
||||||
|
6. `Raw(text)` → 原样追加(不解析 `{{ }}`)
|
||||||
|
7. `Include(name)` → 查找已注册的 partials → 递归渲染
|
||||||
|
|
||||||
|
### 3. PromptTemplateRegistry — 模板注册表
|
||||||
|
|
||||||
|
提供轻量的模板管理能力,按名称注册、查找、从文件加载,适用于管理多个 Agent 的提示词模板。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// prompt/template/registry.rs
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
use crate::prompt::error::PromptError;
|
||||||
|
use crate::prompt::template::{PromptTemplate, TemplateContext};
|
||||||
|
|
||||||
|
/// 内部存储的模板(支持延迟编译)。
|
||||||
|
enum StoredTemplate {
|
||||||
|
Compiled(PromptTemplate),
|
||||||
|
Raw(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 模板注册表——管理多模板实例。
|
||||||
|
pub struct PromptTemplateRegistry {
|
||||||
|
templates: HashMap<String, StoredTemplate>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptTemplateRegistry {
|
||||||
|
pub fn new() -> Self;
|
||||||
|
|
||||||
|
/// 从模板字符串编译并注册(立即编译)。
|
||||||
|
pub fn register(&mut self, name: &str, template: &str) -> Result<(), PromptError>;
|
||||||
|
|
||||||
|
/// 延迟编译注册:只存储原始字符串,首次渲染时编译。
|
||||||
|
/// 适合模板数量多但并非全部立即使用的场景。
|
||||||
|
pub fn register_lazy(&mut self, name: &str, template: &str);
|
||||||
|
|
||||||
|
/// 从文件读取并编译注册。
|
||||||
|
pub fn register_file(&mut self, name: &str, path: &Path) -> Result<(), PromptError>;
|
||||||
|
|
||||||
|
/// 获取已注册的模板(延迟编译的模板在此首次编译)。
|
||||||
|
pub fn get(&mut self, name: &str) -> Result<&PromptTemplate, PromptError>;
|
||||||
|
|
||||||
|
/// 按名称渲染(`{{> name }}` 引用时自动查找)。
|
||||||
|
pub fn render(&mut self, name: &str, ctx: &TemplateContext) -> Result<String, PromptError>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**设计约束**:
|
||||||
|
- 不设全局单例,用户自行创建和持有
|
||||||
|
- 不引入文件系统监听、热加载、序列化等复杂能力
|
||||||
|
- 注册表内部模板可用于解析 `{{> partial_name }}` 子模板引用
|
||||||
|
- 用户仍可单独持有 `PromptTemplate` 实例,不强制使用注册表
|
||||||
|
|
||||||
|
### 4. PromptComposer — 提示词组合器
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// prompt/composer.rs
|
||||||
|
|
||||||
|
use crate::llm::types::message::{OpenaiChatMessage, OpenaiContentPart, ContentField};
|
||||||
|
use crate::prompt::error::PromptError;
|
||||||
|
use crate::prompt::template::{PromptTemplate, TemplateContext};
|
||||||
|
|
||||||
|
/// 提示词组合器——构建多角色消息序列。
|
||||||
|
pub struct PromptComposer {
|
||||||
|
messages: Vec<OpenaiChatMessage>,
|
||||||
|
pending_name: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptComposer {
|
||||||
|
/// 创建一个空的组合器。
|
||||||
|
pub fn new() -> Self;
|
||||||
|
|
||||||
|
/// 从已有的消息列表初始化。
|
||||||
|
pub fn from_messages(messages: Vec<OpenaiChatMessage>) -> Self;
|
||||||
|
|
||||||
|
// ===== 纯文本消息 =====
|
||||||
|
|
||||||
|
/// 添加一条纯文本 system 消息。
|
||||||
|
pub fn system(mut self, text: impl Into<String>) -> Self;
|
||||||
|
|
||||||
|
/// 添加一条纯文本 user 消息。
|
||||||
|
pub fn user(mut self, text: impl Into<String>) -> Self;
|
||||||
|
|
||||||
|
/// 添加一条纯文本 assistant 消息。
|
||||||
|
pub fn assistant(mut self, text: impl Into<String>) -> Self;
|
||||||
|
|
||||||
|
/// 添加一条纯文本 developer 消息(o1 系列模型使用)。
|
||||||
|
pub fn developer(mut self, text: impl Into<String>) -> Self;
|
||||||
|
|
||||||
|
/// 添加一条 Tool 消息(工具执行结果回传)。
|
||||||
|
pub fn tool(mut self, tool_call_id: impl Into<String>, content: impl Into<String>) -> Self;
|
||||||
|
|
||||||
|
// ===== 模板消息 =====
|
||||||
|
|
||||||
|
/// 使用模板和上下文渲染后添加为 user 消息。
|
||||||
|
pub fn user_template(
|
||||||
|
mut self,
|
||||||
|
template: &PromptTemplate,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
) -> Result<Self, PromptError>;
|
||||||
|
|
||||||
|
/// 使用模板和上下文渲染后添加为 system 消息。
|
||||||
|
pub fn system_template(
|
||||||
|
mut self,
|
||||||
|
template: &PromptTemplate,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
) -> Result<Self, PromptError>;
|
||||||
|
|
||||||
|
/// 使用模板和上下文渲染后添加为 assistant 消息。
|
||||||
|
pub fn assistant_template(
|
||||||
|
mut self,
|
||||||
|
template: &PromptTemplate,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
) -> Result<Self, PromptError>;
|
||||||
|
|
||||||
|
/// 使用模板和上下文渲染后添加为 developer 消息。
|
||||||
|
pub fn developer_template(
|
||||||
|
mut self,
|
||||||
|
template: &PromptTemplate,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
) -> Result<Self, PromptError>;
|
||||||
|
|
||||||
|
// ===== 多模态 ContentPart =====
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 system 消息。
|
||||||
|
pub fn system_content(mut self, part: OpenaiContentPart) -> Self;
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 user 消息。
|
||||||
|
pub fn user_content(mut self, part: OpenaiContentPart) -> Self;
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 assistant 消息。
|
||||||
|
pub fn assistant_content(mut self, part: OpenaiContentPart) -> Self;
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 developer 消息。
|
||||||
|
pub fn developer_content(mut self, part: OpenaiContentPart) -> Self;
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 Tool 消息。
|
||||||
|
pub fn tool_content(mut self, tool_call_id: impl Into<String>, part: OpenaiContentPart) -> Self;
|
||||||
|
|
||||||
|
/// 批量添加 ContentPart 作为 user 消息。
|
||||||
|
pub fn user_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self;
|
||||||
|
|
||||||
|
/// 批量添加 ContentPart 作为 system 消息。
|
||||||
|
pub fn system_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self;
|
||||||
|
|
||||||
|
/// 批量添加 ContentPart 作为 assistant 消息。
|
||||||
|
pub fn assistant_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self;
|
||||||
|
|
||||||
|
/// 批量添加 ContentPart 作为 developer 消息。
|
||||||
|
pub fn developer_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self;
|
||||||
|
|
||||||
|
// ===== 角色标识 =====
|
||||||
|
|
||||||
|
/// 为上一条添加的消息设置 `name` 字段(多 agent 系统中区分同角色实体)。
|
||||||
|
pub fn with_name(mut self, name: impl Into<String>) -> Self;
|
||||||
|
|
||||||
|
// ===== 构建 =====
|
||||||
|
|
||||||
|
/// 构建最终的消息列表。
|
||||||
|
pub fn build(self) -> Vec<OpenaiChatMessage>;
|
||||||
|
|
||||||
|
/// 构建并直接创建 ChatRequest(需搭配 model 参数)。
|
||||||
|
/// 返回的 `OpenaiChatRequest` 中 `tools`、`temperature`、`max_tokens` 等字段均为 `None`,
|
||||||
|
/// 可通过结构体更新语法补全:`OpenaiChatRequest { tools: Some(...), ..req }`。
|
||||||
|
pub fn build_request(
|
||||||
|
self,
|
||||||
|
model: impl Into<String>,
|
||||||
|
) -> crate::llm::types::request::OpenaiChatRequest;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 验证消息序列是否符合 OpenAI API 要求。
|
||||||
|
/// 检查项:Tool 消息必须紧跟在匹配的 Assistant 消息后、角色交替规则等。
|
||||||
|
pub fn validate_messages(messages: &[OpenaiChatMessage]) -> Result<(), PromptError>;
|
||||||
|
```
|
||||||
|
|
||||||
|
**Builder 模式设计**:
|
||||||
|
- `PromptComposer` 采用链式调用(builder pattern),与 Rust 生态的主流风格一致
|
||||||
|
- 每个 `system()` / `user()` / `assistant()` / `developer()` / `tool()` 方法返回 `Self`,支持连续调用
|
||||||
|
- `with_name()` 作用于上一条消息,内部通过 `pending_name: Option<String>` 暂存,push 消息时消费
|
||||||
|
- `build()` 返回 `Vec<OpenaiChatMessage>`,`build_request()` 创建完整的 `OpenaiChatRequest`
|
||||||
|
- **`ContentField` 类型约定**:所有纯文本消息(`system()` / `user()` / `assistant()` / `developer()`)统一使用 `ContentField::Array(vec![OpenaiContentPart::Text{...}])`,与 OpenAI API 非流式响应的标准格式一致
|
||||||
|
|
||||||
|
**`validate_messages()` 校验规则**:
|
||||||
|
1. `Tool` 角色的消息必须跟在 `Assistant` 角色且含 `tool_calls` 的消息之后
|
||||||
|
2. 禁止连续出现两条同角色的非 Tool 消息(system 除外)
|
||||||
|
3. 消息列表不能为空
|
||||||
|
|
||||||
|
### 5. PromptError — 错误类型
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// prompt/error.rs
|
||||||
|
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum PromptError {
|
||||||
|
#[error("模板解析错误: {0}")]
|
||||||
|
Parse(String),
|
||||||
|
|
||||||
|
#[error("渲染错误: 变量 '{0}' 未找到")]
|
||||||
|
VariableNotFound(String),
|
||||||
|
|
||||||
|
#[error("渲染错误: 引用的子模板 '{0}' 未注册")]
|
||||||
|
PartialNotFound(String),
|
||||||
|
|
||||||
|
#[error("渲染错误: '{0}' 不是数组,无法遍历")]
|
||||||
|
NotAnArray(String),
|
||||||
|
|
||||||
|
#[error("渲染递归超过最大深度限制 ({0})")]
|
||||||
|
MaxDepthReached(u8),
|
||||||
|
|
||||||
|
#[error("渲染错误: {0}")]
|
||||||
|
Render(String),
|
||||||
|
|
||||||
|
#[error("消息序列校验失败: {0}")]
|
||||||
|
InvalidSequence(String),
|
||||||
|
|
||||||
|
#[error("文件读取错误: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. 使用示例
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use agcore::prompt::{PromptComposer, PromptTemplate, TemplateContext};
|
||||||
|
|
||||||
|
// 编译模板
|
||||||
|
let tpl = PromptTemplate::compile(
|
||||||
|
"你是{{role}}。请回答以下问题:\n{{#if context}}参考背景:{{context}}\n{{/if}}提问:{{question}}"
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// 构建上下文
|
||||||
|
let mut ctx = TemplateContext::new();
|
||||||
|
ctx.insert("role", "资深工程师");
|
||||||
|
ctx.insert("question", "Rust 的所有权规则是什么?");
|
||||||
|
ctx.insert("context", "用户有 Java 背景");
|
||||||
|
|
||||||
|
// 使用组合器构建消息序列
|
||||||
|
let messages = PromptComposer::new()
|
||||||
|
.system("你是一个专业的 Rust 助手")
|
||||||
|
.user_template(&tpl, &ctx)?
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// 可选:直接构建 ChatRequest
|
||||||
|
let request = PromptComposer::new()
|
||||||
|
.system("你是一个翻译助手")
|
||||||
|
.user("Hello, world!")
|
||||||
|
.build_request("gpt-4o");
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. 与 LlmCycle 的集成
|
||||||
|
|
||||||
|
`PromptComposer::build()` 输出 `Vec<OpenaiChatMessage>`,但 **`LlmCycle.messages` 是私有字段**,无法直接赋值。因此**需要对 `LlmCycle` 进行扩展**,提供消息注入入口,使 Composer 能和 `LlmCycle` 的多轮对话循环协同工作。
|
||||||
|
|
||||||
|
**方案**:在 `LlmCycle` 上新增 3 个方法:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// llm/cycle.rs
|
||||||
|
|
||||||
|
impl LlmCycle {
|
||||||
|
/// 直接设置消息历史(覆盖已有消息),支持 Builder 链式调用。
|
||||||
|
pub fn with_messages(mut self, messages: Vec<Message>) -> Self;
|
||||||
|
|
||||||
|
/// 追加消息到历史尾部。
|
||||||
|
pub fn extend_messages(&mut self, messages: Vec<Message>);
|
||||||
|
|
||||||
|
/// 使用预构建消息提交(跳过自动 push user prompt)。
|
||||||
|
/// 与 submit() 不同,不自动添加 user_text(prompt)。
|
||||||
|
pub async fn submit_messages(
|
||||||
|
&mut self,
|
||||||
|
messages: Vec<Message>,
|
||||||
|
tools: Vec<ToolDefinition>,
|
||||||
|
) -> Result<ChatResponse, LlmError>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**`submit_messages()` 与 `submit()` 的区别**:
|
||||||
|
|
||||||
|
| 维度 | `submit()` | `submit_messages()` |
|
||||||
|
|------|-----------|-------------------|
|
||||||
|
| 输入 | `prompt: String` + `tools` | `messages: Vec<Message>` + `tools` |
|
||||||
|
| 内部操作 | 自动 push `user_text(prompt)` | 不自动添加任何消息 |
|
||||||
|
| 适用场景 | 简单单轮对话 | 多轮/预构建消息序列 |
|
||||||
|
| system_prompt 处理 | 自动插入(如果无 System 消息) | 完全由调用方控制 |
|
||||||
|
|
||||||
|
**`system_prompt` 冲突处理**:`submit_messages()` 关闭 `LlmCycle` 的自动 system prompt 插入逻辑,避免与 `PromptComposer` 已构建的 System/Developer 消息重复。由调用方全权控制消息序列内容。
|
||||||
|
|
||||||
|
**使用示例**:
|
||||||
|
```rust
|
||||||
|
let messages = PromptComposer::new()
|
||||||
|
.system("你是一个专业的 Rust 助手")
|
||||||
|
.user_template(&query_tpl, &ctx)?
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let mut cycle = LlmCycle::new(provider, config)
|
||||||
|
.with_messages(messages);
|
||||||
|
|
||||||
|
let resp = cycle.submit_messages(vec![], tools).await?;
|
||||||
|
```
|
||||||
|
|
||||||
|
`PromptComposer::build_request()` 直接创建 `OpenaiChatRequest`,可用于绕过 `LlmCycle` 直接调用 `LlmProvider` 的场景。
|
||||||
|
|
||||||
|
**注意**:`PromptComposer` 模块 **不** 直接依赖 `LlmCycle`(避免 `prompt → cycle` 的强耦合)。集成方法全部在 `LlmCycle` 侧实现,保持单一职责。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 实现计划
|
||||||
|
|
||||||
|
### Step 1: 创建方案文档
|
||||||
|
|
||||||
|
创建 `docs/4-prompt-engineering.md`(即本文档)。
|
||||||
|
|
||||||
|
### Step 2: PromptError
|
||||||
|
|
||||||
|
- 创建 `src/prompt/error.rs`
|
||||||
|
- 定义 `PromptError` 枚举(Parse / Render / VariableNotFound / PartialNotFound)
|
||||||
|
|
||||||
|
### Step 3: PromptTemplate
|
||||||
|
|
||||||
|
- 创建 `src/prompt.rs`(模块根)
|
||||||
|
- 创建 `src/prompt/template.rs`
|
||||||
|
- 实现编译(`compile()`):逐字符扫描 → 生成 `Vec<Fragment>`
|
||||||
|
- 注意处理:嵌套 `#if` 栈匹配、`{{` 字面量转义、未闭合标签检测
|
||||||
|
- 实现渲染(`render()`):递归遍历 Fragment → 拼接字符串
|
||||||
|
- 定义非字符串值渲染格式:`Bool`→`"true"`/`"false"`、`Array`→JSON、`Object`→JSON
|
||||||
|
- 递归加深度限制(16 层)防止循环引用
|
||||||
|
- 支持功能:变量插值、条件渲染、列表循环、原始块、子模板引用
|
||||||
|
- 实现 `Display for PromptTemplate`(输出原始模板字符串)
|
||||||
|
- 编写 20+ 边界测试覆盖:嵌套 if/each、未闭合标签、空变量、空数组 each、递归深度超限
|
||||||
|
- 运行 `cargo test + cargo check` 验证
|
||||||
|
|
||||||
|
### Step 4: PromptTemplateRegistry
|
||||||
|
|
||||||
|
- 推荐选项:`template` 保持单文件,`PromptTemplateRegistry` 合并同文件(~40 行不值得单独目录)
|
||||||
|
- 内部存储使用 `StoredTemplate` 枚举(支持 `Compiled` 和 `Raw` 两种状态)
|
||||||
|
- 实现:`register()` 立即编译、`register_lazy()` 延迟编译、`register_file()` 文件加载
|
||||||
|
- 实现 `get()` / `render()`(延迟编译的模板首次渲染时编译)
|
||||||
|
- 运行 `cargo check` 验证
|
||||||
|
|
||||||
|
### Step 5: PromptComposer
|
||||||
|
|
||||||
|
- 创建 `src/prompt/composer.rs`
|
||||||
|
- 实现 Builder 链式 API,内部维护 `messages: Vec<OpenaiChatMessage>` + `pending_name: Option<String>`
|
||||||
|
- 角色方法:`system()` / `user()` / `assistant()` / `developer()` / `tool()`
|
||||||
|
- 纯文本消息统一使用 `ContentField::Array([Text])`
|
||||||
|
- `tool()` 需传入 `tool_call_id` 和 `content`
|
||||||
|
- 模板方法:`system_template()` / `user_template()` / `assistant_template()` / `developer_template()`
|
||||||
|
- 多模态方法:`*_content()`(单个 ContentPart)和 `*_contents()`(批量)
|
||||||
|
- `with_name()`:作用于上一条消息的 `name` 字段
|
||||||
|
- `build()` / `build_request()`
|
||||||
|
- `validate_messages()`:独立的纯函数,校验 Tool→Assistant 顺序、角色交替、非空
|
||||||
|
- 运行 `cargo check` 验证
|
||||||
|
|
||||||
|
### Step 6: LlmCycle 扩展
|
||||||
|
|
||||||
|
- `cycle.rs` 新增 3 个方法:
|
||||||
|
- `with_messages(self, messages: Vec<Message>) -> Self` — 链式设置消息历史
|
||||||
|
- `extend_messages(&mut self, messages: Vec<Message>)` — 追加消息
|
||||||
|
- `submit_messages(&mut self, messages: Vec<Message>, tools: Vec<ToolDefinition>) -> Result<...>` — 预构建消息提交
|
||||||
|
- `submit_messages()` 关闭自动 system_prompt 插入(避免与 Composer 的 System/Developer 消息重复)
|
||||||
|
- 运行 `cargo check` 验证
|
||||||
|
|
||||||
|
### Step 7: lib.rs 注册
|
||||||
|
|
||||||
|
- `lib.rs` 添加 `pub mod prompt;`
|
||||||
|
- 运行 `cargo check` 验证
|
||||||
|
|
||||||
|
### Step 8: 收尾
|
||||||
|
|
||||||
|
- `cargo clippy` — 无警告
|
||||||
|
- `cargo build` — 完整构建
|
||||||
|
- 检查所有新公开 API 有 `///` 文档注释
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 术语表
|
||||||
|
|
||||||
|
| 术语 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `TemplateValue` | 模板渲染上下文中使用的值类型枚举 |
|
||||||
|
| `TemplateContext` | 模板变量上下文,持有所有变量 |
|
||||||
|
| `PromptTemplate` | 预编译的模板,持有 AST 片段列表 |
|
||||||
|
| `Fragment` | 编译后的 AST 节点(内部枚举) |
|
||||||
|
| `PromptComposer` | 提示词组合器,构建多角色消息序列 |
|
||||||
|
| `PromptError` | 提示词工程专属错误类型 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 风险评估
|
||||||
|
|
||||||
|
| 风险 | 概率 | 缓解措施 |
|
||||||
|
|------|------|---------|
|
||||||
|
| 模板语法不支持复杂场景(如嵌套 each) | 低 | 当前需求不涉及;后续可引入 tera 替换 |
|
||||||
|
| 自定义模板引擎解析有 bug | 中 | 编写单元测试覆盖所有语法分支 |
|
||||||
|
| 与未来 Prompt Optimizer 冲突 | 低 | PromptOptimizer 只修改模板/上下文,不改模板引擎接口 |
|
||||||
|
| 条件语义不明确(什么是"假"值) | 低 | 明确定义:None / false / 空字符串 / 空数组 均为假 |
|
||||||
|
| 子模板循环引用导致栈溢出 | 低 | 渲染时加递归深度限制(16 层)或已注册集合去重检测 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 验收标准
|
||||||
|
|
||||||
|
1. `cargo check` 编译通过
|
||||||
|
2. `cargo clippy` 无警告
|
||||||
|
3. 模块文件路径正确:`src/prompt.rs` + `src/prompt/{template,composer,error}.rs`
|
||||||
|
4. `PromptTemplate::compile()` 能解析含变量/条件/循环的模板
|
||||||
|
5. `PromptTemplate::render()` 正确渲染所有语法
|
||||||
|
6. `PromptTemplate` 实现 `Display` trait,输出原始模板字符串
|
||||||
|
7. `TemplateContext` 提供 `from_json()` / `from_map()` 构造方式,支持 `From<&str>` 自动转换
|
||||||
|
8. `PromptTemplateRegistry` 支持立即编译(`register()`)、延迟编译(`register_lazy()`)、文件加载(`register_file()`)
|
||||||
|
9. `PromptComposer` 支持链式调用,覆盖 System / User / Assistant / Developer / Tool 五种角色
|
||||||
|
10. `PromptComposer` 支持 `user_content()` / `system_content()` / `assistant_content()` / `developer_content()` / `tool_content()` 多模态方法
|
||||||
|
11. `PromptComposer` 支持 `with_name()` 设置消息角色标识
|
||||||
|
12. `PromptComposer::build_request()` 能创建 `OpenaiChatRequest`
|
||||||
|
13. `validate_messages()` 能校验消息序列合法性
|
||||||
|
14. `LlmCycle` 新增 `with_messages()` / `extend_messages()` / `submit_messages()` 支持 Composer 集成
|
||||||
|
15. `lib.rs` 包含 `pub mod prompt;`
|
||||||
|
16. 所有新公开 API 有文档注释
|
||||||
@@ -0,0 +1,996 @@
|
|||||||
|
# Phase 2: Tool System — 方案设计
|
||||||
|
|
||||||
|
> 定稿日期:2026-06-03
|
||||||
|
|
||||||
|
## 背景与目标
|
||||||
|
|
||||||
|
AG Core Phase 0(Foundation)已完成 LLM 调用周期基础设施,Phase 1(Prompt Engineering)已完成提示词组合与模板化。Phase 2 的目标是补齐**工具系统**能力,实现 LLM 驱动的工具定义、注册、调用、权限控制,以及 MCP 协议集成。
|
||||||
|
|
||||||
|
**核心目标**:让 LLM 能通过 `FinishReason::ToolCalls` 触发工具自动执行,并将结果回传至对话上下文,形成完整的"思考 → 调用 → 反馈"闭环。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 需求分析
|
||||||
|
|
||||||
|
### 功能需求
|
||||||
|
|
||||||
|
| 模块 | 需求 | 验收条件 |
|
||||||
|
|------|------|---------|
|
||||||
|
| `BaseTool` trait | 工具抽象接口:名称、描述、参数、执行、权限声明 | 实现 trait 后可注册到 Registry |
|
||||||
|
| `ToolRegistry` | 工具注册、发现、按名称调用 | 注册 3 个工具后能按名称查找到并执行 |
|
||||||
|
| `McpClient` | MCP 协议客户端(stdio transport) | 能启动 MCP 服务器子进程、列出工具、调用工具 |
|
||||||
|
| `PermissionChecker` | 工具执行前权限校验 | 禁止无权限的工具执行,返回结构化错误 |
|
||||||
|
| 自动 Tool 循环 | LlmCycle 收到 ToolCalls 后自动执行工具并回传 | 一个包含工具调用的对话能完整执行 2+ 轮 |
|
||||||
|
| 流式 Tool 事件 | 流式模式下发射 `ToolExecutionCompleted` 事件 | 流式调用中工具执行完成后触发对应事件 |
|
||||||
|
| 工具调用历史持久化 | 自动工具循环产生的 Tool/Assistant 消息正确追加到 `messages` | 查看 `cycle.messages()` 能获取完整工具交互轨迹 |
|
||||||
|
|
||||||
|
### 非功能需求
|
||||||
|
|
||||||
|
- 所有公开 API 必须带 `///` 文档注释
|
||||||
|
- 无新增 `unwrap()` 调用
|
||||||
|
- `BaseTool` 的 `execute()` 必须为 `async`
|
||||||
|
- 工具执行错误必须结构化为 `ToolError`,不允许 panic
|
||||||
|
- MCP 客户端超时默认 30 秒,可配置
|
||||||
|
- 自定义工具与 MCP 工具通过同一 `ToolRegistry` 管理,对 LlmCycle 透明
|
||||||
|
- 权限检查在工具执行之前,阻断后返回错误而非静默跳过
|
||||||
|
- `BaseTool::execute()` 签名必须预留扩展点(`ToolContext` 注入),确保未来 Skill/Agent 层可在不修改 trait 签名的情况下注入 session_id、cancellation_token 等上下文信息
|
||||||
|
- 自动 tool 循环应考虑 token 消耗——工具定义随每轮请求重复发送,工具结果直接追加到对话历史,需提供结果大小限制和截断策略
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 方案设计
|
||||||
|
|
||||||
|
### 模块结构
|
||||||
|
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
tools.rs # tools 模块根:声明子模块 + 重导出公共 API
|
||||||
|
tools/
|
||||||
|
base.rs # BaseTool trait — 工具抽象接口
|
||||||
|
registry.rs # ToolRegistry — 工具注册表
|
||||||
|
permission.rs # PermissionChecker — 权限校验器
|
||||||
|
mcp.rs # McpClient — MCP 协议客户端
|
||||||
|
error.rs # ToolError — 工具系统错误类型
|
||||||
|
```
|
||||||
|
|
||||||
|
`tools.rs` 根模块声明:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// tools.rs
|
||||||
|
pub mod base;
|
||||||
|
pub mod error;
|
||||||
|
pub mod mcp;
|
||||||
|
pub mod permission;
|
||||||
|
pub mod registry;
|
||||||
|
|
||||||
|
pub use base::{BaseTool, ToolContext};
|
||||||
|
pub use error::ToolError;
|
||||||
|
pub use mcp::McpClient;
|
||||||
|
pub use permission::{Permission, PermissionChecker, PermissionConfig};
|
||||||
|
pub use registry::{ToolEntry, ToolInvocation, ToolRegistry};
|
||||||
|
```
|
||||||
|
|
||||||
|
`lib.rs` 添加:
|
||||||
|
|
||||||
|
```diff
|
||||||
|
pub mod llm;
|
||||||
|
pub mod prompt;
|
||||||
|
+pub mod tools;
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1. BaseTool trait — 工具抽象接口
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// tools/base.rs
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::Value;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
use crate::tools::error::ToolError;
|
||||||
|
use crate::tools::permission::Permission;
|
||||||
|
|
||||||
|
/// 工具执行上下文 —— 携带每次执行的运行时信息。
|
||||||
|
/// 新增字段时提供默认值,不破坏已有工具实现。
|
||||||
|
pub struct ToolContext<'a> {
|
||||||
|
/// 当前对话/会话 ID,用于关联性追踪。
|
||||||
|
pub session_id: &'a str,
|
||||||
|
/// 链路追踪 ID,用于跨工具调用的耗时分布。
|
||||||
|
pub trace_id: &'a str,
|
||||||
|
/// 取消令牌,用于优雅取消正在执行的工具。
|
||||||
|
pub cancellation_token: CancellationToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 工具抽象接口 —— 所有工具(自定义或 MCP)最终都实现此 trait。
|
||||||
|
#[async_trait]
|
||||||
|
pub trait BaseTool: Send + Sync {
|
||||||
|
/// 工具名称(唯一标识,用于 LLM 的 tool_calls.name 匹配)。
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
|
/// 工具描述(LLM 据此决定是否调用此工具)。
|
||||||
|
fn description(&self) -> &str;
|
||||||
|
|
||||||
|
/// 工具参数定义(JSON Schema 格式,传递给 LLM 的 tool.parameters)。
|
||||||
|
fn parameters(&self) -> Value;
|
||||||
|
|
||||||
|
/// 声明工具所需的权限列表。
|
||||||
|
fn required_permissions(&self) -> Vec<Permission> {
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行工具调用。
|
||||||
|
/// `ctx` 携带执行上下文(session_id、trace_id 等),Phase 3/4 可扩展字段而不破坏 trait 签名。
|
||||||
|
async fn execute(&self, args: Value, ctx: &ToolContext<'_>) -> Result<Value, ToolError>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**设计说明**:
|
||||||
|
- `name()` 返回 `&str` 而非 `String`,避免每次调用克隆
|
||||||
|
- `parameters()` 返回 `serde_json::Value`,与现有 `OpenaiToolDefinition.parameters` 类型一致
|
||||||
|
- `required_permissions()` 提供默认空实现,简化无敏感操作的工具定义
|
||||||
|
- `execute()` 接收 `Value`(JSON 对象)+ `ToolContext` 作为参数,返回 `Value` 作为结果,与 OpenAI API 的 arguments/output 格式一致
|
||||||
|
- `ToolContext` 从 Phase 2 即注入 `execute()` 签名,防止后续 breaking change;新增字段用 `Option` 包裹或提供默认值
|
||||||
|
|
||||||
|
### 2. ToolRegistry — 工具注册表
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// tools/registry.rs
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::llm::types::{OpenaiToolDefinition, ToolDefinition};
|
||||||
|
use crate::tools::base::BaseTool;
|
||||||
|
use crate::tools::error::ToolError;
|
||||||
|
use crate::tools::permission::{Permission, PermissionChecker};
|
||||||
|
|
||||||
|
/// 工具调用记录 —— 用于追踪和调试。
|
||||||
|
pub struct ToolInvocation {
|
||||||
|
pub tool_name: String,
|
||||||
|
pub input: Value,
|
||||||
|
pub output: Result<Value, ToolError>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 工具注册表 —— 管理工具注册、发现、调用。
|
||||||
|
pub struct ToolRegistry {
|
||||||
|
tools: HashMap<String, Arc<dyn BaseTool>>,
|
||||||
|
permission_checker: Option<Arc<PermissionChecker>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolRegistry {
|
||||||
|
pub fn new() -> Self;
|
||||||
|
|
||||||
|
/// 设置权限检查器(可选,不设置则不检查权限)。
|
||||||
|
pub fn with_permission_checker(mut self, checker: PermissionChecker) -> Self;
|
||||||
|
|
||||||
|
/// 注册一个工具。
|
||||||
|
pub fn register(&mut self, tool: Arc<dyn BaseTool>) -> Result<(), ToolError>;
|
||||||
|
|
||||||
|
/// 批量注册工具。
|
||||||
|
pub fn register_all(&mut self, tools: Vec<Arc<dyn BaseTool>>) -> Result<(), ToolError>;
|
||||||
|
|
||||||
|
/// 注销一个工具。
|
||||||
|
pub fn unregister(&mut self, name: &str) -> Option<Arc<dyn BaseTool>>;
|
||||||
|
|
||||||
|
/// 按名称查找工具。
|
||||||
|
pub fn get(&self, name: &str) -> Option<Arc<dyn BaseTool>>;
|
||||||
|
|
||||||
|
/// 获取所有已注册工具的名称列表。
|
||||||
|
pub fn list_tools(&self) -> Vec<String>;
|
||||||
|
|
||||||
|
/// 获取所有工具的 ToolDefinition 列表(用于传递给 LLM)。
|
||||||
|
pub fn definitions(&self) -> Vec<ToolDefinition>;
|
||||||
|
|
||||||
|
/// 调用一个工具(含权限检查)。
|
||||||
|
pub async fn invoke(&self, name: &str, args: Value) -> Result<ToolInvocation, ToolError>;
|
||||||
|
|
||||||
|
/// 批量执行工具调用(并行执行互不依赖的工具)。
|
||||||
|
pub async fn invoke_all(
|
||||||
|
&self,
|
||||||
|
calls: Vec<(String, Value)>,
|
||||||
|
) -> Vec<ToolInvocation>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**核心逻辑**:
|
||||||
|
- `invoke()`:查找工具 → 权限检查 → 执行 → 返回 `ToolInvocation`
|
||||||
|
- `invoke_all()`:对多个工具调用并行执行(使用 `tokio::join!` 或 `futures::join_all`),适用于 LLM 同时发出多个 tool_calls 的场景
|
||||||
|
- `invoke_all()` 应对每个工具执行添加超时控制(通过 `tokio::time::timeout`),超时时间由 `CycleConfig::tool_timeout_secs` 配置,默认 60 秒,防止单个工具长时间阻塞整个循环
|
||||||
|
- `definitions()`:将注册的工具批量转换为 `Vec<ToolDefinition>`,供 LlmCycle 传递 LLM
|
||||||
|
- `ToolRegistry` 不持有 `PermissionChecker` 的生命周期(使用 `Arc`),允许多个 Registry 共享同一个 Checker
|
||||||
|
|
||||||
|
**使用示例**:
|
||||||
|
```rust
|
||||||
|
let mut registry = ToolRegistry::new()
|
||||||
|
.with_permission_checker(checker);
|
||||||
|
|
||||||
|
registry.register(Arc::new(WeatherTool))?;
|
||||||
|
registry.register(Arc::new(FileReadTool))?;
|
||||||
|
|
||||||
|
// 获取 ToolDefinitions 传递给 LLM
|
||||||
|
let tools = registry.definitions();
|
||||||
|
|
||||||
|
// 收到 LLM 的 tool_calls 后执行
|
||||||
|
let calls = vec![
|
||||||
|
("get_weather".into(), json!({"city": "Beijing"})),
|
||||||
|
("read_file".into(), json!({"path": "/tmp/data.txt"})),
|
||||||
|
];
|
||||||
|
let results = registry.invoke_all(calls).await;
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. PermissionChecker — 权限校验器
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// tools/permission.rs
|
||||||
|
|
||||||
|
/// 权限级别枚举。
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum Permission {
|
||||||
|
/// 只读(读取文件、查询数据库等)。
|
||||||
|
Read,
|
||||||
|
/// 写入(创建/修改文件、插入数据等)。
|
||||||
|
Write,
|
||||||
|
/// 删除(删除文件、记录等)。
|
||||||
|
Delete,
|
||||||
|
/// 网络访问(HTTP 请求等)。
|
||||||
|
Network,
|
||||||
|
/// Shell 命令执行。
|
||||||
|
Shell,
|
||||||
|
/// 文件系统操作(除读/写/删之外的 FS 操作)。
|
||||||
|
FileSystem,
|
||||||
|
/// 自定义权限(可通过 namespaced 字符串扩展)。
|
||||||
|
Custom(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 权限配置。
|
||||||
|
pub struct PermissionConfig {
|
||||||
|
/// 允许的权限列表(空 = 全部允许)。
|
||||||
|
pub allowed: Vec<Permission>,
|
||||||
|
/// 拒绝的权限列表(优先级高于 allowed)。
|
||||||
|
pub denied: Vec<Permission>,
|
||||||
|
/// 是否允许未声明权限的工具执行(默认为 true)。
|
||||||
|
pub allow_unspecified: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PermissionConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
allowed: vec![Permission::Read, Permission::Network],
|
||||||
|
denied: vec![Permission::Delete, Permission::Shell],
|
||||||
|
allow_unspecified: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 权限检查器。
|
||||||
|
pub struct PermissionChecker {
|
||||||
|
config: PermissionConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PermissionChecker {
|
||||||
|
pub fn new(config: PermissionConfig) -> Self;
|
||||||
|
|
||||||
|
/// 检查指定权限是否允许执行。
|
||||||
|
pub fn check(&self, tool_name: &str, permissions: &[Permission]) -> Result<(), ToolError>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**权限判定规则**:
|
||||||
|
1. 如果权限在 `denied` 中 → 拒绝
|
||||||
|
2. 如果权限在 `allowed` 中 → 允许
|
||||||
|
3. 如果 `allowed` 非空且权限不在其中 → 拒绝(白名单模式)
|
||||||
|
4. 如果 `allowed` 为空 → 按 `allow_unspecified` 判定
|
||||||
|
|
||||||
|
### 4. McpClient — MCP 协议客户端
|
||||||
|
|
||||||
|
MCP(Model Context Protocol)是一种基于 JSON-RPC 的协议,用于 LLM 与外部工具系统通信。Phase 2 实现其 **最小可行子集**,优先实现 stdio transport。
|
||||||
|
|
||||||
|
> **传输方式说明**:MCP 协议版本 2025-03-26 定义了两种标准传输——`stdio` 和 `Streamable HTTP`。原有的 `HTTP+SSE` 传输(2024-11-05)已被官方废弃,新实现不应采用。`Streamable HTTP` 通过单一 HTTP 端点同时支持 JSON 响应和 SSE 流式升级,是 HTTP 场景的推荐方案。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// tools/mcp.rs
|
||||||
|
|
||||||
|
use std::process::{Child, Command, Stdio};
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
|
||||||
|
use serde_json::Value;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
|
use tokio::process::{ChildStdin, ChildStdout, Command as TokioCommand};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
use crate::tools::base::BaseTool;
|
||||||
|
use crate::tools::error::ToolError;
|
||||||
|
|
||||||
|
/// MCP 协议版本。
|
||||||
|
const MCP_VERSION: &str = "2025-03-26";
|
||||||
|
|
||||||
|
/// MCP 传输方式。
|
||||||
|
pub enum McpTransport {
|
||||||
|
/// 通过子进程 stdin/stdout 通信。
|
||||||
|
Stdio {
|
||||||
|
command: String,
|
||||||
|
args: Vec<String>,
|
||||||
|
},
|
||||||
|
/// Streamable HTTP 传输(MCP 2025-03-26 引入,替代已废弃的 HTTP+SSE)。
|
||||||
|
/// 客户端通过单一 HTTP 端点与 MCP Server 通信,支持 JSON 和 SSE 流式响应。
|
||||||
|
StreamableHttp {
|
||||||
|
url: String,
|
||||||
|
headers: Option<Vec<(String, String)>>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP 子进程运行时状态(connect() 后创建)。
|
||||||
|
struct ChildProcessState {
|
||||||
|
child: tokio::process::Child,
|
||||||
|
stdin: tokio::io::BufWriter<tokio::process::ChildStdin>,
|
||||||
|
/// 等待响应的请求映射(id → oneshot sender)。
|
||||||
|
pending: HashMap<u64, tokio::sync::oneshot::Sender<Result<Value, ToolError>>>,
|
||||||
|
next_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP 客户端 —— 与 MCP 服务器通信。
|
||||||
|
pub struct McpClient {
|
||||||
|
transport: McpTransport,
|
||||||
|
server_name: String,
|
||||||
|
/// 已初始化的工具列表(缓存)。
|
||||||
|
tools: Vec<McpTool>,
|
||||||
|
/// 是否已初始化。
|
||||||
|
initialized: AtomicBool,
|
||||||
|
/// 超时时间(秒)。
|
||||||
|
timeout_secs: u64,
|
||||||
|
/// 子进程运行时状态(connect() 后创建,close() 后取回)。
|
||||||
|
process: Option<tokio::sync::Mutex<ChildProcessState>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP 服务器暴露的工具(缓存结构)。
|
||||||
|
struct McpTool {
|
||||||
|
name: String,
|
||||||
|
description: Option<String>,
|
||||||
|
input_schema: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpClient {
|
||||||
|
/// 创建一个 MCP 客户端。
|
||||||
|
pub fn new(server_name: impl Into<String>, transport: McpTransport) -> Self;
|
||||||
|
|
||||||
|
/// 设置超时时间。
|
||||||
|
pub fn with_timeout(mut self, secs: u64) -> Self;
|
||||||
|
|
||||||
|
/// 连接并初始化(发送 initialize 请求,获取服务器能力声明)。
|
||||||
|
/// 启动子进程,创建 ChildProcessState(含 reader task)。
|
||||||
|
pub async fn connect(&mut self) -> Result<(), ToolError>;
|
||||||
|
|
||||||
|
/// 列出服务器支持的工具(调用 tools/list)。
|
||||||
|
pub async fn list_tools(&mut self) -> Result<Vec<ToolDefinition>, ToolError>;
|
||||||
|
|
||||||
|
/// 调用一个工具(调用 tools/call)。
|
||||||
|
/// 通过 Mutex 获取 stdin 写入权限,发送 JSON-RPC 请求,通过 id 匹配响应。
|
||||||
|
/// reader task 持续读取 stdout,解析 JSON-RPC 响应,通过 oneshot 通知调用方。
|
||||||
|
pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, ToolError>;
|
||||||
|
|
||||||
|
/// 关闭连接(终止子进程)。
|
||||||
|
/// 发送 shutdown → 等待 5s 优雅退出 → 超时则 child.kill()。
|
||||||
|
pub async fn close(&mut self) -> Result<(), ToolError>;
|
||||||
|
|
||||||
|
/// 将 MCP 客户端转换为 BaseTool 适配器列表(用于注册到 ToolRegistry)。
|
||||||
|
pub fn into_tools(self) -> Vec<Arc<dyn BaseTool>>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**MCP 协议交互流程**:
|
||||||
|
|
||||||
|
```
|
||||||
|
客户端 MCP 服务器
|
||||||
|
│ │
|
||||||
|
├── initialize request ──────► │
|
||||||
|
│ { protocolVersion, capabilities }
|
||||||
|
│◄──── initialize response ── │
|
||||||
|
│ { protocolVersion, serverInfo, capabilities }
|
||||||
|
│ │
|
||||||
|
├── initialized notification ──► │
|
||||||
|
│ │
|
||||||
|
├── tools/list ──────────────► │
|
||||||
|
│◄── tools/list result ──────── │
|
||||||
|
│ { tools: [{ name, description, inputSchema }] }
|
||||||
|
│ │
|
||||||
|
├── tools/call ─────────────► │
|
||||||
|
│ { name, arguments }
|
||||||
|
│◄── tools/call result ──────── │
|
||||||
|
│ { content: [{ type, text }] }
|
||||||
|
│ │
|
||||||
|
├── shutdown ───────────────► │
|
||||||
|
│◄── shutdown ───────────── │
|
||||||
|
```
|
||||||
|
|
||||||
|
**关于 stdio transport 实现**:
|
||||||
|
- 使用 `tokio::process::Command` 启动子进程
|
||||||
|
- stdin 写入 JSON-RPC 请求(每行一个 JSON 对象)
|
||||||
|
- stdout 读取 JSON-RPC 响应(使用 `BufReader` 逐行读取)
|
||||||
|
- 每个请求关联一个 `id`(递增整数),通过 `id` 匹配请求和响应
|
||||||
|
- 进程退出时自动关闭
|
||||||
|
|
||||||
|
**关于 MCP Server 的管理**:
|
||||||
|
- Phase 2 **不** 实现 MCP Server 框架,只实现 Client
|
||||||
|
- MCP Server 由外部提供(如 `npx @anthropic/mcp-server-filesystem`)
|
||||||
|
- 用户需要提供 MCP Server 的启动命令和参数
|
||||||
|
|
||||||
|
**工具缓存说明**:
|
||||||
|
- `McpClient` 在 `list_tools()` 时缓存工具列表,避免每次调用都重新请求
|
||||||
|
- 缓存假设:MCP Server 的工具列表在运行时不会频繁变更(如插件式加载场景除外)
|
||||||
|
- 如需刷新,可通过新增 `refresh_tools()` 方法或基于 TTL(如 60 秒)自动失效
|
||||||
|
|
||||||
|
### 5. ToolError — 错误类型
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// tools/error.rs
|
||||||
|
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum ToolError {
|
||||||
|
#[error("工具 '{0}' 未注册")]
|
||||||
|
NotFound(String),
|
||||||
|
|
||||||
|
#[error("工具 '{0}' 执行失败: {1}")]
|
||||||
|
ExecutionFailed(String, String),
|
||||||
|
|
||||||
|
#[error("工具 '{0}' 参数无效: {1}")]
|
||||||
|
InvalidArguments(String, String),
|
||||||
|
|
||||||
|
#[error("权限被拒绝: 工具 '{0}' 需要 {1:?} 权限")]
|
||||||
|
PermissionDenied(String, String),
|
||||||
|
|
||||||
|
#[error("MCP 协议错误: {0}")]
|
||||||
|
McpError(String),
|
||||||
|
|
||||||
|
#[error("MCP 未初始化: {0}")]
|
||||||
|
McpNotInitialized(String),
|
||||||
|
|
||||||
|
#[error("MCP 超时: {0}")]
|
||||||
|
McpTimeout(String),
|
||||||
|
|
||||||
|
#[error("IO 错误: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
|
||||||
|
#[error("其他错误: {0}")]
|
||||||
|
Other(String),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. LlmCycle 扩展 — 自动 Tool 循环
|
||||||
|
|
||||||
|
**核心设计**:在 `LlmCycle` 中新增 `submit_with_tools()` 方法,自动处理 tool 执行循环。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// llm/cycle.rs 新增
|
||||||
|
|
||||||
|
use crate::tools::registry::ToolRegistry;
|
||||||
|
use crate::tools::error::ToolError;
|
||||||
|
|
||||||
|
impl LlmCycle {
|
||||||
|
/// 提交消息并自动处理工具调用循环。
|
||||||
|
///
|
||||||
|
/// 流程:
|
||||||
|
/// 1. 发送请求(含工具定义)
|
||||||
|
/// 2. 检查响应中的 finish_reason
|
||||||
|
/// 3. 如果是 ToolCalls → 先 push Assistant 消息 → 执行工具 → 回传结果 → 重复 1
|
||||||
|
/// 4. 如果是 Stop/Length → push Assistant 消息 → 返回最终响应
|
||||||
|
///
|
||||||
|
/// 注意:OpenAI API 要求 tool 消息必须紧跟在对应的 Assistant(tool_calls)消息之后。
|
||||||
|
/// 因此 push 工具结果前必须先 push Assistant 响应,否则 API 拒绝请求。
|
||||||
|
pub async fn submit_with_tools(
|
||||||
|
&mut self,
|
||||||
|
prompt: String,
|
||||||
|
registry: &ToolRegistry,
|
||||||
|
) -> Result<ChatResponse, LlmError> {
|
||||||
|
let tools = registry.definitions();
|
||||||
|
let max_turns = self.config.max_tool_turns.unwrap_or(10);
|
||||||
|
let mut turn = 0;
|
||||||
|
|
||||||
|
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||||
|
self.maybe_compact();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
turn += 1;
|
||||||
|
if turn > max_turns {
|
||||||
|
return Err(LlmError::Other(format!(
|
||||||
|
"达到最大工具循环轮次 ({})",
|
||||||
|
max_turns
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
let response = self.submit_request(&tools).await?;
|
||||||
|
|
||||||
|
// 检查是否需要执行工具
|
||||||
|
let should_execute = matches!(
|
||||||
|
response.stop_reason,
|
||||||
|
Some(FinishReason::ToolCalls)
|
||||||
|
) && has_tool_calls(&response.message);
|
||||||
|
|
||||||
|
// 将 Assistant 响应(含 tool_calls 或最终文本)追加到消息历史
|
||||||
|
self.messages.push(response.message.clone());
|
||||||
|
|
||||||
|
if !should_execute {
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 tool_calls 并执行
|
||||||
|
let tool_calls = extract_tool_calls(&response.message);
|
||||||
|
let results = registry.invoke_all(tool_calls).await;
|
||||||
|
|
||||||
|
// 回传工具结果
|
||||||
|
for result in results {
|
||||||
|
let content = match &result.output {
|
||||||
|
Ok(value) => serde_json::to_string(value)
|
||||||
|
.unwrap_or_else(|e| {
|
||||||
|
tracing::warn!("工具结果序列化失败: {}", e);
|
||||||
|
"{}".to_string()
|
||||||
|
}),
|
||||||
|
Err(e) => format!("错误: {}", e),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.messages.push(
|
||||||
|
OpenaiChatMessage::tool_result(result.tool_name.clone(), content)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 每轮工具执行后触发 compaction,防止 token 快速膨胀
|
||||||
|
self.maybe_compact();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 在接近上下文窗口时压缩历史消息。
|
||||||
|
fn maybe_compact(&mut self) {
|
||||||
|
if let Some(ref config) = self.compact_config
|
||||||
|
&& should_compact(&self.messages, config, &self.compact_state)
|
||||||
|
{
|
||||||
|
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||||
|
if freed > 0 {
|
||||||
|
self.compact_state.record_success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 内部请求方法(与 submit 共享重试逻辑,但不 push user message)。
|
||||||
|
async fn submit_request(
|
||||||
|
&mut self,
|
||||||
|
tools: &[ToolDefinition],
|
||||||
|
) -> Result<ChatResponse, LlmError> {
|
||||||
|
// ... 提取 submit() 中的 request → response 逻辑(不含 user prompt push)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**关键设计决策**:
|
||||||
|
|
||||||
|
| 决策 | 选择 | 理由 |
|
||||||
|
|------|------|------|
|
||||||
|
| 循环方式 | 同步循环(单线程串行) | 工具执行依赖前一轮结果,串行更安全 |
|
||||||
|
| 最大轮次 | `CycleConfig.max_tool_turns`,独立于 `max_turns`,默认 `Some(10)` | 防止无限循环(LLM 反复调用工具)。使用独立字段避免影响现有 `submit()`/`submit_messages()` 的 `max_turns` 语义 |
|
||||||
|
| 工具并行 | `invoke_all()` 互不依赖的工具并行 | LLM 可能一次发出多个 tool_calls(parallel_tool_calls) |
|
||||||
|
| 工具超时 | `CycleConfig::tool_timeout_secs`,默认 60 | 防止单个工具长时间阻塞循环。`invoke_all()` 使用 `tokio::time::timeout` 包装 |
|
||||||
|
| 错误处理 | 工具执行错误以文本回传 LLM,而非终止循环 | LLM 可自行从错误中恢复 |
|
||||||
|
| 消息追踪 | 所有工具交互通过 `self.messages` 持久化 | 调用方能通过 `cycle.messages()` 查看完整轨迹 |
|
||||||
|
|
||||||
|
**Token 消耗分析**:
|
||||||
|
|
||||||
|
自动 tool 循环的 token 消耗主要来自三个来源:
|
||||||
|
|
||||||
|
| 来源 | 说明 | 影响程度 |
|
||||||
|
|------|------|---------|
|
||||||
|
| 工具定义重复发送 | `definitions()` 在每轮请求中携带全部工具的 JSON Schema | 注册工具数 × 平均定义大小 × 轮数。20 个工具 × 500B × 5 轮 ≈ 50KB 输入 token |
|
||||||
|
| 工具结果追加历史 | 每次工具执行结果完整追加到 `messages`,后续请求重发全部历史 | 最显著的 token 泄漏源。大结果(如向量搜索 Top-50)单次可能 ~15KB,多轮累加 |
|
||||||
|
| Value→String 序列化 | 工具结果 `serde_json::to_string()` 后 JSON 字符串膨胀 ~20-30% | 线性的常量损耗 |
|
||||||
|
|
||||||
|
**影响估算**:
|
||||||
|
|
||||||
|
| 场景 | 工具相关 token 占比 | 说明 |
|
||||||
|
|------|-------------------|------|
|
||||||
|
| 单次简单查询 | <5% | 可忽略 |
|
||||||
|
| 文件读取+分析(3-4 轮) | ~30% | 工具结果逐步累积 |
|
||||||
|
| 网页搜索+总结(3-5 轮) | ~40% | 工具结果包含页面内容 |
|
||||||
|
| 多工具数据 pipeline(5-10 轮) | ~60%+ | 需关注压缩和限制策略 |
|
||||||
|
|
||||||
|
**缓解方向**(Phase 2 不强制实现,但设计需可扩展):
|
||||||
|
- **结果大小限制**:工具执行结果超过阈值时自动截断(如 `CycleConfig::max_tool_result_bytes`)
|
||||||
|
- **自动压缩**:现有的 Auto-compaction 需感知工具消息,避免压缩掉 LLM 后续依赖的数据
|
||||||
|
- **工具定义缓存**:基础工具定义变化极少,未来可考虑客户端侧缓存(需等 provider 支持)
|
||||||
|
|
||||||
|
**错误分类与处理策略**:
|
||||||
|
|
||||||
|
工具执行错误需要区分"可恢复"和"不可恢复"两类,不可恢复的错误应终止循环而非回传 LLM:
|
||||||
|
|
||||||
|
| 错误类型 | 处理策略 | 理由 |
|
||||||
|
|---------|---------|------|
|
||||||
|
| `ToolError::ExecutionFailed` | 回传 LLM(文本) | LLM 可能下次换参数或换方式重试 |
|
||||||
|
| `ToolError::InvalidArguments` | 回传 LLM(文本) | LLM 可自动修正参数 |
|
||||||
|
| `ToolError::NotFound` | 终止循环,返回 `LlmError` | LLM 无法注册工具,重试无意义 |
|
||||||
|
| `ToolError::PermissionDenied` | 终止循环,返回 `LlmError` | 安全敏感,不应允许重试 |
|
||||||
|
| `ToolError::McpError` | 终止循环,返回 `LlmError` | MCP 链路故障,重试大概率失败 |
|
||||||
|
| `ToolError::McpTimeout` | 终止循环,返回 `LlmError` | 或可考虑重试 1 次后终止 |
|
||||||
|
| `ToolError::Io` | 终止循环,返回 `LlmError` | IO 错误通常是环境问题 |
|
||||||
|
| `ToolError::Other` | 回传 LLM(文本) | 兜底,保守回传 |
|
||||||
|
|
||||||
|
实现上可在 `ToolError` 上添加 `is_recoverable()` 方法,或在 `submit_with_tools()` 中通过 `match` 分支判断。
|
||||||
|
|
||||||
|
**submit_request() 重构说明**:
|
||||||
|
|
||||||
|
提取 `submit_request()` 作为 `submit_with_tools()` 的内部方法时,需确保不影响现有方法的行为。重构后的方法职责矩阵:
|
||||||
|
|
||||||
|
| 方法 | Push user msg | Compaction | Retry | Call provider | Handle response |
|
||||||
|
|------|:---:|:---:|:---:|:---:|:---:|
|
||||||
|
| `submit()` | ✅ | ✅ | ✅ | → `submit_request()` | ✅ |
|
||||||
|
| `submit_messages()` | ❌ | ✅ | ✅ | → `submit_request()` | ✅ |
|
||||||
|
| `submit_with_tools()` | ✅ | ✅ | ✅ | → `submit_request()` | ✅* |
|
||||||
|
| `submit_request()` | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
|
||||||
|
*`submit_with_tools()` 在 `submit_request()` 返回后额外检查 `ToolCalls`,执行工具后递归调用自身。
|
||||||
|
|
||||||
|
**流式模式支持**:
|
||||||
|
|
||||||
|
`submit_stream()` 的增强方案:新增 `submit_stream_with_tools()`,在流式事件层面支持自动 tool 循环。
|
||||||
|
|
||||||
|
> **实现复杂度提示**:流式 tool 循环需要自定义 `Stream` 实现 + 内部状态机(`Streaming` → `ExecutingTools` → `Finished`)。每一轮需要:消费当前流 → 收集事件 → 检测 `TurnComplete(ToolCalls)` → 执行工具 → 发射 `ToolExecutionCompleted` → 发起新流 → 继续 yield。不能用简单的 `stream!` 宏实现。
|
||||||
|
>
|
||||||
|
> 建议 Phaes 3 再实现 `submit_stream_with_tools()`,Phase 2 只实现非流式的 `submit_with_tools()`。如果 Phase 2 需要可先返回 "not yet implemented" 错误。
|
||||||
|
|
||||||
|
```rust
|
||||||
|
impl LlmCycle {
|
||||||
|
pub async fn submit_stream_with_tools(
|
||||||
|
&mut self,
|
||||||
|
prompt: String,
|
||||||
|
registry: &ToolRegistry,
|
||||||
|
) -> Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>, LlmError> {
|
||||||
|
// 1. 使用 submit_stream() 获取初始事件流
|
||||||
|
// 2. 监听 TurnComplete { reason: ToolCalls }
|
||||||
|
// 3. 触发时:通过 ToolRegistry 执行工具
|
||||||
|
// 4. 发射 ToolExecutionCompleted 事件(由 submit_stream_with_tools 负责,非底层 stream parser)
|
||||||
|
// 5. 将工具结果注入 messages
|
||||||
|
// 6. 自动发起下一轮请求(递归)
|
||||||
|
// 7. 直到 finish_reason 为 Stop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**事件发射时序**:
|
||||||
|
```
|
||||||
|
submit_stream_with_tools("查天气")
|
||||||
|
│
|
||||||
|
├─ AssistantTextDelta "我来查一下北京的天气..." ← 底层 stream parser 发射
|
||||||
|
├─ ToolExecutionStarted { tool_name, input, id } ← submit_stream_with_tools 发射
|
||||||
|
├─ TurnComplete { reason: ToolCalls } ← 底层 stream parser 发射
|
||||||
|
│
|
||||||
|
├── [自动] 执行工具 get_weather({city:"北京"})
|
||||||
|
│
|
||||||
|
├─ ToolExecutionCompleted { tool_name, output, ... } ← submit_stream_with_tools 发射
|
||||||
|
│
|
||||||
|
├─ AssistantTextDelta "北京今天 22°C" ← 底层 stream parser 发射
|
||||||
|
├─ TurnComplete { reason: Stop } ← 底层 stream parser 发射
|
||||||
|
│
|
||||||
|
└─ (流结束)
|
||||||
|
|
||||||
|
**事件发射职责划分**:底层 `parse_chunk_stream()` 负责 LLM 原生事件(`AssistantTextDelta`、`TurnComplete`);`submit_stream_with_tools()` 负责工具层事件(`ToolExecutionStarted`、`ToolExecutionCompleted`),在工具执行前/后手动 `yield` 事件。
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. 自定义工具示例
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use agcore::tools::{BaseTool, ToolError};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
struct WeatherTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for WeatherTool {
|
||||||
|
fn name(&self) -> &str { "get_weather" }
|
||||||
|
fn description(&self) -> &str { "获取指定城市的当前天气" }
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": { "type": "string", "description": "城市名称" }
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value) -> Result<Value, ToolError> {
|
||||||
|
let city = args["city"].as_str()
|
||||||
|
.ok_or_else(|| ToolError::InvalidArguments(
|
||||||
|
"get_weather".into(), "缺少 city 参数".into()
|
||||||
|
))?;
|
||||||
|
// 模拟天气查询
|
||||||
|
Ok(serde_json::json!({ "city": city, "temperature": 22, "unit": "°C" }))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8. 模块依赖关系
|
||||||
|
|
||||||
|
```
|
||||||
|
tools/ 模块内部依赖:
|
||||||
|
base.rs → 无内部依赖(Permission 枚举 + ToolError)
|
||||||
|
permission → 无内部依赖
|
||||||
|
registry → base, error, permission
|
||||||
|
mcp → base, error(需通过 registry 注册)
|
||||||
|
error → 无内部依赖
|
||||||
|
|
||||||
|
跨模块依赖:
|
||||||
|
tools/ → llm/types (ToolDefinition 类型)
|
||||||
|
llm/cycle → tools/registry (自动 tool 循环)
|
||||||
|
llm/cycle → tools/error (ToolError 转换)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 9. 未来工具化路线扩展性分析
|
||||||
|
|
||||||
|
> 本节回答"当前设计是否足以支撑未来常规工具、MCP、Skill、记忆等统一走工具调用路线"。
|
||||||
|
|
||||||
|
#### 设计目标
|
||||||
|
|
||||||
|
未来所有 Agent 可调用的能力(常规工具、MCP 工具、Skill、记忆操作)都应通过 `BaseTool` trait 统一暴露给 LLM,`ToolRegistry` 作为唯一的工具发现和调用入口,对 `LlmCycle` 透明。
|
||||||
|
|
||||||
|
#### 各场景支持度评估
|
||||||
|
|
||||||
|
| 场景 | 当前支持度 | 关键瓶颈 |
|
||||||
|
|------|-----------|---------|
|
||||||
|
| 常规工具(天气/计算器) | ✅ 直接可行 | 无 |
|
||||||
|
| MCP 工具(McpClient→BaseTool 适配器) | ✅ 可行 | 适配器模式优雅,MCP 流式/进度能力被 `Value→Value` 约束 |
|
||||||
|
| Memory CRUD(store/recall/forget/update) | ⚠️ 基本可行 | 检索分页、大量结果返回需额外处理 |
|
||||||
|
| 长时运行工具(数据集查询、文件上传) | ❌ 不可行 | 无进度汇报、无 cancellation 机制 |
|
||||||
|
| 多轮确认工具("是否冻结账户?"审批流程) | ❌ 不可行 | 单次调用→单次返回,无法表达"反问→确认"模式 |
|
||||||
|
| Skill 编排(多步骤组合、嵌套执行) | ❌ 不可行 | 无上下文传播(跨步骤传递中间结果)、无工具组合原语 |
|
||||||
|
| Agent 按场景筛选工具子集 | ⚠️ 部分可行 | 无 tag/category 筛选机制 |
|
||||||
|
|
||||||
|
#### 关键扩展点
|
||||||
|
|
||||||
|
**A. `BaseTool::execute()` 签名——预留 `ToolContext` 注入**
|
||||||
|
|
||||||
|
`BaseTool` 是公开 trait,一旦用户实现并发布 crate,后续 breaking change 成本极高。当前签名:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
async fn execute(&self, args: Value) -> Result<Value, ToolError>;
|
||||||
|
```
|
||||||
|
|
||||||
|
未来扩展路径——新增 `ToolContext` 参数,携带执行上下文:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
async fn execute(&self, args: Value, ctx: &ToolContext<'_>) -> Result<Value, ToolError>;
|
||||||
|
```
|
||||||
|
|
||||||
|
`ToolContext` 初始应包含的字段(Phase 2 实现时不必全部实现,但签名需预留参数位置):
|
||||||
|
|
||||||
|
| 字段 | 用途 | 引入阶段 |
|
||||||
|
|------|------|---------|
|
||||||
|
| `session_id: &str` | 追踪一次对话中所有工具调用的关联性 | Phase 2 |
|
||||||
|
| `trace_id: &str` | 链路追踪,跨工具调用的耗时分布 | Phase 2 |
|
||||||
|
| `cancellation_token: CancellationToken` | 优雅取消正在执行的工具 | Phase 2 |
|
||||||
|
| `progress: Option<UnboundedSender<ProgressEvent>>` | 进度汇报(数据处理到 50%) | Phase 3 |
|
||||||
|
| `shared_state: Option<&HashMap<String, Value>>` | Skill 跨步骤传递中间结果 | Phase 4 |
|
||||||
|
|
||||||
|
这样 Skill/Agent 层在 Phase 4 引入时,`execute` 签名不必改,只需在 `ToolContext` 中增加字段。
|
||||||
|
|
||||||
|
**B. `ToolRegistry` 内部结构——引入 `ToolEntry` 元数据**
|
||||||
|
|
||||||
|
当前内部是 `HashMap<String, Arc<dyn BaseTool>>`,未来扩展为:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct ToolEntry {
|
||||||
|
pub tool: Arc<dyn BaseTool>,
|
||||||
|
pub tags: Vec<String>,
|
||||||
|
pub category: String, // "memory", "data", "communication" 等
|
||||||
|
pub version: Option<String>,
|
||||||
|
pub stats: ToolStats, // 调用次数、平均耗时
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
对应的筛选 API:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub fn find_by_tag(&self, tag: &str) -> Vec<&ToolEntry>;
|
||||||
|
pub fn find_by_category(&self, category: &str) -> Vec<&ToolEntry>;
|
||||||
|
pub fn groups(&self) -> HashMap<&str, Vec<&ToolEntry>>;
|
||||||
|
```
|
||||||
|
|
||||||
|
**C. 工具返回模式——从单一 `Value` 到 `ToolOutput` 枚举**
|
||||||
|
|
||||||
|
当前返回类型 `Result<Value, ToolError>` 只能表达"一次性完整返回"。未来根据需要引入多模式输出:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub enum ToolOutput {
|
||||||
|
/// 一次性返回完整结果
|
||||||
|
Final(Value),
|
||||||
|
/// 通过 channel 逐步流式输出结果
|
||||||
|
Streamed { initial: Value, rx: Receiver<Value> },
|
||||||
|
/// 需要 LLM 进一步确认后再继续
|
||||||
|
AwaitingInput { context: Value, prompt: String },
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| 返回模式 | 场景示例 |
|
||||||
|
|---------|---------|
|
||||||
|
| `Final(Value)` | 天气查询、文件读取 |
|
||||||
|
| `Streamed { initial, rx }` | 向量搜索 Top-100 逐批返回 |
|
||||||
|
| `AwaitingInput { context, prompt }` | "检测到可疑交易,是否冻结?" |
|
||||||
|
|
||||||
|
#### 各能力的引入时序
|
||||||
|
|
||||||
|
```
|
||||||
|
Phase 2(当前实现)
|
||||||
|
├─ BaseTool trait (Value→Value, 但签名预留 Context 参数位)
|
||||||
|
├─ ToolRegistry (HashMap<String, ToolEntry> + tag/category 筛选)
|
||||||
|
├─ PermissionChecker / McpClient / ToolError
|
||||||
|
├─ submit_with_tools() / submit_stream_with_tools()
|
||||||
|
└─ ToolContext { session_id, trace_id, cancellation_token }
|
||||||
|
|
||||||
|
Phase 3(Memory 工具化)
|
||||||
|
├─ MemoryStore trait(扩展 BaseTool)
|
||||||
|
├─ memory_store / memory_recall / memory_search 等作为工具注册
|
||||||
|
└─ ToolContext.progress 支持(分批返回检索结果)
|
||||||
|
|
||||||
|
Phase 4(Agent + Skill + 编排)
|
||||||
|
├─ ToolContext.shared_state 支持(跨步骤传递中间结果)
|
||||||
|
├─ ToolOutput 枚举支持(如需要流式/确认模式)
|
||||||
|
├─ ToolChain / ToolSelector 工具组合原语
|
||||||
|
└─ Skill 机制(多步骤编排 + 内部状态)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 已识别但推迟的设计决策
|
||||||
|
|
||||||
|
| 决策 | 推迟原因 | 何时需要 |
|
||||||
|
|------|---------|---------|
|
||||||
|
| `ToolOutput` 枚举 | Phase 2 的所有场景(常规工具/MCP)用 `Value` 足够 | Phase 4 Agent 编排或长时工具 |
|
||||||
|
| 工具 DAG 调度 | Agent 场景后才需要复杂编排 | Phase 4 |
|
||||||
|
| Skill 机制 | 需要先有 Agent 使用工具的实践经验 | Phase 4 |
|
||||||
|
| 工具调用审计持久化 | 可先通过 Hook 点实现简单日志 | Phase 4 |
|
||||||
|
| 用户授权(运行时弹窗确认) | `PermissionChecker` 只做静态策略判定,不处理运行时交互。用户授权属于交互流程,应作为 `ToolOutput::AwaitingInput` 由上层 UI/Agent 层实现 | Phase 4 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 实现计划
|
||||||
|
|
||||||
|
### Step 1: 创建方案文档
|
||||||
|
|
||||||
|
创建 `docs/5-tool-system.md`(即本文档)。
|
||||||
|
|
||||||
|
### Step 2: ToolError
|
||||||
|
|
||||||
|
- 创建 `src/tools/error.rs`
|
||||||
|
- 定义 `ToolError` 枚举(NotFound / ExecutionFailed / InvalidArguments / PermissionDenied / McpError / McpTimeout / Io / Other)
|
||||||
|
- 运行 `cargo check` 验证
|
||||||
|
|
||||||
|
### Step 3: Permission
|
||||||
|
|
||||||
|
- 创建 `src/tools/permission.rs`
|
||||||
|
- 定义 `Permission` 枚举 + `PermissionConfig` + `PermissionChecker`
|
||||||
|
- 编写权限判定逻辑(白名单/黑名单/未指定策略)
|
||||||
|
- 编写 5+ 边界测试覆盖:白名单模式、黑名单模式、空列表、自定义权限冲突
|
||||||
|
- 运行 `cargo test` 验证
|
||||||
|
|
||||||
|
### Step 4: BaseTool trait
|
||||||
|
|
||||||
|
- 创建 `src/tools/base.rs`
|
||||||
|
- 定义 `BaseTool` trait(name / description / parameters / required_permissions / execute)
|
||||||
|
- 定义 `ToolContext` 结构体(session_id / trace_id / cancellation_token),注入 `execute()` 作为第二个参数
|
||||||
|
- 创建 `src/tools.rs` 模块根,声明子模块,重导出公共 API
|
||||||
|
- `lib.rs` 添加 `pub mod tools;`
|
||||||
|
- 编写 1 个 MockTool 测试工具并验证 trait 实现
|
||||||
|
- 运行 `cargo check` 验证
|
||||||
|
|
||||||
|
### Step 5: ToolRegistry
|
||||||
|
|
||||||
|
- 创建 `src/tools/registry.rs`
|
||||||
|
- 定义 `ToolInvocation` 结构体 + `ToolEntry` 元数据包装(tool + tags + category + stats)+ `ToolRegistry`
|
||||||
|
- 实现核心方法:register / get / list / definitions / invoke / invoke_all / find_by_tag / find_by_category
|
||||||
|
- `invoke_all()` 使用 `futures::future::join_all` + `tokio::time::timeout` 并行执行互不依赖的工具(每工具独立超时)
|
||||||
|
- `definitions()` 将 `HashMap` 中的工具转换为 `Vec<ToolDefinition>`
|
||||||
|
- `ToolRegistry` 不支持运行时并发注册(setup 阶段一次性构建),如需热注册由调用方通过 `Arc<RwLock<ToolRegistry>>` 包装
|
||||||
|
- 编写 8+ 测试覆盖:注册冲突、空注册表查找、单次调用、批量并行调用、工具执行失败
|
||||||
|
- 运行 `cargo test` 验证
|
||||||
|
|
||||||
|
### Step 6: LlmCycle 扩展(自动 Tool 循环)
|
||||||
|
|
||||||
|
- 新增 `cycle_submit.rs` 子模块(或直接在 `cycle.rs` 中扩增,取决于代码量)
|
||||||
|
- 提取 `submit_request()` 内部方法(将 submit() 中的 request→response 逻辑独立),同时重构 `submit_messages()` 以复用同一路径
|
||||||
|
- 实现 `submit_with_tools()` 方法:
|
||||||
|
- 循环:submit_request → push Assistant 消息 → 检查 finish_reason → 调用 registry.invoke_all → push tool_results → 重复
|
||||||
|
- 在 push tool_results **之前**先 push Assistant(tool_calls)消息(OpenAI API 要求)
|
||||||
|
- `max_tool_turns` 控制(独立于 `max_turns`),达到上限返回错误
|
||||||
|
- 不可恢复的错误(NotFound、PermissionDenied、McpError)终止循环
|
||||||
|
- 可恢复的错误(ExecutionFailed、InvalidArguments)以文本回传 LLM
|
||||||
|
- 每轮执行后触发 `maybe_compact()` 防止 token 膨胀
|
||||||
|
- `submit_stream_with_tools()` 方法:
|
||||||
|
- Phase 2 标记为未实现(返回 `LlmError::Other("流式 tool 循环将在后续版本中支持")`)
|
||||||
|
- 实际实现推迟到 Phase 3(需要自定义 `ToolStream` 状态机)
|
||||||
|
- 更新 `CycleConfig`:
|
||||||
|
- 新增 `max_tool_turns: Option<u32>`,默认 `Some(10)`(不影响 `max_turns` 语义)
|
||||||
|
- 新增 `tool_timeout_secs: u64`,默认值 60
|
||||||
|
- 新增 `max_tool_result_bytes: Option<usize>`,默认 `Some(65536)`(限制单次工具结果大小)
|
||||||
|
- 编写 3+ 集成测试:单轮 tool 调用、多轮 tool 调用、达到 max_tool_turns 终止
|
||||||
|
- 运行 `cargo test` 验证
|
||||||
|
|
||||||
|
### Step 7: McpClient(MCP 协议客户端)
|
||||||
|
|
||||||
|
- 创建 `src/tools/mcp.rs`
|
||||||
|
- 实现 JSON-RPC 消息结构(Request / Response / Error / Notification)
|
||||||
|
- 定义 `ChildProcessState` 结构体,包含运行时字段:`child`/`stdin`/`pending: HashMap<u64, oneshot::Sender>`/`next_id: u64`
|
||||||
|
- reader task 使用 `tokio::select!` 同时监听 stdout 和 cancellation token
|
||||||
|
- `call_tool()` 通过 Mutex 获取 stdin 写入权限,通过 id 匹配响应
|
||||||
|
- 子进程意外退出时通知所有 pending 请求
|
||||||
|
- 实现 stdio transport:
|
||||||
|
- `connect()`:启动子进程,创建 ChildProcessState,发送 initialize 请求
|
||||||
|
- `list_tools()`:调用 tools/list,缓存结果
|
||||||
|
- `call_tool()`:调用 tools/call,解析响应
|
||||||
|
- `close()`:发送 shutdown → 等待 5s 优雅退出 → 超时则 child.kill()
|
||||||
|
- `StreamableHttp` transport 预留枚举变体,当前返回 "not implemented" 错误,不在 Phase 2 实现
|
||||||
|
- 实现 `into_tools()`:将 MCP 工具转换为 `Vec<Arc<dyn BaseTool>>` 适配器
|
||||||
|
- 设置 30 秒默认超时
|
||||||
|
- 编写 MCP 协议消息序列化/反序列化测试 + 模拟子进程集成测试
|
||||||
|
- 运行 `cargo test` 验证
|
||||||
|
|
||||||
|
### Step 8: 收尾
|
||||||
|
|
||||||
|
- 更新 `docs/roadmap.md` 标记 Phase 2 完成
|
||||||
|
- `cargo clippy` — 无警告
|
||||||
|
- `cargo build` — 完整构建
|
||||||
|
- 检查所有新公开 API 有 `///` 文档注释
|
||||||
|
- `cargo test` — 所有测试通过
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 术语表
|
||||||
|
|
||||||
|
| 术语 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `BaseTool` | 工具抽象接口,所有工具需实现此 trait |
|
||||||
|
| `ToolRegistry` | 工具注册表,管理工具注册、发现、调用 |
|
||||||
|
| `ToolInvocation` | 工具调用记录,包含输入、输出和执行结果 |
|
||||||
|
| `Permission` | 权限级别枚举(Read/Write/Delete/Network/Shell 等) |
|
||||||
|
| `PermissionChecker` | 权限校验器,执行前判定是否允许 |
|
||||||
|
| `McpClient` | MCP 协议客户端,通过 stdio 与 MCP Server 通信 |
|
||||||
|
| `ToolDefinition` | 传递给 LLM 的工具定义(同 `OpenaiToolDefinition`) |
|
||||||
|
| 自动 Tool 循环 | LlmCycle 自动执行 LLM 请求的工具调用并回传结果 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 风险评估
|
||||||
|
|
||||||
|
| 风险 | 概率 | 缓解措施 |
|
||||||
|
|------|------|---------|
|
||||||
|
| MCP 协议规范变化 | 中 | 只实现最小子集(initialize/list_tools/call_tool),封装在 `mcp.rs` 中便于适配 |
|
||||||
|
| MCP 子进程异常退出 | 中 | 实现超时机制 + 错误恢复;进程退出时自动标记为不可用 |
|
||||||
|
| 工具执行死循环(LLM 反复调用工具) | 中 | `max_turns` 硬限制,达到上限后终止循环 |
|
||||||
|
| JSON-RPC 消息竞争(stdio 双工) | 中 | 请求和响应通过 `id` 字段匹配,使用 `Mutex` 保护写操作 + `HashMap<u64, OneshotSender>` 等待响应,实现复杂度高于接口示意 |
|
||||||
|
| 权限配置过于复杂 | 低 | PermissionConfig 提供合理默认值(允许 Read/Network,拒绝 Delete/Shell),简单场景无需自定义 |
|
||||||
|
| 工具调用参数类型不匹配 | 低 | `execute()` 接收 `Value`,由实现方自行校验;通过 `ToolError::InvalidArguments` 返回结构化错误 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 验收标准
|
||||||
|
|
||||||
|
1. `cargo check` 编译通过
|
||||||
|
2. `cargo clippy` 无警告
|
||||||
|
3. 模块文件路径正确:`src/tools.rs` + `src/tools/{base,registry,permission,mcp,error}.rs`
|
||||||
|
4. `BaseTool` trait 可被自定义工具实现,`name()` / `description()` / `parameters()` / `execute()` 四个方法正常工作
|
||||||
|
5. `ToolRegistry` 支持注册、查找、列出、注销操作
|
||||||
|
6. `ToolRegistry::definitions()` 返回正确的 `Vec<ToolDefinition>`
|
||||||
|
7. `ToolRegistry::invoke()` 执行工具前进行权限检查
|
||||||
|
8. `ToolRegistry::invoke_all()` 并行执行多个工具调用
|
||||||
|
9. `PermissionChecker` 根据配置正确判定权限(白名单/黑名单/默认策略)
|
||||||
|
10. `LlmCycle::submit_with_tools()` 收到 `FinishReason::ToolCalls` 后自动执行工具并回传结果
|
||||||
|
11. `LlmCycle::submit_with_tools()` 达到 `max_turns` 上限时终止并返回错误
|
||||||
|
12. `LlmCycle::submit_stream_with_tools()` 在流式模式下发射 `ToolExecutionCompleted` 事件
|
||||||
|
13. 自动 tool 循环产生的 Tool 消息正确追加到 `cycle.messages()`
|
||||||
|
14. `McpClient::connect()` 能完成 MCP 协议握手(initialize)
|
||||||
|
15. `McpClient::list_tools()` 能获取 MCP Server 暴露的工具列表
|
||||||
|
16. `McpClient::call_tool()` 能调用 MCP Server 的工具
|
||||||
|
17. `McpClient::into_tools()` 能生成可供 `ToolRegistry` 注册的适配器
|
||||||
|
18. 所有新公开 API 有文档注释
|
||||||
|
19. 测试覆盖率:`cargo test` 全部通过
|
||||||
|
20. `BaseTool::execute()` 签名通过 `ToolContext` 参数预留了扩展点(session_id、cancellation_token),未来 Skill/Agent 层可在不修改 trait 签名的情况下注入上下文
|
||||||
@@ -0,0 +1,655 @@
|
|||||||
|
# 记忆系统设计方案
|
||||||
|
|
||||||
|
> 设计日期:2026-06-07
|
||||||
|
> 状态:待实现
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 背景与目标
|
||||||
|
|
||||||
|
### 1.1 背景
|
||||||
|
|
||||||
|
AG Core 已完成 Phase 0(LLM 调用周期)、Phase 1(提示词工程)、Phase 2(工具系统)。Phase 3 的目标是构建记忆系统,为 Phase 4(Agent 运行时)提供记忆存储、管理与检索能力。
|
||||||
|
|
||||||
|
### 1.2 目标
|
||||||
|
|
||||||
|
提供一套可插拔的记忆抽象层,支持以下记忆形态:
|
||||||
|
|
||||||
|
- **对话记忆(ConversationMemory)** — 管理多轮对话消息历史,支持 sliding window / 全量策略
|
||||||
|
- **知识库(KnowledgeStore)** — 基于 LLM Wiki 模式的结构化知识管理,Agent 可自主编译和维护知识页面
|
||||||
|
- **检索器(MemoryRetriever)** — 单通道关键词检索,提供统一的记忆查找入口
|
||||||
|
|
||||||
|
### 1.3 设计原则
|
||||||
|
|
||||||
|
- **不引入 embedding 依赖** — 采用 Karpathy's LLM Wiki 模式的 index + keyword 检索,替代传统向量检索
|
||||||
|
- **trait + 轻量默认实现** — 存储抽象接口提供纯内存默认实现(InMemoryStore),满足原型和测试需求
|
||||||
|
- **模块间松耦合** — 记忆系统与 LlmCycle 的集成推迟到 Phase 4 Agent Runtime,Phase 3 只定义接口和数据操作
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 需求分析
|
||||||
|
|
||||||
|
### 2.1 功能需求
|
||||||
|
|
||||||
|
| ID | 需求 | 优先级 | 说明 |
|
||||||
|
|----|------|--------|------|
|
||||||
|
| F1 | MemoryStore 通用键值存储 | P0 | save/get/delete/list |
|
||||||
|
| F2 | 对话消息管理 | P0 | 按 session 管理,支持 sliding window / full |
|
||||||
|
| F3 | 知识页面 CRUD | P1 | 创建/更新/删除/检索知识页面 |
|
||||||
|
| F4 | 知识页面关键词检索 | P1 | 基于标题/摘要/标签的关键词匹配 |
|
||||||
|
| F5 | 知识页面索引维护 | P1 | 维护可遍历的内容目录(index) |
|
||||||
|
| F6 | 可插拔后端 | P0 | MemoryStore 通过 trait 抽象,下游可实现自定义后端 |
|
||||||
|
| F7 | 记忆淘汰 | P1 | 支持 TTL 过期淘汰、容量上限淘汰 |
|
||||||
|
| F8 | 消息条目级淘汰 | P1 | ConversationMemory 达到上限后删除最旧消息 |
|
||||||
|
|
||||||
|
### 2.2 非功能需求
|
||||||
|
|
||||||
|
| ID | 需求 | 说明 |
|
||||||
|
|----|------|------|
|
||||||
|
| NF1 | 零 embedding 依赖 | 核心库不引入任何向量数据库或 embedding 模型依赖 |
|
||||||
|
| NF2 | 错误体系完善 | MemoryError 枚举,支持 is_recoverable() 分类 |
|
||||||
|
| NF3 | 线程安全 | 所有存储实现满足 Send + Sync |
|
||||||
|
| NF4 | 异步 API | 所有 IO 操作为 async |
|
||||||
|
| NF5 | 模块化 | 各组件独立可替换 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 方案设计
|
||||||
|
|
||||||
|
### 3.1 总体架构
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph TB
|
||||||
|
subgraph Retrieval["检索层"]
|
||||||
|
MR["MemoryRetriever"]
|
||||||
|
end
|
||||||
|
subgraph Logic["逻辑层"]
|
||||||
|
CM["ConversationMemory"]
|
||||||
|
KS["KnowledgeStore"]
|
||||||
|
end
|
||||||
|
subgraph Storage["存储层"]
|
||||||
|
MS["MemoryStore (trait)"]
|
||||||
|
IMS["InMemoryStore (默认)"]
|
||||||
|
end
|
||||||
|
|
||||||
|
MR --> KS
|
||||||
|
CM --> MS
|
||||||
|
KS --> MS
|
||||||
|
MS --> IMS
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 模块结构
|
||||||
|
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
memory.rs # 模块根:pub mod + pub use 重导出
|
||||||
|
memory/
|
||||||
|
store.rs # MemoryStore trait + InMemoryStore
|
||||||
|
conversation.rs # ConversationMemory(对话管理)
|
||||||
|
knowledge.rs # KnowledgeStore(具体 struct)
|
||||||
|
retriever.rs # MemoryRetriever(单通道检索)
|
||||||
|
error.rs # MemoryError
|
||||||
|
types.rs # 核心数据类型
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.3 接口定义
|
||||||
|
|
||||||
|
#### MemoryStore — 底层存储抽象
|
||||||
|
|
||||||
|
```rust
|
||||||
|
#[async_trait]
|
||||||
|
pub trait MemoryStore: Send + Sync {
|
||||||
|
/// 保存/覆盖一个 MemoryItem(upsert 语义)。
|
||||||
|
/// - 如果 id 不存在,则插入新条目
|
||||||
|
/// - 如果 id 已存在,则覆盖旧条目
|
||||||
|
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError>;
|
||||||
|
async fn get(&self, id: &str) -> Result<Option<MemoryItem>, MemoryError>;
|
||||||
|
async fn delete(&self, id: &str) -> Result<(), MemoryError>;
|
||||||
|
async fn list(&self, filter: &MemoryFilter) -> Result<Vec<MemoryItem>, MemoryError>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### InMemoryStore — 默认实现
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct InMemoryStore {
|
||||||
|
items: Mutex<HashMap<String, MemoryItem>>,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
基于 `HashMap<String, MemoryItem>` + `Mutex`,纯内存,线程安全。
|
||||||
|
|
||||||
|
#### ConversationMemory — 对话记忆
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct ConversationMemory {
|
||||||
|
store: Arc<dyn MemoryStore>,
|
||||||
|
session_id: String,
|
||||||
|
config: ConversationMemoryConfig,
|
||||||
|
messages: Vec<OpenaiChatMessage>, // 热缓存,供 compact 直接操作
|
||||||
|
compact_state: CompactState, // 断路器状态
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ConversationMemoryConfig {
|
||||||
|
pub strategy: MemoryStrategy, // sliding_window | full
|
||||||
|
pub max_turns: usize, // sliding window 的最大轮数
|
||||||
|
pub compact_config: Option<CompactConfig>, // 复用现有压缩配置
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `add_message(msg)` 写入热缓存 `self.messages`,同时通过 `store.save()` 持久化到后端
|
||||||
|
- `get_history()` 优先从热缓存返回,缓存未命中时从 store 恢复
|
||||||
|
- `compact` 直接在 `self.messages` 上调用 `should_compact()` 和 `microcompact()`
|
||||||
|
- 压缩后同步回 `store`
|
||||||
|
- 使用了 `llm::types::OpenaiChatMessage` 作为内部消息类型
|
||||||
|
- 复用现有 `CompactConfig`(context_window, reserved_tokens, keep_recent)和 `CompactState`
|
||||||
|
|
||||||
|
#### KnowledgeStore — 具体 struct
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct KnowledgeStore {
|
||||||
|
store: Arc<dyn MemoryStore>,
|
||||||
|
index: Mutex<Vec<PageIndexEntry>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KnowledgeStore {
|
||||||
|
pub fn new(store: Arc<dyn MemoryStore>) -> Self { ... }
|
||||||
|
pub async fn add_page(&self, page: KnowledgePage) -> Result<(), MemoryError> { ... }
|
||||||
|
pub async fn get_page(&self, id: &str) -> Result<Option<KnowledgePage>, MemoryError> { ... }
|
||||||
|
pub async fn update_page(&self, page: KnowledgePage) -> Result<(), MemoryError> { ... }
|
||||||
|
pub async fn delete_page(&self, id: &str) -> Result<(), MemoryError> { ... }
|
||||||
|
pub async fn search(&self, query: &str) -> Result<Vec<KnowledgePage>, MemoryError> { ... }
|
||||||
|
pub async fn get_index(&self) -> Result<Vec<PageIndexEntry>, MemoryError> { ... }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `search()` 优先搜索 index(标题/摘要/标签),全文 content 搜索走 MemoryStore
|
||||||
|
- index 在 add/update/delete 时自动维护,也支持通过 `rebuild_index()` 手动重建
|
||||||
|
- `KnowledgeStore` 以 `"knowledge_{page_id}"` 格式作为 `MemoryItem.id` 前缀,前缀字符串提取为常量 `const KNOWLEDGE_PREFIX: &str = "knowledge_"`
|
||||||
|
|
||||||
|
提供 `rebuild_index()` 方法修复 index 与 store 的不同步问题:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
impl KnowledgeStore {
|
||||||
|
/// 从 MemoryStore 重建 index(修复 index 与 store 的不同步问题)
|
||||||
|
pub async fn rebuild_index(&self) -> Result<(), MemoryError> {
|
||||||
|
let items = self.store.list(&MemoryFilter {
|
||||||
|
prefix: Some("knowledge_".into()),
|
||||||
|
since: None,
|
||||||
|
offset: None,
|
||||||
|
limit: None,
|
||||||
|
}).await?;
|
||||||
|
let mut index = self.index.lock();
|
||||||
|
index.clear();
|
||||||
|
for item in items {
|
||||||
|
let page: KnowledgePage = serde_json::from_str(&item.content)
|
||||||
|
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||||
|
index.push(PageIndexEntry {
|
||||||
|
id: page.id.clone(),
|
||||||
|
title: page.title,
|
||||||
|
summary: page.summary,
|
||||||
|
tags: page.tags,
|
||||||
|
updated_at: page.updated_at,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### MemoryRetriever — 简化版检索器
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct MemoryRetriever {
|
||||||
|
knowledge_store: KnowledgeStore,
|
||||||
|
config: RetrieverConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RetrieverConfig {
|
||||||
|
pub max_results: usize, // 默认 20
|
||||||
|
pub min_score: f32, // 默认 0.1
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RetrievalResult {
|
||||||
|
pub items: Vec<ScoredItem>,
|
||||||
|
pub query: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ScoredItem {
|
||||||
|
pub page: KnowledgePage,
|
||||||
|
pub score: f32, // TextOverlap 评分 [0.0, 1.0]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
检索流程:
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart TD
|
||||||
|
A["输入: query"]
|
||||||
|
B["1. 关键词提取(split + 过滤停用词)"]
|
||||||
|
C["2. KnowledgeStore.search(keywords)"]
|
||||||
|
D["3. TextOverlap 评分"]
|
||||||
|
E["4. 过滤 score < min_score"]
|
||||||
|
F["5. 降序排序 → 截取 top-N"]
|
||||||
|
G["6. 返回 RetrievalResult"]
|
||||||
|
|
||||||
|
A --> B
|
||||||
|
B --> C
|
||||||
|
C --> D
|
||||||
|
D --> E
|
||||||
|
E --> F
|
||||||
|
F --> G
|
||||||
|
```
|
||||||
|
|
||||||
|
关键词提取在 MemoryRetriever 内部简单实现:按空格/标点分割 → 过滤单字符和停用词 → 返回关键词列表。TextOverlap 计算 query 与页面标题/摘要/内容的 n-gram 重叠度(基于 Dice 系数)。
|
||||||
|
|
||||||
|
TextOverlap 评分基于 Dice 系数(字符 bigram):
|
||||||
|
|
||||||
|
dice(query, text) = 2 × |bigrams(query) ∩ bigrams(text)| / (|bigrams(query)| + |bigrams(text)|)
|
||||||
|
|
||||||
|
多字段加权:
|
||||||
|
score = title_dice × 0.5 + summary_dice × 0.3 + content_dice × 0.2
|
||||||
|
|
||||||
|
中文场景退化:当前版本按字符级 bigram 处理中文,不依赖分词器。
|
||||||
|
|
||||||
|
> **已知限制**:关键词提取基于空格/标点分割,对中文不做语义分词。
|
||||||
|
> 中文场景按字符 bigram 参与 TextOverlap 计算,精度低于专业分词方案。
|
||||||
|
> 如有更高精度需求,可替换 MemoryRetriever 的关键词提取逻辑。
|
||||||
|
|
||||||
|
### 3.4 核心数据类型
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct MemoryItem {
|
||||||
|
pub id: String,
|
||||||
|
pub content: String,
|
||||||
|
pub metadata: serde_json::Value,
|
||||||
|
pub created_at: time::OffsetDateTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MemoryFilter {
|
||||||
|
pub prefix: Option<String>,
|
||||||
|
pub since: Option<time::OffsetDateTime>,
|
||||||
|
pub offset: Option<usize>, // 跳过前 N 条
|
||||||
|
pub limit: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct KnowledgePage {
|
||||||
|
pub id: String,
|
||||||
|
pub title: String,
|
||||||
|
pub summary: String,
|
||||||
|
pub content: String,
|
||||||
|
pub tags: Vec<String>,
|
||||||
|
pub references: Vec<String>,
|
||||||
|
pub created_at: time::OffsetDateTime,
|
||||||
|
pub updated_at: time::OffsetDateTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PageIndexEntry {
|
||||||
|
pub id: String,
|
||||||
|
pub title: String,
|
||||||
|
pub summary: String,
|
||||||
|
pub tags: Vec<String>,
|
||||||
|
pub updated_at: time::OffsetDateTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum MemoryStrategy {
|
||||||
|
SlidingWindow,
|
||||||
|
Full,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.5 错误类型
|
||||||
|
|
||||||
|
```rust
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum MemoryError {
|
||||||
|
#[error("Item not found: {0}")]
|
||||||
|
NotFound(String),
|
||||||
|
|
||||||
|
#[error("Storage error: {0}")]
|
||||||
|
Storage(String),
|
||||||
|
|
||||||
|
#[error("Serialization error: {0}")]
|
||||||
|
Serialization(String),
|
||||||
|
|
||||||
|
#[error("Invalid input: {0}")]
|
||||||
|
InvalidInput(String),
|
||||||
|
|
||||||
|
#[error("Retrieval error: {0}")]
|
||||||
|
RetrievalError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryError {
|
||||||
|
pub fn is_recoverable(&self) -> bool {
|
||||||
|
matches!(self, Self::NotFound(_) | Self::RetrievalError(_))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.6 ConversationMemory 与 compact 模块的集成
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart TD
|
||||||
|
subgraph CM["ConversationMemory"]
|
||||||
|
A["add_message(msg)"]
|
||||||
|
A1["self.messages.push(msg) ← 热缓存"]
|
||||||
|
A2["store.save(to_item(msg)) ← 冷持久化"]
|
||||||
|
B{"len(messages) > max_turns?"}
|
||||||
|
C["should_compact(&messages, &config, &state) ← 直接在热缓存上操作"]
|
||||||
|
D["microcompact(&mut messages, keep_recent) ← 复用 microcompact()"]
|
||||||
|
E["sync_to_store() ← 压缩后同步回 store"]
|
||||||
|
F["return"]
|
||||||
|
G["get_history()"]
|
||||||
|
H["从 self.messages 返回"]
|
||||||
|
I["从 store 恢复 → 重建热缓存"]
|
||||||
|
|
||||||
|
A --> A1
|
||||||
|
A1 --> A2
|
||||||
|
A2 --> B
|
||||||
|
B -->|是| C
|
||||||
|
C --> D
|
||||||
|
D --> E
|
||||||
|
E --> F
|
||||||
|
B -->|否| F
|
||||||
|
G --> H
|
||||||
|
H -->|缓存未命中| I
|
||||||
|
end
|
||||||
|
J["依赖: llm::compact::{CompactConfig, CompactState, should_compact, microcompact}"]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 物理存储策略
|
||||||
|
|
||||||
|
### 4.1 存储层次
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph TB
|
||||||
|
subgraph RetrievalLayer["检索层"]
|
||||||
|
MR["MemoryRetriever (检索 + 评分,无状态)"]
|
||||||
|
end
|
||||||
|
subgraph AbstractionLayer["存储抽象层"]
|
||||||
|
ABS["MemoryStore\n(存储抽象接口,不感知存储介质)"]
|
||||||
|
end
|
||||||
|
subgraph ImplementationLayer["实现层"]
|
||||||
|
IMS["InMemoryStore (HashMap)\n进程内 volatile\n测试/原型适用"]
|
||||||
|
CUSTOM["下游自定义实现\nFileStore / SqliteStore / RedisStore / ...\n生产环境适用"]
|
||||||
|
end
|
||||||
|
|
||||||
|
MR --> ABS
|
||||||
|
ABS --> IMS
|
||||||
|
ABS --> CUSTOM
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 InMemoryStore 的物理存储
|
||||||
|
|
||||||
|
| 组件 | 数据结构 | 存储位置 | 持久化 | 生命周期 |
|
||||||
|
|------|---------|---------|--------|---------|
|
||||||
|
| `InMemoryStore` | `HashMap<String, MemoryItem>` | 进程堆内存 | ❌ | 随进程销毁 |
|
||||||
|
| `KnowledgeStore` | 基于 `InMemoryStore` + `Vec<PageIndexEntry>` | 进程堆内存 | ❌ | 随进程销毁 |
|
||||||
|
|
||||||
|
**适用场景:**
|
||||||
|
- 单元测试和集成测试
|
||||||
|
- 本地快速原型开发
|
||||||
|
- 单次会话的临时 Agent
|
||||||
|
|
||||||
|
**不适合场景:**
|
||||||
|
- 生产部署
|
||||||
|
- 需要跨进程/跨会话共享记忆
|
||||||
|
- 需要数据持久化和恢复
|
||||||
|
|
||||||
|
### 4.3 持久化存储方案(下游实现)
|
||||||
|
|
||||||
|
agcore 核心库**不内置**持久化实现,用户通过实现 `MemoryStore` trait 对接所需后端:
|
||||||
|
|
||||||
|
| 后端 | 实现建议 | 适用场景 | 复杂度 |
|
||||||
|
|------|---------|---------|--------|
|
||||||
|
| **JSON 文件** | MemoryStore trait → 序列化为单文件 JSON | 单机、轻量持久化 | 低 |
|
||||||
|
| **SQLite** | MemoryStore → 关系表 | 单机、中小规模 | 中 |
|
||||||
|
| **PostgreSQL** | MemoryStore → 关系表 | 多进程共享、中等规模 | 中 |
|
||||||
|
| **Redis** | MemoryStore → Hash/JSON 类型 | 高速缓存、会话共享 | 低 |
|
||||||
|
|
||||||
|
### 4.4 对下游实现的约束
|
||||||
|
|
||||||
|
`MemoryStore` trait 对持久化实现无特殊约束:
|
||||||
|
- 方法签名不涉及文件路径、连接字符串等存储细节
|
||||||
|
- 所有方法均为 `async`,持久化实现可自由选择同步(`spawn_blocking`)或异步 driver
|
||||||
|
- 初始化参数在具体实现的构造函数中注入
|
||||||
|
|
||||||
|
### 4.5 序列化
|
||||||
|
|
||||||
|
核心类型均实现 `Serialize` / `Deserialize`(通过 `#[derive(serde)]`),便于持久化实现直接复用:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct MemoryItem { ... }
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct KnowledgePage { ... }
|
||||||
|
// ...
|
||||||
|
```
|
||||||
|
|
||||||
|
所有类型基于 `serde_json::Value` 作为 metadata 类型,不引入 protobuf/msgpack 等序列化框架。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 淘汰策略
|
||||||
|
|
||||||
|
### 5.1 问题
|
||||||
|
|
||||||
|
所有存储组件如果不设上限,会随运行时间无限增长。ConversationMemory 当前的 sliding window 只做 tool result 压缩,不删除消息条目。
|
||||||
|
|
||||||
|
### 5.2 淘汰策略
|
||||||
|
|
||||||
|
在 `MemoryStore` trait 层提供可选的淘汰配置,上层组件按需设置:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct EvictionConfig {
|
||||||
|
pub policy: EvictionPolicy,
|
||||||
|
pub check_interval: usize, // 每写入 N 条后检查一次淘汰条件
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum EvictionPolicy {
|
||||||
|
None, // 不淘汰(默认)
|
||||||
|
Ttl { ttl_secs: u64 }, // 超过存活时间淘汰
|
||||||
|
Capacity { max_items: usize },// 超过容量淘汰最旧
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`InMemoryStore` 在 `save()` 后检查淘汰条件:
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart TD
|
||||||
|
A["save(item)"]
|
||||||
|
B["items.insert(id, item)"]
|
||||||
|
C{"writes_since_last_check >= check_interval?"}
|
||||||
|
D{"policy 类型"}
|
||||||
|
E["Ttl → items.retain(created_at > cutoff)"]
|
||||||
|
F["Capacity → 按 created_at 升序排列,截断到 max_items"]
|
||||||
|
G["None → 不淘汰"]
|
||||||
|
|
||||||
|
A --> B
|
||||||
|
B --> C
|
||||||
|
C -->|是| D
|
||||||
|
C -->|否| G
|
||||||
|
D -->|Ttl| E
|
||||||
|
D -->|Capacity| F
|
||||||
|
D -->|None| G
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 各组件淘汰策略
|
||||||
|
|
||||||
|
| 组件 | 推荐策略 | 理由 |
|
||||||
|
|------|---------|------|
|
||||||
|
| **ConversationMemory** | `Capacity { max_items }` | 对话是流式的,旧消息价值递减,达到上限后淘汰最旧的消息条目 |
|
||||||
|
| **MemoryStore**(通用) | `Ttl { ttl_secs }` | 通用存储由调用方按场景决定 |
|
||||||
|
| **KnowledgeStore** | `None`(默认不淘汰) | 知识是累积的,新增不淘汰旧 |
|
||||||
|
|
||||||
|
### 5.4 ConversationMemory 淘汰行为
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart TD
|
||||||
|
A["add_message(msg)"]
|
||||||
|
B["store.save(msg)"]
|
||||||
|
C{"eviction.policy == Capacity?"}
|
||||||
|
D["history = store.list(session)"]
|
||||||
|
E{"while history.len() > max_items"}
|
||||||
|
F["oldest = history.remove(0) ← 删除最旧消息"]
|
||||||
|
G["store.delete(oldest.id)"]
|
||||||
|
H["maybe_compact() ← 复用 microcompact() 做内容压缩"]
|
||||||
|
I["return"]
|
||||||
|
|
||||||
|
A --> B
|
||||||
|
B --> C
|
||||||
|
C -->|是| D
|
||||||
|
D --> E
|
||||||
|
E -->|是| F
|
||||||
|
F --> G
|
||||||
|
G --> E
|
||||||
|
E -->|否| H
|
||||||
|
C -->|否| H
|
||||||
|
H --> I
|
||||||
|
```
|
||||||
|
|
||||||
|
两种机制分层:
|
||||||
|
- **淘汰(eviction)**:删除整条消息,控制条目总数上限
|
||||||
|
- **压缩(compaction)**:压缩剩余消息的 tool result 内容,节省 token
|
||||||
|
|
||||||
|
### 5.5 InMemoryStore 的淘汰实现
|
||||||
|
|
||||||
|
```rust
|
||||||
|
impl InMemoryStore {
|
||||||
|
pub fn with_eviction(config: EvictionConfig) -> Self { ... }
|
||||||
|
}
|
||||||
|
|
||||||
|
// 在 save() 内部:
|
||||||
|
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError> {
|
||||||
|
self.items.lock().insert(item.id.clone(), item);
|
||||||
|
self.maybe_evict().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn maybe_evict(&self) {
|
||||||
|
match &self.eviction.policy {
|
||||||
|
EvictionPolicy::Ttl { ttl_secs } => {
|
||||||
|
let cutoff = Utc::now() - Duration::seconds(*ttl_secs as i64);
|
||||||
|
self.items.lock().retain(|_, v| v.created_at > cutoff);
|
||||||
|
}
|
||||||
|
EvictionPolicy::Capacity { max_items } => {
|
||||||
|
let mut items = self.items.lock();
|
||||||
|
if items.len() > *max_items {
|
||||||
|
let mut vec: Vec<_> = items.drain().collect();
|
||||||
|
vec.select_nth_unstable_by(
|
||||||
|
*max_items,
|
||||||
|
|a, b| b.1.created_at.cmp(&a.1.created_at),
|
||||||
|
);
|
||||||
|
vec.truncate(*max_items);
|
||||||
|
*items = vec.into_iter().collect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EvictionPolicy::None => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 实现计划
|
||||||
|
|
||||||
|
### Step 1:基础类型 + MemoryStore(含淘汰机制)
|
||||||
|
|
||||||
|
**文件**:`src/memory.rs`、`src/memory/types.rs`、`src/memory/error.rs`、`src/memory/store.rs`
|
||||||
|
|
||||||
|
- 创建 `memory.rs` + `memory/` 目录
|
||||||
|
- 定义 `MemoryItem`、`MemoryFilter`、`MemoryStrategy` 类型
|
||||||
|
- 定义 `MemoryError` 枚举
|
||||||
|
- 定义 `MemoryStore` trait 与 `InMemoryStore` 实现
|
||||||
|
- 定义 `EvictionConfig`、`EvictionPolicy`(None / Ttl / Capacity)
|
||||||
|
- `InMemoryStore.save()` 内部实现淘汰检查
|
||||||
|
- 单元测试:TTL 过期淘汰、容量上限淘汰、不淘汰(None)
|
||||||
|
- 验收:`cargo build` + `cargo test` 通过
|
||||||
|
|
||||||
|
**依赖**:time(日期时间,启用 `serde` feature)、serde(序列化)
|
||||||
|
|
||||||
|
### Step 2:ConversationMemory(含消息淘汰)
|
||||||
|
|
||||||
|
**文件**:`src/memory/conversation.rs`
|
||||||
|
|
||||||
|
- 定义 `ConversationMemoryConfig`、`MemoryStrategy`
|
||||||
|
- 实现 `ConversationMemory`
|
||||||
|
- `add_message()` → 写入 MemoryStore → 触发容量淘汰(删除最旧消息)
|
||||||
|
- 复用 `llm::compact` 的 `CompactConfig` 和 `microcompact()` 做内容压缩
|
||||||
|
- 单元测试:sliding window 消息淘汰、tool result 内容压缩、full 模式、空 session
|
||||||
|
- 验收:`cargo build` + `cargo test` 通过
|
||||||
|
|
||||||
|
**依赖**:MemoryStore(含 EvictionConfig)+ llm::compact
|
||||||
|
|
||||||
|
### Step 3:KnowledgeStore
|
||||||
|
|
||||||
|
**文件**:`src/memory/knowledge.rs`
|
||||||
|
|
||||||
|
- 定义 `KnowledgePage`、`PageIndexEntry` 类型
|
||||||
|
- 实现 `KnowledgeStore` 具体 struct(非 trait)
|
||||||
|
- 内部使用 `Arc<dyn MemoryStore>` 存储数据
|
||||||
|
- index 自动维护(add/update/delete 时同步)
|
||||||
|
- search 基于标题/摘要/标签的关键词匹配
|
||||||
|
- 单元测试:页面 CRUD、index 一致性、搜索
|
||||||
|
- 验收:`cargo build` + `cargo test` 通过
|
||||||
|
|
||||||
|
### Step 4:MemoryRetriever + 模块整合
|
||||||
|
|
||||||
|
**文件**:`src/memory/retriever.rs`、`src/memory.rs`
|
||||||
|
|
||||||
|
- 实现内存检索器 `MemoryRetriever`
|
||||||
|
- 内部关键词提取:split + 过滤停用词
|
||||||
|
- TextOverlap 评分:基于 Dice 系数计算 query 与页面的文本重叠度
|
||||||
|
- 阈值过滤 → 排序 → 截取 top-N
|
||||||
|
- 在 `memory.rs` 中用 `pub use` 分层重导出:
|
||||||
|
- 高频类型(大多数下游需要):`MemoryStore`、`InMemoryStore`、`ConversationMemory`、`KnowledgeStore`、`MemoryRetriever`、`MemoryError`
|
||||||
|
- 低频类型(配置/高级使用):`MemoryItem`、`MemoryFilter`、`MemoryStrategy`、`KnowledgePage`、`PageIndexEntry`、`EvictionConfig`、`EvictionPolicy`、`ConversationMemoryConfig`、`RetrieverConfig`、`RetrievalResult`、`ScoredItem`
|
||||||
|
- 在 `src/lib.rs` 中声明 `pub mod memory`
|
||||||
|
- 单元测试:关键词提取、TextOverlap 评分正确性、阈值过滤、排序正确性
|
||||||
|
- 集成测试:端到端检索流程
|
||||||
|
- 验收:`cargo build` + `cargo test` 通过
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 风险评估
|
||||||
|
|
||||||
|
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||||
|
|------|------|------|---------|
|
||||||
|
| KnowledgeStore 的 keyword 检索在大规模下效率低 | 中 | 中 | MemoryStore 实现可替换——下游可使用 SQLite FTS 等更高效的后端 |
|
||||||
|
| ConversationMemory 与 compact 耦合引入循环依赖 | 低 | 高 | 仅引用 `CompactConfig`(纯数据结构)和 `microcompact()`(纯函数),不引用 cycle.rs |
|
||||||
|
| time 增加依赖体积 | 低 | 低 | time 是 Rust 官方维护的时间库,体积小于 chrono |
|
||||||
|
| Phase 4 集成时发现 Memory 设计不合理 | 低 | 高 | 按最小可行接口设计,预留扩展空间 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 验收标准
|
||||||
|
|
||||||
|
- [ ] `MemoryStore` trait + `InMemoryStore` 通过单元测试
|
||||||
|
- [ ] `EvictionConfig` 支持 None / Ttl / Capacity 三种策略
|
||||||
|
- [ ] `InMemoryStore` 在 save() 后正确执行 TTL 淘汰和容量淘汰
|
||||||
|
- [ ] `ConversationMemory` 支持 sliding window 和 full 两种策略
|
||||||
|
- [ ] `ConversationMemory` sliding window 模式下达到上限后删除最旧消息条目
|
||||||
|
- [ ] `ConversationMemory` 正确复用 `llm::compact` 的压缩逻辑
|
||||||
|
- [ ] `KnowledgeStore` 支持页面 CRUD 和 index 维护
|
||||||
|
- [ ] `MemoryRetriever` 支持基于 TextOverlap 的知识检索
|
||||||
|
- [ ] 无 embedding 相关依赖
|
||||||
|
- [ ] 模块结构:`memory.rs` + `memory/` 目录 + `pub use` 重导出
|
||||||
|
- [ ] `MemoryError` 枚举完善,支持 `is_recoverable()`
|
||||||
|
- [ ] 所有公开 API 有文档注释(`///`)
|
||||||
|
- [ ] `cargo build` 和 `cargo test` 通过
|
||||||
|
- [ ] 单个文件不超过 300 行
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 附录:与 Karpathy's LLM Wiki 的关系
|
||||||
|
|
||||||
|
本方案受 Karpathy's LLM Wiki 模式启发,但做了一些调整以适应 Agent 核心库的定位:
|
||||||
|
|
||||||
|
| Karpathy LLM Wiki | AG Core Memory System | 差异原因 |
|
||||||
|
|-------------------|----------------------|---------|
|
||||||
|
| 三层:Raw → Wiki → Schema | 三组件:MemoryStore → ConversationMemory + KnowledgeStore + MemoryRetriever | Agent 场景需要区分对话记忆和知识记忆 |
|
||||||
|
| index.md + log.md | PageIndexEntry(同 index.md)+ 无 log(Phase 4 Agent 负责) | 日志是工作流层职责,非存储层 |
|
||||||
|
| LLM Agent 全权维护 | KnowledgeStore 提供数据接口,Phase 4 Agent 编排工作流 | core 只提供存储能力,不编排 |
|
||||||
|
| 文件系统为后端 | MemoryStore trait 抽象后端 | 可插拔设计需要 trait 抽象 |
|
||||||
|
| 基于文件系统搜索 | index + keyword 检索 | 文件系统搜索不适合所有后端 |
|
||||||
@@ -0,0 +1,647 @@
|
|||||||
|
# Agent Runtime 方案设计
|
||||||
|
|
||||||
|
> 设计日期:2026-06-09
|
||||||
|
> 状态:待实施
|
||||||
|
> 关联文档:
|
||||||
|
> - `docs/note-agent-runtime-design.md` — 设计决策记录(接口签名、文件清单、决策依据)
|
||||||
|
> - `docs/note-agent-harness-references.md` — 参考项目调研(OpenClaw / Hermes / OpenHuman / OpenHarness)
|
||||||
|
> - `docs/6-memory-system.md` — Phase 3 方案
|
||||||
|
> - `docs/5-tool-system.md` — Phase 2 方案
|
||||||
|
> - `docs/roadmap.md` — 项目总 Roadmap
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 背景与目标
|
||||||
|
|
||||||
|
### 1.1 背景
|
||||||
|
|
||||||
|
AG Core 已完成 Phase 0(LLM 调用周期)、Phase 1(提示词工程)、Phase 2(工具系统)、Phase 3(记忆系统)共 4 个 phase 的交付。`LlmCycle::submit_with_tools()` 已在 Phase 2 末实现"LLM 决策 → 工具执行 → 回传结果"的单次循环;`ConversationMemory` / `KnowledgeStore` / `MemoryRetriever` 在 Phase 3 提供了完整的记忆抽象。
|
||||||
|
|
||||||
|
当前缺一个**整合层**:把 Phase 0-3 的能力"装配"起来,对上层应用暴露"智能体"的概念。
|
||||||
|
|
||||||
|
### 1.2 目标
|
||||||
|
|
||||||
|
Phase 4 目标是提供一个**薄胶水层 + 一组 trait 抽象**,让上层应用可以基于 AG Core 构建多轮对话、任务规划等智能体行为。具体包括:
|
||||||
|
|
||||||
|
- **`Agent` trait** — 智能体的"角色"抽象(不绑定 session)
|
||||||
|
- **`AgentSession` struct** — 智能体的"会话"实例(绑定 session_id + 状态)
|
||||||
|
- **`TaskAgent` trait** — 任务型智能体的"规划/执行"抽象
|
||||||
|
- **`RuntimeBundle`** — 显式依赖注入容器,集中管理 provider/registry/hook/memory 等依赖
|
||||||
|
- **`AgentBuilder`** — 链式构造入口
|
||||||
|
- **`AgentError`** — 统一错误类型,聚合 LlmError / ToolError / MemoryError
|
||||||
|
|
||||||
|
### 1.3 设计原则
|
||||||
|
|
||||||
|
Phase 4 严格遵循以下原则,所有范围决策都基于这些原则推导:
|
||||||
|
|
||||||
|
| 原则 | 含义 | 推导 |
|
||||||
|
|------|------|------|
|
||||||
|
| **最小范围** | AG Core 是 lib crate,不是产品;不实现业务循环 | 只暴露 trait + 最小 reference impl |
|
||||||
|
| **薄胶水层** | 不在 L1 重写已经做好的能力 | 复用 `LlmCycle::submit_with_tools` 等已有 API |
|
||||||
|
| **依赖注入** | 所有运行时依赖显式打包传递 | 采用 OpenHarness `RuntimeBundle` 模式 |
|
||||||
|
| **实体/会话分离** | 同一角色可被多 session 复用 | `Agent` + `AgentSession` 两层模型 |
|
||||||
|
| **记忆弱引用** | 记忆是"被动能力",不内嵌循环 | `memory_store: Option<Arc<dyn MemoryStore>>` 弱引用 |
|
||||||
|
| **业务可注入** | Plan 拆解是业务能力,不在 core 库实现 | 暴露 `PlanParser` trait,上层注入 |
|
||||||
|
| **借鉴不照搬** | 4 个参考项目均非 Rust 实现 | 只取架构模式,不抄实现细节 |
|
||||||
|
|
||||||
|
### 1.4 与已完成的 Phase 关系
|
||||||
|
|
||||||
|
```
|
||||||
|
Phase 0 (L0/L1) ── LlmProvider / LlmCycle / Hook / Stream / Compact
|
||||||
|
Phase 1 (L2) ── PromptTemplate / PromptComposer
|
||||||
|
Phase 2 (L1) ── ToolRegistry / BaseTool / PermissionChecker / McpClient
|
||||||
|
Phase 3 (L2) ── MemoryStore / ConversationMemory / KnowledgeStore / MemoryRetriever
|
||||||
|
↑
|
||||||
|
│ 复用
|
||||||
|
│
|
||||||
|
Phase 4 (L1→L2) ── Agent trait + AgentSession + TaskAgent + RuntimeBundle(胶水层)
|
||||||
|
↓
|
||||||
|
应用层 (L4) ── 上层 crate / 二进制 / Gateway(不在 Phase 4 范围)
|
||||||
|
```
|
||||||
|
|
||||||
|
详细架构对照见 `docs/note-agent-harness-references.md` §3-5。
|
||||||
|
|
||||||
|
## 2. 需求分析
|
||||||
|
|
||||||
|
### 2.1 功能需求
|
||||||
|
|
||||||
|
| ID | 需求 | 优先级 | 说明 |
|
||||||
|
|----|------|--------|------|
|
||||||
|
| F1 | `Agent` trait 抽象 | P0 | 角色定义:name / system_prompt / 工具集 |
|
||||||
|
| F2 | `AgentSession` 会话实例 | P0 | 绑定 session_id、bundle、turn_index、cost_so_far |
|
||||||
|
| F3 | `submit_turn()` 最小 reference impl | P0 | 组装 LlmCycle → submit → 累计 cost;约 30 行 |
|
||||||
|
| F4 | `TaskAgent::run(goal)` 自主式入口 | P0 | 内部用 LLM 拆 Plan,再调用 `execute_plan` |
|
||||||
|
| F5 | `TaskAgent::execute_plan(plan)` 外部驱动式入口 | P0 | 用户预定义 Plan,逐步执行 |
|
||||||
|
| F6 | `Plan` / `Step` / `StepStatus` 数据结构 | P0 | 含 Pending / Running / Completed / Failed / Skipped 状态机 |
|
||||||
|
| F7 | `PlanParser` trait + `JsonPlanParser` 参考实现 | P0 | 注入式,上层可替换 |
|
||||||
|
| F8 | `RuntimeBundle` 依赖注入容器 | P0 | 聚合 provider/registry/hook/memory/retriever/config |
|
||||||
|
| F9 | `AgentBuilder` 链式构造 | P0 | 构建 `RuntimeBundle`,retriever 存在时自动注册为 tool |
|
||||||
|
| F10 | `AgentError` 统一错误类型 | P0 | 聚合 LlmError / ToolError / MemoryError,含 `is_recoverable()` |
|
||||||
|
| F11 | Hook 事件扩展:OnTurnStart / OnTurnEnd / OnPlanStepComplete | P0 | 在 `llm/hooks.rs` 中追加 3 个事件 + 上下文扩展 2 个字段 |
|
||||||
|
| F12 | 烟雾测试 2-3 个 | P0 | trait 可装配 / RuntimeBundle 可构造 / `submit_turn` 跑通 mock |
|
||||||
|
| F13 | `lib.rs` 导出 `pub mod agent;` | P0 | 一行 |
|
||||||
|
| F14 | 方案文档(本文件)+ 决策记录 | P0 | 已完成 |
|
||||||
|
| F15 | Roadmap 状态翻转 | P0 | 实施完成后做 |
|
||||||
|
|
||||||
|
### 2.2 非功能需求
|
||||||
|
|
||||||
|
| ID | 需求 | 说明 |
|
||||||
|
|----|------|------|
|
||||||
|
| NF1 | 不引入新外部依赖 | 仅使用 Phase 0-3 已有的 `async-trait` / `serde` / `thiserror` / `tokio` 等 |
|
||||||
|
| NF2 | 错误体系完善 | `AgentError` 聚合下层错误,含 `is_recoverable()` 分类 |
|
||||||
|
| NF3 | 线程安全 | 所有公开类型满足 `Send + Sync` |
|
||||||
|
| NF4 | 异步优先 | 涉及 IO 的 API 全部 `async` |
|
||||||
|
| NF5 | 模块化 | 各组件独立可替换,遵循"trait 抽象 + 轻量默认实现"惯例 |
|
||||||
|
| NF6 | 文档注释 | 所有公开 API 必须有 `///` 文档注释 |
|
||||||
|
| NF7 | builder 模式 | 复杂配置走 builder 链式构造 |
|
||||||
|
| NF8 | 显式依赖 | 不引入模块级全局状态,所有依赖通过参数或 bundle 注入 |
|
||||||
|
| NF9 | 不破坏现有 API | Phase 0-3 的公开 API 一字不改;`hooks.rs` 扩展为"追加变体 + 追加字段"(兼容) |
|
||||||
|
| NF10 | 最小测试覆盖 | 核心 trait 至少 1 个烟雾测试;`submit_turn` 至少 1 个 mock 测试;不强求集成测试 |
|
||||||
|
|
||||||
|
## 3. 方案设计
|
||||||
|
|
||||||
|
### 3.1 总体架构
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ 应用层(不在 Phase 4 范围) │
|
||||||
|
│ ┌────────────┐ ┌────────────┐ ┌────────────┐ ┌────────────┐ │
|
||||||
|
│ │ CLI Agent │ │ Feishu Bot │ │ Web Service│ │ TUI App │ │
|
||||||
|
│ └─────┬──────┘ └─────┬──────┘ └─────┬──────┘ └─────┬──────┘ │
|
||||||
|
└─────────┼────────────────┼────────────────┼────────────────┼───────────┘
|
||||||
|
│ │ │ │
|
||||||
|
└────────────────┴────────────────┴────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ Agent Runtime(Phase 4) │
|
||||||
|
│ │
|
||||||
|
│ ┌────────────────┐ ┌──────────────────┐ │
|
||||||
|
│ │ Agent trait │ 1 ──── * │ AgentSession │ │
|
||||||
|
│ │ (角色) │ │ (会话实例) │ │
|
||||||
|
│ └────────────────┘ └──────┬───────────┘ │
|
||||||
|
│ │ Arc<...> │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌──────────────────┐ │
|
||||||
|
│ │ RuntimeBundle │ │
|
||||||
|
│ │ - provider │ │
|
||||||
|
│ │ - tool_registry │ │
|
||||||
|
│ │ - hook_executor │ │
|
||||||
|
│ │ - memory_store? │ ◄─ 弱引用 │
|
||||||
|
│ │ - retriever? │ ◄─ 弱引用 │
|
||||||
|
│ │ - config │ │
|
||||||
|
│ └──────┬───────────┘ │
|
||||||
|
│ │ new() 时若 retriever 存在 │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌──────────────────┐ │
|
||||||
|
│ │ "retrieve" tool │ ◄─ 自动注册 │
|
||||||
|
│ └──────────────────┘ │
|
||||||
|
│ │
|
||||||
|
│ ┌────────────────┐ ┌──────────────────┐ ┌──────────────────┐ │
|
||||||
|
│ │ TaskAgent trait│ │ Plan/Step/Status │ │ PlanParser trait │ │
|
||||||
|
│ │ run() │ │ 状态机 │ │ JsonPlanParser │ │
|
||||||
|
│ │ execute_plan()│ │ │ │ (参考实现 ~20行) │ │
|
||||||
|
│ └────────────────┘ └──────────────────┘ └──────────────────┘ │
|
||||||
|
│ │
|
||||||
|
│ ┌────────────────┐ ┌──────────────────┐ │
|
||||||
|
│ │ AgentError │ │ AgentBuilder │ │
|
||||||
|
│ │ (聚合) │ │ (链式构造) │ │
|
||||||
|
│ └────────────────┘ └──────────────────┘ │
|
||||||
|
└──────────────────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼ 复用
|
||||||
|
┌──────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ LLM / Tool / Prompt / Memory(Phase 0-3) │
|
||||||
|
│ LlmCycle / ProviderRegistry / ToolRegistry / PermissionChecker / │
|
||||||
|
│ HookExecutor / StreamEvents / CompactConfig / │
|
||||||
|
│ PromptTemplate / PromptComposer / │
|
||||||
|
│ MemoryStore / ConversationMemory / KnowledgeStore / MemoryRetriever│
|
||||||
|
└──────────────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 接口设计
|
||||||
|
|
||||||
|
详细接口签名见 `docs/note-agent-runtime-design.md` §4,本节说明设计意图。
|
||||||
|
|
||||||
|
#### 3.2.1 `Agent` trait
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub trait Agent: Send + Sync {
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
fn system_prompt(&self) -> Option<&str>;
|
||||||
|
/// 列出该 Agent 想要暴露给 LLM 的工具定义。
|
||||||
|
/// 默认实现:从 RuntimeBundle.tool_registry 取全部(最常用)。
|
||||||
|
/// 子 trait 可覆盖做白名单/过滤。
|
||||||
|
fn tool_definitions(&self, bundle: &RuntimeBundle) -> Vec<ToolDefinition>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**设计意图**:
|
||||||
|
- `name` / `system_prompt` 是 LLM 调用必需的元数据
|
||||||
|
- `tool_definitions` 默认从 bundle 全量取,**Agent 可以在不修改 bundle 的情况下做工具白名单**——这与 Hermes 的"Skill 暴露"机制对齐
|
||||||
|
- 不在 trait 里强制 `submit_turn`——`submit_turn` 是 `AgentSession` 的方法,不应绑死在角色定义上
|
||||||
|
|
||||||
|
#### 3.2.2 `RuntimeBundle`
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct RuntimeBundle {
|
||||||
|
pub provider: Arc<dyn LlmProvider>,
|
||||||
|
pub tool_registry: Arc<ToolRegistry>,
|
||||||
|
pub hook_executor: Arc<HookExecutor>,
|
||||||
|
pub memory_store: Option<Arc<dyn MemoryStore>>, // 弱引用
|
||||||
|
pub retriever: Option<Arc<MemoryRetriever>>, // 弱引用
|
||||||
|
pub config: AgentConfig,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**设计意图**:
|
||||||
|
- 所有运行时依赖**显式打包**(OpenHarness 风格)
|
||||||
|
- `memory_store` / `retriever` 均为 `Option`——上层应用**不传也能跑**(无记忆模式)
|
||||||
|
- 当 `retriever` 存在时,`RuntimeBundle::new()` 内部自动注册一个名为 `"retrieve"` 的 tool(具体实现:在 `ToolRegistry` 里加一个 `RetrieveTool` 包装),让 LLM 在对话中**主动**调用检索能力
|
||||||
|
- `config` 集中管理所有可调参数(max_turns、max_tool_turns、session_ttl、compact_config)
|
||||||
|
|
||||||
|
#### 3.2.3 `AgentSession` 与最小 reference impl
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct AgentSession {
|
||||||
|
pub session_id: String,
|
||||||
|
pub agent_name: String,
|
||||||
|
bundle: Arc<RuntimeBundle>,
|
||||||
|
turn_index: u32,
|
||||||
|
cost_so_far: CostTracker,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentSession {
|
||||||
|
/// 最小 reference impl(约 30 行):
|
||||||
|
/// 1. 触发 OnTurnStart hook
|
||||||
|
/// 2. 组装 LlmCycle(注入 system_prompt + messages 历史 + tool definitions)
|
||||||
|
/// 3. submit_with_tools() 跑单轮对话
|
||||||
|
/// 4. 累计 cost
|
||||||
|
/// 5. 触发 OnTurnEnd hook
|
||||||
|
/// 6. turn_index += 1
|
||||||
|
/// 7. 返回 ChatResponse
|
||||||
|
/// 不做 memory 回写(由上层独立 task 处理)
|
||||||
|
pub async fn submit_turn(
|
||||||
|
&mut self,
|
||||||
|
user_input: impl Into<String>,
|
||||||
|
) -> Result<ChatResponse, AgentError>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**设计意图**:
|
||||||
|
- "最小 reference impl" 只演示**最常见**的对话场景
|
||||||
|
- 业务循环(多轮策略、错误重试、记忆回写时机)由上层应用或具体的 `TaskAgent` 实现决定
|
||||||
|
- `submit_turn` 不持有 `ConversationMemory`——上层应用可独立 new 一个 `ConversationMemory`,在合适的时机(如 OnTurnEnd hook)调 `add_message`
|
||||||
|
|
||||||
|
#### 3.2.4 `TaskAgent` + `Plan` / `Step`
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct Plan {
|
||||||
|
pub id: String,
|
||||||
|
pub goal: String,
|
||||||
|
pub steps: Vec<Step>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Step {
|
||||||
|
pub index: usize,
|
||||||
|
pub description: String,
|
||||||
|
pub status: StepStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum StepStatus {
|
||||||
|
Pending,
|
||||||
|
Running,
|
||||||
|
Completed(ChatResponse),
|
||||||
|
Failed(AgentError),
|
||||||
|
Skipped,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**设计意图**:
|
||||||
|
- `StepStatus` 用 enum 而非简单 bool,便于上层 UI 展示和统计
|
||||||
|
- 状态机转换:`Pending → Running → (Completed | Failed | Skipped)`,单向不可回退(重试由上层新建 Plan)
|
||||||
|
- `Plan` / `Step` 故意保持简单——不引入 `dependencies` / `parallel_group` 等高级字段(v0.3+ 再考虑)
|
||||||
|
|
||||||
|
#### 3.2.5 `PlanParser` trait + `JsonPlanParser` 参考实现
|
||||||
|
|
||||||
|
```rust
|
||||||
|
#[async_trait]
|
||||||
|
pub trait PlanParser: Send + Sync {
|
||||||
|
async fn parse(&self, raw: &str, goal: &str) -> Result<Plan, AgentError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct JsonPlanParser;
|
||||||
|
#[async_trait]
|
||||||
|
impl PlanParser for JsonPlanParser {
|
||||||
|
/// 期望 LLM 输出形如:
|
||||||
|
/// {"steps": [{"description": "..."}, ...]}
|
||||||
|
/// 的 JSON 文本。
|
||||||
|
/// 解析失败返回 AgentError::PlanParse。
|
||||||
|
async fn parse(&self, raw: &str, goal: &str) -> Result<Plan, AgentError> { /* ... */ }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**设计意图**:
|
||||||
|
- **注入式**:上层应用可以注入自己的 `PlanParser`(如基于 XML / YAML / 自定义 DSL)
|
||||||
|
- `JsonPlanParser` 是**参考实现**,不是默认实现——上层必须显式选择
|
||||||
|
- `JsonPlanParser` 大约 20 行:`serde_json::from_str` 解析 + 字段映射
|
||||||
|
|
||||||
|
#### 3.2.6 `AgentError`
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub enum AgentError {
|
||||||
|
Llm(LlmError),
|
||||||
|
Tool(ToolError),
|
||||||
|
Memory(MemoryError),
|
||||||
|
PlanParse(String),
|
||||||
|
HookBlocked(String),
|
||||||
|
LimitExceeded(String),
|
||||||
|
Config(String),
|
||||||
|
Other(String),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**设计意图**:
|
||||||
|
- 聚合而非包装下层错误(避免 `Box<dyn Error>` 丢失类型)
|
||||||
|
- `PlanParse` / `HookBlocked` / `LimitExceeded` / `Config` 是 Agent 层特有的错误类型
|
||||||
|
- `is_recoverable()` 根据变体类型判定(如 `Memory(_)` 可恢复、`PlanParse(_)` 不可恢复)
|
||||||
|
|
||||||
|
#### 3.2.7 `AgentConfig` + `AgentBuilder`
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct AgentConfig {
|
||||||
|
pub max_turns: u32,
|
||||||
|
pub max_tool_turns: u32,
|
||||||
|
pub session_ttl: Option<Duration>,
|
||||||
|
pub compact_config: Option<CompactConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AgentBuilder { /* ... */ }
|
||||||
|
impl AgentBuilder {
|
||||||
|
pub fn new() -> Self;
|
||||||
|
pub fn provider(self, p: Arc<dyn LlmProvider>) -> Self;
|
||||||
|
pub fn tool_registry(self, r: Arc<ToolRegistry>) -> Self;
|
||||||
|
pub fn hook_executor(self, h: Arc<HookExecutor>) -> Self;
|
||||||
|
pub fn memory_store(self, m: Arc<dyn MemoryStore>) -> Self; // 选填
|
||||||
|
pub fn retriever(self, r: Arc<MemoryRetriever>) -> Self; // 选填
|
||||||
|
pub fn config(self, c: AgentConfig) -> Self;
|
||||||
|
pub fn build(self) -> Result<RuntimeBundle, AgentError>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**设计意图**:
|
||||||
|
- `AgentBuilder` 是**唯一**的 `RuntimeBundle` 构造入口
|
||||||
|
- 必填字段在 `build()` 时校验(`provider` / `tool_registry` / `hook_executor` 不可缺)
|
||||||
|
- `memory_store` / `retriever` 选填,对应 §3.2.2 的"无记忆模式"
|
||||||
|
|
||||||
|
### 3.3 状态机
|
||||||
|
|
||||||
|
#### 3.3.1 `StepStatus` 状态转换图
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────┐
|
||||||
|
│ Pending │ ◄── 初始状态
|
||||||
|
└──────┬──────┘
|
||||||
|
│ execute_plan() 进入
|
||||||
|
▼
|
||||||
|
┌─────────────┐
|
||||||
|
│ Running │ ◄── 触发 OnPlanStepComplete(status=Running)
|
||||||
|
└──────┬──────┘
|
||||||
|
│
|
||||||
|
┌────────────────┼────────────────┐
|
||||||
|
│ │ │
|
||||||
|
▼ ▼ ▼
|
||||||
|
┌─────────┐ ┌──────────┐ ┌──────────┐
|
||||||
|
│Completed│ │ Failed │ │ Skipped │
|
||||||
|
└─────────┘ └──────────┘ └──────────┘
|
||||||
|
触发 OnPlanStepComplete(status=Completed)
|
||||||
|
触发 OnPlanStepComplete(status=Failed)
|
||||||
|
触发 OnPlanStepComplete(status=Skipped)
|
||||||
|
```
|
||||||
|
|
||||||
|
**设计约束**:
|
||||||
|
- 状态转换**单向**(Pending → Running → 终态),不回退
|
||||||
|
- 终态(Completed / Failed / Skipped)触发 `OnPlanStepComplete` hook
|
||||||
|
- 重试由上层应用新建 `Plan` 实现(不在 `TaskAgent` 内做自动重试)
|
||||||
|
|
||||||
|
#### 3.3.2 Session 状态
|
||||||
|
|
||||||
|
`AgentSession` 的状态机比 `Step` 简单:
|
||||||
|
|
||||||
|
```
|
||||||
|
创建 (new) ──► turn_index=0 ──► submit_turn() ──► turn_index+=1 ──► ... ──► 销毁
|
||||||
|
```
|
||||||
|
|
||||||
|
`turn_index` 累加,`cost_so_far` 累加,无显式状态枚举(避免过度设计)。
|
||||||
|
|
||||||
|
### 3.4 Hook 扩展设计
|
||||||
|
|
||||||
|
在 `src/llm/hooks.rs` 中追加 3 个事件 + 2 个上下文字段:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub enum HookEvent {
|
||||||
|
// ... 现有 4 个:PreRequest / PostRequest / OnRetry / OnError ...
|
||||||
|
|
||||||
|
// 新增 3 个(Phase 4):
|
||||||
|
OnTurnStart,
|
||||||
|
OnTurnEnd,
|
||||||
|
OnPlanStepComplete,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct HookContext {
|
||||||
|
// ... 现有字段 ...
|
||||||
|
|
||||||
|
// 新增 2 个(Phase 4):
|
||||||
|
pub turn_index: Option<u32>, // OnTurnStart / OnTurnEnd 用
|
||||||
|
pub plan_step_index: Option<usize>, // OnPlanStepComplete 用
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**设计意图**:
|
||||||
|
- **不破坏现有 hook 兼容性**:3 个新事件是 enum 追加,2 个新字段是 `Option<T>` 默认 `None`
|
||||||
|
- 上层应用可通过监听 `OnTurnEnd` 实现"独立 task 回写 ConversationMemory"——呼应"记忆在独立 task 处理"原则
|
||||||
|
- `OnPlanStepComplete` 提供"步骤级别"的可观测性,与 Hermes 的"任务进度回调"对齐
|
||||||
|
|
||||||
|
### 3.5 错误体系
|
||||||
|
|
||||||
|
`AgentError` 与下层错误的关系:
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────┐
|
||||||
|
│ AgentError │
|
||||||
|
├──────────────────┤
|
||||||
|
│ Llm(LlmError) │──► 透传 Phase 0 错误,含 is_recoverable()
|
||||||
|
│ Tool(ToolError) │──► 透传 Phase 2 错误,含 is_recoverable()
|
||||||
|
│ Memory(MemoryError)│─► 透传 Phase 3 错误
|
||||||
|
│ PlanParse(String) │─► Agent 层特有
|
||||||
|
│ HookBlocked(String)│─► Agent 层特有
|
||||||
|
│ LimitExceeded(String)│► Agent 层特有
|
||||||
|
│ Config(String) │──► Agent 层特有
|
||||||
|
│ Other(String) │──► 兜底
|
||||||
|
└──────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
is_recoverable(): 聚合判定
|
||||||
|
- Llm/Memory 可恢复(重试)
|
||||||
|
- PlanParse / Config 不可恢复(需人工介入)
|
||||||
|
- Tool / HookBlocked / LimitExceeded 按内层错误判定
|
||||||
|
```
|
||||||
|
|
||||||
|
**自动 From 转换**:通过 `#[from]` 宏实现 `From<LlmError>` / `From<ToolError>` / `From<MemoryError>`,让 `submit_turn` 内部可以用 `?` 运算符直接传播。
|
||||||
|
|
||||||
|
### 3.6 与 Phase 0-3 模块的集成
|
||||||
|
|
||||||
|
| Phase 4 组件 | 调用的下层 API | 调用位置 |
|
||||||
|
|-------------|--------------|---------|
|
||||||
|
| `AgentSession::submit_turn` | `LlmCycle::new` + `with_system_prompt` + `with_hook_executor` + `with_compact_config` + `with_messages` + `submit_with_tools` | session.rs |
|
||||||
|
| `AgentSession::submit_turn` | `CostTracker::add`(累计 cost) | session.rs |
|
||||||
|
| `RuntimeBundle::new` | `ToolRegistry::register`(注册 retrieve tool) | runtime.rs |
|
||||||
|
| `TaskAgent::execute_plan` | `AgentSession::submit_turn`(每步调一次) | task.rs |
|
||||||
|
| `JsonPlanParser::parse` | `serde_json::from_str` | task.rs |
|
||||||
|
| `AgentError::from` | `LlmError` / `ToolError` / `MemoryError` | error.rs |
|
||||||
|
| `HookContext` 扩展 | `HookEvent::OnTurnStart/End/OnPlanStepComplete` | llm/hooks.rs |
|
||||||
|
|
||||||
|
**不调用的下层 API**(明确边界):
|
||||||
|
- ❌ `ConversationMemory`(由上层独立 task 管理)
|
||||||
|
- ❌ `KnowledgeStore`(由上层独立 task 管理)
|
||||||
|
- ❌ `McpClient`(已由 `ToolRegistry` 包装)
|
||||||
|
- ❌ `StreamEvents::submit_stream`(v1 暂不暴露流式 `submit_turn`,v0.2 再说)
|
||||||
|
|
||||||
|
## 4. 实施计划
|
||||||
|
|
||||||
|
### 4.1 文件清单
|
||||||
|
|
||||||
|
#### 新增文件(7 个)
|
||||||
|
|
||||||
|
```
|
||||||
|
src/agent.rs # 模块根 + pub use 重导出
|
||||||
|
src/agent/agent.rs # Agent trait
|
||||||
|
src/agent/runtime.rs # RuntimeBundle + AgentConfig
|
||||||
|
src/agent/session.rs # AgentSession(含 submit_turn reference impl)
|
||||||
|
src/agent/task.rs # TaskAgent trait + Plan/Step + PlanParser + JsonPlanParser
|
||||||
|
src/agent/builder.rs # AgentBuilder
|
||||||
|
src/agent/error.rs # AgentError
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 修改文件(3 个)
|
||||||
|
|
||||||
|
```
|
||||||
|
src/lib.rs # + pub mod agent;
|
||||||
|
src/llm/hooks.rs # + 3 个事件变体 + 2 个 HookContext 字段
|
||||||
|
docs/roadmap.md # Phase 4 状态 ❌ 缺失 → ✅
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 关联文档(已完成 / 待写)
|
||||||
|
|
||||||
|
```
|
||||||
|
docs/note-agent-harness-references.md # ✅ 已存在
|
||||||
|
docs/note-agent-runtime-design.md # ✅ 已存在(与本文件配套)
|
||||||
|
docs/7-agent-runtime.md # ✅ 本文件
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 任务拆解(按依赖顺序)
|
||||||
|
|
||||||
|
| 顺序 | 任务 | 涉及文件 | 验证 |
|
||||||
|
|------|------|---------|------|
|
||||||
|
| 1 | 修改 `llm/hooks.rs` 追加 3 个事件 + 2 个字段 | `src/llm/hooks.rs` | `cargo build` 通过;现有测试不挂 |
|
||||||
|
| 2 | 新建 `agent/error.rs` 定义 `AgentError` | `src/agent/error.rs` | `cargo build` 通过 |
|
||||||
|
| 3 | 新建 `agent/agent.rs` 定义 `Agent` trait | `src/agent/agent.rs` | `cargo build` 通过 |
|
||||||
|
| 4 | 新建 `agent/runtime.rs` 定义 `RuntimeBundle` + `AgentConfig` | `src/agent/runtime.rs` | `cargo build` 通过 |
|
||||||
|
| 5 | 新建 `agent/builder.rs` 定义 `AgentBuilder` | `src/agent/builder.rs` | `cargo build` 通过 |
|
||||||
|
| 6 | 新建 `agent/session.rs` 定义 `AgentSession` + `submit_turn` | `src/agent/session.rs` | `cargo build` 通过 |
|
||||||
|
| 7 | 新建 `agent/task.rs` 定义 `TaskAgent` + `Plan` / `Step` / `PlanParser` / `JsonPlanParser` | `src/agent/task.rs` | `cargo build` 通过 |
|
||||||
|
| 8 | 新建 `src/agent.rs` 模块根 + `pub use` 重导出 | `src/agent.rs` | `cargo build` 通过 |
|
||||||
|
| 9 | 修改 `lib.rs` 导出 `pub mod agent;` | `src/lib.rs` | `cargo build` 通过 |
|
||||||
|
| 10 | 编写 2-3 个烟雾测试 | `src/agent/*.rs` 内联 | `cargo test` 通过 |
|
||||||
|
| 11 | 更新 `roadmap.md` 状态翻转 | `docs/roadmap.md` | 文档 review |
|
||||||
|
| 12 | 完整 `cargo test` 跑全量回归 | — | 所有已有测试不挂 |
|
||||||
|
|
||||||
|
### 4.3 依赖关系
|
||||||
|
|
||||||
|
```
|
||||||
|
hooks.rs (1) ──┐
|
||||||
|
├──► agent/error.rs (2) ──► agent/agent.rs (3)
|
||||||
|
│ │
|
||||||
|
│ ▼
|
||||||
|
│ agent/runtime.rs (4)
|
||||||
|
│ │
|
||||||
|
│ ▼
|
||||||
|
│ agent/builder.rs (5)
|
||||||
|
│ │
|
||||||
|
│ ▼
|
||||||
|
│ agent/session.rs (6)
|
||||||
|
│ │
|
||||||
|
│ ▼
|
||||||
|
└─────────────────► agent/task.rs (7)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
src/agent.rs (8)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
src/lib.rs (9)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
cargo test (10)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
roadmap.md (11)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
回归 (12)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.4 预估工作量
|
||||||
|
|
||||||
|
| 阶段 | 行数 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| 1(hooks 扩展) | ~15 | 3 个变体 + 2 个字段 + 文档 |
|
||||||
|
| 2-7(7 个 agent 文件) | ~600 | 含 import + trait + struct + impl + 文档 |
|
||||||
|
| 8-9(lib.rs + agent.rs 模块根) | ~20 | 主要是 pub use 重导出 |
|
||||||
|
| 10(烟雾测试) | ~100 | 2-3 个测试 |
|
||||||
|
| 11(roadmap 同步) | ~5 | 状态翻转一行 |
|
||||||
|
| **合计** | **~740** | 与 `note-agent-runtime-design.md` §6 预估一致 |
|
||||||
|
|
||||||
|
## 5. 风险评估
|
||||||
|
|
||||||
|
### 5.1 抽象化边界(核心风险)
|
||||||
|
|
||||||
|
**风险描述**:Phase 4 容易"过度抽象"——参考了 OpenHarness / Hermes 后,倾向于把它们的核心能力都搬到 Rust core 库里。
|
||||||
|
|
||||||
|
**缓解措施**:
|
||||||
|
- 严格遵循 §1.3 的 7 条设计原则
|
||||||
|
- 每次添加新 trait / struct 前,先问"这属于 core 库职责吗?"
|
||||||
|
- 业务能力(Plan 拆解、多 Agent 协同、技能加载)一律走 trait 注入或 v0.2+ 延后
|
||||||
|
|
||||||
|
### 5.2 对 Phase 0-3 的侵入风险
|
||||||
|
|
||||||
|
**风险描述**:为实现 Phase 4 需修改 `src/llm/hooks.rs`,可能破坏 Phase 0 的现有测试。
|
||||||
|
|
||||||
|
**缓解措施**:
|
||||||
|
- 只追加 enum 变体和 `Option<T>` 字段(NF9)
|
||||||
|
- 顺序:先跑 `cargo test` 确认 Phase 0 测试不挂,再开始 Phase 4
|
||||||
|
- 详细回归验证:实施完毕后跑全量 `cargo test`
|
||||||
|
|
||||||
|
### 5.3 参考项目语言差异
|
||||||
|
|
||||||
|
**风险描述**:OpenClaw / Hermes / OpenHarness 均为 Python/TypeScript,OpenHuman 虽是 Rust + Tauri 但定位是桌面应用。直接照搬接口形状可能导致 Rust 借用检查问题、async 复杂度增加。
|
||||||
|
|
||||||
|
**缓解措施**:
|
||||||
|
- §1.3 明确"借鉴不照搬"
|
||||||
|
- 反模式列表(见 `docs/note-agent-harness-references.md` §6)作为排除项
|
||||||
|
- 接口设计优先考虑 Rust 惯例(`Arc<dyn Trait>` / `async fn` / `Result<T, E>`)
|
||||||
|
|
||||||
|
### 5.4 trait 设计的稳定性风险
|
||||||
|
|
||||||
|
**风险描述**:Phase 4 是 v0.1 的第一个"复杂 trait 集合",如果 trait 形状不稳定,v0.2+ 添加新能力时会 breaking。
|
||||||
|
|
||||||
|
**缓解措施**:
|
||||||
|
- §3.2 的所有 trait / struct 在 `docs/note-agent-runtime-design.md` §4 已固化草案
|
||||||
|
- 实施时如需调整,应先更新决策记录再改代码
|
||||||
|
- 预留扩展点:`Agent::tool_definitions` 的默认实现可被子 trait 覆盖
|
||||||
|
|
||||||
|
### 5.5 实施进度风险
|
||||||
|
|
||||||
|
**风险描述**:12 项交付物虽然不多,但 `submit_turn` 的 reference impl 需要在"LlmCycle 之上做正确组装",容易卡在细节。
|
||||||
|
|
||||||
|
**缓解措施**:
|
||||||
|
- 任务拆解(§4.2)按依赖顺序排好
|
||||||
|
- 烟雾测试只验证"能跑通"不验证"业务正确"——避免陷入业务循环的细节
|
||||||
|
- 必要时先做 `MockProvider`(Phase 0 已有模式),不依赖真实 LLM
|
||||||
|
|
||||||
|
## 6. 验收标准
|
||||||
|
|
||||||
|
### 6.1 代码验收
|
||||||
|
|
||||||
|
- [ ] `cargo build --release` 0 错误 0 警告(clippy)
|
||||||
|
- [ ] `cargo test` 所有 Phase 0-3 已有测试 + Phase 4 新增测试全部通过
|
||||||
|
- [ ] `cargo doc --no-deps` 所有公开 API 有 `///` 文档注释
|
||||||
|
- [ ] 新增代码 700-750 行(含测试 + 文档注释),与 §4.4 预估一致
|
||||||
|
- [ ] `src/lib.rs` 新增一行 `pub mod agent;`
|
||||||
|
- [ ] `src/llm/hooks.rs` 仅追加(不修改现有变体或字段)
|
||||||
|
|
||||||
|
### 6.2 接口验收
|
||||||
|
|
||||||
|
- [ ] 7 个新文件全部存在(§4.1)
|
||||||
|
- [ ] `Agent` trait 包含 `name` / `system_prompt` / `tool_definitions` 三个方法
|
||||||
|
- [ ] `RuntimeBundle` 包含 6 个字段(provider / tool_registry / hook_executor / memory_store? / retriever? / config)
|
||||||
|
- [ ] `AgentSession::submit_turn` 实现约 30 行,含 OnTurnStart/End hook 触发
|
||||||
|
- [ ] `TaskAgent` 提供双入口 `run` + `execute_plan`
|
||||||
|
- [ ] `JsonPlanParser` 实现约 20 行,基于 `serde_json`
|
||||||
|
- [ ] `AgentError` 聚合 8 个变体,含 `is_recoverable()`
|
||||||
|
- [ ] `AgentBuilder` 提供 6 个 setter + `build()` 校验
|
||||||
|
- [ ] `HookEvent` 新增 3 个变体:`OnTurnStart` / `OnTurnEnd` / `OnPlanStepComplete`
|
||||||
|
- [ ] `HookContext` 新增 2 个 `Option` 字段:`turn_index` / `plan_step_index`
|
||||||
|
|
||||||
|
### 6.3 测试验收
|
||||||
|
|
||||||
|
至少 2-3 个烟雾测试通过:
|
||||||
|
|
||||||
|
- [ ] **测试 1**:`Agent` trait 可实现 + `RuntimeBundle` 可构造(builder 链式调用)
|
||||||
|
- [ ] **测试 2**:`AgentSession::submit_turn` 跑通 mock provider(Phase 0 `MockProvider` 模式)
|
||||||
|
- [ ] **测试 3(可选)**:`JsonPlanParser::parse` 能解析合法 JSON,失败时返回 `AgentError::PlanParse`
|
||||||
|
|
||||||
|
### 6.4 文档验收
|
||||||
|
|
||||||
|
- [ ] `docs/7-agent-runtime.md`(本文件)完整、6 段式结构齐备
|
||||||
|
- [ ] `docs/note-agent-runtime-design.md` 与本文件互相引用一致
|
||||||
|
- [ ] `docs/note-agent-harness-references.md` 与本文件互相引用一致
|
||||||
|
- [ ] `docs/roadmap.md` Phase 4 状态从 ❌ 缺失 改为 ✅,交付物清单更新
|
||||||
|
|
||||||
|
### 6.5 行为验收(人工 review)
|
||||||
|
|
||||||
|
- [ ] `AgentSession::submit_turn` 不持有 `ConversationMemory`(grep 验证无 `use crate::memory::ConversationMemory`)
|
||||||
|
- [ ] `RuntimeBundle::new` 当 `retriever` 为 `Some` 时自动注册 `"retrieve"` tool
|
||||||
|
- [ ] `AgentBuilder::build` 在必填字段缺失时返回 `AgentError::Config`(而非 panic)
|
||||||
|
- [ ] `AgentError::is_recoverable()` 对各变体返回正确分类
|
||||||
|
|
||||||
|
### 6.6 风险验收
|
||||||
|
|
||||||
|
- [ ] 5.1 抽象化边界:trailt 列表中**不包含** Multi-Agent / Skills / TUI / Gateway 等应用层能力
|
||||||
|
- [ ] 5.2 Phase 0-3 侵入:`git diff` 显示 `src/llm/hooks.rs` 仅追加
|
||||||
|
- [ ] 5.3 语言差异:trait 形状符合 Rust 惯例(无 Python 风格的复杂继承)
|
||||||
|
- [ ] 5.4 trait 稳定性:决策记录与最终代码一致
|
||||||
|
- [ ] 5.5 实施进度:实际工作量与 §4.4 预估偏差 < 30%
|
||||||
|
|
||||||
|
## 7. 一句话总结
|
||||||
|
|
||||||
|
> **Phase 4 = 1 个 trait(Agent)+ 1 个容器(RuntimeBundle)+ 1 个会话(AgentSession)+ 1 个任务抽象(TaskAgent)+ 4 个辅助组件(Builder / Error / PlanParser / Hook 扩展),约 740 行代码,把 Phase 0-3 已有能力"装配"成"智能体"的概念。**
|
||||||
@@ -0,0 +1,180 @@
|
|||||||
|
# Agent Harness 参考项目调研笔记
|
||||||
|
|
||||||
|
> 调研日期:2026-06-09
|
||||||
|
> 用途:为 AG Core Phase 4(Agent Runtime)及后续 v0.2+ 扩展提供设计参考。
|
||||||
|
> 关联:`docs/roadmap.md` Phase 4 / 扩展计划(v0.2+)小节。
|
||||||
|
|
||||||
|
本笔记调研了 4 个 2026 年公开的 AI Agent 项目,对比其核心架构与 AG Core 已完成模块的交集,作为 Phase 4 设计与未来扩展的输入。
|
||||||
|
|
||||||
|
> **重要事实**:4 个项目**均非 Rust 写的**(OpenHuman 虽是 Rust + Tauri,但定位是 desktop 应用)。其价值不在"抄代码",而在参考**经过生产验证的 Agent Harness 架构模式**。AG Core 处于"core 库"层,是这些项目"最底层依赖"的角色。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 项目概览
|
||||||
|
|
||||||
|
| 项目 | 类型 | 语言 | GitHub | Stars(调研时) | 定位 |
|
||||||
|
|------|------|------|--------|--------------|------|
|
||||||
|
| **OpenClaw** | Gateway 网关 | TypeScript / Node 24 | `openclaw/openclaw` | — | 自托管消息平台 ↔ AI Agent 桥接 |
|
||||||
|
| **Hermes Agent** | 自主学习智能体 | Python 3.11 | `NousResearch/hermes-agent` | — | 随使用成长的个人数字员工 |
|
||||||
|
| **OpenHuman** | 桌面助手 | Rust + Tauri | `tinyhumansai/openhuman` | 2.3k+ | 记忆驱动的跨工具私人助理 |
|
||||||
|
| **OpenHarness** | Agent Harness 框架 | Python | `HKUDS/OpenHarness` | 12.2k+ | 对标 Claude Code 的轻量级基础设施 |
|
||||||
|
|
||||||
|
## 2. 核心架构对照
|
||||||
|
|
||||||
|
| 维度 | OpenClaw | Hermes Agent | OpenHuman | OpenHarness |
|
||||||
|
|------|----------|--------------|-----------|-------------|
|
||||||
|
| **Agent Loop 形态** | 外部 Pi 二进制进程 | 内置 while 循环 | 内置循环 | 70 行 `run_query` |
|
||||||
|
| **记忆模型** | 跨平台 session | MEMORY.md + 技能库 | Memory Tree(SQLite 分层摘要) | MEMORY.md + Auto-Compaction |
|
||||||
|
| **工具机制** | MCP + 插件 | 40+ 内置技能 + 自动技能生成 | 118+ 集成 + Native Toolbelt | 43 工具 + BaseTool Pydantic |
|
||||||
|
| **多 Agent** | 消息平台多 gateway | 并行子 Agent + RPC | Agent Coordination | Swarm 子代理委派 |
|
||||||
|
| **权限/治理** | `allowFrom` + 提及规则 | 容器加固 + Cron 审批 | 本地优先 + 隐私 | 三级权限 + 钩子 |
|
||||||
|
| **规划/任务** | 无显式规划 | 自然语言驱动 | 无显式 | 隐式(LLM 自我规划) |
|
||||||
|
| **持久化** | 外部进程状态 | `~/.hermes/` 目录 | `~/.openhuman/` SQLite | MEMORY.md + state |
|
||||||
|
| **Hook 体系** | 渠道适配器 | cron + 自定义钩子 | 集成触发 | PreToolUse / PostToolUse |
|
||||||
|
| **干运行模式** | ❌ | ❌ | ❌ | ✅ `--dry-run` |
|
||||||
|
| **流式 TUI** | ✅ 控制 UI | ✅ 完整 TUI | ✅ 桌面应用 | ✅ React/Ink |
|
||||||
|
|
||||||
|
## 3. 共同分层(5 层架构)
|
||||||
|
|
||||||
|
4 个项目都能切成这 5 层,**OpenHarness 的分层最清晰**:
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ L4 扩展层 多 Agent / 渠道网关 / 插件 / 任务调度 │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ L3 治理层 权限 / Hook / 审批 / 安全策略 │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ L2 知识层 提示词工厂 / 技能库 / 持久记忆 / 摘要 │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ L1 执行层 Agent Loop / 工具注册 / 流式事件 / 重试 │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ L0 模型层 LLM Provider / Provider Registry / 鉴权 │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**关键观察**:
|
||||||
|
|
||||||
|
- 4 个项目都假定 L0/L1 是"基础设施",不在 core 层重新发明
|
||||||
|
- AG Core 在 Phase 0-3 已完成 L0 / L1 / L2(部分) / L3(部分)
|
||||||
|
- Phase 4 处于 L1 与 L2 的衔接处
|
||||||
|
- L4(Multi-Agent / Gateway)属于应用层,应由上层 crate / 二进制承担
|
||||||
|
|
||||||
|
## 4. AG Core 已具备的对应能力
|
||||||
|
|
||||||
|
对照 5 层架构,AG Core 已就绪情况:
|
||||||
|
|
||||||
|
| 层 | AG Core 对应 | 完成状态 | 文档 |
|
||||||
|
|----|------------|---------|------|
|
||||||
|
| **L0 模型层** | `llm::provider` + `llm::ProviderRegistry` | ✅ Phase 0 | `docs/2-llm-call-lifecycle.md` |
|
||||||
|
| **L1 执行层** | `llm::cycle` + `llm::stream` + `tools::ToolRegistry` | ✅ Phase 0/2 | `docs/5-tool-system.md` |
|
||||||
|
| **L2 知识层** | `prompt` + `memory::store/conversation/knowledge/retriever` | ✅ Phase 1/3 | `docs/4-prompt-engineering.md` / `docs/6-memory-system.md` |
|
||||||
|
| **L3 治理层** | `llm::hooks` + `tools::PermissionChecker` | ✅ Phase 0/2 | `docs/3-phase0-remaining.md` |
|
||||||
|
| **L1→L2 衔接(Agent Runtime)** | — | ❌ **Phase 4 待实现** | — |
|
||||||
|
|
||||||
|
## 5. 借鉴到 Phase 4 的核心模式
|
||||||
|
|
||||||
|
### 5.1 OpenHarness 风格:显式依赖注入容器
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// 核心思想:所有运行时依赖打包成一个对象,沿调用链显式传递
|
||||||
|
pub struct RuntimeBundle {
|
||||||
|
pub provider: Arc<dyn LlmProvider>,
|
||||||
|
pub tool_registry: Arc<ToolRegistry>,
|
||||||
|
pub hook_executor: Arc<HookExecutor>,
|
||||||
|
// ... 可选:memory_store / retriever
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**好处**:测试时可注入 mock bundle;支持同时跑多 session;依赖关系显式可追踪。
|
||||||
|
|
||||||
|
**AG Core 决策**:采纳,命名为 `agent::RuntimeBundle`(详见 Phase 4 设计决策记录)。
|
||||||
|
|
||||||
|
### 5.2 Hermes 风格:实体与会话解耦
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub trait Agent: Send + Sync { /* 角色定义,不绑定 session */ }
|
||||||
|
pub struct AgentSession { /* 绑定 session_id + bundle + 状态 */ }
|
||||||
|
```
|
||||||
|
|
||||||
|
**好处**:同一 `Agent` 可被多个 `AgentSession` 复用(多用户、多会话);session 状态(cost、turn index)独立追踪。
|
||||||
|
|
||||||
|
**AG Core 决策**:采纳,详见 Phase 4 §接口签名草案。
|
||||||
|
|
||||||
|
### 5.3 OpenHarness 风格:Agent Loop 的极简本质
|
||||||
|
|
||||||
|
OpenHarness 的 `run_query` 核心只有 70 行,本质是一个 `while` 循环 + 一个 `if not tool_uses: return` 的判断。
|
||||||
|
|
||||||
|
**AG Core 现状**:`llm::cycle::submit_with_tools()` 已经在 Phase 2 末实现了这个循环,Phase 4 不应重新实现。
|
||||||
|
|
||||||
|
**AG Core 决策**:Phase 4 只在 `AgentSession::submit_turn()` 提供 30 行的 reference impl,组装 `LlmCycle` 并暴露其能力,业务循环留给上层。
|
||||||
|
|
||||||
|
### 5.4 OpenHuman 风格:分层摘要记忆树
|
||||||
|
|
||||||
|
OpenHuman 的 Memory Tree 创新点:
|
||||||
|
- 多源数据(Gmail / Slack / GitHub 等)→ 规范化 Markdown → ≤3k token chunks → 打分 → 折叠成 per-source / per-topic / per-day 摘要树
|
||||||
|
- 存储在本地 SQLite
|
||||||
|
- Auto-fetch 每 20 分钟拉取新数据
|
||||||
|
|
||||||
|
**AG Core 现状**:`memory::KnowledgeStore` 已是 LLM Wiki 风格的抽象层。Phase 4 v1 不引入 SQLite 实现(属于 L4 应用层)。
|
||||||
|
|
||||||
|
**AG Core 决策**:v0.2+ 考虑 `note-knowledge-graph-design.md` 已记录的 KnowledgeGraph / RecallBased 淘汰等深度记忆能力。
|
||||||
|
|
||||||
|
## 6. 反模式(不要照搬)
|
||||||
|
|
||||||
|
| 反模式 | 出现项目 | 不要照搬的理由 |
|
||||||
|
|--------|---------|--------------|
|
||||||
|
| 双进程架构(Node UI + Python 后端) | OpenClaw | 应用层架构,core 库不涉及 |
|
||||||
|
| SQLite 持久化细节 | OpenHuman | 属于 L4 应用层具体实现 |
|
||||||
|
| Pydantic 工具校验 | OpenHarness | Python 生态强项;Rust 已有 `serde_json::Value` + JSON Schema,足够 |
|
||||||
|
| 43 工具内置 | OpenHarness | 应用层选型,core 库应保持"零内置工具" |
|
||||||
|
| 单进程内多平台消息网关 | OpenClaw / Hermes | 属于 L4 应用层 |
|
||||||
|
|
||||||
|
## 7. 与 AG Core 现有模块的接口对齐
|
||||||
|
|
||||||
|
下表列出 4 个项目中被 AG Core **已经覆盖**或**即将在 Phase 4 覆盖**的能力,避免重复造轮子:
|
||||||
|
|
||||||
|
| 4 项目中的能力 | AG Core 对应 | 状态 |
|
||||||
|
|--------------|-------------|------|
|
||||||
|
| 工具注册表 | `tools::ToolRegistry` | ✅ Phase 2 已实现 |
|
||||||
|
| 权限检查 | `tools::PermissionChecker` | ✅ Phase 2 已实现 |
|
||||||
|
| 生命周期钩子 | `llm::HookExecutor` | ✅ Phase 0 已实现,Phase 4 扩展 3 个事件 |
|
||||||
|
| 自动 tool 循环 | `llm::cycle::submit_with_tools()` | ✅ Phase 2 末已实现 |
|
||||||
|
| Auto-Compaction | `llm::compact` | ✅ Phase 0 已实现 |
|
||||||
|
| 对话记忆 | `memory::ConversationMemory` | ✅ Phase 3 已实现 |
|
||||||
|
| 知识库 | `memory::KnowledgeStore` | ✅ Phase 3 已实现 |
|
||||||
|
| 关键词检索 | `memory::MemoryRetriever` | ✅ Phase 3 已实现 |
|
||||||
|
| 提示词模板 | `prompt::PromptTemplate` + `PromptComposer` | ✅ Phase 1 已实现 |
|
||||||
|
| 用量追踪 | `llm::cycle::usage::CostTracker` | ✅ Phase 0 已实现 |
|
||||||
|
| **显式依赖注入容器** | — | ⏳ **Phase 4 新增 `RuntimeBundle`** |
|
||||||
|
| **Agent ↔ Session 分离** | — | ⏳ **Phase 4 新增 `Agent` + `AgentSession`** |
|
||||||
|
| **任务规划** | — | ⏳ **Phase 4 新增 `TaskAgent` + `Plan`** |
|
||||||
|
| **结构化 Plan 解析** | — | ⏳ **Phase 4 新增 `PlanParser` trait** |
|
||||||
|
|
||||||
|
## 8. v0.2+ 扩展项与参考项目的对应
|
||||||
|
|
||||||
|
`docs/roadmap.md` 扩展计划(v0.2+)表中的项,在 4 个项目中的对应实现:
|
||||||
|
|
||||||
|
| 扩展项 | OpenClaw | Hermes | OpenHuman | OpenHarness |
|
||||||
|
|--------|----------|--------|-----------|-------------|
|
||||||
|
| Multi-Agent / Swarm | ❌ | ✅ 并行子 Agent | ✅ Agent Coordination | ✅ Swarm |
|
||||||
|
| Markdown 技能 | ❌ | ✅ SKILL.md | ❌ | ✅ prompts/*.md |
|
||||||
|
| 多通道检索(vector + keyword) | ❌ | ❌ | ✅ Memory Tree | ❌ |
|
||||||
|
| KnowledgeGraph | ❌ | ❌ | ✅ Memory Graph | ❌ |
|
||||||
|
| TokenJuice 智能压缩 | ❌ | ✅ 轨迹压缩 | ✅ TokenJuice | ✅ Auto-Compaction |
|
||||||
|
| TUI / Gateway | ✅ 控制 UI | ✅ 完整 TUI | ✅ 桌面应用 | ✅ React/Ink |
|
||||||
|
| 训练 / RL 轨迹 | ❌ | ✅ Atropos | ❌ | ❌ |
|
||||||
|
| 人类审批(Human-in-the-loop) | ❌ | ✅ Cron 审批 | ❌ | ✅ 权限弹窗 |
|
||||||
|
|
||||||
|
## 9. 参考资源
|
||||||
|
|
||||||
|
- **OpenClaw 文档**:<https://docs.openclaw.ai/zh-CN>
|
||||||
|
- **Hermes Agent 官网**:<https://hermes-agent.org/zh/>
|
||||||
|
- **Hermes Agent GitHub**:<https://github.com/NousResearch/hermes-agent>
|
||||||
|
- **OpenHuman GitHub**:<https://github.com/tinyhumansai/openhuman>
|
||||||
|
- **OpenHuman 中文站**:<https://openhumanai.cn/docs/>
|
||||||
|
- **OpenHarness GitHub**:<https://github.com/HKUDS/OpenHarness>
|
||||||
|
- **OpenHarness 深度学习笔记**:<https://www.joyehuang.me/blog/20260410---openharnessphase1/post>
|
||||||
|
|
||||||
|
## 10. 一句话总结
|
||||||
|
|
||||||
|
> **4 个项目都不在 L0/L1 重新发明轮子——它们都假定基础设施已就绪。AG Core 在 Phase 0-3 已经把这 4 层全做完了。Phase 4 的核心价值是把它们"装配起来",同时为未来 v0.2+ 的 L4 扩展(Multi-Agent / Skills / TUI)留好接口。**
|
||||||
@@ -0,0 +1,335 @@
|
|||||||
|
# Phase 4 Agent Runtime — 设计决策记录
|
||||||
|
|
||||||
|
> 决策固化日期:2026-06-09
|
||||||
|
> 用途:记录 Phase 4 设计阶段的关键决策、接口签名草案、文件清单,作为 `docs/7-agent-runtime.md` 方案文档的输入约束。
|
||||||
|
> 关联:
|
||||||
|
> - `docs/7-agent-runtime.md` — 完整方案文档(待写 / 已写)
|
||||||
|
> - `docs/note-agent-harness-references.md` — 参考项目调研(OpenClaw / Hermes / OpenHuman / OpenHarness)
|
||||||
|
> - `docs/roadmap.md` — 项目总 Roadmap
|
||||||
|
> - `docs/2-llm-call-lifecycle.md` / `3-phase0-remaining.md` / `4-prompt-engineering.md` / `5-tool-system.md` / `6-memory-system.md` — Phase 0-3 方案
|
||||||
|
|
||||||
|
本文件是 Phase 4 设计阶段的"事实基础"——所有决策都有明确的对话出处与依据。后续 Phase 4 实施时应与本记录保持一致;如需调整,应先更新本记录再改代码。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 设计目标
|
||||||
|
|
||||||
|
AG Core Phase 4 的定位是**「Phase 0-3 的薄胶水层 + 一组 trait 抽象」**,遵循 OpenHarness 的"显式依赖注入"模式 + Hermes 的"两层实体/会话"模型。**不**实现业务循环,**不**做产品级功能,**不**假设上层如何使用 memory。
|
||||||
|
|
||||||
|
## 2. 范围与边界
|
||||||
|
|
||||||
|
### 2.1 必须实现(12 项)
|
||||||
|
|
||||||
|
| # | 交付物 | 文件 | 关键决策 |
|
||||||
|
|---|--------|------|---------|
|
||||||
|
| 1 | `Agent` trait | `src/agent/agent.rs` | 角色定义:name / system_prompt / 工具集 / 引用 session 句柄 |
|
||||||
|
| 2 | `RuntimeBundle` | `src/agent/runtime.rs` | 依赖注入容器(OpenHarness 风格) |
|
||||||
|
| 3 | `AgentSession` | `src/agent/session.rs` | 会话实例 + **最小 reference impl**(`submit_turn` ~30 行) |
|
||||||
|
| 4 | `TaskAgent` + `Plan` / `Step` | `src/agent/task.rs` | 双入口:`run(goal)` 自主式 + `execute_plan(plan)` 外部驱动式 |
|
||||||
|
| 5 | `PlanParser` trait + `JsonPlanParser` 参考实现 | `src/agent/task.rs` | 注入式(澄清 3 选项 C) |
|
||||||
|
| 6 | `AgentError` | `src/agent/error.rs` | 聚合 LlmError / ToolError / MemoryError,含 `is_recoverable()` |
|
||||||
|
| 7 | `AgentConfig` / `AgentBuilder` | `src/agent/builder.rs` | 链式构造 `RuntimeBundle` |
|
||||||
|
| 8 | Hook 事件扩展 | `src/llm/hooks.rs` | 追加 `OnTurnStart` / `OnTurnEnd` / `OnPlanStepComplete` 3 个事件 + 上下文扩展(澄清 4) |
|
||||||
|
| 9 | `lib.rs` 导出 | `src/lib.rs` | 一行 `pub mod agent;` |
|
||||||
|
| 10 | 烟雾测试 | `src/agent/tests.rs` 或内联 | 2-3 个:trait 可装配 / RuntimeBundle 可构造 / `submit_turn` 跑通 mock |
|
||||||
|
| 11 | 方案文档 | `docs/7-agent-runtime.md` | 编号 `7`(最大编号是 `6`,已确认无冲突) |
|
||||||
|
| 12 | Roadmap 同步 | `docs/roadmap.md` | 状态从 ❌ 缺失 改为 ✅ |
|
||||||
|
|
||||||
|
### 2.2 明确不做(v0.2+ 边界)
|
||||||
|
|
||||||
|
| 推迟项 | 理由 |
|
||||||
|
|--------|------|
|
||||||
|
| 完整 `BasicAgent` 多轮 turn 循环 | core 库不假设业务循环 |
|
||||||
|
| `ConversationAgent` 自动回写 | 记忆在独立 task 处理,由上层回写 |
|
||||||
|
| 强绑定 `ConversationMemory` 字段 | 改 `Option<Arc<dyn MemoryStore>>` 弱引用 |
|
||||||
|
| Plan 拆解的提示词模板 | 由上层注入 `PlanParser` |
|
||||||
|
| Multi-Agent / Swarm | 接口未稳定,独立 phase |
|
||||||
|
| Markdown 技能按需加载 | 属于知识层 |
|
||||||
|
| 三级权限模式 UI | 应用层 |
|
||||||
|
| 干运行 / TUI / Gateway | 应用层 |
|
||||||
|
| 完整的集成测试套件 | 2-3 个烟雾测试足够(呼应"最小范围") |
|
||||||
|
|
||||||
|
详细 v0.2+ 候选项见 `docs/roadmap.md` 扩展计划(v0.2+)小节与 `docs/note-agent-harness-references.md` 第 8 节。
|
||||||
|
|
||||||
|
## 3. 核心架构
|
||||||
|
|
||||||
|
### 3.1 分层(与 OpenHarness 5 层一致)
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ 应用层 (上层 crate / 二进制 / Gateway) │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ Agent Runtime ← Phase 4:trait + RuntimeBundle + Session │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ LLM / Tool / Prompt / Memory ← Phase 0/1/2/3(已完成) │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 实体关系
|
||||||
|
|
||||||
|
```
|
||||||
|
┌────────────┐ ┌──────────────────┐
|
||||||
|
│ Agent │ 1 * │ AgentSession │
|
||||||
|
│ (trait) ├────────►│ (struct) │
|
||||||
|
│ - name │ │ - session_id │
|
||||||
|
│ - prompt │ │ - bundle: Arc │
|
||||||
|
│ - tools │ │ - turn_index │
|
||||||
|
└────────────┘ │ - cost_so_far │
|
||||||
|
└──────────────────┘
|
||||||
|
│
|
||||||
|
▼ 共享
|
||||||
|
┌──────────────────┐
|
||||||
|
│ RuntimeBundle │
|
||||||
|
│ - provider │
|
||||||
|
│ - tool_registry │
|
||||||
|
│ - hook_executor │
|
||||||
|
│ - memory_store? │ ◄── 弱引用(澄清 2 选项 B)
|
||||||
|
│ - retriever? │
|
||||||
|
│ - config │
|
||||||
|
└──────────────────┘
|
||||||
|
│
|
||||||
|
▼ 注册为 tool
|
||||||
|
┌──────────────────┐
|
||||||
|
│ "retrieve" tool │ ◄── 如果 retriever 存在则自动注册
|
||||||
|
└──────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.3 决策对照表
|
||||||
|
|
||||||
|
| 决策点 | 选择 | 来源 |
|
||||||
|
|--------|------|------|
|
||||||
|
| 实体 vs 会话 | 两层模型(`Agent` + `AgentSession`) | 讨论第 4 轮第 1 条 / OpenHarness / Hermes |
|
||||||
|
| 范围控制 | trait + 最小 reference impl(~30 行) | 讨论第 4 轮第 2 条 / 澄清 1 选项 B |
|
||||||
|
| 记忆处理 | 弱引用 + 自动注册 retriever 为 tool | 讨论第 4 轮第 3 条 / 澄清 2 选项 B |
|
||||||
|
| Hook 扩展 | 3 个新事件 + 上下文扩展 | 讨论第 4 轮第 4 条 / 澄清 4 |
|
||||||
|
| 方案文档位置 | `docs/7-agent-runtime.md` | 讨论第 4 轮第 5 条 |
|
||||||
|
| TaskAgent 入口 | 双入口(自主 + 外部驱动) | 讨论第 4 轮第 6 条 |
|
||||||
|
| 自主式 Plan 解析 | 注入式 `PlanParser` trait + `JsonPlanParser` 参考实现 | 澄清 3 选项 C |
|
||||||
|
| 依赖注入 | `RuntimeBundle` 显式容器 | 讨论第 4 轮第 7 条 / OpenHarness |
|
||||||
|
| 文档撰写 | 推迟到 Proposal 阶段 | 讨论第 4 轮第 8 条 |
|
||||||
|
|
||||||
|
## 4. 接口签名草案
|
||||||
|
|
||||||
|
> ⚠️ **这些是"设计约束",不是最终代码**。方案文档(`docs/7-agent-runtime.md`)与实施阶段可微调字段顺序、文档注释、错误变体等,但**核心 trait 形状**和**方法名**应保持稳定。
|
||||||
|
|
||||||
|
### 4.1 `Agent` trait
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub trait Agent: Send + Sync {
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
fn system_prompt(&self) -> Option<&str>;
|
||||||
|
/// 列出该 Agent 想要暴露给 LLM 的工具定义。
|
||||||
|
/// 默认实现:从 RuntimeBundle.tool_registry 取全部(最常用)。
|
||||||
|
/// 子 trait 可覆盖做白名单/过滤。
|
||||||
|
fn tool_definitions(&self, bundle: &RuntimeBundle) -> Vec<ToolDefinition>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 `RuntimeBundle`
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct RuntimeBundle {
|
||||||
|
pub provider: Arc<dyn LlmProvider>,
|
||||||
|
pub tool_registry: Arc<ToolRegistry>,
|
||||||
|
pub hook_executor: Arc<HookExecutor>,
|
||||||
|
pub memory_store: Option<Arc<dyn MemoryStore>>, // 弱引用(澄清 2 选项 B)
|
||||||
|
pub retriever: Option<Arc<MemoryRetriever>>, // 弱引用(澄清 2 选项 B)
|
||||||
|
pub config: AgentConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RuntimeBundle {
|
||||||
|
/// 构造时如果 retriever 存在,自动注册为 "retrieve" tool。
|
||||||
|
pub fn new(/* ... */) -> Self;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.3 `AgentSession`
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct AgentSession {
|
||||||
|
pub session_id: String,
|
||||||
|
pub agent_name: String,
|
||||||
|
bundle: Arc<RuntimeBundle>,
|
||||||
|
turn_index: u32,
|
||||||
|
cost_so_far: CostTracker,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentSession {
|
||||||
|
pub fn new(agent: &dyn Agent, session_id: impl Into<String>, bundle: Arc<RuntimeBundle>) -> Self;
|
||||||
|
|
||||||
|
/// 最小 reference impl:组装 LlmCycle + submit + 累计 cost。
|
||||||
|
/// 不做 memory 回写(呼应"记忆在独立 task 处理"原则)。
|
||||||
|
pub async fn submit_turn(
|
||||||
|
&mut self,
|
||||||
|
user_input: impl Into<String>,
|
||||||
|
) -> Result<ChatResponse, AgentError>;
|
||||||
|
|
||||||
|
pub fn usage(&self) -> &CostTracker;
|
||||||
|
pub fn turn_index(&self) -> u32;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.4 `TaskAgent` + `Plan` + `Step`
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct Plan {
|
||||||
|
pub id: String,
|
||||||
|
pub goal: String,
|
||||||
|
pub steps: Vec<Step>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Step {
|
||||||
|
pub index: usize,
|
||||||
|
pub description: String,
|
||||||
|
pub status: StepStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum StepStatus {
|
||||||
|
Pending,
|
||||||
|
Running,
|
||||||
|
Completed(ChatResponse),
|
||||||
|
Failed(AgentError),
|
||||||
|
Skipped,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 注入式 Plan 解析器。
|
||||||
|
#[async_trait]
|
||||||
|
pub trait PlanParser: Send + Sync {
|
||||||
|
async fn parse(&self, raw: &str, goal: &str) -> Result<Plan, AgentError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 基于 serde_json 的参考实现(约 20 行)。
|
||||||
|
pub struct JsonPlanParser;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl PlanParser for JsonPlanParser { /* ... */ }
|
||||||
|
|
||||||
|
/// TaskAgent 双入口。
|
||||||
|
#[async_trait]
|
||||||
|
pub trait TaskAgent: Agent {
|
||||||
|
/// 自主式:内部用 LLM 拆 Plan → execute_plan
|
||||||
|
async fn run(&mut self, session: &mut AgentSession, goal: &str) -> Result<Plan, AgentError>;
|
||||||
|
|
||||||
|
/// 外部驱动式:用户预定义 Plan → 逐步执行
|
||||||
|
async fn execute_plan(
|
||||||
|
&mut self,
|
||||||
|
session: &mut AgentSession,
|
||||||
|
plan: Plan,
|
||||||
|
) -> Result<Plan, AgentError>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.5 `AgentError`
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub enum AgentError {
|
||||||
|
Llm(LlmError),
|
||||||
|
Tool(ToolError),
|
||||||
|
Memory(MemoryError),
|
||||||
|
PlanParse(String),
|
||||||
|
HookBlocked(String),
|
||||||
|
LimitExceeded(String),
|
||||||
|
Config(String),
|
||||||
|
Other(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentError {
|
||||||
|
pub fn is_recoverable(&self) -> bool { /* ... */ }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.6 `AgentConfig` + `AgentBuilder`
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct AgentConfig {
|
||||||
|
pub max_turns: u32,
|
||||||
|
pub max_tool_turns: u32,
|
||||||
|
pub session_ttl: Option<Duration>,
|
||||||
|
pub compact_config: Option<CompactConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AgentBuilder { /* ... */ }
|
||||||
|
|
||||||
|
impl AgentBuilder {
|
||||||
|
pub fn new() -> Self;
|
||||||
|
pub fn provider(self, p: Arc<dyn LlmProvider>) -> Self;
|
||||||
|
pub fn tool_registry(self, r: Arc<ToolRegistry>) -> Self;
|
||||||
|
pub fn hook_executor(self, h: Arc<HookExecutor>) -> Self;
|
||||||
|
pub fn memory_store(self, m: Arc<dyn MemoryStore>) -> Self; // 选填
|
||||||
|
pub fn retriever(self, r: Arc<MemoryRetriever>) -> Self; // 选填
|
||||||
|
pub fn config(self, c: AgentConfig) -> Self;
|
||||||
|
pub fn build(self) -> Result<RuntimeBundle, AgentError>;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.7 Hook 扩展(`src/llm/hooks.rs` 改动)
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub enum HookEvent {
|
||||||
|
// ... 现有 4 个 ...
|
||||||
|
|
||||||
|
// 新增 3 个:
|
||||||
|
OnTurnStart,
|
||||||
|
OnTurnEnd,
|
||||||
|
OnPlanStepComplete,
|
||||||
|
}
|
||||||
|
|
||||||
|
// HookContext 扩展 2 个 Option 字段(澄清 4):
|
||||||
|
pub struct HookContext {
|
||||||
|
// ... 现有字段 ...
|
||||||
|
pub turn_index: Option<u32>, // OnTurnStart/End 用
|
||||||
|
pub plan_step_index: Option<usize>, // OnPlanStepComplete 用
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 5. 文件清单
|
||||||
|
|
||||||
|
### 5.1 新增文件(7 个)
|
||||||
|
|
||||||
|
```
|
||||||
|
src/agent.rs # 模块根 + pub use 重导出
|
||||||
|
src/agent/agent.rs # Agent trait
|
||||||
|
src/agent/runtime.rs # RuntimeBundle + AgentConfig
|
||||||
|
src/agent/session.rs # AgentSession
|
||||||
|
src/agent/task.rs # TaskAgent trait + Plan/Step + PlanParser + JsonPlanParser
|
||||||
|
src/agent/builder.rs # AgentBuilder
|
||||||
|
src/agent/error.rs # AgentError
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 修改文件(3 个)
|
||||||
|
|
||||||
|
```
|
||||||
|
src/lib.rs # + pub mod agent;
|
||||||
|
src/llm/hooks.rs # + 3 个事件变体 + 2 个上下文字段(极小改)
|
||||||
|
docs/roadmap.md # 状态翻转 + Phase 4 交付物清单更新(实施时再做)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 关联文档(已存在 / 待写)
|
||||||
|
|
||||||
|
```
|
||||||
|
docs/note-agent-harness-references.md # 参考项目调研(已存在)
|
||||||
|
docs/7-agent-runtime.md # 完整方案文档(路径 A 输出)
|
||||||
|
docs/note-agent-runtime-design.md # 本文件
|
||||||
|
```
|
||||||
|
|
||||||
|
## 6. 预估规模
|
||||||
|
|
||||||
|
- **新增代码**:约 **600-700 行**(含 2-3 个烟雾测试)
|
||||||
|
- **修改代码**:约 **10-20 行**(`hooks.rs` 改动 + `lib.rs` + `roadmap.md`)
|
||||||
|
- **方案文档**:约 **450-550 行 Markdown**(沿用 6-memory-system.md 的 6 段式结构)
|
||||||
|
|
||||||
|
## 7. 待办事项(按依赖顺序)
|
||||||
|
|
||||||
|
1. ✅ Phase 4 范围已收窄(§2.1)
|
||||||
|
2. ✅ 核心架构已对齐 OpenHarness / Hermes(§3)
|
||||||
|
3. ✅ 接口签名草案已固化(§4)
|
||||||
|
4. ✅ 文件清单已确定(§5)
|
||||||
|
5. ✅ 编号冲突已验证(最大是 `6`,新文件用 `7`)
|
||||||
|
6. ⏳ 写 `docs/7-agent-runtime.md` 方案文档
|
||||||
|
7. ⏳ 按文档实施 7 个新文件 + 3 个修改
|
||||||
|
8. ⏳ 跑通 2-3 个烟雾测试
|
||||||
|
9. ⏳ 更新 `docs/roadmap.md` 状态翻转
|
||||||
|
|
||||||
|
## 8. 一句话总结
|
||||||
|
|
||||||
|
> **Phase 4 = Phase 0-3 的薄胶水层 + 一组 trait 抽象**。**不**实现业务循环,**不**做产品级功能,**不**假设上层如何使用 memory。借鉴 OpenHarness 的"显式依赖注入容器"与 Hermes 的"实体/会话分离"模型,记忆以弱引用方式接入,`MemoryRetriever` 在 `RuntimeBundle::new()` 时自动注册为 LLM 可调用的 `retrieve` 工具。
|
||||||
@@ -0,0 +1,266 @@
|
|||||||
|
# 知识图谱与高级检索设计(Phase 4 备用)
|
||||||
|
|
||||||
|
> 本文记录 Phase 3 设计过程中裁剪的内容,待 Phase 4(Agent 运行时)制定时参考。
|
||||||
|
> 来源:`docs/6-memory-system.md` v1 版本,2026-06-07
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
Phase 3 记忆系统方案做减法后,以下设计被推迟到 Phase 4。这些组件需要 Agent 的编排能力(LLM 提取标签、自动维护知识图谱、智能检索策略)才能真正产生价值,因此不适合在 Phase 3 的存储层实现。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. KnowledgeGraph(知识图谱)
|
||||||
|
|
||||||
|
### 1.1 设计意图
|
||||||
|
|
||||||
|
实体-关系图存储,用于关联检索。与 KnowledgeStore(内容/页面级)互补,提供实体级 + 关系维度的检索能力。
|
||||||
|
|
||||||
|
```
|
||||||
|
KnowledgeStore: 页面级内容("什么是 X")
|
||||||
|
KnowledgeGraph: 实体级关系("X 与什么相关")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.2 接口设计(原方案)
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct GraphEntity {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub entity_type: String, // "person" | "concept" | "project" | ...
|
||||||
|
pub description: String,
|
||||||
|
pub tags: Vec<String>, // 检索标签(全小写,原子词)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct GraphRelation {
|
||||||
|
pub source_id: String,
|
||||||
|
pub target_id: String,
|
||||||
|
pub relation_type: String, // "works_on" | "part_of" | "related_to" | ...
|
||||||
|
pub weight: f32, // 关系强度 [0.0, 1.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum RelationDirection {
|
||||||
|
Outgoing, // source_id -> target_id(默认)
|
||||||
|
Incoming, // target_id -> source_id
|
||||||
|
Both, // 双向遍历
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ScoredEntity {
|
||||||
|
pub entity: GraphEntity,
|
||||||
|
pub score: f32, // 基于图距离的评分 [0.0, 1.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait KnowledgeGraph: Send + Sync {
|
||||||
|
// 实体管理
|
||||||
|
async fn add_entity(&self, entity: GraphEntity) -> Result<(), MemoryError>;
|
||||||
|
async fn get_entity(&self, id: &str) -> Result<Option<GraphEntity>, MemoryError>;
|
||||||
|
async fn remove_entity(&self, id: &str) -> Result<(), MemoryError>;
|
||||||
|
|
||||||
|
// 关系管理
|
||||||
|
async fn add_relation(&self, relation: GraphRelation) -> Result<(), MemoryError>;
|
||||||
|
async fn remove_relation(&self, source_id: &str, target_id: &str, relation_type: &str) -> Result<(), MemoryError>;
|
||||||
|
async fn get_related(
|
||||||
|
&self,
|
||||||
|
entity_id: &str,
|
||||||
|
depth: usize,
|
||||||
|
direction: RelationDirection,
|
||||||
|
relation_types: Option<&[&str]>,
|
||||||
|
) -> Result<Vec<ScoredEntity>, MemoryError>;
|
||||||
|
|
||||||
|
// 检索
|
||||||
|
async fn find_by_keywords(&self, keywords: &[String]) -> Result<Vec<GraphEntity>, MemoryError>;
|
||||||
|
|
||||||
|
// 标签管理
|
||||||
|
async fn find_tags(&self, prefix: &str) -> Result<Vec<String>, MemoryError>;
|
||||||
|
async fn entity_count_by_tag(&self, tag: &str) -> Result<usize, MemoryError>;
|
||||||
|
async fn set_entity_tags(&self, entity_id: &str, tags: Vec<String>) -> Result<usize, MemoryError>;
|
||||||
|
fn tag_constraints(&self) -> TagConstraints;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TagConstraints {
|
||||||
|
pub max_tags_per_entity: usize, // 默认 8
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.3 标签复用原则
|
||||||
|
|
||||||
|
标签不应随意增长,应优先复用已有标签。流程:
|
||||||
|
|
||||||
|
```
|
||||||
|
LLM 提取候选标签 → 对每个候选:
|
||||||
|
graph.find_tags(candidate.lowercase())
|
||||||
|
├─ 命中已有标签 → 复用
|
||||||
|
└─ 无匹配 → 注册新标签
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.4 标签容量与精炼
|
||||||
|
|
||||||
|
每个实体最多 `max_tags_per_entity`(默认 8)个标签,按关联度降序排列。超出上限时保留关联度最高的标签。
|
||||||
|
|
||||||
|
### 1.5 InMemoryGraph 实现
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct InMemoryGraph {
|
||||||
|
entities: Mutex<HashMap<String, GraphEntity>>,
|
||||||
|
relations: Mutex<Vec<GraphRelation>>,
|
||||||
|
tag_index: Mutex<HashMap<String, HashSet<String>>>, // tag → entity_ids
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
图遍历使用 BFS/DFS 算法,需用 `HashSet<String>` 防环。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 高级评分策略
|
||||||
|
|
||||||
|
### 2.1 ScoringStrategy
|
||||||
|
|
||||||
|
Phase 3 仅使用内部的简单 TextOverlap 评分(Dice 系数)。Phase 4 可引入以下策略:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct ScoreWeights {
|
||||||
|
pub overlap: f32, // 默认 0.5 — 文本重叠度,以原始 query 为基准
|
||||||
|
pub graph: f32, // 默认 0.2 — 图距离
|
||||||
|
pub temporal: f32, // 默认 0.1 — 时间衰减
|
||||||
|
pub reference: f32, // 默认 0.2 — 引用计数
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum ScoringStrategy {
|
||||||
|
TextOverlap, // 以原始 query 为准绳的文本重叠度(默认)
|
||||||
|
GraphDistance,
|
||||||
|
TemporalWeight,
|
||||||
|
ReferenceCount,
|
||||||
|
Hybrid(ScoreWeights),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ScoreBreakdown {
|
||||||
|
pub overlap_score: f32,
|
||||||
|
pub graph_score: f32,
|
||||||
|
pub temporal_score: f32,
|
||||||
|
pub reference_score: f32,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.2 TextOverlap 算法
|
||||||
|
|
||||||
|
基于 Dice 系数计算 query 与召回内容的文本重叠度:
|
||||||
|
|
||||||
|
```
|
||||||
|
Dice = 2 × |intersect(bigrams)| / (|bigrams_query| + |bigrams_content|)
|
||||||
|
```
|
||||||
|
|
||||||
|
标题权重大于摘要,摘要权重大于正文。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 高级检索(MemoryRetriever 双通道版)
|
||||||
|
|
||||||
|
Phase 3 仅保留单通道(只搜 KnowledgeStore)。Phase 4 可恢复双通道:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct MemoryRetriever {
|
||||||
|
knowledge_store: KnowledgeStore,
|
||||||
|
knowledge_graph: Arc<dyn KnowledgeGraph>,
|
||||||
|
keyword_extractor: Arc<dyn KeywordExtractor>,
|
||||||
|
config: RetrieverConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum RetrievalStrategy {
|
||||||
|
Hybrid, // 结合所有通道 + 评分排序(默认)
|
||||||
|
KnowledgeOnly, // 仅 KnowledgeStore
|
||||||
|
GraphOnly, // 仅 KnowledgeGraph
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
检索流程:
|
||||||
|
|
||||||
|
```
|
||||||
|
1. 关键词提取(KeywordExtractor)
|
||||||
|
2. 并行召回:
|
||||||
|
- KnowledgeStore.find_by_keywords(keywords)
|
||||||
|
- KnowledgeGraph.find_by_keywords(keywords) → get_related() 图遍历
|
||||||
|
3. 逐条评分(ScoringStrategy)
|
||||||
|
4. 过滤 score < min_score
|
||||||
|
5. 排序 → 截取 top-N
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. KeywordExtractor
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub trait KeywordExtractor: Send + Sync {
|
||||||
|
fn extract(&self, query: &str) -> Vec<String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SimpleKeywordExtractor {
|
||||||
|
stop_words: HashSet<String>,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
默认实现:按非字母数字字符分割,过滤停用词和单字符词。停用词表应包含英语常用停用词(约 80-100 个)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 基于召回价值的淘汰(RecallBased)
|
||||||
|
|
||||||
|
### 5.1 记忆价值评分
|
||||||
|
|
||||||
|
每条记忆维护召回统计,计算综合价值分数:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub struct RecallStats {
|
||||||
|
pub recall_count: u64, // 累计召回次数
|
||||||
|
pub total_score: f64, // 累计评分(平均分 = total_score / recall_count)
|
||||||
|
pub last_recall_at: i64, // 最后一次被召回的时间戳(秒)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记忆价值公式:
|
||||||
|
// value = ln(1 + recall_count) × w_recall + avg_score × w_score + recency × w_recency
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 record_recall()
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// MemoryStore trait 可选方法
|
||||||
|
async fn record_recall(&self, id: &str, score: f32) -> Result<(), MemoryError> {
|
||||||
|
Ok(()) // 默认空实现,需覆盖
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 淘汰策略
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub enum EvictionPolicy {
|
||||||
|
// ...Phase 3 已有: None, Ttl, Capacity...
|
||||||
|
|
||||||
|
RecallBased {
|
||||||
|
max_items: usize,
|
||||||
|
recall_weight: f32, // 默认 0.3
|
||||||
|
score_weight: f32, // 默认 0.5
|
||||||
|
recency_weight: f32, // 默认 0.2
|
||||||
|
},
|
||||||
|
Hybrid {
|
||||||
|
ttl_secs: Option<u64>,
|
||||||
|
max_items: Option<usize>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Phase 3 → Phase 4 迁移建议
|
||||||
|
|
||||||
|
| 组件 | Phase 3 状态 | Phase 4 迁移方式 |
|
||||||
|
|------|-------------|-----------------|
|
||||||
|
| KnowledgeStore | 具体 struct(基于 MemoryStore) | 保持 struct,新增知识图谱数据入口 |
|
||||||
|
| KnowledgeGraph | 不存在 | 新建 `memory/graph.rs`,实现 trait + InMemoryGraph |
|
||||||
|
| MemoryRetriever | 单通道(仅 KnowledgeStore) | 增加 KnowledgeGraph 通道,恢复双通道检索 |
|
||||||
|
| ScoringStrategy | 内部 TextOverlap | 恢复枚举策略 + MemoryRetriever 配置 |
|
||||||
|
| KeywordExtractor | MemoryRetriever 内部拆分逻辑 | 抽取为独立 struct |
|
||||||
|
| RecallBased 淘汰 | 不存在 | 恢复 EvictionPolicy 变体 + RecallStats |
|
||||||
|
| 标签管理 | 不存在 | 恢复 tag_index + find_tags + set_entity_tags |
|
||||||
|
|
||||||
|
**关键依赖**:Phase 4 的 Agent 编排是 KnowledgeGraph 和标签管理的驱动者。如果没有 Agent 的 LLM 调用来提取标签、维护知识,KnowledgeGraph 只是一个空的图存储。
|
||||||
+252
@@ -0,0 +1,252 @@
|
|||||||
|
# AG Core Roadmap
|
||||||
|
|
||||||
|
> 定稿日期:2026-05-11
|
||||||
|
> 最后更新:2026-06-09(Phase 4 设计讨论收尾;扩展计划补充 v0.2+ 候选项)
|
||||||
|
|
||||||
|
## 愿景
|
||||||
|
|
||||||
|
AG Core 定位为构建 AI 智能体的底层工具箱,通过模块化、可插拔的架构,提供大模型调用、提示词工程、工具系统、记忆检索四大核心能力,支持快速组合出符合业务需求的智能体应用。
|
||||||
|
|
||||||
|
**当前状态**:Phase 0 基础设施已全部完成,Phase 1 提示词工程已全部完成,Phase 2 工具系统已全部完成,Phase 3 记忆系统已全部完成,等待 Phase 4 启动。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 模块完整性评估
|
||||||
|
|
||||||
|
| 功能领域 | 方案状态 | 文档位置 | 实现优先级 |
|
||||||
|
|---------|---------|---------|-----------|
|
||||||
|
| LLM 调用周期 | ✅ 完整 | `specs/llm-call-lifecycle.md` | P0 |
|
||||||
|
| 提示词工程 | ✅ 完整 | `docs/4-prompt-engineering.md` | P1 |
|
||||||
|
| 工具系统 + 权限 | ✅ 完整 | `docs/5-tool-system.md` | P1 |
|
||||||
|
| 记忆检索 | ✅ 完整 | `docs/6-memory-system.md` | P2 |
|
||||||
|
| Agent 运行时 | ❌ 缺失 | — | P2 |
|
||||||
|
| 生命周期钩子 | ✅ 完整 | `docs/3-phase0-remaining.md` | P0(LLM Cycle 扩展) |
|
||||||
|
| Provider 注册发现 | ✅ 完整 | `docs/3-phase0-remaining.md` | P0(Provider 接口扩展) |
|
||||||
|
| 流式事件系统 | ✅ 完整 | `docs/3-phase0-remaining.md` | P0(流式接口前置) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 分阶段 Roadmap
|
||||||
|
|
||||||
|
### Phase 0 — Foundation(基础设施)
|
||||||
|
|
||||||
|
**目标**:实现 LLM 调用周期的核心功能,作为所有上层模块的基础。
|
||||||
|
|
||||||
|
**交付物**:
|
||||||
|
1. ✅ `llm/types.rs` — 核心数据类型(Message, ContentBlock, ChatRequest/Response, ToolDefinition, StopReason)
|
||||||
|
2. ✅ `llm/error.rs` — 错误体系(LlmError 枚举,可重试/不可重试判断)
|
||||||
|
3. ✅ `llm/provider.rs` + `llm/provider/openai.rs` — Provider 接口 + OpenAI 兼容实现
|
||||||
|
4. ✅ `llm/provider/registry.rs` — ProviderRegistry(多 Provider 注册发现)
|
||||||
|
5. ✅ `llm/cycle.rs` + `llm/cycle/{retry,usage}.rs` — 生命周期引擎(重试策略 + 用量追踪)
|
||||||
|
6. ✅ `llm/hooks.rs` — HookExecutor 接口(生命周期钩子)
|
||||||
|
7. ✅ `llm/stream.rs` — StreamEvents 流式事件系统(AssistantTextDelta, ToolExecutionStarted 等)
|
||||||
|
8. ✅ `llm/compact.rs` — Auto-compaction(上下文自动压缩)
|
||||||
|
9. ✅ `Cargo.toml` — 添加依赖(tokio, reqwest, serde, thiserror, async-trait, tracing)
|
||||||
|
|
||||||
|
**依赖**:无
|
||||||
|
|
||||||
|
**优先级**:Must Have
|
||||||
|
|
||||||
|
**预估规模**:约 1000 行核心代码
|
||||||
|
|
||||||
|
**状态**:✅ Phase 0 全部交付物已完成
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 1 — Prompt Engineering(提示词工程)
|
||||||
|
|
||||||
|
**目标**:提供提示词的组合、模板化与优化能力。
|
||||||
|
|
||||||
|
**交付物**:
|
||||||
|
1. ✅ `prompt.rs` + `prompt/` 模块
|
||||||
|
2. ✅ `PromptTemplate` — 模板引擎(支持变量插值、条件渲染)
|
||||||
|
3. ✅ `PromptComposer` — 提示词组合器(拼接 system/user/assistant 消息)
|
||||||
|
4. ✅ `docs/4-prompt-engineering.md` — 方案文档
|
||||||
|
|
||||||
|
**依赖**:无(可与 Phase 0 并行)
|
||||||
|
|
||||||
|
**优先级**:Should Have
|
||||||
|
|
||||||
|
**预估规模**:约 400 行代码
|
||||||
|
|
||||||
|
**状态**:✅ Phase 1 全部交付物已完成
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 2 — Tool System(工具系统)
|
||||||
|
|
||||||
|
**目标**:实现 MCP 协议集成与自定义工具注册、调用、权限控制。
|
||||||
|
|
||||||
|
**交付物**:
|
||||||
|
1. ✅ `tools.rs` + `tools/` 模块(base/registry/permission/mcp/error)
|
||||||
|
2. ✅ `ToolRegistry` — 工具注册表(注册、发现、调用、并行执行、超时控制)
|
||||||
|
3. ✅ `BaseTool` trait — 工具抽象接口(含 ToolContext 执行上下文)
|
||||||
|
4. ✅ `McpClient` — MCP 协议客户端(stdio transport,StreamableHttp 预留)
|
||||||
|
5. ✅ `PermissionChecker` — 工具执行权限检查(白名单/黑名单/自定义权限)
|
||||||
|
6. ✅ `docs/5-tool-system.md` — 方案设计文档
|
||||||
|
7. ✅ 扩展 `llm/cycle.rs` 支持自动 tool 循环(`submit_with_tools()` + `submit_request()` + `maybe_compact()`)
|
||||||
|
8. ✅ `ToolError` — 结构化错误体系(含 `is_recoverable()` 分类)
|
||||||
|
|
||||||
|
**依赖**:Phase 0(LlmProvider 接口传递 tool definitions)、Phase 1(提示词可能需要注入工具描述)
|
||||||
|
|
||||||
|
**优先级**:Should Have
|
||||||
|
|
||||||
|
**预估规模**:约 900 行代码(实际约 1500 行)
|
||||||
|
|
||||||
|
**状态**:✅ Phase 2 全部交付物已完成
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 3 — Memory System(记忆系统)
|
||||||
|
|
||||||
|
**目标**:提供对话记忆的存储、检索与管理能力。
|
||||||
|
|
||||||
|
**交付物**:
|
||||||
|
1. ✅ `memory.rs` + `memory/` 模块(store / conversation / knowledge / retriever / error / types)
|
||||||
|
2. ✅ `MemoryStore` trait + `InMemoryStore` — 记忆存储抽象(可插拔后端)+ 默认实现
|
||||||
|
3. ✅ `ConversationMemory` — 对话记忆管理(sliding window / 全量),复用 `llm::compact`
|
||||||
|
4. ✅ `KnowledgeStore` — 知识页面存储(具体 struct,非 trait,基于 MemoryStore)
|
||||||
|
5. ✅ `MemoryRetriever` — 记忆检索器(TextOverlap Dice 系数评分,单通道)
|
||||||
|
6. ✅ `docs/6-memory-system.md` — 方案设计文档
|
||||||
|
7. ✅ `docs/note-knowledge-graph-design.md` — KnowledgeGraph 等 Phase 4 备用设计
|
||||||
|
8. ✅ `EvictionPolicy` — 支持 None / Ttl / Capacity 三种淘汰策略
|
||||||
|
|
||||||
|
**依赖**:Phase 0(llm::compact 复用)、Cargo.toml 新增 `time` 依赖
|
||||||
|
|
||||||
|
**优先级**:Could Have
|
||||||
|
|
||||||
|
**预估规模**:约 700 行代码(实际约 1242 行,含测试)
|
||||||
|
|
||||||
|
**状态**:✅ Phase 3 全部交付物已完成
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 4 — Agent Runtime(智能体运行时)
|
||||||
|
|
||||||
|
**目标**:实现多轮对话编排与任务规划。
|
||||||
|
|
||||||
|
**交付物**:
|
||||||
|
1. `agent.rs` + `agent/` 模块
|
||||||
|
2. `Agent` trait — 智能体接口定义
|
||||||
|
3. `ConversationAgent` — 对话型智能体实现
|
||||||
|
4. `TaskAgent` — 任务型智能体(规划 → 执行 → 反馈)
|
||||||
|
5. `specs/agent-runtime.md` — 方案文档
|
||||||
|
|
||||||
|
**依赖**:Phase 0, 1, 2, 3(整合所有模块)
|
||||||
|
|
||||||
|
**优先级**:Could Have
|
||||||
|
|
||||||
|
**预估规模**:约 600 行代码
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 依赖关系图
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph BT
|
||||||
|
P0["<b>Phase 0: Foundation</b><br/>LLM Cycle<br/>ProviderRegistry<br/>HookExecutor<br/>StreamEvents<br/>Auto-compaction"]:::done
|
||||||
|
P1["<b>Phase 1: Prompt Engineering</b><br/>PromptTemplate<br/>PromptComposer"]:::done
|
||||||
|
P2["<b>Phase 2: Tool System</b><br/>Tool Registry<br/>PermissionChecker<br/>MCP Client"]:::done
|
||||||
|
P3["<b>Phase 3: Memory System</b><br/>MemoryStore<br/>ConversationMemory<br/>KnowledgeStore"]:::done
|
||||||
|
P4["<b>Phase 4: Agent Runtime</b><br/>ConversationAgent<br/>TaskAgent"]:::pending
|
||||||
|
|
||||||
|
P1 --> P0
|
||||||
|
P2 --> P0
|
||||||
|
P3 --> P0
|
||||||
|
P2 --> P1
|
||||||
|
P4 --> P1
|
||||||
|
P4 --> P2
|
||||||
|
P4 --> P3
|
||||||
|
|
||||||
|
classDef done fill:#4ade80,stroke:#16a34a,color:#1a1a1a
|
||||||
|
classDef pending fill:#fbbf24,stroke:#d97706,color:#1a1a1a
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 扩展计划(v0.2+)
|
||||||
|
|
||||||
|
> 以下功能在已完成的 phase 中已实现基础能力或在 Phase 4 阶段明确了边界,后续可按维度增量扩展。
|
||||||
|
> 设计参考:见 `docs/note-agent-harness-references.md`(OpenClaw / Hermes / OpenHuman / OpenHarness 横向对比)。
|
||||||
|
|
||||||
|
### 已有扩展项(沿用)
|
||||||
|
|
||||||
|
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||||
|
|-------|---------|------|--------|------|
|
||||||
|
| Prompt Optimizer | `prompt` | 提示词自动优化 | P3 | 待实现 |
|
||||||
|
| 流式接口优化 | `llm/stream` | 流式响应解析与事件化 | P0 | ✅ 已完成基础实现 |
|
||||||
|
|
||||||
|
### v0.2+ 新增扩展项
|
||||||
|
|
||||||
|
> 以下为基于 Phase 4 设计讨论确定的 v0.2+ 候选扩展方向,按维度分组。
|
||||||
|
> 标注为"v0.2 待评估"表示在 Phase 4 完成后再决定是否启动。
|
||||||
|
|
||||||
|
#### Multi-Agent / 协同
|
||||||
|
|
||||||
|
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||||
|
|-------|---------|------|--------|------|
|
||||||
|
| Multi-Agent 协同(Swarm) | `agent` | 子 Agent 委派、并行子任务、结果聚合 | P2 | v0.2 待评估 |
|
||||||
|
|
||||||
|
#### 技能(Skills)
|
||||||
|
|
||||||
|
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||||
|
|-------|---------|------|--------|------|
|
||||||
|
| Markdown 技能按需加载 | `agent` / `prompt` | 兼容 `SKILL.md` 格式(Hermes / OpenHarness 风格),按 prompt 上下文动态加载 | P2 | v0.2 待评估 |
|
||||||
|
|
||||||
|
#### 记忆(Memory)
|
||||||
|
|
||||||
|
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||||
|
|-------|---------|------|--------|------|
|
||||||
|
| 多通道检索(hybrid) | `memory/retriever` | 在 TextOverlap 之上叠加向量检索通道 | P2 | v0.2 待评估 |
|
||||||
|
| KnowledgeGraph 深度记忆 | `memory` | 实体-关系图、`note-knowledge-graph-design.md` 已记录设计 | P3 | v0.2 待评估 |
|
||||||
|
| TokenJuice 智能压缩 | `memory` / `llm/compact` | 借鉴 OpenHuman TokenJuice,对工具结果做语义压缩而非字节截断 | P3 | v0.2 待评估 |
|
||||||
|
|
||||||
|
#### 交互层(TUI / Gateway)
|
||||||
|
|
||||||
|
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||||
|
|-------|---------|------|--------|------|
|
||||||
|
| TUI / 多平台 Gateway | 应用层 | OpenClaw / Hermes 风格的消息平台桥接(Feishu / Telegram / Discord 等) | P3 | v0.2+ 应用层 |
|
||||||
|
|
||||||
|
#### 训练基础设施
|
||||||
|
|
||||||
|
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||||
|
|-------|---------|------|--------|------|
|
||||||
|
| RL 轨迹导出 | `agent` | ShareGPT 格式轨迹、Atropos 集成(Hermes 风格) | P3 | v0.3+ 探索 |
|
||||||
|
|
||||||
|
#### 安全治理
|
||||||
|
|
||||||
|
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||||
|
|-------|---------|------|--------|------|
|
||||||
|
| Human-in-the-loop 审批 | `agent` / `tools/permission` | 高危工具执行前的异步审批回调(OpenHarness `permission_prompt` 模式) | P2 | v0.2 待评估 |
|
||||||
|
|
||||||
|
#### 流式 / 实时
|
||||||
|
|
||||||
|
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||||
|
|-------|---------|------|--------|------|
|
||||||
|
| 流式 `submit_turn` | `agent/session` | Phase 4 v1 只暴露非流式 `submit_turn()`;v0.2 包装 `LlmCycle::submit_stream` 暴露流式入口 | P2 | v0.2 待评估 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 风险与建议
|
||||||
|
|
||||||
|
1. **Phase 0 已完成**:LLM 调用周期基础设施已全部实现,可以支撑后续模块开发
|
||||||
|
2. **并行可能性**:Phase 0 和 Phase 1 可并行开展(无相互依赖),可加速早期交付
|
||||||
|
3. **MCP 协议复杂性**:MCP 涉及协议握手、session 管理、长期连接,建议预留充足时间调研协议细节
|
||||||
|
4. **Scope 蔓延风险**:当前 specs 只有 1 份文档,建议每个模块上线前都产出对应 spec,避免边实现边设计
|
||||||
|
5. **Phase 4 抽象化边界**:AG Core 定位为"支持库"而非"Agent 产品",Phase 4 需严格控制范围——只暴露 trait + 最小 reference impl,业务循环(多轮 turn 编排、记忆自动回写、Task 拆解策略)留给上层应用,避免与 OpenHarness / Hermes / OpenHuman 等已有 Agent 产品竞争实现细节。详细设计决策见 Phase 4 设计讨论记录(待 `docs/7-agent-runtime.md` 落盘)
|
||||||
|
6. **参考项目语言差异**:OpenClaw / Hermes / OpenHarness 均为 Python/TypeScript 实现,OpenHuman 虽是 Rust + Tauri 但定位是桌面应用。借鉴时**只取架构模式**,不照搬具体实现(如 Pydantic 工具校验、SQLite Memory Tree、Node+Python 双进程等)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 下一步行动
|
||||||
|
|
||||||
|
1. **Phase 4 设计讨论收尾**:Phase 4 范围已收窄为「`Agent` trait + `RuntimeBundle` 依赖注入容器 + `AgentSession` 实体/会话分离 + `TaskAgent` 双入口 + 记忆弱引用 + Hook 事件扩展 3 个」。决策记录已固化,待写 `docs/7-agent-runtime.md` 方案文档后启动编码实现
|
||||||
|
2. **Phase 4 方案文档**:将 Phase 4 设计决策沉淀为方案文档,沿用 `docs/4-prompt-engineering.md` / `5-tool-system.md` / `6-memory-system.md` 的 6 段式结构,文件名 `docs/7-agent-runtime.md`
|
||||||
|
3. **参考项目调研沉淀**:已完成 OpenClaw / Hermes / OpenHuman / OpenHarness 横向调研,结果沉淀至 `docs/note-agent-harness-references.md`,作为 v0.2+ 扩展项的输入
|
||||||
|
4. **Phase 3 备用设计就绪**:`docs/note-knowledge-graph-design.md` 记录了 KnowledgeGraph、高级评分、RecallBased 淘汰等设计,v0.2+ 记忆扩展可直接参考
|
||||||
|
|
||||||
|
**已完成阶段**:
|
||||||
|
- ✅ Phase 0 Foundation — 全部交付物已完成
|
||||||
|
- ✅ Phase 1 Prompt Engineering — 全部交付物已完成
|
||||||
|
- ✅ Phase 2 Tool System — 全部交付物已完成
|
||||||
|
- ✅ Phase 3 Memory System — 全部交付物已完成
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"$schema": "https://opencode.ai/config.json",
|
||||||
|
"agent": {
|
||||||
|
"plan": {
|
||||||
|
"permission": {
|
||||||
|
"edit": {
|
||||||
|
"*": "deny",
|
||||||
|
".opencode/plans/**": "allow",
|
||||||
|
"docs/**": "allow"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,179 +0,0 @@
|
|||||||
# AG Core Roadmap
|
|
||||||
|
|
||||||
> 定稿日期:2026-05-11
|
|
||||||
|
|
||||||
## 愿景
|
|
||||||
|
|
||||||
AG Core 定位为构建 AI 智能体的底层工具箱,通过模块化、可插拔的架构,提供大模型调用、提示词工程、工具系统、记忆检索四大核心能力,支持快速组合出符合业务需求的智能体应用。
|
|
||||||
|
|
||||||
**当前状态**:代码为空壳,specs 目录有 1 份方案(LLM 调用周期)。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 模块完整性评估
|
|
||||||
|
|
||||||
| 功能领域 | 方案状态 | 文档位置 | 实现优先级 |
|
|
||||||
|---------|---------|---------|-----------|
|
|
||||||
| LLM 调用周期 | ✅ 完整 | `specs/llm-call-lifecycle.md` | P0 |
|
|
||||||
| 提示词工程 | ❌ 缺失 | — | P1 |
|
|
||||||
| 工具系统 + 权限 | ❌ 缺失 | — | P1 |
|
|
||||||
| 记忆检索 | ❌ 缺失 | — | P2 |
|
|
||||||
| Agent 运行时 | ❌ 缺失 | — | P2 |
|
|
||||||
| 生命周期钩子 | ❌ 缺失 | — | P0(LLM Cycle 扩展) |
|
|
||||||
| Provider 注册发现 | ❌ 缺失 | — | P0(Provider 接口扩展) |
|
|
||||||
| 流式事件系统 | ❌ 缺失 | — | P0(流式接口前置) |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 分阶段 Roadmap
|
|
||||||
|
|
||||||
### Phase 0 — Foundation(基础设施)
|
|
||||||
|
|
||||||
**目标**:实现 LLM 调用周期的核心功能,作为所有上层模块的基础。
|
|
||||||
|
|
||||||
**交付物**:
|
|
||||||
1. `llm/types.rs` — 核心数据类型(Message, ContentBlock, ChatRequest/Response, ToolDefinition, StopReason)
|
|
||||||
2. `llm/error.rs` — 错误体系(LlmError 枚举,可重试/不可重试判断)
|
|
||||||
3. `llm/provider.rs` + `llm/provider/openai.rs` — Provider 接口 + OpenAI 兼容实现
|
|
||||||
4. `llm/provider/registry.rs` — ProviderRegistry(多 Provider 注册发现)
|
|
||||||
5. `llm/cycle.rs` + `llm/cycle/{retry,usage}.rs` — 生命周期引擎(重试策略 + 用量追踪)
|
|
||||||
6. `llm/hooks.rs` — HookExecutor 接口(生命周期钩子)
|
|
||||||
7. `llm/stream.rs` — StreamEvents 流式事件系统(AssistantTextDelta, ToolExecutionStarted 等)
|
|
||||||
8. `llm/compact.rs` — Auto-compaction(上下文自动压缩)
|
|
||||||
9. `Cargo.toml` — 添加依赖(tokio, reqwest, serde, thiserror, async-trait, tracing)
|
|
||||||
|
|
||||||
**依赖**:无
|
|
||||||
|
|
||||||
**优先级**:Must Have
|
|
||||||
|
|
||||||
**预估规模**:约 1000 行核心代码
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Phase 1 — Prompt Engineering(提示词工程)
|
|
||||||
|
|
||||||
**目标**:提供提示词的组合、模板化与优化能力。
|
|
||||||
|
|
||||||
**交付物**:
|
|
||||||
1. `prompt.rs` + `prompt/` 模块
|
|
||||||
2. `PromptTemplate` — 模板引擎(支持变量插值、条件渲染)
|
|
||||||
3. `PromptComposer` — 提示词组合器(拼接 system/user/assistant 消息)
|
|
||||||
4. `specs/prompt-design.md` — 方案文档
|
|
||||||
|
|
||||||
**依赖**:无(可与 Phase 0 并行)
|
|
||||||
|
|
||||||
**优先级**:Should Have
|
|
||||||
|
|
||||||
**预估规模**:约 400 行代码
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Phase 2 — Tool System(工具系统)
|
|
||||||
|
|
||||||
**目标**:实现 MCP 协议集成与自定义工具注册、调用、权限控制。
|
|
||||||
|
|
||||||
**交付物**:
|
|
||||||
1. `tools.rs` + `tools/` 模块
|
|
||||||
2. `ToolRegistry` — 工具注册表(注册、发现、调用)
|
|
||||||
3. `BaseTool` trait — 工具抽象接口
|
|
||||||
4. `McpClient` — MCP 协议客户端
|
|
||||||
5. `PermissionChecker` — 工具执行权限检查(读/写/删除/网络等)
|
|
||||||
6. `specs/tool-call-loop.md` — Tool 自动执行循环设计
|
|
||||||
7. 扩展 `llm/cycle.rs` 支持自动 tool 循环(参考 OpenHarness `run_query()`)
|
|
||||||
|
|
||||||
**依赖**:Phase 0(LlmProvider 接口传递 tool definitions)、Phase 1(提示词可能需要注入工具描述)
|
|
||||||
|
|
||||||
**优先级**:Should Have
|
|
||||||
|
|
||||||
**预估规模**:约 900 行代码
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Phase 3 — Memory System(记忆系统)
|
|
||||||
|
|
||||||
**目标**:提供对话记忆的存储、检索与管理能力。
|
|
||||||
|
|
||||||
**交付物**:
|
|
||||||
1. `memory.rs` + `memory/` 模块
|
|
||||||
2. `MemoryStore` trait — 记忆存储抽象(可插拔后端)
|
|
||||||
3. `VectorStore` — 向量存储实现(支持 embedding 检索)
|
|
||||||
4. `ConversationMemory` — 对话记忆管理(sliding window / 全量)
|
|
||||||
5. `MemoryRetriever` — 记忆检索器(similarity search)
|
|
||||||
6. `specs/memory-system.md` — 方案文档
|
|
||||||
|
|
||||||
**依赖**:Phase 0(LLM 调用可能用于 embedding 生成)
|
|
||||||
|
|
||||||
**优先级**:Could Have
|
|
||||||
|
|
||||||
**预估规模**:约 700 行代码
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Phase 4 — Agent Runtime(智能体运行时)
|
|
||||||
|
|
||||||
**目标**:实现多轮对话编排与任务规划。
|
|
||||||
|
|
||||||
**交付物**:
|
|
||||||
1. `agent.rs` + `agent/` 模块
|
|
||||||
2. `Agent` trait — 智能体接口定义
|
|
||||||
3. `ConversationAgent` — 对话型智能体实现
|
|
||||||
4. `TaskAgent` — 任务型智能体(规划 → 执行 → 反馈)
|
|
||||||
5. `specs/agent-runtime.md` — 方案文档
|
|
||||||
|
|
||||||
**依赖**:Phase 0, 1, 2, 3(整合所有模块)
|
|
||||||
|
|
||||||
**优先级**:Could Have
|
|
||||||
|
|
||||||
**预估规模**:约 600 行代码
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 依赖关系图
|
|
||||||
|
|
||||||
```
|
|
||||||
Phase 4: Agent Runtime
|
|
||||||
│
|
|
||||||
┌─────────────────┼─────────────────┐
|
|
||||||
▼ ▼ ▼
|
|
||||||
Phase 1 Phase 2 Phase 3
|
|
||||||
Prompt Tool System Memory
|
|
||||||
Engineering + Permission System
|
|
||||||
+ HookExecutor
|
|
||||||
│ │ │
|
|
||||||
└────────┬────────┴────────┬────────┘
|
|
||||||
▼ ▼
|
|
||||||
Phase 0 ─────────────────┘
|
|
||||||
LLM Cycle
|
|
||||||
+ ProviderRegistry
|
|
||||||
+ HookExecutor
|
|
||||||
+ StreamEvents
|
|
||||||
+ Auto-compaction
|
|
||||||
(Foundation)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 扩展计划(v0.2+)
|
|
||||||
|
|
||||||
> 以下功能已在 Phase 0 中实现,流式接口为后续增量优化。
|
|
||||||
|
|
||||||
| 扩展项 | 所在模块 | 说明 | 优先级 |
|
|
||||||
|-------|---------|------|--------|
|
|
||||||
| Prompt Optimizer | `prompt` | 提示词自动优化 | P3 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 风险与建议
|
|
||||||
|
|
||||||
1. **Phase 0 尚未实现**:项目代码是空壳,建议优先完成 LLM 调用周期,避免后续模块依赖不存在的底层
|
|
||||||
2. **并行可能性**:Phase 0 和 Phase 1 可并行开展(无相互依赖),可加速早期交付
|
|
||||||
3. **MCP 协议复杂性**:MCP 涉及协议握手、session 管理、长期连接,建议预留充足时间调研协议细节
|
|
||||||
4. **Scope 蔓延风险**:当前 specs 只有 1 份文档,建议每个模块上线前都产出对应 spec,避免边实现边设计
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 下一步行动
|
|
||||||
|
|
||||||
1. **Phase 0 方案评审**:对齐 LLM 模块设计(`specs/llm-call-lifecycle.md` 已在 2026-05-11 更新)
|
|
||||||
2. **Phase 1 方案启动**:启动 `specs/prompt-design.md` 设计
|
|
||||||
3. **Phase 2 方案启动**:启动 `specs/tool-call-loop.md` 设计(含 PermissionChecker)
|
|
||||||
+3
-3
@@ -1,9 +1,9 @@
|
|||||||
//! agcore —— 智能体(Agent)核心工具箱。
|
//! agcore —— 智能体(Agent)核心工具箱。
|
||||||
//!
|
|
||||||
//! 当前提供 LLM 调用周期控制作为核心底层能力,后续将扩展至
|
|
||||||
//! 提示词工程、记忆系统、工具调用、Agent 运行时等领域。
|
|
||||||
|
|
||||||
pub mod llm;
|
pub mod llm;
|
||||||
|
pub mod memory;
|
||||||
|
pub mod prompt;
|
||||||
|
pub mod tools;
|
||||||
|
|
||||||
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
|
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
|
||||||
|
|
||||||
|
|||||||
+3
-2
@@ -1,8 +1,9 @@
|
|||||||
//! LLM 调用周期 —— 大模型基础调用周期控制。
|
//! LLM 调用周期 —— 大模型基础调用周期控制。
|
||||||
//!
|
|
||||||
//! 包含核心数据类型、Provider 抽象、OpenAI 兼容实现以及生命周期引擎。
|
|
||||||
|
|
||||||
|
pub mod compact;
|
||||||
pub mod cycle;
|
pub mod cycle;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
pub mod hooks;
|
||||||
pub mod provider;
|
pub mod provider;
|
||||||
|
pub mod stream;
|
||||||
pub mod types;
|
pub mod types;
|
||||||
|
|||||||
@@ -0,0 +1,159 @@
|
|||||||
|
//! 上下文自动压缩 —— 当对话历史过长时自动压缩。
|
||||||
|
|
||||||
|
use crate::llm::types::{ContentField, OpenaiChatMessage, OpenaiContentPart};
|
||||||
|
|
||||||
|
const AUTOCOMPACT_BUFFER_TOKENS: u32 = 13_000;
|
||||||
|
const RESERVED_OUTPUT_TOKENS: u32 = 20_000;
|
||||||
|
const MAX_CONSECUTIVE_FAILURES: u32 = 3;
|
||||||
|
const KEEP_RECENT: usize = 6;
|
||||||
|
|
||||||
|
/// 上下文压缩配置。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CompactConfig {
|
||||||
|
/// 模型上下文窗口大小(token 数)。
|
||||||
|
pub context_window: u32,
|
||||||
|
/// 为输出预留的 token 数。
|
||||||
|
pub reserved_tokens: u32,
|
||||||
|
/// 微压缩保留的最近消息数。
|
||||||
|
pub keep_recent: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CompactConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
context_window: 128_000,
|
||||||
|
reserved_tokens: RESERVED_OUTPUT_TOKENS,
|
||||||
|
keep_recent: KEEP_RECENT,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompactConfig {
|
||||||
|
/// 计算自动压缩触发的阈值。
|
||||||
|
pub fn threshold(&self) -> u32 {
|
||||||
|
self.context_window
|
||||||
|
.saturating_sub(self.reserved_tokens)
|
||||||
|
.saturating_sub(AUTOCOMPACT_BUFFER_TOKENS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 压缩状态 —— 跟踪连续失败次数(断路器模式)。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CompactState {
|
||||||
|
consecutive_failures: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CompactState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompactState {
|
||||||
|
/// 创建一个新的压缩状态。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
consecutive_failures: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 记录一次成功的压缩。
|
||||||
|
pub fn record_success(&mut self) {
|
||||||
|
self.consecutive_failures = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 记录一次压缩失败。
|
||||||
|
///
|
||||||
|
/// 返回 `true` 表示已达断路器上限,不再尝试。
|
||||||
|
pub fn record_failure(&mut self) -> bool {
|
||||||
|
self.consecutive_failures += 1;
|
||||||
|
self.consecutive_failures >= MAX_CONSECUTIVE_FAILURES
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 粗略估计消息列表的 token 数(基于字符数,4 字符 ≈ 1 token)。
|
||||||
|
pub fn estimate_message_tokens(messages: &[OpenaiChatMessage]) -> u32 {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(estimate_single_message_tokens)
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn estimate_single_message_tokens(msg: &OpenaiChatMessage) -> u32 {
|
||||||
|
let role_overhead: u32 = 4;
|
||||||
|
let content_tokens = match msg {
|
||||||
|
OpenaiChatMessage::Developer { content, .. }
|
||||||
|
| OpenaiChatMessage::System { content, .. }
|
||||||
|
| OpenaiChatMessage::User { content, .. }
|
||||||
|
| OpenaiChatMessage::Assistant { content, .. }
|
||||||
|
| OpenaiChatMessage::Function { content, .. } => estimate_content_tokens(content),
|
||||||
|
OpenaiChatMessage::Tool { content, .. } => estimate_content_tokens(content),
|
||||||
|
};
|
||||||
|
role_overhead + content_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
fn estimate_content_tokens(content: &ContentField) -> u32 {
|
||||||
|
match content {
|
||||||
|
ContentField::String(s) => estimate_text_tokens(s),
|
||||||
|
ContentField::Array(parts) => parts.iter().map(estimate_part_tokens).sum(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn estimate_part_tokens(part: &OpenaiContentPart) -> u32 {
|
||||||
|
match part {
|
||||||
|
OpenaiContentPart::Text { text } => estimate_text_tokens(text),
|
||||||
|
_ => 50,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn estimate_text_tokens(text: &str) -> u32 {
|
||||||
|
if text.is_empty() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
let len = text.len() as u32;
|
||||||
|
(len * 4).div_ceil(3)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 判断是否需要触发自动压缩。
|
||||||
|
pub fn should_compact(
|
||||||
|
messages: &[OpenaiChatMessage],
|
||||||
|
config: &CompactConfig,
|
||||||
|
state: &CompactState,
|
||||||
|
) -> bool {
|
||||||
|
if state.consecutive_failures >= MAX_CONSECUTIVE_FAILURES {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
let tokens = estimate_message_tokens(messages);
|
||||||
|
tokens >= config.threshold()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行微压缩 —— 用 `[pruned]` 替换旧的 tool result 内容。
|
||||||
|
///
|
||||||
|
/// 这是最便宜的压缩方式,不需要 LLM 调用。
|
||||||
|
/// 保留最近的 `keep_recent` 条消息不变。
|
||||||
|
///
|
||||||
|
/// 返回释放的估算 token 数。
|
||||||
|
pub fn microcompact(messages: &mut [OpenaiChatMessage], keep_recent: usize) -> u32 {
|
||||||
|
if messages.len() <= keep_recent {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let prune_start = messages.len() - keep_recent;
|
||||||
|
let mut freed_tokens: u32 = 0;
|
||||||
|
|
||||||
|
for msg in &messages[..prune_start] {
|
||||||
|
if matches!(msg, OpenaiChatMessage::Tool { .. }) {
|
||||||
|
freed_tokens += estimate_single_message_tokens(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for msg in &mut messages[..prune_start] {
|
||||||
|
if let OpenaiChatMessage::Tool { content, .. } = msg {
|
||||||
|
*content = ContentField::Array(vec![OpenaiContentPart::Text {
|
||||||
|
text: "[pruned]".to_string(),
|
||||||
|
}]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
freed_tokens
|
||||||
|
}
|
||||||
+769
-2
@@ -1,22 +1,49 @@
|
|||||||
|
//! LLM 调用周期控制模块。
|
||||||
|
|
||||||
mod retry;
|
mod retry;
|
||||||
pub mod usage;
|
pub mod usage;
|
||||||
|
|
||||||
pub use retry::RetryConfig;
|
pub use retry::RetryConfig;
|
||||||
pub use usage::{CostTracker, Usage};
|
pub use usage::{CostTracker, Usage};
|
||||||
|
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use futures_core::stream::Stream;
|
||||||
|
use async_stream::stream;
|
||||||
|
|
||||||
|
use crate::llm::compact::{should_compact, microcompact, CompactConfig, CompactState};
|
||||||
use crate::llm::cycle::retry::should_retry;
|
use crate::llm::cycle::retry::should_retry;
|
||||||
use crate::llm::error::LlmError;
|
use crate::llm::error::LlmError;
|
||||||
|
use crate::llm::hooks::{HookContext, HookExecutor};
|
||||||
use crate::llm::provider::LlmProvider;
|
use crate::llm::provider::LlmProvider;
|
||||||
|
use crate::llm::stream::StreamEvent;
|
||||||
use crate::llm::types::{
|
use crate::llm::types::{
|
||||||
ChatRequest, ChatResponse, OpenaiChatMessage, OpenaiTool, ToolChoice, ToolDefinition,
|
ChatRequest, ChatResponse, FinishReason, OpenaiChatMessage, OpenaiTool, OpenaiToolCall,
|
||||||
|
ToolChoice, ToolDefinition,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// LLM 调用周期配置。
|
||||||
pub struct CycleConfig {
|
pub struct CycleConfig {
|
||||||
|
/// 模型名称。
|
||||||
pub model: String,
|
pub model: String,
|
||||||
|
/// 最大输出 token 数。
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
/// 采样温度。
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
|
/// 最大对话轮次。
|
||||||
pub max_turns: Option<u32>,
|
pub max_turns: Option<u32>,
|
||||||
|
/// 重试策略配置。
|
||||||
pub retry: RetryConfig,
|
pub retry: RetryConfig,
|
||||||
|
/// 自动 tool 循环的最大轮次(独立于 `max_turns`,避免影响现有 `submit()` 语义)。
|
||||||
|
/// 默认 `Some(10)`,防止 LLM 反复调用工具导致无限循环。
|
||||||
|
pub max_tool_turns: Option<u32>,
|
||||||
|
/// 单个工具执行的超时秒数(0 表示不超时)。
|
||||||
|
/// 默认 60 秒。
|
||||||
|
pub tool_timeout_secs: u64,
|
||||||
|
/// 单个工具结果的最大字节数(超过此值将被截断)。
|
||||||
|
/// 默认 65536(64KB),防止大结果导致 token 膨胀。
|
||||||
|
pub max_tool_result_bytes: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for CycleConfig {
|
impl Default for CycleConfig {
|
||||||
@@ -27,19 +54,27 @@ impl Default for CycleConfig {
|
|||||||
temperature: None,
|
temperature: None,
|
||||||
max_turns: None,
|
max_turns: None,
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
max_tool_turns: Some(10),
|
||||||
|
tool_timeout_secs: 60,
|
||||||
|
max_tool_result_bytes: 65_536,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// LLM 调用周期 —— 管理一次或多次 LLM 请求的生命周期。
|
||||||
pub struct LlmCycle {
|
pub struct LlmCycle {
|
||||||
provider: Box<dyn LlmProvider>,
|
provider: Box<dyn LlmProvider>,
|
||||||
config: CycleConfig,
|
config: CycleConfig,
|
||||||
usage: CostTracker,
|
usage: CostTracker,
|
||||||
messages: Vec<OpenaiChatMessage>,
|
messages: Vec<OpenaiChatMessage>,
|
||||||
system_prompt: Option<String>,
|
system_prompt: Option<String>,
|
||||||
|
hook_executor: Option<Arc<HookExecutor>>,
|
||||||
|
compact_config: Option<CompactConfig>,
|
||||||
|
compact_state: CompactState,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlmCycle {
|
impl LlmCycle {
|
||||||
|
/// 创建一个新的 LlmCycle。
|
||||||
pub fn new(provider: Box<dyn LlmProvider>, config: CycleConfig) -> Self {
|
pub fn new(provider: Box<dyn LlmProvider>, config: CycleConfig) -> Self {
|
||||||
Self {
|
Self {
|
||||||
provider,
|
provider,
|
||||||
@@ -47,30 +82,139 @@ impl LlmCycle {
|
|||||||
usage: CostTracker::default(),
|
usage: CostTracker::default(),
|
||||||
messages: Vec::new(),
|
messages: Vec::new(),
|
||||||
system_prompt: None,
|
system_prompt: None,
|
||||||
|
hook_executor: None,
|
||||||
|
compact_config: None,
|
||||||
|
compact_state: CompactState::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 设置系统提示词。
|
||||||
pub fn with_system_prompt(mut self, prompt: String) -> Self {
|
pub fn with_system_prompt(mut self, prompt: String) -> Self {
|
||||||
self.system_prompt = Some(prompt);
|
self.system_prompt = Some(prompt);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 设置钩子执行器。
|
||||||
|
pub fn with_hook_executor(mut self, executor: HookExecutor) -> Self {
|
||||||
|
self.hook_executor = Some(Arc::new(executor));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 设置上下文压缩配置。
|
||||||
|
pub fn with_compact_config(mut self, config: CompactConfig) -> Self {
|
||||||
|
self.compact_config = Some(config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取用量追踪器引用。
|
||||||
pub fn usage(&self) -> &CostTracker {
|
pub fn usage(&self) -> &CostTracker {
|
||||||
&self.usage
|
&self.usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 获取消息历史引用。
|
||||||
pub fn messages(&self) -> &[OpenaiChatMessage] {
|
pub fn messages(&self) -> &[OpenaiChatMessage] {
|
||||||
&self.messages
|
&self.messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 清空消息历史。
|
||||||
pub fn clear_messages(&mut self) {
|
pub fn clear_messages(&mut self) {
|
||||||
self.messages.clear();
|
self.messages.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 重置用量统计。
|
||||||
pub fn reset_usage(&mut self) {
|
pub fn reset_usage(&mut self) {
|
||||||
self.usage.reset();
|
self.usage.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 直接设置消息历史(覆盖已有消息),支持 Builder 链式调用。
|
||||||
|
pub fn with_messages(mut self, messages: Vec<OpenaiChatMessage>) -> Self {
|
||||||
|
self.messages = messages;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 追加消息到历史尾部。
|
||||||
|
pub fn extend_messages(&mut self, messages: Vec<OpenaiChatMessage>) {
|
||||||
|
self.messages.extend(messages);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用预构建消息提交(跳过自动 push user prompt)。
|
||||||
|
///
|
||||||
|
/// 与 `submit()` 不同,不自动添加 `user_text(prompt)`,也不自动插入 system prompt。
|
||||||
|
/// 调用方完全控制消息序列内容。
|
||||||
|
pub async fn submit_messages(
|
||||||
|
&mut self,
|
||||||
|
messages: Vec<OpenaiChatMessage>,
|
||||||
|
tools: Vec<ToolDefinition>,
|
||||||
|
) -> Result<ChatResponse, LlmError> {
|
||||||
|
let openai_tools: Option<Vec<OpenaiTool>> = if tools.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.map(|t| OpenaiTool::Function {
|
||||||
|
function: t.clone(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = ChatRequest {
|
||||||
|
model: self.config.model.clone(),
|
||||||
|
messages,
|
||||||
|
max_tokens: self.config.max_tokens,
|
||||||
|
temperature: self.config.temperature,
|
||||||
|
tools: openai_tools,
|
||||||
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||||
|
.with_request(&request);
|
||||||
|
let results = executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||||
|
.await;
|
||||||
|
if results.iter().any(|r| r.should_block) {
|
||||||
|
let reason = results
|
||||||
|
.iter()
|
||||||
|
.find(|r| r.should_block)
|
||||||
|
.and_then(|r| r.reason.clone())
|
||||||
|
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||||
|
return Err(LlmError::Other(reason));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.provider.chat(request).await {
|
||||||
|
Ok(response) => {
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let post_request = ChatRequest {
|
||||||
|
model: self.config.model.clone(),
|
||||||
|
messages: vec![],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
|
||||||
|
.with_request(&post_request);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
self.usage.add(&response.usage);
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnError).with_error(&e);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
Err(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 提交用户消息并获取 LLM 响应。
|
||||||
pub async fn submit(
|
pub async fn submit(
|
||||||
&mut self,
|
&mut self,
|
||||||
prompt: String,
|
prompt: String,
|
||||||
@@ -78,31 +222,203 @@ impl LlmCycle {
|
|||||||
) -> Result<ChatResponse, LlmError> {
|
) -> Result<ChatResponse, LlmError> {
|
||||||
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||||
|
|
||||||
|
if let Some(ref config) = self.compact_config
|
||||||
|
&& should_compact(&self.messages, config, &self.compact_state)
|
||||||
|
{
|
||||||
|
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||||
|
if freed > 0 {
|
||||||
|
self.compact_state.record_success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut attempts = 0;
|
let mut attempts = 0;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let request = self.build_request(&tools);
|
let request = self.build_request(&tools);
|
||||||
|
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||||
|
.with_request(&request);
|
||||||
|
let results = executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||||
|
.await;
|
||||||
|
if results.iter().any(|r| r.should_block) {
|
||||||
|
let reason = results
|
||||||
|
.iter()
|
||||||
|
.find(|r| r.should_block)
|
||||||
|
.and_then(|r| r.reason.clone())
|
||||||
|
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||||
|
return Err(LlmError::Other(reason));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
match self.provider.chat(request).await {
|
match self.provider.chat(request).await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
self.messages.push(response.message.clone());
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let post_request = self.build_request(&tools);
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
|
||||||
|
.with_request(&post_request);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.messages.push(response.message.clone());
|
||||||
self.usage.add(&response.usage);
|
self.usage.add(&response.usage);
|
||||||
|
|
||||||
return Ok(response);
|
return Ok(response);
|
||||||
}
|
}
|
||||||
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
|
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
|
||||||
attempts += 1;
|
attempts += 1;
|
||||||
|
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnRetry)
|
||||||
|
.with_error(&e)
|
||||||
|
.with_attempt(attempts);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::OnRetry, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
let delay = self.config.retry.compute_delay(attempts);
|
let delay = self.config.retry.compute_delay(attempts);
|
||||||
tokio::time::sleep(delay).await;
|
tokio::time::sleep(delay).await;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx =
|
||||||
|
HookContext::new(crate::llm::hooks::HookEvent::OnError).with_error(&e);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 提交用户消息并返回语义事件流。
|
||||||
|
///
|
||||||
|
/// 与 `submit` 不同,该方法返回流式事件而非完整响应。
|
||||||
|
/// 适用于需要实时处理 LLM 输出的场景。
|
||||||
|
pub async fn submit_stream(
|
||||||
|
&mut self,
|
||||||
|
prompt: String,
|
||||||
|
tools: Vec<ToolDefinition>,
|
||||||
|
) -> Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>, LlmError> {
|
||||||
|
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||||
|
|
||||||
|
if let Some(ref config) = self.compact_config
|
||||||
|
&& should_compact(&self.messages, config, &self.compact_state)
|
||||||
|
{
|
||||||
|
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||||
|
if freed > 0 {
|
||||||
|
self.compact_state.record_success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let request = self.build_request(&tools);
|
||||||
|
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||||
|
.with_request(&request);
|
||||||
|
let results = executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||||
|
.await;
|
||||||
|
if results.iter().any(|r| r.should_block) {
|
||||||
|
let reason = results
|
||||||
|
.iter()
|
||||||
|
.find(|r| r.should_block)
|
||||||
|
.and_then(|r| r.reason.clone())
|
||||||
|
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||||
|
return Err(LlmError::Other(reason));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk_stream = self.provider.chat_stream(request).await?;
|
||||||
|
let hook_executor = self.hook_executor.clone();
|
||||||
|
let post_request = self.build_request(&tools);
|
||||||
|
|
||||||
|
Ok(Box::pin(stream! {
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
let mut chunk_stream = chunk_stream;
|
||||||
|
|
||||||
|
while let Some(result) = chunk_stream.next().await {
|
||||||
|
match result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
let mut assistant_text = String::new();
|
||||||
|
let mut tool_started: Option<(String, String, String)> = None;
|
||||||
|
|
||||||
|
for choice in &chunk.choices {
|
||||||
|
let delta = &choice.delta;
|
||||||
|
|
||||||
|
if let Some(content) = &delta.content {
|
||||||
|
assistant_text.push_str(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(tool_calls) = &delta.tool_calls
|
||||||
|
&& let Some(tc) = tool_calls.first()
|
||||||
|
{
|
||||||
|
let crate::llm::types::OpenaiToolCall::Function { id, function } = tc;
|
||||||
|
tool_started = Some((id.clone(), function.name.clone(), function.arguments.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assistant_text.is_empty() {
|
||||||
|
yield StreamEvent::AssistantTextDelta { text: assistant_text };
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some((tool_call_id, tool_name, arguments)) = tool_started {
|
||||||
|
let args: serde_json::Value = serde_json::from_str(&arguments)
|
||||||
|
.unwrap_or(serde_json::Value::Null);
|
||||||
|
yield StreamEvent::ToolExecutionStarted {
|
||||||
|
tool_name,
|
||||||
|
input: args,
|
||||||
|
tool_call_id,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
for choice in &chunk.choices {
|
||||||
|
if let Some(finish_reason) = &choice.finish_reason {
|
||||||
|
yield StreamEvent::TurnComplete {
|
||||||
|
reason: *finish_reason,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(usage_info) = &chunk.usage {
|
||||||
|
yield StreamEvent::CostUpdate { usage: *usage_info };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if let Some(ref executor) = hook_executor {
|
||||||
|
let ctx = crate::llm::hooks::HookContext::new(
|
||||||
|
crate::llm::hooks::HookEvent::OnError,
|
||||||
|
)
|
||||||
|
.with_error(&e);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
yield StreamEvent::Error { message: e.to_string() };
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref executor) = hook_executor {
|
||||||
|
let ctx = crate::llm::hooks::HookContext::new(
|
||||||
|
crate::llm::hooks::HookEvent::PostRequest,
|
||||||
|
)
|
||||||
|
.with_request(&post_request);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest {
|
fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest {
|
||||||
let mut messages = self.messages.clone();
|
let mut messages = self.messages.clone();
|
||||||
|
|
||||||
@@ -137,4 +453,455 @@ impl LlmCycle {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 内部请求方法(与 `submit` 共享重试逻辑,但不 push user message 和 Assistant 响应)。
|
||||||
|
///
|
||||||
|
/// 用于 `submit_with_tools()` 的多轮 tool 循环。
|
||||||
|
async fn submit_request(
|
||||||
|
&mut self,
|
||||||
|
tools: &[ToolDefinition],
|
||||||
|
) -> Result<ChatResponse, LlmError> {
|
||||||
|
let mut attempts = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let request = self.build_request(tools);
|
||||||
|
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||||
|
.with_request(&request);
|
||||||
|
let results = executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||||
|
.await;
|
||||||
|
if results.iter().any(|r| r.should_block) {
|
||||||
|
let reason = results
|
||||||
|
.iter()
|
||||||
|
.find(|r| r.should_block)
|
||||||
|
.and_then(|r| r.reason.clone())
|
||||||
|
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||||
|
return Err(LlmError::Other(reason));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.provider.chat(request).await {
|
||||||
|
Ok(response) => {
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let post_request = self.build_request(tools);
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
|
||||||
|
.with_request(&post_request);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
self.usage.add(&response.usage);
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
|
||||||
|
attempts += 1;
|
||||||
|
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnRetry)
|
||||||
|
.with_error(&e)
|
||||||
|
.with_attempt(attempts);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::OnRetry, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let delay = self.config.retry.compute_delay(attempts);
|
||||||
|
tokio::time::sleep(delay).await;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnError)
|
||||||
|
.with_error(&e);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 提交消息并自动处理工具调用循环。
|
||||||
|
///
|
||||||
|
/// 流程:
|
||||||
|
/// 1. 发送请求(含工具定义)
|
||||||
|
/// 2. 检查响应中的 finish_reason
|
||||||
|
/// 3. 如果是 ToolCalls → push Assistant 消息 → 执行工具 → 回传结果 → 重复 1
|
||||||
|
/// 4. 如果是 Stop/Length → push Assistant 消息 → 返回最终响应
|
||||||
|
///
|
||||||
|
/// 注意:OpenAI API 要求 tool 消息必须紧跟在对应的 Assistant(tool_calls)消息之后。
|
||||||
|
/// 因此 push 工具结果前必须先 push Assistant 响应,否则 API 拒绝请求。
|
||||||
|
pub async fn submit_with_tools(
|
||||||
|
&mut self,
|
||||||
|
prompt: String,
|
||||||
|
registry: &crate::tools::ToolRegistry,
|
||||||
|
) -> Result<ChatResponse, LlmError> {
|
||||||
|
let tools = registry.definitions();
|
||||||
|
let max_turns = self.config.max_tool_turns.unwrap_or(10);
|
||||||
|
let tool_timeout = self.config.tool_timeout_secs;
|
||||||
|
let max_bytes = self.config.max_tool_result_bytes;
|
||||||
|
|
||||||
|
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||||
|
self.maybe_compact();
|
||||||
|
|
||||||
|
let mut turn = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
turn += 1;
|
||||||
|
if turn > max_turns {
|
||||||
|
return Err(LlmError::Other(format!(
|
||||||
|
"达到最大工具循环轮次 ({max_turns})"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self.submit_request(&tools).await?;
|
||||||
|
|
||||||
|
// 判断是否需要执行工具
|
||||||
|
let should_execute = matches!(response.stop_reason, Some(FinishReason::ToolCalls))
|
||||||
|
&& has_tool_calls_in_message(&response.message);
|
||||||
|
|
||||||
|
// 将 Assistant 响应(含 tool_calls 或最终文本)追加到消息历史
|
||||||
|
self.messages.push(response.message.clone());
|
||||||
|
|
||||||
|
if !should_execute {
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 tool_calls 并执行
|
||||||
|
let tool_calls = extract_tool_calls_from_message(&response.message);
|
||||||
|
let calls: Vec<(String, serde_json::Value)> = tool_calls
|
||||||
|
.into_iter()
|
||||||
|
.map(|(_id, name, args)| {
|
||||||
|
let args: serde_json::Value =
|
||||||
|
serde_json::from_str(&args).unwrap_or(serde_json::Value::Null);
|
||||||
|
(name, args)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let results = registry.invoke_all(calls, tool_timeout).await;
|
||||||
|
|
||||||
|
// 回传工具结果
|
||||||
|
for result in results {
|
||||||
|
let content = match result.output {
|
||||||
|
Ok(value) => {
|
||||||
|
let serialized = serde_json::to_string(&value).unwrap_or_else(|e| {
|
||||||
|
tracing::warn!("工具结果序列化失败: {}", e);
|
||||||
|
"{}".to_string()
|
||||||
|
});
|
||||||
|
truncate_tool_result(&serialized, max_bytes)
|
||||||
|
}
|
||||||
|
Err(e) if e.is_recoverable() => format!("错误: {}", e),
|
||||||
|
Err(e) => {
|
||||||
|
// 不可恢复错误:终止循环
|
||||||
|
return Err(LlmError::Other(format!(
|
||||||
|
"工具 '{}' 不可恢复错误: {}",
|
||||||
|
result.tool_name, e
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
self.messages
|
||||||
|
.push(OpenaiChatMessage::tool_result(result.tool_name, content));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 每轮工具执行后触发 compaction
|
||||||
|
self.maybe_compact();
|
||||||
|
}
|
||||||
|
|
||||||
|
// unreachable: loop returns
|
||||||
|
#[allow(unreachable_code)]
|
||||||
|
{
|
||||||
|
Err(LlmError::Other("unreachable".into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 在接近上下文窗口时压缩历史消息。
|
||||||
|
fn maybe_compact(&mut self) {
|
||||||
|
if let Some(ref config) = self.compact_config
|
||||||
|
&& should_compact(&self.messages, config, &self.compact_state)
|
||||||
|
{
|
||||||
|
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||||
|
if freed > 0 {
|
||||||
|
self.compact_state.record_success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 判断 Assistant 消息是否包含 tool_calls。
|
||||||
|
fn has_tool_calls_in_message(msg: &OpenaiChatMessage) -> bool {
|
||||||
|
matches!(
|
||||||
|
msg,
|
||||||
|
OpenaiChatMessage::Assistant {
|
||||||
|
tool_calls: Some(calls),
|
||||||
|
..
|
||||||
|
} if !calls.is_empty()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 提取 Assistant 消息中的 tool_calls。
|
||||||
|
///
|
||||||
|
/// 返回 `(tool_call_id, tool_name, arguments_json_string)` 列表。
|
||||||
|
fn extract_tool_calls_from_message(
|
||||||
|
msg: &OpenaiChatMessage,
|
||||||
|
) -> Vec<(String, String, String)> {
|
||||||
|
if let OpenaiChatMessage::Assistant {
|
||||||
|
tool_calls: Some(calls),
|
||||||
|
..
|
||||||
|
} = msg
|
||||||
|
{
|
||||||
|
calls
|
||||||
|
.iter()
|
||||||
|
.map(|c| match c {
|
||||||
|
OpenaiToolCall::Function { id, function } => {
|
||||||
|
(id.clone(), function.name.clone(), function.arguments.clone())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 截断工具结果到指定字节数。
|
||||||
|
fn truncate_tool_result(s: &str, max_bytes: usize) -> String {
|
||||||
|
if s.len() <= max_bytes {
|
||||||
|
return s.to_string();
|
||||||
|
}
|
||||||
|
let mut end = max_bytes;
|
||||||
|
while end > 0 && !s.is_char_boundary(end) {
|
||||||
|
end -= 1;
|
||||||
|
}
|
||||||
|
format!("{}\n\n[... truncated, original size: {} bytes ...]", &s[..end], s.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::llm::types::{ContentField, OpenaiContentPart};
|
||||||
|
use crate::tools::{BaseTool, ToolRegistry};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
|
/// 模拟 Provider —— 预定义响应序列,按调用顺序返回。
|
||||||
|
struct MockProvider {
|
||||||
|
responses: std::sync::Mutex<Vec<ChatResponse>>,
|
||||||
|
call_count: std::sync::Mutex<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockProvider {
|
||||||
|
fn new(responses: Vec<ChatResponse>) -> Self {
|
||||||
|
Self {
|
||||||
|
responses: std::sync::Mutex::new(responses),
|
||||||
|
call_count: std::sync::Mutex::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LlmProvider for MockProvider {
|
||||||
|
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, LlmError> {
|
||||||
|
let mut count = self.call_count.lock().unwrap();
|
||||||
|
*count += 1;
|
||||||
|
let mut responses = self.responses.lock().unwrap();
|
||||||
|
if responses.is_empty() {
|
||||||
|
return Err(LlmError::Other("no more mock responses".into()));
|
||||||
|
}
|
||||||
|
Ok(responses.remove(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn empty_usage() -> crate::llm::types::Usage {
|
||||||
|
crate::llm::types::Usage::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assistant_text_response(text: &str) -> ChatResponse {
|
||||||
|
ChatResponse {
|
||||||
|
message: OpenaiChatMessage::assistant_text(text),
|
||||||
|
usage: empty_usage(),
|
||||||
|
stop_reason: Some(FinishReason::Stop),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assistant_tool_call_response(
|
||||||
|
calls: Vec<(&str, &str, &str)>,
|
||||||
|
) -> ChatResponse {
|
||||||
|
use crate::llm::types::{OpenaiToolCall, FunctionCall};
|
||||||
|
let tool_calls: Vec<OpenaiToolCall> = calls
|
||||||
|
.into_iter()
|
||||||
|
.map(|(id, name, args)| OpenaiToolCall::Function {
|
||||||
|
id: id.to_string(),
|
||||||
|
function: FunctionCall {
|
||||||
|
name: name.to_string(),
|
||||||
|
arguments: args.to_string(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
ChatResponse {
|
||||||
|
message: OpenaiChatMessage::Assistant {
|
||||||
|
content: ContentField::Array(vec![OpenaiContentPart::Text {
|
||||||
|
text: String::new(),
|
||||||
|
}]),
|
||||||
|
refusal: None,
|
||||||
|
name: None,
|
||||||
|
tool_calls: Some(tool_calls),
|
||||||
|
},
|
||||||
|
usage: empty_usage(),
|
||||||
|
stop_reason: Some(FinishReason::ToolCalls),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AddTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for AddTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"add"
|
||||||
|
}
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"加法"
|
||||||
|
}
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
json!({"type":"object","properties":{"a":{"type":"integer"},"b":{"type":"integer"}}})
|
||||||
|
}
|
||||||
|
async fn execute(
|
||||||
|
&self,
|
||||||
|
args: Value,
|
||||||
|
_ctx: &crate::tools::ToolContext<'_>,
|
||||||
|
) -> Result<Value, crate::tools::ToolError> {
|
||||||
|
let a = args["a"].as_i64().unwrap_or(0);
|
||||||
|
let b = args["b"].as_i64().unwrap_or(0);
|
||||||
|
Ok(json!({"result": a + b}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_submit_with_tools_single_turn() {
|
||||||
|
// 第一轮:返回 tool_call;第二轮:返回最终文本
|
||||||
|
let responses = vec![
|
||||||
|
assistant_tool_call_response(vec![("call_1", "add", r#"{"a":1,"b":2}"#)]),
|
||||||
|
assistant_text_response("答案是 3"),
|
||||||
|
];
|
||||||
|
let provider = Box::new(MockProvider::new(responses));
|
||||||
|
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
|
||||||
|
|
||||||
|
let mut registry = ToolRegistry::new();
|
||||||
|
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||||
|
|
||||||
|
let response = cycle
|
||||||
|
.submit_with_tools("1+2=?".to_string(), ®istry)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// 验证最终响应是文本响应
|
||||||
|
assert!(matches!(
|
||||||
|
response.message,
|
||||||
|
OpenaiChatMessage::Assistant { .. }
|
||||||
|
));
|
||||||
|
|
||||||
|
// 验证消息历史:user, assistant(tool_calls), tool, assistant(text)
|
||||||
|
let messages = cycle.messages();
|
||||||
|
assert_eq!(messages.len(), 4);
|
||||||
|
assert!(matches!(messages[0], OpenaiChatMessage::User { .. }));
|
||||||
|
assert!(matches!(messages[1], OpenaiChatMessage::Assistant { .. }));
|
||||||
|
assert!(matches!(messages[2], OpenaiChatMessage::Tool { .. }));
|
||||||
|
assert!(matches!(messages[3], OpenaiChatMessage::Assistant { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_submit_with_tools_multi_turn() {
|
||||||
|
// 3 轮 tool 调用后给出最终答案
|
||||||
|
let responses = vec![
|
||||||
|
assistant_tool_call_response(vec![("call_1", "add", r#"{"a":1,"b":2}"#)]),
|
||||||
|
assistant_tool_call_response(vec![("call_2", "add", r#"{"a":3,"b":4}"#)]),
|
||||||
|
assistant_tool_call_response(vec![("call_3", "add", r#"{"a":5,"b":6}"#)]),
|
||||||
|
assistant_text_response("完成"),
|
||||||
|
];
|
||||||
|
let provider = Box::new(MockProvider::new(responses));
|
||||||
|
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
|
||||||
|
|
||||||
|
let mut registry = ToolRegistry::new();
|
||||||
|
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||||
|
|
||||||
|
let response = cycle
|
||||||
|
.submit_with_tools("计算总和".to_string(), ®istry)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(matches!(
|
||||||
|
response.message,
|
||||||
|
OpenaiChatMessage::Assistant { .. }
|
||||||
|
));
|
||||||
|
|
||||||
|
// user + 3*(assistant + tool) + final assistant = 8
|
||||||
|
let messages = cycle.messages();
|
||||||
|
assert_eq!(messages.len(), 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_submit_with_tools_max_turns_exceeded() {
|
||||||
|
// 配置 max_tool_turns = 2
|
||||||
|
let mut config = CycleConfig::default();
|
||||||
|
config.max_tool_turns = Some(2);
|
||||||
|
// 4 轮 tool 调用 + 终止
|
||||||
|
let responses = vec![
|
||||||
|
assistant_tool_call_response(vec![("c1", "add", r#"{"a":1,"b":1}"#)]),
|
||||||
|
assistant_tool_call_response(vec![("c2", "add", r#"{"a":1,"b":1}"#)]),
|
||||||
|
assistant_tool_call_response(vec![("c3", "add", r#"{"a":1,"b":1}"#)]),
|
||||||
|
assistant_text_response("完成"),
|
||||||
|
];
|
||||||
|
let provider = Box::new(MockProvider::new(responses));
|
||||||
|
let mut cycle = LlmCycle::new(provider, config);
|
||||||
|
|
||||||
|
let mut registry = ToolRegistry::new();
|
||||||
|
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||||
|
|
||||||
|
let result = cycle
|
||||||
|
.submit_with_tools("test".to_string(), ®istry)
|
||||||
|
.await;
|
||||||
|
assert!(matches!(result, Err(LlmError::Other(msg)) if msg.contains("达到最大工具循环轮次")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_submit_with_tools_no_tool_call_response() {
|
||||||
|
// LLM 直接给出最终响应(不调用工具)
|
||||||
|
let responses = vec![assistant_text_response("直接回答")];
|
||||||
|
let provider = Box::new(MockProvider::new(responses));
|
||||||
|
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
|
||||||
|
|
||||||
|
let mut registry = ToolRegistry::new();
|
||||||
|
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||||
|
|
||||||
|
let response = cycle
|
||||||
|
.submit_with_tools("直接回答".to_string(), ®istry)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(matches!(
|
||||||
|
response.message,
|
||||||
|
OpenaiChatMessage::Assistant { .. }
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_truncate_tool_result_short() {
|
||||||
|
let s = "short text";
|
||||||
|
assert_eq!(truncate_tool_result(s, 100), "short text");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_truncate_tool_result_long() {
|
||||||
|
let s = "a".repeat(1000);
|
||||||
|
let truncated = truncate_tool_result(&s, 50);
|
||||||
|
assert!(truncated.len() < s.len());
|
||||||
|
assert!(truncated.contains("[... truncated,"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_truncate_tool_result_chinese_chars() {
|
||||||
|
let s = "中".repeat(100);
|
||||||
|
let truncated = truncate_tool_result(&s, 50);
|
||||||
|
// 不会在字符中间截断
|
||||||
|
assert!(truncated.starts_with("中"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,128 @@
|
|||||||
|
//! 生命周期钩子 —— 在 LLM 调用周期的关键节点插入自定义逻辑。
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use crate::llm::error::LlmError;
|
||||||
|
use crate::llm::types::ChatRequest;
|
||||||
|
|
||||||
|
/// 生命周期钩子事件点。
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum HookEvent {
|
||||||
|
/// LLM 请求发起之前(可阻断)。
|
||||||
|
PreRequest,
|
||||||
|
/// 成功响应之后。
|
||||||
|
PostRequest,
|
||||||
|
/// 重试之前(仅可重试错误时触发)。
|
||||||
|
OnRetry,
|
||||||
|
/// 不可恢复错误返回之前。
|
||||||
|
OnError,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 此次钩子调用的上下文。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct HookContext<'a> {
|
||||||
|
/// 当前事件点。
|
||||||
|
pub event: HookEvent,
|
||||||
|
/// 当前的请求(部分事件点可用)。
|
||||||
|
pub request: Option<&'a ChatRequest>,
|
||||||
|
/// 当前错误(仅 OnError 和 OnRetry 可用)。
|
||||||
|
pub error: Option<&'a LlmError>,
|
||||||
|
/// 当前重试次数(从 1 开始,仅 OnRetry 可用)。
|
||||||
|
pub attempt: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> HookContext<'a> {
|
||||||
|
pub(crate) fn new(event: HookEvent) -> Self {
|
||||||
|
Self {
|
||||||
|
event,
|
||||||
|
request: None,
|
||||||
|
error: None,
|
||||||
|
attempt: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn with_request(mut self, request: &'a ChatRequest) -> Self {
|
||||||
|
self.request = Some(request);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn with_error(mut self, error: &'a LlmError) -> Self {
|
||||||
|
self.error = Some(error);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn with_attempt(mut self, attempt: u32) -> Self {
|
||||||
|
self.attempt = attempt;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 钩子执行结果。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct HookResult {
|
||||||
|
/// 是否阻断后续操作(仅 PreRequest 有效)。
|
||||||
|
pub should_block: bool,
|
||||||
|
/// 阻断/备注原因。
|
||||||
|
pub reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HookResult {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn allow() -> Self {
|
||||||
|
Self {
|
||||||
|
should_block: false,
|
||||||
|
reason: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn block(reason: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
should_block: true,
|
||||||
|
reason: Some(reason.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 生命周期钩子 trait。
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Hook: Send + Sync {
|
||||||
|
/// 执行钩子逻辑。
|
||||||
|
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 钩子执行器 —— 管理注册与触发。
|
||||||
|
pub struct HookExecutor {
|
||||||
|
hooks: Vec<(HookEvent, Box<dyn Hook>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for HookExecutor {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HookExecutor {
|
||||||
|
/// 创建一个空的执行器。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
hooks: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 注册一个钩子到指定事件点。
|
||||||
|
pub fn register(&mut self, event: HookEvent, hook: Box<dyn Hook>) {
|
||||||
|
self.hooks.push((event, hook));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行指定事件点的所有钩子。
|
||||||
|
pub async fn execute(&self, event: HookEvent, ctx: &HookContext<'_>) -> Vec<HookResult> {
|
||||||
|
let mut results = Vec::new();
|
||||||
|
for (e, hook) in &self.hooks {
|
||||||
|
if *e == event {
|
||||||
|
results.push(hook.execute(ctx).await);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results
|
||||||
|
}
|
||||||
|
}
|
||||||
+21
-1
@@ -1,7 +1,12 @@
|
|||||||
pub mod openai;
|
pub mod openai;
|
||||||
|
pub mod registry;
|
||||||
|
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
use tokio_stream::Stream;
|
||||||
|
|
||||||
use crate::llm::error::LlmError;
|
use crate::llm::error::LlmError;
|
||||||
use crate::llm::types::{ChatRequest, ChatResponse};
|
use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatChunk};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
@@ -57,4 +62,19 @@ pub fn create_provider(
|
|||||||
pub trait LlmProvider: Send + Sync {
|
pub trait LlmProvider: Send + Sync {
|
||||||
/// 发送聊天请求并返回完整响应。
|
/// 发送聊天请求并返回完整响应。
|
||||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError>;
|
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError>;
|
||||||
|
|
||||||
|
/// 流式聊天请求 —— 返回原始 SSE chunk 流。
|
||||||
|
///
|
||||||
|
/// 默认实现回退到非流式调用(包装为单元素流)。
|
||||||
|
async fn chat_stream(
|
||||||
|
&self,
|
||||||
|
request: ChatRequest,
|
||||||
|
) -> Result<
|
||||||
|
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||||
|
LlmError,
|
||||||
|
> {
|
||||||
|
let response = self.chat(request).await?;
|
||||||
|
let chunk = OpenaiChatChunk::from(response);
|
||||||
|
Ok(Box::pin(tokio_stream::once(Ok(chunk))))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+122
-1
@@ -1,12 +1,18 @@
|
|||||||
|
use std::pin::Pin;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use bytes::Bytes;
|
||||||
|
use futures_core::stream::Stream;
|
||||||
|
use futures_util::StreamExt;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use tracing::{debug, error, info};
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
use super::LlmProvider;
|
use super::LlmProvider;
|
||||||
use crate::llm::error::LlmError;
|
use crate::llm::error::LlmError;
|
||||||
use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatResponse};
|
use crate::llm::types::{
|
||||||
|
ChatRequest, ChatResponse, OpenaiChatChunk, OpenaiChatResponse, StreamOptions,
|
||||||
|
};
|
||||||
|
|
||||||
pub struct OpenaiProvider {
|
pub struct OpenaiProvider {
|
||||||
http_client: Client,
|
http_client: Client,
|
||||||
@@ -111,4 +117,119 @@ impl LlmProvider for OpenaiProvider {
|
|||||||
|
|
||||||
Ok(ChatResponse::from(chat_response))
|
Ok(ChatResponse::from(chat_response))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat_stream(
|
||||||
|
&self,
|
||||||
|
mut request: ChatRequest,
|
||||||
|
) -> Result<
|
||||||
|
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||||
|
LlmError,
|
||||||
|
> {
|
||||||
|
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
|
||||||
|
|
||||||
|
request.stream = Some(true);
|
||||||
|
request.stream_options = Some(StreamOptions {
|
||||||
|
include_usage: Some(true),
|
||||||
|
include_obfuscation: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
info!(model = %request.model, "发送 LLM 流式请求");
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.http_client
|
||||||
|
.post(&url)
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
error!(error = %e, "流式请求失败");
|
||||||
|
Self::map_reqwest_error(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
let status_code: u16 = status.as_u16();
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let body_text = response.text().await.unwrap_or_default();
|
||||||
|
error!(status = status_code, body = %body_text, "流式请求失败");
|
||||||
|
return Err(LlmError::Request {
|
||||||
|
status: status_code,
|
||||||
|
body: body_text,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let byte_stream = response.bytes_stream().map(|r| {
|
||||||
|
r.map_err(|e| LlmError::Other(format!("流式读取失败: {}", e)))
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Box::pin(SseChunkStream::new(byte_stream)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SseChunkStream<S> {
|
||||||
|
inner: S,
|
||||||
|
buffer: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> SseChunkStream<S> {
|
||||||
|
fn new(stream: S) -> Self {
|
||||||
|
Self {
|
||||||
|
inner: stream,
|
||||||
|
buffer: String::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> Stream for SseChunkStream<S> {
|
||||||
|
type Item = Result<OpenaiChatChunk, LlmError>;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
|
||||||
|
loop {
|
||||||
|
if let Some(pos) = self.buffer.find("\n") {
|
||||||
|
let line = self.buffer.drain(..pos + 1).collect::<String>();
|
||||||
|
let trimmed = line.trim();
|
||||||
|
if trimmed.is_empty()
|
||||||
|
|| trimmed == "data: [DONE]"
|
||||||
|
|| trimmed == "[DONE]"
|
||||||
|
|| trimmed == "data:"
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let data = if let Some(p) = trimmed.strip_prefix("data: ") {
|
||||||
|
p
|
||||||
|
} else {
|
||||||
|
trimmed
|
||||||
|
};
|
||||||
|
match serde_json::from_str::<OpenaiChatChunk>(data) {
|
||||||
|
Ok(chunk) => return std::task::Poll::Ready(Some(Ok(chunk))),
|
||||||
|
Err(e) => {
|
||||||
|
return std::task::Poll::Ready(Some(Err(LlmError::Other(format!(
|
||||||
|
"Chunk 解析失败: {} | raw: {}",
|
||||||
|
e, data
|
||||||
|
)))));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match Pin::new(&mut self.inner).poll_next(cx) {
|
||||||
|
std::task::Poll::Ready(Some(Ok(bytes))) => {
|
||||||
|
if let Ok(s) = std::str::from_utf8(&bytes) {
|
||||||
|
self.buffer.push_str(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::task::Poll::Ready(Some(Err(e))) => {
|
||||||
|
return std::task::Poll::Ready(Some(Err(e)));
|
||||||
|
}
|
||||||
|
std::task::Poll::Ready(None) => {
|
||||||
|
if self.buffer.is_empty() {
|
||||||
|
return std::task::Poll::Ready(None);
|
||||||
|
}
|
||||||
|
self.buffer.clear();
|
||||||
|
return std::task::Poll::Ready(None);
|
||||||
|
}
|
||||||
|
std::task::Poll::Pending => return std::task::Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,68 @@
|
|||||||
|
//! Provider 注册表 —— 多 Provider 实例的注册与发现。
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::llm::error::LlmError;
|
||||||
|
use crate::llm::provider::{create_provider, LlmProvider, ProviderConfig, ProviderType};
|
||||||
|
|
||||||
|
/// Provider 注册表 —— 管理多个 LLM Provider 实例。
|
||||||
|
///
|
||||||
|
/// 支持注册命名 Provider、按名称查找、设置默认 Provider。
|
||||||
|
pub struct ProviderRegistry {
|
||||||
|
providers: HashMap<String, Box<dyn LlmProvider>>,
|
||||||
|
default_name: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ProviderRegistry {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderRegistry {
|
||||||
|
/// 创建一个空的注册表。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
providers: HashMap::new(),
|
||||||
|
default_name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 注册一个已初始化的 Provider 实例。
|
||||||
|
pub fn register(&mut self, name: impl Into<String>, provider: Box<dyn LlmProvider>) {
|
||||||
|
self.providers.insert(name.into(), provider);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 通过 ProviderType + ProviderConfig 创建并注册。
|
||||||
|
pub fn register_with_config(
|
||||||
|
&mut self,
|
||||||
|
name: impl Into<String>,
|
||||||
|
provider_type: ProviderType,
|
||||||
|
config: ProviderConfig,
|
||||||
|
) -> Result<(), LlmError> {
|
||||||
|
let provider = create_provider(provider_type, config)?;
|
||||||
|
self.register(name, provider);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 设置默认 Provider。
|
||||||
|
pub fn set_default(&mut self, name: &str) -> Result<(), LlmError> {
|
||||||
|
if !self.providers.contains_key(name) {
|
||||||
|
return Err(LlmError::Other(format!("Provider '{}' 不存在", name)));
|
||||||
|
}
|
||||||
|
self.default_name = Some(name.to_string());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 按名称查找 Provider。
|
||||||
|
pub fn get(&self, name: &str) -> Option<&dyn LlmProvider> {
|
||||||
|
self.providers.get(name).map(|p| p.as_ref())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取默认 Provider。
|
||||||
|
pub fn get_default(&self) -> Option<&dyn LlmProvider> {
|
||||||
|
self.default_name
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|name| self.get(name))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
//! 流式事件系统 —— 将 LLM 流式响应解析为语义化事件。
|
||||||
|
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
|
use futures_core::stream::Stream;
|
||||||
|
use futures_util::future::poll_fn;
|
||||||
|
use futures_util::FutureExt;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::llm::error::LlmError;
|
||||||
|
use crate::llm::types::{FinishReason, OpenaiChatChunk, OpenaiToolCall, Usage};
|
||||||
|
|
||||||
|
/// 流式事件 —— LLM 调用全生命周期的语义化事件。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum StreamEvent {
|
||||||
|
/// 助手回复文本增量。
|
||||||
|
AssistantTextDelta { text: String },
|
||||||
|
/// 工具调用开始。
|
||||||
|
ToolExecutionStarted {
|
||||||
|
tool_name: String,
|
||||||
|
input: Value,
|
||||||
|
tool_call_id: String,
|
||||||
|
},
|
||||||
|
/// 工具调用完成。
|
||||||
|
ToolExecutionCompleted {
|
||||||
|
tool_name: String,
|
||||||
|
output: Value,
|
||||||
|
is_error: bool,
|
||||||
|
},
|
||||||
|
/// Token 用量更新。
|
||||||
|
CostUpdate { usage: Usage },
|
||||||
|
/// 一轮会话完成。
|
||||||
|
TurnComplete { reason: FinishReason },
|
||||||
|
/// 错误事件。
|
||||||
|
Error { message: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamEvent {
|
||||||
|
fn error(message: impl Into<String>) -> Self {
|
||||||
|
Self::Error {
|
||||||
|
message: message.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 将原始 OpenaiChatChunk 流解析为 StreamEvent 流。
|
||||||
|
pub fn parse_chunk_stream(
|
||||||
|
chunks: Pin<Box<dyn futures_core::Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||||
|
) -> Pin<Box<dyn futures_core::Stream<Item = StreamEvent> + Send>> {
|
||||||
|
Box::pin(ChunkToEventStream { chunks })
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ChunkToEventStream {
|
||||||
|
chunks: Pin<Box<dyn futures_core::Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for ChunkToEventStream {
|
||||||
|
type Item = StreamEvent;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
let this = &mut *self;
|
||||||
|
poll_fn(|cx| match Pin::new(&mut this.chunks).poll_next(cx) {
|
||||||
|
Poll::Ready(Some(Ok(chunk))) => {
|
||||||
|
for choice in &chunk.choices {
|
||||||
|
let delta = &choice.delta;
|
||||||
|
|
||||||
|
if let Some(content) = &delta.content {
|
||||||
|
return Poll::Ready(Some(StreamEvent::AssistantTextDelta {
|
||||||
|
text: content.clone(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(tool_calls) = &delta.tool_calls
|
||||||
|
&& let Some(tc) = tool_calls.first()
|
||||||
|
{
|
||||||
|
let OpenaiToolCall::Function { id, function } = tc;
|
||||||
|
let args: Value =
|
||||||
|
serde_json::from_str(&function.arguments).unwrap_or(Value::Null);
|
||||||
|
return Poll::Ready(Some(StreamEvent::ToolExecutionStarted {
|
||||||
|
tool_name: function.name.clone(),
|
||||||
|
input: args,
|
||||||
|
tool_call_id: id.clone(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(finish_reason) = &choice.finish_reason {
|
||||||
|
return Poll::Ready(Some(StreamEvent::TurnComplete {
|
||||||
|
reason: *finish_reason,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(usage) = &chunk.usage {
|
||||||
|
return Poll::Ready(Some(StreamEvent::CostUpdate {
|
||||||
|
usage: *usage,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
Poll::Ready(None)
|
||||||
|
}
|
||||||
|
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(StreamEvent::error(e.to_string()))),
|
||||||
|
Poll::Ready(None) => Poll::Ready(None),
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
})
|
||||||
|
.poll_unpin(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -43,6 +43,34 @@ impl From<OpenaiChatResponse> for ChatResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<ChatResponse> for OpenaiChatChunk {
|
||||||
|
fn from(response: ChatResponse) -> Self {
|
||||||
|
let delta = Delta::from(response.message.clone());
|
||||||
|
let chunk_choice = ChunkChoice {
|
||||||
|
index: 0,
|
||||||
|
delta,
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: response.stop_reason,
|
||||||
|
};
|
||||||
|
|
||||||
|
OpenaiChatChunk {
|
||||||
|
id: format!("chunk-{}", std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.map(|d| d.as_nanos())
|
||||||
|
.unwrap_or(0)),
|
||||||
|
object: "chat.completion.chunk".to_string(),
|
||||||
|
created: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.map(|d| d.as_secs())
|
||||||
|
.unwrap_or(0),
|
||||||
|
model: String::new(),
|
||||||
|
choices: vec![chunk_choice],
|
||||||
|
usage: Some(response.usage),
|
||||||
|
system_fingerprint: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub type ChatRequest = OpenaiChatRequest;
|
pub type ChatRequest = OpenaiChatRequest;
|
||||||
pub type Message = OpenaiChatMessage;
|
pub type Message = OpenaiChatMessage;
|
||||||
pub type ContentBlock = OpenaiContentPart;
|
pub type ContentBlock = OpenaiContentPart;
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use crate::llm::types::shared::{FinishReason, ServiceTier};
|
|||||||
use crate::llm::types::tool::OpenaiToolCall;
|
use crate::llm::types::tool::OpenaiToolCall;
|
||||||
use crate::llm::types::usage::Usage;
|
use crate::llm::types::usage::Usage;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use crate::llm::types::{ContentField, OpenaiContentPart};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct TokenLogprob {
|
pub struct TokenLogprob {
|
||||||
@@ -115,3 +116,66 @@ pub struct OpenaiChatChunk {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub system_fingerprint: Option<String>,
|
pub system_fingerprint: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<OpenaiChatMessage> for Delta {
|
||||||
|
fn from(msg: OpenaiChatMessage) -> Self {
|
||||||
|
match msg {
|
||||||
|
OpenaiChatMessage::Assistant {
|
||||||
|
content,
|
||||||
|
tool_calls,
|
||||||
|
..
|
||||||
|
} => Delta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: match content {
|
||||||
|
ContentField::String(s) => Some(s),
|
||||||
|
ContentField::Array(parts) => {
|
||||||
|
let mut text = String::new();
|
||||||
|
for part in parts {
|
||||||
|
if let OpenaiContentPart::Text { text: t } = part {
|
||||||
|
text.push_str(&t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if text.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
refusal: None,
|
||||||
|
tool_calls,
|
||||||
|
},
|
||||||
|
_ => Delta {
|
||||||
|
role: None,
|
||||||
|
content: None,
|
||||||
|
refusal: None,
|
||||||
|
tool_calls: None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OpenaiChatResponse> for OpenaiChatChunk {
|
||||||
|
fn from(response: OpenaiChatResponse) -> Self {
|
||||||
|
let choices = response
|
||||||
|
.choices
|
||||||
|
.into_iter()
|
||||||
|
.map(|c| ChunkChoice {
|
||||||
|
index: c.index,
|
||||||
|
delta: Delta::from(c.message),
|
||||||
|
logprobs: c.logprobs,
|
||||||
|
finish_reason: c.finish_reason,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
OpenaiChatChunk {
|
||||||
|
id: response.id,
|
||||||
|
object: "chat.completion.chunk".to_string(),
|
||||||
|
created: response.created,
|
||||||
|
model: response.model,
|
||||||
|
choices,
|
||||||
|
usage: Some(response.usage),
|
||||||
|
system_fingerprint: response.system_fingerprint,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
//! 记忆系统 —— 对话消息管理、知识页面存储与关键词检索。
|
||||||
|
|
||||||
|
pub mod conversation;
|
||||||
|
pub mod error;
|
||||||
|
pub mod knowledge;
|
||||||
|
pub mod retriever;
|
||||||
|
pub mod store;
|
||||||
|
pub mod types;
|
||||||
|
|
||||||
|
// 高频类型(大多数下游需要)
|
||||||
|
pub use conversation::{ConversationMemory, ConversationMemoryConfig};
|
||||||
|
pub use error::MemoryError;
|
||||||
|
pub use knowledge::KnowledgeStore;
|
||||||
|
pub use retriever::MemoryRetriever;
|
||||||
|
pub use store::{InMemoryStore, MemoryStore};
|
||||||
|
|
||||||
|
// 低频类型(配置/高级使用)
|
||||||
|
pub use conversation::MemoryStrategy;
|
||||||
|
pub use knowledge::{PageIndexEntry, KNOWLEDGE_PREFIX};
|
||||||
|
pub use retriever::{RetrieverConfig, RetrievalResult, ScoredItem};
|
||||||
|
pub use store::{EvictionConfig, EvictionPolicy};
|
||||||
|
pub use types::{KnowledgePage, MemoryFilter, MemoryItem};
|
||||||
@@ -0,0 +1,260 @@
|
|||||||
|
//! 对话记忆 —— 多轮对话消息管理,复用 `llm::compact` 的压缩逻辑。
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
use crate::llm::compact::{CompactConfig, CompactState, microcompact, should_compact};
|
||||||
|
use crate::llm::types::OpenaiChatMessage;
|
||||||
|
use crate::memory::error::MemoryError;
|
||||||
|
use crate::memory::store::MemoryStore;
|
||||||
|
use crate::memory::types::MemoryItem;
|
||||||
|
|
||||||
|
/// 对话消息管理策略。
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub enum MemoryStrategy {
|
||||||
|
/// 滑动窗口:达到上限时删除最旧消息。
|
||||||
|
SlidingWindow,
|
||||||
|
/// 保留所有消息(仅压缩,不删除)。
|
||||||
|
Full,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 对话记忆配置。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ConversationMemoryConfig {
|
||||||
|
pub strategy: MemoryStrategy,
|
||||||
|
pub max_turns: usize,
|
||||||
|
pub compact_config: Option<CompactConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ConversationMemoryConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
strategy: MemoryStrategy::SlidingWindow,
|
||||||
|
max_turns: 50,
|
||||||
|
compact_config: Some(CompactConfig::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 对话记忆 —— 按 session 管理多轮对话消息历史。
|
||||||
|
///
|
||||||
|
/// 内部维护 `Vec<OpenaiChatMessage>` 热缓存(供 `llm::compact` 直接操作),
|
||||||
|
/// `MemoryStore` 用作冷持久化层。
|
||||||
|
pub struct ConversationMemory {
|
||||||
|
store: Arc<dyn MemoryStore>,
|
||||||
|
session_id: String,
|
||||||
|
config: ConversationMemoryConfig,
|
||||||
|
/// 热缓存:消息列表,供 `llm::compact` 直接操作。
|
||||||
|
messages: Vec<OpenaiChatMessage>,
|
||||||
|
/// 与 `messages` 一一对应的存储 ID(保持稳定以便淘汰时精准删除)。
|
||||||
|
message_ids: Vec<String>,
|
||||||
|
/// 压缩断路器状态。
|
||||||
|
compact_state: CompactState,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConversationMemory {
|
||||||
|
/// 创建一个新的 ConversationMemory。
|
||||||
|
pub fn new(
|
||||||
|
store: Arc<dyn MemoryStore>,
|
||||||
|
session_id: impl Into<String>,
|
||||||
|
config: ConversationMemoryConfig,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
store,
|
||||||
|
session_id: session_id.into(),
|
||||||
|
config,
|
||||||
|
messages: Vec::new(),
|
||||||
|
message_ids: Vec::new(),
|
||||||
|
compact_state: CompactState::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取 session id。
|
||||||
|
pub fn session_id(&self) -> &str {
|
||||||
|
&self.session_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取配置。
|
||||||
|
pub fn config(&self) -> &ConversationMemoryConfig {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 MemoryStore 加载历史消息到热缓存。
|
||||||
|
pub async fn load(&mut self) -> Result<(), MemoryError> {
|
||||||
|
let filter = crate::memory::types::MemoryFilter {
|
||||||
|
prefix: Some(self.session_prefix()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let items = self.store.list(&filter).await?;
|
||||||
|
let mut pairs: Vec<(String, OpenaiChatMessage, OffsetDateTime)> = Vec::with_capacity(items.len());
|
||||||
|
for item in items {
|
||||||
|
match serde_json::from_str::<OpenaiChatMessage>(&item.content) {
|
||||||
|
Ok(msg) => pairs.push((item.id, msg, item.created_at)),
|
||||||
|
Err(e) => {
|
||||||
|
return Err(MemoryError::Serialization(format!(
|
||||||
|
"load message {} failed: {e}",
|
||||||
|
item.id
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 按 created_at 升序排列
|
||||||
|
pairs.sort_by_key(|p| p.2);
|
||||||
|
self.message_ids = pairs.iter().map(|p| p.0.clone()).collect();
|
||||||
|
self.messages = pairs.into_iter().map(|p| p.1).collect();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条消息。
|
||||||
|
///
|
||||||
|
/// 写入热缓存并通过 `MemoryStore` 持久化。如有需要,触发淘汰和压缩。
|
||||||
|
pub async fn add_message(&mut self, msg: OpenaiChatMessage) -> Result<(), MemoryError> {
|
||||||
|
let now = OffsetDateTime::now_utc();
|
||||||
|
let index = self.messages.len();
|
||||||
|
let id = self.make_message_id(index, &now);
|
||||||
|
|
||||||
|
// 写入热缓存
|
||||||
|
self.messages.push(msg);
|
||||||
|
self.message_ids.push(id.clone());
|
||||||
|
|
||||||
|
// 同步到冷存储
|
||||||
|
let item = MemoryItem {
|
||||||
|
id: id.clone(),
|
||||||
|
content: serde_json::to_string(self.messages.last().unwrap())
|
||||||
|
.map_err(|e| MemoryError::Serialization(e.to_string()))?,
|
||||||
|
metadata: serde_json::json!({ "session_id": &self.session_id, "index": index }),
|
||||||
|
created_at: now,
|
||||||
|
};
|
||||||
|
self.store.save(item).await?;
|
||||||
|
|
||||||
|
// 触发淘汰和压缩
|
||||||
|
self.maybe_evict_and_compact().await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取完整消息历史。
|
||||||
|
pub fn get_history(&self) -> &[OpenaiChatMessage] {
|
||||||
|
&self.messages
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 清空所有消息。
|
||||||
|
pub async fn clear(&mut self) -> Result<(), MemoryError> {
|
||||||
|
let to_delete = std::mem::take(&mut self.message_ids);
|
||||||
|
self.messages.clear();
|
||||||
|
self.compact_state = CompactState::new();
|
||||||
|
for id in to_delete {
|
||||||
|
self.store.delete(&id).await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 当前消息数量。
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.messages.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 是否为空。
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.messages.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn session_prefix(&self) -> String {
|
||||||
|
format!("conv:{self}:", self = self.session_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_message_id(&self, index: usize, now: &OffsetDateTime) -> String {
|
||||||
|
format!("{}{:010}_{}", self.session_prefix(), index, now.unix_timestamp_nanos())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn maybe_evict_and_compact(&mut self) {
|
||||||
|
// 1. Sliding window 淘汰:删除最旧消息
|
||||||
|
if self.config.strategy == MemoryStrategy::SlidingWindow {
|
||||||
|
while self.messages.len() > self.config.max_turns {
|
||||||
|
if let Some(removed_id) = self.message_ids.first().cloned() {
|
||||||
|
let _ = self.store.delete(&removed_id).await;
|
||||||
|
}
|
||||||
|
self.messages.remove(0);
|
||||||
|
self.message_ids.remove(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 压缩(复用 llm::compact)
|
||||||
|
if let Some(ref compact_config) = self.config.compact_config {
|
||||||
|
if should_compact(&self.messages, compact_config, &self.compact_state) {
|
||||||
|
let keep_recent = compact_config.keep_recent;
|
||||||
|
let freed = microcompact(&mut self.messages, keep_recent);
|
||||||
|
if freed > 0 {
|
||||||
|
self.compact_state.record_success();
|
||||||
|
} else {
|
||||||
|
// 没有 token 被释放(可能没找到可压缩的 tool result)
|
||||||
|
let _ = self.compact_state.record_failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::llm::types::OpenaiChatMessage;
|
||||||
|
use crate::memory::InMemoryStore;
|
||||||
|
use crate::memory::MemoryStore;
|
||||||
|
|
||||||
|
fn user_text(s: &str) -> OpenaiChatMessage {
|
||||||
|
OpenaiChatMessage::user_text(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn add_and_get_history() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let mut conv = ConversationMemory::new(store, "session1", ConversationMemoryConfig::default());
|
||||||
|
conv.add_message(user_text("hello")).await.unwrap();
|
||||||
|
conv.add_message(user_text("world")).await.unwrap();
|
||||||
|
assert_eq!(conv.len(), 2);
|
||||||
|
assert_eq!(conv.get_history().len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sliding_window_evicts_oldest() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let config = ConversationMemoryConfig {
|
||||||
|
strategy: MemoryStrategy::SlidingWindow,
|
||||||
|
max_turns: 3,
|
||||||
|
compact_config: None,
|
||||||
|
};
|
||||||
|
let mut conv = ConversationMemory::new(store, "s1", config);
|
||||||
|
for i in 0..5 {
|
||||||
|
conv.add_message(user_text(&format!("msg-{i}"))).await.unwrap();
|
||||||
|
}
|
||||||
|
assert_eq!(conv.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn full_strategy_no_evict() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let config = ConversationMemoryConfig {
|
||||||
|
strategy: MemoryStrategy::Full,
|
||||||
|
max_turns: 3,
|
||||||
|
compact_config: None,
|
||||||
|
};
|
||||||
|
let mut conv = ConversationMemory::new(store, "s1", config);
|
||||||
|
for i in 0..5 {
|
||||||
|
conv.add_message(user_text(&format!("msg-{i}"))).await.unwrap();
|
||||||
|
}
|
||||||
|
// Full 策略不删除消息
|
||||||
|
assert_eq!(conv.len(), 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn clear_empties_messages() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let mut conv = ConversationMemory::new(store.clone(), "s1", ConversationMemoryConfig::default());
|
||||||
|
conv.add_message(user_text("hello")).await.unwrap();
|
||||||
|
assert!(!conv.is_empty());
|
||||||
|
conv.clear().await.unwrap();
|
||||||
|
assert!(conv.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
//! 记忆系统错误类型。
|
||||||
|
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// 记忆系统错误枚举。
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum MemoryError {
|
||||||
|
#[error("Item not found: {0}")]
|
||||||
|
NotFound(String),
|
||||||
|
|
||||||
|
#[error("Storage error: {0}")]
|
||||||
|
Storage(String),
|
||||||
|
|
||||||
|
#[error("Serialization error: {0}")]
|
||||||
|
Serialization(String),
|
||||||
|
|
||||||
|
#[error("Invalid input: {0}")]
|
||||||
|
InvalidInput(String),
|
||||||
|
|
||||||
|
#[error("Retrieval error: {0}")]
|
||||||
|
RetrievalError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryError {
|
||||||
|
/// 是否为可恢复错误(调用方可重试或调整参数)。
|
||||||
|
pub fn is_recoverable(&self) -> bool {
|
||||||
|
matches!(self, Self::NotFound(_) | Self::RetrievalError(_))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,247 @@
|
|||||||
|
//! 知识库 —— KnowledgePage 存储与关键词检索。
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
use crate::memory::error::MemoryError;
|
||||||
|
use crate::memory::store::MemoryStore;
|
||||||
|
use crate::memory::types::{KnowledgePage, MemoryFilter, MemoryItem};
|
||||||
|
|
||||||
|
pub use crate::memory::types::PageIndexEntry;
|
||||||
|
|
||||||
|
/// `MemoryItem.id` 中知识页面前缀。
|
||||||
|
pub const KNOWLEDGE_PREFIX: &str = "knowledge_";
|
||||||
|
|
||||||
|
/// 知识库 —— KnowledgePage CRUD + 关键词检索 + 内容索引。
|
||||||
|
///
|
||||||
|
/// 内部以 `MemoryStore` 为后端存储 KnowledgePage(序列化为 JSON),
|
||||||
|
/// 同时维护一个 `Vec<PageIndexEntry>` 索引以加速列表遍历。
|
||||||
|
pub struct KnowledgeStore {
|
||||||
|
store: Arc<dyn MemoryStore>,
|
||||||
|
index: std::sync::Mutex<Vec<PageIndexEntry>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KnowledgeStore {
|
||||||
|
/// 创建一个新的 KnowledgeStore。
|
||||||
|
pub fn new(store: Arc<dyn MemoryStore>) -> Self {
|
||||||
|
Self {
|
||||||
|
store,
|
||||||
|
index: std::sync::Mutex::new(Vec::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 MemoryStore 重建索引(修复 index 与 store 的不同步问题)。
|
||||||
|
pub async fn rebuild_index(&self) -> Result<(), MemoryError> {
|
||||||
|
let items = self
|
||||||
|
.store
|
||||||
|
.list(&MemoryFilter {
|
||||||
|
prefix: Some(KNOWLEDGE_PREFIX.to_string()),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
let mut index = self.index.lock().unwrap();
|
||||||
|
index.clear();
|
||||||
|
for item in items {
|
||||||
|
let page: KnowledgePage = serde_json::from_str(&item.content)
|
||||||
|
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||||
|
index.push(PageIndexEntry::from(&page));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 创建一个新的知识页面。
|
||||||
|
pub async fn add_page(&self, page: KnowledgePage) -> Result<(), MemoryError> {
|
||||||
|
if page.id.is_empty() {
|
||||||
|
return Err(MemoryError::InvalidInput("page.id is empty".into()));
|
||||||
|
}
|
||||||
|
let now = OffsetDateTime::now_utc();
|
||||||
|
let id = format!("{KNOWLEDGE_PREFIX}{}", page.id);
|
||||||
|
let content = serde_json::to_string(&page)
|
||||||
|
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||||
|
let item = MemoryItem {
|
||||||
|
id,
|
||||||
|
content,
|
||||||
|
metadata: serde_json::json!({}),
|
||||||
|
created_at: now,
|
||||||
|
};
|
||||||
|
self.store.save(item).await?;
|
||||||
|
let mut index = self.index.lock().unwrap();
|
||||||
|
// 替换或追加
|
||||||
|
if let Some(existing) = index.iter_mut().find(|e| e.id == page.id) {
|
||||||
|
*existing = PageIndexEntry::from(&page);
|
||||||
|
} else {
|
||||||
|
index.push(PageIndexEntry::from(&page));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 根据 page id 获取一个页面。
|
||||||
|
pub async fn get_page(&self, id: &str) -> Result<Option<KnowledgePage>, MemoryError> {
|
||||||
|
let full_id = format!("{KNOWLEDGE_PREFIX}{id}");
|
||||||
|
let item = self.store.get(&full_id).await?;
|
||||||
|
match item {
|
||||||
|
None => Ok(None),
|
||||||
|
Some(item) => {
|
||||||
|
let page: KnowledgePage = serde_json::from_str(&item.content)
|
||||||
|
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||||
|
Ok(Some(page))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 更新一个已存在的知识页面。
|
||||||
|
pub async fn update_page(&self, page: KnowledgePage) -> Result<(), MemoryError> {
|
||||||
|
if page.id.is_empty() {
|
||||||
|
return Err(MemoryError::InvalidInput("page.id is empty".into()));
|
||||||
|
}
|
||||||
|
// 通过 get_page 检查存在性
|
||||||
|
if self.get_page(&page.id).await?.is_none() {
|
||||||
|
return Err(MemoryError::NotFound(page.id));
|
||||||
|
}
|
||||||
|
self.add_page(page).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 删除一个知识页面。
|
||||||
|
pub async fn delete_page(&self, id: &str) -> Result<(), MemoryError> {
|
||||||
|
let full_id = format!("{KNOWLEDGE_PREFIX}{id}");
|
||||||
|
self.store.delete(&full_id).await?;
|
||||||
|
let mut index = self.index.lock().unwrap();
|
||||||
|
index.retain(|e| e.id != id);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 根据关键词搜索知识页面。
|
||||||
|
///
|
||||||
|
/// 匹配规则:在 `title` / `summary` / `tags` 中查找子串(不区分大小写)。
|
||||||
|
/// 全文 `content` 搜索走 `MemoryStore`。
|
||||||
|
pub async fn search(&self, query: &str) -> Result<Vec<KnowledgePage>, MemoryError> {
|
||||||
|
if query.is_empty() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
let needle = query.to_lowercase();
|
||||||
|
let mut results = Vec::new();
|
||||||
|
let index = self.index.lock().unwrap();
|
||||||
|
for entry in index.iter() {
|
||||||
|
if entry.title.to_lowercase().contains(&needle)
|
||||||
|
|| entry.summary.to_lowercase().contains(&needle)
|
||||||
|
|| entry.tags.iter().any(|t| t.to_lowercase().contains(&needle))
|
||||||
|
{
|
||||||
|
if let Some(page) = self.get_page(&entry.id).await? {
|
||||||
|
results.push(page);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取内容目录(所有页面的轻量级索引条目)。
|
||||||
|
pub fn get_index(&self) -> Vec<PageIndexEntry> {
|
||||||
|
self.index.lock().unwrap().clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::memory::InMemoryStore;
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
fn make_page(id: &str, title: &str, tags: &[&str]) -> KnowledgePage {
|
||||||
|
let now = OffsetDateTime::now_utc();
|
||||||
|
KnowledgePage {
|
||||||
|
id: id.to_string(),
|
||||||
|
title: title.to_string(),
|
||||||
|
summary: format!("summary of {title}"),
|
||||||
|
content: format!("full content of {title}"),
|
||||||
|
tags: tags.iter().map(|s| s.to_string()).collect(),
|
||||||
|
references: Vec::new(),
|
||||||
|
created_at: now,
|
||||||
|
updated_at: now,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn add_get_delete_page() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
ks.add_page(make_page("p1", "LangGraph", &["langgraph", "framework"]))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let got = ks.get_page("p1").await.unwrap();
|
||||||
|
assert!(got.is_some());
|
||||||
|
assert_eq!(got.unwrap().title, "LangGraph");
|
||||||
|
|
||||||
|
ks.delete_page("p1").await.unwrap();
|
||||||
|
assert!(ks.get_page("p1").await.unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn add_page_rejects_empty_id() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
let result = ks.add_page(make_page("", "NoId", &[])).await;
|
||||||
|
assert!(matches!(result, Err(MemoryError::InvalidInput(_))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn update_page_requires_existing() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
let result = ks.update_page(make_page("nope", "Ghost", &[])).await;
|
||||||
|
assert!(matches!(result, Err(MemoryError::NotFound(_))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn search_finds_by_title_summary_tag() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
ks.add_page(make_page("p1", "LangGraph StateGraph", &["llm"]))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
ks.add_page(make_page("p2", "Other", &["knowledge-graph"]))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
ks.add_page(make_page("p3", "Third", &["unrelated"]))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let results = ks.search("stategraph").await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].id, "p1");
|
||||||
|
|
||||||
|
let results = ks.search("knowledge-graph").await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].id, "p2");
|
||||||
|
|
||||||
|
let results = ks.search("nonexistent").await.unwrap();
|
||||||
|
assert!(results.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn get_index_returns_all_pages() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
ks.add_page(make_page("p1", "A", &[])).await.unwrap();
|
||||||
|
ks.add_page(make_page("p2", "B", &[])).await.unwrap();
|
||||||
|
let index = ks.get_index();
|
||||||
|
assert_eq!(index.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn rebuild_index_recovers_from_drift() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
// 添加页面
|
||||||
|
ks.add_page(make_page("p1", "A", &[])).await.unwrap();
|
||||||
|
ks.add_page(make_page("p2", "B", &[])).await.unwrap();
|
||||||
|
assert_eq!(ks.get_index().len(), 2);
|
||||||
|
|
||||||
|
// 模拟 index 漂移:清空后重建
|
||||||
|
ks.index.lock().unwrap().clear();
|
||||||
|
assert_eq!(ks.get_index().len(), 0);
|
||||||
|
|
||||||
|
ks.rebuild_index().await.unwrap();
|
||||||
|
assert_eq!(ks.get_index().len(), 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,297 @@
|
|||||||
|
//! 记忆检索器 —— 基于 TextOverlap (Dice 系数) 的单通道关键词检索。
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use crate::memory::error::MemoryError;
|
||||||
|
use crate::memory::knowledge::KnowledgeStore;
|
||||||
|
use crate::memory::types::KnowledgePage;
|
||||||
|
|
||||||
|
/// 检索器配置。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RetrieverConfig {
|
||||||
|
/// 最大返回条数(默认 20)。
|
||||||
|
pub max_results: usize,
|
||||||
|
/// 最低分数阈值 [0.0, 1.0](默认 0.1)。
|
||||||
|
pub min_score: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RetrieverConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_results: 20,
|
||||||
|
min_score: 0.1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 单条带评分的检索结果。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ScoredItem {
|
||||||
|
pub page: KnowledgePage,
|
||||||
|
/// TextOverlap 评分 [0.0, 1.0]
|
||||||
|
pub score: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检索结果。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RetrievalResult {
|
||||||
|
pub items: Vec<ScoredItem>,
|
||||||
|
pub query: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 记忆检索器 —— 在 `KnowledgeStore` 中做关键词检索并按 TextOverlap 评分。
|
||||||
|
pub struct MemoryRetriever {
|
||||||
|
knowledge_store: KnowledgeStore,
|
||||||
|
config: RetrieverConfig,
|
||||||
|
/// 停用词表(用于关键词提取)。
|
||||||
|
stop_words: HashSet<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryRetriever {
|
||||||
|
/// 创建一个新的 MemoryRetriever。
|
||||||
|
pub fn new(knowledge_store: KnowledgeStore, config: RetrieverConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
knowledge_store,
|
||||||
|
config,
|
||||||
|
stop_words: default_stop_words(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 替换停用词表。
|
||||||
|
pub fn with_stop_words(mut self, stop_words: HashSet<String>) -> Self {
|
||||||
|
self.stop_words = stop_words;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检索相关知识页面。
|
||||||
|
pub async fn retrieve(&self, query: &str) -> Result<RetrievalResult, MemoryError> {
|
||||||
|
if query.is_empty() {
|
||||||
|
return Ok(RetrievalResult {
|
||||||
|
items: Vec::new(),
|
||||||
|
query: query.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 关键词提取
|
||||||
|
let keywords = extract_keywords(query, &self.stop_words);
|
||||||
|
|
||||||
|
// 2. 用关键词在 KnowledgeStore 中搜索
|
||||||
|
let mut pages = Vec::new();
|
||||||
|
for keyword in &keywords {
|
||||||
|
let found = self.knowledge_store.search(keyword).await?;
|
||||||
|
for page in found {
|
||||||
|
if !pages.iter().any(|p: &KnowledgePage| p.id == page.id) {
|
||||||
|
pages.push(page);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. TextOverlap 评分
|
||||||
|
let mut items: Vec<ScoredItem> = pages
|
||||||
|
.into_iter()
|
||||||
|
.map(|page| {
|
||||||
|
let score = text_overlap_score(query, &page);
|
||||||
|
ScoredItem { page, score }
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// 4. 过滤 → 排序 → 截取
|
||||||
|
items.retain(|i| i.score >= self.config.min_score);
|
||||||
|
items.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
|
items.truncate(self.config.max_results);
|
||||||
|
|
||||||
|
Ok(RetrievalResult {
|
||||||
|
items,
|
||||||
|
query: query.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 query 中提取关键词:按非字母数字字符分割 → 转小写 → 过滤单字符和停用词。
|
||||||
|
fn extract_keywords(query: &str, stop_words: &HashSet<String>) -> Vec<String> {
|
||||||
|
query
|
||||||
|
.split(|c: char| !c.is_alphanumeric())
|
||||||
|
.filter_map(|s| {
|
||||||
|
let lower = s.to_lowercase();
|
||||||
|
if lower.is_empty() || lower.chars().count() < 2 {
|
||||||
|
None
|
||||||
|
} else if stop_words.contains(&lower) {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(lower)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TextOverlap 评分(基于字符 bigram 的 Dice 系数 + 多字段加权)。
|
||||||
|
///
|
||||||
|
/// 字段权重:title 0.5 + summary 0.3 + content 0.2
|
||||||
|
/// 中文场景按字符级 bigram 处理,不依赖分词器。
|
||||||
|
pub fn text_overlap_score(query: &str, page: &KnowledgePage) -> f32 {
|
||||||
|
let title = text_overlap_dice(query, &page.title);
|
||||||
|
let summary = text_overlap_dice(query, &page.summary);
|
||||||
|
let content = text_overlap_dice(query, &page.content);
|
||||||
|
title * 0.5 + summary * 0.3 + content * 0.2
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dice 系数(基于字符 bigram)。
|
||||||
|
fn text_overlap_dice(query: &str, text: &str) -> f32 {
|
||||||
|
let q_bigrams = char_bigrams(query);
|
||||||
|
let t_bigrams = char_bigrams(text);
|
||||||
|
if q_bigrams.is_empty() || t_bigrams.is_empty() {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
let q_set: HashSet<&String> = q_bigrams.iter().collect();
|
||||||
|
let t_set: HashSet<&String> = t_bigrams.iter().collect();
|
||||||
|
let intersect = q_set.intersection(&t_set).count();
|
||||||
|
let denom = q_set.len() + t_set.len();
|
||||||
|
if denom == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
(2.0 * intersect as f32) / denom as f32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 提取字符 bigrams(用于字符级 Dice 系数)。
|
||||||
|
fn char_bigrams(s: &str) -> Vec<String> {
|
||||||
|
let chars: Vec<char> = s.chars().collect();
|
||||||
|
chars.windows(2).map(|w| w.iter().collect()).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_stop_words() -> HashSet<String> {
|
||||||
|
[
|
||||||
|
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
|
||||||
|
"had", "do", "does", "did", "will", "would", "should", "could", "may", "might", "shall",
|
||||||
|
"can", "this", "that", "these", "those", "it", "its", "they", "them", "their", "what",
|
||||||
|
"which", "who", "whom", "how", "when", "where", "and", "or", "but", "not", "no", "nor",
|
||||||
|
"so", "if", "then", "else", "with", "without", "for", "to", "from", "in", "on", "at",
|
||||||
|
"by", "of", "as", "into", "through", "during", "before", "after", "above", "below",
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::memory::knowledge::KnowledgeStore;
|
||||||
|
use crate::memory::{InMemoryStore, MemoryStore};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
fn make_page(id: &str, title: &str, summary: &str, content: &str) -> KnowledgePage {
|
||||||
|
let now = OffsetDateTime::now_utc();
|
||||||
|
KnowledgePage {
|
||||||
|
id: id.to_string(),
|
||||||
|
title: title.to_string(),
|
||||||
|
summary: summary.to_string(),
|
||||||
|
content: content.to_string(),
|
||||||
|
tags: Vec::new(),
|
||||||
|
references: Vec::new(),
|
||||||
|
created_at: now,
|
||||||
|
updated_at: now,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retrieve_empty_query() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
let retriever = MemoryRetriever::new(ks, RetrieverConfig::default());
|
||||||
|
let result = retriever.retrieve("").await.unwrap();
|
||||||
|
assert!(result.items.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retrieve_finds_relevant_page() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
ks.add_page(make_page(
|
||||||
|
"p1",
|
||||||
|
"LangGraph StateGraph",
|
||||||
|
"state management",
|
||||||
|
"LangGraph uses StateGraph for state machines",
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
ks.add_page(make_page("p2", "Other", "unrelated", "nothing matching"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let retriever = MemoryRetriever::new(ks, RetrieverConfig::default());
|
||||||
|
let result = retriever.retrieve("LangGraph state").await.unwrap();
|
||||||
|
assert!(!result.items.is_empty());
|
||||||
|
assert_eq!(result.items[0].page.id, "p1");
|
||||||
|
assert!(result.items[0].score > 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retrieve_respects_min_score() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
ks.add_page(make_page("p1", "X", "Y", "Z")).await.unwrap();
|
||||||
|
|
||||||
|
let config = RetrieverConfig {
|
||||||
|
max_results: 10,
|
||||||
|
min_score: 0.99,
|
||||||
|
};
|
||||||
|
let retriever = MemoryRetriever::new(ks, config);
|
||||||
|
let result = retriever.retrieve("totally unrelated content").await.unwrap();
|
||||||
|
assert!(result.items.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retrieve_respects_max_results() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
for i in 0..5 {
|
||||||
|
ks.add_page(make_page(
|
||||||
|
&format!("p{i}"),
|
||||||
|
"LangGraph",
|
||||||
|
"framework",
|
||||||
|
"agent runtime",
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
let config = RetrieverConfig {
|
||||||
|
max_results: 2,
|
||||||
|
min_score: 0.0,
|
||||||
|
};
|
||||||
|
let retriever = MemoryRetriever::new(ks, config);
|
||||||
|
let result = retriever.retrieve("LangGraph").await.unwrap();
|
||||||
|
assert_eq!(result.items.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn text_overlap_dice_zero_on_empty() {
|
||||||
|
assert_eq!(text_overlap_dice("hello", ""), 0.0);
|
||||||
|
assert_eq!(text_overlap_dice("", "hello"), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn text_overlap_dice_identical() {
|
||||||
|
let s = "hello world";
|
||||||
|
assert!((text_overlap_dice(s, s) - 1.0).abs() < 0.001);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_keywords_filters_stop_words() {
|
||||||
|
let stop = default_stop_words();
|
||||||
|
let kws = extract_keywords("the quick brown fox is fast", &stop);
|
||||||
|
assert!(!kws.contains(&"the".to_string()));
|
||||||
|
assert!(!kws.contains(&"is".to_string()));
|
||||||
|
assert!(kws.contains(&"quick".to_string()));
|
||||||
|
assert!(kws.contains(&"brown".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_keywords_filters_single_chars() {
|
||||||
|
let stop = default_stop_words();
|
||||||
|
let kws = extract_keywords("a b c dog", &stop);
|
||||||
|
assert!(!kws.contains(&"a".to_string()));
|
||||||
|
assert!(!kws.contains(&"b".to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,314 @@
|
|||||||
|
//! MemoryStore 抽象接口与默认实现。
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
use crate::memory::error::MemoryError;
|
||||||
|
use crate::memory::types::{MemoryFilter, MemoryItem};
|
||||||
|
|
||||||
|
/// 底层记忆存储抽象接口。
|
||||||
|
///
|
||||||
|
/// 下游可实现此 trait 以对接持久化后端(JSON 文件、SQLite、Redis 等)。
|
||||||
|
/// 默认实现 [`InMemoryStore`] 基于进程内 HashMap。
|
||||||
|
#[async_trait]
|
||||||
|
pub trait MemoryStore: Send + Sync {
|
||||||
|
/// 保存/覆盖一个 MemoryItem(upsert 语义)。
|
||||||
|
/// - 如果 id 不存在,则插入新条目
|
||||||
|
/// - 如果 id 已存在,则覆盖旧条目
|
||||||
|
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError>;
|
||||||
|
|
||||||
|
/// 根据 id 获取一个 MemoryItem。
|
||||||
|
async fn get(&self, id: &str) -> Result<Option<MemoryItem>, MemoryError>;
|
||||||
|
|
||||||
|
/// 根据 id 删除一个 MemoryItem。
|
||||||
|
async fn delete(&self, id: &str) -> Result<(), MemoryError>;
|
||||||
|
|
||||||
|
/// 根据 filter 列出 MemoryItem。
|
||||||
|
async fn list(&self, filter: &MemoryFilter) -> Result<Vec<MemoryItem>, MemoryError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 淘汰策略。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum EvictionPolicy {
|
||||||
|
/// 不淘汰(默认)。
|
||||||
|
None,
|
||||||
|
/// 超过存活时间(秒)淘汰。
|
||||||
|
Ttl { ttl_secs: u64 },
|
||||||
|
/// 超过容量上限淘汰最旧(基于 created_at)。
|
||||||
|
Capacity { max_items: usize },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 淘汰配置。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct EvictionConfig {
|
||||||
|
pub policy: EvictionPolicy,
|
||||||
|
/// 每写入 N 条后检查一次淘汰条件。
|
||||||
|
pub check_interval: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for EvictionConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
policy: EvictionPolicy::None,
|
||||||
|
check_interval: 64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 进程内默认实现 —— 基于 HashMap + Mutex,纯内存。
|
||||||
|
pub struct InMemoryStore {
|
||||||
|
items: Mutex<HashMap<String, MemoryItem>>,
|
||||||
|
eviction: EvictionConfig,
|
||||||
|
/// 自上次淘汰检查以来的写入次数。
|
||||||
|
writes_since_check: Mutex<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InMemoryStore {
|
||||||
|
/// 创建一个无淘汰策略的 InMemoryStore。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
items: Mutex::new(HashMap::new()),
|
||||||
|
eviction: EvictionConfig::default(),
|
||||||
|
writes_since_check: Mutex::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 创建一个带淘汰配置的 InMemoryStore。
|
||||||
|
pub fn with_eviction(eviction: EvictionConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
items: Mutex::new(HashMap::new()),
|
||||||
|
eviction,
|
||||||
|
writes_since_check: Mutex::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn maybe_evict(&self) {
|
||||||
|
// 不使用 .lock().await 跨点,先取计数判断是否需要淘汰
|
||||||
|
let should_check = {
|
||||||
|
let mut counter = self.writes_since_check.lock().unwrap();
|
||||||
|
*counter += 1;
|
||||||
|
if *counter >= self.eviction.check_interval {
|
||||||
|
*counter = 0;
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if !should_check {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let policy = self.eviction.policy.clone();
|
||||||
|
match policy {
|
||||||
|
EvictionPolicy::None => {}
|
||||||
|
EvictionPolicy::Ttl { ttl_secs } => {
|
||||||
|
let cutoff = OffsetDateTime::now_utc() - time::Duration::seconds(ttl_secs as i64);
|
||||||
|
let mut items = self.items.lock().unwrap();
|
||||||
|
items.retain(|_, v| v.created_at > cutoff);
|
||||||
|
}
|
||||||
|
EvictionPolicy::Capacity { max_items } => {
|
||||||
|
let mut items = self.items.lock().unwrap();
|
||||||
|
if items.len() > max_items {
|
||||||
|
let mut vec: Vec<_> = items.drain().collect();
|
||||||
|
// O(n) 部分排序:保留 created_at 最大的 max_items 个
|
||||||
|
vec.select_nth_unstable_by(max_items, |a, b| {
|
||||||
|
b.1.created_at.cmp(&a.1.created_at)
|
||||||
|
});
|
||||||
|
vec.truncate(max_items);
|
||||||
|
*items = vec.into_iter().collect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for InMemoryStore {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl MemoryStore for InMemoryStore {
|
||||||
|
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError> {
|
||||||
|
{
|
||||||
|
let mut items = self.items.lock().unwrap();
|
||||||
|
items.insert(item.id.clone(), item);
|
||||||
|
}
|
||||||
|
self.maybe_evict();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get(&self, id: &str) -> Result<Option<MemoryItem>, MemoryError> {
|
||||||
|
let items = self.items.lock().unwrap();
|
||||||
|
Ok(items.get(id).cloned())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete(&self, id: &str) -> Result<(), MemoryError> {
|
||||||
|
let mut items = self.items.lock().unwrap();
|
||||||
|
items.remove(id);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list(&self, filter: &MemoryFilter) -> Result<Vec<MemoryItem>, MemoryError> {
|
||||||
|
let items = self.items.lock().unwrap();
|
||||||
|
let mut result: Vec<MemoryItem> = items
|
||||||
|
.values()
|
||||||
|
.filter(|v| match &filter.prefix {
|
||||||
|
Some(p) => v.id.starts_with(p),
|
||||||
|
None => true,
|
||||||
|
})
|
||||||
|
.filter(|v| match filter.since {
|
||||||
|
Some(t) => v.created_at > t,
|
||||||
|
None => true,
|
||||||
|
})
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
// 按 created_at 升序排列(最旧在前)
|
||||||
|
result.sort_by_key(|v| v.created_at);
|
||||||
|
// 应用 offset
|
||||||
|
if let Some(offset) = filter.offset {
|
||||||
|
if offset < result.len() {
|
||||||
|
result.drain(..offset);
|
||||||
|
} else {
|
||||||
|
result.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 应用 limit
|
||||||
|
if let Some(limit) = filter.limit {
|
||||||
|
result.truncate(limit);
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
fn make_item(id: &str) -> MemoryItem {
|
||||||
|
MemoryItem {
|
||||||
|
id: id.to_string(),
|
||||||
|
content: format!("content-{id}"),
|
||||||
|
metadata: serde_json::json!({}),
|
||||||
|
created_at: OffsetDateTime::now_utc(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn save_get_delete_list() {
|
||||||
|
let store = InMemoryStore::new();
|
||||||
|
store.save(make_item("a")).await.unwrap();
|
||||||
|
store.save(make_item("b")).await.unwrap();
|
||||||
|
|
||||||
|
let got = store.get("a").await.unwrap();
|
||||||
|
assert!(got.is_some());
|
||||||
|
assert_eq!(got.unwrap().id, "a");
|
||||||
|
|
||||||
|
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||||
|
assert_eq!(list.len(), 2);
|
||||||
|
|
||||||
|
store.delete("a").await.unwrap();
|
||||||
|
assert!(store.get("a").await.unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn save_is_upsert() {
|
||||||
|
let store = InMemoryStore::new();
|
||||||
|
store.save(make_item("a")).await.unwrap();
|
||||||
|
let mut item = make_item("a");
|
||||||
|
item.content = "updated".to_string();
|
||||||
|
store.save(item).await.unwrap();
|
||||||
|
let got = store.get("a").await.unwrap().unwrap();
|
||||||
|
assert_eq!(got.content, "updated");
|
||||||
|
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||||
|
assert_eq!(list.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn list_with_prefix_and_limit() {
|
||||||
|
let store = InMemoryStore::new();
|
||||||
|
store.save(make_item("foo_a")).await.unwrap();
|
||||||
|
store.save(make_item("foo_b")).await.unwrap();
|
||||||
|
store.save(make_item("bar_a")).await.unwrap();
|
||||||
|
|
||||||
|
let filter = MemoryFilter {
|
||||||
|
prefix: Some("foo_".to_string()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let list = store.list(&filter).await.unwrap();
|
||||||
|
assert_eq!(list.len(), 2);
|
||||||
|
|
||||||
|
let filter = MemoryFilter {
|
||||||
|
prefix: Some("foo_".to_string()),
|
||||||
|
limit: Some(1),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let list = store.list(&filter).await.unwrap();
|
||||||
|
assert_eq!(list.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn capacity_eviction() {
|
||||||
|
// 强制每次写入都检查
|
||||||
|
let eviction = EvictionConfig {
|
||||||
|
policy: EvictionPolicy::Capacity { max_items: 2 },
|
||||||
|
check_interval: 1,
|
||||||
|
};
|
||||||
|
let store = InMemoryStore::with_eviction(eviction);
|
||||||
|
// 第一条和第二条共存
|
||||||
|
store.save(make_item("a")).await.unwrap();
|
||||||
|
store.save(make_item("b")).await.unwrap();
|
||||||
|
// 第三条写入触发淘汰:a 或 b 之一被淘汰
|
||||||
|
store.save(make_item("c")).await.unwrap();
|
||||||
|
|
||||||
|
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||||
|
assert_eq!(list.len(), 2);
|
||||||
|
// 留下的应该是 b 和 c(最新的两个)
|
||||||
|
let ids: Vec<&str> = list.iter().map(|v| v.id.as_str()).collect();
|
||||||
|
assert!(ids.contains(&"b"));
|
||||||
|
assert!(ids.contains(&"c"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ttl_eviction() {
|
||||||
|
// TTL 设为 0 会立即过期,但我们想保留 "a" 等待 "b" 写入后被淘汰。
|
||||||
|
// 改用小 TTL + 睡眠:先 save a,sleep,save b 时 a 已过期被淘汰。
|
||||||
|
let eviction = EvictionConfig {
|
||||||
|
policy: EvictionPolicy::Ttl { ttl_secs: 1 },
|
||||||
|
check_interval: 1,
|
||||||
|
};
|
||||||
|
let store = InMemoryStore::with_eviction(eviction);
|
||||||
|
store.save(make_item("a")).await.unwrap();
|
||||||
|
// 等待超过 1 秒
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(1100));
|
||||||
|
// 触发淘汰:a 已超过 ttl_secs=1,应被淘汰
|
||||||
|
store.save(make_item("b")).await.unwrap();
|
||||||
|
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||||
|
// 由于 ttl_secs=1,且 b 刚写入,可能刚好处于临界值。
|
||||||
|
// 我们只断言 list 不包含 "a" 即可。
|
||||||
|
let ids: Vec<&str> = list.iter().map(|v| v.id.as_str()).collect();
|
||||||
|
assert!(
|
||||||
|
!ids.contains(&"a"),
|
||||||
|
"expected 'a' to be evicted, but found in {ids:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn none_policy_no_eviction() {
|
||||||
|
let eviction = EvictionConfig {
|
||||||
|
policy: EvictionPolicy::None,
|
||||||
|
check_interval: 1,
|
||||||
|
};
|
||||||
|
let store = InMemoryStore::with_eviction(eviction);
|
||||||
|
for i in 0..100 {
|
||||||
|
store.save(make_item(&format!("item_{i}"))).await.unwrap();
|
||||||
|
}
|
||||||
|
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||||
|
assert_eq!(list.len(), 100);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
//! 记忆系统核心数据类型。
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
/// 记忆条目 —— MemoryStore 存储的基本单元。
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct MemoryItem {
|
||||||
|
/// 唯一标识。
|
||||||
|
pub id: String,
|
||||||
|
/// 内容(通常为 JSON 序列化的具体记忆数据)。
|
||||||
|
pub content: String,
|
||||||
|
/// 任意附加元数据。
|
||||||
|
pub metadata: serde_json::Value,
|
||||||
|
/// 创建时间。
|
||||||
|
pub created_at: OffsetDateTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MemoryStore 列表查询条件。
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub struct MemoryFilter {
|
||||||
|
/// 按 id 前缀过滤。
|
||||||
|
pub prefix: Option<String>,
|
||||||
|
/// 仅返回该时间之后创建的条目。
|
||||||
|
pub since: Option<OffsetDateTime>,
|
||||||
|
/// 跳过前 N 条。
|
||||||
|
pub offset: Option<usize>,
|
||||||
|
/// 最多返回 N 条。
|
||||||
|
pub limit: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 知识页面 —— 描述一段结构化知识。
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct KnowledgePage {
|
||||||
|
/// 唯一标识。
|
||||||
|
pub id: String,
|
||||||
|
/// 标题。
|
||||||
|
pub title: String,
|
||||||
|
/// 一句话摘要。
|
||||||
|
pub summary: String,
|
||||||
|
/// 完整内容。
|
||||||
|
pub content: String,
|
||||||
|
/// 检索标签。
|
||||||
|
pub tags: Vec<String>,
|
||||||
|
/// 交叉引用的其他页面 ID。
|
||||||
|
pub references: Vec<String>,
|
||||||
|
/// 创建时间。
|
||||||
|
pub created_at: OffsetDateTime,
|
||||||
|
/// 最后更新时间。
|
||||||
|
pub updated_at: OffsetDateTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 知识页面索引条目 —— 用于轻量遍历和内容目录。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PageIndexEntry {
|
||||||
|
pub id: String,
|
||||||
|
pub title: String,
|
||||||
|
pub summary: String,
|
||||||
|
pub tags: Vec<String>,
|
||||||
|
pub updated_at: OffsetDateTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&KnowledgePage> for PageIndexEntry {
|
||||||
|
fn from(p: &KnowledgePage) -> Self {
|
||||||
|
Self {
|
||||||
|
id: p.id.clone(),
|
||||||
|
title: p.title.clone(),
|
||||||
|
summary: p.summary.clone(),
|
||||||
|
tags: p.tags.clone(),
|
||||||
|
updated_at: p.updated_at,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
pub mod error;
|
||||||
|
pub mod template;
|
||||||
|
pub mod composer;
|
||||||
|
|
||||||
|
pub use error::PromptError;
|
||||||
|
pub use template::{PromptTemplate, PromptTemplateRegistry, TemplateContext, TemplateValue};
|
||||||
|
pub use composer::PromptComposer;
|
||||||
@@ -0,0 +1,406 @@
|
|||||||
|
use crate::llm::types::message::{ContentField, OpenaiChatMessage, OpenaiContentPart};
|
||||||
|
use crate::llm::types::request::OpenaiChatRequest;
|
||||||
|
use crate::prompt::error::PromptError;
|
||||||
|
use crate::prompt::template::{PromptTemplate, TemplateContext};
|
||||||
|
|
||||||
|
/// 提示词组合器——构建多角色消息序列。
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct PromptComposer {
|
||||||
|
messages: Vec<OpenaiChatMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptComposer {
|
||||||
|
/// 创建一个空的组合器。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从已有的消息列表初始化。
|
||||||
|
pub fn from_messages(messages: Vec<OpenaiChatMessage>) -> Self {
|
||||||
|
Self { messages }
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== 纯文本消息 =====
|
||||||
|
|
||||||
|
/// 添加一条纯文本 system 消息。
|
||||||
|
pub fn system(mut self, text: impl Into<String>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::system_text(text.into()));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条纯文本 user 消息。
|
||||||
|
pub fn user(mut self, text: impl Into<String>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::user_text(text.into()));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条纯文本 assistant 消息。
|
||||||
|
pub fn assistant(mut self, text: impl Into<String>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::assistant_text(text.into()));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条纯文本 developer 消息(o1 系列模型使用)。
|
||||||
|
pub fn developer(mut self, text: impl Into<String>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::developer_text(text.into()));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条 Tool 消息(工具执行结果回传)。
|
||||||
|
pub fn tool(mut self, tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::tool_result(
|
||||||
|
tool_call_id.into(),
|
||||||
|
content.into(),
|
||||||
|
));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== 模板消息 =====
|
||||||
|
|
||||||
|
/// 使用模板和上下文渲染后添加为 user 消息。
|
||||||
|
pub fn user_template(
|
||||||
|
mut self,
|
||||||
|
template: &PromptTemplate,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
) -> Result<Self, PromptError> {
|
||||||
|
let text = template.render(ctx)?;
|
||||||
|
self.push_message(OpenaiChatMessage::user_text(text));
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用模板和上下文渲染后添加为 system 消息。
|
||||||
|
pub fn system_template(
|
||||||
|
mut self,
|
||||||
|
template: &PromptTemplate,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
) -> Result<Self, PromptError> {
|
||||||
|
let text = template.render(ctx)?;
|
||||||
|
self.push_message(OpenaiChatMessage::system_text(text));
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用模板和上下文渲染后添加为 assistant 消息。
|
||||||
|
pub fn assistant_template(
|
||||||
|
mut self,
|
||||||
|
template: &PromptTemplate,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
) -> Result<Self, PromptError> {
|
||||||
|
let text = template.render(ctx)?;
|
||||||
|
self.push_message(OpenaiChatMessage::assistant_text(text));
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用模板和上下文渲染后添加为 developer 消息。
|
||||||
|
pub fn developer_template(
|
||||||
|
mut self,
|
||||||
|
template: &PromptTemplate,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
) -> Result<Self, PromptError> {
|
||||||
|
let text = template.render(ctx)?;
|
||||||
|
self.push_message(OpenaiChatMessage::developer_text(text));
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== 多模态 ContentPart =====
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 system 消息。
|
||||||
|
pub fn system_content(mut self, part: OpenaiContentPart) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::System {
|
||||||
|
content: ContentField::Array(vec![part]),
|
||||||
|
name: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 user 消息。
|
||||||
|
pub fn user_content(mut self, part: OpenaiContentPart) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::User {
|
||||||
|
content: ContentField::Array(vec![part]),
|
||||||
|
name: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 assistant 消息。
|
||||||
|
pub fn assistant_content(mut self, part: OpenaiContentPart) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::Assistant {
|
||||||
|
content: ContentField::Array(vec![part]),
|
||||||
|
refusal: None,
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 developer 消息。
|
||||||
|
pub fn developer_content(mut self, part: OpenaiContentPart) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::Developer {
|
||||||
|
content: ContentField::Array(vec![part]),
|
||||||
|
name: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 Tool 消息。
|
||||||
|
pub fn tool_content(
|
||||||
|
mut self,
|
||||||
|
tool_call_id: impl Into<String>,
|
||||||
|
part: OpenaiContentPart,
|
||||||
|
) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::Tool {
|
||||||
|
content: ContentField::Array(vec![part]),
|
||||||
|
tool_call_id: tool_call_id.into(),
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 批量添加 ContentPart 作为 user 消息。
|
||||||
|
pub fn user_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::User {
|
||||||
|
content: ContentField::Array(parts),
|
||||||
|
name: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 批量添加 ContentPart 作为 system 消息。
|
||||||
|
pub fn system_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::System {
|
||||||
|
content: ContentField::Array(parts),
|
||||||
|
name: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 批量添加 ContentPart 作为 assistant 消息。
|
||||||
|
pub fn assistant_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::Assistant {
|
||||||
|
content: ContentField::Array(parts),
|
||||||
|
refusal: None,
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 批量添加 ContentPart 作为 developer 消息。
|
||||||
|
pub fn developer_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::Developer {
|
||||||
|
content: ContentField::Array(parts),
|
||||||
|
name: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== 角色标识 =====
|
||||||
|
|
||||||
|
/// 为上一条添加的消息设置 `name` 字段。
|
||||||
|
pub fn with_name(mut self, name: impl Into<String>) -> Self {
|
||||||
|
let name = name.into();
|
||||||
|
if let Some(msg) = self.messages.last_mut() {
|
||||||
|
set_message_name(msg, name);
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== 构建 =====
|
||||||
|
|
||||||
|
/// 构建最终的消息列表。
|
||||||
|
pub fn build(self) -> Vec<OpenaiChatMessage> {
|
||||||
|
self.messages
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 构建并直接创建 ChatRequest(需搭配 model 参数)。
|
||||||
|
/// 返回的 `OpenaiChatRequest` 中 `tools`、`temperature`、`max_tokens` 等字段均为 `None`,
|
||||||
|
/// 可通过结构体更新语法补全:`OpenaiChatRequest { tools: Some(...), ..req }`。
|
||||||
|
pub fn build_request(self, model: impl Into<String>) -> OpenaiChatRequest {
|
||||||
|
OpenaiChatRequest {
|
||||||
|
model: model.into(),
|
||||||
|
messages: self.messages,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== 内部方法 =====
|
||||||
|
|
||||||
|
fn push_message(&mut self, msg: OpenaiChatMessage) {
|
||||||
|
self.messages.push(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_message_name(msg: &mut OpenaiChatMessage, name: String) {
|
||||||
|
match msg {
|
||||||
|
OpenaiChatMessage::Developer { name: n, .. } => *n = Some(name),
|
||||||
|
OpenaiChatMessage::System { name: n, .. } => *n = Some(name),
|
||||||
|
OpenaiChatMessage::User { name: n, .. } => *n = Some(name),
|
||||||
|
OpenaiChatMessage::Assistant { name: n, .. } => *n = Some(name),
|
||||||
|
OpenaiChatMessage::Tool { .. } => {}
|
||||||
|
OpenaiChatMessage::Function { .. } => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 验证消息序列是否符合 OpenAI API 要求。
|
||||||
|
pub fn validate_messages(messages: &[OpenaiChatMessage]) -> Result<(), PromptError> {
|
||||||
|
if messages.is_empty() {
|
||||||
|
return Err(PromptError::InvalidSequence(
|
||||||
|
"消息列表不能为空".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut last_tool_call_ids: Vec<String> = Vec::new();
|
||||||
|
|
||||||
|
for (i, msg) in messages.iter().enumerate() {
|
||||||
|
match msg {
|
||||||
|
OpenaiChatMessage::Tool {
|
||||||
|
tool_call_id,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
if last_tool_call_ids.is_empty() {
|
||||||
|
return Err(PromptError::InvalidSequence(format!(
|
||||||
|
"消息[{i}] Tool 消息前必须有 Assistant 消息且含 tool_calls"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
if !last_tool_call_ids.iter().any(|id| id == tool_call_id) {
|
||||||
|
return Err(PromptError::InvalidSequence(format!(
|
||||||
|
"消息[{i}] Tool 消息的 tool_call_id '{}' 未匹配任何 assistant tool_calls",
|
||||||
|
tool_call_id
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OpenaiChatMessage::Assistant {
|
||||||
|
tool_calls: Some(calls),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
last_tool_call_ids.clear();
|
||||||
|
for call in calls {
|
||||||
|
let crate::llm::types::OpenaiToolCall::Function { id, .. } = call;
|
||||||
|
last_tool_call_ids.push(id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OpenaiChatMessage::Assistant {
|
||||||
|
tool_calls: None, ..
|
||||||
|
} => {
|
||||||
|
last_tool_call_ids.clear();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
last_tool_call_ids.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::prompt::TemplateValue;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_composer_basic() {
|
||||||
|
let msgs = PromptComposer::new()
|
||||||
|
.system("You are helpful")
|
||||||
|
.user("Hello")
|
||||||
|
.assistant("Hi there!")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
assert_eq!(msgs.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_composer_tool() {
|
||||||
|
let msgs = PromptComposer::new()
|
||||||
|
.system("You are helpful")
|
||||||
|
.user("What's the weather?")
|
||||||
|
.assistant("Let me check")
|
||||||
|
.tool("call_123", "Sunny, 25°C")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
assert_eq!(msgs.len(), 4);
|
||||||
|
match &msgs[3] {
|
||||||
|
OpenaiChatMessage::Tool {
|
||||||
|
tool_call_id,
|
||||||
|
content,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
assert_eq!(tool_call_id, "call_123");
|
||||||
|
match content {
|
||||||
|
ContentField::String(s) => assert_eq!(s, "Sunny, 25°C"),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("Expected Tool message"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_messages_ok() {
|
||||||
|
let msgs = PromptComposer::new()
|
||||||
|
.system("You are helpful")
|
||||||
|
.user("Hello")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
assert!(validate_messages(&msgs).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_messages_empty() {
|
||||||
|
let msgs: Vec<OpenaiChatMessage> = vec![];
|
||||||
|
assert!(validate_messages(&msgs).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_template_render() {
|
||||||
|
let tpl = PromptTemplate::compile("Hello {{name}}, you have {{count}} messages").unwrap();
|
||||||
|
let mut ctx = TemplateContext::new();
|
||||||
|
ctx.insert("name", "Alice");
|
||||||
|
ctx.insert("count", "5");
|
||||||
|
|
||||||
|
let result = tpl.render(&ctx).unwrap();
|
||||||
|
assert_eq!(result, "Hello Alice, you have 5 messages");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_template_if() {
|
||||||
|
let tpl = PromptTemplate::compile("Hello {{#if name}}{{name}}{{else}}Guest{{/if}}").unwrap();
|
||||||
|
let mut ctx = TemplateContext::new();
|
||||||
|
ctx.insert("name", "Bob");
|
||||||
|
|
||||||
|
let with_name = tpl.render(&ctx).unwrap();
|
||||||
|
assert_eq!(with_name, "Hello Bob");
|
||||||
|
|
||||||
|
let without_name = tpl.render(&TemplateContext::new()).unwrap();
|
||||||
|
assert_eq!(without_name, "Hello Guest");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_template_each() {
|
||||||
|
let tpl = PromptTemplate::compile("Items: {{#each items}}{{item}}, {{/each}}").unwrap();
|
||||||
|
let mut ctx = TemplateContext::new();
|
||||||
|
ctx.insert("items", TemplateValue::Array(vec![
|
||||||
|
TemplateValue::String("a".to_string()),
|
||||||
|
TemplateValue::String("b".to_string()),
|
||||||
|
TemplateValue::String("c".to_string()),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let result = tpl.render(&ctx).unwrap();
|
||||||
|
assert_eq!(result, "Items: a, b, c, ");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_template_display() {
|
||||||
|
let tpl = PromptTemplate::compile("Hello {{name}}").unwrap();
|
||||||
|
assert_eq!(format!("{}", tpl), "Hello {{name}}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_context_from_json() {
|
||||||
|
let json: serde_json::Value = serde_json::json!({
|
||||||
|
"name": "Alice",
|
||||||
|
"active": true
|
||||||
|
});
|
||||||
|
let ctx = TemplateContext::from_json(&json).unwrap();
|
||||||
|
assert!(ctx.get("name").is_some());
|
||||||
|
assert!(ctx.get("active").is_some());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum PromptError {
|
||||||
|
#[error("模板解析错误: {0}")]
|
||||||
|
Parse(String),
|
||||||
|
|
||||||
|
#[error("渲染错误: 变量 '{0}' 未找到")]
|
||||||
|
VariableNotFound(String),
|
||||||
|
|
||||||
|
#[error("渲染错误: 引用的子模板 '{0}' 未注册")]
|
||||||
|
PartialNotFound(String),
|
||||||
|
|
||||||
|
#[error("渲染错误: '{0}' 不是数组,无法遍历")]
|
||||||
|
NotAnArray(String),
|
||||||
|
|
||||||
|
#[error("渲染递归超过最大深度限制 ({0})")]
|
||||||
|
MaxDepthReached(u8),
|
||||||
|
|
||||||
|
#[error("渲染错误: {0}")]
|
||||||
|
Render(String),
|
||||||
|
|
||||||
|
#[error("消息序列校验失败: {0}")]
|
||||||
|
InvalidSequence(String),
|
||||||
|
|
||||||
|
#[error("文件读取错误: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
}
|
||||||
@@ -0,0 +1,543 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::fmt;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::prompt::error::PromptError;
|
||||||
|
|
||||||
|
const MAX_RENDER_DEPTH: u8 = 16;
|
||||||
|
|
||||||
|
// ===== TemplateValue =====
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum TemplateValue {
|
||||||
|
String(String),
|
||||||
|
Bool(bool),
|
||||||
|
Array(Vec<TemplateValue>),
|
||||||
|
Object(HashMap<String, TemplateValue>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for TemplateValue {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
TemplateValue::String(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for TemplateValue {
|
||||||
|
fn from(s: &str) -> Self {
|
||||||
|
TemplateValue::String(s.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<bool> for TemplateValue {
|
||||||
|
fn from(b: bool) -> Self {
|
||||||
|
TemplateValue::Bool(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for TemplateValue {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
TemplateValue::String(s) => write!(f, "{}", s),
|
||||||
|
TemplateValue::Bool(b) => write!(f, "{}", b),
|
||||||
|
TemplateValue::Array(arr) => {
|
||||||
|
let strs: Vec<String> = arr.iter().map(|v| format!("{}", v)).collect();
|
||||||
|
write!(f, "[{}]", strs.join(", "))
|
||||||
|
}
|
||||||
|
TemplateValue::Object(map) => {
|
||||||
|
let strs: Vec<String> = map
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| format!("\"{}\": {}", k, v))
|
||||||
|
.collect();
|
||||||
|
write!(f, "{{{}}}", strs.join(", "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TemplateValue {
|
||||||
|
fn is_truthy(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
TemplateValue::String(s) => !s.is_empty(),
|
||||||
|
TemplateValue::Bool(b) => *b,
|
||||||
|
TemplateValue::Array(arr) => !arr.is_empty(),
|
||||||
|
TemplateValue::Object(map) => !map.is_empty(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_array(&self) -> Option<&Vec<TemplateValue>> {
|
||||||
|
match self {
|
||||||
|
TemplateValue::Array(arr) => Some(arr),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== TemplateContext =====
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub struct TemplateContext {
|
||||||
|
vars: HashMap<String, TemplateValue>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TemplateContext {
|
||||||
|
/// 创建一个空的模板上下文。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 插入变量(支持 `&str` / `String` / `bool` 自动转换)。
|
||||||
|
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<TemplateValue>) {
|
||||||
|
self.vars.insert(key.into(), value.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 按名称获取变量值。
|
||||||
|
pub fn get(&self, key: &str) -> Option<&TemplateValue> {
|
||||||
|
self.vars.get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 `serde_json::Value` 递归构造(支持嵌套 Object/Array)。
|
||||||
|
pub fn from_json(value: &Value) -> Result<Self, PromptError> {
|
||||||
|
let map = value
|
||||||
|
.as_object()
|
||||||
|
.ok_or_else(|| PromptError::Render("JSON 根值必须是对象".to_string()))?;
|
||||||
|
|
||||||
|
let mut ctx = Self::new();
|
||||||
|
for (k, v) in map {
|
||||||
|
ctx.vars.insert(k.clone(), json_to_template_value(v)?);
|
||||||
|
}
|
||||||
|
Ok(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 `HashMap` 构造(适用于配置加载场景)。
|
||||||
|
pub fn from_map(map: HashMap<String, TemplateValue>) -> Self {
|
||||||
|
Self { vars: map }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn json_to_template_value(v: &Value) -> Result<TemplateValue, PromptError> {
|
||||||
|
match v {
|
||||||
|
Value::Null => Ok(TemplateValue::String(String::new())),
|
||||||
|
Value::Bool(b) => Ok(TemplateValue::Bool(*b)),
|
||||||
|
Value::Number(n) => Ok(TemplateValue::String(n.to_string())),
|
||||||
|
Value::String(s) => Ok(TemplateValue::String(s.clone())),
|
||||||
|
Value::Array(arr) => {
|
||||||
|
let items: Result<Vec<TemplateValue>, _> =
|
||||||
|
arr.iter().map(json_to_template_value).collect();
|
||||||
|
Ok(TemplateValue::Array(items?))
|
||||||
|
}
|
||||||
|
Value::Object(obj) => {
|
||||||
|
let mut map = HashMap::new();
|
||||||
|
for (k, v) in obj {
|
||||||
|
map.insert(k.clone(), json_to_template_value(v)?);
|
||||||
|
}
|
||||||
|
Ok(TemplateValue::Object(map))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== Fragment (AST) =====
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum Fragment {
|
||||||
|
Literal(String),
|
||||||
|
Variable { name: String },
|
||||||
|
If {
|
||||||
|
condition: String,
|
||||||
|
body: Vec<Fragment>,
|
||||||
|
else_body: Vec<Fragment>,
|
||||||
|
},
|
||||||
|
Each {
|
||||||
|
variable: String,
|
||||||
|
body: Vec<Fragment>,
|
||||||
|
},
|
||||||
|
Raw(String),
|
||||||
|
Include(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== PromptTemplate =====
|
||||||
|
|
||||||
|
pub struct PromptTemplate {
|
||||||
|
raw: String,
|
||||||
|
fragments: Vec<Fragment>,
|
||||||
|
partials: HashMap<String, PromptTemplate>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptTemplate {
|
||||||
|
/// 从模板字符串编译。
|
||||||
|
pub fn compile(template: &str) -> Result<Self, PromptError> {
|
||||||
|
let fragments = compile_fragments(template)?;
|
||||||
|
Ok(Self {
|
||||||
|
raw: template.to_string(),
|
||||||
|
fragments,
|
||||||
|
partials: HashMap::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用上下文渲染。
|
||||||
|
pub fn render(&self, ctx: &TemplateContext) -> Result<String, PromptError> {
|
||||||
|
let mut output = String::new();
|
||||||
|
render_fragments(&self.fragments, ctx, &self.partials, &mut output, 0)?;
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用上下文和外部 partials 渲染。
|
||||||
|
pub fn render_with_partials(
|
||||||
|
&self,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
partials: &HashMap<String, PromptTemplate>,
|
||||||
|
) -> Result<String, PromptError> {
|
||||||
|
let mut output = String::new();
|
||||||
|
render_fragments(&self.fragments, ctx, partials, &mut output, 0)?;
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 注册可引用的子模板。
|
||||||
|
pub fn register_partial(&mut self, name: &str, template: PromptTemplate) {
|
||||||
|
self.partials.insert(name.to_string(), template);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for PromptTemplate {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "{}", self.raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== Compiler =====
|
||||||
|
|
||||||
|
fn compile_fragments(template: &str) -> Result<Vec<Fragment>, PromptError> {
|
||||||
|
let bytes = template.as_bytes();
|
||||||
|
let len = bytes.len();
|
||||||
|
let mut fragments = Vec::new();
|
||||||
|
let mut i = 0;
|
||||||
|
let mut literal = String::new();
|
||||||
|
|
||||||
|
while i < len {
|
||||||
|
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
|
||||||
|
if !literal.is_empty() {
|
||||||
|
fragments.push(Fragment::Literal(literal.clone()));
|
||||||
|
literal.clear();
|
||||||
|
}
|
||||||
|
let (tag_content, end) = parse_tag(bytes, i)?;
|
||||||
|
i = end;
|
||||||
|
|
||||||
|
let tag = tag_content.trim();
|
||||||
|
if let Some(rest) = tag.strip_prefix("#if ") {
|
||||||
|
let (body, else_body, new_i) =
|
||||||
|
parse_block(template, i, "if")?;
|
||||||
|
let condition = rest.trim().to_string();
|
||||||
|
fragments.push(Fragment::If {
|
||||||
|
condition,
|
||||||
|
body,
|
||||||
|
else_body,
|
||||||
|
});
|
||||||
|
i = new_i;
|
||||||
|
} else if let Some(rest) = tag.strip_prefix("#each ") {
|
||||||
|
let (body, new_i) = parse_each_block(template, i)?;
|
||||||
|
let variable = rest.trim().to_string();
|
||||||
|
fragments.push(Fragment::Each { variable, body });
|
||||||
|
i = new_i;
|
||||||
|
} else if tag == "#raw" {
|
||||||
|
let (raw, new_i) = parse_raw_block(template, i)?;
|
||||||
|
fragments.push(Fragment::Raw(raw));
|
||||||
|
i = new_i;
|
||||||
|
} else if let Some(rest) = tag.strip_prefix("> ") {
|
||||||
|
let name = rest.trim().to_string();
|
||||||
|
fragments.push(Fragment::Include(name));
|
||||||
|
} else if tag.starts_with("/") {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
let name = tag.to_string();
|
||||||
|
fragments.push(Fragment::Variable { name });
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
literal.push(bytes[i] as char);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !literal.is_empty() {
|
||||||
|
fragments.push(Fragment::Literal(literal));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(fragments)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_tag(bytes: &[u8], start: usize) -> Result<(String, usize), PromptError> {
|
||||||
|
let len = bytes.len();
|
||||||
|
let mut i = start + 2;
|
||||||
|
let mut content = String::new();
|
||||||
|
while i < len {
|
||||||
|
if bytes[i] == b'}' && i + 1 < len && bytes[i + 1] == b'}' {
|
||||||
|
return Ok((content, i + 2));
|
||||||
|
}
|
||||||
|
content.push(bytes[i] as char);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
Err(PromptError::Parse("未闭合的 {{ 标签".to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_block(
|
||||||
|
template: &str,
|
||||||
|
start: usize,
|
||||||
|
kind: &str,
|
||||||
|
) -> Result<(Vec<Fragment>, Vec<Fragment>, usize), PromptError> {
|
||||||
|
let bytes = template.as_bytes();
|
||||||
|
let len = bytes.len();
|
||||||
|
let mut depth = 1u32;
|
||||||
|
let mut i = start;
|
||||||
|
let mut body = String::new();
|
||||||
|
let mut else_body = String::new();
|
||||||
|
let mut is_else = false;
|
||||||
|
|
||||||
|
while i < len && depth > 0 {
|
||||||
|
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
|
||||||
|
let (tag, end) = parse_tag(bytes, i)?;
|
||||||
|
let tag = tag.trim().to_string();
|
||||||
|
if tag == format!("/{kind}") {
|
||||||
|
depth -= 1;
|
||||||
|
if depth == 0 {
|
||||||
|
let (if_fragments, else_fragments) = if is_else {
|
||||||
|
(compile_fragments(&body)?, compile_fragments(&else_body)?)
|
||||||
|
} else {
|
||||||
|
(compile_fragments(&body)?, compile_fragments("")?)
|
||||||
|
};
|
||||||
|
return Ok((if_fragments, else_fragments, end));
|
||||||
|
} else {
|
||||||
|
body.push_str(&template[i..end]);
|
||||||
|
}
|
||||||
|
i = end;
|
||||||
|
} else if tag == "#if " || tag.starts_with("#if ") {
|
||||||
|
depth += 1;
|
||||||
|
body.push_str(&template[i..end]);
|
||||||
|
i = end;
|
||||||
|
} else if tag == "else" && depth == 1 {
|
||||||
|
is_else = true;
|
||||||
|
i = end;
|
||||||
|
} else {
|
||||||
|
body.push_str(&template[i..end]);
|
||||||
|
i = end;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if is_else {
|
||||||
|
else_body.push(bytes[i] as char);
|
||||||
|
} else {
|
||||||
|
body.push(bytes[i] as char);
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(PromptError::Parse(format!("未闭合的 {{#{}}} 块", kind)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_each_block(
|
||||||
|
template: &str,
|
||||||
|
start: usize,
|
||||||
|
) -> Result<(Vec<Fragment>, usize), PromptError> {
|
||||||
|
let bytes = template.as_bytes();
|
||||||
|
let len = bytes.len();
|
||||||
|
let mut depth = 1u32;
|
||||||
|
let mut i = start;
|
||||||
|
let mut body = String::new();
|
||||||
|
|
||||||
|
while i < len && depth > 0 {
|
||||||
|
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
|
||||||
|
let (tag, end) = parse_tag(bytes, i)?;
|
||||||
|
let tag = tag.trim().to_string();
|
||||||
|
if tag == "/each" {
|
||||||
|
depth -= 1;
|
||||||
|
if depth == 0 {
|
||||||
|
let fragments = compile_fragments(&body)?;
|
||||||
|
return Ok((fragments, end));
|
||||||
|
} else {
|
||||||
|
body.push_str(&template[i..end]);
|
||||||
|
}
|
||||||
|
i = end;
|
||||||
|
} else if tag == "#each " || tag.starts_with("#each ") {
|
||||||
|
depth += 1;
|
||||||
|
body.push_str(&template[i..end]);
|
||||||
|
i = end;
|
||||||
|
} else {
|
||||||
|
body.push_str(&template[i..end]);
|
||||||
|
i = end;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
body.push(bytes[i] as char);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(PromptError::Parse(
|
||||||
|
"未闭合的 {{#each}} 块".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_raw_block(template: &str, start: usize) -> Result<(String, usize), PromptError> {
|
||||||
|
let bytes = template.as_bytes();
|
||||||
|
let len = bytes.len();
|
||||||
|
let mut i = start;
|
||||||
|
let mut content = String::new();
|
||||||
|
|
||||||
|
while i < len {
|
||||||
|
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
|
||||||
|
let (tag, end) = parse_tag(bytes, i)?;
|
||||||
|
let tag = tag.trim().to_string();
|
||||||
|
if tag == "/raw" {
|
||||||
|
return Ok((content, end));
|
||||||
|
} else {
|
||||||
|
content.push_str(&template[i..end]);
|
||||||
|
i = end;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content.push(bytes[i] as char);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(PromptError::Parse(
|
||||||
|
"未闭合的 {{#raw}} 块".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== Renderer =====
|
||||||
|
|
||||||
|
fn render_fragments(
|
||||||
|
fragments: &[Fragment],
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
partials: &HashMap<String, PromptTemplate>,
|
||||||
|
output: &mut String,
|
||||||
|
depth: u8,
|
||||||
|
) -> Result<(), PromptError> {
|
||||||
|
if depth > MAX_RENDER_DEPTH {
|
||||||
|
return Err(PromptError::MaxDepthReached(MAX_RENDER_DEPTH));
|
||||||
|
}
|
||||||
|
|
||||||
|
for frag in fragments {
|
||||||
|
match frag {
|
||||||
|
Fragment::Literal(text) => {
|
||||||
|
output.push_str(text);
|
||||||
|
}
|
||||||
|
Fragment::Variable { name } => {
|
||||||
|
match ctx.get(name) {
|
||||||
|
Some(val) => {
|
||||||
|
output.push_str(&format!("{}", val));
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
return Err(PromptError::VariableNotFound(name.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Fragment::If {
|
||||||
|
condition,
|
||||||
|
body,
|
||||||
|
else_body,
|
||||||
|
} => {
|
||||||
|
let truthy = ctx
|
||||||
|
.get(condition)
|
||||||
|
.map(|v| v.is_truthy())
|
||||||
|
.unwrap_or(false);
|
||||||
|
let target = if truthy { body } else { else_body };
|
||||||
|
render_fragments(target, ctx, partials, output, depth + 1)?;
|
||||||
|
}
|
||||||
|
Fragment::Each { variable, body } => {
|
||||||
|
let arr = match ctx.get(variable) {
|
||||||
|
Some(val) => val.as_array().ok_or_else(|| {
|
||||||
|
PromptError::NotAnArray(variable.clone())
|
||||||
|
})?,
|
||||||
|
None => {
|
||||||
|
return Err(PromptError::VariableNotFound(variable.clone()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for item in arr {
|
||||||
|
let mut child_ctx = ctx.clone();
|
||||||
|
child_ctx.vars.insert("item".to_string(), item.clone());
|
||||||
|
render_fragments(body, &child_ctx, partials, output, depth + 1)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Fragment::Raw(text) => {
|
||||||
|
output.push_str(text);
|
||||||
|
}
|
||||||
|
Fragment::Include(name) => {
|
||||||
|
if let Some(partial) = partials.get(name) {
|
||||||
|
partial.render_with_partials(ctx, partials)?;
|
||||||
|
} else {
|
||||||
|
return Err(PromptError::PartialNotFound(name.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== PromptTemplateRegistry =====
|
||||||
|
|
||||||
|
/// 内部存储的模板(支持延迟编译)。
|
||||||
|
enum StoredTemplate {
|
||||||
|
Compiled(PromptTemplate),
|
||||||
|
Raw(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 模板注册表——管理多模板实例。
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct PromptTemplateRegistry {
|
||||||
|
templates: HashMap<String, StoredTemplate>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptTemplateRegistry {
|
||||||
|
/// 创建一个空的模板注册表。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
templates: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从模板字符串编译并注册(立即编译)。
|
||||||
|
pub fn register(&mut self, name: &str, template: &str) -> Result<(), PromptError> {
|
||||||
|
let compiled = PromptTemplate::compile(template)?;
|
||||||
|
self.templates
|
||||||
|
.insert(name.to_string(), StoredTemplate::Compiled(compiled));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 延迟编译注册:只存储原始字符串,首次渲染时编译。
|
||||||
|
pub fn register_lazy(&mut self, name: &str, template: &str) {
|
||||||
|
self.templates.insert(
|
||||||
|
name.to_string(),
|
||||||
|
StoredTemplate::Raw(template.to_string()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从文件读取并编译注册。
|
||||||
|
pub fn register_file(&mut self, name: &str, path: &std::path::Path) -> Result<(), PromptError> {
|
||||||
|
let content = std::fs::read_to_string(path)?;
|
||||||
|
let compiled = PromptTemplate::compile(&content)?;
|
||||||
|
self.templates
|
||||||
|
.insert(name.to_string(), StoredTemplate::Compiled(compiled));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取已注册的模板(延迟编译的模板在此首次编译)。
|
||||||
|
pub fn get(&mut self, name: &str) -> Result<&PromptTemplate, PromptError> {
|
||||||
|
if let Some(stored) = self.templates.get_mut(name) {
|
||||||
|
if let StoredTemplate::Raw(raw) = stored {
|
||||||
|
let compiled = PromptTemplate::compile(raw)?;
|
||||||
|
*stored = StoredTemplate::Compiled(compiled);
|
||||||
|
}
|
||||||
|
match stored {
|
||||||
|
StoredTemplate::Compiled(tpl) => Ok(tpl),
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(PromptError::PartialNotFound(name.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 按名称渲染。
|
||||||
|
pub fn render(&mut self, name: &str, ctx: &TemplateContext) -> Result<String, PromptError> {
|
||||||
|
let tpl = self.get(name)?;
|
||||||
|
tpl.render(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
//! 工具系统 —— 工具抽象、注册、调用、权限控制与 MCP 集成。
|
||||||
|
|
||||||
|
pub mod base;
|
||||||
|
pub mod error;
|
||||||
|
pub mod mcp;
|
||||||
|
pub mod permission;
|
||||||
|
pub mod registry;
|
||||||
|
|
||||||
|
pub use base::{BaseTool, ToolContext, ToolRef};
|
||||||
|
pub use error::ToolError;
|
||||||
|
pub use mcp::{McpClient, McpTransport};
|
||||||
|
pub use permission::{Permission, PermissionChecker, PermissionConfig};
|
||||||
|
pub use registry::{ToolInvocation, ToolRegistry};
|
||||||
@@ -0,0 +1,139 @@
|
|||||||
|
//! 工具抽象接口与执行上下文。
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::Value;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
use crate::tools::error::ToolError;
|
||||||
|
use crate::tools::permission::Permission;
|
||||||
|
|
||||||
|
/// 工具执行上下文 —— 携带每次执行的运行时信息。
|
||||||
|
///
|
||||||
|
/// 字段在 Phase 2 即注入 `execute()` 签名中,防止后续扩展时出现
|
||||||
|
/// breaking change。后续阶段可扩展字段(如 `progress`、`shared_state`),
|
||||||
|
/// 但已有工具实现无需修改。
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ToolContext<'a> {
|
||||||
|
/// 当前对话/会话 ID,用于关联性追踪。
|
||||||
|
pub session_id: &'a str,
|
||||||
|
/// 链路追踪 ID,用于跨工具调用的耗时分布。
|
||||||
|
pub trace_id: &'a str,
|
||||||
|
/// 取消令牌,用于优雅取消正在执行的工具。
|
||||||
|
pub cancellation_token: CancellationToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ToolContext<'a> {
|
||||||
|
/// 创建一个新的工具执行上下文。
|
||||||
|
pub fn new(session_id: &'a str, trace_id: &'a str) -> Self {
|
||||||
|
Self {
|
||||||
|
session_id,
|
||||||
|
trace_id,
|
||||||
|
cancellation_token: CancellationToken::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 创建一个使用给定取消令牌的上下文。
|
||||||
|
pub fn with_cancellation_token(
|
||||||
|
session_id: &'a str,
|
||||||
|
trace_id: &'a str,
|
||||||
|
token: CancellationToken,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
session_id,
|
||||||
|
trace_id,
|
||||||
|
cancellation_token: token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 工具抽象接口 —— 所有工具(自定义或 MCP)最终都实现此 trait。
|
||||||
|
#[async_trait]
|
||||||
|
pub trait BaseTool: Send + Sync {
|
||||||
|
/// 工具名称(唯一标识,用于 LLM 的 tool_calls.name 匹配)。
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
|
/// 工具描述(LLM 据此决定是否调用此工具)。
|
||||||
|
fn description(&self) -> &str;
|
||||||
|
|
||||||
|
/// 工具参数定义(JSON Schema 格式,传递给 LLM 的 tool.parameters)。
|
||||||
|
fn parameters(&self) -> Value;
|
||||||
|
|
||||||
|
/// 声明工具所需的权限列表。
|
||||||
|
fn required_permissions(&self) -> Vec<Permission> {
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行工具调用。
|
||||||
|
///
|
||||||
|
/// `ctx` 携带执行上下文(session_id、trace_id、cancellation_token),
|
||||||
|
/// 工具实现可在执行期间检查 `ctx.cancellation_token` 来支持优雅取消。
|
||||||
|
async fn execute(&self, args: Value, ctx: &ToolContext<'_>) -> Result<Value, ToolError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 为 `Arc<dyn BaseTool>` 提供 `Send + Sync` 包装,便于在 `Vec<Arc<dyn BaseTool>>` 中使用。
|
||||||
|
pub type ToolRef = Arc<dyn BaseTool>;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
struct EchoTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for EchoTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"echo"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"回显输入"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": { "type": "string" }
|
||||||
|
},
|
||||||
|
"required": ["text"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||||
|
Ok(args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_mock_tool_execute() {
|
||||||
|
let tool = EchoTool;
|
||||||
|
let ctx = ToolContext::new("session-1", "trace-1");
|
||||||
|
let result = tool.execute(json!({"text": "hello"}), &ctx).await.unwrap();
|
||||||
|
assert_eq!(result, json!({"text": "hello"}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_default_permissions_empty() {
|
||||||
|
let tool = EchoTool;
|
||||||
|
assert!(tool.required_permissions().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_context_creation() {
|
||||||
|
let ctx = ToolContext::new("s1", "t1");
|
||||||
|
assert_eq!(ctx.session_id, "s1");
|
||||||
|
assert_eq!(ctx.trace_id, "t1");
|
||||||
|
assert!(!ctx.cancellation_token.is_cancelled());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_context_cancellation() {
|
||||||
|
let token = CancellationToken::new();
|
||||||
|
token.cancel();
|
||||||
|
let ctx = ToolContext::with_cancellation_token("s1", "t1", token);
|
||||||
|
assert!(ctx.cancellation_token.is_cancelled());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,118 @@
|
|||||||
|
//! 工具系统错误类型。
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// 工具调用过程中可能发生的所有错误。
|
||||||
|
#[derive(thiserror::Error, Debug, Clone)]
|
||||||
|
pub enum ToolError {
|
||||||
|
/// 工具未注册。
|
||||||
|
#[error("工具 '{0}' 未注册")]
|
||||||
|
NotFound(String),
|
||||||
|
|
||||||
|
/// 工具执行失败(可恢复——文本回传 LLM)。
|
||||||
|
#[error("工具 '{0}' 执行失败: {1}")]
|
||||||
|
ExecutionFailed(String, String),
|
||||||
|
|
||||||
|
/// 工具参数无效(可恢复——文本回传 LLM)。
|
||||||
|
#[error("工具 '{0}' 参数无效: {1}")]
|
||||||
|
InvalidArguments(String, String),
|
||||||
|
|
||||||
|
/// 权限被拒绝(不可恢复——终止循环)。
|
||||||
|
#[error("权限被拒绝: 工具 '{0}' 需要 {1} 权限")]
|
||||||
|
PermissionDenied(String, String),
|
||||||
|
|
||||||
|
/// MCP 协议错误(不可恢复)。
|
||||||
|
#[error("MCP 协议错误: {0}")]
|
||||||
|
McpError(String),
|
||||||
|
|
||||||
|
/// MCP 未初始化(不可恢复)。
|
||||||
|
#[error("MCP 未初始化: {0}")]
|
||||||
|
McpNotInitialized(String),
|
||||||
|
|
||||||
|
/// MCP 超时(不可恢复)。
|
||||||
|
#[error("MCP 超时: {0}")]
|
||||||
|
McpTimeout(String),
|
||||||
|
|
||||||
|
/// IO 错误(不可恢复)。
|
||||||
|
#[error("IO 错误: {0}")]
|
||||||
|
Io(Arc<std::io::Error>),
|
||||||
|
|
||||||
|
/// 取消。
|
||||||
|
#[error("工具执行已取消: {0}")]
|
||||||
|
Cancelled(String),
|
||||||
|
|
||||||
|
/// 其他未分类错误。
|
||||||
|
#[error("其他错误: {0}")]
|
||||||
|
Other(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ToolError {
|
||||||
|
fn from(e: std::io::Error) -> Self {
|
||||||
|
ToolError::Io(Arc::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolError {
|
||||||
|
/// 判断错误是否可恢复——可恢复的错误回传 LLM 由其自行重试,
|
||||||
|
/// 不可恢复的错误终止自动 tool 循环并返回给调用方。
|
||||||
|
pub fn is_recoverable(&self) -> bool {
|
||||||
|
matches!(
|
||||||
|
self,
|
||||||
|
Self::ExecutionFailed(..) | Self::InvalidArguments(..) | Self::Other(_)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_execution_failed_is_recoverable() {
|
||||||
|
let err = ToolError::ExecutionFailed("foo".into(), "boom".into());
|
||||||
|
assert!(err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_arguments_is_recoverable() {
|
||||||
|
let err = ToolError::InvalidArguments("foo".into(), "missing x".into());
|
||||||
|
assert!(err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_not_found_is_not_recoverable() {
|
||||||
|
let err = ToolError::NotFound("foo".into());
|
||||||
|
assert!(!err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_permission_denied_is_not_recoverable() {
|
||||||
|
let err = ToolError::PermissionDenied("foo".into(), "Shell".into());
|
||||||
|
assert!(!err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mcp_error_is_not_recoverable() {
|
||||||
|
let err = ToolError::McpError("protocol".into());
|
||||||
|
assert!(!err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mcp_timeout_is_not_recoverable() {
|
||||||
|
let err = ToolError::McpTimeout("foo".into());
|
||||||
|
assert!(!err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_io_is_not_recoverable() {
|
||||||
|
let io_err = std::io::Error::new(std::io::ErrorKind::Other, "disk");
|
||||||
|
let err = ToolError::from(io_err);
|
||||||
|
assert!(!err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_other_is_recoverable() {
|
||||||
|
let err = ToolError::Other("something".into());
|
||||||
|
assert!(err.is_recoverable());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,640 @@
|
|||||||
|
//! MCP 协议客户端 —— 与 MCP Server 通过 JSON-RPC over stdio 通信。
|
||||||
|
//!
|
||||||
|
//! 当前 Phase 2 实现 stdio transport。`StreamableHttp` 枚举变体已预留,
|
||||||
|
//! 但实际实现推迟到后续版本。
|
||||||
|
//!
|
||||||
|
//! ## 协议版本
|
||||||
|
//!
|
||||||
|
//! 实现遵循 MCP 协议版本 2025-03-26。
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::process::Stdio;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
|
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||||
|
use tokio::sync::{oneshot, Mutex};
|
||||||
|
|
||||||
|
use crate::llm::types::ToolDefinition;
|
||||||
|
use crate::tools::base::{BaseTool, ToolContext, ToolRef};
|
||||||
|
use crate::tools::error::ToolError;
|
||||||
|
|
||||||
|
/// MCP 协议版本。
|
||||||
|
const MCP_VERSION: &str = "2025-03-26";
|
||||||
|
|
||||||
|
/// MCP 传输方式。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum McpTransport {
|
||||||
|
/// 通过子进程 stdin/stdout 通信。
|
||||||
|
Stdio {
|
||||||
|
/// 启动命令(如 `"npx"`)。
|
||||||
|
command: String,
|
||||||
|
/// 命令参数(如 `["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]`)。
|
||||||
|
args: Vec<String>,
|
||||||
|
},
|
||||||
|
/// Streamable HTTP 传输(MCP 2025-03-26 引入,替代已废弃的 HTTP+SSE)。
|
||||||
|
///
|
||||||
|
/// 当前 Phase 2 预留枚举变体,调用方法会返回 `ToolError::McpError`。
|
||||||
|
StreamableHttp {
|
||||||
|
/// MCP 端点 URL。
|
||||||
|
url: String,
|
||||||
|
/// 可选的 HTTP 头(如 Authorization)。
|
||||||
|
headers: Option<Vec<(String, String)>>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JSON-RPC 请求。
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct JsonRpcRequest {
|
||||||
|
jsonrpc: &'static str,
|
||||||
|
id: u64,
|
||||||
|
method: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
params: Option<Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JsonRpcRequest {
|
||||||
|
fn new(id: u64, method: impl Into<String>, params: Option<Value>) -> Self {
|
||||||
|
Self {
|
||||||
|
jsonrpc: "2.0",
|
||||||
|
id,
|
||||||
|
method: method.into(),
|
||||||
|
params,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JSON-RPC 响应。
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct JsonRpcResponse {
|
||||||
|
jsonrpc: String,
|
||||||
|
id: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
result: Option<Value>,
|
||||||
|
#[serde(default)]
|
||||||
|
error: Option<JsonRpcError>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct JsonRpcError {
|
||||||
|
code: i32,
|
||||||
|
message: String,
|
||||||
|
#[serde(default)]
|
||||||
|
data: Option<Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP 子进程运行时状态。
|
||||||
|
struct ChildProcessState {
|
||||||
|
child: Child,
|
||||||
|
stdin: ChildStdin,
|
||||||
|
pending: HashMap<u64, oneshot::Sender<Result<Value, ToolError>>>,
|
||||||
|
next_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChildProcessState {
|
||||||
|
fn next_id(&mut self) -> u64 {
|
||||||
|
self.next_id += 1;
|
||||||
|
self.next_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP Server 暴露的工具(缓存结构)。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct McpTool {
|
||||||
|
name: String,
|
||||||
|
description: Option<String>,
|
||||||
|
input_schema: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP 客户端 —— 与 MCP 服务器通信。
|
||||||
|
pub struct McpClient {
|
||||||
|
transport: McpTransport,
|
||||||
|
server_name: String,
|
||||||
|
/// 已初始化的工具列表(缓存)。
|
||||||
|
tools: Vec<McpTool>,
|
||||||
|
/// 是否已初始化。
|
||||||
|
initialized: AtomicBool,
|
||||||
|
/// 超时时间(秒)。
|
||||||
|
timeout_secs: u64,
|
||||||
|
/// 子进程运行时状态(`connect()` 后创建,`close()` 后取回)。
|
||||||
|
process: Option<Arc<Mutex<ChildProcessState>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for McpClient {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("McpClient")
|
||||||
|
.field("server_name", &self.server_name)
|
||||||
|
.field("initialized", &self.initialized.load(Ordering::SeqCst))
|
||||||
|
.field("tool_count", &self.tools.len())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpClient {
|
||||||
|
/// 创建一个 MCP 客户端。
|
||||||
|
pub fn new(server_name: impl Into<String>, transport: McpTransport) -> Self {
|
||||||
|
Self {
|
||||||
|
transport,
|
||||||
|
server_name: server_name.into(),
|
||||||
|
tools: Vec::new(),
|
||||||
|
initialized: AtomicBool::new(false),
|
||||||
|
timeout_secs: 30,
|
||||||
|
process: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 设置超时时间(秒)。
|
||||||
|
pub fn with_timeout(mut self, secs: u64) -> Self {
|
||||||
|
self.timeout_secs = secs;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检查是否已连接。
|
||||||
|
pub fn is_initialized(&self) -> bool {
|
||||||
|
self.initialized.load(Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 连接并初始化(发送 initialize 请求)。
|
||||||
|
pub async fn connect(&mut self) -> Result<(), ToolError> {
|
||||||
|
if self.is_initialized() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
match &self.transport {
|
||||||
|
McpTransport::Stdio { command, args } => {
|
||||||
|
let mut cmd = Command::new(command);
|
||||||
|
cmd.args(args)
|
||||||
|
.stdin(Stdio::piped())
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
.stderr(Stdio::piped());
|
||||||
|
#[cfg(unix)]
|
||||||
|
cmd.kill_on_drop(true);
|
||||||
|
#[cfg(windows)]
|
||||||
|
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW
|
||||||
|
|
||||||
|
let mut child = cmd
|
||||||
|
.spawn()
|
||||||
|
.map_err(|e| ToolError::McpError(format!("启动 MCP 子进程失败: {e}")))?;
|
||||||
|
|
||||||
|
let stdin = child
|
||||||
|
.stdin
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| ToolError::McpError("无法获取子进程 stdin".into()))?;
|
||||||
|
let stdout = child
|
||||||
|
.stdout
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| ToolError::McpError("无法获取子进程 stdout".into()))?;
|
||||||
|
|
||||||
|
// 启动 reader task 持续读取 stdout
|
||||||
|
let pending: HashMap<u64, oneshot::Sender<Result<Value, ToolError>>> =
|
||||||
|
HashMap::new();
|
||||||
|
let state = Arc::new(Mutex::new(ChildProcessState {
|
||||||
|
child,
|
||||||
|
stdin,
|
||||||
|
pending,
|
||||||
|
next_id: 0,
|
||||||
|
}));
|
||||||
|
|
||||||
|
// 启动后台 reader
|
||||||
|
let pending_arc = Arc::clone(&state);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
Self::read_loop(BufReader::new(stdout), pending_arc).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
self.process = Some(state);
|
||||||
|
}
|
||||||
|
McpTransport::StreamableHttp { .. } => {
|
||||||
|
return Err(ToolError::McpError(
|
||||||
|
"StreamableHttp transport 尚未实现".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送 initialize 请求
|
||||||
|
let init_params = json!({
|
||||||
|
"protocolVersion": MCP_VERSION,
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {
|
||||||
|
"name": "agcore",
|
||||||
|
"version": env!("CARGO_PKG_VERSION")
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let _response = self
|
||||||
|
.send_request("initialize", Some(init_params))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// 发送 initialized 通知(无 id)
|
||||||
|
self.send_notification("notifications/initialized", Some(json!({})))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.initialized.store(true, Ordering::SeqCst);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 列出服务器支持的工具(调用 `tools/list`)。
|
||||||
|
pub async fn list_tools(&mut self) -> Result<Vec<ToolDefinition>, ToolError> {
|
||||||
|
if !self.is_initialized() {
|
||||||
|
return Err(ToolError::McpNotInitialized(self.server_name.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self.send_request("tools/list", None).await?;
|
||||||
|
let tools_value = response
|
||||||
|
.get("tools")
|
||||||
|
.ok_or_else(|| ToolError::McpError("tools/list 响应缺少 tools 字段".into()))?;
|
||||||
|
let tools_arr = tools_value
|
||||||
|
.as_array()
|
||||||
|
.ok_or_else(|| ToolError::McpError("tools/list 响应 tools 字段不是数组".into()))?;
|
||||||
|
|
||||||
|
self.tools.clear();
|
||||||
|
let mut defs = Vec::with_capacity(tools_arr.len());
|
||||||
|
for tool in tools_arr {
|
||||||
|
let name = tool
|
||||||
|
.get("name")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| ToolError::McpError("工具缺少 name 字段".into()))?
|
||||||
|
.to_string();
|
||||||
|
let description = tool
|
||||||
|
.get("description")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
let input_schema = tool
|
||||||
|
.get("inputSchema")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| json!({"type": "object", "properties": {}}));
|
||||||
|
|
||||||
|
self.tools.push(McpTool {
|
||||||
|
name: name.clone(),
|
||||||
|
description: description.clone(),
|
||||||
|
input_schema: input_schema.clone(),
|
||||||
|
});
|
||||||
|
defs.push(ToolDefinition {
|
||||||
|
name,
|
||||||
|
description,
|
||||||
|
parameters: input_schema,
|
||||||
|
strict: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(defs)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 调用一个工具(调用 `tools/call`)。
|
||||||
|
pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, ToolError> {
|
||||||
|
if !self.is_initialized() {
|
||||||
|
return Err(ToolError::McpNotInitialized(self.server_name.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let params = json!({
|
||||||
|
"name": name,
|
||||||
|
"arguments": args,
|
||||||
|
});
|
||||||
|
let response = self.send_request("tools/call", Some(params)).await?;
|
||||||
|
|
||||||
|
// 解析 content 字段
|
||||||
|
if let Some(content) = response.get("content").and_then(|c| c.as_array()) {
|
||||||
|
// 收集所有 text 内容
|
||||||
|
let mut combined = String::new();
|
||||||
|
for item in content {
|
||||||
|
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||||
|
if !combined.is_empty() {
|
||||||
|
combined.push('\n');
|
||||||
|
}
|
||||||
|
combined.push_str(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !combined.is_empty() {
|
||||||
|
return Ok(Value::String(combined));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有 content 字段,尝试直接返回 is_error 标记
|
||||||
|
if let Some(true) = response.get("isError").and_then(|v| v.as_bool()) {
|
||||||
|
return Err(ToolError::ExecutionFailed(
|
||||||
|
name.to_string(),
|
||||||
|
"MCP 工具返回 isError=true".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 回退:返回完整响应
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 关闭连接(终止子进程)。
|
||||||
|
pub async fn close(&mut self) -> Result<(), ToolError> {
|
||||||
|
if !self.is_initialized() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试发送 shutdown(不强制要求响应)
|
||||||
|
let _ = self.send_notification("shutdown", None).await;
|
||||||
|
|
||||||
|
if let Some(state) = self.process.take() {
|
||||||
|
let mut state = state.lock().await;
|
||||||
|
// 优雅等待 5 秒
|
||||||
|
let graceful = tokio::time::timeout(
|
||||||
|
Duration::from_secs(5),
|
||||||
|
state.child.wait(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
if graceful.is_err() {
|
||||||
|
// 超时则强杀
|
||||||
|
let _ = state.child.kill().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.initialized.store(false, Ordering::SeqCst);
|
||||||
|
self.tools.clear();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 将 MCP 客户端转换为 `BaseTool` 适配器列表(用于注册到 `ToolRegistry`)。
|
||||||
|
///
|
||||||
|
/// **注意**:返回的适配器持有 `Arc<McpClient>`,但 `McpClient` 的可变性
|
||||||
|
/// (如 `list_tools` 刷新缓存)会通过 `Mutex` 处理。当前适配器仅缓存
|
||||||
|
/// 转换时的工具列表,不感知后续刷新。
|
||||||
|
pub fn into_tools(self) -> Vec<ToolRef> {
|
||||||
|
let mut tools = Vec::with_capacity(self.tools.len());
|
||||||
|
for mcp_tool in self.tools {
|
||||||
|
let tool = McpToolAdapter {
|
||||||
|
client: McpClientHandle::Empty,
|
||||||
|
name: mcp_tool.name,
|
||||||
|
description: mcp_tool.description.unwrap_or_default(),
|
||||||
|
parameters: mcp_tool.input_schema,
|
||||||
|
};
|
||||||
|
tools.push(Arc::new(tool) as ToolRef);
|
||||||
|
}
|
||||||
|
tools
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_request(
|
||||||
|
&self,
|
||||||
|
method: &str,
|
||||||
|
params: Option<Value>,
|
||||||
|
) -> Result<Value, ToolError> {
|
||||||
|
let state_arc = self
|
||||||
|
.process
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| ToolError::McpNotInitialized(self.server_name.clone()))?
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let (id, request_json) = {
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
let id = state.next_id();
|
||||||
|
let req = JsonRpcRequest::new(id, method, params);
|
||||||
|
let json = serde_json::to_string(&req)
|
||||||
|
.map_err(|e| ToolError::McpError(format!("序列化请求失败: {e}")))?;
|
||||||
|
(id, json)
|
||||||
|
};
|
||||||
|
|
||||||
|
// 注册 oneshot 等待响应
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
{
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
state.pending.insert(id, tx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入请求
|
||||||
|
{
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
state
|
||||||
|
.stdin
|
||||||
|
.write_all(request_json.as_bytes())
|
||||||
|
.await
|
||||||
|
.map_err(|e| ToolError::McpError(format!("写入请求失败: {e}")))?;
|
||||||
|
state
|
||||||
|
.stdin
|
||||||
|
.write_all(b"\n")
|
||||||
|
.await
|
||||||
|
.map_err(|e| ToolError::McpError(format!("写入换行失败: {e}")))?;
|
||||||
|
state.stdin.flush().await.map_err(|e| {
|
||||||
|
ToolError::McpError(format!("flush stdin 失败: {e}"))
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 等待响应(带超时)
|
||||||
|
tokio::time::timeout(Duration::from_secs(self.timeout_secs), rx)
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
// 超时:清理 pending
|
||||||
|
let state_arc = state_arc.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
state.pending.remove(&id);
|
||||||
|
});
|
||||||
|
ToolError::McpTimeout(method.to_string())
|
||||||
|
})?
|
||||||
|
.map_err(|_| ToolError::McpError("response channel 关闭".into()))?
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_notification(
|
||||||
|
&self,
|
||||||
|
method: &str,
|
||||||
|
params: Option<Value>,
|
||||||
|
) -> Result<(), ToolError> {
|
||||||
|
let state_arc = self
|
||||||
|
.process
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| ToolError::McpNotInitialized(self.server_name.clone()))?
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let notification = json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": method,
|
||||||
|
"params": params,
|
||||||
|
});
|
||||||
|
let json = serde_json::to_string(¬ification)
|
||||||
|
.map_err(|e| ToolError::McpError(format!("序列化通知失败: {e}")))?;
|
||||||
|
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
state
|
||||||
|
.stdin
|
||||||
|
.write_all(json.as_bytes())
|
||||||
|
.await
|
||||||
|
.map_err(|e| ToolError::McpError(format!("写入通知失败: {e}")))?;
|
||||||
|
state
|
||||||
|
.stdin
|
||||||
|
.write_all(b"\n")
|
||||||
|
.await
|
||||||
|
.map_err(|e| ToolError::McpError(format!("写入换行失败: {e}")))?;
|
||||||
|
state
|
||||||
|
.stdin
|
||||||
|
.flush()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ToolError::McpError(format!("flush stdin 失败: {e}")))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 持续读取 stdout,将响应分发到对应的 oneshot sender。
|
||||||
|
async fn read_loop(
|
||||||
|
mut reader: BufReader<ChildStdout>,
|
||||||
|
state: Arc<Mutex<ChildProcessState>>,
|
||||||
|
) {
|
||||||
|
let mut line = String::new();
|
||||||
|
loop {
|
||||||
|
line.clear();
|
||||||
|
match reader.read_line(&mut line).await {
|
||||||
|
Ok(0) => {
|
||||||
|
// EOF:通知所有 pending 失败
|
||||||
|
let mut state = state.lock().await;
|
||||||
|
for (_, tx) in state.pending.drain() {
|
||||||
|
let _ = tx.send(Err(ToolError::McpError("子进程退出".into())));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Ok(_) => {
|
||||||
|
let trimmed = line.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// 尝试解析为 JSON-RPC 响应
|
||||||
|
let parsed: Result<JsonRpcResponse, _> = serde_json::from_str(trimmed);
|
||||||
|
if let Ok(response) = parsed {
|
||||||
|
let value = if let Some(err) = response.error {
|
||||||
|
Err(ToolError::McpError(format!(
|
||||||
|
"[{}] {}",
|
||||||
|
err.code, err.message
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
Ok(response.result.unwrap_or(Value::Null))
|
||||||
|
};
|
||||||
|
let mut state = state.lock().await;
|
||||||
|
if let Some(tx) = state.pending.remove(&response.id) {
|
||||||
|
let _ = tx.send(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 非响应消息(通知、request from server)忽略
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("MCP read_loop error: {e}");
|
||||||
|
let mut state = state.lock().await;
|
||||||
|
for (_, tx) in state.pending.drain() {
|
||||||
|
let _ = tx.send(Err(ToolError::McpError(format!("读取失败: {e}"))));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP 工具适配器 —— 将 MCP 工具包装为 `BaseTool`。
|
||||||
|
struct McpToolAdapter {
|
||||||
|
/// 持有 client 的弱引用。实际生产中应使用 `Arc<McpClient>`,
|
||||||
|
/// 但当前 Phase 2 实现不直接持有可变的 `McpClient`。
|
||||||
|
/// 标记为 unused 但保留字段以展示扩展路径。
|
||||||
|
#[allow(dead_code)]
|
||||||
|
client: McpClientHandle,
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
parameters: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
enum McpClientHandle {
|
||||||
|
Empty,
|
||||||
|
// Future: Shared(Arc<McpClient>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for McpToolAdapter {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
&self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
&self.description
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
self.parameters.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(
|
||||||
|
&self,
|
||||||
|
_args: Value,
|
||||||
|
_ctx: &ToolContext<'_>,
|
||||||
|
) -> Result<Value, ToolError> {
|
||||||
|
// 当前 Phase 2 实现的简化:McpToolAdapter 不持有活跃 MCP 连接。
|
||||||
|
// 实际生产中应持有 Arc<McpClient> 并通过 mcp.call_tool() 执行。
|
||||||
|
// 这里返回错误,提示需要通过其他方式调用 MCP 工具。
|
||||||
|
Err(ToolError::McpError(format!(
|
||||||
|
"MCP 工具 '{}' 需要活跃的 McpClient 引用(当前 Phase 2 简化实现)",
|
||||||
|
self.name
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_transport_debug() {
|
||||||
|
let transport = McpTransport::Stdio {
|
||||||
|
command: "echo".to_string(),
|
||||||
|
args: vec!["hello".to_string()],
|
||||||
|
};
|
||||||
|
let formatted = format!("{transport:?}");
|
||||||
|
assert!(formatted.contains("echo"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_creation() {
|
||||||
|
let transport = McpTransport::Stdio {
|
||||||
|
command: "test".to_string(),
|
||||||
|
args: vec![],
|
||||||
|
};
|
||||||
|
let client = McpClient::new("test-server", transport).with_timeout(60);
|
||||||
|
assert_eq!(client.server_name, "test-server");
|
||||||
|
assert_eq!(client.timeout_secs, 60);
|
||||||
|
assert!(!client.is_initialized());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_jsonrpc_request_serialize() {
|
||||||
|
let req = JsonRpcRequest::new(42, "test", Some(json!({"a": 1})));
|
||||||
|
let s = serde_json::to_string(&req).unwrap();
|
||||||
|
assert!(s.contains("\"jsonrpc\":\"2.0\""));
|
||||||
|
assert!(s.contains("\"id\":42"));
|
||||||
|
assert!(s.contains("\"method\":\"test\""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_jsonrpc_response_parse_ok() {
|
||||||
|
let s = r#"{"jsonrpc":"2.0","id":1,"result":{"foo":"bar"}}"#;
|
||||||
|
let resp: JsonRpcResponse = serde_json::from_str(s).unwrap();
|
||||||
|
assert_eq!(resp.id, 1);
|
||||||
|
assert!(resp.result.is_some());
|
||||||
|
assert!(resp.error.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_jsonrpc_response_parse_error() {
|
||||||
|
let s =
|
||||||
|
r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"Method not found"}}"#;
|
||||||
|
let resp: JsonRpcResponse = serde_json::from_str(s).unwrap();
|
||||||
|
assert_eq!(resp.id, 1);
|
||||||
|
assert!(resp.result.is_none());
|
||||||
|
let err = resp.error.unwrap();
|
||||||
|
assert_eq!(err.code, -32601);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_streamable_http_not_implemented() {
|
||||||
|
let mut client = McpClient::new(
|
||||||
|
"http-server",
|
||||||
|
McpTransport::StreamableHttp {
|
||||||
|
url: "https://example.com/mcp".to_string(),
|
||||||
|
headers: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let result = client.connect().await;
|
||||||
|
// 当前 Phase 2 返回未实现错误
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(matches!(result, Err(ToolError::McpError(_))));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,286 @@
|
|||||||
|
//! 工具权限管理。
|
||||||
|
|
||||||
|
use crate::tools::error::ToolError;
|
||||||
|
|
||||||
|
/// 权限级别枚举。
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum Permission {
|
||||||
|
/// 只读(读取文件、查询数据库等)。
|
||||||
|
Read,
|
||||||
|
/// 写入(创建/修改文件、插入数据等)。
|
||||||
|
Write,
|
||||||
|
/// 删除(删除文件、记录等)。
|
||||||
|
Delete,
|
||||||
|
/// 网络访问(HTTP 请求等)。
|
||||||
|
Network,
|
||||||
|
/// Shell 命令执行。
|
||||||
|
Shell,
|
||||||
|
/// 文件系统操作(除读/写/删之外的 FS 操作)。
|
||||||
|
FileSystem,
|
||||||
|
/// 自定义权限(可通过 namespaced 字符串扩展)。
|
||||||
|
Custom(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Permission {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::Read => write!(f, "Read"),
|
||||||
|
Self::Write => write!(f, "Write"),
|
||||||
|
Self::Delete => write!(f, "Delete"),
|
||||||
|
Self::Network => write!(f, "Network"),
|
||||||
|
Self::Shell => write!(f, "Shell"),
|
||||||
|
Self::FileSystem => write!(f, "FileSystem"),
|
||||||
|
Self::Custom(s) => write!(f, "Custom({})", s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 权限配置。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PermissionConfig {
|
||||||
|
/// 允许的权限列表(空 = 全部允许,配合 `allow_unspecified` 决定)。
|
||||||
|
pub allowed: Vec<Permission>,
|
||||||
|
/// 拒绝的权限列表(优先级高于 `allowed`)。
|
||||||
|
pub denied: Vec<Permission>,
|
||||||
|
/// 当工具未声明权限时是否允许执行。
|
||||||
|
pub allow_unspecified: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PermissionConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
allowed: vec![Permission::Read, Permission::Network],
|
||||||
|
denied: vec![Permission::Delete, Permission::Shell],
|
||||||
|
allow_unspecified: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 权限检查器。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PermissionChecker {
|
||||||
|
config: PermissionConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PermissionChecker {
|
||||||
|
/// 创建一个新的权限检查器。
|
||||||
|
pub fn new(config: PermissionConfig) -> Self {
|
||||||
|
Self { config }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检查指定工具声明的权限是否允许执行。
|
||||||
|
///
|
||||||
|
/// 判定规则:
|
||||||
|
/// 1. 任一权限在 `denied` 中 → 拒绝
|
||||||
|
/// 2. 所有权限都在 `allowed` 中 → 允许
|
||||||
|
/// 3. `allowed` 非空且存在未声明权限 → 拒绝
|
||||||
|
/// 4. `allowed` 为空 → 按 `allow_unspecified` 判定
|
||||||
|
/// 5. 工具未声明任何权限时按 `allow_unspecified` 判定
|
||||||
|
pub fn check(&self, tool_name: &str, permissions: &[Permission]) -> Result<(), ToolError> {
|
||||||
|
// 任一权限在 denied 中 → 拒绝
|
||||||
|
for perm in permissions {
|
||||||
|
if self.config.denied.contains(perm) {
|
||||||
|
return Err(ToolError::PermissionDenied(
|
||||||
|
tool_name.to_string(),
|
||||||
|
perm.to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 工具未声明任何权限
|
||||||
|
if permissions.is_empty() {
|
||||||
|
return if self.config.allow_unspecified {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(ToolError::PermissionDenied(
|
||||||
|
tool_name.to_string(),
|
||||||
|
"Unspecified".to_string(),
|
||||||
|
))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// allowed 为空 → 走 allow_unspecified 兜底
|
||||||
|
if self.config.allowed.is_empty() {
|
||||||
|
return if self.config.allow_unspecified {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(ToolError::PermissionDenied(
|
||||||
|
tool_name.to_string(),
|
||||||
|
"Unspecified".to_string(),
|
||||||
|
))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// allowed 非空(白名单模式)—— 所有权限必须在其中
|
||||||
|
for perm in permissions {
|
||||||
|
if !self.config.allowed.contains(perm) {
|
||||||
|
return Err(ToolError::PermissionDenied(
|
||||||
|
tool_name.to_string(),
|
||||||
|
perm.to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn p(perm: Permission) -> Vec<Permission> {
|
||||||
|
vec![perm]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_config_allows_read() {
|
||||||
|
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||||
|
assert!(checker.check("weather", &p(Permission::Read)).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_config_allows_network() {
|
||||||
|
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||||
|
assert!(checker.check("http_get", &p(Permission::Network)).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_config_denies_delete() {
|
||||||
|
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||||
|
assert!(checker
|
||||||
|
.check("rm_file", &p(Permission::Delete))
|
||||||
|
.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_config_denies_shell() {
|
||||||
|
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||||
|
assert!(checker.check("run_shell", &p(Permission::Shell)).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_white_list_mode_denies_unlisted() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![Permission::Read],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &p(Permission::Read)).is_ok());
|
||||||
|
assert!(checker.check("t", &p(Permission::Write)).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_white_list_mode_allows_listed() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![Permission::Read, Permission::Write],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &p(Permission::Write)).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_black_list_deny_priority() {
|
||||||
|
// 即便 allowed 中包含了 denied 权限,仍以 denied 为准
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![Permission::Shell, Permission::Read],
|
||||||
|
denied: vec![Permission::Shell],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &p(Permission::Shell)).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_allowed_with_allow_unspecified() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: true,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &p(Permission::Write)).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_allowed_without_allow_unspecified() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &p(Permission::Write)).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unspecified_tool_with_allow() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: true,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &[]).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unspecified_tool_without_allow() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &[]).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_custom_permission_collision() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![Permission::Custom("db:read".into())],
|
||||||
|
denied: vec![Permission::Custom("db:write".into())],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker
|
||||||
|
.check("t", &[Permission::Custom("db:read".into())])
|
||||||
|
.is_ok());
|
||||||
|
assert!(checker
|
||||||
|
.check("t", &[Permission::Custom("db:write".into())])
|
||||||
|
.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multi_permission_all_in_allowed() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![Permission::Read, Permission::Write, Permission::Network],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker
|
||||||
|
.check(
|
||||||
|
"t",
|
||||||
|
&[Permission::Read, Permission::Network]
|
||||||
|
)
|
||||||
|
.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multi_permission_one_not_in_allowed() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![Permission::Read, Permission::Network],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
// 任一权限不在白名单则拒绝
|
||||||
|
assert!(checker
|
||||||
|
.check("t", &[Permission::Read, Permission::Write])
|
||||||
|
.is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,371 @@
|
|||||||
|
//! 工具注册表 —— 管理工具注册、发现、调用。
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use futures::future::join_all;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::llm::types::ToolDefinition;
|
||||||
|
use crate::tools::base::{ToolContext, ToolRef};
|
||||||
|
use crate::tools::error::ToolError;
|
||||||
|
use crate::tools::permission::PermissionChecker;
|
||||||
|
|
||||||
|
/// 工具调用记录 —— 用于追踪和调试。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ToolInvocation {
|
||||||
|
/// 被调用的工具名。
|
||||||
|
pub tool_name: String,
|
||||||
|
/// 工具的入参。
|
||||||
|
pub input: Value,
|
||||||
|
/// 工具的输出。
|
||||||
|
pub output: Result<Value, ToolError>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolInvocation {
|
||||||
|
/// 创建一个新的工具调用记录。
|
||||||
|
pub fn new(tool_name: String, input: Value, output: Result<Value, ToolError>) -> Self {
|
||||||
|
Self {
|
||||||
|
tool_name,
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 工具注册表 —— 管理工具注册、发现、调用。
|
||||||
|
///
|
||||||
|
/// 通过 `Arc` 共享,方法签名 `&self`,可安全跨 task 并行调用。
|
||||||
|
/// 不支持运行时并发注册(应在 setup 阶段一次性构建后冻结)。
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct ToolRegistry {
|
||||||
|
inner: Arc<ToolRegistryInner>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct ToolRegistryInner {
|
||||||
|
tools: HashMap<String, ToolRef>,
|
||||||
|
permission_checker: Option<Arc<PermissionChecker>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for ToolRegistry {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("ToolRegistry")
|
||||||
|
.field("tool_names", &self.inner.tools.keys().collect::<Vec<_>>())
|
||||||
|
.field("has_checker", &self.inner.permission_checker.is_some())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolRegistry {
|
||||||
|
/// 创建一个新的工具注册表。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
inner: Arc::new(ToolRegistryInner {
|
||||||
|
tools: HashMap::new(),
|
||||||
|
permission_checker: None,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 设置权限检查器(Builder 模式)。
|
||||||
|
pub fn with_permission_checker(mut self, checker: PermissionChecker) -> Self {
|
||||||
|
let inner = Arc::make_mut(&mut self.inner);
|
||||||
|
inner.permission_checker = Some(Arc::new(checker));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 注册一个工具。
|
||||||
|
///
|
||||||
|
/// 重复注册同名工具返回错误。
|
||||||
|
pub fn register(&mut self, tool: ToolRef) -> Result<(), ToolError> {
|
||||||
|
let name = tool.name().to_string();
|
||||||
|
let inner = Arc::make_mut(&mut self.inner);
|
||||||
|
if inner.tools.contains_key(&name) {
|
||||||
|
return Err(ToolError::ExecutionFailed(name, "工具已存在".to_string()));
|
||||||
|
}
|
||||||
|
inner.tools.insert(name, tool);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 批量注册工具。
|
||||||
|
pub fn register_all(&mut self, tools: Vec<ToolRef>) -> Result<(), ToolError> {
|
||||||
|
for tool in tools {
|
||||||
|
self.register(tool)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 注销一个工具。
|
||||||
|
pub fn unregister(&mut self, name: &str) -> Option<ToolRef> {
|
||||||
|
let inner = Arc::make_mut(&mut self.inner);
|
||||||
|
inner.tools.remove(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 按名称查找工具。
|
||||||
|
pub fn get(&self, name: &str) -> Option<ToolRef> {
|
||||||
|
self.inner.tools.get(name).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取所有已注册工具的名称列表。
|
||||||
|
pub fn list_tools(&self) -> Vec<String> {
|
||||||
|
self.inner.tools.keys().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取所有工具的 `ToolDefinition` 列表(用于传递给 LLM)。
|
||||||
|
pub fn definitions(&self) -> Vec<ToolDefinition> {
|
||||||
|
self.inner
|
||||||
|
.tools
|
||||||
|
.values()
|
||||||
|
.map(|tool| ToolDefinition {
|
||||||
|
name: tool.name().to_string(),
|
||||||
|
description: Some(tool.description().to_string()),
|
||||||
|
parameters: tool.parameters(),
|
||||||
|
strict: None,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 调用单个工具(含权限检查)。
|
||||||
|
pub async fn invoke(&self, name: &str, args: Value) -> Result<ToolInvocation, ToolError> {
|
||||||
|
let tool = self
|
||||||
|
.get(name)
|
||||||
|
.ok_or_else(|| ToolError::NotFound(name.to_string()))?;
|
||||||
|
|
||||||
|
if let Some(checker) = &self.inner.permission_checker {
|
||||||
|
checker.check(name, &tool.required_permissions())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ctx = ToolContext::new(name, "");
|
||||||
|
let output = tool.execute(args.clone(), &ctx).await;
|
||||||
|
Ok(ToolInvocation::new(name.to_string(), args, output))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 并行执行多个工具调用(互不依赖的工具)。
|
||||||
|
///
|
||||||
|
/// 每个工具独立超时(`timeout_per_call_secs`,0 表示不超时)。
|
||||||
|
/// 单个工具超时不会影响其他工具的返回。
|
||||||
|
pub async fn invoke_all(
|
||||||
|
&self,
|
||||||
|
calls: Vec<(String, Value)>,
|
||||||
|
timeout_per_call_secs: u64,
|
||||||
|
) -> Vec<ToolInvocation> {
|
||||||
|
let this = self.clone();
|
||||||
|
let futures = calls.into_iter().map(|(name, args)| {
|
||||||
|
let this = this.clone();
|
||||||
|
async move {
|
||||||
|
match if timeout_per_call_secs == 0 {
|
||||||
|
Ok(this.invoke(&name, args.clone()).await)
|
||||||
|
} else {
|
||||||
|
tokio::time::timeout(
|
||||||
|
Duration::from_secs(timeout_per_call_secs),
|
||||||
|
this.invoke(&name, args.clone()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
} {
|
||||||
|
Ok(result) => result.unwrap_or_else(|e| {
|
||||||
|
ToolInvocation::new(name.clone(), args.clone(), Err(e))
|
||||||
|
}),
|
||||||
|
Err(_) => ToolInvocation::new(
|
||||||
|
name,
|
||||||
|
args,
|
||||||
|
Err(ToolError::McpTimeout("timeout".into())),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
join_all(futures).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::tools::BaseTool;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
struct AddTool {
|
||||||
|
base: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for AddTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"add"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"加法"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": { "n": { "type": "integer" } },
|
||||||
|
"required": ["n"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||||
|
let n = args["n"].as_i64().unwrap_or(0);
|
||||||
|
Ok(json!({ "result": self.base + n }))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FailTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for FailTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"fail"
|
||||||
|
}
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"总会失败"
|
||||||
|
}
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
json!({})
|
||||||
|
}
|
||||||
|
async fn execute(&self, _args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||||
|
Err(ToolError::ExecutionFailed("fail".into(), "boom".into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ShellTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for ShellTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"shell"
|
||||||
|
}
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"shell"
|
||||||
|
}
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
json!({})
|
||||||
|
}
|
||||||
|
fn required_permissions(&self) -> Vec<crate::tools::permission::Permission> {
|
||||||
|
vec![crate::tools::permission::Permission::Shell]
|
||||||
|
}
|
||||||
|
async fn execute(&self, _args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||||
|
Ok(json!({}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_register_and_get() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 10 })).unwrap();
|
||||||
|
assert!(reg.get("add").is_some());
|
||||||
|
assert!(reg.get("nonexistent").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_register_duplicate() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||||
|
let result = reg.register(Arc::new(AddTool { base: 1 }));
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_register_all() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
let result = reg.register_all(vec![
|
||||||
|
Arc::new(AddTool { base: 1 }),
|
||||||
|
Arc::new(AddTool { base: 2 }),
|
||||||
|
]);
|
||||||
|
assert!(result.is_err()); // 重名 add → 失败
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unregister() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||||
|
let removed = reg.unregister("add");
|
||||||
|
assert!(removed.is_some());
|
||||||
|
assert!(reg.get("add").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_list_tools() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||||
|
reg.register(Arc::new(FailTool)).unwrap();
|
||||||
|
let names = reg.list_tools();
|
||||||
|
assert_eq!(names.len(), 2);
|
||||||
|
assert!(names.contains(&"add".to_string()));
|
||||||
|
assert!(names.contains(&"fail".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_definitions() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||||
|
let defs = reg.definitions();
|
||||||
|
assert_eq!(defs.len(), 1);
|
||||||
|
assert_eq!(defs[0].name, "add");
|
||||||
|
assert!(defs[0].description.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invoke_success() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 100 })).unwrap();
|
||||||
|
let result = reg.invoke("add", json!({ "n": 5 })).await.unwrap();
|
||||||
|
let value = result.output.unwrap();
|
||||||
|
assert_eq!(value["result"], 105);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invoke_not_found() {
|
||||||
|
let reg = ToolRegistry::new();
|
||||||
|
let result = reg.invoke("nope", json!({})).await;
|
||||||
|
assert!(matches!(result, Err(ToolError::NotFound(_))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invoke_execution_error() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(FailTool)).unwrap();
|
||||||
|
let result = reg.invoke("fail", json!({})).await.unwrap();
|
||||||
|
assert!(result.output.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invoke_with_permission_denied() {
|
||||||
|
let mut reg = ToolRegistry::new()
|
||||||
|
.with_permission_checker(PermissionChecker::new(Default::default()));
|
||||||
|
reg.register(Arc::new(ShellTool)).unwrap();
|
||||||
|
let result = reg.invoke("shell", json!({})).await;
|
||||||
|
assert!(matches!(result, Err(ToolError::PermissionDenied(_, _))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invoke_all_parallel() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 1 })).unwrap();
|
||||||
|
reg.register(Arc::new(FailTool)).unwrap();
|
||||||
|
let calls = vec![
|
||||||
|
("add".into(), json!({ "n": 1 })),
|
||||||
|
("add".into(), json!({ "n": 2 })),
|
||||||
|
("fail".into(), json!({})),
|
||||||
|
];
|
||||||
|
let results = reg.invoke_all(calls, 0).await;
|
||||||
|
assert_eq!(results.len(), 3);
|
||||||
|
assert!(results[0].output.is_ok());
|
||||||
|
assert!(results[1].output.is_ok());
|
||||||
|
assert!(results[2].output.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invoke_all_with_timeout() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||||
|
let calls = vec![("add".into(), json!({ "n": 1 }))];
|
||||||
|
let results = reg.invoke_all(calls, 5).await;
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert!(results[0].output.is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user