docs(5-tool-system): 更新工具系统方案文档,完善 Phase 2 实现细节

- 在 BaseTool 中添加 ToolContext 执行上下文参数,包含 session_id、trace_id 和取消令牌
- 更新 LlmCycle 工具循环逻辑:修正消息推送顺序,新增 max_tool_turns 独立字段
- 补充 McpClient 子进程运行时状态 ChildProcessState 设计
- 添加消息压缩 maybe_compact() 方法描述
- 明确 submit_stream_with_tools() 推迟至 Phase 3 实现
- 更新实现计划中各步骤的详细变更
This commit is contained in:
徐涛
2026-06-07 10:23:36 +08:00
parent 5d6bb5e983
commit e598f6d3ee
+89 -21
View File
@@ -85,10 +85,22 @@ pub use registry::{ToolEntry, ToolInvocation, ToolRegistry};
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::Value; use serde_json::Value;
use tokio_util::sync::CancellationToken;
use crate::tools::error::ToolError; use crate::tools::error::ToolError;
use crate::tools::permission::Permission; use crate::tools::permission::Permission;
/// 工具执行上下文 —— 携带每次执行的运行时信息。
/// 新增字段时提供默认值,不破坏已有工具实现。
pub struct ToolContext<'a> {
/// 当前对话/会话 ID,用于关联性追踪。
pub session_id: &'a str,
/// 链路追踪 ID,用于跨工具调用的耗时分布。
pub trace_id: &'a str,
/// 取消令牌,用于优雅取消正在执行的工具。
pub cancellation_token: CancellationToken,
}
/// 工具抽象接口 —— 所有工具(自定义或 MCP)最终都实现此 trait。 /// 工具抽象接口 —— 所有工具(自定义或 MCP)最终都实现此 trait。
#[async_trait] #[async_trait]
pub trait BaseTool: Send + Sync { pub trait BaseTool: Send + Sync {
@@ -107,7 +119,8 @@ pub trait BaseTool: Send + Sync {
} }
/// 执行工具调用。 /// 执行工具调用。
async fn execute(&self, args: Value) -> Result<Value, ToolError>; /// `ctx` 携带执行上下文(session_id、trace_id 等),Phase 3/4 可扩展字段而不破坏 trait 签名。
async fn execute(&self, args: Value, ctx: &ToolContext<'_>) -> Result<Value, ToolError>;
} }
``` ```
@@ -115,7 +128,8 @@ pub trait BaseTool: Send + Sync {
- `name()` 返回 `&str` 而非 `String`,避免每次调用克隆 - `name()` 返回 `&str` 而非 `String`,避免每次调用克隆
- `parameters()` 返回 `serde_json::Value`,与现有 `OpenaiToolDefinition.parameters` 类型一致 - `parameters()` 返回 `serde_json::Value`,与现有 `OpenaiToolDefinition.parameters` 类型一致
- `required_permissions()` 提供默认空实现,简化无敏感操作的工具定义 - `required_permissions()` 提供默认空实现,简化无敏感操作的工具定义
- `execute()` 接收 `Value`JSON 对象)作为参数,返回 `Value` 作为结果,与 OpenAI API 的 arguments/output 格式一致 - `execute()` 接收 `Value`JSON 对象)+ `ToolContext` 作为参数,返回 `Value` 作为结果,与 OpenAI API 的 arguments/output 格式一致
- `ToolContext` 从 Phase 2 即注入 `execute()` 签名,防止后续 breaking change;新增字段用 `Option` 包裹或提供默认值
### 2. ToolRegistry — 工具注册表 ### 2. ToolRegistry — 工具注册表
@@ -308,6 +322,15 @@ pub enum McpTransport {
}, },
} }
/// MCP 子进程运行时状态(connect() 后创建)。
struct ChildProcessState {
child: tokio::process::Child,
stdin: tokio::io::BufWriter<tokio::process::ChildStdin>,
/// 等待响应的请求映射(id → oneshot sender)。
pending: HashMap<u64, tokio::sync::oneshot::Sender<Result<Value, ToolError>>>,
next_id: u64,
}
/// MCP 客户端 —— 与 MCP 服务器通信。 /// MCP 客户端 —— 与 MCP 服务器通信。
pub struct McpClient { pub struct McpClient {
transport: McpTransport, transport: McpTransport,
@@ -318,6 +341,8 @@ pub struct McpClient {
initialized: AtomicBool, initialized: AtomicBool,
/// 超时时间(秒)。 /// 超时时间(秒)。
timeout_secs: u64, timeout_secs: u64,
/// 子进程运行时状态(connect() 后创建,close() 后取回)。
process: Option<tokio::sync::Mutex<ChildProcessState>>,
} }
/// MCP 服务器暴露的工具(缓存结构)。 /// MCP 服务器暴露的工具(缓存结构)。
@@ -335,15 +360,19 @@ impl McpClient {
pub fn with_timeout(mut self, secs: u64) -> Self; pub fn with_timeout(mut self, secs: u64) -> Self;
/// 连接并初始化(发送 initialize 请求,获取服务器能力声明)。 /// 连接并初始化(发送 initialize 请求,获取服务器能力声明)。
/// 启动子进程,创建 ChildProcessState(含 reader task)。
pub async fn connect(&mut self) -> Result<(), ToolError>; pub async fn connect(&mut self) -> Result<(), ToolError>;
/// 列出服务器支持的工具(调用 tools/list)。 /// 列出服务器支持的工具(调用 tools/list)。
pub async fn list_tools(&mut self) -> Result<Vec<ToolDefinition>, ToolError>; pub async fn list_tools(&mut self) -> Result<Vec<ToolDefinition>, ToolError>;
/// 调用一个工具(调用 tools/call)。 /// 调用一个工具(调用 tools/call)。
/// 通过 Mutex 获取 stdin 写入权限,发送 JSON-RPC 请求,通过 id 匹配响应。
/// reader task 持续读取 stdout,解析 JSON-RPC 响应,通过 oneshot 通知调用方。
pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, ToolError>; pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, ToolError>;
/// 关闭连接(终止子进程)。 /// 关闭连接(终止子进程)。
/// 发送 shutdown → 等待 5s 优雅退出 → 超时则 child.kill()。
pub async fn close(&mut self) -> Result<(), ToolError>; pub async fn close(&mut self) -> Result<(), ToolError>;
/// 将 MCP 客户端转换为 BaseTool 适配器列表(用于注册到 ToolRegistry)。 /// 将 MCP 客户端转换为 BaseTool 适配器列表(用于注册到 ToolRegistry)。
@@ -447,18 +476,22 @@ impl LlmCycle {
/// 流程: /// 流程:
/// 1. 发送请求(含工具定义) /// 1. 发送请求(含工具定义)
/// 2. 检查响应中的 finish_reason /// 2. 检查响应中的 finish_reason
/// 3. 如果是 ToolCalls → 执行工具 → 回传结果 → 重复 1 /// 3. 如果是 ToolCalls → 先 push Assistant 消息 → 执行工具 → 回传结果 → 重复 1
/// 4. 如果是 Stop/Length → 返回最终响应 /// 4. 如果是 Stop/Length → push Assistant 消息 → 返回最终响应
///
/// 注意:OpenAI API 要求 tool 消息必须紧跟在对应的 Assistanttool_calls)消息之后。
/// 因此 push 工具结果前必须先 push Assistant 响应,否则 API 拒绝请求。
pub async fn submit_with_tools( pub async fn submit_with_tools(
&mut self, &mut self,
prompt: String, prompt: String,
registry: &ToolRegistry, registry: &ToolRegistry,
) -> Result<ChatResponse, LlmError> { ) -> Result<ChatResponse, LlmError> {
let tools = registry.definitions(); let tools = registry.definitions();
let max_turns = self.config.max_turns.unwrap_or(10); // 注:CycleConfig.max_turns 默认值为 None,实现时需修改 Default 为 Some(10) let max_turns = self.config.max_tool_turns.unwrap_or(10);
let mut turn = 0; let mut turn = 0;
self.messages.push(OpenaiChatMessage::user_text(prompt)); self.messages.push(OpenaiChatMessage::user_text(prompt));
self.maybe_compact();
loop { loop {
turn += 1; turn += 1;
@@ -475,9 +508,12 @@ impl LlmCycle {
// 检查是否需要执行工具 // 检查是否需要执行工具
let should_execute = matches!( let should_execute = matches!(
response.stop_reason, response.stop_reason,
Some(FinishReason::ToolCalls) | None Some(FinishReason::ToolCalls)
) && has_tool_calls(&response.message); ) && has_tool_calls(&response.message);
// 将 Assistant 响应(含 tool_calls 或最终文本)追加到消息历史
self.messages.push(response.message.clone());
if !should_execute { if !should_execute {
return Ok(response); return Ok(response);
} }
@@ -490,7 +526,10 @@ impl LlmCycle {
for result in results { for result in results {
let content = match &result.output { let content = match &result.output {
Ok(value) => serde_json::to_string(value) Ok(value) => serde_json::to_string(value)
.unwrap_or_else(|_| "{}".to_string()), .unwrap_or_else(|e| {
tracing::warn!("工具结果序列化失败: {}", e);
"{}".to_string()
}),
Err(e) => format!("错误: {}", e), Err(e) => format!("错误: {}", e),
}; };
@@ -498,6 +537,21 @@ impl LlmCycle {
OpenaiChatMessage::tool_result(result.tool_name.clone(), content) OpenaiChatMessage::tool_result(result.tool_name.clone(), content)
); );
} }
// 每轮工具执行后触发 compaction,防止 token 快速膨胀
self.maybe_compact();
}
}
/// 在接近上下文窗口时压缩历史消息。
fn maybe_compact(&mut self) {
if let Some(ref config) = self.compact_config
&& should_compact(&self.messages, config, &self.compact_state)
{
let freed = microcompact(&mut self.messages, config.keep_recent);
if freed > 0 {
self.compact_state.record_success();
}
} }
} }
@@ -516,7 +570,7 @@ impl LlmCycle {
| 决策 | 选择 | 理由 | | 决策 | 选择 | 理由 |
|------|------|------| |------|------|------|
| 循环方式 | 同步循环(单线程串行) | 工具执行依赖前一轮结果,串行更安全 | | 循环方式 | 同步循环(单线程串行) | 工具执行依赖前一轮结果,串行更安全 |
| 最大轮次 | `CycleConfig.max_turns`,默认 `Some(10)` | 防止无限循环(LLM 反复调用工具)。**注意**:当前 `CycleConfig` 默认值为 `None`,实现时需将 `Default` 改为 `Some(10)` | | 最大轮次 | `CycleConfig.max_tool_turns`,独立于 `max_turns`,默认 `Some(10)` | 防止无限循环(LLM 反复调用工具)。使用独立字段避免影响现有 `submit()`/`submit_messages()` `max_turns` 语义 |
| 工具并行 | `invoke_all()` 互不依赖的工具并行 | LLM 可能一次发出多个 tool_callsparallel_tool_calls | | 工具并行 | `invoke_all()` 互不依赖的工具并行 | LLM 可能一次发出多个 tool_callsparallel_tool_calls |
| 工具超时 | `CycleConfig::tool_timeout_secs`,默认 60 | 防止单个工具长时间阻塞循环。`invoke_all()` 使用 `tokio::time::timeout` 包装 | | 工具超时 | `CycleConfig::tool_timeout_secs`,默认 60 | 防止单个工具长时间阻塞循环。`invoke_all()` 使用 `tokio::time::timeout` 包装 |
| 错误处理 | 工具执行错误以文本回传 LLM,而非终止循环 | LLM 可自行从错误中恢复 | | 错误处理 | 工具执行错误以文本回传 LLM,而非终止循环 | LLM 可自行从错误中恢复 |
@@ -580,6 +634,10 @@ impl LlmCycle {
`submit_stream()` 的增强方案:新增 `submit_stream_with_tools()`,在流式事件层面支持自动 tool 循环。 `submit_stream()` 的增强方案:新增 `submit_stream_with_tools()`,在流式事件层面支持自动 tool 循环。
> **实现复杂度提示**:流式 tool 循环需要自定义 `Stream` 实现 + 内部状态机(`Streaming` → `ExecutingTools` → `Finished`)。每一轮需要:消费当前流 → 收集事件 → 检测 `TurnComplete(ToolCalls)` → 执行工具 → 发射 `ToolExecutionCompleted` → 发起新流 → 继续 yield。不能用简单的 `stream!` 宏实现。
>
> 建议 Phaes 3 再实现 `submit_stream_with_tools()`Phase 2 只实现非流式的 `submit_with_tools()`。如果 Phase 2 需要可先返回 "not yet implemented" 错误。
```rust ```rust
impl LlmCycle { impl LlmCycle {
pub async fn submit_stream_with_tools( pub async fn submit_stream_with_tools(
@@ -830,36 +888,46 @@ Phase 4Agent + Skill + 编排)
- 创建 `src/tools/registry.rs` - 创建 `src/tools/registry.rs`
- 定义 `ToolInvocation` 结构体 + `ToolEntry` 元数据包装(tool + tags + category + stats+ `ToolRegistry` - 定义 `ToolInvocation` 结构体 + `ToolEntry` 元数据包装(tool + tags + category + stats+ `ToolRegistry`
- 实现核心方法:register / get / list / definitions / invoke / invoke_all / find_by_tag / find_by_category - 实现核心方法:register / get / list / definitions / invoke / invoke_all / find_by_tag / find_by_category
- `invoke_all()` 使用 `futures::future::join_all` 并行执行互不依赖的工具 - `invoke_all()` 使用 `futures::future::join_all` + `tokio::time::timeout` 并行执行互不依赖的工具(每工具独立超时)
- `definitions()``HashMap` 中的工具转换为 `Vec<ToolDefinition>` - `definitions()``HashMap` 中的工具转换为 `Vec<ToolDefinition>`
- `ToolRegistry` 不支持运行时并发注册(setup 阶段一次性构建),如需热注册由调用方通过 `Arc<RwLock<ToolRegistry>>` 包装
- 编写 8+ 测试覆盖:注册冲突、空注册表查找、单次调用、批量并行调用、工具执行失败 - 编写 8+ 测试覆盖:注册冲突、空注册表查找、单次调用、批量并行调用、工具执行失败
- 运行 `cargo test` 验证 - 运行 `cargo test` 验证
### Step 6: LlmCycle 扩展(自动 Tool 循环) ### Step 6: LlmCycle 扩展(自动 Tool 循环)
- 新增 `cycle_submit.rs` 子模块(或直接在 `cycle.rs` 中扩增,取决于代码量) - 新增 `cycle_submit.rs` 子模块(或直接在 `cycle.rs` 中扩增,取决于代码量)
- 提取 `submit_request()` 内部方法(将 submit() 中的 request→response 逻辑独立) - 提取 `submit_request()` 内部方法(将 submit() 中的 request→response 逻辑独立),同时重构 `submit_messages()` 以复用同一路径
- 实现 `submit_with_tools()` 方法: - 实现 `submit_with_tools()` 方法:
- 循环:submit_request → 检查 finish_reason → 调用 registry.invoke_all → 回传结果 - 循环:submit_request → push Assistant 消息 → 检查 finish_reason → 调用 registry.invoke_all → push tool_results → 重复
- `max_turns` 控制,达到上限返回错误 - 在 push tool_results **之前**先 push Assistanttool_calls)消息(OpenAI API 要求)
- 工具执行错误以文本回传(LLM 可恢复) - `max_tool_turns` 控制(独立于 `max_turns`),达到上限返回错误
- 实现 `submit_stream_with_tools()` 方法: - 不可恢复的错误(NotFound、PermissionDenied、McpError)终止循环
- 组合流式事件流和自动 tool 循环 - 可恢复的错误(ExecutionFailed、InvalidArguments)以文本回传 LLM
- 在 TurnComplete(ToolCalls) 后发射 ToolExecutionCompleted - 每轮执行后触发 `maybe_compact()` 防止 token 膨胀
- 更新 `CycleConfig` 文档注释,新增 `tool_timeout_secs` 字段,默认值 60 - `submit_stream_with_tools()` 方法:
-`CycleConfig::max_turns` 默认值由 `None` 改为 `Some(10)` - Phase 2 标记为未实现(返回 `LlmError::Other("流式 tool 循环将在后续版本中支持")`
- 编写 3+ 集成测试:单轮 tool 调用、多轮 tool 调用、达到 max_turns 终止 - 实际实现推迟到 Phase 3(需要自定义 `ToolStream` 状态机)
- 更新 `CycleConfig`
- 新增 `max_tool_turns: Option<u32>`,默认 `Some(10)`(不影响 `max_turns` 语义)
- 新增 `tool_timeout_secs: u64`,默认值 60
- 新增 `max_tool_result_bytes: Option<usize>`,默认 `Some(65536)`(限制单次工具结果大小)
- 编写 3+ 集成测试:单轮 tool 调用、多轮 tool 调用、达到 max_tool_turns 终止
- 运行 `cargo test` 验证 - 运行 `cargo test` 验证
### Step 7: McpClientMCP 协议客户端) ### Step 7: McpClientMCP 协议客户端)
- 创建 `src/tools/mcp.rs` - 创建 `src/tools/mcp.rs`
- 实现 JSON-RPC 消息结构(Request / Response / Error / Notification - 实现 JSON-RPC 消息结构(Request / Response / Error / Notification
- 定义 `ChildProcessState` 结构体,包含运行时字段:`child`/`stdin`/`pending: HashMap<u64, oneshot::Sender>`/`next_id: u64`
- reader task 使用 `tokio::select!` 同时监听 stdout 和 cancellation token
- `call_tool()` 通过 Mutex 获取 stdin 写入权限,通过 id 匹配响应
- 子进程意外退出时通知所有 pending 请求
- 实现 stdio transport - 实现 stdio transport
- `connect()`:启动子进程,发送 initialize 请求 - `connect()`:启动子进程,创建 ChildProcessState发送 initialize 请求
- `list_tools()`:调用 tools/list,缓存结果 - `list_tools()`:调用 tools/list,缓存结果
- `call_tool()`:调用 tools/call,解析响应 - `call_tool()`:调用 tools/call,解析响应
- `close()`:发送 shutdown 请求,终止子进程 - `close()`:发送 shutdown → 等待 5s 优雅退出 → 超时则 child.kill()
- `StreamableHttp` transport 预留枚举变体,当前返回 "not implemented" 错误,不在 Phase 2 实现 - `StreamableHttp` transport 预留枚举变体,当前返回 "not implemented" 错误,不在 Phase 2 实现
- 实现 `into_tools()`:将 MCP 工具转换为 `Vec<Arc<dyn BaseTool>>` 适配器 - 实现 `into_tools()`:将 MCP 工具转换为 `Vec<Arc<dyn BaseTool>>` 适配器
- 设置 30 秒默认超时 - 设置 30 秒默认超时