feat(tools): 添加工具系统框架与 MCP 协议客户端

This commit is contained in:
徐涛
2026-06-07 10:57:15 +08:00
parent e598f6d3ee
commit b6e7acfb0f
9 changed files with 2034 additions and 1 deletions
+1
View File
@@ -18,6 +18,7 @@ futures-util = "0.3"
futures-core = "0.3"
bytes = "1"
async-stream = "0.3"
tokio-util = { version = "0.7", features = ["rt"] }
[dev-dependencies]
dotenvy = "0.15.7"
+1
View File
@@ -2,6 +2,7 @@
pub mod llm;
pub mod prompt;
pub mod tools;
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
+465 -1
View File
@@ -19,7 +19,8 @@ use crate::llm::hooks::{HookContext, HookExecutor};
use crate::llm::provider::LlmProvider;
use crate::llm::stream::StreamEvent;
use crate::llm::types::{
ChatRequest, ChatResponse, OpenaiChatMessage, OpenaiTool, ToolChoice, ToolDefinition,
ChatRequest, ChatResponse, FinishReason, OpenaiChatMessage, OpenaiTool, OpenaiToolCall,
ToolChoice, ToolDefinition,
};
/// LLM 调用周期配置。
@@ -34,6 +35,15 @@ pub struct CycleConfig {
pub max_turns: Option<u32>,
/// 重试策略配置。
pub retry: RetryConfig,
/// 自动 tool 循环的最大轮次(独立于 `max_turns`,避免影响现有 `submit()` 语义)。
/// 默认 `Some(10)`,防止 LLM 反复调用工具导致无限循环。
pub max_tool_turns: Option<u32>,
/// 单个工具执行的超时秒数(0 表示不超时)。
/// 默认 60 秒。
pub tool_timeout_secs: u64,
/// 单个工具结果的最大字节数(超过此值将被截断)。
/// 默认 6553664KB),防止大结果导致 token 膨胀。
pub max_tool_result_bytes: usize,
}
impl Default for CycleConfig {
@@ -44,6 +54,9 @@ impl Default for CycleConfig {
temperature: None,
max_turns: None,
retry: RetryConfig::default(),
max_tool_turns: Some(10),
tool_timeout_secs: 60,
max_tool_result_bytes: 65_536,
}
}
}
@@ -440,4 +453,455 @@ impl LlmCycle {
..Default::default()
}
}
/// 内部请求方法(与 `submit` 共享重试逻辑,但不 push user message 和 Assistant 响应)。
///
/// 用于 `submit_with_tools()` 的多轮 tool 循环。
async fn submit_request(
&mut self,
tools: &[ToolDefinition],
) -> Result<ChatResponse, LlmError> {
let mut attempts = 0;
loop {
let request = self.build_request(tools);
if let Some(ref executor) = self.hook_executor {
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
.with_request(&request);
let results = executor
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
.await;
if results.iter().any(|r| r.should_block) {
let reason = results
.iter()
.find(|r| r.should_block)
.and_then(|r| r.reason.clone())
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
return Err(LlmError::Other(reason));
}
}
match self.provider.chat(request).await {
Ok(response) => {
if let Some(ref executor) = self.hook_executor {
let post_request = self.build_request(tools);
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
.with_request(&post_request);
executor
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
.await;
}
self.usage.add(&response.usage);
return Ok(response);
}
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
attempts += 1;
if let Some(ref executor) = self.hook_executor {
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnRetry)
.with_error(&e)
.with_attempt(attempts);
executor
.execute(crate::llm::hooks::HookEvent::OnRetry, &ctx)
.await;
}
let delay = self.config.retry.compute_delay(attempts);
tokio::time::sleep(delay).await;
}
Err(e) => {
if let Some(ref executor) = self.hook_executor {
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnError)
.with_error(&e);
executor
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
.await;
}
return Err(e);
}
}
}
}
/// 提交消息并自动处理工具调用循环。
///
/// 流程:
/// 1. 发送请求(含工具定义)
/// 2. 检查响应中的 finish_reason
/// 3. 如果是 ToolCalls → push Assistant 消息 → 执行工具 → 回传结果 → 重复 1
/// 4. 如果是 Stop/Length → push Assistant 消息 → 返回最终响应
///
/// 注意:OpenAI API 要求 tool 消息必须紧跟在对应的 Assistanttool_calls)消息之后。
/// 因此 push 工具结果前必须先 push Assistant 响应,否则 API 拒绝请求。
pub async fn submit_with_tools(
&mut self,
prompt: String,
registry: &crate::tools::ToolRegistry,
) -> Result<ChatResponse, LlmError> {
let tools = registry.definitions();
let max_turns = self.config.max_tool_turns.unwrap_or(10);
let tool_timeout = self.config.tool_timeout_secs;
let max_bytes = self.config.max_tool_result_bytes;
self.messages.push(OpenaiChatMessage::user_text(prompt));
self.maybe_compact();
let mut turn = 0;
loop {
turn += 1;
if turn > max_turns {
return Err(LlmError::Other(format!(
"达到最大工具循环轮次 ({max_turns})"
)));
}
let response = self.submit_request(&tools).await?;
// 判断是否需要执行工具
let should_execute = matches!(response.stop_reason, Some(FinishReason::ToolCalls))
&& has_tool_calls_in_message(&response.message);
// 将 Assistant 响应(含 tool_calls 或最终文本)追加到消息历史
self.messages.push(response.message.clone());
if !should_execute {
return Ok(response);
}
// 解析 tool_calls 并执行
let tool_calls = extract_tool_calls_from_message(&response.message);
let calls: Vec<(String, serde_json::Value)> = tool_calls
.into_iter()
.map(|(_id, name, args)| {
let args: serde_json::Value =
serde_json::from_str(&args).unwrap_or(serde_json::Value::Null);
(name, args)
})
.collect();
let results = registry.invoke_all(calls, tool_timeout).await;
// 回传工具结果
for result in results {
let content = match result.output {
Ok(value) => {
let serialized = serde_json::to_string(&value).unwrap_or_else(|e| {
tracing::warn!("工具结果序列化失败: {}", e);
"{}".to_string()
});
truncate_tool_result(&serialized, max_bytes)
}
Err(e) if e.is_recoverable() => format!("错误: {}", e),
Err(e) => {
// 不可恢复错误:终止循环
return Err(LlmError::Other(format!(
"工具 '{}' 不可恢复错误: {}",
result.tool_name, e
)));
}
};
self.messages
.push(OpenaiChatMessage::tool_result(result.tool_name, content));
}
// 每轮工具执行后触发 compaction
self.maybe_compact();
}
// unreachable: loop returns
#[allow(unreachable_code)]
{
Err(LlmError::Other("unreachable".into()))
}
}
/// 在接近上下文窗口时压缩历史消息。
fn maybe_compact(&mut self) {
if let Some(ref config) = self.compact_config
&& should_compact(&self.messages, config, &self.compact_state)
{
let freed = microcompact(&mut self.messages, config.keep_recent);
if freed > 0 {
self.compact_state.record_success();
}
}
}
}
/// 判断 Assistant 消息是否包含 tool_calls。
fn has_tool_calls_in_message(msg: &OpenaiChatMessage) -> bool {
matches!(
msg,
OpenaiChatMessage::Assistant {
tool_calls: Some(calls),
..
} if !calls.is_empty()
)
}
/// 提取 Assistant 消息中的 tool_calls。
///
/// 返回 `(tool_call_id, tool_name, arguments_json_string)` 列表。
fn extract_tool_calls_from_message(
msg: &OpenaiChatMessage,
) -> Vec<(String, String, String)> {
if let OpenaiChatMessage::Assistant {
tool_calls: Some(calls),
..
} = msg
{
calls
.iter()
.map(|c| match c {
OpenaiToolCall::Function { id, function } => {
(id.clone(), function.name.clone(), function.arguments.clone())
}
})
.collect()
} else {
Vec::new()
}
}
/// 截断工具结果到指定字节数。
fn truncate_tool_result(s: &str, max_bytes: usize) -> String {
if s.len() <= max_bytes {
return s.to_string();
}
let mut end = max_bytes;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
format!("{}\n\n[... truncated, original size: {} bytes ...]", &s[..end], s.len())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::types::{ContentField, OpenaiContentPart};
use crate::tools::{BaseTool, ToolRegistry};
use async_trait::async_trait;
use serde_json::{json, Value};
/// 模拟 Provider —— 预定义响应序列,按调用顺序返回。
struct MockProvider {
responses: std::sync::Mutex<Vec<ChatResponse>>,
call_count: std::sync::Mutex<u32>,
}
impl MockProvider {
fn new(responses: Vec<ChatResponse>) -> Self {
Self {
responses: std::sync::Mutex::new(responses),
call_count: std::sync::Mutex::new(0),
}
}
}
#[async_trait]
impl LlmProvider for MockProvider {
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, LlmError> {
let mut count = self.call_count.lock().unwrap();
*count += 1;
let mut responses = self.responses.lock().unwrap();
if responses.is_empty() {
return Err(LlmError::Other("no more mock responses".into()));
}
Ok(responses.remove(0))
}
}
fn empty_usage() -> crate::llm::types::Usage {
crate::llm::types::Usage::default()
}
fn assistant_text_response(text: &str) -> ChatResponse {
ChatResponse {
message: OpenaiChatMessage::assistant_text(text),
usage: empty_usage(),
stop_reason: Some(FinishReason::Stop),
}
}
fn assistant_tool_call_response(
calls: Vec<(&str, &str, &str)>,
) -> ChatResponse {
use crate::llm::types::{OpenaiToolCall, FunctionCall};
let tool_calls: Vec<OpenaiToolCall> = calls
.into_iter()
.map(|(id, name, args)| OpenaiToolCall::Function {
id: id.to_string(),
function: FunctionCall {
name: name.to_string(),
arguments: args.to_string(),
},
})
.collect();
ChatResponse {
message: OpenaiChatMessage::Assistant {
content: ContentField::Array(vec![OpenaiContentPart::Text {
text: String::new(),
}]),
refusal: None,
name: None,
tool_calls: Some(tool_calls),
},
usage: empty_usage(),
stop_reason: Some(FinishReason::ToolCalls),
}
}
struct AddTool;
#[async_trait]
impl BaseTool for AddTool {
fn name(&self) -> &str {
"add"
}
fn description(&self) -> &str {
"加法"
}
fn parameters(&self) -> Value {
json!({"type":"object","properties":{"a":{"type":"integer"},"b":{"type":"integer"}}})
}
async fn execute(
&self,
args: Value,
_ctx: &crate::tools::ToolContext<'_>,
) -> Result<Value, crate::tools::ToolError> {
let a = args["a"].as_i64().unwrap_or(0);
let b = args["b"].as_i64().unwrap_or(0);
Ok(json!({"result": a + b}))
}
}
#[tokio::test]
async fn test_submit_with_tools_single_turn() {
// 第一轮:返回 tool_call;第二轮:返回最终文本
let responses = vec![
assistant_tool_call_response(vec![("call_1", "add", r#"{"a":1,"b":2}"#)]),
assistant_text_response("答案是 3"),
];
let provider = Box::new(MockProvider::new(responses));
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
let mut registry = ToolRegistry::new();
registry.register(std::sync::Arc::new(AddTool)).unwrap();
let response = cycle
.submit_with_tools("1+2=?".to_string(), &registry)
.await
.unwrap();
// 验证最终响应是文本响应
assert!(matches!(
response.message,
OpenaiChatMessage::Assistant { .. }
));
// 验证消息历史:user, assistant(tool_calls), tool, assistant(text)
let messages = cycle.messages();
assert_eq!(messages.len(), 4);
assert!(matches!(messages[0], OpenaiChatMessage::User { .. }));
assert!(matches!(messages[1], OpenaiChatMessage::Assistant { .. }));
assert!(matches!(messages[2], OpenaiChatMessage::Tool { .. }));
assert!(matches!(messages[3], OpenaiChatMessage::Assistant { .. }));
}
#[tokio::test]
async fn test_submit_with_tools_multi_turn() {
// 3 轮 tool 调用后给出最终答案
let responses = vec![
assistant_tool_call_response(vec![("call_1", "add", r#"{"a":1,"b":2}"#)]),
assistant_tool_call_response(vec![("call_2", "add", r#"{"a":3,"b":4}"#)]),
assistant_tool_call_response(vec![("call_3", "add", r#"{"a":5,"b":6}"#)]),
assistant_text_response("完成"),
];
let provider = Box::new(MockProvider::new(responses));
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
let mut registry = ToolRegistry::new();
registry.register(std::sync::Arc::new(AddTool)).unwrap();
let response = cycle
.submit_with_tools("计算总和".to_string(), &registry)
.await
.unwrap();
assert!(matches!(
response.message,
OpenaiChatMessage::Assistant { .. }
));
// user + 3*(assistant + tool) + final assistant = 8
let messages = cycle.messages();
assert_eq!(messages.len(), 8);
}
#[tokio::test]
async fn test_submit_with_tools_max_turns_exceeded() {
// 配置 max_tool_turns = 2
let mut config = CycleConfig::default();
config.max_tool_turns = Some(2);
// 4 轮 tool 调用 + 终止
let responses = vec![
assistant_tool_call_response(vec![("c1", "add", r#"{"a":1,"b":1}"#)]),
assistant_tool_call_response(vec![("c2", "add", r#"{"a":1,"b":1}"#)]),
assistant_tool_call_response(vec![("c3", "add", r#"{"a":1,"b":1}"#)]),
assistant_text_response("完成"),
];
let provider = Box::new(MockProvider::new(responses));
let mut cycle = LlmCycle::new(provider, config);
let mut registry = ToolRegistry::new();
registry.register(std::sync::Arc::new(AddTool)).unwrap();
let result = cycle
.submit_with_tools("test".to_string(), &registry)
.await;
assert!(matches!(result, Err(LlmError::Other(msg)) if msg.contains("达到最大工具循环轮次")));
}
#[tokio::test]
async fn test_submit_with_tools_no_tool_call_response() {
// LLM 直接给出最终响应(不调用工具)
let responses = vec![assistant_text_response("直接回答")];
let provider = Box::new(MockProvider::new(responses));
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
let mut registry = ToolRegistry::new();
registry.register(std::sync::Arc::new(AddTool)).unwrap();
let response = cycle
.submit_with_tools("直接回答".to_string(), &registry)
.await
.unwrap();
assert!(matches!(
response.message,
OpenaiChatMessage::Assistant { .. }
));
}
#[test]
fn test_truncate_tool_result_short() {
let s = "short text";
assert_eq!(truncate_tool_result(s, 100), "short text");
}
#[test]
fn test_truncate_tool_result_long() {
let s = "a".repeat(1000);
let truncated = truncate_tool_result(&s, 50);
assert!(truncated.len() < s.len());
assert!(truncated.contains("[... truncated,"));
}
#[test]
fn test_truncate_tool_result_chinese_chars() {
let s = "".repeat(100);
let truncated = truncate_tool_result(&s, 50);
// 不会在字符中间截断
assert!(truncated.starts_with(""));
}
}
+13
View File
@@ -0,0 +1,13 @@
//! 工具系统 —— 工具抽象、注册、调用、权限控制与 MCP 集成。
pub mod base;
pub mod error;
pub mod mcp;
pub mod permission;
pub mod registry;
pub use base::{BaseTool, ToolContext, ToolRef};
pub use error::ToolError;
pub use mcp::{McpClient, McpTransport};
pub use permission::{Permission, PermissionChecker, PermissionConfig};
pub use registry::{ToolInvocation, ToolRegistry};
+139
View File
@@ -0,0 +1,139 @@
//! 工具抽象接口与执行上下文。
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use tokio_util::sync::CancellationToken;
use crate::tools::error::ToolError;
use crate::tools::permission::Permission;
/// 工具执行上下文 —— 携带每次执行的运行时信息。
///
/// 字段在 Phase 2 即注入 `execute()` 签名中,防止后续扩展时出现
/// breaking change。后续阶段可扩展字段(如 `progress`、`shared_state`),
/// 但已有工具实现无需修改。
#[derive(Debug)]
pub struct ToolContext<'a> {
/// 当前对话/会话 ID,用于关联性追踪。
pub session_id: &'a str,
/// 链路追踪 ID,用于跨工具调用的耗时分布。
pub trace_id: &'a str,
/// 取消令牌,用于优雅取消正在执行的工具。
pub cancellation_token: CancellationToken,
}
impl<'a> ToolContext<'a> {
/// 创建一个新的工具执行上下文。
pub fn new(session_id: &'a str, trace_id: &'a str) -> Self {
Self {
session_id,
trace_id,
cancellation_token: CancellationToken::new(),
}
}
/// 创建一个使用给定取消令牌的上下文。
pub fn with_cancellation_token(
session_id: &'a str,
trace_id: &'a str,
token: CancellationToken,
) -> Self {
Self {
session_id,
trace_id,
cancellation_token: token,
}
}
}
/// 工具抽象接口 —— 所有工具(自定义或 MCP)最终都实现此 trait。
#[async_trait]
pub trait BaseTool: Send + Sync {
/// 工具名称(唯一标识,用于 LLM 的 tool_calls.name 匹配)。
fn name(&self) -> &str;
/// 工具描述(LLM 据此决定是否调用此工具)。
fn description(&self) -> &str;
/// 工具参数定义(JSON Schema 格式,传递给 LLM 的 tool.parameters)。
fn parameters(&self) -> Value;
/// 声明工具所需的权限列表。
fn required_permissions(&self) -> Vec<Permission> {
Vec::new()
}
/// 执行工具调用。
///
/// `ctx` 携带执行上下文(session_id、trace_id、cancellation_token),
/// 工具实现可在执行期间检查 `ctx.cancellation_token` 来支持优雅取消。
async fn execute(&self, args: Value, ctx: &ToolContext<'_>) -> Result<Value, ToolError>;
}
/// 为 `Arc<dyn BaseTool>` 提供 `Send + Sync` 包装,便于在 `Vec<Arc<dyn BaseTool>>` 中使用。
pub type ToolRef = Arc<dyn BaseTool>;
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
struct EchoTool;
#[async_trait]
impl BaseTool for EchoTool {
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"回显输入"
}
fn parameters(&self) -> Value {
json!({
"type": "object",
"properties": {
"text": { "type": "string" }
},
"required": ["text"]
})
}
async fn execute(&self, args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
Ok(args)
}
}
#[tokio::test]
async fn test_mock_tool_execute() {
let tool = EchoTool;
let ctx = ToolContext::new("session-1", "trace-1");
let result = tool.execute(json!({"text": "hello"}), &ctx).await.unwrap();
assert_eq!(result, json!({"text": "hello"}));
}
#[tokio::test]
async fn test_default_permissions_empty() {
let tool = EchoTool;
assert!(tool.required_permissions().is_empty());
}
#[test]
fn test_tool_context_creation() {
let ctx = ToolContext::new("s1", "t1");
assert_eq!(ctx.session_id, "s1");
assert_eq!(ctx.trace_id, "t1");
assert!(!ctx.cancellation_token.is_cancelled());
}
#[test]
fn test_tool_context_cancellation() {
let token = CancellationToken::new();
token.cancel();
let ctx = ToolContext::with_cancellation_token("s1", "t1", token);
assert!(ctx.cancellation_token.is_cancelled());
}
}
+118
View File
@@ -0,0 +1,118 @@
//! 工具系统错误类型。
use std::sync::Arc;
/// 工具调用过程中可能发生的所有错误。
#[derive(thiserror::Error, Debug, Clone)]
pub enum ToolError {
/// 工具未注册。
#[error("工具 '{0}' 未注册")]
NotFound(String),
/// 工具执行失败(可恢复——文本回传 LLM)。
#[error("工具 '{0}' 执行失败: {1}")]
ExecutionFailed(String, String),
/// 工具参数无效(可恢复——文本回传 LLM)。
#[error("工具 '{0}' 参数无效: {1}")]
InvalidArguments(String, String),
/// 权限被拒绝(不可恢复——终止循环)。
#[error("权限被拒绝: 工具 '{0}' 需要 {1} 权限")]
PermissionDenied(String, String),
/// MCP 协议错误(不可恢复)。
#[error("MCP 协议错误: {0}")]
McpError(String),
/// MCP 未初始化(不可恢复)。
#[error("MCP 未初始化: {0}")]
McpNotInitialized(String),
/// MCP 超时(不可恢复)。
#[error("MCP 超时: {0}")]
McpTimeout(String),
/// IO 错误(不可恢复)。
#[error("IO 错误: {0}")]
Io(Arc<std::io::Error>),
/// 取消。
#[error("工具执行已取消: {0}")]
Cancelled(String),
/// 其他未分类错误。
#[error("其他错误: {0}")]
Other(String),
}
impl From<std::io::Error> for ToolError {
fn from(e: std::io::Error) -> Self {
ToolError::Io(Arc::new(e))
}
}
impl ToolError {
/// 判断错误是否可恢复——可恢复的错误回传 LLM 由其自行重试,
/// 不可恢复的错误终止自动 tool 循环并返回给调用方。
pub fn is_recoverable(&self) -> bool {
matches!(
self,
Self::ExecutionFailed(..) | Self::InvalidArguments(..) | Self::Other(_)
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_execution_failed_is_recoverable() {
let err = ToolError::ExecutionFailed("foo".into(), "boom".into());
assert!(err.is_recoverable());
}
#[test]
fn test_invalid_arguments_is_recoverable() {
let err = ToolError::InvalidArguments("foo".into(), "missing x".into());
assert!(err.is_recoverable());
}
#[test]
fn test_not_found_is_not_recoverable() {
let err = ToolError::NotFound("foo".into());
assert!(!err.is_recoverable());
}
#[test]
fn test_permission_denied_is_not_recoverable() {
let err = ToolError::PermissionDenied("foo".into(), "Shell".into());
assert!(!err.is_recoverable());
}
#[test]
fn test_mcp_error_is_not_recoverable() {
let err = ToolError::McpError("protocol".into());
assert!(!err.is_recoverable());
}
#[test]
fn test_mcp_timeout_is_not_recoverable() {
let err = ToolError::McpTimeout("foo".into());
assert!(!err.is_recoverable());
}
#[test]
fn test_io_is_not_recoverable() {
let io_err = std::io::Error::new(std::io::ErrorKind::Other, "disk");
let err = ToolError::from(io_err);
assert!(!err.is_recoverable());
}
#[test]
fn test_other_is_recoverable() {
let err = ToolError::Other("something".into());
assert!(err.is_recoverable());
}
}
+640
View File
@@ -0,0 +1,640 @@
//! MCP 协议客户端 —— 与 MCP Server 通过 JSON-RPC over stdio 通信。
//!
//! 当前 Phase 2 实现 stdio transport。`StreamableHttp` 枚举变体已预留,
//! 但实际实现推迟到后续版本。
//!
//! ## 协议版本
//!
//! 实现遵循 MCP 协议版本 2025-03-26。
use std::collections::HashMap;
use std::process::Stdio;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::{oneshot, Mutex};
use crate::llm::types::ToolDefinition;
use crate::tools::base::{BaseTool, ToolContext, ToolRef};
use crate::tools::error::ToolError;
/// MCP 协议版本。
const MCP_VERSION: &str = "2025-03-26";
/// MCP 传输方式。
#[derive(Debug, Clone)]
pub enum McpTransport {
/// 通过子进程 stdin/stdout 通信。
Stdio {
/// 启动命令(如 `"npx"`)。
command: String,
/// 命令参数(如 `["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]`)。
args: Vec<String>,
},
/// Streamable HTTP 传输(MCP 2025-03-26 引入,替代已废弃的 HTTP+SSE)。
///
/// 当前 Phase 2 预留枚举变体,调用方法会返回 `ToolError::McpError`。
StreamableHttp {
/// MCP 端点 URL。
url: String,
/// 可选的 HTTP 头(如 Authorization)。
headers: Option<Vec<(String, String)>>,
},
}
/// JSON-RPC 请求。
#[derive(Debug, Serialize, Deserialize)]
struct JsonRpcRequest {
jsonrpc: &'static str,
id: u64,
method: String,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<Value>,
}
impl JsonRpcRequest {
fn new(id: u64, method: impl Into<String>, params: Option<Value>) -> Self {
Self {
jsonrpc: "2.0",
id,
method: method.into(),
params,
}
}
}
/// JSON-RPC 响应。
#[derive(Debug, Serialize, Deserialize)]
struct JsonRpcResponse {
jsonrpc: String,
id: u64,
#[serde(default)]
result: Option<Value>,
#[serde(default)]
error: Option<JsonRpcError>,
}
#[derive(Debug, Serialize, Deserialize)]
struct JsonRpcError {
code: i32,
message: String,
#[serde(default)]
data: Option<Value>,
}
/// MCP 子进程运行时状态。
struct ChildProcessState {
child: Child,
stdin: ChildStdin,
pending: HashMap<u64, oneshot::Sender<Result<Value, ToolError>>>,
next_id: u64,
}
impl ChildProcessState {
fn next_id(&mut self) -> u64 {
self.next_id += 1;
self.next_id
}
}
/// MCP Server 暴露的工具(缓存结构)。
#[derive(Debug, Clone)]
struct McpTool {
name: String,
description: Option<String>,
input_schema: Value,
}
/// MCP 客户端 —— 与 MCP 服务器通信。
pub struct McpClient {
transport: McpTransport,
server_name: String,
/// 已初始化的工具列表(缓存)。
tools: Vec<McpTool>,
/// 是否已初始化。
initialized: AtomicBool,
/// 超时时间(秒)。
timeout_secs: u64,
/// 子进程运行时状态(`connect()` 后创建,`close()` 后取回)。
process: Option<Arc<Mutex<ChildProcessState>>>,
}
impl std::fmt::Debug for McpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpClient")
.field("server_name", &self.server_name)
.field("initialized", &self.initialized.load(Ordering::SeqCst))
.field("tool_count", &self.tools.len())
.finish()
}
}
impl McpClient {
/// 创建一个 MCP 客户端。
pub fn new(server_name: impl Into<String>, transport: McpTransport) -> Self {
Self {
transport,
server_name: server_name.into(),
tools: Vec::new(),
initialized: AtomicBool::new(false),
timeout_secs: 30,
process: None,
}
}
/// 设置超时时间(秒)。
pub fn with_timeout(mut self, secs: u64) -> Self {
self.timeout_secs = secs;
self
}
/// 检查是否已连接。
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::SeqCst)
}
/// 连接并初始化(发送 initialize 请求)。
pub async fn connect(&mut self) -> Result<(), ToolError> {
if self.is_initialized() {
return Ok(());
}
match &self.transport {
McpTransport::Stdio { command, args } => {
let mut cmd = Command::new(command);
cmd.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
#[cfg(unix)]
cmd.kill_on_drop(true);
#[cfg(windows)]
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW
let mut child = cmd
.spawn()
.map_err(|e| ToolError::McpError(format!("启动 MCP 子进程失败: {e}")))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| ToolError::McpError("无法获取子进程 stdin".into()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| ToolError::McpError("无法获取子进程 stdout".into()))?;
// 启动 reader task 持续读取 stdout
let pending: HashMap<u64, oneshot::Sender<Result<Value, ToolError>>> =
HashMap::new();
let state = Arc::new(Mutex::new(ChildProcessState {
child,
stdin,
pending,
next_id: 0,
}));
// 启动后台 reader
let pending_arc = Arc::clone(&state);
tokio::spawn(async move {
Self::read_loop(BufReader::new(stdout), pending_arc).await;
});
self.process = Some(state);
}
McpTransport::StreamableHttp { .. } => {
return Err(ToolError::McpError(
"StreamableHttp transport 尚未实现".into(),
));
}
}
// 发送 initialize 请求
let init_params = json!({
"protocolVersion": MCP_VERSION,
"capabilities": {},
"clientInfo": {
"name": "agcore",
"version": env!("CARGO_PKG_VERSION")
}
});
let _response = self
.send_request("initialize", Some(init_params))
.await?;
// 发送 initialized 通知(无 id
self.send_notification("notifications/initialized", Some(json!({})))
.await?;
self.initialized.store(true, Ordering::SeqCst);
Ok(())
}
/// 列出服务器支持的工具(调用 `tools/list`)。
pub async fn list_tools(&mut self) -> Result<Vec<ToolDefinition>, ToolError> {
if !self.is_initialized() {
return Err(ToolError::McpNotInitialized(self.server_name.clone()));
}
let response = self.send_request("tools/list", None).await?;
let tools_value = response
.get("tools")
.ok_or_else(|| ToolError::McpError("tools/list 响应缺少 tools 字段".into()))?;
let tools_arr = tools_value
.as_array()
.ok_or_else(|| ToolError::McpError("tools/list 响应 tools 字段不是数组".into()))?;
self.tools.clear();
let mut defs = Vec::with_capacity(tools_arr.len());
for tool in tools_arr {
let name = tool
.get("name")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::McpError("工具缺少 name 字段".into()))?
.to_string();
let description = tool
.get("description")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let input_schema = tool
.get("inputSchema")
.cloned()
.unwrap_or_else(|| json!({"type": "object", "properties": {}}));
self.tools.push(McpTool {
name: name.clone(),
description: description.clone(),
input_schema: input_schema.clone(),
});
defs.push(ToolDefinition {
name,
description,
parameters: input_schema,
strict: None,
});
}
Ok(defs)
}
/// 调用一个工具(调用 `tools/call`)。
pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, ToolError> {
if !self.is_initialized() {
return Err(ToolError::McpNotInitialized(self.server_name.clone()));
}
let params = json!({
"name": name,
"arguments": args,
});
let response = self.send_request("tools/call", Some(params)).await?;
// 解析 content 字段
if let Some(content) = response.get("content").and_then(|c| c.as_array()) {
// 收集所有 text 内容
let mut combined = String::new();
for item in content {
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
if !combined.is_empty() {
combined.push('\n');
}
combined.push_str(text);
}
}
if !combined.is_empty() {
return Ok(Value::String(combined));
}
}
// 如果没有 content 字段,尝试直接返回 is_error 标记
if let Some(true) = response.get("isError").and_then(|v| v.as_bool()) {
return Err(ToolError::ExecutionFailed(
name.to_string(),
"MCP 工具返回 isError=true".into(),
));
}
// 回退:返回完整响应
Ok(response)
}
/// 关闭连接(终止子进程)。
pub async fn close(&mut self) -> Result<(), ToolError> {
if !self.is_initialized() {
return Ok(());
}
// 尝试发送 shutdown(不强制要求响应)
let _ = self.send_notification("shutdown", None).await;
if let Some(state) = self.process.take() {
let mut state = state.lock().await;
// 优雅等待 5 秒
let graceful = tokio::time::timeout(
Duration::from_secs(5),
state.child.wait(),
)
.await;
if graceful.is_err() {
// 超时则强杀
let _ = state.child.kill().await;
}
}
self.initialized.store(false, Ordering::SeqCst);
self.tools.clear();
Ok(())
}
/// 将 MCP 客户端转换为 `BaseTool` 适配器列表(用于注册到 `ToolRegistry`)。
///
/// **注意**:返回的适配器持有 `Arc<McpClient>`,但 `McpClient` 的可变性
/// (如 `list_tools` 刷新缓存)会通过 `Mutex` 处理。当前适配器仅缓存
/// 转换时的工具列表,不感知后续刷新。
pub fn into_tools(self) -> Vec<ToolRef> {
let mut tools = Vec::with_capacity(self.tools.len());
for mcp_tool in self.tools {
let tool = McpToolAdapter {
client: McpClientHandle::Empty,
name: mcp_tool.name,
description: mcp_tool.description.unwrap_or_default(),
parameters: mcp_tool.input_schema,
};
tools.push(Arc::new(tool) as ToolRef);
}
tools
}
async fn send_request(
&self,
method: &str,
params: Option<Value>,
) -> Result<Value, ToolError> {
let state_arc = self
.process
.as_ref()
.ok_or_else(|| ToolError::McpNotInitialized(self.server_name.clone()))?
.clone();
let (id, request_json) = {
let mut state = state_arc.lock().await;
let id = state.next_id();
let req = JsonRpcRequest::new(id, method, params);
let json = serde_json::to_string(&req)
.map_err(|e| ToolError::McpError(format!("序列化请求失败: {e}")))?;
(id, json)
};
// 注册 oneshot 等待响应
let (tx, rx) = oneshot::channel();
{
let mut state = state_arc.lock().await;
state.pending.insert(id, tx);
}
// 写入请求
{
let mut state = state_arc.lock().await;
state
.stdin
.write_all(request_json.as_bytes())
.await
.map_err(|e| ToolError::McpError(format!("写入请求失败: {e}")))?;
state
.stdin
.write_all(b"\n")
.await
.map_err(|e| ToolError::McpError(format!("写入换行失败: {e}")))?;
state.stdin.flush().await.map_err(|e| {
ToolError::McpError(format!("flush stdin 失败: {e}"))
})?;
}
// 等待响应(带超时)
tokio::time::timeout(Duration::from_secs(self.timeout_secs), rx)
.await
.map_err(|_| {
// 超时:清理 pending
let state_arc = state_arc.clone();
tokio::spawn(async move {
let mut state = state_arc.lock().await;
state.pending.remove(&id);
});
ToolError::McpTimeout(method.to_string())
})?
.map_err(|_| ToolError::McpError("response channel 关闭".into()))?
}
async fn send_notification(
&self,
method: &str,
params: Option<Value>,
) -> Result<(), ToolError> {
let state_arc = self
.process
.as_ref()
.ok_or_else(|| ToolError::McpNotInitialized(self.server_name.clone()))?
.clone();
let notification = json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
});
let json = serde_json::to_string(&notification)
.map_err(|e| ToolError::McpError(format!("序列化通知失败: {e}")))?;
let mut state = state_arc.lock().await;
state
.stdin
.write_all(json.as_bytes())
.await
.map_err(|e| ToolError::McpError(format!("写入通知失败: {e}")))?;
state
.stdin
.write_all(b"\n")
.await
.map_err(|e| ToolError::McpError(format!("写入换行失败: {e}")))?;
state
.stdin
.flush()
.await
.map_err(|e| ToolError::McpError(format!("flush stdin 失败: {e}")))?;
Ok(())
}
/// 持续读取 stdout,将响应分发到对应的 oneshot sender。
async fn read_loop(
mut reader: BufReader<ChildStdout>,
state: Arc<Mutex<ChildProcessState>>,
) {
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => {
// EOF:通知所有 pending 失败
let mut state = state.lock().await;
for (_, tx) in state.pending.drain() {
let _ = tx.send(Err(ToolError::McpError("子进程退出".into())));
}
break;
}
Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
// 尝试解析为 JSON-RPC 响应
let parsed: Result<JsonRpcResponse, _> = serde_json::from_str(trimmed);
if let Ok(response) = parsed {
let value = if let Some(err) = response.error {
Err(ToolError::McpError(format!(
"[{}] {}",
err.code, err.message
)))
} else {
Ok(response.result.unwrap_or(Value::Null))
};
let mut state = state.lock().await;
if let Some(tx) = state.pending.remove(&response.id) {
let _ = tx.send(value);
}
}
// 非响应消息(通知、request from server)忽略
}
Err(e) => {
tracing::warn!("MCP read_loop error: {e}");
let mut state = state.lock().await;
for (_, tx) in state.pending.drain() {
let _ = tx.send(Err(ToolError::McpError(format!("读取失败: {e}"))));
}
break;
}
}
}
}
}
/// MCP 工具适配器 —— 将 MCP 工具包装为 `BaseTool`。
struct McpToolAdapter {
/// 持有 client 的弱引用。实际生产中应使用 `Arc<McpClient>`
/// 但当前 Phase 2 实现不直接持有可变的 `McpClient`。
/// 标记为 unused 但保留字段以展示扩展路径。
#[allow(dead_code)]
client: McpClientHandle,
name: String,
description: String,
parameters: Value,
}
#[allow(dead_code)]
enum McpClientHandle {
Empty,
// Future: Shared(Arc<McpClient>),
}
#[async_trait]
impl BaseTool for McpToolAdapter {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn parameters(&self) -> Value {
self.parameters.clone()
}
async fn execute(
&self,
_args: Value,
_ctx: &ToolContext<'_>,
) -> Result<Value, ToolError> {
// 当前 Phase 2 实现的简化:McpToolAdapter 不持有活跃 MCP 连接。
// 实际生产中应持有 Arc<McpClient> 并通过 mcp.call_tool() 执行。
// 这里返回错误,提示需要通过其他方式调用 MCP 工具。
Err(ToolError::McpError(format!(
"MCP 工具 '{}' 需要活跃的 McpClient 引用(当前 Phase 2 简化实现)",
self.name
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_debug() {
let transport = McpTransport::Stdio {
command: "echo".to_string(),
args: vec!["hello".to_string()],
};
let formatted = format!("{transport:?}");
assert!(formatted.contains("echo"));
}
#[test]
fn test_client_creation() {
let transport = McpTransport::Stdio {
command: "test".to_string(),
args: vec![],
};
let client = McpClient::new("test-server", transport).with_timeout(60);
assert_eq!(client.server_name, "test-server");
assert_eq!(client.timeout_secs, 60);
assert!(!client.is_initialized());
}
#[test]
fn test_jsonrpc_request_serialize() {
let req = JsonRpcRequest::new(42, "test", Some(json!({"a": 1})));
let s = serde_json::to_string(&req).unwrap();
assert!(s.contains("\"jsonrpc\":\"2.0\""));
assert!(s.contains("\"id\":42"));
assert!(s.contains("\"method\":\"test\""));
}
#[test]
fn test_jsonrpc_response_parse_ok() {
let s = r#"{"jsonrpc":"2.0","id":1,"result":{"foo":"bar"}}"#;
let resp: JsonRpcResponse = serde_json::from_str(s).unwrap();
assert_eq!(resp.id, 1);
assert!(resp.result.is_some());
assert!(resp.error.is_none());
}
#[test]
fn test_jsonrpc_response_parse_error() {
let s =
r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"Method not found"}}"#;
let resp: JsonRpcResponse = serde_json::from_str(s).unwrap();
assert_eq!(resp.id, 1);
assert!(resp.result.is_none());
let err = resp.error.unwrap();
assert_eq!(err.code, -32601);
}
#[tokio::test]
async fn test_streamable_http_not_implemented() {
let mut client = McpClient::new(
"http-server",
McpTransport::StreamableHttp {
url: "https://example.com/mcp".to_string(),
headers: None,
},
);
let result = client.connect().await;
// 当前 Phase 2 返回未实现错误
assert!(result.is_err());
assert!(matches!(result, Err(ToolError::McpError(_))));
}
}
+286
View File
@@ -0,0 +1,286 @@
//! 工具权限管理。
use crate::tools::error::ToolError;
/// 权限级别枚举。
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Permission {
/// 只读(读取文件、查询数据库等)。
Read,
/// 写入(创建/修改文件、插入数据等)。
Write,
/// 删除(删除文件、记录等)。
Delete,
/// 网络访问(HTTP 请求等)。
Network,
/// Shell 命令执行。
Shell,
/// 文件系统操作(除读/写/删之外的 FS 操作)。
FileSystem,
/// 自定义权限(可通过 namespaced 字符串扩展)。
Custom(String),
}
impl std::fmt::Display for Permission {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Read => write!(f, "Read"),
Self::Write => write!(f, "Write"),
Self::Delete => write!(f, "Delete"),
Self::Network => write!(f, "Network"),
Self::Shell => write!(f, "Shell"),
Self::FileSystem => write!(f, "FileSystem"),
Self::Custom(s) => write!(f, "Custom({})", s),
}
}
}
/// 权限配置。
#[derive(Debug, Clone)]
pub struct PermissionConfig {
/// 允许的权限列表(空 = 全部允许,配合 `allow_unspecified` 决定)。
pub allowed: Vec<Permission>,
/// 拒绝的权限列表(优先级高于 `allowed`)。
pub denied: Vec<Permission>,
/// 当工具未声明权限时是否允许执行。
pub allow_unspecified: bool,
}
impl Default for PermissionConfig {
fn default() -> Self {
Self {
allowed: vec![Permission::Read, Permission::Network],
denied: vec![Permission::Delete, Permission::Shell],
allow_unspecified: true,
}
}
}
/// 权限检查器。
#[derive(Debug, Clone)]
pub struct PermissionChecker {
config: PermissionConfig,
}
impl PermissionChecker {
/// 创建一个新的权限检查器。
pub fn new(config: PermissionConfig) -> Self {
Self { config }
}
/// 检查指定工具声明的权限是否允许执行。
///
/// 判定规则:
/// 1. 任一权限在 `denied` 中 → 拒绝
/// 2. 所有权限都在 `allowed` 中 → 允许
/// 3. `allowed` 非空且存在未声明权限 → 拒绝
/// 4. `allowed` 为空 → 按 `allow_unspecified` 判定
/// 5. 工具未声明任何权限时按 `allow_unspecified` 判定
pub fn check(&self, tool_name: &str, permissions: &[Permission]) -> Result<(), ToolError> {
// 任一权限在 denied 中 → 拒绝
for perm in permissions {
if self.config.denied.contains(perm) {
return Err(ToolError::PermissionDenied(
tool_name.to_string(),
perm.to_string(),
));
}
}
// 工具未声明任何权限
if permissions.is_empty() {
return if self.config.allow_unspecified {
Ok(())
} else {
Err(ToolError::PermissionDenied(
tool_name.to_string(),
"Unspecified".to_string(),
))
};
}
// allowed 为空 → 走 allow_unspecified 兜底
if self.config.allowed.is_empty() {
return if self.config.allow_unspecified {
Ok(())
} else {
Err(ToolError::PermissionDenied(
tool_name.to_string(),
"Unspecified".to_string(),
))
};
}
// allowed 非空(白名单模式)—— 所有权限必须在其中
for perm in permissions {
if !self.config.allowed.contains(perm) {
return Err(ToolError::PermissionDenied(
tool_name.to_string(),
perm.to_string(),
));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn p(perm: Permission) -> Vec<Permission> {
vec![perm]
}
#[test]
fn test_default_config_allows_read() {
let checker = PermissionChecker::new(PermissionConfig::default());
assert!(checker.check("weather", &p(Permission::Read)).is_ok());
}
#[test]
fn test_default_config_allows_network() {
let checker = PermissionChecker::new(PermissionConfig::default());
assert!(checker.check("http_get", &p(Permission::Network)).is_ok());
}
#[test]
fn test_default_config_denies_delete() {
let checker = PermissionChecker::new(PermissionConfig::default());
assert!(checker
.check("rm_file", &p(Permission::Delete))
.is_err());
}
#[test]
fn test_default_config_denies_shell() {
let checker = PermissionChecker::new(PermissionConfig::default());
assert!(checker.check("run_shell", &p(Permission::Shell)).is_err());
}
#[test]
fn test_white_list_mode_denies_unlisted() {
let cfg = PermissionConfig {
allowed: vec![Permission::Read],
denied: vec![],
allow_unspecified: false,
};
let checker = PermissionChecker::new(cfg);
assert!(checker.check("t", &p(Permission::Read)).is_ok());
assert!(checker.check("t", &p(Permission::Write)).is_err());
}
#[test]
fn test_white_list_mode_allows_listed() {
let cfg = PermissionConfig {
allowed: vec![Permission::Read, Permission::Write],
denied: vec![],
allow_unspecified: false,
};
let checker = PermissionChecker::new(cfg);
assert!(checker.check("t", &p(Permission::Write)).is_ok());
}
#[test]
fn test_black_list_deny_priority() {
// 即便 allowed 中包含了 denied 权限,仍以 denied 为准
let cfg = PermissionConfig {
allowed: vec![Permission::Shell, Permission::Read],
denied: vec![Permission::Shell],
allow_unspecified: false,
};
let checker = PermissionChecker::new(cfg);
assert!(checker.check("t", &p(Permission::Shell)).is_err());
}
#[test]
fn test_empty_allowed_with_allow_unspecified() {
let cfg = PermissionConfig {
allowed: vec![],
denied: vec![],
allow_unspecified: true,
};
let checker = PermissionChecker::new(cfg);
assert!(checker.check("t", &p(Permission::Write)).is_ok());
}
#[test]
fn test_empty_allowed_without_allow_unspecified() {
let cfg = PermissionConfig {
allowed: vec![],
denied: vec![],
allow_unspecified: false,
};
let checker = PermissionChecker::new(cfg);
assert!(checker.check("t", &p(Permission::Write)).is_err());
}
#[test]
fn test_unspecified_tool_with_allow() {
let cfg = PermissionConfig {
allowed: vec![],
denied: vec![],
allow_unspecified: true,
};
let checker = PermissionChecker::new(cfg);
assert!(checker.check("t", &[]).is_ok());
}
#[test]
fn test_unspecified_tool_without_allow() {
let cfg = PermissionConfig {
allowed: vec![],
denied: vec![],
allow_unspecified: false,
};
let checker = PermissionChecker::new(cfg);
assert!(checker.check("t", &[]).is_err());
}
#[test]
fn test_custom_permission_collision() {
let cfg = PermissionConfig {
allowed: vec![Permission::Custom("db:read".into())],
denied: vec![Permission::Custom("db:write".into())],
allow_unspecified: false,
};
let checker = PermissionChecker::new(cfg);
assert!(checker
.check("t", &[Permission::Custom("db:read".into())])
.is_ok());
assert!(checker
.check("t", &[Permission::Custom("db:write".into())])
.is_err());
}
#[test]
fn test_multi_permission_all_in_allowed() {
let cfg = PermissionConfig {
allowed: vec![Permission::Read, Permission::Write, Permission::Network],
denied: vec![],
allow_unspecified: false,
};
let checker = PermissionChecker::new(cfg);
assert!(checker
.check(
"t",
&[Permission::Read, Permission::Network]
)
.is_ok());
}
#[test]
fn test_multi_permission_one_not_in_allowed() {
let cfg = PermissionConfig {
allowed: vec![Permission::Read, Permission::Network],
denied: vec![],
allow_unspecified: false,
};
let checker = PermissionChecker::new(cfg);
// 任一权限不在白名单则拒绝
assert!(checker
.check("t", &[Permission::Read, Permission::Write])
.is_err());
}
}
+371
View File
@@ -0,0 +1,371 @@
//! 工具注册表 —— 管理工具注册、发现、调用。
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use futures::future::join_all;
use serde_json::Value;
use crate::llm::types::ToolDefinition;
use crate::tools::base::{ToolContext, ToolRef};
use crate::tools::error::ToolError;
use crate::tools::permission::PermissionChecker;
/// 工具调用记录 —— 用于追踪和调试。
#[derive(Debug, Clone)]
pub struct ToolInvocation {
/// 被调用的工具名。
pub tool_name: String,
/// 工具的入参。
pub input: Value,
/// 工具的输出。
pub output: Result<Value, ToolError>,
}
impl ToolInvocation {
/// 创建一个新的工具调用记录。
pub fn new(tool_name: String, input: Value, output: Result<Value, ToolError>) -> Self {
Self {
tool_name,
input,
output,
}
}
}
/// 工具注册表 —— 管理工具注册、发现、调用。
///
/// 通过 `Arc` 共享,方法签名 `&self`,可安全跨 task 并行调用。
/// 不支持运行时并发注册(应在 setup 阶段一次性构建后冻结)。
#[derive(Clone, Default)]
pub struct ToolRegistry {
inner: Arc<ToolRegistryInner>,
}
#[derive(Clone, Default)]
struct ToolRegistryInner {
tools: HashMap<String, ToolRef>,
permission_checker: Option<Arc<PermissionChecker>>,
}
impl std::fmt::Debug for ToolRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolRegistry")
.field("tool_names", &self.inner.tools.keys().collect::<Vec<_>>())
.field("has_checker", &self.inner.permission_checker.is_some())
.finish()
}
}
impl ToolRegistry {
/// 创建一个新的工具注册表。
pub fn new() -> Self {
Self {
inner: Arc::new(ToolRegistryInner {
tools: HashMap::new(),
permission_checker: None,
}),
}
}
/// 设置权限检查器(Builder 模式)。
pub fn with_permission_checker(mut self, checker: PermissionChecker) -> Self {
let inner = Arc::make_mut(&mut self.inner);
inner.permission_checker = Some(Arc::new(checker));
self
}
/// 注册一个工具。
///
/// 重复注册同名工具返回错误。
pub fn register(&mut self, tool: ToolRef) -> Result<(), ToolError> {
let name = tool.name().to_string();
let inner = Arc::make_mut(&mut self.inner);
if inner.tools.contains_key(&name) {
return Err(ToolError::ExecutionFailed(name, "工具已存在".to_string()));
}
inner.tools.insert(name, tool);
Ok(())
}
/// 批量注册工具。
pub fn register_all(&mut self, tools: Vec<ToolRef>) -> Result<(), ToolError> {
for tool in tools {
self.register(tool)?;
}
Ok(())
}
/// 注销一个工具。
pub fn unregister(&mut self, name: &str) -> Option<ToolRef> {
let inner = Arc::make_mut(&mut self.inner);
inner.tools.remove(name)
}
/// 按名称查找工具。
pub fn get(&self, name: &str) -> Option<ToolRef> {
self.inner.tools.get(name).cloned()
}
/// 获取所有已注册工具的名称列表。
pub fn list_tools(&self) -> Vec<String> {
self.inner.tools.keys().cloned().collect()
}
/// 获取所有工具的 `ToolDefinition` 列表(用于传递给 LLM)。
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.inner
.tools
.values()
.map(|tool| ToolDefinition {
name: tool.name().to_string(),
description: Some(tool.description().to_string()),
parameters: tool.parameters(),
strict: None,
})
.collect()
}
/// 调用单个工具(含权限检查)。
pub async fn invoke(&self, name: &str, args: Value) -> Result<ToolInvocation, ToolError> {
let tool = self
.get(name)
.ok_or_else(|| ToolError::NotFound(name.to_string()))?;
if let Some(checker) = &self.inner.permission_checker {
checker.check(name, &tool.required_permissions())?;
}
let ctx = ToolContext::new(name, "");
let output = tool.execute(args.clone(), &ctx).await;
Ok(ToolInvocation::new(name.to_string(), args, output))
}
/// 并行执行多个工具调用(互不依赖的工具)。
///
/// 每个工具独立超时(`timeout_per_call_secs`0 表示不超时)。
/// 单个工具超时不会影响其他工具的返回。
pub async fn invoke_all(
&self,
calls: Vec<(String, Value)>,
timeout_per_call_secs: u64,
) -> Vec<ToolInvocation> {
let this = self.clone();
let futures = calls.into_iter().map(|(name, args)| {
let this = this.clone();
async move {
match if timeout_per_call_secs == 0 {
Ok(this.invoke(&name, args.clone()).await)
} else {
tokio::time::timeout(
Duration::from_secs(timeout_per_call_secs),
this.invoke(&name, args.clone()),
)
.await
} {
Ok(result) => result.unwrap_or_else(|e| {
ToolInvocation::new(name.clone(), args.clone(), Err(e))
}),
Err(_) => ToolInvocation::new(
name,
args,
Err(ToolError::McpTimeout("timeout".into())),
),
}
}
});
join_all(futures).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::BaseTool;
use async_trait::async_trait;
use serde_json::json;
struct AddTool {
base: i64,
}
#[async_trait]
impl BaseTool for AddTool {
fn name(&self) -> &str {
"add"
}
fn description(&self) -> &str {
"加法"
}
fn parameters(&self) -> Value {
json!({
"type": "object",
"properties": { "n": { "type": "integer" } },
"required": ["n"]
})
}
async fn execute(&self, args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
let n = args["n"].as_i64().unwrap_or(0);
Ok(json!({ "result": self.base + n }))
}
}
struct FailTool;
#[async_trait]
impl BaseTool for FailTool {
fn name(&self) -> &str {
"fail"
}
fn description(&self) -> &str {
"总会失败"
}
fn parameters(&self) -> Value {
json!({})
}
async fn execute(&self, _args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
Err(ToolError::ExecutionFailed("fail".into(), "boom".into()))
}
}
struct ShellTool;
#[async_trait]
impl BaseTool for ShellTool {
fn name(&self) -> &str {
"shell"
}
fn description(&self) -> &str {
"shell"
}
fn parameters(&self) -> Value {
json!({})
}
fn required_permissions(&self) -> Vec<crate::tools::permission::Permission> {
vec![crate::tools::permission::Permission::Shell]
}
async fn execute(&self, _args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
Ok(json!({}))
}
}
#[test]
fn test_register_and_get() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(AddTool { base: 10 })).unwrap();
assert!(reg.get("add").is_some());
assert!(reg.get("nonexistent").is_none());
}
#[test]
fn test_register_duplicate() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
let result = reg.register(Arc::new(AddTool { base: 1 }));
assert!(result.is_err());
}
#[test]
fn test_register_all() {
let mut reg = ToolRegistry::new();
let result = reg.register_all(vec![
Arc::new(AddTool { base: 1 }),
Arc::new(AddTool { base: 2 }),
]);
assert!(result.is_err()); // 重名 add → 失败
}
#[test]
fn test_unregister() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
let removed = reg.unregister("add");
assert!(removed.is_some());
assert!(reg.get("add").is_none());
}
#[test]
fn test_list_tools() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
reg.register(Arc::new(FailTool)).unwrap();
let names = reg.list_tools();
assert_eq!(names.len(), 2);
assert!(names.contains(&"add".to_string()));
assert!(names.contains(&"fail".to_string()));
}
#[test]
fn test_definitions() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
let defs = reg.definitions();
assert_eq!(defs.len(), 1);
assert_eq!(defs[0].name, "add");
assert!(defs[0].description.is_some());
}
#[tokio::test]
async fn test_invoke_success() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(AddTool { base: 100 })).unwrap();
let result = reg.invoke("add", json!({ "n": 5 })).await.unwrap();
let value = result.output.unwrap();
assert_eq!(value["result"], 105);
}
#[tokio::test]
async fn test_invoke_not_found() {
let reg = ToolRegistry::new();
let result = reg.invoke("nope", json!({})).await;
assert!(matches!(result, Err(ToolError::NotFound(_))));
}
#[tokio::test]
async fn test_invoke_execution_error() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(FailTool)).unwrap();
let result = reg.invoke("fail", json!({})).await.unwrap();
assert!(result.output.is_err());
}
#[tokio::test]
async fn test_invoke_with_permission_denied() {
let mut reg = ToolRegistry::new()
.with_permission_checker(PermissionChecker::new(Default::default()));
reg.register(Arc::new(ShellTool)).unwrap();
let result = reg.invoke("shell", json!({})).await;
assert!(matches!(result, Err(ToolError::PermissionDenied(_, _))));
}
#[tokio::test]
async fn test_invoke_all_parallel() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(AddTool { base: 1 })).unwrap();
reg.register(Arc::new(FailTool)).unwrap();
let calls = vec![
("add".into(), json!({ "n": 1 })),
("add".into(), json!({ "n": 2 })),
("fail".into(), json!({})),
];
let results = reg.invoke_all(calls, 0).await;
assert_eq!(results.len(), 3);
assert!(results[0].output.is_ok());
assert!(results[1].output.is_ok());
assert!(results[2].output.is_err());
}
#[tokio::test]
async fn test_invoke_all_with_timeout() {
let mut reg = ToolRegistry::new();
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
let calls = vec![("add".into(), json!({ "n": 1 }))];
let results = reg.invoke_all(calls, 5).await;
assert_eq!(results.len(), 1);
assert!(results[0].output.is_ok());
}
}