Compare commits
42 Commits
16c3e14cb6
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 829be90d19 | |||
| ce1f1aaca0 | |||
| 4de7db0b2c | |||
| 2b189880a9 | |||
| 59ec0f5597 | |||
| be595a6771 | |||
| 75f8736931 | |||
| b539f37eeb | |||
| 336920554a | |||
| 63c50e1fc7 | |||
| 0c51bb78a6 | |||
| 2ecc0b4001 | |||
| 1fe7f02281 | |||
| 6dc7ee492f | |||
| b571f530f8 | |||
| 59994bf55e | |||
| fb1c530358 | |||
| f818bd59f5 | |||
| 8573c6eb92 | |||
| 692bff5751 | |||
| b6e7acfb0f | |||
| e598f6d3ee | |||
| 5d6bb5e983 | |||
| 0d58d07ab1 | |||
| dd9c5be1fe | |||
| 993ae0eb4b | |||
| 7f5513adf3 | |||
| ea1e5c7f7e | |||
| 32f3edaf19 | |||
| 69b6dd942b | |||
| 99b304e120 | |||
| 0267da93f1 | |||
| af5a580b5e | |||
| f7e73dd561 | |||
| 28635e28d5 | |||
| e22c176643 | |||
| a4b7b3b9f9 | |||
| 54387a5f64 | |||
| 91d32a6a82 | |||
| b21e163be0 | |||
| 2f58c803bf | |||
| 93de86e02d |
@@ -0,0 +1,74 @@
|
||||
# http://editorconfig.org
|
||||
root = true
|
||||
|
||||
[*]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
max_line_length = 120
|
||||
|
||||
[*.java]
|
||||
indent_size = 4
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
|
||||
[*.{yml, yaml}]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
|
||||
[*.json]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
|
||||
[*.sql]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
|
||||
[*.go]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
|
||||
[*.py]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
|
||||
[*.sh]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
|
||||
[*.rs]
|
||||
indent_size = 4
|
||||
indent_style = space
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
+880
-77
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,3 @@
|
||||
**/*.md
|
||||
.env
|
||||
.env.*
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"bracketSpacing": true,
|
||||
"bracketSameLine": true,
|
||||
"tabWidth": 2,
|
||||
"useTabs": false,
|
||||
"semi": true,
|
||||
"singleQuote": true,
|
||||
"jsxSingleQuote": false,
|
||||
"trailingComma": "all",
|
||||
"arrowParens": "always",
|
||||
"printWidth": 100
|
||||
}
|
||||
@@ -1,8 +1,11 @@
|
||||
# AGENTS.md
|
||||
|
||||
Karpathy 启发式行为指南,适用于 OpenCode。可根据需要与项目特定指令合并使用。
|
||||
---
|
||||
|
||||
**权衡:** 这些指南倾向于谨慎而非速度。对于简单任务,可自行判断使用。
|
||||
## 交互要求
|
||||
|
||||
- 思考过程全程必须使用中文,包括需求分析、逻辑拆解、方案选择等所有推理环节
|
||||
- 最终输出内容必须全部使用中文,除代码语法本身和英文关键词以外
|
||||
|
||||
---
|
||||
|
||||
@@ -28,6 +31,8 @@ Karpathy 启发式行为指南,适用于 OpenCode。可根据需要与项目
|
||||
### 3. 精准变更 (Surgical Changes)
|
||||
**只改动必须改的。只清理你自己造成的混乱。**
|
||||
|
||||
- 未明确要求时不修改已有文件
|
||||
- 先确认意图再动手
|
||||
- 不要"优化"相邻的代码、注释或格式。
|
||||
- 不要重构没有问题的代码。
|
||||
- 遵循已有风格,即使你自己的写法不同。
|
||||
@@ -61,9 +66,9 @@ Karpathy 启发式行为指南,适用于 OpenCode。可根据需要与项目
|
||||
- 变量/函数命名:`snake_case`
|
||||
|
||||
**测试要求**
|
||||
- 新功能建议附测试;修复 bug 建议附回归测试(不主动编写测试)
|
||||
- 简单明确的逻辑不需要创建测试(如枚举字面值、Getter、无分支的简单转换)
|
||||
- 测试结构:AAA 模式 (Arrange-Act-Assert),优先测试边界条件和错误场景
|
||||
- 核心业务逻辑需测试(关键算法、边界条件、错误处理)
|
||||
- 简单逻辑不需要测试(枚举字面值、Getter、无分支的简单转换)
|
||||
- 不主动补测试(除非用户明确要求)
|
||||
|
||||
**错误处理**
|
||||
- 优先使用 `Result` 处理错误,避免 `unwrap()`
|
||||
@@ -73,10 +78,15 @@ Karpathy 启发式行为指南,适用于 OpenCode。可根据需要与项目
|
||||
**安全规范**
|
||||
- 不硬编码密钥,使用环境变量
|
||||
- 用户输入必须验证
|
||||
- 依赖包保持更新
|
||||
- 依赖升级策略:
|
||||
- **安全补丁**:立即升级(修复已知漏洞)
|
||||
- **次要版本**:评估后升级(新功能、向后兼容)
|
||||
- **主要版本**:谨慎升级(可能破坏兼容性,需全面测试)
|
||||
- **验证**:升级后运行完整测试套件确保无回归
|
||||
|
||||
**Git Commit 规范**
|
||||
- 使用 Conventional Commits 格式:`<type>(<scope>): <description>`
|
||||
- **描述使用中文**
|
||||
- 类型:
|
||||
- `feat` - 新功能
|
||||
- `fix` - Bug 修复
|
||||
@@ -87,6 +97,17 @@ Karpathy 启发式行为指南,适用于 OpenCode。可根据需要与项目
|
||||
- `test` - 测试相关
|
||||
- `chore` - 构建/工具/配置
|
||||
- 描述使用祈使句、现在时态、句尾无句号
|
||||
- **Scope 推导规则**:
|
||||
- 从被提交的文件路径中推导出一个最相关的 scope
|
||||
- 一个 commit 只写一个主要 scope,不要罗列多个
|
||||
- 如果改动涉及多个模块,选择影响范围最大的那个或使用更上层抽象名称
|
||||
- 如果改动是全局的,可以省略 scope 或使用 `core`/`global`
|
||||
- **示例**:
|
||||
- `feat(llm): 添加流式响应支持`(修改 `src/llm/stream.rs`)
|
||||
- `fix(memory): 修复向量检索边界条件`(修改 `src/memory/vector_store.rs`)
|
||||
- `refactor(core): 统一错误类型定义`(修改多个模块的错误处理)
|
||||
- `docs: 更新 API 文档`(全局文档更新)
|
||||
- `chore(deps): 升级 tokio 到 1.40`(依赖更新)
|
||||
|
||||
---
|
||||
|
||||
@@ -102,7 +123,47 @@ Karpathy 启发式行为指南,适用于 OpenCode。可根据需要与项目
|
||||
|
||||
**模块组织**
|
||||
- 按功能领域组织模块(一个模块一个职责)
|
||||
- 使用 Rust 新风格模块结构:`foo.rs` 作为模块根,子模块在 `foo/bar.rs`
|
||||
- **2018+ 版风格**:使用 `foo.rs` 作为模块根,子模块放在 `foo/` 目录下
|
||||
- **扁平优先**:如果模块没有子模块,直接用单个 `foo.rs` 文件
|
||||
- **需要子模块时才用目录**:只有当模块需要组织多个子模块时才创建 `foo/` 目录
|
||||
- **避免 `mod.rs`**:不使用 `foo/mod.rs` 风格,统一使用 `foo.rs` 作为模块根
|
||||
- **公共 API 重导出**:在模块根中使用 `pub use` 重新导出子模块的公共类型,提供清晰的 API 边界
|
||||
|
||||
**示例结构图**:
|
||||
```
|
||||
src/
|
||||
├── lib.rs # crate 根,声明顶层模块
|
||||
├── llm.rs # LLM 模块(无子模块,扁平文件)
|
||||
├── memory.rs # 记忆模块根(有子模块)
|
||||
├── memory/
|
||||
│ ├── conversation.rs # ✓ 子模块
|
||||
│ └── vector_store.rs # ✓ 子模块
|
||||
├── prompt.rs # 提示词模块根(有子模块)
|
||||
├── prompt/
|
||||
│ ├── template.rs # ✓ 子模块
|
||||
│ └── optimizer.rs # ✓ 子模块
|
||||
├── tool.rs # 工具模块根(有子模块)
|
||||
├── tool/
|
||||
│ ├── mcp.rs # ✓ 子模块
|
||||
│ └── registry.rs # ✓ 子模块
|
||||
└── agent.rs # Agent 模块根(有子模块)
|
||||
└── agent/
|
||||
├── runtime.rs # ✓ 子模块
|
||||
└── planner.rs # ✓ 子模块
|
||||
|
||||
# ❌ 避免的风格
|
||||
src/
|
||||
└── memory/
|
||||
└── mod.rs # ❌ 不使用 mod.rs 风格
|
||||
|
||||
# ✅ memory.rs 示例:声明子模块并重导出公共 API
|
||||
pub mod conversation;
|
||||
pub mod vector_store;
|
||||
|
||||
// 重导出常用类型,提供简洁的公共 API
|
||||
pub use conversation::ConversationMemory;
|
||||
pub use vector_store::VectorStore;
|
||||
```
|
||||
|
||||
**测试文件**: 内联测试(`#[cfg(test)] mod tests {}`)或 `tests/` 目录
|
||||
|
||||
@@ -110,6 +171,22 @@ Karpathy 启发式行为指南,适用于 OpenCode。可根据需要与项目
|
||||
|
||||
---
|
||||
|
||||
## 文档规范
|
||||
|
||||
### 方案规范 (docs/)
|
||||
|
||||
**编号规则**:创建新方案前必须先通过 shell 命令确认当前实际最大编号(Unix: `ls docs/` / Windows: `dir docs\`),禁止使用上下文中缓存的编号,如遇冲突自动递增
|
||||
|
||||
**方案文档结构**(6 项):
|
||||
1. **背景与目标** - 问题描述、预期目标
|
||||
2. **需求分析** - 功能需求、非功能需求
|
||||
3. **方案设计** - 架构设计、模块划分、接口定义
|
||||
4. **实现计划** - 任务拆解、优先级、时间估算
|
||||
5. **风险评估** - 潜在风险、缓解措施
|
||||
6. **验收标准** - 可验证的完成条件
|
||||
|
||||
---
|
||||
|
||||
## 项目特定规则
|
||||
|
||||
### 项目结构
|
||||
@@ -143,4 +220,4 @@ agcore 是一个智能体(Agent)核心工具箱,提供:
|
||||
- diff 中不必要的改动更少
|
||||
- 因过度复杂而导致的重写更少
|
||||
- 澄清问题在实现之前提出
|
||||
- 干净、精简的 PR
|
||||
- 干净、精简的 PR
|
||||
|
||||
+19
@@ -4,3 +4,22 @@ version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "2"
|
||||
async-trait = "0.1"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
tokio-stream = "0.1"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
futures-core = "0.3"
|
||||
bytes = "1"
|
||||
async-stream = "0.3"
|
||||
tokio-util = { version = "0.7", features = ["rt"] }
|
||||
time = { version = "0.3", features = ["serde"] }
|
||||
|
||||
[dev-dependencies]
|
||||
dotenvy = "0.15.7"
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
# 方案:重构 `types.rs` 为完整的 OpenAI 兼容 API 类型系统
|
||||
|
||||
## 1. 现状分析
|
||||
|
||||
### 当前问题
|
||||
|
||||
| 问题 | 详细 |
|
||||
|------|------|
|
||||
| **无 serde** | 所有类型只有 `Debug + Clone`,无 `Serialize/Deserialize`,迫使 `OpenaiProvider` 手动构建 JSON(354 行中约 200 行是序列化代码) |
|
||||
| **请求参数不全** | `ChatRequest` 只支持 `model, messages, system_prompt, tools, max_tokens, temperature, extra_body`,缺失 streaming、response_format、tool_choice、stop、reasoning_effort 等 30+ 参数 |
|
||||
| **响应类型太薄** | `ChatResponse` 只返回 `message + usage + stop_reason`,缺失 `id, created, model, choices` 数组、`logprobs`、`system_fingerprint` 等 |
|
||||
| **无流式支持** | 无 `ChatCompletionChunk` 类型,无法处理 SSE 流式响应 |
|
||||
| **反向依赖** | `types.rs` 引用 `cycle::usage::Usage`,造成模块间反向依赖 |
|
||||
| **手动解析易出错** | `parse_response()` 从 `Value` 中逐字段解析,逻辑脆弱,不支持复杂嵌套类型 |
|
||||
|
||||
### OpenAI API 参考文档覆盖范围
|
||||
|
||||
已完整阅读文档(2177 行),涵盖了完整的请求参数(35+ 个顶层参数)和响应结构。
|
||||
|
||||
## 2. 新类型系统设计
|
||||
|
||||
### 架构
|
||||
|
||||
将 `types.rs` 重构为 Rust 新风格模块目录(符合项目已有惯例),按功能领域拆分:
|
||||
|
||||
```
|
||||
src/llm/
|
||||
├── types/
|
||||
│ ├── mod.rs # 模块根:re-exports + 基础枚举/共用类型
|
||||
│ ├── request.rs # 请求参数(ChatCompletionRequest 等)
|
||||
│ ├── response.rs # 响应类型(ChatCompletionResponse + ChatCompletionChunk)
|
||||
│ ├── message.rs # 消息类型(6 种角色消息 + content parts)
|
||||
│ ├── tool.rs # 工具定义 + 工具调用
|
||||
│ ├── usage.rs # Token 用量(从 cycle/usage.rs 移入,消除反向依赖)
|
||||
│ └── shared.rs # 共用枚举(ReasoningEffort, ServiceTier, ResponseFormat 等)
|
||||
```
|
||||
|
||||
同时,将 `cycle/usage.rs` 中的 `Usage` 和 `CostTracker` **移到** `types/usage.rs`,`cycle/usage.rs` 保留 `pub use` 兼容 re-export。
|
||||
|
||||
### 核心决策
|
||||
|
||||
| 决策 | 选择 | 理由 |
|
||||
|------|------|------|
|
||||
| **序列化方式** | 全部类型 derive `Serialize, Deserialize` | 消除手动 JSON 构建,让 provider 直接 `.json(&req)` / `.json::<Res>()` |
|
||||
| **类型风格** | 直接映射 OpenAI API JSON 形状 | 一目了然,与 API 文档 1:1 对应,调试方便 |
|
||||
| **命名策略** | 添加 `OpenAI` 前缀(如 `OpenaiChatRequest`) | 明确标注为 OpenAI 兼容类型 |
|
||||
| **字段命名** | `#[serde(rename_all = "snake_case")]` | OpenAI API 使用 snake_case |
|
||||
| **可选字段** | `#[serde(skip_serializing_if = "Option::is_none")]` | 不序列化 None 字段,保持请求体干净 |
|
||||
| **默认值** | `#[serde(default)]` | 反序列化时缺失字段用默认值 |
|
||||
| **后向兼容** | 通过类型别名保持 `ChatRequest`/`ChatResponse` 等名称可用 | LlmProvider/LlmCycle 接口不变 |
|
||||
| **泛化策略** | Anthropic 是独立体系,暂不纳入当前设计 | 保持当前类型系统专注 OpenAI,Provider 层做转换 |
|
||||
|
||||
### 关键类型设计原则
|
||||
|
||||
- **`OpenaiChatRequest`**:统一结构体(不拆分 NonStreaming/Streaming),包含 `stream: Option<bool>` 字段,所有字段均为 `Option`,build 时 `skip_serializing_if`
|
||||
- **`OpenaiChatResponse`**:直接对应 `ChatCompletion`(完整响应),保留完整 choices 数组等所有字段
|
||||
- **`OpenaiChatChunk`**:对应流式 chunk,`object = "chat.completion.chunk"`
|
||||
- **消息系统**:用单个 `OpenaiChatMessage` enum 覆盖 6 种角色消息类型(Developer/System/User/Assistant/Tool/Function),每种内部使用对应 struct
|
||||
- **Content parts**:`OpenaiContentPart` enum 覆盖 text/image_url/input_audio/file/refusal
|
||||
|
||||
## 3. 完整类型清单
|
||||
|
||||
### `types/mod.rs` — 共用类型
|
||||
```
|
||||
Role → enum { Developer, System, User, Assistant, Tool, Function }
|
||||
FinishReason → enum { Stop, Length, ToolCalls, ContentFilter, FunctionCall }
|
||||
ServiceTier → enum { Auto, Default, Flex, Scale, Priority }
|
||||
Modality → enum { Text, Audio }
|
||||
ImageDetail → enum { Auto, Low, High }
|
||||
AudioFormat → enum { Wav, Mp3, Aac, Flac, Opus, Pcm16 }
|
||||
Voice → struct { id: String } 或预定义枚举
|
||||
SearchContextSize → enum { Low, Medium, High }
|
||||
StopSequence → enum { Single(String), Multiple(Vec<String>) }
|
||||
Verbosity → enum { Low, Medium, High }
|
||||
```
|
||||
|
||||
### `types/request.rs` — 请求参数
|
||||
```
|
||||
OpenaiChatRequest → struct (35+ 字段,所有 OpenAI 参数)
|
||||
ResponseFormat → enum { Text, JsonObject { .. }, JsonSchema { .. } }
|
||||
ToolChoice → enum { None, Auto, Required, Named { .. }, AllowedTools { .. } }
|
||||
StreamOptions → struct { include_usage, include_obfuscation }
|
||||
AudioParam → struct { format, voice }
|
||||
PredictionContent → struct { type, content }
|
||||
WebSearchOptions → struct { search_context_size, user_location }
|
||||
UserLocation → struct { type, approximate: Approximate }
|
||||
Approximate → struct { city, country, region, timezone }
|
||||
FunctionCallOption → struct { name } // deprecated
|
||||
FunctionDefinition → struct { name, description, parameters, strict }
|
||||
OpenaiTool → enum { Function { .. }, Custom { .. } }
|
||||
```
|
||||
|
||||
### `types/response.rs` — 响应类型
|
||||
```
|
||||
OpenaiChatResponse → struct { id, object, created, model, choices, usage, system_fingerprint, service_tier }
|
||||
Choice → struct { index, message, finish_reason, logprobs }
|
||||
OpenaiChatMessage → struct { content, refusal, role, tool_calls, function_call, audio, annotations }
|
||||
OpenaiChatChunk → struct { id, object, created, model, choices, usage, system_fingerprint, service_tier }
|
||||
ChunkChoice → struct { index, delta, logprobs, finish_reason }
|
||||
Delta → struct { role, content, tool_calls, function_call }
|
||||
Logprobs → struct { content, refusal }
|
||||
TokenLogprob → struct { token, bytes, logprob, top_logprobs }
|
||||
TopLogprob → struct { token, bytes, logprob }
|
||||
Annotation → struct { type, url_citation }
|
||||
URLCitation → struct { end_index, start_index, title, url }
|
||||
OpenaiAudio → struct { id, data, expires_at, transcript }
|
||||
FunctionCall → struct { name, arguments }
|
||||
OpenaiToolCall → enum { Function { id, function, type }, Custom { id, custom, type } }
|
||||
```
|
||||
|
||||
### `types/message.rs` — 消息类型
|
||||
```
|
||||
OpenaiChatMessage → enum (覆盖 6 种角色消息)
|
||||
DeveloperMessage → struct { content, role, name }
|
||||
SystemMessage → struct { content, role, name }
|
||||
UserMessage → struct { content, role, name }
|
||||
AssistantMessage → struct { content, refusal, role, name, tool_calls, function_call, audio }
|
||||
ToolMessage → struct { content, role, tool_call_id }
|
||||
FunctionMessage → struct { content, role, name }
|
||||
OpenaiContentPart → enum
|
||||
OpenaiContentPartText → struct { type, text }
|
||||
OpenaiContentPartImage → struct { type, image_url: ImageURL }
|
||||
OpenaiContentPartInputAudio → struct { type, input_audio: InputAudio }
|
||||
OpenaiContentPartFile → struct { type, file: FileData }
|
||||
OpenaiContentPartRefusal → struct { type, refusal }
|
||||
ImageURL → struct { url, detail }
|
||||
InputAudio → struct { data, format }
|
||||
FileData → struct { file_data, file_id, filename }
|
||||
```
|
||||
|
||||
### `types/tool.rs` — 工具类型
|
||||
```
|
||||
OpenaiToolDefinition → struct { name, description, parameters, strict }
|
||||
(保留 ToolDefinition 别名保持后向兼容,重定义为包含所有字段)
|
||||
OpenaiToolCall (在请求中使用) → 见 response.rs 中的定义
|
||||
```
|
||||
|
||||
### `types/usage.rs` — Token 用量
|
||||
```
|
||||
Usage → struct { prompt_tokens, completion_tokens, total_tokens,
|
||||
completion_tokens_details, prompt_tokens_details }
|
||||
CompletionTokensDetails → struct { reasoning_tokens, audio_tokens,
|
||||
accepted_prediction_tokens, rejected_prediction_tokens }
|
||||
PromptTokensDetails → struct { audio_tokens, cached_tokens }
|
||||
CostTracker → 从 cycle/usage.rs 移入(累计追踪器)
|
||||
```
|
||||
|
||||
### 删除的旧类型
|
||||
- `ContentBlock` → 被 `OpenaiContentPart` 替代(更准确的 OpenAI API 命名)
|
||||
- `StopReason` → 被 `FinishReason` 替代(与 API 命名一致)
|
||||
- `Message` → 被 `OpenaiChatMessage` 替代
|
||||
|
||||
### 类型别名(后向兼容)
|
||||
```
|
||||
ChatRequest = OpenaiChatRequest
|
||||
ChatResponse = OpenaiChatResponse
|
||||
Message = OpenaiChatMessage
|
||||
ContentBlock = OpenaiContentPart
|
||||
ToolDefinition = OpenaiToolDefinition
|
||||
Role = Role(保持不变,但扩展变体)
|
||||
StopReason = FinishReason
|
||||
```
|
||||
|
||||
## 4. 对其他模块的影响
|
||||
|
||||
### `provider/openai.rs`
|
||||
- **大幅简化**:`build_request_body()` → 直接 `serde_json::to_value(&request)`
|
||||
- `parse_response()` 中 100+ 行手动解析 → 直接 `serde_json::from_value::<OpenaiChatResponse>()`
|
||||
- `serialize_messages()`, `serialize_message()`, `serialize_content_block()`, `serialize_tool()` → **全部删除**
|
||||
- 新增 `chat_stream()` 方法返回 `OpenaiChatChunk` 流
|
||||
- 需要适配新类型的字段名变更(如 `Usage` 中 `input_tokens` → `prompt_tokens`)
|
||||
|
||||
### `provider.rs` (trait)
|
||||
- 接口保持不变,继续使用 `ChatRequest`/`ChatResponse` 类型别名
|
||||
- 调整 `Usage` 类型引用路径
|
||||
|
||||
### `cycle.rs`
|
||||
- `CycleConfig` 扩展支持更多请求参数(至少增加 `tools, tool_choice, response_format, stop, reasoning_effort, seed` 等)
|
||||
- `LlmCycle::submit()` 构建 `ChatRequest` 时使用新类型
|
||||
- `response.usage` 字段类型变更(新 `Usage` 含更多字段)
|
||||
- 此时不添加流式支持
|
||||
|
||||
### `cycle/usage.rs`
|
||||
- `Usage` 结构体**被移走**到 `types/usage.rs`
|
||||
- `cycle/usage.rs` 保留 `pub use crate::llm::types::usage::{Usage, CostTracker};` 作为兼容性 re-export
|
||||
- `CostTracker` 逻辑不变
|
||||
|
||||
### `error.rs`
|
||||
- 无明显变更,错误类型和映射逻辑不变
|
||||
|
||||
## 5. 实施步骤
|
||||
|
||||
### Phase 1: 基础设施
|
||||
```
|
||||
1. [准备] 在 Cargo.toml 中确认 serde 依赖(已有 serde = "1",features = ["derive"])
|
||||
2. [创建] 新建 src/llm/types/ 目录
|
||||
```
|
||||
|
||||
### Phase 2: 类型定义(按依赖顺序)
|
||||
```
|
||||
3. [usage.rs] 从 cycle/usage.rs 迁移 Usage + CostTracker
|
||||
4. [shared.rs] 定义 Role, FinishReason, ServiceTier, Modality, ImageDetail, StopSequence, ResponseFormat
|
||||
5. [message.rs] 定义 OpenaiChatMessage(6种角色)+ OpenaiContentPart + ImageURL + InputAudio
|
||||
6. [tool.rs] 定义 OpenaiToolDefinition + OpenaiToolCall + FunctionCall
|
||||
7. [request.rs] 定义 OpenaiChatRequest(35+ 字段)+ ToolChoice + StreamOptions
|
||||
8. [response.rs] 定义 OpenaiChatResponse + OpenaiChatChunk + Choice + Delta + Logprobs
|
||||
```
|
||||
|
||||
### Phase 3: 模块组装
|
||||
```
|
||||
9. [mod.rs] 创建模块根,re-export 所有类型 + 别名(ChatRequest = OpenaiChatRequest 等)
|
||||
10. [usage.rs] 更新 cycle/usage.rs 为 pub use re-export
|
||||
11. [删除] 删除旧 src/llm/types.rs
|
||||
```
|
||||
|
||||
### Phase 4: Provider 适配
|
||||
```
|
||||
12. [provider/openai.rs] 重写为 serde 序列化(删除 ~200 行手动代码)
|
||||
13. [cycle.rs] 适配新类型字段(prompt_tokens vs input_tokens)
|
||||
```
|
||||
|
||||
### Phase 5: 验证
|
||||
```
|
||||
14. [编译] cargo check 确保编译通过
|
||||
15. [检查] cargo clippy 确保无警告
|
||||
16. [测试] cargo test 确保测试通过
|
||||
```
|
||||
|
||||
## 6. 验证方式
|
||||
|
||||
- `cargo check` — 编译通过
|
||||
- `cargo clippy` — 无警告
|
||||
- `cargo test` — 所有测试通过(如果有集成测试,可能需要调整)
|
||||
- 检查 `OpenaiProvider` 代码量减少(预期从 354 行降至 ~150 行)
|
||||
- 手动验证序列化输出是否符合 OpenAI API 格式
|
||||
|
||||
## 7. 注意事项
|
||||
|
||||
1. **Break change**: 某些类型名称变化(如 `StopReason` → `FinishReason`),项目处于早期阶段,可接受
|
||||
2. **后向兼容**: 通过类型别名保持旧名称可用,接口层无需修改
|
||||
3. **Anthropic 处理**: Anthropic 是独立体系,不在当前设计中泛化,单独实现 Provider
|
||||
4. **异步流**: `chat_stream()` 的签名需要仔细设计(`Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>>>>` 或自定义类型)
|
||||
5. **CostTracker 不变**: 虽然 Usage 变复杂了,但 CostTracker 只累计 input/output token 数,逻辑不变
|
||||
@@ -0,0 +1,375 @@
|
||||
# Phase 0 剩余模块 — 实施方案
|
||||
|
||||
> 定稿日期:2026-06-02
|
||||
|
||||
## 背景与目标
|
||||
|
||||
AG Core Phase 0(Foundation)已完成核心数据类型、错误体系、Provider 接口、OpenAI 实现、生命周期引擎等基础设施。剩余 4 个子项尚未实现:**ProviderRegistry**、**HookExecutor**、**StreamEvents**、**Auto-compaction**。它们均被 Roadmap 标记为 P0(Must Have),是本阶段不可或缺的底层能力。
|
||||
|
||||
**目标**:完成这 4 个模块的设计与实现,使 Phase 0 全面交付。
|
||||
|
||||
---
|
||||
|
||||
## 需求分析
|
||||
|
||||
### 功能需求
|
||||
|
||||
| 模块 | 需求 | 验收条件 |
|
||||
|------|------|---------|
|
||||
| ProviderRegistry | 支持注册命名 Provider、按名称查找、设置默认 | 3 个公开方法 + 工厂辅助 |
|
||||
| HookExecutor | 4 个事件点:PreRequest / PostRequest / OnRetry / OnError | Hook trait + HookExecutor 触发 |
|
||||
| StreamEvents | 流式事件枚举 + Provider 流式方法 + Cycle 流式入口 | 6 种事件类型 + SSE 解析 |
|
||||
| Auto-compaction | Token 估算 + 微压缩 + 断路器 | 触发后释放 token 且不改变语义 |
|
||||
|
||||
### 非功能需求
|
||||
|
||||
- 所有公开 API 必须带 `///` 文档注释
|
||||
- 无新增 `unwrap()` 调用
|
||||
- 与现有 `LlmCycle` 集成时保持向后兼容(全为可选/增量添加)
|
||||
- 错误统一使用 `LlmError` 枚举
|
||||
|
||||
---
|
||||
|
||||
## 方案设计
|
||||
|
||||
### 1. ProviderRegistry (`src/llm/provider/registry.rs`)
|
||||
|
||||
**职责**:管理多个 LLM Provider 实例,支持按名称注册、发现、切换。
|
||||
|
||||
```rust
|
||||
// src/llm/provider/registry.rs
|
||||
|
||||
use std::collections::HashMap;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::provider::{LlmProvider, ProviderConfig, ProviderType, create_provider};
|
||||
|
||||
/// Provider 注册表——管理多个 LLM Provider 实例。
|
||||
pub struct ProviderRegistry {
|
||||
providers: HashMap<String, Box<dyn LlmProvider>>,
|
||||
default_name: Option<String>,
|
||||
}
|
||||
|
||||
impl ProviderRegistry {
|
||||
pub fn new() -> Self;
|
||||
|
||||
/// 注册一个已初始化的 Provider 实例。
|
||||
pub fn register(&mut self, name: impl Into<String>, provider: Box<dyn LlmProvider>);
|
||||
|
||||
/// 通过 ProviderType + ProviderConfig 创建并注册。
|
||||
pub fn register_with_config(
|
||||
&mut self,
|
||||
name: impl Into<String>,
|
||||
provider_type: ProviderType,
|
||||
config: ProviderConfig,
|
||||
) -> Result<(), LlmError>;
|
||||
|
||||
/// 设置默认 Provider。
|
||||
pub fn set_default(&mut self, name: &str) -> Result<(), LlmError>;
|
||||
|
||||
/// 按名称查找 Provider。
|
||||
pub fn get(&self, name: &str) -> Option<&dyn LlmProvider>;
|
||||
|
||||
/// 获取默认 Provider。
|
||||
pub fn get_default(&self) -> Option<&dyn LlmProvider>;
|
||||
}
|
||||
```
|
||||
|
||||
**无新增依赖**。
|
||||
|
||||
---
|
||||
|
||||
### 2. HookExecutor (`src/llm/hooks.rs`)
|
||||
|
||||
**职责**:在 LLM 调用生命周期的关键节点插入自定义逻辑。
|
||||
|
||||
```rust
|
||||
// src/llm/hooks.rs
|
||||
|
||||
use async_trait::async_trait;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::ChatRequest;
|
||||
|
||||
/// 生命周期钩子事件点。
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum HookEvent {
|
||||
/// LLM 请求发起之前(可阻断)。
|
||||
PreRequest,
|
||||
/// 成功响应之后。
|
||||
PostRequest,
|
||||
/// 重试之前(仅可重试错误时触发)。
|
||||
OnRetry,
|
||||
/// 不可恢复错误返回之前。
|
||||
OnError,
|
||||
}
|
||||
|
||||
/// 此次钩子调用的上下文。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookContext<'a> {
|
||||
pub event: HookEvent,
|
||||
pub request: Option<&'a ChatRequest>,
|
||||
pub error: Option<&'a LlmError>,
|
||||
pub attempt: u32,
|
||||
}
|
||||
|
||||
/// 钩子执行结果。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookResult {
|
||||
/// 是否阻断后续操作(仅 PreRequest 有效)。
|
||||
pub should_block: bool,
|
||||
/// 阻断/备注原因。
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
/// 生命周期钩子 trait。
|
||||
#[async_trait]
|
||||
pub trait Hook: Send + Sync {
|
||||
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult;
|
||||
}
|
||||
|
||||
/// 钩子执行器——管理注册与触发。
|
||||
pub struct HookExecutor {
|
||||
hooks: Vec<(HookEvent, Box<dyn Hook>)>,
|
||||
}
|
||||
|
||||
impl HookExecutor {
|
||||
pub fn new() -> Self;
|
||||
pub fn register(&mut self, event: HookEvent, hook: Box<dyn Hook>);
|
||||
pub async fn execute(&self, event: HookEvent, ctx: &HookContext<'_>) -> Vec<HookResult>;
|
||||
}
|
||||
```
|
||||
|
||||
**与 LlmCycle 集成**:
|
||||
- `LlmCycle` 新增字段 `hook_executor: Option<HookExecutor>`
|
||||
- 新增 builder 方法 `with_hook_executor()`
|
||||
- `submit()` 中在 4 个点触发(PreRequest 若阻断则提前返回)
|
||||
|
||||
**无新增依赖**(async-trait 已存在)。
|
||||
|
||||
---
|
||||
|
||||
### 3. StreamEvents (`src/llm/stream.rs`)
|
||||
|
||||
**职责**:提供流式 LLM 调用的事件抽象,将原始 SSE chunk 解析为语义化事件。
|
||||
|
||||
```rust
|
||||
// src/llm/stream.rs
|
||||
|
||||
use std::pin::Pin;
|
||||
use tokio_stream::Stream;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{FinishReason, Usage, OpenaiChatChunk, ToolDefinition};
|
||||
use serde_json::Value;
|
||||
|
||||
/// 流式事件——LLM 调用全生命周期的语义化事件。
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamEvent {
|
||||
/// 助手回复文本增量。
|
||||
AssistantTextDelta { text: String },
|
||||
/// 工具调用开始。
|
||||
ToolExecutionStarted { tool_name: String, input: Value },
|
||||
/// 工具调用完成。
|
||||
ToolExecutionCompleted { tool_name: String, output: Value, is_error: bool },
|
||||
/// Token 用量更新。
|
||||
CostUpdate { usage: Usage },
|
||||
/// 一轮会话完成。
|
||||
TurnComplete { reason: FinishReason },
|
||||
/// 可恢复的错误事件。
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
/// 将原始 OpenaiChatChunk 流解析为 StreamEvent 流。
|
||||
pub fn parse_chunk_stream(
|
||||
chunks: Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
) -> Pin<Box<dyn Stream<Item = StreamEvent> + Send>>;
|
||||
```
|
||||
|
||||
**Provider 层扩展**(`src/llm/provider.rs`):
|
||||
```rust
|
||||
#[async_trait]
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError>;
|
||||
|
||||
/// 流式聊天请求——返回原始 SSE chunk 流。
|
||||
/// 默认实现回退到非流式调用。
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
request: ChatRequest,
|
||||
) -> Result<
|
||||
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
LlmError,
|
||||
> {
|
||||
let response = self.chat(request).await?;
|
||||
let chunk = OpenaiChatChunk::from(response);
|
||||
Ok(Box::pin(tokio_stream::once(Ok(chunk))))
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**LlmCycle 扩展**:
|
||||
```rust
|
||||
impl LlmCycle {
|
||||
/// 提交请求并返回语义事件流。
|
||||
pub async fn submit_stream(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
tools: Vec<ToolDefinition>,
|
||||
) -> Result<
|
||||
Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
|
||||
LlmError,
|
||||
>;
|
||||
}
|
||||
```
|
||||
|
||||
**新增依赖**:
|
||||
```toml
|
||||
tokio-stream = "0.1"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. Auto-compaction (`src/llm/compact.rs`)
|
||||
|
||||
**职责**:在上下文过长时自动压缩历史消息,避免 ContextLength 错误。
|
||||
|
||||
```rust
|
||||
// src/llm/compact.rs
|
||||
|
||||
use crate::llm::types::{ContentField, Message, OpenaiChatMessage, OpenaiContentPart};
|
||||
|
||||
// === 常量 ===
|
||||
const AUTOCOMPACT_BUFFER_TOKENS: u32 = 13_000;
|
||||
const RESERVED_OUTPUT_TOKENS: u32 = 20_000;
|
||||
const MAX_CONSECUTIVE_FAILURES: u32 = 3;
|
||||
const KEEP_RECENT: usize = 6;
|
||||
|
||||
/// 上下文压缩配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactConfig {
|
||||
/// 模型上下文窗口大小(token 数)。
|
||||
pub context_window: u32,
|
||||
/// 为输出预留的 token 数。
|
||||
pub reserved_tokens: u32,
|
||||
/// 微压缩保留的最近消息数。
|
||||
pub keep_recent: usize,
|
||||
}
|
||||
|
||||
impl Default for CompactConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
context_window: 128_000,
|
||||
reserved_tokens: RESERVED_OUTPUT_TOKENS,
|
||||
keep_recent: KEEP_RECENT,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompactConfig {
|
||||
pub fn threshold(&self) -> u32 {
|
||||
self.context_window
|
||||
.saturating_sub(self.reserved_tokens)
|
||||
.saturating_sub(AUTOCOMPACT_BUFFER_TOKENS)
|
||||
}
|
||||
}
|
||||
|
||||
/// 压缩状态——跟踪连续失败次数(断路器模式)。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactState {
|
||||
consecutive_failures: u32,
|
||||
}
|
||||
|
||||
impl CompactState {
|
||||
pub fn new() -> Self;
|
||||
pub fn record_success(&mut self);
|
||||
/// 记录失败,返回 true 表示已达断路器上限。
|
||||
pub fn record_failure(&mut self) -> bool;
|
||||
}
|
||||
|
||||
/// 粗略估计消息列表的 token 数(基于字符数,4 字符 ≈ 1 token)。
|
||||
pub fn estimate_message_tokens(messages: &[Message]) -> u32;
|
||||
|
||||
/// 判断是否需要触发自动压缩。
|
||||
pub fn should_compact(messages: &[Message], config: &CompactConfig, state: &CompactState) -> bool;
|
||||
|
||||
/// 执行微压缩——用占位符替换旧的 tool result 内容。
|
||||
/// 返回释放的 token 数。
|
||||
pub fn microcompact(messages: &mut Vec<Message>, keep_recent: usize) -> u32;
|
||||
```
|
||||
|
||||
**与 LlmCycle 集成**:
|
||||
- `LlmCycle` 新增字段 `compact_config: Option<CompactConfig>`, `compact_state: CompactState`
|
||||
- 新增 builder 方法 `with_compact_config()`
|
||||
- `submit()` 开始时调用 `should_compact()` → `microcompact()`
|
||||
|
||||
**完整 LLM 摘要压缩**留占位,Phase 2 实现(需要循环内调用 LLM 的能力)。
|
||||
|
||||
**无新增依赖**。
|
||||
|
||||
---
|
||||
|
||||
## 实现计划
|
||||
|
||||
### Step 1: 先写方案文档
|
||||
|
||||
创建 `docs/3-phase0-remaining.md`(即本文档)。
|
||||
|
||||
### Step 2: ProviderRegistry
|
||||
|
||||
- 创建 `src/llm/provider/registry.rs`
|
||||
- `provider.rs` 添加 `pub mod registry;`
|
||||
- `cargo check` 验证
|
||||
|
||||
### Step 3: HookExecutor
|
||||
|
||||
- 创建 `src/llm/hooks.rs`
|
||||
- `llm.rs` 添加 `pub mod hooks;`
|
||||
- `LlmCycle` 新增字段和方法
|
||||
- `submit()` 中插入钩子触发点
|
||||
- `cargo check` 验证
|
||||
|
||||
### Step 4: StreamEvents
|
||||
|
||||
- `Cargo.toml` 添加 `tokio-stream`
|
||||
- 创建 `src/llm/stream.rs`
|
||||
- `llm.rs` 添加 `pub mod stream;`
|
||||
- `LlmProvider` 添加 `chat_stream()`(默认回退)
|
||||
- `OpenaiProvider` 实现 SSE 解析
|
||||
- `LlmCycle` 添加 `submit_stream()`
|
||||
- `cargo check` 验证
|
||||
|
||||
### Step 5: Auto-compaction
|
||||
|
||||
- 创建 `src/llm/compact.rs`
|
||||
- `llm.rs` 添加 `pub mod compact;`
|
||||
- `LlmCycle` 新增字段和方法
|
||||
- `submit()` 开头插入压缩检查
|
||||
- `cargo check` 验证
|
||||
|
||||
### Step 6: 收尾
|
||||
|
||||
- `cargo clippy` — 无警告
|
||||
- `cargo build --release` — 完整构建
|
||||
- 检查所有新公开 API 有 `///` 注释
|
||||
|
||||
---
|
||||
|
||||
## 风险评估
|
||||
|
||||
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||
|------|------|------|---------|
|
||||
| SSE 解析边界情况 | 中 | 中 | 参考 `reqwest` 的 chunked 响应处理;先用简单的 `lines()` 方式逐行读取 |
|
||||
| token 估算不准 | 中 | 低 | 仅用于触发微压缩的阈值判断,保守估算即可;后续可接入 tiktoken-rs |
|
||||
| 钩子阻断语义复杂 | 低 | 中 | PreRequest 阻断后返回明确错误消息;其他事件点只读不阻断 |
|
||||
| 与后续 Phase 冲突 | 低 | 高 | 保持接口向后兼容,全用可选集成(Option) |
|
||||
|
||||
---
|
||||
|
||||
## 验收标准
|
||||
|
||||
1. `cargo check` 编译通过
|
||||
2. `cargo clippy` 无警告
|
||||
3. 4 个模块文件存在且路径正确
|
||||
4. `ProviderRegistry` 支持注册/查找/默认 Provider
|
||||
5. `HookExecutor` 支持 4 个事件点注册与触发
|
||||
6. `StreamEvents` 支持从 SSE chunk 解析为语义事件
|
||||
7. `Auto-compaction` 支持 token 估算与微压缩
|
||||
8. `LlmCycle` 向后兼容(新增字段全为 Option,不影响现有代码)
|
||||
@@ -0,0 +1,611 @@
|
||||
# Phase 1: Prompt Engineering — 方案设计
|
||||
|
||||
> 定稿日期:2026-06-02
|
||||
|
||||
## 背景与目标
|
||||
|
||||
AG Core Phase 0(Foundation)已完成 LLM 调用周期的全部基础设施。Phase 1 的目标是补齐**提示词工程**能力,提供提示词的组合、模板化与优化能力,使其能直接服务于 Phase 2(工具系统)和 Phase 4(Agent 运行时)。
|
||||
|
||||
**目标**:实现 `PromptTemplate`(模板引擎)和 `PromptComposer`(提示词组合器)两个核心组件,覆盖变量插值、条件渲染、多消息序列拼接等场景。
|
||||
|
||||
---
|
||||
|
||||
## 需求分析
|
||||
|
||||
### 功能需求
|
||||
|
||||
| 模块 | 需求 | 验收条件 |
|
||||
|------|------|---------|
|
||||
| `PromptTemplate` | 支持变量插值 `{{ var }}` | 渲染后正确替换所有变量 |
|
||||
| `PromptTemplate` | 支持条件渲染 `{{#if var}}...{{/if}}` | 变量非空/非 false 时渲染块 |
|
||||
| `PromptTemplate` | 支持列表循环 `{{#each list}}...{{/each}}` | 遍历渲染集合元素 |
|
||||
| `PromptTemplate` | 支持嵌套模板引用 `{{> template_name }}` | 引用已注册的模板片段 |
|
||||
| `PromptTemplate` | 支持部分转义(原样输出 `{{ literal }}`) | 提供 raw block 语法 |
|
||||
| `PromptComposer` | 按角色拼接 system/user/assistant 消息序列 | 输出 `Vec<OpenaiChatMessage>` |
|
||||
| `PromptComposer` | 支持插入预编译的 PromptTemplate | 组合器中混合使用静态文本和模板 |
|
||||
| `PromptComposer` | 支持从已有消息列表扩展 | 接收 `Vec<OpenaiChatMessage>` 作为初始状态 |
|
||||
| `PromptComposer` | 支持多模态 ContentPart 构建(图片/音频/文件) | `user_content()`/`system_content()` 等接受 `OpenaiContentPart` |
|
||||
| `PromptComposer` | 支持 Developer 角色(o1 系列模型) | `developer()` / `developer_template()` / `developer_content()` |
|
||||
| `PromptComposer` | 支持 Tool 角色(工具执行结果回传) | `tool()` / `tool_content()` 接受 `tool_call_id` |
|
||||
| `PromptComposer` | 支持 `name` 字段设置(多角色区分) | `with_name()` 链式方法为上一消息设置名称 |
|
||||
| `PromptComposer` | 支持消息序列合法性验证 | `validate_messages()` 检查顺序约束与角色交替 |
|
||||
| `PromptTemplateRegistry` | 模板注册表:按名称注册、查找、文件加载 | 从字符串/文件注册模板,按名称渲染 |
|
||||
| `PromptTemplateRegistry` | 支持延迟编译模式 | `register_lazy()` 存储原始字符串,首次渲染时编译 |
|
||||
| `PromptTemplate` | 实现 `Display` trait | 输出原始模板字符串(用于日志和调试) |
|
||||
| `TemplateContext` | 支持从 JSON 构造 | `from_json()` 递归转换 `serde_json::Value` 为 `TemplateValue` |
|
||||
|
||||
### 非功能需求
|
||||
|
||||
- 所有公开 API 必须带 `///` 文档注释
|
||||
- 无新增 `unwrap()` 调用
|
||||
- **零运行时依赖**(不使用 tera、askama 等模板引擎 crate)
|
||||
- 模板引擎失败时返回结构化错误(`PromptError`)
|
||||
- 与现有 `OpenaiChatMessage` / `ChatRequest` 类型自然集成
|
||||
|
||||
---
|
||||
|
||||
## 方案设计
|
||||
|
||||
### 模块结构
|
||||
|
||||
```
|
||||
src/
|
||||
prompt.rs # prompt 模块根:声明子模块 + 重导出公共 API
|
||||
prompt/
|
||||
template.rs # PromptTemplate — 模板引擎
|
||||
template/
|
||||
registry.rs # PromptTemplateRegistry — 模板注册表
|
||||
composer.rs # PromptComposer — 提示词组合器
|
||||
error.rs # PromptError — 错误类型
|
||||
```
|
||||
|
||||
`prompt.rs` 根模块声明:
|
||||
|
||||
```rust
|
||||
// prompt.rs
|
||||
pub mod composer;
|
||||
pub mod error;
|
||||
pub mod template;
|
||||
|
||||
pub use composer::PromptComposer;
|
||||
pub use error::PromptError;
|
||||
pub use template::{PromptTemplate, PromptTemplateRegistry};
|
||||
```
|
||||
|
||||
`lib.rs` 添加:
|
||||
|
||||
```diff
|
||||
pub mod llm;
|
||||
+pub mod prompt;
|
||||
```
|
||||
|
||||
### 1. 模板引擎选择:自建轻量
|
||||
|
||||
**决策**:不使用 `tera` / `askama` / `maud` / `minijinja` 等第三方模板 crate。
|
||||
|
||||
**理由**:
|
||||
- Phase 1 模板需求极其简单(变量插值 + 条件 + 列表循环),不需要 Jinja2/Handlebars 全能力
|
||||
- 无依赖 = 编译更快、无安全面、版本冲突为 0
|
||||
- 自建 50-80 行核心逻辑即可覆盖所有需求
|
||||
- Roadmap 估算 400 行总代码,60 行模板引擎足够
|
||||
|
||||
**语法设计**(参考 Handlebars 子集):
|
||||
|
||||
```
|
||||
{{ variable_name }} → 变量插值
|
||||
{{#if var}}...{{/if}} → 条件渲染(var 存在且非空)
|
||||
{{#if var}}...{{else}}...{{/if}} → 条件 + 否则
|
||||
{{#each list}} {{item}} {{/each}} → 列表循环
|
||||
{{#raw}} {{literal}} {{/raw}} → 原始块(不解析内部模板语法)
|
||||
{{> template_name}} → 引用已注册的嵌套模板
|
||||
```
|
||||
|
||||
### 2. PromptTemplate — 模板引擎
|
||||
|
||||
```rust
|
||||
// prompt/template.rs
|
||||
|
||||
use std::collections::HashMap;
|
||||
use crate::prompt::error::PromptError;
|
||||
|
||||
/// 渲染上下文中使用的值类型。
|
||||
pub enum TemplateValue {
|
||||
String(String),
|
||||
Bool(bool),
|
||||
Array(Vec<TemplateValue>),
|
||||
Object(HashMap<String, TemplateValue>),
|
||||
}
|
||||
|
||||
/// `TemplateValue` 自动转换(提升 `ctx.insert("name", "Alice")` 的易用性)。
|
||||
impl From<String> for TemplateValue { ... }
|
||||
impl From<&str> for TemplateValue { ... }
|
||||
impl From<bool> for TemplateValue { ... }
|
||||
|
||||
/// 模板变量上下文。
|
||||
pub struct TemplateContext {
|
||||
vars: HashMap<String, TemplateValue>,
|
||||
}
|
||||
|
||||
impl TemplateContext {
|
||||
pub fn new() -> Self;
|
||||
|
||||
/// 插入变量(支持 `&str` / `String` / `bool` 自动转换)。
|
||||
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<TemplateValue>);
|
||||
|
||||
pub fn get(&self, key: &str) -> Option<&TemplateValue>;
|
||||
|
||||
/// 从 `serde_json::Value` 递归构造(支持嵌套 Object/Array)。
|
||||
pub fn from_json(value: &serde_json::Value) -> Result<Self, PromptError>;
|
||||
|
||||
/// 从 `HashMap` 构造(适用于配置加载场景)。
|
||||
pub fn from_map(map: HashMap<String, TemplateValue>) -> Self;
|
||||
}
|
||||
|
||||
/// 预编译的模板。
|
||||
pub struct PromptTemplate {
|
||||
/// 原始模板字符串(用于 debug)。
|
||||
raw: String,
|
||||
/// 编译后的 AST 片段。
|
||||
fragments: Vec<Fragment>,
|
||||
}
|
||||
|
||||
/// 编译后的 AST 节点(内部枚举)。
|
||||
enum Fragment {
|
||||
Literal(String),
|
||||
Variable { name: String },
|
||||
If { condition: String, body: Vec<Fragment>, else_body: Vec<Fragment> },
|
||||
Each { variable: String, body: Vec<Fragment> },
|
||||
Raw(String),
|
||||
Include(String),
|
||||
}
|
||||
|
||||
impl PromptTemplate {
|
||||
/// 从模板字符串编译。
|
||||
pub fn compile(template: &str) -> Result<Self, PromptError>;
|
||||
|
||||
/// 使用上下文渲染。
|
||||
pub fn render(&self, ctx: &TemplateContext) -> Result<String, PromptError>;
|
||||
|
||||
/// 注册可引用的子模板。
|
||||
pub fn register_partial(&mut self, name: &str, template: PromptTemplate);
|
||||
}
|
||||
|
||||
/// `Display` 输出原始模板字符串,便于日志和调试。
|
||||
impl fmt::Display for PromptTemplate {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.raw)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**编译流程**:
|
||||
1. 逐字符扫描模板字符串
|
||||
2. 遇到 `{{` 时解析指令类型(variable/if/each/raw/include)
|
||||
3. 生成 `Fragment` AST 列表
|
||||
4. 返回 `PromptTemplate { raw, fragments }`
|
||||
|
||||
**渲染流程**:
|
||||
1. 遍历 `fragments`
|
||||
2. `Literal` → 直接追加
|
||||
3. `Variable { name }` → `ctx.get(name)` → 追加字符串值
|
||||
4. `If { condition, body, else_body }` → 判断 `ctx.get(condition)` 是否存在且真 → 递归渲染 body/else_body
|
||||
5. `Each { variable, body }` → `ctx.get(variable)` 转为数组 → 为每个元素设置 `item` 变量 → 递归渲染 body
|
||||
6. `Raw(text)` → 原样追加(不解析 `{{ }}`)
|
||||
7. `Include(name)` → 查找已注册的 partials → 递归渲染
|
||||
|
||||
### 3. PromptTemplateRegistry — 模板注册表
|
||||
|
||||
提供轻量的模板管理能力,按名称注册、查找、从文件加载,适用于管理多个 Agent 的提示词模板。
|
||||
|
||||
```rust
|
||||
// prompt/template/registry.rs
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use crate::prompt::error::PromptError;
|
||||
use crate::prompt::template::{PromptTemplate, TemplateContext};
|
||||
|
||||
/// 内部存储的模板(支持延迟编译)。
|
||||
enum StoredTemplate {
|
||||
Compiled(PromptTemplate),
|
||||
Raw(String),
|
||||
}
|
||||
|
||||
/// 模板注册表——管理多模板实例。
|
||||
pub struct PromptTemplateRegistry {
|
||||
templates: HashMap<String, StoredTemplate>,
|
||||
}
|
||||
|
||||
impl PromptTemplateRegistry {
|
||||
pub fn new() -> Self;
|
||||
|
||||
/// 从模板字符串编译并注册(立即编译)。
|
||||
pub fn register(&mut self, name: &str, template: &str) -> Result<(), PromptError>;
|
||||
|
||||
/// 延迟编译注册:只存储原始字符串,首次渲染时编译。
|
||||
/// 适合模板数量多但并非全部立即使用的场景。
|
||||
pub fn register_lazy(&mut self, name: &str, template: &str);
|
||||
|
||||
/// 从文件读取并编译注册。
|
||||
pub fn register_file(&mut self, name: &str, path: &Path) -> Result<(), PromptError>;
|
||||
|
||||
/// 获取已注册的模板(延迟编译的模板在此首次编译)。
|
||||
pub fn get(&mut self, name: &str) -> Result<&PromptTemplate, PromptError>;
|
||||
|
||||
/// 按名称渲染(`{{> name }}` 引用时自动查找)。
|
||||
pub fn render(&mut self, name: &str, ctx: &TemplateContext) -> Result<String, PromptError>;
|
||||
}
|
||||
```
|
||||
|
||||
**设计约束**:
|
||||
- 不设全局单例,用户自行创建和持有
|
||||
- 不引入文件系统监听、热加载、序列化等复杂能力
|
||||
- 注册表内部模板可用于解析 `{{> partial_name }}` 子模板引用
|
||||
- 用户仍可单独持有 `PromptTemplate` 实例,不强制使用注册表
|
||||
|
||||
### 4. PromptComposer — 提示词组合器
|
||||
|
||||
```rust
|
||||
// prompt/composer.rs
|
||||
|
||||
use crate::llm::types::message::{OpenaiChatMessage, OpenaiContentPart, ContentField};
|
||||
use crate::prompt::error::PromptError;
|
||||
use crate::prompt::template::{PromptTemplate, TemplateContext};
|
||||
|
||||
/// 提示词组合器——构建多角色消息序列。
|
||||
pub struct PromptComposer {
|
||||
messages: Vec<OpenaiChatMessage>,
|
||||
pending_name: Option<String>,
|
||||
}
|
||||
|
||||
impl PromptComposer {
|
||||
/// 创建一个空的组合器。
|
||||
pub fn new() -> Self;
|
||||
|
||||
/// 从已有的消息列表初始化。
|
||||
pub fn from_messages(messages: Vec<OpenaiChatMessage>) -> Self;
|
||||
|
||||
// ===== 纯文本消息 =====
|
||||
|
||||
/// 添加一条纯文本 system 消息。
|
||||
pub fn system(mut self, text: impl Into<String>) -> Self;
|
||||
|
||||
/// 添加一条纯文本 user 消息。
|
||||
pub fn user(mut self, text: impl Into<String>) -> Self;
|
||||
|
||||
/// 添加一条纯文本 assistant 消息。
|
||||
pub fn assistant(mut self, text: impl Into<String>) -> Self;
|
||||
|
||||
/// 添加一条纯文本 developer 消息(o1 系列模型使用)。
|
||||
pub fn developer(mut self, text: impl Into<String>) -> Self;
|
||||
|
||||
/// 添加一条 Tool 消息(工具执行结果回传)。
|
||||
pub fn tool(mut self, tool_call_id: impl Into<String>, content: impl Into<String>) -> Self;
|
||||
|
||||
// ===== 模板消息 =====
|
||||
|
||||
/// 使用模板和上下文渲染后添加为 user 消息。
|
||||
pub fn user_template(
|
||||
mut self,
|
||||
template: &PromptTemplate,
|
||||
ctx: &TemplateContext,
|
||||
) -> Result<Self, PromptError>;
|
||||
|
||||
/// 使用模板和上下文渲染后添加为 system 消息。
|
||||
pub fn system_template(
|
||||
mut self,
|
||||
template: &PromptTemplate,
|
||||
ctx: &TemplateContext,
|
||||
) -> Result<Self, PromptError>;
|
||||
|
||||
/// 使用模板和上下文渲染后添加为 assistant 消息。
|
||||
pub fn assistant_template(
|
||||
mut self,
|
||||
template: &PromptTemplate,
|
||||
ctx: &TemplateContext,
|
||||
) -> Result<Self, PromptError>;
|
||||
|
||||
/// 使用模板和上下文渲染后添加为 developer 消息。
|
||||
pub fn developer_template(
|
||||
mut self,
|
||||
template: &PromptTemplate,
|
||||
ctx: &TemplateContext,
|
||||
) -> Result<Self, PromptError>;
|
||||
|
||||
// ===== 多模态 ContentPart =====
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 system 消息。
|
||||
pub fn system_content(mut self, part: OpenaiContentPart) -> Self;
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 user 消息。
|
||||
pub fn user_content(mut self, part: OpenaiContentPart) -> Self;
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 assistant 消息。
|
||||
pub fn assistant_content(mut self, part: OpenaiContentPart) -> Self;
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 developer 消息。
|
||||
pub fn developer_content(mut self, part: OpenaiContentPart) -> Self;
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 Tool 消息。
|
||||
pub fn tool_content(mut self, tool_call_id: impl Into<String>, part: OpenaiContentPart) -> Self;
|
||||
|
||||
/// 批量添加 ContentPart 作为 user 消息。
|
||||
pub fn user_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self;
|
||||
|
||||
/// 批量添加 ContentPart 作为 system 消息。
|
||||
pub fn system_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self;
|
||||
|
||||
/// 批量添加 ContentPart 作为 assistant 消息。
|
||||
pub fn assistant_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self;
|
||||
|
||||
/// 批量添加 ContentPart 作为 developer 消息。
|
||||
pub fn developer_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self;
|
||||
|
||||
// ===== 角色标识 =====
|
||||
|
||||
/// 为上一条添加的消息设置 `name` 字段(多 agent 系统中区分同角色实体)。
|
||||
pub fn with_name(mut self, name: impl Into<String>) -> Self;
|
||||
|
||||
// ===== 构建 =====
|
||||
|
||||
/// 构建最终的消息列表。
|
||||
pub fn build(self) -> Vec<OpenaiChatMessage>;
|
||||
|
||||
/// 构建并直接创建 ChatRequest(需搭配 model 参数)。
|
||||
/// 返回的 `OpenaiChatRequest` 中 `tools`、`temperature`、`max_tokens` 等字段均为 `None`,
|
||||
/// 可通过结构体更新语法补全:`OpenaiChatRequest { tools: Some(...), ..req }`。
|
||||
pub fn build_request(
|
||||
self,
|
||||
model: impl Into<String>,
|
||||
) -> crate::llm::types::request::OpenaiChatRequest;
|
||||
}
|
||||
|
||||
/// 验证消息序列是否符合 OpenAI API 要求。
|
||||
/// 检查项:Tool 消息必须紧跟在匹配的 Assistant 消息后、角色交替规则等。
|
||||
pub fn validate_messages(messages: &[OpenaiChatMessage]) -> Result<(), PromptError>;
|
||||
```
|
||||
|
||||
**Builder 模式设计**:
|
||||
- `PromptComposer` 采用链式调用(builder pattern),与 Rust 生态的主流风格一致
|
||||
- 每个 `system()` / `user()` / `assistant()` / `developer()` / `tool()` 方法返回 `Self`,支持连续调用
|
||||
- `with_name()` 作用于上一条消息,内部通过 `pending_name: Option<String>` 暂存,push 消息时消费
|
||||
- `build()` 返回 `Vec<OpenaiChatMessage>`,`build_request()` 创建完整的 `OpenaiChatRequest`
|
||||
- **`ContentField` 类型约定**:所有纯文本消息(`system()` / `user()` / `assistant()` / `developer()`)统一使用 `ContentField::Array(vec![OpenaiContentPart::Text{...}])`,与 OpenAI API 非流式响应的标准格式一致
|
||||
|
||||
**`validate_messages()` 校验规则**:
|
||||
1. `Tool` 角色的消息必须跟在 `Assistant` 角色且含 `tool_calls` 的消息之后
|
||||
2. 禁止连续出现两条同角色的非 Tool 消息(system 除外)
|
||||
3. 消息列表不能为空
|
||||
|
||||
### 5. PromptError — 错误类型
|
||||
|
||||
```rust
|
||||
// prompt/error.rs
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum PromptError {
|
||||
#[error("模板解析错误: {0}")]
|
||||
Parse(String),
|
||||
|
||||
#[error("渲染错误: 变量 '{0}' 未找到")]
|
||||
VariableNotFound(String),
|
||||
|
||||
#[error("渲染错误: 引用的子模板 '{0}' 未注册")]
|
||||
PartialNotFound(String),
|
||||
|
||||
#[error("渲染错误: '{0}' 不是数组,无法遍历")]
|
||||
NotAnArray(String),
|
||||
|
||||
#[error("渲染递归超过最大深度限制 ({0})")]
|
||||
MaxDepthReached(u8),
|
||||
|
||||
#[error("渲染错误: {0}")]
|
||||
Render(String),
|
||||
|
||||
#[error("消息序列校验失败: {0}")]
|
||||
InvalidSequence(String),
|
||||
|
||||
#[error("文件读取错误: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
```
|
||||
|
||||
### 6. 使用示例
|
||||
|
||||
```rust
|
||||
use agcore::prompt::{PromptComposer, PromptTemplate, TemplateContext};
|
||||
|
||||
// 编译模板
|
||||
let tpl = PromptTemplate::compile(
|
||||
"你是{{role}}。请回答以下问题:\n{{#if context}}参考背景:{{context}}\n{{/if}}提问:{{question}}"
|
||||
)?;
|
||||
|
||||
// 构建上下文
|
||||
let mut ctx = TemplateContext::new();
|
||||
ctx.insert("role", "资深工程师");
|
||||
ctx.insert("question", "Rust 的所有权规则是什么?");
|
||||
ctx.insert("context", "用户有 Java 背景");
|
||||
|
||||
// 使用组合器构建消息序列
|
||||
let messages = PromptComposer::new()
|
||||
.system("你是一个专业的 Rust 助手")
|
||||
.user_template(&tpl, &ctx)?
|
||||
.build();
|
||||
|
||||
// 可选:直接构建 ChatRequest
|
||||
let request = PromptComposer::new()
|
||||
.system("你是一个翻译助手")
|
||||
.user("Hello, world!")
|
||||
.build_request("gpt-4o");
|
||||
```
|
||||
|
||||
### 7. 与 LlmCycle 的集成
|
||||
|
||||
`PromptComposer::build()` 输出 `Vec<OpenaiChatMessage>`,但 **`LlmCycle.messages` 是私有字段**,无法直接赋值。因此**需要对 `LlmCycle` 进行扩展**,提供消息注入入口,使 Composer 能和 `LlmCycle` 的多轮对话循环协同工作。
|
||||
|
||||
**方案**:在 `LlmCycle` 上新增 3 个方法:
|
||||
|
||||
```rust
|
||||
// llm/cycle.rs
|
||||
|
||||
impl LlmCycle {
|
||||
/// 直接设置消息历史(覆盖已有消息),支持 Builder 链式调用。
|
||||
pub fn with_messages(mut self, messages: Vec<Message>) -> Self;
|
||||
|
||||
/// 追加消息到历史尾部。
|
||||
pub fn extend_messages(&mut self, messages: Vec<Message>);
|
||||
|
||||
/// 使用预构建消息提交(跳过自动 push user prompt)。
|
||||
/// 与 submit() 不同,不自动添加 user_text(prompt)。
|
||||
pub async fn submit_messages(
|
||||
&mut self,
|
||||
messages: Vec<Message>,
|
||||
tools: Vec<ToolDefinition>,
|
||||
) -> Result<ChatResponse, LlmError>;
|
||||
}
|
||||
```
|
||||
|
||||
**`submit_messages()` 与 `submit()` 的区别**:
|
||||
|
||||
| 维度 | `submit()` | `submit_messages()` |
|
||||
|------|-----------|-------------------|
|
||||
| 输入 | `prompt: String` + `tools` | `messages: Vec<Message>` + `tools` |
|
||||
| 内部操作 | 自动 push `user_text(prompt)` | 不自动添加任何消息 |
|
||||
| 适用场景 | 简单单轮对话 | 多轮/预构建消息序列 |
|
||||
| system_prompt 处理 | 自动插入(如果无 System 消息) | 完全由调用方控制 |
|
||||
|
||||
**`system_prompt` 冲突处理**:`submit_messages()` 关闭 `LlmCycle` 的自动 system prompt 插入逻辑,避免与 `PromptComposer` 已构建的 System/Developer 消息重复。由调用方全权控制消息序列内容。
|
||||
|
||||
**使用示例**:
|
||||
```rust
|
||||
let messages = PromptComposer::new()
|
||||
.system("你是一个专业的 Rust 助手")
|
||||
.user_template(&query_tpl, &ctx)?
|
||||
.build();
|
||||
|
||||
let mut cycle = LlmCycle::new(provider, config)
|
||||
.with_messages(messages);
|
||||
|
||||
let resp = cycle.submit_messages(vec![], tools).await?;
|
||||
```
|
||||
|
||||
`PromptComposer::build_request()` 直接创建 `OpenaiChatRequest`,可用于绕过 `LlmCycle` 直接调用 `LlmProvider` 的场景。
|
||||
|
||||
**注意**:`PromptComposer` 模块 **不** 直接依赖 `LlmCycle`(避免 `prompt → cycle` 的强耦合)。集成方法全部在 `LlmCycle` 侧实现,保持单一职责。
|
||||
|
||||
---
|
||||
|
||||
## 实现计划
|
||||
|
||||
### Step 1: 创建方案文档
|
||||
|
||||
创建 `docs/4-prompt-engineering.md`(即本文档)。
|
||||
|
||||
### Step 2: PromptError
|
||||
|
||||
- 创建 `src/prompt/error.rs`
|
||||
- 定义 `PromptError` 枚举(Parse / Render / VariableNotFound / PartialNotFound)
|
||||
|
||||
### Step 3: PromptTemplate
|
||||
|
||||
- 创建 `src/prompt.rs`(模块根)
|
||||
- 创建 `src/prompt/template.rs`
|
||||
- 实现编译(`compile()`):逐字符扫描 → 生成 `Vec<Fragment>`
|
||||
- 注意处理:嵌套 `#if` 栈匹配、`{{` 字面量转义、未闭合标签检测
|
||||
- 实现渲染(`render()`):递归遍历 Fragment → 拼接字符串
|
||||
- 定义非字符串值渲染格式:`Bool`→`"true"`/`"false"`、`Array`→JSON、`Object`→JSON
|
||||
- 递归加深度限制(16 层)防止循环引用
|
||||
- 支持功能:变量插值、条件渲染、列表循环、原始块、子模板引用
|
||||
- 实现 `Display for PromptTemplate`(输出原始模板字符串)
|
||||
- 编写 20+ 边界测试覆盖:嵌套 if/each、未闭合标签、空变量、空数组 each、递归深度超限
|
||||
- 运行 `cargo test + cargo check` 验证
|
||||
|
||||
### Step 4: PromptTemplateRegistry
|
||||
|
||||
- 推荐选项:`template` 保持单文件,`PromptTemplateRegistry` 合并同文件(~40 行不值得单独目录)
|
||||
- 内部存储使用 `StoredTemplate` 枚举(支持 `Compiled` 和 `Raw` 两种状态)
|
||||
- 实现:`register()` 立即编译、`register_lazy()` 延迟编译、`register_file()` 文件加载
|
||||
- 实现 `get()` / `render()`(延迟编译的模板首次渲染时编译)
|
||||
- 运行 `cargo check` 验证
|
||||
|
||||
### Step 5: PromptComposer
|
||||
|
||||
- 创建 `src/prompt/composer.rs`
|
||||
- 实现 Builder 链式 API,内部维护 `messages: Vec<OpenaiChatMessage>` + `pending_name: Option<String>`
|
||||
- 角色方法:`system()` / `user()` / `assistant()` / `developer()` / `tool()`
|
||||
- 纯文本消息统一使用 `ContentField::Array([Text])`
|
||||
- `tool()` 需传入 `tool_call_id` 和 `content`
|
||||
- 模板方法:`system_template()` / `user_template()` / `assistant_template()` / `developer_template()`
|
||||
- 多模态方法:`*_content()`(单个 ContentPart)和 `*_contents()`(批量)
|
||||
- `with_name()`:作用于上一条消息的 `name` 字段
|
||||
- `build()` / `build_request()`
|
||||
- `validate_messages()`:独立的纯函数,校验 Tool→Assistant 顺序、角色交替、非空
|
||||
- 运行 `cargo check` 验证
|
||||
|
||||
### Step 6: LlmCycle 扩展
|
||||
|
||||
- `cycle.rs` 新增 3 个方法:
|
||||
- `with_messages(self, messages: Vec<Message>) -> Self` — 链式设置消息历史
|
||||
- `extend_messages(&mut self, messages: Vec<Message>)` — 追加消息
|
||||
- `submit_messages(&mut self, messages: Vec<Message>, tools: Vec<ToolDefinition>) -> Result<...>` — 预构建消息提交
|
||||
- `submit_messages()` 关闭自动 system_prompt 插入(避免与 Composer 的 System/Developer 消息重复)
|
||||
- 运行 `cargo check` 验证
|
||||
|
||||
### Step 7: lib.rs 注册
|
||||
|
||||
- `lib.rs` 添加 `pub mod prompt;`
|
||||
- 运行 `cargo check` 验证
|
||||
|
||||
### Step 8: 收尾
|
||||
|
||||
- `cargo clippy` — 无警告
|
||||
- `cargo build` — 完整构建
|
||||
- 检查所有新公开 API 有 `///` 文档注释
|
||||
|
||||
---
|
||||
|
||||
## 术语表
|
||||
|
||||
| 术语 | 说明 |
|
||||
|------|------|
|
||||
| `TemplateValue` | 模板渲染上下文中使用的值类型枚举 |
|
||||
| `TemplateContext` | 模板变量上下文,持有所有变量 |
|
||||
| `PromptTemplate` | 预编译的模板,持有 AST 片段列表 |
|
||||
| `Fragment` | 编译后的 AST 节点(内部枚举) |
|
||||
| `PromptComposer` | 提示词组合器,构建多角色消息序列 |
|
||||
| `PromptError` | 提示词工程专属错误类型 |
|
||||
|
||||
---
|
||||
|
||||
## 风险评估
|
||||
|
||||
| 风险 | 概率 | 缓解措施 |
|
||||
|------|------|---------|
|
||||
| 模板语法不支持复杂场景(如嵌套 each) | 低 | 当前需求不涉及;后续可引入 tera 替换 |
|
||||
| 自定义模板引擎解析有 bug | 中 | 编写单元测试覆盖所有语法分支 |
|
||||
| 与未来 Prompt Optimizer 冲突 | 低 | PromptOptimizer 只修改模板/上下文,不改模板引擎接口 |
|
||||
| 条件语义不明确(什么是"假"值) | 低 | 明确定义:None / false / 空字符串 / 空数组 均为假 |
|
||||
| 子模板循环引用导致栈溢出 | 低 | 渲染时加递归深度限制(16 层)或已注册集合去重检测 |
|
||||
|
||||
---
|
||||
|
||||
## 验收标准
|
||||
|
||||
1. `cargo check` 编译通过
|
||||
2. `cargo clippy` 无警告
|
||||
3. 模块文件路径正确:`src/prompt.rs` + `src/prompt/{template,composer,error}.rs`
|
||||
4. `PromptTemplate::compile()` 能解析含变量/条件/循环的模板
|
||||
5. `PromptTemplate::render()` 正确渲染所有语法
|
||||
6. `PromptTemplate` 实现 `Display` trait,输出原始模板字符串
|
||||
7. `TemplateContext` 提供 `from_json()` / `from_map()` 构造方式,支持 `From<&str>` 自动转换
|
||||
8. `PromptTemplateRegistry` 支持立即编译(`register()`)、延迟编译(`register_lazy()`)、文件加载(`register_file()`)
|
||||
9. `PromptComposer` 支持链式调用,覆盖 System / User / Assistant / Developer / Tool 五种角色
|
||||
10. `PromptComposer` 支持 `user_content()` / `system_content()` / `assistant_content()` / `developer_content()` / `tool_content()` 多模态方法
|
||||
11. `PromptComposer` 支持 `with_name()` 设置消息角色标识
|
||||
12. `PromptComposer::build_request()` 能创建 `OpenaiChatRequest`
|
||||
13. `validate_messages()` 能校验消息序列合法性
|
||||
14. `LlmCycle` 新增 `with_messages()` / `extend_messages()` / `submit_messages()` 支持 Composer 集成
|
||||
15. `lib.rs` 包含 `pub mod prompt;`
|
||||
16. 所有新公开 API 有文档注释
|
||||
@@ -0,0 +1,996 @@
|
||||
# Phase 2: Tool System — 方案设计
|
||||
|
||||
> 定稿日期:2026-06-03
|
||||
|
||||
## 背景与目标
|
||||
|
||||
AG Core Phase 0(Foundation)已完成 LLM 调用周期基础设施,Phase 1(Prompt Engineering)已完成提示词组合与模板化。Phase 2 的目标是补齐**工具系统**能力,实现 LLM 驱动的工具定义、注册、调用、权限控制,以及 MCP 协议集成。
|
||||
|
||||
**核心目标**:让 LLM 能通过 `FinishReason::ToolCalls` 触发工具自动执行,并将结果回传至对话上下文,形成完整的"思考 → 调用 → 反馈"闭环。
|
||||
|
||||
---
|
||||
|
||||
## 需求分析
|
||||
|
||||
### 功能需求
|
||||
|
||||
| 模块 | 需求 | 验收条件 |
|
||||
|------|------|---------|
|
||||
| `BaseTool` trait | 工具抽象接口:名称、描述、参数、执行、权限声明 | 实现 trait 后可注册到 Registry |
|
||||
| `ToolRegistry` | 工具注册、发现、按名称调用 | 注册 3 个工具后能按名称查找到并执行 |
|
||||
| `McpClient` | MCP 协议客户端(stdio transport) | 能启动 MCP 服务器子进程、列出工具、调用工具 |
|
||||
| `PermissionChecker` | 工具执行前权限校验 | 禁止无权限的工具执行,返回结构化错误 |
|
||||
| 自动 Tool 循环 | LlmCycle 收到 ToolCalls 后自动执行工具并回传 | 一个包含工具调用的对话能完整执行 2+ 轮 |
|
||||
| 流式 Tool 事件 | 流式模式下发射 `ToolExecutionCompleted` 事件 | 流式调用中工具执行完成后触发对应事件 |
|
||||
| 工具调用历史持久化 | 自动工具循环产生的 Tool/Assistant 消息正确追加到 `messages` | 查看 `cycle.messages()` 能获取完整工具交互轨迹 |
|
||||
|
||||
### 非功能需求
|
||||
|
||||
- 所有公开 API 必须带 `///` 文档注释
|
||||
- 无新增 `unwrap()` 调用
|
||||
- `BaseTool` 的 `execute()` 必须为 `async`
|
||||
- 工具执行错误必须结构化为 `ToolError`,不允许 panic
|
||||
- MCP 客户端超时默认 30 秒,可配置
|
||||
- 自定义工具与 MCP 工具通过同一 `ToolRegistry` 管理,对 LlmCycle 透明
|
||||
- 权限检查在工具执行之前,阻断后返回错误而非静默跳过
|
||||
- `BaseTool::execute()` 签名必须预留扩展点(`ToolContext` 注入),确保未来 Skill/Agent 层可在不修改 trait 签名的情况下注入 session_id、cancellation_token 等上下文信息
|
||||
- 自动 tool 循环应考虑 token 消耗——工具定义随每轮请求重复发送,工具结果直接追加到对话历史,需提供结果大小限制和截断策略
|
||||
|
||||
---
|
||||
|
||||
## 方案设计
|
||||
|
||||
### 模块结构
|
||||
|
||||
```
|
||||
src/
|
||||
tools.rs # tools 模块根:声明子模块 + 重导出公共 API
|
||||
tools/
|
||||
base.rs # BaseTool trait — 工具抽象接口
|
||||
registry.rs # ToolRegistry — 工具注册表
|
||||
permission.rs # PermissionChecker — 权限校验器
|
||||
mcp.rs # McpClient — MCP 协议客户端
|
||||
error.rs # ToolError — 工具系统错误类型
|
||||
```
|
||||
|
||||
`tools.rs` 根模块声明:
|
||||
|
||||
```rust
|
||||
// tools.rs
|
||||
pub mod base;
|
||||
pub mod error;
|
||||
pub mod mcp;
|
||||
pub mod permission;
|
||||
pub mod registry;
|
||||
|
||||
pub use base::{BaseTool, ToolContext};
|
||||
pub use error::ToolError;
|
||||
pub use mcp::McpClient;
|
||||
pub use permission::{Permission, PermissionChecker, PermissionConfig};
|
||||
pub use registry::{ToolEntry, ToolInvocation, ToolRegistry};
|
||||
```
|
||||
|
||||
`lib.rs` 添加:
|
||||
|
||||
```diff
|
||||
pub mod llm;
|
||||
pub mod prompt;
|
||||
+pub mod tools;
|
||||
```
|
||||
|
||||
### 1. BaseTool trait — 工具抽象接口
|
||||
|
||||
```rust
|
||||
// tools/base.rs
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::tools::error::ToolError;
|
||||
use crate::tools::permission::Permission;
|
||||
|
||||
/// 工具执行上下文 —— 携带每次执行的运行时信息。
|
||||
/// 新增字段时提供默认值,不破坏已有工具实现。
|
||||
pub struct ToolContext<'a> {
|
||||
/// 当前对话/会话 ID,用于关联性追踪。
|
||||
pub session_id: &'a str,
|
||||
/// 链路追踪 ID,用于跨工具调用的耗时分布。
|
||||
pub trace_id: &'a str,
|
||||
/// 取消令牌,用于优雅取消正在执行的工具。
|
||||
pub cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
/// 工具抽象接口 —— 所有工具(自定义或 MCP)最终都实现此 trait。
|
||||
#[async_trait]
|
||||
pub trait BaseTool: Send + Sync {
|
||||
/// 工具名称(唯一标识,用于 LLM 的 tool_calls.name 匹配)。
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// 工具描述(LLM 据此决定是否调用此工具)。
|
||||
fn description(&self) -> &str;
|
||||
|
||||
/// 工具参数定义(JSON Schema 格式,传递给 LLM 的 tool.parameters)。
|
||||
fn parameters(&self) -> Value;
|
||||
|
||||
/// 声明工具所需的权限列表。
|
||||
fn required_permissions(&self) -> Vec<Permission> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
/// 执行工具调用。
|
||||
/// `ctx` 携带执行上下文(session_id、trace_id 等),Phase 3/4 可扩展字段而不破坏 trait 签名。
|
||||
async fn execute(&self, args: Value, ctx: &ToolContext<'_>) -> Result<Value, ToolError>;
|
||||
}
|
||||
```
|
||||
|
||||
**设计说明**:
|
||||
- `name()` 返回 `&str` 而非 `String`,避免每次调用克隆
|
||||
- `parameters()` 返回 `serde_json::Value`,与现有 `OpenaiToolDefinition.parameters` 类型一致
|
||||
- `required_permissions()` 提供默认空实现,简化无敏感操作的工具定义
|
||||
- `execute()` 接收 `Value`(JSON 对象)+ `ToolContext` 作为参数,返回 `Value` 作为结果,与 OpenAI API 的 arguments/output 格式一致
|
||||
- `ToolContext` 从 Phase 2 即注入 `execute()` 签名,防止后续 breaking change;新增字段用 `Option` 包裹或提供默认值
|
||||
|
||||
### 2. ToolRegistry — 工具注册表
|
||||
|
||||
```rust
|
||||
// tools/registry.rs
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::llm::types::{OpenaiToolDefinition, ToolDefinition};
|
||||
use crate::tools::base::BaseTool;
|
||||
use crate::tools::error::ToolError;
|
||||
use crate::tools::permission::{Permission, PermissionChecker};
|
||||
|
||||
/// 工具调用记录 —— 用于追踪和调试。
|
||||
pub struct ToolInvocation {
|
||||
pub tool_name: String,
|
||||
pub input: Value,
|
||||
pub output: Result<Value, ToolError>,
|
||||
}
|
||||
|
||||
/// 工具注册表 —— 管理工具注册、发现、调用。
|
||||
pub struct ToolRegistry {
|
||||
tools: HashMap<String, Arc<dyn BaseTool>>,
|
||||
permission_checker: Option<Arc<PermissionChecker>>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new() -> Self;
|
||||
|
||||
/// 设置权限检查器(可选,不设置则不检查权限)。
|
||||
pub fn with_permission_checker(mut self, checker: PermissionChecker) -> Self;
|
||||
|
||||
/// 注册一个工具。
|
||||
pub fn register(&mut self, tool: Arc<dyn BaseTool>) -> Result<(), ToolError>;
|
||||
|
||||
/// 批量注册工具。
|
||||
pub fn register_all(&mut self, tools: Vec<Arc<dyn BaseTool>>) -> Result<(), ToolError>;
|
||||
|
||||
/// 注销一个工具。
|
||||
pub fn unregister(&mut self, name: &str) -> Option<Arc<dyn BaseTool>>;
|
||||
|
||||
/// 按名称查找工具。
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn BaseTool>>;
|
||||
|
||||
/// 获取所有已注册工具的名称列表。
|
||||
pub fn list_tools(&self) -> Vec<String>;
|
||||
|
||||
/// 获取所有工具的 ToolDefinition 列表(用于传递给 LLM)。
|
||||
pub fn definitions(&self) -> Vec<ToolDefinition>;
|
||||
|
||||
/// 调用一个工具(含权限检查)。
|
||||
pub async fn invoke(&self, name: &str, args: Value) -> Result<ToolInvocation, ToolError>;
|
||||
|
||||
/// 批量执行工具调用(并行执行互不依赖的工具)。
|
||||
pub async fn invoke_all(
|
||||
&self,
|
||||
calls: Vec<(String, Value)>,
|
||||
) -> Vec<ToolInvocation>;
|
||||
}
|
||||
```
|
||||
|
||||
**核心逻辑**:
|
||||
- `invoke()`:查找工具 → 权限检查 → 执行 → 返回 `ToolInvocation`
|
||||
- `invoke_all()`:对多个工具调用并行执行(使用 `tokio::join!` 或 `futures::join_all`),适用于 LLM 同时发出多个 tool_calls 的场景
|
||||
- `invoke_all()` 应对每个工具执行添加超时控制(通过 `tokio::time::timeout`),超时时间由 `CycleConfig::tool_timeout_secs` 配置,默认 60 秒,防止单个工具长时间阻塞整个循环
|
||||
- `definitions()`:将注册的工具批量转换为 `Vec<ToolDefinition>`,供 LlmCycle 传递 LLM
|
||||
- `ToolRegistry` 不持有 `PermissionChecker` 的生命周期(使用 `Arc`),允许多个 Registry 共享同一个 Checker
|
||||
|
||||
**使用示例**:
|
||||
```rust
|
||||
let mut registry = ToolRegistry::new()
|
||||
.with_permission_checker(checker);
|
||||
|
||||
registry.register(Arc::new(WeatherTool))?;
|
||||
registry.register(Arc::new(FileReadTool))?;
|
||||
|
||||
// 获取 ToolDefinitions 传递给 LLM
|
||||
let tools = registry.definitions();
|
||||
|
||||
// 收到 LLM 的 tool_calls 后执行
|
||||
let calls = vec![
|
||||
("get_weather".into(), json!({"city": "Beijing"})),
|
||||
("read_file".into(), json!({"path": "/tmp/data.txt"})),
|
||||
];
|
||||
let results = registry.invoke_all(calls).await;
|
||||
```
|
||||
|
||||
### 3. PermissionChecker — 权限校验器
|
||||
|
||||
```rust
|
||||
// tools/permission.rs
|
||||
|
||||
/// 权限级别枚举。
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum Permission {
|
||||
/// 只读(读取文件、查询数据库等)。
|
||||
Read,
|
||||
/// 写入(创建/修改文件、插入数据等)。
|
||||
Write,
|
||||
/// 删除(删除文件、记录等)。
|
||||
Delete,
|
||||
/// 网络访问(HTTP 请求等)。
|
||||
Network,
|
||||
/// Shell 命令执行。
|
||||
Shell,
|
||||
/// 文件系统操作(除读/写/删之外的 FS 操作)。
|
||||
FileSystem,
|
||||
/// 自定义权限(可通过 namespaced 字符串扩展)。
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
/// 权限配置。
|
||||
pub struct PermissionConfig {
|
||||
/// 允许的权限列表(空 = 全部允许)。
|
||||
pub allowed: Vec<Permission>,
|
||||
/// 拒绝的权限列表(优先级高于 allowed)。
|
||||
pub denied: Vec<Permission>,
|
||||
/// 是否允许未声明权限的工具执行(默认为 true)。
|
||||
pub allow_unspecified: bool,
|
||||
}
|
||||
|
||||
impl Default for PermissionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allowed: vec![Permission::Read, Permission::Network],
|
||||
denied: vec![Permission::Delete, Permission::Shell],
|
||||
allow_unspecified: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 权限检查器。
|
||||
pub struct PermissionChecker {
|
||||
config: PermissionConfig,
|
||||
}
|
||||
|
||||
impl PermissionChecker {
|
||||
pub fn new(config: PermissionConfig) -> Self;
|
||||
|
||||
/// 检查指定权限是否允许执行。
|
||||
pub fn check(&self, tool_name: &str, permissions: &[Permission]) -> Result<(), ToolError>;
|
||||
}
|
||||
```
|
||||
|
||||
**权限判定规则**:
|
||||
1. 如果权限在 `denied` 中 → 拒绝
|
||||
2. 如果权限在 `allowed` 中 → 允许
|
||||
3. 如果 `allowed` 非空且权限不在其中 → 拒绝(白名单模式)
|
||||
4. 如果 `allowed` 为空 → 按 `allow_unspecified` 判定
|
||||
|
||||
### 4. McpClient — MCP 协议客户端
|
||||
|
||||
MCP(Model Context Protocol)是一种基于 JSON-RPC 的协议,用于 LLM 与外部工具系统通信。Phase 2 实现其 **最小可行子集**,优先实现 stdio transport。
|
||||
|
||||
> **传输方式说明**:MCP 协议版本 2025-03-26 定义了两种标准传输——`stdio` 和 `Streamable HTTP`。原有的 `HTTP+SSE` 传输(2024-11-05)已被官方废弃,新实现不应采用。`Streamable HTTP` 通过单一 HTTP 端点同时支持 JSON 响应和 SSE 流式升级,是 HTTP 场景的推荐方案。
|
||||
|
||||
```rust
|
||||
// tools/mcp.rs
|
||||
|
||||
use std::process::{Child, Command, Stdio};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use serde_json::Value;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{ChildStdin, ChildStdout, Command as TokioCommand};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::tools::base::BaseTool;
|
||||
use crate::tools::error::ToolError;
|
||||
|
||||
/// MCP 协议版本。
|
||||
const MCP_VERSION: &str = "2025-03-26";
|
||||
|
||||
/// MCP 传输方式。
|
||||
pub enum McpTransport {
|
||||
/// 通过子进程 stdin/stdout 通信。
|
||||
Stdio {
|
||||
command: String,
|
||||
args: Vec<String>,
|
||||
},
|
||||
/// Streamable HTTP 传输(MCP 2025-03-26 引入,替代已废弃的 HTTP+SSE)。
|
||||
/// 客户端通过单一 HTTP 端点与 MCP Server 通信,支持 JSON 和 SSE 流式响应。
|
||||
StreamableHttp {
|
||||
url: String,
|
||||
headers: Option<Vec<(String, String)>>,
|
||||
},
|
||||
}
|
||||
|
||||
/// MCP 子进程运行时状态(connect() 后创建)。
|
||||
struct ChildProcessState {
|
||||
child: tokio::process::Child,
|
||||
stdin: tokio::io::BufWriter<tokio::process::ChildStdin>,
|
||||
/// 等待响应的请求映射(id → oneshot sender)。
|
||||
pending: HashMap<u64, tokio::sync::oneshot::Sender<Result<Value, ToolError>>>,
|
||||
next_id: u64,
|
||||
}
|
||||
|
||||
/// MCP 客户端 —— 与 MCP 服务器通信。
|
||||
pub struct McpClient {
|
||||
transport: McpTransport,
|
||||
server_name: String,
|
||||
/// 已初始化的工具列表(缓存)。
|
||||
tools: Vec<McpTool>,
|
||||
/// 是否已初始化。
|
||||
initialized: AtomicBool,
|
||||
/// 超时时间(秒)。
|
||||
timeout_secs: u64,
|
||||
/// 子进程运行时状态(connect() 后创建,close() 后取回)。
|
||||
process: Option<tokio::sync::Mutex<ChildProcessState>>,
|
||||
}
|
||||
|
||||
/// MCP 服务器暴露的工具(缓存结构)。
|
||||
struct McpTool {
|
||||
name: String,
|
||||
description: Option<String>,
|
||||
input_schema: Value,
|
||||
}
|
||||
|
||||
impl McpClient {
|
||||
/// 创建一个 MCP 客户端。
|
||||
pub fn new(server_name: impl Into<String>, transport: McpTransport) -> Self;
|
||||
|
||||
/// 设置超时时间。
|
||||
pub fn with_timeout(mut self, secs: u64) -> Self;
|
||||
|
||||
/// 连接并初始化(发送 initialize 请求,获取服务器能力声明)。
|
||||
/// 启动子进程,创建 ChildProcessState(含 reader task)。
|
||||
pub async fn connect(&mut self) -> Result<(), ToolError>;
|
||||
|
||||
/// 列出服务器支持的工具(调用 tools/list)。
|
||||
pub async fn list_tools(&mut self) -> Result<Vec<ToolDefinition>, ToolError>;
|
||||
|
||||
/// 调用一个工具(调用 tools/call)。
|
||||
/// 通过 Mutex 获取 stdin 写入权限,发送 JSON-RPC 请求,通过 id 匹配响应。
|
||||
/// reader task 持续读取 stdout,解析 JSON-RPC 响应,通过 oneshot 通知调用方。
|
||||
pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, ToolError>;
|
||||
|
||||
/// 关闭连接(终止子进程)。
|
||||
/// 发送 shutdown → 等待 5s 优雅退出 → 超时则 child.kill()。
|
||||
pub async fn close(&mut self) -> Result<(), ToolError>;
|
||||
|
||||
/// 将 MCP 客户端转换为 BaseTool 适配器列表(用于注册到 ToolRegistry)。
|
||||
pub fn into_tools(self) -> Vec<Arc<dyn BaseTool>>;
|
||||
}
|
||||
```
|
||||
|
||||
**MCP 协议交互流程**:
|
||||
|
||||
```
|
||||
客户端 MCP 服务器
|
||||
│ │
|
||||
├── initialize request ──────► │
|
||||
│ { protocolVersion, capabilities }
|
||||
│◄──── initialize response ── │
|
||||
│ { protocolVersion, serverInfo, capabilities }
|
||||
│ │
|
||||
├── initialized notification ──► │
|
||||
│ │
|
||||
├── tools/list ──────────────► │
|
||||
│◄── tools/list result ──────── │
|
||||
│ { tools: [{ name, description, inputSchema }] }
|
||||
│ │
|
||||
├── tools/call ─────────────► │
|
||||
│ { name, arguments }
|
||||
│◄── tools/call result ──────── │
|
||||
│ { content: [{ type, text }] }
|
||||
│ │
|
||||
├── shutdown ───────────────► │
|
||||
│◄── shutdown ───────────── │
|
||||
```
|
||||
|
||||
**关于 stdio transport 实现**:
|
||||
- 使用 `tokio::process::Command` 启动子进程
|
||||
- stdin 写入 JSON-RPC 请求(每行一个 JSON 对象)
|
||||
- stdout 读取 JSON-RPC 响应(使用 `BufReader` 逐行读取)
|
||||
- 每个请求关联一个 `id`(递增整数),通过 `id` 匹配请求和响应
|
||||
- 进程退出时自动关闭
|
||||
|
||||
**关于 MCP Server 的管理**:
|
||||
- Phase 2 **不** 实现 MCP Server 框架,只实现 Client
|
||||
- MCP Server 由外部提供(如 `npx @anthropic/mcp-server-filesystem`)
|
||||
- 用户需要提供 MCP Server 的启动命令和参数
|
||||
|
||||
**工具缓存说明**:
|
||||
- `McpClient` 在 `list_tools()` 时缓存工具列表,避免每次调用都重新请求
|
||||
- 缓存假设:MCP Server 的工具列表在运行时不会频繁变更(如插件式加载场景除外)
|
||||
- 如需刷新,可通过新增 `refresh_tools()` 方法或基于 TTL(如 60 秒)自动失效
|
||||
|
||||
### 5. ToolError — 错误类型
|
||||
|
||||
```rust
|
||||
// tools/error.rs
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ToolError {
|
||||
#[error("工具 '{0}' 未注册")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("工具 '{0}' 执行失败: {1}")]
|
||||
ExecutionFailed(String, String),
|
||||
|
||||
#[error("工具 '{0}' 参数无效: {1}")]
|
||||
InvalidArguments(String, String),
|
||||
|
||||
#[error("权限被拒绝: 工具 '{0}' 需要 {1:?} 权限")]
|
||||
PermissionDenied(String, String),
|
||||
|
||||
#[error("MCP 协议错误: {0}")]
|
||||
McpError(String),
|
||||
|
||||
#[error("MCP 未初始化: {0}")]
|
||||
McpNotInitialized(String),
|
||||
|
||||
#[error("MCP 超时: {0}")]
|
||||
McpTimeout(String),
|
||||
|
||||
#[error("IO 错误: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("其他错误: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
```
|
||||
|
||||
### 6. LlmCycle 扩展 — 自动 Tool 循环
|
||||
|
||||
**核心设计**:在 `LlmCycle` 中新增 `submit_with_tools()` 方法,自动处理 tool 执行循环。
|
||||
|
||||
```rust
|
||||
// llm/cycle.rs 新增
|
||||
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::tools::error::ToolError;
|
||||
|
||||
impl LlmCycle {
|
||||
/// 提交消息并自动处理工具调用循环。
|
||||
///
|
||||
/// 流程:
|
||||
/// 1. 发送请求(含工具定义)
|
||||
/// 2. 检查响应中的 finish_reason
|
||||
/// 3. 如果是 ToolCalls → 先 push Assistant 消息 → 执行工具 → 回传结果 → 重复 1
|
||||
/// 4. 如果是 Stop/Length → push Assistant 消息 → 返回最终响应
|
||||
///
|
||||
/// 注意:OpenAI API 要求 tool 消息必须紧跟在对应的 Assistant(tool_calls)消息之后。
|
||||
/// 因此 push 工具结果前必须先 push Assistant 响应,否则 API 拒绝请求。
|
||||
pub async fn submit_with_tools(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
registry: &ToolRegistry,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let tools = registry.definitions();
|
||||
let max_turns = self.config.max_tool_turns.unwrap_or(10);
|
||||
let mut turn = 0;
|
||||
|
||||
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||
self.maybe_compact();
|
||||
|
||||
loop {
|
||||
turn += 1;
|
||||
if turn > max_turns {
|
||||
return Err(LlmError::Other(format!(
|
||||
"达到最大工具循环轮次 ({})",
|
||||
max_turns
|
||||
)));
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
let response = self.submit_request(&tools).await?;
|
||||
|
||||
// 检查是否需要执行工具
|
||||
let should_execute = matches!(
|
||||
response.stop_reason,
|
||||
Some(FinishReason::ToolCalls)
|
||||
) && has_tool_calls(&response.message);
|
||||
|
||||
// 将 Assistant 响应(含 tool_calls 或最终文本)追加到消息历史
|
||||
self.messages.push(response.message.clone());
|
||||
|
||||
if !should_execute {
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
// 解析 tool_calls 并执行
|
||||
let tool_calls = extract_tool_calls(&response.message);
|
||||
let results = registry.invoke_all(tool_calls).await;
|
||||
|
||||
// 回传工具结果
|
||||
for result in results {
|
||||
let content = match &result.output {
|
||||
Ok(value) => serde_json::to_string(value)
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::warn!("工具结果序列化失败: {}", e);
|
||||
"{}".to_string()
|
||||
}),
|
||||
Err(e) => format!("错误: {}", e),
|
||||
};
|
||||
|
||||
self.messages.push(
|
||||
OpenaiChatMessage::tool_result(result.tool_name.clone(), content)
|
||||
);
|
||||
}
|
||||
|
||||
// 每轮工具执行后触发 compaction,防止 token 快速膨胀
|
||||
self.maybe_compact();
|
||||
}
|
||||
}
|
||||
|
||||
/// 在接近上下文窗口时压缩历史消息。
|
||||
fn maybe_compact(&mut self) {
|
||||
if let Some(ref config) = self.compact_config
|
||||
&& should_compact(&self.messages, config, &self.compact_state)
|
||||
{
|
||||
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||
if freed > 0 {
|
||||
self.compact_state.record_success();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 内部请求方法(与 submit 共享重试逻辑,但不 push user message)。
|
||||
async fn submit_request(
|
||||
&mut self,
|
||||
tools: &[ToolDefinition],
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
// ... 提取 submit() 中的 request → response 逻辑(不含 user prompt push)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**关键设计决策**:
|
||||
|
||||
| 决策 | 选择 | 理由 |
|
||||
|------|------|------|
|
||||
| 循环方式 | 同步循环(单线程串行) | 工具执行依赖前一轮结果,串行更安全 |
|
||||
| 最大轮次 | `CycleConfig.max_tool_turns`,独立于 `max_turns`,默认 `Some(10)` | 防止无限循环(LLM 反复调用工具)。使用独立字段避免影响现有 `submit()`/`submit_messages()` 的 `max_turns` 语义 |
|
||||
| 工具并行 | `invoke_all()` 互不依赖的工具并行 | LLM 可能一次发出多个 tool_calls(parallel_tool_calls) |
|
||||
| 工具超时 | `CycleConfig::tool_timeout_secs`,默认 60 | 防止单个工具长时间阻塞循环。`invoke_all()` 使用 `tokio::time::timeout` 包装 |
|
||||
| 错误处理 | 工具执行错误以文本回传 LLM,而非终止循环 | LLM 可自行从错误中恢复 |
|
||||
| 消息追踪 | 所有工具交互通过 `self.messages` 持久化 | 调用方能通过 `cycle.messages()` 查看完整轨迹 |
|
||||
|
||||
**Token 消耗分析**:
|
||||
|
||||
自动 tool 循环的 token 消耗主要来自三个来源:
|
||||
|
||||
| 来源 | 说明 | 影响程度 |
|
||||
|------|------|---------|
|
||||
| 工具定义重复发送 | `definitions()` 在每轮请求中携带全部工具的 JSON Schema | 注册工具数 × 平均定义大小 × 轮数。20 个工具 × 500B × 5 轮 ≈ 50KB 输入 token |
|
||||
| 工具结果追加历史 | 每次工具执行结果完整追加到 `messages`,后续请求重发全部历史 | 最显著的 token 泄漏源。大结果(如向量搜索 Top-50)单次可能 ~15KB,多轮累加 |
|
||||
| Value→String 序列化 | 工具结果 `serde_json::to_string()` 后 JSON 字符串膨胀 ~20-30% | 线性的常量损耗 |
|
||||
|
||||
**影响估算**:
|
||||
|
||||
| 场景 | 工具相关 token 占比 | 说明 |
|
||||
|------|-------------------|------|
|
||||
| 单次简单查询 | <5% | 可忽略 |
|
||||
| 文件读取+分析(3-4 轮) | ~30% | 工具结果逐步累积 |
|
||||
| 网页搜索+总结(3-5 轮) | ~40% | 工具结果包含页面内容 |
|
||||
| 多工具数据 pipeline(5-10 轮) | ~60%+ | 需关注压缩和限制策略 |
|
||||
|
||||
**缓解方向**(Phase 2 不强制实现,但设计需可扩展):
|
||||
- **结果大小限制**:工具执行结果超过阈值时自动截断(如 `CycleConfig::max_tool_result_bytes`)
|
||||
- **自动压缩**:现有的 Auto-compaction 需感知工具消息,避免压缩掉 LLM 后续依赖的数据
|
||||
- **工具定义缓存**:基础工具定义变化极少,未来可考虑客户端侧缓存(需等 provider 支持)
|
||||
|
||||
**错误分类与处理策略**:
|
||||
|
||||
工具执行错误需要区分"可恢复"和"不可恢复"两类,不可恢复的错误应终止循环而非回传 LLM:
|
||||
|
||||
| 错误类型 | 处理策略 | 理由 |
|
||||
|---------|---------|------|
|
||||
| `ToolError::ExecutionFailed` | 回传 LLM(文本) | LLM 可能下次换参数或换方式重试 |
|
||||
| `ToolError::InvalidArguments` | 回传 LLM(文本) | LLM 可自动修正参数 |
|
||||
| `ToolError::NotFound` | 终止循环,返回 `LlmError` | LLM 无法注册工具,重试无意义 |
|
||||
| `ToolError::PermissionDenied` | 终止循环,返回 `LlmError` | 安全敏感,不应允许重试 |
|
||||
| `ToolError::McpError` | 终止循环,返回 `LlmError` | MCP 链路故障,重试大概率失败 |
|
||||
| `ToolError::McpTimeout` | 终止循环,返回 `LlmError` | 或可考虑重试 1 次后终止 |
|
||||
| `ToolError::Io` | 终止循环,返回 `LlmError` | IO 错误通常是环境问题 |
|
||||
| `ToolError::Other` | 回传 LLM(文本) | 兜底,保守回传 |
|
||||
|
||||
实现上可在 `ToolError` 上添加 `is_recoverable()` 方法,或在 `submit_with_tools()` 中通过 `match` 分支判断。
|
||||
|
||||
**submit_request() 重构说明**:
|
||||
|
||||
提取 `submit_request()` 作为 `submit_with_tools()` 的内部方法时,需确保不影响现有方法的行为。重构后的方法职责矩阵:
|
||||
|
||||
| 方法 | Push user msg | Compaction | Retry | Call provider | Handle response |
|
||||
|------|:---:|:---:|:---:|:---:|:---:|
|
||||
| `submit()` | ✅ | ✅ | ✅ | → `submit_request()` | ✅ |
|
||||
| `submit_messages()` | ❌ | ✅ | ✅ | → `submit_request()` | ✅ |
|
||||
| `submit_with_tools()` | ✅ | ✅ | ✅ | → `submit_request()` | ✅* |
|
||||
| `submit_request()` | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
|
||||
*`submit_with_tools()` 在 `submit_request()` 返回后额外检查 `ToolCalls`,执行工具后递归调用自身。
|
||||
|
||||
**流式模式支持**:
|
||||
|
||||
`submit_stream()` 的增强方案:新增 `submit_stream_with_tools()`,在流式事件层面支持自动 tool 循环。
|
||||
|
||||
> **实现复杂度提示**:流式 tool 循环需要自定义 `Stream` 实现 + 内部状态机(`Streaming` → `ExecutingTools` → `Finished`)。每一轮需要:消费当前流 → 收集事件 → 检测 `TurnComplete(ToolCalls)` → 执行工具 → 发射 `ToolExecutionCompleted` → 发起新流 → 继续 yield。不能用简单的 `stream!` 宏实现。
|
||||
>
|
||||
> 建议 Phaes 3 再实现 `submit_stream_with_tools()`,Phase 2 只实现非流式的 `submit_with_tools()`。如果 Phase 2 需要可先返回 "not yet implemented" 错误。
|
||||
|
||||
```rust
|
||||
impl LlmCycle {
|
||||
pub async fn submit_stream_with_tools(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
registry: &ToolRegistry,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>, LlmError> {
|
||||
// 1. 使用 submit_stream() 获取初始事件流
|
||||
// 2. 监听 TurnComplete { reason: ToolCalls }
|
||||
// 3. 触发时:通过 ToolRegistry 执行工具
|
||||
// 4. 发射 ToolExecutionCompleted 事件(由 submit_stream_with_tools 负责,非底层 stream parser)
|
||||
// 5. 将工具结果注入 messages
|
||||
// 6. 自动发起下一轮请求(递归)
|
||||
// 7. 直到 finish_reason 为 Stop
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**事件发射时序**:
|
||||
```
|
||||
submit_stream_with_tools("查天气")
|
||||
│
|
||||
├─ AssistantTextDelta "我来查一下北京的天气..." ← 底层 stream parser 发射
|
||||
├─ ToolExecutionStarted { tool_name, input, id } ← submit_stream_with_tools 发射
|
||||
├─ TurnComplete { reason: ToolCalls } ← 底层 stream parser 发射
|
||||
│
|
||||
├── [自动] 执行工具 get_weather({city:"北京"})
|
||||
│
|
||||
├─ ToolExecutionCompleted { tool_name, output, ... } ← submit_stream_with_tools 发射
|
||||
│
|
||||
├─ AssistantTextDelta "北京今天 22°C" ← 底层 stream parser 发射
|
||||
├─ TurnComplete { reason: Stop } ← 底层 stream parser 发射
|
||||
│
|
||||
└─ (流结束)
|
||||
|
||||
**事件发射职责划分**:底层 `parse_chunk_stream()` 负责 LLM 原生事件(`AssistantTextDelta`、`TurnComplete`);`submit_stream_with_tools()` 负责工具层事件(`ToolExecutionStarted`、`ToolExecutionCompleted`),在工具执行前/后手动 `yield` 事件。
|
||||
```
|
||||
|
||||
### 7. 自定义工具示例
|
||||
|
||||
```rust
|
||||
use agcore::tools::{BaseTool, ToolError};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
struct WeatherTool;
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for WeatherTool {
|
||||
fn name(&self) -> &str { "get_weather" }
|
||||
fn description(&self) -> &str { "获取指定城市的当前天气" }
|
||||
fn parameters(&self) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": { "type": "string", "description": "城市名称" }
|
||||
},
|
||||
"required": ["city"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<Value, ToolError> {
|
||||
let city = args["city"].as_str()
|
||||
.ok_or_else(|| ToolError::InvalidArguments(
|
||||
"get_weather".into(), "缺少 city 参数".into()
|
||||
))?;
|
||||
// 模拟天气查询
|
||||
Ok(serde_json::json!({ "city": city, "temperature": 22, "unit": "°C" }))
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 8. 模块依赖关系
|
||||
|
||||
```
|
||||
tools/ 模块内部依赖:
|
||||
base.rs → 无内部依赖(Permission 枚举 + ToolError)
|
||||
permission → 无内部依赖
|
||||
registry → base, error, permission
|
||||
mcp → base, error(需通过 registry 注册)
|
||||
error → 无内部依赖
|
||||
|
||||
跨模块依赖:
|
||||
tools/ → llm/types (ToolDefinition 类型)
|
||||
llm/cycle → tools/registry (自动 tool 循环)
|
||||
llm/cycle → tools/error (ToolError 转换)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 9. 未来工具化路线扩展性分析
|
||||
|
||||
> 本节回答"当前设计是否足以支撑未来常规工具、MCP、Skill、记忆等统一走工具调用路线"。
|
||||
|
||||
#### 设计目标
|
||||
|
||||
未来所有 Agent 可调用的能力(常规工具、MCP 工具、Skill、记忆操作)都应通过 `BaseTool` trait 统一暴露给 LLM,`ToolRegistry` 作为唯一的工具发现和调用入口,对 `LlmCycle` 透明。
|
||||
|
||||
#### 各场景支持度评估
|
||||
|
||||
| 场景 | 当前支持度 | 关键瓶颈 |
|
||||
|------|-----------|---------|
|
||||
| 常规工具(天气/计算器) | ✅ 直接可行 | 无 |
|
||||
| MCP 工具(McpClient→BaseTool 适配器) | ✅ 可行 | 适配器模式优雅,MCP 流式/进度能力被 `Value→Value` 约束 |
|
||||
| Memory CRUD(store/recall/forget/update) | ⚠️ 基本可行 | 检索分页、大量结果返回需额外处理 |
|
||||
| 长时运行工具(数据集查询、文件上传) | ❌ 不可行 | 无进度汇报、无 cancellation 机制 |
|
||||
| 多轮确认工具("是否冻结账户?"审批流程) | ❌ 不可行 | 单次调用→单次返回,无法表达"反问→确认"模式 |
|
||||
| Skill 编排(多步骤组合、嵌套执行) | ❌ 不可行 | 无上下文传播(跨步骤传递中间结果)、无工具组合原语 |
|
||||
| Agent 按场景筛选工具子集 | ⚠️ 部分可行 | 无 tag/category 筛选机制 |
|
||||
|
||||
#### 关键扩展点
|
||||
|
||||
**A. `BaseTool::execute()` 签名——预留 `ToolContext` 注入**
|
||||
|
||||
`BaseTool` 是公开 trait,一旦用户实现并发布 crate,后续 breaking change 成本极高。当前签名:
|
||||
|
||||
```rust
|
||||
async fn execute(&self, args: Value) -> Result<Value, ToolError>;
|
||||
```
|
||||
|
||||
未来扩展路径——新增 `ToolContext` 参数,携带执行上下文:
|
||||
|
||||
```rust
|
||||
async fn execute(&self, args: Value, ctx: &ToolContext<'_>) -> Result<Value, ToolError>;
|
||||
```
|
||||
|
||||
`ToolContext` 初始应包含的字段(Phase 2 实现时不必全部实现,但签名需预留参数位置):
|
||||
|
||||
| 字段 | 用途 | 引入阶段 |
|
||||
|------|------|---------|
|
||||
| `session_id: &str` | 追踪一次对话中所有工具调用的关联性 | Phase 2 |
|
||||
| `trace_id: &str` | 链路追踪,跨工具调用的耗时分布 | Phase 2 |
|
||||
| `cancellation_token: CancellationToken` | 优雅取消正在执行的工具 | Phase 2 |
|
||||
| `progress: Option<UnboundedSender<ProgressEvent>>` | 进度汇报(数据处理到 50%) | Phase 3 |
|
||||
| `shared_state: Option<&HashMap<String, Value>>` | Skill 跨步骤传递中间结果 | Phase 4 |
|
||||
|
||||
这样 Skill/Agent 层在 Phase 4 引入时,`execute` 签名不必改,只需在 `ToolContext` 中增加字段。
|
||||
|
||||
**B. `ToolRegistry` 内部结构——引入 `ToolEntry` 元数据**
|
||||
|
||||
当前内部是 `HashMap<String, Arc<dyn BaseTool>>`,未来扩展为:
|
||||
|
||||
```rust
|
||||
pub struct ToolEntry {
|
||||
pub tool: Arc<dyn BaseTool>,
|
||||
pub tags: Vec<String>,
|
||||
pub category: String, // "memory", "data", "communication" 等
|
||||
pub version: Option<String>,
|
||||
pub stats: ToolStats, // 调用次数、平均耗时
|
||||
}
|
||||
```
|
||||
|
||||
对应的筛选 API:
|
||||
|
||||
```rust
|
||||
pub fn find_by_tag(&self, tag: &str) -> Vec<&ToolEntry>;
|
||||
pub fn find_by_category(&self, category: &str) -> Vec<&ToolEntry>;
|
||||
pub fn groups(&self) -> HashMap<&str, Vec<&ToolEntry>>;
|
||||
```
|
||||
|
||||
**C. 工具返回模式——从单一 `Value` 到 `ToolOutput` 枚举**
|
||||
|
||||
当前返回类型 `Result<Value, ToolError>` 只能表达"一次性完整返回"。未来根据需要引入多模式输出:
|
||||
|
||||
```rust
|
||||
pub enum ToolOutput {
|
||||
/// 一次性返回完整结果
|
||||
Final(Value),
|
||||
/// 通过 channel 逐步流式输出结果
|
||||
Streamed { initial: Value, rx: Receiver<Value> },
|
||||
/// 需要 LLM 进一步确认后再继续
|
||||
AwaitingInput { context: Value, prompt: String },
|
||||
}
|
||||
```
|
||||
|
||||
| 返回模式 | 场景示例 |
|
||||
|---------|---------|
|
||||
| `Final(Value)` | 天气查询、文件读取 |
|
||||
| `Streamed { initial, rx }` | 向量搜索 Top-100 逐批返回 |
|
||||
| `AwaitingInput { context, prompt }` | "检测到可疑交易,是否冻结?" |
|
||||
|
||||
#### 各能力的引入时序
|
||||
|
||||
```
|
||||
Phase 2(当前实现)
|
||||
├─ BaseTool trait (Value→Value, 但签名预留 Context 参数位)
|
||||
├─ ToolRegistry (HashMap<String, ToolEntry> + tag/category 筛选)
|
||||
├─ PermissionChecker / McpClient / ToolError
|
||||
├─ submit_with_tools() / submit_stream_with_tools()
|
||||
└─ ToolContext { session_id, trace_id, cancellation_token }
|
||||
|
||||
Phase 3(Memory 工具化)
|
||||
├─ MemoryStore trait(扩展 BaseTool)
|
||||
├─ memory_store / memory_recall / memory_search 等作为工具注册
|
||||
└─ ToolContext.progress 支持(分批返回检索结果)
|
||||
|
||||
Phase 4(Agent + Skill + 编排)
|
||||
├─ ToolContext.shared_state 支持(跨步骤传递中间结果)
|
||||
├─ ToolOutput 枚举支持(如需要流式/确认模式)
|
||||
├─ ToolChain / ToolSelector 工具组合原语
|
||||
└─ Skill 机制(多步骤编排 + 内部状态)
|
||||
```
|
||||
|
||||
#### 已识别但推迟的设计决策
|
||||
|
||||
| 决策 | 推迟原因 | 何时需要 |
|
||||
|------|---------|---------|
|
||||
| `ToolOutput` 枚举 | Phase 2 的所有场景(常规工具/MCP)用 `Value` 足够 | Phase 4 Agent 编排或长时工具 |
|
||||
| 工具 DAG 调度 | Agent 场景后才需要复杂编排 | Phase 4 |
|
||||
| Skill 机制 | 需要先有 Agent 使用工具的实践经验 | Phase 4 |
|
||||
| 工具调用审计持久化 | 可先通过 Hook 点实现简单日志 | Phase 4 |
|
||||
| 用户授权(运行时弹窗确认) | `PermissionChecker` 只做静态策略判定,不处理运行时交互。用户授权属于交互流程,应作为 `ToolOutput::AwaitingInput` 由上层 UI/Agent 层实现 | Phase 4 |
|
||||
|
||||
---
|
||||
|
||||
## 实现计划
|
||||
|
||||
### Step 1: 创建方案文档
|
||||
|
||||
创建 `docs/5-tool-system.md`(即本文档)。
|
||||
|
||||
### Step 2: ToolError
|
||||
|
||||
- 创建 `src/tools/error.rs`
|
||||
- 定义 `ToolError` 枚举(NotFound / ExecutionFailed / InvalidArguments / PermissionDenied / McpError / McpTimeout / Io / Other)
|
||||
- 运行 `cargo check` 验证
|
||||
|
||||
### Step 3: Permission
|
||||
|
||||
- 创建 `src/tools/permission.rs`
|
||||
- 定义 `Permission` 枚举 + `PermissionConfig` + `PermissionChecker`
|
||||
- 编写权限判定逻辑(白名单/黑名单/未指定策略)
|
||||
- 编写 5+ 边界测试覆盖:白名单模式、黑名单模式、空列表、自定义权限冲突
|
||||
- 运行 `cargo test` 验证
|
||||
|
||||
### Step 4: BaseTool trait
|
||||
|
||||
- 创建 `src/tools/base.rs`
|
||||
- 定义 `BaseTool` trait(name / description / parameters / required_permissions / execute)
|
||||
- 定义 `ToolContext` 结构体(session_id / trace_id / cancellation_token),注入 `execute()` 作为第二个参数
|
||||
- 创建 `src/tools.rs` 模块根,声明子模块,重导出公共 API
|
||||
- `lib.rs` 添加 `pub mod tools;`
|
||||
- 编写 1 个 MockTool 测试工具并验证 trait 实现
|
||||
- 运行 `cargo check` 验证
|
||||
|
||||
### Step 5: ToolRegistry
|
||||
|
||||
- 创建 `src/tools/registry.rs`
|
||||
- 定义 `ToolInvocation` 结构体 + `ToolEntry` 元数据包装(tool + tags + category + stats)+ `ToolRegistry`
|
||||
- 实现核心方法:register / get / list / definitions / invoke / invoke_all / find_by_tag / find_by_category
|
||||
- `invoke_all()` 使用 `futures::future::join_all` + `tokio::time::timeout` 并行执行互不依赖的工具(每工具独立超时)
|
||||
- `definitions()` 将 `HashMap` 中的工具转换为 `Vec<ToolDefinition>`
|
||||
- `ToolRegistry` 不支持运行时并发注册(setup 阶段一次性构建),如需热注册由调用方通过 `Arc<RwLock<ToolRegistry>>` 包装
|
||||
- 编写 8+ 测试覆盖:注册冲突、空注册表查找、单次调用、批量并行调用、工具执行失败
|
||||
- 运行 `cargo test` 验证
|
||||
|
||||
### Step 6: LlmCycle 扩展(自动 Tool 循环)
|
||||
|
||||
- 新增 `cycle_submit.rs` 子模块(或直接在 `cycle.rs` 中扩增,取决于代码量)
|
||||
- 提取 `submit_request()` 内部方法(将 submit() 中的 request→response 逻辑独立),同时重构 `submit_messages()` 以复用同一路径
|
||||
- 实现 `submit_with_tools()` 方法:
|
||||
- 循环:submit_request → push Assistant 消息 → 检查 finish_reason → 调用 registry.invoke_all → push tool_results → 重复
|
||||
- 在 push tool_results **之前**先 push Assistant(tool_calls)消息(OpenAI API 要求)
|
||||
- `max_tool_turns` 控制(独立于 `max_turns`),达到上限返回错误
|
||||
- 不可恢复的错误(NotFound、PermissionDenied、McpError)终止循环
|
||||
- 可恢复的错误(ExecutionFailed、InvalidArguments)以文本回传 LLM
|
||||
- 每轮执行后触发 `maybe_compact()` 防止 token 膨胀
|
||||
- `submit_stream_with_tools()` 方法:
|
||||
- Phase 2 标记为未实现(返回 `LlmError::Other("流式 tool 循环将在后续版本中支持")`)
|
||||
- 实际实现推迟到 Phase 3(需要自定义 `ToolStream` 状态机)
|
||||
- 更新 `CycleConfig`:
|
||||
- 新增 `max_tool_turns: Option<u32>`,默认 `Some(10)`(不影响 `max_turns` 语义)
|
||||
- 新增 `tool_timeout_secs: u64`,默认值 60
|
||||
- 新增 `max_tool_result_bytes: Option<usize>`,默认 `Some(65536)`(限制单次工具结果大小)
|
||||
- 编写 3+ 集成测试:单轮 tool 调用、多轮 tool 调用、达到 max_tool_turns 终止
|
||||
- 运行 `cargo test` 验证
|
||||
|
||||
### Step 7: McpClient(MCP 协议客户端)
|
||||
|
||||
- 创建 `src/tools/mcp.rs`
|
||||
- 实现 JSON-RPC 消息结构(Request / Response / Error / Notification)
|
||||
- 定义 `ChildProcessState` 结构体,包含运行时字段:`child`/`stdin`/`pending: HashMap<u64, oneshot::Sender>`/`next_id: u64`
|
||||
- reader task 使用 `tokio::select!` 同时监听 stdout 和 cancellation token
|
||||
- `call_tool()` 通过 Mutex 获取 stdin 写入权限,通过 id 匹配响应
|
||||
- 子进程意外退出时通知所有 pending 请求
|
||||
- 实现 stdio transport:
|
||||
- `connect()`:启动子进程,创建 ChildProcessState,发送 initialize 请求
|
||||
- `list_tools()`:调用 tools/list,缓存结果
|
||||
- `call_tool()`:调用 tools/call,解析响应
|
||||
- `close()`:发送 shutdown → 等待 5s 优雅退出 → 超时则 child.kill()
|
||||
- `StreamableHttp` transport 预留枚举变体,当前返回 "not implemented" 错误,不在 Phase 2 实现
|
||||
- 实现 `into_tools()`:将 MCP 工具转换为 `Vec<Arc<dyn BaseTool>>` 适配器
|
||||
- 设置 30 秒默认超时
|
||||
- 编写 MCP 协议消息序列化/反序列化测试 + 模拟子进程集成测试
|
||||
- 运行 `cargo test` 验证
|
||||
|
||||
### Step 8: 收尾
|
||||
|
||||
- 更新 `docs/roadmap.md` 标记 Phase 2 完成
|
||||
- `cargo clippy` — 无警告
|
||||
- `cargo build` — 完整构建
|
||||
- 检查所有新公开 API 有 `///` 文档注释
|
||||
- `cargo test` — 所有测试通过
|
||||
|
||||
---
|
||||
|
||||
## 术语表
|
||||
|
||||
| 术语 | 说明 |
|
||||
|------|------|
|
||||
| `BaseTool` | 工具抽象接口,所有工具需实现此 trait |
|
||||
| `ToolRegistry` | 工具注册表,管理工具注册、发现、调用 |
|
||||
| `ToolInvocation` | 工具调用记录,包含输入、输出和执行结果 |
|
||||
| `Permission` | 权限级别枚举(Read/Write/Delete/Network/Shell 等) |
|
||||
| `PermissionChecker` | 权限校验器,执行前判定是否允许 |
|
||||
| `McpClient` | MCP 协议客户端,通过 stdio 与 MCP Server 通信 |
|
||||
| `ToolDefinition` | 传递给 LLM 的工具定义(同 `OpenaiToolDefinition`) |
|
||||
| 自动 Tool 循环 | LlmCycle 自动执行 LLM 请求的工具调用并回传结果 |
|
||||
|
||||
---
|
||||
|
||||
## 风险评估
|
||||
|
||||
| 风险 | 概率 | 缓解措施 |
|
||||
|------|------|---------|
|
||||
| MCP 协议规范变化 | 中 | 只实现最小子集(initialize/list_tools/call_tool),封装在 `mcp.rs` 中便于适配 |
|
||||
| MCP 子进程异常退出 | 中 | 实现超时机制 + 错误恢复;进程退出时自动标记为不可用 |
|
||||
| 工具执行死循环(LLM 反复调用工具) | 中 | `max_turns` 硬限制,达到上限后终止循环 |
|
||||
| JSON-RPC 消息竞争(stdio 双工) | 中 | 请求和响应通过 `id` 字段匹配,使用 `Mutex` 保护写操作 + `HashMap<u64, OneshotSender>` 等待响应,实现复杂度高于接口示意 |
|
||||
| 权限配置过于复杂 | 低 | PermissionConfig 提供合理默认值(允许 Read/Network,拒绝 Delete/Shell),简单场景无需自定义 |
|
||||
| 工具调用参数类型不匹配 | 低 | `execute()` 接收 `Value`,由实现方自行校验;通过 `ToolError::InvalidArguments` 返回结构化错误 |
|
||||
|
||||
---
|
||||
|
||||
## 验收标准
|
||||
|
||||
1. `cargo check` 编译通过
|
||||
2. `cargo clippy` 无警告
|
||||
3. 模块文件路径正确:`src/tools.rs` + `src/tools/{base,registry,permission,mcp,error}.rs`
|
||||
4. `BaseTool` trait 可被自定义工具实现,`name()` / `description()` / `parameters()` / `execute()` 四个方法正常工作
|
||||
5. `ToolRegistry` 支持注册、查找、列出、注销操作
|
||||
6. `ToolRegistry::definitions()` 返回正确的 `Vec<ToolDefinition>`
|
||||
7. `ToolRegistry::invoke()` 执行工具前进行权限检查
|
||||
8. `ToolRegistry::invoke_all()` 并行执行多个工具调用
|
||||
9. `PermissionChecker` 根据配置正确判定权限(白名单/黑名单/默认策略)
|
||||
10. `LlmCycle::submit_with_tools()` 收到 `FinishReason::ToolCalls` 后自动执行工具并回传结果
|
||||
11. `LlmCycle::submit_with_tools()` 达到 `max_turns` 上限时终止并返回错误
|
||||
12. `LlmCycle::submit_stream_with_tools()` 在流式模式下发射 `ToolExecutionCompleted` 事件
|
||||
13. 自动 tool 循环产生的 Tool 消息正确追加到 `cycle.messages()`
|
||||
14. `McpClient::connect()` 能完成 MCP 协议握手(initialize)
|
||||
15. `McpClient::list_tools()` 能获取 MCP Server 暴露的工具列表
|
||||
16. `McpClient::call_tool()` 能调用 MCP Server 的工具
|
||||
17. `McpClient::into_tools()` 能生成可供 `ToolRegistry` 注册的适配器
|
||||
18. 所有新公开 API 有文档注释
|
||||
19. 测试覆盖率:`cargo test` 全部通过
|
||||
20. `BaseTool::execute()` 签名通过 `ToolContext` 参数预留了扩展点(session_id、cancellation_token),未来 Skill/Agent 层可在不修改 trait 签名的情况下注入上下文
|
||||
@@ -0,0 +1,655 @@
|
||||
# 记忆系统设计方案
|
||||
|
||||
> 设计日期:2026-06-07
|
||||
> 状态:待实现
|
||||
|
||||
---
|
||||
|
||||
## 1. 背景与目标
|
||||
|
||||
### 1.1 背景
|
||||
|
||||
AG Core 已完成 Phase 0(LLM 调用周期)、Phase 1(提示词工程)、Phase 2(工具系统)。Phase 3 的目标是构建记忆系统,为 Phase 4(Agent 运行时)提供记忆存储、管理与检索能力。
|
||||
|
||||
### 1.2 目标
|
||||
|
||||
提供一套可插拔的记忆抽象层,支持以下记忆形态:
|
||||
|
||||
- **对话记忆(ConversationMemory)** — 管理多轮对话消息历史,支持 sliding window / 全量策略
|
||||
- **知识库(KnowledgeStore)** — 基于 LLM Wiki 模式的结构化知识管理,Agent 可自主编译和维护知识页面
|
||||
- **检索器(MemoryRetriever)** — 单通道关键词检索,提供统一的记忆查找入口
|
||||
|
||||
### 1.3 设计原则
|
||||
|
||||
- **不引入 embedding 依赖** — 采用 Karpathy's LLM Wiki 模式的 index + keyword 检索,替代传统向量检索
|
||||
- **trait + 轻量默认实现** — 存储抽象接口提供纯内存默认实现(InMemoryStore),满足原型和测试需求
|
||||
- **模块间松耦合** — 记忆系统与 LlmCycle 的集成推迟到 Phase 4 Agent Runtime,Phase 3 只定义接口和数据操作
|
||||
|
||||
---
|
||||
|
||||
## 2. 需求分析
|
||||
|
||||
### 2.1 功能需求
|
||||
|
||||
| ID | 需求 | 优先级 | 说明 |
|
||||
|----|------|--------|------|
|
||||
| F1 | MemoryStore 通用键值存储 | P0 | save/get/delete/list |
|
||||
| F2 | 对话消息管理 | P0 | 按 session 管理,支持 sliding window / full |
|
||||
| F3 | 知识页面 CRUD | P1 | 创建/更新/删除/检索知识页面 |
|
||||
| F4 | 知识页面关键词检索 | P1 | 基于标题/摘要/标签的关键词匹配 |
|
||||
| F5 | 知识页面索引维护 | P1 | 维护可遍历的内容目录(index) |
|
||||
| F6 | 可插拔后端 | P0 | MemoryStore 通过 trait 抽象,下游可实现自定义后端 |
|
||||
| F7 | 记忆淘汰 | P1 | 支持 TTL 过期淘汰、容量上限淘汰 |
|
||||
| F8 | 消息条目级淘汰 | P1 | ConversationMemory 达到上限后删除最旧消息 |
|
||||
|
||||
### 2.2 非功能需求
|
||||
|
||||
| ID | 需求 | 说明 |
|
||||
|----|------|------|
|
||||
| NF1 | 零 embedding 依赖 | 核心库不引入任何向量数据库或 embedding 模型依赖 |
|
||||
| NF2 | 错误体系完善 | MemoryError 枚举,支持 is_recoverable() 分类 |
|
||||
| NF3 | 线程安全 | 所有存储实现满足 Send + Sync |
|
||||
| NF4 | 异步 API | 所有 IO 操作为 async |
|
||||
| NF5 | 模块化 | 各组件独立可替换 |
|
||||
|
||||
---
|
||||
|
||||
## 3. 方案设计
|
||||
|
||||
### 3.1 总体架构
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph Retrieval["检索层"]
|
||||
MR["MemoryRetriever"]
|
||||
end
|
||||
subgraph Logic["逻辑层"]
|
||||
CM["ConversationMemory"]
|
||||
KS["KnowledgeStore"]
|
||||
end
|
||||
subgraph Storage["存储层"]
|
||||
MS["MemoryStore (trait)"]
|
||||
IMS["InMemoryStore (默认)"]
|
||||
end
|
||||
|
||||
MR --> KS
|
||||
CM --> MS
|
||||
KS --> MS
|
||||
MS --> IMS
|
||||
```
|
||||
|
||||
### 3.2 模块结构
|
||||
|
||||
```
|
||||
src/
|
||||
memory.rs # 模块根:pub mod + pub use 重导出
|
||||
memory/
|
||||
store.rs # MemoryStore trait + InMemoryStore
|
||||
conversation.rs # ConversationMemory(对话管理)
|
||||
knowledge.rs # KnowledgeStore(具体 struct)
|
||||
retriever.rs # MemoryRetriever(单通道检索)
|
||||
error.rs # MemoryError
|
||||
types.rs # 核心数据类型
|
||||
```
|
||||
|
||||
### 3.3 接口定义
|
||||
|
||||
#### MemoryStore — 底层存储抽象
|
||||
|
||||
```rust
|
||||
#[async_trait]
|
||||
pub trait MemoryStore: Send + Sync {
|
||||
/// 保存/覆盖一个 MemoryItem(upsert 语义)。
|
||||
/// - 如果 id 不存在,则插入新条目
|
||||
/// - 如果 id 已存在,则覆盖旧条目
|
||||
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError>;
|
||||
async fn get(&self, id: &str) -> Result<Option<MemoryItem>, MemoryError>;
|
||||
async fn delete(&self, id: &str) -> Result<(), MemoryError>;
|
||||
async fn list(&self, filter: &MemoryFilter) -> Result<Vec<MemoryItem>, MemoryError>;
|
||||
}
|
||||
```
|
||||
|
||||
#### InMemoryStore — 默认实现
|
||||
|
||||
```rust
|
||||
pub struct InMemoryStore {
|
||||
items: Mutex<HashMap<String, MemoryItem>>,
|
||||
}
|
||||
```
|
||||
|
||||
基于 `HashMap<String, MemoryItem>` + `Mutex`,纯内存,线程安全。
|
||||
|
||||
#### ConversationMemory — 对话记忆
|
||||
|
||||
```rust
|
||||
pub struct ConversationMemory {
|
||||
store: Arc<dyn MemoryStore>,
|
||||
session_id: String,
|
||||
config: ConversationMemoryConfig,
|
||||
messages: Vec<OpenaiChatMessage>, // 热缓存,供 compact 直接操作
|
||||
compact_state: CompactState, // 断路器状态
|
||||
}
|
||||
|
||||
pub struct ConversationMemoryConfig {
|
||||
pub strategy: MemoryStrategy, // sliding_window | full
|
||||
pub max_turns: usize, // sliding window 的最大轮数
|
||||
pub compact_config: Option<CompactConfig>, // 复用现有压缩配置
|
||||
}
|
||||
```
|
||||
|
||||
- `add_message(msg)` 写入热缓存 `self.messages`,同时通过 `store.save()` 持久化到后端
|
||||
- `get_history()` 优先从热缓存返回,缓存未命中时从 store 恢复
|
||||
- `compact` 直接在 `self.messages` 上调用 `should_compact()` 和 `microcompact()`
|
||||
- 压缩后同步回 `store`
|
||||
- 使用了 `llm::types::OpenaiChatMessage` 作为内部消息类型
|
||||
- 复用现有 `CompactConfig`(context_window, reserved_tokens, keep_recent)和 `CompactState`
|
||||
|
||||
#### KnowledgeStore — 具体 struct
|
||||
|
||||
```rust
|
||||
pub struct KnowledgeStore {
|
||||
store: Arc<dyn MemoryStore>,
|
||||
index: Mutex<Vec<PageIndexEntry>>,
|
||||
}
|
||||
|
||||
impl KnowledgeStore {
|
||||
pub fn new(store: Arc<dyn MemoryStore>) -> Self { ... }
|
||||
pub async fn add_page(&self, page: KnowledgePage) -> Result<(), MemoryError> { ... }
|
||||
pub async fn get_page(&self, id: &str) -> Result<Option<KnowledgePage>, MemoryError> { ... }
|
||||
pub async fn update_page(&self, page: KnowledgePage) -> Result<(), MemoryError> { ... }
|
||||
pub async fn delete_page(&self, id: &str) -> Result<(), MemoryError> { ... }
|
||||
pub async fn search(&self, query: &str) -> Result<Vec<KnowledgePage>, MemoryError> { ... }
|
||||
pub async fn get_index(&self) -> Result<Vec<PageIndexEntry>, MemoryError> { ... }
|
||||
}
|
||||
```
|
||||
|
||||
- `search()` 优先搜索 index(标题/摘要/标签),全文 content 搜索走 MemoryStore
|
||||
- index 在 add/update/delete 时自动维护,也支持通过 `rebuild_index()` 手动重建
|
||||
- `KnowledgeStore` 以 `"knowledge_{page_id}"` 格式作为 `MemoryItem.id` 前缀,前缀字符串提取为常量 `const KNOWLEDGE_PREFIX: &str = "knowledge_"`
|
||||
|
||||
提供 `rebuild_index()` 方法修复 index 与 store 的不同步问题:
|
||||
|
||||
```rust
|
||||
impl KnowledgeStore {
|
||||
/// 从 MemoryStore 重建 index(修复 index 与 store 的不同步问题)
|
||||
pub async fn rebuild_index(&self) -> Result<(), MemoryError> {
|
||||
let items = self.store.list(&MemoryFilter {
|
||||
prefix: Some("knowledge_".into()),
|
||||
since: None,
|
||||
offset: None,
|
||||
limit: None,
|
||||
}).await?;
|
||||
let mut index = self.index.lock();
|
||||
index.clear();
|
||||
for item in items {
|
||||
let page: KnowledgePage = serde_json::from_str(&item.content)
|
||||
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||
index.push(PageIndexEntry {
|
||||
id: page.id.clone(),
|
||||
title: page.title,
|
||||
summary: page.summary,
|
||||
tags: page.tags,
|
||||
updated_at: page.updated_at,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### MemoryRetriever — 简化版检索器
|
||||
|
||||
```rust
|
||||
pub struct MemoryRetriever {
|
||||
knowledge_store: KnowledgeStore,
|
||||
config: RetrieverConfig,
|
||||
}
|
||||
|
||||
pub struct RetrieverConfig {
|
||||
pub max_results: usize, // 默认 20
|
||||
pub min_score: f32, // 默认 0.1
|
||||
}
|
||||
|
||||
pub struct RetrievalResult {
|
||||
pub items: Vec<ScoredItem>,
|
||||
pub query: String,
|
||||
}
|
||||
|
||||
pub struct ScoredItem {
|
||||
pub page: KnowledgePage,
|
||||
pub score: f32, // TextOverlap 评分 [0.0, 1.0]
|
||||
}
|
||||
```
|
||||
|
||||
检索流程:
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A["输入: query"]
|
||||
B["1. 关键词提取(split + 过滤停用词)"]
|
||||
C["2. KnowledgeStore.search(keywords)"]
|
||||
D["3. TextOverlap 评分"]
|
||||
E["4. 过滤 score < min_score"]
|
||||
F["5. 降序排序 → 截取 top-N"]
|
||||
G["6. 返回 RetrievalResult"]
|
||||
|
||||
A --> B
|
||||
B --> C
|
||||
C --> D
|
||||
D --> E
|
||||
E --> F
|
||||
F --> G
|
||||
```
|
||||
|
||||
关键词提取在 MemoryRetriever 内部简单实现:按空格/标点分割 → 过滤单字符和停用词 → 返回关键词列表。TextOverlap 计算 query 与页面标题/摘要/内容的 n-gram 重叠度(基于 Dice 系数)。
|
||||
|
||||
TextOverlap 评分基于 Dice 系数(字符 bigram):
|
||||
|
||||
dice(query, text) = 2 × |bigrams(query) ∩ bigrams(text)| / (|bigrams(query)| + |bigrams(text)|)
|
||||
|
||||
多字段加权:
|
||||
score = title_dice × 0.5 + summary_dice × 0.3 + content_dice × 0.2
|
||||
|
||||
中文场景退化:当前版本按字符级 bigram 处理中文,不依赖分词器。
|
||||
|
||||
> **已知限制**:关键词提取基于空格/标点分割,对中文不做语义分词。
|
||||
> 中文场景按字符 bigram 参与 TextOverlap 计算,精度低于专业分词方案。
|
||||
> 如有更高精度需求,可替换 MemoryRetriever 的关键词提取逻辑。
|
||||
|
||||
### 3.4 核心数据类型
|
||||
|
||||
```rust
|
||||
pub struct MemoryItem {
|
||||
pub id: String,
|
||||
pub content: String,
|
||||
pub metadata: serde_json::Value,
|
||||
pub created_at: time::OffsetDateTime,
|
||||
}
|
||||
|
||||
pub struct MemoryFilter {
|
||||
pub prefix: Option<String>,
|
||||
pub since: Option<time::OffsetDateTime>,
|
||||
pub offset: Option<usize>, // 跳过前 N 条
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
pub struct KnowledgePage {
|
||||
pub id: String,
|
||||
pub title: String,
|
||||
pub summary: String,
|
||||
pub content: String,
|
||||
pub tags: Vec<String>,
|
||||
pub references: Vec<String>,
|
||||
pub created_at: time::OffsetDateTime,
|
||||
pub updated_at: time::OffsetDateTime,
|
||||
}
|
||||
|
||||
pub struct PageIndexEntry {
|
||||
pub id: String,
|
||||
pub title: String,
|
||||
pub summary: String,
|
||||
pub tags: Vec<String>,
|
||||
pub updated_at: time::OffsetDateTime,
|
||||
}
|
||||
|
||||
pub enum MemoryStrategy {
|
||||
SlidingWindow,
|
||||
Full,
|
||||
}
|
||||
```
|
||||
|
||||
### 3.5 错误类型
|
||||
|
||||
```rust
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum MemoryError {
|
||||
#[error("Item not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Storage error: {0}")]
|
||||
Storage(String),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(String),
|
||||
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
#[error("Retrieval error: {0}")]
|
||||
RetrievalError(String),
|
||||
}
|
||||
|
||||
impl MemoryError {
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
matches!(self, Self::NotFound(_) | Self::RetrievalError(_))
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3.6 ConversationMemory 与 compact 模块的集成
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
subgraph CM["ConversationMemory"]
|
||||
A["add_message(msg)"]
|
||||
A1["self.messages.push(msg) ← 热缓存"]
|
||||
A2["store.save(to_item(msg)) ← 冷持久化"]
|
||||
B{"len(messages) > max_turns?"}
|
||||
C["should_compact(&messages, &config, &state) ← 直接在热缓存上操作"]
|
||||
D["microcompact(&mut messages, keep_recent) ← 复用 microcompact()"]
|
||||
E["sync_to_store() ← 压缩后同步回 store"]
|
||||
F["return"]
|
||||
G["get_history()"]
|
||||
H["从 self.messages 返回"]
|
||||
I["从 store 恢复 → 重建热缓存"]
|
||||
|
||||
A --> A1
|
||||
A1 --> A2
|
||||
A2 --> B
|
||||
B -->|是| C
|
||||
C --> D
|
||||
D --> E
|
||||
E --> F
|
||||
B -->|否| F
|
||||
G --> H
|
||||
H -->|缓存未命中| I
|
||||
end
|
||||
J["依赖: llm::compact::{CompactConfig, CompactState, should_compact, microcompact}"]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 物理存储策略
|
||||
|
||||
### 4.1 存储层次
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph RetrievalLayer["检索层"]
|
||||
MR["MemoryRetriever (检索 + 评分,无状态)"]
|
||||
end
|
||||
subgraph AbstractionLayer["存储抽象层"]
|
||||
ABS["MemoryStore\n(存储抽象接口,不感知存储介质)"]
|
||||
end
|
||||
subgraph ImplementationLayer["实现层"]
|
||||
IMS["InMemoryStore (HashMap)\n进程内 volatile\n测试/原型适用"]
|
||||
CUSTOM["下游自定义实现\nFileStore / SqliteStore / RedisStore / ...\n生产环境适用"]
|
||||
end
|
||||
|
||||
MR --> ABS
|
||||
ABS --> IMS
|
||||
ABS --> CUSTOM
|
||||
```
|
||||
|
||||
### 4.2 InMemoryStore 的物理存储
|
||||
|
||||
| 组件 | 数据结构 | 存储位置 | 持久化 | 生命周期 |
|
||||
|------|---------|---------|--------|---------|
|
||||
| `InMemoryStore` | `HashMap<String, MemoryItem>` | 进程堆内存 | ❌ | 随进程销毁 |
|
||||
| `KnowledgeStore` | 基于 `InMemoryStore` + `Vec<PageIndexEntry>` | 进程堆内存 | ❌ | 随进程销毁 |
|
||||
|
||||
**适用场景:**
|
||||
- 单元测试和集成测试
|
||||
- 本地快速原型开发
|
||||
- 单次会话的临时 Agent
|
||||
|
||||
**不适合场景:**
|
||||
- 生产部署
|
||||
- 需要跨进程/跨会话共享记忆
|
||||
- 需要数据持久化和恢复
|
||||
|
||||
### 4.3 持久化存储方案(下游实现)
|
||||
|
||||
agcore 核心库**不内置**持久化实现,用户通过实现 `MemoryStore` trait 对接所需后端:
|
||||
|
||||
| 后端 | 实现建议 | 适用场景 | 复杂度 |
|
||||
|------|---------|---------|--------|
|
||||
| **JSON 文件** | MemoryStore trait → 序列化为单文件 JSON | 单机、轻量持久化 | 低 |
|
||||
| **SQLite** | MemoryStore → 关系表 | 单机、中小规模 | 中 |
|
||||
| **PostgreSQL** | MemoryStore → 关系表 | 多进程共享、中等规模 | 中 |
|
||||
| **Redis** | MemoryStore → Hash/JSON 类型 | 高速缓存、会话共享 | 低 |
|
||||
|
||||
### 4.4 对下游实现的约束
|
||||
|
||||
`MemoryStore` trait 对持久化实现无特殊约束:
|
||||
- 方法签名不涉及文件路径、连接字符串等存储细节
|
||||
- 所有方法均为 `async`,持久化实现可自由选择同步(`spawn_blocking`)或异步 driver
|
||||
- 初始化参数在具体实现的构造函数中注入
|
||||
|
||||
### 4.5 序列化
|
||||
|
||||
核心类型均实现 `Serialize` / `Deserialize`(通过 `#[derive(serde)]`),便于持久化实现直接复用:
|
||||
|
||||
```rust
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct MemoryItem { ... }
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct KnowledgePage { ... }
|
||||
// ...
|
||||
```
|
||||
|
||||
所有类型基于 `serde_json::Value` 作为 metadata 类型,不引入 protobuf/msgpack 等序列化框架。
|
||||
|
||||
---
|
||||
|
||||
## 5. 淘汰策略
|
||||
|
||||
### 5.1 问题
|
||||
|
||||
所有存储组件如果不设上限,会随运行时间无限增长。ConversationMemory 当前的 sliding window 只做 tool result 压缩,不删除消息条目。
|
||||
|
||||
### 5.2 淘汰策略
|
||||
|
||||
在 `MemoryStore` trait 层提供可选的淘汰配置,上层组件按需设置:
|
||||
|
||||
```rust
|
||||
pub struct EvictionConfig {
|
||||
pub policy: EvictionPolicy,
|
||||
pub check_interval: usize, // 每写入 N 条后检查一次淘汰条件
|
||||
}
|
||||
|
||||
pub enum EvictionPolicy {
|
||||
None, // 不淘汰(默认)
|
||||
Ttl { ttl_secs: u64 }, // 超过存活时间淘汰
|
||||
Capacity { max_items: usize },// 超过容量淘汰最旧
|
||||
}
|
||||
```
|
||||
|
||||
`InMemoryStore` 在 `save()` 后检查淘汰条件:
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A["save(item)"]
|
||||
B["items.insert(id, item)"]
|
||||
C{"writes_since_last_check >= check_interval?"}
|
||||
D{"policy 类型"}
|
||||
E["Ttl → items.retain(created_at > cutoff)"]
|
||||
F["Capacity → 按 created_at 升序排列,截断到 max_items"]
|
||||
G["None → 不淘汰"]
|
||||
|
||||
A --> B
|
||||
B --> C
|
||||
C -->|是| D
|
||||
C -->|否| G
|
||||
D -->|Ttl| E
|
||||
D -->|Capacity| F
|
||||
D -->|None| G
|
||||
```
|
||||
|
||||
### 5.3 各组件淘汰策略
|
||||
|
||||
| 组件 | 推荐策略 | 理由 |
|
||||
|------|---------|------|
|
||||
| **ConversationMemory** | `Capacity { max_items }` | 对话是流式的,旧消息价值递减,达到上限后淘汰最旧的消息条目 |
|
||||
| **MemoryStore**(通用) | `Ttl { ttl_secs }` | 通用存储由调用方按场景决定 |
|
||||
| **KnowledgeStore** | `None`(默认不淘汰) | 知识是累积的,新增不淘汰旧 |
|
||||
|
||||
### 5.4 ConversationMemory 淘汰行为
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A["add_message(msg)"]
|
||||
B["store.save(msg)"]
|
||||
C{"eviction.policy == Capacity?"}
|
||||
D["history = store.list(session)"]
|
||||
E{"while history.len() > max_items"}
|
||||
F["oldest = history.remove(0) ← 删除最旧消息"]
|
||||
G["store.delete(oldest.id)"]
|
||||
H["maybe_compact() ← 复用 microcompact() 做内容压缩"]
|
||||
I["return"]
|
||||
|
||||
A --> B
|
||||
B --> C
|
||||
C -->|是| D
|
||||
D --> E
|
||||
E -->|是| F
|
||||
F --> G
|
||||
G --> E
|
||||
E -->|否| H
|
||||
C -->|否| H
|
||||
H --> I
|
||||
```
|
||||
|
||||
两种机制分层:
|
||||
- **淘汰(eviction)**:删除整条消息,控制条目总数上限
|
||||
- **压缩(compaction)**:压缩剩余消息的 tool result 内容,节省 token
|
||||
|
||||
### 5.5 InMemoryStore 的淘汰实现
|
||||
|
||||
```rust
|
||||
impl InMemoryStore {
|
||||
pub fn with_eviction(config: EvictionConfig) -> Self { ... }
|
||||
}
|
||||
|
||||
// 在 save() 内部:
|
||||
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError> {
|
||||
self.items.lock().insert(item.id.clone(), item);
|
||||
self.maybe_evict().await;
|
||||
}
|
||||
|
||||
async fn maybe_evict(&self) {
|
||||
match &self.eviction.policy {
|
||||
EvictionPolicy::Ttl { ttl_secs } => {
|
||||
let cutoff = Utc::now() - Duration::seconds(*ttl_secs as i64);
|
||||
self.items.lock().retain(|_, v| v.created_at > cutoff);
|
||||
}
|
||||
EvictionPolicy::Capacity { max_items } => {
|
||||
let mut items = self.items.lock();
|
||||
if items.len() > *max_items {
|
||||
let mut vec: Vec<_> = items.drain().collect();
|
||||
vec.select_nth_unstable_by(
|
||||
*max_items,
|
||||
|a, b| b.1.created_at.cmp(&a.1.created_at),
|
||||
);
|
||||
vec.truncate(*max_items);
|
||||
*items = vec.into_iter().collect();
|
||||
}
|
||||
}
|
||||
EvictionPolicy::None => {}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 实现计划
|
||||
|
||||
### Step 1:基础类型 + MemoryStore(含淘汰机制)
|
||||
|
||||
**文件**:`src/memory.rs`、`src/memory/types.rs`、`src/memory/error.rs`、`src/memory/store.rs`
|
||||
|
||||
- 创建 `memory.rs` + `memory/` 目录
|
||||
- 定义 `MemoryItem`、`MemoryFilter`、`MemoryStrategy` 类型
|
||||
- 定义 `MemoryError` 枚举
|
||||
- 定义 `MemoryStore` trait 与 `InMemoryStore` 实现
|
||||
- 定义 `EvictionConfig`、`EvictionPolicy`(None / Ttl / Capacity)
|
||||
- `InMemoryStore.save()` 内部实现淘汰检查
|
||||
- 单元测试:TTL 过期淘汰、容量上限淘汰、不淘汰(None)
|
||||
- 验收:`cargo build` + `cargo test` 通过
|
||||
|
||||
**依赖**:time(日期时间,启用 `serde` feature)、serde(序列化)
|
||||
|
||||
### Step 2:ConversationMemory(含消息淘汰)
|
||||
|
||||
**文件**:`src/memory/conversation.rs`
|
||||
|
||||
- 定义 `ConversationMemoryConfig`、`MemoryStrategy`
|
||||
- 实现 `ConversationMemory`
|
||||
- `add_message()` → 写入 MemoryStore → 触发容量淘汰(删除最旧消息)
|
||||
- 复用 `llm::compact` 的 `CompactConfig` 和 `microcompact()` 做内容压缩
|
||||
- 单元测试:sliding window 消息淘汰、tool result 内容压缩、full 模式、空 session
|
||||
- 验收:`cargo build` + `cargo test` 通过
|
||||
|
||||
**依赖**:MemoryStore(含 EvictionConfig)+ llm::compact
|
||||
|
||||
### Step 3:KnowledgeStore
|
||||
|
||||
**文件**:`src/memory/knowledge.rs`
|
||||
|
||||
- 定义 `KnowledgePage`、`PageIndexEntry` 类型
|
||||
- 实现 `KnowledgeStore` 具体 struct(非 trait)
|
||||
- 内部使用 `Arc<dyn MemoryStore>` 存储数据
|
||||
- index 自动维护(add/update/delete 时同步)
|
||||
- search 基于标题/摘要/标签的关键词匹配
|
||||
- 单元测试:页面 CRUD、index 一致性、搜索
|
||||
- 验收:`cargo build` + `cargo test` 通过
|
||||
|
||||
### Step 4:MemoryRetriever + 模块整合
|
||||
|
||||
**文件**:`src/memory/retriever.rs`、`src/memory.rs`
|
||||
|
||||
- 实现内存检索器 `MemoryRetriever`
|
||||
- 内部关键词提取:split + 过滤停用词
|
||||
- TextOverlap 评分:基于 Dice 系数计算 query 与页面的文本重叠度
|
||||
- 阈值过滤 → 排序 → 截取 top-N
|
||||
- 在 `memory.rs` 中用 `pub use` 分层重导出:
|
||||
- 高频类型(大多数下游需要):`MemoryStore`、`InMemoryStore`、`ConversationMemory`、`KnowledgeStore`、`MemoryRetriever`、`MemoryError`
|
||||
- 低频类型(配置/高级使用):`MemoryItem`、`MemoryFilter`、`MemoryStrategy`、`KnowledgePage`、`PageIndexEntry`、`EvictionConfig`、`EvictionPolicy`、`ConversationMemoryConfig`、`RetrieverConfig`、`RetrievalResult`、`ScoredItem`
|
||||
- 在 `src/lib.rs` 中声明 `pub mod memory`
|
||||
- 单元测试:关键词提取、TextOverlap 评分正确性、阈值过滤、排序正确性
|
||||
- 集成测试:端到端检索流程
|
||||
- 验收:`cargo build` + `cargo test` 通过
|
||||
|
||||
---
|
||||
|
||||
## 7. 风险评估
|
||||
|
||||
| 风险 | 概率 | 影响 | 缓解措施 |
|
||||
|------|------|------|---------|
|
||||
| KnowledgeStore 的 keyword 检索在大规模下效率低 | 中 | 中 | MemoryStore 实现可替换——下游可使用 SQLite FTS 等更高效的后端 |
|
||||
| ConversationMemory 与 compact 耦合引入循环依赖 | 低 | 高 | 仅引用 `CompactConfig`(纯数据结构)和 `microcompact()`(纯函数),不引用 cycle.rs |
|
||||
| time 增加依赖体积 | 低 | 低 | time 是 Rust 官方维护的时间库,体积小于 chrono |
|
||||
| Phase 4 集成时发现 Memory 设计不合理 | 低 | 高 | 按最小可行接口设计,预留扩展空间 |
|
||||
|
||||
---
|
||||
|
||||
## 8. 验收标准
|
||||
|
||||
- [ ] `MemoryStore` trait + `InMemoryStore` 通过单元测试
|
||||
- [ ] `EvictionConfig` 支持 None / Ttl / Capacity 三种策略
|
||||
- [ ] `InMemoryStore` 在 save() 后正确执行 TTL 淘汰和容量淘汰
|
||||
- [ ] `ConversationMemory` 支持 sliding window 和 full 两种策略
|
||||
- [ ] `ConversationMemory` sliding window 模式下达到上限后删除最旧消息条目
|
||||
- [ ] `ConversationMemory` 正确复用 `llm::compact` 的压缩逻辑
|
||||
- [ ] `KnowledgeStore` 支持页面 CRUD 和 index 维护
|
||||
- [ ] `MemoryRetriever` 支持基于 TextOverlap 的知识检索
|
||||
- [ ] 无 embedding 相关依赖
|
||||
- [ ] 模块结构:`memory.rs` + `memory/` 目录 + `pub use` 重导出
|
||||
- [ ] `MemoryError` 枚举完善,支持 `is_recoverable()`
|
||||
- [ ] 所有公开 API 有文档注释(`///`)
|
||||
- [ ] `cargo build` 和 `cargo test` 通过
|
||||
- [ ] 单个文件不超过 300 行
|
||||
|
||||
---
|
||||
|
||||
## 附录:与 Karpathy's LLM Wiki 的关系
|
||||
|
||||
本方案受 Karpathy's LLM Wiki 模式启发,但做了一些调整以适应 Agent 核心库的定位:
|
||||
|
||||
| Karpathy LLM Wiki | AG Core Memory System | 差异原因 |
|
||||
|-------------------|----------------------|---------|
|
||||
| 三层:Raw → Wiki → Schema | 三组件:MemoryStore → ConversationMemory + KnowledgeStore + MemoryRetriever | Agent 场景需要区分对话记忆和知识记忆 |
|
||||
| index.md + log.md | PageIndexEntry(同 index.md)+ 无 log(Phase 4 Agent 负责) | 日志是工作流层职责,非存储层 |
|
||||
| LLM Agent 全权维护 | KnowledgeStore 提供数据接口,Phase 4 Agent 编排工作流 | core 只提供存储能力,不编排 |
|
||||
| 文件系统为后端 | MemoryStore trait 抽象后端 | 可插拔设计需要 trait 抽象 |
|
||||
| 基于文件系统搜索 | index + keyword 检索 | 文件系统搜索不适合所有后端 |
|
||||
@@ -0,0 +1,884 @@
|
||||
# Agent Runtime 方案设计
|
||||
|
||||
> 设计日期:2026-06-09
|
||||
> 状态:待实施
|
||||
> 关联文档:
|
||||
> - `docs/note-agent-runtime-design.md` — 设计决策记录(接口签名、文件清单、决策依据)
|
||||
> - `docs/note-agent-harness-references.md` — 参考项目调研(OpenClaw / Hermes / OpenHuman / OpenHarness)
|
||||
> - `docs/6-memory-system.md` — Phase 3 方案
|
||||
> - `docs/5-tool-system.md` — Phase 2 方案
|
||||
> - `docs/roadmap.md` — 项目总 Roadmap
|
||||
|
||||
---
|
||||
|
||||
## 1. 背景与目标
|
||||
|
||||
### 1.1 背景
|
||||
|
||||
AG Core 已完成 Phase 0(LLM 调用周期)、Phase 1(提示词工程)、Phase 2(工具系统)、Phase 3(记忆系统)共 4 个 phase 的交付。`LlmCycle::submit_with_tools()` 已在 Phase 2 末实现"LLM 决策 → 工具执行 → 回传结果"的单次循环;`ConversationMemory` / `KnowledgeStore` / `MemoryRetriever` 在 Phase 3 提供了完整的记忆抽象。
|
||||
|
||||
当前缺一个**整合层**:把 Phase 0-3 的能力"装配"起来,对上层应用暴露"智能体"的概念。
|
||||
|
||||
### 1.2 目标
|
||||
|
||||
Phase 4 整体目标是提供一个**薄胶水层 + 一组 trait 抽象**,让上层应用可以基于 AG Core 构建多轮对话、任务规划等智能体行为。为控制 scope、降低交付风险,拆分为三个子阶段实施:
|
||||
|
||||
| 子阶段 | 定位 | 交付物 |
|
||||
|--------|------|--------|
|
||||
| **Phase 4a(核心胶水层)** | 最小可用 Agent Runtime | `Agent` + `AgentSession` + `submit_turn` + `RuntimeBundle` / `AgentBuilder` / `AgentError` + `Plan`/`Step` 纯数据 + hooks 扩展 |
|
||||
| **Phase 4b(任务执行)** | 自主任务规划与执行 | `TaskAgent` + `PlanParser` trait + `JsonPlanParser` + `OnPlanStepComplete` hook |
|
||||
| **Phase 4c(会话级记忆)** | 跨 context 信息桥接 | `SessionMemory`(基于 `MemoryStore`)+ AgentSession 接入 + builder 支持 |
|
||||
|
||||
**每个子阶段独立交付**,Phase 4a 完成后上层即可接入;Phase 4b/4c 无相互依赖,可并行或按需延后。
|
||||
|
||||
Phase 4a 具体包括:
|
||||
|
||||
- **`Agent` trait** — 智能体的"角色"抽象(不绑定 session)
|
||||
- **`AgentSession` struct** — 智能体的"会话"实例(绑定 session_id + 状态)
|
||||
- **`RuntimeBundle`** — 显式依赖注入容器,集中管理 provider/registry/hook 等依赖
|
||||
- **`AgentBuilder`** — 链式构造入口
|
||||
- **`AgentError`** — 统一错误类型,聚合 LlmError / ToolError / MemoryError
|
||||
- **`Plan` / `Step` / `StepStatus`** — 任务规划纯数据结构(不做解析逻辑)
|
||||
|
||||
Phase 4b 追加:
|
||||
|
||||
- **`TaskAgent` trait** — 任务型智能体的"规划/执行"抽象
|
||||
- **`PlanParser` trait + `JsonPlanParser`** — Plan 解析接口与参考实现
|
||||
|
||||
Phase 4c 追加:
|
||||
|
||||
- **`SessionMemory`** — 会话级记忆,用于 context 间的信息桥接(基于 `MemoryStore` 后端)
|
||||
|
||||
### 1.3 设计原则
|
||||
|
||||
Phase 4 严格遵循以下原则,所有范围决策都基于这些原则推导:
|
||||
|
||||
| 原则 | 含义 | 推导 |
|
||||
|------|------|------|
|
||||
| **最小范围** | AG Core 是 lib crate,不是产品;不实现业务循环 | 只暴露 trait + 最小 reference impl |
|
||||
| **薄胶水层** | 不在 L1 重写已经做好的能力 | 复用 `LlmCycle::submit_with_tools` 等已有 API |
|
||||
| **依赖注入** | 所有运行时依赖显式打包传递 | 采用 OpenHarness `RuntimeBundle` 模式 |
|
||||
| **实体/会话分离** | 同一角色可被多 session 复用 | `Agent` + `AgentSession` 两层模型 |
|
||||
| **记忆弱引用** | 记忆是"被动能力",不内嵌循环 | `memory_store: Option<Arc<dyn MemoryStore>>` 弱引用 |
|
||||
| **业务可注入** | Plan 拆解是业务能力,不在 core 库实现 | 暴露 `PlanParser` trait,上层注入 |
|
||||
| **会话级记忆** | session 内共享、context 间桥接,不是持久层也不是对话历史 | `SessionMemory` 基于 `MemoryStore`,按 session_id 命名空间隔离 |
|
||||
| **借鉴不照搬** | 4 个参考项目均非 Rust 实现 | 只取架构模式,不抄实现细节 |
|
||||
|
||||
### 1.4 与已完成的 Phase 关系
|
||||
|
||||
```
|
||||
Phase 0 (L0/L1) ── LlmProvider / LlmCycle / Hook / Stream / Compact
|
||||
Phase 1 (L2) ── PromptTemplate / PromptComposer
|
||||
Phase 2 (L1) ── ToolRegistry / BaseTool / PermissionChecker / McpClient
|
||||
Phase 3 (L2) ── MemoryStore / ConversationMemory / KnowledgeStore / MemoryRetriever
|
||||
↑
|
||||
│ 复用
|
||||
│
|
||||
Phase 4a (L1→L2) ── Agent trait + AgentSession + submit_turn + RuntimeBundle + Plan/Step 纯数据(胶水层)
|
||||
Phase 4b (L2) ── TaskAgent + PlanParser + JsonPlanParser(任务执行)
|
||||
Phase 4c (L2) ── SessionMemory(会话级记忆)
|
||||
↓
|
||||
应用层 (L4) ── 上层 crate / 二进制 / Gateway(不在 Phase 4 范围)
|
||||
```
|
||||
|
||||
详细架构对照见 `docs/note-agent-harness-references.md` §3-5。
|
||||
|
||||
## 2. 需求分析
|
||||
|
||||
### 2.1 功能需求
|
||||
|
||||
| ID | 需求 | 优先级 | 归属 | 说明 |
|
||||
|----|------|--------|------|------|
|
||||
| F1 | `Agent` trait 抽象 | P0 | 4a | 角色定义:name / system_prompt / 工具集 |
|
||||
| F2 | `AgentSession` 会话实例 | P0 | 4a | 绑定 session_id、bundle、turn_index、cost_so_far |
|
||||
| F3 | `submit_turn()` 最小 reference impl | P0 | 4a | 组装 LlmCycle → submit → 累计 cost;~30 行 |
|
||||
| F6 | `Plan` / `Step` / `StepStatus` 数据结构 | P0 | 4a | 含 Pending / Running / Completed / Failed / Skipped 状态机 |
|
||||
| F8 | `RuntimeBundle` 依赖注入容器 | P0 | 4a | 聚合 provider/registry/hook/config(不含 session_memory_backend) |
|
||||
| F9 | `AgentBuilder` 链式构造 | P0 | 4a | 构建 `RuntimeBundle`,retriever 存在时自动注册为 tool |
|
||||
| F10 | `AgentError` 统一错误类型 | P0 | 4a | 聚合 LlmError / ToolError / MemoryError,含 `is_recoverable()` |
|
||||
| F11a | Hook 事件扩展:OnTurnStart / OnTurnEnd + turn_index 字段 | P0 | 4a | 在 `llm/hooks.rs` 中追加 2 个事件 + 1 个字段 |
|
||||
| F12a | 烟雾测试 3-4 个(Phase 4a) | P0 | 4a | trait 可装配 / RuntimeBundle 可构造 / submit_turn 跑通 mock / Plan 数据结构 |
|
||||
| F13 | `lib.rs` 导出 `pub mod agent;` | P0 | 4a | 一行 |
|
||||
| F14 | 方案文档(本文件)+ 决策记录 | P0 | — | ✅ 已完成 |
|
||||
| F4 | `TaskAgent::run(goal)` 自主式入口 | P0 | 4b | 内部用 LLM 拆 Plan,再调用 `execute_plan` |
|
||||
| F5 | `TaskAgent::execute_plan(plan)` 外部驱动式入口 | P0 | 4b | 用户预定义 Plan,逐步执行 |
|
||||
| F7 | `PlanParser` trait + `JsonPlanParser` 参考实现 | P0 | 4b | 注入式,上层可替换 |
|
||||
| F11b | Hook 事件扩展:OnPlanStepComplete + plan_step_index 字段 | P0 | 4b | 在 `llm/hooks.rs` 中追加 1 个事件 + 1 个字段 |
|
||||
| F12b | 烟雾测试 2-3 个(Phase 4b) | P0 | 4b | TaskAgent + PlanParser 跑通 mock |
|
||||
| F15a | Roadmap 状态翻转(Phase 4a) | P0 | 4a | 实施完成后做 |
|
||||
| F15b | Roadmap 状态翻转(Phase 4b) | P0 | 4b | 实施完成后做 |
|
||||
| F16 | SessionMemory 会话级记忆 | P0 | 4c | 基于 `MemoryStore`,context 间信息桥接 |
|
||||
| F17 | RuntimeBundle / Builder 扩展 session_memory_backend | P0 | 4c | 追加字段 + setter 方法 |
|
||||
| F18 | AgentSession 接入 SessionMemory | P0 | 4c | 替换内联 HashMap,接入完整 SessionMemory |
|
||||
| F12c | 烟雾测试 2-3 个(Phase 4c) | P0 | 4c | SessionMemory set/get/snapshot |
|
||||
| F15c | Roadmap 状态翻转(Phase 4c) | P0 | 4c | 实施完成后做 |
|
||||
|
||||
### 2.2 非功能需求
|
||||
|
||||
| ID | 需求 | 说明 |
|
||||
|----|------|------|
|
||||
| NF1 | 不引入新外部依赖 | 仅使用 Phase 0-3 已有的 `async-trait` / `serde` / `thiserror` / `tokio` 等 |
|
||||
| NF2 | 错误体系完善 | `AgentError` 聚合下层错误,含 `is_recoverable()` 分类 |
|
||||
| NF3 | 线程安全 | 所有公开类型满足 `Send + Sync` |
|
||||
| NF4 | 异步优先 | 涉及 IO 的 API 全部 `async` |
|
||||
| NF5 | 模块化 | 各组件独立可替换,遵循"trait 抽象 + 轻量默认实现"惯例 |
|
||||
| NF6 | 文档注释 | 所有公开 API 必须有 `///` 文档注释 |
|
||||
| NF7 | builder 模式 | 复杂配置走 builder 链式构造 |
|
||||
| NF8 | 显式依赖 | 不引入模块级全局状态,所有依赖通过参数或 bundle 注入 |
|
||||
| NF9 | 不破坏现有 API | Phase 0-3 的公开 API 一字不改;`hooks.rs` 扩展为"追加变体 + 追加字段"(兼容) |
|
||||
| NF10 | 最小测试覆盖 | 核心 trait 至少 1 个烟雾测试;`submit_turn` 至少 1 个 mock 测试;不强求集成测试 |
|
||||
|
||||
## 3. 方案设计
|
||||
|
||||
### 3.1 总体架构
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ 应用层(不在 Phase 4 范围) │
|
||||
│ ┌────────────┐ ┌────────────┐ ┌────────────┐ ┌────────────┐ │
|
||||
│ │ CLI Agent │ │ Feishu Bot │ │ Web Service│ │ TUI App │ │
|
||||
│ └─────┬──────┘ └─────┬──────┘ └─────┬──────┘ └─────┬──────┘ │
|
||||
└─────────┼────────────────┼────────────────┼────────────────┼───────────┘
|
||||
│ │ │ │
|
||||
└────────────────┴────────────────┴────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ Agent Runtime(Phase 4) │
|
||||
│ │
|
||||
│ ┌────────────────┐ ┌──────────────────┐ │
|
||||
│ │ Agent trait │ 1 ──── * │ AgentSession │ │
|
||||
│ │ (角色) │ │ (会话实例) │ │
|
||||
│ └────────────────┘ └──────┬───────────┘ │
|
||||
│ │ Arc<...> │
|
||||
│ ▼ │
|
||||
│ ┌──────────────────┐ │
|
||||
│ │ RuntimeBundle │ │
|
||||
│ │ - provider │ │
|
||||
│ │ - tool_registry │ │
|
||||
│ │ - hook_executor │ │
|
||||
│ │ - memory_store? │ ◄─ 弱引用 │
|
||||
│ │ - retriever? │ ◄─ 弱引用 │
|
||||
│ │ - config │ │
|
||||
│ └──────┬───────────┘ │
|
||||
│ │ new() 时若 retriever 存在 │
|
||||
│ ▼ │
|
||||
│ ┌──────────────────┐ │
|
||||
│ │ "retrieve" tool │ ◄─ 自动注册 │
|
||||
│ └──────────────────┘ │
|
||||
│ │
|
||||
│ ┌────────────────┐ ┌──────────────────┐ ┌──────────────────┐ │
|
||||
│ │ TaskAgent trait│ │ Plan/Step/Status │ │ PlanParser trait │ │
|
||||
│ │ run() │ │ 状态机 │ │ JsonPlanParser │ │
|
||||
│ │ execute_plan()│ │ │ │ (参考实现 ~20行) │ │
|
||||
│ └────────────────┘ └──────────────────┘ └──────────────────┘ │
|
||||
│ │
|
||||
│ ┌────────────────┐ ┌──────────────────┐ │
|
||||
│ │ AgentError │ │ AgentBuilder │ │
|
||||
│ │ (聚合) │ │ (链式构造) │ │
|
||||
│ └────────────────┘ └──────────────────┘ │
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼ 复用
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ LLM / Tool / Prompt / Memory(Phase 0-3) │
|
||||
│ LlmCycle / ProviderRegistry / ToolRegistry / PermissionChecker / │
|
||||
│ HookExecutor / StreamEvents / CompactConfig / │
|
||||
│ PromptTemplate / PromptComposer / │
|
||||
│ MemoryStore / ConversationMemory / KnowledgeStore / MemoryRetriever│
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 3.2 接口设计
|
||||
|
||||
详细接口签名见 `docs/note-agent-runtime-design.md` §4,本节说明设计意图。
|
||||
|
||||
#### 3.2.1 `Agent` trait
|
||||
|
||||
```rust
|
||||
pub trait Agent: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
fn system_prompt(&self) -> Option<&str>;
|
||||
/// 列出该 Agent 想要暴露给 LLM 的工具定义。
|
||||
/// 默认实现:从 RuntimeBundle.tool_registry 取全部(最常用)。
|
||||
/// 子 trait 可覆盖做白名单/过滤。
|
||||
fn tool_definitions(&self, bundle: &RuntimeBundle) -> Vec<ToolDefinition>;
|
||||
}
|
||||
```
|
||||
|
||||
**设计意图**:
|
||||
- `name` / `system_prompt` 是 LLM 调用必需的元数据
|
||||
- `tool_definitions` 默认从 bundle 全量取,**Agent 可以在不修改 bundle 的情况下做工具白名单**——这与 Hermes 的"Skill 暴露"机制对齐
|
||||
- 不在 trait 里强制 `submit_turn`——`submit_turn` 是 `AgentSession` 的方法,不应绑死在角色定义上
|
||||
|
||||
#### 3.2.2 `RuntimeBundle`
|
||||
|
||||
```rust
|
||||
pub struct RuntimeBundle {
|
||||
pub provider: Arc<dyn LlmProvider>,
|
||||
pub tool_registry: Arc<ToolRegistry>,
|
||||
pub hook_executor: Arc<HookExecutor>,
|
||||
pub memory_store: Option<Arc<dyn MemoryStore>>, // 弱引用
|
||||
pub retriever: Option<Arc<MemoryRetriever>>, // 弱引用
|
||||
pub session_memory_backend: Option<Arc<dyn MemoryStore>>, // SessionMemory 后端(选填)
|
||||
pub config: AgentConfig,
|
||||
}
|
||||
```
|
||||
|
||||
**设计意图**:
|
||||
- 所有运行时依赖**显式打包**(OpenHarness 风格)
|
||||
- `memory_store` / `retriever` 均为 `Option`——上层应用**不传也能跑**(无记忆模式)
|
||||
- 当 `retriever` 存在时,`RuntimeBundle::new()` 内部自动注册一个名为 `"retrieve"` 的 tool(具体实现:在 `ToolRegistry` 里加一个 `RetrieveTool` 包装),让 LLM 在对话中**主动**调用检索能力
|
||||
- `session_memory_backend` 是 `SessionMemory` 的持久后端。传入时 `SessionMemory` 使用该后端(支持跨进程共享);不传时 `AgentSession` 内部自动创建 `InMemoryStore` 作为进程级隔离的后端
|
||||
- `config` 集中管理所有可调参数(max_turns、max_tool_turns、session_ttl、compact_config)
|
||||
|
||||
#### 3.2.3 `AgentSession` 与最小 reference impl
|
||||
|
||||
```rust
|
||||
pub struct AgentSession {
|
||||
pub session_id: String,
|
||||
pub agent: Arc<dyn Agent>,
|
||||
bundle: Arc<RuntimeBundle>,
|
||||
turn_index: u32,
|
||||
cost_so_far: CostTracker,
|
||||
session_memory: SessionMemory,
|
||||
}
|
||||
|
||||
impl AgentSession {
|
||||
/// 最小 reference impl(约 30 行):
|
||||
/// 1. 触发 OnTurnStart hook
|
||||
/// 2. 组装 LlmCycle(注入 system_prompt + messages 历史 + tool definitions)
|
||||
/// 3. submit_with_tools() 跑单轮对话
|
||||
/// 4. 累计 cost
|
||||
/// 5. 触发 OnTurnEnd hook
|
||||
/// 6. turn_index += 1
|
||||
/// 7. 返回 ChatResponse
|
||||
/// 不做 memory 回写(由上层独立 task 处理)
|
||||
pub async fn submit_turn(
|
||||
&mut self,
|
||||
user_input: impl Into<String>,
|
||||
) -> Result<ChatResponse, AgentError>;
|
||||
}
|
||||
```
|
||||
|
||||
**设计意图**:
|
||||
- `agent: Arc<dyn Agent>` 而非 `agent_name: String`——`submit_turn` 从 agent 获取 `system_prompt()` 和 `tool_definitions()`,同时为 v0.2+ 的"热切换 agent"预留:替换 `self.agent` 即可切换角色
|
||||
- `session_memory` 是进程内共享的会话级记忆,context 间通过它桥接信息(详见 §3.2.8)
|
||||
- "最小 reference impl" 只演示**最常见**的对话场景
|
||||
- 业务循环(多轮策略、错误重试、记忆回写时机)由上层应用或具体的 `TaskAgent` 实现决定
|
||||
- `submit_turn` 不持有 `ConversationMemory`——上层应用可独立 new 一个 `ConversationMemory`,在合适的时机(如 OnTurnEnd hook)调 `add_message`
|
||||
|
||||
#### 3.2.4 `TaskAgent` + `Plan` / `Step`
|
||||
|
||||
```rust
|
||||
pub struct Plan {
|
||||
pub id: String,
|
||||
pub goal: String,
|
||||
pub steps: Vec<Step>,
|
||||
}
|
||||
|
||||
pub struct Step {
|
||||
pub index: usize,
|
||||
pub description: String,
|
||||
pub status: StepStatus,
|
||||
}
|
||||
|
||||
pub enum StepStatus {
|
||||
Pending,
|
||||
Running,
|
||||
Completed(ChatResponse),
|
||||
Failed(AgentError),
|
||||
Skipped,
|
||||
}
|
||||
```
|
||||
|
||||
**设计意图**:
|
||||
- `StepStatus` 用 enum 而非简单 bool,便于上层 UI 展示和统计
|
||||
- 状态机转换:`Pending → Running → (Completed | Failed | Skipped)`,单向不可回退(重试由上层新建 Plan)
|
||||
- `Plan` / `Step` 故意保持简单——不引入 `dependencies` / `parallel_group` 等高级字段(v0.3+ 再考虑)
|
||||
|
||||
#### 3.2.5 `PlanParser` trait + `JsonPlanParser` 参考实现
|
||||
|
||||
```rust
|
||||
#[async_trait]
|
||||
pub trait PlanParser: Send + Sync {
|
||||
async fn parse(&self, raw: &str, goal: &str) -> Result<Plan, AgentError>;
|
||||
}
|
||||
|
||||
pub struct JsonPlanParser;
|
||||
#[async_trait]
|
||||
impl PlanParser for JsonPlanParser {
|
||||
/// 期望 LLM 输出形如:
|
||||
/// {"steps": [{"description": "..."}, ...]}
|
||||
/// 的 JSON 文本。
|
||||
/// 解析失败返回 AgentError::PlanParse。
|
||||
async fn parse(&self, raw: &str, goal: &str) -> Result<Plan, AgentError> { /* ... */ }
|
||||
}
|
||||
```
|
||||
|
||||
**设计意图**:
|
||||
- **注入式**:上层应用可以注入自己的 `PlanParser`(如基于 XML / YAML / 自定义 DSL)
|
||||
- `JsonPlanParser` 是**参考实现**,不是默认实现——上层必须显式选择
|
||||
- `JsonPlanParser` 大约 20 行:`serde_json::from_str` 解析 + 字段映射
|
||||
|
||||
#### 3.2.6 `AgentError`
|
||||
|
||||
```rust
|
||||
pub enum AgentError {
|
||||
Llm(LlmError),
|
||||
Tool(ToolError),
|
||||
Memory(MemoryError),
|
||||
PlanParse(String),
|
||||
HookBlocked(String),
|
||||
LimitExceeded(String),
|
||||
Config(String),
|
||||
Other(String),
|
||||
}
|
||||
```
|
||||
|
||||
**设计意图**:
|
||||
- 聚合而非包装下层错误(避免 `Box<dyn Error>` 丢失类型)
|
||||
- `PlanParse` / `HookBlocked` / `LimitExceeded` / `Config` 是 Agent 层特有的错误类型
|
||||
- `is_recoverable()` 根据变体类型判定(如 `Memory(_)` 可恢复、`PlanParse(_)` 不可恢复)
|
||||
|
||||
#### 3.2.7 `AgentConfig` + `AgentBuilder`
|
||||
|
||||
```rust
|
||||
pub struct AgentConfig {
|
||||
pub max_turns: u32,
|
||||
pub max_tool_turns: u32,
|
||||
pub session_ttl: Option<Duration>,
|
||||
pub compact_config: Option<CompactConfig>,
|
||||
}
|
||||
|
||||
pub struct AgentBuilder { /* ... */ }
|
||||
impl AgentBuilder {
|
||||
pub fn new() -> Self;
|
||||
pub fn provider(self, p: Arc<dyn LlmProvider>) -> Self;
|
||||
pub fn tool_registry(self, r: Arc<ToolRegistry>) -> Self;
|
||||
pub fn hook_executor(self, h: Arc<HookExecutor>) -> Self;
|
||||
pub fn memory_store(self, m: Arc<dyn MemoryStore>) -> Self; // 选填
|
||||
pub fn retriever(self, r: Arc<MemoryRetriever>) -> Self; // 选填
|
||||
pub fn session_memory_backend(self, s: Arc<dyn MemoryStore>) -> Self; // 选填
|
||||
pub fn config(self, c: AgentConfig) -> Self;
|
||||
pub fn build(self) -> Result<RuntimeBundle, AgentError>;
|
||||
}
|
||||
```
|
||||
|
||||
**设计意图**:
|
||||
- `AgentBuilder` 是**唯一**的 `RuntimeBundle` 构造入口
|
||||
- 必填字段在 `build()` 时校验(`provider` / `tool_registry` / `hook_executor` 不可缺)
|
||||
- `memory_store` / `retriever` / `session_memory_backend` 选填
|
||||
- `session_memory_backend` 不传时,`AgentSession` 内部用 `InMemoryStore` 兜底(进程级隔离)
|
||||
|
||||
#### 3.2.8 `SessionMemory` — 会话级记忆
|
||||
|
||||
```rust
|
||||
pub struct SessionMemory {
|
||||
store: Arc<dyn MemoryStore>,
|
||||
namespace: String,
|
||||
}
|
||||
|
||||
impl SessionMemory {
|
||||
/// 创建新的 session 级记忆实例。
|
||||
/// store:后端存储(可跨进程共享的 MemoryStore 实现)。
|
||||
/// namespace:按 session_id 隔离,防止跨 session 泄漏。
|
||||
pub fn new(store: Arc<dyn MemoryStore>, namespace: &str) -> Self;
|
||||
|
||||
/// 写入一条 key-value 条目。
|
||||
pub async fn set(&self, key: &str, value: &str) -> Result<(), AgentError>;
|
||||
|
||||
/// 读取指定 key 的值。
|
||||
pub async fn get(&self, key: &str) -> Result<Option<String>, AgentError>;
|
||||
|
||||
/// 返回所有条目的格式化快照,适合注入 system prompt。
|
||||
/// 格式:
|
||||
/// <session-context>
|
||||
/// key1: value1
|
||||
/// key2: value2
|
||||
/// </session-context>
|
||||
pub async fn snapshot(&self) -> Result<String, AgentError>;
|
||||
|
||||
/// 删除指定 key。
|
||||
pub async fn remove(&self, key: &str) -> Result<(), AgentError>;
|
||||
|
||||
/// 清空当前 namespace 下所有条目。
|
||||
pub async fn clear(&self) -> Result<(), AgentError>;
|
||||
}
|
||||
```
|
||||
|
||||
**设计意图**:
|
||||
- `SessionMemory` 是**会话级**记忆,不是持久层(`MemoryStore`)也不是对话历史(`ConversationMemory`)——它的定位是 session 内各 context 之间的信息桥接
|
||||
- **复用 Phase 3 `MemoryStore` trait**:不引入新的存储后端机制。单进程场景用 `InMemoryStore`(零序列化开销),跨进程场景换 Redis / SQLite 等实现即可
|
||||
- **按 `namespace` 隔离**:每个 session 一个独立命名空间(`"_session_{session_id}"`),避免跨 session 意外泄漏
|
||||
- **`snapshot()` 格式化为标记文本**:专为注入 system prompt 设计,LLM 可以自然理解 `<session-context>` 标签中的内容
|
||||
- **所有方法为 `async`**:因为后端可能是跨进程的(Redis / DB),虽然 `InMemoryStore` 本身是同步操作
|
||||
- **不引入自己的错误类型**:错误通过 `AgentError::Memory` 传播(复用已有变体)
|
||||
|
||||
**三层记忆体系关系**:
|
||||
|
||||
```
|
||||
持久层(Phase 3) MemoryStore / KnowledgeStore ── 跨 session 持久,长期知识
|
||||
会话层(新增) SessionMemory ── 单 session 内共享,context 桥接
|
||||
对话层(Phase 3) ConversationMemory ── 单 context 内消息历史
|
||||
```
|
||||
|
||||
**典型使用模式**(v0.2+ context 切换场景):
|
||||
|
||||
```
|
||||
context_a (build agent)
|
||||
→ 在对话中决定某个关键结论值得记下来
|
||||
→ 调用 session_memory.set("design_decision", "用 PostgreSQL")
|
||||
→ 继续对话
|
||||
|
||||
创建 context_b (plan agent)
|
||||
→ system_prompt 末尾追加 session_memory.snapshot()
|
||||
→ LLM 看到 "<session-context>\ndesign_decision: 用 PostgreSQL\n</session-context>"
|
||||
→ 无需看 context_a 的 50 轮完整历史,但知道关键上下文
|
||||
```
|
||||
|
||||
### 3.3 状态机
|
||||
|
||||
#### 3.3.1 `StepStatus` 状态转换图
|
||||
|
||||
```
|
||||
┌─────────────┐
|
||||
│ Pending │ ◄── 初始状态
|
||||
└──────┬──────┘
|
||||
│ execute_plan() 进入
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ Running │ ◄── 触发 OnPlanStepComplete(status=Running)
|
||||
└──────┬──────┘
|
||||
│
|
||||
┌────────────────┼────────────────┐
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌─────────┐ ┌──────────┐ ┌──────────┐
|
||||
│Completed│ │ Failed │ │ Skipped │
|
||||
└─────────┘ └──────────┘ └──────────┘
|
||||
触发 OnPlanStepComplete(status=Completed)
|
||||
触发 OnPlanStepComplete(status=Failed)
|
||||
触发 OnPlanStepComplete(status=Skipped)
|
||||
```
|
||||
|
||||
**设计约束**:
|
||||
- 状态转换**单向**(Pending → Running → 终态),不回退
|
||||
- 终态(Completed / Failed / Skipped)触发 `OnPlanStepComplete` hook
|
||||
- 重试由上层应用新建 `Plan` 实现(不在 `TaskAgent` 内做自动重试)
|
||||
|
||||
#### 3.3.2 Session 状态
|
||||
|
||||
`AgentSession` 的状态机比 `Step` 简单:
|
||||
|
||||
```
|
||||
创建 (new) ──► turn_index=0 ──► submit_turn() ──► turn_index+=1 ──► ... ──► 销毁
|
||||
```
|
||||
|
||||
`turn_index` 累加,`cost_so_far` 累加,无显式状态枚举(避免过度设计)。
|
||||
|
||||
### 3.4 Hook 扩展设计
|
||||
|
||||
在 `src/llm/hooks.rs` 中追加 3 个事件 + 2 个上下文字段:
|
||||
|
||||
```rust
|
||||
pub enum HookEvent {
|
||||
// ... 现有 4 个:PreRequest / PostRequest / OnRetry / OnError ...
|
||||
|
||||
// 新增 3 个(Phase 4):
|
||||
OnTurnStart,
|
||||
OnTurnEnd,
|
||||
OnPlanStepComplete,
|
||||
}
|
||||
|
||||
pub struct HookContext {
|
||||
// ... 现有字段 ...
|
||||
|
||||
// 新增 2 个(Phase 4):
|
||||
pub turn_index: Option<u32>, // OnTurnStart / OnTurnEnd 用
|
||||
pub plan_step_index: Option<usize>, // OnPlanStepComplete 用
|
||||
}
|
||||
```
|
||||
|
||||
**设计意图**:
|
||||
- **不破坏现有 hook 兼容性**:3 个新事件是 enum 追加,2 个新字段是 `Option<T>` 默认 `None`
|
||||
- 上层应用可通过监听 `OnTurnEnd` 实现"独立 task 回写 ConversationMemory"——呼应"记忆在独立 task 处理"原则
|
||||
- `OnPlanStepComplete` 提供"步骤级别"的可观测性,与 Hermes 的"任务进度回调"对齐
|
||||
|
||||
### 3.5 错误体系
|
||||
|
||||
`AgentError` 与下层错误的关系:
|
||||
|
||||
```
|
||||
┌──────────────────┐
|
||||
│ AgentError │
|
||||
├──────────────────┤
|
||||
│ Llm(LlmError) │──► 透传 Phase 0 错误,含 is_recoverable()
|
||||
│ Tool(ToolError) │──► 透传 Phase 2 错误,含 is_recoverable()
|
||||
│ Memory(MemoryError)│─► 透传 Phase 3 错误
|
||||
│ PlanParse(String) │─► Agent 层特有
|
||||
│ HookBlocked(String)│─► Agent 层特有
|
||||
│ LimitExceeded(String)│► Agent 层特有
|
||||
│ Config(String) │──► Agent 层特有
|
||||
│ Other(String) │──► 兜底
|
||||
└──────────────────┘
|
||||
│
|
||||
▼
|
||||
is_recoverable(): 聚合判定
|
||||
- Llm/Memory 可恢复(重试)
|
||||
- PlanParse / Config 不可恢复(需人工介入)
|
||||
- Tool / HookBlocked / LimitExceeded 按内层错误判定
|
||||
```
|
||||
|
||||
**自动 From 转换**:通过 `#[from]` 宏实现 `From<LlmError>` / `From<ToolError>` / `From<MemoryError>`,让 `submit_turn` 内部可以用 `?` 运算符直接传播。
|
||||
|
||||
### 3.6 与 Phase 0-3 模块的集成
|
||||
|
||||
| Phase 4 组件 | 调用的下层 API | 调用位置 |
|
||||
|-------------|--------------|---------|
|
||||
| `AgentSession::submit_turn` | `LlmCycle::new` + `with_system_prompt` + `with_hook_executor` + `with_compact_config` + `with_messages` + `submit_with_tools` | session.rs |
|
||||
| `AgentSession::submit_turn` | `CostTracker::add`(累计 cost) | session.rs |
|
||||
| `RuntimeBundle::new` | `ToolRegistry::register`(注册 retrieve tool) | runtime.rs |
|
||||
| `TaskAgent::execute_plan` | `AgentSession::submit_turn`(每步调一次) | task.rs |
|
||||
| `JsonPlanParser::parse` | `serde_json::from_str` | task.rs |
|
||||
| `AgentError::from` | `LlmError` / `ToolError` / `MemoryError` | error.rs |
|
||||
| `HookContext` 扩展 | `HookEvent::OnTurnStart/End/OnPlanStepComplete` | llm/hooks.rs |
|
||||
| `SessionMemory::set/get/snapshot` | `MemoryStore::save/load/search` | session_memory.rs |
|
||||
|
||||
**不调用的下层 API**(明确边界):
|
||||
- ❌ `ConversationMemory`(由上层独立 task 管理)
|
||||
- ❌ `KnowledgeStore`(由上层独立 task 管理)
|
||||
- ❌ `McpClient`(已由 `ToolRegistry` 包装)
|
||||
- ❌ `StreamEvents::submit_stream`(v1 暂不暴露流式 `submit_turn`,v0.2 再说)
|
||||
- ❌ 多 context 切换管理(v0.2+ 实现,Phase 4 只预留 `SessionMemory` 桥接通道)
|
||||
- ❌ `"session_memory_set"` 等 session memory tool 自动注册(v0.2+ 可选)
|
||||
|
||||
## 4. 实施计划
|
||||
|
||||
Phase 4 拆分为三个独立子阶段:**Phase 4a(核心胶水层)** → **Phase 4b(任务执行)** → **Phase 4c(会话级记忆)**。每个子阶段独立交付、独立验证,4b 与 4c 无相互依赖。
|
||||
|
||||
### 4.1 文件清单
|
||||
|
||||
#### 新增文件(9 个)
|
||||
|
||||
```
|
||||
src/agent.rs # [4a] 模块根 + pub use 重导出
|
||||
src/agent/agent.rs # [4a] Agent trait
|
||||
src/agent/runtime.rs # [4a] RuntimeBundle + AgentConfig(不含 session_memory_backend)
|
||||
src/agent/session.rs # [4a] AgentSession(submit_turn + 内联 session_data HashMap)
|
||||
src/agent/task.rs # [4a] Plan / Step / StepStatus 纯数据 / [4b] TaskAgent + PlanParser + JsonPlanParser
|
||||
src/agent/builder.rs # [4a] AgentBuilder(不含 session_memory_backend)
|
||||
src/agent/error.rs # [4a] AgentError(不含 PlanParse 变体)/ [4b] 补充 PlanParse 变体
|
||||
src/agent/session_memory.rs # [4c] SessionMemory(基于 MemoryStore)
|
||||
```
|
||||
|
||||
#### 修改文件(2 个)
|
||||
|
||||
```
|
||||
src/lib.rs # [4a] + pub mod agent;
|
||||
src/llm/hooks.rs # [4a] + 2 事件(OnTurnStart/OnTurnEnd)+ 1 字段(turn_index)
|
||||
# [4b] + 1 事件(OnPlanStepComplete)+ 1 字段(plan_step_index)
|
||||
```
|
||||
|
||||
#### 关联文档(已完成)
|
||||
|
||||
```
|
||||
docs/note-agent-harness-references.md # ✅ 已存在
|
||||
docs/note-agent-runtime-design.md # ✅ 已存在(与本文件配套)
|
||||
docs/7-agent-runtime.md # ✅ 本文件
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4.2 Phase 4a — 核心胶水层(最小 Agent Runtime)
|
||||
|
||||
**范围**:Agent trait + AgentSession + submit_turn(内联 HashMap session_data)+ RuntimeBundle/AgentBuilder + AgentError + Plan/Step 纯数据 + hooks 扩展(OnTurnStart/OnTurnEnd)
|
||||
|
||||
**任务拆解**:
|
||||
|
||||
| 顺序 | 任务 | 涉及文件 | 验证 |
|
||||
|------|------|---------|------|
|
||||
| a1 | 修改 `llm/hooks.rs` 追加 OnTurnStart / OnTurnEnd + turn_index 字段 | `src/llm/hooks.rs` | `cargo build` 通过;Phase 0 测试不挂 |
|
||||
| a2 | 新建 `agent/error.rs` 定义 `AgentError`(不含 PlanParse 变体) | `src/agent/error.rs` | `cargo build` 通过 |
|
||||
| a3 | 新建 `agent/agent.rs` 定义 `Agent` trait | `src/agent/agent.rs` | `cargo build` 通过 |
|
||||
| a4 | 新建 `agent/runtime.rs` 定义 `RuntimeBundle` + `AgentConfig`(不含 session_memory_backend) | `src/agent/runtime.rs` | `cargo build` 通过 |
|
||||
| a5 | 新建 `agent/builder.rs` 定义 `AgentBuilder`(不含 session_memory_backend 方法) | `src/agent/builder.rs` | `cargo build` 通过 |
|
||||
| a6 | 新建 `agent/session.rs` 定义 `AgentSession` + `submit_turn`(内联 `HashMap<String,String>` 做 session_data,不引 MemoryStore) | `src/agent/session.rs` | `cargo build` 通过 |
|
||||
| a7 | 新建 `agent/task.rs` 定义 `Plan` / `Step` / `StepStatus` 纯数据结构(不含 TaskAgent trait,不含 PlanParser) | `src/agent/task.rs` | `cargo build` 通过 |
|
||||
| a8 | 新建 `src/agent.rs` 模块根 + `pub use` 重导出 + 修改 `lib.rs` | `src/agent.rs` + `src/lib.rs` | `cargo build` 通过 |
|
||||
| a9 | 编写烟雾测试 3-4 个(Agent trait 可装配 / RuntimeBundle 可构造 / submit_turn 跑通 mock / Plan 数据结构) | `src/agent/*.rs` 内联 | `cargo test` 通过 |
|
||||
| a10 | 完整 `cargo test` 跑全量回归 + roadmap.md 状态更新 | — | 所有已有测试不挂 |
|
||||
|
||||
**依赖关系**:
|
||||
|
||||
```
|
||||
hooks扩展 (a1) ──┐
|
||||
├──► agent/error.rs (a2) ──► agent/agent.rs (a3)
|
||||
│ │
|
||||
│ ▼
|
||||
│ agent/runtime.rs (a4)
|
||||
│ │
|
||||
│ ▼
|
||||
│ agent/builder.rs (a5)
|
||||
│ │
|
||||
│ ▼
|
||||
│ agent/session.rs (a6)
|
||||
│ │
|
||||
│ ▼
|
||||
│ agent/task.rs (a7) [纯数据]
|
||||
│ │
|
||||
└──────────────────► src/agent.rs + lib.rs (a8)
|
||||
│
|
||||
▼
|
||||
cargo test (a9 → a10)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4.3 Phase 4b — 任务执行
|
||||
|
||||
**范围**:TaskAgent trait + PlanParser trait + JsonPlanParser 参考实现 + OnPlanStepComplete hook + AgentError PlanParse 变体
|
||||
|
||||
**前置条件**:Phase 4a 已完成并交付。
|
||||
|
||||
**任务拆解**:
|
||||
|
||||
| 顺序 | 任务 | 涉及文件 | 验证 |
|
||||
|------|------|---------|------|
|
||||
| b1 | 修改 `llm/hooks.rs` 追加 OnPlanStepComplete + plan_step_index 字段 | `src/llm/hooks.rs` | `cargo build` 通过;Phase 0 + 4a 测试不挂 |
|
||||
| b2 | `agent/error.rs` 追加 PlanParse 变体 | `src/agent/error.rs` | `cargo build` 通过 |
|
||||
| b3 | `agent/task.rs` 追加 `TaskAgent` trait + `PlanParser` trait + `JsonPlanParser` 参考实现 | `src/agent/task.rs` | `cargo build` 通过 |
|
||||
| b4 | 更新 `agent.rs` 模块根重导出(如有新增公开类型) | `src/agent.rs` | `cargo build` 通过 |
|
||||
| b5 | 编写烟雾测试 2-3 个(TaskAgent mock 执行 / JsonPlanParser 解析 / PlanParse 错误) | `src/agent/task.rs` 内联 | `cargo test` 通过 |
|
||||
| b6 | 完整 `cargo test` 跑全量回归 + roadmap.md 状态更新 | — | 所有已有测试不挂 |
|
||||
|
||||
**依赖关系**:
|
||||
|
||||
```
|
||||
hooks扩展 (b1) ──┐
|
||||
├──► error.rs 追加 (b2) ──► task.rs 追加 (b3)
|
||||
│
|
||||
▼
|
||||
agent.rs 更新 (b4)
|
||||
│
|
||||
▼
|
||||
cargo test (b5 → b6)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4.4 Phase 4c — 会话级记忆
|
||||
|
||||
**范围**:SessionMemory struct(基于 MemoryStore)+ RuntimeBundle/Build 扩展 session_memory_backend + AgentSession 接入(替换内联 HashMap)
|
||||
|
||||
**前置条件**:Phase 4a 已完成并交付(可在 4b 之前、之后或并行实施)。
|
||||
|
||||
**任务拆解**:
|
||||
|
||||
| 顺序 | 任务 | 涉及文件 | 验证 |
|
||||
|------|------|---------|------|
|
||||
| c1 | 新建 `agent/session_memory.rs` 定义 `SessionMemory`(基于 `MemoryStore`,namespace 隔离) | `src/agent/session_memory.rs` | `cargo build` 通过 |
|
||||
| c2 | `agent/runtime.rs` 追加 `session_memory_backend` 字段到 `RuntimeBundle` | `src/agent/runtime.rs` | `cargo build` 通过 |
|
||||
| c3 | `agent/builder.rs` 追加 `.session_memory_backend()` 方法 | `src/agent/builder.rs` | `cargo build` 通过 |
|
||||
| c4 | `agent/session.rs` 替换内联 HashMap 为完整 `SessionMemory` + 更新模块根重导出 | `src/agent/session.rs` + `src/agent.rs` | `cargo build` 通过 |
|
||||
| c5 | 编写烟雾测试 2-3 个(SessionMemory set/get/snapshot 基于 InMemoryStore) | `src/agent/session_memory.rs` 内联 | `cargo test` 通过 |
|
||||
| c6 | 完整 `cargo test` 跑全量回归 + roadmap.md 状态更新 | — | 所有已有测试不挂 |
|
||||
|
||||
**依赖关系**:
|
||||
|
||||
```
|
||||
session_memory.rs (c1) ──► runtime.rs 追加 (c2) ──► builder.rs 追加 (c3)
|
||||
│
|
||||
▼
|
||||
session.rs 修改 (c4)
|
||||
│
|
||||
▼
|
||||
cargo test (c5 → c6)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4.5 预估工作量(按子阶段)
|
||||
|
||||
| 子阶段 | 文件 | 行数 | 说明 |
|
||||
|--------|------|------|------|
|
||||
| **Phase 4a** | hooks 扩展(2 事件 + 1 字段) | ~10 | 追加变体 + 字段 + 文档 |
|
||||
| | agent/error.rs | ~40 | AgentError 枚举 + From + is_recoverable |
|
||||
| | agent/agent.rs | ~30 | Agent trait + docs |
|
||||
| | agent/runtime.rs | ~60 | RuntimeBundle + AgentConfig |
|
||||
| | agent/builder.rs | ~60 | 链式构造 + build 校验 |
|
||||
| | agent/session.rs | ~100 | AgentSession + submit_turn + 内联 HashMap |
|
||||
| | agent/task.rs(纯数据) | ~40 | Plan / Step / StepStatus |
|
||||
| | src/agent.rs + lib.rs | ~20 | 模块根 + 导出 |
|
||||
| | 烟雾测试 | ~80 | 3-4 个测试 |
|
||||
| | **小计** | **~440** | **核心胶水层** |
|
||||
| **Phase 4b** | hooks 扩展(1 事件 + 1 字段) | ~5 | OnPlanStepComplete + plan_step_index |
|
||||
| | error.rs 追加 PlanParse | ~5 | 1 个变体 |
|
||||
| | task.rs 追加(TaskAgent + PlanParser + JsonPlanParser) | ~130 | trait + 参考实现 + docs |
|
||||
| | 烟雾测试 | ~60 | 2-3 个测试 |
|
||||
| | **小计** | **~200** | **任务执行** |
|
||||
| **Phase 4c** | session_memory.rs | ~40 | 5 个方法 + docs |
|
||||
| | runtime.rs / builder.rs / session.rs 修改 | ~35 | 追加字段 + setter + 替换 HashMap |
|
||||
| | 烟雾测试 | ~40 | 2-3 个测试 |
|
||||
| | **小计** | **~115** | **会话级记忆** |
|
||||
| **合计** | | **~755** | 与原始预估 ~800 基本持平 |
|
||||
|
||||
## 5. 风险评估
|
||||
|
||||
### 5.1 抽象化边界(核心风险)
|
||||
|
||||
**风险描述**:Phase 4 容易"过度抽象"——参考了 OpenHarness / Hermes 后,倾向于把它们的核心能力都搬到 Rust core 库里。
|
||||
|
||||
**缓解措施**:
|
||||
- 严格遵循 §1.3 的 7 条设计原则
|
||||
- 每次添加新 trait / struct 前,先问"这属于 core 库职责吗?"
|
||||
- 业务能力(Plan 拆解、多 Agent 协同、技能加载)一律走 trait 注入或 v0.2+ 延后
|
||||
|
||||
### 5.2 对 Phase 0-3 的侵入风险
|
||||
|
||||
**风险描述**:为实现 Phase 4 需修改 `src/llm/hooks.rs`,可能破坏 Phase 0 的现有测试。
|
||||
|
||||
**缓解措施**:
|
||||
- 只追加 enum 变体和 `Option<T>` 字段(NF9)
|
||||
- 顺序:先跑 `cargo test` 确认 Phase 0 测试不挂,再开始 Phase 4
|
||||
- 详细回归验证:实施完毕后跑全量 `cargo test`
|
||||
|
||||
### 5.3 参考项目语言差异
|
||||
|
||||
**风险描述**:OpenClaw / Hermes / OpenHarness 均为 Python/TypeScript,OpenHuman 虽是 Rust + Tauri 但定位是桌面应用。直接照搬接口形状可能导致 Rust 借用检查问题、async 复杂度增加。
|
||||
|
||||
**缓解措施**:
|
||||
- §1.3 明确"借鉴不照搬"
|
||||
- 反模式列表(见 `docs/note-agent-harness-references.md` §6)作为排除项
|
||||
- 接口设计优先考虑 Rust 惯例(`Arc<dyn Trait>` / `async fn` / `Result<T, E>`)
|
||||
|
||||
### 5.4 trait 设计的稳定性风险
|
||||
|
||||
**风险描述**:Phase 4 是 v0.1 的第一个"复杂 trait 集合",如果 trait 形状不稳定,v0.2+ 添加新能力时会 breaking。
|
||||
|
||||
**缓解措施**:
|
||||
- §3.2 的所有 trait / struct 在 `docs/note-agent-runtime-design.md` §4 已固化草案
|
||||
- 实施时如需调整,应先更新决策记录再改代码
|
||||
- 预留扩展点:`Agent::tool_definitions` 的默认实现可被子 trait 覆盖
|
||||
|
||||
### 5.5 实施进度风险
|
||||
|
||||
**风险描述**:拆为 3 个子阶段后每个阶段任务量降低(4a 约 440 行、4b 约 200 行、4c 约 115 行),但阶段间衔接(4b/4c 对 4a 的依赖)可能产生等待。
|
||||
|
||||
**缓解措施**:
|
||||
- 每个子阶段独立验证,完成即交付,不阻塞后续阶段
|
||||
- 4b 和 4c 无相互依赖,可并行开工
|
||||
- 烟雾测试只验证"能跑通"不验证"业务正确"——避免陷入业务循环的细节
|
||||
- 必要时先做 `MockProvider`(Phase 0 已有模式),不依赖真实 LLM
|
||||
|
||||
## 6. 验收标准
|
||||
|
||||
### 6.1 通用代码验收(每个子阶段必须满足)
|
||||
|
||||
- [ ] `cargo build --release` 0 错误 0 警告(clippy)
|
||||
- [ ] `cargo test` 所有已有测试 + 本阶段新增测试全部通过
|
||||
- [ ] `cargo doc --no-deps` 所有公开 API 有 `///` 文档注释
|
||||
- [ ] `src/llm/hooks.rs` 仅追加(不修改现有变体或字段)
|
||||
|
||||
### 6.2 Phase 4a 验收
|
||||
|
||||
#### 6.2a 代码验收
|
||||
|
||||
- [ ] 新增代码 ~440 行(含测试 + 文档注释),与 §4.5 预估一致
|
||||
- [ ] `src/lib.rs` 新增一行 `pub mod agent;`
|
||||
- [ ] 新增文件:`agent.rs` / `agent/agent.rs` / `agent/runtime.rs` / `agent/builder.rs` / `agent/session.rs` / `agent/task.rs` / `agent/error.rs`(共 7 个文件,不含 `agent/builder.rs` 之外的 builder 则 7 个)
|
||||
|
||||
#### 6.2b 接口验收
|
||||
|
||||
- [ ] `Agent` trait 包含 `name` / `system_prompt` / `tool_definitions` 三个方法
|
||||
- [ ] `RuntimeBundle` 包含 5 个字段:provider / tool_registry / hook_executor / memory_store? / retriever? / config(不含 session_memory_backend)
|
||||
- [ ] `AgentBuilder` 提供 5 个 setter(不含 session_memory_backend)+ `build()` 校验
|
||||
- [ ] `AgentSession` 持 `Arc<dyn Agent>` 而非 `agent_name: String`
|
||||
- [ ] `AgentSession::submit_turn` 实现约 30 行,含 OnTurnStart/End hook 触发
|
||||
- [ ] `AgentSession` 用内联 `HashMap<String, String>` 做 session_data(不引 `MemoryStore`)
|
||||
- [ ] `Plan` / `Step` / `StepStatus` 纯数据结构存在,状态机正确
|
||||
- [ ] `AgentError` 聚合 6 个变体:Llm / Tool / Memory / HookBlocked / LimitExceeded / Config / Other(不含 PlanParse)
|
||||
- [ ] `AgentError::is_recoverable()` 对各变体返回正确分类
|
||||
- [ ] `HookEvent` 新增 2 个变体:`OnTurnStart` / `OnTurnEnd`
|
||||
- [ ] `HookContext` 新增 1 个 `Option` 字段:`turn_index`
|
||||
|
||||
#### 6.2c 测试验收
|
||||
|
||||
- [ ] **测试 1**:`Agent` trait 可实现 + `RuntimeBundle` 可构造(builder 链式调用)
|
||||
- [ ] **测试 2**:`AgentSession::submit_turn` 跑通 mock provider(Phase 0 `MockProvider` 模式)
|
||||
- [ ] **测试 3**:`Plan` / `Step` / `StepStatus` 状态机转换正确
|
||||
- [ ] **测试 4(可选)**:session_data set/get 基本读写
|
||||
|
||||
#### 6.2d 行为验收
|
||||
|
||||
- [ ] `AgentSession::submit_turn` 不持有 `ConversationMemory`(grep 验证无 `use crate::memory::ConversationMemory`)
|
||||
- [ ] `AgentSession` 持 `Arc<dyn Agent>`,可从 agent 获取 `system_prompt()` / `tool_definitions()`
|
||||
- [ ] `RuntimeBundle::new` 当 `retriever` 为 `Some` 时自动注册 `"retrieve"` tool
|
||||
- [ ] `AgentBuilder::build` 在必填字段缺失时返回 `AgentError::Config`(而非 panic)
|
||||
|
||||
---
|
||||
|
||||
### 6.3 Phase 4b 验收
|
||||
|
||||
#### 6.3a 代码验收
|
||||
|
||||
- [ ] 追加代码 ~200 行(增量,在 Phase 4a 基础上),与 §4.5 预估一致
|
||||
- [ ] `src/llm/hooks.rs` 追加 OnPlanStepComplete + plan_step_index(不修改 Phase 4a 新增内容)
|
||||
|
||||
#### 6.3b 接口验收
|
||||
|
||||
- [ ] `TaskAgent` trait 提供双入口 `run(goal)` + `execute_plan(plan)`
|
||||
- [ ] `PlanParser` trait 可注入,`JsonPlanParser` 参考实现基于 `serde_json`(~20 行)
|
||||
- [ ] `AgentError` 追加 PlanParse 变体(共 7 个变体)
|
||||
- [ ] `HookEvent` 追加 1 个变体:`OnPlanStepComplete`
|
||||
- [ ] `HookContext` 追加 1 个 `Option` 字段:`plan_step_index`
|
||||
|
||||
#### 6.3c 测试验收
|
||||
|
||||
- [ ] **测试 1**:`TaskAgent::execute_plan` 跑通 mock provider
|
||||
- [ ] **测试 2**:`JsonPlanParser::parse` 能解析合法 JSON,失败时返回 `AgentError::PlanParse`
|
||||
- [ ] **测试 3(可选)**:`OnPlanStepComplete` hook 触发正确
|
||||
|
||||
---
|
||||
|
||||
### 6.4 Phase 4c 验收
|
||||
|
||||
#### 6.4a 代码验收
|
||||
|
||||
- [ ] 追加代码 ~115 行(增量,在 Phase 4a 基础上),与 §4.5 预估一致
|
||||
- [ ] 新增文件:`agent/session_memory.rs`
|
||||
|
||||
#### 6.4b 接口验收
|
||||
|
||||
- [ ] `SessionMemory` 包含 5 个方法(set / get / snapshot / remove / clear),基于 `MemoryStore` 实现
|
||||
- [ ] `SessionMemory::snapshot` 返回 `<session-context>` 标签包裹的格式化文本
|
||||
- [ ] `RuntimeBundle` 追加 `session_memory_backend: Option<Arc<dyn MemoryStore>>` 字段
|
||||
- [ ] `AgentBuilder` 追加 `.session_memory_backend()` setter
|
||||
- [ ] `AgentSession` 替换内联 HashMap 为完整 `SessionMemory`,含 `session_memory: SessionMemory` 字段
|
||||
- [ ] `SessionMemory` 在 `session_memory_backend` 未传入时自动使用 `InMemoryStore` 兜底
|
||||
|
||||
#### 6.4c 测试验收
|
||||
|
||||
- [ ] **测试 1**:`SessionMemory` set / get / snapshot 基本读写(基于 `InMemoryStore`)
|
||||
- [ ] **测试 2**:session_data 内联 HashMap ↔ SessionMemory 替换后 submit_turn 行为不变
|
||||
|
||||
---
|
||||
|
||||
### 6.5 文档验收
|
||||
|
||||
- [ ] `docs/7-agent-runtime.md`(本文件)完整,6 段式结构齐备
|
||||
- [ ] `docs/note-agent-runtime-design.md` 与本文件互相引用一致
|
||||
- [ ] `docs/note-agent-harness-references.md` 与本文件互相引用一致
|
||||
- [ ] `docs/roadmap.md` 各子阶段状态按阶段翻转
|
||||
|
||||
### 6.6 风险验收
|
||||
|
||||
- [ ] 5.1 抽象化边界:交付物列表中**不包含** Multi-Agent / Skills / TUI / Gateway 等应用层能力
|
||||
- [ ] 5.2 Phase 0-3 侵入:`git diff` 显示 `src/llm/hooks.rs` 仅追加
|
||||
- [ ] 5.3 语言差异:trait 形状符合 Rust 惯例(无 Python 风格的复杂继承)
|
||||
- [ ] 5.4 trait 稳定性:决策记录与最终代码一致
|
||||
- [ ] 5.5 实施进度:每个子阶段实际工作量与 §4.5 预估偏差 < 30%
|
||||
|
||||
## 7. 一句话总结
|
||||
|
||||
> **Phase 4 = 3 个子阶段:4a(核心胶水层:Agent + AgentSession + submit_turn + RuntimeBundle + Plan/Step 纯数据,~440 行)→ 4b(任务执行:TaskAgent + PlanParser/JsonPlanParser,~200 行)→ 4c(会话级记忆:SessionMemory + 接入,~115 行),合计 ~755 行,分步交付、逐段验证,把 Phase 0-3 已有能力"装配"成"智能体"的概念。**
|
||||
@@ -0,0 +1,434 @@
|
||||
# 示例程序新增方案
|
||||
|
||||
> 作者:Proposal Agent
|
||||
> 日期:2026-06-11
|
||||
> 对应版本:agcore v0.1
|
||||
|
||||
## 背景与目标
|
||||
|
||||
### 问题
|
||||
|
||||
当前 `examples/` 目录下只有一个 `simple_visit.rs`,仅演示了 `LlmCycle::submit()` 的基本 LLM 调用(Phase 0),且依赖真实 API key 才能运行。v0.1 已实现的全部 7 个 Phase 的能力(Phase 0~4c)缺乏可运行、可独立验证的示例展示。
|
||||
|
||||
### 目标
|
||||
|
||||
1. **覆盖全 Phase** — 每个 Phase 核心能力至少有一个示例
|
||||
2. **可离线运行** — 优先选用 MockProvider 和本地逻辑,不强制依赖 API key
|
||||
3. **真实使用模式** — 示例反映库的预期使用方式(Builder 模式、trait 实现、? 错误传播)
|
||||
4. **验收辅助** — 示例跑通 = 对应模块公共 API 可用且装配正确
|
||||
|
||||
### 非目标
|
||||
|
||||
- 不替代单元测试的边界覆盖(内联测试仍负责边界条件)
|
||||
- 不引入第三方依赖(示例只使用 `agcore` 公开 API)
|
||||
- 不追求 UI 或交互式输入
|
||||
|
||||
---
|
||||
|
||||
## 当前状态分析
|
||||
|
||||
```text
|
||||
examples/
|
||||
└── simple_visit.rs # 仅 Phase 0 基础调用,需 API key
|
||||
```
|
||||
|
||||
### 现有示例覆盖缺口
|
||||
|
||||
| Phase | 模块 | 示例覆盖 | 缺口 |
|
||||
|-------|------|---------|------|
|
||||
| Phase 0 | LLM 调用周期 | `simple_visit.rs` | 流式事件、重试逻辑、Auto-compaction 未演示 |
|
||||
| Phase 1 | 提示词工程 | ❌ | 模板变量插值、消息组合、条件渲染 |
|
||||
| Phase 2 | 工具系统 | ❌ | 自定义工具注册、并行调用、权限检查 |
|
||||
| Phase 3 | 记忆系统 | ❌ | 对话记忆滑动窗口、知识页面存储、关键词检索 |
|
||||
| Phase 4a | 核心胶水层 | ❌ | Agent/AgentSession/RuntimeBundle/AgentBuilder 装配 |
|
||||
| Phase 4b | 任务执行 | ❌ | PlanParser/Step 状态机/TaskAgent |
|
||||
| Phase 4c | 会话级记忆 | ❌ | SessionMemory set/get/snapshot |
|
||||
|
||||
---
|
||||
|
||||
## 设计方案
|
||||
|
||||
### 总体架构
|
||||
|
||||
新增示例按三层优先级组织,每个示例为一个独立 `.rs` 文件,统一放在 `examples/` 目录下。
|
||||
|
||||
```
|
||||
examples/
|
||||
├── simple_visit.rs # [已有] 基本 LLM 调用(Phase 0)
|
||||
├── prompt_composer.rs # [新增] 提示词组合(Phase 1)🥇
|
||||
├── custom_tool.rs # [新增] 自定义工具(Phase 2)🥇
|
||||
├── agent_session_demo.rs # [新增] Agent 会话(Phase 4a+4c)🥇
|
||||
├── task_agent_demo.rs # [新增] 任务规划(Phase 4b)🥇
|
||||
├── conversation_memory_demo.rs # [新增] 对话记忆(Phase 3)🥈
|
||||
├── knowledge_search_demo.rs # [新增] 知识检索(Phase 3)🥈
|
||||
├── streaming_events_demo.rs # [新增] 流式事件(Phase 0)🥈
|
||||
└── full_integration.rs # [新增] 全栈集成(Phase 全栈)🥉
|
||||
```
|
||||
|
||||
### 详细设计
|
||||
|
||||
#### 🥇 示例:`prompt_composer.rs`(Phase 1)
|
||||
|
||||
**设计思路**:纯本地运行,不依赖任何外部服务。通过构造模板、组合消息来验证 Prompt Engineering 模块的公共 API。
|
||||
|
||||
**流程**:
|
||||
```
|
||||
TemplateContext 构造 → PromptTemplate 填充变量 → PromptComposer 构建消息链 → 断言验证
|
||||
```
|
||||
|
||||
**关键代码片段示意**:
|
||||
```rust
|
||||
// 1. 构造模板
|
||||
let mut registry = PromptTemplateRegistry::new();
|
||||
registry.register(PromptTemplate::new("weather", "今日{location}天气:{condition},温度{temperature}"));
|
||||
|
||||
// 2. 填充变量
|
||||
let template = registry.get("weather").unwrap();
|
||||
let rendered = template.render(&TemplateContext::from([
|
||||
("location", "北京"),
|
||||
("condition", "晴"),
|
||||
("temperature", "25°C"),
|
||||
])?;
|
||||
|
||||
// 3. 组合消息
|
||||
let composer = PromptComposer::new()
|
||||
.system("你是一个天气助手")
|
||||
.user(rendered)
|
||||
.assistant(/* 可选历史 */);
|
||||
|
||||
let messages = composer.compose();
|
||||
assert_eq!(messages.len(), 2);
|
||||
```
|
||||
|
||||
**验证点**:
|
||||
- `TemplateContext` 变量插值正确
|
||||
- `PromptComposer` 消息顺序正确
|
||||
- `PromptError` 在缺失变量时正确返回
|
||||
|
||||
**新增代码量**:约 60 行
|
||||
|
||||
---
|
||||
|
||||
#### 🥇 示例:`custom_tool.rs`(Phase 2)
|
||||
|
||||
**设计思路**:实现一个模拟工具(如 `WeatherTool`),注册到 `ToolRegistry`,演示单次调用、并行调用、权限检查。
|
||||
|
||||
**流程**:
|
||||
```
|
||||
实现 BaseTool → 注册到 ToolRegistry → invoke 单次 → invoke_all 并行 → PermissionChecker 白名单过滤
|
||||
```
|
||||
|
||||
**关键代码片段示意**:
|
||||
```rust
|
||||
// 1. 实现工具
|
||||
struct WeatherTool;
|
||||
#[async_trait]
|
||||
impl BaseTool for WeatherTool {
|
||||
fn name(&self) -> &str { "get_weather" }
|
||||
fn parameters(&self) -> Value { json!({"type":"object","properties":{"city":{"type":"string"}}}) }
|
||||
async fn execute(&self, args: Value, _ctx: &ToolContext) -> Result<Value, ToolError> {
|
||||
Ok(json!({"city": args["city"], "temperature": 22, "condition": "晴"}))
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 注册 + 调用
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(WeatherTool.into())?;
|
||||
let result = registry.invoke("get_weather", json!({"city": "北京"})).await?;
|
||||
|
||||
// 3. 并行调用
|
||||
let results = registry.invoke_all(vec![...], 30).await;
|
||||
|
||||
// 4. 权限检查
|
||||
let checker = PermissionChecker::new(PermissionConfig::white_list(vec!["get_weather"]));
|
||||
assert!(checker.check("get_weather").is_ok());
|
||||
assert!(checker.check("delete_file").is_err());
|
||||
```
|
||||
|
||||
**验证点**:
|
||||
- 工具注册/查找/调用完整链路
|
||||
- 并行调用结果数正确
|
||||
- 权限白名单/黑名单行为
|
||||
- `ToolError::NotFound` 未注册工具
|
||||
|
||||
**新增代码量**:约 80 行
|
||||
|
||||
---
|
||||
|
||||
#### 🥇 示例:`agent_session_demo.rs`(Phase 4a + 4c)
|
||||
|
||||
**设计思路**:使用 `MockProvider` 模拟 LLM 响应,完整演示 `Agent → AgentBuilder → RuntimeBundle → AgentSession` 的装配流程及 `SessionMemory` 的读写。
|
||||
|
||||
**流程**:
|
||||
```
|
||||
实现 Agent → AgentBuilder 构造 RuntimeBundle → AgentSession::new → submit_turn → session_data 读写 → snapshot 输出
|
||||
```
|
||||
|
||||
**关键代码片段示意**:
|
||||
```rust
|
||||
// 1. 定义 Agent
|
||||
struct CalculatorAgent;
|
||||
impl Agent for CalculatorAgent {
|
||||
fn name(&self) -> &str { "calculator" }
|
||||
fn system_prompt(&self) -> Option<&str> { Some("你是计算器助手") }
|
||||
}
|
||||
|
||||
// 2. 装配 RuntimeBundle
|
||||
let bundle = AgentBuilder::new()
|
||||
.provider(Arc::new(mock_provider))
|
||||
.tool_registry(Arc::new(tool_registry))
|
||||
.hook_executor(Arc::new(hook_executor))
|
||||
.build()?;
|
||||
|
||||
// 3. 创建会话
|
||||
let mut session = AgentSession::new(Arc::new(CalculatorAgent), "session-1", Arc::new(bundle));
|
||||
|
||||
// 4. 提交对话
|
||||
let response = session.submit_turn("1+1=?").await?;
|
||||
|
||||
// 5. SessionMemory 读写
|
||||
session.set_session_data("last_result", "2").await?;
|
||||
let result = session.get_session_data("last_result").await?;
|
||||
println!("{}", session.session_memory().snapshot().await?);
|
||||
```
|
||||
|
||||
**验证点**:
|
||||
- `AgentBuilder::build()` 必填字段校验
|
||||
- `submit_turn` 流程完整(hook 触发、cost 累计、turn_index 递增)
|
||||
- `SessionMemory` set/get/snapshot 正确
|
||||
- 多个 session 间数据隔离
|
||||
|
||||
**新增代码量**:约 100 行
|
||||
|
||||
---
|
||||
|
||||
#### 🥇 示例:`task_agent_demo.rs`(Phase 4b)
|
||||
|
||||
**设计思路**:使用 `JsonPlanParser` 从预定义 JSON 解析 Plan,驱动 Step 状态机转换,观察状态单向流转。
|
||||
|
||||
**流程**:
|
||||
```
|
||||
构造 JSON 输入 → JsonPlanParser::parse → Plan 数据结构 → 模拟 execute_plan → Step 状态变迁 → Hook 事件
|
||||
```
|
||||
|
||||
**关键代码片段示意**:
|
||||
```rust
|
||||
// 1. 解析 Plan
|
||||
let parser = JsonPlanParser;
|
||||
let input = r#"{"steps": [{"description": "查天气"}, {"description": "算结果"}]}"#;
|
||||
let mut plan = parser.parse(input, "完成今日任务").await?;
|
||||
|
||||
// 2. 模拟 step 执行
|
||||
assert!(plan.steps[0].status.is_pending());
|
||||
step.status = StepStatus::Running;
|
||||
step.status = StepStatus::Completed(response);
|
||||
assert!(step.status.is_terminal());
|
||||
|
||||
// 3. 失败路径
|
||||
step.status = StepStatus::Failed(AgentError::Other("API 不可用".into()));
|
||||
assert!(step.status.is_terminal());
|
||||
```
|
||||
|
||||
**验证点**:
|
||||
- 合法 JSON 解析正确
|
||||
- 非法 JSON / 空步骤 / 缺字段返回 `AgentError::PlanParse`
|
||||
- 状态机单向转换(Pending → Running → Completed/Failed/Skipped)
|
||||
- `is_terminal()` / `is_pending()` 语义正确
|
||||
|
||||
**新增代码量**:约 70 行
|
||||
|
||||
---
|
||||
|
||||
#### 🥈 示例:`conversation_memory_demo.rs`(Phase 3)
|
||||
|
||||
**设计思路**:演示 `ConversationMemory` 的多轮消息写入、滑动窗口淘汰、冷热分离存储。
|
||||
|
||||
**流程**:
|
||||
```
|
||||
ConversationMemory::new → add_message × N → 触发窗口淘汰 → get_history 验证 → MemoryStore 持久化读取
|
||||
```
|
||||
|
||||
**关键代码片段示意**:
|
||||
```rust
|
||||
let config = ConversationMemoryConfig {
|
||||
strategy: MemoryStrategy::SlidingWindow,
|
||||
max_turns: 5,
|
||||
..Default::default()
|
||||
};
|
||||
let mut memory = ConversationMemory::new(store, "session-1", config);
|
||||
|
||||
// 写入 10 条消息
|
||||
for i in 0..10 {
|
||||
memory.add_message(OpenaiChatMessage::user_text(format!("消息 {i}"))).await?;
|
||||
}
|
||||
|
||||
// 验证窗口大小为 5
|
||||
let history = memory.get_history().await?;
|
||||
assert_eq!(history.len(), 5);
|
||||
assert!(history[0].content().contains("消息 5"));
|
||||
```
|
||||
|
||||
**验证点**:
|
||||
- 滑动窗口淘汰旧消息
|
||||
- Full 策略保留全部消息
|
||||
- 冷存储 `MemoryStore` 写入/读取正确
|
||||
- `CompactConfig` 触发自动压缩
|
||||
|
||||
**新增代码量**:约 70 行
|
||||
|
||||
---
|
||||
|
||||
#### 🥈 示例:`knowledge_search_demo.rs`(Phase 3)
|
||||
|
||||
**设计思路**:演示 `KnowledgeStore` 页面存储 + `MemoryRetriever` 关键词检索与 Dice 系数评分。
|
||||
|
||||
**流程**:
|
||||
```
|
||||
KnowledgeStore 创建页面 → MemoryRetriever::search → 评分排序结果输出 → 阈值过滤观察
|
||||
```
|
||||
|
||||
**关键代码片段示意**:
|
||||
```rust
|
||||
let store = KnowledgeStore::new(memory_store);
|
||||
store.save_page("Rust 入门", "Rust 是一门系统编程语言...", vec!["rust", "编程"]).await?;
|
||||
store.save_page("Python 简介", "Python 是动态类型语言...", vec!["python", "动态"]).await?;
|
||||
|
||||
let retriever = MemoryRetriever::new(store, RetrieverConfig::default());
|
||||
let result = retriever.search("Rust 语言").await?;
|
||||
|
||||
for item in &result.items {
|
||||
println!(" 页面: {} (评分: {:.2})", item.page.title, item.score);
|
||||
assert!(item.score >= 0.0 && item.score <= 1.0);
|
||||
}
|
||||
```
|
||||
|
||||
**验证点**:
|
||||
- `KnowledgeStore` 页面存/取/搜索正确
|
||||
- `TextOverlap` Dice 系数在 [0.0, 1.0] 范围内
|
||||
- 停用词过滤正常
|
||||
- 低于 `min_score` 的结果被过滤
|
||||
|
||||
**新增代码量**:约 60 行
|
||||
|
||||
---
|
||||
|
||||
#### 🥈 示例:`streaming_events_demo.rs`(Phase 0 — 流式接口)
|
||||
|
||||
**设计思路**:调用 `LlmCycle::submit_stream()` 获取事件流,展示了语义事件的消费模式。可选使用 API key 或 MockProvider。
|
||||
|
||||
**流程**:
|
||||
```
|
||||
LlmCycle::submit_stream → 事件循环 match StreamEvent → 输出类型/内容 → TurnComplete 收尾
|
||||
```
|
||||
|
||||
**关键代码片段示意**:
|
||||
```rust
|
||||
let mut cycle = LlmCycle::new(provider, config);
|
||||
let mut stream = cycle.submit_stream("讲个笑话".into(), vec![]).await?;
|
||||
|
||||
use futures_util::StreamExt;
|
||||
while let Some(event) = stream.next().await {
|
||||
match event {
|
||||
StreamEvent::AssistantTextDelta { text } => print!("{text}"),
|
||||
StreamEvent::TurnComplete { reason } => println!("\n\n完成,原因: {reason:?}"),
|
||||
StreamEvent::Error { message } => eprintln!("错误: {message}"),
|
||||
_ => {} // 其他事件
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**验证点**:
|
||||
- 流式链路完整(请求 → 事件 → 完成)
|
||||
- 事件枚举覆盖所有变体
|
||||
- 错误事件正确处理
|
||||
|
||||
**新增代码量**:约 80 行
|
||||
|
||||
---
|
||||
|
||||
#### 🥉 示例:`full_integration.rs`(Phase 全栈集成)
|
||||
|
||||
**设计思路**:端到端演示,将 v0.1 所有模块装配为一个可运行的智能体。需真实 API key。
|
||||
|
||||
**流程**:
|
||||
```
|
||||
创建 Agent(带 system prompt + 2 个工具)→ 注入 MemoryStore/Retriever → AgentSession → 多轮对话 → 知识检索 → SessionMemory 桥接 → 输出运行摘要
|
||||
```
|
||||
|
||||
**验证点**:
|
||||
- Phase 4 "胶水层"真正将 Phase 0~3 粘合
|
||||
- `submit_turn` 内部调用工具
|
||||
- `ConversationMemory` 回写
|
||||
- 全链路无类型/装配错误
|
||||
|
||||
**新增代码量**:约 150 行
|
||||
|
||||
---
|
||||
|
||||
## 实现计划
|
||||
|
||||
### 阶段一:第一梯队(优先级 🥇)
|
||||
|
||||
| 示例 | 预计代码量 | 可并行实施 |
|
||||
|------|-----------|-----------|
|
||||
| `prompt_composer.rs` | ~60 行 | ✅ 与 2/3 并行 |
|
||||
| `custom_tool.rs` | ~80 行 | ✅ 与 1/3 并行 |
|
||||
| `agent_session_demo.rs` | ~100 行 | ✅ 与 1/2 并行 |
|
||||
| `task_agent_demo.rs` | ~70 行 | ✅ 与 1/2/3 并行 |
|
||||
|
||||
**验证标准**:`cargo run --example <name>` 全部成功退出(code 0)。
|
||||
|
||||
### 阶段二:第二梯队(优先级 🥈)
|
||||
|
||||
| 示例 | 预计代码量 | 前置依赖 |
|
||||
|------|-----------|---------|
|
||||
| `conversation_memory_demo.rs` | ~70 行 | 无 |
|
||||
| `knowledge_search_demo.rs` | ~60 行 | 无 |
|
||||
| `streaming_events_demo.rs` | ~80 行 | 无 |
|
||||
|
||||
### 阶段三:第三梯队(优先级 🥉)
|
||||
|
||||
| 示例 | 预计代码量 | 前置依赖 |
|
||||
|------|-----------|---------|
|
||||
| `full_integration.rs` | ~150 行 | 需 `.env` 配置 API key |
|
||||
|
||||
### 总工作量估算
|
||||
|
||||
| 合计 | 代码行数 | 文件数 |
|
||||
|------|---------|-------|
|
||||
| 第一阶段 | ~310 行 | 4 个 |
|
||||
| 第二阶段 | ~210 行 | 3 个 |
|
||||
| 第三阶段 | ~150 行 | 1 个 |
|
||||
| **总计** | **~670 行** | **8 个文件** |
|
||||
|
||||
---
|
||||
|
||||
## 风险评估
|
||||
|
||||
| 风险 | 影响 | 概率 | 缓解措施 |
|
||||
|------|------|------|---------|
|
||||
| 示例与库 API 不同步(库重构后示例过时) | 高 | 中 | 将示例加入 CI:`cargo test --examples` |
|
||||
| MockProvider 行为与真实 Provider 差异 | 低 | 低 | 示例明确标注离线/在线模式 |
|
||||
| 示例代码量膨胀超过预期 | 低 | 低 | 每个示例控制在 200 行以内,超过则拆分子函数 |
|
||||
| `full_integration.rs` 依赖 API key,CI 会跳过 | 中 | 高 | 用 `#[cfg(not(ci))]` 或 `.env` 存在性判断优雅降级 |
|
||||
|
||||
---
|
||||
|
||||
## 验收标准
|
||||
|
||||
1. **阶段一全部完成时**:
|
||||
- `cargo run --example prompt_composer` → 成功退出
|
||||
- `cargo run --example custom_tool` → 成功退出
|
||||
- `cargo run --example agent_session_demo` → 成功退出
|
||||
- `cargo run --example task_agent_demo` → 成功退出
|
||||
|
||||
2. **阶段二全部完成时**:
|
||||
- 额外 3 个示例均可 `cargo run` 成功
|
||||
|
||||
3. **阶段三完成时**(可选):
|
||||
- `full_integration` 在有 `.env` 配置时成功运行,无配置时友好提示降级
|
||||
|
||||
4. **全局验收**:
|
||||
- `cargo build` 无新增警告
|
||||
- 所有示例输出格式清晰,有说明性 println
|
||||
- 每个示例在文件顶部有 `//!` 注释说明其演示目的
|
||||
@@ -0,0 +1,180 @@
|
||||
# Agent Harness 参考项目调研笔记
|
||||
|
||||
> 调研日期:2026-06-09
|
||||
> 用途:为 AG Core Phase 4(Agent Runtime)及后续 v0.2+ 扩展提供设计参考。
|
||||
> 关联:`docs/roadmap.md` Phase 4 / 扩展计划(v0.2+)小节。
|
||||
|
||||
本笔记调研了 4 个 2026 年公开的 AI Agent 项目,对比其核心架构与 AG Core 已完成模块的交集,作为 Phase 4 设计与未来扩展的输入。
|
||||
|
||||
> **重要事实**:4 个项目**均非 Rust 写的**(OpenHuman 虽是 Rust + Tauri,但定位是 desktop 应用)。其价值不在"抄代码",而在参考**经过生产验证的 Agent Harness 架构模式**。AG Core 处于"core 库"层,是这些项目"最底层依赖"的角色。
|
||||
|
||||
---
|
||||
|
||||
## 1. 项目概览
|
||||
|
||||
| 项目 | 类型 | 语言 | GitHub | Stars(调研时) | 定位 |
|
||||
|------|------|------|--------|--------------|------|
|
||||
| **OpenClaw** | Gateway 网关 | TypeScript / Node 24 | `openclaw/openclaw` | — | 自托管消息平台 ↔ AI Agent 桥接 |
|
||||
| **Hermes Agent** | 自主学习智能体 | Python 3.11 | `NousResearch/hermes-agent` | — | 随使用成长的个人数字员工 |
|
||||
| **OpenHuman** | 桌面助手 | Rust + Tauri | `tinyhumansai/openhuman` | 2.3k+ | 记忆驱动的跨工具私人助理 |
|
||||
| **OpenHarness** | Agent Harness 框架 | Python | `HKUDS/OpenHarness` | 12.2k+ | 对标 Claude Code 的轻量级基础设施 |
|
||||
|
||||
## 2. 核心架构对照
|
||||
|
||||
| 维度 | OpenClaw | Hermes Agent | OpenHuman | OpenHarness |
|
||||
|------|----------|--------------|-----------|-------------|
|
||||
| **Agent Loop 形态** | 外部 Pi 二进制进程 | 内置 while 循环 | 内置循环 | 70 行 `run_query` |
|
||||
| **记忆模型** | 跨平台 session | MEMORY.md + 技能库 | Memory Tree(SQLite 分层摘要) | MEMORY.md + Auto-Compaction |
|
||||
| **工具机制** | MCP + 插件 | 40+ 内置技能 + 自动技能生成 | 118+ 集成 + Native Toolbelt | 43 工具 + BaseTool Pydantic |
|
||||
| **多 Agent** | 消息平台多 gateway | 并行子 Agent + RPC | Agent Coordination | Swarm 子代理委派 |
|
||||
| **权限/治理** | `allowFrom` + 提及规则 | 容器加固 + Cron 审批 | 本地优先 + 隐私 | 三级权限 + 钩子 |
|
||||
| **规划/任务** | 无显式规划 | 自然语言驱动 | 无显式 | 隐式(LLM 自我规划) |
|
||||
| **持久化** | 外部进程状态 | `~/.hermes/` 目录 | `~/.openhuman/` SQLite | MEMORY.md + state |
|
||||
| **Hook 体系** | 渠道适配器 | cron + 自定义钩子 | 集成触发 | PreToolUse / PostToolUse |
|
||||
| **干运行模式** | ❌ | ❌ | ❌ | ✅ `--dry-run` |
|
||||
| **流式 TUI** | ✅ 控制 UI | ✅ 完整 TUI | ✅ 桌面应用 | ✅ React/Ink |
|
||||
|
||||
## 3. 共同分层(5 层架构)
|
||||
|
||||
4 个项目都能切成这 5 层,**OpenHarness 的分层最清晰**:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ L4 扩展层 多 Agent / 渠道网关 / 插件 / 任务调度 │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ L3 治理层 权限 / Hook / 审批 / 安全策略 │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ L2 知识层 提示词工厂 / 技能库 / 持久记忆 / 摘要 │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ L1 执行层 Agent Loop / 工具注册 / 流式事件 / 重试 │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ L0 模型层 LLM Provider / Provider Registry / 鉴权 │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**关键观察**:
|
||||
|
||||
- 4 个项目都假定 L0/L1 是"基础设施",不在 core 层重新发明
|
||||
- AG Core 在 Phase 0-3 已完成 L0 / L1 / L2(部分) / L3(部分)
|
||||
- Phase 4 处于 L1 与 L2 的衔接处
|
||||
- L4(Multi-Agent / Gateway)属于应用层,应由上层 crate / 二进制承担
|
||||
|
||||
## 4. AG Core 已具备的对应能力
|
||||
|
||||
对照 5 层架构,AG Core 已就绪情况:
|
||||
|
||||
| 层 | AG Core 对应 | 完成状态 | 文档 |
|
||||
|----|------------|---------|------|
|
||||
| **L0 模型层** | `llm::provider` + `llm::ProviderRegistry` | ✅ Phase 0 | `docs/2-llm-call-lifecycle.md` |
|
||||
| **L1 执行层** | `llm::cycle` + `llm::stream` + `tools::ToolRegistry` | ✅ Phase 0/2 | `docs/5-tool-system.md` |
|
||||
| **L2 知识层** | `prompt` + `memory::store/conversation/knowledge/retriever` | ✅ Phase 1/3 | `docs/4-prompt-engineering.md` / `docs/6-memory-system.md` |
|
||||
| **L3 治理层** | `llm::hooks` + `tools::PermissionChecker` | ✅ Phase 0/2 | `docs/3-phase0-remaining.md` |
|
||||
| **L1→L2 衔接(Agent Runtime)** | — | ❌ **Phase 4 待实现** | — |
|
||||
|
||||
## 5. 借鉴到 Phase 4 的核心模式
|
||||
|
||||
### 5.1 OpenHarness 风格:显式依赖注入容器
|
||||
|
||||
```rust
|
||||
// 核心思想:所有运行时依赖打包成一个对象,沿调用链显式传递
|
||||
pub struct RuntimeBundle {
|
||||
pub provider: Arc<dyn LlmProvider>,
|
||||
pub tool_registry: Arc<ToolRegistry>,
|
||||
pub hook_executor: Arc<HookExecutor>,
|
||||
// ... 可选:memory_store / retriever
|
||||
}
|
||||
```
|
||||
|
||||
**好处**:测试时可注入 mock bundle;支持同时跑多 session;依赖关系显式可追踪。
|
||||
|
||||
**AG Core 决策**:采纳,命名为 `agent::RuntimeBundle`(详见 Phase 4 设计决策记录)。
|
||||
|
||||
### 5.2 Hermes 风格:实体与会话解耦
|
||||
|
||||
```rust
|
||||
pub trait Agent: Send + Sync { /* 角色定义,不绑定 session */ }
|
||||
pub struct AgentSession { /* 绑定 session_id + bundle + 状态 */ }
|
||||
```
|
||||
|
||||
**好处**:同一 `Agent` 可被多个 `AgentSession` 复用(多用户、多会话);session 状态(cost、turn index)独立追踪。
|
||||
|
||||
**AG Core 决策**:采纳,详见 Phase 4 §接口签名草案。
|
||||
|
||||
### 5.3 OpenHarness 风格:Agent Loop 的极简本质
|
||||
|
||||
OpenHarness 的 `run_query` 核心只有 70 行,本质是一个 `while` 循环 + 一个 `if not tool_uses: return` 的判断。
|
||||
|
||||
**AG Core 现状**:`llm::cycle::submit_with_tools()` 已经在 Phase 2 末实现了这个循环,Phase 4 不应重新实现。
|
||||
|
||||
**AG Core 决策**:Phase 4 只在 `AgentSession::submit_turn()` 提供 30 行的 reference impl,组装 `LlmCycle` 并暴露其能力,业务循环留给上层。
|
||||
|
||||
### 5.4 OpenHuman 风格:分层摘要记忆树
|
||||
|
||||
OpenHuman 的 Memory Tree 创新点:
|
||||
- 多源数据(Gmail / Slack / GitHub 等)→ 规范化 Markdown → ≤3k token chunks → 打分 → 折叠成 per-source / per-topic / per-day 摘要树
|
||||
- 存储在本地 SQLite
|
||||
- Auto-fetch 每 20 分钟拉取新数据
|
||||
|
||||
**AG Core 现状**:`memory::KnowledgeStore` 已是 LLM Wiki 风格的抽象层。Phase 4 v1 不引入 SQLite 实现(属于 L4 应用层)。
|
||||
|
||||
**AG Core 决策**:v0.2+ 考虑 `note-knowledge-graph-design.md` 已记录的 KnowledgeGraph / RecallBased 淘汰等深度记忆能力。
|
||||
|
||||
## 6. 反模式(不要照搬)
|
||||
|
||||
| 反模式 | 出现项目 | 不要照搬的理由 |
|
||||
|--------|---------|--------------|
|
||||
| 双进程架构(Node UI + Python 后端) | OpenClaw | 应用层架构,core 库不涉及 |
|
||||
| SQLite 持久化细节 | OpenHuman | 属于 L4 应用层具体实现 |
|
||||
| Pydantic 工具校验 | OpenHarness | Python 生态强项;Rust 已有 `serde_json::Value` + JSON Schema,足够 |
|
||||
| 43 工具内置 | OpenHarness | 应用层选型,core 库应保持"零内置工具" |
|
||||
| 单进程内多平台消息网关 | OpenClaw / Hermes | 属于 L4 应用层 |
|
||||
|
||||
## 7. 与 AG Core 现有模块的接口对齐
|
||||
|
||||
下表列出 4 个项目中被 AG Core **已经覆盖**或**即将在 Phase 4 覆盖**的能力,避免重复造轮子:
|
||||
|
||||
| 4 项目中的能力 | AG Core 对应 | 状态 |
|
||||
|--------------|-------------|------|
|
||||
| 工具注册表 | `tools::ToolRegistry` | ✅ Phase 2 已实现 |
|
||||
| 权限检查 | `tools::PermissionChecker` | ✅ Phase 2 已实现 |
|
||||
| 生命周期钩子 | `llm::HookExecutor` | ✅ Phase 0 已实现,Phase 4 扩展 3 个事件 |
|
||||
| 自动 tool 循环 | `llm::cycle::submit_with_tools()` | ✅ Phase 2 末已实现 |
|
||||
| Auto-Compaction | `llm::compact` | ✅ Phase 0 已实现 |
|
||||
| 对话记忆 | `memory::ConversationMemory` | ✅ Phase 3 已实现 |
|
||||
| 知识库 | `memory::KnowledgeStore` | ✅ Phase 3 已实现 |
|
||||
| 关键词检索 | `memory::MemoryRetriever` | ✅ Phase 3 已实现 |
|
||||
| 提示词模板 | `prompt::PromptTemplate` + `PromptComposer` | ✅ Phase 1 已实现 |
|
||||
| 用量追踪 | `llm::cycle::usage::CostTracker` | ✅ Phase 0 已实现 |
|
||||
| **显式依赖注入容器** | — | ⏳ **Phase 4 新增 `RuntimeBundle`** |
|
||||
| **Agent ↔ Session 分离** | — | ⏳ **Phase 4 新增 `Agent` + `AgentSession`** |
|
||||
| **任务规划** | — | ⏳ **Phase 4 新增 `TaskAgent` + `Plan`** |
|
||||
| **结构化 Plan 解析** | — | ⏳ **Phase 4 新增 `PlanParser` trait** |
|
||||
|
||||
## 8. v0.2+ 扩展项与参考项目的对应
|
||||
|
||||
`docs/roadmap.md` 扩展计划(v0.2+)表中的项,在 4 个项目中的对应实现:
|
||||
|
||||
| 扩展项 | OpenClaw | Hermes | OpenHuman | OpenHarness |
|
||||
|--------|----------|--------|-----------|-------------|
|
||||
| Multi-Agent / Swarm | ❌ | ✅ 并行子 Agent | ✅ Agent Coordination | ✅ Swarm |
|
||||
| Markdown 技能 | ❌ | ✅ SKILL.md | ❌ | ✅ prompts/*.md |
|
||||
| 多通道检索(vector + keyword) | ❌ | ❌ | ✅ Memory Tree | ❌ |
|
||||
| KnowledgeGraph | ❌ | ❌ | ✅ Memory Graph | ❌ |
|
||||
| TokenJuice 智能压缩 | ❌ | ✅ 轨迹压缩 | ✅ TokenJuice | ✅ Auto-Compaction |
|
||||
| TUI / Gateway | ✅ 控制 UI | ✅ 完整 TUI | ✅ 桌面应用 | ✅ React/Ink |
|
||||
| 训练 / RL 轨迹 | ❌ | ✅ Atropos | ❌ | ❌ |
|
||||
| 人类审批(Human-in-the-loop) | ❌ | ✅ Cron 审批 | ❌ | ✅ 权限弹窗 |
|
||||
|
||||
## 9. 参考资源
|
||||
|
||||
- **OpenClaw 文档**:<https://docs.openclaw.ai/zh-CN>
|
||||
- **Hermes Agent 官网**:<https://hermes-agent.org/zh/>
|
||||
- **Hermes Agent GitHub**:<https://github.com/NousResearch/hermes-agent>
|
||||
- **OpenHuman GitHub**:<https://github.com/tinyhumansai/openhuman>
|
||||
- **OpenHuman 中文站**:<https://openhumanai.cn/docs/>
|
||||
- **OpenHarness GitHub**:<https://github.com/HKUDS/OpenHarness>
|
||||
- **OpenHarness 深度学习笔记**:<https://www.joyehuang.me/blog/20260410---openharnessphase1/post>
|
||||
|
||||
## 10. 一句话总结
|
||||
|
||||
> **4 个项目都不在 L0/L1 重新发明轮子——它们都假定基础设施已就绪。AG Core 在 Phase 0-3 已经把这 4 层全做完了。Phase 4 的核心价值是把它们"装配起来",同时为未来 v0.2+ 的 L4 扩展(Multi-Agent / Skills / TUI)留好接口。**
|
||||
@@ -0,0 +1,335 @@
|
||||
# Phase 4 Agent Runtime — 设计决策记录
|
||||
|
||||
> 决策固化日期:2026-06-09
|
||||
> 用途:记录 Phase 4 设计阶段的关键决策、接口签名草案、文件清单,作为 `docs/7-agent-runtime.md` 方案文档的输入约束。
|
||||
> 关联:
|
||||
> - `docs/7-agent-runtime.md` — 完整方案文档(待写 / 已写)
|
||||
> - `docs/note-agent-harness-references.md` — 参考项目调研(OpenClaw / Hermes / OpenHuman / OpenHarness)
|
||||
> - `docs/roadmap.md` — 项目总 Roadmap
|
||||
> - `docs/2-llm-call-lifecycle.md` / `3-phase0-remaining.md` / `4-prompt-engineering.md` / `5-tool-system.md` / `6-memory-system.md` — Phase 0-3 方案
|
||||
|
||||
本文件是 Phase 4 设计阶段的"事实基础"——所有决策都有明确的对话出处与依据。后续 Phase 4 实施时应与本记录保持一致;如需调整,应先更新本记录再改代码。
|
||||
|
||||
---
|
||||
|
||||
## 1. 设计目标
|
||||
|
||||
AG Core Phase 4 的定位是**「Phase 0-3 的薄胶水层 + 一组 trait 抽象」**,遵循 OpenHarness 的"显式依赖注入"模式 + Hermes 的"两层实体/会话"模型。**不**实现业务循环,**不**做产品级功能,**不**假设上层如何使用 memory。
|
||||
|
||||
## 2. 范围与边界
|
||||
|
||||
### 2.1 必须实现(12 项)
|
||||
|
||||
| # | 交付物 | 文件 | 关键决策 |
|
||||
|---|--------|------|---------|
|
||||
| 1 | `Agent` trait | `src/agent/agent.rs` | 角色定义:name / system_prompt / 工具集 / 引用 session 句柄 |
|
||||
| 2 | `RuntimeBundle` | `src/agent/runtime.rs` | 依赖注入容器(OpenHarness 风格) |
|
||||
| 3 | `AgentSession` | `src/agent/session.rs` | 会话实例 + **最小 reference impl**(`submit_turn` ~30 行) |
|
||||
| 4 | `TaskAgent` + `Plan` / `Step` | `src/agent/task.rs` | 双入口:`run(goal)` 自主式 + `execute_plan(plan)` 外部驱动式 |
|
||||
| 5 | `PlanParser` trait + `JsonPlanParser` 参考实现 | `src/agent/task.rs` | 注入式(澄清 3 选项 C) |
|
||||
| 6 | `AgentError` | `src/agent/error.rs` | 聚合 LlmError / ToolError / MemoryError,含 `is_recoverable()` |
|
||||
| 7 | `AgentConfig` / `AgentBuilder` | `src/agent/builder.rs` | 链式构造 `RuntimeBundle` |
|
||||
| 8 | Hook 事件扩展 | `src/llm/hooks.rs` | 追加 `OnTurnStart` / `OnTurnEnd` / `OnPlanStepComplete` 3 个事件 + 上下文扩展(澄清 4) |
|
||||
| 9 | `lib.rs` 导出 | `src/lib.rs` | 一行 `pub mod agent;` |
|
||||
| 10 | 烟雾测试 | `src/agent/tests.rs` 或内联 | 2-3 个:trait 可装配 / RuntimeBundle 可构造 / `submit_turn` 跑通 mock |
|
||||
| 11 | 方案文档 | `docs/7-agent-runtime.md` | 编号 `7`(最大编号是 `6`,已确认无冲突) |
|
||||
| 12 | Roadmap 同步 | `docs/roadmap.md` | 状态从 ❌ 缺失 改为 ✅ |
|
||||
|
||||
### 2.2 明确不做(v0.2+ 边界)
|
||||
|
||||
| 推迟项 | 理由 |
|
||||
|--------|------|
|
||||
| 完整 `BasicAgent` 多轮 turn 循环 | core 库不假设业务循环 |
|
||||
| `ConversationAgent` 自动回写 | 记忆在独立 task 处理,由上层回写 |
|
||||
| 强绑定 `ConversationMemory` 字段 | 改 `Option<Arc<dyn MemoryStore>>` 弱引用 |
|
||||
| Plan 拆解的提示词模板 | 由上层注入 `PlanParser` |
|
||||
| Multi-Agent / Swarm | 接口未稳定,独立 phase |
|
||||
| Markdown 技能按需加载 | 属于知识层 |
|
||||
| 三级权限模式 UI | 应用层 |
|
||||
| 干运行 / TUI / Gateway | 应用层 |
|
||||
| 完整的集成测试套件 | 2-3 个烟雾测试足够(呼应"最小范围") |
|
||||
|
||||
详细 v0.2+ 候选项见 `docs/roadmap.md` 扩展计划(v0.2+)小节与 `docs/note-agent-harness-references.md` 第 8 节。
|
||||
|
||||
## 3. 核心架构
|
||||
|
||||
### 3.1 分层(与 OpenHarness 5 层一致)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ 应用层 (上层 crate / 二进制 / Gateway) │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ Agent Runtime ← Phase 4:trait + RuntimeBundle + Session │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ LLM / Tool / Prompt / Memory ← Phase 0/1/2/3(已完成) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 3.2 实体关系
|
||||
|
||||
```
|
||||
┌────────────┐ ┌──────────────────┐
|
||||
│ Agent │ 1 * │ AgentSession │
|
||||
│ (trait) ├────────►│ (struct) │
|
||||
│ - name │ │ - session_id │
|
||||
│ - prompt │ │ - bundle: Arc │
|
||||
│ - tools │ │ - turn_index │
|
||||
└────────────┘ │ - cost_so_far │
|
||||
└──────────────────┘
|
||||
│
|
||||
▼ 共享
|
||||
┌──────────────────┐
|
||||
│ RuntimeBundle │
|
||||
│ - provider │
|
||||
│ - tool_registry │
|
||||
│ - hook_executor │
|
||||
│ - memory_store? │ ◄── 弱引用(澄清 2 选项 B)
|
||||
│ - retriever? │
|
||||
│ - config │
|
||||
└──────────────────┘
|
||||
│
|
||||
▼ 注册为 tool
|
||||
┌──────────────────┐
|
||||
│ "retrieve" tool │ ◄── 如果 retriever 存在则自动注册
|
||||
└──────────────────┘
|
||||
```
|
||||
|
||||
### 3.3 决策对照表
|
||||
|
||||
| 决策点 | 选择 | 来源 |
|
||||
|--------|------|------|
|
||||
| 实体 vs 会话 | 两层模型(`Agent` + `AgentSession`) | 讨论第 4 轮第 1 条 / OpenHarness / Hermes |
|
||||
| 范围控制 | trait + 最小 reference impl(~30 行) | 讨论第 4 轮第 2 条 / 澄清 1 选项 B |
|
||||
| 记忆处理 | 弱引用 + 自动注册 retriever 为 tool | 讨论第 4 轮第 3 条 / 澄清 2 选项 B |
|
||||
| Hook 扩展 | 3 个新事件 + 上下文扩展 | 讨论第 4 轮第 4 条 / 澄清 4 |
|
||||
| 方案文档位置 | `docs/7-agent-runtime.md` | 讨论第 4 轮第 5 条 |
|
||||
| TaskAgent 入口 | 双入口(自主 + 外部驱动) | 讨论第 4 轮第 6 条 |
|
||||
| 自主式 Plan 解析 | 注入式 `PlanParser` trait + `JsonPlanParser` 参考实现 | 澄清 3 选项 C |
|
||||
| 依赖注入 | `RuntimeBundle` 显式容器 | 讨论第 4 轮第 7 条 / OpenHarness |
|
||||
| 文档撰写 | 推迟到 Proposal 阶段 | 讨论第 4 轮第 8 条 |
|
||||
|
||||
## 4. 接口签名草案
|
||||
|
||||
> ⚠️ **这些是"设计约束",不是最终代码**。方案文档(`docs/7-agent-runtime.md`)与实施阶段可微调字段顺序、文档注释、错误变体等,但**核心 trait 形状**和**方法名**应保持稳定。
|
||||
|
||||
### 4.1 `Agent` trait
|
||||
|
||||
```rust
|
||||
pub trait Agent: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
fn system_prompt(&self) -> Option<&str>;
|
||||
/// 列出该 Agent 想要暴露给 LLM 的工具定义。
|
||||
/// 默认实现:从 RuntimeBundle.tool_registry 取全部(最常用)。
|
||||
/// 子 trait 可覆盖做白名单/过滤。
|
||||
fn tool_definitions(&self, bundle: &RuntimeBundle) -> Vec<ToolDefinition>;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.2 `RuntimeBundle`
|
||||
|
||||
```rust
|
||||
pub struct RuntimeBundle {
|
||||
pub provider: Arc<dyn LlmProvider>,
|
||||
pub tool_registry: Arc<ToolRegistry>,
|
||||
pub hook_executor: Arc<HookExecutor>,
|
||||
pub memory_store: Option<Arc<dyn MemoryStore>>, // 弱引用(澄清 2 选项 B)
|
||||
pub retriever: Option<Arc<MemoryRetriever>>, // 弱引用(澄清 2 选项 B)
|
||||
pub config: AgentConfig,
|
||||
}
|
||||
|
||||
impl RuntimeBundle {
|
||||
/// 构造时如果 retriever 存在,自动注册为 "retrieve" tool。
|
||||
pub fn new(/* ... */) -> Self;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.3 `AgentSession`
|
||||
|
||||
```rust
|
||||
pub struct AgentSession {
|
||||
pub session_id: String,
|
||||
pub agent_name: String,
|
||||
bundle: Arc<RuntimeBundle>,
|
||||
turn_index: u32,
|
||||
cost_so_far: CostTracker,
|
||||
}
|
||||
|
||||
impl AgentSession {
|
||||
pub fn new(agent: &dyn Agent, session_id: impl Into<String>, bundle: Arc<RuntimeBundle>) -> Self;
|
||||
|
||||
/// 最小 reference impl:组装 LlmCycle + submit + 累计 cost。
|
||||
/// 不做 memory 回写(呼应"记忆在独立 task 处理"原则)。
|
||||
pub async fn submit_turn(
|
||||
&mut self,
|
||||
user_input: impl Into<String>,
|
||||
) -> Result<ChatResponse, AgentError>;
|
||||
|
||||
pub fn usage(&self) -> &CostTracker;
|
||||
pub fn turn_index(&self) -> u32;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.4 `TaskAgent` + `Plan` + `Step`
|
||||
|
||||
```rust
|
||||
pub struct Plan {
|
||||
pub id: String,
|
||||
pub goal: String,
|
||||
pub steps: Vec<Step>,
|
||||
}
|
||||
|
||||
pub struct Step {
|
||||
pub index: usize,
|
||||
pub description: String,
|
||||
pub status: StepStatus,
|
||||
}
|
||||
|
||||
pub enum StepStatus {
|
||||
Pending,
|
||||
Running,
|
||||
Completed(ChatResponse),
|
||||
Failed(AgentError),
|
||||
Skipped,
|
||||
}
|
||||
|
||||
/// 注入式 Plan 解析器。
|
||||
#[async_trait]
|
||||
pub trait PlanParser: Send + Sync {
|
||||
async fn parse(&self, raw: &str, goal: &str) -> Result<Plan, AgentError>;
|
||||
}
|
||||
|
||||
/// 基于 serde_json 的参考实现(约 20 行)。
|
||||
pub struct JsonPlanParser;
|
||||
|
||||
#[async_trait]
|
||||
impl PlanParser for JsonPlanParser { /* ... */ }
|
||||
|
||||
/// TaskAgent 双入口。
|
||||
#[async_trait]
|
||||
pub trait TaskAgent: Agent {
|
||||
/// 自主式:内部用 LLM 拆 Plan → execute_plan
|
||||
async fn run(&mut self, session: &mut AgentSession, goal: &str) -> Result<Plan, AgentError>;
|
||||
|
||||
/// 外部驱动式:用户预定义 Plan → 逐步执行
|
||||
async fn execute_plan(
|
||||
&mut self,
|
||||
session: &mut AgentSession,
|
||||
plan: Plan,
|
||||
) -> Result<Plan, AgentError>;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.5 `AgentError`
|
||||
|
||||
```rust
|
||||
pub enum AgentError {
|
||||
Llm(LlmError),
|
||||
Tool(ToolError),
|
||||
Memory(MemoryError),
|
||||
PlanParse(String),
|
||||
HookBlocked(String),
|
||||
LimitExceeded(String),
|
||||
Config(String),
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl AgentError {
|
||||
pub fn is_recoverable(&self) -> bool { /* ... */ }
|
||||
}
|
||||
```
|
||||
|
||||
### 4.6 `AgentConfig` + `AgentBuilder`
|
||||
|
||||
```rust
|
||||
pub struct AgentConfig {
|
||||
pub max_turns: u32,
|
||||
pub max_tool_turns: u32,
|
||||
pub session_ttl: Option<Duration>,
|
||||
pub compact_config: Option<CompactConfig>,
|
||||
}
|
||||
|
||||
pub struct AgentBuilder { /* ... */ }
|
||||
|
||||
impl AgentBuilder {
|
||||
pub fn new() -> Self;
|
||||
pub fn provider(self, p: Arc<dyn LlmProvider>) -> Self;
|
||||
pub fn tool_registry(self, r: Arc<ToolRegistry>) -> Self;
|
||||
pub fn hook_executor(self, h: Arc<HookExecutor>) -> Self;
|
||||
pub fn memory_store(self, m: Arc<dyn MemoryStore>) -> Self; // 选填
|
||||
pub fn retriever(self, r: Arc<MemoryRetriever>) -> Self; // 选填
|
||||
pub fn config(self, c: AgentConfig) -> Self;
|
||||
pub fn build(self) -> Result<RuntimeBundle, AgentError>;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.7 Hook 扩展(`src/llm/hooks.rs` 改动)
|
||||
|
||||
```rust
|
||||
pub enum HookEvent {
|
||||
// ... 现有 4 个 ...
|
||||
|
||||
// 新增 3 个:
|
||||
OnTurnStart,
|
||||
OnTurnEnd,
|
||||
OnPlanStepComplete,
|
||||
}
|
||||
|
||||
// HookContext 扩展 2 个 Option 字段(澄清 4):
|
||||
pub struct HookContext {
|
||||
// ... 现有字段 ...
|
||||
pub turn_index: Option<u32>, // OnTurnStart/End 用
|
||||
pub plan_step_index: Option<usize>, // OnPlanStepComplete 用
|
||||
}
|
||||
```
|
||||
|
||||
## 5. 文件清单
|
||||
|
||||
### 5.1 新增文件(7 个)
|
||||
|
||||
```
|
||||
src/agent.rs # 模块根 + pub use 重导出
|
||||
src/agent/agent.rs # Agent trait
|
||||
src/agent/runtime.rs # RuntimeBundle + AgentConfig
|
||||
src/agent/session.rs # AgentSession
|
||||
src/agent/task.rs # TaskAgent trait + Plan/Step + PlanParser + JsonPlanParser
|
||||
src/agent/builder.rs # AgentBuilder
|
||||
src/agent/error.rs # AgentError
|
||||
```
|
||||
|
||||
### 5.2 修改文件(3 个)
|
||||
|
||||
```
|
||||
src/lib.rs # + pub mod agent;
|
||||
src/llm/hooks.rs # + 3 个事件变体 + 2 个上下文字段(极小改)
|
||||
docs/roadmap.md # 状态翻转 + Phase 4 交付物清单更新(实施时再做)
|
||||
```
|
||||
|
||||
### 5.3 关联文档(已存在 / 待写)
|
||||
|
||||
```
|
||||
docs/note-agent-harness-references.md # 参考项目调研(已存在)
|
||||
docs/7-agent-runtime.md # 完整方案文档(路径 A 输出)
|
||||
docs/note-agent-runtime-design.md # 本文件
|
||||
```
|
||||
|
||||
## 6. 预估规模
|
||||
|
||||
- **新增代码**:约 **600-700 行**(含 2-3 个烟雾测试)
|
||||
- **修改代码**:约 **10-20 行**(`hooks.rs` 改动 + `lib.rs` + `roadmap.md`)
|
||||
- **方案文档**:约 **450-550 行 Markdown**(沿用 6-memory-system.md 的 6 段式结构)
|
||||
|
||||
## 7. 待办事项(按依赖顺序)
|
||||
|
||||
1. ✅ Phase 4 范围已收窄(§2.1)
|
||||
2. ✅ 核心架构已对齐 OpenHarness / Hermes(§3)
|
||||
3. ✅ 接口签名草案已固化(§4)
|
||||
4. ✅ 文件清单已确定(§5)
|
||||
5. ✅ 编号冲突已验证(最大是 `6`,新文件用 `7`)
|
||||
6. ⏳ 写 `docs/7-agent-runtime.md` 方案文档
|
||||
7. ⏳ 按文档实施 7 个新文件 + 3 个修改
|
||||
8. ⏳ 跑通 2-3 个烟雾测试
|
||||
9. ⏳ 更新 `docs/roadmap.md` 状态翻转
|
||||
|
||||
## 8. 一句话总结
|
||||
|
||||
> **Phase 4 = Phase 0-3 的薄胶水层 + 一组 trait 抽象**。**不**实现业务循环,**不**做产品级功能,**不**假设上层如何使用 memory。借鉴 OpenHarness 的"显式依赖注入容器"与 Hermes 的"实体/会话分离"模型,记忆以弱引用方式接入,`MemoryRetriever` 在 `RuntimeBundle::new()` 时自动注册为 LLM 可调用的 `retrieve` 工具。
|
||||
@@ -0,0 +1,222 @@
|
||||
# Context 切换方案设计备忘
|
||||
|
||||
> 创建日期:2026-06-10
|
||||
> 状态:备忘(Phase 4 不实现)
|
||||
> 关联文档:
|
||||
> - `docs/7-agent-runtime.md` — Phase 4 方案(含 SessionMemory 设计)
|
||||
> - `docs/note-opencode-agent-switching.md` — OpenCode 切换机制调研
|
||||
> - `docs/roadmap.md` — 项目总 Roadmap
|
||||
|
||||
---
|
||||
|
||||
## 1. 背景
|
||||
|
||||
### 1.1 问题
|
||||
|
||||
在调研 OpenCode 的 Agent 切换机制后(详见 `docs/note-opencode-agent-switching.md`),发现其做法是:
|
||||
|
||||
- 切换 agent 时**不动消息历史**
|
||||
- 在 user message 末尾追加 `synthetic: true` 的 `<system-reminder>` 提醒
|
||||
- 同时**完全重新计算** system prompt
|
||||
|
||||
这个方法的问题是:**长上下文中频繁切换 agent 容易给 LLM 造成身份困惑**。同一个消息列表里有 `system: build` 的 identity,又出现 `system: plan` 的 identity,LLM 容易"串味"。
|
||||
|
||||
### 1.2 核心思路
|
||||
|
||||
以一个 session 里存在**多个独立的 context** 来解决,每个 context 有自己独立的 system prompt + 消息列表:
|
||||
|
||||
```
|
||||
OpenCode 模式(一条流):
|
||||
[system: build, user: A, ass: A', user: <切plan>, system: plan, user: B]
|
||||
↑ 身份困惑
|
||||
|
||||
建议的多 context 模式:
|
||||
session {
|
||||
context_a: [system: build, user: A, ass: A'] ← 只有 build 的 identity
|
||||
context_b: [system: plan, user: B, ass: B'] ← 只有 plan 的 identity
|
||||
}
|
||||
```
|
||||
|
||||
### 1.3 适用范围
|
||||
|
||||
| 场景 | 适用性 | 说明 |
|
||||
|------|--------|------|
|
||||
| Agent 切换(build ↔ plan) | ✅ 核心场景 | 同一 session 内更换角色 |
|
||||
| 主从 Agent 协作 | ✅ 核心场景 | primary 委派子任务给 subagent,subagent 独立运作 |
|
||||
| 长 session 上下文压缩 | ✅ 附带收益 | 拆分 context 后,每个 context 独立累积消息,不会互相拖长 |
|
||||
| 并行 context 执行 | ⚠️ 拓展场景 | context_a 和 context_b 可各自独立推进 |
|
||||
|
||||
---
|
||||
|
||||
## 2. 三个候选方案
|
||||
|
||||
### 方案 A:OpenCode 式(system prompt 重算 + synthetic 追加)
|
||||
|
||||
**做法**:
|
||||
- 单一消息列表
|
||||
- 切换时重算 system prompt
|
||||
- user message 末尾追加 `<system-reminder>` 标签
|
||||
|
||||
**优点**:
|
||||
- 实现简单
|
||||
- 消息历史完整可见
|
||||
|
||||
**缺点**:
|
||||
- 长上下文身份困惑
|
||||
- context 互相污染(每个 context 都要看全部历史)
|
||||
|
||||
**结论**:❌ 否决。不解决身份困惑问题。
|
||||
|
||||
### 方案 B:信息池 + 切换不重置(借鉴 OpenCode + 增强)
|
||||
|
||||
**做法**:
|
||||
- 切换时保留历史
|
||||
- 使用 `<system-reminder>` 标签
|
||||
- 靠 prompt 工程让 LLM 理解身份变更
|
||||
|
||||
**优点**:
|
||||
- 历史连贯
|
||||
- 改动最小
|
||||
|
||||
**缺点**:
|
||||
- 仍然有身份困惑风险
|
||||
- 上下文不受控增长
|
||||
|
||||
**结论**:❌ 否决。治标不治本。
|
||||
|
||||
### 方案 C:多 context 隔离 + SessionMemory 桥接(推荐)
|
||||
|
||||
**做法**:
|
||||
- 每个 agent 切换创建一个新的 context(独立消息列表 + 独立 system prompt)
|
||||
- context 之间通过 `SessionMemory` 桥接关键信息
|
||||
- 切换时新 context 的 system prompt 末尾注入 `SessionMemory::snapshot()`
|
||||
|
||||
```
|
||||
context_a (build)
|
||||
→ 对话 50 轮
|
||||
→ 写入 SessionMemory: {"design_decision": "用 PostgreSQL",
|
||||
"files_changed": "src/db.rs"}
|
||||
→ 销毁(或沉睡)
|
||||
|
||||
创建 context_b (plan)
|
||||
→ system_prompt += snapshot()
|
||||
→ "<session-context>
|
||||
design_decision: 用 PostgreSQL
|
||||
files_changed: src/db.rs
|
||||
</session-context>"
|
||||
→ 对话 10 轮(不需要看 context_a 的 50 轮历史)
|
||||
→ 读 SessionMemory: get("design_decision") → "用 PostgreSQL"
|
||||
```
|
||||
|
||||
**优点**:
|
||||
- ✅ 身份稳定:每个 context 只有一套 system prompt
|
||||
- ✅ 上下文隔离:context_b 不受 context_a 的消息量影响
|
||||
- ✅ 信息桥接:关键结论通过 SessionMemory 显式传递
|
||||
- ✅ 并行潜力:两个 context 可各自运行
|
||||
|
||||
**缺点**:
|
||||
- ❌ 实现复杂度:从"一个消息列表"到"多个消息列表 + 桥接"
|
||||
- ❌ 信息完整性:LLM 自主决定"什么值得记",可能遗漏细节
|
||||
- ❌ 上层理解成本:应用层需要理解 context 概念
|
||||
|
||||
**结论**:✅ 推荐。架构上最干净,但 Phase 4 不做全部实现。
|
||||
|
||||
---
|
||||
|
||||
## 3. SessionMemory 桥接机制(方案 C 的核心)
|
||||
|
||||
### 3.1 设计决策
|
||||
|
||||
| 决策 | 结论 | 理由 |
|
||||
|------|------|------|
|
||||
| 复用 Phase 3 `MemoryStore` | ✅ 是 | 不引入新存储机制 |
|
||||
| 跨进程支持 | ✅ 是 | 换后端即可(Redis / SQLite),`InMemoryStore` 兜底 |
|
||||
| namespace 隔离 | ✅ 是 | `_session_{session_id}` 命名空间 |
|
||||
| 谁写 SessionMemory | LLM 通过 tool 显式写(v0.2+)或上层应用 API 写 | 不支持自动写——避免 "写太多 = 噪音,写太少 = 遗漏" |
|
||||
| snapshot 格式 | `<session-context>` XML 风格 | 专为注入 system prompt 设计 |
|
||||
|
||||
### 3.2 谁写 SessionMemory 的三种选项
|
||||
|
||||
| 选项 | 描述 | 评估 |
|
||||
|------|------|------|
|
||||
| **选项 1:AgentSession 自动写** | 每轮对话后自动摘录关键信息 | ❌ 摘录什么?容易变成精简版对话历史,失去"关键信息"的定位 |
|
||||
| **选项 2:LLM 通过 tool 显式写** | 把 `SessionMemory::set` 暴露为 Tool 供 LLM 调用 | ✅ LLM 自主决定什么值得记;v0.2+ 实现自动注册 |
|
||||
| **选项 3:上层应用 API 写** | `agent_session.session_memory.set("k", "v")` | ✅ Phase 4 即可用,最透明 |
|
||||
|
||||
**Phase 4 实现选项 3**,v0.2+ 补充选项 2(tool 自动注册)。
|
||||
|
||||
### 3.3 三层记忆体系
|
||||
|
||||
```
|
||||
持久层(Phase 3) MemoryStore / KnowledgeStore ── 跨 session 持久,长期知识
|
||||
会话层(Phase 4) SessionMemory ── 单 session 内共享,context 桥接
|
||||
对话层(Phase 3) ConversationMemory ── 单 context 内消息历史
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Phase 4 范围 vs v0.2+ 范围
|
||||
|
||||
### ✅ Phase 4 做
|
||||
|
||||
| 组件 | 状态 | 行数 |
|
||||
|------|------|------|
|
||||
| `SessionMemory` struct | ✅ 做 | ~40 行 |
|
||||
| `AgentSession` + `session_memory` 字段 | ✅ 做 | ~3 行 |
|
||||
| `AgentSession` 持 `Arc<dyn Agent>` 替代 `agent_name: String` | ✅ 做 | ~3 行 |
|
||||
| `RuntimeBundle` + `session_memory_backend` 字段 | ✅ 做 | ~1 行 |
|
||||
| `AgentBuilder` + `.session_memory_backend()` | ✅ 做 | ~3 行 |
|
||||
|
||||
### ❌ 延后到 v0.2+
|
||||
|
||||
| 组件 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| Context 切换管理(`switch_context` / `create_context`) | ❌ 延后 | 需要 `ContextManager` 包装 |
|
||||
| 多 context 生命周期管理 | ❌ 延后 | context 的创建/销毁/切换策略 |
|
||||
| `"session_memory_set"` tool 自动注册 | ❌ 延后 | 在 `ToolRegistry` 里注册特殊 tool |
|
||||
| Context 级别的 `ConversationMemory` 自动管理 | ❌ 延后 | 每个 context 独立消息历史 |
|
||||
|
||||
### 延后的理由
|
||||
|
||||
1. **最小范围原则**:Phase 4 定位是"薄胶水层 + trait 抽象",多 context 管理属于业务编排的范畴
|
||||
2. **稳定 API 优先**:先把 `AgentSession` / `RuntimeBundle` / `SessionMemory` 的 API 定稳,v0.2+ 在上面搭建 context 切换
|
||||
3. **降低实施风险**:Phase 4 已有 13 个交付任务,加 context 切换会增加 2-3 倍复杂度
|
||||
|
||||
---
|
||||
|
||||
## 5. v0.2+ Context 切换的设想接口
|
||||
|
||||
> 以下为未来实现的草案,非承诺。记录在这里避免 v0.2+ 重新设计时丢失上下文。
|
||||
|
||||
```rust
|
||||
pub struct ContextManager {
|
||||
contexts: HashMap<String, AgentSession>,
|
||||
active_context: String,
|
||||
session_memory: SessionMemory,
|
||||
}
|
||||
|
||||
impl ContextManager {
|
||||
/// 创建一个新的 context,绑定指定 agent
|
||||
pub fn create_context(&mut self, id: &str, agent: Arc<dyn Agent>) -> Result<(), AgentError>;
|
||||
|
||||
/// 切换到已有 context
|
||||
pub fn switch_context(&mut self, id: &str) -> Result<&mut AgentSession, AgentError>;
|
||||
|
||||
/// 销毁 context
|
||||
pub fn destroy_context(&mut self, id: &str) -> Result<(), AgentError>;
|
||||
|
||||
/// 从 context_a 桥接关键信息到 context_b 的 system prompt
|
||||
pub fn bridge(&mut self, from: &str, to: &str) -> Result<(), AgentError>;
|
||||
}
|
||||
```
|
||||
|
||||
切换流程:
|
||||
1. `context_manager.create_context("plan", plan_agent)` — 新 context 的 system prompt 自动附加 `session_memory.snapshot()`
|
||||
2. `context_manager.switch_context("plan")` — 返回 context 的 `AgentSession`,应用层调 `submit_turn`
|
||||
3. context 销毁时,关键信息经由 LLM 或上层应用写入 `SessionMemory`
|
||||
|
||||
---
|
||||
|
||||
## 6. 一句话总结
|
||||
|
||||
> **多 context 切换方案 = `SessionMemory`(Phase 4 做信息桥接基础) + `ContextManager`(v0.2+ 做切换管理)。Phase 4 只铺"水管接口",不装"水循环系统"。**
|
||||
@@ -0,0 +1,266 @@
|
||||
# 知识图谱与高级检索设计(Phase 4 备用)
|
||||
|
||||
> 本文记录 Phase 3 设计过程中裁剪的内容,待 Phase 4(Agent 运行时)制定时参考。
|
||||
> 来源:`docs/6-memory-system.md` v1 版本,2026-06-07
|
||||
|
||||
---
|
||||
|
||||
## 背景
|
||||
|
||||
Phase 3 记忆系统方案做减法后,以下设计被推迟到 Phase 4。这些组件需要 Agent 的编排能力(LLM 提取标签、自动维护知识图谱、智能检索策略)才能真正产生价值,因此不适合在 Phase 3 的存储层实现。
|
||||
|
||||
---
|
||||
|
||||
## 1. KnowledgeGraph(知识图谱)
|
||||
|
||||
### 1.1 设计意图
|
||||
|
||||
实体-关系图存储,用于关联检索。与 KnowledgeStore(内容/页面级)互补,提供实体级 + 关系维度的检索能力。
|
||||
|
||||
```
|
||||
KnowledgeStore: 页面级内容("什么是 X")
|
||||
KnowledgeGraph: 实体级关系("X 与什么相关")
|
||||
```
|
||||
|
||||
### 1.2 接口设计(原方案)
|
||||
|
||||
```rust
|
||||
pub struct GraphEntity {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub entity_type: String, // "person" | "concept" | "project" | ...
|
||||
pub description: String,
|
||||
pub tags: Vec<String>, // 检索标签(全小写,原子词)
|
||||
}
|
||||
|
||||
pub struct GraphRelation {
|
||||
pub source_id: String,
|
||||
pub target_id: String,
|
||||
pub relation_type: String, // "works_on" | "part_of" | "related_to" | ...
|
||||
pub weight: f32, // 关系强度 [0.0, 1.0]
|
||||
}
|
||||
|
||||
pub enum RelationDirection {
|
||||
Outgoing, // source_id -> target_id(默认)
|
||||
Incoming, // target_id -> source_id
|
||||
Both, // 双向遍历
|
||||
}
|
||||
|
||||
pub struct ScoredEntity {
|
||||
pub entity: GraphEntity,
|
||||
pub score: f32, // 基于图距离的评分 [0.0, 1.0]
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait KnowledgeGraph: Send + Sync {
|
||||
// 实体管理
|
||||
async fn add_entity(&self, entity: GraphEntity) -> Result<(), MemoryError>;
|
||||
async fn get_entity(&self, id: &str) -> Result<Option<GraphEntity>, MemoryError>;
|
||||
async fn remove_entity(&self, id: &str) -> Result<(), MemoryError>;
|
||||
|
||||
// 关系管理
|
||||
async fn add_relation(&self, relation: GraphRelation) -> Result<(), MemoryError>;
|
||||
async fn remove_relation(&self, source_id: &str, target_id: &str, relation_type: &str) -> Result<(), MemoryError>;
|
||||
async fn get_related(
|
||||
&self,
|
||||
entity_id: &str,
|
||||
depth: usize,
|
||||
direction: RelationDirection,
|
||||
relation_types: Option<&[&str]>,
|
||||
) -> Result<Vec<ScoredEntity>, MemoryError>;
|
||||
|
||||
// 检索
|
||||
async fn find_by_keywords(&self, keywords: &[String]) -> Result<Vec<GraphEntity>, MemoryError>;
|
||||
|
||||
// 标签管理
|
||||
async fn find_tags(&self, prefix: &str) -> Result<Vec<String>, MemoryError>;
|
||||
async fn entity_count_by_tag(&self, tag: &str) -> Result<usize, MemoryError>;
|
||||
async fn set_entity_tags(&self, entity_id: &str, tags: Vec<String>) -> Result<usize, MemoryError>;
|
||||
fn tag_constraints(&self) -> TagConstraints;
|
||||
}
|
||||
|
||||
pub struct TagConstraints {
|
||||
pub max_tags_per_entity: usize, // 默认 8
|
||||
}
|
||||
```
|
||||
|
||||
### 1.3 标签复用原则
|
||||
|
||||
标签不应随意增长,应优先复用已有标签。流程:
|
||||
|
||||
```
|
||||
LLM 提取候选标签 → 对每个候选:
|
||||
graph.find_tags(candidate.lowercase())
|
||||
├─ 命中已有标签 → 复用
|
||||
└─ 无匹配 → 注册新标签
|
||||
```
|
||||
|
||||
### 1.4 标签容量与精炼
|
||||
|
||||
每个实体最多 `max_tags_per_entity`(默认 8)个标签,按关联度降序排列。超出上限时保留关联度最高的标签。
|
||||
|
||||
### 1.5 InMemoryGraph 实现
|
||||
|
||||
```rust
|
||||
pub struct InMemoryGraph {
|
||||
entities: Mutex<HashMap<String, GraphEntity>>,
|
||||
relations: Mutex<Vec<GraphRelation>>,
|
||||
tag_index: Mutex<HashMap<String, HashSet<String>>>, // tag → entity_ids
|
||||
}
|
||||
```
|
||||
|
||||
图遍历使用 BFS/DFS 算法,需用 `HashSet<String>` 防环。
|
||||
|
||||
---
|
||||
|
||||
## 2. 高级评分策略
|
||||
|
||||
### 2.1 ScoringStrategy
|
||||
|
||||
Phase 3 仅使用内部的简单 TextOverlap 评分(Dice 系数)。Phase 4 可引入以下策略:
|
||||
|
||||
```rust
|
||||
pub struct ScoreWeights {
|
||||
pub overlap: f32, // 默认 0.5 — 文本重叠度,以原始 query 为基准
|
||||
pub graph: f32, // 默认 0.2 — 图距离
|
||||
pub temporal: f32, // 默认 0.1 — 时间衰减
|
||||
pub reference: f32, // 默认 0.2 — 引用计数
|
||||
}
|
||||
|
||||
pub enum ScoringStrategy {
|
||||
TextOverlap, // 以原始 query 为准绳的文本重叠度(默认)
|
||||
GraphDistance,
|
||||
TemporalWeight,
|
||||
ReferenceCount,
|
||||
Hybrid(ScoreWeights),
|
||||
}
|
||||
|
||||
pub struct ScoreBreakdown {
|
||||
pub overlap_score: f32,
|
||||
pub graph_score: f32,
|
||||
pub temporal_score: f32,
|
||||
pub reference_score: f32,
|
||||
}
|
||||
```
|
||||
|
||||
### 2.2 TextOverlap 算法
|
||||
|
||||
基于 Dice 系数计算 query 与召回内容的文本重叠度:
|
||||
|
||||
```
|
||||
Dice = 2 × |intersect(bigrams)| / (|bigrams_query| + |bigrams_content|)
|
||||
```
|
||||
|
||||
标题权重大于摘要,摘要权重大于正文。
|
||||
|
||||
---
|
||||
|
||||
## 3. 高级检索(MemoryRetriever 双通道版)
|
||||
|
||||
Phase 3 仅保留单通道(只搜 KnowledgeStore)。Phase 4 可恢复双通道:
|
||||
|
||||
```rust
|
||||
pub struct MemoryRetriever {
|
||||
knowledge_store: KnowledgeStore,
|
||||
knowledge_graph: Arc<dyn KnowledgeGraph>,
|
||||
keyword_extractor: Arc<dyn KeywordExtractor>,
|
||||
config: RetrieverConfig,
|
||||
}
|
||||
|
||||
pub enum RetrievalStrategy {
|
||||
Hybrid, // 结合所有通道 + 评分排序(默认)
|
||||
KnowledgeOnly, // 仅 KnowledgeStore
|
||||
GraphOnly, // 仅 KnowledgeGraph
|
||||
}
|
||||
```
|
||||
|
||||
检索流程:
|
||||
|
||||
```
|
||||
1. 关键词提取(KeywordExtractor)
|
||||
2. 并行召回:
|
||||
- KnowledgeStore.find_by_keywords(keywords)
|
||||
- KnowledgeGraph.find_by_keywords(keywords) → get_related() 图遍历
|
||||
3. 逐条评分(ScoringStrategy)
|
||||
4. 过滤 score < min_score
|
||||
5. 排序 → 截取 top-N
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. KeywordExtractor
|
||||
|
||||
```rust
|
||||
pub trait KeywordExtractor: Send + Sync {
|
||||
fn extract(&self, query: &str) -> Vec<String>;
|
||||
}
|
||||
|
||||
pub struct SimpleKeywordExtractor {
|
||||
stop_words: HashSet<String>,
|
||||
}
|
||||
```
|
||||
|
||||
默认实现:按非字母数字字符分割,过滤停用词和单字符词。停用词表应包含英语常用停用词(约 80-100 个)。
|
||||
|
||||
---
|
||||
|
||||
## 5. 基于召回价值的淘汰(RecallBased)
|
||||
|
||||
### 5.1 记忆价值评分
|
||||
|
||||
每条记忆维护召回统计,计算综合价值分数:
|
||||
|
||||
```rust
|
||||
pub struct RecallStats {
|
||||
pub recall_count: u64, // 累计召回次数
|
||||
pub total_score: f64, // 累计评分(平均分 = total_score / recall_count)
|
||||
pub last_recall_at: i64, // 最后一次被召回的时间戳(秒)
|
||||
}
|
||||
|
||||
// 记忆价值公式:
|
||||
// value = ln(1 + recall_count) × w_recall + avg_score × w_score + recency × w_recency
|
||||
```
|
||||
|
||||
### 5.2 record_recall()
|
||||
|
||||
```rust
|
||||
// MemoryStore trait 可选方法
|
||||
async fn record_recall(&self, id: &str, score: f32) -> Result<(), MemoryError> {
|
||||
Ok(()) // 默认空实现,需覆盖
|
||||
}
|
||||
```
|
||||
|
||||
### 5.3 淘汰策略
|
||||
|
||||
```rust
|
||||
pub enum EvictionPolicy {
|
||||
// ...Phase 3 已有: None, Ttl, Capacity...
|
||||
|
||||
RecallBased {
|
||||
max_items: usize,
|
||||
recall_weight: f32, // 默认 0.3
|
||||
score_weight: f32, // 默认 0.5
|
||||
recency_weight: f32, // 默认 0.2
|
||||
},
|
||||
Hybrid {
|
||||
ttl_secs: Option<u64>,
|
||||
max_items: Option<usize>,
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. Phase 3 → Phase 4 迁移建议
|
||||
|
||||
| 组件 | Phase 3 状态 | Phase 4 迁移方式 |
|
||||
|------|-------------|-----------------|
|
||||
| KnowledgeStore | 具体 struct(基于 MemoryStore) | 保持 struct,新增知识图谱数据入口 |
|
||||
| KnowledgeGraph | 不存在 | 新建 `memory/graph.rs`,实现 trait + InMemoryGraph |
|
||||
| MemoryRetriever | 单通道(仅 KnowledgeStore) | 增加 KnowledgeGraph 通道,恢复双通道检索 |
|
||||
| ScoringStrategy | 内部 TextOverlap | 恢复枚举策略 + MemoryRetriever 配置 |
|
||||
| KeywordExtractor | MemoryRetriever 内部拆分逻辑 | 抽取为独立 struct |
|
||||
| RecallBased 淘汰 | 不存在 | 恢复 EvictionPolicy 变体 + RecallStats |
|
||||
| 标签管理 | 不存在 | 恢复 tag_index + find_tags + set_entity_tags |
|
||||
|
||||
**关键依赖**:Phase 4 的 Agent 编排是 KnowledgeGraph 和标签管理的驱动者。如果没有 Agent 的 LLM 调用来提取标签、维护知识,KnowledgeGraph 只是一个空的图存储。
|
||||
@@ -0,0 +1,174 @@
|
||||
# OpenCode Agent 切换机制调研笔记
|
||||
|
||||
> 调研日期:2026-06-09
|
||||
> 调研方式:直接读 `sst/opencode` 源码(本地路径 `/Users/midnite/Samples/opencode`)
|
||||
> 调研目标:OpenCode 在切换 agent 时,整个上下文(系统提示词 + agent 提示词)是如何注入的
|
||||
> 关联:
|
||||
> - `docs/note-agent-harness-references.md` — 之前的参考项目调研
|
||||
> - `docs/note-agent-runtime-design.md` — AG Core Phase 4 设计决策
|
||||
> - `docs/7-agent-runtime.md` — AG Core Phase 4 方案文档
|
||||
|
||||
---
|
||||
|
||||
## 1. 项目背景
|
||||
|
||||
| 维度 | 详情 |
|
||||
|------|------|
|
||||
| 项目 | `sst/opencode`(GitHub),不是 npm `opencode-ai`(npm 只发编译产物) |
|
||||
| 规模 | 160k+ stars、900+ contributors、13k+ commits |
|
||||
| 语言 | TypeScript + Bun 运行时 + Effect(依赖注入 / Layer) |
|
||||
| 定位 | 开源 AI 编程代理(TUI / Desktop / IDE 插件),与 Claude Code / Cursor 对标 |
|
||||
| Agent 切换 | 终端按 **Tab 键** 在 primary agent 之间循环(build / plan) |
|
||||
|
||||
## 2. Agent 分类
|
||||
|
||||
OpenCode 把 agent 严格分为三类:
|
||||
|
||||
| 类型 | 内置 | 触发方式 | 用途 |
|
||||
|------|------|---------|------|
|
||||
| **Primary(主代理)** | build / plan | Tab 键循环切换 | 用户直接交互 |
|
||||
| **Subagent(子代理)** | general / explore / scout | 主代理自动调 / 用户 `@提及` | 专门任务 |
|
||||
| **Hidden system(隐藏)** | compaction / title / summary | 框架自动调,用户不可见 | 系统级 |
|
||||
|
||||
源码定义在 `packages/opencode/src/agent/agent.ts` 的 `Agent.Info` schema:
|
||||
|
||||
```typescript
|
||||
{ name, description, mode, native, hidden,
|
||||
topP, temperature, color,
|
||||
permission: PermissionV1.Ruleset,
|
||||
model, variant, prompt, options, steps }
|
||||
```
|
||||
|
||||
每个 agent 是**纯配置对象**,可由用户 `opencode.json` 覆盖或自定义 `.md` 文件定义。
|
||||
|
||||
## 3. 核心:System Prompt 完整拼接机制
|
||||
|
||||
OpenCode 把 system prompt 分为 **3 层**,每次 LLM 调用时**完整重新计算**(不缓存)。
|
||||
|
||||
### 3.1 拼接顺序
|
||||
|
||||
源码:`packages/opencode/src/session/llm/request.ts:58-66`
|
||||
|
||||
```typescript
|
||||
const system = [
|
||||
// Layer1:主 agent 提示词
|
||||
...(input.agent.prompt
|
||||
? [input.agent.prompt] // ① agent 自带 prompt(如 PROMPT_EXPLORE)
|
||||
: SystemPrompt.provider(input.model)), // ② 或按 model 选择 provider prompt
|
||||
|
||||
// Layer2:动态上下文(prompt.ts:1408-1414)
|
||||
...input.system, // ③ env + instructions + skills
|
||||
|
||||
// Layer3:用户自定义 system
|
||||
...(input.user.system ? [input.user.system] : []), // ④ 单次 user 消息的 system 字段
|
||||
]
|
||||
.filter(x => x)
|
||||
.join("\n")
|
||||
```
|
||||
|
||||
### 3.2 Layer 2 的内部构成
|
||||
|
||||
源码:`packages/opencode/src/session/prompt.ts:1408-1414`
|
||||
|
||||
```typescript
|
||||
const [skills, env, instructions, modelMsgs] = yield* Effect.all([
|
||||
sys.skills(agent), // 当前 agent 可用的 skills 描述
|
||||
sys.environment(model), // 工作目录、日期、平台、git 状态
|
||||
instruction.system(), // 自动读取 AGENTS.md / CLAUDE.md / CONTEXT.md
|
||||
MessageV2.toModelMessagesEffect(msgs, model),
|
||||
])
|
||||
const system = [...env, ...instructions, ...(skills ? [skills] : [])]
|
||||
```
|
||||
|
||||
**关键发现**:
|
||||
- **AGENTS.md / CLAUDE.md 是 instruction 自动注入**,不是用户手动 @引用
|
||||
- `system.ts` 的 `provider()` 函数**根据 model ID 选择不同 .txt 模板**(如 `PROMPT_ANTHROPIC` / `PROMPT_GEMINI` / `PROMPT_CODEX`)
|
||||
- `environment()` 注入**运行时环境信息**(cwd、平台、日期)
|
||||
|
||||
### 3.3 Agent 配置中的 `prompt` 字段
|
||||
|
||||
源码:`packages/opencode/src/agent/agent.ts`
|
||||
|
||||
```typescript
|
||||
// build / plan 没有 prompt 字段 → 走 SystemPrompt.provider(model)
|
||||
build: { name: "build", mode: "primary", permission: ... },
|
||||
plan: { name: "plan", mode: "primary", permission: ... },
|
||||
|
||||
// explore / compaction / title / summary 有显式 prompt
|
||||
explore: { ..., prompt: PROMPT_EXPLORE, mode: "subagent" },
|
||||
compaction: { ..., prompt: PROMPT_COMPACTION, mode: "primary", hidden: true },
|
||||
title: { ..., prompt: PROMPT_TITLE, mode: "primary", hidden: true },
|
||||
summary: { ..., prompt: PROMPT_SUMMARY, mode: "primary", hidden: true },
|
||||
```
|
||||
|
||||
**结论**:agent 的 `prompt` 字段**只决定 Layer 1 的内容**。Layer 2(env/instructions/skills)和 Layer 3(user.system)始终拼上。
|
||||
|
||||
## 4. 核心:Agent 切换时的 4 个动作
|
||||
|
||||
源码:`packages/opencode/src/session/reminders.ts`(**整个文件 92 行就是答案**)
|
||||
|
||||
OpenCode 用 **`synthetic: true` 的 text part 注入到 user message**,而不是修改 system prompt。
|
||||
|
||||
| 切换方向 | 动作 | 模板文件 | 大小 |
|
||||
|---------|------|---------|------|
|
||||
| 任意 → **plan** | user message 追加 `PROMPT_PLAN` | `session/prompt/plan.txt` | 26 行 |
|
||||
| **plan → build** | user message 追加 `BUILD_SWITCH` | `session/prompt/build-switch.txt` | **5 行** |
|
||||
| build → **plan** (experimental) | user message 追加 `PLAN_MODE` | `session/prompt/plan-mode.txt` | 70 行 |
|
||||
| 任意切换 | system prompt **完全重算** | `request.ts:58-66` | — |
|
||||
|
||||
**最关键的发现——`build-switch.txt` 全文只有 5 行**:
|
||||
|
||||
```
|
||||
<system-reminder>
|
||||
Your operational mode has changed from plan to build.
|
||||
You are no longer in read-only mode.
|
||||
You are permitted to make file changes, run shell commands, and utilize your arsenal of tools as needed.
|
||||
</system-reminder>
|
||||
```
|
||||
|
||||
**机制总结**:
|
||||
1. 切换时**不动 message history**(之前所有 user/assistant/tool 消息完整保留)
|
||||
2. 通过**比较 `msg.info.agent` 字段**判断上一条 assistant 用的哪个 agent
|
||||
3. 在**当前 user message 末尾追加**一个 `synthetic: true` 的 text part
|
||||
4. 同时**重新计算 system prompt**(Layer 1 根据新 agent 的 `prompt` 字段切换)
|
||||
|
||||
## 5. 关键设计决策
|
||||
|
||||
| 决策 | 做法 | 原因推断 |
|
||||
|------|------|---------|
|
||||
| **system prompt 重算而非缓存** | 每次 LLM 调用都重新拼接 | agent / model / instructions 都可能动态变化 |
|
||||
| **历史消息不重置** | 切换 = 追加 synthetic part | 保持上下文连贯,避免"切换即失忆" |
|
||||
| **切换提醒伪装成 user 内容** | `<system-reminder>` 标签 + `synthetic: true` | 大多数 LLM 对 `<system-reminder>` 标签有特殊信任 |
|
||||
| **agent prompt 与 model prompt 二选一** | `agent.prompt ?? SystemPrompt.provider(model)` | build/plan 共享 model prompt,自定义 agent 可独立 prompt |
|
||||
| **AGENTS.md 自动注入** | instruction.service 每轮扫描 | Claude Code 兼容,提升跨工具体验 |
|
||||
|
||||
## 6. 与 AG Core Phase 4 的对应关系
|
||||
|
||||
| OpenCode 机制 | AG Core 对应 | 借鉴价值 |
|
||||
|--------------|-------------|---------|
|
||||
| 3 层 system prompt 拼接 | `AgentSession::submit_turn` 中组装 `LlmCycle.with_system_prompt()` | **高**——可拆分为「base prompt + agent prompt + env context」 |
|
||||
| agent 切换时追加 synthetic part | Phase 4 v1 **不做**(仅保留单 agent 角色) | 中——v0.2+ 才考虑 |
|
||||
| AGENTS.md 自动注入 | `prompt::PromptTemplate` + 文件加载(应用层) | 低——文件 I/O 是应用层职责 |
|
||||
| 权限矩阵三态 (allow/ask/deny) | `tools::PermissionChecker`(Phase 2 已实现) | 已有 |
|
||||
| Hidden system agent (Compaction) | `llm::compact`(Phase 0 已实现) | 已有 |
|
||||
| Tab 键循环切换 | 应用层 UI 概念 | 不在 core 库范围 |
|
||||
|
||||
## 7. 借鉴 / 不借鉴清单
|
||||
|
||||
### ✅ 值得借鉴(v0.2+ 考虑)
|
||||
|
||||
1. **`AgentSession` 应支持"切换 agent 但保留历史"**——目前 Phase 4 v1 不做,但 trait 设计上要预留空间
|
||||
2. **System prompt 拆分为多层**——`base_prompt + agent_prompt + env_context`,便于将来按 agent 类型切换
|
||||
3. **synthetic message 模式**——切换 agent 时插入"状态变更通知"而非修改历史
|
||||
|
||||
### ❌ 不借鉴
|
||||
|
||||
- Tab 键循环切换(应用层 UI 概念)
|
||||
- `.md` agent 定义文件(应用层文件加载)
|
||||
- `mode: primary/subagent` 区分(AG Core 是 lib 不做 UI 角色区分)
|
||||
- Hidden system agent 字段(AG Core 已在 L0/L1 实现等价能力)
|
||||
- AGENTS.md 自动注入(应用层职责)
|
||||
|
||||
## 8. 一句话总结
|
||||
|
||||
> **OpenCode 的 agent 切换机制 = "system prompt 完全重算" + "user message 追加 5 行 synthetic 提醒"。** Agent 切换**不**重置消息历史,**不**改写之前内容,只在末尾追加一条"状态变更通知",并按新 agent 重新组装 system prompt 的 Layer 1(agent 专属 prompt)。
|
||||
+340
@@ -0,0 +1,340 @@
|
||||
# AG Core Roadmap
|
||||
|
||||
> 定稿日期:2026-05-11
|
||||
> 最后更新:2026-06-11(Phase 4c 编码实施完成)
|
||||
|
||||
## 愿景
|
||||
|
||||
AG Core 定位为构建 AI 智能体的底层工具箱,通过模块化、可插拔的架构,提供大模型调用、提示词工程、工具系统、记忆检索四大核心能力,支持快速组合出符合业务需求的智能体应用。
|
||||
|
||||
**当前状态**:Phase 0 基础设施已全部完成,Phase 1 提示词工程已全部完成,Phase 2 工具系统已全部完成,Phase 3 记忆系统已全部完成,Phase 4a 核心胶水层已全部完成,Phase 4b 任务执行已全部完成,Phase 4c 会话级记忆已全部完成(116 个测试通过,0 警告)。
|
||||
|
||||
---
|
||||
|
||||
## 模块完整性评估
|
||||
|
||||
| 功能领域 | 方案状态 | 文档位置 | 实现优先级 |
|
||||
|---------|---------|---------|-----------|
|
||||
| LLM 调用周期 | ✅ 完整 | `specs/llm-call-lifecycle.md` | P0 |
|
||||
| 提示词工程 | ✅ 完整 | `docs/4-prompt-engineering.md` | P1 |
|
||||
| 工具系统 + 权限 | ✅ 完整 | `docs/5-tool-system.md` | P1 |
|
||||
| 记忆检索 | ✅ 完整 | `docs/6-memory-system.md` | P2 |
|
||||
| Agent 运行时(4a 胶水层) | ✅ 已实现 | `docs/7-agent-runtime.md` | P2 |
|
||||
| 生命周期钩子 | ✅ 完整 | `docs/3-phase0-remaining.md` | P0(LLM Cycle 扩展) |
|
||||
| Provider 注册发现 | ✅ 完整 | `docs/3-phase0-remaining.md` | P0(Provider 接口扩展) |
|
||||
| 流式事件系统 | ✅ 完整 | `docs/3-phase0-remaining.md` | P0(流式接口前置) |
|
||||
|
||||
---
|
||||
|
||||
## 分阶段 Roadmap
|
||||
|
||||
### Phase 0 — Foundation(基础设施)
|
||||
|
||||
**目标**:实现 LLM 调用周期的核心功能,作为所有上层模块的基础。
|
||||
|
||||
**交付物**:
|
||||
1. ✅ `llm/types.rs` — 核心数据类型(Message, ContentBlock, ChatRequest/Response, ToolDefinition, StopReason)
|
||||
2. ✅ `llm/error.rs` — 错误体系(LlmError 枚举,可重试/不可重试判断)
|
||||
3. ✅ `llm/provider.rs` + `llm/provider/openai.rs` — Provider 接口 + OpenAI 兼容实现
|
||||
4. ✅ `llm/provider/registry.rs` — ProviderRegistry(多 Provider 注册发现)
|
||||
5. ✅ `llm/cycle.rs` + `llm/cycle/{retry,usage}.rs` — 生命周期引擎(重试策略 + 用量追踪)
|
||||
6. ✅ `llm/hooks.rs` — HookExecutor 接口(生命周期钩子)
|
||||
7. ✅ `llm/stream.rs` — StreamEvents 流式事件系统(AssistantTextDelta, ToolExecutionStarted 等)
|
||||
8. ✅ `llm/compact.rs` — Auto-compaction(上下文自动压缩)
|
||||
9. ✅ `Cargo.toml` — 添加依赖(tokio, reqwest, serde, thiserror, async-trait, tracing)
|
||||
|
||||
**依赖**:无
|
||||
|
||||
**优先级**:Must Have
|
||||
|
||||
**预估规模**:约 1000 行核心代码
|
||||
|
||||
**状态**:✅ Phase 0 全部交付物已完成
|
||||
|
||||
---
|
||||
|
||||
### Phase 1 — Prompt Engineering(提示词工程)
|
||||
|
||||
**目标**:提供提示词的组合、模板化与优化能力。
|
||||
|
||||
**交付物**:
|
||||
1. ✅ `prompt.rs` + `prompt/` 模块
|
||||
2. ✅ `PromptTemplate` — 模板引擎(支持变量插值、条件渲染)
|
||||
3. ✅ `PromptComposer` — 提示词组合器(拼接 system/user/assistant 消息)
|
||||
4. ✅ `docs/4-prompt-engineering.md` — 方案文档
|
||||
|
||||
**依赖**:无(可与 Phase 0 并行)
|
||||
|
||||
**优先级**:Should Have
|
||||
|
||||
**预估规模**:约 400 行代码
|
||||
|
||||
**状态**:✅ Phase 1 全部交付物已完成
|
||||
|
||||
---
|
||||
|
||||
### Phase 2 — Tool System(工具系统)
|
||||
|
||||
**目标**:实现 MCP 协议集成与自定义工具注册、调用、权限控制。
|
||||
|
||||
**交付物**:
|
||||
1. ✅ `tools.rs` + `tools/` 模块(base/registry/permission/mcp/error)
|
||||
2. ✅ `ToolRegistry` — 工具注册表(注册、发现、调用、并行执行、超时控制)
|
||||
3. ✅ `BaseTool` trait — 工具抽象接口(含 ToolContext 执行上下文)
|
||||
4. ✅ `McpClient` — MCP 协议客户端(stdio transport,StreamableHttp 预留)
|
||||
5. ✅ `PermissionChecker` — 工具执行权限检查(白名单/黑名单/自定义权限)
|
||||
6. ✅ `docs/5-tool-system.md` — 方案设计文档
|
||||
7. ✅ 扩展 `llm/cycle.rs` 支持自动 tool 循环(`submit_with_tools()` + `submit_request()` + `maybe_compact()`)
|
||||
8. ✅ `ToolError` — 结构化错误体系(含 `is_recoverable()` 分类)
|
||||
|
||||
**依赖**:Phase 0(LlmProvider 接口传递 tool definitions)、Phase 1(提示词可能需要注入工具描述)
|
||||
|
||||
**优先级**:Should Have
|
||||
|
||||
**预估规模**:约 900 行代码(实际约 1500 行)
|
||||
|
||||
**状态**:✅ Phase 2 全部交付物已完成
|
||||
|
||||
---
|
||||
|
||||
### Phase 3 — Memory System(记忆系统)
|
||||
|
||||
**目标**:提供对话记忆的存储、检索与管理能力。
|
||||
|
||||
**交付物**:
|
||||
1. ✅ `memory.rs` + `memory/` 模块(store / conversation / knowledge / retriever / error / types)
|
||||
2. ✅ `MemoryStore` trait + `InMemoryStore` — 记忆存储抽象(可插拔后端)+ 默认实现
|
||||
3. ✅ `ConversationMemory` — 对话记忆管理(sliding window / 全量),复用 `llm::compact`
|
||||
4. ✅ `KnowledgeStore` — 知识页面存储(具体 struct,非 trait,基于 MemoryStore)
|
||||
5. ✅ `MemoryRetriever` — 记忆检索器(TextOverlap Dice 系数评分,单通道)
|
||||
6. ✅ `docs/6-memory-system.md` — 方案设计文档
|
||||
7. ✅ `docs/note-knowledge-graph-design.md` — KnowledgeGraph 等 Phase 4 备用设计
|
||||
8. ✅ `EvictionPolicy` — 支持 None / Ttl / Capacity 三种淘汰策略
|
||||
|
||||
**依赖**:Phase 0(llm::compact 复用)、Cargo.toml 新增 `time` 依赖
|
||||
|
||||
**优先级**:Could Have
|
||||
|
||||
**预估规模**:约 700 行代码(实际约 1242 行,含测试)
|
||||
|
||||
**状态**:✅ Phase 3 全部交付物已完成
|
||||
|
||||
---
|
||||
|
||||
### Phase 4a — Agent Core Glue(核心胶水层)
|
||||
|
||||
**目标**:提供最小可用的 Agent Runtime——把 Phase 0-3 的能力"装配"成 `AgentSession::submit_turn`。上层可基于 4a 构建多轮对话应用。
|
||||
|
||||
**交付物**:
|
||||
1. ✅ `agent.rs` + `agent/` 模块(7 个文件:agent/error/runtime/builder/session/task + 模块根)
|
||||
2. ✅ `Agent` trait — 智能体角色定义(name / system_prompt / tool_definitions)
|
||||
3. ✅ `AgentSession` — 会话实例(绑定 `Arc<dyn Agent>` + `RuntimeBundle` + 内联 HashMap session_data)
|
||||
4. ✅ `RuntimeBundle` — 显式依赖注入容器(不含 session_memory_backend)
|
||||
5. ✅ `AgentBuilder` — 链式构造入口(不含 session_memory_backend)
|
||||
6. ✅ `AgentError` — 统一错误类型(7 个变体:Llm / Tool / Memory / HookBlocked / LimitExceeded / Config / Other;不含 PlanParse)
|
||||
7. ✅ `Plan` / `Step` / `StepStatus` — 纯数据结构(不含任何解析逻辑)
|
||||
8. ✅ Hook 事件扩展:OnTurnStart / OnTurnEnd + turn_index 字段
|
||||
9. ✅ `docs/7-agent-runtime.md` — 方案设计文档(含 4a/4b/4c 分阶段计划)
|
||||
|
||||
**实际新增**:
|
||||
- 新增文件 7 个(agent.rs + agent/{agent, error, runtime, builder, session, task}.rs)
|
||||
- 修改文件 3 个(lib.rs +1 行;llm/hooks.rs +13 行追加变体/字段;llm/cycle.rs 内部字段 Box→Arc + 新增 `new_with_arc` 公共方法)
|
||||
- 实际代码量约 800 行(含测试;纯实现约 470 行——略高于方案预估 440 行,因 AgentSession 的 tests 模块内联 MockProvider/StubAgent 等辅助结构)
|
||||
- 新增内联测试 22 个;全量测试 84 → 109(0 失败)
|
||||
- clippy 0 警告(agent 模块)
|
||||
- 无新增外部依赖
|
||||
|
||||
**依赖**:Phase 0, 1, 2, 3
|
||||
|
||||
**优先级**:Could Have
|
||||
|
||||
**预估规模**:约 440 行代码
|
||||
|
||||
**状态**:✅ Phase 4a 全部交付物已完成
|
||||
|
||||
---
|
||||
|
||||
### Phase 4b — Task Execution(任务执行)
|
||||
|
||||
**目标**:在 Phase 4a 基础上,赋予智能体"拆解目标 → 逐步执行"的能力。
|
||||
|
||||
**前置条件**:Phase 4a 已完成。
|
||||
|
||||
**交付物**:
|
||||
1. ✅ `TaskAgent` trait — `run(goal)` 自主式 + `execute_plan(plan)` 外部驱动式
|
||||
2. ✅ `PlanParser` trait + `JsonPlanParser` 参考实现
|
||||
3. ✅ `AgentError` 追加 PlanParse 变体(共 7 个变体)
|
||||
4. ✅ Hook 事件扩展:OnPlanStepComplete + plan_step_index 字段
|
||||
|
||||
**依赖**:Phase 4a
|
||||
|
||||
**优先级**:Could Have
|
||||
|
||||
**预估规模**:约 200 行代码(增量)
|
||||
|
||||
**实际新增**:
|
||||
- 修改文件 2 个(llm/hooks.rs +5 行;agent/error.rs +10 行)
|
||||
- 新增代码约 150 行(含测试;纯实现约 90 行)
|
||||
- 新增内联测试 4 个;全量测试 109 → 113(0 失败)
|
||||
- clippy 0 警告
|
||||
- 无新增外部依赖
|
||||
|
||||
**状态**:✅ Phase 4b 全部交付物已完成
|
||||
|
||||
---
|
||||
|
||||
### Phase 4c — Session Memory(会话级记忆)
|
||||
|
||||
**目标**:提供会话级 key-value 记忆,作为 session 内各 context 之间的信息桥接通道。
|
||||
|
||||
**前置条件**:Phase 4a 已完成(可与 Phase 4b 并行)。
|
||||
|
||||
**交付物**:
|
||||
1. ✅ `SessionMemory` struct — 基于 `MemoryStore`,按 session_id namespace 隔离
|
||||
2. ✅ `RuntimeBundle` + `AgentBuilder` 扩展 `session_memory_backend` 字段
|
||||
3. ✅ `AgentSession` 替换内联 HashMap 为完整 `SessionMemory`
|
||||
|
||||
**依赖**:Phase 4a(Phase 3 MemoryStore)
|
||||
|
||||
**优先级**:Could Have
|
||||
|
||||
**预估规模**:约 115 行代码(增量)
|
||||
|
||||
**实际新增**:
|
||||
- 新增文件 1 个(agent/session_memory.rs)
|
||||
- 修改文件 4 个(agent/runtime.rs +5 行;agent/builder.rs +10 行;agent/session.rs +30 行;agent.rs +2 行)
|
||||
- 新增代码约 180 行(含测试;纯实现约 100 行)
|
||||
- 新增内联测试 3 个;全量测试 113 → 116(0 失败)
|
||||
- clippy 0 警告
|
||||
- 无新增外部依赖
|
||||
|
||||
**状态**:✅ Phase 4c 全部交付物已完成
|
||||
|
||||
---
|
||||
|
||||
## 依赖关系图
|
||||
|
||||
```mermaid
|
||||
graph BT
|
||||
P0["<b>Phase 0: Foundation</b><br/>LLM Cycle<br/>ProviderRegistry<br/>HookExecutor<br/>StreamEvents<br/>Auto-compaction"]:::done
|
||||
P1["<b>Phase 1: Prompt Engineering</b><br/>PromptTemplate<br/>PromptComposer"]:::done
|
||||
P2["<b>Phase 2: Tool System</b><br/>Tool Registry<br/>PermissionChecker<br/>MCP Client"]:::done
|
||||
P3["<b>Phase 3: Memory System</b><br/>MemoryStore<br/>ConversationMemory<br/>KnowledgeStore"]:::done
|
||||
P4a["<b>Phase 4a: Core Glue</b><br/>AgentSession<br/>RuntimeBundle<br/>Plan/Step 纯数据"]:::done
|
||||
P4b["<b>Phase 4b: Task Execution</b><br/>TaskAgent<br/>PlanParser<br/>JsonPlanParser"]:::done
|
||||
P4c["<b>Phase 4c: Session Memory</b><br/>SessionMemory"]:::done
|
||||
|
||||
P1 --> P0
|
||||
P2 --> P0
|
||||
P3 --> P0
|
||||
P2 --> P1
|
||||
P4a --> P1
|
||||
P4a --> P2
|
||||
P4a --> P3
|
||||
P4b --> P4a
|
||||
P4c --> P4a
|
||||
|
||||
classDef done fill:#4ade80,stroke:#16a34a,color:#1a1a1a
|
||||
classDef pending fill:#fbbf24,stroke:#d97706,color:#1a1a1a
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 扩展计划(v0.2+)
|
||||
|
||||
> 以下功能在已完成的 phase 中已实现基础能力或在 Phase 4 阶段明确了边界,后续可按维度增量扩展。
|
||||
> 设计参考:见 `docs/note-agent-harness-references.md`(OpenClaw / Hermes / OpenHuman / OpenHarness 横向对比)。
|
||||
> OpenCode 借鉴:见 `docs/note-opencode-agent-switching.md`(Agent 切换 + System Prompt 拼接机制)。
|
||||
|
||||
### 已有扩展项(沿用)
|
||||
|
||||
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||
|-------|---------|------|--------|------|
|
||||
| Prompt Optimizer | `prompt` | 提示词自动优化 | P3 | 待实现 |
|
||||
| 流式接口优化 | `llm/stream` | 流式响应解析与事件化 | P0 | ✅ 已完成基础实现 |
|
||||
|
||||
### v0.2+ 新增扩展项
|
||||
|
||||
> 以下为基于 Phase 4 设计讨论确定的 v0.2+ 候选扩展方向,按维度分组。
|
||||
> 标注为"v0.2 待评估"表示在 Phase 4 完成后再决定是否启动。
|
||||
|
||||
#### Multi-Agent / 协同
|
||||
|
||||
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||
|-------|---------|------|--------|------|
|
||||
| Multi-Agent 协同(Swarm) | `agent` | 子 Agent 委派、并行子任务、结果聚合 | P2 | v0.2 待评估 |
|
||||
|
||||
#### 技能(Skills)
|
||||
|
||||
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||
|-------|---------|------|--------|------|
|
||||
| Markdown 技能按需加载 | `agent` / `prompt` | 兼容 `SKILL.md` 格式(Hermes / OpenHarness 风格),按 prompt 上下文动态加载 | P2 | v0.2 待评估 |
|
||||
|
||||
#### 记忆(Memory)
|
||||
|
||||
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||
|-------|---------|------|--------|------|
|
||||
| 多通道检索(hybrid) | `memory/retriever` | 在 TextOverlap 之上叠加向量检索通道 | P2 | v0.2 待评估 |
|
||||
| KnowledgeGraph 深度记忆 | `memory` | 实体-关系图、`note-knowledge-graph-design.md` 已记录设计 | P3 | v0.2 待评估 |
|
||||
| TokenJuice 智能压缩 | `memory` / `llm/compact` | 借鉴 OpenHuman TokenJuice,对工具结果做语义压缩而非字节截断 | P3 | v0.2 待评估 |
|
||||
|
||||
#### 交互层(TUI / Gateway)
|
||||
|
||||
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||
|-------|---------|------|--------|------|
|
||||
| TUI / 多平台 Gateway | 应用层 | OpenClaw / Hermes 风格的消息平台桥接(Feishu / Telegram / Discord 等) | P3 | v0.2+ 应用层 |
|
||||
|
||||
#### 训练基础设施
|
||||
|
||||
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||
|-------|---------|------|--------|------|
|
||||
| RL 轨迹导出 | `agent` | ShareGPT 格式轨迹、Atropos 集成(Hermes 风格) | P3 | v0.3+ 探索 |
|
||||
|
||||
#### 安全治理
|
||||
|
||||
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||
|-------|---------|------|--------|------|
|
||||
| Human-in-the-loop 审批 | `agent` / `tools/permission` | 高危工具执行前的异步审批回调(OpenHarness `permission_prompt` 模式) | P2 | v0.2 待评估 |
|
||||
|
||||
#### 流式 / 实时
|
||||
|
||||
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||
|-------|---------|------|--------|------|
|
||||
| 流式 `submit_turn` | `agent/session` | Phase 4 v1 只暴露非流式 `submit_turn()`;v0.2 包装 `LlmCycle::submit_stream` 暴露流式入口 | P2 | v0.2 待评估 |
|
||||
|
||||
#### Agent 切换 / Prompt 动态(OpenCode 借鉴)
|
||||
|
||||
| 扩展项 | 所在模块 | 说明 | 优先级 | 状态 |
|
||||
|-------|---------|------|--------|------|
|
||||
| Agent 身份切换(角色轮换) | `agent` | 借鉴 OpenCode Tab 键切换 build/plan:同一 `AgentSession` 持有可热替换的 `Agent` 引用,切换时不重置消息历史,在末尾追加 `synthetic: true` 的状态变更消息。详见 `docs/note-opencode-agent-switching.md` §4 | P2 | v0.2 待评估 |
|
||||
| System Prompt 多层动态拼接 | `agent/session` | 借鉴 OpenCode `request.ts:58-66`:拆分 `base_prompt + agent_prompt + env_context` 三层,`AgentSession::submit_turn` 每轮重算(不缓存),便于按 agent 类型动态切换 | P2 | v0.2 待评估 |
|
||||
| **多 Context 切换** | `agent` | **Phase 4c 的 SessionMemory 数据结构已预留信息桥接通道,v0.2+ 在其上包装 `ContextManager` 实现完整的多 context 切换:创建/销毁/切换 context、通过 SessionMemory 桥接关键信息。详见 `docs/note-context-switch-design.md`** | P2 | v0.2 待评估 |
|
||||
|
||||
---
|
||||
|
||||
## 风险与建议
|
||||
|
||||
1. **Phase 0 已完成**:LLM 调用周期基础设施已全部实现,可以支撑后续模块开发
|
||||
2. **并行可能性**:Phase 0 和 Phase 1 可并行开展(无相互依赖),可加速早期交付
|
||||
3. **MCP 协议复杂性**:MCP 涉及协议握手、session 管理、长期连接,建议预留充足时间调研协议细节
|
||||
4. **Scope 蔓延风险**:当前 specs 只有 1 份文档,建议每个模块上线前都产出对应 spec,避免边实现边设计
|
||||
5. **Phase 4 抽象化边界**:AG Core 定位为"支持库"而非"Agent 产品",Phase 4(4a/4b/4c)需严格控制范围——只暴露 trait + 最小 reference impl,业务循环(多轮 turn 编排、对话记忆自动回写、Task 拆解策略)留给上层应用。`SessionMemory`(Phase 4c)提供信息桥接通道但不实现 context 切换逻辑。多 context 切换管理延后至 v0.2+。详细设计决策见 `docs/7-agent-runtime.md`
|
||||
6. **参考项目语言差异**:OpenClaw / Hermes / OpenHarness 均为 Python/TypeScript 实现,OpenHuman 虽是 Rust + Tauri 但定位是桌面应用。借鉴时**只取架构模式**,不照搬具体实现(如 Pydantic 工具校验、SQLite Memory Tree、Node+Python 双进程等)
|
||||
|
||||
---
|
||||
|
||||
## 下一步行动
|
||||
|
||||
1. **Phase 4c 已完成**:Phase 4a + 4b + 4c 已交付(116 测试通过,0 clippy 警告)。可启动 v0.2+ 扩展评估(如多 Context 切换、Multi-Agent 协同等)
|
||||
2. **Context 切换备忘**:`docs/note-context-switch-design.md` 记录了多 context 切换方案讨论,作为 v0.2+ 扩展项的输入
|
||||
3. **参考项目调研沉淀**:已完成 OpenClaw / Hermes / OpenHuman / OpenHarness 横向调研,结果沉淀至 `docs/note-agent-harness-references.md`,作为 v0.2+ 扩展项的输入
|
||||
4. **Phase 3 备用设计就绪**:`docs/note-knowledge-graph-design.md` 记录了 KnowledgeGraph、高级评分、RecallBased 淘汰等设计,v0.2+ 记忆扩展可直接参考
|
||||
|
||||
**已完成 / 进行中阶段**:
|
||||
- ✅ Phase 0 Foundation — 全部交付物已完成
|
||||
- ✅ Phase 1 Prompt Engineering — 全部交付物已完成
|
||||
- ✅ Phase 2 Tool System — 全部交付物已完成
|
||||
- ✅ Phase 3 Memory System — 全部交付物已完成
|
||||
- ✅ Phase 4a Core Glue — 全部交付物已完成
|
||||
- ✅ Phase 4b Task Execution — 全部交付物已完成
|
||||
- ✅ Phase 4c Session Memory — 全部交付物已完成
|
||||
@@ -0,0 +1,75 @@
|
||||
use std::env;
|
||||
|
||||
use agcore::init_tracing;
|
||||
use agcore::llm::{
|
||||
cycle::{CycleConfig, LlmCycle},
|
||||
provider::{create_provider, ProviderConfig, ProviderType},
|
||||
types::{ChatResponse, OpenaiContentPart},
|
||||
};
|
||||
|
||||
fn extract_response_text(response: &ChatResponse) -> &str {
|
||||
match &response.message {
|
||||
agcore::llm::types::OpenaiChatMessage::Assistant { content, .. } => match content {
|
||||
agcore::llm::types::ContentField::String(s) => s,
|
||||
agcore::llm::types::ContentField::Array(parts) => {
|
||||
for part in parts {
|
||||
if let OpenaiContentPart::Text { text } = part {
|
||||
return text;
|
||||
}
|
||||
}
|
||||
"[无文本内容]"
|
||||
}
|
||||
},
|
||||
_ => "[非 assistant 消息]",
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
dotenvy::dotenv().ok();
|
||||
init_tracing();
|
||||
|
||||
let api_key = env::var("OPENAI_API_KEY").expect("未设置 OPENAI_API_KEY 环境变量");
|
||||
let base_url = env::var("OPENAI_BASE_URL").expect("未设置 OPENAI_BASE_URL 环境变量");
|
||||
let model = env::var("OPENAI_MODEL").expect("未设置 OPENAI_MODEL 环境变量");
|
||||
|
||||
let provider_type = env::var("PROVIDER")
|
||||
.unwrap_or_else(|_| "openai".into())
|
||||
.parse::<ProviderType>()
|
||||
.expect("无效的 PROVIDER 值");
|
||||
|
||||
let config = ProviderConfig {
|
||||
base_url,
|
||||
api_key,
|
||||
model: model.clone(),
|
||||
};
|
||||
|
||||
let provider = create_provider(provider_type, config)
|
||||
.expect("创建 Provider 失败");
|
||||
|
||||
let cycle_config = CycleConfig {
|
||||
model,
|
||||
max_tokens: Some(65536),
|
||||
temperature: Some(1.3),
|
||||
..CycleConfig::default()
|
||||
};
|
||||
|
||||
let mut cycle = LlmCycle::new(provider, cycle_config)
|
||||
.with_system_prompt("你是一个简洁的助手,对于任何问题都是用一句话回答。".to_string());
|
||||
|
||||
println!("发送请求...");
|
||||
|
||||
match cycle.submit("介绍一下你自己吧。".to_string(), vec![]).await {
|
||||
Ok(response) => {
|
||||
println!("LLM 回复:{}", extract_response_text(&response));
|
||||
println!(
|
||||
"Token 用量:{} 输入, {} 输出",
|
||||
response.usage.prompt_tokens, response.usage.completion_tokens
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("请求失败:{e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"agent": {
|
||||
"plan": {
|
||||
"permission": {
|
||||
"edit": {
|
||||
"*": "deny",
|
||||
".opencode/plans/**": "allow",
|
||||
"docs/**": "allow"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
//! Agent Runtime —— 智能体(Agent)核心胶水层。
|
||||
//!
|
||||
//! 把 Phase 0-3 的能力(LlmCycle / ToolRegistry / MemoryStore / HookExecutor)"装配"为
|
||||
//! 上层可用的智能体抽象:`Agent` / `AgentSession` / `RuntimeBundle` / `AgentBuilder` / `Plan`。
|
||||
//!
|
||||
//! **不**实现业务循环,**不**假设上层如何使用 memory。
|
||||
//! 详细设计见 `docs/7-agent-runtime.md`。
|
||||
|
||||
// 模块根文件 `agent.rs` 与子模块 `agent/agent.rs` 同名(项目惯例,与 `llm/cycle.rs` 一致)。
|
||||
#![allow(clippy::module_inception)]
|
||||
|
||||
pub mod agent;
|
||||
pub mod builder;
|
||||
pub mod error;
|
||||
pub mod runtime;
|
||||
pub mod session;
|
||||
pub mod session_memory;
|
||||
pub mod task;
|
||||
|
||||
// 重导出公共 API(按使用频度排序)
|
||||
pub use agent::Agent;
|
||||
pub use builder::AgentBuilder;
|
||||
pub use error::AgentError;
|
||||
pub use runtime::{AgentConfig, RuntimeBundle};
|
||||
pub use session::AgentSession;
|
||||
pub use session_memory::SessionMemory;
|
||||
pub use task::{Plan, PlanParser, Step, StepStatus, TaskAgent};
|
||||
pub use task::JsonPlanParser;
|
||||
@@ -0,0 +1,30 @@
|
||||
//! Agent trait —— 智能体的"角色"抽象。
|
||||
//!
|
||||
//! 设计要点(参见 `docs/7-agent-runtime.md` §3.2.1):
|
||||
//!
|
||||
//! - **角色与会话分离**:`Agent` 定义"做什么、用什么工具",`AgentSession` 维护"当前状态"
|
||||
//! - **工具白名单扩展点**:默认从 `RuntimeBundle.tool_registry` 取全部,子 trait 可覆盖做白名单/过滤
|
||||
//! - **不绑定业务循环**:`submit_turn` 在 `AgentSession` 上,不在 trait 上
|
||||
|
||||
use crate::agent::runtime::RuntimeBundle;
|
||||
use crate::llm::types::ToolDefinition;
|
||||
|
||||
/// Agent 角色抽象。
|
||||
///
|
||||
/// 实现此 trait 即可接入 Agent Runtime。典型实现是 struct 持有静态配置(name、system prompt 模板),
|
||||
/// 也可以是基于配置动态生成的轻量实现。
|
||||
pub trait Agent: Send + Sync {
|
||||
/// 角色名(用于日志、调试、UI 展示)。
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// 系统提示词。无提示词的纯工具型 agent 返回 `None`。
|
||||
fn system_prompt(&self) -> Option<&str>;
|
||||
|
||||
/// 列出该 Agent 想暴露给 LLM 的工具定义。
|
||||
///
|
||||
/// **默认实现**:从 `bundle.tool_registry` 取全部工具(最常用模式)。
|
||||
/// **子 trait / 具体实现可覆盖**:做白名单、过滤、按状态动态调整等。
|
||||
fn tool_definitions(&self, bundle: &RuntimeBundle) -> Vec<ToolDefinition> {
|
||||
bundle.tool_registry.definitions()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
//! AgentBuilder —— `RuntimeBundle` 的链式构造入口。
|
||||
//!
|
||||
//! 设计原则:
|
||||
//!
|
||||
//! - **唯一构造入口**:上层应用不应直接 `RuntimeBundle::new`;用 `AgentBuilder` 保证必填字段
|
||||
//! 校验集中、默认值集中管理
|
||||
//! - **必填字段在 `build()` 时校验**:缺失返回 `AgentError::Config`,不 panic
|
||||
//! - **选填字段独立 setter**:未调用对应 setter 时使用 `None` 兜底
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::error::AgentError;
|
||||
use crate::agent::runtime::{AgentConfig, RuntimeBundle};
|
||||
use crate::llm::hooks::HookExecutor;
|
||||
use crate::llm::provider::LlmProvider;
|
||||
use crate::memory::retriever::MemoryRetriever;
|
||||
use crate::memory::store::MemoryStore;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
/// `RuntimeBundle` 的链式构造器。
|
||||
///
|
||||
/// 使用示例:
|
||||
/// ```ignore
|
||||
/// let bundle = AgentBuilder::new()
|
||||
/// .provider(my_provider)
|
||||
/// .tool_registry(my_registry)
|
||||
/// .hook_executor(my_executor)
|
||||
/// .build()?;
|
||||
/// ```
|
||||
#[derive(Default)]
|
||||
pub struct AgentBuilder {
|
||||
provider: Option<Arc<dyn LlmProvider>>,
|
||||
tool_registry: Option<Arc<ToolRegistry>>,
|
||||
hook_executor: Option<Arc<HookExecutor>>,
|
||||
memory_store: Option<Arc<dyn MemoryStore>>,
|
||||
retriever: Option<Arc<MemoryRetriever>>,
|
||||
session_memory_backend: Option<Arc<dyn MemoryStore>>,
|
||||
config: Option<AgentConfig>,
|
||||
}
|
||||
|
||||
impl AgentBuilder {
|
||||
/// 创建一个空的 builder,所有必填字段均为 `None`。
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// 设置 LLM provider(必填)。
|
||||
pub fn provider(mut self, p: Arc<dyn LlmProvider>) -> Self {
|
||||
self.provider = Some(p);
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置工具注册表(必填)。
|
||||
pub fn tool_registry(mut self, r: Arc<ToolRegistry>) -> Self {
|
||||
self.tool_registry = Some(r);
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置钩子执行器(必填)。
|
||||
pub fn hook_executor(mut self, h: Arc<HookExecutor>) -> Self {
|
||||
self.hook_executor = Some(h);
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置持久化记忆后端(选填,不传也能跑)。
|
||||
pub fn memory_store(mut self, m: Arc<dyn MemoryStore>) -> Self {
|
||||
self.memory_store = Some(m);
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置记忆检索器(选填,不传也能跑)。
|
||||
pub fn retriever(mut self, r: Arc<MemoryRetriever>) -> Self {
|
||||
self.retriever = Some(r);
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置 SessionMemory 后端(选填,不传则 `AgentSession` 内部用 `InMemoryStore` 兜底)。
|
||||
pub fn session_memory_backend(mut self, s: Arc<dyn MemoryStore>) -> Self {
|
||||
self.session_memory_backend = Some(s);
|
||||
self
|
||||
}
|
||||
|
||||
/// 整体覆盖 `AgentConfig`(选填,不传则用默认值)。
|
||||
pub fn config(mut self, c: AgentConfig) -> Self {
|
||||
self.config = Some(c);
|
||||
self
|
||||
}
|
||||
|
||||
/// 构造 `RuntimeBundle`,校验必填字段。
|
||||
///
|
||||
/// **错误**:`provider` / `tool_registry` / `hook_executor` 任一缺失则返回
|
||||
/// `AgentError::Config("missing <field>")`,不 panic。
|
||||
pub fn build(self) -> Result<RuntimeBundle, AgentError> {
|
||||
let provider = self
|
||||
.provider
|
||||
.ok_or_else(|| AgentError::Config("missing provider".into()))?;
|
||||
let tool_registry = self
|
||||
.tool_registry
|
||||
.ok_or_else(|| AgentError::Config("missing tool_registry".into()))?;
|
||||
let hook_executor = self
|
||||
.hook_executor
|
||||
.ok_or_else(|| AgentError::Config("missing hook_executor".into()))?;
|
||||
|
||||
let config = self.config.unwrap_or_default();
|
||||
|
||||
Ok(RuntimeBundle::new(
|
||||
provider,
|
||||
tool_registry,
|
||||
hook_executor,
|
||||
self.memory_store,
|
||||
self.retriever,
|
||||
self.session_memory_backend,
|
||||
config,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::llm::provider::LlmProvider;
|
||||
use crate::llm::types::{ChatRequest, ChatResponse};
|
||||
use crate::llm::error::LlmError;
|
||||
use async_trait::async_trait;
|
||||
|
||||
struct StubProvider;
|
||||
#[async_trait]
|
||||
impl LlmProvider for StubProvider {
|
||||
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, LlmError> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_with_all_required_succeeds() {
|
||||
let bundle = AgentBuilder::new()
|
||||
.provider(Arc::new(StubProvider))
|
||||
.tool_registry(Arc::new(ToolRegistry::new()))
|
||||
.hook_executor(Arc::new(HookExecutor::new()))
|
||||
.build();
|
||||
assert!(bundle.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_missing_provider_returns_config_error() {
|
||||
let result = AgentBuilder::new()
|
||||
.tool_registry(Arc::new(ToolRegistry::new()))
|
||||
.hook_executor(Arc::new(HookExecutor::new()))
|
||||
.build();
|
||||
assert!(matches!(result, Err(AgentError::Config(s)) if s.contains("provider")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_missing_tool_registry_returns_config_error() {
|
||||
let result = AgentBuilder::new()
|
||||
.provider(Arc::new(StubProvider))
|
||||
.hook_executor(Arc::new(HookExecutor::new()))
|
||||
.build();
|
||||
assert!(matches!(result, Err(AgentError::Config(s)) if s.contains("tool_registry")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_missing_hook_executor_returns_config_error() {
|
||||
let result = AgentBuilder::new()
|
||||
.provider(Arc::new(StubProvider))
|
||||
.tool_registry(Arc::new(ToolRegistry::new()))
|
||||
.build();
|
||||
assert!(matches!(result, Err(AgentError::Config(s)) if s.contains("hook_executor")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn optional_fields_default_to_none() {
|
||||
let bundle = AgentBuilder::new()
|
||||
.provider(Arc::new(StubProvider))
|
||||
.tool_registry(Arc::new(ToolRegistry::new()))
|
||||
.hook_executor(Arc::new(HookExecutor::new()))
|
||||
.build()
|
||||
.unwrap();
|
||||
assert!(bundle.memory_store.is_none());
|
||||
assert!(bundle.retriever.is_none());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
//! Agent Runtime 统一错误类型。
|
||||
//!
|
||||
//! `AgentError` 聚合 Phase 0-3 各层错误(LlmError / ToolError / MemoryError),
|
||||
//! 加上 Agent 层特有的错误变体。设计原则:
|
||||
//!
|
||||
//! - 聚合而非包装:保留内层错误的类型信息(避免 `Box<dyn Error>` 丢失上下文)
|
||||
//! - 显式 `From` 实现:让 `?` 运算符能透明传播下层错误
|
||||
//! - `is_recoverable()`:根据变体类型判定可恢复性,便于上层决策
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::memory::error::MemoryError;
|
||||
use crate::tools::error::ToolError;
|
||||
|
||||
/// Agent Runtime 统一错误枚举。
|
||||
///
|
||||
/// **不实现 `Clone`**:透传内层 `LlmError` / `MemoryError`,两者均未派生 `Clone`(保留
|
||||
/// 完整错误信息,传递所有权)。如需在多 session 间共享错误状态,用 `Arc<AgentError>` 包装。
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AgentError {
|
||||
/// LLM 调用错误(透传 Phase 0)。
|
||||
#[error("LLM 错误: {0}")]
|
||||
Llm(#[from] LlmError),
|
||||
|
||||
/// 工具调用错误(透传 Phase 2)。
|
||||
#[error("工具错误: {0}")]
|
||||
Tool(#[from] ToolError),
|
||||
|
||||
/// 记忆系统错误(透传 Phase 3)。
|
||||
#[error("记忆错误: {0}")]
|
||||
Memory(#[from] MemoryError),
|
||||
|
||||
/// Plan 解析失败(Phase 4b 新增)。
|
||||
#[error("Plan 解析错误: {0}")]
|
||||
PlanParse(String),
|
||||
|
||||
/// 钩子阻断操作(Agent 层特有)。
|
||||
#[error("钩子阻断: {0}")]
|
||||
HookBlocked(String),
|
||||
|
||||
/// 达到限制阈值(最大 turn、token 预算等)。
|
||||
#[error("超过限制: {0}")]
|
||||
LimitExceeded(String),
|
||||
|
||||
/// 配置错误(构建 RuntimeBundle / AgentSession 时校验失败)。
|
||||
#[error("配置错误: {0}")]
|
||||
Config(String),
|
||||
|
||||
/// 其他未分类错误(兜底)。
|
||||
#[error("Agent 错误: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl AgentError {
|
||||
/// 判定错误是否可恢复。
|
||||
///
|
||||
/// - `Llm` / `Memory`:由内层 `is_recoverable()` 决定
|
||||
/// - `Tool`:由内层 `is_recoverable()` 决定
|
||||
/// - `HookBlocked` / `LimitExceeded`:不可恢复(需人工介入或终止循环)
|
||||
/// - `Config` / `Other`:不可恢复
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
match self {
|
||||
Self::Llm(e) => matches!(
|
||||
e,
|
||||
LlmError::RateLimit { .. } | LlmError::Timeout { .. } | LlmError::Stream(_)
|
||||
),
|
||||
Self::Tool(e) => e.is_recoverable(),
|
||||
Self::Memory(e) => e.is_recoverable(),
|
||||
Self::PlanParse(_) => false,
|
||||
Self::HookBlocked(_) | Self::LimitExceeded(_) | Self::Config(_) | Self::Other(_) => {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn llm_recoverable_propagation() {
|
||||
let err = AgentError::Llm(LlmError::Timeout {
|
||||
duration: std::time::Duration::from_secs(30),
|
||||
});
|
||||
assert!(err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_non_recoverable_propagation() {
|
||||
let err = AgentError::Llm(LlmError::Authentication("bad key".into()));
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_recoverable_propagation() {
|
||||
let err = AgentError::Tool(ToolError::ExecutionFailed("foo".into(), "boom".into()));
|
||||
assert!(err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_non_recoverable_propagation() {
|
||||
let err = AgentError::Tool(ToolError::NotFound("foo".into()));
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn memory_recoverable_propagation() {
|
||||
let err = AgentError::Memory(MemoryError::NotFound("page".into()));
|
||||
assert!(err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn memory_non_recoverable_propagation() {
|
||||
let err = AgentError::Memory(MemoryError::Storage("disk full".into()));
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hook_blocked_not_recoverable() {
|
||||
assert!(!AgentError::HookBlocked("denied".into()).is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn limit_exceeded_not_recoverable() {
|
||||
assert!(!AgentError::LimitExceeded("max turns".into()).is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_not_recoverable() {
|
||||
assert!(!AgentError::Config("missing provider".into()).is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn other_not_recoverable() {
|
||||
assert!(!AgentError::Other("unknown".into()).is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_parse_not_recoverable() {
|
||||
assert!(!AgentError::PlanParse("bad json".into()).is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_llm_via_question_mark() {
|
||||
fn returns_llm() -> Result<(), LlmError> {
|
||||
Err(LlmError::Other("test".into()))
|
||||
}
|
||||
fn caller() -> Result<(), AgentError> {
|
||||
returns_llm()?;
|
||||
Ok(())
|
||||
}
|
||||
let err = caller().unwrap_err();
|
||||
assert!(matches!(err, AgentError::Llm(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_tool_via_question_mark() {
|
||||
fn returns_tool() -> Result<(), ToolError> {
|
||||
Err(ToolError::NotFound("x".into()))
|
||||
}
|
||||
fn caller() -> Result<(), AgentError> {
|
||||
returns_tool()?;
|
||||
Ok(())
|
||||
}
|
||||
let err = caller().unwrap_err();
|
||||
assert!(matches!(err, AgentError::Tool(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_memory_via_question_mark() {
|
||||
fn returns_mem() -> Result<(), MemoryError> {
|
||||
Err(MemoryError::Storage("x".into()))
|
||||
}
|
||||
fn caller() -> Result<(), AgentError> {
|
||||
returns_mem()?;
|
||||
Ok(())
|
||||
}
|
||||
let err = caller().unwrap_err();
|
||||
assert!(matches!(err, AgentError::Memory(_)));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
//! Runtime Bundle —— 显式依赖注入容器(OpenHarness 风格)。
|
||||
//!
|
||||
//! 集中持有 Agent 运行所需的全部运行时依赖:`LlmProvider` / `ToolRegistry` / `HookExecutor` /
|
||||
//! `MemoryStore`(弱引用)/ `MemoryRetriever`(弱引用) / `AgentConfig`。
|
||||
//!
|
||||
//! **设计意图**(参见 `docs/7-agent-runtime.md` §3.2.2):
|
||||
//!
|
||||
//! - 所有运行时依赖显式打包,便于跨 `AgentSession` 共享、便于测试注入 mock
|
||||
//! - `memory_store` / `retriever` 为 `Option`:上层应用不传也能跑(无记忆模式)
|
||||
//! - 构造时若 `retriever` 为 `Some`,自动注册 `"retrieve"` tool(v0.1 占位——
|
||||
//! Phase 4a 不在 `submit_turn` 中真正调用;Phase 4a 任务范围仅"装配可注册",
|
||||
//! 真正的 `RetrieveTool` 实现留待 v0.2 接入)
|
||||
//! - 不持有 `Box<dyn LlmProvider>` 而是 `Arc<dyn LlmProvider>`:支持多 session 共享
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::llm::compact::CompactConfig;
|
||||
use crate::llm::provider::LlmProvider;
|
||||
use crate::llm::hooks::HookExecutor;
|
||||
use crate::memory::retriever::MemoryRetriever;
|
||||
use crate::memory::store::MemoryStore;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
/// Agent 运行配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentConfig {
|
||||
/// 单次会话最大 turn 数(含工具循环内部 turn),默认 50。
|
||||
pub max_turns: u32,
|
||||
/// 单次会话最大工具循环轮次(与 LlmCycle 的 `max_tool_turns` 对齐),默认 10。
|
||||
pub max_tool_turns: u32,
|
||||
/// 会话 TTL(None 表示无过期),默认 None。
|
||||
pub session_ttl: Option<Duration>,
|
||||
/// 上下文压缩配置(None 表示不启用自动压缩),默认 None。
|
||||
pub compact_config: Option<CompactConfig>,
|
||||
}
|
||||
|
||||
impl Default for AgentConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_turns: 50,
|
||||
max_tool_turns: 10,
|
||||
session_ttl: None,
|
||||
compact_config: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Agent Runtime 依赖注入容器。
|
||||
///
|
||||
/// 通过 `AgentBuilder::build()` 构造;构造完成后内部为只读视图。
|
||||
/// `Arc` 共享,多个 `AgentSession` 可共用同一个 bundle。
|
||||
#[derive(Clone)]
|
||||
pub struct RuntimeBundle {
|
||||
/// LLM 后端(强引用,多 session 共享)。
|
||||
pub provider: Arc<dyn LlmProvider>,
|
||||
|
||||
/// 工具注册表(强引用,多 session 共享)。
|
||||
pub tool_registry: Arc<ToolRegistry>,
|
||||
|
||||
/// 钩子执行器(强引用,多 session 共享)。
|
||||
pub hook_executor: Arc<HookExecutor>,
|
||||
|
||||
/// 持久化记忆后端(弱引用 —— 不传也能跑)。
|
||||
pub memory_store: Option<Arc<dyn MemoryStore>>,
|
||||
|
||||
/// 记忆检索器(弱引用 —— 不传也能跑)。
|
||||
/// 传入时可在 `submit_turn` 内部将检索能力作为工具暴露给 LLM。
|
||||
pub retriever: Option<Arc<MemoryRetriever>>,
|
||||
|
||||
/// SessionMemory 后端(选填)。
|
||||
/// 传入时 `SessionMemory` 使用该后端(支持跨进程共享);
|
||||
/// 不传时 `AgentSession` 内部自动创建 `InMemoryStore` 作为进程级隔离的后端。
|
||||
pub session_memory_backend: Option<Arc<dyn MemoryStore>>,
|
||||
|
||||
/// 运行时配置。
|
||||
pub config: AgentConfig,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for RuntimeBundle {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RuntimeBundle")
|
||||
.field("provider_type", &"<dyn LlmProvider>")
|
||||
.field("tool_names", &self.tool_registry.list_tools())
|
||||
.field("has_memory_store", &self.memory_store.is_some())
|
||||
.field("has_retriever", &self.retriever.is_some())
|
||||
.field(
|
||||
"has_session_memory_backend",
|
||||
&self.session_memory_backend.is_some(),
|
||||
)
|
||||
.field("config", &self.config)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimeBundle {
|
||||
/// 构造一个 `RuntimeBundle`。
|
||||
///
|
||||
/// **Phase 4a 行为**:`retriever` 存在时仅占位记录,不真正注入工具
|
||||
/// (v0.1 不在 `submit_turn` 中启用检索;Phase 4c 之后再决定是否注册成 tool)。
|
||||
/// 真正的工具注入留待 v0.2 接入 `RetrieveTool` 实现。
|
||||
pub fn new(
|
||||
provider: Arc<dyn LlmProvider>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
hook_executor: Arc<HookExecutor>,
|
||||
memory_store: Option<Arc<dyn MemoryStore>>,
|
||||
retriever: Option<Arc<MemoryRetriever>>,
|
||||
session_memory_backend: Option<Arc<dyn MemoryStore>>,
|
||||
config: AgentConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
tool_registry,
|
||||
hook_executor,
|
||||
memory_store,
|
||||
retriever,
|
||||
session_memory_backend,
|
||||
config,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,356 @@
|
||||
//! AgentSession —— 智能体"会话"实例。
|
||||
//!
|
||||
//! 设计要点(参见 `docs/7-agent-runtime.md` §3.2.3):
|
||||
//!
|
||||
//! - **会话 = 角色 + 状态**:绑定 `session_id` / `agent` / `bundle`,累计 `turn_index` 和 `cost_so_far`
|
||||
//! - **最小 reference impl**:`submit_turn` 演示"组装 LlmCycle → submit_with_tools → 累计 cost"的标准流程
|
||||
//! - **不做业务循环**:多轮策略、错误重试、记忆回写由上层应用或具体 `TaskAgent` 决定
|
||||
//! - **不持有 ConversationMemory**:上层可独立 new 一个 `ConversationMemory`,在合适的时机调 `add_message`
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::agent::Agent;
|
||||
use crate::agent::error::AgentError;
|
||||
use crate::agent::runtime::RuntimeBundle;
|
||||
use crate::agent::session_memory::SessionMemory;
|
||||
use crate::llm::cycle::{CostTracker, CycleConfig, LlmCycle};
|
||||
use crate::llm::hooks::{HookContext, HookEvent};
|
||||
use crate::llm::types::ChatResponse;
|
||||
use crate::memory::store::InMemoryStore;
|
||||
|
||||
/// Agent 会话实例。
|
||||
///
|
||||
/// 同一 `Agent` 可被多个 `AgentSession` 复用(不同 session_id 互不干扰)。
|
||||
/// `submit_turn` 一次只跑一轮 LLM 调用(含自动 tool 循环)。
|
||||
///
|
||||
/// **不实现 `Clone`**:session 持有累计 `turn_index` / `cost_so_far` / `session_memory`,
|
||||
/// 共享这些状态需要显式 sync 语义;如果上层需要并发访问,自己用 `Arc<Mutex<_>>` 包装。
|
||||
pub struct AgentSession {
|
||||
/// 会话 ID(由调用方指定,用于日志/追踪/记忆关联)。
|
||||
pub session_id: String,
|
||||
/// 角色(可热切换为同 bundle 下的其他角色)。
|
||||
pub agent: Arc<dyn Agent>,
|
||||
bundle: Arc<RuntimeBundle>,
|
||||
turn_index: u32,
|
||||
cost_so_far: CostTracker,
|
||||
/// 会话级记忆(Phase 4c 替换内联 HashMap)。
|
||||
pub session_memory: SessionMemory,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AgentSession {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AgentSession")
|
||||
.field("session_id", &self.session_id)
|
||||
.field("agent", &self.agent.name())
|
||||
.field("turn_index", &self.turn_index)
|
||||
.field("cost_so_far", &self.cost_so_far.total())
|
||||
.field("session_memory", &"<SessionMemory>")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentSession {
|
||||
/// 创建一个新的会话实例。
|
||||
///
|
||||
/// `agent` 与 `bundle` 共同决定 `submit_turn` 行为:system_prompt / 工具集 / LLM 后端均来自它们。
|
||||
pub fn new(
|
||||
agent: Arc<dyn Agent>,
|
||||
session_id: impl Into<String>,
|
||||
bundle: Arc<RuntimeBundle>,
|
||||
) -> Self {
|
||||
let session_id_str = session_id.into();
|
||||
let backend = bundle
|
||||
.session_memory_backend
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(InMemoryStore::new()));
|
||||
let session_memory = SessionMemory::new(backend, &session_id_str);
|
||||
Self {
|
||||
session_id: session_id_str,
|
||||
agent,
|
||||
bundle,
|
||||
turn_index: 0,
|
||||
cost_so_far: CostTracker::default(),
|
||||
session_memory,
|
||||
}
|
||||
}
|
||||
|
||||
/// 当前 turn 序号(0-based:第一次 `submit_turn` 完成后变 1)。
|
||||
pub fn turn_index(&self) -> u32 {
|
||||
self.turn_index
|
||||
}
|
||||
|
||||
/// 累计用量(跨所有 turn)。
|
||||
pub fn usage(&self) -> &CostTracker {
|
||||
&self.cost_so_far
|
||||
}
|
||||
|
||||
/// 会话级记忆引用。
|
||||
pub fn session_memory(&self) -> &SessionMemory {
|
||||
&self.session_memory
|
||||
}
|
||||
|
||||
/// 写入一条会话级数据(覆盖同名 key)。
|
||||
pub async fn set_session_data(
|
||||
&mut self,
|
||||
key: impl Into<String>,
|
||||
value: impl Into<String>,
|
||||
) -> Result<(), AgentError> {
|
||||
self.session_memory.set(&key.into(), &value.into()).await
|
||||
}
|
||||
|
||||
/// 读取一条会话级数据。
|
||||
pub async fn get_session_data(&self, key: &str) -> Result<Option<String>, AgentError> {
|
||||
self.session_memory.get(key).await
|
||||
}
|
||||
|
||||
/// 提交一轮对话(含自动 tool 循环),返回 LLM 响应。
|
||||
///
|
||||
/// 流程:
|
||||
/// 1. 触发 `OnTurnStart` hook
|
||||
/// 2. 组装 `LlmCycle`(注入 system_prompt / hook_executor / compact_config / 消息历史)
|
||||
/// 3. `submit_with_tools` 跑单轮对话
|
||||
/// 4. 累计 `cost_so_far`
|
||||
/// 5. 触发 `OnTurnEnd` hook
|
||||
/// 6. `turn_index += 1`
|
||||
///
|
||||
/// **不做**:
|
||||
/// - 不持有 `ConversationMemory`(由上层独立 task 决定何时回写)
|
||||
/// - 不做 Plan 拆解(Phase 4b 才加 `TaskAgent`)
|
||||
/// - 不做 session_data 持久化(Phase 4c 替换为 `SessionMemory`)
|
||||
pub async fn submit_turn(
|
||||
&mut self,
|
||||
user_input: impl Into<String>,
|
||||
) -> Result<ChatResponse, AgentError> {
|
||||
let turn_index = self.turn_index;
|
||||
let hook_executor = Arc::clone(&self.bundle.hook_executor);
|
||||
|
||||
// 1. 触发 OnTurnStart hook
|
||||
let start_ctx =
|
||||
HookContext::new(HookEvent::OnTurnStart).with_turn_index(turn_index);
|
||||
hook_executor
|
||||
.execute(HookEvent::OnTurnStart, &start_ctx)
|
||||
.await;
|
||||
|
||||
// 2. 组装 LlmCycle —— 共享 bundle 中的 provider 句柄
|
||||
// 工具列表从 agent.tool_definitions(bundle) 派生(默认 = bundle 全量);
|
||||
// submit_with_tools 内部从 registry 自行取 definitions,此处仅消费以触发
|
||||
// 子 trait 覆盖(白名单/过滤)的副作用。
|
||||
let _ = self.agent.tool_definitions(&self.bundle);
|
||||
let mut cycle = LlmCycle::new_with_arc(Arc::clone(&self.bundle.provider), CycleConfig::default())
|
||||
.with_messages(Vec::new());
|
||||
if let Some(prompt) = self.agent.system_prompt() {
|
||||
cycle = cycle.with_system_prompt(prompt.to_string());
|
||||
}
|
||||
if let Some(cfg) = self.bundle.config.compact_config.clone() {
|
||||
cycle = cycle.with_compact_config(cfg);
|
||||
}
|
||||
|
||||
// 3. 提交(HookExecutor 不在这里传——内部 hook 由 LlmCycle 在 PreRequest/PostRequest 触发)
|
||||
let response = cycle
|
||||
.submit_with_tools(user_input.into(), &self.bundle.tool_registry)
|
||||
.await?;
|
||||
|
||||
// 4. 累计 cost
|
||||
self.cost_so_far.add(&response.usage);
|
||||
|
||||
// 5. 触发 OnTurnEnd hook
|
||||
let end_ctx = HookContext::new(HookEvent::OnTurnEnd).with_turn_index(turn_index);
|
||||
hook_executor.execute(HookEvent::OnTurnEnd, &end_ctx).await;
|
||||
|
||||
// 6. turn_index 递增
|
||||
self.turn_index += 1;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::agent::builder::AgentBuilder;
|
||||
use crate::llm::hooks::{Hook, HookContext, HookExecutor, HookResult};
|
||||
use crate::llm::provider::LlmProvider;
|
||||
use crate::llm::types::{
|
||||
ChatRequest, ChatResponse, FinishReason, OpenaiChatMessage,
|
||||
};
|
||||
use crate::tools::ToolRegistry;
|
||||
use async_trait::async_trait;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
|
||||
/// 计数 hook —— 每被调用一次 +1。
|
||||
struct CountHook(AtomicU32);
|
||||
|
||||
#[async_trait]
|
||||
impl Hook for CountHook {
|
||||
async fn execute(&self, _ctx: &HookContext<'_>) -> HookResult {
|
||||
self.0.fetch_add(1, Ordering::SeqCst);
|
||||
HookResult::allow()
|
||||
}
|
||||
}
|
||||
|
||||
/// 把 `Arc<CountHook>` 包装为 `Box<dyn Hook>`(dyn Hook 不能直接来自 Arc)。
|
||||
struct CountHookAdapter(Arc<CountHook>);
|
||||
|
||||
#[async_trait]
|
||||
impl Hook for CountHookAdapter {
|
||||
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult {
|
||||
self.0.execute(ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
/// MockProvider:按调用顺序返回预设响应。
|
||||
struct MockProvider {
|
||||
responses: std::sync::Mutex<Vec<ChatResponse>>,
|
||||
}
|
||||
|
||||
impl MockProvider {
|
||||
fn new(responses: Vec<ChatResponse>) -> Self {
|
||||
Self {
|
||||
responses: std::sync::Mutex::new(responses),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for MockProvider {
|
||||
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, crate::llm::error::LlmError> {
|
||||
let mut responses = self.responses.lock().unwrap();
|
||||
if responses.is_empty() {
|
||||
return Err(crate::llm::error::LlmError::Other(
|
||||
"no more mock responses".into(),
|
||||
));
|
||||
}
|
||||
Ok(responses.remove(0))
|
||||
}
|
||||
}
|
||||
|
||||
struct StubAgent {
|
||||
name: String,
|
||||
prompt: Option<String>,
|
||||
}
|
||||
|
||||
impl Agent for StubAgent {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
fn system_prompt(&self) -> Option<&str> {
|
||||
self.prompt.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_text(text: &str) -> ChatResponse {
|
||||
ChatResponse {
|
||||
message: OpenaiChatMessage::assistant_text(text),
|
||||
usage: crate::llm::types::Usage::from_input_output(10, 5),
|
||||
stop_reason: Some(FinishReason::Stop),
|
||||
}
|
||||
}
|
||||
|
||||
/// 烟雾测试 1:AgentSession::submit_turn 跑通 mock provider。
|
||||
#[tokio::test]
|
||||
async fn submit_turn_runs_with_mock_provider() {
|
||||
let provider = Arc::new(MockProvider::new(vec![assistant_text("hello back")]));
|
||||
let agent = Arc::new(StubAgent {
|
||||
name: "stub".into(),
|
||||
prompt: Some("you are a test agent".into()),
|
||||
});
|
||||
let bundle = Arc::new(
|
||||
AgentBuilder::new()
|
||||
.provider(provider)
|
||||
.tool_registry(Arc::new(ToolRegistry::new()))
|
||||
.hook_executor(Arc::new(HookExecutor::new()))
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let mut session = AgentSession::new(agent, "s1", bundle);
|
||||
assert_eq!(session.turn_index(), 0);
|
||||
|
||||
let response = session.submit_turn("hi").await.unwrap();
|
||||
let text = match &response.message {
|
||||
OpenaiChatMessage::Assistant { content, .. } => {
|
||||
if let crate::llm::types::ContentField::String(s) = content {
|
||||
s.clone()
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
_ => String::new(),
|
||||
};
|
||||
assert_eq!(text, "hello back");
|
||||
assert_eq!(session.turn_index(), 1);
|
||||
assert_eq!(session.usage().total().prompt_tokens, 10);
|
||||
assert_eq!(session.usage().total().completion_tokens, 5);
|
||||
}
|
||||
|
||||
/// 烟雾测试 2:session_data 读写。
|
||||
#[tokio::test]
|
||||
async fn session_data_set_get() {
|
||||
let provider = Arc::new(MockProvider::new(vec![]));
|
||||
let agent = Arc::new(StubAgent {
|
||||
name: "stub".into(),
|
||||
prompt: None,
|
||||
});
|
||||
let bundle = Arc::new(
|
||||
AgentBuilder::new()
|
||||
.provider(provider)
|
||||
.tool_registry(Arc::new(ToolRegistry::new()))
|
||||
.hook_executor(Arc::new(HookExecutor::new()))
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
let mut session = AgentSession::new(agent, "s2", bundle);
|
||||
|
||||
assert!(session.get_session_data("k").await.unwrap().is_none());
|
||||
session.set_session_data("k", "v").await.unwrap();
|
||||
assert_eq!(session.get_session_data("k").await.unwrap(), Some("v".into()));
|
||||
// 覆盖写
|
||||
session.set_session_data("k", "v2").await.unwrap();
|
||||
assert_eq!(
|
||||
session.get_session_data("k").await.unwrap(),
|
||||
Some("v2".into())
|
||||
);
|
||||
}
|
||||
|
||||
/// 烟雾测试 3:submit_turn 触发 OnTurnStart / OnTurnEnd hook。
|
||||
#[tokio::test]
|
||||
async fn submit_turn_triggers_turn_hooks() {
|
||||
let mut hook_executor = HookExecutor::new();
|
||||
let start_count = Arc::new(CountHook(AtomicU32::new(0)));
|
||||
let end_count = Arc::new(CountHook(AtomicU32::new(0)));
|
||||
hook_executor.register(
|
||||
HookEvent::OnTurnStart,
|
||||
Box::new(CountHookAdapter(start_count.clone())),
|
||||
);
|
||||
hook_executor.register(
|
||||
HookEvent::OnTurnEnd,
|
||||
Box::new(CountHookAdapter(end_count.clone())),
|
||||
);
|
||||
|
||||
let provider = Arc::new(MockProvider::new(vec![
|
||||
assistant_text("ok"),
|
||||
assistant_text("ok 2"),
|
||||
]));
|
||||
let agent = Arc::new(StubAgent {
|
||||
name: "stub".into(),
|
||||
prompt: None,
|
||||
});
|
||||
let bundle = Arc::new(
|
||||
AgentBuilder::new()
|
||||
.provider(provider)
|
||||
.tool_registry(Arc::new(ToolRegistry::new()))
|
||||
.hook_executor(Arc::new(hook_executor))
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
let mut session = AgentSession::new(agent, "s3", bundle);
|
||||
|
||||
session.submit_turn("hi").await.unwrap();
|
||||
assert_eq!(start_count.0.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(end_count.0.load(Ordering::SeqCst), 1);
|
||||
|
||||
session.submit_turn("hi again").await.unwrap();
|
||||
assert_eq!(start_count.0.load(Ordering::SeqCst), 2);
|
||||
assert_eq!(end_count.0.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,184 @@
|
||||
//! SessionMemory —— 会话级记忆,用于 context 间的信息桥接。
|
||||
//!
|
||||
//! 设计要点(参见 `docs/7-agent-runtime.md` §3.2.8):
|
||||
//!
|
||||
//! - **会话级**:单 session 内共享,跨 context 桥接信息(不是持久层,也不是对话历史)
|
||||
//! - **复用 Phase 3 `MemoryStore`**:不引入新的存储后端机制
|
||||
//! - **按 `namespace` 隔离**:每个 session 一个独立命名空间,防止跨 session 泄漏
|
||||
//! - **`snapshot()` 格式化为标记文本**:专为注入 system prompt 设计
|
||||
//! - **所有方法为 `async`**:因为后端可能是跨进程的(Redis / DB)
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use crate::agent::error::AgentError;
|
||||
use crate::memory::store::MemoryStore;
|
||||
use crate::memory::types::{MemoryFilter, MemoryItem};
|
||||
|
||||
/// 会话级记忆实例。
|
||||
///
|
||||
/// 基于 [`MemoryStore`] 后端,按 `namespace` 隔离键值数据。
|
||||
/// 适用于 session 内各 context 之间的信息桥接(如将关键结论传递给后续 context)。
|
||||
pub struct SessionMemory {
|
||||
store: Arc<dyn MemoryStore>,
|
||||
namespace: String,
|
||||
}
|
||||
|
||||
impl SessionMemory {
|
||||
/// 创建新的 session 级记忆实例。
|
||||
///
|
||||
/// - `store`:后端存储(可跨进程共享的 `MemoryStore` 实现)。
|
||||
/// - `namespace`:按 session_id 隔离,防止跨 session 泄漏。
|
||||
/// 内部会自动添加 `"_session_"` 前缀。
|
||||
pub fn new(store: Arc<dyn MemoryStore>, namespace: &str) -> Self {
|
||||
Self {
|
||||
store,
|
||||
namespace: format!("_session_{namespace}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// 内部 key 格式:`"{namespace}:{key}"`。
|
||||
fn internal_key(&self, key: &str) -> String {
|
||||
format!("{}:{}", self.namespace, key)
|
||||
}
|
||||
|
||||
/// 写入一条 key-value 条目(覆盖同名 key)。
|
||||
pub async fn set(&self, key: &str, value: &str) -> Result<(), AgentError> {
|
||||
let item = MemoryItem {
|
||||
id: self.internal_key(key),
|
||||
content: value.to_string(),
|
||||
metadata: serde_json::json!({}),
|
||||
created_at: OffsetDateTime::now_utc(),
|
||||
};
|
||||
self.store.save(item).await.map_err(AgentError::Memory)
|
||||
}
|
||||
|
||||
/// 读取指定 key 的值。
|
||||
pub async fn get(&self, key: &str) -> Result<Option<String>, AgentError> {
|
||||
let item = self
|
||||
.store
|
||||
.get(&self.internal_key(key))
|
||||
.await
|
||||
.map_err(AgentError::Memory)?;
|
||||
Ok(item.map(|i| i.content))
|
||||
}
|
||||
|
||||
/// 返回所有条目的格式化快照,适合注入 system prompt。
|
||||
///
|
||||
/// 格式:
|
||||
/// ```text
|
||||
/// <session-context>
|
||||
/// key1: value1
|
||||
/// key2: value2
|
||||
/// </session-context>
|
||||
/// ```
|
||||
pub async fn snapshot(&self) -> Result<String, AgentError> {
|
||||
let filter = MemoryFilter {
|
||||
prefix: Some(format!("{}:", self.namespace)),
|
||||
..Default::default()
|
||||
};
|
||||
let items = self
|
||||
.store
|
||||
.list(&filter)
|
||||
.await
|
||||
.map_err(AgentError::Memory)?;
|
||||
|
||||
let mut lines = Vec::with_capacity(items.len() + 2);
|
||||
lines.push("<session-context>".to_string());
|
||||
for item in items {
|
||||
// 从 id 中提取原始 key(去掉 namespace 前缀)
|
||||
let key = item
|
||||
.id
|
||||
.strip_prefix(&format!("{}:", self.namespace))
|
||||
.unwrap_or(&item.id);
|
||||
lines.push(format!("{}: {}", key, item.content));
|
||||
}
|
||||
lines.push("</session-context>".to_string());
|
||||
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
|
||||
/// 删除指定 key。
|
||||
pub async fn remove(&self, key: &str) -> Result<(), AgentError> {
|
||||
self.store
|
||||
.delete(&self.internal_key(key))
|
||||
.await
|
||||
.map_err(AgentError::Memory)
|
||||
}
|
||||
|
||||
/// 清空当前 namespace 下所有条目。
|
||||
pub async fn clear(&self) -> Result<(), AgentError> {
|
||||
let filter = MemoryFilter {
|
||||
prefix: Some(format!("{}:", self.namespace)),
|
||||
..Default::default()
|
||||
};
|
||||
let items = self
|
||||
.store
|
||||
.list(&filter)
|
||||
.await
|
||||
.map_err(AgentError::Memory)?;
|
||||
|
||||
for item in items {
|
||||
self.store
|
||||
.delete(&item.id)
|
||||
.await
|
||||
.map_err(AgentError::Memory)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::store::InMemoryStore;
|
||||
|
||||
fn make_store() -> Arc<dyn MemoryStore> {
|
||||
Arc::new(InMemoryStore::new())
|
||||
}
|
||||
|
||||
/// 烟雾测试 1:set / get / remove 基本读写。
|
||||
#[tokio::test]
|
||||
async fn set_get_remove() {
|
||||
let mem = SessionMemory::new(make_store(), "test-session");
|
||||
|
||||
assert!(mem.get("k").await.unwrap().is_none());
|
||||
|
||||
mem.set("k", "v").await.unwrap();
|
||||
assert_eq!(mem.get("k").await.unwrap(), Some("v".into()));
|
||||
|
||||
mem.remove("k").await.unwrap();
|
||||
assert!(mem.get("k").await.unwrap().is_none());
|
||||
}
|
||||
|
||||
/// 烟雾测试 2:snapshot 格式化输出。
|
||||
#[tokio::test]
|
||||
async fn snapshot_format() {
|
||||
let mem = SessionMemory::new(make_store(), "s1");
|
||||
mem.set("design", "PostgreSQL").await.unwrap();
|
||||
mem.set("lang", "Rust").await.unwrap();
|
||||
|
||||
let snap = mem.snapshot().await.unwrap();
|
||||
assert!(snap.contains("<session-context>"));
|
||||
assert!(snap.contains("</session-context>"));
|
||||
assert!(snap.contains("design: PostgreSQL"));
|
||||
assert!(snap.contains("lang: Rust"));
|
||||
}
|
||||
|
||||
/// 烟雾测试 3:clear 清空当前 namespace。
|
||||
#[tokio::test]
|
||||
async fn clear_only_affects_own_namespace() {
|
||||
let store = make_store();
|
||||
let mem_a = SessionMemory::new(store.clone(), "a");
|
||||
let mem_b = SessionMemory::new(store.clone(), "b");
|
||||
|
||||
mem_a.set("key", "val_a").await.unwrap();
|
||||
mem_b.set("key", "val_b").await.unwrap();
|
||||
|
||||
mem_a.clear().await.unwrap();
|
||||
|
||||
assert!(mem_a.get("key").await.unwrap().is_none());
|
||||
assert_eq!(mem_b.get("key").await.unwrap(), Some("val_b".into()));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,243 @@
|
||||
//! 任务规划数据结构 + Phase 4b 任务执行 trait。
|
||||
//!
|
||||
//! Phase 4a 范围:仅 `Plan` / `Step` / `StepStatus` 纯数据结构。
|
||||
//! Phase 4b 在此文件追加 `TaskAgent` trait / `PlanParser` trait / `JsonPlanParser` 参考实现。
|
||||
//!
|
||||
//! 设计意图(参见 `docs/7-agent-runtime.md` §3.2.4、§3.3.1):
|
||||
//!
|
||||
//! - `StepStatus` 用 enum 而非简单 bool,便于 UI 展示和统计
|
||||
//! - 状态机单向:`Pending → Running → (Completed | Failed | Skipped)`,不回退
|
||||
//! - 重试由上层新建 `Plan` 实现,`TaskAgent` 不做自动重试
|
||||
|
||||
use crate::agent::error::AgentError;
|
||||
use crate::llm::types::ChatResponse;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// 任务规划 —— 一组有序的 Step。
|
||||
#[derive(Debug)]
|
||||
pub struct Plan {
|
||||
/// 规划唯一标识。
|
||||
pub id: String,
|
||||
/// 规划目标(人类可读)。
|
||||
pub goal: String,
|
||||
/// 步骤列表。
|
||||
pub steps: Vec<Step>,
|
||||
}
|
||||
|
||||
/// 任务步骤。
|
||||
#[derive(Debug)]
|
||||
pub struct Step {
|
||||
/// 步骤在 Plan 中的位置(0-based)。
|
||||
pub index: usize,
|
||||
/// 步骤描述(注入 LLM 作为 user prompt)。
|
||||
pub description: String,
|
||||
/// 当前状态。
|
||||
pub status: StepStatus,
|
||||
}
|
||||
|
||||
impl Step {
|
||||
/// 创建一个初始为 `Pending` 的步骤。
|
||||
pub fn new(index: usize, description: impl Into<String>) -> Self {
|
||||
Self {
|
||||
index,
|
||||
description: description.into(),
|
||||
status: StepStatus::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 步骤状态机。
|
||||
///
|
||||
/// 转换路径:`Pending → Running → (Completed | Failed | Skipped)`,单向不回退。
|
||||
///
|
||||
/// **不实现 `Clone`**:`Failed` 变体携带 `AgentError`,下层 `LlmError` / `MemoryError`
|
||||
/// 均未派生 `Clone`(保留原始错误信息,传递所有权而非克隆)。如需复制 `Plan`,
|
||||
/// 只能 clone 处于 `Pending` / `Running` / `Completed` / `Skipped` 状态的步骤。
|
||||
#[derive(Debug)]
|
||||
pub enum StepStatus {
|
||||
/// 初始状态 —— 等待执行。
|
||||
Pending,
|
||||
/// 正在执行(`TaskAgent::execute_plan` 进入)。
|
||||
Running,
|
||||
/// 已完成(含 LLM 响应)。
|
||||
Completed(ChatResponse),
|
||||
/// 失败(含错误)。
|
||||
Failed(AgentError),
|
||||
/// 跳过(上层主动跳过)。
|
||||
Skipped,
|
||||
}
|
||||
|
||||
impl StepStatus {
|
||||
/// 状态是否处于"未完成"。
|
||||
pub fn is_pending(&self) -> bool {
|
||||
matches!(self, Self::Pending)
|
||||
}
|
||||
|
||||
/// 状态是否处于终态。
|
||||
pub fn is_terminal(&self) -> bool {
|
||||
matches!(self, Self::Completed(_) | Self::Failed(_) | Self::Skipped)
|
||||
}
|
||||
}
|
||||
|
||||
/// Plan 解析接口 —— 将 LLM 原始输出转换为 `Plan` 数据结构。
|
||||
///
|
||||
/// **注入式**:上层应用可以注入自定义解析器(如基于 XML / YAML / 自定义 DSL),
|
||||
/// `JsonPlanParser` 是参考实现而非默认实现。
|
||||
#[async_trait]
|
||||
pub trait PlanParser: Send + Sync {
|
||||
/// 将 LLM 原始输出解析为 `Plan`。
|
||||
///
|
||||
/// - `raw`:LLM 返回的原始文本
|
||||
/// - `goal`:规划目标(用于填充 `Plan.goal`)
|
||||
async fn parse(&self, raw: &str, goal: &str) -> Result<Plan, AgentError>;
|
||||
}
|
||||
|
||||
/// JSON 格式的 Plan 解析器(参考实现)。
|
||||
///
|
||||
/// 期望 LLM 输出形如:
|
||||
/// ```json
|
||||
/// {"steps": [{"description": "..."}, ...]}
|
||||
/// ```
|
||||
/// 的 JSON 文本。解析失败返回 `AgentError::PlanParse`。
|
||||
pub struct JsonPlanParser;
|
||||
|
||||
#[async_trait]
|
||||
impl PlanParser for JsonPlanParser {
|
||||
async fn parse(&self, raw: &str, goal: &str) -> Result<Plan, AgentError> {
|
||||
let parsed: serde_json::Value = serde_json::from_str(raw)
|
||||
.map_err(|e| AgentError::PlanParse(format!("JSON 解析失败: {e}")))?;
|
||||
|
||||
let steps_array = parsed
|
||||
.get("steps")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| AgentError::PlanParse("缺少 'steps' 数组".into()))?;
|
||||
|
||||
let steps: Vec<Step> = steps_array
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, item)| {
|
||||
let description = item
|
||||
.get("description")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| {
|
||||
AgentError::PlanParse(format!("步骤 {i} 缺少 'description' 字段"))
|
||||
})?;
|
||||
Ok(Step::new(i, description))
|
||||
})
|
||||
.collect::<Result<Vec<_>, AgentError>>()?;
|
||||
|
||||
if steps.is_empty() {
|
||||
return Err(AgentError::PlanParse(
|
||||
"Plan 至少需要一个步骤".into(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Plan {
|
||||
id: uuid(),
|
||||
goal: goal.to_string(),
|
||||
steps,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// 任务型智能体 —— 自主规划与执行。
|
||||
///
|
||||
/// 与基础 `Agent` trait 分离:`Agent` 定义"角色"(system prompt + 工具集),
|
||||
/// `TaskAgent` 定义"规划/执行"行为(如何拆 Plan、如何执行 Plan)。
|
||||
#[async_trait]
|
||||
pub trait TaskAgent: Send + Sync {
|
||||
/// 自主式入口:根据目标生成 Plan 并执行。
|
||||
///
|
||||
/// 实现内部应调用 `PlanParser::parse` 从 LLM 输出生成 Plan,
|
||||
/// 然后调用 `execute_plan` 执行。
|
||||
async fn run(&mut self, goal: &str) -> Result<Plan, AgentError>;
|
||||
|
||||
/// 外部驱动式入口:执行预定义的 Plan。
|
||||
///
|
||||
/// 逐步调用 `AgentSession::submit_turn`,每步完成后触发
|
||||
/// `OnPlanStepComplete` hook,更新步骤状态。
|
||||
async fn execute_plan(&mut self, plan: &mut Plan) -> Result<(), AgentError>;
|
||||
}
|
||||
|
||||
/// 生成简易唯一 ID(仅用于 Plan 标识,非加密安全)。
|
||||
fn uuid() -> String {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let ts = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos();
|
||||
format!("{ts:x}")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn step_initial_state_is_pending() {
|
||||
let s = Step::new(0, "do something");
|
||||
assert!(s.status.is_pending());
|
||||
assert!(!s.status.is_terminal());
|
||||
assert_eq!(s.index, 0);
|
||||
assert_eq!(s.description, "do something");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn terminal_states_classified() {
|
||||
let err = AgentError::Other("x".into());
|
||||
assert!(StepStatus::Failed(err).is_terminal());
|
||||
assert!(StepStatus::Skipped.is_terminal());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn running_is_not_terminal() {
|
||||
assert!(!StepStatus::Running.is_terminal());
|
||||
assert!(!StepStatus::Running.is_pending());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_holds_steps() {
|
||||
let plan = Plan {
|
||||
id: "p1".into(),
|
||||
goal: "test goal".into(),
|
||||
steps: vec![
|
||||
Step::new(0, "first"),
|
||||
Step::new(1, "second"),
|
||||
],
|
||||
};
|
||||
assert_eq!(plan.steps.len(), 2);
|
||||
assert_eq!(plan.steps[0].index, 0);
|
||||
assert_eq!(plan.steps[1].index, 1);
|
||||
}
|
||||
|
||||
/// 烟雾测试 1:JsonPlanParser 解析合法 JSON。
|
||||
#[tokio::test]
|
||||
async fn json_plan_parser_success() {
|
||||
let parser = JsonPlanParser;
|
||||
let input = r#"{"steps": [{"description": "step one"}, {"description": "step two"}]}"#;
|
||||
let plan = parser.parse(input, "my goal").await.unwrap();
|
||||
assert_eq!(plan.goal, "my goal");
|
||||
assert_eq!(plan.steps.len(), 2);
|
||||
assert_eq!(plan.steps[0].description, "step one");
|
||||
assert_eq!(plan.steps[1].description, "step two");
|
||||
assert!(plan.steps.iter().all(|s| s.status.is_pending()));
|
||||
}
|
||||
|
||||
/// 烟雾测试 2:JsonPlanParser 解析失败返回 AgentError::PlanParse。
|
||||
#[tokio::test]
|
||||
async fn json_plan_parser_invalid_json() {
|
||||
let parser = JsonPlanParser;
|
||||
let err = parser.parse("not json", "goal").await.unwrap_err();
|
||||
assert!(matches!(err, AgentError::PlanParse(_)));
|
||||
}
|
||||
|
||||
/// 烟雾测试 3:JsonPlanParser 空步骤返回错误。
|
||||
#[tokio::test]
|
||||
async fn json_plan_parser_empty_steps() {
|
||||
let parser = JsonPlanParser;
|
||||
let input = r#"{"steps": []}"#;
|
||||
let err = parser.parse(input, "goal").await.unwrap_err();
|
||||
assert!(matches!(err, AgentError::PlanParse(_)));
|
||||
}
|
||||
}
|
||||
+19
-11
@@ -1,14 +1,22 @@
|
||||
pub fn add(left: u64, right: u64) -> u64 {
|
||||
left + right
|
||||
}
|
||||
//! agcore —— 智能体(Agent)核心工具箱。
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
pub mod agent;
|
||||
pub mod llm;
|
||||
pub mod memory;
|
||||
pub mod prompt;
|
||||
pub mod tools;
|
||||
|
||||
#[test]
|
||||
fn it_works() {
|
||||
let result = add(2, 2);
|
||||
assert_eq!(result, 4);
|
||||
}
|
||||
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
|
||||
|
||||
static INIT: std::sync::Once = std::sync::Once::new();
|
||||
|
||||
pub fn init_tracing() {
|
||||
INIT.call_once(|| {
|
||||
let filter =
|
||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("agcore=info"));
|
||||
tracing_subscriber::registry()
|
||||
.with(fmt::layer())
|
||||
.with(filter)
|
||||
.init();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
//! LLM 调用周期 —— 大模型基础调用周期控制。
|
||||
|
||||
pub mod compact;
|
||||
pub mod cycle;
|
||||
pub mod error;
|
||||
pub mod hooks;
|
||||
pub mod provider;
|
||||
pub mod stream;
|
||||
pub mod types;
|
||||
@@ -0,0 +1,159 @@
|
||||
//! 上下文自动压缩 —— 当对话历史过长时自动压缩。
|
||||
|
||||
use crate::llm::types::{ContentField, OpenaiChatMessage, OpenaiContentPart};
|
||||
|
||||
const AUTOCOMPACT_BUFFER_TOKENS: u32 = 13_000;
|
||||
const RESERVED_OUTPUT_TOKENS: u32 = 20_000;
|
||||
const MAX_CONSECUTIVE_FAILURES: u32 = 3;
|
||||
const KEEP_RECENT: usize = 6;
|
||||
|
||||
/// 上下文压缩配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactConfig {
|
||||
/// 模型上下文窗口大小(token 数)。
|
||||
pub context_window: u32,
|
||||
/// 为输出预留的 token 数。
|
||||
pub reserved_tokens: u32,
|
||||
/// 微压缩保留的最近消息数。
|
||||
pub keep_recent: usize,
|
||||
}
|
||||
|
||||
impl Default for CompactConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
context_window: 128_000,
|
||||
reserved_tokens: RESERVED_OUTPUT_TOKENS,
|
||||
keep_recent: KEEP_RECENT,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompactConfig {
|
||||
/// 计算自动压缩触发的阈值。
|
||||
pub fn threshold(&self) -> u32 {
|
||||
self.context_window
|
||||
.saturating_sub(self.reserved_tokens)
|
||||
.saturating_sub(AUTOCOMPACT_BUFFER_TOKENS)
|
||||
}
|
||||
}
|
||||
|
||||
/// 压缩状态 —— 跟踪连续失败次数(断路器模式)。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactState {
|
||||
consecutive_failures: u32,
|
||||
}
|
||||
|
||||
impl Default for CompactState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CompactState {
|
||||
/// 创建一个新的压缩状态。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
consecutive_failures: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// 记录一次成功的压缩。
|
||||
pub fn record_success(&mut self) {
|
||||
self.consecutive_failures = 0;
|
||||
}
|
||||
|
||||
/// 记录一次压缩失败。
|
||||
///
|
||||
/// 返回 `true` 表示已达断路器上限,不再尝试。
|
||||
pub fn record_failure(&mut self) -> bool {
|
||||
self.consecutive_failures += 1;
|
||||
self.consecutive_failures >= MAX_CONSECUTIVE_FAILURES
|
||||
}
|
||||
}
|
||||
|
||||
/// 粗略估计消息列表的 token 数(基于字符数,4 字符 ≈ 1 token)。
|
||||
pub fn estimate_message_tokens(messages: &[OpenaiChatMessage]) -> u32 {
|
||||
messages
|
||||
.iter()
|
||||
.map(estimate_single_message_tokens)
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn estimate_single_message_tokens(msg: &OpenaiChatMessage) -> u32 {
|
||||
let role_overhead: u32 = 4;
|
||||
let content_tokens = match msg {
|
||||
OpenaiChatMessage::Developer { content, .. }
|
||||
| OpenaiChatMessage::System { content, .. }
|
||||
| OpenaiChatMessage::User { content, .. }
|
||||
| OpenaiChatMessage::Assistant { content, .. }
|
||||
| OpenaiChatMessage::Function { content, .. } => estimate_content_tokens(content),
|
||||
OpenaiChatMessage::Tool { content, .. } => estimate_content_tokens(content),
|
||||
};
|
||||
role_overhead + content_tokens
|
||||
}
|
||||
|
||||
fn estimate_content_tokens(content: &ContentField) -> u32 {
|
||||
match content {
|
||||
ContentField::String(s) => estimate_text_tokens(s),
|
||||
ContentField::Array(parts) => parts.iter().map(estimate_part_tokens).sum(),
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_part_tokens(part: &OpenaiContentPart) -> u32 {
|
||||
match part {
|
||||
OpenaiContentPart::Text { text } => estimate_text_tokens(text),
|
||||
_ => 50,
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_text_tokens(text: &str) -> u32 {
|
||||
if text.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let len = text.len() as u32;
|
||||
(len * 4).div_ceil(3)
|
||||
}
|
||||
|
||||
/// 判断是否需要触发自动压缩。
|
||||
pub fn should_compact(
|
||||
messages: &[OpenaiChatMessage],
|
||||
config: &CompactConfig,
|
||||
state: &CompactState,
|
||||
) -> bool {
|
||||
if state.consecutive_failures >= MAX_CONSECUTIVE_FAILURES {
|
||||
return false;
|
||||
}
|
||||
let tokens = estimate_message_tokens(messages);
|
||||
tokens >= config.threshold()
|
||||
}
|
||||
|
||||
/// 执行微压缩 —— 用 `[pruned]` 替换旧的 tool result 内容。
|
||||
///
|
||||
/// 这是最便宜的压缩方式,不需要 LLM 调用。
|
||||
/// 保留最近的 `keep_recent` 条消息不变。
|
||||
///
|
||||
/// 返回释放的估算 token 数。
|
||||
pub fn microcompact(messages: &mut [OpenaiChatMessage], keep_recent: usize) -> u32 {
|
||||
if messages.len() <= keep_recent {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let prune_start = messages.len() - keep_recent;
|
||||
let mut freed_tokens: u32 = 0;
|
||||
|
||||
for msg in &messages[..prune_start] {
|
||||
if matches!(msg, OpenaiChatMessage::Tool { .. }) {
|
||||
freed_tokens += estimate_single_message_tokens(msg);
|
||||
}
|
||||
}
|
||||
|
||||
for msg in &mut messages[..prune_start] {
|
||||
if let OpenaiChatMessage::Tool { content, .. } = msg {
|
||||
*content = ContentField::Array(vec![OpenaiContentPart::Text {
|
||||
text: "[pruned]".to_string(),
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
freed_tokens
|
||||
}
|
||||
@@ -0,0 +1,917 @@
|
||||
//! LLM 调用周期控制模块。
|
||||
|
||||
mod retry;
|
||||
pub mod usage;
|
||||
|
||||
pub use retry::RetryConfig;
|
||||
pub use usage::{CostTracker, Usage};
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_core::stream::Stream;
|
||||
use async_stream::stream;
|
||||
|
||||
use crate::llm::compact::{should_compact, microcompact, CompactConfig, CompactState};
|
||||
use crate::llm::cycle::retry::should_retry;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::hooks::{HookContext, HookExecutor};
|
||||
use crate::llm::provider::LlmProvider;
|
||||
use crate::llm::stream::StreamEvent;
|
||||
use crate::llm::types::{
|
||||
ChatRequest, ChatResponse, FinishReason, OpenaiChatMessage, OpenaiTool, OpenaiToolCall,
|
||||
ToolChoice, ToolDefinition,
|
||||
};
|
||||
|
||||
/// LLM 调用周期配置。
|
||||
pub struct CycleConfig {
|
||||
/// 模型名称。
|
||||
pub model: String,
|
||||
/// 最大输出 token 数。
|
||||
pub max_tokens: Option<u32>,
|
||||
/// 采样温度。
|
||||
pub temperature: Option<f32>,
|
||||
/// 最大对话轮次。
|
||||
pub max_turns: Option<u32>,
|
||||
/// 重试策略配置。
|
||||
pub retry: RetryConfig,
|
||||
/// 自动 tool 循环的最大轮次(独立于 `max_turns`,避免影响现有 `submit()` 语义)。
|
||||
/// 默认 `Some(10)`,防止 LLM 反复调用工具导致无限循环。
|
||||
pub max_tool_turns: Option<u32>,
|
||||
/// 单个工具执行的超时秒数(0 表示不超时)。
|
||||
/// 默认 60 秒。
|
||||
pub tool_timeout_secs: u64,
|
||||
/// 单个工具结果的最大字节数(超过此值将被截断)。
|
||||
/// 默认 65536(64KB),防止大结果导致 token 膨胀。
|
||||
pub max_tool_result_bytes: usize,
|
||||
}
|
||||
|
||||
impl Default for CycleConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model: String::from("gpt-4o"),
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
max_turns: None,
|
||||
retry: RetryConfig::default(),
|
||||
max_tool_turns: Some(10),
|
||||
tool_timeout_secs: 60,
|
||||
max_tool_result_bytes: 65_536,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// LLM 调用周期 —— 管理一次或多次 LLM 请求的生命周期。
|
||||
pub struct LlmCycle {
|
||||
provider: Arc<dyn LlmProvider>,
|
||||
config: CycleConfig,
|
||||
usage: CostTracker,
|
||||
messages: Vec<OpenaiChatMessage>,
|
||||
system_prompt: Option<String>,
|
||||
hook_executor: Option<Arc<HookExecutor>>,
|
||||
compact_config: Option<CompactConfig>,
|
||||
compact_state: CompactState,
|
||||
}
|
||||
|
||||
impl LlmCycle {
|
||||
/// 创建一个新的 LlmCycle(持有 `Box<dyn LlmProvider>` 的独占所有权)。
|
||||
///
|
||||
/// 内部将 Box 转为 `Arc<dyn LlmProvider>` 以便 `new_with_arc` 复用句柄。
|
||||
/// 公共签名保持不变,向后兼容。
|
||||
pub fn new(provider: Box<dyn LlmProvider>, config: CycleConfig) -> Self {
|
||||
Self::new_with_arc(Arc::from(provider), config)
|
||||
}
|
||||
|
||||
/// 创建一个新的 LlmCycle,共享传入的 `Arc<dyn LlmProvider>` 句柄。
|
||||
///
|
||||
/// **新增**(Phase 4a 引入):用于 `AgentSession::submit_turn` 在多 session 间共享 provider。
|
||||
pub fn new_with_arc(provider: Arc<dyn LlmProvider>, config: CycleConfig) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
config,
|
||||
usage: CostTracker::default(),
|
||||
messages: Vec::new(),
|
||||
system_prompt: None,
|
||||
hook_executor: None,
|
||||
compact_config: None,
|
||||
compact_state: CompactState::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 设置系统提示词。
|
||||
pub fn with_system_prompt(mut self, prompt: String) -> Self {
|
||||
self.system_prompt = Some(prompt);
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置钩子执行器。
|
||||
pub fn with_hook_executor(mut self, executor: HookExecutor) -> Self {
|
||||
self.hook_executor = Some(Arc::new(executor));
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置上下文压缩配置。
|
||||
pub fn with_compact_config(mut self, config: CompactConfig) -> Self {
|
||||
self.compact_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
/// 获取用量追踪器引用。
|
||||
pub fn usage(&self) -> &CostTracker {
|
||||
&self.usage
|
||||
}
|
||||
|
||||
/// 获取消息历史引用。
|
||||
pub fn messages(&self) -> &[OpenaiChatMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
/// 清空消息历史。
|
||||
pub fn clear_messages(&mut self) {
|
||||
self.messages.clear();
|
||||
}
|
||||
|
||||
/// 重置用量统计。
|
||||
pub fn reset_usage(&mut self) {
|
||||
self.usage.reset();
|
||||
}
|
||||
|
||||
/// 直接设置消息历史(覆盖已有消息),支持 Builder 链式调用。
|
||||
pub fn with_messages(mut self, messages: Vec<OpenaiChatMessage>) -> Self {
|
||||
self.messages = messages;
|
||||
self
|
||||
}
|
||||
|
||||
/// 追加消息到历史尾部。
|
||||
pub fn extend_messages(&mut self, messages: Vec<OpenaiChatMessage>) {
|
||||
self.messages.extend(messages);
|
||||
}
|
||||
|
||||
/// 使用预构建消息提交(跳过自动 push user prompt)。
|
||||
///
|
||||
/// 与 `submit()` 不同,不自动添加 `user_text(prompt)`,也不自动插入 system prompt。
|
||||
/// 调用方完全控制消息序列内容。
|
||||
pub async fn submit_messages(
|
||||
&mut self,
|
||||
messages: Vec<OpenaiChatMessage>,
|
||||
tools: Vec<ToolDefinition>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let openai_tools: Option<Vec<OpenaiTool>> = if tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
tools
|
||||
.iter()
|
||||
.map(|t| OpenaiTool::Function {
|
||||
function: t.clone(),
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
|
||||
let request = ChatRequest {
|
||||
model: self.config.model.clone(),
|
||||
messages,
|
||||
max_tokens: self.config.max_tokens,
|
||||
temperature: self.config.temperature,
|
||||
tools: openai_tools,
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||
.with_request(&request);
|
||||
let results = executor
|
||||
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||
.await;
|
||||
if results.iter().any(|r| r.should_block) {
|
||||
let reason = results
|
||||
.iter()
|
||||
.find(|r| r.should_block)
|
||||
.and_then(|r| r.reason.clone())
|
||||
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||
return Err(LlmError::Other(reason));
|
||||
}
|
||||
}
|
||||
|
||||
match self.provider.chat(request).await {
|
||||
Ok(response) => {
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let post_request = ChatRequest {
|
||||
model: self.config.model.clone(),
|
||||
messages: vec![],
|
||||
..Default::default()
|
||||
};
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
|
||||
.with_request(&post_request);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||
.await;
|
||||
}
|
||||
self.usage.add(&response.usage);
|
||||
Ok(response)
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnError).with_error(&e);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||
.await;
|
||||
}
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 提交用户消息并获取 LLM 响应。
|
||||
pub async fn submit(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
tools: Vec<ToolDefinition>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||
|
||||
if let Some(ref config) = self.compact_config
|
||||
&& should_compact(&self.messages, config, &self.compact_state)
|
||||
{
|
||||
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||
if freed > 0 {
|
||||
self.compact_state.record_success();
|
||||
}
|
||||
}
|
||||
|
||||
let mut attempts = 0;
|
||||
|
||||
loop {
|
||||
let request = self.build_request(&tools);
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||
.with_request(&request);
|
||||
let results = executor
|
||||
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||
.await;
|
||||
if results.iter().any(|r| r.should_block) {
|
||||
let reason = results
|
||||
.iter()
|
||||
.find(|r| r.should_block)
|
||||
.and_then(|r| r.reason.clone())
|
||||
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||
return Err(LlmError::Other(reason));
|
||||
}
|
||||
}
|
||||
|
||||
match self.provider.chat(request).await {
|
||||
Ok(response) => {
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let post_request = self.build_request(&tools);
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
|
||||
.with_request(&post_request);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||
.await;
|
||||
}
|
||||
|
||||
self.messages.push(response.message.clone());
|
||||
self.usage.add(&response.usage);
|
||||
|
||||
return Ok(response);
|
||||
}
|
||||
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
|
||||
attempts += 1;
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnRetry)
|
||||
.with_error(&e)
|
||||
.with_attempt(attempts);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnRetry, &ctx)
|
||||
.await;
|
||||
}
|
||||
|
||||
let delay = self.config.retry.compute_delay(attempts);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx =
|
||||
HookContext::new(crate::llm::hooks::HookEvent::OnError).with_error(&e);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||
.await;
|
||||
}
|
||||
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 提交用户消息并返回语义事件流。
|
||||
///
|
||||
/// 与 `submit` 不同,该方法返回流式事件而非完整响应。
|
||||
/// 适用于需要实时处理 LLM 输出的场景。
|
||||
pub async fn submit_stream(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
tools: Vec<ToolDefinition>,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>, LlmError> {
|
||||
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||
|
||||
if let Some(ref config) = self.compact_config
|
||||
&& should_compact(&self.messages, config, &self.compact_state)
|
||||
{
|
||||
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||
if freed > 0 {
|
||||
self.compact_state.record_success();
|
||||
}
|
||||
}
|
||||
|
||||
let request = self.build_request(&tools);
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||
.with_request(&request);
|
||||
let results = executor
|
||||
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||
.await;
|
||||
if results.iter().any(|r| r.should_block) {
|
||||
let reason = results
|
||||
.iter()
|
||||
.find(|r| r.should_block)
|
||||
.and_then(|r| r.reason.clone())
|
||||
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||
return Err(LlmError::Other(reason));
|
||||
}
|
||||
}
|
||||
|
||||
let chunk_stream = self.provider.chat_stream(request).await?;
|
||||
let hook_executor = self.hook_executor.clone();
|
||||
let post_request = self.build_request(&tools);
|
||||
|
||||
Ok(Box::pin(stream! {
|
||||
use futures_util::StreamExt;
|
||||
let mut chunk_stream = chunk_stream;
|
||||
|
||||
while let Some(result) = chunk_stream.next().await {
|
||||
match result {
|
||||
Ok(chunk) => {
|
||||
let mut assistant_text = String::new();
|
||||
let mut tool_started: Option<(String, String, String)> = None;
|
||||
|
||||
for choice in &chunk.choices {
|
||||
let delta = &choice.delta;
|
||||
|
||||
if let Some(content) = &delta.content {
|
||||
assistant_text.push_str(content);
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = &delta.tool_calls
|
||||
&& let Some(tc) = tool_calls.first()
|
||||
{
|
||||
let crate::llm::types::OpenaiToolCall::Function { id, function } = tc;
|
||||
tool_started = Some((id.clone(), function.name.clone(), function.arguments.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if !assistant_text.is_empty() {
|
||||
yield StreamEvent::AssistantTextDelta { text: assistant_text };
|
||||
}
|
||||
|
||||
if let Some((tool_call_id, tool_name, arguments)) = tool_started {
|
||||
let args: serde_json::Value = serde_json::from_str(&arguments)
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
yield StreamEvent::ToolExecutionStarted {
|
||||
tool_name,
|
||||
input: args,
|
||||
tool_call_id,
|
||||
};
|
||||
}
|
||||
|
||||
for choice in &chunk.choices {
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
yield StreamEvent::TurnComplete {
|
||||
reason: *finish_reason,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(usage_info) = &chunk.usage {
|
||||
yield StreamEvent::CostUpdate { usage: *usage_info };
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(ref executor) = hook_executor {
|
||||
let ctx = crate::llm::hooks::HookContext::new(
|
||||
crate::llm::hooks::HookEvent::OnError,
|
||||
)
|
||||
.with_error(&e);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||
.await;
|
||||
}
|
||||
yield StreamEvent::Error { message: e.to_string() };
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref executor) = hook_executor {
|
||||
let ctx = crate::llm::hooks::HookContext::new(
|
||||
crate::llm::hooks::HookEvent::PostRequest,
|
||||
)
|
||||
.with_request(&post_request);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||
.await;
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest {
|
||||
let mut messages = self.messages.clone();
|
||||
|
||||
if let Some(sys_prompt) = &self.system_prompt
|
||||
&& !messages
|
||||
.iter()
|
||||
.any(|m| matches!(m, OpenaiChatMessage::System { .. }))
|
||||
{
|
||||
messages.insert(0, OpenaiChatMessage::system_text(sys_prompt));
|
||||
}
|
||||
|
||||
let openai_tools: Option<Vec<OpenaiTool>> = if tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
tools
|
||||
.iter()
|
||||
.map(|t| OpenaiTool::Function {
|
||||
function: t.clone(),
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
|
||||
ChatRequest {
|
||||
model: self.config.model.clone(),
|
||||
messages,
|
||||
max_tokens: self.config.max_tokens,
|
||||
temperature: self.config.temperature,
|
||||
tools: openai_tools,
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// 内部请求方法(与 `submit` 共享重试逻辑,但不 push user message 和 Assistant 响应)。
|
||||
///
|
||||
/// 用于 `submit_with_tools()` 的多轮 tool 循环。
|
||||
async fn submit_request(
|
||||
&mut self,
|
||||
tools: &[ToolDefinition],
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let mut attempts = 0;
|
||||
|
||||
loop {
|
||||
let request = self.build_request(tools);
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||
.with_request(&request);
|
||||
let results = executor
|
||||
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||
.await;
|
||||
if results.iter().any(|r| r.should_block) {
|
||||
let reason = results
|
||||
.iter()
|
||||
.find(|r| r.should_block)
|
||||
.and_then(|r| r.reason.clone())
|
||||
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||
return Err(LlmError::Other(reason));
|
||||
}
|
||||
}
|
||||
|
||||
match self.provider.chat(request).await {
|
||||
Ok(response) => {
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let post_request = self.build_request(tools);
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
|
||||
.with_request(&post_request);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||
.await;
|
||||
}
|
||||
self.usage.add(&response.usage);
|
||||
return Ok(response);
|
||||
}
|
||||
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
|
||||
attempts += 1;
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnRetry)
|
||||
.with_error(&e)
|
||||
.with_attempt(attempts);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnRetry, &ctx)
|
||||
.await;
|
||||
}
|
||||
|
||||
let delay = self.config.retry.compute_delay(attempts);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnError)
|
||||
.with_error(&e);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||
.await;
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 提交消息并自动处理工具调用循环。
|
||||
///
|
||||
/// 流程:
|
||||
/// 1. 发送请求(含工具定义)
|
||||
/// 2. 检查响应中的 finish_reason
|
||||
/// 3. 如果是 ToolCalls → push Assistant 消息 → 执行工具 → 回传结果 → 重复 1
|
||||
/// 4. 如果是 Stop/Length → push Assistant 消息 → 返回最终响应
|
||||
///
|
||||
/// 注意:OpenAI API 要求 tool 消息必须紧跟在对应的 Assistant(tool_calls)消息之后。
|
||||
/// 因此 push 工具结果前必须先 push Assistant 响应,否则 API 拒绝请求。
|
||||
pub async fn submit_with_tools(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
registry: &crate::tools::ToolRegistry,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let tools = registry.definitions();
|
||||
let max_turns = self.config.max_tool_turns.unwrap_or(10);
|
||||
let tool_timeout = self.config.tool_timeout_secs;
|
||||
let max_bytes = self.config.max_tool_result_bytes;
|
||||
|
||||
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||
self.maybe_compact();
|
||||
|
||||
let mut turn = 0;
|
||||
|
||||
loop {
|
||||
turn += 1;
|
||||
if turn > max_turns {
|
||||
return Err(LlmError::Other(format!(
|
||||
"达到最大工具循环轮次 ({max_turns})"
|
||||
)));
|
||||
}
|
||||
|
||||
let response = self.submit_request(&tools).await?;
|
||||
|
||||
// 判断是否需要执行工具
|
||||
let should_execute = matches!(response.stop_reason, Some(FinishReason::ToolCalls))
|
||||
&& has_tool_calls_in_message(&response.message);
|
||||
|
||||
// 将 Assistant 响应(含 tool_calls 或最终文本)追加到消息历史
|
||||
self.messages.push(response.message.clone());
|
||||
|
||||
if !should_execute {
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
// 解析 tool_calls 并执行
|
||||
let tool_calls = extract_tool_calls_from_message(&response.message);
|
||||
let calls: Vec<(String, serde_json::Value)> = tool_calls
|
||||
.into_iter()
|
||||
.map(|(_id, name, args)| {
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&args).unwrap_or(serde_json::Value::Null);
|
||||
(name, args)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = registry.invoke_all(calls, tool_timeout).await;
|
||||
|
||||
// 回传工具结果
|
||||
for result in results {
|
||||
let content = match result.output {
|
||||
Ok(value) => {
|
||||
let serialized = serde_json::to_string(&value).unwrap_or_else(|e| {
|
||||
tracing::warn!("工具结果序列化失败: {}", e);
|
||||
"{}".to_string()
|
||||
});
|
||||
truncate_tool_result(&serialized, max_bytes)
|
||||
}
|
||||
Err(e) if e.is_recoverable() => format!("错误: {}", e),
|
||||
Err(e) => {
|
||||
// 不可恢复错误:终止循环
|
||||
return Err(LlmError::Other(format!(
|
||||
"工具 '{}' 不可恢复错误: {}",
|
||||
result.tool_name, e
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
self.messages
|
||||
.push(OpenaiChatMessage::tool_result(result.tool_name, content));
|
||||
}
|
||||
|
||||
// 每轮工具执行后触发 compaction
|
||||
self.maybe_compact();
|
||||
}
|
||||
|
||||
// unreachable: loop returns
|
||||
#[allow(unreachable_code)]
|
||||
{
|
||||
Err(LlmError::Other("unreachable".into()))
|
||||
}
|
||||
}
|
||||
|
||||
/// 在接近上下文窗口时压缩历史消息。
|
||||
fn maybe_compact(&mut self) {
|
||||
if let Some(ref config) = self.compact_config
|
||||
&& should_compact(&self.messages, config, &self.compact_state)
|
||||
{
|
||||
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||
if freed > 0 {
|
||||
self.compact_state.record_success();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 判断 Assistant 消息是否包含 tool_calls。
|
||||
fn has_tool_calls_in_message(msg: &OpenaiChatMessage) -> bool {
|
||||
matches!(
|
||||
msg,
|
||||
OpenaiChatMessage::Assistant {
|
||||
tool_calls: Some(calls),
|
||||
..
|
||||
} if !calls.is_empty()
|
||||
)
|
||||
}
|
||||
|
||||
/// 提取 Assistant 消息中的 tool_calls。
|
||||
///
|
||||
/// 返回 `(tool_call_id, tool_name, arguments_json_string)` 列表。
|
||||
fn extract_tool_calls_from_message(
|
||||
msg: &OpenaiChatMessage,
|
||||
) -> Vec<(String, String, String)> {
|
||||
if let OpenaiChatMessage::Assistant {
|
||||
tool_calls: Some(calls),
|
||||
..
|
||||
} = msg
|
||||
{
|
||||
calls
|
||||
.iter()
|
||||
.map(|c| match c {
|
||||
OpenaiToolCall::Function { id, function } => {
|
||||
(id.clone(), function.name.clone(), function.arguments.clone())
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// 截断工具结果到指定字节数。
|
||||
fn truncate_tool_result(s: &str, max_bytes: usize) -> String {
|
||||
if s.len() <= max_bytes {
|
||||
return s.to_string();
|
||||
}
|
||||
let mut end = max_bytes;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
format!("{}\n\n[... truncated, original size: {} bytes ...]", &s[..end], s.len())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::llm::types::{ContentField, OpenaiContentPart};
|
||||
use crate::tools::{BaseTool, ToolRegistry};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
/// 模拟 Provider —— 预定义响应序列,按调用顺序返回。
|
||||
struct MockProvider {
|
||||
responses: std::sync::Mutex<Vec<ChatResponse>>,
|
||||
call_count: std::sync::Mutex<u32>,
|
||||
}
|
||||
|
||||
impl MockProvider {
|
||||
fn new(responses: Vec<ChatResponse>) -> Self {
|
||||
Self {
|
||||
responses: std::sync::Mutex::new(responses),
|
||||
call_count: std::sync::Mutex::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for MockProvider {
|
||||
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, LlmError> {
|
||||
let mut count = self.call_count.lock().unwrap();
|
||||
*count += 1;
|
||||
let mut responses = self.responses.lock().unwrap();
|
||||
if responses.is_empty() {
|
||||
return Err(LlmError::Other("no more mock responses".into()));
|
||||
}
|
||||
Ok(responses.remove(0))
|
||||
}
|
||||
}
|
||||
|
||||
fn empty_usage() -> crate::llm::types::Usage {
|
||||
crate::llm::types::Usage::default()
|
||||
}
|
||||
|
||||
fn assistant_text_response(text: &str) -> ChatResponse {
|
||||
ChatResponse {
|
||||
message: OpenaiChatMessage::assistant_text(text),
|
||||
usage: empty_usage(),
|
||||
stop_reason: Some(FinishReason::Stop),
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_tool_call_response(
|
||||
calls: Vec<(&str, &str, &str)>,
|
||||
) -> ChatResponse {
|
||||
use crate::llm::types::{OpenaiToolCall, FunctionCall};
|
||||
let tool_calls: Vec<OpenaiToolCall> = calls
|
||||
.into_iter()
|
||||
.map(|(id, name, args)| OpenaiToolCall::Function {
|
||||
id: id.to_string(),
|
||||
function: FunctionCall {
|
||||
name: name.to_string(),
|
||||
arguments: args.to_string(),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
ChatResponse {
|
||||
message: OpenaiChatMessage::Assistant {
|
||||
content: ContentField::Array(vec![OpenaiContentPart::Text {
|
||||
text: String::new(),
|
||||
}]),
|
||||
refusal: None,
|
||||
name: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
},
|
||||
usage: empty_usage(),
|
||||
stop_reason: Some(FinishReason::ToolCalls),
|
||||
}
|
||||
}
|
||||
|
||||
struct AddTool;
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for AddTool {
|
||||
fn name(&self) -> &str {
|
||||
"add"
|
||||
}
|
||||
fn description(&self) -> &str {
|
||||
"加法"
|
||||
}
|
||||
fn parameters(&self) -> Value {
|
||||
json!({"type":"object","properties":{"a":{"type":"integer"},"b":{"type":"integer"}}})
|
||||
}
|
||||
async fn execute(
|
||||
&self,
|
||||
args: Value,
|
||||
_ctx: &crate::tools::ToolContext<'_>,
|
||||
) -> Result<Value, crate::tools::ToolError> {
|
||||
let a = args["a"].as_i64().unwrap_or(0);
|
||||
let b = args["b"].as_i64().unwrap_or(0);
|
||||
Ok(json!({"result": a + b}))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_with_tools_single_turn() {
|
||||
// 第一轮:返回 tool_call;第二轮:返回最终文本
|
||||
let responses = vec![
|
||||
assistant_tool_call_response(vec![("call_1", "add", r#"{"a":1,"b":2}"#)]),
|
||||
assistant_text_response("答案是 3"),
|
||||
];
|
||||
let provider = Box::new(MockProvider::new(responses));
|
||||
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||
|
||||
let response = cycle
|
||||
.submit_with_tools("1+2=?".to_string(), ®istry)
|
||||
.await
|
||||
.unwrap();
|
||||
// 验证最终响应是文本响应
|
||||
assert!(matches!(
|
||||
response.message,
|
||||
OpenaiChatMessage::Assistant { .. }
|
||||
));
|
||||
|
||||
// 验证消息历史:user, assistant(tool_calls), tool, assistant(text)
|
||||
let messages = cycle.messages();
|
||||
assert_eq!(messages.len(), 4);
|
||||
assert!(matches!(messages[0], OpenaiChatMessage::User { .. }));
|
||||
assert!(matches!(messages[1], OpenaiChatMessage::Assistant { .. }));
|
||||
assert!(matches!(messages[2], OpenaiChatMessage::Tool { .. }));
|
||||
assert!(matches!(messages[3], OpenaiChatMessage::Assistant { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_with_tools_multi_turn() {
|
||||
// 3 轮 tool 调用后给出最终答案
|
||||
let responses = vec![
|
||||
assistant_tool_call_response(vec![("call_1", "add", r#"{"a":1,"b":2}"#)]),
|
||||
assistant_tool_call_response(vec![("call_2", "add", r#"{"a":3,"b":4}"#)]),
|
||||
assistant_tool_call_response(vec![("call_3", "add", r#"{"a":5,"b":6}"#)]),
|
||||
assistant_text_response("完成"),
|
||||
];
|
||||
let provider = Box::new(MockProvider::new(responses));
|
||||
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||
|
||||
let response = cycle
|
||||
.submit_with_tools("计算总和".to_string(), ®istry)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
response.message,
|
||||
OpenaiChatMessage::Assistant { .. }
|
||||
));
|
||||
|
||||
// user + 3*(assistant + tool) + final assistant = 8
|
||||
let messages = cycle.messages();
|
||||
assert_eq!(messages.len(), 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_with_tools_max_turns_exceeded() {
|
||||
// 配置 max_tool_turns = 2
|
||||
let mut config = CycleConfig::default();
|
||||
config.max_tool_turns = Some(2);
|
||||
// 4 轮 tool 调用 + 终止
|
||||
let responses = vec![
|
||||
assistant_tool_call_response(vec![("c1", "add", r#"{"a":1,"b":1}"#)]),
|
||||
assistant_tool_call_response(vec![("c2", "add", r#"{"a":1,"b":1}"#)]),
|
||||
assistant_tool_call_response(vec![("c3", "add", r#"{"a":1,"b":1}"#)]),
|
||||
assistant_text_response("完成"),
|
||||
];
|
||||
let provider = Box::new(MockProvider::new(responses));
|
||||
let mut cycle = LlmCycle::new(provider, config);
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||
|
||||
let result = cycle
|
||||
.submit_with_tools("test".to_string(), ®istry)
|
||||
.await;
|
||||
assert!(matches!(result, Err(LlmError::Other(msg)) if msg.contains("达到最大工具循环轮次")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_with_tools_no_tool_call_response() {
|
||||
// LLM 直接给出最终响应(不调用工具)
|
||||
let responses = vec![assistant_text_response("直接回答")];
|
||||
let provider = Box::new(MockProvider::new(responses));
|
||||
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||
|
||||
let response = cycle
|
||||
.submit_with_tools("直接回答".to_string(), ®istry)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
response.message,
|
||||
OpenaiChatMessage::Assistant { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_tool_result_short() {
|
||||
let s = "short text";
|
||||
assert_eq!(truncate_tool_result(s, 100), "short text");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_tool_result_long() {
|
||||
let s = "a".repeat(1000);
|
||||
let truncated = truncate_tool_result(&s, 50);
|
||||
assert!(truncated.len() < s.len());
|
||||
assert!(truncated.contains("[... truncated,"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_tool_result_chinese_chars() {
|
||||
let s = "中".repeat(100);
|
||||
let truncated = truncate_tool_result(&s, 50);
|
||||
// 不会在字符中间截断
|
||||
assert!(truncated.starts_with("中"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
|
||||
/// 重试策略配置。
|
||||
///
|
||||
/// 使用指数退避 + jitter 算法计算每次重试的等待时间。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryConfig {
|
||||
/// 最大重试次数(默认 3)。
|
||||
pub max_retries: u32,
|
||||
/// 初始延迟(默认 1 秒)。
|
||||
pub base_delay: Duration,
|
||||
/// 最大延迟上限(默认 30 秒)。
|
||||
pub max_delay: Duration,
|
||||
/// Jitter 比例因子(默认 0.25)。
|
||||
pub jitter_factor: f64,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_retries: 3,
|
||||
base_delay: Duration::from_secs(1),
|
||||
max_delay: Duration::from_secs(30),
|
||||
jitter_factor: 0.25,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RetryConfig {
|
||||
/// 根据当前重试次数计算等待时间。
|
||||
///
|
||||
/// 算法: `delay = min(base * 2^(attempt-1), max_delay) + random(0, delay * jitter_factor)`
|
||||
pub fn compute_delay(&self, attempt: u32) -> Duration {
|
||||
let base = self.base_delay.as_secs_f64();
|
||||
let exponential = base * (2u64.pow(attempt.saturating_sub(1))) as f64;
|
||||
let capped = exponential.min(self.max_delay.as_secs_f64());
|
||||
let jitter = rand_jitter(capped * self.jitter_factor);
|
||||
|
||||
Duration::from_secs_f64(capped + jitter)
|
||||
}
|
||||
}
|
||||
|
||||
/// 判断错误是否可重试。
|
||||
///
|
||||
/// 可重试条件:
|
||||
/// - RateLimit(429)
|
||||
/// - Timeout
|
||||
/// - Request 且状态码 >= 500 或 == 429
|
||||
pub fn should_retry(err: &LlmError) -> bool {
|
||||
match err {
|
||||
LlmError::RateLimit { .. } => true,
|
||||
LlmError::Timeout { .. } => true,
|
||||
LlmError::Request { status, .. } => *status >= 500 || *status == 429,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// 基于纳秒时间戳的简单伪随机数,范围 [0, max)。
|
||||
fn rand_jitter(max: f64) -> f64 {
|
||||
if max <= 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
let t = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos();
|
||||
let r = (t % 1000) as f64 / 1000.0;
|
||||
r * max
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
pub use crate::llm::types::usage::{
|
||||
CompletionTokensDetails, CostTracker, PromptTokensDetails, Usage, Usage as LlmUsage,
|
||||
};
|
||||
@@ -0,0 +1,37 @@
|
||||
use std::time::Duration;
|
||||
|
||||
/// LLM 调用过程中可能发生的所有错误。
|
||||
///
|
||||
/// 错误按可重试性分为两类:
|
||||
/// - **可重试**:`RateLimit`、`Timeout`、状态码 >= 500
|
||||
/// - **不可重试**:`Authentication`、`ContextLength`、状态码 4xx(除 429)
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum LlmError {
|
||||
/// API 认证失败(如 API key 无效)。
|
||||
#[error("认证失败: {0}")]
|
||||
Authentication(String),
|
||||
|
||||
/// 请求被限流,可选地附带重试等待时间。
|
||||
#[error("限流(retry_after={retry_after:?})")]
|
||||
RateLimit { retry_after: Option<Duration> },
|
||||
|
||||
/// HTTP 请求失败,包含状态码和响应体。
|
||||
#[error("请求失败(status={status}): {body}")]
|
||||
Request { status: u16, body: String },
|
||||
|
||||
/// 请求超时。
|
||||
#[error("请求超时(duration={duration:?})")]
|
||||
Timeout { duration: Duration },
|
||||
|
||||
/// 流式响应处理错误(预留)。
|
||||
#[error("流式响应错误: {0}")]
|
||||
Stream(String),
|
||||
|
||||
/// 上下文长度超限。
|
||||
#[error("上下文超限(actual={actual}, limit={limit})")]
|
||||
ContextLength { actual: u32, limit: u32 },
|
||||
|
||||
/// 其他未分类的 LLM 调用失败。
|
||||
#[error("LLM 调用失败: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
//! 生命周期钩子 —— 在 LLM 调用周期的关键节点插入自定义逻辑。
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::ChatRequest;
|
||||
|
||||
/// 生命周期钩子事件点。
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum HookEvent {
|
||||
/// LLM 请求发起之前(可阻断)。
|
||||
PreRequest,
|
||||
/// 成功响应之后。
|
||||
PostRequest,
|
||||
/// 重试之前(仅可重试错误时触发)。
|
||||
OnRetry,
|
||||
/// 不可恢复错误返回之前。
|
||||
OnError,
|
||||
/// Agent 会话开始一轮 turn 之前(Phase 4a 新增)。
|
||||
OnTurnStart,
|
||||
/// Agent 会话完成一轮 turn 之后(Phase 4a 新增)。
|
||||
OnTurnEnd,
|
||||
/// TaskAgent 完成一个 Plan 步骤后触发(Phase 4b 新增)。
|
||||
OnPlanStepComplete,
|
||||
}
|
||||
|
||||
/// 此次钩子调用的上下文。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookContext<'a> {
|
||||
/// 当前事件点。
|
||||
pub event: HookEvent,
|
||||
/// 当前的请求(部分事件点可用)。
|
||||
pub request: Option<&'a ChatRequest>,
|
||||
/// 当前错误(仅 OnError 和 OnRetry 可用)。
|
||||
pub error: Option<&'a LlmError>,
|
||||
/// 当前重试次数(从 1 开始,仅 OnRetry 可用)。
|
||||
pub attempt: u32,
|
||||
/// 当前 turn 序号(0-based,仅 OnTurnStart / OnTurnEnd 可用,Phase 4a 新增)。
|
||||
pub turn_index: Option<u32>,
|
||||
/// 当前 plan step 序号(0-based,仅 OnPlanStepComplete 可用,Phase 4b 新增)。
|
||||
pub plan_step_index: Option<usize>,
|
||||
}
|
||||
|
||||
impl<'a> HookContext<'a> {
|
||||
pub(crate) fn new(event: HookEvent) -> Self {
|
||||
Self {
|
||||
event,
|
||||
request: None,
|
||||
error: None,
|
||||
attempt: 0,
|
||||
turn_index: None,
|
||||
plan_step_index: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_request(mut self, request: &'a ChatRequest) -> Self {
|
||||
self.request = Some(request);
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn with_error(mut self, error: &'a LlmError) -> Self {
|
||||
self.error = Some(error);
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn with_attempt(mut self, attempt: u32) -> Self {
|
||||
self.attempt = attempt;
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置 turn 序号(仅 OnTurnStart / OnTurnEnd 使用)。
|
||||
pub(crate) fn with_turn_index(mut self, turn_index: u32) -> Self {
|
||||
self.turn_index = Some(turn_index);
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置 plan step 序号(仅 OnPlanStepComplete 使用,Phase 4b 新增)。
|
||||
pub(crate) fn with_plan_step_index(mut self, plan_step_index: usize) -> Self {
|
||||
self.plan_step_index = Some(plan_step_index);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// 钩子执行结果。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookResult {
|
||||
/// 是否阻断后续操作(仅 PreRequest 有效)。
|
||||
pub should_block: bool,
|
||||
/// 阻断/备注原因。
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
impl HookResult {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn allow() -> Self {
|
||||
Self {
|
||||
should_block: false,
|
||||
reason: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn block(reason: impl Into<String>) -> Self {
|
||||
Self {
|
||||
should_block: true,
|
||||
reason: Some(reason.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 生命周期钩子 trait。
|
||||
#[async_trait]
|
||||
pub trait Hook: Send + Sync {
|
||||
/// 执行钩子逻辑。
|
||||
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult;
|
||||
}
|
||||
|
||||
/// 钩子执行器 —— 管理注册与触发。
|
||||
pub struct HookExecutor {
|
||||
hooks: Vec<(HookEvent, Box<dyn Hook>)>,
|
||||
}
|
||||
|
||||
impl Default for HookExecutor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl HookExecutor {
|
||||
/// 创建一个空的执行器。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
hooks: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 注册一个钩子到指定事件点。
|
||||
pub fn register(&mut self, event: HookEvent, hook: Box<dyn Hook>) {
|
||||
self.hooks.push((event, hook));
|
||||
}
|
||||
|
||||
/// 执行指定事件点的所有钩子。
|
||||
pub async fn execute(&self, event: HookEvent, ctx: &HookContext<'_>) -> Vec<HookResult> {
|
||||
let mut results = Vec::new();
|
||||
for (e, hook) in &self.hooks {
|
||||
if *e == event {
|
||||
results.push(hook.execute(ctx).await);
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
pub mod openai;
|
||||
pub mod registry;
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use tokio_stream::Stream;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatChunk};
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ProviderType {
|
||||
OpenAI,
|
||||
DeepSeek,
|
||||
Qwen,
|
||||
}
|
||||
|
||||
impl std::str::FromStr for ProviderType {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"openai" => Ok(ProviderType::OpenAI),
|
||||
"deepseek" => Ok(ProviderType::DeepSeek),
|
||||
"qwen" | "dashscope" | "tongyi" => Ok(ProviderType::Qwen),
|
||||
_ => Err(format!("未知的 Provider 类型: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProviderConfig {
|
||||
pub base_url: String,
|
||||
pub api_key: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
pub fn create_provider(
|
||||
provider_type: ProviderType,
|
||||
config: ProviderConfig,
|
||||
) -> Result<Box<dyn LlmProvider>, LlmError> {
|
||||
match provider_type {
|
||||
ProviderType::OpenAI => Ok(Box::new(openai::OpenaiProvider::new(
|
||||
config.base_url,
|
||||
config.api_key,
|
||||
config.model,
|
||||
))),
|
||||
ProviderType::DeepSeek => {
|
||||
unimplemented!("DeepSeek Provider 尚未实现")
|
||||
}
|
||||
ProviderType::Qwen => {
|
||||
unimplemented!("Qwen Provider 尚未实现")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// LLM Provider 抽象接口。
|
||||
///
|
||||
/// 所有具体的 LLM 后端实现(OpenAI、Anthropic、Azure 等)
|
||||
/// 均需实现此 trait,以实现可插拔替换。
|
||||
#[async_trait]
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
/// 发送聊天请求并返回完整响应。
|
||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError>;
|
||||
|
||||
/// 流式聊天请求 —— 返回原始 SSE chunk 流。
|
||||
///
|
||||
/// 默认实现回退到非流式调用(包装为单元素流)。
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
request: ChatRequest,
|
||||
) -> Result<
|
||||
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
LlmError,
|
||||
> {
|
||||
let response = self.chat(request).await?;
|
||||
let chunk = OpenaiChatChunk::from(response);
|
||||
Ok(Box::pin(tokio_stream::once(Ok(chunk))))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,235 @@
|
||||
use std::pin::Pin;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use futures_core::stream::Stream;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use super::LlmProvider;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{
|
||||
ChatRequest, ChatResponse, OpenaiChatChunk, OpenaiChatResponse, StreamOptions,
|
||||
};
|
||||
|
||||
pub struct OpenaiProvider {
|
||||
http_client: Client,
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl OpenaiProvider {
|
||||
pub fn new(base_url: String, api_key: String, _model: String) -> Self {
|
||||
let http_client = Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.expect("创建 HTTP 客户端失败");
|
||||
|
||||
Self {
|
||||
http_client,
|
||||
base_url,
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_client(mut self, client: Client) -> Self {
|
||||
self.http_client = client;
|
||||
self
|
||||
}
|
||||
|
||||
fn map_reqwest_error(e: reqwest::Error) -> LlmError {
|
||||
if e.is_timeout() {
|
||||
LlmError::Timeout {
|
||||
duration: Duration::from_secs(120),
|
||||
}
|
||||
} else if e.is_connect() {
|
||||
LlmError::Other(format!("连接失败: {}", e))
|
||||
} else {
|
||||
LlmError::Other(format!("请求失败: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenaiProvider {
|
||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError> {
|
||||
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
|
||||
|
||||
info!(model = %request.model, max_tokens = request.max_tokens, temperature = request.temperature, "发送 LLM 请求");
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(error = %e, "请求失败");
|
||||
Self::map_reqwest_error(e)
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
let status_code: u16 = status.as_u16();
|
||||
|
||||
if !status.is_success() {
|
||||
let retry_after = response
|
||||
.headers()
|
||||
.get("retry-after")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.parse::<u64>().ok())
|
||||
.map(Duration::from_secs);
|
||||
let body_text = response.text().await.unwrap_or_default();
|
||||
|
||||
error!(status = status_code, body = %body_text, "请求失败");
|
||||
|
||||
return match status_code {
|
||||
401 => Err(LlmError::Authentication(body_text)),
|
||||
429 => Err(LlmError::RateLimit { retry_after }),
|
||||
_ if status_code >= 500 => Err(LlmError::Request {
|
||||
status: status_code,
|
||||
body: body_text,
|
||||
}),
|
||||
_ if status_code == 400 && body_text.contains("context_length_exceeded") => {
|
||||
Err(LlmError::ContextLength {
|
||||
actual: 0,
|
||||
limit: 0,
|
||||
})
|
||||
}
|
||||
_ => Err(LlmError::Request {
|
||||
status: status_code,
|
||||
body: body_text,
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
let body_text = response.text().await.unwrap_or_default();
|
||||
debug!(body = %body_text, "收到响应体");
|
||||
|
||||
let chat_response: OpenaiChatResponse = serde_json::from_str(&body_text).map_err(|e| {
|
||||
error!(error = %e, body = %body_text, "响应解析失败");
|
||||
LlmError::Other(format!("响应解析失败: {}", e))
|
||||
})?;
|
||||
|
||||
debug!(response = ?chat_response, "收到 LLM 响应");
|
||||
|
||||
Ok(ChatResponse::from(chat_response))
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
mut request: ChatRequest,
|
||||
) -> Result<
|
||||
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
LlmError,
|
||||
> {
|
||||
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
|
||||
|
||||
request.stream = Some(true);
|
||||
request.stream_options = Some(StreamOptions {
|
||||
include_usage: Some(true),
|
||||
include_obfuscation: None,
|
||||
});
|
||||
|
||||
info!(model = %request.model, "发送 LLM 流式请求");
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(error = %e, "流式请求失败");
|
||||
Self::map_reqwest_error(e)
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
let status_code: u16 = status.as_u16();
|
||||
|
||||
if !status.is_success() {
|
||||
let body_text = response.text().await.unwrap_or_default();
|
||||
error!(status = status_code, body = %body_text, "流式请求失败");
|
||||
return Err(LlmError::Request {
|
||||
status: status_code,
|
||||
body: body_text,
|
||||
});
|
||||
}
|
||||
|
||||
let byte_stream = response.bytes_stream().map(|r| {
|
||||
r.map_err(|e| LlmError::Other(format!("流式读取失败: {}", e)))
|
||||
});
|
||||
|
||||
Ok(Box::pin(SseChunkStream::new(byte_stream)))
|
||||
}
|
||||
}
|
||||
|
||||
struct SseChunkStream<S> {
|
||||
inner: S,
|
||||
buffer: String,
|
||||
}
|
||||
|
||||
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> SseChunkStream<S> {
|
||||
fn new(stream: S) -> Self {
|
||||
Self {
|
||||
inner: stream,
|
||||
buffer: String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> Stream for SseChunkStream<S> {
|
||||
type Item = Result<OpenaiChatChunk, LlmError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
if let Some(pos) = self.buffer.find("\n") {
|
||||
let line = self.buffer.drain(..pos + 1).collect::<String>();
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty()
|
||||
|| trimmed == "data: [DONE]"
|
||||
|| trimmed == "[DONE]"
|
||||
|| trimmed == "data:"
|
||||
{
|
||||
continue;
|
||||
}
|
||||
let data = if let Some(p) = trimmed.strip_prefix("data: ") {
|
||||
p
|
||||
} else {
|
||||
trimmed
|
||||
};
|
||||
match serde_json::from_str::<OpenaiChatChunk>(data) {
|
||||
Ok(chunk) => return std::task::Poll::Ready(Some(Ok(chunk))),
|
||||
Err(e) => {
|
||||
return std::task::Poll::Ready(Some(Err(LlmError::Other(format!(
|
||||
"Chunk 解析失败: {} | raw: {}",
|
||||
e, data
|
||||
)))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match Pin::new(&mut self.inner).poll_next(cx) {
|
||||
std::task::Poll::Ready(Some(Ok(bytes))) => {
|
||||
if let Ok(s) = std::str::from_utf8(&bytes) {
|
||||
self.buffer.push_str(s);
|
||||
}
|
||||
}
|
||||
std::task::Poll::Ready(Some(Err(e))) => {
|
||||
return std::task::Poll::Ready(Some(Err(e)));
|
||||
}
|
||||
std::task::Poll::Ready(None) => {
|
||||
if self.buffer.is_empty() {
|
||||
return std::task::Poll::Ready(None);
|
||||
}
|
||||
self.buffer.clear();
|
||||
return std::task::Poll::Ready(None);
|
||||
}
|
||||
std::task::Poll::Pending => return std::task::Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
//! Provider 注册表 —— 多 Provider 实例的注册与发现。
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::provider::{create_provider, LlmProvider, ProviderConfig, ProviderType};
|
||||
|
||||
/// Provider 注册表 —— 管理多个 LLM Provider 实例。
|
||||
///
|
||||
/// 支持注册命名 Provider、按名称查找、设置默认 Provider。
|
||||
pub struct ProviderRegistry {
|
||||
providers: HashMap<String, Box<dyn LlmProvider>>,
|
||||
default_name: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for ProviderRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderRegistry {
|
||||
/// 创建一个空的注册表。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
providers: HashMap::new(),
|
||||
default_name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 注册一个已初始化的 Provider 实例。
|
||||
pub fn register(&mut self, name: impl Into<String>, provider: Box<dyn LlmProvider>) {
|
||||
self.providers.insert(name.into(), provider);
|
||||
}
|
||||
|
||||
/// 通过 ProviderType + ProviderConfig 创建并注册。
|
||||
pub fn register_with_config(
|
||||
&mut self,
|
||||
name: impl Into<String>,
|
||||
provider_type: ProviderType,
|
||||
config: ProviderConfig,
|
||||
) -> Result<(), LlmError> {
|
||||
let provider = create_provider(provider_type, config)?;
|
||||
self.register(name, provider);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 设置默认 Provider。
|
||||
pub fn set_default(&mut self, name: &str) -> Result<(), LlmError> {
|
||||
if !self.providers.contains_key(name) {
|
||||
return Err(LlmError::Other(format!("Provider '{}' 不存在", name)));
|
||||
}
|
||||
self.default_name = Some(name.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 按名称查找 Provider。
|
||||
pub fn get(&self, name: &str) -> Option<&dyn LlmProvider> {
|
||||
self.providers.get(name).map(|p| p.as_ref())
|
||||
}
|
||||
|
||||
/// 获取默认 Provider。
|
||||
pub fn get_default(&self) -> Option<&dyn LlmProvider> {
|
||||
self.default_name
|
||||
.as_ref()
|
||||
.and_then(|name| self.get(name))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
//! 流式事件系统 —— 将 LLM 流式响应解析为语义化事件。
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use futures_core::stream::Stream;
|
||||
use futures_util::future::poll_fn;
|
||||
use futures_util::FutureExt;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{FinishReason, OpenaiChatChunk, OpenaiToolCall, Usage};
|
||||
|
||||
/// 流式事件 —— LLM 调用全生命周期的语义化事件。
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamEvent {
|
||||
/// 助手回复文本增量。
|
||||
AssistantTextDelta { text: String },
|
||||
/// 工具调用开始。
|
||||
ToolExecutionStarted {
|
||||
tool_name: String,
|
||||
input: Value,
|
||||
tool_call_id: String,
|
||||
},
|
||||
/// 工具调用完成。
|
||||
ToolExecutionCompleted {
|
||||
tool_name: String,
|
||||
output: Value,
|
||||
is_error: bool,
|
||||
},
|
||||
/// Token 用量更新。
|
||||
CostUpdate { usage: Usage },
|
||||
/// 一轮会话完成。
|
||||
TurnComplete { reason: FinishReason },
|
||||
/// 错误事件。
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
impl StreamEvent {
|
||||
fn error(message: impl Into<String>) -> Self {
|
||||
Self::Error {
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 将原始 OpenaiChatChunk 流解析为 StreamEvent 流。
|
||||
pub fn parse_chunk_stream(
|
||||
chunks: Pin<Box<dyn futures_core::Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
) -> Pin<Box<dyn futures_core::Stream<Item = StreamEvent> + Send>> {
|
||||
Box::pin(ChunkToEventStream { chunks })
|
||||
}
|
||||
|
||||
struct ChunkToEventStream {
|
||||
chunks: Pin<Box<dyn futures_core::Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
}
|
||||
|
||||
impl Stream for ChunkToEventStream {
|
||||
type Item = StreamEvent;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = &mut *self;
|
||||
poll_fn(|cx| match Pin::new(&mut this.chunks).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(chunk))) => {
|
||||
for choice in &chunk.choices {
|
||||
let delta = &choice.delta;
|
||||
|
||||
if let Some(content) = &delta.content {
|
||||
return Poll::Ready(Some(StreamEvent::AssistantTextDelta {
|
||||
text: content.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = &delta.tool_calls
|
||||
&& let Some(tc) = tool_calls.first()
|
||||
{
|
||||
let OpenaiToolCall::Function { id, function } = tc;
|
||||
let args: Value =
|
||||
serde_json::from_str(&function.arguments).unwrap_or(Value::Null);
|
||||
return Poll::Ready(Some(StreamEvent::ToolExecutionStarted {
|
||||
tool_name: function.name.clone(),
|
||||
input: args,
|
||||
tool_call_id: id.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
return Poll::Ready(Some(StreamEvent::TurnComplete {
|
||||
reason: *finish_reason,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = &chunk.usage {
|
||||
return Poll::Ready(Some(StreamEvent::CostUpdate {
|
||||
usage: *usage,
|
||||
}));
|
||||
}
|
||||
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(StreamEvent::error(e.to_string()))),
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
})
|
||||
.poll_unpin(cx)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
use crate::llm::types::shared::{AudioFormat, ImageDetail};
|
||||
use crate::llm::types::tool::OpenaiToolCall;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ContentField {
|
||||
String(String),
|
||||
Array(Vec<OpenaiContentPart>),
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for ContentField {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
match value {
|
||||
Value::String(s) => Ok(ContentField::String(s)),
|
||||
Value::Array(arr) => {
|
||||
let parts: Result<Vec<OpenaiContentPart>, _> =
|
||||
serde_json::from_value(Value::Array(arr));
|
||||
match parts {
|
||||
Ok(parts) => Ok(ContentField::Array(parts)),
|
||||
Err(e) => Err(serde::de::Error::custom(e)),
|
||||
}
|
||||
}
|
||||
_ => Err(serde::de::Error::custom("content must be string or array")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for ContentField {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
ContentField::String(s) => s.serialize(serializer),
|
||||
ContentField::Array(arr) => arr.serialize(serializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageURL {
|
||||
pub url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub detail: Option<ImageDetail>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InputAudio {
|
||||
pub data: String,
|
||||
pub format: AudioFormat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FileData {
|
||||
pub file_data: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub file_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub filename: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum OpenaiContentPart {
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
Image {
|
||||
image_url: ImageURL,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
detail: Option<ImageDetail>,
|
||||
},
|
||||
InputAudio {
|
||||
input_audio: InputAudio,
|
||||
},
|
||||
File {
|
||||
file: FileData,
|
||||
},
|
||||
Refusal {
|
||||
refusal: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "role")]
|
||||
pub enum OpenaiChatMessage {
|
||||
Developer {
|
||||
content: ContentField,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
},
|
||||
System {
|
||||
content: ContentField,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
},
|
||||
User {
|
||||
content: ContentField,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
},
|
||||
Assistant {
|
||||
content: ContentField,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
refusal: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<OpenaiToolCall>>,
|
||||
},
|
||||
Tool {
|
||||
content: ContentField,
|
||||
tool_call_id: String,
|
||||
},
|
||||
Function {
|
||||
content: ContentField,
|
||||
name: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl OpenaiChatMessage {
|
||||
pub fn user_text<S: Into<String>>(text: S) -> Self {
|
||||
OpenaiChatMessage::User {
|
||||
content: ContentField::Array(vec![OpenaiContentPart::Text { text: text.into() }]),
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant_text<S: Into<String>>(text: S) -> Self {
|
||||
OpenaiChatMessage::Assistant {
|
||||
content: ContentField::String(text.into()),
|
||||
refusal: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system_text<S: Into<String>>(text: S) -> Self {
|
||||
OpenaiChatMessage::System {
|
||||
content: ContentField::Array(vec![OpenaiContentPart::Text { text: text.into() }]),
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn developer_text<S: Into<String>>(text: S) -> Self {
|
||||
OpenaiChatMessage::Developer {
|
||||
content: ContentField::Array(vec![OpenaiContentPart::Text { text: text.into() }]),
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_result<S: Into<String>>(tool_call_id: String, content: S) -> Self {
|
||||
OpenaiChatMessage::Tool {
|
||||
content: ContentField::Array(vec![OpenaiContentPart::Text {
|
||||
text: content.into(),
|
||||
}]),
|
||||
tool_call_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
pub mod message;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
pub mod shared;
|
||||
pub mod tool;
|
||||
pub mod usage;
|
||||
|
||||
pub use message::{
|
||||
ContentField, FileData, ImageURL, InputAudio, OpenaiChatMessage, OpenaiContentPart,
|
||||
};
|
||||
pub use request::{OpenaiChatRequest, OpenaiTool, StreamOptions, ToolChoice};
|
||||
pub use response::{
|
||||
Annotation, Choice, ChunkChoice, Delta, Logprobs, OpenaiAudio, OpenaiChatChunk,
|
||||
OpenaiChatResponse, TokenLogprob, TopLogprob, URLCitation,
|
||||
};
|
||||
pub use shared::{
|
||||
AudioFormat, FinishReason, ImageDetail, Modality, ResponseFormat, Role, ServiceTier,
|
||||
StopSequence,
|
||||
};
|
||||
pub use tool::{FunctionCall, OpenaiToolCall, OpenaiToolDefinition};
|
||||
pub use usage::{CompletionTokensDetails, CostTracker, PromptTokensDetails, Usage};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatResponse {
|
||||
pub message: OpenaiChatMessage,
|
||||
pub usage: Usage,
|
||||
pub stop_reason: Option<FinishReason>,
|
||||
}
|
||||
|
||||
impl From<OpenaiChatResponse> for ChatResponse {
|
||||
fn from(response: OpenaiChatResponse) -> Self {
|
||||
let message = response
|
||||
.choices
|
||||
.first()
|
||||
.map(|c| c.message.clone())
|
||||
.unwrap_or_else(|| OpenaiChatMessage::assistant_text(""));
|
||||
let stop_reason = response.choices.first().and_then(|c| c.finish_reason);
|
||||
ChatResponse {
|
||||
message,
|
||||
usage: response.usage,
|
||||
stop_reason,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ChatResponse> for OpenaiChatChunk {
|
||||
fn from(response: ChatResponse) -> Self {
|
||||
let delta = Delta::from(response.message.clone());
|
||||
let chunk_choice = ChunkChoice {
|
||||
index: 0,
|
||||
delta,
|
||||
logprobs: None,
|
||||
finish_reason: response.stop_reason,
|
||||
};
|
||||
|
||||
OpenaiChatChunk {
|
||||
id: format!("chunk-{}", std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_nanos())
|
||||
.unwrap_or(0)),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0),
|
||||
model: String::new(),
|
||||
choices: vec![chunk_choice],
|
||||
usage: Some(response.usage),
|
||||
system_fingerprint: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type ChatRequest = OpenaiChatRequest;
|
||||
pub type Message = OpenaiChatMessage;
|
||||
pub type ContentBlock = OpenaiContentPart;
|
||||
pub type ToolDefinition = OpenaiToolDefinition;
|
||||
pub type StopReason = FinishReason;
|
||||
@@ -0,0 +1,181 @@
|
||||
use crate::llm::types::shared::{ResponseFormat, ServiceTier, StopSequence};
|
||||
use crate::llm::types::tool::OpenaiToolDefinition;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StreamOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub include_usage: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub include_obfuscation: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ToolChoice {
|
||||
None,
|
||||
Auto,
|
||||
Required,
|
||||
Named { name: String },
|
||||
AllowedTools { tool_names: Vec<String> },
|
||||
}
|
||||
|
||||
impl Serialize for ToolChoice {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
ToolChoice::None => serializer.serialize_str("none"),
|
||||
ToolChoice::Auto => serializer.serialize_str("auto"),
|
||||
ToolChoice::Required => serializer.serialize_str("required"),
|
||||
ToolChoice::Named { name } => {
|
||||
let obj = serde_json::json!({
|
||||
"type": "function",
|
||||
"function": { "name": name }
|
||||
});
|
||||
obj.serialize(serializer)
|
||||
}
|
||||
ToolChoice::AllowedTools { tool_names } => {
|
||||
let obj = serde_json::json!({
|
||||
"type": "function",
|
||||
"function": { "name": tool_names.first().cloned().unwrap_or_default() }
|
||||
});
|
||||
obj.serialize(serializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for ToolChoice {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
match value {
|
||||
Value::String(s) => match s.as_str() {
|
||||
"none" => Ok(ToolChoice::None),
|
||||
"auto" => Ok(ToolChoice::Auto),
|
||||
"required" => Ok(ToolChoice::Required),
|
||||
_ => Err(serde::de::Error::custom(format!(
|
||||
"unknown tool choice: {s}"
|
||||
))),
|
||||
},
|
||||
Value::Object(obj) => {
|
||||
let typ = obj.get("type").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||
serde::de::Error::custom("missing 'type' field in tool_choice")
|
||||
})?;
|
||||
if typ == "function" {
|
||||
let func =
|
||||
obj.get("function")
|
||||
.and_then(|v| v.as_object())
|
||||
.ok_or_else(|| {
|
||||
serde::de::Error::custom("missing 'function' field in tool_choice")
|
||||
})?;
|
||||
let name = func.get("name").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||
serde::de::Error::custom("missing 'function.name' in tool_choice")
|
||||
})?;
|
||||
Ok(ToolChoice::Named {
|
||||
name: name.to_string(),
|
||||
})
|
||||
} else {
|
||||
Err(serde::de::Error::custom(format!(
|
||||
"unknown tool_choice type: {typ}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
_ => Err(serde::de::Error::custom(
|
||||
"tool_choice must be a string or object",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum OpenaiTool {
|
||||
Function { function: OpenaiToolDefinition },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AudioParam {
|
||||
pub format: String,
|
||||
pub voice: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PredictionContent {
|
||||
#[serde(rename = "type")]
|
||||
pub pred_type: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UserLocation {
|
||||
#[serde(rename = "type")]
|
||||
pub loc_type: String,
|
||||
pub approximate: Approximate,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Approximate {
|
||||
pub city: String,
|
||||
pub country: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub region: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub timezone: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WebSearchOptions {
|
||||
pub search_context_size: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user_location: Option<UserLocation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct OpenaiChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<crate::llm::types::message::OpenaiChatMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logit_bias: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub n: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<i64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub service_tier: Option<ServiceTier>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<StopSequence>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<OpenaiTool>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub extra_headers: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub extra_body: Option<Value>,
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
use crate::llm::types::message::OpenaiChatMessage;
|
||||
use crate::llm::types::shared::{FinishReason, ServiceTier};
|
||||
use crate::llm::types::tool::OpenaiToolCall;
|
||||
use crate::llm::types::usage::Usage;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::llm::types::{ContentField, OpenaiContentPart};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TokenLogprob {
|
||||
pub token: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bytes: Option<Vec<u32>>,
|
||||
pub logprob: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_logprobs: Option<Vec<TopLogprob>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TopLogprob {
|
||||
pub token: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bytes: Option<Vec<u32>>,
|
||||
pub logprob: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Logprobs {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<Vec<TokenLogprob>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub refusal: Option<Vec<TokenLogprob>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct URLCitation {
|
||||
pub end_index: u32,
|
||||
pub start_index: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Annotation {
|
||||
#[serde(rename = "type")]
|
||||
pub ann_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub url_citation: Option<URLCitation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenaiAudio {
|
||||
pub id: String,
|
||||
pub data: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub expires_at: Option<i64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub transcript: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: OpenaiChatMessage,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub finish_reason: Option<FinishReason>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<Logprobs>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenaiChatResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Usage,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub service_tier: Option<ServiceTier>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Delta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub refusal: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<OpenaiToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChunkChoice {
|
||||
pub index: u32,
|
||||
pub delta: Delta,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<Logprobs>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub finish_reason: Option<FinishReason>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenaiChatChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<Usage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
}
|
||||
|
||||
impl From<OpenaiChatMessage> for Delta {
|
||||
fn from(msg: OpenaiChatMessage) -> Self {
|
||||
match msg {
|
||||
OpenaiChatMessage::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
..
|
||||
} => Delta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: match content {
|
||||
ContentField::String(s) => Some(s),
|
||||
ContentField::Array(parts) => {
|
||||
let mut text = String::new();
|
||||
for part in parts {
|
||||
if let OpenaiContentPart::Text { text: t } = part {
|
||||
text.push_str(&t);
|
||||
}
|
||||
}
|
||||
if text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text)
|
||||
}
|
||||
}
|
||||
},
|
||||
refusal: None,
|
||||
tool_calls,
|
||||
},
|
||||
_ => Delta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<OpenaiChatResponse> for OpenaiChatChunk {
|
||||
fn from(response: OpenaiChatResponse) -> Self {
|
||||
let choices = response
|
||||
.choices
|
||||
.into_iter()
|
||||
.map(|c| ChunkChoice {
|
||||
index: c.index,
|
||||
delta: Delta::from(c.message),
|
||||
logprobs: c.logprobs,
|
||||
finish_reason: c.finish_reason,
|
||||
})
|
||||
.collect();
|
||||
|
||||
OpenaiChatChunk {
|
||||
id: response.id,
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
choices,
|
||||
usage: Some(response.usage),
|
||||
system_fingerprint: response.system_fingerprint,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Role {
|
||||
Developer,
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
Tool,
|
||||
Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum FinishReason {
|
||||
Stop,
|
||||
Length,
|
||||
ToolCalls,
|
||||
ContentFilter,
|
||||
FunctionCall,
|
||||
#[serde(other)]
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ServiceTier {
|
||||
Auto,
|
||||
Default,
|
||||
#[serde(other)]
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Modality {
|
||||
Text,
|
||||
Audio,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ImageDetail {
|
||||
Auto,
|
||||
Low,
|
||||
High,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AudioFormat {
|
||||
Wav,
|
||||
Mp3,
|
||||
Aac,
|
||||
Flac,
|
||||
Opus,
|
||||
Pcm16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StopSequence {
|
||||
Single(String),
|
||||
Multiple(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum ResponseFormat {
|
||||
Text,
|
||||
JsonObject,
|
||||
JsonSchema {
|
||||
schema: serde_json::Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
strict: Option<bool>,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenaiToolDefinition {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub parameters: Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub strict: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum OpenaiToolCall {
|
||||
Function { id: String, function: FunctionCall },
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_tokens_details: Option<PromptTokensDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
|
||||
pub struct CompletionTokensDetails {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub audio_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub accepted_prediction_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rejected_prediction_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
|
||||
pub struct PromptTokensDetails {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub audio_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cached_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct CostTracker {
|
||||
accumulated: Usage,
|
||||
}
|
||||
|
||||
impl CostTracker {
|
||||
pub fn add(&mut self, usage: &Usage) {
|
||||
self.accumulated.prompt_tokens = self
|
||||
.accumulated
|
||||
.prompt_tokens
|
||||
.saturating_add(usage.prompt_tokens);
|
||||
self.accumulated.completion_tokens = self
|
||||
.accumulated
|
||||
.completion_tokens
|
||||
.saturating_add(usage.completion_tokens);
|
||||
self.accumulated.total_tokens = self
|
||||
.accumulated
|
||||
.total_tokens
|
||||
.saturating_add(usage.total_tokens);
|
||||
}
|
||||
|
||||
pub fn total(&self) -> &Usage {
|
||||
&self.accumulated
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.accumulated = Usage::default();
|
||||
}
|
||||
}
|
||||
|
||||
impl Usage {
|
||||
pub fn from_input_output(input: u32, output: u32) -> Self {
|
||||
let total = input.saturating_add(output);
|
||||
Usage {
|
||||
prompt_tokens: input,
|
||||
completion_tokens: output,
|
||||
total_tokens: total,
|
||||
completion_tokens_details: None,
|
||||
prompt_tokens_details: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
//! 记忆系统 —— 对话消息管理、知识页面存储与关键词检索。
|
||||
|
||||
pub mod conversation;
|
||||
pub mod error;
|
||||
pub mod knowledge;
|
||||
pub mod retriever;
|
||||
pub mod store;
|
||||
pub mod types;
|
||||
|
||||
// 高频类型(大多数下游需要)
|
||||
pub use conversation::{ConversationMemory, ConversationMemoryConfig};
|
||||
pub use error::MemoryError;
|
||||
pub use knowledge::KnowledgeStore;
|
||||
pub use retriever::MemoryRetriever;
|
||||
pub use store::{InMemoryStore, MemoryStore};
|
||||
|
||||
// 低频类型(配置/高级使用)
|
||||
pub use conversation::MemoryStrategy;
|
||||
pub use knowledge::{PageIndexEntry, KNOWLEDGE_PREFIX};
|
||||
pub use retriever::{RetrieverConfig, RetrievalResult, ScoredItem};
|
||||
pub use store::{EvictionConfig, EvictionPolicy};
|
||||
pub use types::{KnowledgePage, MemoryFilter, MemoryItem};
|
||||
@@ -0,0 +1,260 @@
|
||||
//! 对话记忆 —— 多轮对话消息管理,复用 `llm::compact` 的压缩逻辑。
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use crate::llm::compact::{CompactConfig, CompactState, microcompact, should_compact};
|
||||
use crate::llm::types::OpenaiChatMessage;
|
||||
use crate::memory::error::MemoryError;
|
||||
use crate::memory::store::MemoryStore;
|
||||
use crate::memory::types::MemoryItem;
|
||||
|
||||
/// 对话消息管理策略。
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum MemoryStrategy {
|
||||
/// 滑动窗口:达到上限时删除最旧消息。
|
||||
SlidingWindow,
|
||||
/// 保留所有消息(仅压缩,不删除)。
|
||||
Full,
|
||||
}
|
||||
|
||||
/// 对话记忆配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConversationMemoryConfig {
|
||||
pub strategy: MemoryStrategy,
|
||||
pub max_turns: usize,
|
||||
pub compact_config: Option<CompactConfig>,
|
||||
}
|
||||
|
||||
impl Default for ConversationMemoryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
strategy: MemoryStrategy::SlidingWindow,
|
||||
max_turns: 50,
|
||||
compact_config: Some(CompactConfig::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 对话记忆 —— 按 session 管理多轮对话消息历史。
|
||||
///
|
||||
/// 内部维护 `Vec<OpenaiChatMessage>` 热缓存(供 `llm::compact` 直接操作),
|
||||
/// `MemoryStore` 用作冷持久化层。
|
||||
pub struct ConversationMemory {
|
||||
store: Arc<dyn MemoryStore>,
|
||||
session_id: String,
|
||||
config: ConversationMemoryConfig,
|
||||
/// 热缓存:消息列表,供 `llm::compact` 直接操作。
|
||||
messages: Vec<OpenaiChatMessage>,
|
||||
/// 与 `messages` 一一对应的存储 ID(保持稳定以便淘汰时精准删除)。
|
||||
message_ids: Vec<String>,
|
||||
/// 压缩断路器状态。
|
||||
compact_state: CompactState,
|
||||
}
|
||||
|
||||
impl ConversationMemory {
|
||||
/// 创建一个新的 ConversationMemory。
|
||||
pub fn new(
|
||||
store: Arc<dyn MemoryStore>,
|
||||
session_id: impl Into<String>,
|
||||
config: ConversationMemoryConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
store,
|
||||
session_id: session_id.into(),
|
||||
config,
|
||||
messages: Vec::new(),
|
||||
message_ids: Vec::new(),
|
||||
compact_state: CompactState::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取 session id。
|
||||
pub fn session_id(&self) -> &str {
|
||||
&self.session_id
|
||||
}
|
||||
|
||||
/// 获取配置。
|
||||
pub fn config(&self) -> &ConversationMemoryConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// 从 MemoryStore 加载历史消息到热缓存。
|
||||
pub async fn load(&mut self) -> Result<(), MemoryError> {
|
||||
let filter = crate::memory::types::MemoryFilter {
|
||||
prefix: Some(self.session_prefix()),
|
||||
..Default::default()
|
||||
};
|
||||
let items = self.store.list(&filter).await?;
|
||||
let mut pairs: Vec<(String, OpenaiChatMessage, OffsetDateTime)> = Vec::with_capacity(items.len());
|
||||
for item in items {
|
||||
match serde_json::from_str::<OpenaiChatMessage>(&item.content) {
|
||||
Ok(msg) => pairs.push((item.id, msg, item.created_at)),
|
||||
Err(e) => {
|
||||
return Err(MemoryError::Serialization(format!(
|
||||
"load message {} failed: {e}",
|
||||
item.id
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
// 按 created_at 升序排列
|
||||
pairs.sort_by_key(|p| p.2);
|
||||
self.message_ids = pairs.iter().map(|p| p.0.clone()).collect();
|
||||
self.messages = pairs.into_iter().map(|p| p.1).collect();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 添加一条消息。
|
||||
///
|
||||
/// 写入热缓存并通过 `MemoryStore` 持久化。如有需要,触发淘汰和压缩。
|
||||
pub async fn add_message(&mut self, msg: OpenaiChatMessage) -> Result<(), MemoryError> {
|
||||
let now = OffsetDateTime::now_utc();
|
||||
let index = self.messages.len();
|
||||
let id = self.make_message_id(index, &now);
|
||||
|
||||
// 写入热缓存
|
||||
self.messages.push(msg);
|
||||
self.message_ids.push(id.clone());
|
||||
|
||||
// 同步到冷存储
|
||||
let item = MemoryItem {
|
||||
id: id.clone(),
|
||||
content: serde_json::to_string(self.messages.last().unwrap())
|
||||
.map_err(|e| MemoryError::Serialization(e.to_string()))?,
|
||||
metadata: serde_json::json!({ "session_id": &self.session_id, "index": index }),
|
||||
created_at: now,
|
||||
};
|
||||
self.store.save(item).await?;
|
||||
|
||||
// 触发淘汰和压缩
|
||||
self.maybe_evict_and_compact().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取完整消息历史。
|
||||
pub fn get_history(&self) -> &[OpenaiChatMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
/// 清空所有消息。
|
||||
pub async fn clear(&mut self) -> Result<(), MemoryError> {
|
||||
let to_delete = std::mem::take(&mut self.message_ids);
|
||||
self.messages.clear();
|
||||
self.compact_state = CompactState::new();
|
||||
for id in to_delete {
|
||||
self.store.delete(&id).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 当前消息数量。
|
||||
pub fn len(&self) -> usize {
|
||||
self.messages.len()
|
||||
}
|
||||
|
||||
/// 是否为空。
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.messages.is_empty()
|
||||
}
|
||||
|
||||
fn session_prefix(&self) -> String {
|
||||
format!("conv:{self}:", self = self.session_id)
|
||||
}
|
||||
|
||||
fn make_message_id(&self, index: usize, now: &OffsetDateTime) -> String {
|
||||
format!("{}{:010}_{}", self.session_prefix(), index, now.unix_timestamp_nanos())
|
||||
}
|
||||
|
||||
async fn maybe_evict_and_compact(&mut self) {
|
||||
// 1. Sliding window 淘汰:删除最旧消息
|
||||
if self.config.strategy == MemoryStrategy::SlidingWindow {
|
||||
while self.messages.len() > self.config.max_turns {
|
||||
if let Some(removed_id) = self.message_ids.first().cloned() {
|
||||
let _ = self.store.delete(&removed_id).await;
|
||||
}
|
||||
self.messages.remove(0);
|
||||
self.message_ids.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 压缩(复用 llm::compact)
|
||||
if let Some(ref compact_config) = self.config.compact_config {
|
||||
if should_compact(&self.messages, compact_config, &self.compact_state) {
|
||||
let keep_recent = compact_config.keep_recent;
|
||||
let freed = microcompact(&mut self.messages, keep_recent);
|
||||
if freed > 0 {
|
||||
self.compact_state.record_success();
|
||||
} else {
|
||||
// 没有 token 被释放(可能没找到可压缩的 tool result)
|
||||
let _ = self.compact_state.record_failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::llm::types::OpenaiChatMessage;
|
||||
use crate::memory::InMemoryStore;
|
||||
use crate::memory::MemoryStore;
|
||||
|
||||
fn user_text(s: &str) -> OpenaiChatMessage {
|
||||
OpenaiChatMessage::user_text(s)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_and_get_history() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let mut conv = ConversationMemory::new(store, "session1", ConversationMemoryConfig::default());
|
||||
conv.add_message(user_text("hello")).await.unwrap();
|
||||
conv.add_message(user_text("world")).await.unwrap();
|
||||
assert_eq!(conv.len(), 2);
|
||||
assert_eq!(conv.get_history().len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sliding_window_evicts_oldest() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let config = ConversationMemoryConfig {
|
||||
strategy: MemoryStrategy::SlidingWindow,
|
||||
max_turns: 3,
|
||||
compact_config: None,
|
||||
};
|
||||
let mut conv = ConversationMemory::new(store, "s1", config);
|
||||
for i in 0..5 {
|
||||
conv.add_message(user_text(&format!("msg-{i}"))).await.unwrap();
|
||||
}
|
||||
assert_eq!(conv.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn full_strategy_no_evict() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let config = ConversationMemoryConfig {
|
||||
strategy: MemoryStrategy::Full,
|
||||
max_turns: 3,
|
||||
compact_config: None,
|
||||
};
|
||||
let mut conv = ConversationMemory::new(store, "s1", config);
|
||||
for i in 0..5 {
|
||||
conv.add_message(user_text(&format!("msg-{i}"))).await.unwrap();
|
||||
}
|
||||
// Full 策略不删除消息
|
||||
assert_eq!(conv.len(), 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn clear_empties_messages() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let mut conv = ConversationMemory::new(store.clone(), "s1", ConversationMemoryConfig::default());
|
||||
conv.add_message(user_text("hello")).await.unwrap();
|
||||
assert!(!conv.is_empty());
|
||||
conv.clear().await.unwrap();
|
||||
assert!(conv.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
//! 记忆系统错误类型。
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// 记忆系统错误枚举。
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MemoryError {
|
||||
#[error("Item not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Storage error: {0}")]
|
||||
Storage(String),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(String),
|
||||
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
#[error("Retrieval error: {0}")]
|
||||
RetrievalError(String),
|
||||
}
|
||||
|
||||
impl MemoryError {
|
||||
/// 是否为可恢复错误(调用方可重试或调整参数)。
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
matches!(self, Self::NotFound(_) | Self::RetrievalError(_))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
//! 知识库 —— KnowledgePage 存储与关键词检索。
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use crate::memory::error::MemoryError;
|
||||
use crate::memory::store::MemoryStore;
|
||||
use crate::memory::types::{KnowledgePage, MemoryFilter, MemoryItem};
|
||||
|
||||
pub use crate::memory::types::PageIndexEntry;
|
||||
|
||||
/// `MemoryItem.id` 中知识页面前缀。
|
||||
pub const KNOWLEDGE_PREFIX: &str = "knowledge_";
|
||||
|
||||
/// 知识库 —— KnowledgePage CRUD + 关键词检索 + 内容索引。
|
||||
///
|
||||
/// 内部以 `MemoryStore` 为后端存储 KnowledgePage(序列化为 JSON),
|
||||
/// 同时维护一个 `Vec<PageIndexEntry>` 索引以加速列表遍历。
|
||||
pub struct KnowledgeStore {
|
||||
store: Arc<dyn MemoryStore>,
|
||||
index: std::sync::Mutex<Vec<PageIndexEntry>>,
|
||||
}
|
||||
|
||||
impl KnowledgeStore {
|
||||
/// 创建一个新的 KnowledgeStore。
|
||||
pub fn new(store: Arc<dyn MemoryStore>) -> Self {
|
||||
Self {
|
||||
store,
|
||||
index: std::sync::Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// 从 MemoryStore 重建索引(修复 index 与 store 的不同步问题)。
|
||||
pub async fn rebuild_index(&self) -> Result<(), MemoryError> {
|
||||
let items = self
|
||||
.store
|
||||
.list(&MemoryFilter {
|
||||
prefix: Some(KNOWLEDGE_PREFIX.to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let mut index = self.index.lock().unwrap();
|
||||
index.clear();
|
||||
for item in items {
|
||||
let page: KnowledgePage = serde_json::from_str(&item.content)
|
||||
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||
index.push(PageIndexEntry::from(&page));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 创建一个新的知识页面。
|
||||
pub async fn add_page(&self, page: KnowledgePage) -> Result<(), MemoryError> {
|
||||
if page.id.is_empty() {
|
||||
return Err(MemoryError::InvalidInput("page.id is empty".into()));
|
||||
}
|
||||
let now = OffsetDateTime::now_utc();
|
||||
let id = format!("{KNOWLEDGE_PREFIX}{}", page.id);
|
||||
let content = serde_json::to_string(&page)
|
||||
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||
let item = MemoryItem {
|
||||
id,
|
||||
content,
|
||||
metadata: serde_json::json!({}),
|
||||
created_at: now,
|
||||
};
|
||||
self.store.save(item).await?;
|
||||
let mut index = self.index.lock().unwrap();
|
||||
// 替换或追加
|
||||
if let Some(existing) = index.iter_mut().find(|e| e.id == page.id) {
|
||||
*existing = PageIndexEntry::from(&page);
|
||||
} else {
|
||||
index.push(PageIndexEntry::from(&page));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 根据 page id 获取一个页面。
|
||||
pub async fn get_page(&self, id: &str) -> Result<Option<KnowledgePage>, MemoryError> {
|
||||
let full_id = format!("{KNOWLEDGE_PREFIX}{id}");
|
||||
let item = self.store.get(&full_id).await?;
|
||||
match item {
|
||||
None => Ok(None),
|
||||
Some(item) => {
|
||||
let page: KnowledgePage = serde_json::from_str(&item.content)
|
||||
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||
Ok(Some(page))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 更新一个已存在的知识页面。
|
||||
pub async fn update_page(&self, page: KnowledgePage) -> Result<(), MemoryError> {
|
||||
if page.id.is_empty() {
|
||||
return Err(MemoryError::InvalidInput("page.id is empty".into()));
|
||||
}
|
||||
// 通过 get_page 检查存在性
|
||||
if self.get_page(&page.id).await?.is_none() {
|
||||
return Err(MemoryError::NotFound(page.id));
|
||||
}
|
||||
self.add_page(page).await
|
||||
}
|
||||
|
||||
/// 删除一个知识页面。
|
||||
pub async fn delete_page(&self, id: &str) -> Result<(), MemoryError> {
|
||||
let full_id = format!("{KNOWLEDGE_PREFIX}{id}");
|
||||
self.store.delete(&full_id).await?;
|
||||
let mut index = self.index.lock().unwrap();
|
||||
index.retain(|e| e.id != id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 根据关键词搜索知识页面。
|
||||
///
|
||||
/// 匹配规则:在 `title` / `summary` / `tags` 中查找子串(不区分大小写)。
|
||||
/// 全文 `content` 搜索走 `MemoryStore`。
|
||||
pub async fn search(&self, query: &str) -> Result<Vec<KnowledgePage>, MemoryError> {
|
||||
if query.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let needle = query.to_lowercase();
|
||||
let mut results = Vec::new();
|
||||
let index = self.index.lock().unwrap();
|
||||
for entry in index.iter() {
|
||||
if entry.title.to_lowercase().contains(&needle)
|
||||
|| entry.summary.to_lowercase().contains(&needle)
|
||||
|| entry.tags.iter().any(|t| t.to_lowercase().contains(&needle))
|
||||
{
|
||||
if let Some(page) = self.get_page(&entry.id).await? {
|
||||
results.push(page);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// 获取内容目录(所有页面的轻量级索引条目)。
|
||||
pub fn get_index(&self) -> Vec<PageIndexEntry> {
|
||||
self.index.lock().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::InMemoryStore;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
fn make_page(id: &str, title: &str, tags: &[&str]) -> KnowledgePage {
|
||||
let now = OffsetDateTime::now_utc();
|
||||
KnowledgePage {
|
||||
id: id.to_string(),
|
||||
title: title.to_string(),
|
||||
summary: format!("summary of {title}"),
|
||||
content: format!("full content of {title}"),
|
||||
tags: tags.iter().map(|s| s.to_string()).collect(),
|
||||
references: Vec::new(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_get_delete_page() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
ks.add_page(make_page("p1", "LangGraph", &["langgraph", "framework"]))
|
||||
.await
|
||||
.unwrap();
|
||||
let got = ks.get_page("p1").await.unwrap();
|
||||
assert!(got.is_some());
|
||||
assert_eq!(got.unwrap().title, "LangGraph");
|
||||
|
||||
ks.delete_page("p1").await.unwrap();
|
||||
assert!(ks.get_page("p1").await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_page_rejects_empty_id() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
let result = ks.add_page(make_page("", "NoId", &[])).await;
|
||||
assert!(matches!(result, Err(MemoryError::InvalidInput(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn update_page_requires_existing() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
let result = ks.update_page(make_page("nope", "Ghost", &[])).await;
|
||||
assert!(matches!(result, Err(MemoryError::NotFound(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn search_finds_by_title_summary_tag() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
ks.add_page(make_page("p1", "LangGraph StateGraph", &["llm"]))
|
||||
.await
|
||||
.unwrap();
|
||||
ks.add_page(make_page("p2", "Other", &["knowledge-graph"]))
|
||||
.await
|
||||
.unwrap();
|
||||
ks.add_page(make_page("p3", "Third", &["unrelated"]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = ks.search("stategraph").await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].id, "p1");
|
||||
|
||||
let results = ks.search("knowledge-graph").await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].id, "p2");
|
||||
|
||||
let results = ks.search("nonexistent").await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_index_returns_all_pages() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
ks.add_page(make_page("p1", "A", &[])).await.unwrap();
|
||||
ks.add_page(make_page("p2", "B", &[])).await.unwrap();
|
||||
let index = ks.get_index();
|
||||
assert_eq!(index.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rebuild_index_recovers_from_drift() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
// 添加页面
|
||||
ks.add_page(make_page("p1", "A", &[])).await.unwrap();
|
||||
ks.add_page(make_page("p2", "B", &[])).await.unwrap();
|
||||
assert_eq!(ks.get_index().len(), 2);
|
||||
|
||||
// 模拟 index 漂移:清空后重建
|
||||
ks.index.lock().unwrap().clear();
|
||||
assert_eq!(ks.get_index().len(), 0);
|
||||
|
||||
ks.rebuild_index().await.unwrap();
|
||||
assert_eq!(ks.get_index().len(), 2);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
//! 记忆检索器 —— 基于 TextOverlap (Dice 系数) 的单通道关键词检索。
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::memory::error::MemoryError;
|
||||
use crate::memory::knowledge::KnowledgeStore;
|
||||
use crate::memory::types::KnowledgePage;
|
||||
|
||||
/// 检索器配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrieverConfig {
|
||||
/// 最大返回条数(默认 20)。
|
||||
pub max_results: usize,
|
||||
/// 最低分数阈值 [0.0, 1.0](默认 0.1)。
|
||||
pub min_score: f32,
|
||||
}
|
||||
|
||||
impl Default for RetrieverConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_results: 20,
|
||||
min_score: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 单条带评分的检索结果。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScoredItem {
|
||||
pub page: KnowledgePage,
|
||||
/// TextOverlap 评分 [0.0, 1.0]
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
/// 检索结果。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalResult {
|
||||
pub items: Vec<ScoredItem>,
|
||||
pub query: String,
|
||||
}
|
||||
|
||||
/// 记忆检索器 —— 在 `KnowledgeStore` 中做关键词检索并按 TextOverlap 评分。
|
||||
pub struct MemoryRetriever {
|
||||
knowledge_store: KnowledgeStore,
|
||||
config: RetrieverConfig,
|
||||
/// 停用词表(用于关键词提取)。
|
||||
stop_words: HashSet<String>,
|
||||
}
|
||||
|
||||
impl MemoryRetriever {
|
||||
/// 创建一个新的 MemoryRetriever。
|
||||
pub fn new(knowledge_store: KnowledgeStore, config: RetrieverConfig) -> Self {
|
||||
Self {
|
||||
knowledge_store,
|
||||
config,
|
||||
stop_words: default_stop_words(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 替换停用词表。
|
||||
pub fn with_stop_words(mut self, stop_words: HashSet<String>) -> Self {
|
||||
self.stop_words = stop_words;
|
||||
self
|
||||
}
|
||||
|
||||
/// 检索相关知识页面。
|
||||
pub async fn retrieve(&self, query: &str) -> Result<RetrievalResult, MemoryError> {
|
||||
if query.is_empty() {
|
||||
return Ok(RetrievalResult {
|
||||
items: Vec::new(),
|
||||
query: query.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// 1. 关键词提取
|
||||
let keywords = extract_keywords(query, &self.stop_words);
|
||||
|
||||
// 2. 用关键词在 KnowledgeStore 中搜索
|
||||
let mut pages = Vec::new();
|
||||
for keyword in &keywords {
|
||||
let found = self.knowledge_store.search(keyword).await?;
|
||||
for page in found {
|
||||
if !pages.iter().any(|p: &KnowledgePage| p.id == page.id) {
|
||||
pages.push(page);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. TextOverlap 评分
|
||||
let mut items: Vec<ScoredItem> = pages
|
||||
.into_iter()
|
||||
.map(|page| {
|
||||
let score = text_overlap_score(query, &page);
|
||||
ScoredItem { page, score }
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 4. 过滤 → 排序 → 截取
|
||||
items.retain(|i| i.score >= self.config.min_score);
|
||||
items.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
|
||||
items.truncate(self.config.max_results);
|
||||
|
||||
Ok(RetrievalResult {
|
||||
items,
|
||||
query: query.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// 从 query 中提取关键词:按非字母数字字符分割 → 转小写 → 过滤单字符和停用词。
|
||||
fn extract_keywords(query: &str, stop_words: &HashSet<String>) -> Vec<String> {
|
||||
query
|
||||
.split(|c: char| !c.is_alphanumeric())
|
||||
.filter_map(|s| {
|
||||
let lower = s.to_lowercase();
|
||||
if lower.is_empty() || lower.chars().count() < 2 {
|
||||
None
|
||||
} else if stop_words.contains(&lower) {
|
||||
None
|
||||
} else {
|
||||
Some(lower)
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// TextOverlap 评分(基于字符 bigram 的 Dice 系数 + 多字段加权)。
|
||||
///
|
||||
/// 字段权重:title 0.5 + summary 0.3 + content 0.2
|
||||
/// 中文场景按字符级 bigram 处理,不依赖分词器。
|
||||
pub fn text_overlap_score(query: &str, page: &KnowledgePage) -> f32 {
|
||||
let title = text_overlap_dice(query, &page.title);
|
||||
let summary = text_overlap_dice(query, &page.summary);
|
||||
let content = text_overlap_dice(query, &page.content);
|
||||
title * 0.5 + summary * 0.3 + content * 0.2
|
||||
}
|
||||
|
||||
/// Dice 系数(基于字符 bigram)。
|
||||
fn text_overlap_dice(query: &str, text: &str) -> f32 {
|
||||
let q_bigrams = char_bigrams(query);
|
||||
let t_bigrams = char_bigrams(text);
|
||||
if q_bigrams.is_empty() || t_bigrams.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let q_set: HashSet<&String> = q_bigrams.iter().collect();
|
||||
let t_set: HashSet<&String> = t_bigrams.iter().collect();
|
||||
let intersect = q_set.intersection(&t_set).count();
|
||||
let denom = q_set.len() + t_set.len();
|
||||
if denom == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(2.0 * intersect as f32) / denom as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// 提取字符 bigrams(用于字符级 Dice 系数)。
|
||||
fn char_bigrams(s: &str) -> Vec<String> {
|
||||
let chars: Vec<char> = s.chars().collect();
|
||||
chars.windows(2).map(|w| w.iter().collect()).collect()
|
||||
}
|
||||
|
||||
fn default_stop_words() -> HashSet<String> {
|
||||
[
|
||||
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
|
||||
"had", "do", "does", "did", "will", "would", "should", "could", "may", "might", "shall",
|
||||
"can", "this", "that", "these", "those", "it", "its", "they", "them", "their", "what",
|
||||
"which", "who", "whom", "how", "when", "where", "and", "or", "but", "not", "no", "nor",
|
||||
"so", "if", "then", "else", "with", "without", "for", "to", "from", "in", "on", "at",
|
||||
"by", "of", "as", "into", "through", "during", "before", "after", "above", "below",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::knowledge::KnowledgeStore;
|
||||
use crate::memory::{InMemoryStore, MemoryStore};
|
||||
use std::sync::Arc;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
fn make_page(id: &str, title: &str, summary: &str, content: &str) -> KnowledgePage {
|
||||
let now = OffsetDateTime::now_utc();
|
||||
KnowledgePage {
|
||||
id: id.to_string(),
|
||||
title: title.to_string(),
|
||||
summary: summary.to_string(),
|
||||
content: content.to_string(),
|
||||
tags: Vec::new(),
|
||||
references: Vec::new(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retrieve_empty_query() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
let retriever = MemoryRetriever::new(ks, RetrieverConfig::default());
|
||||
let result = retriever.retrieve("").await.unwrap();
|
||||
assert!(result.items.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retrieve_finds_relevant_page() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
ks.add_page(make_page(
|
||||
"p1",
|
||||
"LangGraph StateGraph",
|
||||
"state management",
|
||||
"LangGraph uses StateGraph for state machines",
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
ks.add_page(make_page("p2", "Other", "unrelated", "nothing matching"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let retriever = MemoryRetriever::new(ks, RetrieverConfig::default());
|
||||
let result = retriever.retrieve("LangGraph state").await.unwrap();
|
||||
assert!(!result.items.is_empty());
|
||||
assert_eq!(result.items[0].page.id, "p1");
|
||||
assert!(result.items[0].score > 0.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retrieve_respects_min_score() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
ks.add_page(make_page("p1", "X", "Y", "Z")).await.unwrap();
|
||||
|
||||
let config = RetrieverConfig {
|
||||
max_results: 10,
|
||||
min_score: 0.99,
|
||||
};
|
||||
let retriever = MemoryRetriever::new(ks, config);
|
||||
let result = retriever.retrieve("totally unrelated content").await.unwrap();
|
||||
assert!(result.items.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retrieve_respects_max_results() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
for i in 0..5 {
|
||||
ks.add_page(make_page(
|
||||
&format!("p{i}"),
|
||||
"LangGraph",
|
||||
"framework",
|
||||
"agent runtime",
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
let config = RetrieverConfig {
|
||||
max_results: 2,
|
||||
min_score: 0.0,
|
||||
};
|
||||
let retriever = MemoryRetriever::new(ks, config);
|
||||
let result = retriever.retrieve("LangGraph").await.unwrap();
|
||||
assert_eq!(result.items.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn text_overlap_dice_zero_on_empty() {
|
||||
assert_eq!(text_overlap_dice("hello", ""), 0.0);
|
||||
assert_eq!(text_overlap_dice("", "hello"), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn text_overlap_dice_identical() {
|
||||
let s = "hello world";
|
||||
assert!((text_overlap_dice(s, s) - 1.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_keywords_filters_stop_words() {
|
||||
let stop = default_stop_words();
|
||||
let kws = extract_keywords("the quick brown fox is fast", &stop);
|
||||
assert!(!kws.contains(&"the".to_string()));
|
||||
assert!(!kws.contains(&"is".to_string()));
|
||||
assert!(kws.contains(&"quick".to_string()));
|
||||
assert!(kws.contains(&"brown".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_keywords_filters_single_chars() {
|
||||
let stop = default_stop_words();
|
||||
let kws = extract_keywords("a b c dog", &stop);
|
||||
assert!(!kws.contains(&"a".to_string()));
|
||||
assert!(!kws.contains(&"b".to_string()));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,314 @@
|
||||
//! MemoryStore 抽象接口与默认实现。
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use crate::memory::error::MemoryError;
|
||||
use crate::memory::types::{MemoryFilter, MemoryItem};
|
||||
|
||||
/// 底层记忆存储抽象接口。
|
||||
///
|
||||
/// 下游可实现此 trait 以对接持久化后端(JSON 文件、SQLite、Redis 等)。
|
||||
/// 默认实现 [`InMemoryStore`] 基于进程内 HashMap。
|
||||
#[async_trait]
|
||||
pub trait MemoryStore: Send + Sync {
|
||||
/// 保存/覆盖一个 MemoryItem(upsert 语义)。
|
||||
/// - 如果 id 不存在,则插入新条目
|
||||
/// - 如果 id 已存在,则覆盖旧条目
|
||||
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError>;
|
||||
|
||||
/// 根据 id 获取一个 MemoryItem。
|
||||
async fn get(&self, id: &str) -> Result<Option<MemoryItem>, MemoryError>;
|
||||
|
||||
/// 根据 id 删除一个 MemoryItem。
|
||||
async fn delete(&self, id: &str) -> Result<(), MemoryError>;
|
||||
|
||||
/// 根据 filter 列出 MemoryItem。
|
||||
async fn list(&self, filter: &MemoryFilter) -> Result<Vec<MemoryItem>, MemoryError>;
|
||||
}
|
||||
|
||||
/// 淘汰策略。
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum EvictionPolicy {
|
||||
/// 不淘汰(默认)。
|
||||
None,
|
||||
/// 超过存活时间(秒)淘汰。
|
||||
Ttl { ttl_secs: u64 },
|
||||
/// 超过容量上限淘汰最旧(基于 created_at)。
|
||||
Capacity { max_items: usize },
|
||||
}
|
||||
|
||||
/// 淘汰配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EvictionConfig {
|
||||
pub policy: EvictionPolicy,
|
||||
/// 每写入 N 条后检查一次淘汰条件。
|
||||
pub check_interval: usize,
|
||||
}
|
||||
|
||||
impl Default for EvictionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
policy: EvictionPolicy::None,
|
||||
check_interval: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 进程内默认实现 —— 基于 HashMap + Mutex,纯内存。
|
||||
pub struct InMemoryStore {
|
||||
items: Mutex<HashMap<String, MemoryItem>>,
|
||||
eviction: EvictionConfig,
|
||||
/// 自上次淘汰检查以来的写入次数。
|
||||
writes_since_check: Mutex<usize>,
|
||||
}
|
||||
|
||||
impl InMemoryStore {
|
||||
/// 创建一个无淘汰策略的 InMemoryStore。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
items: Mutex::new(HashMap::new()),
|
||||
eviction: EvictionConfig::default(),
|
||||
writes_since_check: Mutex::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建一个带淘汰配置的 InMemoryStore。
|
||||
pub fn with_eviction(eviction: EvictionConfig) -> Self {
|
||||
Self {
|
||||
items: Mutex::new(HashMap::new()),
|
||||
eviction,
|
||||
writes_since_check: Mutex::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
fn maybe_evict(&self) {
|
||||
// 不使用 .lock().await 跨点,先取计数判断是否需要淘汰
|
||||
let should_check = {
|
||||
let mut counter = self.writes_since_check.lock().unwrap();
|
||||
*counter += 1;
|
||||
if *counter >= self.eviction.check_interval {
|
||||
*counter = 0;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
if !should_check {
|
||||
return;
|
||||
}
|
||||
|
||||
let policy = self.eviction.policy.clone();
|
||||
match policy {
|
||||
EvictionPolicy::None => {}
|
||||
EvictionPolicy::Ttl { ttl_secs } => {
|
||||
let cutoff = OffsetDateTime::now_utc() - time::Duration::seconds(ttl_secs as i64);
|
||||
let mut items = self.items.lock().unwrap();
|
||||
items.retain(|_, v| v.created_at > cutoff);
|
||||
}
|
||||
EvictionPolicy::Capacity { max_items } => {
|
||||
let mut items = self.items.lock().unwrap();
|
||||
if items.len() > max_items {
|
||||
let mut vec: Vec<_> = items.drain().collect();
|
||||
// O(n) 部分排序:保留 created_at 最大的 max_items 个
|
||||
vec.select_nth_unstable_by(max_items, |a, b| {
|
||||
b.1.created_at.cmp(&a.1.created_at)
|
||||
});
|
||||
vec.truncate(max_items);
|
||||
*items = vec.into_iter().collect();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InMemoryStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MemoryStore for InMemoryStore {
|
||||
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError> {
|
||||
{
|
||||
let mut items = self.items.lock().unwrap();
|
||||
items.insert(item.id.clone(), item);
|
||||
}
|
||||
self.maybe_evict();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get(&self, id: &str) -> Result<Option<MemoryItem>, MemoryError> {
|
||||
let items = self.items.lock().unwrap();
|
||||
Ok(items.get(id).cloned())
|
||||
}
|
||||
|
||||
async fn delete(&self, id: &str) -> Result<(), MemoryError> {
|
||||
let mut items = self.items.lock().unwrap();
|
||||
items.remove(id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list(&self, filter: &MemoryFilter) -> Result<Vec<MemoryItem>, MemoryError> {
|
||||
let items = self.items.lock().unwrap();
|
||||
let mut result: Vec<MemoryItem> = items
|
||||
.values()
|
||||
.filter(|v| match &filter.prefix {
|
||||
Some(p) => v.id.starts_with(p),
|
||||
None => true,
|
||||
})
|
||||
.filter(|v| match filter.since {
|
||||
Some(t) => v.created_at > t,
|
||||
None => true,
|
||||
})
|
||||
.cloned()
|
||||
.collect();
|
||||
// 按 created_at 升序排列(最旧在前)
|
||||
result.sort_by_key(|v| v.created_at);
|
||||
// 应用 offset
|
||||
if let Some(offset) = filter.offset {
|
||||
if offset < result.len() {
|
||||
result.drain(..offset);
|
||||
} else {
|
||||
result.clear();
|
||||
}
|
||||
}
|
||||
// 应用 limit
|
||||
if let Some(limit) = filter.limit {
|
||||
result.truncate(limit);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
fn make_item(id: &str) -> MemoryItem {
|
||||
MemoryItem {
|
||||
id: id.to_string(),
|
||||
content: format!("content-{id}"),
|
||||
metadata: serde_json::json!({}),
|
||||
created_at: OffsetDateTime::now_utc(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_get_delete_list() {
|
||||
let store = InMemoryStore::new();
|
||||
store.save(make_item("a")).await.unwrap();
|
||||
store.save(make_item("b")).await.unwrap();
|
||||
|
||||
let got = store.get("a").await.unwrap();
|
||||
assert!(got.is_some());
|
||||
assert_eq!(got.unwrap().id, "a");
|
||||
|
||||
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||
assert_eq!(list.len(), 2);
|
||||
|
||||
store.delete("a").await.unwrap();
|
||||
assert!(store.get("a").await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_is_upsert() {
|
||||
let store = InMemoryStore::new();
|
||||
store.save(make_item("a")).await.unwrap();
|
||||
let mut item = make_item("a");
|
||||
item.content = "updated".to_string();
|
||||
store.save(item).await.unwrap();
|
||||
let got = store.get("a").await.unwrap().unwrap();
|
||||
assert_eq!(got.content, "updated");
|
||||
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||
assert_eq!(list.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_with_prefix_and_limit() {
|
||||
let store = InMemoryStore::new();
|
||||
store.save(make_item("foo_a")).await.unwrap();
|
||||
store.save(make_item("foo_b")).await.unwrap();
|
||||
store.save(make_item("bar_a")).await.unwrap();
|
||||
|
||||
let filter = MemoryFilter {
|
||||
prefix: Some("foo_".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let list = store.list(&filter).await.unwrap();
|
||||
assert_eq!(list.len(), 2);
|
||||
|
||||
let filter = MemoryFilter {
|
||||
prefix: Some("foo_".to_string()),
|
||||
limit: Some(1),
|
||||
..Default::default()
|
||||
};
|
||||
let list = store.list(&filter).await.unwrap();
|
||||
assert_eq!(list.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn capacity_eviction() {
|
||||
// 强制每次写入都检查
|
||||
let eviction = EvictionConfig {
|
||||
policy: EvictionPolicy::Capacity { max_items: 2 },
|
||||
check_interval: 1,
|
||||
};
|
||||
let store = InMemoryStore::with_eviction(eviction);
|
||||
// 第一条和第二条共存
|
||||
store.save(make_item("a")).await.unwrap();
|
||||
store.save(make_item("b")).await.unwrap();
|
||||
// 第三条写入触发淘汰:a 或 b 之一被淘汰
|
||||
store.save(make_item("c")).await.unwrap();
|
||||
|
||||
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||
assert_eq!(list.len(), 2);
|
||||
// 留下的应该是 b 和 c(最新的两个)
|
||||
let ids: Vec<&str> = list.iter().map(|v| v.id.as_str()).collect();
|
||||
assert!(ids.contains(&"b"));
|
||||
assert!(ids.contains(&"c"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ttl_eviction() {
|
||||
// TTL 设为 0 会立即过期,但我们想保留 "a" 等待 "b" 写入后被淘汰。
|
||||
// 改用小 TTL + 睡眠:先 save a,sleep,save b 时 a 已过期被淘汰。
|
||||
let eviction = EvictionConfig {
|
||||
policy: EvictionPolicy::Ttl { ttl_secs: 1 },
|
||||
check_interval: 1,
|
||||
};
|
||||
let store = InMemoryStore::with_eviction(eviction);
|
||||
store.save(make_item("a")).await.unwrap();
|
||||
// 等待超过 1 秒
|
||||
std::thread::sleep(std::time::Duration::from_millis(1100));
|
||||
// 触发淘汰:a 已超过 ttl_secs=1,应被淘汰
|
||||
store.save(make_item("b")).await.unwrap();
|
||||
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||
// 由于 ttl_secs=1,且 b 刚写入,可能刚好处于临界值。
|
||||
// 我们只断言 list 不包含 "a" 即可。
|
||||
let ids: Vec<&str> = list.iter().map(|v| v.id.as_str()).collect();
|
||||
assert!(
|
||||
!ids.contains(&"a"),
|
||||
"expected 'a' to be evicted, but found in {ids:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn none_policy_no_eviction() {
|
||||
let eviction = EvictionConfig {
|
||||
policy: EvictionPolicy::None,
|
||||
check_interval: 1,
|
||||
};
|
||||
let store = InMemoryStore::with_eviction(eviction);
|
||||
for i in 0..100 {
|
||||
store.save(make_item(&format!("item_{i}"))).await.unwrap();
|
||||
}
|
||||
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||
assert_eq!(list.len(), 100);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
//! 记忆系统核心数据类型。
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use time::OffsetDateTime;
|
||||
|
||||
/// 记忆条目 —— MemoryStore 存储的基本单元。
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryItem {
|
||||
/// 唯一标识。
|
||||
pub id: String,
|
||||
/// 内容(通常为 JSON 序列化的具体记忆数据)。
|
||||
pub content: String,
|
||||
/// 任意附加元数据。
|
||||
pub metadata: serde_json::Value,
|
||||
/// 创建时间。
|
||||
pub created_at: OffsetDateTime,
|
||||
}
|
||||
|
||||
/// MemoryStore 列表查询条件。
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MemoryFilter {
|
||||
/// 按 id 前缀过滤。
|
||||
pub prefix: Option<String>,
|
||||
/// 仅返回该时间之后创建的条目。
|
||||
pub since: Option<OffsetDateTime>,
|
||||
/// 跳过前 N 条。
|
||||
pub offset: Option<usize>,
|
||||
/// 最多返回 N 条。
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
/// 知识页面 —— 描述一段结构化知识。
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KnowledgePage {
|
||||
/// 唯一标识。
|
||||
pub id: String,
|
||||
/// 标题。
|
||||
pub title: String,
|
||||
/// 一句话摘要。
|
||||
pub summary: String,
|
||||
/// 完整内容。
|
||||
pub content: String,
|
||||
/// 检索标签。
|
||||
pub tags: Vec<String>,
|
||||
/// 交叉引用的其他页面 ID。
|
||||
pub references: Vec<String>,
|
||||
/// 创建时间。
|
||||
pub created_at: OffsetDateTime,
|
||||
/// 最后更新时间。
|
||||
pub updated_at: OffsetDateTime,
|
||||
}
|
||||
|
||||
/// 知识页面索引条目 —— 用于轻量遍历和内容目录。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PageIndexEntry {
|
||||
pub id: String,
|
||||
pub title: String,
|
||||
pub summary: String,
|
||||
pub tags: Vec<String>,
|
||||
pub updated_at: OffsetDateTime,
|
||||
}
|
||||
|
||||
impl From<&KnowledgePage> for PageIndexEntry {
|
||||
fn from(p: &KnowledgePage) -> Self {
|
||||
Self {
|
||||
id: p.id.clone(),
|
||||
title: p.title.clone(),
|
||||
summary: p.summary.clone(),
|
||||
tags: p.tags.clone(),
|
||||
updated_at: p.updated_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
pub mod error;
|
||||
pub mod template;
|
||||
pub mod composer;
|
||||
|
||||
pub use error::PromptError;
|
||||
pub use template::{PromptTemplate, PromptTemplateRegistry, TemplateContext, TemplateValue};
|
||||
pub use composer::PromptComposer;
|
||||
@@ -0,0 +1,406 @@
|
||||
use crate::llm::types::message::{ContentField, OpenaiChatMessage, OpenaiContentPart};
|
||||
use crate::llm::types::request::OpenaiChatRequest;
|
||||
use crate::prompt::error::PromptError;
|
||||
use crate::prompt::template::{PromptTemplate, TemplateContext};
|
||||
|
||||
/// 提示词组合器——构建多角色消息序列。
|
||||
#[derive(Default)]
|
||||
pub struct PromptComposer {
|
||||
messages: Vec<OpenaiChatMessage>,
|
||||
}
|
||||
|
||||
impl PromptComposer {
|
||||
/// 创建一个空的组合器。
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// 从已有的消息列表初始化。
|
||||
pub fn from_messages(messages: Vec<OpenaiChatMessage>) -> Self {
|
||||
Self { messages }
|
||||
}
|
||||
|
||||
// ===== 纯文本消息 =====
|
||||
|
||||
/// 添加一条纯文本 system 消息。
|
||||
pub fn system(mut self, text: impl Into<String>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::system_text(text.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条纯文本 user 消息。
|
||||
pub fn user(mut self, text: impl Into<String>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::user_text(text.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条纯文本 assistant 消息。
|
||||
pub fn assistant(mut self, text: impl Into<String>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::assistant_text(text.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条纯文本 developer 消息(o1 系列模型使用)。
|
||||
pub fn developer(mut self, text: impl Into<String>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::developer_text(text.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条 Tool 消息(工具执行结果回传)。
|
||||
pub fn tool(mut self, tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::tool_result(
|
||||
tool_call_id.into(),
|
||||
content.into(),
|
||||
));
|
||||
self
|
||||
}
|
||||
|
||||
// ===== 模板消息 =====
|
||||
|
||||
/// 使用模板和上下文渲染后添加为 user 消息。
|
||||
pub fn user_template(
|
||||
mut self,
|
||||
template: &PromptTemplate,
|
||||
ctx: &TemplateContext,
|
||||
) -> Result<Self, PromptError> {
|
||||
let text = template.render(ctx)?;
|
||||
self.push_message(OpenaiChatMessage::user_text(text));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// 使用模板和上下文渲染后添加为 system 消息。
|
||||
pub fn system_template(
|
||||
mut self,
|
||||
template: &PromptTemplate,
|
||||
ctx: &TemplateContext,
|
||||
) -> Result<Self, PromptError> {
|
||||
let text = template.render(ctx)?;
|
||||
self.push_message(OpenaiChatMessage::system_text(text));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// 使用模板和上下文渲染后添加为 assistant 消息。
|
||||
pub fn assistant_template(
|
||||
mut self,
|
||||
template: &PromptTemplate,
|
||||
ctx: &TemplateContext,
|
||||
) -> Result<Self, PromptError> {
|
||||
let text = template.render(ctx)?;
|
||||
self.push_message(OpenaiChatMessage::assistant_text(text));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// 使用模板和上下文渲染后添加为 developer 消息。
|
||||
pub fn developer_template(
|
||||
mut self,
|
||||
template: &PromptTemplate,
|
||||
ctx: &TemplateContext,
|
||||
) -> Result<Self, PromptError> {
|
||||
let text = template.render(ctx)?;
|
||||
self.push_message(OpenaiChatMessage::developer_text(text));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
// ===== 多模态 ContentPart =====
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 system 消息。
|
||||
pub fn system_content(mut self, part: OpenaiContentPart) -> Self {
|
||||
self.push_message(OpenaiChatMessage::System {
|
||||
content: ContentField::Array(vec![part]),
|
||||
name: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 user 消息。
|
||||
pub fn user_content(mut self, part: OpenaiContentPart) -> Self {
|
||||
self.push_message(OpenaiChatMessage::User {
|
||||
content: ContentField::Array(vec![part]),
|
||||
name: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 assistant 消息。
|
||||
pub fn assistant_content(mut self, part: OpenaiContentPart) -> Self {
|
||||
self.push_message(OpenaiChatMessage::Assistant {
|
||||
content: ContentField::Array(vec![part]),
|
||||
refusal: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 developer 消息。
|
||||
pub fn developer_content(mut self, part: OpenaiContentPart) -> Self {
|
||||
self.push_message(OpenaiChatMessage::Developer {
|
||||
content: ContentField::Array(vec![part]),
|
||||
name: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 Tool 消息。
|
||||
pub fn tool_content(
|
||||
mut self,
|
||||
tool_call_id: impl Into<String>,
|
||||
part: OpenaiContentPart,
|
||||
) -> Self {
|
||||
self.push_message(OpenaiChatMessage::Tool {
|
||||
content: ContentField::Array(vec![part]),
|
||||
tool_call_id: tool_call_id.into(),
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 批量添加 ContentPart 作为 user 消息。
|
||||
pub fn user_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::User {
|
||||
content: ContentField::Array(parts),
|
||||
name: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 批量添加 ContentPart 作为 system 消息。
|
||||
pub fn system_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::System {
|
||||
content: ContentField::Array(parts),
|
||||
name: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 批量添加 ContentPart 作为 assistant 消息。
|
||||
pub fn assistant_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::Assistant {
|
||||
content: ContentField::Array(parts),
|
||||
refusal: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 批量添加 ContentPart 作为 developer 消息。
|
||||
pub fn developer_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::Developer {
|
||||
content: ContentField::Array(parts),
|
||||
name: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
// ===== 角色标识 =====
|
||||
|
||||
/// 为上一条添加的消息设置 `name` 字段。
|
||||
pub fn with_name(mut self, name: impl Into<String>) -> Self {
|
||||
let name = name.into();
|
||||
if let Some(msg) = self.messages.last_mut() {
|
||||
set_message_name(msg, name);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
// ===== 构建 =====
|
||||
|
||||
/// 构建最终的消息列表。
|
||||
pub fn build(self) -> Vec<OpenaiChatMessage> {
|
||||
self.messages
|
||||
}
|
||||
|
||||
/// 构建并直接创建 ChatRequest(需搭配 model 参数)。
|
||||
/// 返回的 `OpenaiChatRequest` 中 `tools`、`temperature`、`max_tokens` 等字段均为 `None`,
|
||||
/// 可通过结构体更新语法补全:`OpenaiChatRequest { tools: Some(...), ..req }`。
|
||||
pub fn build_request(self, model: impl Into<String>) -> OpenaiChatRequest {
|
||||
OpenaiChatRequest {
|
||||
model: model.into(),
|
||||
messages: self.messages,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
// ===== 内部方法 =====
|
||||
|
||||
fn push_message(&mut self, msg: OpenaiChatMessage) {
|
||||
self.messages.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_message_name(msg: &mut OpenaiChatMessage, name: String) {
|
||||
match msg {
|
||||
OpenaiChatMessage::Developer { name: n, .. } => *n = Some(name),
|
||||
OpenaiChatMessage::System { name: n, .. } => *n = Some(name),
|
||||
OpenaiChatMessage::User { name: n, .. } => *n = Some(name),
|
||||
OpenaiChatMessage::Assistant { name: n, .. } => *n = Some(name),
|
||||
OpenaiChatMessage::Tool { .. } => {}
|
||||
OpenaiChatMessage::Function { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// 验证消息序列是否符合 OpenAI API 要求。
|
||||
pub fn validate_messages(messages: &[OpenaiChatMessage]) -> Result<(), PromptError> {
|
||||
if messages.is_empty() {
|
||||
return Err(PromptError::InvalidSequence(
|
||||
"消息列表不能为空".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut last_tool_call_ids: Vec<String> = Vec::new();
|
||||
|
||||
for (i, msg) in messages.iter().enumerate() {
|
||||
match msg {
|
||||
OpenaiChatMessage::Tool {
|
||||
tool_call_id,
|
||||
..
|
||||
} => {
|
||||
if last_tool_call_ids.is_empty() {
|
||||
return Err(PromptError::InvalidSequence(format!(
|
||||
"消息[{i}] Tool 消息前必须有 Assistant 消息且含 tool_calls"
|
||||
)));
|
||||
}
|
||||
if !last_tool_call_ids.iter().any(|id| id == tool_call_id) {
|
||||
return Err(PromptError::InvalidSequence(format!(
|
||||
"消息[{i}] Tool 消息的 tool_call_id '{}' 未匹配任何 assistant tool_calls",
|
||||
tool_call_id
|
||||
)));
|
||||
}
|
||||
}
|
||||
OpenaiChatMessage::Assistant {
|
||||
tool_calls: Some(calls),
|
||||
..
|
||||
} => {
|
||||
last_tool_call_ids.clear();
|
||||
for call in calls {
|
||||
let crate::llm::types::OpenaiToolCall::Function { id, .. } = call;
|
||||
last_tool_call_ids.push(id.clone());
|
||||
}
|
||||
}
|
||||
OpenaiChatMessage::Assistant {
|
||||
tool_calls: None, ..
|
||||
} => {
|
||||
last_tool_call_ids.clear();
|
||||
}
|
||||
_ => {
|
||||
last_tool_call_ids.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::prompt::TemplateValue;
|
||||
|
||||
#[test]
|
||||
fn test_composer_basic() {
|
||||
let msgs = PromptComposer::new()
|
||||
.system("You are helpful")
|
||||
.user("Hello")
|
||||
.assistant("Hi there!")
|
||||
.build();
|
||||
|
||||
assert_eq!(msgs.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_composer_tool() {
|
||||
let msgs = PromptComposer::new()
|
||||
.system("You are helpful")
|
||||
.user("What's the weather?")
|
||||
.assistant("Let me check")
|
||||
.tool("call_123", "Sunny, 25°C")
|
||||
.build();
|
||||
|
||||
assert_eq!(msgs.len(), 4);
|
||||
match &msgs[3] {
|
||||
OpenaiChatMessage::Tool {
|
||||
tool_call_id,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(tool_call_id, "call_123");
|
||||
match content {
|
||||
ContentField::String(s) => assert_eq!(s, "Sunny, 25°C"),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected Tool message"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_messages_ok() {
|
||||
let msgs = PromptComposer::new()
|
||||
.system("You are helpful")
|
||||
.user("Hello")
|
||||
.build();
|
||||
|
||||
assert!(validate_messages(&msgs).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_messages_empty() {
|
||||
let msgs: Vec<OpenaiChatMessage> = vec![];
|
||||
assert!(validate_messages(&msgs).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_render() {
|
||||
let tpl = PromptTemplate::compile("Hello {{name}}, you have {{count}} messages").unwrap();
|
||||
let mut ctx = TemplateContext::new();
|
||||
ctx.insert("name", "Alice");
|
||||
ctx.insert("count", "5");
|
||||
|
||||
let result = tpl.render(&ctx).unwrap();
|
||||
assert_eq!(result, "Hello Alice, you have 5 messages");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_if() {
|
||||
let tpl = PromptTemplate::compile("Hello {{#if name}}{{name}}{{else}}Guest{{/if}}").unwrap();
|
||||
let mut ctx = TemplateContext::new();
|
||||
ctx.insert("name", "Bob");
|
||||
|
||||
let with_name = tpl.render(&ctx).unwrap();
|
||||
assert_eq!(with_name, "Hello Bob");
|
||||
|
||||
let without_name = tpl.render(&TemplateContext::new()).unwrap();
|
||||
assert_eq!(without_name, "Hello Guest");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_each() {
|
||||
let tpl = PromptTemplate::compile("Items: {{#each items}}{{item}}, {{/each}}").unwrap();
|
||||
let mut ctx = TemplateContext::new();
|
||||
ctx.insert("items", TemplateValue::Array(vec![
|
||||
TemplateValue::String("a".to_string()),
|
||||
TemplateValue::String("b".to_string()),
|
||||
TemplateValue::String("c".to_string()),
|
||||
]));
|
||||
|
||||
let result = tpl.render(&ctx).unwrap();
|
||||
assert_eq!(result, "Items: a, b, c, ");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_display() {
|
||||
let tpl = PromptTemplate::compile("Hello {{name}}").unwrap();
|
||||
assert_eq!(format!("{}", tpl), "Hello {{name}}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_from_json() {
|
||||
let json: serde_json::Value = serde_json::json!({
|
||||
"name": "Alice",
|
||||
"active": true
|
||||
});
|
||||
let ctx = TemplateContext::from_json(&json).unwrap();
|
||||
assert!(ctx.get("name").is_some());
|
||||
assert!(ctx.get("active").is_some());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum PromptError {
|
||||
#[error("模板解析错误: {0}")]
|
||||
Parse(String),
|
||||
|
||||
#[error("渲染错误: 变量 '{0}' 未找到")]
|
||||
VariableNotFound(String),
|
||||
|
||||
#[error("渲染错误: 引用的子模板 '{0}' 未注册")]
|
||||
PartialNotFound(String),
|
||||
|
||||
#[error("渲染错误: '{0}' 不是数组,无法遍历")]
|
||||
NotAnArray(String),
|
||||
|
||||
#[error("渲染递归超过最大深度限制 ({0})")]
|
||||
MaxDepthReached(u8),
|
||||
|
||||
#[error("渲染错误: {0}")]
|
||||
Render(String),
|
||||
|
||||
#[error("消息序列校验失败: {0}")]
|
||||
InvalidSequence(String),
|
||||
|
||||
#[error("文件读取错误: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
@@ -0,0 +1,543 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::prompt::error::PromptError;
|
||||
|
||||
const MAX_RENDER_DEPTH: u8 = 16;
|
||||
|
||||
// ===== TemplateValue =====
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TemplateValue {
|
||||
String(String),
|
||||
Bool(bool),
|
||||
Array(Vec<TemplateValue>),
|
||||
Object(HashMap<String, TemplateValue>),
|
||||
}
|
||||
|
||||
impl From<String> for TemplateValue {
|
||||
fn from(s: String) -> Self {
|
||||
TemplateValue::String(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for TemplateValue {
|
||||
fn from(s: &str) -> Self {
|
||||
TemplateValue::String(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<bool> for TemplateValue {
|
||||
fn from(b: bool) -> Self {
|
||||
TemplateValue::Bool(b)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for TemplateValue {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
TemplateValue::String(s) => write!(f, "{}", s),
|
||||
TemplateValue::Bool(b) => write!(f, "{}", b),
|
||||
TemplateValue::Array(arr) => {
|
||||
let strs: Vec<String> = arr.iter().map(|v| format!("{}", v)).collect();
|
||||
write!(f, "[{}]", strs.join(", "))
|
||||
}
|
||||
TemplateValue::Object(map) => {
|
||||
let strs: Vec<String> = map
|
||||
.iter()
|
||||
.map(|(k, v)| format!("\"{}\": {}", k, v))
|
||||
.collect();
|
||||
write!(f, "{{{}}}", strs.join(", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TemplateValue {
|
||||
fn is_truthy(&self) -> bool {
|
||||
match self {
|
||||
TemplateValue::String(s) => !s.is_empty(),
|
||||
TemplateValue::Bool(b) => *b,
|
||||
TemplateValue::Array(arr) => !arr.is_empty(),
|
||||
TemplateValue::Object(map) => !map.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
fn as_array(&self) -> Option<&Vec<TemplateValue>> {
|
||||
match self {
|
||||
TemplateValue::Array(arr) => Some(arr),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== TemplateContext =====
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TemplateContext {
|
||||
vars: HashMap<String, TemplateValue>,
|
||||
}
|
||||
|
||||
impl TemplateContext {
|
||||
/// 创建一个空的模板上下文。
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// 插入变量(支持 `&str` / `String` / `bool` 自动转换)。
|
||||
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<TemplateValue>) {
|
||||
self.vars.insert(key.into(), value.into());
|
||||
}
|
||||
|
||||
/// 按名称获取变量值。
|
||||
pub fn get(&self, key: &str) -> Option<&TemplateValue> {
|
||||
self.vars.get(key)
|
||||
}
|
||||
|
||||
/// 从 `serde_json::Value` 递归构造(支持嵌套 Object/Array)。
|
||||
pub fn from_json(value: &Value) -> Result<Self, PromptError> {
|
||||
let map = value
|
||||
.as_object()
|
||||
.ok_or_else(|| PromptError::Render("JSON 根值必须是对象".to_string()))?;
|
||||
|
||||
let mut ctx = Self::new();
|
||||
for (k, v) in map {
|
||||
ctx.vars.insert(k.clone(), json_to_template_value(v)?);
|
||||
}
|
||||
Ok(ctx)
|
||||
}
|
||||
|
||||
/// 从 `HashMap` 构造(适用于配置加载场景)。
|
||||
pub fn from_map(map: HashMap<String, TemplateValue>) -> Self {
|
||||
Self { vars: map }
|
||||
}
|
||||
}
|
||||
|
||||
fn json_to_template_value(v: &Value) -> Result<TemplateValue, PromptError> {
|
||||
match v {
|
||||
Value::Null => Ok(TemplateValue::String(String::new())),
|
||||
Value::Bool(b) => Ok(TemplateValue::Bool(*b)),
|
||||
Value::Number(n) => Ok(TemplateValue::String(n.to_string())),
|
||||
Value::String(s) => Ok(TemplateValue::String(s.clone())),
|
||||
Value::Array(arr) => {
|
||||
let items: Result<Vec<TemplateValue>, _> =
|
||||
arr.iter().map(json_to_template_value).collect();
|
||||
Ok(TemplateValue::Array(items?))
|
||||
}
|
||||
Value::Object(obj) => {
|
||||
let mut map = HashMap::new();
|
||||
for (k, v) in obj {
|
||||
map.insert(k.clone(), json_to_template_value(v)?);
|
||||
}
|
||||
Ok(TemplateValue::Object(map))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Fragment (AST) =====
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Fragment {
|
||||
Literal(String),
|
||||
Variable { name: String },
|
||||
If {
|
||||
condition: String,
|
||||
body: Vec<Fragment>,
|
||||
else_body: Vec<Fragment>,
|
||||
},
|
||||
Each {
|
||||
variable: String,
|
||||
body: Vec<Fragment>,
|
||||
},
|
||||
Raw(String),
|
||||
Include(String),
|
||||
}
|
||||
|
||||
// ===== PromptTemplate =====
|
||||
|
||||
pub struct PromptTemplate {
|
||||
raw: String,
|
||||
fragments: Vec<Fragment>,
|
||||
partials: HashMap<String, PromptTemplate>,
|
||||
}
|
||||
|
||||
impl PromptTemplate {
|
||||
/// 从模板字符串编译。
|
||||
pub fn compile(template: &str) -> Result<Self, PromptError> {
|
||||
let fragments = compile_fragments(template)?;
|
||||
Ok(Self {
|
||||
raw: template.to_string(),
|
||||
fragments,
|
||||
partials: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// 使用上下文渲染。
|
||||
pub fn render(&self, ctx: &TemplateContext) -> Result<String, PromptError> {
|
||||
let mut output = String::new();
|
||||
render_fragments(&self.fragments, ctx, &self.partials, &mut output, 0)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// 使用上下文和外部 partials 渲染。
|
||||
pub fn render_with_partials(
|
||||
&self,
|
||||
ctx: &TemplateContext,
|
||||
partials: &HashMap<String, PromptTemplate>,
|
||||
) -> Result<String, PromptError> {
|
||||
let mut output = String::new();
|
||||
render_fragments(&self.fragments, ctx, partials, &mut output, 0)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// 注册可引用的子模板。
|
||||
pub fn register_partial(&mut self, name: &str, template: PromptTemplate) {
|
||||
self.partials.insert(name.to_string(), template);
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PromptTemplate {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.raw)
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Compiler =====
|
||||
|
||||
fn compile_fragments(template: &str) -> Result<Vec<Fragment>, PromptError> {
|
||||
let bytes = template.as_bytes();
|
||||
let len = bytes.len();
|
||||
let mut fragments = Vec::new();
|
||||
let mut i = 0;
|
||||
let mut literal = String::new();
|
||||
|
||||
while i < len {
|
||||
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
|
||||
if !literal.is_empty() {
|
||||
fragments.push(Fragment::Literal(literal.clone()));
|
||||
literal.clear();
|
||||
}
|
||||
let (tag_content, end) = parse_tag(bytes, i)?;
|
||||
i = end;
|
||||
|
||||
let tag = tag_content.trim();
|
||||
if let Some(rest) = tag.strip_prefix("#if ") {
|
||||
let (body, else_body, new_i) =
|
||||
parse_block(template, i, "if")?;
|
||||
let condition = rest.trim().to_string();
|
||||
fragments.push(Fragment::If {
|
||||
condition,
|
||||
body,
|
||||
else_body,
|
||||
});
|
||||
i = new_i;
|
||||
} else if let Some(rest) = tag.strip_prefix("#each ") {
|
||||
let (body, new_i) = parse_each_block(template, i)?;
|
||||
let variable = rest.trim().to_string();
|
||||
fragments.push(Fragment::Each { variable, body });
|
||||
i = new_i;
|
||||
} else if tag == "#raw" {
|
||||
let (raw, new_i) = parse_raw_block(template, i)?;
|
||||
fragments.push(Fragment::Raw(raw));
|
||||
i = new_i;
|
||||
} else if let Some(rest) = tag.strip_prefix("> ") {
|
||||
let name = rest.trim().to_string();
|
||||
fragments.push(Fragment::Include(name));
|
||||
} else if tag.starts_with("/") {
|
||||
break;
|
||||
} else {
|
||||
let name = tag.to_string();
|
||||
fragments.push(Fragment::Variable { name });
|
||||
}
|
||||
} else {
|
||||
literal.push(bytes[i] as char);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if !literal.is_empty() {
|
||||
fragments.push(Fragment::Literal(literal));
|
||||
}
|
||||
|
||||
Ok(fragments)
|
||||
}
|
||||
|
||||
fn parse_tag(bytes: &[u8], start: usize) -> Result<(String, usize), PromptError> {
|
||||
let len = bytes.len();
|
||||
let mut i = start + 2;
|
||||
let mut content = String::new();
|
||||
while i < len {
|
||||
if bytes[i] == b'}' && i + 1 < len && bytes[i + 1] == b'}' {
|
||||
return Ok((content, i + 2));
|
||||
}
|
||||
content.push(bytes[i] as char);
|
||||
i += 1;
|
||||
}
|
||||
Err(PromptError::Parse("未闭合的 {{ 标签".to_string()))
|
||||
}
|
||||
|
||||
fn parse_block(
|
||||
template: &str,
|
||||
start: usize,
|
||||
kind: &str,
|
||||
) -> Result<(Vec<Fragment>, Vec<Fragment>, usize), PromptError> {
|
||||
let bytes = template.as_bytes();
|
||||
let len = bytes.len();
|
||||
let mut depth = 1u32;
|
||||
let mut i = start;
|
||||
let mut body = String::new();
|
||||
let mut else_body = String::new();
|
||||
let mut is_else = false;
|
||||
|
||||
while i < len && depth > 0 {
|
||||
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
|
||||
let (tag, end) = parse_tag(bytes, i)?;
|
||||
let tag = tag.trim().to_string();
|
||||
if tag == format!("/{kind}") {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
let (if_fragments, else_fragments) = if is_else {
|
||||
(compile_fragments(&body)?, compile_fragments(&else_body)?)
|
||||
} else {
|
||||
(compile_fragments(&body)?, compile_fragments("")?)
|
||||
};
|
||||
return Ok((if_fragments, else_fragments, end));
|
||||
} else {
|
||||
body.push_str(&template[i..end]);
|
||||
}
|
||||
i = end;
|
||||
} else if tag == "#if " || tag.starts_with("#if ") {
|
||||
depth += 1;
|
||||
body.push_str(&template[i..end]);
|
||||
i = end;
|
||||
} else if tag == "else" && depth == 1 {
|
||||
is_else = true;
|
||||
i = end;
|
||||
} else {
|
||||
body.push_str(&template[i..end]);
|
||||
i = end;
|
||||
}
|
||||
} else {
|
||||
if is_else {
|
||||
else_body.push(bytes[i] as char);
|
||||
} else {
|
||||
body.push(bytes[i] as char);
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Err(PromptError::Parse(format!("未闭合的 {{#{}}} 块", kind)))
|
||||
}
|
||||
|
||||
fn parse_each_block(
|
||||
template: &str,
|
||||
start: usize,
|
||||
) -> Result<(Vec<Fragment>, usize), PromptError> {
|
||||
let bytes = template.as_bytes();
|
||||
let len = bytes.len();
|
||||
let mut depth = 1u32;
|
||||
let mut i = start;
|
||||
let mut body = String::new();
|
||||
|
||||
while i < len && depth > 0 {
|
||||
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
|
||||
let (tag, end) = parse_tag(bytes, i)?;
|
||||
let tag = tag.trim().to_string();
|
||||
if tag == "/each" {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
let fragments = compile_fragments(&body)?;
|
||||
return Ok((fragments, end));
|
||||
} else {
|
||||
body.push_str(&template[i..end]);
|
||||
}
|
||||
i = end;
|
||||
} else if tag == "#each " || tag.starts_with("#each ") {
|
||||
depth += 1;
|
||||
body.push_str(&template[i..end]);
|
||||
i = end;
|
||||
} else {
|
||||
body.push_str(&template[i..end]);
|
||||
i = end;
|
||||
}
|
||||
} else {
|
||||
body.push(bytes[i] as char);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Err(PromptError::Parse(
|
||||
"未闭合的 {{#each}} 块".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn parse_raw_block(template: &str, start: usize) -> Result<(String, usize), PromptError> {
|
||||
let bytes = template.as_bytes();
|
||||
let len = bytes.len();
|
||||
let mut i = start;
|
||||
let mut content = String::new();
|
||||
|
||||
while i < len {
|
||||
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
|
||||
let (tag, end) = parse_tag(bytes, i)?;
|
||||
let tag = tag.trim().to_string();
|
||||
if tag == "/raw" {
|
||||
return Ok((content, end));
|
||||
} else {
|
||||
content.push_str(&template[i..end]);
|
||||
i = end;
|
||||
}
|
||||
} else {
|
||||
content.push(bytes[i] as char);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Err(PromptError::Parse(
|
||||
"未闭合的 {{#raw}} 块".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
// ===== Renderer =====
|
||||
|
||||
fn render_fragments(
|
||||
fragments: &[Fragment],
|
||||
ctx: &TemplateContext,
|
||||
partials: &HashMap<String, PromptTemplate>,
|
||||
output: &mut String,
|
||||
depth: u8,
|
||||
) -> Result<(), PromptError> {
|
||||
if depth > MAX_RENDER_DEPTH {
|
||||
return Err(PromptError::MaxDepthReached(MAX_RENDER_DEPTH));
|
||||
}
|
||||
|
||||
for frag in fragments {
|
||||
match frag {
|
||||
Fragment::Literal(text) => {
|
||||
output.push_str(text);
|
||||
}
|
||||
Fragment::Variable { name } => {
|
||||
match ctx.get(name) {
|
||||
Some(val) => {
|
||||
output.push_str(&format!("{}", val));
|
||||
}
|
||||
None => {
|
||||
return Err(PromptError::VariableNotFound(name.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
Fragment::If {
|
||||
condition,
|
||||
body,
|
||||
else_body,
|
||||
} => {
|
||||
let truthy = ctx
|
||||
.get(condition)
|
||||
.map(|v| v.is_truthy())
|
||||
.unwrap_or(false);
|
||||
let target = if truthy { body } else { else_body };
|
||||
render_fragments(target, ctx, partials, output, depth + 1)?;
|
||||
}
|
||||
Fragment::Each { variable, body } => {
|
||||
let arr = match ctx.get(variable) {
|
||||
Some(val) => val.as_array().ok_or_else(|| {
|
||||
PromptError::NotAnArray(variable.clone())
|
||||
})?,
|
||||
None => {
|
||||
return Err(PromptError::VariableNotFound(variable.clone()));
|
||||
}
|
||||
};
|
||||
|
||||
for item in arr {
|
||||
let mut child_ctx = ctx.clone();
|
||||
child_ctx.vars.insert("item".to_string(), item.clone());
|
||||
render_fragments(body, &child_ctx, partials, output, depth + 1)?;
|
||||
}
|
||||
}
|
||||
Fragment::Raw(text) => {
|
||||
output.push_str(text);
|
||||
}
|
||||
Fragment::Include(name) => {
|
||||
if let Some(partial) = partials.get(name) {
|
||||
partial.render_with_partials(ctx, partials)?;
|
||||
} else {
|
||||
return Err(PromptError::PartialNotFound(name.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ===== PromptTemplateRegistry =====
|
||||
|
||||
/// 内部存储的模板(支持延迟编译)。
|
||||
enum StoredTemplate {
|
||||
Compiled(PromptTemplate),
|
||||
Raw(String),
|
||||
}
|
||||
|
||||
/// 模板注册表——管理多模板实例。
|
||||
#[derive(Default)]
|
||||
pub struct PromptTemplateRegistry {
|
||||
templates: HashMap<String, StoredTemplate>,
|
||||
}
|
||||
|
||||
impl PromptTemplateRegistry {
|
||||
/// 创建一个空的模板注册表。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
templates: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 从模板字符串编译并注册(立即编译)。
|
||||
pub fn register(&mut self, name: &str, template: &str) -> Result<(), PromptError> {
|
||||
let compiled = PromptTemplate::compile(template)?;
|
||||
self.templates
|
||||
.insert(name.to_string(), StoredTemplate::Compiled(compiled));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 延迟编译注册:只存储原始字符串,首次渲染时编译。
|
||||
pub fn register_lazy(&mut self, name: &str, template: &str) {
|
||||
self.templates.insert(
|
||||
name.to_string(),
|
||||
StoredTemplate::Raw(template.to_string()),
|
||||
);
|
||||
}
|
||||
|
||||
/// 从文件读取并编译注册。
|
||||
pub fn register_file(&mut self, name: &str, path: &std::path::Path) -> Result<(), PromptError> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let compiled = PromptTemplate::compile(&content)?;
|
||||
self.templates
|
||||
.insert(name.to_string(), StoredTemplate::Compiled(compiled));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取已注册的模板(延迟编译的模板在此首次编译)。
|
||||
pub fn get(&mut self, name: &str) -> Result<&PromptTemplate, PromptError> {
|
||||
if let Some(stored) = self.templates.get_mut(name) {
|
||||
if let StoredTemplate::Raw(raw) = stored {
|
||||
let compiled = PromptTemplate::compile(raw)?;
|
||||
*stored = StoredTemplate::Compiled(compiled);
|
||||
}
|
||||
match stored {
|
||||
StoredTemplate::Compiled(tpl) => Ok(tpl),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
Err(PromptError::PartialNotFound(name.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// 按名称渲染。
|
||||
pub fn render(&mut self, name: &str, ctx: &TemplateContext) -> Result<String, PromptError> {
|
||||
let tpl = self.get(name)?;
|
||||
tpl.render(ctx)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
//! 工具系统 —— 工具抽象、注册、调用、权限控制与 MCP 集成。
|
||||
|
||||
pub mod base;
|
||||
pub mod error;
|
||||
pub mod mcp;
|
||||
pub mod permission;
|
||||
pub mod registry;
|
||||
|
||||
pub use base::{BaseTool, ToolContext, ToolRef};
|
||||
pub use error::ToolError;
|
||||
pub use mcp::{McpClient, McpTransport};
|
||||
pub use permission::{Permission, PermissionChecker, PermissionConfig};
|
||||
pub use registry::{ToolInvocation, ToolRegistry};
|
||||
@@ -0,0 +1,139 @@
|
||||
//! 工具抽象接口与执行上下文。
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::tools::error::ToolError;
|
||||
use crate::tools::permission::Permission;
|
||||
|
||||
/// 工具执行上下文 —— 携带每次执行的运行时信息。
|
||||
///
|
||||
/// 字段在 Phase 2 即注入 `execute()` 签名中,防止后续扩展时出现
|
||||
/// breaking change。后续阶段可扩展字段(如 `progress`、`shared_state`),
|
||||
/// 但已有工具实现无需修改。
|
||||
#[derive(Debug)]
|
||||
pub struct ToolContext<'a> {
|
||||
/// 当前对话/会话 ID,用于关联性追踪。
|
||||
pub session_id: &'a str,
|
||||
/// 链路追踪 ID,用于跨工具调用的耗时分布。
|
||||
pub trace_id: &'a str,
|
||||
/// 取消令牌,用于优雅取消正在执行的工具。
|
||||
pub cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl<'a> ToolContext<'a> {
|
||||
/// 创建一个新的工具执行上下文。
|
||||
pub fn new(session_id: &'a str, trace_id: &'a str) -> Self {
|
||||
Self {
|
||||
session_id,
|
||||
trace_id,
|
||||
cancellation_token: CancellationToken::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建一个使用给定取消令牌的上下文。
|
||||
pub fn with_cancellation_token(
|
||||
session_id: &'a str,
|
||||
trace_id: &'a str,
|
||||
token: CancellationToken,
|
||||
) -> Self {
|
||||
Self {
|
||||
session_id,
|
||||
trace_id,
|
||||
cancellation_token: token,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 工具抽象接口 —— 所有工具(自定义或 MCP)最终都实现此 trait。
|
||||
#[async_trait]
|
||||
pub trait BaseTool: Send + Sync {
|
||||
/// 工具名称(唯一标识,用于 LLM 的 tool_calls.name 匹配)。
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// 工具描述(LLM 据此决定是否调用此工具)。
|
||||
fn description(&self) -> &str;
|
||||
|
||||
/// 工具参数定义(JSON Schema 格式,传递给 LLM 的 tool.parameters)。
|
||||
fn parameters(&self) -> Value;
|
||||
|
||||
/// 声明工具所需的权限列表。
|
||||
fn required_permissions(&self) -> Vec<Permission> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
/// 执行工具调用。
|
||||
///
|
||||
/// `ctx` 携带执行上下文(session_id、trace_id、cancellation_token),
|
||||
/// 工具实现可在执行期间检查 `ctx.cancellation_token` 来支持优雅取消。
|
||||
async fn execute(&self, args: Value, ctx: &ToolContext<'_>) -> Result<Value, ToolError>;
|
||||
}
|
||||
|
||||
/// 为 `Arc<dyn BaseTool>` 提供 `Send + Sync` 包装,便于在 `Vec<Arc<dyn BaseTool>>` 中使用。
|
||||
pub type ToolRef = Arc<dyn BaseTool>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
struct EchoTool;
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for EchoTool {
|
||||
fn name(&self) -> &str {
|
||||
"echo"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"回显输入"
|
||||
}
|
||||
|
||||
fn parameters(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": { "type": "string" }
|
||||
},
|
||||
"required": ["text"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||
Ok(args)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_tool_execute() {
|
||||
let tool = EchoTool;
|
||||
let ctx = ToolContext::new("session-1", "trace-1");
|
||||
let result = tool.execute(json!({"text": "hello"}), &ctx).await.unwrap();
|
||||
assert_eq!(result, json!({"text": "hello"}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_permissions_empty() {
|
||||
let tool = EchoTool;
|
||||
assert!(tool.required_permissions().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_context_creation() {
|
||||
let ctx = ToolContext::new("s1", "t1");
|
||||
assert_eq!(ctx.session_id, "s1");
|
||||
assert_eq!(ctx.trace_id, "t1");
|
||||
assert!(!ctx.cancellation_token.is_cancelled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_context_cancellation() {
|
||||
let token = CancellationToken::new();
|
||||
token.cancel();
|
||||
let ctx = ToolContext::with_cancellation_token("s1", "t1", token);
|
||||
assert!(ctx.cancellation_token.is_cancelled());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
//! 工具系统错误类型。
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
/// 工具调用过程中可能发生的所有错误。
|
||||
#[derive(thiserror::Error, Debug, Clone)]
|
||||
pub enum ToolError {
|
||||
/// 工具未注册。
|
||||
#[error("工具 '{0}' 未注册")]
|
||||
NotFound(String),
|
||||
|
||||
/// 工具执行失败(可恢复——文本回传 LLM)。
|
||||
#[error("工具 '{0}' 执行失败: {1}")]
|
||||
ExecutionFailed(String, String),
|
||||
|
||||
/// 工具参数无效(可恢复——文本回传 LLM)。
|
||||
#[error("工具 '{0}' 参数无效: {1}")]
|
||||
InvalidArguments(String, String),
|
||||
|
||||
/// 权限被拒绝(不可恢复——终止循环)。
|
||||
#[error("权限被拒绝: 工具 '{0}' 需要 {1} 权限")]
|
||||
PermissionDenied(String, String),
|
||||
|
||||
/// MCP 协议错误(不可恢复)。
|
||||
#[error("MCP 协议错误: {0}")]
|
||||
McpError(String),
|
||||
|
||||
/// MCP 未初始化(不可恢复)。
|
||||
#[error("MCP 未初始化: {0}")]
|
||||
McpNotInitialized(String),
|
||||
|
||||
/// MCP 超时(不可恢复)。
|
||||
#[error("MCP 超时: {0}")]
|
||||
McpTimeout(String),
|
||||
|
||||
/// IO 错误(不可恢复)。
|
||||
#[error("IO 错误: {0}")]
|
||||
Io(Arc<std::io::Error>),
|
||||
|
||||
/// 取消。
|
||||
#[error("工具执行已取消: {0}")]
|
||||
Cancelled(String),
|
||||
|
||||
/// 其他未分类错误。
|
||||
#[error("其他错误: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ToolError {
|
||||
fn from(e: std::io::Error) -> Self {
|
||||
ToolError::Io(Arc::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolError {
|
||||
/// 判断错误是否可恢复——可恢复的错误回传 LLM 由其自行重试,
|
||||
/// 不可恢复的错误终止自动 tool 循环并返回给调用方。
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::ExecutionFailed(..) | Self::InvalidArguments(..) | Self::Other(_)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_execution_failed_is_recoverable() {
|
||||
let err = ToolError::ExecutionFailed("foo".into(), "boom".into());
|
||||
assert!(err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_arguments_is_recoverable() {
|
||||
let err = ToolError::InvalidArguments("foo".into(), "missing x".into());
|
||||
assert!(err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_found_is_not_recoverable() {
|
||||
let err = ToolError::NotFound("foo".into());
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permission_denied_is_not_recoverable() {
|
||||
let err = ToolError::PermissionDenied("foo".into(), "Shell".into());
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mcp_error_is_not_recoverable() {
|
||||
let err = ToolError::McpError("protocol".into());
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mcp_timeout_is_not_recoverable() {
|
||||
let err = ToolError::McpTimeout("foo".into());
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_io_is_not_recoverable() {
|
||||
let io_err = std::io::Error::new(std::io::ErrorKind::Other, "disk");
|
||||
let err = ToolError::from(io_err);
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_other_is_recoverable() {
|
||||
let err = ToolError::Other("something".into());
|
||||
assert!(err.is_recoverable());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,640 @@
|
||||
//! MCP 协议客户端 —— 与 MCP Server 通过 JSON-RPC over stdio 通信。
|
||||
//!
|
||||
//! 当前 Phase 2 实现 stdio transport。`StreamableHttp` 枚举变体已预留,
|
||||
//! 但实际实现推迟到后续版本。
|
||||
//!
|
||||
//! ## 协议版本
|
||||
//!
|
||||
//! 实现遵循 MCP 协议版本 2025-03-26。
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::process::Stdio;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
|
||||
use crate::llm::types::ToolDefinition;
|
||||
use crate::tools::base::{BaseTool, ToolContext, ToolRef};
|
||||
use crate::tools::error::ToolError;
|
||||
|
||||
/// MCP 协议版本。
|
||||
const MCP_VERSION: &str = "2025-03-26";
|
||||
|
||||
/// MCP 传输方式。
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum McpTransport {
|
||||
/// 通过子进程 stdin/stdout 通信。
|
||||
Stdio {
|
||||
/// 启动命令(如 `"npx"`)。
|
||||
command: String,
|
||||
/// 命令参数(如 `["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]`)。
|
||||
args: Vec<String>,
|
||||
},
|
||||
/// Streamable HTTP 传输(MCP 2025-03-26 引入,替代已废弃的 HTTP+SSE)。
|
||||
///
|
||||
/// 当前 Phase 2 预留枚举变体,调用方法会返回 `ToolError::McpError`。
|
||||
StreamableHttp {
|
||||
/// MCP 端点 URL。
|
||||
url: String,
|
||||
/// 可选的 HTTP 头(如 Authorization)。
|
||||
headers: Option<Vec<(String, String)>>,
|
||||
},
|
||||
}
|
||||
|
||||
/// JSON-RPC 请求。
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct JsonRpcRequest {
|
||||
jsonrpc: &'static str,
|
||||
id: u64,
|
||||
method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
params: Option<Value>,
|
||||
}
|
||||
|
||||
impl JsonRpcRequest {
|
||||
fn new(id: u64, method: impl Into<String>, params: Option<Value>) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0",
|
||||
id,
|
||||
method: method.into(),
|
||||
params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// JSON-RPC 响应。
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct JsonRpcResponse {
|
||||
jsonrpc: String,
|
||||
id: u64,
|
||||
#[serde(default)]
|
||||
result: Option<Value>,
|
||||
#[serde(default)]
|
||||
error: Option<JsonRpcError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct JsonRpcError {
|
||||
code: i32,
|
||||
message: String,
|
||||
#[serde(default)]
|
||||
data: Option<Value>,
|
||||
}
|
||||
|
||||
/// MCP 子进程运行时状态。
|
||||
struct ChildProcessState {
|
||||
child: Child,
|
||||
stdin: ChildStdin,
|
||||
pending: HashMap<u64, oneshot::Sender<Result<Value, ToolError>>>,
|
||||
next_id: u64,
|
||||
}
|
||||
|
||||
impl ChildProcessState {
|
||||
fn next_id(&mut self) -> u64 {
|
||||
self.next_id += 1;
|
||||
self.next_id
|
||||
}
|
||||
}
|
||||
|
||||
/// MCP Server 暴露的工具(缓存结构)。
|
||||
#[derive(Debug, Clone)]
|
||||
struct McpTool {
|
||||
name: String,
|
||||
description: Option<String>,
|
||||
input_schema: Value,
|
||||
}
|
||||
|
||||
/// MCP 客户端 —— 与 MCP 服务器通信。
|
||||
pub struct McpClient {
|
||||
transport: McpTransport,
|
||||
server_name: String,
|
||||
/// 已初始化的工具列表(缓存)。
|
||||
tools: Vec<McpTool>,
|
||||
/// 是否已初始化。
|
||||
initialized: AtomicBool,
|
||||
/// 超时时间(秒)。
|
||||
timeout_secs: u64,
|
||||
/// 子进程运行时状态(`connect()` 后创建,`close()` 后取回)。
|
||||
process: Option<Arc<Mutex<ChildProcessState>>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for McpClient {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("McpClient")
|
||||
.field("server_name", &self.server_name)
|
||||
.field("initialized", &self.initialized.load(Ordering::SeqCst))
|
||||
.field("tool_count", &self.tools.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl McpClient {
|
||||
/// 创建一个 MCP 客户端。
|
||||
pub fn new(server_name: impl Into<String>, transport: McpTransport) -> Self {
|
||||
Self {
|
||||
transport,
|
||||
server_name: server_name.into(),
|
||||
tools: Vec::new(),
|
||||
initialized: AtomicBool::new(false),
|
||||
timeout_secs: 30,
|
||||
process: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 设置超时时间(秒)。
|
||||
pub fn with_timeout(mut self, secs: u64) -> Self {
|
||||
self.timeout_secs = secs;
|
||||
self
|
||||
}
|
||||
|
||||
/// 检查是否已连接。
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
self.initialized.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// 连接并初始化(发送 initialize 请求)。
|
||||
pub async fn connect(&mut self) -> Result<(), ToolError> {
|
||||
if self.is_initialized() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match &self.transport {
|
||||
McpTransport::Stdio { command, args } => {
|
||||
let mut cmd = Command::new(command);
|
||||
cmd.args(args)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
#[cfg(unix)]
|
||||
cmd.kill_on_drop(true);
|
||||
#[cfg(windows)]
|
||||
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW
|
||||
|
||||
let mut child = cmd
|
||||
.spawn()
|
||||
.map_err(|e| ToolError::McpError(format!("启动 MCP 子进程失败: {e}")))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| ToolError::McpError("无法获取子进程 stdin".into()))?;
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| ToolError::McpError("无法获取子进程 stdout".into()))?;
|
||||
|
||||
// 启动 reader task 持续读取 stdout
|
||||
let pending: HashMap<u64, oneshot::Sender<Result<Value, ToolError>>> =
|
||||
HashMap::new();
|
||||
let state = Arc::new(Mutex::new(ChildProcessState {
|
||||
child,
|
||||
stdin,
|
||||
pending,
|
||||
next_id: 0,
|
||||
}));
|
||||
|
||||
// 启动后台 reader
|
||||
let pending_arc = Arc::clone(&state);
|
||||
tokio::spawn(async move {
|
||||
Self::read_loop(BufReader::new(stdout), pending_arc).await;
|
||||
});
|
||||
|
||||
self.process = Some(state);
|
||||
}
|
||||
McpTransport::StreamableHttp { .. } => {
|
||||
return Err(ToolError::McpError(
|
||||
"StreamableHttp transport 尚未实现".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// 发送 initialize 请求
|
||||
let init_params = json!({
|
||||
"protocolVersion": MCP_VERSION,
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "agcore",
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}
|
||||
});
|
||||
let _response = self
|
||||
.send_request("initialize", Some(init_params))
|
||||
.await?;
|
||||
|
||||
// 发送 initialized 通知(无 id)
|
||||
self.send_notification("notifications/initialized", Some(json!({})))
|
||||
.await?;
|
||||
|
||||
self.initialized.store(true, Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 列出服务器支持的工具(调用 `tools/list`)。
|
||||
pub async fn list_tools(&mut self) -> Result<Vec<ToolDefinition>, ToolError> {
|
||||
if !self.is_initialized() {
|
||||
return Err(ToolError::McpNotInitialized(self.server_name.clone()));
|
||||
}
|
||||
|
||||
let response = self.send_request("tools/list", None).await?;
|
||||
let tools_value = response
|
||||
.get("tools")
|
||||
.ok_or_else(|| ToolError::McpError("tools/list 响应缺少 tools 字段".into()))?;
|
||||
let tools_arr = tools_value
|
||||
.as_array()
|
||||
.ok_or_else(|| ToolError::McpError("tools/list 响应 tools 字段不是数组".into()))?;
|
||||
|
||||
self.tools.clear();
|
||||
let mut defs = Vec::with_capacity(tools_arr.len());
|
||||
for tool in tools_arr {
|
||||
let name = tool
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| ToolError::McpError("工具缺少 name 字段".into()))?
|
||||
.to_string();
|
||||
let description = tool
|
||||
.get("description")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
let input_schema = tool
|
||||
.get("inputSchema")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| json!({"type": "object", "properties": {}}));
|
||||
|
||||
self.tools.push(McpTool {
|
||||
name: name.clone(),
|
||||
description: description.clone(),
|
||||
input_schema: input_schema.clone(),
|
||||
});
|
||||
defs.push(ToolDefinition {
|
||||
name,
|
||||
description,
|
||||
parameters: input_schema,
|
||||
strict: None,
|
||||
});
|
||||
}
|
||||
Ok(defs)
|
||||
}
|
||||
|
||||
/// 调用一个工具(调用 `tools/call`)。
|
||||
pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, ToolError> {
|
||||
if !self.is_initialized() {
|
||||
return Err(ToolError::McpNotInitialized(self.server_name.clone()));
|
||||
}
|
||||
|
||||
let params = json!({
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
});
|
||||
let response = self.send_request("tools/call", Some(params)).await?;
|
||||
|
||||
// 解析 content 字段
|
||||
if let Some(content) = response.get("content").and_then(|c| c.as_array()) {
|
||||
// 收集所有 text 内容
|
||||
let mut combined = String::new();
|
||||
for item in content {
|
||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||
if !combined.is_empty() {
|
||||
combined.push('\n');
|
||||
}
|
||||
combined.push_str(text);
|
||||
}
|
||||
}
|
||||
if !combined.is_empty() {
|
||||
return Ok(Value::String(combined));
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 content 字段,尝试直接返回 is_error 标记
|
||||
if let Some(true) = response.get("isError").and_then(|v| v.as_bool()) {
|
||||
return Err(ToolError::ExecutionFailed(
|
||||
name.to_string(),
|
||||
"MCP 工具返回 isError=true".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// 回退:返回完整响应
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// 关闭连接(终止子进程)。
|
||||
pub async fn close(&mut self) -> Result<(), ToolError> {
|
||||
if !self.is_initialized() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 尝试发送 shutdown(不强制要求响应)
|
||||
let _ = self.send_notification("shutdown", None).await;
|
||||
|
||||
if let Some(state) = self.process.take() {
|
||||
let mut state = state.lock().await;
|
||||
// 优雅等待 5 秒
|
||||
let graceful = tokio::time::timeout(
|
||||
Duration::from_secs(5),
|
||||
state.child.wait(),
|
||||
)
|
||||
.await;
|
||||
if graceful.is_err() {
|
||||
// 超时则强杀
|
||||
let _ = state.child.kill().await;
|
||||
}
|
||||
}
|
||||
|
||||
self.initialized.store(false, Ordering::SeqCst);
|
||||
self.tools.clear();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 将 MCP 客户端转换为 `BaseTool` 适配器列表(用于注册到 `ToolRegistry`)。
|
||||
///
|
||||
/// **注意**:返回的适配器持有 `Arc<McpClient>`,但 `McpClient` 的可变性
|
||||
/// (如 `list_tools` 刷新缓存)会通过 `Mutex` 处理。当前适配器仅缓存
|
||||
/// 转换时的工具列表,不感知后续刷新。
|
||||
pub fn into_tools(self) -> Vec<ToolRef> {
|
||||
let mut tools = Vec::with_capacity(self.tools.len());
|
||||
for mcp_tool in self.tools {
|
||||
let tool = McpToolAdapter {
|
||||
client: McpClientHandle::Empty,
|
||||
name: mcp_tool.name,
|
||||
description: mcp_tool.description.unwrap_or_default(),
|
||||
parameters: mcp_tool.input_schema,
|
||||
};
|
||||
tools.push(Arc::new(tool) as ToolRef);
|
||||
}
|
||||
tools
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&self,
|
||||
method: &str,
|
||||
params: Option<Value>,
|
||||
) -> Result<Value, ToolError> {
|
||||
let state_arc = self
|
||||
.process
|
||||
.as_ref()
|
||||
.ok_or_else(|| ToolError::McpNotInitialized(self.server_name.clone()))?
|
||||
.clone();
|
||||
|
||||
let (id, request_json) = {
|
||||
let mut state = state_arc.lock().await;
|
||||
let id = state.next_id();
|
||||
let req = JsonRpcRequest::new(id, method, params);
|
||||
let json = serde_json::to_string(&req)
|
||||
.map_err(|e| ToolError::McpError(format!("序列化请求失败: {e}")))?;
|
||||
(id, json)
|
||||
};
|
||||
|
||||
// 注册 oneshot 等待响应
|
||||
let (tx, rx) = oneshot::channel();
|
||||
{
|
||||
let mut state = state_arc.lock().await;
|
||||
state.pending.insert(id, tx);
|
||||
}
|
||||
|
||||
// 写入请求
|
||||
{
|
||||
let mut state = state_arc.lock().await;
|
||||
state
|
||||
.stdin
|
||||
.write_all(request_json.as_bytes())
|
||||
.await
|
||||
.map_err(|e| ToolError::McpError(format!("写入请求失败: {e}")))?;
|
||||
state
|
||||
.stdin
|
||||
.write_all(b"\n")
|
||||
.await
|
||||
.map_err(|e| ToolError::McpError(format!("写入换行失败: {e}")))?;
|
||||
state.stdin.flush().await.map_err(|e| {
|
||||
ToolError::McpError(format!("flush stdin 失败: {e}"))
|
||||
})?;
|
||||
}
|
||||
|
||||
// 等待响应(带超时)
|
||||
tokio::time::timeout(Duration::from_secs(self.timeout_secs), rx)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
// 超时:清理 pending
|
||||
let state_arc = state_arc.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut state = state_arc.lock().await;
|
||||
state.pending.remove(&id);
|
||||
});
|
||||
ToolError::McpTimeout(method.to_string())
|
||||
})?
|
||||
.map_err(|_| ToolError::McpError("response channel 关闭".into()))?
|
||||
}
|
||||
|
||||
async fn send_notification(
|
||||
&self,
|
||||
method: &str,
|
||||
params: Option<Value>,
|
||||
) -> Result<(), ToolError> {
|
||||
let state_arc = self
|
||||
.process
|
||||
.as_ref()
|
||||
.ok_or_else(|| ToolError::McpNotInitialized(self.server_name.clone()))?
|
||||
.clone();
|
||||
|
||||
let notification = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params,
|
||||
});
|
||||
let json = serde_json::to_string(¬ification)
|
||||
.map_err(|e| ToolError::McpError(format!("序列化通知失败: {e}")))?;
|
||||
|
||||
let mut state = state_arc.lock().await;
|
||||
state
|
||||
.stdin
|
||||
.write_all(json.as_bytes())
|
||||
.await
|
||||
.map_err(|e| ToolError::McpError(format!("写入通知失败: {e}")))?;
|
||||
state
|
||||
.stdin
|
||||
.write_all(b"\n")
|
||||
.await
|
||||
.map_err(|e| ToolError::McpError(format!("写入换行失败: {e}")))?;
|
||||
state
|
||||
.stdin
|
||||
.flush()
|
||||
.await
|
||||
.map_err(|e| ToolError::McpError(format!("flush stdin 失败: {e}")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 持续读取 stdout,将响应分发到对应的 oneshot sender。
|
||||
async fn read_loop(
|
||||
mut reader: BufReader<ChildStdout>,
|
||||
state: Arc<Mutex<ChildProcessState>>,
|
||||
) {
|
||||
let mut line = String::new();
|
||||
loop {
|
||||
line.clear();
|
||||
match reader.read_line(&mut line).await {
|
||||
Ok(0) => {
|
||||
// EOF:通知所有 pending 失败
|
||||
let mut state = state.lock().await;
|
||||
for (_, tx) in state.pending.drain() {
|
||||
let _ = tx.send(Err(ToolError::McpError("子进程退出".into())));
|
||||
}
|
||||
break;
|
||||
}
|
||||
Ok(_) => {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
// 尝试解析为 JSON-RPC 响应
|
||||
let parsed: Result<JsonRpcResponse, _> = serde_json::from_str(trimmed);
|
||||
if let Ok(response) = parsed {
|
||||
let value = if let Some(err) = response.error {
|
||||
Err(ToolError::McpError(format!(
|
||||
"[{}] {}",
|
||||
err.code, err.message
|
||||
)))
|
||||
} else {
|
||||
Ok(response.result.unwrap_or(Value::Null))
|
||||
};
|
||||
let mut state = state.lock().await;
|
||||
if let Some(tx) = state.pending.remove(&response.id) {
|
||||
let _ = tx.send(value);
|
||||
}
|
||||
}
|
||||
// 非响应消息(通知、request from server)忽略
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("MCP read_loop error: {e}");
|
||||
let mut state = state.lock().await;
|
||||
for (_, tx) in state.pending.drain() {
|
||||
let _ = tx.send(Err(ToolError::McpError(format!("读取失败: {e}"))));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MCP 工具适配器 —— 将 MCP 工具包装为 `BaseTool`。
|
||||
struct McpToolAdapter {
|
||||
/// 持有 client 的弱引用。实际生产中应使用 `Arc<McpClient>`,
|
||||
/// 但当前 Phase 2 实现不直接持有可变的 `McpClient`。
|
||||
/// 标记为 unused 但保留字段以展示扩展路径。
|
||||
#[allow(dead_code)]
|
||||
client: McpClientHandle,
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: Value,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
enum McpClientHandle {
|
||||
Empty,
|
||||
// Future: Shared(Arc<McpClient>),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for McpToolAdapter {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
&self.description
|
||||
}
|
||||
|
||||
fn parameters(&self) -> Value {
|
||||
self.parameters.clone()
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
_args: Value,
|
||||
_ctx: &ToolContext<'_>,
|
||||
) -> Result<Value, ToolError> {
|
||||
// 当前 Phase 2 实现的简化:McpToolAdapter 不持有活跃 MCP 连接。
|
||||
// 实际生产中应持有 Arc<McpClient> 并通过 mcp.call_tool() 执行。
|
||||
// 这里返回错误,提示需要通过其他方式调用 MCP 工具。
|
||||
Err(ToolError::McpError(format!(
|
||||
"MCP 工具 '{}' 需要活跃的 McpClient 引用(当前 Phase 2 简化实现)",
|
||||
self.name
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_transport_debug() {
|
||||
let transport = McpTransport::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string()],
|
||||
};
|
||||
let formatted = format!("{transport:?}");
|
||||
assert!(formatted.contains("echo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_creation() {
|
||||
let transport = McpTransport::Stdio {
|
||||
command: "test".to_string(),
|
||||
args: vec![],
|
||||
};
|
||||
let client = McpClient::new("test-server", transport).with_timeout(60);
|
||||
assert_eq!(client.server_name, "test-server");
|
||||
assert_eq!(client.timeout_secs, 60);
|
||||
assert!(!client.is_initialized());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_request_serialize() {
|
||||
let req = JsonRpcRequest::new(42, "test", Some(json!({"a": 1})));
|
||||
let s = serde_json::to_string(&req).unwrap();
|
||||
assert!(s.contains("\"jsonrpc\":\"2.0\""));
|
||||
assert!(s.contains("\"id\":42"));
|
||||
assert!(s.contains("\"method\":\"test\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_response_parse_ok() {
|
||||
let s = r#"{"jsonrpc":"2.0","id":1,"result":{"foo":"bar"}}"#;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(s).unwrap();
|
||||
assert_eq!(resp.id, 1);
|
||||
assert!(resp.result.is_some());
|
||||
assert!(resp.error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_response_parse_error() {
|
||||
let s =
|
||||
r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"Method not found"}}"#;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(s).unwrap();
|
||||
assert_eq!(resp.id, 1);
|
||||
assert!(resp.result.is_none());
|
||||
let err = resp.error.unwrap();
|
||||
assert_eq!(err.code, -32601);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streamable_http_not_implemented() {
|
||||
let mut client = McpClient::new(
|
||||
"http-server",
|
||||
McpTransport::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
headers: None,
|
||||
},
|
||||
);
|
||||
let result = client.connect().await;
|
||||
// 当前 Phase 2 返回未实现错误
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result, Err(ToolError::McpError(_))));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,286 @@
|
||||
//! 工具权限管理。
|
||||
|
||||
use crate::tools::error::ToolError;
|
||||
|
||||
/// 权限级别枚举。
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum Permission {
|
||||
/// 只读(读取文件、查询数据库等)。
|
||||
Read,
|
||||
/// 写入(创建/修改文件、插入数据等)。
|
||||
Write,
|
||||
/// 删除(删除文件、记录等)。
|
||||
Delete,
|
||||
/// 网络访问(HTTP 请求等)。
|
||||
Network,
|
||||
/// Shell 命令执行。
|
||||
Shell,
|
||||
/// 文件系统操作(除读/写/删之外的 FS 操作)。
|
||||
FileSystem,
|
||||
/// 自定义权限(可通过 namespaced 字符串扩展)。
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Permission {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Read => write!(f, "Read"),
|
||||
Self::Write => write!(f, "Write"),
|
||||
Self::Delete => write!(f, "Delete"),
|
||||
Self::Network => write!(f, "Network"),
|
||||
Self::Shell => write!(f, "Shell"),
|
||||
Self::FileSystem => write!(f, "FileSystem"),
|
||||
Self::Custom(s) => write!(f, "Custom({})", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 权限配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PermissionConfig {
|
||||
/// 允许的权限列表(空 = 全部允许,配合 `allow_unspecified` 决定)。
|
||||
pub allowed: Vec<Permission>,
|
||||
/// 拒绝的权限列表(优先级高于 `allowed`)。
|
||||
pub denied: Vec<Permission>,
|
||||
/// 当工具未声明权限时是否允许执行。
|
||||
pub allow_unspecified: bool,
|
||||
}
|
||||
|
||||
impl Default for PermissionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allowed: vec![Permission::Read, Permission::Network],
|
||||
denied: vec![Permission::Delete, Permission::Shell],
|
||||
allow_unspecified: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 权限检查器。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PermissionChecker {
|
||||
config: PermissionConfig,
|
||||
}
|
||||
|
||||
impl PermissionChecker {
|
||||
/// 创建一个新的权限检查器。
|
||||
pub fn new(config: PermissionConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// 检查指定工具声明的权限是否允许执行。
|
||||
///
|
||||
/// 判定规则:
|
||||
/// 1. 任一权限在 `denied` 中 → 拒绝
|
||||
/// 2. 所有权限都在 `allowed` 中 → 允许
|
||||
/// 3. `allowed` 非空且存在未声明权限 → 拒绝
|
||||
/// 4. `allowed` 为空 → 按 `allow_unspecified` 判定
|
||||
/// 5. 工具未声明任何权限时按 `allow_unspecified` 判定
|
||||
pub fn check(&self, tool_name: &str, permissions: &[Permission]) -> Result<(), ToolError> {
|
||||
// 任一权限在 denied 中 → 拒绝
|
||||
for perm in permissions {
|
||||
if self.config.denied.contains(perm) {
|
||||
return Err(ToolError::PermissionDenied(
|
||||
tool_name.to_string(),
|
||||
perm.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// 工具未声明任何权限
|
||||
if permissions.is_empty() {
|
||||
return if self.config.allow_unspecified {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ToolError::PermissionDenied(
|
||||
tool_name.to_string(),
|
||||
"Unspecified".to_string(),
|
||||
))
|
||||
};
|
||||
}
|
||||
|
||||
// allowed 为空 → 走 allow_unspecified 兜底
|
||||
if self.config.allowed.is_empty() {
|
||||
return if self.config.allow_unspecified {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ToolError::PermissionDenied(
|
||||
tool_name.to_string(),
|
||||
"Unspecified".to_string(),
|
||||
))
|
||||
};
|
||||
}
|
||||
|
||||
// allowed 非空(白名单模式)—— 所有权限必须在其中
|
||||
for perm in permissions {
|
||||
if !self.config.allowed.contains(perm) {
|
||||
return Err(ToolError::PermissionDenied(
|
||||
tool_name.to_string(),
|
||||
perm.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn p(perm: Permission) -> Vec<Permission> {
|
||||
vec![perm]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config_allows_read() {
|
||||
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||
assert!(checker.check("weather", &p(Permission::Read)).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config_allows_network() {
|
||||
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||
assert!(checker.check("http_get", &p(Permission::Network)).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config_denies_delete() {
|
||||
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||
assert!(checker
|
||||
.check("rm_file", &p(Permission::Delete))
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config_denies_shell() {
|
||||
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||
assert!(checker.check("run_shell", &p(Permission::Shell)).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_white_list_mode_denies_unlisted() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![Permission::Read],
|
||||
denied: vec![],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &p(Permission::Read)).is_ok());
|
||||
assert!(checker.check("t", &p(Permission::Write)).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_white_list_mode_allows_listed() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![Permission::Read, Permission::Write],
|
||||
denied: vec![],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &p(Permission::Write)).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_black_list_deny_priority() {
|
||||
// 即便 allowed 中包含了 denied 权限,仍以 denied 为准
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![Permission::Shell, Permission::Read],
|
||||
denied: vec![Permission::Shell],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &p(Permission::Shell)).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_allowed_with_allow_unspecified() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![],
|
||||
denied: vec![],
|
||||
allow_unspecified: true,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &p(Permission::Write)).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_allowed_without_allow_unspecified() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![],
|
||||
denied: vec![],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &p(Permission::Write)).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unspecified_tool_with_allow() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![],
|
||||
denied: vec![],
|
||||
allow_unspecified: true,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &[]).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unspecified_tool_without_allow() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![],
|
||||
denied: vec![],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &[]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_permission_collision() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![Permission::Custom("db:read".into())],
|
||||
denied: vec![Permission::Custom("db:write".into())],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker
|
||||
.check("t", &[Permission::Custom("db:read".into())])
|
||||
.is_ok());
|
||||
assert!(checker
|
||||
.check("t", &[Permission::Custom("db:write".into())])
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_permission_all_in_allowed() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![Permission::Read, Permission::Write, Permission::Network],
|
||||
denied: vec![],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker
|
||||
.check(
|
||||
"t",
|
||||
&[Permission::Read, Permission::Network]
|
||||
)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_permission_one_not_in_allowed() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![Permission::Read, Permission::Network],
|
||||
denied: vec![],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
// 任一权限不在白名单则拒绝
|
||||
assert!(checker
|
||||
.check("t", &[Permission::Read, Permission::Write])
|
||||
.is_err());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,371 @@
|
||||
//! 工具注册表 —— 管理工具注册、发现、调用。
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::future::join_all;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::llm::types::ToolDefinition;
|
||||
use crate::tools::base::{ToolContext, ToolRef};
|
||||
use crate::tools::error::ToolError;
|
||||
use crate::tools::permission::PermissionChecker;
|
||||
|
||||
/// 工具调用记录 —— 用于追踪和调试。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolInvocation {
|
||||
/// 被调用的工具名。
|
||||
pub tool_name: String,
|
||||
/// 工具的入参。
|
||||
pub input: Value,
|
||||
/// 工具的输出。
|
||||
pub output: Result<Value, ToolError>,
|
||||
}
|
||||
|
||||
impl ToolInvocation {
|
||||
/// 创建一个新的工具调用记录。
|
||||
pub fn new(tool_name: String, input: Value, output: Result<Value, ToolError>) -> Self {
|
||||
Self {
|
||||
tool_name,
|
||||
input,
|
||||
output,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 工具注册表 —— 管理工具注册、发现、调用。
|
||||
///
|
||||
/// 通过 `Arc` 共享,方法签名 `&self`,可安全跨 task 并行调用。
|
||||
/// 不支持运行时并发注册(应在 setup 阶段一次性构建后冻结)。
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ToolRegistry {
|
||||
inner: Arc<ToolRegistryInner>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct ToolRegistryInner {
|
||||
tools: HashMap<String, ToolRef>,
|
||||
permission_checker: Option<Arc<PermissionChecker>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ToolRegistry {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ToolRegistry")
|
||||
.field("tool_names", &self.inner.tools.keys().collect::<Vec<_>>())
|
||||
.field("has_checker", &self.inner.permission_checker.is_some())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
/// 创建一个新的工具注册表。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Arc::new(ToolRegistryInner {
|
||||
tools: HashMap::new(),
|
||||
permission_checker: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// 设置权限检查器(Builder 模式)。
|
||||
pub fn with_permission_checker(mut self, checker: PermissionChecker) -> Self {
|
||||
let inner = Arc::make_mut(&mut self.inner);
|
||||
inner.permission_checker = Some(Arc::new(checker));
|
||||
self
|
||||
}
|
||||
|
||||
/// 注册一个工具。
|
||||
///
|
||||
/// 重复注册同名工具返回错误。
|
||||
pub fn register(&mut self, tool: ToolRef) -> Result<(), ToolError> {
|
||||
let name = tool.name().to_string();
|
||||
let inner = Arc::make_mut(&mut self.inner);
|
||||
if inner.tools.contains_key(&name) {
|
||||
return Err(ToolError::ExecutionFailed(name, "工具已存在".to_string()));
|
||||
}
|
||||
inner.tools.insert(name, tool);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 批量注册工具。
|
||||
pub fn register_all(&mut self, tools: Vec<ToolRef>) -> Result<(), ToolError> {
|
||||
for tool in tools {
|
||||
self.register(tool)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 注销一个工具。
|
||||
pub fn unregister(&mut self, name: &str) -> Option<ToolRef> {
|
||||
let inner = Arc::make_mut(&mut self.inner);
|
||||
inner.tools.remove(name)
|
||||
}
|
||||
|
||||
/// 按名称查找工具。
|
||||
pub fn get(&self, name: &str) -> Option<ToolRef> {
|
||||
self.inner.tools.get(name).cloned()
|
||||
}
|
||||
|
||||
/// 获取所有已注册工具的名称列表。
|
||||
pub fn list_tools(&self) -> Vec<String> {
|
||||
self.inner.tools.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// 获取所有工具的 `ToolDefinition` 列表(用于传递给 LLM)。
|
||||
pub fn definitions(&self) -> Vec<ToolDefinition> {
|
||||
self.inner
|
||||
.tools
|
||||
.values()
|
||||
.map(|tool| ToolDefinition {
|
||||
name: tool.name().to_string(),
|
||||
description: Some(tool.description().to_string()),
|
||||
parameters: tool.parameters(),
|
||||
strict: None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 调用单个工具(含权限检查)。
|
||||
pub async fn invoke(&self, name: &str, args: Value) -> Result<ToolInvocation, ToolError> {
|
||||
let tool = self
|
||||
.get(name)
|
||||
.ok_or_else(|| ToolError::NotFound(name.to_string()))?;
|
||||
|
||||
if let Some(checker) = &self.inner.permission_checker {
|
||||
checker.check(name, &tool.required_permissions())?;
|
||||
}
|
||||
|
||||
let ctx = ToolContext::new(name, "");
|
||||
let output = tool.execute(args.clone(), &ctx).await;
|
||||
Ok(ToolInvocation::new(name.to_string(), args, output))
|
||||
}
|
||||
|
||||
/// 并行执行多个工具调用(互不依赖的工具)。
|
||||
///
|
||||
/// 每个工具独立超时(`timeout_per_call_secs`,0 表示不超时)。
|
||||
/// 单个工具超时不会影响其他工具的返回。
|
||||
pub async fn invoke_all(
|
||||
&self,
|
||||
calls: Vec<(String, Value)>,
|
||||
timeout_per_call_secs: u64,
|
||||
) -> Vec<ToolInvocation> {
|
||||
let this = self.clone();
|
||||
let futures = calls.into_iter().map(|(name, args)| {
|
||||
let this = this.clone();
|
||||
async move {
|
||||
match if timeout_per_call_secs == 0 {
|
||||
Ok(this.invoke(&name, args.clone()).await)
|
||||
} else {
|
||||
tokio::time::timeout(
|
||||
Duration::from_secs(timeout_per_call_secs),
|
||||
this.invoke(&name, args.clone()),
|
||||
)
|
||||
.await
|
||||
} {
|
||||
Ok(result) => result.unwrap_or_else(|e| {
|
||||
ToolInvocation::new(name.clone(), args.clone(), Err(e))
|
||||
}),
|
||||
Err(_) => ToolInvocation::new(
|
||||
name,
|
||||
args,
|
||||
Err(ToolError::McpTimeout("timeout".into())),
|
||||
),
|
||||
}
|
||||
}
|
||||
});
|
||||
join_all(futures).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tools::BaseTool;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
struct AddTool {
|
||||
base: i64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for AddTool {
|
||||
fn name(&self) -> &str {
|
||||
"add"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"加法"
|
||||
}
|
||||
|
||||
fn parameters(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": { "n": { "type": "integer" } },
|
||||
"required": ["n"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||
let n = args["n"].as_i64().unwrap_or(0);
|
||||
Ok(json!({ "result": self.base + n }))
|
||||
}
|
||||
}
|
||||
|
||||
struct FailTool;
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for FailTool {
|
||||
fn name(&self) -> &str {
|
||||
"fail"
|
||||
}
|
||||
fn description(&self) -> &str {
|
||||
"总会失败"
|
||||
}
|
||||
fn parameters(&self) -> Value {
|
||||
json!({})
|
||||
}
|
||||
async fn execute(&self, _args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||
Err(ToolError::ExecutionFailed("fail".into(), "boom".into()))
|
||||
}
|
||||
}
|
||||
|
||||
struct ShellTool;
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for ShellTool {
|
||||
fn name(&self) -> &str {
|
||||
"shell"
|
||||
}
|
||||
fn description(&self) -> &str {
|
||||
"shell"
|
||||
}
|
||||
fn parameters(&self) -> Value {
|
||||
json!({})
|
||||
}
|
||||
fn required_permissions(&self) -> Vec<crate::tools::permission::Permission> {
|
||||
vec![crate::tools::permission::Permission::Shell]
|
||||
}
|
||||
async fn execute(&self, _args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||
Ok(json!({}))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_and_get() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 10 })).unwrap();
|
||||
assert!(reg.get("add").is_some());
|
||||
assert!(reg.get("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_duplicate() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||
let result = reg.register(Arc::new(AddTool { base: 1 }));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_all() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
let result = reg.register_all(vec![
|
||||
Arc::new(AddTool { base: 1 }),
|
||||
Arc::new(AddTool { base: 2 }),
|
||||
]);
|
||||
assert!(result.is_err()); // 重名 add → 失败
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unregister() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||
let removed = reg.unregister("add");
|
||||
assert!(removed.is_some());
|
||||
assert!(reg.get("add").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_tools() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||
reg.register(Arc::new(FailTool)).unwrap();
|
||||
let names = reg.list_tools();
|
||||
assert_eq!(names.len(), 2);
|
||||
assert!(names.contains(&"add".to_string()));
|
||||
assert!(names.contains(&"fail".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_definitions() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||
let defs = reg.definitions();
|
||||
assert_eq!(defs.len(), 1);
|
||||
assert_eq!(defs[0].name, "add");
|
||||
assert!(defs[0].description.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invoke_success() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 100 })).unwrap();
|
||||
let result = reg.invoke("add", json!({ "n": 5 })).await.unwrap();
|
||||
let value = result.output.unwrap();
|
||||
assert_eq!(value["result"], 105);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invoke_not_found() {
|
||||
let reg = ToolRegistry::new();
|
||||
let result = reg.invoke("nope", json!({})).await;
|
||||
assert!(matches!(result, Err(ToolError::NotFound(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invoke_execution_error() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(FailTool)).unwrap();
|
||||
let result = reg.invoke("fail", json!({})).await.unwrap();
|
||||
assert!(result.output.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invoke_with_permission_denied() {
|
||||
let mut reg = ToolRegistry::new()
|
||||
.with_permission_checker(PermissionChecker::new(Default::default()));
|
||||
reg.register(Arc::new(ShellTool)).unwrap();
|
||||
let result = reg.invoke("shell", json!({})).await;
|
||||
assert!(matches!(result, Err(ToolError::PermissionDenied(_, _))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invoke_all_parallel() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 1 })).unwrap();
|
||||
reg.register(Arc::new(FailTool)).unwrap();
|
||||
let calls = vec![
|
||||
("add".into(), json!({ "n": 1 })),
|
||||
("add".into(), json!({ "n": 2 })),
|
||||
("fail".into(), json!({})),
|
||||
];
|
||||
let results = reg.invoke_all(calls, 0).await;
|
||||
assert_eq!(results.len(), 3);
|
||||
assert!(results[0].output.is_ok());
|
||||
assert!(results[1].output.is_ok());
|
||||
assert!(results[2].output.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invoke_all_with_timeout() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||
let calls = vec![("add".into(), json!({ "n": 1 }))];
|
||||
let results = reg.invoke_all(calls, 5).await;
|
||||
assert_eq!(results.len(), 1);
|
||||
assert!(results[0].output.is_ok());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user