feat(tools): 添加工具系统框架与 MCP 协议客户端
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
pub mod llm;
|
||||
pub mod prompt;
|
||||
pub mod tools;
|
||||
|
||||
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
|
||||
|
||||
|
||||
+465
-1
@@ -19,7 +19,8 @@ use crate::llm::hooks::{HookContext, HookExecutor};
|
||||
use crate::llm::provider::LlmProvider;
|
||||
use crate::llm::stream::StreamEvent;
|
||||
use crate::llm::types::{
|
||||
ChatRequest, ChatResponse, OpenaiChatMessage, OpenaiTool, ToolChoice, ToolDefinition,
|
||||
ChatRequest, ChatResponse, FinishReason, OpenaiChatMessage, OpenaiTool, OpenaiToolCall,
|
||||
ToolChoice, ToolDefinition,
|
||||
};
|
||||
|
||||
/// LLM 调用周期配置。
|
||||
@@ -34,6 +35,15 @@ pub struct CycleConfig {
|
||||
pub max_turns: Option<u32>,
|
||||
/// 重试策略配置。
|
||||
pub retry: RetryConfig,
|
||||
/// 自动 tool 循环的最大轮次(独立于 `max_turns`,避免影响现有 `submit()` 语义)。
|
||||
/// 默认 `Some(10)`,防止 LLM 反复调用工具导致无限循环。
|
||||
pub max_tool_turns: Option<u32>,
|
||||
/// 单个工具执行的超时秒数(0 表示不超时)。
|
||||
/// 默认 60 秒。
|
||||
pub tool_timeout_secs: u64,
|
||||
/// 单个工具结果的最大字节数(超过此值将被截断)。
|
||||
/// 默认 65536(64KB),防止大结果导致 token 膨胀。
|
||||
pub max_tool_result_bytes: usize,
|
||||
}
|
||||
|
||||
impl Default for CycleConfig {
|
||||
@@ -44,6 +54,9 @@ impl Default for CycleConfig {
|
||||
temperature: None,
|
||||
max_turns: None,
|
||||
retry: RetryConfig::default(),
|
||||
max_tool_turns: Some(10),
|
||||
tool_timeout_secs: 60,
|
||||
max_tool_result_bytes: 65_536,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -440,4 +453,455 @@ impl LlmCycle {
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// 内部请求方法(与 `submit` 共享重试逻辑,但不 push user message 和 Assistant 响应)。
|
||||
///
|
||||
/// 用于 `submit_with_tools()` 的多轮 tool 循环。
|
||||
async fn submit_request(
|
||||
&mut self,
|
||||
tools: &[ToolDefinition],
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let mut attempts = 0;
|
||||
|
||||
loop {
|
||||
let request = self.build_request(tools);
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||
.with_request(&request);
|
||||
let results = executor
|
||||
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||
.await;
|
||||
if results.iter().any(|r| r.should_block) {
|
||||
let reason = results
|
||||
.iter()
|
||||
.find(|r| r.should_block)
|
||||
.and_then(|r| r.reason.clone())
|
||||
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||
return Err(LlmError::Other(reason));
|
||||
}
|
||||
}
|
||||
|
||||
match self.provider.chat(request).await {
|
||||
Ok(response) => {
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let post_request = self.build_request(tools);
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
|
||||
.with_request(&post_request);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||
.await;
|
||||
}
|
||||
self.usage.add(&response.usage);
|
||||
return Ok(response);
|
||||
}
|
||||
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
|
||||
attempts += 1;
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnRetry)
|
||||
.with_error(&e)
|
||||
.with_attempt(attempts);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnRetry, &ctx)
|
||||
.await;
|
||||
}
|
||||
|
||||
let delay = self.config.retry.compute_delay(attempts);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnError)
|
||||
.with_error(&e);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||
.await;
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 提交消息并自动处理工具调用循环。
|
||||
///
|
||||
/// 流程:
|
||||
/// 1. 发送请求(含工具定义)
|
||||
/// 2. 检查响应中的 finish_reason
|
||||
/// 3. 如果是 ToolCalls → push Assistant 消息 → 执行工具 → 回传结果 → 重复 1
|
||||
/// 4. 如果是 Stop/Length → push Assistant 消息 → 返回最终响应
|
||||
///
|
||||
/// 注意:OpenAI API 要求 tool 消息必须紧跟在对应的 Assistant(tool_calls)消息之后。
|
||||
/// 因此 push 工具结果前必须先 push Assistant 响应,否则 API 拒绝请求。
|
||||
pub async fn submit_with_tools(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
registry: &crate::tools::ToolRegistry,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let tools = registry.definitions();
|
||||
let max_turns = self.config.max_tool_turns.unwrap_or(10);
|
||||
let tool_timeout = self.config.tool_timeout_secs;
|
||||
let max_bytes = self.config.max_tool_result_bytes;
|
||||
|
||||
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||
self.maybe_compact();
|
||||
|
||||
let mut turn = 0;
|
||||
|
||||
loop {
|
||||
turn += 1;
|
||||
if turn > max_turns {
|
||||
return Err(LlmError::Other(format!(
|
||||
"达到最大工具循环轮次 ({max_turns})"
|
||||
)));
|
||||
}
|
||||
|
||||
let response = self.submit_request(&tools).await?;
|
||||
|
||||
// 判断是否需要执行工具
|
||||
let should_execute = matches!(response.stop_reason, Some(FinishReason::ToolCalls))
|
||||
&& has_tool_calls_in_message(&response.message);
|
||||
|
||||
// 将 Assistant 响应(含 tool_calls 或最终文本)追加到消息历史
|
||||
self.messages.push(response.message.clone());
|
||||
|
||||
if !should_execute {
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
// 解析 tool_calls 并执行
|
||||
let tool_calls = extract_tool_calls_from_message(&response.message);
|
||||
let calls: Vec<(String, serde_json::Value)> = tool_calls
|
||||
.into_iter()
|
||||
.map(|(_id, name, args)| {
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&args).unwrap_or(serde_json::Value::Null);
|
||||
(name, args)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = registry.invoke_all(calls, tool_timeout).await;
|
||||
|
||||
// 回传工具结果
|
||||
for result in results {
|
||||
let content = match result.output {
|
||||
Ok(value) => {
|
||||
let serialized = serde_json::to_string(&value).unwrap_or_else(|e| {
|
||||
tracing::warn!("工具结果序列化失败: {}", e);
|
||||
"{}".to_string()
|
||||
});
|
||||
truncate_tool_result(&serialized, max_bytes)
|
||||
}
|
||||
Err(e) if e.is_recoverable() => format!("错误: {}", e),
|
||||
Err(e) => {
|
||||
// 不可恢复错误:终止循环
|
||||
return Err(LlmError::Other(format!(
|
||||
"工具 '{}' 不可恢复错误: {}",
|
||||
result.tool_name, e
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
self.messages
|
||||
.push(OpenaiChatMessage::tool_result(result.tool_name, content));
|
||||
}
|
||||
|
||||
// 每轮工具执行后触发 compaction
|
||||
self.maybe_compact();
|
||||
}
|
||||
|
||||
// unreachable: loop returns
|
||||
#[allow(unreachable_code)]
|
||||
{
|
||||
Err(LlmError::Other("unreachable".into()))
|
||||
}
|
||||
}
|
||||
|
||||
/// 在接近上下文窗口时压缩历史消息。
|
||||
fn maybe_compact(&mut self) {
|
||||
if let Some(ref config) = self.compact_config
|
||||
&& should_compact(&self.messages, config, &self.compact_state)
|
||||
{
|
||||
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||
if freed > 0 {
|
||||
self.compact_state.record_success();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 判断 Assistant 消息是否包含 tool_calls。
|
||||
fn has_tool_calls_in_message(msg: &OpenaiChatMessage) -> bool {
|
||||
matches!(
|
||||
msg,
|
||||
OpenaiChatMessage::Assistant {
|
||||
tool_calls: Some(calls),
|
||||
..
|
||||
} if !calls.is_empty()
|
||||
)
|
||||
}
|
||||
|
||||
/// 提取 Assistant 消息中的 tool_calls。
|
||||
///
|
||||
/// 返回 `(tool_call_id, tool_name, arguments_json_string)` 列表。
|
||||
fn extract_tool_calls_from_message(
|
||||
msg: &OpenaiChatMessage,
|
||||
) -> Vec<(String, String, String)> {
|
||||
if let OpenaiChatMessage::Assistant {
|
||||
tool_calls: Some(calls),
|
||||
..
|
||||
} = msg
|
||||
{
|
||||
calls
|
||||
.iter()
|
||||
.map(|c| match c {
|
||||
OpenaiToolCall::Function { id, function } => {
|
||||
(id.clone(), function.name.clone(), function.arguments.clone())
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// 截断工具结果到指定字节数。
|
||||
fn truncate_tool_result(s: &str, max_bytes: usize) -> String {
|
||||
if s.len() <= max_bytes {
|
||||
return s.to_string();
|
||||
}
|
||||
let mut end = max_bytes;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
format!("{}\n\n[... truncated, original size: {} bytes ...]", &s[..end], s.len())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::llm::types::{ContentField, OpenaiContentPart};
|
||||
use crate::tools::{BaseTool, ToolRegistry};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
/// 模拟 Provider —— 预定义响应序列,按调用顺序返回。
|
||||
struct MockProvider {
|
||||
responses: std::sync::Mutex<Vec<ChatResponse>>,
|
||||
call_count: std::sync::Mutex<u32>,
|
||||
}
|
||||
|
||||
impl MockProvider {
|
||||
fn new(responses: Vec<ChatResponse>) -> Self {
|
||||
Self {
|
||||
responses: std::sync::Mutex::new(responses),
|
||||
call_count: std::sync::Mutex::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for MockProvider {
|
||||
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, LlmError> {
|
||||
let mut count = self.call_count.lock().unwrap();
|
||||
*count += 1;
|
||||
let mut responses = self.responses.lock().unwrap();
|
||||
if responses.is_empty() {
|
||||
return Err(LlmError::Other("no more mock responses".into()));
|
||||
}
|
||||
Ok(responses.remove(0))
|
||||
}
|
||||
}
|
||||
|
||||
fn empty_usage() -> crate::llm::types::Usage {
|
||||
crate::llm::types::Usage::default()
|
||||
}
|
||||
|
||||
fn assistant_text_response(text: &str) -> ChatResponse {
|
||||
ChatResponse {
|
||||
message: OpenaiChatMessage::assistant_text(text),
|
||||
usage: empty_usage(),
|
||||
stop_reason: Some(FinishReason::Stop),
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_tool_call_response(
|
||||
calls: Vec<(&str, &str, &str)>,
|
||||
) -> ChatResponse {
|
||||
use crate::llm::types::{OpenaiToolCall, FunctionCall};
|
||||
let tool_calls: Vec<OpenaiToolCall> = calls
|
||||
.into_iter()
|
||||
.map(|(id, name, args)| OpenaiToolCall::Function {
|
||||
id: id.to_string(),
|
||||
function: FunctionCall {
|
||||
name: name.to_string(),
|
||||
arguments: args.to_string(),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
ChatResponse {
|
||||
message: OpenaiChatMessage::Assistant {
|
||||
content: ContentField::Array(vec![OpenaiContentPart::Text {
|
||||
text: String::new(),
|
||||
}]),
|
||||
refusal: None,
|
||||
name: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
},
|
||||
usage: empty_usage(),
|
||||
stop_reason: Some(FinishReason::ToolCalls),
|
||||
}
|
||||
}
|
||||
|
||||
struct AddTool;
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for AddTool {
|
||||
fn name(&self) -> &str {
|
||||
"add"
|
||||
}
|
||||
fn description(&self) -> &str {
|
||||
"加法"
|
||||
}
|
||||
fn parameters(&self) -> Value {
|
||||
json!({"type":"object","properties":{"a":{"type":"integer"},"b":{"type":"integer"}}})
|
||||
}
|
||||
async fn execute(
|
||||
&self,
|
||||
args: Value,
|
||||
_ctx: &crate::tools::ToolContext<'_>,
|
||||
) -> Result<Value, crate::tools::ToolError> {
|
||||
let a = args["a"].as_i64().unwrap_or(0);
|
||||
let b = args["b"].as_i64().unwrap_or(0);
|
||||
Ok(json!({"result": a + b}))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_with_tools_single_turn() {
|
||||
// 第一轮:返回 tool_call;第二轮:返回最终文本
|
||||
let responses = vec![
|
||||
assistant_tool_call_response(vec![("call_1", "add", r#"{"a":1,"b":2}"#)]),
|
||||
assistant_text_response("答案是 3"),
|
||||
];
|
||||
let provider = Box::new(MockProvider::new(responses));
|
||||
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||
|
||||
let response = cycle
|
||||
.submit_with_tools("1+2=?".to_string(), ®istry)
|
||||
.await
|
||||
.unwrap();
|
||||
// 验证最终响应是文本响应
|
||||
assert!(matches!(
|
||||
response.message,
|
||||
OpenaiChatMessage::Assistant { .. }
|
||||
));
|
||||
|
||||
// 验证消息历史:user, assistant(tool_calls), tool, assistant(text)
|
||||
let messages = cycle.messages();
|
||||
assert_eq!(messages.len(), 4);
|
||||
assert!(matches!(messages[0], OpenaiChatMessage::User { .. }));
|
||||
assert!(matches!(messages[1], OpenaiChatMessage::Assistant { .. }));
|
||||
assert!(matches!(messages[2], OpenaiChatMessage::Tool { .. }));
|
||||
assert!(matches!(messages[3], OpenaiChatMessage::Assistant { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_with_tools_multi_turn() {
|
||||
// 3 轮 tool 调用后给出最终答案
|
||||
let responses = vec![
|
||||
assistant_tool_call_response(vec![("call_1", "add", r#"{"a":1,"b":2}"#)]),
|
||||
assistant_tool_call_response(vec![("call_2", "add", r#"{"a":3,"b":4}"#)]),
|
||||
assistant_tool_call_response(vec![("call_3", "add", r#"{"a":5,"b":6}"#)]),
|
||||
assistant_text_response("完成"),
|
||||
];
|
||||
let provider = Box::new(MockProvider::new(responses));
|
||||
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||
|
||||
let response = cycle
|
||||
.submit_with_tools("计算总和".to_string(), ®istry)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
response.message,
|
||||
OpenaiChatMessage::Assistant { .. }
|
||||
));
|
||||
|
||||
// user + 3*(assistant + tool) + final assistant = 8
|
||||
let messages = cycle.messages();
|
||||
assert_eq!(messages.len(), 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_with_tools_max_turns_exceeded() {
|
||||
// 配置 max_tool_turns = 2
|
||||
let mut config = CycleConfig::default();
|
||||
config.max_tool_turns = Some(2);
|
||||
// 4 轮 tool 调用 + 终止
|
||||
let responses = vec![
|
||||
assistant_tool_call_response(vec![("c1", "add", r#"{"a":1,"b":1}"#)]),
|
||||
assistant_tool_call_response(vec![("c2", "add", r#"{"a":1,"b":1}"#)]),
|
||||
assistant_tool_call_response(vec![("c3", "add", r#"{"a":1,"b":1}"#)]),
|
||||
assistant_text_response("完成"),
|
||||
];
|
||||
let provider = Box::new(MockProvider::new(responses));
|
||||
let mut cycle = LlmCycle::new(provider, config);
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||
|
||||
let result = cycle
|
||||
.submit_with_tools("test".to_string(), ®istry)
|
||||
.await;
|
||||
assert!(matches!(result, Err(LlmError::Other(msg)) if msg.contains("达到最大工具循环轮次")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_submit_with_tools_no_tool_call_response() {
|
||||
// LLM 直接给出最终响应(不调用工具)
|
||||
let responses = vec![assistant_text_response("直接回答")];
|
||||
let provider = Box::new(MockProvider::new(responses));
|
||||
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||
|
||||
let response = cycle
|
||||
.submit_with_tools("直接回答".to_string(), ®istry)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
response.message,
|
||||
OpenaiChatMessage::Assistant { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_tool_result_short() {
|
||||
let s = "short text";
|
||||
assert_eq!(truncate_tool_result(s, 100), "short text");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_tool_result_long() {
|
||||
let s = "a".repeat(1000);
|
||||
let truncated = truncate_tool_result(&s, 50);
|
||||
assert!(truncated.len() < s.len());
|
||||
assert!(truncated.contains("[... truncated,"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_tool_result_chinese_chars() {
|
||||
let s = "中".repeat(100);
|
||||
let truncated = truncate_tool_result(&s, 50);
|
||||
// 不会在字符中间截断
|
||||
assert!(truncated.starts_with("中"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
//! 工具系统 —— 工具抽象、注册、调用、权限控制与 MCP 集成。
|
||||
|
||||
pub mod base;
|
||||
pub mod error;
|
||||
pub mod mcp;
|
||||
pub mod permission;
|
||||
pub mod registry;
|
||||
|
||||
pub use base::{BaseTool, ToolContext, ToolRef};
|
||||
pub use error::ToolError;
|
||||
pub use mcp::{McpClient, McpTransport};
|
||||
pub use permission::{Permission, PermissionChecker, PermissionConfig};
|
||||
pub use registry::{ToolInvocation, ToolRegistry};
|
||||
@@ -0,0 +1,139 @@
|
||||
//! 工具抽象接口与执行上下文。
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::tools::error::ToolError;
|
||||
use crate::tools::permission::Permission;
|
||||
|
||||
/// 工具执行上下文 —— 携带每次执行的运行时信息。
|
||||
///
|
||||
/// 字段在 Phase 2 即注入 `execute()` 签名中,防止后续扩展时出现
|
||||
/// breaking change。后续阶段可扩展字段(如 `progress`、`shared_state`),
|
||||
/// 但已有工具实现无需修改。
|
||||
#[derive(Debug)]
|
||||
pub struct ToolContext<'a> {
|
||||
/// 当前对话/会话 ID,用于关联性追踪。
|
||||
pub session_id: &'a str,
|
||||
/// 链路追踪 ID,用于跨工具调用的耗时分布。
|
||||
pub trace_id: &'a str,
|
||||
/// 取消令牌,用于优雅取消正在执行的工具。
|
||||
pub cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl<'a> ToolContext<'a> {
|
||||
/// 创建一个新的工具执行上下文。
|
||||
pub fn new(session_id: &'a str, trace_id: &'a str) -> Self {
|
||||
Self {
|
||||
session_id,
|
||||
trace_id,
|
||||
cancellation_token: CancellationToken::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建一个使用给定取消令牌的上下文。
|
||||
pub fn with_cancellation_token(
|
||||
session_id: &'a str,
|
||||
trace_id: &'a str,
|
||||
token: CancellationToken,
|
||||
) -> Self {
|
||||
Self {
|
||||
session_id,
|
||||
trace_id,
|
||||
cancellation_token: token,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 工具抽象接口 —— 所有工具(自定义或 MCP)最终都实现此 trait。
|
||||
#[async_trait]
|
||||
pub trait BaseTool: Send + Sync {
|
||||
/// 工具名称(唯一标识,用于 LLM 的 tool_calls.name 匹配)。
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// 工具描述(LLM 据此决定是否调用此工具)。
|
||||
fn description(&self) -> &str;
|
||||
|
||||
/// 工具参数定义(JSON Schema 格式,传递给 LLM 的 tool.parameters)。
|
||||
fn parameters(&self) -> Value;
|
||||
|
||||
/// 声明工具所需的权限列表。
|
||||
fn required_permissions(&self) -> Vec<Permission> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
/// 执行工具调用。
|
||||
///
|
||||
/// `ctx` 携带执行上下文(session_id、trace_id、cancellation_token),
|
||||
/// 工具实现可在执行期间检查 `ctx.cancellation_token` 来支持优雅取消。
|
||||
async fn execute(&self, args: Value, ctx: &ToolContext<'_>) -> Result<Value, ToolError>;
|
||||
}
|
||||
|
||||
/// 为 `Arc<dyn BaseTool>` 提供 `Send + Sync` 包装,便于在 `Vec<Arc<dyn BaseTool>>` 中使用。
|
||||
pub type ToolRef = Arc<dyn BaseTool>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
struct EchoTool;
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for EchoTool {
|
||||
fn name(&self) -> &str {
|
||||
"echo"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"回显输入"
|
||||
}
|
||||
|
||||
fn parameters(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": { "type": "string" }
|
||||
},
|
||||
"required": ["text"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||
Ok(args)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_tool_execute() {
|
||||
let tool = EchoTool;
|
||||
let ctx = ToolContext::new("session-1", "trace-1");
|
||||
let result = tool.execute(json!({"text": "hello"}), &ctx).await.unwrap();
|
||||
assert_eq!(result, json!({"text": "hello"}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_permissions_empty() {
|
||||
let tool = EchoTool;
|
||||
assert!(tool.required_permissions().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_context_creation() {
|
||||
let ctx = ToolContext::new("s1", "t1");
|
||||
assert_eq!(ctx.session_id, "s1");
|
||||
assert_eq!(ctx.trace_id, "t1");
|
||||
assert!(!ctx.cancellation_token.is_cancelled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_context_cancellation() {
|
||||
let token = CancellationToken::new();
|
||||
token.cancel();
|
||||
let ctx = ToolContext::with_cancellation_token("s1", "t1", token);
|
||||
assert!(ctx.cancellation_token.is_cancelled());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
//! 工具系统错误类型。
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
/// 工具调用过程中可能发生的所有错误。
|
||||
#[derive(thiserror::Error, Debug, Clone)]
|
||||
pub enum ToolError {
|
||||
/// 工具未注册。
|
||||
#[error("工具 '{0}' 未注册")]
|
||||
NotFound(String),
|
||||
|
||||
/// 工具执行失败(可恢复——文本回传 LLM)。
|
||||
#[error("工具 '{0}' 执行失败: {1}")]
|
||||
ExecutionFailed(String, String),
|
||||
|
||||
/// 工具参数无效(可恢复——文本回传 LLM)。
|
||||
#[error("工具 '{0}' 参数无效: {1}")]
|
||||
InvalidArguments(String, String),
|
||||
|
||||
/// 权限被拒绝(不可恢复——终止循环)。
|
||||
#[error("权限被拒绝: 工具 '{0}' 需要 {1} 权限")]
|
||||
PermissionDenied(String, String),
|
||||
|
||||
/// MCP 协议错误(不可恢复)。
|
||||
#[error("MCP 协议错误: {0}")]
|
||||
McpError(String),
|
||||
|
||||
/// MCP 未初始化(不可恢复)。
|
||||
#[error("MCP 未初始化: {0}")]
|
||||
McpNotInitialized(String),
|
||||
|
||||
/// MCP 超时(不可恢复)。
|
||||
#[error("MCP 超时: {0}")]
|
||||
McpTimeout(String),
|
||||
|
||||
/// IO 错误(不可恢复)。
|
||||
#[error("IO 错误: {0}")]
|
||||
Io(Arc<std::io::Error>),
|
||||
|
||||
/// 取消。
|
||||
#[error("工具执行已取消: {0}")]
|
||||
Cancelled(String),
|
||||
|
||||
/// 其他未分类错误。
|
||||
#[error("其他错误: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ToolError {
|
||||
fn from(e: std::io::Error) -> Self {
|
||||
ToolError::Io(Arc::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolError {
|
||||
/// 判断错误是否可恢复——可恢复的错误回传 LLM 由其自行重试,
|
||||
/// 不可恢复的错误终止自动 tool 循环并返回给调用方。
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::ExecutionFailed(..) | Self::InvalidArguments(..) | Self::Other(_)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_execution_failed_is_recoverable() {
|
||||
let err = ToolError::ExecutionFailed("foo".into(), "boom".into());
|
||||
assert!(err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_arguments_is_recoverable() {
|
||||
let err = ToolError::InvalidArguments("foo".into(), "missing x".into());
|
||||
assert!(err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_found_is_not_recoverable() {
|
||||
let err = ToolError::NotFound("foo".into());
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permission_denied_is_not_recoverable() {
|
||||
let err = ToolError::PermissionDenied("foo".into(), "Shell".into());
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mcp_error_is_not_recoverable() {
|
||||
let err = ToolError::McpError("protocol".into());
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mcp_timeout_is_not_recoverable() {
|
||||
let err = ToolError::McpTimeout("foo".into());
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_io_is_not_recoverable() {
|
||||
let io_err = std::io::Error::new(std::io::ErrorKind::Other, "disk");
|
||||
let err = ToolError::from(io_err);
|
||||
assert!(!err.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_other_is_recoverable() {
|
||||
let err = ToolError::Other("something".into());
|
||||
assert!(err.is_recoverable());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,640 @@
|
||||
//! MCP 协议客户端 —— 与 MCP Server 通过 JSON-RPC over stdio 通信。
|
||||
//!
|
||||
//! 当前 Phase 2 实现 stdio transport。`StreamableHttp` 枚举变体已预留,
|
||||
//! 但实际实现推迟到后续版本。
|
||||
//!
|
||||
//! ## 协议版本
|
||||
//!
|
||||
//! 实现遵循 MCP 协议版本 2025-03-26。
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::process::Stdio;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
|
||||
use crate::llm::types::ToolDefinition;
|
||||
use crate::tools::base::{BaseTool, ToolContext, ToolRef};
|
||||
use crate::tools::error::ToolError;
|
||||
|
||||
/// MCP 协议版本。
|
||||
const MCP_VERSION: &str = "2025-03-26";
|
||||
|
||||
/// MCP 传输方式。
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum McpTransport {
|
||||
/// 通过子进程 stdin/stdout 通信。
|
||||
Stdio {
|
||||
/// 启动命令(如 `"npx"`)。
|
||||
command: String,
|
||||
/// 命令参数(如 `["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]`)。
|
||||
args: Vec<String>,
|
||||
},
|
||||
/// Streamable HTTP 传输(MCP 2025-03-26 引入,替代已废弃的 HTTP+SSE)。
|
||||
///
|
||||
/// 当前 Phase 2 预留枚举变体,调用方法会返回 `ToolError::McpError`。
|
||||
StreamableHttp {
|
||||
/// MCP 端点 URL。
|
||||
url: String,
|
||||
/// 可选的 HTTP 头(如 Authorization)。
|
||||
headers: Option<Vec<(String, String)>>,
|
||||
},
|
||||
}
|
||||
|
||||
/// JSON-RPC 请求。
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct JsonRpcRequest {
|
||||
jsonrpc: &'static str,
|
||||
id: u64,
|
||||
method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
params: Option<Value>,
|
||||
}
|
||||
|
||||
impl JsonRpcRequest {
|
||||
fn new(id: u64, method: impl Into<String>, params: Option<Value>) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0",
|
||||
id,
|
||||
method: method.into(),
|
||||
params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// JSON-RPC 响应。
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct JsonRpcResponse {
|
||||
jsonrpc: String,
|
||||
id: u64,
|
||||
#[serde(default)]
|
||||
result: Option<Value>,
|
||||
#[serde(default)]
|
||||
error: Option<JsonRpcError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct JsonRpcError {
|
||||
code: i32,
|
||||
message: String,
|
||||
#[serde(default)]
|
||||
data: Option<Value>,
|
||||
}
|
||||
|
||||
/// MCP 子进程运行时状态。
|
||||
struct ChildProcessState {
|
||||
child: Child,
|
||||
stdin: ChildStdin,
|
||||
pending: HashMap<u64, oneshot::Sender<Result<Value, ToolError>>>,
|
||||
next_id: u64,
|
||||
}
|
||||
|
||||
impl ChildProcessState {
|
||||
fn next_id(&mut self) -> u64 {
|
||||
self.next_id += 1;
|
||||
self.next_id
|
||||
}
|
||||
}
|
||||
|
||||
/// MCP Server 暴露的工具(缓存结构)。
|
||||
#[derive(Debug, Clone)]
|
||||
struct McpTool {
|
||||
name: String,
|
||||
description: Option<String>,
|
||||
input_schema: Value,
|
||||
}
|
||||
|
||||
/// MCP 客户端 —— 与 MCP 服务器通信。
|
||||
pub struct McpClient {
|
||||
transport: McpTransport,
|
||||
server_name: String,
|
||||
/// 已初始化的工具列表(缓存)。
|
||||
tools: Vec<McpTool>,
|
||||
/// 是否已初始化。
|
||||
initialized: AtomicBool,
|
||||
/// 超时时间(秒)。
|
||||
timeout_secs: u64,
|
||||
/// 子进程运行时状态(`connect()` 后创建,`close()` 后取回)。
|
||||
process: Option<Arc<Mutex<ChildProcessState>>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for McpClient {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("McpClient")
|
||||
.field("server_name", &self.server_name)
|
||||
.field("initialized", &self.initialized.load(Ordering::SeqCst))
|
||||
.field("tool_count", &self.tools.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl McpClient {
|
||||
/// 创建一个 MCP 客户端。
|
||||
pub fn new(server_name: impl Into<String>, transport: McpTransport) -> Self {
|
||||
Self {
|
||||
transport,
|
||||
server_name: server_name.into(),
|
||||
tools: Vec::new(),
|
||||
initialized: AtomicBool::new(false),
|
||||
timeout_secs: 30,
|
||||
process: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 设置超时时间(秒)。
|
||||
pub fn with_timeout(mut self, secs: u64) -> Self {
|
||||
self.timeout_secs = secs;
|
||||
self
|
||||
}
|
||||
|
||||
/// 检查是否已连接。
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
self.initialized.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// 连接并初始化(发送 initialize 请求)。
|
||||
pub async fn connect(&mut self) -> Result<(), ToolError> {
|
||||
if self.is_initialized() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match &self.transport {
|
||||
McpTransport::Stdio { command, args } => {
|
||||
let mut cmd = Command::new(command);
|
||||
cmd.args(args)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
#[cfg(unix)]
|
||||
cmd.kill_on_drop(true);
|
||||
#[cfg(windows)]
|
||||
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW
|
||||
|
||||
let mut child = cmd
|
||||
.spawn()
|
||||
.map_err(|e| ToolError::McpError(format!("启动 MCP 子进程失败: {e}")))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| ToolError::McpError("无法获取子进程 stdin".into()))?;
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| ToolError::McpError("无法获取子进程 stdout".into()))?;
|
||||
|
||||
// 启动 reader task 持续读取 stdout
|
||||
let pending: HashMap<u64, oneshot::Sender<Result<Value, ToolError>>> =
|
||||
HashMap::new();
|
||||
let state = Arc::new(Mutex::new(ChildProcessState {
|
||||
child,
|
||||
stdin,
|
||||
pending,
|
||||
next_id: 0,
|
||||
}));
|
||||
|
||||
// 启动后台 reader
|
||||
let pending_arc = Arc::clone(&state);
|
||||
tokio::spawn(async move {
|
||||
Self::read_loop(BufReader::new(stdout), pending_arc).await;
|
||||
});
|
||||
|
||||
self.process = Some(state);
|
||||
}
|
||||
McpTransport::StreamableHttp { .. } => {
|
||||
return Err(ToolError::McpError(
|
||||
"StreamableHttp transport 尚未实现".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// 发送 initialize 请求
|
||||
let init_params = json!({
|
||||
"protocolVersion": MCP_VERSION,
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "agcore",
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}
|
||||
});
|
||||
let _response = self
|
||||
.send_request("initialize", Some(init_params))
|
||||
.await?;
|
||||
|
||||
// 发送 initialized 通知(无 id)
|
||||
self.send_notification("notifications/initialized", Some(json!({})))
|
||||
.await?;
|
||||
|
||||
self.initialized.store(true, Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 列出服务器支持的工具(调用 `tools/list`)。
|
||||
pub async fn list_tools(&mut self) -> Result<Vec<ToolDefinition>, ToolError> {
|
||||
if !self.is_initialized() {
|
||||
return Err(ToolError::McpNotInitialized(self.server_name.clone()));
|
||||
}
|
||||
|
||||
let response = self.send_request("tools/list", None).await?;
|
||||
let tools_value = response
|
||||
.get("tools")
|
||||
.ok_or_else(|| ToolError::McpError("tools/list 响应缺少 tools 字段".into()))?;
|
||||
let tools_arr = tools_value
|
||||
.as_array()
|
||||
.ok_or_else(|| ToolError::McpError("tools/list 响应 tools 字段不是数组".into()))?;
|
||||
|
||||
self.tools.clear();
|
||||
let mut defs = Vec::with_capacity(tools_arr.len());
|
||||
for tool in tools_arr {
|
||||
let name = tool
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| ToolError::McpError("工具缺少 name 字段".into()))?
|
||||
.to_string();
|
||||
let description = tool
|
||||
.get("description")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
let input_schema = tool
|
||||
.get("inputSchema")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| json!({"type": "object", "properties": {}}));
|
||||
|
||||
self.tools.push(McpTool {
|
||||
name: name.clone(),
|
||||
description: description.clone(),
|
||||
input_schema: input_schema.clone(),
|
||||
});
|
||||
defs.push(ToolDefinition {
|
||||
name,
|
||||
description,
|
||||
parameters: input_schema,
|
||||
strict: None,
|
||||
});
|
||||
}
|
||||
Ok(defs)
|
||||
}
|
||||
|
||||
/// 调用一个工具(调用 `tools/call`)。
|
||||
pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, ToolError> {
|
||||
if !self.is_initialized() {
|
||||
return Err(ToolError::McpNotInitialized(self.server_name.clone()));
|
||||
}
|
||||
|
||||
let params = json!({
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
});
|
||||
let response = self.send_request("tools/call", Some(params)).await?;
|
||||
|
||||
// 解析 content 字段
|
||||
if let Some(content) = response.get("content").and_then(|c| c.as_array()) {
|
||||
// 收集所有 text 内容
|
||||
let mut combined = String::new();
|
||||
for item in content {
|
||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||
if !combined.is_empty() {
|
||||
combined.push('\n');
|
||||
}
|
||||
combined.push_str(text);
|
||||
}
|
||||
}
|
||||
if !combined.is_empty() {
|
||||
return Ok(Value::String(combined));
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 content 字段,尝试直接返回 is_error 标记
|
||||
if let Some(true) = response.get("isError").and_then(|v| v.as_bool()) {
|
||||
return Err(ToolError::ExecutionFailed(
|
||||
name.to_string(),
|
||||
"MCP 工具返回 isError=true".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// 回退:返回完整响应
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// 关闭连接(终止子进程)。
|
||||
pub async fn close(&mut self) -> Result<(), ToolError> {
|
||||
if !self.is_initialized() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 尝试发送 shutdown(不强制要求响应)
|
||||
let _ = self.send_notification("shutdown", None).await;
|
||||
|
||||
if let Some(state) = self.process.take() {
|
||||
let mut state = state.lock().await;
|
||||
// 优雅等待 5 秒
|
||||
let graceful = tokio::time::timeout(
|
||||
Duration::from_secs(5),
|
||||
state.child.wait(),
|
||||
)
|
||||
.await;
|
||||
if graceful.is_err() {
|
||||
// 超时则强杀
|
||||
let _ = state.child.kill().await;
|
||||
}
|
||||
}
|
||||
|
||||
self.initialized.store(false, Ordering::SeqCst);
|
||||
self.tools.clear();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 将 MCP 客户端转换为 `BaseTool` 适配器列表(用于注册到 `ToolRegistry`)。
|
||||
///
|
||||
/// **注意**:返回的适配器持有 `Arc<McpClient>`,但 `McpClient` 的可变性
|
||||
/// (如 `list_tools` 刷新缓存)会通过 `Mutex` 处理。当前适配器仅缓存
|
||||
/// 转换时的工具列表,不感知后续刷新。
|
||||
pub fn into_tools(self) -> Vec<ToolRef> {
|
||||
let mut tools = Vec::with_capacity(self.tools.len());
|
||||
for mcp_tool in self.tools {
|
||||
let tool = McpToolAdapter {
|
||||
client: McpClientHandle::Empty,
|
||||
name: mcp_tool.name,
|
||||
description: mcp_tool.description.unwrap_or_default(),
|
||||
parameters: mcp_tool.input_schema,
|
||||
};
|
||||
tools.push(Arc::new(tool) as ToolRef);
|
||||
}
|
||||
tools
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&self,
|
||||
method: &str,
|
||||
params: Option<Value>,
|
||||
) -> Result<Value, ToolError> {
|
||||
let state_arc = self
|
||||
.process
|
||||
.as_ref()
|
||||
.ok_or_else(|| ToolError::McpNotInitialized(self.server_name.clone()))?
|
||||
.clone();
|
||||
|
||||
let (id, request_json) = {
|
||||
let mut state = state_arc.lock().await;
|
||||
let id = state.next_id();
|
||||
let req = JsonRpcRequest::new(id, method, params);
|
||||
let json = serde_json::to_string(&req)
|
||||
.map_err(|e| ToolError::McpError(format!("序列化请求失败: {e}")))?;
|
||||
(id, json)
|
||||
};
|
||||
|
||||
// 注册 oneshot 等待响应
|
||||
let (tx, rx) = oneshot::channel();
|
||||
{
|
||||
let mut state = state_arc.lock().await;
|
||||
state.pending.insert(id, tx);
|
||||
}
|
||||
|
||||
// 写入请求
|
||||
{
|
||||
let mut state = state_arc.lock().await;
|
||||
state
|
||||
.stdin
|
||||
.write_all(request_json.as_bytes())
|
||||
.await
|
||||
.map_err(|e| ToolError::McpError(format!("写入请求失败: {e}")))?;
|
||||
state
|
||||
.stdin
|
||||
.write_all(b"\n")
|
||||
.await
|
||||
.map_err(|e| ToolError::McpError(format!("写入换行失败: {e}")))?;
|
||||
state.stdin.flush().await.map_err(|e| {
|
||||
ToolError::McpError(format!("flush stdin 失败: {e}"))
|
||||
})?;
|
||||
}
|
||||
|
||||
// 等待响应(带超时)
|
||||
tokio::time::timeout(Duration::from_secs(self.timeout_secs), rx)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
// 超时:清理 pending
|
||||
let state_arc = state_arc.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut state = state_arc.lock().await;
|
||||
state.pending.remove(&id);
|
||||
});
|
||||
ToolError::McpTimeout(method.to_string())
|
||||
})?
|
||||
.map_err(|_| ToolError::McpError("response channel 关闭".into()))?
|
||||
}
|
||||
|
||||
async fn send_notification(
|
||||
&self,
|
||||
method: &str,
|
||||
params: Option<Value>,
|
||||
) -> Result<(), ToolError> {
|
||||
let state_arc = self
|
||||
.process
|
||||
.as_ref()
|
||||
.ok_or_else(|| ToolError::McpNotInitialized(self.server_name.clone()))?
|
||||
.clone();
|
||||
|
||||
let notification = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params,
|
||||
});
|
||||
let json = serde_json::to_string(¬ification)
|
||||
.map_err(|e| ToolError::McpError(format!("序列化通知失败: {e}")))?;
|
||||
|
||||
let mut state = state_arc.lock().await;
|
||||
state
|
||||
.stdin
|
||||
.write_all(json.as_bytes())
|
||||
.await
|
||||
.map_err(|e| ToolError::McpError(format!("写入通知失败: {e}")))?;
|
||||
state
|
||||
.stdin
|
||||
.write_all(b"\n")
|
||||
.await
|
||||
.map_err(|e| ToolError::McpError(format!("写入换行失败: {e}")))?;
|
||||
state
|
||||
.stdin
|
||||
.flush()
|
||||
.await
|
||||
.map_err(|e| ToolError::McpError(format!("flush stdin 失败: {e}")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 持续读取 stdout,将响应分发到对应的 oneshot sender。
|
||||
async fn read_loop(
|
||||
mut reader: BufReader<ChildStdout>,
|
||||
state: Arc<Mutex<ChildProcessState>>,
|
||||
) {
|
||||
let mut line = String::new();
|
||||
loop {
|
||||
line.clear();
|
||||
match reader.read_line(&mut line).await {
|
||||
Ok(0) => {
|
||||
// EOF:通知所有 pending 失败
|
||||
let mut state = state.lock().await;
|
||||
for (_, tx) in state.pending.drain() {
|
||||
let _ = tx.send(Err(ToolError::McpError("子进程退出".into())));
|
||||
}
|
||||
break;
|
||||
}
|
||||
Ok(_) => {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
// 尝试解析为 JSON-RPC 响应
|
||||
let parsed: Result<JsonRpcResponse, _> = serde_json::from_str(trimmed);
|
||||
if let Ok(response) = parsed {
|
||||
let value = if let Some(err) = response.error {
|
||||
Err(ToolError::McpError(format!(
|
||||
"[{}] {}",
|
||||
err.code, err.message
|
||||
)))
|
||||
} else {
|
||||
Ok(response.result.unwrap_or(Value::Null))
|
||||
};
|
||||
let mut state = state.lock().await;
|
||||
if let Some(tx) = state.pending.remove(&response.id) {
|
||||
let _ = tx.send(value);
|
||||
}
|
||||
}
|
||||
// 非响应消息(通知、request from server)忽略
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("MCP read_loop error: {e}");
|
||||
let mut state = state.lock().await;
|
||||
for (_, tx) in state.pending.drain() {
|
||||
let _ = tx.send(Err(ToolError::McpError(format!("读取失败: {e}"))));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MCP 工具适配器 —— 将 MCP 工具包装为 `BaseTool`。
|
||||
struct McpToolAdapter {
|
||||
/// 持有 client 的弱引用。实际生产中应使用 `Arc<McpClient>`,
|
||||
/// 但当前 Phase 2 实现不直接持有可变的 `McpClient`。
|
||||
/// 标记为 unused 但保留字段以展示扩展路径。
|
||||
#[allow(dead_code)]
|
||||
client: McpClientHandle,
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: Value,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
enum McpClientHandle {
|
||||
Empty,
|
||||
// Future: Shared(Arc<McpClient>),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for McpToolAdapter {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
&self.description
|
||||
}
|
||||
|
||||
fn parameters(&self) -> Value {
|
||||
self.parameters.clone()
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
_args: Value,
|
||||
_ctx: &ToolContext<'_>,
|
||||
) -> Result<Value, ToolError> {
|
||||
// 当前 Phase 2 实现的简化:McpToolAdapter 不持有活跃 MCP 连接。
|
||||
// 实际生产中应持有 Arc<McpClient> 并通过 mcp.call_tool() 执行。
|
||||
// 这里返回错误,提示需要通过其他方式调用 MCP 工具。
|
||||
Err(ToolError::McpError(format!(
|
||||
"MCP 工具 '{}' 需要活跃的 McpClient 引用(当前 Phase 2 简化实现)",
|
||||
self.name
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_transport_debug() {
|
||||
let transport = McpTransport::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string()],
|
||||
};
|
||||
let formatted = format!("{transport:?}");
|
||||
assert!(formatted.contains("echo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_creation() {
|
||||
let transport = McpTransport::Stdio {
|
||||
command: "test".to_string(),
|
||||
args: vec![],
|
||||
};
|
||||
let client = McpClient::new("test-server", transport).with_timeout(60);
|
||||
assert_eq!(client.server_name, "test-server");
|
||||
assert_eq!(client.timeout_secs, 60);
|
||||
assert!(!client.is_initialized());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_request_serialize() {
|
||||
let req = JsonRpcRequest::new(42, "test", Some(json!({"a": 1})));
|
||||
let s = serde_json::to_string(&req).unwrap();
|
||||
assert!(s.contains("\"jsonrpc\":\"2.0\""));
|
||||
assert!(s.contains("\"id\":42"));
|
||||
assert!(s.contains("\"method\":\"test\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_response_parse_ok() {
|
||||
let s = r#"{"jsonrpc":"2.0","id":1,"result":{"foo":"bar"}}"#;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(s).unwrap();
|
||||
assert_eq!(resp.id, 1);
|
||||
assert!(resp.result.is_some());
|
||||
assert!(resp.error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_response_parse_error() {
|
||||
let s =
|
||||
r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"Method not found"}}"#;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(s).unwrap();
|
||||
assert_eq!(resp.id, 1);
|
||||
assert!(resp.result.is_none());
|
||||
let err = resp.error.unwrap();
|
||||
assert_eq!(err.code, -32601);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streamable_http_not_implemented() {
|
||||
let mut client = McpClient::new(
|
||||
"http-server",
|
||||
McpTransport::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
headers: None,
|
||||
},
|
||||
);
|
||||
let result = client.connect().await;
|
||||
// 当前 Phase 2 返回未实现错误
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result, Err(ToolError::McpError(_))));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,286 @@
|
||||
//! 工具权限管理。
|
||||
|
||||
use crate::tools::error::ToolError;
|
||||
|
||||
/// 权限级别枚举。
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum Permission {
|
||||
/// 只读(读取文件、查询数据库等)。
|
||||
Read,
|
||||
/// 写入(创建/修改文件、插入数据等)。
|
||||
Write,
|
||||
/// 删除(删除文件、记录等)。
|
||||
Delete,
|
||||
/// 网络访问(HTTP 请求等)。
|
||||
Network,
|
||||
/// Shell 命令执行。
|
||||
Shell,
|
||||
/// 文件系统操作(除读/写/删之外的 FS 操作)。
|
||||
FileSystem,
|
||||
/// 自定义权限(可通过 namespaced 字符串扩展)。
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Permission {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Read => write!(f, "Read"),
|
||||
Self::Write => write!(f, "Write"),
|
||||
Self::Delete => write!(f, "Delete"),
|
||||
Self::Network => write!(f, "Network"),
|
||||
Self::Shell => write!(f, "Shell"),
|
||||
Self::FileSystem => write!(f, "FileSystem"),
|
||||
Self::Custom(s) => write!(f, "Custom({})", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 权限配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PermissionConfig {
|
||||
/// 允许的权限列表(空 = 全部允许,配合 `allow_unspecified` 决定)。
|
||||
pub allowed: Vec<Permission>,
|
||||
/// 拒绝的权限列表(优先级高于 `allowed`)。
|
||||
pub denied: Vec<Permission>,
|
||||
/// 当工具未声明权限时是否允许执行。
|
||||
pub allow_unspecified: bool,
|
||||
}
|
||||
|
||||
impl Default for PermissionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allowed: vec![Permission::Read, Permission::Network],
|
||||
denied: vec![Permission::Delete, Permission::Shell],
|
||||
allow_unspecified: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 权限检查器。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PermissionChecker {
|
||||
config: PermissionConfig,
|
||||
}
|
||||
|
||||
impl PermissionChecker {
|
||||
/// 创建一个新的权限检查器。
|
||||
pub fn new(config: PermissionConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// 检查指定工具声明的权限是否允许执行。
|
||||
///
|
||||
/// 判定规则:
|
||||
/// 1. 任一权限在 `denied` 中 → 拒绝
|
||||
/// 2. 所有权限都在 `allowed` 中 → 允许
|
||||
/// 3. `allowed` 非空且存在未声明权限 → 拒绝
|
||||
/// 4. `allowed` 为空 → 按 `allow_unspecified` 判定
|
||||
/// 5. 工具未声明任何权限时按 `allow_unspecified` 判定
|
||||
pub fn check(&self, tool_name: &str, permissions: &[Permission]) -> Result<(), ToolError> {
|
||||
// 任一权限在 denied 中 → 拒绝
|
||||
for perm in permissions {
|
||||
if self.config.denied.contains(perm) {
|
||||
return Err(ToolError::PermissionDenied(
|
||||
tool_name.to_string(),
|
||||
perm.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// 工具未声明任何权限
|
||||
if permissions.is_empty() {
|
||||
return if self.config.allow_unspecified {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ToolError::PermissionDenied(
|
||||
tool_name.to_string(),
|
||||
"Unspecified".to_string(),
|
||||
))
|
||||
};
|
||||
}
|
||||
|
||||
// allowed 为空 → 走 allow_unspecified 兜底
|
||||
if self.config.allowed.is_empty() {
|
||||
return if self.config.allow_unspecified {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ToolError::PermissionDenied(
|
||||
tool_name.to_string(),
|
||||
"Unspecified".to_string(),
|
||||
))
|
||||
};
|
||||
}
|
||||
|
||||
// allowed 非空(白名单模式)—— 所有权限必须在其中
|
||||
for perm in permissions {
|
||||
if !self.config.allowed.contains(perm) {
|
||||
return Err(ToolError::PermissionDenied(
|
||||
tool_name.to_string(),
|
||||
perm.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn p(perm: Permission) -> Vec<Permission> {
|
||||
vec![perm]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config_allows_read() {
|
||||
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||
assert!(checker.check("weather", &p(Permission::Read)).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config_allows_network() {
|
||||
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||
assert!(checker.check("http_get", &p(Permission::Network)).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config_denies_delete() {
|
||||
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||
assert!(checker
|
||||
.check("rm_file", &p(Permission::Delete))
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config_denies_shell() {
|
||||
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||
assert!(checker.check("run_shell", &p(Permission::Shell)).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_white_list_mode_denies_unlisted() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![Permission::Read],
|
||||
denied: vec![],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &p(Permission::Read)).is_ok());
|
||||
assert!(checker.check("t", &p(Permission::Write)).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_white_list_mode_allows_listed() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![Permission::Read, Permission::Write],
|
||||
denied: vec![],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &p(Permission::Write)).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_black_list_deny_priority() {
|
||||
// 即便 allowed 中包含了 denied 权限,仍以 denied 为准
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![Permission::Shell, Permission::Read],
|
||||
denied: vec![Permission::Shell],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &p(Permission::Shell)).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_allowed_with_allow_unspecified() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![],
|
||||
denied: vec![],
|
||||
allow_unspecified: true,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &p(Permission::Write)).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_allowed_without_allow_unspecified() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![],
|
||||
denied: vec![],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &p(Permission::Write)).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unspecified_tool_with_allow() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![],
|
||||
denied: vec![],
|
||||
allow_unspecified: true,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &[]).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unspecified_tool_without_allow() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![],
|
||||
denied: vec![],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker.check("t", &[]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_permission_collision() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![Permission::Custom("db:read".into())],
|
||||
denied: vec![Permission::Custom("db:write".into())],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker
|
||||
.check("t", &[Permission::Custom("db:read".into())])
|
||||
.is_ok());
|
||||
assert!(checker
|
||||
.check("t", &[Permission::Custom("db:write".into())])
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_permission_all_in_allowed() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![Permission::Read, Permission::Write, Permission::Network],
|
||||
denied: vec![],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
assert!(checker
|
||||
.check(
|
||||
"t",
|
||||
&[Permission::Read, Permission::Network]
|
||||
)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_permission_one_not_in_allowed() {
|
||||
let cfg = PermissionConfig {
|
||||
allowed: vec![Permission::Read, Permission::Network],
|
||||
denied: vec![],
|
||||
allow_unspecified: false,
|
||||
};
|
||||
let checker = PermissionChecker::new(cfg);
|
||||
// 任一权限不在白名单则拒绝
|
||||
assert!(checker
|
||||
.check("t", &[Permission::Read, Permission::Write])
|
||||
.is_err());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,371 @@
|
||||
//! 工具注册表 —— 管理工具注册、发现、调用。
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::future::join_all;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::llm::types::ToolDefinition;
|
||||
use crate::tools::base::{ToolContext, ToolRef};
|
||||
use crate::tools::error::ToolError;
|
||||
use crate::tools::permission::PermissionChecker;
|
||||
|
||||
/// 工具调用记录 —— 用于追踪和调试。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolInvocation {
|
||||
/// 被调用的工具名。
|
||||
pub tool_name: String,
|
||||
/// 工具的入参。
|
||||
pub input: Value,
|
||||
/// 工具的输出。
|
||||
pub output: Result<Value, ToolError>,
|
||||
}
|
||||
|
||||
impl ToolInvocation {
|
||||
/// 创建一个新的工具调用记录。
|
||||
pub fn new(tool_name: String, input: Value, output: Result<Value, ToolError>) -> Self {
|
||||
Self {
|
||||
tool_name,
|
||||
input,
|
||||
output,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 工具注册表 —— 管理工具注册、发现、调用。
|
||||
///
|
||||
/// 通过 `Arc` 共享,方法签名 `&self`,可安全跨 task 并行调用。
|
||||
/// 不支持运行时并发注册(应在 setup 阶段一次性构建后冻结)。
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ToolRegistry {
|
||||
inner: Arc<ToolRegistryInner>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct ToolRegistryInner {
|
||||
tools: HashMap<String, ToolRef>,
|
||||
permission_checker: Option<Arc<PermissionChecker>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ToolRegistry {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ToolRegistry")
|
||||
.field("tool_names", &self.inner.tools.keys().collect::<Vec<_>>())
|
||||
.field("has_checker", &self.inner.permission_checker.is_some())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
/// 创建一个新的工具注册表。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Arc::new(ToolRegistryInner {
|
||||
tools: HashMap::new(),
|
||||
permission_checker: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// 设置权限检查器(Builder 模式)。
|
||||
pub fn with_permission_checker(mut self, checker: PermissionChecker) -> Self {
|
||||
let inner = Arc::make_mut(&mut self.inner);
|
||||
inner.permission_checker = Some(Arc::new(checker));
|
||||
self
|
||||
}
|
||||
|
||||
/// 注册一个工具。
|
||||
///
|
||||
/// 重复注册同名工具返回错误。
|
||||
pub fn register(&mut self, tool: ToolRef) -> Result<(), ToolError> {
|
||||
let name = tool.name().to_string();
|
||||
let inner = Arc::make_mut(&mut self.inner);
|
||||
if inner.tools.contains_key(&name) {
|
||||
return Err(ToolError::ExecutionFailed(name, "工具已存在".to_string()));
|
||||
}
|
||||
inner.tools.insert(name, tool);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 批量注册工具。
|
||||
pub fn register_all(&mut self, tools: Vec<ToolRef>) -> Result<(), ToolError> {
|
||||
for tool in tools {
|
||||
self.register(tool)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 注销一个工具。
|
||||
pub fn unregister(&mut self, name: &str) -> Option<ToolRef> {
|
||||
let inner = Arc::make_mut(&mut self.inner);
|
||||
inner.tools.remove(name)
|
||||
}
|
||||
|
||||
/// 按名称查找工具。
|
||||
pub fn get(&self, name: &str) -> Option<ToolRef> {
|
||||
self.inner.tools.get(name).cloned()
|
||||
}
|
||||
|
||||
/// 获取所有已注册工具的名称列表。
|
||||
pub fn list_tools(&self) -> Vec<String> {
|
||||
self.inner.tools.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// 获取所有工具的 `ToolDefinition` 列表(用于传递给 LLM)。
|
||||
pub fn definitions(&self) -> Vec<ToolDefinition> {
|
||||
self.inner
|
||||
.tools
|
||||
.values()
|
||||
.map(|tool| ToolDefinition {
|
||||
name: tool.name().to_string(),
|
||||
description: Some(tool.description().to_string()),
|
||||
parameters: tool.parameters(),
|
||||
strict: None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 调用单个工具(含权限检查)。
|
||||
pub async fn invoke(&self, name: &str, args: Value) -> Result<ToolInvocation, ToolError> {
|
||||
let tool = self
|
||||
.get(name)
|
||||
.ok_or_else(|| ToolError::NotFound(name.to_string()))?;
|
||||
|
||||
if let Some(checker) = &self.inner.permission_checker {
|
||||
checker.check(name, &tool.required_permissions())?;
|
||||
}
|
||||
|
||||
let ctx = ToolContext::new(name, "");
|
||||
let output = tool.execute(args.clone(), &ctx).await;
|
||||
Ok(ToolInvocation::new(name.to_string(), args, output))
|
||||
}
|
||||
|
||||
/// 并行执行多个工具调用(互不依赖的工具)。
|
||||
///
|
||||
/// 每个工具独立超时(`timeout_per_call_secs`,0 表示不超时)。
|
||||
/// 单个工具超时不会影响其他工具的返回。
|
||||
pub async fn invoke_all(
|
||||
&self,
|
||||
calls: Vec<(String, Value)>,
|
||||
timeout_per_call_secs: u64,
|
||||
) -> Vec<ToolInvocation> {
|
||||
let this = self.clone();
|
||||
let futures = calls.into_iter().map(|(name, args)| {
|
||||
let this = this.clone();
|
||||
async move {
|
||||
match if timeout_per_call_secs == 0 {
|
||||
Ok(this.invoke(&name, args.clone()).await)
|
||||
} else {
|
||||
tokio::time::timeout(
|
||||
Duration::from_secs(timeout_per_call_secs),
|
||||
this.invoke(&name, args.clone()),
|
||||
)
|
||||
.await
|
||||
} {
|
||||
Ok(result) => result.unwrap_or_else(|e| {
|
||||
ToolInvocation::new(name.clone(), args.clone(), Err(e))
|
||||
}),
|
||||
Err(_) => ToolInvocation::new(
|
||||
name,
|
||||
args,
|
||||
Err(ToolError::McpTimeout("timeout".into())),
|
||||
),
|
||||
}
|
||||
}
|
||||
});
|
||||
join_all(futures).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tools::BaseTool;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
|
||||
struct AddTool {
|
||||
base: i64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for AddTool {
|
||||
fn name(&self) -> &str {
|
||||
"add"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"加法"
|
||||
}
|
||||
|
||||
fn parameters(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": { "n": { "type": "integer" } },
|
||||
"required": ["n"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||
let n = args["n"].as_i64().unwrap_or(0);
|
||||
Ok(json!({ "result": self.base + n }))
|
||||
}
|
||||
}
|
||||
|
||||
struct FailTool;
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for FailTool {
|
||||
fn name(&self) -> &str {
|
||||
"fail"
|
||||
}
|
||||
fn description(&self) -> &str {
|
||||
"总会失败"
|
||||
}
|
||||
fn parameters(&self) -> Value {
|
||||
json!({})
|
||||
}
|
||||
async fn execute(&self, _args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||
Err(ToolError::ExecutionFailed("fail".into(), "boom".into()))
|
||||
}
|
||||
}
|
||||
|
||||
struct ShellTool;
|
||||
|
||||
#[async_trait]
|
||||
impl BaseTool for ShellTool {
|
||||
fn name(&self) -> &str {
|
||||
"shell"
|
||||
}
|
||||
fn description(&self) -> &str {
|
||||
"shell"
|
||||
}
|
||||
fn parameters(&self) -> Value {
|
||||
json!({})
|
||||
}
|
||||
fn required_permissions(&self) -> Vec<crate::tools::permission::Permission> {
|
||||
vec![crate::tools::permission::Permission::Shell]
|
||||
}
|
||||
async fn execute(&self, _args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||
Ok(json!({}))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_and_get() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 10 })).unwrap();
|
||||
assert!(reg.get("add").is_some());
|
||||
assert!(reg.get("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_duplicate() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||
let result = reg.register(Arc::new(AddTool { base: 1 }));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_all() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
let result = reg.register_all(vec![
|
||||
Arc::new(AddTool { base: 1 }),
|
||||
Arc::new(AddTool { base: 2 }),
|
||||
]);
|
||||
assert!(result.is_err()); // 重名 add → 失败
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unregister() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||
let removed = reg.unregister("add");
|
||||
assert!(removed.is_some());
|
||||
assert!(reg.get("add").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_tools() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||
reg.register(Arc::new(FailTool)).unwrap();
|
||||
let names = reg.list_tools();
|
||||
assert_eq!(names.len(), 2);
|
||||
assert!(names.contains(&"add".to_string()));
|
||||
assert!(names.contains(&"fail".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_definitions() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||
let defs = reg.definitions();
|
||||
assert_eq!(defs.len(), 1);
|
||||
assert_eq!(defs[0].name, "add");
|
||||
assert!(defs[0].description.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invoke_success() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 100 })).unwrap();
|
||||
let result = reg.invoke("add", json!({ "n": 5 })).await.unwrap();
|
||||
let value = result.output.unwrap();
|
||||
assert_eq!(value["result"], 105);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invoke_not_found() {
|
||||
let reg = ToolRegistry::new();
|
||||
let result = reg.invoke("nope", json!({})).await;
|
||||
assert!(matches!(result, Err(ToolError::NotFound(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invoke_execution_error() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(FailTool)).unwrap();
|
||||
let result = reg.invoke("fail", json!({})).await.unwrap();
|
||||
assert!(result.output.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invoke_with_permission_denied() {
|
||||
let mut reg = ToolRegistry::new()
|
||||
.with_permission_checker(PermissionChecker::new(Default::default()));
|
||||
reg.register(Arc::new(ShellTool)).unwrap();
|
||||
let result = reg.invoke("shell", json!({})).await;
|
||||
assert!(matches!(result, Err(ToolError::PermissionDenied(_, _))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invoke_all_parallel() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 1 })).unwrap();
|
||||
reg.register(Arc::new(FailTool)).unwrap();
|
||||
let calls = vec![
|
||||
("add".into(), json!({ "n": 1 })),
|
||||
("add".into(), json!({ "n": 2 })),
|
||||
("fail".into(), json!({})),
|
||||
];
|
||||
let results = reg.invoke_all(calls, 0).await;
|
||||
assert_eq!(results.len(), 3);
|
||||
assert!(results[0].output.is_ok());
|
||||
assert!(results[1].output.is_ok());
|
||||
assert!(results[2].output.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invoke_all_with_timeout() {
|
||||
let mut reg = ToolRegistry::new();
|
||||
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||
let calls = vec![("add".into(), json!({ "n": 1 }))];
|
||||
let results = reg.invoke_all(calls, 5).await;
|
||||
assert_eq!(results.len(), 1);
|
||||
assert!(results[0].output.is_ok());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user